1"""Tracker for zero-copy messages with 0MQ."""
2
3# Copyright (C) PyZMQ Developers
4# Distributed under the terms of the Modified BSD License.
5
6from __future__ import annotations
7
8import time
9from threading import Event
10
11from zmq.backend import Frame
12from zmq.error import NotDone
13
14
15class MessageTracker:
16 """A class for tracking if 0MQ is done using one or more messages.
17
18 When you send a 0MQ message, it is not sent immediately. The 0MQ IO thread
19 sends the message at some later time. Often you want to know when 0MQ has
20 actually sent the message though. This is complicated by the fact that
21 a single 0MQ message can be sent multiple times using different sockets.
22 This class allows you to track all of the 0MQ usages of a message.
23
24 Parameters
25 ----------
26 towatch : Event, MessageTracker, zmq.Frame
27 This objects to track. This class can track the low-level
28 Events used by the Message class, other MessageTrackers or
29 actual Messages.
30 """
31
32 events: set[Event]
33 peers: set[MessageTracker]
34
35 def __init__(self, *towatch: tuple[MessageTracker | Event | Frame]):
36 """Create a message tracker to track a set of messages.
37
38 Parameters
39 ----------
40 *towatch : tuple of Event, MessageTracker, Message instances.
41 This list of objects to track. This class can track the low-level
42 Events used by the Message class, other MessageTrackers or
43 actual Messages.
44 """
45 self.events = set()
46 self.peers = set()
47 for obj in towatch:
48 if isinstance(obj, Event):
49 self.events.add(obj)
50 elif isinstance(obj, MessageTracker):
51 self.peers.add(obj)
52 elif isinstance(obj, Frame):
53 if not obj.tracker:
54 raise ValueError("Not a tracked message")
55 self.peers.add(obj.tracker)
56 else:
57 raise TypeError(f"Require Events or Message Frames, not {type(obj)}")
58
59 @property
60 def done(self):
61 """Is 0MQ completely done with the message(s) being tracked?"""
62 for evt in self.events:
63 if not evt.is_set():
64 return False
65 for pm in self.peers:
66 if not pm.done:
67 return False
68 return True
69
70 def wait(self, timeout: float | int = -1):
71 """Wait for 0MQ to be done with the message or until `timeout`.
72
73 Parameters
74 ----------
75 timeout : float
76 default: -1, which means wait forever.
77 Maximum time in (s) to wait before raising NotDone.
78
79 Returns
80 -------
81 None
82 if done before `timeout`
83
84 Raises
85 ------
86 NotDone
87 if `timeout` reached before I am done.
88 """
89 tic = time.time()
90 remaining: float
91 if timeout is False or timeout < 0:
92 remaining = 3600 * 24 * 7 # a week
93 else:
94 remaining = timeout
95 for evt in self.events:
96 if remaining < 0:
97 raise NotDone
98 evt.wait(timeout=remaining)
99 if not evt.is_set():
100 raise NotDone
101 toc = time.time()
102 remaining -= toc - tic
103 tic = toc
104
105 for peer in self.peers:
106 if remaining < 0:
107 raise NotDone
108 peer.wait(timeout=remaining)
109 toc = time.time()
110 remaining -= toc - tic
111 tic = toc
112
113
114_FINISHED_TRACKER = MessageTracker()
115
116__all__ = ['MessageTracker', '_FINISHED_TRACKER']