1"""Implements an async kernel client"""
2# Copyright (c) Jupyter Development Team.
3# Distributed under the terms of the Modified BSD License.
4from __future__ import annotations
5
6import typing as t
7
8import zmq.asyncio
9from traitlets import Instance, Type
10
11from ..channels import AsyncZMQSocketChannel, HBChannel
12from ..client import KernelClient, reqrep
13
14
15def wrapped(meth: t.Callable, channel: str) -> t.Callable:
16 """Wrap a method on a channel and handle replies."""
17
18 def _(self: AsyncKernelClient, *args: t.Any, **kwargs: t.Any) -> t.Any:
19 reply = kwargs.pop("reply", False)
20 timeout = kwargs.pop("timeout", None)
21 msg_id = meth(self, *args, **kwargs)
22 if not reply:
23 return msg_id
24 return self._recv_reply(msg_id, timeout=timeout, channel=channel)
25
26 return _
27
28
29class AsyncKernelClient(KernelClient):
30 """A KernelClient with async APIs
31
32 ``get_[channel]_msg()`` methods wait for and return messages on channels,
33 raising :exc:`queue.Empty` if no message arrives within ``timeout`` seconds.
34 """
35
36 context = Instance(zmq.asyncio.Context) # type:ignore[arg-type]
37
38 def _context_default(self) -> zmq.asyncio.Context:
39 self._created_context = True
40 return zmq.asyncio.Context()
41
42 # --------------------------------------------------------------------------
43 # Channel proxy methods
44 # --------------------------------------------------------------------------
45
46 get_shell_msg = KernelClient._async_get_shell_msg
47 get_iopub_msg = KernelClient._async_get_iopub_msg
48 get_stdin_msg = KernelClient._async_get_stdin_msg
49 get_control_msg = KernelClient._async_get_control_msg
50
51 wait_for_ready = KernelClient._async_wait_for_ready
52
53 # The classes to use for the various channels
54 shell_channel_class = Type(AsyncZMQSocketChannel) # type:ignore[arg-type]
55 iopub_channel_class = Type(AsyncZMQSocketChannel) # type:ignore[arg-type]
56 stdin_channel_class = Type(AsyncZMQSocketChannel) # type:ignore[arg-type]
57 hb_channel_class = Type(HBChannel) # type:ignore[arg-type]
58 control_channel_class = Type(AsyncZMQSocketChannel) # type:ignore[arg-type]
59
60 _recv_reply = KernelClient._async_recv_reply
61
62 # replies come on the shell channel
63 execute = reqrep(wrapped, KernelClient.execute)
64 history = reqrep(wrapped, KernelClient.history)
65 complete = reqrep(wrapped, KernelClient.complete)
66 is_complete = reqrep(wrapped, KernelClient.is_complete)
67 inspect = reqrep(wrapped, KernelClient.inspect)
68 kernel_info = reqrep(wrapped, KernelClient.kernel_info)
69 comm_info = reqrep(wrapped, KernelClient.comm_info)
70
71 is_alive = KernelClient._async_is_alive
72 execute_interactive = KernelClient._async_execute_interactive
73
74 # replies come on the control channel
75 shutdown = reqrep(wrapped, KernelClient.shutdown, channel="control")