Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/more_itertools/more.py: 19%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import math
3from collections import Counter, defaultdict, deque
4from collections.abc import Sequence
5from contextlib import suppress
6from functools import cached_property, partial, reduce, wraps
7from heapq import heapify, heapreplace
8from itertools import (
9 chain,
10 combinations,
11 compress,
12 count,
13 cycle,
14 dropwhile,
15 groupby,
16 islice,
17 permutations,
18 repeat,
19 starmap,
20 takewhile,
21 tee,
22 zip_longest,
23 product,
24)
25from math import comb, e, exp, factorial, floor, fsum, log, log1p, perm, tau
26from math import ceil
27from queue import Empty, Queue
28from random import random, randrange, shuffle, uniform
29from operator import (
30 attrgetter,
31 is_not,
32 itemgetter,
33 lt,
34 mul,
35 neg,
36 sub,
37 gt,
38)
39from sys import maxsize
40from time import monotonic
42from .recipes import (
43 _marker,
44 consume,
45 first_true,
46 flatten,
47 is_prime,
48 nth,
49 powerset,
50 sieve,
51 take,
52 unique_everseen,
53 all_equal,
54 batched,
55)
57__all__ = [
58 'AbortThread',
59 'SequenceView',
60 'adjacent',
61 'all_unique',
62 'always_iterable',
63 'always_reversible',
64 'argmax',
65 'argmin',
66 'bucket',
67 'callback_iter',
68 'chunked',
69 'chunked_even',
70 'circular_shifts',
71 'collapse',
72 'combination_index',
73 'combination_with_replacement_index',
74 'consecutive_groups',
75 'constrained_batches',
76 'consumer',
77 'count_cycle',
78 'countable',
79 'derangements',
80 'dft',
81 'difference',
82 'distinct_combinations',
83 'distinct_permutations',
84 'distribute',
85 'divide',
86 'doublestarmap',
87 'duplicates_everseen',
88 'duplicates_justseen',
89 'classify_unique',
90 'exactly_n',
91 'extract',
92 'filter_except',
93 'filter_map',
94 'first',
95 'gray_product',
96 'groupby_transform',
97 'ichunked',
98 'iequals',
99 'idft',
100 'ilen',
101 'interleave',
102 'interleave_evenly',
103 'interleave_longest',
104 'interleave_randomly',
105 'intersperse',
106 'is_sorted',
107 'islice_extended',
108 'iterate',
109 'iter_suppress',
110 'join_mappings',
111 'last',
112 'locate',
113 'longest_common_prefix',
114 'lstrip',
115 'make_decorator',
116 'map_except',
117 'map_if',
118 'map_reduce',
119 'mark_ends',
120 'minmax',
121 'nth_or_last',
122 'nth_permutation',
123 'nth_prime',
124 'nth_product',
125 'nth_combination_with_replacement',
126 'numeric_range',
127 'one',
128 'only',
129 'outer_product',
130 'padded',
131 'partial_product',
132 'partitions',
133 'peekable',
134 'permutation_index',
135 'powerset_of_sets',
136 'product_index',
137 'raise_',
138 'repeat_each',
139 'repeat_last',
140 'replace',
141 'rlocate',
142 'rstrip',
143 'run_length',
144 'sample',
145 'seekable',
146 'set_partitions',
147 'side_effect',
148 'sliced',
149 'sort_together',
150 'split_after',
151 'split_at',
152 'split_before',
153 'split_into',
154 'split_when',
155 'spy',
156 'stagger',
157 'strip',
158 'strictly_n',
159 'substrings',
160 'substrings_indexes',
161 'takewhile_inclusive',
162 'time_limited',
163 'unique_in_window',
164 'unique_to_each',
165 'unzip',
166 'value_chain',
167 'windowed',
168 'windowed_complete',
169 'with_iter',
170 'zip_broadcast',
171 'zip_offset',
172]
174# math.sumprod is available for Python 3.12+
175try:
176 from math import sumprod as _fsumprod
178except ImportError: # pragma: no cover
179 # Extended precision algorithms from T. J. Dekker,
180 # "A Floating-Point Technique for Extending the Available Precision"
181 # https://csclub.uwaterloo.ca/~pbarfuss/dekker1971.pdf
182 # Formulas: (5.5) (5.6) and (5.8). Code: mul12()
184 def dl_split(x: float):
185 "Split a float into two half-precision components."
186 t = x * 134217729.0 # Veltkamp constant = 2.0 ** 27 + 1
187 hi = t - (t - x)
188 lo = x - hi
189 return hi, lo
191 def dl_mul(x, y):
192 "Lossless multiplication."
193 xx_hi, xx_lo = dl_split(x)
194 yy_hi, yy_lo = dl_split(y)
195 p = xx_hi * yy_hi
196 q = xx_hi * yy_lo + xx_lo * yy_hi
197 z = p + q
198 zz = p - z + q + xx_lo * yy_lo
199 return z, zz
201 def _fsumprod(p, q):
202 return fsum(chain.from_iterable(map(dl_mul, p, q)))
205def chunked(iterable, n, strict=False):
206 """Break *iterable* into lists of length *n*:
208 >>> list(chunked([1, 2, 3, 4, 5, 6], 3))
209 [[1, 2, 3], [4, 5, 6]]
211 By the default, the last yielded list will have fewer than *n* elements
212 if the length of *iterable* is not divisible by *n*:
214 >>> list(chunked([1, 2, 3, 4, 5, 6, 7, 8], 3))
215 [[1, 2, 3], [4, 5, 6], [7, 8]]
217 To use a fill-in value instead, see the :func:`grouper` recipe.
219 If the length of *iterable* is not divisible by *n* and *strict* is
220 ``True``, then ``ValueError`` will be raised before the last
221 list is yielded.
223 """
224 iterator = iter(partial(take, n, iter(iterable)), [])
225 if strict:
226 if n is None:
227 raise ValueError('n must not be None when using strict mode.')
229 def ret():
230 for chunk in iterator:
231 if len(chunk) != n:
232 raise ValueError('iterable is not divisible by n.')
233 yield chunk
235 return ret()
236 else:
237 return iterator
240def first(iterable, default=_marker):
241 """Return the first item of *iterable*, or *default* if *iterable* is
242 empty.
244 >>> first([0, 1, 2, 3])
245 0
246 >>> first([], 'some default')
247 'some default'
249 If *default* is not provided and there are no items in the iterable,
250 raise ``ValueError``.
252 :func:`first` is useful when you have a generator of expensive-to-retrieve
253 values and want any arbitrary one. It is marginally shorter than
254 ``next(iter(iterable), default)``.
256 """
257 for item in iterable:
258 return item
259 if default is _marker:
260 raise ValueError(
261 'first() was called on an empty iterable, '
262 'and no default value was provided.'
263 )
264 return default
267def last(iterable, default=_marker):
268 """Return the last item of *iterable*, or *default* if *iterable* is
269 empty.
271 >>> last([0, 1, 2, 3])
272 3
273 >>> last([], 'some default')
274 'some default'
276 If *default* is not provided and there are no items in the iterable,
277 raise ``ValueError``.
278 """
279 try:
280 if isinstance(iterable, Sequence):
281 return iterable[-1]
282 # Work around https://bugs.python.org/issue38525
283 if getattr(iterable, '__reversed__', None):
284 return next(reversed(iterable))
285 return deque(iterable, maxlen=1)[-1]
286 except (IndexError, TypeError, StopIteration):
287 if default is _marker:
288 raise ValueError(
289 'last() was called on an empty iterable, '
290 'and no default value was provided.'
291 )
292 return default
295def nth_or_last(iterable, n, default=_marker):
296 """Return the nth or the last item of *iterable*,
297 or *default* if *iterable* is empty.
299 >>> nth_or_last([0, 1, 2, 3], 2)
300 2
301 >>> nth_or_last([0, 1], 2)
302 1
303 >>> nth_or_last([], 0, 'some default')
304 'some default'
306 If *default* is not provided and there are no items in the iterable,
307 raise ``ValueError``.
308 """
309 return last(islice(iterable, n + 1), default=default)
312class peekable:
313 """Wrap an iterator to allow lookahead and prepending elements.
315 Call :meth:`peek` on the result to get the value that will be returned
316 by :func:`next`. This won't advance the iterator:
318 >>> p = peekable(['a', 'b'])
319 >>> p.peek()
320 'a'
321 >>> next(p)
322 'a'
324 Pass :meth:`peek` a default value to return that instead of raising
325 ``StopIteration`` when the iterator is exhausted.
327 >>> p = peekable([])
328 >>> p.peek('hi')
329 'hi'
331 peekables also offer a :meth:`prepend` method, which "inserts" items
332 at the head of the iterable:
334 >>> p = peekable([1, 2, 3])
335 >>> p.prepend(10, 11, 12)
336 >>> next(p)
337 10
338 >>> p.peek()
339 11
340 >>> list(p)
341 [11, 12, 1, 2, 3]
343 peekables can be indexed. Index 0 is the item that will be returned by
344 :func:`next`, index 1 is the item after that, and so on:
345 The values up to the given index will be cached.
347 >>> p = peekable(['a', 'b', 'c', 'd'])
348 >>> p[0]
349 'a'
350 >>> p[1]
351 'b'
352 >>> next(p)
353 'a'
355 Negative indexes are supported, but be aware that they will cache the
356 remaining items in the source iterator, which may require significant
357 storage.
359 To check whether a peekable is exhausted, check its truth value:
361 >>> p = peekable(['a', 'b'])
362 >>> if p: # peekable has items
363 ... list(p)
364 ['a', 'b']
365 >>> if not p: # peekable is exhausted
366 ... list(p)
367 []
369 """
371 def __init__(self, iterable):
372 self._it = iter(iterable)
373 self._cache = deque()
375 def __iter__(self):
376 return self
378 def __bool__(self):
379 try:
380 self.peek()
381 except StopIteration:
382 return False
383 return True
385 def peek(self, default=_marker):
386 """Return the item that will be next returned from ``next()``.
388 Return ``default`` if there are no items left. If ``default`` is not
389 provided, raise ``StopIteration``.
391 """
392 if not self._cache:
393 try:
394 self._cache.append(next(self._it))
395 except StopIteration:
396 if default is _marker:
397 raise
398 return default
399 return self._cache[0]
401 def prepend(self, *items):
402 """Stack up items to be the next ones returned from ``next()`` or
403 ``self.peek()``. The items will be returned in
404 first in, first out order::
406 >>> p = peekable([1, 2, 3])
407 >>> p.prepend(10, 11, 12)
408 >>> next(p)
409 10
410 >>> list(p)
411 [11, 12, 1, 2, 3]
413 It is possible, by prepending items, to "resurrect" a peekable that
414 previously raised ``StopIteration``.
416 >>> p = peekable([])
417 >>> next(p)
418 Traceback (most recent call last):
419 ...
420 StopIteration
421 >>> p.prepend(1)
422 >>> next(p)
423 1
424 >>> next(p)
425 Traceback (most recent call last):
426 ...
427 StopIteration
429 """
430 self._cache.extendleft(reversed(items))
432 def __next__(self):
433 if self._cache:
434 return self._cache.popleft()
436 return next(self._it)
438 def _get_slice(self, index):
439 # Normalize the slice's arguments
440 step = 1 if (index.step is None) else index.step
441 if step > 0:
442 start = 0 if (index.start is None) else index.start
443 stop = maxsize if (index.stop is None) else index.stop
444 elif step < 0:
445 start = -1 if (index.start is None) else index.start
446 stop = (-maxsize - 1) if (index.stop is None) else index.stop
447 else:
448 raise ValueError('slice step cannot be zero')
450 # If either the start or stop index is negative, we'll need to cache
451 # the rest of the iterable in order to slice from the right side.
452 if (start < 0) or (stop < 0):
453 self._cache.extend(self._it)
454 # Otherwise we'll need to find the rightmost index and cache to that
455 # point.
456 else:
457 n = min(max(start, stop) + 1, maxsize)
458 cache_len = len(self._cache)
459 if n >= cache_len:
460 self._cache.extend(islice(self._it, n - cache_len))
462 return list(self._cache)[index]
464 def __getitem__(self, index):
465 if isinstance(index, slice):
466 return self._get_slice(index)
468 cache_len = len(self._cache)
469 if index < 0:
470 self._cache.extend(self._it)
471 elif index >= cache_len:
472 self._cache.extend(islice(self._it, index + 1 - cache_len))
474 return self._cache[index]
477def consumer(func):
478 """Decorator that automatically advances a PEP-342-style "reverse iterator"
479 to its first yield point so you don't have to call ``next()`` on it
480 manually.
482 >>> @consumer
483 ... def tally():
484 ... i = 0
485 ... while True:
486 ... print('Thing number %s is %s.' % (i, (yield)))
487 ... i += 1
488 ...
489 >>> t = tally()
490 >>> t.send('red')
491 Thing number 0 is red.
492 >>> t.send('fish')
493 Thing number 1 is fish.
495 Without the decorator, you would have to call ``next(t)`` before
496 ``t.send()`` could be used.
498 """
500 @wraps(func)
501 def wrapper(*args, **kwargs):
502 gen = func(*args, **kwargs)
503 next(gen)
504 return gen
506 return wrapper
509def ilen(iterable):
510 """Return the number of items in *iterable*.
512 For example, there are 168 prime numbers below 1,000:
514 >>> ilen(sieve(1000))
515 168
517 Equivalent to, but faster than::
519 def ilen(iterable):
520 count = 0
521 for _ in iterable:
522 count += 1
523 return count
525 This fully consumes the iterable, so handle with care.
527 """
528 # This is the "most beautiful of the fast variants" of this function.
529 # If you think you can improve on it, please ensure that your version
530 # is both 10x faster and 10x more beautiful.
531 return sum(compress(repeat(1), zip(iterable)))
534def iterate(func, start):
535 """Return ``start``, ``func(start)``, ``func(func(start))``, ...
537 Produces an infinite iterator. To add a stopping condition,
538 use :func:`take`, ``takewhile``, or :func:`takewhile_inclusive`:.
540 >>> take(10, iterate(lambda x: 2*x, 1))
541 [1, 2, 4, 8, 16, 32, 64, 128, 256, 512]
543 >>> collatz = lambda x: 3*x + 1 if x%2==1 else x // 2
544 >>> list(takewhile_inclusive(lambda x: x!=1, iterate(collatz, 10)))
545 [10, 5, 16, 8, 4, 2, 1]
547 """
548 with suppress(StopIteration):
549 while True:
550 yield start
551 start = func(start)
554def with_iter(context_manager):
555 """Wrap an iterable in a ``with`` statement, so it closes once exhausted.
557 For example, this will close the file when the iterator is exhausted::
559 upper_lines = (line.upper() for line in with_iter(open('foo')))
561 Any context manager which returns an iterable is a candidate for
562 ``with_iter``.
564 """
565 with context_manager as iterable:
566 yield from iterable
569def one(iterable, too_short=None, too_long=None):
570 """Return the first item from *iterable*, which is expected to contain only
571 that item. Raise an exception if *iterable* is empty or has more than one
572 item.
574 :func:`one` is useful for ensuring that an iterable contains only one item.
575 For example, it can be used to retrieve the result of a database query
576 that is expected to return a single row.
578 If *iterable* is empty, ``ValueError`` will be raised. You may specify a
579 different exception with the *too_short* keyword:
581 >>> it = []
582 >>> one(it) # doctest: +IGNORE_EXCEPTION_DETAIL
583 Traceback (most recent call last):
584 ...
585 ValueError: too few items in iterable (expected 1)'
586 >>> too_short = IndexError('too few items')
587 >>> one(it, too_short=too_short) # doctest: +IGNORE_EXCEPTION_DETAIL
588 Traceback (most recent call last):
589 ...
590 IndexError: too few items
592 Similarly, if *iterable* contains more than one item, ``ValueError`` will
593 be raised. You may specify a different exception with the *too_long*
594 keyword:
596 >>> it = ['too', 'many']
597 >>> one(it) # doctest: +IGNORE_EXCEPTION_DETAIL
598 Traceback (most recent call last):
599 ...
600 ValueError: Expected exactly one item in iterable, but got 'too',
601 'many', and perhaps more.
602 >>> too_long = RuntimeError
603 >>> one(it, too_long=too_long) # doctest: +IGNORE_EXCEPTION_DETAIL
604 Traceback (most recent call last):
605 ...
606 RuntimeError
608 Note that :func:`one` attempts to advance *iterable* twice to ensure there
609 is only one item. See :func:`spy` or :func:`peekable` to check iterable
610 contents less destructively.
612 """
613 iterator = iter(iterable)
614 for first in iterator:
615 for second in iterator:
616 msg = (
617 f'Expected exactly one item in iterable, but got {first!r}, '
618 f'{second!r}, and perhaps more.'
619 )
620 raise too_long or ValueError(msg)
621 return first
622 raise too_short or ValueError('too few items in iterable (expected 1)')
625def raise_(exception, *args):
626 raise exception(*args)
629def strictly_n(iterable, n, too_short=None, too_long=None):
630 """Validate that *iterable* has exactly *n* items and return them if
631 it does. If it has fewer than *n* items, call function *too_short*
632 with the actual number of items. If it has more than *n* items, call function
633 *too_long* with the number ``n + 1``.
635 >>> iterable = ['a', 'b', 'c', 'd']
636 >>> n = 4
637 >>> list(strictly_n(iterable, n))
638 ['a', 'b', 'c', 'd']
640 Note that the returned iterable must be consumed in order for the check to
641 be made.
643 By default, *too_short* and *too_long* are functions that raise
644 ``ValueError``.
646 >>> list(strictly_n('ab', 3)) # doctest: +IGNORE_EXCEPTION_DETAIL
647 Traceback (most recent call last):
648 ...
649 ValueError: too few items in iterable (got 2)
651 >>> list(strictly_n('abc', 2)) # doctest: +IGNORE_EXCEPTION_DETAIL
652 Traceback (most recent call last):
653 ...
654 ValueError: too many items in iterable (got at least 3)
656 You can instead supply functions that do something else.
657 *too_short* will be called with the number of items in *iterable*.
658 *too_long* will be called with `n + 1`.
660 >>> def too_short(item_count):
661 ... raise RuntimeError
662 >>> it = strictly_n('abcd', 6, too_short=too_short)
663 >>> list(it) # doctest: +IGNORE_EXCEPTION_DETAIL
664 Traceback (most recent call last):
665 ...
666 RuntimeError
668 >>> def too_long(item_count):
669 ... print('The boss is going to hear about this')
670 >>> it = strictly_n('abcdef', 4, too_long=too_long)
671 >>> list(it)
672 The boss is going to hear about this
673 ['a', 'b', 'c', 'd']
675 """
676 if too_short is None:
677 too_short = lambda item_count: raise_(
678 ValueError,
679 f'Too few items in iterable (got {item_count})',
680 )
682 if too_long is None:
683 too_long = lambda item_count: raise_(
684 ValueError,
685 f'Too many items in iterable (got at least {item_count})',
686 )
688 it = iter(iterable)
690 sent = 0
691 for item in islice(it, n):
692 yield item
693 sent += 1
695 if sent < n:
696 too_short(sent)
697 return
699 for item in it:
700 too_long(n + 1)
701 return
704def distinct_permutations(iterable, r=None):
705 """Yield successive distinct permutations of the elements in *iterable*.
707 >>> sorted(distinct_permutations([1, 0, 1]))
708 [(0, 1, 1), (1, 0, 1), (1, 1, 0)]
710 Equivalent to yielding from ``set(permutations(iterable))``, except
711 duplicates are not generated and thrown away. For larger input sequences
712 this is much more efficient.
714 Duplicate permutations arise when there are duplicated elements in the
715 input iterable. The number of items returned is
716 `n! / (x_1! * x_2! * ... * x_n!)`, where `n` is the total number of
717 items input, and each `x_i` is the count of a distinct item in the input
718 sequence. The function :func:`multinomial` computes this directly.
720 If *r* is given, only the *r*-length permutations are yielded.
722 >>> sorted(distinct_permutations([1, 0, 1], r=2))
723 [(0, 1), (1, 0), (1, 1)]
724 >>> sorted(distinct_permutations(range(3), r=2))
725 [(0, 1), (0, 2), (1, 0), (1, 2), (2, 0), (2, 1)]
727 *iterable* need not be sortable, but note that using equal (``x == y``)
728 but non-identical (``id(x) != id(y)``) elements may produce surprising
729 behavior. For example, ``1`` and ``True`` are equal but non-identical:
731 >>> list(distinct_permutations([1, True, '3'])) # doctest: +SKIP
732 [
733 (1, True, '3'),
734 (1, '3', True),
735 ('3', 1, True)
736 ]
737 >>> list(distinct_permutations([1, 2, '3'])) # doctest: +SKIP
738 [
739 (1, 2, '3'),
740 (1, '3', 2),
741 (2, 1, '3'),
742 (2, '3', 1),
743 ('3', 1, 2),
744 ('3', 2, 1)
745 ]
746 """
748 # Algorithm: https://w.wiki/Qai
749 def _full(A):
750 while True:
751 # Yield the permutation we have
752 yield tuple(A)
754 # Find the largest index i such that A[i] < A[i + 1]
755 for i in range(size - 2, -1, -1):
756 if A[i] < A[i + 1]:
757 break
758 # If no such index exists, this permutation is the last one
759 else:
760 return
762 # Find the largest index j greater than j such that A[i] < A[j]
763 for j in range(size - 1, i, -1):
764 if A[i] < A[j]:
765 break
767 # Swap the value of A[i] with that of A[j], then reverse the
768 # sequence from A[i + 1] to form the new permutation
769 A[i], A[j] = A[j], A[i]
770 A[i + 1 :] = A[: i - size : -1] # A[i + 1:][::-1]
772 # Algorithm: modified from the above
773 def _partial(A, r):
774 # Split A into the first r items and the last r items
775 head, tail = A[:r], A[r:]
776 right_head_indexes = range(r - 1, -1, -1)
777 left_tail_indexes = range(len(tail))
779 while True:
780 # Yield the permutation we have
781 yield tuple(head)
783 # Starting from the right, find the first index of the head with
784 # value smaller than the maximum value of the tail - call it i.
785 pivot = tail[-1]
786 for i in right_head_indexes:
787 if head[i] < pivot:
788 break
789 pivot = head[i]
790 else:
791 return
793 # Starting from the left, find the first value of the tail
794 # with a value greater than head[i] and swap.
795 for j in left_tail_indexes:
796 if tail[j] > head[i]:
797 head[i], tail[j] = tail[j], head[i]
798 break
799 # If we didn't find one, start from the right and find the first
800 # index of the head with a value greater than head[i] and swap.
801 else:
802 for j in right_head_indexes:
803 if head[j] > head[i]:
804 head[i], head[j] = head[j], head[i]
805 break
807 # Reverse head[i + 1:] and swap it with tail[:r - (i + 1)]
808 tail += head[: i - r : -1] # head[i + 1:][::-1]
809 i += 1
810 head[i:], tail[:] = tail[: r - i], tail[r - i :]
812 items = list(iterable)
814 try:
815 items.sort()
816 sortable = True
817 except TypeError:
818 sortable = False
820 indices_dict = defaultdict(list)
822 for item in items:
823 indices_dict[items.index(item)].append(item)
825 indices = [items.index(item) for item in items]
826 indices.sort()
828 equivalent_items = {k: cycle(v) for k, v in indices_dict.items()}
830 def permuted_items(permuted_indices):
831 return tuple(
832 next(equivalent_items[index]) for index in permuted_indices
833 )
835 size = len(items)
836 if r is None:
837 r = size
839 # functools.partial(_partial, ... )
840 algorithm = _full if (r == size) else partial(_partial, r=r)
842 if 0 < r <= size:
843 if sortable:
844 return algorithm(items)
845 else:
846 return (
847 permuted_items(permuted_indices)
848 for permuted_indices in algorithm(indices)
849 )
851 return iter(() if r else ((),))
854def derangements(iterable, r=None):
855 """Yield successive derangements of the elements in *iterable*.
857 A derangement is a permutation in which no element appears at its original
858 index. In other words, a derangement is a permutation that has no fixed points.
860 Suppose Alice, Bob, Carol, and Dave are playing Secret Santa.
861 The code below outputs all of the different ways to assign gift recipients
862 such that nobody is assigned to himself or herself:
864 >>> for d in derangements(['Alice', 'Bob', 'Carol', 'Dave']):
865 ... print(', '.join(d))
866 Bob, Alice, Dave, Carol
867 Bob, Carol, Dave, Alice
868 Bob, Dave, Alice, Carol
869 Carol, Alice, Dave, Bob
870 Carol, Dave, Alice, Bob
871 Carol, Dave, Bob, Alice
872 Dave, Alice, Bob, Carol
873 Dave, Carol, Alice, Bob
874 Dave, Carol, Bob, Alice
876 If *r* is given, only the *r*-length derangements are yielded.
878 >>> sorted(derangements(range(3), 2))
879 [(1, 0), (1, 2), (2, 0)]
880 >>> sorted(derangements([0, 2, 3], 2))
881 [(2, 0), (2, 3), (3, 0)]
883 Elements are treated as unique based on their position, not on their value.
885 Consider the Secret Santa example with two *different* people who have
886 the *same* name. Then there are two valid gift assignments even though
887 it might appear that a person is assigned to themselves:
889 >>> names = ['Alice', 'Bob', 'Bob']
890 >>> list(derangements(names))
891 [('Bob', 'Bob', 'Alice'), ('Bob', 'Alice', 'Bob')]
893 To avoid confusion, make the inputs distinct:
895 >>> deduped = [f'{name}{index}' for index, name in enumerate(names)]
896 >>> list(derangements(deduped))
897 [('Bob1', 'Bob2', 'Alice0'), ('Bob2', 'Alice0', 'Bob1')]
899 The number of derangements of a set of size *n* is known as the
900 "subfactorial of n". For n > 0, the subfactorial is:
901 ``round(math.factorial(n) / math.e)``.
903 References:
905 * Article: https://www.numberanalytics.com/blog/ultimate-guide-to-derangements-in-combinatorics
906 * Sizes: https://oeis.org/A000166
907 """
908 xs = tuple(iterable)
909 ys = tuple(range(len(xs)))
910 return compress(
911 permutations(xs, r=r),
912 map(all, map(map, repeat(is_not), repeat(ys), permutations(ys, r=r))),
913 )
916def intersperse(e, iterable, n=1):
917 """Intersperse filler element *e* among the items in *iterable*, leaving
918 *n* items between each filler element.
920 >>> list(intersperse('!', [1, 2, 3, 4, 5]))
921 [1, '!', 2, '!', 3, '!', 4, '!', 5]
923 >>> list(intersperse(None, [1, 2, 3, 4, 5], n=2))
924 [1, 2, None, 3, 4, None, 5]
926 """
927 if n == 0:
928 raise ValueError('n must be > 0')
929 elif n == 1:
930 # interleave(repeat(e), iterable) -> e, x_0, e, x_1, e, x_2...
931 # islice(..., 1, None) -> x_0, e, x_1, e, x_2...
932 return islice(interleave(repeat(e), iterable), 1, None)
933 else:
934 # interleave(filler, chunks) -> [e], [x_0, x_1], [e], [x_2, x_3]...
935 # islice(..., 1, None) -> [x_0, x_1], [e], [x_2, x_3]...
936 # flatten(...) -> x_0, x_1, e, x_2, x_3...
937 filler = repeat([e])
938 chunks = chunked(iterable, n)
939 return flatten(islice(interleave(filler, chunks), 1, None))
942def unique_to_each(*iterables):
943 """Return the elements from each of the input iterables that aren't in the
944 other input iterables.
946 For example, suppose you have a set of packages, each with a set of
947 dependencies::
949 {'pkg_1': {'A', 'B'}, 'pkg_2': {'B', 'C'}, 'pkg_3': {'B', 'D'}}
951 If you remove one package, which dependencies can also be removed?
953 If ``pkg_1`` is removed, then ``A`` is no longer necessary - it is not
954 associated with ``pkg_2`` or ``pkg_3``. Similarly, ``C`` is only needed for
955 ``pkg_2``, and ``D`` is only needed for ``pkg_3``::
957 >>> unique_to_each({'A', 'B'}, {'B', 'C'}, {'B', 'D'})
958 [['A'], ['C'], ['D']]
960 If there are duplicates in one input iterable that aren't in the others
961 they will be duplicated in the output. Input order is preserved::
963 >>> unique_to_each("mississippi", "missouri")
964 [['p', 'p'], ['o', 'u', 'r']]
966 It is assumed that the elements of each iterable are hashable.
968 """
969 pool = [list(it) for it in iterables]
970 counts = Counter(chain.from_iterable(map(set, pool)))
971 uniques = {element for element in counts if counts[element] == 1}
972 return [list(filter(uniques.__contains__, it)) for it in pool]
975def windowed(seq, n, fillvalue=None, step=1):
976 """Return a sliding window of width *n* over the given iterable.
978 >>> all_windows = windowed([1, 2, 3, 4, 5], 3)
979 >>> list(all_windows)
980 [(1, 2, 3), (2, 3, 4), (3, 4, 5)]
982 When the window is larger than the iterable, *fillvalue* is used in place
983 of missing values:
985 >>> list(windowed([1, 2, 3], 4))
986 [(1, 2, 3, None)]
988 Each window will advance in increments of *step*:
990 >>> list(windowed([1, 2, 3, 4, 5, 6], 3, fillvalue='!', step=2))
991 [(1, 2, 3), (3, 4, 5), (5, 6, '!')]
993 To slide into the iterable's items, use :func:`chain` to add filler items
994 to the left:
996 >>> iterable = [1, 2, 3, 4]
997 >>> n = 3
998 >>> padding = [None] * (n - 1)
999 >>> list(windowed(chain(padding, iterable), 3))
1000 [(None, None, 1), (None, 1, 2), (1, 2, 3), (2, 3, 4)]
1001 """
1002 if n < 0:
1003 raise ValueError('n must be >= 0')
1004 if n == 0:
1005 yield ()
1006 return
1007 if step < 1:
1008 raise ValueError('step must be >= 1')
1010 iterator = iter(seq)
1012 # Generate first window
1013 window = deque(islice(iterator, n), maxlen=n)
1015 # Deal with the first window not being full
1016 if not window:
1017 return
1018 if len(window) < n:
1019 yield tuple(window) + ((fillvalue,) * (n - len(window)))
1020 return
1021 yield tuple(window)
1023 # Create the filler for the next windows. The padding ensures
1024 # we have just enough elements to fill the last window.
1025 padding = (fillvalue,) * (n - 1 if step >= n else step - 1)
1026 filler = map(window.append, chain(iterator, padding))
1028 # Generate the rest of the windows
1029 for _ in islice(filler, step - 1, None, step):
1030 yield tuple(window)
1033def substrings(iterable):
1034 """Yield all of the substrings of *iterable*.
1036 >>> [''.join(s) for s in substrings('more')]
1037 ['m', 'o', 'r', 'e', 'mo', 'or', 're', 'mor', 'ore', 'more']
1039 Note that non-string iterables can also be subdivided.
1041 >>> list(substrings([0, 1, 2]))
1042 [(0,), (1,), (2,), (0, 1), (1, 2), (0, 1, 2)]
1044 """
1045 # The length-1 substrings
1046 seq = []
1047 for item in iterable:
1048 seq.append(item)
1049 yield (item,)
1050 seq = tuple(seq)
1051 item_count = len(seq)
1053 # And the rest
1054 for n in range(2, item_count + 1):
1055 for i in range(item_count - n + 1):
1056 yield seq[i : i + n]
1059def substrings_indexes(seq, reverse=False):
1060 """Yield all substrings and their positions in *seq*
1062 The items yielded will be a tuple of the form ``(substr, i, j)``, where
1063 ``substr == seq[i:j]``.
1065 This function only works for iterables that support slicing, such as
1066 ``str`` objects.
1068 >>> for item in substrings_indexes('more'):
1069 ... print(item)
1070 ('m', 0, 1)
1071 ('o', 1, 2)
1072 ('r', 2, 3)
1073 ('e', 3, 4)
1074 ('mo', 0, 2)
1075 ('or', 1, 3)
1076 ('re', 2, 4)
1077 ('mor', 0, 3)
1078 ('ore', 1, 4)
1079 ('more', 0, 4)
1081 Set *reverse* to ``True`` to yield the same items in the opposite order.
1084 """
1085 r = range(1, len(seq) + 1)
1086 if reverse:
1087 r = reversed(r)
1088 return (
1089 (seq[i : i + L], i, i + L) for L in r for i in range(len(seq) - L + 1)
1090 )
1093class bucket:
1094 """Wrap *iterable* and return an object that buckets the iterable into
1095 child iterables based on a *key* function.
1097 >>> iterable = ['a1', 'b1', 'c1', 'a2', 'b2', 'c2', 'b3']
1098 >>> s = bucket(iterable, key=lambda x: x[0]) # Bucket by 1st character
1099 >>> sorted(list(s)) # Get the keys
1100 ['a', 'b', 'c']
1101 >>> a_iterable = s['a']
1102 >>> next(a_iterable)
1103 'a1'
1104 >>> next(a_iterable)
1105 'a2'
1106 >>> list(s['b'])
1107 ['b1', 'b2', 'b3']
1109 The original iterable will be advanced and its items will be cached until
1110 they are used by the child iterables. This may require significant storage.
1112 By default, attempting to select a bucket to which no items belong will
1113 exhaust the iterable and cache all values.
1114 If you specify a *validator* function, selected buckets will instead be
1115 checked against it.
1117 >>> from itertools import count
1118 >>> it = count(1, 2) # Infinite sequence of odd numbers
1119 >>> key = lambda x: x % 10 # Bucket by last digit
1120 >>> validator = lambda x: x in {1, 3, 5, 7, 9} # Odd digits only
1121 >>> s = bucket(it, key=key, validator=validator)
1122 >>> 2 in s
1123 False
1124 >>> list(s[2])
1125 []
1127 """
1129 def __init__(self, iterable, key, validator=None):
1130 self._it = iter(iterable)
1131 self._key = key
1132 self._cache = defaultdict(deque)
1133 self._validator = validator or (lambda x: True)
1135 def __contains__(self, value):
1136 if not self._validator(value):
1137 return False
1139 try:
1140 item = next(self[value])
1141 except StopIteration:
1142 return False
1143 else:
1144 self._cache[value].appendleft(item)
1146 return True
1148 def _get_values(self, value):
1149 """
1150 Helper to yield items from the parent iterator that match *value*.
1151 Items that don't match are stored in the local cache as they
1152 are encountered.
1153 """
1154 while True:
1155 # If we've cached some items that match the target value, emit
1156 # the first one and evict it from the cache.
1157 if self._cache[value]:
1158 yield self._cache[value].popleft()
1159 # Otherwise we need to advance the parent iterator to search for
1160 # a matching item, caching the rest.
1161 else:
1162 while True:
1163 try:
1164 item = next(self._it)
1165 except StopIteration:
1166 return
1167 item_value = self._key(item)
1168 if item_value == value:
1169 yield item
1170 break
1171 elif self._validator(item_value):
1172 self._cache[item_value].append(item)
1174 def __iter__(self):
1175 for item in self._it:
1176 item_value = self._key(item)
1177 if self._validator(item_value):
1178 self._cache[item_value].append(item)
1180 return iter(self._cache)
1182 def __getitem__(self, value):
1183 if not self._validator(value):
1184 return iter(())
1186 return self._get_values(value)
1189def spy(iterable, n=1):
1190 """Return a 2-tuple with a list containing the first *n* elements of
1191 *iterable*, and an iterator with the same items as *iterable*.
1192 This allows you to "look ahead" at the items in the iterable without
1193 advancing it.
1195 There is one item in the list by default:
1197 >>> iterable = 'abcdefg'
1198 >>> head, iterable = spy(iterable)
1199 >>> head
1200 ['a']
1201 >>> list(iterable)
1202 ['a', 'b', 'c', 'd', 'e', 'f', 'g']
1204 You may use unpacking to retrieve items instead of lists:
1206 >>> (head,), iterable = spy('abcdefg')
1207 >>> head
1208 'a'
1209 >>> (first, second), iterable = spy('abcdefg', 2)
1210 >>> first
1211 'a'
1212 >>> second
1213 'b'
1215 The number of items requested can be larger than the number of items in
1216 the iterable:
1218 >>> iterable = [1, 2, 3, 4, 5]
1219 >>> head, iterable = spy(iterable, 10)
1220 >>> head
1221 [1, 2, 3, 4, 5]
1222 >>> list(iterable)
1223 [1, 2, 3, 4, 5]
1225 """
1226 p, q = tee(iterable)
1227 return take(n, q), p
1230def interleave(*iterables):
1231 """Return a new iterable yielding from each iterable in turn,
1232 until the shortest is exhausted.
1234 >>> list(interleave([1, 2, 3], [4, 5], [6, 7, 8]))
1235 [1, 4, 6, 2, 5, 7]
1237 For a version that doesn't terminate after the shortest iterable is
1238 exhausted, see :func:`interleave_longest`.
1240 """
1241 return chain.from_iterable(zip(*iterables))
1244def interleave_longest(*iterables):
1245 """Return a new iterable yielding from each iterable in turn,
1246 skipping any that are exhausted.
1248 >>> list(interleave_longest([1, 2, 3], [4, 5], [6, 7, 8]))
1249 [1, 4, 6, 2, 5, 7, 3, 8]
1251 This function produces the same output as :func:`roundrobin`, but may
1252 perform better for some inputs (in particular when the number of iterables
1253 is large).
1255 """
1256 for xs in zip_longest(*iterables, fillvalue=_marker):
1257 for x in xs:
1258 if x is not _marker:
1259 yield x
1262def interleave_evenly(iterables, lengths=None):
1263 """
1264 Interleave multiple iterables so that their elements are evenly distributed
1265 throughout the output sequence.
1267 >>> iterables = [1, 2, 3, 4, 5], ['a', 'b']
1268 >>> list(interleave_evenly(iterables))
1269 [1, 2, 'a', 3, 4, 'b', 5]
1271 >>> iterables = [[1, 2, 3], [4, 5], [6, 7, 8]]
1272 >>> list(interleave_evenly(iterables))
1273 [1, 6, 4, 2, 7, 3, 8, 5]
1275 This function requires iterables of known length. Iterables without
1276 ``__len__()`` can be used by manually specifying lengths with *lengths*:
1278 >>> from itertools import combinations, repeat
1279 >>> iterables = [combinations(range(4), 2), ['a', 'b', 'c']]
1280 >>> lengths = [4 * (4 - 1) // 2, 3]
1281 >>> list(interleave_evenly(iterables, lengths=lengths))
1282 [(0, 1), (0, 2), 'a', (0, 3), (1, 2), 'b', (1, 3), (2, 3), 'c']
1284 Based on Bresenham's algorithm.
1285 """
1286 if lengths is None:
1287 try:
1288 lengths = [len(it) for it in iterables]
1289 except TypeError:
1290 raise ValueError(
1291 'Iterable lengths could not be determined automatically. '
1292 'Specify them with the lengths keyword.'
1293 )
1294 elif len(iterables) != len(lengths):
1295 raise ValueError('Mismatching number of iterables and lengths.')
1297 dims = len(lengths)
1299 # sort iterables by length, descending
1300 lengths_permute = sorted(
1301 range(dims), key=lambda i: lengths[i], reverse=True
1302 )
1303 lengths_desc = [lengths[i] for i in lengths_permute]
1304 iters_desc = [iter(iterables[i]) for i in lengths_permute]
1306 # the longest iterable is the primary one (Bresenham: the longest
1307 # distance along an axis)
1308 delta_primary, deltas_secondary = lengths_desc[0], lengths_desc[1:]
1309 iter_primary, iters_secondary = iters_desc[0], iters_desc[1:]
1310 errors = [delta_primary // dims] * len(deltas_secondary)
1312 to_yield = sum(lengths)
1313 while to_yield:
1314 yield next(iter_primary)
1315 to_yield -= 1
1316 # update errors for each secondary iterable
1317 errors = [e - delta for e, delta in zip(errors, deltas_secondary)]
1319 # those iterables for which the error is negative are yielded
1320 # ("diagonal step" in Bresenham)
1321 for i, e_ in enumerate(errors):
1322 if e_ < 0:
1323 yield next(iters_secondary[i])
1324 to_yield -= 1
1325 errors[i] += delta_primary
1328def interleave_randomly(*iterables):
1329 """Repeatedly select one of the input *iterables* at random and yield the next
1330 item from it.
1332 >>> iterables = [1, 2, 3], 'abc', (True, False, None)
1333 >>> list(interleave_randomly(*iterables)) # doctest: +SKIP
1334 ['a', 'b', 1, 'c', True, False, None, 2, 3]
1336 The relative order of the items in each input iterable will preserved. Note the
1337 sequences of items with this property are not equally likely to be generated.
1339 """
1340 iterators = [iter(e) for e in iterables]
1341 while iterators:
1342 idx = randrange(len(iterators))
1343 try:
1344 yield next(iterators[idx])
1345 except StopIteration:
1346 # equivalent to `list.pop` but slightly faster
1347 iterators[idx] = iterators[-1]
1348 del iterators[-1]
1351def collapse(iterable, base_type=None, levels=None):
1352 """Flatten an iterable with multiple levels of nesting (e.g., a list of
1353 lists of tuples) into non-iterable types.
1355 >>> iterable = [(1, 2), ([3, 4], [[5], [6]])]
1356 >>> list(collapse(iterable))
1357 [1, 2, 3, 4, 5, 6]
1359 Binary and text strings are not considered iterable and
1360 will not be collapsed.
1362 To avoid collapsing other types, specify *base_type*:
1364 >>> iterable = ['ab', ('cd', 'ef'), ['gh', 'ij']]
1365 >>> list(collapse(iterable, base_type=tuple))
1366 ['ab', ('cd', 'ef'), 'gh', 'ij']
1368 Specify *levels* to stop flattening after a certain level:
1370 >>> iterable = [('a', ['b']), ('c', ['d'])]
1371 >>> list(collapse(iterable)) # Fully flattened
1372 ['a', 'b', 'c', 'd']
1373 >>> list(collapse(iterable, levels=1)) # Only one level flattened
1374 ['a', ['b'], 'c', ['d']]
1376 """
1377 stack = deque()
1378 # Add our first node group, treat the iterable as a single node
1379 stack.appendleft((0, repeat(iterable, 1)))
1381 while stack:
1382 node_group = stack.popleft()
1383 level, nodes = node_group
1385 # Check if beyond max level
1386 if levels is not None and level > levels:
1387 yield from nodes
1388 continue
1390 for node in nodes:
1391 # Check if done iterating
1392 if isinstance(node, (str, bytes)) or (
1393 (base_type is not None) and isinstance(node, base_type)
1394 ):
1395 yield node
1396 # Otherwise try to create child nodes
1397 else:
1398 try:
1399 tree = iter(node)
1400 except TypeError:
1401 yield node
1402 else:
1403 # Save our current location
1404 stack.appendleft(node_group)
1405 # Append the new child node
1406 stack.appendleft((level + 1, tree))
1407 # Break to process child node
1408 break
1411def side_effect(func, iterable, chunk_size=None, before=None, after=None):
1412 """Invoke *func* on each item in *iterable* (or on each *chunk_size* group
1413 of items) before yielding the item.
1415 `func` must be a function that takes a single argument. Its return value
1416 will be discarded.
1418 *before* and *after* are optional functions that take no arguments. They
1419 will be executed before iteration starts and after it ends, respectively.
1421 `side_effect` can be used for logging, updating progress bars, or anything
1422 that is not functionally "pure."
1424 Emitting a status message:
1426 >>> from more_itertools import consume
1427 >>> func = lambda item: print('Received {}'.format(item))
1428 >>> consume(side_effect(func, range(2)))
1429 Received 0
1430 Received 1
1432 Operating on chunks of items:
1434 >>> pair_sums = []
1435 >>> func = lambda chunk: pair_sums.append(sum(chunk))
1436 >>> list(side_effect(func, [0, 1, 2, 3, 4, 5], 2))
1437 [0, 1, 2, 3, 4, 5]
1438 >>> list(pair_sums)
1439 [1, 5, 9]
1441 Writing to a file-like object:
1443 >>> from io import StringIO
1444 >>> from more_itertools import consume
1445 >>> f = StringIO()
1446 >>> func = lambda x: print(x, file=f)
1447 >>> before = lambda: print(u'HEADER', file=f)
1448 >>> after = f.close
1449 >>> it = [u'a', u'b', u'c']
1450 >>> consume(side_effect(func, it, before=before, after=after))
1451 >>> f.closed
1452 True
1454 """
1455 try:
1456 if before is not None:
1457 before()
1459 if chunk_size is None:
1460 for item in iterable:
1461 func(item)
1462 yield item
1463 else:
1464 for chunk in chunked(iterable, chunk_size):
1465 func(chunk)
1466 yield from chunk
1467 finally:
1468 if after is not None:
1469 after()
1472def sliced(seq, n, strict=False):
1473 """Yield slices of length *n* from the sequence *seq*.
1475 >>> list(sliced((1, 2, 3, 4, 5, 6), 3))
1476 [(1, 2, 3), (4, 5, 6)]
1478 By the default, the last yielded slice will have fewer than *n* elements
1479 if the length of *seq* is not divisible by *n*:
1481 >>> list(sliced((1, 2, 3, 4, 5, 6, 7, 8), 3))
1482 [(1, 2, 3), (4, 5, 6), (7, 8)]
1484 If the length of *seq* is not divisible by *n* and *strict* is
1485 ``True``, then ``ValueError`` will be raised before the last
1486 slice is yielded.
1488 This function will only work for iterables that support slicing.
1489 For non-sliceable iterables, see :func:`chunked`.
1491 """
1492 iterator = takewhile(len, (seq[i : i + n] for i in count(0, n)))
1493 if strict:
1495 def ret():
1496 for _slice in iterator:
1497 if len(_slice) != n:
1498 raise ValueError("seq is not divisible by n.")
1499 yield _slice
1501 return ret()
1502 else:
1503 return iterator
1506def split_at(iterable, pred, maxsplit=-1, keep_separator=False):
1507 """Yield lists of items from *iterable*, where each list is delimited by
1508 an item where callable *pred* returns ``True``.
1510 >>> list(split_at('abcdcba', lambda x: x == 'b'))
1511 [['a'], ['c', 'd', 'c'], ['a']]
1513 >>> list(split_at(range(10), lambda n: n % 2 == 1))
1514 [[0], [2], [4], [6], [8], []]
1516 At most *maxsplit* splits are done. If *maxsplit* is not specified or -1,
1517 then there is no limit on the number of splits:
1519 >>> list(split_at(range(10), lambda n: n % 2 == 1, maxsplit=2))
1520 [[0], [2], [4, 5, 6, 7, 8, 9]]
1522 By default, the delimiting items are not included in the output.
1523 To include them, set *keep_separator* to ``True``.
1525 >>> list(split_at('abcdcba', lambda x: x == 'b', keep_separator=True))
1526 [['a'], ['b'], ['c', 'd', 'c'], ['b'], ['a']]
1528 """
1529 if maxsplit == 0:
1530 yield list(iterable)
1531 return
1533 buf = []
1534 it = iter(iterable)
1535 for item in it:
1536 if pred(item):
1537 yield buf
1538 if keep_separator:
1539 yield [item]
1540 if maxsplit == 1:
1541 yield list(it)
1542 return
1543 buf = []
1544 maxsplit -= 1
1545 else:
1546 buf.append(item)
1547 yield buf
1550def split_before(iterable, pred, maxsplit=-1):
1551 """Yield lists of items from *iterable*, where each list ends just before
1552 an item for which callable *pred* returns ``True``:
1554 >>> list(split_before('OneTwo', lambda s: s.isupper()))
1555 [['O', 'n', 'e'], ['T', 'w', 'o']]
1557 >>> list(split_before(range(10), lambda n: n % 3 == 0))
1558 [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
1560 At most *maxsplit* splits are done. If *maxsplit* is not specified or -1,
1561 then there is no limit on the number of splits:
1563 >>> list(split_before(range(10), lambda n: n % 3 == 0, maxsplit=2))
1564 [[0, 1, 2], [3, 4, 5], [6, 7, 8, 9]]
1565 """
1566 if maxsplit == 0:
1567 yield list(iterable)
1568 return
1570 buf = []
1571 it = iter(iterable)
1572 for item in it:
1573 if pred(item) and buf:
1574 yield buf
1575 if maxsplit == 1:
1576 yield [item, *it]
1577 return
1578 buf = []
1579 maxsplit -= 1
1580 buf.append(item)
1581 if buf:
1582 yield buf
1585def split_after(iterable, pred, maxsplit=-1):
1586 """Yield lists of items from *iterable*, where each list ends with an
1587 item where callable *pred* returns ``True``:
1589 >>> list(split_after('one1two2', lambda s: s.isdigit()))
1590 [['o', 'n', 'e', '1'], ['t', 'w', 'o', '2']]
1592 >>> list(split_after(range(10), lambda n: n % 3 == 0))
1593 [[0], [1, 2, 3], [4, 5, 6], [7, 8, 9]]
1595 At most *maxsplit* splits are done. If *maxsplit* is not specified or -1,
1596 then there is no limit on the number of splits:
1598 >>> list(split_after(range(10), lambda n: n % 3 == 0, maxsplit=2))
1599 [[0], [1, 2, 3], [4, 5, 6, 7, 8, 9]]
1601 """
1602 if maxsplit == 0:
1603 yield list(iterable)
1604 return
1606 buf = []
1607 it = iter(iterable)
1608 for item in it:
1609 buf.append(item)
1610 if pred(item) and buf:
1611 yield buf
1612 if maxsplit == 1:
1613 buf = list(it)
1614 if buf:
1615 yield buf
1616 return
1617 buf = []
1618 maxsplit -= 1
1619 if buf:
1620 yield buf
1623def split_when(iterable, pred, maxsplit=-1):
1624 """Split *iterable* into pieces based on the output of *pred*.
1625 *pred* should be a function that takes successive pairs of items and
1626 returns ``True`` if the iterable should be split in between them.
1628 For example, to find runs of increasing numbers, split the iterable when
1629 element ``i`` is larger than element ``i + 1``:
1631 >>> list(split_when([1, 2, 3, 3, 2, 5, 2, 4, 2], lambda x, y: x > y))
1632 [[1, 2, 3, 3], [2, 5], [2, 4], [2]]
1634 At most *maxsplit* splits are done. If *maxsplit* is not specified or -1,
1635 then there is no limit on the number of splits:
1637 >>> list(split_when([1, 2, 3, 3, 2, 5, 2, 4, 2],
1638 ... lambda x, y: x > y, maxsplit=2))
1639 [[1, 2, 3, 3], [2, 5], [2, 4, 2]]
1641 """
1642 if maxsplit == 0:
1643 yield list(iterable)
1644 return
1646 it = iter(iterable)
1647 try:
1648 cur_item = next(it)
1649 except StopIteration:
1650 return
1652 buf = [cur_item]
1653 for next_item in it:
1654 if pred(cur_item, next_item):
1655 yield buf
1656 if maxsplit == 1:
1657 yield [next_item, *it]
1658 return
1659 buf = []
1660 maxsplit -= 1
1662 buf.append(next_item)
1663 cur_item = next_item
1665 yield buf
1668def split_into(iterable, sizes):
1669 """Yield a list of sequential items from *iterable* of length 'n' for each
1670 integer 'n' in *sizes*.
1672 >>> list(split_into([1,2,3,4,5,6], [1,2,3]))
1673 [[1], [2, 3], [4, 5, 6]]
1675 If the sum of *sizes* is smaller than the length of *iterable*, then the
1676 remaining items of *iterable* will not be returned.
1678 >>> list(split_into([1,2,3,4,5,6], [2,3]))
1679 [[1, 2], [3, 4, 5]]
1681 If the sum of *sizes* is larger than the length of *iterable*, fewer items
1682 will be returned in the iteration that overruns the *iterable* and further
1683 lists will be empty:
1685 >>> list(split_into([1,2,3,4], [1,2,3,4]))
1686 [[1], [2, 3], [4], []]
1688 When a ``None`` object is encountered in *sizes*, the returned list will
1689 contain items up to the end of *iterable* the same way that
1690 :func:`itertools.slice` does:
1692 >>> list(split_into([1,2,3,4,5,6,7,8,9,0], [2,3,None]))
1693 [[1, 2], [3, 4, 5], [6, 7, 8, 9, 0]]
1695 :func:`split_into` can be useful for grouping a series of items where the
1696 sizes of the groups are not uniform. An example would be where in a row
1697 from a table, multiple columns represent elements of the same feature
1698 (e.g. a point represented by x,y,z) but, the format is not the same for
1699 all columns.
1700 """
1701 # convert the iterable argument into an iterator so its contents can
1702 # be consumed by islice in case it is a generator
1703 it = iter(iterable)
1705 for size in sizes:
1706 if size is None:
1707 yield list(it)
1708 return
1709 else:
1710 yield list(islice(it, size))
1713def padded(iterable, fillvalue=None, n=None, next_multiple=False):
1714 """Yield the elements from *iterable*, followed by *fillvalue*, such that
1715 at least *n* items are emitted.
1717 >>> list(padded([1, 2, 3], '?', 5))
1718 [1, 2, 3, '?', '?']
1720 If *next_multiple* is ``True``, *fillvalue* will be emitted until the
1721 number of items emitted is a multiple of *n*:
1723 >>> list(padded([1, 2, 3, 4], n=3, next_multiple=True))
1724 [1, 2, 3, 4, None, None]
1726 If *n* is ``None``, *fillvalue* will be emitted indefinitely.
1728 To create an *iterable* of exactly size *n*, you can truncate with
1729 :func:`islice`.
1731 >>> list(islice(padded([1, 2, 3], '?'), 5))
1732 [1, 2, 3, '?', '?']
1733 >>> list(islice(padded([1, 2, 3, 4, 5, 6, 7, 8], '?'), 5))
1734 [1, 2, 3, 4, 5]
1736 """
1737 iterator = iter(iterable)
1738 iterator_with_repeat = chain(iterator, repeat(fillvalue))
1740 if n is None:
1741 return iterator_with_repeat
1742 elif n < 1:
1743 raise ValueError('n must be at least 1')
1744 elif next_multiple:
1746 def slice_generator():
1747 for first in iterator:
1748 yield (first,)
1749 yield islice(iterator_with_repeat, n - 1)
1751 # While elements exist produce slices of size n
1752 return chain.from_iterable(slice_generator())
1753 else:
1754 # Ensure the first batch is at least size n then iterate
1755 return chain(islice(iterator_with_repeat, n), iterator)
1758def repeat_each(iterable, n=2):
1759 """Repeat each element in *iterable* *n* times.
1761 >>> list(repeat_each('ABC', 3))
1762 ['A', 'A', 'A', 'B', 'B', 'B', 'C', 'C', 'C']
1763 """
1764 return chain.from_iterable(map(repeat, iterable, repeat(n)))
1767def repeat_last(iterable, default=None):
1768 """After the *iterable* is exhausted, keep yielding its last element.
1770 >>> list(islice(repeat_last(range(3)), 5))
1771 [0, 1, 2, 2, 2]
1773 If the iterable is empty, yield *default* forever::
1775 >>> list(islice(repeat_last(range(0), 42), 5))
1776 [42, 42, 42, 42, 42]
1778 """
1779 item = _marker
1780 for item in iterable:
1781 yield item
1782 final = default if item is _marker else item
1783 yield from repeat(final)
1786def distribute(n, iterable):
1787 """Distribute the items from *iterable* among *n* smaller iterables.
1789 >>> group_1, group_2 = distribute(2, [1, 2, 3, 4, 5, 6])
1790 >>> list(group_1)
1791 [1, 3, 5]
1792 >>> list(group_2)
1793 [2, 4, 6]
1795 If the length of *iterable* is not evenly divisible by *n*, then the
1796 length of the returned iterables will not be identical:
1798 >>> children = distribute(3, [1, 2, 3, 4, 5, 6, 7])
1799 >>> [list(c) for c in children]
1800 [[1, 4, 7], [2, 5], [3, 6]]
1802 If the length of *iterable* is smaller than *n*, then the last returned
1803 iterables will be empty:
1805 >>> children = distribute(5, [1, 2, 3])
1806 >>> [list(c) for c in children]
1807 [[1], [2], [3], [], []]
1809 This function uses :func:`itertools.tee` and may require significant
1810 storage.
1812 If you need the order items in the smaller iterables to match the
1813 original iterable, see :func:`divide`.
1815 """
1816 if n < 1:
1817 raise ValueError('n must be at least 1')
1819 children = tee(iterable, n)
1820 return [islice(it, index, None, n) for index, it in enumerate(children)]
1823def stagger(iterable, offsets=(-1, 0, 1), longest=False, fillvalue=None):
1824 """Yield tuples whose elements are offset from *iterable*.
1825 The amount by which the `i`-th item in each tuple is offset is given by
1826 the `i`-th item in *offsets*.
1828 >>> list(stagger([0, 1, 2, 3]))
1829 [(None, 0, 1), (0, 1, 2), (1, 2, 3)]
1830 >>> list(stagger(range(8), offsets=(0, 2, 4)))
1831 [(0, 2, 4), (1, 3, 5), (2, 4, 6), (3, 5, 7)]
1833 By default, the sequence will end when the final element of a tuple is the
1834 last item in the iterable. To continue until the first element of a tuple
1835 is the last item in the iterable, set *longest* to ``True``::
1837 >>> list(stagger([0, 1, 2, 3], longest=True))
1838 [(None, 0, 1), (0, 1, 2), (1, 2, 3), (2, 3, None), (3, None, None)]
1840 By default, ``None`` will be used to replace offsets beyond the end of the
1841 sequence. Specify *fillvalue* to use some other value.
1843 """
1844 children = tee(iterable, len(offsets))
1846 return zip_offset(
1847 *children, offsets=offsets, longest=longest, fillvalue=fillvalue
1848 )
1851def zip_offset(*iterables, offsets, longest=False, fillvalue=None):
1852 """``zip`` the input *iterables* together, but offset the `i`-th iterable
1853 by the `i`-th item in *offsets*.
1855 >>> list(zip_offset('0123', 'abcdef', offsets=(0, 1)))
1856 [('0', 'b'), ('1', 'c'), ('2', 'd'), ('3', 'e')]
1858 This can be used as a lightweight alternative to SciPy or pandas to analyze
1859 data sets in which some series have a lead or lag relationship.
1861 By default, the sequence will end when the shortest iterable is exhausted.
1862 To continue until the longest iterable is exhausted, set *longest* to
1863 ``True``.
1865 >>> list(zip_offset('0123', 'abcdef', offsets=(0, 1), longest=True))
1866 [('0', 'b'), ('1', 'c'), ('2', 'd'), ('3', 'e'), (None, 'f')]
1868 By default, ``None`` will be used to replace offsets beyond the end of the
1869 sequence. Specify *fillvalue* to use some other value.
1871 """
1872 if len(iterables) != len(offsets):
1873 raise ValueError("Number of iterables and offsets didn't match")
1875 staggered = []
1876 for it, n in zip(iterables, offsets):
1877 if n < 0:
1878 staggered.append(chain(repeat(fillvalue, -n), it))
1879 elif n > 0:
1880 staggered.append(islice(it, n, None))
1881 else:
1882 staggered.append(it)
1884 if longest:
1885 return zip_longest(*staggered, fillvalue=fillvalue)
1887 return zip(*staggered)
1890def sort_together(
1891 iterables, key_list=(0,), key=None, reverse=False, strict=False
1892):
1893 """Return the input iterables sorted together, with *key_list* as the
1894 priority for sorting. All iterables are trimmed to the length of the
1895 shortest one.
1897 This can be used like the sorting function in a spreadsheet. If each
1898 iterable represents a column of data, the key list determines which
1899 columns are used for sorting.
1901 By default, all iterables are sorted using the ``0``-th iterable::
1903 >>> iterables = [(4, 3, 2, 1), ('a', 'b', 'c', 'd')]
1904 >>> sort_together(iterables)
1905 [(1, 2, 3, 4), ('d', 'c', 'b', 'a')]
1907 Set a different key list to sort according to another iterable.
1908 Specifying multiple keys dictates how ties are broken::
1910 >>> iterables = [(3, 1, 2), (0, 1, 0), ('c', 'b', 'a')]
1911 >>> sort_together(iterables, key_list=(1, 2))
1912 [(2, 3, 1), (0, 0, 1), ('a', 'c', 'b')]
1914 To sort by a function of the elements of the iterable, pass a *key*
1915 function. Its arguments are the elements of the iterables corresponding to
1916 the key list::
1918 >>> names = ('a', 'b', 'c')
1919 >>> lengths = (1, 2, 3)
1920 >>> widths = (5, 2, 1)
1921 >>> def area(length, width):
1922 ... return length * width
1923 >>> sort_together([names, lengths, widths], key_list=(1, 2), key=area)
1924 [('c', 'b', 'a'), (3, 2, 1), (1, 2, 5)]
1926 Set *reverse* to ``True`` to sort in descending order.
1928 >>> sort_together([(1, 2, 3), ('c', 'b', 'a')], reverse=True)
1929 [(3, 2, 1), ('a', 'b', 'c')]
1931 If the *strict* keyword argument is ``True``, then
1932 ``ValueError`` will be raised if any of the iterables have
1933 different lengths.
1935 """
1936 if key is None:
1937 # if there is no key function, the key argument to sorted is an
1938 # itemgetter
1939 key_argument = itemgetter(*key_list)
1940 else:
1941 # if there is a key function, call it with the items at the offsets
1942 # specified by the key function as arguments
1943 key_list = list(key_list)
1944 if len(key_list) == 1:
1945 # if key_list contains a single item, pass the item at that offset
1946 # as the only argument to the key function
1947 key_offset = key_list[0]
1948 key_argument = lambda zipped_items: key(zipped_items[key_offset])
1949 else:
1950 # if key_list contains multiple items, use itemgetter to return a
1951 # tuple of items, which we pass as *args to the key function
1952 get_key_items = itemgetter(*key_list)
1953 key_argument = lambda zipped_items: key(
1954 *get_key_items(zipped_items)
1955 )
1957 transposed = zip(*iterables, strict=strict)
1958 reordered = sorted(transposed, key=key_argument, reverse=reverse)
1959 untransposed = zip(*reordered, strict=strict)
1960 return list(untransposed)
1963def unzip(iterable):
1964 """The inverse of :func:`zip`, this function disaggregates the elements
1965 of the zipped *iterable*.
1967 The ``i``-th iterable contains the ``i``-th element from each element
1968 of the zipped iterable. The first element is used to determine the
1969 length of the remaining elements.
1971 >>> iterable = [('a', 1), ('b', 2), ('c', 3), ('d', 4)]
1972 >>> letters, numbers = unzip(iterable)
1973 >>> list(letters)
1974 ['a', 'b', 'c', 'd']
1975 >>> list(numbers)
1976 [1, 2, 3, 4]
1978 This is similar to using ``zip(*iterable)``, but it avoids reading
1979 *iterable* into memory. Note, however, that this function uses
1980 :func:`itertools.tee` and thus may require significant storage.
1982 """
1983 head, iterable = spy(iterable)
1984 if not head:
1985 # empty iterable, e.g. zip([], [], [])
1986 return ()
1987 # spy returns a one-length iterable as head
1988 head = head[0]
1989 iterables = tee(iterable, len(head))
1991 # If we have an iterable like iter([(1, 2, 3), (4, 5), (6,)]),
1992 # the second unzipped iterable fails at the third tuple since
1993 # it tries to access (6,)[1].
1994 # Same with the third unzipped iterable and the second tuple.
1995 # To support these "improperly zipped" iterables, we suppress
1996 # the IndexError, which just stops the unzipped iterables at
1997 # first length mismatch.
1998 return tuple(
1999 iter_suppress(map(itemgetter(i), it), IndexError)
2000 for i, it in enumerate(iterables)
2001 )
2004def divide(n, iterable):
2005 """Divide the elements from *iterable* into *n* parts, maintaining
2006 order.
2008 >>> group_1, group_2 = divide(2, [1, 2, 3, 4, 5, 6])
2009 >>> list(group_1)
2010 [1, 2, 3]
2011 >>> list(group_2)
2012 [4, 5, 6]
2014 If the length of *iterable* is not evenly divisible by *n*, then the
2015 length of the returned iterables will not be identical:
2017 >>> children = divide(3, [1, 2, 3, 4, 5, 6, 7])
2018 >>> [list(c) for c in children]
2019 [[1, 2, 3], [4, 5], [6, 7]]
2021 If the length of the iterable is smaller than n, then the last returned
2022 iterables will be empty:
2024 >>> children = divide(5, [1, 2, 3])
2025 >>> [list(c) for c in children]
2026 [[1], [2], [3], [], []]
2028 This function will exhaust the iterable before returning.
2029 If order is not important, see :func:`distribute`, which does not first
2030 pull the iterable into memory.
2032 """
2033 if n < 1:
2034 raise ValueError('n must be at least 1')
2036 try:
2037 iterable[:0]
2038 except TypeError:
2039 seq = tuple(iterable)
2040 else:
2041 seq = iterable
2043 q, r = divmod(len(seq), n)
2045 ret = []
2046 stop = 0
2047 for i in range(1, n + 1):
2048 start = stop
2049 stop += q + 1 if i <= r else q
2050 ret.append(iter(seq[start:stop]))
2052 return ret
2055def always_iterable(obj, base_type=(str, bytes)):
2056 """If *obj* is iterable, return an iterator over its items::
2058 >>> obj = (1, 2, 3)
2059 >>> list(always_iterable(obj))
2060 [1, 2, 3]
2062 If *obj* is not iterable, return a one-item iterable containing *obj*::
2064 >>> obj = 1
2065 >>> list(always_iterable(obj))
2066 [1]
2068 If *obj* is ``None``, return an empty iterable:
2070 >>> obj = None
2071 >>> list(always_iterable(None))
2072 []
2074 By default, binary and text strings are not considered iterable::
2076 >>> obj = 'foo'
2077 >>> list(always_iterable(obj))
2078 ['foo']
2080 If *base_type* is set, objects for which ``isinstance(obj, base_type)``
2081 returns ``True`` won't be considered iterable.
2083 >>> obj = {'a': 1}
2084 >>> list(always_iterable(obj)) # Iterate over the dict's keys
2085 ['a']
2086 >>> list(always_iterable(obj, base_type=dict)) # Treat dicts as a unit
2087 [{'a': 1}]
2089 Set *base_type* to ``None`` to avoid any special handling and treat objects
2090 Python considers iterable as iterable:
2092 >>> obj = 'foo'
2093 >>> list(always_iterable(obj, base_type=None))
2094 ['f', 'o', 'o']
2095 """
2096 if obj is None:
2097 return iter(())
2099 if (base_type is not None) and isinstance(obj, base_type):
2100 return iter((obj,))
2102 try:
2103 return iter(obj)
2104 except TypeError:
2105 return iter((obj,))
2108def adjacent(predicate, iterable, distance=1):
2109 """Return an iterable over `(bool, item)` tuples where the `item` is
2110 drawn from *iterable* and the `bool` indicates whether
2111 that item satisfies the *predicate* or is adjacent to an item that does.
2113 For example, to find whether items are adjacent to a ``3``::
2115 >>> list(adjacent(lambda x: x == 3, range(6)))
2116 [(False, 0), (False, 1), (True, 2), (True, 3), (True, 4), (False, 5)]
2118 Set *distance* to change what counts as adjacent. For example, to find
2119 whether items are two places away from a ``3``:
2121 >>> list(adjacent(lambda x: x == 3, range(6), distance=2))
2122 [(False, 0), (True, 1), (True, 2), (True, 3), (True, 4), (True, 5)]
2124 This is useful for contextualizing the results of a search function.
2125 For example, a code comparison tool might want to identify lines that
2126 have changed, but also surrounding lines to give the viewer of the diff
2127 context.
2129 The predicate function will only be called once for each item in the
2130 iterable.
2132 See also :func:`groupby_transform`, which can be used with this function
2133 to group ranges of items with the same `bool` value.
2135 """
2136 # Allow distance=0 mainly for testing that it reproduces results with map()
2137 if distance < 0:
2138 raise ValueError('distance must be at least 0')
2140 i1, i2 = tee(iterable)
2141 padding = [False] * distance
2142 selected = chain(padding, map(predicate, i1), padding)
2143 adjacent_to_selected = map(any, windowed(selected, 2 * distance + 1))
2144 return zip(adjacent_to_selected, i2)
2147def groupby_transform(iterable, keyfunc=None, valuefunc=None, reducefunc=None):
2148 """An extension of :func:`itertools.groupby` that can apply transformations
2149 to the grouped data.
2151 * *keyfunc* is a function computing a key value for each item in *iterable*
2152 * *valuefunc* is a function that transforms the individual items from
2153 *iterable* after grouping
2154 * *reducefunc* is a function that transforms each group of items
2156 >>> iterable = 'aAAbBBcCC'
2157 >>> keyfunc = lambda k: k.upper()
2158 >>> valuefunc = lambda v: v.lower()
2159 >>> reducefunc = lambda g: ''.join(g)
2160 >>> list(groupby_transform(iterable, keyfunc, valuefunc, reducefunc))
2161 [('A', 'aaa'), ('B', 'bbb'), ('C', 'ccc')]
2163 Each optional argument defaults to an identity function if not specified.
2165 :func:`groupby_transform` is useful when grouping elements of an iterable
2166 using a separate iterable as the key. To do this, :func:`zip` the iterables
2167 and pass a *keyfunc* that extracts the first element and a *valuefunc*
2168 that extracts the second element::
2170 >>> from operator import itemgetter
2171 >>> keys = [0, 0, 1, 1, 1, 2, 2, 2, 3]
2172 >>> values = 'abcdefghi'
2173 >>> iterable = zip(keys, values)
2174 >>> grouper = groupby_transform(iterable, itemgetter(0), itemgetter(1))
2175 >>> [(k, ''.join(g)) for k, g in grouper]
2176 [(0, 'ab'), (1, 'cde'), (2, 'fgh'), (3, 'i')]
2178 Note that the order of items in the iterable is significant.
2179 Only adjacent items are grouped together, so if you don't want any
2180 duplicate groups, you should sort the iterable by the key function.
2182 """
2183 ret = groupby(iterable, keyfunc)
2184 if valuefunc:
2185 ret = ((k, map(valuefunc, g)) for k, g in ret)
2186 if reducefunc:
2187 ret = ((k, reducefunc(g)) for k, g in ret)
2189 return ret
2192class numeric_range(Sequence):
2193 """An extension of the built-in ``range()`` function whose arguments can
2194 be any orderable numeric type.
2196 With only *stop* specified, *start* defaults to ``0`` and *step*
2197 defaults to ``1``. The output items will match the type of *stop*:
2199 >>> list(numeric_range(3.5))
2200 [0.0, 1.0, 2.0, 3.0]
2202 With only *start* and *stop* specified, *step* defaults to ``1``. The
2203 output items will match the type of *start*:
2205 >>> from decimal import Decimal
2206 >>> start = Decimal('2.1')
2207 >>> stop = Decimal('5.1')
2208 >>> list(numeric_range(start, stop))
2209 [Decimal('2.1'), Decimal('3.1'), Decimal('4.1')]
2211 With *start*, *stop*, and *step* specified the output items will match
2212 the type of ``start + step``:
2214 >>> from fractions import Fraction
2215 >>> start = Fraction(1, 2) # Start at 1/2
2216 >>> stop = Fraction(5, 2) # End at 5/2
2217 >>> step = Fraction(1, 2) # Count by 1/2
2218 >>> list(numeric_range(start, stop, step))
2219 [Fraction(1, 2), Fraction(1, 1), Fraction(3, 2), Fraction(2, 1)]
2221 If *step* is zero, ``ValueError`` is raised. Negative steps are supported:
2223 >>> list(numeric_range(3, -1, -1.0))
2224 [3.0, 2.0, 1.0, 0.0]
2226 Be aware of the limitations of floating-point numbers; the representation
2227 of the yielded numbers may be surprising.
2229 ``datetime.datetime`` objects can be used for *start* and *stop*, if *step*
2230 is a ``datetime.timedelta`` object:
2232 >>> import datetime
2233 >>> start = datetime.datetime(2019, 1, 1)
2234 >>> stop = datetime.datetime(2019, 1, 3)
2235 >>> step = datetime.timedelta(days=1)
2236 >>> items = iter(numeric_range(start, stop, step))
2237 >>> next(items)
2238 datetime.datetime(2019, 1, 1, 0, 0)
2239 >>> next(items)
2240 datetime.datetime(2019, 1, 2, 0, 0)
2242 """
2244 _EMPTY_HASH = hash(range(0, 0))
2246 def __init__(self, *args):
2247 argc = len(args)
2248 if argc == 1:
2249 (self._stop,) = args
2250 self._start = type(self._stop)(0)
2251 self._step = type(self._stop - self._start)(1)
2252 elif argc == 2:
2253 self._start, self._stop = args
2254 self._step = type(self._stop - self._start)(1)
2255 elif argc == 3:
2256 self._start, self._stop, self._step = args
2257 elif argc == 0:
2258 raise TypeError(
2259 f'numeric_range expected at least 1 argument, got {argc}'
2260 )
2261 else:
2262 raise TypeError(
2263 f'numeric_range expected at most 3 arguments, got {argc}'
2264 )
2266 self._zero = type(self._step)(0)
2267 if self._step == self._zero:
2268 raise ValueError('numeric_range() arg 3 must not be zero')
2269 self._growing = self._step > self._zero
2271 def __bool__(self):
2272 if self._growing:
2273 return self._start < self._stop
2274 else:
2275 return self._start > self._stop
2277 def __contains__(self, elem):
2278 if self._growing:
2279 if self._start <= elem < self._stop:
2280 return (elem - self._start) % self._step == self._zero
2281 else:
2282 if self._start >= elem > self._stop:
2283 return (self._start - elem) % (-self._step) == self._zero
2285 return False
2287 def __eq__(self, other):
2288 if isinstance(other, numeric_range):
2289 empty_self = not bool(self)
2290 empty_other = not bool(other)
2291 if empty_self or empty_other:
2292 return empty_self and empty_other # True if both empty
2293 else:
2294 return (
2295 self._start == other._start
2296 and self._step == other._step
2297 and self._get_by_index(-1) == other._get_by_index(-1)
2298 )
2299 else:
2300 return False
2302 def __getitem__(self, key):
2303 if isinstance(key, int):
2304 return self._get_by_index(key)
2305 elif isinstance(key, slice):
2306 step = self._step if key.step is None else key.step * self._step
2308 if key.start is None or key.start <= -self._len:
2309 start = self._start
2310 elif key.start >= self._len:
2311 start = self._stop
2312 else: # -self._len < key.start < self._len
2313 start = self._get_by_index(key.start)
2315 if key.stop is None or key.stop >= self._len:
2316 stop = self._stop
2317 elif key.stop <= -self._len:
2318 stop = self._start
2319 else: # -self._len < key.stop < self._len
2320 stop = self._get_by_index(key.stop)
2322 return numeric_range(start, stop, step)
2323 else:
2324 raise TypeError(
2325 'numeric range indices must be '
2326 f'integers or slices, not {type(key).__name__}'
2327 )
2329 def __hash__(self):
2330 if self:
2331 return hash((self._start, self._get_by_index(-1), self._step))
2332 else:
2333 return self._EMPTY_HASH
2335 def __iter__(self):
2336 values = (self._start + (n * self._step) for n in count())
2337 if self._growing:
2338 return takewhile(partial(gt, self._stop), values)
2339 else:
2340 return takewhile(partial(lt, self._stop), values)
2342 def __len__(self):
2343 return self._len
2345 @cached_property
2346 def _len(self):
2347 if self._growing:
2348 start = self._start
2349 stop = self._stop
2350 step = self._step
2351 else:
2352 start = self._stop
2353 stop = self._start
2354 step = -self._step
2355 distance = stop - start
2356 if distance <= self._zero:
2357 return 0
2358 else: # distance > 0 and step > 0: regular euclidean division
2359 q, r = divmod(distance, step)
2360 return int(q) + int(r != self._zero)
2362 def __reduce__(self):
2363 return numeric_range, (self._start, self._stop, self._step)
2365 def __repr__(self):
2366 if self._step == 1:
2367 return f"numeric_range({self._start!r}, {self._stop!r})"
2368 return (
2369 f"numeric_range({self._start!r}, {self._stop!r}, {self._step!r})"
2370 )
2372 def __reversed__(self):
2373 return iter(
2374 numeric_range(
2375 self._get_by_index(-1), self._start - self._step, -self._step
2376 )
2377 )
2379 def count(self, value):
2380 return int(value in self)
2382 def index(self, value):
2383 if self._growing:
2384 if self._start <= value < self._stop:
2385 q, r = divmod(value - self._start, self._step)
2386 if r == self._zero:
2387 return int(q)
2388 else:
2389 if self._start >= value > self._stop:
2390 q, r = divmod(self._start - value, -self._step)
2391 if r == self._zero:
2392 return int(q)
2394 raise ValueError(f"{value} is not in numeric range")
2396 def _get_by_index(self, i):
2397 if i < 0:
2398 i += self._len
2399 if i < 0 or i >= self._len:
2400 raise IndexError("numeric range object index out of range")
2401 return self._start + i * self._step
2404def count_cycle(iterable, n=None):
2405 """Cycle through the items from *iterable* up to *n* times, yielding
2406 the number of completed cycles along with each item. If *n* is omitted the
2407 process repeats indefinitely.
2409 >>> list(count_cycle('AB', 3))
2410 [(0, 'A'), (0, 'B'), (1, 'A'), (1, 'B'), (2, 'A'), (2, 'B')]
2412 """
2413 seq = tuple(iterable)
2414 if not seq:
2415 return iter(())
2416 counter = count() if n is None else range(n)
2417 return zip(repeat_each(counter, len(seq)), cycle(seq))
2420def mark_ends(iterable):
2421 """Yield 3-tuples of the form ``(is_first, is_last, item)``.
2423 >>> list(mark_ends('ABC'))
2424 [(True, False, 'A'), (False, False, 'B'), (False, True, 'C')]
2426 Use this when looping over an iterable to take special action on its first
2427 and/or last items:
2429 >>> iterable = ['Header', 100, 200, 'Footer']
2430 >>> total = 0
2431 >>> for is_first, is_last, item in mark_ends(iterable):
2432 ... if is_first:
2433 ... continue # Skip the header
2434 ... if is_last:
2435 ... continue # Skip the footer
2436 ... total += item
2437 >>> print(total)
2438 300
2439 """
2440 it = iter(iterable)
2441 for a in it:
2442 first = True
2443 for b in it:
2444 yield first, False, a
2445 a = b
2446 first = False
2447 yield first, True, a
2450def locate(iterable, pred=bool, window_size=None):
2451 """Yield the index of each item in *iterable* for which *pred* returns
2452 ``True``.
2454 *pred* defaults to :func:`bool`, which will select truthy items:
2456 >>> list(locate([0, 1, 1, 0, 1, 0, 0]))
2457 [1, 2, 4]
2459 Set *pred* to a custom function to, e.g., find the indexes for a particular
2460 item.
2462 >>> list(locate(['a', 'b', 'c', 'b'], lambda x: x == 'b'))
2463 [1, 3]
2465 If *window_size* is given, then the *pred* function will be called with
2466 that many items. This enables searching for sub-sequences:
2468 >>> iterable = [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]
2469 >>> pred = lambda *args: args == (1, 2, 3)
2470 >>> list(locate(iterable, pred=pred, window_size=3))
2471 [1, 5, 9]
2473 Use with :func:`seekable` to find indexes and then retrieve the associated
2474 items:
2476 >>> from itertools import count
2477 >>> from more_itertools import seekable
2478 >>> source = (3 * n + 1 if (n % 2) else n // 2 for n in count())
2479 >>> it = seekable(source)
2480 >>> pred = lambda x: x > 100
2481 >>> indexes = locate(it, pred=pred)
2482 >>> i = next(indexes)
2483 >>> it.seek(i)
2484 >>> next(it)
2485 106
2487 """
2488 if window_size is None:
2489 return compress(count(), map(pred, iterable))
2491 if window_size < 1:
2492 raise ValueError('window size must be at least 1')
2494 it = windowed(iterable, window_size, fillvalue=_marker)
2495 return compress(count(), starmap(pred, it))
2498def longest_common_prefix(iterables):
2499 """Yield elements of the longest common prefix among given *iterables*.
2501 >>> ''.join(longest_common_prefix(['abcd', 'abc', 'abf']))
2502 'ab'
2504 """
2505 return (c[0] for c in takewhile(all_equal, zip(*iterables)))
2508def lstrip(iterable, pred):
2509 """Yield the items from *iterable*, but strip any from the beginning
2510 for which *pred* returns ``True``.
2512 For example, to remove a set of items from the start of an iterable:
2514 >>> iterable = (None, False, None, 1, 2, None, 3, False, None)
2515 >>> pred = lambda x: x in {None, False, ''}
2516 >>> list(lstrip(iterable, pred))
2517 [1, 2, None, 3, False, None]
2519 This function is analogous to to :func:`str.lstrip`, and is essentially
2520 an wrapper for :func:`itertools.dropwhile`.
2522 """
2523 return dropwhile(pred, iterable)
2526def rstrip(iterable, pred):
2527 """Yield the items from *iterable*, but strip any from the end
2528 for which *pred* returns ``True``.
2530 For example, to remove a set of items from the end of an iterable:
2532 >>> iterable = (None, False, None, 1, 2, None, 3, False, None)
2533 >>> pred = lambda x: x in {None, False, ''}
2534 >>> list(rstrip(iterable, pred))
2535 [None, False, None, 1, 2, None, 3]
2537 This function is analogous to :func:`str.rstrip`.
2539 """
2540 cache = []
2541 cache_append = cache.append
2542 cache_clear = cache.clear
2543 for x in iterable:
2544 if pred(x):
2545 cache_append(x)
2546 else:
2547 yield from cache
2548 cache_clear()
2549 yield x
2552def strip(iterable, pred):
2553 """Yield the items from *iterable*, but strip any from the
2554 beginning and end for which *pred* returns ``True``.
2556 For example, to remove a set of items from both ends of an iterable:
2558 >>> iterable = (None, False, None, 1, 2, None, 3, False, None)
2559 >>> pred = lambda x: x in {None, False, ''}
2560 >>> list(strip(iterable, pred))
2561 [1, 2, None, 3]
2563 This function is analogous to :func:`str.strip`.
2565 """
2566 return rstrip(lstrip(iterable, pred), pred)
2569class islice_extended:
2570 """An extension of :func:`itertools.islice` that supports negative values
2571 for *stop*, *start*, and *step*.
2573 >>> iterator = iter('abcdefgh')
2574 >>> list(islice_extended(iterator, -4, -1))
2575 ['e', 'f', 'g']
2577 Slices with negative values require some caching of *iterable*, but this
2578 function takes care to minimize the amount of memory required.
2580 For example, you can use a negative step with an infinite iterator:
2582 >>> from itertools import count
2583 >>> list(islice_extended(count(), 110, 99, -2))
2584 [110, 108, 106, 104, 102, 100]
2586 You can also use slice notation directly:
2588 >>> iterator = map(str, count())
2589 >>> it = islice_extended(iterator)[10:20:2]
2590 >>> list(it)
2591 ['10', '12', '14', '16', '18']
2593 """
2595 def __init__(self, iterable, *args):
2596 it = iter(iterable)
2597 if args:
2598 self._iterator = _islice_helper(it, slice(*args))
2599 else:
2600 self._iterator = it
2602 def __iter__(self):
2603 return self
2605 def __next__(self):
2606 return next(self._iterator)
2608 def __getitem__(self, key):
2609 if isinstance(key, slice):
2610 return islice_extended(_islice_helper(self._iterator, key))
2612 raise TypeError('islice_extended.__getitem__ argument must be a slice')
2615def _islice_helper(it, s):
2616 start = s.start
2617 stop = s.stop
2618 if s.step == 0:
2619 raise ValueError('step argument must be a non-zero integer or None.')
2620 step = s.step or 1
2622 if step > 0:
2623 start = 0 if (start is None) else start
2625 if start < 0:
2626 # Consume all but the last -start items
2627 cache = deque(enumerate(it, 1), maxlen=-start)
2628 len_iter = cache[-1][0] if cache else 0
2630 # Adjust start to be positive
2631 i = max(len_iter + start, 0)
2633 # Adjust stop to be positive
2634 if stop is None:
2635 j = len_iter
2636 elif stop >= 0:
2637 j = min(stop, len_iter)
2638 else:
2639 j = max(len_iter + stop, 0)
2641 # Slice the cache
2642 n = j - i
2643 if n <= 0:
2644 return
2646 for index in range(n):
2647 if index % step == 0:
2648 # pop and yield the item.
2649 # We don't want to use an intermediate variable
2650 # it would extend the lifetime of the current item
2651 yield cache.popleft()[1]
2652 else:
2653 # just pop and discard the item
2654 cache.popleft()
2655 elif (stop is not None) and (stop < 0):
2656 # Advance to the start position
2657 next(islice(it, start, start), None)
2659 # When stop is negative, we have to carry -stop items while
2660 # iterating
2661 cache = deque(islice(it, -stop), maxlen=-stop)
2663 for index, item in enumerate(it):
2664 if index % step == 0:
2665 # pop and yield the item.
2666 # We don't want to use an intermediate variable
2667 # it would extend the lifetime of the current item
2668 yield cache.popleft()
2669 else:
2670 # just pop and discard the item
2671 cache.popleft()
2672 cache.append(item)
2673 else:
2674 # When both start and stop are positive we have the normal case
2675 yield from islice(it, start, stop, step)
2676 else:
2677 start = -1 if (start is None) else start
2679 if (stop is not None) and (stop < 0):
2680 # Consume all but the last items
2681 n = -stop - 1
2682 cache = deque(enumerate(it, 1), maxlen=n)
2683 len_iter = cache[-1][0] if cache else 0
2685 # If start and stop are both negative they are comparable and
2686 # we can just slice. Otherwise we can adjust start to be negative
2687 # and then slice.
2688 if start < 0:
2689 i, j = start, stop
2690 else:
2691 i, j = min(start - len_iter, -1), None
2693 for index, item in list(cache)[i:j:step]:
2694 yield item
2695 else:
2696 # Advance to the stop position
2697 if stop is not None:
2698 m = stop + 1
2699 next(islice(it, m, m), None)
2701 # stop is positive, so if start is negative they are not comparable
2702 # and we need the rest of the items.
2703 if start < 0:
2704 i = start
2705 n = None
2706 # stop is None and start is positive, so we just need items up to
2707 # the start index.
2708 elif stop is None:
2709 i = None
2710 n = start + 1
2711 # Both stop and start are positive, so they are comparable.
2712 else:
2713 i = None
2714 n = start - stop
2715 if n <= 0:
2716 return
2718 cache = list(islice(it, n))
2720 yield from cache[i::step]
2723def always_reversible(iterable):
2724 """An extension of :func:`reversed` that supports all iterables, not
2725 just those which implement the ``Reversible`` or ``Sequence`` protocols.
2727 >>> print(*always_reversible(x for x in range(3)))
2728 2 1 0
2730 If the iterable is already reversible, this function returns the
2731 result of :func:`reversed()`. If the iterable is not reversible,
2732 this function will cache the remaining items in the iterable and
2733 yield them in reverse order, which may require significant storage.
2734 """
2735 try:
2736 return reversed(iterable)
2737 except TypeError:
2738 return reversed(list(iterable))
2741def consecutive_groups(iterable, ordering=None):
2742 """Yield groups of consecutive items using :func:`itertools.groupby`.
2743 The *ordering* function determines whether two items are adjacent by
2744 returning their position.
2746 By default, the ordering function is the identity function. This is
2747 suitable for finding runs of numbers:
2749 >>> iterable = [1, 10, 11, 12, 20, 30, 31, 32, 33, 40]
2750 >>> for group in consecutive_groups(iterable):
2751 ... print(list(group))
2752 [1]
2753 [10, 11, 12]
2754 [20]
2755 [30, 31, 32, 33]
2756 [40]
2758 To find runs of adjacent letters, apply :func:`ord` function
2759 to convert letters to ordinals.
2761 >>> iterable = 'abcdfgilmnop'
2762 >>> ordering = ord
2763 >>> for group in consecutive_groups(iterable, ordering):
2764 ... print(list(group))
2765 ['a', 'b', 'c', 'd']
2766 ['f', 'g']
2767 ['i']
2768 ['l', 'm', 'n', 'o', 'p']
2770 Each group of consecutive items is an iterator that shares it source with
2771 *iterable*. When an an output group is advanced, the previous group is
2772 no longer available unless its elements are copied (e.g., into a ``list``).
2774 >>> iterable = [1, 2, 11, 12, 21, 22]
2775 >>> saved_groups = []
2776 >>> for group in consecutive_groups(iterable):
2777 ... saved_groups.append(list(group)) # Copy group elements
2778 >>> saved_groups
2779 [[1, 2], [11, 12], [21, 22]]
2781 """
2782 if ordering is None:
2783 key = lambda x: x[0] - x[1]
2784 else:
2785 key = lambda x: x[0] - ordering(x[1])
2787 for k, g in groupby(enumerate(iterable), key=key):
2788 yield map(itemgetter(1), g)
2791def difference(iterable, func=sub, *, initial=None):
2792 """This function is the inverse of :func:`itertools.accumulate`. By default
2793 it will compute the first difference of *iterable* using
2794 :func:`operator.sub`:
2796 >>> from itertools import accumulate
2797 >>> iterable = accumulate([0, 1, 2, 3, 4]) # produces 0, 1, 3, 6, 10
2798 >>> list(difference(iterable))
2799 [0, 1, 2, 3, 4]
2801 *func* defaults to :func:`operator.sub`, but other functions can be
2802 specified. They will be applied as follows::
2804 A, B, C, D, ... --> A, func(B, A), func(C, B), func(D, C), ...
2806 For example, to do progressive division:
2808 >>> iterable = [1, 2, 6, 24, 120]
2809 >>> func = lambda x, y: x // y
2810 >>> list(difference(iterable, func))
2811 [1, 2, 3, 4, 5]
2813 If the *initial* keyword is set, the first element will be skipped when
2814 computing successive differences.
2816 >>> it = [10, 11, 13, 16] # from accumulate([1, 2, 3], initial=10)
2817 >>> list(difference(it, initial=10))
2818 [1, 2, 3]
2820 """
2821 a, b = tee(iterable)
2822 try:
2823 first = [next(b)]
2824 except StopIteration:
2825 return iter([])
2827 if initial is not None:
2828 first = []
2830 return chain(first, map(func, b, a))
2833class SequenceView(Sequence):
2834 """Return a read-only view of the sequence object *target*.
2836 :class:`SequenceView` objects are analogous to Python's built-in
2837 "dictionary view" types. They provide a dynamic view of a sequence's items,
2838 meaning that when the sequence updates, so does the view.
2840 >>> seq = ['0', '1', '2']
2841 >>> view = SequenceView(seq)
2842 >>> view
2843 SequenceView(['0', '1', '2'])
2844 >>> seq.append('3')
2845 >>> view
2846 SequenceView(['0', '1', '2', '3'])
2848 Sequence views support indexing, slicing, and length queries. They act
2849 like the underlying sequence, except they don't allow assignment:
2851 >>> view[1]
2852 '1'
2853 >>> view[1:-1]
2854 ['1', '2']
2855 >>> len(view)
2856 4
2858 Sequence views are useful as an alternative to copying, as they don't
2859 require (much) extra storage.
2861 """
2863 def __init__(self, target):
2864 if not isinstance(target, Sequence):
2865 raise TypeError
2866 self._target = target
2868 def __getitem__(self, index):
2869 return self._target[index]
2871 def __len__(self):
2872 return len(self._target)
2874 def __repr__(self):
2875 return f'{self.__class__.__name__}({self._target!r})'
2878class seekable:
2879 """Wrap an iterator to allow for seeking backward and forward. This
2880 progressively caches the items in the source iterable so they can be
2881 re-visited.
2883 Call :meth:`seek` with an index to seek to that position in the source
2884 iterable.
2886 To "reset" an iterator, seek to ``0``:
2888 >>> from itertools import count
2889 >>> it = seekable((str(n) for n in count()))
2890 >>> next(it), next(it), next(it)
2891 ('0', '1', '2')
2892 >>> it.seek(0)
2893 >>> next(it), next(it), next(it)
2894 ('0', '1', '2')
2896 You can also seek forward:
2898 >>> it = seekable((str(n) for n in range(20)))
2899 >>> it.seek(10)
2900 >>> next(it)
2901 '10'
2902 >>> it.seek(20) # Seeking past the end of the source isn't a problem
2903 >>> list(it)
2904 []
2905 >>> it.seek(0) # Resetting works even after hitting the end
2906 >>> next(it)
2907 '0'
2909 Call :meth:`relative_seek` to seek relative to the source iterator's
2910 current position.
2912 >>> it = seekable((str(n) for n in range(20)))
2913 >>> next(it), next(it), next(it)
2914 ('0', '1', '2')
2915 >>> it.relative_seek(2)
2916 >>> next(it)
2917 '5'
2918 >>> it.relative_seek(-3) # Source is at '6', we move back to '3'
2919 >>> next(it)
2920 '3'
2921 >>> it.relative_seek(-3) # Source is at '4', we move back to '1'
2922 >>> next(it)
2923 '1'
2926 Call :meth:`peek` to look ahead one item without advancing the iterator:
2928 >>> it = seekable('1234')
2929 >>> it.peek()
2930 '1'
2931 >>> list(it)
2932 ['1', '2', '3', '4']
2933 >>> it.peek(default='empty')
2934 'empty'
2936 Before the iterator is at its end, calling :func:`bool` on it will return
2937 ``True``. After it will return ``False``:
2939 >>> it = seekable('5678')
2940 >>> bool(it)
2941 True
2942 >>> list(it)
2943 ['5', '6', '7', '8']
2944 >>> bool(it)
2945 False
2947 You may view the contents of the cache with the :meth:`elements` method.
2948 That returns a :class:`SequenceView`, a view that updates automatically:
2950 >>> it = seekable((str(n) for n in range(10)))
2951 >>> next(it), next(it), next(it)
2952 ('0', '1', '2')
2953 >>> elements = it.elements()
2954 >>> elements
2955 SequenceView(['0', '1', '2'])
2956 >>> next(it)
2957 '3'
2958 >>> elements
2959 SequenceView(['0', '1', '2', '3'])
2961 By default, the cache grows as the source iterable progresses, so beware of
2962 wrapping very large or infinite iterables. Supply *maxlen* to limit the
2963 size of the cache (this of course limits how far back you can seek).
2965 >>> from itertools import count
2966 >>> it = seekable((str(n) for n in count()), maxlen=2)
2967 >>> next(it), next(it), next(it), next(it)
2968 ('0', '1', '2', '3')
2969 >>> list(it.elements())
2970 ['2', '3']
2971 >>> it.seek(0)
2972 >>> next(it), next(it), next(it), next(it)
2973 ('2', '3', '4', '5')
2974 >>> next(it)
2975 '6'
2977 """
2979 def __init__(self, iterable, maxlen=None):
2980 self._source = iter(iterable)
2981 if maxlen is None:
2982 self._cache = []
2983 else:
2984 self._cache = deque([], maxlen)
2985 self._index = None
2987 def __iter__(self):
2988 return self
2990 def __next__(self):
2991 if self._index is not None:
2992 try:
2993 item = self._cache[self._index]
2994 except IndexError:
2995 self._index = None
2996 else:
2997 self._index += 1
2998 return item
3000 item = next(self._source)
3001 self._cache.append(item)
3002 return item
3004 def __bool__(self):
3005 try:
3006 self.peek()
3007 except StopIteration:
3008 return False
3009 return True
3011 def peek(self, default=_marker):
3012 try:
3013 peeked = next(self)
3014 except StopIteration:
3015 if default is _marker:
3016 raise
3017 return default
3018 if self._index is None:
3019 self._index = len(self._cache)
3020 self._index -= 1
3021 return peeked
3023 def elements(self):
3024 return SequenceView(self._cache)
3026 def seek(self, index):
3027 self._index = index
3028 remainder = index - len(self._cache)
3029 if remainder > 0:
3030 consume(self, remainder)
3032 def relative_seek(self, count):
3033 if self._index is None:
3034 self._index = len(self._cache)
3036 self.seek(max(self._index + count, 0))
3039class run_length:
3040 """
3041 :func:`run_length.encode` compresses an iterable with run-length encoding.
3042 It yields groups of repeated items with the count of how many times they
3043 were repeated:
3045 >>> uncompressed = 'abbcccdddd'
3046 >>> list(run_length.encode(uncompressed))
3047 [('a', 1), ('b', 2), ('c', 3), ('d', 4)]
3049 :func:`run_length.decode` decompresses an iterable that was previously
3050 compressed with run-length encoding. It yields the items of the
3051 decompressed iterable:
3053 >>> compressed = [('a', 1), ('b', 2), ('c', 3), ('d', 4)]
3054 >>> list(run_length.decode(compressed))
3055 ['a', 'b', 'b', 'c', 'c', 'c', 'd', 'd', 'd', 'd']
3057 """
3059 @staticmethod
3060 def encode(iterable):
3061 return ((k, ilen(g)) for k, g in groupby(iterable))
3063 @staticmethod
3064 def decode(iterable):
3065 return chain.from_iterable(starmap(repeat, iterable))
3068def exactly_n(iterable, n, predicate=bool):
3069 """Return ``True`` if exactly ``n`` items in the iterable are ``True``
3070 according to the *predicate* function.
3072 >>> exactly_n([True, True, False], 2)
3073 True
3074 >>> exactly_n([True, True, False], 1)
3075 False
3076 >>> exactly_n([0, 1, 2, 3, 4, 5], 3, lambda x: x < 3)
3077 True
3079 The iterable will be advanced until ``n + 1`` truthy items are encountered,
3080 so avoid calling it on infinite iterables.
3082 """
3083 return ilen(islice(filter(predicate, iterable), n + 1)) == n
3086def circular_shifts(iterable, steps=1):
3087 """Yield the circular shifts of *iterable*.
3089 >>> list(circular_shifts(range(4)))
3090 [(0, 1, 2, 3), (1, 2, 3, 0), (2, 3, 0, 1), (3, 0, 1, 2)]
3092 Set *steps* to the number of places to rotate to the left
3093 (or to the right if negative). Defaults to 1.
3095 >>> list(circular_shifts(range(4), 2))
3096 [(0, 1, 2, 3), (2, 3, 0, 1)]
3098 >>> list(circular_shifts(range(4), -1))
3099 [(0, 1, 2, 3), (3, 0, 1, 2), (2, 3, 0, 1), (1, 2, 3, 0)]
3101 """
3102 buffer = deque(iterable)
3103 if steps == 0:
3104 raise ValueError('Steps should be a non-zero integer')
3106 buffer.rotate(steps)
3107 steps = -steps
3108 n = len(buffer)
3109 n //= math.gcd(n, steps)
3111 for _ in repeat(None, n):
3112 buffer.rotate(steps)
3113 yield tuple(buffer)
3116def make_decorator(wrapping_func, result_index=0):
3117 """Return a decorator version of *wrapping_func*, which is a function that
3118 modifies an iterable. *result_index* is the position in that function's
3119 signature where the iterable goes.
3121 This lets you use itertools on the "production end," i.e. at function
3122 definition. This can augment what the function returns without changing the
3123 function's code.
3125 For example, to produce a decorator version of :func:`chunked`:
3127 >>> from more_itertools import chunked
3128 >>> chunker = make_decorator(chunked, result_index=0)
3129 >>> @chunker(3)
3130 ... def iter_range(n):
3131 ... return iter(range(n))
3132 ...
3133 >>> list(iter_range(9))
3134 [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
3136 To only allow truthy items to be returned:
3138 >>> truth_serum = make_decorator(filter, result_index=1)
3139 >>> @truth_serum(bool)
3140 ... def boolean_test():
3141 ... return [0, 1, '', ' ', False, True]
3142 ...
3143 >>> list(boolean_test())
3144 [1, ' ', True]
3146 The :func:`peekable` and :func:`seekable` wrappers make for practical
3147 decorators:
3149 >>> from more_itertools import peekable
3150 >>> peekable_function = make_decorator(peekable)
3151 >>> @peekable_function()
3152 ... def str_range(*args):
3153 ... return (str(x) for x in range(*args))
3154 ...
3155 >>> it = str_range(1, 20, 2)
3156 >>> next(it), next(it), next(it)
3157 ('1', '3', '5')
3158 >>> it.peek()
3159 '7'
3160 >>> next(it)
3161 '7'
3163 """
3165 # See https://sites.google.com/site/bbayles/index/decorator_factory for
3166 # notes on how this works.
3167 def decorator(*wrapping_args, **wrapping_kwargs):
3168 def outer_wrapper(f):
3169 def inner_wrapper(*args, **kwargs):
3170 result = f(*args, **kwargs)
3171 wrapping_args_ = list(wrapping_args)
3172 wrapping_args_.insert(result_index, result)
3173 return wrapping_func(*wrapping_args_, **wrapping_kwargs)
3175 return inner_wrapper
3177 return outer_wrapper
3179 return decorator
3182def map_reduce(iterable, keyfunc, valuefunc=None, reducefunc=None):
3183 """Return a dictionary that maps the items in *iterable* to categories
3184 defined by *keyfunc*, transforms them with *valuefunc*, and
3185 then summarizes them by category with *reducefunc*.
3187 *valuefunc* defaults to the identity function if it is unspecified.
3188 If *reducefunc* is unspecified, no summarization takes place:
3190 >>> keyfunc = lambda x: x.upper()
3191 >>> result = map_reduce('abbccc', keyfunc)
3192 >>> sorted(result.items())
3193 [('A', ['a']), ('B', ['b', 'b']), ('C', ['c', 'c', 'c'])]
3195 Specifying *valuefunc* transforms the categorized items:
3197 >>> keyfunc = lambda x: x.upper()
3198 >>> valuefunc = lambda x: 1
3199 >>> result = map_reduce('abbccc', keyfunc, valuefunc)
3200 >>> sorted(result.items())
3201 [('A', [1]), ('B', [1, 1]), ('C', [1, 1, 1])]
3203 Specifying *reducefunc* summarizes the categorized items:
3205 >>> keyfunc = lambda x: x.upper()
3206 >>> valuefunc = lambda x: 1
3207 >>> reducefunc = sum
3208 >>> result = map_reduce('abbccc', keyfunc, valuefunc, reducefunc)
3209 >>> sorted(result.items())
3210 [('A', 1), ('B', 2), ('C', 3)]
3212 You may want to filter the input iterable before applying the map/reduce
3213 procedure:
3215 >>> all_items = range(30)
3216 >>> items = [x for x in all_items if 10 <= x <= 20] # Filter
3217 >>> keyfunc = lambda x: x % 2 # Evens map to 0; odds to 1
3218 >>> categories = map_reduce(items, keyfunc=keyfunc)
3219 >>> sorted(categories.items())
3220 [(0, [10, 12, 14, 16, 18, 20]), (1, [11, 13, 15, 17, 19])]
3221 >>> summaries = map_reduce(items, keyfunc=keyfunc, reducefunc=sum)
3222 >>> sorted(summaries.items())
3223 [(0, 90), (1, 75)]
3225 Note that all items in the iterable are gathered into a list before the
3226 summarization step, which may require significant storage.
3228 The returned object is a :obj:`collections.defaultdict` with the
3229 ``default_factory`` set to ``None``, such that it behaves like a normal
3230 dictionary.
3232 """
3234 ret = defaultdict(list)
3236 if valuefunc is None:
3237 for item in iterable:
3238 key = keyfunc(item)
3239 ret[key].append(item)
3241 else:
3242 for item in iterable:
3243 key = keyfunc(item)
3244 value = valuefunc(item)
3245 ret[key].append(value)
3247 if reducefunc is not None:
3248 for key, value_list in ret.items():
3249 ret[key] = reducefunc(value_list)
3251 ret.default_factory = None
3252 return ret
3255def rlocate(iterable, pred=bool, window_size=None):
3256 """Yield the index of each item in *iterable* for which *pred* returns
3257 ``True``, starting from the right and moving left.
3259 *pred* defaults to :func:`bool`, which will select truthy items:
3261 >>> list(rlocate([0, 1, 1, 0, 1, 0, 0])) # Truthy at 1, 2, and 4
3262 [4, 2, 1]
3264 Set *pred* to a custom function to, e.g., find the indexes for a particular
3265 item:
3267 >>> iterator = iter('abcb')
3268 >>> pred = lambda x: x == 'b'
3269 >>> list(rlocate(iterator, pred))
3270 [3, 1]
3272 If *window_size* is given, then the *pred* function will be called with
3273 that many items. This enables searching for sub-sequences:
3275 >>> iterable = [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]
3276 >>> pred = lambda *args: args == (1, 2, 3)
3277 >>> list(rlocate(iterable, pred=pred, window_size=3))
3278 [9, 5, 1]
3280 Beware, this function won't return anything for infinite iterables.
3281 If *iterable* is reversible, ``rlocate`` will reverse it and search from
3282 the right. Otherwise, it will search from the left and return the results
3283 in reverse order.
3285 See :func:`locate` to for other example applications.
3287 """
3288 if window_size is None:
3289 try:
3290 len_iter = len(iterable)
3291 return (len_iter - i - 1 for i in locate(reversed(iterable), pred))
3292 except TypeError:
3293 pass
3295 return reversed(list(locate(iterable, pred, window_size)))
3298def replace(iterable, pred, substitutes, count=None, window_size=1):
3299 """Yield the items from *iterable*, replacing the items for which *pred*
3300 returns ``True`` with the items from the iterable *substitutes*.
3302 >>> iterable = [1, 1, 0, 1, 1, 0, 1, 1]
3303 >>> pred = lambda x: x == 0
3304 >>> substitutes = (2, 3)
3305 >>> list(replace(iterable, pred, substitutes))
3306 [1, 1, 2, 3, 1, 1, 2, 3, 1, 1]
3308 If *count* is given, the number of replacements will be limited:
3310 >>> iterable = [1, 1, 0, 1, 1, 0, 1, 1, 0]
3311 >>> pred = lambda x: x == 0
3312 >>> substitutes = [None]
3313 >>> list(replace(iterable, pred, substitutes, count=2))
3314 [1, 1, None, 1, 1, None, 1, 1, 0]
3316 Use *window_size* to control the number of items passed as arguments to
3317 *pred*. This allows for locating and replacing subsequences.
3319 >>> iterable = [0, 1, 2, 5, 0, 1, 2, 5]
3320 >>> window_size = 3
3321 >>> pred = lambda *args: args == (0, 1, 2) # 3 items passed to pred
3322 >>> substitutes = [3, 4] # Splice in these items
3323 >>> list(replace(iterable, pred, substitutes, window_size=window_size))
3324 [3, 4, 5, 3, 4, 5]
3326 """
3327 if window_size < 1:
3328 raise ValueError('window_size must be at least 1')
3330 # Save the substitutes iterable, since it's used more than once
3331 substitutes = tuple(substitutes)
3333 # Add padding such that the number of windows matches the length of the
3334 # iterable
3335 it = chain(iterable, repeat(_marker, window_size - 1))
3336 windows = windowed(it, window_size)
3338 n = 0
3339 for w in windows:
3340 # If the current window matches our predicate (and we haven't hit
3341 # our maximum number of replacements), splice in the substitutes
3342 # and then consume the following windows that overlap with this one.
3343 # For example, if the iterable is (0, 1, 2, 3, 4...)
3344 # and the window size is 2, we have (0, 1), (1, 2), (2, 3)...
3345 # If the predicate matches on (0, 1), we need to zap (0, 1) and (1, 2)
3346 if pred(*w):
3347 if (count is None) or (n < count):
3348 n += 1
3349 yield from substitutes
3350 consume(windows, window_size - 1)
3351 continue
3353 # If there was no match (or we've reached the replacement limit),
3354 # yield the first item from the window.
3355 if w and (w[0] is not _marker):
3356 yield w[0]
3359def partitions(iterable):
3360 """Yield all possible order-preserving partitions of *iterable*.
3362 >>> iterable = 'abc'
3363 >>> for part in partitions(iterable):
3364 ... print([''.join(p) for p in part])
3365 ['abc']
3366 ['a', 'bc']
3367 ['ab', 'c']
3368 ['a', 'b', 'c']
3370 This is unrelated to :func:`partition`.
3372 """
3373 sequence = list(iterable)
3374 n = len(sequence)
3375 for i in powerset(range(1, n)):
3376 yield [sequence[i:j] for i, j in zip((0,) + i, i + (n,))]
3379def set_partitions(iterable, k=None, min_size=None, max_size=None):
3380 """
3381 Yield the set partitions of *iterable* into *k* parts. Set partitions are
3382 not order-preserving.
3384 >>> iterable = 'abc'
3385 >>> for part in set_partitions(iterable, 2):
3386 ... print([''.join(p) for p in part])
3387 ['a', 'bc']
3388 ['ab', 'c']
3389 ['b', 'ac']
3392 If *k* is not given, every set partition is generated.
3394 >>> iterable = 'abc'
3395 >>> for part in set_partitions(iterable):
3396 ... print([''.join(p) for p in part])
3397 ['abc']
3398 ['a', 'bc']
3399 ['ab', 'c']
3400 ['b', 'ac']
3401 ['a', 'b', 'c']
3403 if *min_size* and/or *max_size* are given, the minimum and/or maximum size
3404 per block in partition is set.
3406 >>> iterable = 'abc'
3407 >>> for part in set_partitions(iterable, min_size=2):
3408 ... print([''.join(p) for p in part])
3409 ['abc']
3410 >>> for part in set_partitions(iterable, max_size=2):
3411 ... print([''.join(p) for p in part])
3412 ['a', 'bc']
3413 ['ab', 'c']
3414 ['b', 'ac']
3415 ['a', 'b', 'c']
3417 """
3418 L = list(iterable)
3419 n = len(L)
3420 if k is not None:
3421 if k < 1:
3422 raise ValueError(
3423 "Can't partition in a negative or zero number of groups"
3424 )
3425 elif k > n:
3426 return
3428 min_size = min_size if min_size is not None else 0
3429 max_size = max_size if max_size is not None else n
3430 if min_size > max_size:
3431 return
3433 def set_partitions_helper(L, k):
3434 n = len(L)
3435 if k == 1:
3436 yield [L]
3437 elif n == k:
3438 yield [[s] for s in L]
3439 else:
3440 e, *M = L
3441 for p in set_partitions_helper(M, k - 1):
3442 yield [[e], *p]
3443 for p in set_partitions_helper(M, k):
3444 for i in range(len(p)):
3445 yield p[:i] + [[e] + p[i]] + p[i + 1 :]
3447 if k is None:
3448 for k in range(1, n + 1):
3449 yield from filter(
3450 lambda z: all(min_size <= len(bk) <= max_size for bk in z),
3451 set_partitions_helper(L, k),
3452 )
3453 else:
3454 yield from filter(
3455 lambda z: all(min_size <= len(bk) <= max_size for bk in z),
3456 set_partitions_helper(L, k),
3457 )
3460class time_limited:
3461 """
3462 Yield items from *iterable* until *limit_seconds* have passed.
3463 If the time limit expires before all items have been yielded, the
3464 ``timed_out`` parameter will be set to ``True``.
3466 >>> from time import sleep
3467 >>> def generator():
3468 ... yield 1
3469 ... yield 2
3470 ... sleep(0.2)
3471 ... yield 3
3472 >>> iterable = time_limited(0.1, generator())
3473 >>> list(iterable)
3474 [1, 2]
3475 >>> iterable.timed_out
3476 True
3478 Note that the time is checked before each item is yielded, and iteration
3479 stops if the time elapsed is greater than *limit_seconds*. If your time
3480 limit is 1 second, but it takes 2 seconds to generate the first item from
3481 the iterable, the function will run for 2 seconds and not yield anything.
3482 As a special case, when *limit_seconds* is zero, the iterator never
3483 returns anything.
3485 """
3487 def __init__(self, limit_seconds, iterable):
3488 if limit_seconds < 0:
3489 raise ValueError('limit_seconds must be positive')
3490 self.limit_seconds = limit_seconds
3491 self._iterator = iter(iterable)
3492 self._start_time = monotonic()
3493 self.timed_out = False
3495 def __iter__(self):
3496 return self
3498 def __next__(self):
3499 if self.limit_seconds == 0:
3500 self.timed_out = True
3501 raise StopIteration
3502 item = next(self._iterator)
3503 if monotonic() - self._start_time > self.limit_seconds:
3504 self.timed_out = True
3505 raise StopIteration
3507 return item
3510def only(iterable, default=None, too_long=None):
3511 """If *iterable* has only one item, return it.
3512 If it has zero items, return *default*.
3513 If it has more than one item, raise the exception given by *too_long*,
3514 which is ``ValueError`` by default.
3516 >>> only([], default='missing')
3517 'missing'
3518 >>> only([1])
3519 1
3520 >>> only([1, 2]) # doctest: +IGNORE_EXCEPTION_DETAIL
3521 Traceback (most recent call last):
3522 ...
3523 ValueError: Expected exactly one item in iterable, but got 1, 2,
3524 and perhaps more.'
3525 >>> only([1, 2], too_long=TypeError) # doctest: +IGNORE_EXCEPTION_DETAIL
3526 Traceback (most recent call last):
3527 ...
3528 TypeError
3530 Note that :func:`only` attempts to advance *iterable* twice to ensure there
3531 is only one item. See :func:`spy` or :func:`peekable` to check
3532 iterable contents less destructively.
3534 """
3535 iterator = iter(iterable)
3536 for first in iterator:
3537 for second in iterator:
3538 msg = (
3539 f'Expected exactly one item in iterable, but got {first!r}, '
3540 f'{second!r}, and perhaps more.'
3541 )
3542 raise too_long or ValueError(msg)
3543 return first
3544 return default
3547def _ichunk(iterator, n):
3548 cache = deque()
3549 chunk = islice(iterator, n)
3551 def generator():
3552 with suppress(StopIteration):
3553 while True:
3554 if cache:
3555 yield cache.popleft()
3556 else:
3557 yield next(chunk)
3559 def materialize_next(n=1):
3560 # if n not specified materialize everything
3561 if n is None:
3562 cache.extend(chunk)
3563 return len(cache)
3565 to_cache = n - len(cache)
3567 # materialize up to n
3568 if to_cache > 0:
3569 cache.extend(islice(chunk, to_cache))
3571 # return number materialized up to n
3572 return min(n, len(cache))
3574 return (generator(), materialize_next)
3577def ichunked(iterable, n):
3578 """Break *iterable* into sub-iterables with *n* elements each.
3579 :func:`ichunked` is like :func:`chunked`, but it yields iterables
3580 instead of lists.
3582 If the sub-iterables are read in order, the elements of *iterable*
3583 won't be stored in memory.
3584 If they are read out of order, :func:`itertools.tee` is used to cache
3585 elements as necessary.
3587 >>> from itertools import count
3588 >>> all_chunks = ichunked(count(), 4)
3589 >>> c_1, c_2, c_3 = next(all_chunks), next(all_chunks), next(all_chunks)
3590 >>> list(c_2) # c_1's elements have been cached; c_3's haven't been
3591 [4, 5, 6, 7]
3592 >>> list(c_1)
3593 [0, 1, 2, 3]
3594 >>> list(c_3)
3595 [8, 9, 10, 11]
3597 """
3598 iterator = iter(iterable)
3599 while True:
3600 # Create new chunk
3601 chunk, materialize_next = _ichunk(iterator, n)
3603 # Check to see whether we're at the end of the source iterable
3604 if not materialize_next():
3605 return
3607 yield chunk
3609 # Fill previous chunk's cache
3610 materialize_next(None)
3613def iequals(*iterables):
3614 """Return ``True`` if all given *iterables* are equal to each other,
3615 which means that they contain the same elements in the same order.
3617 The function is useful for comparing iterables of different data types
3618 or iterables that do not support equality checks.
3620 >>> iequals("abc", ['a', 'b', 'c'], ('a', 'b', 'c'), iter("abc"))
3621 True
3623 >>> iequals("abc", "acb")
3624 False
3626 Not to be confused with :func:`all_equal`, which checks whether all
3627 elements of iterable are equal to each other.
3629 """
3630 try:
3631 return all(map(all_equal, zip(*iterables, strict=True)))
3632 except ValueError:
3633 return False
3636def distinct_combinations(iterable, r):
3637 """Yield the distinct combinations of *r* items taken from *iterable*.
3639 >>> list(distinct_combinations([0, 0, 1], 2))
3640 [(0, 0), (0, 1)]
3642 Equivalent to ``set(combinations(iterable))``, except duplicates are not
3643 generated and thrown away. For larger input sequences this is much more
3644 efficient.
3646 """
3647 if r < 0:
3648 raise ValueError('r must be non-negative')
3649 elif r == 0:
3650 yield ()
3651 return
3652 pool = tuple(iterable)
3653 generators = [unique_everseen(enumerate(pool), key=itemgetter(1))]
3654 current_combo = [None] * r
3655 level = 0
3656 while generators:
3657 try:
3658 cur_idx, p = next(generators[-1])
3659 except StopIteration:
3660 generators.pop()
3661 level -= 1
3662 continue
3663 current_combo[level] = p
3664 if level + 1 == r:
3665 yield tuple(current_combo)
3666 else:
3667 generators.append(
3668 unique_everseen(
3669 enumerate(pool[cur_idx + 1 :], cur_idx + 1),
3670 key=itemgetter(1),
3671 )
3672 )
3673 level += 1
3676def filter_except(validator, iterable, *exceptions):
3677 """Yield the items from *iterable* for which the *validator* function does
3678 not raise one of the specified *exceptions*.
3680 *validator* is called for each item in *iterable*.
3681 It should be a function that accepts one argument and raises an exception
3682 if that item is not valid.
3684 >>> iterable = ['1', '2', 'three', '4', None]
3685 >>> list(filter_except(int, iterable, ValueError, TypeError))
3686 ['1', '2', '4']
3688 If an exception other than one given by *exceptions* is raised by
3689 *validator*, it is raised like normal.
3690 """
3691 for item in iterable:
3692 try:
3693 validator(item)
3694 except exceptions:
3695 pass
3696 else:
3697 yield item
3700def map_except(function, iterable, *exceptions):
3701 """Transform each item from *iterable* with *function* and yield the
3702 result, unless *function* raises one of the specified *exceptions*.
3704 *function* is called to transform each item in *iterable*.
3705 It should accept one argument.
3707 >>> iterable = ['1', '2', 'three', '4', None]
3708 >>> list(map_except(int, iterable, ValueError, TypeError))
3709 [1, 2, 4]
3711 If an exception other than one given by *exceptions* is raised by
3712 *function*, it is raised like normal.
3713 """
3714 for item in iterable:
3715 try:
3716 yield function(item)
3717 except exceptions:
3718 pass
3721def map_if(iterable, pred, func, func_else=None):
3722 """Evaluate each item from *iterable* using *pred*. If the result is
3723 equivalent to ``True``, transform the item with *func* and yield it.
3724 Otherwise, transform the item with *func_else* and yield it.
3726 *pred*, *func*, and *func_else* should each be functions that accept
3727 one argument. By default, *func_else* is the identity function.
3729 >>> from math import sqrt
3730 >>> iterable = list(range(-5, 5))
3731 >>> iterable
3732 [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4]
3733 >>> list(map_if(iterable, lambda x: x > 3, lambda x: 'toobig'))
3734 [-5, -4, -3, -2, -1, 0, 1, 2, 3, 'toobig']
3735 >>> list(map_if(iterable, lambda x: x >= 0,
3736 ... lambda x: f'{sqrt(x):.2f}', lambda x: None))
3737 [None, None, None, None, None, '0.00', '1.00', '1.41', '1.73', '2.00']
3738 """
3740 if func_else is None:
3741 for item in iterable:
3742 yield func(item) if pred(item) else item
3744 else:
3745 for item in iterable:
3746 yield func(item) if pred(item) else func_else(item)
3749def _sample_unweighted(iterator, k, strict):
3750 # Algorithm L in the 1994 paper by Kim-Hung Li:
3751 # "Reservoir-Sampling Algorithms of Time Complexity O(n(1+log(N/n)))".
3753 reservoir = list(islice(iterator, k))
3754 if strict and len(reservoir) < k:
3755 raise ValueError('Sample larger than population')
3756 W = 1.0
3758 with suppress(StopIteration):
3759 while True:
3760 W *= random() ** (1 / k)
3761 skip = floor(log(random()) / log1p(-W))
3762 element = next(islice(iterator, skip, None))
3763 reservoir[randrange(k)] = element
3765 shuffle(reservoir)
3766 return reservoir
3769def _sample_weighted(iterator, k, weights, strict):
3770 # Implementation of "A-ExpJ" from the 2006 paper by Efraimidis et al. :
3771 # "Weighted random sampling with a reservoir".
3773 # Log-transform for numerical stability for weights that are small/large
3774 weight_keys = (log(random()) / weight for weight in weights)
3776 # Fill up the reservoir (collection of samples) with the first `k`
3777 # weight-keys and elements, then heapify the list.
3778 reservoir = take(k, zip(weight_keys, iterator))
3779 if strict and len(reservoir) < k:
3780 raise ValueError('Sample larger than population')
3782 heapify(reservoir)
3784 # The number of jumps before changing the reservoir is a random variable
3785 # with an exponential distribution. Sample it using random() and logs.
3786 smallest_weight_key, _ = reservoir[0]
3787 weights_to_skip = log(random()) / smallest_weight_key
3789 for weight, element in zip(weights, iterator):
3790 if weight >= weights_to_skip:
3791 # The notation here is consistent with the paper, but we store
3792 # the weight-keys in log-space for better numerical stability.
3793 smallest_weight_key, _ = reservoir[0]
3794 t_w = exp(weight * smallest_weight_key)
3795 r_2 = uniform(t_w, 1) # generate U(t_w, 1)
3796 weight_key = log(r_2) / weight
3797 heapreplace(reservoir, (weight_key, element))
3798 smallest_weight_key, _ = reservoir[0]
3799 weights_to_skip = log(random()) / smallest_weight_key
3800 else:
3801 weights_to_skip -= weight
3803 ret = [element for weight_key, element in reservoir]
3804 shuffle(ret)
3805 return ret
3808def _sample_counted(population, k, counts, strict):
3809 element = None
3810 remaining = 0
3812 def feed(i):
3813 # Advance *i* steps ahead and consume an element
3814 nonlocal element, remaining
3816 while i + 1 > remaining:
3817 i = i - remaining
3818 element = next(population)
3819 remaining = next(counts)
3820 remaining -= i + 1
3821 return element
3823 with suppress(StopIteration):
3824 reservoir = []
3825 for _ in range(k):
3826 reservoir.append(feed(0))
3828 if strict and len(reservoir) < k:
3829 raise ValueError('Sample larger than population')
3831 with suppress(StopIteration):
3832 W = 1.0
3833 while True:
3834 W *= random() ** (1 / k)
3835 skip = floor(log(random()) / log1p(-W))
3836 element = feed(skip)
3837 reservoir[randrange(k)] = element
3839 shuffle(reservoir)
3840 return reservoir
3843def sample(iterable, k, weights=None, *, counts=None, strict=False):
3844 """Return a *k*-length list of elements chosen (without replacement)
3845 from the *iterable*.
3847 Similar to :func:`random.sample`, but works on inputs that aren't
3848 indexable (such as sets and dictionaries) and on inputs where the
3849 size isn't known in advance (such as generators).
3851 >>> iterable = range(100)
3852 >>> sample(iterable, 5) # doctest: +SKIP
3853 [81, 60, 96, 16, 4]
3855 For iterables with repeated elements, you may supply *counts* to
3856 indicate the repeats.
3858 >>> iterable = ['a', 'b']
3859 >>> counts = [3, 4] # Equivalent to 'a', 'a', 'a', 'b', 'b', 'b', 'b'
3860 >>> sample(iterable, k=3, counts=counts) # doctest: +SKIP
3861 ['a', 'a', 'b']
3863 An iterable with *weights* may be given:
3865 >>> iterable = range(100)
3866 >>> weights = (i * i + 1 for i in range(100))
3867 >>> sampled = sample(iterable, 5, weights=weights) # doctest: +SKIP
3868 [79, 67, 74, 66, 78]
3870 Weighted selections are made without replacement.
3871 After an element is selected, it is removed from the pool and the
3872 relative weights of the other elements increase (this
3873 does not match the behavior of :func:`random.sample`'s *counts*
3874 parameter). Note that *weights* may not be used with *counts*.
3876 If the length of *iterable* is less than *k*,
3877 ``ValueError`` is raised if *strict* is ``True`` and
3878 all elements are returned (in shuffled order) if *strict* is ``False``.
3880 By default, the `Algorithm L <https://w.wiki/ANrM>`__ reservoir sampling
3881 technique is used. When *weights* are provided,
3882 `Algorithm A-ExpJ <https://w.wiki/ANrS>`__ is used instead.
3884 Notes on reproducibility:
3886 * The algorithms rely on inexact floating-point functions provided
3887 by the underlying math library (e.g. ``log``, ``log1p``, and ``pow``).
3888 Those functions can `produce slightly different results
3889 <https://members.loria.fr/PZimmermann/papers/accuracy.pdf>`_ on
3890 different builds. Accordingly, selections can vary across builds
3891 even for the same seed.
3893 * The algorithms loop over the input and make selections based on
3894 ordinal position, so selections from unordered collections (such as
3895 sets) won't reproduce across sessions on the same platform using the
3896 same seed. For example, this won't reproduce::
3898 >> seed(8675309)
3899 >> sample(set('abcdefghijklmnopqrstuvwxyz'), 10)
3900 ['c', 'p', 'e', 'w', 's', 'a', 'j', 'd', 'n', 't']
3902 """
3903 iterator = iter(iterable)
3905 if k < 0:
3906 raise ValueError('k must be non-negative')
3908 if k == 0:
3909 return []
3911 if weights is not None and counts is not None:
3912 raise TypeError('weights and counts are mutually exclusive')
3914 elif weights is not None:
3915 weights = iter(weights)
3916 return _sample_weighted(iterator, k, weights, strict)
3918 elif counts is not None:
3919 counts = iter(counts)
3920 return _sample_counted(iterator, k, counts, strict)
3922 else:
3923 return _sample_unweighted(iterator, k, strict)
3926def is_sorted(iterable, key=None, reverse=False, strict=False):
3927 """Returns ``True`` if the items of iterable are in sorted order, and
3928 ``False`` otherwise. *key* and *reverse* have the same meaning that they do
3929 in the built-in :func:`sorted` function.
3931 >>> is_sorted(['1', '2', '3', '4', '5'], key=int)
3932 True
3933 >>> is_sorted([5, 4, 3, 1, 2], reverse=True)
3934 False
3936 If *strict*, tests for strict sorting, that is, returns ``False`` if equal
3937 elements are found:
3939 >>> is_sorted([1, 2, 2])
3940 True
3941 >>> is_sorted([1, 2, 2], strict=True)
3942 False
3944 The function returns ``False`` after encountering the first out-of-order
3945 item, which means it may produce results that differ from the built-in
3946 :func:`sorted` function for objects with unusual comparison dynamics
3947 (like ``math.nan``). If there are no out-of-order items, the iterable is
3948 exhausted.
3949 """
3950 it = iterable if (key is None) else map(key, iterable)
3951 a, b = tee(it)
3952 next(b, None)
3953 if reverse:
3954 b, a = a, b
3955 return all(map(lt, a, b)) if strict else not any(map(lt, b, a))
3958class AbortThread(BaseException):
3959 pass
3962class callback_iter:
3963 """Convert a function that uses callbacks to an iterator.
3965 Let *func* be a function that takes a `callback` keyword argument.
3966 For example:
3968 >>> def func(callback=None):
3969 ... for i, c in [(1, 'a'), (2, 'b'), (3, 'c')]:
3970 ... if callback:
3971 ... callback(i, c)
3972 ... return 4
3975 Use ``with callback_iter(func)`` to get an iterator over the parameters
3976 that are delivered to the callback.
3978 >>> with callback_iter(func) as it:
3979 ... for args, kwargs in it:
3980 ... print(args)
3981 (1, 'a')
3982 (2, 'b')
3983 (3, 'c')
3985 The function will be called in a background thread. The ``done`` property
3986 indicates whether it has completed execution.
3988 >>> it.done
3989 True
3991 If it completes successfully, its return value will be available
3992 in the ``result`` property.
3994 >>> it.result
3995 4
3997 Notes:
3999 * If the function uses some keyword argument besides ``callback``, supply
4000 *callback_kwd*.
4001 * If it finished executing, but raised an exception, accessing the
4002 ``result`` property will raise the same exception.
4003 * If it hasn't finished executing, accessing the ``result``
4004 property from within the ``with`` block will raise ``RuntimeError``.
4005 * If it hasn't finished executing, accessing the ``result`` property from
4006 outside the ``with`` block will raise a
4007 ``more_itertools.AbortThread`` exception.
4008 * Provide *wait_seconds* to adjust how frequently the it is polled for
4009 output.
4011 """
4013 def __init__(self, func, callback_kwd='callback', wait_seconds=0.1):
4014 self._func = func
4015 self._callback_kwd = callback_kwd
4016 self._aborted = False
4017 self._future = None
4018 self._wait_seconds = wait_seconds
4019 # Lazily import concurrent.future
4020 self._executor = __import__(
4021 'concurrent.futures'
4022 ).futures.ThreadPoolExecutor(max_workers=1)
4023 self._iterator = self._reader()
4025 def __enter__(self):
4026 return self
4028 def __exit__(self, exc_type, exc_value, traceback):
4029 self._aborted = True
4030 self._executor.shutdown()
4032 def __iter__(self):
4033 return self
4035 def __next__(self):
4036 return next(self._iterator)
4038 @property
4039 def done(self):
4040 if self._future is None:
4041 return False
4042 return self._future.done()
4044 @property
4045 def result(self):
4046 if not self.done:
4047 raise RuntimeError('Function has not yet completed')
4049 return self._future.result()
4051 def _reader(self):
4052 q = Queue()
4054 def callback(*args, **kwargs):
4055 if self._aborted:
4056 raise AbortThread('canceled by user')
4058 q.put((args, kwargs))
4060 self._future = self._executor.submit(
4061 self._func, **{self._callback_kwd: callback}
4062 )
4064 while True:
4065 try:
4066 item = q.get(timeout=self._wait_seconds)
4067 except Empty:
4068 pass
4069 else:
4070 q.task_done()
4071 yield item
4073 if self._future.done():
4074 break
4076 remaining = []
4077 while True:
4078 try:
4079 item = q.get_nowait()
4080 except Empty:
4081 break
4082 else:
4083 q.task_done()
4084 remaining.append(item)
4085 q.join()
4086 yield from remaining
4089def windowed_complete(iterable, n):
4090 """
4091 Yield ``(beginning, middle, end)`` tuples, where:
4093 * Each ``middle`` has *n* items from *iterable*
4094 * Each ``beginning`` has the items before the ones in ``middle``
4095 * Each ``end`` has the items after the ones in ``middle``
4097 >>> iterable = range(7)
4098 >>> n = 3
4099 >>> for beginning, middle, end in windowed_complete(iterable, n):
4100 ... print(beginning, middle, end)
4101 () (0, 1, 2) (3, 4, 5, 6)
4102 (0,) (1, 2, 3) (4, 5, 6)
4103 (0, 1) (2, 3, 4) (5, 6)
4104 (0, 1, 2) (3, 4, 5) (6,)
4105 (0, 1, 2, 3) (4, 5, 6) ()
4107 Note that *n* must be at least 0 and most equal to the length of
4108 *iterable*.
4110 This function will exhaust the iterable and may require significant
4111 storage.
4112 """
4113 if n < 0:
4114 raise ValueError('n must be >= 0')
4116 seq = tuple(iterable)
4117 size = len(seq)
4119 if n > size:
4120 raise ValueError('n must be <= len(seq)')
4122 for i in range(size - n + 1):
4123 beginning = seq[:i]
4124 middle = seq[i : i + n]
4125 end = seq[i + n :]
4126 yield beginning, middle, end
4129def all_unique(iterable, key=None):
4130 """
4131 Returns ``True`` if all the elements of *iterable* are unique (no two
4132 elements are equal).
4134 >>> all_unique('ABCB')
4135 False
4137 If a *key* function is specified, it will be used to make comparisons.
4139 >>> all_unique('ABCb')
4140 True
4141 >>> all_unique('ABCb', str.lower)
4142 False
4144 The function returns as soon as the first non-unique element is
4145 encountered. Iterables with a mix of hashable and unhashable items can
4146 be used, but the function will be slower for unhashable items.
4147 """
4148 seenset = set()
4149 seenset_add = seenset.add
4150 seenlist = []
4151 seenlist_add = seenlist.append
4152 for element in map(key, iterable) if key else iterable:
4153 try:
4154 if element in seenset:
4155 return False
4156 seenset_add(element)
4157 except TypeError:
4158 if element in seenlist:
4159 return False
4160 seenlist_add(element)
4161 return True
4164def nth_product(index, *args):
4165 """Equivalent to ``list(product(*args))[index]``.
4167 The products of *args* can be ordered lexicographically.
4168 :func:`nth_product` computes the product at sort position *index* without
4169 computing the previous products.
4171 >>> nth_product(8, range(2), range(2), range(2), range(2))
4172 (1, 0, 0, 0)
4174 ``IndexError`` will be raised if the given *index* is invalid.
4175 """
4176 pools = list(map(tuple, reversed(args)))
4177 ns = list(map(len, pools))
4179 c = reduce(mul, ns)
4181 if index < 0:
4182 index += c
4184 if not 0 <= index < c:
4185 raise IndexError
4187 result = []
4188 for pool, n in zip(pools, ns):
4189 result.append(pool[index % n])
4190 index //= n
4192 return tuple(reversed(result))
4195def nth_permutation(iterable, r, index):
4196 """Equivalent to ``list(permutations(iterable, r))[index]```
4198 The subsequences of *iterable* that are of length *r* where order is
4199 important can be ordered lexicographically. :func:`nth_permutation`
4200 computes the subsequence at sort position *index* directly, without
4201 computing the previous subsequences.
4203 >>> nth_permutation('ghijk', 2, 5)
4204 ('h', 'i')
4206 ``ValueError`` will be raised If *r* is negative or greater than the length
4207 of *iterable*.
4208 ``IndexError`` will be raised if the given *index* is invalid.
4209 """
4210 pool = list(iterable)
4211 n = len(pool)
4213 if r is None or r == n:
4214 r, c = n, factorial(n)
4215 elif not 0 <= r < n:
4216 raise ValueError
4217 else:
4218 c = perm(n, r)
4219 assert c > 0 # factorial(n)>0, and r<n so perm(n,r) is never zero
4221 if index < 0:
4222 index += c
4224 if not 0 <= index < c:
4225 raise IndexError
4227 result = [0] * r
4228 q = index * factorial(n) // c if r < n else index
4229 for d in range(1, n + 1):
4230 q, i = divmod(q, d)
4231 if 0 <= n - d < r:
4232 result[n - d] = i
4233 if q == 0:
4234 break
4236 return tuple(map(pool.pop, result))
4239def nth_combination_with_replacement(iterable, r, index):
4240 """Equivalent to
4241 ``list(combinations_with_replacement(iterable, r))[index]``.
4244 The subsequences with repetition of *iterable* that are of length *r* can
4245 be ordered lexicographically. :func:`nth_combination_with_replacement`
4246 computes the subsequence at sort position *index* directly, without
4247 computing the previous subsequences with replacement.
4249 >>> nth_combination_with_replacement(range(5), 3, 5)
4250 (0, 1, 1)
4252 ``ValueError`` will be raised If *r* is negative or greater than the length
4253 of *iterable*.
4254 ``IndexError`` will be raised if the given *index* is invalid.
4255 """
4256 pool = tuple(iterable)
4257 n = len(pool)
4258 if (r < 0) or (r > n):
4259 raise ValueError
4261 c = comb(n + r - 1, r)
4263 if index < 0:
4264 index += c
4266 if (index < 0) or (index >= c):
4267 raise IndexError
4269 result = []
4270 i = 0
4271 while r:
4272 r -= 1
4273 while n >= 0:
4274 num_combs = comb(n + r - 1, r)
4275 if index < num_combs:
4276 break
4277 n -= 1
4278 i += 1
4279 index -= num_combs
4280 result.append(pool[i])
4282 return tuple(result)
4285def value_chain(*args):
4286 """Yield all arguments passed to the function in the same order in which
4287 they were passed. If an argument itself is iterable then iterate over its
4288 values.
4290 >>> list(value_chain(1, 2, 3, [4, 5, 6]))
4291 [1, 2, 3, 4, 5, 6]
4293 Binary and text strings are not considered iterable and are emitted
4294 as-is:
4296 >>> list(value_chain('12', '34', ['56', '78']))
4297 ['12', '34', '56', '78']
4299 Pre- or postpend a single element to an iterable:
4301 >>> list(value_chain(1, [2, 3, 4, 5, 6]))
4302 [1, 2, 3, 4, 5, 6]
4303 >>> list(value_chain([1, 2, 3, 4, 5], 6))
4304 [1, 2, 3, 4, 5, 6]
4306 Multiple levels of nesting are not flattened.
4308 """
4309 for value in args:
4310 if isinstance(value, (str, bytes)):
4311 yield value
4312 continue
4313 try:
4314 yield from value
4315 except TypeError:
4316 yield value
4319def product_index(element, *args):
4320 """Equivalent to ``list(product(*args)).index(element)``
4322 The products of *args* can be ordered lexicographically.
4323 :func:`product_index` computes the first index of *element* without
4324 computing the previous products.
4326 >>> product_index([8, 2], range(10), range(5))
4327 42
4329 ``ValueError`` will be raised if the given *element* isn't in the product
4330 of *args*.
4331 """
4332 elements = tuple(element)
4333 pools = tuple(map(tuple, args))
4334 if len(elements) != len(pools):
4335 raise ValueError('element is not a product of args')
4337 index = 0
4338 for elem, pool in zip(elements, pools):
4339 index = index * len(pool) + pool.index(elem)
4340 return index
4343def combination_index(element, iterable):
4344 """Equivalent to ``list(combinations(iterable, r)).index(element)``
4346 The subsequences of *iterable* that are of length *r* can be ordered
4347 lexicographically. :func:`combination_index` computes the index of the
4348 first *element*, without computing the previous combinations.
4350 >>> combination_index('adf', 'abcdefg')
4351 10
4353 ``ValueError`` will be raised if the given *element* isn't one of the
4354 combinations of *iterable*.
4355 """
4356 element = enumerate(element)
4357 k, y = next(element, (None, None))
4358 if k is None:
4359 return 0
4361 indexes = []
4362 pool = enumerate(iterable)
4363 for n, x in pool:
4364 if x == y:
4365 indexes.append(n)
4366 tmp, y = next(element, (None, None))
4367 if tmp is None:
4368 break
4369 else:
4370 k = tmp
4371 else:
4372 raise ValueError('element is not a combination of iterable')
4374 n, _ = last(pool, default=(n, None))
4376 index = 1
4377 for i, j in enumerate(reversed(indexes), start=1):
4378 j = n - j
4379 if i <= j:
4380 index += comb(j, i)
4382 return comb(n + 1, k + 1) - index
4385def combination_with_replacement_index(element, iterable):
4386 """Equivalent to
4387 ``list(combinations_with_replacement(iterable, r)).index(element)``
4389 The subsequences with repetition of *iterable* that are of length *r* can
4390 be ordered lexicographically. :func:`combination_with_replacement_index`
4391 computes the index of the first *element*, without computing the previous
4392 combinations with replacement.
4394 >>> combination_with_replacement_index('adf', 'abcdefg')
4395 20
4397 ``ValueError`` will be raised if the given *element* isn't one of the
4398 combinations with replacement of *iterable*.
4399 """
4400 element = tuple(element)
4401 l = len(element)
4402 element = enumerate(element)
4404 k, y = next(element, (None, None))
4405 if k is None:
4406 return 0
4408 indexes = []
4409 pool = tuple(iterable)
4410 for n, x in enumerate(pool):
4411 while x == y:
4412 indexes.append(n)
4413 tmp, y = next(element, (None, None))
4414 if tmp is None:
4415 break
4416 else:
4417 k = tmp
4418 if y is None:
4419 break
4420 else:
4421 raise ValueError(
4422 'element is not a combination with replacement of iterable'
4423 )
4425 n = len(pool)
4426 occupations = [0] * n
4427 for p in indexes:
4428 occupations[p] += 1
4430 index = 0
4431 cumulative_sum = 0
4432 for k in range(1, n):
4433 cumulative_sum += occupations[k - 1]
4434 j = l + n - 1 - k - cumulative_sum
4435 i = n - k
4436 if i <= j:
4437 index += comb(j, i)
4439 return index
4442def permutation_index(element, iterable):
4443 """Equivalent to ``list(permutations(iterable, r)).index(element)```
4445 The subsequences of *iterable* that are of length *r* where order is
4446 important can be ordered lexicographically. :func:`permutation_index`
4447 computes the index of the first *element* directly, without computing
4448 the previous permutations.
4450 >>> permutation_index([1, 3, 2], range(5))
4451 19
4453 ``ValueError`` will be raised if the given *element* isn't one of the
4454 permutations of *iterable*.
4455 """
4456 index = 0
4457 pool = list(iterable)
4458 for i, x in zip(range(len(pool), -1, -1), element):
4459 r = pool.index(x)
4460 index = index * i + r
4461 del pool[r]
4463 return index
4466class countable:
4467 """Wrap *iterable* and keep a count of how many items have been consumed.
4469 The ``items_seen`` attribute starts at ``0`` and increments as the iterable
4470 is consumed:
4472 >>> iterable = map(str, range(10))
4473 >>> it = countable(iterable)
4474 >>> it.items_seen
4475 0
4476 >>> next(it), next(it)
4477 ('0', '1')
4478 >>> list(it)
4479 ['2', '3', '4', '5', '6', '7', '8', '9']
4480 >>> it.items_seen
4481 10
4482 """
4484 def __init__(self, iterable):
4485 self._iterator = iter(iterable)
4486 self.items_seen = 0
4488 def __iter__(self):
4489 return self
4491 def __next__(self):
4492 item = next(self._iterator)
4493 self.items_seen += 1
4495 return item
4498def chunked_even(iterable, n):
4499 """Break *iterable* into lists of approximately length *n*.
4500 Items are distributed such the lengths of the lists differ by at most
4501 1 item.
4503 >>> iterable = [1, 2, 3, 4, 5, 6, 7]
4504 >>> n = 3
4505 >>> list(chunked_even(iterable, n)) # List lengths: 3, 2, 2
4506 [[1, 2, 3], [4, 5], [6, 7]]
4507 >>> list(chunked(iterable, n)) # List lengths: 3, 3, 1
4508 [[1, 2, 3], [4, 5, 6], [7]]
4510 """
4511 iterator = iter(iterable)
4513 # Initialize a buffer to process the chunks while keeping
4514 # some back to fill any underfilled chunks
4515 min_buffer = (n - 1) * (n - 2)
4516 buffer = list(islice(iterator, min_buffer))
4518 # Append items until we have a completed chunk
4519 for _ in islice(map(buffer.append, iterator), n, None, n):
4520 yield buffer[:n]
4521 del buffer[:n]
4523 # Check if any chunks need addition processing
4524 if not buffer:
4525 return
4526 length = len(buffer)
4528 # Chunks are either size `full_size <= n` or `partial_size = full_size - 1`
4529 q, r = divmod(length, n)
4530 num_lists = q + (1 if r > 0 else 0)
4531 q, r = divmod(length, num_lists)
4532 full_size = q + (1 if r > 0 else 0)
4533 partial_size = full_size - 1
4534 num_full = length - partial_size * num_lists
4536 # Yield chunks of full size
4537 partial_start_idx = num_full * full_size
4538 if full_size > 0:
4539 for i in range(0, partial_start_idx, full_size):
4540 yield buffer[i : i + full_size]
4542 # Yield chunks of partial size
4543 if partial_size > 0:
4544 for i in range(partial_start_idx, length, partial_size):
4545 yield buffer[i : i + partial_size]
4548def zip_broadcast(*objects, scalar_types=(str, bytes), strict=False):
4549 """A version of :func:`zip` that "broadcasts" any scalar
4550 (i.e., non-iterable) items into output tuples.
4552 >>> iterable_1 = [1, 2, 3]
4553 >>> iterable_2 = ['a', 'b', 'c']
4554 >>> scalar = '_'
4555 >>> list(zip_broadcast(iterable_1, iterable_2, scalar))
4556 [(1, 'a', '_'), (2, 'b', '_'), (3, 'c', '_')]
4558 The *scalar_types* keyword argument determines what types are considered
4559 scalar. It is set to ``(str, bytes)`` by default. Set it to ``None`` to
4560 treat strings and byte strings as iterable:
4562 >>> list(zip_broadcast('abc', 0, 'xyz', scalar_types=None))
4563 [('a', 0, 'x'), ('b', 0, 'y'), ('c', 0, 'z')]
4565 If the *strict* keyword argument is ``True``, then
4566 ``ValueError`` will be raised if any of the iterables have
4567 different lengths.
4568 """
4570 def is_scalar(obj):
4571 if scalar_types and isinstance(obj, scalar_types):
4572 return True
4573 try:
4574 iter(obj)
4575 except TypeError:
4576 return True
4577 else:
4578 return False
4580 size = len(objects)
4581 if not size:
4582 return
4584 new_item = [None] * size
4585 iterables, iterable_positions = [], []
4586 for i, obj in enumerate(objects):
4587 if is_scalar(obj):
4588 new_item[i] = obj
4589 else:
4590 iterables.append(iter(obj))
4591 iterable_positions.append(i)
4593 if not iterables:
4594 yield tuple(objects)
4595 return
4597 for item in zip(*iterables, strict=strict):
4598 for i, new_item[i] in zip(iterable_positions, item):
4599 pass
4600 yield tuple(new_item)
4603def unique_in_window(iterable, n, key=None):
4604 """Yield the items from *iterable* that haven't been seen recently.
4605 *n* is the size of the sliding window.
4607 >>> iterable = [0, 1, 0, 2, 3, 0]
4608 >>> n = 3
4609 >>> list(unique_in_window(iterable, n))
4610 [0, 1, 2, 3, 0]
4612 The *key* function, if provided, will be used to determine uniqueness:
4614 >>> list(unique_in_window('abAcda', 3, key=lambda x: x.lower()))
4615 ['a', 'b', 'c', 'd', 'a']
4617 Updates a sliding window no larger than n and yields a value
4618 if the item only occurs once in the updated window.
4620 When `n == 1`, *unique_in_window* is memoryless:
4622 >>> list(unique_in_window('aab', n=1))
4623 ['a', 'a', 'b']
4625 The items in *iterable* must be hashable.
4627 """
4628 if n <= 0:
4629 raise ValueError('n must be greater than 0')
4631 window = deque(maxlen=n)
4632 counts = Counter()
4633 use_key = key is not None
4635 for item in iterable:
4636 if len(window) == n:
4637 to_discard = window[0]
4638 if counts[to_discard] == 1:
4639 del counts[to_discard]
4640 else:
4641 counts[to_discard] -= 1
4643 k = key(item) if use_key else item
4644 if k not in counts:
4645 yield item
4646 counts[k] += 1
4647 window.append(k)
4650def duplicates_everseen(iterable, key=None):
4651 """Yield duplicate elements after their first appearance.
4653 >>> list(duplicates_everseen('mississippi'))
4654 ['s', 'i', 's', 's', 'i', 'p', 'i']
4655 >>> list(duplicates_everseen('AaaBbbCccAaa', str.lower))
4656 ['a', 'a', 'b', 'b', 'c', 'c', 'A', 'a', 'a']
4658 This function is analogous to :func:`unique_everseen` and is subject to
4659 the same performance considerations.
4661 """
4662 seen_set = set()
4663 seen_list = []
4664 use_key = key is not None
4666 for element in iterable:
4667 k = key(element) if use_key else element
4668 try:
4669 if k not in seen_set:
4670 seen_set.add(k)
4671 else:
4672 yield element
4673 except TypeError:
4674 if k not in seen_list:
4675 seen_list.append(k)
4676 else:
4677 yield element
4680def duplicates_justseen(iterable, key=None):
4681 """Yields serially-duplicate elements after their first appearance.
4683 >>> list(duplicates_justseen('mississippi'))
4684 ['s', 's', 'p']
4685 >>> list(duplicates_justseen('AaaBbbCccAaa', str.lower))
4686 ['a', 'a', 'b', 'b', 'c', 'c', 'a', 'a']
4688 This function is analogous to :func:`unique_justseen`.
4690 """
4691 return flatten(g for _, g in groupby(iterable, key) for _ in g)
4694def classify_unique(iterable, key=None):
4695 """Classify each element in terms of its uniqueness.
4697 For each element in the input iterable, return a 3-tuple consisting of:
4699 1. The element itself
4700 2. ``False`` if the element is equal to the one preceding it in the input,
4701 ``True`` otherwise (i.e. the equivalent of :func:`unique_justseen`)
4702 3. ``False`` if this element has been seen anywhere in the input before,
4703 ``True`` otherwise (i.e. the equivalent of :func:`unique_everseen`)
4705 >>> list(classify_unique('otto')) # doctest: +NORMALIZE_WHITESPACE
4706 [('o', True, True),
4707 ('t', True, True),
4708 ('t', False, False),
4709 ('o', True, False)]
4711 This function is analogous to :func:`unique_everseen` and is subject to
4712 the same performance considerations.
4714 """
4715 seen_set = set()
4716 seen_list = []
4717 use_key = key is not None
4718 previous = None
4720 for i, element in enumerate(iterable):
4721 k = key(element) if use_key else element
4722 is_unique_justseen = not i or previous != k
4723 previous = k
4724 is_unique_everseen = False
4725 try:
4726 if k not in seen_set:
4727 seen_set.add(k)
4728 is_unique_everseen = True
4729 except TypeError:
4730 if k not in seen_list:
4731 seen_list.append(k)
4732 is_unique_everseen = True
4733 yield element, is_unique_justseen, is_unique_everseen
4736def minmax(iterable_or_value, *others, key=None, default=_marker):
4737 """Returns both the smallest and largest items from an iterable
4738 or from two or more arguments.
4740 >>> minmax([3, 1, 5])
4741 (1, 5)
4743 >>> minmax(4, 2, 6)
4744 (2, 6)
4746 If a *key* function is provided, it will be used to transform the input
4747 items for comparison.
4749 >>> minmax([5, 30], key=str) # '30' sorts before '5'
4750 (30, 5)
4752 If a *default* value is provided, it will be returned if there are no
4753 input items.
4755 >>> minmax([], default=(0, 0))
4756 (0, 0)
4758 Otherwise ``ValueError`` is raised.
4760 This function makes a single pass over the input elements and takes care to
4761 minimize the number of comparisons made during processing.
4763 Note that unlike the builtin ``max`` function, which always returns the first
4764 item with the maximum value, this function may return another item when there are
4765 ties.
4767 This function is based on the
4768 `recipe <https://code.activestate.com/recipes/577916-fast-minmax-function>`__ by
4769 Raymond Hettinger.
4770 """
4771 iterable = (iterable_or_value, *others) if others else iterable_or_value
4773 it = iter(iterable)
4775 try:
4776 lo = hi = next(it)
4777 except StopIteration as exc:
4778 if default is _marker:
4779 raise ValueError(
4780 '`minmax()` argument is an empty iterable. '
4781 'Provide a `default` value to suppress this error.'
4782 ) from exc
4783 return default
4785 # Different branches depending on the presence of key. This saves a lot
4786 # of unimportant copies which would slow the "key=None" branch
4787 # significantly down.
4788 if key is None:
4789 for x, y in zip_longest(it, it, fillvalue=lo):
4790 if y < x:
4791 x, y = y, x
4792 if x < lo:
4793 lo = x
4794 if hi < y:
4795 hi = y
4797 else:
4798 lo_key = hi_key = key(lo)
4800 for x, y in zip_longest(it, it, fillvalue=lo):
4801 x_key, y_key = key(x), key(y)
4803 if y_key < x_key:
4804 x, y, x_key, y_key = y, x, y_key, x_key
4805 if x_key < lo_key:
4806 lo, lo_key = x, x_key
4807 if hi_key < y_key:
4808 hi, hi_key = y, y_key
4810 return lo, hi
4813def constrained_batches(
4814 iterable, max_size, max_count=None, get_len=len, strict=True
4815):
4816 """Yield batches of items from *iterable* with a combined size limited by
4817 *max_size*.
4819 >>> iterable = [b'12345', b'123', b'12345678', b'1', b'1', b'12', b'1']
4820 >>> list(constrained_batches(iterable, 10))
4821 [(b'12345', b'123'), (b'12345678', b'1', b'1'), (b'12', b'1')]
4823 If a *max_count* is supplied, the number of items per batch is also
4824 limited:
4826 >>> iterable = [b'12345', b'123', b'12345678', b'1', b'1', b'12', b'1']
4827 >>> list(constrained_batches(iterable, 10, max_count = 2))
4828 [(b'12345', b'123'), (b'12345678', b'1'), (b'1', b'12'), (b'1',)]
4830 If a *get_len* function is supplied, use that instead of :func:`len` to
4831 determine item size.
4833 If *strict* is ``True``, raise ``ValueError`` if any single item is bigger
4834 than *max_size*. Otherwise, allow single items to exceed *max_size*.
4835 """
4836 if max_size <= 0:
4837 raise ValueError('maximum size must be greater than zero')
4839 batch = []
4840 batch_size = 0
4841 batch_count = 0
4842 for item in iterable:
4843 item_len = get_len(item)
4844 if strict and item_len > max_size:
4845 raise ValueError('item size exceeds maximum size')
4847 reached_count = batch_count == max_count
4848 reached_size = item_len + batch_size > max_size
4849 if batch_count and (reached_size or reached_count):
4850 yield tuple(batch)
4851 batch.clear()
4852 batch_size = 0
4853 batch_count = 0
4855 batch.append(item)
4856 batch_size += item_len
4857 batch_count += 1
4859 if batch:
4860 yield tuple(batch)
4863def gray_product(*iterables):
4864 """Like :func:`itertools.product`, but return tuples in an order such
4865 that only one element in the generated tuple changes from one iteration
4866 to the next.
4868 >>> list(gray_product('AB','CD'))
4869 [('A', 'C'), ('B', 'C'), ('B', 'D'), ('A', 'D')]
4871 This function consumes all of the input iterables before producing output.
4872 If any of the input iterables have fewer than two items, ``ValueError``
4873 is raised.
4875 For information on the algorithm, see
4876 `this section <https://www-cs-faculty.stanford.edu/~knuth/fasc2a.ps.gz>`__
4877 of Donald Knuth's *The Art of Computer Programming*.
4878 """
4879 all_iterables = tuple(tuple(x) for x in iterables)
4880 iterable_count = len(all_iterables)
4881 for iterable in all_iterables:
4882 if len(iterable) < 2:
4883 raise ValueError("each iterable must have two or more items")
4885 # This is based on "Algorithm H" from section 7.2.1.1, page 20.
4886 # a holds the indexes of the source iterables for the n-tuple to be yielded
4887 # f is the array of "focus pointers"
4888 # o is the array of "directions"
4889 a = [0] * iterable_count
4890 f = list(range(iterable_count + 1))
4891 o = [1] * iterable_count
4892 while True:
4893 yield tuple(all_iterables[i][a[i]] for i in range(iterable_count))
4894 j = f[0]
4895 f[0] = 0
4896 if j == iterable_count:
4897 break
4898 a[j] = a[j] + o[j]
4899 if a[j] == 0 or a[j] == len(all_iterables[j]) - 1:
4900 o[j] = -o[j]
4901 f[j] = f[j + 1]
4902 f[j + 1] = j + 1
4905def partial_product(*iterables):
4906 """Yields tuples containing one item from each iterator, with subsequent
4907 tuples changing a single item at a time by advancing each iterator until it
4908 is exhausted. This sequence guarantees every value in each iterable is
4909 output at least once without generating all possible combinations.
4911 This may be useful, for example, when testing an expensive function.
4913 >>> list(partial_product('AB', 'C', 'DEF'))
4914 [('A', 'C', 'D'), ('B', 'C', 'D'), ('B', 'C', 'E'), ('B', 'C', 'F')]
4915 """
4917 iterators = list(map(iter, iterables))
4919 try:
4920 prod = [next(it) for it in iterators]
4921 except StopIteration:
4922 return
4923 yield tuple(prod)
4925 for i, it in enumerate(iterators):
4926 for prod[i] in it:
4927 yield tuple(prod)
4930def takewhile_inclusive(predicate, iterable):
4931 """A variant of :func:`takewhile` that yields one additional element.
4933 >>> list(takewhile_inclusive(lambda x: x < 5, [1, 4, 6, 4, 1]))
4934 [1, 4, 6]
4936 :func:`takewhile` would return ``[1, 4]``.
4937 """
4938 for x in iterable:
4939 yield x
4940 if not predicate(x):
4941 break
4944def outer_product(func, xs, ys, *args, **kwargs):
4945 """A generalized outer product that applies a binary function to all
4946 pairs of items. Returns a 2D matrix with ``len(xs)`` rows and ``len(ys)``
4947 columns.
4948 Also accepts ``*args`` and ``**kwargs`` that are passed to ``func``.
4950 Multiplication table:
4952 >>> list(outer_product(mul, range(1, 4), range(1, 6)))
4953 [(1, 2, 3, 4, 5), (2, 4, 6, 8, 10), (3, 6, 9, 12, 15)]
4955 Cross tabulation:
4957 >>> xs = ['A', 'B', 'A', 'A', 'B', 'B', 'A', 'A', 'B', 'B']
4958 >>> ys = ['X', 'X', 'X', 'Y', 'Z', 'Z', 'Y', 'Y', 'Z', 'Z']
4959 >>> pair_counts = Counter(zip(xs, ys))
4960 >>> count_rows = lambda x, y: pair_counts[x, y]
4961 >>> list(outer_product(count_rows, sorted(set(xs)), sorted(set(ys))))
4962 [(2, 3, 0), (1, 0, 4)]
4964 Usage with ``*args`` and ``**kwargs``:
4966 >>> animals = ['cat', 'wolf', 'mouse']
4967 >>> list(outer_product(min, animals, animals, key=len))
4968 [('cat', 'cat', 'cat'), ('cat', 'wolf', 'wolf'), ('cat', 'wolf', 'mouse')]
4969 """
4970 ys = tuple(ys)
4971 return batched(
4972 starmap(lambda x, y: func(x, y, *args, **kwargs), product(xs, ys)),
4973 n=len(ys),
4974 )
4977def iter_suppress(iterable, *exceptions):
4978 """Yield each of the items from *iterable*. If the iteration raises one of
4979 the specified *exceptions*, that exception will be suppressed and iteration
4980 will stop.
4982 >>> from itertools import chain
4983 >>> def breaks_at_five(x):
4984 ... while True:
4985 ... if x >= 5:
4986 ... raise RuntimeError
4987 ... yield x
4988 ... x += 1
4989 >>> it_1 = iter_suppress(breaks_at_five(1), RuntimeError)
4990 >>> it_2 = iter_suppress(breaks_at_five(2), RuntimeError)
4991 >>> list(chain(it_1, it_2))
4992 [1, 2, 3, 4, 2, 3, 4]
4993 """
4994 try:
4995 yield from iterable
4996 except exceptions:
4997 return
5000def filter_map(func, iterable):
5001 """Apply *func* to every element of *iterable*, yielding only those which
5002 are not ``None``.
5004 >>> elems = ['1', 'a', '2', 'b', '3']
5005 >>> list(filter_map(lambda s: int(s) if s.isnumeric() else None, elems))
5006 [1, 2, 3]
5007 """
5008 for x in iterable:
5009 y = func(x)
5010 if y is not None:
5011 yield y
5014def powerset_of_sets(iterable, *, baseset=set):
5015 """Yields all possible subsets of the iterable.
5017 >>> list(powerset_of_sets([1, 2, 3])) # doctest: +SKIP
5018 [set(), {1}, {2}, {3}, {1, 2}, {1, 3}, {2, 3}, {1, 2, 3}]
5019 >>> list(powerset_of_sets([1, 1, 0])) # doctest: +SKIP
5020 [set(), {1}, {0}, {0, 1}]
5022 :func:`powerset_of_sets` takes care to minimize the number
5023 of hash operations performed.
5025 The *baseset* parameter determines what kind of sets are
5026 constructed, either *set* or *frozenset*.
5027 """
5028 sets = tuple(dict.fromkeys(map(frozenset, zip(iterable))))
5029 union = baseset().union
5030 return chain.from_iterable(
5031 starmap(union, combinations(sets, r)) for r in range(len(sets) + 1)
5032 )
5035def join_mappings(**field_to_map):
5036 """
5037 Joins multiple mappings together using their common keys.
5039 >>> user_scores = {'elliot': 50, 'claris': 60}
5040 >>> user_times = {'elliot': 30, 'claris': 40}
5041 >>> join_mappings(score=user_scores, time=user_times)
5042 {'elliot': {'score': 50, 'time': 30}, 'claris': {'score': 60, 'time': 40}}
5043 """
5044 ret = defaultdict(dict)
5046 for field_name, mapping in field_to_map.items():
5047 for key, value in mapping.items():
5048 ret[key][field_name] = value
5050 return dict(ret)
5053def _complex_sumprod(v1, v2):
5054 """High precision sumprod() for complex numbers.
5055 Used by :func:`dft` and :func:`idft`.
5056 """
5058 real = attrgetter('real')
5059 imag = attrgetter('imag')
5060 r1 = chain(map(real, v1), map(neg, map(imag, v1)))
5061 r2 = chain(map(real, v2), map(imag, v2))
5062 i1 = chain(map(real, v1), map(imag, v1))
5063 i2 = chain(map(imag, v2), map(real, v2))
5064 return complex(_fsumprod(r1, r2), _fsumprod(i1, i2))
5067def dft(xarr):
5068 """Discrete Fourier Transform. *xarr* is a sequence of complex numbers.
5069 Yields the components of the corresponding transformed output vector.
5071 >>> import cmath
5072 >>> xarr = [1, 2-1j, -1j, -1+2j] # time domain
5073 >>> Xarr = [2, -2-2j, -2j, 4+4j] # frequency domain
5074 >>> magnitudes, phases = zip(*map(cmath.polar, Xarr))
5075 >>> all(map(cmath.isclose, dft(xarr), Xarr))
5076 True
5078 Inputs are restricted to numeric types that can add and multiply
5079 with a complex number. This includes int, float, complex, and
5080 Fraction, but excludes Decimal.
5082 See :func:`idft` for the inverse Discrete Fourier Transform.
5083 """
5084 N = len(xarr)
5085 roots_of_unity = [e ** (n / N * tau * -1j) for n in range(N)]
5086 for k in range(N):
5087 coeffs = [roots_of_unity[k * n % N] for n in range(N)]
5088 yield _complex_sumprod(xarr, coeffs)
5091def idft(Xarr):
5092 """Inverse Discrete Fourier Transform. *Xarr* is a sequence of
5093 complex numbers. Yields the components of the corresponding
5094 inverse-transformed output vector.
5096 >>> import cmath
5097 >>> xarr = [1, 2-1j, -1j, -1+2j] # time domain
5098 >>> Xarr = [2, -2-2j, -2j, 4+4j] # frequency domain
5099 >>> all(map(cmath.isclose, idft(Xarr), xarr))
5100 True
5102 Inputs are restricted to numeric types that can add and multiply
5103 with a complex number. This includes int, float, complex, and
5104 Fraction, but excludes Decimal.
5106 See :func:`dft` for the Discrete Fourier Transform.
5107 """
5108 N = len(Xarr)
5109 roots_of_unity = [e ** (n / N * tau * 1j) for n in range(N)]
5110 for k in range(N):
5111 coeffs = [roots_of_unity[k * n % N] for n in range(N)]
5112 yield _complex_sumprod(Xarr, coeffs) / N
5115def doublestarmap(func, iterable):
5116 """Apply *func* to every item of *iterable* by dictionary unpacking
5117 the item into *func*.
5119 The difference between :func:`itertools.starmap` and :func:`doublestarmap`
5120 parallels the distinction between ``func(*a)`` and ``func(**a)``.
5122 >>> iterable = [{'a': 1, 'b': 2}, {'a': 40, 'b': 60}]
5123 >>> list(doublestarmap(lambda a, b: a + b, iterable))
5124 [3, 100]
5126 ``TypeError`` will be raised if *func*'s signature doesn't match the
5127 mapping contained in *iterable* or if *iterable* does not contain mappings.
5128 """
5129 for item in iterable:
5130 yield func(**item)
5133def _nth_prime_bounds(n):
5134 """Bounds for the nth prime (counting from 1): lb < p_n < ub."""
5135 # At and above 688,383, the lb/ub spread is under 0.003 * p_n.
5137 if n < 1:
5138 raise ValueError
5140 if n < 6:
5141 return (n, 2.25 * n)
5143 # https://en.wikipedia.org/wiki/Prime-counting_function#Inequalities
5144 upper_bound = n * log(n * log(n))
5145 lower_bound = upper_bound - n
5146 if n >= 688_383:
5147 upper_bound -= n * (1.0 - (log(log(n)) - 2.0) / log(n))
5149 return lower_bound, upper_bound
5152def nth_prime(n, *, approximate=False):
5153 """Return the nth prime (counting from 0).
5155 >>> nth_prime(0)
5156 2
5157 >>> nth_prime(100)
5158 547
5160 If *approximate* is set to True, will return a prime close
5161 to the nth prime. The estimation is much faster than computing
5162 an exact result.
5164 >>> nth_prime(200_000_000, approximate=True) # Exact result is 4222234763
5165 4217820427
5167 """
5168 lb, ub = _nth_prime_bounds(n + 1)
5170 if not approximate or n <= 1_000_000:
5171 return nth(sieve(ceil(ub)), n)
5173 # Search from the midpoint and return the first odd prime
5174 odd = floor((lb + ub) / 2) | 1
5175 return first_true(count(odd, step=2), pred=is_prime)
5178def argmin(iterable, *, key=None):
5179 """
5180 Index of the first occurrence of a minimum value in an iterable.
5182 >>> argmin('efghabcdijkl')
5183 4
5184 >>> argmin([3, 2, 1, 0, 4, 2, 1, 0])
5185 3
5187 For example, look up a label corresponding to the position
5188 of a value that minimizes a cost function::
5190 >>> def cost(x):
5191 ... "Days for a wound to heal given a subject's age."
5192 ... return x**2 - 20*x + 150
5193 ...
5194 >>> labels = ['homer', 'marge', 'bart', 'lisa', 'maggie']
5195 >>> ages = [ 35, 30, 10, 9, 1 ]
5197 # Fastest healing family member
5198 >>> labels[argmin(ages, key=cost)]
5199 'bart'
5201 # Age with fastest healing
5202 >>> min(ages, key=cost)
5203 10
5205 """
5206 if key is not None:
5207 iterable = map(key, iterable)
5208 return min(enumerate(iterable), key=itemgetter(1))[0]
5211def argmax(iterable, *, key=None):
5212 """
5213 Index of the first occurrence of a maximum value in an iterable.
5215 >>> argmax('abcdefghabcd')
5216 7
5217 >>> argmax([0, 1, 2, 3, 3, 2, 1, 0])
5218 3
5220 For example, identify the best machine learning model::
5222 >>> models = ['svm', 'random forest', 'knn', 'naïve bayes']
5223 >>> accuracy = [ 68, 61, 84, 72 ]
5225 # Most accurate model
5226 >>> models[argmax(accuracy)]
5227 'knn'
5229 # Best accuracy
5230 >>> max(accuracy)
5231 84
5233 """
5234 if key is not None:
5235 iterable = map(key, iterable)
5236 return max(enumerate(iterable), key=itemgetter(1))[0]
5239def extract(iterable, indices):
5240 """Yield values at the specified indices.
5242 Example:
5244 >>> data = 'abcdefghijklmnopqrstuvwxyz'
5245 >>> list(extract(data, [7, 4, 11, 11, 14]))
5246 ['h', 'e', 'l', 'l', 'o']
5248 The *iterable* is consumed lazily and can be infinite.
5249 The *indices* are consumed immediately and must be finite.
5251 Raises ``IndexError`` if an index lies beyond the iterable.
5252 Raises ``ValueError`` for negative indices.
5253 """
5255 iterator = iter(iterable)
5256 index_and_position = sorted(zip(indices, count()))
5258 if index_and_position and index_and_position[0][0] < 0:
5259 raise ValueError('Indices must be non-negative')
5261 buffer = {}
5262 iterator_position = -1
5263 next_to_emit = 0
5265 for index, order in index_and_position:
5266 advance = index - iterator_position
5267 if advance:
5268 try:
5269 value = next(islice(iterator, advance - 1, None))
5270 except StopIteration:
5271 raise IndexError(index)
5272 iterator_position = index
5274 buffer[order] = value
5276 while next_to_emit in buffer:
5277 yield buffer.pop(next_to_emit)
5278 next_to_emit += 1