1# This file is part of Hypothesis, which may be found at
2# https://github.com/HypothesisWorks/hypothesis/
3#
4# Copyright the Hypothesis Authors.
5# Individual contributors are listed in AUTHORS.rst and the git log.
6#
7# This Source Code Form is subject to the terms of the Mozilla Public License,
8# v. 2.0. If a copy of the MPL was not distributed with this file, You can
9# obtain one at https://mozilla.org/MPL/2.0/.
10
11from collections.abc import Iterable, Sequence
12from typing import TYPE_CHECKING, TypeAlias, cast, final
13
14if TYPE_CHECKING:
15 from typing_extensions import Self
16
17IntervalsT: TypeAlias = tuple[tuple[int, int], ...]
18
19
20# @final makes mypy happy with the Self return annotations. We otherwise run
21# afoul of:
22# > You should not use Self as the return annotation if the method is not
23# > guaranteed to return an instance of a subclass when the class is subclassed
24# > https://docs.python.org/3/library/typing.html#typing.Self
25
26
27@final
28class IntervalSet:
29 """
30 A compact and efficient representation of a set of ``(a, b)`` intervals. Can
31 be treated like a set of integers, in that ``n in intervals`` will return
32 ``True`` if ``n`` is contained in any of the ``(a, b)`` intervals, and
33 ``False`` otherwise.
34 """
35
36 @classmethod
37 def from_string(cls, s: str) -> "Self":
38 """Return a tuple of intervals, covering the codepoints of characters in `s`.
39
40 >>> IntervalSet.from_string('abcdef0123456789')
41 ((48, 57), (97, 102))
42 """
43 x = cls([(ord(c), ord(c)) for c in sorted(s)])
44 return x.union(x)
45
46 def __init__(self, intervals: Iterable[Sequence[int]] = ()) -> None:
47 self.intervals: IntervalsT = cast(
48 IntervalsT, tuple(tuple(v) for v in intervals)
49 )
50 # cast above is validated by this length assertion. check here instead of
51 # before to not exhaust generators before we create intervals from it
52 assert all(len(v) == 2 for v in self.intervals)
53
54 self.offsets: list[int] = [0]
55 for u, v in self.intervals:
56 self.offsets.append(self.offsets[-1] + v - u + 1)
57 self.size = self.offsets.pop()
58 self._idx_of_zero = self.index_above(ord("0"))
59 self._idx_of_Z = min(self.index_above(ord("Z")), len(self) - 1)
60
61 def __len__(self) -> int:
62 return self.size
63
64 def __iter__(self) -> Iterable[int]:
65 for u, v in self.intervals:
66 yield from range(u, v + 1)
67
68 def __getitem__(self, i: int) -> int:
69 if i < 0:
70 i = self.size + i
71 if i < 0 or i >= self.size:
72 raise IndexError(f"Invalid index {i} for [0, {self.size})")
73 # Want j = maximal such that offsets[j] <= i
74
75 j = len(self.intervals) - 1
76 if self.offsets[j] > i:
77 hi = j
78 lo = 0
79 # Invariant: offsets[lo] <= i < offsets[hi]
80 while lo + 1 < hi:
81 mid = (lo + hi) // 2
82 if self.offsets[mid] <= i:
83 lo = mid
84 else:
85 hi = mid
86 j = lo
87 t = i - self.offsets[j]
88 u, v = self.intervals[j]
89 r = u + t
90 assert r <= v
91 return r
92
93 def __contains__(self, elem: str | int) -> bool:
94 if isinstance(elem, str):
95 elem = ord(elem)
96 assert 0 <= elem <= 0x10FFFF
97 return any(start <= elem <= end for start, end in self.intervals)
98
99 def __repr__(self) -> str:
100 return f"IntervalSet({self.intervals!r})"
101
102 def index(self, value: int) -> int:
103 for offset, (u, v) in zip(self.offsets, self.intervals, strict=True):
104 if u == value:
105 return offset
106 elif u > value:
107 raise ValueError(f"{value} is not in list")
108 if value <= v:
109 return offset + (value - u)
110 raise ValueError(f"{value} is not in list")
111
112 def index_above(self, value: int) -> int:
113 for offset, (u, v) in zip(self.offsets, self.intervals, strict=True):
114 if u >= value:
115 return offset
116 if value <= v:
117 return offset + (value - u)
118 return self.size
119
120 def __or__(self, other: "Self") -> "Self":
121 return self.union(other)
122
123 def __sub__(self, other: "Self") -> "Self":
124 return self.difference(other)
125
126 def __and__(self, other: "Self") -> "Self":
127 return self.intersection(other)
128
129 def __eq__(self, other: object) -> bool:
130 return isinstance(other, IntervalSet) and (other.intervals == self.intervals)
131
132 def __hash__(self) -> int:
133 return hash(self.intervals)
134
135 def union(self, other: "Self") -> "Self":
136 """Merge two sequences of intervals into a single tuple of intervals.
137
138 Any integer bounded by `x` or `y` is also bounded by the result.
139
140 >>> union([(3, 10)], [(1, 2), (5, 17)])
141 ((1, 17),)
142 """
143 assert isinstance(other, type(self))
144 x = self.intervals
145 y = other.intervals
146 if not x:
147 return IntervalSet(y)
148 if not y:
149 return IntervalSet(x)
150 intervals = sorted(x + y, reverse=True)
151 result = [intervals.pop()]
152 while intervals:
153 # 1. intervals is in descending order
154 # 2. pop() takes from the RHS.
155 # 3. (a, b) was popped 1st, then (u, v) was popped 2nd
156 # 4. Therefore: a <= u
157 # 5. We assume that u <= v and a <= b
158 # 6. So we need to handle 2 cases of overlap, and one disjoint case
159 # | u--v | u----v | u--v |
160 # | a----b | a--b | a--b |
161 u, v = intervals.pop()
162 a, b = result[-1]
163 if u <= b + 1:
164 # Overlap cases
165 result[-1] = (a, max(v, b))
166 else:
167 # Disjoint case
168 result.append((u, v))
169 return IntervalSet(result)
170
171 def difference(self, other: "Self") -> "Self":
172 """Set difference for lists of intervals. That is, returns a list of
173 intervals that bounds all values bounded by x that are not also bounded by
174 y. x and y are expected to be in sorted order.
175
176 For example difference([(1, 10)], [(2, 3), (9, 15)]) would
177 return [(1, 1), (4, 8)], removing the values 2, 3, 9 and 10 from the
178 interval.
179 """
180 assert isinstance(other, type(self))
181 x = self.intervals
182 y = other.intervals
183 if not y:
184 return IntervalSet(x)
185 x = list(map(list, x))
186 i = 0
187 j = 0
188 result: list[Iterable[int]] = []
189 while i < len(x) and j < len(y):
190 # Iterate in parallel over x and y. j stays pointing at the smallest
191 # interval in the left hand side that could still overlap with some
192 # element of x at index >= i.
193 # Similarly, i is not incremented until we know that it does not
194 # overlap with any element of y at index >= j.
195
196 xl, xr = x[i]
197 assert xl <= xr
198 yl, yr = y[j]
199 assert yl <= yr
200
201 if yr < xl:
202 # The interval at y[j] is strictly to the left of the interval at
203 # x[i], so will not overlap with it or any later interval of x.
204 j += 1
205 elif yl > xr:
206 # The interval at y[j] is strictly to the right of the interval at
207 # x[i], so all of x[i] goes into the result as no further intervals
208 # in y will intersect it.
209 result.append(x[i])
210 i += 1
211 elif yl <= xl:
212 if yr >= xr:
213 # x[i] is contained entirely in y[j], so we just skip over it
214 # without adding it to the result.
215 i += 1
216 else:
217 # The beginning of x[i] is contained in y[j], so we update the
218 # left endpoint of x[i] to remove this, and increment j as we
219 # now have moved past it. Note that this is not added to the
220 # result as is, as more intervals from y may intersect it so it
221 # may need updating further.
222 x[i][0] = yr + 1
223 j += 1
224 else:
225 # yl > xl, so the left hand part of x[i] is not contained in y[j],
226 # so there are some values we should add to the result.
227 result.append((xl, yl - 1))
228
229 if yr + 1 <= xr:
230 # If y[j] finishes before x[i] does, there may be some values
231 # in x[i] left that should go in the result (or they may be
232 # removed by a later interval in y), so we update x[i] to
233 # reflect that and increment j because it no longer overlaps
234 # with any remaining element of x.
235 x[i][0] = yr + 1
236 j += 1
237 else:
238 # Every element of x[i] other than the initial part we have
239 # already added is contained in y[j], so we move to the next
240 # interval.
241 i += 1
242 # Any remaining intervals in x do not overlap with any of y, as if they did
243 # we would not have incremented j to the end, so can be added to the result
244 # as they are.
245 result.extend(x[i:])
246 return IntervalSet(map(tuple, result))
247
248 def intersection(self, other: "Self") -> "Self":
249 """Set intersection for lists of intervals."""
250 assert isinstance(other, type(self)), other
251 intervals = []
252 i = j = 0
253 while i < len(self.intervals) and j < len(other.intervals):
254 u, v = self.intervals[i]
255 U, V = other.intervals[j]
256 if u > V:
257 j += 1
258 elif U > v:
259 i += 1
260 else:
261 intervals.append((max(u, U), min(v, V)))
262 if v < V:
263 i += 1
264 else:
265 j += 1
266 return IntervalSet(intervals)
267
268 def char_in_shrink_order(self, i: int) -> str:
269 # We would like it so that, where possible, shrinking replaces
270 # characters with simple ascii characters, so we rejig this
271 # bit so that the smallest values are 0, 1, 2, ..., Z.
272 #
273 # Imagine that numbers are laid out as abc0yyyZ...
274 # this rearranges them so that they are laid out as
275 # 0yyyZcba..., which gives a better shrinking order.
276 if i <= self._idx_of_Z:
277 # We want to rewrite the integers [0, n] inclusive
278 # to [zero_point, Z_point].
279 n = self._idx_of_Z - self._idx_of_zero
280 if i <= n:
281 i += self._idx_of_zero
282 else:
283 # We want to rewrite the integers [n + 1, Z_point] to
284 # [zero_point, 0] (reversing the order so that codepoints below
285 # zero_point shrink upwards).
286 i = self._idx_of_zero - (i - n)
287 assert i < self._idx_of_zero
288 assert 0 <= i <= self._idx_of_Z
289
290 return chr(self[i])
291
292 def index_from_char_in_shrink_order(self, c: str) -> int:
293 """
294 Inverse of char_in_shrink_order.
295 """
296 assert len(c) == 1
297 i = self.index(ord(c))
298
299 if i <= self._idx_of_Z:
300 n = self._idx_of_Z - self._idx_of_zero
301 # Rewrite [zero_point, Z_point] to [0, n].
302 if self._idx_of_zero <= i <= self._idx_of_Z:
303 i -= self._idx_of_zero
304 assert 0 <= i <= n
305 # Rewrite [zero_point, 0] to [n + 1, Z_point].
306 else:
307 i = self._idx_of_zero - i + n
308 assert n + 1 <= i <= self._idx_of_Z
309 assert 0 <= i <= self._idx_of_Z
310
311 return i