Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/jupyter_client/connect.py: 24%
289 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-07-01 06:54 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-07-01 06:54 +0000
1"""Utilities for connecting to jupyter kernels
3The :class:`ConnectionFileMixin` class in this module encapsulates the logic
4related to writing and reading connections files.
5"""
6# Copyright (c) Jupyter Development Team.
7# Distributed under the terms of the Modified BSD License.
8import errno
9import glob
10import json
11import os
12import socket
13import stat
14import tempfile
15import warnings
16from getpass import getpass
17from typing import Any, Dict, List, Optional, Set, Tuple, Union, cast
19import zmq
20from jupyter_core.paths import jupyter_data_dir, jupyter_runtime_dir, secure_write
21from traitlets import Bool, CaselessStrEnum, Instance, Integer, Type, Unicode, observe
22from traitlets.config import LoggingConfigurable, SingletonConfigurable
24from .localinterfaces import localhost
25from .utils import _filefind
27# Define custom type for kernel connection info
28KernelConnectionInfo = Dict[str, Union[int, str, bytes]]
31def write_connection_file(
32 fname: Optional[str] = None,
33 shell_port: int = 0,
34 iopub_port: int = 0,
35 stdin_port: int = 0,
36 hb_port: int = 0,
37 control_port: int = 0,
38 ip: str = "",
39 key: bytes = b"",
40 transport: str = "tcp",
41 signature_scheme: str = "hmac-sha256",
42 kernel_name: str = "",
43 **kwargs: Any,
44) -> Tuple[str, KernelConnectionInfo]:
45 """Generates a JSON config file, including the selection of random ports.
47 Parameters
48 ----------
50 fname : unicode
51 The path to the file to write
53 shell_port : int, optional
54 The port to use for ROUTER (shell) channel.
56 iopub_port : int, optional
57 The port to use for the SUB channel.
59 stdin_port : int, optional
60 The port to use for the ROUTER (raw input) channel.
62 control_port : int, optional
63 The port to use for the ROUTER (control) channel.
65 hb_port : int, optional
66 The port to use for the heartbeat REP channel.
68 ip : str, optional
69 The ip address the kernel will bind to.
71 key : str, optional
72 The Session key used for message authentication.
74 signature_scheme : str, optional
75 The scheme used for message authentication.
76 This has the form 'digest-hash', where 'digest'
77 is the scheme used for digests, and 'hash' is the name of the hash function
78 used by the digest scheme.
79 Currently, 'hmac' is the only supported digest scheme,
80 and 'sha256' is the default hash function.
82 kernel_name : str, optional
83 The name of the kernel currently connected to.
84 """
85 if not ip:
86 ip = localhost()
87 # default to temporary connector file
88 if not fname:
89 fd, fname = tempfile.mkstemp(".json")
90 os.close(fd)
92 # Find open ports as necessary.
94 ports: List[int] = []
95 sockets: List[socket.socket] = []
96 ports_needed = (
97 int(shell_port <= 0)
98 + int(iopub_port <= 0)
99 + int(stdin_port <= 0)
100 + int(control_port <= 0)
101 + int(hb_port <= 0)
102 )
103 if transport == "tcp":
104 for _ in range(ports_needed):
105 sock = socket.socket()
106 # struct.pack('ii', (0,0)) is 8 null bytes
107 sock.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, b"\0" * 8)
108 sock.bind((ip, 0))
109 sockets.append(sock)
110 for sock in sockets:
111 port = sock.getsockname()[1]
112 sock.close()
113 ports.append(port)
114 else:
115 N = 1
116 for _ in range(ports_needed):
117 while os.path.exists(f"{ip}-{str(N)}"):
118 N += 1
119 ports.append(N)
120 N += 1
121 if shell_port <= 0:
122 shell_port = ports.pop(0)
123 if iopub_port <= 0:
124 iopub_port = ports.pop(0)
125 if stdin_port <= 0:
126 stdin_port = ports.pop(0)
127 if control_port <= 0:
128 control_port = ports.pop(0)
129 if hb_port <= 0:
130 hb_port = ports.pop(0)
132 cfg: KernelConnectionInfo = {
133 "shell_port": shell_port,
134 "iopub_port": iopub_port,
135 "stdin_port": stdin_port,
136 "control_port": control_port,
137 "hb_port": hb_port,
138 }
139 cfg["ip"] = ip
140 cfg["key"] = key.decode()
141 cfg["transport"] = transport
142 cfg["signature_scheme"] = signature_scheme
143 cfg["kernel_name"] = kernel_name
144 cfg.update(kwargs)
146 # Only ever write this file as user read/writeable
147 # This would otherwise introduce a vulnerability as a file has secrets
148 # which would let others execute arbitrary code as you
149 with secure_write(fname) as f:
150 f.write(json.dumps(cfg, indent=2))
152 if hasattr(stat, "S_ISVTX"):
153 # set the sticky bit on the parent directory of the file
154 # to ensure only owner can remove it
155 runtime_dir = os.path.dirname(fname)
156 if runtime_dir:
157 permissions = os.stat(runtime_dir).st_mode
158 new_permissions = permissions | stat.S_ISVTX
159 if new_permissions != permissions:
160 try:
161 os.chmod(runtime_dir, new_permissions)
162 except OSError as e:
163 if e.errno == errno.EPERM:
164 # suppress permission errors setting sticky bit on runtime_dir,
165 # which we may not own.
166 pass
167 return fname, cfg
170def find_connection_file(
171 filename: str = "kernel-*.json",
172 path: Optional[Union[str, List[str]]] = None,
173 profile: Optional[str] = None,
174) -> str:
175 """find a connection file, and return its absolute path.
177 The current working directory and optional search path
178 will be searched for the file if it is not given by absolute path.
180 If the argument does not match an existing file, it will be interpreted as a
181 fileglob, and the matching file in the profile's security dir with
182 the latest access time will be used.
184 Parameters
185 ----------
186 filename : str
187 The connection file or fileglob to search for.
188 path : str or list of strs[optional]
189 Paths in which to search for connection files.
191 Returns
192 -------
193 str : The absolute path of the connection file.
194 """
195 if profile is not None:
196 warnings.warn(
197 "Jupyter has no profiles. profile=%s has been ignored." % profile, stacklevel=2
198 )
199 if path is None:
200 path = [".", jupyter_runtime_dir()]
201 if isinstance(path, str):
202 path = [path]
204 try:
205 # first, try explicit name
206 return _filefind(filename, path)
207 except OSError:
208 pass
210 # not found by full name
212 if "*" in filename:
213 # given as a glob already
214 pat = filename
215 else:
216 # accept any substring match
217 pat = "*%s*" % filename
219 matches = []
220 for p in path:
221 matches.extend(glob.glob(os.path.join(p, pat)))
223 matches = [os.path.abspath(m) for m in matches]
224 if not matches:
225 msg = f"Could not find {filename!r} in {path!r}"
226 raise OSError(msg)
227 elif len(matches) == 1:
228 return matches[0]
229 else:
230 # get most recent match, by access time:
231 return sorted(matches, key=lambda f: os.stat(f).st_atime)[-1]
234def tunnel_to_kernel(
235 connection_info: Union[str, KernelConnectionInfo],
236 sshserver: str,
237 sshkey: Optional[str] = None,
238) -> Tuple[Any, ...]:
239 """tunnel connections to a kernel via ssh
241 This will open five SSH tunnels from localhost on this machine to the
242 ports associated with the kernel. They can be either direct
243 localhost-localhost tunnels, or if an intermediate server is necessary,
244 the kernel must be listening on a public IP.
246 Parameters
247 ----------
248 connection_info : dict or str (path)
249 Either a connection dict, or the path to a JSON connection file
250 sshserver : str
251 The ssh sever to use to tunnel to the kernel. Can be a full
252 `user@server:port` string. ssh config aliases are respected.
253 sshkey : str [optional]
254 Path to file containing ssh key to use for authentication.
255 Only necessary if your ssh config does not already associate
256 a keyfile with the host.
258 Returns
259 -------
261 (shell, iopub, stdin, hb, control) : ints
262 The five ports on localhost that have been forwarded to the kernel.
263 """
264 from .ssh import tunnel
266 if isinstance(connection_info, str):
267 # it's a path, unpack it
268 with open(connection_info) as f:
269 connection_info = json.loads(f.read())
271 cf = cast(Dict[str, Any], connection_info)
273 lports = tunnel.select_random_ports(5)
274 rports = (
275 cf["shell_port"],
276 cf["iopub_port"],
277 cf["stdin_port"],
278 cf["hb_port"],
279 cf["control_port"],
280 )
282 remote_ip = cf["ip"]
284 if tunnel.try_passwordless_ssh(sshserver, sshkey):
285 password: Union[bool, str] = False
286 else:
287 password = getpass("SSH Password for %s: " % sshserver)
289 for lp, rp in zip(lports, rports):
290 tunnel.ssh_tunnel(lp, rp, sshserver, remote_ip, sshkey, password)
292 return tuple(lports)
295# -----------------------------------------------------------------------------
296# Mixin for classes that work with connection files
297# -----------------------------------------------------------------------------
299channel_socket_types = {
300 "hb": zmq.REQ,
301 "shell": zmq.DEALER,
302 "iopub": zmq.SUB,
303 "stdin": zmq.DEALER,
304 "control": zmq.DEALER,
305}
307port_names = ["%s_port" % channel for channel in ("shell", "stdin", "iopub", "hb", "control")]
310class ConnectionFileMixin(LoggingConfigurable):
311 """Mixin for configurable classes that work with connection files"""
313 data_dir: Union[str, Unicode] = Unicode()
315 def _data_dir_default(self):
316 return jupyter_data_dir()
318 # The addresses for the communication channels
319 connection_file = Unicode(
320 "",
321 config=True,
322 help="""JSON file in which to store connection info [default: kernel-<pid>.json]
324 This file will contain the IP, ports, and authentication key needed to connect
325 clients to this kernel. By default, this file will be created in the security dir
326 of the current profile, but can be specified by absolute path.
327 """,
328 )
329 _connection_file_written = Bool(False)
331 transport = CaselessStrEnum(["tcp", "ipc"], default_value="tcp", config=True)
332 kernel_name: Union[str, Unicode] = Unicode()
334 context = Instance(zmq.Context)
336 ip = Unicode(
337 config=True,
338 help="""Set the kernel\'s IP address [default localhost].
339 If the IP address is something other than localhost, then
340 Consoles on other machines will be able to connect
341 to the Kernel, so be careful!""",
342 )
344 def _ip_default(self):
345 if self.transport == "ipc":
346 if self.connection_file:
347 return os.path.splitext(self.connection_file)[0] + "-ipc"
348 else:
349 return "kernel-ipc"
350 else:
351 return localhost()
353 @observe("ip")
354 def _ip_changed(self, change):
355 if change["new"] == "*":
356 self.ip = "0.0.0.0" # noqa
358 # protected traits
360 hb_port = Integer(0, config=True, help="set the heartbeat port [default: random]")
361 shell_port = Integer(0, config=True, help="set the shell (ROUTER) port [default: random]")
362 iopub_port = Integer(0, config=True, help="set the iopub (PUB) port [default: random]")
363 stdin_port = Integer(0, config=True, help="set the stdin (ROUTER) port [default: random]")
364 control_port = Integer(0, config=True, help="set the control (ROUTER) port [default: random]")
366 # names of the ports with random assignment
367 _random_port_names: Optional[List[str]] = None
369 @property
370 def ports(self) -> List[int]:
371 return [getattr(self, name) for name in port_names]
373 # The Session to use for communication with the kernel.
374 session = Instance("jupyter_client.session.Session")
376 def _session_default(self):
377 from .session import Session
379 return Session(parent=self)
381 # --------------------------------------------------------------------------
382 # Connection and ipc file management
383 # --------------------------------------------------------------------------
385 def get_connection_info(self, session: bool = False) -> KernelConnectionInfo:
386 """Return the connection info as a dict
388 Parameters
389 ----------
390 session : bool [default: False]
391 If True, return our session object will be included in the connection info.
392 If False (default), the configuration parameters of our session object will be included,
393 rather than the session object itself.
395 Returns
396 -------
397 connect_info : dict
398 dictionary of connection information.
399 """
400 info = {
401 "transport": self.transport,
402 "ip": self.ip,
403 "shell_port": self.shell_port,
404 "iopub_port": self.iopub_port,
405 "stdin_port": self.stdin_port,
406 "hb_port": self.hb_port,
407 "control_port": self.control_port,
408 }
409 if session:
410 # add *clone* of my session,
411 # so that state such as digest_history is not shared.
412 info["session"] = self.session.clone()
413 else:
414 # add session info
415 info.update(
416 {
417 "signature_scheme": self.session.signature_scheme,
418 "key": self.session.key,
419 }
420 )
421 return info
423 # factory for blocking clients
424 blocking_class = Type(klass=object, default_value="jupyter_client.BlockingKernelClient")
426 def blocking_client(self):
427 """Make a blocking client connected to my kernel"""
428 info = self.get_connection_info()
429 bc = self.blocking_class(parent=self)
430 bc.load_connection_info(info)
431 return bc
433 def cleanup_connection_file(self) -> None:
434 """Cleanup connection file *if we wrote it*
436 Will not raise if the connection file was already removed somehow.
437 """
438 if self._connection_file_written:
439 # cleanup connection files on full shutdown of kernel we started
440 self._connection_file_written = False
441 try:
442 os.remove(self.connection_file)
443 except (OSError, AttributeError):
444 pass
446 def cleanup_ipc_files(self) -> None:
447 """Cleanup ipc files if we wrote them."""
448 if self.transport != "ipc":
449 return
450 for port in self.ports:
451 ipcfile = "%s-%i" % (self.ip, port)
452 try:
453 os.remove(ipcfile)
454 except OSError:
455 pass
457 def _record_random_port_names(self) -> None:
458 """Records which of the ports are randomly assigned.
460 Records on first invocation, if the transport is tcp.
461 Does nothing on later invocations."""
463 if self.transport != "tcp":
464 return
465 if self._random_port_names is not None:
466 return
468 self._random_port_names = []
469 for name in port_names:
470 if getattr(self, name) <= 0:
471 self._random_port_names.append(name)
473 def cleanup_random_ports(self) -> None:
474 """Forgets randomly assigned port numbers and cleans up the connection file.
476 Does nothing if no port numbers have been randomly assigned.
477 In particular, does nothing unless the transport is tcp.
478 """
480 if not self._random_port_names:
481 return
483 for name in self._random_port_names:
484 setattr(self, name, 0)
486 self.cleanup_connection_file()
488 def write_connection_file(self, **kwargs: Any) -> None:
489 """Write connection info to JSON dict in self.connection_file."""
490 if self._connection_file_written and os.path.exists(self.connection_file):
491 return
493 self.connection_file, cfg = write_connection_file(
494 self.connection_file,
495 transport=self.transport,
496 ip=self.ip,
497 key=self.session.key,
498 stdin_port=self.stdin_port,
499 iopub_port=self.iopub_port,
500 shell_port=self.shell_port,
501 hb_port=self.hb_port,
502 control_port=self.control_port,
503 signature_scheme=self.session.signature_scheme,
504 kernel_name=self.kernel_name,
505 **kwargs,
506 )
507 # write_connection_file also sets default ports:
508 self._record_random_port_names()
509 for name in port_names:
510 setattr(self, name, cfg[name])
512 self._connection_file_written = True
514 def load_connection_file(self, connection_file: Optional[str] = None) -> None:
515 """Load connection info from JSON dict in self.connection_file.
517 Parameters
518 ----------
519 connection_file: unicode, optional
520 Path to connection file to load.
521 If unspecified, use self.connection_file
522 """
523 if connection_file is None:
524 connection_file = self.connection_file
525 self.log.debug("Loading connection file %s", connection_file)
526 with open(connection_file) as f:
527 info = json.load(f)
528 self.load_connection_info(info)
530 def load_connection_info(self, info: KernelConnectionInfo) -> None:
531 """Load connection info from a dict containing connection info.
533 Typically this data comes from a connection file
534 and is called by load_connection_file.
536 Parameters
537 ----------
538 info: dict
539 Dictionary containing connection_info.
540 See the connection_file spec for details.
541 """
542 self.transport = info.get("transport", self.transport)
543 self.ip = info.get("ip", self._ip_default())
545 self._record_random_port_names()
546 for name in port_names:
547 if getattr(self, name) == 0 and name in info:
548 # not overridden by config or cl_args
549 setattr(self, name, info[name])
551 if "key" in info:
552 key = info["key"]
553 if isinstance(key, str):
554 key = key.encode()
555 assert isinstance(key, bytes)
557 self.session.key = key
558 if "signature_scheme" in info:
559 self.session.signature_scheme = info["signature_scheme"]
561 def _reconcile_connection_info(self, info: KernelConnectionInfo) -> None:
562 """Reconciles the connection information returned from the Provisioner.
564 Because some provisioners (like derivations of LocalProvisioner) may have already
565 written the connection file, this method needs to ensure that, if the connection
566 file exists, its contents match that of what was returned by the provisioner. If
567 the file does exist and its contents do not match, the file will be replaced with
568 the provisioner information (which is considered the truth).
570 If the file does not exist, the connection information in 'info' is loaded into the
571 KernelManager and written to the file.
572 """
573 # Prevent over-writing a file that has already been written with the same
574 # info. This is to prevent a race condition where the process has
575 # already been launched but has not yet read the connection file - as is
576 # the case with LocalProvisioners.
577 file_exists: bool = False
578 if os.path.exists(self.connection_file):
579 with open(self.connection_file) as f:
580 file_info = json.load(f)
581 # Prior to the following comparison, we need to adjust the value of "key" to
582 # be bytes, otherwise the comparison below will fail.
583 file_info["key"] = file_info["key"].encode()
584 if not self._equal_connections(info, file_info):
585 os.remove(self.connection_file) # Contents mismatch - remove the file
586 self._connection_file_written = False
587 else:
588 file_exists = True
590 if not file_exists:
591 # Load the connection info and write out file, clearing existing
592 # port-based attributes so they will be reloaded
593 for name in port_names:
594 setattr(self, name, 0)
595 self.load_connection_info(info)
596 self.write_connection_file()
598 # Ensure what is in KernelManager is what we expect.
599 km_info = self.get_connection_info()
600 if not self._equal_connections(info, km_info):
601 msg = (
602 "KernelManager's connection information already exists and does not match "
603 "the expected values returned from provisioner!"
604 )
605 raise ValueError(msg)
607 @staticmethod
608 def _equal_connections(conn1: KernelConnectionInfo, conn2: KernelConnectionInfo) -> bool:
609 """Compares pertinent keys of connection info data. Returns True if equivalent, False otherwise."""
611 pertinent_keys = [
612 "key",
613 "ip",
614 "stdin_port",
615 "iopub_port",
616 "shell_port",
617 "control_port",
618 "hb_port",
619 "transport",
620 "signature_scheme",
621 ]
623 return all(conn1.get(key) == conn2.get(key) for key in pertinent_keys)
625 # --------------------------------------------------------------------------
626 # Creating connected sockets
627 # --------------------------------------------------------------------------
629 def _make_url(self, channel: str) -> str:
630 """Make a ZeroMQ URL for a given channel."""
631 transport = self.transport
632 ip = self.ip
633 port = getattr(self, "%s_port" % channel)
635 if transport == "tcp":
636 return "tcp://%s:%i" % (ip, port)
637 else:
638 return f"{transport}://{ip}-{port}"
640 def _create_connected_socket(
641 self, channel: str, identity: Optional[bytes] = None
642 ) -> zmq.sugar.socket.Socket:
643 """Create a zmq Socket and connect it to the kernel."""
644 url = self._make_url(channel)
645 socket_type = channel_socket_types[channel]
646 self.log.debug("Connecting to: %s", url)
647 sock = self.context.socket(socket_type)
648 # set linger to 1s to prevent hangs at exit
649 sock.linger = 1000
650 if identity:
651 sock.identity = identity
652 sock.connect(url)
653 return sock
655 def connect_iopub(self, identity: Optional[bytes] = None) -> zmq.sugar.socket.Socket:
656 """return zmq Socket connected to the IOPub channel"""
657 sock = self._create_connected_socket("iopub", identity=identity)
658 sock.setsockopt(zmq.SUBSCRIBE, b"")
659 return sock
661 def connect_shell(self, identity: Optional[bytes] = None) -> zmq.sugar.socket.Socket:
662 """return zmq Socket connected to the Shell channel"""
663 return self._create_connected_socket("shell", identity=identity)
665 def connect_stdin(self, identity: Optional[bytes] = None) -> zmq.sugar.socket.Socket:
666 """return zmq Socket connected to the StdIn channel"""
667 return self._create_connected_socket("stdin", identity=identity)
669 def connect_hb(self, identity: Optional[bytes] = None) -> zmq.sugar.socket.Socket:
670 """return zmq Socket connected to the Heartbeat channel"""
671 return self._create_connected_socket("hb", identity=identity)
673 def connect_control(self, identity: Optional[bytes] = None) -> zmq.sugar.socket.Socket:
674 """return zmq Socket connected to the Control channel"""
675 return self._create_connected_socket("control", identity=identity)
678class LocalPortCache(SingletonConfigurable):
679 """
680 Used to keep track of local ports in order to prevent race conditions that
681 can occur between port acquisition and usage by the kernel. All locally-
682 provisioned kernels should use this mechanism to limit the possibility of
683 race conditions. Note that this does not preclude other applications from
684 acquiring a cached but unused port, thereby re-introducing the issue this
685 class is attempting to resolve (minimize).
686 See: https://github.com/jupyter/jupyter_client/issues/487
687 """
689 def __init__(self, **kwargs: Any) -> None:
690 super().__init__(**kwargs)
691 self.currently_used_ports: Set[int] = set()
693 def find_available_port(self, ip: str) -> int:
694 while True:
695 tmp_sock = socket.socket()
696 tmp_sock.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, b"\0" * 8)
697 tmp_sock.bind((ip, 0))
698 port = tmp_sock.getsockname()[1]
699 tmp_sock.close()
701 # This is a workaround for https://github.com/jupyter/jupyter_client/issues/487
702 # We prevent two kernels to have the same ports.
703 if port not in self.currently_used_ports:
704 self.currently_used_ports.add(port)
705 return port
707 def return_port(self, port: int) -> None:
708 if port in self.currently_used_ports: # Tolerate uncached ports
709 self.currently_used_ports.remove(port)
712__all__ = [
713 "write_connection_file",
714 "find_connection_file",
715 "tunnel_to_kernel",
716 "KernelConnectionInfo",
717 "LocalPortCache",
718]