1# This file is part of Hypothesis, which may be found at
2# https://github.com/HypothesisWorks/hypothesis/
3#
4# Copyright the Hypothesis Authors.
5# Individual contributors are listed in AUTHORS.rst and the git log.
6#
7# This Source Code Form is subject to the terms of the Mozilla Public License,
8# v. 2.0. If a copy of the MPL was not distributed with this file, You can
9# obtain one at https://mozilla.org/MPL/2.0/.
10
11import enum
12import hashlib
13import heapq
14import math
15import sys
16from collections import OrderedDict, abc
17from collections.abc import Callable, Sequence
18from functools import lru_cache
19from types import FunctionType
20from typing import TYPE_CHECKING, TypeVar
21
22from hypothesis.errors import InvalidArgument
23from hypothesis.internal.compat import int_from_bytes
24from hypothesis.internal.floats import next_up
25from hypothesis.internal.lambda_sources import _function_key
26
27if TYPE_CHECKING:
28 from hypothesis.internal.conjecture.data import ConjectureData
29
30
31LABEL_MASK = 2**64 - 1
32
33
34def calc_label_from_name(name: str) -> int:
35 hashed = hashlib.sha384(name.encode()).digest()
36 return int_from_bytes(hashed[:8])
37
38
39def calc_label_from_callable(f: Callable) -> int:
40 if isinstance(f, FunctionType):
41 return calc_label_from_hash(_function_key(f, ignore_name=True))
42 elif isinstance(f, type):
43 return calc_label_from_cls(f)
44 else:
45 # probably an instance defining __call__
46 try:
47 return calc_label_from_hash(f)
48 except Exception:
49 # not hashable
50 return calc_label_from_cls(type(f))
51
52
53def calc_label_from_cls(cls: type) -> int:
54 return calc_label_from_name(cls.__qualname__)
55
56
57def calc_label_from_hash(obj: object) -> int:
58 return calc_label_from_name(str(hash(obj)))
59
60
61def combine_labels(*labels: int) -> int:
62 label = 0
63 for l in labels:
64 label = (label << 1) & LABEL_MASK
65 label ^= l
66 return label
67
68
69SAMPLE_IN_SAMPLER_LABEL = calc_label_from_name("a sample() in Sampler")
70ONE_FROM_MANY_LABEL = calc_label_from_name("one more from many()")
71
72
73T = TypeVar("T")
74
75
76def identity(v: T) -> T:
77 return v
78
79
80def fisher_yates_shuffle(data: "ConjectureData", ls: list[T]) -> None:
81 """Shuffle ``ls`` in place, drawing from ``data``.
82
83 Reversed Fisher-Yates shuffle: swap each element with itself or with a
84 later element. This shrinks i==j for each element, i.e. towards no change,
85 so a shuffled sequence shrinks back to its original order. We don't
86 consider the last element as it's always a no-op.
87 """
88 for i in range(len(ls) - 1):
89 j = data.draw_integer(i, len(ls) - 1)
90 ls[i], ls[j] = ls[j], ls[i]
91
92
93def check_sample(
94 values: type[enum.Enum] | Sequence[T], strategy_name: str
95) -> Sequence[T]:
96 if "numpy" in sys.modules and isinstance(values, sys.modules["numpy"].ndarray):
97 if values.ndim != 1:
98 raise InvalidArgument(
99 "Only one-dimensional arrays are supported for sampling, "
100 f"and the given value has {values.ndim} dimensions (shape "
101 f"{values.shape}). This array would give samples of array slices "
102 "instead of elements! Use np.ravel(values) to convert "
103 "to a one-dimensional array, or tuple(values) if you "
104 "want to sample slices."
105 )
106 elif not isinstance(values, (OrderedDict, abc.Sequence, enum.EnumMeta)):
107 raise InvalidArgument(
108 f"Cannot sample from {values!r} because it is not an ordered collection. "
109 f"Hypothesis goes to some length to ensure that the {strategy_name} "
110 "strategy has stable results between runs. To replay a saved "
111 "example, the sampled values must have the same iteration order "
112 "on every run - ruling out sets, dicts, etc due to hash "
113 "randomization. Most cases can simply use `sorted(values)`, but "
114 "mixed types or special values such as math.nan require careful "
115 "handling - and note that when simplifying an example, "
116 "Hypothesis treats earlier values as simpler."
117 )
118 if isinstance(values, range):
119 # Pyright is unhappy with every way I've tried to type-annotate this
120 # function, so fine, we'll just ignore the analysis error.
121 return values # type: ignore
122 return tuple(values)
123
124
125@lru_cache(64)
126def compute_sampler_table(weights: tuple[float, ...]) -> list[tuple[int, int, float]]:
127 n = len(weights)
128 table: list[list[int | float | None]] = [[i, None, None] for i in range(n)]
129 total = sum(weights)
130 num_type = type(total)
131
132 zero = num_type(0) # type: ignore
133 one = num_type(1) # type: ignore
134
135 small: list[int] = []
136 large: list[int] = []
137
138 probabilities = [w / total for w in weights]
139 scaled_probabilities: list[float] = []
140
141 for i, alternate_chance in enumerate(probabilities):
142 scaled = alternate_chance * n
143 scaled_probabilities.append(scaled)
144 if scaled == 1:
145 table[i][2] = zero
146 elif scaled < 1:
147 small.append(i)
148 else:
149 large.append(i)
150 heapq.heapify(small)
151 heapq.heapify(large)
152
153 while small and large:
154 lo = heapq.heappop(small)
155 hi = heapq.heappop(large)
156
157 assert lo != hi
158 assert scaled_probabilities[hi] > one
159 assert table[lo][1] is None
160 table[lo][1] = hi
161 table[lo][2] = one - scaled_probabilities[lo]
162 scaled_probabilities[hi] = (
163 scaled_probabilities[hi] + scaled_probabilities[lo]
164 ) - one
165
166 if scaled_probabilities[hi] < 1:
167 heapq.heappush(small, hi)
168 elif scaled_probabilities[hi] == 1:
169 table[hi][2] = zero
170 else:
171 heapq.heappush(large, hi)
172 while large:
173 table[large.pop()][2] = zero
174 while small:
175 table[small.pop()][2] = zero
176
177 new_table: list[tuple[int, int, float]] = []
178 for base, alternate, alternate_chance in table:
179 assert isinstance(base, int)
180 assert isinstance(alternate, int) or alternate is None
181 assert alternate_chance is not None
182 if alternate is None:
183 new_table.append((base, base, alternate_chance))
184 elif alternate < base:
185 new_table.append((alternate, base, one - alternate_chance))
186 else:
187 new_table.append((base, alternate, alternate_chance))
188 new_table.sort()
189 return new_table
190
191
192class Sampler:
193 """Sampler based on Vose's algorithm for the alias method. See
194 http://www.keithschwarz.com/darts-dice-coins/ for a good explanation.
195
196 The general idea is that we store a table of triples (base, alternate, p).
197 base. We then pick a triple uniformly at random, and choose its alternate
198 value with probability p and else choose its base value. The triples are
199 chosen so that the resulting mixture has the right distribution.
200
201 We maintain the following invariants to try to produce good shrinks:
202
203 1. The table is in lexicographic (base, alternate) order, so that choosing
204 an earlier value in the list always lowers (or at least leaves
205 unchanged) the value.
206 2. base[i] < alternate[i], so that shrinking the draw always results in
207 shrinking the chosen element.
208 """
209
210 table: list[tuple[int, int, float]] # (base_idx, alt_idx, alt_chance)
211
212 def __init__(self, weights: Sequence[float], *, observe: bool = True):
213 self.observe = observe
214 self.table = compute_sampler_table(tuple(weights))
215
216 def sample(
217 self,
218 data: "ConjectureData",
219 *,
220 forced: int | None = None,
221 ) -> int:
222 if self.observe:
223 data.start_span(SAMPLE_IN_SAMPLER_LABEL)
224 forced_choice = ( # pragma: no branch # https://github.com/nedbat/coveragepy/issues/1617
225 None
226 if forced is None
227 else next(
228 (base, alternate, alternate_chance)
229 for (base, alternate, alternate_chance) in self.table
230 if forced == base or (forced == alternate and alternate_chance > 0)
231 )
232 )
233 base, alternate, alternate_chance = data.choice(
234 self.table,
235 forced=forced_choice,
236 observe=self.observe,
237 )
238 forced_use_alternate = None
239 if forced is not None:
240 # we maintain this invariant when picking forced_choice above.
241 # This song and dance about alternate_chance > 0 is to avoid forcing
242 # e.g. draw_boolean(p=0, forced=True), which is an error.
243 forced_use_alternate = forced == alternate and alternate_chance > 0
244 assert forced == base or forced_use_alternate
245
246 use_alternate = data.draw_boolean(
247 alternate_chance,
248 forced=forced_use_alternate,
249 observe=self.observe,
250 )
251 if self.observe:
252 data.stop_span()
253 if use_alternate:
254 assert forced is None or alternate == forced, (forced, alternate)
255 return alternate
256 else:
257 assert forced is None or base == forced, (forced, base)
258 return base
259
260
261class many:
262 """Utility class for collections. Bundles up the logic we use for "should I
263 keep drawing more values?" and handles starting and stopping examples in
264 the right place.
265
266 Intended usage is something like:
267
268 elements = many(data, ...)
269 while elements.more():
270 add_stuff_to_result()
271 """
272
273 def __init__(
274 self,
275 data: "ConjectureData",
276 min_size: int,
277 max_size: int | float,
278 average_size: int | float,
279 *,
280 forced: int | None = None,
281 observe: bool = True,
282 ) -> None:
283 assert 0 <= min_size <= average_size <= max_size
284 assert forced is None or min_size <= forced <= max_size
285 self.min_size = min_size
286 self.max_size = max_size
287 self.data = data
288 self.forced_size = forced
289 self.p_continue = _calc_p_continue(average_size - min_size, max_size - min_size)
290 self.count = 0
291 self.rejections = 0
292 self.drawn = False
293 self.force_stop = False
294 self.rejected = False
295 self.observe = observe
296
297 def stop_span(self):
298 if self.observe:
299 self.data.stop_span()
300
301 def start_span(self, label):
302 if self.observe:
303 self.data.start_span(label)
304
305 def more(self) -> bool:
306 """Should I draw another element to add to the collection?"""
307 if self.drawn:
308 self.stop_span()
309
310 self.drawn = True
311 self.rejected = False
312
313 self.start_span(ONE_FROM_MANY_LABEL)
314 if self.min_size == self.max_size:
315 # if we have to hit an exact size, draw unconditionally until that
316 # point, and no further.
317 should_continue = self.count < self.min_size
318 else:
319 forced_result = None
320 if self.force_stop:
321 # if our size is forced, we can't reject in a way that would
322 # cause us to differ from the forced size.
323 assert self.forced_size is None or self.count == self.forced_size
324 forced_result = False
325 elif self.count < self.min_size:
326 forced_result = True
327 elif self.count >= self.max_size:
328 forced_result = False
329 elif self.forced_size is not None:
330 forced_result = self.count < self.forced_size
331 should_continue = self.data.draw_boolean(
332 self.p_continue,
333 forced=forced_result,
334 observe=self.observe,
335 )
336
337 if should_continue:
338 self.count += 1
339 return True
340 else:
341 self.stop_span()
342 return False
343
344 def reject(self, why: str | None = None) -> None:
345 """Reject the last example (i.e. don't count it towards our budget of
346 elements because it's not going to go in the final collection)."""
347 assert self.count > 0
348 self.count -= 1
349 self.rejections += 1
350 self.rejected = True
351 # We set a minimum number of rejections before we give up to avoid
352 # failing too fast when we reject the first draw.
353 if self.rejections > max(3, 2 * self.count):
354 if self.count < self.min_size:
355 self.data.mark_invalid(why)
356 else:
357 self.force_stop = True
358
359
360SMALLEST_POSITIVE_FLOAT: float = next_up(0.0) or sys.float_info.min
361
362
363@lru_cache
364def _calc_p_continue(desired_avg: float, max_size: int | float) -> float:
365 """Return the p_continue which will generate the desired average size."""
366 assert desired_avg <= max_size, (desired_avg, max_size)
367 if desired_avg == max_size:
368 return 1.0
369 p_continue = 1 - 1.0 / (1 + desired_avg)
370 if p_continue == 0 or max_size == math.inf:
371 assert 0 <= p_continue < 1, p_continue
372 return p_continue
373 assert 0 < p_continue < 1, p_continue
374 # For small max_size, the infinite-series p_continue is a poor approximation,
375 # and while we can't solve the polynomial a few rounds of iteration quickly
376 # gets us a good approximate solution in almost all cases (sometimes exact!).
377 while _p_continue_to_avg(p_continue, max_size) > desired_avg:
378 # This is impossible over the reals, but *can* happen with floats.
379 p_continue -= 0.0001
380 # If we've reached zero or gone negative, we want to break out of this loop,
381 # and do so even if we're on a system with the unsafe denormals-are-zero flag.
382 # We make that an explicit error in st.floats(), but here we'd prefer to
383 # just get somewhat worse precision on collection lengths.
384 if p_continue < SMALLEST_POSITIVE_FLOAT:
385 p_continue = SMALLEST_POSITIVE_FLOAT
386 break
387 # Let's binary-search our way to a better estimate! We tried fancier options
388 # like gradient descent, but this is numerically stable and works better.
389 hi = 1.0
390 while desired_avg - _p_continue_to_avg(p_continue, max_size) > 0.01:
391 assert 0 < p_continue < hi, (p_continue, hi)
392 mid = (p_continue + hi) / 2
393 if _p_continue_to_avg(mid, max_size) <= desired_avg:
394 p_continue = mid
395 else:
396 hi = mid
397 assert 0 < p_continue < 1, p_continue
398 assert _p_continue_to_avg(p_continue, max_size) <= desired_avg
399 return p_continue
400
401
402def _p_continue_to_avg(p_continue: float, max_size: int | float) -> float:
403 """Return the average_size generated by this p_continue and max_size."""
404 if p_continue >= 1:
405 return max_size
406 return (1.0 / (1 - p_continue) - 1) * (1 - p_continue**max_size)