1# This file is part of Hypothesis, which may be found at
2# https://github.com/HypothesisWorks/hypothesis/
3#
4# Copyright the Hypothesis Authors.
5# Individual contributors are listed in AUTHORS.rst and the git log.
6#
7# This Source Code Form is subject to the terms of the Mozilla Public License,
8# v. 2.0. If a copy of the MPL was not distributed with this file, You can
9# obtain one at https://mozilla.org/MPL/2.0/.
10
11import math
12import unicodedata
13from collections import defaultdict
14from collections.abc import Callable, Iterator, Sequence
15from dataclasses import dataclass
16from functools import lru_cache
17from typing import (
18 TYPE_CHECKING,
19 Any,
20 Literal,
21 TypeAlias,
22 cast,
23)
24
25from hypothesis.internal.conjecture.choice import (
26 ChoiceNode,
27 ChoiceT,
28 choice_equal,
29 choice_from_index,
30 choice_key,
31 choice_permitted,
32 choice_to_index,
33)
34from hypothesis.internal.conjecture.data import (
35 ConjectureData,
36 ConjectureResult,
37 Spans,
38 Status,
39 _Overrun,
40 draw_choice,
41)
42from hypothesis.internal.conjecture.junkdrawer import (
43 endswith,
44 find_integer,
45 replace_all,
46 startswith,
47)
48from hypothesis.internal.conjecture.shrinking import (
49 Bytes,
50 Float,
51 Integer,
52 Ordering,
53 String,
54)
55from hypothesis.internal.conjecture.shrinking.choicetree import (
56 ChoiceTree,
57 prefix_selection_order,
58 random_selection_order,
59)
60from hypothesis.internal.floats import MAX_PRECISE_INTEGER
61
62if TYPE_CHECKING:
63 from random import Random
64
65 from hypothesis.internal.conjecture.engine import ConjectureRunner
66
67ShrinkPredicateT: TypeAlias = Callable[[ConjectureResult | _Overrun], bool]
68
69
70def sort_key(nodes: Sequence[ChoiceNode]) -> tuple[int, tuple[int, ...]]:
71 """Returns a sort key such that "simpler" choice sequences are smaller than
72 "more complicated" ones.
73
74 We define sort_key so that x is simpler than y if x is shorter than y or if
75 they have the same length and map(choice_to_index, x) < map(choice_to_index, y).
76
77 The reason for using this ordering is:
78
79 1. If x is shorter than y then that means we had to make fewer decisions
80 in constructing the test case when we ran x than we did when we ran y.
81 2. If x is the same length as y then replacing a choice with a lower index
82 choice corresponds to replacing it with a simpler/smaller choice.
83 3. Because choices drawn early in generation potentially get used in more
84 places they potentially have a more significant impact on the final
85 result, so it makes sense to prioritise reducing earlier choices over
86 later ones.
87 """
88 return (
89 len(nodes),
90 tuple(choice_to_index(node.value, node.constraints) for node in nodes),
91 )
92
93
94@lru_cache(maxsize=4096)
95def _natural_simpler_chars(c, intervals):
96 """Return single-char replacements for ``c`` derived from natural text
97 transformations - case mapping (upper, lower, casefold) and unicode
98 decomposition (NFD, NFKD). We take each individual character of the
99 transformed form so that e.g. ``ß`` can shrink to ``s`` via casefold
100 even though the full case-folded form is two characters.
101
102 Only candidates which are in ``intervals`` and which have a strictly
103 smaller index in shrink order than ``c`` are returned, sorted by that
104 shrink-order index. Callers must pass a single character that is itself
105 in ``intervals``.
106 """
107 candidates: set[str] = set()
108 for form in ("NFKD", "NFD"):
109 candidates.update(unicodedata.normalize(form, c))
110 for transformed in (c.upper(), c.lower(), c.casefold()):
111 candidates.update(transformed)
112 candidates.discard(c)
113 original_idx = intervals.index_from_char_in_shrink_order(c)
114 result = sorted(
115 (intervals.index_from_char_in_shrink_order(cand), cand)
116 for cand in candidates
117 if cand in intervals
118 )
119 return [cand for idx, cand in result if idx < original_idx]
120
121
122@dataclass(slots=True, frozen=False)
123class ShrinkPass:
124 function: Any
125 name: str | None = None
126 last_prefix: Any = ()
127
128 # some execution statistics
129 calls: int = 0
130 misaligned: int = 0
131 shrinks: int = 0
132 deletions: int = 0
133
134 def __post_init__(self):
135 if self.name is None:
136 self.name = self.function.__name__
137
138 def __hash__(self):
139 return hash(self.name)
140
141
142class StopShrinking(Exception):
143 pass
144
145
146class Shrinker:
147 """A shrinker is a child object of a ConjectureRunner which is designed to
148 manage the associated state of a particular shrink problem. That is, we
149 have some initial ConjectureData object and some property of interest
150 that it satisfies, and we want to find a ConjectureData object with a
151 shortlex (see sort_key above) smaller choice sequence that exhibits the same
152 property.
153
154 Currently the only property of interest we use is that the status is
155 INTERESTING and the interesting_origin takes on some fixed value, but we
156 may potentially be interested in other use cases later.
157 However we assume that data with a status < VALID never satisfies the predicate.
158
159 The shrinker keeps track of a value shrink_target which represents the
160 current best known ConjectureData object satisfying the predicate.
161 It refines this value by repeatedly running *shrink passes*, which are
162 methods that perform a series of transformations to the current shrink_target
163 and evaluate the underlying test function to find new ConjectureData
164 objects. If any of these satisfy the predicate, the shrink_target
165 is updated automatically. Shrinking runs until no shrink pass can
166 improve the shrink_target, at which point it stops. It may also be
167 terminated if the underlying engine throws RunIsComplete, but that
168 is handled by the calling code rather than the Shrinker.
169
170 =======================
171 Designing Shrink Passes
172 =======================
173
174 Generally a shrink pass is just any function that calls
175 cached_test_function and/or consider_new_nodes a number of times,
176 but there are a couple of useful things to bear in mind.
177
178 A shrink pass *makes progress* if running it changes self.shrink_target
179 (i.e. it tries a shortlex smaller ConjectureData object satisfying
180 the predicate). The desired end state of shrinking is to find a
181 value such that no shrink pass can make progress, i.e. that we
182 are at a local minimum for each shrink pass.
183
184 In aid of this goal, the main invariant that a shrink pass much
185 satisfy is that whether it makes progress must be deterministic.
186 It is fine (encouraged even) for the specific progress it makes
187 to be non-deterministic, but if you run a shrink pass, it makes
188 no progress, and then you immediately run it again, it should
189 never succeed on the second time. This allows us to stop as soon
190 as we have run each shrink pass and seen no progress on any of
191 them.
192
193 This means that e.g. it's fine to try each of N deletions
194 or replacements in a random order, but it's not OK to try N random
195 deletions (unless you have already shrunk at least once, though we
196 don't currently take advantage of this loophole).
197
198 Shrink passes need to be written so as to be robust against
199 change in the underlying shrink target. It is generally safe
200 to assume that the shrink target does not change prior to the
201 point of first modification - e.g. if you change no bytes at
202 index ``i``, all spans whose start is ``<= i`` still exist,
203 as do all blocks, and the data object is still of length
204 ``>= i + 1``. This can only be violated by bad user code which
205 relies on an external source of non-determinism.
206
207 When the underlying shrink_target changes, shrink
208 passes should not run substantially more test_function calls
209 on success than they do on failure. Say, no more than a constant
210 factor more. In particular shrink passes should not iterate to a
211 fixed point.
212
213 This means that shrink passes are often written with loops that
214 are carefully designed to do the right thing in the case that no
215 shrinks occurred and try to adapt to any changes to do a reasonable
216 job. e.g. say we wanted to write a shrink pass that tried deleting
217 each individual choice (this isn't an especially good pass,
218 but it leads to a simple illustrative example), we might do it
219 by iterating over the choice sequence like so:
220
221 .. code-block:: python
222
223 i = 0
224 while i < len(self.shrink_target.nodes):
225 if not self.consider_new_nodes(
226 self.shrink_target.nodes[:i] + self.shrink_target.nodes[i + 1 :]
227 ):
228 i += 1
229
230 The reason for writing the loop this way is that i is always a
231 valid index into the current choice sequence, even if the current sequence
232 changes as a result of our actions. When the choice sequence changes,
233 we leave the index where it is rather than restarting from the
234 beginning, and carry on. This means that the number of steps we
235 run in this case is always bounded above by the number of steps
236 we would run if nothing works.
237
238 Another thing to bear in mind about shrink pass design is that
239 they should prioritise *progress*. If you have N operations that
240 you need to run, you should try to order them in such a way as
241 to avoid stalling, where you have long periods of test function
242 invocations where no shrinks happen. This is bad because whenever
243 we shrink we reduce the amount of work the shrinker has to do
244 in future, and often speed up the test function, so we ideally
245 wanted those shrinks to happen much earlier in the process.
246
247 Sometimes stalls are inevitable of course - e.g. if the pass
248 makes no progress, then the entire thing is just one long stall,
249 but it's helpful to design it so that stalls are less likely
250 in typical behaviour.
251
252 The two easiest ways to do this are:
253
254 * Just run the N steps in random order. As long as a
255 reasonably large proportion of the operations succeed, this
256 guarantees the expected stall length is quite short. The
257 book keeping for making sure this does the right thing when
258 it succeeds can be quite annoying.
259 * When you have any sort of nested loop, loop in such a way
260 that both loop variables change each time. This prevents
261 stalls which occur when one particular value for the outer
262 loop is impossible to make progress on, rendering the entire
263 inner loop into a stall.
264
265 However, although progress is good, too much progress can be
266 a bad sign! If you're *only* seeing successful reductions,
267 that's probably a sign that you are making changes that are
268 too timid. Two useful things to offset this:
269
270 * It's worth writing shrink passes which are *adaptive*, in
271 the sense that when operations seem to be working really
272 well we try to bundle multiple of them together. This can
273 often be used to turn what would be O(m) successful calls
274 into O(log(m)).
275 * It's often worth trying one or two special minimal values
276 before trying anything more fine grained (e.g. replacing
277 the whole thing with zero).
278
279 """
280
281 def derived_value(fn):
282 """It's useful during shrinking to have access to derived values of
283 the current shrink target.
284
285 This decorator allows you to define these as cached properties. They
286 are calculated once, then cached until the shrink target changes, then
287 recalculated the next time they are used."""
288
289 def accept(self):
290 try:
291 return self.__derived_values[fn.__name__]
292 except KeyError:
293 return self.__derived_values.setdefault(fn.__name__, fn(self))
294
295 accept.__name__ = fn.__name__
296 return property(accept)
297
298 def __init__(
299 self,
300 engine: "ConjectureRunner",
301 initial: ConjectureData | ConjectureResult,
302 predicate: ShrinkPredicateT | None,
303 *,
304 allow_transition: (
305 Callable[[ConjectureData | ConjectureResult, ConjectureData], bool] | None
306 ),
307 explain: bool,
308 in_target_phase: bool = False,
309 ):
310 """Create a shrinker for a particular engine, with a given starting
311 point and predicate. When shrink() is called it will attempt to find an
312 example for which predicate is True and which is strictly smaller than
313 initial.
314
315 Note that initial is a ConjectureData object, and predicate
316 takes ConjectureData objects.
317 """
318 assert predicate is not None or allow_transition is not None
319 self.engine = engine
320 self.__predicate = predicate or (lambda data: True)
321 self.__allow_transition = allow_transition or (lambda source, destination: True)
322 self.__derived_values: dict = {}
323
324 self.initial_size = len(initial.choices)
325 # We keep track of the current best example on the shrink_target
326 # attribute.
327 self.shrink_target = initial
328 self.clear_change_tracking()
329 self.shrinks = 0
330
331 # We terminate shrinks that seem to have reached their logical
332 # conclusion: If we've called the underlying test function at
333 # least self.max_stall times since the last time we shrunk,
334 # it's time to stop shrinking.
335 self.max_stall = 200
336 self.initial_calls = self.engine.call_count
337 self.initial_misaligned = self.engine.misaligned_count
338 self.calls_at_last_shrink = self.initial_calls
339
340 self.shrink_passes: list[ShrinkPass] = [
341 ShrinkPass(self.try_trivial_spans),
342 self.node_program("X" * 5),
343 self.node_program("X" * 4),
344 self.node_program("X" * 3),
345 self.node_program("X" * 2),
346 self.node_program("X" * 1),
347 ShrinkPass(self.pass_to_descendant),
348 ShrinkPass(self.reorder_spans),
349 ShrinkPass(self.minimize_duplicated_choices),
350 ShrinkPass(self.minimize_individual_choices),
351 ShrinkPass(self.redistribute_numeric_pairs),
352 ShrinkPass(self.lower_integers_together),
353 ShrinkPass(self.lower_duplicated_characters),
354 ShrinkPass(self.normalize_unicode_chars),
355 ]
356
357 # Because the shrinker is also used to `pareto_optimise` in the target phase,
358 # we sometimes want to allow extending buffers instead of aborting at the end.
359 self.__extend: Literal["full"] | int = "full" if in_target_phase else 0
360 self.should_explain = explain
361
362 @derived_value # type: ignore
363 def cached_calculations(self):
364 return {}
365
366 def cached(self, *keys):
367 def accept(f):
368 cache_key = (f.__name__, *keys)
369 try:
370 return self.cached_calculations[cache_key]
371 except KeyError:
372 return self.cached_calculations.setdefault(cache_key, f())
373
374 return accept
375
376 @property
377 def calls(self) -> int:
378 """Return the number of calls that have been made to the underlying
379 test function."""
380 return self.engine.call_count
381
382 @property
383 def misaligned(self) -> int:
384 return self.engine.misaligned_count
385
386 def check_calls(self) -> None:
387 if self.calls - self.calls_at_last_shrink >= self.max_stall:
388 raise StopShrinking
389
390 def cached_test_function(
391 self, nodes: Sequence[ChoiceNode]
392 ) -> tuple[bool, ConjectureResult | _Overrun | None]:
393 nodes = nodes[: len(self.nodes)]
394
395 if startswith(nodes, self.nodes):
396 return (True, None)
397
398 if sort_key(self.nodes) < sort_key(nodes):
399 return (False, None)
400
401 # sometimes our shrinking passes try obviously invalid things. We handle
402 # discarding them in one place here.
403 if any(not choice_permitted(node.value, node.constraints) for node in nodes):
404 return (False, None)
405
406 result = self.engine.cached_test_function(
407 [n.value for n in nodes], extend=self.__extend
408 )
409 previous = self.shrink_target
410 self.incorporate_test_data(result)
411 self.check_calls()
412 return (previous is not self.shrink_target, result)
413
414 def consider_new_nodes(self, nodes: Sequence[ChoiceNode]) -> bool:
415 return self.cached_test_function(nodes)[0]
416
417 def incorporate_test_data(self, data):
418 """Takes a ConjectureData or Overrun object updates the current
419 shrink_target if this data represents an improvement over it."""
420 if data.status < Status.VALID or data is self.shrink_target:
421 return
422 if (
423 self.__predicate(data)
424 and sort_key(data.nodes) < sort_key(self.shrink_target.nodes)
425 and self.__allow_transition(self.shrink_target, data)
426 ):
427 self.update_shrink_target(data)
428
429 def debug(self, msg: str) -> None:
430 self.engine.debug(msg)
431
432 @property
433 def random(self) -> "Random":
434 return self.engine.random
435
436 def shrink(self) -> None:
437 """Run the full set of shrinks and update shrink_target.
438
439 This method is "mostly idempotent" - calling it twice is unlikely to
440 have any effect, though it has a non-zero probability of doing so.
441 """
442
443 try:
444 self.initial_coarse_reduction()
445 self.greedy_shrink()
446 except StopShrinking:
447 # If we stopped shrinking because we're making slow progress (instead of
448 # reaching a local optimum), don't run the explain-phase logic.
449 self.should_explain = False
450 finally:
451 if self.engine.report_debug_info:
452
453 def s(n):
454 return "s" if n != 1 else ""
455
456 total_deleted = self.initial_size - len(self.shrink_target.choices)
457 calls = self.engine.call_count - self.initial_calls
458 misaligned = self.engine.misaligned_count - self.initial_misaligned
459
460 self.debug(
461 "---------------------\n"
462 "Shrink pass profiling\n"
463 "---------------------\n\n"
464 f"Shrinking made a total of {calls} call{s(calls)} of which "
465 f"{self.shrinks} shrank and {misaligned} were misaligned. This "
466 f"deleted {total_deleted} choices out of {self.initial_size}."
467 )
468 for useful in [True, False]:
469 self.debug("")
470 if useful:
471 self.debug("Useful passes:")
472 else:
473 self.debug("Useless passes:")
474 self.debug("")
475 for pass_ in sorted(
476 self.shrink_passes,
477 key=lambda t: (-t.calls, t.deletions, t.shrinks),
478 ):
479 if pass_.calls == 0:
480 continue
481 if (pass_.shrinks != 0) != useful:
482 continue
483
484 self.debug(
485 f" * {pass_.name} made {pass_.calls} call{s(pass_.calls)} of which "
486 f"{pass_.shrinks} shrank and {pass_.misaligned} were misaligned, "
487 f"deleting {pass_.deletions} choice{s(pass_.deletions)}."
488 )
489 self.debug("")
490 self.explain()
491
492 def explain(self) -> None:
493
494 if not self.should_explain or not self.shrink_target.arg_slices:
495 return
496
497 self.max_stall = 2**100
498 shrink_target = self.shrink_target
499 nodes = self.nodes
500 choices = self.choices
501 chunks: dict[tuple[int, int], list[tuple[ChoiceT, ...]]] = defaultdict(list)
502
503 # Before we start running experiments, let's check for known inputs which would
504 # make them redundant. The shrinking process means that we've already tried many
505 # variations on the minimal example, so this can save a lot of time.
506 seen_passing_seq = self.engine.passing_choice_sequences(
507 prefix=self.nodes[: min(self.shrink_target.arg_slices)[0]]
508 )
509
510 # Now that we've shrunk to a minimal failing example, it's time to try
511 # varying each part that we've noted will go in the final report. Consider
512 # slices in largest-first order
513 for start, end in sorted(
514 self.shrink_target.arg_slices, key=lambda x: (-(x[1] - x[0]), x)
515 ):
516 # Check for any previous examples that match the prefix and suffix,
517 # so we can skip if we found a passing example while shrinking.
518 if any(
519 startswith(seen, nodes[:start]) and endswith(seen, nodes[end:])
520 for seen in seen_passing_seq
521 ):
522 continue
523
524 # Skip slices that are subsets of already-explained slices.
525 # If a larger slice can vary freely, so can its sub-slices.
526 # Note: (0, 0) is a special marker for the "together" comment that
527 # applies to the whole test, not a specific slice, so we exclude it.
528 if any(
529 s <= start and end <= e
530 for s, e in self.shrink_target.slice_comments
531 if (s, e) != (0, 0)
532 ):
533 continue
534
535 # Try a few targeted candidates before falling back to random sampling,
536 # so that simple cases like ``assert n1 == n2`` -- where the only
537 # passing value of ``n1`` is exactly ``n2``'s value -- aren't reported
538 # as freely-variable just because random sampling missed it.
539 candidates = list(self._explain_candidates(start, end))
540
541 # Run our experiments
542 n_same_failures = 0
543 note = "or any other generated value"
544 # TODO: is 100 same-failures out of 500 attempts a good heuristic?
545 for n_attempt in range(500 + len(candidates)): # pragma: no branch
546 # no-branch here because we don't coverage-test the abort-at-500 logic.
547
548 if n_attempt - 10 - len(candidates) > n_same_failures * 5:
549 # stop early if we're seeing mostly invalid examples
550 break # pragma: no cover
551
552 if n_attempt < len(candidates):
553 replacement = list(candidates[n_attempt])
554 else:
555 # replace start:end with random values
556 replacement = []
557 for i in range(start, end):
558 node = nodes[i]
559 if not node.was_forced:
560 value = draw_choice(
561 node.type, node.constraints, random=self.random
562 )
563 node = node.copy(with_value=value)
564 replacement.append(node.value)
565
566 attempt = choices[:start] + tuple(replacement) + choices[end:]
567 result = self.engine.cached_test_function(attempt, extend="full")
568
569 if result.status is Status.OVERRUN:
570 continue # pragma: no cover # flakily covered
571 result = cast(ConjectureResult, result)
572 if not (
573 len(attempt) == len(result.choices)
574 and endswith(result.nodes, nodes[end:])
575 ):
576 # Turns out this was a variable-length part, so grab the infix...
577 for span1, span2 in zip(
578 shrink_target.spans, result.spans, strict=False
579 ):
580 assert span1.start == span2.start
581 assert span1.start <= start
582 if span1.start == start and span1.end == end:
583 result_end = span2.end
584 break
585 else:
586 raise NotImplementedError("Expected matching prefixes")
587
588 attempt = (
589 choices[:start]
590 + result.choices[start:result_end]
591 + choices[end:]
592 )
593 chunks[(start, end)].append(result.choices[start:result_end])
594 result = self.engine.cached_test_function(attempt)
595
596 if result.status is Status.OVERRUN:
597 continue # pragma: no cover # flakily covered
598 result = cast(ConjectureResult, result)
599 else:
600 chunks[(start, end)].append(result.choices[start:end])
601
602 if shrink_target is not self.shrink_target: # pragma: no cover
603 # If we've shrunk further without meaning to, bail out.
604 self.shrink_target.slice_comments.clear()
605 return
606 if result.status is Status.VALID:
607 # The test passed, indicating that this param can't vary freely.
608 # However, it's really hard to write a simple and reliable covering
609 # test, because of our `seen_passing_buffers` check above.
610 break # pragma: no cover
611 if self.__predicate(result): # pragma: no branch
612 n_same_failures += 1
613 if n_same_failures >= 100:
614 self.shrink_target.slice_comments[(start, end)] = note
615 break
616
617 # Finally, if we've found multiple independently-variable parts, check whether
618 # they can all be varied together.
619 if len(self.shrink_target.slice_comments) <= 1:
620 return
621 n_same_failures_together = 0
622 # Only include slices that were actually added to slice_comments
623 chunks_by_start_index = sorted(
624 (k, v) for k, v in chunks.items() if k in self.shrink_target.slice_comments
625 )
626 for _ in range(500): # pragma: no branch
627 # no-branch here because we don't coverage-test the abort-at-500 logic.
628 new_choices: list[ChoiceT] = []
629 prev_end = 0
630 for (start, end), ls in chunks_by_start_index:
631 assert prev_end <= start < end, "these chunks must be nonoverlapping"
632 new_choices.extend(choices[prev_end:start])
633 new_choices.extend(self.random.choice(ls))
634 prev_end = end
635
636 result = self.engine.cached_test_function(new_choices)
637
638 # This *can't* be a shrink because none of the components were.
639 assert shrink_target is self.shrink_target
640 if result.status == Status.VALID:
641 self.shrink_target.slice_comments[(0, 0)] = (
642 "The test sometimes passed when commented parts were varied together."
643 )
644 break # Test passed, this param can't vary freely.
645 if self.__predicate(result): # pragma: no branch
646 n_same_failures_together += 1
647 if n_same_failures_together >= 100:
648 self.shrink_target.slice_comments[(0, 0)] = (
649 "The test always failed when commented parts were varied together."
650 )
651 break
652
653 def _explain_candidates(
654 self, start: int, end: int
655 ) -> "Iterator[tuple[ChoiceT, ...]]":
656 """Yield deterministic candidate replacements for ``nodes[start:end]``.
657
658 Random sampling alone misses cases like ``assert n1 == n2``, where the
659 only passing value of ``n1`` is exactly ``n2``'s value. We try
660 substituting values from each other arg slice with matching length and
661 types, which catches such comparisons. Invalid borrowed values just
662 produce an irrelevant test result the outer loop discards.
663 """
664 nodes = self.nodes
665 target_types = tuple(nodes[i].type for i in range(start, end))
666 current_keys = tuple(choice_key(nodes[i].value) for i in range(start, end))
667 seen: set[tuple[Any, ...]] = {current_keys}
668 for s2, e2 in sorted(self.shrink_target.arg_slices):
669 if (s2, e2) == (start, end) or (e2 - s2) != (end - start):
670 continue
671 if tuple(nodes[s2 + j].type for j in range(end - start)) != target_types:
672 continue
673 borrowed = tuple(nodes[s2 + j].value for j in range(end - start))
674 key = tuple(choice_key(v) for v in borrowed)
675 if key in seen:
676 continue
677 seen.add(key)
678 yield borrowed
679
680 def greedy_shrink(self) -> None:
681 """Run a full set of greedy shrinks (that is, ones that will only ever
682 move to a better target) and update shrink_target appropriately.
683
684 This method iterates to a fixed point and so is idempontent - calling
685 it twice will have exactly the same effect as calling it once.
686 """
687 self.fixate_shrink_passes(self.shrink_passes)
688
689 def initial_coarse_reduction(self):
690 """Performs some preliminary reductions that should not be
691 repeated as part of the main shrink passes.
692
693 The main reason why these can't be included as part of shrink
694 passes is that they have much more ability to make the test
695 case "worse". e.g. they might rerandomise part of it, significantly
696 increasing the value of individual nodes, which works in direct
697 opposition to the lexical shrinking and will frequently undo
698 its work.
699 """
700 self.reduce_each_alternative()
701
702 @derived_value # type: ignore
703 def spans_starting_at(self):
704 result = [[] for _ in self.shrink_target.nodes]
705 for i, ex in enumerate(self.spans):
706 # We can have zero-length spans that start at the end
707 if ex.start < len(result):
708 result[ex.start].append(i)
709 return tuple(map(tuple, result))
710
711 def reduce_each_alternative(self):
712 """This is a pass that is designed to rerandomise use of the
713 one_of strategy or things that look like it, in order to try
714 to move from later strategies to earlier ones in the branch
715 order.
716
717 It does this by trying to systematically lower each value it
718 finds that looks like it might be the branch decision for
719 one_of, and then attempts to repair any changes in shape that
720 this causes.
721 """
722 i = 0
723 while i < len(self.shrink_target.nodes):
724 nodes = self.shrink_target.nodes
725 node = nodes[i]
726 if (
727 node.type == "integer"
728 and not node.was_forced
729 and node.value <= 10
730 and node.constraints["min_value"] == 0
731 ):
732 assert isinstance(node.value, int)
733
734 # We've found a plausible candidate for a ``one_of`` choice.
735 # We now want to see if the shape of the test case actually depends
736 # on it. If it doesn't, then we don't need to do this (comparatively
737 # costly) pass, and can let much simpler lexicographic reduction
738 # handle it later.
739 #
740 # We test this by trying to set the value to zero and seeing if the
741 # shape changes, as measured by either changing the number of subsequent
742 # nodes, or changing the nodes in such a way as to cause one of the
743 # previous values to no longer be valid in its position.
744 zero_attempt = self.cached_test_function(
745 nodes[:i] + (nodes[i].copy(with_value=0),) + nodes[i + 1 :]
746 )[1]
747 if (
748 zero_attempt is not self.shrink_target
749 and zero_attempt is not None
750 and zero_attempt.status >= Status.VALID
751 ):
752 changed_shape = len(zero_attempt.nodes) != len(nodes)
753
754 if not changed_shape:
755 for j in range(i + 1, len(nodes)):
756 zero_node = zero_attempt.nodes[j]
757 orig_node = nodes[j]
758 if (
759 zero_node.type != orig_node.type
760 or not choice_permitted(
761 orig_node.value, zero_node.constraints
762 )
763 ):
764 changed_shape = True
765 break
766 if changed_shape:
767 for v in range(node.value):
768 if self.try_lower_node_as_alternative(i, v):
769 break
770 i += 1
771
772 def try_lower_node_as_alternative(self, i, v):
773 """Attempt to lower `self.shrink_target.nodes[i]` to `v`,
774 while rerandomising and attempting to repair any subsequent
775 changes to the shape of the test case that this causes."""
776 nodes = self.shrink_target.nodes
777 if self.consider_new_nodes(
778 nodes[:i] + (nodes[i].copy(with_value=v),) + nodes[i + 1 :]
779 ):
780 return True
781
782 prefix = nodes[:i] + (nodes[i].copy(with_value=v),)
783 initial = self.shrink_target
784 spans = self.spans_starting_at[i]
785 for _ in range(3):
786 random_attempt = self.engine.cached_test_function(
787 [n.value for n in prefix], extend=len(nodes)
788 )
789 if random_attempt.status < Status.VALID:
790 continue
791 self.incorporate_test_data(random_attempt)
792 for j in spans:
793 initial_span = initial.spans[j]
794 attempt_span = random_attempt.spans[j]
795 contents = random_attempt.nodes[attempt_span.start : attempt_span.end]
796 self.consider_new_nodes(
797 nodes[:i] + contents + nodes[initial_span.end :]
798 )
799 if initial is not self.shrink_target:
800 return True
801 return False
802
803 @derived_value # type: ignore
804 def shrink_pass_choice_trees(self) -> dict[Any, ChoiceTree]:
805 return defaultdict(ChoiceTree)
806
807 def step(self, shrink_pass: ShrinkPass, *, random_order: bool = False) -> bool:
808 tree = self.shrink_pass_choice_trees[shrink_pass]
809 if tree.exhausted:
810 return False
811
812 initial_shrinks = self.shrinks
813 initial_calls = self.calls
814 initial_misaligned = self.misaligned
815 size = len(self.shrink_target.choices)
816 assert shrink_pass.name is not None
817 self.engine.explain_next_call_as(shrink_pass.name)
818
819 if random_order:
820 selection_order = random_selection_order(self.random)
821 else:
822 selection_order = prefix_selection_order(shrink_pass.last_prefix)
823
824 try:
825 shrink_pass.last_prefix = tree.step(
826 selection_order,
827 lambda chooser: shrink_pass.function(chooser),
828 )
829 finally:
830 shrink_pass.calls += self.calls - initial_calls
831 shrink_pass.misaligned += self.misaligned - initial_misaligned
832 shrink_pass.shrinks += self.shrinks - initial_shrinks
833 shrink_pass.deletions += size - len(self.shrink_target.choices)
834 self.engine.clear_call_explanation()
835 return True
836
837 def fixate_shrink_passes(self, passes: list[ShrinkPass]) -> None:
838 """Run steps from each pass in ``passes`` until the current shrink target
839 is a fixed point of all of them."""
840 any_ran = True
841 while any_ran:
842 any_ran = False
843
844 reordering = {}
845
846 # We run remove_discarded after every pass to do cleanup
847 # keeping track of whether that actually works. Either there is
848 # no discarded data and it is basically free, or it reliably works
849 # and deletes data, or it doesn't work. In that latter case we turn
850 # it off for the rest of this loop through the passes, but will
851 # try again once all of the passes have been run.
852 can_discard = self.remove_discarded()
853
854 calls_at_loop_start = self.calls
855
856 # We keep track of how many calls can be made by a single step
857 # without making progress and use this to test how much to pad
858 # out self.max_stall by as we go along.
859 max_calls_per_failing_step = 1
860
861 for sp in passes:
862 if can_discard:
863 can_discard = self.remove_discarded()
864
865 before_sp = self.shrink_target
866
867 # Run the shrink pass until it fails to make any progress
868 # max_failures times in a row. This implicitly boosts shrink
869 # passes that are more likely to work.
870 failures = 0
871 max_failures = 20
872 while failures < max_failures:
873 # We don't allow more than max_stall consecutive failures
874 # to shrink, but this means that if we're unlucky and the
875 # shrink passes are in a bad order where only the ones at
876 # the end are useful, if we're not careful this heuristic
877 # might stop us before we've tried everything. In order to
878 # avoid that happening, we make sure that there's always
879 # plenty of breathing room to make it through a single
880 # iteration of the fixate_shrink_passes loop.
881 self.max_stall = max(
882 self.max_stall,
883 2 * max_calls_per_failing_step
884 + (self.calls - calls_at_loop_start),
885 )
886
887 prev = self.shrink_target
888 initial_calls = self.calls
889 # It's better for us to run shrink passes in a deterministic
890 # order, to avoid repeat work, but this can cause us to create
891 # long stalls when there are a lot of steps which fail to do
892 # anything useful. In order to avoid this, once we've noticed
893 # we're in a stall (i.e. half of max_failures calls have failed
894 # to do anything) we switch to randomly jumping around. If we
895 # find a success then we'll resume deterministic order from
896 # there which, with any luck, is in a new good region.
897 if not self.step(sp, random_order=failures >= max_failures // 2):
898 # step returns False when there is nothing to do because
899 # the entire choice tree is exhausted. If this happens
900 # we break because we literally can't run this pass any
901 # more than we already have until something else makes
902 # progress.
903 break
904 any_ran = True
905
906 # Don't count steps that didn't actually try to do
907 # anything as failures. Otherwise, this call is a failure
908 # if it failed to make any changes to the shrink target.
909 if initial_calls != self.calls:
910 if prev is not self.shrink_target:
911 failures = 0
912 else:
913 max_calls_per_failing_step = max(
914 max_calls_per_failing_step, self.calls - initial_calls
915 )
916 failures += 1
917
918 # We reorder the shrink passes so that on our next run through
919 # we try good ones first. The rule is that shrink passes that
920 # did nothing useful are the worst, shrink passes that reduced
921 # the length are the best.
922 if self.shrink_target is before_sp:
923 reordering[sp] = 1
924 elif len(self.choices) < len(before_sp.choices):
925 reordering[sp] = -1
926 else:
927 reordering[sp] = 0
928
929 passes.sort(key=reordering.__getitem__)
930
931 @property
932 def nodes(self) -> tuple[ChoiceNode, ...]:
933 return self.shrink_target.nodes
934
935 @property
936 def choices(self) -> tuple[ChoiceT, ...]:
937 return self.shrink_target.choices
938
939 @property
940 def spans(self) -> Spans:
941 return self.shrink_target.spans
942
943 @derived_value # type: ignore
944 def spans_by_label(self):
945 """
946 A mapping of labels to a list of spans with that label. Spans in the list
947 are ordered by their normal index order.
948 """
949
950 spans_by_label = defaultdict(list)
951 for ex in self.spans:
952 spans_by_label[ex.label].append(ex)
953 return dict(spans_by_label)
954
955 @derived_value # type: ignore
956 def distinct_labels(self):
957 return sorted(self.spans_by_label, key=str)
958
959 def pass_to_descendant(self, chooser):
960 """Attempt to replace each span with a descendant span.
961
962 This is designed to deal with strategies that call themselves
963 recursively. For example, suppose we had:
964
965 binary_tree = st.deferred(
966 lambda: st.one_of(
967 st.integers(), st.tuples(binary_tree, binary_tree)))
968
969 This pass guarantees that we can replace any binary tree with one of
970 its subtrees - each of those will create an interval that the parent
971 could validly be replaced with, and this pass will try doing that.
972
973 This is pretty expensive - it takes O(len(intervals)^2) - so we run it
974 late in the process when we've got the number of intervals as far down
975 as possible.
976 """
977
978 label = chooser.choose(
979 self.distinct_labels, lambda l: len(self.spans_by_label[l]) >= 2
980 )
981
982 spans = self.spans_by_label[label]
983 i = chooser.choose(range(len(spans) - 1))
984 ancestor = spans[i]
985
986 if i + 1 == len(spans) or spans[i + 1].start >= ancestor.end:
987 return
988
989 @self.cached(label, i)
990 def descendants():
991 lo = i + 1
992 hi = len(spans)
993 while lo + 1 < hi:
994 mid = (lo + hi) // 2
995 if spans[mid].start >= ancestor.end:
996 hi = mid
997 else:
998 lo = mid
999 return [
1000 span
1001 for span in spans[i + 1 : hi]
1002 if span.choice_count < ancestor.choice_count
1003 ]
1004
1005 descendant = chooser.choose(descendants, lambda ex: ex.choice_count > 0)
1006
1007 assert ancestor.start <= descendant.start
1008 assert ancestor.end >= descendant.end
1009 assert descendant.choice_count < ancestor.choice_count
1010
1011 self.consider_new_nodes(
1012 self.nodes[: ancestor.start]
1013 + self.nodes[descendant.start : descendant.end]
1014 + self.nodes[ancestor.end :]
1015 )
1016
1017 def lower_common_node_offset(self):
1018 """Sometimes we find ourselves in a situation where changes to one part
1019 of the choice sequence unlock changes to other parts. Sometimes this is
1020 good, but sometimes this can cause us to exhibit exponential slow
1021 downs!
1022
1023 e.g. suppose we had the following:
1024
1025 m = draw(integers(min_value=0))
1026 n = draw(integers(min_value=0))
1027 assert abs(m - n) > 1
1028
1029 If this fails then we'll end up with a loop where on each iteration we
1030 reduce each of m and n by 2 - m can't go lower because of n, then n
1031 can't go lower because of m.
1032
1033 This will take us O(m) iterations to complete, which is exponential in
1034 the data size, as we gradually zig zag our way towards zero.
1035
1036 This can only happen if we're failing to reduce the size of the choice
1037 sequence: The number of iterations that reduce the length of the choice
1038 sequence is bounded by that length.
1039
1040 So what we do is this: We keep track of which nodes are changing, and
1041 then if there's some non-zero common offset to them we try and minimize
1042 them all at once by lowering that offset.
1043
1044 This may not work, and it definitely won't get us out of all possible
1045 exponential slow downs (an example of where it doesn't is where the
1046 shape of the nodes changes as a result of this bouncing behaviour),
1047 but it fails fast when it doesn't work and gets us out of a really
1048 nastily slow case when it does.
1049 """
1050 if len(self.__changed_nodes) <= 1:
1051 return
1052
1053 changed = []
1054 for i in sorted(self.__changed_nodes):
1055 node = self.nodes[i]
1056 if node.trivial or node.type != "integer":
1057 continue
1058 changed.append(node)
1059
1060 if not changed:
1061 return
1062
1063 ints = [
1064 abs(node.value - node.constraints["shrink_towards"]) for node in changed
1065 ]
1066 offset = min(ints)
1067 assert offset > 0
1068
1069 for i in range(len(ints)):
1070 ints[i] -= offset
1071
1072 st = self.shrink_target
1073
1074 def offset_node(node, n):
1075 return (
1076 node.index,
1077 node.index + 1,
1078 [node.copy(with_value=node.constraints["shrink_towards"] + n)],
1079 )
1080
1081 def consider(n, sign):
1082 return self.consider_new_nodes(
1083 replace_all(
1084 st.nodes,
1085 [
1086 offset_node(node, sign * (n + v))
1087 for node, v in zip(changed, ints, strict=False)
1088 ],
1089 )
1090 )
1091
1092 # shrink from both sides
1093 Integer.shrink(offset, lambda n: consider(n, 1))
1094 Integer.shrink(offset, lambda n: consider(n, -1))
1095 self.clear_change_tracking()
1096
1097 def clear_change_tracking(self):
1098 self.__last_checked_changed_at = self.shrink_target
1099 self.__all_changed_nodes = set()
1100
1101 def mark_changed(self, i):
1102 self.__changed_nodes.add(i)
1103
1104 @property
1105 def __changed_nodes(self) -> set[int]:
1106 if self.__last_checked_changed_at is self.shrink_target:
1107 return self.__all_changed_nodes
1108
1109 prev_target = self.__last_checked_changed_at
1110 new_target = self.shrink_target
1111 assert prev_target is not new_target
1112 prev_nodes = prev_target.nodes
1113 new_nodes = new_target.nodes
1114 assert sort_key(new_target.nodes) < sort_key(prev_target.nodes)
1115
1116 if len(prev_nodes) != len(new_nodes) or any(
1117 n1.type != n2.type for n1, n2 in zip(prev_nodes, new_nodes, strict=True)
1118 ):
1119 # should we check constraints are equal as well?
1120 self.__all_changed_nodes = set()
1121 else:
1122 assert len(prev_nodes) == len(new_nodes)
1123 for i, (n1, n2) in enumerate(zip(prev_nodes, new_nodes, strict=True)):
1124 assert n1.type == n2.type
1125 if not choice_equal(n1.value, n2.value):
1126 self.__all_changed_nodes.add(i)
1127
1128 return self.__all_changed_nodes
1129
1130 def update_shrink_target(self, new_target):
1131 assert isinstance(new_target, ConjectureResult)
1132 self.shrinks += 1
1133 # If we are just taking a long time to shrink we don't want to
1134 # trigger this heuristic, so whenever we shrink successfully
1135 # we give ourselves a bit of breathing room to make sure we
1136 # would find a shrink that took that long to find the next time.
1137 # The case where we're taking a long time but making steady
1138 # progress is handled by `finish_shrinking_deadline` in engine.py
1139 self.max_stall = max(
1140 self.max_stall, (self.calls - self.calls_at_last_shrink) * 2
1141 )
1142 self.calls_at_last_shrink = self.calls
1143 self.shrink_target = new_target
1144 self.__derived_values = {}
1145
1146 def try_shrinking_nodes(self, nodes, n):
1147 """Attempts to replace each node in the nodes list with n. Returns
1148 True if it succeeded (which may include some additional modifications
1149 to shrink_target).
1150
1151 In current usage it is expected that each of the nodes currently have
1152 the same value and choice_type, although this is not essential. Note that
1153 n must be < the node at min(nodes) or this is not a valid shrink.
1154
1155 This method will attempt to do some small amount of work to delete data
1156 that occurs after the end of the nodes. This is useful for cases where
1157 there is some size dependency on the value of a node.
1158 """
1159 # If the length of the shrink target has changed from under us such that
1160 # the indices are out of bounds, give up on the replacement.
1161 # TODO_BETTER_SHRINK: we probably want to narrow down the root cause here at some point.
1162 if any(node.index >= len(self.nodes) for node in nodes):
1163 return # pragma: no cover
1164
1165 initial_attempt = replace_all(
1166 self.nodes,
1167 [(node.index, node.index + 1, [node.copy(with_value=n)]) for node in nodes],
1168 )
1169
1170 attempt = self.cached_test_function(initial_attempt)[1]
1171
1172 if attempt is None:
1173 return False
1174
1175 if attempt is self.shrink_target:
1176 # if the initial shrink was a success, try lowering offsets.
1177 self.lower_common_node_offset()
1178 return True
1179
1180 # If this produced something completely invalid we ditch it
1181 # here rather than trying to persevere.
1182 if attempt.status is Status.OVERRUN:
1183 return False
1184
1185 if attempt.status is Status.INVALID:
1186 return False
1187
1188 if attempt.misaligned_at is not None:
1189 # we're invalid due to a misalignment in the tree. We'll try to fix
1190 # a very specific type of misalignment here: where we have a node of
1191 # {"size": n} and tried to draw the same node, but with {"size": m < n}.
1192 # This can occur with eg
1193 #
1194 # n = data.draw_integer()
1195 # s = data.draw_string(min_size=n)
1196 #
1197 # where we try lowering n, resulting in the test_function drawing a lower
1198 # min_size than our attempt had for the draw_string node.
1199 #
1200 # We'll now try realigning this tree by:
1201 # * replacing the constraints in our attempt with what test_function tried
1202 # to draw in practice
1203 # * truncating the value of that node to match min_size
1204 #
1205 # This helps in the specific case of drawing a value and then drawing
1206 # a collection of that size...and not much else. In practice this
1207 # helps because this antipattern is fairly common.
1208
1209 # TODO we'll probably want to apply the same trick as in the valid
1210 # case of this function of preserving from the right instead of
1211 # preserving from the left. see test_can_shrink_variable_string_draws.
1212
1213 index, attempt_choice_type, attempt_constraints, _attempt_forced = (
1214 attempt.misaligned_at
1215 )
1216 node = self.nodes[index]
1217 if node.type != attempt_choice_type:
1218 return False # pragma: no cover
1219 if node.was_forced:
1220 return False # pragma: no cover
1221
1222 if node.type in {"string", "bytes"}:
1223 # if the size *increased*, we would have to guess what to pad with
1224 # in order to try fixing up this attempt. Just give up.
1225 if node.constraints["min_size"] <= attempt_constraints["min_size"]:
1226 # attempts which increase min_size tend to overrun rather than
1227 # be misaligned, making a covering case difficult.
1228 return False # pragma: no cover
1229 # the size decreased in our attempt. Try again, but truncate the value
1230 # to that size by removing any elements past min_size.
1231 return self.consider_new_nodes(
1232 initial_attempt[: node.index]
1233 + [
1234 initial_attempt[node.index].copy(
1235 with_constraints=attempt_constraints,
1236 with_value=initial_attempt[node.index].value[
1237 : attempt_constraints["min_size"]
1238 ],
1239 )
1240 ]
1241 + initial_attempt[node.index :]
1242 )
1243
1244 lost_nodes = len(self.nodes) - len(attempt.nodes)
1245 if lost_nodes <= 0:
1246 return False
1247
1248 start = nodes[0].index
1249 end = nodes[-1].index + 1
1250 # We now look for contiguous regions to delete that might help fix up
1251 # this failed shrink. We only look for contiguous regions of the right
1252 # lengths because doing anything more than that starts to get very
1253 # expensive. See minimize_individual_choices for where we
1254 # try to be more aggressive.
1255 regions_to_delete = {(end, end + lost_nodes)}
1256
1257 for ex in self.spans:
1258 if ex.start > start:
1259 continue
1260 if ex.end <= end:
1261 continue
1262
1263 if ex.index >= len(attempt.spans):
1264 continue # pragma: no cover
1265
1266 replacement = attempt.spans[ex.index]
1267 in_original = [c for c in ex.children if c.start >= end]
1268 in_replaced = [c for c in replacement.children if c.start >= end]
1269
1270 if len(in_replaced) >= len(in_original) or not in_replaced:
1271 continue
1272
1273 # We've found a span where some of the children went missing
1274 # as a result of this change, and just replacing it with the data
1275 # it would have had and removing the spillover didn't work. This
1276 # means that some of its children towards the right must be
1277 # important, so we try to arrange it so that it retains its
1278 # rightmost children instead of its leftmost.
1279 regions_to_delete.add(
1280 (in_original[0].start, in_original[-len(in_replaced)].start)
1281 )
1282
1283 for u, v in sorted(regions_to_delete, key=lambda x: x[1] - x[0], reverse=True):
1284 try_with_deleted = initial_attempt[:u] + initial_attempt[v:]
1285 if self.consider_new_nodes(try_with_deleted):
1286 return True
1287
1288 return False
1289
1290 def remove_discarded(self):
1291 """Try removing all bytes marked as discarded.
1292
1293 This is primarily to deal with data that has been ignored while
1294 doing rejection sampling - e.g. as a result of an integer range, or a
1295 filtered strategy.
1296
1297 Such data will also be handled by the adaptive_example_deletion pass,
1298 but that pass is necessarily more conservative and will try deleting
1299 each interval individually. The common case is that all data drawn and
1300 rejected can just be thrown away immediately in one block, so this pass
1301 will be much faster than trying each one individually when it works.
1302
1303 returns False if there is discarded data and removing it does not work,
1304 otherwise returns True.
1305 """
1306 while self.shrink_target.has_discards:
1307 discarded = []
1308
1309 for ex in self.shrink_target.spans:
1310 if (
1311 ex.choice_count > 0
1312 and ex.discarded
1313 and (not discarded or ex.start >= discarded[-1][-1])
1314 ):
1315 discarded.append((ex.start, ex.end))
1316
1317 # This can happen if we have discards but they are all of
1318 # zero length. This shouldn't happen very often so it's
1319 # faster to check for it here than at the point of example
1320 # generation.
1321 if not discarded:
1322 break
1323
1324 attempt = list(self.nodes)
1325 for u, v in reversed(discarded):
1326 del attempt[u:v]
1327
1328 if not self.consider_new_nodes(tuple(attempt)):
1329 return False
1330 return True
1331
1332 @derived_value # type: ignore
1333 def duplicated_nodes(self):
1334 """Returns a list of nodes grouped (choice_type, value)."""
1335 duplicates = defaultdict(list)
1336 for node in self.nodes:
1337 duplicates[(node.type, choice_key(node.value))].append(node)
1338 return list(duplicates.values())
1339
1340 def node_program(self, program: str) -> ShrinkPass:
1341 return ShrinkPass(
1342 lambda chooser: self._node_program(chooser, program),
1343 name=f"node_program_{program}",
1344 )
1345
1346 def _node_program(self, chooser, program):
1347 n = len(program)
1348 # Adaptively attempt to run the node program at the current
1349 # index. If this successfully applies the node program ``k`` times
1350 # then this runs in ``O(log(k))`` test function calls.
1351 i = chooser.choose(range(len(self.nodes) - n + 1))
1352
1353 # First, run the node program at the chosen index. If this fails,
1354 # don't do any extra work, so that failure is as cheap as possible.
1355 if not self.run_node_program(i, program, original=self.shrink_target):
1356 return
1357
1358 # Because we run in a random order we will often find ourselves in the middle
1359 # of a region where we could run the node program. We thus start by moving
1360 # left to the beginning of that region if possible in order to start from
1361 # the beginning of that region.
1362 def offset_left(k):
1363 return i - k * n
1364
1365 i = offset_left(
1366 find_integer(
1367 lambda k: self.run_node_program(
1368 offset_left(k), program, original=self.shrink_target
1369 )
1370 )
1371 )
1372
1373 original = self.shrink_target
1374 # Now try to run the node program multiple times here.
1375 find_integer(
1376 lambda k: self.run_node_program(i, program, original=original, repeats=k)
1377 )
1378
1379 def minimize_duplicated_choices(self, chooser):
1380 """Find choices that have been duplicated in multiple places and attempt
1381 to minimize all of the duplicates simultaneously.
1382
1383 This lets us handle cases where two values can't be shrunk
1384 independently of each other but can easily be shrunk together.
1385 For example if we had something like:
1386
1387 ls = data.draw(lists(integers()))
1388 y = data.draw(integers())
1389 assert y not in ls
1390
1391 Suppose we drew y = 3 and after shrinking we have ls = [3]. If we were
1392 to replace both 3s with 0, this would be a valid shrink, but if we were
1393 to replace either 3 with 0 on its own the test would start passing.
1394
1395 It is also useful for when that duplication is accidental and the value
1396 of the choices don't matter very much because it allows us to replace
1397 more values at once.
1398 """
1399 nodes = chooser.choose(self.duplicated_nodes)
1400 # we can't lower any nodes which are trivial. try proceeding with the
1401 # remaining nodes.
1402 nodes = [node for node in nodes if not node.trivial]
1403 if len(nodes) <= 1:
1404 return
1405
1406 self.minimize_nodes(nodes)
1407
1408 def redistribute_numeric_pairs(self, chooser):
1409 """If there is a sum of generated numbers that we need their sum
1410 to exceed some bound, lowering one of them requires raising the
1411 other. This pass enables that."""
1412
1413 # look for a pair of nodes (node1, node2) which are both numeric
1414 # and aren't separated by too many other nodes. We'll decrease node1 and
1415 # increase node2 (note that the other way around doesn't make sense as
1416 # it's strictly worse in the ordering).
1417 def can_choose_node(node):
1418 # don't choose nan, inf, or floats above the threshold where f + 1 > f
1419 # (which is not necessarily true for floats above MAX_PRECISE_INTEGER).
1420 # The motivation for the last condition is to avoid trying weird
1421 # non-shrinks where we raise one node and think we lowered another
1422 # (but didn't).
1423 return node.type in {"integer", "float"} and not (
1424 node.type == "float"
1425 and (math.isnan(node.value) or abs(node.value) >= MAX_PRECISE_INTEGER)
1426 )
1427
1428 node1 = chooser.choose(
1429 self.nodes,
1430 lambda node: can_choose_node(node) and not node.trivial,
1431 )
1432 node2 = chooser.choose(
1433 self.nodes,
1434 lambda node: (
1435 can_choose_node(node)
1436 # Note that it's fine for node2 to be trivial, because we're going to
1437 # explicitly make it *not* trivial by adding to its value.
1438 and not node.was_forced
1439 # to avoid quadratic behavior, scan ahead only a small amount for
1440 # the related node.
1441 and node1.index < node.index <= node1.index + 4
1442 ),
1443 )
1444
1445 m: int | float = node1.value
1446 n: int | float = node2.value
1447
1448 def boost(k: int) -> bool:
1449 # floats always shrink towards 0
1450 shrink_towards = (
1451 node1.constraints["shrink_towards"] if node1.type == "integer" else 0
1452 )
1453 if k > abs(m - shrink_towards):
1454 return False
1455
1456 # We are trying to move node1 (m) closer to shrink_towards, and node2
1457 # (n) farther away from shrink_towards. If m is below shrink_towards,
1458 # we want to add to m and subtract from n, and vice versa if above
1459 # shrink_towards.
1460 if m < shrink_towards:
1461 k = -k
1462
1463 try:
1464 v1 = m - k
1465 v2 = n + k
1466 except OverflowError: # pragma: no cover
1467 # if n or m is a float and k is over sys.float_info.max, coercing
1468 # k to a float will overflow.
1469 return False
1470
1471 # if we've increased node2 to the point that we're past max precision,
1472 # give up - things have become too unstable.
1473 if node1.type == "float" and abs(v2) >= MAX_PRECISE_INTEGER:
1474 return False
1475
1476 return self.consider_new_nodes(
1477 self.nodes[: node1.index]
1478 + (node1.copy(with_value=v1),)
1479 + self.nodes[node1.index + 1 : node2.index]
1480 + (node2.copy(with_value=v2),)
1481 + self.nodes[node2.index + 1 :]
1482 )
1483
1484 find_integer(boost)
1485
1486 def lower_integers_together(self, chooser):
1487 node1 = chooser.choose(
1488 self.nodes, lambda n: n.type == "integer" and not n.trivial
1489 )
1490 # Search up to 3 nodes ahead, to avoid quadratic time.
1491 node2 = self.nodes[
1492 chooser.choose(
1493 range(node1.index + 1, min(len(self.nodes), node1.index + 3 + 1)),
1494 lambda i: (
1495 self.nodes[i].type == "integer" and not self.nodes[i].was_forced
1496 ),
1497 )
1498 ]
1499
1500 # one might expect us to require node2 to be nontrivial, and to minimize
1501 # the node which is closer to its shrink_towards, rather than node1
1502 # unconditionally. In reality, it's acceptable for us to transition node2
1503 # from trivial to nontrivial, because the shrink ordering is dominated by
1504 # the complexity of the earlier node1. What matters is minimizing node1.
1505 shrink_towards = node1.constraints["shrink_towards"]
1506
1507 def consider(n):
1508 return self.consider_new_nodes(
1509 self.nodes[: node1.index]
1510 + (node1.copy(with_value=node1.value - n),)
1511 + self.nodes[node1.index + 1 : node2.index]
1512 + (node2.copy(with_value=node2.value - n),)
1513 + self.nodes[node2.index + 1 :]
1514 )
1515
1516 find_integer(lambda n: consider(shrink_towards - n))
1517 find_integer(lambda n: consider(n - shrink_towards))
1518
1519 def lower_duplicated_characters(self, chooser):
1520 """
1521 Select two string choices no more than 4 choices apart and simultaneously
1522 lower characters which appear in both strings. This helps cases where the
1523 same character must appear in two strings, but the actual value of the
1524 character is not relevant.
1525
1526 This shrinking pass currently only tries lowering *all* instances of the
1527 duplicated character in both strings. So for instance, given two choices:
1528
1529 "bbac"
1530 "abbb"
1531
1532 we would try lowering all five of the b characters simultaneously. This
1533 may fail to shrink some cases where only certain character indices are
1534 correlated, for instance if only the b at index 1 could be lowered
1535 simultaneously and the rest did in fact actually have to be a `b`.
1536
1537 It would be nice to try shrinking that case as well, but we would need good
1538 safeguards because it could get very expensive to try all combinations.
1539 I expect lowering all duplicates to handle most cases in the meantime.
1540 """
1541 node1 = chooser.choose(
1542 self.nodes, lambda n: n.type == "string" and not n.trivial
1543 )
1544
1545 # limit search to up to 4 choices ahead, to avoid quadratic behavior
1546 node2 = self.nodes[
1547 chooser.choose(
1548 range(node1.index + 1, min(len(self.nodes), node1.index + 1 + 4)),
1549 lambda i: (
1550 self.nodes[i].type == "string"
1551 and not self.nodes[i].trivial
1552 # select nodes which have at least one of the same character present
1553 and set(node1.value) & set(self.nodes[i].value)
1554 ),
1555 )
1556 ]
1557
1558 duplicated_characters = set(node1.value) & set(node2.value)
1559 # deterministic ordering
1560 char = chooser.choose(sorted(duplicated_characters))
1561 intervals = node1.constraints["intervals"]
1562
1563 def copy_node(node, n):
1564 # replace all duplicate characters in each string. This might miss
1565 # some shrinks compared to only replacing some, but trying all possible
1566 # combinations of indices could get expensive if done without some
1567 # thought.
1568 return node.copy(
1569 with_value=node.value.replace(char, intervals.char_in_shrink_order(n))
1570 )
1571
1572 Integer.shrink(
1573 intervals.index_from_char_in_shrink_order(char),
1574 lambda n: self.consider_new_nodes(
1575 self.nodes[: node1.index]
1576 + (copy_node(node1, n),)
1577 + self.nodes[node1.index + 1 : node2.index]
1578 + (copy_node(node2, n),)
1579 + self.nodes[node2.index + 1 :]
1580 ),
1581 )
1582
1583 def normalize_unicode_chars(self, chooser):
1584 """For string nodes, try replacing characters with simpler equivalents
1585 from natural text transformations: unicode decomposition (NFD, NFKD)
1586 and case mapping. For example, an accented latin letter is reduced
1587 to its base form, a ligature is reduced to its first base character,
1588 a mathematical alphanumeric symbol is reduced to its plain ascii
1589 counterpart, and a lowercase letter is replaced with its uppercase
1590 form (which has a smaller shrink-order index in the default
1591 alphabet).
1592
1593 The codepoint shrinker is binary-search based, so it can get stuck on
1594 a high codepoint whose simpler equivalents aren't reached by halving
1595 / shifting / masking. This pass directly tries the natural simpler
1596 forms one character at a time.
1597 """
1598 node = chooser.choose(
1599 self.nodes,
1600 lambda n: n.type == "string"
1601 and any(
1602 _natural_simpler_chars(c, n.constraints["intervals"]) for c in n.value
1603 ),
1604 )
1605 intervals = node.constraints["intervals"]
1606 i = chooser.choose(
1607 range(len(node.value)),
1608 lambda j: bool(_natural_simpler_chars(node.value[j], intervals)),
1609 )
1610 for replacement in _natural_simpler_chars(node.value[i], intervals):
1611 new_value = node.value[:i] + replacement + node.value[i + 1 :]
1612 if self.consider_new_nodes(
1613 self.nodes[: node.index]
1614 + (node.copy(with_value=new_value),)
1615 + self.nodes[node.index + 1 :]
1616 ):
1617 return
1618
1619 def minimize_nodes(self, nodes):
1620 choice_type = nodes[0].type
1621 value = nodes[0].value
1622 # unlike choice_type and value, constraints are *not* guaranteed to be equal among all
1623 # passed nodes. We arbitrarily use the constraints of the first node. I think
1624 # this is unsound (= leads to us trying shrinks that could not have been
1625 # generated), but those get discarded at test-time, and this enables useful
1626 # slips where constraints are not equal but are close enough that doing the
1627 # same operation on both basically just works.
1628 constraints = nodes[0].constraints
1629 assert all(
1630 node.type == choice_type and choice_equal(node.value, value)
1631 for node in nodes
1632 )
1633
1634 if choice_type == "integer":
1635 shrink_towards = constraints["shrink_towards"]
1636 # try shrinking from both sides towards shrink_towards.
1637 # we're starting from n = abs(shrink_towards - value). Because the
1638 # shrinker will not check its starting value, we need to try
1639 # shrinking to n first.
1640 self.try_shrinking_nodes(nodes, abs(shrink_towards - value))
1641 Integer.shrink(
1642 abs(shrink_towards - value),
1643 lambda n: self.try_shrinking_nodes(nodes, shrink_towards + n),
1644 )
1645 Integer.shrink(
1646 abs(shrink_towards - value),
1647 lambda n: self.try_shrinking_nodes(nodes, shrink_towards - n),
1648 )
1649 elif choice_type == "float":
1650 self.try_shrinking_nodes(nodes, abs(value))
1651 Float.shrink(
1652 abs(value),
1653 lambda val: self.try_shrinking_nodes(nodes, val),
1654 )
1655 Float.shrink(
1656 abs(value),
1657 lambda val: self.try_shrinking_nodes(nodes, -val),
1658 )
1659 elif choice_type == "boolean":
1660 # must be True, otherwise would be trivial and not selected.
1661 assert value is True
1662 # only one thing to try: false!
1663 self.try_shrinking_nodes(nodes, False)
1664 elif choice_type == "bytes":
1665 Bytes.shrink(
1666 value,
1667 lambda val: self.try_shrinking_nodes(nodes, val),
1668 min_size=constraints["min_size"],
1669 )
1670 elif choice_type == "string":
1671 String.shrink(
1672 value,
1673 lambda val: self.try_shrinking_nodes(nodes, val),
1674 intervals=constraints["intervals"],
1675 min_size=constraints["min_size"],
1676 )
1677 else:
1678 raise NotImplementedError
1679
1680 def try_trivial_spans(self, chooser):
1681 i = chooser.choose(range(len(self.spans)))
1682
1683 prev = self.shrink_target
1684 nodes = self.shrink_target.nodes
1685 span = self.spans[i]
1686 prefix = nodes[: span.start]
1687 replacement = tuple(
1688 [
1689 (
1690 node
1691 if node.was_forced
1692 else node.copy(
1693 with_value=choice_from_index(0, node.type, node.constraints)
1694 )
1695 )
1696 for node in nodes[span.start : span.end]
1697 ]
1698 )
1699 suffix = nodes[span.end :]
1700 attempt = self.cached_test_function(prefix + replacement + suffix)[1]
1701
1702 if self.shrink_target is not prev:
1703 return
1704
1705 if isinstance(attempt, ConjectureResult):
1706 new_span = attempt.spans[i]
1707 new_replacement = attempt.nodes[new_span.start : new_span.end]
1708 self.consider_new_nodes(prefix + new_replacement + suffix)
1709
1710 def minimize_individual_choices(self, chooser):
1711 """Attempt to minimize each choice in sequence.
1712
1713 This is the pass that ensures that e.g. each integer we draw is a
1714 minimum value. So it's the part that guarantees that if we e.g. do
1715
1716 x = data.draw(integers())
1717 assert x < 10
1718
1719 then in our shrunk example, x = 10 rather than say 97.
1720
1721 If we are unsuccessful at minimizing a choice of interest we then
1722 check if that's because it's changing the size of the test case and,
1723 if so, we also make an attempt to delete parts of the test case to
1724 see if that fixes it.
1725
1726 We handle most of the common cases in try_shrinking_nodes which is
1727 pretty good at clearing out large contiguous blocks of dead space,
1728 but it fails when there is data that has to stay in particular places
1729 in the list.
1730 """
1731 node = chooser.choose(self.nodes, lambda node: not node.trivial)
1732 initial_target = self.shrink_target
1733
1734 self.minimize_nodes([node])
1735 if self.shrink_target is not initial_target:
1736 # the shrink target changed, so our shrink worked. Defer doing
1737 # anything more intelligent until this shrink fails.
1738 return
1739
1740 # the shrink failed. One particularly common case where minimizing a
1741 # node can fail is the antipattern of drawing a size and then drawing a
1742 # collection of that size, or more generally when there is a size
1743 # dependency on some single node. We'll explicitly try and fix up this
1744 # common case here: if decreasing an integer node by one would reduce
1745 # the size of the generated input, we'll try deleting things after that
1746 # node and see if the resulting attempt works.
1747
1748 if node.type != "integer":
1749 # Only try this fixup logic on integer draws. Almost all size
1750 # dependencies are on integer draws, and if it's not, it's doing
1751 # something convoluted enough that it is unlikely to shrink well anyway.
1752 # TODO: extent to floats? we probably currently fail on the following,
1753 # albeit convoluted example:
1754 # n = int(data.draw(st.floats()))
1755 # s = data.draw(st.lists(st.integers(), min_size=n, max_size=n))
1756 return
1757
1758 lowered = (
1759 self.nodes[: node.index]
1760 + (node.copy(with_value=node.value - 1),)
1761 + self.nodes[node.index + 1 :]
1762 )
1763 attempt = self.cached_test_function(lowered)[1]
1764 if (
1765 attempt is None
1766 or attempt.status < Status.VALID
1767 or len(attempt.nodes) == len(self.nodes)
1768 or len(attempt.nodes) == node.index + 1
1769 ):
1770 # no point in trying our size-dependency-logic if our attempt at
1771 # lowering the node resulted in:
1772 # * an invalid conjecture data
1773 # * the same number of nodes as before
1774 # * no nodes beyond the lowered node (nothing to try to delete afterwards)
1775 return
1776
1777 # If it were then the original shrink should have worked and we could
1778 # never have got here.
1779 assert attempt is not self.shrink_target
1780
1781 @self.cached(node.index)
1782 def first_span_after_node():
1783 lo = 0
1784 hi = len(self.spans)
1785 while lo + 1 < hi:
1786 mid = (lo + hi) // 2
1787 span = self.spans[mid]
1788 if span.start >= node.index:
1789 hi = mid
1790 else:
1791 lo = mid
1792 return hi
1793
1794 # we try deleting both entire spans, and single nodes.
1795 # If we wanted to get more aggressive, we could try deleting n
1796 # consecutive nodes (that don't cross a span boundary) for say
1797 # n <= 2 or n <= 3.
1798 if chooser.choose([True, False]):
1799 span = self.spans[
1800 chooser.choose(
1801 range(first_span_after_node, len(self.spans)),
1802 lambda i: self.spans[i].choice_count > 0,
1803 )
1804 ]
1805 self.consider_new_nodes(lowered[: span.start] + lowered[span.end :])
1806 else:
1807 node = self.nodes[chooser.choose(range(node.index + 1, len(self.nodes)))]
1808 self.consider_new_nodes(lowered[: node.index] + lowered[node.index + 1 :])
1809
1810 def reorder_spans(self, chooser):
1811 """This pass allows us to reorder the children of each span.
1812
1813 For example, consider the following:
1814
1815 .. code-block:: python
1816
1817 import hypothesis.strategies as st
1818 from hypothesis import given
1819
1820
1821 @given(st.text(), st.text())
1822 def test_not_equal(x, y):
1823 assert x != y
1824
1825 Without the ability to reorder x and y this could fail either with
1826 ``x=""``, ``y="0"``, or the other way around. With reordering it will
1827 reliably fail with ``x=""``, ``y="0"``.
1828 """
1829 span = chooser.choose(self.spans)
1830
1831 label = chooser.choose(span.children).label
1832 spans = [c for c in span.children if c.label == label]
1833 if len(spans) <= 1:
1834 return
1835
1836 endpoints = [(span.start, span.end) for span in spans]
1837 st = self.shrink_target
1838
1839 Ordering.shrink(
1840 range(len(spans)),
1841 lambda indices: self.consider_new_nodes(
1842 replace_all(
1843 st.nodes,
1844 [
1845 (
1846 u,
1847 v,
1848 st.nodes[spans[i].start : spans[i].end],
1849 )
1850 for (u, v), i in zip(endpoints, indices, strict=True)
1851 ],
1852 )
1853 ),
1854 key=lambda i: sort_key(st.nodes[spans[i].start : spans[i].end]),
1855 )
1856
1857 def run_node_program(self, i, program, original, repeats=1):
1858 """Node programs are a mini-DSL for node rewriting, defined as a sequence
1859 of commands that can be run at some index into the nodes
1860
1861 Commands are:
1862
1863 * "X", delete this node
1864
1865 This method runs the node program in ``program`` at node index
1866 ``i`` on the ConjectureData ``original``. If ``repeats > 1`` then it
1867 will attempt to approximate the results of running it that many times.
1868
1869 Returns True if this successfully changes the underlying shrink target,
1870 else False.
1871 """
1872 if i + len(program) > len(original.nodes) or i < 0:
1873 return False
1874 attempt = list(original.nodes)
1875 for _ in range(repeats):
1876 for k, command in reversed(list(enumerate(program))):
1877 j = i + k
1878 if j >= len(attempt):
1879 return False
1880
1881 if command == "X":
1882 del attempt[j]
1883 else:
1884 raise NotImplementedError(f"Unrecognised command {command!r}")
1885
1886 return self.consider_new_nodes(attempt)