Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/more_itertools/more.py: 19%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import math
2import warnings
4from collections import Counter, defaultdict, deque, abc
5from collections.abc import Sequence
6from contextlib import suppress
7from functools import cached_property, partial, reduce, wraps
8from heapq import heapify, heapreplace
9from itertools import (
10 chain,
11 combinations,
12 compress,
13 count,
14 cycle,
15 dropwhile,
16 groupby,
17 islice,
18 permutations,
19 repeat,
20 starmap,
21 takewhile,
22 tee,
23 zip_longest,
24 product,
25)
26from math import comb, e, exp, factorial, floor, fsum, log, log1p, perm, tau
27from math import ceil
28from queue import Empty, Queue
29from random import random, randrange, shuffle, uniform
30from operator import (
31 attrgetter,
32 is_not,
33 itemgetter,
34 lt,
35 mul,
36 neg,
37 sub,
38 gt,
39)
40from sys import hexversion, maxsize
41from time import monotonic
43from .recipes import (
44 _marker,
45 _zip_equal,
46 UnequalIterablesError,
47 consume,
48 first_true,
49 flatten,
50 is_prime,
51 nth,
52 powerset,
53 sieve,
54 take,
55 unique_everseen,
56 all_equal,
57 batched,
58)
60__all__ = [
61 'AbortThread',
62 'SequenceView',
63 'UnequalIterablesError',
64 'adjacent',
65 'all_unique',
66 'always_iterable',
67 'always_reversible',
68 'argmax',
69 'argmin',
70 'bucket',
71 'callback_iter',
72 'chunked',
73 'chunked_even',
74 'circular_shifts',
75 'collapse',
76 'combination_index',
77 'combination_with_replacement_index',
78 'consecutive_groups',
79 'constrained_batches',
80 'consumer',
81 'count_cycle',
82 'countable',
83 'derangements',
84 'dft',
85 'difference',
86 'distinct_combinations',
87 'distinct_permutations',
88 'distribute',
89 'divide',
90 'doublestarmap',
91 'duplicates_everseen',
92 'duplicates_justseen',
93 'classify_unique',
94 'exactly_n',
95 'filter_except',
96 'filter_map',
97 'first',
98 'gray_product',
99 'groupby_transform',
100 'ichunked',
101 'iequals',
102 'idft',
103 'ilen',
104 'interleave',
105 'interleave_evenly',
106 'interleave_longest',
107 'intersperse',
108 'is_sorted',
109 'islice_extended',
110 'iterate',
111 'iter_suppress',
112 'join_mappings',
113 'last',
114 'locate',
115 'longest_common_prefix',
116 'lstrip',
117 'make_decorator',
118 'map_except',
119 'map_if',
120 'map_reduce',
121 'mark_ends',
122 'minmax',
123 'nth_or_last',
124 'nth_permutation',
125 'nth_prime',
126 'nth_product',
127 'nth_combination_with_replacement',
128 'numeric_range',
129 'one',
130 'only',
131 'outer_product',
132 'padded',
133 'partial_product',
134 'partitions',
135 'peekable',
136 'permutation_index',
137 'powerset_of_sets',
138 'product_index',
139 'raise_',
140 'repeat_each',
141 'repeat_last',
142 'replace',
143 'rlocate',
144 'rstrip',
145 'run_length',
146 'sample',
147 'seekable',
148 'set_partitions',
149 'side_effect',
150 'sliced',
151 'sort_together',
152 'split_after',
153 'split_at',
154 'split_before',
155 'split_into',
156 'split_when',
157 'spy',
158 'stagger',
159 'strip',
160 'strictly_n',
161 'substrings',
162 'substrings_indexes',
163 'takewhile_inclusive',
164 'time_limited',
165 'unique_in_window',
166 'unique_to_each',
167 'unzip',
168 'value_chain',
169 'windowed',
170 'windowed_complete',
171 'with_iter',
172 'zip_broadcast',
173 'zip_equal',
174 'zip_offset',
175]
177# math.sumprod is available for Python 3.12+
178try:
179 from math import sumprod as _fsumprod
181except ImportError: # pragma: no cover
182 # Extended precision algorithms from T. J. Dekker,
183 # "A Floating-Point Technique for Extending the Available Precision"
184 # https://csclub.uwaterloo.ca/~pbarfuss/dekker1971.pdf
185 # Formulas: (5.5) (5.6) and (5.8). Code: mul12()
187 def dl_split(x: float):
188 "Split a float into two half-precision components."
189 t = x * 134217729.0 # Veltkamp constant = 2.0 ** 27 + 1
190 hi = t - (t - x)
191 lo = x - hi
192 return hi, lo
194 def dl_mul(x, y):
195 "Lossless multiplication."
196 xx_hi, xx_lo = dl_split(x)
197 yy_hi, yy_lo = dl_split(y)
198 p = xx_hi * yy_hi
199 q = xx_hi * yy_lo + xx_lo * yy_hi
200 z = p + q
201 zz = p - z + q + xx_lo * yy_lo
202 return z, zz
204 def _fsumprod(p, q):
205 return fsum(chain.from_iterable(map(dl_mul, p, q)))
208def chunked(iterable, n, strict=False):
209 """Break *iterable* into lists of length *n*:
211 >>> list(chunked([1, 2, 3, 4, 5, 6], 3))
212 [[1, 2, 3], [4, 5, 6]]
214 By the default, the last yielded list will have fewer than *n* elements
215 if the length of *iterable* is not divisible by *n*:
217 >>> list(chunked([1, 2, 3, 4, 5, 6, 7, 8], 3))
218 [[1, 2, 3], [4, 5, 6], [7, 8]]
220 To use a fill-in value instead, see the :func:`grouper` recipe.
222 If the length of *iterable* is not divisible by *n* and *strict* is
223 ``True``, then ``ValueError`` will be raised before the last
224 list is yielded.
226 """
227 iterator = iter(partial(take, n, iter(iterable)), [])
228 if strict:
229 if n is None:
230 raise ValueError('n must not be None when using strict mode.')
232 def ret():
233 for chunk in iterator:
234 if len(chunk) != n:
235 raise ValueError('iterable is not divisible by n.')
236 yield chunk
238 return ret()
239 else:
240 return iterator
243def first(iterable, default=_marker):
244 """Return the first item of *iterable*, or *default* if *iterable* is
245 empty.
247 >>> first([0, 1, 2, 3])
248 0
249 >>> first([], 'some default')
250 'some default'
252 If *default* is not provided and there are no items in the iterable,
253 raise ``ValueError``.
255 :func:`first` is useful when you have a generator of expensive-to-retrieve
256 values and want any arbitrary one. It is marginally shorter than
257 ``next(iter(iterable), default)``.
259 """
260 for item in iterable:
261 return item
262 if default is _marker:
263 raise ValueError(
264 'first() was called on an empty iterable, '
265 'and no default value was provided.'
266 )
267 return default
270def last(iterable, default=_marker):
271 """Return the last item of *iterable*, or *default* if *iterable* is
272 empty.
274 >>> last([0, 1, 2, 3])
275 3
276 >>> last([], 'some default')
277 'some default'
279 If *default* is not provided and there are no items in the iterable,
280 raise ``ValueError``.
281 """
282 try:
283 if isinstance(iterable, Sequence):
284 return iterable[-1]
285 # Work around https://bugs.python.org/issue38525
286 if getattr(iterable, '__reversed__', None):
287 return next(reversed(iterable))
288 return deque(iterable, maxlen=1)[-1]
289 except (IndexError, TypeError, StopIteration):
290 if default is _marker:
291 raise ValueError(
292 'last() was called on an empty iterable, '
293 'and no default value was provided.'
294 )
295 return default
298def nth_or_last(iterable, n, default=_marker):
299 """Return the nth or the last item of *iterable*,
300 or *default* if *iterable* is empty.
302 >>> nth_or_last([0, 1, 2, 3], 2)
303 2
304 >>> nth_or_last([0, 1], 2)
305 1
306 >>> nth_or_last([], 0, 'some default')
307 'some default'
309 If *default* is not provided and there are no items in the iterable,
310 raise ``ValueError``.
311 """
312 return last(islice(iterable, n + 1), default=default)
315class peekable:
316 """Wrap an iterator to allow lookahead and prepending elements.
318 Call :meth:`peek` on the result to get the value that will be returned
319 by :func:`next`. This won't advance the iterator:
321 >>> p = peekable(['a', 'b'])
322 >>> p.peek()
323 'a'
324 >>> next(p)
325 'a'
327 Pass :meth:`peek` a default value to return that instead of raising
328 ``StopIteration`` when the iterator is exhausted.
330 >>> p = peekable([])
331 >>> p.peek('hi')
332 'hi'
334 peekables also offer a :meth:`prepend` method, which "inserts" items
335 at the head of the iterable:
337 >>> p = peekable([1, 2, 3])
338 >>> p.prepend(10, 11, 12)
339 >>> next(p)
340 10
341 >>> p.peek()
342 11
343 >>> list(p)
344 [11, 12, 1, 2, 3]
346 peekables can be indexed. Index 0 is the item that will be returned by
347 :func:`next`, index 1 is the item after that, and so on:
348 The values up to the given index will be cached.
350 >>> p = peekable(['a', 'b', 'c', 'd'])
351 >>> p[0]
352 'a'
353 >>> p[1]
354 'b'
355 >>> next(p)
356 'a'
358 Negative indexes are supported, but be aware that they will cache the
359 remaining items in the source iterator, which may require significant
360 storage.
362 To check whether a peekable is exhausted, check its truth value:
364 >>> p = peekable(['a', 'b'])
365 >>> if p: # peekable has items
366 ... list(p)
367 ['a', 'b']
368 >>> if not p: # peekable is exhausted
369 ... list(p)
370 []
372 """
374 def __init__(self, iterable):
375 self._it = iter(iterable)
376 self._cache = deque()
378 def __iter__(self):
379 return self
381 def __bool__(self):
382 try:
383 self.peek()
384 except StopIteration:
385 return False
386 return True
388 def peek(self, default=_marker):
389 """Return the item that will be next returned from ``next()``.
391 Return ``default`` if there are no items left. If ``default`` is not
392 provided, raise ``StopIteration``.
394 """
395 if not self._cache:
396 try:
397 self._cache.append(next(self._it))
398 except StopIteration:
399 if default is _marker:
400 raise
401 return default
402 return self._cache[0]
404 def prepend(self, *items):
405 """Stack up items to be the next ones returned from ``next()`` or
406 ``self.peek()``. The items will be returned in
407 first in, first out order::
409 >>> p = peekable([1, 2, 3])
410 >>> p.prepend(10, 11, 12)
411 >>> next(p)
412 10
413 >>> list(p)
414 [11, 12, 1, 2, 3]
416 It is possible, by prepending items, to "resurrect" a peekable that
417 previously raised ``StopIteration``.
419 >>> p = peekable([])
420 >>> next(p)
421 Traceback (most recent call last):
422 ...
423 StopIteration
424 >>> p.prepend(1)
425 >>> next(p)
426 1
427 >>> next(p)
428 Traceback (most recent call last):
429 ...
430 StopIteration
432 """
433 self._cache.extendleft(reversed(items))
435 def __next__(self):
436 if self._cache:
437 return self._cache.popleft()
439 return next(self._it)
441 def _get_slice(self, index):
442 # Normalize the slice's arguments
443 step = 1 if (index.step is None) else index.step
444 if step > 0:
445 start = 0 if (index.start is None) else index.start
446 stop = maxsize if (index.stop is None) else index.stop
447 elif step < 0:
448 start = -1 if (index.start is None) else index.start
449 stop = (-maxsize - 1) if (index.stop is None) else index.stop
450 else:
451 raise ValueError('slice step cannot be zero')
453 # If either the start or stop index is negative, we'll need to cache
454 # the rest of the iterable in order to slice from the right side.
455 if (start < 0) or (stop < 0):
456 self._cache.extend(self._it)
457 # Otherwise we'll need to find the rightmost index and cache to that
458 # point.
459 else:
460 n = min(max(start, stop) + 1, maxsize)
461 cache_len = len(self._cache)
462 if n >= cache_len:
463 self._cache.extend(islice(self._it, n - cache_len))
465 return list(self._cache)[index]
467 def __getitem__(self, index):
468 if isinstance(index, slice):
469 return self._get_slice(index)
471 cache_len = len(self._cache)
472 if index < 0:
473 self._cache.extend(self._it)
474 elif index >= cache_len:
475 self._cache.extend(islice(self._it, index + 1 - cache_len))
477 return self._cache[index]
480def consumer(func):
481 """Decorator that automatically advances a PEP-342-style "reverse iterator"
482 to its first yield point so you don't have to call ``next()`` on it
483 manually.
485 >>> @consumer
486 ... def tally():
487 ... i = 0
488 ... while True:
489 ... print('Thing number %s is %s.' % (i, (yield)))
490 ... i += 1
491 ...
492 >>> t = tally()
493 >>> t.send('red')
494 Thing number 0 is red.
495 >>> t.send('fish')
496 Thing number 1 is fish.
498 Without the decorator, you would have to call ``next(t)`` before
499 ``t.send()`` could be used.
501 """
503 @wraps(func)
504 def wrapper(*args, **kwargs):
505 gen = func(*args, **kwargs)
506 next(gen)
507 return gen
509 return wrapper
512def ilen(iterable):
513 """Return the number of items in *iterable*.
515 For example, there are 168 prime numbers below 1,000:
517 >>> ilen(sieve(1000))
518 168
520 Equivalent to, but faster than::
522 def ilen(iterable):
523 count = 0
524 for _ in iterable:
525 count += 1
526 return count
528 This fully consumes the iterable, so handle with care.
530 """
531 # This is the "most beautiful of the fast variants" of this function.
532 # If you think you can improve on it, please ensure that your version
533 # is both 10x faster and 10x more beautiful.
534 return sum(compress(repeat(1), zip(iterable)))
537def iterate(func, start):
538 """Return ``start``, ``func(start)``, ``func(func(start))``, ...
540 Produces an infinite iterator. To add a stopping condition,
541 use :func:`take`, ``takewhile``, or :func:`takewhile_inclusive`:.
543 >>> take(10, iterate(lambda x: 2*x, 1))
544 [1, 2, 4, 8, 16, 32, 64, 128, 256, 512]
546 >>> collatz = lambda x: 3*x + 1 if x%2==1 else x // 2
547 >>> list(takewhile_inclusive(lambda x: x!=1, iterate(collatz, 10)))
548 [10, 5, 16, 8, 4, 2, 1]
550 """
551 with suppress(StopIteration):
552 while True:
553 yield start
554 start = func(start)
557def with_iter(context_manager):
558 """Wrap an iterable in a ``with`` statement, so it closes once exhausted.
560 For example, this will close the file when the iterator is exhausted::
562 upper_lines = (line.upper() for line in with_iter(open('foo')))
564 Any context manager which returns an iterable is a candidate for
565 ``with_iter``.
567 """
568 with context_manager as iterable:
569 yield from iterable
572def one(iterable, too_short=None, too_long=None):
573 """Return the first item from *iterable*, which is expected to contain only
574 that item. Raise an exception if *iterable* is empty or has more than one
575 item.
577 :func:`one` is useful for ensuring that an iterable contains only one item.
578 For example, it can be used to retrieve the result of a database query
579 that is expected to return a single row.
581 If *iterable* is empty, ``ValueError`` will be raised. You may specify a
582 different exception with the *too_short* keyword:
584 >>> it = []
585 >>> one(it) # doctest: +IGNORE_EXCEPTION_DETAIL
586 Traceback (most recent call last):
587 ...
588 ValueError: too few items in iterable (expected 1)'
589 >>> too_short = IndexError('too few items')
590 >>> one(it, too_short=too_short) # doctest: +IGNORE_EXCEPTION_DETAIL
591 Traceback (most recent call last):
592 ...
593 IndexError: too few items
595 Similarly, if *iterable* contains more than one item, ``ValueError`` will
596 be raised. You may specify a different exception with the *too_long*
597 keyword:
599 >>> it = ['too', 'many']
600 >>> one(it) # doctest: +IGNORE_EXCEPTION_DETAIL
601 Traceback (most recent call last):
602 ...
603 ValueError: Expected exactly one item in iterable, but got 'too',
604 'many', and perhaps more.
605 >>> too_long = RuntimeError
606 >>> one(it, too_long=too_long) # doctest: +IGNORE_EXCEPTION_DETAIL
607 Traceback (most recent call last):
608 ...
609 RuntimeError
611 Note that :func:`one` attempts to advance *iterable* twice to ensure there
612 is only one item. See :func:`spy` or :func:`peekable` to check iterable
613 contents less destructively.
615 """
616 iterator = iter(iterable)
617 for first in iterator:
618 for second in iterator:
619 msg = (
620 f'Expected exactly one item in iterable, but got {first!r}, '
621 f'{second!r}, and perhaps more.'
622 )
623 raise too_long or ValueError(msg)
624 return first
625 raise too_short or ValueError('too few items in iterable (expected 1)')
628def raise_(exception, *args):
629 raise exception(*args)
632def strictly_n(iterable, n, too_short=None, too_long=None):
633 """Validate that *iterable* has exactly *n* items and return them if
634 it does. If it has fewer than *n* items, call function *too_short*
635 with those items. If it has more than *n* items, call function
636 *too_long* with the first ``n + 1`` items.
638 >>> iterable = ['a', 'b', 'c', 'd']
639 >>> n = 4
640 >>> list(strictly_n(iterable, n))
641 ['a', 'b', 'c', 'd']
643 Note that the returned iterable must be consumed in order for the check to
644 be made.
646 By default, *too_short* and *too_long* are functions that raise
647 ``ValueError``.
649 >>> list(strictly_n('ab', 3)) # doctest: +IGNORE_EXCEPTION_DETAIL
650 Traceback (most recent call last):
651 ...
652 ValueError: too few items in iterable (got 2)
654 >>> list(strictly_n('abc', 2)) # doctest: +IGNORE_EXCEPTION_DETAIL
655 Traceback (most recent call last):
656 ...
657 ValueError: too many items in iterable (got at least 3)
659 You can instead supply functions that do something else.
660 *too_short* will be called with the number of items in *iterable*.
661 *too_long* will be called with `n + 1`.
663 >>> def too_short(item_count):
664 ... raise RuntimeError
665 >>> it = strictly_n('abcd', 6, too_short=too_short)
666 >>> list(it) # doctest: +IGNORE_EXCEPTION_DETAIL
667 Traceback (most recent call last):
668 ...
669 RuntimeError
671 >>> def too_long(item_count):
672 ... print('The boss is going to hear about this')
673 >>> it = strictly_n('abcdef', 4, too_long=too_long)
674 >>> list(it)
675 The boss is going to hear about this
676 ['a', 'b', 'c', 'd']
678 """
679 if too_short is None:
680 too_short = lambda item_count: raise_(
681 ValueError,
682 f'Too few items in iterable (got {item_count})',
683 )
685 if too_long is None:
686 too_long = lambda item_count: raise_(
687 ValueError,
688 f'Too many items in iterable (got at least {item_count})',
689 )
691 it = iter(iterable)
693 sent = 0
694 for item in islice(it, n):
695 yield item
696 sent += 1
698 if sent < n:
699 too_short(sent)
700 return
702 for item in it:
703 too_long(n + 1)
704 return
707def distinct_permutations(iterable, r=None):
708 """Yield successive distinct permutations of the elements in *iterable*.
710 >>> sorted(distinct_permutations([1, 0, 1]))
711 [(0, 1, 1), (1, 0, 1), (1, 1, 0)]
713 Equivalent to yielding from ``set(permutations(iterable))``, except
714 duplicates are not generated and thrown away. For larger input sequences
715 this is much more efficient.
717 Duplicate permutations arise when there are duplicated elements in the
718 input iterable. The number of items returned is
719 `n! / (x_1! * x_2! * ... * x_n!)`, where `n` is the total number of
720 items input, and each `x_i` is the count of a distinct item in the input
721 sequence. The function :func:`multinomial` computes this directly.
723 If *r* is given, only the *r*-length permutations are yielded.
725 >>> sorted(distinct_permutations([1, 0, 1], r=2))
726 [(0, 1), (1, 0), (1, 1)]
727 >>> sorted(distinct_permutations(range(3), r=2))
728 [(0, 1), (0, 2), (1, 0), (1, 2), (2, 0), (2, 1)]
730 *iterable* need not be sortable, but note that using equal (``x == y``)
731 but non-identical (``id(x) != id(y)``) elements may produce surprising
732 behavior. For example, ``1`` and ``True`` are equal but non-identical:
734 >>> list(distinct_permutations([1, True, '3'])) # doctest: +SKIP
735 [
736 (1, True, '3'),
737 (1, '3', True),
738 ('3', 1, True)
739 ]
740 >>> list(distinct_permutations([1, 2, '3'])) # doctest: +SKIP
741 [
742 (1, 2, '3'),
743 (1, '3', 2),
744 (2, 1, '3'),
745 (2, '3', 1),
746 ('3', 1, 2),
747 ('3', 2, 1)
748 ]
749 """
751 # Algorithm: https://w.wiki/Qai
752 def _full(A):
753 while True:
754 # Yield the permutation we have
755 yield tuple(A)
757 # Find the largest index i such that A[i] < A[i + 1]
758 for i in range(size - 2, -1, -1):
759 if A[i] < A[i + 1]:
760 break
761 # If no such index exists, this permutation is the last one
762 else:
763 return
765 # Find the largest index j greater than j such that A[i] < A[j]
766 for j in range(size - 1, i, -1):
767 if A[i] < A[j]:
768 break
770 # Swap the value of A[i] with that of A[j], then reverse the
771 # sequence from A[i + 1] to form the new permutation
772 A[i], A[j] = A[j], A[i]
773 A[i + 1 :] = A[: i - size : -1] # A[i + 1:][::-1]
775 # Algorithm: modified from the above
776 def _partial(A, r):
777 # Split A into the first r items and the last r items
778 head, tail = A[:r], A[r:]
779 right_head_indexes = range(r - 1, -1, -1)
780 left_tail_indexes = range(len(tail))
782 while True:
783 # Yield the permutation we have
784 yield tuple(head)
786 # Starting from the right, find the first index of the head with
787 # value smaller than the maximum value of the tail - call it i.
788 pivot = tail[-1]
789 for i in right_head_indexes:
790 if head[i] < pivot:
791 break
792 pivot = head[i]
793 else:
794 return
796 # Starting from the left, find the first value of the tail
797 # with a value greater than head[i] and swap.
798 for j in left_tail_indexes:
799 if tail[j] > head[i]:
800 head[i], tail[j] = tail[j], head[i]
801 break
802 # If we didn't find one, start from the right and find the first
803 # index of the head with a value greater than head[i] and swap.
804 else:
805 for j in right_head_indexes:
806 if head[j] > head[i]:
807 head[i], head[j] = head[j], head[i]
808 break
810 # Reverse head[i + 1:] and swap it with tail[:r - (i + 1)]
811 tail += head[: i - r : -1] # head[i + 1:][::-1]
812 i += 1
813 head[i:], tail[:] = tail[: r - i], tail[r - i :]
815 items = list(iterable)
817 try:
818 items.sort()
819 sortable = True
820 except TypeError:
821 sortable = False
823 indices_dict = defaultdict(list)
825 for item in items:
826 indices_dict[items.index(item)].append(item)
828 indices = [items.index(item) for item in items]
829 indices.sort()
831 equivalent_items = {k: cycle(v) for k, v in indices_dict.items()}
833 def permuted_items(permuted_indices):
834 return tuple(
835 next(equivalent_items[index]) for index in permuted_indices
836 )
838 size = len(items)
839 if r is None:
840 r = size
842 # functools.partial(_partial, ... )
843 algorithm = _full if (r == size) else partial(_partial, r=r)
845 if 0 < r <= size:
846 if sortable:
847 return algorithm(items)
848 else:
849 return (
850 permuted_items(permuted_indices)
851 for permuted_indices in algorithm(indices)
852 )
854 return iter(() if r else ((),))
857def derangements(iterable, r=None):
858 """Yield successive derangements of the elements in *iterable*.
860 A derangement is a permutation in which no element appears at its original
861 index. In other words, a derangement is a permutation that has no fixed points.
863 Suppose Alice, Bob, Carol, and Dave are playing Secret Santa.
864 The code below outputs all of the different ways to assign gift recipients
865 such that nobody is assigned to himself or herself:
867 >>> for d in derangements(['Alice', 'Bob', 'Carol', 'Dave']):
868 ... print(', '.join(d))
869 Bob, Alice, Dave, Carol
870 Bob, Carol, Dave, Alice
871 Bob, Dave, Alice, Carol
872 Carol, Alice, Dave, Bob
873 Carol, Dave, Alice, Bob
874 Carol, Dave, Bob, Alice
875 Dave, Alice, Bob, Carol
876 Dave, Carol, Alice, Bob
877 Dave, Carol, Bob, Alice
879 If *r* is given, only the *r*-length derangements are yielded.
881 >>> sorted(derangements(range(3), 2))
882 [(1, 0), (1, 2), (2, 0)]
883 >>> sorted(derangements([0, 2, 3], 2))
884 [(2, 0), (2, 3), (3, 0)]
886 Elements are treated as unique based on their position, not on their value.
887 If the input elements are unique, there will be no repeated values within a
888 permutation.
890 The number of derangements of a set of size *n* is known as the
891 "subfactorial of n". For n > 0, the subfactorial is:
892 ``round(math.factorial(n) / math.e)``.
894 References:
896 * Article: https://www.numberanalytics.com/blog/ultimate-guide-to-derangements-in-combinatorics
897 * Sizes: https://oeis.org/A000166
898 """
899 xs = tuple(iterable)
900 ys = tuple(range(len(xs)))
901 return compress(
902 permutations(xs, r=r),
903 map(all, map(map, repeat(is_not), repeat(ys), permutations(ys, r=r))),
904 )
907def intersperse(e, iterable, n=1):
908 """Intersperse filler element *e* among the items in *iterable*, leaving
909 *n* items between each filler element.
911 >>> list(intersperse('!', [1, 2, 3, 4, 5]))
912 [1, '!', 2, '!', 3, '!', 4, '!', 5]
914 >>> list(intersperse(None, [1, 2, 3, 4, 5], n=2))
915 [1, 2, None, 3, 4, None, 5]
917 """
918 if n == 0:
919 raise ValueError('n must be > 0')
920 elif n == 1:
921 # interleave(repeat(e), iterable) -> e, x_0, e, x_1, e, x_2...
922 # islice(..., 1, None) -> x_0, e, x_1, e, x_2...
923 return islice(interleave(repeat(e), iterable), 1, None)
924 else:
925 # interleave(filler, chunks) -> [e], [x_0, x_1], [e], [x_2, x_3]...
926 # islice(..., 1, None) -> [x_0, x_1], [e], [x_2, x_3]...
927 # flatten(...) -> x_0, x_1, e, x_2, x_3...
928 filler = repeat([e])
929 chunks = chunked(iterable, n)
930 return flatten(islice(interleave(filler, chunks), 1, None))
933def unique_to_each(*iterables):
934 """Return the elements from each of the input iterables that aren't in the
935 other input iterables.
937 For example, suppose you have a set of packages, each with a set of
938 dependencies::
940 {'pkg_1': {'A', 'B'}, 'pkg_2': {'B', 'C'}, 'pkg_3': {'B', 'D'}}
942 If you remove one package, which dependencies can also be removed?
944 If ``pkg_1`` is removed, then ``A`` is no longer necessary - it is not
945 associated with ``pkg_2`` or ``pkg_3``. Similarly, ``C`` is only needed for
946 ``pkg_2``, and ``D`` is only needed for ``pkg_3``::
948 >>> unique_to_each({'A', 'B'}, {'B', 'C'}, {'B', 'D'})
949 [['A'], ['C'], ['D']]
951 If there are duplicates in one input iterable that aren't in the others
952 they will be duplicated in the output. Input order is preserved::
954 >>> unique_to_each("mississippi", "missouri")
955 [['p', 'p'], ['o', 'u', 'r']]
957 It is assumed that the elements of each iterable are hashable.
959 """
960 pool = [list(it) for it in iterables]
961 counts = Counter(chain.from_iterable(map(set, pool)))
962 uniques = {element for element in counts if counts[element] == 1}
963 return [list(filter(uniques.__contains__, it)) for it in pool]
966def windowed(seq, n, fillvalue=None, step=1):
967 """Return a sliding window of width *n* over the given iterable.
969 >>> all_windows = windowed([1, 2, 3, 4, 5], 3)
970 >>> list(all_windows)
971 [(1, 2, 3), (2, 3, 4), (3, 4, 5)]
973 When the window is larger than the iterable, *fillvalue* is used in place
974 of missing values:
976 >>> list(windowed([1, 2, 3], 4))
977 [(1, 2, 3, None)]
979 Each window will advance in increments of *step*:
981 >>> list(windowed([1, 2, 3, 4, 5, 6], 3, fillvalue='!', step=2))
982 [(1, 2, 3), (3, 4, 5), (5, 6, '!')]
984 To slide into the iterable's items, use :func:`chain` to add filler items
985 to the left:
987 >>> iterable = [1, 2, 3, 4]
988 >>> n = 3
989 >>> padding = [None] * (n - 1)
990 >>> list(windowed(chain(padding, iterable), 3))
991 [(None, None, 1), (None, 1, 2), (1, 2, 3), (2, 3, 4)]
992 """
993 if n < 0:
994 raise ValueError('n must be >= 0')
995 if n == 0:
996 yield ()
997 return
998 if step < 1:
999 raise ValueError('step must be >= 1')
1001 iterator = iter(seq)
1003 # Generate first window
1004 window = deque(islice(iterator, n), maxlen=n)
1006 # Deal with the first window not being full
1007 if not window:
1008 return
1009 if len(window) < n:
1010 yield tuple(window) + ((fillvalue,) * (n - len(window)))
1011 return
1012 yield tuple(window)
1014 # Create the filler for the next windows. The padding ensures
1015 # we have just enough elements to fill the last window.
1016 padding = (fillvalue,) * (n - 1 if step >= n else step - 1)
1017 filler = map(window.append, chain(iterator, padding))
1019 # Generate the rest of the windows
1020 for _ in islice(filler, step - 1, None, step):
1021 yield tuple(window)
1024def substrings(iterable):
1025 """Yield all of the substrings of *iterable*.
1027 >>> [''.join(s) for s in substrings('more')]
1028 ['m', 'o', 'r', 'e', 'mo', 'or', 're', 'mor', 'ore', 'more']
1030 Note that non-string iterables can also be subdivided.
1032 >>> list(substrings([0, 1, 2]))
1033 [(0,), (1,), (2,), (0, 1), (1, 2), (0, 1, 2)]
1035 """
1036 # The length-1 substrings
1037 seq = []
1038 for item in iterable:
1039 seq.append(item)
1040 yield (item,)
1041 seq = tuple(seq)
1042 item_count = len(seq)
1044 # And the rest
1045 for n in range(2, item_count + 1):
1046 for i in range(item_count - n + 1):
1047 yield seq[i : i + n]
1050def substrings_indexes(seq, reverse=False):
1051 """Yield all substrings and their positions in *seq*
1053 The items yielded will be a tuple of the form ``(substr, i, j)``, where
1054 ``substr == seq[i:j]``.
1056 This function only works for iterables that support slicing, such as
1057 ``str`` objects.
1059 >>> for item in substrings_indexes('more'):
1060 ... print(item)
1061 ('m', 0, 1)
1062 ('o', 1, 2)
1063 ('r', 2, 3)
1064 ('e', 3, 4)
1065 ('mo', 0, 2)
1066 ('or', 1, 3)
1067 ('re', 2, 4)
1068 ('mor', 0, 3)
1069 ('ore', 1, 4)
1070 ('more', 0, 4)
1072 Set *reverse* to ``True`` to yield the same items in the opposite order.
1075 """
1076 r = range(1, len(seq) + 1)
1077 if reverse:
1078 r = reversed(r)
1079 return (
1080 (seq[i : i + L], i, i + L) for L in r for i in range(len(seq) - L + 1)
1081 )
1084class bucket:
1085 """Wrap *iterable* and return an object that buckets the iterable into
1086 child iterables based on a *key* function.
1088 >>> iterable = ['a1', 'b1', 'c1', 'a2', 'b2', 'c2', 'b3']
1089 >>> s = bucket(iterable, key=lambda x: x[0]) # Bucket by 1st character
1090 >>> sorted(list(s)) # Get the keys
1091 ['a', 'b', 'c']
1092 >>> a_iterable = s['a']
1093 >>> next(a_iterable)
1094 'a1'
1095 >>> next(a_iterable)
1096 'a2'
1097 >>> list(s['b'])
1098 ['b1', 'b2', 'b3']
1100 The original iterable will be advanced and its items will be cached until
1101 they are used by the child iterables. This may require significant storage.
1103 By default, attempting to select a bucket to which no items belong will
1104 exhaust the iterable and cache all values.
1105 If you specify a *validator* function, selected buckets will instead be
1106 checked against it.
1108 >>> from itertools import count
1109 >>> it = count(1, 2) # Infinite sequence of odd numbers
1110 >>> key = lambda x: x % 10 # Bucket by last digit
1111 >>> validator = lambda x: x in {1, 3, 5, 7, 9} # Odd digits only
1112 >>> s = bucket(it, key=key, validator=validator)
1113 >>> 2 in s
1114 False
1115 >>> list(s[2])
1116 []
1118 """
1120 def __init__(self, iterable, key, validator=None):
1121 self._it = iter(iterable)
1122 self._key = key
1123 self._cache = defaultdict(deque)
1124 self._validator = validator or (lambda x: True)
1126 def __contains__(self, value):
1127 if not self._validator(value):
1128 return False
1130 try:
1131 item = next(self[value])
1132 except StopIteration:
1133 return False
1134 else:
1135 self._cache[value].appendleft(item)
1137 return True
1139 def _get_values(self, value):
1140 """
1141 Helper to yield items from the parent iterator that match *value*.
1142 Items that don't match are stored in the local cache as they
1143 are encountered.
1144 """
1145 while True:
1146 # If we've cached some items that match the target value, emit
1147 # the first one and evict it from the cache.
1148 if self._cache[value]:
1149 yield self._cache[value].popleft()
1150 # Otherwise we need to advance the parent iterator to search for
1151 # a matching item, caching the rest.
1152 else:
1153 while True:
1154 try:
1155 item = next(self._it)
1156 except StopIteration:
1157 return
1158 item_value = self._key(item)
1159 if item_value == value:
1160 yield item
1161 break
1162 elif self._validator(item_value):
1163 self._cache[item_value].append(item)
1165 def __iter__(self):
1166 for item in self._it:
1167 item_value = self._key(item)
1168 if self._validator(item_value):
1169 self._cache[item_value].append(item)
1171 return iter(self._cache)
1173 def __getitem__(self, value):
1174 if not self._validator(value):
1175 return iter(())
1177 return self._get_values(value)
1180def spy(iterable, n=1):
1181 """Return a 2-tuple with a list containing the first *n* elements of
1182 *iterable*, and an iterator with the same items as *iterable*.
1183 This allows you to "look ahead" at the items in the iterable without
1184 advancing it.
1186 There is one item in the list by default:
1188 >>> iterable = 'abcdefg'
1189 >>> head, iterable = spy(iterable)
1190 >>> head
1191 ['a']
1192 >>> list(iterable)
1193 ['a', 'b', 'c', 'd', 'e', 'f', 'g']
1195 You may use unpacking to retrieve items instead of lists:
1197 >>> (head,), iterable = spy('abcdefg')
1198 >>> head
1199 'a'
1200 >>> (first, second), iterable = spy('abcdefg', 2)
1201 >>> first
1202 'a'
1203 >>> second
1204 'b'
1206 The number of items requested can be larger than the number of items in
1207 the iterable:
1209 >>> iterable = [1, 2, 3, 4, 5]
1210 >>> head, iterable = spy(iterable, 10)
1211 >>> head
1212 [1, 2, 3, 4, 5]
1213 >>> list(iterable)
1214 [1, 2, 3, 4, 5]
1216 """
1217 p, q = tee(iterable)
1218 return take(n, q), p
1221def interleave(*iterables):
1222 """Return a new iterable yielding from each iterable in turn,
1223 until the shortest is exhausted.
1225 >>> list(interleave([1, 2, 3], [4, 5], [6, 7, 8]))
1226 [1, 4, 6, 2, 5, 7]
1228 For a version that doesn't terminate after the shortest iterable is
1229 exhausted, see :func:`interleave_longest`.
1231 """
1232 return chain.from_iterable(zip(*iterables))
1235def interleave_longest(*iterables):
1236 """Return a new iterable yielding from each iterable in turn,
1237 skipping any that are exhausted.
1239 >>> list(interleave_longest([1, 2, 3], [4, 5], [6, 7, 8]))
1240 [1, 4, 6, 2, 5, 7, 3, 8]
1242 This function produces the same output as :func:`roundrobin`, but may
1243 perform better for some inputs (in particular when the number of iterables
1244 is large).
1246 """
1247 for xs in zip_longest(*iterables, fillvalue=_marker):
1248 for x in xs:
1249 if x is not _marker:
1250 yield x
1253def interleave_evenly(iterables, lengths=None):
1254 """
1255 Interleave multiple iterables so that their elements are evenly distributed
1256 throughout the output sequence.
1258 >>> iterables = [1, 2, 3, 4, 5], ['a', 'b']
1259 >>> list(interleave_evenly(iterables))
1260 [1, 2, 'a', 3, 4, 'b', 5]
1262 >>> iterables = [[1, 2, 3], [4, 5], [6, 7, 8]]
1263 >>> list(interleave_evenly(iterables))
1264 [1, 6, 4, 2, 7, 3, 8, 5]
1266 This function requires iterables of known length. Iterables without
1267 ``__len__()`` can be used by manually specifying lengths with *lengths*:
1269 >>> from itertools import combinations, repeat
1270 >>> iterables = [combinations(range(4), 2), ['a', 'b', 'c']]
1271 >>> lengths = [4 * (4 - 1) // 2, 3]
1272 >>> list(interleave_evenly(iterables, lengths=lengths))
1273 [(0, 1), (0, 2), 'a', (0, 3), (1, 2), 'b', (1, 3), (2, 3), 'c']
1275 Based on Bresenham's algorithm.
1276 """
1277 if lengths is None:
1278 try:
1279 lengths = [len(it) for it in iterables]
1280 except TypeError:
1281 raise ValueError(
1282 'Iterable lengths could not be determined automatically. '
1283 'Specify them with the lengths keyword.'
1284 )
1285 elif len(iterables) != len(lengths):
1286 raise ValueError('Mismatching number of iterables and lengths.')
1288 dims = len(lengths)
1290 # sort iterables by length, descending
1291 lengths_permute = sorted(
1292 range(dims), key=lambda i: lengths[i], reverse=True
1293 )
1294 lengths_desc = [lengths[i] for i in lengths_permute]
1295 iters_desc = [iter(iterables[i]) for i in lengths_permute]
1297 # the longest iterable is the primary one (Bresenham: the longest
1298 # distance along an axis)
1299 delta_primary, deltas_secondary = lengths_desc[0], lengths_desc[1:]
1300 iter_primary, iters_secondary = iters_desc[0], iters_desc[1:]
1301 errors = [delta_primary // dims] * len(deltas_secondary)
1303 to_yield = sum(lengths)
1304 while to_yield:
1305 yield next(iter_primary)
1306 to_yield -= 1
1307 # update errors for each secondary iterable
1308 errors = [e - delta for e, delta in zip(errors, deltas_secondary)]
1310 # those iterables for which the error is negative are yielded
1311 # ("diagonal step" in Bresenham)
1312 for i, e_ in enumerate(errors):
1313 if e_ < 0:
1314 yield next(iters_secondary[i])
1315 to_yield -= 1
1316 errors[i] += delta_primary
1319def collapse(iterable, base_type=None, levels=None):
1320 """Flatten an iterable with multiple levels of nesting (e.g., a list of
1321 lists of tuples) into non-iterable types.
1323 >>> iterable = [(1, 2), ([3, 4], [[5], [6]])]
1324 >>> list(collapse(iterable))
1325 [1, 2, 3, 4, 5, 6]
1327 Binary and text strings are not considered iterable and
1328 will not be collapsed.
1330 To avoid collapsing other types, specify *base_type*:
1332 >>> iterable = ['ab', ('cd', 'ef'), ['gh', 'ij']]
1333 >>> list(collapse(iterable, base_type=tuple))
1334 ['ab', ('cd', 'ef'), 'gh', 'ij']
1336 Specify *levels* to stop flattening after a certain level:
1338 >>> iterable = [('a', ['b']), ('c', ['d'])]
1339 >>> list(collapse(iterable)) # Fully flattened
1340 ['a', 'b', 'c', 'd']
1341 >>> list(collapse(iterable, levels=1)) # Only one level flattened
1342 ['a', ['b'], 'c', ['d']]
1344 """
1345 stack = deque()
1346 # Add our first node group, treat the iterable as a single node
1347 stack.appendleft((0, repeat(iterable, 1)))
1349 while stack:
1350 node_group = stack.popleft()
1351 level, nodes = node_group
1353 # Check if beyond max level
1354 if levels is not None and level > levels:
1355 yield from nodes
1356 continue
1358 for node in nodes:
1359 # Check if done iterating
1360 if isinstance(node, (str, bytes)) or (
1361 (base_type is not None) and isinstance(node, base_type)
1362 ):
1363 yield node
1364 # Otherwise try to create child nodes
1365 else:
1366 try:
1367 tree = iter(node)
1368 except TypeError:
1369 yield node
1370 else:
1371 # Save our current location
1372 stack.appendleft(node_group)
1373 # Append the new child node
1374 stack.appendleft((level + 1, tree))
1375 # Break to process child node
1376 break
1379def side_effect(func, iterable, chunk_size=None, before=None, after=None):
1380 """Invoke *func* on each item in *iterable* (or on each *chunk_size* group
1381 of items) before yielding the item.
1383 `func` must be a function that takes a single argument. Its return value
1384 will be discarded.
1386 *before* and *after* are optional functions that take no arguments. They
1387 will be executed before iteration starts and after it ends, respectively.
1389 `side_effect` can be used for logging, updating progress bars, or anything
1390 that is not functionally "pure."
1392 Emitting a status message:
1394 >>> from more_itertools import consume
1395 >>> func = lambda item: print('Received {}'.format(item))
1396 >>> consume(side_effect(func, range(2)))
1397 Received 0
1398 Received 1
1400 Operating on chunks of items:
1402 >>> pair_sums = []
1403 >>> func = lambda chunk: pair_sums.append(sum(chunk))
1404 >>> list(side_effect(func, [0, 1, 2, 3, 4, 5], 2))
1405 [0, 1, 2, 3, 4, 5]
1406 >>> list(pair_sums)
1407 [1, 5, 9]
1409 Writing to a file-like object:
1411 >>> from io import StringIO
1412 >>> from more_itertools import consume
1413 >>> f = StringIO()
1414 >>> func = lambda x: print(x, file=f)
1415 >>> before = lambda: print(u'HEADER', file=f)
1416 >>> after = f.close
1417 >>> it = [u'a', u'b', u'c']
1418 >>> consume(side_effect(func, it, before=before, after=after))
1419 >>> f.closed
1420 True
1422 """
1423 try:
1424 if before is not None:
1425 before()
1427 if chunk_size is None:
1428 for item in iterable:
1429 func(item)
1430 yield item
1431 else:
1432 for chunk in chunked(iterable, chunk_size):
1433 func(chunk)
1434 yield from chunk
1435 finally:
1436 if after is not None:
1437 after()
1440def sliced(seq, n, strict=False):
1441 """Yield slices of length *n* from the sequence *seq*.
1443 >>> list(sliced((1, 2, 3, 4, 5, 6), 3))
1444 [(1, 2, 3), (4, 5, 6)]
1446 By the default, the last yielded slice will have fewer than *n* elements
1447 if the length of *seq* is not divisible by *n*:
1449 >>> list(sliced((1, 2, 3, 4, 5, 6, 7, 8), 3))
1450 [(1, 2, 3), (4, 5, 6), (7, 8)]
1452 If the length of *seq* is not divisible by *n* and *strict* is
1453 ``True``, then ``ValueError`` will be raised before the last
1454 slice is yielded.
1456 This function will only work for iterables that support slicing.
1457 For non-sliceable iterables, see :func:`chunked`.
1459 """
1460 iterator = takewhile(len, (seq[i : i + n] for i in count(0, n)))
1461 if strict:
1463 def ret():
1464 for _slice in iterator:
1465 if len(_slice) != n:
1466 raise ValueError("seq is not divisible by n.")
1467 yield _slice
1469 return ret()
1470 else:
1471 return iterator
1474def split_at(iterable, pred, maxsplit=-1, keep_separator=False):
1475 """Yield lists of items from *iterable*, where each list is delimited by
1476 an item where callable *pred* returns ``True``.
1478 >>> list(split_at('abcdcba', lambda x: x == 'b'))
1479 [['a'], ['c', 'd', 'c'], ['a']]
1481 >>> list(split_at(range(10), lambda n: n % 2 == 1))
1482 [[0], [2], [4], [6], [8], []]
1484 At most *maxsplit* splits are done. If *maxsplit* is not specified or -1,
1485 then there is no limit on the number of splits:
1487 >>> list(split_at(range(10), lambda n: n % 2 == 1, maxsplit=2))
1488 [[0], [2], [4, 5, 6, 7, 8, 9]]
1490 By default, the delimiting items are not included in the output.
1491 To include them, set *keep_separator* to ``True``.
1493 >>> list(split_at('abcdcba', lambda x: x == 'b', keep_separator=True))
1494 [['a'], ['b'], ['c', 'd', 'c'], ['b'], ['a']]
1496 """
1497 if maxsplit == 0:
1498 yield list(iterable)
1499 return
1501 buf = []
1502 it = iter(iterable)
1503 for item in it:
1504 if pred(item):
1505 yield buf
1506 if keep_separator:
1507 yield [item]
1508 if maxsplit == 1:
1509 yield list(it)
1510 return
1511 buf = []
1512 maxsplit -= 1
1513 else:
1514 buf.append(item)
1515 yield buf
1518def split_before(iterable, pred, maxsplit=-1):
1519 """Yield lists of items from *iterable*, where each list ends just before
1520 an item for which callable *pred* returns ``True``:
1522 >>> list(split_before('OneTwo', lambda s: s.isupper()))
1523 [['O', 'n', 'e'], ['T', 'w', 'o']]
1525 >>> list(split_before(range(10), lambda n: n % 3 == 0))
1526 [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
1528 At most *maxsplit* splits are done. If *maxsplit* is not specified or -1,
1529 then there is no limit on the number of splits:
1531 >>> list(split_before(range(10), lambda n: n % 3 == 0, maxsplit=2))
1532 [[0, 1, 2], [3, 4, 5], [6, 7, 8, 9]]
1533 """
1534 if maxsplit == 0:
1535 yield list(iterable)
1536 return
1538 buf = []
1539 it = iter(iterable)
1540 for item in it:
1541 if pred(item) and buf:
1542 yield buf
1543 if maxsplit == 1:
1544 yield [item, *it]
1545 return
1546 buf = []
1547 maxsplit -= 1
1548 buf.append(item)
1549 if buf:
1550 yield buf
1553def split_after(iterable, pred, maxsplit=-1):
1554 """Yield lists of items from *iterable*, where each list ends with an
1555 item where callable *pred* returns ``True``:
1557 >>> list(split_after('one1two2', lambda s: s.isdigit()))
1558 [['o', 'n', 'e', '1'], ['t', 'w', 'o', '2']]
1560 >>> list(split_after(range(10), lambda n: n % 3 == 0))
1561 [[0], [1, 2, 3], [4, 5, 6], [7, 8, 9]]
1563 At most *maxsplit* splits are done. If *maxsplit* is not specified or -1,
1564 then there is no limit on the number of splits:
1566 >>> list(split_after(range(10), lambda n: n % 3 == 0, maxsplit=2))
1567 [[0], [1, 2, 3], [4, 5, 6, 7, 8, 9]]
1569 """
1570 if maxsplit == 0:
1571 yield list(iterable)
1572 return
1574 buf = []
1575 it = iter(iterable)
1576 for item in it:
1577 buf.append(item)
1578 if pred(item) and buf:
1579 yield buf
1580 if maxsplit == 1:
1581 buf = list(it)
1582 if buf:
1583 yield buf
1584 return
1585 buf = []
1586 maxsplit -= 1
1587 if buf:
1588 yield buf
1591def split_when(iterable, pred, maxsplit=-1):
1592 """Split *iterable* into pieces based on the output of *pred*.
1593 *pred* should be a function that takes successive pairs of items and
1594 returns ``True`` if the iterable should be split in between them.
1596 For example, to find runs of increasing numbers, split the iterable when
1597 element ``i`` is larger than element ``i + 1``:
1599 >>> list(split_when([1, 2, 3, 3, 2, 5, 2, 4, 2], lambda x, y: x > y))
1600 [[1, 2, 3, 3], [2, 5], [2, 4], [2]]
1602 At most *maxsplit* splits are done. If *maxsplit* is not specified or -1,
1603 then there is no limit on the number of splits:
1605 >>> list(split_when([1, 2, 3, 3, 2, 5, 2, 4, 2],
1606 ... lambda x, y: x > y, maxsplit=2))
1607 [[1, 2, 3, 3], [2, 5], [2, 4, 2]]
1609 """
1610 if maxsplit == 0:
1611 yield list(iterable)
1612 return
1614 it = iter(iterable)
1615 try:
1616 cur_item = next(it)
1617 except StopIteration:
1618 return
1620 buf = [cur_item]
1621 for next_item in it:
1622 if pred(cur_item, next_item):
1623 yield buf
1624 if maxsplit == 1:
1625 yield [next_item, *it]
1626 return
1627 buf = []
1628 maxsplit -= 1
1630 buf.append(next_item)
1631 cur_item = next_item
1633 yield buf
1636def split_into(iterable, sizes):
1637 """Yield a list of sequential items from *iterable* of length 'n' for each
1638 integer 'n' in *sizes*.
1640 >>> list(split_into([1,2,3,4,5,6], [1,2,3]))
1641 [[1], [2, 3], [4, 5, 6]]
1643 If the sum of *sizes* is smaller than the length of *iterable*, then the
1644 remaining items of *iterable* will not be returned.
1646 >>> list(split_into([1,2,3,4,5,6], [2,3]))
1647 [[1, 2], [3, 4, 5]]
1649 If the sum of *sizes* is larger than the length of *iterable*, fewer items
1650 will be returned in the iteration that overruns the *iterable* and further
1651 lists will be empty:
1653 >>> list(split_into([1,2,3,4], [1,2,3,4]))
1654 [[1], [2, 3], [4], []]
1656 When a ``None`` object is encountered in *sizes*, the returned list will
1657 contain items up to the end of *iterable* the same way that
1658 :func:`itertools.slice` does:
1660 >>> list(split_into([1,2,3,4,5,6,7,8,9,0], [2,3,None]))
1661 [[1, 2], [3, 4, 5], [6, 7, 8, 9, 0]]
1663 :func:`split_into` can be useful for grouping a series of items where the
1664 sizes of the groups are not uniform. An example would be where in a row
1665 from a table, multiple columns represent elements of the same feature
1666 (e.g. a point represented by x,y,z) but, the format is not the same for
1667 all columns.
1668 """
1669 # convert the iterable argument into an iterator so its contents can
1670 # be consumed by islice in case it is a generator
1671 it = iter(iterable)
1673 for size in sizes:
1674 if size is None:
1675 yield list(it)
1676 return
1677 else:
1678 yield list(islice(it, size))
1681def padded(iterable, fillvalue=None, n=None, next_multiple=False):
1682 """Yield the elements from *iterable*, followed by *fillvalue*, such that
1683 at least *n* items are emitted.
1685 >>> list(padded([1, 2, 3], '?', 5))
1686 [1, 2, 3, '?', '?']
1688 If *next_multiple* is ``True``, *fillvalue* will be emitted until the
1689 number of items emitted is a multiple of *n*:
1691 >>> list(padded([1, 2, 3, 4], n=3, next_multiple=True))
1692 [1, 2, 3, 4, None, None]
1694 If *n* is ``None``, *fillvalue* will be emitted indefinitely.
1696 To create an *iterable* of exactly size *n*, you can truncate with
1697 :func:`islice`.
1699 >>> list(islice(padded([1, 2, 3], '?'), 5))
1700 [1, 2, 3, '?', '?']
1701 >>> list(islice(padded([1, 2, 3, 4, 5, 6, 7, 8], '?'), 5))
1702 [1, 2, 3, 4, 5]
1704 """
1705 iterator = iter(iterable)
1706 iterator_with_repeat = chain(iterator, repeat(fillvalue))
1708 if n is None:
1709 return iterator_with_repeat
1710 elif n < 1:
1711 raise ValueError('n must be at least 1')
1712 elif next_multiple:
1714 def slice_generator():
1715 for first in iterator:
1716 yield (first,)
1717 yield islice(iterator_with_repeat, n - 1)
1719 # While elements exist produce slices of size n
1720 return chain.from_iterable(slice_generator())
1721 else:
1722 # Ensure the first batch is at least size n then iterate
1723 return chain(islice(iterator_with_repeat, n), iterator)
1726def repeat_each(iterable, n=2):
1727 """Repeat each element in *iterable* *n* times.
1729 >>> list(repeat_each('ABC', 3))
1730 ['A', 'A', 'A', 'B', 'B', 'B', 'C', 'C', 'C']
1731 """
1732 return chain.from_iterable(map(repeat, iterable, repeat(n)))
1735def repeat_last(iterable, default=None):
1736 """After the *iterable* is exhausted, keep yielding its last element.
1738 >>> list(islice(repeat_last(range(3)), 5))
1739 [0, 1, 2, 2, 2]
1741 If the iterable is empty, yield *default* forever::
1743 >>> list(islice(repeat_last(range(0), 42), 5))
1744 [42, 42, 42, 42, 42]
1746 """
1747 item = _marker
1748 for item in iterable:
1749 yield item
1750 final = default if item is _marker else item
1751 yield from repeat(final)
1754def distribute(n, iterable):
1755 """Distribute the items from *iterable* among *n* smaller iterables.
1757 >>> group_1, group_2 = distribute(2, [1, 2, 3, 4, 5, 6])
1758 >>> list(group_1)
1759 [1, 3, 5]
1760 >>> list(group_2)
1761 [2, 4, 6]
1763 If the length of *iterable* is not evenly divisible by *n*, then the
1764 length of the returned iterables will not be identical:
1766 >>> children = distribute(3, [1, 2, 3, 4, 5, 6, 7])
1767 >>> [list(c) for c in children]
1768 [[1, 4, 7], [2, 5], [3, 6]]
1770 If the length of *iterable* is smaller than *n*, then the last returned
1771 iterables will be empty:
1773 >>> children = distribute(5, [1, 2, 3])
1774 >>> [list(c) for c in children]
1775 [[1], [2], [3], [], []]
1777 This function uses :func:`itertools.tee` and may require significant
1778 storage.
1780 If you need the order items in the smaller iterables to match the
1781 original iterable, see :func:`divide`.
1783 """
1784 if n < 1:
1785 raise ValueError('n must be at least 1')
1787 children = tee(iterable, n)
1788 return [islice(it, index, None, n) for index, it in enumerate(children)]
1791def stagger(iterable, offsets=(-1, 0, 1), longest=False, fillvalue=None):
1792 """Yield tuples whose elements are offset from *iterable*.
1793 The amount by which the `i`-th item in each tuple is offset is given by
1794 the `i`-th item in *offsets*.
1796 >>> list(stagger([0, 1, 2, 3]))
1797 [(None, 0, 1), (0, 1, 2), (1, 2, 3)]
1798 >>> list(stagger(range(8), offsets=(0, 2, 4)))
1799 [(0, 2, 4), (1, 3, 5), (2, 4, 6), (3, 5, 7)]
1801 By default, the sequence will end when the final element of a tuple is the
1802 last item in the iterable. To continue until the first element of a tuple
1803 is the last item in the iterable, set *longest* to ``True``::
1805 >>> list(stagger([0, 1, 2, 3], longest=True))
1806 [(None, 0, 1), (0, 1, 2), (1, 2, 3), (2, 3, None), (3, None, None)]
1808 By default, ``None`` will be used to replace offsets beyond the end of the
1809 sequence. Specify *fillvalue* to use some other value.
1811 """
1812 children = tee(iterable, len(offsets))
1814 return zip_offset(
1815 *children, offsets=offsets, longest=longest, fillvalue=fillvalue
1816 )
1819def zip_equal(*iterables):
1820 """``zip`` the input *iterables* together but raise
1821 ``UnequalIterablesError`` if they aren't all the same length.
1823 >>> it_1 = range(3)
1824 >>> it_2 = iter('abc')
1825 >>> list(zip_equal(it_1, it_2))
1826 [(0, 'a'), (1, 'b'), (2, 'c')]
1828 >>> it_1 = range(3)
1829 >>> it_2 = iter('abcd')
1830 >>> list(zip_equal(it_1, it_2)) # doctest: +IGNORE_EXCEPTION_DETAIL
1831 Traceback (most recent call last):
1832 ...
1833 more_itertools.more.UnequalIterablesError: Iterables have different
1834 lengths
1836 """
1837 if hexversion >= 0x30A00A6:
1838 warnings.warn(
1839 (
1840 'zip_equal will be removed in a future version of '
1841 'more-itertools. Use the builtin zip function with '
1842 'strict=True instead.'
1843 ),
1844 DeprecationWarning,
1845 )
1847 return _zip_equal(*iterables)
1850def zip_offset(*iterables, offsets, longest=False, fillvalue=None):
1851 """``zip`` the input *iterables* together, but offset the `i`-th iterable
1852 by the `i`-th item in *offsets*.
1854 >>> list(zip_offset('0123', 'abcdef', offsets=(0, 1)))
1855 [('0', 'b'), ('1', 'c'), ('2', 'd'), ('3', 'e')]
1857 This can be used as a lightweight alternative to SciPy or pandas to analyze
1858 data sets in which some series have a lead or lag relationship.
1860 By default, the sequence will end when the shortest iterable is exhausted.
1861 To continue until the longest iterable is exhausted, set *longest* to
1862 ``True``.
1864 >>> list(zip_offset('0123', 'abcdef', offsets=(0, 1), longest=True))
1865 [('0', 'b'), ('1', 'c'), ('2', 'd'), ('3', 'e'), (None, 'f')]
1867 By default, ``None`` will be used to replace offsets beyond the end of the
1868 sequence. Specify *fillvalue* to use some other value.
1870 """
1871 if len(iterables) != len(offsets):
1872 raise ValueError("Number of iterables and offsets didn't match")
1874 staggered = []
1875 for it, n in zip(iterables, offsets):
1876 if n < 0:
1877 staggered.append(chain(repeat(fillvalue, -n), it))
1878 elif n > 0:
1879 staggered.append(islice(it, n, None))
1880 else:
1881 staggered.append(it)
1883 if longest:
1884 return zip_longest(*staggered, fillvalue=fillvalue)
1886 return zip(*staggered)
1889def sort_together(
1890 iterables, key_list=(0,), key=None, reverse=False, strict=False
1891):
1892 """Return the input iterables sorted together, with *key_list* as the
1893 priority for sorting. All iterables are trimmed to the length of the
1894 shortest one.
1896 This can be used like the sorting function in a spreadsheet. If each
1897 iterable represents a column of data, the key list determines which
1898 columns are used for sorting.
1900 By default, all iterables are sorted using the ``0``-th iterable::
1902 >>> iterables = [(4, 3, 2, 1), ('a', 'b', 'c', 'd')]
1903 >>> sort_together(iterables)
1904 [(1, 2, 3, 4), ('d', 'c', 'b', 'a')]
1906 Set a different key list to sort according to another iterable.
1907 Specifying multiple keys dictates how ties are broken::
1909 >>> iterables = [(3, 1, 2), (0, 1, 0), ('c', 'b', 'a')]
1910 >>> sort_together(iterables, key_list=(1, 2))
1911 [(2, 3, 1), (0, 0, 1), ('a', 'c', 'b')]
1913 To sort by a function of the elements of the iterable, pass a *key*
1914 function. Its arguments are the elements of the iterables corresponding to
1915 the key list::
1917 >>> names = ('a', 'b', 'c')
1918 >>> lengths = (1, 2, 3)
1919 >>> widths = (5, 2, 1)
1920 >>> def area(length, width):
1921 ... return length * width
1922 >>> sort_together([names, lengths, widths], key_list=(1, 2), key=area)
1923 [('c', 'b', 'a'), (3, 2, 1), (1, 2, 5)]
1925 Set *reverse* to ``True`` to sort in descending order.
1927 >>> sort_together([(1, 2, 3), ('c', 'b', 'a')], reverse=True)
1928 [(3, 2, 1), ('a', 'b', 'c')]
1930 If the *strict* keyword argument is ``True``, then
1931 ``UnequalIterablesError`` will be raised if any of the iterables have
1932 different lengths.
1934 """
1935 if key is None:
1936 # if there is no key function, the key argument to sorted is an
1937 # itemgetter
1938 key_argument = itemgetter(*key_list)
1939 else:
1940 # if there is a key function, call it with the items at the offsets
1941 # specified by the key function as arguments
1942 key_list = list(key_list)
1943 if len(key_list) == 1:
1944 # if key_list contains a single item, pass the item at that offset
1945 # as the only argument to the key function
1946 key_offset = key_list[0]
1947 key_argument = lambda zipped_items: key(zipped_items[key_offset])
1948 else:
1949 # if key_list contains multiple items, use itemgetter to return a
1950 # tuple of items, which we pass as *args to the key function
1951 get_key_items = itemgetter(*key_list)
1952 key_argument = lambda zipped_items: key(
1953 *get_key_items(zipped_items)
1954 )
1956 zipper = zip_equal if strict else zip
1957 return list(
1958 zipper(*sorted(zipper(*iterables), key=key_argument, reverse=reverse))
1959 )
1962def unzip(iterable):
1963 """The inverse of :func:`zip`, this function disaggregates the elements
1964 of the zipped *iterable*.
1966 The ``i``-th iterable contains the ``i``-th element from each element
1967 of the zipped iterable. The first element is used to determine the
1968 length of the remaining elements.
1970 >>> iterable = [('a', 1), ('b', 2), ('c', 3), ('d', 4)]
1971 >>> letters, numbers = unzip(iterable)
1972 >>> list(letters)
1973 ['a', 'b', 'c', 'd']
1974 >>> list(numbers)
1975 [1, 2, 3, 4]
1977 This is similar to using ``zip(*iterable)``, but it avoids reading
1978 *iterable* into memory. Note, however, that this function uses
1979 :func:`itertools.tee` and thus may require significant storage.
1981 """
1982 head, iterable = spy(iterable)
1983 if not head:
1984 # empty iterable, e.g. zip([], [], [])
1985 return ()
1986 # spy returns a one-length iterable as head
1987 head = head[0]
1988 iterables = tee(iterable, len(head))
1990 def itemgetter(i):
1991 def getter(obj):
1992 try:
1993 return obj[i]
1994 except IndexError:
1995 # basically if we have an iterable like
1996 # iter([(1, 2, 3), (4, 5), (6,)])
1997 # the second unzipped iterable would fail at the third tuple
1998 # since it would try to access tup[1]
1999 # same with the third unzipped iterable and the second tuple
2000 # to support these "improperly zipped" iterables,
2001 # we create a custom itemgetter
2002 # which just stops the unzipped iterables
2003 # at first length mismatch
2004 raise StopIteration
2006 return getter
2008 return tuple(map(itemgetter(i), it) for i, it in enumerate(iterables))
2011def divide(n, iterable):
2012 """Divide the elements from *iterable* into *n* parts, maintaining
2013 order.
2015 >>> group_1, group_2 = divide(2, [1, 2, 3, 4, 5, 6])
2016 >>> list(group_1)
2017 [1, 2, 3]
2018 >>> list(group_2)
2019 [4, 5, 6]
2021 If the length of *iterable* is not evenly divisible by *n*, then the
2022 length of the returned iterables will not be identical:
2024 >>> children = divide(3, [1, 2, 3, 4, 5, 6, 7])
2025 >>> [list(c) for c in children]
2026 [[1, 2, 3], [4, 5], [6, 7]]
2028 If the length of the iterable is smaller than n, then the last returned
2029 iterables will be empty:
2031 >>> children = divide(5, [1, 2, 3])
2032 >>> [list(c) for c in children]
2033 [[1], [2], [3], [], []]
2035 This function will exhaust the iterable before returning.
2036 If order is not important, see :func:`distribute`, which does not first
2037 pull the iterable into memory.
2039 """
2040 if n < 1:
2041 raise ValueError('n must be at least 1')
2043 try:
2044 iterable[:0]
2045 except TypeError:
2046 seq = tuple(iterable)
2047 else:
2048 seq = iterable
2050 q, r = divmod(len(seq), n)
2052 ret = []
2053 stop = 0
2054 for i in range(1, n + 1):
2055 start = stop
2056 stop += q + 1 if i <= r else q
2057 ret.append(iter(seq[start:stop]))
2059 return ret
2062def always_iterable(obj, base_type=(str, bytes)):
2063 """If *obj* is iterable, return an iterator over its items::
2065 >>> obj = (1, 2, 3)
2066 >>> list(always_iterable(obj))
2067 [1, 2, 3]
2069 If *obj* is not iterable, return a one-item iterable containing *obj*::
2071 >>> obj = 1
2072 >>> list(always_iterable(obj))
2073 [1]
2075 If *obj* is ``None``, return an empty iterable:
2077 >>> obj = None
2078 >>> list(always_iterable(None))
2079 []
2081 By default, binary and text strings are not considered iterable::
2083 >>> obj = 'foo'
2084 >>> list(always_iterable(obj))
2085 ['foo']
2087 If *base_type* is set, objects for which ``isinstance(obj, base_type)``
2088 returns ``True`` won't be considered iterable.
2090 >>> obj = {'a': 1}
2091 >>> list(always_iterable(obj)) # Iterate over the dict's keys
2092 ['a']
2093 >>> list(always_iterable(obj, base_type=dict)) # Treat dicts as a unit
2094 [{'a': 1}]
2096 Set *base_type* to ``None`` to avoid any special handling and treat objects
2097 Python considers iterable as iterable:
2099 >>> obj = 'foo'
2100 >>> list(always_iterable(obj, base_type=None))
2101 ['f', 'o', 'o']
2102 """
2103 if obj is None:
2104 return iter(())
2106 if (base_type is not None) and isinstance(obj, base_type):
2107 return iter((obj,))
2109 try:
2110 return iter(obj)
2111 except TypeError:
2112 return iter((obj,))
2115def adjacent(predicate, iterable, distance=1):
2116 """Return an iterable over `(bool, item)` tuples where the `item` is
2117 drawn from *iterable* and the `bool` indicates whether
2118 that item satisfies the *predicate* or is adjacent to an item that does.
2120 For example, to find whether items are adjacent to a ``3``::
2122 >>> list(adjacent(lambda x: x == 3, range(6)))
2123 [(False, 0), (False, 1), (True, 2), (True, 3), (True, 4), (False, 5)]
2125 Set *distance* to change what counts as adjacent. For example, to find
2126 whether items are two places away from a ``3``:
2128 >>> list(adjacent(lambda x: x == 3, range(6), distance=2))
2129 [(False, 0), (True, 1), (True, 2), (True, 3), (True, 4), (True, 5)]
2131 This is useful for contextualizing the results of a search function.
2132 For example, a code comparison tool might want to identify lines that
2133 have changed, but also surrounding lines to give the viewer of the diff
2134 context.
2136 The predicate function will only be called once for each item in the
2137 iterable.
2139 See also :func:`groupby_transform`, which can be used with this function
2140 to group ranges of items with the same `bool` value.
2142 """
2143 # Allow distance=0 mainly for testing that it reproduces results with map()
2144 if distance < 0:
2145 raise ValueError('distance must be at least 0')
2147 i1, i2 = tee(iterable)
2148 padding = [False] * distance
2149 selected = chain(padding, map(predicate, i1), padding)
2150 adjacent_to_selected = map(any, windowed(selected, 2 * distance + 1))
2151 return zip(adjacent_to_selected, i2)
2154def groupby_transform(iterable, keyfunc=None, valuefunc=None, reducefunc=None):
2155 """An extension of :func:`itertools.groupby` that can apply transformations
2156 to the grouped data.
2158 * *keyfunc* is a function computing a key value for each item in *iterable*
2159 * *valuefunc* is a function that transforms the individual items from
2160 *iterable* after grouping
2161 * *reducefunc* is a function that transforms each group of items
2163 >>> iterable = 'aAAbBBcCC'
2164 >>> keyfunc = lambda k: k.upper()
2165 >>> valuefunc = lambda v: v.lower()
2166 >>> reducefunc = lambda g: ''.join(g)
2167 >>> list(groupby_transform(iterable, keyfunc, valuefunc, reducefunc))
2168 [('A', 'aaa'), ('B', 'bbb'), ('C', 'ccc')]
2170 Each optional argument defaults to an identity function if not specified.
2172 :func:`groupby_transform` is useful when grouping elements of an iterable
2173 using a separate iterable as the key. To do this, :func:`zip` the iterables
2174 and pass a *keyfunc* that extracts the first element and a *valuefunc*
2175 that extracts the second element::
2177 >>> from operator import itemgetter
2178 >>> keys = [0, 0, 1, 1, 1, 2, 2, 2, 3]
2179 >>> values = 'abcdefghi'
2180 >>> iterable = zip(keys, values)
2181 >>> grouper = groupby_transform(iterable, itemgetter(0), itemgetter(1))
2182 >>> [(k, ''.join(g)) for k, g in grouper]
2183 [(0, 'ab'), (1, 'cde'), (2, 'fgh'), (3, 'i')]
2185 Note that the order of items in the iterable is significant.
2186 Only adjacent items are grouped together, so if you don't want any
2187 duplicate groups, you should sort the iterable by the key function.
2189 """
2190 ret = groupby(iterable, keyfunc)
2191 if valuefunc:
2192 ret = ((k, map(valuefunc, g)) for k, g in ret)
2193 if reducefunc:
2194 ret = ((k, reducefunc(g)) for k, g in ret)
2196 return ret
2199class numeric_range(abc.Sequence, abc.Hashable):
2200 """An extension of the built-in ``range()`` function whose arguments can
2201 be any orderable numeric type.
2203 With only *stop* specified, *start* defaults to ``0`` and *step*
2204 defaults to ``1``. The output items will match the type of *stop*:
2206 >>> list(numeric_range(3.5))
2207 [0.0, 1.0, 2.0, 3.0]
2209 With only *start* and *stop* specified, *step* defaults to ``1``. The
2210 output items will match the type of *start*:
2212 >>> from decimal import Decimal
2213 >>> start = Decimal('2.1')
2214 >>> stop = Decimal('5.1')
2215 >>> list(numeric_range(start, stop))
2216 [Decimal('2.1'), Decimal('3.1'), Decimal('4.1')]
2218 With *start*, *stop*, and *step* specified the output items will match
2219 the type of ``start + step``:
2221 >>> from fractions import Fraction
2222 >>> start = Fraction(1, 2) # Start at 1/2
2223 >>> stop = Fraction(5, 2) # End at 5/2
2224 >>> step = Fraction(1, 2) # Count by 1/2
2225 >>> list(numeric_range(start, stop, step))
2226 [Fraction(1, 2), Fraction(1, 1), Fraction(3, 2), Fraction(2, 1)]
2228 If *step* is zero, ``ValueError`` is raised. Negative steps are supported:
2230 >>> list(numeric_range(3, -1, -1.0))
2231 [3.0, 2.0, 1.0, 0.0]
2233 Be aware of the limitations of floating-point numbers; the representation
2234 of the yielded numbers may be surprising.
2236 ``datetime.datetime`` objects can be used for *start* and *stop*, if *step*
2237 is a ``datetime.timedelta`` object:
2239 >>> import datetime
2240 >>> start = datetime.datetime(2019, 1, 1)
2241 >>> stop = datetime.datetime(2019, 1, 3)
2242 >>> step = datetime.timedelta(days=1)
2243 >>> items = iter(numeric_range(start, stop, step))
2244 >>> next(items)
2245 datetime.datetime(2019, 1, 1, 0, 0)
2246 >>> next(items)
2247 datetime.datetime(2019, 1, 2, 0, 0)
2249 """
2251 _EMPTY_HASH = hash(range(0, 0))
2253 def __init__(self, *args):
2254 argc = len(args)
2255 if argc == 1:
2256 (self._stop,) = args
2257 self._start = type(self._stop)(0)
2258 self._step = type(self._stop - self._start)(1)
2259 elif argc == 2:
2260 self._start, self._stop = args
2261 self._step = type(self._stop - self._start)(1)
2262 elif argc == 3:
2263 self._start, self._stop, self._step = args
2264 elif argc == 0:
2265 raise TypeError(
2266 f'numeric_range expected at least 1 argument, got {argc}'
2267 )
2268 else:
2269 raise TypeError(
2270 f'numeric_range expected at most 3 arguments, got {argc}'
2271 )
2273 self._zero = type(self._step)(0)
2274 if self._step == self._zero:
2275 raise ValueError('numeric_range() arg 3 must not be zero')
2276 self._growing = self._step > self._zero
2278 def __bool__(self):
2279 if self._growing:
2280 return self._start < self._stop
2281 else:
2282 return self._start > self._stop
2284 def __contains__(self, elem):
2285 if self._growing:
2286 if self._start <= elem < self._stop:
2287 return (elem - self._start) % self._step == self._zero
2288 else:
2289 if self._start >= elem > self._stop:
2290 return (self._start - elem) % (-self._step) == self._zero
2292 return False
2294 def __eq__(self, other):
2295 if isinstance(other, numeric_range):
2296 empty_self = not bool(self)
2297 empty_other = not bool(other)
2298 if empty_self or empty_other:
2299 return empty_self and empty_other # True if both empty
2300 else:
2301 return (
2302 self._start == other._start
2303 and self._step == other._step
2304 and self._get_by_index(-1) == other._get_by_index(-1)
2305 )
2306 else:
2307 return False
2309 def __getitem__(self, key):
2310 if isinstance(key, int):
2311 return self._get_by_index(key)
2312 elif isinstance(key, slice):
2313 step = self._step if key.step is None else key.step * self._step
2315 if key.start is None or key.start <= -self._len:
2316 start = self._start
2317 elif key.start >= self._len:
2318 start = self._stop
2319 else: # -self._len < key.start < self._len
2320 start = self._get_by_index(key.start)
2322 if key.stop is None or key.stop >= self._len:
2323 stop = self._stop
2324 elif key.stop <= -self._len:
2325 stop = self._start
2326 else: # -self._len < key.stop < self._len
2327 stop = self._get_by_index(key.stop)
2329 return numeric_range(start, stop, step)
2330 else:
2331 raise TypeError(
2332 'numeric range indices must be '
2333 f'integers or slices, not {type(key).__name__}'
2334 )
2336 def __hash__(self):
2337 if self:
2338 return hash((self._start, self._get_by_index(-1), self._step))
2339 else:
2340 return self._EMPTY_HASH
2342 def __iter__(self):
2343 values = (self._start + (n * self._step) for n in count())
2344 if self._growing:
2345 return takewhile(partial(gt, self._stop), values)
2346 else:
2347 return takewhile(partial(lt, self._stop), values)
2349 def __len__(self):
2350 return self._len
2352 @cached_property
2353 def _len(self):
2354 if self._growing:
2355 start = self._start
2356 stop = self._stop
2357 step = self._step
2358 else:
2359 start = self._stop
2360 stop = self._start
2361 step = -self._step
2362 distance = stop - start
2363 if distance <= self._zero:
2364 return 0
2365 else: # distance > 0 and step > 0: regular euclidean division
2366 q, r = divmod(distance, step)
2367 return int(q) + int(r != self._zero)
2369 def __reduce__(self):
2370 return numeric_range, (self._start, self._stop, self._step)
2372 def __repr__(self):
2373 if self._step == 1:
2374 return f"numeric_range({self._start!r}, {self._stop!r})"
2375 return (
2376 f"numeric_range({self._start!r}, {self._stop!r}, {self._step!r})"
2377 )
2379 def __reversed__(self):
2380 return iter(
2381 numeric_range(
2382 self._get_by_index(-1), self._start - self._step, -self._step
2383 )
2384 )
2386 def count(self, value):
2387 return int(value in self)
2389 def index(self, value):
2390 if self._growing:
2391 if self._start <= value < self._stop:
2392 q, r = divmod(value - self._start, self._step)
2393 if r == self._zero:
2394 return int(q)
2395 else:
2396 if self._start >= value > self._stop:
2397 q, r = divmod(self._start - value, -self._step)
2398 if r == self._zero:
2399 return int(q)
2401 raise ValueError(f"{value} is not in numeric range")
2403 def _get_by_index(self, i):
2404 if i < 0:
2405 i += self._len
2406 if i < 0 or i >= self._len:
2407 raise IndexError("numeric range object index out of range")
2408 return self._start + i * self._step
2411def count_cycle(iterable, n=None):
2412 """Cycle through the items from *iterable* up to *n* times, yielding
2413 the number of completed cycles along with each item. If *n* is omitted the
2414 process repeats indefinitely.
2416 >>> list(count_cycle('AB', 3))
2417 [(0, 'A'), (0, 'B'), (1, 'A'), (1, 'B'), (2, 'A'), (2, 'B')]
2419 """
2420 iterable = tuple(iterable)
2421 if not iterable:
2422 return iter(())
2423 counter = count() if n is None else range(n)
2424 return ((i, item) for i in counter for item in iterable)
2427def mark_ends(iterable):
2428 """Yield 3-tuples of the form ``(is_first, is_last, item)``.
2430 >>> list(mark_ends('ABC'))
2431 [(True, False, 'A'), (False, False, 'B'), (False, True, 'C')]
2433 Use this when looping over an iterable to take special action on its first
2434 and/or last items:
2436 >>> iterable = ['Header', 100, 200, 'Footer']
2437 >>> total = 0
2438 >>> for is_first, is_last, item in mark_ends(iterable):
2439 ... if is_first:
2440 ... continue # Skip the header
2441 ... if is_last:
2442 ... continue # Skip the footer
2443 ... total += item
2444 >>> print(total)
2445 300
2446 """
2447 it = iter(iterable)
2448 for a in it:
2449 first = True
2450 for b in it:
2451 yield first, False, a
2452 a = b
2453 first = False
2454 yield first, True, a
2457def locate(iterable, pred=bool, window_size=None):
2458 """Yield the index of each item in *iterable* for which *pred* returns
2459 ``True``.
2461 *pred* defaults to :func:`bool`, which will select truthy items:
2463 >>> list(locate([0, 1, 1, 0, 1, 0, 0]))
2464 [1, 2, 4]
2466 Set *pred* to a custom function to, e.g., find the indexes for a particular
2467 item.
2469 >>> list(locate(['a', 'b', 'c', 'b'], lambda x: x == 'b'))
2470 [1, 3]
2472 If *window_size* is given, then the *pred* function will be called with
2473 that many items. This enables searching for sub-sequences:
2475 >>> iterable = [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]
2476 >>> pred = lambda *args: args == (1, 2, 3)
2477 >>> list(locate(iterable, pred=pred, window_size=3))
2478 [1, 5, 9]
2480 Use with :func:`seekable` to find indexes and then retrieve the associated
2481 items:
2483 >>> from itertools import count
2484 >>> from more_itertools import seekable
2485 >>> source = (3 * n + 1 if (n % 2) else n // 2 for n in count())
2486 >>> it = seekable(source)
2487 >>> pred = lambda x: x > 100
2488 >>> indexes = locate(it, pred=pred)
2489 >>> i = next(indexes)
2490 >>> it.seek(i)
2491 >>> next(it)
2492 106
2494 """
2495 if window_size is None:
2496 return compress(count(), map(pred, iterable))
2498 if window_size < 1:
2499 raise ValueError('window size must be at least 1')
2501 it = windowed(iterable, window_size, fillvalue=_marker)
2502 return compress(count(), starmap(pred, it))
2505def longest_common_prefix(iterables):
2506 """Yield elements of the longest common prefix among given *iterables*.
2508 >>> ''.join(longest_common_prefix(['abcd', 'abc', 'abf']))
2509 'ab'
2511 """
2512 return (c[0] for c in takewhile(all_equal, zip(*iterables)))
2515def lstrip(iterable, pred):
2516 """Yield the items from *iterable*, but strip any from the beginning
2517 for which *pred* returns ``True``.
2519 For example, to remove a set of items from the start of an iterable:
2521 >>> iterable = (None, False, None, 1, 2, None, 3, False, None)
2522 >>> pred = lambda x: x in {None, False, ''}
2523 >>> list(lstrip(iterable, pred))
2524 [1, 2, None, 3, False, None]
2526 This function is analogous to to :func:`str.lstrip`, and is essentially
2527 an wrapper for :func:`itertools.dropwhile`.
2529 """
2530 return dropwhile(pred, iterable)
2533def rstrip(iterable, pred):
2534 """Yield the items from *iterable*, but strip any from the end
2535 for which *pred* returns ``True``.
2537 For example, to remove a set of items from the end of an iterable:
2539 >>> iterable = (None, False, None, 1, 2, None, 3, False, None)
2540 >>> pred = lambda x: x in {None, False, ''}
2541 >>> list(rstrip(iterable, pred))
2542 [None, False, None, 1, 2, None, 3]
2544 This function is analogous to :func:`str.rstrip`.
2546 """
2547 cache = []
2548 cache_append = cache.append
2549 cache_clear = cache.clear
2550 for x in iterable:
2551 if pred(x):
2552 cache_append(x)
2553 else:
2554 yield from cache
2555 cache_clear()
2556 yield x
2559def strip(iterable, pred):
2560 """Yield the items from *iterable*, but strip any from the
2561 beginning and end for which *pred* returns ``True``.
2563 For example, to remove a set of items from both ends of an iterable:
2565 >>> iterable = (None, False, None, 1, 2, None, 3, False, None)
2566 >>> pred = lambda x: x in {None, False, ''}
2567 >>> list(strip(iterable, pred))
2568 [1, 2, None, 3]
2570 This function is analogous to :func:`str.strip`.
2572 """
2573 return rstrip(lstrip(iterable, pred), pred)
2576class islice_extended:
2577 """An extension of :func:`itertools.islice` that supports negative values
2578 for *stop*, *start*, and *step*.
2580 >>> iterator = iter('abcdefgh')
2581 >>> list(islice_extended(iterator, -4, -1))
2582 ['e', 'f', 'g']
2584 Slices with negative values require some caching of *iterable*, but this
2585 function takes care to minimize the amount of memory required.
2587 For example, you can use a negative step with an infinite iterator:
2589 >>> from itertools import count
2590 >>> list(islice_extended(count(), 110, 99, -2))
2591 [110, 108, 106, 104, 102, 100]
2593 You can also use slice notation directly:
2595 >>> iterator = map(str, count())
2596 >>> it = islice_extended(iterator)[10:20:2]
2597 >>> list(it)
2598 ['10', '12', '14', '16', '18']
2600 """
2602 def __init__(self, iterable, *args):
2603 it = iter(iterable)
2604 if args:
2605 self._iterator = _islice_helper(it, slice(*args))
2606 else:
2607 self._iterator = it
2609 def __iter__(self):
2610 return self
2612 def __next__(self):
2613 return next(self._iterator)
2615 def __getitem__(self, key):
2616 if isinstance(key, slice):
2617 return islice_extended(_islice_helper(self._iterator, key))
2619 raise TypeError('islice_extended.__getitem__ argument must be a slice')
2622def _islice_helper(it, s):
2623 start = s.start
2624 stop = s.stop
2625 if s.step == 0:
2626 raise ValueError('step argument must be a non-zero integer or None.')
2627 step = s.step or 1
2629 if step > 0:
2630 start = 0 if (start is None) else start
2632 if start < 0:
2633 # Consume all but the last -start items
2634 cache = deque(enumerate(it, 1), maxlen=-start)
2635 len_iter = cache[-1][0] if cache else 0
2637 # Adjust start to be positive
2638 i = max(len_iter + start, 0)
2640 # Adjust stop to be positive
2641 if stop is None:
2642 j = len_iter
2643 elif stop >= 0:
2644 j = min(stop, len_iter)
2645 else:
2646 j = max(len_iter + stop, 0)
2648 # Slice the cache
2649 n = j - i
2650 if n <= 0:
2651 return
2653 for index in range(n):
2654 if index % step == 0:
2655 # pop and yield the item.
2656 # We don't want to use an intermediate variable
2657 # it would extend the lifetime of the current item
2658 yield cache.popleft()[1]
2659 else:
2660 # just pop and discard the item
2661 cache.popleft()
2662 elif (stop is not None) and (stop < 0):
2663 # Advance to the start position
2664 next(islice(it, start, start), None)
2666 # When stop is negative, we have to carry -stop items while
2667 # iterating
2668 cache = deque(islice(it, -stop), maxlen=-stop)
2670 for index, item in enumerate(it):
2671 if index % step == 0:
2672 # pop and yield the item.
2673 # We don't want to use an intermediate variable
2674 # it would extend the lifetime of the current item
2675 yield cache.popleft()
2676 else:
2677 # just pop and discard the item
2678 cache.popleft()
2679 cache.append(item)
2680 else:
2681 # When both start and stop are positive we have the normal case
2682 yield from islice(it, start, stop, step)
2683 else:
2684 start = -1 if (start is None) else start
2686 if (stop is not None) and (stop < 0):
2687 # Consume all but the last items
2688 n = -stop - 1
2689 cache = deque(enumerate(it, 1), maxlen=n)
2690 len_iter = cache[-1][0] if cache else 0
2692 # If start and stop are both negative they are comparable and
2693 # we can just slice. Otherwise we can adjust start to be negative
2694 # and then slice.
2695 if start < 0:
2696 i, j = start, stop
2697 else:
2698 i, j = min(start - len_iter, -1), None
2700 for index, item in list(cache)[i:j:step]:
2701 yield item
2702 else:
2703 # Advance to the stop position
2704 if stop is not None:
2705 m = stop + 1
2706 next(islice(it, m, m), None)
2708 # stop is positive, so if start is negative they are not comparable
2709 # and we need the rest of the items.
2710 if start < 0:
2711 i = start
2712 n = None
2713 # stop is None and start is positive, so we just need items up to
2714 # the start index.
2715 elif stop is None:
2716 i = None
2717 n = start + 1
2718 # Both stop and start are positive, so they are comparable.
2719 else:
2720 i = None
2721 n = start - stop
2722 if n <= 0:
2723 return
2725 cache = list(islice(it, n))
2727 yield from cache[i::step]
2730def always_reversible(iterable):
2731 """An extension of :func:`reversed` that supports all iterables, not
2732 just those which implement the ``Reversible`` or ``Sequence`` protocols.
2734 >>> print(*always_reversible(x for x in range(3)))
2735 2 1 0
2737 If the iterable is already reversible, this function returns the
2738 result of :func:`reversed()`. If the iterable is not reversible,
2739 this function will cache the remaining items in the iterable and
2740 yield them in reverse order, which may require significant storage.
2741 """
2742 try:
2743 return reversed(iterable)
2744 except TypeError:
2745 return reversed(list(iterable))
2748def consecutive_groups(iterable, ordering=None):
2749 """Yield groups of consecutive items using :func:`itertools.groupby`.
2750 The *ordering* function determines whether two items are adjacent by
2751 returning their position.
2753 By default, the ordering function is the identity function. This is
2754 suitable for finding runs of numbers:
2756 >>> iterable = [1, 10, 11, 12, 20, 30, 31, 32, 33, 40]
2757 >>> for group in consecutive_groups(iterable):
2758 ... print(list(group))
2759 [1]
2760 [10, 11, 12]
2761 [20]
2762 [30, 31, 32, 33]
2763 [40]
2765 To find runs of adjacent letters, apply :func:`ord` function
2766 to convert letters to ordinals.
2768 >>> iterable = 'abcdfgilmnop'
2769 >>> ordering = ord
2770 >>> for group in consecutive_groups(iterable, ordering):
2771 ... print(list(group))
2772 ['a', 'b', 'c', 'd']
2773 ['f', 'g']
2774 ['i']
2775 ['l', 'm', 'n', 'o', 'p']
2777 Each group of consecutive items is an iterator that shares it source with
2778 *iterable*. When an an output group is advanced, the previous group is
2779 no longer available unless its elements are copied (e.g., into a ``list``).
2781 >>> iterable = [1, 2, 11, 12, 21, 22]
2782 >>> saved_groups = []
2783 >>> for group in consecutive_groups(iterable):
2784 ... saved_groups.append(list(group)) # Copy group elements
2785 >>> saved_groups
2786 [[1, 2], [11, 12], [21, 22]]
2788 """
2789 if ordering is None:
2790 key = lambda x: x[0] - x[1]
2791 else:
2792 key = lambda x: x[0] - ordering(x[1])
2794 for k, g in groupby(enumerate(iterable), key=key):
2795 yield map(itemgetter(1), g)
2798def difference(iterable, func=sub, *, initial=None):
2799 """This function is the inverse of :func:`itertools.accumulate`. By default
2800 it will compute the first difference of *iterable* using
2801 :func:`operator.sub`:
2803 >>> from itertools import accumulate
2804 >>> iterable = accumulate([0, 1, 2, 3, 4]) # produces 0, 1, 3, 6, 10
2805 >>> list(difference(iterable))
2806 [0, 1, 2, 3, 4]
2808 *func* defaults to :func:`operator.sub`, but other functions can be
2809 specified. They will be applied as follows::
2811 A, B, C, D, ... --> A, func(B, A), func(C, B), func(D, C), ...
2813 For example, to do progressive division:
2815 >>> iterable = [1, 2, 6, 24, 120]
2816 >>> func = lambda x, y: x // y
2817 >>> list(difference(iterable, func))
2818 [1, 2, 3, 4, 5]
2820 If the *initial* keyword is set, the first element will be skipped when
2821 computing successive differences.
2823 >>> it = [10, 11, 13, 16] # from accumulate([1, 2, 3], initial=10)
2824 >>> list(difference(it, initial=10))
2825 [1, 2, 3]
2827 """
2828 a, b = tee(iterable)
2829 try:
2830 first = [next(b)]
2831 except StopIteration:
2832 return iter([])
2834 if initial is not None:
2835 first = []
2837 return chain(first, map(func, b, a))
2840class SequenceView(Sequence):
2841 """Return a read-only view of the sequence object *target*.
2843 :class:`SequenceView` objects are analogous to Python's built-in
2844 "dictionary view" types. They provide a dynamic view of a sequence's items,
2845 meaning that when the sequence updates, so does the view.
2847 >>> seq = ['0', '1', '2']
2848 >>> view = SequenceView(seq)
2849 >>> view
2850 SequenceView(['0', '1', '2'])
2851 >>> seq.append('3')
2852 >>> view
2853 SequenceView(['0', '1', '2', '3'])
2855 Sequence views support indexing, slicing, and length queries. They act
2856 like the underlying sequence, except they don't allow assignment:
2858 >>> view[1]
2859 '1'
2860 >>> view[1:-1]
2861 ['1', '2']
2862 >>> len(view)
2863 4
2865 Sequence views are useful as an alternative to copying, as they don't
2866 require (much) extra storage.
2868 """
2870 def __init__(self, target):
2871 if not isinstance(target, Sequence):
2872 raise TypeError
2873 self._target = target
2875 def __getitem__(self, index):
2876 return self._target[index]
2878 def __len__(self):
2879 return len(self._target)
2881 def __repr__(self):
2882 return f'{self.__class__.__name__}({self._target!r})'
2885class seekable:
2886 """Wrap an iterator to allow for seeking backward and forward. This
2887 progressively caches the items in the source iterable so they can be
2888 re-visited.
2890 Call :meth:`seek` with an index to seek to that position in the source
2891 iterable.
2893 To "reset" an iterator, seek to ``0``:
2895 >>> from itertools import count
2896 >>> it = seekable((str(n) for n in count()))
2897 >>> next(it), next(it), next(it)
2898 ('0', '1', '2')
2899 >>> it.seek(0)
2900 >>> next(it), next(it), next(it)
2901 ('0', '1', '2')
2903 You can also seek forward:
2905 >>> it = seekable((str(n) for n in range(20)))
2906 >>> it.seek(10)
2907 >>> next(it)
2908 '10'
2909 >>> it.seek(20) # Seeking past the end of the source isn't a problem
2910 >>> list(it)
2911 []
2912 >>> it.seek(0) # Resetting works even after hitting the end
2913 >>> next(it)
2914 '0'
2916 Call :meth:`relative_seek` to seek relative to the source iterator's
2917 current position.
2919 >>> it = seekable((str(n) for n in range(20)))
2920 >>> next(it), next(it), next(it)
2921 ('0', '1', '2')
2922 >>> it.relative_seek(2)
2923 >>> next(it)
2924 '5'
2925 >>> it.relative_seek(-3) # Source is at '6', we move back to '3'
2926 >>> next(it)
2927 '3'
2928 >>> it.relative_seek(-3) # Source is at '4', we move back to '1'
2929 >>> next(it)
2930 '1'
2933 Call :meth:`peek` to look ahead one item without advancing the iterator:
2935 >>> it = seekable('1234')
2936 >>> it.peek()
2937 '1'
2938 >>> list(it)
2939 ['1', '2', '3', '4']
2940 >>> it.peek(default='empty')
2941 'empty'
2943 Before the iterator is at its end, calling :func:`bool` on it will return
2944 ``True``. After it will return ``False``:
2946 >>> it = seekable('5678')
2947 >>> bool(it)
2948 True
2949 >>> list(it)
2950 ['5', '6', '7', '8']
2951 >>> bool(it)
2952 False
2954 You may view the contents of the cache with the :meth:`elements` method.
2955 That returns a :class:`SequenceView`, a view that updates automatically:
2957 >>> it = seekable((str(n) for n in range(10)))
2958 >>> next(it), next(it), next(it)
2959 ('0', '1', '2')
2960 >>> elements = it.elements()
2961 >>> elements
2962 SequenceView(['0', '1', '2'])
2963 >>> next(it)
2964 '3'
2965 >>> elements
2966 SequenceView(['0', '1', '2', '3'])
2968 By default, the cache grows as the source iterable progresses, so beware of
2969 wrapping very large or infinite iterables. Supply *maxlen* to limit the
2970 size of the cache (this of course limits how far back you can seek).
2972 >>> from itertools import count
2973 >>> it = seekable((str(n) for n in count()), maxlen=2)
2974 >>> next(it), next(it), next(it), next(it)
2975 ('0', '1', '2', '3')
2976 >>> list(it.elements())
2977 ['2', '3']
2978 >>> it.seek(0)
2979 >>> next(it), next(it), next(it), next(it)
2980 ('2', '3', '4', '5')
2981 >>> next(it)
2982 '6'
2984 """
2986 def __init__(self, iterable, maxlen=None):
2987 self._source = iter(iterable)
2988 if maxlen is None:
2989 self._cache = []
2990 else:
2991 self._cache = deque([], maxlen)
2992 self._index = None
2994 def __iter__(self):
2995 return self
2997 def __next__(self):
2998 if self._index is not None:
2999 try:
3000 item = self._cache[self._index]
3001 except IndexError:
3002 self._index = None
3003 else:
3004 self._index += 1
3005 return item
3007 item = next(self._source)
3008 self._cache.append(item)
3009 return item
3011 def __bool__(self):
3012 try:
3013 self.peek()
3014 except StopIteration:
3015 return False
3016 return True
3018 def peek(self, default=_marker):
3019 try:
3020 peeked = next(self)
3021 except StopIteration:
3022 if default is _marker:
3023 raise
3024 return default
3025 if self._index is None:
3026 self._index = len(self._cache)
3027 self._index -= 1
3028 return peeked
3030 def elements(self):
3031 return SequenceView(self._cache)
3033 def seek(self, index):
3034 self._index = index
3035 remainder = index - len(self._cache)
3036 if remainder > 0:
3037 consume(self, remainder)
3039 def relative_seek(self, count):
3040 if self._index is None:
3041 self._index = len(self._cache)
3043 self.seek(max(self._index + count, 0))
3046class run_length:
3047 """
3048 :func:`run_length.encode` compresses an iterable with run-length encoding.
3049 It yields groups of repeated items with the count of how many times they
3050 were repeated:
3052 >>> uncompressed = 'abbcccdddd'
3053 >>> list(run_length.encode(uncompressed))
3054 [('a', 1), ('b', 2), ('c', 3), ('d', 4)]
3056 :func:`run_length.decode` decompresses an iterable that was previously
3057 compressed with run-length encoding. It yields the items of the
3058 decompressed iterable:
3060 >>> compressed = [('a', 1), ('b', 2), ('c', 3), ('d', 4)]
3061 >>> list(run_length.decode(compressed))
3062 ['a', 'b', 'b', 'c', 'c', 'c', 'd', 'd', 'd', 'd']
3064 """
3066 @staticmethod
3067 def encode(iterable):
3068 return ((k, ilen(g)) for k, g in groupby(iterable))
3070 @staticmethod
3071 def decode(iterable):
3072 return chain.from_iterable(starmap(repeat, iterable))
3075def exactly_n(iterable, n, predicate=bool):
3076 """Return ``True`` if exactly ``n`` items in the iterable are ``True``
3077 according to the *predicate* function.
3079 >>> exactly_n([True, True, False], 2)
3080 True
3081 >>> exactly_n([True, True, False], 1)
3082 False
3083 >>> exactly_n([0, 1, 2, 3, 4, 5], 3, lambda x: x < 3)
3084 True
3086 The iterable will be advanced until ``n + 1`` truthy items are encountered,
3087 so avoid calling it on infinite iterables.
3089 """
3090 return ilen(islice(filter(predicate, iterable), n + 1)) == n
3093def circular_shifts(iterable, steps=1):
3094 """Yield the circular shifts of *iterable*.
3096 >>> list(circular_shifts(range(4)))
3097 [(0, 1, 2, 3), (1, 2, 3, 0), (2, 3, 0, 1), (3, 0, 1, 2)]
3099 Set *steps* to the number of places to rotate to the left
3100 (or to the right if negative). Defaults to 1.
3102 >>> list(circular_shifts(range(4), 2))
3103 [(0, 1, 2, 3), (2, 3, 0, 1)]
3105 >>> list(circular_shifts(range(4), -1))
3106 [(0, 1, 2, 3), (3, 0, 1, 2), (2, 3, 0, 1), (1, 2, 3, 0)]
3108 """
3109 buffer = deque(iterable)
3110 if steps == 0:
3111 raise ValueError('Steps should be a non-zero integer')
3113 buffer.rotate(steps)
3114 steps = -steps
3115 n = len(buffer)
3116 n //= math.gcd(n, steps)
3118 for _ in repeat(None, n):
3119 buffer.rotate(steps)
3120 yield tuple(buffer)
3123def make_decorator(wrapping_func, result_index=0):
3124 """Return a decorator version of *wrapping_func*, which is a function that
3125 modifies an iterable. *result_index* is the position in that function's
3126 signature where the iterable goes.
3128 This lets you use itertools on the "production end," i.e. at function
3129 definition. This can augment what the function returns without changing the
3130 function's code.
3132 For example, to produce a decorator version of :func:`chunked`:
3134 >>> from more_itertools import chunked
3135 >>> chunker = make_decorator(chunked, result_index=0)
3136 >>> @chunker(3)
3137 ... def iter_range(n):
3138 ... return iter(range(n))
3139 ...
3140 >>> list(iter_range(9))
3141 [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
3143 To only allow truthy items to be returned:
3145 >>> truth_serum = make_decorator(filter, result_index=1)
3146 >>> @truth_serum(bool)
3147 ... def boolean_test():
3148 ... return [0, 1, '', ' ', False, True]
3149 ...
3150 >>> list(boolean_test())
3151 [1, ' ', True]
3153 The :func:`peekable` and :func:`seekable` wrappers make for practical
3154 decorators:
3156 >>> from more_itertools import peekable
3157 >>> peekable_function = make_decorator(peekable)
3158 >>> @peekable_function()
3159 ... def str_range(*args):
3160 ... return (str(x) for x in range(*args))
3161 ...
3162 >>> it = str_range(1, 20, 2)
3163 >>> next(it), next(it), next(it)
3164 ('1', '3', '5')
3165 >>> it.peek()
3166 '7'
3167 >>> next(it)
3168 '7'
3170 """
3172 # See https://sites.google.com/site/bbayles/index/decorator_factory for
3173 # notes on how this works.
3174 def decorator(*wrapping_args, **wrapping_kwargs):
3175 def outer_wrapper(f):
3176 def inner_wrapper(*args, **kwargs):
3177 result = f(*args, **kwargs)
3178 wrapping_args_ = list(wrapping_args)
3179 wrapping_args_.insert(result_index, result)
3180 return wrapping_func(*wrapping_args_, **wrapping_kwargs)
3182 return inner_wrapper
3184 return outer_wrapper
3186 return decorator
3189def map_reduce(iterable, keyfunc, valuefunc=None, reducefunc=None):
3190 """Return a dictionary that maps the items in *iterable* to categories
3191 defined by *keyfunc*, transforms them with *valuefunc*, and
3192 then summarizes them by category with *reducefunc*.
3194 *valuefunc* defaults to the identity function if it is unspecified.
3195 If *reducefunc* is unspecified, no summarization takes place:
3197 >>> keyfunc = lambda x: x.upper()
3198 >>> result = map_reduce('abbccc', keyfunc)
3199 >>> sorted(result.items())
3200 [('A', ['a']), ('B', ['b', 'b']), ('C', ['c', 'c', 'c'])]
3202 Specifying *valuefunc* transforms the categorized items:
3204 >>> keyfunc = lambda x: x.upper()
3205 >>> valuefunc = lambda x: 1
3206 >>> result = map_reduce('abbccc', keyfunc, valuefunc)
3207 >>> sorted(result.items())
3208 [('A', [1]), ('B', [1, 1]), ('C', [1, 1, 1])]
3210 Specifying *reducefunc* summarizes the categorized items:
3212 >>> keyfunc = lambda x: x.upper()
3213 >>> valuefunc = lambda x: 1
3214 >>> reducefunc = sum
3215 >>> result = map_reduce('abbccc', keyfunc, valuefunc, reducefunc)
3216 >>> sorted(result.items())
3217 [('A', 1), ('B', 2), ('C', 3)]
3219 You may want to filter the input iterable before applying the map/reduce
3220 procedure:
3222 >>> all_items = range(30)
3223 >>> items = [x for x in all_items if 10 <= x <= 20] # Filter
3224 >>> keyfunc = lambda x: x % 2 # Evens map to 0; odds to 1
3225 >>> categories = map_reduce(items, keyfunc=keyfunc)
3226 >>> sorted(categories.items())
3227 [(0, [10, 12, 14, 16, 18, 20]), (1, [11, 13, 15, 17, 19])]
3228 >>> summaries = map_reduce(items, keyfunc=keyfunc, reducefunc=sum)
3229 >>> sorted(summaries.items())
3230 [(0, 90), (1, 75)]
3232 Note that all items in the iterable are gathered into a list before the
3233 summarization step, which may require significant storage.
3235 The returned object is a :obj:`collections.defaultdict` with the
3236 ``default_factory`` set to ``None``, such that it behaves like a normal
3237 dictionary.
3239 """
3241 ret = defaultdict(list)
3243 if valuefunc is None:
3244 for item in iterable:
3245 key = keyfunc(item)
3246 ret[key].append(item)
3248 else:
3249 for item in iterable:
3250 key = keyfunc(item)
3251 value = valuefunc(item)
3252 ret[key].append(value)
3254 if reducefunc is not None:
3255 for key, value_list in ret.items():
3256 ret[key] = reducefunc(value_list)
3258 ret.default_factory = None
3259 return ret
3262def rlocate(iterable, pred=bool, window_size=None):
3263 """Yield the index of each item in *iterable* for which *pred* returns
3264 ``True``, starting from the right and moving left.
3266 *pred* defaults to :func:`bool`, which will select truthy items:
3268 >>> list(rlocate([0, 1, 1, 0, 1, 0, 0])) # Truthy at 1, 2, and 4
3269 [4, 2, 1]
3271 Set *pred* to a custom function to, e.g., find the indexes for a particular
3272 item:
3274 >>> iterator = iter('abcb')
3275 >>> pred = lambda x: x == 'b'
3276 >>> list(rlocate(iterator, pred))
3277 [3, 1]
3279 If *window_size* is given, then the *pred* function will be called with
3280 that many items. This enables searching for sub-sequences:
3282 >>> iterable = [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]
3283 >>> pred = lambda *args: args == (1, 2, 3)
3284 >>> list(rlocate(iterable, pred=pred, window_size=3))
3285 [9, 5, 1]
3287 Beware, this function won't return anything for infinite iterables.
3288 If *iterable* is reversible, ``rlocate`` will reverse it and search from
3289 the right. Otherwise, it will search from the left and return the results
3290 in reverse order.
3292 See :func:`locate` to for other example applications.
3294 """
3295 if window_size is None:
3296 try:
3297 len_iter = len(iterable)
3298 return (len_iter - i - 1 for i in locate(reversed(iterable), pred))
3299 except TypeError:
3300 pass
3302 return reversed(list(locate(iterable, pred, window_size)))
3305def replace(iterable, pred, substitutes, count=None, window_size=1):
3306 """Yield the items from *iterable*, replacing the items for which *pred*
3307 returns ``True`` with the items from the iterable *substitutes*.
3309 >>> iterable = [1, 1, 0, 1, 1, 0, 1, 1]
3310 >>> pred = lambda x: x == 0
3311 >>> substitutes = (2, 3)
3312 >>> list(replace(iterable, pred, substitutes))
3313 [1, 1, 2, 3, 1, 1, 2, 3, 1, 1]
3315 If *count* is given, the number of replacements will be limited:
3317 >>> iterable = [1, 1, 0, 1, 1, 0, 1, 1, 0]
3318 >>> pred = lambda x: x == 0
3319 >>> substitutes = [None]
3320 >>> list(replace(iterable, pred, substitutes, count=2))
3321 [1, 1, None, 1, 1, None, 1, 1, 0]
3323 Use *window_size* to control the number of items passed as arguments to
3324 *pred*. This allows for locating and replacing subsequences.
3326 >>> iterable = [0, 1, 2, 5, 0, 1, 2, 5]
3327 >>> window_size = 3
3328 >>> pred = lambda *args: args == (0, 1, 2) # 3 items passed to pred
3329 >>> substitutes = [3, 4] # Splice in these items
3330 >>> list(replace(iterable, pred, substitutes, window_size=window_size))
3331 [3, 4, 5, 3, 4, 5]
3333 """
3334 if window_size < 1:
3335 raise ValueError('window_size must be at least 1')
3337 # Save the substitutes iterable, since it's used more than once
3338 substitutes = tuple(substitutes)
3340 # Add padding such that the number of windows matches the length of the
3341 # iterable
3342 it = chain(iterable, repeat(_marker, window_size - 1))
3343 windows = windowed(it, window_size)
3345 n = 0
3346 for w in windows:
3347 # If the current window matches our predicate (and we haven't hit
3348 # our maximum number of replacements), splice in the substitutes
3349 # and then consume the following windows that overlap with this one.
3350 # For example, if the iterable is (0, 1, 2, 3, 4...)
3351 # and the window size is 2, we have (0, 1), (1, 2), (2, 3)...
3352 # If the predicate matches on (0, 1), we need to zap (0, 1) and (1, 2)
3353 if pred(*w):
3354 if (count is None) or (n < count):
3355 n += 1
3356 yield from substitutes
3357 consume(windows, window_size - 1)
3358 continue
3360 # If there was no match (or we've reached the replacement limit),
3361 # yield the first item from the window.
3362 if w and (w[0] is not _marker):
3363 yield w[0]
3366def partitions(iterable):
3367 """Yield all possible order-preserving partitions of *iterable*.
3369 >>> iterable = 'abc'
3370 >>> for part in partitions(iterable):
3371 ... print([''.join(p) for p in part])
3372 ['abc']
3373 ['a', 'bc']
3374 ['ab', 'c']
3375 ['a', 'b', 'c']
3377 This is unrelated to :func:`partition`.
3379 """
3380 sequence = list(iterable)
3381 n = len(sequence)
3382 for i in powerset(range(1, n)):
3383 yield [sequence[i:j] for i, j in zip((0,) + i, i + (n,))]
3386def set_partitions(iterable, k=None, min_size=None, max_size=None):
3387 """
3388 Yield the set partitions of *iterable* into *k* parts. Set partitions are
3389 not order-preserving.
3391 >>> iterable = 'abc'
3392 >>> for part in set_partitions(iterable, 2):
3393 ... print([''.join(p) for p in part])
3394 ['a', 'bc']
3395 ['ab', 'c']
3396 ['b', 'ac']
3399 If *k* is not given, every set partition is generated.
3401 >>> iterable = 'abc'
3402 >>> for part in set_partitions(iterable):
3403 ... print([''.join(p) for p in part])
3404 ['abc']
3405 ['a', 'bc']
3406 ['ab', 'c']
3407 ['b', 'ac']
3408 ['a', 'b', 'c']
3410 if *min_size* and/or *max_size* are given, the minimum and/or maximum size
3411 per block in partition is set.
3413 >>> iterable = 'abc'
3414 >>> for part in set_partitions(iterable, min_size=2):
3415 ... print([''.join(p) for p in part])
3416 ['abc']
3417 >>> for part in set_partitions(iterable, max_size=2):
3418 ... print([''.join(p) for p in part])
3419 ['a', 'bc']
3420 ['ab', 'c']
3421 ['b', 'ac']
3422 ['a', 'b', 'c']
3424 """
3425 L = list(iterable)
3426 n = len(L)
3427 if k is not None:
3428 if k < 1:
3429 raise ValueError(
3430 "Can't partition in a negative or zero number of groups"
3431 )
3432 elif k > n:
3433 return
3435 min_size = min_size if min_size is not None else 0
3436 max_size = max_size if max_size is not None else n
3437 if min_size > max_size:
3438 return
3440 def set_partitions_helper(L, k):
3441 n = len(L)
3442 if k == 1:
3443 yield [L]
3444 elif n == k:
3445 yield [[s] for s in L]
3446 else:
3447 e, *M = L
3448 for p in set_partitions_helper(M, k - 1):
3449 yield [[e], *p]
3450 for p in set_partitions_helper(M, k):
3451 for i in range(len(p)):
3452 yield p[:i] + [[e] + p[i]] + p[i + 1 :]
3454 if k is None:
3455 for k in range(1, n + 1):
3456 yield from filter(
3457 lambda z: all(min_size <= len(bk) <= max_size for bk in z),
3458 set_partitions_helper(L, k),
3459 )
3460 else:
3461 yield from filter(
3462 lambda z: all(min_size <= len(bk) <= max_size for bk in z),
3463 set_partitions_helper(L, k),
3464 )
3467class time_limited:
3468 """
3469 Yield items from *iterable* until *limit_seconds* have passed.
3470 If the time limit expires before all items have been yielded, the
3471 ``timed_out`` parameter will be set to ``True``.
3473 >>> from time import sleep
3474 >>> def generator():
3475 ... yield 1
3476 ... yield 2
3477 ... sleep(0.2)
3478 ... yield 3
3479 >>> iterable = time_limited(0.1, generator())
3480 >>> list(iterable)
3481 [1, 2]
3482 >>> iterable.timed_out
3483 True
3485 Note that the time is checked before each item is yielded, and iteration
3486 stops if the time elapsed is greater than *limit_seconds*. If your time
3487 limit is 1 second, but it takes 2 seconds to generate the first item from
3488 the iterable, the function will run for 2 seconds and not yield anything.
3489 As a special case, when *limit_seconds* is zero, the iterator never
3490 returns anything.
3492 """
3494 def __init__(self, limit_seconds, iterable):
3495 if limit_seconds < 0:
3496 raise ValueError('limit_seconds must be positive')
3497 self.limit_seconds = limit_seconds
3498 self._iterator = iter(iterable)
3499 self._start_time = monotonic()
3500 self.timed_out = False
3502 def __iter__(self):
3503 return self
3505 def __next__(self):
3506 if self.limit_seconds == 0:
3507 self.timed_out = True
3508 raise StopIteration
3509 item = next(self._iterator)
3510 if monotonic() - self._start_time > self.limit_seconds:
3511 self.timed_out = True
3512 raise StopIteration
3514 return item
3517def only(iterable, default=None, too_long=None):
3518 """If *iterable* has only one item, return it.
3519 If it has zero items, return *default*.
3520 If it has more than one item, raise the exception given by *too_long*,
3521 which is ``ValueError`` by default.
3523 >>> only([], default='missing')
3524 'missing'
3525 >>> only([1])
3526 1
3527 >>> only([1, 2]) # doctest: +IGNORE_EXCEPTION_DETAIL
3528 Traceback (most recent call last):
3529 ...
3530 ValueError: Expected exactly one item in iterable, but got 1, 2,
3531 and perhaps more.'
3532 >>> only([1, 2], too_long=TypeError) # doctest: +IGNORE_EXCEPTION_DETAIL
3533 Traceback (most recent call last):
3534 ...
3535 TypeError
3537 Note that :func:`only` attempts to advance *iterable* twice to ensure there
3538 is only one item. See :func:`spy` or :func:`peekable` to check
3539 iterable contents less destructively.
3541 """
3542 iterator = iter(iterable)
3543 for first in iterator:
3544 for second in iterator:
3545 msg = (
3546 f'Expected exactly one item in iterable, but got {first!r}, '
3547 f'{second!r}, and perhaps more.'
3548 )
3549 raise too_long or ValueError(msg)
3550 return first
3551 return default
3554def _ichunk(iterator, n):
3555 cache = deque()
3556 chunk = islice(iterator, n)
3558 def generator():
3559 with suppress(StopIteration):
3560 while True:
3561 if cache:
3562 yield cache.popleft()
3563 else:
3564 yield next(chunk)
3566 def materialize_next(n=1):
3567 # if n not specified materialize everything
3568 if n is None:
3569 cache.extend(chunk)
3570 return len(cache)
3572 to_cache = n - len(cache)
3574 # materialize up to n
3575 if to_cache > 0:
3576 cache.extend(islice(chunk, to_cache))
3578 # return number materialized up to n
3579 return min(n, len(cache))
3581 return (generator(), materialize_next)
3584def ichunked(iterable, n):
3585 """Break *iterable* into sub-iterables with *n* elements each.
3586 :func:`ichunked` is like :func:`chunked`, but it yields iterables
3587 instead of lists.
3589 If the sub-iterables are read in order, the elements of *iterable*
3590 won't be stored in memory.
3591 If they are read out of order, :func:`itertools.tee` is used to cache
3592 elements as necessary.
3594 >>> from itertools import count
3595 >>> all_chunks = ichunked(count(), 4)
3596 >>> c_1, c_2, c_3 = next(all_chunks), next(all_chunks), next(all_chunks)
3597 >>> list(c_2) # c_1's elements have been cached; c_3's haven't been
3598 [4, 5, 6, 7]
3599 >>> list(c_1)
3600 [0, 1, 2, 3]
3601 >>> list(c_3)
3602 [8, 9, 10, 11]
3604 """
3605 iterator = iter(iterable)
3606 while True:
3607 # Create new chunk
3608 chunk, materialize_next = _ichunk(iterator, n)
3610 # Check to see whether we're at the end of the source iterable
3611 if not materialize_next():
3612 return
3614 yield chunk
3616 # Fill previous chunk's cache
3617 materialize_next(None)
3620def iequals(*iterables):
3621 """Return ``True`` if all given *iterables* are equal to each other,
3622 which means that they contain the same elements in the same order.
3624 The function is useful for comparing iterables of different data types
3625 or iterables that do not support equality checks.
3627 >>> iequals("abc", ['a', 'b', 'c'], ('a', 'b', 'c'), iter("abc"))
3628 True
3630 >>> iequals("abc", "acb")
3631 False
3633 Not to be confused with :func:`all_equal`, which checks whether all
3634 elements of iterable are equal to each other.
3636 """
3637 return all(map(all_equal, zip_longest(*iterables, fillvalue=object())))
3640def distinct_combinations(iterable, r):
3641 """Yield the distinct combinations of *r* items taken from *iterable*.
3643 >>> list(distinct_combinations([0, 0, 1], 2))
3644 [(0, 0), (0, 1)]
3646 Equivalent to ``set(combinations(iterable))``, except duplicates are not
3647 generated and thrown away. For larger input sequences this is much more
3648 efficient.
3650 """
3651 if r < 0:
3652 raise ValueError('r must be non-negative')
3653 elif r == 0:
3654 yield ()
3655 return
3656 pool = tuple(iterable)
3657 generators = [unique_everseen(enumerate(pool), key=itemgetter(1))]
3658 current_combo = [None] * r
3659 level = 0
3660 while generators:
3661 try:
3662 cur_idx, p = next(generators[-1])
3663 except StopIteration:
3664 generators.pop()
3665 level -= 1
3666 continue
3667 current_combo[level] = p
3668 if level + 1 == r:
3669 yield tuple(current_combo)
3670 else:
3671 generators.append(
3672 unique_everseen(
3673 enumerate(pool[cur_idx + 1 :], cur_idx + 1),
3674 key=itemgetter(1),
3675 )
3676 )
3677 level += 1
3680def filter_except(validator, iterable, *exceptions):
3681 """Yield the items from *iterable* for which the *validator* function does
3682 not raise one of the specified *exceptions*.
3684 *validator* is called for each item in *iterable*.
3685 It should be a function that accepts one argument and raises an exception
3686 if that item is not valid.
3688 >>> iterable = ['1', '2', 'three', '4', None]
3689 >>> list(filter_except(int, iterable, ValueError, TypeError))
3690 ['1', '2', '4']
3692 If an exception other than one given by *exceptions* is raised by
3693 *validator*, it is raised like normal.
3694 """
3695 for item in iterable:
3696 try:
3697 validator(item)
3698 except exceptions:
3699 pass
3700 else:
3701 yield item
3704def map_except(function, iterable, *exceptions):
3705 """Transform each item from *iterable* with *function* and yield the
3706 result, unless *function* raises one of the specified *exceptions*.
3708 *function* is called to transform each item in *iterable*.
3709 It should accept one argument.
3711 >>> iterable = ['1', '2', 'three', '4', None]
3712 >>> list(map_except(int, iterable, ValueError, TypeError))
3713 [1, 2, 4]
3715 If an exception other than one given by *exceptions* is raised by
3716 *function*, it is raised like normal.
3717 """
3718 for item in iterable:
3719 try:
3720 yield function(item)
3721 except exceptions:
3722 pass
3725def map_if(iterable, pred, func, func_else=None):
3726 """Evaluate each item from *iterable* using *pred*. If the result is
3727 equivalent to ``True``, transform the item with *func* and yield it.
3728 Otherwise, transform the item with *func_else* and yield it.
3730 *pred*, *func*, and *func_else* should each be functions that accept
3731 one argument. By default, *func_else* is the identity function.
3733 >>> from math import sqrt
3734 >>> iterable = list(range(-5, 5))
3735 >>> iterable
3736 [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4]
3737 >>> list(map_if(iterable, lambda x: x > 3, lambda x: 'toobig'))
3738 [-5, -4, -3, -2, -1, 0, 1, 2, 3, 'toobig']
3739 >>> list(map_if(iterable, lambda x: x >= 0,
3740 ... lambda x: f'{sqrt(x):.2f}', lambda x: None))
3741 [None, None, None, None, None, '0.00', '1.00', '1.41', '1.73', '2.00']
3742 """
3744 if func_else is None:
3745 for item in iterable:
3746 yield func(item) if pred(item) else item
3748 else:
3749 for item in iterable:
3750 yield func(item) if pred(item) else func_else(item)
3753def _sample_unweighted(iterator, k, strict):
3754 # Algorithm L in the 1994 paper by Kim-Hung Li:
3755 # "Reservoir-Sampling Algorithms of Time Complexity O(n(1+log(N/n)))".
3757 reservoir = list(islice(iterator, k))
3758 if strict and len(reservoir) < k:
3759 raise ValueError('Sample larger than population')
3760 W = 1.0
3762 with suppress(StopIteration):
3763 while True:
3764 W *= random() ** (1 / k)
3765 skip = floor(log(random()) / log1p(-W))
3766 element = next(islice(iterator, skip, None))
3767 reservoir[randrange(k)] = element
3769 shuffle(reservoir)
3770 return reservoir
3773def _sample_weighted(iterator, k, weights, strict):
3774 # Implementation of "A-ExpJ" from the 2006 paper by Efraimidis et al. :
3775 # "Weighted random sampling with a reservoir".
3777 # Log-transform for numerical stability for weights that are small/large
3778 weight_keys = (log(random()) / weight for weight in weights)
3780 # Fill up the reservoir (collection of samples) with the first `k`
3781 # weight-keys and elements, then heapify the list.
3782 reservoir = take(k, zip(weight_keys, iterator))
3783 if strict and len(reservoir) < k:
3784 raise ValueError('Sample larger than population')
3786 heapify(reservoir)
3788 # The number of jumps before changing the reservoir is a random variable
3789 # with an exponential distribution. Sample it using random() and logs.
3790 smallest_weight_key, _ = reservoir[0]
3791 weights_to_skip = log(random()) / smallest_weight_key
3793 for weight, element in zip(weights, iterator):
3794 if weight >= weights_to_skip:
3795 # The notation here is consistent with the paper, but we store
3796 # the weight-keys in log-space for better numerical stability.
3797 smallest_weight_key, _ = reservoir[0]
3798 t_w = exp(weight * smallest_weight_key)
3799 r_2 = uniform(t_w, 1) # generate U(t_w, 1)
3800 weight_key = log(r_2) / weight
3801 heapreplace(reservoir, (weight_key, element))
3802 smallest_weight_key, _ = reservoir[0]
3803 weights_to_skip = log(random()) / smallest_weight_key
3804 else:
3805 weights_to_skip -= weight
3807 ret = [element for weight_key, element in reservoir]
3808 shuffle(ret)
3809 return ret
3812def _sample_counted(population, k, counts, strict):
3813 element = None
3814 remaining = 0
3816 def feed(i):
3817 # Advance *i* steps ahead and consume an element
3818 nonlocal element, remaining
3820 while i + 1 > remaining:
3821 i = i - remaining
3822 element = next(population)
3823 remaining = next(counts)
3824 remaining -= i + 1
3825 return element
3827 with suppress(StopIteration):
3828 reservoir = []
3829 for _ in range(k):
3830 reservoir.append(feed(0))
3832 if strict and len(reservoir) < k:
3833 raise ValueError('Sample larger than population')
3835 with suppress(StopIteration):
3836 W = 1.0
3837 while True:
3838 W *= random() ** (1 / k)
3839 skip = floor(log(random()) / log1p(-W))
3840 element = feed(skip)
3841 reservoir[randrange(k)] = element
3843 shuffle(reservoir)
3844 return reservoir
3847def sample(iterable, k, weights=None, *, counts=None, strict=False):
3848 """Return a *k*-length list of elements chosen (without replacement)
3849 from the *iterable*.
3851 Similar to :func:`random.sample`, but works on inputs that aren't
3852 indexable (such as sets and dictionaries) and on inputs where the
3853 size isn't known in advance (such as generators).
3855 >>> iterable = range(100)
3856 >>> sample(iterable, 5) # doctest: +SKIP
3857 [81, 60, 96, 16, 4]
3859 For iterables with repeated elements, you may supply *counts* to
3860 indicate the repeats.
3862 >>> iterable = ['a', 'b']
3863 >>> counts = [3, 4] # Equivalent to 'a', 'a', 'a', 'b', 'b', 'b', 'b'
3864 >>> sample(iterable, k=3, counts=counts) # doctest: +SKIP
3865 ['a', 'a', 'b']
3867 An iterable with *weights* may be given:
3869 >>> iterable = range(100)
3870 >>> weights = (i * i + 1 for i in range(100))
3871 >>> sampled = sample(iterable, 5, weights=weights) # doctest: +SKIP
3872 [79, 67, 74, 66, 78]
3874 Weighted selections are made without replacement.
3875 After an element is selected, it is removed from the pool and the
3876 relative weights of the other elements increase (this
3877 does not match the behavior of :func:`random.sample`'s *counts*
3878 parameter). Note that *weights* may not be used with *counts*.
3880 If the length of *iterable* is less than *k*,
3881 ``ValueError`` is raised if *strict* is ``True`` and
3882 all elements are returned (in shuffled order) if *strict* is ``False``.
3884 By default, the `Algorithm L <https://w.wiki/ANrM>`__ reservoir sampling
3885 technique is used. When *weights* are provided,
3886 `Algorithm A-ExpJ <https://w.wiki/ANrS>`__ is used instead.
3888 Notes on reproducibility:
3890 * The algorithms rely on inexact floating-point functions provided
3891 by the underlying math library (e.g. ``log``, ``log1p``, and ``pow``).
3892 Those functions can `produce slightly different results
3893 <https://members.loria.fr/PZimmermann/papers/accuracy.pdf>`_ on
3894 different builds. Accordingly, selections can vary across builds
3895 even for the same seed.
3897 * The algorithms loop over the input and make selections based on
3898 ordinal position, so selections from unordered collections (such as
3899 sets) won't reproduce across sessions on the same platform using the
3900 same seed. For example, this won't reproduce::
3902 >> seed(8675309)
3903 >> sample(set('abcdefghijklmnopqrstuvwxyz'), 10)
3904 ['c', 'p', 'e', 'w', 's', 'a', 'j', 'd', 'n', 't']
3906 """
3907 iterator = iter(iterable)
3909 if k < 0:
3910 raise ValueError('k must be non-negative')
3912 if k == 0:
3913 return []
3915 if weights is not None and counts is not None:
3916 raise TypeError('weights and counts are mutually exclusive')
3918 elif weights is not None:
3919 weights = iter(weights)
3920 return _sample_weighted(iterator, k, weights, strict)
3922 elif counts is not None:
3923 counts = iter(counts)
3924 return _sample_counted(iterator, k, counts, strict)
3926 else:
3927 return _sample_unweighted(iterator, k, strict)
3930def is_sorted(iterable, key=None, reverse=False, strict=False):
3931 """Returns ``True`` if the items of iterable are in sorted order, and
3932 ``False`` otherwise. *key* and *reverse* have the same meaning that they do
3933 in the built-in :func:`sorted` function.
3935 >>> is_sorted(['1', '2', '3', '4', '5'], key=int)
3936 True
3937 >>> is_sorted([5, 4, 3, 1, 2], reverse=True)
3938 False
3940 If *strict*, tests for strict sorting, that is, returns ``False`` if equal
3941 elements are found:
3943 >>> is_sorted([1, 2, 2])
3944 True
3945 >>> is_sorted([1, 2, 2], strict=True)
3946 False
3948 The function returns ``False`` after encountering the first out-of-order
3949 item, which means it may produce results that differ from the built-in
3950 :func:`sorted` function for objects with unusual comparison dynamics
3951 (like ``math.nan``). If there are no out-of-order items, the iterable is
3952 exhausted.
3953 """
3954 it = iterable if (key is None) else map(key, iterable)
3955 a, b = tee(it)
3956 next(b, None)
3957 if reverse:
3958 b, a = a, b
3959 return all(map(lt, a, b)) if strict else not any(map(lt, b, a))
3962class AbortThread(BaseException):
3963 pass
3966class callback_iter:
3967 """Convert a function that uses callbacks to an iterator.
3969 Let *func* be a function that takes a `callback` keyword argument.
3970 For example:
3972 >>> def func(callback=None):
3973 ... for i, c in [(1, 'a'), (2, 'b'), (3, 'c')]:
3974 ... if callback:
3975 ... callback(i, c)
3976 ... return 4
3979 Use ``with callback_iter(func)`` to get an iterator over the parameters
3980 that are delivered to the callback.
3982 >>> with callback_iter(func) as it:
3983 ... for args, kwargs in it:
3984 ... print(args)
3985 (1, 'a')
3986 (2, 'b')
3987 (3, 'c')
3989 The function will be called in a background thread. The ``done`` property
3990 indicates whether it has completed execution.
3992 >>> it.done
3993 True
3995 If it completes successfully, its return value will be available
3996 in the ``result`` property.
3998 >>> it.result
3999 4
4001 Notes:
4003 * If the function uses some keyword argument besides ``callback``, supply
4004 *callback_kwd*.
4005 * If it finished executing, but raised an exception, accessing the
4006 ``result`` property will raise the same exception.
4007 * If it hasn't finished executing, accessing the ``result``
4008 property from within the ``with`` block will raise ``RuntimeError``.
4009 * If it hasn't finished executing, accessing the ``result`` property from
4010 outside the ``with`` block will raise a
4011 ``more_itertools.AbortThread`` exception.
4012 * Provide *wait_seconds* to adjust how frequently the it is polled for
4013 output.
4015 """
4017 def __init__(self, func, callback_kwd='callback', wait_seconds=0.1):
4018 self._func = func
4019 self._callback_kwd = callback_kwd
4020 self._aborted = False
4021 self._future = None
4022 self._wait_seconds = wait_seconds
4023 # Lazily import concurrent.future
4024 self._executor = __import__(
4025 'concurrent.futures'
4026 ).futures.ThreadPoolExecutor(max_workers=1)
4027 self._iterator = self._reader()
4029 def __enter__(self):
4030 return self
4032 def __exit__(self, exc_type, exc_value, traceback):
4033 self._aborted = True
4034 self._executor.shutdown()
4036 def __iter__(self):
4037 return self
4039 def __next__(self):
4040 return next(self._iterator)
4042 @property
4043 def done(self):
4044 if self._future is None:
4045 return False
4046 return self._future.done()
4048 @property
4049 def result(self):
4050 if not self.done:
4051 raise RuntimeError('Function has not yet completed')
4053 return self._future.result()
4055 def _reader(self):
4056 q = Queue()
4058 def callback(*args, **kwargs):
4059 if self._aborted:
4060 raise AbortThread('canceled by user')
4062 q.put((args, kwargs))
4064 self._future = self._executor.submit(
4065 self._func, **{self._callback_kwd: callback}
4066 )
4068 while True:
4069 try:
4070 item = q.get(timeout=self._wait_seconds)
4071 except Empty:
4072 pass
4073 else:
4074 q.task_done()
4075 yield item
4077 if self._future.done():
4078 break
4080 remaining = []
4081 while True:
4082 try:
4083 item = q.get_nowait()
4084 except Empty:
4085 break
4086 else:
4087 q.task_done()
4088 remaining.append(item)
4089 q.join()
4090 yield from remaining
4093def windowed_complete(iterable, n):
4094 """
4095 Yield ``(beginning, middle, end)`` tuples, where:
4097 * Each ``middle`` has *n* items from *iterable*
4098 * Each ``beginning`` has the items before the ones in ``middle``
4099 * Each ``end`` has the items after the ones in ``middle``
4101 >>> iterable = range(7)
4102 >>> n = 3
4103 >>> for beginning, middle, end in windowed_complete(iterable, n):
4104 ... print(beginning, middle, end)
4105 () (0, 1, 2) (3, 4, 5, 6)
4106 (0,) (1, 2, 3) (4, 5, 6)
4107 (0, 1) (2, 3, 4) (5, 6)
4108 (0, 1, 2) (3, 4, 5) (6,)
4109 (0, 1, 2, 3) (4, 5, 6) ()
4111 Note that *n* must be at least 0 and most equal to the length of
4112 *iterable*.
4114 This function will exhaust the iterable and may require significant
4115 storage.
4116 """
4117 if n < 0:
4118 raise ValueError('n must be >= 0')
4120 seq = tuple(iterable)
4121 size = len(seq)
4123 if n > size:
4124 raise ValueError('n must be <= len(seq)')
4126 for i in range(size - n + 1):
4127 beginning = seq[:i]
4128 middle = seq[i : i + n]
4129 end = seq[i + n :]
4130 yield beginning, middle, end
4133def all_unique(iterable, key=None):
4134 """
4135 Returns ``True`` if all the elements of *iterable* are unique (no two
4136 elements are equal).
4138 >>> all_unique('ABCB')
4139 False
4141 If a *key* function is specified, it will be used to make comparisons.
4143 >>> all_unique('ABCb')
4144 True
4145 >>> all_unique('ABCb', str.lower)
4146 False
4148 The function returns as soon as the first non-unique element is
4149 encountered. Iterables with a mix of hashable and unhashable items can
4150 be used, but the function will be slower for unhashable items.
4151 """
4152 seenset = set()
4153 seenset_add = seenset.add
4154 seenlist = []
4155 seenlist_add = seenlist.append
4156 for element in map(key, iterable) if key else iterable:
4157 try:
4158 if element in seenset:
4159 return False
4160 seenset_add(element)
4161 except TypeError:
4162 if element in seenlist:
4163 return False
4164 seenlist_add(element)
4165 return True
4168def nth_product(index, *args):
4169 """Equivalent to ``list(product(*args))[index]``.
4171 The products of *args* can be ordered lexicographically.
4172 :func:`nth_product` computes the product at sort position *index* without
4173 computing the previous products.
4175 >>> nth_product(8, range(2), range(2), range(2), range(2))
4176 (1, 0, 0, 0)
4178 ``IndexError`` will be raised if the given *index* is invalid.
4179 """
4180 pools = list(map(tuple, reversed(args)))
4181 ns = list(map(len, pools))
4183 c = reduce(mul, ns)
4185 if index < 0:
4186 index += c
4188 if not 0 <= index < c:
4189 raise IndexError
4191 result = []
4192 for pool, n in zip(pools, ns):
4193 result.append(pool[index % n])
4194 index //= n
4196 return tuple(reversed(result))
4199def nth_permutation(iterable, r, index):
4200 """Equivalent to ``list(permutations(iterable, r))[index]```
4202 The subsequences of *iterable* that are of length *r* where order is
4203 important can be ordered lexicographically. :func:`nth_permutation`
4204 computes the subsequence at sort position *index* directly, without
4205 computing the previous subsequences.
4207 >>> nth_permutation('ghijk', 2, 5)
4208 ('h', 'i')
4210 ``ValueError`` will be raised If *r* is negative or greater than the length
4211 of *iterable*.
4212 ``IndexError`` will be raised if the given *index* is invalid.
4213 """
4214 pool = list(iterable)
4215 n = len(pool)
4217 if r is None or r == n:
4218 r, c = n, factorial(n)
4219 elif not 0 <= r < n:
4220 raise ValueError
4221 else:
4222 c = perm(n, r)
4223 assert c > 0 # factorial(n)>0, and r<n so perm(n,r) is never zero
4225 if index < 0:
4226 index += c
4228 if not 0 <= index < c:
4229 raise IndexError
4231 result = [0] * r
4232 q = index * factorial(n) // c if r < n else index
4233 for d in range(1, n + 1):
4234 q, i = divmod(q, d)
4235 if 0 <= n - d < r:
4236 result[n - d] = i
4237 if q == 0:
4238 break
4240 return tuple(map(pool.pop, result))
4243def nth_combination_with_replacement(iterable, r, index):
4244 """Equivalent to
4245 ``list(combinations_with_replacement(iterable, r))[index]``.
4248 The subsequences with repetition of *iterable* that are of length *r* can
4249 be ordered lexicographically. :func:`nth_combination_with_replacement`
4250 computes the subsequence at sort position *index* directly, without
4251 computing the previous subsequences with replacement.
4253 >>> nth_combination_with_replacement(range(5), 3, 5)
4254 (0, 1, 1)
4256 ``ValueError`` will be raised If *r* is negative or greater than the length
4257 of *iterable*.
4258 ``IndexError`` will be raised if the given *index* is invalid.
4259 """
4260 pool = tuple(iterable)
4261 n = len(pool)
4262 if (r < 0) or (r > n):
4263 raise ValueError
4265 c = comb(n + r - 1, r)
4267 if index < 0:
4268 index += c
4270 if (index < 0) or (index >= c):
4271 raise IndexError
4273 result = []
4274 i = 0
4275 while r:
4276 r -= 1
4277 while n >= 0:
4278 num_combs = comb(n + r - 1, r)
4279 if index < num_combs:
4280 break
4281 n -= 1
4282 i += 1
4283 index -= num_combs
4284 result.append(pool[i])
4286 return tuple(result)
4289def value_chain(*args):
4290 """Yield all arguments passed to the function in the same order in which
4291 they were passed. If an argument itself is iterable then iterate over its
4292 values.
4294 >>> list(value_chain(1, 2, 3, [4, 5, 6]))
4295 [1, 2, 3, 4, 5, 6]
4297 Binary and text strings are not considered iterable and are emitted
4298 as-is:
4300 >>> list(value_chain('12', '34', ['56', '78']))
4301 ['12', '34', '56', '78']
4303 Pre- or postpend a single element to an iterable:
4305 >>> list(value_chain(1, [2, 3, 4, 5, 6]))
4306 [1, 2, 3, 4, 5, 6]
4307 >>> list(value_chain([1, 2, 3, 4, 5], 6))
4308 [1, 2, 3, 4, 5, 6]
4310 Multiple levels of nesting are not flattened.
4312 """
4313 for value in args:
4314 if isinstance(value, (str, bytes)):
4315 yield value
4316 continue
4317 try:
4318 yield from value
4319 except TypeError:
4320 yield value
4323def product_index(element, *args):
4324 """Equivalent to ``list(product(*args)).index(element)``
4326 The products of *args* can be ordered lexicographically.
4327 :func:`product_index` computes the first index of *element* without
4328 computing the previous products.
4330 >>> product_index([8, 2], range(10), range(5))
4331 42
4333 ``ValueError`` will be raised if the given *element* isn't in the product
4334 of *args*.
4335 """
4336 index = 0
4338 for x, pool in zip_longest(element, args, fillvalue=_marker):
4339 if x is _marker or pool is _marker:
4340 raise ValueError('element is not a product of args')
4342 pool = tuple(pool)
4343 index = index * len(pool) + pool.index(x)
4345 return index
4348def combination_index(element, iterable):
4349 """Equivalent to ``list(combinations(iterable, r)).index(element)``
4351 The subsequences of *iterable* that are of length *r* can be ordered
4352 lexicographically. :func:`combination_index` computes the index of the
4353 first *element*, without computing the previous combinations.
4355 >>> combination_index('adf', 'abcdefg')
4356 10
4358 ``ValueError`` will be raised if the given *element* isn't one of the
4359 combinations of *iterable*.
4360 """
4361 element = enumerate(element)
4362 k, y = next(element, (None, None))
4363 if k is None:
4364 return 0
4366 indexes = []
4367 pool = enumerate(iterable)
4368 for n, x in pool:
4369 if x == y:
4370 indexes.append(n)
4371 tmp, y = next(element, (None, None))
4372 if tmp is None:
4373 break
4374 else:
4375 k = tmp
4376 else:
4377 raise ValueError('element is not a combination of iterable')
4379 n, _ = last(pool, default=(n, None))
4381 # Python versions below 3.8 don't have math.comb
4382 index = 1
4383 for i, j in enumerate(reversed(indexes), start=1):
4384 j = n - j
4385 if i <= j:
4386 index += comb(j, i)
4388 return comb(n + 1, k + 1) - index
4391def combination_with_replacement_index(element, iterable):
4392 """Equivalent to
4393 ``list(combinations_with_replacement(iterable, r)).index(element)``
4395 The subsequences with repetition of *iterable* that are of length *r* can
4396 be ordered lexicographically. :func:`combination_with_replacement_index`
4397 computes the index of the first *element*, without computing the previous
4398 combinations with replacement.
4400 >>> combination_with_replacement_index('adf', 'abcdefg')
4401 20
4403 ``ValueError`` will be raised if the given *element* isn't one of the
4404 combinations with replacement of *iterable*.
4405 """
4406 element = tuple(element)
4407 l = len(element)
4408 element = enumerate(element)
4410 k, y = next(element, (None, None))
4411 if k is None:
4412 return 0
4414 indexes = []
4415 pool = tuple(iterable)
4416 for n, x in enumerate(pool):
4417 while x == y:
4418 indexes.append(n)
4419 tmp, y = next(element, (None, None))
4420 if tmp is None:
4421 break
4422 else:
4423 k = tmp
4424 if y is None:
4425 break
4426 else:
4427 raise ValueError(
4428 'element is not a combination with replacement of iterable'
4429 )
4431 n = len(pool)
4432 occupations = [0] * n
4433 for p in indexes:
4434 occupations[p] += 1
4436 index = 0
4437 cumulative_sum = 0
4438 for k in range(1, n):
4439 cumulative_sum += occupations[k - 1]
4440 j = l + n - 1 - k - cumulative_sum
4441 i = n - k
4442 if i <= j:
4443 index += comb(j, i)
4445 return index
4448def permutation_index(element, iterable):
4449 """Equivalent to ``list(permutations(iterable, r)).index(element)```
4451 The subsequences of *iterable* that are of length *r* where order is
4452 important can be ordered lexicographically. :func:`permutation_index`
4453 computes the index of the first *element* directly, without computing
4454 the previous permutations.
4456 >>> permutation_index([1, 3, 2], range(5))
4457 19
4459 ``ValueError`` will be raised if the given *element* isn't one of the
4460 permutations of *iterable*.
4461 """
4462 index = 0
4463 pool = list(iterable)
4464 for i, x in zip(range(len(pool), -1, -1), element):
4465 r = pool.index(x)
4466 index = index * i + r
4467 del pool[r]
4469 return index
4472class countable:
4473 """Wrap *iterable* and keep a count of how many items have been consumed.
4475 The ``items_seen`` attribute starts at ``0`` and increments as the iterable
4476 is consumed:
4478 >>> iterable = map(str, range(10))
4479 >>> it = countable(iterable)
4480 >>> it.items_seen
4481 0
4482 >>> next(it), next(it)
4483 ('0', '1')
4484 >>> list(it)
4485 ['2', '3', '4', '5', '6', '7', '8', '9']
4486 >>> it.items_seen
4487 10
4488 """
4490 def __init__(self, iterable):
4491 self._iterator = iter(iterable)
4492 self.items_seen = 0
4494 def __iter__(self):
4495 return self
4497 def __next__(self):
4498 item = next(self._iterator)
4499 self.items_seen += 1
4501 return item
4504def chunked_even(iterable, n):
4505 """Break *iterable* into lists of approximately length *n*.
4506 Items are distributed such the lengths of the lists differ by at most
4507 1 item.
4509 >>> iterable = [1, 2, 3, 4, 5, 6, 7]
4510 >>> n = 3
4511 >>> list(chunked_even(iterable, n)) # List lengths: 3, 2, 2
4512 [[1, 2, 3], [4, 5], [6, 7]]
4513 >>> list(chunked(iterable, n)) # List lengths: 3, 3, 1
4514 [[1, 2, 3], [4, 5, 6], [7]]
4516 """
4517 iterator = iter(iterable)
4519 # Initialize a buffer to process the chunks while keeping
4520 # some back to fill any underfilled chunks
4521 min_buffer = (n - 1) * (n - 2)
4522 buffer = list(islice(iterator, min_buffer))
4524 # Append items until we have a completed chunk
4525 for _ in islice(map(buffer.append, iterator), n, None, n):
4526 yield buffer[:n]
4527 del buffer[:n]
4529 # Check if any chunks need addition processing
4530 if not buffer:
4531 return
4532 length = len(buffer)
4534 # Chunks are either size `full_size <= n` or `partial_size = full_size - 1`
4535 q, r = divmod(length, n)
4536 num_lists = q + (1 if r > 0 else 0)
4537 q, r = divmod(length, num_lists)
4538 full_size = q + (1 if r > 0 else 0)
4539 partial_size = full_size - 1
4540 num_full = length - partial_size * num_lists
4542 # Yield chunks of full size
4543 partial_start_idx = num_full * full_size
4544 if full_size > 0:
4545 for i in range(0, partial_start_idx, full_size):
4546 yield buffer[i : i + full_size]
4548 # Yield chunks of partial size
4549 if partial_size > 0:
4550 for i in range(partial_start_idx, length, partial_size):
4551 yield buffer[i : i + partial_size]
4554def zip_broadcast(*objects, scalar_types=(str, bytes), strict=False):
4555 """A version of :func:`zip` that "broadcasts" any scalar
4556 (i.e., non-iterable) items into output tuples.
4558 >>> iterable_1 = [1, 2, 3]
4559 >>> iterable_2 = ['a', 'b', 'c']
4560 >>> scalar = '_'
4561 >>> list(zip_broadcast(iterable_1, iterable_2, scalar))
4562 [(1, 'a', '_'), (2, 'b', '_'), (3, 'c', '_')]
4564 The *scalar_types* keyword argument determines what types are considered
4565 scalar. It is set to ``(str, bytes)`` by default. Set it to ``None`` to
4566 treat strings and byte strings as iterable:
4568 >>> list(zip_broadcast('abc', 0, 'xyz', scalar_types=None))
4569 [('a', 0, 'x'), ('b', 0, 'y'), ('c', 0, 'z')]
4571 If the *strict* keyword argument is ``True``, then
4572 ``UnequalIterablesError`` will be raised if any of the iterables have
4573 different lengths.
4574 """
4576 def is_scalar(obj):
4577 if scalar_types and isinstance(obj, scalar_types):
4578 return True
4579 try:
4580 iter(obj)
4581 except TypeError:
4582 return True
4583 else:
4584 return False
4586 size = len(objects)
4587 if not size:
4588 return
4590 new_item = [None] * size
4591 iterables, iterable_positions = [], []
4592 for i, obj in enumerate(objects):
4593 if is_scalar(obj):
4594 new_item[i] = obj
4595 else:
4596 iterables.append(iter(obj))
4597 iterable_positions.append(i)
4599 if not iterables:
4600 yield tuple(objects)
4601 return
4603 zipper = _zip_equal if strict else zip
4604 for item in zipper(*iterables):
4605 for i, new_item[i] in zip(iterable_positions, item):
4606 pass
4607 yield tuple(new_item)
4610def unique_in_window(iterable, n, key=None):
4611 """Yield the items from *iterable* that haven't been seen recently.
4612 *n* is the size of the lookback window.
4614 >>> iterable = [0, 1, 0, 2, 3, 0]
4615 >>> n = 3
4616 >>> list(unique_in_window(iterable, n))
4617 [0, 1, 2, 3, 0]
4619 The *key* function, if provided, will be used to determine uniqueness:
4621 >>> list(unique_in_window('abAcda', 3, key=lambda x: x.lower()))
4622 ['a', 'b', 'c', 'd', 'a']
4624 The items in *iterable* must be hashable.
4626 """
4627 if n <= 0:
4628 raise ValueError('n must be greater than 0')
4630 window = deque(maxlen=n)
4631 counts = defaultdict(int)
4632 use_key = key is not None
4634 for item in iterable:
4635 if len(window) == n:
4636 to_discard = window[0]
4637 if counts[to_discard] == 1:
4638 del counts[to_discard]
4639 else:
4640 counts[to_discard] -= 1
4642 k = key(item) if use_key else item
4643 if k not in counts:
4644 yield item
4645 counts[k] += 1
4646 window.append(k)
4649def duplicates_everseen(iterable, key=None):
4650 """Yield duplicate elements after their first appearance.
4652 >>> list(duplicates_everseen('mississippi'))
4653 ['s', 'i', 's', 's', 'i', 'p', 'i']
4654 >>> list(duplicates_everseen('AaaBbbCccAaa', str.lower))
4655 ['a', 'a', 'b', 'b', 'c', 'c', 'A', 'a', 'a']
4657 This function is analogous to :func:`unique_everseen` and is subject to
4658 the same performance considerations.
4660 """
4661 seen_set = set()
4662 seen_list = []
4663 use_key = key is not None
4665 for element in iterable:
4666 k = key(element) if use_key else element
4667 try:
4668 if k not in seen_set:
4669 seen_set.add(k)
4670 else:
4671 yield element
4672 except TypeError:
4673 if k not in seen_list:
4674 seen_list.append(k)
4675 else:
4676 yield element
4679def duplicates_justseen(iterable, key=None):
4680 """Yields serially-duplicate elements after their first appearance.
4682 >>> list(duplicates_justseen('mississippi'))
4683 ['s', 's', 'p']
4684 >>> list(duplicates_justseen('AaaBbbCccAaa', str.lower))
4685 ['a', 'a', 'b', 'b', 'c', 'c', 'a', 'a']
4687 This function is analogous to :func:`unique_justseen`.
4689 """
4690 return flatten(g for _, g in groupby(iterable, key) for _ in g)
4693def classify_unique(iterable, key=None):
4694 """Classify each element in terms of its uniqueness.
4696 For each element in the input iterable, return a 3-tuple consisting of:
4698 1. The element itself
4699 2. ``False`` if the element is equal to the one preceding it in the input,
4700 ``True`` otherwise (i.e. the equivalent of :func:`unique_justseen`)
4701 3. ``False`` if this element has been seen anywhere in the input before,
4702 ``True`` otherwise (i.e. the equivalent of :func:`unique_everseen`)
4704 >>> list(classify_unique('otto')) # doctest: +NORMALIZE_WHITESPACE
4705 [('o', True, True),
4706 ('t', True, True),
4707 ('t', False, False),
4708 ('o', True, False)]
4710 This function is analogous to :func:`unique_everseen` and is subject to
4711 the same performance considerations.
4713 """
4714 seen_set = set()
4715 seen_list = []
4716 use_key = key is not None
4717 previous = None
4719 for i, element in enumerate(iterable):
4720 k = key(element) if use_key else element
4721 is_unique_justseen = not i or previous != k
4722 previous = k
4723 is_unique_everseen = False
4724 try:
4725 if k not in seen_set:
4726 seen_set.add(k)
4727 is_unique_everseen = True
4728 except TypeError:
4729 if k not in seen_list:
4730 seen_list.append(k)
4731 is_unique_everseen = True
4732 yield element, is_unique_justseen, is_unique_everseen
4735def minmax(iterable_or_value, *others, key=None, default=_marker):
4736 """Returns both the smallest and largest items from an iterable
4737 or from two or more arguments.
4739 >>> minmax([3, 1, 5])
4740 (1, 5)
4742 >>> minmax(4, 2, 6)
4743 (2, 6)
4745 If a *key* function is provided, it will be used to transform the input
4746 items for comparison.
4748 >>> minmax([5, 30], key=str) # '30' sorts before '5'
4749 (30, 5)
4751 If a *default* value is provided, it will be returned if there are no
4752 input items.
4754 >>> minmax([], default=(0, 0))
4755 (0, 0)
4757 Otherwise ``ValueError`` is raised.
4759 This function makes a single pass over the input elements and takes care to
4760 minimize the number of comparisons made during processing.
4762 Note that unlike the builtin ``max`` function, which always returns the first
4763 item with the maximum value, this function may return another item when there are
4764 ties.
4766 This function is based on the
4767 `recipe <https://code.activestate.com/recipes/577916-fast-minmax-function>`__ by
4768 Raymond Hettinger.
4769 """
4770 iterable = (iterable_or_value, *others) if others else iterable_or_value
4772 it = iter(iterable)
4774 try:
4775 lo = hi = next(it)
4776 except StopIteration as exc:
4777 if default is _marker:
4778 raise ValueError(
4779 '`minmax()` argument is an empty iterable. '
4780 'Provide a `default` value to suppress this error.'
4781 ) from exc
4782 return default
4784 # Different branches depending on the presence of key. This saves a lot
4785 # of unimportant copies which would slow the "key=None" branch
4786 # significantly down.
4787 if key is None:
4788 for x, y in zip_longest(it, it, fillvalue=lo):
4789 if y < x:
4790 x, y = y, x
4791 if x < lo:
4792 lo = x
4793 if hi < y:
4794 hi = y
4796 else:
4797 lo_key = hi_key = key(lo)
4799 for x, y in zip_longest(it, it, fillvalue=lo):
4800 x_key, y_key = key(x), key(y)
4802 if y_key < x_key:
4803 x, y, x_key, y_key = y, x, y_key, x_key
4804 if x_key < lo_key:
4805 lo, lo_key = x, x_key
4806 if hi_key < y_key:
4807 hi, hi_key = y, y_key
4809 return lo, hi
4812def constrained_batches(
4813 iterable, max_size, max_count=None, get_len=len, strict=True
4814):
4815 """Yield batches of items from *iterable* with a combined size limited by
4816 *max_size*.
4818 >>> iterable = [b'12345', b'123', b'12345678', b'1', b'1', b'12', b'1']
4819 >>> list(constrained_batches(iterable, 10))
4820 [(b'12345', b'123'), (b'12345678', b'1', b'1'), (b'12', b'1')]
4822 If a *max_count* is supplied, the number of items per batch is also
4823 limited:
4825 >>> iterable = [b'12345', b'123', b'12345678', b'1', b'1', b'12', b'1']
4826 >>> list(constrained_batches(iterable, 10, max_count = 2))
4827 [(b'12345', b'123'), (b'12345678', b'1'), (b'1', b'12'), (b'1',)]
4829 If a *get_len* function is supplied, use that instead of :func:`len` to
4830 determine item size.
4832 If *strict* is ``True``, raise ``ValueError`` if any single item is bigger
4833 than *max_size*. Otherwise, allow single items to exceed *max_size*.
4834 """
4835 if max_size <= 0:
4836 raise ValueError('maximum size must be greater than zero')
4838 batch = []
4839 batch_size = 0
4840 batch_count = 0
4841 for item in iterable:
4842 item_len = get_len(item)
4843 if strict and item_len > max_size:
4844 raise ValueError('item size exceeds maximum size')
4846 reached_count = batch_count == max_count
4847 reached_size = item_len + batch_size > max_size
4848 if batch_count and (reached_size or reached_count):
4849 yield tuple(batch)
4850 batch.clear()
4851 batch_size = 0
4852 batch_count = 0
4854 batch.append(item)
4855 batch_size += item_len
4856 batch_count += 1
4858 if batch:
4859 yield tuple(batch)
4862def gray_product(*iterables):
4863 """Like :func:`itertools.product`, but return tuples in an order such
4864 that only one element in the generated tuple changes from one iteration
4865 to the next.
4867 >>> list(gray_product('AB','CD'))
4868 [('A', 'C'), ('B', 'C'), ('B', 'D'), ('A', 'D')]
4870 This function consumes all of the input iterables before producing output.
4871 If any of the input iterables have fewer than two items, ``ValueError``
4872 is raised.
4874 For information on the algorithm, see
4875 `this section <https://www-cs-faculty.stanford.edu/~knuth/fasc2a.ps.gz>`__
4876 of Donald Knuth's *The Art of Computer Programming*.
4877 """
4878 all_iterables = tuple(tuple(x) for x in iterables)
4879 iterable_count = len(all_iterables)
4880 for iterable in all_iterables:
4881 if len(iterable) < 2:
4882 raise ValueError("each iterable must have two or more items")
4884 # This is based on "Algorithm H" from section 7.2.1.1, page 20.
4885 # a holds the indexes of the source iterables for the n-tuple to be yielded
4886 # f is the array of "focus pointers"
4887 # o is the array of "directions"
4888 a = [0] * iterable_count
4889 f = list(range(iterable_count + 1))
4890 o = [1] * iterable_count
4891 while True:
4892 yield tuple(all_iterables[i][a[i]] for i in range(iterable_count))
4893 j = f[0]
4894 f[0] = 0
4895 if j == iterable_count:
4896 break
4897 a[j] = a[j] + o[j]
4898 if a[j] == 0 or a[j] == len(all_iterables[j]) - 1:
4899 o[j] = -o[j]
4900 f[j] = f[j + 1]
4901 f[j + 1] = j + 1
4904def partial_product(*iterables):
4905 """Yields tuples containing one item from each iterator, with subsequent
4906 tuples changing a single item at a time by advancing each iterator until it
4907 is exhausted. This sequence guarantees every value in each iterable is
4908 output at least once without generating all possible combinations.
4910 This may be useful, for example, when testing an expensive function.
4912 >>> list(partial_product('AB', 'C', 'DEF'))
4913 [('A', 'C', 'D'), ('B', 'C', 'D'), ('B', 'C', 'E'), ('B', 'C', 'F')]
4914 """
4916 iterators = list(map(iter, iterables))
4918 try:
4919 prod = [next(it) for it in iterators]
4920 except StopIteration:
4921 return
4922 yield tuple(prod)
4924 for i, it in enumerate(iterators):
4925 for prod[i] in it:
4926 yield tuple(prod)
4929def takewhile_inclusive(predicate, iterable):
4930 """A variant of :func:`takewhile` that yields one additional element.
4932 >>> list(takewhile_inclusive(lambda x: x < 5, [1, 4, 6, 4, 1]))
4933 [1, 4, 6]
4935 :func:`takewhile` would return ``[1, 4]``.
4936 """
4937 for x in iterable:
4938 yield x
4939 if not predicate(x):
4940 break
4943def outer_product(func, xs, ys, *args, **kwargs):
4944 """A generalized outer product that applies a binary function to all
4945 pairs of items. Returns a 2D matrix with ``len(xs)`` rows and ``len(ys)``
4946 columns.
4947 Also accepts ``*args`` and ``**kwargs`` that are passed to ``func``.
4949 Multiplication table:
4951 >>> list(outer_product(mul, range(1, 4), range(1, 6)))
4952 [(1, 2, 3, 4, 5), (2, 4, 6, 8, 10), (3, 6, 9, 12, 15)]
4954 Cross tabulation:
4956 >>> xs = ['A', 'B', 'A', 'A', 'B', 'B', 'A', 'A', 'B', 'B']
4957 >>> ys = ['X', 'X', 'X', 'Y', 'Z', 'Z', 'Y', 'Y', 'Z', 'Z']
4958 >>> pair_counts = Counter(zip(xs, ys))
4959 >>> count_rows = lambda x, y: pair_counts[x, y]
4960 >>> list(outer_product(count_rows, sorted(set(xs)), sorted(set(ys))))
4961 [(2, 3, 0), (1, 0, 4)]
4963 Usage with ``*args`` and ``**kwargs``:
4965 >>> animals = ['cat', 'wolf', 'mouse']
4966 >>> list(outer_product(min, animals, animals, key=len))
4967 [('cat', 'cat', 'cat'), ('cat', 'wolf', 'wolf'), ('cat', 'wolf', 'mouse')]
4968 """
4969 ys = tuple(ys)
4970 return batched(
4971 starmap(lambda x, y: func(x, y, *args, **kwargs), product(xs, ys)),
4972 n=len(ys),
4973 )
4976def iter_suppress(iterable, *exceptions):
4977 """Yield each of the items from *iterable*. If the iteration raises one of
4978 the specified *exceptions*, that exception will be suppressed and iteration
4979 will stop.
4981 >>> from itertools import chain
4982 >>> def breaks_at_five(x):
4983 ... while True:
4984 ... if x >= 5:
4985 ... raise RuntimeError
4986 ... yield x
4987 ... x += 1
4988 >>> it_1 = iter_suppress(breaks_at_five(1), RuntimeError)
4989 >>> it_2 = iter_suppress(breaks_at_five(2), RuntimeError)
4990 >>> list(chain(it_1, it_2))
4991 [1, 2, 3, 4, 2, 3, 4]
4992 """
4993 try:
4994 yield from iterable
4995 except exceptions:
4996 return
4999def filter_map(func, iterable):
5000 """Apply *func* to every element of *iterable*, yielding only those which
5001 are not ``None``.
5003 >>> elems = ['1', 'a', '2', 'b', '3']
5004 >>> list(filter_map(lambda s: int(s) if s.isnumeric() else None, elems))
5005 [1, 2, 3]
5006 """
5007 for x in iterable:
5008 y = func(x)
5009 if y is not None:
5010 yield y
5013def powerset_of_sets(iterable):
5014 """Yields all possible subsets of the iterable.
5016 >>> list(powerset_of_sets([1, 2, 3])) # doctest: +SKIP
5017 [set(), {1}, {2}, {3}, {1, 2}, {1, 3}, {2, 3}, {1, 2, 3}]
5018 >>> list(powerset_of_sets([1, 1, 0])) # doctest: +SKIP
5019 [set(), {1}, {0}, {0, 1}]
5021 :func:`powerset_of_sets` takes care to minimize the number
5022 of hash operations performed.
5023 """
5024 sets = tuple(dict.fromkeys(map(frozenset, zip(iterable))))
5025 return chain.from_iterable(
5026 starmap(set().union, combinations(sets, r))
5027 for r in range(len(sets) + 1)
5028 )
5031def join_mappings(**field_to_map):
5032 """
5033 Joins multiple mappings together using their common keys.
5035 >>> user_scores = {'elliot': 50, 'claris': 60}
5036 >>> user_times = {'elliot': 30, 'claris': 40}
5037 >>> join_mappings(score=user_scores, time=user_times)
5038 {'elliot': {'score': 50, 'time': 30}, 'claris': {'score': 60, 'time': 40}}
5039 """
5040 ret = defaultdict(dict)
5042 for field_name, mapping in field_to_map.items():
5043 for key, value in mapping.items():
5044 ret[key][field_name] = value
5046 return dict(ret)
5049def _complex_sumprod(v1, v2):
5050 """High precision sumprod() for complex numbers.
5051 Used by :func:`dft` and :func:`idft`.
5052 """
5054 real = attrgetter('real')
5055 imag = attrgetter('imag')
5056 r1 = chain(map(real, v1), map(neg, map(imag, v1)))
5057 r2 = chain(map(real, v2), map(imag, v2))
5058 i1 = chain(map(real, v1), map(imag, v1))
5059 i2 = chain(map(imag, v2), map(real, v2))
5060 return complex(_fsumprod(r1, r2), _fsumprod(i1, i2))
5063def dft(xarr):
5064 """Discrete Fourier Transform. *xarr* is a sequence of complex numbers.
5065 Yields the components of the corresponding transformed output vector.
5067 >>> import cmath
5068 >>> xarr = [1, 2-1j, -1j, -1+2j] # time domain
5069 >>> Xarr = [2, -2-2j, -2j, 4+4j] # frequency domain
5070 >>> magnitudes, phases = zip(*map(cmath.polar, Xarr))
5071 >>> all(map(cmath.isclose, dft(xarr), Xarr))
5072 True
5074 Inputs are restricted to numeric types that can add and multiply
5075 with a complex number. This includes int, float, complex, and
5076 Fraction, but excludes Decimal.
5078 See :func:`idft` for the inverse Discrete Fourier Transform.
5079 """
5080 N = len(xarr)
5081 roots_of_unity = [e ** (n / N * tau * -1j) for n in range(N)]
5082 for k in range(N):
5083 coeffs = [roots_of_unity[k * n % N] for n in range(N)]
5084 yield _complex_sumprod(xarr, coeffs)
5087def idft(Xarr):
5088 """Inverse Discrete Fourier Transform. *Xarr* is a sequence of
5089 complex numbers. Yields the components of the corresponding
5090 inverse-transformed output vector.
5092 >>> import cmath
5093 >>> xarr = [1, 2-1j, -1j, -1+2j] # time domain
5094 >>> Xarr = [2, -2-2j, -2j, 4+4j] # frequency domain
5095 >>> all(map(cmath.isclose, idft(Xarr), xarr))
5096 True
5098 Inputs are restricted to numeric types that can add and multiply
5099 with a complex number. This includes int, float, complex, and
5100 Fraction, but excludes Decimal.
5102 See :func:`dft` for the Discrete Fourier Transform.
5103 """
5104 N = len(Xarr)
5105 roots_of_unity = [e ** (n / N * tau * 1j) for n in range(N)]
5106 for k in range(N):
5107 coeffs = [roots_of_unity[k * n % N] for n in range(N)]
5108 yield _complex_sumprod(Xarr, coeffs) / N
5111def doublestarmap(func, iterable):
5112 """Apply *func* to every item of *iterable* by dictionary unpacking
5113 the item into *func*.
5115 The difference between :func:`itertools.starmap` and :func:`doublestarmap`
5116 parallels the distinction between ``func(*a)`` and ``func(**a)``.
5118 >>> iterable = [{'a': 1, 'b': 2}, {'a': 40, 'b': 60}]
5119 >>> list(doublestarmap(lambda a, b: a + b, iterable))
5120 [3, 100]
5122 ``TypeError`` will be raised if *func*'s signature doesn't match the
5123 mapping contained in *iterable* or if *iterable* does not contain mappings.
5124 """
5125 for item in iterable:
5126 yield func(**item)
5129def _nth_prime_bounds(n):
5130 """Bounds for the nth prime (counting from 1): lb < p_n < ub."""
5131 # At and above 688,383, the lb/ub spread is under 0.003 * p_n.
5133 if n < 1:
5134 raise ValueError
5136 if n < 6:
5137 return (n, 2.25 * n)
5139 # https://en.wikipedia.org/wiki/Prime-counting_function#Inequalities
5140 upper_bound = n * log(n * log(n))
5141 lower_bound = upper_bound - n
5142 if n >= 688_383:
5143 upper_bound -= n * (1.0 - (log(log(n)) - 2.0) / log(n))
5145 return lower_bound, upper_bound
5148def nth_prime(n, *, approximate=False):
5149 """Return the nth prime (counting from 0).
5151 >>> nth_prime(0)
5152 2
5153 >>> nth_prime(100)
5154 547
5156 If *approximate* is set to True, will return a prime in the close
5157 to the nth prime. The estimation is much faster than computing
5158 an exact result.
5160 >>> nth_prime(200_000_000, approximate=True) # Exact result is 4222234763
5161 4217820427
5163 """
5164 lb, ub = _nth_prime_bounds(n + 1)
5166 if not approximate or n <= 1_000_000:
5167 return nth(sieve(ceil(ub)), n)
5169 # Search from the midpoint and return the first odd prime
5170 odd = floor((lb + ub) / 2) | 1
5171 return first_true(count(odd, step=2), pred=is_prime)
5174def argmin(iterable, *, key=None):
5175 """
5176 Index of the first occurrence of a minimum value in an iterable.
5178 >>> argmin('efghabcdijkl')
5179 4
5180 >>> argmin([3, 2, 1, 0, 4, 2, 1, 0])
5181 3
5183 For example, look up a label corresponding to the position
5184 of a value that minimizes a cost function::
5186 >>> def cost(x):
5187 ... "Days for a wound to heal given a subject's age."
5188 ... return x**2 - 20*x + 150
5189 ...
5190 >>> labels = ['homer', 'marge', 'bart', 'lisa', 'maggie']
5191 >>> ages = [ 35, 30, 10, 9, 1 ]
5193 # Fastest healing family member
5194 >>> labels[argmin(ages, key=cost)]
5195 'bart'
5197 # Age with fastest healing
5198 >>> min(ages, key=cost)
5199 10
5201 """
5202 if key is not None:
5203 iterable = map(key, iterable)
5204 return min(enumerate(iterable), key=itemgetter(1))[0]
5207def argmax(iterable, *, key=None):
5208 """
5209 Index of the first occurrence of a maximum value in an iterable.
5211 >>> argmax('abcdefghabcd')
5212 7
5213 >>> argmax([0, 1, 2, 3, 3, 2, 1, 0])
5214 3
5216 For example, identify the best machine learning model::
5218 >>> models = ['svm', 'random forest', 'knn', 'naïve bayes']
5219 >>> accuracy = [ 68, 61, 84, 72 ]
5221 # Most accurate model
5222 >>> models[argmax(accuracy)]
5223 'knn'
5225 # Best accuracy
5226 >>> max(accuracy)
5227 84
5229 """
5230 if key is not None:
5231 iterable = map(key, iterable)
5232 return max(enumerate(iterable), key=itemgetter(1))[0]