1# This file is part of Hypothesis, which may be found at
2# https://github.com/HypothesisWorks/hypothesis/
3#
4# Copyright the Hypothesis Authors.
5# Individual contributors are listed in AUTHORS.rst and the git log.
6#
7# This Source Code Form is subject to the terms of the Mozilla Public License,
8# v. 2.0. If a copy of the MPL was not distributed with this file, You can
9# obtain one at https://mozilla.org/MPL/2.0/.
10
11import threading
12from contextlib import contextmanager
13
14from hypothesis.errors import InvalidArgument
15from hypothesis.internal.reflection import get_pretty_function_description
16from hypothesis.internal.validation import check_type
17from hypothesis.strategies._internal.strategies import (
18 OneOfStrategy,
19 SearchStrategy,
20 check_strategy,
21)
22
23
24class LimitReached(BaseException):
25 pass
26
27
28class LimitedStrategy(SearchStrategy):
29 def __init__(self, strategy):
30 super().__init__()
31 self.base_strategy = strategy
32 self._threadlocal = threading.local()
33
34 @property
35 def marker(self):
36 return getattr(self._threadlocal, "marker", 0)
37
38 @marker.setter
39 def marker(self, value):
40 self._threadlocal.marker = value
41
42 @property
43 def currently_capped(self):
44 return getattr(self._threadlocal, "currently_capped", False)
45
46 @currently_capped.setter
47 def currently_capped(self, value):
48 self._threadlocal.currently_capped = value
49
50 def __repr__(self) -> str:
51 return f"LimitedStrategy({self.base_strategy!r})"
52
53 def do_validate(self) -> None:
54 self.base_strategy.validate()
55
56 def do_draw(self, data):
57 assert self.currently_capped
58 if self.marker <= 0:
59 raise LimitReached
60 self.marker -= 1
61 return data.draw(self.base_strategy)
62
63 @contextmanager
64 def capped(self, max_templates):
65 try:
66 was_capped = self.currently_capped
67 self.currently_capped = True
68 self.marker = max_templates
69 yield
70 finally:
71 self.currently_capped = was_capped
72
73
74class RecursiveStrategy(SearchStrategy):
75 def __init__(self, base, extend, max_leaves):
76 super().__init__()
77 self.max_leaves = max_leaves
78 self.base = base
79 self.limited_base = LimitedStrategy(base)
80 self.extend = extend
81
82 strategies = [self.limited_base, self.extend(self.limited_base)]
83 while 2 ** (len(strategies) - 1) <= max_leaves:
84 strategies.append(extend(OneOfStrategy(tuple(strategies))))
85 self.strategy = OneOfStrategy(strategies)
86
87 def __repr__(self) -> str:
88 if not hasattr(self, "_cached_repr"):
89 self._cached_repr = "recursive(%r, %s, max_leaves=%d)" % (
90 self.base,
91 get_pretty_function_description(self.extend),
92 self.max_leaves,
93 )
94 return self._cached_repr
95
96 def do_validate(self) -> None:
97 check_strategy(self.base, "base")
98 extended = self.extend(self.limited_base)
99 check_strategy(extended, f"extend({self.limited_base!r})")
100 self.limited_base.validate()
101 extended.validate()
102 check_type(int, self.max_leaves, "max_leaves")
103 if self.max_leaves <= 0:
104 raise InvalidArgument(
105 f"max_leaves={self.max_leaves!r} must be greater than zero"
106 )
107
108 def do_draw(self, data):
109 count = 0
110 while True:
111 try:
112 with self.limited_base.capped(self.max_leaves):
113 return data.draw(self.strategy)
114 except LimitReached:
115 if count == 0:
116 msg = f"Draw for {self!r} exceeded max_leaves and had to be retried"
117 data.events[msg] = ""
118 count += 1