Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/google/api_core/grpc_helpers_async.py: 38%

126 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-12-08 06:45 +0000

1# Copyright 2020 Google LLC 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15"""AsyncIO helpers for :mod:`grpc` supporting 3.7+. 

16 

17Please combine more detailed docstring in grpc_helpers.py to use following 

18functions. This module is implementing the same surface with AsyncIO semantics. 

19""" 

20 

21import asyncio 

22import functools 

23 

24from typing import Generic, Iterator, AsyncGenerator, TypeVar 

25 

26import grpc 

27from grpc import aio 

28 

29from google.api_core import exceptions, grpc_helpers 

30 

31# denotes the proto response type for grpc calls 

32P = TypeVar("P") 

33 

34# NOTE(lidiz) Alternatively, we can hack "__getattribute__" to perform 

35# automatic patching for us. But that means the overhead of creating an 

36# extra Python function spreads to every single send and receive. 

37 

38 

39class _WrappedCall(aio.Call): 

40 def __init__(self): 

41 self._call = None 

42 

43 def with_call(self, call): 

44 """Supplies the call object separately to keep __init__ clean.""" 

45 self._call = call 

46 return self 

47 

48 async def initial_metadata(self): 

49 return await self._call.initial_metadata() 

50 

51 async def trailing_metadata(self): 

52 return await self._call.trailing_metadata() 

53 

54 async def code(self): 

55 return await self._call.code() 

56 

57 async def details(self): 

58 return await self._call.details() 

59 

60 def cancelled(self): 

61 return self._call.cancelled() 

62 

63 def done(self): 

64 return self._call.done() 

65 

66 def time_remaining(self): 

67 return self._call.time_remaining() 

68 

69 def cancel(self): 

70 return self._call.cancel() 

71 

72 def add_done_callback(self, callback): 

73 self._call.add_done_callback(callback) 

74 

75 async def wait_for_connection(self): 

76 try: 

77 await self._call.wait_for_connection() 

78 except grpc.RpcError as rpc_error: 

79 raise exceptions.from_grpc_error(rpc_error) from rpc_error 

80 

81 

82class _WrappedUnaryResponseMixin(Generic[P], _WrappedCall): 

83 def __await__(self) -> Iterator[P]: 

84 try: 

85 response = yield from self._call.__await__() 

86 return response 

87 except grpc.RpcError as rpc_error: 

88 raise exceptions.from_grpc_error(rpc_error) from rpc_error 

89 

90 

91class _WrappedStreamResponseMixin(Generic[P], _WrappedCall): 

92 def __init__(self): 

93 self._wrapped_async_generator = None 

94 

95 async def read(self) -> P: 

96 try: 

97 return await self._call.read() 

98 except grpc.RpcError as rpc_error: 

99 raise exceptions.from_grpc_error(rpc_error) from rpc_error 

100 

101 async def _wrapped_aiter(self) -> AsyncGenerator[P, None]: 

102 try: 

103 # NOTE(lidiz) coverage doesn't understand the exception raised from 

104 # __anext__ method. It is covered by test case: 

105 # test_wrap_stream_errors_aiter_non_rpc_error 

106 async for response in self._call: # pragma: no branch 

107 yield response 

108 except grpc.RpcError as rpc_error: 

109 raise exceptions.from_grpc_error(rpc_error) from rpc_error 

110 

111 def __aiter__(self) -> AsyncGenerator[P, None]: 

112 if not self._wrapped_async_generator: 

113 self._wrapped_async_generator = self._wrapped_aiter() 

114 return self._wrapped_async_generator 

115 

116 

117class _WrappedStreamRequestMixin(_WrappedCall): 

118 async def write(self, request): 

119 try: 

120 await self._call.write(request) 

121 except grpc.RpcError as rpc_error: 

122 raise exceptions.from_grpc_error(rpc_error) from rpc_error 

123 

124 async def done_writing(self): 

125 try: 

126 await self._call.done_writing() 

127 except grpc.RpcError as rpc_error: 

128 raise exceptions.from_grpc_error(rpc_error) from rpc_error 

129 

130 

131# NOTE(lidiz) Implementing each individual class separately, so we don't 

132# expose any API that should not be seen. E.g., __aiter__ in unary-unary 

133# RPC, or __await__ in stream-stream RPC. 

134class _WrappedUnaryUnaryCall(_WrappedUnaryResponseMixin[P], aio.UnaryUnaryCall): 

135 """Wrapped UnaryUnaryCall to map exceptions.""" 

136 

137 

138class _WrappedUnaryStreamCall(_WrappedStreamResponseMixin[P], aio.UnaryStreamCall): 

139 """Wrapped UnaryStreamCall to map exceptions.""" 

140 

141 

142class _WrappedStreamUnaryCall( 

143 _WrappedUnaryResponseMixin[P], _WrappedStreamRequestMixin, aio.StreamUnaryCall 

144): 

145 """Wrapped StreamUnaryCall to map exceptions.""" 

146 

147 

148class _WrappedStreamStreamCall( 

149 _WrappedStreamRequestMixin, _WrappedStreamResponseMixin[P], aio.StreamStreamCall 

150): 

151 """Wrapped StreamStreamCall to map exceptions.""" 

152 

153 

154# public type alias denoting the return type of async streaming gapic calls 

155GrpcAsyncStream = _WrappedStreamResponseMixin[P] 

156# public type alias denoting the return type of unary gapic calls 

157AwaitableGrpcCall = _WrappedUnaryResponseMixin[P] 

158 

159 

160def _wrap_unary_errors(callable_): 

161 """Map errors for Unary-Unary async callables.""" 

162 grpc_helpers._patch_callable_name(callable_) 

163 

164 @functools.wraps(callable_) 

165 def error_remapped_callable(*args, **kwargs): 

166 call = callable_(*args, **kwargs) 

167 return _WrappedUnaryUnaryCall().with_call(call) 

168 

169 return error_remapped_callable 

170 

171 

172def _wrap_stream_errors(callable_): 

173 """Map errors for streaming RPC async callables.""" 

174 grpc_helpers._patch_callable_name(callable_) 

175 

176 @functools.wraps(callable_) 

177 async def error_remapped_callable(*args, **kwargs): 

178 call = callable_(*args, **kwargs) 

179 

180 if isinstance(call, aio.UnaryStreamCall): 

181 call = _WrappedUnaryStreamCall().with_call(call) 

182 elif isinstance(call, aio.StreamUnaryCall): 

183 call = _WrappedStreamUnaryCall().with_call(call) 

184 elif isinstance(call, aio.StreamStreamCall): 

185 call = _WrappedStreamStreamCall().with_call(call) 

186 else: 

187 raise TypeError("Unexpected type of call %s" % type(call)) 

188 

189 await call.wait_for_connection() 

190 return call 

191 

192 return error_remapped_callable 

193 

194 

195def wrap_errors(callable_): 

196 """Wrap a gRPC async callable and map :class:`grpc.RpcErrors` to 

197 friendly error classes. 

198 

199 Errors raised by the gRPC callable are mapped to the appropriate 

200 :class:`google.api_core.exceptions.GoogleAPICallError` subclasses. The 

201 original `grpc.RpcError` (which is usually also a `grpc.Call`) is 

202 available from the ``response`` property on the mapped exception. This 

203 is useful for extracting metadata from the original error. 

204 

205 Args: 

206 callable_ (Callable): A gRPC callable. 

207 

208 Returns: Callable: The wrapped gRPC callable. 

209 """ 

210 if isinstance(callable_, aio.UnaryUnaryMultiCallable): 

211 return _wrap_unary_errors(callable_) 

212 else: 

213 return _wrap_stream_errors(callable_) 

214 

215 

216def create_channel( 

217 target, 

218 credentials=None, 

219 scopes=None, 

220 ssl_credentials=None, 

221 credentials_file=None, 

222 quota_project_id=None, 

223 default_scopes=None, 

224 default_host=None, 

225 compression=None, 

226 **kwargs 

227): 

228 """Create an AsyncIO secure channel with credentials. 

229 

230 Args: 

231 target (str): The target service address in the format 'hostname:port'. 

232 credentials (google.auth.credentials.Credentials): The credentials. If 

233 not specified, then this function will attempt to ascertain the 

234 credentials from the environment using :func:`google.auth.default`. 

235 scopes (Sequence[str]): A optional list of scopes needed for this 

236 service. These are only used when credentials are not specified and 

237 are passed to :func:`google.auth.default`. 

238 ssl_credentials (grpc.ChannelCredentials): Optional SSL channel 

239 credentials. This can be used to specify different certificates. 

240 credentials_file (str): A file with credentials that can be loaded with 

241 :func:`google.auth.load_credentials_from_file`. This argument is 

242 mutually exclusive with credentials. 

243 quota_project_id (str): An optional project to use for billing and quota. 

244 default_scopes (Sequence[str]): Default scopes passed by a Google client 

245 library. Use 'scopes' for user-defined scopes. 

246 default_host (str): The default endpoint. e.g., "pubsub.googleapis.com". 

247 compression (grpc.Compression): An optional value indicating the 

248 compression method to be used over the lifetime of the channel. 

249 kwargs: Additional key-word args passed to :func:`aio.secure_channel`. 

250 

251 Returns: 

252 aio.Channel: The created channel. 

253 

254 Raises: 

255 google.api_core.DuplicateCredentialArgs: If both a credentials object and credentials_file are passed. 

256 """ 

257 

258 composite_credentials = grpc_helpers._create_composite_credentials( 

259 credentials=credentials, 

260 credentials_file=credentials_file, 

261 scopes=scopes, 

262 default_scopes=default_scopes, 

263 ssl_credentials=ssl_credentials, 

264 quota_project_id=quota_project_id, 

265 default_host=default_host, 

266 ) 

267 

268 return aio.secure_channel( 

269 target, composite_credentials, compression=compression, **kwargs 

270 ) 

271 

272 

273class FakeUnaryUnaryCall(_WrappedUnaryUnaryCall): 

274 """Fake implementation for unary-unary RPCs. 

275 

276 It is a dummy object for response message. Supply the intended response 

277 upon the initialization, and the coroutine will return the exact response 

278 message. 

279 """ 

280 

281 def __init__(self, response=object()): 

282 self.response = response 

283 self._future = asyncio.get_event_loop().create_future() 

284 self._future.set_result(self.response) 

285 

286 def __await__(self): 

287 response = yield from self._future.__await__() 

288 return response 

289 

290 

291class FakeStreamUnaryCall(_WrappedStreamUnaryCall): 

292 """Fake implementation for stream-unary RPCs. 

293 

294 It is a dummy object for response message. Supply the intended response 

295 upon the initialization, and the coroutine will return the exact response 

296 message. 

297 """ 

298 

299 def __init__(self, response=object()): 

300 self.response = response 

301 self._future = asyncio.get_event_loop().create_future() 

302 self._future.set_result(self.response) 

303 

304 def __await__(self): 

305 response = yield from self._future.__await__() 

306 return response 

307 

308 async def wait_for_connection(self): 

309 pass