Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/jupyter_client/connect.py: 24%
288 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-03 06:10 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-03 06:10 +0000
1"""Utilities for connecting to jupyter kernels
3The :class:`ConnectionFileMixin` class in this module encapsulates the logic
4related to writing and reading connections files.
5"""
6# Copyright (c) Jupyter Development Team.
7# Distributed under the terms of the Modified BSD License.
8import errno
9import glob
10import json
11import os
12import socket
13import stat
14import tempfile
15import warnings
16from getpass import getpass
17from typing import Any, Dict, List, Optional, Set, Tuple, Union, cast
19import zmq
20from jupyter_core.paths import jupyter_data_dir, jupyter_runtime_dir, secure_write
21from traitlets import Bool, CaselessStrEnum, Instance, Integer, Type, Unicode, observe
22from traitlets.config import LoggingConfigurable, SingletonConfigurable
24from .localinterfaces import localhost
25from .utils import _filefind
27# Define custom type for kernel connection info
28KernelConnectionInfo = Dict[str, Union[int, str, bytes]]
31def write_connection_file(
32 fname: Optional[str] = None,
33 shell_port: int = 0,
34 iopub_port: int = 0,
35 stdin_port: int = 0,
36 hb_port: int = 0,
37 control_port: int = 0,
38 ip: str = "",
39 key: bytes = b"",
40 transport: str = "tcp",
41 signature_scheme: str = "hmac-sha256",
42 kernel_name: str = "",
43) -> Tuple[str, KernelConnectionInfo]:
44 """Generates a JSON config file, including the selection of random ports.
46 Parameters
47 ----------
49 fname : unicode
50 The path to the file to write
52 shell_port : int, optional
53 The port to use for ROUTER (shell) channel.
55 iopub_port : int, optional
56 The port to use for the SUB channel.
58 stdin_port : int, optional
59 The port to use for the ROUTER (raw input) channel.
61 control_port : int, optional
62 The port to use for the ROUTER (control) channel.
64 hb_port : int, optional
65 The port to use for the heartbeat REP channel.
67 ip : str, optional
68 The ip address the kernel will bind to.
70 key : str, optional
71 The Session key used for message authentication.
73 signature_scheme : str, optional
74 The scheme used for message authentication.
75 This has the form 'digest-hash', where 'digest'
76 is the scheme used for digests, and 'hash' is the name of the hash function
77 used by the digest scheme.
78 Currently, 'hmac' is the only supported digest scheme,
79 and 'sha256' is the default hash function.
81 kernel_name : str, optional
82 The name of the kernel currently connected to.
83 """
84 if not ip:
85 ip = localhost()
86 # default to temporary connector file
87 if not fname:
88 fd, fname = tempfile.mkstemp(".json")
89 os.close(fd)
91 # Find open ports as necessary.
93 ports: List[int] = []
94 sockets: List[socket.socket] = []
95 ports_needed = (
96 int(shell_port <= 0)
97 + int(iopub_port <= 0)
98 + int(stdin_port <= 0)
99 + int(control_port <= 0)
100 + int(hb_port <= 0)
101 )
102 if transport == "tcp":
103 for _ in range(ports_needed):
104 sock = socket.socket()
105 # struct.pack('ii', (0,0)) is 8 null bytes
106 sock.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, b"\0" * 8)
107 sock.bind((ip, 0))
108 sockets.append(sock)
109 for sock in sockets:
110 port = sock.getsockname()[1]
111 sock.close()
112 ports.append(port)
113 else:
114 N = 1
115 for _ in range(ports_needed):
116 while os.path.exists(f"{ip}-{str(N)}"):
117 N += 1
118 ports.append(N)
119 N += 1
120 if shell_port <= 0:
121 shell_port = ports.pop(0)
122 if iopub_port <= 0:
123 iopub_port = ports.pop(0)
124 if stdin_port <= 0:
125 stdin_port = ports.pop(0)
126 if control_port <= 0:
127 control_port = ports.pop(0)
128 if hb_port <= 0:
129 hb_port = ports.pop(0)
131 cfg: KernelConnectionInfo = {
132 "shell_port": shell_port,
133 "iopub_port": iopub_port,
134 "stdin_port": stdin_port,
135 "control_port": control_port,
136 "hb_port": hb_port,
137 }
138 cfg["ip"] = ip
139 cfg["key"] = key.decode()
140 cfg["transport"] = transport
141 cfg["signature_scheme"] = signature_scheme
142 cfg["kernel_name"] = kernel_name
144 # Only ever write this file as user read/writeable
145 # This would otherwise introduce a vulnerability as a file has secrets
146 # which would let others execute arbitrary code as you
147 with secure_write(fname) as f:
148 f.write(json.dumps(cfg, indent=2))
150 if hasattr(stat, "S_ISVTX"):
151 # set the sticky bit on the parent directory of the file
152 # to ensure only owner can remove it
153 runtime_dir = os.path.dirname(fname)
154 if runtime_dir:
155 permissions = os.stat(runtime_dir).st_mode
156 new_permissions = permissions | stat.S_ISVTX
157 if new_permissions != permissions:
158 try:
159 os.chmod(runtime_dir, new_permissions)
160 except OSError as e:
161 if e.errno == errno.EPERM:
162 # suppress permission errors setting sticky bit on runtime_dir,
163 # which we may not own.
164 pass
165 return fname, cfg
168def find_connection_file(
169 filename: str = "kernel-*.json",
170 path: Optional[Union[str, List[str]]] = None,
171 profile: Optional[str] = None,
172) -> str:
173 """find a connection file, and return its absolute path.
175 The current working directory and optional search path
176 will be searched for the file if it is not given by absolute path.
178 If the argument does not match an existing file, it will be interpreted as a
179 fileglob, and the matching file in the profile's security dir with
180 the latest access time will be used.
182 Parameters
183 ----------
184 filename : str
185 The connection file or fileglob to search for.
186 path : str or list of strs[optional]
187 Paths in which to search for connection files.
189 Returns
190 -------
191 str : The absolute path of the connection file.
192 """
193 if profile is not None:
194 warnings.warn(
195 "Jupyter has no profiles. profile=%s has been ignored." % profile, stacklevel=2
196 )
197 if path is None:
198 path = [".", jupyter_runtime_dir()]
199 if isinstance(path, str):
200 path = [path]
202 try:
203 # first, try explicit name
204 return _filefind(filename, path)
205 except OSError:
206 pass
208 # not found by full name
210 if "*" in filename:
211 # given as a glob already
212 pat = filename
213 else:
214 # accept any substring match
215 pat = "*%s*" % filename
217 matches = []
218 for p in path:
219 matches.extend(glob.glob(os.path.join(p, pat)))
221 matches = [os.path.abspath(m) for m in matches]
222 if not matches:
223 msg = f"Could not find {filename!r} in {path!r}"
224 raise OSError(msg)
225 elif len(matches) == 1:
226 return matches[0]
227 else:
228 # get most recent match, by access time:
229 return sorted(matches, key=lambda f: os.stat(f).st_atime)[-1]
232def tunnel_to_kernel(
233 connection_info: Union[str, KernelConnectionInfo],
234 sshserver: str,
235 sshkey: Optional[str] = None,
236) -> Tuple[Any, ...]:
237 """tunnel connections to a kernel via ssh
239 This will open five SSH tunnels from localhost on this machine to the
240 ports associated with the kernel. They can be either direct
241 localhost-localhost tunnels, or if an intermediate server is necessary,
242 the kernel must be listening on a public IP.
244 Parameters
245 ----------
246 connection_info : dict or str (path)
247 Either a connection dict, or the path to a JSON connection file
248 sshserver : str
249 The ssh sever to use to tunnel to the kernel. Can be a full
250 `user@server:port` string. ssh config aliases are respected.
251 sshkey : str [optional]
252 Path to file containing ssh key to use for authentication.
253 Only necessary if your ssh config does not already associate
254 a keyfile with the host.
256 Returns
257 -------
259 (shell, iopub, stdin, hb, control) : ints
260 The five ports on localhost that have been forwarded to the kernel.
261 """
262 from .ssh import tunnel
264 if isinstance(connection_info, str):
265 # it's a path, unpack it
266 with open(connection_info) as f:
267 connection_info = json.loads(f.read())
269 cf = cast(Dict[str, Any], connection_info)
271 lports = tunnel.select_random_ports(5)
272 rports = (
273 cf["shell_port"],
274 cf["iopub_port"],
275 cf["stdin_port"],
276 cf["hb_port"],
277 cf["control_port"],
278 )
280 remote_ip = cf["ip"]
282 if tunnel.try_passwordless_ssh(sshserver, sshkey):
283 password: Union[bool, str] = False
284 else:
285 password = getpass("SSH Password for %s: " % sshserver)
287 for lp, rp in zip(lports, rports):
288 tunnel.ssh_tunnel(lp, rp, sshserver, remote_ip, sshkey, password)
290 return tuple(lports)
293# -----------------------------------------------------------------------------
294# Mixin for classes that work with connection files
295# -----------------------------------------------------------------------------
297channel_socket_types = {
298 "hb": zmq.REQ,
299 "shell": zmq.DEALER,
300 "iopub": zmq.SUB,
301 "stdin": zmq.DEALER,
302 "control": zmq.DEALER,
303}
305port_names = ["%s_port" % channel for channel in ("shell", "stdin", "iopub", "hb", "control")]
308class ConnectionFileMixin(LoggingConfigurable):
309 """Mixin for configurable classes that work with connection files"""
311 data_dir: Union[str, Unicode] = Unicode()
313 def _data_dir_default(self):
314 return jupyter_data_dir()
316 # The addresses for the communication channels
317 connection_file = Unicode(
318 "",
319 config=True,
320 help="""JSON file in which to store connection info [default: kernel-<pid>.json]
322 This file will contain the IP, ports, and authentication key needed to connect
323 clients to this kernel. By default, this file will be created in the security dir
324 of the current profile, but can be specified by absolute path.
325 """,
326 )
327 _connection_file_written = Bool(False)
329 transport = CaselessStrEnum(["tcp", "ipc"], default_value="tcp", config=True)
330 kernel_name: Union[str, Unicode] = Unicode()
332 context = Instance(zmq.Context)
334 ip = Unicode(
335 config=True,
336 help="""Set the kernel\'s IP address [default localhost].
337 If the IP address is something other than localhost, then
338 Consoles on other machines will be able to connect
339 to the Kernel, so be careful!""",
340 )
342 def _ip_default(self):
343 if self.transport == "ipc":
344 if self.connection_file:
345 return os.path.splitext(self.connection_file)[0] + "-ipc"
346 else:
347 return "kernel-ipc"
348 else:
349 return localhost()
351 @observe("ip")
352 def _ip_changed(self, change):
353 if change["new"] == "*":
354 self.ip = "0.0.0.0" # noqa
356 # protected traits
358 hb_port = Integer(0, config=True, help="set the heartbeat port [default: random]")
359 shell_port = Integer(0, config=True, help="set the shell (ROUTER) port [default: random]")
360 iopub_port = Integer(0, config=True, help="set the iopub (PUB) port [default: random]")
361 stdin_port = Integer(0, config=True, help="set the stdin (ROUTER) port [default: random]")
362 control_port = Integer(0, config=True, help="set the control (ROUTER) port [default: random]")
364 # names of the ports with random assignment
365 _random_port_names: Optional[List[str]] = None
367 @property
368 def ports(self) -> List[int]:
369 return [getattr(self, name) for name in port_names]
371 # The Session to use for communication with the kernel.
372 session = Instance("jupyter_client.session.Session")
374 def _session_default(self):
375 from .session import Session
377 return Session(parent=self)
379 # --------------------------------------------------------------------------
380 # Connection and ipc file management
381 # --------------------------------------------------------------------------
383 def get_connection_info(self, session: bool = False) -> KernelConnectionInfo:
384 """Return the connection info as a dict
386 Parameters
387 ----------
388 session : bool [default: False]
389 If True, return our session object will be included in the connection info.
390 If False (default), the configuration parameters of our session object will be included,
391 rather than the session object itself.
393 Returns
394 -------
395 connect_info : dict
396 dictionary of connection information.
397 """
398 info = {
399 "transport": self.transport,
400 "ip": self.ip,
401 "shell_port": self.shell_port,
402 "iopub_port": self.iopub_port,
403 "stdin_port": self.stdin_port,
404 "hb_port": self.hb_port,
405 "control_port": self.control_port,
406 }
407 if session:
408 # add *clone* of my session,
409 # so that state such as digest_history is not shared.
410 info["session"] = self.session.clone()
411 else:
412 # add session info
413 info.update(
414 {
415 "signature_scheme": self.session.signature_scheme,
416 "key": self.session.key,
417 }
418 )
419 return info
421 # factory for blocking clients
422 blocking_class = Type(klass=object, default_value="jupyter_client.BlockingKernelClient")
424 def blocking_client(self):
425 """Make a blocking client connected to my kernel"""
426 info = self.get_connection_info()
427 bc = self.blocking_class(parent=self)
428 bc.load_connection_info(info)
429 return bc
431 def cleanup_connection_file(self) -> None:
432 """Cleanup connection file *if we wrote it*
434 Will not raise if the connection file was already removed somehow.
435 """
436 if self._connection_file_written:
437 # cleanup connection files on full shutdown of kernel we started
438 self._connection_file_written = False
439 try:
440 os.remove(self.connection_file)
441 except (OSError, AttributeError):
442 pass
444 def cleanup_ipc_files(self) -> None:
445 """Cleanup ipc files if we wrote them."""
446 if self.transport != "ipc":
447 return
448 for port in self.ports:
449 ipcfile = "%s-%i" % (self.ip, port)
450 try:
451 os.remove(ipcfile)
452 except OSError:
453 pass
455 def _record_random_port_names(self) -> None:
456 """Records which of the ports are randomly assigned.
458 Records on first invocation, if the transport is tcp.
459 Does nothing on later invocations."""
461 if self.transport != "tcp":
462 return
463 if self._random_port_names is not None:
464 return
466 self._random_port_names = []
467 for name in port_names:
468 if getattr(self, name) <= 0:
469 self._random_port_names.append(name)
471 def cleanup_random_ports(self) -> None:
472 """Forgets randomly assigned port numbers and cleans up the connection file.
474 Does nothing if no port numbers have been randomly assigned.
475 In particular, does nothing unless the transport is tcp.
476 """
478 if not self._random_port_names:
479 return
481 for name in self._random_port_names:
482 setattr(self, name, 0)
484 self.cleanup_connection_file()
486 def write_connection_file(self) -> None:
487 """Write connection info to JSON dict in self.connection_file."""
488 if self._connection_file_written and os.path.exists(self.connection_file):
489 return
491 self.connection_file, cfg = write_connection_file(
492 self.connection_file,
493 transport=self.transport,
494 ip=self.ip,
495 key=self.session.key,
496 stdin_port=self.stdin_port,
497 iopub_port=self.iopub_port,
498 shell_port=self.shell_port,
499 hb_port=self.hb_port,
500 control_port=self.control_port,
501 signature_scheme=self.session.signature_scheme,
502 kernel_name=self.kernel_name,
503 )
504 # write_connection_file also sets default ports:
505 self._record_random_port_names()
506 for name in port_names:
507 setattr(self, name, cfg[name])
509 self._connection_file_written = True
511 def load_connection_file(self, connection_file: Optional[str] = None) -> None:
512 """Load connection info from JSON dict in self.connection_file.
514 Parameters
515 ----------
516 connection_file: unicode, optional
517 Path to connection file to load.
518 If unspecified, use self.connection_file
519 """
520 if connection_file is None:
521 connection_file = self.connection_file
522 self.log.debug("Loading connection file %s", connection_file)
523 with open(connection_file) as f:
524 info = json.load(f)
525 self.load_connection_info(info)
527 def load_connection_info(self, info: KernelConnectionInfo) -> None:
528 """Load connection info from a dict containing connection info.
530 Typically this data comes from a connection file
531 and is called by load_connection_file.
533 Parameters
534 ----------
535 info: dict
536 Dictionary containing connection_info.
537 See the connection_file spec for details.
538 """
539 self.transport = info.get("transport", self.transport)
540 self.ip = info.get("ip", self._ip_default())
542 self._record_random_port_names()
543 for name in port_names:
544 if getattr(self, name) == 0 and name in info:
545 # not overridden by config or cl_args
546 setattr(self, name, info[name])
548 if "key" in info:
549 key = info["key"]
550 if isinstance(key, str):
551 key = key.encode()
552 assert isinstance(key, bytes)
554 self.session.key = key
555 if "signature_scheme" in info:
556 self.session.signature_scheme = info["signature_scheme"]
558 def _reconcile_connection_info(self, info: KernelConnectionInfo) -> None:
559 """Reconciles the connection information returned from the Provisioner.
561 Because some provisioners (like derivations of LocalProvisioner) may have already
562 written the connection file, this method needs to ensure that, if the connection
563 file exists, its contents match that of what was returned by the provisioner. If
564 the file does exist and its contents do not match, the file will be replaced with
565 the provisioner information (which is considered the truth).
567 If the file does not exist, the connection information in 'info' is loaded into the
568 KernelManager and written to the file.
569 """
570 # Prevent over-writing a file that has already been written with the same
571 # info. This is to prevent a race condition where the process has
572 # already been launched but has not yet read the connection file - as is
573 # the case with LocalProvisioners.
574 file_exists: bool = False
575 if os.path.exists(self.connection_file):
576 with open(self.connection_file) as f:
577 file_info = json.load(f)
578 # Prior to the following comparison, we need to adjust the value of "key" to
579 # be bytes, otherwise the comparison below will fail.
580 file_info["key"] = file_info["key"].encode()
581 if not self._equal_connections(info, file_info):
582 os.remove(self.connection_file) # Contents mismatch - remove the file
583 self._connection_file_written = False
584 else:
585 file_exists = True
587 if not file_exists:
588 # Load the connection info and write out file, clearing existing
589 # port-based attributes so they will be reloaded
590 for name in port_names:
591 setattr(self, name, 0)
592 self.load_connection_info(info)
593 self.write_connection_file()
595 # Ensure what is in KernelManager is what we expect.
596 km_info = self.get_connection_info()
597 if not self._equal_connections(info, km_info):
598 msg = (
599 "KernelManager's connection information already exists and does not match "
600 "the expected values returned from provisioner!"
601 )
602 raise ValueError(msg)
604 @staticmethod
605 def _equal_connections(conn1: KernelConnectionInfo, conn2: KernelConnectionInfo) -> bool:
606 """Compares pertinent keys of connection info data. Returns True if equivalent, False otherwise."""
608 pertinent_keys = [
609 "key",
610 "ip",
611 "stdin_port",
612 "iopub_port",
613 "shell_port",
614 "control_port",
615 "hb_port",
616 "transport",
617 "signature_scheme",
618 ]
620 return all(conn1.get(key) == conn2.get(key) for key in pertinent_keys)
622 # --------------------------------------------------------------------------
623 # Creating connected sockets
624 # --------------------------------------------------------------------------
626 def _make_url(self, channel: str) -> str:
627 """Make a ZeroMQ URL for a given channel."""
628 transport = self.transport
629 ip = self.ip
630 port = getattr(self, "%s_port" % channel)
632 if transport == "tcp":
633 return "tcp://%s:%i" % (ip, port)
634 else:
635 return f"{transport}://{ip}-{port}"
637 def _create_connected_socket(
638 self, channel: str, identity: Optional[bytes] = None
639 ) -> zmq.sugar.socket.Socket:
640 """Create a zmq Socket and connect it to the kernel."""
641 url = self._make_url(channel)
642 socket_type = channel_socket_types[channel]
643 self.log.debug("Connecting to: %s", url)
644 sock = self.context.socket(socket_type)
645 # set linger to 1s to prevent hangs at exit
646 sock.linger = 1000
647 if identity:
648 sock.identity = identity
649 sock.connect(url)
650 return sock
652 def connect_iopub(self, identity: Optional[bytes] = None) -> zmq.sugar.socket.Socket:
653 """return zmq Socket connected to the IOPub channel"""
654 sock = self._create_connected_socket("iopub", identity=identity)
655 sock.setsockopt(zmq.SUBSCRIBE, b"")
656 return sock
658 def connect_shell(self, identity: Optional[bytes] = None) -> zmq.sugar.socket.Socket:
659 """return zmq Socket connected to the Shell channel"""
660 return self._create_connected_socket("shell", identity=identity)
662 def connect_stdin(self, identity: Optional[bytes] = None) -> zmq.sugar.socket.Socket:
663 """return zmq Socket connected to the StdIn channel"""
664 return self._create_connected_socket("stdin", identity=identity)
666 def connect_hb(self, identity: Optional[bytes] = None) -> zmq.sugar.socket.Socket:
667 """return zmq Socket connected to the Heartbeat channel"""
668 return self._create_connected_socket("hb", identity=identity)
670 def connect_control(self, identity: Optional[bytes] = None) -> zmq.sugar.socket.Socket:
671 """return zmq Socket connected to the Control channel"""
672 return self._create_connected_socket("control", identity=identity)
675class LocalPortCache(SingletonConfigurable):
676 """
677 Used to keep track of local ports in order to prevent race conditions that
678 can occur between port acquisition and usage by the kernel. All locally-
679 provisioned kernels should use this mechanism to limit the possibility of
680 race conditions. Note that this does not preclude other applications from
681 acquiring a cached but unused port, thereby re-introducing the issue this
682 class is attempting to resolve (minimize).
683 See: https://github.com/jupyter/jupyter_client/issues/487
684 """
686 def __init__(self, **kwargs: Any) -> None:
687 super().__init__(**kwargs)
688 self.currently_used_ports: Set[int] = set()
690 def find_available_port(self, ip: str) -> int:
691 while True:
692 tmp_sock = socket.socket()
693 tmp_sock.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, b"\0" * 8)
694 tmp_sock.bind((ip, 0))
695 port = tmp_sock.getsockname()[1]
696 tmp_sock.close()
698 # This is a workaround for https://github.com/jupyter/jupyter_client/issues/487
699 # We prevent two kernels to have the same ports.
700 if port not in self.currently_used_ports:
701 self.currently_used_ports.add(port)
702 return port
704 def return_port(self, port: int) -> None:
705 if port in self.currently_used_ports: # Tolerate uncached ports
706 self.currently_used_ports.remove(port)
709__all__ = [
710 "write_connection_file",
711 "find_connection_file",
712 "tunnel_to_kernel",
713 "KernelConnectionInfo",
714 "LocalPortCache",
715]