Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/zmq/_future.py: 19%
395 statements
« prev ^ index » next coverage.py v7.3.3, created at 2023-12-15 06:13 +0000
« prev ^ index » next coverage.py v7.3.3, created at 2023-12-15 06:13 +0000
1"""Future-returning APIs for coroutines."""
3# Copyright (c) PyZMQ Developers.
4# Distributed under the terms of the Modified BSD License.
6import warnings
7from asyncio import Future
8from collections import deque
9from itertools import chain
10from typing import (
11 Any,
12 Awaitable,
13 Callable,
14 Dict,
15 List,
16 NamedTuple,
17 Optional,
18 Tuple,
19 Type,
20 TypeVar,
21 Union,
22 cast,
23 overload,
24)
26import zmq as _zmq
27from zmq import EVENTS, POLLIN, POLLOUT
28from zmq._typing import Literal
31class _FutureEvent(NamedTuple):
32 future: Future
33 kind: str
34 kwargs: Dict
35 msg: Any
36 timer: Any
39# These are incomplete classes and need a Mixin for compatibility with an eventloop
40# defining the following attributes:
41#
42# _Future
43# _READ
44# _WRITE
45# _default_loop()
48class _Async:
49 """Mixin for common async logic"""
51 _current_loop: Any = None
52 _Future: Type[Future]
54 def _get_loop(self) -> Any:
55 """Get event loop
57 Notice if event loop has changed,
58 and register init_io_state on activation of a new event loop
59 """
60 if self._current_loop is None:
61 self._current_loop = self._default_loop()
62 self._init_io_state(self._current_loop)
63 return self._current_loop
64 current_loop = self._default_loop()
65 if current_loop is not self._current_loop:
66 # warn? This means a socket is being used in multiple loops!
67 self._current_loop = current_loop
68 self._init_io_state(current_loop)
69 return current_loop
71 def _default_loop(self) -> Any:
72 raise NotImplementedError("Must be implemented in a subclass")
74 def _init_io_state(self, loop=None) -> None:
75 pass
78class _AsyncPoller(_Async, _zmq.Poller):
79 """Poller that returns a Future on poll, instead of blocking."""
81 _socket_class: Type["_AsyncSocket"]
82 _READ: int
83 _WRITE: int
84 raw_sockets: List[Any]
86 def _watch_raw_socket(self, loop: Any, socket: Any, evt: int, f: Callable) -> None:
87 """Schedule callback for a raw socket"""
88 raise NotImplementedError()
90 def _unwatch_raw_sockets(self, loop: Any, *sockets: Any) -> None:
91 """Unschedule callback for a raw socket"""
92 raise NotImplementedError()
94 def poll(self, timeout=-1) -> Awaitable[List[Tuple[Any, int]]]: # type: ignore
95 """Return a Future for a poll event"""
96 future = self._Future()
97 if timeout == 0:
98 try:
99 result = super().poll(0)
100 except Exception as e:
101 future.set_exception(e)
102 else:
103 future.set_result(result)
104 return future
106 loop = self._get_loop()
108 # register Future to be called as soon as any event is available on any socket
109 watcher = self._Future()
111 # watch raw sockets:
112 raw_sockets: List[Any] = []
114 def wake_raw(*args):
115 if not watcher.done():
116 watcher.set_result(None)
118 watcher.add_done_callback(
119 lambda f: self._unwatch_raw_sockets(loop, *raw_sockets)
120 )
122 for socket, mask in self.sockets:
123 if isinstance(socket, _zmq.Socket):
124 if not isinstance(socket, self._socket_class):
125 # it's a blocking zmq.Socket, wrap it in async
126 socket = self._socket_class.from_socket(socket)
127 if mask & _zmq.POLLIN:
128 socket._add_recv_event('poll', future=watcher)
129 if mask & _zmq.POLLOUT:
130 socket._add_send_event('poll', future=watcher)
131 else:
132 raw_sockets.append(socket)
133 evt = 0
134 if mask & _zmq.POLLIN:
135 evt |= self._READ
136 if mask & _zmq.POLLOUT:
137 evt |= self._WRITE
138 self._watch_raw_socket(loop, socket, evt, wake_raw)
140 def on_poll_ready(f):
141 if future.done():
142 return
143 if watcher.cancelled():
144 try:
145 future.cancel()
146 except RuntimeError:
147 # RuntimeError may be called during teardown
148 pass
149 return
150 if watcher.exception():
151 future.set_exception(watcher.exception())
152 else:
153 try:
154 result = super(_AsyncPoller, self).poll(0)
155 except Exception as e:
156 future.set_exception(e)
157 else:
158 future.set_result(result)
160 watcher.add_done_callback(on_poll_ready)
162 if timeout is not None and timeout > 0:
163 # schedule cancel to fire on poll timeout, if any
164 def trigger_timeout():
165 if not watcher.done():
166 watcher.set_result(None)
168 timeout_handle = loop.call_later(1e-3 * timeout, trigger_timeout)
170 def cancel_timeout(f):
171 if hasattr(timeout_handle, 'cancel'):
172 timeout_handle.cancel()
173 else:
174 loop.remove_timeout(timeout_handle)
176 future.add_done_callback(cancel_timeout)
178 def cancel_watcher(f):
179 if not watcher.done():
180 watcher.cancel()
182 future.add_done_callback(cancel_watcher)
184 return future
187class _NoTimer:
188 @staticmethod
189 def cancel():
190 pass
193T = TypeVar("T", bound="_AsyncSocket")
196class _AsyncSocket(_Async, _zmq.Socket[Future]):
197 # Warning : these class variables are only here to allow to call super().__setattr__.
198 # They be overridden at instance initialization and not shared in the whole class
199 _recv_futures = None
200 _send_futures = None
201 _state = 0
202 _shadow_sock: "_zmq.Socket"
203 _poller_class = _AsyncPoller
204 _fd = None
206 def __init__(
207 self,
208 context=None,
209 socket_type=-1,
210 io_loop=None,
211 _from_socket: Optional["_zmq.Socket"] = None,
212 **kwargs,
213 ) -> None:
214 if isinstance(context, _zmq.Socket):
215 context, _from_socket = (None, context)
216 if _from_socket is not None:
217 super().__init__(shadow=_from_socket.underlying) # type: ignore
218 self._shadow_sock = _from_socket
219 else:
220 super().__init__(context, socket_type, **kwargs) # type: ignore
221 self._shadow_sock = _zmq.Socket.shadow(self.underlying)
223 if io_loop is not None:
224 warnings.warn(
225 f"{self.__class__.__name__}(io_loop) argument is deprecated in pyzmq 22.2."
226 " The currently active loop will always be used.",
227 DeprecationWarning,
228 stacklevel=3,
229 )
230 self._recv_futures = deque()
231 self._send_futures = deque()
232 self._state = 0
233 self._fd = self._shadow_sock.FD
235 @classmethod
236 def from_socket(cls: Type[T], socket: "_zmq.Socket", io_loop: Any = None) -> T:
237 """Create an async socket from an existing Socket"""
238 return cls(_from_socket=socket, io_loop=io_loop)
240 def close(self, linger: Optional[int] = None) -> None:
241 if not self.closed and self._fd is not None:
242 event_list: List[_FutureEvent] = list(
243 chain(self._recv_futures or [], self._send_futures or [])
244 )
245 for event in event_list:
246 if not event.future.done():
247 try:
248 event.future.cancel()
249 except RuntimeError:
250 # RuntimeError may be called during teardown
251 pass
252 self._clear_io_state()
253 super().close(linger=linger)
255 close.__doc__ = _zmq.Socket.close.__doc__
257 def get(self, key):
258 result = super().get(key)
259 if key == EVENTS:
260 self._schedule_remaining_events(result)
261 return result
263 get.__doc__ = _zmq.Socket.get.__doc__
265 @overload # type: ignore
266 def recv_multipart(
267 self, flags: int = 0, *, track: bool = False
268 ) -> Awaitable[List[bytes]]:
269 ...
271 @overload
272 def recv_multipart(
273 self, flags: int = 0, *, copy: Literal[True], track: bool = False
274 ) -> Awaitable[List[bytes]]:
275 ...
277 @overload
278 def recv_multipart(
279 self, flags: int = 0, *, copy: Literal[False], track: bool = False
280 ) -> Awaitable[List[_zmq.Frame]]: # type: ignore
281 ...
283 @overload
284 def recv_multipart(
285 self, flags: int = 0, copy: bool = True, track: bool = False
286 ) -> Awaitable[Union[List[bytes], List[_zmq.Frame]]]:
287 ...
289 def recv_multipart(
290 self, flags: int = 0, copy: bool = True, track: bool = False
291 ) -> Awaitable[Union[List[bytes], List[_zmq.Frame]]]:
292 """Receive a complete multipart zmq message.
294 Returns a Future whose result will be a multipart message.
295 """
296 return self._add_recv_event(
297 'recv_multipart', dict(flags=flags, copy=copy, track=track)
298 )
300 def recv( # type: ignore
301 self, flags: int = 0, copy: bool = True, track: bool = False
302 ) -> Awaitable[Union[bytes, _zmq.Frame]]:
303 """Receive a single zmq frame.
305 Returns a Future, whose result will be the received frame.
307 Recommend using recv_multipart instead.
308 """
309 return self._add_recv_event('recv', dict(flags=flags, copy=copy, track=track))
311 def send_multipart( # type: ignore
312 self, msg_parts: Any, flags: int = 0, copy: bool = True, track=False, **kwargs
313 ) -> Awaitable[Optional[_zmq.MessageTracker]]:
314 """Send a complete multipart zmq message.
316 Returns a Future that resolves when sending is complete.
317 """
318 kwargs['flags'] = flags
319 kwargs['copy'] = copy
320 kwargs['track'] = track
321 return self._add_send_event('send_multipart', msg=msg_parts, kwargs=kwargs)
323 def send( # type: ignore
324 self,
325 data: Any,
326 flags: int = 0,
327 copy: bool = True,
328 track: bool = False,
329 **kwargs: Any,
330 ) -> Awaitable[Optional[_zmq.MessageTracker]]:
331 """Send a single zmq frame.
333 Returns a Future that resolves when sending is complete.
335 Recommend using send_multipart instead.
336 """
337 kwargs['flags'] = flags
338 kwargs['copy'] = copy
339 kwargs['track'] = track
340 kwargs.update(dict(flags=flags, copy=copy, track=track))
341 return self._add_send_event('send', msg=data, kwargs=kwargs)
343 def _deserialize(self, recvd, load):
344 """Deserialize with Futures"""
345 f = self._Future()
347 def _chain(_):
348 """Chain result through serialization to recvd"""
349 if f.done():
350 return
351 if recvd.exception():
352 f.set_exception(recvd.exception())
353 else:
354 buf = recvd.result()
355 try:
356 loaded = load(buf)
357 except Exception as e:
358 f.set_exception(e)
359 else:
360 f.set_result(loaded)
362 recvd.add_done_callback(_chain)
364 def _chain_cancel(_):
365 """Chain cancellation from f to recvd"""
366 if recvd.done():
367 return
368 if f.cancelled():
369 recvd.cancel()
371 f.add_done_callback(_chain_cancel)
373 return f
375 def poll(self, timeout=None, flags=_zmq.POLLIN) -> Awaitable[int]: # type: ignore
376 """poll the socket for events
378 returns a Future for the poll results.
379 """
381 if self.closed:
382 raise _zmq.ZMQError(_zmq.ENOTSUP)
384 p = self._poller_class()
385 p.register(self, flags)
386 poll_future = cast(Future, p.poll(timeout))
388 future = self._Future()
390 def unwrap_result(f):
391 if future.done():
392 return
393 if poll_future.cancelled():
394 try:
395 future.cancel()
396 except RuntimeError:
397 # RuntimeError may be called during teardown
398 pass
399 return
400 if f.exception():
401 future.set_exception(poll_future.exception())
402 else:
403 evts = dict(poll_future.result())
404 future.set_result(evts.get(self, 0))
406 if poll_future.done():
407 # hook up result if already done
408 unwrap_result(poll_future)
409 else:
410 poll_future.add_done_callback(unwrap_result)
412 def cancel_poll(future):
413 """Cancel underlying poll if request has been cancelled"""
414 if not poll_future.done():
415 try:
416 poll_future.cancel()
417 except RuntimeError:
418 # RuntimeError may be called during teardown
419 pass
421 future.add_done_callback(cancel_poll)
423 return future
425 # overrides only necessary for updated types
426 def recv_string(self, *args, **kwargs) -> Awaitable[str]: # type: ignore
427 return super().recv_string(*args, **kwargs) # type: ignore
429 def send_string(self, s: str, flags: int = 0, encoding: str = 'utf-8') -> Awaitable[None]: # type: ignore
430 return super().send_string(s, flags=flags, encoding=encoding) # type: ignore
432 def _add_timeout(self, future, timeout):
433 """Add a timeout for a send or recv Future"""
435 def future_timeout():
436 if future.done():
437 # future already resolved, do nothing
438 return
440 # raise EAGAIN
441 future.set_exception(_zmq.Again())
443 return self._call_later(timeout, future_timeout)
445 def _call_later(self, delay, callback):
446 """Schedule a function to be called later
448 Override for different IOLoop implementations
450 Tornado and asyncio happen to both have ioloop.call_later
451 with the same signature.
452 """
453 return self._get_loop().call_later(delay, callback)
455 @staticmethod
456 def _remove_finished_future(future, event_list):
457 """Make sure that futures are removed from the event list when they resolve
459 Avoids delaying cleanup until the next send/recv event,
460 which may never come.
461 """
462 for f_idx, event in enumerate(event_list):
463 if event.future is future:
464 break
465 else:
466 return
468 # "future" instance is shared between sockets, but each socket has its own event list.
469 event_list.remove(event_list[f_idx])
471 def _add_recv_event(self, kind, kwargs=None, future=None):
472 """Add a recv event, returning the corresponding Future"""
473 f = future or self._Future()
474 if kind.startswith('recv') and kwargs.get('flags', 0) & _zmq.DONTWAIT:
475 # short-circuit non-blocking calls
476 recv = getattr(self._shadow_sock, kind)
477 try:
478 r = recv(**kwargs)
479 except Exception as e:
480 f.set_exception(e)
481 else:
482 f.set_result(r)
483 return f
485 timer = _NoTimer
486 if hasattr(_zmq, 'RCVTIMEO'):
487 timeout_ms = self._shadow_sock.rcvtimeo
488 if timeout_ms >= 0:
489 timer = self._add_timeout(f, timeout_ms * 1e-3)
491 # we add it to the list of futures before we add the timeout as the
492 # timeout will remove the future from recv_futures to avoid leaks
493 self._recv_futures.append(_FutureEvent(f, kind, kwargs, msg=None, timer=timer))
495 # Don't let the Future sit in _recv_events after it's done
496 f.add_done_callback(
497 lambda f: self._remove_finished_future(f, self._recv_futures)
498 )
500 if self._shadow_sock.get(EVENTS) & POLLIN:
501 # recv immediately, if we can
502 self._handle_recv()
503 if self._recv_futures:
504 self._add_io_state(POLLIN)
505 return f
507 def _add_send_event(self, kind, msg=None, kwargs=None, future=None):
508 """Add a send event, returning the corresponding Future"""
509 f = future or self._Future()
510 # attempt send with DONTWAIT if no futures are waiting
511 # short-circuit for sends that will resolve immediately
512 # only call if no send Futures are waiting
513 if kind in ('send', 'send_multipart') and not self._send_futures:
514 flags = kwargs.get('flags', 0)
515 nowait_kwargs = kwargs.copy()
516 nowait_kwargs['flags'] = flags | _zmq.DONTWAIT
518 # short-circuit non-blocking calls
519 send = getattr(self._shadow_sock, kind)
520 # track if the send resolved or not
521 # (EAGAIN if DONTWAIT is not set should proceed with)
522 finish_early = True
523 try:
524 r = send(msg, **nowait_kwargs)
525 except _zmq.Again as e:
526 if flags & _zmq.DONTWAIT:
527 f.set_exception(e)
528 else:
529 # EAGAIN raised and DONTWAIT not requested,
530 # proceed with async send
531 finish_early = False
532 except Exception as e:
533 f.set_exception(e)
534 else:
535 f.set_result(r)
537 if finish_early:
538 # short-circuit resolved, return finished Future
539 # schedule wake for recv if there are any receivers waiting
540 if self._recv_futures:
541 self._schedule_remaining_events()
542 return f
544 timer = _NoTimer
545 if hasattr(_zmq, 'SNDTIMEO'):
546 timeout_ms = self._shadow_sock.get(_zmq.SNDTIMEO)
547 if timeout_ms >= 0:
548 timer = self._add_timeout(f, timeout_ms * 1e-3)
550 # we add it to the list of futures before we add the timeout as the
551 # timeout will remove the future from recv_futures to avoid leaks
552 self._send_futures.append(
553 _FutureEvent(f, kind, kwargs=kwargs, msg=msg, timer=timer)
554 )
555 # Don't let the Future sit in _send_futures after it's done
556 f.add_done_callback(
557 lambda f: self._remove_finished_future(f, self._send_futures)
558 )
560 self._add_io_state(POLLOUT)
561 return f
563 def _handle_recv(self):
564 """Handle recv events"""
565 if not self._shadow_sock.get(EVENTS) & POLLIN:
566 # event triggered, but state may have been changed between trigger and callback
567 return
568 f = None
569 while self._recv_futures:
570 f, kind, kwargs, _, timer = self._recv_futures.popleft()
571 # skip any cancelled futures
572 if f.done():
573 f = None
574 else:
575 break
577 if not self._recv_futures:
578 self._drop_io_state(POLLIN)
580 if f is None:
581 return
583 timer.cancel()
585 if kind == 'poll':
586 # on poll event, just signal ready, nothing else.
587 f.set_result(None)
588 return
589 elif kind == 'recv_multipart':
590 recv = self._shadow_sock.recv_multipart
591 elif kind == 'recv':
592 recv = self._shadow_sock.recv
593 else:
594 raise ValueError("Unhandled recv event type: %r" % kind)
596 kwargs['flags'] |= _zmq.DONTWAIT
597 try:
598 result = recv(**kwargs)
599 except Exception as e:
600 f.set_exception(e)
601 else:
602 f.set_result(result)
604 def _handle_send(self):
605 if not self._shadow_sock.get(EVENTS) & POLLOUT:
606 # event triggered, but state may have been changed between trigger and callback
607 return
608 f = None
609 while self._send_futures:
610 f, kind, kwargs, msg, timer = self._send_futures.popleft()
611 # skip any cancelled futures
612 if f.done():
613 f = None
614 else:
615 break
617 if not self._send_futures:
618 self._drop_io_state(POLLOUT)
620 if f is None:
621 return
623 timer.cancel()
625 if kind == 'poll':
626 # on poll event, just signal ready, nothing else.
627 f.set_result(None)
628 return
629 elif kind == 'send_multipart':
630 send = self._shadow_sock.send_multipart
631 elif kind == 'send':
632 send = self._shadow_sock.send
633 else:
634 raise ValueError("Unhandled send event type: %r" % kind)
636 kwargs['flags'] |= _zmq.DONTWAIT
637 try:
638 result = send(msg, **kwargs)
639 except Exception as e:
640 f.set_exception(e)
641 else:
642 f.set_result(result)
644 # event masking from ZMQStream
645 def _handle_events(self, fd=0, events=0):
646 """Dispatch IO events to _handle_recv, etc."""
647 if self._shadow_sock.closed:
648 return
650 zmq_events = self._shadow_sock.get(EVENTS)
651 if zmq_events & _zmq.POLLIN:
652 self._handle_recv()
653 if zmq_events & _zmq.POLLOUT:
654 self._handle_send()
655 self._schedule_remaining_events()
657 def _schedule_remaining_events(self, events=None):
658 """Schedule a call to handle_events next loop iteration
660 If there are still events to handle.
661 """
662 # edge-triggered handling
663 # allow passing events in, in case this is triggered by retrieving events,
664 # so we don't have to retrieve it twice.
665 if self._state == 0:
666 # not watching for anything, nothing to schedule
667 return
668 if events is None:
669 events = self._shadow_sock.get(EVENTS)
670 if events & self._state:
671 self._call_later(0, self._handle_events)
673 def _add_io_state(self, state):
674 """Add io_state to poller."""
675 if self._state != state:
676 state = self._state = self._state | state
677 self._update_handler(self._state)
679 def _drop_io_state(self, state):
680 """Stop poller from watching an io_state."""
681 if self._state & state:
682 self._state = self._state & (~state)
683 self._update_handler(self._state)
685 def _update_handler(self, state):
686 """Update IOLoop handler with state.
688 zmq FD is always read-only.
689 """
690 # ensure loop is registered and init_io has been called
691 # if there are any events to watch for
692 if state:
693 self._get_loop()
694 self._schedule_remaining_events()
696 def _init_io_state(self, loop=None):
697 """initialize the ioloop event handler"""
698 if loop is None:
699 loop = self._get_loop()
700 loop.add_handler(self._shadow_sock, self._handle_events, self._READ)
701 self._call_later(0, self._handle_events)
703 def _clear_io_state(self):
704 """unregister the ioloop event handler
706 called once during close
707 """
708 fd = self._shadow_sock
709 if self._shadow_sock.closed:
710 fd = self._fd
711 if self._current_loop is not None:
712 self._current_loop.remove_handler(fd)