1"""Sorted Set
2=============
3
4:doc:`Sorted Containers<index>` is an Apache2 licensed Python sorted
5collections library, written in pure-Python, and fast as C-extensions. The
6:doc:`introduction<introduction>` is the best way to get started.
7
8Sorted set implementations:
9
10.. currentmodule:: sortedcontainers
11
12* :class:`SortedSet`
13
14"""
15
16from itertools import chain
17from operator import eq, ne, gt, ge, lt, le
18from textwrap import dedent
19
20from .sortedlist import SortedList, recursive_repr
21
22###############################################################################
23# BEGIN Python 2/3 Shims
24###############################################################################
25
26try:
27 from collections.abc import MutableSet, Sequence, Set
28except ImportError:
29 from collections import MutableSet, Sequence, Set
30
31###############################################################################
32# END Python 2/3 Shims
33###############################################################################
34
35
36class SortedSet(MutableSet, Sequence):
37 """Sorted set is a sorted mutable set.
38
39 Sorted set values are maintained in sorted order. The design of sorted set
40 is simple: sorted set uses a set for set-operations and maintains a sorted
41 list of values.
42
43 Sorted set values must be hashable and comparable. The hash and total
44 ordering of values must not change while they are stored in the sorted set.
45
46 Mutable set methods:
47
48 * :func:`SortedSet.__contains__`
49 * :func:`SortedSet.__iter__`
50 * :func:`SortedSet.__len__`
51 * :func:`SortedSet.add`
52 * :func:`SortedSet.discard`
53
54 Sequence methods:
55
56 * :func:`SortedSet.__getitem__`
57 * :func:`SortedSet.__delitem__`
58 * :func:`SortedSet.__reversed__`
59
60 Methods for removing values:
61
62 * :func:`SortedSet.clear`
63 * :func:`SortedSet.pop`
64 * :func:`SortedSet.remove`
65
66 Set-operation methods:
67
68 * :func:`SortedSet.difference`
69 * :func:`SortedSet.difference_update`
70 * :func:`SortedSet.intersection`
71 * :func:`SortedSet.intersection_update`
72 * :func:`SortedSet.symmetric_difference`
73 * :func:`SortedSet.symmetric_difference_update`
74 * :func:`SortedSet.union`
75 * :func:`SortedSet.update`
76
77 Methods for miscellany:
78
79 * :func:`SortedSet.copy`
80 * :func:`SortedSet.count`
81 * :func:`SortedSet.__repr__`
82 * :func:`SortedSet._check`
83
84 Sorted list methods available:
85
86 * :func:`SortedList.bisect_left`
87 * :func:`SortedList.bisect_right`
88 * :func:`SortedList.index`
89 * :func:`SortedList.irange`
90 * :func:`SortedList.islice`
91 * :func:`SortedList._reset`
92
93 Additional sorted list methods available, if key-function used:
94
95 * :func:`SortedKeyList.bisect_key_left`
96 * :func:`SortedKeyList.bisect_key_right`
97 * :func:`SortedKeyList.irange_key`
98
99 Sorted set comparisons use subset and superset relations. Two sorted sets
100 are equal if and only if every element of each sorted set is contained in
101 the other (each is a subset of the other). A sorted set is less than
102 another sorted set if and only if the first sorted set is a proper subset
103 of the second sorted set (is a subset, but is not equal). A sorted set is
104 greater than another sorted set if and only if the first sorted set is a
105 proper superset of the second sorted set (is a superset, but is not equal).
106
107 """
108 def __init__(self, iterable=None, key=None):
109 """Initialize sorted set instance.
110
111 Optional `iterable` argument provides an initial iterable of values to
112 initialize the sorted set.
113
114 Optional `key` argument defines a callable that, like the `key`
115 argument to Python's `sorted` function, extracts a comparison key from
116 each value. The default, none, compares values directly.
117
118 Runtime complexity: `O(n*log(n))`
119
120 >>> ss = SortedSet([3, 1, 2, 5, 4])
121 >>> ss
122 SortedSet([1, 2, 3, 4, 5])
123 >>> from operator import neg
124 >>> ss = SortedSet([3, 1, 2, 5, 4], neg)
125 >>> ss
126 SortedSet([5, 4, 3, 2, 1], key=<built-in function neg>)
127
128 :param iterable: initial values (optional)
129 :param key: function used to extract comparison key (optional)
130
131 """
132 self._key = key
133
134 # SortedSet._fromset calls SortedSet.__init__ after initializing the
135 # _set attribute. So only create a new set if the _set attribute is not
136 # already present.
137
138 if not hasattr(self, '_set'):
139 self._set = set()
140
141 self._list = SortedList(self._set, key=key)
142
143 # Expose some set methods publicly.
144
145 _set = self._set
146 self.isdisjoint = _set.isdisjoint
147 self.issubset = _set.issubset
148 self.issuperset = _set.issuperset
149
150 # Expose some sorted list methods publicly.
151
152 _list = self._list
153 self.bisect_left = _list.bisect_left
154 self.bisect = _list.bisect
155 self.bisect_right = _list.bisect_right
156 self.index = _list.index
157 self.irange = _list.irange
158 self.islice = _list.islice
159 self._reset = _list._reset
160
161 if key is not None:
162 self.bisect_key_left = _list.bisect_key_left
163 self.bisect_key_right = _list.bisect_key_right
164 self.bisect_key = _list.bisect_key
165 self.irange_key = _list.irange_key
166
167 if iterable is not None:
168 self._update(iterable)
169
170
171 @classmethod
172 def _fromset(cls, values, key=None):
173 """Initialize sorted set from existing set.
174
175 Used internally by set operations that return a new set.
176
177 """
178 sorted_set = object.__new__(cls)
179 sorted_set._set = values
180 sorted_set.__init__(key=key)
181 return sorted_set
182
183
184 @property
185 def key(self):
186 """Function used to extract comparison key from values.
187
188 Sorted set compares values directly when the key function is none.
189
190 """
191 return self._key
192
193
194 def __contains__(self, value):
195 """Return true if `value` is an element of the sorted set.
196
197 ``ss.__contains__(value)`` <==> ``value in ss``
198
199 Runtime complexity: `O(1)`
200
201 >>> ss = SortedSet([1, 2, 3, 4, 5])
202 >>> 3 in ss
203 True
204
205 :param value: search for value in sorted set
206 :return: true if `value` in sorted set
207
208 """
209 return value in self._set
210
211
212 def __getitem__(self, index):
213 """Lookup value at `index` in sorted set.
214
215 ``ss.__getitem__(index)`` <==> ``ss[index]``
216
217 Supports slicing.
218
219 Runtime complexity: `O(log(n))` -- approximate.
220
221 >>> ss = SortedSet('abcde')
222 >>> ss[2]
223 'c'
224 >>> ss[-1]
225 'e'
226 >>> ss[2:5]
227 ['c', 'd', 'e']
228
229 :param index: integer or slice for indexing
230 :return: value or list of values
231 :raises IndexError: if index out of range
232
233 """
234 return self._list[index]
235
236
237 def __delitem__(self, index):
238 """Remove value at `index` from sorted set.
239
240 ``ss.__delitem__(index)`` <==> ``del ss[index]``
241
242 Supports slicing.
243
244 Runtime complexity: `O(log(n))` -- approximate.
245
246 >>> ss = SortedSet('abcde')
247 >>> del ss[2]
248 >>> ss
249 SortedSet(['a', 'b', 'd', 'e'])
250 >>> del ss[:2]
251 >>> ss
252 SortedSet(['d', 'e'])
253
254 :param index: integer or slice for indexing
255 :raises IndexError: if index out of range
256
257 """
258 _set = self._set
259 _list = self._list
260 if isinstance(index, slice):
261 values = _list[index]
262 _set.difference_update(values)
263 else:
264 value = _list[index]
265 _set.remove(value)
266 del _list[index]
267
268
269 def __make_cmp(set_op, symbol, doc):
270 "Make comparator method."
271 def comparer(self, other):
272 "Compare method for sorted set and set."
273 if isinstance(other, SortedSet):
274 return set_op(self._set, other._set)
275 elif isinstance(other, Set):
276 return set_op(self._set, other)
277 return NotImplemented
278
279 set_op_name = set_op.__name__
280 comparer.__name__ = '__{0}__'.format(set_op_name)
281 doc_str = """Return true if and only if sorted set is {0} `other`.
282
283 ``ss.__{1}__(other)`` <==> ``ss {2} other``
284
285 Comparisons use subset and superset semantics as with sets.
286
287 Runtime complexity: `O(n)`
288
289 :param other: `other` set
290 :return: true if sorted set is {0} `other`
291
292 """
293 comparer.__doc__ = dedent(doc_str.format(doc, set_op_name, symbol))
294 return comparer
295
296
297 __eq__ = __make_cmp(eq, '==', 'equal to')
298 __ne__ = __make_cmp(ne, '!=', 'not equal to')
299 __lt__ = __make_cmp(lt, '<', 'a proper subset of')
300 __gt__ = __make_cmp(gt, '>', 'a proper superset of')
301 __le__ = __make_cmp(le, '<=', 'a subset of')
302 __ge__ = __make_cmp(ge, '>=', 'a superset of')
303 __make_cmp = staticmethod(__make_cmp)
304
305
306 def __len__(self):
307 """Return the size of the sorted set.
308
309 ``ss.__len__()`` <==> ``len(ss)``
310
311 :return: size of sorted set
312
313 """
314 return len(self._set)
315
316
317 def __iter__(self):
318 """Return an iterator over the sorted set.
319
320 ``ss.__iter__()`` <==> ``iter(ss)``
321
322 Iterating the sorted set while adding or deleting values may raise a
323 :exc:`RuntimeError` or fail to iterate over all values.
324
325 """
326 return iter(self._list)
327
328
329 def __reversed__(self):
330 """Return a reverse iterator over the sorted set.
331
332 ``ss.__reversed__()`` <==> ``reversed(ss)``
333
334 Iterating the sorted set while adding or deleting values may raise a
335 :exc:`RuntimeError` or fail to iterate over all values.
336
337 """
338 return reversed(self._list)
339
340
341 def add(self, value):
342 """Add `value` to sorted set.
343
344 Runtime complexity: `O(log(n))` -- approximate.
345
346 >>> ss = SortedSet()
347 >>> ss.add(3)
348 >>> ss.add(1)
349 >>> ss.add(2)
350 >>> ss
351 SortedSet([1, 2, 3])
352
353 :param value: value to add to sorted set
354
355 """
356 _set = self._set
357 if value not in _set:
358 _set.add(value)
359 self._list.add(value)
360
361 _add = add
362
363
364 def clear(self):
365 """Remove all values from sorted set.
366
367 Runtime complexity: `O(n)`
368
369 """
370 self._set.clear()
371 self._list.clear()
372
373
374 def copy(self):
375 """Return a shallow copy of the sorted set.
376
377 Runtime complexity: `O(n)`
378
379 :return: new sorted set
380
381 """
382 return self._fromset(set(self._set), key=self._key)
383
384 __copy__ = copy
385
386
387 def count(self, value):
388 """Return number of occurrences of `value` in the sorted set.
389
390 Runtime complexity: `O(1)`
391
392 >>> ss = SortedSet([1, 2, 3, 4, 5])
393 >>> ss.count(3)
394 1
395
396 :param value: value to count in sorted set
397 :return: count
398
399 """
400 return 1 if value in self._set else 0
401
402
403 def discard(self, value):
404 """Remove `value` from sorted set if it is a member.
405
406 If `value` is not a member, do nothing.
407
408 Runtime complexity: `O(log(n))` -- approximate.
409
410 >>> ss = SortedSet([1, 2, 3, 4, 5])
411 >>> ss.discard(5)
412 >>> ss.discard(0)
413 >>> ss == set([1, 2, 3, 4])
414 True
415
416 :param value: `value` to discard from sorted set
417
418 """
419 _set = self._set
420 if value in _set:
421 _set.remove(value)
422 self._list.remove(value)
423
424 _discard = discard
425
426
427 def pop(self, index=-1):
428 """Remove and return value at `index` in sorted set.
429
430 Raise :exc:`IndexError` if the sorted set is empty or index is out of
431 range.
432
433 Negative indices are supported.
434
435 Runtime complexity: `O(log(n))` -- approximate.
436
437 >>> ss = SortedSet('abcde')
438 >>> ss.pop()
439 'e'
440 >>> ss.pop(2)
441 'c'
442 >>> ss
443 SortedSet(['a', 'b', 'd'])
444
445 :param int index: index of value (default -1)
446 :return: value
447 :raises IndexError: if index is out of range
448
449 """
450 # pylint: disable=arguments-differ
451 value = self._list.pop(index)
452 self._set.remove(value)
453 return value
454
455
456 def remove(self, value):
457 """Remove `value` from sorted set; `value` must be a member.
458
459 If `value` is not a member, raise :exc:`KeyError`.
460
461 Runtime complexity: `O(log(n))` -- approximate.
462
463 >>> ss = SortedSet([1, 2, 3, 4, 5])
464 >>> ss.remove(5)
465 >>> ss == set([1, 2, 3, 4])
466 True
467 >>> ss.remove(0)
468 Traceback (most recent call last):
469 ...
470 KeyError: 0
471
472 :param value: `value` to remove from sorted set
473 :raises KeyError: if `value` is not in sorted set
474
475 """
476 self._set.remove(value)
477 self._list.remove(value)
478
479
480 def difference(self, *iterables):
481 """Return the difference of two or more sets as a new sorted set.
482
483 The `difference` method also corresponds to operator ``-``.
484
485 ``ss.__sub__(iterable)`` <==> ``ss - iterable``
486
487 The difference is all values that are in this sorted set but not the
488 other `iterables`.
489
490 >>> ss = SortedSet([1, 2, 3, 4, 5])
491 >>> ss.difference([4, 5, 6, 7])
492 SortedSet([1, 2, 3])
493
494 :param iterables: iterable arguments
495 :return: new sorted set
496
497 """
498 diff = self._set.difference(*iterables)
499 return self._fromset(diff, key=self._key)
500
501 __sub__ = difference
502
503
504 def difference_update(self, *iterables):
505 """Remove all values of `iterables` from this sorted set.
506
507 The `difference_update` method also corresponds to operator ``-=``.
508
509 ``ss.__isub__(iterable)`` <==> ``ss -= iterable``
510
511 >>> ss = SortedSet([1, 2, 3, 4, 5])
512 >>> _ = ss.difference_update([4, 5, 6, 7])
513 >>> ss
514 SortedSet([1, 2, 3])
515
516 :param iterables: iterable arguments
517 :return: itself
518
519 """
520 _set = self._set
521 _list = self._list
522 values = set(chain(*iterables))
523 if (4 * len(values)) > len(_set):
524 _set.difference_update(values)
525 _list.clear()
526 _list.update(_set)
527 else:
528 _discard = self._discard
529 for value in values:
530 _discard(value)
531 return self
532
533 __isub__ = difference_update
534
535
536 def intersection(self, *iterables):
537 """Return the intersection of two or more sets as a new sorted set.
538
539 The `intersection` method also corresponds to operator ``&``.
540
541 ``ss.__and__(iterable)`` <==> ``ss & iterable``
542
543 The intersection is all values that are in this sorted set and each of
544 the other `iterables`.
545
546 >>> ss = SortedSet([1, 2, 3, 4, 5])
547 >>> ss.intersection([4, 5, 6, 7])
548 SortedSet([4, 5])
549
550 :param iterables: iterable arguments
551 :return: new sorted set
552
553 """
554 intersect = self._set.intersection(*iterables)
555 return self._fromset(intersect, key=self._key)
556
557 __and__ = intersection
558 __rand__ = __and__
559
560
561 def intersection_update(self, *iterables):
562 """Update the sorted set with the intersection of `iterables`.
563
564 The `intersection_update` method also corresponds to operator ``&=``.
565
566 ``ss.__iand__(iterable)`` <==> ``ss &= iterable``
567
568 Keep only values found in itself and all `iterables`.
569
570 >>> ss = SortedSet([1, 2, 3, 4, 5])
571 >>> _ = ss.intersection_update([4, 5, 6, 7])
572 >>> ss
573 SortedSet([4, 5])
574
575 :param iterables: iterable arguments
576 :return: itself
577
578 """
579 _set = self._set
580 _list = self._list
581 _set.intersection_update(*iterables)
582 _list.clear()
583 _list.update(_set)
584 return self
585
586 __iand__ = intersection_update
587
588
589 def symmetric_difference(self, other):
590 """Return the symmetric difference with `other` as a new sorted set.
591
592 The `symmetric_difference` method also corresponds to operator ``^``.
593
594 ``ss.__xor__(other)`` <==> ``ss ^ other``
595
596 The symmetric difference is all values tha are in exactly one of the
597 sets.
598
599 >>> ss = SortedSet([1, 2, 3, 4, 5])
600 >>> ss.symmetric_difference([4, 5, 6, 7])
601 SortedSet([1, 2, 3, 6, 7])
602
603 :param other: `other` iterable
604 :return: new sorted set
605
606 """
607 diff = self._set.symmetric_difference(other)
608 return self._fromset(diff, key=self._key)
609
610 __xor__ = symmetric_difference
611 __rxor__ = __xor__
612
613
614 def symmetric_difference_update(self, other):
615 """Update the sorted set with the symmetric difference with `other`.
616
617 The `symmetric_difference_update` method also corresponds to operator
618 ``^=``.
619
620 ``ss.__ixor__(other)`` <==> ``ss ^= other``
621
622 Keep only values found in exactly one of itself and `other`.
623
624 >>> ss = SortedSet([1, 2, 3, 4, 5])
625 >>> _ = ss.symmetric_difference_update([4, 5, 6, 7])
626 >>> ss
627 SortedSet([1, 2, 3, 6, 7])
628
629 :param other: `other` iterable
630 :return: itself
631
632 """
633 _set = self._set
634 _list = self._list
635 _set.symmetric_difference_update(other)
636 _list.clear()
637 _list.update(_set)
638 return self
639
640 __ixor__ = symmetric_difference_update
641
642
643 def union(self, *iterables):
644 """Return new sorted set with values from itself and all `iterables`.
645
646 The `union` method also corresponds to operator ``|``.
647
648 ``ss.__or__(iterable)`` <==> ``ss | iterable``
649
650 >>> ss = SortedSet([1, 2, 3, 4, 5])
651 >>> ss.union([4, 5, 6, 7])
652 SortedSet([1, 2, 3, 4, 5, 6, 7])
653
654 :param iterables: iterable arguments
655 :return: new sorted set
656
657 """
658 return self.__class__(chain(iter(self), *iterables), key=self._key)
659
660 __or__ = union
661 __ror__ = __or__
662
663
664 def update(self, *iterables):
665 """Update the sorted set adding values from all `iterables`.
666
667 The `update` method also corresponds to operator ``|=``.
668
669 ``ss.__ior__(iterable)`` <==> ``ss |= iterable``
670
671 >>> ss = SortedSet([1, 2, 3, 4, 5])
672 >>> _ = ss.update([4, 5, 6, 7])
673 >>> ss
674 SortedSet([1, 2, 3, 4, 5, 6, 7])
675
676 :param iterables: iterable arguments
677 :return: itself
678
679 """
680 _set = self._set
681 _list = self._list
682 values = set(chain(*iterables))
683 if (4 * len(values)) > len(_set):
684 _list = self._list
685 _set.update(values)
686 _list.clear()
687 _list.update(_set)
688 else:
689 _add = self._add
690 for value in values:
691 _add(value)
692 return self
693
694 __ior__ = update
695 _update = update
696
697
698 def __reduce__(self):
699 """Support for pickle.
700
701 The tricks played with exposing methods in :func:`SortedSet.__init__`
702 confuse pickle so customize the reducer.
703
704 """
705 return (type(self), (self._set, self._key))
706
707
708 @recursive_repr()
709 def __repr__(self):
710 """Return string representation of sorted set.
711
712 ``ss.__repr__()`` <==> ``repr(ss)``
713
714 :return: string representation
715
716 """
717 _key = self._key
718 key = '' if _key is None else ', key={0!r}'.format(_key)
719 type_name = type(self).__name__
720 return '{0}({1!r}{2})'.format(type_name, list(self), key)
721
722
723 def _check(self):
724 """Check invariants of sorted set.
725
726 Runtime complexity: `O(n)`
727
728 """
729 _set = self._set
730 _list = self._list
731 _list._check()
732 assert len(_set) == len(_list)
733 assert all(value in _set for value in _list)