Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/dask/threaded.py: 44%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2A threaded shared-memory scheduler
4See local.py
5"""
7from __future__ import annotations
9import atexit
10import contextvars
11import multiprocessing.pool
12import sys
13import threading
14from collections import defaultdict
15from collections.abc import Mapping, Sequence
16from concurrent.futures import Executor, ThreadPoolExecutor
17from threading import Lock, current_thread
19from dask import config
20from dask.local import MultiprocessingPoolExecutor, get_async
21from dask.system import CPU_COUNT
22from dask.typing import Key
25def _thread_get_id():
26 return current_thread().ident
29main_thread = current_thread()
30default_pool: Executor | None = None
31pools: defaultdict[threading.Thread, dict[int, Executor]] = defaultdict(dict)
32pools_lock = Lock()
35def pack_exception(e, dumps):
36 return e, sys.exc_info()[2]
39class ContextAwareThreadPoolExecutor(ThreadPoolExecutor):
40 """Variant ThreadPoolExecutor that propagates contextvars
41 from the submitting thread to the worker threads.
43 With a vanilla ThreadPoolExecutor, contextvars are automatically propagated on
44 CPython 3.14t (noGIL) if and only if they are set before the worker threads are
45 started.
46 This variant propagates contextvars on all Python interpreters and also when the
47 worker threads are already warm when the variables are set in the submitting thread.
49 This also affects catching warnings, which on 3.14t use contextvars.
51 See Also
52 --------
53 https://docs.python.org/3/using/cmdline.html#envvar-PYTHON_THREAD_INHERIT_CONTEXT
54 https://docs.python.org/3/using/cmdline.html#envvar-PYTHON_CONTEXT_AWARE_WARNINGS
55 """
57 def submit(self, fn, /, *args, **kwargs):
58 ctx = contextvars.copy_context()
59 return super().submit(ctx.run, fn, *args, **kwargs)
62def get(
63 dsk: Mapping,
64 keys: Sequence[Key] | Key,
65 cache=None,
66 num_workers=None,
67 pool=None,
68 **kwargs,
69):
70 """Threaded cached implementation of dask.get
72 Parameters
73 ----------
75 dsk: dict
76 A dask dictionary specifying a workflow
77 keys: key or list of keys
78 Keys corresponding to desired data
79 num_workers: integer of thread count
80 The number of threads to use in the ThreadPool that will actually execute tasks
81 cache: dict-like (optional)
82 Temporary storage of results
84 Examples
85 --------
86 >>> inc = lambda x: x + 1
87 >>> add = lambda x, y: x + y
88 >>> dsk = {'x': 1, 'y': 2, 'z': (inc, 'x'), 'w': (add, 'z', 'y')}
89 >>> get(dsk, 'w')
90 4
91 >>> get(dsk, ['w', 'y'])
92 (4, 2)
93 """
94 global default_pool
95 pool = pool or config.get("pool", None)
96 num_workers = num_workers or config.get("num_workers", None)
97 thread = current_thread()
99 with pools_lock:
100 if pool is None:
101 if num_workers is None and thread is main_thread:
102 if default_pool is None:
103 default_pool = ContextAwareThreadPoolExecutor(CPU_COUNT)
104 atexit.register(default_pool.shutdown)
105 pool = default_pool
106 elif thread in pools and num_workers in pools[thread]:
107 pool = pools[thread][num_workers]
108 else:
109 pool = ContextAwareThreadPoolExecutor(num_workers)
110 atexit.register(pool.shutdown)
111 pools[thread][num_workers] = pool
112 elif isinstance(pool, multiprocessing.pool.Pool):
113 pool = MultiprocessingPoolExecutor(pool)
115 results = get_async(
116 pool.submit,
117 pool._max_workers,
118 dsk,
119 keys,
120 cache=cache,
121 get_id=_thread_get_id,
122 pack_exception=pack_exception,
123 **kwargs,
124 )
126 # Cleanup pools associated to dead threads
127 with pools_lock:
128 active_threads = set(threading.enumerate())
129 if thread is not main_thread:
130 for t in list(pools):
131 if t not in active_threads:
132 for p in pools.pop(t).values():
133 p.shutdown()
135 return results