Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/more_itertools/more.py: 19%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import math
2import types
4from collections import Counter, defaultdict, deque
5from collections.abc import Sequence
6from contextlib import suppress
7from functools import cached_property, partial, wraps
8from heapq import heapify, heapreplace
9from itertools import (
10 chain,
11 combinations,
12 compress,
13 count,
14 cycle,
15 dropwhile,
16 groupby,
17 islice,
18 permutations,
19 repeat,
20 starmap,
21 takewhile,
22 tee,
23 zip_longest,
24 product,
25)
26from math import comb, e, exp, factorial, floor, fsum, log, log1p, perm, tau
27from math import ceil, prod
28from queue import Empty, Queue
29from random import random, randrange, shuffle, uniform
30from operator import (
31 attrgetter,
32 getitem,
33 is_not,
34 itemgetter,
35 lt,
36 neg,
37 sub,
38 gt,
39)
40from sys import maxsize
41from time import monotonic
42from threading import Lock
44from .recipes import (
45 _marker,
46 consume,
47 first_true,
48 flatten,
49 is_prime,
50 nth,
51 powerset,
52 sieve,
53 take,
54 unique_everseen,
55 all_equal,
56 batched,
57)
59__all__ = [
60 'AbortThread',
61 'SequenceView',
62 'adjacent',
63 'all_unique',
64 'always_iterable',
65 'always_reversible',
66 'argmax',
67 'argmin',
68 'bucket',
69 'callback_iter',
70 'chunked',
71 'chunked_even',
72 'circular_shifts',
73 'collapse',
74 'combination_index',
75 'combination_with_replacement_index',
76 'concurrent_tee',
77 'consecutive_groups',
78 'constrained_batches',
79 'consumer',
80 'count_cycle',
81 'countable',
82 'derangements',
83 'dft',
84 'difference',
85 'distinct_combinations',
86 'distinct_permutations',
87 'distribute',
88 'divide',
89 'doublestarmap',
90 'duplicates_everseen',
91 'duplicates_justseen',
92 'classify_unique',
93 'exactly_n',
94 'extract',
95 'filter_except',
96 'filter_map',
97 'first',
98 'gray_product',
99 'groupby_transform',
100 'ichunked',
101 'iequals',
102 'idft',
103 'ilen',
104 'interleave',
105 'interleave_evenly',
106 'interleave_longest',
107 'interleave_randomly',
108 'intersperse',
109 'is_sorted',
110 'islice_extended',
111 'iterate',
112 'iter_suppress',
113 'join_mappings',
114 'last',
115 'locate',
116 'longest_common_prefix',
117 'lstrip',
118 'make_decorator',
119 'map_except',
120 'map_if',
121 'map_reduce',
122 'mark_ends',
123 'minmax',
124 'nth_or_last',
125 'nth_permutation',
126 'nth_prime',
127 'nth_product',
128 'nth_combination_with_replacement',
129 'numeric_range',
130 'one',
131 'only',
132 'outer_product',
133 'padded',
134 'partial_product',
135 'partitions',
136 'peekable',
137 'permutation_index',
138 'powerset_of_sets',
139 'product_index',
140 'raise_',
141 'repeat_each',
142 'repeat_last',
143 'replace',
144 'rlocate',
145 'rstrip',
146 'run_length',
147 'sample',
148 'seekable',
149 'serialize',
150 'set_partitions',
151 'side_effect',
152 'sized_iterator',
153 'sliced',
154 'sort_together',
155 'split_after',
156 'split_at',
157 'split_before',
158 'split_into',
159 'split_when',
160 'spy',
161 'stagger',
162 'strip',
163 'strictly_n',
164 'substrings',
165 'substrings_indexes',
166 'synchronized',
167 'takewhile_inclusive',
168 'time_limited',
169 'unique_in_window',
170 'unique_to_each',
171 'unzip',
172 'value_chain',
173 'windowed',
174 'windowed_complete',
175 'with_iter',
176 'zip_broadcast',
177 'zip_offset',
178]
180# math.sumprod is available for Python 3.12+
181try:
182 from math import sumprod as _fsumprod
184except ImportError: # pragma: no cover
185 # Extended precision algorithms from T. J. Dekker,
186 # "A Floating-Point Technique for Extending the Available Precision"
187 # https://csclub.uwaterloo.ca/~pbarfuss/dekker1971.pdf
188 # Formulas: (5.5) (5.6) and (5.8). Code: mul12()
190 def dl_split(x: float):
191 "Split a float into two half-precision components."
192 t = x * 134217729.0 # Veltkamp constant = 2.0 ** 27 + 1
193 hi = t - (t - x)
194 lo = x - hi
195 return hi, lo
197 def dl_mul(x, y):
198 "Lossless multiplication."
199 xx_hi, xx_lo = dl_split(x)
200 yy_hi, yy_lo = dl_split(y)
201 p = xx_hi * yy_hi
202 q = xx_hi * yy_lo + xx_lo * yy_hi
203 z = p + q
204 zz = p - z + q + xx_lo * yy_lo
205 return z, zz
207 def _fsumprod(p, q):
208 return fsum(chain.from_iterable(map(dl_mul, p, q)))
211def chunked(iterable, n, strict=False):
212 """Break *iterable* into lists of length *n*:
214 >>> list(chunked([1, 2, 3, 4, 5, 6], 3))
215 [[1, 2, 3], [4, 5, 6]]
217 By the default, the last yielded list will have fewer than *n* elements
218 if the length of *iterable* is not divisible by *n*:
220 >>> list(chunked([1, 2, 3, 4, 5, 6, 7, 8], 3))
221 [[1, 2, 3], [4, 5, 6], [7, 8]]
223 To use a fill-in value instead, see the :func:`grouper` recipe.
225 If the length of *iterable* is not divisible by *n* and *strict* is
226 ``True``, then ``ValueError`` will be raised before the last
227 list is yielded.
229 """
230 iterator = iter(partial(take, n, iter(iterable)), [])
231 if strict:
232 if n is None:
233 raise ValueError('n must not be None when using strict mode.')
235 def ret():
236 for chunk in iterator:
237 if len(chunk) != n:
238 raise ValueError('iterable is not divisible by n.')
239 yield chunk
241 return ret()
242 else:
243 return iterator
246def first(iterable, default=_marker):
247 """Return the first item of *iterable*, or *default* if *iterable* is
248 empty.
250 >>> first([0, 1, 2, 3])
251 0
252 >>> first([], 'some default')
253 'some default'
255 If *default* is not provided and there are no items in the iterable,
256 raise ``ValueError``.
258 :func:`first` is useful when you have a generator of expensive-to-retrieve
259 values and want any arbitrary one. It is marginally shorter than
260 ``next(iter(iterable), default)``.
262 """
263 for item in iterable:
264 return item
265 if default is _marker:
266 raise ValueError(
267 'first() was called on an empty iterable, '
268 'and no default value was provided.'
269 )
270 return default
273def last(iterable, default=_marker):
274 """Return the last item of *iterable*, or *default* if *iterable* is
275 empty.
277 >>> last([0, 1, 2, 3])
278 3
279 >>> last([], 'some default')
280 'some default'
282 If *default* is not provided and there are no items in the iterable,
283 raise ``ValueError``.
284 """
285 try:
286 if isinstance(iterable, Sequence):
287 return iterable[-1]
288 # Work around https://bugs.python.org/issue38525
289 if getattr(iterable, '__reversed__', None):
290 return next(reversed(iterable))
291 return deque(iterable, maxlen=1)[-1]
292 except (IndexError, TypeError, StopIteration):
293 if default is _marker:
294 raise ValueError(
295 'last() was called on an empty iterable, '
296 'and no default value was provided.'
297 )
298 return default
301def nth_or_last(iterable, n, default=_marker):
302 """Return the nth or the last item of *iterable*,
303 or *default* if *iterable* is empty.
305 >>> nth_or_last([0, 1, 2, 3], 2)
306 2
307 >>> nth_or_last([0, 1], 2)
308 1
309 >>> nth_or_last([], 0, 'some default')
310 'some default'
312 If *default* is not provided and there are no items in the iterable,
313 raise ``ValueError``.
314 """
315 return last(islice(iterable, n + 1), default=default)
318class peekable:
319 """Wrap an iterator to allow lookahead and prepending elements.
321 Call :meth:`peek` on the result to get the value that will be returned
322 by :func:`next`. This won't advance the iterator:
324 >>> p = peekable(['a', 'b'])
325 >>> p.peek()
326 'a'
327 >>> next(p)
328 'a'
330 Pass :meth:`peek` a default value to return that instead of raising
331 ``StopIteration`` when the iterator is exhausted.
333 >>> p = peekable([])
334 >>> p.peek('hi')
335 'hi'
337 peekables also offer a :meth:`prepend` method, which "inserts" items
338 at the head of the iterable:
340 >>> p = peekable([1, 2, 3])
341 >>> p.prepend(10, 11, 12)
342 >>> next(p)
343 10
344 >>> p.peek()
345 11
346 >>> list(p)
347 [11, 12, 1, 2, 3]
349 peekables can be indexed. Index 0 is the item that will be returned by
350 :func:`next`, index 1 is the item after that, and so on:
351 The values up to the given index will be cached.
353 >>> p = peekable(['a', 'b', 'c', 'd'])
354 >>> p[0]
355 'a'
356 >>> p[1]
357 'b'
358 >>> next(p)
359 'a'
361 Negative indexes are supported, but be aware that they will cache the
362 remaining items in the source iterator, which may require significant
363 storage.
365 To check whether a peekable is exhausted, check its truth value:
367 >>> p = peekable(['a', 'b'])
368 >>> if p: # peekable has items
369 ... list(p)
370 ['a', 'b']
371 >>> if not p: # peekable is exhausted
372 ... list(p)
373 []
375 """
377 def __init__(self, iterable):
378 self._it = iter(iterable)
379 self._cache = deque()
381 def __iter__(self):
382 return self
384 def __bool__(self):
385 try:
386 self.peek()
387 except StopIteration:
388 return False
389 return True
391 def peek(self, default=_marker):
392 """Return the item that will be next returned from ``next()``.
394 Return ``default`` if there are no items left. If ``default`` is not
395 provided, raise ``StopIteration``.
397 """
398 if not self._cache:
399 try:
400 self._cache.append(next(self._it))
401 except StopIteration:
402 if default is _marker:
403 raise
404 return default
405 return self._cache[0]
407 def prepend(self, *items):
408 """Stack up items to be the next ones returned from ``next()`` or
409 ``self.peek()``. The items will be returned in
410 first in, first out order::
412 >>> p = peekable([1, 2, 3])
413 >>> p.prepend(10, 11, 12)
414 >>> next(p)
415 10
416 >>> list(p)
417 [11, 12, 1, 2, 3]
419 It is possible, by prepending items, to "resurrect" a peekable that
420 previously raised ``StopIteration``.
422 >>> p = peekable([])
423 >>> next(p)
424 Traceback (most recent call last):
425 ...
426 StopIteration
427 >>> p.prepend(1)
428 >>> next(p)
429 1
430 >>> next(p)
431 Traceback (most recent call last):
432 ...
433 StopIteration
435 """
436 self._cache.extendleft(reversed(items))
438 __class_getitem__ = classmethod(types.GenericAlias)
440 def __next__(self):
441 if self._cache:
442 return self._cache.popleft()
444 return next(self._it)
446 def _get_slice(self, index):
447 # Normalize the slice's arguments
448 step = 1 if (index.step is None) else index.step
449 if step > 0:
450 start = 0 if (index.start is None) else index.start
451 stop = maxsize if (index.stop is None) else index.stop
452 elif step < 0:
453 start = -1 if (index.start is None) else index.start
454 stop = (-maxsize - 1) if (index.stop is None) else index.stop
455 else:
456 raise ValueError('slice step cannot be zero')
458 # If either the start or stop index is negative, we'll need to cache
459 # the rest of the iterable in order to slice from the right side.
460 if (start < 0) or (stop < 0):
461 self._cache.extend(self._it)
462 # Otherwise we'll need to find the rightmost index and cache to that
463 # point.
464 else:
465 n = min(max(start, stop) + 1, maxsize)
466 cache_len = len(self._cache)
467 if n >= cache_len:
468 self._cache.extend(islice(self._it, n - cache_len))
470 return list(self._cache)[index]
472 def __getitem__(self, index):
473 if isinstance(index, slice):
474 return self._get_slice(index)
476 cache_len = len(self._cache)
477 if index < 0:
478 self._cache.extend(self._it)
479 elif index >= cache_len:
480 self._cache.extend(islice(self._it, index + 1 - cache_len))
482 return self._cache[index]
485def consumer(func):
486 """Decorator that automatically advances a PEP-342-style "reverse iterator"
487 to its first yield point so you don't have to call ``next()`` on it
488 manually.
490 >>> @consumer
491 ... def tally():
492 ... i = 0
493 ... while True:
494 ... print('Thing number %s is %s.' % (i, (yield)))
495 ... i += 1
496 ...
497 >>> t = tally()
498 >>> t.send('red')
499 Thing number 0 is red.
500 >>> t.send('fish')
501 Thing number 1 is fish.
503 Without the decorator, you would have to call ``next(t)`` before
504 ``t.send()`` could be used.
506 """
508 @wraps(func)
509 def wrapper(*args, **kwargs):
510 gen = func(*args, **kwargs)
511 next(gen)
512 return gen
514 return wrapper
517def ilen(iterable):
518 """Return the number of items in *iterable*.
520 For example, there are 168 prime numbers below 1,000:
522 >>> ilen(sieve(1000))
523 168
525 Equivalent to, but faster than::
527 def ilen(iterable):
528 count = 0
529 for _ in iterable:
530 count += 1
531 return count
533 This fully consumes the iterable, so handle with care.
535 """
536 # This is the "most beautiful of the fast variants" of this function.
537 # If you think you can improve on it, please ensure that your version
538 # is both 10x faster and 10x more beautiful.
539 return sum(compress(repeat(1), zip(iterable)))
542def iterate(func, start):
543 """Return ``start``, ``func(start)``, ``func(func(start))``, ...
545 Produces an infinite iterator. To add a stopping condition,
546 use :func:`take`, ``takewhile``, or :func:`takewhile_inclusive`:.
548 >>> take(10, iterate(lambda x: 2*x, 1))
549 [1, 2, 4, 8, 16, 32, 64, 128, 256, 512]
551 >>> collatz = lambda x: 3*x + 1 if x%2==1 else x // 2
552 >>> list(takewhile_inclusive(lambda x: x!=1, iterate(collatz, 10)))
553 [10, 5, 16, 8, 4, 2, 1]
555 """
556 with suppress(StopIteration):
557 while True:
558 yield start
559 start = func(start)
562def with_iter(context_manager):
563 """Wrap an iterable in a ``with`` statement, so it closes once exhausted.
565 For example, this will close the file when the iterator is exhausted::
567 upper_lines = (line.upper() for line in with_iter(open('foo')))
569 Note that you have to actually exhaust the iterator for opened files to be closed.
571 Any context manager which returns an iterable is a candidate for
572 ``with_iter``.
574 """
575 with context_manager as iterable:
576 yield from iterable
579class sized_iterator:
580 """Wrapper for *iterable* that implements ``__len__``.
582 >>> it = map(str, range(5))
583 >>> sized_it = sized_iterator(it, 5)
584 >>> len(sized_it)
585 5
586 >>> list(sized_it)
587 ['0', '1', '2', '3', '4']
589 This is useful for tools that use :func:`len`, like
590 `tqdm <https://pypi.org/project/tqdm/>`__ .
592 The wrapper doesn't validate the provided *length*, so be sure to choose
593 a value that reflects reality.
594 """
596 def __init__(self, iterable, length):
597 self._iterator = iter(iterable)
598 self._length = length
600 def __next__(self):
601 return next(self._iterator)
603 def __iter__(self):
604 return self
606 def __len__(self):
607 return self._length
610def one(iterable, too_short=None, too_long=None):
611 """Return the first item from *iterable*, which is expected to contain only
612 that item. Raise an exception if *iterable* is empty or has more than one
613 item.
615 :func:`one` is useful for ensuring that an iterable contains only one item.
616 For example, it can be used to retrieve the result of a database query
617 that is expected to return a single row.
619 If *iterable* is empty, ``ValueError`` will be raised. You may specify a
620 different exception with the *too_short* keyword:
622 >>> it = []
623 >>> one(it) # doctest: +IGNORE_EXCEPTION_DETAIL
624 Traceback (most recent call last):
625 ...
626 ValueError: too few items in iterable (expected 1)'
627 >>> too_short = IndexError('too few items')
628 >>> one(it, too_short=too_short) # doctest: +IGNORE_EXCEPTION_DETAIL
629 Traceback (most recent call last):
630 ...
631 IndexError: too few items
633 Similarly, if *iterable* contains more than one item, ``ValueError`` will
634 be raised. You may specify a different exception with the *too_long*
635 keyword:
637 >>> it = ['too', 'many']
638 >>> one(it) # doctest: +IGNORE_EXCEPTION_DETAIL
639 Traceback (most recent call last):
640 ...
641 ValueError: Expected exactly one item in iterable, but got 'too',
642 'many', and perhaps more.
643 >>> too_long = RuntimeError
644 >>> one(it, too_long=too_long) # doctest: +IGNORE_EXCEPTION_DETAIL
645 Traceback (most recent call last):
646 ...
647 RuntimeError
649 Note that :func:`one` attempts to advance *iterable* twice to ensure there
650 is only one item. See :func:`spy` or :func:`peekable` to check iterable
651 contents less destructively.
653 """
654 iterator = iter(iterable)
655 for first in iterator:
656 for second in iterator:
657 msg = (
658 f'Expected exactly one item in iterable, but got {first!r}, '
659 f'{second!r}, and perhaps more.'
660 )
661 raise too_long or ValueError(msg)
662 return first
663 raise too_short or ValueError('too few items in iterable (expected 1)')
666def raise_(exception, *args):
667 raise exception(*args)
670def strictly_n(iterable, n, too_short=None, too_long=None):
671 """Validate that *iterable* has exactly *n* items and return them if
672 it does. If it has fewer than *n* items, call function *too_short*
673 with the actual number of items. If it has more than *n* items, call function
674 *too_long* with the number ``n + 1``.
676 >>> iterable = ['a', 'b', 'c', 'd']
677 >>> n = 4
678 >>> list(strictly_n(iterable, n))
679 ['a', 'b', 'c', 'd']
681 Note that the returned iterable must be consumed in order for the check to
682 be made.
684 By default, *too_short* and *too_long* are functions that raise
685 ``ValueError``.
687 >>> list(strictly_n('ab', 3)) # doctest: +IGNORE_EXCEPTION_DETAIL
688 Traceback (most recent call last):
689 ...
690 ValueError: too few items in iterable (got 2)
692 >>> list(strictly_n('abc', 2)) # doctest: +IGNORE_EXCEPTION_DETAIL
693 Traceback (most recent call last):
694 ...
695 ValueError: too many items in iterable (got at least 3)
697 You can instead supply functions that do something else.
698 *too_short* will be called with the number of items in *iterable*.
699 *too_long* will be called with `n + 1`.
701 >>> def too_short(item_count):
702 ... raise RuntimeError
703 >>> it = strictly_n('abcd', 6, too_short=too_short)
704 >>> list(it) # doctest: +IGNORE_EXCEPTION_DETAIL
705 Traceback (most recent call last):
706 ...
707 RuntimeError
709 >>> def too_long(item_count):
710 ... print('The boss is going to hear about this')
711 >>> it = strictly_n('abcdef', 4, too_long=too_long)
712 >>> list(it)
713 The boss is going to hear about this
714 ['a', 'b', 'c', 'd']
716 """
717 if too_short is None:
718 too_short = lambda item_count: raise_(
719 ValueError,
720 f'Too few items in iterable (got {item_count})',
721 )
723 if too_long is None:
724 too_long = lambda item_count: raise_(
725 ValueError,
726 f'Too many items in iterable (got at least {item_count})',
727 )
729 it = iter(iterable)
731 sent = 0
732 for item in islice(it, n):
733 yield item
734 sent += 1
736 if sent < n:
737 too_short(sent)
738 return
740 for item in it:
741 too_long(n + 1)
742 return
745def distinct_permutations(iterable, r=None):
746 """Yield successive distinct permutations of the elements in *iterable*.
748 >>> sorted(distinct_permutations([1, 0, 1]))
749 [(0, 1, 1), (1, 0, 1), (1, 1, 0)]
751 Equivalent to yielding from ``set(permutations(iterable))``, except
752 duplicates are not generated and thrown away. For larger input sequences
753 this is much more efficient.
755 Duplicate permutations arise when there are duplicated elements in the
756 input iterable. The number of items returned is
757 `n! / (x_1! * x_2! * ... * x_n!)`, where `n` is the total number of
758 items input, and each `x_i` is the count of a distinct item in the input
759 sequence. The function :func:`multinomial` computes this directly.
761 If *r* is given, only the *r*-length permutations are yielded.
763 >>> sorted(distinct_permutations([1, 0, 1], r=2))
764 [(0, 1), (1, 0), (1, 1)]
765 >>> sorted(distinct_permutations(range(3), r=2))
766 [(0, 1), (0, 2), (1, 0), (1, 2), (2, 0), (2, 1)]
768 *iterable* need not be sortable, but note that using equal (``x == y``)
769 but non-identical (``id(x) != id(y)``) elements may produce surprising
770 behavior. For example, ``1`` and ``True`` are equal but non-identical:
772 >>> list(distinct_permutations([1, True, '3'])) # doctest: +SKIP
773 [
774 (1, True, '3'),
775 (1, '3', True),
776 ('3', 1, True)
777 ]
778 >>> list(distinct_permutations([1, 2, '3'])) # doctest: +SKIP
779 [
780 (1, 2, '3'),
781 (1, '3', 2),
782 (2, 1, '3'),
783 (2, '3', 1),
784 ('3', 1, 2),
785 ('3', 2, 1)
786 ]
787 """
789 # Algorithm: https://w.wiki/Qai
790 def _full(A):
791 while True:
792 # Yield the permutation we have
793 yield tuple(A)
795 # Find the largest index i such that A[i] < A[i + 1]
796 for i in range(size - 2, -1, -1):
797 if A[i] < A[i + 1]:
798 break
799 # If no such index exists, this permutation is the last one
800 else:
801 return
803 # Find the largest index j greater than j such that A[i] < A[j]
804 for j in range(size - 1, i, -1):
805 if A[i] < A[j]:
806 break
808 # Swap the value of A[i] with that of A[j], then reverse the
809 # sequence from A[i + 1] to form the new permutation
810 A[i], A[j] = A[j], A[i]
811 A[i + 1 :] = A[: i - size : -1] # A[i + 1:][::-1]
813 # Algorithm: modified from the above
814 def _partial(A, r):
815 # Split A into the first r items and the last r items
816 head, tail = A[:r], A[r:]
817 right_head_indexes = range(r - 1, -1, -1)
818 left_tail_indexes = range(len(tail))
820 while True:
821 # Yield the permutation we have
822 yield tuple(head)
824 # Starting from the right, find the first index of the head with
825 # value smaller than the maximum value of the tail - call it i.
826 pivot = tail[-1]
827 for i in right_head_indexes:
828 if head[i] < pivot:
829 break
830 pivot = head[i]
831 else:
832 return
834 # Starting from the left, find the first value of the tail
835 # with a value greater than head[i] and swap.
836 for j in left_tail_indexes:
837 if tail[j] > head[i]:
838 head[i], tail[j] = tail[j], head[i]
839 break
840 # If we didn't find one, start from the right and find the first
841 # index of the head with a value greater than head[i] and swap.
842 else:
843 for j in right_head_indexes:
844 if head[j] > head[i]:
845 head[i], head[j] = head[j], head[i]
846 break
848 # Reverse head[i + 1:] and swap it with tail[:r - (i + 1)]
849 tail += head[: i - r : -1] # head[i + 1:][::-1]
850 i += 1
851 head[i:], tail[:] = tail[: r - i], tail[r - i :]
853 items = list(iterable)
855 try:
856 items.sort()
857 sortable = True
858 except TypeError:
859 sortable = False
861 indices_dict = defaultdict(list)
863 for item in items:
864 indices_dict[items.index(item)].append(item)
866 indices = [items.index(item) for item in items]
867 indices.sort()
869 equivalent_items = {k: cycle(v) for k, v in indices_dict.items()}
871 def permuted_items(permuted_indices):
872 return tuple(
873 next(equivalent_items[index]) for index in permuted_indices
874 )
876 size = len(items)
877 if r is None:
878 r = size
880 # functools.partial(_partial, ... )
881 algorithm = _full if (r == size) else partial(_partial, r=r)
883 if 0 < r <= size:
884 if sortable:
885 return algorithm(items)
886 else:
887 return (
888 permuted_items(permuted_indices)
889 for permuted_indices in algorithm(indices)
890 )
892 return iter(() if r else ((),))
895def derangements(iterable, r=None):
896 """Yield successive derangements of the elements in *iterable*.
898 A derangement is a permutation in which no element appears at its original
899 index. In other words, a derangement is a permutation that has no fixed points.
901 Suppose Alice, Bob, Carol, and Dave are playing Secret Santa.
902 The code below outputs all of the different ways to assign gift recipients
903 such that nobody is assigned to himself or herself:
905 >>> for d in derangements(['Alice', 'Bob', 'Carol', 'Dave']):
906 ... print(', '.join(d))
907 Bob, Alice, Dave, Carol
908 Bob, Carol, Dave, Alice
909 Bob, Dave, Alice, Carol
910 Carol, Alice, Dave, Bob
911 Carol, Dave, Alice, Bob
912 Carol, Dave, Bob, Alice
913 Dave, Alice, Bob, Carol
914 Dave, Carol, Alice, Bob
915 Dave, Carol, Bob, Alice
917 If *r* is given, only the *r*-length derangements are yielded.
919 >>> sorted(derangements(range(3), 2))
920 [(1, 0), (1, 2), (2, 0)]
921 >>> sorted(derangements([0, 2, 3], 2))
922 [(2, 0), (2, 3), (3, 0)]
924 Elements are treated as unique based on their position, not on their value.
926 Consider the Secret Santa example with two *different* people who have
927 the *same* name. Then there are two valid gift assignments even though
928 it might appear that a person is assigned to themselves:
930 >>> names = ['Alice', 'Bob', 'Bob']
931 >>> list(derangements(names))
932 [('Bob', 'Bob', 'Alice'), ('Bob', 'Alice', 'Bob')]
934 To avoid confusion, make the inputs distinct:
936 >>> deduped = [f'{name}{index}' for index, name in enumerate(names)]
937 >>> list(derangements(deduped))
938 [('Bob1', 'Bob2', 'Alice0'), ('Bob2', 'Alice0', 'Bob1')]
940 The number of derangements of a set of size *n* is known as the
941 "subfactorial of n". For n > 0, the subfactorial is:
942 ``round(math.factorial(n) / math.e)``.
944 References:
946 * Article: https://www.numberanalytics.com/blog/ultimate-guide-to-derangements-in-combinatorics
947 * Sizes: https://oeis.org/A000166
948 """
949 xs = tuple(iterable)
950 ys = tuple(range(len(xs)))
951 return compress(
952 permutations(xs, r=r),
953 map(all, map(map, repeat(is_not), repeat(ys), permutations(ys, r=r))),
954 )
957def intersperse(e, iterable, n=1):
958 """Intersperse filler element *e* among the items in *iterable*, leaving
959 *n* items between each filler element.
961 >>> list(intersperse('!', [1, 2, 3, 4, 5]))
962 [1, '!', 2, '!', 3, '!', 4, '!', 5]
964 >>> list(intersperse(None, [1, 2, 3, 4, 5], n=2))
965 [1, 2, None, 3, 4, None, 5]
967 """
968 if n == 0:
969 raise ValueError('n must be > 0')
970 elif n == 1:
971 # interleave(repeat(e), iterable) -> e, x_0, e, x_1, e, x_2...
972 # islice(..., 1, None) -> x_0, e, x_1, e, x_2...
973 return islice(interleave(repeat(e), iterable), 1, None)
974 else:
975 # interleave(filler, chunks) -> [e], [x_0, x_1], [e], [x_2, x_3]...
976 # islice(..., 1, None) -> [x_0, x_1], [e], [x_2, x_3]...
977 # flatten(...) -> x_0, x_1, e, x_2, x_3...
978 filler = repeat([e])
979 chunks = chunked(iterable, n)
980 return flatten(islice(interleave(filler, chunks), 1, None))
983def unique_to_each(*iterables):
984 """Return the elements from each of the input iterables that aren't in the
985 other input iterables.
987 For example, suppose you have a set of packages, each with a set of
988 dependencies::
990 {'pkg_1': {'A', 'B'}, 'pkg_2': {'B', 'C'}, 'pkg_3': {'B', 'D'}}
992 If you remove one package, which dependencies can also be removed?
994 If ``pkg_1`` is removed, then ``A`` is no longer necessary - it is not
995 associated with ``pkg_2`` or ``pkg_3``. Similarly, ``C`` is only needed for
996 ``pkg_2``, and ``D`` is only needed for ``pkg_3``::
998 >>> unique_to_each({'A', 'B'}, {'B', 'C'}, {'B', 'D'})
999 [['A'], ['C'], ['D']]
1001 If there are duplicates in one input iterable that aren't in the others
1002 they will be duplicated in the output. Input order is preserved::
1004 >>> unique_to_each("mississippi", "missouri")
1005 [['p', 'p'], ['o', 'u', 'r']]
1007 It is assumed that the elements of each iterable are hashable.
1009 """
1010 pool = [list(it) for it in iterables]
1011 counts = Counter(chain.from_iterable(map(set, pool)))
1012 uniques = {element for element in counts if counts[element] == 1}
1013 return [list(filter(uniques.__contains__, it)) for it in pool]
1016def windowed(seq, n, fillvalue=None, step=1):
1017 """Return a sliding window of width *n* over the given iterable.
1019 >>> all_windows = windowed([1, 2, 3, 4, 5], 3)
1020 >>> list(all_windows)
1021 [(1, 2, 3), (2, 3, 4), (3, 4, 5)]
1023 When the window is larger than the iterable, *fillvalue* is used in place
1024 of missing values:
1026 >>> list(windowed([1, 2, 3], 4))
1027 [(1, 2, 3, None)]
1029 Each window will advance in increments of *step*:
1031 >>> list(windowed([1, 2, 3, 4, 5, 6], 3, fillvalue='!', step=2))
1032 [(1, 2, 3), (3, 4, 5), (5, 6, '!')]
1034 To slide into the iterable's items, use :func:`chain` to add filler items
1035 to the left:
1037 >>> iterable = [1, 2, 3, 4]
1038 >>> n = 3
1039 >>> padding = [None] * (n - 1)
1040 >>> list(windowed(chain(padding, iterable), 3))
1041 [(None, None, 1), (None, 1, 2), (1, 2, 3), (2, 3, 4)]
1042 """
1043 if n <= 0:
1044 raise ValueError('n must be > 0')
1045 if step < 1:
1046 raise ValueError('step must be >= 1')
1048 iterator = iter(seq)
1050 # Generate first window
1051 window = deque(islice(iterator, n), maxlen=n)
1053 # Deal with the first window not being full
1054 if not window:
1055 return
1056 if len(window) < n:
1057 yield tuple(window) + ((fillvalue,) * (n - len(window)))
1058 return
1059 yield tuple(window)
1061 # Create the filler for the next windows. The padding ensures
1062 # we have just enough elements to fill the last window.
1063 padding = (fillvalue,) * (n - 1 if step >= n else step - 1)
1064 filler = map(window.append, chain(iterator, padding))
1066 # Generate the rest of the windows
1067 for _ in islice(filler, step - 1, None, step):
1068 yield tuple(window)
1071def substrings(iterable):
1072 """Yield all of the substrings of *iterable*.
1074 >>> [''.join(s) for s in substrings('more')]
1075 ['m', 'o', 'r', 'e', 'mo', 'or', 're', 'mor', 'ore', 'more']
1077 Note that non-string iterables can also be subdivided.
1079 >>> list(substrings([0, 1, 2]))
1080 [(0,), (1,), (2,), (0, 1), (1, 2), (0, 1, 2)]
1082 Like subslices() but returns tuples instead of lists
1083 and returns the shortest substrings first.
1085 """
1086 seq = tuple(iterable)
1087 item_count = len(seq)
1088 for n in range(1, item_count + 1):
1089 slices = map(slice, range(item_count), range(n, item_count + 1))
1090 yield from map(getitem, repeat(seq), slices)
1093def substrings_indexes(seq, reverse=False):
1094 """Yield all substrings and their positions in *seq*
1096 The items yielded will be a tuple of the form ``(substr, i, j)``, where
1097 ``substr == seq[i:j]``.
1099 This function only works for iterables that support slicing, such as
1100 ``str`` objects.
1102 >>> for item in substrings_indexes('more'):
1103 ... print(item)
1104 ('m', 0, 1)
1105 ('o', 1, 2)
1106 ('r', 2, 3)
1107 ('e', 3, 4)
1108 ('mo', 0, 2)
1109 ('or', 1, 3)
1110 ('re', 2, 4)
1111 ('mor', 0, 3)
1112 ('ore', 1, 4)
1113 ('more', 0, 4)
1115 Set *reverse* to ``True`` to yield the same items in the opposite order.
1118 """
1119 r = range(1, len(seq) + 1)
1120 if reverse:
1121 r = reversed(r)
1122 return (
1123 (seq[i : i + L], i, i + L) for L in r for i in range(len(seq) - L + 1)
1124 )
1127class bucket:
1128 """Wrap *iterable* and return an object that buckets the iterable into
1129 child iterables based on a *key* function.
1131 >>> iterable = ['a1', 'b1', 'c1', 'a2', 'b2', 'c2', 'b3']
1132 >>> s = bucket(iterable, key=lambda x: x[0]) # Bucket by 1st character
1133 >>> sorted(list(s)) # Get the keys
1134 ['a', 'b', 'c']
1135 >>> a_iterable = s['a']
1136 >>> next(a_iterable)
1137 'a1'
1138 >>> next(a_iterable)
1139 'a2'
1140 >>> list(s['b'])
1141 ['b1', 'b2', 'b3']
1143 The original iterable will be advanced and its items will be cached until
1144 they are used by the child iterables. This may require significant storage.
1146 By default, attempting to select a bucket to which no items belong will
1147 exhaust the iterable and cache all values.
1148 If you specify a *validator* function, selected buckets will instead be
1149 checked against it.
1151 >>> from itertools import count
1152 >>> it = count(1, 2) # Infinite sequence of odd numbers
1153 >>> key = lambda x: x % 10 # Bucket by last digit
1154 >>> validator = lambda x: x in {1, 3, 5, 7, 9} # Odd digits only
1155 >>> s = bucket(it, key=key, validator=validator)
1156 >>> 2 in s
1157 False
1158 >>> list(s[2])
1159 []
1161 .. seealso:: :func:`map_reduce`, :func:`groupby_transform`
1163 If storage is not a concern, :func:`map_reduce` returns a Python
1164 dictionary, which is generally easier to work with. If the elements
1165 with the same key are already adjacent, :func:`groupby_transform`
1166 or :func:`itertools.groupby` can be used without any caching overhead.
1168 """
1170 def __init__(self, iterable, key, validator=None):
1171 self._it = iter(iterable)
1172 self._key = key
1173 self._cache = defaultdict(deque)
1174 self._validator = validator or (lambda x: True)
1176 def __contains__(self, value):
1177 if not self._validator(value):
1178 return False
1180 try:
1181 item = next(self[value])
1182 except StopIteration:
1183 return False
1184 else:
1185 self._cache[value].appendleft(item)
1187 return True
1189 def _get_values(self, value):
1190 """
1191 Helper to yield items from the parent iterator that match *value*.
1192 Items that don't match are stored in the local cache as they
1193 are encountered.
1194 """
1195 while True:
1196 # If we've cached some items that match the target value, emit
1197 # the first one and evict it from the cache.
1198 if self._cache[value]:
1199 yield self._cache[value].popleft()
1200 # Otherwise we need to advance the parent iterator to search for
1201 # a matching item, caching the rest.
1202 else:
1203 while True:
1204 try:
1205 item = next(self._it)
1206 except StopIteration:
1207 return
1208 item_value = self._key(item)
1209 if item_value == value:
1210 yield item
1211 break
1212 elif self._validator(item_value):
1213 self._cache[item_value].append(item)
1215 def __iter__(self):
1216 for item in self._it:
1217 item_value = self._key(item)
1218 if self._validator(item_value):
1219 self._cache[item_value].append(item)
1221 return iter(self._cache)
1223 def __getitem__(self, value):
1224 if not self._validator(value):
1225 return iter(())
1227 return self._get_values(value)
1230def spy(iterable, n=1):
1231 """Return a 2-tuple with a list containing the first *n* elements of
1232 *iterable*, and an iterator with the same items as *iterable*.
1233 This allows you to "look ahead" at the items in the iterable without
1234 advancing it.
1236 There is one item in the list by default:
1238 >>> iterable = 'abcdefg'
1239 >>> head, iterable = spy(iterable)
1240 >>> head
1241 ['a']
1242 >>> list(iterable)
1243 ['a', 'b', 'c', 'd', 'e', 'f', 'g']
1245 You may use unpacking to retrieve items instead of lists:
1247 >>> (head,), iterable = spy('abcdefg')
1248 >>> head
1249 'a'
1250 >>> (first, second), iterable = spy('abcdefg', 2)
1251 >>> first
1252 'a'
1253 >>> second
1254 'b'
1256 The number of items requested can be larger than the number of items in
1257 the iterable:
1259 >>> iterable = [1, 2, 3, 4, 5]
1260 >>> head, iterable = spy(iterable, 10)
1261 >>> head
1262 [1, 2, 3, 4, 5]
1263 >>> list(iterable)
1264 [1, 2, 3, 4, 5]
1266 """
1267 p, q = tee(iterable)
1268 return take(n, q), p
1271def interleave(*iterables):
1272 """Return a new iterable yielding from each iterable in turn,
1273 until the shortest is exhausted.
1275 >>> list(interleave([1, 2, 3], [4, 5], [6, 7, 8]))
1276 [1, 4, 6, 2, 5, 7]
1278 For a version that doesn't terminate after the shortest iterable is
1279 exhausted, see :func:`interleave_longest`.
1281 """
1282 return chain.from_iterable(zip(*iterables))
1285def interleave_longest(*iterables):
1286 """Return a new iterable yielding from each iterable in turn,
1287 skipping any that are exhausted.
1289 >>> list(interleave_longest([1, 2, 3], [4, 5], [6, 7, 8]))
1290 [1, 4, 6, 2, 5, 7, 3, 8]
1292 This function produces the same output as :func:`roundrobin`, but may
1293 perform better for some inputs (in particular when the number of iterables
1294 is large).
1296 """
1297 for xs in zip_longest(*iterables, fillvalue=_marker):
1298 for x in xs:
1299 if x is not _marker:
1300 yield x
1303def interleave_evenly(iterables, lengths=None):
1304 """
1305 Interleave multiple iterables so that their elements are evenly distributed
1306 throughout the output sequence.
1308 >>> iterables = [1, 2, 3, 4, 5], ['a', 'b']
1309 >>> list(interleave_evenly(iterables))
1310 [1, 2, 'a', 3, 4, 'b', 5]
1312 >>> iterables = [[1, 2, 3], [4, 5], [6, 7, 8]]
1313 >>> list(interleave_evenly(iterables))
1314 [1, 6, 4, 2, 7, 3, 8, 5]
1316 This function requires iterables of known length. Iterables without
1317 ``__len__()`` can be used by manually specifying lengths with *lengths*:
1319 >>> from itertools import combinations, repeat
1320 >>> iterables = [combinations(range(4), 2), ['a', 'b', 'c']]
1321 >>> lengths = [4 * (4 - 1) // 2, 3]
1322 >>> list(interleave_evenly(iterables, lengths=lengths))
1323 [(0, 1), (0, 2), 'a', (0, 3), (1, 2), 'b', (1, 3), (2, 3), 'c']
1325 Based on Bresenham's algorithm.
1326 """
1327 if lengths is None:
1328 try:
1329 lengths = [len(it) for it in iterables]
1330 except TypeError:
1331 raise ValueError(
1332 'Iterable lengths could not be determined automatically. '
1333 'Specify them with the lengths keyword.'
1334 )
1335 elif len(iterables) != len(lengths):
1336 raise ValueError('Mismatching number of iterables and lengths.')
1338 dims = len(lengths)
1340 # sort iterables by length, descending
1341 lengths_permute = sorted(
1342 range(dims), key=lambda i: lengths[i], reverse=True
1343 )
1344 lengths_desc = [lengths[i] for i in lengths_permute]
1345 iters_desc = [iter(iterables[i]) for i in lengths_permute]
1347 # the longest iterable is the primary one (Bresenham: the longest
1348 # distance along an axis)
1349 delta_primary, deltas_secondary = lengths_desc[0], lengths_desc[1:]
1350 iter_primary, iters_secondary = iters_desc[0], iters_desc[1:]
1351 errors = [delta_primary // dims] * len(deltas_secondary)
1353 to_yield = sum(lengths)
1354 while to_yield:
1355 yield next(iter_primary)
1356 to_yield -= 1
1357 # update errors for each secondary iterable
1358 errors = [e - delta for e, delta in zip(errors, deltas_secondary)]
1360 # those iterables for which the error is negative are yielded
1361 # ("diagonal step" in Bresenham)
1362 for i, e_ in enumerate(errors):
1363 if e_ < 0:
1364 yield next(iters_secondary[i])
1365 to_yield -= 1
1366 errors[i] += delta_primary
1369def interleave_randomly(*iterables):
1370 """Repeatedly select one of the input *iterables* at random and yield the next
1371 item from it.
1373 >>> iterables = [1, 2, 3], 'abc', (True, False, None)
1374 >>> list(interleave_randomly(*iterables)) # doctest: +SKIP
1375 ['a', 'b', 1, 'c', True, False, None, 2, 3]
1377 The relative order of the items in each input iterable will preserved. Note the
1378 sequences of items with this property are not equally likely to be generated.
1380 """
1381 iterators = [iter(e) for e in iterables]
1382 while iterators:
1383 idx = randrange(len(iterators))
1384 try:
1385 yield next(iterators[idx])
1386 except StopIteration:
1387 # equivalent to `list.pop` but slightly faster
1388 iterators[idx] = iterators[-1]
1389 del iterators[-1]
1392def collapse(iterable, base_type=None, levels=None):
1393 """Flatten an iterable with multiple levels of nesting (e.g., a list of
1394 lists of tuples) into non-iterable types.
1396 >>> iterable = [(1, 2), ([3, 4], [[5], [6]])]
1397 >>> list(collapse(iterable))
1398 [1, 2, 3, 4, 5, 6]
1400 Binary and text strings are not considered iterable and
1401 will not be collapsed.
1403 To avoid collapsing other types, specify *base_type*:
1405 >>> iterable = ['ab', ('cd', 'ef'), ['gh', 'ij']]
1406 >>> list(collapse(iterable, base_type=tuple))
1407 ['ab', ('cd', 'ef'), 'gh', 'ij']
1409 Specify *levels* to stop flattening after a certain level:
1411 >>> iterable = [('a', ['b']), ('c', ['d'])]
1412 >>> list(collapse(iterable)) # Fully flattened
1413 ['a', 'b', 'c', 'd']
1414 >>> list(collapse(iterable, levels=1)) # Only one level flattened
1415 ['a', ['b'], 'c', ['d']]
1417 """
1418 stack = deque()
1419 # Add our first node group, treat the iterable as a single node
1420 stack.appendleft((0, repeat(iterable, 1)))
1422 while stack:
1423 node_group = stack.popleft()
1424 level, nodes = node_group
1426 # Check if beyond max level
1427 if levels is not None and level > levels:
1428 yield from nodes
1429 continue
1431 for node in nodes:
1432 # Check if done iterating
1433 if isinstance(node, (str, bytes)) or (
1434 (base_type is not None) and isinstance(node, base_type)
1435 ):
1436 yield node
1437 # Otherwise try to create child nodes
1438 else:
1439 try:
1440 tree = iter(node)
1441 except TypeError:
1442 yield node
1443 else:
1444 # Save our current location
1445 stack.appendleft(node_group)
1446 # Append the new child node
1447 stack.appendleft((level + 1, tree))
1448 # Break to process child node
1449 break
1452def side_effect(func, iterable, chunk_size=None, before=None, after=None):
1453 """Invoke *func* on each item in *iterable* (or on each *chunk_size* group
1454 of items) before yielding the item.
1456 `func` must be a function that takes a single argument. Its return value
1457 will be discarded.
1459 *before* and *after* are optional functions that take no arguments. They
1460 will be executed before iteration starts and after it ends, respectively.
1462 `side_effect` can be used for logging, updating progress bars, or anything
1463 that is not functionally "pure."
1465 Emitting a status message:
1467 >>> from more_itertools import consume
1468 >>> func = lambda item: print('Received {}'.format(item))
1469 >>> consume(side_effect(func, range(2)))
1470 Received 0
1471 Received 1
1473 Operating on chunks of items:
1475 >>> pair_sums = []
1476 >>> func = lambda chunk: pair_sums.append(sum(chunk))
1477 >>> list(side_effect(func, [0, 1, 2, 3, 4, 5], 2))
1478 [0, 1, 2, 3, 4, 5]
1479 >>> list(pair_sums)
1480 [1, 5, 9]
1482 Writing to a file-like object:
1484 >>> from io import StringIO
1485 >>> from more_itertools import consume
1486 >>> f = StringIO()
1487 >>> func = lambda x: print(x, file=f)
1488 >>> before = lambda: print(u'HEADER', file=f)
1489 >>> after = f.close
1490 >>> it = [u'a', u'b', u'c']
1491 >>> consume(side_effect(func, it, before=before, after=after))
1492 >>> f.closed
1493 True
1495 """
1496 try:
1497 if before is not None:
1498 before()
1500 if chunk_size is None:
1501 for item in iterable:
1502 func(item)
1503 yield item
1504 else:
1505 for chunk in chunked(iterable, chunk_size):
1506 func(chunk)
1507 yield from chunk
1508 finally:
1509 if after is not None:
1510 after()
1513def sliced(seq, n, strict=False):
1514 """Yield slices of length *n* from the sequence *seq*.
1516 >>> list(sliced((1, 2, 3, 4, 5, 6), 3))
1517 [(1, 2, 3), (4, 5, 6)]
1519 By the default, the last yielded slice will have fewer than *n* elements
1520 if the length of *seq* is not divisible by *n*:
1522 >>> list(sliced((1, 2, 3, 4, 5, 6, 7, 8), 3))
1523 [(1, 2, 3), (4, 5, 6), (7, 8)]
1525 If the length of *seq* is not divisible by *n* and *strict* is
1526 ``True``, then ``ValueError`` will be raised before the last
1527 slice is yielded.
1529 This function will only work for iterables that support slicing.
1530 For non-sliceable iterables, see :func:`chunked`.
1532 """
1533 iterator = takewhile(len, (seq[i : i + n] for i in count(0, n)))
1534 if strict:
1536 def ret():
1537 for _slice in iterator:
1538 if len(_slice) != n:
1539 raise ValueError("seq is not divisible by n.")
1540 yield _slice
1542 return ret()
1543 else:
1544 return iterator
1547def split_at(iterable, pred, maxsplit=-1, keep_separator=False):
1548 """Yield lists of items from *iterable*, where each list is delimited by
1549 an item where callable *pred* returns ``True``.
1551 >>> list(split_at('abcdcba', lambda x: x == 'b'))
1552 [['a'], ['c', 'd', 'c'], ['a']]
1554 >>> list(split_at(range(10), lambda n: n % 2 == 1))
1555 [[0], [2], [4], [6], [8], []]
1557 At most *maxsplit* splits are done. If *maxsplit* is not specified or -1,
1558 then there is no limit on the number of splits:
1560 >>> list(split_at(range(10), lambda n: n % 2 == 1, maxsplit=2))
1561 [[0], [2], [4, 5, 6, 7, 8, 9]]
1563 By default, the delimiting items are not included in the output.
1564 To include them, set *keep_separator* to ``True``.
1566 >>> list(split_at('abcdcba', lambda x: x == 'b', keep_separator=True))
1567 [['a'], ['b'], ['c', 'd', 'c'], ['b'], ['a']]
1569 """
1570 if maxsplit == 0:
1571 yield list(iterable)
1572 return
1574 buf = []
1575 it = iter(iterable)
1576 for item in it:
1577 if pred(item):
1578 yield buf
1579 if keep_separator:
1580 yield [item]
1581 if maxsplit == 1:
1582 yield list(it)
1583 return
1584 buf = []
1585 maxsplit -= 1
1586 else:
1587 buf.append(item)
1588 yield buf
1591def split_before(iterable, pred, maxsplit=-1):
1592 """Yield lists of items from *iterable*, where each list ends just before
1593 an item for which callable *pred* returns ``True``:
1595 >>> list(split_before('OneTwo', lambda s: s.isupper()))
1596 [['O', 'n', 'e'], ['T', 'w', 'o']]
1598 >>> list(split_before(range(10), lambda n: n % 3 == 0))
1599 [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
1601 At most *maxsplit* splits are done. If *maxsplit* is not specified or -1,
1602 then there is no limit on the number of splits:
1604 >>> list(split_before(range(10), lambda n: n % 3 == 0, maxsplit=2))
1605 [[0, 1, 2], [3, 4, 5], [6, 7, 8, 9]]
1606 """
1607 if maxsplit == 0:
1608 yield list(iterable)
1609 return
1611 buf = []
1612 it = iter(iterable)
1613 for item in it:
1614 if pred(item) and buf:
1615 yield buf
1616 if maxsplit == 1:
1617 yield [item, *it]
1618 return
1619 buf = []
1620 maxsplit -= 1
1621 buf.append(item)
1622 if buf:
1623 yield buf
1626def split_after(iterable, pred, maxsplit=-1):
1627 """Yield lists of items from *iterable*, where each list ends with an
1628 item where callable *pred* returns ``True``:
1630 >>> list(split_after('one1two2', lambda s: s.isdigit()))
1631 [['o', 'n', 'e', '1'], ['t', 'w', 'o', '2']]
1633 >>> list(split_after(range(10), lambda n: n % 3 == 0))
1634 [[0], [1, 2, 3], [4, 5, 6], [7, 8, 9]]
1636 At most *maxsplit* splits are done. If *maxsplit* is not specified or -1,
1637 then there is no limit on the number of splits:
1639 >>> list(split_after(range(10), lambda n: n % 3 == 0, maxsplit=2))
1640 [[0], [1, 2, 3], [4, 5, 6, 7, 8, 9]]
1642 """
1643 if maxsplit == 0:
1644 yield list(iterable)
1645 return
1647 buf = []
1648 it = iter(iterable)
1649 for item in it:
1650 buf.append(item)
1651 if pred(item) and buf:
1652 yield buf
1653 if maxsplit == 1:
1654 buf = list(it)
1655 if buf:
1656 yield buf
1657 return
1658 buf = []
1659 maxsplit -= 1
1660 if buf:
1661 yield buf
1664def split_when(iterable, pred, maxsplit=-1):
1665 """Split *iterable* into pieces based on the output of *pred*.
1666 *pred* should be a function that takes successive pairs of items and
1667 returns ``True`` if the iterable should be split in between them.
1669 For example, to find runs of increasing numbers, split the iterable when
1670 element ``i`` is larger than element ``i + 1``:
1672 >>> list(split_when([1, 2, 3, 3, 2, 5, 2, 4, 2], lambda x, y: x > y))
1673 [[1, 2, 3, 3], [2, 5], [2, 4], [2]]
1675 At most *maxsplit* splits are done. If *maxsplit* is not specified or -1,
1676 then there is no limit on the number of splits:
1678 >>> list(split_when([1, 2, 3, 3, 2, 5, 2, 4, 2],
1679 ... lambda x, y: x > y, maxsplit=2))
1680 [[1, 2, 3, 3], [2, 5], [2, 4, 2]]
1682 """
1683 if maxsplit == 0:
1684 yield list(iterable)
1685 return
1687 it = iter(iterable)
1688 try:
1689 cur_item = next(it)
1690 except StopIteration:
1691 return
1693 buf = [cur_item]
1694 for next_item in it:
1695 if pred(cur_item, next_item):
1696 yield buf
1697 if maxsplit == 1:
1698 yield [next_item, *it]
1699 return
1700 buf = []
1701 maxsplit -= 1
1703 buf.append(next_item)
1704 cur_item = next_item
1706 yield buf
1709def split_into(iterable, sizes):
1710 """Yield a list of sequential items from *iterable* of length 'n' for each
1711 integer 'n' in *sizes*.
1713 >>> list(split_into([1,2,3,4,5,6], [1,2,3]))
1714 [[1], [2, 3], [4, 5, 6]]
1716 If the sum of *sizes* is smaller than the length of *iterable*, then the
1717 remaining items of *iterable* will not be returned.
1719 >>> list(split_into([1,2,3,4,5,6], [2,3]))
1720 [[1, 2], [3, 4, 5]]
1722 If the sum of *sizes* is larger than the length of *iterable*, fewer items
1723 will be returned in the iteration that overruns the *iterable* and further
1724 lists will be empty:
1726 >>> list(split_into([1,2,3,4], [1,2,3,4]))
1727 [[1], [2, 3], [4], []]
1729 When a ``None`` object is encountered in *sizes*, the returned list will
1730 contain items up to the end of *iterable* the same way that
1731 :func:`itertools.slice` does:
1733 >>> list(split_into([1,2,3,4,5,6,7,8,9,0], [2,3,None]))
1734 [[1, 2], [3, 4, 5], [6, 7, 8, 9, 0]]
1736 :func:`split_into` can be useful for grouping a series of items where the
1737 sizes of the groups are not uniform. An example would be where in a row
1738 from a table, multiple columns represent elements of the same feature
1739 (e.g. a point represented by x,y,z) but, the format is not the same for
1740 all columns.
1741 """
1742 # convert the iterable argument into an iterator so its contents can
1743 # be consumed by islice in case it is a generator
1744 it = iter(iterable)
1746 for size in sizes:
1747 if size is None:
1748 yield list(it)
1749 return
1750 else:
1751 yield list(islice(it, size))
1754def padded(iterable, fillvalue=None, n=None, next_multiple=False):
1755 """Yield the elements from *iterable*, followed by *fillvalue*, such that
1756 at least *n* items are emitted.
1758 >>> list(padded([1, 2, 3], '?', 5))
1759 [1, 2, 3, '?', '?']
1761 If *next_multiple* is ``True``, *fillvalue* will be emitted until the
1762 number of items emitted is a multiple of *n*:
1764 >>> list(padded([1, 2, 3, 4], n=3, next_multiple=True))
1765 [1, 2, 3, 4, None, None]
1767 If *n* is ``None``, *fillvalue* will be emitted indefinitely.
1769 To create an *iterable* of exactly size *n*, you can truncate with
1770 :func:`islice`.
1772 >>> list(islice(padded([1, 2, 3], '?'), 5))
1773 [1, 2, 3, '?', '?']
1774 >>> list(islice(padded([1, 2, 3, 4, 5, 6, 7, 8], '?'), 5))
1775 [1, 2, 3, 4, 5]
1777 """
1778 iterator = iter(iterable)
1779 iterator_with_repeat = chain(iterator, repeat(fillvalue))
1781 if n is None:
1782 return iterator_with_repeat
1783 elif n < 1:
1784 raise ValueError('n must be at least 1')
1785 elif next_multiple:
1787 def slice_generator():
1788 for first in iterator:
1789 yield (first,)
1790 yield islice(iterator_with_repeat, n - 1)
1792 # While elements exist produce slices of size n
1793 return chain.from_iterable(slice_generator())
1794 else:
1795 # Ensure the first batch is at least size n then iterate
1796 return chain(islice(iterator_with_repeat, n), iterator)
1799def repeat_each(iterable, n=2):
1800 """Repeat each element in *iterable* *n* times.
1802 >>> list(repeat_each('ABC', 3))
1803 ['A', 'A', 'A', 'B', 'B', 'B', 'C', 'C', 'C']
1804 """
1805 return chain.from_iterable(map(repeat, iterable, repeat(n)))
1808def repeat_last(iterable, default=None):
1809 """After the *iterable* is exhausted, keep yielding its last element.
1811 >>> list(islice(repeat_last(range(3)), 5))
1812 [0, 1, 2, 2, 2]
1814 If the iterable is empty, yield *default* forever::
1816 >>> list(islice(repeat_last(range(0), 42), 5))
1817 [42, 42, 42, 42, 42]
1819 """
1820 item = _marker
1821 for item in iterable:
1822 yield item
1823 final = default if item is _marker else item
1824 yield from repeat(final)
1827def distribute(n, iterable):
1828 """Distribute the items from *iterable* among *n* smaller iterables.
1830 >>> group_1, group_2 = distribute(2, [1, 2, 3, 4, 5, 6])
1831 >>> list(group_1)
1832 [1, 3, 5]
1833 >>> list(group_2)
1834 [2, 4, 6]
1836 If the length of *iterable* is not evenly divisible by *n*, then the
1837 length of the returned iterables will not be identical:
1839 >>> children = distribute(3, [1, 2, 3, 4, 5, 6, 7])
1840 >>> [list(c) for c in children]
1841 [[1, 4, 7], [2, 5], [3, 6]]
1843 If the length of *iterable* is smaller than *n*, then the last returned
1844 iterables will be empty:
1846 >>> children = distribute(5, [1, 2, 3])
1847 >>> [list(c) for c in children]
1848 [[1], [2], [3], [], []]
1850 This function uses :func:`itertools.tee` and may require significant
1851 storage.
1853 If you need the order items in the smaller iterables to match the
1854 original iterable, see :func:`divide`.
1856 """
1857 if n < 1:
1858 raise ValueError('n must be at least 1')
1860 children = tee(iterable, n)
1861 return [islice(it, index, None, n) for index, it in enumerate(children)]
1864def stagger(iterable, offsets=(-1, 0, 1), longest=False, fillvalue=None):
1865 """Yield tuples whose elements are offset from *iterable*.
1866 The amount by which the `i`-th item in each tuple is offset is given by
1867 the `i`-th item in *offsets*.
1869 >>> list(stagger([0, 1, 2, 3]))
1870 [(None, 0, 1), (0, 1, 2), (1, 2, 3)]
1871 >>> list(stagger(range(8), offsets=(0, 2, 4)))
1872 [(0, 2, 4), (1, 3, 5), (2, 4, 6), (3, 5, 7)]
1874 By default, the sequence will end when the final element of a tuple is the
1875 last item in the iterable. To continue until the first element of a tuple
1876 is the last item in the iterable, set *longest* to ``True``::
1878 >>> list(stagger([0, 1, 2, 3], longest=True))
1879 [(None, 0, 1), (0, 1, 2), (1, 2, 3), (2, 3, None), (3, None, None)]
1881 By default, ``None`` will be used to replace offsets beyond the end of the
1882 sequence. Specify *fillvalue* to use some other value.
1884 """
1885 children = tee(iterable, len(offsets))
1887 return zip_offset(
1888 *children, offsets=offsets, longest=longest, fillvalue=fillvalue
1889 )
1892def zip_offset(*iterables, offsets, longest=False, fillvalue=None):
1893 """``zip`` the input *iterables* together, but offset the `i`-th iterable
1894 by the `i`-th item in *offsets*.
1896 >>> list(zip_offset('0123', 'abcdef', offsets=(0, 1)))
1897 [('0', 'b'), ('1', 'c'), ('2', 'd'), ('3', 'e')]
1899 This can be used as a lightweight alternative to SciPy or pandas to analyze
1900 data sets in which some series have a lead or lag relationship.
1902 By default, the sequence will end when the shortest iterable is exhausted.
1903 To continue until the longest iterable is exhausted, set *longest* to
1904 ``True``.
1906 >>> list(zip_offset('0123', 'abcdef', offsets=(0, 1), longest=True))
1907 [('0', 'b'), ('1', 'c'), ('2', 'd'), ('3', 'e'), (None, 'f')]
1909 By default, ``None`` will be used to replace offsets beyond the end of the
1910 sequence. Specify *fillvalue* to use some other value.
1912 """
1913 if len(iterables) != len(offsets):
1914 raise ValueError("Number of iterables and offsets didn't match")
1916 staggered = []
1917 for it, n in zip(iterables, offsets):
1918 if n < 0:
1919 staggered.append(chain(repeat(fillvalue, -n), it))
1920 elif n > 0:
1921 staggered.append(islice(it, n, None))
1922 else:
1923 staggered.append(it)
1925 if longest:
1926 return zip_longest(*staggered, fillvalue=fillvalue)
1928 return zip(*staggered)
1931def sort_together(
1932 iterables, key_list=(0,), key=None, reverse=False, strict=False
1933):
1934 """Return the input iterables sorted together, with *key_list* as the
1935 priority for sorting. All iterables are trimmed to the length of the
1936 shortest one.
1938 This can be used like the sorting function in a spreadsheet. If each
1939 iterable represents a column of data, the key list determines which
1940 columns are used for sorting.
1942 By default, all iterables are sorted using the ``0``-th iterable::
1944 >>> iterables = [(4, 3, 2, 1), ('a', 'b', 'c', 'd')]
1945 >>> sort_together(iterables)
1946 [(1, 2, 3, 4), ('d', 'c', 'b', 'a')]
1948 Set a different key list to sort according to another iterable.
1949 Specifying multiple keys dictates how ties are broken::
1951 >>> iterables = [(3, 1, 2), (0, 1, 0), ('c', 'b', 'a')]
1952 >>> sort_together(iterables, key_list=(1, 2))
1953 [(2, 3, 1), (0, 0, 1), ('a', 'c', 'b')]
1955 To sort by a function of the elements of the iterable, pass a *key*
1956 function. Its arguments are the elements of the iterables corresponding to
1957 the key list::
1959 >>> names = ('a', 'b', 'c')
1960 >>> lengths = (1, 2, 3)
1961 >>> widths = (5, 2, 1)
1962 >>> def area(length, width):
1963 ... return length * width
1964 >>> sort_together([names, lengths, widths], key_list=(1, 2), key=area)
1965 [('c', 'b', 'a'), (3, 2, 1), (1, 2, 5)]
1967 Set *reverse* to ``True`` to sort in descending order.
1969 >>> sort_together([(1, 2, 3), ('c', 'b', 'a')], reverse=True)
1970 [(3, 2, 1), ('a', 'b', 'c')]
1972 If the *strict* keyword argument is ``True``, then
1973 ``ValueError`` will be raised if any of the iterables have
1974 different lengths.
1976 """
1977 if key is None:
1978 # if there is no key function, the key argument to sorted is an
1979 # itemgetter
1980 key_argument = itemgetter(*key_list)
1981 else:
1982 # if there is a key function, call it with the items at the offsets
1983 # specified by the key function as arguments
1984 key_list = list(key_list)
1985 if len(key_list) == 1:
1986 # if key_list contains a single item, pass the item at that offset
1987 # as the only argument to the key function
1988 key_offset = key_list[0]
1989 key_argument = lambda zipped_items: key(zipped_items[key_offset])
1990 else:
1991 # if key_list contains multiple items, use itemgetter to return a
1992 # tuple of items, which we pass as *args to the key function
1993 get_key_items = itemgetter(*key_list)
1994 key_argument = lambda zipped_items: key(
1995 *get_key_items(zipped_items)
1996 )
1998 transposed = zip(*iterables, strict=strict)
1999 reordered = sorted(transposed, key=key_argument, reverse=reverse)
2000 untransposed = zip(*reordered, strict=strict)
2001 return list(untransposed)
2004def unzip(iterable):
2005 """The inverse of :func:`zip`, this function disaggregates the elements
2006 of the zipped *iterable*.
2008 The ``i``-th iterable contains the ``i``-th element from each element
2009 of the zipped iterable. The first element is used to determine the
2010 length of the remaining elements.
2012 >>> iterable = [('a', 1), ('b', 2), ('c', 3), ('d', 4)]
2013 >>> letters, numbers = unzip(iterable)
2014 >>> list(letters)
2015 ['a', 'b', 'c', 'd']
2016 >>> list(numbers)
2017 [1, 2, 3, 4]
2019 This is similar to using ``zip(*iterable)``, but it avoids reading
2020 *iterable* into memory. Note, however, that this function uses
2021 :func:`itertools.tee` and thus may require significant storage.
2023 """
2024 head, iterable = spy(iterable)
2025 if not head:
2026 # empty iterable, e.g. zip([], [], [])
2027 return ()
2028 # spy returns a one-length iterable as head
2029 head = head[0]
2030 iterables = tee(iterable, len(head))
2032 # If we have an iterable like iter([(1, 2, 3), (4, 5), (6,)]),
2033 # the second unzipped iterable fails at the third tuple since
2034 # it tries to access (6,)[1].
2035 # Same with the third unzipped iterable and the second tuple.
2036 # To support these "improperly zipped" iterables, we suppress
2037 # the IndexError, which just stops the unzipped iterables at
2038 # first length mismatch.
2039 return tuple(
2040 iter_suppress(map(itemgetter(i), it), IndexError)
2041 for i, it in enumerate(iterables)
2042 )
2045def divide(n, iterable):
2046 """Divide the elements from *iterable* into *n* parts, maintaining
2047 order.
2049 >>> group_1, group_2 = divide(2, [1, 2, 3, 4, 5, 6])
2050 >>> list(group_1)
2051 [1, 2, 3]
2052 >>> list(group_2)
2053 [4, 5, 6]
2055 If the length of *iterable* is not evenly divisible by *n*, then the
2056 length of the returned iterables will not be identical:
2058 >>> children = divide(3, [1, 2, 3, 4, 5, 6, 7])
2059 >>> [list(c) for c in children]
2060 [[1, 2, 3], [4, 5], [6, 7]]
2062 If the length of the iterable is smaller than n, then the last returned
2063 iterables will be empty:
2065 >>> children = divide(5, [1, 2, 3])
2066 >>> [list(c) for c in children]
2067 [[1], [2], [3], [], []]
2069 This function will exhaust the iterable before returning.
2070 If order is not important, see :func:`distribute`, which does not first
2071 pull the iterable into memory.
2073 """
2074 if n < 1:
2075 raise ValueError('n must be at least 1')
2077 try:
2078 iterable[:0]
2079 except TypeError:
2080 seq = tuple(iterable)
2081 else:
2082 seq = iterable
2084 q, r = divmod(len(seq), n)
2086 ret = []
2087 stop = 0
2088 for i in range(1, n + 1):
2089 start = stop
2090 stop += q + 1 if i <= r else q
2091 ret.append(iter(seq[start:stop]))
2093 return ret
2096def always_iterable(obj, base_type=(str, bytes)):
2097 """If *obj* is iterable, return an iterator over its items::
2099 >>> obj = (1, 2, 3)
2100 >>> list(always_iterable(obj))
2101 [1, 2, 3]
2103 If *obj* is not iterable, return a one-item iterable containing *obj*::
2105 >>> obj = 1
2106 >>> list(always_iterable(obj))
2107 [1]
2109 If *obj* is ``None``, return an empty iterable:
2111 >>> obj = None
2112 >>> list(always_iterable(None))
2113 []
2115 By default, binary and text strings are not considered iterable::
2117 >>> obj = 'foo'
2118 >>> list(always_iterable(obj))
2119 ['foo']
2121 If *base_type* is set, objects for which ``isinstance(obj, base_type)``
2122 returns ``True`` won't be considered iterable.
2124 >>> obj = {'a': 1}
2125 >>> list(always_iterable(obj)) # Iterate over the dict's keys
2126 ['a']
2127 >>> list(always_iterable(obj, base_type=dict)) # Treat dicts as a unit
2128 [{'a': 1}]
2130 Set *base_type* to ``None`` to avoid any special handling and treat objects
2131 Python considers iterable as iterable:
2133 >>> obj = 'foo'
2134 >>> list(always_iterable(obj, base_type=None))
2135 ['f', 'o', 'o']
2136 """
2137 if obj is None:
2138 return iter(())
2140 if (base_type is not None) and isinstance(obj, base_type):
2141 return iter((obj,))
2143 try:
2144 return iter(obj)
2145 except TypeError:
2146 return iter((obj,))
2149def adjacent(predicate, iterable, distance=1):
2150 """Return an iterable over `(bool, item)` tuples where the `item` is
2151 drawn from *iterable* and the `bool` indicates whether
2152 that item satisfies the *predicate* or is adjacent to an item that does.
2154 For example, to find whether items are adjacent to a ``3``::
2156 >>> list(adjacent(lambda x: x == 3, range(6)))
2157 [(False, 0), (False, 1), (True, 2), (True, 3), (True, 4), (False, 5)]
2159 Set *distance* to change what counts as adjacent. For example, to find
2160 whether items are two places away from a ``3``:
2162 >>> list(adjacent(lambda x: x == 3, range(6), distance=2))
2163 [(False, 0), (True, 1), (True, 2), (True, 3), (True, 4), (True, 5)]
2165 This is useful for contextualizing the results of a search function.
2166 For example, a code comparison tool might want to identify lines that
2167 have changed, but also surrounding lines to give the viewer of the diff
2168 context.
2170 The predicate function will only be called once for each item in the
2171 iterable.
2173 See also :func:`groupby_transform`, which can be used with this function
2174 to group ranges of items with the same `bool` value.
2176 """
2177 # Allow distance=0 mainly for testing that it reproduces results with map()
2178 if distance < 0:
2179 raise ValueError('distance must be at least 0')
2181 i1, i2 = tee(iterable)
2182 padding = [False] * distance
2183 selected = chain(padding, map(predicate, i1), padding)
2184 adjacent_to_selected = map(any, windowed(selected, 2 * distance + 1))
2185 return zip(adjacent_to_selected, i2)
2188def groupby_transform(iterable, keyfunc=None, valuefunc=None, reducefunc=None):
2189 """An extension of :func:`itertools.groupby` that can apply transformations
2190 to the grouped data.
2192 * *keyfunc* is a function computing a key value for each item in *iterable*
2193 * *valuefunc* is a function that transforms the individual items from
2194 *iterable* after grouping
2195 * *reducefunc* is a function that transforms each group of items
2197 >>> iterable = 'aAAbBBcCC'
2198 >>> keyfunc = lambda k: k.upper()
2199 >>> valuefunc = lambda v: v.lower()
2200 >>> reducefunc = lambda g: ''.join(g)
2201 >>> list(groupby_transform(iterable, keyfunc, valuefunc, reducefunc))
2202 [('A', 'aaa'), ('B', 'bbb'), ('C', 'ccc')]
2204 Each optional argument defaults to an identity function if not specified.
2206 :func:`groupby_transform` is useful when grouping elements of an iterable
2207 using a separate iterable as the key. To do this, :func:`zip` the iterables
2208 and pass a *keyfunc* that extracts the first element and a *valuefunc*
2209 that extracts the second element::
2211 >>> from operator import itemgetter
2212 >>> keys = [0, 0, 1, 1, 1, 2, 2, 2, 3]
2213 >>> values = 'abcdefghi'
2214 >>> iterable = zip(keys, values)
2215 >>> grouper = groupby_transform(iterable, itemgetter(0), itemgetter(1))
2216 >>> [(k, ''.join(g)) for k, g in grouper]
2217 [(0, 'ab'), (1, 'cde'), (2, 'fgh'), (3, 'i')]
2219 Note that the order of items in the iterable is significant.
2220 Only adjacent items are grouped together, so if you don't want any
2221 duplicate groups, you should sort the iterable by the key function
2222 or consider :func:`bucket` or :func:`map_reduce`. :func:`map_reduce`
2223 consumes the iterable immediately and returns a dictionary, while
2224 :func:`bucket` does not.
2226 .. seealso:: :func:`bucket`, :func:`map_reduce`
2228 """
2229 ret = groupby(iterable, keyfunc)
2230 if valuefunc:
2231 ret = ((k, map(valuefunc, g)) for k, g in ret)
2232 if reducefunc:
2233 ret = ((k, reducefunc(g)) for k, g in ret)
2235 return ret
2238class numeric_range(Sequence):
2239 """An extension of the built-in ``range()`` function whose arguments can
2240 be any orderable numeric type.
2242 With only *stop* specified, *start* defaults to ``0`` and *step*
2243 defaults to ``1``. The output items will match the type of *stop*:
2245 >>> list(numeric_range(3.5))
2246 [0.0, 1.0, 2.0, 3.0]
2248 With only *start* and *stop* specified, *step* defaults to ``1``. The
2249 output items will match the type of *start*:
2251 >>> from decimal import Decimal
2252 >>> start = Decimal('2.1')
2253 >>> stop = Decimal('5.1')
2254 >>> list(numeric_range(start, stop))
2255 [Decimal('2.1'), Decimal('3.1'), Decimal('4.1')]
2257 With *start*, *stop*, and *step* specified the output items will match
2258 the type of ``start + step``:
2260 >>> from fractions import Fraction
2261 >>> start = Fraction(1, 2) # Start at 1/2
2262 >>> stop = Fraction(5, 2) # End at 5/2
2263 >>> step = Fraction(1, 2) # Count by 1/2
2264 >>> list(numeric_range(start, stop, step))
2265 [Fraction(1, 2), Fraction(1, 1), Fraction(3, 2), Fraction(2, 1)]
2267 If *step* is zero, ``ValueError`` is raised. Negative steps are supported:
2269 >>> list(numeric_range(3, -1, -1.0))
2270 [3.0, 2.0, 1.0, 0.0]
2272 Be aware of the limitations of floating-point numbers; the representation
2273 of the yielded numbers may be surprising.
2275 ``datetime.datetime`` objects can be used for *start* and *stop*, if *step*
2276 is a ``datetime.timedelta`` object:
2278 >>> import datetime
2279 >>> start = datetime.datetime(2019, 1, 1)
2280 >>> stop = datetime.datetime(2019, 1, 3)
2281 >>> step = datetime.timedelta(days=1)
2282 >>> items = iter(numeric_range(start, stop, step))
2283 >>> next(items)
2284 datetime.datetime(2019, 1, 1, 0, 0)
2285 >>> next(items)
2286 datetime.datetime(2019, 1, 2, 0, 0)
2288 """
2290 _EMPTY_HASH = hash(range(0, 0))
2292 def __init__(self, *args):
2293 argc = len(args)
2294 if argc == 1:
2295 (self._stop,) = args
2296 self._start = type(self._stop)(0)
2297 self._step = type(self._stop - self._start)(1)
2298 elif argc == 2:
2299 self._start, self._stop = args
2300 self._step = type(self._stop - self._start)(1)
2301 elif argc == 3:
2302 self._start, self._stop, self._step = args
2303 elif argc == 0:
2304 raise TypeError(
2305 f'numeric_range expected at least 1 argument, got {argc}'
2306 )
2307 else:
2308 raise TypeError(
2309 f'numeric_range expected at most 3 arguments, got {argc}'
2310 )
2312 self._zero = type(self._step)(0)
2313 if self._step == self._zero:
2314 raise ValueError('numeric_range() arg 3 must not be zero')
2315 self._growing = self._step > self._zero
2317 def __bool__(self):
2318 if self._growing:
2319 return self._start < self._stop
2320 else:
2321 return self._start > self._stop
2323 def __contains__(self, elem):
2324 if self._growing:
2325 if self._start <= elem < self._stop:
2326 return (elem - self._start) % self._step == self._zero
2327 else:
2328 if self._start >= elem > self._stop:
2329 return (self._start - elem) % (-self._step) == self._zero
2331 return False
2333 def __eq__(self, other):
2334 if isinstance(other, numeric_range):
2335 empty_self = not bool(self)
2336 empty_other = not bool(other)
2337 if empty_self or empty_other:
2338 return empty_self and empty_other # True if both empty
2339 else:
2340 return (
2341 self._start == other._start
2342 and self._step == other._step
2343 and self._get_by_index(-1) == other._get_by_index(-1)
2344 )
2345 else:
2346 return False
2348 def __getitem__(self, key):
2349 if isinstance(key, int):
2350 return self._get_by_index(key)
2351 elif isinstance(key, slice):
2352 start_idx, stop_idx, step_idx = key.indices(self._len)
2353 return numeric_range(
2354 self._start + start_idx * self._step,
2355 self._start + stop_idx * self._step,
2356 self._step * step_idx,
2357 )
2358 else:
2359 raise TypeError(
2360 'numeric range indices must be '
2361 f'integers or slices, not {type(key).__name__}'
2362 )
2364 def __hash__(self):
2365 if self:
2366 return hash((self._start, self._get_by_index(-1), self._step))
2367 else:
2368 return self._EMPTY_HASH
2370 def __iter__(self):
2371 values = (self._start + (n * self._step) for n in count())
2372 if self._growing:
2373 return takewhile(partial(gt, self._stop), values)
2374 else:
2375 return takewhile(partial(lt, self._stop), values)
2377 def __len__(self):
2378 return self._len
2380 @cached_property
2381 def _len(self):
2382 if self._growing:
2383 start = self._start
2384 stop = self._stop
2385 step = self._step
2386 else:
2387 start = self._stop
2388 stop = self._start
2389 step = -self._step
2390 distance = stop - start
2391 if distance <= self._zero:
2392 return 0
2393 else: # distance > 0 and step > 0: regular euclidean division
2394 q, r = divmod(distance, step)
2395 return int(q) + int(r != self._zero)
2397 def __reduce__(self):
2398 return numeric_range, (self._start, self._stop, self._step)
2400 def __repr__(self):
2401 if self._step == 1:
2402 return f"numeric_range({self._start!r}, {self._stop!r})"
2403 return (
2404 f"numeric_range({self._start!r}, {self._stop!r}, {self._step!r})"
2405 )
2407 def __reversed__(self):
2408 # Empty iterator
2409 try:
2410 start = self._get_by_index(-1)
2411 except IndexError:
2412 return iter([])
2414 return iter(
2415 numeric_range(start, self._start - self._step, -self._step)
2416 )
2418 def count(self, value):
2419 return int(value in self)
2421 def index(self, value):
2422 if self._growing:
2423 if self._start <= value < self._stop:
2424 q, r = divmod(value - self._start, self._step)
2425 if r == self._zero:
2426 return int(q)
2427 else:
2428 if self._start >= value > self._stop:
2429 q, r = divmod(self._start - value, -self._step)
2430 if r == self._zero:
2431 return int(q)
2433 raise ValueError(f"{value} is not in numeric range")
2435 def _get_by_index(self, i):
2436 if i < 0:
2437 i += self._len
2438 if i < 0 or i >= self._len:
2439 raise IndexError("numeric range object index out of range")
2440 return self._start + i * self._step
2443def count_cycle(iterable, n=None):
2444 """Cycle through the items from *iterable* up to *n* times, yielding
2445 the number of completed cycles along with each item. If *n* is omitted the
2446 process repeats indefinitely.
2448 >>> list(count_cycle('AB', 3))
2449 [(0, 'A'), (0, 'B'), (1, 'A'), (1, 'B'), (2, 'A'), (2, 'B')]
2451 """
2452 if n is not None:
2453 return product(range(n), iterable)
2454 seq = tuple(iterable)
2455 if not seq:
2456 return iter(())
2457 counter = count() if n is None else range(n)
2458 return zip(repeat_each(counter, len(seq)), cycle(seq))
2461def mark_ends(iterable):
2462 """Yield 3-tuples of the form ``(is_first, is_last, item)``.
2464 >>> list(mark_ends('ABC'))
2465 [(True, False, 'A'), (False, False, 'B'), (False, True, 'C')]
2467 Use this when looping over an iterable to take special action on its first
2468 and/or last items:
2470 >>> iterable = ['Header', 100, 200, 'Footer']
2471 >>> total = 0
2472 >>> for is_first, is_last, item in mark_ends(iterable):
2473 ... if is_first:
2474 ... continue # Skip the header
2475 ... if is_last:
2476 ... continue # Skip the footer
2477 ... total += item
2478 >>> print(total)
2479 300
2480 """
2481 it = iter(iterable)
2482 for a in it:
2483 first = True
2484 for b in it:
2485 yield first, False, a
2486 a = b
2487 first = False
2488 yield first, True, a
2491def locate(iterable, pred=bool, window_size=None):
2492 """Yield the index of each item in *iterable* for which *pred* returns
2493 ``True``.
2495 *pred* defaults to :func:`bool`, which will select truthy items:
2497 >>> list(locate([0, 1, 1, 0, 1, 0, 0]))
2498 [1, 2, 4]
2500 Set *pred* to a custom function to, e.g., find the indexes for a particular
2501 item.
2503 >>> list(locate(['a', 'b', 'c', 'b'], lambda x: x == 'b'))
2504 [1, 3]
2506 If *window_size* is given, then the *pred* function will be called with
2507 the values in each window. This enables searching for sub-sequences.
2508 Note that *pred* may receive fewer than *window_size* arguments at the end of
2509 the iterable.
2511 >>> iterable = [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]
2512 >>> pred = lambda *args: args == (1, 2, 3)
2513 >>> list(locate(iterable, pred=pred, window_size=3))
2514 [1, 5, 9]
2516 Use with :func:`seekable` to find indexes and then retrieve the associated
2517 items:
2519 >>> from itertools import count
2520 >>> from more_itertools import seekable
2521 >>> source = (3 * n + 1 if (n % 2) else n // 2 for n in count())
2522 >>> it = seekable(source)
2523 >>> pred = lambda x: x > 100
2524 >>> indexes = locate(it, pred=pred)
2525 >>> i = next(indexes)
2526 >>> it.seek(i)
2527 >>> next(it)
2528 106
2530 """
2531 if window_size is None:
2532 return compress(count(), map(pred, iterable))
2534 if window_size < 1:
2535 raise ValueError('window size must be at least 1')
2537 it = windowed(iterable, window_size, fillvalue=_marker)
2538 return compress(
2539 count(),
2540 (pred(*(x for x in w if x is not _marker)) for w in it),
2541 )
2544def longest_common_prefix(iterables):
2545 """Yield elements of the longest common prefix among given *iterables*.
2547 >>> ''.join(longest_common_prefix(['abcd', 'abc', 'abf']))
2548 'ab'
2550 """
2551 return (c[0] for c in takewhile(all_equal, zip(*iterables)))
2554def lstrip(iterable, pred):
2555 """Yield the items from *iterable*, but strip any from the beginning
2556 for which *pred* returns ``True``.
2558 For example, to remove a set of items from the start of an iterable:
2560 >>> iterable = (None, False, None, 1, 2, None, 3, False, None)
2561 >>> pred = lambda x: x in {None, False, ''}
2562 >>> list(lstrip(iterable, pred))
2563 [1, 2, None, 3, False, None]
2565 This function is analogous to to :func:`str.lstrip`, and is essentially
2566 an wrapper for :func:`itertools.dropwhile`.
2568 """
2569 return dropwhile(pred, iterable)
2572def rstrip(iterable, pred):
2573 """Yield the items from *iterable*, but strip any from the end
2574 for which *pred* returns ``True``.
2576 For example, to remove a set of items from the end of an iterable:
2578 >>> iterable = (None, False, None, 1, 2, None, 3, False, None)
2579 >>> pred = lambda x: x in {None, False, ''}
2580 >>> list(rstrip(iterable, pred))
2581 [None, False, None, 1, 2, None, 3]
2583 This function is analogous to :func:`str.rstrip`.
2585 """
2586 cache = []
2587 cache_append = cache.append
2588 cache_clear = cache.clear
2589 for x in iterable:
2590 if pred(x):
2591 cache_append(x)
2592 else:
2593 yield from cache
2594 cache_clear()
2595 yield x
2598def strip(iterable, pred):
2599 """Yield the items from *iterable*, but strip any from the
2600 beginning and end for which *pred* returns ``True``.
2602 For example, to remove a set of items from both ends of an iterable:
2604 >>> iterable = (None, False, None, 1, 2, None, 3, False, None)
2605 >>> pred = lambda x: x in {None, False, ''}
2606 >>> list(strip(iterable, pred))
2607 [1, 2, None, 3]
2609 This function is analogous to :func:`str.strip`.
2611 """
2612 return rstrip(lstrip(iterable, pred), pred)
2615class islice_extended:
2616 """An extension of :func:`itertools.islice` that supports negative values
2617 for *stop*, *start*, and *step*.
2619 >>> iterator = iter('abcdefgh')
2620 >>> list(islice_extended(iterator, -4, -1))
2621 ['e', 'f', 'g']
2623 Slices with negative values require some caching of *iterable*, but this
2624 function takes care to minimize the amount of memory required.
2626 For example, you can use a negative step with an infinite iterator:
2628 >>> from itertools import count
2629 >>> list(islice_extended(count(), 110, 99, -2))
2630 [110, 108, 106, 104, 102, 100]
2632 You can also use slice notation directly:
2634 >>> iterator = map(str, count())
2635 >>> it = islice_extended(iterator)[10:20:2]
2636 >>> list(it)
2637 ['10', '12', '14', '16', '18']
2639 """
2641 def __init__(self, iterable, *args):
2642 it = iter(iterable)
2643 if args:
2644 self._iterator = _islice_helper(it, slice(*args))
2645 else:
2646 self._iterator = it
2648 def __iter__(self):
2649 return self
2651 def __next__(self):
2652 return next(self._iterator)
2654 def __getitem__(self, key):
2655 if isinstance(key, slice):
2656 return islice_extended(_islice_helper(self._iterator, key))
2658 raise TypeError('islice_extended.__getitem__ argument must be a slice')
2661def _islice_helper(it, s):
2662 start = s.start
2663 stop = s.stop
2664 if s.step == 0:
2665 raise ValueError('step argument must be a non-zero integer or None.')
2666 step = s.step or 1
2668 if step > 0:
2669 start = 0 if (start is None) else start
2671 if start < 0:
2672 # Consume all but the last -start items
2673 counter = count(1)
2674 wrapper = compress(it, counter)
2675 cache = deque(wrapper, maxlen=-start)
2676 len_iter = next(counter) - 1
2678 # Adjust start to be positive
2679 i = max(len_iter + start, 0)
2681 # Adjust stop to be positive
2682 if stop is None:
2683 j = len_iter
2684 elif stop >= 0:
2685 j = min(stop, len_iter)
2686 else:
2687 j = max(len_iter + stop, 0)
2689 # Slice the cache
2690 n = j - i
2691 if n <= 0:
2692 return
2694 for index in range(n):
2695 if index % step == 0:
2696 # pop and yield the item.
2697 # We don't want to use an intermediate variable
2698 # it would extend the lifetime of the current item
2699 yield cache.popleft()
2700 else:
2701 # just pop and discard the item
2702 cache.popleft()
2703 elif (stop is not None) and (stop < 0):
2704 # Advance to the start position
2705 next(islice(it, start, start), None)
2707 # When stop is negative, we have to carry -stop items while
2708 # iterating
2709 cache = deque(islice(it, -stop), maxlen=-stop)
2711 for index, item in enumerate(it):
2712 if index % step == 0:
2713 # pop and yield the item.
2714 # We don't want to use an intermediate variable
2715 # it would extend the lifetime of the current item
2716 yield cache.popleft()
2717 else:
2718 # just pop and discard the item
2719 cache.popleft()
2720 cache.append(item)
2721 else:
2722 # When both start and stop are positive we have the normal case
2723 yield from islice(it, start, stop, step)
2724 else:
2725 start = -1 if (start is None) else start
2727 if (stop is not None) and (stop < 0):
2728 # Consume all but the last items
2729 n = -stop - 1
2730 counter = count(1)
2731 wrapper = compress(it, counter)
2732 cache = deque(wrapper, maxlen=n)
2733 len_iter = next(counter) - 1
2735 # If start and stop are both negative they are comparable and
2736 # we can just slice. Otherwise we can adjust start to be negative
2737 # and then slice.
2738 if start < 0:
2739 i, j = start, stop
2740 else:
2741 i, j = min(start - len_iter, -1), None
2743 yield from list(cache)[i:j:step]
2744 else:
2745 # Advance to the stop position
2746 if stop is not None:
2747 m = stop + 1
2748 next(islice(it, m, m), None)
2750 # stop is positive, so if start is negative they are not comparable
2751 # and we need the rest of the items.
2752 if start < 0:
2753 i = start
2754 n = None
2755 # stop is None and start is positive, so we just need items up to
2756 # the start index.
2757 elif stop is None:
2758 i = None
2759 n = start + 1
2760 # Both stop and start are positive, so they are comparable.
2761 else:
2762 i = None
2763 n = start - stop
2764 if n <= 0:
2765 return
2767 cache = list(islice(it, n))
2769 yield from cache[i::step]
2772def always_reversible(iterable):
2773 """An extension of :func:`reversed` that supports all iterables, not
2774 just those which implement the ``Reversible`` or ``Sequence`` protocols.
2776 >>> print(*always_reversible(x for x in range(3)))
2777 2 1 0
2779 If the iterable is already reversible, this function returns the
2780 result of :func:`reversed()`. If the iterable is not reversible,
2781 this function will cache the remaining items in the iterable and
2782 yield them in reverse order, which may require significant storage.
2783 """
2784 try:
2785 return reversed(iterable)
2786 except TypeError:
2787 return reversed(list(iterable))
2790def consecutive_groups(iterable, ordering=None):
2791 """Yield groups of consecutive items using :func:`itertools.groupby`.
2792 The *ordering* function determines whether two items are adjacent by
2793 returning their position.
2795 By default, the ordering function is the identity function. This is
2796 suitable for finding runs of numbers:
2798 >>> iterable = [1, 10, 11, 12, 20, 30, 31, 32, 33, 40]
2799 >>> for group in consecutive_groups(iterable):
2800 ... print(list(group))
2801 [1]
2802 [10, 11, 12]
2803 [20]
2804 [30, 31, 32, 33]
2805 [40]
2807 To find runs of adjacent letters, apply :func:`ord` function
2808 to convert letters to ordinals.
2810 >>> iterable = 'abcdfgilmnop'
2811 >>> ordering = ord
2812 >>> for group in consecutive_groups(iterable, ordering):
2813 ... print(list(group))
2814 ['a', 'b', 'c', 'd']
2815 ['f', 'g']
2816 ['i']
2817 ['l', 'm', 'n', 'o', 'p']
2819 Each group of consecutive items is an iterator that shares it source with
2820 *iterable*. When an an output group is advanced, the previous group is
2821 no longer available unless its elements are copied (e.g., into a ``list``).
2823 >>> iterable = [1, 2, 11, 12, 21, 22]
2824 >>> saved_groups = []
2825 >>> for group in consecutive_groups(iterable):
2826 ... saved_groups.append(list(group)) # Copy group elements
2827 >>> saved_groups
2828 [[1, 2], [11, 12], [21, 22]]
2830 """
2831 if ordering is None:
2832 key = lambda x: x[0] - x[1]
2833 else:
2834 key = lambda x: x[0] - ordering(x[1])
2836 for k, g in groupby(enumerate(iterable), key=key):
2837 yield map(itemgetter(1), g)
2840def difference(iterable, func=sub, *, initial=None):
2841 """This function is the inverse of :func:`itertools.accumulate`. By default
2842 it will compute the first difference of *iterable* using
2843 :func:`operator.sub`:
2845 >>> from itertools import accumulate
2846 >>> iterable = accumulate([0, 1, 2, 3, 4]) # produces 0, 1, 3, 6, 10
2847 >>> list(difference(iterable))
2848 [0, 1, 2, 3, 4]
2850 *func* defaults to :func:`operator.sub`, but other functions can be
2851 specified. They will be applied as follows::
2853 A, B, C, D, ... --> A, func(B, A), func(C, B), func(D, C), ...
2855 For example, to do progressive division:
2857 >>> iterable = [1, 2, 6, 24, 120]
2858 >>> func = lambda x, y: x // y
2859 >>> list(difference(iterable, func))
2860 [1, 2, 3, 4, 5]
2862 If the *initial* keyword is set, the first element will be skipped when
2863 computing successive differences.
2865 >>> it = [10, 11, 13, 16] # from accumulate([1, 2, 3], initial=10)
2866 >>> list(difference(it, initial=10))
2867 [1, 2, 3]
2869 """
2870 a, b = tee(iterable)
2871 try:
2872 first = [next(b)]
2873 except StopIteration:
2874 return iter([])
2876 if initial is not None:
2877 return map(func, b, a)
2879 return chain(first, map(func, b, a))
2882class SequenceView(Sequence):
2883 """Return a read-only view of the sequence object *target*.
2885 :class:`SequenceView` objects are analogous to Python's built-in
2886 "dictionary view" types. They provide a dynamic view of a sequence's items,
2887 meaning that when the sequence updates, so does the view.
2889 >>> seq = ['0', '1', '2']
2890 >>> view = SequenceView(seq)
2891 >>> view
2892 SequenceView(['0', '1', '2'])
2893 >>> seq.append('3')
2894 >>> view
2895 SequenceView(['0', '1', '2', '3'])
2897 Sequence views support indexing, slicing, and length queries. They act
2898 like the underlying sequence, except they don't allow assignment:
2900 >>> view[1]
2901 '1'
2902 >>> view[1:-1]
2903 ['1', '2']
2904 >>> len(view)
2905 4
2907 Sequence views are useful as an alternative to copying, as they don't
2908 require (much) extra storage.
2910 """
2912 def __init__(self, target):
2913 if not isinstance(target, Sequence):
2914 raise TypeError
2915 self._target = target
2917 def __getitem__(self, index):
2918 return self._target[index]
2920 def __len__(self):
2921 return len(self._target)
2923 def __repr__(self):
2924 return f'{self.__class__.__name__}({self._target!r})'
2927class seekable:
2928 """Wrap an iterator to allow for seeking backward and forward. This
2929 progressively caches the items in the source iterable so they can be
2930 re-visited.
2932 Call :meth:`seek` with an index to seek to that position in the source
2933 iterable.
2935 To "reset" an iterator, seek to ``0``:
2937 >>> from itertools import count
2938 >>> it = seekable((str(n) for n in count()))
2939 >>> next(it), next(it), next(it)
2940 ('0', '1', '2')
2941 >>> it.seek(0)
2942 >>> next(it), next(it), next(it)
2943 ('0', '1', '2')
2945 You can also seek forward:
2947 >>> it = seekable((str(n) for n in range(20)))
2948 >>> it.seek(10)
2949 >>> next(it)
2950 '10'
2951 >>> it.seek(20) # Seeking past the end of the source isn't a problem
2952 >>> list(it)
2953 []
2954 >>> it.seek(0) # Resetting works even after hitting the end
2955 >>> next(it)
2956 '0'
2958 Call :meth:`relative_seek` to seek relative to the source iterator's
2959 current position.
2961 >>> it = seekable((str(n) for n in range(20)))
2962 >>> next(it), next(it), next(it)
2963 ('0', '1', '2')
2964 >>> it.relative_seek(2)
2965 >>> next(it)
2966 '5'
2967 >>> it.relative_seek(-3) # Source is at '6', we move back to '3'
2968 >>> next(it)
2969 '3'
2970 >>> it.relative_seek(-3) # Source is at '4', we move back to '1'
2971 >>> next(it)
2972 '1'
2975 Call :meth:`peek` to look ahead one item without advancing the iterator:
2977 >>> it = seekable('1234')
2978 >>> it.peek()
2979 '1'
2980 >>> list(it)
2981 ['1', '2', '3', '4']
2982 >>> it.peek(default='empty')
2983 'empty'
2985 Before the iterator is at its end, calling :func:`bool` on it will return
2986 ``True``. After it will return ``False``:
2988 >>> it = seekable('5678')
2989 >>> bool(it)
2990 True
2991 >>> list(it)
2992 ['5', '6', '7', '8']
2993 >>> bool(it)
2994 False
2996 You may view the contents of the cache with the :meth:`elements` method.
2997 That returns a :class:`SequenceView`, a view that updates automatically:
2999 >>> it = seekable((str(n) for n in range(10)))
3000 >>> next(it), next(it), next(it)
3001 ('0', '1', '2')
3002 >>> elements = it.elements()
3003 >>> elements
3004 SequenceView(['0', '1', '2'])
3005 >>> next(it)
3006 '3'
3007 >>> elements
3008 SequenceView(['0', '1', '2', '3'])
3010 Indexing the :class:`seekable` directly returns items from the cache:
3012 >>> it = seekable((str(n) for n in range(10)))
3013 >>> next(it), next(it), next(it)
3014 ('0', '1', '2')
3015 >>> it[-1]
3016 '2'
3017 >>> it[0]
3018 '0'
3020 By default, the cache grows as the source iterable progresses, so beware of
3021 wrapping very large or infinite iterables. Supply *maxlen* to limit the
3022 size of the cache (this of course limits how far back you can seek).
3024 >>> from itertools import count
3025 >>> it = seekable((str(n) for n in count()), maxlen=2)
3026 >>> next(it), next(it), next(it), next(it)
3027 ('0', '1', '2', '3')
3028 >>> list(it.elements())
3029 ['2', '3']
3030 >>> it.seek(0)
3031 >>> next(it), next(it), next(it), next(it)
3032 ('2', '3', '4', '5')
3033 >>> next(it)
3034 '6'
3036 """
3038 def __init__(self, iterable, maxlen=None):
3039 self._source = iter(iterable)
3040 if maxlen is None:
3041 self._cache = []
3042 else:
3043 self._cache = deque([], maxlen)
3044 self._index = None
3046 def __iter__(self):
3047 return self
3049 def __next__(self):
3050 if self._index is not None:
3051 try:
3052 item = self._cache[self._index]
3053 except IndexError:
3054 self._index = None
3055 else:
3056 self._index += 1
3057 return item
3059 item = next(self._source)
3060 self._cache.append(item)
3061 return item
3063 def __bool__(self):
3064 try:
3065 self.peek()
3066 except StopIteration:
3067 return False
3068 return True
3070 def peek(self, default=_marker):
3071 try:
3072 peeked = next(self)
3073 except StopIteration:
3074 if default is _marker:
3075 raise
3076 return default
3077 if self._index is None:
3078 self._index = len(self._cache)
3079 self._index -= 1
3080 return peeked
3082 def elements(self):
3083 return SequenceView(self._cache)
3085 def seek(self, index):
3086 self._index = index
3087 remainder = index - len(self._cache)
3088 if remainder > 0:
3089 consume(self, remainder)
3091 def relative_seek(self, count):
3092 if self._index is None:
3093 self._index = len(self._cache)
3095 self.seek(max(self._index + count, 0))
3097 def __getitem__(self, index):
3098 return self._cache[index]
3101class run_length:
3102 """
3103 :func:`run_length.encode` compresses an iterable with run-length encoding.
3104 It yields groups of repeated items with the count of how many times they
3105 were repeated:
3107 >>> uncompressed = 'abbcccdddd'
3108 >>> list(run_length.encode(uncompressed))
3109 [('a', 1), ('b', 2), ('c', 3), ('d', 4)]
3111 :func:`run_length.decode` decompresses an iterable that was previously
3112 compressed with run-length encoding. It yields the items of the
3113 decompressed iterable:
3115 >>> compressed = [('a', 1), ('b', 2), ('c', 3), ('d', 4)]
3116 >>> list(run_length.decode(compressed))
3117 ['a', 'b', 'b', 'c', 'c', 'c', 'd', 'd', 'd', 'd']
3119 """
3121 @staticmethod
3122 def encode(iterable):
3123 return ((k, ilen(g)) for k, g in groupby(iterable))
3125 @staticmethod
3126 def decode(iterable):
3127 return chain.from_iterable(starmap(repeat, iterable))
3130def exactly_n(iterable, n, predicate=bool):
3131 """Return ``True`` if exactly ``n`` items in the iterable are ``True``
3132 according to the *predicate* function.
3134 >>> exactly_n([True, True, False], 2)
3135 True
3136 >>> exactly_n([True, True, False], 1)
3137 False
3138 >>> exactly_n([0, 1, 2, 3, 4, 5], 3, lambda x: x < 3)
3139 True
3141 The iterable will be advanced until ``n + 1`` truthy items are encountered,
3142 so avoid calling it on infinite iterables.
3144 """
3145 iterator = filter(predicate, iterable)
3146 if n <= 0:
3147 if n < 0:
3148 return False
3149 for _ in iterator:
3150 return False
3151 return True
3153 iterator = islice(iterator, n - 1, None)
3154 for _ in iterator:
3155 for _ in iterator:
3156 return False
3157 return True
3158 return False
3161def circular_shifts(iterable, steps=1):
3162 """Yield the circular shifts of *iterable*.
3164 >>> list(circular_shifts(range(4)))
3165 [(0, 1, 2, 3), (1, 2, 3, 0), (2, 3, 0, 1), (3, 0, 1, 2)]
3167 Set *steps* to the number of places to rotate to the left
3168 (or to the right if negative). Defaults to 1.
3170 >>> list(circular_shifts(range(4), 2))
3171 [(0, 1, 2, 3), (2, 3, 0, 1)]
3173 >>> list(circular_shifts(range(4), -1))
3174 [(0, 1, 2, 3), (3, 0, 1, 2), (2, 3, 0, 1), (1, 2, 3, 0)]
3176 """
3177 buffer = deque(iterable)
3178 if steps == 0:
3179 raise ValueError('Steps should be a non-zero integer')
3181 buffer.rotate(steps)
3182 steps = -steps
3183 n = len(buffer)
3184 n //= math.gcd(n, steps)
3186 for _ in repeat(None, n):
3187 buffer.rotate(steps)
3188 yield tuple(buffer)
3191def make_decorator(wrapping_func, result_index=0):
3192 """Return a decorator version of *wrapping_func*, which is a function that
3193 modifies an iterable. *result_index* is the position in that function's
3194 signature where the iterable goes.
3196 This lets you use itertools on the "production end," i.e. at function
3197 definition. This can augment what the function returns without changing the
3198 function's code.
3200 For example, to produce a decorator version of :func:`chunked`:
3202 >>> from more_itertools import chunked
3203 >>> chunker = make_decorator(chunked, result_index=0)
3204 >>> @chunker(3)
3205 ... def iter_range(n):
3206 ... return iter(range(n))
3207 ...
3208 >>> list(iter_range(9))
3209 [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
3211 To only allow truthy items to be returned:
3213 >>> truth_serum = make_decorator(filter, result_index=1)
3214 >>> @truth_serum(bool)
3215 ... def boolean_test():
3216 ... return [0, 1, '', ' ', False, True]
3217 ...
3218 >>> list(boolean_test())
3219 [1, ' ', True]
3221 The :func:`peekable` and :func:`seekable` wrappers make for practical
3222 decorators:
3224 >>> from more_itertools import peekable
3225 >>> peekable_function = make_decorator(peekable)
3226 >>> @peekable_function()
3227 ... def str_range(*args):
3228 ... return (str(x) for x in range(*args))
3229 ...
3230 >>> it = str_range(1, 20, 2)
3231 >>> next(it), next(it), next(it)
3232 ('1', '3', '5')
3233 >>> it.peek()
3234 '7'
3235 >>> next(it)
3236 '7'
3238 """
3240 # See https://sites.google.com/site/bbayles/index/decorator_factory for
3241 # notes on how this works.
3242 def decorator(*wrapping_args, **wrapping_kwargs):
3243 def outer_wrapper(f):
3244 def inner_wrapper(*args, **kwargs):
3245 result = f(*args, **kwargs)
3246 wrapping_args_ = list(wrapping_args)
3247 wrapping_args_.insert(result_index, result)
3248 return wrapping_func(*wrapping_args_, **wrapping_kwargs)
3250 return inner_wrapper
3252 return outer_wrapper
3254 return decorator
3257def map_reduce(iterable, keyfunc, valuefunc=None, reducefunc=None):
3258 """Return a dictionary that maps the items in *iterable* to categories
3259 defined by *keyfunc*, transforms them with *valuefunc*, and
3260 then summarizes them by category with *reducefunc*.
3262 *valuefunc* defaults to the identity function if it is unspecified.
3263 If *reducefunc* is unspecified, no summarization takes place:
3265 >>> keyfunc = lambda x: x.upper()
3266 >>> result = map_reduce('abbccc', keyfunc)
3267 >>> sorted(result.items())
3268 [('A', ['a']), ('B', ['b', 'b']), ('C', ['c', 'c', 'c'])]
3270 Specifying *valuefunc* transforms the categorized items:
3272 >>> keyfunc = lambda x: x.upper()
3273 >>> valuefunc = lambda x: 1
3274 >>> result = map_reduce('abbccc', keyfunc, valuefunc)
3275 >>> sorted(result.items())
3276 [('A', [1]), ('B', [1, 1]), ('C', [1, 1, 1])]
3278 Specifying *reducefunc* summarizes the categorized items:
3280 >>> keyfunc = lambda x: x.upper()
3281 >>> valuefunc = lambda x: 1
3282 >>> reducefunc = sum
3283 >>> result = map_reduce('abbccc', keyfunc, valuefunc, reducefunc)
3284 >>> sorted(result.items())
3285 [('A', 1), ('B', 2), ('C', 3)]
3287 You may want to filter the input iterable before applying the map/reduce
3288 procedure:
3290 >>> all_items = range(30)
3291 >>> items = [x for x in all_items if 10 <= x <= 20] # Filter
3292 >>> keyfunc = lambda x: x % 2 # Evens map to 0; odds to 1
3293 >>> categories = map_reduce(items, keyfunc=keyfunc)
3294 >>> sorted(categories.items())
3295 [(0, [10, 12, 14, 16, 18, 20]), (1, [11, 13, 15, 17, 19])]
3296 >>> summaries = map_reduce(items, keyfunc=keyfunc, reducefunc=sum)
3297 >>> sorted(summaries.items())
3298 [(0, 90), (1, 75)]
3300 Note that all items in the iterable are gathered into a list before the
3301 summarization step, which may require significant storage.
3303 The returned object is a :obj:`collections.defaultdict` with the
3304 ``default_factory`` set to ``None``, such that it behaves like a normal
3305 dictionary.
3307 .. seealso:: :func:`bucket`, :func:`groupby_transform`
3309 If storage is a concern, :func:`bucket` can be used without consuming the
3310 entire iterable right away. If the elements with the same key are already
3311 adjacent, :func:`groupby_transform` or :func:`itertools.groupby` can be
3312 used without any caching overhead.
3314 """
3316 ret = defaultdict(list)
3318 if valuefunc is None:
3319 for item in iterable:
3320 key = keyfunc(item)
3321 ret[key].append(item)
3323 else:
3324 for item in iterable:
3325 key = keyfunc(item)
3326 value = valuefunc(item)
3327 ret[key].append(value)
3329 if reducefunc is not None:
3330 for key, value_list in ret.items():
3331 ret[key] = reducefunc(value_list)
3333 ret.default_factory = None
3334 return ret
3337def rlocate(iterable, pred=bool, window_size=None):
3338 """Yield the index of each item in *iterable* for which *pred* returns
3339 ``True``, starting from the right and moving left.
3341 *pred* defaults to :func:`bool`, which will select truthy items:
3343 >>> list(rlocate([0, 1, 1, 0, 1, 0, 0])) # Truthy at 1, 2, and 4
3344 [4, 2, 1]
3346 Set *pred* to a custom function to, e.g., find the indexes for a particular
3347 item:
3349 >>> iterator = iter('abcb')
3350 >>> pred = lambda x: x == 'b'
3351 >>> list(rlocate(iterator, pred))
3352 [3, 1]
3354 If *window_size* is given, then the *pred* function will be called with
3355 that many items. This enables searching for sub-sequences:
3357 >>> iterable = [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]
3358 >>> pred = lambda *args: args == (1, 2, 3)
3359 >>> list(rlocate(iterable, pred=pred, window_size=3))
3360 [9, 5, 1]
3362 Beware, this function won't return anything for infinite iterables.
3363 If *iterable* is reversible, ``rlocate`` will reverse it and search from
3364 the right. Otherwise, it will search from the left and return the results
3365 in reverse order.
3367 See :func:`locate` to for other example applications.
3369 """
3370 if window_size is None:
3371 try:
3372 len_iter = len(iterable)
3373 return (len_iter - i - 1 for i in locate(reversed(iterable), pred))
3374 except TypeError:
3375 pass
3377 return reversed(list(locate(iterable, pred, window_size)))
3380def replace(iterable, pred, substitutes, count=None, window_size=1):
3381 """Yield the items from *iterable*, replacing the items for which *pred*
3382 returns ``True`` with the items from the iterable *substitutes*.
3384 >>> iterable = [1, 1, 0, 1, 1, 0, 1, 1]
3385 >>> pred = lambda x: x == 0
3386 >>> substitutes = (2, 3)
3387 >>> list(replace(iterable, pred, substitutes))
3388 [1, 1, 2, 3, 1, 1, 2, 3, 1, 1]
3390 If *count* is given, the number of replacements will be limited:
3392 >>> iterable = [1, 1, 0, 1, 1, 0, 1, 1, 0]
3393 >>> pred = lambda x: x == 0
3394 >>> substitutes = [None]
3395 >>> list(replace(iterable, pred, substitutes, count=2))
3396 [1, 1, None, 1, 1, None, 1, 1, 0]
3398 Use *window_size* to control the number of items passed as arguments to
3399 *pred*. This allows for locating and replacing subsequences.
3401 >>> iterable = [0, 1, 2, 5, 0, 1, 2, 5]
3402 >>> window_size = 3
3403 >>> pred = lambda *args: args == (0, 1, 2) # 3 items passed to pred
3404 >>> substitutes = [3, 4] # Splice in these items
3405 >>> list(replace(iterable, pred, substitutes, window_size=window_size))
3406 [3, 4, 5, 3, 4, 5]
3408 *pred* may receive fewer than *window_size* arguments at the end of
3409 the iterable and should be able to handle this.
3411 """
3412 if window_size < 1:
3413 raise ValueError('window_size must be at least 1')
3415 # Save the substitutes iterable, since it's used more than once
3416 substitutes = tuple(substitutes)
3418 # Add padding such that the number of windows matches the length of the
3419 # iterable
3420 it = chain(iterable, repeat(_marker, window_size - 1))
3421 windows = windowed(it, window_size)
3423 n = 0
3424 for w in windows:
3425 # Strip any _marker padding so pred never sees internal sentinels.
3426 # Near the end of the iterable, pred will receive fewer arguments.
3427 args = tuple(x for x in w if x is not _marker)
3429 # If the current window matches our predicate (and we haven't hit
3430 # our maximum number of replacements), splice in the substitutes
3431 # and then consume the following windows that overlap with this one.
3432 # For example, if the iterable is (0, 1, 2, 3, 4...)
3433 # and the window size is 2, we have (0, 1), (1, 2), (2, 3)...
3434 # If the predicate matches on (0, 1), we need to zap (0, 1) and (1, 2)
3435 if args and pred(*args):
3436 if (count is None) or (n < count):
3437 n += 1
3438 yield from substitutes
3439 consume(windows, window_size - 1)
3440 continue
3442 # If there was no match (or we've reached the replacement limit),
3443 # yield the first item from the window.
3444 if args:
3445 yield args[0]
3448def partitions(iterable):
3449 """Yield all possible order-preserving partitions of *iterable*.
3451 >>> iterable = 'abc'
3452 >>> for part in partitions(iterable):
3453 ... print([''.join(p) for p in part])
3454 ['abc']
3455 ['a', 'bc']
3456 ['ab', 'c']
3457 ['a', 'b', 'c']
3459 This is unrelated to :func:`partition`.
3461 """
3462 sequence = list(iterable)
3463 n = len(sequence)
3464 for i in powerset(range(1, n)):
3465 yield [sequence[i:j] for i, j in zip((0,) + i, i + (n,))]
3468def set_partitions(iterable, k=None, min_size=None, max_size=None):
3469 """
3470 Yield the set partitions of *iterable* into *k* parts. Set partitions are
3471 not order-preserving.
3473 >>> iterable = 'abc'
3474 >>> for part in set_partitions(iterable, 2):
3475 ... print([''.join(p) for p in part])
3476 ['a', 'bc']
3477 ['ab', 'c']
3478 ['b', 'ac']
3481 If *k* is not given, every set partition is generated.
3483 >>> iterable = 'abc'
3484 >>> for part in set_partitions(iterable):
3485 ... print([''.join(p) for p in part])
3486 ['abc']
3487 ['a', 'bc']
3488 ['ab', 'c']
3489 ['b', 'ac']
3490 ['a', 'b', 'c']
3492 if *min_size* and/or *max_size* are given, the minimum and/or maximum size
3493 per block in partition is set.
3495 >>> iterable = 'abc'
3496 >>> for part in set_partitions(iterable, min_size=2):
3497 ... print([''.join(p) for p in part])
3498 ['abc']
3499 >>> for part in set_partitions(iterable, max_size=2):
3500 ... print([''.join(p) for p in part])
3501 ['a', 'bc']
3502 ['ab', 'c']
3503 ['b', 'ac']
3504 ['a', 'b', 'c']
3506 """
3507 L = list(iterable)
3508 n = len(L)
3509 if k is not None:
3510 if k < 1:
3511 raise ValueError(
3512 "Can't partition in a negative or zero number of groups"
3513 )
3514 elif k > n:
3515 return
3517 min_size = min_size if min_size is not None else 0
3518 max_size = max_size if max_size is not None else n
3519 if min_size > max_size:
3520 return
3522 def set_partitions_helper(L, k):
3523 n = len(L)
3524 if k == 1:
3525 yield [L]
3526 elif n == k:
3527 yield [[s] for s in L]
3528 else:
3529 e, *M = L
3530 for p in set_partitions_helper(M, k - 1):
3531 yield [[e], *p]
3532 for p in set_partitions_helper(M, k):
3533 for i in range(len(p)):
3534 yield p[:i] + [[e] + p[i]] + p[i + 1 :]
3536 if k is None:
3537 for k in range(1, n + 1):
3538 yield from filter(
3539 lambda z: all(min_size <= len(bk) <= max_size for bk in z),
3540 set_partitions_helper(L, k),
3541 )
3542 else:
3543 yield from filter(
3544 lambda z: all(min_size <= len(bk) <= max_size for bk in z),
3545 set_partitions_helper(L, k),
3546 )
3549class time_limited:
3550 """
3551 Yield items from *iterable* until *limit_seconds* have passed.
3552 If the time limit expires before all items have been yielded, the
3553 ``timed_out`` parameter will be set to ``True``.
3555 >>> from time import sleep
3556 >>> def generator():
3557 ... yield 1
3558 ... yield 2
3559 ... sleep(0.2)
3560 ... yield 3
3561 >>> iterable = time_limited(0.1, generator())
3562 >>> list(iterable)
3563 [1, 2]
3564 >>> iterable.timed_out
3565 True
3567 Note that the time is checked before each item is yielded, and iteration
3568 stops if the time elapsed is greater than *limit_seconds*. If your time
3569 limit is 1 second, but it takes 2 seconds to generate the first item from
3570 the iterable, the function will run for 2 seconds and not yield anything.
3571 As a special case, when *limit_seconds* is zero, the iterator never
3572 returns anything.
3574 """
3576 def __init__(self, limit_seconds, iterable):
3577 if limit_seconds < 0:
3578 raise ValueError('limit_seconds must be positive')
3579 self.limit_seconds = limit_seconds
3580 self._iterator = iter(iterable)
3581 self._start_time = monotonic()
3582 self.timed_out = False
3584 def __iter__(self):
3585 return self
3587 def __next__(self):
3588 if self.limit_seconds == 0:
3589 self.timed_out = True
3590 raise StopIteration
3591 item = next(self._iterator)
3592 if monotonic() - self._start_time > self.limit_seconds:
3593 self.timed_out = True
3594 raise StopIteration
3596 return item
3599def only(iterable, default=None, too_long=None):
3600 """If *iterable* has only one item, return it.
3601 If it has zero items, return *default*.
3602 If it has more than one item, raise the exception given by *too_long*,
3603 which is ``ValueError`` by default.
3605 >>> only([], default='missing')
3606 'missing'
3607 >>> only([1])
3608 1
3609 >>> only([1, 2]) # doctest: +IGNORE_EXCEPTION_DETAIL
3610 Traceback (most recent call last):
3611 ...
3612 ValueError: Expected exactly one item in iterable, but got 1, 2,
3613 and perhaps more.'
3614 >>> only([1, 2], too_long=TypeError) # doctest: +IGNORE_EXCEPTION_DETAIL
3615 Traceback (most recent call last):
3616 ...
3617 TypeError
3619 Note that :func:`only` attempts to advance *iterable* twice to ensure there
3620 is only one item. See :func:`spy` or :func:`peekable` to check
3621 iterable contents less destructively.
3623 """
3624 iterator = iter(iterable)
3625 for first in iterator:
3626 for second in iterator:
3627 msg = (
3628 f'Expected exactly one item in iterable, but got {first!r}, '
3629 f'{second!r}, and perhaps more.'
3630 )
3631 raise too_long or ValueError(msg)
3632 return first
3633 return default
3636def ichunked(iterable, n):
3637 """Break *iterable* into sub-iterables with *n* elements each.
3638 :func:`ichunked` is like :func:`chunked`, but it yields iterables
3639 instead of lists.
3641 If the sub-iterables are read in order, the elements of *iterable*
3642 won't be stored in memory.
3643 If they are read out of order, :func:`itertools.tee` is used to cache
3644 elements as necessary.
3646 >>> from itertools import count
3647 >>> all_chunks = ichunked(count(), 4)
3648 >>> c_1, c_2, c_3 = next(all_chunks), next(all_chunks), next(all_chunks)
3649 >>> list(c_2) # c_1's elements have been cached; c_3's haven't been
3650 [4, 5, 6, 7]
3651 >>> list(c_1)
3652 [0, 1, 2, 3]
3653 >>> list(c_3)
3654 [8, 9, 10, 11]
3656 """
3657 iterator = iter(iterable)
3658 for first in iterator:
3659 rest = islice(iterator, n - 1)
3660 cache, cacher = tee(rest)
3661 yield chain([first], rest, cache)
3662 consume(cacher)
3665def iequals(*iterables):
3666 """Return ``True`` if all given *iterables* are equal to each other,
3667 which means that they contain the same elements in the same order.
3669 The function is useful for comparing iterables of different data types
3670 or iterables that do not support equality checks.
3672 >>> iequals("abc", ['a', 'b', 'c'], ('a', 'b', 'c'), iter("abc"))
3673 True
3675 >>> iequals("abc", "acb")
3676 False
3678 Not to be confused with :func:`all_equal`, which checks whether all
3679 elements of iterable are equal to each other.
3681 """
3682 try:
3683 return all(map(all_equal, zip(*iterables, strict=True)))
3684 except ValueError:
3685 return False
3688def distinct_combinations(iterable, r):
3689 """Yield the distinct combinations of *r* items taken from *iterable*.
3691 >>> list(distinct_combinations([0, 0, 1], 2))
3692 [(0, 0), (0, 1)]
3694 Equivalent to ``set(combinations(iterable))``, except duplicates are not
3695 generated and thrown away. For larger input sequences this is much more
3696 efficient.
3698 """
3699 if r < 0:
3700 raise ValueError('r must be non-negative')
3701 elif r == 0:
3702 yield ()
3703 return
3704 pool = tuple(iterable)
3705 generators = [unique_everseen(enumerate(pool), key=itemgetter(1))]
3706 current_combo = [None] * r
3707 level = 0
3708 while generators:
3709 try:
3710 cur_idx, p = next(generators[-1])
3711 except StopIteration:
3712 generators.pop()
3713 level -= 1
3714 continue
3715 current_combo[level] = p
3716 if level + 1 == r:
3717 yield tuple(current_combo)
3718 else:
3719 generators.append(
3720 unique_everseen(
3721 enumerate(pool[cur_idx + 1 :], cur_idx + 1),
3722 key=itemgetter(1),
3723 )
3724 )
3725 level += 1
3728def filter_except(validator, iterable, *exceptions):
3729 """Yield the items from *iterable* for which the *validator* function does
3730 not raise one of the specified *exceptions*.
3732 *validator* is called for each item in *iterable*.
3733 It should be a function that accepts one argument and raises an exception
3734 if that item is not valid.
3736 >>> iterable = ['1', '2', 'three', '4', None]
3737 >>> list(filter_except(int, iterable, ValueError, TypeError))
3738 ['1', '2', '4']
3740 If an exception other than one given by *exceptions* is raised by
3741 *validator*, it is raised like normal.
3742 """
3743 for item in iterable:
3744 try:
3745 validator(item)
3746 except exceptions:
3747 pass
3748 else:
3749 yield item
3752def map_except(function, iterable, *exceptions):
3753 """Transform each item from *iterable* with *function* and yield the
3754 result, unless *function* raises one of the specified *exceptions*.
3756 *function* is called to transform each item in *iterable*.
3757 It should accept one argument.
3759 >>> iterable = ['1', '2', 'three', '4', None]
3760 >>> list(map_except(int, iterable, ValueError, TypeError))
3761 [1, 2, 4]
3763 If an exception other than one given by *exceptions* is raised by
3764 *function*, it is raised like normal.
3765 """
3766 for item in iterable:
3767 try:
3768 yield function(item)
3769 except exceptions:
3770 pass
3773def map_if(iterable, pred, func, func_else=None):
3774 """Evaluate each item from *iterable* using *pred*. If the result is
3775 equivalent to ``True``, transform the item with *func* and yield it.
3776 Otherwise, transform the item with *func_else* and yield it.
3778 *pred*, *func*, and *func_else* should each be functions that accept
3779 one argument. By default, *func_else* is the identity function.
3781 >>> from math import sqrt
3782 >>> iterable = list(range(-5, 5))
3783 >>> iterable
3784 [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4]
3785 >>> list(map_if(iterable, lambda x: x > 3, lambda x: 'toobig'))
3786 [-5, -4, -3, -2, -1, 0, 1, 2, 3, 'toobig']
3787 >>> list(map_if(iterable, lambda x: x >= 0,
3788 ... lambda x: f'{sqrt(x):.2f}', lambda x: None))
3789 [None, None, None, None, None, '0.00', '1.00', '1.41', '1.73', '2.00']
3790 """
3792 if func_else is None:
3793 for item in iterable:
3794 yield func(item) if pred(item) else item
3796 else:
3797 for item in iterable:
3798 yield func(item) if pred(item) else func_else(item)
3801def _sample_unweighted(iterator, k, strict):
3802 # Algorithm L in the 1994 paper by Kim-Hung Li:
3803 # "Reservoir-Sampling Algorithms of Time Complexity O(n(1+log(N/n)))".
3805 reservoir = list(islice(iterator, k))
3806 if strict and len(reservoir) < k:
3807 raise ValueError('Sample larger than population')
3808 W = 1.0
3810 with suppress(StopIteration):
3811 while True:
3812 W *= random() ** (1 / k)
3813 skip = floor(log(random()) / log1p(-W))
3814 element = next(islice(iterator, skip, None))
3815 reservoir[randrange(k)] = element
3817 shuffle(reservoir)
3818 return reservoir
3821def _sample_weighted(iterator, k, weights, strict):
3822 # Implementation of "A-ExpJ" from the 2006 paper by Efraimidis et al. :
3823 # "Weighted random sampling with a reservoir".
3825 # Log-transform for numerical stability for weights that are small/large
3826 weight_keys = (log(random()) / weight for weight in weights)
3828 # Fill up the reservoir (collection of samples) with the first `k`
3829 # weight-keys and elements, then heapify the list.
3830 reservoir = take(k, zip(weight_keys, iterator))
3831 if strict and len(reservoir) < k:
3832 raise ValueError('Sample larger than population')
3834 heapify(reservoir)
3836 # The number of jumps before changing the reservoir is a random variable
3837 # with an exponential distribution. Sample it using random() and logs.
3838 smallest_weight_key, _ = reservoir[0]
3839 weights_to_skip = log(random()) / smallest_weight_key
3841 for weight, element in zip(weights, iterator):
3842 if weight >= weights_to_skip:
3843 # The notation here is consistent with the paper, but we store
3844 # the weight-keys in log-space for better numerical stability.
3845 smallest_weight_key, _ = reservoir[0]
3846 t_w = exp(weight * smallest_weight_key)
3847 r_2 = uniform(t_w, 1) # generate U(t_w, 1)
3848 weight_key = log(r_2) / weight
3849 heapreplace(reservoir, (weight_key, element))
3850 smallest_weight_key, _ = reservoir[0]
3851 weights_to_skip = log(random()) / smallest_weight_key
3852 else:
3853 weights_to_skip -= weight
3855 ret = [element for weight_key, element in reservoir]
3856 shuffle(ret)
3857 return ret
3860def _sample_counted(population, k, counts, strict):
3861 element = None
3862 remaining = 0
3864 def feed(i):
3865 # Advance *i* steps ahead and consume an element
3866 nonlocal element, remaining
3868 while i + 1 > remaining:
3869 i = i - remaining
3870 element = next(population)
3871 remaining = next(counts)
3872 remaining -= i + 1
3873 return element
3875 with suppress(StopIteration):
3876 reservoir = []
3877 for _ in range(k):
3878 reservoir.append(feed(0))
3880 if strict and len(reservoir) < k:
3881 raise ValueError('Sample larger than population')
3883 with suppress(StopIteration):
3884 W = 1.0
3885 while True:
3886 W *= random() ** (1 / k)
3887 skip = floor(log(random()) / log1p(-W))
3888 element = feed(skip)
3889 reservoir[randrange(k)] = element
3891 shuffle(reservoir)
3892 return reservoir
3895def sample(iterable, k, weights=None, *, counts=None, strict=False):
3896 """Return a *k*-length list of elements chosen (without replacement)
3897 from the *iterable*.
3899 Similar to :func:`random.sample`, but works on inputs that aren't
3900 indexable (such as sets and dictionaries) and on inputs where the
3901 size isn't known in advance (such as generators).
3903 >>> iterable = range(100)
3904 >>> sample(iterable, 5) # doctest: +SKIP
3905 [81, 60, 96, 16, 4]
3907 For iterables with repeated elements, you may supply *counts* to
3908 indicate the repeats.
3910 >>> iterable = ['a', 'b']
3911 >>> counts = [3, 4] # Equivalent to 'a', 'a', 'a', 'b', 'b', 'b', 'b'
3912 >>> sample(iterable, k=3, counts=counts) # doctest: +SKIP
3913 ['a', 'a', 'b']
3915 An iterable with *weights* may be given:
3917 >>> iterable = range(100)
3918 >>> weights = (i * i + 1 for i in range(100))
3919 >>> sampled = sample(iterable, 5, weights=weights) # doctest: +SKIP
3920 [79, 67, 74, 66, 78]
3922 Weighted selections are made without replacement.
3923 After an element is selected, it is removed from the pool and the
3924 relative weights of the other elements increase (this
3925 does not match the behavior of :func:`random.sample`'s *counts*
3926 parameter). Note that *weights* may not be used with *counts*.
3928 If the length of *iterable* is less than *k*,
3929 ``ValueError`` is raised if *strict* is ``True`` and
3930 all elements are returned (in shuffled order) if *strict* is ``False``.
3932 By default, the `Algorithm L <https://w.wiki/ANrM>`__ reservoir sampling
3933 technique is used. When *weights* are provided,
3934 `Algorithm A-ExpJ <https://w.wiki/ANrS>`__ is used instead.
3936 Notes on reproducibility:
3938 * The algorithms rely on inexact floating-point functions provided
3939 by the underlying math library (e.g. ``log``, ``log1p``, and ``pow``).
3940 Those functions can `produce slightly different results
3941 <https://members.loria.fr/PZimmermann/papers/accuracy.pdf>`_ on
3942 different builds. Accordingly, selections can vary across builds
3943 even for the same seed.
3945 * The algorithms loop over the input and make selections based on
3946 ordinal position, so selections from unordered collections (such as
3947 sets) won't reproduce across sessions on the same platform using the
3948 same seed. For example, this won't reproduce::
3950 >> seed(8675309)
3951 >> sample(set('abcdefghijklmnopqrstuvwxyz'), 10)
3952 ['c', 'p', 'e', 'w', 's', 'a', 'j', 'd', 'n', 't']
3954 """
3955 iterator = iter(iterable)
3957 if k < 0:
3958 raise ValueError('k must be non-negative')
3960 if k == 0:
3961 return []
3963 if weights is not None and counts is not None:
3964 raise TypeError('weights and counts are mutually exclusive')
3966 elif weights is not None:
3967 weights = iter(weights)
3968 return _sample_weighted(iterator, k, weights, strict)
3970 elif counts is not None:
3971 counts = iter(counts)
3972 return _sample_counted(iterator, k, counts, strict)
3974 else:
3975 return _sample_unweighted(iterator, k, strict)
3978def is_sorted(iterable, key=None, reverse=False, strict=False):
3979 """Returns ``True`` if the items of iterable are in sorted order, and
3980 ``False`` otherwise. *key* and *reverse* have the same meaning that they do
3981 in the built-in :func:`sorted` function.
3983 >>> is_sorted(['1', '2', '3', '4', '5'], key=int)
3984 True
3985 >>> is_sorted([5, 4, 3, 1, 2], reverse=True)
3986 False
3988 If *strict*, tests for strict sorting, that is, returns ``False`` if equal
3989 elements are found:
3991 >>> is_sorted([1, 2, 2])
3992 True
3993 >>> is_sorted([1, 2, 2], strict=True)
3994 False
3996 The function returns ``False`` after encountering the first out-of-order
3997 item, which means it may produce results that differ from the built-in
3998 :func:`sorted` function for objects with unusual comparison dynamics
3999 (like ``math.nan``). If there are no out-of-order items, the iterable is
4000 exhausted.
4001 """
4002 it = iterable if (key is None) else map(key, iterable)
4003 a, b = tee(it)
4004 next(b, None)
4005 if reverse:
4006 b, a = a, b
4007 return all(map(lt, a, b)) if strict else not any(map(lt, b, a))
4010class AbortThread(BaseException):
4011 pass
4014class callback_iter:
4015 """Convert a function that uses callbacks to an iterator.
4017 .. warning::
4019 This function is deprecated as of version 11.0.0. It will be removed in a future
4020 major release.
4022 Let *func* be a function that takes a `callback` keyword argument.
4023 For example:
4025 >>> def func(callback=None):
4026 ... for i, c in [(1, 'a'), (2, 'b'), (3, 'c')]:
4027 ... if callback:
4028 ... callback(i, c)
4029 ... return 4
4032 Use ``with callback_iter(func)`` to get an iterator over the parameters
4033 that are delivered to the callback.
4035 >>> with callback_iter(func) as it:
4036 ... for args, kwargs in it:
4037 ... print(args)
4038 (1, 'a')
4039 (2, 'b')
4040 (3, 'c')
4042 The function will be called in a background thread. The ``done`` property
4043 indicates whether it has completed execution.
4045 >>> it.done
4046 True
4048 If it completes successfully, its return value will be available
4049 in the ``result`` property.
4051 >>> it.result
4052 4
4054 Notes:
4056 * If the function uses some keyword argument besides ``callback``, supply
4057 *callback_kwd*.
4058 * If it finished executing, but raised an exception, accessing the
4059 ``result`` property will raise the same exception.
4060 * If it hasn't finished executing, accessing the ``result``
4061 property from within the ``with`` block will raise ``RuntimeError``.
4062 * If it hasn't finished executing, accessing the ``result`` property from
4063 outside the ``with`` block will raise a
4064 ``more_itertools.AbortThread`` exception.
4065 * Provide *wait_seconds* to adjust how frequently the it is polled for
4066 output.
4068 """
4070 def __init__(self, func, callback_kwd='callback', wait_seconds=0.1):
4071 self._func = func
4072 self._callback_kwd = callback_kwd
4073 self._aborted = False
4074 self._future = None
4075 self._wait_seconds = wait_seconds
4077 # Lazily import concurrent.future
4078 self._module = __import__('concurrent.futures').futures
4079 self._executor = self._module.ThreadPoolExecutor(max_workers=1)
4080 self._iterator = self._reader()
4082 def __enter__(self):
4083 return self
4085 def __exit__(self, exc_type, exc_value, traceback):
4086 self._aborted = True
4087 self._executor.shutdown()
4089 def __iter__(self):
4090 return self
4092 def __next__(self):
4093 return next(self._iterator)
4095 @property
4096 def done(self):
4097 if self._future is None:
4098 return False
4099 return self._future.done()
4101 @property
4102 def result(self):
4103 if self._future:
4104 try:
4105 return self._future.result(timeout=0)
4106 except self._module.TimeoutError:
4107 pass
4109 raise RuntimeError('Function has not yet completed')
4111 def _reader(self):
4112 q = Queue()
4114 def callback(*args, **kwargs):
4115 if self._aborted:
4116 raise AbortThread('canceled by user')
4118 q.put((args, kwargs))
4120 self._future = self._executor.submit(
4121 self._func, **{self._callback_kwd: callback}
4122 )
4124 while True:
4125 try:
4126 item = q.get(timeout=self._wait_seconds)
4127 except Empty:
4128 pass
4129 else:
4130 q.task_done()
4131 yield item
4133 if self._future.done():
4134 break
4136 remaining = []
4137 while True:
4138 try:
4139 item = q.get_nowait()
4140 except Empty:
4141 break
4142 else:
4143 q.task_done()
4144 remaining.append(item)
4145 q.join()
4146 yield from remaining
4149def windowed_complete(iterable, n):
4150 """
4151 Yield ``(beginning, middle, end)`` tuples, where:
4153 * Each ``middle`` has *n* items from *iterable*
4154 * Each ``beginning`` has the items before the ones in ``middle``
4155 * Each ``end`` has the items after the ones in ``middle``
4157 >>> iterable = range(7)
4158 >>> n = 3
4159 >>> for beginning, middle, end in windowed_complete(iterable, n):
4160 ... print(beginning, middle, end)
4161 () (0, 1, 2) (3, 4, 5, 6)
4162 (0,) (1, 2, 3) (4, 5, 6)
4163 (0, 1) (2, 3, 4) (5, 6)
4164 (0, 1, 2) (3, 4, 5) (6,)
4165 (0, 1, 2, 3) (4, 5, 6) ()
4167 Note that *n* must be at least 0 and most equal to the length of
4168 *iterable*.
4170 This function will exhaust the iterable and may require significant
4171 storage.
4172 """
4173 if n < 0:
4174 raise ValueError('n must be >= 0')
4176 seq = tuple(iterable)
4177 size = len(seq)
4179 if n > size:
4180 raise ValueError('n must be <= len(seq)')
4182 for i in range(size - n + 1):
4183 beginning = seq[:i]
4184 middle = seq[i : i + n]
4185 end = seq[i + n :]
4186 yield beginning, middle, end
4189def all_unique(iterable, key=None):
4190 """
4191 Returns ``True`` if all the elements of *iterable* are unique (no two
4192 elements are equal).
4194 >>> all_unique('ABCB')
4195 False
4197 If a *key* function is specified, it will be used to make comparisons.
4199 >>> all_unique('ABCb')
4200 True
4201 >>> all_unique('ABCb', str.lower)
4202 False
4204 The function returns as soon as the first non-unique element is
4205 encountered. Iterables with a mix of hashable and unhashable items can
4206 be used, but the function will be slower for unhashable items.
4207 """
4208 seenset = set()
4209 seenset_add = seenset.add
4210 seenlist = []
4211 seenlist_add = seenlist.append
4212 for element in map(key, iterable) if key else iterable:
4213 try:
4214 if element in seenset:
4215 return False
4216 seenset_add(element)
4217 except TypeError:
4218 if element in seenlist:
4219 return False
4220 seenlist_add(element)
4221 return True
4224def nth_product(index, *iterables, repeat=1):
4225 """Equivalent to ``list(product(*iterables, repeat=repeat))[index]``.
4227 The products of *iterables* can be ordered lexicographically.
4228 :func:`nth_product` computes the product at sort position *index* without
4229 computing the previous products.
4231 >>> nth_product(8, range(2), range(2), range(2), range(2))
4232 (1, 0, 0, 0)
4234 The *repeat* keyword argument specifies the number of repetitions
4235 of the iterables. The above example is equivalent to::
4237 >>> nth_product(8, range(2), repeat=4)
4238 (1, 0, 0, 0)
4240 ``IndexError`` will be raised if the given *index* is invalid.
4241 """
4242 pools = tuple(map(tuple, reversed(iterables))) * repeat
4243 ns = tuple(map(len, pools))
4245 c = prod(ns)
4247 if index < 0:
4248 index += c
4249 if not 0 <= index < c:
4250 raise IndexError
4252 result = []
4253 for pool, n in zip(pools, ns):
4254 result.append(pool[index % n])
4255 index //= n
4257 return tuple(reversed(result))
4260def nth_permutation(iterable, r, index):
4261 """Equivalent to ``list(permutations(iterable, r))[index]```
4263 The subsequences of *iterable* that are of length *r* where order is
4264 important can be ordered lexicographically. :func:`nth_permutation`
4265 computes the subsequence at sort position *index* directly, without
4266 computing the previous subsequences.
4268 >>> nth_permutation('ghijk', 2, 5)
4269 ('h', 'i')
4271 ``ValueError`` will be raised If *r* is negative.
4272 ``IndexError`` will be raised if the given *index* is invalid.
4273 """
4274 pool = list(iterable)
4275 n = len(pool)
4276 if r is None:
4277 r = n
4278 c = perm(n, r)
4280 if index < 0:
4281 index += c
4282 if not 0 <= index < c:
4283 raise IndexError
4285 result = [0] * r
4286 q = index * factorial(n) // c if r < n else index
4287 for d in range(1, n + 1):
4288 q, i = divmod(q, d)
4289 if 0 <= n - d < r:
4290 result[n - d] = i
4291 if q == 0:
4292 break
4294 return tuple(map(pool.pop, result))
4297def nth_combination_with_replacement(iterable, r, index):
4298 """Equivalent to
4299 ``list(combinations_with_replacement(iterable, r))[index]``.
4302 The subsequences with repetition of *iterable* that are of length *r* can
4303 be ordered lexicographically. :func:`nth_combination_with_replacement`
4304 computes the subsequence at sort position *index* directly, without
4305 computing the previous subsequences with replacement.
4307 >>> nth_combination_with_replacement(range(5), 3, 5)
4308 (0, 1, 1)
4310 ``ValueError`` will be raised If *r* is negative.
4311 ``IndexError`` will be raised if the given *index* is invalid.
4312 """
4313 pool = tuple(iterable)
4314 n = len(pool)
4315 if r < 0:
4316 raise ValueError
4317 c = comb(n + r - 1, r) if n else 0 if r else 1
4319 if index < 0:
4320 index += c
4321 if not 0 <= index < c:
4322 raise IndexError
4324 result = []
4325 i = 0
4326 while r:
4327 r -= 1
4328 while n >= 0:
4329 num_combs = comb(n + r - 1, r)
4330 if index < num_combs:
4331 break
4332 n -= 1
4333 i += 1
4334 index -= num_combs
4335 result.append(pool[i])
4337 return tuple(result)
4340def value_chain(*args):
4341 """Yield all arguments passed to the function in the same order in which
4342 they were passed. If an argument itself is iterable then iterate over its
4343 values.
4345 >>> list(value_chain(1, 2, 3, [4, 5, 6]))
4346 [1, 2, 3, 4, 5, 6]
4348 Binary and text strings are not considered iterable and are emitted
4349 as-is:
4351 >>> list(value_chain('12', '34', ['56', '78']))
4352 ['12', '34', '56', '78']
4354 Pre- or postpend a single element to an iterable:
4356 >>> list(value_chain(1, [2, 3, 4, 5, 6]))
4357 [1, 2, 3, 4, 5, 6]
4358 >>> list(value_chain([1, 2, 3, 4, 5], 6))
4359 [1, 2, 3, 4, 5, 6]
4361 Multiple levels of nesting are not flattened.
4363 """
4364 scalar_types = (str, bytes)
4365 for value in args:
4366 if isinstance(value, scalar_types):
4367 yield value
4368 continue
4369 try:
4370 yield from value
4371 except TypeError:
4372 yield value
4375def product_index(element, *iterables, repeat=1):
4376 """Equivalent to ``list(product(*iterables, repeat=repeat)).index(tuple(element))``
4378 The products of *iterables* can be ordered lexicographically.
4379 :func:`product_index` computes the first index of *element* without
4380 computing the previous products.
4382 >>> product_index([8, 2], range(10), range(5))
4383 42
4385 The *repeat* keyword argument specifies the number of repetitions
4386 of the iterables::
4388 >>> product_index([8, 0, 7], range(10), repeat=3)
4389 807
4391 ``ValueError`` will be raised if the given *element* isn't in the product
4392 of *args*.
4393 """
4394 elements = tuple(element)
4395 pools = tuple(map(tuple, iterables)) * repeat
4396 if len(elements) != len(pools):
4397 raise ValueError('element is not a product of args')
4399 index = 0
4400 for elem, pool in zip(elements, pools):
4401 index = index * len(pool) + pool.index(elem)
4402 return index
4405def combination_index(element, iterable):
4406 """Equivalent to ``list(combinations(iterable, r)).index(element)``
4408 The subsequences of *iterable* that are of length *r* can be ordered
4409 lexicographically. :func:`combination_index` computes the index of the
4410 first *element*, without computing the previous combinations.
4412 >>> combination_index('adf', 'abcdefg')
4413 10
4415 ``ValueError`` will be raised if the given *element* isn't one of the
4416 combinations of *iterable*.
4417 """
4418 element = enumerate(element)
4419 k, y = next(element, (None, None))
4420 if k is None:
4421 return 0
4423 indexes = []
4424 pool = enumerate(iterable)
4425 for n, x in pool:
4426 if x == y:
4427 indexes.append(n)
4428 tmp, y = next(element, (None, None))
4429 if tmp is None:
4430 break
4431 else:
4432 k = tmp
4433 else:
4434 raise ValueError('element is not a combination of iterable')
4436 n, _ = last(pool, default=(n, None))
4438 index = 1
4439 for i, j in enumerate(reversed(indexes), start=1):
4440 j = n - j
4441 if i <= j:
4442 index += comb(j, i)
4444 return comb(n + 1, k + 1) - index
4447def combination_with_replacement_index(element, iterable):
4448 """Equivalent to
4449 ``list(combinations_with_replacement(iterable, r)).index(element)``
4451 The subsequences with repetition of *iterable* that are of length *r* can
4452 be ordered lexicographically. :func:`combination_with_replacement_index`
4453 computes the index of the first *element*, without computing the previous
4454 combinations with replacement.
4456 >>> combination_with_replacement_index('adf', 'abcdefg')
4457 20
4459 ``ValueError`` will be raised if the given *element* isn't one of the
4460 combinations with replacement of *iterable*.
4461 """
4462 element = tuple(element)
4463 l = len(element)
4464 element = enumerate(element)
4466 k, y = next(element, (None, None))
4467 if k is None:
4468 return 0
4470 indexes = []
4471 pool = tuple(iterable)
4472 for n, x in enumerate(pool):
4473 while x == y:
4474 indexes.append(n)
4475 tmp, y = next(element, (None, None))
4476 if tmp is None:
4477 break
4478 else:
4479 k = tmp
4480 if y is None:
4481 break
4482 else:
4483 raise ValueError(
4484 'element is not a combination with replacement of iterable'
4485 )
4487 n = len(pool)
4488 occupations = [0] * n
4489 for p in indexes:
4490 occupations[p] += 1
4492 index = 0
4493 cumulative_sum = 0
4494 for k in range(1, n):
4495 cumulative_sum += occupations[k - 1]
4496 j = l + n - 1 - k - cumulative_sum
4497 i = n - k
4498 if i <= j:
4499 index += comb(j, i)
4501 return index
4504def permutation_index(element, iterable):
4505 """Equivalent to ``list(permutations(iterable, r)).index(element)```
4507 The subsequences of *iterable* that are of length *r* where order is
4508 important can be ordered lexicographically. :func:`permutation_index`
4509 computes the index of the first *element* directly, without computing
4510 the previous permutations.
4512 >>> permutation_index([1, 3, 2], range(5))
4513 19
4515 ``ValueError`` will be raised if the given *element* isn't one of the
4516 permutations of *iterable*.
4517 """
4518 index = 0
4519 pool = list(iterable)
4520 for i, x in zip(range(len(pool), -1, -1), element):
4521 r = pool.index(x)
4522 index = index * i + r
4523 del pool[r]
4525 return index
4528class countable:
4529 """Wrap *iterable* and keep a count of how many items have been consumed.
4531 The ``items_seen`` attribute starts at ``0`` and increments as the iterable
4532 is consumed:
4534 >>> iterable = map(str, range(10))
4535 >>> it = countable(iterable)
4536 >>> it.items_seen
4537 0
4538 >>> next(it), next(it)
4539 ('0', '1')
4540 >>> list(it)
4541 ['2', '3', '4', '5', '6', '7', '8', '9']
4542 >>> it.items_seen
4543 10
4544 """
4546 def __init__(self, iterable):
4547 self._iterator = iter(iterable)
4548 self.items_seen = 0
4550 def __iter__(self):
4551 return self
4553 def __next__(self):
4554 item = next(self._iterator)
4555 self.items_seen += 1
4557 return item
4560def chunked_even(iterable, n):
4561 """Break *iterable* into lists of approximately length *n*.
4562 Items are distributed such the lengths of the lists differ by at most
4563 1 item.
4565 >>> iterable = [1, 2, 3, 4, 5, 6, 7]
4566 >>> n = 3
4567 >>> list(chunked_even(iterable, n)) # List lengths: 3, 2, 2
4568 [[1, 2, 3], [4, 5], [6, 7]]
4569 >>> list(chunked(iterable, n)) # List lengths: 3, 3, 1
4570 [[1, 2, 3], [4, 5, 6], [7]]
4572 """
4573 iterator = iter(iterable)
4575 # Initialize a buffer to process the chunks while keeping
4576 # some back to fill any underfilled chunks
4577 min_buffer = (n - 1) * (n - 2)
4578 buffer = list(islice(iterator, min_buffer))
4580 # Append items until we have a completed chunk
4581 for _ in islice(map(buffer.append, iterator), n, None, n):
4582 yield buffer[:n]
4583 del buffer[:n]
4585 # Check if any chunks need addition processing
4586 if not buffer:
4587 return
4588 length = len(buffer)
4590 # Chunks are either size `full_size <= n` or `partial_size = full_size - 1`
4591 q, r = divmod(length, n)
4592 num_lists = q + (1 if r > 0 else 0)
4593 q, r = divmod(length, num_lists)
4594 full_size = q + (1 if r > 0 else 0)
4595 partial_size = full_size - 1
4596 num_full = length - partial_size * num_lists
4598 # Yield chunks of full size
4599 partial_start_idx = num_full * full_size
4600 if full_size > 0:
4601 for i in range(0, partial_start_idx, full_size):
4602 yield buffer[i : i + full_size]
4604 # Yield chunks of partial size
4605 if partial_size > 0:
4606 for i in range(partial_start_idx, length, partial_size):
4607 yield buffer[i : i + partial_size]
4610def zip_broadcast(*objects, scalar_types=(str, bytes), strict=False):
4611 """A version of :func:`zip` that "broadcasts" any scalar
4612 (i.e., non-iterable) items into output tuples.
4614 >>> iterable_1 = [1, 2, 3]
4615 >>> iterable_2 = ['a', 'b', 'c']
4616 >>> scalar = '_'
4617 >>> list(zip_broadcast(iterable_1, iterable_2, scalar))
4618 [(1, 'a', '_'), (2, 'b', '_'), (3, 'c', '_')]
4620 The *scalar_types* keyword argument determines what types are considered
4621 scalar. It is set to ``(str, bytes)`` by default. Set it to ``None`` to
4622 treat strings and byte strings as iterable:
4624 >>> list(zip_broadcast('abc', 0, 'xyz', scalar_types=None))
4625 [('a', 0, 'x'), ('b', 0, 'y'), ('c', 0, 'z')]
4627 If the *strict* keyword argument is ``True``, then
4628 ``ValueError`` will be raised if any of the iterables have
4629 different lengths.
4630 """
4632 def is_scalar(obj):
4633 if scalar_types and isinstance(obj, scalar_types):
4634 return True
4635 try:
4636 iter(obj)
4637 except TypeError:
4638 return True
4639 else:
4640 return False
4642 size = len(objects)
4643 if not size:
4644 return
4646 new_item = [None] * size
4647 iterables, iterable_positions = [], []
4648 for i, obj in enumerate(objects):
4649 if is_scalar(obj):
4650 new_item[i] = obj
4651 else:
4652 iterables.append(iter(obj))
4653 iterable_positions.append(i)
4655 if not iterables:
4656 yield tuple(objects)
4657 return
4659 for item in zip(*iterables, strict=strict):
4660 for i, new_item[i] in zip(iterable_positions, item):
4661 pass
4662 yield tuple(new_item)
4665def unique_in_window(iterable, n, key=None):
4666 """Yield the items from *iterable* that haven't been seen recently.
4667 *n* is the size of the sliding window.
4669 >>> iterable = [0, 1, 0, 2, 3, 0]
4670 >>> n = 3
4671 >>> list(unique_in_window(iterable, n))
4672 [0, 1, 2, 3, 0]
4674 The *key* function, if provided, will be used to determine uniqueness:
4676 >>> list(unique_in_window('abAcda', 3, key=lambda x: x.lower()))
4677 ['a', 'b', 'c', 'd', 'a']
4679 Updates a sliding window no larger than n and yields a value
4680 if the item only occurs once in the updated window.
4682 When `n == 1`, *unique_in_window* is memoryless:
4684 >>> list(unique_in_window('aab', n=1))
4685 ['a', 'a', 'b']
4687 The items in *iterable* must be hashable.
4689 """
4690 if n <= 0:
4691 raise ValueError('n must be greater than 0')
4693 window = deque(maxlen=n)
4694 counts = Counter()
4695 use_key = key is not None
4697 for item in iterable:
4698 if len(window) == n:
4699 to_discard = window[0]
4700 if counts[to_discard] == 1:
4701 del counts[to_discard]
4702 else:
4703 counts[to_discard] -= 1
4705 k = key(item) if use_key else item
4706 if k not in counts:
4707 yield item
4708 counts[k] += 1
4709 window.append(k)
4712def duplicates_everseen(iterable, key=None):
4713 """Yield duplicate elements after their first appearance.
4715 >>> list(duplicates_everseen('mississippi'))
4716 ['s', 'i', 's', 's', 'i', 'p', 'i']
4717 >>> list(duplicates_everseen('AaaBbbCccAaa', str.lower))
4718 ['a', 'a', 'b', 'b', 'c', 'c', 'A', 'a', 'a']
4720 This function is analogous to :func:`unique_everseen` and is subject to
4721 the same performance considerations.
4723 """
4724 seen_set = set()
4725 seen_list = []
4726 use_key = key is not None
4728 for element in iterable:
4729 k = key(element) if use_key else element
4730 try:
4731 if k not in seen_set:
4732 seen_set.add(k)
4733 else:
4734 yield element
4735 except TypeError:
4736 if k not in seen_list:
4737 seen_list.append(k)
4738 else:
4739 yield element
4742def duplicates_justseen(iterable, key=None):
4743 """Yields serially-duplicate elements after their first appearance.
4745 >>> list(duplicates_justseen('mississippi'))
4746 ['s', 's', 'p']
4747 >>> list(duplicates_justseen('AaaBbbCccAaa', str.lower))
4748 ['a', 'a', 'b', 'b', 'c', 'c', 'a', 'a']
4750 This function is analogous to :func:`unique_justseen`.
4752 """
4753 return flatten(g for _, g in groupby(iterable, key) for _ in g)
4756def classify_unique(iterable, key=None):
4757 """Classify each element in terms of its uniqueness.
4759 For each element in the input iterable, return a 3-tuple consisting of:
4761 1. The element itself
4762 2. ``False`` if the element is equal to the one preceding it in the input,
4763 ``True`` otherwise (i.e. the equivalent of :func:`unique_justseen`)
4764 3. ``False`` if this element has been seen anywhere in the input before,
4765 ``True`` otherwise (i.e. the equivalent of :func:`unique_everseen`)
4767 >>> list(classify_unique('otto')) # doctest: +NORMALIZE_WHITESPACE
4768 [('o', True, True),
4769 ('t', True, True),
4770 ('t', False, False),
4771 ('o', True, False)]
4773 This function is analogous to :func:`unique_everseen` and is subject to
4774 the same performance considerations.
4776 """
4777 seen_set = set()
4778 seen_list = []
4779 use_key = key is not None
4780 previous = None
4782 for i, element in enumerate(iterable):
4783 k = key(element) if use_key else element
4784 is_unique_justseen = not i or previous != k
4785 previous = k
4786 is_unique_everseen = False
4787 try:
4788 if k not in seen_set:
4789 seen_set.add(k)
4790 is_unique_everseen = True
4791 except TypeError:
4792 if k not in seen_list:
4793 seen_list.append(k)
4794 is_unique_everseen = True
4795 yield element, is_unique_justseen, is_unique_everseen
4798def minmax(iterable_or_value, *others, key=None, default=_marker):
4799 """Returns both the smallest and largest items from an iterable
4800 or from two or more arguments.
4802 >>> minmax([3, 1, 5])
4803 (1, 5)
4805 >>> minmax(4, 2, 6)
4806 (2, 6)
4808 If a *key* function is provided, it will be used to transform the input
4809 items for comparison.
4811 >>> minmax([5, 30], key=str) # '30' sorts before '5'
4812 (30, 5)
4814 If a *default* value is provided, it will be returned if there are no
4815 input items.
4817 >>> minmax([], default=(0, 0))
4818 (0, 0)
4820 Otherwise ``ValueError`` is raised.
4822 This function makes a single pass over the input elements and takes care to
4823 minimize the number of comparisons made during processing.
4825 Note that unlike the builtin ``max`` function, which always returns the first
4826 item with the maximum value, this function may return another item when there are
4827 ties.
4829 This function is based on the
4830 `recipe <https://code.activestate.com/recipes/577916-fast-minmax-function>`__ by
4831 Raymond Hettinger.
4832 """
4833 iterable = (iterable_or_value, *others) if others else iterable_or_value
4835 it = iter(iterable)
4837 try:
4838 lo = hi = next(it)
4839 except StopIteration as exc:
4840 if default is _marker:
4841 raise ValueError(
4842 '`minmax()` argument is an empty iterable. '
4843 'Provide a `default` value to suppress this error.'
4844 ) from exc
4845 return default
4847 # Different branches depending on the presence of key. This saves a lot
4848 # of unimportant copies which would slow the "key=None" branch
4849 # significantly down.
4850 if key is None:
4851 for x, y in zip_longest(it, it, fillvalue=lo):
4852 if y < x:
4853 x, y = y, x
4854 if x < lo:
4855 lo = x
4856 if hi < y:
4857 hi = y
4859 else:
4860 lo_key = hi_key = key(lo)
4862 for x, y in zip_longest(it, it, fillvalue=lo):
4863 x_key, y_key = key(x), key(y)
4865 if y_key < x_key:
4866 x, y, x_key, y_key = y, x, y_key, x_key
4867 if x_key < lo_key:
4868 lo, lo_key = x, x_key
4869 if hi_key < y_key:
4870 hi, hi_key = y, y_key
4872 return lo, hi
4875def constrained_batches(
4876 iterable, max_size, max_count=None, get_len=len, strict=True
4877):
4878 """Yield batches of items from *iterable* with a combined size limited by
4879 *max_size*.
4881 >>> iterable = [b'12345', b'123', b'12345678', b'1', b'1', b'12', b'1']
4882 >>> list(constrained_batches(iterable, 10))
4883 [(b'12345', b'123'), (b'12345678', b'1', b'1'), (b'12', b'1')]
4885 If a *max_count* is supplied, the number of items per batch is also
4886 limited:
4888 >>> iterable = [b'12345', b'123', b'12345678', b'1', b'1', b'12', b'1']
4889 >>> list(constrained_batches(iterable, 10, max_count = 2))
4890 [(b'12345', b'123'), (b'12345678', b'1'), (b'1', b'12'), (b'1',)]
4892 If a *get_len* function is supplied, use that instead of :func:`len` to
4893 determine item size.
4895 If *strict* is ``True``, raise ``ValueError`` if any single item is bigger
4896 than *max_size*. Otherwise, allow single items to exceed *max_size*.
4897 """
4898 if max_size <= 0:
4899 raise ValueError('maximum size must be greater than zero')
4901 batch = []
4902 batch_size = 0
4903 batch_count = 0
4904 for item in iterable:
4905 item_len = get_len(item)
4906 if strict and item_len > max_size:
4907 raise ValueError('item size exceeds maximum size')
4909 reached_count = batch_count == max_count
4910 reached_size = item_len + batch_size > max_size
4911 if batch_count and (reached_size or reached_count):
4912 yield tuple(batch)
4913 batch.clear()
4914 batch_size = 0
4915 batch_count = 0
4917 batch.append(item)
4918 batch_size += item_len
4919 batch_count += 1
4921 if batch:
4922 yield tuple(batch)
4925def gray_product(*iterables, repeat=1):
4926 """Like :func:`itertools.product`, but return tuples in an order such
4927 that only one element in the generated tuple changes from one iteration
4928 to the next.
4930 >>> list(gray_product('AB','CD'))
4931 [('A', 'C'), ('B', 'C'), ('B', 'D'), ('A', 'D')]
4933 The *repeat* keyword argument specifies the number of repetitions
4934 of the iterables. For example, ``gray_product('AB', repeat=3)`` is
4935 equivalent to ``gray_product('AB', 'AB', 'AB')``.
4937 This function consumes all of the input iterables before producing output.
4938 If any of the input iterables have fewer than two items, ``ValueError``
4939 is raised.
4941 For information on the algorithm, see
4942 `this section <https://www-cs-faculty.stanford.edu/~knuth/fasc2a.ps.gz>`__
4943 of Donald Knuth's *The Art of Computer Programming*.
4944 """
4945 all_iterables = tuple(map(tuple, iterables)) * repeat
4946 iterable_count = len(all_iterables)
4947 for iterable in all_iterables:
4948 if len(iterable) < 2:
4949 raise ValueError("each iterable must have two or more items")
4951 # This is based on "Algorithm H" from section 7.2.1.1, page 20.
4952 # a holds the indexes of the source iterables for the n-tuple to be yielded
4953 # f is the array of "focus pointers"
4954 # o is the array of "directions"
4955 a = [0] * iterable_count
4956 f = list(range(iterable_count + 1))
4957 o = [1] * iterable_count
4958 while True:
4959 yield tuple(all_iterables[i][a[i]] for i in range(iterable_count))
4960 j = f[0]
4961 f[0] = 0
4962 if j == iterable_count:
4963 break
4964 a[j] = a[j] + o[j]
4965 if a[j] == 0 or a[j] == len(all_iterables[j]) - 1:
4966 o[j] = -o[j]
4967 f[j] = f[j + 1]
4968 f[j + 1] = j + 1
4971def partial_product(*iterables, repeat=1):
4972 """Yields tuples containing one item from each iterator, with subsequent
4973 tuples changing a single item at a time by advancing each iterator until it
4974 is exhausted. This sequence guarantees every value in each iterable is
4975 output at least once without generating all possible combinations.
4977 This may be useful, for example, when testing an expensive function.
4979 >>> list(partial_product('AB', 'C', 'DEF'))
4980 [('A', 'C', 'D'), ('B', 'C', 'D'), ('B', 'C', 'E'), ('B', 'C', 'F')]
4982 The *repeat* keyword argument specifies the number of repetitions
4983 of the iterables. For example, ``partial_product('AB', repeat=3)`` is
4984 equivalent to ``partial_product('AB', 'AB', 'AB')``.
4985 """
4987 all_iterables = tuple(map(tuple, iterables)) * repeat
4988 iterators = tuple(map(iter, all_iterables))
4990 try:
4991 prod = [next(it) for it in iterators]
4992 except StopIteration:
4993 return
4994 yield tuple(prod)
4996 for i, it in enumerate(iterators):
4997 for prod[i] in it:
4998 yield tuple(prod)
5001def takewhile_inclusive(predicate, iterable):
5002 """A variant of :func:`takewhile` that yields one additional element.
5004 >>> list(takewhile_inclusive(lambda x: x < 5, [1, 4, 6, 4, 1]))
5005 [1, 4, 6]
5007 :func:`takewhile` would return ``[1, 4]``.
5008 """
5009 for x in iterable:
5010 yield x
5011 if not predicate(x):
5012 break
5015def outer_product(func, xs, ys, *args, **kwargs):
5016 """A generalized outer product that applies a binary function to all
5017 pairs of items. Returns a 2D matrix with ``len(xs)`` rows and ``len(ys)``
5018 columns.
5019 Also accepts ``*args`` and ``**kwargs`` that are passed to ``func``.
5021 Multiplication table:
5023 >>> from operator import mul
5024 >>> list(outer_product(mul, range(1, 4), range(1, 6)))
5025 [(1, 2, 3, 4, 5), (2, 4, 6, 8, 10), (3, 6, 9, 12, 15)]
5027 Cross tabulation:
5029 >>> xs = ['A', 'B', 'A', 'A', 'B', 'B', 'A', 'A', 'B', 'B']
5030 >>> ys = ['X', 'X', 'X', 'Y', 'Z', 'Z', 'Y', 'Y', 'Z', 'Z']
5031 >>> pair_counts = Counter(zip(xs, ys))
5032 >>> count_rows = lambda x, y: pair_counts[x, y]
5033 >>> list(outer_product(count_rows, sorted(set(xs)), sorted(set(ys))))
5034 [(2, 3, 0), (1, 0, 4)]
5036 Usage with ``*args`` and ``**kwargs``:
5038 >>> animals = ['cat', 'wolf', 'mouse']
5039 >>> list(outer_product(min, animals, animals, key=len))
5040 [('cat', 'cat', 'cat'), ('cat', 'wolf', 'wolf'), ('cat', 'wolf', 'mouse')]
5041 """
5042 ys = tuple(ys)
5043 return batched(
5044 starmap(lambda x, y: func(x, y, *args, **kwargs), product(xs, ys)),
5045 n=len(ys),
5046 )
5049def iter_suppress(iterable, *exceptions):
5050 """Yield each of the items from *iterable*. If the iteration raises one of
5051 the specified *exceptions*, that exception will be suppressed and iteration
5052 will stop.
5054 >>> from itertools import chain
5055 >>> def breaks_at_five(x):
5056 ... while True:
5057 ... if x >= 5:
5058 ... raise RuntimeError
5059 ... yield x
5060 ... x += 1
5061 >>> it_1 = iter_suppress(breaks_at_five(1), RuntimeError)
5062 >>> it_2 = iter_suppress(breaks_at_five(2), RuntimeError)
5063 >>> list(chain(it_1, it_2))
5064 [1, 2, 3, 4, 2, 3, 4]
5065 """
5066 try:
5067 yield from iterable
5068 except exceptions:
5069 return
5072def filter_map(func, iterable):
5073 """Apply *func* to every element of *iterable*, yielding only those which
5074 are not ``None``.
5076 >>> elems = ['1', 'a', '2', 'b', '3']
5077 >>> list(filter_map(lambda s: int(s) if s.isnumeric() else None, elems))
5078 [1, 2, 3]
5079 """
5080 for x in iterable:
5081 y = func(x)
5082 if y is not None:
5083 yield y
5086def powerset_of_sets(iterable, *, baseset=set):
5087 """Yields all possible subsets of the iterable.
5089 >>> list(powerset_of_sets([1, 2, 3])) # doctest: +SKIP
5090 [set(), {1}, {2}, {3}, {1, 2}, {1, 3}, {2, 3}, {1, 2, 3}]
5091 >>> list(powerset_of_sets([1, 1, 0])) # doctest: +SKIP
5092 [set(), {1}, {0}, {0, 1}]
5094 :func:`powerset_of_sets` takes care to minimize the number
5095 of hash operations performed.
5097 The *baseset* parameter determines what kind of sets are
5098 constructed, either *set* or *frozenset*.
5099 """
5100 sets = tuple(dict.fromkeys(map(frozenset, zip(iterable))))
5101 union = baseset().union
5102 return chain.from_iterable(
5103 starmap(union, combinations(sets, r)) for r in range(len(sets) + 1)
5104 )
5107def join_mappings(**field_to_map):
5108 """
5109 Joins multiple mappings together using their common keys.
5111 >>> user_scores = {'elliot': 50, 'claris': 60}
5112 >>> user_times = {'elliot': 30, 'claris': 40}
5113 >>> join_mappings(score=user_scores, time=user_times)
5114 {'elliot': {'score': 50, 'time': 30}, 'claris': {'score': 60, 'time': 40}}
5115 """
5116 ret = defaultdict(dict)
5118 for field_name, mapping in field_to_map.items():
5119 for key, value in mapping.items():
5120 ret[key][field_name] = value
5122 return dict(ret)
5125def _complex_sumprod(v1, v2):
5126 """High precision sumprod() for complex numbers.
5127 Used by :func:`dft` and :func:`idft`.
5128 """
5130 real = attrgetter('real')
5131 imag = attrgetter('imag')
5132 r1 = chain(map(real, v1), map(neg, map(imag, v1)))
5133 r2 = chain(map(real, v2), map(imag, v2))
5134 i1 = chain(map(real, v1), map(imag, v1))
5135 i2 = chain(map(imag, v2), map(real, v2))
5136 return complex(_fsumprod(r1, r2), _fsumprod(i1, i2))
5139def dft(xarr):
5140 """Discrete Fourier Transform. *xarr* is a sequence of complex numbers.
5141 Yields the components of the corresponding transformed output vector.
5143 >>> import cmath
5144 >>> xarr = [1, 2-1j, -1j, -1+2j] # time domain
5145 >>> Xarr = [2, -2-2j, -2j, 4+4j] # frequency domain
5146 >>> magnitudes, phases = zip(*map(cmath.polar, Xarr))
5147 >>> all(map(cmath.isclose, dft(xarr), Xarr))
5148 True
5150 Inputs are restricted to numeric types that can add and multiply
5151 with a complex number. This includes int, float, complex, and
5152 Fraction, but excludes Decimal.
5154 See :func:`idft` for the inverse Discrete Fourier Transform.
5155 """
5156 N = len(xarr)
5157 roots_of_unity = [e ** (n / N * tau * -1j) for n in range(N)]
5158 for k in range(N):
5159 coeffs = [roots_of_unity[k * n % N] for n in range(N)]
5160 yield _complex_sumprod(xarr, coeffs)
5163def idft(Xarr):
5164 """Inverse Discrete Fourier Transform. *Xarr* is a sequence of
5165 complex numbers. Yields the components of the corresponding
5166 inverse-transformed output vector.
5168 >>> import cmath
5169 >>> xarr = [1, 2-1j, -1j, -1+2j] # time domain
5170 >>> Xarr = [2, -2-2j, -2j, 4+4j] # frequency domain
5171 >>> all(map(cmath.isclose, idft(Xarr), xarr))
5172 True
5174 Inputs are restricted to numeric types that can add and multiply
5175 with a complex number. This includes int, float, complex, and
5176 Fraction, but excludes Decimal.
5178 See :func:`dft` for the Discrete Fourier Transform.
5179 """
5180 N = len(Xarr)
5181 roots_of_unity = [e ** (n / N * tau * 1j) for n in range(N)]
5182 for k in range(N):
5183 coeffs = [roots_of_unity[k * n % N] for n in range(N)]
5184 yield _complex_sumprod(Xarr, coeffs) / N
5187def doublestarmap(func, iterable):
5188 """Apply *func* to every item of *iterable* by dictionary unpacking
5189 the item into *func*.
5191 The difference between :func:`itertools.starmap` and :func:`doublestarmap`
5192 parallels the distinction between ``func(*a)`` and ``func(**a)``.
5194 >>> iterable = [{'a': 1, 'b': 2}, {'a': 40, 'b': 60}]
5195 >>> list(doublestarmap(lambda a, b: a + b, iterable))
5196 [3, 100]
5198 ``TypeError`` will be raised if *func*'s signature doesn't match the
5199 mapping contained in *iterable* or if *iterable* does not contain mappings.
5200 """
5201 for item in iterable:
5202 yield func(**item)
5205def _nth_prime_bounds(n):
5206 """Bounds for the nth prime (counting from 1): lb < p_n < ub."""
5207 # At and above 688,383, the lb/ub spread is under 0.003 * p_n.
5209 if n < 1:
5210 raise ValueError
5212 if n < 6:
5213 return (n, 2.25 * n)
5215 # https://en.wikipedia.org/wiki/Prime-counting_function#Inequalities
5216 upper_bound = n * log(n * log(n))
5217 lower_bound = upper_bound - n
5218 if n >= 688_383:
5219 upper_bound -= n * (1.0 - (log(log(n)) - 2.0) / log(n))
5221 return lower_bound, upper_bound
5224def nth_prime(n, *, approximate=False):
5225 """Return the nth prime (counting from 0).
5227 >>> nth_prime(0)
5228 2
5229 >>> nth_prime(100)
5230 547
5232 If *approximate* is set to True, will return a prime close
5233 to the nth prime. The estimation is much faster than computing
5234 an exact result.
5236 >>> nth_prime(200_000_000, approximate=True) # Exact result is 4222234763
5237 4217820427
5239 """
5240 lb, ub = _nth_prime_bounds(n + 1)
5242 if not approximate or n <= 1_000_000:
5243 return nth(sieve(ceil(ub)), n)
5245 # Search from the midpoint and return the first odd prime
5246 odd = floor((lb + ub) / 2) | 1
5247 return first_true(count(odd, step=2), pred=is_prime)
5250def argmin(iterable, *, key=None):
5251 """
5252 Index of the first occurrence of a minimum value in an iterable.
5254 >>> argmin('efghabcdijkl')
5255 4
5256 >>> argmin([3, 2, 1, 0, 4, 2, 1, 0])
5257 3
5259 For example, look up a label corresponding to the position
5260 of a value that minimizes a cost function::
5262 >>> def cost(x):
5263 ... "Days for a wound to heal given a subject's age."
5264 ... return x**2 - 20*x + 150
5265 ...
5266 >>> labels = ['homer', 'marge', 'bart', 'lisa', 'maggie']
5267 >>> ages = [ 35, 30, 10, 9, 1 ]
5269 # Fastest healing family member
5270 >>> labels[argmin(ages, key=cost)]
5271 'bart'
5273 # Age with fastest healing
5274 >>> min(ages, key=cost)
5275 10
5277 """
5278 if key is not None:
5279 iterable = map(key, iterable)
5280 return min(enumerate(iterable), key=itemgetter(1))[0]
5283def argmax(iterable, *, key=None):
5284 """
5285 Index of the first occurrence of a maximum value in an iterable.
5287 >>> argmax('abcdefghabcd')
5288 7
5289 >>> argmax([0, 1, 2, 3, 3, 2, 1, 0])
5290 3
5292 For example, identify the best machine learning model::
5294 >>> models = ['svm', 'random forest', 'knn', 'naïve bayes']
5295 >>> accuracy = [ 68, 61, 84, 72 ]
5297 # Most accurate model
5298 >>> models[argmax(accuracy)]
5299 'knn'
5301 # Best accuracy
5302 >>> max(accuracy)
5303 84
5305 """
5306 if key is not None:
5307 iterable = map(key, iterable)
5308 return max(enumerate(iterable), key=itemgetter(1))[0]
5311def _extract_monotonic(iterator, indices):
5312 'Non-decreasing indices, lazily consumed'
5313 num_read = 0
5314 for index in indices:
5315 advance = index - num_read
5316 try:
5317 value = next(islice(iterator, advance, None))
5318 except ValueError:
5319 if advance != -1 or index < 0:
5320 raise ValueError(f'Invalid index: {index}') from None
5321 except StopIteration:
5322 raise IndexError(index) from None
5323 else:
5324 num_read += advance + 1
5325 yield value
5328def _extract_buffered(iterator, index_and_position):
5329 'Arbitrary index order, greedily consumed'
5330 buffer = {}
5331 iterator_position = -1
5332 next_to_emit = 0
5334 for index, order in index_and_position:
5335 advance = index - iterator_position
5336 if advance:
5337 try:
5338 value = next(islice(iterator, advance - 1, None))
5339 except StopIteration:
5340 raise IndexError(index) from None
5341 iterator_position = index
5343 buffer[order] = value
5345 while next_to_emit in buffer:
5346 yield buffer.pop(next_to_emit)
5347 next_to_emit += 1
5350def extract(iterable, indices, *, monotonic=False):
5351 """Yield values at the specified indices.
5353 Example:
5355 >>> data = 'abcdefghijklmnopqrstuvwxyz'
5356 >>> list(extract(data, [7, 4, 11, 11, 14]))
5357 ['h', 'e', 'l', 'l', 'o']
5359 The *iterable* is consumed lazily and can be infinite.
5361 When *monotonic* is false, the *indices* are consumed immediately
5362 and must be finite. When *monotonic* is true, *indices* are consumed
5363 lazily and can be infinite but must be non-decreasing.
5365 Raises ``IndexError`` if an index lies beyond the iterable.
5366 Raises ``ValueError`` for a negative index or for a decreasing
5367 index when *monotonic* is true.
5368 """
5370 iterator = iter(iterable)
5371 indices = iter(indices)
5373 if monotonic:
5374 return _extract_monotonic(iterator, indices)
5376 index_and_position = sorted(zip(indices, count()))
5377 if index_and_position and index_and_position[0][0] < 0:
5378 raise ValueError('Indices must be non-negative')
5379 return _extract_buffered(iterator, index_and_position)
5382class serialize:
5383 """Wrap a non-concurrent iterator with a lock to enforce sequential access.
5385 Applies a non-reentrant lock around calls to ``__next__``, allowing
5386 iterator and generator instances to be shared by multiple consumer
5387 threads.
5388 """
5390 __slots__ = ('_iterator', '_lock')
5392 def __init__(self, iterable):
5393 self._iterator = iter(iterable)
5394 self._lock = Lock()
5396 def __iter__(self):
5397 return self
5399 def __next__(self):
5400 with self._lock:
5401 return next(self._iterator)
5403 def send(self, value, /):
5404 """Send a value to a generator.
5406 Raises AttributeError if not a generator.
5407 """
5408 with self._lock:
5409 return self._iterator.send(value)
5411 def throw(self, *args):
5412 """Call throw() on a generator.
5414 Raises AttributeError if not a generator.
5415 """
5416 with self._lock:
5417 return self._iterator.throw(*args)
5419 def close(self):
5420 """Call close() on a generator.
5422 Raises AttributeError if not a generator.
5423 """
5424 with self._lock:
5425 return self._iterator.close()
5428def synchronized(func):
5429 """Wrap an iterator-returning callable to make its iterators thread-safe.
5431 Existing itertools and more-itertools can be wrapped so that their
5432 iterator instances are serialized.
5434 For example, ``itertools.count`` does not make thread-safe instances,
5435 but that is easily fixed with::
5437 atomic_counter = synchronized(itertools.count)
5439 Can also be used as a decorator for generator functions definitions
5440 so that the generator instances are serialized::
5442 @synchronized
5443 def enumerate_and_timestamp(iterable):
5444 for count, value in enumerate(iterable):
5445 yield count, time_ns(), value
5447 """
5449 @wraps(func)
5450 def inner(*args, **kwargs):
5451 iterator = func(*args, **kwargs)
5452 return serialize(iterator)
5454 return inner
5457def concurrent_tee(iterable, n=2):
5458 """Variant of itertools.tee() but with guaranteed threading semantics.
5460 Takes a non-threadsafe iterator as an input and creates concurrent
5461 tee objects for other threads to have reliable independent copies of
5462 the data stream.
5464 The new iterators are only thread-safe if consumed within a single thread.
5465 To share just one of the new iterators across multiple threads, wrap it
5466 with :func:`serialize`.
5467 """
5469 if n < 0:
5470 raise ValueError
5471 if n == 0:
5472 return ()
5473 iterator = _concurrent_tee(iterable)
5474 result = [iterator]
5475 for _ in range(n - 1):
5476 result.append(_concurrent_tee(iterator))
5477 return tuple(result)
5480class _concurrent_tee:
5481 __slots__ = ('iterator', 'link', 'lock')
5483 def __init__(self, iterable):
5484 if isinstance(iterable, _concurrent_tee):
5485 self.iterator = iterable.iterator
5486 self.link = iterable.link
5487 self.lock = iterable.lock
5488 else:
5489 self.iterator = iter(iterable)
5490 self.link = [None, None]
5491 self.lock = Lock()
5493 def __iter__(self):
5494 return self
5496 def __next__(self):
5497 link = self.link
5498 if link[1] is None:
5499 with self.lock:
5500 if link[1] is None:
5501 link[0] = next(self.iterator)
5502 link[1] = [None, None]
5503 value, self.link = link
5504 return value