Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/jupyter_client/session.py: 28%
443 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-07-01 06:54 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-07-01 06:54 +0000
1"""Session object for building, serializing, sending, and receiving messages.
3The Session object supports serialization, HMAC signatures,
4and metadata on messages.
6Also defined here are utilities for working with Sessions:
7* A SessionFactory to be used as a base class for configurables that work with
8Sessions.
9* A Message object for convenience that allows attribute-access to the msg dict.
10"""
11# Copyright (c) Jupyter Development Team.
12# Distributed under the terms of the Modified BSD License.
13import hashlib
14import hmac
15import json
16import logging
17import os
18import pickle
19import pprint
20import random
21import typing as t
22import warnings
23from binascii import b2a_hex
24from datetime import datetime, timezone
25from hmac import compare_digest
27# We are using compare_digest to limit the surface of timing attacks
28from typing import Optional, Union
30import zmq.asyncio
31from tornado.ioloop import IOLoop
32from traitlets import (
33 Any,
34 Bool,
35 CBytes,
36 CUnicode,
37 Dict,
38 DottedObjectName,
39 Instance,
40 Integer,
41 Set,
42 TraitError,
43 Unicode,
44 observe,
45)
46from traitlets.config.configurable import Configurable, LoggingConfigurable
47from traitlets.log import get_logger
48from traitlets.utils.importstring import import_item
49from zmq.eventloop.zmqstream import ZMQStream
51from ._version import protocol_version
52from .adapter import adapt
53from .jsonutil import extract_dates, json_clean, json_default, squash_dates
55PICKLE_PROTOCOL = pickle.DEFAULT_PROTOCOL
57utc = timezone.utc
59# -----------------------------------------------------------------------------
60# utility functions
61# -----------------------------------------------------------------------------
64def squash_unicode(obj):
65 """coerce unicode back to bytestrings."""
66 if isinstance(obj, dict):
67 for key in list(obj.keys()):
68 obj[key] = squash_unicode(obj[key])
69 if isinstance(key, str):
70 obj[squash_unicode(key)] = obj.pop(key)
71 elif isinstance(obj, list):
72 for i, v in enumerate(obj):
73 obj[i] = squash_unicode(v)
74 elif isinstance(obj, str):
75 obj = obj.encode("utf8")
76 return obj
79# -----------------------------------------------------------------------------
80# globals and defaults
81# -----------------------------------------------------------------------------
83# default values for the thresholds:
84MAX_ITEMS = 64
85MAX_BYTES = 1024
87# ISO8601-ify datetime objects
88# allow unicode
89# disallow nan, because it's not actually valid JSON
92def json_packer(obj):
93 """Convert a json object to a bytes."""
94 try:
95 return json.dumps(
96 obj,
97 default=json_default,
98 ensure_ascii=False,
99 allow_nan=False,
100 ).encode("utf8", errors="surrogateescape")
101 except (TypeError, ValueError) as e:
102 # Fallback to trying to clean the json before serializing
103 packed = json.dumps(
104 json_clean(obj),
105 default=json_default,
106 ensure_ascii=False,
107 allow_nan=False,
108 ).encode("utf8", errors="surrogateescape")
110 warnings.warn(
111 f"Message serialization failed with:\n{e}\n"
112 "Supporting this message is deprecated in jupyter-client 7, please make "
113 "sure your message is JSON-compliant",
114 stacklevel=2,
115 )
117 return packed
120def json_unpacker(s):
121 """Convert a json bytes or string to an object."""
122 if isinstance(s, bytes):
123 s = s.decode("utf8", "replace")
124 return json.loads(s)
127def pickle_packer(o):
128 """Pack an object using the pickle module."""
129 return pickle.dumps(squash_dates(o), PICKLE_PROTOCOL)
132pickle_unpacker = pickle.loads
134default_packer = json_packer
135default_unpacker = json_unpacker
137DELIM = b"<IDS|MSG>"
138# singleton dummy tracker, which will always report as done
139DONE = zmq.MessageTracker()
141# -----------------------------------------------------------------------------
142# Mixin tools for apps that use Sessions
143# -----------------------------------------------------------------------------
146def new_id() -> str:
147 """Generate a new random id.
149 Avoids problematic runtime import in stdlib uuid on Python 2.
151 Returns
152 -------
154 id string (16 random bytes as hex-encoded text, chunks separated by '-')
155 """
156 buf = os.urandom(16)
157 return "-".join(b2a_hex(x).decode("ascii") for x in (buf[:4], buf[4:]))
160def new_id_bytes() -> bytes:
161 """Return new_id as ascii bytes"""
162 return new_id().encode("ascii")
165session_aliases = {
166 "ident": "Session.session",
167 "user": "Session.username",
168 "keyfile": "Session.keyfile",
169}
171session_flags = {
172 "secure": (
173 {"Session": {"key": new_id_bytes(), "keyfile": ""}},
174 """Use HMAC digests for authentication of messages.
175 Setting this flag will generate a new UUID to use as the HMAC key.
176 """,
177 ),
178 "no-secure": (
179 {"Session": {"key": b"", "keyfile": ""}},
180 """Don't authenticate messages.""",
181 ),
182}
185def default_secure(cfg: t.Any) -> None: # pragma: no cover
186 """Set the default behavior for a config environment to be secure.
188 If Session.key/keyfile have not been set, set Session.key to
189 a new random UUID.
190 """
191 warnings.warn("default_secure is deprecated", DeprecationWarning, stacklevel=2)
192 if "Session" in cfg and ("key" in cfg.Session or "keyfile" in cfg.Session):
193 return
194 # key/keyfile not specified, generate new UUID:
195 cfg.Session.key = new_id_bytes()
198def utcnow() -> datetime:
199 """Return timezone-aware UTC timestamp"""
200 return datetime.utcnow().replace(tzinfo=utc) # noqa
203# -----------------------------------------------------------------------------
204# Classes
205# -----------------------------------------------------------------------------
208class SessionFactory(LoggingConfigurable):
209 """The Base class for configurables that have a Session, Context, logger,
210 and IOLoop.
211 """
213 logname = Unicode("")
215 @observe("logname") # type:ignore[misc]
216 def _logname_changed(self, change: t.Any) -> None:
217 self.log = logging.getLogger(change["new"])
219 # not configurable:
220 context = Instance("zmq.Context")
222 def _context_default(self) -> zmq.Context:
223 return zmq.Context()
225 session = Instance("jupyter_client.session.Session", allow_none=True)
227 loop = Instance("tornado.ioloop.IOLoop")
229 def _loop_default(self):
230 return IOLoop.current()
232 def __init__(self, **kwargs):
233 """Initialize a session factory."""
234 super().__init__(**kwargs)
236 if self.session is None:
237 # construct the session
238 self.session = Session(**kwargs)
241class Message:
242 """A simple message object that maps dict keys to attributes.
244 A Message can be created from a dict and a dict from a Message instance
245 simply by calling dict(msg_obj)."""
247 def __init__(self, msg_dict: t.Dict[str, t.Any]) -> None:
248 """Initialize a message."""
249 dct = self.__dict__
250 for k, v in dict(msg_dict).items():
251 if isinstance(v, dict):
252 v = Message(v) # noqa
253 dct[k] = v
255 # Having this iterator lets dict(msg_obj) work out of the box.
256 def __iter__(self) -> t.ItemsView[str, t.Any]:
257 return iter(self.__dict__.items()) # type:ignore[return-value]
259 def __repr__(self) -> str:
260 return repr(self.__dict__)
262 def __str__(self) -> str:
263 return pprint.pformat(self.__dict__)
265 def __contains__(self, k: object) -> bool:
266 return k in self.__dict__
268 def __getitem__(self, k: str) -> t.Any:
269 return self.__dict__[k]
272def msg_header(msg_id: str, msg_type: str, username: str, session: "Session") -> t.Dict[str, t.Any]:
273 """Create a new message header"""
274 date = utcnow()
275 version = protocol_version
276 return locals()
279def extract_header(msg_or_header: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]:
280 """Given a message or header, return the header."""
281 if not msg_or_header:
282 return {}
283 try:
284 # See if msg_or_header is the entire message.
285 h = msg_or_header["header"]
286 except KeyError:
287 try:
288 # See if msg_or_header is just the header
289 h = msg_or_header["msg_id"]
290 except KeyError:
291 raise
292 else:
293 h = msg_or_header
294 if not isinstance(h, dict):
295 h = dict(h)
296 return h
299class Session(Configurable):
300 """Object for handling serialization and sending of messages.
302 The Session object handles building messages and sending them
303 with ZMQ sockets or ZMQStream objects. Objects can communicate with each
304 other over the network via Session objects, and only need to work with the
305 dict-based IPython message spec. The Session will handle
306 serialization/deserialization, security, and metadata.
308 Sessions support configurable serialization via packer/unpacker traits,
309 and signing with HMAC digests via the key/keyfile traits.
311 Parameters
312 ----------
314 debug : bool
315 whether to trigger extra debugging statements
316 packer/unpacker : str : 'json', 'pickle' or import_string
317 importstrings for methods to serialize message parts. If just
318 'json' or 'pickle', predefined JSON and pickle packers will be used.
319 Otherwise, the entire importstring must be used.
321 The functions must accept at least valid JSON input, and output *bytes*.
323 For example, to use msgpack:
324 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
325 pack/unpack : callables
326 You can also set the pack/unpack callables for serialization directly.
327 session : bytes
328 the ID of this Session object. The default is to generate a new UUID.
329 username : unicode
330 username added to message headers. The default is to ask the OS.
331 key : bytes
332 The key used to initialize an HMAC signature. If unset, messages
333 will not be signed or checked.
334 keyfile : filepath
335 The file containing a key. If this is set, `key` will be initialized
336 to the contents of the file.
338 """
340 debug = Bool(False, config=True, help="""Debug output in the Session""")
342 check_pid = Bool(
343 True,
344 config=True,
345 help="""Whether to check PID to protect against calls after fork.
347 This check can be disabled if fork-safety is handled elsewhere.
348 """,
349 )
351 packer = DottedObjectName(
352 "json",
353 config=True,
354 help="""The name of the packer for serializing messages.
355 Should be one of 'json', 'pickle', or an import name
356 for a custom callable serializer.""",
357 )
359 @observe("packer")
360 def _packer_changed(self, change):
361 new = change["new"]
362 if new.lower() == "json":
363 self.pack = json_packer
364 self.unpack = json_unpacker
365 self.unpacker = new
366 elif new.lower() == "pickle":
367 self.pack = pickle_packer
368 self.unpack = pickle_unpacker
369 self.unpacker = new
370 else:
371 self.pack = import_item(str(new))
373 unpacker = DottedObjectName(
374 "json",
375 config=True,
376 help="""The name of the unpacker for unserializing messages.
377 Only used with custom functions for `packer`.""",
378 )
380 @observe("unpacker")
381 def _unpacker_changed(self, change):
382 new = change["new"]
383 if new.lower() == "json":
384 self.pack = json_packer
385 self.unpack = json_unpacker
386 self.packer = new
387 elif new.lower() == "pickle":
388 self.pack = pickle_packer
389 self.unpack = pickle_unpacker
390 self.packer = new
391 else:
392 self.unpack = import_item(str(new))
394 session = CUnicode("", config=True, help="""The UUID identifying this session.""")
396 def _session_default(self) -> str:
397 u = new_id()
398 self.bsession = u.encode("ascii")
399 return u
401 @observe("session")
402 def _session_changed(self, change):
403 self.bsession = self.session.encode("ascii")
405 # bsession is the session as bytes
406 bsession = CBytes(b"")
408 username = Unicode(
409 os.environ.get("USER", "username"),
410 help="""Username for the Session. Default is your system username.""",
411 config=True,
412 )
414 metadata = Dict(
415 {},
416 config=True,
417 help="Metadata dictionary, which serves as the default top-level metadata dict for each "
418 "message.",
419 )
421 # if 0, no adapting to do.
422 adapt_version = Integer(0)
424 # message signature related traits:
426 key = CBytes(config=True, help="""execution key, for signing messages.""")
428 def _key_default(self) -> bytes:
429 return new_id_bytes()
431 @observe("key")
432 def _key_changed(self, change):
433 self._new_auth()
435 signature_scheme = Unicode(
436 "hmac-sha256",
437 config=True,
438 help="""The digest scheme used to construct the message signatures.
439 Must have the form 'hmac-HASH'.""",
440 )
442 @observe("signature_scheme")
443 def _signature_scheme_changed(self, change):
444 new = change["new"]
445 if not new.startswith("hmac-"):
446 raise TraitError("signature_scheme must start with 'hmac-', got %r" % new)
447 hash_name = new.split("-", 1)[1]
448 try:
449 self.digest_mod = getattr(hashlib, hash_name)
450 except AttributeError as e:
451 raise TraitError("hashlib has no such attribute: %s" % hash_name) from e
452 self._new_auth()
454 digest_mod = Any()
456 def _digest_mod_default(self) -> t.Callable:
457 return hashlib.sha256
459 auth = Instance(hmac.HMAC, allow_none=True)
461 def _new_auth(self) -> None:
462 if self.key:
463 self.auth = hmac.HMAC(self.key, digestmod=self.digest_mod)
464 else:
465 self.auth = None
467 digest_history = Set()
468 digest_history_size = Integer(
469 2**16,
470 config=True,
471 help="""The maximum number of digests to remember.
473 The digest history will be culled when it exceeds this value.
474 """,
475 )
477 keyfile = Unicode("", config=True, help="""path to file containing execution key.""")
479 @observe("keyfile")
480 def _keyfile_changed(self, change):
481 with open(change["new"], "rb") as f:
482 self.key = f.read().strip()
484 # for protecting against sends from forks
485 pid = Integer()
487 # serialization traits:
489 pack = Any(default_packer) # the actual packer function
491 @observe("pack")
492 def _pack_changed(self, change):
493 new = change["new"]
494 if not callable(new):
495 raise TypeError("packer must be callable, not %s" % type(new))
497 unpack = Any(default_unpacker) # the actual packer function
499 @observe("unpack")
500 def _unpack_changed(self, change):
501 # unpacker is not checked - it is assumed to be
502 new = change["new"]
503 if not callable(new):
504 raise TypeError("unpacker must be callable, not %s" % type(new))
506 # thresholds:
507 copy_threshold = Integer(
508 2**16,
509 config=True,
510 help="Threshold (in bytes) beyond which a buffer should be sent without copying.",
511 )
512 buffer_threshold = Integer(
513 MAX_BYTES,
514 config=True,
515 help="Threshold (in bytes) beyond which an object's buffer should be extracted to avoid "
516 "pickling.",
517 )
518 item_threshold = Integer(
519 MAX_ITEMS,
520 config=True,
521 help="""The maximum number of items for a container to be introspected for custom serialization.
522 Containers larger than this are pickled outright.
523 """,
524 )
526 def __init__(self, **kwargs):
527 """create a Session object
529 Parameters
530 ----------
532 debug : bool
533 whether to trigger extra debugging statements
534 packer/unpacker : str : 'json', 'pickle' or import_string
535 importstrings for methods to serialize message parts. If just
536 'json' or 'pickle', predefined JSON and pickle packers will be used.
537 Otherwise, the entire importstring must be used.
539 The functions must accept at least valid JSON input, and output
540 *bytes*.
542 For example, to use msgpack:
543 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
544 pack/unpack : callables
545 You can also set the pack/unpack callables for serialization
546 directly.
547 session : unicode (must be ascii)
548 the ID of this Session object. The default is to generate a new
549 UUID.
550 bsession : bytes
551 The session as bytes
552 username : unicode
553 username added to message headers. The default is to ask the OS.
554 key : bytes
555 The key used to initialize an HMAC signature. If unset, messages
556 will not be signed or checked.
557 signature_scheme : str
558 The message digest scheme. Currently must be of the form 'hmac-HASH',
559 where 'HASH' is a hashing function available in Python's hashlib.
560 The default is 'hmac-sha256'.
561 This is ignored if 'key' is empty.
562 keyfile : filepath
563 The file containing a key. If this is set, `key` will be
564 initialized to the contents of the file.
565 """
566 super().__init__(**kwargs)
567 self._check_packers()
568 self.none = self.pack({})
569 # ensure self._session_default() if necessary, so bsession is defined:
570 self.session # noqa
571 self.pid = os.getpid()
572 self._new_auth()
573 if not self.key:
574 get_logger().warning(
575 "Message signing is disabled. This is insecure and not recommended!"
576 )
578 def clone(self) -> "Session":
579 """Create a copy of this Session
581 Useful when connecting multiple times to a given kernel.
582 This prevents a shared digest_history warning about duplicate digests
583 due to multiple connections to IOPub in the same process.
585 .. versionadded:: 5.1
586 """
587 # make a copy
588 new_session = type(self)()
589 for name in self.traits():
590 setattr(new_session, name, getattr(self, name))
591 # fork digest_history
592 new_session.digest_history = set()
593 new_session.digest_history.update(self.digest_history)
594 return new_session
596 message_count = 0
598 @property
599 def msg_id(self) -> str:
600 message_number = self.message_count
601 self.message_count += 1
602 return f"{self.session}_{os.getpid()}_{message_number}"
604 def _check_packers(self) -> None:
605 """check packers for datetime support."""
606 pack = self.pack
607 unpack = self.unpack
609 # check simple serialization
610 msg_list = {"a": [1, "hi"]}
611 try:
612 packed = pack(msg_list)
613 except Exception as e:
614 msg = f"packer '{self.packer}' could not serialize a simple message: {e}"
615 raise ValueError(msg) from e
617 # ensure packed message is bytes
618 if not isinstance(packed, bytes):
619 raise ValueError("message packed to %r, but bytes are required" % type(packed))
621 # check that unpack is pack's inverse
622 try:
623 unpacked = unpack(packed)
624 assert unpacked == msg_list
625 except Exception as e:
626 msg = (
627 f"unpacker '{self.unpacker}' could not handle output from packer"
628 f" '{self.packer}': {e}"
629 )
630 raise ValueError(msg) from e
632 # check datetime support
633 msg_datetime = {"t": utcnow()}
634 try:
635 unpacked = unpack(pack(msg_datetime))
636 if isinstance(unpacked["t"], datetime):
637 msg = "Shouldn't deserialize to datetime"
638 raise ValueError(msg)
639 except Exception:
640 self.pack = lambda o: pack(squash_dates(o))
641 self.unpack = lambda s: unpack(s)
643 def msg_header(self, msg_type: str) -> t.Dict[str, t.Any]:
644 """Create a header for a message type."""
645 return msg_header(self.msg_id, msg_type, self.username, self.session)
647 def msg(
648 self,
649 msg_type: str,
650 content: t.Optional[t.Dict] = None,
651 parent: t.Optional[t.Dict[str, t.Any]] = None,
652 header: t.Optional[t.Dict[str, t.Any]] = None,
653 metadata: t.Optional[t.Dict[str, t.Any]] = None,
654 ) -> t.Dict[str, t.Any]:
655 """Return the nested message dict.
657 This format is different from what is sent over the wire. The
658 serialize/deserialize methods converts this nested message dict to the wire
659 format, which is a list of message parts.
660 """
661 msg = {}
662 header = self.msg_header(msg_type) if header is None else header
663 msg["header"] = header
664 msg["msg_id"] = header["msg_id"]
665 msg["msg_type"] = header["msg_type"]
666 msg["parent_header"] = {} if parent is None else extract_header(parent)
667 msg["content"] = {} if content is None else content
668 msg["metadata"] = self.metadata.copy()
669 if metadata is not None:
670 msg["metadata"].update(metadata)
671 return msg
673 def sign(self, msg_list: t.List) -> bytes:
674 """Sign a message with HMAC digest. If no auth, return b''.
676 Parameters
677 ----------
678 msg_list : list
679 The [p_header,p_parent,p_content] part of the message list.
680 """
681 if self.auth is None:
682 return b""
683 h = self.auth.copy()
684 for m in msg_list:
685 h.update(m)
686 return h.hexdigest().encode()
688 def serialize(
689 self,
690 msg: t.Dict[str, t.Any],
691 ident: t.Optional[t.Union[t.List[bytes], bytes]] = None,
692 ) -> t.List[bytes]:
693 """Serialize the message components to bytes.
695 This is roughly the inverse of deserialize. The serialize/deserialize
696 methods work with full message lists, whereas pack/unpack work with
697 the individual message parts in the message list.
699 Parameters
700 ----------
701 msg : dict or Message
702 The next message dict as returned by the self.msg method.
704 Returns
705 -------
706 msg_list : list
707 The list of bytes objects to be sent with the format::
709 [ident1, ident2, ..., DELIM, HMAC, p_header, p_parent,
710 p_metadata, p_content, buffer1, buffer2, ...]
712 In this list, the ``p_*`` entities are the packed or serialized
713 versions, so if JSON is used, these are utf8 encoded JSON strings.
714 """
715 content = msg.get("content", {})
716 if content is None:
717 content = self.none
718 elif isinstance(content, dict):
719 content = self.pack(content)
720 elif isinstance(content, bytes):
721 # content is already packed, as in a relayed message
722 pass
723 elif isinstance(content, str):
724 # should be bytes, but JSON often spits out unicode
725 content = content.encode("utf8")
726 else:
727 raise TypeError("Content incorrect type: %s" % type(content))
729 real_message = [
730 self.pack(msg["header"]),
731 self.pack(msg["parent_header"]),
732 self.pack(msg["metadata"]),
733 content,
734 ]
736 to_send = []
738 if isinstance(ident, list):
739 # accept list of idents
740 to_send.extend(ident)
741 elif ident is not None:
742 to_send.append(ident)
743 to_send.append(DELIM)
745 signature = self.sign(real_message)
746 to_send.append(signature)
748 to_send.extend(real_message)
750 return to_send
752 def send(
753 self,
754 stream: Optional[Union[zmq.sugar.socket.Socket, ZMQStream]],
755 msg_or_type: t.Union[t.Dict[str, t.Any], str],
756 content: t.Optional[t.Dict[str, t.Any]] = None,
757 parent: t.Optional[t.Dict[str, t.Any]] = None,
758 ident: t.Optional[t.Union[bytes, t.List[bytes]]] = None,
759 buffers: t.Optional[t.List[bytes]] = None,
760 track: bool = False,
761 header: t.Optional[t.Dict[str, t.Any]] = None,
762 metadata: t.Optional[t.Dict[str, t.Any]] = None,
763 ) -> t.Optional[t.Dict[str, t.Any]]:
764 """Build and send a message via stream or socket.
766 The message format used by this function internally is as follows:
768 [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content,
769 buffer1,buffer2,...]
771 The serialize/deserialize methods convert the nested message dict into this
772 format.
774 Parameters
775 ----------
777 stream : zmq.Socket or ZMQStream
778 The socket-like object used to send the data.
779 msg_or_type : str or Message/dict
780 Normally, msg_or_type will be a msg_type unless a message is being
781 sent more than once. If a header is supplied, this can be set to
782 None and the msg_type will be pulled from the header.
784 content : dict or None
785 The content of the message (ignored if msg_or_type is a message).
786 header : dict or None
787 The header dict for the message (ignored if msg_to_type is a message).
788 parent : Message or dict or None
789 The parent or parent header describing the parent of this message
790 (ignored if msg_or_type is a message).
791 ident : bytes or list of bytes
792 The zmq.IDENTITY routing path.
793 metadata : dict or None
794 The metadata describing the message
795 buffers : list or None
796 The already-serialized buffers to be appended to the message.
797 track : bool
798 Whether to track. Only for use with Sockets, because ZMQStream
799 objects cannot track messages.
802 Returns
803 -------
804 msg : dict
805 The constructed message.
806 """
807 if not isinstance(stream, zmq.Socket):
808 # ZMQStreams and dummy sockets do not support tracking.
809 track = False
811 if isinstance(stream, zmq.asyncio.Socket):
812 assert stream is not None
813 stream = zmq.Socket.shadow(stream.underlying)
815 if isinstance(msg_or_type, (Message, dict)):
816 # We got a Message or message dict, not a msg_type so don't
817 # build a new Message.
818 msg = msg_or_type
819 buffers = buffers or msg.get("buffers", [])
820 else:
821 msg = self.msg(
822 msg_or_type,
823 content=content,
824 parent=parent,
825 header=header,
826 metadata=metadata,
827 )
828 if self.check_pid and os.getpid() != self.pid:
829 get_logger().warning("WARNING: attempted to send message from fork\n%s", msg)
830 return None
831 buffers = [] if buffers is None else buffers
832 for idx, buf in enumerate(buffers):
833 if isinstance(buf, memoryview):
834 view = buf
835 else:
836 try:
837 # check to see if buf supports the buffer protocol.
838 view = memoryview(buf)
839 except TypeError as e:
840 emsg = "Buffer objects must support the buffer protocol."
841 raise TypeError(emsg) from e
842 # memoryview.contiguous is new in 3.3,
843 # just skip the check on Python 2
844 if hasattr(view, "contiguous") and not view.contiguous:
845 # zmq requires memoryviews to be contiguous
846 raise ValueError("Buffer %i (%r) is not contiguous" % (idx, buf))
848 if self.adapt_version:
849 msg = adapt(msg, self.adapt_version)
850 to_send = self.serialize(msg, ident)
851 to_send.extend(buffers)
852 longest = max([len(s) for s in to_send])
853 copy = longest < self.copy_threshold
855 if stream and buffers and track and not copy:
856 # only really track when we are doing zero-copy buffers
857 tracker = stream.send_multipart(to_send, copy=False, track=True)
858 elif stream:
859 # use dummy tracker, which will be done immediately
860 tracker = DONE
861 stream.send_multipart(to_send, copy=copy)
863 if self.debug:
864 pprint.pprint(msg) # noqa
865 pprint.pprint(to_send) # noqa
866 pprint.pprint(buffers) # noqa
868 msg["tracker"] = tracker
870 return msg
872 def send_raw(
873 self,
874 stream: zmq.sugar.socket.Socket,
875 msg_list: t.List,
876 flags: int = 0,
877 copy: bool = True,
878 ident: t.Optional[t.Union[bytes, t.List[bytes]]] = None,
879 ) -> None:
880 """Send a raw message via ident path.
882 This method is used to send a already serialized message.
884 Parameters
885 ----------
886 stream : ZMQStream or Socket
887 The ZMQ stream or socket to use for sending the message.
888 msg_list : list
889 The serialized list of messages to send. This only includes the
890 [p_header,p_parent,p_metadata,p_content,buffer1,buffer2,...] portion of
891 the message.
892 ident : ident or list
893 A single ident or a list of idents to use in sending.
894 """
895 to_send = []
896 if isinstance(ident, bytes):
897 ident = [ident]
898 if ident is not None:
899 to_send.extend(ident)
901 to_send.append(DELIM)
902 # Don't include buffers in signature (per spec).
903 to_send.append(self.sign(msg_list[0:4]))
904 to_send.extend(msg_list)
905 if isinstance(stream, zmq.asyncio.Socket):
906 stream = zmq.Socket.shadow(stream.underlying)
907 stream.send_multipart(to_send, flags, copy=copy)
909 def recv(
910 self,
911 socket: zmq.sugar.socket.Socket,
912 mode: int = zmq.NOBLOCK,
913 content: bool = True,
914 copy: bool = True,
915 ) -> t.Tuple[t.Optional[t.List[bytes]], t.Optional[t.Dict[str, t.Any]]]:
916 """Receive and unpack a message.
918 Parameters
919 ----------
920 socket : ZMQStream or Socket
921 The socket or stream to use in receiving.
923 Returns
924 -------
925 [idents], msg
926 [idents] is a list of idents and msg is a nested message dict of
927 same format as self.msg returns.
928 """
929 if isinstance(socket, ZMQStream):
930 socket = socket.socket
931 if isinstance(socket, zmq.asyncio.Socket):
932 socket = zmq.Socket.shadow(socket.underlying)
934 try:
935 msg_list = socket.recv_multipart(mode, copy=copy)
936 except zmq.ZMQError as e:
937 if e.errno == zmq.EAGAIN:
938 # We can convert EAGAIN to None as we know in this case
939 # recv_multipart won't return None.
940 return None, None
941 else:
942 raise
943 # split multipart message into identity list and message dict
944 # invalid large messages can cause very expensive string comparisons
945 idents, msg_list = self.feed_identities(msg_list, copy)
946 try:
947 return idents, self.deserialize(msg_list, content=content, copy=copy)
948 except Exception as e:
949 # TODO: handle it
950 raise e
952 def feed_identities(
953 self, msg_list: t.Union[t.List[bytes], t.List[zmq.Message]], copy: bool = True
954 ) -> t.Tuple[t.List[bytes], t.Union[t.List[bytes], t.List[zmq.Message]]]:
955 """Split the identities from the rest of the message.
957 Feed until DELIM is reached, then return the prefix as idents and
958 remainder as msg_list. This is easily broken by setting an IDENT to DELIM,
959 but that would be silly.
961 Parameters
962 ----------
963 msg_list : a list of Message or bytes objects
964 The message to be split.
965 copy : bool
966 flag determining whether the arguments are bytes or Messages
968 Returns
969 -------
970 (idents, msg_list) : two lists
971 idents will always be a list of bytes, each of which is a ZMQ
972 identity. msg_list will be a list of bytes or zmq.Messages of the
973 form [HMAC,p_header,p_parent,p_content,buffer1,buffer2,...] and
974 should be unpackable/unserializable via self.deserialize at this
975 point.
976 """
977 if copy:
978 msg_list = t.cast(t.List[bytes], msg_list)
979 idx = msg_list.index(DELIM)
980 return msg_list[:idx], msg_list[idx + 1 :]
981 else:
982 msg_list = t.cast(t.List[zmq.Message], msg_list)
983 failed = True
984 for idx, m in enumerate(msg_list): # noqa
985 if m.bytes == DELIM:
986 failed = False
987 break
988 if failed:
989 msg = "DELIM not in msg_list"
990 raise ValueError(msg)
991 idents, msg_list = msg_list[:idx], msg_list[idx + 1 :]
992 return [bytes(m.bytes) for m in idents], msg_list
994 def _add_digest(self, signature: bytes) -> None:
995 """add a digest to history to protect against replay attacks"""
996 if self.digest_history_size == 0:
997 # no history, never add digests
998 return
1000 self.digest_history.add(signature)
1001 if len(self.digest_history) > self.digest_history_size:
1002 # threshold reached, cull 10%
1003 self._cull_digest_history()
1005 def _cull_digest_history(self) -> None:
1006 """cull the digest history
1008 Removes a randomly selected 10% of the digest history
1009 """
1010 current = len(self.digest_history)
1011 n_to_cull = max(int(current // 10), current - self.digest_history_size)
1012 if n_to_cull >= current:
1013 self.digest_history = set()
1014 return
1015 to_cull = random.sample(tuple(sorted(self.digest_history)), n_to_cull)
1016 self.digest_history.difference_update(to_cull)
1018 def deserialize(
1019 self,
1020 msg_list: t.Union[t.List[bytes], t.List[zmq.Message]],
1021 content: bool = True,
1022 copy: bool = True,
1023 ) -> t.Dict[str, t.Any]:
1024 """Unserialize a msg_list to a nested message dict.
1026 This is roughly the inverse of serialize. The serialize/deserialize
1027 methods work with full message lists, whereas pack/unpack work with
1028 the individual message parts in the message list.
1030 Parameters
1031 ----------
1032 msg_list : list of bytes or Message objects
1033 The list of message parts of the form [HMAC,p_header,p_parent,
1034 p_metadata,p_content,buffer1,buffer2,...].
1035 content : bool (True)
1036 Whether to unpack the content dict (True), or leave it packed
1037 (False).
1038 copy : bool (True)
1039 Whether msg_list contains bytes (True) or the non-copying Message
1040 objects in each place (False).
1042 Returns
1043 -------
1044 msg : dict
1045 The nested message dict with top-level keys [header, parent_header,
1046 content, buffers]. The buffers are returned as memoryviews.
1047 """
1048 minlen = 5
1049 message = {}
1050 if not copy:
1051 # pyzmq didn't copy the first parts of the message, so we'll do it
1052 msg_list = t.cast(t.List[zmq.Message], msg_list)
1053 msg_list_beginning = [bytes(msg.bytes) for msg in msg_list[:minlen]]
1054 msg_list = t.cast(t.List[bytes], msg_list)
1055 msg_list = msg_list_beginning + msg_list[minlen:]
1056 msg_list = t.cast(t.List[bytes], msg_list)
1057 if self.auth is not None:
1058 signature = msg_list[0]
1059 if not signature:
1060 msg = "Unsigned Message"
1061 raise ValueError(msg)
1062 if signature in self.digest_history:
1063 raise ValueError("Duplicate Signature: %r" % signature)
1064 if content:
1065 # Only store signature if we are unpacking content, don't store if just peeking.
1066 self._add_digest(signature)
1067 check = self.sign(msg_list[1:5])
1068 if not compare_digest(signature, check):
1069 msg = "Invalid Signature: %r" % signature
1070 raise ValueError(msg)
1071 if not len(msg_list) >= minlen:
1072 msg = "malformed message, must have at least %i elements" % minlen
1073 raise TypeError(msg)
1074 header = self.unpack(msg_list[1])
1075 message["header"] = extract_dates(header)
1076 message["msg_id"] = header["msg_id"]
1077 message["msg_type"] = header["msg_type"]
1078 message["parent_header"] = extract_dates(self.unpack(msg_list[2]))
1079 message["metadata"] = self.unpack(msg_list[3])
1080 if content:
1081 message["content"] = self.unpack(msg_list[4])
1082 else:
1083 message["content"] = msg_list[4]
1084 buffers = [memoryview(b) for b in msg_list[5:]]
1085 if buffers and buffers[0].shape is None:
1086 # force copy to workaround pyzmq #646
1087 msg_list = t.cast(t.List[zmq.Message], msg_list)
1088 buffers = [memoryview(bytes(b.bytes)) for b in msg_list[5:]]
1089 message["buffers"] = buffers
1090 if self.debug:
1091 pprint.pprint(message) # noqa
1092 # adapt to the current version
1093 return adapt(message)
1095 def unserialize(self, *args: t.Any, **kwargs: t.Any) -> t.Dict[str, t.Any]:
1096 """**DEPRECATED** Use deserialize instead."""
1097 # pragma: no cover
1098 warnings.warn(
1099 "Session.unserialize is deprecated. Use Session.deserialize.",
1100 DeprecationWarning,
1101 stacklevel=2,
1102 )
1103 return self.deserialize(*args, **kwargs)