Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/jupyter_client/connect.py: 24%
288 statements
« prev ^ index » next coverage.py v7.2.3, created at 2023-04-10 06:20 +0000
« prev ^ index » next coverage.py v7.2.3, created at 2023-04-10 06:20 +0000
1"""Utilities for connecting to jupyter kernels
3The :class:`ConnectionFileMixin` class in this module encapsulates the logic
4related to writing and reading connections files.
5"""
6# Copyright (c) Jupyter Development Team.
7# Distributed under the terms of the Modified BSD License.
8import errno
9import glob
10import json
11import os
12import socket
13import stat
14import tempfile
15import warnings
16from getpass import getpass
17from typing import Any, Dict, List, Optional, Set, Tuple, Union, cast
19import zmq
20from jupyter_core.paths import jupyter_data_dir, jupyter_runtime_dir, secure_write
21from traitlets import Bool, CaselessStrEnum, Instance, Integer, Type, Unicode, observe
22from traitlets.config import LoggingConfigurable, SingletonConfigurable
24from .localinterfaces import localhost
25from .utils import _filefind
27# Define custom type for kernel connection info
28KernelConnectionInfo = Dict[str, Union[int, str, bytes]]
31def write_connection_file(
32 fname: Optional[str] = None,
33 shell_port: int = 0,
34 iopub_port: int = 0,
35 stdin_port: int = 0,
36 hb_port: int = 0,
37 control_port: int = 0,
38 ip: str = "",
39 key: bytes = b"",
40 transport: str = "tcp",
41 signature_scheme: str = "hmac-sha256",
42 kernel_name: str = "",
43) -> Tuple[str, KernelConnectionInfo]:
44 """Generates a JSON config file, including the selection of random ports.
46 Parameters
47 ----------
49 fname : unicode
50 The path to the file to write
52 shell_port : int, optional
53 The port to use for ROUTER (shell) channel.
55 iopub_port : int, optional
56 The port to use for the SUB channel.
58 stdin_port : int, optional
59 The port to use for the ROUTER (raw input) channel.
61 control_port : int, optional
62 The port to use for the ROUTER (control) channel.
64 hb_port : int, optional
65 The port to use for the heartbeat REP channel.
67 ip : str, optional
68 The ip address the kernel will bind to.
70 key : str, optional
71 The Session key used for message authentication.
73 signature_scheme : str, optional
74 The scheme used for message authentication.
75 This has the form 'digest-hash', where 'digest'
76 is the scheme used for digests, and 'hash' is the name of the hash function
77 used by the digest scheme.
78 Currently, 'hmac' is the only supported digest scheme,
79 and 'sha256' is the default hash function.
81 kernel_name : str, optional
82 The name of the kernel currently connected to.
83 """
84 if not ip:
85 ip = localhost()
86 # default to temporary connector file
87 if not fname:
88 fd, fname = tempfile.mkstemp(".json")
89 os.close(fd)
91 # Find open ports as necessary.
93 ports: List[int] = []
94 sockets: List[socket.socket] = []
95 ports_needed = (
96 int(shell_port <= 0)
97 + int(iopub_port <= 0)
98 + int(stdin_port <= 0)
99 + int(control_port <= 0)
100 + int(hb_port <= 0)
101 )
102 if transport == "tcp":
103 for _ in range(ports_needed):
104 sock = socket.socket()
105 # struct.pack('ii', (0,0)) is 8 null bytes
106 sock.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, b"\0" * 8)
107 sock.bind((ip, 0))
108 sockets.append(sock)
109 for sock in sockets:
110 port = sock.getsockname()[1]
111 sock.close()
112 ports.append(port)
113 else:
114 N = 1
115 for _ in range(ports_needed):
116 while os.path.exists(f"{ip}-{str(N)}"):
117 N += 1
118 ports.append(N)
119 N += 1
120 if shell_port <= 0:
121 shell_port = ports.pop(0)
122 if iopub_port <= 0:
123 iopub_port = ports.pop(0)
124 if stdin_port <= 0:
125 stdin_port = ports.pop(0)
126 if control_port <= 0:
127 control_port = ports.pop(0)
128 if hb_port <= 0:
129 hb_port = ports.pop(0)
131 cfg: KernelConnectionInfo = {
132 "shell_port": shell_port,
133 "iopub_port": iopub_port,
134 "stdin_port": stdin_port,
135 "control_port": control_port,
136 "hb_port": hb_port,
137 }
138 cfg["ip"] = ip
139 cfg["key"] = key.decode()
140 cfg["transport"] = transport
141 cfg["signature_scheme"] = signature_scheme
142 cfg["kernel_name"] = kernel_name
144 # Only ever write this file as user read/writeable
145 # This would otherwise introduce a vulnerability as a file has secrets
146 # which would let others execute arbitrary code as you
147 with secure_write(fname) as f:
148 f.write(json.dumps(cfg, indent=2))
150 if hasattr(stat, "S_ISVTX"):
151 # set the sticky bit on the parent directory of the file
152 # to ensure only owner can remove it
153 runtime_dir = os.path.dirname(fname)
154 if runtime_dir:
155 permissions = os.stat(runtime_dir).st_mode
156 new_permissions = permissions | stat.S_ISVTX
157 if new_permissions != permissions:
158 try:
159 os.chmod(runtime_dir, new_permissions)
160 except OSError as e:
161 if e.errno == errno.EPERM:
162 # suppress permission errors setting sticky bit on runtime_dir,
163 # which we may not own.
164 pass
165 return fname, cfg
168def find_connection_file(
169 filename: str = "kernel-*.json",
170 path: Optional[Union[str, List[str]]] = None,
171 profile: Optional[str] = None,
172) -> str:
173 """find a connection file, and return its absolute path.
175 The current working directory and optional search path
176 will be searched for the file if it is not given by absolute path.
178 If the argument does not match an existing file, it will be interpreted as a
179 fileglob, and the matching file in the profile's security dir with
180 the latest access time will be used.
182 Parameters
183 ----------
184 filename : str
185 The connection file or fileglob to search for.
186 path : str or list of strs[optional]
187 Paths in which to search for connection files.
189 Returns
190 -------
191 str : The absolute path of the connection file.
192 """
193 if profile is not None:
194 warnings.warn("Jupyter has no profiles. profile=%s has been ignored." % profile)
195 if path is None:
196 path = [".", jupyter_runtime_dir()]
197 if isinstance(path, str):
198 path = [path]
200 try:
201 # first, try explicit name
202 return _filefind(filename, path)
203 except OSError:
204 pass
206 # not found by full name
208 if "*" in filename:
209 # given as a glob already
210 pat = filename
211 else:
212 # accept any substring match
213 pat = "*%s*" % filename
215 matches = []
216 for p in path:
217 matches.extend(glob.glob(os.path.join(p, pat)))
219 matches = [os.path.abspath(m) for m in matches]
220 if not matches:
221 msg = f"Could not find {filename!r} in {path!r}"
222 raise OSError(msg)
223 elif len(matches) == 1:
224 return matches[0]
225 else:
226 # get most recent match, by access time:
227 return sorted(matches, key=lambda f: os.stat(f).st_atime)[-1]
230def tunnel_to_kernel(
231 connection_info: Union[str, KernelConnectionInfo],
232 sshserver: str,
233 sshkey: Optional[str] = None,
234) -> Tuple[Any, ...]:
235 """tunnel connections to a kernel via ssh
237 This will open five SSH tunnels from localhost on this machine to the
238 ports associated with the kernel. They can be either direct
239 localhost-localhost tunnels, or if an intermediate server is necessary,
240 the kernel must be listening on a public IP.
242 Parameters
243 ----------
244 connection_info : dict or str (path)
245 Either a connection dict, or the path to a JSON connection file
246 sshserver : str
247 The ssh sever to use to tunnel to the kernel. Can be a full
248 `user@server:port` string. ssh config aliases are respected.
249 sshkey : str [optional]
250 Path to file containing ssh key to use for authentication.
251 Only necessary if your ssh config does not already associate
252 a keyfile with the host.
254 Returns
255 -------
257 (shell, iopub, stdin, hb, control) : ints
258 The five ports on localhost that have been forwarded to the kernel.
259 """
260 from .ssh import tunnel
262 if isinstance(connection_info, str):
263 # it's a path, unpack it
264 with open(connection_info) as f:
265 connection_info = json.loads(f.read())
267 cf = cast(Dict[str, Any], connection_info)
269 lports = tunnel.select_random_ports(5)
270 rports = (
271 cf["shell_port"],
272 cf["iopub_port"],
273 cf["stdin_port"],
274 cf["hb_port"],
275 cf["control_port"],
276 )
278 remote_ip = cf["ip"]
280 if tunnel.try_passwordless_ssh(sshserver, sshkey):
281 password: Union[bool, str] = False
282 else:
283 password = getpass("SSH Password for %s: " % sshserver)
285 for lp, rp in zip(lports, rports):
286 tunnel.ssh_tunnel(lp, rp, sshserver, remote_ip, sshkey, password)
288 return tuple(lports)
291# -----------------------------------------------------------------------------
292# Mixin for classes that work with connection files
293# -----------------------------------------------------------------------------
295channel_socket_types = {
296 "hb": zmq.REQ,
297 "shell": zmq.DEALER,
298 "iopub": zmq.SUB,
299 "stdin": zmq.DEALER,
300 "control": zmq.DEALER,
301}
303port_names = ["%s_port" % channel for channel in ("shell", "stdin", "iopub", "hb", "control")]
306class ConnectionFileMixin(LoggingConfigurable):
307 """Mixin for configurable classes that work with connection files"""
309 data_dir: Union[str, Unicode] = Unicode()
311 def _data_dir_default(self):
312 return jupyter_data_dir()
314 # The addresses for the communication channels
315 connection_file = Unicode(
316 "",
317 config=True,
318 help="""JSON file in which to store connection info [default: kernel-<pid>.json]
320 This file will contain the IP, ports, and authentication key needed to connect
321 clients to this kernel. By default, this file will be created in the security dir
322 of the current profile, but can be specified by absolute path.
323 """,
324 )
325 _connection_file_written = Bool(False)
327 transport = CaselessStrEnum(["tcp", "ipc"], default_value="tcp", config=True)
328 kernel_name: Union[str, Unicode] = Unicode()
330 context = Instance(zmq.Context)
332 ip = Unicode(
333 config=True,
334 help="""Set the kernel\'s IP address [default localhost].
335 If the IP address is something other than localhost, then
336 Consoles on other machines will be able to connect
337 to the Kernel, so be careful!""",
338 )
340 def _ip_default(self):
341 if self.transport == "ipc":
342 if self.connection_file:
343 return os.path.splitext(self.connection_file)[0] + "-ipc"
344 else:
345 return "kernel-ipc"
346 else:
347 return localhost()
349 @observe("ip")
350 def _ip_changed(self, change):
351 if change["new"] == "*":
352 self.ip = "0.0.0.0" # noqa
354 # protected traits
356 hb_port = Integer(0, config=True, help="set the heartbeat port [default: random]")
357 shell_port = Integer(0, config=True, help="set the shell (ROUTER) port [default: random]")
358 iopub_port = Integer(0, config=True, help="set the iopub (PUB) port [default: random]")
359 stdin_port = Integer(0, config=True, help="set the stdin (ROUTER) port [default: random]")
360 control_port = Integer(0, config=True, help="set the control (ROUTER) port [default: random]")
362 # names of the ports with random assignment
363 _random_port_names: Optional[List[str]] = None
365 @property
366 def ports(self) -> List[int]:
367 return [getattr(self, name) for name in port_names]
369 # The Session to use for communication with the kernel.
370 session = Instance("jupyter_client.session.Session")
372 def _session_default(self):
373 from .session import Session
375 return Session(parent=self)
377 # --------------------------------------------------------------------------
378 # Connection and ipc file management
379 # --------------------------------------------------------------------------
381 def get_connection_info(self, session: bool = False) -> KernelConnectionInfo:
382 """Return the connection info as a dict
384 Parameters
385 ----------
386 session : bool [default: False]
387 If True, return our session object will be included in the connection info.
388 If False (default), the configuration parameters of our session object will be included,
389 rather than the session object itself.
391 Returns
392 -------
393 connect_info : dict
394 dictionary of connection information.
395 """
396 info = {
397 "transport": self.transport,
398 "ip": self.ip,
399 "shell_port": self.shell_port,
400 "iopub_port": self.iopub_port,
401 "stdin_port": self.stdin_port,
402 "hb_port": self.hb_port,
403 "control_port": self.control_port,
404 }
405 if session:
406 # add *clone* of my session,
407 # so that state such as digest_history is not shared.
408 info["session"] = self.session.clone()
409 else:
410 # add session info
411 info.update(
412 {
413 "signature_scheme": self.session.signature_scheme,
414 "key": self.session.key,
415 }
416 )
417 return info
419 # factory for blocking clients
420 blocking_class = Type(klass=object, default_value="jupyter_client.BlockingKernelClient")
422 def blocking_client(self):
423 """Make a blocking client connected to my kernel"""
424 info = self.get_connection_info()
425 bc = self.blocking_class(parent=self)
426 bc.load_connection_info(info)
427 return bc
429 def cleanup_connection_file(self) -> None:
430 """Cleanup connection file *if we wrote it*
432 Will not raise if the connection file was already removed somehow.
433 """
434 if self._connection_file_written:
435 # cleanup connection files on full shutdown of kernel we started
436 self._connection_file_written = False
437 try:
438 os.remove(self.connection_file)
439 except (OSError, AttributeError):
440 pass
442 def cleanup_ipc_files(self) -> None:
443 """Cleanup ipc files if we wrote them."""
444 if self.transport != "ipc":
445 return
446 for port in self.ports:
447 ipcfile = "%s-%i" % (self.ip, port)
448 try:
449 os.remove(ipcfile)
450 except OSError:
451 pass
453 def _record_random_port_names(self) -> None:
454 """Records which of the ports are randomly assigned.
456 Records on first invocation, if the transport is tcp.
457 Does nothing on later invocations."""
459 if self.transport != "tcp":
460 return
461 if self._random_port_names is not None:
462 return
464 self._random_port_names = []
465 for name in port_names:
466 if getattr(self, name) <= 0:
467 self._random_port_names.append(name)
469 def cleanup_random_ports(self) -> None:
470 """Forgets randomly assigned port numbers and cleans up the connection file.
472 Does nothing if no port numbers have been randomly assigned.
473 In particular, does nothing unless the transport is tcp.
474 """
476 if not self._random_port_names:
477 return
479 for name in self._random_port_names:
480 setattr(self, name, 0)
482 self.cleanup_connection_file()
484 def write_connection_file(self) -> None:
485 """Write connection info to JSON dict in self.connection_file."""
486 if self._connection_file_written and os.path.exists(self.connection_file):
487 return
489 self.connection_file, cfg = write_connection_file(
490 self.connection_file,
491 transport=self.transport,
492 ip=self.ip,
493 key=self.session.key,
494 stdin_port=self.stdin_port,
495 iopub_port=self.iopub_port,
496 shell_port=self.shell_port,
497 hb_port=self.hb_port,
498 control_port=self.control_port,
499 signature_scheme=self.session.signature_scheme,
500 kernel_name=self.kernel_name,
501 )
502 # write_connection_file also sets default ports:
503 self._record_random_port_names()
504 for name in port_names:
505 setattr(self, name, cfg[name])
507 self._connection_file_written = True
509 def load_connection_file(self, connection_file: Optional[str] = None) -> None:
510 """Load connection info from JSON dict in self.connection_file.
512 Parameters
513 ----------
514 connection_file: unicode, optional
515 Path to connection file to load.
516 If unspecified, use self.connection_file
517 """
518 if connection_file is None:
519 connection_file = self.connection_file
520 self.log.debug("Loading connection file %s", connection_file)
521 with open(connection_file) as f:
522 info = json.load(f)
523 self.load_connection_info(info)
525 def load_connection_info(self, info: KernelConnectionInfo) -> None:
526 """Load connection info from a dict containing connection info.
528 Typically this data comes from a connection file
529 and is called by load_connection_file.
531 Parameters
532 ----------
533 info: dict
534 Dictionary containing connection_info.
535 See the connection_file spec for details.
536 """
537 self.transport = info.get("transport", self.transport)
538 self.ip = info.get("ip", self._ip_default())
540 self._record_random_port_names()
541 for name in port_names:
542 if getattr(self, name) == 0 and name in info:
543 # not overridden by config or cl_args
544 setattr(self, name, info[name])
546 if "key" in info:
547 key = info["key"]
548 if isinstance(key, str):
549 key = key.encode()
550 assert isinstance(key, bytes)
552 self.session.key = key
553 if "signature_scheme" in info:
554 self.session.signature_scheme = info["signature_scheme"]
556 def _reconcile_connection_info(self, info: KernelConnectionInfo) -> None:
557 """Reconciles the connection information returned from the Provisioner.
559 Because some provisioners (like derivations of LocalProvisioner) may have already
560 written the connection file, this method needs to ensure that, if the connection
561 file exists, its contents match that of what was returned by the provisioner. If
562 the file does exist and its contents do not match, the file will be replaced with
563 the provisioner information (which is considered the truth).
565 If the file does not exist, the connection information in 'info' is loaded into the
566 KernelManager and written to the file.
567 """
568 # Prevent over-writing a file that has already been written with the same
569 # info. This is to prevent a race condition where the process has
570 # already been launched but has not yet read the connection file - as is
571 # the case with LocalProvisioners.
572 file_exists: bool = False
573 if os.path.exists(self.connection_file):
574 with open(self.connection_file) as f:
575 file_info = json.load(f)
576 # Prior to the following comparison, we need to adjust the value of "key" to
577 # be bytes, otherwise the comparison below will fail.
578 file_info["key"] = file_info["key"].encode()
579 if not self._equal_connections(info, file_info):
580 os.remove(self.connection_file) # Contents mismatch - remove the file
581 self._connection_file_written = False
582 else:
583 file_exists = True
585 if not file_exists:
586 # Load the connection info and write out file, clearing existing
587 # port-based attributes so they will be reloaded
588 for name in port_names:
589 setattr(self, name, 0)
590 self.load_connection_info(info)
591 self.write_connection_file()
593 # Ensure what is in KernelManager is what we expect.
594 km_info = self.get_connection_info()
595 if not self._equal_connections(info, km_info):
596 msg = (
597 "KernelManager's connection information already exists and does not match "
598 "the expected values returned from provisioner!"
599 )
600 raise ValueError(msg)
602 @staticmethod
603 def _equal_connections(conn1: KernelConnectionInfo, conn2: KernelConnectionInfo) -> bool:
604 """Compares pertinent keys of connection info data. Returns True if equivalent, False otherwise."""
606 pertinent_keys = [
607 "key",
608 "ip",
609 "stdin_port",
610 "iopub_port",
611 "shell_port",
612 "control_port",
613 "hb_port",
614 "transport",
615 "signature_scheme",
616 ]
618 return all(conn1.get(key) == conn2.get(key) for key in pertinent_keys)
620 # --------------------------------------------------------------------------
621 # Creating connected sockets
622 # --------------------------------------------------------------------------
624 def _make_url(self, channel: str) -> str:
625 """Make a ZeroMQ URL for a given channel."""
626 transport = self.transport
627 ip = self.ip
628 port = getattr(self, "%s_port" % channel)
630 if transport == "tcp":
631 return "tcp://%s:%i" % (ip, port)
632 else:
633 return f"{transport}://{ip}-{port}"
635 def _create_connected_socket(
636 self, channel: str, identity: Optional[bytes] = None
637 ) -> zmq.sugar.socket.Socket:
638 """Create a zmq Socket and connect it to the kernel."""
639 url = self._make_url(channel)
640 socket_type = channel_socket_types[channel]
641 self.log.debug("Connecting to: %s", url)
642 sock = self.context.socket(socket_type)
643 # set linger to 1s to prevent hangs at exit
644 sock.linger = 1000
645 if identity:
646 sock.identity = identity
647 sock.connect(url)
648 return sock
650 def connect_iopub(self, identity: Optional[bytes] = None) -> zmq.sugar.socket.Socket:
651 """return zmq Socket connected to the IOPub channel"""
652 sock = self._create_connected_socket("iopub", identity=identity)
653 sock.setsockopt(zmq.SUBSCRIBE, b"")
654 return sock
656 def connect_shell(self, identity: Optional[bytes] = None) -> zmq.sugar.socket.Socket:
657 """return zmq Socket connected to the Shell channel"""
658 return self._create_connected_socket("shell", identity=identity)
660 def connect_stdin(self, identity: Optional[bytes] = None) -> zmq.sugar.socket.Socket:
661 """return zmq Socket connected to the StdIn channel"""
662 return self._create_connected_socket("stdin", identity=identity)
664 def connect_hb(self, identity: Optional[bytes] = None) -> zmq.sugar.socket.Socket:
665 """return zmq Socket connected to the Heartbeat channel"""
666 return self._create_connected_socket("hb", identity=identity)
668 def connect_control(self, identity: Optional[bytes] = None) -> zmq.sugar.socket.Socket:
669 """return zmq Socket connected to the Control channel"""
670 return self._create_connected_socket("control", identity=identity)
673class LocalPortCache(SingletonConfigurable):
674 """
675 Used to keep track of local ports in order to prevent race conditions that
676 can occur between port acquisition and usage by the kernel. All locally-
677 provisioned kernels should use this mechanism to limit the possibility of
678 race conditions. Note that this does not preclude other applications from
679 acquiring a cached but unused port, thereby re-introducing the issue this
680 class is attempting to resolve (minimize).
681 See: https://github.com/jupyter/jupyter_client/issues/487
682 """
684 def __init__(self, **kwargs: Any) -> None:
685 super().__init__(**kwargs)
686 self.currently_used_ports: Set[int] = set()
688 def find_available_port(self, ip: str) -> int:
689 while True:
690 tmp_sock = socket.socket()
691 tmp_sock.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, b"\0" * 8)
692 tmp_sock.bind((ip, 0))
693 port = tmp_sock.getsockname()[1]
694 tmp_sock.close()
696 # This is a workaround for https://github.com/jupyter/jupyter_client/issues/487
697 # We prevent two kernels to have the same ports.
698 if port not in self.currently_used_ports:
699 self.currently_used_ports.add(port)
700 return port
702 def return_port(self, port: int) -> None:
703 if port in self.currently_used_ports: # Tolerate uncached ports
704 self.currently_used_ports.remove(port)
707__all__ = [
708 "write_connection_file",
709 "find_connection_file",
710 "tunnel_to_kernel",
711 "KernelConnectionInfo",
712 "LocalPortCache",
713]