Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/more_itertools/more.py: 19%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import math
3from collections import Counter, defaultdict, deque
4from collections.abc import Sequence
5from contextlib import suppress
6from functools import cached_property, partial, wraps
7from heapq import heapify, heapreplace
8from itertools import (
9 chain,
10 combinations,
11 compress,
12 count,
13 cycle,
14 dropwhile,
15 groupby,
16 islice,
17 permutations,
18 repeat,
19 starmap,
20 takewhile,
21 tee,
22 zip_longest,
23 product,
24)
25from math import comb, e, exp, factorial, floor, fsum, log, log1p, perm, tau
26from math import ceil, prod
27from queue import Empty, Queue
28from random import random, randrange, shuffle, uniform
29from operator import (
30 attrgetter,
31 getitem,
32 is_not,
33 itemgetter,
34 lt,
35 neg,
36 sub,
37 gt,
38)
39from sys import maxsize
40from time import monotonic
41from threading import Lock
43from .recipes import (
44 _marker,
45 consume,
46 first_true,
47 flatten,
48 is_prime,
49 nth,
50 powerset,
51 sieve,
52 take,
53 unique_everseen,
54 all_equal,
55 batched,
56)
58__all__ = [
59 'AbortThread',
60 'SequenceView',
61 'adjacent',
62 'all_unique',
63 'always_iterable',
64 'always_reversible',
65 'argmax',
66 'argmin',
67 'bucket',
68 'callback_iter',
69 'chunked',
70 'chunked_even',
71 'circular_shifts',
72 'collapse',
73 'combination_index',
74 'combination_with_replacement_index',
75 'concurrent_tee',
76 'consecutive_groups',
77 'constrained_batches',
78 'consumer',
79 'count_cycle',
80 'countable',
81 'derangements',
82 'dft',
83 'difference',
84 'distinct_combinations',
85 'distinct_permutations',
86 'distribute',
87 'divide',
88 'doublestarmap',
89 'duplicates_everseen',
90 'duplicates_justseen',
91 'classify_unique',
92 'exactly_n',
93 'extract',
94 'filter_except',
95 'filter_map',
96 'first',
97 'gray_product',
98 'groupby_transform',
99 'ichunked',
100 'iequals',
101 'idft',
102 'ilen',
103 'interleave',
104 'interleave_evenly',
105 'interleave_longest',
106 'interleave_randomly',
107 'intersperse',
108 'is_sorted',
109 'islice_extended',
110 'iterate',
111 'iter_suppress',
112 'join_mappings',
113 'last',
114 'locate',
115 'longest_common_prefix',
116 'lstrip',
117 'make_decorator',
118 'map_except',
119 'map_if',
120 'map_reduce',
121 'mark_ends',
122 'minmax',
123 'nth_or_last',
124 'nth_permutation',
125 'nth_prime',
126 'nth_product',
127 'nth_combination_with_replacement',
128 'numeric_range',
129 'one',
130 'only',
131 'outer_product',
132 'padded',
133 'partial_product',
134 'partitions',
135 'peekable',
136 'permutation_index',
137 'powerset_of_sets',
138 'product_index',
139 'raise_',
140 'repeat_each',
141 'repeat_last',
142 'replace',
143 'rlocate',
144 'rstrip',
145 'run_length',
146 'sample',
147 'seekable',
148 'serialize',
149 'set_partitions',
150 'side_effect',
151 'sliced',
152 'sort_together',
153 'split_after',
154 'split_at',
155 'split_before',
156 'split_into',
157 'split_when',
158 'spy',
159 'stagger',
160 'strip',
161 'strictly_n',
162 'substrings',
163 'substrings_indexes',
164 'synchronized',
165 'takewhile_inclusive',
166 'time_limited',
167 'unique_in_window',
168 'unique_to_each',
169 'unzip',
170 'value_chain',
171 'windowed',
172 'windowed_complete',
173 'with_iter',
174 'zip_broadcast',
175 'zip_offset',
176]
178# math.sumprod is available for Python 3.12+
179try:
180 from math import sumprod as _fsumprod
182except ImportError: # pragma: no cover
183 # Extended precision algorithms from T. J. Dekker,
184 # "A Floating-Point Technique for Extending the Available Precision"
185 # https://csclub.uwaterloo.ca/~pbarfuss/dekker1971.pdf
186 # Formulas: (5.5) (5.6) and (5.8). Code: mul12()
188 def dl_split(x: float):
189 "Split a float into two half-precision components."
190 t = x * 134217729.0 # Veltkamp constant = 2.0 ** 27 + 1
191 hi = t - (t - x)
192 lo = x - hi
193 return hi, lo
195 def dl_mul(x, y):
196 "Lossless multiplication."
197 xx_hi, xx_lo = dl_split(x)
198 yy_hi, yy_lo = dl_split(y)
199 p = xx_hi * yy_hi
200 q = xx_hi * yy_lo + xx_lo * yy_hi
201 z = p + q
202 zz = p - z + q + xx_lo * yy_lo
203 return z, zz
205 def _fsumprod(p, q):
206 return fsum(chain.from_iterable(map(dl_mul, p, q)))
209def chunked(iterable, n, strict=False):
210 """Break *iterable* into lists of length *n*:
212 >>> list(chunked([1, 2, 3, 4, 5, 6], 3))
213 [[1, 2, 3], [4, 5, 6]]
215 By the default, the last yielded list will have fewer than *n* elements
216 if the length of *iterable* is not divisible by *n*:
218 >>> list(chunked([1, 2, 3, 4, 5, 6, 7, 8], 3))
219 [[1, 2, 3], [4, 5, 6], [7, 8]]
221 To use a fill-in value instead, see the :func:`grouper` recipe.
223 If the length of *iterable* is not divisible by *n* and *strict* is
224 ``True``, then ``ValueError`` will be raised before the last
225 list is yielded.
227 """
228 iterator = iter(partial(take, n, iter(iterable)), [])
229 if strict:
230 if n is None:
231 raise ValueError('n must not be None when using strict mode.')
233 def ret():
234 for chunk in iterator:
235 if len(chunk) != n:
236 raise ValueError('iterable is not divisible by n.')
237 yield chunk
239 return ret()
240 else:
241 return iterator
244def first(iterable, default=_marker):
245 """Return the first item of *iterable*, or *default* if *iterable* is
246 empty.
248 >>> first([0, 1, 2, 3])
249 0
250 >>> first([], 'some default')
251 'some default'
253 If *default* is not provided and there are no items in the iterable,
254 raise ``ValueError``.
256 :func:`first` is useful when you have a generator of expensive-to-retrieve
257 values and want any arbitrary one. It is marginally shorter than
258 ``next(iter(iterable), default)``.
260 """
261 for item in iterable:
262 return item
263 if default is _marker:
264 raise ValueError(
265 'first() was called on an empty iterable, '
266 'and no default value was provided.'
267 )
268 return default
271def last(iterable, default=_marker):
272 """Return the last item of *iterable*, or *default* if *iterable* is
273 empty.
275 >>> last([0, 1, 2, 3])
276 3
277 >>> last([], 'some default')
278 'some default'
280 If *default* is not provided and there are no items in the iterable,
281 raise ``ValueError``.
282 """
283 try:
284 if isinstance(iterable, Sequence):
285 return iterable[-1]
286 # Work around https://bugs.python.org/issue38525
287 if getattr(iterable, '__reversed__', None):
288 return next(reversed(iterable))
289 return deque(iterable, maxlen=1)[-1]
290 except (IndexError, TypeError, StopIteration):
291 if default is _marker:
292 raise ValueError(
293 'last() was called on an empty iterable, '
294 'and no default value was provided.'
295 )
296 return default
299def nth_or_last(iterable, n, default=_marker):
300 """Return the nth or the last item of *iterable*,
301 or *default* if *iterable* is empty.
303 >>> nth_or_last([0, 1, 2, 3], 2)
304 2
305 >>> nth_or_last([0, 1], 2)
306 1
307 >>> nth_or_last([], 0, 'some default')
308 'some default'
310 If *default* is not provided and there are no items in the iterable,
311 raise ``ValueError``.
312 """
313 return last(islice(iterable, n + 1), default=default)
316class peekable:
317 """Wrap an iterator to allow lookahead and prepending elements.
319 Call :meth:`peek` on the result to get the value that will be returned
320 by :func:`next`. This won't advance the iterator:
322 >>> p = peekable(['a', 'b'])
323 >>> p.peek()
324 'a'
325 >>> next(p)
326 'a'
328 Pass :meth:`peek` a default value to return that instead of raising
329 ``StopIteration`` when the iterator is exhausted.
331 >>> p = peekable([])
332 >>> p.peek('hi')
333 'hi'
335 peekables also offer a :meth:`prepend` method, which "inserts" items
336 at the head of the iterable:
338 >>> p = peekable([1, 2, 3])
339 >>> p.prepend(10, 11, 12)
340 >>> next(p)
341 10
342 >>> p.peek()
343 11
344 >>> list(p)
345 [11, 12, 1, 2, 3]
347 peekables can be indexed. Index 0 is the item that will be returned by
348 :func:`next`, index 1 is the item after that, and so on:
349 The values up to the given index will be cached.
351 >>> p = peekable(['a', 'b', 'c', 'd'])
352 >>> p[0]
353 'a'
354 >>> p[1]
355 'b'
356 >>> next(p)
357 'a'
359 Negative indexes are supported, but be aware that they will cache the
360 remaining items in the source iterator, which may require significant
361 storage.
363 To check whether a peekable is exhausted, check its truth value:
365 >>> p = peekable(['a', 'b'])
366 >>> if p: # peekable has items
367 ... list(p)
368 ['a', 'b']
369 >>> if not p: # peekable is exhausted
370 ... list(p)
371 []
373 """
375 def __init__(self, iterable):
376 self._it = iter(iterable)
377 self._cache = deque()
379 def __iter__(self):
380 return self
382 def __bool__(self):
383 try:
384 self.peek()
385 except StopIteration:
386 return False
387 return True
389 def peek(self, default=_marker):
390 """Return the item that will be next returned from ``next()``.
392 Return ``default`` if there are no items left. If ``default`` is not
393 provided, raise ``StopIteration``.
395 """
396 if not self._cache:
397 try:
398 self._cache.append(next(self._it))
399 except StopIteration:
400 if default is _marker:
401 raise
402 return default
403 return self._cache[0]
405 def prepend(self, *items):
406 """Stack up items to be the next ones returned from ``next()`` or
407 ``self.peek()``. The items will be returned in
408 first in, first out order::
410 >>> p = peekable([1, 2, 3])
411 >>> p.prepend(10, 11, 12)
412 >>> next(p)
413 10
414 >>> list(p)
415 [11, 12, 1, 2, 3]
417 It is possible, by prepending items, to "resurrect" a peekable that
418 previously raised ``StopIteration``.
420 >>> p = peekable([])
421 >>> next(p)
422 Traceback (most recent call last):
423 ...
424 StopIteration
425 >>> p.prepend(1)
426 >>> next(p)
427 1
428 >>> next(p)
429 Traceback (most recent call last):
430 ...
431 StopIteration
433 """
434 self._cache.extendleft(reversed(items))
436 def __next__(self):
437 if self._cache:
438 return self._cache.popleft()
440 return next(self._it)
442 def _get_slice(self, index):
443 # Normalize the slice's arguments
444 step = 1 if (index.step is None) else index.step
445 if step > 0:
446 start = 0 if (index.start is None) else index.start
447 stop = maxsize if (index.stop is None) else index.stop
448 elif step < 0:
449 start = -1 if (index.start is None) else index.start
450 stop = (-maxsize - 1) if (index.stop is None) else index.stop
451 else:
452 raise ValueError('slice step cannot be zero')
454 # If either the start or stop index is negative, we'll need to cache
455 # the rest of the iterable in order to slice from the right side.
456 if (start < 0) or (stop < 0):
457 self._cache.extend(self._it)
458 # Otherwise we'll need to find the rightmost index and cache to that
459 # point.
460 else:
461 n = min(max(start, stop) + 1, maxsize)
462 cache_len = len(self._cache)
463 if n >= cache_len:
464 self._cache.extend(islice(self._it, n - cache_len))
466 return list(self._cache)[index]
468 def __getitem__(self, index):
469 if isinstance(index, slice):
470 return self._get_slice(index)
472 cache_len = len(self._cache)
473 if index < 0:
474 self._cache.extend(self._it)
475 elif index >= cache_len:
476 self._cache.extend(islice(self._it, index + 1 - cache_len))
478 return self._cache[index]
481def consumer(func):
482 """Decorator that automatically advances a PEP-342-style "reverse iterator"
483 to its first yield point so you don't have to call ``next()`` on it
484 manually.
486 >>> @consumer
487 ... def tally():
488 ... i = 0
489 ... while True:
490 ... print('Thing number %s is %s.' % (i, (yield)))
491 ... i += 1
492 ...
493 >>> t = tally()
494 >>> t.send('red')
495 Thing number 0 is red.
496 >>> t.send('fish')
497 Thing number 1 is fish.
499 Without the decorator, you would have to call ``next(t)`` before
500 ``t.send()`` could be used.
502 """
504 @wraps(func)
505 def wrapper(*args, **kwargs):
506 gen = func(*args, **kwargs)
507 next(gen)
508 return gen
510 return wrapper
513def ilen(iterable):
514 """Return the number of items in *iterable*.
516 For example, there are 168 prime numbers below 1,000:
518 >>> ilen(sieve(1000))
519 168
521 Equivalent to, but faster than::
523 def ilen(iterable):
524 count = 0
525 for _ in iterable:
526 count += 1
527 return count
529 This fully consumes the iterable, so handle with care.
531 """
532 # This is the "most beautiful of the fast variants" of this function.
533 # If you think you can improve on it, please ensure that your version
534 # is both 10x faster and 10x more beautiful.
535 return sum(compress(repeat(1), zip(iterable)))
538def iterate(func, start):
539 """Return ``start``, ``func(start)``, ``func(func(start))``, ...
541 Produces an infinite iterator. To add a stopping condition,
542 use :func:`take`, ``takewhile``, or :func:`takewhile_inclusive`:.
544 >>> take(10, iterate(lambda x: 2*x, 1))
545 [1, 2, 4, 8, 16, 32, 64, 128, 256, 512]
547 >>> collatz = lambda x: 3*x + 1 if x%2==1 else x // 2
548 >>> list(takewhile_inclusive(lambda x: x!=1, iterate(collatz, 10)))
549 [10, 5, 16, 8, 4, 2, 1]
551 """
552 with suppress(StopIteration):
553 while True:
554 yield start
555 start = func(start)
558def with_iter(context_manager):
559 """Wrap an iterable in a ``with`` statement, so it closes once exhausted.
561 For example, this will close the file when the iterator is exhausted::
563 upper_lines = (line.upper() for line in with_iter(open('foo')))
565 Note that you have to actually exhaust the iterator for opened files to be closed.
567 Any context manager which returns an iterable is a candidate for
568 ``with_iter``.
570 """
571 with context_manager as iterable:
572 yield from iterable
575def one(iterable, too_short=None, too_long=None):
576 """Return the first item from *iterable*, which is expected to contain only
577 that item. Raise an exception if *iterable* is empty or has more than one
578 item.
580 :func:`one` is useful for ensuring that an iterable contains only one item.
581 For example, it can be used to retrieve the result of a database query
582 that is expected to return a single row.
584 If *iterable* is empty, ``ValueError`` will be raised. You may specify a
585 different exception with the *too_short* keyword:
587 >>> it = []
588 >>> one(it) # doctest: +IGNORE_EXCEPTION_DETAIL
589 Traceback (most recent call last):
590 ...
591 ValueError: too few items in iterable (expected 1)'
592 >>> too_short = IndexError('too few items')
593 >>> one(it, too_short=too_short) # doctest: +IGNORE_EXCEPTION_DETAIL
594 Traceback (most recent call last):
595 ...
596 IndexError: too few items
598 Similarly, if *iterable* contains more than one item, ``ValueError`` will
599 be raised. You may specify a different exception with the *too_long*
600 keyword:
602 >>> it = ['too', 'many']
603 >>> one(it) # doctest: +IGNORE_EXCEPTION_DETAIL
604 Traceback (most recent call last):
605 ...
606 ValueError: Expected exactly one item in iterable, but got 'too',
607 'many', and perhaps more.
608 >>> too_long = RuntimeError
609 >>> one(it, too_long=too_long) # doctest: +IGNORE_EXCEPTION_DETAIL
610 Traceback (most recent call last):
611 ...
612 RuntimeError
614 Note that :func:`one` attempts to advance *iterable* twice to ensure there
615 is only one item. See :func:`spy` or :func:`peekable` to check iterable
616 contents less destructively.
618 """
619 iterator = iter(iterable)
620 for first in iterator:
621 for second in iterator:
622 msg = (
623 f'Expected exactly one item in iterable, but got {first!r}, '
624 f'{second!r}, and perhaps more.'
625 )
626 raise too_long or ValueError(msg)
627 return first
628 raise too_short or ValueError('too few items in iterable (expected 1)')
631def raise_(exception, *args):
632 raise exception(*args)
635def strictly_n(iterable, n, too_short=None, too_long=None):
636 """Validate that *iterable* has exactly *n* items and return them if
637 it does. If it has fewer than *n* items, call function *too_short*
638 with the actual number of items. If it has more than *n* items, call function
639 *too_long* with the number ``n + 1``.
641 >>> iterable = ['a', 'b', 'c', 'd']
642 >>> n = 4
643 >>> list(strictly_n(iterable, n))
644 ['a', 'b', 'c', 'd']
646 Note that the returned iterable must be consumed in order for the check to
647 be made.
649 By default, *too_short* and *too_long* are functions that raise
650 ``ValueError``.
652 >>> list(strictly_n('ab', 3)) # doctest: +IGNORE_EXCEPTION_DETAIL
653 Traceback (most recent call last):
654 ...
655 ValueError: too few items in iterable (got 2)
657 >>> list(strictly_n('abc', 2)) # doctest: +IGNORE_EXCEPTION_DETAIL
658 Traceback (most recent call last):
659 ...
660 ValueError: too many items in iterable (got at least 3)
662 You can instead supply functions that do something else.
663 *too_short* will be called with the number of items in *iterable*.
664 *too_long* will be called with `n + 1`.
666 >>> def too_short(item_count):
667 ... raise RuntimeError
668 >>> it = strictly_n('abcd', 6, too_short=too_short)
669 >>> list(it) # doctest: +IGNORE_EXCEPTION_DETAIL
670 Traceback (most recent call last):
671 ...
672 RuntimeError
674 >>> def too_long(item_count):
675 ... print('The boss is going to hear about this')
676 >>> it = strictly_n('abcdef', 4, too_long=too_long)
677 >>> list(it)
678 The boss is going to hear about this
679 ['a', 'b', 'c', 'd']
681 """
682 if too_short is None:
683 too_short = lambda item_count: raise_(
684 ValueError,
685 f'Too few items in iterable (got {item_count})',
686 )
688 if too_long is None:
689 too_long = lambda item_count: raise_(
690 ValueError,
691 f'Too many items in iterable (got at least {item_count})',
692 )
694 it = iter(iterable)
696 sent = 0
697 for item in islice(it, n):
698 yield item
699 sent += 1
701 if sent < n:
702 too_short(sent)
703 return
705 for item in it:
706 too_long(n + 1)
707 return
710def distinct_permutations(iterable, r=None):
711 """Yield successive distinct permutations of the elements in *iterable*.
713 >>> sorted(distinct_permutations([1, 0, 1]))
714 [(0, 1, 1), (1, 0, 1), (1, 1, 0)]
716 Equivalent to yielding from ``set(permutations(iterable))``, except
717 duplicates are not generated and thrown away. For larger input sequences
718 this is much more efficient.
720 Duplicate permutations arise when there are duplicated elements in the
721 input iterable. The number of items returned is
722 `n! / (x_1! * x_2! * ... * x_n!)`, where `n` is the total number of
723 items input, and each `x_i` is the count of a distinct item in the input
724 sequence. The function :func:`multinomial` computes this directly.
726 If *r* is given, only the *r*-length permutations are yielded.
728 >>> sorted(distinct_permutations([1, 0, 1], r=2))
729 [(0, 1), (1, 0), (1, 1)]
730 >>> sorted(distinct_permutations(range(3), r=2))
731 [(0, 1), (0, 2), (1, 0), (1, 2), (2, 0), (2, 1)]
733 *iterable* need not be sortable, but note that using equal (``x == y``)
734 but non-identical (``id(x) != id(y)``) elements may produce surprising
735 behavior. For example, ``1`` and ``True`` are equal but non-identical:
737 >>> list(distinct_permutations([1, True, '3'])) # doctest: +SKIP
738 [
739 (1, True, '3'),
740 (1, '3', True),
741 ('3', 1, True)
742 ]
743 >>> list(distinct_permutations([1, 2, '3'])) # doctest: +SKIP
744 [
745 (1, 2, '3'),
746 (1, '3', 2),
747 (2, 1, '3'),
748 (2, '3', 1),
749 ('3', 1, 2),
750 ('3', 2, 1)
751 ]
752 """
754 # Algorithm: https://w.wiki/Qai
755 def _full(A):
756 while True:
757 # Yield the permutation we have
758 yield tuple(A)
760 # Find the largest index i such that A[i] < A[i + 1]
761 for i in range(size - 2, -1, -1):
762 if A[i] < A[i + 1]:
763 break
764 # If no such index exists, this permutation is the last one
765 else:
766 return
768 # Find the largest index j greater than j such that A[i] < A[j]
769 for j in range(size - 1, i, -1):
770 if A[i] < A[j]:
771 break
773 # Swap the value of A[i] with that of A[j], then reverse the
774 # sequence from A[i + 1] to form the new permutation
775 A[i], A[j] = A[j], A[i]
776 A[i + 1 :] = A[: i - size : -1] # A[i + 1:][::-1]
778 # Algorithm: modified from the above
779 def _partial(A, r):
780 # Split A into the first r items and the last r items
781 head, tail = A[:r], A[r:]
782 right_head_indexes = range(r - 1, -1, -1)
783 left_tail_indexes = range(len(tail))
785 while True:
786 # Yield the permutation we have
787 yield tuple(head)
789 # Starting from the right, find the first index of the head with
790 # value smaller than the maximum value of the tail - call it i.
791 pivot = tail[-1]
792 for i in right_head_indexes:
793 if head[i] < pivot:
794 break
795 pivot = head[i]
796 else:
797 return
799 # Starting from the left, find the first value of the tail
800 # with a value greater than head[i] and swap.
801 for j in left_tail_indexes:
802 if tail[j] > head[i]:
803 head[i], tail[j] = tail[j], head[i]
804 break
805 # If we didn't find one, start from the right and find the first
806 # index of the head with a value greater than head[i] and swap.
807 else:
808 for j in right_head_indexes:
809 if head[j] > head[i]:
810 head[i], head[j] = head[j], head[i]
811 break
813 # Reverse head[i + 1:] and swap it with tail[:r - (i + 1)]
814 tail += head[: i - r : -1] # head[i + 1:][::-1]
815 i += 1
816 head[i:], tail[:] = tail[: r - i], tail[r - i :]
818 items = list(iterable)
820 try:
821 items.sort()
822 sortable = True
823 except TypeError:
824 sortable = False
826 indices_dict = defaultdict(list)
828 for item in items:
829 indices_dict[items.index(item)].append(item)
831 indices = [items.index(item) for item in items]
832 indices.sort()
834 equivalent_items = {k: cycle(v) for k, v in indices_dict.items()}
836 def permuted_items(permuted_indices):
837 return tuple(
838 next(equivalent_items[index]) for index in permuted_indices
839 )
841 size = len(items)
842 if r is None:
843 r = size
845 # functools.partial(_partial, ... )
846 algorithm = _full if (r == size) else partial(_partial, r=r)
848 if 0 < r <= size:
849 if sortable:
850 return algorithm(items)
851 else:
852 return (
853 permuted_items(permuted_indices)
854 for permuted_indices in algorithm(indices)
855 )
857 return iter(() if r else ((),))
860def derangements(iterable, r=None):
861 """Yield successive derangements of the elements in *iterable*.
863 A derangement is a permutation in which no element appears at its original
864 index. In other words, a derangement is a permutation that has no fixed points.
866 Suppose Alice, Bob, Carol, and Dave are playing Secret Santa.
867 The code below outputs all of the different ways to assign gift recipients
868 such that nobody is assigned to himself or herself:
870 >>> for d in derangements(['Alice', 'Bob', 'Carol', 'Dave']):
871 ... print(', '.join(d))
872 Bob, Alice, Dave, Carol
873 Bob, Carol, Dave, Alice
874 Bob, Dave, Alice, Carol
875 Carol, Alice, Dave, Bob
876 Carol, Dave, Alice, Bob
877 Carol, Dave, Bob, Alice
878 Dave, Alice, Bob, Carol
879 Dave, Carol, Alice, Bob
880 Dave, Carol, Bob, Alice
882 If *r* is given, only the *r*-length derangements are yielded.
884 >>> sorted(derangements(range(3), 2))
885 [(1, 0), (1, 2), (2, 0)]
886 >>> sorted(derangements([0, 2, 3], 2))
887 [(2, 0), (2, 3), (3, 0)]
889 Elements are treated as unique based on their position, not on their value.
891 Consider the Secret Santa example with two *different* people who have
892 the *same* name. Then there are two valid gift assignments even though
893 it might appear that a person is assigned to themselves:
895 >>> names = ['Alice', 'Bob', 'Bob']
896 >>> list(derangements(names))
897 [('Bob', 'Bob', 'Alice'), ('Bob', 'Alice', 'Bob')]
899 To avoid confusion, make the inputs distinct:
901 >>> deduped = [f'{name}{index}' for index, name in enumerate(names)]
902 >>> list(derangements(deduped))
903 [('Bob1', 'Bob2', 'Alice0'), ('Bob2', 'Alice0', 'Bob1')]
905 The number of derangements of a set of size *n* is known as the
906 "subfactorial of n". For n > 0, the subfactorial is:
907 ``round(math.factorial(n) / math.e)``.
909 References:
911 * Article: https://www.numberanalytics.com/blog/ultimate-guide-to-derangements-in-combinatorics
912 * Sizes: https://oeis.org/A000166
913 """
914 xs = tuple(iterable)
915 ys = tuple(range(len(xs)))
916 return compress(
917 permutations(xs, r=r),
918 map(all, map(map, repeat(is_not), repeat(ys), permutations(ys, r=r))),
919 )
922def intersperse(e, iterable, n=1):
923 """Intersperse filler element *e* among the items in *iterable*, leaving
924 *n* items between each filler element.
926 >>> list(intersperse('!', [1, 2, 3, 4, 5]))
927 [1, '!', 2, '!', 3, '!', 4, '!', 5]
929 >>> list(intersperse(None, [1, 2, 3, 4, 5], n=2))
930 [1, 2, None, 3, 4, None, 5]
932 """
933 if n == 0:
934 raise ValueError('n must be > 0')
935 elif n == 1:
936 # interleave(repeat(e), iterable) -> e, x_0, e, x_1, e, x_2...
937 # islice(..., 1, None) -> x_0, e, x_1, e, x_2...
938 return islice(interleave(repeat(e), iterable), 1, None)
939 else:
940 # interleave(filler, chunks) -> [e], [x_0, x_1], [e], [x_2, x_3]...
941 # islice(..., 1, None) -> [x_0, x_1], [e], [x_2, x_3]...
942 # flatten(...) -> x_0, x_1, e, x_2, x_3...
943 filler = repeat([e])
944 chunks = chunked(iterable, n)
945 return flatten(islice(interleave(filler, chunks), 1, None))
948def unique_to_each(*iterables):
949 """Return the elements from each of the input iterables that aren't in the
950 other input iterables.
952 For example, suppose you have a set of packages, each with a set of
953 dependencies::
955 {'pkg_1': {'A', 'B'}, 'pkg_2': {'B', 'C'}, 'pkg_3': {'B', 'D'}}
957 If you remove one package, which dependencies can also be removed?
959 If ``pkg_1`` is removed, then ``A`` is no longer necessary - it is not
960 associated with ``pkg_2`` or ``pkg_3``. Similarly, ``C`` is only needed for
961 ``pkg_2``, and ``D`` is only needed for ``pkg_3``::
963 >>> unique_to_each({'A', 'B'}, {'B', 'C'}, {'B', 'D'})
964 [['A'], ['C'], ['D']]
966 If there are duplicates in one input iterable that aren't in the others
967 they will be duplicated in the output. Input order is preserved::
969 >>> unique_to_each("mississippi", "missouri")
970 [['p', 'p'], ['o', 'u', 'r']]
972 It is assumed that the elements of each iterable are hashable.
974 """
975 pool = [list(it) for it in iterables]
976 counts = Counter(chain.from_iterable(map(set, pool)))
977 uniques = {element for element in counts if counts[element] == 1}
978 return [list(filter(uniques.__contains__, it)) for it in pool]
981def windowed(seq, n, fillvalue=None, step=1):
982 """Return a sliding window of width *n* over the given iterable.
984 >>> all_windows = windowed([1, 2, 3, 4, 5], 3)
985 >>> list(all_windows)
986 [(1, 2, 3), (2, 3, 4), (3, 4, 5)]
988 When the window is larger than the iterable, *fillvalue* is used in place
989 of missing values:
991 >>> list(windowed([1, 2, 3], 4))
992 [(1, 2, 3, None)]
994 Each window will advance in increments of *step*:
996 >>> list(windowed([1, 2, 3, 4, 5, 6], 3, fillvalue='!', step=2))
997 [(1, 2, 3), (3, 4, 5), (5, 6, '!')]
999 To slide into the iterable's items, use :func:`chain` to add filler items
1000 to the left:
1002 >>> iterable = [1, 2, 3, 4]
1003 >>> n = 3
1004 >>> padding = [None] * (n - 1)
1005 >>> list(windowed(chain(padding, iterable), 3))
1006 [(None, None, 1), (None, 1, 2), (1, 2, 3), (2, 3, 4)]
1007 """
1008 if n < 0:
1009 raise ValueError('n must be >= 0')
1010 if n == 0:
1011 yield ()
1012 return
1013 if step < 1:
1014 raise ValueError('step must be >= 1')
1016 iterator = iter(seq)
1018 # Generate first window
1019 window = deque(islice(iterator, n), maxlen=n)
1021 # Deal with the first window not being full
1022 if not window:
1023 return
1024 if len(window) < n:
1025 yield tuple(window) + ((fillvalue,) * (n - len(window)))
1026 return
1027 yield tuple(window)
1029 # Create the filler for the next windows. The padding ensures
1030 # we have just enough elements to fill the last window.
1031 padding = (fillvalue,) * (n - 1 if step >= n else step - 1)
1032 filler = map(window.append, chain(iterator, padding))
1034 # Generate the rest of the windows
1035 for _ in islice(filler, step - 1, None, step):
1036 yield tuple(window)
1039def substrings(iterable):
1040 """Yield all of the substrings of *iterable*.
1042 >>> [''.join(s) for s in substrings('more')]
1043 ['m', 'o', 'r', 'e', 'mo', 'or', 're', 'mor', 'ore', 'more']
1045 Note that non-string iterables can also be subdivided.
1047 >>> list(substrings([0, 1, 2]))
1048 [(0,), (1,), (2,), (0, 1), (1, 2), (0, 1, 2)]
1050 Like subslices() but returns tuples instead of lists
1051 and returns the shortest substrings first.
1053 """
1054 seq = tuple(iterable)
1055 item_count = len(seq)
1056 for n in range(1, item_count + 1):
1057 slices = map(slice, range(item_count), range(n, item_count + 1))
1058 yield from map(getitem, repeat(seq), slices)
1061def substrings_indexes(seq, reverse=False):
1062 """Yield all substrings and their positions in *seq*
1064 The items yielded will be a tuple of the form ``(substr, i, j)``, where
1065 ``substr == seq[i:j]``.
1067 This function only works for iterables that support slicing, such as
1068 ``str`` objects.
1070 >>> for item in substrings_indexes('more'):
1071 ... print(item)
1072 ('m', 0, 1)
1073 ('o', 1, 2)
1074 ('r', 2, 3)
1075 ('e', 3, 4)
1076 ('mo', 0, 2)
1077 ('or', 1, 3)
1078 ('re', 2, 4)
1079 ('mor', 0, 3)
1080 ('ore', 1, 4)
1081 ('more', 0, 4)
1083 Set *reverse* to ``True`` to yield the same items in the opposite order.
1086 """
1087 r = range(1, len(seq) + 1)
1088 if reverse:
1089 r = reversed(r)
1090 return (
1091 (seq[i : i + L], i, i + L) for L in r for i in range(len(seq) - L + 1)
1092 )
1095class bucket:
1096 """Wrap *iterable* and return an object that buckets the iterable into
1097 child iterables based on a *key* function.
1099 >>> iterable = ['a1', 'b1', 'c1', 'a2', 'b2', 'c2', 'b3']
1100 >>> s = bucket(iterable, key=lambda x: x[0]) # Bucket by 1st character
1101 >>> sorted(list(s)) # Get the keys
1102 ['a', 'b', 'c']
1103 >>> a_iterable = s['a']
1104 >>> next(a_iterable)
1105 'a1'
1106 >>> next(a_iterable)
1107 'a2'
1108 >>> list(s['b'])
1109 ['b1', 'b2', 'b3']
1111 The original iterable will be advanced and its items will be cached until
1112 they are used by the child iterables. This may require significant storage.
1114 By default, attempting to select a bucket to which no items belong will
1115 exhaust the iterable and cache all values.
1116 If you specify a *validator* function, selected buckets will instead be
1117 checked against it.
1119 >>> from itertools import count
1120 >>> it = count(1, 2) # Infinite sequence of odd numbers
1121 >>> key = lambda x: x % 10 # Bucket by last digit
1122 >>> validator = lambda x: x in {1, 3, 5, 7, 9} # Odd digits only
1123 >>> s = bucket(it, key=key, validator=validator)
1124 >>> 2 in s
1125 False
1126 >>> list(s[2])
1127 []
1129 """
1131 def __init__(self, iterable, key, validator=None):
1132 self._it = iter(iterable)
1133 self._key = key
1134 self._cache = defaultdict(deque)
1135 self._validator = validator or (lambda x: True)
1137 def __contains__(self, value):
1138 if not self._validator(value):
1139 return False
1141 try:
1142 item = next(self[value])
1143 except StopIteration:
1144 return False
1145 else:
1146 self._cache[value].appendleft(item)
1148 return True
1150 def _get_values(self, value):
1151 """
1152 Helper to yield items from the parent iterator that match *value*.
1153 Items that don't match are stored in the local cache as they
1154 are encountered.
1155 """
1156 while True:
1157 # If we've cached some items that match the target value, emit
1158 # the first one and evict it from the cache.
1159 if self._cache[value]:
1160 yield self._cache[value].popleft()
1161 # Otherwise we need to advance the parent iterator to search for
1162 # a matching item, caching the rest.
1163 else:
1164 while True:
1165 try:
1166 item = next(self._it)
1167 except StopIteration:
1168 return
1169 item_value = self._key(item)
1170 if item_value == value:
1171 yield item
1172 break
1173 elif self._validator(item_value):
1174 self._cache[item_value].append(item)
1176 def __iter__(self):
1177 for item in self._it:
1178 item_value = self._key(item)
1179 if self._validator(item_value):
1180 self._cache[item_value].append(item)
1182 return iter(self._cache)
1184 def __getitem__(self, value):
1185 if not self._validator(value):
1186 return iter(())
1188 return self._get_values(value)
1191def spy(iterable, n=1):
1192 """Return a 2-tuple with a list containing the first *n* elements of
1193 *iterable*, and an iterator with the same items as *iterable*.
1194 This allows you to "look ahead" at the items in the iterable without
1195 advancing it.
1197 There is one item in the list by default:
1199 >>> iterable = 'abcdefg'
1200 >>> head, iterable = spy(iterable)
1201 >>> head
1202 ['a']
1203 >>> list(iterable)
1204 ['a', 'b', 'c', 'd', 'e', 'f', 'g']
1206 You may use unpacking to retrieve items instead of lists:
1208 >>> (head,), iterable = spy('abcdefg')
1209 >>> head
1210 'a'
1211 >>> (first, second), iterable = spy('abcdefg', 2)
1212 >>> first
1213 'a'
1214 >>> second
1215 'b'
1217 The number of items requested can be larger than the number of items in
1218 the iterable:
1220 >>> iterable = [1, 2, 3, 4, 5]
1221 >>> head, iterable = spy(iterable, 10)
1222 >>> head
1223 [1, 2, 3, 4, 5]
1224 >>> list(iterable)
1225 [1, 2, 3, 4, 5]
1227 """
1228 p, q = tee(iterable)
1229 return take(n, q), p
1232def interleave(*iterables):
1233 """Return a new iterable yielding from each iterable in turn,
1234 until the shortest is exhausted.
1236 >>> list(interleave([1, 2, 3], [4, 5], [6, 7, 8]))
1237 [1, 4, 6, 2, 5, 7]
1239 For a version that doesn't terminate after the shortest iterable is
1240 exhausted, see :func:`interleave_longest`.
1242 """
1243 return chain.from_iterable(zip(*iterables))
1246def interleave_longest(*iterables):
1247 """Return a new iterable yielding from each iterable in turn,
1248 skipping any that are exhausted.
1250 >>> list(interleave_longest([1, 2, 3], [4, 5], [6, 7, 8]))
1251 [1, 4, 6, 2, 5, 7, 3, 8]
1253 This function produces the same output as :func:`roundrobin`, but may
1254 perform better for some inputs (in particular when the number of iterables
1255 is large).
1257 """
1258 for xs in zip_longest(*iterables, fillvalue=_marker):
1259 for x in xs:
1260 if x is not _marker:
1261 yield x
1264def interleave_evenly(iterables, lengths=None):
1265 """
1266 Interleave multiple iterables so that their elements are evenly distributed
1267 throughout the output sequence.
1269 >>> iterables = [1, 2, 3, 4, 5], ['a', 'b']
1270 >>> list(interleave_evenly(iterables))
1271 [1, 2, 'a', 3, 4, 'b', 5]
1273 >>> iterables = [[1, 2, 3], [4, 5], [6, 7, 8]]
1274 >>> list(interleave_evenly(iterables))
1275 [1, 6, 4, 2, 7, 3, 8, 5]
1277 This function requires iterables of known length. Iterables without
1278 ``__len__()`` can be used by manually specifying lengths with *lengths*:
1280 >>> from itertools import combinations, repeat
1281 >>> iterables = [combinations(range(4), 2), ['a', 'b', 'c']]
1282 >>> lengths = [4 * (4 - 1) // 2, 3]
1283 >>> list(interleave_evenly(iterables, lengths=lengths))
1284 [(0, 1), (0, 2), 'a', (0, 3), (1, 2), 'b', (1, 3), (2, 3), 'c']
1286 Based on Bresenham's algorithm.
1287 """
1288 if lengths is None:
1289 try:
1290 lengths = [len(it) for it in iterables]
1291 except TypeError:
1292 raise ValueError(
1293 'Iterable lengths could not be determined automatically. '
1294 'Specify them with the lengths keyword.'
1295 )
1296 elif len(iterables) != len(lengths):
1297 raise ValueError('Mismatching number of iterables and lengths.')
1299 dims = len(lengths)
1301 # sort iterables by length, descending
1302 lengths_permute = sorted(
1303 range(dims), key=lambda i: lengths[i], reverse=True
1304 )
1305 lengths_desc = [lengths[i] for i in lengths_permute]
1306 iters_desc = [iter(iterables[i]) for i in lengths_permute]
1308 # the longest iterable is the primary one (Bresenham: the longest
1309 # distance along an axis)
1310 delta_primary, deltas_secondary = lengths_desc[0], lengths_desc[1:]
1311 iter_primary, iters_secondary = iters_desc[0], iters_desc[1:]
1312 errors = [delta_primary // dims] * len(deltas_secondary)
1314 to_yield = sum(lengths)
1315 while to_yield:
1316 yield next(iter_primary)
1317 to_yield -= 1
1318 # update errors for each secondary iterable
1319 errors = [e - delta for e, delta in zip(errors, deltas_secondary)]
1321 # those iterables for which the error is negative are yielded
1322 # ("diagonal step" in Bresenham)
1323 for i, e_ in enumerate(errors):
1324 if e_ < 0:
1325 yield next(iters_secondary[i])
1326 to_yield -= 1
1327 errors[i] += delta_primary
1330def interleave_randomly(*iterables):
1331 """Repeatedly select one of the input *iterables* at random and yield the next
1332 item from it.
1334 >>> iterables = [1, 2, 3], 'abc', (True, False, None)
1335 >>> list(interleave_randomly(*iterables)) # doctest: +SKIP
1336 ['a', 'b', 1, 'c', True, False, None, 2, 3]
1338 The relative order of the items in each input iterable will preserved. Note the
1339 sequences of items with this property are not equally likely to be generated.
1341 """
1342 iterators = [iter(e) for e in iterables]
1343 while iterators:
1344 idx = randrange(len(iterators))
1345 try:
1346 yield next(iterators[idx])
1347 except StopIteration:
1348 # equivalent to `list.pop` but slightly faster
1349 iterators[idx] = iterators[-1]
1350 del iterators[-1]
1353def collapse(iterable, base_type=None, levels=None):
1354 """Flatten an iterable with multiple levels of nesting (e.g., a list of
1355 lists of tuples) into non-iterable types.
1357 >>> iterable = [(1, 2), ([3, 4], [[5], [6]])]
1358 >>> list(collapse(iterable))
1359 [1, 2, 3, 4, 5, 6]
1361 Binary and text strings are not considered iterable and
1362 will not be collapsed.
1364 To avoid collapsing other types, specify *base_type*:
1366 >>> iterable = ['ab', ('cd', 'ef'), ['gh', 'ij']]
1367 >>> list(collapse(iterable, base_type=tuple))
1368 ['ab', ('cd', 'ef'), 'gh', 'ij']
1370 Specify *levels* to stop flattening after a certain level:
1372 >>> iterable = [('a', ['b']), ('c', ['d'])]
1373 >>> list(collapse(iterable)) # Fully flattened
1374 ['a', 'b', 'c', 'd']
1375 >>> list(collapse(iterable, levels=1)) # Only one level flattened
1376 ['a', ['b'], 'c', ['d']]
1378 """
1379 stack = deque()
1380 # Add our first node group, treat the iterable as a single node
1381 stack.appendleft((0, repeat(iterable, 1)))
1383 while stack:
1384 node_group = stack.popleft()
1385 level, nodes = node_group
1387 # Check if beyond max level
1388 if levels is not None and level > levels:
1389 yield from nodes
1390 continue
1392 for node in nodes:
1393 # Check if done iterating
1394 if isinstance(node, (str, bytes)) or (
1395 (base_type is not None) and isinstance(node, base_type)
1396 ):
1397 yield node
1398 # Otherwise try to create child nodes
1399 else:
1400 try:
1401 tree = iter(node)
1402 except TypeError:
1403 yield node
1404 else:
1405 # Save our current location
1406 stack.appendleft(node_group)
1407 # Append the new child node
1408 stack.appendleft((level + 1, tree))
1409 # Break to process child node
1410 break
1413def side_effect(func, iterable, chunk_size=None, before=None, after=None):
1414 """Invoke *func* on each item in *iterable* (or on each *chunk_size* group
1415 of items) before yielding the item.
1417 `func` must be a function that takes a single argument. Its return value
1418 will be discarded.
1420 *before* and *after* are optional functions that take no arguments. They
1421 will be executed before iteration starts and after it ends, respectively.
1423 `side_effect` can be used for logging, updating progress bars, or anything
1424 that is not functionally "pure."
1426 Emitting a status message:
1428 >>> from more_itertools import consume
1429 >>> func = lambda item: print('Received {}'.format(item))
1430 >>> consume(side_effect(func, range(2)))
1431 Received 0
1432 Received 1
1434 Operating on chunks of items:
1436 >>> pair_sums = []
1437 >>> func = lambda chunk: pair_sums.append(sum(chunk))
1438 >>> list(side_effect(func, [0, 1, 2, 3, 4, 5], 2))
1439 [0, 1, 2, 3, 4, 5]
1440 >>> list(pair_sums)
1441 [1, 5, 9]
1443 Writing to a file-like object:
1445 >>> from io import StringIO
1446 >>> from more_itertools import consume
1447 >>> f = StringIO()
1448 >>> func = lambda x: print(x, file=f)
1449 >>> before = lambda: print(u'HEADER', file=f)
1450 >>> after = f.close
1451 >>> it = [u'a', u'b', u'c']
1452 >>> consume(side_effect(func, it, before=before, after=after))
1453 >>> f.closed
1454 True
1456 """
1457 try:
1458 if before is not None:
1459 before()
1461 if chunk_size is None:
1462 for item in iterable:
1463 func(item)
1464 yield item
1465 else:
1466 for chunk in chunked(iterable, chunk_size):
1467 func(chunk)
1468 yield from chunk
1469 finally:
1470 if after is not None:
1471 after()
1474def sliced(seq, n, strict=False):
1475 """Yield slices of length *n* from the sequence *seq*.
1477 >>> list(sliced((1, 2, 3, 4, 5, 6), 3))
1478 [(1, 2, 3), (4, 5, 6)]
1480 By the default, the last yielded slice will have fewer than *n* elements
1481 if the length of *seq* is not divisible by *n*:
1483 >>> list(sliced((1, 2, 3, 4, 5, 6, 7, 8), 3))
1484 [(1, 2, 3), (4, 5, 6), (7, 8)]
1486 If the length of *seq* is not divisible by *n* and *strict* is
1487 ``True``, then ``ValueError`` will be raised before the last
1488 slice is yielded.
1490 This function will only work for iterables that support slicing.
1491 For non-sliceable iterables, see :func:`chunked`.
1493 """
1494 iterator = takewhile(len, (seq[i : i + n] for i in count(0, n)))
1495 if strict:
1497 def ret():
1498 for _slice in iterator:
1499 if len(_slice) != n:
1500 raise ValueError("seq is not divisible by n.")
1501 yield _slice
1503 return ret()
1504 else:
1505 return iterator
1508def split_at(iterable, pred, maxsplit=-1, keep_separator=False):
1509 """Yield lists of items from *iterable*, where each list is delimited by
1510 an item where callable *pred* returns ``True``.
1512 >>> list(split_at('abcdcba', lambda x: x == 'b'))
1513 [['a'], ['c', 'd', 'c'], ['a']]
1515 >>> list(split_at(range(10), lambda n: n % 2 == 1))
1516 [[0], [2], [4], [6], [8], []]
1518 At most *maxsplit* splits are done. If *maxsplit* is not specified or -1,
1519 then there is no limit on the number of splits:
1521 >>> list(split_at(range(10), lambda n: n % 2 == 1, maxsplit=2))
1522 [[0], [2], [4, 5, 6, 7, 8, 9]]
1524 By default, the delimiting items are not included in the output.
1525 To include them, set *keep_separator* to ``True``.
1527 >>> list(split_at('abcdcba', lambda x: x == 'b', keep_separator=True))
1528 [['a'], ['b'], ['c', 'd', 'c'], ['b'], ['a']]
1530 """
1531 if maxsplit == 0:
1532 yield list(iterable)
1533 return
1535 buf = []
1536 it = iter(iterable)
1537 for item in it:
1538 if pred(item):
1539 yield buf
1540 if keep_separator:
1541 yield [item]
1542 if maxsplit == 1:
1543 yield list(it)
1544 return
1545 buf = []
1546 maxsplit -= 1
1547 else:
1548 buf.append(item)
1549 yield buf
1552def split_before(iterable, pred, maxsplit=-1):
1553 """Yield lists of items from *iterable*, where each list ends just before
1554 an item for which callable *pred* returns ``True``:
1556 >>> list(split_before('OneTwo', lambda s: s.isupper()))
1557 [['O', 'n', 'e'], ['T', 'w', 'o']]
1559 >>> list(split_before(range(10), lambda n: n % 3 == 0))
1560 [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
1562 At most *maxsplit* splits are done. If *maxsplit* is not specified or -1,
1563 then there is no limit on the number of splits:
1565 >>> list(split_before(range(10), lambda n: n % 3 == 0, maxsplit=2))
1566 [[0, 1, 2], [3, 4, 5], [6, 7, 8, 9]]
1567 """
1568 if maxsplit == 0:
1569 yield list(iterable)
1570 return
1572 buf = []
1573 it = iter(iterable)
1574 for item in it:
1575 if pred(item) and buf:
1576 yield buf
1577 if maxsplit == 1:
1578 yield [item, *it]
1579 return
1580 buf = []
1581 maxsplit -= 1
1582 buf.append(item)
1583 if buf:
1584 yield buf
1587def split_after(iterable, pred, maxsplit=-1):
1588 """Yield lists of items from *iterable*, where each list ends with an
1589 item where callable *pred* returns ``True``:
1591 >>> list(split_after('one1two2', lambda s: s.isdigit()))
1592 [['o', 'n', 'e', '1'], ['t', 'w', 'o', '2']]
1594 >>> list(split_after(range(10), lambda n: n % 3 == 0))
1595 [[0], [1, 2, 3], [4, 5, 6], [7, 8, 9]]
1597 At most *maxsplit* splits are done. If *maxsplit* is not specified or -1,
1598 then there is no limit on the number of splits:
1600 >>> list(split_after(range(10), lambda n: n % 3 == 0, maxsplit=2))
1601 [[0], [1, 2, 3], [4, 5, 6, 7, 8, 9]]
1603 """
1604 if maxsplit == 0:
1605 yield list(iterable)
1606 return
1608 buf = []
1609 it = iter(iterable)
1610 for item in it:
1611 buf.append(item)
1612 if pred(item) and buf:
1613 yield buf
1614 if maxsplit == 1:
1615 buf = list(it)
1616 if buf:
1617 yield buf
1618 return
1619 buf = []
1620 maxsplit -= 1
1621 if buf:
1622 yield buf
1625def split_when(iterable, pred, maxsplit=-1):
1626 """Split *iterable* into pieces based on the output of *pred*.
1627 *pred* should be a function that takes successive pairs of items and
1628 returns ``True`` if the iterable should be split in between them.
1630 For example, to find runs of increasing numbers, split the iterable when
1631 element ``i`` is larger than element ``i + 1``:
1633 >>> list(split_when([1, 2, 3, 3, 2, 5, 2, 4, 2], lambda x, y: x > y))
1634 [[1, 2, 3, 3], [2, 5], [2, 4], [2]]
1636 At most *maxsplit* splits are done. If *maxsplit* is not specified or -1,
1637 then there is no limit on the number of splits:
1639 >>> list(split_when([1, 2, 3, 3, 2, 5, 2, 4, 2],
1640 ... lambda x, y: x > y, maxsplit=2))
1641 [[1, 2, 3, 3], [2, 5], [2, 4, 2]]
1643 """
1644 if maxsplit == 0:
1645 yield list(iterable)
1646 return
1648 it = iter(iterable)
1649 try:
1650 cur_item = next(it)
1651 except StopIteration:
1652 return
1654 buf = [cur_item]
1655 for next_item in it:
1656 if pred(cur_item, next_item):
1657 yield buf
1658 if maxsplit == 1:
1659 yield [next_item, *it]
1660 return
1661 buf = []
1662 maxsplit -= 1
1664 buf.append(next_item)
1665 cur_item = next_item
1667 yield buf
1670def split_into(iterable, sizes):
1671 """Yield a list of sequential items from *iterable* of length 'n' for each
1672 integer 'n' in *sizes*.
1674 >>> list(split_into([1,2,3,4,5,6], [1,2,3]))
1675 [[1], [2, 3], [4, 5, 6]]
1677 If the sum of *sizes* is smaller than the length of *iterable*, then the
1678 remaining items of *iterable* will not be returned.
1680 >>> list(split_into([1,2,3,4,5,6], [2,3]))
1681 [[1, 2], [3, 4, 5]]
1683 If the sum of *sizes* is larger than the length of *iterable*, fewer items
1684 will be returned in the iteration that overruns the *iterable* and further
1685 lists will be empty:
1687 >>> list(split_into([1,2,3,4], [1,2,3,4]))
1688 [[1], [2, 3], [4], []]
1690 When a ``None`` object is encountered in *sizes*, the returned list will
1691 contain items up to the end of *iterable* the same way that
1692 :func:`itertools.slice` does:
1694 >>> list(split_into([1,2,3,4,5,6,7,8,9,0], [2,3,None]))
1695 [[1, 2], [3, 4, 5], [6, 7, 8, 9, 0]]
1697 :func:`split_into` can be useful for grouping a series of items where the
1698 sizes of the groups are not uniform. An example would be where in a row
1699 from a table, multiple columns represent elements of the same feature
1700 (e.g. a point represented by x,y,z) but, the format is not the same for
1701 all columns.
1702 """
1703 # convert the iterable argument into an iterator so its contents can
1704 # be consumed by islice in case it is a generator
1705 it = iter(iterable)
1707 for size in sizes:
1708 if size is None:
1709 yield list(it)
1710 return
1711 else:
1712 yield list(islice(it, size))
1715def padded(iterable, fillvalue=None, n=None, next_multiple=False):
1716 """Yield the elements from *iterable*, followed by *fillvalue*, such that
1717 at least *n* items are emitted.
1719 >>> list(padded([1, 2, 3], '?', 5))
1720 [1, 2, 3, '?', '?']
1722 If *next_multiple* is ``True``, *fillvalue* will be emitted until the
1723 number of items emitted is a multiple of *n*:
1725 >>> list(padded([1, 2, 3, 4], n=3, next_multiple=True))
1726 [1, 2, 3, 4, None, None]
1728 If *n* is ``None``, *fillvalue* will be emitted indefinitely.
1730 To create an *iterable* of exactly size *n*, you can truncate with
1731 :func:`islice`.
1733 >>> list(islice(padded([1, 2, 3], '?'), 5))
1734 [1, 2, 3, '?', '?']
1735 >>> list(islice(padded([1, 2, 3, 4, 5, 6, 7, 8], '?'), 5))
1736 [1, 2, 3, 4, 5]
1738 """
1739 iterator = iter(iterable)
1740 iterator_with_repeat = chain(iterator, repeat(fillvalue))
1742 if n is None:
1743 return iterator_with_repeat
1744 elif n < 1:
1745 raise ValueError('n must be at least 1')
1746 elif next_multiple:
1748 def slice_generator():
1749 for first in iterator:
1750 yield (first,)
1751 yield islice(iterator_with_repeat, n - 1)
1753 # While elements exist produce slices of size n
1754 return chain.from_iterable(slice_generator())
1755 else:
1756 # Ensure the first batch is at least size n then iterate
1757 return chain(islice(iterator_with_repeat, n), iterator)
1760def repeat_each(iterable, n=2):
1761 """Repeat each element in *iterable* *n* times.
1763 >>> list(repeat_each('ABC', 3))
1764 ['A', 'A', 'A', 'B', 'B', 'B', 'C', 'C', 'C']
1765 """
1766 return chain.from_iterable(map(repeat, iterable, repeat(n)))
1769def repeat_last(iterable, default=None):
1770 """After the *iterable* is exhausted, keep yielding its last element.
1772 >>> list(islice(repeat_last(range(3)), 5))
1773 [0, 1, 2, 2, 2]
1775 If the iterable is empty, yield *default* forever::
1777 >>> list(islice(repeat_last(range(0), 42), 5))
1778 [42, 42, 42, 42, 42]
1780 """
1781 item = _marker
1782 for item in iterable:
1783 yield item
1784 final = default if item is _marker else item
1785 yield from repeat(final)
1788def distribute(n, iterable):
1789 """Distribute the items from *iterable* among *n* smaller iterables.
1791 >>> group_1, group_2 = distribute(2, [1, 2, 3, 4, 5, 6])
1792 >>> list(group_1)
1793 [1, 3, 5]
1794 >>> list(group_2)
1795 [2, 4, 6]
1797 If the length of *iterable* is not evenly divisible by *n*, then the
1798 length of the returned iterables will not be identical:
1800 >>> children = distribute(3, [1, 2, 3, 4, 5, 6, 7])
1801 >>> [list(c) for c in children]
1802 [[1, 4, 7], [2, 5], [3, 6]]
1804 If the length of *iterable* is smaller than *n*, then the last returned
1805 iterables will be empty:
1807 >>> children = distribute(5, [1, 2, 3])
1808 >>> [list(c) for c in children]
1809 [[1], [2], [3], [], []]
1811 This function uses :func:`itertools.tee` and may require significant
1812 storage.
1814 If you need the order items in the smaller iterables to match the
1815 original iterable, see :func:`divide`.
1817 """
1818 if n < 1:
1819 raise ValueError('n must be at least 1')
1821 children = tee(iterable, n)
1822 return [islice(it, index, None, n) for index, it in enumerate(children)]
1825def stagger(iterable, offsets=(-1, 0, 1), longest=False, fillvalue=None):
1826 """Yield tuples whose elements are offset from *iterable*.
1827 The amount by which the `i`-th item in each tuple is offset is given by
1828 the `i`-th item in *offsets*.
1830 >>> list(stagger([0, 1, 2, 3]))
1831 [(None, 0, 1), (0, 1, 2), (1, 2, 3)]
1832 >>> list(stagger(range(8), offsets=(0, 2, 4)))
1833 [(0, 2, 4), (1, 3, 5), (2, 4, 6), (3, 5, 7)]
1835 By default, the sequence will end when the final element of a tuple is the
1836 last item in the iterable. To continue until the first element of a tuple
1837 is the last item in the iterable, set *longest* to ``True``::
1839 >>> list(stagger([0, 1, 2, 3], longest=True))
1840 [(None, 0, 1), (0, 1, 2), (1, 2, 3), (2, 3, None), (3, None, None)]
1842 By default, ``None`` will be used to replace offsets beyond the end of the
1843 sequence. Specify *fillvalue* to use some other value.
1845 """
1846 children = tee(iterable, len(offsets))
1848 return zip_offset(
1849 *children, offsets=offsets, longest=longest, fillvalue=fillvalue
1850 )
1853def zip_offset(*iterables, offsets, longest=False, fillvalue=None):
1854 """``zip`` the input *iterables* together, but offset the `i`-th iterable
1855 by the `i`-th item in *offsets*.
1857 >>> list(zip_offset('0123', 'abcdef', offsets=(0, 1)))
1858 [('0', 'b'), ('1', 'c'), ('2', 'd'), ('3', 'e')]
1860 This can be used as a lightweight alternative to SciPy or pandas to analyze
1861 data sets in which some series have a lead or lag relationship.
1863 By default, the sequence will end when the shortest iterable is exhausted.
1864 To continue until the longest iterable is exhausted, set *longest* to
1865 ``True``.
1867 >>> list(zip_offset('0123', 'abcdef', offsets=(0, 1), longest=True))
1868 [('0', 'b'), ('1', 'c'), ('2', 'd'), ('3', 'e'), (None, 'f')]
1870 By default, ``None`` will be used to replace offsets beyond the end of the
1871 sequence. Specify *fillvalue* to use some other value.
1873 """
1874 if len(iterables) != len(offsets):
1875 raise ValueError("Number of iterables and offsets didn't match")
1877 staggered = []
1878 for it, n in zip(iterables, offsets):
1879 if n < 0:
1880 staggered.append(chain(repeat(fillvalue, -n), it))
1881 elif n > 0:
1882 staggered.append(islice(it, n, None))
1883 else:
1884 staggered.append(it)
1886 if longest:
1887 return zip_longest(*staggered, fillvalue=fillvalue)
1889 return zip(*staggered)
1892def sort_together(
1893 iterables, key_list=(0,), key=None, reverse=False, strict=False
1894):
1895 """Return the input iterables sorted together, with *key_list* as the
1896 priority for sorting. All iterables are trimmed to the length of the
1897 shortest one.
1899 This can be used like the sorting function in a spreadsheet. If each
1900 iterable represents a column of data, the key list determines which
1901 columns are used for sorting.
1903 By default, all iterables are sorted using the ``0``-th iterable::
1905 >>> iterables = [(4, 3, 2, 1), ('a', 'b', 'c', 'd')]
1906 >>> sort_together(iterables)
1907 [(1, 2, 3, 4), ('d', 'c', 'b', 'a')]
1909 Set a different key list to sort according to another iterable.
1910 Specifying multiple keys dictates how ties are broken::
1912 >>> iterables = [(3, 1, 2), (0, 1, 0), ('c', 'b', 'a')]
1913 >>> sort_together(iterables, key_list=(1, 2))
1914 [(2, 3, 1), (0, 0, 1), ('a', 'c', 'b')]
1916 To sort by a function of the elements of the iterable, pass a *key*
1917 function. Its arguments are the elements of the iterables corresponding to
1918 the key list::
1920 >>> names = ('a', 'b', 'c')
1921 >>> lengths = (1, 2, 3)
1922 >>> widths = (5, 2, 1)
1923 >>> def area(length, width):
1924 ... return length * width
1925 >>> sort_together([names, lengths, widths], key_list=(1, 2), key=area)
1926 [('c', 'b', 'a'), (3, 2, 1), (1, 2, 5)]
1928 Set *reverse* to ``True`` to sort in descending order.
1930 >>> sort_together([(1, 2, 3), ('c', 'b', 'a')], reverse=True)
1931 [(3, 2, 1), ('a', 'b', 'c')]
1933 If the *strict* keyword argument is ``True``, then
1934 ``ValueError`` will be raised if any of the iterables have
1935 different lengths.
1937 """
1938 if key is None:
1939 # if there is no key function, the key argument to sorted is an
1940 # itemgetter
1941 key_argument = itemgetter(*key_list)
1942 else:
1943 # if there is a key function, call it with the items at the offsets
1944 # specified by the key function as arguments
1945 key_list = list(key_list)
1946 if len(key_list) == 1:
1947 # if key_list contains a single item, pass the item at that offset
1948 # as the only argument to the key function
1949 key_offset = key_list[0]
1950 key_argument = lambda zipped_items: key(zipped_items[key_offset])
1951 else:
1952 # if key_list contains multiple items, use itemgetter to return a
1953 # tuple of items, which we pass as *args to the key function
1954 get_key_items = itemgetter(*key_list)
1955 key_argument = lambda zipped_items: key(
1956 *get_key_items(zipped_items)
1957 )
1959 transposed = zip(*iterables, strict=strict)
1960 reordered = sorted(transposed, key=key_argument, reverse=reverse)
1961 untransposed = zip(*reordered, strict=strict)
1962 return list(untransposed)
1965def unzip(iterable):
1966 """The inverse of :func:`zip`, this function disaggregates the elements
1967 of the zipped *iterable*.
1969 The ``i``-th iterable contains the ``i``-th element from each element
1970 of the zipped iterable. The first element is used to determine the
1971 length of the remaining elements.
1973 >>> iterable = [('a', 1), ('b', 2), ('c', 3), ('d', 4)]
1974 >>> letters, numbers = unzip(iterable)
1975 >>> list(letters)
1976 ['a', 'b', 'c', 'd']
1977 >>> list(numbers)
1978 [1, 2, 3, 4]
1980 This is similar to using ``zip(*iterable)``, but it avoids reading
1981 *iterable* into memory. Note, however, that this function uses
1982 :func:`itertools.tee` and thus may require significant storage.
1984 """
1985 head, iterable = spy(iterable)
1986 if not head:
1987 # empty iterable, e.g. zip([], [], [])
1988 return ()
1989 # spy returns a one-length iterable as head
1990 head = head[0]
1991 iterables = tee(iterable, len(head))
1993 # If we have an iterable like iter([(1, 2, 3), (4, 5), (6,)]),
1994 # the second unzipped iterable fails at the third tuple since
1995 # it tries to access (6,)[1].
1996 # Same with the third unzipped iterable and the second tuple.
1997 # To support these "improperly zipped" iterables, we suppress
1998 # the IndexError, which just stops the unzipped iterables at
1999 # first length mismatch.
2000 return tuple(
2001 iter_suppress(map(itemgetter(i), it), IndexError)
2002 for i, it in enumerate(iterables)
2003 )
2006def divide(n, iterable):
2007 """Divide the elements from *iterable* into *n* parts, maintaining
2008 order.
2010 >>> group_1, group_2 = divide(2, [1, 2, 3, 4, 5, 6])
2011 >>> list(group_1)
2012 [1, 2, 3]
2013 >>> list(group_2)
2014 [4, 5, 6]
2016 If the length of *iterable* is not evenly divisible by *n*, then the
2017 length of the returned iterables will not be identical:
2019 >>> children = divide(3, [1, 2, 3, 4, 5, 6, 7])
2020 >>> [list(c) for c in children]
2021 [[1, 2, 3], [4, 5], [6, 7]]
2023 If the length of the iterable is smaller than n, then the last returned
2024 iterables will be empty:
2026 >>> children = divide(5, [1, 2, 3])
2027 >>> [list(c) for c in children]
2028 [[1], [2], [3], [], []]
2030 This function will exhaust the iterable before returning.
2031 If order is not important, see :func:`distribute`, which does not first
2032 pull the iterable into memory.
2034 """
2035 if n < 1:
2036 raise ValueError('n must be at least 1')
2038 try:
2039 iterable[:0]
2040 except TypeError:
2041 seq = tuple(iterable)
2042 else:
2043 seq = iterable
2045 q, r = divmod(len(seq), n)
2047 ret = []
2048 stop = 0
2049 for i in range(1, n + 1):
2050 start = stop
2051 stop += q + 1 if i <= r else q
2052 ret.append(iter(seq[start:stop]))
2054 return ret
2057def always_iterable(obj, base_type=(str, bytes)):
2058 """If *obj* is iterable, return an iterator over its items::
2060 >>> obj = (1, 2, 3)
2061 >>> list(always_iterable(obj))
2062 [1, 2, 3]
2064 If *obj* is not iterable, return a one-item iterable containing *obj*::
2066 >>> obj = 1
2067 >>> list(always_iterable(obj))
2068 [1]
2070 If *obj* is ``None``, return an empty iterable:
2072 >>> obj = None
2073 >>> list(always_iterable(None))
2074 []
2076 By default, binary and text strings are not considered iterable::
2078 >>> obj = 'foo'
2079 >>> list(always_iterable(obj))
2080 ['foo']
2082 If *base_type* is set, objects for which ``isinstance(obj, base_type)``
2083 returns ``True`` won't be considered iterable.
2085 >>> obj = {'a': 1}
2086 >>> list(always_iterable(obj)) # Iterate over the dict's keys
2087 ['a']
2088 >>> list(always_iterable(obj, base_type=dict)) # Treat dicts as a unit
2089 [{'a': 1}]
2091 Set *base_type* to ``None`` to avoid any special handling and treat objects
2092 Python considers iterable as iterable:
2094 >>> obj = 'foo'
2095 >>> list(always_iterable(obj, base_type=None))
2096 ['f', 'o', 'o']
2097 """
2098 if obj is None:
2099 return iter(())
2101 if (base_type is not None) and isinstance(obj, base_type):
2102 return iter((obj,))
2104 try:
2105 return iter(obj)
2106 except TypeError:
2107 return iter((obj,))
2110def adjacent(predicate, iterable, distance=1):
2111 """Return an iterable over `(bool, item)` tuples where the `item` is
2112 drawn from *iterable* and the `bool` indicates whether
2113 that item satisfies the *predicate* or is adjacent to an item that does.
2115 For example, to find whether items are adjacent to a ``3``::
2117 >>> list(adjacent(lambda x: x == 3, range(6)))
2118 [(False, 0), (False, 1), (True, 2), (True, 3), (True, 4), (False, 5)]
2120 Set *distance* to change what counts as adjacent. For example, to find
2121 whether items are two places away from a ``3``:
2123 >>> list(adjacent(lambda x: x == 3, range(6), distance=2))
2124 [(False, 0), (True, 1), (True, 2), (True, 3), (True, 4), (True, 5)]
2126 This is useful for contextualizing the results of a search function.
2127 For example, a code comparison tool might want to identify lines that
2128 have changed, but also surrounding lines to give the viewer of the diff
2129 context.
2131 The predicate function will only be called once for each item in the
2132 iterable.
2134 See also :func:`groupby_transform`, which can be used with this function
2135 to group ranges of items with the same `bool` value.
2137 """
2138 # Allow distance=0 mainly for testing that it reproduces results with map()
2139 if distance < 0:
2140 raise ValueError('distance must be at least 0')
2142 i1, i2 = tee(iterable)
2143 padding = [False] * distance
2144 selected = chain(padding, map(predicate, i1), padding)
2145 adjacent_to_selected = map(any, windowed(selected, 2 * distance + 1))
2146 return zip(adjacent_to_selected, i2)
2149def groupby_transform(iterable, keyfunc=None, valuefunc=None, reducefunc=None):
2150 """An extension of :func:`itertools.groupby` that can apply transformations
2151 to the grouped data.
2153 * *keyfunc* is a function computing a key value for each item in *iterable*
2154 * *valuefunc* is a function that transforms the individual items from
2155 *iterable* after grouping
2156 * *reducefunc* is a function that transforms each group of items
2158 >>> iterable = 'aAAbBBcCC'
2159 >>> keyfunc = lambda k: k.upper()
2160 >>> valuefunc = lambda v: v.lower()
2161 >>> reducefunc = lambda g: ''.join(g)
2162 >>> list(groupby_transform(iterable, keyfunc, valuefunc, reducefunc))
2163 [('A', 'aaa'), ('B', 'bbb'), ('C', 'ccc')]
2165 Each optional argument defaults to an identity function if not specified.
2167 :func:`groupby_transform` is useful when grouping elements of an iterable
2168 using a separate iterable as the key. To do this, :func:`zip` the iterables
2169 and pass a *keyfunc* that extracts the first element and a *valuefunc*
2170 that extracts the second element::
2172 >>> from operator import itemgetter
2173 >>> keys = [0, 0, 1, 1, 1, 2, 2, 2, 3]
2174 >>> values = 'abcdefghi'
2175 >>> iterable = zip(keys, values)
2176 >>> grouper = groupby_transform(iterable, itemgetter(0), itemgetter(1))
2177 >>> [(k, ''.join(g)) for k, g in grouper]
2178 [(0, 'ab'), (1, 'cde'), (2, 'fgh'), (3, 'i')]
2180 Note that the order of items in the iterable is significant.
2181 Only adjacent items are grouped together, so if you don't want any
2182 duplicate groups, you should sort the iterable by the key function.
2184 """
2185 ret = groupby(iterable, keyfunc)
2186 if valuefunc:
2187 ret = ((k, map(valuefunc, g)) for k, g in ret)
2188 if reducefunc:
2189 ret = ((k, reducefunc(g)) for k, g in ret)
2191 return ret
2194class numeric_range(Sequence):
2195 """An extension of the built-in ``range()`` function whose arguments can
2196 be any orderable numeric type.
2198 With only *stop* specified, *start* defaults to ``0`` and *step*
2199 defaults to ``1``. The output items will match the type of *stop*:
2201 >>> list(numeric_range(3.5))
2202 [0.0, 1.0, 2.0, 3.0]
2204 With only *start* and *stop* specified, *step* defaults to ``1``. The
2205 output items will match the type of *start*:
2207 >>> from decimal import Decimal
2208 >>> start = Decimal('2.1')
2209 >>> stop = Decimal('5.1')
2210 >>> list(numeric_range(start, stop))
2211 [Decimal('2.1'), Decimal('3.1'), Decimal('4.1')]
2213 With *start*, *stop*, and *step* specified the output items will match
2214 the type of ``start + step``:
2216 >>> from fractions import Fraction
2217 >>> start = Fraction(1, 2) # Start at 1/2
2218 >>> stop = Fraction(5, 2) # End at 5/2
2219 >>> step = Fraction(1, 2) # Count by 1/2
2220 >>> list(numeric_range(start, stop, step))
2221 [Fraction(1, 2), Fraction(1, 1), Fraction(3, 2), Fraction(2, 1)]
2223 If *step* is zero, ``ValueError`` is raised. Negative steps are supported:
2225 >>> list(numeric_range(3, -1, -1.0))
2226 [3.0, 2.0, 1.0, 0.0]
2228 Be aware of the limitations of floating-point numbers; the representation
2229 of the yielded numbers may be surprising.
2231 ``datetime.datetime`` objects can be used for *start* and *stop*, if *step*
2232 is a ``datetime.timedelta`` object:
2234 >>> import datetime
2235 >>> start = datetime.datetime(2019, 1, 1)
2236 >>> stop = datetime.datetime(2019, 1, 3)
2237 >>> step = datetime.timedelta(days=1)
2238 >>> items = iter(numeric_range(start, stop, step))
2239 >>> next(items)
2240 datetime.datetime(2019, 1, 1, 0, 0)
2241 >>> next(items)
2242 datetime.datetime(2019, 1, 2, 0, 0)
2244 """
2246 _EMPTY_HASH = hash(range(0, 0))
2248 def __init__(self, *args):
2249 argc = len(args)
2250 if argc == 1:
2251 (self._stop,) = args
2252 self._start = type(self._stop)(0)
2253 self._step = type(self._stop - self._start)(1)
2254 elif argc == 2:
2255 self._start, self._stop = args
2256 self._step = type(self._stop - self._start)(1)
2257 elif argc == 3:
2258 self._start, self._stop, self._step = args
2259 elif argc == 0:
2260 raise TypeError(
2261 f'numeric_range expected at least 1 argument, got {argc}'
2262 )
2263 else:
2264 raise TypeError(
2265 f'numeric_range expected at most 3 arguments, got {argc}'
2266 )
2268 self._zero = type(self._step)(0)
2269 if self._step == self._zero:
2270 raise ValueError('numeric_range() arg 3 must not be zero')
2271 self._growing = self._step > self._zero
2273 def __bool__(self):
2274 if self._growing:
2275 return self._start < self._stop
2276 else:
2277 return self._start > self._stop
2279 def __contains__(self, elem):
2280 if self._growing:
2281 if self._start <= elem < self._stop:
2282 return (elem - self._start) % self._step == self._zero
2283 else:
2284 if self._start >= elem > self._stop:
2285 return (self._start - elem) % (-self._step) == self._zero
2287 return False
2289 def __eq__(self, other):
2290 if isinstance(other, numeric_range):
2291 empty_self = not bool(self)
2292 empty_other = not bool(other)
2293 if empty_self or empty_other:
2294 return empty_self and empty_other # True if both empty
2295 else:
2296 return (
2297 self._start == other._start
2298 and self._step == other._step
2299 and self._get_by_index(-1) == other._get_by_index(-1)
2300 )
2301 else:
2302 return False
2304 def __getitem__(self, key):
2305 if isinstance(key, int):
2306 return self._get_by_index(key)
2307 elif isinstance(key, slice):
2308 step = self._step if key.step is None else key.step * self._step
2310 if key.start is None or key.start <= -self._len:
2311 start = self._start
2312 elif key.start >= self._len:
2313 start = self._stop
2314 else: # -self._len < key.start < self._len
2315 start = self._get_by_index(key.start)
2317 if key.stop is None or key.stop >= self._len:
2318 stop = self._stop
2319 elif key.stop <= -self._len:
2320 stop = self._start
2321 else: # -self._len < key.stop < self._len
2322 stop = self._get_by_index(key.stop)
2324 return numeric_range(start, stop, step)
2325 else:
2326 raise TypeError(
2327 'numeric range indices must be '
2328 f'integers or slices, not {type(key).__name__}'
2329 )
2331 def __hash__(self):
2332 if self:
2333 return hash((self._start, self._get_by_index(-1), self._step))
2334 else:
2335 return self._EMPTY_HASH
2337 def __iter__(self):
2338 values = (self._start + (n * self._step) for n in count())
2339 if self._growing:
2340 return takewhile(partial(gt, self._stop), values)
2341 else:
2342 return takewhile(partial(lt, self._stop), values)
2344 def __len__(self):
2345 return self._len
2347 @cached_property
2348 def _len(self):
2349 if self._growing:
2350 start = self._start
2351 stop = self._stop
2352 step = self._step
2353 else:
2354 start = self._stop
2355 stop = self._start
2356 step = -self._step
2357 distance = stop - start
2358 if distance <= self._zero:
2359 return 0
2360 else: # distance > 0 and step > 0: regular euclidean division
2361 q, r = divmod(distance, step)
2362 return int(q) + int(r != self._zero)
2364 def __reduce__(self):
2365 return numeric_range, (self._start, self._stop, self._step)
2367 def __repr__(self):
2368 if self._step == 1:
2369 return f"numeric_range({self._start!r}, {self._stop!r})"
2370 return (
2371 f"numeric_range({self._start!r}, {self._stop!r}, {self._step!r})"
2372 )
2374 def __reversed__(self):
2375 return iter(
2376 numeric_range(
2377 self._get_by_index(-1), self._start - self._step, -self._step
2378 )
2379 )
2381 def count(self, value):
2382 return int(value in self)
2384 def index(self, value):
2385 if self._growing:
2386 if self._start <= value < self._stop:
2387 q, r = divmod(value - self._start, self._step)
2388 if r == self._zero:
2389 return int(q)
2390 else:
2391 if self._start >= value > self._stop:
2392 q, r = divmod(self._start - value, -self._step)
2393 if r == self._zero:
2394 return int(q)
2396 raise ValueError(f"{value} is not in numeric range")
2398 def _get_by_index(self, i):
2399 if i < 0:
2400 i += self._len
2401 if i < 0 or i >= self._len:
2402 raise IndexError("numeric range object index out of range")
2403 return self._start + i * self._step
2406def count_cycle(iterable, n=None):
2407 """Cycle through the items from *iterable* up to *n* times, yielding
2408 the number of completed cycles along with each item. If *n* is omitted the
2409 process repeats indefinitely.
2411 >>> list(count_cycle('AB', 3))
2412 [(0, 'A'), (0, 'B'), (1, 'A'), (1, 'B'), (2, 'A'), (2, 'B')]
2414 """
2415 if n is not None:
2416 return product(range(n), iterable)
2417 seq = tuple(iterable)
2418 if not seq:
2419 return iter(())
2420 counter = count() if n is None else range(n)
2421 return zip(repeat_each(counter, len(seq)), cycle(seq))
2424def mark_ends(iterable):
2425 """Yield 3-tuples of the form ``(is_first, is_last, item)``.
2427 >>> list(mark_ends('ABC'))
2428 [(True, False, 'A'), (False, False, 'B'), (False, True, 'C')]
2430 Use this when looping over an iterable to take special action on its first
2431 and/or last items:
2433 >>> iterable = ['Header', 100, 200, 'Footer']
2434 >>> total = 0
2435 >>> for is_first, is_last, item in mark_ends(iterable):
2436 ... if is_first:
2437 ... continue # Skip the header
2438 ... if is_last:
2439 ... continue # Skip the footer
2440 ... total += item
2441 >>> print(total)
2442 300
2443 """
2444 it = iter(iterable)
2445 for a in it:
2446 first = True
2447 for b in it:
2448 yield first, False, a
2449 a = b
2450 first = False
2451 yield first, True, a
2454def locate(iterable, pred=bool, window_size=None):
2455 """Yield the index of each item in *iterable* for which *pred* returns
2456 ``True``.
2458 *pred* defaults to :func:`bool`, which will select truthy items:
2460 >>> list(locate([0, 1, 1, 0, 1, 0, 0]))
2461 [1, 2, 4]
2463 Set *pred* to a custom function to, e.g., find the indexes for a particular
2464 item.
2466 >>> list(locate(['a', 'b', 'c', 'b'], lambda x: x == 'b'))
2467 [1, 3]
2469 If *window_size* is given, then the *pred* function will be called with
2470 that many items. This enables searching for sub-sequences:
2472 >>> iterable = [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]
2473 >>> pred = lambda *args: args == (1, 2, 3)
2474 >>> list(locate(iterable, pred=pred, window_size=3))
2475 [1, 5, 9]
2477 Use with :func:`seekable` to find indexes and then retrieve the associated
2478 items:
2480 >>> from itertools import count
2481 >>> from more_itertools import seekable
2482 >>> source = (3 * n + 1 if (n % 2) else n // 2 for n in count())
2483 >>> it = seekable(source)
2484 >>> pred = lambda x: x > 100
2485 >>> indexes = locate(it, pred=pred)
2486 >>> i = next(indexes)
2487 >>> it.seek(i)
2488 >>> next(it)
2489 106
2491 """
2492 if window_size is None:
2493 return compress(count(), map(pred, iterable))
2495 if window_size < 1:
2496 raise ValueError('window size must be at least 1')
2498 it = windowed(iterable, window_size, fillvalue=_marker)
2499 return compress(count(), starmap(pred, it))
2502def longest_common_prefix(iterables):
2503 """Yield elements of the longest common prefix among given *iterables*.
2505 >>> ''.join(longest_common_prefix(['abcd', 'abc', 'abf']))
2506 'ab'
2508 """
2509 return (c[0] for c in takewhile(all_equal, zip(*iterables)))
2512def lstrip(iterable, pred):
2513 """Yield the items from *iterable*, but strip any from the beginning
2514 for which *pred* returns ``True``.
2516 For example, to remove a set of items from the start of an iterable:
2518 >>> iterable = (None, False, None, 1, 2, None, 3, False, None)
2519 >>> pred = lambda x: x in {None, False, ''}
2520 >>> list(lstrip(iterable, pred))
2521 [1, 2, None, 3, False, None]
2523 This function is analogous to to :func:`str.lstrip`, and is essentially
2524 an wrapper for :func:`itertools.dropwhile`.
2526 """
2527 return dropwhile(pred, iterable)
2530def rstrip(iterable, pred):
2531 """Yield the items from *iterable*, but strip any from the end
2532 for which *pred* returns ``True``.
2534 For example, to remove a set of items from the end of an iterable:
2536 >>> iterable = (None, False, None, 1, 2, None, 3, False, None)
2537 >>> pred = lambda x: x in {None, False, ''}
2538 >>> list(rstrip(iterable, pred))
2539 [None, False, None, 1, 2, None, 3]
2541 This function is analogous to :func:`str.rstrip`.
2543 """
2544 cache = []
2545 cache_append = cache.append
2546 cache_clear = cache.clear
2547 for x in iterable:
2548 if pred(x):
2549 cache_append(x)
2550 else:
2551 yield from cache
2552 cache_clear()
2553 yield x
2556def strip(iterable, pred):
2557 """Yield the items from *iterable*, but strip any from the
2558 beginning and end for which *pred* returns ``True``.
2560 For example, to remove a set of items from both ends of an iterable:
2562 >>> iterable = (None, False, None, 1, 2, None, 3, False, None)
2563 >>> pred = lambda x: x in {None, False, ''}
2564 >>> list(strip(iterable, pred))
2565 [1, 2, None, 3]
2567 This function is analogous to :func:`str.strip`.
2569 """
2570 return rstrip(lstrip(iterable, pred), pred)
2573class islice_extended:
2574 """An extension of :func:`itertools.islice` that supports negative values
2575 for *stop*, *start*, and *step*.
2577 >>> iterator = iter('abcdefgh')
2578 >>> list(islice_extended(iterator, -4, -1))
2579 ['e', 'f', 'g']
2581 Slices with negative values require some caching of *iterable*, but this
2582 function takes care to minimize the amount of memory required.
2584 For example, you can use a negative step with an infinite iterator:
2586 >>> from itertools import count
2587 >>> list(islice_extended(count(), 110, 99, -2))
2588 [110, 108, 106, 104, 102, 100]
2590 You can also use slice notation directly:
2592 >>> iterator = map(str, count())
2593 >>> it = islice_extended(iterator)[10:20:2]
2594 >>> list(it)
2595 ['10', '12', '14', '16', '18']
2597 """
2599 def __init__(self, iterable, *args):
2600 it = iter(iterable)
2601 if args:
2602 self._iterator = _islice_helper(it, slice(*args))
2603 else:
2604 self._iterator = it
2606 def __iter__(self):
2607 return self
2609 def __next__(self):
2610 return next(self._iterator)
2612 def __getitem__(self, key):
2613 if isinstance(key, slice):
2614 return islice_extended(_islice_helper(self._iterator, key))
2616 raise TypeError('islice_extended.__getitem__ argument must be a slice')
2619def _islice_helper(it, s):
2620 start = s.start
2621 stop = s.stop
2622 if s.step == 0:
2623 raise ValueError('step argument must be a non-zero integer or None.')
2624 step = s.step or 1
2626 if step > 0:
2627 start = 0 if (start is None) else start
2629 if start < 0:
2630 # Consume all but the last -start items
2631 cache = deque(enumerate(it, 1), maxlen=-start)
2632 len_iter = cache[-1][0] if cache else 0
2634 # Adjust start to be positive
2635 i = max(len_iter + start, 0)
2637 # Adjust stop to be positive
2638 if stop is None:
2639 j = len_iter
2640 elif stop >= 0:
2641 j = min(stop, len_iter)
2642 else:
2643 j = max(len_iter + stop, 0)
2645 # Slice the cache
2646 n = j - i
2647 if n <= 0:
2648 return
2650 for index in range(n):
2651 if index % step == 0:
2652 # pop and yield the item.
2653 # We don't want to use an intermediate variable
2654 # it would extend the lifetime of the current item
2655 yield cache.popleft()[1]
2656 else:
2657 # just pop and discard the item
2658 cache.popleft()
2659 elif (stop is not None) and (stop < 0):
2660 # Advance to the start position
2661 next(islice(it, start, start), None)
2663 # When stop is negative, we have to carry -stop items while
2664 # iterating
2665 cache = deque(islice(it, -stop), maxlen=-stop)
2667 for index, item in enumerate(it):
2668 if index % step == 0:
2669 # pop and yield the item.
2670 # We don't want to use an intermediate variable
2671 # it would extend the lifetime of the current item
2672 yield cache.popleft()
2673 else:
2674 # just pop and discard the item
2675 cache.popleft()
2676 cache.append(item)
2677 else:
2678 # When both start and stop are positive we have the normal case
2679 yield from islice(it, start, stop, step)
2680 else:
2681 start = -1 if (start is None) else start
2683 if (stop is not None) and (stop < 0):
2684 # Consume all but the last items
2685 n = -stop - 1
2686 cache = deque(enumerate(it, 1), maxlen=n)
2687 len_iter = cache[-1][0] if cache else 0
2689 # If start and stop are both negative they are comparable and
2690 # we can just slice. Otherwise we can adjust start to be negative
2691 # and then slice.
2692 if start < 0:
2693 i, j = start, stop
2694 else:
2695 i, j = min(start - len_iter, -1), None
2697 for index, item in list(cache)[i:j:step]:
2698 yield item
2699 else:
2700 # Advance to the stop position
2701 if stop is not None:
2702 m = stop + 1
2703 next(islice(it, m, m), None)
2705 # stop is positive, so if start is negative they are not comparable
2706 # and we need the rest of the items.
2707 if start < 0:
2708 i = start
2709 n = None
2710 # stop is None and start is positive, so we just need items up to
2711 # the start index.
2712 elif stop is None:
2713 i = None
2714 n = start + 1
2715 # Both stop and start are positive, so they are comparable.
2716 else:
2717 i = None
2718 n = start - stop
2719 if n <= 0:
2720 return
2722 cache = list(islice(it, n))
2724 yield from cache[i::step]
2727def always_reversible(iterable):
2728 """An extension of :func:`reversed` that supports all iterables, not
2729 just those which implement the ``Reversible`` or ``Sequence`` protocols.
2731 >>> print(*always_reversible(x for x in range(3)))
2732 2 1 0
2734 If the iterable is already reversible, this function returns the
2735 result of :func:`reversed()`. If the iterable is not reversible,
2736 this function will cache the remaining items in the iterable and
2737 yield them in reverse order, which may require significant storage.
2738 """
2739 try:
2740 return reversed(iterable)
2741 except TypeError:
2742 return reversed(list(iterable))
2745def consecutive_groups(iterable, ordering=None):
2746 """Yield groups of consecutive items using :func:`itertools.groupby`.
2747 The *ordering* function determines whether two items are adjacent by
2748 returning their position.
2750 By default, the ordering function is the identity function. This is
2751 suitable for finding runs of numbers:
2753 >>> iterable = [1, 10, 11, 12, 20, 30, 31, 32, 33, 40]
2754 >>> for group in consecutive_groups(iterable):
2755 ... print(list(group))
2756 [1]
2757 [10, 11, 12]
2758 [20]
2759 [30, 31, 32, 33]
2760 [40]
2762 To find runs of adjacent letters, apply :func:`ord` function
2763 to convert letters to ordinals.
2765 >>> iterable = 'abcdfgilmnop'
2766 >>> ordering = ord
2767 >>> for group in consecutive_groups(iterable, ordering):
2768 ... print(list(group))
2769 ['a', 'b', 'c', 'd']
2770 ['f', 'g']
2771 ['i']
2772 ['l', 'm', 'n', 'o', 'p']
2774 Each group of consecutive items is an iterator that shares it source with
2775 *iterable*. When an an output group is advanced, the previous group is
2776 no longer available unless its elements are copied (e.g., into a ``list``).
2778 >>> iterable = [1, 2, 11, 12, 21, 22]
2779 >>> saved_groups = []
2780 >>> for group in consecutive_groups(iterable):
2781 ... saved_groups.append(list(group)) # Copy group elements
2782 >>> saved_groups
2783 [[1, 2], [11, 12], [21, 22]]
2785 """
2786 if ordering is None:
2787 key = lambda x: x[0] - x[1]
2788 else:
2789 key = lambda x: x[0] - ordering(x[1])
2791 for k, g in groupby(enumerate(iterable), key=key):
2792 yield map(itemgetter(1), g)
2795def difference(iterable, func=sub, *, initial=None):
2796 """This function is the inverse of :func:`itertools.accumulate`. By default
2797 it will compute the first difference of *iterable* using
2798 :func:`operator.sub`:
2800 >>> from itertools import accumulate
2801 >>> iterable = accumulate([0, 1, 2, 3, 4]) # produces 0, 1, 3, 6, 10
2802 >>> list(difference(iterable))
2803 [0, 1, 2, 3, 4]
2805 *func* defaults to :func:`operator.sub`, but other functions can be
2806 specified. They will be applied as follows::
2808 A, B, C, D, ... --> A, func(B, A), func(C, B), func(D, C), ...
2810 For example, to do progressive division:
2812 >>> iterable = [1, 2, 6, 24, 120]
2813 >>> func = lambda x, y: x // y
2814 >>> list(difference(iterable, func))
2815 [1, 2, 3, 4, 5]
2817 If the *initial* keyword is set, the first element will be skipped when
2818 computing successive differences.
2820 >>> it = [10, 11, 13, 16] # from accumulate([1, 2, 3], initial=10)
2821 >>> list(difference(it, initial=10))
2822 [1, 2, 3]
2824 """
2825 a, b = tee(iterable)
2826 try:
2827 first = [next(b)]
2828 except StopIteration:
2829 return iter([])
2831 if initial is not None:
2832 first = []
2834 return chain(first, map(func, b, a))
2837class SequenceView(Sequence):
2838 """Return a read-only view of the sequence object *target*.
2840 :class:`SequenceView` objects are analogous to Python's built-in
2841 "dictionary view" types. They provide a dynamic view of a sequence's items,
2842 meaning that when the sequence updates, so does the view.
2844 >>> seq = ['0', '1', '2']
2845 >>> view = SequenceView(seq)
2846 >>> view
2847 SequenceView(['0', '1', '2'])
2848 >>> seq.append('3')
2849 >>> view
2850 SequenceView(['0', '1', '2', '3'])
2852 Sequence views support indexing, slicing, and length queries. They act
2853 like the underlying sequence, except they don't allow assignment:
2855 >>> view[1]
2856 '1'
2857 >>> view[1:-1]
2858 ['1', '2']
2859 >>> len(view)
2860 4
2862 Sequence views are useful as an alternative to copying, as they don't
2863 require (much) extra storage.
2865 """
2867 def __init__(self, target):
2868 if not isinstance(target, Sequence):
2869 raise TypeError
2870 self._target = target
2872 def __getitem__(self, index):
2873 return self._target[index]
2875 def __len__(self):
2876 return len(self._target)
2878 def __repr__(self):
2879 return f'{self.__class__.__name__}({self._target!r})'
2882class seekable:
2883 """Wrap an iterator to allow for seeking backward and forward. This
2884 progressively caches the items in the source iterable so they can be
2885 re-visited.
2887 Call :meth:`seek` with an index to seek to that position in the source
2888 iterable.
2890 To "reset" an iterator, seek to ``0``:
2892 >>> from itertools import count
2893 >>> it = seekable((str(n) for n in count()))
2894 >>> next(it), next(it), next(it)
2895 ('0', '1', '2')
2896 >>> it.seek(0)
2897 >>> next(it), next(it), next(it)
2898 ('0', '1', '2')
2900 You can also seek forward:
2902 >>> it = seekable((str(n) for n in range(20)))
2903 >>> it.seek(10)
2904 >>> next(it)
2905 '10'
2906 >>> it.seek(20) # Seeking past the end of the source isn't a problem
2907 >>> list(it)
2908 []
2909 >>> it.seek(0) # Resetting works even after hitting the end
2910 >>> next(it)
2911 '0'
2913 Call :meth:`relative_seek` to seek relative to the source iterator's
2914 current position.
2916 >>> it = seekable((str(n) for n in range(20)))
2917 >>> next(it), next(it), next(it)
2918 ('0', '1', '2')
2919 >>> it.relative_seek(2)
2920 >>> next(it)
2921 '5'
2922 >>> it.relative_seek(-3) # Source is at '6', we move back to '3'
2923 >>> next(it)
2924 '3'
2925 >>> it.relative_seek(-3) # Source is at '4', we move back to '1'
2926 >>> next(it)
2927 '1'
2930 Call :meth:`peek` to look ahead one item without advancing the iterator:
2932 >>> it = seekable('1234')
2933 >>> it.peek()
2934 '1'
2935 >>> list(it)
2936 ['1', '2', '3', '4']
2937 >>> it.peek(default='empty')
2938 'empty'
2940 Before the iterator is at its end, calling :func:`bool` on it will return
2941 ``True``. After it will return ``False``:
2943 >>> it = seekable('5678')
2944 >>> bool(it)
2945 True
2946 >>> list(it)
2947 ['5', '6', '7', '8']
2948 >>> bool(it)
2949 False
2951 You may view the contents of the cache with the :meth:`elements` method.
2952 That returns a :class:`SequenceView`, a view that updates automatically:
2954 >>> it = seekable((str(n) for n in range(10)))
2955 >>> next(it), next(it), next(it)
2956 ('0', '1', '2')
2957 >>> elements = it.elements()
2958 >>> elements
2959 SequenceView(['0', '1', '2'])
2960 >>> next(it)
2961 '3'
2962 >>> elements
2963 SequenceView(['0', '1', '2', '3'])
2965 By default, the cache grows as the source iterable progresses, so beware of
2966 wrapping very large or infinite iterables. Supply *maxlen* to limit the
2967 size of the cache (this of course limits how far back you can seek).
2969 >>> from itertools import count
2970 >>> it = seekable((str(n) for n in count()), maxlen=2)
2971 >>> next(it), next(it), next(it), next(it)
2972 ('0', '1', '2', '3')
2973 >>> list(it.elements())
2974 ['2', '3']
2975 >>> it.seek(0)
2976 >>> next(it), next(it), next(it), next(it)
2977 ('2', '3', '4', '5')
2978 >>> next(it)
2979 '6'
2981 """
2983 def __init__(self, iterable, maxlen=None):
2984 self._source = iter(iterable)
2985 if maxlen is None:
2986 self._cache = []
2987 else:
2988 self._cache = deque([], maxlen)
2989 self._index = None
2991 def __iter__(self):
2992 return self
2994 def __next__(self):
2995 if self._index is not None:
2996 try:
2997 item = self._cache[self._index]
2998 except IndexError:
2999 self._index = None
3000 else:
3001 self._index += 1
3002 return item
3004 item = next(self._source)
3005 self._cache.append(item)
3006 return item
3008 def __bool__(self):
3009 try:
3010 self.peek()
3011 except StopIteration:
3012 return False
3013 return True
3015 def peek(self, default=_marker):
3016 try:
3017 peeked = next(self)
3018 except StopIteration:
3019 if default is _marker:
3020 raise
3021 return default
3022 if self._index is None:
3023 self._index = len(self._cache)
3024 self._index -= 1
3025 return peeked
3027 def elements(self):
3028 return SequenceView(self._cache)
3030 def seek(self, index):
3031 self._index = index
3032 remainder = index - len(self._cache)
3033 if remainder > 0:
3034 consume(self, remainder)
3036 def relative_seek(self, count):
3037 if self._index is None:
3038 self._index = len(self._cache)
3040 self.seek(max(self._index + count, 0))
3043class run_length:
3044 """
3045 :func:`run_length.encode` compresses an iterable with run-length encoding.
3046 It yields groups of repeated items with the count of how many times they
3047 were repeated:
3049 >>> uncompressed = 'abbcccdddd'
3050 >>> list(run_length.encode(uncompressed))
3051 [('a', 1), ('b', 2), ('c', 3), ('d', 4)]
3053 :func:`run_length.decode` decompresses an iterable that was previously
3054 compressed with run-length encoding. It yields the items of the
3055 decompressed iterable:
3057 >>> compressed = [('a', 1), ('b', 2), ('c', 3), ('d', 4)]
3058 >>> list(run_length.decode(compressed))
3059 ['a', 'b', 'b', 'c', 'c', 'c', 'd', 'd', 'd', 'd']
3061 """
3063 @staticmethod
3064 def encode(iterable):
3065 return ((k, ilen(g)) for k, g in groupby(iterable))
3067 @staticmethod
3068 def decode(iterable):
3069 return chain.from_iterable(starmap(repeat, iterable))
3072def exactly_n(iterable, n, predicate=bool):
3073 """Return ``True`` if exactly ``n`` items in the iterable are ``True``
3074 according to the *predicate* function.
3076 >>> exactly_n([True, True, False], 2)
3077 True
3078 >>> exactly_n([True, True, False], 1)
3079 False
3080 >>> exactly_n([0, 1, 2, 3, 4, 5], 3, lambda x: x < 3)
3081 True
3083 The iterable will be advanced until ``n + 1`` truthy items are encountered,
3084 so avoid calling it on infinite iterables.
3086 """
3087 iterator = filter(predicate, iterable)
3088 if n <= 0:
3089 if n < 0:
3090 return False
3091 for _ in iterator:
3092 return False
3093 return True
3095 iterator = islice(iterator, n - 1, None)
3096 for _ in iterator:
3097 for _ in iterator:
3098 return False
3099 return True
3100 return False
3103def circular_shifts(iterable, steps=1):
3104 """Yield the circular shifts of *iterable*.
3106 >>> list(circular_shifts(range(4)))
3107 [(0, 1, 2, 3), (1, 2, 3, 0), (2, 3, 0, 1), (3, 0, 1, 2)]
3109 Set *steps* to the number of places to rotate to the left
3110 (or to the right if negative). Defaults to 1.
3112 >>> list(circular_shifts(range(4), 2))
3113 [(0, 1, 2, 3), (2, 3, 0, 1)]
3115 >>> list(circular_shifts(range(4), -1))
3116 [(0, 1, 2, 3), (3, 0, 1, 2), (2, 3, 0, 1), (1, 2, 3, 0)]
3118 """
3119 buffer = deque(iterable)
3120 if steps == 0:
3121 raise ValueError('Steps should be a non-zero integer')
3123 buffer.rotate(steps)
3124 steps = -steps
3125 n = len(buffer)
3126 n //= math.gcd(n, steps)
3128 for _ in repeat(None, n):
3129 buffer.rotate(steps)
3130 yield tuple(buffer)
3133def make_decorator(wrapping_func, result_index=0):
3134 """Return a decorator version of *wrapping_func*, which is a function that
3135 modifies an iterable. *result_index* is the position in that function's
3136 signature where the iterable goes.
3138 This lets you use itertools on the "production end," i.e. at function
3139 definition. This can augment what the function returns without changing the
3140 function's code.
3142 For example, to produce a decorator version of :func:`chunked`:
3144 >>> from more_itertools import chunked
3145 >>> chunker = make_decorator(chunked, result_index=0)
3146 >>> @chunker(3)
3147 ... def iter_range(n):
3148 ... return iter(range(n))
3149 ...
3150 >>> list(iter_range(9))
3151 [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
3153 To only allow truthy items to be returned:
3155 >>> truth_serum = make_decorator(filter, result_index=1)
3156 >>> @truth_serum(bool)
3157 ... def boolean_test():
3158 ... return [0, 1, '', ' ', False, True]
3159 ...
3160 >>> list(boolean_test())
3161 [1, ' ', True]
3163 The :func:`peekable` and :func:`seekable` wrappers make for practical
3164 decorators:
3166 >>> from more_itertools import peekable
3167 >>> peekable_function = make_decorator(peekable)
3168 >>> @peekable_function()
3169 ... def str_range(*args):
3170 ... return (str(x) for x in range(*args))
3171 ...
3172 >>> it = str_range(1, 20, 2)
3173 >>> next(it), next(it), next(it)
3174 ('1', '3', '5')
3175 >>> it.peek()
3176 '7'
3177 >>> next(it)
3178 '7'
3180 """
3182 # See https://sites.google.com/site/bbayles/index/decorator_factory for
3183 # notes on how this works.
3184 def decorator(*wrapping_args, **wrapping_kwargs):
3185 def outer_wrapper(f):
3186 def inner_wrapper(*args, **kwargs):
3187 result = f(*args, **kwargs)
3188 wrapping_args_ = list(wrapping_args)
3189 wrapping_args_.insert(result_index, result)
3190 return wrapping_func(*wrapping_args_, **wrapping_kwargs)
3192 return inner_wrapper
3194 return outer_wrapper
3196 return decorator
3199def map_reduce(iterable, keyfunc, valuefunc=None, reducefunc=None):
3200 """Return a dictionary that maps the items in *iterable* to categories
3201 defined by *keyfunc*, transforms them with *valuefunc*, and
3202 then summarizes them by category with *reducefunc*.
3204 *valuefunc* defaults to the identity function if it is unspecified.
3205 If *reducefunc* is unspecified, no summarization takes place:
3207 >>> keyfunc = lambda x: x.upper()
3208 >>> result = map_reduce('abbccc', keyfunc)
3209 >>> sorted(result.items())
3210 [('A', ['a']), ('B', ['b', 'b']), ('C', ['c', 'c', 'c'])]
3212 Specifying *valuefunc* transforms the categorized items:
3214 >>> keyfunc = lambda x: x.upper()
3215 >>> valuefunc = lambda x: 1
3216 >>> result = map_reduce('abbccc', keyfunc, valuefunc)
3217 >>> sorted(result.items())
3218 [('A', [1]), ('B', [1, 1]), ('C', [1, 1, 1])]
3220 Specifying *reducefunc* summarizes the categorized items:
3222 >>> keyfunc = lambda x: x.upper()
3223 >>> valuefunc = lambda x: 1
3224 >>> reducefunc = sum
3225 >>> result = map_reduce('abbccc', keyfunc, valuefunc, reducefunc)
3226 >>> sorted(result.items())
3227 [('A', 1), ('B', 2), ('C', 3)]
3229 You may want to filter the input iterable before applying the map/reduce
3230 procedure:
3232 >>> all_items = range(30)
3233 >>> items = [x for x in all_items if 10 <= x <= 20] # Filter
3234 >>> keyfunc = lambda x: x % 2 # Evens map to 0; odds to 1
3235 >>> categories = map_reduce(items, keyfunc=keyfunc)
3236 >>> sorted(categories.items())
3237 [(0, [10, 12, 14, 16, 18, 20]), (1, [11, 13, 15, 17, 19])]
3238 >>> summaries = map_reduce(items, keyfunc=keyfunc, reducefunc=sum)
3239 >>> sorted(summaries.items())
3240 [(0, 90), (1, 75)]
3242 Note that all items in the iterable are gathered into a list before the
3243 summarization step, which may require significant storage.
3245 The returned object is a :obj:`collections.defaultdict` with the
3246 ``default_factory`` set to ``None``, such that it behaves like a normal
3247 dictionary.
3249 """
3251 ret = defaultdict(list)
3253 if valuefunc is None:
3254 for item in iterable:
3255 key = keyfunc(item)
3256 ret[key].append(item)
3258 else:
3259 for item in iterable:
3260 key = keyfunc(item)
3261 value = valuefunc(item)
3262 ret[key].append(value)
3264 if reducefunc is not None:
3265 for key, value_list in ret.items():
3266 ret[key] = reducefunc(value_list)
3268 ret.default_factory = None
3269 return ret
3272def rlocate(iterable, pred=bool, window_size=None):
3273 """Yield the index of each item in *iterable* for which *pred* returns
3274 ``True``, starting from the right and moving left.
3276 *pred* defaults to :func:`bool`, which will select truthy items:
3278 >>> list(rlocate([0, 1, 1, 0, 1, 0, 0])) # Truthy at 1, 2, and 4
3279 [4, 2, 1]
3281 Set *pred* to a custom function to, e.g., find the indexes for a particular
3282 item:
3284 >>> iterator = iter('abcb')
3285 >>> pred = lambda x: x == 'b'
3286 >>> list(rlocate(iterator, pred))
3287 [3, 1]
3289 If *window_size* is given, then the *pred* function will be called with
3290 that many items. This enables searching for sub-sequences:
3292 >>> iterable = [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]
3293 >>> pred = lambda *args: args == (1, 2, 3)
3294 >>> list(rlocate(iterable, pred=pred, window_size=3))
3295 [9, 5, 1]
3297 Beware, this function won't return anything for infinite iterables.
3298 If *iterable* is reversible, ``rlocate`` will reverse it and search from
3299 the right. Otherwise, it will search from the left and return the results
3300 in reverse order.
3302 See :func:`locate` to for other example applications.
3304 """
3305 if window_size is None:
3306 try:
3307 len_iter = len(iterable)
3308 return (len_iter - i - 1 for i in locate(reversed(iterable), pred))
3309 except TypeError:
3310 pass
3312 return reversed(list(locate(iterable, pred, window_size)))
3315def replace(iterable, pred, substitutes, count=None, window_size=1):
3316 """Yield the items from *iterable*, replacing the items for which *pred*
3317 returns ``True`` with the items from the iterable *substitutes*.
3319 >>> iterable = [1, 1, 0, 1, 1, 0, 1, 1]
3320 >>> pred = lambda x: x == 0
3321 >>> substitutes = (2, 3)
3322 >>> list(replace(iterable, pred, substitutes))
3323 [1, 1, 2, 3, 1, 1, 2, 3, 1, 1]
3325 If *count* is given, the number of replacements will be limited:
3327 >>> iterable = [1, 1, 0, 1, 1, 0, 1, 1, 0]
3328 >>> pred = lambda x: x == 0
3329 >>> substitutes = [None]
3330 >>> list(replace(iterable, pred, substitutes, count=2))
3331 [1, 1, None, 1, 1, None, 1, 1, 0]
3333 Use *window_size* to control the number of items passed as arguments to
3334 *pred*. This allows for locating and replacing subsequences.
3336 >>> iterable = [0, 1, 2, 5, 0, 1, 2, 5]
3337 >>> window_size = 3
3338 >>> pred = lambda *args: args == (0, 1, 2) # 3 items passed to pred
3339 >>> substitutes = [3, 4] # Splice in these items
3340 >>> list(replace(iterable, pred, substitutes, window_size=window_size))
3341 [3, 4, 5, 3, 4, 5]
3343 """
3344 if window_size < 1:
3345 raise ValueError('window_size must be at least 1')
3347 # Save the substitutes iterable, since it's used more than once
3348 substitutes = tuple(substitutes)
3350 # Add padding such that the number of windows matches the length of the
3351 # iterable
3352 it = chain(iterable, repeat(_marker, window_size - 1))
3353 windows = windowed(it, window_size)
3355 n = 0
3356 for w in windows:
3357 # If the current window matches our predicate (and we haven't hit
3358 # our maximum number of replacements), splice in the substitutes
3359 # and then consume the following windows that overlap with this one.
3360 # For example, if the iterable is (0, 1, 2, 3, 4...)
3361 # and the window size is 2, we have (0, 1), (1, 2), (2, 3)...
3362 # If the predicate matches on (0, 1), we need to zap (0, 1) and (1, 2)
3363 if pred(*w):
3364 if (count is None) or (n < count):
3365 n += 1
3366 yield from substitutes
3367 consume(windows, window_size - 1)
3368 continue
3370 # If there was no match (or we've reached the replacement limit),
3371 # yield the first item from the window.
3372 if w and (w[0] is not _marker):
3373 yield w[0]
3376def partitions(iterable):
3377 """Yield all possible order-preserving partitions of *iterable*.
3379 >>> iterable = 'abc'
3380 >>> for part in partitions(iterable):
3381 ... print([''.join(p) for p in part])
3382 ['abc']
3383 ['a', 'bc']
3384 ['ab', 'c']
3385 ['a', 'b', 'c']
3387 This is unrelated to :func:`partition`.
3389 """
3390 sequence = list(iterable)
3391 n = len(sequence)
3392 for i in powerset(range(1, n)):
3393 yield [sequence[i:j] for i, j in zip((0,) + i, i + (n,))]
3396def set_partitions(iterable, k=None, min_size=None, max_size=None):
3397 """
3398 Yield the set partitions of *iterable* into *k* parts. Set partitions are
3399 not order-preserving.
3401 >>> iterable = 'abc'
3402 >>> for part in set_partitions(iterable, 2):
3403 ... print([''.join(p) for p in part])
3404 ['a', 'bc']
3405 ['ab', 'c']
3406 ['b', 'ac']
3409 If *k* is not given, every set partition is generated.
3411 >>> iterable = 'abc'
3412 >>> for part in set_partitions(iterable):
3413 ... print([''.join(p) for p in part])
3414 ['abc']
3415 ['a', 'bc']
3416 ['ab', 'c']
3417 ['b', 'ac']
3418 ['a', 'b', 'c']
3420 if *min_size* and/or *max_size* are given, the minimum and/or maximum size
3421 per block in partition is set.
3423 >>> iterable = 'abc'
3424 >>> for part in set_partitions(iterable, min_size=2):
3425 ... print([''.join(p) for p in part])
3426 ['abc']
3427 >>> for part in set_partitions(iterable, max_size=2):
3428 ... print([''.join(p) for p in part])
3429 ['a', 'bc']
3430 ['ab', 'c']
3431 ['b', 'ac']
3432 ['a', 'b', 'c']
3434 """
3435 L = list(iterable)
3436 n = len(L)
3437 if k is not None:
3438 if k < 1:
3439 raise ValueError(
3440 "Can't partition in a negative or zero number of groups"
3441 )
3442 elif k > n:
3443 return
3445 min_size = min_size if min_size is not None else 0
3446 max_size = max_size if max_size is not None else n
3447 if min_size > max_size:
3448 return
3450 def set_partitions_helper(L, k):
3451 n = len(L)
3452 if k == 1:
3453 yield [L]
3454 elif n == k:
3455 yield [[s] for s in L]
3456 else:
3457 e, *M = L
3458 for p in set_partitions_helper(M, k - 1):
3459 yield [[e], *p]
3460 for p in set_partitions_helper(M, k):
3461 for i in range(len(p)):
3462 yield p[:i] + [[e] + p[i]] + p[i + 1 :]
3464 if k is None:
3465 for k in range(1, n + 1):
3466 yield from filter(
3467 lambda z: all(min_size <= len(bk) <= max_size for bk in z),
3468 set_partitions_helper(L, k),
3469 )
3470 else:
3471 yield from filter(
3472 lambda z: all(min_size <= len(bk) <= max_size for bk in z),
3473 set_partitions_helper(L, k),
3474 )
3477class time_limited:
3478 """
3479 Yield items from *iterable* until *limit_seconds* have passed.
3480 If the time limit expires before all items have been yielded, the
3481 ``timed_out`` parameter will be set to ``True``.
3483 >>> from time import sleep
3484 >>> def generator():
3485 ... yield 1
3486 ... yield 2
3487 ... sleep(0.2)
3488 ... yield 3
3489 >>> iterable = time_limited(0.1, generator())
3490 >>> list(iterable)
3491 [1, 2]
3492 >>> iterable.timed_out
3493 True
3495 Note that the time is checked before each item is yielded, and iteration
3496 stops if the time elapsed is greater than *limit_seconds*. If your time
3497 limit is 1 second, but it takes 2 seconds to generate the first item from
3498 the iterable, the function will run for 2 seconds and not yield anything.
3499 As a special case, when *limit_seconds* is zero, the iterator never
3500 returns anything.
3502 """
3504 def __init__(self, limit_seconds, iterable):
3505 if limit_seconds < 0:
3506 raise ValueError('limit_seconds must be positive')
3507 self.limit_seconds = limit_seconds
3508 self._iterator = iter(iterable)
3509 self._start_time = monotonic()
3510 self.timed_out = False
3512 def __iter__(self):
3513 return self
3515 def __next__(self):
3516 if self.limit_seconds == 0:
3517 self.timed_out = True
3518 raise StopIteration
3519 item = next(self._iterator)
3520 if monotonic() - self._start_time > self.limit_seconds:
3521 self.timed_out = True
3522 raise StopIteration
3524 return item
3527def only(iterable, default=None, too_long=None):
3528 """If *iterable* has only one item, return it.
3529 If it has zero items, return *default*.
3530 If it has more than one item, raise the exception given by *too_long*,
3531 which is ``ValueError`` by default.
3533 >>> only([], default='missing')
3534 'missing'
3535 >>> only([1])
3536 1
3537 >>> only([1, 2]) # doctest: +IGNORE_EXCEPTION_DETAIL
3538 Traceback (most recent call last):
3539 ...
3540 ValueError: Expected exactly one item in iterable, but got 1, 2,
3541 and perhaps more.'
3542 >>> only([1, 2], too_long=TypeError) # doctest: +IGNORE_EXCEPTION_DETAIL
3543 Traceback (most recent call last):
3544 ...
3545 TypeError
3547 Note that :func:`only` attempts to advance *iterable* twice to ensure there
3548 is only one item. See :func:`spy` or :func:`peekable` to check
3549 iterable contents less destructively.
3551 """
3552 iterator = iter(iterable)
3553 for first in iterator:
3554 for second in iterator:
3555 msg = (
3556 f'Expected exactly one item in iterable, but got {first!r}, '
3557 f'{second!r}, and perhaps more.'
3558 )
3559 raise too_long or ValueError(msg)
3560 return first
3561 return default
3564def ichunked(iterable, n):
3565 """Break *iterable* into sub-iterables with *n* elements each.
3566 :func:`ichunked` is like :func:`chunked`, but it yields iterables
3567 instead of lists.
3569 If the sub-iterables are read in order, the elements of *iterable*
3570 won't be stored in memory.
3571 If they are read out of order, :func:`itertools.tee` is used to cache
3572 elements as necessary.
3574 >>> from itertools import count
3575 >>> all_chunks = ichunked(count(), 4)
3576 >>> c_1, c_2, c_3 = next(all_chunks), next(all_chunks), next(all_chunks)
3577 >>> list(c_2) # c_1's elements have been cached; c_3's haven't been
3578 [4, 5, 6, 7]
3579 >>> list(c_1)
3580 [0, 1, 2, 3]
3581 >>> list(c_3)
3582 [8, 9, 10, 11]
3584 """
3585 iterator = iter(iterable)
3586 for first in iterator:
3587 rest = islice(iterator, n - 1)
3588 cache, cacher = tee(rest)
3589 yield chain([first], rest, cache)
3590 consume(cacher)
3593def iequals(*iterables):
3594 """Return ``True`` if all given *iterables* are equal to each other,
3595 which means that they contain the same elements in the same order.
3597 The function is useful for comparing iterables of different data types
3598 or iterables that do not support equality checks.
3600 >>> iequals("abc", ['a', 'b', 'c'], ('a', 'b', 'c'), iter("abc"))
3601 True
3603 >>> iequals("abc", "acb")
3604 False
3606 Not to be confused with :func:`all_equal`, which checks whether all
3607 elements of iterable are equal to each other.
3609 """
3610 try:
3611 return all(map(all_equal, zip(*iterables, strict=True)))
3612 except ValueError:
3613 return False
3616def distinct_combinations(iterable, r):
3617 """Yield the distinct combinations of *r* items taken from *iterable*.
3619 >>> list(distinct_combinations([0, 0, 1], 2))
3620 [(0, 0), (0, 1)]
3622 Equivalent to ``set(combinations(iterable))``, except duplicates are not
3623 generated and thrown away. For larger input sequences this is much more
3624 efficient.
3626 """
3627 if r < 0:
3628 raise ValueError('r must be non-negative')
3629 elif r == 0:
3630 yield ()
3631 return
3632 pool = tuple(iterable)
3633 generators = [unique_everseen(enumerate(pool), key=itemgetter(1))]
3634 current_combo = [None] * r
3635 level = 0
3636 while generators:
3637 try:
3638 cur_idx, p = next(generators[-1])
3639 except StopIteration:
3640 generators.pop()
3641 level -= 1
3642 continue
3643 current_combo[level] = p
3644 if level + 1 == r:
3645 yield tuple(current_combo)
3646 else:
3647 generators.append(
3648 unique_everseen(
3649 enumerate(pool[cur_idx + 1 :], cur_idx + 1),
3650 key=itemgetter(1),
3651 )
3652 )
3653 level += 1
3656def filter_except(validator, iterable, *exceptions):
3657 """Yield the items from *iterable* for which the *validator* function does
3658 not raise one of the specified *exceptions*.
3660 *validator* is called for each item in *iterable*.
3661 It should be a function that accepts one argument and raises an exception
3662 if that item is not valid.
3664 >>> iterable = ['1', '2', 'three', '4', None]
3665 >>> list(filter_except(int, iterable, ValueError, TypeError))
3666 ['1', '2', '4']
3668 If an exception other than one given by *exceptions* is raised by
3669 *validator*, it is raised like normal.
3670 """
3671 for item in iterable:
3672 try:
3673 validator(item)
3674 except exceptions:
3675 pass
3676 else:
3677 yield item
3680def map_except(function, iterable, *exceptions):
3681 """Transform each item from *iterable* with *function* and yield the
3682 result, unless *function* raises one of the specified *exceptions*.
3684 *function* is called to transform each item in *iterable*.
3685 It should accept one argument.
3687 >>> iterable = ['1', '2', 'three', '4', None]
3688 >>> list(map_except(int, iterable, ValueError, TypeError))
3689 [1, 2, 4]
3691 If an exception other than one given by *exceptions* is raised by
3692 *function*, it is raised like normal.
3693 """
3694 for item in iterable:
3695 try:
3696 yield function(item)
3697 except exceptions:
3698 pass
3701def map_if(iterable, pred, func, func_else=None):
3702 """Evaluate each item from *iterable* using *pred*. If the result is
3703 equivalent to ``True``, transform the item with *func* and yield it.
3704 Otherwise, transform the item with *func_else* and yield it.
3706 *pred*, *func*, and *func_else* should each be functions that accept
3707 one argument. By default, *func_else* is the identity function.
3709 >>> from math import sqrt
3710 >>> iterable = list(range(-5, 5))
3711 >>> iterable
3712 [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4]
3713 >>> list(map_if(iterable, lambda x: x > 3, lambda x: 'toobig'))
3714 [-5, -4, -3, -2, -1, 0, 1, 2, 3, 'toobig']
3715 >>> list(map_if(iterable, lambda x: x >= 0,
3716 ... lambda x: f'{sqrt(x):.2f}', lambda x: None))
3717 [None, None, None, None, None, '0.00', '1.00', '1.41', '1.73', '2.00']
3718 """
3720 if func_else is None:
3721 for item in iterable:
3722 yield func(item) if pred(item) else item
3724 else:
3725 for item in iterable:
3726 yield func(item) if pred(item) else func_else(item)
3729def _sample_unweighted(iterator, k, strict):
3730 # Algorithm L in the 1994 paper by Kim-Hung Li:
3731 # "Reservoir-Sampling Algorithms of Time Complexity O(n(1+log(N/n)))".
3733 reservoir = list(islice(iterator, k))
3734 if strict and len(reservoir) < k:
3735 raise ValueError('Sample larger than population')
3736 W = 1.0
3738 with suppress(StopIteration):
3739 while True:
3740 W *= random() ** (1 / k)
3741 skip = floor(log(random()) / log1p(-W))
3742 element = next(islice(iterator, skip, None))
3743 reservoir[randrange(k)] = element
3745 shuffle(reservoir)
3746 return reservoir
3749def _sample_weighted(iterator, k, weights, strict):
3750 # Implementation of "A-ExpJ" from the 2006 paper by Efraimidis et al. :
3751 # "Weighted random sampling with a reservoir".
3753 # Log-transform for numerical stability for weights that are small/large
3754 weight_keys = (log(random()) / weight for weight in weights)
3756 # Fill up the reservoir (collection of samples) with the first `k`
3757 # weight-keys and elements, then heapify the list.
3758 reservoir = take(k, zip(weight_keys, iterator))
3759 if strict and len(reservoir) < k:
3760 raise ValueError('Sample larger than population')
3762 heapify(reservoir)
3764 # The number of jumps before changing the reservoir is a random variable
3765 # with an exponential distribution. Sample it using random() and logs.
3766 smallest_weight_key, _ = reservoir[0]
3767 weights_to_skip = log(random()) / smallest_weight_key
3769 for weight, element in zip(weights, iterator):
3770 if weight >= weights_to_skip:
3771 # The notation here is consistent with the paper, but we store
3772 # the weight-keys in log-space for better numerical stability.
3773 smallest_weight_key, _ = reservoir[0]
3774 t_w = exp(weight * smallest_weight_key)
3775 r_2 = uniform(t_w, 1) # generate U(t_w, 1)
3776 weight_key = log(r_2) / weight
3777 heapreplace(reservoir, (weight_key, element))
3778 smallest_weight_key, _ = reservoir[0]
3779 weights_to_skip = log(random()) / smallest_weight_key
3780 else:
3781 weights_to_skip -= weight
3783 ret = [element for weight_key, element in reservoir]
3784 shuffle(ret)
3785 return ret
3788def _sample_counted(population, k, counts, strict):
3789 element = None
3790 remaining = 0
3792 def feed(i):
3793 # Advance *i* steps ahead and consume an element
3794 nonlocal element, remaining
3796 while i + 1 > remaining:
3797 i = i - remaining
3798 element = next(population)
3799 remaining = next(counts)
3800 remaining -= i + 1
3801 return element
3803 with suppress(StopIteration):
3804 reservoir = []
3805 for _ in range(k):
3806 reservoir.append(feed(0))
3808 if strict and len(reservoir) < k:
3809 raise ValueError('Sample larger than population')
3811 with suppress(StopIteration):
3812 W = 1.0
3813 while True:
3814 W *= random() ** (1 / k)
3815 skip = floor(log(random()) / log1p(-W))
3816 element = feed(skip)
3817 reservoir[randrange(k)] = element
3819 shuffle(reservoir)
3820 return reservoir
3823def sample(iterable, k, weights=None, *, counts=None, strict=False):
3824 """Return a *k*-length list of elements chosen (without replacement)
3825 from the *iterable*.
3827 Similar to :func:`random.sample`, but works on inputs that aren't
3828 indexable (such as sets and dictionaries) and on inputs where the
3829 size isn't known in advance (such as generators).
3831 >>> iterable = range(100)
3832 >>> sample(iterable, 5) # doctest: +SKIP
3833 [81, 60, 96, 16, 4]
3835 For iterables with repeated elements, you may supply *counts* to
3836 indicate the repeats.
3838 >>> iterable = ['a', 'b']
3839 >>> counts = [3, 4] # Equivalent to 'a', 'a', 'a', 'b', 'b', 'b', 'b'
3840 >>> sample(iterable, k=3, counts=counts) # doctest: +SKIP
3841 ['a', 'a', 'b']
3843 An iterable with *weights* may be given:
3845 >>> iterable = range(100)
3846 >>> weights = (i * i + 1 for i in range(100))
3847 >>> sampled = sample(iterable, 5, weights=weights) # doctest: +SKIP
3848 [79, 67, 74, 66, 78]
3850 Weighted selections are made without replacement.
3851 After an element is selected, it is removed from the pool and the
3852 relative weights of the other elements increase (this
3853 does not match the behavior of :func:`random.sample`'s *counts*
3854 parameter). Note that *weights* may not be used with *counts*.
3856 If the length of *iterable* is less than *k*,
3857 ``ValueError`` is raised if *strict* is ``True`` and
3858 all elements are returned (in shuffled order) if *strict* is ``False``.
3860 By default, the `Algorithm L <https://w.wiki/ANrM>`__ reservoir sampling
3861 technique is used. When *weights* are provided,
3862 `Algorithm A-ExpJ <https://w.wiki/ANrS>`__ is used instead.
3864 Notes on reproducibility:
3866 * The algorithms rely on inexact floating-point functions provided
3867 by the underlying math library (e.g. ``log``, ``log1p``, and ``pow``).
3868 Those functions can `produce slightly different results
3869 <https://members.loria.fr/PZimmermann/papers/accuracy.pdf>`_ on
3870 different builds. Accordingly, selections can vary across builds
3871 even for the same seed.
3873 * The algorithms loop over the input and make selections based on
3874 ordinal position, so selections from unordered collections (such as
3875 sets) won't reproduce across sessions on the same platform using the
3876 same seed. For example, this won't reproduce::
3878 >> seed(8675309)
3879 >> sample(set('abcdefghijklmnopqrstuvwxyz'), 10)
3880 ['c', 'p', 'e', 'w', 's', 'a', 'j', 'd', 'n', 't']
3882 """
3883 iterator = iter(iterable)
3885 if k < 0:
3886 raise ValueError('k must be non-negative')
3888 if k == 0:
3889 return []
3891 if weights is not None and counts is not None:
3892 raise TypeError('weights and counts are mutually exclusive')
3894 elif weights is not None:
3895 weights = iter(weights)
3896 return _sample_weighted(iterator, k, weights, strict)
3898 elif counts is not None:
3899 counts = iter(counts)
3900 return _sample_counted(iterator, k, counts, strict)
3902 else:
3903 return _sample_unweighted(iterator, k, strict)
3906def is_sorted(iterable, key=None, reverse=False, strict=False):
3907 """Returns ``True`` if the items of iterable are in sorted order, and
3908 ``False`` otherwise. *key* and *reverse* have the same meaning that they do
3909 in the built-in :func:`sorted` function.
3911 >>> is_sorted(['1', '2', '3', '4', '5'], key=int)
3912 True
3913 >>> is_sorted([5, 4, 3, 1, 2], reverse=True)
3914 False
3916 If *strict*, tests for strict sorting, that is, returns ``False`` if equal
3917 elements are found:
3919 >>> is_sorted([1, 2, 2])
3920 True
3921 >>> is_sorted([1, 2, 2], strict=True)
3922 False
3924 The function returns ``False`` after encountering the first out-of-order
3925 item, which means it may produce results that differ from the built-in
3926 :func:`sorted` function for objects with unusual comparison dynamics
3927 (like ``math.nan``). If there are no out-of-order items, the iterable is
3928 exhausted.
3929 """
3930 it = iterable if (key is None) else map(key, iterable)
3931 a, b = tee(it)
3932 next(b, None)
3933 if reverse:
3934 b, a = a, b
3935 return all(map(lt, a, b)) if strict else not any(map(lt, b, a))
3938class AbortThread(BaseException):
3939 pass
3942class callback_iter:
3943 """Convert a function that uses callbacks to an iterator.
3945 Let *func* be a function that takes a `callback` keyword argument.
3946 For example:
3948 >>> def func(callback=None):
3949 ... for i, c in [(1, 'a'), (2, 'b'), (3, 'c')]:
3950 ... if callback:
3951 ... callback(i, c)
3952 ... return 4
3955 Use ``with callback_iter(func)`` to get an iterator over the parameters
3956 that are delivered to the callback.
3958 >>> with callback_iter(func) as it:
3959 ... for args, kwargs in it:
3960 ... print(args)
3961 (1, 'a')
3962 (2, 'b')
3963 (3, 'c')
3965 The function will be called in a background thread. The ``done`` property
3966 indicates whether it has completed execution.
3968 >>> it.done
3969 True
3971 If it completes successfully, its return value will be available
3972 in the ``result`` property.
3974 >>> it.result
3975 4
3977 Notes:
3979 * If the function uses some keyword argument besides ``callback``, supply
3980 *callback_kwd*.
3981 * If it finished executing, but raised an exception, accessing the
3982 ``result`` property will raise the same exception.
3983 * If it hasn't finished executing, accessing the ``result``
3984 property from within the ``with`` block will raise ``RuntimeError``.
3985 * If it hasn't finished executing, accessing the ``result`` property from
3986 outside the ``with`` block will raise a
3987 ``more_itertools.AbortThread`` exception.
3988 * Provide *wait_seconds* to adjust how frequently the it is polled for
3989 output.
3991 """
3993 def __init__(self, func, callback_kwd='callback', wait_seconds=0.1):
3994 self._func = func
3995 self._callback_kwd = callback_kwd
3996 self._aborted = False
3997 self._future = None
3998 self._wait_seconds = wait_seconds
4000 # Lazily import concurrent.future
4001 self._module = __import__('concurrent.futures').futures
4002 self._executor = self._module.ThreadPoolExecutor(max_workers=1)
4003 self._iterator = self._reader()
4005 def __enter__(self):
4006 return self
4008 def __exit__(self, exc_type, exc_value, traceback):
4009 self._aborted = True
4010 self._executor.shutdown()
4012 def __iter__(self):
4013 return self
4015 def __next__(self):
4016 return next(self._iterator)
4018 @property
4019 def done(self):
4020 if self._future is None:
4021 return False
4022 return self._future.done()
4024 @property
4025 def result(self):
4026 if self._future:
4027 try:
4028 return self._future.result(timeout=0)
4029 except self._module.TimeoutError:
4030 pass
4032 raise RuntimeError('Function has not yet completed')
4034 def _reader(self):
4035 q = Queue()
4037 def callback(*args, **kwargs):
4038 if self._aborted:
4039 raise AbortThread('canceled by user')
4041 q.put((args, kwargs))
4043 self._future = self._executor.submit(
4044 self._func, **{self._callback_kwd: callback}
4045 )
4047 while True:
4048 try:
4049 item = q.get(timeout=self._wait_seconds)
4050 except Empty:
4051 pass
4052 else:
4053 q.task_done()
4054 yield item
4056 if self._future.done():
4057 break
4059 remaining = []
4060 while True:
4061 try:
4062 item = q.get_nowait()
4063 except Empty:
4064 break
4065 else:
4066 q.task_done()
4067 remaining.append(item)
4068 q.join()
4069 yield from remaining
4072def windowed_complete(iterable, n):
4073 """
4074 Yield ``(beginning, middle, end)`` tuples, where:
4076 * Each ``middle`` has *n* items from *iterable*
4077 * Each ``beginning`` has the items before the ones in ``middle``
4078 * Each ``end`` has the items after the ones in ``middle``
4080 >>> iterable = range(7)
4081 >>> n = 3
4082 >>> for beginning, middle, end in windowed_complete(iterable, n):
4083 ... print(beginning, middle, end)
4084 () (0, 1, 2) (3, 4, 5, 6)
4085 (0,) (1, 2, 3) (4, 5, 6)
4086 (0, 1) (2, 3, 4) (5, 6)
4087 (0, 1, 2) (3, 4, 5) (6,)
4088 (0, 1, 2, 3) (4, 5, 6) ()
4090 Note that *n* must be at least 0 and most equal to the length of
4091 *iterable*.
4093 This function will exhaust the iterable and may require significant
4094 storage.
4095 """
4096 if n < 0:
4097 raise ValueError('n must be >= 0')
4099 seq = tuple(iterable)
4100 size = len(seq)
4102 if n > size:
4103 raise ValueError('n must be <= len(seq)')
4105 for i in range(size - n + 1):
4106 beginning = seq[:i]
4107 middle = seq[i : i + n]
4108 end = seq[i + n :]
4109 yield beginning, middle, end
4112def all_unique(iterable, key=None):
4113 """
4114 Returns ``True`` if all the elements of *iterable* are unique (no two
4115 elements are equal).
4117 >>> all_unique('ABCB')
4118 False
4120 If a *key* function is specified, it will be used to make comparisons.
4122 >>> all_unique('ABCb')
4123 True
4124 >>> all_unique('ABCb', str.lower)
4125 False
4127 The function returns as soon as the first non-unique element is
4128 encountered. Iterables with a mix of hashable and unhashable items can
4129 be used, but the function will be slower for unhashable items.
4130 """
4131 seenset = set()
4132 seenset_add = seenset.add
4133 seenlist = []
4134 seenlist_add = seenlist.append
4135 for element in map(key, iterable) if key else iterable:
4136 try:
4137 if element in seenset:
4138 return False
4139 seenset_add(element)
4140 except TypeError:
4141 if element in seenlist:
4142 return False
4143 seenlist_add(element)
4144 return True
4147def nth_product(index, *iterables, repeat=1):
4148 """Equivalent to ``list(product(*iterables, repeat=repeat))[index]``.
4150 The products of *iterables* can be ordered lexicographically.
4151 :func:`nth_product` computes the product at sort position *index* without
4152 computing the previous products.
4154 >>> nth_product(8, range(2), range(2), range(2), range(2))
4155 (1, 0, 0, 0)
4157 The *repeat* keyword argument specifies the number of repetitions
4158 of the iterables. The above example is equivalent to::
4160 >>> nth_product(8, range(2), repeat=4)
4161 (1, 0, 0, 0)
4163 ``IndexError`` will be raised if the given *index* is invalid.
4164 """
4165 pools = tuple(map(tuple, reversed(iterables))) * repeat
4166 ns = tuple(map(len, pools))
4168 c = prod(ns)
4170 if index < 0:
4171 index += c
4172 if not 0 <= index < c:
4173 raise IndexError
4175 result = []
4176 for pool, n in zip(pools, ns):
4177 result.append(pool[index % n])
4178 index //= n
4180 return tuple(reversed(result))
4183def nth_permutation(iterable, r, index):
4184 """Equivalent to ``list(permutations(iterable, r))[index]```
4186 The subsequences of *iterable* that are of length *r* where order is
4187 important can be ordered lexicographically. :func:`nth_permutation`
4188 computes the subsequence at sort position *index* directly, without
4189 computing the previous subsequences.
4191 >>> nth_permutation('ghijk', 2, 5)
4192 ('h', 'i')
4194 ``ValueError`` will be raised If *r* is negative.
4195 ``IndexError`` will be raised if the given *index* is invalid.
4196 """
4197 pool = list(iterable)
4198 n = len(pool)
4199 if r is None:
4200 r = n
4201 c = perm(n, r)
4203 if index < 0:
4204 index += c
4205 if not 0 <= index < c:
4206 raise IndexError
4208 result = [0] * r
4209 q = index * factorial(n) // c if r < n else index
4210 for d in range(1, n + 1):
4211 q, i = divmod(q, d)
4212 if 0 <= n - d < r:
4213 result[n - d] = i
4214 if q == 0:
4215 break
4217 return tuple(map(pool.pop, result))
4220def nth_combination_with_replacement(iterable, r, index):
4221 """Equivalent to
4222 ``list(combinations_with_replacement(iterable, r))[index]``.
4225 The subsequences with repetition of *iterable* that are of length *r* can
4226 be ordered lexicographically. :func:`nth_combination_with_replacement`
4227 computes the subsequence at sort position *index* directly, without
4228 computing the previous subsequences with replacement.
4230 >>> nth_combination_with_replacement(range(5), 3, 5)
4231 (0, 1, 1)
4233 ``ValueError`` will be raised If *r* is negative.
4234 ``IndexError`` will be raised if the given *index* is invalid.
4235 """
4236 pool = tuple(iterable)
4237 n = len(pool)
4238 if r < 0:
4239 raise ValueError
4240 c = comb(n + r - 1, r) if n else 0 if r else 1
4242 if index < 0:
4243 index += c
4244 if not 0 <= index < c:
4245 raise IndexError
4247 result = []
4248 i = 0
4249 while r:
4250 r -= 1
4251 while n >= 0:
4252 num_combs = comb(n + r - 1, r)
4253 if index < num_combs:
4254 break
4255 n -= 1
4256 i += 1
4257 index -= num_combs
4258 result.append(pool[i])
4260 return tuple(result)
4263def value_chain(*args):
4264 """Yield all arguments passed to the function in the same order in which
4265 they were passed. If an argument itself is iterable then iterate over its
4266 values.
4268 >>> list(value_chain(1, 2, 3, [4, 5, 6]))
4269 [1, 2, 3, 4, 5, 6]
4271 Binary and text strings are not considered iterable and are emitted
4272 as-is:
4274 >>> list(value_chain('12', '34', ['56', '78']))
4275 ['12', '34', '56', '78']
4277 Pre- or postpend a single element to an iterable:
4279 >>> list(value_chain(1, [2, 3, 4, 5, 6]))
4280 [1, 2, 3, 4, 5, 6]
4281 >>> list(value_chain([1, 2, 3, 4, 5], 6))
4282 [1, 2, 3, 4, 5, 6]
4284 Multiple levels of nesting are not flattened.
4286 """
4287 scalar_types = (str, bytes)
4288 for value in args:
4289 if isinstance(value, scalar_types):
4290 yield value
4291 continue
4292 try:
4293 yield from value
4294 except TypeError:
4295 yield value
4298def product_index(element, *iterables, repeat=1):
4299 """Equivalent to ``list(product(*iterables, repeat=repeat)).index(tuple(element))``
4301 The products of *iterables* can be ordered lexicographically.
4302 :func:`product_index` computes the first index of *element* without
4303 computing the previous products.
4305 >>> product_index([8, 2], range(10), range(5))
4306 42
4308 The *repeat* keyword argument specifies the number of repetitions
4309 of the iterables::
4311 >>> product_index([8, 0, 7], range(10), repeat=3)
4312 807
4314 ``ValueError`` will be raised if the given *element* isn't in the product
4315 of *args*.
4316 """
4317 elements = tuple(element)
4318 pools = tuple(map(tuple, iterables)) * repeat
4319 if len(elements) != len(pools):
4320 raise ValueError('element is not a product of args')
4322 index = 0
4323 for elem, pool in zip(elements, pools):
4324 index = index * len(pool) + pool.index(elem)
4325 return index
4328def combination_index(element, iterable):
4329 """Equivalent to ``list(combinations(iterable, r)).index(element)``
4331 The subsequences of *iterable* that are of length *r* can be ordered
4332 lexicographically. :func:`combination_index` computes the index of the
4333 first *element*, without computing the previous combinations.
4335 >>> combination_index('adf', 'abcdefg')
4336 10
4338 ``ValueError`` will be raised if the given *element* isn't one of the
4339 combinations of *iterable*.
4340 """
4341 element = enumerate(element)
4342 k, y = next(element, (None, None))
4343 if k is None:
4344 return 0
4346 indexes = []
4347 pool = enumerate(iterable)
4348 for n, x in pool:
4349 if x == y:
4350 indexes.append(n)
4351 tmp, y = next(element, (None, None))
4352 if tmp is None:
4353 break
4354 else:
4355 k = tmp
4356 else:
4357 raise ValueError('element is not a combination of iterable')
4359 n, _ = last(pool, default=(n, None))
4361 index = 1
4362 for i, j in enumerate(reversed(indexes), start=1):
4363 j = n - j
4364 if i <= j:
4365 index += comb(j, i)
4367 return comb(n + 1, k + 1) - index
4370def combination_with_replacement_index(element, iterable):
4371 """Equivalent to
4372 ``list(combinations_with_replacement(iterable, r)).index(element)``
4374 The subsequences with repetition of *iterable* that are of length *r* can
4375 be ordered lexicographically. :func:`combination_with_replacement_index`
4376 computes the index of the first *element*, without computing the previous
4377 combinations with replacement.
4379 >>> combination_with_replacement_index('adf', 'abcdefg')
4380 20
4382 ``ValueError`` will be raised if the given *element* isn't one of the
4383 combinations with replacement of *iterable*.
4384 """
4385 element = tuple(element)
4386 l = len(element)
4387 element = enumerate(element)
4389 k, y = next(element, (None, None))
4390 if k is None:
4391 return 0
4393 indexes = []
4394 pool = tuple(iterable)
4395 for n, x in enumerate(pool):
4396 while x == y:
4397 indexes.append(n)
4398 tmp, y = next(element, (None, None))
4399 if tmp is None:
4400 break
4401 else:
4402 k = tmp
4403 if y is None:
4404 break
4405 else:
4406 raise ValueError(
4407 'element is not a combination with replacement of iterable'
4408 )
4410 n = len(pool)
4411 occupations = [0] * n
4412 for p in indexes:
4413 occupations[p] += 1
4415 index = 0
4416 cumulative_sum = 0
4417 for k in range(1, n):
4418 cumulative_sum += occupations[k - 1]
4419 j = l + n - 1 - k - cumulative_sum
4420 i = n - k
4421 if i <= j:
4422 index += comb(j, i)
4424 return index
4427def permutation_index(element, iterable):
4428 """Equivalent to ``list(permutations(iterable, r)).index(element)```
4430 The subsequences of *iterable* that are of length *r* where order is
4431 important can be ordered lexicographically. :func:`permutation_index`
4432 computes the index of the first *element* directly, without computing
4433 the previous permutations.
4435 >>> permutation_index([1, 3, 2], range(5))
4436 19
4438 ``ValueError`` will be raised if the given *element* isn't one of the
4439 permutations of *iterable*.
4440 """
4441 index = 0
4442 pool = list(iterable)
4443 for i, x in zip(range(len(pool), -1, -1), element):
4444 r = pool.index(x)
4445 index = index * i + r
4446 del pool[r]
4448 return index
4451class countable:
4452 """Wrap *iterable* and keep a count of how many items have been consumed.
4454 The ``items_seen`` attribute starts at ``0`` and increments as the iterable
4455 is consumed:
4457 >>> iterable = map(str, range(10))
4458 >>> it = countable(iterable)
4459 >>> it.items_seen
4460 0
4461 >>> next(it), next(it)
4462 ('0', '1')
4463 >>> list(it)
4464 ['2', '3', '4', '5', '6', '7', '8', '9']
4465 >>> it.items_seen
4466 10
4467 """
4469 def __init__(self, iterable):
4470 self._iterator = iter(iterable)
4471 self.items_seen = 0
4473 def __iter__(self):
4474 return self
4476 def __next__(self):
4477 item = next(self._iterator)
4478 self.items_seen += 1
4480 return item
4483def chunked_even(iterable, n):
4484 """Break *iterable* into lists of approximately length *n*.
4485 Items are distributed such the lengths of the lists differ by at most
4486 1 item.
4488 >>> iterable = [1, 2, 3, 4, 5, 6, 7]
4489 >>> n = 3
4490 >>> list(chunked_even(iterable, n)) # List lengths: 3, 2, 2
4491 [[1, 2, 3], [4, 5], [6, 7]]
4492 >>> list(chunked(iterable, n)) # List lengths: 3, 3, 1
4493 [[1, 2, 3], [4, 5, 6], [7]]
4495 """
4496 iterator = iter(iterable)
4498 # Initialize a buffer to process the chunks while keeping
4499 # some back to fill any underfilled chunks
4500 min_buffer = (n - 1) * (n - 2)
4501 buffer = list(islice(iterator, min_buffer))
4503 # Append items until we have a completed chunk
4504 for _ in islice(map(buffer.append, iterator), n, None, n):
4505 yield buffer[:n]
4506 del buffer[:n]
4508 # Check if any chunks need addition processing
4509 if not buffer:
4510 return
4511 length = len(buffer)
4513 # Chunks are either size `full_size <= n` or `partial_size = full_size - 1`
4514 q, r = divmod(length, n)
4515 num_lists = q + (1 if r > 0 else 0)
4516 q, r = divmod(length, num_lists)
4517 full_size = q + (1 if r > 0 else 0)
4518 partial_size = full_size - 1
4519 num_full = length - partial_size * num_lists
4521 # Yield chunks of full size
4522 partial_start_idx = num_full * full_size
4523 if full_size > 0:
4524 for i in range(0, partial_start_idx, full_size):
4525 yield buffer[i : i + full_size]
4527 # Yield chunks of partial size
4528 if partial_size > 0:
4529 for i in range(partial_start_idx, length, partial_size):
4530 yield buffer[i : i + partial_size]
4533def zip_broadcast(*objects, scalar_types=(str, bytes), strict=False):
4534 """A version of :func:`zip` that "broadcasts" any scalar
4535 (i.e., non-iterable) items into output tuples.
4537 >>> iterable_1 = [1, 2, 3]
4538 >>> iterable_2 = ['a', 'b', 'c']
4539 >>> scalar = '_'
4540 >>> list(zip_broadcast(iterable_1, iterable_2, scalar))
4541 [(1, 'a', '_'), (2, 'b', '_'), (3, 'c', '_')]
4543 The *scalar_types* keyword argument determines what types are considered
4544 scalar. It is set to ``(str, bytes)`` by default. Set it to ``None`` to
4545 treat strings and byte strings as iterable:
4547 >>> list(zip_broadcast('abc', 0, 'xyz', scalar_types=None))
4548 [('a', 0, 'x'), ('b', 0, 'y'), ('c', 0, 'z')]
4550 If the *strict* keyword argument is ``True``, then
4551 ``ValueError`` will be raised if any of the iterables have
4552 different lengths.
4553 """
4555 def is_scalar(obj):
4556 if scalar_types and isinstance(obj, scalar_types):
4557 return True
4558 try:
4559 iter(obj)
4560 except TypeError:
4561 return True
4562 else:
4563 return False
4565 size = len(objects)
4566 if not size:
4567 return
4569 new_item = [None] * size
4570 iterables, iterable_positions = [], []
4571 for i, obj in enumerate(objects):
4572 if is_scalar(obj):
4573 new_item[i] = obj
4574 else:
4575 iterables.append(iter(obj))
4576 iterable_positions.append(i)
4578 if not iterables:
4579 yield tuple(objects)
4580 return
4582 for item in zip(*iterables, strict=strict):
4583 for i, new_item[i] in zip(iterable_positions, item):
4584 pass
4585 yield tuple(new_item)
4588def unique_in_window(iterable, n, key=None):
4589 """Yield the items from *iterable* that haven't been seen recently.
4590 *n* is the size of the sliding window.
4592 >>> iterable = [0, 1, 0, 2, 3, 0]
4593 >>> n = 3
4594 >>> list(unique_in_window(iterable, n))
4595 [0, 1, 2, 3, 0]
4597 The *key* function, if provided, will be used to determine uniqueness:
4599 >>> list(unique_in_window('abAcda', 3, key=lambda x: x.lower()))
4600 ['a', 'b', 'c', 'd', 'a']
4602 Updates a sliding window no larger than n and yields a value
4603 if the item only occurs once in the updated window.
4605 When `n == 1`, *unique_in_window* is memoryless:
4607 >>> list(unique_in_window('aab', n=1))
4608 ['a', 'a', 'b']
4610 The items in *iterable* must be hashable.
4612 """
4613 if n <= 0:
4614 raise ValueError('n must be greater than 0')
4616 window = deque(maxlen=n)
4617 counts = Counter()
4618 use_key = key is not None
4620 for item in iterable:
4621 if len(window) == n:
4622 to_discard = window[0]
4623 if counts[to_discard] == 1:
4624 del counts[to_discard]
4625 else:
4626 counts[to_discard] -= 1
4628 k = key(item) if use_key else item
4629 if k not in counts:
4630 yield item
4631 counts[k] += 1
4632 window.append(k)
4635def duplicates_everseen(iterable, key=None):
4636 """Yield duplicate elements after their first appearance.
4638 >>> list(duplicates_everseen('mississippi'))
4639 ['s', 'i', 's', 's', 'i', 'p', 'i']
4640 >>> list(duplicates_everseen('AaaBbbCccAaa', str.lower))
4641 ['a', 'a', 'b', 'b', 'c', 'c', 'A', 'a', 'a']
4643 This function is analogous to :func:`unique_everseen` and is subject to
4644 the same performance considerations.
4646 """
4647 seen_set = set()
4648 seen_list = []
4649 use_key = key is not None
4651 for element in iterable:
4652 k = key(element) if use_key else element
4653 try:
4654 if k not in seen_set:
4655 seen_set.add(k)
4656 else:
4657 yield element
4658 except TypeError:
4659 if k not in seen_list:
4660 seen_list.append(k)
4661 else:
4662 yield element
4665def duplicates_justseen(iterable, key=None):
4666 """Yields serially-duplicate elements after their first appearance.
4668 >>> list(duplicates_justseen('mississippi'))
4669 ['s', 's', 'p']
4670 >>> list(duplicates_justseen('AaaBbbCccAaa', str.lower))
4671 ['a', 'a', 'b', 'b', 'c', 'c', 'a', 'a']
4673 This function is analogous to :func:`unique_justseen`.
4675 """
4676 return flatten(g for _, g in groupby(iterable, key) for _ in g)
4679def classify_unique(iterable, key=None):
4680 """Classify each element in terms of its uniqueness.
4682 For each element in the input iterable, return a 3-tuple consisting of:
4684 1. The element itself
4685 2. ``False`` if the element is equal to the one preceding it in the input,
4686 ``True`` otherwise (i.e. the equivalent of :func:`unique_justseen`)
4687 3. ``False`` if this element has been seen anywhere in the input before,
4688 ``True`` otherwise (i.e. the equivalent of :func:`unique_everseen`)
4690 >>> list(classify_unique('otto')) # doctest: +NORMALIZE_WHITESPACE
4691 [('o', True, True),
4692 ('t', True, True),
4693 ('t', False, False),
4694 ('o', True, False)]
4696 This function is analogous to :func:`unique_everseen` and is subject to
4697 the same performance considerations.
4699 """
4700 seen_set = set()
4701 seen_list = []
4702 use_key = key is not None
4703 previous = None
4705 for i, element in enumerate(iterable):
4706 k = key(element) if use_key else element
4707 is_unique_justseen = not i or previous != k
4708 previous = k
4709 is_unique_everseen = False
4710 try:
4711 if k not in seen_set:
4712 seen_set.add(k)
4713 is_unique_everseen = True
4714 except TypeError:
4715 if k not in seen_list:
4716 seen_list.append(k)
4717 is_unique_everseen = True
4718 yield element, is_unique_justseen, is_unique_everseen
4721def minmax(iterable_or_value, *others, key=None, default=_marker):
4722 """Returns both the smallest and largest items from an iterable
4723 or from two or more arguments.
4725 >>> minmax([3, 1, 5])
4726 (1, 5)
4728 >>> minmax(4, 2, 6)
4729 (2, 6)
4731 If a *key* function is provided, it will be used to transform the input
4732 items for comparison.
4734 >>> minmax([5, 30], key=str) # '30' sorts before '5'
4735 (30, 5)
4737 If a *default* value is provided, it will be returned if there are no
4738 input items.
4740 >>> minmax([], default=(0, 0))
4741 (0, 0)
4743 Otherwise ``ValueError`` is raised.
4745 This function makes a single pass over the input elements and takes care to
4746 minimize the number of comparisons made during processing.
4748 Note that unlike the builtin ``max`` function, which always returns the first
4749 item with the maximum value, this function may return another item when there are
4750 ties.
4752 This function is based on the
4753 `recipe <https://code.activestate.com/recipes/577916-fast-minmax-function>`__ by
4754 Raymond Hettinger.
4755 """
4756 iterable = (iterable_or_value, *others) if others else iterable_or_value
4758 it = iter(iterable)
4760 try:
4761 lo = hi = next(it)
4762 except StopIteration as exc:
4763 if default is _marker:
4764 raise ValueError(
4765 '`minmax()` argument is an empty iterable. '
4766 'Provide a `default` value to suppress this error.'
4767 ) from exc
4768 return default
4770 # Different branches depending on the presence of key. This saves a lot
4771 # of unimportant copies which would slow the "key=None" branch
4772 # significantly down.
4773 if key is None:
4774 for x, y in zip_longest(it, it, fillvalue=lo):
4775 if y < x:
4776 x, y = y, x
4777 if x < lo:
4778 lo = x
4779 if hi < y:
4780 hi = y
4782 else:
4783 lo_key = hi_key = key(lo)
4785 for x, y in zip_longest(it, it, fillvalue=lo):
4786 x_key, y_key = key(x), key(y)
4788 if y_key < x_key:
4789 x, y, x_key, y_key = y, x, y_key, x_key
4790 if x_key < lo_key:
4791 lo, lo_key = x, x_key
4792 if hi_key < y_key:
4793 hi, hi_key = y, y_key
4795 return lo, hi
4798def constrained_batches(
4799 iterable, max_size, max_count=None, get_len=len, strict=True
4800):
4801 """Yield batches of items from *iterable* with a combined size limited by
4802 *max_size*.
4804 >>> iterable = [b'12345', b'123', b'12345678', b'1', b'1', b'12', b'1']
4805 >>> list(constrained_batches(iterable, 10))
4806 [(b'12345', b'123'), (b'12345678', b'1', b'1'), (b'12', b'1')]
4808 If a *max_count* is supplied, the number of items per batch is also
4809 limited:
4811 >>> iterable = [b'12345', b'123', b'12345678', b'1', b'1', b'12', b'1']
4812 >>> list(constrained_batches(iterable, 10, max_count = 2))
4813 [(b'12345', b'123'), (b'12345678', b'1'), (b'1', b'12'), (b'1',)]
4815 If a *get_len* function is supplied, use that instead of :func:`len` to
4816 determine item size.
4818 If *strict* is ``True``, raise ``ValueError`` if any single item is bigger
4819 than *max_size*. Otherwise, allow single items to exceed *max_size*.
4820 """
4821 if max_size <= 0:
4822 raise ValueError('maximum size must be greater than zero')
4824 batch = []
4825 batch_size = 0
4826 batch_count = 0
4827 for item in iterable:
4828 item_len = get_len(item)
4829 if strict and item_len > max_size:
4830 raise ValueError('item size exceeds maximum size')
4832 reached_count = batch_count == max_count
4833 reached_size = item_len + batch_size > max_size
4834 if batch_count and (reached_size or reached_count):
4835 yield tuple(batch)
4836 batch.clear()
4837 batch_size = 0
4838 batch_count = 0
4840 batch.append(item)
4841 batch_size += item_len
4842 batch_count += 1
4844 if batch:
4845 yield tuple(batch)
4848def gray_product(*iterables, repeat=1):
4849 """Like :func:`itertools.product`, but return tuples in an order such
4850 that only one element in the generated tuple changes from one iteration
4851 to the next.
4853 >>> list(gray_product('AB','CD'))
4854 [('A', 'C'), ('B', 'C'), ('B', 'D'), ('A', 'D')]
4856 The *repeat* keyword argument specifies the number of repetitions
4857 of the iterables. For example, ``gray_product('AB', repeat=3)`` is
4858 equivalent to ``gray_product('AB', 'AB', 'AB')``.
4860 This function consumes all of the input iterables before producing output.
4861 If any of the input iterables have fewer than two items, ``ValueError``
4862 is raised.
4864 For information on the algorithm, see
4865 `this section <https://www-cs-faculty.stanford.edu/~knuth/fasc2a.ps.gz>`__
4866 of Donald Knuth's *The Art of Computer Programming*.
4867 """
4868 all_iterables = tuple(map(tuple, iterables * repeat))
4869 iterable_count = len(all_iterables)
4870 for iterable in all_iterables:
4871 if len(iterable) < 2:
4872 raise ValueError("each iterable must have two or more items")
4874 # This is based on "Algorithm H" from section 7.2.1.1, page 20.
4875 # a holds the indexes of the source iterables for the n-tuple to be yielded
4876 # f is the array of "focus pointers"
4877 # o is the array of "directions"
4878 a = [0] * iterable_count
4879 f = list(range(iterable_count + 1))
4880 o = [1] * iterable_count
4881 while True:
4882 yield tuple(all_iterables[i][a[i]] for i in range(iterable_count))
4883 j = f[0]
4884 f[0] = 0
4885 if j == iterable_count:
4886 break
4887 a[j] = a[j] + o[j]
4888 if a[j] == 0 or a[j] == len(all_iterables[j]) - 1:
4889 o[j] = -o[j]
4890 f[j] = f[j + 1]
4891 f[j + 1] = j + 1
4894def partial_product(*iterables, repeat=1):
4895 """Yields tuples containing one item from each iterator, with subsequent
4896 tuples changing a single item at a time by advancing each iterator until it
4897 is exhausted. This sequence guarantees every value in each iterable is
4898 output at least once without generating all possible combinations.
4900 This may be useful, for example, when testing an expensive function.
4902 >>> list(partial_product('AB', 'C', 'DEF'))
4903 [('A', 'C', 'D'), ('B', 'C', 'D'), ('B', 'C', 'E'), ('B', 'C', 'F')]
4905 The *repeat* keyword argument specifies the number of repetitions
4906 of the iterables. For example, ``partial_product('AB', repeat=3)`` is
4907 equivalent to ``partial_product('AB', 'AB', 'AB')``.
4908 """
4910 iterators = tuple(map(iter, iterables * repeat))
4912 try:
4913 prod = [next(it) for it in iterators]
4914 except StopIteration:
4915 return
4916 yield tuple(prod)
4918 for i, it in enumerate(iterators):
4919 for prod[i] in it:
4920 yield tuple(prod)
4923def takewhile_inclusive(predicate, iterable):
4924 """A variant of :func:`takewhile` that yields one additional element.
4926 >>> list(takewhile_inclusive(lambda x: x < 5, [1, 4, 6, 4, 1]))
4927 [1, 4, 6]
4929 :func:`takewhile` would return ``[1, 4]``.
4930 """
4931 for x in iterable:
4932 yield x
4933 if not predicate(x):
4934 break
4937def outer_product(func, xs, ys, *args, **kwargs):
4938 """A generalized outer product that applies a binary function to all
4939 pairs of items. Returns a 2D matrix with ``len(xs)`` rows and ``len(ys)``
4940 columns.
4941 Also accepts ``*args`` and ``**kwargs`` that are passed to ``func``.
4943 Multiplication table:
4945 >>> from operator import mul
4946 >>> list(outer_product(mul, range(1, 4), range(1, 6)))
4947 [(1, 2, 3, 4, 5), (2, 4, 6, 8, 10), (3, 6, 9, 12, 15)]
4949 Cross tabulation:
4951 >>> xs = ['A', 'B', 'A', 'A', 'B', 'B', 'A', 'A', 'B', 'B']
4952 >>> ys = ['X', 'X', 'X', 'Y', 'Z', 'Z', 'Y', 'Y', 'Z', 'Z']
4953 >>> pair_counts = Counter(zip(xs, ys))
4954 >>> count_rows = lambda x, y: pair_counts[x, y]
4955 >>> list(outer_product(count_rows, sorted(set(xs)), sorted(set(ys))))
4956 [(2, 3, 0), (1, 0, 4)]
4958 Usage with ``*args`` and ``**kwargs``:
4960 >>> animals = ['cat', 'wolf', 'mouse']
4961 >>> list(outer_product(min, animals, animals, key=len))
4962 [('cat', 'cat', 'cat'), ('cat', 'wolf', 'wolf'), ('cat', 'wolf', 'mouse')]
4963 """
4964 ys = tuple(ys)
4965 return batched(
4966 starmap(lambda x, y: func(x, y, *args, **kwargs), product(xs, ys)),
4967 n=len(ys),
4968 )
4971def iter_suppress(iterable, *exceptions):
4972 """Yield each of the items from *iterable*. If the iteration raises one of
4973 the specified *exceptions*, that exception will be suppressed and iteration
4974 will stop.
4976 >>> from itertools import chain
4977 >>> def breaks_at_five(x):
4978 ... while True:
4979 ... if x >= 5:
4980 ... raise RuntimeError
4981 ... yield x
4982 ... x += 1
4983 >>> it_1 = iter_suppress(breaks_at_five(1), RuntimeError)
4984 >>> it_2 = iter_suppress(breaks_at_five(2), RuntimeError)
4985 >>> list(chain(it_1, it_2))
4986 [1, 2, 3, 4, 2, 3, 4]
4987 """
4988 try:
4989 yield from iterable
4990 except exceptions:
4991 return
4994def filter_map(func, iterable):
4995 """Apply *func* to every element of *iterable*, yielding only those which
4996 are not ``None``.
4998 >>> elems = ['1', 'a', '2', 'b', '3']
4999 >>> list(filter_map(lambda s: int(s) if s.isnumeric() else None, elems))
5000 [1, 2, 3]
5001 """
5002 for x in iterable:
5003 y = func(x)
5004 if y is not None:
5005 yield y
5008def powerset_of_sets(iterable, *, baseset=set):
5009 """Yields all possible subsets of the iterable.
5011 >>> list(powerset_of_sets([1, 2, 3])) # doctest: +SKIP
5012 [set(), {1}, {2}, {3}, {1, 2}, {1, 3}, {2, 3}, {1, 2, 3}]
5013 >>> list(powerset_of_sets([1, 1, 0])) # doctest: +SKIP
5014 [set(), {1}, {0}, {0, 1}]
5016 :func:`powerset_of_sets` takes care to minimize the number
5017 of hash operations performed.
5019 The *baseset* parameter determines what kind of sets are
5020 constructed, either *set* or *frozenset*.
5021 """
5022 sets = tuple(dict.fromkeys(map(frozenset, zip(iterable))))
5023 union = baseset().union
5024 return chain.from_iterable(
5025 starmap(union, combinations(sets, r)) for r in range(len(sets) + 1)
5026 )
5029def join_mappings(**field_to_map):
5030 """
5031 Joins multiple mappings together using their common keys.
5033 >>> user_scores = {'elliot': 50, 'claris': 60}
5034 >>> user_times = {'elliot': 30, 'claris': 40}
5035 >>> join_mappings(score=user_scores, time=user_times)
5036 {'elliot': {'score': 50, 'time': 30}, 'claris': {'score': 60, 'time': 40}}
5037 """
5038 ret = defaultdict(dict)
5040 for field_name, mapping in field_to_map.items():
5041 for key, value in mapping.items():
5042 ret[key][field_name] = value
5044 return dict(ret)
5047def _complex_sumprod(v1, v2):
5048 """High precision sumprod() for complex numbers.
5049 Used by :func:`dft` and :func:`idft`.
5050 """
5052 real = attrgetter('real')
5053 imag = attrgetter('imag')
5054 r1 = chain(map(real, v1), map(neg, map(imag, v1)))
5055 r2 = chain(map(real, v2), map(imag, v2))
5056 i1 = chain(map(real, v1), map(imag, v1))
5057 i2 = chain(map(imag, v2), map(real, v2))
5058 return complex(_fsumprod(r1, r2), _fsumprod(i1, i2))
5061def dft(xarr):
5062 """Discrete Fourier Transform. *xarr* is a sequence of complex numbers.
5063 Yields the components of the corresponding transformed output vector.
5065 >>> import cmath
5066 >>> xarr = [1, 2-1j, -1j, -1+2j] # time domain
5067 >>> Xarr = [2, -2-2j, -2j, 4+4j] # frequency domain
5068 >>> magnitudes, phases = zip(*map(cmath.polar, Xarr))
5069 >>> all(map(cmath.isclose, dft(xarr), Xarr))
5070 True
5072 Inputs are restricted to numeric types that can add and multiply
5073 with a complex number. This includes int, float, complex, and
5074 Fraction, but excludes Decimal.
5076 See :func:`idft` for the inverse Discrete Fourier Transform.
5077 """
5078 N = len(xarr)
5079 roots_of_unity = [e ** (n / N * tau * -1j) for n in range(N)]
5080 for k in range(N):
5081 coeffs = [roots_of_unity[k * n % N] for n in range(N)]
5082 yield _complex_sumprod(xarr, coeffs)
5085def idft(Xarr):
5086 """Inverse Discrete Fourier Transform. *Xarr* is a sequence of
5087 complex numbers. Yields the components of the corresponding
5088 inverse-transformed output vector.
5090 >>> import cmath
5091 >>> xarr = [1, 2-1j, -1j, -1+2j] # time domain
5092 >>> Xarr = [2, -2-2j, -2j, 4+4j] # frequency domain
5093 >>> all(map(cmath.isclose, idft(Xarr), xarr))
5094 True
5096 Inputs are restricted to numeric types that can add and multiply
5097 with a complex number. This includes int, float, complex, and
5098 Fraction, but excludes Decimal.
5100 See :func:`dft` for the Discrete Fourier Transform.
5101 """
5102 N = len(Xarr)
5103 roots_of_unity = [e ** (n / N * tau * 1j) for n in range(N)]
5104 for k in range(N):
5105 coeffs = [roots_of_unity[k * n % N] for n in range(N)]
5106 yield _complex_sumprod(Xarr, coeffs) / N
5109def doublestarmap(func, iterable):
5110 """Apply *func* to every item of *iterable* by dictionary unpacking
5111 the item into *func*.
5113 The difference between :func:`itertools.starmap` and :func:`doublestarmap`
5114 parallels the distinction between ``func(*a)`` and ``func(**a)``.
5116 >>> iterable = [{'a': 1, 'b': 2}, {'a': 40, 'b': 60}]
5117 >>> list(doublestarmap(lambda a, b: a + b, iterable))
5118 [3, 100]
5120 ``TypeError`` will be raised if *func*'s signature doesn't match the
5121 mapping contained in *iterable* or if *iterable* does not contain mappings.
5122 """
5123 for item in iterable:
5124 yield func(**item)
5127def _nth_prime_bounds(n):
5128 """Bounds for the nth prime (counting from 1): lb < p_n < ub."""
5129 # At and above 688,383, the lb/ub spread is under 0.003 * p_n.
5131 if n < 1:
5132 raise ValueError
5134 if n < 6:
5135 return (n, 2.25 * n)
5137 # https://en.wikipedia.org/wiki/Prime-counting_function#Inequalities
5138 upper_bound = n * log(n * log(n))
5139 lower_bound = upper_bound - n
5140 if n >= 688_383:
5141 upper_bound -= n * (1.0 - (log(log(n)) - 2.0) / log(n))
5143 return lower_bound, upper_bound
5146def nth_prime(n, *, approximate=False):
5147 """Return the nth prime (counting from 0).
5149 >>> nth_prime(0)
5150 2
5151 >>> nth_prime(100)
5152 547
5154 If *approximate* is set to True, will return a prime close
5155 to the nth prime. The estimation is much faster than computing
5156 an exact result.
5158 >>> nth_prime(200_000_000, approximate=True) # Exact result is 4222234763
5159 4217820427
5161 """
5162 lb, ub = _nth_prime_bounds(n + 1)
5164 if not approximate or n <= 1_000_000:
5165 return nth(sieve(ceil(ub)), n)
5167 # Search from the midpoint and return the first odd prime
5168 odd = floor((lb + ub) / 2) | 1
5169 return first_true(count(odd, step=2), pred=is_prime)
5172def argmin(iterable, *, key=None):
5173 """
5174 Index of the first occurrence of a minimum value in an iterable.
5176 >>> argmin('efghabcdijkl')
5177 4
5178 >>> argmin([3, 2, 1, 0, 4, 2, 1, 0])
5179 3
5181 For example, look up a label corresponding to the position
5182 of a value that minimizes a cost function::
5184 >>> def cost(x):
5185 ... "Days for a wound to heal given a subject's age."
5186 ... return x**2 - 20*x + 150
5187 ...
5188 >>> labels = ['homer', 'marge', 'bart', 'lisa', 'maggie']
5189 >>> ages = [ 35, 30, 10, 9, 1 ]
5191 # Fastest healing family member
5192 >>> labels[argmin(ages, key=cost)]
5193 'bart'
5195 # Age with fastest healing
5196 >>> min(ages, key=cost)
5197 10
5199 """
5200 if key is not None:
5201 iterable = map(key, iterable)
5202 return min(enumerate(iterable), key=itemgetter(1))[0]
5205def argmax(iterable, *, key=None):
5206 """
5207 Index of the first occurrence of a maximum value in an iterable.
5209 >>> argmax('abcdefghabcd')
5210 7
5211 >>> argmax([0, 1, 2, 3, 3, 2, 1, 0])
5212 3
5214 For example, identify the best machine learning model::
5216 >>> models = ['svm', 'random forest', 'knn', 'naïve bayes']
5217 >>> accuracy = [ 68, 61, 84, 72 ]
5219 # Most accurate model
5220 >>> models[argmax(accuracy)]
5221 'knn'
5223 # Best accuracy
5224 >>> max(accuracy)
5225 84
5227 """
5228 if key is not None:
5229 iterable = map(key, iterable)
5230 return max(enumerate(iterable), key=itemgetter(1))[0]
5233def _extract_monotonic(iterator, indices):
5234 'Non-decreasing indices, lazily consumed'
5235 num_read = 0
5236 for index in indices:
5237 advance = index - num_read
5238 try:
5239 value = next(islice(iterator, advance, None))
5240 except ValueError:
5241 if advance != -1 or index < 0:
5242 raise ValueError(f'Invalid index: {index}') from None
5243 except StopIteration:
5244 raise IndexError(index) from None
5245 else:
5246 num_read += advance + 1
5247 yield value
5250def _extract_buffered(iterator, index_and_position):
5251 'Arbitrary index order, greedily consumed'
5252 buffer = {}
5253 iterator_position = -1
5254 next_to_emit = 0
5256 for index, order in index_and_position:
5257 advance = index - iterator_position
5258 if advance:
5259 try:
5260 value = next(islice(iterator, advance - 1, None))
5261 except StopIteration:
5262 raise IndexError(index) from None
5263 iterator_position = index
5265 buffer[order] = value
5267 while next_to_emit in buffer:
5268 yield buffer.pop(next_to_emit)
5269 next_to_emit += 1
5272def extract(iterable, indices, *, monotonic=False):
5273 """Yield values at the specified indices.
5275 Example:
5277 >>> data = 'abcdefghijklmnopqrstuvwxyz'
5278 >>> list(extract(data, [7, 4, 11, 11, 14]))
5279 ['h', 'e', 'l', 'l', 'o']
5281 The *iterable* is consumed lazily and can be infinite.
5283 When *monotonic* is false, the *indices* are consumed immediately
5284 and must be finite. When *monotonic* is true, *indices* are consumed
5285 lazily and can be infinite but must be non-decreasing.
5287 Raises ``IndexError`` if an index lies beyond the iterable.
5288 Raises ``ValueError`` for a negative index or for a decreasing
5289 index when *monotonic* is true.
5290 """
5292 iterator = iter(iterable)
5293 indices = iter(indices)
5295 if monotonic:
5296 return _extract_monotonic(iterator, indices)
5298 index_and_position = sorted(zip(indices, count()))
5299 if index_and_position and index_and_position[0][0] < 0:
5300 raise ValueError('Indices must be non-negative')
5301 return _extract_buffered(iterator, index_and_position)
5304class serialize:
5305 """Wrap a non-concurrent iterator with a lock to enforce sequential access.
5307 Applies a non-reentrant lock around calls to ``__next__``, allowing
5308 iterator and generator instances to be shared by multiple consumer
5309 threads.
5310 """
5312 __slots__ = ('iterator', 'lock')
5314 def __init__(self, iterable):
5315 self.iterator = iter(iterable)
5316 self.lock = Lock()
5318 def __iter__(self):
5319 return self
5321 def __next__(self):
5322 with self.lock:
5323 return next(self.iterator)
5326def synchronized(func):
5327 """Wrap an iterator-returning callable to make its iterators thread-safe.
5329 Existing itertools and more-itertools can be wrapped so that their
5330 iterator instances are serialized.
5332 For example, ``itertools.count`` does not make thread-safe instances,
5333 but that is easily fixed with::
5335 atomic_counter = synchronized(itertools.count)
5337 Can also be used as a decorator for generator functions definitions
5338 so that the generator instances are serialized::
5340 @synchronized
5341 def enumerate_and_timestamp(iterable):
5342 for count, value in enumerate(iterable):
5343 yield count, time_ns(), value
5345 """
5347 @wraps(func)
5348 def inner(*args, **kwargs):
5349 iterator = func(*args, **kwargs)
5350 return serialize(iterator)
5352 return inner
5355def concurrent_tee(iterable, n=2):
5356 """Variant of itertools.tee() but with guaranteed threading semantics.
5358 Takes a non-threadsafe iterator as an input and creates concurrent
5359 tee objects for other threads to have reliable independent copies of
5360 the data stream.
5362 The new iterators are only thread-safe if consumed within a single thread.
5363 To share just one of the new iterators across multiple threads, wrap it
5364 with :func:`serialize`.
5365 """
5367 if n < 0:
5368 raise ValueError
5369 if n == 0:
5370 return ()
5371 iterator = _concurrent_tee(iterable)
5372 result = [iterator]
5373 for _ in range(n - 1):
5374 result.append(_concurrent_tee(iterator))
5375 return tuple(result)
5378class _concurrent_tee:
5379 __slots__ = ('iterator', 'link', 'lock')
5381 def __init__(self, iterable):
5382 if isinstance(iterable, _concurrent_tee):
5383 self.iterator = iterable.iterator
5384 self.link = iterable.link
5385 self.lock = iterable.lock
5386 else:
5387 self.iterator = iter(iterable)
5388 self.link = [None, None]
5389 self.lock = Lock()
5391 def __iter__(self):
5392 return self
5394 def __next__(self):
5395 link = self.link
5396 if link[1] is None:
5397 with self.lock:
5398 if link[1] is None:
5399 link[0] = next(self.iterator)
5400 link[1] = [None, None]
5401 value, self.link = link
5402 return value