Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/jupyter_client/channels.py: 32%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

181 statements  

1"""Base classes to manage a Client's interaction with a running kernel""" 

2 

3# Copyright (c) Jupyter Development Team. 

4# Distributed under the terms of the Modified BSD License. 

5import asyncio 

6import atexit 

7import time 

8import typing as t 

9from queue import Empty 

10from threading import Event, Thread 

11 

12import zmq.asyncio 

13from jupyter_core.utils import ensure_async 

14 

15from ._version import protocol_version_info 

16from .channelsabc import HBChannelABC 

17from .session import Session 

18 

19# import ZMQError in top-level namespace, to avoid ugly attribute-error messages 

20# during garbage collection of threads at exit 

21 

22# ----------------------------------------------------------------------------- 

23# Constants and exceptions 

24# ----------------------------------------------------------------------------- 

25 

26major_protocol_version = protocol_version_info[0] 

27 

28 

29class InvalidPortNumber(Exception): # noqa 

30 """An exception raised for an invalid port number.""" 

31 

32 pass 

33 

34 

35class HBChannel(Thread): 

36 """The heartbeat channel which monitors the kernel heartbeat. 

37 

38 Note that the heartbeat channel is paused by default. As long as you start 

39 this channel, the kernel manager will ensure that it is paused and un-paused 

40 as appropriate. 

41 """ 

42 

43 session = None 

44 socket = None 

45 address = None 

46 _exiting = False 

47 

48 time_to_dead: float = 1.0 

49 _running = None 

50 _pause = None 

51 _beating = None 

52 

53 def __init__( 

54 self, 

55 context: zmq.Context | None = None, 

56 session: Session | None = None, 

57 address: t.Union[t.Tuple[str, int], str] = "", 

58 ) -> None: 

59 """Create the heartbeat monitor thread. 

60 

61 Parameters 

62 ---------- 

63 context : :class:`zmq.Context` 

64 The ZMQ context to use. 

65 session : :class:`session.Session` 

66 The session to use. 

67 address : zmq url 

68 Standard (ip, port) tuple that the kernel is listening on. 

69 """ 

70 super().__init__() 

71 self.daemon = True 

72 

73 self.context = context 

74 self.session = session 

75 if isinstance(address, tuple): 

76 if address[1] == 0: 

77 message = "The port number for a channel cannot be 0." 

78 raise InvalidPortNumber(message) 

79 address_str = "tcp://%s:%i" % address 

80 else: 

81 address_str = address 

82 self.address = address_str 

83 

84 # running is False until `.start()` is called 

85 self._running = False 

86 self._exit = Event() 

87 # don't start paused 

88 self._pause = False 

89 self.poller = zmq.Poller() 

90 

91 @staticmethod 

92 @atexit.register 

93 def _notice_exit() -> None: 

94 # Class definitions can be torn down during interpreter shutdown. 

95 # We only need to set _exiting flag if this hasn't happened. 

96 if HBChannel is not None: 

97 HBChannel._exiting = True 

98 

99 def _create_socket(self) -> None: 

100 if self.socket is not None: 

101 # close previous socket, before opening a new one 

102 self.poller.unregister(self.socket) # type:ignore[unreachable] 

103 self.socket.close() 

104 assert self.context is not None 

105 self.socket = self.context.socket(zmq.REQ) 

106 self.socket.linger = 1000 

107 assert self.address is not None 

108 self.socket.connect(self.address) 

109 

110 self.poller.register(self.socket, zmq.POLLIN) 

111 

112 async def _async_run(self) -> None: 

113 """The thread's main activity. Call start() instead.""" 

114 self._create_socket() 

115 self._running = True 

116 self._beating = True 

117 assert self.socket is not None 

118 

119 while self._running: 

120 if self._pause: 

121 # just sleep, and skip the rest of the loop 

122 self._exit.wait(self.time_to_dead) 

123 continue 

124 

125 since_last_heartbeat = 0.0 

126 # no need to catch EFSM here, because the previous event was 

127 # either a recv or connect, which cannot be followed by EFSM) 

128 await ensure_async(self.socket.send(b"ping")) 

129 request_time = time.time() 

130 # Wait until timeout 

131 self._exit.wait(self.time_to_dead) 

132 # poll(0) means return immediately (see http://api.zeromq.org/2-1:zmq-poll) 

133 self._beating = bool(self.poller.poll(0)) 

134 if self._beating: 

135 # the poll above guarantees we have something to recv 

136 await ensure_async(self.socket.recv()) 

137 continue 

138 elif self._running: 

139 # nothing was received within the time limit, signal heart failure 

140 since_last_heartbeat = time.time() - request_time 

141 self.call_handlers(since_last_heartbeat) 

142 # and close/reopen the socket, because the REQ/REP cycle has been broken 

143 self._create_socket() 

144 continue 

145 

146 def run(self) -> None: 

147 """Run the heartbeat thread.""" 

148 loop = asyncio.new_event_loop() 

149 asyncio.set_event_loop(loop) 

150 try: 

151 loop.run_until_complete(self._async_run()) 

152 finally: 

153 loop.close() 

154 

155 def pause(self) -> None: 

156 """Pause the heartbeat.""" 

157 self._pause = True 

158 

159 def unpause(self) -> None: 

160 """Unpause the heartbeat.""" 

161 self._pause = False 

162 

163 def is_beating(self) -> bool: 

164 """Is the heartbeat running and responsive (and not paused).""" 

165 if self.is_alive() and not self._pause and self._beating: # noqa 

166 return True 

167 else: 

168 return False 

169 

170 def stop(self) -> None: 

171 """Stop the channel's event loop and join its thread.""" 

172 self._running = False 

173 self._exit.set() 

174 self.join() 

175 self.close() 

176 

177 def close(self) -> None: 

178 """Close the heartbeat thread.""" 

179 if self.socket is not None: 

180 try: 

181 self.socket.close(linger=0) 

182 except Exception: 

183 pass 

184 self.socket = None 

185 

186 def call_handlers(self, since_last_heartbeat: float) -> None: 

187 """This method is called in the ioloop thread when a message arrives. 

188 

189 Subclasses should override this method to handle incoming messages. 

190 It is important to remember that this method is called in the thread 

191 so that some logic must be done to ensure that the application level 

192 handlers are called in the application thread. 

193 """ 

194 pass 

195 

196 

197HBChannelABC.register(HBChannel) 

198 

199 

200class ZMQSocketChannel: 

201 """A ZMQ socket wrapper""" 

202 

203 def __init__(self, socket: zmq.Socket, session: Session, loop: t.Any = None) -> None: 

204 """Create a channel. 

205 

206 Parameters 

207 ---------- 

208 socket : :class:`zmq.Socket` 

209 The ZMQ socket to use. 

210 session : :class:`session.Session` 

211 The session to use. 

212 loop 

213 Unused here, for other implementations 

214 """ 

215 super().__init__() 

216 

217 self.socket: zmq.Socket | None = socket 

218 self.session = session 

219 

220 def _recv(self, **kwargs: t.Any) -> t.Dict[str, t.Any]: 

221 assert self.socket is not None 

222 msg = self.socket.recv_multipart(**kwargs) 

223 _ident, smsg = self.session.feed_identities(msg) 

224 return self.session.deserialize(smsg) 

225 

226 def get_msg(self, timeout: float | None = None) -> t.Dict[str, t.Any]: 

227 """Gets a message if there is one that is ready.""" 

228 assert self.socket is not None 

229 timeout_ms = None if timeout is None else int(timeout * 1000) # seconds to ms 

230 ready = self.socket.poll(timeout_ms) 

231 if ready: 

232 res = self._recv() 

233 return res 

234 else: 

235 raise Empty 

236 

237 def get_msgs(self) -> t.List[t.Dict[str, t.Any]]: 

238 """Get all messages that are currently ready.""" 

239 msgs = [] 

240 while True: 

241 try: 

242 msgs.append(self.get_msg()) 

243 except Empty: 

244 break 

245 return msgs 

246 

247 def msg_ready(self) -> bool: 

248 """Is there a message that has been received?""" 

249 assert self.socket is not None 

250 return bool(self.socket.poll(timeout=0)) 

251 

252 def close(self) -> None: 

253 """Close the socket channel.""" 

254 if self.socket is not None: 

255 try: 

256 self.socket.close(linger=0) 

257 except Exception: 

258 pass 

259 self.socket = None 

260 

261 stop = close 

262 

263 def is_alive(self) -> bool: 

264 """Test whether the channel is alive.""" 

265 return self.socket is not None 

266 

267 def send(self, msg: t.Dict[str, t.Any]) -> None: 

268 """Pass a message to the ZMQ socket to send""" 

269 assert self.socket is not None 

270 self.session.send(self.socket, msg) 

271 

272 def start(self) -> None: 

273 """Start the socket channel.""" 

274 pass 

275 

276 

277class AsyncZMQSocketChannel(ZMQSocketChannel): 

278 """A ZMQ socket in an async API""" 

279 

280 socket: zmq.asyncio.Socket 

281 

282 def __init__(self, socket: zmq.asyncio.Socket, session: Session, loop: t.Any = None) -> None: 

283 """Create a channel. 

284 

285 Parameters 

286 ---------- 

287 socket : :class:`zmq.asyncio.Socket` 

288 The ZMQ socket to use. 

289 session : :class:`session.Session` 

290 The session to use. 

291 loop 

292 Unused here, for other implementations 

293 """ 

294 if not isinstance(socket, zmq.asyncio.Socket): 

295 msg = "Socket must be asyncio" # type:ignore[unreachable] 

296 raise ValueError(msg) 

297 super().__init__(socket, session) 

298 

299 async def _recv(self, **kwargs: t.Any) -> t.Dict[str, t.Any]: # type:ignore[override] 

300 assert self.socket is not None 

301 msg = await self.socket.recv_multipart(**kwargs) 

302 _, smsg = self.session.feed_identities(msg) 

303 return self.session.deserialize(smsg) 

304 

305 async def get_msg( # type:ignore[override] 

306 self, timeout: float | None = None 

307 ) -> t.Dict[str, t.Any]: 

308 """Gets a message if there is one that is ready.""" 

309 assert self.socket is not None 

310 timeout_ms = None if timeout is None else int(timeout * 1000) # seconds to ms 

311 ready = await self.socket.poll(timeout_ms) 

312 if ready: 

313 res = await self._recv() 

314 return res 

315 else: 

316 raise Empty 

317 

318 async def get_msgs(self) -> t.List[t.Dict[str, t.Any]]: # type:ignore[override] 

319 """Get all messages that are currently ready.""" 

320 msgs = [] 

321 while True: 

322 try: 

323 msgs.append(await self.get_msg()) 

324 except Empty: 

325 break 

326 return msgs 

327 

328 async def msg_ready(self) -> bool: # type:ignore[override] 

329 """Is there a message that has been received?""" 

330 assert self.socket is not None 

331 return bool(await self.socket.poll(timeout=0))