Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/more_itertools/more.py: 18%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import math
3from collections import Counter, defaultdict, deque
4from collections.abc import Sequence
5from contextlib import suppress
6from functools import cached_property, partial, reduce, wraps
7from heapq import heapify, heapreplace
8from itertools import (
9 chain,
10 combinations,
11 compress,
12 count,
13 cycle,
14 dropwhile,
15 groupby,
16 islice,
17 permutations,
18 repeat,
19 starmap,
20 takewhile,
21 tee,
22 zip_longest,
23 product,
24)
25from math import comb, e, exp, factorial, floor, fsum, log, log1p, perm, tau
26from math import ceil
27from queue import Empty, Queue
28from random import random, randrange, shuffle, uniform
29from operator import (
30 attrgetter,
31 getitem,
32 is_not,
33 itemgetter,
34 lt,
35 mul,
36 neg,
37 sub,
38 gt,
39)
40from sys import maxsize
41from time import monotonic
42from threading import Lock
44from .recipes import (
45 _marker,
46 consume,
47 first_true,
48 flatten,
49 is_prime,
50 nth,
51 powerset,
52 sieve,
53 take,
54 unique_everseen,
55 all_equal,
56 batched,
57)
59__all__ = [
60 'AbortThread',
61 'SequenceView',
62 'adjacent',
63 'all_unique',
64 'always_iterable',
65 'always_reversible',
66 'argmax',
67 'argmin',
68 'bucket',
69 'callback_iter',
70 'chunked',
71 'chunked_even',
72 'circular_shifts',
73 'collapse',
74 'combination_index',
75 'combination_with_replacement_index',
76 'concurrent_tee',
77 'consecutive_groups',
78 'constrained_batches',
79 'consumer',
80 'count_cycle',
81 'countable',
82 'derangements',
83 'dft',
84 'difference',
85 'distinct_combinations',
86 'distinct_permutations',
87 'distribute',
88 'divide',
89 'doublestarmap',
90 'duplicates_everseen',
91 'duplicates_justseen',
92 'classify_unique',
93 'exactly_n',
94 'extract',
95 'filter_except',
96 'filter_map',
97 'first',
98 'gray_product',
99 'groupby_transform',
100 'ichunked',
101 'iequals',
102 'idft',
103 'ilen',
104 'interleave',
105 'interleave_evenly',
106 'interleave_longest',
107 'interleave_randomly',
108 'intersperse',
109 'is_sorted',
110 'islice_extended',
111 'iterate',
112 'iter_suppress',
113 'join_mappings',
114 'last',
115 'locate',
116 'longest_common_prefix',
117 'lstrip',
118 'make_decorator',
119 'map_except',
120 'map_if',
121 'map_reduce',
122 'mark_ends',
123 'minmax',
124 'nth_or_last',
125 'nth_permutation',
126 'nth_prime',
127 'nth_product',
128 'nth_combination_with_replacement',
129 'numeric_range',
130 'one',
131 'only',
132 'outer_product',
133 'padded',
134 'partial_product',
135 'partitions',
136 'peekable',
137 'permutation_index',
138 'powerset_of_sets',
139 'product_index',
140 'raise_',
141 'repeat_each',
142 'repeat_last',
143 'replace',
144 'rlocate',
145 'rstrip',
146 'run_length',
147 'sample',
148 'seekable',
149 'serialize',
150 'set_partitions',
151 'side_effect',
152 'sliced',
153 'sort_together',
154 'split_after',
155 'split_at',
156 'split_before',
157 'split_into',
158 'split_when',
159 'spy',
160 'stagger',
161 'strip',
162 'strictly_n',
163 'substrings',
164 'substrings_indexes',
165 'synchronized',
166 'takewhile_inclusive',
167 'time_limited',
168 'unique_in_window',
169 'unique_to_each',
170 'unzip',
171 'value_chain',
172 'windowed',
173 'windowed_complete',
174 'with_iter',
175 'zip_broadcast',
176 'zip_offset',
177]
179# math.sumprod is available for Python 3.12+
180try:
181 from math import sumprod as _fsumprod
183except ImportError: # pragma: no cover
184 # Extended precision algorithms from T. J. Dekker,
185 # "A Floating-Point Technique for Extending the Available Precision"
186 # https://csclub.uwaterloo.ca/~pbarfuss/dekker1971.pdf
187 # Formulas: (5.5) (5.6) and (5.8). Code: mul12()
189 def dl_split(x: float):
190 "Split a float into two half-precision components."
191 t = x * 134217729.0 # Veltkamp constant = 2.0 ** 27 + 1
192 hi = t - (t - x)
193 lo = x - hi
194 return hi, lo
196 def dl_mul(x, y):
197 "Lossless multiplication."
198 xx_hi, xx_lo = dl_split(x)
199 yy_hi, yy_lo = dl_split(y)
200 p = xx_hi * yy_hi
201 q = xx_hi * yy_lo + xx_lo * yy_hi
202 z = p + q
203 zz = p - z + q + xx_lo * yy_lo
204 return z, zz
206 def _fsumprod(p, q):
207 return fsum(chain.from_iterable(map(dl_mul, p, q)))
210def chunked(iterable, n, strict=False):
211 """Break *iterable* into lists of length *n*:
213 >>> list(chunked([1, 2, 3, 4, 5, 6], 3))
214 [[1, 2, 3], [4, 5, 6]]
216 By the default, the last yielded list will have fewer than *n* elements
217 if the length of *iterable* is not divisible by *n*:
219 >>> list(chunked([1, 2, 3, 4, 5, 6, 7, 8], 3))
220 [[1, 2, 3], [4, 5, 6], [7, 8]]
222 To use a fill-in value instead, see the :func:`grouper` recipe.
224 If the length of *iterable* is not divisible by *n* and *strict* is
225 ``True``, then ``ValueError`` will be raised before the last
226 list is yielded.
228 """
229 iterator = iter(partial(take, n, iter(iterable)), [])
230 if strict:
231 if n is None:
232 raise ValueError('n must not be None when using strict mode.')
234 def ret():
235 for chunk in iterator:
236 if len(chunk) != n:
237 raise ValueError('iterable is not divisible by n.')
238 yield chunk
240 return ret()
241 else:
242 return iterator
245def first(iterable, default=_marker):
246 """Return the first item of *iterable*, or *default* if *iterable* is
247 empty.
249 >>> first([0, 1, 2, 3])
250 0
251 >>> first([], 'some default')
252 'some default'
254 If *default* is not provided and there are no items in the iterable,
255 raise ``ValueError``.
257 :func:`first` is useful when you have a generator of expensive-to-retrieve
258 values and want any arbitrary one. It is marginally shorter than
259 ``next(iter(iterable), default)``.
261 """
262 for item in iterable:
263 return item
264 if default is _marker:
265 raise ValueError(
266 'first() was called on an empty iterable, '
267 'and no default value was provided.'
268 )
269 return default
272def last(iterable, default=_marker):
273 """Return the last item of *iterable*, or *default* if *iterable* is
274 empty.
276 >>> last([0, 1, 2, 3])
277 3
278 >>> last([], 'some default')
279 'some default'
281 If *default* is not provided and there are no items in the iterable,
282 raise ``ValueError``.
283 """
284 try:
285 if isinstance(iterable, Sequence):
286 return iterable[-1]
287 # Work around https://bugs.python.org/issue38525
288 if getattr(iterable, '__reversed__', None):
289 return next(reversed(iterable))
290 return deque(iterable, maxlen=1)[-1]
291 except (IndexError, TypeError, StopIteration):
292 if default is _marker:
293 raise ValueError(
294 'last() was called on an empty iterable, '
295 'and no default value was provided.'
296 )
297 return default
300def nth_or_last(iterable, n, default=_marker):
301 """Return the nth or the last item of *iterable*,
302 or *default* if *iterable* is empty.
304 >>> nth_or_last([0, 1, 2, 3], 2)
305 2
306 >>> nth_or_last([0, 1], 2)
307 1
308 >>> nth_or_last([], 0, 'some default')
309 'some default'
311 If *default* is not provided and there are no items in the iterable,
312 raise ``ValueError``.
313 """
314 return last(islice(iterable, n + 1), default=default)
317class peekable:
318 """Wrap an iterator to allow lookahead and prepending elements.
320 Call :meth:`peek` on the result to get the value that will be returned
321 by :func:`next`. This won't advance the iterator:
323 >>> p = peekable(['a', 'b'])
324 >>> p.peek()
325 'a'
326 >>> next(p)
327 'a'
329 Pass :meth:`peek` a default value to return that instead of raising
330 ``StopIteration`` when the iterator is exhausted.
332 >>> p = peekable([])
333 >>> p.peek('hi')
334 'hi'
336 peekables also offer a :meth:`prepend` method, which "inserts" items
337 at the head of the iterable:
339 >>> p = peekable([1, 2, 3])
340 >>> p.prepend(10, 11, 12)
341 >>> next(p)
342 10
343 >>> p.peek()
344 11
345 >>> list(p)
346 [11, 12, 1, 2, 3]
348 peekables can be indexed. Index 0 is the item that will be returned by
349 :func:`next`, index 1 is the item after that, and so on:
350 The values up to the given index will be cached.
352 >>> p = peekable(['a', 'b', 'c', 'd'])
353 >>> p[0]
354 'a'
355 >>> p[1]
356 'b'
357 >>> next(p)
358 'a'
360 Negative indexes are supported, but be aware that they will cache the
361 remaining items in the source iterator, which may require significant
362 storage.
364 To check whether a peekable is exhausted, check its truth value:
366 >>> p = peekable(['a', 'b'])
367 >>> if p: # peekable has items
368 ... list(p)
369 ['a', 'b']
370 >>> if not p: # peekable is exhausted
371 ... list(p)
372 []
374 """
376 def __init__(self, iterable):
377 self._it = iter(iterable)
378 self._cache = deque()
380 def __iter__(self):
381 return self
383 def __bool__(self):
384 try:
385 self.peek()
386 except StopIteration:
387 return False
388 return True
390 def peek(self, default=_marker):
391 """Return the item that will be next returned from ``next()``.
393 Return ``default`` if there are no items left. If ``default`` is not
394 provided, raise ``StopIteration``.
396 """
397 if not self._cache:
398 try:
399 self._cache.append(next(self._it))
400 except StopIteration:
401 if default is _marker:
402 raise
403 return default
404 return self._cache[0]
406 def prepend(self, *items):
407 """Stack up items to be the next ones returned from ``next()`` or
408 ``self.peek()``. The items will be returned in
409 first in, first out order::
411 >>> p = peekable([1, 2, 3])
412 >>> p.prepend(10, 11, 12)
413 >>> next(p)
414 10
415 >>> list(p)
416 [11, 12, 1, 2, 3]
418 It is possible, by prepending items, to "resurrect" a peekable that
419 previously raised ``StopIteration``.
421 >>> p = peekable([])
422 >>> next(p)
423 Traceback (most recent call last):
424 ...
425 StopIteration
426 >>> p.prepend(1)
427 >>> next(p)
428 1
429 >>> next(p)
430 Traceback (most recent call last):
431 ...
432 StopIteration
434 """
435 self._cache.extendleft(reversed(items))
437 def __next__(self):
438 if self._cache:
439 return self._cache.popleft()
441 return next(self._it)
443 def _get_slice(self, index):
444 # Normalize the slice's arguments
445 step = 1 if (index.step is None) else index.step
446 if step > 0:
447 start = 0 if (index.start is None) else index.start
448 stop = maxsize if (index.stop is None) else index.stop
449 elif step < 0:
450 start = -1 if (index.start is None) else index.start
451 stop = (-maxsize - 1) if (index.stop is None) else index.stop
452 else:
453 raise ValueError('slice step cannot be zero')
455 # If either the start or stop index is negative, we'll need to cache
456 # the rest of the iterable in order to slice from the right side.
457 if (start < 0) or (stop < 0):
458 self._cache.extend(self._it)
459 # Otherwise we'll need to find the rightmost index and cache to that
460 # point.
461 else:
462 n = min(max(start, stop) + 1, maxsize)
463 cache_len = len(self._cache)
464 if n >= cache_len:
465 self._cache.extend(islice(self._it, n - cache_len))
467 return list(self._cache)[index]
469 def __getitem__(self, index):
470 if isinstance(index, slice):
471 return self._get_slice(index)
473 cache_len = len(self._cache)
474 if index < 0:
475 self._cache.extend(self._it)
476 elif index >= cache_len:
477 self._cache.extend(islice(self._it, index + 1 - cache_len))
479 return self._cache[index]
482def consumer(func):
483 """Decorator that automatically advances a PEP-342-style "reverse iterator"
484 to its first yield point so you don't have to call ``next()`` on it
485 manually.
487 >>> @consumer
488 ... def tally():
489 ... i = 0
490 ... while True:
491 ... print('Thing number %s is %s.' % (i, (yield)))
492 ... i += 1
493 ...
494 >>> t = tally()
495 >>> t.send('red')
496 Thing number 0 is red.
497 >>> t.send('fish')
498 Thing number 1 is fish.
500 Without the decorator, you would have to call ``next(t)`` before
501 ``t.send()`` could be used.
503 """
505 @wraps(func)
506 def wrapper(*args, **kwargs):
507 gen = func(*args, **kwargs)
508 next(gen)
509 return gen
511 return wrapper
514def ilen(iterable):
515 """Return the number of items in *iterable*.
517 For example, there are 168 prime numbers below 1,000:
519 >>> ilen(sieve(1000))
520 168
522 Equivalent to, but faster than::
524 def ilen(iterable):
525 count = 0
526 for _ in iterable:
527 count += 1
528 return count
530 This fully consumes the iterable, so handle with care.
532 """
533 # This is the "most beautiful of the fast variants" of this function.
534 # If you think you can improve on it, please ensure that your version
535 # is both 10x faster and 10x more beautiful.
536 return sum(compress(repeat(1), zip(iterable)))
539def iterate(func, start):
540 """Return ``start``, ``func(start)``, ``func(func(start))``, ...
542 Produces an infinite iterator. To add a stopping condition,
543 use :func:`take`, ``takewhile``, or :func:`takewhile_inclusive`:.
545 >>> take(10, iterate(lambda x: 2*x, 1))
546 [1, 2, 4, 8, 16, 32, 64, 128, 256, 512]
548 >>> collatz = lambda x: 3*x + 1 if x%2==1 else x // 2
549 >>> list(takewhile_inclusive(lambda x: x!=1, iterate(collatz, 10)))
550 [10, 5, 16, 8, 4, 2, 1]
552 """
553 with suppress(StopIteration):
554 while True:
555 yield start
556 start = func(start)
559def with_iter(context_manager):
560 """Wrap an iterable in a ``with`` statement, so it closes once exhausted.
562 For example, this will close the file when the iterator is exhausted::
564 upper_lines = (line.upper() for line in with_iter(open('foo')))
566 Any context manager which returns an iterable is a candidate for
567 ``with_iter``.
569 """
570 with context_manager as iterable:
571 yield from iterable
574def one(iterable, too_short=None, too_long=None):
575 """Return the first item from *iterable*, which is expected to contain only
576 that item. Raise an exception if *iterable* is empty or has more than one
577 item.
579 :func:`one` is useful for ensuring that an iterable contains only one item.
580 For example, it can be used to retrieve the result of a database query
581 that is expected to return a single row.
583 If *iterable* is empty, ``ValueError`` will be raised. You may specify a
584 different exception with the *too_short* keyword:
586 >>> it = []
587 >>> one(it) # doctest: +IGNORE_EXCEPTION_DETAIL
588 Traceback (most recent call last):
589 ...
590 ValueError: too few items in iterable (expected 1)'
591 >>> too_short = IndexError('too few items')
592 >>> one(it, too_short=too_short) # doctest: +IGNORE_EXCEPTION_DETAIL
593 Traceback (most recent call last):
594 ...
595 IndexError: too few items
597 Similarly, if *iterable* contains more than one item, ``ValueError`` will
598 be raised. You may specify a different exception with the *too_long*
599 keyword:
601 >>> it = ['too', 'many']
602 >>> one(it) # doctest: +IGNORE_EXCEPTION_DETAIL
603 Traceback (most recent call last):
604 ...
605 ValueError: Expected exactly one item in iterable, but got 'too',
606 'many', and perhaps more.
607 >>> too_long = RuntimeError
608 >>> one(it, too_long=too_long) # doctest: +IGNORE_EXCEPTION_DETAIL
609 Traceback (most recent call last):
610 ...
611 RuntimeError
613 Note that :func:`one` attempts to advance *iterable* twice to ensure there
614 is only one item. See :func:`spy` or :func:`peekable` to check iterable
615 contents less destructively.
617 """
618 iterator = iter(iterable)
619 for first in iterator:
620 for second in iterator:
621 msg = (
622 f'Expected exactly one item in iterable, but got {first!r}, '
623 f'{second!r}, and perhaps more.'
624 )
625 raise too_long or ValueError(msg)
626 return first
627 raise too_short or ValueError('too few items in iterable (expected 1)')
630def raise_(exception, *args):
631 raise exception(*args)
634def strictly_n(iterable, n, too_short=None, too_long=None):
635 """Validate that *iterable* has exactly *n* items and return them if
636 it does. If it has fewer than *n* items, call function *too_short*
637 with the actual number of items. If it has more than *n* items, call function
638 *too_long* with the number ``n + 1``.
640 >>> iterable = ['a', 'b', 'c', 'd']
641 >>> n = 4
642 >>> list(strictly_n(iterable, n))
643 ['a', 'b', 'c', 'd']
645 Note that the returned iterable must be consumed in order for the check to
646 be made.
648 By default, *too_short* and *too_long* are functions that raise
649 ``ValueError``.
651 >>> list(strictly_n('ab', 3)) # doctest: +IGNORE_EXCEPTION_DETAIL
652 Traceback (most recent call last):
653 ...
654 ValueError: too few items in iterable (got 2)
656 >>> list(strictly_n('abc', 2)) # doctest: +IGNORE_EXCEPTION_DETAIL
657 Traceback (most recent call last):
658 ...
659 ValueError: too many items in iterable (got at least 3)
661 You can instead supply functions that do something else.
662 *too_short* will be called with the number of items in *iterable*.
663 *too_long* will be called with `n + 1`.
665 >>> def too_short(item_count):
666 ... raise RuntimeError
667 >>> it = strictly_n('abcd', 6, too_short=too_short)
668 >>> list(it) # doctest: +IGNORE_EXCEPTION_DETAIL
669 Traceback (most recent call last):
670 ...
671 RuntimeError
673 >>> def too_long(item_count):
674 ... print('The boss is going to hear about this')
675 >>> it = strictly_n('abcdef', 4, too_long=too_long)
676 >>> list(it)
677 The boss is going to hear about this
678 ['a', 'b', 'c', 'd']
680 """
681 if too_short is None:
682 too_short = lambda item_count: raise_(
683 ValueError,
684 f'Too few items in iterable (got {item_count})',
685 )
687 if too_long is None:
688 too_long = lambda item_count: raise_(
689 ValueError,
690 f'Too many items in iterable (got at least {item_count})',
691 )
693 it = iter(iterable)
695 sent = 0
696 for item in islice(it, n):
697 yield item
698 sent += 1
700 if sent < n:
701 too_short(sent)
702 return
704 for item in it:
705 too_long(n + 1)
706 return
709def distinct_permutations(iterable, r=None):
710 """Yield successive distinct permutations of the elements in *iterable*.
712 >>> sorted(distinct_permutations([1, 0, 1]))
713 [(0, 1, 1), (1, 0, 1), (1, 1, 0)]
715 Equivalent to yielding from ``set(permutations(iterable))``, except
716 duplicates are not generated and thrown away. For larger input sequences
717 this is much more efficient.
719 Duplicate permutations arise when there are duplicated elements in the
720 input iterable. The number of items returned is
721 `n! / (x_1! * x_2! * ... * x_n!)`, where `n` is the total number of
722 items input, and each `x_i` is the count of a distinct item in the input
723 sequence. The function :func:`multinomial` computes this directly.
725 If *r* is given, only the *r*-length permutations are yielded.
727 >>> sorted(distinct_permutations([1, 0, 1], r=2))
728 [(0, 1), (1, 0), (1, 1)]
729 >>> sorted(distinct_permutations(range(3), r=2))
730 [(0, 1), (0, 2), (1, 0), (1, 2), (2, 0), (2, 1)]
732 *iterable* need not be sortable, but note that using equal (``x == y``)
733 but non-identical (``id(x) != id(y)``) elements may produce surprising
734 behavior. For example, ``1`` and ``True`` are equal but non-identical:
736 >>> list(distinct_permutations([1, True, '3'])) # doctest: +SKIP
737 [
738 (1, True, '3'),
739 (1, '3', True),
740 ('3', 1, True)
741 ]
742 >>> list(distinct_permutations([1, 2, '3'])) # doctest: +SKIP
743 [
744 (1, 2, '3'),
745 (1, '3', 2),
746 (2, 1, '3'),
747 (2, '3', 1),
748 ('3', 1, 2),
749 ('3', 2, 1)
750 ]
751 """
753 # Algorithm: https://w.wiki/Qai
754 def _full(A):
755 while True:
756 # Yield the permutation we have
757 yield tuple(A)
759 # Find the largest index i such that A[i] < A[i + 1]
760 for i in range(size - 2, -1, -1):
761 if A[i] < A[i + 1]:
762 break
763 # If no such index exists, this permutation is the last one
764 else:
765 return
767 # Find the largest index j greater than j such that A[i] < A[j]
768 for j in range(size - 1, i, -1):
769 if A[i] < A[j]:
770 break
772 # Swap the value of A[i] with that of A[j], then reverse the
773 # sequence from A[i + 1] to form the new permutation
774 A[i], A[j] = A[j], A[i]
775 A[i + 1 :] = A[: i - size : -1] # A[i + 1:][::-1]
777 # Algorithm: modified from the above
778 def _partial(A, r):
779 # Split A into the first r items and the last r items
780 head, tail = A[:r], A[r:]
781 right_head_indexes = range(r - 1, -1, -1)
782 left_tail_indexes = range(len(tail))
784 while True:
785 # Yield the permutation we have
786 yield tuple(head)
788 # Starting from the right, find the first index of the head with
789 # value smaller than the maximum value of the tail - call it i.
790 pivot = tail[-1]
791 for i in right_head_indexes:
792 if head[i] < pivot:
793 break
794 pivot = head[i]
795 else:
796 return
798 # Starting from the left, find the first value of the tail
799 # with a value greater than head[i] and swap.
800 for j in left_tail_indexes:
801 if tail[j] > head[i]:
802 head[i], tail[j] = tail[j], head[i]
803 break
804 # If we didn't find one, start from the right and find the first
805 # index of the head with a value greater than head[i] and swap.
806 else:
807 for j in right_head_indexes:
808 if head[j] > head[i]:
809 head[i], head[j] = head[j], head[i]
810 break
812 # Reverse head[i + 1:] and swap it with tail[:r - (i + 1)]
813 tail += head[: i - r : -1] # head[i + 1:][::-1]
814 i += 1
815 head[i:], tail[:] = tail[: r - i], tail[r - i :]
817 items = list(iterable)
819 try:
820 items.sort()
821 sortable = True
822 except TypeError:
823 sortable = False
825 indices_dict = defaultdict(list)
827 for item in items:
828 indices_dict[items.index(item)].append(item)
830 indices = [items.index(item) for item in items]
831 indices.sort()
833 equivalent_items = {k: cycle(v) for k, v in indices_dict.items()}
835 def permuted_items(permuted_indices):
836 return tuple(
837 next(equivalent_items[index]) for index in permuted_indices
838 )
840 size = len(items)
841 if r is None:
842 r = size
844 # functools.partial(_partial, ... )
845 algorithm = _full if (r == size) else partial(_partial, r=r)
847 if 0 < r <= size:
848 if sortable:
849 return algorithm(items)
850 else:
851 return (
852 permuted_items(permuted_indices)
853 for permuted_indices in algorithm(indices)
854 )
856 return iter(() if r else ((),))
859def derangements(iterable, r=None):
860 """Yield successive derangements of the elements in *iterable*.
862 A derangement is a permutation in which no element appears at its original
863 index. In other words, a derangement is a permutation that has no fixed points.
865 Suppose Alice, Bob, Carol, and Dave are playing Secret Santa.
866 The code below outputs all of the different ways to assign gift recipients
867 such that nobody is assigned to himself or herself:
869 >>> for d in derangements(['Alice', 'Bob', 'Carol', 'Dave']):
870 ... print(', '.join(d))
871 Bob, Alice, Dave, Carol
872 Bob, Carol, Dave, Alice
873 Bob, Dave, Alice, Carol
874 Carol, Alice, Dave, Bob
875 Carol, Dave, Alice, Bob
876 Carol, Dave, Bob, Alice
877 Dave, Alice, Bob, Carol
878 Dave, Carol, Alice, Bob
879 Dave, Carol, Bob, Alice
881 If *r* is given, only the *r*-length derangements are yielded.
883 >>> sorted(derangements(range(3), 2))
884 [(1, 0), (1, 2), (2, 0)]
885 >>> sorted(derangements([0, 2, 3], 2))
886 [(2, 0), (2, 3), (3, 0)]
888 Elements are treated as unique based on their position, not on their value.
890 Consider the Secret Santa example with two *different* people who have
891 the *same* name. Then there are two valid gift assignments even though
892 it might appear that a person is assigned to themselves:
894 >>> names = ['Alice', 'Bob', 'Bob']
895 >>> list(derangements(names))
896 [('Bob', 'Bob', 'Alice'), ('Bob', 'Alice', 'Bob')]
898 To avoid confusion, make the inputs distinct:
900 >>> deduped = [f'{name}{index}' for index, name in enumerate(names)]
901 >>> list(derangements(deduped))
902 [('Bob1', 'Bob2', 'Alice0'), ('Bob2', 'Alice0', 'Bob1')]
904 The number of derangements of a set of size *n* is known as the
905 "subfactorial of n". For n > 0, the subfactorial is:
906 ``round(math.factorial(n) / math.e)``.
908 References:
910 * Article: https://www.numberanalytics.com/blog/ultimate-guide-to-derangements-in-combinatorics
911 * Sizes: https://oeis.org/A000166
912 """
913 xs = tuple(iterable)
914 ys = tuple(range(len(xs)))
915 return compress(
916 permutations(xs, r=r),
917 map(all, map(map, repeat(is_not), repeat(ys), permutations(ys, r=r))),
918 )
921def intersperse(e, iterable, n=1):
922 """Intersperse filler element *e* among the items in *iterable*, leaving
923 *n* items between each filler element.
925 >>> list(intersperse('!', [1, 2, 3, 4, 5]))
926 [1, '!', 2, '!', 3, '!', 4, '!', 5]
928 >>> list(intersperse(None, [1, 2, 3, 4, 5], n=2))
929 [1, 2, None, 3, 4, None, 5]
931 """
932 if n == 0:
933 raise ValueError('n must be > 0')
934 elif n == 1:
935 # interleave(repeat(e), iterable) -> e, x_0, e, x_1, e, x_2...
936 # islice(..., 1, None) -> x_0, e, x_1, e, x_2...
937 return islice(interleave(repeat(e), iterable), 1, None)
938 else:
939 # interleave(filler, chunks) -> [e], [x_0, x_1], [e], [x_2, x_3]...
940 # islice(..., 1, None) -> [x_0, x_1], [e], [x_2, x_3]...
941 # flatten(...) -> x_0, x_1, e, x_2, x_3...
942 filler = repeat([e])
943 chunks = chunked(iterable, n)
944 return flatten(islice(interleave(filler, chunks), 1, None))
947def unique_to_each(*iterables):
948 """Return the elements from each of the input iterables that aren't in the
949 other input iterables.
951 For example, suppose you have a set of packages, each with a set of
952 dependencies::
954 {'pkg_1': {'A', 'B'}, 'pkg_2': {'B', 'C'}, 'pkg_3': {'B', 'D'}}
956 If you remove one package, which dependencies can also be removed?
958 If ``pkg_1`` is removed, then ``A`` is no longer necessary - it is not
959 associated with ``pkg_2`` or ``pkg_3``. Similarly, ``C`` is only needed for
960 ``pkg_2``, and ``D`` is only needed for ``pkg_3``::
962 >>> unique_to_each({'A', 'B'}, {'B', 'C'}, {'B', 'D'})
963 [['A'], ['C'], ['D']]
965 If there are duplicates in one input iterable that aren't in the others
966 they will be duplicated in the output. Input order is preserved::
968 >>> unique_to_each("mississippi", "missouri")
969 [['p', 'p'], ['o', 'u', 'r']]
971 It is assumed that the elements of each iterable are hashable.
973 """
974 pool = [list(it) for it in iterables]
975 counts = Counter(chain.from_iterable(map(set, pool)))
976 uniques = {element for element in counts if counts[element] == 1}
977 return [list(filter(uniques.__contains__, it)) for it in pool]
980def windowed(seq, n, fillvalue=None, step=1):
981 """Return a sliding window of width *n* over the given iterable.
983 >>> all_windows = windowed([1, 2, 3, 4, 5], 3)
984 >>> list(all_windows)
985 [(1, 2, 3), (2, 3, 4), (3, 4, 5)]
987 When the window is larger than the iterable, *fillvalue* is used in place
988 of missing values:
990 >>> list(windowed([1, 2, 3], 4))
991 [(1, 2, 3, None)]
993 Each window will advance in increments of *step*:
995 >>> list(windowed([1, 2, 3, 4, 5, 6], 3, fillvalue='!', step=2))
996 [(1, 2, 3), (3, 4, 5), (5, 6, '!')]
998 To slide into the iterable's items, use :func:`chain` to add filler items
999 to the left:
1001 >>> iterable = [1, 2, 3, 4]
1002 >>> n = 3
1003 >>> padding = [None] * (n - 1)
1004 >>> list(windowed(chain(padding, iterable), 3))
1005 [(None, None, 1), (None, 1, 2), (1, 2, 3), (2, 3, 4)]
1006 """
1007 if n < 0:
1008 raise ValueError('n must be >= 0')
1009 if n == 0:
1010 yield ()
1011 return
1012 if step < 1:
1013 raise ValueError('step must be >= 1')
1015 iterator = iter(seq)
1017 # Generate first window
1018 window = deque(islice(iterator, n), maxlen=n)
1020 # Deal with the first window not being full
1021 if not window:
1022 return
1023 if len(window) < n:
1024 yield tuple(window) + ((fillvalue,) * (n - len(window)))
1025 return
1026 yield tuple(window)
1028 # Create the filler for the next windows. The padding ensures
1029 # we have just enough elements to fill the last window.
1030 padding = (fillvalue,) * (n - 1 if step >= n else step - 1)
1031 filler = map(window.append, chain(iterator, padding))
1033 # Generate the rest of the windows
1034 for _ in islice(filler, step - 1, None, step):
1035 yield tuple(window)
1038def substrings(iterable):
1039 """Yield all of the substrings of *iterable*.
1041 >>> [''.join(s) for s in substrings('more')]
1042 ['m', 'o', 'r', 'e', 'mo', 'or', 're', 'mor', 'ore', 'more']
1044 Note that non-string iterables can also be subdivided.
1046 >>> list(substrings([0, 1, 2]))
1047 [(0,), (1,), (2,), (0, 1), (1, 2), (0, 1, 2)]
1049 Like subslices() but returns tuples instead of lists
1050 and returns the shortest substrings first.
1052 """
1053 seq = tuple(iterable)
1054 item_count = len(seq)
1055 for n in range(1, item_count + 1):
1056 slices = map(slice, range(item_count), range(n, item_count + 1))
1057 yield from map(getitem, repeat(seq), slices)
1060def substrings_indexes(seq, reverse=False):
1061 """Yield all substrings and their positions in *seq*
1063 The items yielded will be a tuple of the form ``(substr, i, j)``, where
1064 ``substr == seq[i:j]``.
1066 This function only works for iterables that support slicing, such as
1067 ``str`` objects.
1069 >>> for item in substrings_indexes('more'):
1070 ... print(item)
1071 ('m', 0, 1)
1072 ('o', 1, 2)
1073 ('r', 2, 3)
1074 ('e', 3, 4)
1075 ('mo', 0, 2)
1076 ('or', 1, 3)
1077 ('re', 2, 4)
1078 ('mor', 0, 3)
1079 ('ore', 1, 4)
1080 ('more', 0, 4)
1082 Set *reverse* to ``True`` to yield the same items in the opposite order.
1085 """
1086 r = range(1, len(seq) + 1)
1087 if reverse:
1088 r = reversed(r)
1089 return (
1090 (seq[i : i + L], i, i + L) for L in r for i in range(len(seq) - L + 1)
1091 )
1094class bucket:
1095 """Wrap *iterable* and return an object that buckets the iterable into
1096 child iterables based on a *key* function.
1098 >>> iterable = ['a1', 'b1', 'c1', 'a2', 'b2', 'c2', 'b3']
1099 >>> s = bucket(iterable, key=lambda x: x[0]) # Bucket by 1st character
1100 >>> sorted(list(s)) # Get the keys
1101 ['a', 'b', 'c']
1102 >>> a_iterable = s['a']
1103 >>> next(a_iterable)
1104 'a1'
1105 >>> next(a_iterable)
1106 'a2'
1107 >>> list(s['b'])
1108 ['b1', 'b2', 'b3']
1110 The original iterable will be advanced and its items will be cached until
1111 they are used by the child iterables. This may require significant storage.
1113 By default, attempting to select a bucket to which no items belong will
1114 exhaust the iterable and cache all values.
1115 If you specify a *validator* function, selected buckets will instead be
1116 checked against it.
1118 >>> from itertools import count
1119 >>> it = count(1, 2) # Infinite sequence of odd numbers
1120 >>> key = lambda x: x % 10 # Bucket by last digit
1121 >>> validator = lambda x: x in {1, 3, 5, 7, 9} # Odd digits only
1122 >>> s = bucket(it, key=key, validator=validator)
1123 >>> 2 in s
1124 False
1125 >>> list(s[2])
1126 []
1128 """
1130 def __init__(self, iterable, key, validator=None):
1131 self._it = iter(iterable)
1132 self._key = key
1133 self._cache = defaultdict(deque)
1134 self._validator = validator or (lambda x: True)
1136 def __contains__(self, value):
1137 if not self._validator(value):
1138 return False
1140 try:
1141 item = next(self[value])
1142 except StopIteration:
1143 return False
1144 else:
1145 self._cache[value].appendleft(item)
1147 return True
1149 def _get_values(self, value):
1150 """
1151 Helper to yield items from the parent iterator that match *value*.
1152 Items that don't match are stored in the local cache as they
1153 are encountered.
1154 """
1155 while True:
1156 # If we've cached some items that match the target value, emit
1157 # the first one and evict it from the cache.
1158 if self._cache[value]:
1159 yield self._cache[value].popleft()
1160 # Otherwise we need to advance the parent iterator to search for
1161 # a matching item, caching the rest.
1162 else:
1163 while True:
1164 try:
1165 item = next(self._it)
1166 except StopIteration:
1167 return
1168 item_value = self._key(item)
1169 if item_value == value:
1170 yield item
1171 break
1172 elif self._validator(item_value):
1173 self._cache[item_value].append(item)
1175 def __iter__(self):
1176 for item in self._it:
1177 item_value = self._key(item)
1178 if self._validator(item_value):
1179 self._cache[item_value].append(item)
1181 return iter(self._cache)
1183 def __getitem__(self, value):
1184 if not self._validator(value):
1185 return iter(())
1187 return self._get_values(value)
1190def spy(iterable, n=1):
1191 """Return a 2-tuple with a list containing the first *n* elements of
1192 *iterable*, and an iterator with the same items as *iterable*.
1193 This allows you to "look ahead" at the items in the iterable without
1194 advancing it.
1196 There is one item in the list by default:
1198 >>> iterable = 'abcdefg'
1199 >>> head, iterable = spy(iterable)
1200 >>> head
1201 ['a']
1202 >>> list(iterable)
1203 ['a', 'b', 'c', 'd', 'e', 'f', 'g']
1205 You may use unpacking to retrieve items instead of lists:
1207 >>> (head,), iterable = spy('abcdefg')
1208 >>> head
1209 'a'
1210 >>> (first, second), iterable = spy('abcdefg', 2)
1211 >>> first
1212 'a'
1213 >>> second
1214 'b'
1216 The number of items requested can be larger than the number of items in
1217 the iterable:
1219 >>> iterable = [1, 2, 3, 4, 5]
1220 >>> head, iterable = spy(iterable, 10)
1221 >>> head
1222 [1, 2, 3, 4, 5]
1223 >>> list(iterable)
1224 [1, 2, 3, 4, 5]
1226 """
1227 p, q = tee(iterable)
1228 return take(n, q), p
1231def interleave(*iterables):
1232 """Return a new iterable yielding from each iterable in turn,
1233 until the shortest is exhausted.
1235 >>> list(interleave([1, 2, 3], [4, 5], [6, 7, 8]))
1236 [1, 4, 6, 2, 5, 7]
1238 For a version that doesn't terminate after the shortest iterable is
1239 exhausted, see :func:`interleave_longest`.
1241 """
1242 return chain.from_iterable(zip(*iterables))
1245def interleave_longest(*iterables):
1246 """Return a new iterable yielding from each iterable in turn,
1247 skipping any that are exhausted.
1249 >>> list(interleave_longest([1, 2, 3], [4, 5], [6, 7, 8]))
1250 [1, 4, 6, 2, 5, 7, 3, 8]
1252 This function produces the same output as :func:`roundrobin`, but may
1253 perform better for some inputs (in particular when the number of iterables
1254 is large).
1256 """
1257 for xs in zip_longest(*iterables, fillvalue=_marker):
1258 for x in xs:
1259 if x is not _marker:
1260 yield x
1263def interleave_evenly(iterables, lengths=None):
1264 """
1265 Interleave multiple iterables so that their elements are evenly distributed
1266 throughout the output sequence.
1268 >>> iterables = [1, 2, 3, 4, 5], ['a', 'b']
1269 >>> list(interleave_evenly(iterables))
1270 [1, 2, 'a', 3, 4, 'b', 5]
1272 >>> iterables = [[1, 2, 3], [4, 5], [6, 7, 8]]
1273 >>> list(interleave_evenly(iterables))
1274 [1, 6, 4, 2, 7, 3, 8, 5]
1276 This function requires iterables of known length. Iterables without
1277 ``__len__()`` can be used by manually specifying lengths with *lengths*:
1279 >>> from itertools import combinations, repeat
1280 >>> iterables = [combinations(range(4), 2), ['a', 'b', 'c']]
1281 >>> lengths = [4 * (4 - 1) // 2, 3]
1282 >>> list(interleave_evenly(iterables, lengths=lengths))
1283 [(0, 1), (0, 2), 'a', (0, 3), (1, 2), 'b', (1, 3), (2, 3), 'c']
1285 Based on Bresenham's algorithm.
1286 """
1287 if lengths is None:
1288 try:
1289 lengths = [len(it) for it in iterables]
1290 except TypeError:
1291 raise ValueError(
1292 'Iterable lengths could not be determined automatically. '
1293 'Specify them with the lengths keyword.'
1294 )
1295 elif len(iterables) != len(lengths):
1296 raise ValueError('Mismatching number of iterables and lengths.')
1298 dims = len(lengths)
1300 # sort iterables by length, descending
1301 lengths_permute = sorted(
1302 range(dims), key=lambda i: lengths[i], reverse=True
1303 )
1304 lengths_desc = [lengths[i] for i in lengths_permute]
1305 iters_desc = [iter(iterables[i]) for i in lengths_permute]
1307 # the longest iterable is the primary one (Bresenham: the longest
1308 # distance along an axis)
1309 delta_primary, deltas_secondary = lengths_desc[0], lengths_desc[1:]
1310 iter_primary, iters_secondary = iters_desc[0], iters_desc[1:]
1311 errors = [delta_primary // dims] * len(deltas_secondary)
1313 to_yield = sum(lengths)
1314 while to_yield:
1315 yield next(iter_primary)
1316 to_yield -= 1
1317 # update errors for each secondary iterable
1318 errors = [e - delta for e, delta in zip(errors, deltas_secondary)]
1320 # those iterables for which the error is negative are yielded
1321 # ("diagonal step" in Bresenham)
1322 for i, e_ in enumerate(errors):
1323 if e_ < 0:
1324 yield next(iters_secondary[i])
1325 to_yield -= 1
1326 errors[i] += delta_primary
1329def interleave_randomly(*iterables):
1330 """Repeatedly select one of the input *iterables* at random and yield the next
1331 item from it.
1333 >>> iterables = [1, 2, 3], 'abc', (True, False, None)
1334 >>> list(interleave_randomly(*iterables)) # doctest: +SKIP
1335 ['a', 'b', 1, 'c', True, False, None, 2, 3]
1337 The relative order of the items in each input iterable will preserved. Note the
1338 sequences of items with this property are not equally likely to be generated.
1340 """
1341 iterators = [iter(e) for e in iterables]
1342 while iterators:
1343 idx = randrange(len(iterators))
1344 try:
1345 yield next(iterators[idx])
1346 except StopIteration:
1347 # equivalent to `list.pop` but slightly faster
1348 iterators[idx] = iterators[-1]
1349 del iterators[-1]
1352def collapse(iterable, base_type=None, levels=None):
1353 """Flatten an iterable with multiple levels of nesting (e.g., a list of
1354 lists of tuples) into non-iterable types.
1356 >>> iterable = [(1, 2), ([3, 4], [[5], [6]])]
1357 >>> list(collapse(iterable))
1358 [1, 2, 3, 4, 5, 6]
1360 Binary and text strings are not considered iterable and
1361 will not be collapsed.
1363 To avoid collapsing other types, specify *base_type*:
1365 >>> iterable = ['ab', ('cd', 'ef'), ['gh', 'ij']]
1366 >>> list(collapse(iterable, base_type=tuple))
1367 ['ab', ('cd', 'ef'), 'gh', 'ij']
1369 Specify *levels* to stop flattening after a certain level:
1371 >>> iterable = [('a', ['b']), ('c', ['d'])]
1372 >>> list(collapse(iterable)) # Fully flattened
1373 ['a', 'b', 'c', 'd']
1374 >>> list(collapse(iterable, levels=1)) # Only one level flattened
1375 ['a', ['b'], 'c', ['d']]
1377 """
1378 stack = deque()
1379 # Add our first node group, treat the iterable as a single node
1380 stack.appendleft((0, repeat(iterable, 1)))
1382 while stack:
1383 node_group = stack.popleft()
1384 level, nodes = node_group
1386 # Check if beyond max level
1387 if levels is not None and level > levels:
1388 yield from nodes
1389 continue
1391 for node in nodes:
1392 # Check if done iterating
1393 if isinstance(node, (str, bytes)) or (
1394 (base_type is not None) and isinstance(node, base_type)
1395 ):
1396 yield node
1397 # Otherwise try to create child nodes
1398 else:
1399 try:
1400 tree = iter(node)
1401 except TypeError:
1402 yield node
1403 else:
1404 # Save our current location
1405 stack.appendleft(node_group)
1406 # Append the new child node
1407 stack.appendleft((level + 1, tree))
1408 # Break to process child node
1409 break
1412def side_effect(func, iterable, chunk_size=None, before=None, after=None):
1413 """Invoke *func* on each item in *iterable* (or on each *chunk_size* group
1414 of items) before yielding the item.
1416 `func` must be a function that takes a single argument. Its return value
1417 will be discarded.
1419 *before* and *after* are optional functions that take no arguments. They
1420 will be executed before iteration starts and after it ends, respectively.
1422 `side_effect` can be used for logging, updating progress bars, or anything
1423 that is not functionally "pure."
1425 Emitting a status message:
1427 >>> from more_itertools import consume
1428 >>> func = lambda item: print('Received {}'.format(item))
1429 >>> consume(side_effect(func, range(2)))
1430 Received 0
1431 Received 1
1433 Operating on chunks of items:
1435 >>> pair_sums = []
1436 >>> func = lambda chunk: pair_sums.append(sum(chunk))
1437 >>> list(side_effect(func, [0, 1, 2, 3, 4, 5], 2))
1438 [0, 1, 2, 3, 4, 5]
1439 >>> list(pair_sums)
1440 [1, 5, 9]
1442 Writing to a file-like object:
1444 >>> from io import StringIO
1445 >>> from more_itertools import consume
1446 >>> f = StringIO()
1447 >>> func = lambda x: print(x, file=f)
1448 >>> before = lambda: print(u'HEADER', file=f)
1449 >>> after = f.close
1450 >>> it = [u'a', u'b', u'c']
1451 >>> consume(side_effect(func, it, before=before, after=after))
1452 >>> f.closed
1453 True
1455 """
1456 try:
1457 if before is not None:
1458 before()
1460 if chunk_size is None:
1461 for item in iterable:
1462 func(item)
1463 yield item
1464 else:
1465 for chunk in chunked(iterable, chunk_size):
1466 func(chunk)
1467 yield from chunk
1468 finally:
1469 if after is not None:
1470 after()
1473def sliced(seq, n, strict=False):
1474 """Yield slices of length *n* from the sequence *seq*.
1476 >>> list(sliced((1, 2, 3, 4, 5, 6), 3))
1477 [(1, 2, 3), (4, 5, 6)]
1479 By the default, the last yielded slice will have fewer than *n* elements
1480 if the length of *seq* is not divisible by *n*:
1482 >>> list(sliced((1, 2, 3, 4, 5, 6, 7, 8), 3))
1483 [(1, 2, 3), (4, 5, 6), (7, 8)]
1485 If the length of *seq* is not divisible by *n* and *strict* is
1486 ``True``, then ``ValueError`` will be raised before the last
1487 slice is yielded.
1489 This function will only work for iterables that support slicing.
1490 For non-sliceable iterables, see :func:`chunked`.
1492 """
1493 iterator = takewhile(len, (seq[i : i + n] for i in count(0, n)))
1494 if strict:
1496 def ret():
1497 for _slice in iterator:
1498 if len(_slice) != n:
1499 raise ValueError("seq is not divisible by n.")
1500 yield _slice
1502 return ret()
1503 else:
1504 return iterator
1507def split_at(iterable, pred, maxsplit=-1, keep_separator=False):
1508 """Yield lists of items from *iterable*, where each list is delimited by
1509 an item where callable *pred* returns ``True``.
1511 >>> list(split_at('abcdcba', lambda x: x == 'b'))
1512 [['a'], ['c', 'd', 'c'], ['a']]
1514 >>> list(split_at(range(10), lambda n: n % 2 == 1))
1515 [[0], [2], [4], [6], [8], []]
1517 At most *maxsplit* splits are done. If *maxsplit* is not specified or -1,
1518 then there is no limit on the number of splits:
1520 >>> list(split_at(range(10), lambda n: n % 2 == 1, maxsplit=2))
1521 [[0], [2], [4, 5, 6, 7, 8, 9]]
1523 By default, the delimiting items are not included in the output.
1524 To include them, set *keep_separator* to ``True``.
1526 >>> list(split_at('abcdcba', lambda x: x == 'b', keep_separator=True))
1527 [['a'], ['b'], ['c', 'd', 'c'], ['b'], ['a']]
1529 """
1530 if maxsplit == 0:
1531 yield list(iterable)
1532 return
1534 buf = []
1535 it = iter(iterable)
1536 for item in it:
1537 if pred(item):
1538 yield buf
1539 if keep_separator:
1540 yield [item]
1541 if maxsplit == 1:
1542 yield list(it)
1543 return
1544 buf = []
1545 maxsplit -= 1
1546 else:
1547 buf.append(item)
1548 yield buf
1551def split_before(iterable, pred, maxsplit=-1):
1552 """Yield lists of items from *iterable*, where each list ends just before
1553 an item for which callable *pred* returns ``True``:
1555 >>> list(split_before('OneTwo', lambda s: s.isupper()))
1556 [['O', 'n', 'e'], ['T', 'w', 'o']]
1558 >>> list(split_before(range(10), lambda n: n % 3 == 0))
1559 [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
1561 At most *maxsplit* splits are done. If *maxsplit* is not specified or -1,
1562 then there is no limit on the number of splits:
1564 >>> list(split_before(range(10), lambda n: n % 3 == 0, maxsplit=2))
1565 [[0, 1, 2], [3, 4, 5], [6, 7, 8, 9]]
1566 """
1567 if maxsplit == 0:
1568 yield list(iterable)
1569 return
1571 buf = []
1572 it = iter(iterable)
1573 for item in it:
1574 if pred(item) and buf:
1575 yield buf
1576 if maxsplit == 1:
1577 yield [item, *it]
1578 return
1579 buf = []
1580 maxsplit -= 1
1581 buf.append(item)
1582 if buf:
1583 yield buf
1586def split_after(iterable, pred, maxsplit=-1):
1587 """Yield lists of items from *iterable*, where each list ends with an
1588 item where callable *pred* returns ``True``:
1590 >>> list(split_after('one1two2', lambda s: s.isdigit()))
1591 [['o', 'n', 'e', '1'], ['t', 'w', 'o', '2']]
1593 >>> list(split_after(range(10), lambda n: n % 3 == 0))
1594 [[0], [1, 2, 3], [4, 5, 6], [7, 8, 9]]
1596 At most *maxsplit* splits are done. If *maxsplit* is not specified or -1,
1597 then there is no limit on the number of splits:
1599 >>> list(split_after(range(10), lambda n: n % 3 == 0, maxsplit=2))
1600 [[0], [1, 2, 3], [4, 5, 6, 7, 8, 9]]
1602 """
1603 if maxsplit == 0:
1604 yield list(iterable)
1605 return
1607 buf = []
1608 it = iter(iterable)
1609 for item in it:
1610 buf.append(item)
1611 if pred(item) and buf:
1612 yield buf
1613 if maxsplit == 1:
1614 buf = list(it)
1615 if buf:
1616 yield buf
1617 return
1618 buf = []
1619 maxsplit -= 1
1620 if buf:
1621 yield buf
1624def split_when(iterable, pred, maxsplit=-1):
1625 """Split *iterable* into pieces based on the output of *pred*.
1626 *pred* should be a function that takes successive pairs of items and
1627 returns ``True`` if the iterable should be split in between them.
1629 For example, to find runs of increasing numbers, split the iterable when
1630 element ``i`` is larger than element ``i + 1``:
1632 >>> list(split_when([1, 2, 3, 3, 2, 5, 2, 4, 2], lambda x, y: x > y))
1633 [[1, 2, 3, 3], [2, 5], [2, 4], [2]]
1635 At most *maxsplit* splits are done. If *maxsplit* is not specified or -1,
1636 then there is no limit on the number of splits:
1638 >>> list(split_when([1, 2, 3, 3, 2, 5, 2, 4, 2],
1639 ... lambda x, y: x > y, maxsplit=2))
1640 [[1, 2, 3, 3], [2, 5], [2, 4, 2]]
1642 """
1643 if maxsplit == 0:
1644 yield list(iterable)
1645 return
1647 it = iter(iterable)
1648 try:
1649 cur_item = next(it)
1650 except StopIteration:
1651 return
1653 buf = [cur_item]
1654 for next_item in it:
1655 if pred(cur_item, next_item):
1656 yield buf
1657 if maxsplit == 1:
1658 yield [next_item, *it]
1659 return
1660 buf = []
1661 maxsplit -= 1
1663 buf.append(next_item)
1664 cur_item = next_item
1666 yield buf
1669def split_into(iterable, sizes):
1670 """Yield a list of sequential items from *iterable* of length 'n' for each
1671 integer 'n' in *sizes*.
1673 >>> list(split_into([1,2,3,4,5,6], [1,2,3]))
1674 [[1], [2, 3], [4, 5, 6]]
1676 If the sum of *sizes* is smaller than the length of *iterable*, then the
1677 remaining items of *iterable* will not be returned.
1679 >>> list(split_into([1,2,3,4,5,6], [2,3]))
1680 [[1, 2], [3, 4, 5]]
1682 If the sum of *sizes* is larger than the length of *iterable*, fewer items
1683 will be returned in the iteration that overruns the *iterable* and further
1684 lists will be empty:
1686 >>> list(split_into([1,2,3,4], [1,2,3,4]))
1687 [[1], [2, 3], [4], []]
1689 When a ``None`` object is encountered in *sizes*, the returned list will
1690 contain items up to the end of *iterable* the same way that
1691 :func:`itertools.slice` does:
1693 >>> list(split_into([1,2,3,4,5,6,7,8,9,0], [2,3,None]))
1694 [[1, 2], [3, 4, 5], [6, 7, 8, 9, 0]]
1696 :func:`split_into` can be useful for grouping a series of items where the
1697 sizes of the groups are not uniform. An example would be where in a row
1698 from a table, multiple columns represent elements of the same feature
1699 (e.g. a point represented by x,y,z) but, the format is not the same for
1700 all columns.
1701 """
1702 # convert the iterable argument into an iterator so its contents can
1703 # be consumed by islice in case it is a generator
1704 it = iter(iterable)
1706 for size in sizes:
1707 if size is None:
1708 yield list(it)
1709 return
1710 else:
1711 yield list(islice(it, size))
1714def padded(iterable, fillvalue=None, n=None, next_multiple=False):
1715 """Yield the elements from *iterable*, followed by *fillvalue*, such that
1716 at least *n* items are emitted.
1718 >>> list(padded([1, 2, 3], '?', 5))
1719 [1, 2, 3, '?', '?']
1721 If *next_multiple* is ``True``, *fillvalue* will be emitted until the
1722 number of items emitted is a multiple of *n*:
1724 >>> list(padded([1, 2, 3, 4], n=3, next_multiple=True))
1725 [1, 2, 3, 4, None, None]
1727 If *n* is ``None``, *fillvalue* will be emitted indefinitely.
1729 To create an *iterable* of exactly size *n*, you can truncate with
1730 :func:`islice`.
1732 >>> list(islice(padded([1, 2, 3], '?'), 5))
1733 [1, 2, 3, '?', '?']
1734 >>> list(islice(padded([1, 2, 3, 4, 5, 6, 7, 8], '?'), 5))
1735 [1, 2, 3, 4, 5]
1737 """
1738 iterator = iter(iterable)
1739 iterator_with_repeat = chain(iterator, repeat(fillvalue))
1741 if n is None:
1742 return iterator_with_repeat
1743 elif n < 1:
1744 raise ValueError('n must be at least 1')
1745 elif next_multiple:
1747 def slice_generator():
1748 for first in iterator:
1749 yield (first,)
1750 yield islice(iterator_with_repeat, n - 1)
1752 # While elements exist produce slices of size n
1753 return chain.from_iterable(slice_generator())
1754 else:
1755 # Ensure the first batch is at least size n then iterate
1756 return chain(islice(iterator_with_repeat, n), iterator)
1759def repeat_each(iterable, n=2):
1760 """Repeat each element in *iterable* *n* times.
1762 >>> list(repeat_each('ABC', 3))
1763 ['A', 'A', 'A', 'B', 'B', 'B', 'C', 'C', 'C']
1764 """
1765 return chain.from_iterable(map(repeat, iterable, repeat(n)))
1768def repeat_last(iterable, default=None):
1769 """After the *iterable* is exhausted, keep yielding its last element.
1771 >>> list(islice(repeat_last(range(3)), 5))
1772 [0, 1, 2, 2, 2]
1774 If the iterable is empty, yield *default* forever::
1776 >>> list(islice(repeat_last(range(0), 42), 5))
1777 [42, 42, 42, 42, 42]
1779 """
1780 item = _marker
1781 for item in iterable:
1782 yield item
1783 final = default if item is _marker else item
1784 yield from repeat(final)
1787def distribute(n, iterable):
1788 """Distribute the items from *iterable* among *n* smaller iterables.
1790 >>> group_1, group_2 = distribute(2, [1, 2, 3, 4, 5, 6])
1791 >>> list(group_1)
1792 [1, 3, 5]
1793 >>> list(group_2)
1794 [2, 4, 6]
1796 If the length of *iterable* is not evenly divisible by *n*, then the
1797 length of the returned iterables will not be identical:
1799 >>> children = distribute(3, [1, 2, 3, 4, 5, 6, 7])
1800 >>> [list(c) for c in children]
1801 [[1, 4, 7], [2, 5], [3, 6]]
1803 If the length of *iterable* is smaller than *n*, then the last returned
1804 iterables will be empty:
1806 >>> children = distribute(5, [1, 2, 3])
1807 >>> [list(c) for c in children]
1808 [[1], [2], [3], [], []]
1810 This function uses :func:`itertools.tee` and may require significant
1811 storage.
1813 If you need the order items in the smaller iterables to match the
1814 original iterable, see :func:`divide`.
1816 """
1817 if n < 1:
1818 raise ValueError('n must be at least 1')
1820 children = tee(iterable, n)
1821 return [islice(it, index, None, n) for index, it in enumerate(children)]
1824def stagger(iterable, offsets=(-1, 0, 1), longest=False, fillvalue=None):
1825 """Yield tuples whose elements are offset from *iterable*.
1826 The amount by which the `i`-th item in each tuple is offset is given by
1827 the `i`-th item in *offsets*.
1829 >>> list(stagger([0, 1, 2, 3]))
1830 [(None, 0, 1), (0, 1, 2), (1, 2, 3)]
1831 >>> list(stagger(range(8), offsets=(0, 2, 4)))
1832 [(0, 2, 4), (1, 3, 5), (2, 4, 6), (3, 5, 7)]
1834 By default, the sequence will end when the final element of a tuple is the
1835 last item in the iterable. To continue until the first element of a tuple
1836 is the last item in the iterable, set *longest* to ``True``::
1838 >>> list(stagger([0, 1, 2, 3], longest=True))
1839 [(None, 0, 1), (0, 1, 2), (1, 2, 3), (2, 3, None), (3, None, None)]
1841 By default, ``None`` will be used to replace offsets beyond the end of the
1842 sequence. Specify *fillvalue* to use some other value.
1844 """
1845 children = tee(iterable, len(offsets))
1847 return zip_offset(
1848 *children, offsets=offsets, longest=longest, fillvalue=fillvalue
1849 )
1852def zip_offset(*iterables, offsets, longest=False, fillvalue=None):
1853 """``zip`` the input *iterables* together, but offset the `i`-th iterable
1854 by the `i`-th item in *offsets*.
1856 >>> list(zip_offset('0123', 'abcdef', offsets=(0, 1)))
1857 [('0', 'b'), ('1', 'c'), ('2', 'd'), ('3', 'e')]
1859 This can be used as a lightweight alternative to SciPy or pandas to analyze
1860 data sets in which some series have a lead or lag relationship.
1862 By default, the sequence will end when the shortest iterable is exhausted.
1863 To continue until the longest iterable is exhausted, set *longest* to
1864 ``True``.
1866 >>> list(zip_offset('0123', 'abcdef', offsets=(0, 1), longest=True))
1867 [('0', 'b'), ('1', 'c'), ('2', 'd'), ('3', 'e'), (None, 'f')]
1869 By default, ``None`` will be used to replace offsets beyond the end of the
1870 sequence. Specify *fillvalue* to use some other value.
1872 """
1873 if len(iterables) != len(offsets):
1874 raise ValueError("Number of iterables and offsets didn't match")
1876 staggered = []
1877 for it, n in zip(iterables, offsets):
1878 if n < 0:
1879 staggered.append(chain(repeat(fillvalue, -n), it))
1880 elif n > 0:
1881 staggered.append(islice(it, n, None))
1882 else:
1883 staggered.append(it)
1885 if longest:
1886 return zip_longest(*staggered, fillvalue=fillvalue)
1888 return zip(*staggered)
1891def sort_together(
1892 iterables, key_list=(0,), key=None, reverse=False, strict=False
1893):
1894 """Return the input iterables sorted together, with *key_list* as the
1895 priority for sorting. All iterables are trimmed to the length of the
1896 shortest one.
1898 This can be used like the sorting function in a spreadsheet. If each
1899 iterable represents a column of data, the key list determines which
1900 columns are used for sorting.
1902 By default, all iterables are sorted using the ``0``-th iterable::
1904 >>> iterables = [(4, 3, 2, 1), ('a', 'b', 'c', 'd')]
1905 >>> sort_together(iterables)
1906 [(1, 2, 3, 4), ('d', 'c', 'b', 'a')]
1908 Set a different key list to sort according to another iterable.
1909 Specifying multiple keys dictates how ties are broken::
1911 >>> iterables = [(3, 1, 2), (0, 1, 0), ('c', 'b', 'a')]
1912 >>> sort_together(iterables, key_list=(1, 2))
1913 [(2, 3, 1), (0, 0, 1), ('a', 'c', 'b')]
1915 To sort by a function of the elements of the iterable, pass a *key*
1916 function. Its arguments are the elements of the iterables corresponding to
1917 the key list::
1919 >>> names = ('a', 'b', 'c')
1920 >>> lengths = (1, 2, 3)
1921 >>> widths = (5, 2, 1)
1922 >>> def area(length, width):
1923 ... return length * width
1924 >>> sort_together([names, lengths, widths], key_list=(1, 2), key=area)
1925 [('c', 'b', 'a'), (3, 2, 1), (1, 2, 5)]
1927 Set *reverse* to ``True`` to sort in descending order.
1929 >>> sort_together([(1, 2, 3), ('c', 'b', 'a')], reverse=True)
1930 [(3, 2, 1), ('a', 'b', 'c')]
1932 If the *strict* keyword argument is ``True``, then
1933 ``ValueError`` will be raised if any of the iterables have
1934 different lengths.
1936 """
1937 if key is None:
1938 # if there is no key function, the key argument to sorted is an
1939 # itemgetter
1940 key_argument = itemgetter(*key_list)
1941 else:
1942 # if there is a key function, call it with the items at the offsets
1943 # specified by the key function as arguments
1944 key_list = list(key_list)
1945 if len(key_list) == 1:
1946 # if key_list contains a single item, pass the item at that offset
1947 # as the only argument to the key function
1948 key_offset = key_list[0]
1949 key_argument = lambda zipped_items: key(zipped_items[key_offset])
1950 else:
1951 # if key_list contains multiple items, use itemgetter to return a
1952 # tuple of items, which we pass as *args to the key function
1953 get_key_items = itemgetter(*key_list)
1954 key_argument = lambda zipped_items: key(
1955 *get_key_items(zipped_items)
1956 )
1958 transposed = zip(*iterables, strict=strict)
1959 reordered = sorted(transposed, key=key_argument, reverse=reverse)
1960 untransposed = zip(*reordered, strict=strict)
1961 return list(untransposed)
1964def unzip(iterable):
1965 """The inverse of :func:`zip`, this function disaggregates the elements
1966 of the zipped *iterable*.
1968 The ``i``-th iterable contains the ``i``-th element from each element
1969 of the zipped iterable. The first element is used to determine the
1970 length of the remaining elements.
1972 >>> iterable = [('a', 1), ('b', 2), ('c', 3), ('d', 4)]
1973 >>> letters, numbers = unzip(iterable)
1974 >>> list(letters)
1975 ['a', 'b', 'c', 'd']
1976 >>> list(numbers)
1977 [1, 2, 3, 4]
1979 This is similar to using ``zip(*iterable)``, but it avoids reading
1980 *iterable* into memory. Note, however, that this function uses
1981 :func:`itertools.tee` and thus may require significant storage.
1983 """
1984 head, iterable = spy(iterable)
1985 if not head:
1986 # empty iterable, e.g. zip([], [], [])
1987 return ()
1988 # spy returns a one-length iterable as head
1989 head = head[0]
1990 iterables = tee(iterable, len(head))
1992 # If we have an iterable like iter([(1, 2, 3), (4, 5), (6,)]),
1993 # the second unzipped iterable fails at the third tuple since
1994 # it tries to access (6,)[1].
1995 # Same with the third unzipped iterable and the second tuple.
1996 # To support these "improperly zipped" iterables, we suppress
1997 # the IndexError, which just stops the unzipped iterables at
1998 # first length mismatch.
1999 return tuple(
2000 iter_suppress(map(itemgetter(i), it), IndexError)
2001 for i, it in enumerate(iterables)
2002 )
2005def divide(n, iterable):
2006 """Divide the elements from *iterable* into *n* parts, maintaining
2007 order.
2009 >>> group_1, group_2 = divide(2, [1, 2, 3, 4, 5, 6])
2010 >>> list(group_1)
2011 [1, 2, 3]
2012 >>> list(group_2)
2013 [4, 5, 6]
2015 If the length of *iterable* is not evenly divisible by *n*, then the
2016 length of the returned iterables will not be identical:
2018 >>> children = divide(3, [1, 2, 3, 4, 5, 6, 7])
2019 >>> [list(c) for c in children]
2020 [[1, 2, 3], [4, 5], [6, 7]]
2022 If the length of the iterable is smaller than n, then the last returned
2023 iterables will be empty:
2025 >>> children = divide(5, [1, 2, 3])
2026 >>> [list(c) for c in children]
2027 [[1], [2], [3], [], []]
2029 This function will exhaust the iterable before returning.
2030 If order is not important, see :func:`distribute`, which does not first
2031 pull the iterable into memory.
2033 """
2034 if n < 1:
2035 raise ValueError('n must be at least 1')
2037 try:
2038 iterable[:0]
2039 except TypeError:
2040 seq = tuple(iterable)
2041 else:
2042 seq = iterable
2044 q, r = divmod(len(seq), n)
2046 ret = []
2047 stop = 0
2048 for i in range(1, n + 1):
2049 start = stop
2050 stop += q + 1 if i <= r else q
2051 ret.append(iter(seq[start:stop]))
2053 return ret
2056def always_iterable(obj, base_type=(str, bytes)):
2057 """If *obj* is iterable, return an iterator over its items::
2059 >>> obj = (1, 2, 3)
2060 >>> list(always_iterable(obj))
2061 [1, 2, 3]
2063 If *obj* is not iterable, return a one-item iterable containing *obj*::
2065 >>> obj = 1
2066 >>> list(always_iterable(obj))
2067 [1]
2069 If *obj* is ``None``, return an empty iterable:
2071 >>> obj = None
2072 >>> list(always_iterable(None))
2073 []
2075 By default, binary and text strings are not considered iterable::
2077 >>> obj = 'foo'
2078 >>> list(always_iterable(obj))
2079 ['foo']
2081 If *base_type* is set, objects for which ``isinstance(obj, base_type)``
2082 returns ``True`` won't be considered iterable.
2084 >>> obj = {'a': 1}
2085 >>> list(always_iterable(obj)) # Iterate over the dict's keys
2086 ['a']
2087 >>> list(always_iterable(obj, base_type=dict)) # Treat dicts as a unit
2088 [{'a': 1}]
2090 Set *base_type* to ``None`` to avoid any special handling and treat objects
2091 Python considers iterable as iterable:
2093 >>> obj = 'foo'
2094 >>> list(always_iterable(obj, base_type=None))
2095 ['f', 'o', 'o']
2096 """
2097 if obj is None:
2098 return iter(())
2100 if (base_type is not None) and isinstance(obj, base_type):
2101 return iter((obj,))
2103 try:
2104 return iter(obj)
2105 except TypeError:
2106 return iter((obj,))
2109def adjacent(predicate, iterable, distance=1):
2110 """Return an iterable over `(bool, item)` tuples where the `item` is
2111 drawn from *iterable* and the `bool` indicates whether
2112 that item satisfies the *predicate* or is adjacent to an item that does.
2114 For example, to find whether items are adjacent to a ``3``::
2116 >>> list(adjacent(lambda x: x == 3, range(6)))
2117 [(False, 0), (False, 1), (True, 2), (True, 3), (True, 4), (False, 5)]
2119 Set *distance* to change what counts as adjacent. For example, to find
2120 whether items are two places away from a ``3``:
2122 >>> list(adjacent(lambda x: x == 3, range(6), distance=2))
2123 [(False, 0), (True, 1), (True, 2), (True, 3), (True, 4), (True, 5)]
2125 This is useful for contextualizing the results of a search function.
2126 For example, a code comparison tool might want to identify lines that
2127 have changed, but also surrounding lines to give the viewer of the diff
2128 context.
2130 The predicate function will only be called once for each item in the
2131 iterable.
2133 See also :func:`groupby_transform`, which can be used with this function
2134 to group ranges of items with the same `bool` value.
2136 """
2137 # Allow distance=0 mainly for testing that it reproduces results with map()
2138 if distance < 0:
2139 raise ValueError('distance must be at least 0')
2141 i1, i2 = tee(iterable)
2142 padding = [False] * distance
2143 selected = chain(padding, map(predicate, i1), padding)
2144 adjacent_to_selected = map(any, windowed(selected, 2 * distance + 1))
2145 return zip(adjacent_to_selected, i2)
2148def groupby_transform(iterable, keyfunc=None, valuefunc=None, reducefunc=None):
2149 """An extension of :func:`itertools.groupby` that can apply transformations
2150 to the grouped data.
2152 * *keyfunc* is a function computing a key value for each item in *iterable*
2153 * *valuefunc* is a function that transforms the individual items from
2154 *iterable* after grouping
2155 * *reducefunc* is a function that transforms each group of items
2157 >>> iterable = 'aAAbBBcCC'
2158 >>> keyfunc = lambda k: k.upper()
2159 >>> valuefunc = lambda v: v.lower()
2160 >>> reducefunc = lambda g: ''.join(g)
2161 >>> list(groupby_transform(iterable, keyfunc, valuefunc, reducefunc))
2162 [('A', 'aaa'), ('B', 'bbb'), ('C', 'ccc')]
2164 Each optional argument defaults to an identity function if not specified.
2166 :func:`groupby_transform` is useful when grouping elements of an iterable
2167 using a separate iterable as the key. To do this, :func:`zip` the iterables
2168 and pass a *keyfunc* that extracts the first element and a *valuefunc*
2169 that extracts the second element::
2171 >>> from operator import itemgetter
2172 >>> keys = [0, 0, 1, 1, 1, 2, 2, 2, 3]
2173 >>> values = 'abcdefghi'
2174 >>> iterable = zip(keys, values)
2175 >>> grouper = groupby_transform(iterable, itemgetter(0), itemgetter(1))
2176 >>> [(k, ''.join(g)) for k, g in grouper]
2177 [(0, 'ab'), (1, 'cde'), (2, 'fgh'), (3, 'i')]
2179 Note that the order of items in the iterable is significant.
2180 Only adjacent items are grouped together, so if you don't want any
2181 duplicate groups, you should sort the iterable by the key function.
2183 """
2184 ret = groupby(iterable, keyfunc)
2185 if valuefunc:
2186 ret = ((k, map(valuefunc, g)) for k, g in ret)
2187 if reducefunc:
2188 ret = ((k, reducefunc(g)) for k, g in ret)
2190 return ret
2193class numeric_range(Sequence):
2194 """An extension of the built-in ``range()`` function whose arguments can
2195 be any orderable numeric type.
2197 With only *stop* specified, *start* defaults to ``0`` and *step*
2198 defaults to ``1``. The output items will match the type of *stop*:
2200 >>> list(numeric_range(3.5))
2201 [0.0, 1.0, 2.0, 3.0]
2203 With only *start* and *stop* specified, *step* defaults to ``1``. The
2204 output items will match the type of *start*:
2206 >>> from decimal import Decimal
2207 >>> start = Decimal('2.1')
2208 >>> stop = Decimal('5.1')
2209 >>> list(numeric_range(start, stop))
2210 [Decimal('2.1'), Decimal('3.1'), Decimal('4.1')]
2212 With *start*, *stop*, and *step* specified the output items will match
2213 the type of ``start + step``:
2215 >>> from fractions import Fraction
2216 >>> start = Fraction(1, 2) # Start at 1/2
2217 >>> stop = Fraction(5, 2) # End at 5/2
2218 >>> step = Fraction(1, 2) # Count by 1/2
2219 >>> list(numeric_range(start, stop, step))
2220 [Fraction(1, 2), Fraction(1, 1), Fraction(3, 2), Fraction(2, 1)]
2222 If *step* is zero, ``ValueError`` is raised. Negative steps are supported:
2224 >>> list(numeric_range(3, -1, -1.0))
2225 [3.0, 2.0, 1.0, 0.0]
2227 Be aware of the limitations of floating-point numbers; the representation
2228 of the yielded numbers may be surprising.
2230 ``datetime.datetime`` objects can be used for *start* and *stop*, if *step*
2231 is a ``datetime.timedelta`` object:
2233 >>> import datetime
2234 >>> start = datetime.datetime(2019, 1, 1)
2235 >>> stop = datetime.datetime(2019, 1, 3)
2236 >>> step = datetime.timedelta(days=1)
2237 >>> items = iter(numeric_range(start, stop, step))
2238 >>> next(items)
2239 datetime.datetime(2019, 1, 1, 0, 0)
2240 >>> next(items)
2241 datetime.datetime(2019, 1, 2, 0, 0)
2243 """
2245 _EMPTY_HASH = hash(range(0, 0))
2247 def __init__(self, *args):
2248 argc = len(args)
2249 if argc == 1:
2250 (self._stop,) = args
2251 self._start = type(self._stop)(0)
2252 self._step = type(self._stop - self._start)(1)
2253 elif argc == 2:
2254 self._start, self._stop = args
2255 self._step = type(self._stop - self._start)(1)
2256 elif argc == 3:
2257 self._start, self._stop, self._step = args
2258 elif argc == 0:
2259 raise TypeError(
2260 f'numeric_range expected at least 1 argument, got {argc}'
2261 )
2262 else:
2263 raise TypeError(
2264 f'numeric_range expected at most 3 arguments, got {argc}'
2265 )
2267 self._zero = type(self._step)(0)
2268 if self._step == self._zero:
2269 raise ValueError('numeric_range() arg 3 must not be zero')
2270 self._growing = self._step > self._zero
2272 def __bool__(self):
2273 if self._growing:
2274 return self._start < self._stop
2275 else:
2276 return self._start > self._stop
2278 def __contains__(self, elem):
2279 if self._growing:
2280 if self._start <= elem < self._stop:
2281 return (elem - self._start) % self._step == self._zero
2282 else:
2283 if self._start >= elem > self._stop:
2284 return (self._start - elem) % (-self._step) == self._zero
2286 return False
2288 def __eq__(self, other):
2289 if isinstance(other, numeric_range):
2290 empty_self = not bool(self)
2291 empty_other = not bool(other)
2292 if empty_self or empty_other:
2293 return empty_self and empty_other # True if both empty
2294 else:
2295 return (
2296 self._start == other._start
2297 and self._step == other._step
2298 and self._get_by_index(-1) == other._get_by_index(-1)
2299 )
2300 else:
2301 return False
2303 def __getitem__(self, key):
2304 if isinstance(key, int):
2305 return self._get_by_index(key)
2306 elif isinstance(key, slice):
2307 step = self._step if key.step is None else key.step * self._step
2309 if key.start is None or key.start <= -self._len:
2310 start = self._start
2311 elif key.start >= self._len:
2312 start = self._stop
2313 else: # -self._len < key.start < self._len
2314 start = self._get_by_index(key.start)
2316 if key.stop is None or key.stop >= self._len:
2317 stop = self._stop
2318 elif key.stop <= -self._len:
2319 stop = self._start
2320 else: # -self._len < key.stop < self._len
2321 stop = self._get_by_index(key.stop)
2323 return numeric_range(start, stop, step)
2324 else:
2325 raise TypeError(
2326 'numeric range indices must be '
2327 f'integers or slices, not {type(key).__name__}'
2328 )
2330 def __hash__(self):
2331 if self:
2332 return hash((self._start, self._get_by_index(-1), self._step))
2333 else:
2334 return self._EMPTY_HASH
2336 def __iter__(self):
2337 values = (self._start + (n * self._step) for n in count())
2338 if self._growing:
2339 return takewhile(partial(gt, self._stop), values)
2340 else:
2341 return takewhile(partial(lt, self._stop), values)
2343 def __len__(self):
2344 return self._len
2346 @cached_property
2347 def _len(self):
2348 if self._growing:
2349 start = self._start
2350 stop = self._stop
2351 step = self._step
2352 else:
2353 start = self._stop
2354 stop = self._start
2355 step = -self._step
2356 distance = stop - start
2357 if distance <= self._zero:
2358 return 0
2359 else: # distance > 0 and step > 0: regular euclidean division
2360 q, r = divmod(distance, step)
2361 return int(q) + int(r != self._zero)
2363 def __reduce__(self):
2364 return numeric_range, (self._start, self._stop, self._step)
2366 def __repr__(self):
2367 if self._step == 1:
2368 return f"numeric_range({self._start!r}, {self._stop!r})"
2369 return (
2370 f"numeric_range({self._start!r}, {self._stop!r}, {self._step!r})"
2371 )
2373 def __reversed__(self):
2374 return iter(
2375 numeric_range(
2376 self._get_by_index(-1), self._start - self._step, -self._step
2377 )
2378 )
2380 def count(self, value):
2381 return int(value in self)
2383 def index(self, value):
2384 if self._growing:
2385 if self._start <= value < self._stop:
2386 q, r = divmod(value - self._start, self._step)
2387 if r == self._zero:
2388 return int(q)
2389 else:
2390 if self._start >= value > self._stop:
2391 q, r = divmod(self._start - value, -self._step)
2392 if r == self._zero:
2393 return int(q)
2395 raise ValueError(f"{value} is not in numeric range")
2397 def _get_by_index(self, i):
2398 if i < 0:
2399 i += self._len
2400 if i < 0 or i >= self._len:
2401 raise IndexError("numeric range object index out of range")
2402 return self._start + i * self._step
2405def count_cycle(iterable, n=None):
2406 """Cycle through the items from *iterable* up to *n* times, yielding
2407 the number of completed cycles along with each item. If *n* is omitted the
2408 process repeats indefinitely.
2410 >>> list(count_cycle('AB', 3))
2411 [(0, 'A'), (0, 'B'), (1, 'A'), (1, 'B'), (2, 'A'), (2, 'B')]
2413 """
2414 if n is not None:
2415 return product(range(n), iterable)
2416 seq = tuple(iterable)
2417 if not seq:
2418 return iter(())
2419 counter = count() if n is None else range(n)
2420 return zip(repeat_each(counter, len(seq)), cycle(seq))
2423def mark_ends(iterable):
2424 """Yield 3-tuples of the form ``(is_first, is_last, item)``.
2426 >>> list(mark_ends('ABC'))
2427 [(True, False, 'A'), (False, False, 'B'), (False, True, 'C')]
2429 Use this when looping over an iterable to take special action on its first
2430 and/or last items:
2432 >>> iterable = ['Header', 100, 200, 'Footer']
2433 >>> total = 0
2434 >>> for is_first, is_last, item in mark_ends(iterable):
2435 ... if is_first:
2436 ... continue # Skip the header
2437 ... if is_last:
2438 ... continue # Skip the footer
2439 ... total += item
2440 >>> print(total)
2441 300
2442 """
2443 it = iter(iterable)
2444 for a in it:
2445 first = True
2446 for b in it:
2447 yield first, False, a
2448 a = b
2449 first = False
2450 yield first, True, a
2453def locate(iterable, pred=bool, window_size=None):
2454 """Yield the index of each item in *iterable* for which *pred* returns
2455 ``True``.
2457 *pred* defaults to :func:`bool`, which will select truthy items:
2459 >>> list(locate([0, 1, 1, 0, 1, 0, 0]))
2460 [1, 2, 4]
2462 Set *pred* to a custom function to, e.g., find the indexes for a particular
2463 item.
2465 >>> list(locate(['a', 'b', 'c', 'b'], lambda x: x == 'b'))
2466 [1, 3]
2468 If *window_size* is given, then the *pred* function will be called with
2469 that many items. This enables searching for sub-sequences:
2471 >>> iterable = [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]
2472 >>> pred = lambda *args: args == (1, 2, 3)
2473 >>> list(locate(iterable, pred=pred, window_size=3))
2474 [1, 5, 9]
2476 Use with :func:`seekable` to find indexes and then retrieve the associated
2477 items:
2479 >>> from itertools import count
2480 >>> from more_itertools import seekable
2481 >>> source = (3 * n + 1 if (n % 2) else n // 2 for n in count())
2482 >>> it = seekable(source)
2483 >>> pred = lambda x: x > 100
2484 >>> indexes = locate(it, pred=pred)
2485 >>> i = next(indexes)
2486 >>> it.seek(i)
2487 >>> next(it)
2488 106
2490 """
2491 if window_size is None:
2492 return compress(count(), map(pred, iterable))
2494 if window_size < 1:
2495 raise ValueError('window size must be at least 1')
2497 it = windowed(iterable, window_size, fillvalue=_marker)
2498 return compress(count(), starmap(pred, it))
2501def longest_common_prefix(iterables):
2502 """Yield elements of the longest common prefix among given *iterables*.
2504 >>> ''.join(longest_common_prefix(['abcd', 'abc', 'abf']))
2505 'ab'
2507 """
2508 return (c[0] for c in takewhile(all_equal, zip(*iterables)))
2511def lstrip(iterable, pred):
2512 """Yield the items from *iterable*, but strip any from the beginning
2513 for which *pred* returns ``True``.
2515 For example, to remove a set of items from the start of an iterable:
2517 >>> iterable = (None, False, None, 1, 2, None, 3, False, None)
2518 >>> pred = lambda x: x in {None, False, ''}
2519 >>> list(lstrip(iterable, pred))
2520 [1, 2, None, 3, False, None]
2522 This function is analogous to to :func:`str.lstrip`, and is essentially
2523 an wrapper for :func:`itertools.dropwhile`.
2525 """
2526 return dropwhile(pred, iterable)
2529def rstrip(iterable, pred):
2530 """Yield the items from *iterable*, but strip any from the end
2531 for which *pred* returns ``True``.
2533 For example, to remove a set of items from the end of an iterable:
2535 >>> iterable = (None, False, None, 1, 2, None, 3, False, None)
2536 >>> pred = lambda x: x in {None, False, ''}
2537 >>> list(rstrip(iterable, pred))
2538 [None, False, None, 1, 2, None, 3]
2540 This function is analogous to :func:`str.rstrip`.
2542 """
2543 cache = []
2544 cache_append = cache.append
2545 cache_clear = cache.clear
2546 for x in iterable:
2547 if pred(x):
2548 cache_append(x)
2549 else:
2550 yield from cache
2551 cache_clear()
2552 yield x
2555def strip(iterable, pred):
2556 """Yield the items from *iterable*, but strip any from the
2557 beginning and end for which *pred* returns ``True``.
2559 For example, to remove a set of items from both ends of an iterable:
2561 >>> iterable = (None, False, None, 1, 2, None, 3, False, None)
2562 >>> pred = lambda x: x in {None, False, ''}
2563 >>> list(strip(iterable, pred))
2564 [1, 2, None, 3]
2566 This function is analogous to :func:`str.strip`.
2568 """
2569 return rstrip(lstrip(iterable, pred), pred)
2572class islice_extended:
2573 """An extension of :func:`itertools.islice` that supports negative values
2574 for *stop*, *start*, and *step*.
2576 >>> iterator = iter('abcdefgh')
2577 >>> list(islice_extended(iterator, -4, -1))
2578 ['e', 'f', 'g']
2580 Slices with negative values require some caching of *iterable*, but this
2581 function takes care to minimize the amount of memory required.
2583 For example, you can use a negative step with an infinite iterator:
2585 >>> from itertools import count
2586 >>> list(islice_extended(count(), 110, 99, -2))
2587 [110, 108, 106, 104, 102, 100]
2589 You can also use slice notation directly:
2591 >>> iterator = map(str, count())
2592 >>> it = islice_extended(iterator)[10:20:2]
2593 >>> list(it)
2594 ['10', '12', '14', '16', '18']
2596 """
2598 def __init__(self, iterable, *args):
2599 it = iter(iterable)
2600 if args:
2601 self._iterator = _islice_helper(it, slice(*args))
2602 else:
2603 self._iterator = it
2605 def __iter__(self):
2606 return self
2608 def __next__(self):
2609 return next(self._iterator)
2611 def __getitem__(self, key):
2612 if isinstance(key, slice):
2613 return islice_extended(_islice_helper(self._iterator, key))
2615 raise TypeError('islice_extended.__getitem__ argument must be a slice')
2618def _islice_helper(it, s):
2619 start = s.start
2620 stop = s.stop
2621 if s.step == 0:
2622 raise ValueError('step argument must be a non-zero integer or None.')
2623 step = s.step or 1
2625 if step > 0:
2626 start = 0 if (start is None) else start
2628 if start < 0:
2629 # Consume all but the last -start items
2630 cache = deque(enumerate(it, 1), maxlen=-start)
2631 len_iter = cache[-1][0] if cache else 0
2633 # Adjust start to be positive
2634 i = max(len_iter + start, 0)
2636 # Adjust stop to be positive
2637 if stop is None:
2638 j = len_iter
2639 elif stop >= 0:
2640 j = min(stop, len_iter)
2641 else:
2642 j = max(len_iter + stop, 0)
2644 # Slice the cache
2645 n = j - i
2646 if n <= 0:
2647 return
2649 for index in range(n):
2650 if index % step == 0:
2651 # pop and yield the item.
2652 # We don't want to use an intermediate variable
2653 # it would extend the lifetime of the current item
2654 yield cache.popleft()[1]
2655 else:
2656 # just pop and discard the item
2657 cache.popleft()
2658 elif (stop is not None) and (stop < 0):
2659 # Advance to the start position
2660 next(islice(it, start, start), None)
2662 # When stop is negative, we have to carry -stop items while
2663 # iterating
2664 cache = deque(islice(it, -stop), maxlen=-stop)
2666 for index, item in enumerate(it):
2667 if index % step == 0:
2668 # pop and yield the item.
2669 # We don't want to use an intermediate variable
2670 # it would extend the lifetime of the current item
2671 yield cache.popleft()
2672 else:
2673 # just pop and discard the item
2674 cache.popleft()
2675 cache.append(item)
2676 else:
2677 # When both start and stop are positive we have the normal case
2678 yield from islice(it, start, stop, step)
2679 else:
2680 start = -1 if (start is None) else start
2682 if (stop is not None) and (stop < 0):
2683 # Consume all but the last items
2684 n = -stop - 1
2685 cache = deque(enumerate(it, 1), maxlen=n)
2686 len_iter = cache[-1][0] if cache else 0
2688 # If start and stop are both negative they are comparable and
2689 # we can just slice. Otherwise we can adjust start to be negative
2690 # and then slice.
2691 if start < 0:
2692 i, j = start, stop
2693 else:
2694 i, j = min(start - len_iter, -1), None
2696 for index, item in list(cache)[i:j:step]:
2697 yield item
2698 else:
2699 # Advance to the stop position
2700 if stop is not None:
2701 m = stop + 1
2702 next(islice(it, m, m), None)
2704 # stop is positive, so if start is negative they are not comparable
2705 # and we need the rest of the items.
2706 if start < 0:
2707 i = start
2708 n = None
2709 # stop is None and start is positive, so we just need items up to
2710 # the start index.
2711 elif stop is None:
2712 i = None
2713 n = start + 1
2714 # Both stop and start are positive, so they are comparable.
2715 else:
2716 i = None
2717 n = start - stop
2718 if n <= 0:
2719 return
2721 cache = list(islice(it, n))
2723 yield from cache[i::step]
2726def always_reversible(iterable):
2727 """An extension of :func:`reversed` that supports all iterables, not
2728 just those which implement the ``Reversible`` or ``Sequence`` protocols.
2730 >>> print(*always_reversible(x for x in range(3)))
2731 2 1 0
2733 If the iterable is already reversible, this function returns the
2734 result of :func:`reversed()`. If the iterable is not reversible,
2735 this function will cache the remaining items in the iterable and
2736 yield them in reverse order, which may require significant storage.
2737 """
2738 try:
2739 return reversed(iterable)
2740 except TypeError:
2741 return reversed(list(iterable))
2744def consecutive_groups(iterable, ordering=None):
2745 """Yield groups of consecutive items using :func:`itertools.groupby`.
2746 The *ordering* function determines whether two items are adjacent by
2747 returning their position.
2749 By default, the ordering function is the identity function. This is
2750 suitable for finding runs of numbers:
2752 >>> iterable = [1, 10, 11, 12, 20, 30, 31, 32, 33, 40]
2753 >>> for group in consecutive_groups(iterable):
2754 ... print(list(group))
2755 [1]
2756 [10, 11, 12]
2757 [20]
2758 [30, 31, 32, 33]
2759 [40]
2761 To find runs of adjacent letters, apply :func:`ord` function
2762 to convert letters to ordinals.
2764 >>> iterable = 'abcdfgilmnop'
2765 >>> ordering = ord
2766 >>> for group in consecutive_groups(iterable, ordering):
2767 ... print(list(group))
2768 ['a', 'b', 'c', 'd']
2769 ['f', 'g']
2770 ['i']
2771 ['l', 'm', 'n', 'o', 'p']
2773 Each group of consecutive items is an iterator that shares it source with
2774 *iterable*. When an an output group is advanced, the previous group is
2775 no longer available unless its elements are copied (e.g., into a ``list``).
2777 >>> iterable = [1, 2, 11, 12, 21, 22]
2778 >>> saved_groups = []
2779 >>> for group in consecutive_groups(iterable):
2780 ... saved_groups.append(list(group)) # Copy group elements
2781 >>> saved_groups
2782 [[1, 2], [11, 12], [21, 22]]
2784 """
2785 if ordering is None:
2786 key = lambda x: x[0] - x[1]
2787 else:
2788 key = lambda x: x[0] - ordering(x[1])
2790 for k, g in groupby(enumerate(iterable), key=key):
2791 yield map(itemgetter(1), g)
2794def difference(iterable, func=sub, *, initial=None):
2795 """This function is the inverse of :func:`itertools.accumulate`. By default
2796 it will compute the first difference of *iterable* using
2797 :func:`operator.sub`:
2799 >>> from itertools import accumulate
2800 >>> iterable = accumulate([0, 1, 2, 3, 4]) # produces 0, 1, 3, 6, 10
2801 >>> list(difference(iterable))
2802 [0, 1, 2, 3, 4]
2804 *func* defaults to :func:`operator.sub`, but other functions can be
2805 specified. They will be applied as follows::
2807 A, B, C, D, ... --> A, func(B, A), func(C, B), func(D, C), ...
2809 For example, to do progressive division:
2811 >>> iterable = [1, 2, 6, 24, 120]
2812 >>> func = lambda x, y: x // y
2813 >>> list(difference(iterable, func))
2814 [1, 2, 3, 4, 5]
2816 If the *initial* keyword is set, the first element will be skipped when
2817 computing successive differences.
2819 >>> it = [10, 11, 13, 16] # from accumulate([1, 2, 3], initial=10)
2820 >>> list(difference(it, initial=10))
2821 [1, 2, 3]
2823 """
2824 a, b = tee(iterable)
2825 try:
2826 first = [next(b)]
2827 except StopIteration:
2828 return iter([])
2830 if initial is not None:
2831 first = []
2833 return chain(first, map(func, b, a))
2836class SequenceView(Sequence):
2837 """Return a read-only view of the sequence object *target*.
2839 :class:`SequenceView` objects are analogous to Python's built-in
2840 "dictionary view" types. They provide a dynamic view of a sequence's items,
2841 meaning that when the sequence updates, so does the view.
2843 >>> seq = ['0', '1', '2']
2844 >>> view = SequenceView(seq)
2845 >>> view
2846 SequenceView(['0', '1', '2'])
2847 >>> seq.append('3')
2848 >>> view
2849 SequenceView(['0', '1', '2', '3'])
2851 Sequence views support indexing, slicing, and length queries. They act
2852 like the underlying sequence, except they don't allow assignment:
2854 >>> view[1]
2855 '1'
2856 >>> view[1:-1]
2857 ['1', '2']
2858 >>> len(view)
2859 4
2861 Sequence views are useful as an alternative to copying, as they don't
2862 require (much) extra storage.
2864 """
2866 def __init__(self, target):
2867 if not isinstance(target, Sequence):
2868 raise TypeError
2869 self._target = target
2871 def __getitem__(self, index):
2872 return self._target[index]
2874 def __len__(self):
2875 return len(self._target)
2877 def __repr__(self):
2878 return f'{self.__class__.__name__}({self._target!r})'
2881class seekable:
2882 """Wrap an iterator to allow for seeking backward and forward. This
2883 progressively caches the items in the source iterable so they can be
2884 re-visited.
2886 Call :meth:`seek` with an index to seek to that position in the source
2887 iterable.
2889 To "reset" an iterator, seek to ``0``:
2891 >>> from itertools import count
2892 >>> it = seekable((str(n) for n in count()))
2893 >>> next(it), next(it), next(it)
2894 ('0', '1', '2')
2895 >>> it.seek(0)
2896 >>> next(it), next(it), next(it)
2897 ('0', '1', '2')
2899 You can also seek forward:
2901 >>> it = seekable((str(n) for n in range(20)))
2902 >>> it.seek(10)
2903 >>> next(it)
2904 '10'
2905 >>> it.seek(20) # Seeking past the end of the source isn't a problem
2906 >>> list(it)
2907 []
2908 >>> it.seek(0) # Resetting works even after hitting the end
2909 >>> next(it)
2910 '0'
2912 Call :meth:`relative_seek` to seek relative to the source iterator's
2913 current position.
2915 >>> it = seekable((str(n) for n in range(20)))
2916 >>> next(it), next(it), next(it)
2917 ('0', '1', '2')
2918 >>> it.relative_seek(2)
2919 >>> next(it)
2920 '5'
2921 >>> it.relative_seek(-3) # Source is at '6', we move back to '3'
2922 >>> next(it)
2923 '3'
2924 >>> it.relative_seek(-3) # Source is at '4', we move back to '1'
2925 >>> next(it)
2926 '1'
2929 Call :meth:`peek` to look ahead one item without advancing the iterator:
2931 >>> it = seekable('1234')
2932 >>> it.peek()
2933 '1'
2934 >>> list(it)
2935 ['1', '2', '3', '4']
2936 >>> it.peek(default='empty')
2937 'empty'
2939 Before the iterator is at its end, calling :func:`bool` on it will return
2940 ``True``. After it will return ``False``:
2942 >>> it = seekable('5678')
2943 >>> bool(it)
2944 True
2945 >>> list(it)
2946 ['5', '6', '7', '8']
2947 >>> bool(it)
2948 False
2950 You may view the contents of the cache with the :meth:`elements` method.
2951 That returns a :class:`SequenceView`, a view that updates automatically:
2953 >>> it = seekable((str(n) for n in range(10)))
2954 >>> next(it), next(it), next(it)
2955 ('0', '1', '2')
2956 >>> elements = it.elements()
2957 >>> elements
2958 SequenceView(['0', '1', '2'])
2959 >>> next(it)
2960 '3'
2961 >>> elements
2962 SequenceView(['0', '1', '2', '3'])
2964 By default, the cache grows as the source iterable progresses, so beware of
2965 wrapping very large or infinite iterables. Supply *maxlen* to limit the
2966 size of the cache (this of course limits how far back you can seek).
2968 >>> from itertools import count
2969 >>> it = seekable((str(n) for n in count()), maxlen=2)
2970 >>> next(it), next(it), next(it), next(it)
2971 ('0', '1', '2', '3')
2972 >>> list(it.elements())
2973 ['2', '3']
2974 >>> it.seek(0)
2975 >>> next(it), next(it), next(it), next(it)
2976 ('2', '3', '4', '5')
2977 >>> next(it)
2978 '6'
2980 """
2982 def __init__(self, iterable, maxlen=None):
2983 self._source = iter(iterable)
2984 if maxlen is None:
2985 self._cache = []
2986 else:
2987 self._cache = deque([], maxlen)
2988 self._index = None
2990 def __iter__(self):
2991 return self
2993 def __next__(self):
2994 if self._index is not None:
2995 try:
2996 item = self._cache[self._index]
2997 except IndexError:
2998 self._index = None
2999 else:
3000 self._index += 1
3001 return item
3003 item = next(self._source)
3004 self._cache.append(item)
3005 return item
3007 def __bool__(self):
3008 try:
3009 self.peek()
3010 except StopIteration:
3011 return False
3012 return True
3014 def peek(self, default=_marker):
3015 try:
3016 peeked = next(self)
3017 except StopIteration:
3018 if default is _marker:
3019 raise
3020 return default
3021 if self._index is None:
3022 self._index = len(self._cache)
3023 self._index -= 1
3024 return peeked
3026 def elements(self):
3027 return SequenceView(self._cache)
3029 def seek(self, index):
3030 self._index = index
3031 remainder = index - len(self._cache)
3032 if remainder > 0:
3033 consume(self, remainder)
3035 def relative_seek(self, count):
3036 if self._index is None:
3037 self._index = len(self._cache)
3039 self.seek(max(self._index + count, 0))
3042class run_length:
3043 """
3044 :func:`run_length.encode` compresses an iterable with run-length encoding.
3045 It yields groups of repeated items with the count of how many times they
3046 were repeated:
3048 >>> uncompressed = 'abbcccdddd'
3049 >>> list(run_length.encode(uncompressed))
3050 [('a', 1), ('b', 2), ('c', 3), ('d', 4)]
3052 :func:`run_length.decode` decompresses an iterable that was previously
3053 compressed with run-length encoding. It yields the items of the
3054 decompressed iterable:
3056 >>> compressed = [('a', 1), ('b', 2), ('c', 3), ('d', 4)]
3057 >>> list(run_length.decode(compressed))
3058 ['a', 'b', 'b', 'c', 'c', 'c', 'd', 'd', 'd', 'd']
3060 """
3062 @staticmethod
3063 def encode(iterable):
3064 return ((k, ilen(g)) for k, g in groupby(iterable))
3066 @staticmethod
3067 def decode(iterable):
3068 return chain.from_iterable(starmap(repeat, iterable))
3071def exactly_n(iterable, n, predicate=bool):
3072 """Return ``True`` if exactly ``n`` items in the iterable are ``True``
3073 according to the *predicate* function.
3075 >>> exactly_n([True, True, False], 2)
3076 True
3077 >>> exactly_n([True, True, False], 1)
3078 False
3079 >>> exactly_n([0, 1, 2, 3, 4, 5], 3, lambda x: x < 3)
3080 True
3082 The iterable will be advanced until ``n + 1`` truthy items are encountered,
3083 so avoid calling it on infinite iterables.
3085 """
3086 iterator = filter(predicate, iterable)
3087 if n <= 0:
3088 if n < 0:
3089 return False
3090 for _ in iterator:
3091 return False
3092 return True
3094 iterator = islice(iterator, n - 1, None)
3095 for _ in iterator:
3096 for _ in iterator:
3097 return False
3098 return True
3099 return False
3102def circular_shifts(iterable, steps=1):
3103 """Yield the circular shifts of *iterable*.
3105 >>> list(circular_shifts(range(4)))
3106 [(0, 1, 2, 3), (1, 2, 3, 0), (2, 3, 0, 1), (3, 0, 1, 2)]
3108 Set *steps* to the number of places to rotate to the left
3109 (or to the right if negative). Defaults to 1.
3111 >>> list(circular_shifts(range(4), 2))
3112 [(0, 1, 2, 3), (2, 3, 0, 1)]
3114 >>> list(circular_shifts(range(4), -1))
3115 [(0, 1, 2, 3), (3, 0, 1, 2), (2, 3, 0, 1), (1, 2, 3, 0)]
3117 """
3118 buffer = deque(iterable)
3119 if steps == 0:
3120 raise ValueError('Steps should be a non-zero integer')
3122 buffer.rotate(steps)
3123 steps = -steps
3124 n = len(buffer)
3125 n //= math.gcd(n, steps)
3127 for _ in repeat(None, n):
3128 buffer.rotate(steps)
3129 yield tuple(buffer)
3132def make_decorator(wrapping_func, result_index=0):
3133 """Return a decorator version of *wrapping_func*, which is a function that
3134 modifies an iterable. *result_index* is the position in that function's
3135 signature where the iterable goes.
3137 This lets you use itertools on the "production end," i.e. at function
3138 definition. This can augment what the function returns without changing the
3139 function's code.
3141 For example, to produce a decorator version of :func:`chunked`:
3143 >>> from more_itertools import chunked
3144 >>> chunker = make_decorator(chunked, result_index=0)
3145 >>> @chunker(3)
3146 ... def iter_range(n):
3147 ... return iter(range(n))
3148 ...
3149 >>> list(iter_range(9))
3150 [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
3152 To only allow truthy items to be returned:
3154 >>> truth_serum = make_decorator(filter, result_index=1)
3155 >>> @truth_serum(bool)
3156 ... def boolean_test():
3157 ... return [0, 1, '', ' ', False, True]
3158 ...
3159 >>> list(boolean_test())
3160 [1, ' ', True]
3162 The :func:`peekable` and :func:`seekable` wrappers make for practical
3163 decorators:
3165 >>> from more_itertools import peekable
3166 >>> peekable_function = make_decorator(peekable)
3167 >>> @peekable_function()
3168 ... def str_range(*args):
3169 ... return (str(x) for x in range(*args))
3170 ...
3171 >>> it = str_range(1, 20, 2)
3172 >>> next(it), next(it), next(it)
3173 ('1', '3', '5')
3174 >>> it.peek()
3175 '7'
3176 >>> next(it)
3177 '7'
3179 """
3181 # See https://sites.google.com/site/bbayles/index/decorator_factory for
3182 # notes on how this works.
3183 def decorator(*wrapping_args, **wrapping_kwargs):
3184 def outer_wrapper(f):
3185 def inner_wrapper(*args, **kwargs):
3186 result = f(*args, **kwargs)
3187 wrapping_args_ = list(wrapping_args)
3188 wrapping_args_.insert(result_index, result)
3189 return wrapping_func(*wrapping_args_, **wrapping_kwargs)
3191 return inner_wrapper
3193 return outer_wrapper
3195 return decorator
3198def map_reduce(iterable, keyfunc, valuefunc=None, reducefunc=None):
3199 """Return a dictionary that maps the items in *iterable* to categories
3200 defined by *keyfunc*, transforms them with *valuefunc*, and
3201 then summarizes them by category with *reducefunc*.
3203 *valuefunc* defaults to the identity function if it is unspecified.
3204 If *reducefunc* is unspecified, no summarization takes place:
3206 >>> keyfunc = lambda x: x.upper()
3207 >>> result = map_reduce('abbccc', keyfunc)
3208 >>> sorted(result.items())
3209 [('A', ['a']), ('B', ['b', 'b']), ('C', ['c', 'c', 'c'])]
3211 Specifying *valuefunc* transforms the categorized items:
3213 >>> keyfunc = lambda x: x.upper()
3214 >>> valuefunc = lambda x: 1
3215 >>> result = map_reduce('abbccc', keyfunc, valuefunc)
3216 >>> sorted(result.items())
3217 [('A', [1]), ('B', [1, 1]), ('C', [1, 1, 1])]
3219 Specifying *reducefunc* summarizes the categorized items:
3221 >>> keyfunc = lambda x: x.upper()
3222 >>> valuefunc = lambda x: 1
3223 >>> reducefunc = sum
3224 >>> result = map_reduce('abbccc', keyfunc, valuefunc, reducefunc)
3225 >>> sorted(result.items())
3226 [('A', 1), ('B', 2), ('C', 3)]
3228 You may want to filter the input iterable before applying the map/reduce
3229 procedure:
3231 >>> all_items = range(30)
3232 >>> items = [x for x in all_items if 10 <= x <= 20] # Filter
3233 >>> keyfunc = lambda x: x % 2 # Evens map to 0; odds to 1
3234 >>> categories = map_reduce(items, keyfunc=keyfunc)
3235 >>> sorted(categories.items())
3236 [(0, [10, 12, 14, 16, 18, 20]), (1, [11, 13, 15, 17, 19])]
3237 >>> summaries = map_reduce(items, keyfunc=keyfunc, reducefunc=sum)
3238 >>> sorted(summaries.items())
3239 [(0, 90), (1, 75)]
3241 Note that all items in the iterable are gathered into a list before the
3242 summarization step, which may require significant storage.
3244 The returned object is a :obj:`collections.defaultdict` with the
3245 ``default_factory`` set to ``None``, such that it behaves like a normal
3246 dictionary.
3248 """
3250 ret = defaultdict(list)
3252 if valuefunc is None:
3253 for item in iterable:
3254 key = keyfunc(item)
3255 ret[key].append(item)
3257 else:
3258 for item in iterable:
3259 key = keyfunc(item)
3260 value = valuefunc(item)
3261 ret[key].append(value)
3263 if reducefunc is not None:
3264 for key, value_list in ret.items():
3265 ret[key] = reducefunc(value_list)
3267 ret.default_factory = None
3268 return ret
3271def rlocate(iterable, pred=bool, window_size=None):
3272 """Yield the index of each item in *iterable* for which *pred* returns
3273 ``True``, starting from the right and moving left.
3275 *pred* defaults to :func:`bool`, which will select truthy items:
3277 >>> list(rlocate([0, 1, 1, 0, 1, 0, 0])) # Truthy at 1, 2, and 4
3278 [4, 2, 1]
3280 Set *pred* to a custom function to, e.g., find the indexes for a particular
3281 item:
3283 >>> iterator = iter('abcb')
3284 >>> pred = lambda x: x == 'b'
3285 >>> list(rlocate(iterator, pred))
3286 [3, 1]
3288 If *window_size* is given, then the *pred* function will be called with
3289 that many items. This enables searching for sub-sequences:
3291 >>> iterable = [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]
3292 >>> pred = lambda *args: args == (1, 2, 3)
3293 >>> list(rlocate(iterable, pred=pred, window_size=3))
3294 [9, 5, 1]
3296 Beware, this function won't return anything for infinite iterables.
3297 If *iterable* is reversible, ``rlocate`` will reverse it and search from
3298 the right. Otherwise, it will search from the left and return the results
3299 in reverse order.
3301 See :func:`locate` to for other example applications.
3303 """
3304 if window_size is None:
3305 try:
3306 len_iter = len(iterable)
3307 return (len_iter - i - 1 for i in locate(reversed(iterable), pred))
3308 except TypeError:
3309 pass
3311 return reversed(list(locate(iterable, pred, window_size)))
3314def replace(iterable, pred, substitutes, count=None, window_size=1):
3315 """Yield the items from *iterable*, replacing the items for which *pred*
3316 returns ``True`` with the items from the iterable *substitutes*.
3318 >>> iterable = [1, 1, 0, 1, 1, 0, 1, 1]
3319 >>> pred = lambda x: x == 0
3320 >>> substitutes = (2, 3)
3321 >>> list(replace(iterable, pred, substitutes))
3322 [1, 1, 2, 3, 1, 1, 2, 3, 1, 1]
3324 If *count* is given, the number of replacements will be limited:
3326 >>> iterable = [1, 1, 0, 1, 1, 0, 1, 1, 0]
3327 >>> pred = lambda x: x == 0
3328 >>> substitutes = [None]
3329 >>> list(replace(iterable, pred, substitutes, count=2))
3330 [1, 1, None, 1, 1, None, 1, 1, 0]
3332 Use *window_size* to control the number of items passed as arguments to
3333 *pred*. This allows for locating and replacing subsequences.
3335 >>> iterable = [0, 1, 2, 5, 0, 1, 2, 5]
3336 >>> window_size = 3
3337 >>> pred = lambda *args: args == (0, 1, 2) # 3 items passed to pred
3338 >>> substitutes = [3, 4] # Splice in these items
3339 >>> list(replace(iterable, pred, substitutes, window_size=window_size))
3340 [3, 4, 5, 3, 4, 5]
3342 """
3343 if window_size < 1:
3344 raise ValueError('window_size must be at least 1')
3346 # Save the substitutes iterable, since it's used more than once
3347 substitutes = tuple(substitutes)
3349 # Add padding such that the number of windows matches the length of the
3350 # iterable
3351 it = chain(iterable, repeat(_marker, window_size - 1))
3352 windows = windowed(it, window_size)
3354 n = 0
3355 for w in windows:
3356 # If the current window matches our predicate (and we haven't hit
3357 # our maximum number of replacements), splice in the substitutes
3358 # and then consume the following windows that overlap with this one.
3359 # For example, if the iterable is (0, 1, 2, 3, 4...)
3360 # and the window size is 2, we have (0, 1), (1, 2), (2, 3)...
3361 # If the predicate matches on (0, 1), we need to zap (0, 1) and (1, 2)
3362 if pred(*w):
3363 if (count is None) or (n < count):
3364 n += 1
3365 yield from substitutes
3366 consume(windows, window_size - 1)
3367 continue
3369 # If there was no match (or we've reached the replacement limit),
3370 # yield the first item from the window.
3371 if w and (w[0] is not _marker):
3372 yield w[0]
3375def partitions(iterable):
3376 """Yield all possible order-preserving partitions of *iterable*.
3378 >>> iterable = 'abc'
3379 >>> for part in partitions(iterable):
3380 ... print([''.join(p) for p in part])
3381 ['abc']
3382 ['a', 'bc']
3383 ['ab', 'c']
3384 ['a', 'b', 'c']
3386 This is unrelated to :func:`partition`.
3388 """
3389 sequence = list(iterable)
3390 n = len(sequence)
3391 for i in powerset(range(1, n)):
3392 yield [sequence[i:j] for i, j in zip((0,) + i, i + (n,))]
3395def set_partitions(iterable, k=None, min_size=None, max_size=None):
3396 """
3397 Yield the set partitions of *iterable* into *k* parts. Set partitions are
3398 not order-preserving.
3400 >>> iterable = 'abc'
3401 >>> for part in set_partitions(iterable, 2):
3402 ... print([''.join(p) for p in part])
3403 ['a', 'bc']
3404 ['ab', 'c']
3405 ['b', 'ac']
3408 If *k* is not given, every set partition is generated.
3410 >>> iterable = 'abc'
3411 >>> for part in set_partitions(iterable):
3412 ... print([''.join(p) for p in part])
3413 ['abc']
3414 ['a', 'bc']
3415 ['ab', 'c']
3416 ['b', 'ac']
3417 ['a', 'b', 'c']
3419 if *min_size* and/or *max_size* are given, the minimum and/or maximum size
3420 per block in partition is set.
3422 >>> iterable = 'abc'
3423 >>> for part in set_partitions(iterable, min_size=2):
3424 ... print([''.join(p) for p in part])
3425 ['abc']
3426 >>> for part in set_partitions(iterable, max_size=2):
3427 ... print([''.join(p) for p in part])
3428 ['a', 'bc']
3429 ['ab', 'c']
3430 ['b', 'ac']
3431 ['a', 'b', 'c']
3433 """
3434 L = list(iterable)
3435 n = len(L)
3436 if k is not None:
3437 if k < 1:
3438 raise ValueError(
3439 "Can't partition in a negative or zero number of groups"
3440 )
3441 elif k > n:
3442 return
3444 min_size = min_size if min_size is not None else 0
3445 max_size = max_size if max_size is not None else n
3446 if min_size > max_size:
3447 return
3449 def set_partitions_helper(L, k):
3450 n = len(L)
3451 if k == 1:
3452 yield [L]
3453 elif n == k:
3454 yield [[s] for s in L]
3455 else:
3456 e, *M = L
3457 for p in set_partitions_helper(M, k - 1):
3458 yield [[e], *p]
3459 for p in set_partitions_helper(M, k):
3460 for i in range(len(p)):
3461 yield p[:i] + [[e] + p[i]] + p[i + 1 :]
3463 if k is None:
3464 for k in range(1, n + 1):
3465 yield from filter(
3466 lambda z: all(min_size <= len(bk) <= max_size for bk in z),
3467 set_partitions_helper(L, k),
3468 )
3469 else:
3470 yield from filter(
3471 lambda z: all(min_size <= len(bk) <= max_size for bk in z),
3472 set_partitions_helper(L, k),
3473 )
3476class time_limited:
3477 """
3478 Yield items from *iterable* until *limit_seconds* have passed.
3479 If the time limit expires before all items have been yielded, the
3480 ``timed_out`` parameter will be set to ``True``.
3482 >>> from time import sleep
3483 >>> def generator():
3484 ... yield 1
3485 ... yield 2
3486 ... sleep(0.2)
3487 ... yield 3
3488 >>> iterable = time_limited(0.1, generator())
3489 >>> list(iterable)
3490 [1, 2]
3491 >>> iterable.timed_out
3492 True
3494 Note that the time is checked before each item is yielded, and iteration
3495 stops if the time elapsed is greater than *limit_seconds*. If your time
3496 limit is 1 second, but it takes 2 seconds to generate the first item from
3497 the iterable, the function will run for 2 seconds and not yield anything.
3498 As a special case, when *limit_seconds* is zero, the iterator never
3499 returns anything.
3501 """
3503 def __init__(self, limit_seconds, iterable):
3504 if limit_seconds < 0:
3505 raise ValueError('limit_seconds must be positive')
3506 self.limit_seconds = limit_seconds
3507 self._iterator = iter(iterable)
3508 self._start_time = monotonic()
3509 self.timed_out = False
3511 def __iter__(self):
3512 return self
3514 def __next__(self):
3515 if self.limit_seconds == 0:
3516 self.timed_out = True
3517 raise StopIteration
3518 item = next(self._iterator)
3519 if monotonic() - self._start_time > self.limit_seconds:
3520 self.timed_out = True
3521 raise StopIteration
3523 return item
3526def only(iterable, default=None, too_long=None):
3527 """If *iterable* has only one item, return it.
3528 If it has zero items, return *default*.
3529 If it has more than one item, raise the exception given by *too_long*,
3530 which is ``ValueError`` by default.
3532 >>> only([], default='missing')
3533 'missing'
3534 >>> only([1])
3535 1
3536 >>> only([1, 2]) # doctest: +IGNORE_EXCEPTION_DETAIL
3537 Traceback (most recent call last):
3538 ...
3539 ValueError: Expected exactly one item in iterable, but got 1, 2,
3540 and perhaps more.'
3541 >>> only([1, 2], too_long=TypeError) # doctest: +IGNORE_EXCEPTION_DETAIL
3542 Traceback (most recent call last):
3543 ...
3544 TypeError
3546 Note that :func:`only` attempts to advance *iterable* twice to ensure there
3547 is only one item. See :func:`spy` or :func:`peekable` to check
3548 iterable contents less destructively.
3550 """
3551 iterator = iter(iterable)
3552 for first in iterator:
3553 for second in iterator:
3554 msg = (
3555 f'Expected exactly one item in iterable, but got {first!r}, '
3556 f'{second!r}, and perhaps more.'
3557 )
3558 raise too_long or ValueError(msg)
3559 return first
3560 return default
3563def _ichunk(iterator, n):
3564 cache = deque()
3565 chunk = islice(iterator, n)
3567 def generator():
3568 with suppress(StopIteration):
3569 while True:
3570 if cache:
3571 yield cache.popleft()
3572 else:
3573 yield next(chunk)
3575 def materialize_next(n=1):
3576 # if n not specified materialize everything
3577 if n is None:
3578 cache.extend(chunk)
3579 return len(cache)
3581 to_cache = n - len(cache)
3583 # materialize up to n
3584 if to_cache > 0:
3585 cache.extend(islice(chunk, to_cache))
3587 # return number materialized up to n
3588 return min(n, len(cache))
3590 return (generator(), materialize_next)
3593def ichunked(iterable, n):
3594 """Break *iterable* into sub-iterables with *n* elements each.
3595 :func:`ichunked` is like :func:`chunked`, but it yields iterables
3596 instead of lists.
3598 If the sub-iterables are read in order, the elements of *iterable*
3599 won't be stored in memory.
3600 If they are read out of order, :func:`itertools.tee` is used to cache
3601 elements as necessary.
3603 >>> from itertools import count
3604 >>> all_chunks = ichunked(count(), 4)
3605 >>> c_1, c_2, c_3 = next(all_chunks), next(all_chunks), next(all_chunks)
3606 >>> list(c_2) # c_1's elements have been cached; c_3's haven't been
3607 [4, 5, 6, 7]
3608 >>> list(c_1)
3609 [0, 1, 2, 3]
3610 >>> list(c_3)
3611 [8, 9, 10, 11]
3613 """
3614 iterator = iter(iterable)
3615 while True:
3616 # Create new chunk
3617 chunk, materialize_next = _ichunk(iterator, n)
3619 # Check to see whether we're at the end of the source iterable
3620 if not materialize_next():
3621 return
3623 yield chunk
3625 # Fill previous chunk's cache
3626 materialize_next(None)
3629def iequals(*iterables):
3630 """Return ``True`` if all given *iterables* are equal to each other,
3631 which means that they contain the same elements in the same order.
3633 The function is useful for comparing iterables of different data types
3634 or iterables that do not support equality checks.
3636 >>> iequals("abc", ['a', 'b', 'c'], ('a', 'b', 'c'), iter("abc"))
3637 True
3639 >>> iequals("abc", "acb")
3640 False
3642 Not to be confused with :func:`all_equal`, which checks whether all
3643 elements of iterable are equal to each other.
3645 """
3646 try:
3647 return all(map(all_equal, zip(*iterables, strict=True)))
3648 except ValueError:
3649 return False
3652def distinct_combinations(iterable, r):
3653 """Yield the distinct combinations of *r* items taken from *iterable*.
3655 >>> list(distinct_combinations([0, 0, 1], 2))
3656 [(0, 0), (0, 1)]
3658 Equivalent to ``set(combinations(iterable))``, except duplicates are not
3659 generated and thrown away. For larger input sequences this is much more
3660 efficient.
3662 """
3663 if r < 0:
3664 raise ValueError('r must be non-negative')
3665 elif r == 0:
3666 yield ()
3667 return
3668 pool = tuple(iterable)
3669 generators = [unique_everseen(enumerate(pool), key=itemgetter(1))]
3670 current_combo = [None] * r
3671 level = 0
3672 while generators:
3673 try:
3674 cur_idx, p = next(generators[-1])
3675 except StopIteration:
3676 generators.pop()
3677 level -= 1
3678 continue
3679 current_combo[level] = p
3680 if level + 1 == r:
3681 yield tuple(current_combo)
3682 else:
3683 generators.append(
3684 unique_everseen(
3685 enumerate(pool[cur_idx + 1 :], cur_idx + 1),
3686 key=itemgetter(1),
3687 )
3688 )
3689 level += 1
3692def filter_except(validator, iterable, *exceptions):
3693 """Yield the items from *iterable* for which the *validator* function does
3694 not raise one of the specified *exceptions*.
3696 *validator* is called for each item in *iterable*.
3697 It should be a function that accepts one argument and raises an exception
3698 if that item is not valid.
3700 >>> iterable = ['1', '2', 'three', '4', None]
3701 >>> list(filter_except(int, iterable, ValueError, TypeError))
3702 ['1', '2', '4']
3704 If an exception other than one given by *exceptions* is raised by
3705 *validator*, it is raised like normal.
3706 """
3707 for item in iterable:
3708 try:
3709 validator(item)
3710 except exceptions:
3711 pass
3712 else:
3713 yield item
3716def map_except(function, iterable, *exceptions):
3717 """Transform each item from *iterable* with *function* and yield the
3718 result, unless *function* raises one of the specified *exceptions*.
3720 *function* is called to transform each item in *iterable*.
3721 It should accept one argument.
3723 >>> iterable = ['1', '2', 'three', '4', None]
3724 >>> list(map_except(int, iterable, ValueError, TypeError))
3725 [1, 2, 4]
3727 If an exception other than one given by *exceptions* is raised by
3728 *function*, it is raised like normal.
3729 """
3730 for item in iterable:
3731 try:
3732 yield function(item)
3733 except exceptions:
3734 pass
3737def map_if(iterable, pred, func, func_else=None):
3738 """Evaluate each item from *iterable* using *pred*. If the result is
3739 equivalent to ``True``, transform the item with *func* and yield it.
3740 Otherwise, transform the item with *func_else* and yield it.
3742 *pred*, *func*, and *func_else* should each be functions that accept
3743 one argument. By default, *func_else* is the identity function.
3745 >>> from math import sqrt
3746 >>> iterable = list(range(-5, 5))
3747 >>> iterable
3748 [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4]
3749 >>> list(map_if(iterable, lambda x: x > 3, lambda x: 'toobig'))
3750 [-5, -4, -3, -2, -1, 0, 1, 2, 3, 'toobig']
3751 >>> list(map_if(iterable, lambda x: x >= 0,
3752 ... lambda x: f'{sqrt(x):.2f}', lambda x: None))
3753 [None, None, None, None, None, '0.00', '1.00', '1.41', '1.73', '2.00']
3754 """
3756 if func_else is None:
3757 for item in iterable:
3758 yield func(item) if pred(item) else item
3760 else:
3761 for item in iterable:
3762 yield func(item) if pred(item) else func_else(item)
3765def _sample_unweighted(iterator, k, strict):
3766 # Algorithm L in the 1994 paper by Kim-Hung Li:
3767 # "Reservoir-Sampling Algorithms of Time Complexity O(n(1+log(N/n)))".
3769 reservoir = list(islice(iterator, k))
3770 if strict and len(reservoir) < k:
3771 raise ValueError('Sample larger than population')
3772 W = 1.0
3774 with suppress(StopIteration):
3775 while True:
3776 W *= random() ** (1 / k)
3777 skip = floor(log(random()) / log1p(-W))
3778 element = next(islice(iterator, skip, None))
3779 reservoir[randrange(k)] = element
3781 shuffle(reservoir)
3782 return reservoir
3785def _sample_weighted(iterator, k, weights, strict):
3786 # Implementation of "A-ExpJ" from the 2006 paper by Efraimidis et al. :
3787 # "Weighted random sampling with a reservoir".
3789 # Log-transform for numerical stability for weights that are small/large
3790 weight_keys = (log(random()) / weight for weight in weights)
3792 # Fill up the reservoir (collection of samples) with the first `k`
3793 # weight-keys and elements, then heapify the list.
3794 reservoir = take(k, zip(weight_keys, iterator))
3795 if strict and len(reservoir) < k:
3796 raise ValueError('Sample larger than population')
3798 heapify(reservoir)
3800 # The number of jumps before changing the reservoir is a random variable
3801 # with an exponential distribution. Sample it using random() and logs.
3802 smallest_weight_key, _ = reservoir[0]
3803 weights_to_skip = log(random()) / smallest_weight_key
3805 for weight, element in zip(weights, iterator):
3806 if weight >= weights_to_skip:
3807 # The notation here is consistent with the paper, but we store
3808 # the weight-keys in log-space for better numerical stability.
3809 smallest_weight_key, _ = reservoir[0]
3810 t_w = exp(weight * smallest_weight_key)
3811 r_2 = uniform(t_w, 1) # generate U(t_w, 1)
3812 weight_key = log(r_2) / weight
3813 heapreplace(reservoir, (weight_key, element))
3814 smallest_weight_key, _ = reservoir[0]
3815 weights_to_skip = log(random()) / smallest_weight_key
3816 else:
3817 weights_to_skip -= weight
3819 ret = [element for weight_key, element in reservoir]
3820 shuffle(ret)
3821 return ret
3824def _sample_counted(population, k, counts, strict):
3825 element = None
3826 remaining = 0
3828 def feed(i):
3829 # Advance *i* steps ahead and consume an element
3830 nonlocal element, remaining
3832 while i + 1 > remaining:
3833 i = i - remaining
3834 element = next(population)
3835 remaining = next(counts)
3836 remaining -= i + 1
3837 return element
3839 with suppress(StopIteration):
3840 reservoir = []
3841 for _ in range(k):
3842 reservoir.append(feed(0))
3844 if strict and len(reservoir) < k:
3845 raise ValueError('Sample larger than population')
3847 with suppress(StopIteration):
3848 W = 1.0
3849 while True:
3850 W *= random() ** (1 / k)
3851 skip = floor(log(random()) / log1p(-W))
3852 element = feed(skip)
3853 reservoir[randrange(k)] = element
3855 shuffle(reservoir)
3856 return reservoir
3859def sample(iterable, k, weights=None, *, counts=None, strict=False):
3860 """Return a *k*-length list of elements chosen (without replacement)
3861 from the *iterable*.
3863 Similar to :func:`random.sample`, but works on inputs that aren't
3864 indexable (such as sets and dictionaries) and on inputs where the
3865 size isn't known in advance (such as generators).
3867 >>> iterable = range(100)
3868 >>> sample(iterable, 5) # doctest: +SKIP
3869 [81, 60, 96, 16, 4]
3871 For iterables with repeated elements, you may supply *counts* to
3872 indicate the repeats.
3874 >>> iterable = ['a', 'b']
3875 >>> counts = [3, 4] # Equivalent to 'a', 'a', 'a', 'b', 'b', 'b', 'b'
3876 >>> sample(iterable, k=3, counts=counts) # doctest: +SKIP
3877 ['a', 'a', 'b']
3879 An iterable with *weights* may be given:
3881 >>> iterable = range(100)
3882 >>> weights = (i * i + 1 for i in range(100))
3883 >>> sampled = sample(iterable, 5, weights=weights) # doctest: +SKIP
3884 [79, 67, 74, 66, 78]
3886 Weighted selections are made without replacement.
3887 After an element is selected, it is removed from the pool and the
3888 relative weights of the other elements increase (this
3889 does not match the behavior of :func:`random.sample`'s *counts*
3890 parameter). Note that *weights* may not be used with *counts*.
3892 If the length of *iterable* is less than *k*,
3893 ``ValueError`` is raised if *strict* is ``True`` and
3894 all elements are returned (in shuffled order) if *strict* is ``False``.
3896 By default, the `Algorithm L <https://w.wiki/ANrM>`__ reservoir sampling
3897 technique is used. When *weights* are provided,
3898 `Algorithm A-ExpJ <https://w.wiki/ANrS>`__ is used instead.
3900 Notes on reproducibility:
3902 * The algorithms rely on inexact floating-point functions provided
3903 by the underlying math library (e.g. ``log``, ``log1p``, and ``pow``).
3904 Those functions can `produce slightly different results
3905 <https://members.loria.fr/PZimmermann/papers/accuracy.pdf>`_ on
3906 different builds. Accordingly, selections can vary across builds
3907 even for the same seed.
3909 * The algorithms loop over the input and make selections based on
3910 ordinal position, so selections from unordered collections (such as
3911 sets) won't reproduce across sessions on the same platform using the
3912 same seed. For example, this won't reproduce::
3914 >> seed(8675309)
3915 >> sample(set('abcdefghijklmnopqrstuvwxyz'), 10)
3916 ['c', 'p', 'e', 'w', 's', 'a', 'j', 'd', 'n', 't']
3918 """
3919 iterator = iter(iterable)
3921 if k < 0:
3922 raise ValueError('k must be non-negative')
3924 if k == 0:
3925 return []
3927 if weights is not None and counts is not None:
3928 raise TypeError('weights and counts are mutually exclusive')
3930 elif weights is not None:
3931 weights = iter(weights)
3932 return _sample_weighted(iterator, k, weights, strict)
3934 elif counts is not None:
3935 counts = iter(counts)
3936 return _sample_counted(iterator, k, counts, strict)
3938 else:
3939 return _sample_unweighted(iterator, k, strict)
3942def is_sorted(iterable, key=None, reverse=False, strict=False):
3943 """Returns ``True`` if the items of iterable are in sorted order, and
3944 ``False`` otherwise. *key* and *reverse* have the same meaning that they do
3945 in the built-in :func:`sorted` function.
3947 >>> is_sorted(['1', '2', '3', '4', '5'], key=int)
3948 True
3949 >>> is_sorted([5, 4, 3, 1, 2], reverse=True)
3950 False
3952 If *strict*, tests for strict sorting, that is, returns ``False`` if equal
3953 elements are found:
3955 >>> is_sorted([1, 2, 2])
3956 True
3957 >>> is_sorted([1, 2, 2], strict=True)
3958 False
3960 The function returns ``False`` after encountering the first out-of-order
3961 item, which means it may produce results that differ from the built-in
3962 :func:`sorted` function for objects with unusual comparison dynamics
3963 (like ``math.nan``). If there are no out-of-order items, the iterable is
3964 exhausted.
3965 """
3966 it = iterable if (key is None) else map(key, iterable)
3967 a, b = tee(it)
3968 next(b, None)
3969 if reverse:
3970 b, a = a, b
3971 return all(map(lt, a, b)) if strict else not any(map(lt, b, a))
3974class AbortThread(BaseException):
3975 pass
3978class callback_iter:
3979 """Convert a function that uses callbacks to an iterator.
3981 Let *func* be a function that takes a `callback` keyword argument.
3982 For example:
3984 >>> def func(callback=None):
3985 ... for i, c in [(1, 'a'), (2, 'b'), (3, 'c')]:
3986 ... if callback:
3987 ... callback(i, c)
3988 ... return 4
3991 Use ``with callback_iter(func)`` to get an iterator over the parameters
3992 that are delivered to the callback.
3994 >>> with callback_iter(func) as it:
3995 ... for args, kwargs in it:
3996 ... print(args)
3997 (1, 'a')
3998 (2, 'b')
3999 (3, 'c')
4001 The function will be called in a background thread. The ``done`` property
4002 indicates whether it has completed execution.
4004 >>> it.done
4005 True
4007 If it completes successfully, its return value will be available
4008 in the ``result`` property.
4010 >>> it.result
4011 4
4013 Notes:
4015 * If the function uses some keyword argument besides ``callback``, supply
4016 *callback_kwd*.
4017 * If it finished executing, but raised an exception, accessing the
4018 ``result`` property will raise the same exception.
4019 * If it hasn't finished executing, accessing the ``result``
4020 property from within the ``with`` block will raise ``RuntimeError``.
4021 * If it hasn't finished executing, accessing the ``result`` property from
4022 outside the ``with`` block will raise a
4023 ``more_itertools.AbortThread`` exception.
4024 * Provide *wait_seconds* to adjust how frequently the it is polled for
4025 output.
4027 """
4029 def __init__(self, func, callback_kwd='callback', wait_seconds=0.1):
4030 self._func = func
4031 self._callback_kwd = callback_kwd
4032 self._aborted = False
4033 self._future = None
4034 self._wait_seconds = wait_seconds
4035 # Lazily import concurrent.future
4036 self._executor = __import__(
4037 'concurrent.futures'
4038 ).futures.ThreadPoolExecutor(max_workers=1)
4039 self._iterator = self._reader()
4041 def __enter__(self):
4042 return self
4044 def __exit__(self, exc_type, exc_value, traceback):
4045 self._aborted = True
4046 self._executor.shutdown()
4048 def __iter__(self):
4049 return self
4051 def __next__(self):
4052 return next(self._iterator)
4054 @property
4055 def done(self):
4056 if self._future is None:
4057 return False
4058 return self._future.done()
4060 @property
4061 def result(self):
4062 if not self.done:
4063 raise RuntimeError('Function has not yet completed')
4065 return self._future.result()
4067 def _reader(self):
4068 q = Queue()
4070 def callback(*args, **kwargs):
4071 if self._aborted:
4072 raise AbortThread('canceled by user')
4074 q.put((args, kwargs))
4076 self._future = self._executor.submit(
4077 self._func, **{self._callback_kwd: callback}
4078 )
4080 while True:
4081 try:
4082 item = q.get(timeout=self._wait_seconds)
4083 except Empty:
4084 pass
4085 else:
4086 q.task_done()
4087 yield item
4089 if self._future.done():
4090 break
4092 remaining = []
4093 while True:
4094 try:
4095 item = q.get_nowait()
4096 except Empty:
4097 break
4098 else:
4099 q.task_done()
4100 remaining.append(item)
4101 q.join()
4102 yield from remaining
4105def windowed_complete(iterable, n):
4106 """
4107 Yield ``(beginning, middle, end)`` tuples, where:
4109 * Each ``middle`` has *n* items from *iterable*
4110 * Each ``beginning`` has the items before the ones in ``middle``
4111 * Each ``end`` has the items after the ones in ``middle``
4113 >>> iterable = range(7)
4114 >>> n = 3
4115 >>> for beginning, middle, end in windowed_complete(iterable, n):
4116 ... print(beginning, middle, end)
4117 () (0, 1, 2) (3, 4, 5, 6)
4118 (0,) (1, 2, 3) (4, 5, 6)
4119 (0, 1) (2, 3, 4) (5, 6)
4120 (0, 1, 2) (3, 4, 5) (6,)
4121 (0, 1, 2, 3) (4, 5, 6) ()
4123 Note that *n* must be at least 0 and most equal to the length of
4124 *iterable*.
4126 This function will exhaust the iterable and may require significant
4127 storage.
4128 """
4129 if n < 0:
4130 raise ValueError('n must be >= 0')
4132 seq = tuple(iterable)
4133 size = len(seq)
4135 if n > size:
4136 raise ValueError('n must be <= len(seq)')
4138 for i in range(size - n + 1):
4139 beginning = seq[:i]
4140 middle = seq[i : i + n]
4141 end = seq[i + n :]
4142 yield beginning, middle, end
4145def all_unique(iterable, key=None):
4146 """
4147 Returns ``True`` if all the elements of *iterable* are unique (no two
4148 elements are equal).
4150 >>> all_unique('ABCB')
4151 False
4153 If a *key* function is specified, it will be used to make comparisons.
4155 >>> all_unique('ABCb')
4156 True
4157 >>> all_unique('ABCb', str.lower)
4158 False
4160 The function returns as soon as the first non-unique element is
4161 encountered. Iterables with a mix of hashable and unhashable items can
4162 be used, but the function will be slower for unhashable items.
4163 """
4164 seenset = set()
4165 seenset_add = seenset.add
4166 seenlist = []
4167 seenlist_add = seenlist.append
4168 for element in map(key, iterable) if key else iterable:
4169 try:
4170 if element in seenset:
4171 return False
4172 seenset_add(element)
4173 except TypeError:
4174 if element in seenlist:
4175 return False
4176 seenlist_add(element)
4177 return True
4180def nth_product(index, *args):
4181 """Equivalent to ``list(product(*args))[index]``.
4183 The products of *args* can be ordered lexicographically.
4184 :func:`nth_product` computes the product at sort position *index* without
4185 computing the previous products.
4187 >>> nth_product(8, range(2), range(2), range(2), range(2))
4188 (1, 0, 0, 0)
4190 ``IndexError`` will be raised if the given *index* is invalid.
4191 """
4192 pools = list(map(tuple, reversed(args)))
4193 ns = list(map(len, pools))
4195 c = reduce(mul, ns)
4197 if index < 0:
4198 index += c
4200 if not 0 <= index < c:
4201 raise IndexError
4203 result = []
4204 for pool, n in zip(pools, ns):
4205 result.append(pool[index % n])
4206 index //= n
4208 return tuple(reversed(result))
4211def nth_permutation(iterable, r, index):
4212 """Equivalent to ``list(permutations(iterable, r))[index]```
4214 The subsequences of *iterable* that are of length *r* where order is
4215 important can be ordered lexicographically. :func:`nth_permutation`
4216 computes the subsequence at sort position *index* directly, without
4217 computing the previous subsequences.
4219 >>> nth_permutation('ghijk', 2, 5)
4220 ('h', 'i')
4222 ``ValueError`` will be raised If *r* is negative or greater than the length
4223 of *iterable*.
4224 ``IndexError`` will be raised if the given *index* is invalid.
4225 """
4226 pool = list(iterable)
4227 n = len(pool)
4229 if r is None or r == n:
4230 r, c = n, factorial(n)
4231 elif not 0 <= r < n:
4232 raise ValueError
4233 else:
4234 c = perm(n, r)
4235 assert c > 0 # factorial(n)>0, and r<n so perm(n,r) is never zero
4237 if index < 0:
4238 index += c
4240 if not 0 <= index < c:
4241 raise IndexError
4243 result = [0] * r
4244 q = index * factorial(n) // c if r < n else index
4245 for d in range(1, n + 1):
4246 q, i = divmod(q, d)
4247 if 0 <= n - d < r:
4248 result[n - d] = i
4249 if q == 0:
4250 break
4252 return tuple(map(pool.pop, result))
4255def nth_combination_with_replacement(iterable, r, index):
4256 """Equivalent to
4257 ``list(combinations_with_replacement(iterable, r))[index]``.
4260 The subsequences with repetition of *iterable* that are of length *r* can
4261 be ordered lexicographically. :func:`nth_combination_with_replacement`
4262 computes the subsequence at sort position *index* directly, without
4263 computing the previous subsequences with replacement.
4265 >>> nth_combination_with_replacement(range(5), 3, 5)
4266 (0, 1, 1)
4268 ``ValueError`` will be raised If *r* is negative or greater than the length
4269 of *iterable*.
4270 ``IndexError`` will be raised if the given *index* is invalid.
4271 """
4272 pool = tuple(iterable)
4273 n = len(pool)
4274 if (r < 0) or (r > n):
4275 raise ValueError
4277 c = comb(n + r - 1, r)
4279 if index < 0:
4280 index += c
4282 if (index < 0) or (index >= c):
4283 raise IndexError
4285 result = []
4286 i = 0
4287 while r:
4288 r -= 1
4289 while n >= 0:
4290 num_combs = comb(n + r - 1, r)
4291 if index < num_combs:
4292 break
4293 n -= 1
4294 i += 1
4295 index -= num_combs
4296 result.append(pool[i])
4298 return tuple(result)
4301def value_chain(*args):
4302 """Yield all arguments passed to the function in the same order in which
4303 they were passed. If an argument itself is iterable then iterate over its
4304 values.
4306 >>> list(value_chain(1, 2, 3, [4, 5, 6]))
4307 [1, 2, 3, 4, 5, 6]
4309 Binary and text strings are not considered iterable and are emitted
4310 as-is:
4312 >>> list(value_chain('12', '34', ['56', '78']))
4313 ['12', '34', '56', '78']
4315 Pre- or postpend a single element to an iterable:
4317 >>> list(value_chain(1, [2, 3, 4, 5, 6]))
4318 [1, 2, 3, 4, 5, 6]
4319 >>> list(value_chain([1, 2, 3, 4, 5], 6))
4320 [1, 2, 3, 4, 5, 6]
4322 Multiple levels of nesting are not flattened.
4324 """
4325 scalar_types = (str, bytes)
4326 for value in args:
4327 if isinstance(value, scalar_types):
4328 yield value
4329 continue
4330 try:
4331 yield from value
4332 except TypeError:
4333 yield value
4336def product_index(element, *args):
4337 """Equivalent to ``list(product(*args)).index(element)``
4339 The products of *args* can be ordered lexicographically.
4340 :func:`product_index` computes the first index of *element* without
4341 computing the previous products.
4343 >>> product_index([8, 2], range(10), range(5))
4344 42
4346 ``ValueError`` will be raised if the given *element* isn't in the product
4347 of *args*.
4348 """
4349 elements = tuple(element)
4350 pools = tuple(map(tuple, args))
4351 if len(elements) != len(pools):
4352 raise ValueError('element is not a product of args')
4354 index = 0
4355 for elem, pool in zip(elements, pools):
4356 index = index * len(pool) + pool.index(elem)
4357 return index
4360def combination_index(element, iterable):
4361 """Equivalent to ``list(combinations(iterable, r)).index(element)``
4363 The subsequences of *iterable* that are of length *r* can be ordered
4364 lexicographically. :func:`combination_index` computes the index of the
4365 first *element*, without computing the previous combinations.
4367 >>> combination_index('adf', 'abcdefg')
4368 10
4370 ``ValueError`` will be raised if the given *element* isn't one of the
4371 combinations of *iterable*.
4372 """
4373 element = enumerate(element)
4374 k, y = next(element, (None, None))
4375 if k is None:
4376 return 0
4378 indexes = []
4379 pool = enumerate(iterable)
4380 for n, x in pool:
4381 if x == y:
4382 indexes.append(n)
4383 tmp, y = next(element, (None, None))
4384 if tmp is None:
4385 break
4386 else:
4387 k = tmp
4388 else:
4389 raise ValueError('element is not a combination of iterable')
4391 n, _ = last(pool, default=(n, None))
4393 index = 1
4394 for i, j in enumerate(reversed(indexes), start=1):
4395 j = n - j
4396 if i <= j:
4397 index += comb(j, i)
4399 return comb(n + 1, k + 1) - index
4402def combination_with_replacement_index(element, iterable):
4403 """Equivalent to
4404 ``list(combinations_with_replacement(iterable, r)).index(element)``
4406 The subsequences with repetition of *iterable* that are of length *r* can
4407 be ordered lexicographically. :func:`combination_with_replacement_index`
4408 computes the index of the first *element*, without computing the previous
4409 combinations with replacement.
4411 >>> combination_with_replacement_index('adf', 'abcdefg')
4412 20
4414 ``ValueError`` will be raised if the given *element* isn't one of the
4415 combinations with replacement of *iterable*.
4416 """
4417 element = tuple(element)
4418 l = len(element)
4419 element = enumerate(element)
4421 k, y = next(element, (None, None))
4422 if k is None:
4423 return 0
4425 indexes = []
4426 pool = tuple(iterable)
4427 for n, x in enumerate(pool):
4428 while x == y:
4429 indexes.append(n)
4430 tmp, y = next(element, (None, None))
4431 if tmp is None:
4432 break
4433 else:
4434 k = tmp
4435 if y is None:
4436 break
4437 else:
4438 raise ValueError(
4439 'element is not a combination with replacement of iterable'
4440 )
4442 n = len(pool)
4443 occupations = [0] * n
4444 for p in indexes:
4445 occupations[p] += 1
4447 index = 0
4448 cumulative_sum = 0
4449 for k in range(1, n):
4450 cumulative_sum += occupations[k - 1]
4451 j = l + n - 1 - k - cumulative_sum
4452 i = n - k
4453 if i <= j:
4454 index += comb(j, i)
4456 return index
4459def permutation_index(element, iterable):
4460 """Equivalent to ``list(permutations(iterable, r)).index(element)```
4462 The subsequences of *iterable* that are of length *r* where order is
4463 important can be ordered lexicographically. :func:`permutation_index`
4464 computes the index of the first *element* directly, without computing
4465 the previous permutations.
4467 >>> permutation_index([1, 3, 2], range(5))
4468 19
4470 ``ValueError`` will be raised if the given *element* isn't one of the
4471 permutations of *iterable*.
4472 """
4473 index = 0
4474 pool = list(iterable)
4475 for i, x in zip(range(len(pool), -1, -1), element):
4476 r = pool.index(x)
4477 index = index * i + r
4478 del pool[r]
4480 return index
4483class countable:
4484 """Wrap *iterable* and keep a count of how many items have been consumed.
4486 The ``items_seen`` attribute starts at ``0`` and increments as the iterable
4487 is consumed:
4489 >>> iterable = map(str, range(10))
4490 >>> it = countable(iterable)
4491 >>> it.items_seen
4492 0
4493 >>> next(it), next(it)
4494 ('0', '1')
4495 >>> list(it)
4496 ['2', '3', '4', '5', '6', '7', '8', '9']
4497 >>> it.items_seen
4498 10
4499 """
4501 def __init__(self, iterable):
4502 self._iterator = iter(iterable)
4503 self.items_seen = 0
4505 def __iter__(self):
4506 return self
4508 def __next__(self):
4509 item = next(self._iterator)
4510 self.items_seen += 1
4512 return item
4515def chunked_even(iterable, n):
4516 """Break *iterable* into lists of approximately length *n*.
4517 Items are distributed such the lengths of the lists differ by at most
4518 1 item.
4520 >>> iterable = [1, 2, 3, 4, 5, 6, 7]
4521 >>> n = 3
4522 >>> list(chunked_even(iterable, n)) # List lengths: 3, 2, 2
4523 [[1, 2, 3], [4, 5], [6, 7]]
4524 >>> list(chunked(iterable, n)) # List lengths: 3, 3, 1
4525 [[1, 2, 3], [4, 5, 6], [7]]
4527 """
4528 iterator = iter(iterable)
4530 # Initialize a buffer to process the chunks while keeping
4531 # some back to fill any underfilled chunks
4532 min_buffer = (n - 1) * (n - 2)
4533 buffer = list(islice(iterator, min_buffer))
4535 # Append items until we have a completed chunk
4536 for _ in islice(map(buffer.append, iterator), n, None, n):
4537 yield buffer[:n]
4538 del buffer[:n]
4540 # Check if any chunks need addition processing
4541 if not buffer:
4542 return
4543 length = len(buffer)
4545 # Chunks are either size `full_size <= n` or `partial_size = full_size - 1`
4546 q, r = divmod(length, n)
4547 num_lists = q + (1 if r > 0 else 0)
4548 q, r = divmod(length, num_lists)
4549 full_size = q + (1 if r > 0 else 0)
4550 partial_size = full_size - 1
4551 num_full = length - partial_size * num_lists
4553 # Yield chunks of full size
4554 partial_start_idx = num_full * full_size
4555 if full_size > 0:
4556 for i in range(0, partial_start_idx, full_size):
4557 yield buffer[i : i + full_size]
4559 # Yield chunks of partial size
4560 if partial_size > 0:
4561 for i in range(partial_start_idx, length, partial_size):
4562 yield buffer[i : i + partial_size]
4565def zip_broadcast(*objects, scalar_types=(str, bytes), strict=False):
4566 """A version of :func:`zip` that "broadcasts" any scalar
4567 (i.e., non-iterable) items into output tuples.
4569 >>> iterable_1 = [1, 2, 3]
4570 >>> iterable_2 = ['a', 'b', 'c']
4571 >>> scalar = '_'
4572 >>> list(zip_broadcast(iterable_1, iterable_2, scalar))
4573 [(1, 'a', '_'), (2, 'b', '_'), (3, 'c', '_')]
4575 The *scalar_types* keyword argument determines what types are considered
4576 scalar. It is set to ``(str, bytes)`` by default. Set it to ``None`` to
4577 treat strings and byte strings as iterable:
4579 >>> list(zip_broadcast('abc', 0, 'xyz', scalar_types=None))
4580 [('a', 0, 'x'), ('b', 0, 'y'), ('c', 0, 'z')]
4582 If the *strict* keyword argument is ``True``, then
4583 ``ValueError`` will be raised if any of the iterables have
4584 different lengths.
4585 """
4587 def is_scalar(obj):
4588 if scalar_types and isinstance(obj, scalar_types):
4589 return True
4590 try:
4591 iter(obj)
4592 except TypeError:
4593 return True
4594 else:
4595 return False
4597 size = len(objects)
4598 if not size:
4599 return
4601 new_item = [None] * size
4602 iterables, iterable_positions = [], []
4603 for i, obj in enumerate(objects):
4604 if is_scalar(obj):
4605 new_item[i] = obj
4606 else:
4607 iterables.append(iter(obj))
4608 iterable_positions.append(i)
4610 if not iterables:
4611 yield tuple(objects)
4612 return
4614 for item in zip(*iterables, strict=strict):
4615 for i, new_item[i] in zip(iterable_positions, item):
4616 pass
4617 yield tuple(new_item)
4620def unique_in_window(iterable, n, key=None):
4621 """Yield the items from *iterable* that haven't been seen recently.
4622 *n* is the size of the sliding window.
4624 >>> iterable = [0, 1, 0, 2, 3, 0]
4625 >>> n = 3
4626 >>> list(unique_in_window(iterable, n))
4627 [0, 1, 2, 3, 0]
4629 The *key* function, if provided, will be used to determine uniqueness:
4631 >>> list(unique_in_window('abAcda', 3, key=lambda x: x.lower()))
4632 ['a', 'b', 'c', 'd', 'a']
4634 Updates a sliding window no larger than n and yields a value
4635 if the item only occurs once in the updated window.
4637 When `n == 1`, *unique_in_window* is memoryless:
4639 >>> list(unique_in_window('aab', n=1))
4640 ['a', 'a', 'b']
4642 The items in *iterable* must be hashable.
4644 """
4645 if n <= 0:
4646 raise ValueError('n must be greater than 0')
4648 window = deque(maxlen=n)
4649 counts = Counter()
4650 use_key = key is not None
4652 for item in iterable:
4653 if len(window) == n:
4654 to_discard = window[0]
4655 if counts[to_discard] == 1:
4656 del counts[to_discard]
4657 else:
4658 counts[to_discard] -= 1
4660 k = key(item) if use_key else item
4661 if k not in counts:
4662 yield item
4663 counts[k] += 1
4664 window.append(k)
4667def duplicates_everseen(iterable, key=None):
4668 """Yield duplicate elements after their first appearance.
4670 >>> list(duplicates_everseen('mississippi'))
4671 ['s', 'i', 's', 's', 'i', 'p', 'i']
4672 >>> list(duplicates_everseen('AaaBbbCccAaa', str.lower))
4673 ['a', 'a', 'b', 'b', 'c', 'c', 'A', 'a', 'a']
4675 This function is analogous to :func:`unique_everseen` and is subject to
4676 the same performance considerations.
4678 """
4679 seen_set = set()
4680 seen_list = []
4681 use_key = key is not None
4683 for element in iterable:
4684 k = key(element) if use_key else element
4685 try:
4686 if k not in seen_set:
4687 seen_set.add(k)
4688 else:
4689 yield element
4690 except TypeError:
4691 if k not in seen_list:
4692 seen_list.append(k)
4693 else:
4694 yield element
4697def duplicates_justseen(iterable, key=None):
4698 """Yields serially-duplicate elements after their first appearance.
4700 >>> list(duplicates_justseen('mississippi'))
4701 ['s', 's', 'p']
4702 >>> list(duplicates_justseen('AaaBbbCccAaa', str.lower))
4703 ['a', 'a', 'b', 'b', 'c', 'c', 'a', 'a']
4705 This function is analogous to :func:`unique_justseen`.
4707 """
4708 return flatten(g for _, g in groupby(iterable, key) for _ in g)
4711def classify_unique(iterable, key=None):
4712 """Classify each element in terms of its uniqueness.
4714 For each element in the input iterable, return a 3-tuple consisting of:
4716 1. The element itself
4717 2. ``False`` if the element is equal to the one preceding it in the input,
4718 ``True`` otherwise (i.e. the equivalent of :func:`unique_justseen`)
4719 3. ``False`` if this element has been seen anywhere in the input before,
4720 ``True`` otherwise (i.e. the equivalent of :func:`unique_everseen`)
4722 >>> list(classify_unique('otto')) # doctest: +NORMALIZE_WHITESPACE
4723 [('o', True, True),
4724 ('t', True, True),
4725 ('t', False, False),
4726 ('o', True, False)]
4728 This function is analogous to :func:`unique_everseen` and is subject to
4729 the same performance considerations.
4731 """
4732 seen_set = set()
4733 seen_list = []
4734 use_key = key is not None
4735 previous = None
4737 for i, element in enumerate(iterable):
4738 k = key(element) if use_key else element
4739 is_unique_justseen = not i or previous != k
4740 previous = k
4741 is_unique_everseen = False
4742 try:
4743 if k not in seen_set:
4744 seen_set.add(k)
4745 is_unique_everseen = True
4746 except TypeError:
4747 if k not in seen_list:
4748 seen_list.append(k)
4749 is_unique_everseen = True
4750 yield element, is_unique_justseen, is_unique_everseen
4753def minmax(iterable_or_value, *others, key=None, default=_marker):
4754 """Returns both the smallest and largest items from an iterable
4755 or from two or more arguments.
4757 >>> minmax([3, 1, 5])
4758 (1, 5)
4760 >>> minmax(4, 2, 6)
4761 (2, 6)
4763 If a *key* function is provided, it will be used to transform the input
4764 items for comparison.
4766 >>> minmax([5, 30], key=str) # '30' sorts before '5'
4767 (30, 5)
4769 If a *default* value is provided, it will be returned if there are no
4770 input items.
4772 >>> minmax([], default=(0, 0))
4773 (0, 0)
4775 Otherwise ``ValueError`` is raised.
4777 This function makes a single pass over the input elements and takes care to
4778 minimize the number of comparisons made during processing.
4780 Note that unlike the builtin ``max`` function, which always returns the first
4781 item with the maximum value, this function may return another item when there are
4782 ties.
4784 This function is based on the
4785 `recipe <https://code.activestate.com/recipes/577916-fast-minmax-function>`__ by
4786 Raymond Hettinger.
4787 """
4788 iterable = (iterable_or_value, *others) if others else iterable_or_value
4790 it = iter(iterable)
4792 try:
4793 lo = hi = next(it)
4794 except StopIteration as exc:
4795 if default is _marker:
4796 raise ValueError(
4797 '`minmax()` argument is an empty iterable. '
4798 'Provide a `default` value to suppress this error.'
4799 ) from exc
4800 return default
4802 # Different branches depending on the presence of key. This saves a lot
4803 # of unimportant copies which would slow the "key=None" branch
4804 # significantly down.
4805 if key is None:
4806 for x, y in zip_longest(it, it, fillvalue=lo):
4807 if y < x:
4808 x, y = y, x
4809 if x < lo:
4810 lo = x
4811 if hi < y:
4812 hi = y
4814 else:
4815 lo_key = hi_key = key(lo)
4817 for x, y in zip_longest(it, it, fillvalue=lo):
4818 x_key, y_key = key(x), key(y)
4820 if y_key < x_key:
4821 x, y, x_key, y_key = y, x, y_key, x_key
4822 if x_key < lo_key:
4823 lo, lo_key = x, x_key
4824 if hi_key < y_key:
4825 hi, hi_key = y, y_key
4827 return lo, hi
4830def constrained_batches(
4831 iterable, max_size, max_count=None, get_len=len, strict=True
4832):
4833 """Yield batches of items from *iterable* with a combined size limited by
4834 *max_size*.
4836 >>> iterable = [b'12345', b'123', b'12345678', b'1', b'1', b'12', b'1']
4837 >>> list(constrained_batches(iterable, 10))
4838 [(b'12345', b'123'), (b'12345678', b'1', b'1'), (b'12', b'1')]
4840 If a *max_count* is supplied, the number of items per batch is also
4841 limited:
4843 >>> iterable = [b'12345', b'123', b'12345678', b'1', b'1', b'12', b'1']
4844 >>> list(constrained_batches(iterable, 10, max_count = 2))
4845 [(b'12345', b'123'), (b'12345678', b'1'), (b'1', b'12'), (b'1',)]
4847 If a *get_len* function is supplied, use that instead of :func:`len` to
4848 determine item size.
4850 If *strict* is ``True``, raise ``ValueError`` if any single item is bigger
4851 than *max_size*. Otherwise, allow single items to exceed *max_size*.
4852 """
4853 if max_size <= 0:
4854 raise ValueError('maximum size must be greater than zero')
4856 batch = []
4857 batch_size = 0
4858 batch_count = 0
4859 for item in iterable:
4860 item_len = get_len(item)
4861 if strict and item_len > max_size:
4862 raise ValueError('item size exceeds maximum size')
4864 reached_count = batch_count == max_count
4865 reached_size = item_len + batch_size > max_size
4866 if batch_count and (reached_size or reached_count):
4867 yield tuple(batch)
4868 batch.clear()
4869 batch_size = 0
4870 batch_count = 0
4872 batch.append(item)
4873 batch_size += item_len
4874 batch_count += 1
4876 if batch:
4877 yield tuple(batch)
4880def gray_product(*iterables):
4881 """Like :func:`itertools.product`, but return tuples in an order such
4882 that only one element in the generated tuple changes from one iteration
4883 to the next.
4885 >>> list(gray_product('AB','CD'))
4886 [('A', 'C'), ('B', 'C'), ('B', 'D'), ('A', 'D')]
4888 This function consumes all of the input iterables before producing output.
4889 If any of the input iterables have fewer than two items, ``ValueError``
4890 is raised.
4892 For information on the algorithm, see
4893 `this section <https://www-cs-faculty.stanford.edu/~knuth/fasc2a.ps.gz>`__
4894 of Donald Knuth's *The Art of Computer Programming*.
4895 """
4896 all_iterables = tuple(tuple(x) for x in iterables)
4897 iterable_count = len(all_iterables)
4898 for iterable in all_iterables:
4899 if len(iterable) < 2:
4900 raise ValueError("each iterable must have two or more items")
4902 # This is based on "Algorithm H" from section 7.2.1.1, page 20.
4903 # a holds the indexes of the source iterables for the n-tuple to be yielded
4904 # f is the array of "focus pointers"
4905 # o is the array of "directions"
4906 a = [0] * iterable_count
4907 f = list(range(iterable_count + 1))
4908 o = [1] * iterable_count
4909 while True:
4910 yield tuple(all_iterables[i][a[i]] for i in range(iterable_count))
4911 j = f[0]
4912 f[0] = 0
4913 if j == iterable_count:
4914 break
4915 a[j] = a[j] + o[j]
4916 if a[j] == 0 or a[j] == len(all_iterables[j]) - 1:
4917 o[j] = -o[j]
4918 f[j] = f[j + 1]
4919 f[j + 1] = j + 1
4922def partial_product(*iterables):
4923 """Yields tuples containing one item from each iterator, with subsequent
4924 tuples changing a single item at a time by advancing each iterator until it
4925 is exhausted. This sequence guarantees every value in each iterable is
4926 output at least once without generating all possible combinations.
4928 This may be useful, for example, when testing an expensive function.
4930 >>> list(partial_product('AB', 'C', 'DEF'))
4931 [('A', 'C', 'D'), ('B', 'C', 'D'), ('B', 'C', 'E'), ('B', 'C', 'F')]
4932 """
4934 iterators = list(map(iter, iterables))
4936 try:
4937 prod = [next(it) for it in iterators]
4938 except StopIteration:
4939 return
4940 yield tuple(prod)
4942 for i, it in enumerate(iterators):
4943 for prod[i] in it:
4944 yield tuple(prod)
4947def takewhile_inclusive(predicate, iterable):
4948 """A variant of :func:`takewhile` that yields one additional element.
4950 >>> list(takewhile_inclusive(lambda x: x < 5, [1, 4, 6, 4, 1]))
4951 [1, 4, 6]
4953 :func:`takewhile` would return ``[1, 4]``.
4954 """
4955 for x in iterable:
4956 yield x
4957 if not predicate(x):
4958 break
4961def outer_product(func, xs, ys, *args, **kwargs):
4962 """A generalized outer product that applies a binary function to all
4963 pairs of items. Returns a 2D matrix with ``len(xs)`` rows and ``len(ys)``
4964 columns.
4965 Also accepts ``*args`` and ``**kwargs`` that are passed to ``func``.
4967 Multiplication table:
4969 >>> list(outer_product(mul, range(1, 4), range(1, 6)))
4970 [(1, 2, 3, 4, 5), (2, 4, 6, 8, 10), (3, 6, 9, 12, 15)]
4972 Cross tabulation:
4974 >>> xs = ['A', 'B', 'A', 'A', 'B', 'B', 'A', 'A', 'B', 'B']
4975 >>> ys = ['X', 'X', 'X', 'Y', 'Z', 'Z', 'Y', 'Y', 'Z', 'Z']
4976 >>> pair_counts = Counter(zip(xs, ys))
4977 >>> count_rows = lambda x, y: pair_counts[x, y]
4978 >>> list(outer_product(count_rows, sorted(set(xs)), sorted(set(ys))))
4979 [(2, 3, 0), (1, 0, 4)]
4981 Usage with ``*args`` and ``**kwargs``:
4983 >>> animals = ['cat', 'wolf', 'mouse']
4984 >>> list(outer_product(min, animals, animals, key=len))
4985 [('cat', 'cat', 'cat'), ('cat', 'wolf', 'wolf'), ('cat', 'wolf', 'mouse')]
4986 """
4987 ys = tuple(ys)
4988 return batched(
4989 starmap(lambda x, y: func(x, y, *args, **kwargs), product(xs, ys)),
4990 n=len(ys),
4991 )
4994def iter_suppress(iterable, *exceptions):
4995 """Yield each of the items from *iterable*. If the iteration raises one of
4996 the specified *exceptions*, that exception will be suppressed and iteration
4997 will stop.
4999 >>> from itertools import chain
5000 >>> def breaks_at_five(x):
5001 ... while True:
5002 ... if x >= 5:
5003 ... raise RuntimeError
5004 ... yield x
5005 ... x += 1
5006 >>> it_1 = iter_suppress(breaks_at_five(1), RuntimeError)
5007 >>> it_2 = iter_suppress(breaks_at_five(2), RuntimeError)
5008 >>> list(chain(it_1, it_2))
5009 [1, 2, 3, 4, 2, 3, 4]
5010 """
5011 try:
5012 yield from iterable
5013 except exceptions:
5014 return
5017def filter_map(func, iterable):
5018 """Apply *func* to every element of *iterable*, yielding only those which
5019 are not ``None``.
5021 >>> elems = ['1', 'a', '2', 'b', '3']
5022 >>> list(filter_map(lambda s: int(s) if s.isnumeric() else None, elems))
5023 [1, 2, 3]
5024 """
5025 for x in iterable:
5026 y = func(x)
5027 if y is not None:
5028 yield y
5031def powerset_of_sets(iterable, *, baseset=set):
5032 """Yields all possible subsets of the iterable.
5034 >>> list(powerset_of_sets([1, 2, 3])) # doctest: +SKIP
5035 [set(), {1}, {2}, {3}, {1, 2}, {1, 3}, {2, 3}, {1, 2, 3}]
5036 >>> list(powerset_of_sets([1, 1, 0])) # doctest: +SKIP
5037 [set(), {1}, {0}, {0, 1}]
5039 :func:`powerset_of_sets` takes care to minimize the number
5040 of hash operations performed.
5042 The *baseset* parameter determines what kind of sets are
5043 constructed, either *set* or *frozenset*.
5044 """
5045 sets = tuple(dict.fromkeys(map(frozenset, zip(iterable))))
5046 union = baseset().union
5047 return chain.from_iterable(
5048 starmap(union, combinations(sets, r)) for r in range(len(sets) + 1)
5049 )
5052def join_mappings(**field_to_map):
5053 """
5054 Joins multiple mappings together using their common keys.
5056 >>> user_scores = {'elliot': 50, 'claris': 60}
5057 >>> user_times = {'elliot': 30, 'claris': 40}
5058 >>> join_mappings(score=user_scores, time=user_times)
5059 {'elliot': {'score': 50, 'time': 30}, 'claris': {'score': 60, 'time': 40}}
5060 """
5061 ret = defaultdict(dict)
5063 for field_name, mapping in field_to_map.items():
5064 for key, value in mapping.items():
5065 ret[key][field_name] = value
5067 return dict(ret)
5070def _complex_sumprod(v1, v2):
5071 """High precision sumprod() for complex numbers.
5072 Used by :func:`dft` and :func:`idft`.
5073 """
5075 real = attrgetter('real')
5076 imag = attrgetter('imag')
5077 r1 = chain(map(real, v1), map(neg, map(imag, v1)))
5078 r2 = chain(map(real, v2), map(imag, v2))
5079 i1 = chain(map(real, v1), map(imag, v1))
5080 i2 = chain(map(imag, v2), map(real, v2))
5081 return complex(_fsumprod(r1, r2), _fsumprod(i1, i2))
5084def dft(xarr):
5085 """Discrete Fourier Transform. *xarr* is a sequence of complex numbers.
5086 Yields the components of the corresponding transformed output vector.
5088 >>> import cmath
5089 >>> xarr = [1, 2-1j, -1j, -1+2j] # time domain
5090 >>> Xarr = [2, -2-2j, -2j, 4+4j] # frequency domain
5091 >>> magnitudes, phases = zip(*map(cmath.polar, Xarr))
5092 >>> all(map(cmath.isclose, dft(xarr), Xarr))
5093 True
5095 Inputs are restricted to numeric types that can add and multiply
5096 with a complex number. This includes int, float, complex, and
5097 Fraction, but excludes Decimal.
5099 See :func:`idft` for the inverse Discrete Fourier Transform.
5100 """
5101 N = len(xarr)
5102 roots_of_unity = [e ** (n / N * tau * -1j) for n in range(N)]
5103 for k in range(N):
5104 coeffs = [roots_of_unity[k * n % N] for n in range(N)]
5105 yield _complex_sumprod(xarr, coeffs)
5108def idft(Xarr):
5109 """Inverse Discrete Fourier Transform. *Xarr* is a sequence of
5110 complex numbers. Yields the components of the corresponding
5111 inverse-transformed output vector.
5113 >>> import cmath
5114 >>> xarr = [1, 2-1j, -1j, -1+2j] # time domain
5115 >>> Xarr = [2, -2-2j, -2j, 4+4j] # frequency domain
5116 >>> all(map(cmath.isclose, idft(Xarr), xarr))
5117 True
5119 Inputs are restricted to numeric types that can add and multiply
5120 with a complex number. This includes int, float, complex, and
5121 Fraction, but excludes Decimal.
5123 See :func:`dft` for the Discrete Fourier Transform.
5124 """
5125 N = len(Xarr)
5126 roots_of_unity = [e ** (n / N * tau * 1j) for n in range(N)]
5127 for k in range(N):
5128 coeffs = [roots_of_unity[k * n % N] for n in range(N)]
5129 yield _complex_sumprod(Xarr, coeffs) / N
5132def doublestarmap(func, iterable):
5133 """Apply *func* to every item of *iterable* by dictionary unpacking
5134 the item into *func*.
5136 The difference between :func:`itertools.starmap` and :func:`doublestarmap`
5137 parallels the distinction between ``func(*a)`` and ``func(**a)``.
5139 >>> iterable = [{'a': 1, 'b': 2}, {'a': 40, 'b': 60}]
5140 >>> list(doublestarmap(lambda a, b: a + b, iterable))
5141 [3, 100]
5143 ``TypeError`` will be raised if *func*'s signature doesn't match the
5144 mapping contained in *iterable* or if *iterable* does not contain mappings.
5145 """
5146 for item in iterable:
5147 yield func(**item)
5150def _nth_prime_bounds(n):
5151 """Bounds for the nth prime (counting from 1): lb < p_n < ub."""
5152 # At and above 688,383, the lb/ub spread is under 0.003 * p_n.
5154 if n < 1:
5155 raise ValueError
5157 if n < 6:
5158 return (n, 2.25 * n)
5160 # https://en.wikipedia.org/wiki/Prime-counting_function#Inequalities
5161 upper_bound = n * log(n * log(n))
5162 lower_bound = upper_bound - n
5163 if n >= 688_383:
5164 upper_bound -= n * (1.0 - (log(log(n)) - 2.0) / log(n))
5166 return lower_bound, upper_bound
5169def nth_prime(n, *, approximate=False):
5170 """Return the nth prime (counting from 0).
5172 >>> nth_prime(0)
5173 2
5174 >>> nth_prime(100)
5175 547
5177 If *approximate* is set to True, will return a prime close
5178 to the nth prime. The estimation is much faster than computing
5179 an exact result.
5181 >>> nth_prime(200_000_000, approximate=True) # Exact result is 4222234763
5182 4217820427
5184 """
5185 lb, ub = _nth_prime_bounds(n + 1)
5187 if not approximate or n <= 1_000_000:
5188 return nth(sieve(ceil(ub)), n)
5190 # Search from the midpoint and return the first odd prime
5191 odd = floor((lb + ub) / 2) | 1
5192 return first_true(count(odd, step=2), pred=is_prime)
5195def argmin(iterable, *, key=None):
5196 """
5197 Index of the first occurrence of a minimum value in an iterable.
5199 >>> argmin('efghabcdijkl')
5200 4
5201 >>> argmin([3, 2, 1, 0, 4, 2, 1, 0])
5202 3
5204 For example, look up a label corresponding to the position
5205 of a value that minimizes a cost function::
5207 >>> def cost(x):
5208 ... "Days for a wound to heal given a subject's age."
5209 ... return x**2 - 20*x + 150
5210 ...
5211 >>> labels = ['homer', 'marge', 'bart', 'lisa', 'maggie']
5212 >>> ages = [ 35, 30, 10, 9, 1 ]
5214 # Fastest healing family member
5215 >>> labels[argmin(ages, key=cost)]
5216 'bart'
5218 # Age with fastest healing
5219 >>> min(ages, key=cost)
5220 10
5222 """
5223 if key is not None:
5224 iterable = map(key, iterable)
5225 return min(enumerate(iterable), key=itemgetter(1))[0]
5228def argmax(iterable, *, key=None):
5229 """
5230 Index of the first occurrence of a maximum value in an iterable.
5232 >>> argmax('abcdefghabcd')
5233 7
5234 >>> argmax([0, 1, 2, 3, 3, 2, 1, 0])
5235 3
5237 For example, identify the best machine learning model::
5239 >>> models = ['svm', 'random forest', 'knn', 'naïve bayes']
5240 >>> accuracy = [ 68, 61, 84, 72 ]
5242 # Most accurate model
5243 >>> models[argmax(accuracy)]
5244 'knn'
5246 # Best accuracy
5247 >>> max(accuracy)
5248 84
5250 """
5251 if key is not None:
5252 iterable = map(key, iterable)
5253 return max(enumerate(iterable), key=itemgetter(1))[0]
5256def _extract_monotonic(iterator, indices):
5257 'Non-decreasing indices, lazily consumed'
5258 num_read = 0
5259 for index in indices:
5260 advance = index - num_read
5261 try:
5262 value = next(islice(iterator, advance, None))
5263 except ValueError:
5264 if advance != -1 or index < 0:
5265 raise ValueError(f'Invalid index: {index}') from None
5266 except StopIteration:
5267 raise IndexError(index) from None
5268 else:
5269 num_read += advance + 1
5270 yield value
5273def _extract_buffered(iterator, index_and_position):
5274 'Arbitrary index order, greedily consumed'
5275 buffer = {}
5276 iterator_position = -1
5277 next_to_emit = 0
5279 for index, order in index_and_position:
5280 advance = index - iterator_position
5281 if advance:
5282 try:
5283 value = next(islice(iterator, advance - 1, None))
5284 except StopIteration:
5285 raise IndexError(index) from None
5286 iterator_position = index
5288 buffer[order] = value
5290 while next_to_emit in buffer:
5291 yield buffer.pop(next_to_emit)
5292 next_to_emit += 1
5295def extract(iterable, indices, *, monotonic=False):
5296 """Yield values at the specified indices.
5298 Example:
5300 >>> data = 'abcdefghijklmnopqrstuvwxyz'
5301 >>> list(extract(data, [7, 4, 11, 11, 14]))
5302 ['h', 'e', 'l', 'l', 'o']
5304 The *iterable* is consumed lazily and can be infinite.
5306 When *monotonic* is false, the *indices* are consumed immediately
5307 and must be finite. When *monotonic* is true, *indices* are consumed
5308 lazily and can be infinite but must be non-decreasing.
5310 Raises ``IndexError`` if an index lies beyond the iterable.
5311 Raises ``ValueError`` for a negative index or for a decreasing
5312 index when *monotonic* is true.
5313 """
5315 iterator = iter(iterable)
5316 indices = iter(indices)
5318 if monotonic:
5319 return _extract_monotonic(iterator, indices)
5321 index_and_position = sorted(zip(indices, count()))
5322 if index_and_position and index_and_position[0][0] < 0:
5323 raise ValueError('Indices must be non-negative')
5324 return _extract_buffered(iterator, index_and_position)
5327class serialize:
5328 """Wrap a non-concurrent iterator with a lock to enforce sequential access.
5330 Applies a non-reentrant lock around calls to ``__next__``, allowing
5331 iterator and generator instances to be shared by multiple consumer
5332 threads.
5333 """
5335 __slots__ = ('iterator', 'lock')
5337 def __init__(self, iterable):
5338 self.iterator = iter(iterable)
5339 self.lock = Lock()
5341 def __iter__(self):
5342 return self
5344 def __next__(self):
5345 with self.lock:
5346 return next(self.iterator)
5349def synchronized(func):
5350 """Wrap an iterator-returning callable to make its iterators thread-safe.
5352 Existing itertools and more-itertools can be wrapped so that their
5353 iterator instances are serialized.
5355 For example, ``itertools.count`` does not make thread-safe instances,
5356 but that is easily fixed with::
5358 atomic_counter = synchronized(itertools.count)
5360 Can also be used as a decorator for generator functions definitions
5361 so that the generator instances are serialized::
5363 @synchronized
5364 def enumerate_and_timestamp(iterable):
5365 for count, value in enumerate(iterable):
5366 yield count, time_ns(), value
5368 """
5370 @wraps(func)
5371 def inner(*args, **kwargs):
5372 iterator = func(*args, **kwargs)
5373 return serialize(iterator)
5375 return inner
5378def concurrent_tee(iterable, n=2):
5379 """Variant of itertools.tee() but with guaranteed threading semantics.
5381 Takes a non-threadsafe iterator as an input and creates concurrent
5382 tee objects for other threads to have reliable independent copies of
5383 the data stream.
5385 The new iterators are only thread-safe if consumed within a single thread.
5386 To share just one of the new iterators across multiple threads, wrap it
5387 with :func:`serialize`.
5388 """
5390 if n < 0:
5391 raise ValueError
5392 if n == 0:
5393 return ()
5394 iterator = _concurrent_tee(iterable)
5395 result = [iterator]
5396 for _ in range(n - 1):
5397 result.append(_concurrent_tee(iterator))
5398 return tuple(result)
5401class _concurrent_tee:
5402 __slots__ = ('iterator', 'link', 'lock')
5404 def __init__(self, iterable):
5405 it = iter(iterable)
5406 if isinstance(it, _concurrent_tee):
5407 self.iterator = it.iterator
5408 self.link = it.link
5409 self.lock = it.lock
5410 else:
5411 self.iterator = it
5412 self.link = [None, None]
5413 self.lock = Lock()
5415 def __iter__(self):
5416 return self
5418 def __next__(self):
5419 link = self.link
5420 if link[1] is None:
5421 with self.lock:
5422 if link[1] is None:
5423 link[0] = next(self.iterator)
5424 link[1] = [None, None]
5425 value, self.link = link
5426 return value