1# This file is part of Hypothesis, which may be found at
2# https://github.com/HypothesisWorks/hypothesis/
3#
4# Copyright the Hypothesis Authors.
5# Individual contributors are listed in AUTHORS.rst and the git log.
6#
7# This Source Code Form is subject to the terms of the Mozilla Public License,
8# v. 2.0. If a copy of the MPL was not distributed with this file, You can
9# obtain one at https://mozilla.org/MPL/2.0/.
10
11import threading
12import warnings
13from collections.abc import Callable, Generator
14from contextlib import contextmanager
15from typing import Any, TypeVar
16
17from hypothesis.errors import HypothesisWarning, InvalidArgument
18from hypothesis.internal.conjecture.data import ConjectureData
19from hypothesis.internal.reflection import (
20 get_pretty_function_description,
21 is_first_param_referenced_in_function,
22 is_identity_function,
23)
24from hypothesis.internal.validation import check_type
25from hypothesis.strategies._internal.strategies import (
26 OneOfStrategy,
27 SearchStrategy,
28 check_strategy,
29)
30from hypothesis.utils.deprecation import note_deprecation
31
32T = TypeVar("T")
33
34
35class LimitReached(BaseException):
36 pass
37
38
39class LimitedStrategy(SearchStrategy[T]):
40 def __init__(self, strategy: SearchStrategy[T]):
41 super().__init__()
42 self.base_strategy = strategy
43 self._threadlocal = threading.local()
44
45 @property
46 def marker(self) -> int:
47 return getattr(self._threadlocal, "marker", 0)
48
49 @marker.setter
50 def marker(self, value: int) -> None:
51 self._threadlocal.marker = value
52
53 @property
54 def currently_capped(self) -> bool:
55 return getattr(self._threadlocal, "currently_capped", False)
56
57 @currently_capped.setter
58 def currently_capped(self, value: bool) -> None:
59 self._threadlocal.currently_capped = value
60
61 def __repr__(self) -> str:
62 return f"LimitedStrategy({self.base_strategy!r})"
63
64 def do_validate(self) -> None:
65 self.base_strategy.validate()
66
67 def do_draw(self, data: ConjectureData) -> T:
68 assert self.currently_capped
69 if self.marker <= 0:
70 raise LimitReached
71 self.marker -= 1
72 return data.draw(self.base_strategy)
73
74 @contextmanager
75 def capped(self, max_templates: int) -> Generator[None, None, None]:
76 try:
77 was_capped = self.currently_capped
78 self.currently_capped = True
79 self.marker = max_templates
80 yield
81 finally:
82 self.currently_capped = was_capped
83
84
85class RecursiveStrategy(SearchStrategy):
86 def __init__(
87 self,
88 base: SearchStrategy,
89 extend: Callable[[SearchStrategy], SearchStrategy],
90 min_leaves: int | None,
91 max_leaves: int,
92 ):
93 super().__init__()
94 self.min_leaves = min_leaves
95 self.max_leaves = max_leaves
96 self.base = base
97 self.limited_base = LimitedStrategy(base)
98 self.extend = extend
99
100 strategies = [self.limited_base, self.extend(self.limited_base)]
101 while 2 ** (len(strategies) - 1) <= max_leaves:
102 strategies.append(extend(OneOfStrategy(tuple(strategies))))
103 # If min_leaves > 1, we can never draw from base directly
104 if min_leaves is not None and min_leaves > 1:
105 strategies = strategies[1:]
106 self.strategy = OneOfStrategy(strategies)
107
108 def __repr__(self) -> str:
109 if not hasattr(self, "_cached_repr"):
110 self._cached_repr = (
111 f"recursive({self.base!r}, "
112 f"{get_pretty_function_description(self.extend)}, "
113 f"min_leaves={self.min_leaves}, max_leaves={self.max_leaves})"
114 )
115 return self._cached_repr
116
117 def do_validate(self) -> None:
118 check_strategy(self.base, "base")
119 extended = self.extend(self.limited_base)
120 check_strategy(extended, f"extend({self.limited_base!r})")
121 self.limited_base.validate()
122 extended.validate()
123
124 if is_identity_function(self.extend):
125 warnings.warn(
126 "extend=lambda x: x is a no-op; you probably want to use a "
127 "different extend function, or just use the base strategy directly.",
128 HypothesisWarning,
129 stacklevel=5,
130 )
131
132 if not is_first_param_referenced_in_function(self.extend):
133 msg = (
134 f"extend={get_pretty_function_description(self.extend)} doesn't use "
135 "it's argument, and thus can't actually recurse!"
136 )
137 if self.min_leaves is None:
138 note_deprecation(
139 msg,
140 since="2026-01-12",
141 has_codemod=False,
142 stacklevel=1,
143 )
144 else:
145 raise InvalidArgument(msg)
146
147 if self.min_leaves is not None:
148 check_type(int, self.min_leaves, "min_leaves")
149 check_type(int, self.max_leaves, "max_leaves")
150 if self.min_leaves is not None and self.min_leaves <= 0:
151 raise InvalidArgument(
152 f"min_leaves={self.min_leaves!r} must be greater than zero"
153 )
154 if self.max_leaves <= 0:
155 raise InvalidArgument(
156 f"max_leaves={self.max_leaves!r} must be greater than zero"
157 )
158 if (self.min_leaves or 1) > self.max_leaves:
159 raise InvalidArgument(
160 f"min_leaves={self.min_leaves!r} must be less than or equal to "
161 f"max_leaves={self.max_leaves!r}"
162 )
163
164 def do_draw(self, data: ConjectureData) -> Any:
165 min_leaves_retries = 0
166 while True:
167 try:
168 with self.limited_base.capped(self.max_leaves):
169 result = data.draw(self.strategy)
170 leaves_drawn = self.max_leaves - self.limited_base.marker
171 if self.min_leaves and leaves_drawn < self.min_leaves:
172 data.events[
173 f"Draw for {self!r} had fewer than "
174 f"min_leaves={self.min_leaves} and had to be retried"
175 ] = ""
176 min_leaves_retries += 1
177 if min_leaves_retries < 5:
178 continue
179 data.mark_invalid(f"min_leaves={self.min_leaves} unsatisfied")
180 return result
181 except LimitReached:
182 data.events[
183 f"Draw for {self!r} exceeded "
184 f"max_leaves={self.max_leaves} and had to be retried"
185 ] = ""