1# Copyright (c) Jupyter Development Team.
2# Distributed under the terms of the Modified BSD License.
3from __future__ import annotations
4
5import asyncio
6import atexit
7import errno
8import inspect
9import sys
10import threading
11import warnings
12from collections.abc import Awaitable, Callable
13from contextvars import ContextVar
14from pathlib import Path
15from types import FrameType
16from typing import Any, TypeVar, cast
17
18
19def ensure_dir_exists(path: str | Path, mode: int = 0o777) -> None:
20 """Ensure that a directory exists
21
22 If it doesn't exist, try to create it, protecting against a race condition
23 if another process is doing the same.
24 The default permissions are determined by the current umask.
25 """
26 try:
27 Path(path).mkdir(parents=True, mode=mode)
28 except OSError as e:
29 if e.errno != errno.EEXIST:
30 raise
31 if not Path(path).is_dir():
32 msg = f"{path!r} exists but is not a directory"
33 raise OSError(msg)
34
35
36def _get_frame(level: int) -> FrameType | None:
37 """Get the frame at the given stack level."""
38 # sys._getframe is much faster than inspect.stack, but isn't guaranteed to
39 # exist in all python implementations, so we fall back to inspect.stack()
40
41 # We need to add one to level to account for this get_frame call.
42 if hasattr(sys, "_getframe"):
43 frame = sys._getframe(level + 1)
44 else:
45 frame = inspect.stack(context=0)[level + 1].frame
46 return frame
47
48
49# This function is from https://github.com/python/cpython/issues/67998
50# (https://bugs.python.org/file39550/deprecated_module_stacklevel.diff) and
51# calculates the appropriate stacklevel for deprecations to target the
52# deprecation for the caller, no matter how many internal stack frames we have
53# added in the process. For example, with the deprecation warning in the
54# __init__ below, the appropriate stacklevel will change depending on how deep
55# the inheritance hierarchy is.
56def _external_stacklevel(internal: list[str]) -> int:
57 """Find the stacklevel of the first frame that doesn't contain any of the given internal strings
58
59 The depth will be 1 at minimum in order to start checking at the caller of
60 the function that called this utility method.
61 """
62 # Get the level of my caller's caller
63 level = 2
64 frame = _get_frame(level)
65
66 # Normalize the path separators:
67 normalized_internal = [str(Path(s)) for s in internal]
68
69 # climb the stack frames while we see internal frames
70 while frame and any(s in str(Path(frame.f_code.co_filename)) for s in normalized_internal):
71 level += 1
72 frame = frame.f_back
73
74 # Return the stack level from the perspective of whoever called us (i.e., one level up)
75 return level - 1
76
77
78def deprecation(message: str, internal: str | list[str] = "jupyter_core/") -> None:
79 """Generate a deprecation warning targeting the first frame that is not 'internal'
80
81 internal is a string or list of strings, which if they appear in filenames in the
82 frames, the frames will be considered internal. Changing this can be useful if, for example,
83 we know that our internal code is calling out to another library.
84 """
85 _internal: list[str]
86 _internal = [internal] if isinstance(internal, str) else internal
87
88 # stack level of the first external frame from here
89 stacklevel = _external_stacklevel(_internal)
90
91 # The call to .warn adds one frame, so bump the stacklevel up by one
92 warnings.warn(message, DeprecationWarning, stacklevel=stacklevel + 1)
93
94
95T = TypeVar("T")
96
97
98class _TaskRunner:
99 """A task runner that runs an asyncio event loop on a background thread."""
100
101 def __init__(self) -> None:
102 self.__io_loop: asyncio.AbstractEventLoop | None = None
103 self.__runner_thread: threading.Thread | None = None
104 self.__lock = threading.Lock()
105 atexit.register(self._close)
106
107 def _close(self) -> None:
108 if self.__io_loop:
109 self.__io_loop.stop()
110
111 def _runner(self) -> None:
112 loop = self.__io_loop
113 assert loop is not None
114 try:
115 loop.run_forever()
116 finally:
117 loop.close()
118
119 def run(self, coro: Any) -> Any:
120 """Synchronously run a coroutine on a background thread."""
121 with self.__lock:
122 name = f"{threading.current_thread().name} - runner"
123 if self.__io_loop is None:
124 self.__io_loop = asyncio.new_event_loop()
125 self.__runner_thread = threading.Thread(target=self._runner, daemon=True, name=name)
126 self.__runner_thread.start()
127 fut = asyncio.run_coroutine_threadsafe(coro, self.__io_loop)
128 return fut.result(None)
129
130
131_runner_map: dict[str, _TaskRunner] = {}
132_loop: ContextVar[asyncio.AbstractEventLoop | None] = ContextVar("_loop", default=None)
133
134
135def run_sync(coro: Callable[..., Awaitable[T]]) -> Callable[..., T]:
136 """Wraps coroutine in a function that blocks until it has executed.
137
138 Parameters
139 ----------
140 coro : coroutine-function
141 The coroutine-function to be executed.
142
143 Returns
144 -------
145 result :
146 Whatever the coroutine-function returns.
147 """
148
149 assert inspect.iscoroutinefunction(coro)
150
151 def wrapped(*args: Any, **kwargs: Any) -> Any:
152 name = threading.current_thread().name
153 inner = coro(*args, **kwargs)
154
155 loop_running = False
156 try:
157 asyncio.get_running_loop()
158 loop_running = True
159 except RuntimeError:
160 pass
161
162 if not loop_running:
163 # No loop running, run the loop for this thread.
164 loop = ensure_event_loop()
165 return loop.run_until_complete(inner)
166
167 # Loop is currently running in this thread,
168 # use a task runner.
169 if name not in _runner_map:
170 _runner_map[name] = _TaskRunner()
171 return _runner_map[name].run(inner)
172
173 wrapped.__doc__ = coro.__doc__
174 return wrapped
175
176
177def ensure_event_loop(prefer_selector_loop: bool = False) -> asyncio.AbstractEventLoop:
178 # Get the loop for this thread, or create a new one.
179 loop = _loop.get()
180 if loop is not None and not loop.is_closed():
181 return loop
182 try:
183 loop = asyncio.get_running_loop()
184 except RuntimeError:
185 if sys.platform == "win32" and prefer_selector_loop:
186 if (3, 14) <= sys.version_info < (3, 15):
187 # ignore deprecation only for 3.14 and revisit later.
188 with warnings.catch_warnings():
189 warnings.filterwarnings(
190 "ignore",
191 category=DeprecationWarning,
192 message=".*WindowsSelectorEventLoopPolicy.*",
193 )
194 loop = asyncio.WindowsSelectorEventLoopPolicy().new_event_loop()
195 else:
196 loop = asyncio.WindowsSelectorEventLoopPolicy().new_event_loop()
197 else:
198 loop = asyncio.new_event_loop()
199 asyncio.set_event_loop(loop)
200 _loop.set(loop)
201 return loop
202
203
204async def ensure_async(obj: Awaitable[T] | T) -> T:
205 """Convert a non-awaitable object to a coroutine if needed,
206 and await it if it was not already awaited.
207
208 This function is meant to be called on the result of calling a function,
209 when that function could either be asynchronous or not.
210 """
211 if inspect.isawaitable(obj):
212 obj = cast(Awaitable[T], obj)
213 try:
214 result = await obj
215 except RuntimeError as e:
216 if str(e) == "cannot reuse already awaited coroutine":
217 # obj is already the coroutine's result
218 return cast(T, obj)
219 raise
220 return result
221 return obj