1"""Base classes to manage a Client's interaction with a running kernel"""
2# Copyright (c) Jupyter Development Team.
3# Distributed under the terms of the Modified BSD License.
4import asyncio
5import atexit
6import time
7import typing as t
8from queue import Empty
9from threading import Event, Thread
10
11import zmq.asyncio
12from jupyter_core.utils import ensure_async
13
14from ._version import protocol_version_info
15from .channelsabc import HBChannelABC
16from .session import Session
17
18# import ZMQError in top-level namespace, to avoid ugly attribute-error messages
19# during garbage collection of threads at exit
20
21# -----------------------------------------------------------------------------
22# Constants and exceptions
23# -----------------------------------------------------------------------------
24
25major_protocol_version = protocol_version_info[0]
26
27
28class InvalidPortNumber(Exception): # noqa
29 """An exception raised for an invalid port number."""
30
31 pass
32
33
34class HBChannel(Thread):
35 """The heartbeat channel which monitors the kernel heartbeat.
36
37 Note that the heartbeat channel is paused by default. As long as you start
38 this channel, the kernel manager will ensure that it is paused and un-paused
39 as appropriate.
40 """
41
42 session = None
43 socket = None
44 address = None
45 _exiting = False
46
47 time_to_dead: float = 1.0
48 _running = None
49 _pause = None
50 _beating = None
51
52 def __init__(
53 self,
54 context: t.Optional[zmq.Context] = None,
55 session: t.Optional[Session] = None,
56 address: t.Union[t.Tuple[str, int], str] = "",
57 ) -> None:
58 """Create the heartbeat monitor thread.
59
60 Parameters
61 ----------
62 context : :class:`zmq.Context`
63 The ZMQ context to use.
64 session : :class:`session.Session`
65 The session to use.
66 address : zmq url
67 Standard (ip, port) tuple that the kernel is listening on.
68 """
69 super().__init__()
70 self.daemon = True
71
72 self.context = context
73 self.session = session
74 if isinstance(address, tuple):
75 if address[1] == 0:
76 message = "The port number for a channel cannot be 0."
77 raise InvalidPortNumber(message)
78 address_str = "tcp://%s:%i" % address
79 else:
80 address_str = address
81 self.address = address_str
82
83 # running is False until `.start()` is called
84 self._running = False
85 self._exit = Event()
86 # don't start paused
87 self._pause = False
88 self.poller = zmq.Poller()
89
90 @staticmethod
91 @atexit.register
92 def _notice_exit() -> None:
93 # Class definitions can be torn down during interpreter shutdown.
94 # We only need to set _exiting flag if this hasn't happened.
95 if HBChannel is not None:
96 HBChannel._exiting = True
97
98 def _create_socket(self) -> None:
99 if self.socket is not None:
100 # close previous socket, before opening a new one
101 self.poller.unregister(self.socket) # type:ignore[unreachable]
102 self.socket.close()
103 assert self.context is not None
104 self.socket = self.context.socket(zmq.REQ)
105 self.socket.linger = 1000
106 assert self.address is not None
107 self.socket.connect(self.address)
108
109 self.poller.register(self.socket, zmq.POLLIN)
110
111 async def _async_run(self) -> None:
112 """The thread's main activity. Call start() instead."""
113 self._create_socket()
114 self._running = True
115 self._beating = True
116 assert self.socket is not None
117
118 while self._running:
119 if self._pause:
120 # just sleep, and skip the rest of the loop
121 self._exit.wait(self.time_to_dead)
122 continue
123
124 since_last_heartbeat = 0.0
125 # no need to catch EFSM here, because the previous event was
126 # either a recv or connect, which cannot be followed by EFSM)
127 await ensure_async(self.socket.send(b"ping"))
128 request_time = time.time()
129 # Wait until timeout
130 self._exit.wait(self.time_to_dead)
131 # poll(0) means return immediately (see http://api.zeromq.org/2-1:zmq-poll)
132 self._beating = bool(self.poller.poll(0))
133 if self._beating:
134 # the poll above guarantees we have something to recv
135 await ensure_async(self.socket.recv())
136 continue
137 elif self._running:
138 # nothing was received within the time limit, signal heart failure
139 since_last_heartbeat = time.time() - request_time
140 self.call_handlers(since_last_heartbeat)
141 # and close/reopen the socket, because the REQ/REP cycle has been broken
142 self._create_socket()
143 continue
144
145 def run(self) -> None:
146 """Run the heartbeat thread."""
147 loop = asyncio.new_event_loop()
148 asyncio.set_event_loop(loop)
149 loop.run_until_complete(self._async_run())
150 loop.close()
151
152 def pause(self) -> None:
153 """Pause the heartbeat."""
154 self._pause = True
155
156 def unpause(self) -> None:
157 """Unpause the heartbeat."""
158 self._pause = False
159
160 def is_beating(self) -> bool:
161 """Is the heartbeat running and responsive (and not paused)."""
162 if self.is_alive() and not self._pause and self._beating: # noqa
163 return True
164 else:
165 return False
166
167 def stop(self) -> None:
168 """Stop the channel's event loop and join its thread."""
169 self._running = False
170 self._exit.set()
171 self.join()
172 self.close()
173
174 def close(self) -> None:
175 """Close the heartbeat thread."""
176 if self.socket is not None:
177 try:
178 self.socket.close(linger=0)
179 except Exception:
180 pass
181 self.socket = None
182
183 def call_handlers(self, since_last_heartbeat: float) -> None:
184 """This method is called in the ioloop thread when a message arrives.
185
186 Subclasses should override this method to handle incoming messages.
187 It is important to remember that this method is called in the thread
188 so that some logic must be done to ensure that the application level
189 handlers are called in the application thread.
190 """
191 pass
192
193
194HBChannelABC.register(HBChannel)
195
196
197class ZMQSocketChannel:
198 """A ZMQ socket wrapper"""
199
200 def __init__(self, socket: zmq.Socket, session: Session, loop: t.Any = None) -> None:
201 """Create a channel.
202
203 Parameters
204 ----------
205 socket : :class:`zmq.Socket`
206 The ZMQ socket to use.
207 session : :class:`session.Session`
208 The session to use.
209 loop
210 Unused here, for other implementations
211 """
212 super().__init__()
213
214 self.socket: t.Optional[zmq.Socket] = socket
215 self.session = session
216
217 def _recv(self, **kwargs: t.Any) -> t.Dict[str, t.Any]:
218 assert self.socket is not None
219 msg = self.socket.recv_multipart(**kwargs)
220 ident, smsg = self.session.feed_identities(msg)
221 return self.session.deserialize(smsg)
222
223 def get_msg(self, timeout: t.Optional[float] = None) -> t.Dict[str, t.Any]:
224 """Gets a message if there is one that is ready."""
225 assert self.socket is not None
226 timeout_ms = None if timeout is None else int(timeout * 1000) # seconds to ms
227 ready = self.socket.poll(timeout_ms)
228 if ready:
229 res = self._recv()
230 return res
231 else:
232 raise Empty
233
234 def get_msgs(self) -> t.List[t.Dict[str, t.Any]]:
235 """Get all messages that are currently ready."""
236 msgs = []
237 while True:
238 try:
239 msgs.append(self.get_msg())
240 except Empty:
241 break
242 return msgs
243
244 def msg_ready(self) -> bool:
245 """Is there a message that has been received?"""
246 assert self.socket is not None
247 return bool(self.socket.poll(timeout=0))
248
249 def close(self) -> None:
250 """Close the socket channel."""
251 if self.socket is not None:
252 try:
253 self.socket.close(linger=0)
254 except Exception:
255 pass
256 self.socket = None
257
258 stop = close
259
260 def is_alive(self) -> bool:
261 """Test whether the channel is alive."""
262 return self.socket is not None
263
264 def send(self, msg: t.Dict[str, t.Any]) -> None:
265 """Pass a message to the ZMQ socket to send"""
266 assert self.socket is not None
267 self.session.send(self.socket, msg)
268
269 def start(self) -> None:
270 """Start the socket channel."""
271 pass
272
273
274class AsyncZMQSocketChannel(ZMQSocketChannel):
275 """A ZMQ socket in an async API"""
276
277 socket: zmq.asyncio.Socket
278
279 def __init__(self, socket: zmq.asyncio.Socket, session: Session, loop: t.Any = None) -> None:
280 """Create a channel.
281
282 Parameters
283 ----------
284 socket : :class:`zmq.asyncio.Socket`
285 The ZMQ socket to use.
286 session : :class:`session.Session`
287 The session to use.
288 loop
289 Unused here, for other implementations
290 """
291 if not isinstance(socket, zmq.asyncio.Socket):
292 msg = "Socket must be asyncio" # type:ignore[unreachable]
293 raise ValueError(msg)
294 super().__init__(socket, session)
295
296 async def _recv(self, **kwargs: t.Any) -> t.Dict[str, t.Any]: # type:ignore[override]
297 assert self.socket is not None
298 msg = await self.socket.recv_multipart(**kwargs)
299 _, smsg = self.session.feed_identities(msg)
300 return self.session.deserialize(smsg)
301
302 async def get_msg( # type:ignore[override]
303 self, timeout: t.Optional[float] = None
304 ) -> t.Dict[str, t.Any]:
305 """Gets a message if there is one that is ready."""
306 assert self.socket is not None
307 timeout_ms = None if timeout is None else int(timeout * 1000) # seconds to ms
308 ready = await self.socket.poll(timeout_ms)
309 if ready:
310 res = await self._recv()
311 return res
312 else:
313 raise Empty
314
315 async def get_msgs(self) -> t.List[t.Dict[str, t.Any]]: # type:ignore[override]
316 """Get all messages that are currently ready."""
317 msgs = []
318 while True:
319 try:
320 msgs.append(await self.get_msg())
321 except Empty:
322 break
323 return msgs
324
325 async def msg_ready(self) -> bool: # type:ignore[override]
326 """Is there a message that has been received?"""
327 assert self.socket is not None
328 return bool(await self.socket.poll(timeout=0))