1"""A kernel manager for multiple kernels"""
2
3# Copyright (c) Jupyter Development Team.
4# Distributed under the terms of the Modified BSD License.
5from __future__ import annotations
6
7import asyncio
8import json
9import os
10import socket
11import typing as t
12import uuid
13from functools import wraps
14from pathlib import Path
15
16import zmq
17from traitlets import Any, Bool, Dict, DottedObjectName, Instance, Unicode, default, observe
18from traitlets.config.configurable import LoggingConfigurable
19from traitlets.utils.importstring import import_item
20
21from .connect import KernelConnectionInfo
22from .kernelspec import NATIVE_KERNEL_NAME, KernelSpecManager
23from .manager import KernelManager
24from .utils import ensure_async, run_sync, utcnow
25
26
27class DuplicateKernelError(Exception):
28 pass
29
30
31def kernel_method(f: t.Callable) -> t.Callable:
32 """decorator for proxying MKM.method(kernel_id) to individual KMs by ID"""
33
34 @wraps(f)
35 def wrapped(
36 self: t.Any, kernel_id: str, *args: t.Any, **kwargs: t.Any
37 ) -> t.Callable | t.Awaitable:
38 # get the kernel
39 km = self.get_kernel(kernel_id)
40 method = getattr(km, f.__name__)
41 # call the kernel's method
42 r = method(*args, **kwargs)
43 # last thing, call anything defined in the actual class method
44 # such as logging messages
45 f(self, kernel_id, *args, **kwargs)
46 # return the method result
47 return r
48
49 return wrapped
50
51
52class MultiKernelManager(LoggingConfigurable):
53 """A class for managing multiple kernels."""
54
55 default_kernel_name = Unicode(
56 NATIVE_KERNEL_NAME, help="The name of the default kernel to start"
57 ).tag(config=True)
58
59 kernel_spec_manager = Instance(KernelSpecManager, allow_none=True)
60
61 kernel_manager_class = DottedObjectName(
62 "jupyter_client.ioloop.IOLoopKernelManager",
63 help="""The kernel manager class. This is configurable to allow
64 subclassing of the KernelManager for customized behavior.
65 """,
66 ).tag(config=True)
67
68 @observe("kernel_manager_class")
69 def _kernel_manager_class_changed(self, change: t.Any) -> None:
70 self.kernel_manager_factory = self._create_kernel_manager_factory()
71
72 kernel_manager_factory = Any(help="this is kernel_manager_class after import")
73
74 @default("kernel_manager_factory")
75 def _kernel_manager_factory_default(self) -> t.Callable:
76 return self._create_kernel_manager_factory()
77
78 def _create_kernel_manager_factory(self) -> t.Callable:
79 kernel_manager_ctor = import_item(self.kernel_manager_class)
80
81 def create_kernel_manager(*args: t.Any, **kwargs: t.Any) -> KernelManager:
82 if self.shared_context:
83 if self.context.closed:
84 # recreate context if closed
85 self.context = self._context_default()
86 kwargs.setdefault("context", self.context)
87 km = kernel_manager_ctor(*args, **kwargs)
88 return km
89
90 return create_kernel_manager
91
92 shared_context = Bool(
93 True,
94 help="Share a single zmq.Context to talk to all my kernels",
95 ).tag(config=True)
96
97 context = Instance("zmq.Context")
98
99 _created_context = Bool(False)
100
101 _pending_kernels = Dict()
102
103 @property
104 def _starting_kernels(self) -> dict:
105 """A shim for backwards compatibility."""
106 return self._pending_kernels
107
108 @default("context")
109 def _context_default(self) -> zmq.Context:
110 self._created_context = True
111 return zmq.Context()
112
113 connection_dir = Unicode("")
114 external_connection_dir = Unicode(None, allow_none=True)
115
116 _kernels = Dict()
117
118 def __init__(self, *args: t.Any, **kwargs: t.Any) -> None:
119 super().__init__(*args, **kwargs)
120 self.kernel_id_to_connection_file: dict[str, Path] = {}
121
122 def __del__(self) -> None:
123 """Handle garbage collection. Destroy context if applicable."""
124 if self._created_context and self.context and not self.context.closed:
125 if self.log:
126 self.log.debug("Destroying zmq context for %s", self)
127 self.context.destroy(linger=1000)
128 try:
129 super_del = super().__del__ # type:ignore[misc]
130 except AttributeError:
131 pass
132 else:
133 super_del()
134
135 def list_kernel_ids(self) -> list[str]:
136 """Return a list of the kernel ids of the active kernels."""
137 if self.external_connection_dir is not None:
138 external_connection_dir = Path(self.external_connection_dir)
139 if external_connection_dir.is_dir():
140 connection_files = [p for p in external_connection_dir.iterdir() if p.is_file()]
141
142 # remove kernels (whose connection file has disappeared) from our list
143 k = list(self.kernel_id_to_connection_file.keys())
144 v = list(self.kernel_id_to_connection_file.values())
145 for connection_file in list(self.kernel_id_to_connection_file.values()):
146 if connection_file not in connection_files:
147 kernel_id = k[v.index(connection_file)]
148 del self.kernel_id_to_connection_file[kernel_id]
149 del self._kernels[kernel_id]
150
151 # add kernels (whose connection file appeared) to our list
152 for connection_file in connection_files:
153 if connection_file in self.kernel_id_to_connection_file.values():
154 continue
155 try:
156 connection_info: KernelConnectionInfo = json.loads(
157 connection_file.read_text()
158 )
159 except Exception: # noqa: S112
160 continue
161 self.log.debug("Loading connection file %s", connection_file)
162 if not ("kernel_name" in connection_info and "key" in connection_info):
163 continue
164 # it looks like a connection file
165 kernel_id = self.new_kernel_id()
166 self.kernel_id_to_connection_file[kernel_id] = connection_file
167 km = self.kernel_manager_factory(
168 parent=self,
169 log=self.log,
170 owns_kernel=False,
171 )
172 km.load_connection_info(connection_info)
173 km.last_activity = utcnow()
174 km.execution_state = "idle"
175 km.connections = 1
176 km.kernel_id = kernel_id
177 km.kernel_name = connection_info["kernel_name"]
178 km.ready.set_result(None)
179
180 self._kernels[kernel_id] = km
181
182 # Create a copy so we can iterate over kernels in operations
183 # that delete keys.
184 return list(self._kernels.keys())
185
186 def __len__(self) -> int:
187 """Return the number of running kernels."""
188 return len(self.list_kernel_ids())
189
190 def __contains__(self, kernel_id: str) -> bool:
191 return kernel_id in self._kernels
192
193 def pre_start_kernel(
194 self, kernel_name: str | None, kwargs: t.Any
195 ) -> tuple[KernelManager, str, str]:
196 # kwargs should be mutable, passing it as a dict argument.
197 kernel_id = kwargs.pop("kernel_id", self.new_kernel_id(**kwargs))
198 if kernel_id in self:
199 raise DuplicateKernelError("Kernel already exists: %s" % kernel_id)
200
201 if kernel_name is None:
202 kernel_name = self.default_kernel_name
203 # kernel_manager_factory is the constructor for the KernelManager
204 # subclass we are using. It can be configured as any Configurable,
205 # including things like its transport and ip.
206 constructor_kwargs = {}
207 if self.kernel_spec_manager:
208 constructor_kwargs["kernel_spec_manager"] = self.kernel_spec_manager
209 km = self.kernel_manager_factory(
210 connection_file=os.path.join(self.connection_dir, "kernel-%s.json" % kernel_id),
211 parent=self,
212 log=self.log,
213 kernel_name=kernel_name,
214 **constructor_kwargs,
215 )
216 return km, kernel_name, kernel_id
217
218 def update_env(self, *, kernel_id: str, env: t.Dict[str, str]) -> None:
219 """
220 Allow to update the environment of the given kernel.
221
222 Forward the update env request to the corresponding kernel.
223
224 .. version-added: 8.5
225 """
226 if kernel_id in self:
227 self._kernels[kernel_id].update_env(env=env)
228
229 async def _add_kernel_when_ready(
230 self, kernel_id: str, km: KernelManager, kernel_awaitable: t.Awaitable
231 ) -> None:
232 try:
233 await kernel_awaitable
234 self._kernels[kernel_id] = km
235 self._pending_kernels.pop(kernel_id, None)
236 except Exception as e:
237 self.log.exception(e)
238
239 async def _remove_kernel_when_ready(
240 self, kernel_id: str, kernel_awaitable: t.Awaitable
241 ) -> None:
242 try:
243 await kernel_awaitable
244 self.remove_kernel(kernel_id)
245 self._pending_kernels.pop(kernel_id, None)
246 except Exception as e:
247 self.log.exception(e)
248
249 def _using_pending_kernels(self) -> bool:
250 """Returns a boolean; a clearer method for determining if
251 this multikernelmanager is using pending kernels or not
252 """
253 return getattr(self, "use_pending_kernels", False)
254
255 async def _async_start_kernel(self, *, kernel_name: str | None = None, **kwargs: t.Any) -> str:
256 """Start a new kernel.
257
258 The caller can pick a kernel_id by passing one in as a keyword arg,
259 otherwise one will be generated using new_kernel_id().
260
261 The kernel ID for the newly started kernel is returned.
262 """
263 km, kernel_name, kernel_id = self.pre_start_kernel(kernel_name, kwargs)
264 if not isinstance(km, KernelManager):
265 self.log.warning( # type:ignore[unreachable]
266 f"Kernel manager class ({self.kernel_manager_class.__class__}) is not an instance of 'KernelManager'!"
267 )
268 kwargs["kernel_id"] = kernel_id # Make kernel_id available to manager and provisioner
269
270 starter = ensure_async(km.start_kernel(**kwargs))
271 task = asyncio.create_task(self._add_kernel_when_ready(kernel_id, km, starter))
272 self._pending_kernels[kernel_id] = task
273 # Handling a Pending Kernel
274 if self._using_pending_kernels():
275 # If using pending kernels, do not block
276 # on the kernel start.
277 self._kernels[kernel_id] = km
278 else:
279 await task
280 # raise an exception if one occurred during kernel startup.
281 if km.ready.exception():
282 raise km.ready.exception() # type: ignore[misc]
283
284 return kernel_id
285
286 start_kernel = run_sync(_async_start_kernel)
287
288 async def _async_shutdown_kernel(
289 self,
290 kernel_id: str,
291 now: bool | None = False,
292 restart: bool | None = False,
293 ) -> None:
294 """Shutdown a kernel by its kernel uuid.
295
296 Parameters
297 ==========
298 kernel_id : uuid
299 The id of the kernel to shutdown.
300 now : bool
301 Should the kernel be shutdown forcibly using a signal.
302 restart : bool
303 Will the kernel be restarted?
304 """
305 self.log.info("Kernel shutdown: %s", kernel_id)
306 # If the kernel is still starting, wait for it to be ready.
307 if kernel_id in self._pending_kernels:
308 task = self._pending_kernels[kernel_id]
309 try:
310 await task
311 km = self.get_kernel(kernel_id)
312 await t.cast(asyncio.Future, km.ready)
313 except asyncio.CancelledError:
314 pass
315 except Exception:
316 self.remove_kernel(kernel_id)
317 return
318 km = self.get_kernel(kernel_id)
319 # If a pending kernel raised an exception, remove it.
320 if not km.ready.cancelled() and km.ready.exception():
321 self.remove_kernel(kernel_id)
322 return
323 stopper = ensure_async(km.shutdown_kernel(now, restart))
324 fut = asyncio.ensure_future(self._remove_kernel_when_ready(kernel_id, stopper))
325 self._pending_kernels[kernel_id] = fut
326 # Await the kernel if not using pending kernels.
327 if not self._using_pending_kernels():
328 await fut
329 # raise an exception if one occurred during kernel shutdown.
330 if km.ready.exception():
331 raise km.ready.exception() # type: ignore[misc]
332
333 shutdown_kernel = run_sync(_async_shutdown_kernel)
334
335 @kernel_method
336 def request_shutdown(self, kernel_id: str, restart: bool | None = False) -> None:
337 """Ask a kernel to shut down by its kernel uuid"""
338
339 @kernel_method
340 def finish_shutdown(
341 self,
342 kernel_id: str,
343 waittime: float | None = None,
344 pollinterval: float | None = 0.1,
345 ) -> None:
346 """Wait for a kernel to finish shutting down, and kill it if it doesn't"""
347 self.log.info("Kernel shutdown: %s", kernel_id)
348
349 @kernel_method
350 def cleanup_resources(self, kernel_id: str, restart: bool = False) -> None:
351 """Clean up a kernel's resources"""
352
353 def remove_kernel(self, kernel_id: str) -> KernelManager:
354 """remove a kernel from our mapping.
355
356 Mainly so that a kernel can be removed if it is already dead,
357 without having to call shutdown_kernel.
358
359 The kernel object is returned, or `None` if not found.
360 """
361 return self._kernels.pop(kernel_id, None)
362
363 async def _async_shutdown_all(self, now: bool = False) -> None:
364 """Shutdown all kernels."""
365 kids = self.list_kernel_ids()
366 kids += list(self._pending_kernels)
367 kms = list(self._kernels.values())
368 futs = [self._async_shutdown_kernel(kid, now=now) for kid in set(kids)]
369 await asyncio.gather(*futs)
370 # If using pending kernels, the kernels will not have been fully shut down.
371 if self._using_pending_kernels():
372 for km in kms:
373 try:
374 await km.ready
375 except asyncio.CancelledError:
376 self._pending_kernels[km.kernel_id].cancel()
377 except Exception:
378 # Will have been logged in _add_kernel_when_ready
379 pass
380
381 shutdown_all = run_sync(_async_shutdown_all)
382
383 def interrupt_kernel(self, kernel_id: str) -> None:
384 """Interrupt (SIGINT) the kernel by its uuid.
385
386 Parameters
387 ==========
388 kernel_id : uuid
389 The id of the kernel to interrupt.
390 """
391 kernel = self.get_kernel(kernel_id)
392 if not kernel.ready.done():
393 msg = "Kernel is in a pending state. Cannot interrupt."
394 raise RuntimeError(msg)
395 out = kernel.interrupt_kernel()
396 self.log.info("Kernel interrupted: %s", kernel_id)
397 return out
398
399 @kernel_method
400 def signal_kernel(self, kernel_id: str, signum: int) -> None:
401 """Sends a signal to the kernel by its uuid.
402
403 Note that since only SIGTERM is supported on Windows, this function
404 is only useful on Unix systems.
405
406 Parameters
407 ==========
408 kernel_id : uuid
409 The id of the kernel to signal.
410 signum : int
411 Signal number to send kernel.
412 """
413 self.log.info("Signaled Kernel %s with %s", kernel_id, signum)
414
415 async def _async_restart_kernel(self, kernel_id: str, now: bool = False) -> None:
416 """Restart a kernel by its uuid, keeping the same ports.
417
418 Parameters
419 ==========
420 kernel_id : uuid
421 The id of the kernel to interrupt.
422 now : bool, optional
423 If True, the kernel is forcefully restarted *immediately*, without
424 having a chance to do any cleanup action. Otherwise the kernel is
425 given 1s to clean up before a forceful restart is issued.
426
427 In all cases the kernel is restarted, the only difference is whether
428 it is given a chance to perform a clean shutdown or not.
429 """
430 kernel = self.get_kernel(kernel_id)
431 if self._using_pending_kernels() and not kernel.ready.done():
432 msg = "Kernel is in a pending state. Cannot restart."
433 raise RuntimeError(msg)
434 await ensure_async(kernel.restart_kernel(now=now))
435 self.log.info("Kernel restarted: %s", kernel_id)
436
437 restart_kernel = run_sync(_async_restart_kernel)
438
439 @kernel_method
440 def is_alive(self, kernel_id: str) -> bool: # type:ignore[empty-body]
441 """Is the kernel alive.
442
443 This calls KernelManager.is_alive() which calls Popen.poll on the
444 actual kernel subprocess.
445
446 Parameters
447 ==========
448 kernel_id : uuid
449 The id of the kernel.
450 """
451
452 def _check_kernel_id(self, kernel_id: str) -> None:
453 """check that a kernel id is valid"""
454 if kernel_id not in self:
455 raise KeyError("Kernel with id not found: %s" % kernel_id)
456
457 def get_kernel(self, kernel_id: str) -> KernelManager:
458 """Get the single KernelManager object for a kernel by its uuid.
459
460 Parameters
461 ==========
462 kernel_id : uuid
463 The id of the kernel.
464 """
465 self._check_kernel_id(kernel_id)
466 return self._kernels[kernel_id]
467
468 @kernel_method
469 def add_restart_callback(
470 self, kernel_id: str, callback: t.Callable, event: str = "restart"
471 ) -> None:
472 """add a callback for the KernelRestarter"""
473
474 @kernel_method
475 def remove_restart_callback(
476 self, kernel_id: str, callback: t.Callable, event: str = "restart"
477 ) -> None:
478 """remove a callback for the KernelRestarter"""
479
480 @kernel_method
481 def get_connection_info(self, kernel_id: str) -> dict[str, t.Any]: # type:ignore[empty-body]
482 """Return a dictionary of connection data for a kernel.
483
484 Parameters
485 ==========
486 kernel_id : uuid
487 The id of the kernel.
488
489 Returns
490 =======
491 connection_dict : dict
492 A dict of the information needed to connect to a kernel.
493 This includes the ip address and the integer port
494 numbers of the different channels (stdin_port, iopub_port,
495 shell_port, hb_port).
496 """
497
498 @kernel_method
499 def connect_iopub( # type:ignore[empty-body]
500 self, kernel_id: str, identity: bytes | None = None
501 ) -> socket.socket:
502 """Return a zmq Socket connected to the iopub channel.
503
504 Parameters
505 ==========
506 kernel_id : uuid
507 The id of the kernel
508 identity : bytes (optional)
509 The zmq identity of the socket
510
511 Returns
512 =======
513 stream : zmq Socket or ZMQStream
514 """
515
516 @kernel_method
517 def connect_shell( # type:ignore[empty-body]
518 self, kernel_id: str, identity: bytes | None = None
519 ) -> socket.socket:
520 """Return a zmq Socket connected to the shell channel.
521
522 Parameters
523 ==========
524 kernel_id : uuid
525 The id of the kernel
526 identity : bytes (optional)
527 The zmq identity of the socket
528
529 Returns
530 =======
531 stream : zmq Socket or ZMQStream
532 """
533
534 @kernel_method
535 def connect_control( # type:ignore[empty-body]
536 self, kernel_id: str, identity: bytes | None = None
537 ) -> socket.socket:
538 """Return a zmq Socket connected to the control channel.
539
540 Parameters
541 ==========
542 kernel_id : uuid
543 The id of the kernel
544 identity : bytes (optional)
545 The zmq identity of the socket
546
547 Returns
548 =======
549 stream : zmq Socket or ZMQStream
550 """
551
552 @kernel_method
553 def connect_stdin( # type:ignore[empty-body]
554 self, kernel_id: str, identity: bytes | None = None
555 ) -> socket.socket:
556 """Return a zmq Socket connected to the stdin channel.
557
558 Parameters
559 ==========
560 kernel_id : uuid
561 The id of the kernel
562 identity : bytes (optional)
563 The zmq identity of the socket
564
565 Returns
566 =======
567 stream : zmq Socket or ZMQStream
568 """
569
570 @kernel_method
571 def connect_hb( # type:ignore[empty-body]
572 self, kernel_id: str, identity: bytes | None = None
573 ) -> socket.socket:
574 """Return a zmq Socket connected to the hb channel.
575
576 Parameters
577 ==========
578 kernel_id : uuid
579 The id of the kernel
580 identity : bytes (optional)
581 The zmq identity of the socket
582
583 Returns
584 =======
585 stream : zmq Socket or ZMQStream
586 """
587
588 def new_kernel_id(self, **kwargs: t.Any) -> str:
589 """
590 Returns the id to associate with the kernel for this request. Subclasses may override
591 this method to substitute other sources of kernel ids.
592 :param kwargs:
593 :return: string-ized version 4 uuid
594 """
595 return str(uuid.uuid4())
596
597
598class AsyncMultiKernelManager(MultiKernelManager):
599 kernel_manager_class = DottedObjectName(
600 "jupyter_client.ioloop.AsyncIOLoopKernelManager",
601 config=True,
602 help="""The kernel manager class. This is configurable to allow
603 subclassing of the AsyncKernelManager for customized behavior.
604 """,
605 )
606
607 use_pending_kernels = Bool(
608 False,
609 help="""Whether to make kernels available before the process has started. The
610 kernel has a `.ready` future which can be awaited before connecting""",
611 ).tag(config=True)
612
613 context = Instance("zmq.asyncio.Context")
614
615 @default("context")
616 def _context_default(self) -> zmq.asyncio.Context:
617 self._created_context = True
618 return zmq.asyncio.Context()
619
620 start_kernel: t.Callable[..., t.Awaitable] = MultiKernelManager._async_start_kernel # type:ignore[assignment]
621 restart_kernel: t.Callable[..., t.Awaitable] = MultiKernelManager._async_restart_kernel # type:ignore[assignment]
622 shutdown_kernel: t.Callable[..., t.Awaitable] = MultiKernelManager._async_shutdown_kernel # type:ignore[assignment]
623 shutdown_all: t.Callable[..., t.Awaitable] = MultiKernelManager._async_shutdown_all # type:ignore[assignment]