Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/grpc/aio/_channel.py: 42%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

189 statements  

1# Copyright 2019 gRPC authors. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14"""Invocation-side implementation of gRPC Asyncio Python.""" 

15 

16import asyncio 

17import sys 

18from typing import Any, Iterable, List, Optional, Sequence 

19 

20import grpc 

21from grpc import _common 

22from grpc import _compression 

23from grpc import _grpcio_metadata 

24from grpc._cython import cygrpc 

25 

26from . import _base_call 

27from . import _base_channel 

28from ._call import StreamStreamCall 

29from ._call import StreamUnaryCall 

30from ._call import UnaryStreamCall 

31from ._call import UnaryUnaryCall 

32from ._interceptor import ClientInterceptor 

33from ._interceptor import InterceptedStreamStreamCall 

34from ._interceptor import InterceptedStreamUnaryCall 

35from ._interceptor import InterceptedUnaryStreamCall 

36from ._interceptor import InterceptedUnaryUnaryCall 

37from ._interceptor import StreamStreamClientInterceptor 

38from ._interceptor import StreamUnaryClientInterceptor 

39from ._interceptor import UnaryStreamClientInterceptor 

40from ._interceptor import UnaryUnaryClientInterceptor 

41from ._metadata import Metadata 

42from ._typing import ChannelArgumentType 

43from ._typing import DeserializingFunction 

44from ._typing import MetadataType 

45from ._typing import RequestIterableType 

46from ._typing import RequestType 

47from ._typing import ResponseType 

48from ._typing import SerializingFunction 

49from ._utils import _timeout_to_deadline 

50 

51_USER_AGENT = "grpc-python-asyncio/{}".format(_grpcio_metadata.__version__) 

52 

53if sys.version_info[1] < 7: 

54 

55 def _all_tasks() -> Iterable[asyncio.Task]: 

56 return asyncio.Task.all_tasks() # pylint: disable=no-member 

57 

58else: 

59 

60 def _all_tasks() -> Iterable[asyncio.Task]: 

61 return asyncio.all_tasks() 

62 

63 

64def _augment_channel_arguments( 

65 base_options: ChannelArgumentType, compression: Optional[grpc.Compression] 

66): 

67 compression_channel_argument = _compression.create_channel_option( 

68 compression 

69 ) 

70 user_agent_channel_argument = ( 

71 ( 

72 cygrpc.ChannelArgKey.primary_user_agent_string, 

73 _USER_AGENT, 

74 ), 

75 ) 

76 return ( 

77 tuple(base_options) 

78 + compression_channel_argument 

79 + user_agent_channel_argument 

80 ) 

81 

82 

83class _BaseMultiCallable: 

84 """Base class of all multi callable objects. 

85 

86 Handles the initialization logic and stores common attributes. 

87 """ 

88 

89 _loop: asyncio.AbstractEventLoop 

90 _channel: cygrpc.AioChannel 

91 _method: bytes 

92 _request_serializer: Optional[SerializingFunction] 

93 _response_deserializer: Optional[DeserializingFunction] 

94 _interceptors: Optional[Sequence[ClientInterceptor]] 

95 _references: List[Any] 

96 _loop: asyncio.AbstractEventLoop 

97 

98 # pylint: disable=too-many-arguments 

99 def __init__( 

100 self, 

101 channel: cygrpc.AioChannel, 

102 method: bytes, 

103 request_serializer: Optional[SerializingFunction], 

104 response_deserializer: Optional[DeserializingFunction], 

105 interceptors: Optional[Sequence[ClientInterceptor]], 

106 references: List[Any], 

107 loop: asyncio.AbstractEventLoop, 

108 ) -> None: 

109 self._loop = loop 

110 self._channel = channel 

111 self._method = method 

112 self._request_serializer = request_serializer 

113 self._response_deserializer = response_deserializer 

114 self._interceptors = interceptors 

115 self._references = references 

116 

117 @staticmethod 

118 def _init_metadata( 

119 metadata: Optional[MetadataType] = None, 

120 compression: Optional[grpc.Compression] = None, 

121 ) -> Metadata: 

122 """Based on the provided values for <metadata> or <compression> initialise the final 

123 metadata, as it should be used for the current call. 

124 """ 

125 metadata = metadata or Metadata() 

126 if not isinstance(metadata, Metadata) and isinstance( 

127 metadata, Sequence 

128 ): 

129 metadata = Metadata.from_tuple(tuple(metadata)) 

130 if compression: 

131 metadata = Metadata( 

132 *_compression.augment_metadata(metadata, compression) 

133 ) 

134 return metadata 

135 

136 

137class UnaryUnaryMultiCallable( 

138 _BaseMultiCallable, _base_channel.UnaryUnaryMultiCallable 

139): 

140 def __call__( 

141 self, 

142 request: RequestType, 

143 *, 

144 timeout: Optional[float] = None, 

145 metadata: Optional[MetadataType] = None, 

146 credentials: Optional[grpc.CallCredentials] = None, 

147 wait_for_ready: Optional[bool] = None, 

148 compression: Optional[grpc.Compression] = None, 

149 ) -> _base_call.UnaryUnaryCall[RequestType, ResponseType]: 

150 metadata = self._init_metadata(metadata, compression) 

151 if not self._interceptors: 

152 call = UnaryUnaryCall( 

153 request, 

154 _timeout_to_deadline(timeout), 

155 metadata, 

156 credentials, 

157 wait_for_ready, 

158 self._channel, 

159 self._method, 

160 self._request_serializer, 

161 self._response_deserializer, 

162 self._loop, 

163 ) 

164 else: 

165 call = InterceptedUnaryUnaryCall( 

166 self._interceptors, 

167 request, 

168 timeout, 

169 metadata, 

170 credentials, 

171 wait_for_ready, 

172 self._channel, 

173 self._method, 

174 self._request_serializer, 

175 self._response_deserializer, 

176 self._loop, 

177 ) 

178 

179 return call 

180 

181 

182class UnaryStreamMultiCallable( 

183 _BaseMultiCallable, _base_channel.UnaryStreamMultiCallable 

184): 

185 def __call__( 

186 self, 

187 request: RequestType, 

188 *, 

189 timeout: Optional[float] = None, 

190 metadata: Optional[MetadataType] = None, 

191 credentials: Optional[grpc.CallCredentials] = None, 

192 wait_for_ready: Optional[bool] = None, 

193 compression: Optional[grpc.Compression] = None, 

194 ) -> _base_call.UnaryStreamCall[RequestType, ResponseType]: 

195 metadata = self._init_metadata(metadata, compression) 

196 

197 if not self._interceptors: 

198 call = UnaryStreamCall( 

199 request, 

200 _timeout_to_deadline(timeout), 

201 metadata, 

202 credentials, 

203 wait_for_ready, 

204 self._channel, 

205 self._method, 

206 self._request_serializer, 

207 self._response_deserializer, 

208 self._loop, 

209 ) 

210 else: 

211 call = InterceptedUnaryStreamCall( 

212 self._interceptors, 

213 request, 

214 timeout, 

215 metadata, 

216 credentials, 

217 wait_for_ready, 

218 self._channel, 

219 self._method, 

220 self._request_serializer, 

221 self._response_deserializer, 

222 self._loop, 

223 ) 

224 

225 return call 

226 

227 

228class StreamUnaryMultiCallable( 

229 _BaseMultiCallable, _base_channel.StreamUnaryMultiCallable 

230): 

231 def __call__( 

232 self, 

233 request_iterator: Optional[RequestIterableType] = None, 

234 timeout: Optional[float] = None, 

235 metadata: Optional[MetadataType] = None, 

236 credentials: Optional[grpc.CallCredentials] = None, 

237 wait_for_ready: Optional[bool] = None, 

238 compression: Optional[grpc.Compression] = None, 

239 ) -> _base_call.StreamUnaryCall: 

240 metadata = self._init_metadata(metadata, compression) 

241 

242 if not self._interceptors: 

243 call = StreamUnaryCall( 

244 request_iterator, 

245 _timeout_to_deadline(timeout), 

246 metadata, 

247 credentials, 

248 wait_for_ready, 

249 self._channel, 

250 self._method, 

251 self._request_serializer, 

252 self._response_deserializer, 

253 self._loop, 

254 ) 

255 else: 

256 call = InterceptedStreamUnaryCall( 

257 self._interceptors, 

258 request_iterator, 

259 timeout, 

260 metadata, 

261 credentials, 

262 wait_for_ready, 

263 self._channel, 

264 self._method, 

265 self._request_serializer, 

266 self._response_deserializer, 

267 self._loop, 

268 ) 

269 

270 return call 

271 

272 

273class StreamStreamMultiCallable( 

274 _BaseMultiCallable, _base_channel.StreamStreamMultiCallable 

275): 

276 def __call__( 

277 self, 

278 request_iterator: Optional[RequestIterableType] = None, 

279 timeout: Optional[float] = None, 

280 metadata: Optional[MetadataType] = None, 

281 credentials: Optional[grpc.CallCredentials] = None, 

282 wait_for_ready: Optional[bool] = None, 

283 compression: Optional[grpc.Compression] = None, 

284 ) -> _base_call.StreamStreamCall: 

285 metadata = self._init_metadata(metadata, compression) 

286 

287 if not self._interceptors: 

288 call = StreamStreamCall( 

289 request_iterator, 

290 _timeout_to_deadline(timeout), 

291 metadata, 

292 credentials, 

293 wait_for_ready, 

294 self._channel, 

295 self._method, 

296 self._request_serializer, 

297 self._response_deserializer, 

298 self._loop, 

299 ) 

300 else: 

301 call = InterceptedStreamStreamCall( 

302 self._interceptors, 

303 request_iterator, 

304 timeout, 

305 metadata, 

306 credentials, 

307 wait_for_ready, 

308 self._channel, 

309 self._method, 

310 self._request_serializer, 

311 self._response_deserializer, 

312 self._loop, 

313 ) 

314 

315 return call 

316 

317 

318class Channel(_base_channel.Channel): 

319 _loop: asyncio.AbstractEventLoop 

320 _channel: cygrpc.AioChannel 

321 _unary_unary_interceptors: List[UnaryUnaryClientInterceptor] 

322 _unary_stream_interceptors: List[UnaryStreamClientInterceptor] 

323 _stream_unary_interceptors: List[StreamUnaryClientInterceptor] 

324 _stream_stream_interceptors: List[StreamStreamClientInterceptor] 

325 

326 def __init__( 

327 self, 

328 target: str, 

329 options: ChannelArgumentType, 

330 credentials: Optional[grpc.ChannelCredentials], 

331 compression: Optional[grpc.Compression], 

332 interceptors: Optional[Sequence[ClientInterceptor]], 

333 ): 

334 """Constructor. 

335 

336 Args: 

337 target: The target to which to connect. 

338 options: Configuration options for the channel. 

339 credentials: A cygrpc.ChannelCredentials or None. 

340 compression: An optional value indicating the compression method to be 

341 used over the lifetime of the channel. 

342 interceptors: An optional list of interceptors that would be used for 

343 intercepting any RPC executed with that channel. 

344 """ 

345 self._unary_unary_interceptors = [] 

346 self._unary_stream_interceptors = [] 

347 self._stream_unary_interceptors = [] 

348 self._stream_stream_interceptors = [] 

349 

350 if interceptors is not None: 

351 for interceptor in interceptors: 

352 if isinstance(interceptor, UnaryUnaryClientInterceptor): 

353 self._unary_unary_interceptors.append(interceptor) 

354 elif isinstance(interceptor, UnaryStreamClientInterceptor): 

355 self._unary_stream_interceptors.append(interceptor) 

356 elif isinstance(interceptor, StreamUnaryClientInterceptor): 

357 self._stream_unary_interceptors.append(interceptor) 

358 elif isinstance(interceptor, StreamStreamClientInterceptor): 

359 self._stream_stream_interceptors.append(interceptor) 

360 else: 

361 raise ValueError( # noqa: TRY004 

362 "Interceptor {} must be ".format(interceptor) 

363 + "{} or ".format(UnaryUnaryClientInterceptor.__name__) 

364 + "{} or ".format(UnaryStreamClientInterceptor.__name__) 

365 + "{} or ".format(StreamUnaryClientInterceptor.__name__) 

366 + "{}. ".format(StreamStreamClientInterceptor.__name__) 

367 ) 

368 

369 self._loop = cygrpc.get_working_loop() 

370 self._channel = cygrpc.AioChannel( 

371 _common.encode(target), 

372 _augment_channel_arguments(options, compression), 

373 credentials, 

374 self._loop, 

375 ) 

376 

377 async def __aenter__(self): 

378 return self 

379 

380 async def __aexit__(self, exc_type, exc_val, exc_tb): 

381 await self._close(None) 

382 

383 async def _close(self, grace): # pylint: disable=too-many-branches 

384 if self._channel.closed(): 

385 return 

386 

387 # No new calls will be accepted by the Cython channel. 

388 self._channel.closing() 

389 

390 # Iterate through running tasks 

391 tasks = _all_tasks() 

392 calls = [] 

393 call_tasks = [] 

394 for task in tasks: 

395 try: 

396 stack = task.get_stack(limit=1) 

397 except AttributeError as attribute_error: 

398 # NOTE(lidiz) tl;dr: If the Task is created with a CPython 

399 # object, it will trigger AttributeError. 

400 # 

401 # In the global finalizer, the event loop schedules 

402 # a CPython PyAsyncGenAThrow object. 

403 # https://github.com/python/cpython/blob/00e45877e33d32bb61aa13a2033e3bba370bda4d/Lib/asyncio/base_events.py#L484 

404 # 

405 # However, the PyAsyncGenAThrow object is written in C and 

406 # failed to include the normal Python frame objects. Hence, 

407 # this exception is a false negative, and it is safe to ignore 

408 # the failure. It is fixed by https://github.com/python/cpython/pull/18669, 

409 # but not available until 3.9 or 3.8.3. So, we have to keep it 

410 # for a while. 

411 # TODO(lidiz): drop this hack after 3.8 deprecation 

412 if "frame" in str(attribute_error): 

413 continue 

414 raise 

415 

416 # If the Task is created by a C-extension, the stack will be empty. 

417 if not stack: 

418 continue 

419 

420 # Locate ones created by `aio.Call`. 

421 frame = stack[0] 

422 candidate = frame.f_locals.get("self") 

423 # Explicitly check for a non-null candidate instead of the more pythonic 'if candidate:' 

424 # because doing 'if candidate:' assumes that the coroutine implements '__bool__' which 

425 # might not always be the case. 

426 if candidate is not None and isinstance(candidate, _base_call.Call): 

427 if hasattr(candidate, "_channel"): 

428 # For intercepted Call object 

429 if candidate._channel is not self._channel: 

430 continue 

431 elif hasattr(candidate, "_cython_call"): 

432 # For normal Call object 

433 if candidate._cython_call._channel is not self._channel: 

434 continue 

435 else: 

436 # Unidentified Call object 

437 error_msg = f"Unrecognized call object: {candidate}" 

438 raise cygrpc.InternalError(error_msg) 

439 

440 calls.append(candidate) 

441 call_tasks.append(task) 

442 

443 # If needed, try to wait for them to finish. 

444 # Call objects are not always awaitables. 

445 if grace and call_tasks: 

446 await asyncio.wait(call_tasks, timeout=grace) 

447 

448 # Time to cancel existing calls. 

449 for call in calls: 

450 call.cancel() 

451 

452 # Destroy the channel 

453 self._channel.close() 

454 

455 async def close(self, grace: Optional[float] = None): 

456 await self._close(grace) 

457 

458 def __del__(self): 

459 if hasattr(self, "_channel") and not self._channel.closed(): 

460 self._channel.close() 

461 

462 def get_state( 

463 self, try_to_connect: bool = False 

464 ) -> grpc.ChannelConnectivity: 

465 result = self._channel.check_connectivity_state(try_to_connect) 

466 return _common.CYGRPC_CONNECTIVITY_STATE_TO_CHANNEL_CONNECTIVITY[result] 

467 

468 async def wait_for_state_change( 

469 self, 

470 last_observed_state: grpc.ChannelConnectivity, 

471 ) -> None: 

472 assert await self._channel.watch_connectivity_state( 

473 last_observed_state.value[0], None 

474 ) 

475 

476 async def channel_ready(self) -> None: 

477 state = self.get_state(try_to_connect=True) 

478 while state != grpc.ChannelConnectivity.READY: 

479 await self.wait_for_state_change(state) 

480 state = self.get_state(try_to_connect=True) 

481 

482 # TODO(xuanwn): Implement this method after we have 

483 # observability for Asyncio. 

484 def _get_registered_call_handle(self, method: str) -> int: 

485 pass 

486 

487 # TODO(xuanwn): Implement _registered_method after we have 

488 # observability for Asyncio. 

489 # pylint: disable=arguments-differ,unused-argument 

490 def unary_unary( 

491 self, 

492 method: str, 

493 request_serializer: Optional[SerializingFunction] = None, 

494 response_deserializer: Optional[DeserializingFunction] = None, 

495 _registered_method: Optional[bool] = False, 

496 ) -> UnaryUnaryMultiCallable: 

497 return UnaryUnaryMultiCallable( 

498 self._channel, 

499 _common.encode(method), 

500 request_serializer, 

501 response_deserializer, 

502 self._unary_unary_interceptors, 

503 [self], 

504 self._loop, 

505 ) 

506 

507 # TODO(xuanwn): Implement _registered_method after we have 

508 # observability for Asyncio. 

509 # pylint: disable=arguments-differ,unused-argument 

510 def unary_stream( 

511 self, 

512 method: str, 

513 request_serializer: Optional[SerializingFunction] = None, 

514 response_deserializer: Optional[DeserializingFunction] = None, 

515 _registered_method: Optional[bool] = False, 

516 ) -> UnaryStreamMultiCallable: 

517 return UnaryStreamMultiCallable( 

518 self._channel, 

519 _common.encode(method), 

520 request_serializer, 

521 response_deserializer, 

522 self._unary_stream_interceptors, 

523 [self], 

524 self._loop, 

525 ) 

526 

527 # TODO(xuanwn): Implement _registered_method after we have 

528 # observability for Asyncio. 

529 # pylint: disable=arguments-differ,unused-argument 

530 def stream_unary( 

531 self, 

532 method: str, 

533 request_serializer: Optional[SerializingFunction] = None, 

534 response_deserializer: Optional[DeserializingFunction] = None, 

535 _registered_method: Optional[bool] = False, 

536 ) -> StreamUnaryMultiCallable: 

537 return StreamUnaryMultiCallable( 

538 self._channel, 

539 _common.encode(method), 

540 request_serializer, 

541 response_deserializer, 

542 self._stream_unary_interceptors, 

543 [self], 

544 self._loop, 

545 ) 

546 

547 # TODO(xuanwn): Implement _registered_method after we have 

548 # observability for Asyncio. 

549 # pylint: disable=arguments-differ,unused-argument 

550 def stream_stream( 

551 self, 

552 method: str, 

553 request_serializer: Optional[SerializingFunction] = None, 

554 response_deserializer: Optional[DeserializingFunction] = None, 

555 _registered_method: Optional[bool] = False, 

556 ) -> StreamStreamMultiCallable: 

557 return StreamStreamMultiCallable( 

558 self._channel, 

559 _common.encode(method), 

560 request_serializer, 

561 response_deserializer, 

562 self._stream_stream_interceptors, 

563 [self], 

564 self._loop, 

565 ) 

566 

567 

568def insecure_channel( 

569 target: str, 

570 options: Optional[ChannelArgumentType] = None, 

571 compression: Optional[grpc.Compression] = None, 

572 interceptors: Optional[Sequence[ClientInterceptor]] = None, 

573): 

574 """Creates an insecure asynchronous Channel to a server. 

575 

576 Args: 

577 target: The server address 

578 options: An optional list of key-value pairs (:term:`channel_arguments` 

579 in gRPC Core runtime) to configure the channel. 

580 compression: An optional value indicating the compression method to be 

581 used over the lifetime of the channel. 

582 interceptors: An optional sequence of interceptors that will be executed for 

583 any call executed with this channel. 

584 

585 Returns: 

586 A Channel. 

587 """ 

588 return Channel( 

589 target, 

590 () if options is None else options, 

591 None, 

592 compression, 

593 interceptors, 

594 ) 

595 

596 

597def secure_channel( 

598 target: str, 

599 credentials: grpc.ChannelCredentials, 

600 options: Optional[ChannelArgumentType] = None, 

601 compression: Optional[grpc.Compression] = None, 

602 interceptors: Optional[Sequence[ClientInterceptor]] = None, 

603): 

604 """Creates a secure asynchronous Channel to a server. 

605 

606 Args: 

607 target: The server address. 

608 credentials: A ChannelCredentials instance. 

609 options: An optional list of key-value pairs (:term:`channel_arguments` 

610 in gRPC Core runtime) to configure the channel. 

611 compression: An optional value indicating the compression method to be 

612 used over the lifetime of the channel. 

613 interceptors: An optional sequence of interceptors that will be executed for 

614 any call executed with this channel. 

615 

616 Returns: 

617 An aio.Channel. 

618 """ 

619 return Channel( 

620 target, 

621 () if options is None else options, 

622 credentials._credentials, 

623 compression, 

624 interceptors, 

625 )