1"""Session object for building, serializing, sending, and receiving messages.
2
3The Session object supports serialization, HMAC signatures,
4and metadata on messages.
5
6Also defined here are utilities for working with Sessions:
7* A SessionFactory to be used as a base class for configurables that work with
8Sessions.
9* A Message object for convenience that allows attribute-access to the msg dict.
10"""
11# Copyright (c) Jupyter Development Team.
12# Distributed under the terms of the Modified BSD License.
13from __future__ import annotations
14
15import hashlib
16import hmac
17import json
18import logging
19import os
20import pickle
21import pprint
22import random
23import typing as t
24import warnings
25from binascii import b2a_hex
26from datetime import datetime, timezone
27from hmac import compare_digest
28
29# We are using compare_digest to limit the surface of timing attacks
30import zmq.asyncio
31from tornado.ioloop import IOLoop
32from traitlets import (
33 Any,
34 Bool,
35 CBytes,
36 CUnicode,
37 Dict,
38 DottedObjectName,
39 Instance,
40 Integer,
41 Set,
42 TraitError,
43 Unicode,
44 observe,
45)
46from traitlets.config.configurable import Configurable, LoggingConfigurable
47from traitlets.log import get_logger
48from traitlets.utils.importstring import import_item
49from zmq.eventloop.zmqstream import ZMQStream
50
51from ._version import protocol_version
52from .adapter import adapt
53from .jsonutil import extract_dates, json_clean, json_default, squash_dates
54
55PICKLE_PROTOCOL = pickle.DEFAULT_PROTOCOL
56
57utc = timezone.utc
58
59# -----------------------------------------------------------------------------
60# utility functions
61# -----------------------------------------------------------------------------
62
63
64def squash_unicode(obj: t.Any) -> t.Any:
65 """coerce unicode back to bytestrings."""
66 if isinstance(obj, dict):
67 for key in list(obj.keys()):
68 obj[key] = squash_unicode(obj[key])
69 if isinstance(key, str):
70 obj[squash_unicode(key)] = obj.pop(key)
71 elif isinstance(obj, list):
72 for i, v in enumerate(obj):
73 obj[i] = squash_unicode(v)
74 elif isinstance(obj, str):
75 obj = obj.encode("utf8")
76 return obj
77
78
79# -----------------------------------------------------------------------------
80# globals and defaults
81# -----------------------------------------------------------------------------
82
83# default values for the thresholds:
84MAX_ITEMS = 64
85MAX_BYTES = 1024
86
87# ISO8601-ify datetime objects
88# allow unicode
89# disallow nan, because it's not actually valid JSON
90
91
92def json_packer(obj: t.Any) -> bytes:
93 """Convert a json object to a bytes."""
94 try:
95 return json.dumps(
96 obj,
97 default=json_default,
98 ensure_ascii=False,
99 allow_nan=False,
100 ).encode("utf8", errors="surrogateescape")
101 except (TypeError, ValueError) as e:
102 # Fallback to trying to clean the json before serializing
103 packed = json.dumps(
104 json_clean(obj),
105 default=json_default,
106 ensure_ascii=False,
107 allow_nan=False,
108 ).encode("utf8", errors="surrogateescape")
109
110 warnings.warn(
111 f"Message serialization failed with:\n{e}\n"
112 "Supporting this message is deprecated in jupyter-client 7, please make "
113 "sure your message is JSON-compliant",
114 stacklevel=2,
115 )
116
117 return packed
118
119
120def json_unpacker(s: str | bytes) -> t.Any:
121 """Convert a json bytes or string to an object."""
122 if isinstance(s, bytes):
123 s = s.decode("utf8", "replace")
124 return json.loads(s)
125
126
127def pickle_packer(o: t.Any) -> bytes:
128 """Pack an object using the pickle module."""
129 return pickle.dumps(squash_dates(o), PICKLE_PROTOCOL)
130
131
132pickle_unpacker = pickle.loads
133
134default_packer = json_packer
135default_unpacker = json_unpacker
136
137DELIM = b"<IDS|MSG>"
138# singleton dummy tracker, which will always report as done
139DONE = zmq.MessageTracker()
140
141# -----------------------------------------------------------------------------
142# Mixin tools for apps that use Sessions
143# -----------------------------------------------------------------------------
144
145
146def new_id() -> str:
147 """Generate a new random id.
148
149 Avoids problematic runtime import in stdlib uuid on Python 2.
150
151 Returns
152 -------
153
154 id string (16 random bytes as hex-encoded text, chunks separated by '-')
155 """
156 buf = os.urandom(16)
157 return "-".join(b2a_hex(x).decode("ascii") for x in (buf[:4], buf[4:]))
158
159
160def new_id_bytes() -> bytes:
161 """Return new_id as ascii bytes"""
162 return new_id().encode("ascii")
163
164
165session_aliases = {
166 "ident": "Session.session",
167 "user": "Session.username",
168 "keyfile": "Session.keyfile",
169}
170
171session_flags = {
172 "secure": (
173 {"Session": {"key": new_id_bytes(), "keyfile": ""}},
174 """Use HMAC digests for authentication of messages.
175 Setting this flag will generate a new UUID to use as the HMAC key.
176 """,
177 ),
178 "no-secure": (
179 {"Session": {"key": b"", "keyfile": ""}},
180 """Don't authenticate messages.""",
181 ),
182}
183
184
185def default_secure(cfg: t.Any) -> None: # pragma: no cover
186 """Set the default behavior for a config environment to be secure.
187
188 If Session.key/keyfile have not been set, set Session.key to
189 a new random UUID.
190 """
191 warnings.warn("default_secure is deprecated", DeprecationWarning, stacklevel=2)
192 if "Session" in cfg and ("key" in cfg.Session or "keyfile" in cfg.Session):
193 return
194 # key/keyfile not specified, generate new UUID:
195 cfg.Session.key = new_id_bytes()
196
197
198def utcnow() -> datetime:
199 """Return timezone-aware UTC timestamp"""
200 return datetime.now(utc)
201
202
203# -----------------------------------------------------------------------------
204# Classes
205# -----------------------------------------------------------------------------
206
207
208class SessionFactory(LoggingConfigurable):
209 """The Base class for configurables that have a Session, Context, logger,
210 and IOLoop.
211 """
212
213 logname = Unicode("")
214
215 @observe("logname")
216 def _logname_changed(self, change: t.Any) -> None:
217 self.log = logging.getLogger(change["new"])
218
219 # not configurable:
220 context = Instance("zmq.Context")
221
222 def _context_default(self) -> zmq.Context:
223 return zmq.Context()
224
225 session = Instance("jupyter_client.session.Session", allow_none=True)
226
227 loop = Instance("tornado.ioloop.IOLoop")
228
229 def _loop_default(self) -> IOLoop:
230 return IOLoop.current()
231
232 def __init__(self, **kwargs: t.Any) -> None:
233 """Initialize a session factory."""
234 super().__init__(**kwargs)
235
236 if self.session is None:
237 # construct the session
238 self.session = Session(**kwargs)
239
240
241class Message:
242 """A simple message object that maps dict keys to attributes.
243
244 A Message can be created from a dict and a dict from a Message instance
245 simply by calling dict(msg_obj)."""
246
247 def __init__(self, msg_dict: dict[str, t.Any]) -> None:
248 """Initialize a message."""
249 dct = self.__dict__
250 for k, v in dict(msg_dict).items():
251 if isinstance(v, dict):
252 v = Message(v) # noqa
253 dct[k] = v
254
255 # Having this iterator lets dict(msg_obj) work out of the box.
256 def __iter__(self) -> t.ItemsView[str, t.Any]:
257 return iter(self.__dict__.items()) # type:ignore[return-value]
258
259 def __repr__(self) -> str:
260 return repr(self.__dict__)
261
262 def __str__(self) -> str:
263 return pprint.pformat(self.__dict__)
264
265 def __contains__(self, k: object) -> bool:
266 return k in self.__dict__
267
268 def __getitem__(self, k: str) -> t.Any:
269 return self.__dict__[k]
270
271
272def msg_header(
273 msg_id: str, msg_type: str, username: str, session: Session | str
274) -> dict[str, t.Any]:
275 """Create a new message header"""
276 date = utcnow()
277 version = protocol_version
278 return locals()
279
280
281def extract_header(msg_or_header: dict[str, t.Any]) -> dict[str, t.Any]:
282 """Given a message or header, return the header."""
283 if not msg_or_header:
284 return {}
285 try:
286 # See if msg_or_header is the entire message.
287 h = msg_or_header["header"]
288 except KeyError:
289 try:
290 # See if msg_or_header is just the header
291 h = msg_or_header["msg_id"]
292 except KeyError:
293 raise
294 else:
295 h = msg_or_header
296 if not isinstance(h, dict):
297 h = dict(h)
298 return h
299
300
301class Session(Configurable):
302 """Object for handling serialization and sending of messages.
303
304 The Session object handles building messages and sending them
305 with ZMQ sockets or ZMQStream objects. Objects can communicate with each
306 other over the network via Session objects, and only need to work with the
307 dict-based IPython message spec. The Session will handle
308 serialization/deserialization, security, and metadata.
309
310 Sessions support configurable serialization via packer/unpacker traits,
311 and signing with HMAC digests via the key/keyfile traits.
312
313 Parameters
314 ----------
315
316 debug : bool
317 whether to trigger extra debugging statements
318 packer/unpacker : str : 'json', 'pickle' or import_string
319 importstrings for methods to serialize message parts. If just
320 'json' or 'pickle', predefined JSON and pickle packers will be used.
321 Otherwise, the entire importstring must be used.
322
323 The functions must accept at least valid JSON input, and output *bytes*.
324
325 For example, to use msgpack:
326 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
327 pack/unpack : callables
328 You can also set the pack/unpack callables for serialization directly.
329 session : bytes
330 the ID of this Session object. The default is to generate a new UUID.
331 username : unicode
332 username added to message headers. The default is to ask the OS.
333 key : bytes
334 The key used to initialize an HMAC signature. If unset, messages
335 will not be signed or checked.
336 keyfile : filepath
337 The file containing a key. If this is set, `key` will be initialized
338 to the contents of the file.
339
340 """
341
342 debug = Bool(False, config=True, help="""Debug output in the Session""")
343
344 check_pid = Bool(
345 True,
346 config=True,
347 help="""Whether to check PID to protect against calls after fork.
348
349 This check can be disabled if fork-safety is handled elsewhere.
350 """,
351 )
352
353 packer = DottedObjectName(
354 "json",
355 config=True,
356 help="""The name of the packer for serializing messages.
357 Should be one of 'json', 'pickle', or an import name
358 for a custom callable serializer.""",
359 )
360
361 @observe("packer")
362 def _packer_changed(self, change: t.Any) -> None:
363 new = change["new"]
364 if new.lower() == "json":
365 self.pack = json_packer
366 self.unpack = json_unpacker
367 self.unpacker = new
368 elif new.lower() == "pickle":
369 self.pack = pickle_packer
370 self.unpack = pickle_unpacker
371 self.unpacker = new
372 else:
373 self.pack = import_item(str(new))
374
375 unpacker = DottedObjectName(
376 "json",
377 config=True,
378 help="""The name of the unpacker for unserializing messages.
379 Only used with custom functions for `packer`.""",
380 )
381
382 @observe("unpacker")
383 def _unpacker_changed(self, change: t.Any) -> None:
384 new = change["new"]
385 if new.lower() == "json":
386 self.pack = json_packer
387 self.unpack = json_unpacker
388 self.packer = new
389 elif new.lower() == "pickle":
390 self.pack = pickle_packer
391 self.unpack = pickle_unpacker
392 self.packer = new
393 else:
394 self.unpack = import_item(str(new))
395
396 session = CUnicode("", config=True, help="""The UUID identifying this session.""")
397
398 def _session_default(self) -> str:
399 u = new_id()
400 self.bsession = u.encode("ascii")
401 return u
402
403 @observe("session")
404 def _session_changed(self, change: t.Any) -> None:
405 self.bsession = self.session.encode("ascii")
406
407 # bsession is the session as bytes
408 bsession = CBytes(b"")
409
410 username = Unicode(
411 os.environ.get("USER", "username"),
412 help="""Username for the Session. Default is your system username.""",
413 config=True,
414 )
415
416 metadata = Dict(
417 {},
418 config=True,
419 help="Metadata dictionary, which serves as the default top-level metadata dict for each "
420 "message.",
421 )
422
423 # if 0, no adapting to do.
424 adapt_version = Integer(0)
425
426 # message signature related traits:
427
428 key = CBytes(config=True, help="""execution key, for signing messages.""")
429
430 def _key_default(self) -> bytes:
431 return new_id_bytes()
432
433 @observe("key")
434 def _key_changed(self, change: t.Any) -> None:
435 self._new_auth()
436
437 signature_scheme = Unicode(
438 "hmac-sha256",
439 config=True,
440 help="""The digest scheme used to construct the message signatures.
441 Must have the form 'hmac-HASH'.""",
442 )
443
444 @observe("signature_scheme")
445 def _signature_scheme_changed(self, change: t.Any) -> None:
446 new = change["new"]
447 if not new.startswith("hmac-"):
448 raise TraitError("signature_scheme must start with 'hmac-', got %r" % new)
449 hash_name = new.split("-", 1)[1]
450 try:
451 self.digest_mod = getattr(hashlib, hash_name)
452 except AttributeError as e:
453 raise TraitError("hashlib has no such attribute: %s" % hash_name) from e
454 self._new_auth()
455
456 digest_mod = Any()
457
458 def _digest_mod_default(self) -> t.Callable:
459 return hashlib.sha256
460
461 auth = Instance(hmac.HMAC, allow_none=True)
462
463 def _new_auth(self) -> None:
464 if self.key:
465 self.auth = hmac.HMAC(self.key, digestmod=self.digest_mod)
466 else:
467 self.auth = None
468
469 digest_history = Set()
470 digest_history_size = Integer(
471 2**16,
472 config=True,
473 help="""The maximum number of digests to remember.
474
475 The digest history will be culled when it exceeds this value.
476 """,
477 )
478
479 keyfile = Unicode("", config=True, help="""path to file containing execution key.""")
480
481 @observe("keyfile")
482 def _keyfile_changed(self, change: t.Any) -> None:
483 with open(change["new"], "rb") as f:
484 self.key = f.read().strip()
485
486 # for protecting against sends from forks
487 pid = Integer()
488
489 # serialization traits:
490
491 pack = Any(default_packer) # the actual packer function
492
493 @observe("pack")
494 def _pack_changed(self, change: t.Any) -> None:
495 new = change["new"]
496 if not callable(new):
497 raise TypeError("packer must be callable, not %s" % type(new))
498
499 unpack = Any(default_unpacker) # the actual packer function
500
501 @observe("unpack")
502 def _unpack_changed(self, change: t.Any) -> None:
503 # unpacker is not checked - it is assumed to be
504 new = change["new"]
505 if not callable(new):
506 raise TypeError("unpacker must be callable, not %s" % type(new))
507
508 # thresholds:
509 copy_threshold = Integer(
510 2**16,
511 config=True,
512 help="Threshold (in bytes) beyond which a buffer should be sent without copying.",
513 )
514 buffer_threshold = Integer(
515 MAX_BYTES,
516 config=True,
517 help="Threshold (in bytes) beyond which an object's buffer should be extracted to avoid "
518 "pickling.",
519 )
520 item_threshold = Integer(
521 MAX_ITEMS,
522 config=True,
523 help="""The maximum number of items for a container to be introspected for custom serialization.
524 Containers larger than this are pickled outright.
525 """,
526 )
527
528 def __init__(self, **kwargs: t.Any) -> None:
529 """create a Session object
530
531 Parameters
532 ----------
533
534 debug : bool
535 whether to trigger extra debugging statements
536 packer/unpacker : str : 'json', 'pickle' or import_string
537 importstrings for methods to serialize message parts. If just
538 'json' or 'pickle', predefined JSON and pickle packers will be used.
539 Otherwise, the entire importstring must be used.
540
541 The functions must accept at least valid JSON input, and output
542 *bytes*.
543
544 For example, to use msgpack:
545 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
546 pack/unpack : callables
547 You can also set the pack/unpack callables for serialization
548 directly.
549 session : unicode (must be ascii)
550 the ID of this Session object. The default is to generate a new
551 UUID.
552 bsession : bytes
553 The session as bytes
554 username : unicode
555 username added to message headers. The default is to ask the OS.
556 key : bytes
557 The key used to initialize an HMAC signature. If unset, messages
558 will not be signed or checked.
559 signature_scheme : str
560 The message digest scheme. Currently must be of the form 'hmac-HASH',
561 where 'HASH' is a hashing function available in Python's hashlib.
562 The default is 'hmac-sha256'.
563 This is ignored if 'key' is empty.
564 keyfile : filepath
565 The file containing a key. If this is set, `key` will be
566 initialized to the contents of the file.
567 """
568 super().__init__(**kwargs)
569 self._check_packers()
570 self.none = self.pack({})
571 # ensure self._session_default() if necessary, so bsession is defined:
572 self.session # noqa
573 self.pid = os.getpid()
574 self._new_auth()
575 if not self.key:
576 get_logger().warning(
577 "Message signing is disabled. This is insecure and not recommended!"
578 )
579
580 def clone(self) -> Session:
581 """Create a copy of this Session
582
583 Useful when connecting multiple times to a given kernel.
584 This prevents a shared digest_history warning about duplicate digests
585 due to multiple connections to IOPub in the same process.
586
587 .. versionadded:: 5.1
588 """
589 # make a copy
590 new_session = type(self)()
591 for name in self.traits():
592 setattr(new_session, name, getattr(self, name))
593 # fork digest_history
594 new_session.digest_history = set()
595 new_session.digest_history.update(self.digest_history)
596 return new_session
597
598 message_count = 0
599
600 @property
601 def msg_id(self) -> str:
602 message_number = self.message_count
603 self.message_count += 1
604 return f"{self.session}_{os.getpid()}_{message_number}"
605
606 def _check_packers(self) -> None:
607 """check packers for datetime support."""
608 pack = self.pack
609 unpack = self.unpack
610
611 # check simple serialization
612 msg_list = {"a": [1, "hi"]}
613 try:
614 packed = pack(msg_list)
615 except Exception as e:
616 msg = f"packer '{self.packer}' could not serialize a simple message: {e}"
617 raise ValueError(msg) from e
618
619 # ensure packed message is bytes
620 if not isinstance(packed, bytes):
621 raise ValueError("message packed to %r, but bytes are required" % type(packed))
622
623 # check that unpack is pack's inverse
624 try:
625 unpacked = unpack(packed)
626 assert unpacked == msg_list
627 except Exception as e:
628 msg = (
629 f"unpacker '{self.unpacker}' could not handle output from packer"
630 f" '{self.packer}': {e}"
631 )
632 raise ValueError(msg) from e
633
634 # check datetime support
635 msg_datetime = {"t": utcnow()}
636 try:
637 unpacked = unpack(pack(msg_datetime))
638 if isinstance(unpacked["t"], datetime):
639 msg = "Shouldn't deserialize to datetime"
640 raise ValueError(msg)
641 except Exception:
642 self.pack = lambda o: pack(squash_dates(o))
643 self.unpack = lambda s: unpack(s)
644
645 def msg_header(self, msg_type: str) -> dict[str, t.Any]:
646 """Create a header for a message type."""
647 return msg_header(self.msg_id, msg_type, self.username, self.session)
648
649 def msg(
650 self,
651 msg_type: str,
652 content: dict | None = None,
653 parent: dict[str, t.Any] | None = None,
654 header: dict[str, t.Any] | None = None,
655 metadata: dict[str, t.Any] | None = None,
656 ) -> dict[str, t.Any]:
657 """Return the nested message dict.
658
659 This format is different from what is sent over the wire. The
660 serialize/deserialize methods converts this nested message dict to the wire
661 format, which is a list of message parts.
662 """
663 msg = {}
664 header = self.msg_header(msg_type) if header is None else header
665 msg["header"] = header
666 msg["msg_id"] = header["msg_id"]
667 msg["msg_type"] = header["msg_type"]
668 msg["parent_header"] = {} if parent is None else extract_header(parent)
669 msg["content"] = {} if content is None else content
670 msg["metadata"] = self.metadata.copy()
671 if metadata is not None:
672 msg["metadata"].update(metadata)
673 return msg
674
675 def sign(self, msg_list: list) -> bytes:
676 """Sign a message with HMAC digest. If no auth, return b''.
677
678 Parameters
679 ----------
680 msg_list : list
681 The [p_header,p_parent,p_content] part of the message list.
682 """
683 if self.auth is None:
684 return b""
685 h = self.auth.copy()
686 for m in msg_list:
687 h.update(m)
688 return h.hexdigest().encode()
689
690 def serialize(
691 self,
692 msg: dict[str, t.Any],
693 ident: list[bytes] | bytes | None = None,
694 ) -> list[bytes]:
695 """Serialize the message components to bytes.
696
697 This is roughly the inverse of deserialize. The serialize/deserialize
698 methods work with full message lists, whereas pack/unpack work with
699 the individual message parts in the message list.
700
701 Parameters
702 ----------
703 msg : dict or Message
704 The next message dict as returned by the self.msg method.
705
706 Returns
707 -------
708 msg_list : list
709 The list of bytes objects to be sent with the format::
710
711 [ident1, ident2, ..., DELIM, HMAC, p_header, p_parent,
712 p_metadata, p_content, buffer1, buffer2, ...]
713
714 In this list, the ``p_*`` entities are the packed or serialized
715 versions, so if JSON is used, these are utf8 encoded JSON strings.
716 """
717 content = msg.get("content", {})
718 if content is None:
719 content = self.none
720 elif isinstance(content, dict):
721 content = self.pack(content)
722 elif isinstance(content, bytes):
723 # content is already packed, as in a relayed message
724 pass
725 elif isinstance(content, str):
726 # should be bytes, but JSON often spits out unicode
727 content = content.encode("utf8")
728 else:
729 raise TypeError("Content incorrect type: %s" % type(content))
730
731 real_message = [
732 self.pack(msg["header"]),
733 self.pack(msg["parent_header"]),
734 self.pack(msg["metadata"]),
735 content,
736 ]
737
738 to_send = []
739
740 if isinstance(ident, list):
741 # accept list of idents
742 to_send.extend(ident)
743 elif ident is not None:
744 to_send.append(ident)
745 to_send.append(DELIM)
746
747 signature = self.sign(real_message)
748 to_send.append(signature)
749
750 to_send.extend(real_message)
751
752 return to_send
753
754 def send(
755 self,
756 stream: zmq.sugar.socket.Socket | ZMQStream | None,
757 msg_or_type: dict[str, t.Any] | str,
758 content: dict[str, t.Any] | None = None,
759 parent: dict[str, t.Any] | None = None,
760 ident: bytes | list[bytes] | None = None,
761 buffers: list[bytes] | None = None,
762 track: bool = False,
763 header: dict[str, t.Any] | None = None,
764 metadata: dict[str, t.Any] | None = None,
765 ) -> dict[str, t.Any] | None:
766 """Build and send a message via stream or socket.
767
768 The message format used by this function internally is as follows:
769
770 [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content,
771 buffer1,buffer2,...]
772
773 The serialize/deserialize methods convert the nested message dict into this
774 format.
775
776 Parameters
777 ----------
778
779 stream : zmq.Socket or ZMQStream
780 The socket-like object used to send the data.
781 msg_or_type : str or Message/dict
782 Normally, msg_or_type will be a msg_type unless a message is being
783 sent more than once. If a header is supplied, this can be set to
784 None and the msg_type will be pulled from the header.
785
786 content : dict or None
787 The content of the message (ignored if msg_or_type is a message).
788 header : dict or None
789 The header dict for the message (ignored if msg_to_type is a message).
790 parent : Message or dict or None
791 The parent or parent header describing the parent of this message
792 (ignored if msg_or_type is a message).
793 ident : bytes or list of bytes
794 The zmq.IDENTITY routing path.
795 metadata : dict or None
796 The metadata describing the message
797 buffers : list or None
798 The already-serialized buffers to be appended to the message.
799 track : bool
800 Whether to track. Only for use with Sockets, because ZMQStream
801 objects cannot track messages.
802
803
804 Returns
805 -------
806 msg : dict
807 The constructed message.
808 """
809 if not isinstance(stream, zmq.Socket):
810 # ZMQStreams and dummy sockets do not support tracking.
811 track = False
812
813 if isinstance(stream, zmq.asyncio.Socket):
814 assert stream is not None # type:ignore[unreachable]
815 stream = zmq.Socket.shadow(stream.underlying)
816
817 if isinstance(msg_or_type, (Message, dict)):
818 # We got a Message or message dict, not a msg_type so don't
819 # build a new Message.
820 msg = msg_or_type
821 buffers = buffers or msg.get("buffers", [])
822 else:
823 msg = self.msg(
824 msg_or_type,
825 content=content,
826 parent=parent,
827 header=header,
828 metadata=metadata,
829 )
830 if self.check_pid and os.getpid() != self.pid:
831 get_logger().warning("WARNING: attempted to send message from fork\n%s", msg)
832 return None
833 buffers = [] if buffers is None else buffers
834 for idx, buf in enumerate(buffers):
835 if isinstance(buf, memoryview):
836 view = buf
837 else:
838 try:
839 # check to see if buf supports the buffer protocol.
840 view = memoryview(buf)
841 except TypeError as e:
842 emsg = "Buffer objects must support the buffer protocol."
843 raise TypeError(emsg) from e
844 # memoryview.contiguous is new in 3.3,
845 # just skip the check on Python 2
846 if hasattr(view, "contiguous") and not view.contiguous:
847 # zmq requires memoryviews to be contiguous
848 raise ValueError("Buffer %i (%r) is not contiguous" % (idx, buf))
849
850 if self.adapt_version:
851 msg = adapt(msg, self.adapt_version)
852 to_send = self.serialize(msg, ident)
853 to_send.extend(buffers)
854 longest = max([len(s) for s in to_send])
855 copy = longest < self.copy_threshold
856
857 if stream and buffers and track and not copy:
858 # only really track when we are doing zero-copy buffers
859 tracker = stream.send_multipart(to_send, copy=False, track=True)
860 elif stream:
861 # use dummy tracker, which will be done immediately
862 tracker = DONE
863 stream.send_multipart(to_send, copy=copy)
864 else:
865 tracker = DONE
866
867 if self.debug:
868 pprint.pprint(msg) # noqa
869 pprint.pprint(to_send) # noqa
870 pprint.pprint(buffers) # noqa
871
872 msg["tracker"] = tracker
873
874 return msg
875
876 def send_raw(
877 self,
878 stream: zmq.sugar.socket.Socket,
879 msg_list: list,
880 flags: int = 0,
881 copy: bool = True,
882 ident: bytes | list[bytes] | None = None,
883 ) -> None:
884 """Send a raw message via ident path.
885
886 This method is used to send a already serialized message.
887
888 Parameters
889 ----------
890 stream : ZMQStream or Socket
891 The ZMQ stream or socket to use for sending the message.
892 msg_list : list
893 The serialized list of messages to send. This only includes the
894 [p_header,p_parent,p_metadata,p_content,buffer1,buffer2,...] portion of
895 the message.
896 ident : ident or list
897 A single ident or a list of idents to use in sending.
898 """
899 to_send = []
900 if isinstance(ident, bytes):
901 ident = [ident]
902 if ident is not None:
903 to_send.extend(ident)
904
905 to_send.append(DELIM)
906 # Don't include buffers in signature (per spec).
907 to_send.append(self.sign(msg_list[0:4]))
908 to_send.extend(msg_list)
909 if isinstance(stream, zmq.asyncio.Socket):
910 stream = zmq.Socket.shadow(stream.underlying)
911 stream.send_multipart(to_send, flags, copy=copy)
912
913 def recv(
914 self,
915 socket: zmq.sugar.socket.Socket,
916 mode: int = zmq.NOBLOCK,
917 content: bool = True,
918 copy: bool = True,
919 ) -> tuple[list[bytes] | None, dict[str, t.Any] | None]:
920 """Receive and unpack a message.
921
922 Parameters
923 ----------
924 socket : ZMQStream or Socket
925 The socket or stream to use in receiving.
926
927 Returns
928 -------
929 [idents], msg
930 [idents] is a list of idents and msg is a nested message dict of
931 same format as self.msg returns.
932 """
933 if isinstance(socket, ZMQStream): # type:ignore[unreachable]
934 socket = socket.socket # type:ignore[unreachable]
935 if isinstance(socket, zmq.asyncio.Socket):
936 socket = zmq.Socket.shadow(socket.underlying)
937
938 try:
939 msg_list = socket.recv_multipart(mode, copy=copy)
940 except zmq.ZMQError as e:
941 if e.errno == zmq.EAGAIN:
942 # We can convert EAGAIN to None as we know in this case
943 # recv_multipart won't return None.
944 return None, None
945 else:
946 raise
947 # split multipart message into identity list and message dict
948 # invalid large messages can cause very expensive string comparisons
949 idents, msg_list = self.feed_identities(msg_list, copy)
950 try:
951 return idents, self.deserialize(msg_list, content=content, copy=copy)
952 except Exception as e:
953 # TODO: handle it
954 raise e
955
956 def feed_identities(
957 self, msg_list: list[bytes] | list[zmq.Message], copy: bool = True
958 ) -> tuple[list[bytes], list[bytes] | list[zmq.Message]]:
959 """Split the identities from the rest of the message.
960
961 Feed until DELIM is reached, then return the prefix as idents and
962 remainder as msg_list. This is easily broken by setting an IDENT to DELIM,
963 but that would be silly.
964
965 Parameters
966 ----------
967 msg_list : a list of Message or bytes objects
968 The message to be split.
969 copy : bool
970 flag determining whether the arguments are bytes or Messages
971
972 Returns
973 -------
974 (idents, msg_list) : two lists
975 idents will always be a list of bytes, each of which is a ZMQ
976 identity. msg_list will be a list of bytes or zmq.Messages of the
977 form [HMAC,p_header,p_parent,p_content,buffer1,buffer2,...] and
978 should be unpackable/unserializable via self.deserialize at this
979 point.
980 """
981 if copy:
982 msg_list = t.cast(t.List[bytes], msg_list)
983 idx = msg_list.index(DELIM)
984 return msg_list[:idx], msg_list[idx + 1 :]
985 else:
986 msg_list = t.cast(t.List[zmq.Message], msg_list)
987 failed = True
988 for idx, m in enumerate(msg_list): # noqa
989 if m.bytes == DELIM:
990 failed = False
991 break
992 if failed:
993 msg = "DELIM not in msg_list"
994 raise ValueError(msg)
995 idents, msg_list = msg_list[:idx], msg_list[idx + 1 :]
996 return [bytes(m.bytes) for m in idents], msg_list
997
998 def _add_digest(self, signature: bytes) -> None:
999 """add a digest to history to protect against replay attacks"""
1000 if self.digest_history_size == 0:
1001 # no history, never add digests
1002 return
1003
1004 self.digest_history.add(signature)
1005 if len(self.digest_history) > self.digest_history_size:
1006 # threshold reached, cull 10%
1007 self._cull_digest_history()
1008
1009 def _cull_digest_history(self) -> None:
1010 """cull the digest history
1011
1012 Removes a randomly selected 10% of the digest history
1013 """
1014 current = len(self.digest_history)
1015 n_to_cull = max(int(current // 10), current - self.digest_history_size)
1016 if n_to_cull >= current:
1017 self.digest_history = set()
1018 return
1019 to_cull = random.sample(tuple(sorted(self.digest_history)), n_to_cull)
1020 self.digest_history.difference_update(to_cull)
1021
1022 def deserialize(
1023 self,
1024 msg_list: list[bytes] | list[zmq.Message],
1025 content: bool = True,
1026 copy: bool = True,
1027 ) -> dict[str, t.Any]:
1028 """Unserialize a msg_list to a nested message dict.
1029
1030 This is roughly the inverse of serialize. The serialize/deserialize
1031 methods work with full message lists, whereas pack/unpack work with
1032 the individual message parts in the message list.
1033
1034 Parameters
1035 ----------
1036 msg_list : list of bytes or Message objects
1037 The list of message parts of the form [HMAC,p_header,p_parent,
1038 p_metadata,p_content,buffer1,buffer2,...].
1039 content : bool (True)
1040 Whether to unpack the content dict (True), or leave it packed
1041 (False).
1042 copy : bool (True)
1043 Whether msg_list contains bytes (True) or the non-copying Message
1044 objects in each place (False).
1045
1046 Returns
1047 -------
1048 msg : dict
1049 The nested message dict with top-level keys [header, parent_header,
1050 content, buffers]. The buffers are returned as memoryviews.
1051 """
1052 minlen = 5
1053 message = {}
1054 if not copy:
1055 # pyzmq didn't copy the first parts of the message, so we'll do it
1056 msg_list = t.cast(t.List[zmq.Message], msg_list)
1057 msg_list_beginning = [bytes(msg.bytes) for msg in msg_list[:minlen]]
1058 msg_list = t.cast(t.List[bytes], msg_list)
1059 msg_list = msg_list_beginning + msg_list[minlen:]
1060 msg_list = t.cast(t.List[bytes], msg_list)
1061 if self.auth is not None:
1062 signature = msg_list[0]
1063 if not signature:
1064 msg = "Unsigned Message"
1065 raise ValueError(msg)
1066 if signature in self.digest_history:
1067 raise ValueError("Duplicate Signature: %r" % signature)
1068 if content:
1069 # Only store signature if we are unpacking content, don't store if just peeking.
1070 self._add_digest(signature)
1071 check = self.sign(msg_list[1:5])
1072 if not compare_digest(signature, check):
1073 msg = "Invalid Signature: %r" % signature
1074 raise ValueError(msg)
1075 if not len(msg_list) >= minlen:
1076 msg = "malformed message, must have at least %i elements" % minlen
1077 raise TypeError(msg)
1078 header = self.unpack(msg_list[1])
1079 message["header"] = extract_dates(header)
1080 message["msg_id"] = header["msg_id"]
1081 message["msg_type"] = header["msg_type"]
1082 message["parent_header"] = extract_dates(self.unpack(msg_list[2]))
1083 message["metadata"] = self.unpack(msg_list[3])
1084 if content:
1085 message["content"] = self.unpack(msg_list[4])
1086 else:
1087 message["content"] = msg_list[4]
1088 buffers = [memoryview(b) for b in msg_list[5:]]
1089 if buffers and buffers[0].shape is None:
1090 # force copy to workaround pyzmq #646
1091 msg_list = t.cast(t.List[zmq.Message], msg_list)
1092 buffers = [memoryview(bytes(b.bytes)) for b in msg_list[5:]]
1093 message["buffers"] = buffers
1094 if self.debug:
1095 pprint.pprint(message) # noqa
1096 # adapt to the current version
1097 return adapt(message)
1098
1099 def unserialize(self, *args: t.Any, **kwargs: t.Any) -> dict[str, t.Any]:
1100 """**DEPRECATED** Use deserialize instead."""
1101 # pragma: no cover
1102 warnings.warn(
1103 "Session.unserialize is deprecated. Use Session.deserialize.",
1104 DeprecationWarning,
1105 stacklevel=2,
1106 )
1107 return self.deserialize(*args, **kwargs)