1# This file is part of Hypothesis, which may be found at
2# https://github.com/HypothesisWorks/hypothesis/
3#
4# Copyright the Hypothesis Authors.
5# Individual contributors are listed in AUTHORS.rst and the git log.
6#
7# This Source Code Form is subject to the terms of the Mozilla Public License,
8# v. 2.0. If a copy of the MPL was not distributed with this file, You can
9# obtain one at https://mozilla.org/MPL/2.0/.
10
11import math
12import unicodedata
13from collections import defaultdict
14from collections.abc import Callable, Iterator, Sequence
15from dataclasses import dataclass
16from functools import lru_cache
17from typing import (
18 TYPE_CHECKING,
19 Any,
20 Literal,
21 TypeAlias,
22 cast,
23)
24
25from hypothesis.internal.conjecture.choice import (
26 ChoiceNode,
27 ChoiceT,
28 choice_equal,
29 choice_from_index,
30 choice_key,
31 choice_permitted,
32 choice_to_index,
33 choices_key,
34)
35from hypothesis.internal.conjecture.data import (
36 ConjectureData,
37 ConjectureResult,
38 Spans,
39 Status,
40 _Overrun,
41 draw_choice,
42)
43from hypothesis.internal.conjecture.junkdrawer import (
44 endswith,
45 find_integer,
46 replace_all,
47 startswith,
48)
49from hypothesis.internal.conjecture.shrinking import (
50 Bytes,
51 Float,
52 Integer,
53 Ordering,
54 String,
55)
56from hypothesis.internal.conjecture.shrinking.choicetree import (
57 ChoiceTree,
58 prefix_selection_order,
59 random_selection_order,
60)
61from hypothesis.internal.floats import MAX_PRECISE_INTEGER
62
63if TYPE_CHECKING:
64 from random import Random
65
66 from hypothesis.internal.conjecture.engine import ConjectureRunner
67
68ShrinkPredicateT: TypeAlias = Callable[[ConjectureResult | _Overrun], bool]
69
70
71def sort_key(nodes: Sequence[ChoiceNode]) -> tuple[int, tuple[int, ...]]:
72 """Returns a sort key such that "simpler" choice sequences are smaller than
73 "more complicated" ones.
74
75 We define sort_key so that x is simpler than y if x is shorter than y or if
76 they have the same length and map(choice_to_index, x) < map(choice_to_index, y).
77
78 The reason for using this ordering is:
79
80 1. If x is shorter than y then that means we had to make fewer decisions
81 in constructing the test case when we ran x than we did when we ran y.
82 2. If x is the same length as y then replacing a choice with a lower index
83 choice corresponds to replacing it with a simpler/smaller choice.
84 3. Because choices drawn early in generation potentially get used in more
85 places they potentially have a more significant impact on the final
86 result, so it makes sense to prioritise reducing earlier choices over
87 later ones.
88 """
89 return (
90 len(nodes),
91 tuple(choice_to_index(node.value, node.constraints) for node in nodes),
92 )
93
94
95@lru_cache(maxsize=4096)
96def _natural_simpler_chars(c, intervals):
97 """Return single-char replacements for ``c`` derived from natural text
98 transformations - case mapping (upper, lower, casefold) and unicode
99 decomposition (NFD, NFKD). We take each individual character of the
100 transformed form so that e.g. ``ß`` can shrink to ``s`` via casefold
101 even though the full case-folded form is two characters.
102
103 Only candidates which are in ``intervals`` and which have a strictly
104 smaller index in shrink order than ``c`` are returned, sorted by that
105 shrink-order index. Callers must pass a single character that is itself
106 in ``intervals``.
107 """
108 candidates: set[str] = set()
109 for form in ("NFKD", "NFD"):
110 candidates.update(unicodedata.normalize(form, c))
111 for transformed in (c.upper(), c.lower(), c.casefold()):
112 candidates.update(transformed)
113 candidates.discard(c)
114 original_idx = intervals.index_from_char_in_shrink_order(c)
115 result = sorted(
116 (intervals.index_from_char_in_shrink_order(cand), cand)
117 for cand in candidates
118 if cand in intervals
119 )
120 return [cand for idx, cand in result if idx < original_idx]
121
122
123@dataclass(slots=True, frozen=False)
124class ShrinkPass:
125 function: Any
126 name: str | None = None
127 last_prefix: Any = ()
128
129 # some execution statistics
130 calls: int = 0
131 misaligned: int = 0
132 shrinks: int = 0
133 deletions: int = 0
134
135 def __post_init__(self):
136 if self.name is None:
137 self.name = self.function.__name__
138
139 def __hash__(self):
140 return hash(self.name)
141
142
143class StopShrinking(Exception):
144 pass
145
146
147class Shrinker:
148 """A shrinker is a child object of a ConjectureRunner which is designed to
149 manage the associated state of a particular shrink problem. That is, we
150 have some initial ConjectureData object and some property of interest
151 that it satisfies, and we want to find a ConjectureData object with a
152 shortlex (see sort_key above) smaller choice sequence that exhibits the same
153 property.
154
155 Currently the only property of interest we use is that the status is
156 INTERESTING and the interesting_origin takes on some fixed value, but we
157 may potentially be interested in other use cases later.
158 However we assume that data with a status < VALID never satisfies the predicate.
159
160 The shrinker keeps track of a value shrink_target which represents the
161 current best known ConjectureData object satisfying the predicate.
162 It refines this value by repeatedly running *shrink passes*, which are
163 methods that perform a series of transformations to the current shrink_target
164 and evaluate the underlying test function to find new ConjectureData
165 objects. If any of these satisfy the predicate, the shrink_target
166 is updated automatically. Shrinking runs until no shrink pass can
167 improve the shrink_target, at which point it stops. It may also be
168 terminated if the underlying engine throws RunIsComplete, but that
169 is handled by the calling code rather than the Shrinker.
170
171 =======================
172 Designing Shrink Passes
173 =======================
174
175 Generally a shrink pass is just any function that calls
176 cached_test_function and/or consider_new_nodes a number of times,
177 but there are a couple of useful things to bear in mind.
178
179 A shrink pass *makes progress* if running it changes self.shrink_target
180 (i.e. it tries a shortlex smaller ConjectureData object satisfying
181 the predicate). The desired end state of shrinking is to find a
182 value such that no shrink pass can make progress, i.e. that we
183 are at a local minimum for each shrink pass.
184
185 In aid of this goal, the main invariant that a shrink pass much
186 satisfy is that whether it makes progress must be deterministic.
187 It is fine (encouraged even) for the specific progress it makes
188 to be non-deterministic, but if you run a shrink pass, it makes
189 no progress, and then you immediately run it again, it should
190 never succeed on the second time. This allows us to stop as soon
191 as we have run each shrink pass and seen no progress on any of
192 them.
193
194 This means that e.g. it's fine to try each of N deletions
195 or replacements in a random order, but it's not OK to try N random
196 deletions (unless you have already shrunk at least once, though we
197 don't currently take advantage of this loophole).
198
199 Shrink passes need to be written so as to be robust against
200 change in the underlying shrink target. It is generally safe
201 to assume that the shrink target does not change prior to the
202 point of first modification - e.g. if you change no bytes at
203 index ``i``, all spans whose start is ``<= i`` still exist,
204 as do all blocks, and the data object is still of length
205 ``>= i + 1``. This can only be violated by bad user code which
206 relies on an external source of non-determinism.
207
208 When the underlying shrink_target changes, shrink
209 passes should not run substantially more test_function calls
210 on success than they do on failure. Say, no more than a constant
211 factor more. In particular shrink passes should not iterate to a
212 fixed point.
213
214 This means that shrink passes are often written with loops that
215 are carefully designed to do the right thing in the case that no
216 shrinks occurred and try to adapt to any changes to do a reasonable
217 job. e.g. say we wanted to write a shrink pass that tried deleting
218 each individual choice (this isn't an especially good pass,
219 but it leads to a simple illustrative example), we might do it
220 by iterating over the choice sequence like so:
221
222 .. code-block:: python
223
224 i = 0
225 while i < len(self.shrink_target.nodes):
226 if not self.consider_new_nodes(
227 self.shrink_target.nodes[:i] + self.shrink_target.nodes[i + 1 :]
228 ):
229 i += 1
230
231 The reason for writing the loop this way is that i is always a
232 valid index into the current choice sequence, even if the current sequence
233 changes as a result of our actions. When the choice sequence changes,
234 we leave the index where it is rather than restarting from the
235 beginning, and carry on. This means that the number of steps we
236 run in this case is always bounded above by the number of steps
237 we would run if nothing works.
238
239 Another thing to bear in mind about shrink pass design is that
240 they should prioritise *progress*. If you have N operations that
241 you need to run, you should try to order them in such a way as
242 to avoid stalling, where you have long periods of test function
243 invocations where no shrinks happen. This is bad because whenever
244 we shrink we reduce the amount of work the shrinker has to do
245 in future, and often speed up the test function, so we ideally
246 wanted those shrinks to happen much earlier in the process.
247
248 Sometimes stalls are inevitable of course - e.g. if the pass
249 makes no progress, then the entire thing is just one long stall,
250 but it's helpful to design it so that stalls are less likely
251 in typical behaviour.
252
253 The two easiest ways to do this are:
254
255 * Just run the N steps in random order. As long as a
256 reasonably large proportion of the operations succeed, this
257 guarantees the expected stall length is quite short. The
258 book keeping for making sure this does the right thing when
259 it succeeds can be quite annoying.
260 * When you have any sort of nested loop, loop in such a way
261 that both loop variables change each time. This prevents
262 stalls which occur when one particular value for the outer
263 loop is impossible to make progress on, rendering the entire
264 inner loop into a stall.
265
266 However, although progress is good, too much progress can be
267 a bad sign! If you're *only* seeing successful reductions,
268 that's probably a sign that you are making changes that are
269 too timid. Two useful things to offset this:
270
271 * It's worth writing shrink passes which are *adaptive*, in
272 the sense that when operations seem to be working really
273 well we try to bundle multiple of them together. This can
274 often be used to turn what would be O(m) successful calls
275 into O(log(m)).
276 * It's often worth trying one or two special minimal values
277 before trying anything more fine grained (e.g. replacing
278 the whole thing with zero).
279
280 """
281
282 def derived_value(fn):
283 """It's useful during shrinking to have access to derived values of
284 the current shrink target.
285
286 This decorator allows you to define these as cached properties. They
287 are calculated once, then cached until the shrink target changes, then
288 recalculated the next time they are used."""
289
290 def accept(self):
291 try:
292 return self.__derived_values[fn.__name__]
293 except KeyError:
294 return self.__derived_values.setdefault(fn.__name__, fn(self))
295
296 accept.__name__ = fn.__name__
297 return property(accept)
298
299 def __init__(
300 self,
301 engine: "ConjectureRunner",
302 initial: ConjectureData | ConjectureResult,
303 predicate: ShrinkPredicateT | None,
304 *,
305 allow_transition: (
306 Callable[[ConjectureData | ConjectureResult, ConjectureData], bool] | None
307 ),
308 explain: bool,
309 in_target_phase: bool = False,
310 ):
311 """Create a shrinker for a particular engine, with a given starting
312 point and predicate. When shrink() is called it will attempt to find an
313 example for which predicate is True and which is strictly smaller than
314 initial.
315
316 Note that initial is a ConjectureData object, and predicate
317 takes ConjectureData objects.
318 """
319 assert predicate is not None or allow_transition is not None
320 self.engine = engine
321 self.__predicate = predicate or (lambda data: True)
322 self.__allow_transition = allow_transition or (lambda source, destination: True)
323 self.__derived_values: dict = {}
324
325 self.initial_size = len(initial.choices)
326 # We keep track of the current best example on the shrink_target
327 # attribute.
328 self.shrink_target = initial
329 self.clear_change_tracking()
330 self.shrinks = 0
331
332 # We terminate shrinks that seem to have reached their logical
333 # conclusion: If we've called the underlying test function at
334 # least self.max_stall times since the last time we shrunk,
335 # it's time to stop shrinking.
336 self.max_stall = 200
337 self.initial_calls = self.engine.call_count
338 self.initial_misaligned = self.engine.misaligned_count
339 self.calls_at_last_shrink = self.initial_calls
340
341 self.shrink_passes: list[ShrinkPass] = [
342 ShrinkPass(self.try_trivial_spans),
343 self.node_program("X" * 5),
344 self.node_program("X" * 4),
345 self.node_program("X" * 3),
346 self.node_program("X" * 2),
347 self.node_program("X" * 1),
348 ShrinkPass(self.pass_to_descendant),
349 ShrinkPass(self.reorder_spans),
350 ShrinkPass(self.minimize_duplicated_choices),
351 ShrinkPass(self.minimize_individual_choices),
352 ShrinkPass(self.redistribute_numeric_pairs),
353 ShrinkPass(self.lower_integers_together),
354 ShrinkPass(self.lower_duplicated_characters),
355 ShrinkPass(self.normalize_unicode_chars),
356 ]
357
358 # Because the shrinker is also used to `pareto_optimise` in the target phase,
359 # we sometimes want to allow extending buffers instead of aborting at the end.
360 self.__extend: Literal["full"] | int = "full" if in_target_phase else 0
361 self.should_explain = explain
362
363 @derived_value # type: ignore
364 def cached_calculations(self):
365 return {}
366
367 def cached(self, *keys):
368 def accept(f):
369 cache_key = (f.__name__, *keys)
370 try:
371 return self.cached_calculations[cache_key]
372 except KeyError:
373 return self.cached_calculations.setdefault(cache_key, f())
374
375 return accept
376
377 @property
378 def calls(self) -> int:
379 """Return the number of calls that have been made to the underlying
380 test function."""
381 return self.engine.call_count
382
383 @property
384 def misaligned(self) -> int:
385 return self.engine.misaligned_count
386
387 def check_calls(self) -> None:
388 if self.calls - self.calls_at_last_shrink >= self.max_stall:
389 raise StopShrinking
390
391 def cached_test_function(
392 self, nodes: Sequence[ChoiceNode]
393 ) -> tuple[bool, ConjectureResult | _Overrun | None]:
394 nodes = nodes[: len(self.nodes)]
395
396 if startswith(nodes, self.nodes):
397 return (True, None)
398
399 if sort_key(self.nodes) < sort_key(nodes):
400 return (False, None)
401
402 # sometimes our shrinking passes try obviously invalid things. We handle
403 # discarding them in one place here.
404 if any(not choice_permitted(node.value, node.constraints) for node in nodes):
405 return (False, None)
406
407 result = self.engine.cached_test_function(
408 [n.value for n in nodes], extend=self.__extend
409 )
410 previous = self.shrink_target
411 self.incorporate_test_data(result)
412 self.check_calls()
413 return (previous is not self.shrink_target, result)
414
415 def consider_new_nodes(self, nodes: Sequence[ChoiceNode]) -> bool:
416 return self.cached_test_function(nodes)[0]
417
418 def incorporate_test_data(self, data):
419 """Takes a ConjectureData or Overrun object updates the current
420 shrink_target if this data represents an improvement over it."""
421 if data.status < Status.VALID or data is self.shrink_target:
422 return
423 if (
424 self.__predicate(data)
425 and sort_key(data.nodes) < sort_key(self.shrink_target.nodes)
426 and self.__allow_transition(self.shrink_target, data)
427 ):
428 self.update_shrink_target(data)
429
430 def debug(self, msg: str) -> None:
431 self.engine.debug(msg)
432
433 @property
434 def random(self) -> "Random":
435 return self.engine.random
436
437 def shrink(self) -> None:
438 """Run the full set of shrinks and update shrink_target.
439
440 This method is "mostly idempotent" - calling it twice is unlikely to
441 have any effect, though it has a non-zero probability of doing so.
442 """
443
444 try:
445 self.initial_coarse_reduction()
446 self.greedy_shrink()
447 except StopShrinking:
448 # If we stopped shrinking because we're making slow progress (instead of
449 # reaching a local optimum), don't run the explain-phase logic.
450 self.should_explain = False
451 finally:
452 if self.engine.report_debug_info:
453
454 def s(n):
455 return "s" if n != 1 else ""
456
457 total_deleted = self.initial_size - len(self.shrink_target.choices)
458 calls = self.engine.call_count - self.initial_calls
459 misaligned = self.engine.misaligned_count - self.initial_misaligned
460
461 self.debug(
462 "---------------------\n"
463 "Shrink pass profiling\n"
464 "---------------------\n\n"
465 f"Shrinking made a total of {calls} call{s(calls)} of which "
466 f"{self.shrinks} shrank and {misaligned} were misaligned. This "
467 f"deleted {total_deleted} choices out of {self.initial_size}."
468 )
469 for useful in [True, False]:
470 self.debug("")
471 if useful:
472 self.debug("Useful passes:")
473 else:
474 self.debug("Useless passes:")
475 self.debug("")
476 for pass_ in sorted(
477 self.shrink_passes,
478 key=lambda t: (-t.calls, t.deletions, t.shrinks),
479 ):
480 if pass_.calls == 0:
481 continue
482 if (pass_.shrinks != 0) != useful:
483 continue
484
485 self.debug(
486 f" * {pass_.name} made {pass_.calls} call{s(pass_.calls)} of which "
487 f"{pass_.shrinks} shrank and {pass_.misaligned} were misaligned, "
488 f"deleting {pass_.deletions} choice{s(pass_.deletions)}."
489 )
490 self.debug("")
491 self.explain()
492
493 def explain(self) -> None:
494 if not self.should_explain or not self.shrink_target.arg_slices:
495 return
496 with self.engine._log_phase_statistics("explain"):
497 self._explain()
498
499 def _explain(self) -> None:
500 self.max_stall = 2**100
501 shrink_target = self.shrink_target
502 nodes = self.nodes
503 choices = self.choices
504 chunks: dict[tuple[int, int], list[tuple[ChoiceT, ...]]] = defaultdict(list)
505
506 # Before we start running experiments, let's check for known inputs which would
507 # make them redundant. The shrinking process means that we've already tried many
508 # variations on the minimal example, so this can save a lot of time.
509 seen_passing_seq = self.engine.passing_choice_sequences(
510 prefix=self.nodes[: min(self.shrink_target.arg_slices)[0]]
511 )
512
513 # Now that we've shrunk to a minimal failing example, it's time to try
514 # varying each part that we've noted will go in the final report. Consider
515 # slices in largest-first order
516 for start, end in sorted(
517 self.shrink_target.arg_slices, key=lambda x: (-(x[1] - x[0]), x)
518 ):
519 # Check for any previous examples that match the prefix and suffix,
520 # so we can skip if we found a passing example while shrinking.
521 if any(
522 startswith(seen, nodes[:start]) and endswith(seen, nodes[end:])
523 for seen in seen_passing_seq
524 ):
525 continue
526
527 # Skip slices that are subsets of already-explained slices.
528 # If a larger slice can vary freely, so can its sub-slices.
529 # Note: (0, 0) is a special marker for the "together" comment that
530 # applies to the whole test, not a specific slice, so we exclude it.
531 if any(
532 s <= start and end <= e
533 for s, e in self.shrink_target.slice_comments
534 if (s, e) != (0, 0)
535 ):
536 continue
537
538 # Try a few targeted candidates before falling back to random sampling,
539 # so that simple cases like ``assert n1 == n2`` -- where the only
540 # passing value of ``n1`` is exactly ``n2``'s value -- aren't reported
541 # as freely-variable just because random sampling missed it.
542 candidates = list(self._explain_candidates(start, end))
543
544 # Run our experiments
545 n_same_failures = 0
546 note = "or any other generated value"
547 # TODO: is 100 same-failures out of 500 attempts a good heuristic?
548 for n_attempt in range(500 + len(candidates)): # pragma: no branch
549 # no-branch here because we don't coverage-test the abort-at-500 logic.
550
551 if n_attempt - 10 - len(candidates) > n_same_failures * 5:
552 # stop early if we're seeing mostly invalid examples
553 break # pragma: no cover
554
555 if n_attempt < len(candidates):
556 replacement = list(candidates[n_attempt])
557 else:
558 # replace start:end with random values
559 replacement = []
560 for i in range(start, end):
561 node = nodes[i]
562 if not node.was_forced:
563 value = draw_choice(
564 node.type, node.constraints, random=self.random
565 )
566 node = node.copy(with_value=value)
567 replacement.append(node.value)
568
569 attempt = choices[:start] + tuple(replacement) + choices[end:]
570 result = self.engine.cached_test_function(attempt, extend="full")
571
572 if result.status is Status.OVERRUN:
573 continue # pragma: no cover # flakily covered
574 result = cast(ConjectureResult, result)
575 if not (
576 len(attempt) == len(result.choices)
577 and endswith(result.nodes, nodes[end:])
578 ):
579 # Turns out this was a variable-length part, so grab the infix...
580 for span1, span2 in zip(
581 shrink_target.spans, result.spans, strict=False
582 ):
583 assert span1.start == span2.start
584 assert span1.start <= start
585 if span1.start == start and span1.end == end:
586 result_end = span2.end
587 break
588 else:
589 raise NotImplementedError("Expected matching prefixes")
590
591 attempt = (
592 choices[:start]
593 + result.choices[start:result_end]
594 + choices[end:]
595 )
596 chunks[(start, end)].append(result.choices[start:result_end])
597 result = self.engine.cached_test_function(attempt)
598
599 if result.status is Status.OVERRUN:
600 continue # pragma: no cover # flakily covered
601 result = cast(ConjectureResult, result)
602 else:
603 chunks[(start, end)].append(result.choices[start:end])
604
605 if shrink_target is not self.shrink_target: # pragma: no cover
606 # If we've shrunk further without meaning to, bail out.
607 self.shrink_target.slice_comments.clear()
608 return
609 if result.status is Status.VALID:
610 # The test passed, indicating that this param can't vary freely.
611 # However, it's really hard to write a simple and reliable covering
612 # test, because of our `seen_passing_buffers` check above.
613 break # pragma: no cover
614 if self.__predicate(result): # pragma: no branch
615 n_same_failures += 1
616 if n_same_failures >= 100:
617 self.shrink_target.slice_comments[(start, end)] = note
618 break
619
620 # Finally, if we've found multiple independently-variable parts, check whether
621 # they can all be varied together.
622 if len(self.shrink_target.slice_comments) <= 1:
623 return
624 n_same_failures_together = 0
625 # Only include slices that were actually added to slice_comments
626 chunks_by_start_index = sorted(
627 (k, v) for k, v in chunks.items() if k in self.shrink_target.slice_comments
628 )
629 for _ in range(500): # pragma: no branch
630 # no-branch here because we don't coverage-test the abort-at-500 logic.
631 new_choices: list[ChoiceT] = []
632 prev_end = 0
633 for (start, end), ls in chunks_by_start_index:
634 assert prev_end <= start < end, "these chunks must be nonoverlapping"
635 new_choices.extend(choices[prev_end:start])
636 new_choices.extend(self.random.choice(ls))
637 prev_end = end
638
639 result = self.engine.cached_test_function(new_choices)
640
641 # This *can't* be a shrink because none of the components were.
642 assert shrink_target is self.shrink_target
643 if result.status == Status.VALID:
644 self.shrink_target.slice_comments[(0, 0)] = (
645 "The test sometimes passed when commented parts were varied together."
646 )
647 break # Test passed, this param can't vary freely.
648 if self.__predicate(result): # pragma: no branch
649 n_same_failures_together += 1
650 if n_same_failures_together >= 100:
651 self.shrink_target.slice_comments[(0, 0)] = (
652 "The test always failed when commented parts were varied together."
653 )
654 break
655
656 def _explain_candidates(
657 self, start: int, end: int
658 ) -> "Iterator[tuple[ChoiceT, ...]]":
659 """Yield deterministic candidate replacements for ``nodes[start:end]``.
660
661 Random sampling alone misses cases like ``assert n1 == n2``, where the
662 only passing value of ``n1`` is exactly ``n2``'s value. We try
663 substituting values from each other arg slice with matching length and
664 types, which catches such comparisons. Invalid borrowed values just
665 produce an irrelevant test result the outer loop discards.
666 """
667 nodes = self.nodes
668 target_types = tuple(nodes[i].type for i in range(start, end))
669 current_key = choices_key(tuple(nodes[i].value for i in range(start, end)))
670 seen: set[tuple[Any, ...]] = {current_key}
671 for start2, end2 in sorted(self.shrink_target.arg_slices):
672 if (start2, end2) == (start, end) or (end2 - start2) != (end - start):
673 continue
674 if (
675 tuple(nodes[start2 + j].type for j in range(end - start))
676 != target_types
677 ):
678 continue
679 borrowed = tuple(nodes[start2 + j].value for j in range(end - start))
680 key = choices_key(borrowed)
681 if key in seen:
682 continue
683 seen.add(key)
684 yield borrowed
685
686 def greedy_shrink(self) -> None:
687 """Run a full set of greedy shrinks (that is, ones that will only ever
688 move to a better target) and update shrink_target appropriately.
689
690 This method iterates to a fixed point and so is idempontent - calling
691 it twice will have exactly the same effect as calling it once.
692 """
693 self.fixate_shrink_passes(self.shrink_passes)
694
695 def initial_coarse_reduction(self):
696 """Performs some preliminary reductions that should not be
697 repeated as part of the main shrink passes.
698
699 The main reason why these can't be included as part of shrink
700 passes is that they have much more ability to make the test
701 case "worse". e.g. they might rerandomise part of it, significantly
702 increasing the value of individual nodes, which works in direct
703 opposition to the lexical shrinking and will frequently undo
704 its work.
705 """
706 self.reduce_each_alternative()
707
708 @derived_value # type: ignore
709 def spans_starting_at(self):
710 result = [[] for _ in self.shrink_target.nodes]
711 for i, ex in enumerate(self.spans):
712 # We can have zero-length spans that start at the end
713 if ex.start < len(result):
714 result[ex.start].append(i)
715 return tuple(map(tuple, result))
716
717 def reduce_each_alternative(self):
718 """This is a pass that is designed to rerandomise use of the
719 one_of strategy or things that look like it, in order to try
720 to move from later strategies to earlier ones in the branch
721 order.
722
723 It does this by trying to systematically lower each value it
724 finds that looks like it might be the branch decision for
725 one_of, and then attempts to repair any changes in shape that
726 this causes.
727 """
728 i = 0
729 while i < len(self.shrink_target.nodes):
730 nodes = self.shrink_target.nodes
731 node = nodes[i]
732 if (
733 node.type == "integer"
734 and not node.was_forced
735 and node.value <= 10
736 and node.constraints["min_value"] == 0
737 ):
738 assert isinstance(node.value, int)
739
740 # We've found a plausible candidate for a ``one_of`` choice.
741 # We now want to see if the shape of the test case actually depends
742 # on it. If it doesn't, then we don't need to do this (comparatively
743 # costly) pass, and can let much simpler lexicographic reduction
744 # handle it later.
745 #
746 # We test this by trying to set the value to zero and seeing if the
747 # shape changes, as measured by either changing the number of subsequent
748 # nodes, or changing the nodes in such a way as to cause one of the
749 # previous values to no longer be valid in its position.
750 zero_attempt = self.cached_test_function(
751 nodes[:i] + (nodes[i].copy(with_value=0),) + nodes[i + 1 :]
752 )[1]
753 if (
754 zero_attempt is not self.shrink_target
755 and zero_attempt is not None
756 and zero_attempt.status >= Status.VALID
757 ):
758 changed_shape = len(zero_attempt.nodes) != len(nodes)
759
760 if not changed_shape:
761 for j in range(i + 1, len(nodes)):
762 zero_node = zero_attempt.nodes[j]
763 orig_node = nodes[j]
764 if (
765 zero_node.type != orig_node.type
766 or not choice_permitted(
767 orig_node.value, zero_node.constraints
768 )
769 ):
770 changed_shape = True
771 break
772 if changed_shape:
773 for v in range(node.value):
774 if self.try_lower_node_as_alternative(i, v):
775 break
776 i += 1
777
778 def try_lower_node_as_alternative(self, i, v):
779 """Attempt to lower `self.shrink_target.nodes[i]` to `v`,
780 while rerandomising and attempting to repair any subsequent
781 changes to the shape of the test case that this causes."""
782 nodes = self.shrink_target.nodes
783 if self.consider_new_nodes(
784 nodes[:i] + (nodes[i].copy(with_value=v),) + nodes[i + 1 :]
785 ):
786 return True
787
788 prefix = nodes[:i] + (nodes[i].copy(with_value=v),)
789 initial = self.shrink_target
790 spans = self.spans_starting_at[i]
791 for _ in range(3):
792 random_attempt = self.engine.cached_test_function(
793 [n.value for n in prefix], extend=len(nodes)
794 )
795 if random_attempt.status < Status.VALID:
796 continue
797 self.incorporate_test_data(random_attempt)
798 for j in spans:
799 initial_span = initial.spans[j]
800 attempt_span = random_attempt.spans[j]
801 contents = random_attempt.nodes[attempt_span.start : attempt_span.end]
802 self.consider_new_nodes(
803 nodes[:i] + contents + nodes[initial_span.end :]
804 )
805 if initial is not self.shrink_target:
806 return True
807 return False
808
809 @derived_value # type: ignore
810 def shrink_pass_choice_trees(self) -> dict[Any, ChoiceTree]:
811 return defaultdict(ChoiceTree)
812
813 def step(self, shrink_pass: ShrinkPass, *, random_order: bool = False) -> bool:
814 tree = self.shrink_pass_choice_trees[shrink_pass]
815 if tree.exhausted:
816 return False
817
818 initial_shrinks = self.shrinks
819 initial_calls = self.calls
820 initial_misaligned = self.misaligned
821 size = len(self.shrink_target.choices)
822 assert shrink_pass.name is not None
823 self.engine.explain_next_call_as(shrink_pass.name)
824
825 if random_order:
826 selection_order = random_selection_order(self.random)
827 else:
828 selection_order = prefix_selection_order(shrink_pass.last_prefix)
829
830 try:
831 shrink_pass.last_prefix = tree.step(
832 selection_order,
833 lambda chooser: shrink_pass.function(chooser),
834 )
835 finally:
836 shrink_pass.calls += self.calls - initial_calls
837 shrink_pass.misaligned += self.misaligned - initial_misaligned
838 shrink_pass.shrinks += self.shrinks - initial_shrinks
839 shrink_pass.deletions += size - len(self.shrink_target.choices)
840 self.engine.clear_call_explanation()
841 return True
842
843 def fixate_shrink_passes(self, passes: list[ShrinkPass]) -> None:
844 """Run steps from each pass in ``passes`` until the current shrink target
845 is a fixed point of all of them."""
846 any_ran = True
847 while any_ran:
848 any_ran = False
849
850 reordering = {}
851
852 # We run remove_discarded after every pass to do cleanup
853 # keeping track of whether that actually works. Either there is
854 # no discarded data and it is basically free, or it reliably works
855 # and deletes data, or it doesn't work. In that latter case we turn
856 # it off for the rest of this loop through the passes, but will
857 # try again once all of the passes have been run.
858 can_discard = self.remove_discarded()
859
860 calls_at_loop_start = self.calls
861
862 # We keep track of how many calls can be made by a single step
863 # without making progress and use this to test how much to pad
864 # out self.max_stall by as we go along.
865 max_calls_per_failing_step = 1
866
867 for sp in passes:
868 if can_discard:
869 can_discard = self.remove_discarded()
870
871 before_sp = self.shrink_target
872
873 # Run the shrink pass until it fails to make any progress
874 # max_failures times in a row. This implicitly boosts shrink
875 # passes that are more likely to work.
876 failures = 0
877 max_failures = 20
878 while failures < max_failures:
879 # We don't allow more than max_stall consecutive failures
880 # to shrink, but this means that if we're unlucky and the
881 # shrink passes are in a bad order where only the ones at
882 # the end are useful, if we're not careful this heuristic
883 # might stop us before we've tried everything. In order to
884 # avoid that happening, we make sure that there's always
885 # plenty of breathing room to make it through a single
886 # iteration of the fixate_shrink_passes loop.
887 self.max_stall = max(
888 self.max_stall,
889 2 * max_calls_per_failing_step
890 + (self.calls - calls_at_loop_start),
891 )
892
893 prev = self.shrink_target
894 initial_calls = self.calls
895 # It's better for us to run shrink passes in a deterministic
896 # order, to avoid repeat work, but this can cause us to create
897 # long stalls when there are a lot of steps which fail to do
898 # anything useful. In order to avoid this, once we've noticed
899 # we're in a stall (i.e. half of max_failures calls have failed
900 # to do anything) we switch to randomly jumping around. If we
901 # find a success then we'll resume deterministic order from
902 # there which, with any luck, is in a new good region.
903 if not self.step(sp, random_order=failures >= max_failures // 2):
904 # step returns False when there is nothing to do because
905 # the entire choice tree is exhausted. If this happens
906 # we break because we literally can't run this pass any
907 # more than we already have until something else makes
908 # progress.
909 break
910 any_ran = True
911
912 # Don't count steps that didn't actually try to do
913 # anything as failures. Otherwise, this call is a failure
914 # if it failed to make any changes to the shrink target.
915 if initial_calls != self.calls:
916 if prev is not self.shrink_target:
917 failures = 0
918 else:
919 max_calls_per_failing_step = max(
920 max_calls_per_failing_step, self.calls - initial_calls
921 )
922 failures += 1
923
924 # We reorder the shrink passes so that on our next run through
925 # we try good ones first. The rule is that shrink passes that
926 # did nothing useful are the worst, shrink passes that reduced
927 # the length are the best.
928 if self.shrink_target is before_sp:
929 reordering[sp] = 1
930 elif len(self.choices) < len(before_sp.choices):
931 reordering[sp] = -1
932 else:
933 reordering[sp] = 0
934
935 passes.sort(key=reordering.__getitem__)
936
937 @property
938 def nodes(self) -> tuple[ChoiceNode, ...]:
939 return self.shrink_target.nodes
940
941 @property
942 def choices(self) -> tuple[ChoiceT, ...]:
943 return self.shrink_target.choices
944
945 @property
946 def spans(self) -> Spans:
947 return self.shrink_target.spans
948
949 @derived_value # type: ignore
950 def spans_by_label(self):
951 """
952 A mapping of labels to a list of spans with that label. Spans in the list
953 are ordered by their normal index order.
954 """
955
956 spans_by_label = defaultdict(list)
957 for ex in self.spans:
958 spans_by_label[ex.label].append(ex)
959 return dict(spans_by_label)
960
961 @derived_value # type: ignore
962 def distinct_labels(self):
963 return sorted(self.spans_by_label, key=str)
964
965 def pass_to_descendant(self, chooser):
966 """Attempt to replace each span with a descendant span.
967
968 This is designed to deal with strategies that call themselves
969 recursively. For example, suppose we had:
970
971 binary_tree = st.deferred(
972 lambda: st.one_of(
973 st.integers(), st.tuples(binary_tree, binary_tree)))
974
975 This pass guarantees that we can replace any binary tree with one of
976 its subtrees - each of those will create an interval that the parent
977 could validly be replaced with, and this pass will try doing that.
978
979 This is pretty expensive - it takes O(len(intervals)^2) - so we run it
980 late in the process when we've got the number of intervals as far down
981 as possible.
982 """
983
984 label = chooser.choose(
985 self.distinct_labels, lambda l: len(self.spans_by_label[l]) >= 2
986 )
987
988 spans = self.spans_by_label[label]
989 i = chooser.choose(range(len(spans) - 1))
990 ancestor = spans[i]
991
992 if i + 1 == len(spans) or spans[i + 1].start >= ancestor.end:
993 return
994
995 @self.cached(label, i)
996 def descendants():
997 lo = i + 1
998 hi = len(spans)
999 while lo + 1 < hi:
1000 mid = (lo + hi) // 2
1001 if spans[mid].start >= ancestor.end:
1002 hi = mid
1003 else:
1004 lo = mid
1005 return [
1006 span
1007 for span in spans[i + 1 : hi]
1008 if span.choice_count < ancestor.choice_count
1009 ]
1010
1011 descendant = chooser.choose(descendants, lambda ex: ex.choice_count > 0)
1012
1013 assert ancestor.start <= descendant.start
1014 assert ancestor.end >= descendant.end
1015 assert descendant.choice_count < ancestor.choice_count
1016
1017 self.consider_new_nodes(
1018 self.nodes[: ancestor.start]
1019 + self.nodes[descendant.start : descendant.end]
1020 + self.nodes[ancestor.end :]
1021 )
1022
1023 def lower_common_node_offset(self):
1024 """Sometimes we find ourselves in a situation where changes to one part
1025 of the choice sequence unlock changes to other parts. Sometimes this is
1026 good, but sometimes this can cause us to exhibit exponential slow
1027 downs!
1028
1029 e.g. suppose we had the following:
1030
1031 m = draw(integers(min_value=0))
1032 n = draw(integers(min_value=0))
1033 assert abs(m - n) > 1
1034
1035 If this fails then we'll end up with a loop where on each iteration we
1036 reduce each of m and n by 2 - m can't go lower because of n, then n
1037 can't go lower because of m.
1038
1039 This will take us O(m) iterations to complete, which is exponential in
1040 the data size, as we gradually zig zag our way towards zero.
1041
1042 This can only happen if we're failing to reduce the size of the choice
1043 sequence: The number of iterations that reduce the length of the choice
1044 sequence is bounded by that length.
1045
1046 So what we do is this: We keep track of which nodes are changing, and
1047 then if there's some non-zero common offset to them we try and minimize
1048 them all at once by lowering that offset.
1049
1050 This may not work, and it definitely won't get us out of all possible
1051 exponential slow downs (an example of where it doesn't is where the
1052 shape of the nodes changes as a result of this bouncing behaviour),
1053 but it fails fast when it doesn't work and gets us out of a really
1054 nastily slow case when it does.
1055 """
1056 if len(self.__changed_nodes) <= 1:
1057 return
1058
1059 changed = []
1060 for i in sorted(self.__changed_nodes):
1061 node = self.nodes[i]
1062 if node.trivial or node.type != "integer":
1063 continue
1064 changed.append(node)
1065
1066 if not changed:
1067 return
1068
1069 ints = [
1070 abs(node.value - node.constraints["shrink_towards"]) for node in changed
1071 ]
1072 offset = min(ints)
1073 assert offset > 0
1074
1075 for i in range(len(ints)):
1076 ints[i] -= offset
1077
1078 st = self.shrink_target
1079
1080 def offset_node(node, n):
1081 return (
1082 node.index,
1083 node.index + 1,
1084 [node.copy(with_value=node.constraints["shrink_towards"] + n)],
1085 )
1086
1087 def consider(n, sign):
1088 return self.consider_new_nodes(
1089 replace_all(
1090 st.nodes,
1091 [
1092 offset_node(node, sign * (n + v))
1093 for node, v in zip(changed, ints, strict=False)
1094 ],
1095 )
1096 )
1097
1098 # shrink from both sides
1099 Integer.shrink(offset, lambda n: consider(n, 1))
1100 Integer.shrink(offset, lambda n: consider(n, -1))
1101 self.clear_change_tracking()
1102
1103 def clear_change_tracking(self):
1104 self.__last_checked_changed_at = self.shrink_target
1105 self.__all_changed_nodes = set()
1106
1107 def mark_changed(self, i):
1108 self.__changed_nodes.add(i)
1109
1110 @property
1111 def __changed_nodes(self) -> set[int]:
1112 if self.__last_checked_changed_at is self.shrink_target:
1113 return self.__all_changed_nodes
1114
1115 prev_target = self.__last_checked_changed_at
1116 new_target = self.shrink_target
1117 assert prev_target is not new_target
1118 prev_nodes = prev_target.nodes
1119 new_nodes = new_target.nodes
1120 assert sort_key(new_target.nodes) < sort_key(prev_target.nodes)
1121
1122 if len(prev_nodes) != len(new_nodes) or any(
1123 n1.type != n2.type for n1, n2 in zip(prev_nodes, new_nodes, strict=True)
1124 ):
1125 # should we check constraints are equal as well?
1126 self.__all_changed_nodes = set()
1127 else:
1128 assert len(prev_nodes) == len(new_nodes)
1129 for i, (n1, n2) in enumerate(zip(prev_nodes, new_nodes, strict=True)):
1130 assert n1.type == n2.type
1131 if not choice_equal(n1.value, n2.value):
1132 self.__all_changed_nodes.add(i)
1133
1134 return self.__all_changed_nodes
1135
1136 def update_shrink_target(self, new_target):
1137 assert isinstance(new_target, ConjectureResult)
1138 self.shrinks += 1
1139 # If we are just taking a long time to shrink we don't want to
1140 # trigger this heuristic, so whenever we shrink successfully
1141 # we give ourselves a bit of breathing room to make sure we
1142 # would find a shrink that took that long to find the next time.
1143 # The case where we're taking a long time but making steady
1144 # progress is handled by `finish_shrinking_deadline` in engine.py
1145 self.max_stall = max(
1146 self.max_stall, (self.calls - self.calls_at_last_shrink) * 2
1147 )
1148 self.calls_at_last_shrink = self.calls
1149 self.shrink_target = new_target
1150 self.__derived_values = {}
1151
1152 def try_shrinking_nodes(self, nodes, n):
1153 """Attempts to replace each node in the nodes list with n. Returns
1154 True if it succeeded (which may include some additional modifications
1155 to shrink_target).
1156
1157 In current usage it is expected that each of the nodes currently have
1158 the same value and choice_type, although this is not essential. Note that
1159 n must be < the node at min(nodes) or this is not a valid shrink.
1160
1161 This method will attempt to do some small amount of work to delete data
1162 that occurs after the end of the nodes. This is useful for cases where
1163 there is some size dependency on the value of a node.
1164 """
1165 # If the length of the shrink target has changed from under us such that
1166 # the indices are out of bounds, give up on the replacement.
1167 # TODO_BETTER_SHRINK: we probably want to narrow down the root cause here at some point.
1168 if any(node.index >= len(self.nodes) for node in nodes):
1169 return # pragma: no cover
1170
1171 initial_attempt = replace_all(
1172 self.nodes,
1173 [(node.index, node.index + 1, [node.copy(with_value=n)]) for node in nodes],
1174 )
1175
1176 attempt = self.cached_test_function(initial_attempt)[1]
1177
1178 if attempt is None:
1179 return False
1180
1181 if attempt is self.shrink_target:
1182 # if the initial shrink was a success, try lowering offsets.
1183 self.lower_common_node_offset()
1184 return True
1185
1186 # If this produced something completely invalid we ditch it
1187 # here rather than trying to persevere.
1188 if attempt.status is Status.OVERRUN:
1189 # Lowering a size-controlling choice can make the realigned (and
1190 # now boring) collection stop triggering the failure, so the test
1191 # draws further and overruns before we see the realignment -- this
1192 # is common in stateful tests, where a non-failing step is followed
1193 # by more steps. Re-run without the length limit to recover the
1194 # realigned tree, which the repair logic below can then act on.
1195 attempt = self.engine.cached_test_function(
1196 [n.value for n in initial_attempt], extend="full"
1197 )
1198 if attempt.status is Status.OVERRUN:
1199 return False
1200
1201 if attempt.status is Status.INVALID:
1202 return False
1203
1204 # When we lower a choice that controls the size of a later collection,
1205 # eg
1206 #
1207 # n = data.draw_integer()
1208 # s = data.draw_string(min_size=n, max_size=n)
1209 #
1210 # the recorded value for that collection no longer fits the constraints
1211 # the test function actually used, so the engine realigns the tree by
1212 # substituting a freshly-generated (simplest) value -- discarding
1213 # whatever made the collection interesting. (We can't rely on
1214 # ``attempt.misaligned_at`` to detect this, because the realigned choice
1215 # sequence is often independently cached as an ordinary, non-misaligned
1216 # result.) We detect a string/bytes node whose recorded value is now too
1217 # long, and retry with it truncated to fit. We try preserving content
1218 # from either end, since the interesting part may be at the start or the
1219 # end (see test_can_shrink_variable_string_draws).
1220 for i in range(min(len(initial_attempt), len(attempt.nodes))):
1221 node = initial_attempt[i]
1222 attempt_node = attempt.nodes[i]
1223 if (
1224 node.type == attempt_node.type
1225 and node.type in {"string", "bytes"}
1226 and not node.was_forced
1227 and len(node.value) > attempt_node.constraints["max_size"]
1228 ):
1229 max_size = attempt_node.constraints["max_size"]
1230 for truncated in (node.value[:max_size], node.value[-max_size:]):
1231 if self.consider_new_nodes(
1232 initial_attempt[:i]
1233 + [
1234 node.copy(
1235 with_constraints=attempt_node.constraints,
1236 with_value=truncated,
1237 )
1238 ]
1239 + initial_attempt[i + 1 :]
1240 ):
1241 return True
1242
1243 lost_nodes = len(self.nodes) - len(attempt.nodes)
1244 if lost_nodes <= 0:
1245 return False
1246
1247 start = nodes[0].index
1248 end = nodes[-1].index + 1
1249 # We now look for contiguous regions to delete that might help fix up
1250 # this failed shrink. We only look for contiguous regions of the right
1251 # lengths because doing anything more than that starts to get very
1252 # expensive. See minimize_individual_choices for where we
1253 # try to be more aggressive.
1254 regions_to_delete = {(end, end + lost_nodes)}
1255
1256 for ex in self.spans:
1257 if ex.start > start:
1258 continue
1259 if ex.end <= end:
1260 continue
1261
1262 if ex.index >= len(attempt.spans):
1263 continue # pragma: no cover
1264
1265 replacement = attempt.spans[ex.index]
1266 in_original = [c for c in ex.children if c.start >= end]
1267 in_replaced = [c for c in replacement.children if c.start >= end]
1268
1269 if len(in_replaced) >= len(in_original) or not in_replaced:
1270 continue
1271
1272 # We've found a span where some of the children went missing
1273 # as a result of this change, and just replacing it with the data
1274 # it would have had and removing the spillover didn't work. This
1275 # means that some of its children towards the right must be
1276 # important, so we try to arrange it so that it retains its
1277 # rightmost children instead of its leftmost.
1278 regions_to_delete.add(
1279 (in_original[0].start, in_original[-len(in_replaced)].start)
1280 )
1281
1282 for u, v in sorted(regions_to_delete, key=lambda x: x[1] - x[0], reverse=True):
1283 try_with_deleted = initial_attempt[:u] + initial_attempt[v:]
1284 if self.consider_new_nodes(try_with_deleted):
1285 return True
1286
1287 return False
1288
1289 def remove_discarded(self):
1290 """Try removing all nodes marked as discarded.
1291
1292 This is primarily to deal with data that has been ignored while
1293 doing rejection sampling - e.g. as a result of an integer range, or a
1294 filtered strategy.
1295
1296 Such data will also be handled by the ``node_program("X")`` deletion
1297 passes, but those are necessarily more conservative and will try
1298 deleting each contiguous run of nodes individually. The common case is
1299 that all data drawn and rejected can just be thrown away immediately in
1300 one block, so this pass will be much faster than trying each one
1301 individually when it works.
1302
1303 returns False if there is discarded data and removing it does not work,
1304 otherwise returns True.
1305 """
1306 while self.shrink_target.has_discards:
1307 discarded = []
1308
1309 for ex in self.shrink_target.spans:
1310 if (
1311 ex.choice_count > 0
1312 and ex.discarded
1313 and (not discarded or ex.start >= discarded[-1][-1])
1314 ):
1315 discarded.append((ex.start, ex.end))
1316
1317 # This can happen if we have discards but they are all of
1318 # zero length. This shouldn't happen very often so it's
1319 # faster to check for it here than at the point of example
1320 # generation.
1321 if not discarded:
1322 break
1323
1324 attempt = list(self.nodes)
1325 for u, v in reversed(discarded):
1326 del attempt[u:v]
1327
1328 if not self.consider_new_nodes(tuple(attempt)):
1329 return False
1330 return True
1331
1332 @derived_value # type: ignore
1333 def duplicated_nodes(self):
1334 """Returns a list of nodes grouped (choice_type, value)."""
1335 duplicates = defaultdict(list)
1336 for node in self.nodes:
1337 duplicates[(node.type, choice_key(node.value))].append(node)
1338 return list(duplicates.values())
1339
1340 def node_program(self, program: str) -> ShrinkPass:
1341 return ShrinkPass(
1342 lambda chooser: self._node_program(chooser, program),
1343 name=f"node_program_{program}",
1344 )
1345
1346 def _node_program(self, chooser, program):
1347 n = len(program)
1348 # Adaptively attempt to run the node program at the current
1349 # index. If this successfully applies the node program ``k`` times
1350 # then this runs in ``O(log(k))`` test function calls.
1351 i = chooser.choose(range(len(self.nodes) - n + 1))
1352
1353 # First, run the node program at the chosen index. If this fails,
1354 # don't do any extra work, so that failure is as cheap as possible.
1355 if not self.run_node_program(i, program, original=self.shrink_target):
1356 return
1357
1358 # Because we run in a random order we will often find ourselves in the middle
1359 # of a region where we could run the node program. We thus start by moving
1360 # left to the beginning of that region if possible in order to start from
1361 # the beginning of that region.
1362 def offset_left(k):
1363 return i - k * n
1364
1365 i = offset_left(
1366 find_integer(
1367 lambda k: self.run_node_program(
1368 offset_left(k), program, original=self.shrink_target
1369 )
1370 )
1371 )
1372
1373 original = self.shrink_target
1374 # Now try to run the node program multiple times here.
1375 find_integer(
1376 lambda k: self.run_node_program(i, program, original=original, repeats=k)
1377 )
1378
1379 def minimize_duplicated_choices(self, chooser):
1380 """Find choices that have been duplicated in multiple places and attempt
1381 to minimize all of the duplicates simultaneously.
1382
1383 This lets us handle cases where two values can't be shrunk
1384 independently of each other but can easily be shrunk together.
1385 For example if we had something like:
1386
1387 ls = data.draw(lists(integers()))
1388 y = data.draw(integers())
1389 assert y not in ls
1390
1391 Suppose we drew y = 3 and after shrinking we have ls = [3]. If we were
1392 to replace both 3s with 0, this would be a valid shrink, but if we were
1393 to replace either 3 with 0 on its own the test would start passing.
1394
1395 It is also useful for when that duplication is accidental and the value
1396 of the choices don't matter very much because it allows us to replace
1397 more values at once.
1398 """
1399 nodes = chooser.choose(self.duplicated_nodes)
1400 # we can't lower any nodes which are trivial. try proceeding with the
1401 # remaining nodes.
1402 nodes = [node for node in nodes if not node.trivial]
1403 if len(nodes) <= 1:
1404 return
1405
1406 self.minimize_nodes(nodes)
1407
1408 def redistribute_numeric_pairs(self, chooser):
1409 """If there is a sum of generated numbers that we need their sum
1410 to exceed some bound, lowering one of them requires raising the
1411 other. This pass enables that."""
1412
1413 # look for a pair of nodes (node1, node2) which are both numeric
1414 # and aren't separated by too many other nodes. We'll decrease node1 and
1415 # increase node2 (note that the other way around doesn't make sense as
1416 # it's strictly worse in the ordering).
1417 def can_choose_node(node):
1418 # don't choose nan, inf, or floats above the threshold where f + 1 > f
1419 # (which is not necessarily true for floats above MAX_PRECISE_INTEGER).
1420 # The motivation for the last condition is to avoid trying weird
1421 # non-shrinks where we raise one node and think we lowered another
1422 # (but didn't).
1423 return node.type in {"integer", "float"} and not (
1424 node.type == "float"
1425 and (math.isnan(node.value) or abs(node.value) >= MAX_PRECISE_INTEGER)
1426 )
1427
1428 node1 = chooser.choose(
1429 self.nodes,
1430 lambda node: can_choose_node(node) and not node.trivial,
1431 )
1432 node2 = chooser.choose(
1433 self.nodes,
1434 lambda node: (
1435 can_choose_node(node)
1436 # Note that it's fine for node2 to be trivial, because we're going to
1437 # explicitly make it *not* trivial by adding to its value.
1438 and not node.was_forced
1439 # to avoid quadratic behavior, scan ahead only a small amount for
1440 # the related node.
1441 and node1.index < node.index <= node1.index + 4
1442 ),
1443 )
1444
1445 m: int | float = node1.value
1446 n: int | float = node2.value
1447
1448 def boost(k: int) -> bool:
1449 # floats always shrink towards 0
1450 shrink_towards = (
1451 node1.constraints["shrink_towards"] if node1.type == "integer" else 0
1452 )
1453 if k > abs(m - shrink_towards):
1454 return False
1455
1456 # We are trying to move node1 (m) closer to shrink_towards, and node2
1457 # (n) farther away from shrink_towards. If m is below shrink_towards,
1458 # we want to add to m and subtract from n, and vice versa if above
1459 # shrink_towards.
1460 if m < shrink_towards:
1461 k = -k
1462
1463 try:
1464 v1 = m - k
1465 v2 = n + k
1466 except OverflowError: # pragma: no cover
1467 # if n or m is a float and k is over sys.float_info.max, coercing
1468 # k to a float will overflow.
1469 return False
1470
1471 # if we've increased node2 to the point that we're past max precision,
1472 # give up - things have become too unstable.
1473 if node1.type == "float" and abs(v2) >= MAX_PRECISE_INTEGER:
1474 return False
1475
1476 return self.consider_new_nodes(
1477 self.nodes[: node1.index]
1478 + (node1.copy(with_value=v1),)
1479 + self.nodes[node1.index + 1 : node2.index]
1480 + (node2.copy(with_value=v2),)
1481 + self.nodes[node2.index + 1 :]
1482 )
1483
1484 find_integer(boost)
1485
1486 def lower_integers_together(self, chooser):
1487 node1 = chooser.choose(
1488 self.nodes, lambda n: n.type == "integer" and not n.trivial
1489 )
1490 # Search up to 3 nodes ahead, to avoid quadratic time.
1491 node2 = self.nodes[
1492 chooser.choose(
1493 range(node1.index + 1, min(len(self.nodes), node1.index + 3 + 1)),
1494 lambda i: (
1495 self.nodes[i].type == "integer" and not self.nodes[i].was_forced
1496 ),
1497 )
1498 ]
1499
1500 # one might expect us to require node2 to be nontrivial, and to minimize
1501 # the node which is closer to its shrink_towards, rather than node1
1502 # unconditionally. In reality, it's acceptable for us to transition node2
1503 # from trivial to nontrivial, because the shrink ordering is dominated by
1504 # the complexity of the earlier node1. What matters is minimizing node1.
1505 shrink_towards = node1.constraints["shrink_towards"]
1506
1507 def consider(n):
1508 return self.consider_new_nodes(
1509 self.nodes[: node1.index]
1510 + (node1.copy(with_value=node1.value - n),)
1511 + self.nodes[node1.index + 1 : node2.index]
1512 + (node2.copy(with_value=node2.value - n),)
1513 + self.nodes[node2.index + 1 :]
1514 )
1515
1516 find_integer(lambda n: consider(shrink_towards - n))
1517 find_integer(lambda n: consider(n - shrink_towards))
1518
1519 def lower_duplicated_characters(self, chooser):
1520 """
1521 Select two string choices no more than 4 choices apart and simultaneously
1522 lower characters which appear in both strings. This helps cases where the
1523 same character must appear in two strings, but the actual value of the
1524 character is not relevant.
1525
1526 This shrinking pass currently only tries lowering *all* instances of the
1527 duplicated character in both strings. So for instance, given two choices:
1528
1529 "bbac"
1530 "abbb"
1531
1532 we would try lowering all five of the b characters simultaneously. This
1533 may fail to shrink some cases where only certain character indices are
1534 correlated, for instance if only the b at index 1 could be lowered
1535 simultaneously and the rest did in fact actually have to be a `b`.
1536
1537 It would be nice to try shrinking that case as well, but we would need good
1538 safeguards because it could get very expensive to try all combinations.
1539 I expect lowering all duplicates to handle most cases in the meantime.
1540 """
1541 node1 = chooser.choose(
1542 self.nodes, lambda n: n.type == "string" and not n.trivial
1543 )
1544
1545 # limit search to up to 4 choices ahead, to avoid quadratic behavior
1546 node2 = self.nodes[
1547 chooser.choose(
1548 range(node1.index + 1, min(len(self.nodes), node1.index + 1 + 4)),
1549 lambda i: (
1550 self.nodes[i].type == "string"
1551 and not self.nodes[i].trivial
1552 # select nodes which have at least one of the same character present
1553 and set(node1.value) & set(self.nodes[i].value)
1554 ),
1555 )
1556 ]
1557
1558 duplicated_characters = set(node1.value) & set(node2.value)
1559 # deterministic ordering
1560 char = chooser.choose(sorted(duplicated_characters))
1561 intervals = node1.constraints["intervals"]
1562
1563 def copy_node(node, n):
1564 # replace all duplicate characters in each string. This might miss
1565 # some shrinks compared to only replacing some, but trying all possible
1566 # combinations of indices could get expensive if done without some
1567 # thought.
1568 return node.copy(
1569 with_value=node.value.replace(char, intervals.char_in_shrink_order(n))
1570 )
1571
1572 Integer.shrink(
1573 intervals.index_from_char_in_shrink_order(char),
1574 lambda n: self.consider_new_nodes(
1575 self.nodes[: node1.index]
1576 + (copy_node(node1, n),)
1577 + self.nodes[node1.index + 1 : node2.index]
1578 + (copy_node(node2, n),)
1579 + self.nodes[node2.index + 1 :]
1580 ),
1581 )
1582
1583 def normalize_unicode_chars(self, chooser):
1584 """For string nodes, try replacing characters with simpler equivalents
1585 from natural text transformations: unicode decomposition (NFD, NFKD)
1586 and case mapping. For example, an accented latin letter is reduced
1587 to its base form, a ligature is reduced to its first base character,
1588 a mathematical alphanumeric symbol is reduced to its plain ascii
1589 counterpart, and a lowercase letter is replaced with its uppercase
1590 form (which has a smaller shrink-order index in the default
1591 alphabet).
1592
1593 The codepoint shrinker is binary-search based, so it can get stuck on
1594 a high codepoint whose simpler equivalents aren't reached by halving
1595 / shifting / masking. This pass directly tries the natural simpler
1596 forms one character at a time.
1597 """
1598 node = chooser.choose(
1599 self.nodes,
1600 lambda n: n.type == "string"
1601 and any(
1602 _natural_simpler_chars(c, n.constraints["intervals"]) for c in n.value
1603 ),
1604 )
1605 intervals = node.constraints["intervals"]
1606 i = chooser.choose(
1607 range(len(node.value)),
1608 lambda j: bool(_natural_simpler_chars(node.value[j], intervals)),
1609 )
1610 for replacement in _natural_simpler_chars(node.value[i], intervals):
1611 new_value = node.value[:i] + replacement + node.value[i + 1 :]
1612 if self.consider_new_nodes(
1613 self.nodes[: node.index]
1614 + (node.copy(with_value=new_value),)
1615 + self.nodes[node.index + 1 :]
1616 ):
1617 return
1618
1619 def minimize_nodes(self, nodes):
1620 choice_type = nodes[0].type
1621 value = nodes[0].value
1622 # unlike choice_type and value, constraints are *not* guaranteed to be equal among all
1623 # passed nodes. We arbitrarily use the constraints of the first node. I think
1624 # this is unsound (= leads to us trying shrinks that could not have been
1625 # generated), but those get discarded at test-time, and this enables useful
1626 # slips where constraints are not equal but are close enough that doing the
1627 # same operation on both basically just works.
1628 constraints = nodes[0].constraints
1629 assert all(
1630 node.type == choice_type and choice_equal(node.value, value)
1631 for node in nodes
1632 )
1633
1634 if choice_type == "integer":
1635 shrink_towards = constraints["shrink_towards"]
1636 # try shrinking from both sides towards shrink_towards.
1637 # we're starting from n = abs(shrink_towards - value). Because the
1638 # shrinker will not check its starting value, we need to try
1639 # shrinking to n first.
1640 self.try_shrinking_nodes(nodes, abs(shrink_towards - value))
1641 Integer.shrink(
1642 abs(shrink_towards - value),
1643 lambda n: self.try_shrinking_nodes(nodes, shrink_towards + n),
1644 )
1645 Integer.shrink(
1646 abs(shrink_towards - value),
1647 lambda n: self.try_shrinking_nodes(nodes, shrink_towards - n),
1648 )
1649 elif choice_type == "float":
1650 self.try_shrinking_nodes(nodes, abs(value))
1651 Float.shrink(
1652 abs(value),
1653 lambda val: self.try_shrinking_nodes(nodes, val),
1654 )
1655 Float.shrink(
1656 abs(value),
1657 lambda val: self.try_shrinking_nodes(nodes, -val),
1658 )
1659 elif choice_type == "boolean":
1660 # must be True, otherwise would be trivial and not selected.
1661 assert value is True
1662 # only one thing to try: false!
1663 self.try_shrinking_nodes(nodes, False)
1664 elif choice_type == "bytes":
1665 Bytes.shrink(
1666 value,
1667 lambda val: self.try_shrinking_nodes(nodes, val),
1668 min_size=constraints["min_size"],
1669 )
1670 elif choice_type == "string":
1671 String.shrink(
1672 value,
1673 lambda val: self.try_shrinking_nodes(nodes, val),
1674 intervals=constraints["intervals"],
1675 min_size=constraints["min_size"],
1676 )
1677 else:
1678 raise NotImplementedError
1679
1680 def try_trivial_spans(self, chooser):
1681 i = chooser.choose(range(len(self.spans)))
1682
1683 prev = self.shrink_target
1684 nodes = self.shrink_target.nodes
1685 span = self.spans[i]
1686 prefix = nodes[: span.start]
1687 replacement = tuple(
1688 [
1689 (
1690 node
1691 if node.was_forced
1692 else node.copy(
1693 with_value=choice_from_index(0, node.type, node.constraints)
1694 )
1695 )
1696 for node in nodes[span.start : span.end]
1697 ]
1698 )
1699 suffix = nodes[span.end :]
1700 attempt = self.cached_test_function(prefix + replacement + suffix)[1]
1701
1702 if self.shrink_target is not prev:
1703 return
1704
1705 if isinstance(attempt, ConjectureResult):
1706 new_span = attempt.spans[i]
1707 new_replacement = attempt.nodes[new_span.start : new_span.end]
1708 self.consider_new_nodes(prefix + new_replacement + suffix)
1709
1710 def minimize_individual_choices(self, chooser):
1711 """Attempt to minimize each choice in sequence.
1712
1713 This is the pass that ensures that e.g. each integer we draw is a
1714 minimum value. So it's the part that guarantees that if we e.g. do
1715
1716 x = data.draw(integers())
1717 assert x < 10
1718
1719 then in our shrunk example, x = 10 rather than say 97.
1720
1721 If we are unsuccessful at minimizing a choice of interest we then
1722 check if that's because it's changing the size of the test case and,
1723 if so, we also make an attempt to delete parts of the test case to
1724 see if that fixes it.
1725
1726 We handle most of the common cases in try_shrinking_nodes which is
1727 pretty good at clearing out large contiguous blocks of dead space,
1728 but it fails when there is data that has to stay in particular places
1729 in the list.
1730 """
1731 node = chooser.choose(self.nodes, lambda node: not node.trivial)
1732 initial_target = self.shrink_target
1733
1734 self.minimize_nodes([node])
1735 if self.shrink_target is not initial_target:
1736 # the shrink target changed, so our shrink worked. Defer doing
1737 # anything more intelligent until this shrink fails.
1738 return
1739
1740 # the shrink failed. One particularly common case where minimizing a
1741 # node can fail is the antipattern of drawing a size and then drawing a
1742 # collection of that size, or more generally when there is a size
1743 # dependency on some single node. We'll explicitly try and fix up this
1744 # common case here: if decreasing an integer node by one would reduce
1745 # the size of the generated input, we'll try deleting things after that
1746 # node and see if the resulting attempt works.
1747
1748 if node.type != "integer":
1749 # Only try this fixup logic on integer draws. Almost all size
1750 # dependencies are on integer draws, and if it's not, it's doing
1751 # something convoluted enough that it is unlikely to shrink well anyway.
1752 # TODO: extent to floats? we probably currently fail on the following,
1753 # albeit convoluted example:
1754 # n = int(data.draw(st.floats()))
1755 # s = data.draw(st.lists(st.integers(), min_size=n, max_size=n))
1756 return
1757
1758 lowered = (
1759 self.nodes[: node.index]
1760 + (node.copy(with_value=node.value - 1),)
1761 + self.nodes[node.index + 1 :]
1762 )
1763 attempt = self.cached_test_function(lowered)[1]
1764 if (
1765 attempt is None
1766 or attempt.status < Status.VALID
1767 or len(attempt.nodes) == len(self.nodes)
1768 or len(attempt.nodes) == node.index + 1
1769 ):
1770 # no point in trying our size-dependency-logic if our attempt at
1771 # lowering the node resulted in:
1772 # * an invalid conjecture data
1773 # * the same number of nodes as before
1774 # * no nodes beyond the lowered node (nothing to try to delete afterwards)
1775 return
1776
1777 # If it were then the original shrink should have worked and we could
1778 # never have got here.
1779 assert attempt is not self.shrink_target
1780
1781 @self.cached(node.index)
1782 def first_span_after_node():
1783 lo = 0
1784 hi = len(self.spans)
1785 while lo + 1 < hi:
1786 mid = (lo + hi) // 2
1787 span = self.spans[mid]
1788 if span.start >= node.index:
1789 hi = mid
1790 else:
1791 lo = mid
1792 return hi
1793
1794 # we try deleting both entire spans, and single nodes.
1795 # If we wanted to get more aggressive, we could try deleting n
1796 # consecutive nodes (that don't cross a span boundary) for say
1797 # n <= 2 or n <= 3.
1798 if chooser.choose([True, False]):
1799 span = self.spans[
1800 chooser.choose(
1801 range(first_span_after_node, len(self.spans)),
1802 lambda i: self.spans[i].choice_count > 0,
1803 )
1804 ]
1805 self.consider_new_nodes(lowered[: span.start] + lowered[span.end :])
1806 else:
1807 node = self.nodes[chooser.choose(range(node.index + 1, len(self.nodes)))]
1808 self.consider_new_nodes(lowered[: node.index] + lowered[node.index + 1 :])
1809
1810 def reorder_spans(self, chooser):
1811 """This pass allows us to reorder the children of each span.
1812
1813 For example, consider the following:
1814
1815 .. code-block:: python
1816
1817 import hypothesis.strategies as st
1818 from hypothesis import given
1819
1820
1821 @given(st.text(), st.text())
1822 def test_not_equal(x, y):
1823 assert x != y
1824
1825 Without the ability to reorder x and y this could fail either with
1826 ``x=""``, ``y="0"``, or the other way around. With reordering it will
1827 reliably fail with ``x=""``, ``y="0"``.
1828 """
1829 span = chooser.choose(self.spans)
1830
1831 label = chooser.choose(span.children).label
1832 spans = [c for c in span.children if c.label == label]
1833 if len(spans) <= 1:
1834 return
1835
1836 endpoints = [(span.start, span.end) for span in spans]
1837 st = self.shrink_target
1838
1839 Ordering.shrink(
1840 range(len(spans)),
1841 lambda indices: self.consider_new_nodes(
1842 replace_all(
1843 st.nodes,
1844 [
1845 (
1846 u,
1847 v,
1848 st.nodes[spans[i].start : spans[i].end],
1849 )
1850 for (u, v), i in zip(endpoints, indices, strict=True)
1851 ],
1852 )
1853 ),
1854 key=lambda i: sort_key(st.nodes[spans[i].start : spans[i].end]),
1855 )
1856
1857 def run_node_program(self, i, program, original, repeats=1):
1858 """Node programs are a mini-DSL for node rewriting, defined as a sequence
1859 of commands that can be run at some index into the nodes
1860
1861 Commands are:
1862
1863 * "X", delete this node
1864
1865 This method runs the node program in ``program`` at node index
1866 ``i`` on the ConjectureData ``original``. If ``repeats > 1`` then it
1867 will attempt to approximate the results of running it that many times.
1868
1869 Returns True if this successfully changes the underlying shrink target,
1870 else False.
1871 """
1872 if i + len(program) > len(original.nodes) or i < 0:
1873 return False
1874 attempt = list(original.nodes)
1875 for _ in range(repeats):
1876 for k, command in reversed(list(enumerate(program))):
1877 j = i + k
1878 if j >= len(attempt):
1879 return False
1880
1881 if command == "X":
1882 del attempt[j]
1883 else:
1884 raise NotImplementedError(f"Unrecognised command {command!r}")
1885
1886 return self.consider_new_nodes(attempt)