Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/zmq/_future.py: 19%

395 statements  

« prev     ^ index     » next       coverage.py v7.3.3, created at 2023-12-15 06:13 +0000

1"""Future-returning APIs for coroutines.""" 

2 

3# Copyright (c) PyZMQ Developers. 

4# Distributed under the terms of the Modified BSD License. 

5 

6import warnings 

7from asyncio import Future 

8from collections import deque 

9from itertools import chain 

10from typing import ( 

11 Any, 

12 Awaitable, 

13 Callable, 

14 Dict, 

15 List, 

16 NamedTuple, 

17 Optional, 

18 Tuple, 

19 Type, 

20 TypeVar, 

21 Union, 

22 cast, 

23 overload, 

24) 

25 

26import zmq as _zmq 

27from zmq import EVENTS, POLLIN, POLLOUT 

28from zmq._typing import Literal 

29 

30 

31class _FutureEvent(NamedTuple): 

32 future: Future 

33 kind: str 

34 kwargs: Dict 

35 msg: Any 

36 timer: Any 

37 

38 

39# These are incomplete classes and need a Mixin for compatibility with an eventloop 

40# defining the following attributes: 

41# 

42# _Future 

43# _READ 

44# _WRITE 

45# _default_loop() 

46 

47 

48class _Async: 

49 """Mixin for common async logic""" 

50 

51 _current_loop: Any = None 

52 _Future: Type[Future] 

53 

54 def _get_loop(self) -> Any: 

55 """Get event loop 

56 

57 Notice if event loop has changed, 

58 and register init_io_state on activation of a new event loop 

59 """ 

60 if self._current_loop is None: 

61 self._current_loop = self._default_loop() 

62 self._init_io_state(self._current_loop) 

63 return self._current_loop 

64 current_loop = self._default_loop() 

65 if current_loop is not self._current_loop: 

66 # warn? This means a socket is being used in multiple loops! 

67 self._current_loop = current_loop 

68 self._init_io_state(current_loop) 

69 return current_loop 

70 

71 def _default_loop(self) -> Any: 

72 raise NotImplementedError("Must be implemented in a subclass") 

73 

74 def _init_io_state(self, loop=None) -> None: 

75 pass 

76 

77 

78class _AsyncPoller(_Async, _zmq.Poller): 

79 """Poller that returns a Future on poll, instead of blocking.""" 

80 

81 _socket_class: Type["_AsyncSocket"] 

82 _READ: int 

83 _WRITE: int 

84 raw_sockets: List[Any] 

85 

86 def _watch_raw_socket(self, loop: Any, socket: Any, evt: int, f: Callable) -> None: 

87 """Schedule callback for a raw socket""" 

88 raise NotImplementedError() 

89 

90 def _unwatch_raw_sockets(self, loop: Any, *sockets: Any) -> None: 

91 """Unschedule callback for a raw socket""" 

92 raise NotImplementedError() 

93 

94 def poll(self, timeout=-1) -> Awaitable[List[Tuple[Any, int]]]: # type: ignore 

95 """Return a Future for a poll event""" 

96 future = self._Future() 

97 if timeout == 0: 

98 try: 

99 result = super().poll(0) 

100 except Exception as e: 

101 future.set_exception(e) 

102 else: 

103 future.set_result(result) 

104 return future 

105 

106 loop = self._get_loop() 

107 

108 # register Future to be called as soon as any event is available on any socket 

109 watcher = self._Future() 

110 

111 # watch raw sockets: 

112 raw_sockets: List[Any] = [] 

113 

114 def wake_raw(*args): 

115 if not watcher.done(): 

116 watcher.set_result(None) 

117 

118 watcher.add_done_callback( 

119 lambda f: self._unwatch_raw_sockets(loop, *raw_sockets) 

120 ) 

121 

122 for socket, mask in self.sockets: 

123 if isinstance(socket, _zmq.Socket): 

124 if not isinstance(socket, self._socket_class): 

125 # it's a blocking zmq.Socket, wrap it in async 

126 socket = self._socket_class.from_socket(socket) 

127 if mask & _zmq.POLLIN: 

128 socket._add_recv_event('poll', future=watcher) 

129 if mask & _zmq.POLLOUT: 

130 socket._add_send_event('poll', future=watcher) 

131 else: 

132 raw_sockets.append(socket) 

133 evt = 0 

134 if mask & _zmq.POLLIN: 

135 evt |= self._READ 

136 if mask & _zmq.POLLOUT: 

137 evt |= self._WRITE 

138 self._watch_raw_socket(loop, socket, evt, wake_raw) 

139 

140 def on_poll_ready(f): 

141 if future.done(): 

142 return 

143 if watcher.cancelled(): 

144 try: 

145 future.cancel() 

146 except RuntimeError: 

147 # RuntimeError may be called during teardown 

148 pass 

149 return 

150 if watcher.exception(): 

151 future.set_exception(watcher.exception()) 

152 else: 

153 try: 

154 result = super(_AsyncPoller, self).poll(0) 

155 except Exception as e: 

156 future.set_exception(e) 

157 else: 

158 future.set_result(result) 

159 

160 watcher.add_done_callback(on_poll_ready) 

161 

162 if timeout is not None and timeout > 0: 

163 # schedule cancel to fire on poll timeout, if any 

164 def trigger_timeout(): 

165 if not watcher.done(): 

166 watcher.set_result(None) 

167 

168 timeout_handle = loop.call_later(1e-3 * timeout, trigger_timeout) 

169 

170 def cancel_timeout(f): 

171 if hasattr(timeout_handle, 'cancel'): 

172 timeout_handle.cancel() 

173 else: 

174 loop.remove_timeout(timeout_handle) 

175 

176 future.add_done_callback(cancel_timeout) 

177 

178 def cancel_watcher(f): 

179 if not watcher.done(): 

180 watcher.cancel() 

181 

182 future.add_done_callback(cancel_watcher) 

183 

184 return future 

185 

186 

187class _NoTimer: 

188 @staticmethod 

189 def cancel(): 

190 pass 

191 

192 

193T = TypeVar("T", bound="_AsyncSocket") 

194 

195 

196class _AsyncSocket(_Async, _zmq.Socket[Future]): 

197 # Warning : these class variables are only here to allow to call super().__setattr__. 

198 # They be overridden at instance initialization and not shared in the whole class 

199 _recv_futures = None 

200 _send_futures = None 

201 _state = 0 

202 _shadow_sock: "_zmq.Socket" 

203 _poller_class = _AsyncPoller 

204 _fd = None 

205 

206 def __init__( 

207 self, 

208 context=None, 

209 socket_type=-1, 

210 io_loop=None, 

211 _from_socket: Optional["_zmq.Socket"] = None, 

212 **kwargs, 

213 ) -> None: 

214 if isinstance(context, _zmq.Socket): 

215 context, _from_socket = (None, context) 

216 if _from_socket is not None: 

217 super().__init__(shadow=_from_socket.underlying) # type: ignore 

218 self._shadow_sock = _from_socket 

219 else: 

220 super().__init__(context, socket_type, **kwargs) # type: ignore 

221 self._shadow_sock = _zmq.Socket.shadow(self.underlying) 

222 

223 if io_loop is not None: 

224 warnings.warn( 

225 f"{self.__class__.__name__}(io_loop) argument is deprecated in pyzmq 22.2." 

226 " The currently active loop will always be used.", 

227 DeprecationWarning, 

228 stacklevel=3, 

229 ) 

230 self._recv_futures = deque() 

231 self._send_futures = deque() 

232 self._state = 0 

233 self._fd = self._shadow_sock.FD 

234 

235 @classmethod 

236 def from_socket(cls: Type[T], socket: "_zmq.Socket", io_loop: Any = None) -> T: 

237 """Create an async socket from an existing Socket""" 

238 return cls(_from_socket=socket, io_loop=io_loop) 

239 

240 def close(self, linger: Optional[int] = None) -> None: 

241 if not self.closed and self._fd is not None: 

242 event_list: List[_FutureEvent] = list( 

243 chain(self._recv_futures or [], self._send_futures or []) 

244 ) 

245 for event in event_list: 

246 if not event.future.done(): 

247 try: 

248 event.future.cancel() 

249 except RuntimeError: 

250 # RuntimeError may be called during teardown 

251 pass 

252 self._clear_io_state() 

253 super().close(linger=linger) 

254 

255 close.__doc__ = _zmq.Socket.close.__doc__ 

256 

257 def get(self, key): 

258 result = super().get(key) 

259 if key == EVENTS: 

260 self._schedule_remaining_events(result) 

261 return result 

262 

263 get.__doc__ = _zmq.Socket.get.__doc__ 

264 

265 @overload # type: ignore 

266 def recv_multipart( 

267 self, flags: int = 0, *, track: bool = False 

268 ) -> Awaitable[List[bytes]]: 

269 ... 

270 

271 @overload 

272 def recv_multipart( 

273 self, flags: int = 0, *, copy: Literal[True], track: bool = False 

274 ) -> Awaitable[List[bytes]]: 

275 ... 

276 

277 @overload 

278 def recv_multipart( 

279 self, flags: int = 0, *, copy: Literal[False], track: bool = False 

280 ) -> Awaitable[List[_zmq.Frame]]: # type: ignore 

281 ... 

282 

283 @overload 

284 def recv_multipart( 

285 self, flags: int = 0, copy: bool = True, track: bool = False 

286 ) -> Awaitable[Union[List[bytes], List[_zmq.Frame]]]: 

287 ... 

288 

289 def recv_multipart( 

290 self, flags: int = 0, copy: bool = True, track: bool = False 

291 ) -> Awaitable[Union[List[bytes], List[_zmq.Frame]]]: 

292 """Receive a complete multipart zmq message. 

293 

294 Returns a Future whose result will be a multipart message. 

295 """ 

296 return self._add_recv_event( 

297 'recv_multipart', dict(flags=flags, copy=copy, track=track) 

298 ) 

299 

300 def recv( # type: ignore 

301 self, flags: int = 0, copy: bool = True, track: bool = False 

302 ) -> Awaitable[Union[bytes, _zmq.Frame]]: 

303 """Receive a single zmq frame. 

304 

305 Returns a Future, whose result will be the received frame. 

306 

307 Recommend using recv_multipart instead. 

308 """ 

309 return self._add_recv_event('recv', dict(flags=flags, copy=copy, track=track)) 

310 

311 def send_multipart( # type: ignore 

312 self, msg_parts: Any, flags: int = 0, copy: bool = True, track=False, **kwargs 

313 ) -> Awaitable[Optional[_zmq.MessageTracker]]: 

314 """Send a complete multipart zmq message. 

315 

316 Returns a Future that resolves when sending is complete. 

317 """ 

318 kwargs['flags'] = flags 

319 kwargs['copy'] = copy 

320 kwargs['track'] = track 

321 return self._add_send_event('send_multipart', msg=msg_parts, kwargs=kwargs) 

322 

323 def send( # type: ignore 

324 self, 

325 data: Any, 

326 flags: int = 0, 

327 copy: bool = True, 

328 track: bool = False, 

329 **kwargs: Any, 

330 ) -> Awaitable[Optional[_zmq.MessageTracker]]: 

331 """Send a single zmq frame. 

332 

333 Returns a Future that resolves when sending is complete. 

334 

335 Recommend using send_multipart instead. 

336 """ 

337 kwargs['flags'] = flags 

338 kwargs['copy'] = copy 

339 kwargs['track'] = track 

340 kwargs.update(dict(flags=flags, copy=copy, track=track)) 

341 return self._add_send_event('send', msg=data, kwargs=kwargs) 

342 

343 def _deserialize(self, recvd, load): 

344 """Deserialize with Futures""" 

345 f = self._Future() 

346 

347 def _chain(_): 

348 """Chain result through serialization to recvd""" 

349 if f.done(): 

350 return 

351 if recvd.exception(): 

352 f.set_exception(recvd.exception()) 

353 else: 

354 buf = recvd.result() 

355 try: 

356 loaded = load(buf) 

357 except Exception as e: 

358 f.set_exception(e) 

359 else: 

360 f.set_result(loaded) 

361 

362 recvd.add_done_callback(_chain) 

363 

364 def _chain_cancel(_): 

365 """Chain cancellation from f to recvd""" 

366 if recvd.done(): 

367 return 

368 if f.cancelled(): 

369 recvd.cancel() 

370 

371 f.add_done_callback(_chain_cancel) 

372 

373 return f 

374 

375 def poll(self, timeout=None, flags=_zmq.POLLIN) -> Awaitable[int]: # type: ignore 

376 """poll the socket for events 

377 

378 returns a Future for the poll results. 

379 """ 

380 

381 if self.closed: 

382 raise _zmq.ZMQError(_zmq.ENOTSUP) 

383 

384 p = self._poller_class() 

385 p.register(self, flags) 

386 poll_future = cast(Future, p.poll(timeout)) 

387 

388 future = self._Future() 

389 

390 def unwrap_result(f): 

391 if future.done(): 

392 return 

393 if poll_future.cancelled(): 

394 try: 

395 future.cancel() 

396 except RuntimeError: 

397 # RuntimeError may be called during teardown 

398 pass 

399 return 

400 if f.exception(): 

401 future.set_exception(poll_future.exception()) 

402 else: 

403 evts = dict(poll_future.result()) 

404 future.set_result(evts.get(self, 0)) 

405 

406 if poll_future.done(): 

407 # hook up result if already done 

408 unwrap_result(poll_future) 

409 else: 

410 poll_future.add_done_callback(unwrap_result) 

411 

412 def cancel_poll(future): 

413 """Cancel underlying poll if request has been cancelled""" 

414 if not poll_future.done(): 

415 try: 

416 poll_future.cancel() 

417 except RuntimeError: 

418 # RuntimeError may be called during teardown 

419 pass 

420 

421 future.add_done_callback(cancel_poll) 

422 

423 return future 

424 

425 # overrides only necessary for updated types 

426 def recv_string(self, *args, **kwargs) -> Awaitable[str]: # type: ignore 

427 return super().recv_string(*args, **kwargs) # type: ignore 

428 

429 def send_string(self, s: str, flags: int = 0, encoding: str = 'utf-8') -> Awaitable[None]: # type: ignore 

430 return super().send_string(s, flags=flags, encoding=encoding) # type: ignore 

431 

432 def _add_timeout(self, future, timeout): 

433 """Add a timeout for a send or recv Future""" 

434 

435 def future_timeout(): 

436 if future.done(): 

437 # future already resolved, do nothing 

438 return 

439 

440 # raise EAGAIN 

441 future.set_exception(_zmq.Again()) 

442 

443 return self._call_later(timeout, future_timeout) 

444 

445 def _call_later(self, delay, callback): 

446 """Schedule a function to be called later 

447 

448 Override for different IOLoop implementations 

449 

450 Tornado and asyncio happen to both have ioloop.call_later 

451 with the same signature. 

452 """ 

453 return self._get_loop().call_later(delay, callback) 

454 

455 @staticmethod 

456 def _remove_finished_future(future, event_list): 

457 """Make sure that futures are removed from the event list when they resolve 

458 

459 Avoids delaying cleanup until the next send/recv event, 

460 which may never come. 

461 """ 

462 for f_idx, event in enumerate(event_list): 

463 if event.future is future: 

464 break 

465 else: 

466 return 

467 

468 # "future" instance is shared between sockets, but each socket has its own event list. 

469 event_list.remove(event_list[f_idx]) 

470 

471 def _add_recv_event(self, kind, kwargs=None, future=None): 

472 """Add a recv event, returning the corresponding Future""" 

473 f = future or self._Future() 

474 if kind.startswith('recv') and kwargs.get('flags', 0) & _zmq.DONTWAIT: 

475 # short-circuit non-blocking calls 

476 recv = getattr(self._shadow_sock, kind) 

477 try: 

478 r = recv(**kwargs) 

479 except Exception as e: 

480 f.set_exception(e) 

481 else: 

482 f.set_result(r) 

483 return f 

484 

485 timer = _NoTimer 

486 if hasattr(_zmq, 'RCVTIMEO'): 

487 timeout_ms = self._shadow_sock.rcvtimeo 

488 if timeout_ms >= 0: 

489 timer = self._add_timeout(f, timeout_ms * 1e-3) 

490 

491 # we add it to the list of futures before we add the timeout as the 

492 # timeout will remove the future from recv_futures to avoid leaks 

493 self._recv_futures.append(_FutureEvent(f, kind, kwargs, msg=None, timer=timer)) 

494 

495 # Don't let the Future sit in _recv_events after it's done 

496 f.add_done_callback( 

497 lambda f: self._remove_finished_future(f, self._recv_futures) 

498 ) 

499 

500 if self._shadow_sock.get(EVENTS) & POLLIN: 

501 # recv immediately, if we can 

502 self._handle_recv() 

503 if self._recv_futures: 

504 self._add_io_state(POLLIN) 

505 return f 

506 

507 def _add_send_event(self, kind, msg=None, kwargs=None, future=None): 

508 """Add a send event, returning the corresponding Future""" 

509 f = future or self._Future() 

510 # attempt send with DONTWAIT if no futures are waiting 

511 # short-circuit for sends that will resolve immediately 

512 # only call if no send Futures are waiting 

513 if kind in ('send', 'send_multipart') and not self._send_futures: 

514 flags = kwargs.get('flags', 0) 

515 nowait_kwargs = kwargs.copy() 

516 nowait_kwargs['flags'] = flags | _zmq.DONTWAIT 

517 

518 # short-circuit non-blocking calls 

519 send = getattr(self._shadow_sock, kind) 

520 # track if the send resolved or not 

521 # (EAGAIN if DONTWAIT is not set should proceed with) 

522 finish_early = True 

523 try: 

524 r = send(msg, **nowait_kwargs) 

525 except _zmq.Again as e: 

526 if flags & _zmq.DONTWAIT: 

527 f.set_exception(e) 

528 else: 

529 # EAGAIN raised and DONTWAIT not requested, 

530 # proceed with async send 

531 finish_early = False 

532 except Exception as e: 

533 f.set_exception(e) 

534 else: 

535 f.set_result(r) 

536 

537 if finish_early: 

538 # short-circuit resolved, return finished Future 

539 # schedule wake for recv if there are any receivers waiting 

540 if self._recv_futures: 

541 self._schedule_remaining_events() 

542 return f 

543 

544 timer = _NoTimer 

545 if hasattr(_zmq, 'SNDTIMEO'): 

546 timeout_ms = self._shadow_sock.get(_zmq.SNDTIMEO) 

547 if timeout_ms >= 0: 

548 timer = self._add_timeout(f, timeout_ms * 1e-3) 

549 

550 # we add it to the list of futures before we add the timeout as the 

551 # timeout will remove the future from recv_futures to avoid leaks 

552 self._send_futures.append( 

553 _FutureEvent(f, kind, kwargs=kwargs, msg=msg, timer=timer) 

554 ) 

555 # Don't let the Future sit in _send_futures after it's done 

556 f.add_done_callback( 

557 lambda f: self._remove_finished_future(f, self._send_futures) 

558 ) 

559 

560 self._add_io_state(POLLOUT) 

561 return f 

562 

563 def _handle_recv(self): 

564 """Handle recv events""" 

565 if not self._shadow_sock.get(EVENTS) & POLLIN: 

566 # event triggered, but state may have been changed between trigger and callback 

567 return 

568 f = None 

569 while self._recv_futures: 

570 f, kind, kwargs, _, timer = self._recv_futures.popleft() 

571 # skip any cancelled futures 

572 if f.done(): 

573 f = None 

574 else: 

575 break 

576 

577 if not self._recv_futures: 

578 self._drop_io_state(POLLIN) 

579 

580 if f is None: 

581 return 

582 

583 timer.cancel() 

584 

585 if kind == 'poll': 

586 # on poll event, just signal ready, nothing else. 

587 f.set_result(None) 

588 return 

589 elif kind == 'recv_multipart': 

590 recv = self._shadow_sock.recv_multipart 

591 elif kind == 'recv': 

592 recv = self._shadow_sock.recv 

593 else: 

594 raise ValueError("Unhandled recv event type: %r" % kind) 

595 

596 kwargs['flags'] |= _zmq.DONTWAIT 

597 try: 

598 result = recv(**kwargs) 

599 except Exception as e: 

600 f.set_exception(e) 

601 else: 

602 f.set_result(result) 

603 

604 def _handle_send(self): 

605 if not self._shadow_sock.get(EVENTS) & POLLOUT: 

606 # event triggered, but state may have been changed between trigger and callback 

607 return 

608 f = None 

609 while self._send_futures: 

610 f, kind, kwargs, msg, timer = self._send_futures.popleft() 

611 # skip any cancelled futures 

612 if f.done(): 

613 f = None 

614 else: 

615 break 

616 

617 if not self._send_futures: 

618 self._drop_io_state(POLLOUT) 

619 

620 if f is None: 

621 return 

622 

623 timer.cancel() 

624 

625 if kind == 'poll': 

626 # on poll event, just signal ready, nothing else. 

627 f.set_result(None) 

628 return 

629 elif kind == 'send_multipart': 

630 send = self._shadow_sock.send_multipart 

631 elif kind == 'send': 

632 send = self._shadow_sock.send 

633 else: 

634 raise ValueError("Unhandled send event type: %r" % kind) 

635 

636 kwargs['flags'] |= _zmq.DONTWAIT 

637 try: 

638 result = send(msg, **kwargs) 

639 except Exception as e: 

640 f.set_exception(e) 

641 else: 

642 f.set_result(result) 

643 

644 # event masking from ZMQStream 

645 def _handle_events(self, fd=0, events=0): 

646 """Dispatch IO events to _handle_recv, etc.""" 

647 if self._shadow_sock.closed: 

648 return 

649 

650 zmq_events = self._shadow_sock.get(EVENTS) 

651 if zmq_events & _zmq.POLLIN: 

652 self._handle_recv() 

653 if zmq_events & _zmq.POLLOUT: 

654 self._handle_send() 

655 self._schedule_remaining_events() 

656 

657 def _schedule_remaining_events(self, events=None): 

658 """Schedule a call to handle_events next loop iteration 

659 

660 If there are still events to handle. 

661 """ 

662 # edge-triggered handling 

663 # allow passing events in, in case this is triggered by retrieving events, 

664 # so we don't have to retrieve it twice. 

665 if self._state == 0: 

666 # not watching for anything, nothing to schedule 

667 return 

668 if events is None: 

669 events = self._shadow_sock.get(EVENTS) 

670 if events & self._state: 

671 self._call_later(0, self._handle_events) 

672 

673 def _add_io_state(self, state): 

674 """Add io_state to poller.""" 

675 if self._state != state: 

676 state = self._state = self._state | state 

677 self._update_handler(self._state) 

678 

679 def _drop_io_state(self, state): 

680 """Stop poller from watching an io_state.""" 

681 if self._state & state: 

682 self._state = self._state & (~state) 

683 self._update_handler(self._state) 

684 

685 def _update_handler(self, state): 

686 """Update IOLoop handler with state. 

687 

688 zmq FD is always read-only. 

689 """ 

690 # ensure loop is registered and init_io has been called 

691 # if there are any events to watch for 

692 if state: 

693 self._get_loop() 

694 self._schedule_remaining_events() 

695 

696 def _init_io_state(self, loop=None): 

697 """initialize the ioloop event handler""" 

698 if loop is None: 

699 loop = self._get_loop() 

700 loop.add_handler(self._shadow_sock, self._handle_events, self._READ) 

701 self._call_later(0, self._handle_events) 

702 

703 def _clear_io_state(self): 

704 """unregister the ioloop event handler 

705 

706 called once during close 

707 """ 

708 fd = self._shadow_sock 

709 if self._shadow_sock.closed: 

710 fd = self._fd 

711 if self._current_loop is not None: 

712 self._current_loop.remove_handler(fd)