Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/zmq/eventloop/zmqstream.py: 26%
284 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-07-01 06:54 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-07-01 06:54 +0000
1#
2# Copyright 2009 Facebook
3#
4# Licensed under the Apache License, Version 2.0 (the "License"); you may
5# not use this file except in compliance with the License. You may obtain
6# a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13# License for the specific language governing permissions and limitations
14# under the License.
16"""A utility class for event-based messaging on a zmq socket using tornado.
18.. seealso::
20 - :mod:`zmq.asyncio`
21 - :mod:`zmq.eventloop.future`
22"""
24import asyncio
25import pickle
26import warnings
27from queue import Queue
28from typing import (
29 Any,
30 Awaitable,
31 Callable,
32 List,
33 Optional,
34 Sequence,
35 Union,
36 cast,
37 overload,
38)
40from tornado.ioloop import IOLoop
41from tornado.log import gen_log
43import zmq
44import zmq._future
45from zmq import POLLIN, POLLOUT
46from zmq._typing import Literal
47from zmq.utils import jsonapi
50class ZMQStream:
51 """A utility class to register callbacks when a zmq socket sends and receives
53 For use with tornado IOLoop.
55 There are three main methods
57 Methods:
59 * **on_recv(callback, copy=True):**
60 register a callback to be run every time the socket has something to receive
61 * **on_send(callback):**
62 register a callback to be run every time you call send
63 * **send_multipart(self, msg, flags=0, copy=False, callback=None):**
64 perform a send that will trigger the callback
65 if callback is passed, on_send is also called.
67 There are also send_multipart(), send_json(), send_pyobj()
69 Three other methods for deactivating the callbacks:
71 * **stop_on_recv():**
72 turn off the recv callback
73 * **stop_on_send():**
74 turn off the send callback
76 which simply call ``on_<evt>(None)``.
78 The entire socket interface, excluding direct recv methods, is also
79 provided, primarily through direct-linking the methods.
80 e.g.
82 >>> stream.bind is stream.socket.bind
83 True
86 .. versionadded:: 25
88 send/recv callbacks can be coroutines.
90 .. versionchanged:: 25
92 ZMQStreams only support base zmq.Socket classes (this has always been true, but not enforced).
93 If ZMQStreams are created with e.g. async Socket subclasses,
94 a RuntimeWarning will be shown,
95 and the socket cast back to the default zmq.Socket
96 before connecting events.
98 Previously, using async sockets (or any zmq.Socket subclass) would result in undefined behavior for the
99 arguments passed to callback functions.
100 Now, the callback functions reliably get the return value of the base `zmq.Socket` send/recv_multipart methods
101 (the list of message frames).
102 """
104 socket: zmq.Socket
105 io_loop: IOLoop
106 poller: zmq.Poller
107 _send_queue: Queue
108 _recv_callback: Optional[Callable]
109 _send_callback: Optional[Callable]
110 _close_callback: Optional[Callable]
111 _state: int = 0
112 _flushed: bool = False
113 _recv_copy: bool = False
114 _fd: int
116 def __init__(self, socket: "zmq.Socket", io_loop: Optional[IOLoop] = None):
117 if isinstance(socket, zmq._future._AsyncSocket):
118 warnings.warn(
119 f"""ZMQStream only supports the base zmq.Socket class.
121 Use zmq.Socket(shadow=other_socket)
122 or `ctx.socket(zmq.{socket._type_name}, socket_class=zmq.Socket)`
123 to create a base zmq.Socket object,
124 no matter what other kind of socket your Context creates.
125 """,
126 RuntimeWarning,
127 stacklevel=2,
128 )
129 # shadow back to base zmq.Socket,
130 # otherwise callbacks like `on_recv` will get the wrong types.
131 socket = zmq.Socket(shadow=socket)
132 self.socket = socket
134 # IOLoop.current() is deprecated if called outside the event loop
135 # that means
136 self.io_loop = io_loop or IOLoop.current()
137 self.poller = zmq.Poller()
138 self._fd = cast(int, self.socket.FD)
140 self._send_queue = Queue()
141 self._recv_callback = None
142 self._send_callback = None
143 self._close_callback = None
144 self._recv_copy = False
145 self._flushed = False
147 self._state = 0
148 self._init_io_state()
150 # shortcircuit some socket methods
151 self.bind = self.socket.bind
152 self.bind_to_random_port = self.socket.bind_to_random_port
153 self.connect = self.socket.connect
154 self.setsockopt = self.socket.setsockopt
155 self.getsockopt = self.socket.getsockopt
156 self.setsockopt_string = self.socket.setsockopt_string
157 self.getsockopt_string = self.socket.getsockopt_string
158 self.setsockopt_unicode = self.socket.setsockopt_unicode
159 self.getsockopt_unicode = self.socket.getsockopt_unicode
161 def stop_on_recv(self):
162 """Disable callback and automatic receiving."""
163 return self.on_recv(None)
165 def stop_on_send(self):
166 """Disable callback on sending."""
167 return self.on_send(None)
169 def stop_on_err(self):
170 """DEPRECATED, does nothing"""
171 gen_log.warn("on_err does nothing, and will be removed")
173 def on_err(self, callback: Callable):
174 """DEPRECATED, does nothing"""
175 gen_log.warn("on_err does nothing, and will be removed")
177 @overload
178 def on_recv(
179 self,
180 callback: Callable[[List[bytes]], Any],
181 ) -> None:
182 ...
184 @overload
185 def on_recv(
186 self,
187 callback: Callable[[List[bytes]], Any],
188 copy: Literal[True],
189 ) -> None:
190 ...
192 @overload
193 def on_recv(
194 self,
195 callback: Callable[[List[zmq.Frame]], Any],
196 copy: Literal[False],
197 ) -> None:
198 ...
200 @overload
201 def on_recv(
202 self,
203 callback: Union[
204 Callable[[List[zmq.Frame]], Any],
205 Callable[[List[bytes]], Any],
206 ],
207 copy: bool = ...,
208 ):
209 ...
211 def on_recv(
212 self,
213 callback: Union[
214 Callable[[List[zmq.Frame]], Any],
215 Callable[[List[bytes]], Any],
216 ],
217 copy: bool = True,
218 ) -> None:
219 """Register a callback for when a message is ready to recv.
221 There can be only one callback registered at a time, so each
222 call to `on_recv` replaces previously registered callbacks.
224 on_recv(None) disables recv event polling.
226 Use on_recv_stream(callback) instead, to register a callback that will receive
227 both this ZMQStream and the message, instead of just the message.
229 Parameters
230 ----------
232 callback : callable
233 callback must take exactly one argument, which will be a
234 list, as returned by socket.recv_multipart()
235 if callback is None, recv callbacks are disabled.
236 copy : bool
237 copy is passed directly to recv, so if copy is False,
238 callback will receive Message objects. If copy is True,
239 then callback will receive bytes/str objects.
241 Returns : None
242 """
244 self._check_closed()
245 assert callback is None or callable(callback)
246 self._recv_callback = callback
247 self._recv_copy = copy
248 if callback is None:
249 self._drop_io_state(zmq.POLLIN)
250 else:
251 self._add_io_state(zmq.POLLIN)
253 @overload
254 def on_recv_stream(
255 self,
256 callback: Callable[["ZMQStream", List[bytes]], Any],
257 ) -> None:
258 ...
260 @overload
261 def on_recv_stream(
262 self,
263 callback: Callable[["ZMQStream", List[bytes]], Any],
264 copy: Literal[True],
265 ) -> None:
266 ...
268 @overload
269 def on_recv_stream(
270 self,
271 callback: Callable[["ZMQStream", List[zmq.Frame]], Any],
272 copy: Literal[False],
273 ) -> None:
274 ...
276 @overload
277 def on_recv_stream(
278 self,
279 callback: Union[
280 Callable[["ZMQStream", List[zmq.Frame]], Any],
281 Callable[["ZMQStream", List[bytes]], Any],
282 ],
283 copy: bool = ...,
284 ):
285 ...
287 def on_recv_stream(
288 self,
289 callback: Union[
290 Callable[["ZMQStream", List[zmq.Frame]], Any],
291 Callable[["ZMQStream", List[bytes]], Any],
292 ],
293 copy: bool = True,
294 ):
295 """Same as on_recv, but callback will get this stream as first argument
297 callback must take exactly two arguments, as it will be called as::
299 callback(stream, msg)
301 Useful when a single callback should be used with multiple streams.
302 """
303 if callback is None:
304 self.stop_on_recv()
305 else:
307 def stream_callback(msg):
308 return callback(self, msg)
310 self.on_recv(stream_callback, copy=copy)
312 def on_send(
313 self, callback: Callable[[Sequence[Any], Optional[zmq.MessageTracker]], Any]
314 ):
315 """Register a callback to be called on each send
317 There will be two arguments::
319 callback(msg, status)
321 * `msg` will be the list of sendable objects that was just sent
322 * `status` will be the return result of socket.send_multipart(msg) -
323 MessageTracker or None.
325 Non-copying sends return a MessageTracker object whose
326 `done` attribute will be True when the send is complete.
327 This allows users to track when an object is safe to write to
328 again.
330 The second argument will always be None if copy=True
331 on the send.
333 Use on_send_stream(callback) to register a callback that will be passed
334 this ZMQStream as the first argument, in addition to the other two.
336 on_send(None) disables recv event polling.
338 Parameters
339 ----------
341 callback : callable
342 callback must take exactly two arguments, which will be
343 the message being sent (always a list),
344 and the return result of socket.send_multipart(msg) -
345 MessageTracker or None.
347 if callback is None, send callbacks are disabled.
348 """
350 self._check_closed()
351 assert callback is None or callable(callback)
352 self._send_callback = callback
354 def on_send_stream(
355 self,
356 callback: Callable[
357 ["ZMQStream", Sequence[Any], Optional[zmq.MessageTracker]], Any
358 ],
359 ):
360 """Same as on_send, but callback will get this stream as first argument
362 Callback will be passed three arguments::
364 callback(stream, msg, status)
366 Useful when a single callback should be used with multiple streams.
367 """
368 if callback is None:
369 self.stop_on_send()
370 else:
371 self.on_send(lambda msg, status: callback(self, msg, status))
373 def send(self, msg, flags=0, copy=True, track=False, callback=None, **kwargs):
374 """Send a message, optionally also register a new callback for sends.
375 See zmq.socket.send for details.
376 """
377 return self.send_multipart(
378 [msg], flags=flags, copy=copy, track=track, callback=callback, **kwargs
379 )
381 def send_multipart(
382 self,
383 msg: Sequence[Any],
384 flags: int = 0,
385 copy: bool = True,
386 track: bool = False,
387 callback: Optional[Callable] = None,
388 **kwargs: Any,
389 ) -> None:
390 """Send a multipart message, optionally also register a new callback for sends.
391 See zmq.socket.send_multipart for details.
392 """
393 kwargs.update(dict(flags=flags, copy=copy, track=track))
394 self._send_queue.put((msg, kwargs))
395 callback = callback or self._send_callback
396 if callback is not None:
397 self.on_send(callback)
398 else:
399 # noop callback
400 self.on_send(lambda *args: None)
401 self._add_io_state(zmq.POLLOUT)
403 def send_string(
404 self,
405 u: str,
406 flags: int = 0,
407 encoding: str = 'utf-8',
408 callback: Optional[Callable] = None,
409 **kwargs: Any,
410 ):
411 """Send a unicode message with an encoding.
412 See zmq.socket.send_unicode for details.
413 """
414 if not isinstance(u, str):
415 raise TypeError("unicode/str objects only")
416 return self.send(u.encode(encoding), flags=flags, callback=callback, **kwargs)
418 send_unicode = send_string
420 def send_json(
421 self,
422 obj: Any,
423 flags: int = 0,
424 callback: Optional[Callable] = None,
425 **kwargs: Any,
426 ):
427 """Send json-serialized version of an object.
428 See zmq.socket.send_json for details.
429 """
430 msg = jsonapi.dumps(obj)
431 return self.send(msg, flags=flags, callback=callback, **kwargs)
433 def send_pyobj(
434 self,
435 obj: Any,
436 flags: int = 0,
437 protocol: int = -1,
438 callback: Optional[Callable] = None,
439 **kwargs: Any,
440 ):
441 """Send a Python object as a message using pickle to serialize.
443 See zmq.socket.send_json for details.
444 """
445 msg = pickle.dumps(obj, protocol)
446 return self.send(msg, flags, callback=callback, **kwargs)
448 def _finish_flush(self):
449 """callback for unsetting _flushed flag."""
450 self._flushed = False
452 def flush(self, flag: int = zmq.POLLIN | zmq.POLLOUT, limit: Optional[int] = None):
453 """Flush pending messages.
455 This method safely handles all pending incoming and/or outgoing messages,
456 bypassing the inner loop, passing them to the registered callbacks.
458 A limit can be specified, to prevent blocking under high load.
460 flush will return the first time ANY of these conditions are met:
461 * No more events matching the flag are pending.
462 * the total number of events handled reaches the limit.
464 Note that if ``flag|POLLIN != 0``, recv events will be flushed even if no callback
465 is registered, unlike normal IOLoop operation. This allows flush to be
466 used to remove *and ignore* incoming messages.
468 Parameters
469 ----------
470 flag : int, default=POLLIN|POLLOUT
471 0MQ poll flags.
472 If flag|POLLIN, recv events will be flushed.
473 If flag|POLLOUT, send events will be flushed.
474 Both flags can be set at once, which is the default.
475 limit : None or int, optional
476 The maximum number of messages to send or receive.
477 Both send and recv count against this limit.
479 Returns
480 -------
481 int : count of events handled (both send and recv)
482 """
483 self._check_closed()
484 # unset self._flushed, so callbacks will execute, in case flush has
485 # already been called this iteration
486 already_flushed = self._flushed
487 self._flushed = False
488 # initialize counters
489 count = 0
491 def update_flag():
492 """Update the poll flag, to prevent registering POLLOUT events
493 if we don't have pending sends."""
494 return flag & zmq.POLLIN | (self.sending() and flag & zmq.POLLOUT)
496 flag = update_flag()
497 if not flag:
498 # nothing to do
499 return 0
500 self.poller.register(self.socket, flag)
501 events = self.poller.poll(0)
502 while events and (not limit or count < limit):
503 s, event = events[0]
504 if event & POLLIN: # receiving
505 self._handle_recv()
506 count += 1
507 if self.socket is None:
508 # break if socket was closed during callback
509 break
510 if event & POLLOUT and self.sending():
511 self._handle_send()
512 count += 1
513 if self.socket is None:
514 # break if socket was closed during callback
515 break
517 flag = update_flag()
518 if flag:
519 self.poller.register(self.socket, flag)
520 events = self.poller.poll(0)
521 else:
522 events = []
523 if count: # only bypass loop if we actually flushed something
524 # skip send/recv callbacks this iteration
525 self._flushed = True
526 # reregister them at the end of the loop
527 if not already_flushed: # don't need to do it again
528 self.io_loop.add_callback(self._finish_flush)
529 elif already_flushed:
530 self._flushed = True
532 # update ioloop poll state, which may have changed
533 self._rebuild_io_state()
534 return count
536 def set_close_callback(self, callback: Optional[Callable]):
537 """Call the given callback when the stream is closed."""
538 self._close_callback = callback
540 def close(self, linger: Optional[int] = None) -> None:
541 """Close this stream."""
542 if self.socket is not None:
543 if self.socket.closed:
544 # fallback on raw fd for closed sockets
545 # hopefully this happened promptly after close,
546 # otherwise somebody else may have the FD
547 warnings.warn(
548 "Unregistering FD %s after closing socket. "
549 "This could result in unregistering handlers for the wrong socket. "
550 "Please use stream.close() instead of closing the socket directly."
551 % self._fd,
552 stacklevel=2,
553 )
554 self.io_loop.remove_handler(self._fd)
555 else:
556 self.io_loop.remove_handler(self.socket)
557 self.socket.close(linger)
558 self.socket = None # type: ignore
559 if self._close_callback:
560 self._run_callback(self._close_callback)
562 def receiving(self) -> bool:
563 """Returns True if we are currently receiving from the stream."""
564 return self._recv_callback is not None
566 def sending(self) -> bool:
567 """Returns True if we are currently sending to the stream."""
568 return not self._send_queue.empty()
570 def closed(self) -> bool:
571 if self.socket is None:
572 return True
573 if self.socket.closed:
574 # underlying socket has been closed, but not by us!
575 # trigger our cleanup
576 self.close()
577 return True
578 return False
580 def _run_callback(self, callback, *args, **kwargs):
581 """Wrap running callbacks in try/except to allow us to
582 close our socket."""
583 try:
584 f = callback(*args, **kwargs)
585 if isinstance(f, Awaitable):
586 f = asyncio.ensure_future(f)
587 else:
588 f = None
589 except Exception:
590 gen_log.error("Uncaught exception in ZMQStream callback", exc_info=True)
591 # Re-raise the exception so that IOLoop.handle_callback_exception
592 # can see it and log the error
593 raise
595 if f is not None:
596 # handle async callbacks
597 def _log_error(f):
598 try:
599 f.result()
600 except Exception:
601 gen_log.error(
602 "Uncaught exception in ZMQStream callback", exc_info=True
603 )
605 f.add_done_callback(_log_error)
607 def _handle_events(self, fd, events):
608 """This method is the actual handler for IOLoop, that gets called whenever
609 an event on my socket is posted. It dispatches to _handle_recv, etc."""
610 if not self.socket:
611 gen_log.warning("Got events for closed stream %s", self)
612 return
613 try:
614 zmq_events = self.socket.EVENTS
615 except zmq.ContextTerminated:
616 gen_log.warning("Got events for stream %s after terminating context", self)
617 # trigger close check, this will unregister callbacks
618 self.closed()
619 return
620 except zmq.ZMQError as e:
621 # run close check
622 # shadow sockets may have been closed elsewhere,
623 # which should show up as ENOTSOCK here
624 if self.closed():
625 gen_log.warning(
626 "Got events for stream %s attached to closed socket: %s", self, e
627 )
628 else:
629 gen_log.error("Error getting events for %s: %s", self, e)
630 return
631 try:
632 # dispatch events:
633 if zmq_events & zmq.POLLIN and self.receiving():
634 self._handle_recv()
635 if not self.socket:
636 return
637 if zmq_events & zmq.POLLOUT and self.sending():
638 self._handle_send()
639 if not self.socket:
640 return
642 # rebuild the poll state
643 self._rebuild_io_state()
644 except Exception:
645 gen_log.error("Uncaught exception in zmqstream callback", exc_info=True)
646 raise
648 def _handle_recv(self):
649 """Handle a recv event."""
650 if self._flushed:
651 return
652 try:
653 msg = self.socket.recv_multipart(zmq.NOBLOCK, copy=self._recv_copy)
654 except zmq.ZMQError as e:
655 if e.errno == zmq.EAGAIN:
656 # state changed since poll event
657 pass
658 else:
659 raise
660 else:
661 if self._recv_callback:
662 callback = self._recv_callback
663 self._run_callback(callback, msg)
665 def _handle_send(self):
666 """Handle a send event."""
667 if self._flushed:
668 return
669 if not self.sending():
670 gen_log.error("Shouldn't have handled a send event")
671 return
673 msg, kwargs = self._send_queue.get()
674 try:
675 status = self.socket.send_multipart(msg, **kwargs)
676 except zmq.ZMQError as e:
677 gen_log.error("SEND Error: %s", e)
678 status = e
679 if self._send_callback:
680 callback = self._send_callback
681 self._run_callback(callback, msg, status)
683 def _check_closed(self):
684 if not self.socket:
685 raise OSError("Stream is closed")
687 def _rebuild_io_state(self):
688 """rebuild io state based on self.sending() and receiving()"""
689 if self.socket is None:
690 return
691 state = 0
692 if self.receiving():
693 state |= zmq.POLLIN
694 if self.sending():
695 state |= zmq.POLLOUT
697 self._state = state
698 self._update_handler(state)
700 def _add_io_state(self, state):
701 """Add io_state to poller."""
702 self._state = self._state | state
703 self._update_handler(self._state)
705 def _drop_io_state(self, state):
706 """Stop poller from watching an io_state."""
707 self._state = self._state & (~state)
708 self._update_handler(self._state)
710 def _update_handler(self, state):
711 """Update IOLoop handler with state."""
712 if self.socket is None:
713 return
715 if state & self.socket.events:
716 # events still exist that haven't been processed
717 # explicitly schedule handling to avoid missing events due to edge-triggered FDs
718 self.io_loop.add_callback(lambda: self._handle_events(self.socket, 0))
720 def _init_io_state(self):
721 """initialize the ioloop event handler"""
722 self.io_loop.add_handler(self.socket, self._handle_events, self.io_loop.READ)