1"""Implements an async kernel client"""
2
3# Copyright (c) Jupyter Development Team.
4# Distributed under the terms of the Modified BSD License.
5from __future__ import annotations
6
7import typing as t
8
9import zmq.asyncio
10from traitlets import Instance, Type
11
12from ..channels import AsyncZMQSocketChannel, HBChannel
13from ..client import KernelClient, reqrep
14
15
16def wrapped(meth: t.Callable, channel: str) -> t.Callable:
17 """Wrap a method on a channel and handle replies."""
18
19 def _(self: AsyncKernelClient, *args: t.Any, **kwargs: t.Any) -> t.Any:
20 reply = kwargs.pop("reply", False)
21 timeout = kwargs.pop("timeout", None)
22 msg_id = meth(self, *args, **kwargs)
23 if not reply:
24 return msg_id
25 return self._recv_reply(msg_id, timeout=timeout, channel=channel)
26
27 return _
28
29
30class AsyncKernelClient(KernelClient):
31 """A KernelClient with async APIs
32
33 ``get_[channel]_msg()`` methods wait for and return messages on channels,
34 raising :exc:`queue.Empty` if no message arrives within ``timeout`` seconds.
35 """
36
37 context = Instance(zmq.asyncio.Context) # type:ignore[assignment]
38
39 def _context_default(self) -> zmq.asyncio.Context:
40 self._created_context = True
41 return zmq.asyncio.Context()
42
43 # --------------------------------------------------------------------------
44 # Channel proxy methods
45 # --------------------------------------------------------------------------
46
47 get_shell_msg = KernelClient._async_get_shell_msg
48 get_iopub_msg = KernelClient._async_get_iopub_msg
49 get_stdin_msg = KernelClient._async_get_stdin_msg
50 get_control_msg = KernelClient._async_get_control_msg
51
52 wait_for_ready = KernelClient._async_wait_for_ready
53
54 # The classes to use for the various channels
55 shell_channel_class = Type(AsyncZMQSocketChannel) # type:ignore[assignment]
56 iopub_channel_class = Type(AsyncZMQSocketChannel) # type:ignore[assignment]
57 stdin_channel_class = Type(AsyncZMQSocketChannel) # type:ignore[assignment]
58 hb_channel_class = Type(HBChannel) # type:ignore[assignment]
59 control_channel_class = Type(AsyncZMQSocketChannel) # type:ignore[assignment]
60
61 _recv_reply = KernelClient._async_recv_reply
62
63 # replies come on the shell channel
64 execute = reqrep(wrapped, KernelClient.execute)
65 history = reqrep(wrapped, KernelClient.history)
66 complete = reqrep(wrapped, KernelClient.complete)
67 is_complete = reqrep(wrapped, KernelClient.is_complete)
68 inspect = reqrep(wrapped, KernelClient.inspect)
69 kernel_info = reqrep(wrapped, KernelClient.kernel_info)
70 comm_info = reqrep(wrapped, KernelClient.comm_info)
71
72 is_alive = KernelClient._async_is_alive
73 execute_interactive = KernelClient._async_execute_interactive
74
75 # replies come on the control channel
76 shutdown = reqrep(wrapped, KernelClient.shutdown, channel="control")