Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/zmq/_future.py: 18%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

402 statements  

1"""Future-returning APIs for coroutines.""" 

2 

3# Copyright (c) PyZMQ Developers. 

4# Distributed under the terms of the Modified BSD License. 

5from __future__ import annotations 

6 

7import warnings 

8from asyncio import Future 

9from collections import deque 

10from functools import partial 

11from itertools import chain 

12from typing import ( 

13 Any, 

14 Awaitable, 

15 Callable, 

16 NamedTuple, 

17 TypeVar, 

18 cast, 

19) 

20 

21import zmq as _zmq 

22from zmq import EVENTS, POLLIN, POLLOUT 

23 

24 

25class _FutureEvent(NamedTuple): 

26 future: Future 

27 kind: str 

28 args: tuple 

29 kwargs: dict 

30 msg: Any 

31 timer: Any 

32 

33 

34# These are incomplete classes and need a Mixin for compatibility with an eventloop 

35# defining the following attributes: 

36# 

37# _Future 

38# _READ 

39# _WRITE 

40# _default_loop() 

41 

42 

43class _Async: 

44 """Mixin for common async logic""" 

45 

46 _current_loop: Any = None 

47 _Future: type[Future] 

48 

49 def _get_loop(self) -> Any: 

50 """Get event loop 

51 

52 Notice if event loop has changed, 

53 and register init_io_state on activation of a new event loop 

54 """ 

55 if self._current_loop is None: 

56 self._current_loop = self._default_loop() 

57 self._init_io_state(self._current_loop) 

58 return self._current_loop 

59 current_loop = self._default_loop() 

60 if current_loop is not self._current_loop: 

61 # warn? This means a socket is being used in multiple loops! 

62 self._current_loop = current_loop 

63 self._init_io_state(current_loop) 

64 return current_loop 

65 

66 def _default_loop(self) -> Any: 

67 raise NotImplementedError("Must be implemented in a subclass") 

68 

69 def _init_io_state(self, loop=None) -> None: 

70 pass 

71 

72 

73class _AsyncPoller(_Async, _zmq.Poller): 

74 """Poller that returns a Future on poll, instead of blocking.""" 

75 

76 _socket_class: type[_AsyncSocket] 

77 _READ: int 

78 _WRITE: int 

79 raw_sockets: list[Any] 

80 

81 def _watch_raw_socket(self, loop: Any, socket: Any, evt: int, f: Callable) -> None: 

82 """Schedule callback for a raw socket""" 

83 raise NotImplementedError() 

84 

85 def _unwatch_raw_sockets(self, loop: Any, *sockets: Any) -> None: 

86 """Unschedule callback for a raw socket""" 

87 raise NotImplementedError() 

88 

89 def poll(self, timeout=-1) -> Awaitable[list[tuple[Any, int]]]: # type: ignore 

90 """Return a Future for a poll event""" 

91 future = self._Future() 

92 if timeout == 0: 

93 try: 

94 result = super().poll(0) 

95 except Exception as e: 

96 future.set_exception(e) 

97 else: 

98 future.set_result(result) 

99 return future 

100 

101 loop = self._get_loop() 

102 

103 # register Future to be called as soon as any event is available on any socket 

104 watcher = self._Future() 

105 

106 # watch raw sockets: 

107 raw_sockets: list[Any] = [] 

108 

109 def wake_raw(*args): 

110 if not watcher.done(): 

111 watcher.set_result(None) 

112 

113 watcher.add_done_callback( 

114 lambda f: self._unwatch_raw_sockets(loop, *raw_sockets) 

115 ) 

116 

117 wrapped_sockets: list[_AsyncSocket] = [] 

118 

119 def _clear_wrapper_io(f): 

120 for s in wrapped_sockets: 

121 s._clear_io_state() 

122 

123 for socket, mask in self.sockets: 

124 if isinstance(socket, _zmq.Socket): 

125 if not isinstance(socket, self._socket_class): 

126 # it's a blocking zmq.Socket, wrap it in async 

127 socket = self._socket_class.from_socket(socket) 

128 wrapped_sockets.append(socket) 

129 if mask & _zmq.POLLIN: 

130 socket._add_recv_event('poll', future=watcher) 

131 if mask & _zmq.POLLOUT: 

132 socket._add_send_event('poll', future=watcher) 

133 else: 

134 raw_sockets.append(socket) 

135 evt = 0 

136 if mask & _zmq.POLLIN: 

137 evt |= self._READ 

138 if mask & _zmq.POLLOUT: 

139 evt |= self._WRITE 

140 self._watch_raw_socket(loop, socket, evt, wake_raw) 

141 

142 def on_poll_ready(f): 

143 if future.done(): 

144 return 

145 if watcher.cancelled(): 

146 try: 

147 future.cancel() 

148 except RuntimeError: 

149 # RuntimeError may be called during teardown 

150 pass 

151 return 

152 if watcher.exception(): 

153 future.set_exception(watcher.exception()) 

154 else: 

155 try: 

156 result = super(_AsyncPoller, self).poll(0) 

157 except Exception as e: 

158 future.set_exception(e) 

159 else: 

160 future.set_result(result) 

161 

162 watcher.add_done_callback(on_poll_ready) 

163 

164 if wrapped_sockets: 

165 watcher.add_done_callback(_clear_wrapper_io) 

166 

167 if timeout is not None and timeout > 0: 

168 # schedule cancel to fire on poll timeout, if any 

169 def trigger_timeout(): 

170 if not watcher.done(): 

171 watcher.set_result(None) 

172 

173 timeout_handle = loop.call_later(1e-3 * timeout, trigger_timeout) 

174 

175 def cancel_timeout(f): 

176 if hasattr(timeout_handle, 'cancel'): 

177 timeout_handle.cancel() 

178 else: 

179 loop.remove_timeout(timeout_handle) 

180 

181 future.add_done_callback(cancel_timeout) 

182 

183 def cancel_watcher(f): 

184 if not watcher.done(): 

185 watcher.cancel() 

186 

187 future.add_done_callback(cancel_watcher) 

188 

189 return future 

190 

191 

192class _NoTimer: 

193 @staticmethod 

194 def cancel(): 

195 pass 

196 

197 

198T = TypeVar("T", bound="_AsyncSocket") 

199 

200 

201class _AsyncSocket(_Async, _zmq.Socket[Future]): 

202 # Warning : these class variables are only here to allow to call super().__setattr__. 

203 # They be overridden at instance initialization and not shared in the whole class 

204 _recv_futures = None 

205 _send_futures = None 

206 _state = 0 

207 _shadow_sock: _zmq.Socket 

208 _poller_class = _AsyncPoller 

209 _fd = None 

210 

211 def __init__( 

212 self, 

213 context=None, 

214 socket_type=-1, 

215 io_loop=None, 

216 _from_socket: _zmq.Socket | None = None, 

217 **kwargs, 

218 ) -> None: 

219 if isinstance(context, _zmq.Socket): 

220 context, _from_socket = (None, context) 

221 if _from_socket is not None: 

222 super().__init__(shadow=_from_socket.underlying) # type: ignore 

223 self._shadow_sock = _from_socket 

224 else: 

225 super().__init__(context, socket_type, **kwargs) # type: ignore 

226 self._shadow_sock = _zmq.Socket.shadow(self.underlying) 

227 

228 if io_loop is not None: 

229 warnings.warn( 

230 f"{self.__class__.__name__}(io_loop) argument is deprecated in pyzmq 22.2." 

231 " The currently active loop will always be used.", 

232 DeprecationWarning, 

233 stacklevel=3, 

234 ) 

235 self._recv_futures = deque() 

236 self._send_futures = deque() 

237 self._state = 0 

238 self._fd = self._shadow_sock.FD 

239 

240 @classmethod 

241 def from_socket(cls: type[T], socket: _zmq.Socket, io_loop: Any = None) -> T: 

242 """Create an async socket from an existing Socket""" 

243 return cls(_from_socket=socket, io_loop=io_loop) 

244 

245 def close(self, linger: int | None = None) -> None: 

246 if not self.closed and self._fd is not None: 

247 event_list: list[_FutureEvent] = list( 

248 chain(self._recv_futures or [], self._send_futures or []) 

249 ) 

250 for event in event_list: 

251 if not event.future.done(): 

252 try: 

253 event.future.cancel() 

254 except RuntimeError: 

255 # RuntimeError may be called during teardown 

256 pass 

257 self._clear_io_state() 

258 super().close(linger=linger) 

259 

260 close.__doc__ = _zmq.Socket.close.__doc__ 

261 

262 def get(self, key): 

263 result = super().get(key) 

264 if key == EVENTS: 

265 self._schedule_remaining_events(result) 

266 return result 

267 

268 get.__doc__ = _zmq.Socket.get.__doc__ 

269 

270 def recv_multipart( 

271 self, flags: int = 0, copy: bool = True, track: bool = False 

272 ) -> Awaitable[list[bytes] | list[_zmq.Frame]]: 

273 """Receive a complete multipart zmq message. 

274 

275 Returns a Future whose result will be a multipart message. 

276 """ 

277 return self._add_recv_event( 

278 'recv_multipart', kwargs=dict(flags=flags, copy=copy, track=track) 

279 ) 

280 

281 def recv( # type: ignore 

282 self, flags: int = 0, copy: bool = True, track: bool = False 

283 ) -> Awaitable[bytes | _zmq.Frame]: 

284 """Receive a single zmq frame. 

285 

286 Returns a Future, whose result will be the received frame. 

287 

288 Recommend using recv_multipart instead. 

289 """ 

290 return self._add_recv_event( 

291 'recv', kwargs=dict(flags=flags, copy=copy, track=track) 

292 ) 

293 

294 def recv_into( # type: ignore 

295 self, buf, /, *, nbytes: int = 0, flags: int = 0 

296 ) -> Awaitable[int]: 

297 """Receive a single zmq frame into a pre-allocated buffer. 

298 

299 Returns a Future, whose result will be the number of bytes received. 

300 """ 

301 return self._add_recv_event( 

302 'recv_into', args=(buf,), kwargs=dict(nbytes=nbytes, flags=flags) 

303 ) 

304 

305 def send_multipart( # type: ignore 

306 self, msg_parts: Any, flags: int = 0, copy: bool = True, track=False, **kwargs 

307 ) -> Awaitable[_zmq.MessageTracker | None]: 

308 """Send a complete multipart zmq message. 

309 

310 Returns a Future that resolves when sending is complete. 

311 """ 

312 kwargs['flags'] = flags 

313 kwargs['copy'] = copy 

314 kwargs['track'] = track 

315 return self._add_send_event('send_multipart', msg=msg_parts, kwargs=kwargs) 

316 

317 def send( # type: ignore 

318 self, 

319 data: Any, 

320 flags: int = 0, 

321 copy: bool = True, 

322 track: bool = False, 

323 **kwargs: Any, 

324 ) -> Awaitable[_zmq.MessageTracker | None]: 

325 """Send a single zmq frame. 

326 

327 Returns a Future that resolves when sending is complete. 

328 

329 Recommend using send_multipart instead. 

330 """ 

331 kwargs['flags'] = flags 

332 kwargs['copy'] = copy 

333 kwargs['track'] = track 

334 kwargs.update(dict(flags=flags, copy=copy, track=track)) 

335 return self._add_send_event('send', msg=data, kwargs=kwargs) 

336 

337 def _deserialize(self, recvd, load): 

338 """Deserialize with Futures""" 

339 f = self._Future() 

340 

341 def _chain(_): 

342 """Chain result through serialization to recvd""" 

343 if f.done(): 

344 # chained future may be cancelled, which means nobody is going to get this result 

345 # if it's an error, that's no big deal (probably zmq.Again), 

346 # but if it's a successful recv, this is a dropped message! 

347 if not recvd.cancelled() and recvd.exception() is None: 

348 warnings.warn( 

349 # is there a useful stacklevel? 

350 # ideally, it would point to where `f.cancel()` was called 

351 f"Future {f} completed while awaiting {recvd}. A message has been dropped!", 

352 RuntimeWarning, 

353 ) 

354 return 

355 if recvd.exception(): 

356 f.set_exception(recvd.exception()) 

357 else: 

358 buf = recvd.result() 

359 try: 

360 loaded = load(buf) 

361 except Exception as e: 

362 f.set_exception(e) 

363 else: 

364 f.set_result(loaded) 

365 

366 recvd.add_done_callback(_chain) 

367 

368 def _chain_cancel(_): 

369 """Chain cancellation from f to recvd""" 

370 if recvd.done(): 

371 return 

372 if f.cancelled(): 

373 recvd.cancel() 

374 

375 f.add_done_callback(_chain_cancel) 

376 

377 return f 

378 

379 def poll(self, timeout=None, flags=_zmq.POLLIN) -> Awaitable[int]: # type: ignore 

380 """poll the socket for events 

381 

382 returns a Future for the poll results. 

383 """ 

384 

385 if self.closed: 

386 raise _zmq.ZMQError(_zmq.ENOTSUP) 

387 

388 p = self._poller_class() 

389 p.register(self, flags) 

390 poll_future = cast(Future, p.poll(timeout)) 

391 

392 future = self._Future() 

393 

394 def unwrap_result(f): 

395 if future.done(): 

396 return 

397 if poll_future.cancelled(): 

398 try: 

399 future.cancel() 

400 except RuntimeError: 

401 # RuntimeError may be called during teardown 

402 pass 

403 return 

404 if f.exception(): 

405 future.set_exception(poll_future.exception()) 

406 else: 

407 evts = dict(poll_future.result()) 

408 future.set_result(evts.get(self, 0)) 

409 

410 if poll_future.done(): 

411 # hook up result if already done 

412 unwrap_result(poll_future) 

413 else: 

414 poll_future.add_done_callback(unwrap_result) 

415 

416 def cancel_poll(future): 

417 """Cancel underlying poll if request has been cancelled""" 

418 if not poll_future.done(): 

419 try: 

420 poll_future.cancel() 

421 except RuntimeError: 

422 # RuntimeError may be called during teardown 

423 pass 

424 

425 future.add_done_callback(cancel_poll) 

426 

427 return future 

428 

429 def _add_timeout(self, future, timeout): 

430 """Add a timeout for a send or recv Future""" 

431 

432 def future_timeout(): 

433 if future.done(): 

434 # future already resolved, do nothing 

435 return 

436 

437 # raise EAGAIN 

438 future.set_exception(_zmq.Again()) 

439 

440 return self._call_later(timeout, future_timeout) 

441 

442 def _call_later(self, delay, callback): 

443 """Schedule a function to be called later 

444 

445 Override for different IOLoop implementations 

446 

447 Tornado and asyncio happen to both have ioloop.call_later 

448 with the same signature. 

449 """ 

450 return self._get_loop().call_later(delay, callback) 

451 

452 @staticmethod 

453 def _remove_finished_future(future, event_list, event=None): 

454 """Make sure that futures are removed from the event list when they resolve 

455 

456 Avoids delaying cleanup until the next send/recv event, 

457 which may never come. 

458 """ 

459 # "future" instance is shared between sockets, but each socket has its own event list. 

460 if not event_list: 

461 return 

462 # only unconsumed events (e.g. cancelled calls) 

463 # will be present when this happens 

464 try: 

465 event_list.remove(event) 

466 except ValueError: 

467 # usually this will have been removed by being consumed 

468 return 

469 

470 def _add_recv_event( 

471 self, 

472 kind: str, 

473 *, 

474 args: tuple | None = None, 

475 kwargs: dict[str, Any] | None = None, 

476 future: Future | None = None, 

477 ) -> Future: 

478 """Add a recv event, returning the corresponding Future""" 

479 f = future or self._Future() 

480 if args is None: 

481 args = () 

482 if kwargs is None: 

483 kwargs = {} 

484 if kind.startswith('recv') and kwargs.get('flags', 0) & _zmq.DONTWAIT: 

485 # short-circuit non-blocking calls 

486 recv = getattr(self._shadow_sock, kind) 

487 try: 

488 r = recv(*args, **kwargs) 

489 except Exception as e: 

490 f.set_exception(e) 

491 else: 

492 f.set_result(r) 

493 return f 

494 

495 timer = _NoTimer 

496 if hasattr(_zmq, 'RCVTIMEO'): 

497 timeout_ms = self._shadow_sock.rcvtimeo 

498 if timeout_ms >= 0: 

499 timer = self._add_timeout(f, timeout_ms * 1e-3) 

500 

501 # we add it to the list of futures before we add the timeout as the 

502 # timeout will remove the future from recv_futures to avoid leaks 

503 _future_event = _FutureEvent( 

504 f, kind, args=args, kwargs=kwargs, msg=None, timer=timer 

505 ) 

506 self._recv_futures.append(_future_event) 

507 

508 if self._shadow_sock.get(EVENTS) & POLLIN: 

509 # recv immediately, if we can 

510 self._handle_recv() 

511 if self._recv_futures and _future_event in self._recv_futures: 

512 # Don't let the Future sit in _recv_events after it's done 

513 # no need to register this if we've already been handled 

514 # (i.e. immediately-resolved recv) 

515 f.add_done_callback( 

516 partial( 

517 self._remove_finished_future, 

518 event_list=self._recv_futures, 

519 event=_future_event, 

520 ) 

521 ) 

522 self._add_io_state(POLLIN) 

523 return f 

524 

525 def _add_send_event(self, kind, msg=None, kwargs=None, future=None): 

526 """Add a send event, returning the corresponding Future""" 

527 f = future or self._Future() 

528 # attempt send with DONTWAIT if no futures are waiting 

529 # short-circuit for sends that will resolve immediately 

530 # only call if no send Futures are waiting 

531 if kind in ('send', 'send_multipart') and not self._send_futures: 

532 flags = kwargs.get('flags', 0) 

533 nowait_kwargs = kwargs.copy() 

534 nowait_kwargs['flags'] = flags | _zmq.DONTWAIT 

535 

536 # short-circuit non-blocking calls 

537 send = getattr(self._shadow_sock, kind) 

538 # track if the send resolved or not 

539 # (EAGAIN if DONTWAIT is not set should proceed with) 

540 finish_early = True 

541 try: 

542 r = send(msg, **nowait_kwargs) 

543 except _zmq.Again as e: 

544 if flags & _zmq.DONTWAIT: 

545 f.set_exception(e) 

546 else: 

547 # EAGAIN raised and DONTWAIT not requested, 

548 # proceed with async send 

549 finish_early = False 

550 except Exception as e: 

551 f.set_exception(e) 

552 else: 

553 f.set_result(r) 

554 

555 if finish_early: 

556 # short-circuit resolved, return finished Future 

557 # schedule wake for recv if there are any receivers waiting 

558 if self._recv_futures: 

559 self._schedule_remaining_events() 

560 return f 

561 

562 timer = _NoTimer 

563 if hasattr(_zmq, 'SNDTIMEO'): 

564 timeout_ms = self._shadow_sock.get(_zmq.SNDTIMEO) 

565 if timeout_ms >= 0: 

566 timer = self._add_timeout(f, timeout_ms * 1e-3) 

567 

568 # we add it to the list of futures before we add the timeout as the 

569 # timeout will remove the future from recv_futures to avoid leaks 

570 _future_event = _FutureEvent( 

571 f, kind, args=(), kwargs=kwargs, msg=msg, timer=timer 

572 ) 

573 self._send_futures.append(_future_event) 

574 # Don't let the Future sit in _send_futures after it's done 

575 f.add_done_callback( 

576 partial( 

577 self._remove_finished_future, 

578 event_list=self._send_futures, 

579 event=_future_event, 

580 ) 

581 ) 

582 

583 self._add_io_state(POLLOUT) 

584 return f 

585 

586 def _handle_recv(self): 

587 """Handle recv events""" 

588 if not self._shadow_sock.get(EVENTS) & POLLIN: 

589 # event triggered, but state may have been changed between trigger and callback 

590 return 

591 f = None 

592 while self._recv_futures: 

593 f, kind, args, kwargs, _, timer = self._recv_futures.popleft() 

594 # skip any cancelled futures 

595 if f.done(): 

596 f = None 

597 else: 

598 break 

599 

600 if not self._recv_futures: 

601 self._drop_io_state(POLLIN) 

602 

603 if f is None: 

604 return 

605 

606 timer.cancel() 

607 

608 if kind == 'poll': 

609 # on poll event, just signal ready, nothing else. 

610 f.set_result(None) 

611 return 

612 elif kind == 'recv_multipart': 

613 recv = self._shadow_sock.recv_multipart 

614 elif kind == 'recv': 

615 recv = self._shadow_sock.recv 

616 elif kind == 'recv_into': 

617 recv = self._shadow_sock.recv_into 

618 else: 

619 raise ValueError(f"Unhandled recv event type: {kind!r}") 

620 

621 kwargs['flags'] |= _zmq.DONTWAIT 

622 try: 

623 result = recv(*args, **kwargs) 

624 except Exception as e: 

625 f.set_exception(e) 

626 else: 

627 f.set_result(result) 

628 

629 def _handle_send(self): 

630 if not self._shadow_sock.get(EVENTS) & POLLOUT: 

631 # event triggered, but state may have been changed between trigger and callback 

632 return 

633 f = None 

634 while self._send_futures: 

635 f, kind, args, kwargs, msg, timer = self._send_futures.popleft() 

636 # skip any cancelled futures 

637 if f.done(): 

638 f = None 

639 else: 

640 break 

641 

642 if not self._send_futures: 

643 self._drop_io_state(POLLOUT) 

644 

645 if f is None: 

646 return 

647 

648 timer.cancel() 

649 

650 if kind == 'poll': 

651 # on poll event, just signal ready, nothing else. 

652 f.set_result(None) 

653 return 

654 elif kind == 'send_multipart': 

655 send = self._shadow_sock.send_multipart 

656 elif kind == 'send': 

657 send = self._shadow_sock.send 

658 else: 

659 raise ValueError(f"Unhandled send event type: {kind!r}") 

660 

661 kwargs['flags'] |= _zmq.DONTWAIT 

662 try: 

663 result = send(msg, **kwargs) 

664 except Exception as e: 

665 f.set_exception(e) 

666 else: 

667 f.set_result(result) 

668 

669 # event masking from ZMQStream 

670 def _handle_events(self, fd=0, events=0): 

671 """Dispatch IO events to _handle_recv, etc.""" 

672 if self._shadow_sock.closed: 

673 return 

674 

675 zmq_events = self._shadow_sock.get(EVENTS) 

676 if zmq_events & _zmq.POLLIN: 

677 self._handle_recv() 

678 if zmq_events & _zmq.POLLOUT: 

679 self._handle_send() 

680 self._schedule_remaining_events() 

681 

682 def _schedule_remaining_events(self, events=None): 

683 """Schedule a call to handle_events next loop iteration 

684 

685 If there are still events to handle. 

686 """ 

687 # edge-triggered handling 

688 # allow passing events in, in case this is triggered by retrieving events, 

689 # so we don't have to retrieve it twice. 

690 if self._state == 0: 

691 # not watching for anything, nothing to schedule 

692 return 

693 if events is None: 

694 events = self._shadow_sock.get(EVENTS) 

695 if events & self._state: 

696 self._call_later(0, self._handle_events) 

697 

698 def _add_io_state(self, state): 

699 """Add io_state to poller.""" 

700 if self._state != state: 

701 state = self._state = self._state | state 

702 self._update_handler(self._state) 

703 

704 def _drop_io_state(self, state): 

705 """Stop poller from watching an io_state.""" 

706 if self._state & state: 

707 self._state = self._state & (~state) 

708 self._update_handler(self._state) 

709 

710 def _update_handler(self, state): 

711 """Update IOLoop handler with state. 

712 

713 zmq FD is always read-only. 

714 """ 

715 # ensure loop is registered and init_io has been called 

716 # if there are any events to watch for 

717 if state: 

718 self._get_loop() 

719 self._schedule_remaining_events() 

720 

721 def _init_io_state(self, loop=None): 

722 """initialize the ioloop event handler""" 

723 if loop is None: 

724 loop = self._get_loop() 

725 loop.add_handler(self._shadow_sock, self._handle_events, self._READ) 

726 self._call_later(0, self._handle_events) 

727 

728 def _clear_io_state(self): 

729 """unregister the ioloop event handler 

730 

731 called once during close 

732 """ 

733 fd = self._shadow_sock 

734 if self._shadow_sock.closed: 

735 fd = self._fd 

736 if self._current_loop is not None: 

737 self._current_loop.remove_handler(fd)