1#
2# Copyright 2009 Facebook
3#
4# Licensed under the Apache License, Version 2.0 (the "License"); you may
5# not use this file except in compliance with the License. You may obtain
6# a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13# License for the specific language governing permissions and limitations
14# under the License.
15"""A utility class for event-based messaging on a zmq socket using tornado.
16
17.. seealso::
18
19 - :mod:`zmq.asyncio`
20 - :mod:`zmq.eventloop.future`
21"""
22
23from __future__ import annotations
24
25import asyncio
26import pickle
27import warnings
28from queue import Queue
29from typing import Any, Awaitable, Callable, Sequence, cast, overload
30
31from tornado.ioloop import IOLoop
32from tornado.log import gen_log
33
34import zmq
35import zmq._future
36from zmq import POLLIN, POLLOUT
37from zmq._typing import Literal
38from zmq.utils import jsonapi
39
40
41class ZMQStream:
42 """A utility class to register callbacks when a zmq socket sends and receives
43
44 For use with tornado IOLoop.
45
46 There are three main methods
47
48 Methods:
49
50 * **on_recv(callback, copy=True):**
51 register a callback to be run every time the socket has something to receive
52 * **on_send(callback):**
53 register a callback to be run every time you call send
54 * **send_multipart(self, msg, flags=0, copy=False, callback=None):**
55 perform a send that will trigger the callback
56 if callback is passed, on_send is also called.
57
58 There are also send_multipart(), send_json(), send_pyobj()
59
60 Three other methods for deactivating the callbacks:
61
62 * **stop_on_recv():**
63 turn off the recv callback
64 * **stop_on_send():**
65 turn off the send callback
66
67 which simply call ``on_<evt>(None)``.
68
69 The entire socket interface, excluding direct recv methods, is also
70 provided, primarily through direct-linking the methods.
71 e.g.
72
73 >>> stream.bind is stream.socket.bind
74 True
75
76
77 .. versionadded:: 25
78
79 send/recv callbacks can be coroutines.
80
81 .. versionchanged:: 25
82
83 ZMQStreams only support base zmq.Socket classes (this has always been true, but not enforced).
84 If ZMQStreams are created with e.g. async Socket subclasses,
85 a RuntimeWarning will be shown,
86 and the socket cast back to the default zmq.Socket
87 before connecting events.
88
89 Previously, using async sockets (or any zmq.Socket subclass) would result in undefined behavior for the
90 arguments passed to callback functions.
91 Now, the callback functions reliably get the return value of the base `zmq.Socket` send/recv_multipart methods
92 (the list of message frames).
93 """
94
95 socket: zmq.Socket
96 io_loop: IOLoop
97 poller: zmq.Poller
98 _send_queue: Queue
99 _recv_callback: Callable | None
100 _send_callback: Callable | None
101 _close_callback: Callable | None
102 _state: int = 0
103 _flushed: bool = False
104 _recv_copy: bool = False
105 _fd: int
106
107 def __init__(self, socket: zmq.Socket, io_loop: IOLoop | None = None):
108 if isinstance(socket, zmq._future._AsyncSocket):
109 warnings.warn(
110 f"""ZMQStream only supports the base zmq.Socket class.
111
112 Use zmq.Socket(shadow=other_socket)
113 or `ctx.socket(zmq.{socket._type_name}, socket_class=zmq.Socket)`
114 to create a base zmq.Socket object,
115 no matter what other kind of socket your Context creates.
116 """,
117 RuntimeWarning,
118 stacklevel=2,
119 )
120 # shadow back to base zmq.Socket,
121 # otherwise callbacks like `on_recv` will get the wrong types.
122 socket = zmq.Socket(shadow=socket)
123 self.socket = socket
124
125 # IOLoop.current() is deprecated if called outside the event loop
126 # that means
127 self.io_loop = io_loop or IOLoop.current()
128 self.poller = zmq.Poller()
129 self._fd = cast(int, self.socket.FD)
130
131 self._send_queue = Queue()
132 self._recv_callback = None
133 self._send_callback = None
134 self._close_callback = None
135 self._recv_copy = False
136 self._flushed = False
137
138 self._state = 0
139 self._init_io_state()
140
141 # shortcircuit some socket methods
142 self.bind = self.socket.bind
143 self.bind_to_random_port = self.socket.bind_to_random_port
144 self.connect = self.socket.connect
145 self.setsockopt = self.socket.setsockopt
146 self.getsockopt = self.socket.getsockopt
147 self.setsockopt_string = self.socket.setsockopt_string
148 self.getsockopt_string = self.socket.getsockopt_string
149 self.setsockopt_unicode = self.socket.setsockopt_unicode
150 self.getsockopt_unicode = self.socket.getsockopt_unicode
151
152 def stop_on_recv(self):
153 """Disable callback and automatic receiving."""
154 return self.on_recv(None)
155
156 def stop_on_send(self):
157 """Disable callback on sending."""
158 return self.on_send(None)
159
160 def stop_on_err(self):
161 """DEPRECATED, does nothing"""
162 gen_log.warn("on_err does nothing, and will be removed")
163
164 def on_err(self, callback: Callable):
165 """DEPRECATED, does nothing"""
166 gen_log.warn("on_err does nothing, and will be removed")
167
168 @overload
169 def on_recv(
170 self,
171 callback: Callable[[list[bytes]], Any],
172 ) -> None: ...
173
174 @overload
175 def on_recv(
176 self,
177 callback: Callable[[list[bytes]], Any],
178 copy: Literal[True],
179 ) -> None: ...
180
181 @overload
182 def on_recv(
183 self,
184 callback: Callable[[list[zmq.Frame]], Any],
185 copy: Literal[False],
186 ) -> None: ...
187
188 @overload
189 def on_recv(
190 self,
191 callback: Callable[[list[zmq.Frame]], Any] | Callable[[list[bytes]], Any],
192 copy: bool = ...,
193 ): ...
194
195 def on_recv(
196 self,
197 callback: Callable[[list[zmq.Frame]], Any] | Callable[[list[bytes]], Any],
198 copy: bool = True,
199 ) -> None:
200 """Register a callback for when a message is ready to recv.
201
202 There can be only one callback registered at a time, so each
203 call to `on_recv` replaces previously registered callbacks.
204
205 on_recv(None) disables recv event polling.
206
207 Use on_recv_stream(callback) instead, to register a callback that will receive
208 both this ZMQStream and the message, instead of just the message.
209
210 Parameters
211 ----------
212
213 callback : callable
214 callback must take exactly one argument, which will be a
215 list, as returned by socket.recv_multipart()
216 if callback is None, recv callbacks are disabled.
217 copy : bool
218 copy is passed directly to recv, so if copy is False,
219 callback will receive Message objects. If copy is True,
220 then callback will receive bytes/str objects.
221
222 Returns : None
223 """
224
225 self._check_closed()
226 assert callback is None or callable(callback)
227 self._recv_callback = callback
228 self._recv_copy = copy
229 if callback is None:
230 self._drop_io_state(zmq.POLLIN)
231 else:
232 self._add_io_state(zmq.POLLIN)
233
234 @overload
235 def on_recv_stream(
236 self,
237 callback: Callable[[ZMQStream, list[bytes]], Any],
238 ) -> None: ...
239
240 @overload
241 def on_recv_stream(
242 self,
243 callback: Callable[[ZMQStream, list[bytes]], Any],
244 copy: Literal[True],
245 ) -> None: ...
246
247 @overload
248 def on_recv_stream(
249 self,
250 callback: Callable[[ZMQStream, list[zmq.Frame]], Any],
251 copy: Literal[False],
252 ) -> None: ...
253
254 @overload
255 def on_recv_stream(
256 self,
257 callback: (
258 Callable[[ZMQStream, list[zmq.Frame]], Any]
259 | Callable[[ZMQStream, list[bytes]], Any]
260 ),
261 copy: bool = ...,
262 ): ...
263
264 def on_recv_stream(
265 self,
266 callback: (
267 Callable[[ZMQStream, list[zmq.Frame]], Any]
268 | Callable[[ZMQStream, list[bytes]], Any]
269 ),
270 copy: bool = True,
271 ):
272 """Same as on_recv, but callback will get this stream as first argument
273
274 callback must take exactly two arguments, as it will be called as::
275
276 callback(stream, msg)
277
278 Useful when a single callback should be used with multiple streams.
279 """
280 if callback is None:
281 self.stop_on_recv()
282 else:
283
284 def stream_callback(msg):
285 return callback(self, msg)
286
287 self.on_recv(stream_callback, copy=copy)
288
289 def on_send(
290 self, callback: Callable[[Sequence[Any], zmq.MessageTracker | None], Any]
291 ):
292 """Register a callback to be called on each send
293
294 There will be two arguments::
295
296 callback(msg, status)
297
298 * `msg` will be the list of sendable objects that was just sent
299 * `status` will be the return result of socket.send_multipart(msg) -
300 MessageTracker or None.
301
302 Non-copying sends return a MessageTracker object whose
303 `done` attribute will be True when the send is complete.
304 This allows users to track when an object is safe to write to
305 again.
306
307 The second argument will always be None if copy=True
308 on the send.
309
310 Use on_send_stream(callback) to register a callback that will be passed
311 this ZMQStream as the first argument, in addition to the other two.
312
313 on_send(None) disables recv event polling.
314
315 Parameters
316 ----------
317
318 callback : callable
319 callback must take exactly two arguments, which will be
320 the message being sent (always a list),
321 and the return result of socket.send_multipart(msg) -
322 MessageTracker or None.
323
324 if callback is None, send callbacks are disabled.
325 """
326
327 self._check_closed()
328 assert callback is None or callable(callback)
329 self._send_callback = callback
330
331 def on_send_stream(
332 self,
333 callback: Callable[[ZMQStream, Sequence[Any], zmq.MessageTracker | None], Any],
334 ):
335 """Same as on_send, but callback will get this stream as first argument
336
337 Callback will be passed three arguments::
338
339 callback(stream, msg, status)
340
341 Useful when a single callback should be used with multiple streams.
342 """
343 if callback is None:
344 self.stop_on_send()
345 else:
346 self.on_send(lambda msg, status: callback(self, msg, status))
347
348 def send(self, msg, flags=0, copy=True, track=False, callback=None, **kwargs):
349 """Send a message, optionally also register a new callback for sends.
350 See zmq.socket.send for details.
351 """
352 return self.send_multipart(
353 [msg], flags=flags, copy=copy, track=track, callback=callback, **kwargs
354 )
355
356 def send_multipart(
357 self,
358 msg: Sequence[Any],
359 flags: int = 0,
360 copy: bool = True,
361 track: bool = False,
362 callback: Callable | None = None,
363 **kwargs: Any,
364 ) -> None:
365 """Send a multipart message, optionally also register a new callback for sends.
366 See zmq.socket.send_multipart for details.
367 """
368 kwargs.update(dict(flags=flags, copy=copy, track=track))
369 self._send_queue.put((msg, kwargs))
370 callback = callback or self._send_callback
371 if callback is not None:
372 self.on_send(callback)
373 else:
374 # noop callback
375 self.on_send(lambda *args: None)
376 self._add_io_state(zmq.POLLOUT)
377
378 def send_string(
379 self,
380 u: str,
381 flags: int = 0,
382 encoding: str = 'utf-8',
383 callback: Callable | None = None,
384 **kwargs: Any,
385 ):
386 """Send a unicode message with an encoding.
387 See zmq.socket.send_unicode for details.
388 """
389 if not isinstance(u, str):
390 raise TypeError("unicode/str objects only")
391 return self.send(u.encode(encoding), flags=flags, callback=callback, **kwargs)
392
393 send_unicode = send_string
394
395 def send_json(
396 self,
397 obj: Any,
398 flags: int = 0,
399 callback: Callable | None = None,
400 **kwargs: Any,
401 ):
402 """Send json-serialized version of an object.
403 See zmq.socket.send_json for details.
404 """
405 msg = jsonapi.dumps(obj)
406 return self.send(msg, flags=flags, callback=callback, **kwargs)
407
408 def send_pyobj(
409 self,
410 obj: Any,
411 flags: int = 0,
412 protocol: int = -1,
413 callback: Callable | None = None,
414 **kwargs: Any,
415 ):
416 """Send a Python object as a message using pickle to serialize.
417
418 See zmq.socket.send_json for details.
419 """
420 msg = pickle.dumps(obj, protocol)
421 return self.send(msg, flags, callback=callback, **kwargs)
422
423 def _finish_flush(self):
424 """callback for unsetting _flushed flag."""
425 self._flushed = False
426
427 def flush(self, flag: int = zmq.POLLIN | zmq.POLLOUT, limit: int | None = None):
428 """Flush pending messages.
429
430 This method safely handles all pending incoming and/or outgoing messages,
431 bypassing the inner loop, passing them to the registered callbacks.
432
433 A limit can be specified, to prevent blocking under high load.
434
435 flush will return the first time ANY of these conditions are met:
436 * No more events matching the flag are pending.
437 * the total number of events handled reaches the limit.
438
439 Note that if ``flag|POLLIN != 0``, recv events will be flushed even if no callback
440 is registered, unlike normal IOLoop operation. This allows flush to be
441 used to remove *and ignore* incoming messages.
442
443 Parameters
444 ----------
445 flag : int
446 default=POLLIN|POLLOUT
447 0MQ poll flags.
448 If flag|POLLIN, recv events will be flushed.
449 If flag|POLLOUT, send events will be flushed.
450 Both flags can be set at once, which is the default.
451 limit : None or int, optional
452 The maximum number of messages to send or receive.
453 Both send and recv count against this limit.
454
455 Returns
456 -------
457 int :
458 count of events handled (both send and recv)
459 """
460 self._check_closed()
461 # unset self._flushed, so callbacks will execute, in case flush has
462 # already been called this iteration
463 already_flushed = self._flushed
464 self._flushed = False
465 # initialize counters
466 count = 0
467
468 def update_flag():
469 """Update the poll flag, to prevent registering POLLOUT events
470 if we don't have pending sends."""
471 return flag & zmq.POLLIN | (self.sending() and flag & zmq.POLLOUT)
472
473 flag = update_flag()
474 if not flag:
475 # nothing to do
476 return 0
477 self.poller.register(self.socket, flag)
478 events = self.poller.poll(0)
479 while events and (not limit or count < limit):
480 s, event = events[0]
481 if event & POLLIN: # receiving
482 self._handle_recv()
483 count += 1
484 if self.socket is None:
485 # break if socket was closed during callback
486 break
487 if event & POLLOUT and self.sending():
488 self._handle_send()
489 count += 1
490 if self.socket is None:
491 # break if socket was closed during callback
492 break
493
494 flag = update_flag()
495 if flag:
496 self.poller.register(self.socket, flag)
497 events = self.poller.poll(0)
498 else:
499 events = []
500 if count: # only bypass loop if we actually flushed something
501 # skip send/recv callbacks this iteration
502 self._flushed = True
503 # reregister them at the end of the loop
504 if not already_flushed: # don't need to do it again
505 self.io_loop.add_callback(self._finish_flush)
506 elif already_flushed:
507 self._flushed = True
508
509 # update ioloop poll state, which may have changed
510 self._rebuild_io_state()
511 return count
512
513 def set_close_callback(self, callback: Callable | None):
514 """Call the given callback when the stream is closed."""
515 self._close_callback = callback
516
517 def close(self, linger: int | None = None) -> None:
518 """Close this stream."""
519 if self.socket is not None:
520 if self.socket.closed:
521 # fallback on raw fd for closed sockets
522 # hopefully this happened promptly after close,
523 # otherwise somebody else may have the FD
524 warnings.warn(
525 "Unregistering FD %s after closing socket. "
526 "This could result in unregistering handlers for the wrong socket. "
527 "Please use stream.close() instead of closing the socket directly."
528 % self._fd,
529 stacklevel=2,
530 )
531 self.io_loop.remove_handler(self._fd)
532 else:
533 self.io_loop.remove_handler(self.socket)
534 self.socket.close(linger)
535 self.socket = None # type: ignore
536 if self._close_callback:
537 self._run_callback(self._close_callback)
538
539 def receiving(self) -> bool:
540 """Returns True if we are currently receiving from the stream."""
541 return self._recv_callback is not None
542
543 def sending(self) -> bool:
544 """Returns True if we are currently sending to the stream."""
545 return not self._send_queue.empty()
546
547 def closed(self) -> bool:
548 if self.socket is None:
549 return True
550 if self.socket.closed:
551 # underlying socket has been closed, but not by us!
552 # trigger our cleanup
553 self.close()
554 return True
555 return False
556
557 def _run_callback(self, callback, *args, **kwargs):
558 """Wrap running callbacks in try/except to allow us to
559 close our socket."""
560 try:
561 f = callback(*args, **kwargs)
562 if isinstance(f, Awaitable):
563 f = asyncio.ensure_future(f)
564 else:
565 f = None
566 except Exception:
567 gen_log.error("Uncaught exception in ZMQStream callback", exc_info=True)
568 # Re-raise the exception so that IOLoop.handle_callback_exception
569 # can see it and log the error
570 raise
571
572 if f is not None:
573 # handle async callbacks
574 def _log_error(f):
575 try:
576 f.result()
577 except Exception:
578 gen_log.error(
579 "Uncaught exception in ZMQStream callback", exc_info=True
580 )
581
582 f.add_done_callback(_log_error)
583
584 def _handle_events(self, fd, events):
585 """This method is the actual handler for IOLoop, that gets called whenever
586 an event on my socket is posted. It dispatches to _handle_recv, etc."""
587 if not self.socket:
588 gen_log.warning("Got events for closed stream %s", self)
589 return
590 try:
591 zmq_events = self.socket.EVENTS
592 except zmq.ContextTerminated:
593 gen_log.warning("Got events for stream %s after terminating context", self)
594 # trigger close check, this will unregister callbacks
595 self.closed()
596 return
597 except zmq.ZMQError as e:
598 # run close check
599 # shadow sockets may have been closed elsewhere,
600 # which should show up as ENOTSOCK here
601 if self.closed():
602 gen_log.warning(
603 "Got events for stream %s attached to closed socket: %s", self, e
604 )
605 else:
606 gen_log.error("Error getting events for %s: %s", self, e)
607 return
608 try:
609 # dispatch events:
610 if zmq_events & zmq.POLLIN and self.receiving():
611 self._handle_recv()
612 if not self.socket:
613 return
614 if zmq_events & zmq.POLLOUT and self.sending():
615 self._handle_send()
616 if not self.socket:
617 return
618
619 # rebuild the poll state
620 self._rebuild_io_state()
621 except Exception:
622 gen_log.error("Uncaught exception in zmqstream callback", exc_info=True)
623 raise
624
625 def _handle_recv(self):
626 """Handle a recv event."""
627 if self._flushed:
628 return
629 try:
630 msg = self.socket.recv_multipart(zmq.NOBLOCK, copy=self._recv_copy)
631 except zmq.ZMQError as e:
632 if e.errno == zmq.EAGAIN:
633 # state changed since poll event
634 pass
635 else:
636 raise
637 else:
638 if self._recv_callback:
639 callback = self._recv_callback
640 self._run_callback(callback, msg)
641
642 def _handle_send(self):
643 """Handle a send event."""
644 if self._flushed:
645 return
646 if not self.sending():
647 gen_log.error("Shouldn't have handled a send event")
648 return
649
650 msg, kwargs = self._send_queue.get()
651 try:
652 status = self.socket.send_multipart(msg, **kwargs)
653 except zmq.ZMQError as e:
654 gen_log.error("SEND Error: %s", e)
655 status = e
656 if self._send_callback:
657 callback = self._send_callback
658 self._run_callback(callback, msg, status)
659
660 def _check_closed(self):
661 if not self.socket:
662 raise OSError("Stream is closed")
663
664 def _rebuild_io_state(self):
665 """rebuild io state based on self.sending() and receiving()"""
666 if self.socket is None:
667 return
668 state = 0
669 if self.receiving():
670 state |= zmq.POLLIN
671 if self.sending():
672 state |= zmq.POLLOUT
673
674 self._state = state
675 self._update_handler(state)
676
677 def _add_io_state(self, state):
678 """Add io_state to poller."""
679 self._state = self._state | state
680 self._update_handler(self._state)
681
682 def _drop_io_state(self, state):
683 """Stop poller from watching an io_state."""
684 self._state = self._state & (~state)
685 self._update_handler(self._state)
686
687 def _update_handler(self, state):
688 """Update IOLoop handler with state."""
689 if self.socket is None:
690 return
691
692 if state & self.socket.events:
693 # events still exist that haven't been processed
694 # explicitly schedule handling to avoid missing events due to edge-triggered FDs
695 self.io_loop.add_callback(lambda: self._handle_events(self.socket, 0))
696
697 def _init_io_state(self):
698 """initialize the ioloop event handler"""
699 self.io_loop.add_handler(self.socket, self._handle_events, self.io_loop.READ)