1# This file is part of Hypothesis, which may be found at
2# https://github.com/HypothesisWorks/hypothesis/
3#
4# Copyright the Hypothesis Authors.
5# Individual contributors are listed in AUTHORS.rst and the git log.
6#
7# This Source Code Form is subject to the terms of the Mozilla Public License,
8# v. 2.0. If a copy of the MPL was not distributed with this file, You can
9# obtain one at https://mozilla.org/MPL/2.0/.
10
11import copy
12import re
13import warnings
14from functools import cache, lru_cache, partial
15from typing import Optional
16
17from hypothesis.errors import HypothesisWarning, InvalidArgument
18from hypothesis.internal import charmap
19from hypothesis.internal.conjecture.data import ConjectureData
20from hypothesis.internal.conjecture.providers import COLLECTION_DEFAULT_MAX_SIZE
21from hypothesis.internal.filtering import max_len, min_len
22from hypothesis.internal.intervalsets import IntervalSet
23from hypothesis.internal.reflection import get_pretty_function_description
24from hypothesis.strategies._internal.collections import ListStrategy
25from hypothesis.strategies._internal.lazy import unwrap_strategies
26from hypothesis.strategies._internal.strategies import (
27 OneOfStrategy,
28 SampledFromStrategy,
29 SearchStrategy,
30)
31from hypothesis.vendor.pretty import pretty
32
33
34# Cache size is limited by sys.maxunicode, but passing None makes it slightly faster.
35@cache
36def _check_is_single_character(c):
37 # In order to mitigate the performance cost of this check, we use a shared cache,
38 # even at the cost of showing the culprit strategy in the error message.
39 if not isinstance(c, str):
40 type_ = get_pretty_function_description(type(c))
41 raise InvalidArgument(f"Got non-string {c!r} (type {type_})")
42 if len(c) != 1:
43 raise InvalidArgument(f"Got {c!r} (length {len(c)} != 1)")
44 return c
45
46
47class OneCharStringStrategy(SearchStrategy[str]):
48 """A strategy which generates single character strings of text type."""
49
50 def __init__(
51 self, intervals: IntervalSet, force_repr: Optional[str] = None
52 ) -> None:
53 super().__init__()
54 assert isinstance(intervals, IntervalSet)
55 self.intervals = intervals
56 self._force_repr = force_repr
57
58 @classmethod
59 def from_characters_args(
60 cls,
61 *,
62 codec=None,
63 min_codepoint=None,
64 max_codepoint=None,
65 categories=None,
66 exclude_characters=None,
67 include_characters=None,
68 ):
69 assert set(categories or ()).issubset(charmap.categories())
70 intervals = charmap.query(
71 min_codepoint=min_codepoint,
72 max_codepoint=max_codepoint,
73 categories=categories,
74 exclude_characters=exclude_characters,
75 include_characters=include_characters,
76 )
77 if codec is not None:
78 intervals &= charmap.intervals_from_codec(codec)
79
80 _arg_repr = ", ".join(
81 f"{k}={v!r}"
82 for k, v in [
83 ("codec", codec),
84 ("min_codepoint", min_codepoint),
85 ("max_codepoint", max_codepoint),
86 ("categories", categories),
87 ("exclude_characters", exclude_characters),
88 ("include_characters", include_characters),
89 ]
90 if v not in (None, "")
91 and not (k == "categories" and set(v) == set(charmap.categories()) - {"Cs"})
92 )
93 if not intervals:
94 raise InvalidArgument(
95 "No characters are allowed to be generated by this "
96 f"combination of arguments: {_arg_repr}"
97 )
98 return cls(intervals, force_repr=f"characters({_arg_repr})")
99
100 @classmethod
101 def from_alphabet(cls, alphabet):
102 if isinstance(alphabet, str):
103 return cls.from_characters_args(categories=(), include_characters=alphabet)
104
105 assert isinstance(alphabet, SearchStrategy)
106 char_strategy = unwrap_strategies(alphabet)
107 if isinstance(char_strategy, cls):
108 return char_strategy
109 elif isinstance(char_strategy, SampledFromStrategy):
110 for c in char_strategy.elements:
111 _check_is_single_character(c)
112 return cls.from_characters_args(
113 categories=(),
114 include_characters=char_strategy.elements,
115 )
116 elif isinstance(char_strategy, OneOfStrategy):
117 intervals = IntervalSet()
118 for s in char_strategy.element_strategies:
119 intervals = intervals.union(cls.from_alphabet(s).intervals)
120 return cls(intervals, force_repr=repr(alphabet))
121 raise InvalidArgument(
122 f"{alphabet=} must be a sampled_from() or characters() strategy"
123 )
124
125 def __repr__(self) -> str:
126 return self._force_repr or f"OneCharStringStrategy({self.intervals!r})"
127
128 def do_draw(self, data: ConjectureData) -> str:
129 return data.draw_string(self.intervals, min_size=1, max_size=1)
130
131
132_nonempty_names = (
133 "capitalize",
134 "expandtabs",
135 "join",
136 "lower",
137 "rsplit",
138 "split",
139 "splitlines",
140 "swapcase",
141 "title",
142 "upper",
143)
144_nonempty_and_content_names = (
145 "islower",
146 "isupper",
147 "isalnum",
148 "isalpha",
149 "isascii",
150 "isdigit",
151 "isspace",
152 "istitle",
153 "lstrip",
154 "rstrip",
155 "strip",
156)
157
158
159class TextStrategy(ListStrategy[str]):
160 def do_draw(self, data):
161 # if our element strategy is OneCharStringStrategy, we can skip the
162 # ListStrategy draw and jump right to data.draw_string.
163 # Doing so for user-provided element strategies is not correct in
164 # general, as they may define a different distribution than data.draw_string.
165 elems = unwrap_strategies(self.element_strategy)
166 if isinstance(elems, OneCharStringStrategy):
167 return data.draw_string(
168 elems.intervals,
169 min_size=self.min_size,
170 max_size=(
171 COLLECTION_DEFAULT_MAX_SIZE
172 if self.max_size == float("inf")
173 else self.max_size
174 ),
175 )
176 return "".join(super().do_draw(data))
177
178 def __repr__(self) -> str:
179 args = []
180 if repr(self.element_strategy) != "characters()":
181 args.append(repr(self.element_strategy))
182 if self.min_size:
183 args.append(f"min_size={self.min_size}")
184 if self.max_size < float("inf"):
185 args.append(f"max_size={self.max_size}")
186 return f"text({', '.join(args)})"
187
188 # See https://docs.python.org/3/library/stdtypes.html#string-methods
189 # These methods always return Truthy values for any nonempty string.
190 _nonempty_filters = (
191 *ListStrategy._nonempty_filters,
192 str,
193 str.casefold,
194 str.encode,
195 *(getattr(str, n) for n in _nonempty_names),
196 )
197 _nonempty_and_content_filters = (
198 str.isdecimal,
199 str.isnumeric,
200 *(getattr(str, n) for n in _nonempty_and_content_names),
201 )
202
203 def filter(self, condition):
204 elems = unwrap_strategies(self.element_strategy)
205 if (
206 condition is str.isidentifier
207 and self.max_size >= 1
208 and isinstance(elems, OneCharStringStrategy)
209 ):
210 from hypothesis.strategies import builds, nothing
211
212 id_start, id_continue = _identifier_characters()
213 if not (elems.intervals & id_start):
214 return nothing()
215 return builds(
216 "{}{}".format,
217 OneCharStringStrategy(elems.intervals & id_start),
218 TextStrategy(
219 OneCharStringStrategy(elems.intervals & id_continue),
220 min_size=max(0, self.min_size - 1),
221 max_size=self.max_size - 1,
222 ),
223 # Filter to ensure that NFKC normalization keeps working in future
224 ).filter(str.isidentifier)
225 if (new := _string_filter_rewrite(self, str, condition)) is not None:
226 return new
227 return super().filter(condition)
228
229
230def _string_filter_rewrite(self, kind, condition):
231 if condition in (kind.lower, kind.title, kind.upper):
232 k = kind.__name__
233 warnings.warn(
234 f"You applied {k}.{condition.__name__} as a filter, but this allows "
235 f"all nonempty strings! Did you mean {k}.is{condition.__name__}?",
236 HypothesisWarning,
237 stacklevel=2,
238 )
239
240 if (
241 (
242 kind is bytes
243 or isinstance(
244 unwrap_strategies(self.element_strategy), OneCharStringStrategy
245 )
246 )
247 and isinstance(pattern := getattr(condition, "__self__", None), re.Pattern)
248 and isinstance(pattern.pattern, kind)
249 ):
250 from hypothesis.strategies._internal.regex import regex_strategy
251
252 if condition.__name__ == "match":
253 # Replace with an easier-to-handle equivalent condition
254 caret, close = ("^(?:", ")") if kind is str else (b"^(?:", b")")
255 pattern = re.compile(caret + pattern.pattern + close, flags=pattern.flags)
256 condition = pattern.search
257
258 if condition.__name__ in ("search", "findall", "fullmatch"):
259 s = regex_strategy(
260 pattern,
261 fullmatch=condition.__name__ == "fullmatch",
262 alphabet=self.element_strategy if kind is str else None,
263 )
264 if self.min_size > 0:
265 s = s.filter(partial(min_len, self.min_size))
266 if self.max_size < 1e999:
267 s = s.filter(partial(max_len, self.max_size))
268 return s
269 elif condition.__name__ in ("finditer", "scanner"):
270 # PyPy implements `finditer` as an alias to their `scanner` method
271 warnings.warn(
272 f"You applied {pretty(condition)} as a filter, but this allows "
273 f"any string at all! Did you mean .findall ?",
274 HypothesisWarning,
275 stacklevel=3,
276 )
277 return self
278 elif condition.__name__ == "split":
279 warnings.warn(
280 f"You applied {pretty(condition)} as a filter, but this allows "
281 f"any nonempty string! Did you mean .search ?",
282 HypothesisWarning,
283 stacklevel=3,
284 )
285 return self.filter(bool)
286
287 # We use ListStrategy filter logic for the conditions that *only* imply
288 # the string is nonempty. Here, we increment the min_size but still apply
289 # the filter for conditions that imply nonempty *and specific contents*.
290 if condition in self._nonempty_and_content_filters and self.max_size >= 1:
291 self = copy.copy(self)
292 self.min_size = max(1, self.min_size)
293 return ListStrategy.filter(self, condition)
294
295 return None
296
297
298# Excerpted from https://www.unicode.org/Public/15.0.0/ucd/PropList.txt
299# Python updates it's Unicode version between minor releases, but fortunately
300# these properties do not change between the Unicode versions in question.
301_PROPLIST = """
302# ================================================
303
3041885..1886 ; Other_ID_Start # Mn [2] MONGOLIAN LETTER ALI GALI BALUDA..MONGOLIAN LETTER ALI GALI THREE BALUDA
3052118 ; Other_ID_Start # Sm SCRIPT CAPITAL P
306212E ; Other_ID_Start # So ESTIMATED SYMBOL
307309B..309C ; Other_ID_Start # Sk [2] KATAKANA-HIRAGANA VOICED SOUND MARK..KATAKANA-HIRAGANA SEMI-VOICED SOUND MARK
308
309# Total code points: 6
310
311# ================================================
312
31300B7 ; Other_ID_Continue # Po MIDDLE DOT
3140387 ; Other_ID_Continue # Po GREEK ANO TELEIA
3151369..1371 ; Other_ID_Continue # No [9] ETHIOPIC DIGIT ONE..ETHIOPIC DIGIT NINE
31619DA ; Other_ID_Continue # No NEW TAI LUE THAM DIGIT ONE
317
318# Total code points: 12
319"""
320
321
322@lru_cache
323def _identifier_characters():
324 """See https://docs.python.org/3/reference/lexical_analysis.html#identifiers"""
325 # Start by computing the set of special characters
326 chars = {"Other_ID_Start": "", "Other_ID_Continue": ""}
327 for line in _PROPLIST.splitlines():
328 if m := re.match(r"([0-9A-F.]+) +; (\w+) # ", line):
329 codes, prop = m.groups()
330 span = range(int(codes[:4], base=16), int(codes[-4:], base=16) + 1)
331 chars[prop] += "".join(chr(x) for x in span)
332
333 # Then get the basic set by Unicode category and known extras
334 id_start = charmap.query(
335 categories=("Lu", "Ll", "Lt", "Lm", "Lo", "Nl"),
336 include_characters="_" + chars["Other_ID_Start"],
337 )
338 id_start -= IntervalSet.from_string(
339 # Magic value: the characters which NFKC-normalize to be invalid identifiers.
340 # Conveniently they're all in `id_start`, so we only need to do this once.
341 "\u037a\u0e33\u0eb3\u2e2f\u309b\u309c\ufc5e\ufc5f\ufc60\ufc61\ufc62\ufc63"
342 "\ufdfa\ufdfb\ufe70\ufe72\ufe74\ufe76\ufe78\ufe7a\ufe7c\ufe7e\uff9e\uff9f"
343 )
344 id_continue = id_start | charmap.query(
345 categories=("Mn", "Mc", "Nd", "Pc"),
346 include_characters=chars["Other_ID_Continue"],
347 )
348 return id_start, id_continue
349
350
351class BytesStrategy(SearchStrategy):
352 def __init__(self, min_size: int, max_size: Optional[int]):
353 super().__init__()
354 self.min_size = min_size
355 self.max_size = (
356 max_size if max_size is not None else COLLECTION_DEFAULT_MAX_SIZE
357 )
358
359 def do_draw(self, data):
360 return data.draw_bytes(self.min_size, self.max_size)
361
362 _nonempty_filters = (
363 *ListStrategy._nonempty_filters,
364 bytes,
365 *(getattr(bytes, n) for n in _nonempty_names),
366 )
367 _nonempty_and_content_filters = (
368 *(getattr(bytes, n) for n in _nonempty_and_content_names),
369 )
370
371 def filter(self, condition):
372 if (new := _string_filter_rewrite(self, bytes, condition)) is not None:
373 return new
374 return ListStrategy.filter(self, condition)