Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/grpc/aio/_channel.py: 42%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

190 statements  

1# Copyright 2019 gRPC authors. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14"""Invocation-side implementation of gRPC Asyncio Python.""" 

15 

16import asyncio 

17import sys 

18from typing import Any, Iterable, List, Optional, Sequence 

19 

20import grpc 

21from grpc import _common 

22from grpc import _compression 

23from grpc import _grpcio_metadata 

24from grpc._cython import cygrpc 

25 

26from . import _base_call 

27from . import _base_channel 

28from ._call import StreamStreamCall 

29from ._call import StreamUnaryCall 

30from ._call import UnaryStreamCall 

31from ._call import UnaryUnaryCall 

32from ._interceptor import ClientInterceptor 

33from ._interceptor import InterceptedStreamStreamCall 

34from ._interceptor import InterceptedStreamUnaryCall 

35from ._interceptor import InterceptedUnaryStreamCall 

36from ._interceptor import InterceptedUnaryUnaryCall 

37from ._interceptor import StreamStreamClientInterceptor 

38from ._interceptor import StreamUnaryClientInterceptor 

39from ._interceptor import UnaryStreamClientInterceptor 

40from ._interceptor import UnaryUnaryClientInterceptor 

41from ._metadata import Metadata 

42from ._typing import ChannelArgumentType 

43from ._typing import DeserializingFunction 

44from ._typing import MetadataType 

45from ._typing import RequestIterableType 

46from ._typing import RequestType 

47from ._typing import ResponseType 

48from ._typing import SerializingFunction 

49from ._utils import _timeout_to_deadline 

50 

51_USER_AGENT = "grpc-python-asyncio/{}".format(_grpcio_metadata.__version__) 

52 

53if sys.version_info[1] < 7: 

54 

55 def _all_tasks() -> Iterable[asyncio.Task]: 

56 return asyncio.Task.all_tasks() # pylint: disable=no-member 

57 

58else: 

59 

60 def _all_tasks() -> Iterable[asyncio.Task]: 

61 return asyncio.all_tasks() 

62 

63 

64def _augment_channel_arguments( 

65 base_options: ChannelArgumentType, compression: Optional[grpc.Compression] 

66): 

67 compression_channel_argument = _compression.create_channel_option( 

68 compression 

69 ) 

70 user_agent_channel_argument = ( 

71 ( 

72 cygrpc.ChannelArgKey.primary_user_agent_string, 

73 _USER_AGENT, 

74 ), 

75 ) 

76 return ( 

77 tuple(base_options) 

78 + compression_channel_argument 

79 + user_agent_channel_argument 

80 ) 

81 

82 

83class _BaseMultiCallable: 

84 """Base class of all multi callable objects. 

85 

86 Handles the initialization logic and stores common attributes. 

87 """ 

88 

89 _loop: asyncio.AbstractEventLoop 

90 _channel: cygrpc.AioChannel 

91 _method: bytes 

92 _request_serializer: SerializingFunction 

93 _response_deserializer: DeserializingFunction 

94 _interceptors: Optional[Sequence[ClientInterceptor]] 

95 _references: List[Any] 

96 _loop: asyncio.AbstractEventLoop 

97 

98 # pylint: disable=too-many-arguments 

99 def __init__( 

100 self, 

101 channel: cygrpc.AioChannel, 

102 method: bytes, 

103 request_serializer: SerializingFunction, 

104 response_deserializer: DeserializingFunction, 

105 interceptors: Optional[Sequence[ClientInterceptor]], 

106 references: List[Any], 

107 loop: asyncio.AbstractEventLoop, 

108 ) -> None: 

109 self._loop = loop 

110 self._channel = channel 

111 self._method = method 

112 self._request_serializer = request_serializer 

113 self._response_deserializer = response_deserializer 

114 self._interceptors = interceptors 

115 self._references = references 

116 

117 @staticmethod 

118 def _init_metadata( 

119 metadata: Optional[MetadataType] = None, 

120 compression: Optional[grpc.Compression] = None, 

121 ) -> Metadata: 

122 """Based on the provided values for <metadata> or <compression> initialise the final 

123 metadata, as it should be used for the current call. 

124 """ 

125 metadata = metadata or Metadata() 

126 if not isinstance(metadata, Metadata) and isinstance(metadata, tuple): 

127 metadata = Metadata.from_tuple(metadata) 

128 if compression: 

129 metadata = Metadata( 

130 *_compression.augment_metadata(metadata, compression) 

131 ) 

132 return metadata 

133 

134 

135class UnaryUnaryMultiCallable( 

136 _BaseMultiCallable, _base_channel.UnaryUnaryMultiCallable 

137): 

138 def __call__( 

139 self, 

140 request: RequestType, 

141 *, 

142 timeout: Optional[float] = None, 

143 metadata: Optional[MetadataType] = None, 

144 credentials: Optional[grpc.CallCredentials] = None, 

145 wait_for_ready: Optional[bool] = None, 

146 compression: Optional[grpc.Compression] = None, 

147 ) -> _base_call.UnaryUnaryCall[RequestType, ResponseType]: 

148 metadata = self._init_metadata(metadata, compression) 

149 if not self._interceptors: 

150 call = UnaryUnaryCall( 

151 request, 

152 _timeout_to_deadline(timeout), 

153 metadata, 

154 credentials, 

155 wait_for_ready, 

156 self._channel, 

157 self._method, 

158 self._request_serializer, 

159 self._response_deserializer, 

160 self._loop, 

161 ) 

162 else: 

163 call = InterceptedUnaryUnaryCall( 

164 self._interceptors, 

165 request, 

166 timeout, 

167 metadata, 

168 credentials, 

169 wait_for_ready, 

170 self._channel, 

171 self._method, 

172 self._request_serializer, 

173 self._response_deserializer, 

174 self._loop, 

175 ) 

176 

177 return call 

178 

179 

180class UnaryStreamMultiCallable( 

181 _BaseMultiCallable, _base_channel.UnaryStreamMultiCallable 

182): 

183 def __call__( 

184 self, 

185 request: RequestType, 

186 *, 

187 timeout: Optional[float] = None, 

188 metadata: Optional[MetadataType] = None, 

189 credentials: Optional[grpc.CallCredentials] = None, 

190 wait_for_ready: Optional[bool] = None, 

191 compression: Optional[grpc.Compression] = None, 

192 ) -> _base_call.UnaryStreamCall[RequestType, ResponseType]: 

193 metadata = self._init_metadata(metadata, compression) 

194 

195 if not self._interceptors: 

196 call = UnaryStreamCall( 

197 request, 

198 _timeout_to_deadline(timeout), 

199 metadata, 

200 credentials, 

201 wait_for_ready, 

202 self._channel, 

203 self._method, 

204 self._request_serializer, 

205 self._response_deserializer, 

206 self._loop, 

207 ) 

208 else: 

209 call = InterceptedUnaryStreamCall( 

210 self._interceptors, 

211 request, 

212 timeout, 

213 metadata, 

214 credentials, 

215 wait_for_ready, 

216 self._channel, 

217 self._method, 

218 self._request_serializer, 

219 self._response_deserializer, 

220 self._loop, 

221 ) 

222 

223 return call 

224 

225 

226class StreamUnaryMultiCallable( 

227 _BaseMultiCallable, _base_channel.StreamUnaryMultiCallable 

228): 

229 def __call__( 

230 self, 

231 request_iterator: Optional[RequestIterableType] = None, 

232 timeout: Optional[float] = None, 

233 metadata: Optional[MetadataType] = None, 

234 credentials: Optional[grpc.CallCredentials] = None, 

235 wait_for_ready: Optional[bool] = None, 

236 compression: Optional[grpc.Compression] = None, 

237 ) -> _base_call.StreamUnaryCall: 

238 metadata = self._init_metadata(metadata, compression) 

239 

240 if not self._interceptors: 

241 call = StreamUnaryCall( 

242 request_iterator, 

243 _timeout_to_deadline(timeout), 

244 metadata, 

245 credentials, 

246 wait_for_ready, 

247 self._channel, 

248 self._method, 

249 self._request_serializer, 

250 self._response_deserializer, 

251 self._loop, 

252 ) 

253 else: 

254 call = InterceptedStreamUnaryCall( 

255 self._interceptors, 

256 request_iterator, 

257 timeout, 

258 metadata, 

259 credentials, 

260 wait_for_ready, 

261 self._channel, 

262 self._method, 

263 self._request_serializer, 

264 self._response_deserializer, 

265 self._loop, 

266 ) 

267 

268 return call 

269 

270 

271class StreamStreamMultiCallable( 

272 _BaseMultiCallable, _base_channel.StreamStreamMultiCallable 

273): 

274 def __call__( 

275 self, 

276 request_iterator: Optional[RequestIterableType] = None, 

277 timeout: Optional[float] = None, 

278 metadata: Optional[MetadataType] = None, 

279 credentials: Optional[grpc.CallCredentials] = None, 

280 wait_for_ready: Optional[bool] = None, 

281 compression: Optional[grpc.Compression] = None, 

282 ) -> _base_call.StreamStreamCall: 

283 metadata = self._init_metadata(metadata, compression) 

284 

285 if not self._interceptors: 

286 call = StreamStreamCall( 

287 request_iterator, 

288 _timeout_to_deadline(timeout), 

289 metadata, 

290 credentials, 

291 wait_for_ready, 

292 self._channel, 

293 self._method, 

294 self._request_serializer, 

295 self._response_deserializer, 

296 self._loop, 

297 ) 

298 else: 

299 call = InterceptedStreamStreamCall( 

300 self._interceptors, 

301 request_iterator, 

302 timeout, 

303 metadata, 

304 credentials, 

305 wait_for_ready, 

306 self._channel, 

307 self._method, 

308 self._request_serializer, 

309 self._response_deserializer, 

310 self._loop, 

311 ) 

312 

313 return call 

314 

315 

316class Channel(_base_channel.Channel): 

317 _loop: asyncio.AbstractEventLoop 

318 _channel: cygrpc.AioChannel 

319 _unary_unary_interceptors: List[UnaryUnaryClientInterceptor] 

320 _unary_stream_interceptors: List[UnaryStreamClientInterceptor] 

321 _stream_unary_interceptors: List[StreamUnaryClientInterceptor] 

322 _stream_stream_interceptors: List[StreamStreamClientInterceptor] 

323 

324 def __init__( 

325 self, 

326 target: str, 

327 options: ChannelArgumentType, 

328 credentials: Optional[grpc.ChannelCredentials], 

329 compression: Optional[grpc.Compression], 

330 interceptors: Optional[Sequence[ClientInterceptor]], 

331 ): 

332 """Constructor. 

333 

334 Args: 

335 target: The target to which to connect. 

336 options: Configuration options for the channel. 

337 credentials: A cygrpc.ChannelCredentials or None. 

338 compression: An optional value indicating the compression method to be 

339 used over the lifetime of the channel. 

340 interceptors: An optional list of interceptors that would be used for 

341 intercepting any RPC executed with that channel. 

342 """ 

343 self._unary_unary_interceptors = [] 

344 self._unary_stream_interceptors = [] 

345 self._stream_unary_interceptors = [] 

346 self._stream_stream_interceptors = [] 

347 

348 if interceptors is not None: 

349 for interceptor in interceptors: 

350 if isinstance(interceptor, UnaryUnaryClientInterceptor): 

351 self._unary_unary_interceptors.append(interceptor) 

352 elif isinstance(interceptor, UnaryStreamClientInterceptor): 

353 self._unary_stream_interceptors.append(interceptor) 

354 elif isinstance(interceptor, StreamUnaryClientInterceptor): 

355 self._stream_unary_interceptors.append(interceptor) 

356 elif isinstance(interceptor, StreamStreamClientInterceptor): 

357 self._stream_stream_interceptors.append(interceptor) 

358 else: 

359 raise ValueError( 

360 "Interceptor {} must be ".format(interceptor) 

361 + "{} or ".format(UnaryUnaryClientInterceptor.__name__) 

362 + "{} or ".format(UnaryStreamClientInterceptor.__name__) 

363 + "{} or ".format(StreamUnaryClientInterceptor.__name__) 

364 + "{}. ".format(StreamStreamClientInterceptor.__name__) 

365 ) 

366 

367 self._loop = cygrpc.get_working_loop() 

368 self._channel = cygrpc.AioChannel( 

369 _common.encode(target), 

370 _augment_channel_arguments(options, compression), 

371 credentials, 

372 self._loop, 

373 ) 

374 

375 async def __aenter__(self): 

376 return self 

377 

378 async def __aexit__(self, exc_type, exc_val, exc_tb): 

379 await self._close(None) 

380 

381 async def _close(self, grace): # pylint: disable=too-many-branches 

382 if self._channel.closed(): 

383 return 

384 

385 # No new calls will be accepted by the Cython channel. 

386 self._channel.closing() 

387 

388 # Iterate through running tasks 

389 tasks = _all_tasks() 

390 calls = [] 

391 call_tasks = [] 

392 for task in tasks: 

393 try: 

394 stack = task.get_stack(limit=1) 

395 except AttributeError as attribute_error: 

396 # NOTE(lidiz) tl;dr: If the Task is created with a CPython 

397 # object, it will trigger AttributeError. 

398 # 

399 # In the global finalizer, the event loop schedules 

400 # a CPython PyAsyncGenAThrow object. 

401 # https://github.com/python/cpython/blob/00e45877e33d32bb61aa13a2033e3bba370bda4d/Lib/asyncio/base_events.py#L484 

402 # 

403 # However, the PyAsyncGenAThrow object is written in C and 

404 # failed to include the normal Python frame objects. Hence, 

405 # this exception is a false negative, and it is safe to ignore 

406 # the failure. It is fixed by https://github.com/python/cpython/pull/18669, 

407 # but not available until 3.9 or 3.8.3. So, we have to keep it 

408 # for a while. 

409 # TODO(lidiz) drop this hack after 3.8 deprecation 

410 if "frame" in str(attribute_error): 

411 continue 

412 else: 

413 raise 

414 

415 # If the Task is created by a C-extension, the stack will be empty. 

416 if not stack: 

417 continue 

418 

419 # Locate ones created by `aio.Call`. 

420 frame = stack[0] 

421 candidate = frame.f_locals.get("self") 

422 # Explicitly check for a non-null candidate instead of the more pythonic 'if candidate:' 

423 # because doing 'if candidate:' assumes that the coroutine implements '__bool__' which 

424 # might not always be the case. 

425 if candidate is not None: 

426 if isinstance(candidate, _base_call.Call): 

427 if hasattr(candidate, "_channel"): 

428 # For intercepted Call object 

429 if candidate._channel is not self._channel: 

430 continue 

431 elif hasattr(candidate, "_cython_call"): 

432 # For normal Call object 

433 if candidate._cython_call._channel is not self._channel: 

434 continue 

435 else: 

436 # Unidentified Call object 

437 raise cygrpc.InternalError( 

438 f"Unrecognized call object: {candidate}" 

439 ) 

440 

441 calls.append(candidate) 

442 call_tasks.append(task) 

443 

444 # If needed, try to wait for them to finish. 

445 # Call objects are not always awaitables. 

446 if grace and call_tasks: 

447 await asyncio.wait(call_tasks, timeout=grace) 

448 

449 # Time to cancel existing calls. 

450 for call in calls: 

451 call.cancel() 

452 

453 # Destroy the channel 

454 self._channel.close() 

455 

456 async def close(self, grace: Optional[float] = None): 

457 await self._close(grace) 

458 

459 def __del__(self): 

460 if hasattr(self, "_channel"): 

461 if not self._channel.closed(): 

462 self._channel.close() 

463 

464 def get_state( 

465 self, try_to_connect: bool = False 

466 ) -> grpc.ChannelConnectivity: 

467 result = self._channel.check_connectivity_state(try_to_connect) 

468 return _common.CYGRPC_CONNECTIVITY_STATE_TO_CHANNEL_CONNECTIVITY[result] 

469 

470 async def wait_for_state_change( 

471 self, 

472 last_observed_state: grpc.ChannelConnectivity, 

473 ) -> None: 

474 assert await self._channel.watch_connectivity_state( 

475 last_observed_state.value[0], None 

476 ) 

477 

478 async def channel_ready(self) -> None: 

479 state = self.get_state(try_to_connect=True) 

480 while state != grpc.ChannelConnectivity.READY: 

481 await self.wait_for_state_change(state) 

482 state = self.get_state(try_to_connect=True) 

483 

484 # TODO(xuanwn): Implement this method after we have 

485 # observability for Asyncio. 

486 def _get_registered_call_handle(self, method: str) -> int: 

487 pass 

488 

489 # TODO(xuanwn): Implement _registered_method after we have 

490 # observability for Asyncio. 

491 # pylint: disable=arguments-differ,unused-argument 

492 def unary_unary( 

493 self, 

494 method: str, 

495 request_serializer: Optional[SerializingFunction] = None, 

496 response_deserializer: Optional[DeserializingFunction] = None, 

497 _registered_method: Optional[bool] = False, 

498 ) -> UnaryUnaryMultiCallable: 

499 return UnaryUnaryMultiCallable( 

500 self._channel, 

501 _common.encode(method), 

502 request_serializer, 

503 response_deserializer, 

504 self._unary_unary_interceptors, 

505 [self], 

506 self._loop, 

507 ) 

508 

509 # TODO(xuanwn): Implement _registered_method after we have 

510 # observability for Asyncio. 

511 # pylint: disable=arguments-differ,unused-argument 

512 def unary_stream( 

513 self, 

514 method: str, 

515 request_serializer: Optional[SerializingFunction] = None, 

516 response_deserializer: Optional[DeserializingFunction] = None, 

517 _registered_method: Optional[bool] = False, 

518 ) -> UnaryStreamMultiCallable: 

519 return UnaryStreamMultiCallable( 

520 self._channel, 

521 _common.encode(method), 

522 request_serializer, 

523 response_deserializer, 

524 self._unary_stream_interceptors, 

525 [self], 

526 self._loop, 

527 ) 

528 

529 # TODO(xuanwn): Implement _registered_method after we have 

530 # observability for Asyncio. 

531 # pylint: disable=arguments-differ,unused-argument 

532 def stream_unary( 

533 self, 

534 method: str, 

535 request_serializer: Optional[SerializingFunction] = None, 

536 response_deserializer: Optional[DeserializingFunction] = None, 

537 _registered_method: Optional[bool] = False, 

538 ) -> StreamUnaryMultiCallable: 

539 return StreamUnaryMultiCallable( 

540 self._channel, 

541 _common.encode(method), 

542 request_serializer, 

543 response_deserializer, 

544 self._stream_unary_interceptors, 

545 [self], 

546 self._loop, 

547 ) 

548 

549 # TODO(xuanwn): Implement _registered_method after we have 

550 # observability for Asyncio. 

551 # pylint: disable=arguments-differ,unused-argument 

552 def stream_stream( 

553 self, 

554 method: str, 

555 request_serializer: Optional[SerializingFunction] = None, 

556 response_deserializer: Optional[DeserializingFunction] = None, 

557 _registered_method: Optional[bool] = False, 

558 ) -> StreamStreamMultiCallable: 

559 return StreamStreamMultiCallable( 

560 self._channel, 

561 _common.encode(method), 

562 request_serializer, 

563 response_deserializer, 

564 self._stream_stream_interceptors, 

565 [self], 

566 self._loop, 

567 ) 

568 

569 

570def insecure_channel( 

571 target: str, 

572 options: Optional[ChannelArgumentType] = None, 

573 compression: Optional[grpc.Compression] = None, 

574 interceptors: Optional[Sequence[ClientInterceptor]] = None, 

575): 

576 """Creates an insecure asynchronous Channel to a server. 

577 

578 Args: 

579 target: The server address 

580 options: An optional list of key-value pairs (:term:`channel_arguments` 

581 in gRPC Core runtime) to configure the channel. 

582 compression: An optional value indicating the compression method to be 

583 used over the lifetime of the channel. 

584 interceptors: An optional sequence of interceptors that will be executed for 

585 any call executed with this channel. 

586 

587 Returns: 

588 A Channel. 

589 """ 

590 return Channel( 

591 target, 

592 () if options is None else options, 

593 None, 

594 compression, 

595 interceptors, 

596 ) 

597 

598 

599def secure_channel( 

600 target: str, 

601 credentials: grpc.ChannelCredentials, 

602 options: Optional[ChannelArgumentType] = None, 

603 compression: Optional[grpc.Compression] = None, 

604 interceptors: Optional[Sequence[ClientInterceptor]] = None, 

605): 

606 """Creates a secure asynchronous Channel to a server. 

607 

608 Args: 

609 target: The server address. 

610 credentials: A ChannelCredentials instance. 

611 options: An optional list of key-value pairs (:term:`channel_arguments` 

612 in gRPC Core runtime) to configure the channel. 

613 compression: An optional value indicating the compression method to be 

614 used over the lifetime of the channel. 

615 interceptors: An optional sequence of interceptors that will be executed for 

616 any call executed with this channel. 

617 

618 Returns: 

619 An aio.Channel. 

620 """ 

621 return Channel( 

622 target, 

623 () if options is None else options, 

624 credentials._credentials, 

625 compression, 

626 interceptors, 

627 )