1"""Base classes to manage a Client's interaction with a running kernel"""
2
3# Copyright (c) Jupyter Development Team.
4# Distributed under the terms of the Modified BSD License.
5import asyncio
6import atexit
7import time
8import typing as t
9from queue import Empty
10from threading import Event, Thread
11
12import zmq.asyncio
13from jupyter_core.utils import ensure_async
14
15from ._version import protocol_version_info
16from .channelsabc import HBChannelABC
17from .session import Session
18
19# import ZMQError in top-level namespace, to avoid ugly attribute-error messages
20# during garbage collection of threads at exit
21
22# -----------------------------------------------------------------------------
23# Constants and exceptions
24# -----------------------------------------------------------------------------
25
26major_protocol_version = protocol_version_info[0]
27
28
29class InvalidPortNumber(Exception): # noqa
30 """An exception raised for an invalid port number."""
31
32 pass
33
34
35class HBChannel(Thread):
36 """The heartbeat channel which monitors the kernel heartbeat.
37
38 Note that the heartbeat channel is paused by default. As long as you start
39 this channel, the kernel manager will ensure that it is paused and un-paused
40 as appropriate.
41 """
42
43 session = None
44 socket = None
45 address = None
46 _exiting = False
47
48 time_to_dead: float = 1.0
49 _running = None
50 _pause = None
51 _beating = None
52
53 def __init__(
54 self,
55 context: zmq.Context | None = None,
56 session: Session | None = None,
57 address: t.Union[t.Tuple[str, int], str] = "",
58 ) -> None:
59 """Create the heartbeat monitor thread.
60
61 Parameters
62 ----------
63 context : :class:`zmq.Context`
64 The ZMQ context to use.
65 session : :class:`session.Session`
66 The session to use.
67 address : zmq url
68 Standard (ip, port) tuple that the kernel is listening on.
69 """
70 super().__init__()
71 self.daemon = True
72
73 self.context = context
74 self.session = session
75 if isinstance(address, tuple):
76 if address[1] == 0:
77 message = "The port number for a channel cannot be 0."
78 raise InvalidPortNumber(message)
79 address_str = "tcp://%s:%i" % address
80 else:
81 address_str = address
82 self.address = address_str
83
84 # running is False until `.start()` is called
85 self._running = False
86 self._exit = Event()
87 # don't start paused
88 self._pause = False
89 self.poller = zmq.Poller()
90
91 @staticmethod
92 @atexit.register
93 def _notice_exit() -> None:
94 # Class definitions can be torn down during interpreter shutdown.
95 # We only need to set _exiting flag if this hasn't happened.
96 if HBChannel is not None:
97 HBChannel._exiting = True
98
99 def _create_socket(self) -> None:
100 if self.socket is not None:
101 # close previous socket, before opening a new one
102 self.poller.unregister(self.socket) # type:ignore[unreachable]
103 self.socket.close()
104 assert self.context is not None
105 self.socket = self.context.socket(zmq.REQ)
106 self.socket.linger = 1000
107 assert self.address is not None
108 self.socket.connect(self.address)
109
110 self.poller.register(self.socket, zmq.POLLIN)
111
112 async def _async_run(self) -> None:
113 """The thread's main activity. Call start() instead."""
114 self._create_socket()
115 self._running = True
116 self._beating = True
117 assert self.socket is not None
118
119 while self._running:
120 if self._pause:
121 # just sleep, and skip the rest of the loop
122 self._exit.wait(self.time_to_dead)
123 continue
124
125 since_last_heartbeat = 0.0
126 # no need to catch EFSM here, because the previous event was
127 # either a recv or connect, which cannot be followed by EFSM)
128 await ensure_async(self.socket.send(b"ping"))
129 request_time = time.time()
130 # Wait until timeout
131 self._exit.wait(self.time_to_dead)
132 # poll(0) means return immediately (see http://api.zeromq.org/2-1:zmq-poll)
133 self._beating = bool(self.poller.poll(0))
134 if self._beating:
135 # the poll above guarantees we have something to recv
136 await ensure_async(self.socket.recv())
137 continue
138 elif self._running:
139 # nothing was received within the time limit, signal heart failure
140 since_last_heartbeat = time.time() - request_time
141 self.call_handlers(since_last_heartbeat)
142 # and close/reopen the socket, because the REQ/REP cycle has been broken
143 self._create_socket()
144 continue
145
146 def run(self) -> None:
147 """Run the heartbeat thread."""
148 loop = asyncio.new_event_loop()
149 asyncio.set_event_loop(loop)
150 try:
151 loop.run_until_complete(self._async_run())
152 finally:
153 loop.close()
154
155 def pause(self) -> None:
156 """Pause the heartbeat."""
157 self._pause = True
158
159 def unpause(self) -> None:
160 """Unpause the heartbeat."""
161 self._pause = False
162
163 def is_beating(self) -> bool:
164 """Is the heartbeat running and responsive (and not paused)."""
165 if self.is_alive() and not self._pause and self._beating: # noqa
166 return True
167 else:
168 return False
169
170 def stop(self) -> None:
171 """Stop the channel's event loop and join its thread."""
172 self._running = False
173 self._exit.set()
174 self.join()
175 self.close()
176
177 def close(self) -> None:
178 """Close the heartbeat thread."""
179 if self.socket is not None:
180 try:
181 self.socket.close(linger=0)
182 except Exception:
183 pass
184 self.socket = None
185
186 def call_handlers(self, since_last_heartbeat: float) -> None:
187 """This method is called in the ioloop thread when a message arrives.
188
189 Subclasses should override this method to handle incoming messages.
190 It is important to remember that this method is called in the thread
191 so that some logic must be done to ensure that the application level
192 handlers are called in the application thread.
193 """
194 pass
195
196
197HBChannelABC.register(HBChannel)
198
199
200class ZMQSocketChannel:
201 """A ZMQ socket wrapper"""
202
203 def __init__(self, socket: zmq.Socket, session: Session, loop: t.Any = None) -> None:
204 """Create a channel.
205
206 Parameters
207 ----------
208 socket : :class:`zmq.Socket`
209 The ZMQ socket to use.
210 session : :class:`session.Session`
211 The session to use.
212 loop
213 Unused here, for other implementations
214 """
215 super().__init__()
216
217 self.socket: zmq.Socket | None = socket
218 self.session = session
219
220 def _recv(self, **kwargs: t.Any) -> t.Dict[str, t.Any]:
221 assert self.socket is not None
222 msg = self.socket.recv_multipart(**kwargs)
223 _ident, smsg = self.session.feed_identities(msg)
224 return self.session.deserialize(smsg)
225
226 def get_msg(self, timeout: float | None = None) -> t.Dict[str, t.Any]:
227 """Gets a message if there is one that is ready."""
228 assert self.socket is not None
229 timeout_ms = None if timeout is None else int(timeout * 1000) # seconds to ms
230 ready = self.socket.poll(timeout_ms)
231 if ready:
232 res = self._recv()
233 return res
234 else:
235 raise Empty
236
237 def get_msgs(self) -> t.List[t.Dict[str, t.Any]]:
238 """Get all messages that are currently ready."""
239 msgs = []
240 while True:
241 try:
242 msgs.append(self.get_msg())
243 except Empty:
244 break
245 return msgs
246
247 def msg_ready(self) -> bool:
248 """Is there a message that has been received?"""
249 assert self.socket is not None
250 return bool(self.socket.poll(timeout=0))
251
252 def close(self) -> None:
253 """Close the socket channel."""
254 if self.socket is not None:
255 try:
256 self.socket.close(linger=0)
257 except Exception:
258 pass
259 self.socket = None
260
261 stop = close
262
263 def is_alive(self) -> bool:
264 """Test whether the channel is alive."""
265 return self.socket is not None
266
267 def send(self, msg: t.Dict[str, t.Any]) -> None:
268 """Pass a message to the ZMQ socket to send"""
269 assert self.socket is not None
270 self.session.send(self.socket, msg)
271
272 def start(self) -> None:
273 """Start the socket channel."""
274 pass
275
276
277class AsyncZMQSocketChannel(ZMQSocketChannel):
278 """A ZMQ socket in an async API"""
279
280 socket: zmq.asyncio.Socket
281
282 def __init__(self, socket: zmq.asyncio.Socket, session: Session, loop: t.Any = None) -> None:
283 """Create a channel.
284
285 Parameters
286 ----------
287 socket : :class:`zmq.asyncio.Socket`
288 The ZMQ socket to use.
289 session : :class:`session.Session`
290 The session to use.
291 loop
292 Unused here, for other implementations
293 """
294 if not isinstance(socket, zmq.asyncio.Socket):
295 msg = "Socket must be asyncio" # type:ignore[unreachable]
296 raise ValueError(msg)
297 super().__init__(socket, session)
298
299 async def _recv(self, **kwargs: t.Any) -> t.Dict[str, t.Any]: # type:ignore[override]
300 assert self.socket is not None
301 msg = await self.socket.recv_multipart(**kwargs)
302 _, smsg = self.session.feed_identities(msg)
303 return self.session.deserialize(smsg)
304
305 async def get_msg( # type:ignore[override]
306 self, timeout: float | None = None
307 ) -> t.Dict[str, t.Any]:
308 """Gets a message if there is one that is ready."""
309 assert self.socket is not None
310 timeout_ms = None if timeout is None else int(timeout * 1000) # seconds to ms
311 ready = await self.socket.poll(timeout_ms)
312 if ready:
313 res = await self._recv()
314 return res
315 else:
316 raise Empty
317
318 async def get_msgs(self) -> t.List[t.Dict[str, t.Any]]: # type:ignore[override]
319 """Get all messages that are currently ready."""
320 msgs = []
321 while True:
322 try:
323 msgs.append(await self.get_msg())
324 except Empty:
325 break
326 return msgs
327
328 async def msg_ready(self) -> bool: # type:ignore[override]
329 """Is there a message that has been received?"""
330 assert self.socket is not None
331 return bool(await self.socket.poll(timeout=0))