Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/zmq/_future.py: 20%
388 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-07-01 06:54 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-07-01 06:54 +0000
1"""Future-returning APIs for coroutines."""
3# Copyright (c) PyZMQ Developers.
4# Distributed under the terms of the Modified BSD License.
6import warnings
7from asyncio import Future
8from collections import deque
9from itertools import chain
10from typing import (
11 Any,
12 Awaitable,
13 Callable,
14 Dict,
15 List,
16 NamedTuple,
17 Optional,
18 Tuple,
19 Type,
20 TypeVar,
21 Union,
22 cast,
23 overload,
24)
26import zmq as _zmq
27from zmq import EVENTS, POLLIN, POLLOUT
28from zmq._typing import Literal
31class _FutureEvent(NamedTuple):
32 future: Future
33 kind: str
34 kwargs: Dict
35 msg: Any
36 timer: Any
39# These are incomplete classes and need a Mixin for compatibility with an eventloop
40# defining the following attributes:
41#
42# _Future
43# _READ
44# _WRITE
45# _default_loop()
48class _Async:
49 """Mixin for common async logic"""
51 _current_loop: Any = None
52 _Future: Type[Future]
54 def _get_loop(self) -> Any:
55 """Get event loop
57 Notice if event loop has changed,
58 and register init_io_state on activation of a new event loop
59 """
60 if self._current_loop is None:
61 self._current_loop = self._default_loop()
62 self._init_io_state(self._current_loop)
63 return self._current_loop
64 current_loop = self._default_loop()
65 if current_loop is not self._current_loop:
66 # warn? This means a socket is being used in multiple loops!
67 self._current_loop = current_loop
68 self._init_io_state(current_loop)
69 return current_loop
71 def _default_loop(self) -> Any:
72 raise NotImplementedError("Must be implemented in a subclass")
74 def _init_io_state(self, loop=None) -> None:
75 pass
78class _AsyncPoller(_Async, _zmq.Poller):
79 """Poller that returns a Future on poll, instead of blocking."""
81 _socket_class: Type["_AsyncSocket"]
82 _READ: int
83 _WRITE: int
84 raw_sockets: List[Any]
86 def _watch_raw_socket(self, loop: Any, socket: Any, evt: int, f: Callable) -> None:
87 """Schedule callback for a raw socket"""
88 raise NotImplementedError()
90 def _unwatch_raw_sockets(self, loop: Any, *sockets: Any) -> None:
91 """Unschedule callback for a raw socket"""
92 raise NotImplementedError()
94 def poll(self, timeout=-1) -> Awaitable[List[Tuple[Any, int]]]: # type: ignore
95 """Return a Future for a poll event"""
96 future = self._Future()
97 if timeout == 0:
98 try:
99 result = super().poll(0)
100 except Exception as e:
101 future.set_exception(e)
102 else:
103 future.set_result(result)
104 return future
106 loop = self._get_loop()
108 # register Future to be called as soon as any event is available on any socket
109 watcher = self._Future()
111 # watch raw sockets:
112 raw_sockets: List[Any] = []
114 def wake_raw(*args):
115 if not watcher.done():
116 watcher.set_result(None)
118 watcher.add_done_callback(
119 lambda f: self._unwatch_raw_sockets(loop, *raw_sockets)
120 )
122 for socket, mask in self.sockets:
123 if isinstance(socket, _zmq.Socket):
124 if not isinstance(socket, self._socket_class):
125 # it's a blocking zmq.Socket, wrap it in async
126 socket = self._socket_class.from_socket(socket)
127 if mask & _zmq.POLLIN:
128 socket._add_recv_event('poll', future=watcher)
129 if mask & _zmq.POLLOUT:
130 socket._add_send_event('poll', future=watcher)
131 else:
132 raw_sockets.append(socket)
133 evt = 0
134 if mask & _zmq.POLLIN:
135 evt |= self._READ
136 if mask & _zmq.POLLOUT:
137 evt |= self._WRITE
138 self._watch_raw_socket(loop, socket, evt, wake_raw)
140 def on_poll_ready(f):
141 if future.done():
142 return
143 if watcher.cancelled():
144 try:
145 future.cancel()
146 except RuntimeError:
147 # RuntimeError may be called during teardown
148 pass
149 return
150 if watcher.exception():
151 future.set_exception(watcher.exception())
152 else:
153 try:
154 result = super(_AsyncPoller, self).poll(0)
155 except Exception as e:
156 future.set_exception(e)
157 else:
158 future.set_result(result)
160 watcher.add_done_callback(on_poll_ready)
162 if timeout is not None and timeout > 0:
163 # schedule cancel to fire on poll timeout, if any
164 def trigger_timeout():
165 if not watcher.done():
166 watcher.set_result(None)
168 timeout_handle = loop.call_later(1e-3 * timeout, trigger_timeout)
170 def cancel_timeout(f):
171 if hasattr(timeout_handle, 'cancel'):
172 timeout_handle.cancel()
173 else:
174 loop.remove_timeout(timeout_handle)
176 future.add_done_callback(cancel_timeout)
178 def cancel_watcher(f):
179 if not watcher.done():
180 watcher.cancel()
182 future.add_done_callback(cancel_watcher)
184 return future
187class _NoTimer:
188 @staticmethod
189 def cancel():
190 pass
193T = TypeVar("T", bound="_AsyncSocket")
196class _AsyncSocket(_Async, _zmq.Socket[Future]):
197 # Warning : these class variables are only here to allow to call super().__setattr__.
198 # They be overridden at instance initialization and not shared in the whole class
199 _recv_futures = None
200 _send_futures = None
201 _state = 0
202 _shadow_sock: "_zmq.Socket"
203 _poller_class = _AsyncPoller
204 _fd = None
206 def __init__(
207 self,
208 context=None,
209 socket_type=-1,
210 io_loop=None,
211 _from_socket: Optional["_zmq.Socket"] = None,
212 **kwargs,
213 ) -> None:
214 if isinstance(context, _zmq.Socket):
215 context, _from_socket = (None, context)
216 if _from_socket is not None:
217 super().__init__(shadow=_from_socket.underlying) # type: ignore
218 self._shadow_sock = _from_socket
219 else:
220 super().__init__(context, socket_type, **kwargs) # type: ignore
221 self._shadow_sock = _zmq.Socket.shadow(self.underlying)
223 if io_loop is not None:
224 warnings.warn(
225 f"{self.__class__.__name__}(io_loop) argument is deprecated in pyzmq 22.2."
226 " The currently active loop will always be used.",
227 DeprecationWarning,
228 stacklevel=3,
229 )
230 self._recv_futures = deque()
231 self._send_futures = deque()
232 self._state = 0
233 self._fd = self._shadow_sock.FD
235 @classmethod
236 def from_socket(cls: Type[T], socket: "_zmq.Socket", io_loop: Any = None) -> T:
237 """Create an async socket from an existing Socket"""
238 return cls(_from_socket=socket, io_loop=io_loop)
240 def close(self, linger: Optional[int] = None) -> None:
241 if not self.closed and self._fd is not None:
242 event_list: List[_FutureEvent] = list(
243 chain(self._recv_futures or [], self._send_futures or [])
244 )
245 for event in event_list:
246 if not event.future.done():
247 try:
248 event.future.cancel()
249 except RuntimeError:
250 # RuntimeError may be called during teardown
251 pass
252 self._clear_io_state()
253 super().close(linger=linger)
255 close.__doc__ = _zmq.Socket.close.__doc__
257 def get(self, key):
258 result = super().get(key)
259 if key == EVENTS:
260 self._schedule_remaining_events(result)
261 return result
263 get.__doc__ = _zmq.Socket.get.__doc__
265 @overload # type: ignore
266 def recv_multipart(
267 self, flags: int = 0, *, track: bool = False
268 ) -> Awaitable[List[bytes]]:
269 ...
271 @overload
272 def recv_multipart(
273 self, flags: int = 0, *, copy: Literal[True], track: bool = False
274 ) -> Awaitable[List[bytes]]:
275 ...
277 @overload
278 def recv_multipart(
279 self, flags: int = 0, *, copy: Literal[False], track: bool = False
280 ) -> Awaitable[List[_zmq.Frame]]: # type: ignore
281 ...
283 @overload
284 def recv_multipart(
285 self, flags: int = 0, copy: bool = True, track: bool = False
286 ) -> Awaitable[Union[List[bytes], List[_zmq.Frame]]]:
287 ...
289 def recv_multipart(
290 self, flags: int = 0, copy: bool = True, track: bool = False
291 ) -> Awaitable[Union[List[bytes], List[_zmq.Frame]]]:
292 """Receive a complete multipart zmq message.
294 Returns a Future whose result will be a multipart message.
295 """
296 return self._add_recv_event(
297 'recv_multipart', dict(flags=flags, copy=copy, track=track)
298 )
300 def recv( # type: ignore
301 self, flags: int = 0, copy: bool = True, track: bool = False
302 ) -> Awaitable[Union[bytes, _zmq.Frame]]:
303 """Receive a single zmq frame.
305 Returns a Future, whose result will be the received frame.
307 Recommend using recv_multipart instead.
308 """
309 return self._add_recv_event('recv', dict(flags=flags, copy=copy, track=track))
311 def send_multipart( # type: ignore
312 self, msg_parts: Any, flags: int = 0, copy: bool = True, track=False, **kwargs
313 ) -> Awaitable[Optional[_zmq.MessageTracker]]:
314 """Send a complete multipart zmq message.
316 Returns a Future that resolves when sending is complete.
317 """
318 kwargs['flags'] = flags
319 kwargs['copy'] = copy
320 kwargs['track'] = track
321 return self._add_send_event('send_multipart', msg=msg_parts, kwargs=kwargs)
323 def send( # type: ignore
324 self,
325 data: Any,
326 flags: int = 0,
327 copy: bool = True,
328 track: bool = False,
329 **kwargs: Any,
330 ) -> Awaitable[Optional[_zmq.MessageTracker]]:
331 """Send a single zmq frame.
333 Returns a Future that resolves when sending is complete.
335 Recommend using send_multipart instead.
336 """
337 kwargs['flags'] = flags
338 kwargs['copy'] = copy
339 kwargs['track'] = track
340 kwargs.update(dict(flags=flags, copy=copy, track=track))
341 return self._add_send_event('send', msg=data, kwargs=kwargs)
343 def _deserialize(self, recvd, load):
344 """Deserialize with Futures"""
345 f = self._Future()
347 def _chain(_):
348 """Chain result through serialization to recvd"""
349 if f.done():
350 return
351 if recvd.exception():
352 f.set_exception(recvd.exception())
353 else:
354 buf = recvd.result()
355 try:
356 loaded = load(buf)
357 except Exception as e:
358 f.set_exception(e)
359 else:
360 f.set_result(loaded)
362 recvd.add_done_callback(_chain)
364 def _chain_cancel(_):
365 """Chain cancellation from f to recvd"""
366 if recvd.done():
367 return
368 if f.cancelled():
369 recvd.cancel()
371 f.add_done_callback(_chain_cancel)
373 return f
375 def poll(self, timeout=None, flags=_zmq.POLLIN) -> Awaitable[int]: # type: ignore
376 """poll the socket for events
378 returns a Future for the poll results.
379 """
381 if self.closed:
382 raise _zmq.ZMQError(_zmq.ENOTSUP)
384 p = self._poller_class()
385 p.register(self, flags)
386 f = cast(Future, p.poll(timeout))
388 future = self._Future()
390 def unwrap_result(f):
391 if future.done():
392 return
393 if f.cancelled():
394 try:
395 future.cancel()
396 except RuntimeError:
397 # RuntimeError may be called during teardown
398 pass
399 return
400 if f.exception():
401 future.set_exception(f.exception())
402 else:
403 evts = dict(f.result())
404 future.set_result(evts.get(self, 0))
406 if f.done():
407 # hook up result if
408 unwrap_result(f)
409 else:
410 f.add_done_callback(unwrap_result)
411 return future
413 # overrides only necessary for updated types
414 def recv_string(self, *args, **kwargs) -> Awaitable[str]: # type: ignore
415 return super().recv_string(*args, **kwargs) # type: ignore
417 def send_string(self, s: str, flags: int = 0, encoding: str = 'utf-8') -> Awaitable[None]: # type: ignore
418 return super().send_string(s, flags=flags, encoding=encoding) # type: ignore
420 def _add_timeout(self, future, timeout):
421 """Add a timeout for a send or recv Future"""
423 def future_timeout():
424 if future.done():
425 # future already resolved, do nothing
426 return
428 # raise EAGAIN
429 future.set_exception(_zmq.Again())
431 return self._call_later(timeout, future_timeout)
433 def _call_later(self, delay, callback):
434 """Schedule a function to be called later
436 Override for different IOLoop implementations
438 Tornado and asyncio happen to both have ioloop.call_later
439 with the same signature.
440 """
441 return self._get_loop().call_later(delay, callback)
443 @staticmethod
444 def _remove_finished_future(future, event_list):
445 """Make sure that futures are removed from the event list when they resolve
447 Avoids delaying cleanup until the next send/recv event,
448 which may never come.
449 """
450 for f_idx, event in enumerate(event_list):
451 if event.future is future:
452 break
453 else:
454 return
456 # "future" instance is shared between sockets, but each socket has its own event list.
457 event_list.remove(event_list[f_idx])
459 def _add_recv_event(self, kind, kwargs=None, future=None):
460 """Add a recv event, returning the corresponding Future"""
461 f = future or self._Future()
462 if kind.startswith('recv') and kwargs.get('flags', 0) & _zmq.DONTWAIT:
463 # short-circuit non-blocking calls
464 recv = getattr(self._shadow_sock, kind)
465 try:
466 r = recv(**kwargs)
467 except Exception as e:
468 f.set_exception(e)
469 else:
470 f.set_result(r)
471 return f
473 timer = _NoTimer
474 if hasattr(_zmq, 'RCVTIMEO'):
475 timeout_ms = self._shadow_sock.rcvtimeo
476 if timeout_ms >= 0:
477 timer = self._add_timeout(f, timeout_ms * 1e-3)
479 # we add it to the list of futures before we add the timeout as the
480 # timeout will remove the future from recv_futures to avoid leaks
481 self._recv_futures.append(_FutureEvent(f, kind, kwargs, msg=None, timer=timer))
483 # Don't let the Future sit in _recv_events after it's done
484 f.add_done_callback(
485 lambda f: self._remove_finished_future(f, self._recv_futures)
486 )
488 if self._shadow_sock.get(EVENTS) & POLLIN:
489 # recv immediately, if we can
490 self._handle_recv()
491 if self._recv_futures:
492 self._add_io_state(POLLIN)
493 return f
495 def _add_send_event(self, kind, msg=None, kwargs=None, future=None):
496 """Add a send event, returning the corresponding Future"""
497 f = future or self._Future()
498 # attempt send with DONTWAIT if no futures are waiting
499 # short-circuit for sends that will resolve immediately
500 # only call if no send Futures are waiting
501 if kind in ('send', 'send_multipart') and not self._send_futures:
502 flags = kwargs.get('flags', 0)
503 nowait_kwargs = kwargs.copy()
504 nowait_kwargs['flags'] = flags | _zmq.DONTWAIT
506 # short-circuit non-blocking calls
507 send = getattr(self._shadow_sock, kind)
508 # track if the send resolved or not
509 # (EAGAIN if DONTWAIT is not set should proceed with)
510 finish_early = True
511 try:
512 r = send(msg, **nowait_kwargs)
513 except _zmq.Again as e:
514 if flags & _zmq.DONTWAIT:
515 f.set_exception(e)
516 else:
517 # EAGAIN raised and DONTWAIT not requested,
518 # proceed with async send
519 finish_early = False
520 except Exception as e:
521 f.set_exception(e)
522 else:
523 f.set_result(r)
525 if finish_early:
526 # short-circuit resolved, return finished Future
527 # schedule wake for recv if there are any receivers waiting
528 if self._recv_futures:
529 self._schedule_remaining_events()
530 return f
532 timer = _NoTimer
533 if hasattr(_zmq, 'SNDTIMEO'):
534 timeout_ms = self._shadow_sock.get(_zmq.SNDTIMEO)
535 if timeout_ms >= 0:
536 timer = self._add_timeout(f, timeout_ms * 1e-3)
538 # we add it to the list of futures before we add the timeout as the
539 # timeout will remove the future from recv_futures to avoid leaks
540 self._send_futures.append(
541 _FutureEvent(f, kind, kwargs=kwargs, msg=msg, timer=timer)
542 )
543 # Don't let the Future sit in _send_futures after it's done
544 f.add_done_callback(
545 lambda f: self._remove_finished_future(f, self._send_futures)
546 )
548 self._add_io_state(POLLOUT)
549 return f
551 def _handle_recv(self):
552 """Handle recv events"""
553 if not self._shadow_sock.get(EVENTS) & POLLIN:
554 # event triggered, but state may have been changed between trigger and callback
555 return
556 f = None
557 while self._recv_futures:
558 f, kind, kwargs, _, timer = self._recv_futures.popleft()
559 # skip any cancelled futures
560 if f.done():
561 f = None
562 else:
563 break
565 if not self._recv_futures:
566 self._drop_io_state(POLLIN)
568 if f is None:
569 return
571 timer.cancel()
573 if kind == 'poll':
574 # on poll event, just signal ready, nothing else.
575 f.set_result(None)
576 return
577 elif kind == 'recv_multipart':
578 recv = self._shadow_sock.recv_multipart
579 elif kind == 'recv':
580 recv = self._shadow_sock.recv
581 else:
582 raise ValueError("Unhandled recv event type: %r" % kind)
584 kwargs['flags'] |= _zmq.DONTWAIT
585 try:
586 result = recv(**kwargs)
587 except Exception as e:
588 f.set_exception(e)
589 else:
590 f.set_result(result)
592 def _handle_send(self):
593 if not self._shadow_sock.get(EVENTS) & POLLOUT:
594 # event triggered, but state may have been changed between trigger and callback
595 return
596 f = None
597 while self._send_futures:
598 f, kind, kwargs, msg, timer = self._send_futures.popleft()
599 # skip any cancelled futures
600 if f.done():
601 f = None
602 else:
603 break
605 if not self._send_futures:
606 self._drop_io_state(POLLOUT)
608 if f is None:
609 return
611 timer.cancel()
613 if kind == 'poll':
614 # on poll event, just signal ready, nothing else.
615 f.set_result(None)
616 return
617 elif kind == 'send_multipart':
618 send = self._shadow_sock.send_multipart
619 elif kind == 'send':
620 send = self._shadow_sock.send
621 else:
622 raise ValueError("Unhandled send event type: %r" % kind)
624 kwargs['flags'] |= _zmq.DONTWAIT
625 try:
626 result = send(msg, **kwargs)
627 except Exception as e:
628 f.set_exception(e)
629 else:
630 f.set_result(result)
632 # event masking from ZMQStream
633 def _handle_events(self, fd=0, events=0):
634 """Dispatch IO events to _handle_recv, etc."""
635 if self._shadow_sock.closed:
636 return
638 zmq_events = self._shadow_sock.get(EVENTS)
639 if zmq_events & _zmq.POLLIN:
640 self._handle_recv()
641 if zmq_events & _zmq.POLLOUT:
642 self._handle_send()
643 self._schedule_remaining_events()
645 def _schedule_remaining_events(self, events=None):
646 """Schedule a call to handle_events next loop iteration
648 If there are still events to handle.
649 """
650 # edge-triggered handling
651 # allow passing events in, in case this is triggered by retrieving events,
652 # so we don't have to retrieve it twice.
653 if self._state == 0:
654 # not watching for anything, nothing to schedule
655 return
656 if events is None:
657 events = self._shadow_sock.get(EVENTS)
658 if events & self._state:
659 self._call_later(0, self._handle_events)
661 def _add_io_state(self, state):
662 """Add io_state to poller."""
663 if self._state != state:
664 state = self._state = self._state | state
665 self._update_handler(self._state)
667 def _drop_io_state(self, state):
668 """Stop poller from watching an io_state."""
669 if self._state & state:
670 self._state = self._state & (~state)
671 self._update_handler(self._state)
673 def _update_handler(self, state):
674 """Update IOLoop handler with state.
676 zmq FD is always read-only.
677 """
678 # ensure loop is registered and init_io has been called
679 # if there are any events to watch for
680 if state:
681 self._get_loop()
682 self._schedule_remaining_events()
684 def _init_io_state(self, loop=None):
685 """initialize the ioloop event handler"""
686 if loop is None:
687 loop = self._get_loop()
688 loop.add_handler(self._shadow_sock, self._handle_events, self._READ)
689 self._call_later(0, self._handle_events)
691 def _clear_io_state(self):
692 """unregister the ioloop event handler
694 called once during close
695 """
696 fd = self._shadow_sock
697 if self._shadow_sock.closed:
698 fd = self._fd
699 if self._current_loop is not None:
700 self._current_loop.remove_handler(fd)