1# This file is part of Hypothesis, which may be found at
2# https://github.com/HypothesisWorks/hypothesis/
3#
4# Copyright the Hypothesis Authors.
5# Individual contributors are listed in AUTHORS.rst and the git log.
6#
7# This Source Code Form is subject to the terms of the Mozilla Public License,
8# v. 2.0. If a copy of the MPL was not distributed with this file, You can
9# obtain one at https://mozilla.org/MPL/2.0/.
10
11import threading
12import warnings
13from collections.abc import Callable
14from contextlib import contextmanager
15
16from hypothesis.errors import HypothesisWarning, InvalidArgument
17from hypothesis.internal.reflection import (
18 get_pretty_function_description,
19 is_first_param_referenced_in_function,
20 is_identity_function,
21)
22from hypothesis.internal.validation import check_type
23from hypothesis.strategies._internal.strategies import (
24 OneOfStrategy,
25 SearchStrategy,
26 check_strategy,
27)
28from hypothesis.utils.deprecation import note_deprecation
29
30
31class LimitReached(BaseException):
32 pass
33
34
35class LimitedStrategy(SearchStrategy):
36 def __init__(self, strategy):
37 super().__init__()
38 self.base_strategy = strategy
39 self._threadlocal = threading.local()
40
41 @property
42 def marker(self):
43 return getattr(self._threadlocal, "marker", 0)
44
45 @marker.setter
46 def marker(self, value):
47 self._threadlocal.marker = value
48
49 @property
50 def currently_capped(self):
51 return getattr(self._threadlocal, "currently_capped", False)
52
53 @currently_capped.setter
54 def currently_capped(self, value):
55 self._threadlocal.currently_capped = value
56
57 def __repr__(self) -> str:
58 return f"LimitedStrategy({self.base_strategy!r})"
59
60 def do_validate(self) -> None:
61 self.base_strategy.validate()
62
63 def do_draw(self, data):
64 assert self.currently_capped
65 if self.marker <= 0:
66 raise LimitReached
67 self.marker -= 1
68 return data.draw(self.base_strategy)
69
70 @contextmanager
71 def capped(self, max_templates):
72 try:
73 was_capped = self.currently_capped
74 self.currently_capped = True
75 self.marker = max_templates
76 yield
77 finally:
78 self.currently_capped = was_capped
79
80
81class RecursiveStrategy(SearchStrategy):
82 def __init__(
83 self,
84 base: SearchStrategy,
85 extend: Callable[[SearchStrategy], SearchStrategy],
86 min_leaves: int | None,
87 max_leaves: int,
88 ):
89 super().__init__()
90 self.min_leaves = min_leaves
91 self.max_leaves = max_leaves
92 self.base = base
93 self.limited_base = LimitedStrategy(base)
94 self.extend = extend
95
96 strategies = [self.limited_base, self.extend(self.limited_base)]
97 while 2 ** (len(strategies) - 1) <= max_leaves:
98 strategies.append(extend(OneOfStrategy(tuple(strategies))))
99 # If min_leaves > 1, we can never draw from base directly
100 if min_leaves is not None and min_leaves > 1:
101 strategies = strategies[1:]
102 self.strategy = OneOfStrategy(strategies)
103
104 def __repr__(self) -> str:
105 if not hasattr(self, "_cached_repr"):
106 self._cached_repr = (
107 f"recursive({self.base!r}, "
108 f"{get_pretty_function_description(self.extend)}, "
109 f"min_leaves={self.min_leaves}, max_leaves={self.max_leaves})"
110 )
111 return self._cached_repr
112
113 def do_validate(self) -> None:
114 check_strategy(self.base, "base")
115 extended = self.extend(self.limited_base)
116 check_strategy(extended, f"extend({self.limited_base!r})")
117 self.limited_base.validate()
118 extended.validate()
119
120 if is_identity_function(self.extend):
121 warnings.warn(
122 "extend=lambda x: x is a no-op; you probably want to use a "
123 "different extend function, or just use the base strategy directly.",
124 HypothesisWarning,
125 stacklevel=5,
126 )
127
128 if not is_first_param_referenced_in_function(self.extend):
129 msg = (
130 f"extend={get_pretty_function_description(self.extend)} doesn't use "
131 "it's argument, and thus can't actually recurse!"
132 )
133 if self.min_leaves is None:
134 note_deprecation(
135 msg,
136 since="2026-01-12",
137 has_codemod=False,
138 stacklevel=1,
139 )
140 else:
141 raise InvalidArgument(msg)
142
143 if self.min_leaves is not None:
144 check_type(int, self.min_leaves, "min_leaves")
145 check_type(int, self.max_leaves, "max_leaves")
146 if self.min_leaves is not None and self.min_leaves <= 0:
147 raise InvalidArgument(
148 f"min_leaves={self.min_leaves!r} must be greater than zero"
149 )
150 if self.max_leaves <= 0:
151 raise InvalidArgument(
152 f"max_leaves={self.max_leaves!r} must be greater than zero"
153 )
154 if (self.min_leaves or 1) > self.max_leaves:
155 raise InvalidArgument(
156 f"min_leaves={self.min_leaves!r} must be less than or equal to "
157 f"max_leaves={self.max_leaves!r}"
158 )
159
160 def do_draw(self, data):
161 min_leaves_retries = 0
162 while True:
163 try:
164 with self.limited_base.capped(self.max_leaves):
165 result = data.draw(self.strategy)
166 leaves_drawn = self.max_leaves - self.limited_base.marker
167 if self.min_leaves and leaves_drawn < self.min_leaves:
168 data.events[
169 f"Draw for {self!r} had fewer than "
170 f"min_leaves={self.min_leaves} and had to be retried"
171 ] = ""
172 min_leaves_retries += 1
173 if min_leaves_retries < 5:
174 continue
175 data.mark_invalid(f"min_leaves={self.min_leaves} unsatisfied")
176 return result
177 except LimitReached:
178 data.events[
179 f"Draw for {self!r} exceeded "
180 f"max_leaves={self.max_leaves} and had to be retried"
181 ] = ""