Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/zmq/_future.py: 21%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""Future-returning APIs for coroutines."""
3# Copyright (c) PyZMQ Developers.
4# Distributed under the terms of the Modified BSD License.
5from __future__ import annotations
7import warnings
8from asyncio import Future
9from collections import deque
10from functools import partial
11from itertools import chain
12from typing import Any, Awaitable, Callable, NamedTuple, TypeVar, cast, overload
14import zmq as _zmq
15from zmq import EVENTS, POLLIN, POLLOUT
16from zmq._typing import Literal
19class _FutureEvent(NamedTuple):
20 future: Future
21 kind: str
22 kwargs: dict
23 msg: Any
24 timer: Any
27# These are incomplete classes and need a Mixin for compatibility with an eventloop
28# defining the following attributes:
29#
30# _Future
31# _READ
32# _WRITE
33# _default_loop()
36class _Async:
37 """Mixin for common async logic"""
39 _current_loop: Any = None
40 _Future: type[Future]
42 def _get_loop(self) -> Any:
43 """Get event loop
45 Notice if event loop has changed,
46 and register init_io_state on activation of a new event loop
47 """
48 if self._current_loop is None:
49 self._current_loop = self._default_loop()
50 self._init_io_state(self._current_loop)
51 return self._current_loop
52 current_loop = self._default_loop()
53 if current_loop is not self._current_loop:
54 # warn? This means a socket is being used in multiple loops!
55 self._current_loop = current_loop
56 self._init_io_state(current_loop)
57 return current_loop
59 def _default_loop(self) -> Any:
60 raise NotImplementedError("Must be implemented in a subclass")
62 def _init_io_state(self, loop=None) -> None:
63 pass
66class _AsyncPoller(_Async, _zmq.Poller):
67 """Poller that returns a Future on poll, instead of blocking."""
69 _socket_class: type[_AsyncSocket]
70 _READ: int
71 _WRITE: int
72 raw_sockets: list[Any]
74 def _watch_raw_socket(self, loop: Any, socket: Any, evt: int, f: Callable) -> None:
75 """Schedule callback for a raw socket"""
76 raise NotImplementedError()
78 def _unwatch_raw_sockets(self, loop: Any, *sockets: Any) -> None:
79 """Unschedule callback for a raw socket"""
80 raise NotImplementedError()
82 def poll(self, timeout=-1) -> Awaitable[list[tuple[Any, int]]]: # type: ignore
83 """Return a Future for a poll event"""
84 future = self._Future()
85 if timeout == 0:
86 try:
87 result = super().poll(0)
88 except Exception as e:
89 future.set_exception(e)
90 else:
91 future.set_result(result)
92 return future
94 loop = self._get_loop()
96 # register Future to be called as soon as any event is available on any socket
97 watcher = self._Future()
99 # watch raw sockets:
100 raw_sockets: list[Any] = []
102 def wake_raw(*args):
103 if not watcher.done():
104 watcher.set_result(None)
106 watcher.add_done_callback(
107 lambda f: self._unwatch_raw_sockets(loop, *raw_sockets)
108 )
110 wrapped_sockets: list[_AsyncSocket] = []
112 def _clear_wrapper_io(f):
113 for s in wrapped_sockets:
114 s._clear_io_state()
116 for socket, mask in self.sockets:
117 if isinstance(socket, _zmq.Socket):
118 if not isinstance(socket, self._socket_class):
119 # it's a blocking zmq.Socket, wrap it in async
120 socket = self._socket_class.from_socket(socket)
121 wrapped_sockets.append(socket)
122 if mask & _zmq.POLLIN:
123 socket._add_recv_event('poll', future=watcher)
124 if mask & _zmq.POLLOUT:
125 socket._add_send_event('poll', future=watcher)
126 else:
127 raw_sockets.append(socket)
128 evt = 0
129 if mask & _zmq.POLLIN:
130 evt |= self._READ
131 if mask & _zmq.POLLOUT:
132 evt |= self._WRITE
133 self._watch_raw_socket(loop, socket, evt, wake_raw)
135 def on_poll_ready(f):
136 if future.done():
137 return
138 if watcher.cancelled():
139 try:
140 future.cancel()
141 except RuntimeError:
142 # RuntimeError may be called during teardown
143 pass
144 return
145 if watcher.exception():
146 future.set_exception(watcher.exception())
147 else:
148 try:
149 result = super(_AsyncPoller, self).poll(0)
150 except Exception as e:
151 future.set_exception(e)
152 else:
153 future.set_result(result)
155 watcher.add_done_callback(on_poll_ready)
157 if wrapped_sockets:
158 watcher.add_done_callback(_clear_wrapper_io)
160 if timeout is not None and timeout > 0:
161 # schedule cancel to fire on poll timeout, if any
162 def trigger_timeout():
163 if not watcher.done():
164 watcher.set_result(None)
166 timeout_handle = loop.call_later(1e-3 * timeout, trigger_timeout)
168 def cancel_timeout(f):
169 if hasattr(timeout_handle, 'cancel'):
170 timeout_handle.cancel()
171 else:
172 loop.remove_timeout(timeout_handle)
174 future.add_done_callback(cancel_timeout)
176 def cancel_watcher(f):
177 if not watcher.done():
178 watcher.cancel()
180 future.add_done_callback(cancel_watcher)
182 return future
185class _NoTimer:
186 @staticmethod
187 def cancel():
188 pass
191T = TypeVar("T", bound="_AsyncSocket")
194class _AsyncSocket(_Async, _zmq.Socket[Future]):
195 # Warning : these class variables are only here to allow to call super().__setattr__.
196 # They be overridden at instance initialization and not shared in the whole class
197 _recv_futures = None
198 _send_futures = None
199 _state = 0
200 _shadow_sock: _zmq.Socket
201 _poller_class = _AsyncPoller
202 _fd = None
204 def __init__(
205 self,
206 context=None,
207 socket_type=-1,
208 io_loop=None,
209 _from_socket: _zmq.Socket | None = None,
210 **kwargs,
211 ) -> None:
212 if isinstance(context, _zmq.Socket):
213 context, _from_socket = (None, context)
214 if _from_socket is not None:
215 super().__init__(shadow=_from_socket.underlying) # type: ignore
216 self._shadow_sock = _from_socket
217 else:
218 super().__init__(context, socket_type, **kwargs) # type: ignore
219 self._shadow_sock = _zmq.Socket.shadow(self.underlying)
221 if io_loop is not None:
222 warnings.warn(
223 f"{self.__class__.__name__}(io_loop) argument is deprecated in pyzmq 22.2."
224 " The currently active loop will always be used.",
225 DeprecationWarning,
226 stacklevel=3,
227 )
228 self._recv_futures = deque()
229 self._send_futures = deque()
230 self._state = 0
231 self._fd = self._shadow_sock.FD
233 @classmethod
234 def from_socket(cls: type[T], socket: _zmq.Socket, io_loop: Any = None) -> T:
235 """Create an async socket from an existing Socket"""
236 return cls(_from_socket=socket, io_loop=io_loop)
238 def close(self, linger: int | None = None) -> None:
239 if not self.closed and self._fd is not None:
240 event_list: list[_FutureEvent] = list(
241 chain(self._recv_futures or [], self._send_futures or [])
242 )
243 for event in event_list:
244 if not event.future.done():
245 try:
246 event.future.cancel()
247 except RuntimeError:
248 # RuntimeError may be called during teardown
249 pass
250 self._clear_io_state()
251 super().close(linger=linger)
253 close.__doc__ = _zmq.Socket.close.__doc__
255 def get(self, key):
256 result = super().get(key)
257 if key == EVENTS:
258 self._schedule_remaining_events(result)
259 return result
261 get.__doc__ = _zmq.Socket.get.__doc__
263 @overload # type: ignore
264 def recv_multipart(
265 self, flags: int = 0, *, track: bool = False
266 ) -> Awaitable[list[bytes]]: ...
268 @overload
269 def recv_multipart(
270 self, flags: int = 0, *, copy: Literal[True], track: bool = False
271 ) -> Awaitable[list[bytes]]: ...
273 @overload
274 def recv_multipart(
275 self, flags: int = 0, *, copy: Literal[False], track: bool = False
276 ) -> Awaitable[list[_zmq.Frame]]: # type: ignore
277 ...
279 @overload
280 def recv_multipart(
281 self, flags: int = 0, copy: bool = True, track: bool = False
282 ) -> Awaitable[list[bytes] | list[_zmq.Frame]]: ...
284 def recv_multipart(
285 self, flags: int = 0, copy: bool = True, track: bool = False
286 ) -> Awaitable[list[bytes] | list[_zmq.Frame]]:
287 """Receive a complete multipart zmq message.
289 Returns a Future whose result will be a multipart message.
290 """
291 return self._add_recv_event(
292 'recv_multipart', dict(flags=flags, copy=copy, track=track)
293 )
295 @overload # type: ignore
296 def recv(self, flags: int = 0, *, track: bool = False) -> Awaitable[bytes]: ...
298 @overload
299 def recv(
300 self, flags: int = 0, *, copy: Literal[True], track: bool = False
301 ) -> Awaitable[bytes]: ...
303 @overload
304 def recv(
305 self, flags: int = 0, *, copy: Literal[False], track: bool = False
306 ) -> Awaitable[_zmq.Frame]: ...
308 def recv( # type: ignore
309 self, flags: int = 0, copy: bool = True, track: bool = False
310 ) -> Awaitable[bytes | _zmq.Frame]:
311 """Receive a single zmq frame.
313 Returns a Future, whose result will be the received frame.
315 Recommend using recv_multipart instead.
316 """
317 return self._add_recv_event('recv', dict(flags=flags, copy=copy, track=track))
319 def send_multipart( # type: ignore
320 self, msg_parts: Any, flags: int = 0, copy: bool = True, track=False, **kwargs
321 ) -> Awaitable[_zmq.MessageTracker | None]:
322 """Send a complete multipart zmq message.
324 Returns a Future that resolves when sending is complete.
325 """
326 kwargs['flags'] = flags
327 kwargs['copy'] = copy
328 kwargs['track'] = track
329 return self._add_send_event('send_multipart', msg=msg_parts, kwargs=kwargs)
331 def send( # type: ignore
332 self,
333 data: Any,
334 flags: int = 0,
335 copy: bool = True,
336 track: bool = False,
337 **kwargs: Any,
338 ) -> Awaitable[_zmq.MessageTracker | None]:
339 """Send a single zmq frame.
341 Returns a Future that resolves when sending is complete.
343 Recommend using send_multipart instead.
344 """
345 kwargs['flags'] = flags
346 kwargs['copy'] = copy
347 kwargs['track'] = track
348 kwargs.update(dict(flags=flags, copy=copy, track=track))
349 return self._add_send_event('send', msg=data, kwargs=kwargs)
351 def _deserialize(self, recvd, load):
352 """Deserialize with Futures"""
353 f = self._Future()
355 def _chain(_):
356 """Chain result through serialization to recvd"""
357 if f.done():
358 # chained future may be cancelled, which means nobody is going to get this result
359 # if it's an error, that's no big deal (probably zmq.Again),
360 # but if it's a successful recv, this is a dropped message!
361 if not recvd.cancelled() and recvd.exception() is None:
362 warnings.warn(
363 # is there a useful stacklevel?
364 # ideally, it would point to where `f.cancel()` was called
365 f"Future {f} completed while awaiting {recvd}. A message has been dropped!",
366 RuntimeWarning,
367 )
368 return
369 if recvd.exception():
370 f.set_exception(recvd.exception())
371 else:
372 buf = recvd.result()
373 try:
374 loaded = load(buf)
375 except Exception as e:
376 f.set_exception(e)
377 else:
378 f.set_result(loaded)
380 recvd.add_done_callback(_chain)
382 def _chain_cancel(_):
383 """Chain cancellation from f to recvd"""
384 if recvd.done():
385 return
386 if f.cancelled():
387 recvd.cancel()
389 f.add_done_callback(_chain_cancel)
391 return f
393 def poll(self, timeout=None, flags=_zmq.POLLIN) -> Awaitable[int]: # type: ignore
394 """poll the socket for events
396 returns a Future for the poll results.
397 """
399 if self.closed:
400 raise _zmq.ZMQError(_zmq.ENOTSUP)
402 p = self._poller_class()
403 p.register(self, flags)
404 poll_future = cast(Future, p.poll(timeout))
406 future = self._Future()
408 def unwrap_result(f):
409 if future.done():
410 return
411 if poll_future.cancelled():
412 try:
413 future.cancel()
414 except RuntimeError:
415 # RuntimeError may be called during teardown
416 pass
417 return
418 if f.exception():
419 future.set_exception(poll_future.exception())
420 else:
421 evts = dict(poll_future.result())
422 future.set_result(evts.get(self, 0))
424 if poll_future.done():
425 # hook up result if already done
426 unwrap_result(poll_future)
427 else:
428 poll_future.add_done_callback(unwrap_result)
430 def cancel_poll(future):
431 """Cancel underlying poll if request has been cancelled"""
432 if not poll_future.done():
433 try:
434 poll_future.cancel()
435 except RuntimeError:
436 # RuntimeError may be called during teardown
437 pass
439 future.add_done_callback(cancel_poll)
441 return future
443 # overrides only necessary for updated types
444 def recv_string(self, *args, **kwargs) -> Awaitable[str]: # type: ignore
445 return super().recv_string(*args, **kwargs) # type: ignore
447 def send_string( # type: ignore
448 self, s: str, flags: int = 0, encoding: str = 'utf-8'
449 ) -> Awaitable[None]:
450 return super().send_string(s, flags=flags, encoding=encoding) # type: ignore
452 def _add_timeout(self, future, timeout):
453 """Add a timeout for a send or recv Future"""
455 def future_timeout():
456 if future.done():
457 # future already resolved, do nothing
458 return
460 # raise EAGAIN
461 future.set_exception(_zmq.Again())
463 return self._call_later(timeout, future_timeout)
465 def _call_later(self, delay, callback):
466 """Schedule a function to be called later
468 Override for different IOLoop implementations
470 Tornado and asyncio happen to both have ioloop.call_later
471 with the same signature.
472 """
473 return self._get_loop().call_later(delay, callback)
475 @staticmethod
476 def _remove_finished_future(future, event_list, event=None):
477 """Make sure that futures are removed from the event list when they resolve
479 Avoids delaying cleanup until the next send/recv event,
480 which may never come.
481 """
482 # "future" instance is shared between sockets, but each socket has its own event list.
483 if not event_list:
484 return
485 # only unconsumed events (e.g. cancelled calls)
486 # will be present when this happens
487 try:
488 event_list.remove(event)
489 except ValueError:
490 # usually this will have been removed by being consumed
491 return
493 def _add_recv_event(self, kind, kwargs=None, future=None):
494 """Add a recv event, returning the corresponding Future"""
495 f = future or self._Future()
496 if kind.startswith('recv') and kwargs.get('flags', 0) & _zmq.DONTWAIT:
497 # short-circuit non-blocking calls
498 recv = getattr(self._shadow_sock, kind)
499 try:
500 r = recv(**kwargs)
501 except Exception as e:
502 f.set_exception(e)
503 else:
504 f.set_result(r)
505 return f
507 timer = _NoTimer
508 if hasattr(_zmq, 'RCVTIMEO'):
509 timeout_ms = self._shadow_sock.rcvtimeo
510 if timeout_ms >= 0:
511 timer = self._add_timeout(f, timeout_ms * 1e-3)
513 # we add it to the list of futures before we add the timeout as the
514 # timeout will remove the future from recv_futures to avoid leaks
515 _future_event = _FutureEvent(f, kind, kwargs, msg=None, timer=timer)
516 self._recv_futures.append(_future_event)
518 if self._shadow_sock.get(EVENTS) & POLLIN:
519 # recv immediately, if we can
520 self._handle_recv()
521 if self._recv_futures and _future_event in self._recv_futures:
522 # Don't let the Future sit in _recv_events after it's done
523 # no need to register this if we've already been handled
524 # (i.e. immediately-resolved recv)
525 f.add_done_callback(
526 partial(
527 self._remove_finished_future,
528 event_list=self._recv_futures,
529 event=_future_event,
530 )
531 )
532 self._add_io_state(POLLIN)
533 return f
535 def _add_send_event(self, kind, msg=None, kwargs=None, future=None):
536 """Add a send event, returning the corresponding Future"""
537 f = future or self._Future()
538 # attempt send with DONTWAIT if no futures are waiting
539 # short-circuit for sends that will resolve immediately
540 # only call if no send Futures are waiting
541 if kind in ('send', 'send_multipart') and not self._send_futures:
542 flags = kwargs.get('flags', 0)
543 nowait_kwargs = kwargs.copy()
544 nowait_kwargs['flags'] = flags | _zmq.DONTWAIT
546 # short-circuit non-blocking calls
547 send = getattr(self._shadow_sock, kind)
548 # track if the send resolved or not
549 # (EAGAIN if DONTWAIT is not set should proceed with)
550 finish_early = True
551 try:
552 r = send(msg, **nowait_kwargs)
553 except _zmq.Again as e:
554 if flags & _zmq.DONTWAIT:
555 f.set_exception(e)
556 else:
557 # EAGAIN raised and DONTWAIT not requested,
558 # proceed with async send
559 finish_early = False
560 except Exception as e:
561 f.set_exception(e)
562 else:
563 f.set_result(r)
565 if finish_early:
566 # short-circuit resolved, return finished Future
567 # schedule wake for recv if there are any receivers waiting
568 if self._recv_futures:
569 self._schedule_remaining_events()
570 return f
572 timer = _NoTimer
573 if hasattr(_zmq, 'SNDTIMEO'):
574 timeout_ms = self._shadow_sock.get(_zmq.SNDTIMEO)
575 if timeout_ms >= 0:
576 timer = self._add_timeout(f, timeout_ms * 1e-3)
578 # we add it to the list of futures before we add the timeout as the
579 # timeout will remove the future from recv_futures to avoid leaks
580 _future_event = _FutureEvent(f, kind, kwargs=kwargs, msg=msg, timer=timer)
581 self._send_futures.append(_future_event)
582 # Don't let the Future sit in _send_futures after it's done
583 f.add_done_callback(
584 partial(
585 self._remove_finished_future,
586 event_list=self._send_futures,
587 event=_future_event,
588 )
589 )
591 self._add_io_state(POLLOUT)
592 return f
594 def _handle_recv(self):
595 """Handle recv events"""
596 if not self._shadow_sock.get(EVENTS) & POLLIN:
597 # event triggered, but state may have been changed between trigger and callback
598 return
599 f = None
600 while self._recv_futures:
601 f, kind, kwargs, _, timer = self._recv_futures.popleft()
602 # skip any cancelled futures
603 if f.done():
604 f = None
605 else:
606 break
608 if not self._recv_futures:
609 self._drop_io_state(POLLIN)
611 if f is None:
612 return
614 timer.cancel()
616 if kind == 'poll':
617 # on poll event, just signal ready, nothing else.
618 f.set_result(None)
619 return
620 elif kind == 'recv_multipart':
621 recv = self._shadow_sock.recv_multipart
622 elif kind == 'recv':
623 recv = self._shadow_sock.recv
624 else:
625 raise ValueError("Unhandled recv event type: %r" % kind)
627 kwargs['flags'] |= _zmq.DONTWAIT
628 try:
629 result = recv(**kwargs)
630 except Exception as e:
631 f.set_exception(e)
632 else:
633 f.set_result(result)
635 def _handle_send(self):
636 if not self._shadow_sock.get(EVENTS) & POLLOUT:
637 # event triggered, but state may have been changed between trigger and callback
638 return
639 f = None
640 while self._send_futures:
641 f, kind, kwargs, msg, timer = self._send_futures.popleft()
642 # skip any cancelled futures
643 if f.done():
644 f = None
645 else:
646 break
648 if not self._send_futures:
649 self._drop_io_state(POLLOUT)
651 if f is None:
652 return
654 timer.cancel()
656 if kind == 'poll':
657 # on poll event, just signal ready, nothing else.
658 f.set_result(None)
659 return
660 elif kind == 'send_multipart':
661 send = self._shadow_sock.send_multipart
662 elif kind == 'send':
663 send = self._shadow_sock.send
664 else:
665 raise ValueError("Unhandled send event type: %r" % kind)
667 kwargs['flags'] |= _zmq.DONTWAIT
668 try:
669 result = send(msg, **kwargs)
670 except Exception as e:
671 f.set_exception(e)
672 else:
673 f.set_result(result)
675 # event masking from ZMQStream
676 def _handle_events(self, fd=0, events=0):
677 """Dispatch IO events to _handle_recv, etc."""
678 if self._shadow_sock.closed:
679 return
681 zmq_events = self._shadow_sock.get(EVENTS)
682 if zmq_events & _zmq.POLLIN:
683 self._handle_recv()
684 if zmq_events & _zmq.POLLOUT:
685 self._handle_send()
686 self._schedule_remaining_events()
688 def _schedule_remaining_events(self, events=None):
689 """Schedule a call to handle_events next loop iteration
691 If there are still events to handle.
692 """
693 # edge-triggered handling
694 # allow passing events in, in case this is triggered by retrieving events,
695 # so we don't have to retrieve it twice.
696 if self._state == 0:
697 # not watching for anything, nothing to schedule
698 return
699 if events is None:
700 events = self._shadow_sock.get(EVENTS)
701 if events & self._state:
702 self._call_later(0, self._handle_events)
704 def _add_io_state(self, state):
705 """Add io_state to poller."""
706 if self._state != state:
707 state = self._state = self._state | state
708 self._update_handler(self._state)
710 def _drop_io_state(self, state):
711 """Stop poller from watching an io_state."""
712 if self._state & state:
713 self._state = self._state & (~state)
714 self._update_handler(self._state)
716 def _update_handler(self, state):
717 """Update IOLoop handler with state.
719 zmq FD is always read-only.
720 """
721 # ensure loop is registered and init_io has been called
722 # if there are any events to watch for
723 if state:
724 self._get_loop()
725 self._schedule_remaining_events()
727 def _init_io_state(self, loop=None):
728 """initialize the ioloop event handler"""
729 if loop is None:
730 loop = self._get_loop()
731 loop.add_handler(self._shadow_sock, self._handle_events, self._READ)
732 self._call_later(0, self._handle_events)
734 def _clear_io_state(self):
735 """unregister the ioloop event handler
737 called once during close
738 """
739 fd = self._shadow_sock
740 if self._shadow_sock.closed:
741 fd = self._fd
742 if self._current_loop is not None:
743 self._current_loop.remove_handler(fd)