Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/dask/threaded.py: 43%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

49 statements  

1""" 

2A threaded shared-memory scheduler 

3 

4See local.py 

5""" 

6 

7from __future__ import annotations 

8 

9import atexit 

10import multiprocessing.pool 

11import sys 

12import threading 

13from collections import defaultdict 

14from collections.abc import Mapping, Sequence 

15from concurrent.futures import Executor, ThreadPoolExecutor 

16from threading import Lock, current_thread 

17 

18from dask import config 

19from dask.local import MultiprocessingPoolExecutor, get_async 

20from dask.system import CPU_COUNT 

21from dask.typing import Key 

22 

23 

24def _thread_get_id(): 

25 return current_thread().ident 

26 

27 

28main_thread = current_thread() 

29default_pool: Executor | None = None 

30pools: defaultdict[threading.Thread, dict[int, Executor]] = defaultdict(dict) 

31pools_lock = Lock() 

32 

33 

34def pack_exception(e, dumps): 

35 return e, sys.exc_info()[2] 

36 

37 

38def get( 

39 dsk: Mapping, 

40 keys: Sequence[Key] | Key, 

41 cache=None, 

42 num_workers=None, 

43 pool=None, 

44 **kwargs, 

45): 

46 """Threaded cached implementation of dask.get 

47 

48 Parameters 

49 ---------- 

50 

51 dsk: dict 

52 A dask dictionary specifying a workflow 

53 keys: key or list of keys 

54 Keys corresponding to desired data 

55 num_workers: integer of thread count 

56 The number of threads to use in the ThreadPool that will actually execute tasks 

57 cache: dict-like (optional) 

58 Temporary storage of results 

59 

60 Examples 

61 -------- 

62 >>> inc = lambda x: x + 1 

63 >>> add = lambda x, y: x + y 

64 >>> dsk = {'x': 1, 'y': 2, 'z': (inc, 'x'), 'w': (add, 'z', 'y')} 

65 >>> get(dsk, 'w') 

66 4 

67 >>> get(dsk, ['w', 'y']) 

68 (4, 2) 

69 """ 

70 global default_pool 

71 pool = pool or config.get("pool", None) 

72 num_workers = num_workers or config.get("num_workers", None) 

73 thread = current_thread() 

74 

75 with pools_lock: 

76 if pool is None: 

77 if num_workers is None and thread is main_thread: 

78 if default_pool is None: 

79 default_pool = ThreadPoolExecutor(CPU_COUNT) 

80 atexit.register(default_pool.shutdown) 

81 pool = default_pool 

82 elif thread in pools and num_workers in pools[thread]: 

83 pool = pools[thread][num_workers] 

84 else: 

85 pool = ThreadPoolExecutor(num_workers) 

86 atexit.register(pool.shutdown) 

87 pools[thread][num_workers] = pool 

88 elif isinstance(pool, multiprocessing.pool.Pool): 

89 pool = MultiprocessingPoolExecutor(pool) 

90 

91 results = get_async( 

92 pool.submit, 

93 pool._max_workers, 

94 dsk, 

95 keys, 

96 cache=cache, 

97 get_id=_thread_get_id, 

98 pack_exception=pack_exception, 

99 **kwargs, 

100 ) 

101 

102 # Cleanup pools associated to dead threads 

103 with pools_lock: 

104 active_threads = set(threading.enumerate()) 

105 if thread is not main_thread: 

106 for t in list(pools): 

107 if t not in active_threads: 

108 for p in pools.pop(t).values(): 

109 p.shutdown() 

110 

111 return results