1"""Utilities for connecting to jupyter kernels
2
3The :class:`ConnectionFileMixin` class in this module encapsulates the logic
4related to writing and reading connections files.
5"""
6
7# Copyright (c) Jupyter Development Team.
8# Distributed under the terms of the Modified BSD License.
9from __future__ import annotations
10
11import errno
12import glob
13import json
14import os
15import socket
16import stat
17import tempfile
18import warnings
19from getpass import getpass
20from typing import TYPE_CHECKING, Any, Union, cast
21
22import zmq
23from jupyter_core.paths import jupyter_data_dir, jupyter_runtime_dir, secure_write
24from traitlets import Bool, CaselessStrEnum, Instance, Integer, Type, Unicode, observe
25from traitlets.config import LoggingConfigurable, SingletonConfigurable
26
27from .localinterfaces import localhost
28from .utils import _filefind
29
30if TYPE_CHECKING:
31 from jupyter_client import BlockingKernelClient
32
33 from .session import Session
34
35# Define custom type for kernel connection info
36KernelConnectionInfo = dict[str, Union[int, str, bytes]]
37
38
39def write_connection_file(
40 fname: str | None = None,
41 shell_port: int = 0,
42 iopub_port: int = 0,
43 stdin_port: int = 0,
44 hb_port: int = 0,
45 control_port: int = 0,
46 ip: str = "",
47 key: bytes = b"",
48 transport: str = "tcp",
49 signature_scheme: str = "hmac-sha256",
50 kernel_name: str = "",
51 **kwargs: Any,
52) -> tuple[str, KernelConnectionInfo]:
53 """Generates a JSON config file, including the selection of random ports.
54
55 Parameters
56 ----------
57
58 fname : unicode
59 The path to the file to write
60
61 shell_port : int, optional
62 The port to use for ROUTER (shell) channel.
63
64 iopub_port : int, optional
65 The port to use for the SUB channel.
66
67 stdin_port : int, optional
68 The port to use for the ROUTER (raw input) channel.
69
70 control_port : int, optional
71 The port to use for the ROUTER (control) channel.
72
73 hb_port : int, optional
74 The port to use for the heartbeat REP channel.
75
76 ip : str, optional
77 The ip address the kernel will bind to.
78
79 key : str, optional
80 The Session key used for message authentication.
81
82 signature_scheme : str, optional
83 The scheme used for message authentication.
84 This has the form 'digest-hash', where 'digest'
85 is the scheme used for digests, and 'hash' is the name of the hash function
86 used by the digest scheme.
87 Currently, 'hmac' is the only supported digest scheme,
88 and 'sha256' is the default hash function.
89
90 kernel_name : str, optional
91 The name of the kernel currently connected to.
92 """
93 if not ip:
94 ip = localhost()
95 # default to temporary connector file
96 if not fname:
97 fd, fname = tempfile.mkstemp(".json")
98 os.close(fd)
99
100 # Find open ports as necessary.
101
102 ports: list[int] = []
103 sockets: list[socket.socket] = []
104 ports_needed = (
105 int(shell_port <= 0)
106 + int(iopub_port <= 0)
107 + int(stdin_port <= 0)
108 + int(control_port <= 0)
109 + int(hb_port <= 0)
110 )
111 if transport == "tcp":
112 for _ in range(ports_needed):
113 sock = socket.socket()
114 # struct.pack('ii', (0,0)) is 8 null bytes
115 sock.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, b"\0" * 8)
116 sock.bind((ip, 0))
117 sockets.append(sock)
118 for sock in sockets:
119 port = sock.getsockname()[1]
120 sock.close()
121 ports.append(port)
122 else:
123 N = 1
124 for _ in range(ports_needed):
125 while os.path.exists(f"{ip}-{N!s}"):
126 N += 1
127 ports.append(N)
128 N += 1
129 if shell_port <= 0:
130 shell_port = ports.pop(0)
131 if iopub_port <= 0:
132 iopub_port = ports.pop(0)
133 if stdin_port <= 0:
134 stdin_port = ports.pop(0)
135 if control_port <= 0:
136 control_port = ports.pop(0)
137 if hb_port <= 0:
138 hb_port = ports.pop(0)
139
140 cfg: KernelConnectionInfo = {
141 "shell_port": shell_port,
142 "iopub_port": iopub_port,
143 "stdin_port": stdin_port,
144 "control_port": control_port,
145 "hb_port": hb_port,
146 }
147 cfg["ip"] = ip
148 cfg["key"] = key.decode()
149 cfg["transport"] = transport
150 cfg["signature_scheme"] = signature_scheme
151 cfg["kernel_name"] = kernel_name
152 cfg.update(kwargs)
153
154 # Only ever write this file as user read/writeable
155 # This would otherwise introduce a vulnerability as a file has secrets
156 # which would let others execute arbitrary code as you
157 with secure_write(fname) as f:
158 f.write(json.dumps(cfg, indent=2))
159
160 if hasattr(stat, "S_ISVTX"):
161 # set the sticky bit on the parent directory of the file
162 # to ensure only owner can remove it
163 runtime_dir = os.path.dirname(fname)
164 if runtime_dir:
165 permissions = os.stat(runtime_dir).st_mode
166 new_permissions = permissions | stat.S_ISVTX
167 if new_permissions != permissions:
168 try:
169 os.chmod(runtime_dir, new_permissions)
170 except OSError as e:
171 if e.errno == errno.EPERM:
172 # suppress permission errors setting sticky bit on runtime_dir,
173 # which we may not own.
174 pass
175 return fname, cfg
176
177
178def find_connection_file(
179 filename: str = "kernel-*.json",
180 path: str | list[str] | None = None,
181 profile: str | None = None,
182) -> str:
183 """find a connection file, and return its absolute path.
184
185 The current working directory and optional search path
186 will be searched for the file if it is not given by absolute path.
187
188 If the argument does not match an existing file, it will be interpreted as a
189 fileglob, and the matching file in the profile's security dir with
190 the latest access time will be used.
191
192 Parameters
193 ----------
194 filename : str
195 The connection file or fileglob to search for.
196 path : str or list of strs[optional]
197 Paths in which to search for connection files.
198
199 Returns
200 -------
201 str : The absolute path of the connection file.
202 """
203 if profile is not None:
204 warnings.warn(
205 "Jupyter has no profiles. profile=%s has been ignored." % profile, stacklevel=2
206 )
207 if path is None:
208 path = [".", jupyter_runtime_dir()]
209 if isinstance(path, str):
210 path = [path]
211
212 try:
213 # first, try explicit name
214 return _filefind(filename, path)
215 except OSError:
216 pass
217
218 # not found by full name
219
220 if "*" in filename:
221 # given as a glob already
222 pat = filename
223 else:
224 # accept any substring match
225 pat = "*%s*" % filename
226
227 matches = []
228 for p in path:
229 matches.extend(glob.glob(os.path.join(p, pat)))
230
231 matches = [os.path.abspath(m) for m in matches]
232 if not matches:
233 msg = f"Could not find {filename!r} in {path!r}"
234 raise OSError(msg)
235 elif len(matches) == 1:
236 return matches[0]
237 else:
238 # get most recent match, by access time:
239 return sorted(matches, key=lambda f: os.stat(f).st_atime)[-1]
240
241
242def tunnel_to_kernel(
243 connection_info: str | KernelConnectionInfo,
244 sshserver: str,
245 sshkey: str | None = None,
246) -> tuple[Any, ...]:
247 """tunnel connections to a kernel via ssh
248
249 This will open five SSH tunnels from localhost on this machine to the
250 ports associated with the kernel. They can be either direct
251 localhost-localhost tunnels, or if an intermediate server is necessary,
252 the kernel must be listening on a public IP.
253
254 Parameters
255 ----------
256 connection_info : dict or str (path)
257 Either a connection dict, or the path to a JSON connection file
258 sshserver : str
259 The ssh sever to use to tunnel to the kernel. Can be a full
260 `user@server:port` string. ssh config aliases are respected.
261 sshkey : str [optional]
262 Path to file containing ssh key to use for authentication.
263 Only necessary if your ssh config does not already associate
264 a keyfile with the host.
265
266 Returns
267 -------
268
269 (shell, iopub, stdin, hb, control) : ints
270 The five ports on localhost that have been forwarded to the kernel.
271 """
272 from .ssh import tunnel
273
274 if isinstance(connection_info, str):
275 # it's a path, unpack it
276 with open(connection_info) as f:
277 connection_info = json.loads(f.read())
278
279 cf = cast(dict[str, Any], connection_info)
280
281 lports = tunnel.select_random_ports(5)
282 rports = (
283 cf["shell_port"],
284 cf["iopub_port"],
285 cf["stdin_port"],
286 cf["hb_port"],
287 cf["control_port"],
288 )
289
290 remote_ip = cf["ip"]
291
292 if tunnel.try_passwordless_ssh(sshserver, sshkey):
293 password: bool | str = False
294 else:
295 password = getpass("SSH Password for %s: " % sshserver)
296
297 for lp, rp in zip(lports, rports, strict=False):
298 tunnel.ssh_tunnel(lp, rp, sshserver, remote_ip, sshkey, password)
299
300 return tuple(lports)
301
302
303# -----------------------------------------------------------------------------
304# Mixin for classes that work with connection files
305# -----------------------------------------------------------------------------
306
307channel_socket_types = {
308 "hb": zmq.REQ,
309 "shell": zmq.DEALER,
310 "iopub": zmq.SUB,
311 "stdin": zmq.DEALER,
312 "control": zmq.DEALER,
313}
314
315port_names = ["%s_port" % channel for channel in ("shell", "stdin", "iopub", "hb", "control")]
316
317
318class ConnectionFileMixin(LoggingConfigurable):
319 """Mixin for configurable classes that work with connection files"""
320
321 data_dir: str | Unicode = Unicode()
322
323 def _data_dir_default(self) -> str:
324 return jupyter_data_dir()
325
326 # The addresses for the communication channels
327 connection_file = Unicode(
328 "",
329 config=True,
330 help="""JSON file in which to store connection info [default: kernel-<pid>.json]
331
332 This file will contain the IP, ports, and authentication key needed to connect
333 clients to this kernel. By default, this file will be created in the security dir
334 of the current profile, but can be specified by absolute path.
335 """,
336 )
337 _connection_file_written = Bool(False)
338
339 transport = CaselessStrEnum(["tcp", "ipc"], default_value="tcp", config=True)
340 kernel_name: str | Unicode = Unicode()
341
342 context = Instance(zmq.Context)
343
344 ip = Unicode(
345 config=True,
346 help="""Set the kernel\'s IP address [default localhost].
347 If the IP address is something other than localhost, then
348 Consoles on other machines will be able to connect
349 to the Kernel, so be careful!""",
350 )
351
352 def _ip_default(self) -> str:
353 if self.transport == "ipc":
354 if self.connection_file:
355 return os.path.splitext(self.connection_file)[0] + "-ipc"
356 else:
357 return "kernel-ipc"
358 else:
359 return localhost()
360
361 @observe("ip")
362 def _ip_changed(self, change: Any) -> None:
363 if change["new"] == "*":
364 self.ip = "0.0.0.0" # noqa
365
366 # protected traits
367
368 hb_port = Integer(0, config=True, help="set the heartbeat port [default: random]")
369 shell_port = Integer(0, config=True, help="set the shell (ROUTER) port [default: random]")
370 iopub_port = Integer(0, config=True, help="set the iopub (PUB) port [default: random]")
371 stdin_port = Integer(0, config=True, help="set the stdin (ROUTER) port [default: random]")
372 control_port = Integer(0, config=True, help="set the control (ROUTER) port [default: random]")
373
374 # names of the ports with random assignment
375 _random_port_names: list[str] | None = None
376
377 @property
378 def ports(self) -> list[int]:
379 return [getattr(self, name) for name in port_names]
380
381 # The Session to use for communication with the kernel.
382 session = Instance("jupyter_client.session.Session")
383
384 def _session_default(self) -> Session:
385 from .session import Session
386
387 return Session(parent=self)
388
389 # --------------------------------------------------------------------------
390 # Connection and ipc file management
391 # --------------------------------------------------------------------------
392
393 def get_connection_info(self, session: bool = False) -> KernelConnectionInfo:
394 """Return the connection info as a dict
395
396 Parameters
397 ----------
398 session : bool [default: False]
399 If True, return our session object will be included in the connection info.
400 If False (default), the configuration parameters of our session object will be included,
401 rather than the session object itself.
402
403 Returns
404 -------
405 connect_info : dict
406 dictionary of connection information.
407 """
408 info = {
409 "transport": self.transport,
410 "ip": self.ip,
411 "shell_port": self.shell_port,
412 "iopub_port": self.iopub_port,
413 "stdin_port": self.stdin_port,
414 "hb_port": self.hb_port,
415 "control_port": self.control_port,
416 }
417 if session:
418 # add *clone* of my session,
419 # so that state such as digest_history is not shared.
420 info["session"] = self.session.clone()
421 else:
422 # add session info
423 info.update(
424 {
425 "signature_scheme": self.session.signature_scheme,
426 "key": self.session.key,
427 }
428 )
429 return info
430
431 # factory for blocking clients
432 blocking_class = Type(klass=object, default_value="jupyter_client.BlockingKernelClient")
433
434 def blocking_client(self) -> BlockingKernelClient:
435 """Make a blocking client connected to my kernel"""
436 info = self.get_connection_info()
437 bc = self.blocking_class(parent=self) # type:ignore[operator]
438 bc.load_connection_info(info)
439 return bc
440
441 def cleanup_connection_file(self) -> None:
442 """Cleanup connection file *if we wrote it*
443
444 Will not raise if the connection file was already removed somehow.
445 """
446 if self._connection_file_written:
447 # cleanup connection files on full shutdown of kernel we started
448 self._connection_file_written = False
449 try:
450 os.remove(self.connection_file)
451 except (OSError, AttributeError):
452 pass
453
454 def cleanup_ipc_files(self) -> None:
455 """Cleanup ipc files if we wrote them."""
456 if self.transport != "ipc":
457 return
458 for port in self.ports:
459 ipcfile = "%s-%i" % (self.ip, port)
460 try:
461 os.remove(ipcfile)
462 except OSError:
463 pass
464
465 def _record_random_port_names(self) -> None:
466 """Records which of the ports are randomly assigned.
467
468 Records on first invocation, if the transport is tcp.
469 Does nothing on later invocations."""
470
471 if self.transport != "tcp":
472 return
473 if self._random_port_names is not None:
474 return
475
476 self._random_port_names = []
477 for name in port_names:
478 if getattr(self, name) <= 0:
479 self._random_port_names.append(name)
480
481 def cleanup_random_ports(self) -> None:
482 """Forgets randomly assigned port numbers and cleans up the connection file.
483
484 Does nothing if no port numbers have been randomly assigned.
485 In particular, does nothing unless the transport is tcp.
486 """
487
488 if not self._random_port_names:
489 return
490
491 for name in self._random_port_names:
492 setattr(self, name, 0)
493
494 self.cleanup_connection_file()
495
496 def write_connection_file(self, **kwargs: Any) -> None:
497 """Write connection info to JSON dict in self.connection_file."""
498 if self._connection_file_written and os.path.exists(self.connection_file):
499 return
500
501 self.connection_file, cfg = write_connection_file(
502 self.connection_file,
503 transport=self.transport,
504 ip=self.ip,
505 key=self.session.key,
506 stdin_port=self.stdin_port,
507 iopub_port=self.iopub_port,
508 shell_port=self.shell_port,
509 hb_port=self.hb_port,
510 control_port=self.control_port,
511 signature_scheme=self.session.signature_scheme,
512 kernel_name=self.kernel_name,
513 **kwargs,
514 )
515 # write_connection_file also sets default ports:
516 self._record_random_port_names()
517 for name in port_names:
518 setattr(self, name, cfg[name])
519
520 self._connection_file_written = True
521
522 def load_connection_file(self, connection_file: str | None = None) -> None:
523 """Load connection info from JSON dict in self.connection_file.
524
525 Parameters
526 ----------
527 connection_file: unicode, optional
528 Path to connection file to load.
529 If unspecified, use self.connection_file
530 """
531 if connection_file is None:
532 connection_file = self.connection_file
533 self.log.debug("Loading connection file %s", connection_file)
534 with open(connection_file) as f:
535 info = json.load(f)
536 self.load_connection_info(info)
537
538 def load_connection_info(self, info: KernelConnectionInfo) -> None:
539 """Load connection info from a dict containing connection info.
540
541 Typically this data comes from a connection file
542 and is called by load_connection_file.
543
544 Parameters
545 ----------
546 info: dict
547 Dictionary containing connection_info.
548 See the connection_file spec for details.
549 """
550 self.transport = info.get("transport", self.transport)
551 self.ip = info.get("ip", self._ip_default()) # type:ignore[assignment]
552
553 self._record_random_port_names()
554 for name in port_names:
555 if getattr(self, name) == 0 and name in info:
556 # not overridden by config or cl_args
557 setattr(self, name, info[name])
558
559 if "key" in info:
560 key = info["key"]
561 if isinstance(key, str):
562 key = key.encode()
563 assert isinstance(key, bytes)
564
565 self.session.key = key
566 if "signature_scheme" in info:
567 self.session.signature_scheme = info["signature_scheme"]
568
569 def _reconcile_connection_info(self, info: KernelConnectionInfo) -> None:
570 """Reconciles the connection information returned from the Provisioner.
571
572 Because some provisioners (like derivations of LocalProvisioner) may have already
573 written the connection file, this method needs to ensure that, if the connection
574 file exists, its contents match that of what was returned by the provisioner. If
575 the file does exist and its contents do not match, the file will be replaced with
576 the provisioner information (which is considered the truth).
577
578 If the file does not exist, the connection information in 'info' is loaded into the
579 KernelManager and written to the file.
580 """
581 # Prevent over-writing a file that has already been written with the same
582 # info. This is to prevent a race condition where the process has
583 # already been launched but has not yet read the connection file - as is
584 # the case with LocalProvisioners.
585 file_exists: bool = False
586 if os.path.exists(self.connection_file):
587 with open(self.connection_file) as f:
588 file_info = json.load(f)
589 # Prior to the following comparison, we need to adjust the value of "key" to
590 # be bytes, otherwise the comparison below will fail.
591 file_info["key"] = file_info["key"].encode()
592 if not self._equal_connections(info, file_info):
593 os.remove(self.connection_file) # Contents mismatch - remove the file
594 self._connection_file_written = False
595 else:
596 file_exists = True
597
598 if not file_exists:
599 # Load the connection info and write out file, clearing existing
600 # port-based attributes so they will be reloaded
601 for name in port_names:
602 setattr(self, name, 0)
603 self.load_connection_info(info)
604 self.write_connection_file()
605
606 # Ensure what is in KernelManager is what we expect.
607 km_info = self.get_connection_info()
608 if not self._equal_connections(info, km_info):
609 msg = (
610 "KernelManager's connection information already exists and does not match "
611 "the expected values returned from provisioner!"
612 )
613 raise ValueError(msg)
614
615 @staticmethod
616 def _equal_connections(conn1: KernelConnectionInfo, conn2: KernelConnectionInfo) -> bool:
617 """Compares pertinent keys of connection info data. Returns True if equivalent, False otherwise."""
618
619 pertinent_keys = [
620 "key",
621 "ip",
622 "stdin_port",
623 "iopub_port",
624 "shell_port",
625 "control_port",
626 "hb_port",
627 "transport",
628 "signature_scheme",
629 ]
630
631 return all(conn1.get(key) == conn2.get(key) for key in pertinent_keys)
632
633 # --------------------------------------------------------------------------
634 # Creating connected sockets
635 # --------------------------------------------------------------------------
636
637 def _make_url(self, channel: str) -> str:
638 """Make a ZeroMQ URL for a given channel."""
639 transport = self.transport
640 ip = self.ip
641 port = getattr(self, "%s_port" % channel)
642
643 if transport == "tcp":
644 return "tcp://%s:%i" % (ip, port)
645 else:
646 return f"{transport}://{ip}-{port}"
647
648 def _create_connected_socket(
649 self, channel: str, identity: bytes | None = None
650 ) -> zmq.sugar.socket.Socket:
651 """Create a zmq Socket and connect it to the kernel."""
652 url = self._make_url(channel)
653 socket_type = channel_socket_types[channel]
654 self.log.debug("Connecting to: %s", url)
655 sock = self.context.socket(socket_type)
656 # set linger to 1s to prevent hangs at exit
657 sock.linger = 1000
658 if identity:
659 sock.identity = identity
660 sock.connect(url)
661 return sock
662
663 def connect_iopub(self, identity: bytes | None = None) -> zmq.sugar.socket.Socket:
664 """return zmq Socket connected to the IOPub channel"""
665 sock = self._create_connected_socket("iopub", identity=identity)
666 sock.setsockopt(zmq.SUBSCRIBE, b"")
667 return sock
668
669 def connect_shell(self, identity: bytes | None = None) -> zmq.sugar.socket.Socket:
670 """return zmq Socket connected to the Shell channel"""
671 return self._create_connected_socket("shell", identity=identity)
672
673 def connect_stdin(self, identity: bytes | None = None) -> zmq.sugar.socket.Socket:
674 """return zmq Socket connected to the StdIn channel"""
675 return self._create_connected_socket("stdin", identity=identity)
676
677 def connect_hb(self, identity: bytes | None = None) -> zmq.sugar.socket.Socket:
678 """return zmq Socket connected to the Heartbeat channel"""
679 return self._create_connected_socket("hb", identity=identity)
680
681 def connect_control(self, identity: bytes | None = None) -> zmq.sugar.socket.Socket:
682 """return zmq Socket connected to the Control channel"""
683 return self._create_connected_socket("control", identity=identity)
684
685
686class LocalPortCache(SingletonConfigurable):
687 """
688 Used to keep track of local ports in order to prevent race conditions that
689 can occur between port acquisition and usage by the kernel. All locally-
690 provisioned kernels should use this mechanism to limit the possibility of
691 race conditions. Note that this does not preclude other applications from
692 acquiring a cached but unused port, thereby re-introducing the issue this
693 class is attempting to resolve (minimize).
694 See: https://github.com/jupyter/jupyter_client/issues/487
695 """
696
697 def __init__(self, **kwargs: Any) -> None:
698 super().__init__(**kwargs)
699 self.currently_used_ports: set[int] = set()
700
701 def find_available_port(self, ip: str) -> int:
702 while True:
703 tmp_sock = socket.socket()
704 tmp_sock.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, b"\0" * 8)
705 tmp_sock.bind((ip, 0))
706 port = tmp_sock.getsockname()[1]
707 tmp_sock.close()
708
709 # This is a workaround for https://github.com/jupyter/jupyter_client/issues/487
710 # We prevent two kernels to have the same ports.
711 if port not in self.currently_used_ports:
712 self.currently_used_ports.add(port)
713 return port
714
715 def return_port(self, port: int) -> None:
716 if port in self.currently_used_ports: # Tolerate uncached ports
717 self.currently_used_ports.remove(port)
718
719
720__all__ = [
721 "KernelConnectionInfo",
722 "LocalPortCache",
723 "find_connection_file",
724 "tunnel_to_kernel",
725 "write_connection_file",
726]