1# This file is part of Hypothesis, which may be found at
2# https://github.com/HypothesisWorks/hypothesis/
3#
4# Copyright the Hypothesis Authors.
5# Individual contributors are listed in AUTHORS.rst and the git log.
6#
7# This Source Code Form is subject to the terms of the Mozilla Public License,
8# v. 2.0. If a copy of the MPL was not distributed with this file, You can
9# obtain one at https://mozilla.org/MPL/2.0/.
10
11import enum
12import hashlib
13import heapq
14import math
15import sys
16from collections import OrderedDict, abc
17from collections.abc import Callable, Sequence
18from functools import lru_cache
19from types import FunctionType
20from typing import TYPE_CHECKING, TypeVar
21
22from hypothesis.errors import InvalidArgument
23from hypothesis.internal.compat import int_from_bytes
24from hypothesis.internal.floats import next_up
25from hypothesis.internal.lambda_sources import _function_key
26
27if TYPE_CHECKING:
28 from hypothesis.internal.conjecture.data import ConjectureData
29
30
31LABEL_MASK = 2**64 - 1
32
33
34def calc_label_from_name(name: str) -> int:
35 hashed = hashlib.sha384(name.encode()).digest()
36 return int_from_bytes(hashed[:8])
37
38
39def calc_label_from_callable(f: Callable) -> int:
40 if isinstance(f, FunctionType):
41 return calc_label_from_hash(_function_key(f, ignore_name=True))
42 elif isinstance(f, type):
43 return calc_label_from_cls(f)
44 else:
45 # probably an instance defining __call__
46 try:
47 return calc_label_from_hash(f)
48 except Exception:
49 # not hashable
50 return calc_label_from_cls(type(f))
51
52
53def calc_label_from_cls(cls: type) -> int:
54 return calc_label_from_name(cls.__qualname__)
55
56
57def calc_label_from_hash(obj: object) -> int:
58 return calc_label_from_name(str(hash(obj)))
59
60
61def combine_labels(*labels: int) -> int:
62 label = 0
63 for l in labels:
64 label = (label << 1) & LABEL_MASK
65 label ^= l
66 return label
67
68
69SAMPLE_IN_SAMPLER_LABEL = calc_label_from_name("a sample() in Sampler")
70ONE_FROM_MANY_LABEL = calc_label_from_name("one more from many()")
71
72
73T = TypeVar("T")
74
75
76def identity(v: T) -> T:
77 return v
78
79
80def check_sample(
81 values: type[enum.Enum] | Sequence[T], strategy_name: str
82) -> Sequence[T]:
83 if "numpy" in sys.modules and isinstance(values, sys.modules["numpy"].ndarray):
84 if values.ndim != 1:
85 raise InvalidArgument(
86 "Only one-dimensional arrays are supported for sampling, "
87 f"and the given value has {values.ndim} dimensions (shape "
88 f"{values.shape}). This array would give samples of array slices "
89 "instead of elements! Use np.ravel(values) to convert "
90 "to a one-dimensional array, or tuple(values) if you "
91 "want to sample slices."
92 )
93 elif not isinstance(values, (OrderedDict, abc.Sequence, enum.EnumMeta)):
94 raise InvalidArgument(
95 f"Cannot sample from {values!r} because it is not an ordered collection. "
96 f"Hypothesis goes to some length to ensure that the {strategy_name} "
97 "strategy has stable results between runs. To replay a saved "
98 "example, the sampled values must have the same iteration order "
99 "on every run - ruling out sets, dicts, etc due to hash "
100 "randomization. Most cases can simply use `sorted(values)`, but "
101 "mixed types or special values such as math.nan require careful "
102 "handling - and note that when simplifying an example, "
103 "Hypothesis treats earlier values as simpler."
104 )
105 if isinstance(values, range):
106 # Pyright is unhappy with every way I've tried to type-annotate this
107 # function, so fine, we'll just ignore the analysis error.
108 return values # type: ignore
109 return tuple(values)
110
111
112@lru_cache(64)
113def compute_sampler_table(weights: tuple[float, ...]) -> list[tuple[int, int, float]]:
114 n = len(weights)
115 table: list[list[int | float | None]] = [[i, None, None] for i in range(n)]
116 total = sum(weights)
117 num_type = type(total)
118
119 zero = num_type(0) # type: ignore
120 one = num_type(1) # type: ignore
121
122 small: list[int] = []
123 large: list[int] = []
124
125 probabilities = [w / total for w in weights]
126 scaled_probabilities: list[float] = []
127
128 for i, alternate_chance in enumerate(probabilities):
129 scaled = alternate_chance * n
130 scaled_probabilities.append(scaled)
131 if scaled == 1:
132 table[i][2] = zero
133 elif scaled < 1:
134 small.append(i)
135 else:
136 large.append(i)
137 heapq.heapify(small)
138 heapq.heapify(large)
139
140 while small and large:
141 lo = heapq.heappop(small)
142 hi = heapq.heappop(large)
143
144 assert lo != hi
145 assert scaled_probabilities[hi] > one
146 assert table[lo][1] is None
147 table[lo][1] = hi
148 table[lo][2] = one - scaled_probabilities[lo]
149 scaled_probabilities[hi] = (
150 scaled_probabilities[hi] + scaled_probabilities[lo]
151 ) - one
152
153 if scaled_probabilities[hi] < 1:
154 heapq.heappush(small, hi)
155 elif scaled_probabilities[hi] == 1:
156 table[hi][2] = zero
157 else:
158 heapq.heappush(large, hi)
159 while large:
160 table[large.pop()][2] = zero
161 while small:
162 table[small.pop()][2] = zero
163
164 new_table: list[tuple[int, int, float]] = []
165 for base, alternate, alternate_chance in table:
166 assert isinstance(base, int)
167 assert isinstance(alternate, int) or alternate is None
168 assert alternate_chance is not None
169 if alternate is None:
170 new_table.append((base, base, alternate_chance))
171 elif alternate < base:
172 new_table.append((alternate, base, one - alternate_chance))
173 else:
174 new_table.append((base, alternate, alternate_chance))
175 new_table.sort()
176 return new_table
177
178
179class Sampler:
180 """Sampler based on Vose's algorithm for the alias method. See
181 http://www.keithschwarz.com/darts-dice-coins/ for a good explanation.
182
183 The general idea is that we store a table of triples (base, alternate, p).
184 base. We then pick a triple uniformly at random, and choose its alternate
185 value with probability p and else choose its base value. The triples are
186 chosen so that the resulting mixture has the right distribution.
187
188 We maintain the following invariants to try to produce good shrinks:
189
190 1. The table is in lexicographic (base, alternate) order, so that choosing
191 an earlier value in the list always lowers (or at least leaves
192 unchanged) the value.
193 2. base[i] < alternate[i], so that shrinking the draw always results in
194 shrinking the chosen element.
195 """
196
197 table: list[tuple[int, int, float]] # (base_idx, alt_idx, alt_chance)
198
199 def __init__(self, weights: Sequence[float], *, observe: bool = True):
200 self.observe = observe
201 self.table = compute_sampler_table(tuple(weights))
202
203 def sample(
204 self,
205 data: "ConjectureData",
206 *,
207 forced: int | None = None,
208 ) -> int:
209 if self.observe:
210 data.start_span(SAMPLE_IN_SAMPLER_LABEL)
211 forced_choice = ( # pragma: no branch # https://github.com/nedbat/coveragepy/issues/1617
212 None
213 if forced is None
214 else next(
215 (base, alternate, alternate_chance)
216 for (base, alternate, alternate_chance) in self.table
217 if forced == base or (forced == alternate and alternate_chance > 0)
218 )
219 )
220 base, alternate, alternate_chance = data.choice(
221 self.table,
222 forced=forced_choice,
223 observe=self.observe,
224 )
225 forced_use_alternate = None
226 if forced is not None:
227 # we maintain this invariant when picking forced_choice above.
228 # This song and dance about alternate_chance > 0 is to avoid forcing
229 # e.g. draw_boolean(p=0, forced=True), which is an error.
230 forced_use_alternate = forced == alternate and alternate_chance > 0
231 assert forced == base or forced_use_alternate
232
233 use_alternate = data.draw_boolean(
234 alternate_chance,
235 forced=forced_use_alternate,
236 observe=self.observe,
237 )
238 if self.observe:
239 data.stop_span()
240 if use_alternate:
241 assert forced is None or alternate == forced, (forced, alternate)
242 return alternate
243 else:
244 assert forced is None or base == forced, (forced, base)
245 return base
246
247
248class many:
249 """Utility class for collections. Bundles up the logic we use for "should I
250 keep drawing more values?" and handles starting and stopping examples in
251 the right place.
252
253 Intended usage is something like:
254
255 elements = many(data, ...)
256 while elements.more():
257 add_stuff_to_result()
258 """
259
260 def __init__(
261 self,
262 data: "ConjectureData",
263 min_size: int,
264 max_size: int | float,
265 average_size: int | float,
266 *,
267 forced: int | None = None,
268 observe: bool = True,
269 ) -> None:
270 assert 0 <= min_size <= average_size <= max_size
271 assert forced is None or min_size <= forced <= max_size
272 self.min_size = min_size
273 self.max_size = max_size
274 self.data = data
275 self.forced_size = forced
276 self.p_continue = _calc_p_continue(average_size - min_size, max_size - min_size)
277 self.count = 0
278 self.rejections = 0
279 self.drawn = False
280 self.force_stop = False
281 self.rejected = False
282 self.observe = observe
283
284 def stop_span(self):
285 if self.observe:
286 self.data.stop_span()
287
288 def start_span(self, label):
289 if self.observe:
290 self.data.start_span(label)
291
292 def more(self) -> bool:
293 """Should I draw another element to add to the collection?"""
294 if self.drawn:
295 self.stop_span()
296
297 self.drawn = True
298 self.rejected = False
299
300 self.start_span(ONE_FROM_MANY_LABEL)
301 if self.min_size == self.max_size:
302 # if we have to hit an exact size, draw unconditionally until that
303 # point, and no further.
304 should_continue = self.count < self.min_size
305 else:
306 forced_result = None
307 if self.force_stop:
308 # if our size is forced, we can't reject in a way that would
309 # cause us to differ from the forced size.
310 assert self.forced_size is None or self.count == self.forced_size
311 forced_result = False
312 elif self.count < self.min_size:
313 forced_result = True
314 elif self.count >= self.max_size:
315 forced_result = False
316 elif self.forced_size is not None:
317 forced_result = self.count < self.forced_size
318 should_continue = self.data.draw_boolean(
319 self.p_continue,
320 forced=forced_result,
321 observe=self.observe,
322 )
323
324 if should_continue:
325 self.count += 1
326 return True
327 else:
328 self.stop_span()
329 return False
330
331 def reject(self, why: str | None = None) -> None:
332 """Reject the last example (i.e. don't count it towards our budget of
333 elements because it's not going to go in the final collection)."""
334 assert self.count > 0
335 self.count -= 1
336 self.rejections += 1
337 self.rejected = True
338 # We set a minimum number of rejections before we give up to avoid
339 # failing too fast when we reject the first draw.
340 if self.rejections > max(3, 2 * self.count):
341 if self.count < self.min_size:
342 self.data.mark_invalid(why)
343 else:
344 self.force_stop = True
345
346
347SMALLEST_POSITIVE_FLOAT: float = next_up(0.0) or sys.float_info.min
348
349
350@lru_cache
351def _calc_p_continue(desired_avg: float, max_size: int | float) -> float:
352 """Return the p_continue which will generate the desired average size."""
353 assert desired_avg <= max_size, (desired_avg, max_size)
354 if desired_avg == max_size:
355 return 1.0
356 p_continue = 1 - 1.0 / (1 + desired_avg)
357 if p_continue == 0 or max_size == math.inf:
358 assert 0 <= p_continue < 1, p_continue
359 return p_continue
360 assert 0 < p_continue < 1, p_continue
361 # For small max_size, the infinite-series p_continue is a poor approximation,
362 # and while we can't solve the polynomial a few rounds of iteration quickly
363 # gets us a good approximate solution in almost all cases (sometimes exact!).
364 while _p_continue_to_avg(p_continue, max_size) > desired_avg:
365 # This is impossible over the reals, but *can* happen with floats.
366 p_continue -= 0.0001
367 # If we've reached zero or gone negative, we want to break out of this loop,
368 # and do so even if we're on a system with the unsafe denormals-are-zero flag.
369 # We make that an explicit error in st.floats(), but here we'd prefer to
370 # just get somewhat worse precision on collection lengths.
371 if p_continue < SMALLEST_POSITIVE_FLOAT:
372 p_continue = SMALLEST_POSITIVE_FLOAT
373 break
374 # Let's binary-search our way to a better estimate! We tried fancier options
375 # like gradient descent, but this is numerically stable and works better.
376 hi = 1.0
377 while desired_avg - _p_continue_to_avg(p_continue, max_size) > 0.01:
378 assert 0 < p_continue < hi, (p_continue, hi)
379 mid = (p_continue + hi) / 2
380 if _p_continue_to_avg(mid, max_size) <= desired_avg:
381 p_continue = mid
382 else:
383 hi = mid
384 assert 0 < p_continue < 1, p_continue
385 assert _p_continue_to_avg(p_continue, max_size) <= desired_avg
386 return p_continue
387
388
389def _p_continue_to_avg(p_continue: float, max_size: int | float) -> float:
390 """Return the average_size generated by this p_continue and max_size."""
391 if p_continue >= 1:
392 return max_size
393 return (1.0 / (1 - p_continue) - 1) * (1 - p_continue**max_size)