1# This file is part of Hypothesis, which may be found at
2# https://github.com/HypothesisWorks/hypothesis/
3#
4# Copyright the Hypothesis Authors.
5# Individual contributors are listed in AUTHORS.rst and the git log.
6#
7# This Source Code Form is subject to the terms of the Mozilla Public License,
8# v. 2.0. If a copy of the MPL was not distributed with this file, You can
9# obtain one at https://mozilla.org/MPL/2.0/.
10
11from collections.abc import Iterable, Sequence
12from typing import TYPE_CHECKING, Union, cast, final
13
14if TYPE_CHECKING:
15 from typing import TypeAlias
16
17 from typing_extensions import Self
18
19IntervalsT: "TypeAlias" = tuple[tuple[int, int], ...]
20
21
22# @final makes mypy happy with the Self return annotations. We otherwise run
23# afoul of:
24# > You should not use Self as the return annotation if the method is not
25# > guaranteed to return an instance of a subclass when the class is subclassed
26# > https://docs.python.org/3/library/typing.html#typing.Self
27
28
29@final
30class IntervalSet:
31 @classmethod
32 def from_string(cls, s: str) -> "Self":
33 """Return a tuple of intervals, covering the codepoints of characters in `s`.
34
35 >>> IntervalSet.from_string('abcdef0123456789')
36 ((48, 57), (97, 102))
37 """
38 x = cls([(ord(c), ord(c)) for c in sorted(s)])
39 return x.union(x)
40
41 def __init__(self, intervals: Iterable[Sequence[int]] = ()) -> None:
42 self.intervals: IntervalsT = cast(
43 IntervalsT, tuple(tuple(v) for v in intervals)
44 )
45 # cast above is validated by this length assertion. check here instead of
46 # before to not exhaust generators before we create intervals from it
47 assert all(len(v) == 2 for v in self.intervals)
48
49 self.offsets: list[int] = [0]
50 for u, v in self.intervals:
51 self.offsets.append(self.offsets[-1] + v - u + 1)
52 self.size = self.offsets.pop()
53 self._idx_of_zero = self.index_above(ord("0"))
54 self._idx_of_Z = min(self.index_above(ord("Z")), len(self) - 1)
55
56 def __len__(self) -> int:
57 return self.size
58
59 def __iter__(self) -> Iterable[int]:
60 for u, v in self.intervals:
61 yield from range(u, v + 1)
62
63 def __getitem__(self, i: int) -> int:
64 if i < 0:
65 i = self.size + i
66 if i < 0 or i >= self.size:
67 raise IndexError(f"Invalid index {i} for [0, {self.size})")
68 # Want j = maximal such that offsets[j] <= i
69
70 j = len(self.intervals) - 1
71 if self.offsets[j] > i:
72 hi = j
73 lo = 0
74 # Invariant: offsets[lo] <= i < offsets[hi]
75 while lo + 1 < hi:
76 mid = (lo + hi) // 2
77 if self.offsets[mid] <= i:
78 lo = mid
79 else:
80 hi = mid
81 j = lo
82 t = i - self.offsets[j]
83 u, v = self.intervals[j]
84 r = u + t
85 assert r <= v
86 return r
87
88 def __contains__(self, elem: Union[str, int]) -> bool:
89 if isinstance(elem, str):
90 elem = ord(elem)
91 assert 0 <= elem <= 0x10FFFF
92 return any(start <= elem <= end for start, end in self.intervals)
93
94 def __repr__(self) -> str:
95 return f"IntervalSet({self.intervals!r})"
96
97 def index(self, value: int) -> int:
98 for offset, (u, v) in zip(self.offsets, self.intervals):
99 if u == value:
100 return offset
101 elif u > value:
102 raise ValueError(f"{value} is not in list")
103 if value <= v:
104 return offset + (value - u)
105 raise ValueError(f"{value} is not in list")
106
107 def index_above(self, value: int) -> int:
108 for offset, (u, v) in zip(self.offsets, self.intervals):
109 if u >= value:
110 return offset
111 if value <= v:
112 return offset + (value - u)
113 return self.size
114
115 def __or__(self, other: "Self") -> "Self":
116 return self.union(other)
117
118 def __sub__(self, other: "Self") -> "Self":
119 return self.difference(other)
120
121 def __and__(self, other: "Self") -> "Self":
122 return self.intersection(other)
123
124 def __eq__(self, other: object) -> bool:
125 return isinstance(other, IntervalSet) and (other.intervals == self.intervals)
126
127 def __hash__(self) -> int:
128 return hash(self.intervals)
129
130 def union(self, other: "Self") -> "Self":
131 """Merge two sequences of intervals into a single tuple of intervals.
132
133 Any integer bounded by `x` or `y` is also bounded by the result.
134
135 >>> union([(3, 10)], [(1, 2), (5, 17)])
136 ((1, 17),)
137 """
138 assert isinstance(other, type(self))
139 x = self.intervals
140 y = other.intervals
141 if not x:
142 return IntervalSet(y)
143 if not y:
144 return IntervalSet(x)
145 intervals = sorted(x + y, reverse=True)
146 result = [intervals.pop()]
147 while intervals:
148 # 1. intervals is in descending order
149 # 2. pop() takes from the RHS.
150 # 3. (a, b) was popped 1st, then (u, v) was popped 2nd
151 # 4. Therefore: a <= u
152 # 5. We assume that u <= v and a <= b
153 # 6. So we need to handle 2 cases of overlap, and one disjoint case
154 # | u--v | u----v | u--v |
155 # | a----b | a--b | a--b |
156 u, v = intervals.pop()
157 a, b = result[-1]
158 if u <= b + 1:
159 # Overlap cases
160 result[-1] = (a, max(v, b))
161 else:
162 # Disjoint case
163 result.append((u, v))
164 return IntervalSet(result)
165
166 def difference(self, other: "Self") -> "Self":
167 """Set difference for lists of intervals. That is, returns a list of
168 intervals that bounds all values bounded by x that are not also bounded by
169 y. x and y are expected to be in sorted order.
170
171 For example difference([(1, 10)], [(2, 3), (9, 15)]) would
172 return [(1, 1), (4, 8)], removing the values 2, 3, 9 and 10 from the
173 interval.
174 """
175 assert isinstance(other, type(self))
176 x = self.intervals
177 y = other.intervals
178 if not y:
179 return IntervalSet(x)
180 x = list(map(list, x))
181 i = 0
182 j = 0
183 result: list[Iterable[int]] = []
184 while i < len(x) and j < len(y):
185 # Iterate in parallel over x and y. j stays pointing at the smallest
186 # interval in the left hand side that could still overlap with some
187 # element of x at index >= i.
188 # Similarly, i is not incremented until we know that it does not
189 # overlap with any element of y at index >= j.
190
191 xl, xr = x[i]
192 assert xl <= xr
193 yl, yr = y[j]
194 assert yl <= yr
195
196 if yr < xl:
197 # The interval at y[j] is strictly to the left of the interval at
198 # x[i], so will not overlap with it or any later interval of x.
199 j += 1
200 elif yl > xr:
201 # The interval at y[j] is strictly to the right of the interval at
202 # x[i], so all of x[i] goes into the result as no further intervals
203 # in y will intersect it.
204 result.append(x[i])
205 i += 1
206 elif yl <= xl:
207 if yr >= xr:
208 # x[i] is contained entirely in y[j], so we just skip over it
209 # without adding it to the result.
210 i += 1
211 else:
212 # The beginning of x[i] is contained in y[j], so we update the
213 # left endpoint of x[i] to remove this, and increment j as we
214 # now have moved past it. Note that this is not added to the
215 # result as is, as more intervals from y may intersect it so it
216 # may need updating further.
217 x[i][0] = yr + 1
218 j += 1
219 else:
220 # yl > xl, so the left hand part of x[i] is not contained in y[j],
221 # so there are some values we should add to the result.
222 result.append((xl, yl - 1))
223
224 if yr + 1 <= xr:
225 # If y[j] finishes before x[i] does, there may be some values
226 # in x[i] left that should go in the result (or they may be
227 # removed by a later interval in y), so we update x[i] to
228 # reflect that and increment j because it no longer overlaps
229 # with any remaining element of x.
230 x[i][0] = yr + 1
231 j += 1
232 else:
233 # Every element of x[i] other than the initial part we have
234 # already added is contained in y[j], so we move to the next
235 # interval.
236 i += 1
237 # Any remaining intervals in x do not overlap with any of y, as if they did
238 # we would not have incremented j to the end, so can be added to the result
239 # as they are.
240 result.extend(x[i:])
241 return IntervalSet(map(tuple, result))
242
243 def intersection(self, other: "Self") -> "Self":
244 """Set intersection for lists of intervals."""
245 assert isinstance(other, type(self)), other
246 intervals = []
247 i = j = 0
248 while i < len(self.intervals) and j < len(other.intervals):
249 u, v = self.intervals[i]
250 U, V = other.intervals[j]
251 if u > V:
252 j += 1
253 elif U > v:
254 i += 1
255 else:
256 intervals.append((max(u, U), min(v, V)))
257 if v < V:
258 i += 1
259 else:
260 j += 1
261 return IntervalSet(intervals)
262
263 def char_in_shrink_order(self, i: int) -> str:
264 # We would like it so that, where possible, shrinking replaces
265 # characters with simple ascii characters, so we rejig this
266 # bit so that the smallest values are 0, 1, 2, ..., Z.
267 #
268 # Imagine that numbers are laid out as abc0yyyZ...
269 # this rearranges them so that they are laid out as
270 # 0yyyZcba..., which gives a better shrinking order.
271 if i <= self._idx_of_Z:
272 # We want to rewrite the integers [0, n] inclusive
273 # to [zero_point, Z_point].
274 n = self._idx_of_Z - self._idx_of_zero
275 if i <= n:
276 i += self._idx_of_zero
277 else:
278 # We want to rewrite the integers [n + 1, Z_point] to
279 # [zero_point, 0] (reversing the order so that codepoints below
280 # zero_point shrink upwards).
281 i = self._idx_of_zero - (i - n)
282 assert i < self._idx_of_zero
283 assert 0 <= i <= self._idx_of_Z
284
285 return chr(self[i])
286
287 def index_from_char_in_shrink_order(self, c: str) -> int:
288 """
289 Inverse of char_in_shrink_order.
290 """
291 assert len(c) == 1
292 i = self.index(ord(c))
293
294 if i <= self._idx_of_Z:
295 n = self._idx_of_Z - self._idx_of_zero
296 # Rewrite [zero_point, Z_point] to [0, n].
297 if self._idx_of_zero <= i <= self._idx_of_Z:
298 i -= self._idx_of_zero
299 assert 0 <= i <= n
300 # Rewrite [zero_point, 0] to [n + 1, Z_point].
301 else:
302 i = self._idx_of_zero - i + n
303 assert n + 1 <= i <= self._idx_of_Z
304 assert 0 <= i <= self._idx_of_Z
305
306 return i