Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/more_itertools/more.py: 19%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import math
2import types
4from collections import Counter, defaultdict, deque
5from collections.abc import Sequence
6from contextlib import suppress
7from functools import cached_property, partial, wraps
8from heapq import heapify, heapreplace
9from itertools import (
10 chain,
11 combinations,
12 compress,
13 count,
14 cycle,
15 dropwhile,
16 groupby,
17 islice,
18 permutations,
19 repeat,
20 starmap,
21 takewhile,
22 tee,
23 zip_longest,
24 product,
25)
26from math import comb, e, exp, factorial, floor, fsum, log, log1p, perm, tau
27from math import ceil, prod
28from queue import Empty, Queue
29from random import random, randrange, shuffle, uniform
30from operator import (
31 attrgetter,
32 getitem,
33 is_not,
34 itemgetter,
35 lt,
36 neg,
37 sub,
38 gt,
39)
40from sys import maxsize
41from time import monotonic
42from threading import Lock
44from .recipes import (
45 _marker,
46 consume,
47 first_true,
48 flatten,
49 is_prime,
50 nth,
51 powerset,
52 sieve,
53 take,
54 unique_everseen,
55 all_equal,
56 batched,
57)
59__all__ = [
60 'AbortThread',
61 'SequenceView',
62 'adjacent',
63 'all_unique',
64 'always_iterable',
65 'always_reversible',
66 'argmax',
67 'argmin',
68 'bucket',
69 'callback_iter',
70 'chunked',
71 'chunked_even',
72 'circular_shifts',
73 'collapse',
74 'combination_index',
75 'combination_with_replacement_index',
76 'concurrent_tee',
77 'consecutive_groups',
78 'constrained_batches',
79 'consumer',
80 'count_cycle',
81 'countable',
82 'derangements',
83 'dft',
84 'difference',
85 'distinct_combinations',
86 'distinct_permutations',
87 'distribute',
88 'divide',
89 'doublestarmap',
90 'duplicates_everseen',
91 'duplicates_justseen',
92 'classify_unique',
93 'exactly_n',
94 'extract',
95 'filter_except',
96 'filter_map',
97 'first',
98 'gray_product',
99 'groupby_transform',
100 'ichunked',
101 'iequals',
102 'idft',
103 'ilen',
104 'interleave',
105 'interleave_evenly',
106 'interleave_longest',
107 'interleave_randomly',
108 'intersperse',
109 'is_sorted',
110 'islice_extended',
111 'iterate',
112 'iter_suppress',
113 'join_mappings',
114 'last',
115 'locate',
116 'longest_common_prefix',
117 'lstrip',
118 'make_decorator',
119 'map_except',
120 'map_if',
121 'map_reduce',
122 'mark_ends',
123 'minmax',
124 'nth_or_last',
125 'nth_permutation',
126 'nth_prime',
127 'nth_product',
128 'nth_combination_with_replacement',
129 'numeric_range',
130 'one',
131 'only',
132 'outer_product',
133 'padded',
134 'partial_product',
135 'partitions',
136 'peekable',
137 'permutation_index',
138 'powerset_of_sets',
139 'product_index',
140 'raise_',
141 'repeat_each',
142 'repeat_last',
143 'replace',
144 'rlocate',
145 'rstrip',
146 'run_length',
147 'sample',
148 'seekable',
149 'serialize',
150 'set_partitions',
151 'side_effect',
152 'sized_iterator',
153 'sliced',
154 'sort_together',
155 'split_after',
156 'split_at',
157 'split_before',
158 'split_into',
159 'split_when',
160 'spy',
161 'stagger',
162 'strip',
163 'strictly_n',
164 'substrings',
165 'substrings_indexes',
166 'synchronized',
167 'takewhile_inclusive',
168 'time_limited',
169 'unique_in_window',
170 'unique_to_each',
171 'unzip',
172 'value_chain',
173 'windowed',
174 'windowed_complete',
175 'with_iter',
176 'zip_broadcast',
177 'zip_offset',
178]
180# math.sumprod is available for Python 3.12+
181try:
182 from math import sumprod as _fsumprod
184except ImportError: # pragma: no cover
185 # Extended precision algorithms from T. J. Dekker,
186 # "A Floating-Point Technique for Extending the Available Precision"
187 # https://csclub.uwaterloo.ca/~pbarfuss/dekker1971.pdf
188 # Formulas: (5.5) (5.6) and (5.8). Code: mul12()
190 def dl_split(x: float):
191 "Split a float into two half-precision components."
192 t = x * 134217729.0 # Veltkamp constant = 2.0 ** 27 + 1
193 hi = t - (t - x)
194 lo = x - hi
195 return hi, lo
197 def dl_mul(x, y):
198 "Lossless multiplication."
199 xx_hi, xx_lo = dl_split(x)
200 yy_hi, yy_lo = dl_split(y)
201 p = xx_hi * yy_hi
202 q = xx_hi * yy_lo + xx_lo * yy_hi
203 z = p + q
204 zz = p - z + q + xx_lo * yy_lo
205 return z, zz
207 def _fsumprod(p, q):
208 return fsum(chain.from_iterable(map(dl_mul, p, q)))
211def chunked(iterable, n, strict=False):
212 """Break *iterable* into lists of length *n*:
214 >>> list(chunked([1, 2, 3, 4, 5, 6], 3))
215 [[1, 2, 3], [4, 5, 6]]
217 By the default, the last yielded list will have fewer than *n* elements
218 if the length of *iterable* is not divisible by *n*:
220 >>> list(chunked([1, 2, 3, 4, 5, 6, 7, 8], 3))
221 [[1, 2, 3], [4, 5, 6], [7, 8]]
223 To use a fill-in value instead, see the :func:`grouper` recipe.
225 If the length of *iterable* is not divisible by *n* and *strict* is
226 ``True``, then ``ValueError`` will be raised before the last
227 list is yielded.
229 """
230 iterator = iter(partial(take, n, iter(iterable)), [])
231 if strict:
232 if n is None:
233 raise ValueError('n must not be None when using strict mode.')
235 def ret():
236 for chunk in iterator:
237 if len(chunk) != n:
238 raise ValueError('iterable is not divisible by n.')
239 yield chunk
241 return ret()
242 else:
243 return iterator
246def first(iterable, default=_marker):
247 """Return the first item of *iterable*, or *default* if *iterable* is
248 empty.
250 >>> first([0, 1, 2, 3])
251 0
252 >>> first([], 'some default')
253 'some default'
255 If *default* is not provided and there are no items in the iterable,
256 raise ``ValueError``.
258 :func:`first` is useful when you have a generator of expensive-to-retrieve
259 values and want any arbitrary one. It is marginally shorter than
260 ``next(iter(iterable), default)``.
262 """
263 for item in iterable:
264 return item
265 if default is _marker:
266 raise ValueError(
267 'first() was called on an empty iterable, '
268 'and no default value was provided.'
269 )
270 return default
273def last(iterable, default=_marker):
274 """Return the last item of *iterable*, or *default* if *iterable* is
275 empty.
277 >>> last([0, 1, 2, 3])
278 3
279 >>> last([], 'some default')
280 'some default'
282 If *default* is not provided and there are no items in the iterable,
283 raise ``ValueError``.
284 """
285 try:
286 if isinstance(iterable, Sequence):
287 return iterable[-1]
288 # Work around https://bugs.python.org/issue38525
289 if getattr(iterable, '__reversed__', None):
290 return next(reversed(iterable))
291 return deque(iterable, maxlen=1)[-1]
292 except (IndexError, TypeError, StopIteration):
293 if default is _marker:
294 raise ValueError(
295 'last() was called on an empty iterable, '
296 'and no default value was provided.'
297 )
298 return default
301def nth_or_last(iterable, n, default=_marker):
302 """Return the nth or the last item of *iterable*,
303 or *default* if *iterable* is empty.
305 >>> nth_or_last([0, 1, 2, 3], 2)
306 2
307 >>> nth_or_last([0, 1], 2)
308 1
309 >>> nth_or_last([], 0, 'some default')
310 'some default'
312 If *default* is not provided and there are no items in the iterable,
313 raise ``ValueError``.
314 """
315 return last(islice(iterable, n + 1), default=default)
318class peekable:
319 """Wrap an iterator to allow lookahead and prepending elements.
321 Call :meth:`peek` on the result to get the value that will be returned
322 by :func:`next`. This won't advance the iterator:
324 >>> p = peekable(['a', 'b'])
325 >>> p.peek()
326 'a'
327 >>> next(p)
328 'a'
330 Pass :meth:`peek` a default value to return that instead of raising
331 ``StopIteration`` when the iterator is exhausted.
333 >>> p = peekable([])
334 >>> p.peek('hi')
335 'hi'
337 peekables also offer a :meth:`prepend` method, which "inserts" items
338 at the head of the iterable:
340 >>> p = peekable([1, 2, 3])
341 >>> p.prepend(10, 11, 12)
342 >>> next(p)
343 10
344 >>> p.peek()
345 11
346 >>> list(p)
347 [11, 12, 1, 2, 3]
349 peekables can be indexed. Index 0 is the item that will be returned by
350 :func:`next`, index 1 is the item after that, and so on:
351 The values up to the given index will be cached.
353 >>> p = peekable(['a', 'b', 'c', 'd'])
354 >>> p[0]
355 'a'
356 >>> p[1]
357 'b'
358 >>> next(p)
359 'a'
361 Negative indexes are supported, but be aware that they will cache the
362 remaining items in the source iterator, which may require significant
363 storage.
365 To check whether a peekable is exhausted, check its truth value:
367 >>> p = peekable(['a', 'b'])
368 >>> if p: # peekable has items
369 ... list(p)
370 ['a', 'b']
371 >>> if not p: # peekable is exhausted
372 ... list(p)
373 []
375 """
377 def __init__(self, iterable):
378 self._it = iter(iterable)
379 self._cache = deque()
381 def __iter__(self):
382 return self
384 def __bool__(self):
385 try:
386 self.peek()
387 except StopIteration:
388 return False
389 return True
391 def peek(self, default=_marker):
392 """Return the item that will be next returned from ``next()``.
394 Return ``default`` if there are no items left. If ``default`` is not
395 provided, raise ``StopIteration``.
397 """
398 if not self._cache:
399 try:
400 self._cache.append(next(self._it))
401 except StopIteration:
402 if default is _marker:
403 raise
404 return default
405 return self._cache[0]
407 def prepend(self, *items):
408 """Stack up items to be the next ones returned from ``next()`` or
409 ``self.peek()``. The items will be returned in
410 first in, first out order::
412 >>> p = peekable([1, 2, 3])
413 >>> p.prepend(10, 11, 12)
414 >>> next(p)
415 10
416 >>> list(p)
417 [11, 12, 1, 2, 3]
419 It is possible, by prepending items, to "resurrect" a peekable that
420 previously raised ``StopIteration``.
422 >>> p = peekable([])
423 >>> next(p)
424 Traceback (most recent call last):
425 ...
426 StopIteration
427 >>> p.prepend(1)
428 >>> next(p)
429 1
430 >>> next(p)
431 Traceback (most recent call last):
432 ...
433 StopIteration
435 """
436 self._cache.extendleft(reversed(items))
438 __class_getitem__ = classmethod(types.GenericAlias)
440 def __next__(self):
441 if self._cache:
442 return self._cache.popleft()
444 return next(self._it)
446 def _get_slice(self, index):
447 # Normalize the slice's arguments
448 step = 1 if (index.step is None) else index.step
449 if step > 0:
450 start = 0 if (index.start is None) else index.start
451 stop = maxsize if (index.stop is None) else index.stop
452 elif step < 0:
453 start = -1 if (index.start is None) else index.start
454 stop = (-maxsize - 1) if (index.stop is None) else index.stop
455 else:
456 raise ValueError('slice step cannot be zero')
458 # If either the start or stop index is negative, we'll need to cache
459 # the rest of the iterable in order to slice from the right side.
460 if (start < 0) or (stop < 0):
461 self._cache.extend(self._it)
462 # Otherwise we'll need to find the rightmost index and cache to that
463 # point.
464 else:
465 n = min(max(start, stop) + 1, maxsize)
466 cache_len = len(self._cache)
467 if n >= cache_len:
468 self._cache.extend(islice(self._it, n - cache_len))
470 return list(self._cache)[index]
472 def __getitem__(self, index):
473 if isinstance(index, slice):
474 return self._get_slice(index)
476 cache_len = len(self._cache)
477 if index < 0:
478 self._cache.extend(self._it)
479 elif index >= cache_len:
480 self._cache.extend(islice(self._it, index + 1 - cache_len))
482 return self._cache[index]
485def consumer(func):
486 """Decorator that automatically advances a PEP-342-style "reverse iterator"
487 to its first yield point so you don't have to call ``next()`` on it
488 manually.
490 >>> @consumer
491 ... def tally():
492 ... i = 0
493 ... while True:
494 ... print('Thing number %s is %s.' % (i, (yield)))
495 ... i += 1
496 ...
497 >>> t = tally()
498 >>> t.send('red')
499 Thing number 0 is red.
500 >>> t.send('fish')
501 Thing number 1 is fish.
503 Without the decorator, you would have to call ``next(t)`` before
504 ``t.send()`` could be used.
506 """
508 @wraps(func)
509 def wrapper(*args, **kwargs):
510 gen = func(*args, **kwargs)
511 next(gen)
512 return gen
514 return wrapper
517def ilen(iterable):
518 """Return the number of items in *iterable*.
520 For example, there are 168 prime numbers below 1,000:
522 >>> ilen(sieve(1000))
523 168
525 Equivalent to, but faster than::
527 def ilen(iterable):
528 count = 0
529 for _ in iterable:
530 count += 1
531 return count
533 This fully consumes the iterable, so handle with care.
535 """
536 # This is the "most beautiful of the fast variants" of this function.
537 # If you think you can improve on it, please ensure that your version
538 # is both 10x faster and 10x more beautiful.
539 return sum(compress(repeat(1), zip(iterable)))
542def iterate(func, start):
543 """Return ``start``, ``func(start)``, ``func(func(start))``, ...
545 Produces an infinite iterator. To add a stopping condition,
546 use :func:`take`, ``takewhile``, or :func:`takewhile_inclusive`:.
548 >>> take(10, iterate(lambda x: 2*x, 1))
549 [1, 2, 4, 8, 16, 32, 64, 128, 256, 512]
551 >>> collatz = lambda x: 3*x + 1 if x%2==1 else x // 2
552 >>> list(takewhile_inclusive(lambda x: x!=1, iterate(collatz, 10)))
553 [10, 5, 16, 8, 4, 2, 1]
555 """
556 with suppress(StopIteration):
557 while True:
558 yield start
559 start = func(start)
562def with_iter(context_manager):
563 """Wrap an iterable in a ``with`` statement, so it closes once exhausted.
565 For example, this will close the file when the iterator is exhausted::
567 upper_lines = (line.upper() for line in with_iter(open('foo')))
569 Note that you have to actually exhaust the iterator for opened files to be closed.
571 Any context manager which returns an iterable is a candidate for
572 ``with_iter``.
574 """
575 with context_manager as iterable:
576 yield from iterable
579class sized_iterator:
580 """Wrapper for *iterable* that implements ``__len__``.
582 >>> it = map(str, range(5))
583 >>> sized_it = sized_iterator(it, 5)
584 >>> len(sized_it)
585 5
586 >>> list(sized_it)
587 ['0', '1', '2', '3', '4']
589 This is useful for tools that use :func:`len`, like
590 `tqdm <https://pypi.org/project/tqdm/>`__ .
592 The wrapper doesn't validate the provided *length*, so be sure to choose
593 a value that reflects reality.
594 """
596 def __init__(self, iterable, length):
597 self._iterator = iter(iterable)
598 self._length = length
600 def __next__(self):
601 return next(self._iterator)
603 def __iter__(self):
604 return self
606 def __len__(self):
607 return self._length
610def one(iterable, too_short=None, too_long=None):
611 """Return the first item from *iterable*, which is expected to contain only
612 that item. Raise an exception if *iterable* is empty or has more than one
613 item.
615 :func:`one` is useful for ensuring that an iterable contains only one item.
616 For example, it can be used to retrieve the result of a database query
617 that is expected to return a single row.
619 If *iterable* is empty, ``ValueError`` will be raised. You may specify a
620 different exception with the *too_short* keyword:
622 >>> it = []
623 >>> one(it) # doctest: +IGNORE_EXCEPTION_DETAIL
624 Traceback (most recent call last):
625 ...
626 ValueError: too few items in iterable (expected 1)'
627 >>> too_short = IndexError('too few items')
628 >>> one(it, too_short=too_short) # doctest: +IGNORE_EXCEPTION_DETAIL
629 Traceback (most recent call last):
630 ...
631 IndexError: too few items
633 Similarly, if *iterable* contains more than one item, ``ValueError`` will
634 be raised. You may specify a different exception with the *too_long*
635 keyword:
637 >>> it = ['too', 'many']
638 >>> one(it) # doctest: +IGNORE_EXCEPTION_DETAIL
639 Traceback (most recent call last):
640 ...
641 ValueError: Expected exactly one item in iterable, but got 'too',
642 'many', and perhaps more.
643 >>> too_long = RuntimeError
644 >>> one(it, too_long=too_long) # doctest: +IGNORE_EXCEPTION_DETAIL
645 Traceback (most recent call last):
646 ...
647 RuntimeError
649 Note that :func:`one` attempts to advance *iterable* twice to ensure there
650 is only one item. See :func:`spy` or :func:`peekable` to check iterable
651 contents less destructively.
653 """
654 iterator = iter(iterable)
655 for first in iterator:
656 for second in iterator:
657 msg = (
658 f'Expected exactly one item in iterable, but got {first!r}, '
659 f'{second!r}, and perhaps more.'
660 )
661 raise too_long or ValueError(msg)
662 return first
663 raise too_short or ValueError('too few items in iterable (expected 1)')
666def raise_(exception, *args):
667 raise exception(*args)
670def strictly_n(iterable, n, too_short=None, too_long=None):
671 """Validate that *iterable* has exactly *n* items and return them if
672 it does. If it has fewer than *n* items, call function *too_short*
673 with the actual number of items. If it has more than *n* items, call function
674 *too_long* with the number ``n + 1``.
676 >>> iterable = ['a', 'b', 'c', 'd']
677 >>> n = 4
678 >>> list(strictly_n(iterable, n))
679 ['a', 'b', 'c', 'd']
681 Note that the returned iterable must be consumed in order for the check to
682 be made.
684 By default, *too_short* and *too_long* are functions that raise
685 ``ValueError``.
687 >>> list(strictly_n('ab', 3)) # doctest: +IGNORE_EXCEPTION_DETAIL
688 Traceback (most recent call last):
689 ...
690 ValueError: too few items in iterable (got 2)
692 >>> list(strictly_n('abc', 2)) # doctest: +IGNORE_EXCEPTION_DETAIL
693 Traceback (most recent call last):
694 ...
695 ValueError: too many items in iterable (got at least 3)
697 You can instead supply functions that do something else.
698 *too_short* will be called with the number of items in *iterable*.
699 *too_long* will be called with `n + 1`.
701 >>> def too_short(item_count):
702 ... raise RuntimeError
703 >>> it = strictly_n('abcd', 6, too_short=too_short)
704 >>> list(it) # doctest: +IGNORE_EXCEPTION_DETAIL
705 Traceback (most recent call last):
706 ...
707 RuntimeError
709 >>> def too_long(item_count):
710 ... print('The boss is going to hear about this')
711 >>> it = strictly_n('abcdef', 4, too_long=too_long)
712 >>> list(it)
713 The boss is going to hear about this
714 ['a', 'b', 'c', 'd']
716 """
717 if too_short is None:
718 too_short = lambda item_count: raise_(
719 ValueError,
720 f'Too few items in iterable (got {item_count})',
721 )
723 if too_long is None:
724 too_long = lambda item_count: raise_(
725 ValueError,
726 f'Too many items in iterable (got at least {item_count})',
727 )
729 it = iter(iterable)
731 sent = 0
732 for item in islice(it, n):
733 yield item
734 sent += 1
736 if sent < n:
737 too_short(sent)
738 return
740 for item in it:
741 too_long(n + 1)
742 return
745def distinct_permutations(iterable, r=None):
746 """Yield successive distinct permutations of the elements in *iterable*.
748 >>> sorted(distinct_permutations([1, 0, 1]))
749 [(0, 1, 1), (1, 0, 1), (1, 1, 0)]
751 Equivalent to yielding from ``set(permutations(iterable))``, except
752 duplicates are not generated and thrown away. For larger input sequences
753 this is much more efficient.
755 Duplicate permutations arise when there are duplicated elements in the
756 input iterable. The number of items returned is
757 `n! / (x_1! * x_2! * ... * x_n!)`, where `n` is the total number of
758 items input, and each `x_i` is the count of a distinct item in the input
759 sequence. The function :func:`multinomial` computes this directly.
761 If *r* is given, only the *r*-length permutations are yielded.
763 >>> sorted(distinct_permutations([1, 0, 1], r=2))
764 [(0, 1), (1, 0), (1, 1)]
765 >>> sorted(distinct_permutations(range(3), r=2))
766 [(0, 1), (0, 2), (1, 0), (1, 2), (2, 0), (2, 1)]
768 *iterable* need not be sortable, but note that using equal (``x == y``)
769 but non-identical (``id(x) != id(y)``) elements may produce surprising
770 behavior. For example, ``1`` and ``True`` are equal but non-identical:
772 >>> list(distinct_permutations([1, True, '3'])) # doctest: +SKIP
773 [
774 (1, True, '3'),
775 (1, '3', True),
776 ('3', 1, True)
777 ]
778 >>> list(distinct_permutations([1, 2, '3'])) # doctest: +SKIP
779 [
780 (1, 2, '3'),
781 (1, '3', 2),
782 (2, 1, '3'),
783 (2, '3', 1),
784 ('3', 1, 2),
785 ('3', 2, 1)
786 ]
787 """
789 # Algorithm: https://w.wiki/Qai
790 def _full(A):
791 while True:
792 # Yield the permutation we have
793 yield tuple(A)
795 # Find the largest index i such that A[i] < A[i + 1]
796 for i in range(size - 2, -1, -1):
797 if A[i] < A[i + 1]:
798 break
799 # If no such index exists, this permutation is the last one
800 else:
801 return
803 # Find the largest index j greater than j such that A[i] < A[j]
804 for j in range(size - 1, i, -1):
805 if A[i] < A[j]:
806 break
808 # Swap the value of A[i] with that of A[j], then reverse the
809 # sequence from A[i + 1] to form the new permutation
810 A[i], A[j] = A[j], A[i]
811 A[i + 1 :] = A[: i - size : -1] # A[i + 1:][::-1]
813 # Algorithm: modified from the above
814 def _partial(A, r):
815 # Split A into the first r items and the last r items
816 head, tail = A[:r], A[r:]
817 right_head_indexes = range(r - 1, -1, -1)
818 left_tail_indexes = range(len(tail))
820 while True:
821 # Yield the permutation we have
822 yield tuple(head)
824 # Starting from the right, find the first index of the head with
825 # value smaller than the maximum value of the tail - call it i.
826 pivot = tail[-1]
827 for i in right_head_indexes:
828 if head[i] < pivot:
829 break
830 pivot = head[i]
831 else:
832 return
834 # Starting from the left, find the first value of the tail
835 # with a value greater than head[i] and swap.
836 for j in left_tail_indexes:
837 if tail[j] > head[i]:
838 head[i], tail[j] = tail[j], head[i]
839 break
840 # If we didn't find one, start from the right and find the first
841 # index of the head with a value greater than head[i] and swap.
842 else:
843 for j in right_head_indexes:
844 if head[j] > head[i]:
845 head[i], head[j] = head[j], head[i]
846 break
848 # Reverse head[i + 1:] and swap it with tail[:r - (i + 1)]
849 tail += head[: i - r : -1] # head[i + 1:][::-1]
850 i += 1
851 head[i:], tail[:] = tail[: r - i], tail[r - i :]
853 items = list(iterable)
855 try:
856 items.sort()
857 sortable = True
858 except TypeError:
859 sortable = False
861 indices_dict = defaultdict(list)
863 for item in items:
864 indices_dict[items.index(item)].append(item)
866 indices = [items.index(item) for item in items]
867 indices.sort()
869 equivalent_items = {k: cycle(v) for k, v in indices_dict.items()}
871 def permuted_items(permuted_indices):
872 return tuple(
873 next(equivalent_items[index]) for index in permuted_indices
874 )
876 size = len(items)
877 if r is None:
878 r = size
880 # functools.partial(_partial, ... )
881 algorithm = _full if (r == size) else partial(_partial, r=r)
883 if 0 < r <= size:
884 if sortable:
885 return algorithm(items)
886 else:
887 return (
888 permuted_items(permuted_indices)
889 for permuted_indices in algorithm(indices)
890 )
892 return iter(() if r else ((),))
895def derangements(iterable, r=None):
896 """Yield successive derangements of the elements in *iterable*.
898 A derangement is a permutation in which no element appears at its original
899 index. In other words, a derangement is a permutation that has no fixed points.
901 Suppose Alice, Bob, Carol, and Dave are playing Secret Santa.
902 The code below outputs all of the different ways to assign gift recipients
903 such that nobody is assigned to himself or herself:
905 >>> for d in derangements(['Alice', 'Bob', 'Carol', 'Dave']):
906 ... print(', '.join(d))
907 Bob, Alice, Dave, Carol
908 Bob, Carol, Dave, Alice
909 Bob, Dave, Alice, Carol
910 Carol, Alice, Dave, Bob
911 Carol, Dave, Alice, Bob
912 Carol, Dave, Bob, Alice
913 Dave, Alice, Bob, Carol
914 Dave, Carol, Alice, Bob
915 Dave, Carol, Bob, Alice
917 If *r* is given, only the *r*-length derangements are yielded.
919 >>> sorted(derangements(range(3), 2))
920 [(1, 0), (1, 2), (2, 0)]
921 >>> sorted(derangements([0, 2, 3], 2))
922 [(2, 0), (2, 3), (3, 0)]
924 Elements are treated as unique based on their position, not on their value.
926 Consider the Secret Santa example with two *different* people who have
927 the *same* name. Then there are two valid gift assignments even though
928 it might appear that a person is assigned to themselves:
930 >>> names = ['Alice', 'Bob', 'Bob']
931 >>> list(derangements(names))
932 [('Bob', 'Bob', 'Alice'), ('Bob', 'Alice', 'Bob')]
934 To avoid confusion, make the inputs distinct:
936 >>> deduped = [f'{name}{index}' for index, name in enumerate(names)]
937 >>> list(derangements(deduped))
938 [('Bob1', 'Bob2', 'Alice0'), ('Bob2', 'Alice0', 'Bob1')]
940 The number of derangements of a set of size *n* is known as the
941 "subfactorial of n". For n > 0, the subfactorial is:
942 ``round(math.factorial(n) / math.e)``.
944 References:
946 * Article: https://www.numberanalytics.com/blog/ultimate-guide-to-derangements-in-combinatorics
947 * Sizes: https://oeis.org/A000166
948 """
949 xs = tuple(iterable)
950 ys = tuple(range(len(xs)))
951 return compress(
952 permutations(xs, r=r),
953 map(all, map(map, repeat(is_not), repeat(ys), permutations(ys, r=r))),
954 )
957def intersperse(e, iterable, n=1):
958 """Intersperse filler element *e* among the items in *iterable*, leaving
959 *n* items between each filler element.
961 >>> list(intersperse('!', [1, 2, 3, 4, 5]))
962 [1, '!', 2, '!', 3, '!', 4, '!', 5]
964 >>> list(intersperse(None, [1, 2, 3, 4, 5], n=2))
965 [1, 2, None, 3, 4, None, 5]
967 """
968 if n == 0:
969 raise ValueError('n must be > 0')
970 elif n == 1:
971 # interleave(repeat(e), iterable) -> e, x_0, e, x_1, e, x_2...
972 # islice(..., 1, None) -> x_0, e, x_1, e, x_2...
973 return islice(interleave(repeat(e), iterable), 1, None)
974 else:
975 # interleave(filler, chunks) -> [e], [x_0, x_1], [e], [x_2, x_3]...
976 # islice(..., 1, None) -> [x_0, x_1], [e], [x_2, x_3]...
977 # flatten(...) -> x_0, x_1, e, x_2, x_3...
978 filler = repeat([e])
979 chunks = chunked(iterable, n)
980 return flatten(islice(interleave(filler, chunks), 1, None))
983def unique_to_each(*iterables):
984 """Return the elements from each of the input iterables that aren't in the
985 other input iterables.
987 For example, suppose you have a set of packages, each with a set of
988 dependencies::
990 {'pkg_1': {'A', 'B'}, 'pkg_2': {'B', 'C'}, 'pkg_3': {'B', 'D'}}
992 If you remove one package, which dependencies can also be removed?
994 If ``pkg_1`` is removed, then ``A`` is no longer necessary - it is not
995 associated with ``pkg_2`` or ``pkg_3``. Similarly, ``C`` is only needed for
996 ``pkg_2``, and ``D`` is only needed for ``pkg_3``::
998 >>> unique_to_each({'A', 'B'}, {'B', 'C'}, {'B', 'D'})
999 [['A'], ['C'], ['D']]
1001 If there are duplicates in one input iterable that aren't in the others
1002 they will be duplicated in the output. Input order is preserved::
1004 >>> unique_to_each("mississippi", "missouri")
1005 [['p', 'p'], ['o', 'u', 'r']]
1007 It is assumed that the elements of each iterable are hashable.
1009 """
1010 pool = [list(it) for it in iterables]
1011 counts = Counter(chain.from_iterable(map(set, pool)))
1012 uniques = {element for element in counts if counts[element] == 1}
1013 return [list(filter(uniques.__contains__, it)) for it in pool]
1016def windowed(seq, n, fillvalue=None, step=1):
1017 """Return a sliding window of width *n* over the given iterable.
1019 >>> all_windows = windowed([1, 2, 3, 4, 5], 3)
1020 >>> list(all_windows)
1021 [(1, 2, 3), (2, 3, 4), (3, 4, 5)]
1023 When the window is larger than the iterable, *fillvalue* is used in place
1024 of missing values:
1026 >>> list(windowed([1, 2, 3], 4))
1027 [(1, 2, 3, None)]
1029 Each window will advance in increments of *step*:
1031 >>> list(windowed([1, 2, 3, 4, 5, 6], 3, fillvalue='!', step=2))
1032 [(1, 2, 3), (3, 4, 5), (5, 6, '!')]
1034 To slide into the iterable's items, use :func:`chain` to add filler items
1035 to the left:
1037 >>> iterable = [1, 2, 3, 4]
1038 >>> n = 3
1039 >>> padding = [None] * (n - 1)
1040 >>> list(windowed(chain(padding, iterable), 3))
1041 [(None, None, 1), (None, 1, 2), (1, 2, 3), (2, 3, 4)]
1042 """
1043 if n <= 0:
1044 raise ValueError('n must be > 0')
1045 if step < 1:
1046 raise ValueError('step must be >= 1')
1048 iterator = iter(seq)
1050 # Generate first window
1051 window = deque(islice(iterator, n), maxlen=n)
1053 # Deal with the first window not being full
1054 if not window:
1055 return
1056 if len(window) < n:
1057 yield tuple(window) + ((fillvalue,) * (n - len(window)))
1058 return
1059 yield tuple(window)
1061 # Create the filler for the next windows. The padding ensures
1062 # we have just enough elements to fill the last window.
1063 padding = (fillvalue,) * (n - 1 if step >= n else step - 1)
1064 filler = map(window.append, chain(iterator, padding))
1066 # Generate the rest of the windows
1067 for _ in islice(filler, step - 1, None, step):
1068 yield tuple(window)
1071def substrings(iterable):
1072 """Yield all of the substrings of *iterable*.
1074 >>> [''.join(s) for s in substrings('more')]
1075 ['m', 'o', 'r', 'e', 'mo', 'or', 're', 'mor', 'ore', 'more']
1077 Note that non-string iterables can also be subdivided.
1079 >>> list(substrings([0, 1, 2]))
1080 [(0,), (1,), (2,), (0, 1), (1, 2), (0, 1, 2)]
1082 Like subslices() but returns tuples instead of lists
1083 and returns the shortest substrings first.
1085 """
1086 seq = tuple(iterable)
1087 item_count = len(seq)
1088 for n in range(1, item_count + 1):
1089 slices = map(slice, range(item_count), range(n, item_count + 1))
1090 yield from map(getitem, repeat(seq), slices)
1093def substrings_indexes(seq, reverse=False):
1094 """Yield all substrings and their positions in *seq*
1096 The items yielded will be a tuple of the form ``(substr, i, j)``, where
1097 ``substr == seq[i:j]``.
1099 This function only works for iterables that support slicing, such as
1100 ``str`` objects.
1102 >>> for item in substrings_indexes('more'):
1103 ... print(item)
1104 ('m', 0, 1)
1105 ('o', 1, 2)
1106 ('r', 2, 3)
1107 ('e', 3, 4)
1108 ('mo', 0, 2)
1109 ('or', 1, 3)
1110 ('re', 2, 4)
1111 ('mor', 0, 3)
1112 ('ore', 1, 4)
1113 ('more', 0, 4)
1115 Set *reverse* to ``True`` to yield the same items in the opposite order.
1118 """
1119 r = range(1, len(seq) + 1)
1120 if reverse:
1121 r = reversed(r)
1122 return (
1123 (seq[i : i + L], i, i + L) for L in r for i in range(len(seq) - L + 1)
1124 )
1127class bucket:
1128 """Wrap *iterable* and return an object that buckets the iterable into
1129 child iterables based on a *key* function.
1131 >>> iterable = ['a1', 'b1', 'c1', 'a2', 'b2', 'c2', 'b3']
1132 >>> s = bucket(iterable, key=lambda x: x[0]) # Bucket by 1st character
1133 >>> sorted(list(s)) # Get the keys
1134 ['a', 'b', 'c']
1135 >>> a_iterable = s['a']
1136 >>> next(a_iterable)
1137 'a1'
1138 >>> next(a_iterable)
1139 'a2'
1140 >>> list(s['b'])
1141 ['b1', 'b2', 'b3']
1143 The original iterable will be advanced and its items will be cached until
1144 they are used by the child iterables. This may require significant storage.
1146 By default, attempting to select a bucket to which no items belong will
1147 exhaust the iterable and cache all values.
1148 If you specify a *validator* function, selected buckets will instead be
1149 checked against it.
1151 >>> from itertools import count
1152 >>> it = count(1, 2) # Infinite sequence of odd numbers
1153 >>> key = lambda x: x % 10 # Bucket by last digit
1154 >>> validator = lambda x: x in {1, 3, 5, 7, 9} # Odd digits only
1155 >>> s = bucket(it, key=key, validator=validator)
1156 >>> 2 in s
1157 False
1158 >>> list(s[2])
1159 []
1161 .. seealso:: :func:`map_reduce`, :func:`groupby_transform`
1163 If storage is not a concern, :func:`map_reduce` returns a Python
1164 dictionary, which is generally easier to work with. If the elements
1165 with the same key are already adjacent, :func:`groupby_transform`
1166 or :func:`itertools.groupby` can be used without any caching overhead.
1168 """
1170 def __init__(self, iterable, key, validator=None):
1171 self._it = iter(iterable)
1172 self._key = key
1173 self._cache = defaultdict(deque)
1174 self._validator = validator or (lambda x: True)
1176 def __contains__(self, value):
1177 if not self._validator(value):
1178 return False
1180 try:
1181 item = next(self[value])
1182 except StopIteration:
1183 return False
1184 else:
1185 self._cache[value].appendleft(item)
1187 return True
1189 def _get_values(self, value):
1190 """
1191 Helper to yield items from the parent iterator that match *value*.
1192 Items that don't match are stored in the local cache as they
1193 are encountered.
1194 """
1195 while True:
1196 # If we've cached some items that match the target value, emit
1197 # the first one and evict it from the cache.
1198 if self._cache[value]:
1199 yield self._cache[value].popleft()
1200 # Otherwise we need to advance the parent iterator to search for
1201 # a matching item, caching the rest.
1202 else:
1203 while True:
1204 try:
1205 item = next(self._it)
1206 except StopIteration:
1207 return
1208 item_value = self._key(item)
1209 if item_value == value:
1210 yield item
1211 break
1212 elif self._validator(item_value):
1213 self._cache[item_value].append(item)
1215 def __iter__(self):
1216 for item in self._it:
1217 item_value = self._key(item)
1218 if self._validator(item_value):
1219 self._cache[item_value].append(item)
1221 return iter(self._cache)
1223 def __getitem__(self, value):
1224 if not self._validator(value):
1225 return iter(())
1227 return self._get_values(value)
1230def spy(iterable, n=1):
1231 """Return a 2-tuple with a list containing the first *n* elements of
1232 *iterable*, and an iterator with the same items as *iterable*.
1233 This allows you to "look ahead" at the items in the iterable without
1234 advancing it.
1236 There is one item in the list by default:
1238 >>> iterable = 'abcdefg'
1239 >>> head, iterable = spy(iterable)
1240 >>> head
1241 ['a']
1242 >>> list(iterable)
1243 ['a', 'b', 'c', 'd', 'e', 'f', 'g']
1245 You may use unpacking to retrieve items instead of lists:
1247 >>> (head,), iterable = spy('abcdefg')
1248 >>> head
1249 'a'
1250 >>> (first, second), iterable = spy('abcdefg', 2)
1251 >>> first
1252 'a'
1253 >>> second
1254 'b'
1256 The number of items requested can be larger than the number of items in
1257 the iterable:
1259 >>> iterable = [1, 2, 3, 4, 5]
1260 >>> head, iterable = spy(iterable, 10)
1261 >>> head
1262 [1, 2, 3, 4, 5]
1263 >>> list(iterable)
1264 [1, 2, 3, 4, 5]
1266 """
1267 p, q = tee(iterable)
1268 return take(n, q), p
1271def interleave(*iterables):
1272 """Return a new iterable yielding from each iterable in turn,
1273 until the shortest is exhausted.
1275 >>> list(interleave([1, 2, 3], [4, 5], [6, 7, 8]))
1276 [1, 4, 6, 2, 5, 7]
1278 For a version that doesn't terminate after the shortest iterable is
1279 exhausted, see :func:`interleave_longest`.
1281 """
1282 return chain.from_iterable(zip(*iterables))
1285def interleave_longest(*iterables):
1286 """Return a new iterable yielding from each iterable in turn,
1287 skipping any that are exhausted.
1289 >>> list(interleave_longest([1, 2, 3], [4, 5], [6, 7, 8]))
1290 [1, 4, 6, 2, 5, 7, 3, 8]
1292 This function produces the same output as :func:`roundrobin`, but may
1293 perform better for some inputs (in particular when the number of iterables
1294 is large).
1296 """
1297 for xs in zip_longest(*iterables, fillvalue=_marker):
1298 for x in xs:
1299 if x is not _marker:
1300 yield x
1303def interleave_evenly(iterables, lengths=None):
1304 """
1305 Interleave multiple iterables so that their elements are evenly distributed
1306 throughout the output sequence.
1308 >>> iterables = [1, 2, 3, 4, 5], ['a', 'b']
1309 >>> list(interleave_evenly(iterables))
1310 [1, 2, 'a', 3, 4, 'b', 5]
1312 >>> iterables = [[1, 2, 3], [4, 5], [6, 7, 8]]
1313 >>> list(interleave_evenly(iterables))
1314 [1, 6, 4, 2, 7, 3, 8, 5]
1316 This function requires iterables of known length. Iterables without
1317 ``__len__()`` can be used by manually specifying lengths with *lengths*:
1319 >>> from itertools import combinations, repeat
1320 >>> iterables = [combinations(range(4), 2), ['a', 'b', 'c']]
1321 >>> lengths = [4 * (4 - 1) // 2, 3]
1322 >>> list(interleave_evenly(iterables, lengths=lengths))
1323 [(0, 1), (0, 2), 'a', (0, 3), (1, 2), 'b', (1, 3), (2, 3), 'c']
1325 Based on Bresenham's algorithm.
1326 """
1327 if lengths is None:
1328 try:
1329 lengths = [len(it) for it in iterables]
1330 except TypeError:
1331 raise ValueError(
1332 'Iterable lengths could not be determined automatically. '
1333 'Specify them with the lengths keyword.'
1334 )
1335 elif len(iterables) != len(lengths):
1336 raise ValueError('Mismatching number of iterables and lengths.')
1338 dims = len(lengths)
1340 # sort iterables by length, descending
1341 lengths_permute = sorted(
1342 range(dims), key=lambda i: lengths[i], reverse=True
1343 )
1344 lengths_desc = [lengths[i] for i in lengths_permute]
1345 iters_desc = [iter(iterables[i]) for i in lengths_permute]
1347 # the longest iterable is the primary one (Bresenham: the longest
1348 # distance along an axis)
1349 delta_primary, deltas_secondary = lengths_desc[0], lengths_desc[1:]
1350 iter_primary, iters_secondary = iters_desc[0], iters_desc[1:]
1351 errors = [delta_primary // dims] * len(deltas_secondary)
1353 to_yield = sum(lengths)
1354 while to_yield:
1355 yield next(iter_primary)
1356 to_yield -= 1
1357 # update errors for each secondary iterable
1358 errors = [e - delta for e, delta in zip(errors, deltas_secondary)]
1360 # those iterables for which the error is negative are yielded
1361 # ("diagonal step" in Bresenham)
1362 for i, e_ in enumerate(errors):
1363 if e_ < 0:
1364 yield next(iters_secondary[i])
1365 to_yield -= 1
1366 errors[i] += delta_primary
1369def interleave_randomly(*iterables):
1370 """Repeatedly select one of the input *iterables* at random and yield the next
1371 item from it.
1373 >>> iterables = [1, 2, 3], 'abc', (True, False, None)
1374 >>> list(interleave_randomly(*iterables)) # doctest: +SKIP
1375 ['a', 'b', 1, 'c', True, False, None, 2, 3]
1377 The relative order of the items in each input iterable will preserved. Note the
1378 sequences of items with this property are not equally likely to be generated.
1380 """
1381 iterators = [iter(e) for e in iterables]
1382 while iterators:
1383 idx = randrange(len(iterators))
1384 try:
1385 yield next(iterators[idx])
1386 except StopIteration:
1387 # equivalent to `list.pop` but slightly faster
1388 iterators[idx] = iterators[-1]
1389 del iterators[-1]
1392def collapse(iterable, base_type=None, levels=None):
1393 """Flatten an iterable with multiple levels of nesting (e.g., a list of
1394 lists of tuples) into non-iterable types.
1396 >>> iterable = [(1, 2), ([3, 4], [[5], [6]])]
1397 >>> list(collapse(iterable))
1398 [1, 2, 3, 4, 5, 6]
1400 Binary and text strings are not considered iterable and
1401 will not be collapsed.
1403 To avoid collapsing other types, specify *base_type*:
1405 >>> iterable = ['ab', ('cd', 'ef'), ['gh', 'ij']]
1406 >>> list(collapse(iterable, base_type=tuple))
1407 ['ab', ('cd', 'ef'), 'gh', 'ij']
1409 Specify *levels* to stop flattening after a certain level:
1411 >>> iterable = [('a', ['b']), ('c', ['d'])]
1412 >>> list(collapse(iterable)) # Fully flattened
1413 ['a', 'b', 'c', 'd']
1414 >>> list(collapse(iterable, levels=1)) # Only one level flattened
1415 ['a', ['b'], 'c', ['d']]
1417 """
1418 stack = deque()
1419 # Add our first node group, treat the iterable as a single node
1420 stack.appendleft((0, repeat(iterable, 1)))
1422 while stack:
1423 node_group = stack.popleft()
1424 level, nodes = node_group
1426 # Check if beyond max level
1427 if levels is not None and level > levels:
1428 yield from nodes
1429 continue
1431 for node in nodes:
1432 # Check if done iterating
1433 if isinstance(node, (str, bytes)) or (
1434 (base_type is not None) and isinstance(node, base_type)
1435 ):
1436 yield node
1437 # Otherwise try to create child nodes
1438 else:
1439 try:
1440 tree = iter(node)
1441 except TypeError:
1442 yield node
1443 else:
1444 # Save our current location
1445 stack.appendleft(node_group)
1446 # Append the new child node
1447 stack.appendleft((level + 1, tree))
1448 # Break to process child node
1449 break
1452def side_effect(func, iterable, chunk_size=None, before=None, after=None):
1453 """Invoke *func* on each item in *iterable* (or on each *chunk_size* group
1454 of items) before yielding the item.
1456 `func` must be a function that takes a single argument. Its return value
1457 will be discarded.
1459 *before* and *after* are optional functions that take no arguments. They
1460 will be executed before iteration starts and after it ends, respectively.
1462 `side_effect` can be used for logging, updating progress bars, or anything
1463 that is not functionally "pure."
1465 Emitting a status message:
1467 >>> from more_itertools import consume
1468 >>> func = lambda item: print('Received {}'.format(item))
1469 >>> consume(side_effect(func, range(2)))
1470 Received 0
1471 Received 1
1473 Operating on chunks of items:
1475 >>> pair_sums = []
1476 >>> func = lambda chunk: pair_sums.append(sum(chunk))
1477 >>> list(side_effect(func, [0, 1, 2, 3, 4, 5], 2))
1478 [0, 1, 2, 3, 4, 5]
1479 >>> list(pair_sums)
1480 [1, 5, 9]
1482 Writing to a file-like object:
1484 >>> from io import StringIO
1485 >>> from more_itertools import consume
1486 >>> f = StringIO()
1487 >>> func = lambda x: print(x, file=f)
1488 >>> before = lambda: print(u'HEADER', file=f)
1489 >>> after = f.close
1490 >>> it = [u'a', u'b', u'c']
1491 >>> consume(side_effect(func, it, before=before, after=after))
1492 >>> f.closed
1493 True
1495 """
1496 try:
1497 if before is not None:
1498 before()
1500 if chunk_size is None:
1501 for item in iterable:
1502 func(item)
1503 yield item
1504 else:
1505 for chunk in chunked(iterable, chunk_size):
1506 func(chunk)
1507 yield from chunk
1508 finally:
1509 if after is not None:
1510 after()
1513def sliced(seq, n, strict=False):
1514 """Yield slices of length *n* from the sequence *seq*.
1516 >>> list(sliced((1, 2, 3, 4, 5, 6), 3))
1517 [(1, 2, 3), (4, 5, 6)]
1519 By the default, the last yielded slice will have fewer than *n* elements
1520 if the length of *seq* is not divisible by *n*:
1522 >>> list(sliced((1, 2, 3, 4, 5, 6, 7, 8), 3))
1523 [(1, 2, 3), (4, 5, 6), (7, 8)]
1525 If the length of *seq* is not divisible by *n* and *strict* is
1526 ``True``, then ``ValueError`` will be raised before the last
1527 slice is yielded.
1529 This function will only work for iterables that support slicing.
1530 For non-sliceable iterables, see :func:`chunked`.
1532 """
1533 iterator = takewhile(len, (seq[i : i + n] for i in count(0, n)))
1534 if strict:
1536 def ret():
1537 for _slice in iterator:
1538 if len(_slice) != n:
1539 raise ValueError("seq is not divisible by n.")
1540 yield _slice
1542 return ret()
1543 else:
1544 return iterator
1547def split_at(iterable, pred, maxsplit=-1, keep_separator=False):
1548 """Yield lists of items from *iterable*, where each list is delimited by
1549 an item where callable *pred* returns ``True``.
1551 >>> list(split_at('abcdcba', lambda x: x == 'b'))
1552 [['a'], ['c', 'd', 'c'], ['a']]
1554 >>> list(split_at(range(10), lambda n: n % 2 == 1))
1555 [[0], [2], [4], [6], [8], []]
1557 At most *maxsplit* splits are done. If *maxsplit* is not specified or -1,
1558 then there is no limit on the number of splits:
1560 >>> list(split_at(range(10), lambda n: n % 2 == 1, maxsplit=2))
1561 [[0], [2], [4, 5, 6, 7, 8, 9]]
1563 By default, the delimiting items are not included in the output.
1564 To include them, set *keep_separator* to ``True``.
1566 >>> list(split_at('abcdcba', lambda x: x == 'b', keep_separator=True))
1567 [['a'], ['b'], ['c', 'd', 'c'], ['b'], ['a']]
1569 """
1570 if maxsplit == 0:
1571 yield list(iterable)
1572 return
1574 buf = []
1575 it = iter(iterable)
1576 for item in it:
1577 if pred(item):
1578 yield buf
1579 if keep_separator:
1580 yield [item]
1581 if maxsplit == 1:
1582 yield list(it)
1583 return
1584 buf = []
1585 maxsplit -= 1
1586 else:
1587 buf.append(item)
1588 yield buf
1591def split_before(iterable, pred, maxsplit=-1):
1592 """Yield lists of items from *iterable*, where each list ends just before
1593 an item for which callable *pred* returns ``True``:
1595 >>> list(split_before('OneTwo', lambda s: s.isupper()))
1596 [['O', 'n', 'e'], ['T', 'w', 'o']]
1598 >>> list(split_before(range(10), lambda n: n % 3 == 0))
1599 [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
1601 At most *maxsplit* splits are done. If *maxsplit* is not specified or -1,
1602 then there is no limit on the number of splits:
1604 >>> list(split_before(range(10), lambda n: n % 3 == 0, maxsplit=2))
1605 [[0, 1, 2], [3, 4, 5], [6, 7, 8, 9]]
1606 """
1607 if maxsplit == 0:
1608 yield list(iterable)
1609 return
1611 buf = []
1612 it = iter(iterable)
1613 for item in it:
1614 if pred(item) and buf:
1615 yield buf
1616 if maxsplit == 1:
1617 yield [item, *it]
1618 return
1619 buf = []
1620 maxsplit -= 1
1621 buf.append(item)
1622 if buf:
1623 yield buf
1626def split_after(iterable, pred, maxsplit=-1):
1627 """Yield lists of items from *iterable*, where each list ends with an
1628 item where callable *pred* returns ``True``:
1630 >>> list(split_after('one1two2', lambda s: s.isdigit()))
1631 [['o', 'n', 'e', '1'], ['t', 'w', 'o', '2']]
1633 >>> list(split_after(range(10), lambda n: n % 3 == 0))
1634 [[0], [1, 2, 3], [4, 5, 6], [7, 8, 9]]
1636 At most *maxsplit* splits are done. If *maxsplit* is not specified or -1,
1637 then there is no limit on the number of splits:
1639 >>> list(split_after(range(10), lambda n: n % 3 == 0, maxsplit=2))
1640 [[0], [1, 2, 3], [4, 5, 6, 7, 8, 9]]
1642 """
1643 if maxsplit == 0:
1644 yield list(iterable)
1645 return
1647 buf = []
1648 it = iter(iterable)
1649 for item in it:
1650 buf.append(item)
1651 if pred(item) and buf:
1652 yield buf
1653 if maxsplit == 1:
1654 buf = list(it)
1655 if buf:
1656 yield buf
1657 return
1658 buf = []
1659 maxsplit -= 1
1660 if buf:
1661 yield buf
1664def split_when(iterable, pred, maxsplit=-1):
1665 """Split *iterable* into pieces based on the output of *pred*.
1666 *pred* should be a function that takes successive pairs of items and
1667 returns ``True`` if the iterable should be split in between them.
1669 For example, to find runs of increasing numbers, split the iterable when
1670 element ``i`` is larger than element ``i + 1``:
1672 >>> list(split_when([1, 2, 3, 3, 2, 5, 2, 4, 2], lambda x, y: x > y))
1673 [[1, 2, 3, 3], [2, 5], [2, 4], [2]]
1675 At most *maxsplit* splits are done. If *maxsplit* is not specified or -1,
1676 then there is no limit on the number of splits:
1678 >>> list(split_when([1, 2, 3, 3, 2, 5, 2, 4, 2],
1679 ... lambda x, y: x > y, maxsplit=2))
1680 [[1, 2, 3, 3], [2, 5], [2, 4, 2]]
1682 """
1683 if maxsplit == 0:
1684 yield list(iterable)
1685 return
1687 it = iter(iterable)
1688 try:
1689 cur_item = next(it)
1690 except StopIteration:
1691 return
1693 buf = [cur_item]
1694 for next_item in it:
1695 if pred(cur_item, next_item):
1696 yield buf
1697 if maxsplit == 1:
1698 yield [next_item, *it]
1699 return
1700 buf = []
1701 maxsplit -= 1
1703 buf.append(next_item)
1704 cur_item = next_item
1706 yield buf
1709def split_into(iterable, sizes):
1710 """Yield a list of sequential items from *iterable* of length 'n' for each
1711 integer 'n' in *sizes*.
1713 >>> list(split_into([1,2,3,4,5,6], [1,2,3]))
1714 [[1], [2, 3], [4, 5, 6]]
1716 If the sum of *sizes* is smaller than the length of *iterable*, then the
1717 remaining items of *iterable* will not be returned.
1719 >>> list(split_into([1,2,3,4,5,6], [2,3]))
1720 [[1, 2], [3, 4, 5]]
1722 If the sum of *sizes* is larger than the length of *iterable*, fewer items
1723 will be returned in the iteration that overruns the *iterable* and further
1724 lists will be empty:
1726 >>> list(split_into([1,2,3,4], [1,2,3,4]))
1727 [[1], [2, 3], [4], []]
1729 When a ``None`` object is encountered in *sizes*, the returned list will
1730 contain items up to the end of *iterable* the same way that
1731 :func:`itertools.slice` does:
1733 >>> list(split_into([1,2,3,4,5,6,7,8,9,0], [2,3,None]))
1734 [[1, 2], [3, 4, 5], [6, 7, 8, 9, 0]]
1736 :func:`split_into` can be useful for grouping a series of items where the
1737 sizes of the groups are not uniform. An example would be where in a row
1738 from a table, multiple columns represent elements of the same feature
1739 (e.g. a point represented by x,y,z) but, the format is not the same for
1740 all columns.
1741 """
1742 # convert the iterable argument into an iterator so its contents can
1743 # be consumed by islice in case it is a generator
1744 it = iter(iterable)
1746 for size in sizes:
1747 if size is None:
1748 yield list(it)
1749 return
1750 else:
1751 yield list(islice(it, size))
1754def padded(iterable, fillvalue=None, n=None, next_multiple=False):
1755 """Yield the elements from *iterable*, followed by *fillvalue*, such that
1756 at least *n* items are emitted.
1758 >>> list(padded([1, 2, 3], '?', 5))
1759 [1, 2, 3, '?', '?']
1761 If *next_multiple* is ``True``, *fillvalue* will be emitted until the
1762 number of items emitted is a multiple of *n*:
1764 >>> list(padded([1, 2, 3, 4], n=3, next_multiple=True))
1765 [1, 2, 3, 4, None, None]
1767 If *n* is ``None``, *fillvalue* will be emitted indefinitely.
1769 To create an *iterable* of exactly size *n*, you can truncate with
1770 :func:`islice`.
1772 >>> list(islice(padded([1, 2, 3], '?'), 5))
1773 [1, 2, 3, '?', '?']
1774 >>> list(islice(padded([1, 2, 3, 4, 5, 6, 7, 8], '?'), 5))
1775 [1, 2, 3, 4, 5]
1777 """
1778 iterator = iter(iterable)
1779 iterator_with_repeat = chain(iterator, repeat(fillvalue))
1781 if n is None:
1782 return iterator_with_repeat
1783 elif n < 1:
1784 raise ValueError('n must be at least 1')
1785 elif next_multiple:
1787 def slice_generator():
1788 for first in iterator:
1789 yield (first,)
1790 yield islice(iterator_with_repeat, n - 1)
1792 # While elements exist produce slices of size n
1793 return chain.from_iterable(slice_generator())
1794 else:
1795 # Ensure the first batch is at least size n then iterate
1796 return chain(islice(iterator_with_repeat, n), iterator)
1799def repeat_each(iterable, n=2):
1800 """Repeat each element in *iterable* *n* times.
1802 >>> list(repeat_each('ABC', 3))
1803 ['A', 'A', 'A', 'B', 'B', 'B', 'C', 'C', 'C']
1804 """
1805 return chain.from_iterable(map(repeat, iterable, repeat(n)))
1808def repeat_last(iterable, default=None):
1809 """After the *iterable* is exhausted, keep yielding its last element.
1811 >>> list(islice(repeat_last(range(3)), 5))
1812 [0, 1, 2, 2, 2]
1814 If the iterable is empty, yield *default* forever::
1816 >>> list(islice(repeat_last(range(0), 42), 5))
1817 [42, 42, 42, 42, 42]
1819 """
1820 item = _marker
1821 for item in iterable:
1822 yield item
1823 final = default if item is _marker else item
1824 yield from repeat(final)
1827def distribute(n, iterable):
1828 """Distribute the items from *iterable* among *n* smaller iterables.
1830 >>> group_1, group_2 = distribute(2, [1, 2, 3, 4, 5, 6])
1831 >>> list(group_1)
1832 [1, 3, 5]
1833 >>> list(group_2)
1834 [2, 4, 6]
1836 If the length of *iterable* is not evenly divisible by *n*, then the
1837 length of the returned iterables will not be identical:
1839 >>> children = distribute(3, [1, 2, 3, 4, 5, 6, 7])
1840 >>> [list(c) for c in children]
1841 [[1, 4, 7], [2, 5], [3, 6]]
1843 If the length of *iterable* is smaller than *n*, then the last returned
1844 iterables will be empty:
1846 >>> children = distribute(5, [1, 2, 3])
1847 >>> [list(c) for c in children]
1848 [[1], [2], [3], [], []]
1850 This function uses :func:`itertools.tee` and may require significant
1851 storage.
1853 If you need the order items in the smaller iterables to match the
1854 original iterable, see :func:`divide`.
1856 """
1857 if n < 1:
1858 raise ValueError('n must be at least 1')
1860 children = tee(iterable, n)
1861 return [islice(it, index, None, n) for index, it in enumerate(children)]
1864def stagger(iterable, offsets=(-1, 0, 1), longest=False, fillvalue=None):
1865 """Yield tuples whose elements are offset from *iterable*.
1866 The amount by which the `i`-th item in each tuple is offset is given by
1867 the `i`-th item in *offsets*.
1869 >>> list(stagger([0, 1, 2, 3]))
1870 [(None, 0, 1), (0, 1, 2), (1, 2, 3)]
1871 >>> list(stagger(range(8), offsets=(0, 2, 4)))
1872 [(0, 2, 4), (1, 3, 5), (2, 4, 6), (3, 5, 7)]
1874 By default, the sequence will end when the final element of a tuple is the
1875 last item in the iterable. To continue until the first element of a tuple
1876 is the last item in the iterable, set *longest* to ``True``::
1878 >>> list(stagger([0, 1, 2, 3], longest=True))
1879 [(None, 0, 1), (0, 1, 2), (1, 2, 3), (2, 3, None), (3, None, None)]
1881 By default, ``None`` will be used to replace offsets beyond the end of the
1882 sequence. Specify *fillvalue* to use some other value.
1884 """
1885 children = tee(iterable, len(offsets))
1887 return zip_offset(
1888 *children, offsets=offsets, longest=longest, fillvalue=fillvalue
1889 )
1892def zip_offset(*iterables, offsets, longest=False, fillvalue=None):
1893 """``zip`` the input *iterables* together, but offset the `i`-th iterable
1894 by the `i`-th item in *offsets*.
1896 >>> list(zip_offset('0123', 'abcdef', offsets=(0, 1)))
1897 [('0', 'b'), ('1', 'c'), ('2', 'd'), ('3', 'e')]
1899 This can be used as a lightweight alternative to SciPy or pandas to analyze
1900 data sets in which some series have a lead or lag relationship.
1902 By default, the sequence will end when the shortest iterable is exhausted.
1903 To continue until the longest iterable is exhausted, set *longest* to
1904 ``True``.
1906 >>> list(zip_offset('0123', 'abcdef', offsets=(0, 1), longest=True))
1907 [('0', 'b'), ('1', 'c'), ('2', 'd'), ('3', 'e'), (None, 'f')]
1909 By default, ``None`` will be used to replace offsets beyond the end of the
1910 sequence. Specify *fillvalue* to use some other value.
1912 """
1913 if len(iterables) != len(offsets):
1914 raise ValueError("Number of iterables and offsets didn't match")
1916 staggered = []
1917 for it, n in zip(iterables, offsets):
1918 if n < 0:
1919 staggered.append(chain(repeat(fillvalue, -n), it))
1920 elif n > 0:
1921 staggered.append(islice(it, n, None))
1922 else:
1923 staggered.append(it)
1925 if longest:
1926 return zip_longest(*staggered, fillvalue=fillvalue)
1928 return zip(*staggered)
1931def sort_together(
1932 iterables, key_list=(0,), key=None, reverse=False, strict=False
1933):
1934 """Return the input iterables sorted together, with *key_list* as the
1935 priority for sorting. All iterables are trimmed to the length of the
1936 shortest one.
1938 This can be used like the sorting function in a spreadsheet. If each
1939 iterable represents a column of data, the key list determines which
1940 columns are used for sorting.
1942 By default, all iterables are sorted using the ``0``-th iterable::
1944 >>> iterables = [(4, 3, 2, 1), ('a', 'b', 'c', 'd')]
1945 >>> sort_together(iterables)
1946 [(1, 2, 3, 4), ('d', 'c', 'b', 'a')]
1948 Set a different key list to sort according to another iterable.
1949 Specifying multiple keys dictates how ties are broken::
1951 >>> iterables = [(3, 1, 2), (0, 1, 0), ('c', 'b', 'a')]
1952 >>> sort_together(iterables, key_list=(1, 2))
1953 [(2, 3, 1), (0, 0, 1), ('a', 'c', 'b')]
1955 To sort by a function of the elements of the iterable, pass a *key*
1956 function. Its arguments are the elements of the iterables corresponding to
1957 the key list::
1959 >>> names = ('a', 'b', 'c')
1960 >>> lengths = (1, 2, 3)
1961 >>> widths = (5, 2, 1)
1962 >>> def area(length, width):
1963 ... return length * width
1964 >>> sort_together([names, lengths, widths], key_list=(1, 2), key=area)
1965 [('c', 'b', 'a'), (3, 2, 1), (1, 2, 5)]
1967 Set *reverse* to ``True`` to sort in descending order.
1969 >>> sort_together([(1, 2, 3), ('c', 'b', 'a')], reverse=True)
1970 [(3, 2, 1), ('a', 'b', 'c')]
1972 If the *strict* keyword argument is ``True``, then
1973 ``ValueError`` will be raised if any of the iterables have
1974 different lengths.
1976 """
1977 if key is None:
1978 # if there is no key function, the key argument to sorted is an
1979 # itemgetter
1980 key_argument = itemgetter(*key_list)
1981 else:
1982 # if there is a key function, call it with the items at the offsets
1983 # specified by the key function as arguments
1984 key_list = list(key_list)
1985 if len(key_list) == 1:
1986 # if key_list contains a single item, pass the item at that offset
1987 # as the only argument to the key function
1988 key_offset = key_list[0]
1989 key_argument = lambda zipped_items: key(zipped_items[key_offset])
1990 else:
1991 # if key_list contains multiple items, use itemgetter to return a
1992 # tuple of items, which we pass as *args to the key function
1993 get_key_items = itemgetter(*key_list)
1994 key_argument = lambda zipped_items: key(
1995 *get_key_items(zipped_items)
1996 )
1998 transposed = zip(*iterables, strict=strict)
1999 reordered = sorted(transposed, key=key_argument, reverse=reverse)
2000 untransposed = zip(*reordered, strict=strict)
2001 return list(untransposed)
2004def unzip(iterable):
2005 """The inverse of :func:`zip`, this function disaggregates the elements
2006 of the zipped *iterable*.
2008 The ``i``-th iterable contains the ``i``-th element from each element
2009 of the zipped iterable. The first element is used to determine the
2010 length of the remaining elements.
2012 >>> iterable = [('a', 1), ('b', 2), ('c', 3), ('d', 4)]
2013 >>> letters, numbers = unzip(iterable)
2014 >>> list(letters)
2015 ['a', 'b', 'c', 'd']
2016 >>> list(numbers)
2017 [1, 2, 3, 4]
2019 This is similar to using ``zip(*iterable)``, but it avoids reading
2020 *iterable* into memory. Note, however, that this function uses
2021 :func:`itertools.tee` and thus may require significant storage.
2023 """
2024 head, iterable = spy(iterable)
2025 if not head:
2026 # empty iterable, e.g. zip([], [], [])
2027 return ()
2028 # spy returns a one-length iterable as head
2029 head = head[0]
2030 iterables = tee(iterable, len(head))
2032 # If we have an iterable like iter([(1, 2, 3), (4, 5), (6,)]),
2033 # the second unzipped iterable fails at the third tuple since
2034 # it tries to access (6,)[1].
2035 # Same with the third unzipped iterable and the second tuple.
2036 # To support these "improperly zipped" iterables, we suppress
2037 # the IndexError, which just stops the unzipped iterables at
2038 # first length mismatch.
2039 return tuple(
2040 iter_suppress(map(itemgetter(i), it), IndexError)
2041 for i, it in enumerate(iterables)
2042 )
2045def divide(n, iterable):
2046 """Divide the elements from *iterable* into *n* parts, maintaining
2047 order.
2049 >>> group_1, group_2 = divide(2, [1, 2, 3, 4, 5, 6])
2050 >>> list(group_1)
2051 [1, 2, 3]
2052 >>> list(group_2)
2053 [4, 5, 6]
2055 If the length of *iterable* is not evenly divisible by *n*, then the
2056 length of the returned iterables will not be identical:
2058 >>> children = divide(3, [1, 2, 3, 4, 5, 6, 7])
2059 >>> [list(c) for c in children]
2060 [[1, 2, 3], [4, 5], [6, 7]]
2062 If the length of the iterable is smaller than n, then the last returned
2063 iterables will be empty:
2065 >>> children = divide(5, [1, 2, 3])
2066 >>> [list(c) for c in children]
2067 [[1], [2], [3], [], []]
2069 This function will exhaust the iterable before returning.
2070 If order is not important, see :func:`distribute`, which does not first
2071 pull the iterable into memory.
2073 """
2074 if n < 1:
2075 raise ValueError('n must be at least 1')
2077 try:
2078 iterable[:0]
2079 except TypeError:
2080 seq = tuple(iterable)
2081 else:
2082 seq = iterable
2084 q, r = divmod(len(seq), n)
2086 ret = []
2087 stop = 0
2088 for i in range(1, n + 1):
2089 start = stop
2090 stop += q + 1 if i <= r else q
2091 ret.append(iter(seq[start:stop]))
2093 return ret
2096def always_iterable(obj, base_type=(str, bytes)):
2097 """If *obj* is iterable, return an iterator over its items::
2099 >>> obj = (1, 2, 3)
2100 >>> list(always_iterable(obj))
2101 [1, 2, 3]
2103 If *obj* is not iterable, return a one-item iterable containing *obj*::
2105 >>> obj = 1
2106 >>> list(always_iterable(obj))
2107 [1]
2109 If *obj* is ``None``, return an empty iterable:
2111 >>> obj = None
2112 >>> list(always_iterable(None))
2113 []
2115 By default, binary and text strings are not considered iterable::
2117 >>> obj = 'foo'
2118 >>> list(always_iterable(obj))
2119 ['foo']
2121 If *base_type* is set, objects for which ``isinstance(obj, base_type)``
2122 returns ``True`` won't be considered iterable.
2124 >>> obj = {'a': 1}
2125 >>> list(always_iterable(obj)) # Iterate over the dict's keys
2126 ['a']
2127 >>> list(always_iterable(obj, base_type=dict)) # Treat dicts as a unit
2128 [{'a': 1}]
2130 Set *base_type* to ``None`` to avoid any special handling and treat objects
2131 Python considers iterable as iterable:
2133 >>> obj = 'foo'
2134 >>> list(always_iterable(obj, base_type=None))
2135 ['f', 'o', 'o']
2136 """
2137 if obj is None:
2138 return iter(())
2140 if (base_type is not None) and isinstance(obj, base_type):
2141 return iter((obj,))
2143 try:
2144 return iter(obj)
2145 except TypeError:
2146 return iter((obj,))
2149def adjacent(predicate, iterable, distance=1):
2150 """Return an iterable over `(bool, item)` tuples where the `item` is
2151 drawn from *iterable* and the `bool` indicates whether
2152 that item satisfies the *predicate* or is adjacent to an item that does.
2154 For example, to find whether items are adjacent to a ``3``::
2156 >>> list(adjacent(lambda x: x == 3, range(6)))
2157 [(False, 0), (False, 1), (True, 2), (True, 3), (True, 4), (False, 5)]
2159 Set *distance* to change what counts as adjacent. For example, to find
2160 whether items are two places away from a ``3``:
2162 >>> list(adjacent(lambda x: x == 3, range(6), distance=2))
2163 [(False, 0), (True, 1), (True, 2), (True, 3), (True, 4), (True, 5)]
2165 This is useful for contextualizing the results of a search function.
2166 For example, a code comparison tool might want to identify lines that
2167 have changed, but also surrounding lines to give the viewer of the diff
2168 context.
2170 The predicate function will only be called once for each item in the
2171 iterable.
2173 See also :func:`groupby_transform`, which can be used with this function
2174 to group ranges of items with the same `bool` value.
2176 """
2177 # Allow distance=0 mainly for testing that it reproduces results with map()
2178 if distance < 0:
2179 raise ValueError('distance must be at least 0')
2181 i1, i2 = tee(iterable)
2182 padding = [False] * distance
2183 selected = chain(padding, map(predicate, i1), padding)
2184 adjacent_to_selected = map(any, windowed(selected, 2 * distance + 1))
2185 return zip(adjacent_to_selected, i2)
2188def groupby_transform(iterable, keyfunc=None, valuefunc=None, reducefunc=None):
2189 """An extension of :func:`itertools.groupby` that can apply transformations
2190 to the grouped data.
2192 * *keyfunc* is a function computing a key value for each item in *iterable*
2193 * *valuefunc* is a function that transforms the individual items from
2194 *iterable* after grouping
2195 * *reducefunc* is a function that transforms each group of items
2197 >>> iterable = 'aAAbBBcCC'
2198 >>> keyfunc = lambda k: k.upper()
2199 >>> valuefunc = lambda v: v.lower()
2200 >>> reducefunc = lambda g: ''.join(g)
2201 >>> list(groupby_transform(iterable, keyfunc, valuefunc, reducefunc))
2202 [('A', 'aaa'), ('B', 'bbb'), ('C', 'ccc')]
2204 Each optional argument defaults to an identity function if not specified.
2206 :func:`groupby_transform` is useful when grouping elements of an iterable
2207 using a separate iterable as the key. To do this, :func:`zip` the iterables
2208 and pass a *keyfunc* that extracts the first element and a *valuefunc*
2209 that extracts the second element::
2211 >>> from operator import itemgetter
2212 >>> keys = [0, 0, 1, 1, 1, 2, 2, 2, 3]
2213 >>> values = 'abcdefghi'
2214 >>> iterable = zip(keys, values)
2215 >>> grouper = groupby_transform(iterable, itemgetter(0), itemgetter(1))
2216 >>> [(k, ''.join(g)) for k, g in grouper]
2217 [(0, 'ab'), (1, 'cde'), (2, 'fgh'), (3, 'i')]
2219 Note that the order of items in the iterable is significant.
2220 Only adjacent items are grouped together, so if you don't want any
2221 duplicate groups, you should sort the iterable by the key function
2222 or consider :func:`bucket` or :func:`map_reduce`. :func:`map_reduce`
2223 consumes the iterable immediately and returns a dictionary, while
2224 :func:`bucket` does not.
2226 .. seealso:: :func:`bucket`, :func:`map_reduce`
2228 """
2229 ret = groupby(iterable, keyfunc)
2230 if valuefunc:
2231 ret = ((k, map(valuefunc, g)) for k, g in ret)
2232 if reducefunc:
2233 ret = ((k, reducefunc(g)) for k, g in ret)
2235 return ret
2238class numeric_range(Sequence):
2239 """An extension of the built-in ``range()`` function whose arguments can
2240 be any orderable numeric type.
2242 With only *stop* specified, *start* defaults to ``0`` and *step*
2243 defaults to ``1``. The output items will match the type of *stop*:
2245 >>> list(numeric_range(3.5))
2246 [0.0, 1.0, 2.0, 3.0]
2248 With only *start* and *stop* specified, *step* defaults to ``1``. The
2249 output items will match the type of *start*:
2251 >>> from decimal import Decimal
2252 >>> start = Decimal('2.1')
2253 >>> stop = Decimal('5.1')
2254 >>> list(numeric_range(start, stop))
2255 [Decimal('2.1'), Decimal('3.1'), Decimal('4.1')]
2257 With *start*, *stop*, and *step* specified the output items will match
2258 the type of ``start + step``:
2260 >>> from fractions import Fraction
2261 >>> start = Fraction(1, 2) # Start at 1/2
2262 >>> stop = Fraction(5, 2) # End at 5/2
2263 >>> step = Fraction(1, 2) # Count by 1/2
2264 >>> list(numeric_range(start, stop, step))
2265 [Fraction(1, 2), Fraction(1, 1), Fraction(3, 2), Fraction(2, 1)]
2267 If *step* is zero, ``ValueError`` is raised. Negative steps are supported:
2269 >>> list(numeric_range(3, -1, -1.0))
2270 [3.0, 2.0, 1.0, 0.0]
2272 Be aware of the limitations of floating-point numbers; the representation
2273 of the yielded numbers may be surprising.
2275 ``datetime.datetime`` objects can be used for *start* and *stop*, if *step*
2276 is a ``datetime.timedelta`` object:
2278 >>> import datetime
2279 >>> start = datetime.datetime(2019, 1, 1)
2280 >>> stop = datetime.datetime(2019, 1, 3)
2281 >>> step = datetime.timedelta(days=1)
2282 >>> items = iter(numeric_range(start, stop, step))
2283 >>> next(items)
2284 datetime.datetime(2019, 1, 1, 0, 0)
2285 >>> next(items)
2286 datetime.datetime(2019, 1, 2, 0, 0)
2288 """
2290 _EMPTY_HASH = hash(range(0, 0))
2292 def __init__(self, *args):
2293 argc = len(args)
2294 if argc == 1:
2295 (self._stop,) = args
2296 self._start = type(self._stop)(0)
2297 self._step = type(self._stop - self._start)(1)
2298 elif argc == 2:
2299 self._start, self._stop = args
2300 self._step = type(self._stop - self._start)(1)
2301 elif argc == 3:
2302 self._start, self._stop, self._step = args
2303 elif argc == 0:
2304 raise TypeError(
2305 f'numeric_range expected at least 1 argument, got {argc}'
2306 )
2307 else:
2308 raise TypeError(
2309 f'numeric_range expected at most 3 arguments, got {argc}'
2310 )
2312 self._zero = type(self._step)(0)
2313 if self._step == self._zero:
2314 raise ValueError('numeric_range() arg 3 must not be zero')
2315 self._growing = self._step > self._zero
2317 def __bool__(self):
2318 if self._growing:
2319 return self._start < self._stop
2320 else:
2321 return self._start > self._stop
2323 def __contains__(self, elem):
2324 if self._growing:
2325 if self._start <= elem < self._stop:
2326 return (elem - self._start) % self._step == self._zero
2327 else:
2328 if self._start >= elem > self._stop:
2329 return (self._start - elem) % (-self._step) == self._zero
2331 return False
2333 def __eq__(self, other):
2334 if isinstance(other, numeric_range):
2335 empty_self = not bool(self)
2336 empty_other = not bool(other)
2337 if empty_self or empty_other:
2338 return empty_self and empty_other # True if both empty
2339 else:
2340 return (
2341 self._start == other._start
2342 and self._step == other._step
2343 and self._get_by_index(-1) == other._get_by_index(-1)
2344 )
2345 else:
2346 return False
2348 def __getitem__(self, key):
2349 if isinstance(key, int):
2350 return self._get_by_index(key)
2351 elif isinstance(key, slice):
2352 start_idx, stop_idx, step_idx = key.indices(self._len)
2353 return numeric_range(
2354 self._start + start_idx * self._step,
2355 self._start + stop_idx * self._step,
2356 self._step * step_idx,
2357 )
2358 else:
2359 raise TypeError(
2360 'numeric range indices must be '
2361 f'integers or slices, not {type(key).__name__}'
2362 )
2364 def __hash__(self):
2365 if self:
2366 return hash((self._start, self._get_by_index(-1), self._step))
2367 else:
2368 return self._EMPTY_HASH
2370 def __iter__(self):
2371 values = (self._start + (n * self._step) for n in count())
2372 if self._growing:
2373 return takewhile(partial(gt, self._stop), values)
2374 else:
2375 return takewhile(partial(lt, self._stop), values)
2377 def __len__(self):
2378 return self._len
2380 @cached_property
2381 def _len(self):
2382 if self._growing:
2383 start = self._start
2384 stop = self._stop
2385 step = self._step
2386 else:
2387 start = self._stop
2388 stop = self._start
2389 step = -self._step
2390 distance = stop - start
2391 if distance <= self._zero:
2392 return 0
2393 else: # distance > 0 and step > 0: regular euclidean division
2394 q, r = divmod(distance, step)
2395 return int(q) + int(r != self._zero)
2397 def __reduce__(self):
2398 return numeric_range, (self._start, self._stop, self._step)
2400 def __repr__(self):
2401 if self._step == 1:
2402 return f"numeric_range({self._start!r}, {self._stop!r})"
2403 return (
2404 f"numeric_range({self._start!r}, {self._stop!r}, {self._step!r})"
2405 )
2407 def __reversed__(self):
2408 # Empty iterator
2409 try:
2410 start = self._get_by_index(-1)
2411 except IndexError:
2412 return iter([])
2414 return iter(
2415 numeric_range(start, self._start - self._step, -self._step)
2416 )
2418 def count(self, value):
2419 return int(value in self)
2421 def index(self, value):
2422 if self._growing:
2423 if self._start <= value < self._stop:
2424 q, r = divmod(value - self._start, self._step)
2425 if r == self._zero:
2426 return int(q)
2427 else:
2428 if self._start >= value > self._stop:
2429 q, r = divmod(self._start - value, -self._step)
2430 if r == self._zero:
2431 return int(q)
2433 raise ValueError(f"{value} is not in numeric range")
2435 def _get_by_index(self, i):
2436 if i < 0:
2437 i += self._len
2438 if i < 0 or i >= self._len:
2439 raise IndexError("numeric range object index out of range")
2440 return self._start + i * self._step
2443def count_cycle(iterable, n=None):
2444 """Cycle through the items from *iterable* up to *n* times, yielding
2445 the number of completed cycles along with each item. If *n* is omitted the
2446 process repeats indefinitely.
2448 >>> list(count_cycle('AB', 3))
2449 [(0, 'A'), (0, 'B'), (1, 'A'), (1, 'B'), (2, 'A'), (2, 'B')]
2451 """
2452 if n is not None:
2453 return product(range(n), iterable)
2454 seq = tuple(iterable)
2455 if not seq:
2456 return iter(())
2457 counter = count() if n is None else range(n)
2458 return zip(repeat_each(counter, len(seq)), cycle(seq))
2461def mark_ends(iterable):
2462 """Yield 3-tuples of the form ``(is_first, is_last, item)``.
2464 >>> list(mark_ends('ABC'))
2465 [(True, False, 'A'), (False, False, 'B'), (False, True, 'C')]
2467 Use this when looping over an iterable to take special action on its first
2468 and/or last items:
2470 >>> iterable = ['Header', 100, 200, 'Footer']
2471 >>> total = 0
2472 >>> for is_first, is_last, item in mark_ends(iterable):
2473 ... if is_first:
2474 ... continue # Skip the header
2475 ... if is_last:
2476 ... continue # Skip the footer
2477 ... total += item
2478 >>> print(total)
2479 300
2480 """
2481 it = iter(iterable)
2482 for a in it:
2483 first = True
2484 for b in it:
2485 yield first, False, a
2486 a = b
2487 first = False
2488 yield first, True, a
2491def locate(iterable, pred=bool, window_size=None):
2492 """Yield the index of each item in *iterable* for which *pred* returns
2493 ``True``.
2495 *pred* defaults to :func:`bool`, which will select truthy items:
2497 >>> list(locate([0, 1, 1, 0, 1, 0, 0]))
2498 [1, 2, 4]
2500 Set *pred* to a custom function to, e.g., find the indexes for a particular
2501 item.
2503 >>> list(locate(['a', 'b', 'c', 'b'], lambda x: x == 'b'))
2504 [1, 3]
2506 If *window_size* is given, then the *pred* function will be called with
2507 the values in each window. This enables searching for sub-sequences.
2508 Note that *pred* may receive fewer than *window_size* arguments at the end of
2509 the iterable.
2511 >>> iterable = [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]
2512 >>> pred = lambda *args: args == (1, 2, 3)
2513 >>> list(locate(iterable, pred=pred, window_size=3))
2514 [1, 5, 9]
2516 Use with :func:`seekable` to find indexes and then retrieve the associated
2517 items:
2519 >>> from itertools import count
2520 >>> from more_itertools import seekable
2521 >>> source = (3 * n + 1 if (n % 2) else n // 2 for n in count())
2522 >>> it = seekable(source)
2523 >>> pred = lambda x: x > 100
2524 >>> indexes = locate(it, pred=pred)
2525 >>> i = next(indexes)
2526 >>> it.seek(i)
2527 >>> next(it)
2528 106
2530 """
2531 if window_size is None:
2532 return compress(count(), map(pred, iterable))
2534 if window_size < 1:
2535 raise ValueError('window size must be at least 1')
2537 it = windowed(iterable, window_size, fillvalue=_marker)
2538 return compress(
2539 count(),
2540 (pred(*(x for x in w if x is not _marker)) for w in it),
2541 )
2544def longest_common_prefix(iterables):
2545 """Yield elements of the longest common prefix among given *iterables*.
2547 >>> ''.join(longest_common_prefix(['abcd', 'abc', 'abf']))
2548 'ab'
2550 """
2551 return (c[0] for c in takewhile(all_equal, zip(*iterables)))
2554def lstrip(iterable, pred):
2555 """Yield the items from *iterable*, but strip any from the beginning
2556 for which *pred* returns ``True``.
2558 For example, to remove a set of items from the start of an iterable:
2560 >>> iterable = (None, False, None, 1, 2, None, 3, False, None)
2561 >>> pred = lambda x: x in {None, False, ''}
2562 >>> list(lstrip(iterable, pred))
2563 [1, 2, None, 3, False, None]
2565 This function is analogous to to :func:`str.lstrip`, and is essentially
2566 an wrapper for :func:`itertools.dropwhile`.
2568 """
2569 return dropwhile(pred, iterable)
2572def rstrip(iterable, pred):
2573 """Yield the items from *iterable*, but strip any from the end
2574 for which *pred* returns ``True``.
2576 For example, to remove a set of items from the end of an iterable:
2578 >>> iterable = (None, False, None, 1, 2, None, 3, False, None)
2579 >>> pred = lambda x: x in {None, False, ''}
2580 >>> list(rstrip(iterable, pred))
2581 [None, False, None, 1, 2, None, 3]
2583 This function is analogous to :func:`str.rstrip`.
2585 """
2586 cache = []
2587 cache_append = cache.append
2588 cache_clear = cache.clear
2589 for x in iterable:
2590 if pred(x):
2591 cache_append(x)
2592 else:
2593 yield from cache
2594 cache_clear()
2595 yield x
2598def strip(iterable, pred):
2599 """Yield the items from *iterable*, but strip any from the
2600 beginning and end for which *pred* returns ``True``.
2602 For example, to remove a set of items from both ends of an iterable:
2604 >>> iterable = (None, False, None, 1, 2, None, 3, False, None)
2605 >>> pred = lambda x: x in {None, False, ''}
2606 >>> list(strip(iterable, pred))
2607 [1, 2, None, 3]
2609 This function is analogous to :func:`str.strip`.
2611 """
2612 return rstrip(lstrip(iterable, pred), pred)
2615class islice_extended:
2616 """An extension of :func:`itertools.islice` that supports negative values
2617 for *stop*, *start*, and *step*.
2619 >>> iterator = iter('abcdefgh')
2620 >>> list(islice_extended(iterator, -4, -1))
2621 ['e', 'f', 'g']
2623 Slices with negative values require some caching of *iterable*, but this
2624 function takes care to minimize the amount of memory required.
2626 For example, you can use a negative step with an infinite iterator:
2628 >>> from itertools import count
2629 >>> list(islice_extended(count(), 110, 99, -2))
2630 [110, 108, 106, 104, 102, 100]
2632 You can also use slice notation directly:
2634 >>> iterator = map(str, count())
2635 >>> it = islice_extended(iterator)[10:20:2]
2636 >>> list(it)
2637 ['10', '12', '14', '16', '18']
2639 """
2641 def __init__(self, iterable, *args):
2642 it = iter(iterable)
2643 if args:
2644 self._iterator = _islice_helper(it, slice(*args))
2645 else:
2646 self._iterator = it
2648 def __iter__(self):
2649 return self
2651 def __next__(self):
2652 return next(self._iterator)
2654 def __getitem__(self, key):
2655 if isinstance(key, slice):
2656 return islice_extended(_islice_helper(self._iterator, key))
2658 raise TypeError('islice_extended.__getitem__ argument must be a slice')
2661def _islice_helper(it, s):
2662 start = s.start
2663 stop = s.stop
2664 if s.step == 0:
2665 raise ValueError('step argument must be a non-zero integer or None.')
2666 step = s.step or 1
2668 if step > 0:
2669 start = 0 if (start is None) else start
2671 if start < 0:
2672 # Consume all but the last -start items
2673 counter = count(1)
2674 wrapper = compress(it, counter)
2675 cache = deque(wrapper, maxlen=-start)
2676 len_iter = next(counter) - 1
2678 # Adjust start to be positive
2679 i = max(len_iter + start, 0)
2681 # Adjust stop to be positive
2682 if stop is None:
2683 j = len_iter
2684 elif stop >= 0:
2685 j = min(stop, len_iter)
2686 else:
2687 j = max(len_iter + stop, 0)
2689 # Slice the cache
2690 n = j - i
2691 if n <= 0:
2692 return
2694 for index in range(n):
2695 if index % step == 0:
2696 # pop and yield the item.
2697 # We don't want to use an intermediate variable
2698 # it would extend the lifetime of the current item
2699 yield cache.popleft()
2700 else:
2701 # just pop and discard the item
2702 cache.popleft()
2703 elif (stop is not None) and (stop < 0):
2704 # Advance to the start position
2705 next(islice(it, start, start), None)
2707 # When stop is negative, we have to carry -stop items while
2708 # iterating
2709 cache = deque(islice(it, -stop), maxlen=-stop)
2711 for index, item in enumerate(it):
2712 if index % step == 0:
2713 # pop and yield the item.
2714 # We don't want to use an intermediate variable
2715 # it would extend the lifetime of the current item
2716 yield cache.popleft()
2717 else:
2718 # just pop and discard the item
2719 cache.popleft()
2720 cache.append(item)
2721 else:
2722 # When both start and stop are positive we have the normal case
2723 yield from islice(it, start, stop, step)
2724 else:
2725 start = -1 if (start is None) else start
2727 if (stop is not None) and (stop < 0):
2728 # Consume all but the last items
2729 n = -stop - 1
2730 counter = count(1)
2731 wrapper = compress(it, counter)
2732 cache = deque(wrapper, maxlen=n)
2733 len_iter = next(counter) - 1
2735 # If start and stop are both negative they are comparable and
2736 # we can just slice. Otherwise we can adjust start to be negative
2737 # and then slice.
2738 if start < 0:
2739 i, j = start, stop
2740 else:
2741 i, j = min(start - len_iter, -1), None
2743 yield from list(cache)[i:j:step]
2744 else:
2745 # Advance to the stop position
2746 if stop is not None:
2747 m = stop + 1
2748 next(islice(it, m, m), None)
2750 # stop is positive, so if start is negative they are not comparable
2751 # and we need the rest of the items.
2752 if start < 0:
2753 i = start
2754 n = None
2755 # stop is None and start is positive, so we just need items up to
2756 # the start index.
2757 elif stop is None:
2758 i = None
2759 n = start + 1
2760 # Both stop and start are positive, so they are comparable.
2761 else:
2762 i = None
2763 n = start - stop
2764 if n <= 0:
2765 return
2767 cache = list(islice(it, n))
2769 yield from cache[i::step]
2772def always_reversible(iterable):
2773 """An extension of :func:`reversed` that supports all iterables, not
2774 just those which implement the ``Reversible`` or ``Sequence`` protocols.
2776 >>> print(*always_reversible(x for x in range(3)))
2777 2 1 0
2779 If the iterable is already reversible, this function returns the
2780 result of :func:`reversed()`. If the iterable is not reversible,
2781 this function will cache the remaining items in the iterable and
2782 yield them in reverse order, which may require significant storage.
2783 """
2784 try:
2785 return reversed(iterable)
2786 except TypeError:
2787 return reversed(list(iterable))
2790def consecutive_groups(iterable, ordering=None):
2791 """Yield groups of consecutive items using :func:`itertools.groupby`.
2792 The *ordering* function determines whether two items are adjacent by
2793 returning their position.
2795 By default, the ordering function is the identity function. This is
2796 suitable for finding runs of numbers:
2798 >>> iterable = [1, 10, 11, 12, 20, 30, 31, 32, 33, 40]
2799 >>> for group in consecutive_groups(iterable):
2800 ... print(list(group))
2801 [1]
2802 [10, 11, 12]
2803 [20]
2804 [30, 31, 32, 33]
2805 [40]
2807 To find runs of adjacent letters, apply :func:`ord` function
2808 to convert letters to ordinals.
2810 >>> iterable = 'abcdfgilmnop'
2811 >>> ordering = ord
2812 >>> for group in consecutive_groups(iterable, ordering):
2813 ... print(list(group))
2814 ['a', 'b', 'c', 'd']
2815 ['f', 'g']
2816 ['i']
2817 ['l', 'm', 'n', 'o', 'p']
2819 Each group of consecutive items is an iterator that shares it source with
2820 *iterable*. When an an output group is advanced, the previous group is
2821 no longer available unless its elements are copied (e.g., into a ``list``).
2823 >>> iterable = [1, 2, 11, 12, 21, 22]
2824 >>> saved_groups = []
2825 >>> for group in consecutive_groups(iterable):
2826 ... saved_groups.append(list(group)) # Copy group elements
2827 >>> saved_groups
2828 [[1, 2], [11, 12], [21, 22]]
2830 """
2831 if ordering is None:
2832 key = lambda x: x[0] - x[1]
2833 else:
2834 key = lambda x: x[0] - ordering(x[1])
2836 for k, g in groupby(enumerate(iterable), key=key):
2837 yield map(itemgetter(1), g)
2840def difference(iterable, func=sub, *, initial=None):
2841 """This function is the inverse of :func:`itertools.accumulate`. By default
2842 it will compute the first difference of *iterable* using
2843 :func:`operator.sub`:
2845 >>> from itertools import accumulate
2846 >>> iterable = accumulate([0, 1, 2, 3, 4]) # produces 0, 1, 3, 6, 10
2847 >>> list(difference(iterable))
2848 [0, 1, 2, 3, 4]
2850 *func* defaults to :func:`operator.sub`, but other functions can be
2851 specified. They will be applied as follows::
2853 A, B, C, D, ... --> A, func(B, A), func(C, B), func(D, C), ...
2855 For example, to do progressive division:
2857 >>> iterable = [1, 2, 6, 24, 120]
2858 >>> func = lambda x, y: x // y
2859 >>> list(difference(iterable, func))
2860 [1, 2, 3, 4, 5]
2862 If the *initial* keyword is set, the first element will be skipped when
2863 computing successive differences.
2865 >>> it = [10, 11, 13, 16] # from accumulate([1, 2, 3], initial=10)
2866 >>> list(difference(it, initial=10))
2867 [1, 2, 3]
2869 """
2870 a, b = tee(iterable)
2871 try:
2872 first = [next(b)]
2873 except StopIteration:
2874 return iter([])
2876 if initial is not None:
2877 return map(func, b, a)
2879 return chain(first, map(func, b, a))
2882class SequenceView(Sequence):
2883 """Return a read-only view of the sequence object *target*.
2885 :class:`SequenceView` objects are analogous to Python's built-in
2886 "dictionary view" types. They provide a dynamic view of a sequence's items,
2887 meaning that when the sequence updates, so does the view.
2889 >>> seq = ['0', '1', '2']
2890 >>> view = SequenceView(seq)
2891 >>> view
2892 SequenceView(['0', '1', '2'])
2893 >>> seq.append('3')
2894 >>> view
2895 SequenceView(['0', '1', '2', '3'])
2897 Sequence views support indexing, slicing, and length queries. They act
2898 like the underlying sequence, except they don't allow assignment:
2900 >>> view[1]
2901 '1'
2902 >>> view[1:-1]
2903 ['1', '2']
2904 >>> len(view)
2905 4
2907 Sequence views are useful as an alternative to copying, as they don't
2908 require (much) extra storage.
2910 """
2912 def __init__(self, target):
2913 if not isinstance(target, Sequence):
2914 raise TypeError
2915 self._target = target
2917 def __getitem__(self, index):
2918 return self._target[index]
2920 def __len__(self):
2921 return len(self._target)
2923 def __repr__(self):
2924 return f'{self.__class__.__name__}({self._target!r})'
2927class seekable:
2928 """Wrap an iterator to allow for seeking backward and forward. This
2929 progressively caches the items in the source iterable so they can be
2930 re-visited.
2932 Call :meth:`seek` with an index to seek to that position in the source
2933 iterable.
2935 To "reset" an iterator, seek to ``0``:
2937 >>> from itertools import count
2938 >>> it = seekable((str(n) for n in count()))
2939 >>> next(it), next(it), next(it)
2940 ('0', '1', '2')
2941 >>> it.seek(0)
2942 >>> next(it), next(it), next(it)
2943 ('0', '1', '2')
2945 You can also seek forward:
2947 >>> it = seekable((str(n) for n in range(20)))
2948 >>> it.seek(10)
2949 >>> next(it)
2950 '10'
2951 >>> it.seek(20) # Seeking past the end of the source isn't a problem
2952 >>> list(it)
2953 []
2954 >>> it.seek(0) # Resetting works even after hitting the end
2955 >>> next(it)
2956 '0'
2958 Call :meth:`relative_seek` to seek relative to the source iterator's
2959 current position.
2961 >>> it = seekable((str(n) for n in range(20)))
2962 >>> next(it), next(it), next(it)
2963 ('0', '1', '2')
2964 >>> it.relative_seek(2)
2965 >>> next(it)
2966 '5'
2967 >>> it.relative_seek(-3) # Source is at '6', we move back to '3'
2968 >>> next(it)
2969 '3'
2970 >>> it.relative_seek(-3) # Source is at '4', we move back to '1'
2971 >>> next(it)
2972 '1'
2975 Call :meth:`peek` to look ahead one item without advancing the iterator:
2977 >>> it = seekable('1234')
2978 >>> it.peek()
2979 '1'
2980 >>> list(it)
2981 ['1', '2', '3', '4']
2982 >>> it.peek(default='empty')
2983 'empty'
2985 Before the iterator is at its end, calling :func:`bool` on it will return
2986 ``True``. After it will return ``False``:
2988 >>> it = seekable('5678')
2989 >>> bool(it)
2990 True
2991 >>> list(it)
2992 ['5', '6', '7', '8']
2993 >>> bool(it)
2994 False
2996 You may view the contents of the cache with the :meth:`elements` method.
2997 That returns a :class:`SequenceView`, a view that updates automatically:
2999 >>> it = seekable((str(n) for n in range(10)))
3000 >>> next(it), next(it), next(it)
3001 ('0', '1', '2')
3002 >>> elements = it.elements()
3003 >>> elements
3004 SequenceView(['0', '1', '2'])
3005 >>> next(it)
3006 '3'
3007 >>> elements
3008 SequenceView(['0', '1', '2', '3'])
3010 Indexing the :class:`seekable` directly returns items from the cache,
3011 which is useful for inspecting the most recently produced item:
3013 >>> it = seekable((str(n) for n in range(10)))
3014 >>> next(it), next(it), next(it)
3015 ('0', '1', '2')
3016 >>> it[-1]
3017 '2'
3018 >>> it[0]
3019 '0'
3021 By default, the cache grows as the source iterable progresses, so beware of
3022 wrapping very large or infinite iterables. Supply *maxlen* to limit the
3023 size of the cache (this of course limits how far back you can seek).
3025 >>> from itertools import count
3026 >>> it = seekable((str(n) for n in count()), maxlen=2)
3027 >>> next(it), next(it), next(it), next(it)
3028 ('0', '1', '2', '3')
3029 >>> list(it.elements())
3030 ['2', '3']
3031 >>> it.seek(0)
3032 >>> next(it), next(it), next(it), next(it)
3033 ('2', '3', '4', '5')
3034 >>> next(it)
3035 '6'
3037 """
3039 def __init__(self, iterable, maxlen=None):
3040 self._source = iter(iterable)
3041 if maxlen is None:
3042 self._cache = []
3043 else:
3044 self._cache = deque([], maxlen)
3045 self._index = None
3047 def __iter__(self):
3048 return self
3050 def __next__(self):
3051 if self._index is not None:
3052 try:
3053 item = self._cache[self._index]
3054 except IndexError:
3055 self._index = None
3056 else:
3057 self._index += 1
3058 return item
3060 item = next(self._source)
3061 self._cache.append(item)
3062 return item
3064 def __bool__(self):
3065 try:
3066 self.peek()
3067 except StopIteration:
3068 return False
3069 return True
3071 def peek(self, default=_marker):
3072 try:
3073 peeked = next(self)
3074 except StopIteration:
3075 if default is _marker:
3076 raise
3077 return default
3078 if self._index is None:
3079 self._index = len(self._cache)
3080 self._index -= 1
3081 return peeked
3083 def elements(self):
3084 return SequenceView(self._cache)
3086 def seek(self, index):
3087 self._index = index
3088 remainder = index - len(self._cache)
3089 if remainder > 0:
3090 consume(self, remainder)
3092 def relative_seek(self, count):
3093 if self._index is None:
3094 self._index = len(self._cache)
3096 self.seek(max(self._index + count, 0))
3098 def __getitem__(self, index):
3099 return self._cache[index]
3102class run_length:
3103 """
3104 :func:`run_length.encode` compresses an iterable with run-length encoding.
3105 It yields groups of repeated items with the count of how many times they
3106 were repeated:
3108 >>> uncompressed = 'abbcccdddd'
3109 >>> list(run_length.encode(uncompressed))
3110 [('a', 1), ('b', 2), ('c', 3), ('d', 4)]
3112 :func:`run_length.decode` decompresses an iterable that was previously
3113 compressed with run-length encoding. It yields the items of the
3114 decompressed iterable:
3116 >>> compressed = [('a', 1), ('b', 2), ('c', 3), ('d', 4)]
3117 >>> list(run_length.decode(compressed))
3118 ['a', 'b', 'b', 'c', 'c', 'c', 'd', 'd', 'd', 'd']
3120 """
3122 @staticmethod
3123 def encode(iterable):
3124 return ((k, ilen(g)) for k, g in groupby(iterable))
3126 @staticmethod
3127 def decode(iterable):
3128 return chain.from_iterable(starmap(repeat, iterable))
3131def exactly_n(iterable, n, predicate=bool):
3132 """Return ``True`` if exactly ``n`` items in the iterable are ``True``
3133 according to the *predicate* function.
3135 >>> exactly_n([True, True, False], 2)
3136 True
3137 >>> exactly_n([True, True, False], 1)
3138 False
3139 >>> exactly_n([0, 1, 2, 3, 4, 5], 3, lambda x: x < 3)
3140 True
3142 The iterable will be advanced until ``n + 1`` truthy items are encountered,
3143 so avoid calling it on infinite iterables.
3145 """
3146 iterator = filter(predicate, iterable)
3147 if n <= 0:
3148 if n < 0:
3149 return False
3150 for _ in iterator:
3151 return False
3152 return True
3154 iterator = islice(iterator, n - 1, None)
3155 for _ in iterator:
3156 for _ in iterator:
3157 return False
3158 return True
3159 return False
3162def circular_shifts(iterable, steps=1):
3163 """Yield the circular shifts of *iterable*.
3165 >>> list(circular_shifts(range(4)))
3166 [(0, 1, 2, 3), (1, 2, 3, 0), (2, 3, 0, 1), (3, 0, 1, 2)]
3168 Set *steps* to the number of places to rotate to the left
3169 (or to the right if negative). Defaults to 1.
3171 >>> list(circular_shifts(range(4), 2))
3172 [(0, 1, 2, 3), (2, 3, 0, 1)]
3174 >>> list(circular_shifts(range(4), -1))
3175 [(0, 1, 2, 3), (3, 0, 1, 2), (2, 3, 0, 1), (1, 2, 3, 0)]
3177 """
3178 buffer = deque(iterable)
3179 if steps == 0:
3180 raise ValueError('Steps should be a non-zero integer')
3182 buffer.rotate(steps)
3183 steps = -steps
3184 n = len(buffer)
3185 n //= math.gcd(n, steps)
3187 for _ in repeat(None, n):
3188 buffer.rotate(steps)
3189 yield tuple(buffer)
3192def make_decorator(wrapping_func, result_index=0):
3193 """Return a decorator version of *wrapping_func*, which is a function that
3194 modifies an iterable. *result_index* is the position in that function's
3195 signature where the iterable goes.
3197 This lets you use itertools on the "production end," i.e. at function
3198 definition. This can augment what the function returns without changing the
3199 function's code.
3201 For example, to produce a decorator version of :func:`chunked`:
3203 >>> from more_itertools import chunked
3204 >>> chunker = make_decorator(chunked, result_index=0)
3205 >>> @chunker(3)
3206 ... def iter_range(n):
3207 ... return iter(range(n))
3208 ...
3209 >>> list(iter_range(9))
3210 [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
3212 To only allow truthy items to be returned:
3214 >>> truth_serum = make_decorator(filter, result_index=1)
3215 >>> @truth_serum(bool)
3216 ... def boolean_test():
3217 ... return [0, 1, '', ' ', False, True]
3218 ...
3219 >>> list(boolean_test())
3220 [1, ' ', True]
3222 The :func:`peekable` and :func:`seekable` wrappers make for practical
3223 decorators:
3225 >>> from more_itertools import peekable
3226 >>> peekable_function = make_decorator(peekable)
3227 >>> @peekable_function()
3228 ... def str_range(*args):
3229 ... return (str(x) for x in range(*args))
3230 ...
3231 >>> it = str_range(1, 20, 2)
3232 >>> next(it), next(it), next(it)
3233 ('1', '3', '5')
3234 >>> it.peek()
3235 '7'
3236 >>> next(it)
3237 '7'
3239 """
3241 # See https://sites.google.com/site/bbayles/index/decorator_factory for
3242 # notes on how this works.
3243 def decorator(*wrapping_args, **wrapping_kwargs):
3244 def outer_wrapper(f):
3245 def inner_wrapper(*args, **kwargs):
3246 result = f(*args, **kwargs)
3247 wrapping_args_ = list(wrapping_args)
3248 wrapping_args_.insert(result_index, result)
3249 return wrapping_func(*wrapping_args_, **wrapping_kwargs)
3251 return inner_wrapper
3253 return outer_wrapper
3255 return decorator
3258def map_reduce(iterable, keyfunc, valuefunc=None, reducefunc=None):
3259 """Return a dictionary that maps the items in *iterable* to categories
3260 defined by *keyfunc*, transforms them with *valuefunc*, and
3261 then summarizes them by category with *reducefunc*.
3263 *valuefunc* defaults to the identity function if it is unspecified.
3264 If *reducefunc* is unspecified, no summarization takes place:
3266 >>> keyfunc = lambda x: x.upper()
3267 >>> result = map_reduce('abbccc', keyfunc)
3268 >>> sorted(result.items())
3269 [('A', ['a']), ('B', ['b', 'b']), ('C', ['c', 'c', 'c'])]
3271 Specifying *valuefunc* transforms the categorized items:
3273 >>> keyfunc = lambda x: x.upper()
3274 >>> valuefunc = lambda x: 1
3275 >>> result = map_reduce('abbccc', keyfunc, valuefunc)
3276 >>> sorted(result.items())
3277 [('A', [1]), ('B', [1, 1]), ('C', [1, 1, 1])]
3279 Specifying *reducefunc* summarizes the categorized items:
3281 >>> keyfunc = lambda x: x.upper()
3282 >>> valuefunc = lambda x: 1
3283 >>> reducefunc = sum
3284 >>> result = map_reduce('abbccc', keyfunc, valuefunc, reducefunc)
3285 >>> sorted(result.items())
3286 [('A', 1), ('B', 2), ('C', 3)]
3288 You may want to filter the input iterable before applying the map/reduce
3289 procedure:
3291 >>> all_items = range(30)
3292 >>> items = [x for x in all_items if 10 <= x <= 20] # Filter
3293 >>> keyfunc = lambda x: x % 2 # Evens map to 0; odds to 1
3294 >>> categories = map_reduce(items, keyfunc=keyfunc)
3295 >>> sorted(categories.items())
3296 [(0, [10, 12, 14, 16, 18, 20]), (1, [11, 13, 15, 17, 19])]
3297 >>> summaries = map_reduce(items, keyfunc=keyfunc, reducefunc=sum)
3298 >>> sorted(summaries.items())
3299 [(0, 90), (1, 75)]
3301 Note that all items in the iterable are gathered into a list before the
3302 summarization step, which may require significant storage.
3304 The returned object is a :obj:`collections.defaultdict` with the
3305 ``default_factory`` set to ``None``, such that it behaves like a normal
3306 dictionary.
3308 .. seealso:: :func:`bucket`, :func:`groupby_transform`
3310 If storage is a concern, :func:`bucket` can be used without consuming the
3311 entire iterable right away. If the elements with the same key are already
3312 adjacent, :func:`groupby_transform` or :func:`itertools.groupby` can be
3313 used without any caching overhead.
3315 """
3317 ret = defaultdict(list)
3319 if valuefunc is None:
3320 for item in iterable:
3321 key = keyfunc(item)
3322 ret[key].append(item)
3324 else:
3325 for item in iterable:
3326 key = keyfunc(item)
3327 value = valuefunc(item)
3328 ret[key].append(value)
3330 if reducefunc is not None:
3331 for key, value_list in ret.items():
3332 ret[key] = reducefunc(value_list)
3334 ret.default_factory = None
3335 return ret
3338def rlocate(iterable, pred=bool, window_size=None):
3339 """Yield the index of each item in *iterable* for which *pred* returns
3340 ``True``, starting from the right and moving left.
3342 *pred* defaults to :func:`bool`, which will select truthy items:
3344 >>> list(rlocate([0, 1, 1, 0, 1, 0, 0])) # Truthy at 1, 2, and 4
3345 [4, 2, 1]
3347 Set *pred* to a custom function to, e.g., find the indexes for a particular
3348 item:
3350 >>> iterator = iter('abcb')
3351 >>> pred = lambda x: x == 'b'
3352 >>> list(rlocate(iterator, pred))
3353 [3, 1]
3355 If *window_size* is given, then the *pred* function will be called with
3356 that many items. This enables searching for sub-sequences:
3358 >>> iterable = [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]
3359 >>> pred = lambda *args: args == (1, 2, 3)
3360 >>> list(rlocate(iterable, pred=pred, window_size=3))
3361 [9, 5, 1]
3363 Beware, this function won't return anything for infinite iterables.
3364 If *iterable* is reversible, ``rlocate`` will reverse it and search from
3365 the right. Otherwise, it will search from the left and return the results
3366 in reverse order.
3368 See :func:`locate` to for other example applications.
3370 """
3371 if window_size is None:
3372 try:
3373 len_iter = len(iterable)
3374 return (len_iter - i - 1 for i in locate(reversed(iterable), pred))
3375 except TypeError:
3376 pass
3378 return reversed(list(locate(iterable, pred, window_size)))
3381def replace(iterable, pred, substitutes, count=None, window_size=1):
3382 """Yield the items from *iterable*, replacing the items for which *pred*
3383 returns ``True`` with the items from the iterable *substitutes*.
3385 >>> iterable = [1, 1, 0, 1, 1, 0, 1, 1]
3386 >>> pred = lambda x: x == 0
3387 >>> substitutes = (2, 3)
3388 >>> list(replace(iterable, pred, substitutes))
3389 [1, 1, 2, 3, 1, 1, 2, 3, 1, 1]
3391 If *count* is given, the number of replacements will be limited:
3393 >>> iterable = [1, 1, 0, 1, 1, 0, 1, 1, 0]
3394 >>> pred = lambda x: x == 0
3395 >>> substitutes = [None]
3396 >>> list(replace(iterable, pred, substitutes, count=2))
3397 [1, 1, None, 1, 1, None, 1, 1, 0]
3399 Use *window_size* to control the number of items passed as arguments to
3400 *pred*. This allows for locating and replacing subsequences.
3402 >>> iterable = [0, 1, 2, 5, 0, 1, 2, 5]
3403 >>> window_size = 3
3404 >>> pred = lambda *args: args == (0, 1, 2) # 3 items passed to pred
3405 >>> substitutes = [3, 4] # Splice in these items
3406 >>> list(replace(iterable, pred, substitutes, window_size=window_size))
3407 [3, 4, 5, 3, 4, 5]
3409 *pred* may receive fewer than *window_size* arguments at the end of
3410 the iterable and should be able to handle this.
3412 """
3413 if window_size < 1:
3414 raise ValueError('window_size must be at least 1')
3416 # Save the substitutes iterable, since it's used more than once
3417 substitutes = tuple(substitutes)
3419 # Add padding such that the number of windows matches the length of the
3420 # iterable
3421 it = chain(iterable, repeat(_marker, window_size - 1))
3422 windows = windowed(it, window_size)
3424 n = 0
3425 for w in windows:
3426 # Strip any _marker padding so pred never sees internal sentinels.
3427 # Near the end of the iterable, pred will receive fewer arguments.
3428 args = tuple(x for x in w if x is not _marker)
3430 # If the current window matches our predicate (and we haven't hit
3431 # our maximum number of replacements), splice in the substitutes
3432 # and then consume the following windows that overlap with this one.
3433 # For example, if the iterable is (0, 1, 2, 3, 4...)
3434 # and the window size is 2, we have (0, 1), (1, 2), (2, 3)...
3435 # If the predicate matches on (0, 1), we need to zap (0, 1) and (1, 2)
3436 if args and pred(*args):
3437 if (count is None) or (n < count):
3438 n += 1
3439 yield from substitutes
3440 consume(windows, window_size - 1)
3441 continue
3443 # If there was no match (or we've reached the replacement limit),
3444 # yield the first item from the window.
3445 if args:
3446 yield args[0]
3449def partitions(iterable):
3450 """Yield all possible order-preserving partitions of *iterable*.
3452 >>> iterable = 'abc'
3453 >>> for part in partitions(iterable):
3454 ... print([''.join(p) for p in part])
3455 ['abc']
3456 ['a', 'bc']
3457 ['ab', 'c']
3458 ['a', 'b', 'c']
3460 This is unrelated to :func:`partition`.
3462 """
3463 sequence = list(iterable)
3464 n = len(sequence)
3465 for i in powerset(range(1, n)):
3466 yield [sequence[i:j] for i, j in zip((0,) + i, i + (n,))]
3469def set_partitions(iterable, k=None, min_size=None, max_size=None):
3470 """
3471 Yield the set partitions of *iterable* into *k* parts. Set partitions are
3472 not order-preserving.
3474 >>> iterable = 'abc'
3475 >>> for part in set_partitions(iterable, 2):
3476 ... print([''.join(p) for p in part])
3477 ['a', 'bc']
3478 ['ab', 'c']
3479 ['b', 'ac']
3482 If *k* is not given, every set partition is generated.
3484 >>> iterable = 'abc'
3485 >>> for part in set_partitions(iterable):
3486 ... print([''.join(p) for p in part])
3487 ['abc']
3488 ['a', 'bc']
3489 ['ab', 'c']
3490 ['b', 'ac']
3491 ['a', 'b', 'c']
3493 if *min_size* and/or *max_size* are given, the minimum and/or maximum size
3494 per block in partition is set.
3496 >>> iterable = 'abc'
3497 >>> for part in set_partitions(iterable, min_size=2):
3498 ... print([''.join(p) for p in part])
3499 ['abc']
3500 >>> for part in set_partitions(iterable, max_size=2):
3501 ... print([''.join(p) for p in part])
3502 ['a', 'bc']
3503 ['ab', 'c']
3504 ['b', 'ac']
3505 ['a', 'b', 'c']
3507 """
3508 L = list(iterable)
3509 n = len(L)
3510 if k is not None:
3511 if k < 1:
3512 raise ValueError(
3513 "Can't partition in a negative or zero number of groups"
3514 )
3515 elif k > n:
3516 return
3518 min_size = min_size if min_size is not None else 0
3519 max_size = max_size if max_size is not None else n
3520 if min_size > max_size:
3521 return
3523 def set_partitions_helper(L, k):
3524 n = len(L)
3525 if k == 1:
3526 yield [L]
3527 elif n == k:
3528 yield [[s] for s in L]
3529 else:
3530 e, *M = L
3531 for p in set_partitions_helper(M, k - 1):
3532 yield [[e], *p]
3533 for p in set_partitions_helper(M, k):
3534 for i in range(len(p)):
3535 yield p[:i] + [[e] + p[i]] + p[i + 1 :]
3537 if k is None:
3538 for k in range(1, n + 1):
3539 yield from filter(
3540 lambda z: all(min_size <= len(bk) <= max_size for bk in z),
3541 set_partitions_helper(L, k),
3542 )
3543 else:
3544 yield from filter(
3545 lambda z: all(min_size <= len(bk) <= max_size for bk in z),
3546 set_partitions_helper(L, k),
3547 )
3550class time_limited:
3551 """
3552 Yield items from *iterable* until *limit_seconds* have passed.
3553 If the time limit expires before all items have been yielded, the
3554 ``timed_out`` parameter will be set to ``True``.
3556 >>> from time import sleep
3557 >>> def generator():
3558 ... yield 1
3559 ... yield 2
3560 ... sleep(0.2)
3561 ... yield 3
3562 >>> iterable = time_limited(0.1, generator())
3563 >>> list(iterable)
3564 [1, 2]
3565 >>> iterable.timed_out
3566 True
3568 Note that the time is checked before each item is yielded, and iteration
3569 stops if the time elapsed is greater than *limit_seconds*. If your time
3570 limit is 1 second, but it takes 2 seconds to generate the first item from
3571 the iterable, the function will run for 2 seconds and not yield anything.
3572 As a special case, when *limit_seconds* is zero, the iterator never
3573 returns anything.
3575 """
3577 def __init__(self, limit_seconds, iterable):
3578 if limit_seconds < 0:
3579 raise ValueError('limit_seconds must be positive')
3580 self.limit_seconds = limit_seconds
3581 self._iterator = iter(iterable)
3582 self._start_time = monotonic()
3583 self.timed_out = False
3585 def __iter__(self):
3586 return self
3588 def __next__(self):
3589 if self.limit_seconds == 0:
3590 self.timed_out = True
3591 raise StopIteration
3592 item = next(self._iterator)
3593 if monotonic() - self._start_time > self.limit_seconds:
3594 self.timed_out = True
3595 raise StopIteration
3597 return item
3600def only(iterable, default=None, too_long=None):
3601 """If *iterable* has only one item, return it.
3602 If it has zero items, return *default*.
3603 If it has more than one item, raise the exception given by *too_long*,
3604 which is ``ValueError`` by default.
3606 >>> only([], default='missing')
3607 'missing'
3608 >>> only([1])
3609 1
3610 >>> only([1, 2]) # doctest: +IGNORE_EXCEPTION_DETAIL
3611 Traceback (most recent call last):
3612 ...
3613 ValueError: Expected exactly one item in iterable, but got 1, 2,
3614 and perhaps more.'
3615 >>> only([1, 2], too_long=TypeError) # doctest: +IGNORE_EXCEPTION_DETAIL
3616 Traceback (most recent call last):
3617 ...
3618 TypeError
3620 Note that :func:`only` attempts to advance *iterable* twice to ensure there
3621 is only one item. See :func:`spy` or :func:`peekable` to check
3622 iterable contents less destructively.
3624 """
3625 iterator = iter(iterable)
3626 for first in iterator:
3627 for second in iterator:
3628 msg = (
3629 f'Expected exactly one item in iterable, but got {first!r}, '
3630 f'{second!r}, and perhaps more.'
3631 )
3632 raise too_long or ValueError(msg)
3633 return first
3634 return default
3637def ichunked(iterable, n):
3638 """Break *iterable* into sub-iterables with *n* elements each.
3639 :func:`ichunked` is like :func:`chunked`, but it yields iterables
3640 instead of lists.
3642 If the sub-iterables are read in order, the elements of *iterable*
3643 won't be stored in memory.
3644 If they are read out of order, :func:`itertools.tee` is used to cache
3645 elements as necessary.
3647 >>> from itertools import count
3648 >>> all_chunks = ichunked(count(), 4)
3649 >>> c_1, c_2, c_3 = next(all_chunks), next(all_chunks), next(all_chunks)
3650 >>> list(c_2) # c_1's elements have been cached; c_3's haven't been
3651 [4, 5, 6, 7]
3652 >>> list(c_1)
3653 [0, 1, 2, 3]
3654 >>> list(c_3)
3655 [8, 9, 10, 11]
3657 """
3658 iterator = iter(iterable)
3659 for first in iterator:
3660 rest = islice(iterator, n - 1)
3661 cache, cacher = tee(rest)
3662 yield chain([first], rest, cache)
3663 consume(cacher)
3666def iequals(*iterables):
3667 """Return ``True`` if all given *iterables* are equal to each other,
3668 which means that they contain the same elements in the same order.
3670 The function is useful for comparing iterables of different data types
3671 or iterables that do not support equality checks.
3673 >>> iequals("abc", ['a', 'b', 'c'], ('a', 'b', 'c'), iter("abc"))
3674 True
3676 >>> iequals("abc", "acb")
3677 False
3679 Not to be confused with :func:`all_equal`, which checks whether all
3680 elements of iterable are equal to each other.
3682 """
3683 try:
3684 return all(map(all_equal, zip(*iterables, strict=True)))
3685 except ValueError:
3686 return False
3689def distinct_combinations(iterable, r):
3690 """Yield the distinct combinations of *r* items taken from *iterable*.
3692 >>> list(distinct_combinations([0, 0, 1], 2))
3693 [(0, 0), (0, 1)]
3695 Equivalent to ``set(combinations(iterable))``, except duplicates are not
3696 generated and thrown away. For larger input sequences this is much more
3697 efficient.
3699 """
3700 if r < 0:
3701 raise ValueError('r must be non-negative')
3702 elif r == 0:
3703 yield ()
3704 return
3705 pool = tuple(iterable)
3706 generators = [unique_everseen(enumerate(pool), key=itemgetter(1))]
3707 current_combo = [None] * r
3708 level = 0
3709 while generators:
3710 try:
3711 cur_idx, p = next(generators[-1])
3712 except StopIteration:
3713 generators.pop()
3714 level -= 1
3715 continue
3716 current_combo[level] = p
3717 if level + 1 == r:
3718 yield tuple(current_combo)
3719 else:
3720 generators.append(
3721 unique_everseen(
3722 enumerate(pool[cur_idx + 1 :], cur_idx + 1),
3723 key=itemgetter(1),
3724 )
3725 )
3726 level += 1
3729def filter_except(validator, iterable, *exceptions):
3730 """Yield the items from *iterable* for which the *validator* function does
3731 not raise one of the specified *exceptions*.
3733 *validator* is called for each item in *iterable*.
3734 It should be a function that accepts one argument and raises an exception
3735 if that item is not valid.
3737 >>> iterable = ['1', '2', 'three', '4', None]
3738 >>> list(filter_except(int, iterable, ValueError, TypeError))
3739 ['1', '2', '4']
3741 If an exception other than one given by *exceptions* is raised by
3742 *validator*, it is raised like normal.
3743 """
3744 for item in iterable:
3745 try:
3746 validator(item)
3747 except exceptions:
3748 pass
3749 else:
3750 yield item
3753def map_except(function, iterable, *exceptions):
3754 """Transform each item from *iterable* with *function* and yield the
3755 result, unless *function* raises one of the specified *exceptions*.
3757 *function* is called to transform each item in *iterable*.
3758 It should accept one argument.
3760 >>> iterable = ['1', '2', 'three', '4', None]
3761 >>> list(map_except(int, iterable, ValueError, TypeError))
3762 [1, 2, 4]
3764 If an exception other than one given by *exceptions* is raised by
3765 *function*, it is raised like normal.
3766 """
3767 for item in iterable:
3768 try:
3769 yield function(item)
3770 except exceptions:
3771 pass
3774def map_if(iterable, pred, func, func_else=None):
3775 """Evaluate each item from *iterable* using *pred*. If the result is
3776 equivalent to ``True``, transform the item with *func* and yield it.
3777 Otherwise, transform the item with *func_else* and yield it.
3779 *pred*, *func*, and *func_else* should each be functions that accept
3780 one argument. By default, *func_else* is the identity function.
3782 >>> from math import sqrt
3783 >>> iterable = list(range(-5, 5))
3784 >>> iterable
3785 [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4]
3786 >>> list(map_if(iterable, lambda x: x > 3, lambda x: 'toobig'))
3787 [-5, -4, -3, -2, -1, 0, 1, 2, 3, 'toobig']
3788 >>> list(map_if(iterable, lambda x: x >= 0,
3789 ... lambda x: f'{sqrt(x):.2f}', lambda x: None))
3790 [None, None, None, None, None, '0.00', '1.00', '1.41', '1.73', '2.00']
3791 """
3793 if func_else is None:
3794 for item in iterable:
3795 yield func(item) if pred(item) else item
3797 else:
3798 for item in iterable:
3799 yield func(item) if pred(item) else func_else(item)
3802def _sample_unweighted(iterator, k, strict):
3803 # Algorithm L in the 1994 paper by Kim-Hung Li:
3804 # "Reservoir-Sampling Algorithms of Time Complexity O(n(1+log(N/n)))".
3806 reservoir = list(islice(iterator, k))
3807 if strict and len(reservoir) < k:
3808 raise ValueError('Sample larger than population')
3809 W = 1.0
3811 with suppress(StopIteration):
3812 while True:
3813 W *= random() ** (1 / k)
3814 skip = floor(log(random()) / log1p(-W))
3815 element = next(islice(iterator, skip, None))
3816 reservoir[randrange(k)] = element
3818 shuffle(reservoir)
3819 return reservoir
3822def _sample_weighted(iterator, k, weights, strict):
3823 # Implementation of "A-ExpJ" from the 2006 paper by Efraimidis et al. :
3824 # "Weighted random sampling with a reservoir".
3826 # Log-transform for numerical stability for weights that are small/large
3827 weight_keys = (log(random()) / weight for weight in weights)
3829 # Fill up the reservoir (collection of samples) with the first `k`
3830 # weight-keys and elements, then heapify the list.
3831 reservoir = take(k, zip(weight_keys, iterator))
3832 if strict and len(reservoir) < k:
3833 raise ValueError('Sample larger than population')
3835 heapify(reservoir)
3837 # The number of jumps before changing the reservoir is a random variable
3838 # with an exponential distribution. Sample it using random() and logs.
3839 smallest_weight_key, _ = reservoir[0]
3840 weights_to_skip = log(random()) / smallest_weight_key
3842 for weight, element in zip(weights, iterator):
3843 if weight >= weights_to_skip:
3844 # The notation here is consistent with the paper, but we store
3845 # the weight-keys in log-space for better numerical stability.
3846 smallest_weight_key, _ = reservoir[0]
3847 t_w = exp(weight * smallest_weight_key)
3848 r_2 = uniform(t_w, 1) # generate U(t_w, 1)
3849 weight_key = log(r_2) / weight
3850 heapreplace(reservoir, (weight_key, element))
3851 smallest_weight_key, _ = reservoir[0]
3852 weights_to_skip = log(random()) / smallest_weight_key
3853 else:
3854 weights_to_skip -= weight
3856 ret = [element for weight_key, element in reservoir]
3857 shuffle(ret)
3858 return ret
3861def _sample_counted(population, k, counts, strict):
3862 element = None
3863 remaining = 0
3865 def feed(i):
3866 # Advance *i* steps ahead and consume an element
3867 nonlocal element, remaining
3869 while i + 1 > remaining:
3870 i = i - remaining
3871 element = next(population)
3872 remaining = next(counts)
3873 remaining -= i + 1
3874 return element
3876 with suppress(StopIteration):
3877 reservoir = []
3878 for _ in range(k):
3879 reservoir.append(feed(0))
3881 if strict and len(reservoir) < k:
3882 raise ValueError('Sample larger than population')
3884 with suppress(StopIteration):
3885 W = 1.0
3886 while True:
3887 W *= random() ** (1 / k)
3888 skip = floor(log(random()) / log1p(-W))
3889 element = feed(skip)
3890 reservoir[randrange(k)] = element
3892 shuffle(reservoir)
3893 return reservoir
3896def sample(iterable, k, weights=None, *, counts=None, strict=False):
3897 """Return a *k*-length list of elements chosen (without replacement)
3898 from the *iterable*.
3900 Similar to :func:`random.sample`, but works on inputs that aren't
3901 indexable (such as sets and dictionaries) and on inputs where the
3902 size isn't known in advance (such as generators).
3904 >>> iterable = range(100)
3905 >>> sample(iterable, 5) # doctest: +SKIP
3906 [81, 60, 96, 16, 4]
3908 For iterables with repeated elements, you may supply *counts* to
3909 indicate the repeats.
3911 >>> iterable = ['a', 'b']
3912 >>> counts = [3, 4] # Equivalent to 'a', 'a', 'a', 'b', 'b', 'b', 'b'
3913 >>> sample(iterable, k=3, counts=counts) # doctest: +SKIP
3914 ['a', 'a', 'b']
3916 An iterable with *weights* may be given:
3918 >>> iterable = range(100)
3919 >>> weights = (i * i + 1 for i in range(100))
3920 >>> sampled = sample(iterable, 5, weights=weights) # doctest: +SKIP
3921 [79, 67, 74, 66, 78]
3923 Weighted selections are made without replacement.
3924 After an element is selected, it is removed from the pool and the
3925 relative weights of the other elements increase (this
3926 does not match the behavior of :func:`random.sample`'s *counts*
3927 parameter). Note that *weights* may not be used with *counts*.
3929 If the length of *iterable* is less than *k*,
3930 ``ValueError`` is raised if *strict* is ``True`` and
3931 all elements are returned (in shuffled order) if *strict* is ``False``.
3933 By default, the `Algorithm L <https://w.wiki/ANrM>`__ reservoir sampling
3934 technique is used. When *weights* are provided,
3935 `Algorithm A-ExpJ <https://w.wiki/ANrS>`__ is used instead.
3937 Notes on reproducibility:
3939 * The algorithms rely on inexact floating-point functions provided
3940 by the underlying math library (e.g. ``log``, ``log1p``, and ``pow``).
3941 Those functions can `produce slightly different results
3942 <https://members.loria.fr/PZimmermann/papers/accuracy.pdf>`_ on
3943 different builds. Accordingly, selections can vary across builds
3944 even for the same seed.
3946 * The algorithms loop over the input and make selections based on
3947 ordinal position, so selections from unordered collections (such as
3948 sets) won't reproduce across sessions on the same platform using the
3949 same seed. For example, this won't reproduce::
3951 >> seed(8675309)
3952 >> sample(set('abcdefghijklmnopqrstuvwxyz'), 10)
3953 ['c', 'p', 'e', 'w', 's', 'a', 'j', 'd', 'n', 't']
3955 """
3956 iterator = iter(iterable)
3958 if k < 0:
3959 raise ValueError('k must be non-negative')
3961 if k == 0:
3962 return []
3964 if weights is not None and counts is not None:
3965 raise TypeError('weights and counts are mutually exclusive')
3967 elif weights is not None:
3968 weights = iter(weights)
3969 return _sample_weighted(iterator, k, weights, strict)
3971 elif counts is not None:
3972 counts = iter(counts)
3973 return _sample_counted(iterator, k, counts, strict)
3975 else:
3976 return _sample_unweighted(iterator, k, strict)
3979def is_sorted(iterable, key=None, reverse=False, strict=False):
3980 """Returns ``True`` if the items of iterable are in sorted order, and
3981 ``False`` otherwise. *key* and *reverse* have the same meaning that they do
3982 in the built-in :func:`sorted` function.
3984 >>> is_sorted(['1', '2', '3', '4', '5'], key=int)
3985 True
3986 >>> is_sorted([5, 4, 3, 1, 2], reverse=True)
3987 False
3989 If *strict*, tests for strict sorting, that is, returns ``False`` if equal
3990 elements are found:
3992 >>> is_sorted([1, 2, 2])
3993 True
3994 >>> is_sorted([1, 2, 2], strict=True)
3995 False
3997 The function returns ``False`` after encountering the first out-of-order
3998 item, which means it may produce results that differ from the built-in
3999 :func:`sorted` function for objects with unusual comparison dynamics
4000 (like ``math.nan``). If there are no out-of-order items, the iterable is
4001 exhausted.
4002 """
4003 it = iterable if (key is None) else map(key, iterable)
4004 a, b = tee(it)
4005 next(b, None)
4006 if reverse:
4007 b, a = a, b
4008 return all(map(lt, a, b)) if strict else not any(map(lt, b, a))
4011class AbortThread(BaseException):
4012 pass
4015class callback_iter:
4016 """Convert a function that uses callbacks to an iterator.
4018 .. warning::
4020 This function is deprecated as of version 11.0.0. It will be removed in a future
4021 major release.
4023 Let *func* be a function that takes a `callback` keyword argument.
4024 For example:
4026 >>> def func(callback=None):
4027 ... for i, c in [(1, 'a'), (2, 'b'), (3, 'c')]:
4028 ... if callback:
4029 ... callback(i, c)
4030 ... return 4
4033 Use ``with callback_iter(func)`` to get an iterator over the parameters
4034 that are delivered to the callback.
4036 >>> with callback_iter(func) as it:
4037 ... for args, kwargs in it:
4038 ... print(args)
4039 (1, 'a')
4040 (2, 'b')
4041 (3, 'c')
4043 The function will be called in a background thread. The ``done`` property
4044 indicates whether it has completed execution.
4046 >>> it.done
4047 True
4049 If it completes successfully, its return value will be available
4050 in the ``result`` property.
4052 >>> it.result
4053 4
4055 Notes:
4057 * If the function uses some keyword argument besides ``callback``, supply
4058 *callback_kwd*.
4059 * If it finished executing, but raised an exception, accessing the
4060 ``result`` property will raise the same exception.
4061 * If it hasn't finished executing, accessing the ``result``
4062 property from within the ``with`` block will raise ``RuntimeError``.
4063 * If it hasn't finished executing, accessing the ``result`` property from
4064 outside the ``with`` block will raise a
4065 ``more_itertools.AbortThread`` exception.
4066 * Provide *wait_seconds* to adjust how frequently the it is polled for
4067 output.
4069 """
4071 def __init__(self, func, callback_kwd='callback', wait_seconds=0.1):
4072 self._func = func
4073 self._callback_kwd = callback_kwd
4074 self._aborted = False
4075 self._future = None
4076 self._wait_seconds = wait_seconds
4078 # Lazily import concurrent.future
4079 self._module = __import__('concurrent.futures').futures
4080 self._executor = self._module.ThreadPoolExecutor(max_workers=1)
4081 self._iterator = self._reader()
4083 def __enter__(self):
4084 return self
4086 def __exit__(self, exc_type, exc_value, traceback):
4087 self._aborted = True
4088 self._executor.shutdown()
4090 def __iter__(self):
4091 return self
4093 def __next__(self):
4094 return next(self._iterator)
4096 @property
4097 def done(self):
4098 if self._future is None:
4099 return False
4100 return self._future.done()
4102 @property
4103 def result(self):
4104 if self._future:
4105 try:
4106 return self._future.result(timeout=0)
4107 except self._module.TimeoutError:
4108 pass
4110 raise RuntimeError('Function has not yet completed')
4112 def _reader(self):
4113 q = Queue()
4115 def callback(*args, **kwargs):
4116 if self._aborted:
4117 raise AbortThread('canceled by user')
4119 q.put((args, kwargs))
4121 self._future = self._executor.submit(
4122 self._func, **{self._callback_kwd: callback}
4123 )
4125 while True:
4126 try:
4127 item = q.get(timeout=self._wait_seconds)
4128 except Empty:
4129 pass
4130 else:
4131 q.task_done()
4132 yield item
4134 if self._future.done():
4135 break
4137 remaining = []
4138 while True:
4139 try:
4140 item = q.get_nowait()
4141 except Empty:
4142 break
4143 else:
4144 q.task_done()
4145 remaining.append(item)
4146 q.join()
4147 yield from remaining
4150def windowed_complete(iterable, n):
4151 """
4152 Yield ``(beginning, middle, end)`` tuples, where:
4154 * Each ``middle`` has *n* items from *iterable*
4155 * Each ``beginning`` has the items before the ones in ``middle``
4156 * Each ``end`` has the items after the ones in ``middle``
4158 >>> iterable = range(7)
4159 >>> n = 3
4160 >>> for beginning, middle, end in windowed_complete(iterable, n):
4161 ... print(beginning, middle, end)
4162 () (0, 1, 2) (3, 4, 5, 6)
4163 (0,) (1, 2, 3) (4, 5, 6)
4164 (0, 1) (2, 3, 4) (5, 6)
4165 (0, 1, 2) (3, 4, 5) (6,)
4166 (0, 1, 2, 3) (4, 5, 6) ()
4168 Note that *n* must be at least 0 and most equal to the length of
4169 *iterable*.
4171 This function will exhaust the iterable and may require significant
4172 storage.
4173 """
4174 if n < 0:
4175 raise ValueError('n must be >= 0')
4177 seq = tuple(iterable)
4178 size = len(seq)
4180 if n > size:
4181 raise ValueError('n must be <= len(seq)')
4183 for i in range(size - n + 1):
4184 beginning = seq[:i]
4185 middle = seq[i : i + n]
4186 end = seq[i + n :]
4187 yield beginning, middle, end
4190def all_unique(iterable, key=None):
4191 """
4192 Returns ``True`` if all the elements of *iterable* are unique (no two
4193 elements are equal).
4195 >>> all_unique('ABCB')
4196 False
4198 If a *key* function is specified, it will be used to make comparisons.
4200 >>> all_unique('ABCb')
4201 True
4202 >>> all_unique('ABCb', str.lower)
4203 False
4205 The function returns as soon as the first non-unique element is
4206 encountered. Iterables with a mix of hashable and unhashable items can
4207 be used, but the function will be slower for unhashable items.
4208 """
4209 seenset = set()
4210 seenset_add = seenset.add
4211 seenlist = []
4212 seenlist_add = seenlist.append
4213 for element in map(key, iterable) if key else iterable:
4214 try:
4215 if element in seenset:
4216 return False
4217 seenset_add(element)
4218 except TypeError:
4219 if element in seenlist:
4220 return False
4221 seenlist_add(element)
4222 return True
4225def nth_product(index, *iterables, repeat=1):
4226 """Equivalent to ``list(product(*iterables, repeat=repeat))[index]``.
4228 The products of *iterables* can be ordered lexicographically.
4229 :func:`nth_product` computes the product at sort position *index* without
4230 computing the previous products.
4232 >>> nth_product(8, range(2), range(2), range(2), range(2))
4233 (1, 0, 0, 0)
4235 The *repeat* keyword argument specifies the number of repetitions
4236 of the iterables. The above example is equivalent to::
4238 >>> nth_product(8, range(2), repeat=4)
4239 (1, 0, 0, 0)
4241 ``IndexError`` will be raised if the given *index* is invalid.
4242 """
4243 pools = tuple(map(tuple, reversed(iterables))) * repeat
4244 ns = tuple(map(len, pools))
4246 c = prod(ns)
4248 if index < 0:
4249 index += c
4250 if not 0 <= index < c:
4251 raise IndexError
4253 result = []
4254 for pool, n in zip(pools, ns):
4255 result.append(pool[index % n])
4256 index //= n
4258 return tuple(reversed(result))
4261def nth_permutation(iterable, r, index):
4262 """Equivalent to ``list(permutations(iterable, r))[index]```
4264 The subsequences of *iterable* that are of length *r* where order is
4265 important can be ordered lexicographically. :func:`nth_permutation`
4266 computes the subsequence at sort position *index* directly, without
4267 computing the previous subsequences.
4269 >>> nth_permutation('ghijk', 2, 5)
4270 ('h', 'i')
4272 ``ValueError`` will be raised If *r* is negative.
4273 ``IndexError`` will be raised if the given *index* is invalid.
4274 """
4275 pool = list(iterable)
4276 n = len(pool)
4277 if r is None:
4278 r = n
4279 c = perm(n, r)
4281 if index < 0:
4282 index += c
4283 if not 0 <= index < c:
4284 raise IndexError
4286 result = [0] * r
4287 q = index * factorial(n) // c if r < n else index
4288 for d in range(1, n + 1):
4289 q, i = divmod(q, d)
4290 if 0 <= n - d < r:
4291 result[n - d] = i
4292 if q == 0:
4293 break
4295 return tuple(map(pool.pop, result))
4298def nth_combination_with_replacement(iterable, r, index):
4299 """Equivalent to
4300 ``list(combinations_with_replacement(iterable, r))[index]``.
4303 The subsequences with repetition of *iterable* that are of length *r* can
4304 be ordered lexicographically. :func:`nth_combination_with_replacement`
4305 computes the subsequence at sort position *index* directly, without
4306 computing the previous subsequences with replacement.
4308 >>> nth_combination_with_replacement(range(5), 3, 5)
4309 (0, 1, 1)
4311 ``ValueError`` will be raised If *r* is negative.
4312 ``IndexError`` will be raised if the given *index* is invalid.
4313 """
4314 pool = tuple(iterable)
4315 n = len(pool)
4316 if r < 0:
4317 raise ValueError
4318 c = comb(n + r - 1, r) if n else 0 if r else 1
4320 if index < 0:
4321 index += c
4322 if not 0 <= index < c:
4323 raise IndexError
4325 result = []
4326 i = 0
4327 while r:
4328 r -= 1
4329 while n >= 0:
4330 num_combs = comb(n + r - 1, r)
4331 if index < num_combs:
4332 break
4333 n -= 1
4334 i += 1
4335 index -= num_combs
4336 result.append(pool[i])
4338 return tuple(result)
4341def value_chain(*args):
4342 """Yield all arguments passed to the function in the same order in which
4343 they were passed. If an argument itself is iterable then iterate over its
4344 values.
4346 >>> list(value_chain(1, 2, 3, [4, 5, 6]))
4347 [1, 2, 3, 4, 5, 6]
4349 Binary and text strings are not considered iterable and are emitted
4350 as-is:
4352 >>> list(value_chain('12', '34', ['56', '78']))
4353 ['12', '34', '56', '78']
4355 Pre- or postpend a single element to an iterable:
4357 >>> list(value_chain(1, [2, 3, 4, 5, 6]))
4358 [1, 2, 3, 4, 5, 6]
4359 >>> list(value_chain([1, 2, 3, 4, 5], 6))
4360 [1, 2, 3, 4, 5, 6]
4362 Multiple levels of nesting are not flattened.
4364 """
4365 scalar_types = (str, bytes)
4366 for value in args:
4367 if isinstance(value, scalar_types):
4368 yield value
4369 continue
4370 try:
4371 yield from value
4372 except TypeError:
4373 yield value
4376def product_index(element, *iterables, repeat=1):
4377 """Equivalent to ``list(product(*iterables, repeat=repeat)).index(tuple(element))``
4379 The products of *iterables* can be ordered lexicographically.
4380 :func:`product_index` computes the first index of *element* without
4381 computing the previous products.
4383 >>> product_index([8, 2], range(10), range(5))
4384 42
4386 The *repeat* keyword argument specifies the number of repetitions
4387 of the iterables::
4389 >>> product_index([8, 0, 7], range(10), repeat=3)
4390 807
4392 ``ValueError`` will be raised if the given *element* isn't in the product
4393 of *args*.
4394 """
4395 elements = tuple(element)
4396 pools = tuple(map(tuple, iterables)) * repeat
4397 if len(elements) != len(pools):
4398 raise ValueError('element is not a product of args')
4400 index = 0
4401 for elem, pool in zip(elements, pools):
4402 index = index * len(pool) + pool.index(elem)
4403 return index
4406def combination_index(element, iterable):
4407 """Equivalent to ``list(combinations(iterable, r)).index(element)``
4409 The subsequences of *iterable* that are of length *r* can be ordered
4410 lexicographically. :func:`combination_index` computes the index of the
4411 first *element*, without computing the previous combinations.
4413 >>> combination_index('adf', 'abcdefg')
4414 10
4416 ``ValueError`` will be raised if the given *element* isn't one of the
4417 combinations of *iterable*.
4418 """
4419 element = enumerate(element)
4420 k, y = next(element, (None, None))
4421 if k is None:
4422 return 0
4424 indexes = []
4425 pool = enumerate(iterable)
4426 for n, x in pool:
4427 if x == y:
4428 indexes.append(n)
4429 tmp, y = next(element, (None, None))
4430 if tmp is None:
4431 break
4432 else:
4433 k = tmp
4434 else:
4435 raise ValueError('element is not a combination of iterable')
4437 n, _ = last(pool, default=(n, None))
4439 index = 1
4440 for i, j in enumerate(reversed(indexes), start=1):
4441 j = n - j
4442 if i <= j:
4443 index += comb(j, i)
4445 return comb(n + 1, k + 1) - index
4448def combination_with_replacement_index(element, iterable):
4449 """Equivalent to
4450 ``list(combinations_with_replacement(iterable, r)).index(element)``
4452 The subsequences with repetition of *iterable* that are of length *r* can
4453 be ordered lexicographically. :func:`combination_with_replacement_index`
4454 computes the index of the first *element*, without computing the previous
4455 combinations with replacement.
4457 >>> combination_with_replacement_index('adf', 'abcdefg')
4458 20
4460 ``ValueError`` will be raised if the given *element* isn't one of the
4461 combinations with replacement of *iterable*.
4462 """
4463 element = tuple(element)
4464 l = len(element)
4465 element = enumerate(element)
4467 k, y = next(element, (None, None))
4468 if k is None:
4469 return 0
4471 indexes = []
4472 pool = tuple(iterable)
4473 for n, x in enumerate(pool):
4474 while x == y:
4475 indexes.append(n)
4476 tmp, y = next(element, (None, None))
4477 if tmp is None:
4478 break
4479 else:
4480 k = tmp
4481 if y is None:
4482 break
4483 else:
4484 raise ValueError(
4485 'element is not a combination with replacement of iterable'
4486 )
4488 n = len(pool)
4489 occupations = [0] * n
4490 for p in indexes:
4491 occupations[p] += 1
4493 index = 0
4494 cumulative_sum = 0
4495 for k in range(1, n):
4496 cumulative_sum += occupations[k - 1]
4497 j = l + n - 1 - k - cumulative_sum
4498 i = n - k
4499 if i <= j:
4500 index += comb(j, i)
4502 return index
4505def permutation_index(element, iterable):
4506 """Equivalent to ``list(permutations(iterable, r)).index(element)```
4508 The subsequences of *iterable* that are of length *r* where order is
4509 important can be ordered lexicographically. :func:`permutation_index`
4510 computes the index of the first *element* directly, without computing
4511 the previous permutations.
4513 >>> permutation_index([1, 3, 2], range(5))
4514 19
4516 ``ValueError`` will be raised if the given *element* isn't one of the
4517 permutations of *iterable*.
4518 """
4519 index = 0
4520 pool = list(iterable)
4521 for i, x in zip(range(len(pool), -1, -1), element):
4522 r = pool.index(x)
4523 index = index * i + r
4524 del pool[r]
4526 return index
4529class countable:
4530 """Wrap *iterable* and keep a count of how many items have been consumed.
4532 The ``items_seen`` attribute starts at ``0`` and increments as the iterable
4533 is consumed:
4535 >>> iterable = map(str, range(10))
4536 >>> it = countable(iterable)
4537 >>> it.items_seen
4538 0
4539 >>> next(it), next(it)
4540 ('0', '1')
4541 >>> list(it)
4542 ['2', '3', '4', '5', '6', '7', '8', '9']
4543 >>> it.items_seen
4544 10
4545 """
4547 def __init__(self, iterable):
4548 self._iterator = iter(iterable)
4549 self.items_seen = 0
4551 def __iter__(self):
4552 return self
4554 def __next__(self):
4555 item = next(self._iterator)
4556 self.items_seen += 1
4558 return item
4561def chunked_even(iterable, n):
4562 """Break *iterable* into lists of approximately length *n*.
4563 Items are distributed such the lengths of the lists differ by at most
4564 1 item.
4566 >>> iterable = [1, 2, 3, 4, 5, 6, 7]
4567 >>> n = 3
4568 >>> list(chunked_even(iterable, n)) # List lengths: 3, 2, 2
4569 [[1, 2, 3], [4, 5], [6, 7]]
4570 >>> list(chunked(iterable, n)) # List lengths: 3, 3, 1
4571 [[1, 2, 3], [4, 5, 6], [7]]
4573 """
4574 iterator = iter(iterable)
4576 # Initialize a buffer to process the chunks while keeping
4577 # some back to fill any underfilled chunks
4578 min_buffer = (n - 1) * (n - 2)
4579 buffer = list(islice(iterator, min_buffer))
4581 # Append items until we have a completed chunk
4582 for _ in islice(map(buffer.append, iterator), n, None, n):
4583 yield buffer[:n]
4584 del buffer[:n]
4586 # Check if any chunks need addition processing
4587 if not buffer:
4588 return
4589 length = len(buffer)
4591 # Chunks are either size `full_size <= n` or `partial_size = full_size - 1`
4592 q, r = divmod(length, n)
4593 num_lists = q + (1 if r > 0 else 0)
4594 q, r = divmod(length, num_lists)
4595 full_size = q + (1 if r > 0 else 0)
4596 partial_size = full_size - 1
4597 num_full = length - partial_size * num_lists
4599 # Yield chunks of full size
4600 partial_start_idx = num_full * full_size
4601 if full_size > 0:
4602 for i in range(0, partial_start_idx, full_size):
4603 yield buffer[i : i + full_size]
4605 # Yield chunks of partial size
4606 if partial_size > 0:
4607 for i in range(partial_start_idx, length, partial_size):
4608 yield buffer[i : i + partial_size]
4611def zip_broadcast(*objects, scalar_types=(str, bytes), strict=False):
4612 """A version of :func:`zip` that "broadcasts" any scalar
4613 (i.e., non-iterable) items into output tuples.
4615 >>> iterable_1 = [1, 2, 3]
4616 >>> iterable_2 = ['a', 'b', 'c']
4617 >>> scalar = '_'
4618 >>> list(zip_broadcast(iterable_1, iterable_2, scalar))
4619 [(1, 'a', '_'), (2, 'b', '_'), (3, 'c', '_')]
4621 The *scalar_types* keyword argument determines what types are considered
4622 scalar. It is set to ``(str, bytes)`` by default. Set it to ``None`` to
4623 treat strings and byte strings as iterable:
4625 >>> list(zip_broadcast('abc', 0, 'xyz', scalar_types=None))
4626 [('a', 0, 'x'), ('b', 0, 'y'), ('c', 0, 'z')]
4628 If the *strict* keyword argument is ``True``, then
4629 ``ValueError`` will be raised if any of the iterables have
4630 different lengths.
4631 """
4633 def is_scalar(obj):
4634 if scalar_types and isinstance(obj, scalar_types):
4635 return True
4636 try:
4637 iter(obj)
4638 except TypeError:
4639 return True
4640 else:
4641 return False
4643 size = len(objects)
4644 if not size:
4645 return
4647 new_item = [None] * size
4648 iterables, iterable_positions = [], []
4649 for i, obj in enumerate(objects):
4650 if is_scalar(obj):
4651 new_item[i] = obj
4652 else:
4653 iterables.append(iter(obj))
4654 iterable_positions.append(i)
4656 if not iterables:
4657 yield tuple(objects)
4658 return
4660 for item in zip(*iterables, strict=strict):
4661 for i, new_item[i] in zip(iterable_positions, item):
4662 pass
4663 yield tuple(new_item)
4666def unique_in_window(iterable, n, key=None):
4667 """Yield the items from *iterable* that haven't been seen recently.
4668 *n* is the size of the sliding window.
4670 >>> iterable = [0, 1, 0, 2, 3, 0]
4671 >>> n = 3
4672 >>> list(unique_in_window(iterable, n))
4673 [0, 1, 2, 3, 0]
4675 The *key* function, if provided, will be used to determine uniqueness:
4677 >>> list(unique_in_window('abAcda', 3, key=lambda x: x.lower()))
4678 ['a', 'b', 'c', 'd', 'a']
4680 Updates a sliding window no larger than n and yields a value
4681 if the item only occurs once in the updated window.
4683 When `n == 1`, *unique_in_window* is memoryless:
4685 >>> list(unique_in_window('aab', n=1))
4686 ['a', 'a', 'b']
4688 The items in *iterable* must be hashable.
4690 """
4691 if n <= 0:
4692 raise ValueError('n must be greater than 0')
4694 window = deque(maxlen=n)
4695 counts = Counter()
4696 use_key = key is not None
4698 for item in iterable:
4699 if len(window) == n:
4700 to_discard = window[0]
4701 if counts[to_discard] == 1:
4702 del counts[to_discard]
4703 else:
4704 counts[to_discard] -= 1
4706 k = key(item) if use_key else item
4707 if k not in counts:
4708 yield item
4709 counts[k] += 1
4710 window.append(k)
4713def duplicates_everseen(iterable, key=None):
4714 """Yield duplicate elements after their first appearance.
4716 >>> list(duplicates_everseen('mississippi'))
4717 ['s', 'i', 's', 's', 'i', 'p', 'i']
4718 >>> list(duplicates_everseen('AaaBbbCccAaa', str.lower))
4719 ['a', 'a', 'b', 'b', 'c', 'c', 'A', 'a', 'a']
4721 This function is analogous to :func:`unique_everseen` and is subject to
4722 the same performance considerations.
4724 """
4725 seen_set = set()
4726 seen_list = []
4727 use_key = key is not None
4729 for element in iterable:
4730 k = key(element) if use_key else element
4731 try:
4732 if k not in seen_set:
4733 seen_set.add(k)
4734 else:
4735 yield element
4736 except TypeError:
4737 if k not in seen_list:
4738 seen_list.append(k)
4739 else:
4740 yield element
4743def duplicates_justseen(iterable, key=None):
4744 """Yields serially-duplicate elements after their first appearance.
4746 >>> list(duplicates_justseen('mississippi'))
4747 ['s', 's', 'p']
4748 >>> list(duplicates_justseen('AaaBbbCccAaa', str.lower))
4749 ['a', 'a', 'b', 'b', 'c', 'c', 'a', 'a']
4751 This function is analogous to :func:`unique_justseen`.
4753 """
4754 return flatten(g for _, g in groupby(iterable, key) for _ in g)
4757def classify_unique(iterable, key=None):
4758 """Classify each element in terms of its uniqueness.
4760 For each element in the input iterable, return a 3-tuple consisting of:
4762 1. The element itself
4763 2. ``False`` if the element is equal to the one preceding it in the input,
4764 ``True`` otherwise (i.e. the equivalent of :func:`unique_justseen`)
4765 3. ``False`` if this element has been seen anywhere in the input before,
4766 ``True`` otherwise (i.e. the equivalent of :func:`unique_everseen`)
4768 >>> list(classify_unique('otto')) # doctest: +NORMALIZE_WHITESPACE
4769 [('o', True, True),
4770 ('t', True, True),
4771 ('t', False, False),
4772 ('o', True, False)]
4774 This function is analogous to :func:`unique_everseen` and is subject to
4775 the same performance considerations.
4777 """
4778 seen_set = set()
4779 seen_list = []
4780 use_key = key is not None
4781 previous = None
4783 for i, element in enumerate(iterable):
4784 k = key(element) if use_key else element
4785 is_unique_justseen = not i or previous != k
4786 previous = k
4787 is_unique_everseen = False
4788 try:
4789 if k not in seen_set:
4790 seen_set.add(k)
4791 is_unique_everseen = True
4792 except TypeError:
4793 if k not in seen_list:
4794 seen_list.append(k)
4795 is_unique_everseen = True
4796 yield element, is_unique_justseen, is_unique_everseen
4799def minmax(iterable_or_value, *others, key=None, default=_marker):
4800 """Returns both the smallest and largest items from an iterable
4801 or from two or more arguments.
4803 >>> minmax([3, 1, 5])
4804 (1, 5)
4806 >>> minmax(4, 2, 6)
4807 (2, 6)
4809 If a *key* function is provided, it will be used to transform the input
4810 items for comparison.
4812 >>> minmax([5, 30], key=str) # '30' sorts before '5'
4813 (30, 5)
4815 If a *default* value is provided, it will be returned if there are no
4816 input items.
4818 >>> minmax([], default=(0, 0))
4819 (0, 0)
4821 Otherwise ``ValueError`` is raised.
4823 This function makes a single pass over the input elements and takes care to
4824 minimize the number of comparisons made during processing.
4826 Note that unlike the builtin ``max`` function, which always returns the first
4827 item with the maximum value, this function may return another item when there are
4828 ties.
4830 This function is based on the
4831 `recipe <https://code.activestate.com/recipes/577916-fast-minmax-function>`__ by
4832 Raymond Hettinger.
4833 """
4834 iterable = (iterable_or_value, *others) if others else iterable_or_value
4836 it = iter(iterable)
4838 try:
4839 lo = hi = next(it)
4840 except StopIteration as exc:
4841 if default is _marker:
4842 raise ValueError(
4843 '`minmax()` argument is an empty iterable. '
4844 'Provide a `default` value to suppress this error.'
4845 ) from exc
4846 return default
4848 # Different branches depending on the presence of key. This saves a lot
4849 # of unimportant copies which would slow the "key=None" branch
4850 # significantly down.
4851 if key is None:
4852 for x, y in zip_longest(it, it, fillvalue=lo):
4853 if y < x:
4854 x, y = y, x
4855 if x < lo:
4856 lo = x
4857 if hi < y:
4858 hi = y
4860 else:
4861 lo_key = hi_key = key(lo)
4863 for x, y in zip_longest(it, it, fillvalue=lo):
4864 x_key, y_key = key(x), key(y)
4866 if y_key < x_key:
4867 x, y, x_key, y_key = y, x, y_key, x_key
4868 if x_key < lo_key:
4869 lo, lo_key = x, x_key
4870 if hi_key < y_key:
4871 hi, hi_key = y, y_key
4873 return lo, hi
4876def constrained_batches(
4877 iterable, max_size, max_count=None, get_len=len, strict=True
4878):
4879 """Yield batches of items from *iterable* with a combined size limited by
4880 *max_size*.
4882 >>> iterable = [b'12345', b'123', b'12345678', b'1', b'1', b'12', b'1']
4883 >>> list(constrained_batches(iterable, 10))
4884 [(b'12345', b'123'), (b'12345678', b'1', b'1'), (b'12', b'1')]
4886 If a *max_count* is supplied, the number of items per batch is also
4887 limited:
4889 >>> iterable = [b'12345', b'123', b'12345678', b'1', b'1', b'12', b'1']
4890 >>> list(constrained_batches(iterable, 10, max_count = 2))
4891 [(b'12345', b'123'), (b'12345678', b'1'), (b'1', b'12'), (b'1',)]
4893 If a *get_len* function is supplied, use that instead of :func:`len` to
4894 determine item size.
4896 If *strict* is ``True``, raise ``ValueError`` if any single item is bigger
4897 than *max_size*. Otherwise, allow single items to exceed *max_size*.
4898 """
4899 if max_size <= 0:
4900 raise ValueError('maximum size must be greater than zero')
4902 batch = []
4903 batch_size = 0
4904 batch_count = 0
4905 for item in iterable:
4906 item_len = get_len(item)
4907 if strict and item_len > max_size:
4908 raise ValueError('item size exceeds maximum size')
4910 reached_count = batch_count == max_count
4911 reached_size = item_len + batch_size > max_size
4912 if batch_count and (reached_size or reached_count):
4913 yield tuple(batch)
4914 batch.clear()
4915 batch_size = 0
4916 batch_count = 0
4918 batch.append(item)
4919 batch_size += item_len
4920 batch_count += 1
4922 if batch:
4923 yield tuple(batch)
4926def gray_product(*iterables, repeat=1):
4927 """Like :func:`itertools.product`, but return tuples in an order such
4928 that only one element in the generated tuple changes from one iteration
4929 to the next.
4931 >>> list(gray_product('AB','CD'))
4932 [('A', 'C'), ('B', 'C'), ('B', 'D'), ('A', 'D')]
4934 The *repeat* keyword argument specifies the number of repetitions
4935 of the iterables. For example, ``gray_product('AB', repeat=3)`` is
4936 equivalent to ``gray_product('AB', 'AB', 'AB')``.
4938 This function consumes all of the input iterables before producing output.
4939 If any of the input iterables have fewer than two items, ``ValueError``
4940 is raised.
4942 For information on the algorithm, see
4943 `this section <https://www-cs-faculty.stanford.edu/~knuth/fasc2a.ps.gz>`__
4944 of Donald Knuth's *The Art of Computer Programming*.
4945 """
4946 all_iterables = tuple(map(tuple, iterables)) * repeat
4947 iterable_count = len(all_iterables)
4948 for iterable in all_iterables:
4949 if len(iterable) < 2:
4950 raise ValueError("each iterable must have two or more items")
4952 # This is based on "Algorithm H" from section 7.2.1.1, page 20.
4953 # a holds the indexes of the source iterables for the n-tuple to be yielded
4954 # f is the array of "focus pointers"
4955 # o is the array of "directions"
4956 a = [0] * iterable_count
4957 f = list(range(iterable_count + 1))
4958 o = [1] * iterable_count
4959 while True:
4960 yield tuple(all_iterables[i][a[i]] for i in range(iterable_count))
4961 j = f[0]
4962 f[0] = 0
4963 if j == iterable_count:
4964 break
4965 a[j] = a[j] + o[j]
4966 if a[j] == 0 or a[j] == len(all_iterables[j]) - 1:
4967 o[j] = -o[j]
4968 f[j] = f[j + 1]
4969 f[j + 1] = j + 1
4972def partial_product(*iterables, repeat=1):
4973 """Yields tuples containing one item from each iterator, with subsequent
4974 tuples changing a single item at a time by advancing each iterator until it
4975 is exhausted. This sequence guarantees every value in each iterable is
4976 output at least once without generating all possible combinations.
4978 This may be useful, for example, when testing an expensive function.
4980 >>> list(partial_product('AB', 'C', 'DEF'))
4981 [('A', 'C', 'D'), ('B', 'C', 'D'), ('B', 'C', 'E'), ('B', 'C', 'F')]
4983 The *repeat* keyword argument specifies the number of repetitions
4984 of the iterables. For example, ``partial_product('AB', repeat=3)`` is
4985 equivalent to ``partial_product('AB', 'AB', 'AB')``.
4986 """
4988 all_iterables = tuple(map(tuple, iterables)) * repeat
4989 iterators = tuple(map(iter, all_iterables))
4991 try:
4992 prod = [next(it) for it in iterators]
4993 except StopIteration:
4994 return
4995 yield tuple(prod)
4997 for i, it in enumerate(iterators):
4998 for prod[i] in it:
4999 yield tuple(prod)
5002def takewhile_inclusive(predicate, iterable):
5003 """A variant of :func:`takewhile` that yields one additional element.
5005 >>> list(takewhile_inclusive(lambda x: x < 5, [1, 4, 6, 4, 1]))
5006 [1, 4, 6]
5008 :func:`takewhile` would return ``[1, 4]``.
5009 """
5010 for x in iterable:
5011 yield x
5012 if not predicate(x):
5013 break
5016def outer_product(func, xs, ys, *args, **kwargs):
5017 """A generalized outer product that applies a binary function to all
5018 pairs of items. Returns a 2D matrix with ``len(xs)`` rows and ``len(ys)``
5019 columns.
5020 Also accepts ``*args`` and ``**kwargs`` that are passed to ``func``.
5022 Multiplication table:
5024 >>> from operator import mul
5025 >>> list(outer_product(mul, range(1, 4), range(1, 6)))
5026 [(1, 2, 3, 4, 5), (2, 4, 6, 8, 10), (3, 6, 9, 12, 15)]
5028 Cross tabulation:
5030 >>> xs = ['A', 'B', 'A', 'A', 'B', 'B', 'A', 'A', 'B', 'B']
5031 >>> ys = ['X', 'X', 'X', 'Y', 'Z', 'Z', 'Y', 'Y', 'Z', 'Z']
5032 >>> pair_counts = Counter(zip(xs, ys))
5033 >>> count_rows = lambda x, y: pair_counts[x, y]
5034 >>> list(outer_product(count_rows, sorted(set(xs)), sorted(set(ys))))
5035 [(2, 3, 0), (1, 0, 4)]
5037 Usage with ``*args`` and ``**kwargs``:
5039 >>> animals = ['cat', 'wolf', 'mouse']
5040 >>> list(outer_product(min, animals, animals, key=len))
5041 [('cat', 'cat', 'cat'), ('cat', 'wolf', 'wolf'), ('cat', 'wolf', 'mouse')]
5042 """
5043 ys = tuple(ys)
5044 return batched(
5045 starmap(lambda x, y: func(x, y, *args, **kwargs), product(xs, ys)),
5046 n=len(ys),
5047 )
5050def iter_suppress(iterable, *exceptions):
5051 """Yield each of the items from *iterable*. If the iteration raises one of
5052 the specified *exceptions*, that exception will be suppressed and iteration
5053 will stop.
5055 >>> from itertools import chain
5056 >>> def breaks_at_five(x):
5057 ... while True:
5058 ... if x >= 5:
5059 ... raise RuntimeError
5060 ... yield x
5061 ... x += 1
5062 >>> it_1 = iter_suppress(breaks_at_five(1), RuntimeError)
5063 >>> it_2 = iter_suppress(breaks_at_five(2), RuntimeError)
5064 >>> list(chain(it_1, it_2))
5065 [1, 2, 3, 4, 2, 3, 4]
5066 """
5067 try:
5068 yield from iterable
5069 except exceptions:
5070 return
5073def filter_map(func, iterable):
5074 """Apply *func* to every element of *iterable*, yielding only those which
5075 are not ``None``.
5077 >>> elems = ['1', 'a', '2', 'b', '3']
5078 >>> list(filter_map(lambda s: int(s) if s.isnumeric() else None, elems))
5079 [1, 2, 3]
5080 """
5081 for x in iterable:
5082 y = func(x)
5083 if y is not None:
5084 yield y
5087def powerset_of_sets(iterable, *, baseset=set):
5088 """Yields all possible subsets of the iterable.
5090 >>> list(powerset_of_sets([1, 2, 3])) # doctest: +SKIP
5091 [set(), {1}, {2}, {3}, {1, 2}, {1, 3}, {2, 3}, {1, 2, 3}]
5092 >>> list(powerset_of_sets([1, 1, 0])) # doctest: +SKIP
5093 [set(), {1}, {0}, {0, 1}]
5095 :func:`powerset_of_sets` takes care to minimize the number
5096 of hash operations performed.
5098 The *baseset* parameter determines what kind of sets are
5099 constructed, either *set* or *frozenset*.
5100 """
5101 sets = tuple(dict.fromkeys(map(frozenset, zip(iterable))))
5102 union = baseset().union
5103 return chain.from_iterable(
5104 starmap(union, combinations(sets, r)) for r in range(len(sets) + 1)
5105 )
5108def join_mappings(**field_to_map):
5109 """
5110 Joins multiple mappings together using their common keys.
5112 >>> user_scores = {'elliot': 50, 'claris': 60}
5113 >>> user_times = {'elliot': 30, 'claris': 40}
5114 >>> join_mappings(score=user_scores, time=user_times)
5115 {'elliot': {'score': 50, 'time': 30}, 'claris': {'score': 60, 'time': 40}}
5116 """
5117 ret = defaultdict(dict)
5119 for field_name, mapping in field_to_map.items():
5120 for key, value in mapping.items():
5121 ret[key][field_name] = value
5123 return dict(ret)
5126def _complex_sumprod(v1, v2):
5127 """High precision sumprod() for complex numbers.
5128 Used by :func:`dft` and :func:`idft`.
5129 """
5131 real = attrgetter('real')
5132 imag = attrgetter('imag')
5133 r1 = chain(map(real, v1), map(neg, map(imag, v1)))
5134 r2 = chain(map(real, v2), map(imag, v2))
5135 i1 = chain(map(real, v1), map(imag, v1))
5136 i2 = chain(map(imag, v2), map(real, v2))
5137 return complex(_fsumprod(r1, r2), _fsumprod(i1, i2))
5140def dft(xarr):
5141 """Discrete Fourier Transform. *xarr* is a sequence of complex numbers.
5142 Yields the components of the corresponding transformed output vector.
5144 >>> import cmath
5145 >>> xarr = [1, 2-1j, -1j, -1+2j] # time domain
5146 >>> Xarr = [2, -2-2j, -2j, 4+4j] # frequency domain
5147 >>> magnitudes, phases = zip(*map(cmath.polar, Xarr))
5148 >>> all(map(cmath.isclose, dft(xarr), Xarr))
5149 True
5151 Inputs are restricted to numeric types that can add and multiply
5152 with a complex number. This includes int, float, complex, and
5153 Fraction, but excludes Decimal.
5155 See :func:`idft` for the inverse Discrete Fourier Transform.
5156 """
5157 N = len(xarr)
5158 roots_of_unity = [e ** (n / N * tau * -1j) for n in range(N)]
5159 for k in range(N):
5160 coeffs = [roots_of_unity[k * n % N] for n in range(N)]
5161 yield _complex_sumprod(xarr, coeffs)
5164def idft(Xarr):
5165 """Inverse Discrete Fourier Transform. *Xarr* is a sequence of
5166 complex numbers. Yields the components of the corresponding
5167 inverse-transformed output vector.
5169 >>> import cmath
5170 >>> xarr = [1, 2-1j, -1j, -1+2j] # time domain
5171 >>> Xarr = [2, -2-2j, -2j, 4+4j] # frequency domain
5172 >>> all(map(cmath.isclose, idft(Xarr), xarr))
5173 True
5175 Inputs are restricted to numeric types that can add and multiply
5176 with a complex number. This includes int, float, complex, and
5177 Fraction, but excludes Decimal.
5179 See :func:`dft` for the Discrete Fourier Transform.
5180 """
5181 N = len(Xarr)
5182 roots_of_unity = [e ** (n / N * tau * 1j) for n in range(N)]
5183 for k in range(N):
5184 coeffs = [roots_of_unity[k * n % N] for n in range(N)]
5185 yield _complex_sumprod(Xarr, coeffs) / N
5188def doublestarmap(func, iterable):
5189 """Apply *func* to every item of *iterable* by dictionary unpacking
5190 the item into *func*.
5192 The difference between :func:`itertools.starmap` and :func:`doublestarmap`
5193 parallels the distinction between ``func(*a)`` and ``func(**a)``.
5195 >>> iterable = [{'a': 1, 'b': 2}, {'a': 40, 'b': 60}]
5196 >>> list(doublestarmap(lambda a, b: a + b, iterable))
5197 [3, 100]
5199 ``TypeError`` will be raised if *func*'s signature doesn't match the
5200 mapping contained in *iterable* or if *iterable* does not contain mappings.
5201 """
5202 for item in iterable:
5203 yield func(**item)
5206def _nth_prime_bounds(n):
5207 """Bounds for the nth prime (counting from 1): lb < p_n < ub."""
5208 # At and above 688,383, the lb/ub spread is under 0.003 * p_n.
5210 if n < 1:
5211 raise ValueError
5213 if n < 6:
5214 return (n, 2.25 * n)
5216 # https://en.wikipedia.org/wiki/Prime-counting_function#Inequalities
5217 upper_bound = n * log(n * log(n))
5218 lower_bound = upper_bound - n
5219 if n >= 688_383:
5220 upper_bound -= n * (1.0 - (log(log(n)) - 2.0) / log(n))
5222 return lower_bound, upper_bound
5225def nth_prime(n, *, approximate=False):
5226 """Return the nth prime (counting from 0).
5228 >>> nth_prime(0)
5229 2
5230 >>> nth_prime(100)
5231 547
5233 If *approximate* is set to True, will return a prime close
5234 to the nth prime. The estimation is much faster than computing
5235 an exact result.
5237 >>> nth_prime(200_000_000, approximate=True) # Exact result is 4222234763
5238 4217820427
5240 """
5241 lb, ub = _nth_prime_bounds(n + 1)
5243 if not approximate or n <= 1_000_000:
5244 return nth(sieve(ceil(ub)), n)
5246 # Search from the midpoint and return the first odd prime
5247 odd = floor((lb + ub) / 2) | 1
5248 return first_true(count(odd, step=2), pred=is_prime)
5251def argmin(iterable, *, key=None):
5252 """
5253 Index of the first occurrence of a minimum value in an iterable.
5255 >>> argmin('efghabcdijkl')
5256 4
5257 >>> argmin([3, 2, 1, 0, 4, 2, 1, 0])
5258 3
5260 For example, look up a label corresponding to the position
5261 of a value that minimizes a cost function::
5263 >>> def cost(x):
5264 ... "Days for a wound to heal given a subject's age."
5265 ... return x**2 - 20*x + 150
5266 ...
5267 >>> labels = ['homer', 'marge', 'bart', 'lisa', 'maggie']
5268 >>> ages = [ 35, 30, 10, 9, 1 ]
5270 # Fastest healing family member
5271 >>> labels[argmin(ages, key=cost)]
5272 'bart'
5274 # Age with fastest healing
5275 >>> min(ages, key=cost)
5276 10
5278 """
5279 if key is not None:
5280 iterable = map(key, iterable)
5281 return min(enumerate(iterable), key=itemgetter(1))[0]
5284def argmax(iterable, *, key=None):
5285 """
5286 Index of the first occurrence of a maximum value in an iterable.
5288 >>> argmax('abcdefghabcd')
5289 7
5290 >>> argmax([0, 1, 2, 3, 3, 2, 1, 0])
5291 3
5293 For example, identify the best machine learning model::
5295 >>> models = ['svm', 'random forest', 'knn', 'naïve bayes']
5296 >>> accuracy = [ 68, 61, 84, 72 ]
5298 # Most accurate model
5299 >>> models[argmax(accuracy)]
5300 'knn'
5302 # Best accuracy
5303 >>> max(accuracy)
5304 84
5306 """
5307 if key is not None:
5308 iterable = map(key, iterable)
5309 return max(enumerate(iterable), key=itemgetter(1))[0]
5312def _extract_monotonic(iterator, indices):
5313 'Non-decreasing indices, lazily consumed'
5314 num_read = 0
5315 for index in indices:
5316 advance = index - num_read
5317 try:
5318 value = next(islice(iterator, advance, None))
5319 except ValueError:
5320 if advance != -1 or index < 0:
5321 raise ValueError(f'Invalid index: {index}') from None
5322 except StopIteration:
5323 raise IndexError(index) from None
5324 else:
5325 num_read += advance + 1
5326 yield value
5329def _extract_buffered(iterator, index_and_position):
5330 'Arbitrary index order, greedily consumed'
5331 buffer = {}
5332 iterator_position = -1
5333 next_to_emit = 0
5335 for index, order in index_and_position:
5336 advance = index - iterator_position
5337 if advance:
5338 try:
5339 value = next(islice(iterator, advance - 1, None))
5340 except StopIteration:
5341 raise IndexError(index) from None
5342 iterator_position = index
5344 buffer[order] = value
5346 while next_to_emit in buffer:
5347 yield buffer.pop(next_to_emit)
5348 next_to_emit += 1
5351def extract(iterable, indices, *, monotonic=False):
5352 """Yield values at the specified indices.
5354 Example:
5356 >>> data = 'abcdefghijklmnopqrstuvwxyz'
5357 >>> list(extract(data, [7, 4, 11, 11, 14]))
5358 ['h', 'e', 'l', 'l', 'o']
5360 The *iterable* is consumed lazily and can be infinite.
5362 When *monotonic* is false, the *indices* are consumed immediately
5363 and must be finite. When *monotonic* is true, *indices* are consumed
5364 lazily and can be infinite but must be non-decreasing.
5366 Raises ``IndexError`` if an index lies beyond the iterable.
5367 Raises ``ValueError`` for a negative index or for a decreasing
5368 index when *monotonic* is true.
5369 """
5371 iterator = iter(iterable)
5372 indices = iter(indices)
5374 if monotonic:
5375 return _extract_monotonic(iterator, indices)
5377 index_and_position = sorted(zip(indices, count()))
5378 if index_and_position and index_and_position[0][0] < 0:
5379 raise ValueError('Indices must be non-negative')
5380 return _extract_buffered(iterator, index_and_position)
5383class serialize:
5384 """Wrap a non-concurrent iterator with a lock to enforce sequential access.
5386 Applies a non-reentrant lock around calls to ``__next__``, allowing
5387 iterator and generator instances to be shared by multiple consumer
5388 threads.
5389 """
5391 __slots__ = ('_iterator', '_lock')
5393 def __init__(self, iterable):
5394 self._iterator = iter(iterable)
5395 self._lock = Lock()
5397 def __iter__(self):
5398 return self
5400 def __next__(self):
5401 with self._lock:
5402 return next(self._iterator)
5404 def send(self, value, /):
5405 """Send a value to a generator.
5407 Raises AttributeError if not a generator.
5408 """
5409 with self._lock:
5410 return self._iterator.send(value)
5412 def throw(self, *args):
5413 """Call throw() on a generator.
5415 Raises AttributeError if not a generator.
5416 """
5417 with self._lock:
5418 return self._iterator.throw(*args)
5420 def close(self):
5421 """Call close() on a generator.
5423 Raises AttributeError if not a generator.
5424 """
5425 with self._lock:
5426 return self._iterator.close()
5429def synchronized(func):
5430 """Wrap an iterator-returning callable to make its iterators thread-safe.
5432 Existing itertools and more-itertools can be wrapped so that their
5433 iterator instances are serialized.
5435 For example, ``itertools.count`` does not make thread-safe instances,
5436 but that is easily fixed with::
5438 atomic_counter = synchronized(itertools.count)
5440 Can also be used as a decorator for generator functions definitions
5441 so that the generator instances are serialized::
5443 @synchronized
5444 def enumerate_and_timestamp(iterable):
5445 for count, value in enumerate(iterable):
5446 yield count, time_ns(), value
5448 """
5450 @wraps(func)
5451 def inner(*args, **kwargs):
5452 iterator = func(*args, **kwargs)
5453 return serialize(iterator)
5455 return inner
5458def concurrent_tee(iterable, n=2):
5459 """Variant of itertools.tee() but with guaranteed threading semantics.
5461 Takes a non-threadsafe iterator as an input and creates concurrent
5462 tee objects for other threads to have reliable independent copies of
5463 the data stream.
5465 The new iterators are only thread-safe if consumed within a single thread.
5466 To share just one of the new iterators across multiple threads, wrap it
5467 with :func:`serialize`.
5468 """
5470 if n < 0:
5471 raise ValueError
5472 if n == 0:
5473 return ()
5474 iterator = _concurrent_tee(iterable)
5475 result = [iterator]
5476 for _ in range(n - 1):
5477 result.append(_concurrent_tee(iterator))
5478 return tuple(result)
5481class _concurrent_tee:
5482 __slots__ = ('iterator', 'link', 'lock')
5484 def __init__(self, iterable):
5485 if isinstance(iterable, _concurrent_tee):
5486 self.iterator = iterable.iterator
5487 self.link = iterable.link
5488 self.lock = iterable.lock
5489 else:
5490 self.iterator = iter(iterable)
5491 self.link = [None, None]
5492 self.lock = Lock()
5494 def __iter__(self):
5495 return self
5497 def __next__(self):
5498 link = self.link
5499 if link[1] is None:
5500 with self.lock:
5501 if link[1] is None:
5502 link[0] = next(self.iterator)
5503 link[1] = [None, None]
5504 value, self.link = link
5505 return value