1# Copyright 2015 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import collections
16import logging
17import threading
18from typing import Callable, Optional, Type, Union
19
20import grpc
21from grpc import _common
22from grpc._cython import cygrpc
23from grpc._typing import MetadataType
24
25_LOGGER = logging.getLogger(__name__)
26
27
28class _AuthMetadataContext(
29 collections.namedtuple(
30 "AuthMetadataContext",
31 (
32 "service_url",
33 "method_name",
34 ),
35 ),
36 grpc.AuthMetadataContext,
37):
38 pass
39
40
41class _CallbackState:
42 def __init__(self):
43 self.lock = threading.Lock()
44 self.called = False
45 self.exception = None
46
47
48class _AuthMetadataPluginCallback(grpc.AuthMetadataPluginCallback):
49 _state: _CallbackState
50 _callback: Callable
51
52 def __init__(self, state: _CallbackState, callback: Callable):
53 self._state = state
54 self._callback = callback
55
56 def __call__(
57 self, metadata: MetadataType, error: Optional[Type[BaseException]]
58 ) -> None:
59 with self._state.lock:
60 if self._state.exception is None:
61 if self._state.called:
62 error_msg = (
63 "AuthMetadataPluginCallback invoked more than once!"
64 )
65 raise RuntimeError(error_msg)
66 self._state.called = True
67 else:
68 error_msg = (
69 "AuthMetadataPluginCallback"
70 'raised exception "{self._state.exception}"!'
71 )
72 raise RuntimeError(error_msg)
73 if error is None:
74 self._callback(metadata, cygrpc.StatusCode.ok, None)
75 else:
76 self._callback(
77 None, cygrpc.StatusCode.internal, _common.encode(str(error))
78 )
79
80
81class _Plugin:
82 _metadata_plugin: grpc.AuthMetadataPlugin
83
84 def __init__(self, metadata_plugin: grpc.AuthMetadataPlugin):
85 self._metadata_plugin = metadata_plugin
86 self._stored_ctx = None
87
88 try:
89 import contextvars # pylint: disable=wrong-import-position
90
91 # The plugin may be invoked on a thread created by Core, which will not
92 # have the context propagated. This context is stored and installed in
93 # the thread invoking the plugin.
94 self._stored_ctx = contextvars.copy_context()
95 except ImportError:
96 # Support versions predating contextvars.
97 pass
98
99 def __call__(
100 self,
101 service_url: Union[str, bytes],
102 method_name: Union[str, bytes],
103 callback: Callable,
104 ) -> None:
105 context = _AuthMetadataContext(
106 _common.decode(service_url), _common.decode(method_name)
107 )
108 callback_state = _CallbackState()
109 try:
110 self._metadata_plugin(
111 context, _AuthMetadataPluginCallback(callback_state, callback)
112 )
113 except Exception as exception: # pylint: disable=broad-except
114 _LOGGER.exception(
115 'AuthMetadataPluginCallback "%s" raised exception!',
116 self._metadata_plugin,
117 )
118 with callback_state.lock:
119 callback_state.exception = exception
120 if callback_state.called:
121 return
122 callback(
123 None, cygrpc.StatusCode.internal, _common.encode(str(exception))
124 )
125
126
127def metadata_plugin_call_credentials(
128 metadata_plugin: grpc.AuthMetadataPlugin, name: Optional[str]
129) -> grpc.CallCredentials:
130 if name is None:
131 try:
132 effective_name = metadata_plugin.__name__
133 except AttributeError:
134 effective_name = metadata_plugin.__class__.__name__
135 else:
136 effective_name = name
137 return grpc.CallCredentials(
138 cygrpc.MetadataPluginCallCredentials(
139 _Plugin(metadata_plugin), _common.encode(effective_name)
140 )
141 )