1"""Kernel Provisioner Classes"""
2
3# Copyright (c) Jupyter Development Team.
4# Distributed under the terms of the Modified BSD License.
5import asyncio
6import os
7import pathlib
8import signal
9import sys
10from typing import TYPE_CHECKING, Any
11
12from ..connect import KernelConnectionInfo, LocalPortCache
13from ..launcher import launch_kernel
14from ..localinterfaces import is_local_ip, local_ips
15from .provisioner_base import KernelProvisionerBase
16
17
18class LocalProvisioner(KernelProvisionerBase):
19 """
20 :class:`LocalProvisioner` is a concrete class of ABC :py:class:`KernelProvisionerBase`
21 and is the out-of-box default implementation used when no kernel provisioner is
22 specified in the kernel specification (``kernel.json``). It provides functional
23 parity to existing applications by launching the kernel locally and using
24 :class:`subprocess.Popen` to manage its lifecycle.
25
26 This class is intended to be subclassed for customizing local kernel environments
27 and serve as a reference implementation for other custom provisioners.
28 """
29
30 process = None
31 _exit_future = None
32 pid = None
33 pgid = None
34 ip = None
35 ports_cached = False
36 cwd = None
37
38 @property
39 def has_process(self) -> bool:
40 return self.process is not None
41
42 async def poll(self) -> int | None:
43 """Poll the provisioner."""
44 ret = 0
45 if self.process:
46 ret = self.process.poll() # type:ignore[unreachable]
47 return ret
48
49 async def wait(self) -> int | None:
50 """Wait for the provisioner process."""
51 ret = 0
52 if self.process:
53 # Use busy loop at 100ms intervals, polling until the process is
54 # not alive. If we find the process is no longer alive, complete
55 # its cleanup via the blocking wait(). Callers are responsible for
56 # issuing calls to wait() using a timeout (see kill()).
57 while await self.poll() is None: # type:ignore[unreachable]
58 await asyncio.sleep(0.1)
59
60 # Process is no longer alive, wait and clear
61 ret = self.process.wait()
62 # Make sure all the fds get closed.
63 for attr in ["stdout", "stderr", "stdin"]:
64 fid = getattr(self.process, attr)
65 if fid:
66 fid.close()
67 self.process = None # allow has_process to now return False
68 return ret
69
70 async def send_signal(self, signum: int) -> None:
71 """Sends a signal to the process group of the kernel (this
72 usually includes the kernel and any subprocesses spawned by
73 the kernel).
74
75 Note that since only SIGTERM is supported on Windows, we will
76 check if the desired signal is for interrupt and apply the
77 applicable code on Windows in that case.
78 """
79 if self.process:
80 if signum == signal.SIGINT and sys.platform == "win32": # type:ignore[unreachable]
81 from ..win_interrupt import send_interrupt
82
83 send_interrupt(self.process.win32_interrupt_event)
84 return
85
86 # Prefer process-group over process
87 if self.pgid and hasattr(os, "killpg"):
88 try:
89 os.killpg(self.pgid, signum)
90 return
91 except OSError:
92 pass # We'll retry sending the signal to only the process below
93
94 # If we're here, send the signal to the process and let caller handle exceptions
95 self.process.send_signal(signum)
96 return
97
98 async def kill(self, restart: bool = False) -> None:
99 """Kill the provisioner and optionally restart."""
100 if self.process:
101 if hasattr(signal, "SIGKILL"): # type:ignore[unreachable]
102 # If available, give preference to signalling the process-group over `kill()`.
103 try:
104 await self.send_signal(signal.SIGKILL)
105 return
106 except OSError:
107 pass
108 try:
109 self.process.kill()
110 except OSError as e:
111 LocalProvisioner._tolerate_no_process(e)
112
113 async def terminate(self, restart: bool = False) -> None:
114 """Terminate the provisioner and optionally restart."""
115 if self.process:
116 if hasattr(signal, "SIGTERM"): # type:ignore[unreachable]
117 # If available, give preference to signalling the process group over `terminate()`.
118 try:
119 await self.send_signal(signal.SIGTERM)
120 return
121 except OSError:
122 pass
123 try:
124 self.process.terminate()
125 except OSError as e:
126 LocalProvisioner._tolerate_no_process(e)
127
128 @staticmethod
129 def _tolerate_no_process(os_error: OSError) -> None:
130 # In Windows, we will get an Access Denied error if the process
131 # has already terminated. Ignore it.
132 if sys.platform == "win32":
133 if os_error.winerror != 5:
134 err_message = f"Invalid Error, expecting error number to be 5, got {os_error}"
135 raise ValueError(err_message)
136
137 # On Unix, we may get an ESRCH error (or ProcessLookupError instance) if
138 # the process has already terminated. Ignore it.
139 else:
140 from errno import ESRCH
141
142 if not isinstance(os_error, ProcessLookupError) or os_error.errno != ESRCH:
143 err_message = (
144 f"Invalid Error, expecting ProcessLookupError or ESRCH, got {os_error}"
145 )
146 raise ValueError(err_message)
147
148 async def cleanup(self, restart: bool = False) -> None:
149 """Clean up the resources used by the provisioner and optionally restart."""
150 if self.ports_cached and not restart:
151 # provisioner is about to be destroyed, return cached ports
152 lpc = LocalPortCache.instance()
153 ports = (
154 self.connection_info["shell_port"],
155 self.connection_info["iopub_port"],
156 self.connection_info["stdin_port"],
157 self.connection_info["hb_port"],
158 self.connection_info["control_port"],
159 )
160 for port in ports:
161 if TYPE_CHECKING:
162 assert isinstance(port, int)
163 lpc.return_port(port)
164
165 async def pre_launch(self, **kwargs: Any) -> dict[str, Any]:
166 """Perform any steps in preparation for kernel process launch.
167
168 This includes applying additional substitutions to the kernel launch command and env.
169 It also includes preparation of launch parameters.
170
171 Returns the updated kwargs.
172 """
173
174 # This should be considered temporary until a better division of labor can be defined.
175 km = self.parent
176 if km:
177 if km.transport == "tcp" and not is_local_ip(km.ip):
178 msg = (
179 "Can only launch a kernel on a local interface. "
180 f"This one is not: {km.ip}."
181 "Make sure that the '*_address' attributes are "
182 "configured properly. "
183 f"Currently valid addresses are: {local_ips()}"
184 )
185 raise RuntimeError(msg)
186 # build the Popen cmd
187 extra_arguments = kwargs.pop("extra_arguments", [])
188
189 # write connection file / get default ports
190 # TODO - change when handshake pattern is adopted
191 if km.cache_ports and not self.ports_cached:
192 lpc = LocalPortCache.instance()
193 km.shell_port = lpc.find_available_port(km.ip)
194 km.iopub_port = lpc.find_available_port(km.ip)
195 km.stdin_port = lpc.find_available_port(km.ip)
196 km.hb_port = lpc.find_available_port(km.ip)
197 km.control_port = lpc.find_available_port(km.ip)
198 self.ports_cached = True
199 if "env" in kwargs:
200 jupyter_session = kwargs["env"].get("JPY_SESSION_NAME", "")
201 km.write_connection_file(jupyter_session=jupyter_session)
202 else:
203 km.write_connection_file()
204 self.connection_info = km.get_connection_info()
205
206 kernel_cmd = km.format_kernel_cmd(
207 extra_arguments=extra_arguments
208 ) # This needs to remain here for b/c
209 else:
210 extra_arguments = kwargs.pop("extra_arguments", [])
211 kernel_cmd = self.kernel_spec.argv + extra_arguments
212
213 return await super().pre_launch(cmd=kernel_cmd, **kwargs)
214
215 async def launch_kernel(self, cmd: list[str], **kwargs: Any) -> KernelConnectionInfo:
216 """Launch a kernel with a command."""
217
218 scrubbed_kwargs = LocalProvisioner._scrub_kwargs(kwargs)
219 self.process = launch_kernel(cmd, **scrubbed_kwargs)
220 pgid = None
221 if hasattr(os, "getpgid"):
222 try:
223 pgid = os.getpgid(self.process.pid)
224 except OSError:
225 pass
226
227 self.pid = self.process.pid
228 self.pgid = pgid
229 self.cwd = kwargs.get("cwd", pathlib.Path.cwd())
230 return self.connection_info
231
232 def resolve_path(self, path_str: str) -> str | None:
233 """Resolve path to given file."""
234 path = pathlib.Path(path_str).expanduser()
235 if not path.is_absolute() and self.cwd:
236 path = (pathlib.Path(self.cwd) / path).resolve()
237 if path.exists():
238 return path.as_posix()
239 return None
240
241 @staticmethod
242 def _scrub_kwargs(kwargs: dict[str, Any]) -> dict[str, Any]:
243 """Remove any keyword arguments that Popen does not tolerate."""
244 keywords_to_scrub: list[str] = ["extra_arguments", "kernel_id"]
245 scrubbed_kwargs = kwargs.copy()
246 for kw in keywords_to_scrub:
247 scrubbed_kwargs.pop(kw, None)
248 return scrubbed_kwargs
249
250 async def get_provisioner_info(self) -> dict:
251 """Captures the base information necessary for persistence relative to this instance."""
252 provisioner_info = await super().get_provisioner_info()
253 provisioner_info.update({"pid": self.pid, "pgid": self.pgid, "ip": self.ip})
254 return provisioner_info
255
256 async def load_provisioner_info(self, provisioner_info: dict) -> None:
257 """Loads the base information necessary for persistence relative to this instance."""
258 await super().load_provisioner_info(provisioner_info)
259 self.pid = provisioner_info["pid"]
260 self.pgid = provisioner_info["pgid"]
261 self.ip = provisioner_info["ip"]