Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/more_itertools/more.py: 19%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1__lazy_modules__ = frozenset({'queue'})
3import math
4import types
6from _thread import allocate_lock
7from collections import Counter, defaultdict, deque
8from collections.abc import Sequence
9from contextlib import suppress
10from functools import cached_property, partial, wraps
11from heapq import heapify, heapreplace
12from itertools import (
13 chain,
14 combinations,
15 compress,
16 count,
17 cycle,
18 dropwhile,
19 groupby,
20 islice,
21 permutations,
22 repeat,
23 starmap,
24 takewhile,
25 tee,
26 zip_longest,
27 product,
28)
29from math import comb, e, exp, factorial, floor, fsum, log, log1p, perm, tau
30from math import ceil, prod
31from queue import Empty, Queue
32from random import random, randrange, shuffle, uniform
33from operator import (
34 attrgetter,
35 getitem,
36 is_not,
37 itemgetter,
38 lt,
39 neg,
40 sub,
41 gt,
42)
43from sys import maxsize
44from time import monotonic
46from .recipes import (
47 _marker,
48 consume,
49 first_true,
50 flatten,
51 is_prime,
52 nth,
53 powerset,
54 sieve,
55 take,
56 unique_everseen,
57 all_equal,
58 batched,
59)
61__all__ = [
62 'AbortThread',
63 'SequenceView',
64 'adjacent',
65 'all_unique',
66 'always_iterable',
67 'always_reversible',
68 'argmax',
69 'argmin',
70 'bucket',
71 'callback_iter',
72 'chunked',
73 'chunked_even',
74 'circular_shifts',
75 'collapse',
76 'combination_index',
77 'combination_with_replacement_index',
78 'concurrent_tee',
79 'consecutive_groups',
80 'constrained_batches',
81 'consumer',
82 'count_cycle',
83 'countable',
84 'derangements',
85 'dft',
86 'difference',
87 'distinct_combinations',
88 'distinct_permutations',
89 'distribute',
90 'divide',
91 'doublestarmap',
92 'duplicates_everseen',
93 'duplicates_justseen',
94 'classify_unique',
95 'exactly_n',
96 'extract',
97 'filter_except',
98 'filter_map',
99 'first',
100 'gray_product',
101 'groupby_transform',
102 'ichunked',
103 'iequals',
104 'idft',
105 'ilen',
106 'interleave',
107 'interleave_evenly',
108 'interleave_longest',
109 'interleave_randomly',
110 'intersperse',
111 'is_sorted',
112 'islice_extended',
113 'iterate',
114 'iter_suppress',
115 'join_mappings',
116 'last',
117 'locate',
118 'longest_common_prefix',
119 'lstrip',
120 'make_decorator',
121 'map_except',
122 'map_if',
123 'map_reduce',
124 'mark_ends',
125 'minmax',
126 'nth_or_last',
127 'nth_permutation',
128 'nth_prime',
129 'nth_product',
130 'nth_combination_with_replacement',
131 'numeric_range',
132 'one',
133 'only',
134 'outer_product',
135 'padded',
136 'partial_product',
137 'partitions',
138 'peekable',
139 'permutation_index',
140 'powerset_of_sets',
141 'product_index',
142 'raise_',
143 'repeat_each',
144 'repeat_last',
145 'replace',
146 'rlocate',
147 'rstrip',
148 'run_length',
149 'sample',
150 'seekable',
151 'serialize',
152 'set_partitions',
153 'side_effect',
154 'sized_iterator',
155 'sliced',
156 'sort_together',
157 'split_after',
158 'split_at',
159 'split_before',
160 'split_into',
161 'split_when',
162 'spy',
163 'stagger',
164 'strip',
165 'strictly_n',
166 'subfactorial',
167 'substrings',
168 'substrings_indexes',
169 'synchronized',
170 'takewhile_inclusive',
171 'time_limited',
172 'unique_in_window',
173 'unique_to_each',
174 'unzip',
175 'value_chain',
176 'windowed',
177 'windowed_complete',
178 'with_iter',
179 'zip_broadcast',
180 'zip_offset',
181]
183# math.sumprod is available for Python 3.12+
184try:
185 from math import sumprod as _fsumprod
187except ImportError: # pragma: no cover
188 # Extended precision algorithms from T. J. Dekker,
189 # "A Floating-Point Technique for Extending the Available Precision"
190 # https://csclub.uwaterloo.ca/~pbarfuss/dekker1971.pdf
191 # Formulas: (5.5) (5.6) and (5.8). Code: mul12()
193 def dl_split(x: float):
194 "Split a float into two half-precision components."
195 t = x * 134217729.0 # Veltkamp constant = 2.0 ** 27 + 1
196 hi = t - (t - x)
197 lo = x - hi
198 return hi, lo
200 def dl_mul(x, y):
201 "Lossless multiplication."
202 xx_hi, xx_lo = dl_split(x)
203 yy_hi, yy_lo = dl_split(y)
204 p = xx_hi * yy_hi
205 q = xx_hi * yy_lo + xx_lo * yy_hi
206 z = p + q
207 zz = p - z + q + xx_lo * yy_lo
208 return z, zz
210 def _fsumprod(p, q):
211 return fsum(chain.from_iterable(map(dl_mul, p, q)))
214def chunked(iterable, n, strict=False):
215 """Break *iterable* into lists of length *n*:
217 >>> list(chunked([1, 2, 3, 4, 5, 6], 3))
218 [[1, 2, 3], [4, 5, 6]]
220 By the default, the last yielded list will have fewer than *n* elements
221 if the length of *iterable* is not divisible by *n*:
223 >>> list(chunked([1, 2, 3, 4, 5, 6, 7, 8], 3))
224 [[1, 2, 3], [4, 5, 6], [7, 8]]
226 To use a fill-in value instead, see the :func:`grouper` recipe.
228 If the length of *iterable* is not divisible by *n* and *strict* is
229 ``True``, then ``ValueError`` will be raised before the last
230 list is yielded.
232 """
233 iterator = iter(partial(take, n, iter(iterable)), [])
234 if strict:
235 if n is None:
236 raise ValueError('n must not be None when using strict mode.')
238 def ret():
239 for chunk in iterator:
240 if len(chunk) != n:
241 raise ValueError('iterable is not divisible by n.')
242 yield chunk
244 return ret()
245 else:
246 return iterator
249def first(iterable, default=_marker):
250 """Return the first item of *iterable*, or *default* if *iterable* is
251 empty.
253 >>> first([0, 1, 2, 3])
254 0
255 >>> first([], 'some default')
256 'some default'
258 If *default* is not provided and there are no items in the iterable,
259 raise ``ValueError``.
261 :func:`first` is useful when you have a generator of expensive-to-retrieve
262 values and want any arbitrary one. It is marginally shorter than
263 ``next(iter(iterable), default)``.
265 """
266 for item in iterable:
267 return item
268 if default is _marker:
269 raise ValueError(
270 'first() was called on an empty iterable, '
271 'and no default value was provided.'
272 )
273 return default
276def last(iterable, default=_marker):
277 """Return the last item of *iterable*, or *default* if *iterable* is
278 empty.
280 >>> last([0, 1, 2, 3])
281 3
282 >>> last([], 'some default')
283 'some default'
285 If *default* is not provided and there are no items in the iterable,
286 raise ``ValueError``.
287 """
288 try:
289 if getattr(iterable, '__reversed__', None):
290 return next(reversed(iterable))
291 return deque(iterable, maxlen=1)[-1]
292 except (IndexError, StopIteration):
293 if default is _marker:
294 raise ValueError(
295 'last() was called on an empty iterable, '
296 'and no default value was provided.'
297 )
298 return default
301def nth_or_last(iterable, n, default=_marker):
302 """Return the nth or the last item of *iterable*,
303 or *default* if *iterable* is empty.
305 >>> nth_or_last([0, 1, 2, 3], 2)
306 2
307 >>> nth_or_last([0, 1], 2)
308 1
309 >>> nth_or_last([], 0, 'some default')
310 'some default'
312 If *default* is not provided and there are no items in the iterable,
313 raise ``ValueError``.
314 """
315 return last(islice(iterable, n + 1), default=default)
318class peekable:
319 """Wrap an iterator to allow lookahead and prepending elements.
321 Call :meth:`peek` on the result to get the value that will be returned
322 by :func:`next`. This won't advance the iterator:
324 >>> p = peekable(['a', 'b'])
325 >>> p.peek()
326 'a'
327 >>> next(p)
328 'a'
330 Pass :meth:`peek` a default value to return that instead of raising
331 ``StopIteration`` when the iterator is exhausted.
333 >>> p = peekable([])
334 >>> p.peek('hi')
335 'hi'
337 peekables also offer a :meth:`prepend` method, which "inserts" items
338 at the head of the iterable:
340 >>> p = peekable([1, 2, 3])
341 >>> p.prepend(10, 11, 12)
342 >>> next(p)
343 10
344 >>> p.peek()
345 11
346 >>> list(p)
347 [11, 12, 1, 2, 3]
349 peekables can be indexed. Index 0 is the item that will be returned by
350 :func:`next`, index 1 is the item after that, and so on:
351 The values up to the given index will be cached.
353 >>> p = peekable(['a', 'b', 'c', 'd'])
354 >>> p[0]
355 'a'
356 >>> p[1]
357 'b'
358 >>> next(p)
359 'a'
361 Negative indexes are supported, but be aware that they will cache the
362 remaining items in the source iterator, which may require significant
363 storage.
365 To check whether a peekable is exhausted, check its truth value:
367 >>> p = peekable(['a', 'b'])
368 >>> if p: # peekable has items
369 ... list(p)
370 ['a', 'b']
371 >>> if not p: # peekable is exhausted
372 ... list(p)
373 []
375 """
377 def __init__(self, iterable):
378 self._it = iter(iterable)
379 self._cache = deque()
381 def __iter__(self):
382 return self
384 def __bool__(self):
385 try:
386 self.peek()
387 except StopIteration:
388 return False
389 return True
391 def peek(self, default=_marker):
392 """Return the item that will be next returned from ``next()``.
394 Return ``default`` if there are no items left. If ``default`` is not
395 provided, raise ``StopIteration``.
397 """
398 if not self._cache:
399 try:
400 self._cache.append(next(self._it))
401 except StopIteration:
402 if default is _marker:
403 raise
404 return default
405 return self._cache[0]
407 def prepend(self, *items):
408 """Stack up items to be the next ones returned from ``next()`` or
409 ``self.peek()``. The items will be returned in
410 first in, first out order::
412 >>> p = peekable([1, 2, 3])
413 >>> p.prepend(10, 11, 12)
414 >>> next(p)
415 10
416 >>> list(p)
417 [11, 12, 1, 2, 3]
419 It is possible, by prepending items, to "resurrect" a peekable that
420 previously raised ``StopIteration``.
422 >>> p = peekable([])
423 >>> next(p)
424 Traceback (most recent call last):
425 ...
426 StopIteration
427 >>> p.prepend(1)
428 >>> next(p)
429 1
430 >>> next(p)
431 Traceback (most recent call last):
432 ...
433 StopIteration
435 """
436 self._cache.extendleft(reversed(items))
438 __class_getitem__ = classmethod(types.GenericAlias)
440 def __next__(self):
441 if self._cache:
442 return self._cache.popleft()
444 return next(self._it)
446 def _get_slice(self, index):
447 # Normalize the slice's arguments
448 step = 1 if (index.step is None) else index.step
449 if step > 0:
450 start = 0 if (index.start is None) else index.start
451 stop = maxsize if (index.stop is None) else index.stop
452 elif step < 0:
453 start = -1 if (index.start is None) else index.start
454 stop = (-maxsize - 1) if (index.stop is None) else index.stop
455 else:
456 raise ValueError('slice step cannot be zero')
458 # If either the start or stop index is negative, we'll need to cache
459 # the rest of the iterable in order to slice from the right side.
460 if (start < 0) or (stop < 0):
461 self._cache.extend(self._it)
462 # Otherwise we'll need to find the rightmost index and cache to that
463 # point.
464 else:
465 n = min(max(start, stop) + 1, maxsize)
466 cache_len = len(self._cache)
467 if n >= cache_len:
468 self._cache.extend(islice(self._it, n - cache_len))
470 return list(self._cache)[index]
472 def __getitem__(self, index):
473 if isinstance(index, slice):
474 return self._get_slice(index)
476 cache_len = len(self._cache)
477 if index < 0:
478 self._cache.extend(self._it)
479 elif index >= cache_len:
480 self._cache.extend(islice(self._it, index + 1 - cache_len))
482 return self._cache[index]
485def consumer(func):
486 """Decorator that automatically advances a PEP-342-style "reverse iterator"
487 to its first yield point so you don't have to call ``next()`` on it
488 manually.
490 >>> @consumer
491 ... def tally():
492 ... i = 0
493 ... while True:
494 ... print('Thing number %s is %s.' % (i, (yield)))
495 ... i += 1
496 ...
497 >>> t = tally()
498 >>> t.send('red')
499 Thing number 0 is red.
500 >>> t.send('fish')
501 Thing number 1 is fish.
503 Without the decorator, you would have to call ``next(t)`` before
504 ``t.send()`` could be used.
506 """
508 @wraps(func)
509 def wrapper(*args, **kwargs):
510 gen = func(*args, **kwargs)
511 next(gen)
512 return gen
514 return wrapper
517def ilen(iterable):
518 """Return the number of items in *iterable*.
520 For example, there are 168 prime numbers below 1,000:
522 >>> ilen(sieve(1000))
523 168
525 Equivalent to, but faster than::
527 def ilen(iterable):
528 count = 0
529 for _ in iterable:
530 count += 1
531 return count
533 This fully consumes the iterable, so handle with care.
535 """
536 # This is the "most beautiful of the fast variants" of this function.
537 # If you think you can improve on it, please ensure that your version
538 # is both 10x faster and 10x more beautiful.
539 return sum(compress(repeat(1), zip(iterable)))
542def iterate(func, start):
543 """Return ``start``, ``func(start)``, ``func(func(start))``, ...
545 Produces an infinite iterator. To add a stopping condition,
546 use :func:`take`, ``takewhile``, or :func:`takewhile_inclusive`:.
548 >>> take(10, iterate(lambda x: 2*x, 1))
549 [1, 2, 4, 8, 16, 32, 64, 128, 256, 512]
551 >>> collatz = lambda x: 3*x + 1 if x%2==1 else x // 2
552 >>> list(takewhile_inclusive(lambda x: x!=1, iterate(collatz, 10)))
553 [10, 5, 16, 8, 4, 2, 1]
555 """
556 with suppress(StopIteration):
557 while True:
558 yield start
559 start = func(start)
562def with_iter(context_manager):
563 """Wrap an iterable in a ``with`` statement, so it closes once exhausted.
565 For example, this will close the file when the iterator is exhausted::
567 upper_lines = (line.upper() for line in with_iter(open('foo')))
569 Note that you have to actually exhaust the iterator for opened files to be closed.
571 Any context manager which returns an iterable is a candidate for
572 ``with_iter``.
574 """
575 with context_manager as iterable:
576 yield from iterable
579class sized_iterator:
580 """Wrapper for *iterable* that implements ``__len__``.
582 >>> it = map(str, range(5))
583 >>> sized_it = sized_iterator(it, 5)
584 >>> len(sized_it)
585 5
586 >>> list(sized_it)
587 ['0', '1', '2', '3', '4']
589 This is useful for tools that use :func:`len`, like
590 `tqdm <https://pypi.org/project/tqdm/>`__ .
592 The wrapper doesn't validate the provided *length*, so be sure to choose
593 a value that reflects reality.
594 """
596 def __init__(self, iterable, length):
597 self._iterator = iter(iterable)
598 self._length = length
600 def __next__(self):
601 return next(self._iterator)
603 def __iter__(self):
604 return self
606 def __len__(self):
607 return self._length
610def one(iterable, too_short=None, too_long=None):
611 """Return the first item from *iterable*, which is expected to contain only
612 that item. Raise an exception if *iterable* is empty or has more than one
613 item.
615 :func:`one` is useful for ensuring that an iterable contains only one item.
616 For example, it can be used to retrieve the result of a database query
617 that is expected to return a single row.
619 If *iterable* is empty, ``ValueError`` will be raised. You may specify a
620 different exception with the *too_short* keyword:
622 >>> it = []
623 >>> one(it) # doctest: +IGNORE_EXCEPTION_DETAIL
624 Traceback (most recent call last):
625 ...
626 ValueError: too few items in iterable (expected 1)'
627 >>> too_short = IndexError('too few items')
628 >>> one(it, too_short=too_short) # doctest: +IGNORE_EXCEPTION_DETAIL
629 Traceback (most recent call last):
630 ...
631 IndexError: too few items
633 Similarly, if *iterable* contains more than one item, ``ValueError`` will
634 be raised. You may specify a different exception with the *too_long*
635 keyword:
637 >>> it = ['too', 'many']
638 >>> one(it) # doctest: +IGNORE_EXCEPTION_DETAIL
639 Traceback (most recent call last):
640 ...
641 ValueError: Expected exactly one item in iterable, but got 'too',
642 'many', and perhaps more.
643 >>> too_long = RuntimeError
644 >>> one(it, too_long=too_long) # doctest: +IGNORE_EXCEPTION_DETAIL
645 Traceback (most recent call last):
646 ...
647 RuntimeError
649 Note that :func:`one` attempts to advance *iterable* twice to ensure there
650 is only one item. See :func:`spy` or :func:`peekable` to check iterable
651 contents less destructively.
653 """
654 iterator = iter(iterable)
655 for first in iterator:
656 for second in iterator:
657 msg = (
658 f'Expected exactly one item in iterable, but got {first!r}, '
659 f'{second!r}, and perhaps more.'
660 )
661 raise too_long or ValueError(msg)
662 return first
663 raise too_short or ValueError('too few items in iterable (expected 1)')
666def raise_(exception, *args):
667 raise exception(*args)
670def strictly_n(iterable, n, too_short=None, too_long=None):
671 """Validate that *iterable* has exactly *n* items and return them if
672 it does. If it has fewer than *n* items, call function *too_short*
673 with the actual number of items. If it has more than *n* items, call function
674 *too_long* with the number ``n + 1``.
676 >>> iterable = ['a', 'b', 'c', 'd']
677 >>> n = 4
678 >>> list(strictly_n(iterable, n))
679 ['a', 'b', 'c', 'd']
681 Note that the returned iterable must be consumed in order for the check to
682 be made.
684 By default, *too_short* and *too_long* are functions that raise
685 ``ValueError``.
687 >>> list(strictly_n('ab', 3)) # doctest: +IGNORE_EXCEPTION_DETAIL
688 Traceback (most recent call last):
689 ...
690 ValueError: too few items in iterable (got 2)
692 >>> list(strictly_n('abc', 2)) # doctest: +IGNORE_EXCEPTION_DETAIL
693 Traceback (most recent call last):
694 ...
695 ValueError: too many items in iterable (got at least 3)
697 You can instead supply functions that do something else.
698 *too_short* will be called with the number of items in *iterable*.
699 *too_long* will be called with `n + 1`.
701 >>> def too_short(item_count):
702 ... raise RuntimeError
703 >>> it = strictly_n('abcd', 6, too_short=too_short)
704 >>> list(it) # doctest: +IGNORE_EXCEPTION_DETAIL
705 Traceback (most recent call last):
706 ...
707 RuntimeError
709 >>> def too_long(item_count):
710 ... print('The boss is going to hear about this')
711 >>> it = strictly_n('abcdef', 4, too_long=too_long)
712 >>> list(it)
713 The boss is going to hear about this
714 ['a', 'b', 'c', 'd']
716 """
717 if too_short is None:
718 too_short = lambda item_count: raise_(
719 ValueError,
720 f'Too few items in iterable (got {item_count})',
721 )
723 if too_long is None:
724 too_long = lambda item_count: raise_(
725 ValueError,
726 f'Too many items in iterable (got at least {item_count})',
727 )
729 it = iter(iterable)
731 sent = 0
732 for item in islice(it, n):
733 yield item
734 sent += 1
736 if sent < n:
737 too_short(sent)
738 return
740 for item in it:
741 too_long(n + 1)
742 return
745def distinct_permutations(iterable, r=None):
746 """Yield successive distinct permutations of the elements in *iterable*.
748 >>> sorted(distinct_permutations([1, 0, 1]))
749 [(0, 1, 1), (1, 0, 1), (1, 1, 0)]
751 Equivalent to yielding from ``set(permutations(iterable))``, except
752 duplicates are not generated and thrown away. For larger input sequences
753 this is much more efficient.
755 Duplicate permutations arise when there are duplicated elements in the
756 input iterable. The number of items returned is
757 `n! / (x_1! * x_2! * ... * x_n!)`, where `n` is the total number of
758 items input, and each `x_i` is the count of a distinct item in the input
759 sequence. The function :func:`multinomial` computes this directly.
761 If *r* is given, only the *r*-length permutations are yielded.
763 >>> sorted(distinct_permutations([1, 0, 1], r=2))
764 [(0, 1), (1, 0), (1, 1)]
765 >>> sorted(distinct_permutations(range(3), r=2))
766 [(0, 1), (0, 2), (1, 0), (1, 2), (2, 0), (2, 1)]
768 *iterable* need not be sortable, but note that using equal (``x == y``)
769 but non-identical (``id(x) != id(y)``) elements may produce surprising
770 behavior. For example, ``1`` and ``True`` are equal but non-identical:
772 >>> list(distinct_permutations([1, True, '3'])) # doctest: +SKIP
773 [
774 (1, True, '3'),
775 (1, '3', True),
776 ('3', 1, True)
777 ]
778 >>> list(distinct_permutations([1, 2, '3'])) # doctest: +SKIP
779 [
780 (1, 2, '3'),
781 (1, '3', 2),
782 (2, 1, '3'),
783 (2, '3', 1),
784 ('3', 1, 2),
785 ('3', 2, 1)
786 ]
787 """
789 # Algorithm: https://w.wiki/Qai
790 def _full(A):
791 while True:
792 # Yield the permutation we have
793 yield tuple(A)
795 # Find the largest index i such that A[i] < A[i + 1]
796 for i in range(size - 2, -1, -1):
797 if A[i] < A[i + 1]:
798 break
799 # If no such index exists, this permutation is the last one
800 else:
801 return
803 # Find the largest index j greater than j such that A[i] < A[j]
804 for j in range(size - 1, i, -1):
805 if A[i] < A[j]:
806 break
808 # Swap the value of A[i] with that of A[j], then reverse the
809 # sequence from A[i + 1] to form the new permutation
810 A[i], A[j] = A[j], A[i]
811 A[i + 1 :] = A[: i - size : -1] # A[i + 1:][::-1]
813 # Algorithm: modified from the above
814 def _partial(A, r):
815 # Split A into the first r items and the last r items
816 head, tail = A[:r], A[r:]
817 right_head_indexes = range(r - 1, -1, -1)
818 left_tail_indexes = range(len(tail))
820 while True:
821 # Yield the permutation we have
822 yield tuple(head)
824 # Starting from the right, find the first index of the head with
825 # value smaller than the maximum value of the tail - call it i.
826 pivot = tail[-1]
827 for i in right_head_indexes:
828 if head[i] < pivot:
829 break
830 pivot = head[i]
831 else:
832 return
834 # Starting from the left, find the first value of the tail
835 # with a value greater than head[i] and swap.
836 for j in left_tail_indexes:
837 if tail[j] > head[i]:
838 head[i], tail[j] = tail[j], head[i]
839 break
840 # If we didn't find one, start from the right and find the first
841 # index of the head with a value greater than head[i] and swap.
842 else:
843 for j in right_head_indexes:
844 if head[j] > head[i]:
845 head[i], head[j] = head[j], head[i]
846 break
848 # Reverse head[i + 1:] and swap it with tail[:r - (i + 1)]
849 tail += head[: i - r : -1] # head[i + 1:][::-1]
850 i += 1
851 head[i:], tail[:] = tail[: r - i], tail[r - i :]
853 items = list(iterable)
855 try:
856 items.sort()
857 sortable = True
858 except TypeError:
859 sortable = False
861 indices_dict = defaultdict(list)
863 for item in items:
864 indices_dict[items.index(item)].append(item)
866 indices = [items.index(item) for item in items]
867 indices.sort()
869 equivalent_items = {k: cycle(v) for k, v in indices_dict.items()}
871 def permuted_items(permuted_indices):
872 return tuple(
873 next(equivalent_items[index]) for index in permuted_indices
874 )
876 size = len(items)
877 if r is None:
878 r = size
880 # functools.partial(_partial, ... )
881 algorithm = _full if (r == size) else partial(_partial, r=r)
883 if 0 < r <= size:
884 if sortable:
885 return algorithm(items)
886 else:
887 return (
888 permuted_items(permuted_indices)
889 for permuted_indices in algorithm(indices)
890 )
892 return iter(() if r else ((),))
895def derangements(iterable, r=None):
896 """Yield successive derangements of the elements in *iterable*.
898 A derangement is a permutation in which no element appears at its original
899 index. In other words, a derangement is a permutation that has no fixed points.
901 Suppose Alice, Bob, Carol, and Dave are playing Secret Santa.
902 The code below outputs all of the different ways to assign gift recipients
903 such that nobody is assigned to himself or herself:
905 >>> for d in derangements(['Alice', 'Bob', 'Carol', 'Dave']):
906 ... print(', '.join(d))
907 Bob, Alice, Dave, Carol
908 Bob, Carol, Dave, Alice
909 Bob, Dave, Alice, Carol
910 Carol, Alice, Dave, Bob
911 Carol, Dave, Alice, Bob
912 Carol, Dave, Bob, Alice
913 Dave, Alice, Bob, Carol
914 Dave, Carol, Alice, Bob
915 Dave, Carol, Bob, Alice
917 If *r* is given, only the *r*-length derangements are yielded.
919 >>> sorted(derangements(range(3), 2))
920 [(1, 0), (1, 2), (2, 0)]
921 >>> sorted(derangements([0, 2, 3], 2))
922 [(2, 0), (2, 3), (3, 0)]
924 Elements are treated as unique based on their position, not on their value.
926 Consider the Secret Santa example with two *different* people who have
927 the *same* name. Then there are two valid gift assignments even though
928 it might appear that a person is assigned to themselves:
930 >>> names = ['Alice', 'Bob', 'Bob']
931 >>> list(derangements(names))
932 [('Bob', 'Bob', 'Alice'), ('Bob', 'Alice', 'Bob')]
934 To avoid confusion, make the inputs distinct:
936 >>> deduped = [f'{name}{index}' for index, name in enumerate(names)]
937 >>> list(derangements(deduped))
938 [('Bob1', 'Bob2', 'Alice0'), ('Bob2', 'Alice0', 'Bob1')]
940 The number of derangements of a set of size *n* is known as the
941 "subfactorial of n". For n > 0, the subfactorial is:
942 ``round(math.factorial(n) / math.e)``. The more-itertools function
943 :func:`subfactorial` computes this directly.
945 References:
947 * Article: https://www.numberanalytics.com/blog/ultimate-guide-to-derangements-in-combinatorics
948 * Sizes: https://oeis.org/A000166
949 """
950 xs = tuple(iterable)
951 ys = tuple(range(len(xs)))
952 return compress(
953 permutations(xs, r=r),
954 map(all, map(map, repeat(is_not), repeat(ys), permutations(ys, r=r))),
955 )
958def intersperse(e, iterable, n=1):
959 """Intersperse filler element *e* among the items in *iterable*, leaving
960 *n* items between each filler element.
962 >>> list(intersperse('!', [1, 2, 3, 4, 5]))
963 [1, '!', 2, '!', 3, '!', 4, '!', 5]
965 >>> list(intersperse(None, [1, 2, 3, 4, 5], n=2))
966 [1, 2, None, 3, 4, None, 5]
968 """
969 if n == 0:
970 raise ValueError('n must be > 0')
971 elif n == 1:
972 # interleave(repeat(e), iterable) -> e, x_0, e, x_1, e, x_2...
973 # islice(..., 1, None) -> x_0, e, x_1, e, x_2...
974 return islice(interleave(repeat(e), iterable), 1, None)
975 else:
976 # interleave(filler, chunks) -> [e], [x_0, x_1], [e], [x_2, x_3]...
977 # islice(..., 1, None) -> [x_0, x_1], [e], [x_2, x_3]...
978 # flatten(...) -> x_0, x_1, e, x_2, x_3...
979 filler = repeat([e])
980 chunks = chunked(iterable, n)
981 return flatten(islice(interleave(filler, chunks), 1, None))
984def unique_to_each(*iterables):
985 """Return the elements from each of the input iterables that aren't in the
986 other input iterables.
988 For example, suppose you have a set of packages, each with a set of
989 dependencies::
991 {'pkg_1': {'A', 'B'}, 'pkg_2': {'B', 'C'}, 'pkg_3': {'B', 'D'}}
993 If you remove one package, which dependencies can also be removed?
995 If ``pkg_1`` is removed, then ``A`` is no longer necessary - it is not
996 associated with ``pkg_2`` or ``pkg_3``. Similarly, ``C`` is only needed for
997 ``pkg_2``, and ``D`` is only needed for ``pkg_3``::
999 >>> unique_to_each({'A', 'B'}, {'B', 'C'}, {'B', 'D'})
1000 [['A'], ['C'], ['D']]
1002 If there are duplicates in one input iterable that aren't in the others
1003 they will be duplicated in the output. Input order is preserved::
1005 >>> unique_to_each("mississippi", "missouri")
1006 [['p', 'p'], ['o', 'u', 'r']]
1008 It is assumed that the elements of each iterable are hashable.
1010 """
1011 pool = [list(it) for it in iterables]
1012 counts = Counter(chain.from_iterable(map(set, pool)))
1013 uniques = {element for element in counts if counts[element] == 1}
1014 return [list(filter(uniques.__contains__, it)) for it in pool]
1017def windowed(seq, n, fillvalue=None, step=1):
1018 """Return a sliding window of width *n* over the given iterable.
1020 >>> all_windows = windowed([1, 2, 3, 4, 5], 3)
1021 >>> list(all_windows)
1022 [(1, 2, 3), (2, 3, 4), (3, 4, 5)]
1024 When the window is larger than the iterable, *fillvalue* is used in place
1025 of missing values:
1027 >>> list(windowed([1, 2, 3], 4))
1028 [(1, 2, 3, None)]
1030 Each window will advance in increments of *step*:
1032 >>> list(windowed([1, 2, 3, 4, 5, 6], 3, fillvalue='!', step=2))
1033 [(1, 2, 3), (3, 4, 5), (5, 6, '!')]
1035 To slide into the iterable's items, use :func:`chain` to add filler items
1036 to the left:
1038 >>> iterable = [1, 2, 3, 4]
1039 >>> n = 3
1040 >>> padding = [None] * (n - 1)
1041 >>> list(windowed(chain(padding, iterable), 3))
1042 [(None, None, 1), (None, 1, 2), (1, 2, 3), (2, 3, 4)]
1043 """
1044 if n <= 0:
1045 raise ValueError('n must be > 0')
1046 if step < 1:
1047 raise ValueError('step must be >= 1')
1049 iterator = iter(seq)
1051 # Generate first window
1052 window = deque(islice(iterator, n), maxlen=n)
1054 # Deal with the first window not being full
1055 if not window:
1056 return
1057 if len(window) < n:
1058 yield tuple(window) + ((fillvalue,) * (n - len(window)))
1059 return
1060 yield tuple(window)
1062 # Create the filler for the next windows. The padding ensures
1063 # we have just enough elements to fill the last window.
1064 padding = (fillvalue,) * (n - 1 if step >= n else step - 1)
1065 filler = map(window.append, chain(iterator, padding))
1067 # Generate the rest of the windows
1068 for _ in islice(filler, step - 1, None, step):
1069 yield tuple(window)
1072def substrings(iterable):
1073 """Yield all of the substrings of *iterable*.
1075 >>> [''.join(s) for s in substrings('more')]
1076 ['m', 'o', 'r', 'e', 'mo', 'or', 're', 'mor', 'ore', 'more']
1078 Note that non-string iterables can also be subdivided.
1080 >>> list(substrings([0, 1, 2]))
1081 [(0,), (1,), (2,), (0, 1), (1, 2), (0, 1, 2)]
1083 Like subslices() but returns tuples instead of lists
1084 and returns the shortest substrings first.
1086 """
1087 seq = tuple(iterable)
1088 item_count = len(seq)
1089 for n in range(1, item_count + 1):
1090 slices = map(slice, range(item_count), range(n, item_count + 1))
1091 yield from map(getitem, repeat(seq), slices)
1094def substrings_indexes(seq, reverse=False):
1095 """Yield all substrings and their positions in *seq*
1097 The items yielded will be a tuple of the form ``(substr, i, j)``, where
1098 ``substr == seq[i:j]``.
1100 This function only works for iterables that support slicing, such as
1101 ``str`` objects.
1103 >>> for item in substrings_indexes('more'):
1104 ... print(item)
1105 ('m', 0, 1)
1106 ('o', 1, 2)
1107 ('r', 2, 3)
1108 ('e', 3, 4)
1109 ('mo', 0, 2)
1110 ('or', 1, 3)
1111 ('re', 2, 4)
1112 ('mor', 0, 3)
1113 ('ore', 1, 4)
1114 ('more', 0, 4)
1116 Set *reverse* to ``True`` to yield the same items in the opposite order.
1119 """
1120 r = range(1, len(seq) + 1)
1121 if reverse:
1122 r = reversed(r)
1123 return (
1124 (seq[i : i + L], i, i + L) for L in r for i in range(len(seq) - L + 1)
1125 )
1128class bucket:
1129 """Wrap *iterable* and return an object that buckets the iterable into
1130 child iterables based on a *key* function.
1132 >>> iterable = ['a1', 'b1', 'c1', 'a2', 'b2', 'c2', 'b3']
1133 >>> s = bucket(iterable, key=lambda x: x[0]) # Bucket by 1st character
1134 >>> sorted(list(s)) # Get the keys
1135 ['a', 'b', 'c']
1136 >>> a_iterable = s['a']
1137 >>> next(a_iterable)
1138 'a1'
1139 >>> next(a_iterable)
1140 'a2'
1141 >>> list(s['b'])
1142 ['b1', 'b2', 'b3']
1144 The original iterable will be advanced and its items will be cached until
1145 they are used by the child iterables. This may require significant storage.
1147 By default, attempting to select a bucket to which no items belong will
1148 exhaust the iterable and cache all values.
1149 If you specify a *validator* function, selected buckets will instead be
1150 checked against it.
1152 >>> from itertools import count
1153 >>> it = count(1, 2) # Infinite sequence of odd numbers
1154 >>> key = lambda x: x % 10 # Bucket by last digit
1155 >>> validator = lambda x: x in {1, 3, 5, 7, 9} # Odd digits only
1156 >>> s = bucket(it, key=key, validator=validator)
1157 >>> 2 in s
1158 False
1159 >>> list(s[2])
1160 []
1162 .. seealso:: :func:`map_reduce`, :func:`groupby_transform`
1164 If storage is not a concern, :func:`map_reduce` returns a Python
1165 dictionary, which is generally easier to work with. If the elements
1166 with the same key are already adjacent, :func:`groupby_transform`
1167 or :func:`itertools.groupby` can be used without any caching overhead.
1169 """
1171 def __init__(self, iterable, key, validator=None):
1172 self._it = iter(iterable)
1173 self._key = key
1174 self._cache = defaultdict(deque)
1175 self._validator = validator or (lambda x: True)
1177 def __contains__(self, value):
1178 if not self._validator(value):
1179 return False
1181 try:
1182 item = next(self[value])
1183 except StopIteration:
1184 return False
1185 else:
1186 self._cache[value].appendleft(item)
1188 return True
1190 def _get_values(self, value):
1191 """
1192 Helper to yield items from the parent iterator that match *value*.
1193 Items that don't match are stored in the local cache as they
1194 are encountered.
1195 """
1196 while True:
1197 # If we've cached some items that match the target value, emit
1198 # the first one and evict it from the cache.
1199 if self._cache[value]:
1200 yield self._cache[value].popleft()
1201 # Otherwise we need to advance the parent iterator to search for
1202 # a matching item, caching the rest.
1203 else:
1204 while True:
1205 try:
1206 item = next(self._it)
1207 except StopIteration:
1208 return
1209 item_value = self._key(item)
1210 if item_value == value:
1211 yield item
1212 break
1213 elif self._validator(item_value):
1214 self._cache[item_value].append(item)
1216 def __iter__(self):
1217 for item in self._it:
1218 item_value = self._key(item)
1219 if self._validator(item_value):
1220 self._cache[item_value].append(item)
1222 return iter(self._cache)
1224 def __getitem__(self, value):
1225 if not self._validator(value):
1226 return iter(())
1228 return self._get_values(value)
1231def spy(iterable, n=1):
1232 """Return a 2-tuple with a list containing the first *n* elements of
1233 *iterable*, and an iterator with the same items as *iterable*.
1234 This allows you to "look ahead" at the items in the iterable without
1235 advancing it.
1237 There is one item in the list by default:
1239 >>> iterable = 'abcdefg'
1240 >>> head, iterable = spy(iterable)
1241 >>> head
1242 ['a']
1243 >>> list(iterable)
1244 ['a', 'b', 'c', 'd', 'e', 'f', 'g']
1246 You may use unpacking to retrieve items instead of lists:
1248 >>> (head,), iterable = spy('abcdefg')
1249 >>> head
1250 'a'
1251 >>> (first, second), iterable = spy('abcdefg', 2)
1252 >>> first
1253 'a'
1254 >>> second
1255 'b'
1257 The number of items requested can be larger than the number of items in
1258 the iterable:
1260 >>> iterable = [1, 2, 3, 4, 5]
1261 >>> head, iterable = spy(iterable, 10)
1262 >>> head
1263 [1, 2, 3, 4, 5]
1264 >>> list(iterable)
1265 [1, 2, 3, 4, 5]
1267 """
1268 p, q = tee(iterable)
1269 return take(n, q), p
1272def interleave(*iterables):
1273 """Return a new iterable yielding from each iterable in turn,
1274 until the shortest is exhausted.
1276 >>> list(interleave([1, 2, 3], [4, 5], [6, 7, 8]))
1277 [1, 4, 6, 2, 5, 7]
1279 For a version that doesn't terminate after the shortest iterable is
1280 exhausted, see :func:`interleave_longest`.
1282 """
1283 return chain.from_iterable(zip(*iterables))
1286def interleave_longest(*iterables):
1287 """Return a new iterable yielding from each iterable in turn,
1288 skipping any that are exhausted.
1290 >>> list(interleave_longest([1, 2, 3], [4, 5], [6, 7, 8]))
1291 [1, 4, 6, 2, 5, 7, 3, 8]
1293 This function produces the same output as :func:`roundrobin`, but may
1294 perform better for some inputs (in particular when the number of iterables
1295 is large).
1297 """
1298 for xs in zip_longest(*iterables, fillvalue=_marker):
1299 for x in xs:
1300 if x is not _marker:
1301 yield x
1304def interleave_evenly(iterables, lengths=None):
1305 """
1306 Interleave multiple iterables so that their elements are evenly distributed
1307 throughout the output sequence.
1309 >>> iterables = [1, 2, 3, 4, 5], ['a', 'b']
1310 >>> list(interleave_evenly(iterables))
1311 [1, 2, 'a', 3, 4, 'b', 5]
1313 >>> iterables = [[1, 2, 3], [4, 5], [6, 7, 8]]
1314 >>> list(interleave_evenly(iterables))
1315 [1, 6, 4, 2, 7, 3, 8, 5]
1317 This function requires iterables of known length. Iterables without
1318 ``__len__()`` can be used by manually specifying lengths with *lengths*:
1320 >>> from itertools import combinations, repeat
1321 >>> iterables = [combinations(range(4), 2), ['a', 'b', 'c']]
1322 >>> lengths = [4 * (4 - 1) // 2, 3]
1323 >>> list(interleave_evenly(iterables, lengths=lengths))
1324 [(0, 1), (0, 2), 'a', (0, 3), (1, 2), 'b', (1, 3), (2, 3), 'c']
1326 Based on Bresenham's algorithm.
1327 """
1328 if lengths is None:
1329 try:
1330 lengths = [len(it) for it in iterables]
1331 except TypeError:
1332 raise ValueError(
1333 'Iterable lengths could not be determined automatically. '
1334 'Specify them with the lengths keyword.'
1335 )
1336 elif len(iterables) != len(lengths):
1337 raise ValueError('Mismatching number of iterables and lengths.')
1339 dims = len(lengths)
1341 if not dims:
1342 return
1344 # sort iterables by length, descending
1345 lengths_permute = sorted(
1346 range(dims), key=lambda i: lengths[i], reverse=True
1347 )
1348 lengths_desc = [lengths[i] for i in lengths_permute]
1349 iters_desc = [iter(iterables[i]) for i in lengths_permute]
1351 # the longest iterable is the primary one (Bresenham: the longest
1352 # distance along an axis)
1353 delta_primary, deltas_secondary = lengths_desc[0], lengths_desc[1:]
1354 iter_primary, iters_secondary = iters_desc[0], iters_desc[1:]
1355 errors = [delta_primary // dims] * len(deltas_secondary)
1357 to_yield = sum(lengths)
1358 while to_yield:
1359 yield next(iter_primary)
1360 to_yield -= 1
1361 # update errors for each secondary iterable
1362 errors = [e - delta for e, delta in zip(errors, deltas_secondary)]
1364 # those iterables for which the error is negative are yielded
1365 # ("diagonal step" in Bresenham)
1366 for i, e_ in enumerate(errors):
1367 if e_ < 0:
1368 yield next(iters_secondary[i])
1369 to_yield -= 1
1370 errors[i] += delta_primary
1373def interleave_randomly(*iterables):
1374 """Repeatedly select one of the input *iterables* at random and yield the next
1375 item from it.
1377 >>> iterables = [1, 2, 3], 'abc', (True, False, None)
1378 >>> list(interleave_randomly(*iterables)) # doctest: +SKIP
1379 ['a', 'b', 1, 'c', True, False, None, 2, 3]
1381 The relative order of the items in each input iterable will preserved. Note the
1382 sequences of items with this property are not equally likely to be generated.
1384 """
1385 iterators = [iter(e) for e in iterables]
1386 while iterators:
1387 idx = randrange(len(iterators))
1388 try:
1389 yield next(iterators[idx])
1390 except StopIteration:
1391 # equivalent to `list.pop` but slightly faster
1392 iterators[idx] = iterators[-1]
1393 del iterators[-1]
1396def collapse(iterable, base_type=None, levels=None):
1397 """Flatten an iterable with multiple levels of nesting (e.g., a list of
1398 lists of tuples) into non-iterable types.
1400 >>> iterable = [(1, 2), ([3, 4], [[5], [6]])]
1401 >>> list(collapse(iterable))
1402 [1, 2, 3, 4, 5, 6]
1404 Binary and text strings are not considered iterable and
1405 will not be collapsed.
1407 To avoid collapsing other types, specify *base_type*:
1409 >>> iterable = ['ab', ('cd', 'ef'), ['gh', 'ij']]
1410 >>> list(collapse(iterable, base_type=tuple))
1411 ['ab', ('cd', 'ef'), 'gh', 'ij']
1413 Specify *levels* to stop flattening after a certain level:
1415 >>> iterable = [('a', ['b']), ('c', ['d'])]
1416 >>> list(collapse(iterable)) # Fully flattened
1417 ['a', 'b', 'c', 'd']
1418 >>> list(collapse(iterable, levels=1)) # Only one level flattened
1419 ['a', ['b'], 'c', ['d']]
1421 """
1422 stack = deque()
1423 # Add our first node group, treat the iterable as a single node
1424 stack.appendleft((0, repeat(iterable, 1)))
1426 while stack:
1427 node_group = stack.popleft()
1428 level, nodes = node_group
1430 # Check if beyond max level
1431 if levels is not None and level > levels:
1432 yield from nodes
1433 continue
1435 for node in nodes:
1436 # Check if done iterating
1437 if isinstance(node, (str, bytes)) or (
1438 (base_type is not None) and isinstance(node, base_type)
1439 ):
1440 yield node
1441 # Otherwise try to create child nodes
1442 else:
1443 try:
1444 tree = iter(node)
1445 except TypeError:
1446 yield node
1447 else:
1448 # Save our current location
1449 stack.appendleft(node_group)
1450 # Append the new child node
1451 stack.appendleft((level + 1, tree))
1452 # Break to process child node
1453 break
1456def side_effect(func, iterable, chunk_size=None, before=None, after=None):
1457 """Invoke *func* on each item in *iterable* (or on each *chunk_size* group
1458 of items) before yielding the item.
1460 `func` must be a function that takes a single argument. Its return value
1461 will be discarded.
1463 *before* and *after* are optional functions that take no arguments. They
1464 will be executed before iteration starts and after it ends, respectively.
1466 `side_effect` can be used for logging, updating progress bars, or anything
1467 that is not functionally "pure."
1469 Emitting a status message:
1471 >>> from more_itertools import consume
1472 >>> func = lambda item: print('Received {}'.format(item))
1473 >>> consume(side_effect(func, range(2)))
1474 Received 0
1475 Received 1
1477 Operating on chunks of items:
1479 >>> pair_sums = []
1480 >>> func = lambda chunk: pair_sums.append(sum(chunk))
1481 >>> list(side_effect(func, [0, 1, 2, 3, 4, 5], 2))
1482 [0, 1, 2, 3, 4, 5]
1483 >>> list(pair_sums)
1484 [1, 5, 9]
1486 Writing to a file-like object:
1488 >>> from io import StringIO
1489 >>> from more_itertools import consume
1490 >>> f = StringIO()
1491 >>> func = lambda x: print(x, file=f)
1492 >>> before = lambda: print(u'HEADER', file=f)
1493 >>> after = f.close
1494 >>> it = [u'a', u'b', u'c']
1495 >>> consume(side_effect(func, it, before=before, after=after))
1496 >>> f.closed
1497 True
1499 """
1500 try:
1501 if before is not None:
1502 before()
1504 if chunk_size is None:
1505 for item in iterable:
1506 func(item)
1507 yield item
1508 else:
1509 for chunk in chunked(iterable, chunk_size):
1510 func(chunk)
1511 yield from chunk
1512 finally:
1513 if after is not None:
1514 after()
1517def sliced(seq, n, strict=False):
1518 """Yield slices of length *n* from the sequence *seq*.
1520 >>> list(sliced((1, 2, 3, 4, 5, 6), 3))
1521 [(1, 2, 3), (4, 5, 6)]
1523 By the default, the last yielded slice will have fewer than *n* elements
1524 if the length of *seq* is not divisible by *n*:
1526 >>> list(sliced((1, 2, 3, 4, 5, 6, 7, 8), 3))
1527 [(1, 2, 3), (4, 5, 6), (7, 8)]
1529 If the length of *seq* is not divisible by *n* and *strict* is
1530 ``True``, then ``ValueError`` will be raised before the last
1531 slice is yielded.
1533 This function will only work for iterables that support slicing.
1534 For non-sliceable iterables, see :func:`chunked`.
1536 """
1537 iterator = takewhile(len, (seq[i : i + n] for i in count(0, n)))
1538 if strict:
1540 def ret():
1541 for _slice in iterator:
1542 if len(_slice) != n:
1543 raise ValueError("seq is not divisible by n.")
1544 yield _slice
1546 return ret()
1547 else:
1548 return iterator
1551def split_at(iterable, pred, maxsplit=-1, keep_separator=False):
1552 """Yield lists of items from *iterable*, where each list is delimited by
1553 an item where callable *pred* returns ``True``.
1555 >>> list(split_at('abcdcba', lambda x: x == 'b'))
1556 [['a'], ['c', 'd', 'c'], ['a']]
1558 >>> list(split_at(range(10), lambda n: n % 2 == 1))
1559 [[0], [2], [4], [6], [8], []]
1561 At most *maxsplit* splits are done. If *maxsplit* is not specified or -1,
1562 then there is no limit on the number of splits:
1564 >>> list(split_at(range(10), lambda n: n % 2 == 1, maxsplit=2))
1565 [[0], [2], [4, 5, 6, 7, 8, 9]]
1567 By default, the delimiting items are not included in the output.
1568 To include them, set *keep_separator* to ``True``.
1570 >>> list(split_at('abcdcba', lambda x: x == 'b', keep_separator=True))
1571 [['a'], ['b'], ['c', 'd', 'c'], ['b'], ['a']]
1573 """
1574 if maxsplit == 0:
1575 yield list(iterable)
1576 return
1578 buf = []
1579 it = iter(iterable)
1580 for item in it:
1581 if pred(item):
1582 yield buf
1583 if keep_separator:
1584 yield [item]
1585 if maxsplit == 1:
1586 yield list(it)
1587 return
1588 buf = []
1589 maxsplit -= 1
1590 else:
1591 buf.append(item)
1592 yield buf
1595def split_before(iterable, pred, maxsplit=-1):
1596 """Yield lists of items from *iterable*, where each list ends just before
1597 an item for which callable *pred* returns ``True``:
1599 >>> list(split_before('OneTwo', lambda s: s.isupper()))
1600 [['O', 'n', 'e'], ['T', 'w', 'o']]
1602 >>> list(split_before(range(10), lambda n: n % 3 == 0))
1603 [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
1605 At most *maxsplit* splits are done. If *maxsplit* is not specified or -1,
1606 then there is no limit on the number of splits:
1608 >>> list(split_before(range(10), lambda n: n % 3 == 0, maxsplit=2))
1609 [[0, 1, 2], [3, 4, 5], [6, 7, 8, 9]]
1610 """
1611 if maxsplit == 0:
1612 yield list(iterable)
1613 return
1615 buf = []
1616 it = iter(iterable)
1617 for item in it:
1618 if pred(item) and buf:
1619 yield buf
1620 if maxsplit == 1:
1621 yield [item, *it]
1622 return
1623 buf = []
1624 maxsplit -= 1
1625 buf.append(item)
1626 if buf:
1627 yield buf
1630def split_after(iterable, pred, maxsplit=-1):
1631 """Yield lists of items from *iterable*, where each list ends with an
1632 item where callable *pred* returns ``True``:
1634 >>> list(split_after('one1two2', lambda s: s.isdigit()))
1635 [['o', 'n', 'e', '1'], ['t', 'w', 'o', '2']]
1637 >>> list(split_after(range(10), lambda n: n % 3 == 0))
1638 [[0], [1, 2, 3], [4, 5, 6], [7, 8, 9]]
1640 At most *maxsplit* splits are done. If *maxsplit* is not specified or -1,
1641 then there is no limit on the number of splits:
1643 >>> list(split_after(range(10), lambda n: n % 3 == 0, maxsplit=2))
1644 [[0], [1, 2, 3], [4, 5, 6, 7, 8, 9]]
1646 """
1647 if maxsplit == 0:
1648 yield list(iterable)
1649 return
1651 buf = []
1652 it = iter(iterable)
1653 for item in it:
1654 buf.append(item)
1655 if pred(item) and buf:
1656 yield buf
1657 if maxsplit == 1:
1658 buf = list(it)
1659 if buf:
1660 yield buf
1661 return
1662 buf = []
1663 maxsplit -= 1
1664 if buf:
1665 yield buf
1668def split_when(iterable, pred, maxsplit=-1):
1669 """Split *iterable* into pieces based on the output of *pred*.
1670 *pred* should be a function that takes successive pairs of items and
1671 returns ``True`` if the iterable should be split in between them.
1673 For example, to find runs of increasing numbers, split the iterable when
1674 element ``i`` is larger than element ``i + 1``:
1676 >>> list(split_when([1, 2, 3, 3, 2, 5, 2, 4, 2], lambda x, y: x > y))
1677 [[1, 2, 3, 3], [2, 5], [2, 4], [2]]
1679 At most *maxsplit* splits are done. If *maxsplit* is not specified or -1,
1680 then there is no limit on the number of splits:
1682 >>> list(split_when([1, 2, 3, 3, 2, 5, 2, 4, 2],
1683 ... lambda x, y: x > y, maxsplit=2))
1684 [[1, 2, 3, 3], [2, 5], [2, 4, 2]]
1686 """
1687 if maxsplit == 0:
1688 yield list(iterable)
1689 return
1691 it = iter(iterable)
1692 try:
1693 cur_item = next(it)
1694 except StopIteration:
1695 return
1697 buf = [cur_item]
1698 for next_item in it:
1699 if pred(cur_item, next_item):
1700 yield buf
1701 if maxsplit == 1:
1702 yield [next_item, *it]
1703 return
1704 buf = []
1705 maxsplit -= 1
1707 buf.append(next_item)
1708 cur_item = next_item
1710 yield buf
1713def split_into(iterable, sizes):
1714 """Yield a list of sequential items from *iterable* of length 'n' for each
1715 integer 'n' in *sizes*.
1717 >>> list(split_into([1,2,3,4,5,6], [1,2,3]))
1718 [[1], [2, 3], [4, 5, 6]]
1720 If the sum of *sizes* is smaller than the length of *iterable*, then the
1721 remaining items of *iterable* will not be returned.
1723 >>> list(split_into([1,2,3,4,5,6], [2,3]))
1724 [[1, 2], [3, 4, 5]]
1726 If the sum of *sizes* is larger than the length of *iterable*, fewer items
1727 will be returned in the iteration that overruns the *iterable* and further
1728 lists will be empty:
1730 >>> list(split_into([1,2,3,4], [1,2,3,4]))
1731 [[1], [2, 3], [4], []]
1733 When a ``None`` object is encountered in *sizes*, the returned list will
1734 contain items up to the end of *iterable* the same way that
1735 :func:`itertools.slice` does:
1737 >>> list(split_into([1,2,3,4,5,6,7,8,9,0], [2,3,None]))
1738 [[1, 2], [3, 4, 5], [6, 7, 8, 9, 0]]
1740 :func:`split_into` can be useful for grouping a series of items where the
1741 sizes of the groups are not uniform. An example would be where in a row
1742 from a table, multiple columns represent elements of the same feature
1743 (e.g. a point represented by x,y,z) but, the format is not the same for
1744 all columns.
1745 """
1746 # convert the iterable argument into an iterator so its contents can
1747 # be consumed by islice in case it is a generator
1748 it = iter(iterable)
1750 for size in sizes:
1751 if size is None:
1752 yield list(it)
1753 return
1754 else:
1755 yield list(islice(it, size))
1758def padded(iterable, fillvalue=None, n=None, next_multiple=False):
1759 """Yield the elements from *iterable*, followed by *fillvalue*, such that
1760 at least *n* items are emitted.
1762 >>> list(padded([1, 2, 3], '?', 5))
1763 [1, 2, 3, '?', '?']
1765 If *next_multiple* is ``True``, *fillvalue* will be emitted until the
1766 number of items emitted is a multiple of *n*:
1768 >>> list(padded([1, 2, 3, 4], n=3, next_multiple=True))
1769 [1, 2, 3, 4, None, None]
1771 If *n* is ``None``, *fillvalue* will be emitted indefinitely.
1773 To create an *iterable* of exactly size *n*, you can truncate with
1774 :func:`islice`.
1776 >>> list(islice(padded([1, 2, 3], '?'), 5))
1777 [1, 2, 3, '?', '?']
1778 >>> list(islice(padded([1, 2, 3, 4, 5, 6, 7, 8], '?'), 5))
1779 [1, 2, 3, 4, 5]
1781 """
1782 iterator = iter(iterable)
1783 iterator_with_repeat = chain(iterator, repeat(fillvalue))
1785 if n is None:
1786 return iterator_with_repeat
1787 elif n < 1:
1788 raise ValueError('n must be at least 1')
1789 elif next_multiple:
1791 def slice_generator():
1792 for first in iterator:
1793 yield (first,)
1794 yield islice(iterator_with_repeat, n - 1)
1796 # While elements exist produce slices of size n
1797 return chain.from_iterable(slice_generator())
1798 else:
1799 # Ensure the first batch is at least size n then iterate
1800 return chain(islice(iterator_with_repeat, n), iterator)
1803def repeat_each(iterable, n=2):
1804 """Repeat each element in *iterable* *n* times.
1806 >>> list(repeat_each('ABC', 3))
1807 ['A', 'A', 'A', 'B', 'B', 'B', 'C', 'C', 'C']
1808 """
1809 return chain.from_iterable(map(repeat, iterable, repeat(n)))
1812def repeat_last(iterable, default=None):
1813 """After the *iterable* is exhausted, keep yielding its last element.
1815 >>> list(islice(repeat_last(range(3)), 5))
1816 [0, 1, 2, 2, 2]
1818 If the iterable is empty, yield *default* forever::
1820 >>> list(islice(repeat_last(range(0), 42), 5))
1821 [42, 42, 42, 42, 42]
1823 """
1824 item = _marker
1825 for item in iterable:
1826 yield item
1827 final = default if item is _marker else item
1828 yield from repeat(final)
1831def distribute(n, iterable):
1832 """Distribute the items from *iterable* among *n* smaller iterables.
1834 >>> group_1, group_2 = distribute(2, [1, 2, 3, 4, 5, 6])
1835 >>> list(group_1)
1836 [1, 3, 5]
1837 >>> list(group_2)
1838 [2, 4, 6]
1840 If the length of *iterable* is not evenly divisible by *n*, then the
1841 length of the returned iterables will not be identical:
1843 >>> children = distribute(3, [1, 2, 3, 4, 5, 6, 7])
1844 >>> [list(c) for c in children]
1845 [[1, 4, 7], [2, 5], [3, 6]]
1847 If the length of *iterable* is smaller than *n*, then the last returned
1848 iterables will be empty:
1850 >>> children = distribute(5, [1, 2, 3])
1851 >>> [list(c) for c in children]
1852 [[1], [2], [3], [], []]
1854 This function uses :func:`itertools.tee` and may require significant
1855 storage.
1857 If you need the order items in the smaller iterables to match the
1858 original iterable, see :func:`divide`.
1860 """
1861 if n < 1:
1862 raise ValueError('n must be at least 1')
1864 children = tee(iterable, n)
1865 return [islice(it, index, None, n) for index, it in enumerate(children)]
1868def stagger(iterable, offsets=(-1, 0, 1), longest=False, fillvalue=None):
1869 """Yield tuples whose elements are offset from *iterable*.
1870 The amount by which the `i`-th item in each tuple is offset is given by
1871 the `i`-th item in *offsets*.
1873 >>> list(stagger([0, 1, 2, 3]))
1874 [(None, 0, 1), (0, 1, 2), (1, 2, 3)]
1875 >>> list(stagger(range(8), offsets=(0, 2, 4)))
1876 [(0, 2, 4), (1, 3, 5), (2, 4, 6), (3, 5, 7)]
1878 By default, the sequence will end when the final element of a tuple is the
1879 last item in the iterable. To continue until the first element of a tuple
1880 is the last item in the iterable, set *longest* to ``True``::
1882 >>> list(stagger([0, 1, 2, 3], longest=True))
1883 [(None, 0, 1), (0, 1, 2), (1, 2, 3), (2, 3, None), (3, None, None)]
1885 By default, ``None`` will be used to replace offsets beyond the end of the
1886 sequence. Specify *fillvalue* to use some other value.
1888 """
1889 children = tee(iterable, len(offsets))
1891 return zip_offset(
1892 *children, offsets=offsets, longest=longest, fillvalue=fillvalue
1893 )
1896def zip_offset(*iterables, offsets, longest=False, fillvalue=None):
1897 """``zip`` the input *iterables* together, but offset the `i`-th iterable
1898 by the `i`-th item in *offsets*.
1900 >>> list(zip_offset('0123', 'abcdef', offsets=(0, 1)))
1901 [('0', 'b'), ('1', 'c'), ('2', 'd'), ('3', 'e')]
1903 This can be used as a lightweight alternative to SciPy or pandas to analyze
1904 data sets in which some series have a lead or lag relationship.
1906 By default, the sequence will end when the shortest iterable is exhausted.
1907 To continue until the longest iterable is exhausted, set *longest* to
1908 ``True``.
1910 >>> list(zip_offset('0123', 'abcdef', offsets=(0, 1), longest=True))
1911 [('0', 'b'), ('1', 'c'), ('2', 'd'), ('3', 'e'), (None, 'f')]
1913 By default, ``None`` will be used to replace offsets beyond the end of the
1914 sequence. Specify *fillvalue* to use some other value.
1916 """
1917 if len(iterables) != len(offsets):
1918 raise ValueError("Number of iterables and offsets didn't match")
1920 staggered = []
1921 for it, n in zip(iterables, offsets):
1922 if n < 0:
1923 staggered.append(chain(repeat(fillvalue, -n), it))
1924 elif n > 0:
1925 staggered.append(islice(it, n, None))
1926 else:
1927 staggered.append(it)
1929 if longest:
1930 return zip_longest(*staggered, fillvalue=fillvalue)
1932 return zip(*staggered)
1935def sort_together(
1936 iterables, key_list=(0,), key=None, reverse=False, strict=False
1937):
1938 """Return the input iterables sorted together, with *key_list* as the
1939 priority for sorting. All iterables are trimmed to the length of the
1940 shortest one.
1942 This can be used like the sorting function in a spreadsheet. If each
1943 iterable represents a column of data, the key list determines which
1944 columns are used for sorting.
1946 By default, all iterables are sorted using the ``0``-th iterable::
1948 >>> iterables = [(4, 3, 2, 1), ('a', 'b', 'c', 'd')]
1949 >>> sort_together(iterables)
1950 [(1, 2, 3, 4), ('d', 'c', 'b', 'a')]
1952 Set a different key list to sort according to another iterable.
1953 Specifying multiple keys dictates how ties are broken::
1955 >>> iterables = [(3, 1, 2), (0, 1, 0), ('c', 'b', 'a')]
1956 >>> sort_together(iterables, key_list=(1, 2))
1957 [(2, 3, 1), (0, 0, 1), ('a', 'c', 'b')]
1959 To sort by a function of the elements of the iterable, pass a *key*
1960 function. Its arguments are the elements of the iterables corresponding to
1961 the key list::
1963 >>> names = ('a', 'b', 'c')
1964 >>> lengths = (1, 2, 3)
1965 >>> widths = (5, 2, 1)
1966 >>> def area(length, width):
1967 ... return length * width
1968 >>> sort_together([names, lengths, widths], key_list=(1, 2), key=area)
1969 [('c', 'b', 'a'), (3, 2, 1), (1, 2, 5)]
1971 Set *reverse* to ``True`` to sort in descending order.
1973 >>> sort_together([(1, 2, 3), ('c', 'b', 'a')], reverse=True)
1974 [(3, 2, 1), ('a', 'b', 'c')]
1976 If the *strict* keyword argument is ``True``, then
1977 ``ValueError`` will be raised if any of the iterables have
1978 different lengths.
1980 """
1981 if key is None:
1982 # if there is no key function, the key argument to sorted is an
1983 # itemgetter
1984 key_argument = itemgetter(*key_list)
1985 else:
1986 # if there is a key function, call it with the items at the offsets
1987 # specified by the key function as arguments
1988 key_list = list(key_list)
1989 if len(key_list) == 1:
1990 # if key_list contains a single item, pass the item at that offset
1991 # as the only argument to the key function
1992 key_offset = key_list[0]
1993 key_argument = lambda zipped_items: key(zipped_items[key_offset])
1994 else:
1995 # if key_list contains multiple items, use itemgetter to return a
1996 # tuple of items, which we pass as *args to the key function
1997 get_key_items = itemgetter(*key_list)
1998 key_argument = lambda zipped_items: key(
1999 *get_key_items(zipped_items)
2000 )
2002 transposed = zip(*iterables, strict=strict)
2003 reordered = sorted(transposed, key=key_argument, reverse=reverse)
2004 untransposed = zip(*reordered, strict=strict)
2005 return list(untransposed)
2008def unzip(iterable):
2009 """The inverse of :func:`zip`, this function disaggregates the elements
2010 of the zipped *iterable*.
2012 The ``i``-th iterable contains the ``i``-th element from each element
2013 of the zipped iterable. The first element is used to determine the
2014 length of the remaining elements.
2016 >>> iterable = [('a', 1), ('b', 2), ('c', 3), ('d', 4)]
2017 >>> letters, numbers = unzip(iterable)
2018 >>> list(letters)
2019 ['a', 'b', 'c', 'd']
2020 >>> list(numbers)
2021 [1, 2, 3, 4]
2023 This is similar to using ``zip(*iterable)``, but it avoids reading
2024 *iterable* into memory. Note, however, that this function uses
2025 :func:`itertools.tee` and thus may require significant storage.
2027 """
2028 head, iterable = spy(iterable)
2029 if not head:
2030 # empty iterable, e.g. zip([], [], [])
2031 return ()
2032 # spy returns a one-length iterable as head
2033 head = head[0]
2034 iterables = tee(iterable, len(head))
2036 # If we have an iterable like iter([(1, 2, 3), (4, 5), (6,)]),
2037 # the second unzipped iterable fails at the third tuple since
2038 # it tries to access (6,)[1].
2039 # Same with the third unzipped iterable and the second tuple.
2040 # To support these "improperly zipped" iterables, we suppress
2041 # the IndexError, which just stops the unzipped iterables at
2042 # first length mismatch.
2043 return tuple(
2044 iter_suppress(map(itemgetter(i), it), IndexError)
2045 for i, it in enumerate(iterables)
2046 )
2049def divide(n, iterable):
2050 """Divide the elements from *iterable* into *n* parts, maintaining
2051 order.
2053 >>> group_1, group_2 = divide(2, [1, 2, 3, 4, 5, 6])
2054 >>> list(group_1)
2055 [1, 2, 3]
2056 >>> list(group_2)
2057 [4, 5, 6]
2059 If the length of *iterable* is not evenly divisible by *n*, then the
2060 length of the returned iterables will not be identical:
2062 >>> children = divide(3, [1, 2, 3, 4, 5, 6, 7])
2063 >>> [list(c) for c in children]
2064 [[1, 2, 3], [4, 5], [6, 7]]
2066 If the length of the iterable is smaller than n, then the last returned
2067 iterables will be empty:
2069 >>> children = divide(5, [1, 2, 3])
2070 >>> [list(c) for c in children]
2071 [[1], [2], [3], [], []]
2073 This function will exhaust the iterable before returning.
2074 If order is not important, see :func:`distribute`, which does not first
2075 pull the iterable into memory.
2077 """
2078 if n < 1:
2079 raise ValueError('n must be at least 1')
2081 try:
2082 iterable[:0]
2083 except TypeError:
2084 seq = tuple(iterable)
2085 else:
2086 seq = iterable
2088 q, r = divmod(len(seq), n)
2090 ret = []
2091 stop = 0
2092 for i in range(1, n + 1):
2093 start = stop
2094 stop += q + 1 if i <= r else q
2095 ret.append(iter(seq[start:stop]))
2097 return ret
2100def always_iterable(obj, base_type=(str, bytes)):
2101 """If *obj* is iterable, return an iterator over its items::
2103 >>> obj = (1, 2, 3)
2104 >>> list(always_iterable(obj))
2105 [1, 2, 3]
2107 If *obj* is not iterable, return a one-item iterable containing *obj*::
2109 >>> obj = 1
2110 >>> list(always_iterable(obj))
2111 [1]
2113 If *obj* is ``None``, return an empty iterable:
2115 >>> obj = None
2116 >>> list(always_iterable(None))
2117 []
2119 By default, binary and text strings are not considered iterable::
2121 >>> obj = 'foo'
2122 >>> list(always_iterable(obj))
2123 ['foo']
2125 If *base_type* is set, objects for which ``isinstance(obj, base_type)``
2126 returns ``True`` won't be considered iterable.
2128 >>> obj = {'a': 1}
2129 >>> list(always_iterable(obj)) # Iterate over the dict's keys
2130 ['a']
2131 >>> list(always_iterable(obj, base_type=dict)) # Treat dicts as a unit
2132 [{'a': 1}]
2134 Set *base_type* to ``None`` to avoid any special handling and treat objects
2135 Python considers iterable as iterable:
2137 >>> obj = 'foo'
2138 >>> list(always_iterable(obj, base_type=None))
2139 ['f', 'o', 'o']
2140 """
2141 if obj is None:
2142 return iter(())
2144 if (base_type is not None) and isinstance(obj, base_type):
2145 return iter((obj,))
2147 try:
2148 return iter(obj)
2149 except TypeError:
2150 return iter((obj,))
2153def adjacent(predicate, iterable, distance=1):
2154 """Return an iterable over `(bool, item)` tuples where the `item` is
2155 drawn from *iterable* and the `bool` indicates whether
2156 that item satisfies the *predicate* or is adjacent to an item that does.
2158 For example, to find whether items are adjacent to a ``3``::
2160 >>> list(adjacent(lambda x: x == 3, range(6)))
2161 [(False, 0), (False, 1), (True, 2), (True, 3), (True, 4), (False, 5)]
2163 Set *distance* to change what counts as adjacent. For example, to find
2164 whether items are two places away from a ``3``:
2166 >>> list(adjacent(lambda x: x == 3, range(6), distance=2))
2167 [(False, 0), (True, 1), (True, 2), (True, 3), (True, 4), (True, 5)]
2169 This is useful for contextualizing the results of a search function.
2170 For example, a code comparison tool might want to identify lines that
2171 have changed, but also surrounding lines to give the viewer of the diff
2172 context.
2174 The predicate function will only be called once for each item in the
2175 iterable.
2177 See also :func:`groupby_transform`, which can be used with this function
2178 to group ranges of items with the same `bool` value.
2180 """
2181 # Allow distance=0 mainly for testing that it reproduces results with map()
2182 if distance < 0:
2183 raise ValueError('distance must be at least 0')
2185 i1, i2 = tee(iterable)
2186 padding = [False] * distance
2187 selected = chain(padding, map(predicate, i1), padding)
2188 adjacent_to_selected = map(any, windowed(selected, 2 * distance + 1))
2189 return zip(adjacent_to_selected, i2)
2192def groupby_transform(iterable, keyfunc=None, valuefunc=None, reducefunc=None):
2193 """An extension of :func:`itertools.groupby` that can apply transformations
2194 to the grouped data.
2196 * *keyfunc* is a function computing a key value for each item in *iterable*
2197 * *valuefunc* is a function that transforms the individual items from
2198 *iterable* after grouping
2199 * *reducefunc* is a function that transforms each group of items
2201 >>> iterable = 'aAAbBBcCC'
2202 >>> keyfunc = lambda k: k.upper()
2203 >>> valuefunc = lambda v: v.lower()
2204 >>> reducefunc = lambda g: ''.join(g)
2205 >>> list(groupby_transform(iterable, keyfunc, valuefunc, reducefunc))
2206 [('A', 'aaa'), ('B', 'bbb'), ('C', 'ccc')]
2208 Each optional argument defaults to an identity function if not specified.
2210 :func:`groupby_transform` is useful when grouping elements of an iterable
2211 using a separate iterable as the key. To do this, :func:`zip` the iterables
2212 and pass a *keyfunc* that extracts the first element and a *valuefunc*
2213 that extracts the second element::
2215 >>> from operator import itemgetter
2216 >>> keys = [0, 0, 1, 1, 1, 2, 2, 2, 3]
2217 >>> values = 'abcdefghi'
2218 >>> iterable = zip(keys, values)
2219 >>> grouper = groupby_transform(iterable, itemgetter(0), itemgetter(1))
2220 >>> [(k, ''.join(g)) for k, g in grouper]
2221 [(0, 'ab'), (1, 'cde'), (2, 'fgh'), (3, 'i')]
2223 Note that the order of items in the iterable is significant.
2224 Only adjacent items are grouped together, so if you don't want any
2225 duplicate groups, you should sort the iterable by the key function
2226 or consider :func:`bucket` or :func:`map_reduce`. :func:`map_reduce`
2227 consumes the iterable immediately and returns a dictionary, while
2228 :func:`bucket` does not.
2230 .. seealso:: :func:`bucket`, :func:`map_reduce`
2232 """
2233 ret = groupby(iterable, keyfunc)
2234 if valuefunc:
2235 ret = ((k, map(valuefunc, g)) for k, g in ret)
2236 if reducefunc:
2237 ret = ((k, reducefunc(g)) for k, g in ret)
2239 return ret
2242class numeric_range(Sequence):
2243 """An extension of the built-in ``range()`` function whose arguments can
2244 be any orderable numeric type.
2246 With only *stop* specified, *start* defaults to ``0`` and *step*
2247 defaults to ``1``. The output items will match the type of *stop*:
2249 >>> list(numeric_range(3.5))
2250 [0.0, 1.0, 2.0, 3.0]
2252 With only *start* and *stop* specified, *step* defaults to ``1``. The
2253 output items will match the type of *start*:
2255 >>> from decimal import Decimal
2256 >>> start = Decimal('2.1')
2257 >>> stop = Decimal('5.1')
2258 >>> list(numeric_range(start, stop))
2259 [Decimal('2.1'), Decimal('3.1'), Decimal('4.1')]
2261 With *start*, *stop*, and *step* specified the output items will match
2262 the type of ``start + step``:
2264 >>> from fractions import Fraction
2265 >>> start = Fraction(1, 2) # Start at 1/2
2266 >>> stop = Fraction(5, 2) # End at 5/2
2267 >>> step = Fraction(1, 2) # Count by 1/2
2268 >>> list(numeric_range(start, stop, step))
2269 [Fraction(1, 2), Fraction(1, 1), Fraction(3, 2), Fraction(2, 1)]
2271 If *step* is zero, ``ValueError`` is raised. Negative steps are supported:
2273 >>> list(numeric_range(3, -1, -1.0))
2274 [3.0, 2.0, 1.0, 0.0]
2276 Be aware of the limitations of floating-point numbers; the representation
2277 of the yielded numbers may be surprising.
2279 ``datetime.datetime`` objects can be used for *start* and *stop*, if *step*
2280 is a ``datetime.timedelta`` object:
2282 >>> import datetime
2283 >>> start = datetime.datetime(2019, 1, 1)
2284 >>> stop = datetime.datetime(2019, 1, 3)
2285 >>> step = datetime.timedelta(days=1)
2286 >>> items = iter(numeric_range(start, stop, step))
2287 >>> next(items)
2288 datetime.datetime(2019, 1, 1, 0, 0)
2289 >>> next(items)
2290 datetime.datetime(2019, 1, 2, 0, 0)
2292 """
2294 _EMPTY_HASH = hash(range(0, 0))
2296 def __init__(self, *args):
2297 argc = len(args)
2298 if argc == 1:
2299 (self._stop,) = args
2300 self._start = type(self._stop)(0)
2301 self._step = type(self._stop - self._start)(1)
2302 elif argc == 2:
2303 self._start, self._stop = args
2304 self._step = type(self._stop - self._start)(1)
2305 elif argc == 3:
2306 self._start, self._stop, self._step = args
2307 elif argc == 0:
2308 raise TypeError(
2309 f'numeric_range expected at least 1 argument, got {argc}'
2310 )
2311 else:
2312 raise TypeError(
2313 f'numeric_range expected at most 3 arguments, got {argc}'
2314 )
2316 self._zero = type(self._step)(0)
2317 if self._step == self._zero:
2318 raise ValueError('numeric_range() arg 3 must not be zero')
2319 self._growing = self._step > self._zero
2321 def __bool__(self):
2322 if self._growing:
2323 return self._start < self._stop
2324 else:
2325 return self._start > self._stop
2327 def __contains__(self, elem):
2328 if self._growing:
2329 if self._start <= elem < self._stop:
2330 return (elem - self._start) % self._step == self._zero
2331 else:
2332 if self._start >= elem > self._stop:
2333 return (self._start - elem) % (-self._step) == self._zero
2335 return False
2337 def __eq__(self, other):
2338 if isinstance(other, numeric_range):
2339 empty_self = not bool(self)
2340 empty_other = not bool(other)
2341 if empty_self or empty_other:
2342 return empty_self and empty_other # True if both empty
2343 else:
2344 return (
2345 self._start == other._start
2346 and self._step == other._step
2347 and self._get_by_index(-1) == other._get_by_index(-1)
2348 )
2349 else:
2350 return False
2352 def __getitem__(self, key):
2353 if isinstance(key, int):
2354 return self._get_by_index(key)
2355 elif isinstance(key, slice):
2356 start_idx, stop_idx, step_idx = key.indices(self._len)
2357 return numeric_range(
2358 self._start + start_idx * self._step,
2359 self._start + stop_idx * self._step,
2360 self._step * step_idx,
2361 )
2362 else:
2363 raise TypeError(
2364 'numeric range indices must be '
2365 f'integers or slices, not {type(key).__name__}'
2366 )
2368 def __hash__(self):
2369 if self:
2370 return hash((self._start, self._get_by_index(-1), self._step))
2371 else:
2372 return self._EMPTY_HASH
2374 def __iter__(self):
2375 values = (self._start + (n * self._step) for n in count())
2376 if self._growing:
2377 return takewhile(partial(gt, self._stop), values)
2378 else:
2379 return takewhile(partial(lt, self._stop), values)
2381 def __len__(self):
2382 return self._len
2384 @cached_property
2385 def _len(self):
2386 if self._growing:
2387 start = self._start
2388 stop = self._stop
2389 step = self._step
2390 else:
2391 start = self._stop
2392 stop = self._start
2393 step = -self._step
2394 distance = stop - start
2395 if distance <= self._zero:
2396 return 0
2397 else: # distance > 0 and step > 0: regular euclidean division
2398 q, r = divmod(distance, step)
2399 return int(q) + int(r != self._zero)
2401 def __reduce__(self):
2402 return numeric_range, (self._start, self._stop, self._step)
2404 def __repr__(self):
2405 if self._step == 1:
2406 return f"numeric_range({self._start!r}, {self._stop!r})"
2407 return (
2408 f"numeric_range({self._start!r}, {self._stop!r}, {self._step!r})"
2409 )
2411 def __reversed__(self):
2412 # Empty iterator
2413 try:
2414 start = self._get_by_index(-1)
2415 except IndexError:
2416 return iter([])
2418 return iter(
2419 numeric_range(start, self._start - self._step, -self._step)
2420 )
2422 def count(self, value):
2423 return int(value in self)
2425 def index(self, value):
2426 if self._growing:
2427 if self._start <= value < self._stop:
2428 q, r = divmod(value - self._start, self._step)
2429 if r == self._zero:
2430 return int(q)
2431 else:
2432 if self._start >= value > self._stop:
2433 q, r = divmod(self._start - value, -self._step)
2434 if r == self._zero:
2435 return int(q)
2437 raise ValueError(f"{value} is not in numeric range")
2439 def _get_by_index(self, i):
2440 if i < 0:
2441 i += self._len
2442 if i < 0 or i >= self._len:
2443 raise IndexError("numeric range object index out of range")
2444 return self._start + i * self._step
2447def count_cycle(iterable, n=None):
2448 """Cycle through the items from *iterable* up to *n* times, yielding
2449 the number of completed cycles along with each item. If *n* is omitted the
2450 process repeats indefinitely.
2452 >>> list(count_cycle('AB', 3))
2453 [(0, 'A'), (0, 'B'), (1, 'A'), (1, 'B'), (2, 'A'), (2, 'B')]
2455 """
2456 if n is not None:
2457 return product(range(n), iterable)
2458 seq = tuple(iterable)
2459 if not seq:
2460 return iter(())
2461 return zip(repeat_each(count(), len(seq)), cycle(seq))
2464def mark_ends(iterable):
2465 """Yield 3-tuples of the form ``(is_first, is_last, item)``.
2467 >>> list(mark_ends('ABC'))
2468 [(True, False, 'A'), (False, False, 'B'), (False, True, 'C')]
2470 Use this when looping over an iterable to take special action on its first
2471 and/or last items:
2473 >>> iterable = ['Header', 100, 200, 'Footer']
2474 >>> total = 0
2475 >>> for is_first, is_last, item in mark_ends(iterable):
2476 ... if is_first:
2477 ... continue # Skip the header
2478 ... if is_last:
2479 ... continue # Skip the footer
2480 ... total += item
2481 >>> print(total)
2482 300
2483 """
2484 it = iter(iterable)
2485 for a in it:
2486 first = True
2487 for b in it:
2488 yield first, False, a
2489 a = b
2490 first = False
2491 yield first, True, a
2494def locate(iterable, pred=bool, window_size=None):
2495 """Yield the index of each item in *iterable* for which *pred* returns
2496 ``True``.
2498 *pred* defaults to :func:`bool`, which will select truthy items:
2500 >>> list(locate([0, 1, 1, 0, 1, 0, 0]))
2501 [1, 2, 4]
2503 Set *pred* to a custom function to, e.g., find the indexes for a particular
2504 item.
2506 >>> list(locate(['a', 'b', 'c', 'b'], lambda x: x == 'b'))
2507 [1, 3]
2509 If *window_size* is given, then the *pred* function will be called with
2510 the values in each window. This enables searching for sub-sequences.
2511 Note that *pred* may receive fewer than *window_size* arguments at the end of
2512 the iterable.
2514 >>> iterable = [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]
2515 >>> pred = lambda *args: args == (1, 2, 3)
2516 >>> list(locate(iterable, pred=pred, window_size=3))
2517 [1, 5, 9]
2519 Use with :func:`seekable` to find indexes and then retrieve the associated
2520 items:
2522 >>> from itertools import count
2523 >>> from more_itertools import seekable
2524 >>> source = (3 * n + 1 if (n % 2) else n // 2 for n in count())
2525 >>> it = seekable(source)
2526 >>> pred = lambda x: x > 100
2527 >>> indexes = locate(it, pred=pred)
2528 >>> i = next(indexes)
2529 >>> it.seek(i)
2530 >>> next(it)
2531 106
2533 """
2534 if window_size is None:
2535 return compress(count(), map(pred, iterable))
2537 if window_size < 1:
2538 raise ValueError('window size must be at least 1')
2540 it = windowed(iterable, window_size, fillvalue=_marker)
2541 return compress(
2542 count(),
2543 (pred(*(x for x in w if x is not _marker)) for w in it),
2544 )
2547def longest_common_prefix(iterables):
2548 """Yield elements of the longest common prefix among given *iterables*.
2550 >>> ''.join(longest_common_prefix(['abcd', 'abc', 'abf']))
2551 'ab'
2553 """
2554 return (c[0] for c in takewhile(all_equal, zip(*iterables)))
2557def lstrip(iterable, pred):
2558 """Yield the items from *iterable*, but strip any from the beginning
2559 for which *pred* returns ``True``.
2561 For example, to remove a set of items from the start of an iterable:
2563 >>> iterable = (None, False, None, 1, 2, None, 3, False, None)
2564 >>> pred = lambda x: x in {None, False, ''}
2565 >>> list(lstrip(iterable, pred))
2566 [1, 2, None, 3, False, None]
2568 This function is analogous to :func:`str.lstrip`, and is essentially
2569 a wrapper for :func:`itertools.dropwhile`.
2571 """
2572 return dropwhile(pred, iterable)
2575def rstrip(iterable, pred):
2576 """Yield the items from *iterable*, but strip any from the end
2577 for which *pred* returns ``True``.
2579 For example, to remove a set of items from the end of an iterable:
2581 >>> iterable = (None, False, None, 1, 2, None, 3, False, None)
2582 >>> pred = lambda x: x in {None, False, ''}
2583 >>> list(rstrip(iterable, pred))
2584 [None, False, None, 1, 2, None, 3]
2586 This function is analogous to :func:`str.rstrip`.
2588 """
2589 cache = []
2590 cache_append = cache.append
2591 cache_clear = cache.clear
2592 for x in iterable:
2593 if pred(x):
2594 cache_append(x)
2595 else:
2596 yield from cache
2597 cache_clear()
2598 yield x
2601def strip(iterable, pred):
2602 """Yield the items from *iterable*, but strip any from the
2603 beginning and end for which *pred* returns ``True``.
2605 For example, to remove a set of items from both ends of an iterable:
2607 >>> iterable = (None, False, None, 1, 2, None, 3, False, None)
2608 >>> pred = lambda x: x in {None, False, ''}
2609 >>> list(strip(iterable, pred))
2610 [1, 2, None, 3]
2612 This function is analogous to :func:`str.strip`.
2614 """
2615 return rstrip(lstrip(iterable, pred), pred)
2618class islice_extended:
2619 """An extension of :func:`itertools.islice` that supports negative values
2620 for *stop*, *start*, and *step*.
2622 >>> iterator = iter('abcdefgh')
2623 >>> list(islice_extended(iterator, -4, -1))
2624 ['e', 'f', 'g']
2626 Slices with negative values require some caching of *iterable*, but this
2627 function takes care to minimize the amount of memory required.
2629 For example, you can use a negative step with an infinite iterator:
2631 >>> from itertools import count
2632 >>> list(islice_extended(count(), 110, 99, -2))
2633 [110, 108, 106, 104, 102, 100]
2635 You can also use slice notation directly:
2637 >>> iterator = map(str, count())
2638 >>> it = islice_extended(iterator)[10:20:2]
2639 >>> list(it)
2640 ['10', '12', '14', '16', '18']
2642 """
2644 def __init__(self, iterable, *args):
2645 it = iter(iterable)
2646 if args:
2647 self._iterator = _islice_helper(it, slice(*args))
2648 else:
2649 self._iterator = it
2651 def __iter__(self):
2652 return self
2654 def __next__(self):
2655 return next(self._iterator)
2657 def __getitem__(self, key):
2658 if isinstance(key, slice):
2659 return islice_extended(_islice_helper(self._iterator, key))
2661 raise TypeError('islice_extended.__getitem__ argument must be a slice')
2664def _islice_helper(it, s):
2665 start = s.start
2666 stop = s.stop
2667 if s.step == 0:
2668 raise ValueError('step argument must be a non-zero integer or None.')
2669 step = s.step or 1
2671 if step > 0:
2672 start = 0 if (start is None) else start
2674 if start < 0:
2675 # Consume all but the last -start items
2676 counter = count(1)
2677 wrapper = compress(it, counter)
2678 cache = deque(wrapper, maxlen=-start)
2679 len_iter = next(counter) - 1
2681 # Adjust start to be positive
2682 i = max(len_iter + start, 0)
2684 # Adjust stop to be positive
2685 if stop is None:
2686 j = len_iter
2687 elif stop >= 0:
2688 j = min(stop, len_iter)
2689 else:
2690 j = max(len_iter + stop, 0)
2692 # Slice the cache
2693 n = j - i
2694 if n <= 0:
2695 return
2697 for index in range(n):
2698 if index % step == 0:
2699 # pop and yield the item.
2700 # We don't want to use an intermediate variable
2701 # it would extend the lifetime of the current item
2702 yield cache.popleft()
2703 else:
2704 # just pop and discard the item
2705 cache.popleft()
2706 elif (stop is not None) and (stop < 0):
2707 # Advance to the start position
2708 next(islice(it, start, start), None)
2710 # When stop is negative, we have to carry -stop items while
2711 # iterating
2712 cache = deque(islice(it, -stop), maxlen=-stop)
2714 for index, item in enumerate(it):
2715 if index % step == 0:
2716 # pop and yield the item.
2717 # We don't want to use an intermediate variable
2718 # it would extend the lifetime of the current item
2719 yield cache.popleft()
2720 else:
2721 # just pop and discard the item
2722 cache.popleft()
2723 cache.append(item)
2724 else:
2725 # When both start and stop are positive we have the normal case
2726 yield from islice(it, start, stop, step)
2727 else:
2728 start = -1 if (start is None) else start
2730 if (stop is not None) and (stop < 0):
2731 # Consume all but the last items
2732 n = -stop - 1
2733 counter = count(1)
2734 wrapper = compress(it, counter)
2735 cache = deque(wrapper, maxlen=n)
2736 len_iter = next(counter) - 1
2738 # If start and stop are both negative they are comparable and
2739 # we can just slice. Otherwise we can adjust start to be negative
2740 # and then slice.
2741 if start < 0:
2742 i, j = start, stop
2743 else:
2744 i, j = min(start - len_iter, -1), None
2746 yield from list(cache)[i:j:step]
2747 else:
2748 # Advance to the stop position
2749 if stop is not None:
2750 m = stop + 1
2751 next(islice(it, m, m), None)
2753 # stop is positive, so if start is negative they are not comparable
2754 # and we need the rest of the items.
2755 if start < 0:
2756 i = start
2757 n = None
2758 # stop is None and start is positive, so we just need items up to
2759 # the start index.
2760 elif stop is None:
2761 i = None
2762 n = start + 1
2763 # Both stop and start are positive, so they are comparable.
2764 else:
2765 i = None
2766 n = start - stop
2767 if n <= 0:
2768 return
2770 cache = list(islice(it, n))
2772 yield from cache[i::step]
2775def always_reversible(iterable):
2776 """An extension of :func:`reversed` that supports all iterables, not
2777 just those which implement the ``Reversible`` or ``Sequence`` protocols.
2779 >>> print(*always_reversible(x for x in range(3)))
2780 2 1 0
2782 If the iterable is already reversible, this function returns the
2783 result of :func:`reversed()`. If the iterable is not reversible,
2784 this function will cache the remaining items in the iterable and
2785 yield them in reverse order, which may require significant storage.
2786 """
2787 try:
2788 return reversed(iterable)
2789 except TypeError:
2790 return reversed(list(iterable))
2793def consecutive_groups(iterable, ordering=None):
2794 """Yield groups of consecutive items using :func:`itertools.groupby`.
2795 The *ordering* function determines whether two items are adjacent by
2796 returning their position.
2798 By default, the ordering function is the identity function. This is
2799 suitable for finding runs of numbers:
2801 >>> iterable = [1, 10, 11, 12, 20, 30, 31, 32, 33, 40]
2802 >>> for group in consecutive_groups(iterable):
2803 ... print(list(group))
2804 [1]
2805 [10, 11, 12]
2806 [20]
2807 [30, 31, 32, 33]
2808 [40]
2810 To find runs of adjacent letters, apply :func:`ord` function
2811 to convert letters to ordinals.
2813 >>> iterable = 'abcdfgilmnop'
2814 >>> ordering = ord
2815 >>> for group in consecutive_groups(iterable, ordering):
2816 ... print(list(group))
2817 ['a', 'b', 'c', 'd']
2818 ['f', 'g']
2819 ['i']
2820 ['l', 'm', 'n', 'o', 'p']
2822 Each group of consecutive items is an iterator that shares its source with
2823 *iterable*. When an output group is advanced, the previous group is
2824 no longer available unless its elements are copied (e.g., into a ``list``).
2826 >>> iterable = [1, 2, 11, 12, 21, 22]
2827 >>> saved_groups = []
2828 >>> for group in consecutive_groups(iterable):
2829 ... saved_groups.append(list(group)) # Copy group elements
2830 >>> saved_groups
2831 [[1, 2], [11, 12], [21, 22]]
2833 """
2834 if ordering is None:
2835 key = lambda x: x[0] - x[1]
2836 else:
2837 key = lambda x: x[0] - ordering(x[1])
2839 for k, g in groupby(enumerate(iterable), key=key):
2840 yield map(itemgetter(1), g)
2843def difference(iterable, func=sub, *, initial=None):
2844 """This function is the inverse of :func:`itertools.accumulate`. By default
2845 it will compute the first difference of *iterable* using
2846 :func:`operator.sub`:
2848 >>> from itertools import accumulate
2849 >>> iterable = accumulate([0, 1, 2, 3, 4]) # produces 0, 1, 3, 6, 10
2850 >>> list(difference(iterable))
2851 [0, 1, 2, 3, 4]
2853 *func* defaults to :func:`operator.sub`, but other functions can be
2854 specified. They will be applied as follows::
2856 A, B, C, D, ... --> A, func(B, A), func(C, B), func(D, C), ...
2858 For example, to do progressive division:
2860 >>> iterable = [1, 2, 6, 24, 120]
2861 >>> func = lambda x, y: x // y
2862 >>> list(difference(iterable, func))
2863 [1, 2, 3, 4, 5]
2865 If the *initial* keyword is set, the first element will be skipped when
2866 computing successive differences.
2868 >>> it = [10, 11, 13, 16] # from accumulate([1, 2, 3], initial=10)
2869 >>> list(difference(it, initial=10))
2870 [1, 2, 3]
2872 """
2873 a, b = tee(iterable)
2874 try:
2875 first = [next(b)]
2876 except StopIteration:
2877 return iter([])
2879 if initial is not None:
2880 return map(func, b, a)
2882 return chain(first, map(func, b, a))
2885class SequenceView(Sequence):
2886 """Return a read-only view of the sequence object *target*.
2888 :class:`SequenceView` objects are analogous to Python's built-in
2889 "dictionary view" types. They provide a dynamic view of a sequence's items,
2890 meaning that when the sequence updates, so does the view.
2892 >>> seq = ['0', '1', '2']
2893 >>> view = SequenceView(seq)
2894 >>> view
2895 SequenceView(['0', '1', '2'])
2896 >>> seq.append('3')
2897 >>> view
2898 SequenceView(['0', '1', '2', '3'])
2900 Sequence views support indexing, slicing, and length queries. They act
2901 like the underlying sequence, except they don't allow assignment:
2903 >>> view[1]
2904 '1'
2905 >>> view[1:-1]
2906 ['1', '2']
2907 >>> len(view)
2908 4
2910 Sequence views are useful as an alternative to copying, as they don't
2911 require (much) extra storage.
2913 """
2915 def __init__(self, target):
2916 if not isinstance(target, Sequence):
2917 raise TypeError
2918 self._target = target
2920 def __getitem__(self, index):
2921 return self._target[index]
2923 def __len__(self):
2924 return len(self._target)
2926 def __repr__(self):
2927 return f'{self.__class__.__name__}({self._target!r})'
2930class seekable:
2931 """Wrap an iterator to allow for seeking backward and forward. This
2932 progressively caches the items in the source iterable so they can be
2933 re-visited.
2935 Call :meth:`seek` with an index to seek to that position in the source
2936 iterable.
2938 To "reset" an iterator, seek to ``0``:
2940 >>> from itertools import count
2941 >>> it = seekable((str(n) for n in count()))
2942 >>> next(it), next(it), next(it)
2943 ('0', '1', '2')
2944 >>> it.seek(0)
2945 >>> next(it), next(it), next(it)
2946 ('0', '1', '2')
2948 You can also seek forward:
2950 >>> it = seekable((str(n) for n in range(20)))
2951 >>> it.seek(10)
2952 >>> next(it)
2953 '10'
2954 >>> it.seek(20) # Seeking past the end of the source isn't a problem
2955 >>> list(it)
2956 []
2957 >>> it.seek(0) # Resetting works even after hitting the end
2958 >>> next(it)
2959 '0'
2961 Call :meth:`relative_seek` to seek relative to the source iterator's
2962 current position.
2964 >>> it = seekable((str(n) for n in range(20)))
2965 >>> next(it), next(it), next(it)
2966 ('0', '1', '2')
2967 >>> it.relative_seek(2)
2968 >>> next(it)
2969 '5'
2970 >>> it.relative_seek(-3) # Source is at '6', we move back to '3'
2971 >>> next(it)
2972 '3'
2973 >>> it.relative_seek(-3) # Source is at '4', we move back to '1'
2974 >>> next(it)
2975 '1'
2978 Call :meth:`peek` to look ahead one item without advancing the iterator:
2980 >>> it = seekable('1234')
2981 >>> it.peek()
2982 '1'
2983 >>> list(it)
2984 ['1', '2', '3', '4']
2985 >>> it.peek(default='empty')
2986 'empty'
2988 Before the iterator is at its end, calling :func:`bool` on it will return
2989 ``True``. After it will return ``False``:
2991 >>> it = seekable('5678')
2992 >>> bool(it)
2993 True
2994 >>> list(it)
2995 ['5', '6', '7', '8']
2996 >>> bool(it)
2997 False
2999 You may view the contents of the cache with the :meth:`elements` method.
3000 That returns a :class:`SequenceView`, a view that updates automatically:
3002 >>> it = seekable((str(n) for n in range(10)))
3003 >>> next(it), next(it), next(it)
3004 ('0', '1', '2')
3005 >>> elements = it.elements()
3006 >>> elements
3007 SequenceView(['0', '1', '2'])
3008 >>> next(it)
3009 '3'
3010 >>> elements
3011 SequenceView(['0', '1', '2', '3'])
3013 Indexing the :class:`seekable` directly returns items from the cache:
3015 >>> it = seekable((str(n) for n in range(10)))
3016 >>> next(it), next(it), next(it)
3017 ('0', '1', '2')
3018 >>> it[-1]
3019 '2'
3020 >>> it[0]
3021 '0'
3023 By default, the cache grows as the source iterable progresses, so beware of
3024 wrapping very large or infinite iterables. Supply *maxlen* to limit the
3025 size of the cache (this of course limits how far back you can seek).
3027 >>> from itertools import count
3028 >>> it = seekable((str(n) for n in count()), maxlen=2)
3029 >>> next(it), next(it), next(it), next(it)
3030 ('0', '1', '2', '3')
3031 >>> list(it.elements())
3032 ['2', '3']
3033 >>> it.seek(0)
3034 >>> next(it), next(it), next(it), next(it)
3035 ('2', '3', '4', '5')
3036 >>> next(it)
3037 '6'
3039 """
3041 def __init__(self, iterable, maxlen=None):
3042 self._source = iter(iterable)
3043 if maxlen is None:
3044 self._cache = []
3045 else:
3046 self._cache = deque([], maxlen)
3047 self._index = None
3049 def __iter__(self):
3050 return self
3052 def __next__(self):
3053 if self._index is not None:
3054 try:
3055 item = self._cache[self._index]
3056 except IndexError:
3057 self._index = None
3058 else:
3059 self._index += 1
3060 return item
3062 item = next(self._source)
3063 self._cache.append(item)
3064 return item
3066 def __bool__(self):
3067 try:
3068 self.peek()
3069 except StopIteration:
3070 return False
3071 return True
3073 def peek(self, default=_marker):
3074 try:
3075 peeked = next(self)
3076 except StopIteration:
3077 if default is _marker:
3078 raise
3079 return default
3080 if self._index is None:
3081 self._index = len(self._cache)
3082 self._index -= 1
3083 return peeked
3085 def elements(self):
3086 return SequenceView(self._cache)
3088 def seek(self, index):
3089 self._index = index
3090 remainder = index - len(self._cache)
3091 if remainder > 0:
3092 consume(self, remainder)
3094 def relative_seek(self, count):
3095 if self._index is None:
3096 self._index = len(self._cache)
3098 self.seek(max(self._index + count, 0))
3100 def __getitem__(self, index):
3101 return self._cache[index]
3104class run_length:
3105 """
3106 :func:`run_length.encode` compresses an iterable with run-length encoding.
3107 It yields groups of repeated items with the count of how many times they
3108 were repeated:
3110 >>> uncompressed = 'abbcccdddd'
3111 >>> list(run_length.encode(uncompressed))
3112 [('a', 1), ('b', 2), ('c', 3), ('d', 4)]
3114 :func:`run_length.decode` decompresses an iterable that was previously
3115 compressed with run-length encoding. It yields the items of the
3116 decompressed iterable:
3118 >>> compressed = [('a', 1), ('b', 2), ('c', 3), ('d', 4)]
3119 >>> list(run_length.decode(compressed))
3120 ['a', 'b', 'b', 'c', 'c', 'c', 'd', 'd', 'd', 'd']
3122 """
3124 @staticmethod
3125 def encode(iterable):
3126 return ((k, ilen(g)) for k, g in groupby(iterable))
3128 @staticmethod
3129 def decode(iterable):
3130 return chain.from_iterable(starmap(repeat, iterable))
3133def exactly_n(iterable, n, predicate=bool):
3134 """Return ``True`` if exactly ``n`` items in the iterable are ``True``
3135 according to the *predicate* function.
3137 >>> exactly_n([True, True, False], 2)
3138 True
3139 >>> exactly_n([True, True, False], 1)
3140 False
3141 >>> exactly_n([0, 1, 2, 3, 4, 5], 3, lambda x: x < 3)
3142 True
3144 The iterable will be advanced until ``n + 1`` truthy items are encountered,
3145 so avoid calling it on infinite iterables.
3147 """
3148 iterator = filter(predicate, iterable)
3149 if n <= 0:
3150 if n < 0:
3151 return False
3152 for _ in iterator:
3153 return False
3154 return True
3156 iterator = islice(iterator, n - 1, None)
3157 for _ in iterator:
3158 for _ in iterator:
3159 return False
3160 return True
3161 return False
3164def circular_shifts(iterable, steps=1):
3165 """Yield the circular shifts of *iterable*.
3167 >>> list(circular_shifts(range(4)))
3168 [(0, 1, 2, 3), (1, 2, 3, 0), (2, 3, 0, 1), (3, 0, 1, 2)]
3170 Set *steps* to the number of places to rotate to the left
3171 (or to the right if negative). Defaults to 1.
3173 >>> list(circular_shifts(range(4), 2))
3174 [(0, 1, 2, 3), (2, 3, 0, 1)]
3176 >>> list(circular_shifts(range(4), -1))
3177 [(0, 1, 2, 3), (3, 0, 1, 2), (2, 3, 0, 1), (1, 2, 3, 0)]
3179 """
3180 buffer = deque(iterable)
3181 if steps == 0:
3182 raise ValueError('Steps should be a non-zero integer')
3184 buffer.rotate(steps)
3185 steps = -steps
3186 n = len(buffer)
3187 n //= math.gcd(n, steps)
3189 for _ in repeat(None, n):
3190 buffer.rotate(steps)
3191 yield tuple(buffer)
3194def make_decorator(wrapping_func, result_index=0):
3195 """Return a decorator version of *wrapping_func*, which is a function that
3196 modifies an iterable. *result_index* is the position in that function's
3197 signature where the iterable goes.
3199 This lets you use itertools on the "production end," i.e. at function
3200 definition. This can augment what the function returns without changing the
3201 function's code.
3203 For example, to produce a decorator version of :func:`chunked`:
3205 >>> from more_itertools import chunked
3206 >>> chunker = make_decorator(chunked, result_index=0)
3207 >>> @chunker(3)
3208 ... def iter_range(n):
3209 ... return iter(range(n))
3210 ...
3211 >>> list(iter_range(9))
3212 [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
3214 To only allow truthy items to be returned:
3216 >>> truth_serum = make_decorator(filter, result_index=1)
3217 >>> @truth_serum(bool)
3218 ... def boolean_test():
3219 ... return [0, 1, '', ' ', False, True]
3220 ...
3221 >>> list(boolean_test())
3222 [1, ' ', True]
3224 The :func:`peekable` and :func:`seekable` wrappers make for practical
3225 decorators:
3227 >>> from more_itertools import peekable
3228 >>> peekable_function = make_decorator(peekable)
3229 >>> @peekable_function()
3230 ... def str_range(*args):
3231 ... return (str(x) for x in range(*args))
3232 ...
3233 >>> it = str_range(1, 20, 2)
3234 >>> next(it), next(it), next(it)
3235 ('1', '3', '5')
3236 >>> it.peek()
3237 '7'
3238 >>> next(it)
3239 '7'
3241 """
3243 # See https://sites.google.com/site/bbayles/index/decorator_factory for
3244 # notes on how this works.
3245 def decorator(*wrapping_args, **wrapping_kwargs):
3246 def outer_wrapper(f):
3247 def inner_wrapper(*args, **kwargs):
3248 result = f(*args, **kwargs)
3249 wrapping_args_ = list(wrapping_args)
3250 wrapping_args_.insert(result_index, result)
3251 return wrapping_func(*wrapping_args_, **wrapping_kwargs)
3253 return inner_wrapper
3255 return outer_wrapper
3257 return decorator
3260def map_reduce(iterable, keyfunc, valuefunc=None, reducefunc=None):
3261 """Return a dictionary that maps the items in *iterable* to categories
3262 defined by *keyfunc*, transforms them with *valuefunc*, and
3263 then summarizes them by category with *reducefunc*.
3265 *valuefunc* defaults to the identity function if it is unspecified.
3266 If *reducefunc* is unspecified, no summarization takes place:
3268 >>> keyfunc = lambda x: x.upper()
3269 >>> result = map_reduce('abbccc', keyfunc)
3270 >>> sorted(result.items())
3271 [('A', ['a']), ('B', ['b', 'b']), ('C', ['c', 'c', 'c'])]
3273 Specifying *valuefunc* transforms the categorized items:
3275 >>> keyfunc = lambda x: x.upper()
3276 >>> valuefunc = lambda x: 1
3277 >>> result = map_reduce('abbccc', keyfunc, valuefunc)
3278 >>> sorted(result.items())
3279 [('A', [1]), ('B', [1, 1]), ('C', [1, 1, 1])]
3281 Specifying *reducefunc* summarizes the categorized items:
3283 >>> keyfunc = lambda x: x.upper()
3284 >>> valuefunc = lambda x: 1
3285 >>> reducefunc = sum
3286 >>> result = map_reduce('abbccc', keyfunc, valuefunc, reducefunc)
3287 >>> sorted(result.items())
3288 [('A', 1), ('B', 2), ('C', 3)]
3290 You may want to filter the input iterable before applying the map/reduce
3291 procedure:
3293 >>> all_items = range(30)
3294 >>> items = [x for x in all_items if 10 <= x <= 20] # Filter
3295 >>> keyfunc = lambda x: x % 2 # Evens map to 0; odds to 1
3296 >>> categories = map_reduce(items, keyfunc=keyfunc)
3297 >>> sorted(categories.items())
3298 [(0, [10, 12, 14, 16, 18, 20]), (1, [11, 13, 15, 17, 19])]
3299 >>> summaries = map_reduce(items, keyfunc=keyfunc, reducefunc=sum)
3300 >>> sorted(summaries.items())
3301 [(0, 90), (1, 75)]
3303 Note that all items in the iterable are gathered into a list before the
3304 summarization step, which may require significant storage.
3306 The returned object is a :obj:`collections.defaultdict` with the
3307 ``default_factory`` set to ``None``, such that it behaves like a normal
3308 dictionary.
3310 .. seealso:: :func:`bucket`, :func:`groupby_transform`
3312 If storage is a concern, :func:`bucket` can be used without consuming the
3313 entire iterable right away. If the elements with the same key are already
3314 adjacent, :func:`groupby_transform` or :func:`itertools.groupby` can be
3315 used without any caching overhead.
3317 """
3319 ret = defaultdict(list)
3321 if valuefunc is None:
3322 for item in iterable:
3323 key = keyfunc(item)
3324 ret[key].append(item)
3326 else:
3327 for item in iterable:
3328 key = keyfunc(item)
3329 value = valuefunc(item)
3330 ret[key].append(value)
3332 if reducefunc is not None:
3333 for key, value_list in ret.items():
3334 ret[key] = reducefunc(value_list)
3336 ret.default_factory = None
3337 return ret
3340def rlocate(iterable, pred=bool, window_size=None):
3341 """Yield the index of each item in *iterable* for which *pred* returns
3342 ``True``, starting from the right and moving left.
3344 *pred* defaults to :func:`bool`, which will select truthy items:
3346 >>> list(rlocate([0, 1, 1, 0, 1, 0, 0])) # Truthy at 1, 2, and 4
3347 [4, 2, 1]
3349 Set *pred* to a custom function to, e.g., find the indexes for a particular
3350 item:
3352 >>> iterator = iter('abcb')
3353 >>> pred = lambda x: x == 'b'
3354 >>> list(rlocate(iterator, pred))
3355 [3, 1]
3357 If *window_size* is given, then the *pred* function will be called with
3358 that many items. This enables searching for sub-sequences:
3360 >>> iterable = [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]
3361 >>> pred = lambda *args: args == (1, 2, 3)
3362 >>> list(rlocate(iterable, pred=pred, window_size=3))
3363 [9, 5, 1]
3365 Beware, this function won't return anything for infinite iterables.
3366 If *iterable* is reversible, ``rlocate`` will reverse it and search from
3367 the right. Otherwise, it will search from the left and return the results
3368 in reverse order.
3370 See :func:`locate` to for other example applications.
3372 """
3373 if window_size is None:
3374 try:
3375 len_iter = len(iterable)
3376 return (len_iter - i - 1 for i in locate(reversed(iterable), pred))
3377 except TypeError:
3378 pass
3380 return reversed(list(locate(iterable, pred, window_size)))
3383def replace(iterable, pred, substitutes, count=None, window_size=1):
3384 """Yield the items from *iterable*, replacing the items for which *pred*
3385 returns ``True`` with the items from the iterable *substitutes*.
3387 >>> iterable = [1, 1, 0, 1, 1, 0, 1, 1]
3388 >>> pred = lambda x: x == 0
3389 >>> substitutes = (2, 3)
3390 >>> list(replace(iterable, pred, substitutes))
3391 [1, 1, 2, 3, 1, 1, 2, 3, 1, 1]
3393 If *count* is given, the number of replacements will be limited:
3395 >>> iterable = [1, 1, 0, 1, 1, 0, 1, 1, 0]
3396 >>> pred = lambda x: x == 0
3397 >>> substitutes = [None]
3398 >>> list(replace(iterable, pred, substitutes, count=2))
3399 [1, 1, None, 1, 1, None, 1, 1, 0]
3401 Use *window_size* to control the number of items passed as arguments to
3402 *pred*. This allows for locating and replacing subsequences.
3404 >>> iterable = [0, 1, 2, 5, 0, 1, 2, 5]
3405 >>> window_size = 3
3406 >>> pred = lambda *args: args == (0, 1, 2) # 3 items passed to pred
3407 >>> substitutes = [3, 4] # Splice in these items
3408 >>> list(replace(iterable, pred, substitutes, window_size=window_size))
3409 [3, 4, 5, 3, 4, 5]
3411 *pred* may receive fewer than *window_size* arguments at the end of
3412 the iterable and should be able to handle this.
3414 """
3415 if window_size < 1:
3416 raise ValueError('window_size must be at least 1')
3418 # Save the substitutes iterable, since it's used more than once
3419 substitutes = tuple(substitutes)
3421 # Add padding such that the number of windows matches the length of the
3422 # iterable
3423 it = chain(iterable, repeat(_marker, window_size - 1))
3424 windows = windowed(it, window_size)
3426 n = 0
3427 for w in windows:
3428 # Strip any _marker padding so pred never sees internal sentinels.
3429 # Near the end of the iterable, pred will receive fewer arguments.
3430 args = tuple(x for x in w if x is not _marker)
3432 # If the current window matches our predicate (and we haven't hit
3433 # our maximum number of replacements), splice in the substitutes
3434 # and then consume the following windows that overlap with this one.
3435 # For example, if the iterable is (0, 1, 2, 3, 4...)
3436 # and the window size is 2, we have (0, 1), (1, 2), (2, 3)...
3437 # If the predicate matches on (0, 1), we need to zap (0, 1) and (1, 2)
3438 if args and pred(*args):
3439 if (count is None) or (n < count):
3440 n += 1
3441 yield from substitutes
3442 consume(windows, window_size - 1)
3443 continue
3445 # If there was no match (or we've reached the replacement limit),
3446 # yield the first item from the window.
3447 if args:
3448 yield args[0]
3451def partitions(iterable):
3452 """Yield all possible order-preserving partitions of *iterable*.
3454 >>> iterable = 'abc'
3455 >>> for part in partitions(iterable):
3456 ... print([''.join(p) for p in part])
3457 ['abc']
3458 ['a', 'bc']
3459 ['ab', 'c']
3460 ['a', 'b', 'c']
3462 This is unrelated to :func:`partition`.
3464 """
3465 sequence = list(iterable)
3466 n = len(sequence)
3467 for i in powerset(range(1, n)):
3468 yield [sequence[i:j] for i, j in zip((0,) + i, i + (n,))]
3471def set_partitions(iterable, k=None, min_size=None, max_size=None):
3472 """
3473 Yield the set partitions of *iterable* into *k* parts. Set partitions are
3474 not order-preserving.
3476 >>> iterable = 'abc'
3477 >>> for part in set_partitions(iterable, 2):
3478 ... print([''.join(p) for p in part])
3479 ['a', 'bc']
3480 ['ab', 'c']
3481 ['b', 'ac']
3484 If *k* is not given, every set partition is generated.
3486 >>> iterable = 'abc'
3487 >>> for part in set_partitions(iterable):
3488 ... print([''.join(p) for p in part])
3489 ['abc']
3490 ['a', 'bc']
3491 ['ab', 'c']
3492 ['b', 'ac']
3493 ['a', 'b', 'c']
3495 if *min_size* and/or *max_size* are given, the minimum and/or maximum size
3496 per block in partition is set.
3498 >>> iterable = 'abc'
3499 >>> for part in set_partitions(iterable, min_size=2):
3500 ... print([''.join(p) for p in part])
3501 ['abc']
3502 >>> for part in set_partitions(iterable, max_size=2):
3503 ... print([''.join(p) for p in part])
3504 ['a', 'bc']
3505 ['ab', 'c']
3506 ['b', 'ac']
3507 ['a', 'b', 'c']
3509 """
3510 L = list(iterable)
3511 n = len(L)
3512 if k is not None:
3513 if k < 1:
3514 raise ValueError(
3515 "Can't partition in a negative or zero number of groups"
3516 )
3517 elif k > n:
3518 return
3520 min_size = min_size if min_size is not None else 0
3521 max_size = max_size if max_size is not None else n
3522 if min_size > max_size:
3523 return
3525 def set_partitions_helper(L, k):
3526 n = len(L)
3527 if k == 1:
3528 yield [L]
3529 elif n == k:
3530 yield [[s] for s in L]
3531 else:
3532 e, *M = L
3533 for p in set_partitions_helper(M, k - 1):
3534 yield [[e], *p]
3535 for p in set_partitions_helper(M, k):
3536 for i in range(len(p)):
3537 yield p[:i] + [[e] + p[i]] + p[i + 1 :]
3539 if k is None:
3540 for k in range(1, n + 1):
3541 yield from filter(
3542 lambda z: all(min_size <= len(bk) <= max_size for bk in z),
3543 set_partitions_helper(L, k),
3544 )
3545 else:
3546 yield from filter(
3547 lambda z: all(min_size <= len(bk) <= max_size for bk in z),
3548 set_partitions_helper(L, k),
3549 )
3552class time_limited:
3553 """
3554 Yield items from *iterable* until *limit_seconds* have passed.
3555 If the time limit expires before all items have been yielded, the
3556 ``timed_out`` parameter will be set to ``True``.
3558 >>> from time import sleep
3559 >>> def generator():
3560 ... yield 1
3561 ... yield 2
3562 ... sleep(0.2)
3563 ... yield 3
3564 >>> iterable = time_limited(0.1, generator())
3565 >>> list(iterable)
3566 [1, 2]
3567 >>> iterable.timed_out
3568 True
3570 Note that the time is checked before each item is yielded, and iteration
3571 stops if the time elapsed is greater than *limit_seconds*. If your time
3572 limit is 1 second, but it takes 2 seconds to generate the first item from
3573 the iterable, the function will run for 2 seconds and not yield anything.
3574 As a special case, when *limit_seconds* is zero, the iterator never
3575 returns anything.
3577 """
3579 def __init__(self, limit_seconds, iterable):
3580 if limit_seconds < 0:
3581 raise ValueError('limit_seconds must be positive')
3582 self.limit_seconds = limit_seconds
3583 self._iterator = iter(iterable)
3584 self._start_time = monotonic()
3585 self.timed_out = False
3587 def __iter__(self):
3588 return self
3590 def __next__(self):
3591 if self.limit_seconds == 0:
3592 self.timed_out = True
3593 raise StopIteration
3594 item = next(self._iterator)
3595 if monotonic() - self._start_time > self.limit_seconds:
3596 self.timed_out = True
3597 raise StopIteration
3599 return item
3602def only(iterable, default=None, too_long=None):
3603 """If *iterable* has only one item, return it.
3604 If it has zero items, return *default*.
3605 If it has more than one item, raise the exception given by *too_long*,
3606 which is ``ValueError`` by default.
3608 >>> only([], default='missing')
3609 'missing'
3610 >>> only([1])
3611 1
3612 >>> only([1, 2]) # doctest: +IGNORE_EXCEPTION_DETAIL
3613 Traceback (most recent call last):
3614 ...
3615 ValueError: Expected exactly one item in iterable, but got 1, 2,
3616 and perhaps more.'
3617 >>> only([1, 2], too_long=TypeError) # doctest: +IGNORE_EXCEPTION_DETAIL
3618 Traceback (most recent call last):
3619 ...
3620 TypeError
3622 Note that :func:`only` attempts to advance *iterable* twice to ensure there
3623 is only one item. See :func:`spy` or :func:`peekable` to check
3624 iterable contents less destructively.
3626 """
3627 iterator = iter(iterable)
3628 for first in iterator:
3629 for second in iterator:
3630 msg = (
3631 f'Expected exactly one item in iterable, but got {first!r}, '
3632 f'{second!r}, and perhaps more.'
3633 )
3634 raise too_long or ValueError(msg)
3635 return first
3636 return default
3639def ichunked(iterable, n):
3640 """Break *iterable* into sub-iterables with *n* elements each.
3641 :func:`ichunked` is like :func:`chunked`, but it yields iterables
3642 instead of lists.
3644 If the sub-iterables are read in order, the elements of *iterable*
3645 won't be stored in memory.
3646 If they are read out of order, :func:`itertools.tee` is used to cache
3647 elements as necessary.
3649 >>> from itertools import count
3650 >>> all_chunks = ichunked(count(), 4)
3651 >>> c_1, c_2, c_3 = next(all_chunks), next(all_chunks), next(all_chunks)
3652 >>> list(c_2) # c_1's elements have been cached; c_3's haven't been
3653 [4, 5, 6, 7]
3654 >>> list(c_1)
3655 [0, 1, 2, 3]
3656 >>> list(c_3)
3657 [8, 9, 10, 11]
3659 """
3660 iterator = iter(iterable)
3661 for first in iterator:
3662 rest = islice(iterator, n - 1)
3663 cache, cacher = tee(rest)
3664 yield chain([first], rest, cache)
3665 consume(cacher)
3668def iequals(*iterables):
3669 """Return ``True`` if all given *iterables* are equal to each other,
3670 which means that they contain the same elements in the same order.
3672 The function is useful for comparing iterables of different data types
3673 or iterables that do not support equality checks.
3675 >>> iequals("abc", ['a', 'b', 'c'], ('a', 'b', 'c'), iter("abc"))
3676 True
3678 >>> iequals("abc", "acb")
3679 False
3681 Not to be confused with :func:`all_equal`, which checks whether all
3682 elements of iterable are equal to each other.
3684 """
3685 try:
3686 return all(map(all_equal, zip(*iterables, strict=True)))
3687 except ValueError:
3688 return False
3691def distinct_combinations(iterable, r):
3692 """Yield the distinct combinations of *r* items taken from *iterable*.
3694 >>> list(distinct_combinations([0, 0, 1], 2))
3695 [(0, 0), (0, 1)]
3697 Equivalent to ``set(combinations(iterable))``, except duplicates are not
3698 generated and thrown away. For larger input sequences this is much more
3699 efficient.
3701 """
3702 if r < 0:
3703 raise ValueError('r must be non-negative')
3704 elif r == 0:
3705 yield ()
3706 return
3707 pool = tuple(iterable)
3708 generators = [unique_everseen(enumerate(pool), key=itemgetter(1))]
3709 current_combo = [None] * r
3710 level = 0
3711 while generators:
3712 try:
3713 cur_idx, p = next(generators[-1])
3714 except StopIteration:
3715 generators.pop()
3716 level -= 1
3717 continue
3718 current_combo[level] = p
3719 if level + 1 == r:
3720 yield tuple(current_combo)
3721 else:
3722 generators.append(
3723 unique_everseen(
3724 enumerate(pool[cur_idx + 1 :], cur_idx + 1),
3725 key=itemgetter(1),
3726 )
3727 )
3728 level += 1
3731def filter_except(validator, iterable, *exceptions):
3732 """Yield the items from *iterable* for which the *validator* function does
3733 not raise one of the specified *exceptions*.
3735 *validator* is called for each item in *iterable*.
3736 It should be a function that accepts one argument and raises an exception
3737 if that item is not valid.
3739 >>> iterable = ['1', '2', 'three', '4', None]
3740 >>> list(filter_except(int, iterable, ValueError, TypeError))
3741 ['1', '2', '4']
3743 If an exception other than one given by *exceptions* is raised by
3744 *validator*, it is raised like normal.
3745 """
3746 for item in iterable:
3747 try:
3748 validator(item)
3749 except exceptions:
3750 pass
3751 else:
3752 yield item
3755def map_except(function, iterable, *exceptions):
3756 """Transform each item from *iterable* with *function* and yield the
3757 result, unless *function* raises one of the specified *exceptions*.
3759 *function* is called to transform each item in *iterable*.
3760 It should accept one argument.
3762 >>> iterable = ['1', '2', 'three', '4', None]
3763 >>> list(map_except(int, iterable, ValueError, TypeError))
3764 [1, 2, 4]
3766 If an exception other than one given by *exceptions* is raised by
3767 *function*, it is raised like normal.
3768 """
3769 for item in iterable:
3770 try:
3771 yield function(item)
3772 except exceptions:
3773 pass
3776def map_if(iterable, pred, func, func_else=None):
3777 """Evaluate each item from *iterable* using *pred*. If the result is
3778 equivalent to ``True``, transform the item with *func* and yield it.
3779 Otherwise, transform the item with *func_else* and yield it.
3781 *pred*, *func*, and *func_else* should each be functions that accept
3782 one argument. By default, *func_else* is the identity function.
3784 >>> from math import sqrt
3785 >>> iterable = list(range(-5, 5))
3786 >>> iterable
3787 [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4]
3788 >>> list(map_if(iterable, lambda x: x > 3, lambda x: 'toobig'))
3789 [-5, -4, -3, -2, -1, 0, 1, 2, 3, 'toobig']
3790 >>> list(map_if(iterable, lambda x: x >= 0,
3791 ... lambda x: f'{sqrt(x):.2f}', lambda x: None))
3792 [None, None, None, None, None, '0.00', '1.00', '1.41', '1.73', '2.00']
3793 """
3795 if func_else is None:
3796 for item in iterable:
3797 yield func(item) if pred(item) else item
3799 else:
3800 for item in iterable:
3801 yield func(item) if pred(item) else func_else(item)
3804def _sample_unweighted(iterator, k, strict):
3805 # Algorithm L in the 1994 paper by Kim-Hung Li:
3806 # "Reservoir-Sampling Algorithms of Time Complexity O(n(1+log(N/n)))".
3808 reservoir = list(islice(iterator, k))
3809 if strict and len(reservoir) < k:
3810 raise ValueError('Sample larger than population')
3811 W = 1.0
3813 with suppress(StopIteration):
3814 while True:
3815 W *= random() ** (1 / k)
3816 skip = floor(log(random()) / log1p(-W))
3817 element = next(islice(iterator, skip, None))
3818 reservoir[randrange(k)] = element
3820 shuffle(reservoir)
3821 return reservoir
3824def _sample_weighted(iterator, k, weights, strict):
3825 # Implementation of "A-ExpJ" from the 2006 paper by Efraimidis et al. :
3826 # "Weighted random sampling with a reservoir".
3828 # Log-transform for numerical stability for weights that are small/large
3829 weight_keys = (log(random()) / weight for weight in weights)
3831 # Fill up the reservoir (collection of samples) with the first `k`
3832 # weight-keys and elements, then heapify the list.
3833 reservoir = take(k, zip(weight_keys, iterator))
3834 if strict and len(reservoir) < k:
3835 raise ValueError('Sample larger than population')
3837 heapify(reservoir)
3839 # The number of jumps before changing the reservoir is a random variable
3840 # with an exponential distribution. Sample it using random() and logs.
3841 smallest_weight_key, _ = reservoir[0]
3842 weights_to_skip = log(random()) / smallest_weight_key
3844 for weight, element in zip(weights, iterator):
3845 if weight >= weights_to_skip:
3846 # The notation here is consistent with the paper, but we store
3847 # the weight-keys in log-space for better numerical stability.
3848 smallest_weight_key, _ = reservoir[0]
3849 t_w = exp(weight * smallest_weight_key)
3850 r_2 = uniform(t_w, 1) # generate U(t_w, 1)
3851 weight_key = log(r_2) / weight
3852 heapreplace(reservoir, (weight_key, element))
3853 smallest_weight_key, _ = reservoir[0]
3854 weights_to_skip = log(random()) / smallest_weight_key
3855 else:
3856 weights_to_skip -= weight
3858 ret = [element for weight_key, element in reservoir]
3859 shuffle(ret)
3860 return ret
3863def _sample_counted(population, k, counts, strict):
3864 element = None
3865 remaining = 0
3867 def feed(i):
3868 # Advance *i* steps ahead and consume an element
3869 nonlocal element, remaining
3871 while i + 1 > remaining:
3872 i = i - remaining
3873 element = next(population)
3874 remaining = next(counts)
3875 remaining -= i + 1
3876 return element
3878 with suppress(StopIteration):
3879 reservoir = []
3880 for _ in range(k):
3881 reservoir.append(feed(0))
3883 if strict and len(reservoir) < k:
3884 raise ValueError('Sample larger than population')
3886 with suppress(StopIteration):
3887 W = 1.0
3888 while True:
3889 W *= random() ** (1 / k)
3890 skip = floor(log(random()) / log1p(-W))
3891 element = feed(skip)
3892 reservoir[randrange(k)] = element
3894 shuffle(reservoir)
3895 return reservoir
3898def sample(iterable, k, weights=None, *, counts=None, strict=False):
3899 """Return a *k*-length list of elements chosen (without replacement)
3900 from the *iterable*.
3902 Similar to :func:`random.sample`, but works on inputs that aren't
3903 indexable (such as sets and dictionaries) and on inputs where the
3904 size isn't known in advance (such as generators).
3906 >>> iterable = range(100)
3907 >>> sample(iterable, 5) # doctest: +SKIP
3908 [81, 60, 96, 16, 4]
3910 For iterables with repeated elements, you may supply *counts* to
3911 indicate the repeats.
3913 >>> iterable = ['a', 'b']
3914 >>> counts = [3, 4] # Equivalent to 'a', 'a', 'a', 'b', 'b', 'b', 'b'
3915 >>> sample(iterable, k=3, counts=counts) # doctest: +SKIP
3916 ['a', 'a', 'b']
3918 An iterable with *weights* may be given:
3920 >>> iterable = range(100)
3921 >>> weights = (i * i + 1 for i in range(100))
3922 >>> sampled = sample(iterable, 5, weights=weights) # doctest: +SKIP
3923 [79, 67, 74, 66, 78]
3925 Weighted selections are made without replacement.
3926 After an element is selected, it is removed from the pool and the
3927 relative weights of the other elements increase (this
3928 does not match the behavior of :func:`random.sample`'s *counts*
3929 parameter). Note that *weights* may not be used with *counts*.
3931 If the length of *iterable* is less than *k*,
3932 ``ValueError`` is raised if *strict* is ``True`` and
3933 all elements are returned (in shuffled order) if *strict* is ``False``.
3935 By default, the `Algorithm L <https://w.wiki/ANrM>`__ reservoir sampling
3936 technique is used. When *weights* are provided,
3937 `Algorithm A-ExpJ <https://w.wiki/ANrS>`__ is used instead.
3939 Notes on reproducibility:
3941 * The algorithms rely on inexact floating-point functions provided
3942 by the underlying math library (e.g. ``log``, ``log1p``, and ``pow``).
3943 Those functions can `produce slightly different results
3944 <https://members.loria.fr/PZimmermann/papers/accuracy.pdf>`_ on
3945 different builds. Accordingly, selections can vary across builds
3946 even for the same seed.
3948 * The algorithms loop over the input and make selections based on
3949 ordinal position, so selections from unordered collections (such as
3950 sets) won't reproduce across sessions on the same platform using the
3951 same seed. For example, this won't reproduce::
3953 >> seed(8675309)
3954 >> sample(set('abcdefghijklmnopqrstuvwxyz'), 10)
3955 ['c', 'p', 'e', 'w', 's', 'a', 'j', 'd', 'n', 't']
3957 """
3958 iterator = iter(iterable)
3960 if k < 0:
3961 raise ValueError('k must be non-negative')
3963 if k == 0:
3964 return []
3966 if weights is not None and counts is not None:
3967 raise TypeError('weights and counts are mutually exclusive')
3969 elif weights is not None:
3970 weights = iter(weights)
3971 return _sample_weighted(iterator, k, weights, strict)
3973 elif counts is not None:
3974 counts = iter(counts)
3975 return _sample_counted(iterator, k, counts, strict)
3977 else:
3978 return _sample_unweighted(iterator, k, strict)
3981def is_sorted(iterable, key=None, reverse=False, strict=False):
3982 """Returns ``True`` if the items of iterable are in sorted order, and
3983 ``False`` otherwise. *key* and *reverse* have the same meaning that they do
3984 in the built-in :func:`sorted` function.
3986 >>> is_sorted(['1', '2', '3', '4', '5'], key=int)
3987 True
3988 >>> is_sorted([5, 4, 3, 1, 2], reverse=True)
3989 False
3991 If *strict*, tests for strict sorting, that is, returns ``False`` if equal
3992 elements are found:
3994 >>> is_sorted([1, 2, 2])
3995 True
3996 >>> is_sorted([1, 2, 2], strict=True)
3997 False
3999 The function returns ``False`` after encountering the first out-of-order
4000 item, which means it may produce results that differ from the built-in
4001 :func:`sorted` function for objects with unusual comparison dynamics
4002 (like ``math.nan``). If there are no out-of-order items, the iterable is
4003 exhausted.
4004 """
4005 it = iterable if (key is None) else map(key, iterable)
4006 a, b = tee(it)
4007 next(b, None)
4008 if reverse:
4009 b, a = a, b
4010 return all(map(lt, a, b)) if strict else not any(map(lt, b, a))
4013class AbortThread(BaseException):
4014 pass
4017class callback_iter:
4018 """Convert a function that uses callbacks to an iterator.
4020 .. warning::
4022 This function is deprecated as of version 11.0.0. It will be removed in a future
4023 major release.
4025 Let *func* be a function that takes a `callback` keyword argument.
4026 For example:
4028 >>> def func(callback=None):
4029 ... for i, c in [(1, 'a'), (2, 'b'), (3, 'c')]:
4030 ... if callback:
4031 ... callback(i, c)
4032 ... return 4
4035 Use ``with callback_iter(func)`` to get an iterator over the parameters
4036 that are delivered to the callback.
4038 >>> with callback_iter(func) as it:
4039 ... for args, kwargs in it:
4040 ... print(args)
4041 (1, 'a')
4042 (2, 'b')
4043 (3, 'c')
4045 The function will be called in a background thread. The ``done`` property
4046 indicates whether it has completed execution.
4048 >>> it.done
4049 True
4051 If it completes successfully, its return value will be available
4052 in the ``result`` property.
4054 >>> it.result
4055 4
4057 Notes:
4059 * If the function uses some keyword argument besides ``callback``, supply
4060 *callback_kwd*.
4061 * If it finished executing, but raised an exception, accessing the
4062 ``result`` property will raise the same exception.
4063 * If it hasn't finished executing, accessing the ``result``
4064 property from within the ``with`` block will raise ``RuntimeError``.
4065 * If it hasn't finished executing, accessing the ``result`` property from
4066 outside the ``with`` block will raise a
4067 ``more_itertools.AbortThread`` exception.
4068 * Provide *wait_seconds* to adjust how frequently the it is polled for
4069 output.
4071 """
4073 def __init__(self, func, callback_kwd='callback', wait_seconds=0.1):
4074 self._func = func
4075 self._callback_kwd = callback_kwd
4076 self._aborted = False
4077 self._future = None
4078 self._wait_seconds = wait_seconds
4080 # Lazily import concurrent.future
4081 self._module = __import__('concurrent.futures').futures
4082 self._executor = self._module.ThreadPoolExecutor(max_workers=1)
4083 self._iterator = self._reader()
4085 def __enter__(self):
4086 return self
4088 def __exit__(self, exc_type, exc_value, traceback):
4089 self._aborted = True
4090 self._executor.shutdown()
4092 def __iter__(self):
4093 return self
4095 def __next__(self):
4096 return next(self._iterator)
4098 @property
4099 def done(self):
4100 if self._future is None:
4101 return False
4102 return self._future.done()
4104 @property
4105 def result(self):
4106 if self._future:
4107 try:
4108 return self._future.result(timeout=0)
4109 except self._module.TimeoutError:
4110 pass
4112 raise RuntimeError('Function has not yet completed')
4114 def _reader(self):
4115 q = Queue()
4117 def callback(*args, **kwargs):
4118 if self._aborted:
4119 raise AbortThread('canceled by user')
4121 q.put((args, kwargs))
4123 self._future = self._executor.submit(
4124 self._func, **{self._callback_kwd: callback}
4125 )
4127 while True:
4128 try:
4129 item = q.get(timeout=self._wait_seconds)
4130 except Empty:
4131 pass
4132 else:
4133 q.task_done()
4134 yield item
4136 if self._future.done():
4137 break
4139 remaining = []
4140 while True:
4141 try:
4142 item = q.get_nowait()
4143 except Empty:
4144 break
4145 else:
4146 q.task_done()
4147 remaining.append(item)
4148 q.join()
4149 yield from remaining
4152def windowed_complete(iterable, n):
4153 """
4154 Yield ``(beginning, middle, end)`` tuples, where:
4156 * Each ``middle`` has *n* items from *iterable*
4157 * Each ``beginning`` has the items before the ones in ``middle``
4158 * Each ``end`` has the items after the ones in ``middle``
4160 >>> iterable = range(7)
4161 >>> n = 3
4162 >>> for beginning, middle, end in windowed_complete(iterable, n):
4163 ... print(beginning, middle, end)
4164 () (0, 1, 2) (3, 4, 5, 6)
4165 (0,) (1, 2, 3) (4, 5, 6)
4166 (0, 1) (2, 3, 4) (5, 6)
4167 (0, 1, 2) (3, 4, 5) (6,)
4168 (0, 1, 2, 3) (4, 5, 6) ()
4170 Note that *n* must be at least 0 and most equal to the length of
4171 *iterable*.
4173 This function will exhaust the iterable and may require significant
4174 storage.
4175 """
4176 if n < 0:
4177 raise ValueError('n must be >= 0')
4179 seq = tuple(iterable)
4180 size = len(seq)
4182 if n > size:
4183 raise ValueError('n must be <= len(seq)')
4185 for i in range(size - n + 1):
4186 beginning = seq[:i]
4187 middle = seq[i : i + n]
4188 end = seq[i + n :]
4189 yield beginning, middle, end
4192def all_unique(iterable, key=None):
4193 """
4194 Returns ``True`` if all the elements of *iterable* are unique (no two
4195 elements are equal).
4197 >>> all_unique('ABCB')
4198 False
4200 If a *key* function is specified, it will be used to make comparisons.
4202 >>> all_unique('ABCb')
4203 True
4204 >>> all_unique('ABCb', str.lower)
4205 False
4207 The function returns as soon as the first non-unique element is
4208 encountered. Iterables with a mix of hashable and unhashable items can
4209 be used, but the function will be slower for unhashable items.
4210 """
4211 seenset = set()
4212 seenset_add = seenset.add
4213 seenlist = []
4214 seenlist_add = seenlist.append
4215 for element in map(key, iterable) if key else iterable:
4216 try:
4217 if element in seenset:
4218 return False
4219 seenset_add(element)
4220 except TypeError:
4221 if element in seenlist:
4222 return False
4223 seenlist_add(element)
4224 return True
4227def nth_product(index, *iterables, repeat=1):
4228 """Equivalent to ``list(product(*iterables, repeat=repeat))[index]``.
4230 The products of *iterables* can be ordered lexicographically.
4231 :func:`nth_product` computes the product at sort position *index* without
4232 computing the previous products.
4234 >>> nth_product(8, range(2), range(2), range(2), range(2))
4235 (1, 0, 0, 0)
4237 The *repeat* keyword argument specifies the number of repetitions
4238 of the iterables. The above example is equivalent to::
4240 >>> nth_product(8, range(2), repeat=4)
4241 (1, 0, 0, 0)
4243 ``IndexError`` will be raised if the given *index* is invalid.
4244 """
4245 pools = tuple(map(tuple, reversed(iterables))) * repeat
4246 ns = tuple(map(len, pools))
4248 c = prod(ns)
4250 if index < 0:
4251 index += c
4252 if not 0 <= index < c:
4253 raise IndexError
4255 result = []
4256 for pool, n in zip(pools, ns):
4257 result.append(pool[index % n])
4258 index //= n
4260 return tuple(reversed(result))
4263def nth_permutation(iterable, r, index):
4264 """Equivalent to ``list(permutations(iterable, r))[index]```
4266 The subsequences of *iterable* that are of length *r* where order is
4267 important can be ordered lexicographically. :func:`nth_permutation`
4268 computes the subsequence at sort position *index* directly, without
4269 computing the previous subsequences.
4271 >>> nth_permutation('ghijk', 2, 5)
4272 ('h', 'i')
4274 ``ValueError`` will be raised If *r* is negative.
4275 ``IndexError`` will be raised if the given *index* is invalid.
4276 """
4277 pool = list(iterable)
4278 n = len(pool)
4279 if r is None:
4280 r = n
4281 c = perm(n, r)
4283 if index < 0:
4284 index += c
4285 if not 0 <= index < c:
4286 raise IndexError
4288 result = [0] * r
4289 q = index * factorial(n) // c if r < n else index
4290 for d in range(1, n + 1):
4291 q, i = divmod(q, d)
4292 if 0 <= n - d < r:
4293 result[n - d] = i
4294 if q == 0:
4295 break
4297 return tuple(map(pool.pop, result))
4300def nth_combination_with_replacement(iterable, r, index):
4301 """Equivalent to
4302 ``list(combinations_with_replacement(iterable, r))[index]``.
4305 The subsequences with repetition of *iterable* that are of length *r* can
4306 be ordered lexicographically. :func:`nth_combination_with_replacement`
4307 computes the subsequence at sort position *index* directly, without
4308 computing the previous subsequences with replacement.
4310 >>> nth_combination_with_replacement(range(5), 3, 5)
4311 (0, 1, 1)
4313 ``ValueError`` will be raised If *r* is negative.
4314 ``IndexError`` will be raised if the given *index* is invalid.
4315 """
4316 pool = tuple(iterable)
4317 n = len(pool)
4318 if r < 0:
4319 raise ValueError
4320 c = comb(n + r - 1, r) if n else 0 if r else 1
4322 if index < 0:
4323 index += c
4324 if not 0 <= index < c:
4325 raise IndexError
4327 result = []
4328 i = 0
4329 while r:
4330 r -= 1
4331 while n >= 0:
4332 num_combs = comb(n + r - 1, r)
4333 if index < num_combs:
4334 break
4335 n -= 1
4336 i += 1
4337 index -= num_combs
4338 result.append(pool[i])
4340 return tuple(result)
4343def value_chain(*args):
4344 """Yield all arguments passed to the function in the same order in which
4345 they were passed. If an argument itself is iterable then iterate over its
4346 values.
4348 >>> list(value_chain(1, 2, 3, [4, 5, 6]))
4349 [1, 2, 3, 4, 5, 6]
4351 Binary and text strings are not considered iterable and are emitted
4352 as-is:
4354 >>> list(value_chain('12', '34', ['56', '78']))
4355 ['12', '34', '56', '78']
4357 Pre- or postpend a single element to an iterable:
4359 >>> list(value_chain(1, [2, 3, 4, 5, 6]))
4360 [1, 2, 3, 4, 5, 6]
4361 >>> list(value_chain([1, 2, 3, 4, 5], 6))
4362 [1, 2, 3, 4, 5, 6]
4364 Multiple levels of nesting are not flattened.
4366 """
4367 scalar_types = (str, bytes)
4368 for value in args:
4369 if isinstance(value, scalar_types):
4370 yield value
4371 continue
4372 try:
4373 yield from value
4374 except TypeError:
4375 yield value
4378def product_index(element, *iterables, repeat=1):
4379 """Equivalent to ``list(product(*iterables, repeat=repeat)).index(tuple(element))``
4381 The products of *iterables* can be ordered lexicographically.
4382 :func:`product_index` computes the first index of *element* without
4383 computing the previous products.
4385 >>> product_index([8, 2], range(10), range(5))
4386 42
4388 The *repeat* keyword argument specifies the number of repetitions
4389 of the iterables::
4391 >>> product_index([8, 0, 7], range(10), repeat=3)
4392 807
4394 ``ValueError`` will be raised if the given *element* isn't in the product
4395 of *args*.
4396 """
4397 elements = tuple(element)
4398 pools = tuple(map(tuple, iterables)) * repeat
4399 if len(elements) != len(pools):
4400 raise ValueError('element is not a product of args')
4402 index = 0
4403 for elem, pool in zip(elements, pools):
4404 index = index * len(pool) + pool.index(elem)
4405 return index
4408def combination_index(element, iterable):
4409 """Equivalent to ``list(combinations(iterable, r)).index(element)``
4411 The subsequences of *iterable* that are of length *r* can be ordered
4412 lexicographically. :func:`combination_index` computes the index of the
4413 first *element*, without computing the previous combinations.
4415 >>> combination_index('adf', 'abcdefg')
4416 10
4418 ``ValueError`` will be raised if the given *element* isn't one of the
4419 combinations of *iterable*.
4420 """
4421 element = enumerate(element)
4422 k, y = next(element, (None, None))
4423 if k is None:
4424 return 0
4426 indexes = []
4427 pool = enumerate(iterable)
4428 for n, x in pool:
4429 if x == y:
4430 indexes.append(n)
4431 tmp, y = next(element, (None, None))
4432 if tmp is None:
4433 break
4434 else:
4435 k = tmp
4436 else:
4437 raise ValueError('element is not a combination of iterable')
4439 n, _ = last(pool, default=(n, None))
4441 index = 1
4442 for i, j in enumerate(reversed(indexes), start=1):
4443 j = n - j
4444 if i <= j:
4445 index += comb(j, i)
4447 return comb(n + 1, k + 1) - index
4450def combination_with_replacement_index(element, iterable):
4451 """Equivalent to
4452 ``list(combinations_with_replacement(iterable, r)).index(element)``
4454 The subsequences with repetition of *iterable* that are of length *r* can
4455 be ordered lexicographically. :func:`combination_with_replacement_index`
4456 computes the index of the first *element*, without computing the previous
4457 combinations with replacement.
4459 >>> combination_with_replacement_index('adf', 'abcdefg')
4460 20
4462 ``ValueError`` will be raised if the given *element* isn't one of the
4463 combinations with replacement of *iterable*.
4464 """
4465 element = tuple(element)
4466 l = len(element)
4467 element = enumerate(element)
4469 k, y = next(element, (None, None))
4470 if k is None:
4471 return 0
4473 indexes = []
4474 pool = tuple(iterable)
4475 for n, x in enumerate(pool):
4476 while x == y:
4477 indexes.append(n)
4478 tmp, y = next(element, (None, None))
4479 if tmp is None:
4480 break
4481 else:
4482 k = tmp
4483 if y is None:
4484 break
4485 else:
4486 raise ValueError(
4487 'element is not a combination with replacement of iterable'
4488 )
4490 n = len(pool)
4491 occupations = [0] * n
4492 for p in indexes:
4493 occupations[p] += 1
4495 index = 0
4496 cumulative_sum = 0
4497 for k in range(1, n):
4498 cumulative_sum += occupations[k - 1]
4499 j = l + n - 1 - k - cumulative_sum
4500 i = n - k
4501 if i <= j:
4502 index += comb(j, i)
4504 return index
4507def permutation_index(element, iterable):
4508 """Equivalent to ``list(permutations(iterable, r)).index(element)```
4510 The subsequences of *iterable* that are of length *r* where order is
4511 important can be ordered lexicographically. :func:`permutation_index`
4512 computes the index of the first *element* directly, without computing
4513 the previous permutations.
4515 >>> permutation_index([1, 3, 2], range(5))
4516 19
4518 ``ValueError`` will be raised if the given *element* isn't one of the
4519 permutations of *iterable*.
4520 """
4521 index = 0
4522 pool = list(iterable)
4523 for i, x in zip(range(len(pool), -1, -1), element):
4524 r = pool.index(x)
4525 index = index * i + r
4526 del pool[r]
4528 return index
4531class countable:
4532 """Wrap *iterable* and keep a count of how many items have been consumed.
4534 The ``items_seen`` attribute starts at ``0`` and increments as the iterable
4535 is consumed:
4537 >>> iterable = map(str, range(10))
4538 >>> it = countable(iterable)
4539 >>> it.items_seen
4540 0
4541 >>> next(it), next(it)
4542 ('0', '1')
4543 >>> list(it)
4544 ['2', '3', '4', '5', '6', '7', '8', '9']
4545 >>> it.items_seen
4546 10
4547 """
4549 def __init__(self, iterable):
4550 self._iterator = iter(iterable)
4551 self.items_seen = 0
4553 def __iter__(self):
4554 return self
4556 def __next__(self):
4557 item = next(self._iterator)
4558 self.items_seen += 1
4560 return item
4563def chunked_even(iterable, n):
4564 """Break *iterable* into lists of approximately length *n*.
4565 Items are distributed such the lengths of the lists differ by at most
4566 1 item.
4568 >>> iterable = [1, 2, 3, 4, 5, 6, 7]
4569 >>> n = 3
4570 >>> list(chunked_even(iterable, n)) # List lengths: 3, 2, 2
4571 [[1, 2, 3], [4, 5], [6, 7]]
4572 >>> list(chunked(iterable, n)) # List lengths: 3, 3, 1
4573 [[1, 2, 3], [4, 5, 6], [7]]
4575 """
4576 iterator = iter(iterable)
4578 # Initialize a buffer to process the chunks while keeping
4579 # some back to fill any underfilled chunks
4580 min_buffer = (n - 1) * (n - 2)
4581 buffer = list(islice(iterator, min_buffer))
4583 # Append items until we have a completed chunk
4584 for _ in islice(map(buffer.append, iterator), n, None, n):
4585 yield buffer[:n]
4586 del buffer[:n]
4588 # Check if any chunks need addition processing
4589 if not buffer:
4590 return
4591 length = len(buffer)
4593 # Chunks are either size `full_size <= n` or `partial_size = full_size - 1`
4594 q, r = divmod(length, n)
4595 num_lists = q + (1 if r > 0 else 0)
4596 q, r = divmod(length, num_lists)
4597 full_size = q + (1 if r > 0 else 0)
4598 partial_size = full_size - 1
4599 num_full = length - partial_size * num_lists
4601 # Yield chunks of full size
4602 partial_start_idx = num_full * full_size
4603 if full_size > 0:
4604 for i in range(0, partial_start_idx, full_size):
4605 yield buffer[i : i + full_size]
4607 # Yield chunks of partial size
4608 if partial_size > 0:
4609 for i in range(partial_start_idx, length, partial_size):
4610 yield buffer[i : i + partial_size]
4613def zip_broadcast(*objects, scalar_types=(str, bytes), strict=False):
4614 """A version of :func:`zip` that "broadcasts" any scalar
4615 (i.e., non-iterable) items into output tuples.
4617 >>> iterable_1 = [1, 2, 3]
4618 >>> iterable_2 = ['a', 'b', 'c']
4619 >>> scalar = '_'
4620 >>> list(zip_broadcast(iterable_1, iterable_2, scalar))
4621 [(1, 'a', '_'), (2, 'b', '_'), (3, 'c', '_')]
4623 The *scalar_types* keyword argument determines what types are considered
4624 scalar. It is set to ``(str, bytes)`` by default. Set it to ``None`` to
4625 treat strings and byte strings as iterable:
4627 >>> list(zip_broadcast('abc', 0, 'xyz', scalar_types=None))
4628 [('a', 0, 'x'), ('b', 0, 'y'), ('c', 0, 'z')]
4630 If the *strict* keyword argument is ``True``, then
4631 ``ValueError`` will be raised if any of the iterables have
4632 different lengths.
4633 """
4635 def is_scalar(obj):
4636 if scalar_types and isinstance(obj, scalar_types):
4637 return True
4638 try:
4639 iter(obj)
4640 except TypeError:
4641 return True
4642 else:
4643 return False
4645 size = len(objects)
4646 if not size:
4647 return
4649 new_item = [None] * size
4650 iterables, iterable_positions = [], []
4651 for i, obj in enumerate(objects):
4652 if is_scalar(obj):
4653 new_item[i] = obj
4654 else:
4655 iterables.append(iter(obj))
4656 iterable_positions.append(i)
4658 if not iterables:
4659 yield tuple(objects)
4660 return
4662 for item in zip(*iterables, strict=strict):
4663 for i, new_item[i] in zip(iterable_positions, item):
4664 pass
4665 yield tuple(new_item)
4668def unique_in_window(iterable, n, key=None):
4669 """Yield the items from *iterable* that haven't been seen recently.
4670 *n* is the size of the sliding window.
4672 >>> iterable = [0, 1, 0, 2, 3, 0]
4673 >>> n = 3
4674 >>> list(unique_in_window(iterable, n))
4675 [0, 1, 2, 3, 0]
4677 The *key* function, if provided, will be used to determine uniqueness:
4679 >>> list(unique_in_window('abAcda', 3, key=lambda x: x.lower()))
4680 ['a', 'b', 'c', 'd', 'a']
4682 Updates a sliding window no larger than n and yields a value
4683 if the item only occurs once in the updated window.
4685 When `n == 1`, *unique_in_window* is memoryless:
4687 >>> list(unique_in_window('aab', n=1))
4688 ['a', 'a', 'b']
4690 The items in *iterable* must be hashable.
4692 """
4693 if n <= 0:
4694 raise ValueError('n must be greater than 0')
4696 window = deque(maxlen=n)
4697 counts = Counter()
4698 use_key = key is not None
4700 for item in iterable:
4701 if len(window) == n:
4702 to_discard = window[0]
4703 if counts[to_discard] == 1:
4704 del counts[to_discard]
4705 else:
4706 counts[to_discard] -= 1
4708 k = key(item) if use_key else item
4709 if k not in counts:
4710 yield item
4711 counts[k] += 1
4712 window.append(k)
4715def duplicates_everseen(iterable, key=None):
4716 """Yield duplicate elements after their first appearance.
4718 >>> list(duplicates_everseen('mississippi'))
4719 ['s', 'i', 's', 's', 'i', 'p', 'i']
4720 >>> list(duplicates_everseen('AaaBbbCccAaa', str.lower))
4721 ['a', 'a', 'b', 'b', 'c', 'c', 'A', 'a', 'a']
4723 This function is analogous to :func:`unique_everseen` and is subject to
4724 the same performance considerations.
4726 """
4727 seen_set = set()
4728 seen_list = []
4729 use_key = key is not None
4731 for element in iterable:
4732 k = key(element) if use_key else element
4733 try:
4734 if k not in seen_set:
4735 seen_set.add(k)
4736 else:
4737 yield element
4738 except TypeError:
4739 if k not in seen_list:
4740 seen_list.append(k)
4741 else:
4742 yield element
4745def duplicates_justseen(iterable, key=None):
4746 """Yields serially-duplicate elements after their first appearance.
4748 >>> list(duplicates_justseen('mississippi'))
4749 ['s', 's', 'p']
4750 >>> list(duplicates_justseen('AaaBbbCccAaa', str.lower))
4751 ['a', 'a', 'b', 'b', 'c', 'c', 'a', 'a']
4753 This function is analogous to :func:`unique_justseen`.
4755 """
4756 return flatten(g for _, g in groupby(iterable, key) for _ in g)
4759def classify_unique(iterable, key=None):
4760 """Classify each element in terms of its uniqueness.
4762 For each element in the input iterable, return a 3-tuple consisting of:
4764 1. The element itself
4765 2. ``False`` if the element is equal to the one preceding it in the input,
4766 ``True`` otherwise (i.e. the equivalent of :func:`unique_justseen`)
4767 3. ``False`` if this element has been seen anywhere in the input before,
4768 ``True`` otherwise (i.e. the equivalent of :func:`unique_everseen`)
4770 >>> list(classify_unique('otto')) # doctest: +NORMALIZE_WHITESPACE
4771 [('o', True, True),
4772 ('t', True, True),
4773 ('t', False, False),
4774 ('o', True, False)]
4776 This function is analogous to :func:`unique_everseen` and is subject to
4777 the same performance considerations.
4779 """
4780 seen_set = set()
4781 seen_list = []
4782 use_key = key is not None
4783 previous = None
4785 for i, element in enumerate(iterable):
4786 k = key(element) if use_key else element
4787 is_unique_justseen = not i or previous != k
4788 previous = k
4789 is_unique_everseen = False
4790 try:
4791 if k not in seen_set:
4792 seen_set.add(k)
4793 is_unique_everseen = True
4794 except TypeError:
4795 if k not in seen_list:
4796 seen_list.append(k)
4797 is_unique_everseen = True
4798 yield element, is_unique_justseen, is_unique_everseen
4801def minmax(iterable_or_value, *others, key=None, default=_marker):
4802 """Returns both the smallest and largest items from an iterable
4803 or from two or more arguments.
4805 >>> minmax([3, 1, 5])
4806 (1, 5)
4808 >>> minmax(4, 2, 6)
4809 (2, 6)
4811 If a *key* function is provided, it will be used to transform the input
4812 items for comparison.
4814 >>> minmax([5, 30], key=str) # '30' sorts before '5'
4815 (30, 5)
4817 If a *default* value is provided, it will be returned if there are no
4818 input items.
4820 >>> minmax([], default=(0, 0))
4821 (0, 0)
4823 Otherwise ``ValueError`` is raised.
4825 This function makes a single pass over the input elements and takes care to
4826 minimize the number of comparisons made during processing.
4828 Note that unlike the builtin ``max`` function, which always returns the first
4829 item with the maximum value, this function may return another item when there are
4830 ties.
4832 This function is based on the
4833 `recipe <https://code.activestate.com/recipes/577916-fast-minmax-function>`__ by
4834 Raymond Hettinger.
4835 """
4836 iterable = (iterable_or_value, *others) if others else iterable_or_value
4838 it = iter(iterable)
4840 try:
4841 lo = hi = next(it)
4842 except StopIteration as exc:
4843 if default is _marker:
4844 raise ValueError(
4845 '`minmax()` argument is an empty iterable. '
4846 'Provide a `default` value to suppress this error.'
4847 ) from exc
4848 return default
4850 # Different branches depending on the presence of key. This saves a lot
4851 # of unimportant copies which would slow the "key=None" branch
4852 # significantly down.
4853 if key is None:
4854 for x, y in zip_longest(it, it, fillvalue=lo):
4855 if y < x:
4856 if y < lo:
4857 lo = y
4858 if hi < x:
4859 hi = x
4860 else:
4861 if x < lo:
4862 lo = x
4863 if hi < y:
4864 hi = y
4866 else:
4867 lo_key = hi_key = key(lo)
4869 for x, y in zip_longest(it, it, fillvalue=lo):
4870 x_key, y_key = key(x), key(y)
4872 if y_key < x_key:
4873 if y_key < lo_key:
4874 lo, lo_key = y, y_key
4875 if hi_key < x_key:
4876 hi, hi_key = x, x_key
4877 else:
4878 if x_key < lo_key:
4879 lo, lo_key = x, x_key
4880 if hi_key < y_key:
4881 hi, hi_key = y, y_key
4883 return lo, hi
4886def constrained_batches(
4887 iterable, max_size, max_count=None, get_len=len, strict=True
4888):
4889 """Yield batches of items from *iterable* with a combined size limited by
4890 *max_size*.
4892 >>> iterable = [b'12345', b'123', b'12345678', b'1', b'1', b'12', b'1']
4893 >>> list(constrained_batches(iterable, 10))
4894 [(b'12345', b'123'), (b'12345678', b'1', b'1'), (b'12', b'1')]
4896 If a *max_count* is supplied, the number of items per batch is also
4897 limited:
4899 >>> iterable = [b'12345', b'123', b'12345678', b'1', b'1', b'12', b'1']
4900 >>> list(constrained_batches(iterable, 10, max_count = 2))
4901 [(b'12345', b'123'), (b'12345678', b'1'), (b'1', b'12'), (b'1',)]
4903 If a *get_len* function is supplied, use that instead of :func:`len` to
4904 determine item size.
4906 If *strict* is ``True``, raise ``ValueError`` if any single item is bigger
4907 than *max_size*. Otherwise, allow single items to exceed *max_size*.
4908 """
4909 if max_size <= 0:
4910 raise ValueError('maximum size must be greater than zero')
4912 batch = []
4913 batch_size = 0
4914 batch_count = 0
4915 for item in iterable:
4916 item_len = get_len(item)
4917 if strict and item_len > max_size:
4918 raise ValueError('item size exceeds maximum size')
4920 reached_count = batch_count == max_count
4921 reached_size = item_len + batch_size > max_size
4922 if batch_count and (reached_size or reached_count):
4923 yield tuple(batch)
4924 batch.clear()
4925 batch_size = 0
4926 batch_count = 0
4928 batch.append(item)
4929 batch_size += item_len
4930 batch_count += 1
4932 if batch:
4933 yield tuple(batch)
4936def gray_product(*iterables, repeat=1):
4937 """Like :func:`itertools.product`, but return tuples in an order such
4938 that only one element in the generated tuple changes from one iteration
4939 to the next.
4941 >>> list(gray_product('AB','CD'))
4942 [('A', 'C'), ('B', 'C'), ('B', 'D'), ('A', 'D')]
4944 The *repeat* keyword argument specifies the number of repetitions
4945 of the iterables. For example, ``gray_product('AB', repeat=3)`` is
4946 equivalent to ``gray_product('AB', 'AB', 'AB')``.
4948 This function consumes all of the input iterables before producing output.
4949 If any of the input iterables have fewer than two items, ``ValueError``
4950 is raised.
4952 For information on the algorithm, see
4953 `this section <https://www-cs-faculty.stanford.edu/~knuth/fasc2a.ps.gz>`__
4954 of Donald Knuth's *The Art of Computer Programming*.
4955 """
4956 all_iterables = tuple(map(tuple, iterables)) * repeat
4957 iterable_count = len(all_iterables)
4958 for iterable in all_iterables:
4959 if len(iterable) < 2:
4960 raise ValueError("each iterable must have two or more items")
4962 # This is based on "Algorithm H" from section 7.2.1.1, page 20.
4963 # a holds the indexes of the source iterables for the n-tuple to be yielded
4964 # f is the array of "focus pointers"
4965 # o is the array of "directions"
4966 a = [0] * iterable_count
4967 f = list(range(iterable_count + 1))
4968 o = [1] * iterable_count
4969 while True:
4970 yield tuple(all_iterables[i][a[i]] for i in range(iterable_count))
4971 j = f[0]
4972 f[0] = 0
4973 if j == iterable_count:
4974 break
4975 a[j] = a[j] + o[j]
4976 if a[j] == 0 or a[j] == len(all_iterables[j]) - 1:
4977 o[j] = -o[j]
4978 f[j] = f[j + 1]
4979 f[j + 1] = j + 1
4982def partial_product(*iterables, repeat=1):
4983 """Yields tuples containing one item from each iterator, with subsequent
4984 tuples changing a single item at a time by advancing each iterator until it
4985 is exhausted. This sequence guarantees every value in each iterable is
4986 output at least once without generating all possible combinations.
4988 This may be useful, for example, when testing an expensive function.
4990 >>> list(partial_product('AB', 'C', 'DEF'))
4991 [('A', 'C', 'D'), ('B', 'C', 'D'), ('B', 'C', 'E'), ('B', 'C', 'F')]
4993 The *repeat* keyword argument specifies the number of repetitions
4994 of the iterables. For example, ``partial_product('AB', repeat=3)`` is
4995 equivalent to ``partial_product('AB', 'AB', 'AB')``.
4996 """
4998 all_iterables = tuple(map(tuple, iterables)) * repeat
4999 iterators = tuple(map(iter, all_iterables))
5001 try:
5002 prod = [next(it) for it in iterators]
5003 except StopIteration:
5004 return
5005 yield tuple(prod)
5007 for i, it in enumerate(iterators):
5008 for prod[i] in it:
5009 yield tuple(prod)
5012def takewhile_inclusive(predicate, iterable):
5013 """A variant of :func:`takewhile` that yields one additional element.
5015 >>> list(takewhile_inclusive(lambda x: x < 5, [1, 4, 6, 4, 1]))
5016 [1, 4, 6]
5018 :func:`takewhile` would return ``[1, 4]``.
5019 """
5020 for x in iterable:
5021 yield x
5022 if not predicate(x):
5023 break
5026def outer_product(func, xs, ys, *args, **kwargs):
5027 """A generalized outer product that applies a binary function to all
5028 pairs of items. Returns a 2D matrix with ``len(xs)`` rows and ``len(ys)``
5029 columns.
5030 Also accepts ``*args`` and ``**kwargs`` that are passed to ``func``.
5032 Multiplication table:
5034 >>> from operator import mul
5035 >>> list(outer_product(mul, range(1, 4), range(1, 6)))
5036 [(1, 2, 3, 4, 5), (2, 4, 6, 8, 10), (3, 6, 9, 12, 15)]
5038 Cross tabulation:
5040 >>> xs = ['A', 'B', 'A', 'A', 'B', 'B', 'A', 'A', 'B', 'B']
5041 >>> ys = ['X', 'X', 'X', 'Y', 'Z', 'Z', 'Y', 'Y', 'Z', 'Z']
5042 >>> pair_counts = Counter(zip(xs, ys))
5043 >>> count_rows = lambda x, y: pair_counts[x, y]
5044 >>> list(outer_product(count_rows, sorted(set(xs)), sorted(set(ys))))
5045 [(2, 3, 0), (1, 0, 4)]
5047 Usage with ``*args`` and ``**kwargs``:
5049 >>> animals = ['cat', 'wolf', 'mouse']
5050 >>> list(outer_product(min, animals, animals, key=len))
5051 [('cat', 'cat', 'cat'), ('cat', 'wolf', 'wolf'), ('cat', 'wolf', 'mouse')]
5052 """
5053 ys = tuple(ys)
5054 return batched(
5055 starmap(lambda x, y: func(x, y, *args, **kwargs), product(xs, ys)),
5056 n=len(ys),
5057 )
5060def iter_suppress(iterable, *exceptions):
5061 """Yield each of the items from *iterable*. If the iteration raises one of
5062 the specified *exceptions*, that exception will be suppressed and iteration
5063 will stop.
5065 >>> from itertools import chain
5066 >>> def breaks_at_five(x):
5067 ... while True:
5068 ... if x >= 5:
5069 ... raise RuntimeError
5070 ... yield x
5071 ... x += 1
5072 >>> it_1 = iter_suppress(breaks_at_five(1), RuntimeError)
5073 >>> it_2 = iter_suppress(breaks_at_five(2), RuntimeError)
5074 >>> list(chain(it_1, it_2))
5075 [1, 2, 3, 4, 2, 3, 4]
5076 """
5077 try:
5078 yield from iterable
5079 except exceptions:
5080 return
5083def filter_map(func, iterable):
5084 """Apply *func* to every element of *iterable*, yielding only those which
5085 are not ``None``.
5087 >>> elems = ['1', 'a', '2', 'b', '3']
5088 >>> list(filter_map(lambda s: int(s) if s.isnumeric() else None, elems))
5089 [1, 2, 3]
5090 """
5091 for x in iterable:
5092 y = func(x)
5093 if y is not None:
5094 yield y
5097def powerset_of_sets(iterable, *, baseset=set):
5098 """Yields all possible subsets of the iterable.
5100 >>> list(powerset_of_sets([1, 2, 3])) # doctest: +SKIP
5101 [set(), {1}, {2}, {3}, {1, 2}, {1, 3}, {2, 3}, {1, 2, 3}]
5102 >>> list(powerset_of_sets([1, 1, 0])) # doctest: +SKIP
5103 [set(), {1}, {0}, {0, 1}]
5105 :func:`powerset_of_sets` takes care to minimize the number
5106 of hash operations performed.
5108 The *baseset* parameter determines what kind of sets are
5109 constructed, either *set* or *frozenset*.
5110 """
5111 sets = tuple(dict.fromkeys(map(frozenset, zip(iterable))))
5112 union = baseset().union
5113 return chain.from_iterable(
5114 starmap(union, combinations(sets, r)) for r in range(len(sets) + 1)
5115 )
5118def join_mappings(**field_to_map):
5119 """
5120 Joins multiple mappings together using their common keys.
5122 >>> user_scores = {'elliot': 50, 'claris': 60}
5123 >>> user_times = {'elliot': 30, 'claris': 40}
5124 >>> join_mappings(score=user_scores, time=user_times)
5125 {'elliot': {'score': 50, 'time': 30}, 'claris': {'score': 60, 'time': 40}}
5126 """
5127 ret = defaultdict(dict)
5129 for field_name, mapping in field_to_map.items():
5130 for key, value in mapping.items():
5131 ret[key][field_name] = value
5133 return dict(ret)
5136def _complex_sumprod(v1, v2):
5137 """High precision sumprod() for complex numbers.
5138 Used by :func:`dft` and :func:`idft`.
5139 """
5141 real = attrgetter('real')
5142 imag = attrgetter('imag')
5143 r1 = chain(map(real, v1), map(neg, map(imag, v1)))
5144 r2 = chain(map(real, v2), map(imag, v2))
5145 i1 = chain(map(real, v1), map(imag, v1))
5146 i2 = chain(map(imag, v2), map(real, v2))
5147 return complex(_fsumprod(r1, r2), _fsumprod(i1, i2))
5150def dft(xarr):
5151 """Discrete Fourier Transform. *xarr* is a sequence of complex numbers.
5152 Yields the components of the corresponding transformed output vector.
5154 >>> import cmath
5155 >>> xarr = [1, 2-1j, -1j, -1+2j] # time domain
5156 >>> Xarr = [2, -2-2j, -2j, 4+4j] # frequency domain
5157 >>> magnitudes, phases = zip(*map(cmath.polar, Xarr))
5158 >>> all(map(cmath.isclose, dft(xarr), Xarr))
5159 True
5161 Inputs are restricted to numeric types that can add and multiply
5162 with a complex number. This includes int, float, complex, and
5163 Fraction, but excludes Decimal.
5165 See :func:`idft` for the inverse Discrete Fourier Transform.
5166 """
5167 N = len(xarr)
5168 roots_of_unity = [e ** (n / N * tau * -1j) for n in range(N)]
5169 for k in range(N):
5170 coeffs = [roots_of_unity[k * n % N] for n in range(N)]
5171 yield _complex_sumprod(xarr, coeffs)
5174def idft(Xarr):
5175 """Inverse Discrete Fourier Transform. *Xarr* is a sequence of
5176 complex numbers. Yields the components of the corresponding
5177 inverse-transformed output vector.
5179 >>> import cmath
5180 >>> xarr = [1, 2-1j, -1j, -1+2j] # time domain
5181 >>> Xarr = [2, -2-2j, -2j, 4+4j] # frequency domain
5182 >>> all(map(cmath.isclose, idft(Xarr), xarr))
5183 True
5185 Inputs are restricted to numeric types that can add and multiply
5186 with a complex number. This includes int, float, complex, and
5187 Fraction, but excludes Decimal.
5189 See :func:`dft` for the Discrete Fourier Transform.
5190 """
5191 N = len(Xarr)
5192 roots_of_unity = [e ** (n / N * tau * 1j) for n in range(N)]
5193 for k in range(N):
5194 coeffs = [roots_of_unity[k * n % N] for n in range(N)]
5195 yield _complex_sumprod(Xarr, coeffs) / N
5198def doublestarmap(func, iterable):
5199 """Apply *func* to every item of *iterable* by dictionary unpacking
5200 the item into *func*.
5202 The difference between :func:`itertools.starmap` and :func:`doublestarmap`
5203 parallels the distinction between ``func(*a)`` and ``func(**a)``.
5205 >>> iterable = [{'a': 1, 'b': 2}, {'a': 40, 'b': 60}]
5206 >>> list(doublestarmap(lambda a, b: a + b, iterable))
5207 [3, 100]
5209 ``TypeError`` will be raised if *func*'s signature doesn't match the
5210 mapping contained in *iterable* or if *iterable* does not contain mappings.
5211 """
5212 for item in iterable:
5213 yield func(**item)
5216def _nth_prime_bounds(n):
5217 """Bounds for the nth prime (counting from 1): lb < p_n < ub."""
5218 # At and above 688,383, the lb/ub spread is under 0.003 * p_n.
5220 if n < 1:
5221 raise ValueError
5223 if n < 6:
5224 return (n, 2.25 * n)
5226 # https://en.wikipedia.org/wiki/Prime-counting_function#Inequalities
5227 upper_bound = n * log(n * log(n))
5228 lower_bound = upper_bound - n
5229 if n >= 688_383:
5230 upper_bound -= n * (1.0 - (log(log(n)) - 2.0) / log(n))
5232 return lower_bound, upper_bound
5235def nth_prime(n, *, approximate=False):
5236 """Return the nth prime (counting from 0).
5238 >>> nth_prime(0)
5239 2
5240 >>> nth_prime(100)
5241 547
5243 If *approximate* is set to True, will return a prime close
5244 to the nth prime. The estimation is much faster than computing
5245 an exact result.
5247 >>> nth_prime(200_000_000, approximate=True) # Exact result is 4222234763
5248 4217820427
5250 """
5251 lb, ub = _nth_prime_bounds(n + 1)
5253 if not approximate or n <= 1_000_000:
5254 return nth(sieve(ceil(ub)), n)
5256 # Search from the midpoint and return the first odd prime
5257 odd = floor((lb + ub) / 2) | 1
5258 return first_true(count(odd, step=2), pred=is_prime)
5261def argmin(iterable, *, key=None):
5262 """
5263 Index of the first occurrence of a minimum value in an iterable.
5265 >>> argmin('efghabcdijkl')
5266 4
5267 >>> argmin([3, 2, 1, 0, 4, 2, 1, 0])
5268 3
5270 For example, look up a label corresponding to the position
5271 of a value that minimizes a cost function::
5273 >>> def cost(x):
5274 ... "Days for a wound to heal given a subject's age."
5275 ... return x**2 - 20*x + 150
5276 ...
5277 >>> labels = ['homer', 'marge', 'bart', 'lisa', 'maggie']
5278 >>> ages = [ 35, 30, 10, 9, 1 ]
5280 # Fastest healing family member
5281 >>> labels[argmin(ages, key=cost)]
5282 'bart'
5284 # Age with fastest healing
5285 >>> min(ages, key=cost)
5286 10
5288 """
5289 if key is not None:
5290 iterable = map(key, iterable)
5291 return min(enumerate(iterable), key=itemgetter(1))[0]
5294def argmax(iterable, *, key=None):
5295 """
5296 Index of the first occurrence of a maximum value in an iterable.
5298 >>> argmax('abcdefghabcd')
5299 7
5300 >>> argmax([0, 1, 2, 3, 3, 2, 1, 0])
5301 3
5303 For example, identify the best machine learning model::
5305 >>> models = ['svm', 'random forest', 'knn', 'naïve bayes']
5306 >>> accuracy = [ 68, 61, 84, 72 ]
5308 # Most accurate model
5309 >>> models[argmax(accuracy)]
5310 'knn'
5312 # Best accuracy
5313 >>> max(accuracy)
5314 84
5316 """
5317 if key is not None:
5318 iterable = map(key, iterable)
5319 return max(enumerate(iterable), key=itemgetter(1))[0]
5322def _extract_monotonic(iterator, indices):
5323 'Non-decreasing indices, lazily consumed'
5324 num_read = 0
5325 for index in indices:
5326 advance = index - num_read
5327 try:
5328 value = next(islice(iterator, advance, None))
5329 except ValueError:
5330 if advance != -1 or index < 0:
5331 raise ValueError(f'Invalid index: {index}') from None
5332 except StopIteration:
5333 raise IndexError(index) from None
5334 else:
5335 num_read += advance + 1
5336 yield value
5339def _extract_buffered(iterator, index_and_position):
5340 'Arbitrary index order, greedily consumed'
5341 buffer = {}
5342 iterator_position = -1
5343 next_to_emit = 0
5345 for index, order in index_and_position:
5346 advance = index - iterator_position
5347 if advance:
5348 try:
5349 value = next(islice(iterator, advance - 1, None))
5350 except StopIteration:
5351 raise IndexError(index) from None
5352 iterator_position = index
5354 buffer[order] = value
5356 while next_to_emit in buffer:
5357 yield buffer.pop(next_to_emit)
5358 next_to_emit += 1
5361def extract(iterable, indices, *, monotonic=False):
5362 """Yield values at the specified indices.
5364 Example:
5366 >>> data = 'abcdefghijklmnopqrstuvwxyz'
5367 >>> list(extract(data, [7, 4, 11, 11, 14]))
5368 ['h', 'e', 'l', 'l', 'o']
5370 The *iterable* is consumed lazily and can be infinite.
5372 When *monotonic* is false, the *indices* are consumed immediately
5373 and must be finite. When *monotonic* is true, *indices* are consumed
5374 lazily and can be infinite but must be non-decreasing.
5376 Raises ``IndexError`` if an index lies beyond the iterable.
5377 Raises ``ValueError`` for a negative index or for a decreasing
5378 index when *monotonic* is true.
5379 """
5381 iterator = iter(iterable)
5382 indices = iter(indices)
5384 if monotonic:
5385 return _extract_monotonic(iterator, indices)
5387 index_and_position = sorted(zip(indices, count()))
5388 if index_and_position and index_and_position[0][0] < 0:
5389 raise ValueError('Indices must be non-negative')
5390 return _extract_buffered(iterator, index_and_position)
5393class serialize:
5394 """Wrap a non-concurrent iterator with a lock to enforce sequential access.
5396 Applies a non-reentrant lock around calls to ``__next__``, allowing
5397 iterator and generator instances to be shared by multiple consumer
5398 threads.
5399 """
5401 __slots__ = ('_iterator', '_lock')
5403 def __init__(self, iterable):
5404 self._iterator = iter(iterable)
5405 self._lock = allocate_lock()
5407 def __iter__(self):
5408 return self
5410 def __next__(self):
5411 with self._lock:
5412 return next(self._iterator)
5414 def send(self, value, /):
5415 """Send a value to a generator.
5417 Raises AttributeError if not a generator.
5418 """
5419 with self._lock:
5420 return self._iterator.send(value)
5422 def throw(self, *args):
5423 """Call throw() on a generator.
5425 Raises AttributeError if not a generator.
5426 """
5427 with self._lock:
5428 return self._iterator.throw(*args)
5430 def close(self):
5431 """Call close() on a generator.
5433 Raises AttributeError if not a generator.
5434 """
5435 with self._lock:
5436 return self._iterator.close()
5439def synchronized(func):
5440 """Wrap an iterator-returning callable to make its iterators thread-safe.
5442 Existing itertools and more-itertools can be wrapped so that their
5443 iterator instances are serialized.
5445 For example, ``itertools.count`` does not make thread-safe instances,
5446 but that is easily fixed with::
5448 atomic_counter = synchronized(itertools.count)
5450 Can also be used as a decorator for generator functions definitions
5451 so that the generator instances are serialized::
5453 @synchronized
5454 def enumerate_and_timestamp(iterable):
5455 for count, value in enumerate(iterable):
5456 yield count, time_ns(), value
5458 """
5460 @wraps(func)
5461 def inner(*args, **kwargs):
5462 iterator = func(*args, **kwargs)
5463 return serialize(iterator)
5465 return inner
5468def concurrent_tee(iterable, n=2):
5469 """Variant of itertools.tee() but with guaranteed threading semantics.
5471 Takes a non-threadsafe iterator as an input and creates concurrent
5472 tee objects for other threads to have reliable independent copies of
5473 the data stream.
5475 The new iterators are only thread-safe if consumed within a single thread.
5476 To share just one of the new iterators across multiple threads, wrap it
5477 with :func:`serialize`.
5478 """
5480 if n < 0:
5481 raise ValueError
5482 if n == 0:
5483 return ()
5484 iterator = _concurrent_tee(iterable)
5485 result = [iterator]
5486 for _ in range(n - 1):
5487 result.append(_concurrent_tee(iterator))
5488 return tuple(result)
5491class _concurrent_tee:
5492 __slots__ = ('iterator', 'link', 'lock')
5494 def __init__(self, iterable):
5495 if isinstance(iterable, _concurrent_tee):
5496 self.iterator = iterable.iterator
5497 self.link = iterable.link
5498 self.lock = iterable.lock
5499 else:
5500 self.iterator = iter(iterable)
5501 self.link = [None, None]
5502 self.lock = allocate_lock()
5504 def __iter__(self):
5505 return self
5507 def __next__(self):
5508 link = self.link
5509 if link[1] is None:
5510 with self.lock:
5511 if link[1] is None:
5512 link[0] = next(self.iterator)
5513 link[1] = [None, None]
5514 value, self.link = link
5515 return value
5518def subfactorial(n):
5519 """Number of permutations of *n* elements with no fixed points.
5521 The :func:`subfactorial` function computes the length of
5522 :func:`derangements`. For example, there are 1,854 ways to
5523 rearrange the letters in word "epsilon" without leaving any
5524 letter in its original position:
5526 >>> from more_itertools import derangements, ilen
5527 >>> ilen(derangements('epsilon'))
5528 1854
5529 >>> subfactorial(len('epsilon'))
5530 1854
5532 Reference: https://oeis.org/A000166
5534 """
5535 if n < 0:
5536 raise ValueError
5537 sf = adj = 1
5538 for i in range(n + 1):
5539 sf = sf * i + adj
5540 adj = -adj
5541 return sf