Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/grpc/aio/_channel.py: 42%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

191 statements  

1# Copyright 2019 gRPC authors. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14"""Invocation-side implementation of gRPC Asyncio Python.""" 

15 

16import asyncio 

17import sys 

18from typing import Any, Iterable, List, Optional, Sequence 

19 

20import grpc 

21from grpc import _common 

22from grpc import _compression 

23from grpc import _grpcio_metadata 

24from grpc._cython import cygrpc 

25 

26from . import _base_call 

27from . import _base_channel 

28from ._call import StreamStreamCall 

29from ._call import StreamUnaryCall 

30from ._call import UnaryStreamCall 

31from ._call import UnaryUnaryCall 

32from ._interceptor import ClientInterceptor 

33from ._interceptor import InterceptedStreamStreamCall 

34from ._interceptor import InterceptedStreamUnaryCall 

35from ._interceptor import InterceptedUnaryStreamCall 

36from ._interceptor import InterceptedUnaryUnaryCall 

37from ._interceptor import StreamStreamClientInterceptor 

38from ._interceptor import StreamUnaryClientInterceptor 

39from ._interceptor import UnaryStreamClientInterceptor 

40from ._interceptor import UnaryUnaryClientInterceptor 

41from ._metadata import Metadata 

42from ._typing import ChannelArgumentType 

43from ._typing import DeserializingFunction 

44from ._typing import MetadataType 

45from ._typing import RequestIterableType 

46from ._typing import RequestType 

47from ._typing import ResponseType 

48from ._typing import SerializingFunction 

49from ._utils import _timeout_to_deadline 

50 

51_USER_AGENT = "grpc-python-asyncio/{}".format(_grpcio_metadata.__version__) 

52 

53if sys.version_info[1] < 7: 

54 

55 def _all_tasks() -> Iterable[asyncio.Task]: 

56 return asyncio.Task.all_tasks() # pylint: disable=no-member 

57 

58else: 

59 

60 def _all_tasks() -> Iterable[asyncio.Task]: 

61 return asyncio.all_tasks() 

62 

63 

64def _augment_channel_arguments( 

65 base_options: ChannelArgumentType, compression: Optional[grpc.Compression] 

66): 

67 compression_channel_argument = _compression.create_channel_option( 

68 compression 

69 ) 

70 user_agent_channel_argument = ( 

71 ( 

72 cygrpc.ChannelArgKey.primary_user_agent_string, 

73 _USER_AGENT, 

74 ), 

75 ) 

76 return ( 

77 tuple(base_options) 

78 + compression_channel_argument 

79 + user_agent_channel_argument 

80 ) 

81 

82 

83class _BaseMultiCallable: 

84 """Base class of all multi callable objects. 

85 

86 Handles the initialization logic and stores common attributes. 

87 """ 

88 

89 _loop: asyncio.AbstractEventLoop 

90 _channel: cygrpc.AioChannel 

91 _method: bytes 

92 _request_serializer: Optional[SerializingFunction] 

93 _response_deserializer: Optional[DeserializingFunction] 

94 _interceptors: Optional[Sequence[ClientInterceptor]] 

95 _references: List[Any] 

96 _loop: asyncio.AbstractEventLoop 

97 

98 # pylint: disable=too-many-arguments 

99 def __init__( 

100 self, 

101 channel: cygrpc.AioChannel, 

102 method: bytes, 

103 request_serializer: Optional[SerializingFunction], 

104 response_deserializer: Optional[DeserializingFunction], 

105 interceptors: Optional[Sequence[ClientInterceptor]], 

106 references: List[Any], 

107 loop: asyncio.AbstractEventLoop, 

108 ) -> None: 

109 self._loop = loop 

110 self._channel = channel 

111 self._method = method 

112 self._request_serializer = request_serializer 

113 self._response_deserializer = response_deserializer 

114 self._interceptors = interceptors 

115 self._references = references 

116 

117 @staticmethod 

118 def _init_metadata( 

119 metadata: Optional[MetadataType] = None, 

120 compression: Optional[grpc.Compression] = None, 

121 ) -> Metadata: 

122 """Based on the provided values for <metadata> or <compression> initialise the final 

123 metadata, as it should be used for the current call. 

124 """ 

125 metadata = metadata or Metadata() 

126 if not isinstance(metadata, Metadata) and isinstance( 

127 metadata, Sequence 

128 ): 

129 metadata = Metadata.from_tuple(tuple(metadata)) 

130 if compression: 

131 metadata = Metadata( 

132 *_compression.augment_metadata(metadata, compression) 

133 ) 

134 return metadata 

135 

136 

137class UnaryUnaryMultiCallable( 

138 _BaseMultiCallable, _base_channel.UnaryUnaryMultiCallable 

139): 

140 def __call__( 

141 self, 

142 request: RequestType, 

143 *, 

144 timeout: Optional[float] = None, 

145 metadata: Optional[MetadataType] = None, 

146 credentials: Optional[grpc.CallCredentials] = None, 

147 wait_for_ready: Optional[bool] = None, 

148 compression: Optional[grpc.Compression] = None, 

149 ) -> _base_call.UnaryUnaryCall[RequestType, ResponseType]: 

150 metadata = self._init_metadata(metadata, compression) 

151 if not self._interceptors: 

152 call = UnaryUnaryCall( 

153 request, 

154 _timeout_to_deadline(timeout), 

155 metadata, 

156 credentials, 

157 wait_for_ready, 

158 self._channel, 

159 self._method, 

160 self._request_serializer, 

161 self._response_deserializer, 

162 self._loop, 

163 ) 

164 else: 

165 call = InterceptedUnaryUnaryCall( 

166 self._interceptors, 

167 request, 

168 timeout, 

169 metadata, 

170 credentials, 

171 wait_for_ready, 

172 self._channel, 

173 self._method, 

174 self._request_serializer, 

175 self._response_deserializer, 

176 self._loop, 

177 ) 

178 

179 return call 

180 

181 

182class UnaryStreamMultiCallable( 

183 _BaseMultiCallable, _base_channel.UnaryStreamMultiCallable 

184): 

185 def __call__( 

186 self, 

187 request: RequestType, 

188 *, 

189 timeout: Optional[float] = None, 

190 metadata: Optional[MetadataType] = None, 

191 credentials: Optional[grpc.CallCredentials] = None, 

192 wait_for_ready: Optional[bool] = None, 

193 compression: Optional[grpc.Compression] = None, 

194 ) -> _base_call.UnaryStreamCall[RequestType, ResponseType]: 

195 metadata = self._init_metadata(metadata, compression) 

196 

197 if not self._interceptors: 

198 call = UnaryStreamCall( 

199 request, 

200 _timeout_to_deadline(timeout), 

201 metadata, 

202 credentials, 

203 wait_for_ready, 

204 self._channel, 

205 self._method, 

206 self._request_serializer, 

207 self._response_deserializer, 

208 self._loop, 

209 ) 

210 else: 

211 call = InterceptedUnaryStreamCall( 

212 self._interceptors, 

213 request, 

214 timeout, 

215 metadata, 

216 credentials, 

217 wait_for_ready, 

218 self._channel, 

219 self._method, 

220 self._request_serializer, 

221 self._response_deserializer, 

222 self._loop, 

223 ) 

224 

225 return call 

226 

227 

228class StreamUnaryMultiCallable( 

229 _BaseMultiCallable, _base_channel.StreamUnaryMultiCallable 

230): 

231 def __call__( 

232 self, 

233 request_iterator: Optional[RequestIterableType] = None, 

234 timeout: Optional[float] = None, 

235 metadata: Optional[MetadataType] = None, 

236 credentials: Optional[grpc.CallCredentials] = None, 

237 wait_for_ready: Optional[bool] = None, 

238 compression: Optional[grpc.Compression] = None, 

239 ) -> _base_call.StreamUnaryCall: 

240 metadata = self._init_metadata(metadata, compression) 

241 

242 if not self._interceptors: 

243 call = StreamUnaryCall( 

244 request_iterator, 

245 _timeout_to_deadline(timeout), 

246 metadata, 

247 credentials, 

248 wait_for_ready, 

249 self._channel, 

250 self._method, 

251 self._request_serializer, 

252 self._response_deserializer, 

253 self._loop, 

254 ) 

255 else: 

256 call = InterceptedStreamUnaryCall( 

257 self._interceptors, 

258 request_iterator, 

259 timeout, 

260 metadata, 

261 credentials, 

262 wait_for_ready, 

263 self._channel, 

264 self._method, 

265 self._request_serializer, 

266 self._response_deserializer, 

267 self._loop, 

268 ) 

269 

270 return call 

271 

272 

273class StreamStreamMultiCallable( 

274 _BaseMultiCallable, _base_channel.StreamStreamMultiCallable 

275): 

276 def __call__( 

277 self, 

278 request_iterator: Optional[RequestIterableType] = None, 

279 timeout: Optional[float] = None, 

280 metadata: Optional[MetadataType] = None, 

281 credentials: Optional[grpc.CallCredentials] = None, 

282 wait_for_ready: Optional[bool] = None, 

283 compression: Optional[grpc.Compression] = None, 

284 ) -> _base_call.StreamStreamCall: 

285 metadata = self._init_metadata(metadata, compression) 

286 

287 if not self._interceptors: 

288 call = StreamStreamCall( 

289 request_iterator, 

290 _timeout_to_deadline(timeout), 

291 metadata, 

292 credentials, 

293 wait_for_ready, 

294 self._channel, 

295 self._method, 

296 self._request_serializer, 

297 self._response_deserializer, 

298 self._loop, 

299 ) 

300 else: 

301 call = InterceptedStreamStreamCall( 

302 self._interceptors, 

303 request_iterator, 

304 timeout, 

305 metadata, 

306 credentials, 

307 wait_for_ready, 

308 self._channel, 

309 self._method, 

310 self._request_serializer, 

311 self._response_deserializer, 

312 self._loop, 

313 ) 

314 

315 return call 

316 

317 

318class Channel(_base_channel.Channel): 

319 _loop: asyncio.AbstractEventLoop 

320 _channel: cygrpc.AioChannel 

321 _unary_unary_interceptors: List[UnaryUnaryClientInterceptor] 

322 _unary_stream_interceptors: List[UnaryStreamClientInterceptor] 

323 _stream_unary_interceptors: List[StreamUnaryClientInterceptor] 

324 _stream_stream_interceptors: List[StreamStreamClientInterceptor] 

325 

326 def __init__( 

327 self, 

328 target: str, 

329 options: ChannelArgumentType, 

330 credentials: Optional[grpc.ChannelCredentials], 

331 compression: Optional[grpc.Compression], 

332 interceptors: Optional[Sequence[ClientInterceptor]], 

333 ): 

334 """Constructor. 

335 

336 Args: 

337 target: The target to which to connect. 

338 options: Configuration options for the channel. 

339 credentials: A cygrpc.ChannelCredentials or None. 

340 compression: An optional value indicating the compression method to be 

341 used over the lifetime of the channel. 

342 interceptors: An optional list of interceptors that would be used for 

343 intercepting any RPC executed with that channel. 

344 """ 

345 self._unary_unary_interceptors = [] 

346 self._unary_stream_interceptors = [] 

347 self._stream_unary_interceptors = [] 

348 self._stream_stream_interceptors = [] 

349 

350 if interceptors is not None: 

351 for interceptor in interceptors: 

352 if isinstance(interceptor, UnaryUnaryClientInterceptor): 

353 self._unary_unary_interceptors.append(interceptor) 

354 elif isinstance(interceptor, UnaryStreamClientInterceptor): 

355 self._unary_stream_interceptors.append(interceptor) 

356 elif isinstance(interceptor, StreamUnaryClientInterceptor): 

357 self._stream_unary_interceptors.append(interceptor) 

358 elif isinstance(interceptor, StreamStreamClientInterceptor): 

359 self._stream_stream_interceptors.append(interceptor) 

360 else: 

361 raise ValueError( 

362 "Interceptor {} must be ".format(interceptor) 

363 + "{} or ".format(UnaryUnaryClientInterceptor.__name__) 

364 + "{} or ".format(UnaryStreamClientInterceptor.__name__) 

365 + "{} or ".format(StreamUnaryClientInterceptor.__name__) 

366 + "{}. ".format(StreamStreamClientInterceptor.__name__) 

367 ) 

368 

369 self._loop = cygrpc.get_working_loop() 

370 self._channel = cygrpc.AioChannel( 

371 _common.encode(target), 

372 _augment_channel_arguments(options, compression), 

373 credentials, 

374 self._loop, 

375 ) 

376 

377 async def __aenter__(self): 

378 return self 

379 

380 async def __aexit__(self, exc_type, exc_val, exc_tb): 

381 await self._close(None) 

382 

383 async def _close(self, grace): # pylint: disable=too-many-branches 

384 if self._channel.closed(): 

385 return 

386 

387 # No new calls will be accepted by the Cython channel. 

388 self._channel.closing() 

389 

390 # Iterate through running tasks 

391 tasks = _all_tasks() 

392 calls = [] 

393 call_tasks = [] 

394 for task in tasks: 

395 try: 

396 stack = task.get_stack(limit=1) 

397 except AttributeError as attribute_error: 

398 # NOTE(lidiz) tl;dr: If the Task is created with a CPython 

399 # object, it will trigger AttributeError. 

400 # 

401 # In the global finalizer, the event loop schedules 

402 # a CPython PyAsyncGenAThrow object. 

403 # https://github.com/python/cpython/blob/00e45877e33d32bb61aa13a2033e3bba370bda4d/Lib/asyncio/base_events.py#L484 

404 # 

405 # However, the PyAsyncGenAThrow object is written in C and 

406 # failed to include the normal Python frame objects. Hence, 

407 # this exception is a false negative, and it is safe to ignore 

408 # the failure. It is fixed by https://github.com/python/cpython/pull/18669, 

409 # but not available until 3.9 or 3.8.3. So, we have to keep it 

410 # for a while. 

411 # TODO(lidiz) drop this hack after 3.8 deprecation 

412 if "frame" in str(attribute_error): 

413 continue 

414 else: 

415 raise 

416 

417 # If the Task is created by a C-extension, the stack will be empty. 

418 if not stack: 

419 continue 

420 

421 # Locate ones created by `aio.Call`. 

422 frame = stack[0] 

423 candidate = frame.f_locals.get("self") 

424 # Explicitly check for a non-null candidate instead of the more pythonic 'if candidate:' 

425 # because doing 'if candidate:' assumes that the coroutine implements '__bool__' which 

426 # might not always be the case. 

427 if candidate is not None: 

428 if isinstance(candidate, _base_call.Call): 

429 if hasattr(candidate, "_channel"): 

430 # For intercepted Call object 

431 if candidate._channel is not self._channel: 

432 continue 

433 elif hasattr(candidate, "_cython_call"): 

434 # For normal Call object 

435 if candidate._cython_call._channel is not self._channel: 

436 continue 

437 else: 

438 # Unidentified Call object 

439 error_msg = f"Unrecognized call object: {candidate}" 

440 raise cygrpc.InternalError(error_msg) 

441 

442 calls.append(candidate) 

443 call_tasks.append(task) 

444 

445 # If needed, try to wait for them to finish. 

446 # Call objects are not always awaitables. 

447 if grace and call_tasks: 

448 await asyncio.wait(call_tasks, timeout=grace) 

449 

450 # Time to cancel existing calls. 

451 for call in calls: 

452 call.cancel() 

453 

454 # Destroy the channel 

455 self._channel.close() 

456 

457 async def close(self, grace: Optional[float] = None): 

458 await self._close(grace) 

459 

460 def __del__(self): 

461 if hasattr(self, "_channel"): 

462 if not self._channel.closed(): 

463 self._channel.close() 

464 

465 def get_state( 

466 self, try_to_connect: bool = False 

467 ) -> grpc.ChannelConnectivity: 

468 result = self._channel.check_connectivity_state(try_to_connect) 

469 return _common.CYGRPC_CONNECTIVITY_STATE_TO_CHANNEL_CONNECTIVITY[result] 

470 

471 async def wait_for_state_change( 

472 self, 

473 last_observed_state: grpc.ChannelConnectivity, 

474 ) -> None: 

475 assert await self._channel.watch_connectivity_state( 

476 last_observed_state.value[0], None 

477 ) 

478 

479 async def channel_ready(self) -> None: 

480 state = self.get_state(try_to_connect=True) 

481 while state != grpc.ChannelConnectivity.READY: 

482 await self.wait_for_state_change(state) 

483 state = self.get_state(try_to_connect=True) 

484 

485 # TODO(xuanwn): Implement this method after we have 

486 # observability for Asyncio. 

487 def _get_registered_call_handle(self, method: str) -> int: 

488 pass 

489 

490 # TODO(xuanwn): Implement _registered_method after we have 

491 # observability for Asyncio. 

492 # pylint: disable=arguments-differ,unused-argument 

493 def unary_unary( 

494 self, 

495 method: str, 

496 request_serializer: Optional[SerializingFunction] = None, 

497 response_deserializer: Optional[DeserializingFunction] = None, 

498 _registered_method: Optional[bool] = False, 

499 ) -> UnaryUnaryMultiCallable: 

500 return UnaryUnaryMultiCallable( 

501 self._channel, 

502 _common.encode(method), 

503 request_serializer, 

504 response_deserializer, 

505 self._unary_unary_interceptors, 

506 [self], 

507 self._loop, 

508 ) 

509 

510 # TODO(xuanwn): Implement _registered_method after we have 

511 # observability for Asyncio. 

512 # pylint: disable=arguments-differ,unused-argument 

513 def unary_stream( 

514 self, 

515 method: str, 

516 request_serializer: Optional[SerializingFunction] = None, 

517 response_deserializer: Optional[DeserializingFunction] = None, 

518 _registered_method: Optional[bool] = False, 

519 ) -> UnaryStreamMultiCallable: 

520 return UnaryStreamMultiCallable( 

521 self._channel, 

522 _common.encode(method), 

523 request_serializer, 

524 response_deserializer, 

525 self._unary_stream_interceptors, 

526 [self], 

527 self._loop, 

528 ) 

529 

530 # TODO(xuanwn): Implement _registered_method after we have 

531 # observability for Asyncio. 

532 # pylint: disable=arguments-differ,unused-argument 

533 def stream_unary( 

534 self, 

535 method: str, 

536 request_serializer: Optional[SerializingFunction] = None, 

537 response_deserializer: Optional[DeserializingFunction] = None, 

538 _registered_method: Optional[bool] = False, 

539 ) -> StreamUnaryMultiCallable: 

540 return StreamUnaryMultiCallable( 

541 self._channel, 

542 _common.encode(method), 

543 request_serializer, 

544 response_deserializer, 

545 self._stream_unary_interceptors, 

546 [self], 

547 self._loop, 

548 ) 

549 

550 # TODO(xuanwn): Implement _registered_method after we have 

551 # observability for Asyncio. 

552 # pylint: disable=arguments-differ,unused-argument 

553 def stream_stream( 

554 self, 

555 method: str, 

556 request_serializer: Optional[SerializingFunction] = None, 

557 response_deserializer: Optional[DeserializingFunction] = None, 

558 _registered_method: Optional[bool] = False, 

559 ) -> StreamStreamMultiCallable: 

560 return StreamStreamMultiCallable( 

561 self._channel, 

562 _common.encode(method), 

563 request_serializer, 

564 response_deserializer, 

565 self._stream_stream_interceptors, 

566 [self], 

567 self._loop, 

568 ) 

569 

570 

571def insecure_channel( 

572 target: str, 

573 options: Optional[ChannelArgumentType] = None, 

574 compression: Optional[grpc.Compression] = None, 

575 interceptors: Optional[Sequence[ClientInterceptor]] = None, 

576): 

577 """Creates an insecure asynchronous Channel to a server. 

578 

579 Args: 

580 target: The server address 

581 options: An optional list of key-value pairs (:term:`channel_arguments` 

582 in gRPC Core runtime) to configure the channel. 

583 compression: An optional value indicating the compression method to be 

584 used over the lifetime of the channel. 

585 interceptors: An optional sequence of interceptors that will be executed for 

586 any call executed with this channel. 

587 

588 Returns: 

589 A Channel. 

590 """ 

591 return Channel( 

592 target, 

593 () if options is None else options, 

594 None, 

595 compression, 

596 interceptors, 

597 ) 

598 

599 

600def secure_channel( 

601 target: str, 

602 credentials: grpc.ChannelCredentials, 

603 options: Optional[ChannelArgumentType] = None, 

604 compression: Optional[grpc.Compression] = None, 

605 interceptors: Optional[Sequence[ClientInterceptor]] = None, 

606): 

607 """Creates a secure asynchronous Channel to a server. 

608 

609 Args: 

610 target: The server address. 

611 credentials: A ChannelCredentials instance. 

612 options: An optional list of key-value pairs (:term:`channel_arguments` 

613 in gRPC Core runtime) to configure the channel. 

614 compression: An optional value indicating the compression method to be 

615 used over the lifetime of the channel. 

616 interceptors: An optional sequence of interceptors that will be executed for 

617 any call executed with this channel. 

618 

619 Returns: 

620 An aio.Channel. 

621 """ 

622 return Channel( 

623 target, 

624 () if options is None else options, 

625 credentials._credentials, 

626 compression, 

627 interceptors, 

628 )