1"""Utilities for connecting to jupyter kernels
2
3The :class:`ConnectionFileMixin` class in this module encapsulates the logic
4related to writing and reading connections files.
5"""
6# Copyright (c) Jupyter Development Team.
7# Distributed under the terms of the Modified BSD License.
8from __future__ import annotations
9
10import errno
11import glob
12import json
13import os
14import socket
15import stat
16import tempfile
17import warnings
18from getpass import getpass
19from typing import TYPE_CHECKING, Any, Dict, Union, cast
20
21import zmq
22from jupyter_core.paths import jupyter_data_dir, jupyter_runtime_dir, secure_write
23from traitlets import Bool, CaselessStrEnum, Instance, Integer, Type, Unicode, observe
24from traitlets.config import LoggingConfigurable, SingletonConfigurable
25
26from .localinterfaces import localhost
27from .utils import _filefind
28
29if TYPE_CHECKING:
30 from jupyter_client import BlockingKernelClient
31
32 from .session import Session
33
34# Define custom type for kernel connection info
35KernelConnectionInfo = Dict[str, Union[int, str, bytes]]
36
37
38def write_connection_file(
39 fname: str | None = None,
40 shell_port: int = 0,
41 iopub_port: int = 0,
42 stdin_port: int = 0,
43 hb_port: int = 0,
44 control_port: int = 0,
45 ip: str = "",
46 key: bytes = b"",
47 transport: str = "tcp",
48 signature_scheme: str = "hmac-sha256",
49 kernel_name: str = "",
50 **kwargs: Any,
51) -> tuple[str, KernelConnectionInfo]:
52 """Generates a JSON config file, including the selection of random ports.
53
54 Parameters
55 ----------
56
57 fname : unicode
58 The path to the file to write
59
60 shell_port : int, optional
61 The port to use for ROUTER (shell) channel.
62
63 iopub_port : int, optional
64 The port to use for the SUB channel.
65
66 stdin_port : int, optional
67 The port to use for the ROUTER (raw input) channel.
68
69 control_port : int, optional
70 The port to use for the ROUTER (control) channel.
71
72 hb_port : int, optional
73 The port to use for the heartbeat REP channel.
74
75 ip : str, optional
76 The ip address the kernel will bind to.
77
78 key : str, optional
79 The Session key used for message authentication.
80
81 signature_scheme : str, optional
82 The scheme used for message authentication.
83 This has the form 'digest-hash', where 'digest'
84 is the scheme used for digests, and 'hash' is the name of the hash function
85 used by the digest scheme.
86 Currently, 'hmac' is the only supported digest scheme,
87 and 'sha256' is the default hash function.
88
89 kernel_name : str, optional
90 The name of the kernel currently connected to.
91 """
92 if not ip:
93 ip = localhost()
94 # default to temporary connector file
95 if not fname:
96 fd, fname = tempfile.mkstemp(".json")
97 os.close(fd)
98
99 # Find open ports as necessary.
100
101 ports: list[int] = []
102 sockets: list[socket.socket] = []
103 ports_needed = (
104 int(shell_port <= 0)
105 + int(iopub_port <= 0)
106 + int(stdin_port <= 0)
107 + int(control_port <= 0)
108 + int(hb_port <= 0)
109 )
110 if transport == "tcp":
111 for _ in range(ports_needed):
112 sock = socket.socket()
113 # struct.pack('ii', (0,0)) is 8 null bytes
114 sock.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, b"\0" * 8)
115 sock.bind((ip, 0))
116 sockets.append(sock)
117 for sock in sockets:
118 port = sock.getsockname()[1]
119 sock.close()
120 ports.append(port)
121 else:
122 N = 1
123 for _ in range(ports_needed):
124 while os.path.exists(f"{ip}-{N!s}"):
125 N += 1
126 ports.append(N)
127 N += 1
128 if shell_port <= 0:
129 shell_port = ports.pop(0)
130 if iopub_port <= 0:
131 iopub_port = ports.pop(0)
132 if stdin_port <= 0:
133 stdin_port = ports.pop(0)
134 if control_port <= 0:
135 control_port = ports.pop(0)
136 if hb_port <= 0:
137 hb_port = ports.pop(0)
138
139 cfg: KernelConnectionInfo = {
140 "shell_port": shell_port,
141 "iopub_port": iopub_port,
142 "stdin_port": stdin_port,
143 "control_port": control_port,
144 "hb_port": hb_port,
145 }
146 cfg["ip"] = ip
147 cfg["key"] = key.decode()
148 cfg["transport"] = transport
149 cfg["signature_scheme"] = signature_scheme
150 cfg["kernel_name"] = kernel_name
151 cfg.update(kwargs)
152
153 # Only ever write this file as user read/writeable
154 # This would otherwise introduce a vulnerability as a file has secrets
155 # which would let others execute arbitrary code as you
156 with secure_write(fname) as f:
157 f.write(json.dumps(cfg, indent=2))
158
159 if hasattr(stat, "S_ISVTX"):
160 # set the sticky bit on the parent directory of the file
161 # to ensure only owner can remove it
162 runtime_dir = os.path.dirname(fname)
163 if runtime_dir:
164 permissions = os.stat(runtime_dir).st_mode
165 new_permissions = permissions | stat.S_ISVTX
166 if new_permissions != permissions:
167 try:
168 os.chmod(runtime_dir, new_permissions)
169 except OSError as e:
170 if e.errno == errno.EPERM:
171 # suppress permission errors setting sticky bit on runtime_dir,
172 # which we may not own.
173 pass
174 return fname, cfg
175
176
177def find_connection_file(
178 filename: str = "kernel-*.json",
179 path: str | list[str] | None = None,
180 profile: str | None = None,
181) -> str:
182 """find a connection file, and return its absolute path.
183
184 The current working directory and optional search path
185 will be searched for the file if it is not given by absolute path.
186
187 If the argument does not match an existing file, it will be interpreted as a
188 fileglob, and the matching file in the profile's security dir with
189 the latest access time will be used.
190
191 Parameters
192 ----------
193 filename : str
194 The connection file or fileglob to search for.
195 path : str or list of strs[optional]
196 Paths in which to search for connection files.
197
198 Returns
199 -------
200 str : The absolute path of the connection file.
201 """
202 if profile is not None:
203 warnings.warn(
204 "Jupyter has no profiles. profile=%s has been ignored." % profile, stacklevel=2
205 )
206 if path is None:
207 path = [".", jupyter_runtime_dir()]
208 if isinstance(path, str):
209 path = [path]
210
211 try:
212 # first, try explicit name
213 return _filefind(filename, path)
214 except OSError:
215 pass
216
217 # not found by full name
218
219 if "*" in filename:
220 # given as a glob already
221 pat = filename
222 else:
223 # accept any substring match
224 pat = "*%s*" % filename
225
226 matches = []
227 for p in path:
228 matches.extend(glob.glob(os.path.join(p, pat)))
229
230 matches = [os.path.abspath(m) for m in matches]
231 if not matches:
232 msg = f"Could not find {filename!r} in {path!r}"
233 raise OSError(msg)
234 elif len(matches) == 1:
235 return matches[0]
236 else:
237 # get most recent match, by access time:
238 return sorted(matches, key=lambda f: os.stat(f).st_atime)[-1]
239
240
241def tunnel_to_kernel(
242 connection_info: str | KernelConnectionInfo,
243 sshserver: str,
244 sshkey: str | None = None,
245) -> tuple[Any, ...]:
246 """tunnel connections to a kernel via ssh
247
248 This will open five SSH tunnels from localhost on this machine to the
249 ports associated with the kernel. They can be either direct
250 localhost-localhost tunnels, or if an intermediate server is necessary,
251 the kernel must be listening on a public IP.
252
253 Parameters
254 ----------
255 connection_info : dict or str (path)
256 Either a connection dict, or the path to a JSON connection file
257 sshserver : str
258 The ssh sever to use to tunnel to the kernel. Can be a full
259 `user@server:port` string. ssh config aliases are respected.
260 sshkey : str [optional]
261 Path to file containing ssh key to use for authentication.
262 Only necessary if your ssh config does not already associate
263 a keyfile with the host.
264
265 Returns
266 -------
267
268 (shell, iopub, stdin, hb, control) : ints
269 The five ports on localhost that have been forwarded to the kernel.
270 """
271 from .ssh import tunnel
272
273 if isinstance(connection_info, str):
274 # it's a path, unpack it
275 with open(connection_info) as f:
276 connection_info = json.loads(f.read())
277
278 cf = cast(Dict[str, Any], connection_info)
279
280 lports = tunnel.select_random_ports(5)
281 rports = (
282 cf["shell_port"],
283 cf["iopub_port"],
284 cf["stdin_port"],
285 cf["hb_port"],
286 cf["control_port"],
287 )
288
289 remote_ip = cf["ip"]
290
291 if tunnel.try_passwordless_ssh(sshserver, sshkey):
292 password: bool | str = False
293 else:
294 password = getpass("SSH Password for %s: " % sshserver)
295
296 for lp, rp in zip(lports, rports):
297 tunnel.ssh_tunnel(lp, rp, sshserver, remote_ip, sshkey, password)
298
299 return tuple(lports)
300
301
302# -----------------------------------------------------------------------------
303# Mixin for classes that work with connection files
304# -----------------------------------------------------------------------------
305
306channel_socket_types = {
307 "hb": zmq.REQ,
308 "shell": zmq.DEALER,
309 "iopub": zmq.SUB,
310 "stdin": zmq.DEALER,
311 "control": zmq.DEALER,
312}
313
314port_names = ["%s_port" % channel for channel in ("shell", "stdin", "iopub", "hb", "control")]
315
316
317class ConnectionFileMixin(LoggingConfigurable):
318 """Mixin for configurable classes that work with connection files"""
319
320 data_dir: str | Unicode = Unicode()
321
322 def _data_dir_default(self) -> str:
323 return jupyter_data_dir()
324
325 # The addresses for the communication channels
326 connection_file = Unicode(
327 "",
328 config=True,
329 help="""JSON file in which to store connection info [default: kernel-<pid>.json]
330
331 This file will contain the IP, ports, and authentication key needed to connect
332 clients to this kernel. By default, this file will be created in the security dir
333 of the current profile, but can be specified by absolute path.
334 """,
335 )
336 _connection_file_written = Bool(False)
337
338 transport = CaselessStrEnum(["tcp", "ipc"], default_value="tcp", config=True)
339 kernel_name: str | Unicode = Unicode()
340
341 context = Instance(zmq.Context)
342
343 ip = Unicode(
344 config=True,
345 help="""Set the kernel\'s IP address [default localhost].
346 If the IP address is something other than localhost, then
347 Consoles on other machines will be able to connect
348 to the Kernel, so be careful!""",
349 )
350
351 def _ip_default(self) -> str:
352 if self.transport == "ipc":
353 if self.connection_file:
354 return os.path.splitext(self.connection_file)[0] + "-ipc"
355 else:
356 return "kernel-ipc"
357 else:
358 return localhost()
359
360 @observe("ip")
361 def _ip_changed(self, change: Any) -> None:
362 if change["new"] == "*":
363 self.ip = "0.0.0.0" # noqa
364
365 # protected traits
366
367 hb_port = Integer(0, config=True, help="set the heartbeat port [default: random]")
368 shell_port = Integer(0, config=True, help="set the shell (ROUTER) port [default: random]")
369 iopub_port = Integer(0, config=True, help="set the iopub (PUB) port [default: random]")
370 stdin_port = Integer(0, config=True, help="set the stdin (ROUTER) port [default: random]")
371 control_port = Integer(0, config=True, help="set the control (ROUTER) port [default: random]")
372
373 # names of the ports with random assignment
374 _random_port_names: list[str] | None = None
375
376 @property
377 def ports(self) -> list[int]:
378 return [getattr(self, name) for name in port_names]
379
380 # The Session to use for communication with the kernel.
381 session = Instance("jupyter_client.session.Session")
382
383 def _session_default(self) -> Session:
384 from .session import Session
385
386 return Session(parent=self)
387
388 # --------------------------------------------------------------------------
389 # Connection and ipc file management
390 # --------------------------------------------------------------------------
391
392 def get_connection_info(self, session: bool = False) -> KernelConnectionInfo:
393 """Return the connection info as a dict
394
395 Parameters
396 ----------
397 session : bool [default: False]
398 If True, return our session object will be included in the connection info.
399 If False (default), the configuration parameters of our session object will be included,
400 rather than the session object itself.
401
402 Returns
403 -------
404 connect_info : dict
405 dictionary of connection information.
406 """
407 info = {
408 "transport": self.transport,
409 "ip": self.ip,
410 "shell_port": self.shell_port,
411 "iopub_port": self.iopub_port,
412 "stdin_port": self.stdin_port,
413 "hb_port": self.hb_port,
414 "control_port": self.control_port,
415 }
416 if session:
417 # add *clone* of my session,
418 # so that state such as digest_history is not shared.
419 info["session"] = self.session.clone()
420 else:
421 # add session info
422 info.update(
423 {
424 "signature_scheme": self.session.signature_scheme,
425 "key": self.session.key,
426 }
427 )
428 return info
429
430 # factory for blocking clients
431 blocking_class = Type(klass=object, default_value="jupyter_client.BlockingKernelClient")
432
433 def blocking_client(self) -> BlockingKernelClient:
434 """Make a blocking client connected to my kernel"""
435 info = self.get_connection_info()
436 bc = self.blocking_class(parent=self) # type:ignore[operator]
437 bc.load_connection_info(info)
438 return bc
439
440 def cleanup_connection_file(self) -> None:
441 """Cleanup connection file *if we wrote it*
442
443 Will not raise if the connection file was already removed somehow.
444 """
445 if self._connection_file_written:
446 # cleanup connection files on full shutdown of kernel we started
447 self._connection_file_written = False
448 try:
449 os.remove(self.connection_file)
450 except (OSError, AttributeError):
451 pass
452
453 def cleanup_ipc_files(self) -> None:
454 """Cleanup ipc files if we wrote them."""
455 if self.transport != "ipc":
456 return
457 for port in self.ports:
458 ipcfile = "%s-%i" % (self.ip, port)
459 try:
460 os.remove(ipcfile)
461 except OSError:
462 pass
463
464 def _record_random_port_names(self) -> None:
465 """Records which of the ports are randomly assigned.
466
467 Records on first invocation, if the transport is tcp.
468 Does nothing on later invocations."""
469
470 if self.transport != "tcp":
471 return
472 if self._random_port_names is not None:
473 return
474
475 self._random_port_names = []
476 for name in port_names:
477 if getattr(self, name) <= 0:
478 self._random_port_names.append(name)
479
480 def cleanup_random_ports(self) -> None:
481 """Forgets randomly assigned port numbers and cleans up the connection file.
482
483 Does nothing if no port numbers have been randomly assigned.
484 In particular, does nothing unless the transport is tcp.
485 """
486
487 if not self._random_port_names:
488 return
489
490 for name in self._random_port_names:
491 setattr(self, name, 0)
492
493 self.cleanup_connection_file()
494
495 def write_connection_file(self, **kwargs: Any) -> None:
496 """Write connection info to JSON dict in self.connection_file."""
497 if self._connection_file_written and os.path.exists(self.connection_file):
498 return
499
500 self.connection_file, cfg = write_connection_file(
501 self.connection_file,
502 transport=self.transport,
503 ip=self.ip,
504 key=self.session.key,
505 stdin_port=self.stdin_port,
506 iopub_port=self.iopub_port,
507 shell_port=self.shell_port,
508 hb_port=self.hb_port,
509 control_port=self.control_port,
510 signature_scheme=self.session.signature_scheme,
511 kernel_name=self.kernel_name,
512 **kwargs,
513 )
514 # write_connection_file also sets default ports:
515 self._record_random_port_names()
516 for name in port_names:
517 setattr(self, name, cfg[name])
518
519 self._connection_file_written = True
520
521 def load_connection_file(self, connection_file: str | None = None) -> None:
522 """Load connection info from JSON dict in self.connection_file.
523
524 Parameters
525 ----------
526 connection_file: unicode, optional
527 Path to connection file to load.
528 If unspecified, use self.connection_file
529 """
530 if connection_file is None:
531 connection_file = self.connection_file
532 self.log.debug("Loading connection file %s", connection_file)
533 with open(connection_file) as f:
534 info = json.load(f)
535 self.load_connection_info(info)
536
537 def load_connection_info(self, info: KernelConnectionInfo) -> None:
538 """Load connection info from a dict containing connection info.
539
540 Typically this data comes from a connection file
541 and is called by load_connection_file.
542
543 Parameters
544 ----------
545 info: dict
546 Dictionary containing connection_info.
547 See the connection_file spec for details.
548 """
549 self.transport = info.get("transport", self.transport)
550 self.ip = info.get("ip", self._ip_default()) # type:ignore[assignment]
551
552 self._record_random_port_names()
553 for name in port_names:
554 if getattr(self, name) == 0 and name in info:
555 # not overridden by config or cl_args
556 setattr(self, name, info[name])
557
558 if "key" in info:
559 key = info["key"]
560 if isinstance(key, str):
561 key = key.encode()
562 assert isinstance(key, bytes)
563
564 self.session.key = key
565 if "signature_scheme" in info:
566 self.session.signature_scheme = info["signature_scheme"]
567
568 def _reconcile_connection_info(self, info: KernelConnectionInfo) -> None:
569 """Reconciles the connection information returned from the Provisioner.
570
571 Because some provisioners (like derivations of LocalProvisioner) may have already
572 written the connection file, this method needs to ensure that, if the connection
573 file exists, its contents match that of what was returned by the provisioner. If
574 the file does exist and its contents do not match, the file will be replaced with
575 the provisioner information (which is considered the truth).
576
577 If the file does not exist, the connection information in 'info' is loaded into the
578 KernelManager and written to the file.
579 """
580 # Prevent over-writing a file that has already been written with the same
581 # info. This is to prevent a race condition where the process has
582 # already been launched but has not yet read the connection file - as is
583 # the case with LocalProvisioners.
584 file_exists: bool = False
585 if os.path.exists(self.connection_file):
586 with open(self.connection_file) as f:
587 file_info = json.load(f)
588 # Prior to the following comparison, we need to adjust the value of "key" to
589 # be bytes, otherwise the comparison below will fail.
590 file_info["key"] = file_info["key"].encode()
591 if not self._equal_connections(info, file_info):
592 os.remove(self.connection_file) # Contents mismatch - remove the file
593 self._connection_file_written = False
594 else:
595 file_exists = True
596
597 if not file_exists:
598 # Load the connection info and write out file, clearing existing
599 # port-based attributes so they will be reloaded
600 for name in port_names:
601 setattr(self, name, 0)
602 self.load_connection_info(info)
603 self.write_connection_file()
604
605 # Ensure what is in KernelManager is what we expect.
606 km_info = self.get_connection_info()
607 if not self._equal_connections(info, km_info):
608 msg = (
609 "KernelManager's connection information already exists and does not match "
610 "the expected values returned from provisioner!"
611 )
612 raise ValueError(msg)
613
614 @staticmethod
615 def _equal_connections(conn1: KernelConnectionInfo, conn2: KernelConnectionInfo) -> bool:
616 """Compares pertinent keys of connection info data. Returns True if equivalent, False otherwise."""
617
618 pertinent_keys = [
619 "key",
620 "ip",
621 "stdin_port",
622 "iopub_port",
623 "shell_port",
624 "control_port",
625 "hb_port",
626 "transport",
627 "signature_scheme",
628 ]
629
630 return all(conn1.get(key) == conn2.get(key) for key in pertinent_keys)
631
632 # --------------------------------------------------------------------------
633 # Creating connected sockets
634 # --------------------------------------------------------------------------
635
636 def _make_url(self, channel: str) -> str:
637 """Make a ZeroMQ URL for a given channel."""
638 transport = self.transport
639 ip = self.ip
640 port = getattr(self, "%s_port" % channel)
641
642 if transport == "tcp":
643 return "tcp://%s:%i" % (ip, port)
644 else:
645 return f"{transport}://{ip}-{port}"
646
647 def _create_connected_socket(
648 self, channel: str, identity: bytes | None = None
649 ) -> zmq.sugar.socket.Socket:
650 """Create a zmq Socket and connect it to the kernel."""
651 url = self._make_url(channel)
652 socket_type = channel_socket_types[channel]
653 self.log.debug("Connecting to: %s", url)
654 sock = self.context.socket(socket_type)
655 # set linger to 1s to prevent hangs at exit
656 sock.linger = 1000
657 if identity:
658 sock.identity = identity
659 sock.connect(url)
660 return sock
661
662 def connect_iopub(self, identity: bytes | None = None) -> zmq.sugar.socket.Socket:
663 """return zmq Socket connected to the IOPub channel"""
664 sock = self._create_connected_socket("iopub", identity=identity)
665 sock.setsockopt(zmq.SUBSCRIBE, b"")
666 return sock
667
668 def connect_shell(self, identity: bytes | None = None) -> zmq.sugar.socket.Socket:
669 """return zmq Socket connected to the Shell channel"""
670 return self._create_connected_socket("shell", identity=identity)
671
672 def connect_stdin(self, identity: bytes | None = None) -> zmq.sugar.socket.Socket:
673 """return zmq Socket connected to the StdIn channel"""
674 return self._create_connected_socket("stdin", identity=identity)
675
676 def connect_hb(self, identity: bytes | None = None) -> zmq.sugar.socket.Socket:
677 """return zmq Socket connected to the Heartbeat channel"""
678 return self._create_connected_socket("hb", identity=identity)
679
680 def connect_control(self, identity: bytes | None = None) -> zmq.sugar.socket.Socket:
681 """return zmq Socket connected to the Control channel"""
682 return self._create_connected_socket("control", identity=identity)
683
684
685class LocalPortCache(SingletonConfigurable):
686 """
687 Used to keep track of local ports in order to prevent race conditions that
688 can occur between port acquisition and usage by the kernel. All locally-
689 provisioned kernels should use this mechanism to limit the possibility of
690 race conditions. Note that this does not preclude other applications from
691 acquiring a cached but unused port, thereby re-introducing the issue this
692 class is attempting to resolve (minimize).
693 See: https://github.com/jupyter/jupyter_client/issues/487
694 """
695
696 def __init__(self, **kwargs: Any) -> None:
697 super().__init__(**kwargs)
698 self.currently_used_ports: set[int] = set()
699
700 def find_available_port(self, ip: str) -> int:
701 while True:
702 tmp_sock = socket.socket()
703 tmp_sock.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, b"\0" * 8)
704 tmp_sock.bind((ip, 0))
705 port = tmp_sock.getsockname()[1]
706 tmp_sock.close()
707
708 # This is a workaround for https://github.com/jupyter/jupyter_client/issues/487
709 # We prevent two kernels to have the same ports.
710 if port not in self.currently_used_ports:
711 self.currently_used_ports.add(port)
712 return port
713
714 def return_port(self, port: int) -> None:
715 if port in self.currently_used_ports: # Tolerate uncached ports
716 self.currently_used_ports.remove(port)
717
718
719__all__ = [
720 "write_connection_file",
721 "find_connection_file",
722 "tunnel_to_kernel",
723 "KernelConnectionInfo",
724 "LocalPortCache",
725]