Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/dask/utils_test.py: 28%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1from __future__ import annotations
3import contextlib
4import importlib
5import time
6from typing import TYPE_CHECKING
8if TYPE_CHECKING:
9 from dask.highlevelgraph import HighLevelGraph, Layer
12def inc(x):
13 return x + 1
16def dec(x):
17 return x - 1
20def add(x, y):
21 return x + y
24def slowadd(a, b, delay=0.1):
25 time.sleep(delay)
26 return a + b
29class GetFunctionTestMixin:
30 """
31 The GetFunctionTestCase class can be imported and used to test foreign
32 implementations of the `get` function specification. It aims to enforce all
33 known expectations of `get` functions.
35 To use the class, inherit from it and override the `get` function. For
36 example:
38 > from dask.utils_test import GetFunctionTestMixin
39 > class TestCustomGet(GetFunctionTestMixin):
40 get = staticmethod(myget)
42 Note that the foreign `myget` function has to be explicitly decorated as a
43 staticmethod.
44 """
46 def test_get(self):
47 d = {":x": 1, ":y": (inc, ":x"), ":z": (add, ":x", ":y")}
49 assert self.get(d, ":x") == 1
50 assert self.get(d, ":y") == 2
51 assert self.get(d, ":z") == 3
53 def test_badkey(self):
54 d = {":x": 1, ":y": (inc, ":x"), ":z": (add, ":x", ":y")}
55 try:
56 result = self.get(d, "badkey")
57 except KeyError:
58 pass
59 else:
60 msg = "Expected `{}` with badkey to raise KeyError.\n"
61 msg += f"Obtained '{result}' instead."
62 assert False, msg.format(self.get.__name__)
64 def test_nested_badkey(self):
65 d = {"x": 1, "y": 2, "z": (sum, ["x", "y"])}
67 try:
68 result = self.get(d, [["badkey"], "y"])
69 except KeyError:
70 pass
71 else:
72 msg = "Expected `{}` with badkey to raise KeyError.\n"
73 msg += f"Obtained '{result}' instead."
74 assert False, msg.format(self.get.__name__)
76 def test_data_not_in_dict_is_ok(self):
77 d = {"x": 1, "y": (add, "x", 10)}
78 assert self.get(d, "y") == 11
80 def test_get_with_list(self):
81 d = {"x": 1, "y": 2, "z": (sum, ["x", "y"])}
83 assert self.get(d, ["x", "y"]) == (1, 2)
84 assert self.get(d, "z") == 3
86 def test_get_with_list_top_level(self):
87 d = {
88 "a": [1, 2, 3],
89 "b": "a",
90 "c": [1, (inc, 1)],
91 "d": [(sum, "a")],
92 "e": ["a", "b"],
93 "f": [[[(sum, "a"), "c"], (sum, "b")], 2],
94 }
95 assert self.get(d, "a") == [1, 2, 3]
96 assert self.get(d, "b") == [1, 2, 3]
97 assert self.get(d, "c") == [1, 2]
98 assert self.get(d, "d") == [6]
99 assert self.get(d, "e") == [[1, 2, 3], [1, 2, 3]]
100 assert self.get(d, "f") == [[[6, [1, 2]], 6], 2]
102 def test_get_with_nested_list(self):
103 d = {"x": 1, "y": 2, "z": (sum, ["x", "y"])}
105 assert self.get(d, [["x"], "y"]) == ((1,), 2)
106 assert self.get(d, "z") == 3
108 def test_get_works_with_unhashables_in_values(self):
109 f = lambda x, y: x + len(y)
110 d = {"x": 1, "y": (f, "x", {1})}
112 assert self.get(d, "y") == 2
114 def test_nested_tasks(self):
115 d = {"x": 1, "y": (inc, "x"), "z": (add, (inc, "x"), "y")}
117 assert self.get(d, "z") == 4
119 def test_get_stack_limit(self):
120 d = {"x%d" % (i + 1): (inc, "x%d" % i) for i in range(10000)}
121 d["x0"] = 0
122 assert self.get(d, "x10000") == 10000
124 def test_with_HighLevelGraph(self):
125 from dask.highlevelgraph import HighLevelGraph
127 layers = {"a": {"x": 1, "y": (inc, "x")}, "b": {"z": (add, (inc, "x"), "y")}}
128 dependencies = {"a": (), "b": {"a"}}
129 graph = HighLevelGraph(layers, dependencies)
130 assert self.get(graph, "z") == 4
133def import_or_none(name):
134 """Import a module and return it; in case of failure; return None"""
135 try:
136 return importlib.import_module(name)
137 except (ImportError, AttributeError):
138 return None
141def hlg_layer(hlg: HighLevelGraph, prefix: str) -> Layer:
142 "Get the first layer from a HighLevelGraph whose name starts with a prefix"
143 for key, lyr in hlg.layers.items():
144 if key.startswith(prefix):
145 return lyr
146 raise KeyError(f"No layer starts with {prefix!r}: {list(hlg.layers)}")
149def hlg_layer_topological(hlg: HighLevelGraph, i: int) -> Layer:
150 "Get the layer from a HighLevelGraph at position ``i``, topologically"
151 return hlg.layers[hlg._toposort_layers()[i]]
154@contextlib.contextmanager
155def _check_warning(condition: bool, category: type[Warning], message: str):
156 """Conditionally check if a warning is raised"""
157 if condition:
158 import pytest
160 with pytest.warns(category, match=message) as ctx:
161 yield ctx
162 else:
163 with contextlib.nullcontext() as ctx:
164 yield ctx