1"""Session object for building, serializing, sending, and receiving messages.
2
3The Session object supports serialization, HMAC signatures,
4and metadata on messages.
5
6Also defined here are utilities for working with Sessions:
7* A SessionFactory to be used as a base class for configurables that work with
8Sessions.
9* A Message object for convenience that allows attribute-access to the msg dict.
10"""
11
12# Copyright (c) Jupyter Development Team.
13# Distributed under the terms of the Modified BSD License.
14from __future__ import annotations
15
16import hashlib
17import hmac
18import json
19import logging
20import os
21import pickle
22import pprint
23import random
24import typing as t
25import warnings
26from binascii import b2a_hex
27from datetime import datetime, timezone
28from hmac import compare_digest
29
30# We are using compare_digest to limit the surface of timing attacks
31import zmq.asyncio
32from tornado.ioloop import IOLoop
33from traitlets import (
34 Any,
35 Bool,
36 CBytes,
37 CUnicode,
38 Dict,
39 DottedObjectName,
40 Instance,
41 Integer,
42 Set,
43 TraitError,
44 Unicode,
45 observe,
46)
47from traitlets.config.configurable import Configurable, LoggingConfigurable
48from traitlets.log import get_logger
49from traitlets.utils.importstring import import_item
50from zmq.eventloop.zmqstream import ZMQStream
51
52from ._version import protocol_version
53from .adapter import adapt
54from .jsonutil import extract_dates, json_clean, json_default, squash_dates
55
56PICKLE_PROTOCOL = pickle.DEFAULT_PROTOCOL
57
58utc = timezone.utc
59
60# -----------------------------------------------------------------------------
61# utility functions
62# -----------------------------------------------------------------------------
63
64
65def squash_unicode(obj: t.Any) -> t.Any:
66 """coerce unicode back to bytestrings."""
67 if isinstance(obj, dict):
68 for key in list(obj.keys()):
69 obj[key] = squash_unicode(obj[key])
70 if isinstance(key, str):
71 obj[squash_unicode(key)] = obj.pop(key)
72 elif isinstance(obj, list):
73 for i, v in enumerate(obj):
74 obj[i] = squash_unicode(v)
75 elif isinstance(obj, str):
76 obj = obj.encode("utf8")
77 return obj
78
79
80# -----------------------------------------------------------------------------
81# globals and defaults
82# -----------------------------------------------------------------------------
83
84# default values for the thresholds:
85MAX_ITEMS = 64
86MAX_BYTES = 1024
87
88# ISO8601-ify datetime objects
89# allow unicode
90# disallow nan, because it's not actually valid JSON
91
92
93def json_packer(obj: t.Any) -> bytes:
94 """Convert a json object to a bytes."""
95 try:
96 return json.dumps(
97 obj,
98 default=json_default,
99 ensure_ascii=False,
100 allow_nan=False,
101 ).encode("utf8", errors="surrogateescape")
102 except (TypeError, ValueError) as e:
103 # Fallback to trying to clean the json before serializing
104 packed = json.dumps(
105 json_clean(obj),
106 default=json_default,
107 ensure_ascii=False,
108 allow_nan=False,
109 ).encode("utf8", errors="surrogateescape")
110
111 warnings.warn(
112 f"Message serialization failed with:\n{e}\n"
113 "Supporting this message is deprecated in jupyter-client 7, please make "
114 "sure your message is JSON-compliant",
115 stacklevel=2,
116 )
117
118 return packed
119
120
121def json_unpacker(s: str | bytes) -> t.Any:
122 """Convert a json bytes or string to an object."""
123 if isinstance(s, bytes):
124 s = s.decode("utf8", "replace")
125 return json.loads(s)
126
127
128def pickle_packer(o: t.Any) -> bytes:
129 """Pack an object using the pickle module."""
130 return pickle.dumps(squash_dates(o), PICKLE_PROTOCOL)
131
132
133pickle_unpacker = pickle.loads
134
135default_packer = json_packer
136default_unpacker = json_unpacker
137
138DELIM = b"<IDS|MSG>"
139# singleton dummy tracker, which will always report as done
140DONE = zmq.MessageTracker()
141
142# -----------------------------------------------------------------------------
143# Mixin tools for apps that use Sessions
144# -----------------------------------------------------------------------------
145
146
147def new_id() -> str:
148 """Generate a new random id.
149
150 Avoids problematic runtime import in stdlib uuid on Python 2.
151
152 Returns
153 -------
154
155 id string (16 random bytes as hex-encoded text, chunks separated by '-')
156 """
157 buf = os.urandom(16)
158 return "-".join(b2a_hex(x).decode("ascii") for x in (buf[:4], buf[4:]))
159
160
161def new_id_bytes() -> bytes:
162 """Return new_id as ascii bytes"""
163 return new_id().encode("ascii")
164
165
166session_aliases = {
167 "ident": "Session.session",
168 "user": "Session.username",
169 "keyfile": "Session.keyfile",
170}
171
172session_flags = {
173 "secure": (
174 {"Session": {"key": new_id_bytes(), "keyfile": ""}},
175 """Use HMAC digests for authentication of messages.
176 Setting this flag will generate a new UUID to use as the HMAC key.
177 """,
178 ),
179 "no-secure": (
180 {"Session": {"key": b"", "keyfile": ""}},
181 """Don't authenticate messages.""",
182 ),
183}
184
185
186def default_secure(cfg: t.Any) -> None: # pragma: no cover
187 """Set the default behavior for a config environment to be secure.
188
189 If Session.key/keyfile have not been set, set Session.key to
190 a new random UUID.
191 """
192 warnings.warn("default_secure is deprecated", DeprecationWarning, stacklevel=2)
193 if "Session" in cfg and ("key" in cfg.Session or "keyfile" in cfg.Session):
194 return
195 # key/keyfile not specified, generate new UUID:
196 cfg.Session.key = new_id_bytes()
197
198
199def utcnow() -> datetime:
200 """Return timezone-aware UTC timestamp"""
201 return datetime.now(utc)
202
203
204# -----------------------------------------------------------------------------
205# Classes
206# -----------------------------------------------------------------------------
207
208
209class SessionFactory(LoggingConfigurable):
210 """The Base class for configurables that have a Session, Context, logger,
211 and IOLoop.
212 """
213
214 logname = Unicode("")
215
216 @observe("logname")
217 def _logname_changed(self, change: t.Any) -> None:
218 self.log = logging.getLogger(change["new"])
219
220 # not configurable:
221 context = Instance("zmq.Context")
222
223 def _context_default(self) -> zmq.Context:
224 return zmq.Context()
225
226 session = Instance("jupyter_client.session.Session", allow_none=True)
227
228 loop = Instance("tornado.ioloop.IOLoop")
229
230 def _loop_default(self) -> IOLoop:
231 return IOLoop.current()
232
233 def __init__(self, **kwargs: t.Any) -> None:
234 """Initialize a session factory."""
235 super().__init__(**kwargs)
236
237 if self.session is None:
238 # construct the session
239 self.session = Session(**kwargs)
240
241
242class Message:
243 """A simple message object that maps dict keys to attributes.
244
245 A Message can be created from a dict and a dict from a Message instance
246 simply by calling dict(msg_obj)."""
247
248 def __init__(self, msg_dict: dict[str, t.Any]) -> None:
249 """Initialize a message."""
250 dct = self.__dict__
251 for k, v in dict(msg_dict).items():
252 if isinstance(v, dict):
253 v = Message(v) # noqa
254 dct[k] = v
255
256 # Having this iterator lets dict(msg_obj) work out of the box.
257 def __iter__(self) -> t.ItemsView[str, t.Any]:
258 return iter(self.__dict__.items()) # type:ignore[return-value]
259
260 def __repr__(self) -> str:
261 return repr(self.__dict__)
262
263 def __str__(self) -> str:
264 return pprint.pformat(self.__dict__)
265
266 def __contains__(self, k: object) -> bool:
267 return k in self.__dict__
268
269 def __getitem__(self, k: str) -> t.Any:
270 return self.__dict__[k]
271
272
273def msg_header(
274 msg_id: str, msg_type: str, username: str, session: Session | str
275) -> dict[str, t.Any]:
276 """Create a new message header"""
277 date = utcnow()
278 version = protocol_version
279 return locals()
280
281
282def extract_header(msg_or_header: dict[str, t.Any]) -> dict[str, t.Any]:
283 """Given a message or header, return the header."""
284 if not msg_or_header:
285 return {}
286 try:
287 # See if msg_or_header is the entire message.
288 h = msg_or_header["header"]
289 except KeyError:
290 try:
291 # See if msg_or_header is just the header
292 h = msg_or_header["msg_id"]
293 except KeyError:
294 raise
295 else:
296 h = msg_or_header
297 if not isinstance(h, dict):
298 h = dict(h)
299 return h
300
301
302class Session(Configurable):
303 """Object for handling serialization and sending of messages.
304
305 The Session object handles building messages and sending them
306 with ZMQ sockets or ZMQStream objects. Objects can communicate with each
307 other over the network via Session objects, and only need to work with the
308 dict-based IPython message spec. The Session will handle
309 serialization/deserialization, security, and metadata.
310
311 Sessions support configurable serialization via packer/unpacker traits,
312 and signing with HMAC digests via the key/keyfile traits.
313
314 Parameters
315 ----------
316
317 debug : bool
318 whether to trigger extra debugging statements
319 packer/unpacker : str : 'json', 'pickle' or import_string
320 importstrings for methods to serialize message parts. If just
321 'json' or 'pickle', predefined JSON and pickle packers will be used.
322 Otherwise, the entire importstring must be used.
323
324 The functions must accept at least valid JSON input, and output *bytes*.
325
326 For example, to use msgpack:
327 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
328 pack/unpack : callables
329 You can also set the pack/unpack callables for serialization directly.
330 session : bytes
331 the ID of this Session object. The default is to generate a new UUID.
332 username : unicode
333 username added to message headers. The default is to ask the OS.
334 key : bytes
335 The key used to initialize an HMAC signature. If unset, messages
336 will not be signed or checked.
337 keyfile : filepath
338 The file containing a key. If this is set, `key` will be initialized
339 to the contents of the file.
340
341 """
342
343 debug = Bool(False, config=True, help="""Debug output in the Session""")
344
345 check_pid = Bool(
346 True,
347 config=True,
348 help="""Whether to check PID to protect against calls after fork.
349
350 This check can be disabled if fork-safety is handled elsewhere.
351 """,
352 )
353
354 packer = DottedObjectName(
355 "json",
356 config=True,
357 help="""The name of the packer for serializing messages.
358 Should be one of 'json', 'pickle', or an import name
359 for a custom callable serializer.""",
360 )
361
362 @observe("packer")
363 def _packer_changed(self, change: t.Any) -> None:
364 new = change["new"]
365 if new.lower() == "json":
366 self.pack = json_packer
367 self.unpack = json_unpacker
368 self.unpacker = new
369 elif new.lower() == "pickle":
370 self.pack = pickle_packer
371 self.unpack = pickle_unpacker
372 self.unpacker = new
373 else:
374 self.pack = import_item(str(new))
375
376 unpacker = DottedObjectName(
377 "json",
378 config=True,
379 help="""The name of the unpacker for unserializing messages.
380 Only used with custom functions for `packer`.""",
381 )
382
383 @observe("unpacker")
384 def _unpacker_changed(self, change: t.Any) -> None:
385 new = change["new"]
386 if new.lower() == "json":
387 self.pack = json_packer
388 self.unpack = json_unpacker
389 self.packer = new
390 elif new.lower() == "pickle":
391 self.pack = pickle_packer
392 self.unpack = pickle_unpacker
393 self.packer = new
394 else:
395 self.unpack = import_item(str(new))
396
397 session = CUnicode("", config=True, help="""The UUID identifying this session.""")
398
399 def _session_default(self) -> str:
400 u = new_id()
401 self.bsession = u.encode("ascii")
402 return u
403
404 @observe("session")
405 def _session_changed(self, change: t.Any) -> None:
406 self.bsession = self.session.encode("ascii")
407
408 # bsession is the session as bytes
409 bsession = CBytes(b"")
410
411 username = Unicode(
412 os.environ.get("USER", "username"),
413 help="""Username for the Session. Default is your system username.""",
414 config=True,
415 )
416
417 metadata = Dict(
418 {},
419 config=True,
420 help="Metadata dictionary, which serves as the default top-level metadata dict for each "
421 "message.",
422 )
423
424 # if 0, no adapting to do.
425 adapt_version = Integer(0)
426
427 # message signature related traits:
428
429 key = CBytes(config=True, help="""execution key, for signing messages.""")
430
431 def _key_default(self) -> bytes:
432 return new_id_bytes()
433
434 @observe("key")
435 def _key_changed(self, change: t.Any) -> None:
436 self._new_auth()
437
438 signature_scheme = Unicode(
439 "hmac-sha256",
440 config=True,
441 help="""The digest scheme used to construct the message signatures.
442 Must have the form 'hmac-HASH'.""",
443 )
444
445 @observe("signature_scheme")
446 def _signature_scheme_changed(self, change: t.Any) -> None:
447 new = change["new"]
448 if not new.startswith("hmac-"):
449 raise TraitError("signature_scheme must start with 'hmac-', got %r" % new)
450 hash_name = new.split("-", 1)[1]
451 try:
452 self.digest_mod = getattr(hashlib, hash_name)
453 except AttributeError as e:
454 raise TraitError("hashlib has no such attribute: %s" % hash_name) from e
455 self._new_auth()
456
457 digest_mod = Any()
458
459 def _digest_mod_default(self) -> t.Callable:
460 return hashlib.sha256
461
462 auth = Instance(hmac.HMAC, allow_none=True)
463
464 def _new_auth(self) -> None:
465 if self.key:
466 self.auth = hmac.HMAC(self.key, digestmod=self.digest_mod)
467 else:
468 self.auth = None
469
470 digest_history = Set()
471 digest_history_size = Integer(
472 2**16,
473 config=True,
474 help="""The maximum number of digests to remember.
475
476 The digest history will be culled when it exceeds this value.
477 """,
478 )
479
480 keyfile = Unicode("", config=True, help="""path to file containing execution key.""")
481
482 @observe("keyfile")
483 def _keyfile_changed(self, change: t.Any) -> None:
484 with open(change["new"], "rb") as f:
485 self.key = f.read().strip()
486
487 # for protecting against sends from forks
488 pid = Integer()
489
490 # serialization traits:
491
492 pack = Any(default_packer) # the actual packer function
493
494 @observe("pack")
495 def _pack_changed(self, change: t.Any) -> None:
496 new = change["new"]
497 if not callable(new):
498 raise TypeError("packer must be callable, not %s" % type(new))
499
500 unpack = Any(default_unpacker) # the actual packer function
501
502 @observe("unpack")
503 def _unpack_changed(self, change: t.Any) -> None:
504 # unpacker is not checked - it is assumed to be
505 new = change["new"]
506 if not callable(new):
507 raise TypeError("unpacker must be callable, not %s" % type(new))
508
509 # thresholds:
510 copy_threshold = Integer(
511 2**16,
512 config=True,
513 help="Threshold (in bytes) beyond which a buffer should be sent without copying.",
514 )
515 buffer_threshold = Integer(
516 MAX_BYTES,
517 config=True,
518 help="Threshold (in bytes) beyond which an object's buffer should be extracted to avoid "
519 "pickling.",
520 )
521 item_threshold = Integer(
522 MAX_ITEMS,
523 config=True,
524 help="""The maximum number of items for a container to be introspected for custom serialization.
525 Containers larger than this are pickled outright.
526 """,
527 )
528
529 def __init__(self, **kwargs: t.Any) -> None:
530 """create a Session object
531
532 Parameters
533 ----------
534
535 debug : bool
536 whether to trigger extra debugging statements
537 packer/unpacker : str : 'json', 'pickle' or import_string
538 importstrings for methods to serialize message parts. If just
539 'json' or 'pickle', predefined JSON and pickle packers will be used.
540 Otherwise, the entire importstring must be used.
541
542 The functions must accept at least valid JSON input, and output
543 *bytes*.
544
545 For example, to use msgpack:
546 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
547 pack/unpack : callables
548 You can also set the pack/unpack callables for serialization
549 directly.
550 session : unicode (must be ascii)
551 the ID of this Session object. The default is to generate a new
552 UUID.
553 bsession : bytes
554 The session as bytes
555 username : unicode
556 username added to message headers. The default is to ask the OS.
557 key : bytes
558 The key used to initialize an HMAC signature. If unset, messages
559 will not be signed or checked.
560 signature_scheme : str
561 The message digest scheme. Currently must be of the form 'hmac-HASH',
562 where 'HASH' is a hashing function available in Python's hashlib.
563 The default is 'hmac-sha256'.
564 This is ignored if 'key' is empty.
565 keyfile : filepath
566 The file containing a key. If this is set, `key` will be
567 initialized to the contents of the file.
568 """
569 super().__init__(**kwargs)
570 self._check_packers()
571 self.none = self.pack({})
572 # ensure self._session_default() if necessary, so bsession is defined:
573 self.session # noqa
574 self.pid = os.getpid()
575 self._new_auth()
576 if not self.key:
577 get_logger().warning(
578 "Message signing is disabled. This is insecure and not recommended!"
579 )
580
581 def clone(self) -> Session:
582 """Create a copy of this Session
583
584 Useful when connecting multiple times to a given kernel.
585 This prevents a shared digest_history warning about duplicate digests
586 due to multiple connections to IOPub in the same process.
587
588 .. versionadded:: 5.1
589 """
590 # make a copy
591 new_session = type(self)()
592 for name in self.traits():
593 setattr(new_session, name, getattr(self, name))
594 # fork digest_history
595 new_session.digest_history = set()
596 new_session.digest_history.update(self.digest_history)
597 return new_session
598
599 message_count = 0
600
601 @property
602 def msg_id(self) -> str:
603 message_number = self.message_count
604 self.message_count += 1
605 return f"{self.session}_{os.getpid()}_{message_number}"
606
607 def _check_packers(self) -> None:
608 """check packers for datetime support."""
609 pack = self.pack
610 unpack = self.unpack
611
612 # check simple serialization
613 msg_list = {"a": [1, "hi"]}
614 try:
615 packed = pack(msg_list)
616 except Exception as e:
617 msg = f"packer '{self.packer}' could not serialize a simple message: {e}"
618 raise ValueError(msg) from e
619
620 # ensure packed message is bytes
621 if not isinstance(packed, bytes):
622 raise ValueError("message packed to %r, but bytes are required" % type(packed))
623
624 # check that unpack is pack's inverse
625 try:
626 unpacked = unpack(packed)
627 assert unpacked == msg_list
628 except Exception as e:
629 msg = (
630 f"unpacker '{self.unpacker}' could not handle output from packer"
631 f" '{self.packer}': {e}"
632 )
633 raise ValueError(msg) from e
634
635 # check datetime support
636 msg_datetime = {"t": utcnow()}
637 try:
638 unpacked = unpack(pack(msg_datetime))
639 if isinstance(unpacked["t"], datetime):
640 msg = "Shouldn't deserialize to datetime"
641 raise ValueError(msg)
642 except Exception:
643 self.pack = lambda o: pack(squash_dates(o))
644 self.unpack = lambda s: unpack(s)
645
646 def msg_header(self, msg_type: str) -> dict[str, t.Any]:
647 """Create a header for a message type."""
648 return msg_header(self.msg_id, msg_type, self.username, self.session)
649
650 def msg(
651 self,
652 msg_type: str,
653 content: dict | None = None,
654 parent: dict[str, t.Any] | None = None,
655 header: dict[str, t.Any] | None = None,
656 metadata: dict[str, t.Any] | None = None,
657 ) -> dict[str, t.Any]:
658 """Return the nested message dict.
659
660 This format is different from what is sent over the wire. The
661 serialize/deserialize methods converts this nested message dict to the wire
662 format, which is a list of message parts.
663 """
664 msg = {}
665 header = self.msg_header(msg_type) if header is None else header
666 msg["header"] = header
667 msg["msg_id"] = header["msg_id"]
668 msg["msg_type"] = header["msg_type"]
669 msg["parent_header"] = {} if parent is None else extract_header(parent)
670 msg["content"] = {} if content is None else content
671 msg["metadata"] = self.metadata.copy()
672 if metadata is not None:
673 msg["metadata"].update(metadata)
674 return msg
675
676 def sign(self, msg_list: list) -> bytes:
677 """Sign a message with HMAC digest. If no auth, return b''.
678
679 Parameters
680 ----------
681 msg_list : list
682 The [p_header,p_parent,p_content] part of the message list.
683 """
684 if self.auth is None:
685 return b""
686 h = self.auth.copy()
687 for m in msg_list:
688 h.update(m)
689 return h.hexdigest().encode()
690
691 def serialize(
692 self,
693 msg: dict[str, t.Any],
694 ident: list[bytes] | bytes | None = None,
695 ) -> list[bytes]:
696 """Serialize the message components to bytes.
697
698 This is roughly the inverse of deserialize. The serialize/deserialize
699 methods work with full message lists, whereas pack/unpack work with
700 the individual message parts in the message list.
701
702 Parameters
703 ----------
704 msg : dict or Message
705 The next message dict as returned by the self.msg method.
706
707 Returns
708 -------
709 msg_list : list
710 The list of bytes objects to be sent with the format::
711
712 [ident1, ident2, ..., DELIM, HMAC, p_header, p_parent,
713 p_metadata, p_content, buffer1, buffer2, ...]
714
715 In this list, the ``p_*`` entities are the packed or serialized
716 versions, so if JSON is used, these are utf8 encoded JSON strings.
717 """
718 content = msg.get("content", {})
719 if content is None:
720 content = self.none
721 elif isinstance(content, dict):
722 content = self.pack(content)
723 elif isinstance(content, bytes):
724 # content is already packed, as in a relayed message
725 pass
726 elif isinstance(content, str):
727 # should be bytes, but JSON often spits out unicode
728 content = content.encode("utf8")
729 else:
730 raise TypeError("Content incorrect type: %s" % type(content))
731
732 real_message = [
733 self.pack(msg["header"]),
734 self.pack(msg["parent_header"]),
735 self.pack(msg["metadata"]),
736 content,
737 ]
738
739 to_send = []
740
741 if isinstance(ident, list):
742 # accept list of idents
743 to_send.extend(ident)
744 elif ident is not None:
745 to_send.append(ident)
746 to_send.append(DELIM)
747
748 signature = self.sign(real_message)
749 to_send.append(signature)
750
751 to_send.extend(real_message)
752
753 return to_send
754
755 def send(
756 self,
757 stream: zmq.sugar.socket.Socket | ZMQStream | None,
758 msg_or_type: dict[str, t.Any] | str,
759 content: dict[str, t.Any] | None = None,
760 parent: dict[str, t.Any] | None = None,
761 ident: bytes | list[bytes] | None = None,
762 buffers: list[bytes | memoryview[bytes]] | None = None,
763 track: bool = False,
764 header: dict[str, t.Any] | None = None,
765 metadata: dict[str, t.Any] | None = None,
766 ) -> dict[str, t.Any] | None:
767 """Build and send a message via stream or socket.
768
769 The message format used by this function internally is as follows:
770
771 [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content,
772 buffer1,buffer2,...]
773
774 The serialize/deserialize methods convert the nested message dict into this
775 format.
776
777 Parameters
778 ----------
779
780 stream : zmq.Socket or ZMQStream
781 The socket-like object used to send the data.
782 msg_or_type : str or Message/dict
783 Normally, msg_or_type will be a msg_type unless a message is being
784 sent more than once. If a header is supplied, this can be set to
785 None and the msg_type will be pulled from the header.
786
787 content : dict or None
788 The content of the message (ignored if msg_or_type is a message).
789 header : dict or None
790 The header dict for the message (ignored if msg_to_type is a message).
791 parent : Message or dict or None
792 The parent or parent header describing the parent of this message
793 (ignored if msg_or_type is a message).
794 ident : bytes or list of bytes
795 The zmq.IDENTITY routing path.
796 metadata : dict or None
797 The metadata describing the message
798 buffers : list or None
799 The already-serialized buffers to be appended to the message.
800 track : bool
801 Whether to track. Only for use with Sockets, because ZMQStream
802 objects cannot track messages.
803
804
805 Returns
806 -------
807 msg : dict
808 The constructed message.
809 """
810 if not isinstance(stream, zmq.Socket):
811 # ZMQStreams and dummy sockets do not support tracking.
812 track = False
813
814 if isinstance(stream, zmq.asyncio.Socket):
815 assert stream is not None
816 stream = zmq.Socket.shadow(stream.underlying)
817
818 if isinstance(msg_or_type, Message | dict):
819 # We got a Message or message dict, not a msg_type so don't
820 # build a new Message.
821 msg = msg_or_type
822 buffers = buffers or msg.get("buffers", [])
823 else:
824 msg = self.msg(
825 msg_or_type,
826 content=content,
827 parent=parent,
828 header=header,
829 metadata=metadata,
830 )
831 if self.check_pid and os.getpid() != self.pid:
832 get_logger().warning("WARNING: attempted to send message from fork\n%s", msg)
833 return None
834 buffers = [] if buffers is None else buffers
835 for idx, buf in enumerate(buffers):
836 if isinstance(buf, memoryview):
837 view = buf
838 else:
839 try:
840 # check to see if buf supports the buffer protocol.
841 view = memoryview(buf)
842 except TypeError as e:
843 emsg = "Buffer objects must support the buffer protocol."
844 raise TypeError(emsg) from e
845 if not view.contiguous:
846 # zmq requires memoryviews to be contiguous
847 raise ValueError("Buffer %i (%r) is not contiguous" % (idx, buf))
848
849 if self.adapt_version:
850 msg = adapt(msg, self.adapt_version)
851 to_send = self.serialize(msg, ident)
852 to_send.extend(buffers) # type: ignore[arg-type]
853 longest = max([len(s) for s in to_send])
854 copy = longest < self.copy_threshold
855
856 if stream and buffers and track and not copy:
857 # only really track when we are doing zero-copy buffers
858 tracker = stream.send_multipart(to_send, copy=False, track=True)
859 elif stream:
860 # use dummy tracker, which will be done immediately
861 tracker = DONE
862 stream.send_multipart(to_send, copy=copy)
863 else:
864 tracker = DONE
865
866 if self.debug:
867 pprint.pprint(msg) # noqa
868 pprint.pprint(to_send) # noqa
869 pprint.pprint(buffers) # noqa
870
871 msg["tracker"] = tracker
872
873 return msg
874
875 def send_raw(
876 self,
877 stream: zmq.sugar.socket.Socket,
878 msg_list: list,
879 flags: int = 0,
880 copy: bool = True,
881 ident: bytes | list[bytes] | None = None,
882 ) -> None:
883 """Send a raw message via ident path.
884
885 This method is used to send a already serialized message.
886
887 Parameters
888 ----------
889 stream : ZMQStream or Socket
890 The ZMQ stream or socket to use for sending the message.
891 msg_list : list
892 The serialized list of messages to send. This only includes the
893 [p_header,p_parent,p_metadata,p_content,buffer1,buffer2,...] portion of
894 the message.
895 ident : ident or list
896 A single ident or a list of idents to use in sending.
897 """
898 to_send = []
899 if isinstance(ident, bytes):
900 ident = [ident]
901 if ident is not None:
902 to_send.extend(ident)
903
904 to_send.append(DELIM)
905 # Don't include buffers in signature (per spec).
906 to_send.append(self.sign(msg_list[0:4]))
907 to_send.extend(msg_list)
908 if isinstance(stream, zmq.asyncio.Socket):
909 stream = zmq.Socket.shadow(stream.underlying)
910 stream.send_multipart(to_send, flags, copy=copy)
911
912 def recv(
913 self,
914 socket: zmq.sugar.socket.Socket,
915 mode: int = zmq.NOBLOCK,
916 content: bool = True,
917 copy: bool = True,
918 ) -> tuple[list[bytes] | None, dict[str, t.Any] | None]:
919 """Receive and unpack a message.
920
921 Parameters
922 ----------
923 socket : ZMQStream or Socket
924 The socket or stream to use in receiving.
925
926 Returns
927 -------
928 [idents], msg
929 [idents] is a list of idents and msg is a nested message dict of
930 same format as self.msg returns.
931 """
932 if isinstance(socket, ZMQStream): # type:ignore[unreachable]
933 socket = socket.socket # type:ignore[unreachable]
934 if isinstance(socket, zmq.asyncio.Socket):
935 socket = zmq.Socket.shadow(socket.underlying)
936
937 try:
938 msg_list = socket.recv_multipart(mode, copy=copy)
939 except zmq.ZMQError as e:
940 if e.errno == zmq.EAGAIN:
941 # We can convert EAGAIN to None as we know in this case
942 # recv_multipart won't return None.
943 return None, None
944 else:
945 raise
946 # split multipart message into identity list and message dict
947 # invalid large messages can cause very expensive string comparisons
948 idents, msg_list = self.feed_identities(msg_list, copy)
949 try:
950 return idents, self.deserialize(msg_list, content=content, copy=copy)
951 except Exception as e:
952 # TODO: handle it
953 raise e
954
955 def feed_identities(
956 self, msg_list: list[bytes] | list[zmq.Message], copy: bool = True
957 ) -> tuple[list[bytes], list[bytes] | list[zmq.Message]]:
958 """Split the identities from the rest of the message.
959
960 Feed until DELIM is reached, then return the prefix as idents and
961 remainder as msg_list. This is easily broken by setting an IDENT to DELIM,
962 but that would be silly.
963
964 Parameters
965 ----------
966 msg_list : a list of Message or bytes objects
967 The message to be split.
968 copy : bool
969 flag determining whether the arguments are bytes or Messages
970
971 Returns
972 -------
973 (idents, msg_list) : two lists
974 idents will always be a list of bytes, each of which is a ZMQ
975 identity. msg_list will be a list of bytes or zmq.Messages of the
976 form [HMAC,p_header,p_parent,p_content,buffer1,buffer2,...] and
977 should be unpackable/unserializable via self.deserialize at this
978 point.
979 """
980 if copy:
981 msg_list = t.cast(t.List[bytes], msg_list)
982 idx = msg_list.index(DELIM)
983 return msg_list[:idx], msg_list[idx + 1 :]
984 else:
985 msg_list = t.cast(t.List[zmq.Message], msg_list)
986 failed = True
987 for idx, m in enumerate(msg_list): # noqa
988 if m.bytes == DELIM:
989 failed = False
990 break
991 if failed:
992 msg = "DELIM not in msg_list"
993 raise ValueError(msg)
994 idents, msg_list = msg_list[:idx], msg_list[idx + 1 :]
995 return [bytes(m.bytes) for m in idents], msg_list
996
997 def _add_digest(self, signature: bytes) -> None:
998 """add a digest to history to protect against replay attacks"""
999 if self.digest_history_size == 0:
1000 # no history, never add digests
1001 return
1002
1003 self.digest_history.add(signature)
1004 if len(self.digest_history) > self.digest_history_size:
1005 # threshold reached, cull 10%
1006 self._cull_digest_history()
1007
1008 def _cull_digest_history(self) -> None:
1009 """cull the digest history
1010
1011 Removes a randomly selected 10% of the digest history
1012 """
1013 current = len(self.digest_history)
1014 n_to_cull = max(int(current // 10), current - self.digest_history_size)
1015 if n_to_cull >= current:
1016 self.digest_history = set()
1017 return
1018 to_cull = random.sample(tuple(sorted(self.digest_history)), n_to_cull)
1019 self.digest_history.difference_update(to_cull)
1020
1021 def deserialize(
1022 self,
1023 msg_list: list[bytes] | list[zmq.Message],
1024 content: bool = True,
1025 copy: bool = True,
1026 ) -> dict[str, t.Any]:
1027 """Unserialize a msg_list to a nested message dict.
1028
1029 This is roughly the inverse of serialize. The serialize/deserialize
1030 methods work with full message lists, whereas pack/unpack work with
1031 the individual message parts in the message list.
1032
1033 Parameters
1034 ----------
1035 msg_list : list of bytes or Message objects
1036 The list of message parts of the form [HMAC,p_header,p_parent,
1037 p_metadata,p_content,buffer1,buffer2,...].
1038 content : bool (True)
1039 Whether to unpack the content dict (True), or leave it packed
1040 (False).
1041 copy : bool (True)
1042 Whether msg_list contains bytes (True) or the non-copying Message
1043 objects in each place (False).
1044
1045 Returns
1046 -------
1047 msg : dict
1048 The nested message dict with top-level keys [header, parent_header,
1049 content, buffers]. The buffers are returned as memoryviews.
1050 """
1051 minlen = 5
1052 message = {}
1053 if not copy:
1054 # pyzmq didn't copy the first parts of the message, so we'll do it
1055 msg_list = t.cast(t.List[zmq.Message], msg_list)
1056 msg_list_beginning = [bytes(msg.bytes) for msg in msg_list[:minlen]]
1057 msg_list = t.cast(t.List[bytes], msg_list)
1058 msg_list = msg_list_beginning + msg_list[minlen:]
1059 msg_list = t.cast(t.List[bytes], msg_list)
1060 if self.auth is not None:
1061 signature = msg_list[0]
1062 if not signature:
1063 msg = "Unsigned Message"
1064 raise ValueError(msg)
1065 if signature in self.digest_history:
1066 raise ValueError("Duplicate Signature: %r" % signature)
1067 if content:
1068 # Only store signature if we are unpacking content, don't store if just peeking.
1069 self._add_digest(signature)
1070 check = self.sign(msg_list[1:5])
1071 if not compare_digest(signature, check):
1072 msg = "Invalid Signature: %r" % signature
1073 raise ValueError(msg)
1074 if not len(msg_list) >= minlen:
1075 msg = "malformed message, must have at least %i elements" % minlen
1076 raise TypeError(msg)
1077 header = self.unpack(msg_list[1])
1078 message["header"] = extract_dates(header)
1079 message["msg_id"] = header["msg_id"]
1080 message["msg_type"] = header["msg_type"]
1081 message["parent_header"] = extract_dates(self.unpack(msg_list[2]))
1082 message["metadata"] = self.unpack(msg_list[3])
1083 if content:
1084 message["content"] = self.unpack(msg_list[4])
1085 else:
1086 message["content"] = msg_list[4]
1087 buffers = [memoryview(b) for b in msg_list[5:]]
1088 if buffers and buffers[0].shape is None:
1089 # force copy to workaround pyzmq #646
1090 msg_list = t.cast(t.List[zmq.Message], msg_list)
1091 buffers = [memoryview(bytes(b.bytes)) for b in msg_list[5:]]
1092 message["buffers"] = buffers
1093 if self.debug:
1094 pprint.pprint(message) # noqa
1095 # adapt to the current version
1096 return adapt(message)
1097
1098 def unserialize(self, *args: t.Any, **kwargs: t.Any) -> dict[str, t.Any]:
1099 """**DEPRECATED** Use deserialize instead."""
1100 # pragma: no cover
1101 warnings.warn(
1102 "Session.unserialize is deprecated. Use Session.deserialize.",
1103 DeprecationWarning,
1104 stacklevel=2,
1105 )
1106 return self.deserialize(*args, **kwargs)