Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.10/site-packages/jupyter_client/channels.py: 31%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

179 statements  

1"""Base classes to manage a Client's interaction with a running kernel""" 

2# Copyright (c) Jupyter Development Team. 

3# Distributed under the terms of the Modified BSD License. 

4import asyncio 

5import atexit 

6import time 

7import typing as t 

8from queue import Empty 

9from threading import Event, Thread 

10 

11import zmq.asyncio 

12from jupyter_core.utils import ensure_async 

13 

14from ._version import protocol_version_info 

15from .channelsabc import HBChannelABC 

16from .session import Session 

17 

18# import ZMQError in top-level namespace, to avoid ugly attribute-error messages 

19# during garbage collection of threads at exit 

20 

21# ----------------------------------------------------------------------------- 

22# Constants and exceptions 

23# ----------------------------------------------------------------------------- 

24 

25major_protocol_version = protocol_version_info[0] 

26 

27 

28class InvalidPortNumber(Exception): # noqa 

29 """An exception raised for an invalid port number.""" 

30 

31 pass 

32 

33 

34class HBChannel(Thread): 

35 """The heartbeat channel which monitors the kernel heartbeat. 

36 

37 Note that the heartbeat channel is paused by default. As long as you start 

38 this channel, the kernel manager will ensure that it is paused and un-paused 

39 as appropriate. 

40 """ 

41 

42 session = None 

43 socket = None 

44 address = None 

45 _exiting = False 

46 

47 time_to_dead: float = 1.0 

48 _running = None 

49 _pause = None 

50 _beating = None 

51 

52 def __init__( 

53 self, 

54 context: t.Optional[zmq.Context] = None, 

55 session: t.Optional[Session] = None, 

56 address: t.Union[t.Tuple[str, int], str] = "", 

57 ) -> None: 

58 """Create the heartbeat monitor thread. 

59 

60 Parameters 

61 ---------- 

62 context : :class:`zmq.Context` 

63 The ZMQ context to use. 

64 session : :class:`session.Session` 

65 The session to use. 

66 address : zmq url 

67 Standard (ip, port) tuple that the kernel is listening on. 

68 """ 

69 super().__init__() 

70 self.daemon = True 

71 

72 self.context = context 

73 self.session = session 

74 if isinstance(address, tuple): 

75 if address[1] == 0: 

76 message = "The port number for a channel cannot be 0." 

77 raise InvalidPortNumber(message) 

78 address_str = "tcp://%s:%i" % address 

79 else: 

80 address_str = address 

81 self.address = address_str 

82 

83 # running is False until `.start()` is called 

84 self._running = False 

85 self._exit = Event() 

86 # don't start paused 

87 self._pause = False 

88 self.poller = zmq.Poller() 

89 

90 @staticmethod 

91 @atexit.register 

92 def _notice_exit() -> None: 

93 # Class definitions can be torn down during interpreter shutdown. 

94 # We only need to set _exiting flag if this hasn't happened. 

95 if HBChannel is not None: 

96 HBChannel._exiting = True 

97 

98 def _create_socket(self) -> None: 

99 if self.socket is not None: 

100 # close previous socket, before opening a new one 

101 self.poller.unregister(self.socket) # type:ignore[unreachable] 

102 self.socket.close() 

103 assert self.context is not None 

104 self.socket = self.context.socket(zmq.REQ) 

105 self.socket.linger = 1000 

106 assert self.address is not None 

107 self.socket.connect(self.address) 

108 

109 self.poller.register(self.socket, zmq.POLLIN) 

110 

111 async def _async_run(self) -> None: 

112 """The thread's main activity. Call start() instead.""" 

113 self._create_socket() 

114 self._running = True 

115 self._beating = True 

116 assert self.socket is not None 

117 

118 while self._running: 

119 if self._pause: 

120 # just sleep, and skip the rest of the loop 

121 self._exit.wait(self.time_to_dead) 

122 continue 

123 

124 since_last_heartbeat = 0.0 

125 # no need to catch EFSM here, because the previous event was 

126 # either a recv or connect, which cannot be followed by EFSM) 

127 await ensure_async(self.socket.send(b"ping")) 

128 request_time = time.time() 

129 # Wait until timeout 

130 self._exit.wait(self.time_to_dead) 

131 # poll(0) means return immediately (see http://api.zeromq.org/2-1:zmq-poll) 

132 self._beating = bool(self.poller.poll(0)) 

133 if self._beating: 

134 # the poll above guarantees we have something to recv 

135 await ensure_async(self.socket.recv()) 

136 continue 

137 elif self._running: 

138 # nothing was received within the time limit, signal heart failure 

139 since_last_heartbeat = time.time() - request_time 

140 self.call_handlers(since_last_heartbeat) 

141 # and close/reopen the socket, because the REQ/REP cycle has been broken 

142 self._create_socket() 

143 continue 

144 

145 def run(self) -> None: 

146 """Run the heartbeat thread.""" 

147 loop = asyncio.new_event_loop() 

148 asyncio.set_event_loop(loop) 

149 loop.run_until_complete(self._async_run()) 

150 loop.close() 

151 

152 def pause(self) -> None: 

153 """Pause the heartbeat.""" 

154 self._pause = True 

155 

156 def unpause(self) -> None: 

157 """Unpause the heartbeat.""" 

158 self._pause = False 

159 

160 def is_beating(self) -> bool: 

161 """Is the heartbeat running and responsive (and not paused).""" 

162 if self.is_alive() and not self._pause and self._beating: # noqa 

163 return True 

164 else: 

165 return False 

166 

167 def stop(self) -> None: 

168 """Stop the channel's event loop and join its thread.""" 

169 self._running = False 

170 self._exit.set() 

171 self.join() 

172 self.close() 

173 

174 def close(self) -> None: 

175 """Close the heartbeat thread.""" 

176 if self.socket is not None: 

177 try: 

178 self.socket.close(linger=0) 

179 except Exception: 

180 pass 

181 self.socket = None 

182 

183 def call_handlers(self, since_last_heartbeat: float) -> None: 

184 """This method is called in the ioloop thread when a message arrives. 

185 

186 Subclasses should override this method to handle incoming messages. 

187 It is important to remember that this method is called in the thread 

188 so that some logic must be done to ensure that the application level 

189 handlers are called in the application thread. 

190 """ 

191 pass 

192 

193 

194HBChannelABC.register(HBChannel) 

195 

196 

197class ZMQSocketChannel: 

198 """A ZMQ socket wrapper""" 

199 

200 def __init__(self, socket: zmq.Socket, session: Session, loop: t.Any = None) -> None: 

201 """Create a channel. 

202 

203 Parameters 

204 ---------- 

205 socket : :class:`zmq.Socket` 

206 The ZMQ socket to use. 

207 session : :class:`session.Session` 

208 The session to use. 

209 loop 

210 Unused here, for other implementations 

211 """ 

212 super().__init__() 

213 

214 self.socket: t.Optional[zmq.Socket] = socket 

215 self.session = session 

216 

217 def _recv(self, **kwargs: t.Any) -> t.Dict[str, t.Any]: 

218 assert self.socket is not None 

219 msg = self.socket.recv_multipart(**kwargs) 

220 ident, smsg = self.session.feed_identities(msg) 

221 return self.session.deserialize(smsg) 

222 

223 def get_msg(self, timeout: t.Optional[float] = None) -> t.Dict[str, t.Any]: 

224 """Gets a message if there is one that is ready.""" 

225 assert self.socket is not None 

226 timeout_ms = None if timeout is None else int(timeout * 1000) # seconds to ms 

227 ready = self.socket.poll(timeout_ms) 

228 if ready: 

229 res = self._recv() 

230 return res 

231 else: 

232 raise Empty 

233 

234 def get_msgs(self) -> t.List[t.Dict[str, t.Any]]: 

235 """Get all messages that are currently ready.""" 

236 msgs = [] 

237 while True: 

238 try: 

239 msgs.append(self.get_msg()) 

240 except Empty: 

241 break 

242 return msgs 

243 

244 def msg_ready(self) -> bool: 

245 """Is there a message that has been received?""" 

246 assert self.socket is not None 

247 return bool(self.socket.poll(timeout=0)) 

248 

249 def close(self) -> None: 

250 """Close the socket channel.""" 

251 if self.socket is not None: 

252 try: 

253 self.socket.close(linger=0) 

254 except Exception: 

255 pass 

256 self.socket = None 

257 

258 stop = close 

259 

260 def is_alive(self) -> bool: 

261 """Test whether the channel is alive.""" 

262 return self.socket is not None 

263 

264 def send(self, msg: t.Dict[str, t.Any]) -> None: 

265 """Pass a message to the ZMQ socket to send""" 

266 assert self.socket is not None 

267 self.session.send(self.socket, msg) 

268 

269 def start(self) -> None: 

270 """Start the socket channel.""" 

271 pass 

272 

273 

274class AsyncZMQSocketChannel(ZMQSocketChannel): 

275 """A ZMQ socket in an async API""" 

276 

277 socket: zmq.asyncio.Socket 

278 

279 def __init__(self, socket: zmq.asyncio.Socket, session: Session, loop: t.Any = None) -> None: 

280 """Create a channel. 

281 

282 Parameters 

283 ---------- 

284 socket : :class:`zmq.asyncio.Socket` 

285 The ZMQ socket to use. 

286 session : :class:`session.Session` 

287 The session to use. 

288 loop 

289 Unused here, for other implementations 

290 """ 

291 if not isinstance(socket, zmq.asyncio.Socket): 

292 msg = "Socket must be asyncio" # type:ignore[unreachable] 

293 raise ValueError(msg) 

294 super().__init__(socket, session) 

295 

296 async def _recv(self, **kwargs: t.Any) -> t.Dict[str, t.Any]: # type:ignore[override] 

297 assert self.socket is not None 

298 msg = await self.socket.recv_multipart(**kwargs) 

299 _, smsg = self.session.feed_identities(msg) 

300 return self.session.deserialize(smsg) 

301 

302 async def get_msg( # type:ignore[override] 

303 self, timeout: t.Optional[float] = None 

304 ) -> t.Dict[str, t.Any]: 

305 """Gets a message if there is one that is ready.""" 

306 assert self.socket is not None 

307 timeout_ms = None if timeout is None else int(timeout * 1000) # seconds to ms 

308 ready = await self.socket.poll(timeout_ms) 

309 if ready: 

310 res = await self._recv() 

311 return res 

312 else: 

313 raise Empty 

314 

315 async def get_msgs(self) -> t.List[t.Dict[str, t.Any]]: # type:ignore[override] 

316 """Get all messages that are currently ready.""" 

317 msgs = [] 

318 while True: 

319 try: 

320 msgs.append(await self.get_msg()) 

321 except Empty: 

322 break 

323 return msgs 

324 

325 async def msg_ready(self) -> bool: # type:ignore[override] 

326 """Is there a message that has been received?""" 

327 assert self.socket is not None 

328 return bool(await self.socket.poll(timeout=0))