1# This file is part of Hypothesis, which may be found at
2# https://github.com/HypothesisWorks/hypothesis/
3#
4# Copyright the Hypothesis Authors.
5# Individual contributors are listed in AUTHORS.rst and the git log.
6#
7# This Source Code Form is subject to the terms of the Mozilla Public License,
8# v. 2.0. If a copy of the MPL was not distributed with this file, You can
9# obtain one at https://mozilla.org/MPL/2.0/.
10
11import os
12import sys
13import threading
14import warnings
15from collections import abc, defaultdict
16from collections.abc import Callable, Sequence
17from functools import lru_cache
18from random import shuffle
19from threading import RLock
20from typing import (
21 TYPE_CHECKING,
22 Any,
23 ClassVar,
24 Generic,
25 Literal,
26 TypeAlias,
27 TypeVar,
28 cast,
29 overload,
30)
31
32from hypothesis._settings import HealthCheck, Phase, Verbosity, settings
33from hypothesis.control import _current_build_context, current_build_context
34from hypothesis.errors import (
35 HypothesisException,
36 HypothesisWarning,
37 InvalidArgument,
38 NonInteractiveExampleWarning,
39 UnsatisfiedAssumption,
40)
41from hypothesis.internal.conjecture import utils as cu
42from hypothesis.internal.conjecture.data import ConjectureData
43from hypothesis.internal.conjecture.utils import (
44 calc_label_from_cls,
45 calc_label_from_hash,
46 calc_label_from_name,
47 combine_labels,
48)
49from hypothesis.internal.coverage import check_function
50from hypothesis.internal.reflection import (
51 get_pretty_function_description,
52 is_identity_function,
53)
54from hypothesis.strategies._internal.utils import defines_strategy
55from hypothesis.utils.conventions import UniqueIdentifier
56
57if TYPE_CHECKING:
58 Ex = TypeVar("Ex", covariant=True, default=Any)
59else:
60 Ex = TypeVar("Ex", covariant=True)
61
62T = TypeVar("T")
63T3 = TypeVar("T3")
64T4 = TypeVar("T4")
65T5 = TypeVar("T5")
66MappedFrom = TypeVar("MappedFrom")
67MappedTo = TypeVar("MappedTo")
68RecurT: TypeAlias = Callable[["SearchStrategy"], bool]
69calculating = UniqueIdentifier("calculating")
70
71MAPPED_SEARCH_STRATEGY_DO_DRAW_LABEL = calc_label_from_name(
72 "another attempted draw in MappedStrategy"
73)
74
75FILTERED_SEARCH_STRATEGY_DO_DRAW_LABEL = calc_label_from_name(
76 "single loop iteration in FilteredStrategy"
77)
78
79label_lock = RLock()
80
81
82def recursive_property(strategy: "SearchStrategy", name: str, default: object) -> Any:
83 """Handle properties which may be mutually recursive among a set of
84 strategies.
85
86 These are essentially lazily cached properties, with the ability to set
87 an override: If the property has not been explicitly set, we calculate
88 it on first access and memoize the result for later.
89
90 The problem is that for properties that depend on each other, a naive
91 calculation strategy may hit infinite recursion. Consider for example
92 the property is_empty. A strategy defined as x = st.deferred(lambda: x)
93 is certainly empty (in order to draw a value from x we would have to
94 draw a value from x, for which we would have to draw a value from x,
95 ...), but in order to calculate it the naive approach would end up
96 calling x.is_empty in order to calculate x.is_empty in order to etc.
97
98 The solution is one of fixed point calculation. We start with a default
99 value that is the value of the property in the absence of evidence to
100 the contrary, and then update the values of the property for all
101 dependent strategies until we reach a fixed point.
102
103 The approach taken roughly follows that in section 4.2 of Adams,
104 Michael D., Celeste Hollenbeck, and Matthew Might. "On the complexity
105 and performance of parsing with derivatives." ACM SIGPLAN Notices 51.6
106 (2016): 224-236.
107 """
108 assert name in {"is_empty", "has_reusable_values", "is_cacheable"}
109 cache_key = "cached_" + name
110 calculation = "calc_" + name
111 force_key = "force_" + name
112
113 def forced_value(target: SearchStrategy) -> Any:
114 try:
115 return getattr(target, force_key)
116 except AttributeError:
117 return getattr(target, cache_key)
118
119 try:
120 return forced_value(strategy)
121 except AttributeError:
122 pass
123
124 mapping: dict[SearchStrategy, Any] = {}
125 sentinel = object()
126 hit_recursion = False
127
128 # For a first pass we do a direct recursive calculation of the
129 # property, but we block recursively visiting a value in the
130 # computation of its property: When that happens, we simply
131 # note that it happened and return the default value.
132 def recur(strat: SearchStrategy) -> Any:
133 nonlocal hit_recursion
134 try:
135 return forced_value(strat)
136 except AttributeError:
137 pass
138 result = mapping.get(strat, sentinel)
139 if result is calculating:
140 hit_recursion = True
141 return default
142 elif result is sentinel:
143 mapping[strat] = calculating
144 mapping[strat] = getattr(strat, calculation)(recur)
145 return mapping[strat]
146 return result
147
148 recur(strategy)
149
150 # If we hit self-recursion in the computation of any strategy
151 # value, our mapping at the end is imprecise - it may or may
152 # not have the right values in it. We now need to proceed with
153 # a more careful fixed point calculation to get the exact
154 # values. Hopefully our mapping is still pretty good and it
155 # won't take a large number of updates to reach a fixed point.
156 if hit_recursion:
157 needs_update = set(mapping)
158
159 # We track which strategies use which in the course of
160 # calculating their property value. If A ever uses B in
161 # the course of calculating its value, then whenever the
162 # value of B changes we might need to update the value of
163 # A.
164 listeners: dict[SearchStrategy, set[SearchStrategy]] = defaultdict(set)
165 else:
166 needs_update = None
167
168 def recur2(strat: SearchStrategy) -> Any:
169 def recur_inner(other: SearchStrategy) -> Any:
170 try:
171 return forced_value(other)
172 except AttributeError:
173 pass
174 listeners[other].add(strat)
175 result = mapping.get(other, sentinel)
176 if result is sentinel:
177 assert needs_update is not None
178 needs_update.add(other)
179 mapping[other] = default
180 return default
181 return result
182
183 return recur_inner
184
185 count = 0
186 seen = set()
187 while needs_update:
188 count += 1
189 # If we seem to be taking a really long time to stabilize we
190 # start tracking seen values to attempt to detect an infinite
191 # loop. This should be impossible, and most code will never
192 # hit the count, but having an assertion for it means that
193 # testing is easier to debug and we don't just have a hung
194 # test.
195 # Note: This is actually covered, by test_very_deep_deferral
196 # in tests/cover/test_deferred_strategies.py. Unfortunately it
197 # runs into a coverage bug. See
198 # https://github.com/nedbat/coveragepy/issues/605
199 # for details.
200 if count > 50: # pragma: no cover
201 key = frozenset(mapping.items())
202 assert key not in seen, (key, name)
203 seen.add(key)
204 to_update = needs_update
205 needs_update = set()
206 for strat in to_update:
207 new_value = getattr(strat, calculation)(recur2(strat))
208 if new_value != mapping[strat]:
209 needs_update.update(listeners[strat])
210 mapping[strat] = new_value
211
212 # We now have a complete and accurate calculation of the
213 # property values for everything we have seen in the course of
214 # running this calculation. We simultaneously update all of
215 # them (not just the strategy we started out with).
216 for k, v in mapping.items():
217 setattr(k, cache_key, v)
218 return getattr(strategy, cache_key)
219
220
221class SearchStrategy(Generic[Ex]):
222 """A ``SearchStrategy`` tells Hypothesis how to generate that kind of input.
223
224 This class is only part of the public API for use in type annotations, so that
225 you can write e.g. ``-> SearchStrategy[Foo]`` for your function which returns
226 ``builds(Foo, ...)``. Do not inherit from or directly instantiate this class.
227 """
228
229 __module__: str = "hypothesis.strategies"
230 LABELS: ClassVar[dict[type, int]] = {}
231 # triggers `assert isinstance(label, int)` under threading when setting this
232 # in init instead of a classvar. I'm not sure why, init should be safe. But
233 # this works so I'm not looking into it further atm.
234 __label: int | UniqueIdentifier | None = None
235
236 def __init__(self):
237 self.validate_called: dict[int, bool] = {}
238
239 def is_currently_empty(self, data: ConjectureData) -> bool:
240 """
241 Returns whether this strategy is currently empty. Unlike ``empty``,
242 which is computed based on static information and cannot change,
243 ``is_currently_empty`` may change over time based on choices made
244 during the test case.
245
246 This is currently only used for stateful testing, where |Bundle| grows a
247 list of values to choose from over the course of a test case.
248
249 ``data`` will only be used for introspection. No values will be drawn
250 from it in a way that modifies the choice sequence.
251 """
252 return self.is_empty
253
254 @property
255 def is_empty(self) -> Any:
256 # Returns True if this strategy can never draw a value and will always
257 # result in the data being marked invalid.
258 # The fact that this returns False does not guarantee that a valid value
259 # can be drawn - this is not intended to be perfect, and is primarily
260 # intended to be an optimisation for some cases.
261 return recursive_property(self, "is_empty", True)
262
263 # Returns True if values from this strategy can safely be reused without
264 # this causing unexpected behaviour.
265
266 # True if values from this strategy can be implicitly reused (e.g. as
267 # background values in a numpy array) without causing surprising
268 # user-visible behaviour. Should be false for built-in strategies that
269 # produce mutable values, and for strategies that have been mapped/filtered
270 # by arbitrary user-provided functions.
271 @property
272 def has_reusable_values(self) -> Any:
273 return recursive_property(self, "has_reusable_values", True)
274
275 @property
276 def is_cacheable(self) -> Any:
277 """
278 Whether it is safe to hold on to instances of this strategy in a cache.
279 See _STRATEGY_CACHE.
280 """
281 return recursive_property(self, "is_cacheable", True)
282
283 def calc_is_cacheable(self, recur: RecurT) -> bool:
284 return True
285
286 def calc_is_empty(self, recur: RecurT) -> bool:
287 # Note: It is correct and significant that the default return value
288 # from calc_is_empty is False despite the default value for is_empty
289 # being true. The reason for this is that strategies should be treated
290 # as empty absent evidence to the contrary, but most basic strategies
291 # are trivially non-empty and it would be annoying to have to override
292 # this method to show that.
293 return False
294
295 def calc_has_reusable_values(self, recur: RecurT) -> bool:
296 return False
297
298 def example(self) -> Ex: # FIXME
299 """Provide an example of the sort of value that this strategy generates.
300
301 This method is designed for use in a REPL, and will raise an error if
302 called from inside |@given| or a strategy definition. For serious use,
303 see |@composite| or |st.data|.
304 """
305 if getattr(sys, "ps1", None) is None and (
306 # The main module's __spec__ is None when running interactively
307 # or running a source file directly.
308 # See https://docs.python.org/3/reference/import.html#main-spec.
309 sys.modules["__main__"].__spec__ is not None
310 # __spec__ is also None under pytest-xdist. To avoid an unfortunate
311 # missed alarm here, always warn under pytest.
312 or os.environ.get("PYTEST_CURRENT_TEST") is not None
313 ): # pragma: no branch
314 # The other branch *is* covered in cover/test_interactive_example.py;
315 # but as that uses `pexpect` for an interactive session `coverage`
316 # doesn't see it.
317 warnings.warn(
318 "The `.example()` method is good for exploring strategies, but should "
319 "only be used interactively. We recommend using `@given` for tests - "
320 "it performs better, saves and replays failures to avoid flakiness, "
321 f"and reports minimal examples. (strategy: {self!r})",
322 NonInteractiveExampleWarning,
323 stacklevel=2,
324 )
325
326 context = _current_build_context.value
327 if context is not None:
328 if context.data is not None and context.data.depth > 0:
329 raise HypothesisException(
330 "Using example() inside a strategy definition is a bad "
331 "idea. Instead consider using hypothesis.strategies.builds() "
332 "or @hypothesis.strategies.composite to define your strategy."
333 " See https://hypothesis.readthedocs.io/en/latest/reference/"
334 "strategies.html#hypothesis.strategies.builds or "
335 "https://hypothesis.readthedocs.io/en/latest/reference/"
336 "strategies.html#hypothesis.strategies.composite for more "
337 "details."
338 )
339 else:
340 raise HypothesisException(
341 "Using example() inside a test function is a bad "
342 "idea. Instead consider using hypothesis.strategies.data() "
343 "to draw more examples during testing. See "
344 "https://hypothesis.readthedocs.io/en/latest/reference/"
345 "strategies.html#hypothesis.strategies.data for more details."
346 )
347
348 try:
349 return self.__examples.pop()
350 except (AttributeError, IndexError):
351 self.__examples: list[Ex] = []
352
353 from hypothesis.core import given
354
355 # Note: this function has a weird name because it might appear in
356 # tracebacks, and we want users to know that they can ignore it.
357 @given(self)
358 @settings(
359 database=None,
360 # generate only a few examples at a time to avoid slow interactivity
361 # for large strategies. The overhead of @given is very small relative
362 # to generation, so a small batch size is fine.
363 max_examples=10,
364 deadline=None,
365 verbosity=Verbosity.quiet,
366 phases=(Phase.generate,),
367 suppress_health_check=list(HealthCheck),
368 )
369 def example_generating_inner_function(
370 ex: Ex, # type: ignore # mypy is overzealous in preventing covariant params
371 ) -> None:
372 self.__examples.append(ex)
373
374 example_generating_inner_function()
375 shuffle(self.__examples)
376 return self.__examples.pop()
377
378 def map(self, pack: Callable[[Ex], T]) -> "SearchStrategy[T]":
379 """Returns a new strategy which generates a value from this one, and
380 then returns ``pack(value)``. For example, ``integers().map(str)``
381 could generate ``str(5)`` == ``"5"``.
382 """
383 if is_identity_function(pack):
384 return self # type: ignore # Mypy has no way to know that `Ex == T`
385 return MappedStrategy(self, pack=pack)
386
387 def flatmap(
388 self, expand: Callable[[Ex], "SearchStrategy[T]"]
389 ) -> "SearchStrategy[T]": # FIXME
390 """Old syntax for a special case of |@composite|:
391
392 .. code-block:: python
393
394 @st.composite
395 def flatmap_like(draw, base_strategy, expand):
396 value = draw(base_strategy)
397 new_strategy = expand(value)
398 return draw(new_strategy)
399
400 We find that the greater readability of |@composite| usually outweighs
401 the verbosity, with a few exceptions for simple cases or recipes like
402 ``from_type(type).flatmap(from_type)`` ("pick a type, get a strategy for
403 any instance of that type, and then generate one of those").
404 """
405 from hypothesis.strategies._internal.flatmapped import FlatMapStrategy
406
407 return FlatMapStrategy(self, expand=expand)
408
409 # Note that we previously had condition extracted to a type alias as
410 # PredicateT. However, that was only useful when not specifying a relationship
411 # between the generic Ts and some other function param / return value.
412 # If we do want to - like here, where we want to say that the Ex arg to condition
413 # is of the same type as the strategy's Ex - then you need to write out the
414 # entire Callable[[Ex], Any] expression rather than use a type alias.
415 # TypeAlias is *not* simply a macro that inserts the text. TypeAlias will not
416 # reference the local TypeVar context.
417 def filter(self, condition: Callable[[Ex], Any]) -> "SearchStrategy[Ex]":
418 """Returns a new strategy that generates values from this strategy
419 which satisfy the provided condition.
420
421 Note that if the condition is too hard to satisfy this might result
422 in your tests failing with an Unsatisfiable exception.
423 A basic version of the filtering logic would look something like:
424
425 .. code-block:: python
426
427 @st.composite
428 def filter_like(draw, strategy, condition):
429 for _ in range(3):
430 value = draw(strategy)
431 if condition(value):
432 return value
433 assume(False)
434 """
435 return FilteredStrategy(self, conditions=(condition,))
436
437 @property
438 def branches(self) -> Sequence["SearchStrategy[Ex]"]:
439 return [self]
440
441 def __or__(self, other: "SearchStrategy[T]") -> "SearchStrategy[Ex | T]":
442 """Return a strategy which produces values by randomly drawing from one
443 of this strategy or the other strategy.
444
445 This method is part of the public API.
446 """
447 if not isinstance(other, SearchStrategy):
448 raise ValueError(f"Cannot | a SearchStrategy with {other!r}")
449
450 # Unwrap explicitly or'd strategies. This turns the
451 # common case of e.g. st.integers() | st.integers() | st.integers() from
452 #
453 # one_of(one_of(integers(), integers()), integers())
454 #
455 # into
456 #
457 # one_of(integers(), integers(), integers())
458 #
459 # This is purely an aesthetic unwrapping, for e.g. reprs. In practice
460 # we use .branches / .element_strategies to get the list of possible
461 # strategies, so this unwrapping is *not* necessary for correctness.
462 strategies: list[SearchStrategy] = []
463 strategies.extend(
464 self.original_strategies if isinstance(self, OneOfStrategy) else [self]
465 )
466 strategies.extend(
467 other.original_strategies if isinstance(other, OneOfStrategy) else [other]
468 )
469 return OneOfStrategy(strategies)
470
471 def __bool__(self) -> bool:
472 warnings.warn(
473 f"bool({self!r}) is always True, did you mean to draw a value?",
474 HypothesisWarning,
475 stacklevel=2,
476 )
477 return True
478
479 def validate(self) -> None:
480 """Throw an exception if the strategy is not valid.
481
482 Strategies should implement ``do_validate``, which is called by this
483 method. They should not override ``validate``.
484
485 This can happen due to invalid arguments, or lazy construction.
486 """
487 thread_id = threading.get_ident()
488 if self.validate_called.get(thread_id, False):
489 return
490 # we need to set validate_called before calling do_validate, for
491 # recursive / deferred strategies. But if a thread switches after
492 # validate_called but before do_validate, we might have a strategy
493 # which does weird things like drawing when do_validate would error but
494 # its params are technically valid (e.g. a param was passed as 1.0
495 # instead of 1) and get into weird internal states.
496 #
497 # There are two ways to fix this.
498 # (1) The first is a per-strategy lock around do_validate. Even though we
499 # expect near-zero lock contention, this still adds the lock overhead.
500 # (2) The second is allowing concurrent .validate calls. Since validation
501 # is (assumed to be) deterministic, both threads will produce the same
502 # end state, so the validation order or race conditions does not matter.
503 #
504 # In order to avoid the lock overhead of (1), we use (2) here. See also
505 # discussion in https://github.com/HypothesisWorks/hypothesis/pull/4473.
506 try:
507 self.validate_called[thread_id] = True
508 self.do_validate()
509 self.is_empty
510 self.has_reusable_values
511 except Exception:
512 self.validate_called[thread_id] = False
513 raise
514
515 @property
516 def class_label(self) -> int:
517 cls = self.__class__
518 try:
519 return cls.LABELS[cls]
520 except KeyError:
521 pass
522 result = calc_label_from_cls(cls)
523 cls.LABELS[cls] = result
524 return result
525
526 @property
527 def label(self) -> int:
528 if isinstance((label := self.__label), int):
529 # avoid locking if we've already completely computed the label.
530 return label
531
532 with label_lock:
533 if self.__label is calculating:
534 return 0
535 self.__label = calculating
536 self.__label = self.calc_label()
537 return self.__label
538
539 def calc_label(self) -> int:
540 return self.class_label
541
542 def do_validate(self) -> None:
543 pass
544
545 def do_draw(self, data: ConjectureData) -> Ex:
546 raise NotImplementedError(f"{type(self).__name__}.do_draw")
547
548
549def _is_hashable(value: object) -> tuple[bool, int | None]:
550 # hashing can be expensive; return the hash value if we compute it, so that
551 # callers don't have to recompute.
552 try:
553 return (True, hash(value))
554 except TypeError:
555 return (False, None)
556
557
558def is_hashable(value: object) -> bool:
559 return _is_hashable(value)[0]
560
561
562class SampledFromStrategy(SearchStrategy[Ex]):
563 """A strategy which samples from a set of elements. This is essentially
564 equivalent to using a OneOfStrategy over Just strategies but may be more
565 efficient and convenient.
566 """
567
568 _MAX_FILTER_CALLS: ClassVar[int] = 10_000
569
570 def __init__(
571 self,
572 elements: Sequence[Ex],
573 *,
574 force_repr: str | None = None,
575 force_repr_braces: tuple[str, str] | None = None,
576 transformations: tuple[
577 tuple[Literal["filter", "map"], Callable[[Ex], Any]],
578 ...,
579 ] = (),
580 ):
581 super().__init__()
582 self.elements = cu.check_sample(elements, "sampled_from")
583 assert self.elements
584 self.force_repr = force_repr
585 self.force_repr_braces = force_repr_braces
586 self._transformations = transformations
587
588 self._cached_repr: str | None = None
589
590 def map(self, pack: Callable[[Ex], T]) -> SearchStrategy[T]:
591 s = type(self)(
592 self.elements,
593 force_repr=self.force_repr,
594 force_repr_braces=self.force_repr_braces,
595 transformations=(*self._transformations, ("map", pack)),
596 )
597 # guaranteed by the ("map", pack) transformation
598 return cast(SearchStrategy[T], s)
599
600 def filter(self, condition: Callable[[Ex], Any]) -> SearchStrategy[Ex]:
601 return type(self)(
602 self.elements,
603 force_repr=self.force_repr,
604 force_repr_braces=self.force_repr_braces,
605 transformations=(*self._transformations, ("filter", condition)),
606 )
607
608 def __repr__(self):
609 if self._cached_repr is None:
610 rep = get_pretty_function_description
611 elements_s = (
612 ", ".join(rep(v) for v in self.elements[:512]) + ", ..."
613 if len(self.elements) > 512
614 else ", ".join(rep(v) for v in self.elements)
615 )
616 braces = self.force_repr_braces or ("(", ")")
617 instance_s = (
618 self.force_repr or f"sampled_from({braces[0]}{elements_s}{braces[1]})"
619 )
620 transforms_s = "".join(
621 f".{name}({get_pretty_function_description(f)})"
622 for name, f in self._transformations
623 )
624 repr_s = instance_s + transforms_s
625 self._cached_repr = repr_s
626 return self._cached_repr
627
628 def calc_label(self) -> int:
629 # strategy.label is effectively an under-approximation of structural
630 # equality (i.e., some strategies may have the same label when they are not
631 # structurally identical). More importantly for calculating the
632 # SampledFromStrategy label, we might have hash(s1) != hash(s2) even
633 # when s1 and s2 are structurally identical. For instance:
634 #
635 # s1 = st.sampled_from([st.none()])
636 # s2 = st.sampled_from([st.none()])
637 # assert hash(s1) != hash(s2)
638 #
639 # (see also test cases in test_labels.py).
640 #
641 # We therefore use the labels of any component strategies when calculating
642 # our label, and only use the hash if it is not a strategy.
643 #
644 # That's the ideal, anyway. In reality the logic is more complicated than
645 # necessary in order to be efficient in the presence of (very) large sequences:
646 # * add an unabashed special case for range, to avoid iteration over an
647 # enormous range when we know it is entirely integers.
648 # * if there is at least one strategy in self.elements, use strategy label,
649 # and the element hash otherwise.
650 # * if there are no strategies in self.elements, take the hash of the
651 # entire sequence. This prevents worst-case performance of hashing each
652 # element when a hash of the entire sequence would have sufficed.
653 #
654 # The worst case performance of this scheme is
655 # itertools.chain(range(2**100), [st.none()]), where it degrades to
656 # hashing every int in the range.
657 (elements_is_hashable, hash_value) = _is_hashable(self.elements)
658 if isinstance(self.elements, range) or (
659 elements_is_hashable
660 and not any(isinstance(e, SearchStrategy) for e in self.elements)
661 ):
662 return combine_labels(
663 self.class_label, calc_label_from_name(str(hash_value))
664 )
665
666 labels = [self.class_label]
667 for element in self.elements:
668 if not is_hashable(element):
669 continue
670
671 labels.append(
672 element.label
673 if isinstance(element, SearchStrategy)
674 else calc_label_from_hash(element)
675 )
676
677 return combine_labels(*labels)
678
679 def calc_has_reusable_values(self, recur: RecurT) -> bool:
680 # Because our custom .map/.filter implementations skip the normal
681 # wrapper strategies (which would automatically return False for us),
682 # we need to manually return False here if any transformations have
683 # been applied.
684 return not self._transformations
685
686 def calc_is_cacheable(self, recur: RecurT) -> bool:
687 return is_hashable(self.elements)
688
689 def _transform(
690 self,
691 # https://github.com/python/mypy/issues/7049, we're not writing `element`
692 # anywhere in the class so this is still type-safe. mypy is being more
693 # conservative than necessary
694 element: Ex, # type: ignore
695 ) -> Ex | UniqueIdentifier:
696 # Used in UniqueSampledListStrategy
697 for name, f in self._transformations:
698 if name == "map":
699 result = f(element)
700 if build_context := _current_build_context.value:
701 build_context.record_call(result, f, args=[element], kwargs={})
702 element = result
703 else:
704 assert name == "filter"
705 if not f(element):
706 return filter_not_satisfied
707 return element
708
709 def do_draw(self, data: ConjectureData) -> Ex:
710 result = self.do_filtered_draw(data)
711 if isinstance(result, SearchStrategy) and all(
712 isinstance(x, SearchStrategy) for x in self.elements
713 ):
714 data._sampled_from_all_strategies_elements_message = (
715 "sampled_from was given a collection of strategies: "
716 "{!r}. Was one_of intended?",
717 self.elements,
718 )
719 if result is filter_not_satisfied:
720 data.mark_invalid(f"Aborted test because unable to satisfy {self!r}")
721 assert not isinstance(result, UniqueIdentifier)
722 return result
723
724 def get_element(self, i: int) -> Ex | UniqueIdentifier:
725 return self._transform(self.elements[i])
726
727 def do_filtered_draw(self, data: ConjectureData) -> Ex | UniqueIdentifier:
728 # Set of indices that have been tried so far, so that we never test
729 # the same element twice during a draw.
730 known_bad_indices: set[int] = set()
731
732 # Start with ordinary rejection sampling. It's fast if it works, and
733 # if it doesn't work then it was only a small amount of overhead.
734 for _ in range(3):
735 i = data.draw_integer(0, len(self.elements) - 1)
736 if i not in known_bad_indices:
737 element = self.get_element(i)
738 if element is not filter_not_satisfied:
739 return element
740 if not known_bad_indices:
741 data.events[f"Retried draw from {self!r} to satisfy filter"] = ""
742 known_bad_indices.add(i)
743
744 # If we've tried all the possible elements, give up now.
745 max_good_indices = len(self.elements) - len(known_bad_indices)
746 if not max_good_indices:
747 return filter_not_satisfied
748
749 # Impose an arbitrary cutoff to prevent us from wasting too much time
750 # on very large element lists.
751 max_good_indices = min(max_good_indices, self._MAX_FILTER_CALLS - 3)
752
753 # Before building the list of allowed indices, speculatively choose
754 # one of them. We don't yet know how many allowed indices there will be,
755 # so this choice might be out-of-bounds, but that's OK.
756 speculative_index = data.draw_integer(0, max_good_indices - 1)
757
758 # Calculate the indices of allowed values, so that we can choose one
759 # of them at random. But if we encounter the speculatively-chosen one,
760 # just use that and return immediately. Note that we also track the
761 # allowed elements, in case of .map(some_stateful_function)
762 allowed: list[tuple[int, Ex]] = []
763 for i in range(min(len(self.elements), self._MAX_FILTER_CALLS - 3)):
764 if i not in known_bad_indices:
765 element = self.get_element(i)
766 if element is not filter_not_satisfied:
767 assert not isinstance(element, UniqueIdentifier)
768 allowed.append((i, element))
769 if len(allowed) > speculative_index:
770 # Early-exit case: We reached the speculative index, so
771 # we just return the corresponding element.
772 data.draw_integer(0, len(self.elements) - 1, forced=i)
773 return element
774
775 # The speculative index didn't work out, but at this point we've built
776 # and can choose from the complete list of allowed indices and elements.
777 if allowed:
778 i, element = data.choice(allowed)
779 data.draw_integer(0, len(self.elements) - 1, forced=i)
780 return element
781 # If there are no allowed indices, the filter couldn't be satisfied.
782 return filter_not_satisfied
783
784
785class OneOfStrategy(SearchStrategy[Ex]):
786 """Implements a union of strategies. Given a number of strategies this
787 generates values which could have come from any of them.
788
789 The conditional distribution draws uniformly at random from some
790 non-empty subset of these strategies and then draws from the
791 conditional distribution of that strategy.
792 """
793
794 def __init__(self, strategies: Sequence[SearchStrategy[Ex]]):
795 super().__init__()
796 self.original_strategies = tuple(strategies)
797 self.__element_strategies: Sequence[SearchStrategy[Ex]] | None = None
798 self.__in_branches = False
799 self._branches_lock = RLock()
800
801 def calc_is_empty(self, recur: RecurT) -> bool:
802 return all(recur(e) for e in self.original_strategies)
803
804 def calc_has_reusable_values(self, recur: RecurT) -> bool:
805 return all(recur(e) for e in self.original_strategies)
806
807 def calc_is_cacheable(self, recur: RecurT) -> bool:
808 return all(recur(e) for e in self.original_strategies)
809
810 @property
811 def element_strategies(self) -> Sequence[SearchStrategy[Ex]]:
812 if self.__element_strategies is None:
813 # While strategies are hashable, they use object.__hash__ and are
814 # therefore distinguished only by identity.
815 #
816 # In principle we could "just" define a __hash__ method
817 # (and __eq__, but that's easy in terms of type() and hash())
818 # to make this more powerful, but this is harder than it sounds:
819 #
820 # 1. Strategies are often distinguished by non-hashable attributes,
821 # or by attributes that have the same hash value ("^.+" / b"^.+").
822 # 2. LazyStrategy: can't reify the wrapped strategy without breaking
823 # laziness, so there's a hash each for the lazy and the nonlazy.
824 #
825 # Having made several attempts, the minor benefits of making strategies
826 # hashable are simply not worth the engineering effort it would take.
827 # See also issues #2291 and #2327.
828 seen: set[SearchStrategy] = {self}
829 strategies: list[SearchStrategy] = []
830 for arg in self.original_strategies:
831 check_strategy(arg)
832 if not arg.is_empty:
833 for s in arg.branches:
834 if s not in seen and not s.is_empty:
835 seen.add(s)
836 strategies.append(s)
837 self.__element_strategies = strategies
838 return self.__element_strategies
839
840 def calc_label(self) -> int:
841 return combine_labels(
842 self.class_label, *(p.label for p in self.original_strategies)
843 )
844
845 def do_draw(self, data: ConjectureData) -> Ex:
846 strategy = data.draw(
847 SampledFromStrategy(self.element_strategies).filter(
848 lambda s: not s.is_currently_empty(data)
849 )
850 )
851 return data.draw(strategy)
852
853 def __repr__(self) -> str:
854 return "one_of({})".format(", ".join(map(repr, self.original_strategies)))
855
856 def do_validate(self) -> None:
857 for e in self.element_strategies:
858 e.validate()
859
860 @property
861 def branches(self) -> Sequence[SearchStrategy[Ex]]:
862 if self.__element_strategies is not None:
863 # common fast path which avoids the lock
864 return self.element_strategies
865
866 with self._branches_lock:
867 if not self.__in_branches:
868 try:
869 self.__in_branches = True
870 return self.element_strategies
871 finally:
872 self.__in_branches = False
873 else:
874 return [self]
875
876 def filter(self, condition: Callable[[Ex], Any]) -> SearchStrategy[Ex]:
877 return FilteredStrategy(
878 OneOfStrategy([s.filter(condition) for s in self.original_strategies]),
879 conditions=(),
880 )
881
882
883@overload
884def one_of(
885 __args: Sequence[SearchStrategy[Ex]],
886) -> SearchStrategy[Ex]: # pragma: no cover
887 ...
888
889
890@overload
891def one_of(__a1: SearchStrategy[Ex]) -> SearchStrategy[Ex]: # pragma: no cover
892 ...
893
894
895@overload
896def one_of(
897 __a1: SearchStrategy[Ex], __a2: SearchStrategy[T]
898) -> SearchStrategy[Ex | T]: # pragma: no cover
899 ...
900
901
902@overload
903def one_of(
904 __a1: SearchStrategy[Ex], __a2: SearchStrategy[T], __a3: SearchStrategy[T3]
905) -> SearchStrategy[Ex | T | T3]: # pragma: no cover
906 ...
907
908
909@overload
910def one_of(
911 __a1: SearchStrategy[Ex],
912 __a2: SearchStrategy[T],
913 __a3: SearchStrategy[T3],
914 __a4: SearchStrategy[T4],
915) -> SearchStrategy[Ex | T | T3 | T4]: # pragma: no cover
916 ...
917
918
919@overload
920def one_of(
921 __a1: SearchStrategy[Ex],
922 __a2: SearchStrategy[T],
923 __a3: SearchStrategy[T3],
924 __a4: SearchStrategy[T4],
925 __a5: SearchStrategy[T5],
926) -> SearchStrategy[Ex | T | T3 | T4 | T5]: # pragma: no cover
927 ...
928
929
930@overload
931def one_of(*args: SearchStrategy[Any]) -> SearchStrategy[Any]: # pragma: no cover
932 ...
933
934
935@defines_strategy(eager=True)
936def one_of(
937 *args: Sequence[SearchStrategy[Any]] | SearchStrategy[Any],
938) -> SearchStrategy[Any]:
939 # Mypy workaround alert: Any is too loose above; the return parameter
940 # should be the union of the input parameters. Unfortunately, Mypy <=0.600
941 # raises errors due to incompatible inputs instead. See #1270 for links.
942 # v0.610 doesn't error; it gets inference wrong for 2+ arguments instead.
943 """Return a strategy which generates values from any of the argument
944 strategies.
945
946 This may be called with one iterable argument instead of multiple
947 strategy arguments, in which case ``one_of(x)`` and ``one_of(*x)`` are
948 equivalent.
949
950 Examples from this strategy will generally shrink to ones that come from
951 strategies earlier in the list, then shrink according to behaviour of the
952 strategy that produced them. In order to get good shrinking behaviour,
953 try to put simpler strategies first. e.g. ``one_of(none(), text())`` is
954 better than ``one_of(text(), none())``.
955
956 This is especially important when using recursive strategies. e.g.
957 ``x = st.deferred(lambda: st.none() | st.tuples(x, x))`` will shrink well,
958 but ``x = st.deferred(lambda: st.tuples(x, x) | st.none())`` will shrink
959 very badly indeed.
960 """
961 if len(args) == 1 and not isinstance(args[0], SearchStrategy):
962 try:
963 args = tuple(args[0])
964 except TypeError:
965 pass
966 if len(args) == 1 and isinstance(args[0], SearchStrategy):
967 # This special-case means that we can one_of over lists of any size
968 # without incurring any performance overhead when there is only one
969 # strategy, and keeps our reprs simple.
970 return args[0]
971 if args and not any(isinstance(a, SearchStrategy) for a in args):
972 # And this special case is to give a more-specific error message if it
973 # seems that the user has confused `one_of()` for `sampled_from()`;
974 # the remaining validation is left to OneOfStrategy. See PR #2627.
975 raise InvalidArgument(
976 f"Did you mean st.sampled_from({list(args)!r})? st.one_of() is used "
977 "to combine strategies, but all of the arguments were of other types."
978 )
979 # we've handled the case where args is a one-element sequence [(s1, s2, ...)]
980 # above, so we can assume it's an actual sequence of strategies.
981 args = cast(Sequence[SearchStrategy], args)
982 return OneOfStrategy(args)
983
984
985class MappedStrategy(SearchStrategy[MappedTo], Generic[MappedFrom, MappedTo]):
986 """A strategy which is defined purely by conversion to and from another
987 strategy.
988
989 Its parameter and distribution come from that other strategy.
990 """
991
992 def __init__(
993 self,
994 strategy: SearchStrategy[MappedFrom],
995 pack: Callable[[MappedFrom], MappedTo],
996 ) -> None:
997 super().__init__()
998 self.mapped_strategy = strategy
999 self.pack = pack
1000
1001 def calc_is_empty(self, recur: RecurT) -> bool:
1002 return recur(self.mapped_strategy)
1003
1004 def calc_is_cacheable(self, recur: RecurT) -> bool:
1005 return recur(self.mapped_strategy)
1006
1007 def __repr__(self) -> str:
1008 if not hasattr(self, "_cached_repr"):
1009 self._cached_repr = f"{self.mapped_strategy!r}.map({get_pretty_function_description(self.pack)})"
1010 return self._cached_repr
1011
1012 def do_validate(self) -> None:
1013 self.mapped_strategy.validate()
1014
1015 def do_draw(self, data: ConjectureData) -> MappedTo:
1016 with warnings.catch_warnings():
1017 if isinstance(self.pack, type) and issubclass(
1018 self.pack, (abc.Mapping, abc.Set)
1019 ):
1020 warnings.simplefilter("ignore", BytesWarning)
1021 for _ in range(3):
1022 try:
1023 data.start_span(MAPPED_SEARCH_STRATEGY_DO_DRAW_LABEL)
1024 x = data.draw(self.mapped_strategy)
1025 result = self.pack(x)
1026 data.stop_span()
1027 current_build_context().record_call(
1028 result, self.pack, args=[x], kwargs={}
1029 )
1030 return result
1031 except UnsatisfiedAssumption:
1032 data.stop_span(discard=True)
1033 raise UnsatisfiedAssumption
1034
1035 @property
1036 def branches(self) -> Sequence[SearchStrategy[MappedTo]]:
1037 return [
1038 MappedStrategy(strategy, pack=self.pack)
1039 for strategy in self.mapped_strategy.branches
1040 ]
1041
1042 def filter(
1043 self, condition: Callable[[MappedTo], Any]
1044 ) -> "SearchStrategy[MappedTo]":
1045 # Includes a special case so that we can rewrite filters on collection
1046 # lengths, when most collections are `st.lists(...).map(the_type)`.
1047 ListStrategy = _list_strategy_type()
1048 if not isinstance(self.mapped_strategy, ListStrategy) or not (
1049 (isinstance(self.pack, type) and issubclass(self.pack, abc.Collection))
1050 or self.pack in _collection_ish_functions()
1051 ):
1052 return super().filter(condition)
1053
1054 # Check whether our inner list strategy can rewrite this filter condition.
1055 # If not, discard the result and _only_ apply a new outer filter.
1056 new = ListStrategy.filter(self.mapped_strategy, condition)
1057 if getattr(new, "filtered_strategy", None) is self.mapped_strategy:
1058 return super().filter(condition) # didn't rewrite
1059
1060 # Apply a new outer filter even though we rewrote the inner strategy,
1061 # because some collections can change the list length (dict, set, etc).
1062 return FilteredStrategy(type(self)(new, self.pack), conditions=(condition,))
1063
1064
1065@lru_cache
1066def _list_strategy_type() -> Any:
1067 from hypothesis.strategies._internal.collections import ListStrategy
1068
1069 return ListStrategy
1070
1071
1072def _collection_ish_functions() -> Sequence[Any]:
1073 funcs = [sorted]
1074 if np := sys.modules.get("numpy"):
1075 # c.f. https://numpy.org/doc/stable/reference/routines.array-creation.html
1076 # Probably only `np.array` and `np.asarray` will be used in practice,
1077 # but why should that stop us when we've already gone this far?
1078 funcs += [
1079 np.empty_like,
1080 np.eye,
1081 np.identity,
1082 np.ones_like,
1083 np.zeros_like,
1084 np.array,
1085 np.asarray,
1086 np.asanyarray,
1087 np.ascontiguousarray,
1088 np.asmatrix,
1089 np.copy,
1090 np.rec.array,
1091 np.rec.fromarrays,
1092 np.rec.fromrecords,
1093 np.diag,
1094 # bonus undocumented functions from tab-completion:
1095 np.asarray_chkfinite,
1096 np.asfortranarray,
1097 ]
1098
1099 return funcs
1100
1101
1102filter_not_satisfied = UniqueIdentifier("filter not satisfied")
1103
1104
1105class FilteredStrategy(SearchStrategy[Ex]):
1106 def __init__(
1107 self, strategy: SearchStrategy[Ex], conditions: tuple[Callable[[Ex], Any], ...]
1108 ):
1109 super().__init__()
1110 if isinstance(strategy, FilteredStrategy):
1111 # Flatten chained filters into a single filter with multiple conditions.
1112 self.flat_conditions: tuple[Callable[[Ex], Any], ...] = (
1113 strategy.flat_conditions + conditions
1114 )
1115 self.filtered_strategy: SearchStrategy[Ex] = strategy.filtered_strategy
1116 else:
1117 self.flat_conditions = conditions
1118 self.filtered_strategy = strategy
1119
1120 assert isinstance(self.flat_conditions, tuple)
1121 assert not isinstance(self.filtered_strategy, FilteredStrategy)
1122
1123 self.__condition: Callable[[Ex], Any] | None = None
1124
1125 def calc_is_empty(self, recur: RecurT) -> bool:
1126 return recur(self.filtered_strategy)
1127
1128 def calc_is_cacheable(self, recur: RecurT) -> bool:
1129 return recur(self.filtered_strategy)
1130
1131 def __repr__(self) -> str:
1132 if not hasattr(self, "_cached_repr"):
1133 self._cached_repr = "{!r}{}".format(
1134 self.filtered_strategy,
1135 "".join(
1136 f".filter({get_pretty_function_description(cond)})"
1137 for cond in self.flat_conditions
1138 ),
1139 )
1140 return self._cached_repr
1141
1142 def do_validate(self) -> None:
1143 # Start by validating our inner filtered_strategy. If this was a LazyStrategy,
1144 # validation also reifies it so that subsequent calls to e.g. `.filter()` will
1145 # be passed through.
1146 self.filtered_strategy.validate()
1147 # So now we have a reified inner strategy, we'll replay all our saved
1148 # predicates in case some or all of them can be rewritten. Note that this
1149 # replaces the `fresh` strategy too!
1150 fresh = self.filtered_strategy
1151 for cond in self.flat_conditions:
1152 fresh = fresh.filter(cond)
1153 if isinstance(fresh, FilteredStrategy):
1154 # In this case we have at least some non-rewritten filter predicates,
1155 # so we just re-initialize the strategy.
1156 FilteredStrategy.__init__(
1157 self, fresh.filtered_strategy, fresh.flat_conditions
1158 )
1159 else:
1160 # But if *all* the predicates were rewritten... well, do_validate() is
1161 # an in-place method so we still just re-initialize the strategy!
1162 FilteredStrategy.__init__(self, fresh, ())
1163
1164 def filter(self, condition: Callable[[Ex], Any]) -> "FilteredStrategy[Ex]":
1165 # If we can, it's more efficient to rewrite our strategy to satisfy the
1166 # condition. We therefore exploit the fact that the order of predicates
1167 # doesn't matter (`f(x) and g(x) == g(x) and f(x)`) by attempting to apply
1168 # condition directly to our filtered strategy as the inner-most filter.
1169 out = self.filtered_strategy.filter(condition)
1170 # If it couldn't be rewritten, we'll get a new FilteredStrategy - and then
1171 # combine the conditions of each in our expected newest=last order.
1172 if isinstance(out, FilteredStrategy):
1173 return FilteredStrategy(
1174 out.filtered_strategy, self.flat_conditions + out.flat_conditions
1175 )
1176 # But if it *could* be rewritten, we can return the more efficient form!
1177 return FilteredStrategy(out, self.flat_conditions)
1178
1179 @property
1180 def condition(self) -> Callable[[Ex], Any]:
1181 # We write this defensively to avoid any threading race conditions
1182 # with our manual FilteredStrategy.__init__ for filter-rewriting.
1183 # See https://github.com/HypothesisWorks/hypothesis/pull/4522.
1184 if (condition := self.__condition) is not None:
1185 return condition
1186
1187 if len(self.flat_conditions) == 1:
1188 # Avoid an extra indirection in the common case of only one condition.
1189 condition = self.flat_conditions[0]
1190 elif len(self.flat_conditions) == 0:
1191 # Possible, if unlikely, due to filter predicate rewriting
1192 condition = lambda _: True # type: ignore # covariant type param
1193 else:
1194 condition = lambda x: all( # type: ignore # covariant type param
1195 cond(x) for cond in self.flat_conditions
1196 )
1197 self.__condition = condition
1198 return condition
1199
1200 def do_draw(self, data: ConjectureData) -> Ex:
1201 result = self.do_filtered_draw(data)
1202 if result is not filter_not_satisfied:
1203 return cast(Ex, result)
1204
1205 data.mark_invalid(f"Aborted test because unable to satisfy {self!r}")
1206
1207 def do_filtered_draw(self, data: ConjectureData) -> Ex | UniqueIdentifier:
1208 for i in range(3):
1209 data.start_span(FILTERED_SEARCH_STRATEGY_DO_DRAW_LABEL)
1210 value = data.draw(self.filtered_strategy)
1211 if self.condition(value):
1212 data.stop_span()
1213 return value
1214 else:
1215 data.stop_span(discard=True)
1216 if i == 0:
1217 data.events[f"Retried draw from {self!r} to satisfy filter"] = ""
1218
1219 return filter_not_satisfied
1220
1221 @property
1222 def branches(self) -> Sequence[SearchStrategy[Ex]]:
1223 return [
1224 FilteredStrategy(strategy=strategy, conditions=self.flat_conditions)
1225 for strategy in self.filtered_strategy.branches
1226 ]
1227
1228
1229@check_function
1230def check_strategy(arg: object, name: str = "") -> None:
1231 assert isinstance(name, str)
1232 if not isinstance(arg, SearchStrategy):
1233 hint = ""
1234 if isinstance(arg, (list, tuple)):
1235 hint = ", such as st.sampled_from({}),".format(name or "...")
1236 if name:
1237 name += "="
1238 raise InvalidArgument(
1239 f"Expected a SearchStrategy{hint} but got {name}{arg!r} "
1240 f"(type={type(arg).__name__})"
1241 )