Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/dask/utils_test.py: 28%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

96 statements  

1from __future__ import annotations 

2 

3import contextlib 

4import importlib 

5import time 

6from typing import TYPE_CHECKING 

7 

8if TYPE_CHECKING: 

9 from dask.highlevelgraph import HighLevelGraph, Layer 

10 

11 

12def inc(x): 

13 return x + 1 

14 

15 

16def dec(x): 

17 return x - 1 

18 

19 

20def add(x, y): 

21 return x + y 

22 

23 

24def slowadd(a, b, delay=0.1): 

25 time.sleep(delay) 

26 return a + b 

27 

28 

29class GetFunctionTestMixin: 

30 """ 

31 The GetFunctionTestCase class can be imported and used to test foreign 

32 implementations of the `get` function specification. It aims to enforce all 

33 known expectations of `get` functions. 

34 

35 To use the class, inherit from it and override the `get` function. For 

36 example: 

37 

38 > from dask.utils_test import GetFunctionTestMixin 

39 > class TestCustomGet(GetFunctionTestMixin): 

40 get = staticmethod(myget) 

41 

42 Note that the foreign `myget` function has to be explicitly decorated as a 

43 staticmethod. 

44 """ 

45 

46 def test_get(self): 

47 d = {":x": 1, ":y": (inc, ":x"), ":z": (add, ":x", ":y")} 

48 

49 assert self.get(d, ":x") == 1 

50 assert self.get(d, ":y") == 2 

51 assert self.get(d, ":z") == 3 

52 

53 def test_badkey(self): 

54 d = {":x": 1, ":y": (inc, ":x"), ":z": (add, ":x", ":y")} 

55 try: 

56 result = self.get(d, "badkey") 

57 except KeyError: 

58 pass 

59 else: 

60 msg = "Expected `{}` with badkey to raise KeyError.\n" 

61 msg += f"Obtained '{result}' instead." 

62 assert False, msg.format(self.get.__name__) 

63 

64 def test_nested_badkey(self): 

65 d = {"x": 1, "y": 2, "z": (sum, ["x", "y"])} 

66 

67 try: 

68 result = self.get(d, [["badkey"], "y"]) 

69 except KeyError: 

70 pass 

71 else: 

72 msg = "Expected `{}` with badkey to raise KeyError.\n" 

73 msg += f"Obtained '{result}' instead." 

74 assert False, msg.format(self.get.__name__) 

75 

76 def test_data_not_in_dict_is_ok(self): 

77 d = {"x": 1, "y": (add, "x", 10)} 

78 assert self.get(d, "y") == 11 

79 

80 def test_get_with_list(self): 

81 d = {"x": 1, "y": 2, "z": (sum, ["x", "y"])} 

82 

83 assert self.get(d, ["x", "y"]) == (1, 2) 

84 assert self.get(d, "z") == 3 

85 

86 def test_get_with_list_top_level(self): 

87 d = { 

88 "a": [1, 2, 3], 

89 "b": "a", 

90 "c": [1, (inc, 1)], 

91 "d": [(sum, "a")], 

92 "e": ["a", "b"], 

93 "f": [[[(sum, "a"), "c"], (sum, "b")], 2], 

94 } 

95 assert self.get(d, "a") == [1, 2, 3] 

96 assert self.get(d, "b") == [1, 2, 3] 

97 assert self.get(d, "c") == [1, 2] 

98 assert self.get(d, "d") == [6] 

99 assert self.get(d, "e") == [[1, 2, 3], [1, 2, 3]] 

100 assert self.get(d, "f") == [[[6, [1, 2]], 6], 2] 

101 

102 def test_get_with_nested_list(self): 

103 d = {"x": 1, "y": 2, "z": (sum, ["x", "y"])} 

104 

105 assert self.get(d, [["x"], "y"]) == ((1,), 2) 

106 assert self.get(d, "z") == 3 

107 

108 def test_get_works_with_unhashables_in_values(self): 

109 f = lambda x, y: x + len(y) 

110 d = {"x": 1, "y": (f, "x", {1})} 

111 

112 assert self.get(d, "y") == 2 

113 

114 def test_nested_tasks(self): 

115 d = {"x": 1, "y": (inc, "x"), "z": (add, (inc, "x"), "y")} 

116 

117 assert self.get(d, "z") == 4 

118 

119 def test_get_stack_limit(self): 

120 d = {"x%d" % (i + 1): (inc, "x%d" % i) for i in range(10000)} 

121 d["x0"] = 0 

122 assert self.get(d, "x10000") == 10000 

123 

124 def test_with_HighLevelGraph(self): 

125 from dask.highlevelgraph import HighLevelGraph 

126 

127 layers = {"a": {"x": 1, "y": (inc, "x")}, "b": {"z": (add, (inc, "x"), "y")}} 

128 dependencies = {"a": (), "b": {"a"}} 

129 graph = HighLevelGraph(layers, dependencies) 

130 assert self.get(graph, "z") == 4 

131 

132 

133def import_or_none(name): 

134 """Import a module and return it; in case of failure; return None""" 

135 try: 

136 return importlib.import_module(name) 

137 except (ImportError, AttributeError): 

138 return None 

139 

140 

141def hlg_layer(hlg: HighLevelGraph, prefix: str) -> Layer: 

142 "Get the first layer from a HighLevelGraph whose name starts with a prefix" 

143 for key, lyr in hlg.layers.items(): 

144 if key.startswith(prefix): 

145 return lyr 

146 raise KeyError(f"No layer starts with {prefix!r}: {list(hlg.layers)}") 

147 

148 

149def hlg_layer_topological(hlg: HighLevelGraph, i: int) -> Layer: 

150 "Get the layer from a HighLevelGraph at position ``i``, topologically" 

151 return hlg.layers[hlg._toposort_layers()[i]] 

152 

153 

154@contextlib.contextmanager 

155def _check_warning(condition: bool, category: type[Warning], message: str): 

156 """Conditionally check if a warning is raised""" 

157 if condition: 

158 import pytest 

159 

160 with pytest.warns(category, match=message) as ctx: 

161 yield ctx 

162 else: 

163 with contextlib.nullcontext() as ctx: 

164 yield ctx