Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/dask/threaded.py: 44%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

54 statements  

1""" 

2A threaded shared-memory scheduler 

3 

4See local.py 

5""" 

6 

7from __future__ import annotations 

8 

9import atexit 

10import contextvars 

11import multiprocessing.pool 

12import sys 

13import threading 

14from collections import defaultdict 

15from collections.abc import Mapping, Sequence 

16from concurrent.futures import Executor, ThreadPoolExecutor 

17from threading import Lock, current_thread 

18 

19from dask import config 

20from dask.local import MultiprocessingPoolExecutor, get_async 

21from dask.system import CPU_COUNT 

22from dask.typing import Key 

23 

24 

25def _thread_get_id(): 

26 return current_thread().ident 

27 

28 

29main_thread = current_thread() 

30default_pool: Executor | None = None 

31pools: defaultdict[threading.Thread, dict[int, Executor]] = defaultdict(dict) 

32pools_lock = Lock() 

33 

34 

35def pack_exception(e, dumps): 

36 return e, sys.exc_info()[2] 

37 

38 

39class ContextAwareThreadPoolExecutor(ThreadPoolExecutor): 

40 """Variant ThreadPoolExecutor that propagates contextvars 

41 from the submitting thread to the worker threads. 

42 

43 With a vanilla ThreadPoolExecutor, contextvars are automatically propagated on 

44 CPython 3.14t (noGIL) if and only if they are set before the worker threads are 

45 started. 

46 This variant propagates contextvars on all Python interpreters and also when the 

47 worker threads are already warm when the variables are set in the submitting thread. 

48 

49 This also affects catching warnings, which on 3.14t use contextvars. 

50 

51 See Also 

52 -------- 

53 https://docs.python.org/3/using/cmdline.html#envvar-PYTHON_THREAD_INHERIT_CONTEXT 

54 https://docs.python.org/3/using/cmdline.html#envvar-PYTHON_CONTEXT_AWARE_WARNINGS 

55 """ 

56 

57 def submit(self, fn, /, *args, **kwargs): 

58 ctx = contextvars.copy_context() 

59 return super().submit(ctx.run, fn, *args, **kwargs) 

60 

61 

62def get( 

63 dsk: Mapping, 

64 keys: Sequence[Key] | Key, 

65 cache=None, 

66 num_workers=None, 

67 pool=None, 

68 **kwargs, 

69): 

70 """Threaded cached implementation of dask.get 

71 

72 Parameters 

73 ---------- 

74 

75 dsk: dict 

76 A dask dictionary specifying a workflow 

77 keys: key or list of keys 

78 Keys corresponding to desired data 

79 num_workers: integer of thread count 

80 The number of threads to use in the ThreadPool that will actually execute tasks 

81 cache: dict-like (optional) 

82 Temporary storage of results 

83 

84 Examples 

85 -------- 

86 >>> inc = lambda x: x + 1 

87 >>> add = lambda x, y: x + y 

88 >>> dsk = {'x': 1, 'y': 2, 'z': (inc, 'x'), 'w': (add, 'z', 'y')} 

89 >>> get(dsk, 'w') 

90 4 

91 >>> get(dsk, ['w', 'y']) 

92 (4, 2) 

93 """ 

94 global default_pool 

95 pool = pool or config.get("pool", None) 

96 num_workers = num_workers or config.get("num_workers", None) 

97 thread = current_thread() 

98 

99 with pools_lock: 

100 if pool is None: 

101 if num_workers is None and thread is main_thread: 

102 if default_pool is None: 

103 default_pool = ContextAwareThreadPoolExecutor(CPU_COUNT) 

104 atexit.register(default_pool.shutdown) 

105 pool = default_pool 

106 elif thread in pools and num_workers in pools[thread]: 

107 pool = pools[thread][num_workers] 

108 else: 

109 pool = ContextAwareThreadPoolExecutor(num_workers) 

110 atexit.register(pool.shutdown) 

111 pools[thread][num_workers] = pool 

112 elif isinstance(pool, multiprocessing.pool.Pool): 

113 pool = MultiprocessingPoolExecutor(pool) 

114 

115 results = get_async( 

116 pool.submit, 

117 pool._max_workers, 

118 dsk, 

119 keys, 

120 cache=cache, 

121 get_id=_thread_get_id, 

122 pack_exception=pack_exception, 

123 **kwargs, 

124 ) 

125 

126 # Cleanup pools associated to dead threads 

127 with pools_lock: 

128 active_threads = set(threading.enumerate()) 

129 if thread is not main_thread: 

130 for t in list(pools): 

131 if t not in active_threads: 

132 for p in pools.pop(t).values(): 

133 p.shutdown() 

134 

135 return results