Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.10/site-packages/jupyter_client/session.py: 30%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

444 statements  

1"""Session object for building, serializing, sending, and receiving messages. 

2 

3The Session object supports serialization, HMAC signatures, 

4and metadata on messages. 

5 

6Also defined here are utilities for working with Sessions: 

7* A SessionFactory to be used as a base class for configurables that work with 

8Sessions. 

9* A Message object for convenience that allows attribute-access to the msg dict. 

10""" 

11# Copyright (c) Jupyter Development Team. 

12# Distributed under the terms of the Modified BSD License. 

13from __future__ import annotations 

14 

15import hashlib 

16import hmac 

17import json 

18import logging 

19import os 

20import pickle 

21import pprint 

22import random 

23import typing as t 

24import warnings 

25from binascii import b2a_hex 

26from datetime import datetime, timezone 

27from hmac import compare_digest 

28 

29# We are using compare_digest to limit the surface of timing attacks 

30import zmq.asyncio 

31from tornado.ioloop import IOLoop 

32from traitlets import ( 

33 Any, 

34 Bool, 

35 CBytes, 

36 CUnicode, 

37 Dict, 

38 DottedObjectName, 

39 Instance, 

40 Integer, 

41 Set, 

42 TraitError, 

43 Unicode, 

44 observe, 

45) 

46from traitlets.config.configurable import Configurable, LoggingConfigurable 

47from traitlets.log import get_logger 

48from traitlets.utils.importstring import import_item 

49from zmq.eventloop.zmqstream import ZMQStream 

50 

51from ._version import protocol_version 

52from .adapter import adapt 

53from .jsonutil import extract_dates, json_clean, json_default, squash_dates 

54 

55PICKLE_PROTOCOL = pickle.DEFAULT_PROTOCOL 

56 

57utc = timezone.utc 

58 

59# ----------------------------------------------------------------------------- 

60# utility functions 

61# ----------------------------------------------------------------------------- 

62 

63 

64def squash_unicode(obj: t.Any) -> t.Any: 

65 """coerce unicode back to bytestrings.""" 

66 if isinstance(obj, dict): 

67 for key in list(obj.keys()): 

68 obj[key] = squash_unicode(obj[key]) 

69 if isinstance(key, str): 

70 obj[squash_unicode(key)] = obj.pop(key) 

71 elif isinstance(obj, list): 

72 for i, v in enumerate(obj): 

73 obj[i] = squash_unicode(v) 

74 elif isinstance(obj, str): 

75 obj = obj.encode("utf8") 

76 return obj 

77 

78 

79# ----------------------------------------------------------------------------- 

80# globals and defaults 

81# ----------------------------------------------------------------------------- 

82 

83# default values for the thresholds: 

84MAX_ITEMS = 64 

85MAX_BYTES = 1024 

86 

87# ISO8601-ify datetime objects 

88# allow unicode 

89# disallow nan, because it's not actually valid JSON 

90 

91 

92def json_packer(obj: t.Any) -> bytes: 

93 """Convert a json object to a bytes.""" 

94 try: 

95 return json.dumps( 

96 obj, 

97 default=json_default, 

98 ensure_ascii=False, 

99 allow_nan=False, 

100 ).encode("utf8", errors="surrogateescape") 

101 except (TypeError, ValueError) as e: 

102 # Fallback to trying to clean the json before serializing 

103 packed = json.dumps( 

104 json_clean(obj), 

105 default=json_default, 

106 ensure_ascii=False, 

107 allow_nan=False, 

108 ).encode("utf8", errors="surrogateescape") 

109 

110 warnings.warn( 

111 f"Message serialization failed with:\n{e}\n" 

112 "Supporting this message is deprecated in jupyter-client 7, please make " 

113 "sure your message is JSON-compliant", 

114 stacklevel=2, 

115 ) 

116 

117 return packed 

118 

119 

120def json_unpacker(s: str | bytes) -> t.Any: 

121 """Convert a json bytes or string to an object.""" 

122 if isinstance(s, bytes): 

123 s = s.decode("utf8", "replace") 

124 return json.loads(s) 

125 

126 

127def pickle_packer(o: t.Any) -> bytes: 

128 """Pack an object using the pickle module.""" 

129 return pickle.dumps(squash_dates(o), PICKLE_PROTOCOL) 

130 

131 

132pickle_unpacker = pickle.loads 

133 

134default_packer = json_packer 

135default_unpacker = json_unpacker 

136 

137DELIM = b"<IDS|MSG>" 

138# singleton dummy tracker, which will always report as done 

139DONE = zmq.MessageTracker() 

140 

141# ----------------------------------------------------------------------------- 

142# Mixin tools for apps that use Sessions 

143# ----------------------------------------------------------------------------- 

144 

145 

146def new_id() -> str: 

147 """Generate a new random id. 

148 

149 Avoids problematic runtime import in stdlib uuid on Python 2. 

150 

151 Returns 

152 ------- 

153 

154 id string (16 random bytes as hex-encoded text, chunks separated by '-') 

155 """ 

156 buf = os.urandom(16) 

157 return "-".join(b2a_hex(x).decode("ascii") for x in (buf[:4], buf[4:])) 

158 

159 

160def new_id_bytes() -> bytes: 

161 """Return new_id as ascii bytes""" 

162 return new_id().encode("ascii") 

163 

164 

165session_aliases = { 

166 "ident": "Session.session", 

167 "user": "Session.username", 

168 "keyfile": "Session.keyfile", 

169} 

170 

171session_flags = { 

172 "secure": ( 

173 {"Session": {"key": new_id_bytes(), "keyfile": ""}}, 

174 """Use HMAC digests for authentication of messages. 

175 Setting this flag will generate a new UUID to use as the HMAC key. 

176 """, 

177 ), 

178 "no-secure": ( 

179 {"Session": {"key": b"", "keyfile": ""}}, 

180 """Don't authenticate messages.""", 

181 ), 

182} 

183 

184 

185def default_secure(cfg: t.Any) -> None: # pragma: no cover 

186 """Set the default behavior for a config environment to be secure. 

187 

188 If Session.key/keyfile have not been set, set Session.key to 

189 a new random UUID. 

190 """ 

191 warnings.warn("default_secure is deprecated", DeprecationWarning, stacklevel=2) 

192 if "Session" in cfg and ("key" in cfg.Session or "keyfile" in cfg.Session): 

193 return 

194 # key/keyfile not specified, generate new UUID: 

195 cfg.Session.key = new_id_bytes() 

196 

197 

198def utcnow() -> datetime: 

199 """Return timezone-aware UTC timestamp""" 

200 return datetime.now(utc) 

201 

202 

203# ----------------------------------------------------------------------------- 

204# Classes 

205# ----------------------------------------------------------------------------- 

206 

207 

208class SessionFactory(LoggingConfigurable): 

209 """The Base class for configurables that have a Session, Context, logger, 

210 and IOLoop. 

211 """ 

212 

213 logname = Unicode("") 

214 

215 @observe("logname") 

216 def _logname_changed(self, change: t.Any) -> None: 

217 self.log = logging.getLogger(change["new"]) 

218 

219 # not configurable: 

220 context = Instance("zmq.Context") 

221 

222 def _context_default(self) -> zmq.Context: 

223 return zmq.Context() 

224 

225 session = Instance("jupyter_client.session.Session", allow_none=True) 

226 

227 loop = Instance("tornado.ioloop.IOLoop") 

228 

229 def _loop_default(self) -> IOLoop: 

230 return IOLoop.current() 

231 

232 def __init__(self, **kwargs: t.Any) -> None: 

233 """Initialize a session factory.""" 

234 super().__init__(**kwargs) 

235 

236 if self.session is None: 

237 # construct the session 

238 self.session = Session(**kwargs) 

239 

240 

241class Message: 

242 """A simple message object that maps dict keys to attributes. 

243 

244 A Message can be created from a dict and a dict from a Message instance 

245 simply by calling dict(msg_obj).""" 

246 

247 def __init__(self, msg_dict: dict[str, t.Any]) -> None: 

248 """Initialize a message.""" 

249 dct = self.__dict__ 

250 for k, v in dict(msg_dict).items(): 

251 if isinstance(v, dict): 

252 v = Message(v) # noqa 

253 dct[k] = v 

254 

255 # Having this iterator lets dict(msg_obj) work out of the box. 

256 def __iter__(self) -> t.ItemsView[str, t.Any]: 

257 return iter(self.__dict__.items()) # type:ignore[return-value] 

258 

259 def __repr__(self) -> str: 

260 return repr(self.__dict__) 

261 

262 def __str__(self) -> str: 

263 return pprint.pformat(self.__dict__) 

264 

265 def __contains__(self, k: object) -> bool: 

266 return k in self.__dict__ 

267 

268 def __getitem__(self, k: str) -> t.Any: 

269 return self.__dict__[k] 

270 

271 

272def msg_header( 

273 msg_id: str, msg_type: str, username: str, session: Session | str 

274) -> dict[str, t.Any]: 

275 """Create a new message header""" 

276 date = utcnow() 

277 version = protocol_version 

278 return locals() 

279 

280 

281def extract_header(msg_or_header: dict[str, t.Any]) -> dict[str, t.Any]: 

282 """Given a message or header, return the header.""" 

283 if not msg_or_header: 

284 return {} 

285 try: 

286 # See if msg_or_header is the entire message. 

287 h = msg_or_header["header"] 

288 except KeyError: 

289 try: 

290 # See if msg_or_header is just the header 

291 h = msg_or_header["msg_id"] 

292 except KeyError: 

293 raise 

294 else: 

295 h = msg_or_header 

296 if not isinstance(h, dict): 

297 h = dict(h) 

298 return h 

299 

300 

301class Session(Configurable): 

302 """Object for handling serialization and sending of messages. 

303 

304 The Session object handles building messages and sending them 

305 with ZMQ sockets or ZMQStream objects. Objects can communicate with each 

306 other over the network via Session objects, and only need to work with the 

307 dict-based IPython message spec. The Session will handle 

308 serialization/deserialization, security, and metadata. 

309 

310 Sessions support configurable serialization via packer/unpacker traits, 

311 and signing with HMAC digests via the key/keyfile traits. 

312 

313 Parameters 

314 ---------- 

315 

316 debug : bool 

317 whether to trigger extra debugging statements 

318 packer/unpacker : str : 'json', 'pickle' or import_string 

319 importstrings for methods to serialize message parts. If just 

320 'json' or 'pickle', predefined JSON and pickle packers will be used. 

321 Otherwise, the entire importstring must be used. 

322 

323 The functions must accept at least valid JSON input, and output *bytes*. 

324 

325 For example, to use msgpack: 

326 packer = 'msgpack.packb', unpacker='msgpack.unpackb' 

327 pack/unpack : callables 

328 You can also set the pack/unpack callables for serialization directly. 

329 session : bytes 

330 the ID of this Session object. The default is to generate a new UUID. 

331 username : unicode 

332 username added to message headers. The default is to ask the OS. 

333 key : bytes 

334 The key used to initialize an HMAC signature. If unset, messages 

335 will not be signed or checked. 

336 keyfile : filepath 

337 The file containing a key. If this is set, `key` will be initialized 

338 to the contents of the file. 

339 

340 """ 

341 

342 debug = Bool(False, config=True, help="""Debug output in the Session""") 

343 

344 check_pid = Bool( 

345 True, 

346 config=True, 

347 help="""Whether to check PID to protect against calls after fork. 

348 

349 This check can be disabled if fork-safety is handled elsewhere. 

350 """, 

351 ) 

352 

353 packer = DottedObjectName( 

354 "json", 

355 config=True, 

356 help="""The name of the packer for serializing messages. 

357 Should be one of 'json', 'pickle', or an import name 

358 for a custom callable serializer.""", 

359 ) 

360 

361 @observe("packer") 

362 def _packer_changed(self, change: t.Any) -> None: 

363 new = change["new"] 

364 if new.lower() == "json": 

365 self.pack = json_packer 

366 self.unpack = json_unpacker 

367 self.unpacker = new 

368 elif new.lower() == "pickle": 

369 self.pack = pickle_packer 

370 self.unpack = pickle_unpacker 

371 self.unpacker = new 

372 else: 

373 self.pack = import_item(str(new)) 

374 

375 unpacker = DottedObjectName( 

376 "json", 

377 config=True, 

378 help="""The name of the unpacker for unserializing messages. 

379 Only used with custom functions for `packer`.""", 

380 ) 

381 

382 @observe("unpacker") 

383 def _unpacker_changed(self, change: t.Any) -> None: 

384 new = change["new"] 

385 if new.lower() == "json": 

386 self.pack = json_packer 

387 self.unpack = json_unpacker 

388 self.packer = new 

389 elif new.lower() == "pickle": 

390 self.pack = pickle_packer 

391 self.unpack = pickle_unpacker 

392 self.packer = new 

393 else: 

394 self.unpack = import_item(str(new)) 

395 

396 session = CUnicode("", config=True, help="""The UUID identifying this session.""") 

397 

398 def _session_default(self) -> str: 

399 u = new_id() 

400 self.bsession = u.encode("ascii") 

401 return u 

402 

403 @observe("session") 

404 def _session_changed(self, change: t.Any) -> None: 

405 self.bsession = self.session.encode("ascii") 

406 

407 # bsession is the session as bytes 

408 bsession = CBytes(b"") 

409 

410 username = Unicode( 

411 os.environ.get("USER", "username"), 

412 help="""Username for the Session. Default is your system username.""", 

413 config=True, 

414 ) 

415 

416 metadata = Dict( 

417 {}, 

418 config=True, 

419 help="Metadata dictionary, which serves as the default top-level metadata dict for each " 

420 "message.", 

421 ) 

422 

423 # if 0, no adapting to do. 

424 adapt_version = Integer(0) 

425 

426 # message signature related traits: 

427 

428 key = CBytes(config=True, help="""execution key, for signing messages.""") 

429 

430 def _key_default(self) -> bytes: 

431 return new_id_bytes() 

432 

433 @observe("key") 

434 def _key_changed(self, change: t.Any) -> None: 

435 self._new_auth() 

436 

437 signature_scheme = Unicode( 

438 "hmac-sha256", 

439 config=True, 

440 help="""The digest scheme used to construct the message signatures. 

441 Must have the form 'hmac-HASH'.""", 

442 ) 

443 

444 @observe("signature_scheme") 

445 def _signature_scheme_changed(self, change: t.Any) -> None: 

446 new = change["new"] 

447 if not new.startswith("hmac-"): 

448 raise TraitError("signature_scheme must start with 'hmac-', got %r" % new) 

449 hash_name = new.split("-", 1)[1] 

450 try: 

451 self.digest_mod = getattr(hashlib, hash_name) 

452 except AttributeError as e: 

453 raise TraitError("hashlib has no such attribute: %s" % hash_name) from e 

454 self._new_auth() 

455 

456 digest_mod = Any() 

457 

458 def _digest_mod_default(self) -> t.Callable: 

459 return hashlib.sha256 

460 

461 auth = Instance(hmac.HMAC, allow_none=True) 

462 

463 def _new_auth(self) -> None: 

464 if self.key: 

465 self.auth = hmac.HMAC(self.key, digestmod=self.digest_mod) 

466 else: 

467 self.auth = None 

468 

469 digest_history = Set() 

470 digest_history_size = Integer( 

471 2**16, 

472 config=True, 

473 help="""The maximum number of digests to remember. 

474 

475 The digest history will be culled when it exceeds this value. 

476 """, 

477 ) 

478 

479 keyfile = Unicode("", config=True, help="""path to file containing execution key.""") 

480 

481 @observe("keyfile") 

482 def _keyfile_changed(self, change: t.Any) -> None: 

483 with open(change["new"], "rb") as f: 

484 self.key = f.read().strip() 

485 

486 # for protecting against sends from forks 

487 pid = Integer() 

488 

489 # serialization traits: 

490 

491 pack = Any(default_packer) # the actual packer function 

492 

493 @observe("pack") 

494 def _pack_changed(self, change: t.Any) -> None: 

495 new = change["new"] 

496 if not callable(new): 

497 raise TypeError("packer must be callable, not %s" % type(new)) 

498 

499 unpack = Any(default_unpacker) # the actual packer function 

500 

501 @observe("unpack") 

502 def _unpack_changed(self, change: t.Any) -> None: 

503 # unpacker is not checked - it is assumed to be 

504 new = change["new"] 

505 if not callable(new): 

506 raise TypeError("unpacker must be callable, not %s" % type(new)) 

507 

508 # thresholds: 

509 copy_threshold = Integer( 

510 2**16, 

511 config=True, 

512 help="Threshold (in bytes) beyond which a buffer should be sent without copying.", 

513 ) 

514 buffer_threshold = Integer( 

515 MAX_BYTES, 

516 config=True, 

517 help="Threshold (in bytes) beyond which an object's buffer should be extracted to avoid " 

518 "pickling.", 

519 ) 

520 item_threshold = Integer( 

521 MAX_ITEMS, 

522 config=True, 

523 help="""The maximum number of items for a container to be introspected for custom serialization. 

524 Containers larger than this are pickled outright. 

525 """, 

526 ) 

527 

528 def __init__(self, **kwargs: t.Any) -> None: 

529 """create a Session object 

530 

531 Parameters 

532 ---------- 

533 

534 debug : bool 

535 whether to trigger extra debugging statements 

536 packer/unpacker : str : 'json', 'pickle' or import_string 

537 importstrings for methods to serialize message parts. If just 

538 'json' or 'pickle', predefined JSON and pickle packers will be used. 

539 Otherwise, the entire importstring must be used. 

540 

541 The functions must accept at least valid JSON input, and output 

542 *bytes*. 

543 

544 For example, to use msgpack: 

545 packer = 'msgpack.packb', unpacker='msgpack.unpackb' 

546 pack/unpack : callables 

547 You can also set the pack/unpack callables for serialization 

548 directly. 

549 session : unicode (must be ascii) 

550 the ID of this Session object. The default is to generate a new 

551 UUID. 

552 bsession : bytes 

553 The session as bytes 

554 username : unicode 

555 username added to message headers. The default is to ask the OS. 

556 key : bytes 

557 The key used to initialize an HMAC signature. If unset, messages 

558 will not be signed or checked. 

559 signature_scheme : str 

560 The message digest scheme. Currently must be of the form 'hmac-HASH', 

561 where 'HASH' is a hashing function available in Python's hashlib. 

562 The default is 'hmac-sha256'. 

563 This is ignored if 'key' is empty. 

564 keyfile : filepath 

565 The file containing a key. If this is set, `key` will be 

566 initialized to the contents of the file. 

567 """ 

568 super().__init__(**kwargs) 

569 self._check_packers() 

570 self.none = self.pack({}) 

571 # ensure self._session_default() if necessary, so bsession is defined: 

572 self.session # noqa 

573 self.pid = os.getpid() 

574 self._new_auth() 

575 if not self.key: 

576 get_logger().warning( 

577 "Message signing is disabled. This is insecure and not recommended!" 

578 ) 

579 

580 def clone(self) -> Session: 

581 """Create a copy of this Session 

582 

583 Useful when connecting multiple times to a given kernel. 

584 This prevents a shared digest_history warning about duplicate digests 

585 due to multiple connections to IOPub in the same process. 

586 

587 .. versionadded:: 5.1 

588 """ 

589 # make a copy 

590 new_session = type(self)() 

591 for name in self.traits(): 

592 setattr(new_session, name, getattr(self, name)) 

593 # fork digest_history 

594 new_session.digest_history = set() 

595 new_session.digest_history.update(self.digest_history) 

596 return new_session 

597 

598 message_count = 0 

599 

600 @property 

601 def msg_id(self) -> str: 

602 message_number = self.message_count 

603 self.message_count += 1 

604 return f"{self.session}_{os.getpid()}_{message_number}" 

605 

606 def _check_packers(self) -> None: 

607 """check packers for datetime support.""" 

608 pack = self.pack 

609 unpack = self.unpack 

610 

611 # check simple serialization 

612 msg_list = {"a": [1, "hi"]} 

613 try: 

614 packed = pack(msg_list) 

615 except Exception as e: 

616 msg = f"packer '{self.packer}' could not serialize a simple message: {e}" 

617 raise ValueError(msg) from e 

618 

619 # ensure packed message is bytes 

620 if not isinstance(packed, bytes): 

621 raise ValueError("message packed to %r, but bytes are required" % type(packed)) 

622 

623 # check that unpack is pack's inverse 

624 try: 

625 unpacked = unpack(packed) 

626 assert unpacked == msg_list 

627 except Exception as e: 

628 msg = ( 

629 f"unpacker '{self.unpacker}' could not handle output from packer" 

630 f" '{self.packer}': {e}" 

631 ) 

632 raise ValueError(msg) from e 

633 

634 # check datetime support 

635 msg_datetime = {"t": utcnow()} 

636 try: 

637 unpacked = unpack(pack(msg_datetime)) 

638 if isinstance(unpacked["t"], datetime): 

639 msg = "Shouldn't deserialize to datetime" 

640 raise ValueError(msg) 

641 except Exception: 

642 self.pack = lambda o: pack(squash_dates(o)) 

643 self.unpack = lambda s: unpack(s) 

644 

645 def msg_header(self, msg_type: str) -> dict[str, t.Any]: 

646 """Create a header for a message type.""" 

647 return msg_header(self.msg_id, msg_type, self.username, self.session) 

648 

649 def msg( 

650 self, 

651 msg_type: str, 

652 content: dict | None = None, 

653 parent: dict[str, t.Any] | None = None, 

654 header: dict[str, t.Any] | None = None, 

655 metadata: dict[str, t.Any] | None = None, 

656 ) -> dict[str, t.Any]: 

657 """Return the nested message dict. 

658 

659 This format is different from what is sent over the wire. The 

660 serialize/deserialize methods converts this nested message dict to the wire 

661 format, which is a list of message parts. 

662 """ 

663 msg = {} 

664 header = self.msg_header(msg_type) if header is None else header 

665 msg["header"] = header 

666 msg["msg_id"] = header["msg_id"] 

667 msg["msg_type"] = header["msg_type"] 

668 msg["parent_header"] = {} if parent is None else extract_header(parent) 

669 msg["content"] = {} if content is None else content 

670 msg["metadata"] = self.metadata.copy() 

671 if metadata is not None: 

672 msg["metadata"].update(metadata) 

673 return msg 

674 

675 def sign(self, msg_list: list) -> bytes: 

676 """Sign a message with HMAC digest. If no auth, return b''. 

677 

678 Parameters 

679 ---------- 

680 msg_list : list 

681 The [p_header,p_parent,p_content] part of the message list. 

682 """ 

683 if self.auth is None: 

684 return b"" 

685 h = self.auth.copy() 

686 for m in msg_list: 

687 h.update(m) 

688 return h.hexdigest().encode() 

689 

690 def serialize( 

691 self, 

692 msg: dict[str, t.Any], 

693 ident: list[bytes] | bytes | None = None, 

694 ) -> list[bytes]: 

695 """Serialize the message components to bytes. 

696 

697 This is roughly the inverse of deserialize. The serialize/deserialize 

698 methods work with full message lists, whereas pack/unpack work with 

699 the individual message parts in the message list. 

700 

701 Parameters 

702 ---------- 

703 msg : dict or Message 

704 The next message dict as returned by the self.msg method. 

705 

706 Returns 

707 ------- 

708 msg_list : list 

709 The list of bytes objects to be sent with the format:: 

710 

711 [ident1, ident2, ..., DELIM, HMAC, p_header, p_parent, 

712 p_metadata, p_content, buffer1, buffer2, ...] 

713 

714 In this list, the ``p_*`` entities are the packed or serialized 

715 versions, so if JSON is used, these are utf8 encoded JSON strings. 

716 """ 

717 content = msg.get("content", {}) 

718 if content is None: 

719 content = self.none 

720 elif isinstance(content, dict): 

721 content = self.pack(content) 

722 elif isinstance(content, bytes): 

723 # content is already packed, as in a relayed message 

724 pass 

725 elif isinstance(content, str): 

726 # should be bytes, but JSON often spits out unicode 

727 content = content.encode("utf8") 

728 else: 

729 raise TypeError("Content incorrect type: %s" % type(content)) 

730 

731 real_message = [ 

732 self.pack(msg["header"]), 

733 self.pack(msg["parent_header"]), 

734 self.pack(msg["metadata"]), 

735 content, 

736 ] 

737 

738 to_send = [] 

739 

740 if isinstance(ident, list): 

741 # accept list of idents 

742 to_send.extend(ident) 

743 elif ident is not None: 

744 to_send.append(ident) 

745 to_send.append(DELIM) 

746 

747 signature = self.sign(real_message) 

748 to_send.append(signature) 

749 

750 to_send.extend(real_message) 

751 

752 return to_send 

753 

754 def send( 

755 self, 

756 stream: zmq.sugar.socket.Socket | ZMQStream | None, 

757 msg_or_type: dict[str, t.Any] | str, 

758 content: dict[str, t.Any] | None = None, 

759 parent: dict[str, t.Any] | None = None, 

760 ident: bytes | list[bytes] | None = None, 

761 buffers: list[bytes] | None = None, 

762 track: bool = False, 

763 header: dict[str, t.Any] | None = None, 

764 metadata: dict[str, t.Any] | None = None, 

765 ) -> dict[str, t.Any] | None: 

766 """Build and send a message via stream or socket. 

767 

768 The message format used by this function internally is as follows: 

769 

770 [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content, 

771 buffer1,buffer2,...] 

772 

773 The serialize/deserialize methods convert the nested message dict into this 

774 format. 

775 

776 Parameters 

777 ---------- 

778 

779 stream : zmq.Socket or ZMQStream 

780 The socket-like object used to send the data. 

781 msg_or_type : str or Message/dict 

782 Normally, msg_or_type will be a msg_type unless a message is being 

783 sent more than once. If a header is supplied, this can be set to 

784 None and the msg_type will be pulled from the header. 

785 

786 content : dict or None 

787 The content of the message (ignored if msg_or_type is a message). 

788 header : dict or None 

789 The header dict for the message (ignored if msg_to_type is a message). 

790 parent : Message or dict or None 

791 The parent or parent header describing the parent of this message 

792 (ignored if msg_or_type is a message). 

793 ident : bytes or list of bytes 

794 The zmq.IDENTITY routing path. 

795 metadata : dict or None 

796 The metadata describing the message 

797 buffers : list or None 

798 The already-serialized buffers to be appended to the message. 

799 track : bool 

800 Whether to track. Only for use with Sockets, because ZMQStream 

801 objects cannot track messages. 

802 

803 

804 Returns 

805 ------- 

806 msg : dict 

807 The constructed message. 

808 """ 

809 if not isinstance(stream, zmq.Socket): 

810 # ZMQStreams and dummy sockets do not support tracking. 

811 track = False 

812 

813 if isinstance(stream, zmq.asyncio.Socket): 

814 assert stream is not None # type:ignore[unreachable] 

815 stream = zmq.Socket.shadow(stream.underlying) 

816 

817 if isinstance(msg_or_type, (Message, dict)): 

818 # We got a Message or message dict, not a msg_type so don't 

819 # build a new Message. 

820 msg = msg_or_type 

821 buffers = buffers or msg.get("buffers", []) 

822 else: 

823 msg = self.msg( 

824 msg_or_type, 

825 content=content, 

826 parent=parent, 

827 header=header, 

828 metadata=metadata, 

829 ) 

830 if self.check_pid and os.getpid() != self.pid: 

831 get_logger().warning("WARNING: attempted to send message from fork\n%s", msg) 

832 return None 

833 buffers = [] if buffers is None else buffers 

834 for idx, buf in enumerate(buffers): 

835 if isinstance(buf, memoryview): 

836 view = buf 

837 else: 

838 try: 

839 # check to see if buf supports the buffer protocol. 

840 view = memoryview(buf) 

841 except TypeError as e: 

842 emsg = "Buffer objects must support the buffer protocol." 

843 raise TypeError(emsg) from e 

844 # memoryview.contiguous is new in 3.3, 

845 # just skip the check on Python 2 

846 if hasattr(view, "contiguous") and not view.contiguous: 

847 # zmq requires memoryviews to be contiguous 

848 raise ValueError("Buffer %i (%r) is not contiguous" % (idx, buf)) 

849 

850 if self.adapt_version: 

851 msg = adapt(msg, self.adapt_version) 

852 to_send = self.serialize(msg, ident) 

853 to_send.extend(buffers) 

854 longest = max([len(s) for s in to_send]) 

855 copy = longest < self.copy_threshold 

856 

857 if stream and buffers and track and not copy: 

858 # only really track when we are doing zero-copy buffers 

859 tracker = stream.send_multipart(to_send, copy=False, track=True) 

860 elif stream: 

861 # use dummy tracker, which will be done immediately 

862 tracker = DONE 

863 stream.send_multipart(to_send, copy=copy) 

864 else: 

865 tracker = DONE 

866 

867 if self.debug: 

868 pprint.pprint(msg) # noqa 

869 pprint.pprint(to_send) # noqa 

870 pprint.pprint(buffers) # noqa 

871 

872 msg["tracker"] = tracker 

873 

874 return msg 

875 

876 def send_raw( 

877 self, 

878 stream: zmq.sugar.socket.Socket, 

879 msg_list: list, 

880 flags: int = 0, 

881 copy: bool = True, 

882 ident: bytes | list[bytes] | None = None, 

883 ) -> None: 

884 """Send a raw message via ident path. 

885 

886 This method is used to send a already serialized message. 

887 

888 Parameters 

889 ---------- 

890 stream : ZMQStream or Socket 

891 The ZMQ stream or socket to use for sending the message. 

892 msg_list : list 

893 The serialized list of messages to send. This only includes the 

894 [p_header,p_parent,p_metadata,p_content,buffer1,buffer2,...] portion of 

895 the message. 

896 ident : ident or list 

897 A single ident or a list of idents to use in sending. 

898 """ 

899 to_send = [] 

900 if isinstance(ident, bytes): 

901 ident = [ident] 

902 if ident is not None: 

903 to_send.extend(ident) 

904 

905 to_send.append(DELIM) 

906 # Don't include buffers in signature (per spec). 

907 to_send.append(self.sign(msg_list[0:4])) 

908 to_send.extend(msg_list) 

909 if isinstance(stream, zmq.asyncio.Socket): 

910 stream = zmq.Socket.shadow(stream.underlying) 

911 stream.send_multipart(to_send, flags, copy=copy) 

912 

913 def recv( 

914 self, 

915 socket: zmq.sugar.socket.Socket, 

916 mode: int = zmq.NOBLOCK, 

917 content: bool = True, 

918 copy: bool = True, 

919 ) -> tuple[list[bytes] | None, dict[str, t.Any] | None]: 

920 """Receive and unpack a message. 

921 

922 Parameters 

923 ---------- 

924 socket : ZMQStream or Socket 

925 The socket or stream to use in receiving. 

926 

927 Returns 

928 ------- 

929 [idents], msg 

930 [idents] is a list of idents and msg is a nested message dict of 

931 same format as self.msg returns. 

932 """ 

933 if isinstance(socket, ZMQStream): # type:ignore[unreachable] 

934 socket = socket.socket # type:ignore[unreachable] 

935 if isinstance(socket, zmq.asyncio.Socket): 

936 socket = zmq.Socket.shadow(socket.underlying) 

937 

938 try: 

939 msg_list = socket.recv_multipart(mode, copy=copy) 

940 except zmq.ZMQError as e: 

941 if e.errno == zmq.EAGAIN: 

942 # We can convert EAGAIN to None as we know in this case 

943 # recv_multipart won't return None. 

944 return None, None 

945 else: 

946 raise 

947 # split multipart message into identity list and message dict 

948 # invalid large messages can cause very expensive string comparisons 

949 idents, msg_list = self.feed_identities(msg_list, copy) 

950 try: 

951 return idents, self.deserialize(msg_list, content=content, copy=copy) 

952 except Exception as e: 

953 # TODO: handle it 

954 raise e 

955 

956 def feed_identities( 

957 self, msg_list: list[bytes] | list[zmq.Message], copy: bool = True 

958 ) -> tuple[list[bytes], list[bytes] | list[zmq.Message]]: 

959 """Split the identities from the rest of the message. 

960 

961 Feed until DELIM is reached, then return the prefix as idents and 

962 remainder as msg_list. This is easily broken by setting an IDENT to DELIM, 

963 but that would be silly. 

964 

965 Parameters 

966 ---------- 

967 msg_list : a list of Message or bytes objects 

968 The message to be split. 

969 copy : bool 

970 flag determining whether the arguments are bytes or Messages 

971 

972 Returns 

973 ------- 

974 (idents, msg_list) : two lists 

975 idents will always be a list of bytes, each of which is a ZMQ 

976 identity. msg_list will be a list of bytes or zmq.Messages of the 

977 form [HMAC,p_header,p_parent,p_content,buffer1,buffer2,...] and 

978 should be unpackable/unserializable via self.deserialize at this 

979 point. 

980 """ 

981 if copy: 

982 msg_list = t.cast(t.List[bytes], msg_list) 

983 idx = msg_list.index(DELIM) 

984 return msg_list[:idx], msg_list[idx + 1 :] 

985 else: 

986 msg_list = t.cast(t.List[zmq.Message], msg_list) 

987 failed = True 

988 for idx, m in enumerate(msg_list): # noqa 

989 if m.bytes == DELIM: 

990 failed = False 

991 break 

992 if failed: 

993 msg = "DELIM not in msg_list" 

994 raise ValueError(msg) 

995 idents, msg_list = msg_list[:idx], msg_list[idx + 1 :] 

996 return [bytes(m.bytes) for m in idents], msg_list 

997 

998 def _add_digest(self, signature: bytes) -> None: 

999 """add a digest to history to protect against replay attacks""" 

1000 if self.digest_history_size == 0: 

1001 # no history, never add digests 

1002 return 

1003 

1004 self.digest_history.add(signature) 

1005 if len(self.digest_history) > self.digest_history_size: 

1006 # threshold reached, cull 10% 

1007 self._cull_digest_history() 

1008 

1009 def _cull_digest_history(self) -> None: 

1010 """cull the digest history 

1011 

1012 Removes a randomly selected 10% of the digest history 

1013 """ 

1014 current = len(self.digest_history) 

1015 n_to_cull = max(int(current // 10), current - self.digest_history_size) 

1016 if n_to_cull >= current: 

1017 self.digest_history = set() 

1018 return 

1019 to_cull = random.sample(tuple(sorted(self.digest_history)), n_to_cull) 

1020 self.digest_history.difference_update(to_cull) 

1021 

1022 def deserialize( 

1023 self, 

1024 msg_list: list[bytes] | list[zmq.Message], 

1025 content: bool = True, 

1026 copy: bool = True, 

1027 ) -> dict[str, t.Any]: 

1028 """Unserialize a msg_list to a nested message dict. 

1029 

1030 This is roughly the inverse of serialize. The serialize/deserialize 

1031 methods work with full message lists, whereas pack/unpack work with 

1032 the individual message parts in the message list. 

1033 

1034 Parameters 

1035 ---------- 

1036 msg_list : list of bytes or Message objects 

1037 The list of message parts of the form [HMAC,p_header,p_parent, 

1038 p_metadata,p_content,buffer1,buffer2,...]. 

1039 content : bool (True) 

1040 Whether to unpack the content dict (True), or leave it packed 

1041 (False). 

1042 copy : bool (True) 

1043 Whether msg_list contains bytes (True) or the non-copying Message 

1044 objects in each place (False). 

1045 

1046 Returns 

1047 ------- 

1048 msg : dict 

1049 The nested message dict with top-level keys [header, parent_header, 

1050 content, buffers]. The buffers are returned as memoryviews. 

1051 """ 

1052 minlen = 5 

1053 message = {} 

1054 if not copy: 

1055 # pyzmq didn't copy the first parts of the message, so we'll do it 

1056 msg_list = t.cast(t.List[zmq.Message], msg_list) 

1057 msg_list_beginning = [bytes(msg.bytes) for msg in msg_list[:minlen]] 

1058 msg_list = t.cast(t.List[bytes], msg_list) 

1059 msg_list = msg_list_beginning + msg_list[minlen:] 

1060 msg_list = t.cast(t.List[bytes], msg_list) 

1061 if self.auth is not None: 

1062 signature = msg_list[0] 

1063 if not signature: 

1064 msg = "Unsigned Message" 

1065 raise ValueError(msg) 

1066 if signature in self.digest_history: 

1067 raise ValueError("Duplicate Signature: %r" % signature) 

1068 if content: 

1069 # Only store signature if we are unpacking content, don't store if just peeking. 

1070 self._add_digest(signature) 

1071 check = self.sign(msg_list[1:5]) 

1072 if not compare_digest(signature, check): 

1073 msg = "Invalid Signature: %r" % signature 

1074 raise ValueError(msg) 

1075 if not len(msg_list) >= minlen: 

1076 msg = "malformed message, must have at least %i elements" % minlen 

1077 raise TypeError(msg) 

1078 header = self.unpack(msg_list[1]) 

1079 message["header"] = extract_dates(header) 

1080 message["msg_id"] = header["msg_id"] 

1081 message["msg_type"] = header["msg_type"] 

1082 message["parent_header"] = extract_dates(self.unpack(msg_list[2])) 

1083 message["metadata"] = self.unpack(msg_list[3]) 

1084 if content: 

1085 message["content"] = self.unpack(msg_list[4]) 

1086 else: 

1087 message["content"] = msg_list[4] 

1088 buffers = [memoryview(b) for b in msg_list[5:]] 

1089 if buffers and buffers[0].shape is None: 

1090 # force copy to workaround pyzmq #646 

1091 msg_list = t.cast(t.List[zmq.Message], msg_list) 

1092 buffers = [memoryview(bytes(b.bytes)) for b in msg_list[5:]] 

1093 message["buffers"] = buffers 

1094 if self.debug: 

1095 pprint.pprint(message) # noqa 

1096 # adapt to the current version 

1097 return adapt(message) 

1098 

1099 def unserialize(self, *args: t.Any, **kwargs: t.Any) -> dict[str, t.Any]: 

1100 """**DEPRECATED** Use deserialize instead.""" 

1101 # pragma: no cover 

1102 warnings.warn( 

1103 "Session.unserialize is deprecated. Use Session.deserialize.", 

1104 DeprecationWarning, 

1105 stacklevel=2, 

1106 ) 

1107 return self.deserialize(*args, **kwargs)