1# This file is part of Hypothesis, which may be found at
2# https://github.com/HypothesisWorks/hypothesis/
3#
4# Copyright the Hypothesis Authors.
5# Individual contributors are listed in AUTHORS.rst and the git log.
6#
7# This Source Code Form is subject to the terms of the Mozilla Public License,
8# v. 2.0. If a copy of the MPL was not distributed with this file, You can
9# obtain one at https://mozilla.org/MPL/2.0/.
10
11"""A module for miscellaneous useful bits and bobs that don't
12obviously belong anywhere else. If you spot a better home for
13anything that lives here, please move it."""
14
15import array
16import gc
17import sys
18import time
19import warnings
20from array import ArrayType
21from collections.abc import Iterable, Iterator, Sequence
22from typing import Any, Callable, Generic, Literal, Optional, TypeVar, Union, overload
23
24from sortedcontainers import SortedList
25
26from hypothesis.errors import HypothesisWarning
27
28ARRAY_CODES = ["B", "H", "I", "L", "Q", "O"]
29
30T = TypeVar("T")
31
32
33def array_or_list(
34 code: str, contents: Iterable[int]
35) -> Union[list[int], "ArrayType[int]"]:
36 if code == "O":
37 return list(contents)
38 return array.array(code, contents)
39
40
41def replace_all(
42 ls: Sequence[T],
43 replacements: Iterable[tuple[int, int, Sequence[T]]],
44) -> list[T]:
45 """Substitute multiple replacement values into a list.
46
47 Replacements is a list of (start, end, value) triples.
48 """
49
50 result: list[T] = []
51 prev = 0
52 offset = 0
53 for u, v, r in replacements:
54 result.extend(ls[prev:u])
55 result.extend(r)
56 prev = v
57 offset += len(r) - (v - u)
58 result.extend(ls[prev:])
59 assert len(result) == len(ls) + offset
60 return result
61
62
63NEXT_ARRAY_CODE = dict(zip(ARRAY_CODES, ARRAY_CODES[1:]))
64
65
66class IntList(Sequence[int]):
67 """Class for storing a list of non-negative integers compactly.
68
69 We store them as the smallest size integer array we can get
70 away with. When we try to add an integer that is too large,
71 we upgrade the array to the smallest word size needed to store
72 the new value."""
73
74 __slots__ = ("__underlying",)
75
76 __underlying: Union[list[int], "ArrayType[int]"]
77
78 def __init__(self, values: Sequence[int] = ()):
79 for code in ARRAY_CODES:
80 try:
81 underlying = array_or_list(code, values)
82 break
83 except OverflowError:
84 pass
85 else: # pragma: no cover
86 raise AssertionError(f"Could not create storage for {values!r}")
87 if isinstance(underlying, list):
88 for v in underlying:
89 if not isinstance(v, int) or v < 0:
90 raise ValueError(f"Could not create IntList for {values!r}")
91 self.__underlying = underlying
92
93 @classmethod
94 def of_length(cls, n: int) -> "IntList":
95 return cls(array_or_list("B", [0]) * n)
96
97 def count(self, value: int) -> int:
98 return self.__underlying.count(value)
99
100 def __repr__(self) -> str:
101 return f"IntList({list(self.__underlying)!r})"
102
103 def __len__(self) -> int:
104 return len(self.__underlying)
105
106 @overload
107 def __getitem__(self, i: int) -> int: ... # pragma: no cover
108
109 @overload
110 def __getitem__(
111 self, i: slice
112 ) -> Union[list[int], "ArrayType[int]"]: ... # pragma: no cover
113
114 def __getitem__(
115 self, i: Union[int, slice]
116 ) -> Union[int, list[int], "ArrayType[int]"]:
117 return self.__underlying[i]
118
119 def __delitem__(self, i: Union[int, slice]) -> None:
120 del self.__underlying[i]
121
122 def insert(self, i: int, v: int) -> None:
123 self.__underlying.insert(i, v)
124
125 def __iter__(self) -> Iterator[int]:
126 return iter(self.__underlying)
127
128 def __eq__(self, other: object) -> bool:
129 if self is other:
130 return True
131 if not isinstance(other, IntList):
132 return NotImplemented
133 return self.__underlying == other.__underlying
134
135 def __ne__(self, other: object) -> bool:
136 if self is other:
137 return False
138 if not isinstance(other, IntList):
139 return NotImplemented
140 return self.__underlying != other.__underlying
141
142 def append(self, n: int) -> None:
143 i = len(self)
144 self.__underlying.append(0)
145 self[i] = n
146
147 def __setitem__(self, i: int, n: int) -> None:
148 while True:
149 try:
150 self.__underlying[i] = n
151 return
152 except OverflowError:
153 assert n > 0
154 self.__upgrade()
155
156 def extend(self, ls: Iterable[int]) -> None:
157 for n in ls:
158 self.append(n)
159
160 def __upgrade(self) -> None:
161 assert isinstance(self.__underlying, array.array)
162 code = NEXT_ARRAY_CODE[self.__underlying.typecode]
163 self.__underlying = array_or_list(code, self.__underlying)
164
165
166def binary_search(lo: int, hi: int, f: Callable[[int], bool]) -> int:
167 """Binary searches in [lo , hi) to find
168 n such that f(n) == f(lo) but f(n + 1) != f(lo).
169 It is implicitly assumed and will not be checked
170 that f(hi) != f(lo).
171 """
172
173 reference = f(lo)
174
175 while lo + 1 < hi:
176 mid = (lo + hi) // 2
177 if f(mid) == reference:
178 lo = mid
179 else:
180 hi = mid
181 return lo
182
183
184class LazySequenceCopy(Generic[T]):
185 """A "copy" of a sequence that works by inserting a mask in front
186 of the underlying sequence, so that you can mutate it without changing
187 the underlying sequence. Effectively behaves as if you could do list(x)
188 in O(1) time. The full list API is not supported yet but there's no reason
189 in principle it couldn't be."""
190
191 def __init__(self, values: Sequence[T]):
192 self.__values = values
193 self.__len = len(values)
194 self.__mask: Optional[dict[int, T]] = None
195 self.__popped_indices: Optional[SortedList[int]] = None
196
197 def __len__(self) -> int:
198 if self.__popped_indices is None:
199 return self.__len
200 return self.__len - len(self.__popped_indices)
201
202 def pop(self, i: int = -1) -> T:
203 if len(self) == 0:
204 raise IndexError("Cannot pop from empty list")
205 i = self.__underlying_index(i)
206
207 v = None
208 if self.__mask is not None:
209 v = self.__mask.pop(i, None)
210 if v is None:
211 v = self.__values[i]
212
213 if self.__popped_indices is None:
214 self.__popped_indices = SortedList()
215 self.__popped_indices.add(i)
216 return v
217
218 def swap(self, i: int, j: int) -> None:
219 """Swap the elements ls[i], ls[j]."""
220 if i == j:
221 return
222 self[i], self[j] = self[j], self[i]
223
224 def __getitem__(self, i: int) -> T:
225 i = self.__underlying_index(i)
226
227 default = self.__values[i]
228 if self.__mask is None:
229 return default
230 else:
231 return self.__mask.get(i, default)
232
233 def __setitem__(self, i: int, v: T) -> None:
234 i = self.__underlying_index(i)
235 if self.__mask is None:
236 self.__mask = {}
237 self.__mask[i] = v
238
239 def __underlying_index(self, i: int) -> int:
240 n = len(self)
241 if i < -n or i >= n:
242 raise IndexError(f"Index {i} out of range [0, {n})")
243 if i < 0:
244 i += n
245 assert 0 <= i < n
246
247 if self.__popped_indices is not None:
248 # given an index i in the popped representation of the list, compute
249 # its corresponding index in the underlying list. given
250 # l = [1, 4, 2, 10, 188]
251 # l.pop(3)
252 # l.pop(1)
253 # assert l == [1, 2, 188]
254 #
255 # we want l[i] == self.__values[f(i)], where f is this function.
256 assert len(self.__popped_indices) <= len(self.__values)
257
258 for idx in self.__popped_indices:
259 if idx > i:
260 break
261 i += 1
262 return i
263
264 # even though we have len + getitem, mypyc requires iter.
265 def __iter__(self) -> Iterable[T]:
266 for i in range(len(self)):
267 yield self[i]
268
269
270def stack_depth_of_caller() -> int:
271 """Get stack size for caller's frame.
272
273 From https://stackoverflow.com/a/47956089/9297601 , this is a simple
274 but much faster alternative to `len(inspect.stack(0))`. We use it
275 with get/set recursionlimit to make stack overflows non-flaky; see
276 https://github.com/HypothesisWorks/hypothesis/issues/2494 for details.
277 """
278 frame = sys._getframe(2)
279 size = 1
280 while frame:
281 frame = frame.f_back # type: ignore[assignment]
282 size += 1
283 return size
284
285
286class ensure_free_stackframes:
287 """Context manager that ensures there are at least N free stackframes (for
288 a reasonable value of N).
289 """
290
291 def __enter__(self) -> None:
292 cur_depth = stack_depth_of_caller()
293 self.old_maxdepth = sys.getrecursionlimit()
294 # The default CPython recursionlimit is 1000, but pytest seems to bump
295 # it to 3000 during test execution. Let's make it something reasonable:
296 self.new_maxdepth = cur_depth + 2000
297 # Because we add to the recursion limit, to be good citizens we also
298 # add a check for unbounded recursion. The default limit is typically
299 # 1000/3000, so this can only ever trigger if something really strange
300 # is happening and it's hard to imagine an
301 # intentionally-deeply-recursive use of this code.
302 assert cur_depth <= 1000, (
303 "Hypothesis would usually add %d to the stack depth of %d here, "
304 "but we are already much deeper than expected. Aborting now, to "
305 "avoid extending the stack limit in an infinite loop..."
306 % (self.new_maxdepth - self.old_maxdepth, self.old_maxdepth)
307 )
308 sys.setrecursionlimit(self.new_maxdepth)
309
310 def __exit__(self, *args, **kwargs):
311 if self.new_maxdepth == sys.getrecursionlimit():
312 sys.setrecursionlimit(self.old_maxdepth)
313 else: # pragma: no cover
314 warnings.warn(
315 "The recursion limit will not be reset, since it was changed "
316 "from another thread or during execution of a test.",
317 HypothesisWarning,
318 stacklevel=2,
319 )
320
321
322def find_integer(f: Callable[[int], bool]) -> int:
323 """Finds a (hopefully large) integer such that f(n) is True and f(n + 1) is
324 False.
325
326 f(0) is assumed to be True and will not be checked.
327 """
328 # We first do a linear scan over the small numbers and only start to do
329 # anything intelligent if f(4) is true. This is because it's very hard to
330 # win big when the result is small. If the result is 0 and we try 2 first
331 # then we've done twice as much work as we needed to!
332 for i in range(1, 5):
333 if not f(i):
334 return i - 1
335
336 # We now know that f(4) is true. We want to find some number for which
337 # f(n) is *not* true.
338 # lo is the largest number for which we know that f(lo) is true.
339 lo = 4
340
341 # Exponential probe upwards until we find some value hi such that f(hi)
342 # is not true. Subsequently we maintain the invariant that hi is the
343 # smallest number for which we know that f(hi) is not true.
344 hi = 5
345 while f(hi):
346 lo = hi
347 hi *= 2
348
349 # Now binary search until lo + 1 = hi. At that point we have f(lo) and not
350 # f(lo + 1), as desired..
351 while lo + 1 < hi:
352 mid = (lo + hi) // 2
353 if f(mid):
354 lo = mid
355 else:
356 hi = mid
357 return lo
358
359
360class NotFound(Exception):
361 pass
362
363
364class SelfOrganisingList(Generic[T]):
365 """A self-organising list with the move-to-front heuristic.
366
367 A self-organising list is a collection which we want to retrieve items
368 that satisfy some predicate from. There is no faster way to do this than
369 a linear scan (as the predicates may be arbitrary), but the performance
370 of a linear scan can vary dramatically - if we happen to find a good item
371 on the first try it's O(1) after all. The idea of a self-organising list is
372 to reorder the list to try to get lucky this way as often as possible.
373
374 There are various heuristics we could use for this, and it's not clear
375 which are best. We use the simplest, which is that every time we find
376 an item we move it to the "front" (actually the back in our implementation
377 because we iterate in reverse) of the list.
378
379 """
380
381 def __init__(self, values: Iterable[T] = ()) -> None:
382 self.__values = list(values)
383
384 def __repr__(self) -> str:
385 return f"SelfOrganisingList({self.__values!r})"
386
387 def add(self, value: T) -> None:
388 """Add a value to this list."""
389 self.__values.append(value)
390
391 def find(self, condition: Callable[[T], bool]) -> T:
392 """Returns some value in this list such that ``condition(value)`` is
393 True. If no such value exists raises ``NotFound``."""
394 for i in range(len(self.__values) - 1, -1, -1):
395 value = self.__values[i]
396 if condition(value):
397 del self.__values[i]
398 self.__values.append(value)
399 return value
400 raise NotFound("No values satisfying condition")
401
402
403_gc_initialized = False
404_gc_start: float = 0
405_gc_cumulative_time: float = 0
406
407# Since gc_callback potentially runs in test context, and perf_counter
408# might be monkeypatched, we store a reference to the real one.
409_perf_counter = time.perf_counter
410
411
412def gc_cumulative_time() -> float:
413 global _gc_initialized
414 if not _gc_initialized:
415 if hasattr(gc, "callbacks"):
416 # CPython
417 def gc_callback(
418 phase: Literal["start", "stop"], info: dict[str, int]
419 ) -> None:
420 global _gc_start, _gc_cumulative_time
421 try:
422 now = _perf_counter()
423 if phase == "start":
424 _gc_start = now
425 elif phase == "stop" and _gc_start > 0:
426 _gc_cumulative_time += now - _gc_start # pragma: no cover # ??
427 except RecursionError: # pragma: no cover
428 # Avoid flakiness via UnraisableException, which is caught and
429 # warned by pytest. The actual callback (this function) is
430 # validated to never trigger a RecursionError itself when
431 # when called by gc.collect.
432 # Anyway, we should hit the same error on "start"
433 # and "stop", but to ensure we don't get out of sync we just
434 # signal that there is no matching start.
435 _gc_start = 0
436 return
437
438 gc.callbacks.insert(0, gc_callback)
439 elif hasattr(gc, "hooks"): # pragma: no cover # pypy only
440 # PyPy
441 def hook(stats: Any) -> None:
442 global _gc_cumulative_time
443 try:
444 _gc_cumulative_time += stats.duration
445 except RecursionError:
446 pass
447
448 if gc.hooks.on_gc_minor is None:
449 gc.hooks.on_gc_minor = hook
450 if gc.hooks.on_gc_collect_step is None:
451 gc.hooks.on_gc_collect_step = hook
452
453 _gc_initialized = True
454
455 return _gc_cumulative_time
456
457
458def startswith(l1: Sequence[T], l2: Sequence[T]) -> bool:
459 if len(l1) < len(l2):
460 return False
461 return all(v1 == v2 for v1, v2 in zip(l1[: len(l2)], l2))
462
463
464def endswith(l1: Sequence[T], l2: Sequence[T]) -> bool:
465 if len(l1) < len(l2):
466 return False
467 return all(v1 == v2 for v1, v2 in zip(l1[-len(l2) :], l2))
468
469
470def bits_to_bytes(n: int) -> int:
471 """The number of bytes required to represent an n-bit number.
472 Equivalent to (n + 7) // 8, but slightly faster. This really is
473 called enough times that that matters."""
474 return (n + 7) >> 3