1# This file is part of Hypothesis, which may be found at
2# https://github.com/HypothesisWorks/hypothesis/
3#
4# Copyright the Hypothesis Authors.
5# Individual contributors are listed in AUTHORS.rst and the git log.
6#
7# This Source Code Form is subject to the terms of the Mozilla Public License,
8# v. 2.0. If a copy of the MPL was not distributed with this file, You can
9# obtain one at https://mozilla.org/MPL/2.0/.
10
11from typing import Union
12
13
14class IntervalSet:
15 @classmethod
16 def from_string(cls, s):
17 """Return a tuple of intervals, covering the codepoints of characters in `s`.
18
19 >>> IntervalSet.from_string('abcdef0123456789')
20 ((48, 57), (97, 102))
21 """
22 x = cls((ord(c), ord(c)) for c in sorted(s))
23 return x.union(x)
24
25 def __init__(self, intervals):
26 self.intervals = tuple(intervals)
27 self.offsets = [0]
28 for u, v in self.intervals:
29 self.offsets.append(self.offsets[-1] + v - u + 1)
30 self.size = self.offsets.pop()
31 self._idx_of_zero = self.index_above(ord("0"))
32 self._idx_of_Z = min(self.index_above(ord("Z")), len(self) - 1)
33
34 def __len__(self):
35 return self.size
36
37 def __iter__(self):
38 for u, v in self.intervals:
39 yield from range(u, v + 1)
40
41 def __getitem__(self, i):
42 if i < 0:
43 i = self.size + i
44 if i < 0 or i >= self.size:
45 raise IndexError(f"Invalid index {i} for [0, {self.size})")
46 # Want j = maximal such that offsets[j] <= i
47
48 j = len(self.intervals) - 1
49 if self.offsets[j] > i:
50 hi = j
51 lo = 0
52 # Invariant: offsets[lo] <= i < offsets[hi]
53 while lo + 1 < hi:
54 mid = (lo + hi) // 2
55 if self.offsets[mid] <= i:
56 lo = mid
57 else:
58 hi = mid
59 j = lo
60 t = i - self.offsets[j]
61 u, v = self.intervals[j]
62 r = u + t
63 assert r <= v
64 return r
65
66 def __contains__(self, elem: Union[str, int]) -> bool:
67 if isinstance(elem, str):
68 elem = ord(elem)
69 assert 0 <= elem <= 0x10FFFF
70 return any(start <= elem <= end for start, end in self.intervals)
71
72 def __repr__(self):
73 return f"IntervalSet({self.intervals!r})"
74
75 def index(self, value: int) -> int:
76 for offset, (u, v) in zip(self.offsets, self.intervals):
77 if u == value:
78 return offset
79 elif u > value:
80 raise ValueError(f"{value} is not in list")
81 if value <= v:
82 return offset + (value - u)
83 raise ValueError(f"{value} is not in list")
84
85 def index_above(self, value: int) -> int:
86 for offset, (u, v) in zip(self.offsets, self.intervals):
87 if u >= value:
88 return offset
89 if value <= v:
90 return offset + (value - u)
91 return self.size
92
93 def __or__(self, other):
94 return self.union(other)
95
96 def __sub__(self, other):
97 return self.difference(other)
98
99 def __and__(self, other):
100 return self.intersection(other)
101
102 def __eq__(self, other):
103 return isinstance(other, IntervalSet) and (other.intervals == self.intervals)
104
105 def __hash__(self):
106 return hash(self.intervals)
107
108 def union(self, other):
109 """Merge two sequences of intervals into a single tuple of intervals.
110
111 Any integer bounded by `x` or `y` is also bounded by the result.
112
113 >>> union([(3, 10)], [(1, 2), (5, 17)])
114 ((1, 17),)
115 """
116 assert isinstance(other, type(self))
117 x = self.intervals
118 y = other.intervals
119 if not x:
120 return IntervalSet((u, v) for u, v in y)
121 if not y:
122 return IntervalSet((u, v) for u, v in x)
123 intervals = sorted(x + y, reverse=True)
124 result = [intervals.pop()]
125 while intervals:
126 # 1. intervals is in descending order
127 # 2. pop() takes from the RHS.
128 # 3. (a, b) was popped 1st, then (u, v) was popped 2nd
129 # 4. Therefore: a <= u
130 # 5. We assume that u <= v and a <= b
131 # 6. So we need to handle 2 cases of overlap, and one disjoint case
132 # | u--v | u----v | u--v |
133 # | a----b | a--b | a--b |
134 u, v = intervals.pop()
135 a, b = result[-1]
136 if u <= b + 1:
137 # Overlap cases
138 result[-1] = (a, max(v, b))
139 else:
140 # Disjoint case
141 result.append((u, v))
142 return IntervalSet(result)
143
144 def difference(self, other):
145 """Set difference for lists of intervals. That is, returns a list of
146 intervals that bounds all values bounded by x that are not also bounded by
147 y. x and y are expected to be in sorted order.
148
149 For example difference([(1, 10)], [(2, 3), (9, 15)]) would
150 return [(1, 1), (4, 8)], removing the values 2, 3, 9 and 10 from the
151 interval.
152 """
153 assert isinstance(other, type(self))
154 x = self.intervals
155 y = other.intervals
156 if not y:
157 return IntervalSet(x)
158 x = list(map(list, x))
159 i = 0
160 j = 0
161 result = []
162 while i < len(x) and j < len(y):
163 # Iterate in parallel over x and y. j stays pointing at the smallest
164 # interval in the left hand side that could still overlap with some
165 # element of x at index >= i.
166 # Similarly, i is not incremented until we know that it does not
167 # overlap with any element of y at index >= j.
168
169 xl, xr = x[i]
170 assert xl <= xr
171 yl, yr = y[j]
172 assert yl <= yr
173
174 if yr < xl:
175 # The interval at y[j] is strictly to the left of the interval at
176 # x[i], so will not overlap with it or any later interval of x.
177 j += 1
178 elif yl > xr:
179 # The interval at y[j] is strictly to the right of the interval at
180 # x[i], so all of x[i] goes into the result as no further intervals
181 # in y will intersect it.
182 result.append(x[i])
183 i += 1
184 elif yl <= xl:
185 if yr >= xr:
186 # x[i] is contained entirely in y[j], so we just skip over it
187 # without adding it to the result.
188 i += 1
189 else:
190 # The beginning of x[i] is contained in y[j], so we update the
191 # left endpoint of x[i] to remove this, and increment j as we
192 # now have moved past it. Note that this is not added to the
193 # result as is, as more intervals from y may intersect it so it
194 # may need updating further.
195 x[i][0] = yr + 1
196 j += 1
197 else:
198 # yl > xl, so the left hand part of x[i] is not contained in y[j],
199 # so there are some values we should add to the result.
200 result.append((xl, yl - 1))
201
202 if yr + 1 <= xr:
203 # If y[j] finishes before x[i] does, there may be some values
204 # in x[i] left that should go in the result (or they may be
205 # removed by a later interval in y), so we update x[i] to
206 # reflect that and increment j because it no longer overlaps
207 # with any remaining element of x.
208 x[i][0] = yr + 1
209 j += 1
210 else:
211 # Every element of x[i] other than the initial part we have
212 # already added is contained in y[j], so we move to the next
213 # interval.
214 i += 1
215 # Any remaining intervals in x do not overlap with any of y, as if they did
216 # we would not have incremented j to the end, so can be added to the result
217 # as they are.
218 result.extend(x[i:])
219 return IntervalSet(map(tuple, result))
220
221 def intersection(self, other):
222 """Set intersection for lists of intervals."""
223 assert isinstance(other, type(self)), other
224 intervals = []
225 i = j = 0
226 while i < len(self.intervals) and j < len(other.intervals):
227 u, v = self.intervals[i]
228 U, V = other.intervals[j]
229 if u > V:
230 j += 1
231 elif U > v:
232 i += 1
233 else:
234 intervals.append((max(u, U), min(v, V)))
235 if v < V:
236 i += 1
237 else:
238 j += 1
239 return IntervalSet(intervals)
240
241 def char_in_shrink_order(self, i: int) -> str:
242 # We would like it so that, where possible, shrinking replaces
243 # characters with simple ascii characters, so we rejig this
244 # bit so that the smallest values are 0, 1, 2, ..., Z.
245 #
246 # Imagine that numbers are laid out as abc0yyyZ...
247 # this rearranges them so that they are laid out as
248 # 0yyyZcba..., which gives a better shrinking order.
249 if i <= self._idx_of_Z:
250 # We want to rewrite the integers [0, n] inclusive
251 # to [zero_point, Z_point].
252 n = self._idx_of_Z - self._idx_of_zero
253 if i <= n:
254 i += self._idx_of_zero
255 else:
256 # We want to rewrite the integers [n + 1, Z_point] to
257 # [zero_point, 0] (reversing the order so that codepoints below
258 # zero_point shrink upwards).
259 i = self._idx_of_zero - (i - n)
260 assert i < self._idx_of_zero
261 assert 0 <= i <= self._idx_of_Z
262
263 return chr(self[i])
264
265 def index_from_char_in_shrink_order(self, c: str) -> int:
266 """
267 Inverse of char_in_shrink_order.
268 """
269 assert len(c) == 1
270 i = self.index(ord(c))
271
272 if i <= self._idx_of_Z:
273 n = self._idx_of_Z - self._idx_of_zero
274 # Rewrite [zero_point, Z_point] to [0, n].
275 if self._idx_of_zero <= i <= self._idx_of_Z:
276 i -= self._idx_of_zero
277 assert 0 <= i <= n
278 # Rewrite [zero_point, 0] to [n + 1, Z_point].
279 else:
280 i = self._idx_of_zero - i + n
281 assert n + 1 <= i <= self._idx_of_Z
282 assert 0 <= i <= self._idx_of_Z
283
284 return i