1"""Kernel connection helpers."""
2
3import json
4import struct
5from typing import Any
6
7from jupyter_client.session import Session
8from tornado.websocket import WebSocketHandler
9from traitlets import Float, Instance, Unicode, default
10from traitlets.config import LoggingConfigurable
11
12try:
13 from jupyter_client.jsonutil import json_default
14except ImportError:
15 from jupyter_client.jsonutil import date_default as json_default
16
17from jupyter_client.jsonutil import extract_dates
18
19from jupyter_server.transutils import _i18n
20
21from .abc import KernelWebsocketConnectionABC
22
23
24def serialize_binary_message(msg):
25 """serialize a message as a binary blob
26
27 Header:
28
29 4 bytes: number of msg parts (nbufs) as 32b int
30 4 * nbufs bytes: offset for each buffer as integer as 32b int
31
32 Offsets are from the start of the buffer, including the header.
33
34 Returns
35 -------
36 The message serialized to bytes.
37
38 """
39 # don't modify msg or buffer list in-place
40 msg = msg.copy()
41 buffers = list(msg.pop("buffers"))
42 bmsg = json.dumps(msg, default=json_default).encode("utf8")
43 buffers.insert(0, bmsg)
44 nbufs = len(buffers)
45 offsets = [4 * (nbufs + 1)]
46 for buf in buffers[:-1]:
47 offsets.append(offsets[-1] + len(buf))
48 offsets_buf = struct.pack("!" + "I" * (nbufs + 1), nbufs, *offsets)
49 buffers.insert(0, offsets_buf)
50 return b"".join(buffers)
51
52
53def deserialize_binary_message(bmsg):
54 """deserialize a message from a binary blog
55
56 Header:
57
58 4 bytes: number of msg parts (nbufs) as 32b int
59 4 * nbufs bytes: offset for each buffer as integer as 32b int
60
61 Offsets are from the start of the buffer, including the header.
62
63 Returns
64 -------
65 message dictionary
66 """
67 nbufs = struct.unpack("!i", bmsg[:4])[0]
68 offsets = list(struct.unpack("!" + "I" * nbufs, bmsg[4 : 4 * (nbufs + 1)]))
69 offsets.append(None)
70 bufs = []
71 for start, stop in zip(offsets[:-1], offsets[1:]):
72 bufs.append(bmsg[start:stop])
73 msg = json.loads(bufs[0].decode("utf8"))
74 msg["header"] = extract_dates(msg["header"])
75 msg["parent_header"] = extract_dates(msg["parent_header"])
76 msg["buffers"] = bufs[1:]
77 return msg
78
79
80def serialize_msg_to_ws_v1(msg_or_list, channel, pack=None):
81 """Serialize a message using the v1 protocol."""
82 if pack:
83 msg_list = [
84 pack(msg_or_list["header"]),
85 pack(msg_or_list["parent_header"]),
86 pack(msg_or_list["metadata"]),
87 pack(msg_or_list["content"]),
88 ]
89 else:
90 msg_list = msg_or_list
91 channel = channel.encode("utf-8")
92 offsets: list[Any] = []
93 offsets.append(8 * (1 + 1 + len(msg_list) + 1))
94 offsets.append(len(channel) + offsets[-1])
95 for msg in msg_list:
96 offsets.append(len(msg) + offsets[-1])
97 offset_number = len(offsets).to_bytes(8, byteorder="little")
98 offsets = [offset.to_bytes(8, byteorder="little") for offset in offsets]
99 bin_msg = b"".join([offset_number, *offsets, channel, *msg_list])
100 return bin_msg
101
102
103def deserialize_msg_from_ws_v1(ws_msg):
104 """Deserialize a message using the v1 protocol."""
105 offset_number = int.from_bytes(ws_msg[:8], "little")
106 offsets = [
107 int.from_bytes(ws_msg[8 * (i + 1) : 8 * (i + 2)], "little") for i in range(offset_number)
108 ]
109 channel = ws_msg[offsets[0] : offsets[1]].decode("utf-8")
110 msg_list = [ws_msg[offsets[i] : offsets[i + 1]] for i in range(1, offset_number - 1)]
111 return channel, msg_list
112
113
114class BaseKernelWebsocketConnection(LoggingConfigurable):
115 """A configurable base class for connecting Kernel WebSockets to ZMQ sockets."""
116
117 kernel_ws_protocol = Unicode(
118 None,
119 allow_none=True,
120 config=True,
121 help=_i18n(
122 "Preferred kernel message protocol over websocket to use (default: None). "
123 "If an empty string is passed, select the legacy protocol. If None, "
124 "the selected protocol will depend on what the front-end supports "
125 "(usually the most recent protocol supported by the back-end and the "
126 "front-end)."
127 ),
128 )
129
130 @property
131 def kernel_manager(self):
132 """The kernel manager."""
133 return self.parent
134
135 @property
136 def multi_kernel_manager(self):
137 """The multi kernel manager."""
138 return self.kernel_manager.parent
139
140 @property
141 def kernel_id(self):
142 """The kernel id."""
143 return self.kernel_manager.kernel_id
144
145 @property
146 def session_id(self):
147 """The session id."""
148 return self.session.session
149
150 kernel_info_timeout = Float()
151
152 @default("kernel_info_timeout")
153 def _default_kernel_info_timeout(self):
154 return self.multi_kernel_manager.kernel_info_timeout
155
156 session = Instance(klass=Session, config=True)
157
158 @default("session")
159 def _default_session(self):
160 return Session(config=self.config)
161
162 websocket_handler = Instance(WebSocketHandler)
163
164 async def connect(self):
165 """Handle a connect."""
166 raise NotImplementedError
167
168 async def disconnect(self):
169 """Handle a disconnect."""
170 raise NotImplementedError
171
172 def handle_incoming_message(self, incoming_msg: str) -> None:
173 """Handle an incoming message."""
174 raise NotImplementedError
175
176 def handle_outgoing_message(self, stream: str, outgoing_msg: list[Any]) -> None:
177 """Handle an outgoing message."""
178 raise NotImplementedError
179
180
181KernelWebsocketConnectionABC.register(BaseKernelWebsocketConnection)