1# This file is part of Hypothesis, which may be found at
2# https://github.com/HypothesisWorks/hypothesis/
3#
4# Copyright the Hypothesis Authors.
5# Individual contributors are listed in AUTHORS.rst and the git log.
6#
7# This Source Code Form is subject to the terms of the Mozilla Public License,
8# v. 2.0. If a copy of the MPL was not distributed with this file, You can
9# obtain one at https://mozilla.org/MPL/2.0/.
10
11import threading
12from contextlib import contextmanager
13
14from hypothesis.errors import InvalidArgument
15from hypothesis.internal.reflection import get_pretty_function_description
16from hypothesis.internal.validation import check_type
17from hypothesis.strategies._internal.strategies import (
18 OneOfStrategy,
19 SearchStrategy,
20 check_strategy,
21)
22
23
24class LimitReached(BaseException):
25 pass
26
27
28class LimitedStrategy(SearchStrategy):
29 def __init__(self, strategy):
30 super().__init__()
31 self.base_strategy = strategy
32 self._threadlocal = threading.local()
33
34 @property
35 def marker(self):
36 return getattr(self._threadlocal, "marker", 0)
37
38 @marker.setter
39 def marker(self, value):
40 self._threadlocal.marker = value
41
42 @property
43 def currently_capped(self):
44 return getattr(self._threadlocal, "currently_capped", False)
45
46 @currently_capped.setter
47 def currently_capped(self, value):
48 self._threadlocal.currently_capped = value
49
50 def __repr__(self):
51 return f"LimitedStrategy({self.base_strategy!r})"
52
53 def do_validate(self):
54 self.base_strategy.validate()
55
56 def do_draw(self, data):
57 assert self.currently_capped
58 if self.marker <= 0:
59 raise LimitReached
60 self.marker -= 1
61 return data.draw(self.base_strategy)
62
63 @contextmanager
64 def capped(self, max_templates):
65 try:
66 was_capped = self.currently_capped
67 self.currently_capped = True
68 self.marker = max_templates
69 yield
70 finally:
71 self.currently_capped = was_capped
72
73
74class RecursiveStrategy(SearchStrategy):
75 def __init__(self, base, extend, max_leaves):
76 self.max_leaves = max_leaves
77 self.base = base
78 self.limited_base = LimitedStrategy(base)
79 self.extend = extend
80
81 strategies = [self.limited_base, self.extend(self.limited_base)]
82 while 2 ** (len(strategies) - 1) <= max_leaves:
83 strategies.append(extend(OneOfStrategy(tuple(strategies))))
84 self.strategy = OneOfStrategy(strategies)
85
86 def __repr__(self):
87 if not hasattr(self, "_cached_repr"):
88 self._cached_repr = "recursive(%r, %s, max_leaves=%d)" % (
89 self.base,
90 get_pretty_function_description(self.extend),
91 self.max_leaves,
92 )
93 return self._cached_repr
94
95 def do_validate(self):
96 check_strategy(self.base, "base")
97 extended = self.extend(self.limited_base)
98 check_strategy(extended, f"extend({self.limited_base!r})")
99 self.limited_base.validate()
100 extended.validate()
101 check_type(int, self.max_leaves, "max_leaves")
102 if self.max_leaves <= 0:
103 raise InvalidArgument(
104 f"max_leaves={self.max_leaves!r} must be greater than zero"
105 )
106
107 def do_draw(self, data):
108 count = 0
109 while True:
110 try:
111 with self.limited_base.capped(self.max_leaves):
112 return data.draw(self.strategy)
113 except LimitReached:
114 if count == 0:
115 msg = f"Draw for {self!r} exceeded max_leaves and had to be retried"
116 data.events[msg] = ""
117 count += 1