Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/jupyter_client/session.py: 40%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

445 statements  

1"""Session object for building, serializing, sending, and receiving messages. 

2 

3The Session object supports serialization, HMAC signatures, 

4and metadata on messages. 

5 

6Also defined here are utilities for working with Sessions: 

7* A SessionFactory to be used as a base class for configurables that work with 

8Sessions. 

9* A Message object for convenience that allows attribute-access to the msg dict. 

10""" 

11 

12# Copyright (c) Jupyter Development Team. 

13# Distributed under the terms of the Modified BSD License. 

14from __future__ import annotations 

15 

16import hashlib 

17import hmac 

18import json 

19import logging 

20import os 

21import pickle 

22import pprint 

23import random 

24import typing as t 

25import warnings 

26from binascii import b2a_hex 

27from datetime import datetime, timezone 

28from hmac import compare_digest 

29 

30# We are using compare_digest to limit the surface of timing attacks 

31import zmq.asyncio 

32from tornado.ioloop import IOLoop 

33from traitlets import ( 

34 Any, 

35 Bool, 

36 CBytes, 

37 CUnicode, 

38 Dict, 

39 DottedObjectName, 

40 Instance, 

41 Integer, 

42 Set, 

43 TraitError, 

44 Unicode, 

45 observe, 

46) 

47from traitlets.config.configurable import Configurable, LoggingConfigurable 

48from traitlets.log import get_logger 

49from traitlets.utils.importstring import import_item 

50from zmq.eventloop.zmqstream import ZMQStream 

51 

52from ._version import protocol_version 

53from .adapter import adapt 

54from .jsonutil import extract_dates, json_clean, json_default, squash_dates 

55 

56PICKLE_PROTOCOL = pickle.DEFAULT_PROTOCOL 

57 

58utc = timezone.utc 

59 

60# ----------------------------------------------------------------------------- 

61# utility functions 

62# ----------------------------------------------------------------------------- 

63 

64 

65def squash_unicode(obj: t.Any) -> t.Any: 

66 """coerce unicode back to bytestrings.""" 

67 if isinstance(obj, dict): 

68 for key in list(obj.keys()): 

69 obj[key] = squash_unicode(obj[key]) 

70 if isinstance(key, str): 

71 obj[squash_unicode(key)] = obj.pop(key) 

72 elif isinstance(obj, list): 

73 for i, v in enumerate(obj): 

74 obj[i] = squash_unicode(v) 

75 elif isinstance(obj, str): 

76 obj = obj.encode("utf8") 

77 return obj 

78 

79 

80# ----------------------------------------------------------------------------- 

81# globals and defaults 

82# ----------------------------------------------------------------------------- 

83 

84# default values for the thresholds: 

85MAX_ITEMS = 64 

86MAX_BYTES = 1024 

87 

88# ISO8601-ify datetime objects 

89# allow unicode 

90# disallow nan, because it's not actually valid JSON 

91 

92 

93def json_packer(obj: t.Any) -> bytes: 

94 """Convert a json object to a bytes.""" 

95 try: 

96 return json.dumps( 

97 obj, 

98 default=json_default, 

99 ensure_ascii=False, 

100 allow_nan=False, 

101 ).encode("utf8", errors="surrogateescape") 

102 except (TypeError, ValueError) as e: 

103 # Fallback to trying to clean the json before serializing 

104 packed = json.dumps( 

105 json_clean(obj), 

106 default=json_default, 

107 ensure_ascii=False, 

108 allow_nan=False, 

109 ).encode("utf8", errors="surrogateescape") 

110 

111 warnings.warn( 

112 f"Message serialization failed with:\n{e}\n" 

113 "Supporting this message is deprecated in jupyter-client 7, please make " 

114 "sure your message is JSON-compliant", 

115 stacklevel=2, 

116 ) 

117 

118 return packed 

119 

120 

121def json_unpacker(s: str | bytes) -> t.Any: 

122 """Convert a json bytes or string to an object.""" 

123 if isinstance(s, bytes): 

124 s = s.decode("utf8", "replace") 

125 return json.loads(s) 

126 

127 

128def pickle_packer(o: t.Any) -> bytes: 

129 """Pack an object using the pickle module.""" 

130 return pickle.dumps(squash_dates(o), PICKLE_PROTOCOL) 

131 

132 

133pickle_unpacker = pickle.loads 

134 

135default_packer = json_packer 

136default_unpacker = json_unpacker 

137 

138DELIM = b"<IDS|MSG>" 

139# singleton dummy tracker, which will always report as done 

140DONE = zmq.MessageTracker() 

141 

142# ----------------------------------------------------------------------------- 

143# Mixin tools for apps that use Sessions 

144# ----------------------------------------------------------------------------- 

145 

146 

147def new_id() -> str: 

148 """Generate a new random id. 

149 

150 Avoids problematic runtime import in stdlib uuid on Python 2. 

151 

152 Returns 

153 ------- 

154 

155 id string (16 random bytes as hex-encoded text, chunks separated by '-') 

156 """ 

157 buf = os.urandom(16) 

158 return "-".join(b2a_hex(x).decode("ascii") for x in (buf[:4], buf[4:])) 

159 

160 

161def new_id_bytes() -> bytes: 

162 """Return new_id as ascii bytes""" 

163 return new_id().encode("ascii") 

164 

165 

166session_aliases = { 

167 "ident": "Session.session", 

168 "user": "Session.username", 

169 "keyfile": "Session.keyfile", 

170} 

171 

172session_flags = { 

173 "secure": ( 

174 {"Session": {"key": new_id_bytes(), "keyfile": ""}}, 

175 """Use HMAC digests for authentication of messages. 

176 Setting this flag will generate a new UUID to use as the HMAC key. 

177 """, 

178 ), 

179 "no-secure": ( 

180 {"Session": {"key": b"", "keyfile": ""}}, 

181 """Don't authenticate messages.""", 

182 ), 

183} 

184 

185 

186def default_secure(cfg: t.Any) -> None: # pragma: no cover 

187 """Set the default behavior for a config environment to be secure. 

188 

189 If Session.key/keyfile have not been set, set Session.key to 

190 a new random UUID. 

191 """ 

192 warnings.warn("default_secure is deprecated", DeprecationWarning, stacklevel=2) 

193 if "Session" in cfg and ("key" in cfg.Session or "keyfile" in cfg.Session): 

194 return 

195 # key/keyfile not specified, generate new UUID: 

196 cfg.Session.key = new_id_bytes() 

197 

198 

199def utcnow() -> datetime: 

200 """Return timezone-aware UTC timestamp""" 

201 return datetime.now(utc) 

202 

203 

204# ----------------------------------------------------------------------------- 

205# Classes 

206# ----------------------------------------------------------------------------- 

207 

208 

209class SessionFactory(LoggingConfigurable): 

210 """The Base class for configurables that have a Session, Context, logger, 

211 and IOLoop. 

212 """ 

213 

214 logname = Unicode("") 

215 

216 @observe("logname") 

217 def _logname_changed(self, change: t.Any) -> None: 

218 self.log = logging.getLogger(change["new"]) 

219 

220 # not configurable: 

221 context = Instance("zmq.Context") 

222 

223 def _context_default(self) -> zmq.Context: 

224 return zmq.Context() 

225 

226 session = Instance("jupyter_client.session.Session", allow_none=True) 

227 

228 loop = Instance("tornado.ioloop.IOLoop") 

229 

230 def _loop_default(self) -> IOLoop: 

231 return IOLoop.current() 

232 

233 def __init__(self, **kwargs: t.Any) -> None: 

234 """Initialize a session factory.""" 

235 super().__init__(**kwargs) 

236 

237 if self.session is None: 

238 # construct the session 

239 self.session = Session(**kwargs) 

240 

241 

242class Message: 

243 """A simple message object that maps dict keys to attributes. 

244 

245 A Message can be created from a dict and a dict from a Message instance 

246 simply by calling dict(msg_obj).""" 

247 

248 def __init__(self, msg_dict: dict[str, t.Any]) -> None: 

249 """Initialize a message.""" 

250 dct = self.__dict__ 

251 for k, v in dict(msg_dict).items(): 

252 if isinstance(v, dict): 

253 v = Message(v) # noqa 

254 dct[k] = v 

255 

256 # Having this iterator lets dict(msg_obj) work out of the box. 

257 def __iter__(self) -> t.ItemsView[str, t.Any]: 

258 return iter(self.__dict__.items()) # type:ignore[return-value] 

259 

260 def __repr__(self) -> str: 

261 return repr(self.__dict__) 

262 

263 def __str__(self) -> str: 

264 return pprint.pformat(self.__dict__) 

265 

266 def __contains__(self, k: object) -> bool: 

267 return k in self.__dict__ 

268 

269 def __getitem__(self, k: str) -> t.Any: 

270 return self.__dict__[k] 

271 

272 

273def msg_header( 

274 msg_id: str, msg_type: str, username: str, session: Session | str 

275) -> dict[str, t.Any]: 

276 """Create a new message header""" 

277 date = utcnow() 

278 version = protocol_version 

279 return locals() 

280 

281 

282def extract_header(msg_or_header: dict[str, t.Any]) -> dict[str, t.Any]: 

283 """Given a message or header, return the header.""" 

284 if not msg_or_header: 

285 return {} 

286 try: 

287 # See if msg_or_header is the entire message. 

288 h = msg_or_header["header"] 

289 except KeyError: 

290 try: 

291 # See if msg_or_header is just the header 

292 h = msg_or_header["msg_id"] 

293 except KeyError: 

294 raise 

295 else: 

296 h = msg_or_header 

297 if not isinstance(h, dict): 

298 h = dict(h) 

299 return h 

300 

301 

302class Session(Configurable): 

303 """Object for handling serialization and sending of messages. 

304 

305 The Session object handles building messages and sending them 

306 with ZMQ sockets or ZMQStream objects. Objects can communicate with each 

307 other over the network via Session objects, and only need to work with the 

308 dict-based IPython message spec. The Session will handle 

309 serialization/deserialization, security, and metadata. 

310 

311 Sessions support configurable serialization via packer/unpacker traits, 

312 and signing with HMAC digests via the key/keyfile traits. 

313 

314 Parameters 

315 ---------- 

316 

317 debug : bool 

318 whether to trigger extra debugging statements 

319 packer/unpacker : str : 'json', 'pickle' or import_string 

320 importstrings for methods to serialize message parts. If just 

321 'json' or 'pickle', predefined JSON and pickle packers will be used. 

322 Otherwise, the entire importstring must be used. 

323 

324 The functions must accept at least valid JSON input, and output *bytes*. 

325 

326 For example, to use msgpack: 

327 packer = 'msgpack.packb', unpacker='msgpack.unpackb' 

328 pack/unpack : callables 

329 You can also set the pack/unpack callables for serialization directly. 

330 session : bytes 

331 the ID of this Session object. The default is to generate a new UUID. 

332 username : unicode 

333 username added to message headers. The default is to ask the OS. 

334 key : bytes 

335 The key used to initialize an HMAC signature. If unset, messages 

336 will not be signed or checked. 

337 keyfile : filepath 

338 The file containing a key. If this is set, `key` will be initialized 

339 to the contents of the file. 

340 

341 """ 

342 

343 debug = Bool(False, config=True, help="""Debug output in the Session""") 

344 

345 check_pid = Bool( 

346 True, 

347 config=True, 

348 help="""Whether to check PID to protect against calls after fork. 

349 

350 This check can be disabled if fork-safety is handled elsewhere. 

351 """, 

352 ) 

353 

354 packer = DottedObjectName( 

355 "json", 

356 config=True, 

357 help="""The name of the packer for serializing messages. 

358 Should be one of 'json', 'pickle', or an import name 

359 for a custom callable serializer.""", 

360 ) 

361 

362 @observe("packer") 

363 def _packer_changed(self, change: t.Any) -> None: 

364 new = change["new"] 

365 if new.lower() == "json": 

366 self.pack = json_packer 

367 self.unpack = json_unpacker 

368 self.unpacker = new 

369 elif new.lower() == "pickle": 

370 self.pack = pickle_packer 

371 self.unpack = pickle_unpacker 

372 self.unpacker = new 

373 else: 

374 self.pack = import_item(str(new)) 

375 

376 unpacker = DottedObjectName( 

377 "json", 

378 config=True, 

379 help="""The name of the unpacker for unserializing messages. 

380 Only used with custom functions for `packer`.""", 

381 ) 

382 

383 @observe("unpacker") 

384 def _unpacker_changed(self, change: t.Any) -> None: 

385 new = change["new"] 

386 if new.lower() == "json": 

387 self.pack = json_packer 

388 self.unpack = json_unpacker 

389 self.packer = new 

390 elif new.lower() == "pickle": 

391 self.pack = pickle_packer 

392 self.unpack = pickle_unpacker 

393 self.packer = new 

394 else: 

395 self.unpack = import_item(str(new)) 

396 

397 session = CUnicode("", config=True, help="""The UUID identifying this session.""") 

398 

399 def _session_default(self) -> str: 

400 u = new_id() 

401 self.bsession = u.encode("ascii") 

402 return u 

403 

404 @observe("session") 

405 def _session_changed(self, change: t.Any) -> None: 

406 self.bsession = self.session.encode("ascii") 

407 

408 # bsession is the session as bytes 

409 bsession = CBytes(b"") 

410 

411 username = Unicode( 

412 os.environ.get("USER", "username"), 

413 help="""Username for the Session. Default is your system username.""", 

414 config=True, 

415 ) 

416 

417 metadata = Dict( 

418 {}, 

419 config=True, 

420 help="Metadata dictionary, which serves as the default top-level metadata dict for each " 

421 "message.", 

422 ) 

423 

424 # if 0, no adapting to do. 

425 adapt_version = Integer(0) 

426 

427 # message signature related traits: 

428 

429 key = CBytes(config=True, help="""execution key, for signing messages.""") 

430 

431 def _key_default(self) -> bytes: 

432 return new_id_bytes() 

433 

434 @observe("key") 

435 def _key_changed(self, change: t.Any) -> None: 

436 self._new_auth() 

437 

438 signature_scheme = Unicode( 

439 "hmac-sha256", 

440 config=True, 

441 help="""The digest scheme used to construct the message signatures. 

442 Must have the form 'hmac-HASH'.""", 

443 ) 

444 

445 @observe("signature_scheme") 

446 def _signature_scheme_changed(self, change: t.Any) -> None: 

447 new = change["new"] 

448 if not new.startswith("hmac-"): 

449 raise TraitError("signature_scheme must start with 'hmac-', got %r" % new) 

450 hash_name = new.split("-", 1)[1] 

451 try: 

452 self.digest_mod = getattr(hashlib, hash_name) 

453 except AttributeError as e: 

454 raise TraitError("hashlib has no such attribute: %s" % hash_name) from e 

455 self._new_auth() 

456 

457 digest_mod = Any() 

458 

459 def _digest_mod_default(self) -> t.Callable: 

460 return hashlib.sha256 

461 

462 auth = Instance(hmac.HMAC, allow_none=True) 

463 

464 def _new_auth(self) -> None: 

465 if self.key: 

466 self.auth = hmac.HMAC(self.key, digestmod=self.digest_mod) 

467 else: 

468 self.auth = None 

469 

470 digest_history = Set() 

471 digest_history_size = Integer( 

472 2**16, 

473 config=True, 

474 help="""The maximum number of digests to remember. 

475 

476 The digest history will be culled when it exceeds this value. 

477 """, 

478 ) 

479 

480 keyfile = Unicode("", config=True, help="""path to file containing execution key.""") 

481 

482 @observe("keyfile") 

483 def _keyfile_changed(self, change: t.Any) -> None: 

484 with open(change["new"], "rb") as f: 

485 self.key = f.read().strip() 

486 

487 # for protecting against sends from forks 

488 pid = Integer() 

489 

490 # serialization traits: 

491 

492 pack = Any(default_packer) # the actual packer function 

493 

494 @observe("pack") 

495 def _pack_changed(self, change: t.Any) -> None: 

496 new = change["new"] 

497 if not callable(new): 

498 raise TypeError("packer must be callable, not %s" % type(new)) 

499 

500 unpack = Any(default_unpacker) # the actual packer function 

501 

502 @observe("unpack") 

503 def _unpack_changed(self, change: t.Any) -> None: 

504 # unpacker is not checked - it is assumed to be 

505 new = change["new"] 

506 if not callable(new): 

507 raise TypeError("unpacker must be callable, not %s" % type(new)) 

508 

509 # thresholds: 

510 copy_threshold = Integer( 

511 2**16, 

512 config=True, 

513 help="Threshold (in bytes) beyond which a buffer should be sent without copying.", 

514 ) 

515 buffer_threshold = Integer( 

516 MAX_BYTES, 

517 config=True, 

518 help="Threshold (in bytes) beyond which an object's buffer should be extracted to avoid " 

519 "pickling.", 

520 ) 

521 item_threshold = Integer( 

522 MAX_ITEMS, 

523 config=True, 

524 help="""The maximum number of items for a container to be introspected for custom serialization. 

525 Containers larger than this are pickled outright. 

526 """, 

527 ) 

528 

529 def __init__(self, **kwargs: t.Any) -> None: 

530 """create a Session object 

531 

532 Parameters 

533 ---------- 

534 

535 debug : bool 

536 whether to trigger extra debugging statements 

537 packer/unpacker : str : 'json', 'pickle' or import_string 

538 importstrings for methods to serialize message parts. If just 

539 'json' or 'pickle', predefined JSON and pickle packers will be used. 

540 Otherwise, the entire importstring must be used. 

541 

542 The functions must accept at least valid JSON input, and output 

543 *bytes*. 

544 

545 For example, to use msgpack: 

546 packer = 'msgpack.packb', unpacker='msgpack.unpackb' 

547 pack/unpack : callables 

548 You can also set the pack/unpack callables for serialization 

549 directly. 

550 session : unicode (must be ascii) 

551 the ID of this Session object. The default is to generate a new 

552 UUID. 

553 bsession : bytes 

554 The session as bytes 

555 username : unicode 

556 username added to message headers. The default is to ask the OS. 

557 key : bytes 

558 The key used to initialize an HMAC signature. If unset, messages 

559 will not be signed or checked. 

560 signature_scheme : str 

561 The message digest scheme. Currently must be of the form 'hmac-HASH', 

562 where 'HASH' is a hashing function available in Python's hashlib. 

563 The default is 'hmac-sha256'. 

564 This is ignored if 'key' is empty. 

565 keyfile : filepath 

566 The file containing a key. If this is set, `key` will be 

567 initialized to the contents of the file. 

568 """ 

569 super().__init__(**kwargs) 

570 self._check_packers() 

571 self.none = self.pack({}) 

572 # ensure self._session_default() if necessary, so bsession is defined: 

573 self.session # noqa 

574 self.pid = os.getpid() 

575 self._new_auth() 

576 if not self.key: 

577 get_logger().warning( 

578 "Message signing is disabled. This is insecure and not recommended!" 

579 ) 

580 

581 def clone(self) -> Session: 

582 """Create a copy of this Session 

583 

584 Useful when connecting multiple times to a given kernel. 

585 This prevents a shared digest_history warning about duplicate digests 

586 due to multiple connections to IOPub in the same process. 

587 

588 .. versionadded:: 5.1 

589 """ 

590 # make a copy 

591 new_session = type(self)() 

592 for name in self.traits(): 

593 setattr(new_session, name, getattr(self, name)) 

594 # fork digest_history 

595 new_session.digest_history = set() 

596 new_session.digest_history.update(self.digest_history) 

597 return new_session 

598 

599 message_count = 0 

600 

601 @property 

602 def msg_id(self) -> str: 

603 message_number = self.message_count 

604 self.message_count += 1 

605 return f"{self.session}_{os.getpid()}_{message_number}" 

606 

607 def _check_packers(self) -> None: 

608 """check packers for datetime support.""" 

609 pack = self.pack 

610 unpack = self.unpack 

611 

612 # check simple serialization 

613 msg_list = {"a": [1, "hi"]} 

614 try: 

615 packed = pack(msg_list) 

616 except Exception as e: 

617 msg = f"packer '{self.packer}' could not serialize a simple message: {e}" 

618 raise ValueError(msg) from e 

619 

620 # ensure packed message is bytes 

621 if not isinstance(packed, bytes): 

622 raise ValueError("message packed to %r, but bytes are required" % type(packed)) 

623 

624 # check that unpack is pack's inverse 

625 try: 

626 unpacked = unpack(packed) 

627 assert unpacked == msg_list 

628 except Exception as e: 

629 msg = ( 

630 f"unpacker '{self.unpacker}' could not handle output from packer" 

631 f" '{self.packer}': {e}" 

632 ) 

633 raise ValueError(msg) from e 

634 

635 # check datetime support 

636 msg_datetime = {"t": utcnow()} 

637 try: 

638 unpacked = unpack(pack(msg_datetime)) 

639 if isinstance(unpacked["t"], datetime): 

640 msg = "Shouldn't deserialize to datetime" 

641 raise ValueError(msg) 

642 except Exception: 

643 self.pack = lambda o: pack(squash_dates(o)) 

644 self.unpack = lambda s: unpack(s) 

645 

646 def msg_header(self, msg_type: str) -> dict[str, t.Any]: 

647 """Create a header for a message type.""" 

648 return msg_header(self.msg_id, msg_type, self.username, self.session) 

649 

650 def msg( 

651 self, 

652 msg_type: str, 

653 content: dict | None = None, 

654 parent: dict[str, t.Any] | None = None, 

655 header: dict[str, t.Any] | None = None, 

656 metadata: dict[str, t.Any] | None = None, 

657 ) -> dict[str, t.Any]: 

658 """Return the nested message dict. 

659 

660 This format is different from what is sent over the wire. The 

661 serialize/deserialize methods converts this nested message dict to the wire 

662 format, which is a list of message parts. 

663 """ 

664 msg = {} 

665 header = self.msg_header(msg_type) if header is None else header 

666 msg["header"] = header 

667 msg["msg_id"] = header["msg_id"] 

668 msg["msg_type"] = header["msg_type"] 

669 msg["parent_header"] = {} if parent is None else extract_header(parent) 

670 msg["content"] = {} if content is None else content 

671 msg["metadata"] = self.metadata.copy() 

672 if metadata is not None: 

673 msg["metadata"].update(metadata) 

674 return msg 

675 

676 def sign(self, msg_list: list) -> bytes: 

677 """Sign a message with HMAC digest. If no auth, return b''. 

678 

679 Parameters 

680 ---------- 

681 msg_list : list 

682 The [p_header,p_parent,p_content] part of the message list. 

683 """ 

684 if self.auth is None: 

685 return b"" 

686 h = self.auth.copy() 

687 for m in msg_list: 

688 h.update(m) 

689 return h.hexdigest().encode() 

690 

691 def serialize( 

692 self, 

693 msg: dict[str, t.Any], 

694 ident: list[bytes] | bytes | None = None, 

695 ) -> list[bytes]: 

696 """Serialize the message components to bytes. 

697 

698 This is roughly the inverse of deserialize. The serialize/deserialize 

699 methods work with full message lists, whereas pack/unpack work with 

700 the individual message parts in the message list. 

701 

702 Parameters 

703 ---------- 

704 msg : dict or Message 

705 The next message dict as returned by the self.msg method. 

706 

707 Returns 

708 ------- 

709 msg_list : list 

710 The list of bytes objects to be sent with the format:: 

711 

712 [ident1, ident2, ..., DELIM, HMAC, p_header, p_parent, 

713 p_metadata, p_content, buffer1, buffer2, ...] 

714 

715 In this list, the ``p_*`` entities are the packed or serialized 

716 versions, so if JSON is used, these are utf8 encoded JSON strings. 

717 """ 

718 content = msg.get("content", {}) 

719 if content is None: 

720 content = self.none 

721 elif isinstance(content, dict): 

722 content = self.pack(content) 

723 elif isinstance(content, bytes): 

724 # content is already packed, as in a relayed message 

725 pass 

726 elif isinstance(content, str): 

727 # should be bytes, but JSON often spits out unicode 

728 content = content.encode("utf8") 

729 else: 

730 raise TypeError("Content incorrect type: %s" % type(content)) 

731 

732 real_message = [ 

733 self.pack(msg["header"]), 

734 self.pack(msg["parent_header"]), 

735 self.pack(msg["metadata"]), 

736 content, 

737 ] 

738 

739 to_send = [] 

740 

741 if isinstance(ident, list): 

742 # accept list of idents 

743 to_send.extend(ident) 

744 elif ident is not None: 

745 to_send.append(ident) 

746 to_send.append(DELIM) 

747 

748 signature = self.sign(real_message) 

749 to_send.append(signature) 

750 

751 to_send.extend(real_message) 

752 

753 return to_send 

754 

755 def send( 

756 self, 

757 stream: zmq.sugar.socket.Socket | ZMQStream | None, 

758 msg_or_type: dict[str, t.Any] | str, 

759 content: dict[str, t.Any] | None = None, 

760 parent: dict[str, t.Any] | None = None, 

761 ident: bytes | list[bytes] | None = None, 

762 buffers: list[bytes | memoryview[bytes]] | None = None, 

763 track: bool = False, 

764 header: dict[str, t.Any] | None = None, 

765 metadata: dict[str, t.Any] | None = None, 

766 ) -> dict[str, t.Any] | None: 

767 """Build and send a message via stream or socket. 

768 

769 The message format used by this function internally is as follows: 

770 

771 [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content, 

772 buffer1,buffer2,...] 

773 

774 The serialize/deserialize methods convert the nested message dict into this 

775 format. 

776 

777 Parameters 

778 ---------- 

779 

780 stream : zmq.Socket or ZMQStream 

781 The socket-like object used to send the data. 

782 msg_or_type : str or Message/dict 

783 Normally, msg_or_type will be a msg_type unless a message is being 

784 sent more than once. If a header is supplied, this can be set to 

785 None and the msg_type will be pulled from the header. 

786 

787 content : dict or None 

788 The content of the message (ignored if msg_or_type is a message). 

789 header : dict or None 

790 The header dict for the message (ignored if msg_to_type is a message). 

791 parent : Message or dict or None 

792 The parent or parent header describing the parent of this message 

793 (ignored if msg_or_type is a message). 

794 ident : bytes or list of bytes 

795 The zmq.IDENTITY routing path. 

796 metadata : dict or None 

797 The metadata describing the message 

798 buffers : list or None 

799 The already-serialized buffers to be appended to the message. 

800 track : bool 

801 Whether to track. Only for use with Sockets, because ZMQStream 

802 objects cannot track messages. 

803 

804 

805 Returns 

806 ------- 

807 msg : dict 

808 The constructed message. 

809 """ 

810 if not isinstance(stream, zmq.Socket): 

811 # ZMQStreams and dummy sockets do not support tracking. 

812 track = False 

813 

814 if isinstance(stream, zmq.asyncio.Socket): 

815 assert stream is not None 

816 stream = zmq.Socket.shadow(stream.underlying) 

817 

818 if isinstance(msg_or_type, Message | dict): 

819 # We got a Message or message dict, not a msg_type so don't 

820 # build a new Message. 

821 msg = msg_or_type 

822 buffers = buffers or msg.get("buffers", []) 

823 else: 

824 msg = self.msg( 

825 msg_or_type, 

826 content=content, 

827 parent=parent, 

828 header=header, 

829 metadata=metadata, 

830 ) 

831 if self.check_pid and os.getpid() != self.pid: 

832 get_logger().warning("WARNING: attempted to send message from fork\n%s", msg) 

833 return None 

834 buffers = [] if buffers is None else buffers 

835 for idx, buf in enumerate(buffers): 

836 if isinstance(buf, memoryview): 

837 view = buf 

838 else: 

839 try: 

840 # check to see if buf supports the buffer protocol. 

841 view = memoryview(buf) 

842 except TypeError as e: 

843 emsg = "Buffer objects must support the buffer protocol." 

844 raise TypeError(emsg) from e 

845 if not view.contiguous: 

846 # zmq requires memoryviews to be contiguous 

847 raise ValueError("Buffer %i (%r) is not contiguous" % (idx, buf)) 

848 

849 if self.adapt_version: 

850 msg = adapt(msg, self.adapt_version) 

851 to_send = self.serialize(msg, ident) 

852 to_send.extend(buffers) # type: ignore[arg-type] 

853 longest = max([len(s) for s in to_send]) 

854 copy = longest < self.copy_threshold 

855 

856 if stream and buffers and track and not copy: 

857 # only really track when we are doing zero-copy buffers 

858 tracker = stream.send_multipart(to_send, copy=False, track=True) 

859 elif stream: 

860 # use dummy tracker, which will be done immediately 

861 tracker = DONE 

862 stream.send_multipart(to_send, copy=copy) 

863 else: 

864 tracker = DONE 

865 

866 if self.debug: 

867 pprint.pprint(msg) # noqa 

868 pprint.pprint(to_send) # noqa 

869 pprint.pprint(buffers) # noqa 

870 

871 msg["tracker"] = tracker 

872 

873 return msg 

874 

875 def send_raw( 

876 self, 

877 stream: zmq.sugar.socket.Socket, 

878 msg_list: list, 

879 flags: int = 0, 

880 copy: bool = True, 

881 ident: bytes | list[bytes] | None = None, 

882 ) -> None: 

883 """Send a raw message via ident path. 

884 

885 This method is used to send a already serialized message. 

886 

887 Parameters 

888 ---------- 

889 stream : ZMQStream or Socket 

890 The ZMQ stream or socket to use for sending the message. 

891 msg_list : list 

892 The serialized list of messages to send. This only includes the 

893 [p_header,p_parent,p_metadata,p_content,buffer1,buffer2,...] portion of 

894 the message. 

895 ident : ident or list 

896 A single ident or a list of idents to use in sending. 

897 """ 

898 to_send = [] 

899 if isinstance(ident, bytes): 

900 ident = [ident] 

901 if ident is not None: 

902 to_send.extend(ident) 

903 

904 to_send.append(DELIM) 

905 # Don't include buffers in signature (per spec). 

906 to_send.append(self.sign(msg_list[0:4])) 

907 to_send.extend(msg_list) 

908 if isinstance(stream, zmq.asyncio.Socket): 

909 stream = zmq.Socket.shadow(stream.underlying) 

910 stream.send_multipart(to_send, flags, copy=copy) 

911 

912 def recv( 

913 self, 

914 socket: zmq.sugar.socket.Socket, 

915 mode: int = zmq.NOBLOCK, 

916 content: bool = True, 

917 copy: bool = True, 

918 ) -> tuple[list[bytes] | None, dict[str, t.Any] | None]: 

919 """Receive and unpack a message. 

920 

921 Parameters 

922 ---------- 

923 socket : ZMQStream or Socket 

924 The socket or stream to use in receiving. 

925 

926 Returns 

927 ------- 

928 [idents], msg 

929 [idents] is a list of idents and msg is a nested message dict of 

930 same format as self.msg returns. 

931 """ 

932 if isinstance(socket, ZMQStream): # type:ignore[unreachable] 

933 socket = socket.socket # type:ignore[unreachable] 

934 if isinstance(socket, zmq.asyncio.Socket): 

935 socket = zmq.Socket.shadow(socket.underlying) 

936 

937 try: 

938 msg_list = socket.recv_multipart(mode, copy=copy) 

939 except zmq.ZMQError as e: 

940 if e.errno == zmq.EAGAIN: 

941 # We can convert EAGAIN to None as we know in this case 

942 # recv_multipart won't return None. 

943 return None, None 

944 else: 

945 raise 

946 # split multipart message into identity list and message dict 

947 # invalid large messages can cause very expensive string comparisons 

948 idents, msg_list = self.feed_identities(msg_list, copy) 

949 try: 

950 return idents, self.deserialize(msg_list, content=content, copy=copy) 

951 except Exception as e: 

952 # TODO: handle it 

953 raise e 

954 

955 def feed_identities( 

956 self, msg_list: list[bytes] | list[zmq.Message], copy: bool = True 

957 ) -> tuple[list[bytes], list[bytes] | list[zmq.Message]]: 

958 """Split the identities from the rest of the message. 

959 

960 Feed until DELIM is reached, then return the prefix as idents and 

961 remainder as msg_list. This is easily broken by setting an IDENT to DELIM, 

962 but that would be silly. 

963 

964 Parameters 

965 ---------- 

966 msg_list : a list of Message or bytes objects 

967 The message to be split. 

968 copy : bool 

969 flag determining whether the arguments are bytes or Messages 

970 

971 Returns 

972 ------- 

973 (idents, msg_list) : two lists 

974 idents will always be a list of bytes, each of which is a ZMQ 

975 identity. msg_list will be a list of bytes or zmq.Messages of the 

976 form [HMAC,p_header,p_parent,p_content,buffer1,buffer2,...] and 

977 should be unpackable/unserializable via self.deserialize at this 

978 point. 

979 """ 

980 if copy: 

981 msg_list = t.cast(t.List[bytes], msg_list) 

982 idx = msg_list.index(DELIM) 

983 return msg_list[:idx], msg_list[idx + 1 :] 

984 else: 

985 msg_list = t.cast(t.List[zmq.Message], msg_list) 

986 failed = True 

987 for idx, m in enumerate(msg_list): # noqa 

988 if m.bytes == DELIM: 

989 failed = False 

990 break 

991 if failed: 

992 msg = "DELIM not in msg_list" 

993 raise ValueError(msg) 

994 idents, msg_list = msg_list[:idx], msg_list[idx + 1 :] 

995 return [bytes(m.bytes) for m in idents], msg_list 

996 

997 def _add_digest(self, signature: bytes) -> None: 

998 """add a digest to history to protect against replay attacks""" 

999 if self.digest_history_size == 0: 

1000 # no history, never add digests 

1001 return 

1002 

1003 self.digest_history.add(signature) 

1004 if len(self.digest_history) > self.digest_history_size: 

1005 # threshold reached, cull 10% 

1006 self._cull_digest_history() 

1007 

1008 def _cull_digest_history(self) -> None: 

1009 """cull the digest history 

1010 

1011 Removes a randomly selected 10% of the digest history 

1012 """ 

1013 current = len(self.digest_history) 

1014 n_to_cull = max(int(current // 10), current - self.digest_history_size) 

1015 if n_to_cull >= current: 

1016 self.digest_history = set() 

1017 return 

1018 to_cull = random.sample(tuple(sorted(self.digest_history)), n_to_cull) 

1019 self.digest_history.difference_update(to_cull) 

1020 

1021 def deserialize( 

1022 self, 

1023 msg_list: list[bytes] | list[zmq.Message], 

1024 content: bool = True, 

1025 copy: bool = True, 

1026 ) -> dict[str, t.Any]: 

1027 """Unserialize a msg_list to a nested message dict. 

1028 

1029 This is roughly the inverse of serialize. The serialize/deserialize 

1030 methods work with full message lists, whereas pack/unpack work with 

1031 the individual message parts in the message list. 

1032 

1033 Parameters 

1034 ---------- 

1035 msg_list : list of bytes or Message objects 

1036 The list of message parts of the form [HMAC,p_header,p_parent, 

1037 p_metadata,p_content,buffer1,buffer2,...]. 

1038 content : bool (True) 

1039 Whether to unpack the content dict (True), or leave it packed 

1040 (False). 

1041 copy : bool (True) 

1042 Whether msg_list contains bytes (True) or the non-copying Message 

1043 objects in each place (False). 

1044 

1045 Returns 

1046 ------- 

1047 msg : dict 

1048 The nested message dict with top-level keys [header, parent_header, 

1049 content, buffers]. The buffers are returned as memoryviews. 

1050 """ 

1051 minlen = 5 

1052 message = {} 

1053 if not copy: 

1054 # pyzmq didn't copy the first parts of the message, so we'll do it 

1055 msg_list = t.cast(t.List[zmq.Message], msg_list) 

1056 msg_list_beginning = [bytes(msg.bytes) for msg in msg_list[:minlen]] 

1057 msg_list = t.cast(t.List[bytes], msg_list) 

1058 msg_list = msg_list_beginning + msg_list[minlen:] 

1059 msg_list = t.cast(t.List[bytes], msg_list) 

1060 if self.auth is not None: 

1061 signature = msg_list[0] 

1062 if not signature: 

1063 msg = "Unsigned Message" 

1064 raise ValueError(msg) 

1065 if signature in self.digest_history: 

1066 raise ValueError("Duplicate Signature: %r" % signature) 

1067 if content: 

1068 # Only store signature if we are unpacking content, don't store if just peeking. 

1069 self._add_digest(signature) 

1070 check = self.sign(msg_list[1:5]) 

1071 if not compare_digest(signature, check): 

1072 msg = "Invalid Signature: %r" % signature 

1073 raise ValueError(msg) 

1074 if not len(msg_list) >= minlen: 

1075 msg = "malformed message, must have at least %i elements" % minlen 

1076 raise TypeError(msg) 

1077 header = self.unpack(msg_list[1]) 

1078 message["header"] = extract_dates(header) 

1079 message["msg_id"] = header["msg_id"] 

1080 message["msg_type"] = header["msg_type"] 

1081 message["parent_header"] = extract_dates(self.unpack(msg_list[2])) 

1082 message["metadata"] = self.unpack(msg_list[3]) 

1083 if content: 

1084 message["content"] = self.unpack(msg_list[4]) 

1085 else: 

1086 message["content"] = msg_list[4] 

1087 buffers = [memoryview(b) for b in msg_list[5:]] 

1088 if buffers and buffers[0].shape is None: 

1089 # force copy to workaround pyzmq #646 

1090 msg_list = t.cast(t.List[zmq.Message], msg_list) 

1091 buffers = [memoryview(bytes(b.bytes)) for b in msg_list[5:]] 

1092 message["buffers"] = buffers 

1093 if self.debug: 

1094 pprint.pprint(message) # noqa 

1095 # adapt to the current version 

1096 return adapt(message) 

1097 

1098 def unserialize(self, *args: t.Any, **kwargs: t.Any) -> dict[str, t.Any]: 

1099 """**DEPRECATED** Use deserialize instead.""" 

1100 # pragma: no cover 

1101 warnings.warn( 

1102 "Session.unserialize is deprecated. Use Session.deserialize.", 

1103 DeprecationWarning, 

1104 stacklevel=2, 

1105 ) 

1106 return self.deserialize(*args, **kwargs)