1# This file is part of Hypothesis, which may be found at
2# https://github.com/HypothesisWorks/hypothesis/
3#
4# Copyright the Hypothesis Authors.
5# Individual contributors are listed in AUTHORS.rst and the git log.
6#
7# This Source Code Form is subject to the terms of the Mozilla Public License,
8# v. 2.0. If a copy of the MPL was not distributed with this file, You can
9# obtain one at https://mozilla.org/MPL/2.0/.
10
11import enum
12import hashlib
13import heapq
14import math
15import sys
16from collections import OrderedDict, abc
17from collections.abc import Sequence
18from functools import lru_cache
19from types import FunctionType
20from typing import TYPE_CHECKING, Callable, Optional, TypeVar, Union
21
22from hypothesis.errors import InvalidArgument
23from hypothesis.internal.compat import int_from_bytes
24from hypothesis.internal.floats import next_up
25from hypothesis.internal.lambda_sources import _function_key
26
27if TYPE_CHECKING:
28 from hypothesis.internal.conjecture.data import ConjectureData
29
30
31LABEL_MASK = 2**64 - 1
32
33
34def calc_label_from_name(name: str) -> int:
35 hashed = hashlib.sha384(name.encode()).digest()
36 return int_from_bytes(hashed[:8])
37
38
39def calc_label_from_callable(f: Callable) -> int:
40 if isinstance(f, FunctionType):
41 return calc_label_from_hash(_function_key(f, ignore_name=True))
42 elif isinstance(f, type):
43 return calc_label_from_cls(f)
44 else:
45 # probably an instance defining __call__
46 try:
47 return calc_label_from_hash(f)
48 except Exception:
49 # not hashable
50 return calc_label_from_cls(type(f))
51
52
53def calc_label_from_cls(cls: type) -> int:
54 return calc_label_from_name(cls.__qualname__)
55
56
57def calc_label_from_hash(obj: object) -> int:
58 return calc_label_from_name(str(hash(obj)))
59
60
61def combine_labels(*labels: int) -> int:
62 label = 0
63 for l in labels:
64 label = (label << 1) & LABEL_MASK
65 label ^= l
66 return label
67
68
69SAMPLE_IN_SAMPLER_LABEL = calc_label_from_name("a sample() in Sampler")
70ONE_FROM_MANY_LABEL = calc_label_from_name("one more from many()")
71
72
73T = TypeVar("T")
74
75
76def identity(v: T) -> T:
77 return v
78
79
80def check_sample(
81 values: Union[type[enum.Enum], Sequence[T]], strategy_name: str
82) -> Sequence[T]:
83 if "numpy" in sys.modules and isinstance(values, sys.modules["numpy"].ndarray):
84 if values.ndim != 1:
85 raise InvalidArgument(
86 "Only one-dimensional arrays are supported for sampling, "
87 f"and the given value has {values.ndim} dimensions (shape "
88 f"{values.shape}). This array would give samples of array slices "
89 "instead of elements! Use np.ravel(values) to convert "
90 "to a one-dimensional array, or tuple(values) if you "
91 "want to sample slices."
92 )
93 elif not isinstance(values, (OrderedDict, abc.Sequence, enum.EnumMeta)):
94 raise InvalidArgument(
95 f"Cannot sample from {values!r} because it is not an ordered collection. "
96 f"Hypothesis goes to some length to ensure that the {strategy_name} "
97 "strategy has stable results between runs. To replay a saved "
98 "example, the sampled values must have the same iteration order "
99 "on every run - ruling out sets, dicts, etc due to hash "
100 "randomization. Most cases can simply use `sorted(values)`, but "
101 "mixed types or special values such as math.nan require careful "
102 "handling - and note that when simplifying an example, "
103 "Hypothesis treats earlier values as simpler."
104 )
105 if isinstance(values, range):
106 # Pyright is unhappy with every way I've tried to type-annotate this
107 # function, so fine, we'll just ignore the analysis error.
108 return values # type: ignore
109 return tuple(values)
110
111
112@lru_cache(64)
113def compute_sampler_table(weights: tuple[float, ...]) -> list[tuple[int, int, float]]:
114 n = len(weights)
115 table: list[list[int | float | None]] = [[i, None, None] for i in range(n)]
116 total = sum(weights)
117 num_type = type(total)
118
119 zero = num_type(0) # type: ignore
120 one = num_type(1) # type: ignore
121
122 small: list[int] = []
123 large: list[int] = []
124
125 probabilities = [w / total for w in weights]
126 scaled_probabilities: list[float] = []
127
128 for i, alternate_chance in enumerate(probabilities):
129 scaled = alternate_chance * n
130 scaled_probabilities.append(scaled)
131 if scaled == 1:
132 table[i][2] = zero
133 elif scaled < 1:
134 small.append(i)
135 else:
136 large.append(i)
137 heapq.heapify(small)
138 heapq.heapify(large)
139
140 while small and large:
141 lo = heapq.heappop(small)
142 hi = heapq.heappop(large)
143
144 assert lo != hi
145 assert scaled_probabilities[hi] > one
146 assert table[lo][1] is None
147 table[lo][1] = hi
148 table[lo][2] = one - scaled_probabilities[lo]
149 scaled_probabilities[hi] = (
150 scaled_probabilities[hi] + scaled_probabilities[lo]
151 ) - one
152
153 if scaled_probabilities[hi] < 1:
154 heapq.heappush(small, hi)
155 elif scaled_probabilities[hi] == 1:
156 table[hi][2] = zero
157 else:
158 heapq.heappush(large, hi)
159 while large:
160 table[large.pop()][2] = zero
161 while small:
162 table[small.pop()][2] = zero
163
164 new_table: list[tuple[int, int, float]] = []
165 for base, alternate, alternate_chance in table:
166 assert isinstance(base, int)
167 assert isinstance(alternate, int) or alternate is None
168 assert alternate_chance is not None
169 if alternate is None:
170 new_table.append((base, base, alternate_chance))
171 elif alternate < base:
172 new_table.append((alternate, base, one - alternate_chance))
173 else:
174 new_table.append((base, alternate, alternate_chance))
175 new_table.sort()
176 return new_table
177
178
179class Sampler:
180 """Sampler based on Vose's algorithm for the alias method. See
181 http://www.keithschwarz.com/darts-dice-coins/ for a good explanation.
182
183 The general idea is that we store a table of triples (base, alternate, p).
184 base. We then pick a triple uniformly at random, and choose its alternate
185 value with probability p and else choose its base value. The triples are
186 chosen so that the resulting mixture has the right distribution.
187
188 We maintain the following invariants to try to produce good shrinks:
189
190 1. The table is in lexicographic (base, alternate) order, so that choosing
191 an earlier value in the list always lowers (or at least leaves
192 unchanged) the value.
193 2. base[i] < alternate[i], so that shrinking the draw always results in
194 shrinking the chosen element.
195 """
196
197 table: list[tuple[int, int, float]] # (base_idx, alt_idx, alt_chance)
198
199 def __init__(self, weights: Sequence[float], *, observe: bool = True):
200 self.observe = observe
201 self.table = compute_sampler_table(tuple(weights))
202
203 def sample(
204 self,
205 data: "ConjectureData",
206 *,
207 forced: Optional[int] = None,
208 ) -> int:
209 if self.observe:
210 data.start_span(SAMPLE_IN_SAMPLER_LABEL)
211 forced_choice = ( # pragma: no branch # https://github.com/nedbat/coveragepy/issues/1617
212 None
213 if forced is None
214 else next(
215 (base, alternate, alternate_chance)
216 for (base, alternate, alternate_chance) in self.table
217 if forced == base or (forced == alternate and alternate_chance > 0)
218 )
219 )
220 base, alternate, alternate_chance = data.choice(
221 self.table,
222 forced=forced_choice,
223 observe=self.observe,
224 )
225 forced_use_alternate = None
226 if forced is not None:
227 # we maintain this invariant when picking forced_choice above.
228 # This song and dance about alternate_chance > 0 is to avoid forcing
229 # e.g. draw_boolean(p=0, forced=True), which is an error.
230 forced_use_alternate = forced == alternate and alternate_chance > 0
231 assert forced == base or forced_use_alternate
232
233 use_alternate = data.draw_boolean(
234 alternate_chance,
235 forced=forced_use_alternate,
236 observe=self.observe,
237 )
238 if self.observe:
239 data.stop_span()
240 if use_alternate:
241 assert forced is None or alternate == forced, (forced, alternate)
242 return alternate
243 else:
244 assert forced is None or base == forced, (forced, base)
245 return base
246
247
248INT_SIZES = (8, 16, 32, 64, 128)
249INT_SIZES_SAMPLER = Sampler((4.0, 8.0, 1.0, 1.0, 0.5), observe=False)
250
251
252class many:
253 """Utility class for collections. Bundles up the logic we use for "should I
254 keep drawing more values?" and handles starting and stopping examples in
255 the right place.
256
257 Intended usage is something like:
258
259 elements = many(data, ...)
260 while elements.more():
261 add_stuff_to_result()
262 """
263
264 def __init__(
265 self,
266 data: "ConjectureData",
267 min_size: int,
268 max_size: Union[int, float],
269 average_size: Union[int, float],
270 *,
271 forced: Optional[int] = None,
272 observe: bool = True,
273 ) -> None:
274 assert 0 <= min_size <= average_size <= max_size
275 assert forced is None or min_size <= forced <= max_size
276 self.min_size = min_size
277 self.max_size = max_size
278 self.data = data
279 self.forced_size = forced
280 self.p_continue = _calc_p_continue(average_size - min_size, max_size - min_size)
281 self.count = 0
282 self.rejections = 0
283 self.drawn = False
284 self.force_stop = False
285 self.rejected = False
286 self.observe = observe
287
288 def stop_span(self):
289 if self.observe:
290 self.data.stop_span()
291
292 def start_span(self, label):
293 if self.observe:
294 self.data.start_span(label)
295
296 def more(self) -> bool:
297 """Should I draw another element to add to the collection?"""
298 if self.drawn:
299 self.stop_span()
300
301 self.drawn = True
302 self.rejected = False
303
304 self.start_span(ONE_FROM_MANY_LABEL)
305 if self.min_size == self.max_size:
306 # if we have to hit an exact size, draw unconditionally until that
307 # point, and no further.
308 should_continue = self.count < self.min_size
309 else:
310 forced_result = None
311 if self.force_stop:
312 # if our size is forced, we can't reject in a way that would
313 # cause us to differ from the forced size.
314 assert self.forced_size is None or self.count == self.forced_size
315 forced_result = False
316 elif self.count < self.min_size:
317 forced_result = True
318 elif self.count >= self.max_size:
319 forced_result = False
320 elif self.forced_size is not None:
321 forced_result = self.count < self.forced_size
322 should_continue = self.data.draw_boolean(
323 self.p_continue,
324 forced=forced_result,
325 observe=self.observe,
326 )
327
328 if should_continue:
329 self.count += 1
330 return True
331 else:
332 self.stop_span()
333 return False
334
335 def reject(self, why: Optional[str] = None) -> None:
336 """Reject the last example (i.e. don't count it towards our budget of
337 elements because it's not going to go in the final collection)."""
338 assert self.count > 0
339 self.count -= 1
340 self.rejections += 1
341 self.rejected = True
342 # We set a minimum number of rejections before we give up to avoid
343 # failing too fast when we reject the first draw.
344 if self.rejections > max(3, 2 * self.count):
345 if self.count < self.min_size:
346 self.data.mark_invalid(why)
347 else:
348 self.force_stop = True
349
350
351SMALLEST_POSITIVE_FLOAT: float = next_up(0.0) or sys.float_info.min
352
353
354@lru_cache
355def _calc_p_continue(desired_avg: float, max_size: Union[int, float]) -> float:
356 """Return the p_continue which will generate the desired average size."""
357 assert desired_avg <= max_size, (desired_avg, max_size)
358 if desired_avg == max_size:
359 return 1.0
360 p_continue = 1 - 1.0 / (1 + desired_avg)
361 if p_continue == 0 or max_size == math.inf:
362 assert 0 <= p_continue < 1, p_continue
363 return p_continue
364 assert 0 < p_continue < 1, p_continue
365 # For small max_size, the infinite-series p_continue is a poor approximation,
366 # and while we can't solve the polynomial a few rounds of iteration quickly
367 # gets us a good approximate solution in almost all cases (sometimes exact!).
368 while _p_continue_to_avg(p_continue, max_size) > desired_avg:
369 # This is impossible over the reals, but *can* happen with floats.
370 p_continue -= 0.0001
371 # If we've reached zero or gone negative, we want to break out of this loop,
372 # and do so even if we're on a system with the unsafe denormals-are-zero flag.
373 # We make that an explicit error in st.floats(), but here we'd prefer to
374 # just get somewhat worse precision on collection lengths.
375 if p_continue < SMALLEST_POSITIVE_FLOAT:
376 p_continue = SMALLEST_POSITIVE_FLOAT
377 break
378 # Let's binary-search our way to a better estimate! We tried fancier options
379 # like gradient descent, but this is numerically stable and works better.
380 hi = 1.0
381 while desired_avg - _p_continue_to_avg(p_continue, max_size) > 0.01:
382 assert 0 < p_continue < hi, (p_continue, hi)
383 mid = (p_continue + hi) / 2
384 if _p_continue_to_avg(mid, max_size) <= desired_avg:
385 p_continue = mid
386 else:
387 hi = mid
388 assert 0 < p_continue < 1, p_continue
389 assert _p_continue_to_avg(p_continue, max_size) <= desired_avg
390 return p_continue
391
392
393def _p_continue_to_avg(p_continue: float, max_size: Union[int, float]) -> float:
394 """Return the average_size generated by this p_continue and max_size."""
395 if p_continue >= 1:
396 return max_size
397 return (1.0 / (1 - p_continue) - 1) * (1 - p_continue**max_size)