1# Derived from iostream.py from tornado 1.0, Copyright 2009 Facebook
2# Used under Apache License Version 2.0
3#
4# Modifications are Copyright (C) PyZMQ Developers
5# Distributed under the terms of the Modified BSD License.
6"""A utility class for event-based messaging on a zmq socket using tornado.
7
8.. seealso::
9
10 - :mod:`zmq.asyncio`
11 - :mod:`zmq.eventloop.future`
12"""
13
14from __future__ import annotations
15
16import asyncio
17import pickle
18import warnings
19from queue import Queue
20from typing import Any, Awaitable, Callable, Literal, Sequence, cast, overload
21
22from tornado.ioloop import IOLoop
23from tornado.log import gen_log
24
25import zmq
26import zmq._future
27from zmq import POLLIN, POLLOUT
28from zmq.utils import jsonapi
29
30
31class ZMQStream:
32 """A utility class to register callbacks when a zmq socket sends and receives
33
34 For use with tornado IOLoop.
35
36 There are three main methods
37
38 Methods:
39
40 * **on_recv(callback, copy=True):**
41 register a callback to be run every time the socket has something to receive
42 * **on_send(callback):**
43 register a callback to be run every time you call send
44 * **send_multipart(self, msg, flags=0, copy=False, callback=None):**
45 perform a send that will trigger the callback
46 if callback is passed, on_send is also called.
47
48 There are also send_multipart(), send_json(), send_pyobj()
49
50 Three other methods for deactivating the callbacks:
51
52 * **stop_on_recv():**
53 turn off the recv callback
54 * **stop_on_send():**
55 turn off the send callback
56
57 which simply call ``on_<evt>(None)``.
58
59 The entire socket interface, excluding direct recv methods, is also
60 provided, primarily through direct-linking the methods.
61 e.g.
62
63 >>> stream.bind is stream.socket.bind
64 True
65
66
67 .. versionadded:: 25
68
69 send/recv callbacks can be coroutines.
70
71 .. versionchanged:: 25
72
73 ZMQStreams only support base zmq.Socket classes (this has always been true, but not enforced).
74 If ZMQStreams are created with e.g. async Socket subclasses,
75 a RuntimeWarning will be shown,
76 and the socket cast back to the default zmq.Socket
77 before connecting events.
78
79 Previously, using async sockets (or any zmq.Socket subclass) would result in undefined behavior for the
80 arguments passed to callback functions.
81 Now, the callback functions reliably get the return value of the base `zmq.Socket` send/recv_multipart methods
82 (the list of message frames).
83 """
84
85 socket: zmq.Socket
86 io_loop: IOLoop
87 poller: zmq.Poller
88 _send_queue: Queue
89 _recv_callback: Callable | None
90 _send_callback: Callable | None
91 _close_callback: Callable | None
92 _state: int = 0
93 _flushed: bool = False
94 _recv_copy: bool = False
95 _fd: int
96
97 def __init__(self, socket: zmq.Socket, io_loop: IOLoop | None = None):
98 if isinstance(socket, zmq._future._AsyncSocket):
99 warnings.warn(
100 f"""ZMQStream only supports the base zmq.Socket class.
101
102 Use zmq.Socket(shadow=other_socket)
103 or `ctx.socket(zmq.{socket._type_name}, socket_class=zmq.Socket)`
104 to create a base zmq.Socket object,
105 no matter what other kind of socket your Context creates.
106 """,
107 RuntimeWarning,
108 stacklevel=2,
109 )
110 # shadow back to base zmq.Socket,
111 # otherwise callbacks like `on_recv` will get the wrong types.
112 socket = zmq.Socket(shadow=socket)
113 self.socket = socket
114
115 # IOLoop.current() is deprecated if called outside the event loop
116 # that means
117 self.io_loop = io_loop or IOLoop.current()
118 self.poller = zmq.Poller()
119 self._fd = cast(int, self.socket.FD)
120
121 self._send_queue = Queue()
122 self._recv_callback = None
123 self._send_callback = None
124 self._close_callback = None
125 self._recv_copy = False
126 self._flushed = False
127
128 self._state = 0
129 self._init_io_state()
130
131 # shortcircuit some socket methods
132 self.bind = self.socket.bind
133 self.bind_to_random_port = self.socket.bind_to_random_port
134 self.connect = self.socket.connect
135 self.setsockopt = self.socket.setsockopt
136 self.getsockopt = self.socket.getsockopt
137 self.setsockopt_string = self.socket.setsockopt_string
138 self.getsockopt_string = self.socket.getsockopt_string
139 self.setsockopt_unicode = self.socket.setsockopt_unicode
140 self.getsockopt_unicode = self.socket.getsockopt_unicode
141
142 def stop_on_recv(self):
143 """Disable callback and automatic receiving."""
144 return self.on_recv(None)
145
146 def stop_on_send(self):
147 """Disable callback on sending."""
148 return self.on_send(None)
149
150 def stop_on_err(self):
151 """DEPRECATED, does nothing"""
152 gen_log.warn("on_err does nothing, and will be removed")
153
154 def on_err(self, callback: Callable):
155 """DEPRECATED, does nothing"""
156 gen_log.warn("on_err does nothing, and will be removed")
157
158 @overload
159 def on_recv(
160 self,
161 callback: Callable[[list[bytes]], Any],
162 ) -> None: ...
163
164 @overload
165 def on_recv(
166 self,
167 callback: Callable[[list[bytes]], Any],
168 copy: Literal[True],
169 ) -> None: ...
170
171 @overload
172 def on_recv(
173 self,
174 callback: Callable[[list[zmq.Frame]], Any],
175 copy: Literal[False],
176 ) -> None: ...
177
178 @overload
179 def on_recv(
180 self,
181 callback: Callable[[list[zmq.Frame]], Any] | Callable[[list[bytes]], Any],
182 copy: bool = ...,
183 ): ...
184
185 def on_recv(
186 self,
187 callback: Callable[[list[zmq.Frame]], Any] | Callable[[list[bytes]], Any],
188 copy: bool = True,
189 ) -> None:
190 """Register a callback for when a message is ready to recv.
191
192 There can be only one callback registered at a time, so each
193 call to `on_recv` replaces previously registered callbacks.
194
195 on_recv(None) disables recv event polling.
196
197 Use on_recv_stream(callback) instead, to register a callback that will receive
198 both this ZMQStream and the message, instead of just the message.
199
200 Parameters
201 ----------
202
203 callback : callable
204 callback must take exactly one argument, which will be a
205 list, as returned by socket.recv_multipart()
206 if callback is None, recv callbacks are disabled.
207 copy : bool
208 copy is passed directly to recv, so if copy is False,
209 callback will receive Message objects. If copy is True,
210 then callback will receive bytes/str objects.
211
212 Returns : None
213 """
214
215 self._check_closed()
216 assert callback is None or callable(callback)
217 self._recv_callback = callback
218 self._recv_copy = copy
219 if callback is None:
220 self._drop_io_state(zmq.POLLIN)
221 else:
222 self._add_io_state(zmq.POLLIN)
223
224 @overload
225 def on_recv_stream(
226 self,
227 callback: Callable[[ZMQStream, list[bytes]], Any],
228 ) -> None: ...
229
230 @overload
231 def on_recv_stream(
232 self,
233 callback: Callable[[ZMQStream, list[bytes]], Any],
234 copy: Literal[True],
235 ) -> None: ...
236
237 @overload
238 def on_recv_stream(
239 self,
240 callback: Callable[[ZMQStream, list[zmq.Frame]], Any],
241 copy: Literal[False],
242 ) -> None: ...
243
244 @overload
245 def on_recv_stream(
246 self,
247 callback: (
248 Callable[[ZMQStream, list[zmq.Frame]], Any]
249 | Callable[[ZMQStream, list[bytes]], Any]
250 ),
251 copy: bool = ...,
252 ): ...
253
254 def on_recv_stream(
255 self,
256 callback: (
257 Callable[[ZMQStream, list[zmq.Frame]], Any]
258 | Callable[[ZMQStream, list[bytes]], Any]
259 ),
260 copy: bool = True,
261 ):
262 """Same as on_recv, but callback will get this stream as first argument
263
264 callback must take exactly two arguments, as it will be called as::
265
266 callback(stream, msg)
267
268 Useful when a single callback should be used with multiple streams.
269 """
270 if callback is None:
271 self.stop_on_recv()
272 else:
273
274 def stream_callback(msg):
275 return callback(self, msg)
276
277 self.on_recv(stream_callback, copy=copy)
278
279 def on_send(
280 self, callback: Callable[[Sequence[Any], zmq.MessageTracker | None], Any]
281 ):
282 """Register a callback to be called on each send
283
284 There will be two arguments::
285
286 callback(msg, status)
287
288 * `msg` will be the list of sendable objects that was just sent
289 * `status` will be the return result of socket.send_multipart(msg) -
290 MessageTracker or None.
291
292 Non-copying sends return a MessageTracker object whose
293 `done` attribute will be True when the send is complete.
294 This allows users to track when an object is safe to write to
295 again.
296
297 The second argument will always be None if copy=True
298 on the send.
299
300 Use on_send_stream(callback) to register a callback that will be passed
301 this ZMQStream as the first argument, in addition to the other two.
302
303 on_send(None) disables recv event polling.
304
305 Parameters
306 ----------
307
308 callback : callable
309 callback must take exactly two arguments, which will be
310 the message being sent (always a list),
311 and the return result of socket.send_multipart(msg) -
312 MessageTracker or None.
313
314 if callback is None, send callbacks are disabled.
315 """
316
317 self._check_closed()
318 assert callback is None or callable(callback)
319 self._send_callback = callback
320
321 def on_send_stream(
322 self,
323 callback: Callable[[ZMQStream, Sequence[Any], zmq.MessageTracker | None], Any],
324 ):
325 """Same as on_send, but callback will get this stream as first argument
326
327 Callback will be passed three arguments::
328
329 callback(stream, msg, status)
330
331 Useful when a single callback should be used with multiple streams.
332 """
333 if callback is None:
334 self.stop_on_send()
335 else:
336 self.on_send(lambda msg, status: callback(self, msg, status))
337
338 def send(self, msg, flags=0, copy=True, track=False, callback=None, **kwargs):
339 """Send a message, optionally also register a new callback for sends.
340 See zmq.socket.send for details.
341 """
342 return self.send_multipart(
343 [msg], flags=flags, copy=copy, track=track, callback=callback, **kwargs
344 )
345
346 def send_multipart(
347 self,
348 msg: Sequence[Any],
349 flags: int = 0,
350 copy: bool = True,
351 track: bool = False,
352 callback: Callable | None = None,
353 **kwargs: Any,
354 ) -> None:
355 """Send a multipart message, optionally also register a new callback for sends.
356 See zmq.socket.send_multipart for details.
357 """
358 kwargs.update(dict(flags=flags, copy=copy, track=track))
359 self._send_queue.put((msg, kwargs))
360 callback = callback or self._send_callback
361 if callback is not None:
362 self.on_send(callback)
363 else:
364 # noop callback
365 self.on_send(lambda *args: None)
366 self._add_io_state(zmq.POLLOUT)
367
368 def send_string(
369 self,
370 u: str,
371 flags: int = 0,
372 encoding: str = 'utf-8',
373 callback: Callable | None = None,
374 **kwargs: Any,
375 ):
376 """Send a unicode message with an encoding.
377 See zmq.socket.send_unicode for details.
378 """
379 if not isinstance(u, str):
380 raise TypeError("unicode/str objects only")
381 return self.send(u.encode(encoding), flags=flags, callback=callback, **kwargs)
382
383 send_unicode = send_string
384
385 def send_json(
386 self,
387 obj: Any,
388 flags: int = 0,
389 callback: Callable | None = None,
390 **kwargs: Any,
391 ):
392 """Send json-serialized version of an object.
393 See zmq.socket.send_json for details.
394 """
395 msg = jsonapi.dumps(obj)
396 return self.send(msg, flags=flags, callback=callback, **kwargs)
397
398 def send_pyobj(
399 self,
400 obj: Any,
401 flags: int = 0,
402 protocol: int = -1,
403 callback: Callable | None = None,
404 **kwargs: Any,
405 ):
406 """Send a Python object as a message using pickle to serialize.
407
408 See zmq.socket.send_json for details.
409 """
410 msg = pickle.dumps(obj, protocol)
411 return self.send(msg, flags, callback=callback, **kwargs)
412
413 def _finish_flush(self):
414 """callback for unsetting _flushed flag."""
415 self._flushed = False
416
417 def flush(self, flag: int = zmq.POLLIN | zmq.POLLOUT, limit: int | None = None):
418 """Flush pending messages.
419
420 This method safely handles all pending incoming and/or outgoing messages,
421 bypassing the inner loop, passing them to the registered callbacks.
422
423 A limit can be specified, to prevent blocking under high load.
424
425 flush will return the first time ANY of these conditions are met:
426 * No more events matching the flag are pending.
427 * the total number of events handled reaches the limit.
428
429 Note that if ``flag|POLLIN != 0``, recv events will be flushed even if no callback
430 is registered, unlike normal IOLoop operation. This allows flush to be
431 used to remove *and ignore* incoming messages.
432
433 Parameters
434 ----------
435 flag : int
436 default=POLLIN|POLLOUT
437 0MQ poll flags.
438 If flag|POLLIN, recv events will be flushed.
439 If flag|POLLOUT, send events will be flushed.
440 Both flags can be set at once, which is the default.
441 limit : None or int, optional
442 The maximum number of messages to send or receive.
443 Both send and recv count against this limit.
444
445 Returns
446 -------
447 int :
448 count of events handled (both send and recv)
449 """
450 self._check_closed()
451 # unset self._flushed, so callbacks will execute, in case flush has
452 # already been called this iteration
453 already_flushed = self._flushed
454 self._flushed = False
455 # initialize counters
456 count = 0
457
458 def update_flag():
459 """Update the poll flag, to prevent registering POLLOUT events
460 if we don't have pending sends."""
461 return flag & zmq.POLLIN | (self.sending() and flag & zmq.POLLOUT)
462
463 flag = update_flag()
464 if not flag:
465 # nothing to do
466 return 0
467 self.poller.register(self.socket, flag)
468 events = self.poller.poll(0)
469 while events and (not limit or count < limit):
470 s, event = events[0]
471 if event & POLLIN: # receiving
472 self._handle_recv()
473 count += 1
474 if self.socket is None:
475 # break if socket was closed during callback
476 break
477 if event & POLLOUT and self.sending():
478 self._handle_send()
479 count += 1
480 if self.socket is None:
481 # break if socket was closed during callback
482 break
483
484 flag = update_flag()
485 if flag:
486 self.poller.register(self.socket, flag)
487 events = self.poller.poll(0)
488 else:
489 events = []
490 if count: # only bypass loop if we actually flushed something
491 # skip send/recv callbacks this iteration
492 self._flushed = True
493 # reregister them at the end of the loop
494 if not already_flushed: # don't need to do it again
495 self.io_loop.add_callback(self._finish_flush)
496 elif already_flushed:
497 self._flushed = True
498
499 # update ioloop poll state, which may have changed
500 self._rebuild_io_state()
501 return count
502
503 def set_close_callback(self, callback: Callable | None):
504 """Call the given callback when the stream is closed."""
505 self._close_callback = callback
506
507 def close(self, linger: int | None = None) -> None:
508 """Close this stream."""
509 if self.socket is not None:
510 if self.socket.closed:
511 # fallback on raw fd for closed sockets
512 # hopefully this happened promptly after close,
513 # otherwise somebody else may have the FD
514 warnings.warn(
515 f"Unregistering FD {self._fd} after closing socket. "
516 "This could result in unregistering handlers for the wrong socket. "
517 "Please use stream.close() instead of closing the socket directly.",
518 stacklevel=2,
519 )
520 self.io_loop.remove_handler(self._fd)
521 else:
522 self.io_loop.remove_handler(self.socket)
523 self.socket.close(linger)
524 self.socket = None # type: ignore
525 if self._close_callback:
526 self._run_callback(self._close_callback)
527
528 def receiving(self) -> bool:
529 """Returns True if we are currently receiving from the stream."""
530 return self._recv_callback is not None
531
532 def sending(self) -> bool:
533 """Returns True if we are currently sending to the stream."""
534 return not self._send_queue.empty()
535
536 def closed(self) -> bool:
537 if self.socket is None:
538 return True
539 if self.socket.closed:
540 # underlying socket has been closed, but not by us!
541 # trigger our cleanup
542 self.close()
543 return True
544 return False
545
546 def _run_callback(self, callback, *args, **kwargs):
547 """Wrap running callbacks in try/except to allow us to
548 close our socket."""
549 try:
550 f = callback(*args, **kwargs)
551 if isinstance(f, Awaitable):
552 f = asyncio.ensure_future(f)
553 else:
554 f = None
555 except Exception:
556 gen_log.error("Uncaught exception in ZMQStream callback", exc_info=True)
557 # Re-raise the exception so that IOLoop.handle_callback_exception
558 # can see it and log the error
559 raise
560
561 if f is not None:
562 # handle async callbacks
563 def _log_error(f):
564 try:
565 f.result()
566 except Exception:
567 gen_log.error(
568 "Uncaught exception in ZMQStream callback", exc_info=True
569 )
570
571 f.add_done_callback(_log_error)
572
573 def _handle_events(self, fd, events):
574 """This method is the actual handler for IOLoop, that gets called whenever
575 an event on my socket is posted. It dispatches to _handle_recv, etc."""
576 if not self.socket:
577 gen_log.warning("Got events for closed stream %s", self)
578 return
579 try:
580 zmq_events = self.socket.EVENTS
581 except zmq.ContextTerminated:
582 gen_log.warning("Got events for stream %s after terminating context", self)
583 # trigger close check, this will unregister callbacks
584 self.closed()
585 return
586 except zmq.ZMQError as e:
587 # run close check
588 # shadow sockets may have been closed elsewhere,
589 # which should show up as ENOTSOCK here
590 if self.closed():
591 gen_log.warning(
592 "Got events for stream %s attached to closed socket: %s", self, e
593 )
594 else:
595 gen_log.error("Error getting events for %s: %s", self, e)
596 return
597 try:
598 # dispatch events:
599 if zmq_events & zmq.POLLIN and self.receiving():
600 self._handle_recv()
601 if not self.socket:
602 return
603 if zmq_events & zmq.POLLOUT and self.sending():
604 self._handle_send()
605 if not self.socket:
606 return
607
608 # rebuild the poll state
609 self._rebuild_io_state()
610 except Exception:
611 gen_log.error("Uncaught exception in zmqstream callback", exc_info=True)
612 raise
613
614 def _handle_recv(self):
615 """Handle a recv event."""
616 if self._flushed:
617 return
618 try:
619 msg = self.socket.recv_multipart(zmq.NOBLOCK, copy=self._recv_copy)
620 except zmq.ZMQError as e:
621 if e.errno == zmq.EAGAIN:
622 # state changed since poll event
623 pass
624 else:
625 raise
626 else:
627 if self._recv_callback:
628 callback = self._recv_callback
629 self._run_callback(callback, msg)
630
631 def _handle_send(self):
632 """Handle a send event."""
633 if self._flushed:
634 return
635 if not self.sending():
636 gen_log.error("Shouldn't have handled a send event")
637 return
638
639 msg, kwargs = self._send_queue.get()
640 try:
641 status = self.socket.send_multipart(msg, **kwargs)
642 except zmq.ZMQError as e:
643 gen_log.error("SEND Error: %s", e)
644 status = e
645 if self._send_callback:
646 callback = self._send_callback
647 self._run_callback(callback, msg, status)
648
649 def _check_closed(self):
650 if not self.socket:
651 raise OSError("Stream is closed")
652
653 def _rebuild_io_state(self):
654 """rebuild io state based on self.sending() and receiving()"""
655 if self.socket is None:
656 return
657 state = 0
658 if self.receiving():
659 state |= zmq.POLLIN
660 if self.sending():
661 state |= zmq.POLLOUT
662
663 self._state = state
664 self._update_handler(state)
665
666 def _add_io_state(self, state):
667 """Add io_state to poller."""
668 self._state = self._state | state
669 self._update_handler(self._state)
670
671 def _drop_io_state(self, state):
672 """Stop poller from watching an io_state."""
673 self._state = self._state & (~state)
674 self._update_handler(self._state)
675
676 def _update_handler(self, state):
677 """Update IOLoop handler with state."""
678 if self.socket is None:
679 return
680
681 if state & self.socket.events:
682 # events still exist that haven't been processed
683 # explicitly schedule handling to avoid missing events due to edge-triggered FDs
684 self.io_loop.add_callback(lambda: self._handle_events(self.socket, 0))
685
686 def _init_io_state(self):
687 """initialize the ioloop event handler"""
688 self.io_loop.add_handler(self.socket, self._handle_events, self.io_loop.READ)