Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/zmq/_future.py: 18%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""Future-returning APIs for coroutines."""
3# Copyright (c) PyZMQ Developers.
4# Distributed under the terms of the Modified BSD License.
5from __future__ import annotations
7import warnings
8from asyncio import Future
9from collections import deque
10from functools import partial
11from itertools import chain
12from typing import (
13 Any,
14 Awaitable,
15 Callable,
16 NamedTuple,
17 TypeVar,
18 cast,
19)
21import zmq as _zmq
22from zmq import EVENTS, POLLIN, POLLOUT
25class _FutureEvent(NamedTuple):
26 future: Future
27 kind: str
28 args: tuple
29 kwargs: dict
30 msg: Any
31 timer: Any
34# These are incomplete classes and need a Mixin for compatibility with an eventloop
35# defining the following attributes:
36#
37# _Future
38# _READ
39# _WRITE
40# _default_loop()
43class _Async:
44 """Mixin for common async logic"""
46 _current_loop: Any = None
47 _Future: type[Future]
49 def _get_loop(self) -> Any:
50 """Get event loop
52 Notice if event loop has changed,
53 and register init_io_state on activation of a new event loop
54 """
55 if self._current_loop is None:
56 self._current_loop = self._default_loop()
57 self._init_io_state(self._current_loop)
58 return self._current_loop
59 current_loop = self._default_loop()
60 if current_loop is not self._current_loop:
61 # warn? This means a socket is being used in multiple loops!
62 self._current_loop = current_loop
63 self._init_io_state(current_loop)
64 return current_loop
66 def _default_loop(self) -> Any:
67 raise NotImplementedError("Must be implemented in a subclass")
69 def _init_io_state(self, loop=None) -> None:
70 pass
73class _AsyncPoller(_Async, _zmq.Poller):
74 """Poller that returns a Future on poll, instead of blocking."""
76 _socket_class: type[_AsyncSocket]
77 _READ: int
78 _WRITE: int
79 raw_sockets: list[Any]
81 def _watch_raw_socket(self, loop: Any, socket: Any, evt: int, f: Callable) -> None:
82 """Schedule callback for a raw socket"""
83 raise NotImplementedError()
85 def _unwatch_raw_sockets(self, loop: Any, *sockets: Any) -> None:
86 """Unschedule callback for a raw socket"""
87 raise NotImplementedError()
89 def poll(self, timeout=-1) -> Awaitable[list[tuple[Any, int]]]: # type: ignore
90 """Return a Future for a poll event"""
91 future = self._Future()
92 if timeout == 0:
93 try:
94 result = super().poll(0)
95 except Exception as e:
96 future.set_exception(e)
97 else:
98 future.set_result(result)
99 return future
101 loop = self._get_loop()
103 # register Future to be called as soon as any event is available on any socket
104 watcher = self._Future()
106 # watch raw sockets:
107 raw_sockets: list[Any] = []
109 def wake_raw(*args):
110 if not watcher.done():
111 watcher.set_result(None)
113 watcher.add_done_callback(
114 lambda f: self._unwatch_raw_sockets(loop, *raw_sockets)
115 )
117 wrapped_sockets: list[_AsyncSocket] = []
119 def _clear_wrapper_io(f):
120 for s in wrapped_sockets:
121 s._clear_io_state()
123 for socket, mask in self.sockets:
124 if isinstance(socket, _zmq.Socket):
125 if not isinstance(socket, self._socket_class):
126 # it's a blocking zmq.Socket, wrap it in async
127 socket = self._socket_class.from_socket(socket)
128 wrapped_sockets.append(socket)
129 if mask & _zmq.POLLIN:
130 socket._add_recv_event('poll', future=watcher)
131 if mask & _zmq.POLLOUT:
132 socket._add_send_event('poll', future=watcher)
133 else:
134 raw_sockets.append(socket)
135 evt = 0
136 if mask & _zmq.POLLIN:
137 evt |= self._READ
138 if mask & _zmq.POLLOUT:
139 evt |= self._WRITE
140 self._watch_raw_socket(loop, socket, evt, wake_raw)
142 def on_poll_ready(f):
143 if future.done():
144 return
145 if watcher.cancelled():
146 try:
147 future.cancel()
148 except RuntimeError:
149 # RuntimeError may be called during teardown
150 pass
151 return
152 if watcher.exception():
153 future.set_exception(watcher.exception())
154 else:
155 try:
156 result = super(_AsyncPoller, self).poll(0)
157 except Exception as e:
158 future.set_exception(e)
159 else:
160 future.set_result(result)
162 watcher.add_done_callback(on_poll_ready)
164 if wrapped_sockets:
165 watcher.add_done_callback(_clear_wrapper_io)
167 if timeout is not None and timeout > 0:
168 # schedule cancel to fire on poll timeout, if any
169 def trigger_timeout():
170 if not watcher.done():
171 watcher.set_result(None)
173 timeout_handle = loop.call_later(1e-3 * timeout, trigger_timeout)
175 def cancel_timeout(f):
176 if hasattr(timeout_handle, 'cancel'):
177 timeout_handle.cancel()
178 else:
179 loop.remove_timeout(timeout_handle)
181 future.add_done_callback(cancel_timeout)
183 def cancel_watcher(f):
184 if not watcher.done():
185 watcher.cancel()
187 future.add_done_callback(cancel_watcher)
189 return future
192class _NoTimer:
193 @staticmethod
194 def cancel():
195 pass
198T = TypeVar("T", bound="_AsyncSocket")
201class _AsyncSocket(_Async, _zmq.Socket[Future]):
202 # Warning : these class variables are only here to allow to call super().__setattr__.
203 # They be overridden at instance initialization and not shared in the whole class
204 _recv_futures = None
205 _send_futures = None
206 _state = 0
207 _shadow_sock: _zmq.Socket
208 _poller_class = _AsyncPoller
209 _fd = None
211 def __init__(
212 self,
213 context=None,
214 socket_type=-1,
215 io_loop=None,
216 _from_socket: _zmq.Socket | None = None,
217 **kwargs,
218 ) -> None:
219 if isinstance(context, _zmq.Socket):
220 context, _from_socket = (None, context)
221 if _from_socket is not None:
222 super().__init__(shadow=_from_socket.underlying) # type: ignore
223 self._shadow_sock = _from_socket
224 else:
225 super().__init__(context, socket_type, **kwargs) # type: ignore
226 self._shadow_sock = _zmq.Socket.shadow(self.underlying)
228 if io_loop is not None:
229 warnings.warn(
230 f"{self.__class__.__name__}(io_loop) argument is deprecated in pyzmq 22.2."
231 " The currently active loop will always be used.",
232 DeprecationWarning,
233 stacklevel=3,
234 )
235 self._recv_futures = deque()
236 self._send_futures = deque()
237 self._state = 0
238 self._fd = self._shadow_sock.FD
240 @classmethod
241 def from_socket(cls: type[T], socket: _zmq.Socket, io_loop: Any = None) -> T:
242 """Create an async socket from an existing Socket"""
243 return cls(_from_socket=socket, io_loop=io_loop)
245 def close(self, linger: int | None = None) -> None:
246 if not self.closed and self._fd is not None:
247 event_list: list[_FutureEvent] = list(
248 chain(self._recv_futures or [], self._send_futures or [])
249 )
250 for event in event_list:
251 if not event.future.done():
252 try:
253 event.future.cancel()
254 except RuntimeError:
255 # RuntimeError may be called during teardown
256 pass
257 self._clear_io_state()
258 super().close(linger=linger)
260 close.__doc__ = _zmq.Socket.close.__doc__
262 def get(self, key):
263 result = super().get(key)
264 if key == EVENTS:
265 self._schedule_remaining_events(result)
266 return result
268 get.__doc__ = _zmq.Socket.get.__doc__
270 def recv_multipart(
271 self, flags: int = 0, copy: bool = True, track: bool = False
272 ) -> Awaitable[list[bytes] | list[_zmq.Frame]]:
273 """Receive a complete multipart zmq message.
275 Returns a Future whose result will be a multipart message.
276 """
277 return self._add_recv_event(
278 'recv_multipart', kwargs=dict(flags=flags, copy=copy, track=track)
279 )
281 def recv( # type: ignore
282 self, flags: int = 0, copy: bool = True, track: bool = False
283 ) -> Awaitable[bytes | _zmq.Frame]:
284 """Receive a single zmq frame.
286 Returns a Future, whose result will be the received frame.
288 Recommend using recv_multipart instead.
289 """
290 return self._add_recv_event(
291 'recv', kwargs=dict(flags=flags, copy=copy, track=track)
292 )
294 def recv_into( # type: ignore
295 self, buf, /, *, nbytes: int = 0, flags: int = 0
296 ) -> Awaitable[int]:
297 """Receive a single zmq frame into a pre-allocated buffer.
299 Returns a Future, whose result will be the number of bytes received.
300 """
301 return self._add_recv_event(
302 'recv_into', args=(buf,), kwargs=dict(nbytes=nbytes, flags=flags)
303 )
305 def send_multipart( # type: ignore
306 self, msg_parts: Any, flags: int = 0, copy: bool = True, track=False, **kwargs
307 ) -> Awaitable[_zmq.MessageTracker | None]:
308 """Send a complete multipart zmq message.
310 Returns a Future that resolves when sending is complete.
311 """
312 kwargs['flags'] = flags
313 kwargs['copy'] = copy
314 kwargs['track'] = track
315 return self._add_send_event('send_multipart', msg=msg_parts, kwargs=kwargs)
317 def send( # type: ignore
318 self,
319 data: Any,
320 flags: int = 0,
321 copy: bool = True,
322 track: bool = False,
323 **kwargs: Any,
324 ) -> Awaitable[_zmq.MessageTracker | None]:
325 """Send a single zmq frame.
327 Returns a Future that resolves when sending is complete.
329 Recommend using send_multipart instead.
330 """
331 kwargs['flags'] = flags
332 kwargs['copy'] = copy
333 kwargs['track'] = track
334 kwargs.update(dict(flags=flags, copy=copy, track=track))
335 return self._add_send_event('send', msg=data, kwargs=kwargs)
337 def _deserialize(self, recvd, load):
338 """Deserialize with Futures"""
339 f = self._Future()
341 def _chain(_):
342 """Chain result through serialization to recvd"""
343 if f.done():
344 # chained future may be cancelled, which means nobody is going to get this result
345 # if it's an error, that's no big deal (probably zmq.Again),
346 # but if it's a successful recv, this is a dropped message!
347 if not recvd.cancelled() and recvd.exception() is None:
348 warnings.warn(
349 # is there a useful stacklevel?
350 # ideally, it would point to where `f.cancel()` was called
351 f"Future {f} completed while awaiting {recvd}. A message has been dropped!",
352 RuntimeWarning,
353 )
354 return
355 if recvd.exception():
356 f.set_exception(recvd.exception())
357 else:
358 buf = recvd.result()
359 try:
360 loaded = load(buf)
361 except Exception as e:
362 f.set_exception(e)
363 else:
364 f.set_result(loaded)
366 recvd.add_done_callback(_chain)
368 def _chain_cancel(_):
369 """Chain cancellation from f to recvd"""
370 if recvd.done():
371 return
372 if f.cancelled():
373 recvd.cancel()
375 f.add_done_callback(_chain_cancel)
377 return f
379 def poll(self, timeout=None, flags=_zmq.POLLIN) -> Awaitable[int]: # type: ignore
380 """poll the socket for events
382 returns a Future for the poll results.
383 """
385 if self.closed:
386 raise _zmq.ZMQError(_zmq.ENOTSUP)
388 p = self._poller_class()
389 p.register(self, flags)
390 poll_future = cast(Future, p.poll(timeout))
392 future = self._Future()
394 def unwrap_result(f):
395 if future.done():
396 return
397 if poll_future.cancelled():
398 try:
399 future.cancel()
400 except RuntimeError:
401 # RuntimeError may be called during teardown
402 pass
403 return
404 if f.exception():
405 future.set_exception(poll_future.exception())
406 else:
407 evts = dict(poll_future.result())
408 future.set_result(evts.get(self, 0))
410 if poll_future.done():
411 # hook up result if already done
412 unwrap_result(poll_future)
413 else:
414 poll_future.add_done_callback(unwrap_result)
416 def cancel_poll(future):
417 """Cancel underlying poll if request has been cancelled"""
418 if not poll_future.done():
419 try:
420 poll_future.cancel()
421 except RuntimeError:
422 # RuntimeError may be called during teardown
423 pass
425 future.add_done_callback(cancel_poll)
427 return future
429 def _add_timeout(self, future, timeout):
430 """Add a timeout for a send or recv Future"""
432 def future_timeout():
433 if future.done():
434 # future already resolved, do nothing
435 return
437 # raise EAGAIN
438 future.set_exception(_zmq.Again())
440 return self._call_later(timeout, future_timeout)
442 def _call_later(self, delay, callback):
443 """Schedule a function to be called later
445 Override for different IOLoop implementations
447 Tornado and asyncio happen to both have ioloop.call_later
448 with the same signature.
449 """
450 return self._get_loop().call_later(delay, callback)
452 @staticmethod
453 def _remove_finished_future(future, event_list, event=None):
454 """Make sure that futures are removed from the event list when they resolve
456 Avoids delaying cleanup until the next send/recv event,
457 which may never come.
458 """
459 # "future" instance is shared between sockets, but each socket has its own event list.
460 if not event_list:
461 return
462 # only unconsumed events (e.g. cancelled calls)
463 # will be present when this happens
464 try:
465 event_list.remove(event)
466 except ValueError:
467 # usually this will have been removed by being consumed
468 return
470 def _add_recv_event(
471 self,
472 kind: str,
473 *,
474 args: tuple | None = None,
475 kwargs: dict[str, Any] | None = None,
476 future: Future | None = None,
477 ) -> Future:
478 """Add a recv event, returning the corresponding Future"""
479 f = future or self._Future()
480 if args is None:
481 args = ()
482 if kwargs is None:
483 kwargs = {}
484 if kind.startswith('recv') and kwargs.get('flags', 0) & _zmq.DONTWAIT:
485 # short-circuit non-blocking calls
486 recv = getattr(self._shadow_sock, kind)
487 try:
488 r = recv(*args, **kwargs)
489 except Exception as e:
490 f.set_exception(e)
491 else:
492 f.set_result(r)
493 return f
495 timer = _NoTimer
496 if hasattr(_zmq, 'RCVTIMEO'):
497 timeout_ms = self._shadow_sock.rcvtimeo
498 if timeout_ms >= 0:
499 timer = self._add_timeout(f, timeout_ms * 1e-3)
501 # we add it to the list of futures before we add the timeout as the
502 # timeout will remove the future from recv_futures to avoid leaks
503 _future_event = _FutureEvent(
504 f, kind, args=args, kwargs=kwargs, msg=None, timer=timer
505 )
506 self._recv_futures.append(_future_event)
508 if self._shadow_sock.get(EVENTS) & POLLIN:
509 # recv immediately, if we can
510 self._handle_recv()
511 if self._recv_futures and _future_event in self._recv_futures:
512 # Don't let the Future sit in _recv_events after it's done
513 # no need to register this if we've already been handled
514 # (i.e. immediately-resolved recv)
515 f.add_done_callback(
516 partial(
517 self._remove_finished_future,
518 event_list=self._recv_futures,
519 event=_future_event,
520 )
521 )
522 self._add_io_state(POLLIN)
523 return f
525 def _add_send_event(self, kind, msg=None, kwargs=None, future=None):
526 """Add a send event, returning the corresponding Future"""
527 f = future or self._Future()
528 # attempt send with DONTWAIT if no futures are waiting
529 # short-circuit for sends that will resolve immediately
530 # only call if no send Futures are waiting
531 if kind in ('send', 'send_multipart') and not self._send_futures:
532 flags = kwargs.get('flags', 0)
533 nowait_kwargs = kwargs.copy()
534 nowait_kwargs['flags'] = flags | _zmq.DONTWAIT
536 # short-circuit non-blocking calls
537 send = getattr(self._shadow_sock, kind)
538 # track if the send resolved or not
539 # (EAGAIN if DONTWAIT is not set should proceed with)
540 finish_early = True
541 try:
542 r = send(msg, **nowait_kwargs)
543 except _zmq.Again as e:
544 if flags & _zmq.DONTWAIT:
545 f.set_exception(e)
546 else:
547 # EAGAIN raised and DONTWAIT not requested,
548 # proceed with async send
549 finish_early = False
550 except Exception as e:
551 f.set_exception(e)
552 else:
553 f.set_result(r)
555 if finish_early:
556 # short-circuit resolved, return finished Future
557 # schedule wake for recv if there are any receivers waiting
558 if self._recv_futures:
559 self._schedule_remaining_events()
560 return f
562 timer = _NoTimer
563 if hasattr(_zmq, 'SNDTIMEO'):
564 timeout_ms = self._shadow_sock.get(_zmq.SNDTIMEO)
565 if timeout_ms >= 0:
566 timer = self._add_timeout(f, timeout_ms * 1e-3)
568 # we add it to the list of futures before we add the timeout as the
569 # timeout will remove the future from recv_futures to avoid leaks
570 _future_event = _FutureEvent(
571 f, kind, args=(), kwargs=kwargs, msg=msg, timer=timer
572 )
573 self._send_futures.append(_future_event)
574 # Don't let the Future sit in _send_futures after it's done
575 f.add_done_callback(
576 partial(
577 self._remove_finished_future,
578 event_list=self._send_futures,
579 event=_future_event,
580 )
581 )
583 self._add_io_state(POLLOUT)
584 return f
586 def _handle_recv(self):
587 """Handle recv events"""
588 if not self._shadow_sock.get(EVENTS) & POLLIN:
589 # event triggered, but state may have been changed between trigger and callback
590 return
591 f = None
592 while self._recv_futures:
593 f, kind, args, kwargs, _, timer = self._recv_futures.popleft()
594 # skip any cancelled futures
595 if f.done():
596 f = None
597 else:
598 break
600 if not self._recv_futures:
601 self._drop_io_state(POLLIN)
603 if f is None:
604 return
606 timer.cancel()
608 if kind == 'poll':
609 # on poll event, just signal ready, nothing else.
610 f.set_result(None)
611 return
612 elif kind == 'recv_multipart':
613 recv = self._shadow_sock.recv_multipart
614 elif kind == 'recv':
615 recv = self._shadow_sock.recv
616 elif kind == 'recv_into':
617 recv = self._shadow_sock.recv_into
618 else:
619 raise ValueError(f"Unhandled recv event type: {kind!r}")
621 kwargs['flags'] |= _zmq.DONTWAIT
622 try:
623 result = recv(*args, **kwargs)
624 except Exception as e:
625 f.set_exception(e)
626 else:
627 f.set_result(result)
629 def _handle_send(self):
630 if not self._shadow_sock.get(EVENTS) & POLLOUT:
631 # event triggered, but state may have been changed between trigger and callback
632 return
633 f = None
634 while self._send_futures:
635 f, kind, args, kwargs, msg, timer = self._send_futures.popleft()
636 # skip any cancelled futures
637 if f.done():
638 f = None
639 else:
640 break
642 if not self._send_futures:
643 self._drop_io_state(POLLOUT)
645 if f is None:
646 return
648 timer.cancel()
650 if kind == 'poll':
651 # on poll event, just signal ready, nothing else.
652 f.set_result(None)
653 return
654 elif kind == 'send_multipart':
655 send = self._shadow_sock.send_multipart
656 elif kind == 'send':
657 send = self._shadow_sock.send
658 else:
659 raise ValueError(f"Unhandled send event type: {kind!r}")
661 kwargs['flags'] |= _zmq.DONTWAIT
662 try:
663 result = send(msg, **kwargs)
664 except Exception as e:
665 f.set_exception(e)
666 else:
667 f.set_result(result)
669 # event masking from ZMQStream
670 def _handle_events(self, fd=0, events=0):
671 """Dispatch IO events to _handle_recv, etc."""
672 if self._shadow_sock.closed:
673 return
675 zmq_events = self._shadow_sock.get(EVENTS)
676 if zmq_events & _zmq.POLLIN:
677 self._handle_recv()
678 if zmq_events & _zmq.POLLOUT:
679 self._handle_send()
680 self._schedule_remaining_events()
682 def _schedule_remaining_events(self, events=None):
683 """Schedule a call to handle_events next loop iteration
685 If there are still events to handle.
686 """
687 # edge-triggered handling
688 # allow passing events in, in case this is triggered by retrieving events,
689 # so we don't have to retrieve it twice.
690 if self._state == 0:
691 # not watching for anything, nothing to schedule
692 return
693 if events is None:
694 events = self._shadow_sock.get(EVENTS)
695 if events & self._state:
696 self._call_later(0, self._handle_events)
698 def _add_io_state(self, state):
699 """Add io_state to poller."""
700 if self._state != state:
701 state = self._state = self._state | state
702 self._update_handler(self._state)
704 def _drop_io_state(self, state):
705 """Stop poller from watching an io_state."""
706 if self._state & state:
707 self._state = self._state & (~state)
708 self._update_handler(self._state)
710 def _update_handler(self, state):
711 """Update IOLoop handler with state.
713 zmq FD is always read-only.
714 """
715 # ensure loop is registered and init_io has been called
716 # if there are any events to watch for
717 if state:
718 self._get_loop()
719 self._schedule_remaining_events()
721 def _init_io_state(self, loop=None):
722 """initialize the ioloop event handler"""
723 if loop is None:
724 loop = self._get_loop()
725 loop.add_handler(self._shadow_sock, self._handle_events, self._READ)
726 self._call_later(0, self._handle_events)
728 def _clear_io_state(self):
729 """unregister the ioloop event handler
731 called once during close
732 """
733 fd = self._shadow_sock
734 if self._shadow_sock.closed:
735 fd = self._fd
736 if self._current_loop is not None:
737 self._current_loop.remove_handler(fd)