1# This file is part of Hypothesis, which may be found at
2# https://github.com/HypothesisWorks/hypothesis/
3#
4# Copyright the Hypothesis Authors.
5# Individual contributors are listed in AUTHORS.rst and the git log.
6#
7# This Source Code Form is subject to the terms of the Mozilla Public License,
8# v. 2.0. If a copy of the MPL was not distributed with this file, You can
9# obtain one at https://mozilla.org/MPL/2.0/.
10
11import enum
12import hashlib
13import heapq
14import sys
15from collections import OrderedDict, abc
16from functools import lru_cache
17from typing import TYPE_CHECKING, List, Optional, Sequence, Tuple, Type, TypeVar, Union
18
19from hypothesis.errors import InvalidArgument
20from hypothesis.internal.compat import int_from_bytes
21from hypothesis.internal.floats import next_up
22
23if TYPE_CHECKING:
24 from hypothesis.internal.conjecture.data import ConjectureData
25
26
27LABEL_MASK = 2**64 - 1
28
29
30def calc_label_from_name(name: str) -> int:
31 hashed = hashlib.sha384(name.encode()).digest()
32 return int_from_bytes(hashed[:8])
33
34
35def calc_label_from_cls(cls: type) -> int:
36 return calc_label_from_name(cls.__qualname__)
37
38
39def combine_labels(*labels: int) -> int:
40 label = 0
41 for l in labels:
42 label = (label << 1) & LABEL_MASK
43 label ^= l
44 return label
45
46
47SAMPLE_IN_SAMPLER_LABEL = calc_label_from_name("a sample() in Sampler")
48ONE_FROM_MANY_LABEL = calc_label_from_name("one more from many()")
49
50
51T = TypeVar("T")
52
53
54def check_sample(
55 values: Union[Type[enum.Enum], Sequence[T]], strategy_name: str
56) -> Sequence[T]:
57 if "numpy" in sys.modules and isinstance(values, sys.modules["numpy"].ndarray):
58 if values.ndim != 1:
59 raise InvalidArgument(
60 "Only one-dimensional arrays are supported for sampling, "
61 f"and the given value has {values.ndim} dimensions (shape "
62 f"{values.shape}). This array would give samples of array slices "
63 "instead of elements! Use np.ravel(values) to convert "
64 "to a one-dimensional array, or tuple(values) if you "
65 "want to sample slices."
66 )
67 elif not isinstance(values, (OrderedDict, abc.Sequence, enum.EnumMeta)):
68 raise InvalidArgument(
69 f"Cannot sample from {values!r}, not an ordered collection. "
70 f"Hypothesis goes to some length to ensure that the {strategy_name} "
71 "strategy has stable results between runs. To replay a saved "
72 "example, the sampled values must have the same iteration order "
73 "on every run - ruling out sets, dicts, etc due to hash "
74 "randomization. Most cases can simply use `sorted(values)`, but "
75 "mixed types or special values such as math.nan require careful "
76 "handling - and note that when simplifying an example, "
77 "Hypothesis treats earlier values as simpler."
78 )
79 if isinstance(values, range):
80 return values
81 return tuple(values)
82
83
84class Sampler:
85 """Sampler based on Vose's algorithm for the alias method. See
86 http://www.keithschwarz.com/darts-dice-coins/ for a good explanation.
87
88 The general idea is that we store a table of triples (base, alternate, p).
89 base. We then pick a triple uniformly at random, and choose its alternate
90 value with probability p and else choose its base value. The triples are
91 chosen so that the resulting mixture has the right distribution.
92
93 We maintain the following invariants to try to produce good shrinks:
94
95 1. The table is in lexicographic (base, alternate) order, so that choosing
96 an earlier value in the list always lowers (or at least leaves
97 unchanged) the value.
98 2. base[i] < alternate[i], so that shrinking the draw always results in
99 shrinking the chosen element.
100 """
101
102 table: List[Tuple[int, int, float]] # (base_idx, alt_idx, alt_chance)
103
104 def __init__(self, weights: Sequence[float], *, observe: bool = True):
105 self.observe = observe
106
107 n = len(weights)
108 table: "list[list[int | float | None]]" = [[i, None, None] for i in range(n)]
109 total = sum(weights)
110 num_type = type(total)
111
112 zero = num_type(0) # type: ignore
113 one = num_type(1) # type: ignore
114
115 small: "List[int]" = []
116 large: "List[int]" = []
117
118 probabilities = [w / total for w in weights]
119 scaled_probabilities: "List[float]" = []
120
121 for i, alternate_chance in enumerate(probabilities):
122 scaled = alternate_chance * n
123 scaled_probabilities.append(scaled)
124 if scaled == 1:
125 table[i][2] = zero
126 elif scaled < 1:
127 small.append(i)
128 else:
129 large.append(i)
130 heapq.heapify(small)
131 heapq.heapify(large)
132
133 while small and large:
134 lo = heapq.heappop(small)
135 hi = heapq.heappop(large)
136
137 assert lo != hi
138 assert scaled_probabilities[hi] > one
139 assert table[lo][1] is None
140 table[lo][1] = hi
141 table[lo][2] = one - scaled_probabilities[lo]
142 scaled_probabilities[hi] = (
143 scaled_probabilities[hi] + scaled_probabilities[lo]
144 ) - one
145
146 if scaled_probabilities[hi] < 1:
147 heapq.heappush(small, hi)
148 elif scaled_probabilities[hi] == 1:
149 table[hi][2] = zero
150 else:
151 heapq.heappush(large, hi)
152 while large:
153 table[large.pop()][2] = zero
154 while small:
155 table[small.pop()][2] = zero
156
157 self.table: "List[Tuple[int, int, float]]" = []
158 for base, alternate, alternate_chance in table: # type: ignore
159 assert isinstance(base, int)
160 assert isinstance(alternate, int) or alternate is None
161 if alternate is None:
162 self.table.append((base, base, alternate_chance))
163 elif alternate < base:
164 self.table.append((alternate, base, one - alternate_chance))
165 else:
166 self.table.append((base, alternate, alternate_chance))
167 self.table.sort()
168
169 def sample(
170 self,
171 data: "ConjectureData",
172 *,
173 forced: Optional[int] = None,
174 fake_forced: bool = False,
175 ) -> int:
176 if self.observe:
177 data.start_example(SAMPLE_IN_SAMPLER_LABEL)
178 forced_choice = ( # pragma: no branch # https://github.com/nedbat/coveragepy/issues/1617
179 None
180 if forced is None
181 else next(
182 (base, alternate, alternate_chance)
183 for (base, alternate, alternate_chance) in self.table
184 if forced == base or (forced == alternate and alternate_chance > 0)
185 )
186 )
187 base, alternate, alternate_chance = data.choice(
188 self.table,
189 forced=forced_choice,
190 fake_forced=fake_forced,
191 observe=self.observe,
192 )
193 forced_use_alternate = None
194 if forced is not None:
195 # we maintain this invariant when picking forced_choice above.
196 # This song and dance about alternate_chance > 0 is to avoid forcing
197 # e.g. draw_boolean(p=0, forced=True), which is an error.
198 forced_use_alternate = forced == alternate and alternate_chance > 0
199 assert forced == base or forced_use_alternate
200
201 use_alternate = data.draw_boolean(
202 alternate_chance,
203 forced=forced_use_alternate,
204 fake_forced=fake_forced,
205 observe=self.observe,
206 )
207 if self.observe:
208 data.stop_example()
209 if use_alternate:
210 assert forced is None or alternate == forced, (forced, alternate)
211 return alternate
212 else:
213 assert forced is None or base == forced, (forced, base)
214 return base
215
216
217INT_SIZES = (8, 16, 32, 64, 128)
218INT_SIZES_SAMPLER = Sampler((4.0, 8.0, 1.0, 1.0, 0.5), observe=False)
219
220
221class many:
222 """Utility class for collections. Bundles up the logic we use for "should I
223 keep drawing more values?" and handles starting and stopping examples in
224 the right place.
225
226 Intended usage is something like:
227
228 elements = many(data, ...)
229 while elements.more():
230 add_stuff_to_result()
231 """
232
233 def __init__(
234 self,
235 data: "ConjectureData",
236 min_size: int,
237 max_size: Union[int, float],
238 average_size: Union[int, float],
239 *,
240 forced: Optional[int] = None,
241 fake_forced: bool = False,
242 observe: bool = True,
243 ) -> None:
244 assert 0 <= min_size <= average_size <= max_size
245 assert forced is None or min_size <= forced <= max_size
246 self.min_size = min_size
247 self.max_size = max_size
248 self.data = data
249 self.forced_size = forced
250 self.fake_forced = fake_forced
251 self.p_continue = _calc_p_continue(average_size - min_size, max_size - min_size)
252 self.count = 0
253 self.rejections = 0
254 self.drawn = False
255 self.force_stop = False
256 self.rejected = False
257 self.observe = observe
258
259 def stop_example(self):
260 if self.observe:
261 self.data.stop_example()
262
263 def start_example(self, label):
264 if self.observe:
265 self.data.start_example(label)
266
267 def more(self) -> bool:
268 """Should I draw another element to add to the collection?"""
269 if self.drawn:
270 self.stop_example()
271
272 self.drawn = True
273 self.rejected = False
274
275 self.start_example(ONE_FROM_MANY_LABEL)
276 if self.min_size == self.max_size:
277 # if we have to hit an exact size, draw unconditionally until that
278 # point, and no further.
279 should_continue = self.count < self.min_size
280 else:
281 forced_result = None
282 if self.force_stop:
283 # if our size is forced, we can't reject in a way that would
284 # cause us to differ from the forced size.
285 assert self.forced_size is None or self.count == self.forced_size
286 forced_result = False
287 elif self.count < self.min_size:
288 forced_result = True
289 elif self.count >= self.max_size:
290 forced_result = False
291 elif self.forced_size is not None:
292 forced_result = self.count < self.forced_size
293 should_continue = self.data.draw_boolean(
294 self.p_continue,
295 forced=forced_result,
296 fake_forced=self.fake_forced,
297 observe=self.observe,
298 )
299
300 if should_continue:
301 self.count += 1
302 return True
303 else:
304 self.stop_example()
305 return False
306
307 def reject(self, why: Optional[str] = None) -> None:
308 """Reject the last example (i.e. don't count it towards our budget of
309 elements because it's not going to go in the final collection)."""
310 assert self.count > 0
311 self.count -= 1
312 self.rejections += 1
313 self.rejected = True
314 # We set a minimum number of rejections before we give up to avoid
315 # failing too fast when we reject the first draw.
316 if self.rejections > max(3, 2 * self.count):
317 if self.count < self.min_size:
318 self.data.mark_invalid(why)
319 else:
320 self.force_stop = True
321
322
323SMALLEST_POSITIVE_FLOAT: float = next_up(0.0) or sys.float_info.min
324
325
326@lru_cache
327def _calc_p_continue(desired_avg: float, max_size: int) -> float:
328 """Return the p_continue which will generate the desired average size."""
329 assert desired_avg <= max_size, (desired_avg, max_size)
330 if desired_avg == max_size:
331 return 1.0
332 p_continue = 1 - 1.0 / (1 + desired_avg)
333 if p_continue == 0 or max_size == float("inf"):
334 assert 0 <= p_continue < 1, p_continue
335 return p_continue
336 assert 0 < p_continue < 1, p_continue
337 # For small max_size, the infinite-series p_continue is a poor approximation,
338 # and while we can't solve the polynomial a few rounds of iteration quickly
339 # gets us a good approximate solution in almost all cases (sometimes exact!).
340 while _p_continue_to_avg(p_continue, max_size) > desired_avg:
341 # This is impossible over the reals, but *can* happen with floats.
342 p_continue -= 0.0001
343 # If we've reached zero or gone negative, we want to break out of this loop,
344 # and do so even if we're on a system with the unsafe denormals-are-zero flag.
345 # We make that an explicit error in st.floats(), but here we'd prefer to
346 # just get somewhat worse precision on collection lengths.
347 if p_continue < SMALLEST_POSITIVE_FLOAT:
348 p_continue = SMALLEST_POSITIVE_FLOAT
349 break
350 # Let's binary-search our way to a better estimate! We tried fancier options
351 # like gradient descent, but this is numerically stable and works better.
352 hi = 1.0
353 while desired_avg - _p_continue_to_avg(p_continue, max_size) > 0.01:
354 assert 0 < p_continue < hi, (p_continue, hi)
355 mid = (p_continue + hi) / 2
356 if _p_continue_to_avg(mid, max_size) <= desired_avg:
357 p_continue = mid
358 else:
359 hi = mid
360 assert 0 < p_continue < 1, p_continue
361 assert _p_continue_to_avg(p_continue, max_size) <= desired_avg
362 return p_continue
363
364
365def _p_continue_to_avg(p_continue: float, max_size: int) -> float:
366 """Return the average_size generated by this p_continue and max_size."""
367 if p_continue >= 1:
368 return max_size
369 return (1.0 / (1 - p_continue) - 1) * (1 - p_continue**max_size)