Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/more_itertools/more.py: 18%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import math
3from collections import Counter, defaultdict, deque
4from collections.abc import Sequence
5from contextlib import suppress
6from functools import cached_property, partial, reduce, wraps
7from heapq import heapify, heapreplace
8from itertools import (
9 chain,
10 combinations,
11 compress,
12 count,
13 cycle,
14 dropwhile,
15 groupby,
16 islice,
17 permutations,
18 repeat,
19 starmap,
20 takewhile,
21 tee,
22 zip_longest,
23 product,
24)
25from math import comb, e, exp, factorial, floor, fsum, log, log1p, perm, tau
26from math import ceil
27from queue import Empty, Queue
28from random import random, randrange, shuffle, uniform
29from operator import (
30 attrgetter,
31 getitem,
32 is_not,
33 itemgetter,
34 lt,
35 mul,
36 neg,
37 sub,
38 gt,
39)
40from sys import maxsize
41from time import monotonic
42from threading import Lock
44from .recipes import (
45 _marker,
46 consume,
47 first_true,
48 flatten,
49 is_prime,
50 nth,
51 powerset,
52 sieve,
53 take,
54 unique_everseen,
55 all_equal,
56 batched,
57)
59__all__ = [
60 'AbortThread',
61 'SequenceView',
62 'adjacent',
63 'all_unique',
64 'always_iterable',
65 'always_reversible',
66 'argmax',
67 'argmin',
68 'bucket',
69 'callback_iter',
70 'chunked',
71 'chunked_even',
72 'circular_shifts',
73 'collapse',
74 'combination_index',
75 'combination_with_replacement_index',
76 'concurrent_tee',
77 'consecutive_groups',
78 'constrained_batches',
79 'consumer',
80 'count_cycle',
81 'countable',
82 'derangements',
83 'dft',
84 'difference',
85 'distinct_combinations',
86 'distinct_permutations',
87 'distribute',
88 'divide',
89 'doublestarmap',
90 'duplicates_everseen',
91 'duplicates_justseen',
92 'classify_unique',
93 'exactly_n',
94 'extract',
95 'filter_except',
96 'filter_map',
97 'first',
98 'gray_product',
99 'groupby_transform',
100 'ichunked',
101 'iequals',
102 'idft',
103 'ilen',
104 'interleave',
105 'interleave_evenly',
106 'interleave_longest',
107 'interleave_randomly',
108 'intersperse',
109 'is_sorted',
110 'islice_extended',
111 'iterate',
112 'iter_suppress',
113 'join_mappings',
114 'last',
115 'locate',
116 'longest_common_prefix',
117 'lstrip',
118 'make_decorator',
119 'map_except',
120 'map_if',
121 'map_reduce',
122 'mark_ends',
123 'minmax',
124 'nth_or_last',
125 'nth_permutation',
126 'nth_prime',
127 'nth_product',
128 'nth_combination_with_replacement',
129 'numeric_range',
130 'one',
131 'only',
132 'outer_product',
133 'padded',
134 'partial_product',
135 'partitions',
136 'peekable',
137 'permutation_index',
138 'powerset_of_sets',
139 'product_index',
140 'raise_',
141 'repeat_each',
142 'repeat_last',
143 'replace',
144 'rlocate',
145 'rstrip',
146 'run_length',
147 'sample',
148 'seekable',
149 'serialize',
150 'set_partitions',
151 'side_effect',
152 'sliced',
153 'sort_together',
154 'split_after',
155 'split_at',
156 'split_before',
157 'split_into',
158 'split_when',
159 'spy',
160 'stagger',
161 'strip',
162 'strictly_n',
163 'substrings',
164 'substrings_indexes',
165 'takewhile_inclusive',
166 'time_limited',
167 'unique_in_window',
168 'unique_to_each',
169 'unzip',
170 'value_chain',
171 'windowed',
172 'windowed_complete',
173 'with_iter',
174 'zip_broadcast',
175 'zip_offset',
176]
178# math.sumprod is available for Python 3.12+
179try:
180 from math import sumprod as _fsumprod
182except ImportError: # pragma: no cover
183 # Extended precision algorithms from T. J. Dekker,
184 # "A Floating-Point Technique for Extending the Available Precision"
185 # https://csclub.uwaterloo.ca/~pbarfuss/dekker1971.pdf
186 # Formulas: (5.5) (5.6) and (5.8). Code: mul12()
188 def dl_split(x: float):
189 "Split a float into two half-precision components."
190 t = x * 134217729.0 # Veltkamp constant = 2.0 ** 27 + 1
191 hi = t - (t - x)
192 lo = x - hi
193 return hi, lo
195 def dl_mul(x, y):
196 "Lossless multiplication."
197 xx_hi, xx_lo = dl_split(x)
198 yy_hi, yy_lo = dl_split(y)
199 p = xx_hi * yy_hi
200 q = xx_hi * yy_lo + xx_lo * yy_hi
201 z = p + q
202 zz = p - z + q + xx_lo * yy_lo
203 return z, zz
205 def _fsumprod(p, q):
206 return fsum(chain.from_iterable(map(dl_mul, p, q)))
209def chunked(iterable, n, strict=False):
210 """Break *iterable* into lists of length *n*:
212 >>> list(chunked([1, 2, 3, 4, 5, 6], 3))
213 [[1, 2, 3], [4, 5, 6]]
215 By the default, the last yielded list will have fewer than *n* elements
216 if the length of *iterable* is not divisible by *n*:
218 >>> list(chunked([1, 2, 3, 4, 5, 6, 7, 8], 3))
219 [[1, 2, 3], [4, 5, 6], [7, 8]]
221 To use a fill-in value instead, see the :func:`grouper` recipe.
223 If the length of *iterable* is not divisible by *n* and *strict* is
224 ``True``, then ``ValueError`` will be raised before the last
225 list is yielded.
227 """
228 iterator = iter(partial(take, n, iter(iterable)), [])
229 if strict:
230 if n is None:
231 raise ValueError('n must not be None when using strict mode.')
233 def ret():
234 for chunk in iterator:
235 if len(chunk) != n:
236 raise ValueError('iterable is not divisible by n.')
237 yield chunk
239 return ret()
240 else:
241 return iterator
244def first(iterable, default=_marker):
245 """Return the first item of *iterable*, or *default* if *iterable* is
246 empty.
248 >>> first([0, 1, 2, 3])
249 0
250 >>> first([], 'some default')
251 'some default'
253 If *default* is not provided and there are no items in the iterable,
254 raise ``ValueError``.
256 :func:`first` is useful when you have a generator of expensive-to-retrieve
257 values and want any arbitrary one. It is marginally shorter than
258 ``next(iter(iterable), default)``.
260 """
261 for item in iterable:
262 return item
263 if default is _marker:
264 raise ValueError(
265 'first() was called on an empty iterable, '
266 'and no default value was provided.'
267 )
268 return default
271def last(iterable, default=_marker):
272 """Return the last item of *iterable*, or *default* if *iterable* is
273 empty.
275 >>> last([0, 1, 2, 3])
276 3
277 >>> last([], 'some default')
278 'some default'
280 If *default* is not provided and there are no items in the iterable,
281 raise ``ValueError``.
282 """
283 try:
284 if isinstance(iterable, Sequence):
285 return iterable[-1]
286 # Work around https://bugs.python.org/issue38525
287 if getattr(iterable, '__reversed__', None):
288 return next(reversed(iterable))
289 return deque(iterable, maxlen=1)[-1]
290 except (IndexError, TypeError, StopIteration):
291 if default is _marker:
292 raise ValueError(
293 'last() was called on an empty iterable, '
294 'and no default value was provided.'
295 )
296 return default
299def nth_or_last(iterable, n, default=_marker):
300 """Return the nth or the last item of *iterable*,
301 or *default* if *iterable* is empty.
303 >>> nth_or_last([0, 1, 2, 3], 2)
304 2
305 >>> nth_or_last([0, 1], 2)
306 1
307 >>> nth_or_last([], 0, 'some default')
308 'some default'
310 If *default* is not provided and there are no items in the iterable,
311 raise ``ValueError``.
312 """
313 return last(islice(iterable, n + 1), default=default)
316class peekable:
317 """Wrap an iterator to allow lookahead and prepending elements.
319 Call :meth:`peek` on the result to get the value that will be returned
320 by :func:`next`. This won't advance the iterator:
322 >>> p = peekable(['a', 'b'])
323 >>> p.peek()
324 'a'
325 >>> next(p)
326 'a'
328 Pass :meth:`peek` a default value to return that instead of raising
329 ``StopIteration`` when the iterator is exhausted.
331 >>> p = peekable([])
332 >>> p.peek('hi')
333 'hi'
335 peekables also offer a :meth:`prepend` method, which "inserts" items
336 at the head of the iterable:
338 >>> p = peekable([1, 2, 3])
339 >>> p.prepend(10, 11, 12)
340 >>> next(p)
341 10
342 >>> p.peek()
343 11
344 >>> list(p)
345 [11, 12, 1, 2, 3]
347 peekables can be indexed. Index 0 is the item that will be returned by
348 :func:`next`, index 1 is the item after that, and so on:
349 The values up to the given index will be cached.
351 >>> p = peekable(['a', 'b', 'c', 'd'])
352 >>> p[0]
353 'a'
354 >>> p[1]
355 'b'
356 >>> next(p)
357 'a'
359 Negative indexes are supported, but be aware that they will cache the
360 remaining items in the source iterator, which may require significant
361 storage.
363 To check whether a peekable is exhausted, check its truth value:
365 >>> p = peekable(['a', 'b'])
366 >>> if p: # peekable has items
367 ... list(p)
368 ['a', 'b']
369 >>> if not p: # peekable is exhausted
370 ... list(p)
371 []
373 """
375 def __init__(self, iterable):
376 self._it = iter(iterable)
377 self._cache = deque()
379 def __iter__(self):
380 return self
382 def __bool__(self):
383 try:
384 self.peek()
385 except StopIteration:
386 return False
387 return True
389 def peek(self, default=_marker):
390 """Return the item that will be next returned from ``next()``.
392 Return ``default`` if there are no items left. If ``default`` is not
393 provided, raise ``StopIteration``.
395 """
396 if not self._cache:
397 try:
398 self._cache.append(next(self._it))
399 except StopIteration:
400 if default is _marker:
401 raise
402 return default
403 return self._cache[0]
405 def prepend(self, *items):
406 """Stack up items to be the next ones returned from ``next()`` or
407 ``self.peek()``. The items will be returned in
408 first in, first out order::
410 >>> p = peekable([1, 2, 3])
411 >>> p.prepend(10, 11, 12)
412 >>> next(p)
413 10
414 >>> list(p)
415 [11, 12, 1, 2, 3]
417 It is possible, by prepending items, to "resurrect" a peekable that
418 previously raised ``StopIteration``.
420 >>> p = peekable([])
421 >>> next(p)
422 Traceback (most recent call last):
423 ...
424 StopIteration
425 >>> p.prepend(1)
426 >>> next(p)
427 1
428 >>> next(p)
429 Traceback (most recent call last):
430 ...
431 StopIteration
433 """
434 self._cache.extendleft(reversed(items))
436 def __next__(self):
437 if self._cache:
438 return self._cache.popleft()
440 return next(self._it)
442 def _get_slice(self, index):
443 # Normalize the slice's arguments
444 step = 1 if (index.step is None) else index.step
445 if step > 0:
446 start = 0 if (index.start is None) else index.start
447 stop = maxsize if (index.stop is None) else index.stop
448 elif step < 0:
449 start = -1 if (index.start is None) else index.start
450 stop = (-maxsize - 1) if (index.stop is None) else index.stop
451 else:
452 raise ValueError('slice step cannot be zero')
454 # If either the start or stop index is negative, we'll need to cache
455 # the rest of the iterable in order to slice from the right side.
456 if (start < 0) or (stop < 0):
457 self._cache.extend(self._it)
458 # Otherwise we'll need to find the rightmost index and cache to that
459 # point.
460 else:
461 n = min(max(start, stop) + 1, maxsize)
462 cache_len = len(self._cache)
463 if n >= cache_len:
464 self._cache.extend(islice(self._it, n - cache_len))
466 return list(self._cache)[index]
468 def __getitem__(self, index):
469 if isinstance(index, slice):
470 return self._get_slice(index)
472 cache_len = len(self._cache)
473 if index < 0:
474 self._cache.extend(self._it)
475 elif index >= cache_len:
476 self._cache.extend(islice(self._it, index + 1 - cache_len))
478 return self._cache[index]
481def consumer(func):
482 """Decorator that automatically advances a PEP-342-style "reverse iterator"
483 to its first yield point so you don't have to call ``next()`` on it
484 manually.
486 >>> @consumer
487 ... def tally():
488 ... i = 0
489 ... while True:
490 ... print('Thing number %s is %s.' % (i, (yield)))
491 ... i += 1
492 ...
493 >>> t = tally()
494 >>> t.send('red')
495 Thing number 0 is red.
496 >>> t.send('fish')
497 Thing number 1 is fish.
499 Without the decorator, you would have to call ``next(t)`` before
500 ``t.send()`` could be used.
502 """
504 @wraps(func)
505 def wrapper(*args, **kwargs):
506 gen = func(*args, **kwargs)
507 next(gen)
508 return gen
510 return wrapper
513def ilen(iterable):
514 """Return the number of items in *iterable*.
516 For example, there are 168 prime numbers below 1,000:
518 >>> ilen(sieve(1000))
519 168
521 Equivalent to, but faster than::
523 def ilen(iterable):
524 count = 0
525 for _ in iterable:
526 count += 1
527 return count
529 This fully consumes the iterable, so handle with care.
531 """
532 # This is the "most beautiful of the fast variants" of this function.
533 # If you think you can improve on it, please ensure that your version
534 # is both 10x faster and 10x more beautiful.
535 return sum(compress(repeat(1), zip(iterable)))
538def iterate(func, start):
539 """Return ``start``, ``func(start)``, ``func(func(start))``, ...
541 Produces an infinite iterator. To add a stopping condition,
542 use :func:`take`, ``takewhile``, or :func:`takewhile_inclusive`:.
544 >>> take(10, iterate(lambda x: 2*x, 1))
545 [1, 2, 4, 8, 16, 32, 64, 128, 256, 512]
547 >>> collatz = lambda x: 3*x + 1 if x%2==1 else x // 2
548 >>> list(takewhile_inclusive(lambda x: x!=1, iterate(collatz, 10)))
549 [10, 5, 16, 8, 4, 2, 1]
551 """
552 with suppress(StopIteration):
553 while True:
554 yield start
555 start = func(start)
558def with_iter(context_manager):
559 """Wrap an iterable in a ``with`` statement, so it closes once exhausted.
561 For example, this will close the file when the iterator is exhausted::
563 upper_lines = (line.upper() for line in with_iter(open('foo')))
565 Any context manager which returns an iterable is a candidate for
566 ``with_iter``.
568 """
569 with context_manager as iterable:
570 yield from iterable
573def one(iterable, too_short=None, too_long=None):
574 """Return the first item from *iterable*, which is expected to contain only
575 that item. Raise an exception if *iterable* is empty or has more than one
576 item.
578 :func:`one` is useful for ensuring that an iterable contains only one item.
579 For example, it can be used to retrieve the result of a database query
580 that is expected to return a single row.
582 If *iterable* is empty, ``ValueError`` will be raised. You may specify a
583 different exception with the *too_short* keyword:
585 >>> it = []
586 >>> one(it) # doctest: +IGNORE_EXCEPTION_DETAIL
587 Traceback (most recent call last):
588 ...
589 ValueError: too few items in iterable (expected 1)'
590 >>> too_short = IndexError('too few items')
591 >>> one(it, too_short=too_short) # doctest: +IGNORE_EXCEPTION_DETAIL
592 Traceback (most recent call last):
593 ...
594 IndexError: too few items
596 Similarly, if *iterable* contains more than one item, ``ValueError`` will
597 be raised. You may specify a different exception with the *too_long*
598 keyword:
600 >>> it = ['too', 'many']
601 >>> one(it) # doctest: +IGNORE_EXCEPTION_DETAIL
602 Traceback (most recent call last):
603 ...
604 ValueError: Expected exactly one item in iterable, but got 'too',
605 'many', and perhaps more.
606 >>> too_long = RuntimeError
607 >>> one(it, too_long=too_long) # doctest: +IGNORE_EXCEPTION_DETAIL
608 Traceback (most recent call last):
609 ...
610 RuntimeError
612 Note that :func:`one` attempts to advance *iterable* twice to ensure there
613 is only one item. See :func:`spy` or :func:`peekable` to check iterable
614 contents less destructively.
616 """
617 iterator = iter(iterable)
618 for first in iterator:
619 for second in iterator:
620 msg = (
621 f'Expected exactly one item in iterable, but got {first!r}, '
622 f'{second!r}, and perhaps more.'
623 )
624 raise too_long or ValueError(msg)
625 return first
626 raise too_short or ValueError('too few items in iterable (expected 1)')
629def raise_(exception, *args):
630 raise exception(*args)
633def strictly_n(iterable, n, too_short=None, too_long=None):
634 """Validate that *iterable* has exactly *n* items and return them if
635 it does. If it has fewer than *n* items, call function *too_short*
636 with the actual number of items. If it has more than *n* items, call function
637 *too_long* with the number ``n + 1``.
639 >>> iterable = ['a', 'b', 'c', 'd']
640 >>> n = 4
641 >>> list(strictly_n(iterable, n))
642 ['a', 'b', 'c', 'd']
644 Note that the returned iterable must be consumed in order for the check to
645 be made.
647 By default, *too_short* and *too_long* are functions that raise
648 ``ValueError``.
650 >>> list(strictly_n('ab', 3)) # doctest: +IGNORE_EXCEPTION_DETAIL
651 Traceback (most recent call last):
652 ...
653 ValueError: too few items in iterable (got 2)
655 >>> list(strictly_n('abc', 2)) # doctest: +IGNORE_EXCEPTION_DETAIL
656 Traceback (most recent call last):
657 ...
658 ValueError: too many items in iterable (got at least 3)
660 You can instead supply functions that do something else.
661 *too_short* will be called with the number of items in *iterable*.
662 *too_long* will be called with `n + 1`.
664 >>> def too_short(item_count):
665 ... raise RuntimeError
666 >>> it = strictly_n('abcd', 6, too_short=too_short)
667 >>> list(it) # doctest: +IGNORE_EXCEPTION_DETAIL
668 Traceback (most recent call last):
669 ...
670 RuntimeError
672 >>> def too_long(item_count):
673 ... print('The boss is going to hear about this')
674 >>> it = strictly_n('abcdef', 4, too_long=too_long)
675 >>> list(it)
676 The boss is going to hear about this
677 ['a', 'b', 'c', 'd']
679 """
680 if too_short is None:
681 too_short = lambda item_count: raise_(
682 ValueError,
683 f'Too few items in iterable (got {item_count})',
684 )
686 if too_long is None:
687 too_long = lambda item_count: raise_(
688 ValueError,
689 f'Too many items in iterable (got at least {item_count})',
690 )
692 it = iter(iterable)
694 sent = 0
695 for item in islice(it, n):
696 yield item
697 sent += 1
699 if sent < n:
700 too_short(sent)
701 return
703 for item in it:
704 too_long(n + 1)
705 return
708def distinct_permutations(iterable, r=None):
709 """Yield successive distinct permutations of the elements in *iterable*.
711 >>> sorted(distinct_permutations([1, 0, 1]))
712 [(0, 1, 1), (1, 0, 1), (1, 1, 0)]
714 Equivalent to yielding from ``set(permutations(iterable))``, except
715 duplicates are not generated and thrown away. For larger input sequences
716 this is much more efficient.
718 Duplicate permutations arise when there are duplicated elements in the
719 input iterable. The number of items returned is
720 `n! / (x_1! * x_2! * ... * x_n!)`, where `n` is the total number of
721 items input, and each `x_i` is the count of a distinct item in the input
722 sequence. The function :func:`multinomial` computes this directly.
724 If *r* is given, only the *r*-length permutations are yielded.
726 >>> sorted(distinct_permutations([1, 0, 1], r=2))
727 [(0, 1), (1, 0), (1, 1)]
728 >>> sorted(distinct_permutations(range(3), r=2))
729 [(0, 1), (0, 2), (1, 0), (1, 2), (2, 0), (2, 1)]
731 *iterable* need not be sortable, but note that using equal (``x == y``)
732 but non-identical (``id(x) != id(y)``) elements may produce surprising
733 behavior. For example, ``1`` and ``True`` are equal but non-identical:
735 >>> list(distinct_permutations([1, True, '3'])) # doctest: +SKIP
736 [
737 (1, True, '3'),
738 (1, '3', True),
739 ('3', 1, True)
740 ]
741 >>> list(distinct_permutations([1, 2, '3'])) # doctest: +SKIP
742 [
743 (1, 2, '3'),
744 (1, '3', 2),
745 (2, 1, '3'),
746 (2, '3', 1),
747 ('3', 1, 2),
748 ('3', 2, 1)
749 ]
750 """
752 # Algorithm: https://w.wiki/Qai
753 def _full(A):
754 while True:
755 # Yield the permutation we have
756 yield tuple(A)
758 # Find the largest index i such that A[i] < A[i + 1]
759 for i in range(size - 2, -1, -1):
760 if A[i] < A[i + 1]:
761 break
762 # If no such index exists, this permutation is the last one
763 else:
764 return
766 # Find the largest index j greater than j such that A[i] < A[j]
767 for j in range(size - 1, i, -1):
768 if A[i] < A[j]:
769 break
771 # Swap the value of A[i] with that of A[j], then reverse the
772 # sequence from A[i + 1] to form the new permutation
773 A[i], A[j] = A[j], A[i]
774 A[i + 1 :] = A[: i - size : -1] # A[i + 1:][::-1]
776 # Algorithm: modified from the above
777 def _partial(A, r):
778 # Split A into the first r items and the last r items
779 head, tail = A[:r], A[r:]
780 right_head_indexes = range(r - 1, -1, -1)
781 left_tail_indexes = range(len(tail))
783 while True:
784 # Yield the permutation we have
785 yield tuple(head)
787 # Starting from the right, find the first index of the head with
788 # value smaller than the maximum value of the tail - call it i.
789 pivot = tail[-1]
790 for i in right_head_indexes:
791 if head[i] < pivot:
792 break
793 pivot = head[i]
794 else:
795 return
797 # Starting from the left, find the first value of the tail
798 # with a value greater than head[i] and swap.
799 for j in left_tail_indexes:
800 if tail[j] > head[i]:
801 head[i], tail[j] = tail[j], head[i]
802 break
803 # If we didn't find one, start from the right and find the first
804 # index of the head with a value greater than head[i] and swap.
805 else:
806 for j in right_head_indexes:
807 if head[j] > head[i]:
808 head[i], head[j] = head[j], head[i]
809 break
811 # Reverse head[i + 1:] and swap it with tail[:r - (i + 1)]
812 tail += head[: i - r : -1] # head[i + 1:][::-1]
813 i += 1
814 head[i:], tail[:] = tail[: r - i], tail[r - i :]
816 items = list(iterable)
818 try:
819 items.sort()
820 sortable = True
821 except TypeError:
822 sortable = False
824 indices_dict = defaultdict(list)
826 for item in items:
827 indices_dict[items.index(item)].append(item)
829 indices = [items.index(item) for item in items]
830 indices.sort()
832 equivalent_items = {k: cycle(v) for k, v in indices_dict.items()}
834 def permuted_items(permuted_indices):
835 return tuple(
836 next(equivalent_items[index]) for index in permuted_indices
837 )
839 size = len(items)
840 if r is None:
841 r = size
843 # functools.partial(_partial, ... )
844 algorithm = _full if (r == size) else partial(_partial, r=r)
846 if 0 < r <= size:
847 if sortable:
848 return algorithm(items)
849 else:
850 return (
851 permuted_items(permuted_indices)
852 for permuted_indices in algorithm(indices)
853 )
855 return iter(() if r else ((),))
858def derangements(iterable, r=None):
859 """Yield successive derangements of the elements in *iterable*.
861 A derangement is a permutation in which no element appears at its original
862 index. In other words, a derangement is a permutation that has no fixed points.
864 Suppose Alice, Bob, Carol, and Dave are playing Secret Santa.
865 The code below outputs all of the different ways to assign gift recipients
866 such that nobody is assigned to himself or herself:
868 >>> for d in derangements(['Alice', 'Bob', 'Carol', 'Dave']):
869 ... print(', '.join(d))
870 Bob, Alice, Dave, Carol
871 Bob, Carol, Dave, Alice
872 Bob, Dave, Alice, Carol
873 Carol, Alice, Dave, Bob
874 Carol, Dave, Alice, Bob
875 Carol, Dave, Bob, Alice
876 Dave, Alice, Bob, Carol
877 Dave, Carol, Alice, Bob
878 Dave, Carol, Bob, Alice
880 If *r* is given, only the *r*-length derangements are yielded.
882 >>> sorted(derangements(range(3), 2))
883 [(1, 0), (1, 2), (2, 0)]
884 >>> sorted(derangements([0, 2, 3], 2))
885 [(2, 0), (2, 3), (3, 0)]
887 Elements are treated as unique based on their position, not on their value.
889 Consider the Secret Santa example with two *different* people who have
890 the *same* name. Then there are two valid gift assignments even though
891 it might appear that a person is assigned to themselves:
893 >>> names = ['Alice', 'Bob', 'Bob']
894 >>> list(derangements(names))
895 [('Bob', 'Bob', 'Alice'), ('Bob', 'Alice', 'Bob')]
897 To avoid confusion, make the inputs distinct:
899 >>> deduped = [f'{name}{index}' for index, name in enumerate(names)]
900 >>> list(derangements(deduped))
901 [('Bob1', 'Bob2', 'Alice0'), ('Bob2', 'Alice0', 'Bob1')]
903 The number of derangements of a set of size *n* is known as the
904 "subfactorial of n". For n > 0, the subfactorial is:
905 ``round(math.factorial(n) / math.e)``.
907 References:
909 * Article: https://www.numberanalytics.com/blog/ultimate-guide-to-derangements-in-combinatorics
910 * Sizes: https://oeis.org/A000166
911 """
912 xs = tuple(iterable)
913 ys = tuple(range(len(xs)))
914 return compress(
915 permutations(xs, r=r),
916 map(all, map(map, repeat(is_not), repeat(ys), permutations(ys, r=r))),
917 )
920def intersperse(e, iterable, n=1):
921 """Intersperse filler element *e* among the items in *iterable*, leaving
922 *n* items between each filler element.
924 >>> list(intersperse('!', [1, 2, 3, 4, 5]))
925 [1, '!', 2, '!', 3, '!', 4, '!', 5]
927 >>> list(intersperse(None, [1, 2, 3, 4, 5], n=2))
928 [1, 2, None, 3, 4, None, 5]
930 """
931 if n == 0:
932 raise ValueError('n must be > 0')
933 elif n == 1:
934 # interleave(repeat(e), iterable) -> e, x_0, e, x_1, e, x_2...
935 # islice(..., 1, None) -> x_0, e, x_1, e, x_2...
936 return islice(interleave(repeat(e), iterable), 1, None)
937 else:
938 # interleave(filler, chunks) -> [e], [x_0, x_1], [e], [x_2, x_3]...
939 # islice(..., 1, None) -> [x_0, x_1], [e], [x_2, x_3]...
940 # flatten(...) -> x_0, x_1, e, x_2, x_3...
941 filler = repeat([e])
942 chunks = chunked(iterable, n)
943 return flatten(islice(interleave(filler, chunks), 1, None))
946def unique_to_each(*iterables):
947 """Return the elements from each of the input iterables that aren't in the
948 other input iterables.
950 For example, suppose you have a set of packages, each with a set of
951 dependencies::
953 {'pkg_1': {'A', 'B'}, 'pkg_2': {'B', 'C'}, 'pkg_3': {'B', 'D'}}
955 If you remove one package, which dependencies can also be removed?
957 If ``pkg_1`` is removed, then ``A`` is no longer necessary - it is not
958 associated with ``pkg_2`` or ``pkg_3``. Similarly, ``C`` is only needed for
959 ``pkg_2``, and ``D`` is only needed for ``pkg_3``::
961 >>> unique_to_each({'A', 'B'}, {'B', 'C'}, {'B', 'D'})
962 [['A'], ['C'], ['D']]
964 If there are duplicates in one input iterable that aren't in the others
965 they will be duplicated in the output. Input order is preserved::
967 >>> unique_to_each("mississippi", "missouri")
968 [['p', 'p'], ['o', 'u', 'r']]
970 It is assumed that the elements of each iterable are hashable.
972 """
973 pool = [list(it) for it in iterables]
974 counts = Counter(chain.from_iterable(map(set, pool)))
975 uniques = {element for element in counts if counts[element] == 1}
976 return [list(filter(uniques.__contains__, it)) for it in pool]
979def windowed(seq, n, fillvalue=None, step=1):
980 """Return a sliding window of width *n* over the given iterable.
982 >>> all_windows = windowed([1, 2, 3, 4, 5], 3)
983 >>> list(all_windows)
984 [(1, 2, 3), (2, 3, 4), (3, 4, 5)]
986 When the window is larger than the iterable, *fillvalue* is used in place
987 of missing values:
989 >>> list(windowed([1, 2, 3], 4))
990 [(1, 2, 3, None)]
992 Each window will advance in increments of *step*:
994 >>> list(windowed([1, 2, 3, 4, 5, 6], 3, fillvalue='!', step=2))
995 [(1, 2, 3), (3, 4, 5), (5, 6, '!')]
997 To slide into the iterable's items, use :func:`chain` to add filler items
998 to the left:
1000 >>> iterable = [1, 2, 3, 4]
1001 >>> n = 3
1002 >>> padding = [None] * (n - 1)
1003 >>> list(windowed(chain(padding, iterable), 3))
1004 [(None, None, 1), (None, 1, 2), (1, 2, 3), (2, 3, 4)]
1005 """
1006 if n < 0:
1007 raise ValueError('n must be >= 0')
1008 if n == 0:
1009 yield ()
1010 return
1011 if step < 1:
1012 raise ValueError('step must be >= 1')
1014 iterator = iter(seq)
1016 # Generate first window
1017 window = deque(islice(iterator, n), maxlen=n)
1019 # Deal with the first window not being full
1020 if not window:
1021 return
1022 if len(window) < n:
1023 yield tuple(window) + ((fillvalue,) * (n - len(window)))
1024 return
1025 yield tuple(window)
1027 # Create the filler for the next windows. The padding ensures
1028 # we have just enough elements to fill the last window.
1029 padding = (fillvalue,) * (n - 1 if step >= n else step - 1)
1030 filler = map(window.append, chain(iterator, padding))
1032 # Generate the rest of the windows
1033 for _ in islice(filler, step - 1, None, step):
1034 yield tuple(window)
1037def substrings(iterable):
1038 """Yield all of the substrings of *iterable*.
1040 >>> [''.join(s) for s in substrings('more')]
1041 ['m', 'o', 'r', 'e', 'mo', 'or', 're', 'mor', 'ore', 'more']
1043 Note that non-string iterables can also be subdivided.
1045 >>> list(substrings([0, 1, 2]))
1046 [(0,), (1,), (2,), (0, 1), (1, 2), (0, 1, 2)]
1048 Like subslices() but returns tuples instead of lists
1049 and returns the shortest substrings first.
1051 """
1052 seq = tuple(iterable)
1053 item_count = len(seq)
1054 for n in range(1, item_count + 1):
1055 slices = map(slice, range(item_count), range(n, item_count + 1))
1056 yield from map(getitem, repeat(seq), slices)
1059def substrings_indexes(seq, reverse=False):
1060 """Yield all substrings and their positions in *seq*
1062 The items yielded will be a tuple of the form ``(substr, i, j)``, where
1063 ``substr == seq[i:j]``.
1065 This function only works for iterables that support slicing, such as
1066 ``str`` objects.
1068 >>> for item in substrings_indexes('more'):
1069 ... print(item)
1070 ('m', 0, 1)
1071 ('o', 1, 2)
1072 ('r', 2, 3)
1073 ('e', 3, 4)
1074 ('mo', 0, 2)
1075 ('or', 1, 3)
1076 ('re', 2, 4)
1077 ('mor', 0, 3)
1078 ('ore', 1, 4)
1079 ('more', 0, 4)
1081 Set *reverse* to ``True`` to yield the same items in the opposite order.
1084 """
1085 r = range(1, len(seq) + 1)
1086 if reverse:
1087 r = reversed(r)
1088 return (
1089 (seq[i : i + L], i, i + L) for L in r for i in range(len(seq) - L + 1)
1090 )
1093class bucket:
1094 """Wrap *iterable* and return an object that buckets the iterable into
1095 child iterables based on a *key* function.
1097 >>> iterable = ['a1', 'b1', 'c1', 'a2', 'b2', 'c2', 'b3']
1098 >>> s = bucket(iterable, key=lambda x: x[0]) # Bucket by 1st character
1099 >>> sorted(list(s)) # Get the keys
1100 ['a', 'b', 'c']
1101 >>> a_iterable = s['a']
1102 >>> next(a_iterable)
1103 'a1'
1104 >>> next(a_iterable)
1105 'a2'
1106 >>> list(s['b'])
1107 ['b1', 'b2', 'b3']
1109 The original iterable will be advanced and its items will be cached until
1110 they are used by the child iterables. This may require significant storage.
1112 By default, attempting to select a bucket to which no items belong will
1113 exhaust the iterable and cache all values.
1114 If you specify a *validator* function, selected buckets will instead be
1115 checked against it.
1117 >>> from itertools import count
1118 >>> it = count(1, 2) # Infinite sequence of odd numbers
1119 >>> key = lambda x: x % 10 # Bucket by last digit
1120 >>> validator = lambda x: x in {1, 3, 5, 7, 9} # Odd digits only
1121 >>> s = bucket(it, key=key, validator=validator)
1122 >>> 2 in s
1123 False
1124 >>> list(s[2])
1125 []
1127 """
1129 def __init__(self, iterable, key, validator=None):
1130 self._it = iter(iterable)
1131 self._key = key
1132 self._cache = defaultdict(deque)
1133 self._validator = validator or (lambda x: True)
1135 def __contains__(self, value):
1136 if not self._validator(value):
1137 return False
1139 try:
1140 item = next(self[value])
1141 except StopIteration:
1142 return False
1143 else:
1144 self._cache[value].appendleft(item)
1146 return True
1148 def _get_values(self, value):
1149 """
1150 Helper to yield items from the parent iterator that match *value*.
1151 Items that don't match are stored in the local cache as they
1152 are encountered.
1153 """
1154 while True:
1155 # If we've cached some items that match the target value, emit
1156 # the first one and evict it from the cache.
1157 if self._cache[value]:
1158 yield self._cache[value].popleft()
1159 # Otherwise we need to advance the parent iterator to search for
1160 # a matching item, caching the rest.
1161 else:
1162 while True:
1163 try:
1164 item = next(self._it)
1165 except StopIteration:
1166 return
1167 item_value = self._key(item)
1168 if item_value == value:
1169 yield item
1170 break
1171 elif self._validator(item_value):
1172 self._cache[item_value].append(item)
1174 def __iter__(self):
1175 for item in self._it:
1176 item_value = self._key(item)
1177 if self._validator(item_value):
1178 self._cache[item_value].append(item)
1180 return iter(self._cache)
1182 def __getitem__(self, value):
1183 if not self._validator(value):
1184 return iter(())
1186 return self._get_values(value)
1189def spy(iterable, n=1):
1190 """Return a 2-tuple with a list containing the first *n* elements of
1191 *iterable*, and an iterator with the same items as *iterable*.
1192 This allows you to "look ahead" at the items in the iterable without
1193 advancing it.
1195 There is one item in the list by default:
1197 >>> iterable = 'abcdefg'
1198 >>> head, iterable = spy(iterable)
1199 >>> head
1200 ['a']
1201 >>> list(iterable)
1202 ['a', 'b', 'c', 'd', 'e', 'f', 'g']
1204 You may use unpacking to retrieve items instead of lists:
1206 >>> (head,), iterable = spy('abcdefg')
1207 >>> head
1208 'a'
1209 >>> (first, second), iterable = spy('abcdefg', 2)
1210 >>> first
1211 'a'
1212 >>> second
1213 'b'
1215 The number of items requested can be larger than the number of items in
1216 the iterable:
1218 >>> iterable = [1, 2, 3, 4, 5]
1219 >>> head, iterable = spy(iterable, 10)
1220 >>> head
1221 [1, 2, 3, 4, 5]
1222 >>> list(iterable)
1223 [1, 2, 3, 4, 5]
1225 """
1226 p, q = tee(iterable)
1227 return take(n, q), p
1230def interleave(*iterables):
1231 """Return a new iterable yielding from each iterable in turn,
1232 until the shortest is exhausted.
1234 >>> list(interleave([1, 2, 3], [4, 5], [6, 7, 8]))
1235 [1, 4, 6, 2, 5, 7]
1237 For a version that doesn't terminate after the shortest iterable is
1238 exhausted, see :func:`interleave_longest`.
1240 """
1241 return chain.from_iterable(zip(*iterables))
1244def interleave_longest(*iterables):
1245 """Return a new iterable yielding from each iterable in turn,
1246 skipping any that are exhausted.
1248 >>> list(interleave_longest([1, 2, 3], [4, 5], [6, 7, 8]))
1249 [1, 4, 6, 2, 5, 7, 3, 8]
1251 This function produces the same output as :func:`roundrobin`, but may
1252 perform better for some inputs (in particular when the number of iterables
1253 is large).
1255 """
1256 for xs in zip_longest(*iterables, fillvalue=_marker):
1257 for x in xs:
1258 if x is not _marker:
1259 yield x
1262def interleave_evenly(iterables, lengths=None):
1263 """
1264 Interleave multiple iterables so that their elements are evenly distributed
1265 throughout the output sequence.
1267 >>> iterables = [1, 2, 3, 4, 5], ['a', 'b']
1268 >>> list(interleave_evenly(iterables))
1269 [1, 2, 'a', 3, 4, 'b', 5]
1271 >>> iterables = [[1, 2, 3], [4, 5], [6, 7, 8]]
1272 >>> list(interleave_evenly(iterables))
1273 [1, 6, 4, 2, 7, 3, 8, 5]
1275 This function requires iterables of known length. Iterables without
1276 ``__len__()`` can be used by manually specifying lengths with *lengths*:
1278 >>> from itertools import combinations, repeat
1279 >>> iterables = [combinations(range(4), 2), ['a', 'b', 'c']]
1280 >>> lengths = [4 * (4 - 1) // 2, 3]
1281 >>> list(interleave_evenly(iterables, lengths=lengths))
1282 [(0, 1), (0, 2), 'a', (0, 3), (1, 2), 'b', (1, 3), (2, 3), 'c']
1284 Based on Bresenham's algorithm.
1285 """
1286 if lengths is None:
1287 try:
1288 lengths = [len(it) for it in iterables]
1289 except TypeError:
1290 raise ValueError(
1291 'Iterable lengths could not be determined automatically. '
1292 'Specify them with the lengths keyword.'
1293 )
1294 elif len(iterables) != len(lengths):
1295 raise ValueError('Mismatching number of iterables and lengths.')
1297 dims = len(lengths)
1299 # sort iterables by length, descending
1300 lengths_permute = sorted(
1301 range(dims), key=lambda i: lengths[i], reverse=True
1302 )
1303 lengths_desc = [lengths[i] for i in lengths_permute]
1304 iters_desc = [iter(iterables[i]) for i in lengths_permute]
1306 # the longest iterable is the primary one (Bresenham: the longest
1307 # distance along an axis)
1308 delta_primary, deltas_secondary = lengths_desc[0], lengths_desc[1:]
1309 iter_primary, iters_secondary = iters_desc[0], iters_desc[1:]
1310 errors = [delta_primary // dims] * len(deltas_secondary)
1312 to_yield = sum(lengths)
1313 while to_yield:
1314 yield next(iter_primary)
1315 to_yield -= 1
1316 # update errors for each secondary iterable
1317 errors = [e - delta for e, delta in zip(errors, deltas_secondary)]
1319 # those iterables for which the error is negative are yielded
1320 # ("diagonal step" in Bresenham)
1321 for i, e_ in enumerate(errors):
1322 if e_ < 0:
1323 yield next(iters_secondary[i])
1324 to_yield -= 1
1325 errors[i] += delta_primary
1328def interleave_randomly(*iterables):
1329 """Repeatedly select one of the input *iterables* at random and yield the next
1330 item from it.
1332 >>> iterables = [1, 2, 3], 'abc', (True, False, None)
1333 >>> list(interleave_randomly(*iterables)) # doctest: +SKIP
1334 ['a', 'b', 1, 'c', True, False, None, 2, 3]
1336 The relative order of the items in each input iterable will preserved. Note the
1337 sequences of items with this property are not equally likely to be generated.
1339 """
1340 iterators = [iter(e) for e in iterables]
1341 while iterators:
1342 idx = randrange(len(iterators))
1343 try:
1344 yield next(iterators[idx])
1345 except StopIteration:
1346 # equivalent to `list.pop` but slightly faster
1347 iterators[idx] = iterators[-1]
1348 del iterators[-1]
1351def collapse(iterable, base_type=None, levels=None):
1352 """Flatten an iterable with multiple levels of nesting (e.g., a list of
1353 lists of tuples) into non-iterable types.
1355 >>> iterable = [(1, 2), ([3, 4], [[5], [6]])]
1356 >>> list(collapse(iterable))
1357 [1, 2, 3, 4, 5, 6]
1359 Binary and text strings are not considered iterable and
1360 will not be collapsed.
1362 To avoid collapsing other types, specify *base_type*:
1364 >>> iterable = ['ab', ('cd', 'ef'), ['gh', 'ij']]
1365 >>> list(collapse(iterable, base_type=tuple))
1366 ['ab', ('cd', 'ef'), 'gh', 'ij']
1368 Specify *levels* to stop flattening after a certain level:
1370 >>> iterable = [('a', ['b']), ('c', ['d'])]
1371 >>> list(collapse(iterable)) # Fully flattened
1372 ['a', 'b', 'c', 'd']
1373 >>> list(collapse(iterable, levels=1)) # Only one level flattened
1374 ['a', ['b'], 'c', ['d']]
1376 """
1377 stack = deque()
1378 # Add our first node group, treat the iterable as a single node
1379 stack.appendleft((0, repeat(iterable, 1)))
1381 while stack:
1382 node_group = stack.popleft()
1383 level, nodes = node_group
1385 # Check if beyond max level
1386 if levels is not None and level > levels:
1387 yield from nodes
1388 continue
1390 for node in nodes:
1391 # Check if done iterating
1392 if isinstance(node, (str, bytes)) or (
1393 (base_type is not None) and isinstance(node, base_type)
1394 ):
1395 yield node
1396 # Otherwise try to create child nodes
1397 else:
1398 try:
1399 tree = iter(node)
1400 except TypeError:
1401 yield node
1402 else:
1403 # Save our current location
1404 stack.appendleft(node_group)
1405 # Append the new child node
1406 stack.appendleft((level + 1, tree))
1407 # Break to process child node
1408 break
1411def side_effect(func, iterable, chunk_size=None, before=None, after=None):
1412 """Invoke *func* on each item in *iterable* (or on each *chunk_size* group
1413 of items) before yielding the item.
1415 `func` must be a function that takes a single argument. Its return value
1416 will be discarded.
1418 *before* and *after* are optional functions that take no arguments. They
1419 will be executed before iteration starts and after it ends, respectively.
1421 `side_effect` can be used for logging, updating progress bars, or anything
1422 that is not functionally "pure."
1424 Emitting a status message:
1426 >>> from more_itertools import consume
1427 >>> func = lambda item: print('Received {}'.format(item))
1428 >>> consume(side_effect(func, range(2)))
1429 Received 0
1430 Received 1
1432 Operating on chunks of items:
1434 >>> pair_sums = []
1435 >>> func = lambda chunk: pair_sums.append(sum(chunk))
1436 >>> list(side_effect(func, [0, 1, 2, 3, 4, 5], 2))
1437 [0, 1, 2, 3, 4, 5]
1438 >>> list(pair_sums)
1439 [1, 5, 9]
1441 Writing to a file-like object:
1443 >>> from io import StringIO
1444 >>> from more_itertools import consume
1445 >>> f = StringIO()
1446 >>> func = lambda x: print(x, file=f)
1447 >>> before = lambda: print(u'HEADER', file=f)
1448 >>> after = f.close
1449 >>> it = [u'a', u'b', u'c']
1450 >>> consume(side_effect(func, it, before=before, after=after))
1451 >>> f.closed
1452 True
1454 """
1455 try:
1456 if before is not None:
1457 before()
1459 if chunk_size is None:
1460 for item in iterable:
1461 func(item)
1462 yield item
1463 else:
1464 for chunk in chunked(iterable, chunk_size):
1465 func(chunk)
1466 yield from chunk
1467 finally:
1468 if after is not None:
1469 after()
1472def sliced(seq, n, strict=False):
1473 """Yield slices of length *n* from the sequence *seq*.
1475 >>> list(sliced((1, 2, 3, 4, 5, 6), 3))
1476 [(1, 2, 3), (4, 5, 6)]
1478 By the default, the last yielded slice will have fewer than *n* elements
1479 if the length of *seq* is not divisible by *n*:
1481 >>> list(sliced((1, 2, 3, 4, 5, 6, 7, 8), 3))
1482 [(1, 2, 3), (4, 5, 6), (7, 8)]
1484 If the length of *seq* is not divisible by *n* and *strict* is
1485 ``True``, then ``ValueError`` will be raised before the last
1486 slice is yielded.
1488 This function will only work for iterables that support slicing.
1489 For non-sliceable iterables, see :func:`chunked`.
1491 """
1492 iterator = takewhile(len, (seq[i : i + n] for i in count(0, n)))
1493 if strict:
1495 def ret():
1496 for _slice in iterator:
1497 if len(_slice) != n:
1498 raise ValueError("seq is not divisible by n.")
1499 yield _slice
1501 return ret()
1502 else:
1503 return iterator
1506def split_at(iterable, pred, maxsplit=-1, keep_separator=False):
1507 """Yield lists of items from *iterable*, where each list is delimited by
1508 an item where callable *pred* returns ``True``.
1510 >>> list(split_at('abcdcba', lambda x: x == 'b'))
1511 [['a'], ['c', 'd', 'c'], ['a']]
1513 >>> list(split_at(range(10), lambda n: n % 2 == 1))
1514 [[0], [2], [4], [6], [8], []]
1516 At most *maxsplit* splits are done. If *maxsplit* is not specified or -1,
1517 then there is no limit on the number of splits:
1519 >>> list(split_at(range(10), lambda n: n % 2 == 1, maxsplit=2))
1520 [[0], [2], [4, 5, 6, 7, 8, 9]]
1522 By default, the delimiting items are not included in the output.
1523 To include them, set *keep_separator* to ``True``.
1525 >>> list(split_at('abcdcba', lambda x: x == 'b', keep_separator=True))
1526 [['a'], ['b'], ['c', 'd', 'c'], ['b'], ['a']]
1528 """
1529 if maxsplit == 0:
1530 yield list(iterable)
1531 return
1533 buf = []
1534 it = iter(iterable)
1535 for item in it:
1536 if pred(item):
1537 yield buf
1538 if keep_separator:
1539 yield [item]
1540 if maxsplit == 1:
1541 yield list(it)
1542 return
1543 buf = []
1544 maxsplit -= 1
1545 else:
1546 buf.append(item)
1547 yield buf
1550def split_before(iterable, pred, maxsplit=-1):
1551 """Yield lists of items from *iterable*, where each list ends just before
1552 an item for which callable *pred* returns ``True``:
1554 >>> list(split_before('OneTwo', lambda s: s.isupper()))
1555 [['O', 'n', 'e'], ['T', 'w', 'o']]
1557 >>> list(split_before(range(10), lambda n: n % 3 == 0))
1558 [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
1560 At most *maxsplit* splits are done. If *maxsplit* is not specified or -1,
1561 then there is no limit on the number of splits:
1563 >>> list(split_before(range(10), lambda n: n % 3 == 0, maxsplit=2))
1564 [[0, 1, 2], [3, 4, 5], [6, 7, 8, 9]]
1565 """
1566 if maxsplit == 0:
1567 yield list(iterable)
1568 return
1570 buf = []
1571 it = iter(iterable)
1572 for item in it:
1573 if pred(item) and buf:
1574 yield buf
1575 if maxsplit == 1:
1576 yield [item, *it]
1577 return
1578 buf = []
1579 maxsplit -= 1
1580 buf.append(item)
1581 if buf:
1582 yield buf
1585def split_after(iterable, pred, maxsplit=-1):
1586 """Yield lists of items from *iterable*, where each list ends with an
1587 item where callable *pred* returns ``True``:
1589 >>> list(split_after('one1two2', lambda s: s.isdigit()))
1590 [['o', 'n', 'e', '1'], ['t', 'w', 'o', '2']]
1592 >>> list(split_after(range(10), lambda n: n % 3 == 0))
1593 [[0], [1, 2, 3], [4, 5, 6], [7, 8, 9]]
1595 At most *maxsplit* splits are done. If *maxsplit* is not specified or -1,
1596 then there is no limit on the number of splits:
1598 >>> list(split_after(range(10), lambda n: n % 3 == 0, maxsplit=2))
1599 [[0], [1, 2, 3], [4, 5, 6, 7, 8, 9]]
1601 """
1602 if maxsplit == 0:
1603 yield list(iterable)
1604 return
1606 buf = []
1607 it = iter(iterable)
1608 for item in it:
1609 buf.append(item)
1610 if pred(item) and buf:
1611 yield buf
1612 if maxsplit == 1:
1613 buf = list(it)
1614 if buf:
1615 yield buf
1616 return
1617 buf = []
1618 maxsplit -= 1
1619 if buf:
1620 yield buf
1623def split_when(iterable, pred, maxsplit=-1):
1624 """Split *iterable* into pieces based on the output of *pred*.
1625 *pred* should be a function that takes successive pairs of items and
1626 returns ``True`` if the iterable should be split in between them.
1628 For example, to find runs of increasing numbers, split the iterable when
1629 element ``i`` is larger than element ``i + 1``:
1631 >>> list(split_when([1, 2, 3, 3, 2, 5, 2, 4, 2], lambda x, y: x > y))
1632 [[1, 2, 3, 3], [2, 5], [2, 4], [2]]
1634 At most *maxsplit* splits are done. If *maxsplit* is not specified or -1,
1635 then there is no limit on the number of splits:
1637 >>> list(split_when([1, 2, 3, 3, 2, 5, 2, 4, 2],
1638 ... lambda x, y: x > y, maxsplit=2))
1639 [[1, 2, 3, 3], [2, 5], [2, 4, 2]]
1641 """
1642 if maxsplit == 0:
1643 yield list(iterable)
1644 return
1646 it = iter(iterable)
1647 try:
1648 cur_item = next(it)
1649 except StopIteration:
1650 return
1652 buf = [cur_item]
1653 for next_item in it:
1654 if pred(cur_item, next_item):
1655 yield buf
1656 if maxsplit == 1:
1657 yield [next_item, *it]
1658 return
1659 buf = []
1660 maxsplit -= 1
1662 buf.append(next_item)
1663 cur_item = next_item
1665 yield buf
1668def split_into(iterable, sizes):
1669 """Yield a list of sequential items from *iterable* of length 'n' for each
1670 integer 'n' in *sizes*.
1672 >>> list(split_into([1,2,3,4,5,6], [1,2,3]))
1673 [[1], [2, 3], [4, 5, 6]]
1675 If the sum of *sizes* is smaller than the length of *iterable*, then the
1676 remaining items of *iterable* will not be returned.
1678 >>> list(split_into([1,2,3,4,5,6], [2,3]))
1679 [[1, 2], [3, 4, 5]]
1681 If the sum of *sizes* is larger than the length of *iterable*, fewer items
1682 will be returned in the iteration that overruns the *iterable* and further
1683 lists will be empty:
1685 >>> list(split_into([1,2,3,4], [1,2,3,4]))
1686 [[1], [2, 3], [4], []]
1688 When a ``None`` object is encountered in *sizes*, the returned list will
1689 contain items up to the end of *iterable* the same way that
1690 :func:`itertools.slice` does:
1692 >>> list(split_into([1,2,3,4,5,6,7,8,9,0], [2,3,None]))
1693 [[1, 2], [3, 4, 5], [6, 7, 8, 9, 0]]
1695 :func:`split_into` can be useful for grouping a series of items where the
1696 sizes of the groups are not uniform. An example would be where in a row
1697 from a table, multiple columns represent elements of the same feature
1698 (e.g. a point represented by x,y,z) but, the format is not the same for
1699 all columns.
1700 """
1701 # convert the iterable argument into an iterator so its contents can
1702 # be consumed by islice in case it is a generator
1703 it = iter(iterable)
1705 for size in sizes:
1706 if size is None:
1707 yield list(it)
1708 return
1709 else:
1710 yield list(islice(it, size))
1713def padded(iterable, fillvalue=None, n=None, next_multiple=False):
1714 """Yield the elements from *iterable*, followed by *fillvalue*, such that
1715 at least *n* items are emitted.
1717 >>> list(padded([1, 2, 3], '?', 5))
1718 [1, 2, 3, '?', '?']
1720 If *next_multiple* is ``True``, *fillvalue* will be emitted until the
1721 number of items emitted is a multiple of *n*:
1723 >>> list(padded([1, 2, 3, 4], n=3, next_multiple=True))
1724 [1, 2, 3, 4, None, None]
1726 If *n* is ``None``, *fillvalue* will be emitted indefinitely.
1728 To create an *iterable* of exactly size *n*, you can truncate with
1729 :func:`islice`.
1731 >>> list(islice(padded([1, 2, 3], '?'), 5))
1732 [1, 2, 3, '?', '?']
1733 >>> list(islice(padded([1, 2, 3, 4, 5, 6, 7, 8], '?'), 5))
1734 [1, 2, 3, 4, 5]
1736 """
1737 iterator = iter(iterable)
1738 iterator_with_repeat = chain(iterator, repeat(fillvalue))
1740 if n is None:
1741 return iterator_with_repeat
1742 elif n < 1:
1743 raise ValueError('n must be at least 1')
1744 elif next_multiple:
1746 def slice_generator():
1747 for first in iterator:
1748 yield (first,)
1749 yield islice(iterator_with_repeat, n - 1)
1751 # While elements exist produce slices of size n
1752 return chain.from_iterable(slice_generator())
1753 else:
1754 # Ensure the first batch is at least size n then iterate
1755 return chain(islice(iterator_with_repeat, n), iterator)
1758def repeat_each(iterable, n=2):
1759 """Repeat each element in *iterable* *n* times.
1761 >>> list(repeat_each('ABC', 3))
1762 ['A', 'A', 'A', 'B', 'B', 'B', 'C', 'C', 'C']
1763 """
1764 return chain.from_iterable(map(repeat, iterable, repeat(n)))
1767def repeat_last(iterable, default=None):
1768 """After the *iterable* is exhausted, keep yielding its last element.
1770 >>> list(islice(repeat_last(range(3)), 5))
1771 [0, 1, 2, 2, 2]
1773 If the iterable is empty, yield *default* forever::
1775 >>> list(islice(repeat_last(range(0), 42), 5))
1776 [42, 42, 42, 42, 42]
1778 """
1779 item = _marker
1780 for item in iterable:
1781 yield item
1782 final = default if item is _marker else item
1783 yield from repeat(final)
1786def distribute(n, iterable):
1787 """Distribute the items from *iterable* among *n* smaller iterables.
1789 >>> group_1, group_2 = distribute(2, [1, 2, 3, 4, 5, 6])
1790 >>> list(group_1)
1791 [1, 3, 5]
1792 >>> list(group_2)
1793 [2, 4, 6]
1795 If the length of *iterable* is not evenly divisible by *n*, then the
1796 length of the returned iterables will not be identical:
1798 >>> children = distribute(3, [1, 2, 3, 4, 5, 6, 7])
1799 >>> [list(c) for c in children]
1800 [[1, 4, 7], [2, 5], [3, 6]]
1802 If the length of *iterable* is smaller than *n*, then the last returned
1803 iterables will be empty:
1805 >>> children = distribute(5, [1, 2, 3])
1806 >>> [list(c) for c in children]
1807 [[1], [2], [3], [], []]
1809 This function uses :func:`itertools.tee` and may require significant
1810 storage.
1812 If you need the order items in the smaller iterables to match the
1813 original iterable, see :func:`divide`.
1815 """
1816 if n < 1:
1817 raise ValueError('n must be at least 1')
1819 children = tee(iterable, n)
1820 return [islice(it, index, None, n) for index, it in enumerate(children)]
1823def stagger(iterable, offsets=(-1, 0, 1), longest=False, fillvalue=None):
1824 """Yield tuples whose elements are offset from *iterable*.
1825 The amount by which the `i`-th item in each tuple is offset is given by
1826 the `i`-th item in *offsets*.
1828 >>> list(stagger([0, 1, 2, 3]))
1829 [(None, 0, 1), (0, 1, 2), (1, 2, 3)]
1830 >>> list(stagger(range(8), offsets=(0, 2, 4)))
1831 [(0, 2, 4), (1, 3, 5), (2, 4, 6), (3, 5, 7)]
1833 By default, the sequence will end when the final element of a tuple is the
1834 last item in the iterable. To continue until the first element of a tuple
1835 is the last item in the iterable, set *longest* to ``True``::
1837 >>> list(stagger([0, 1, 2, 3], longest=True))
1838 [(None, 0, 1), (0, 1, 2), (1, 2, 3), (2, 3, None), (3, None, None)]
1840 By default, ``None`` will be used to replace offsets beyond the end of the
1841 sequence. Specify *fillvalue* to use some other value.
1843 """
1844 children = tee(iterable, len(offsets))
1846 return zip_offset(
1847 *children, offsets=offsets, longest=longest, fillvalue=fillvalue
1848 )
1851def zip_offset(*iterables, offsets, longest=False, fillvalue=None):
1852 """``zip`` the input *iterables* together, but offset the `i`-th iterable
1853 by the `i`-th item in *offsets*.
1855 >>> list(zip_offset('0123', 'abcdef', offsets=(0, 1)))
1856 [('0', 'b'), ('1', 'c'), ('2', 'd'), ('3', 'e')]
1858 This can be used as a lightweight alternative to SciPy or pandas to analyze
1859 data sets in which some series have a lead or lag relationship.
1861 By default, the sequence will end when the shortest iterable is exhausted.
1862 To continue until the longest iterable is exhausted, set *longest* to
1863 ``True``.
1865 >>> list(zip_offset('0123', 'abcdef', offsets=(0, 1), longest=True))
1866 [('0', 'b'), ('1', 'c'), ('2', 'd'), ('3', 'e'), (None, 'f')]
1868 By default, ``None`` will be used to replace offsets beyond the end of the
1869 sequence. Specify *fillvalue* to use some other value.
1871 """
1872 if len(iterables) != len(offsets):
1873 raise ValueError("Number of iterables and offsets didn't match")
1875 staggered = []
1876 for it, n in zip(iterables, offsets):
1877 if n < 0:
1878 staggered.append(chain(repeat(fillvalue, -n), it))
1879 elif n > 0:
1880 staggered.append(islice(it, n, None))
1881 else:
1882 staggered.append(it)
1884 if longest:
1885 return zip_longest(*staggered, fillvalue=fillvalue)
1887 return zip(*staggered)
1890def sort_together(
1891 iterables, key_list=(0,), key=None, reverse=False, strict=False
1892):
1893 """Return the input iterables sorted together, with *key_list* as the
1894 priority for sorting. All iterables are trimmed to the length of the
1895 shortest one.
1897 This can be used like the sorting function in a spreadsheet. If each
1898 iterable represents a column of data, the key list determines which
1899 columns are used for sorting.
1901 By default, all iterables are sorted using the ``0``-th iterable::
1903 >>> iterables = [(4, 3, 2, 1), ('a', 'b', 'c', 'd')]
1904 >>> sort_together(iterables)
1905 [(1, 2, 3, 4), ('d', 'c', 'b', 'a')]
1907 Set a different key list to sort according to another iterable.
1908 Specifying multiple keys dictates how ties are broken::
1910 >>> iterables = [(3, 1, 2), (0, 1, 0), ('c', 'b', 'a')]
1911 >>> sort_together(iterables, key_list=(1, 2))
1912 [(2, 3, 1), (0, 0, 1), ('a', 'c', 'b')]
1914 To sort by a function of the elements of the iterable, pass a *key*
1915 function. Its arguments are the elements of the iterables corresponding to
1916 the key list::
1918 >>> names = ('a', 'b', 'c')
1919 >>> lengths = (1, 2, 3)
1920 >>> widths = (5, 2, 1)
1921 >>> def area(length, width):
1922 ... return length * width
1923 >>> sort_together([names, lengths, widths], key_list=(1, 2), key=area)
1924 [('c', 'b', 'a'), (3, 2, 1), (1, 2, 5)]
1926 Set *reverse* to ``True`` to sort in descending order.
1928 >>> sort_together([(1, 2, 3), ('c', 'b', 'a')], reverse=True)
1929 [(3, 2, 1), ('a', 'b', 'c')]
1931 If the *strict* keyword argument is ``True``, then
1932 ``ValueError`` will be raised if any of the iterables have
1933 different lengths.
1935 """
1936 if key is None:
1937 # if there is no key function, the key argument to sorted is an
1938 # itemgetter
1939 key_argument = itemgetter(*key_list)
1940 else:
1941 # if there is a key function, call it with the items at the offsets
1942 # specified by the key function as arguments
1943 key_list = list(key_list)
1944 if len(key_list) == 1:
1945 # if key_list contains a single item, pass the item at that offset
1946 # as the only argument to the key function
1947 key_offset = key_list[0]
1948 key_argument = lambda zipped_items: key(zipped_items[key_offset])
1949 else:
1950 # if key_list contains multiple items, use itemgetter to return a
1951 # tuple of items, which we pass as *args to the key function
1952 get_key_items = itemgetter(*key_list)
1953 key_argument = lambda zipped_items: key(
1954 *get_key_items(zipped_items)
1955 )
1957 transposed = zip(*iterables, strict=strict)
1958 reordered = sorted(transposed, key=key_argument, reverse=reverse)
1959 untransposed = zip(*reordered, strict=strict)
1960 return list(untransposed)
1963def unzip(iterable):
1964 """The inverse of :func:`zip`, this function disaggregates the elements
1965 of the zipped *iterable*.
1967 The ``i``-th iterable contains the ``i``-th element from each element
1968 of the zipped iterable. The first element is used to determine the
1969 length of the remaining elements.
1971 >>> iterable = [('a', 1), ('b', 2), ('c', 3), ('d', 4)]
1972 >>> letters, numbers = unzip(iterable)
1973 >>> list(letters)
1974 ['a', 'b', 'c', 'd']
1975 >>> list(numbers)
1976 [1, 2, 3, 4]
1978 This is similar to using ``zip(*iterable)``, but it avoids reading
1979 *iterable* into memory. Note, however, that this function uses
1980 :func:`itertools.tee` and thus may require significant storage.
1982 """
1983 head, iterable = spy(iterable)
1984 if not head:
1985 # empty iterable, e.g. zip([], [], [])
1986 return ()
1987 # spy returns a one-length iterable as head
1988 head = head[0]
1989 iterables = tee(iterable, len(head))
1991 # If we have an iterable like iter([(1, 2, 3), (4, 5), (6,)]),
1992 # the second unzipped iterable fails at the third tuple since
1993 # it tries to access (6,)[1].
1994 # Same with the third unzipped iterable and the second tuple.
1995 # To support these "improperly zipped" iterables, we suppress
1996 # the IndexError, which just stops the unzipped iterables at
1997 # first length mismatch.
1998 return tuple(
1999 iter_suppress(map(itemgetter(i), it), IndexError)
2000 for i, it in enumerate(iterables)
2001 )
2004def divide(n, iterable):
2005 """Divide the elements from *iterable* into *n* parts, maintaining
2006 order.
2008 >>> group_1, group_2 = divide(2, [1, 2, 3, 4, 5, 6])
2009 >>> list(group_1)
2010 [1, 2, 3]
2011 >>> list(group_2)
2012 [4, 5, 6]
2014 If the length of *iterable* is not evenly divisible by *n*, then the
2015 length of the returned iterables will not be identical:
2017 >>> children = divide(3, [1, 2, 3, 4, 5, 6, 7])
2018 >>> [list(c) for c in children]
2019 [[1, 2, 3], [4, 5], [6, 7]]
2021 If the length of the iterable is smaller than n, then the last returned
2022 iterables will be empty:
2024 >>> children = divide(5, [1, 2, 3])
2025 >>> [list(c) for c in children]
2026 [[1], [2], [3], [], []]
2028 This function will exhaust the iterable before returning.
2029 If order is not important, see :func:`distribute`, which does not first
2030 pull the iterable into memory.
2032 """
2033 if n < 1:
2034 raise ValueError('n must be at least 1')
2036 try:
2037 iterable[:0]
2038 except TypeError:
2039 seq = tuple(iterable)
2040 else:
2041 seq = iterable
2043 q, r = divmod(len(seq), n)
2045 ret = []
2046 stop = 0
2047 for i in range(1, n + 1):
2048 start = stop
2049 stop += q + 1 if i <= r else q
2050 ret.append(iter(seq[start:stop]))
2052 return ret
2055def always_iterable(obj, base_type=(str, bytes)):
2056 """If *obj* is iterable, return an iterator over its items::
2058 >>> obj = (1, 2, 3)
2059 >>> list(always_iterable(obj))
2060 [1, 2, 3]
2062 If *obj* is not iterable, return a one-item iterable containing *obj*::
2064 >>> obj = 1
2065 >>> list(always_iterable(obj))
2066 [1]
2068 If *obj* is ``None``, return an empty iterable:
2070 >>> obj = None
2071 >>> list(always_iterable(None))
2072 []
2074 By default, binary and text strings are not considered iterable::
2076 >>> obj = 'foo'
2077 >>> list(always_iterable(obj))
2078 ['foo']
2080 If *base_type* is set, objects for which ``isinstance(obj, base_type)``
2081 returns ``True`` won't be considered iterable.
2083 >>> obj = {'a': 1}
2084 >>> list(always_iterable(obj)) # Iterate over the dict's keys
2085 ['a']
2086 >>> list(always_iterable(obj, base_type=dict)) # Treat dicts as a unit
2087 [{'a': 1}]
2089 Set *base_type* to ``None`` to avoid any special handling and treat objects
2090 Python considers iterable as iterable:
2092 >>> obj = 'foo'
2093 >>> list(always_iterable(obj, base_type=None))
2094 ['f', 'o', 'o']
2095 """
2096 if obj is None:
2097 return iter(())
2099 if (base_type is not None) and isinstance(obj, base_type):
2100 return iter((obj,))
2102 try:
2103 return iter(obj)
2104 except TypeError:
2105 return iter((obj,))
2108def adjacent(predicate, iterable, distance=1):
2109 """Return an iterable over `(bool, item)` tuples where the `item` is
2110 drawn from *iterable* and the `bool` indicates whether
2111 that item satisfies the *predicate* or is adjacent to an item that does.
2113 For example, to find whether items are adjacent to a ``3``::
2115 >>> list(adjacent(lambda x: x == 3, range(6)))
2116 [(False, 0), (False, 1), (True, 2), (True, 3), (True, 4), (False, 5)]
2118 Set *distance* to change what counts as adjacent. For example, to find
2119 whether items are two places away from a ``3``:
2121 >>> list(adjacent(lambda x: x == 3, range(6), distance=2))
2122 [(False, 0), (True, 1), (True, 2), (True, 3), (True, 4), (True, 5)]
2124 This is useful for contextualizing the results of a search function.
2125 For example, a code comparison tool might want to identify lines that
2126 have changed, but also surrounding lines to give the viewer of the diff
2127 context.
2129 The predicate function will only be called once for each item in the
2130 iterable.
2132 See also :func:`groupby_transform`, which can be used with this function
2133 to group ranges of items with the same `bool` value.
2135 """
2136 # Allow distance=0 mainly for testing that it reproduces results with map()
2137 if distance < 0:
2138 raise ValueError('distance must be at least 0')
2140 i1, i2 = tee(iterable)
2141 padding = [False] * distance
2142 selected = chain(padding, map(predicate, i1), padding)
2143 adjacent_to_selected = map(any, windowed(selected, 2 * distance + 1))
2144 return zip(adjacent_to_selected, i2)
2147def groupby_transform(iterable, keyfunc=None, valuefunc=None, reducefunc=None):
2148 """An extension of :func:`itertools.groupby` that can apply transformations
2149 to the grouped data.
2151 * *keyfunc* is a function computing a key value for each item in *iterable*
2152 * *valuefunc* is a function that transforms the individual items from
2153 *iterable* after grouping
2154 * *reducefunc* is a function that transforms each group of items
2156 >>> iterable = 'aAAbBBcCC'
2157 >>> keyfunc = lambda k: k.upper()
2158 >>> valuefunc = lambda v: v.lower()
2159 >>> reducefunc = lambda g: ''.join(g)
2160 >>> list(groupby_transform(iterable, keyfunc, valuefunc, reducefunc))
2161 [('A', 'aaa'), ('B', 'bbb'), ('C', 'ccc')]
2163 Each optional argument defaults to an identity function if not specified.
2165 :func:`groupby_transform` is useful when grouping elements of an iterable
2166 using a separate iterable as the key. To do this, :func:`zip` the iterables
2167 and pass a *keyfunc* that extracts the first element and a *valuefunc*
2168 that extracts the second element::
2170 >>> from operator import itemgetter
2171 >>> keys = [0, 0, 1, 1, 1, 2, 2, 2, 3]
2172 >>> values = 'abcdefghi'
2173 >>> iterable = zip(keys, values)
2174 >>> grouper = groupby_transform(iterable, itemgetter(0), itemgetter(1))
2175 >>> [(k, ''.join(g)) for k, g in grouper]
2176 [(0, 'ab'), (1, 'cde'), (2, 'fgh'), (3, 'i')]
2178 Note that the order of items in the iterable is significant.
2179 Only adjacent items are grouped together, so if you don't want any
2180 duplicate groups, you should sort the iterable by the key function.
2182 """
2183 ret = groupby(iterable, keyfunc)
2184 if valuefunc:
2185 ret = ((k, map(valuefunc, g)) for k, g in ret)
2186 if reducefunc:
2187 ret = ((k, reducefunc(g)) for k, g in ret)
2189 return ret
2192class numeric_range(Sequence):
2193 """An extension of the built-in ``range()`` function whose arguments can
2194 be any orderable numeric type.
2196 With only *stop* specified, *start* defaults to ``0`` and *step*
2197 defaults to ``1``. The output items will match the type of *stop*:
2199 >>> list(numeric_range(3.5))
2200 [0.0, 1.0, 2.0, 3.0]
2202 With only *start* and *stop* specified, *step* defaults to ``1``. The
2203 output items will match the type of *start*:
2205 >>> from decimal import Decimal
2206 >>> start = Decimal('2.1')
2207 >>> stop = Decimal('5.1')
2208 >>> list(numeric_range(start, stop))
2209 [Decimal('2.1'), Decimal('3.1'), Decimal('4.1')]
2211 With *start*, *stop*, and *step* specified the output items will match
2212 the type of ``start + step``:
2214 >>> from fractions import Fraction
2215 >>> start = Fraction(1, 2) # Start at 1/2
2216 >>> stop = Fraction(5, 2) # End at 5/2
2217 >>> step = Fraction(1, 2) # Count by 1/2
2218 >>> list(numeric_range(start, stop, step))
2219 [Fraction(1, 2), Fraction(1, 1), Fraction(3, 2), Fraction(2, 1)]
2221 If *step* is zero, ``ValueError`` is raised. Negative steps are supported:
2223 >>> list(numeric_range(3, -1, -1.0))
2224 [3.0, 2.0, 1.0, 0.0]
2226 Be aware of the limitations of floating-point numbers; the representation
2227 of the yielded numbers may be surprising.
2229 ``datetime.datetime`` objects can be used for *start* and *stop*, if *step*
2230 is a ``datetime.timedelta`` object:
2232 >>> import datetime
2233 >>> start = datetime.datetime(2019, 1, 1)
2234 >>> stop = datetime.datetime(2019, 1, 3)
2235 >>> step = datetime.timedelta(days=1)
2236 >>> items = iter(numeric_range(start, stop, step))
2237 >>> next(items)
2238 datetime.datetime(2019, 1, 1, 0, 0)
2239 >>> next(items)
2240 datetime.datetime(2019, 1, 2, 0, 0)
2242 """
2244 _EMPTY_HASH = hash(range(0, 0))
2246 def __init__(self, *args):
2247 argc = len(args)
2248 if argc == 1:
2249 (self._stop,) = args
2250 self._start = type(self._stop)(0)
2251 self._step = type(self._stop - self._start)(1)
2252 elif argc == 2:
2253 self._start, self._stop = args
2254 self._step = type(self._stop - self._start)(1)
2255 elif argc == 3:
2256 self._start, self._stop, self._step = args
2257 elif argc == 0:
2258 raise TypeError(
2259 f'numeric_range expected at least 1 argument, got {argc}'
2260 )
2261 else:
2262 raise TypeError(
2263 f'numeric_range expected at most 3 arguments, got {argc}'
2264 )
2266 self._zero = type(self._step)(0)
2267 if self._step == self._zero:
2268 raise ValueError('numeric_range() arg 3 must not be zero')
2269 self._growing = self._step > self._zero
2271 def __bool__(self):
2272 if self._growing:
2273 return self._start < self._stop
2274 else:
2275 return self._start > self._stop
2277 def __contains__(self, elem):
2278 if self._growing:
2279 if self._start <= elem < self._stop:
2280 return (elem - self._start) % self._step == self._zero
2281 else:
2282 if self._start >= elem > self._stop:
2283 return (self._start - elem) % (-self._step) == self._zero
2285 return False
2287 def __eq__(self, other):
2288 if isinstance(other, numeric_range):
2289 empty_self = not bool(self)
2290 empty_other = not bool(other)
2291 if empty_self or empty_other:
2292 return empty_self and empty_other # True if both empty
2293 else:
2294 return (
2295 self._start == other._start
2296 and self._step == other._step
2297 and self._get_by_index(-1) == other._get_by_index(-1)
2298 )
2299 else:
2300 return False
2302 def __getitem__(self, key):
2303 if isinstance(key, int):
2304 return self._get_by_index(key)
2305 elif isinstance(key, slice):
2306 step = self._step if key.step is None else key.step * self._step
2308 if key.start is None or key.start <= -self._len:
2309 start = self._start
2310 elif key.start >= self._len:
2311 start = self._stop
2312 else: # -self._len < key.start < self._len
2313 start = self._get_by_index(key.start)
2315 if key.stop is None or key.stop >= self._len:
2316 stop = self._stop
2317 elif key.stop <= -self._len:
2318 stop = self._start
2319 else: # -self._len < key.stop < self._len
2320 stop = self._get_by_index(key.stop)
2322 return numeric_range(start, stop, step)
2323 else:
2324 raise TypeError(
2325 'numeric range indices must be '
2326 f'integers or slices, not {type(key).__name__}'
2327 )
2329 def __hash__(self):
2330 if self:
2331 return hash((self._start, self._get_by_index(-1), self._step))
2332 else:
2333 return self._EMPTY_HASH
2335 def __iter__(self):
2336 values = (self._start + (n * self._step) for n in count())
2337 if self._growing:
2338 return takewhile(partial(gt, self._stop), values)
2339 else:
2340 return takewhile(partial(lt, self._stop), values)
2342 def __len__(self):
2343 return self._len
2345 @cached_property
2346 def _len(self):
2347 if self._growing:
2348 start = self._start
2349 stop = self._stop
2350 step = self._step
2351 else:
2352 start = self._stop
2353 stop = self._start
2354 step = -self._step
2355 distance = stop - start
2356 if distance <= self._zero:
2357 return 0
2358 else: # distance > 0 and step > 0: regular euclidean division
2359 q, r = divmod(distance, step)
2360 return int(q) + int(r != self._zero)
2362 def __reduce__(self):
2363 return numeric_range, (self._start, self._stop, self._step)
2365 def __repr__(self):
2366 if self._step == 1:
2367 return f"numeric_range({self._start!r}, {self._stop!r})"
2368 return (
2369 f"numeric_range({self._start!r}, {self._stop!r}, {self._step!r})"
2370 )
2372 def __reversed__(self):
2373 return iter(
2374 numeric_range(
2375 self._get_by_index(-1), self._start - self._step, -self._step
2376 )
2377 )
2379 def count(self, value):
2380 return int(value in self)
2382 def index(self, value):
2383 if self._growing:
2384 if self._start <= value < self._stop:
2385 q, r = divmod(value - self._start, self._step)
2386 if r == self._zero:
2387 return int(q)
2388 else:
2389 if self._start >= value > self._stop:
2390 q, r = divmod(self._start - value, -self._step)
2391 if r == self._zero:
2392 return int(q)
2394 raise ValueError(f"{value} is not in numeric range")
2396 def _get_by_index(self, i):
2397 if i < 0:
2398 i += self._len
2399 if i < 0 or i >= self._len:
2400 raise IndexError("numeric range object index out of range")
2401 return self._start + i * self._step
2404def count_cycle(iterable, n=None):
2405 """Cycle through the items from *iterable* up to *n* times, yielding
2406 the number of completed cycles along with each item. If *n* is omitted the
2407 process repeats indefinitely.
2409 >>> list(count_cycle('AB', 3))
2410 [(0, 'A'), (0, 'B'), (1, 'A'), (1, 'B'), (2, 'A'), (2, 'B')]
2412 """
2413 if n is not None:
2414 return product(range(n), iterable)
2415 seq = tuple(iterable)
2416 if not seq:
2417 return iter(())
2418 counter = count() if n is None else range(n)
2419 return zip(repeat_each(counter, len(seq)), cycle(seq))
2422def mark_ends(iterable):
2423 """Yield 3-tuples of the form ``(is_first, is_last, item)``.
2425 >>> list(mark_ends('ABC'))
2426 [(True, False, 'A'), (False, False, 'B'), (False, True, 'C')]
2428 Use this when looping over an iterable to take special action on its first
2429 and/or last items:
2431 >>> iterable = ['Header', 100, 200, 'Footer']
2432 >>> total = 0
2433 >>> for is_first, is_last, item in mark_ends(iterable):
2434 ... if is_first:
2435 ... continue # Skip the header
2436 ... if is_last:
2437 ... continue # Skip the footer
2438 ... total += item
2439 >>> print(total)
2440 300
2441 """
2442 it = iter(iterable)
2443 for a in it:
2444 first = True
2445 for b in it:
2446 yield first, False, a
2447 a = b
2448 first = False
2449 yield first, True, a
2452def locate(iterable, pred=bool, window_size=None):
2453 """Yield the index of each item in *iterable* for which *pred* returns
2454 ``True``.
2456 *pred* defaults to :func:`bool`, which will select truthy items:
2458 >>> list(locate([0, 1, 1, 0, 1, 0, 0]))
2459 [1, 2, 4]
2461 Set *pred* to a custom function to, e.g., find the indexes for a particular
2462 item.
2464 >>> list(locate(['a', 'b', 'c', 'b'], lambda x: x == 'b'))
2465 [1, 3]
2467 If *window_size* is given, then the *pred* function will be called with
2468 that many items. This enables searching for sub-sequences:
2470 >>> iterable = [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]
2471 >>> pred = lambda *args: args == (1, 2, 3)
2472 >>> list(locate(iterable, pred=pred, window_size=3))
2473 [1, 5, 9]
2475 Use with :func:`seekable` to find indexes and then retrieve the associated
2476 items:
2478 >>> from itertools import count
2479 >>> from more_itertools import seekable
2480 >>> source = (3 * n + 1 if (n % 2) else n // 2 for n in count())
2481 >>> it = seekable(source)
2482 >>> pred = lambda x: x > 100
2483 >>> indexes = locate(it, pred=pred)
2484 >>> i = next(indexes)
2485 >>> it.seek(i)
2486 >>> next(it)
2487 106
2489 """
2490 if window_size is None:
2491 return compress(count(), map(pred, iterable))
2493 if window_size < 1:
2494 raise ValueError('window size must be at least 1')
2496 it = windowed(iterable, window_size, fillvalue=_marker)
2497 return compress(count(), starmap(pred, it))
2500def longest_common_prefix(iterables):
2501 """Yield elements of the longest common prefix among given *iterables*.
2503 >>> ''.join(longest_common_prefix(['abcd', 'abc', 'abf']))
2504 'ab'
2506 """
2507 return (c[0] for c in takewhile(all_equal, zip(*iterables)))
2510def lstrip(iterable, pred):
2511 """Yield the items from *iterable*, but strip any from the beginning
2512 for which *pred* returns ``True``.
2514 For example, to remove a set of items from the start of an iterable:
2516 >>> iterable = (None, False, None, 1, 2, None, 3, False, None)
2517 >>> pred = lambda x: x in {None, False, ''}
2518 >>> list(lstrip(iterable, pred))
2519 [1, 2, None, 3, False, None]
2521 This function is analogous to to :func:`str.lstrip`, and is essentially
2522 an wrapper for :func:`itertools.dropwhile`.
2524 """
2525 return dropwhile(pred, iterable)
2528def rstrip(iterable, pred):
2529 """Yield the items from *iterable*, but strip any from the end
2530 for which *pred* returns ``True``.
2532 For example, to remove a set of items from the end of an iterable:
2534 >>> iterable = (None, False, None, 1, 2, None, 3, False, None)
2535 >>> pred = lambda x: x in {None, False, ''}
2536 >>> list(rstrip(iterable, pred))
2537 [None, False, None, 1, 2, None, 3]
2539 This function is analogous to :func:`str.rstrip`.
2541 """
2542 cache = []
2543 cache_append = cache.append
2544 cache_clear = cache.clear
2545 for x in iterable:
2546 if pred(x):
2547 cache_append(x)
2548 else:
2549 yield from cache
2550 cache_clear()
2551 yield x
2554def strip(iterable, pred):
2555 """Yield the items from *iterable*, but strip any from the
2556 beginning and end for which *pred* returns ``True``.
2558 For example, to remove a set of items from both ends of an iterable:
2560 >>> iterable = (None, False, None, 1, 2, None, 3, False, None)
2561 >>> pred = lambda x: x in {None, False, ''}
2562 >>> list(strip(iterable, pred))
2563 [1, 2, None, 3]
2565 This function is analogous to :func:`str.strip`.
2567 """
2568 return rstrip(lstrip(iterable, pred), pred)
2571class islice_extended:
2572 """An extension of :func:`itertools.islice` that supports negative values
2573 for *stop*, *start*, and *step*.
2575 >>> iterator = iter('abcdefgh')
2576 >>> list(islice_extended(iterator, -4, -1))
2577 ['e', 'f', 'g']
2579 Slices with negative values require some caching of *iterable*, but this
2580 function takes care to minimize the amount of memory required.
2582 For example, you can use a negative step with an infinite iterator:
2584 >>> from itertools import count
2585 >>> list(islice_extended(count(), 110, 99, -2))
2586 [110, 108, 106, 104, 102, 100]
2588 You can also use slice notation directly:
2590 >>> iterator = map(str, count())
2591 >>> it = islice_extended(iterator)[10:20:2]
2592 >>> list(it)
2593 ['10', '12', '14', '16', '18']
2595 """
2597 def __init__(self, iterable, *args):
2598 it = iter(iterable)
2599 if args:
2600 self._iterator = _islice_helper(it, slice(*args))
2601 else:
2602 self._iterator = it
2604 def __iter__(self):
2605 return self
2607 def __next__(self):
2608 return next(self._iterator)
2610 def __getitem__(self, key):
2611 if isinstance(key, slice):
2612 return islice_extended(_islice_helper(self._iterator, key))
2614 raise TypeError('islice_extended.__getitem__ argument must be a slice')
2617def _islice_helper(it, s):
2618 start = s.start
2619 stop = s.stop
2620 if s.step == 0:
2621 raise ValueError('step argument must be a non-zero integer or None.')
2622 step = s.step or 1
2624 if step > 0:
2625 start = 0 if (start is None) else start
2627 if start < 0:
2628 # Consume all but the last -start items
2629 cache = deque(enumerate(it, 1), maxlen=-start)
2630 len_iter = cache[-1][0] if cache else 0
2632 # Adjust start to be positive
2633 i = max(len_iter + start, 0)
2635 # Adjust stop to be positive
2636 if stop is None:
2637 j = len_iter
2638 elif stop >= 0:
2639 j = min(stop, len_iter)
2640 else:
2641 j = max(len_iter + stop, 0)
2643 # Slice the cache
2644 n = j - i
2645 if n <= 0:
2646 return
2648 for index in range(n):
2649 if index % step == 0:
2650 # pop and yield the item.
2651 # We don't want to use an intermediate variable
2652 # it would extend the lifetime of the current item
2653 yield cache.popleft()[1]
2654 else:
2655 # just pop and discard the item
2656 cache.popleft()
2657 elif (stop is not None) and (stop < 0):
2658 # Advance to the start position
2659 next(islice(it, start, start), None)
2661 # When stop is negative, we have to carry -stop items while
2662 # iterating
2663 cache = deque(islice(it, -stop), maxlen=-stop)
2665 for index, item in enumerate(it):
2666 if index % step == 0:
2667 # pop and yield the item.
2668 # We don't want to use an intermediate variable
2669 # it would extend the lifetime of the current item
2670 yield cache.popleft()
2671 else:
2672 # just pop and discard the item
2673 cache.popleft()
2674 cache.append(item)
2675 else:
2676 # When both start and stop are positive we have the normal case
2677 yield from islice(it, start, stop, step)
2678 else:
2679 start = -1 if (start is None) else start
2681 if (stop is not None) and (stop < 0):
2682 # Consume all but the last items
2683 n = -stop - 1
2684 cache = deque(enumerate(it, 1), maxlen=n)
2685 len_iter = cache[-1][0] if cache else 0
2687 # If start and stop are both negative they are comparable and
2688 # we can just slice. Otherwise we can adjust start to be negative
2689 # and then slice.
2690 if start < 0:
2691 i, j = start, stop
2692 else:
2693 i, j = min(start - len_iter, -1), None
2695 for index, item in list(cache)[i:j:step]:
2696 yield item
2697 else:
2698 # Advance to the stop position
2699 if stop is not None:
2700 m = stop + 1
2701 next(islice(it, m, m), None)
2703 # stop is positive, so if start is negative they are not comparable
2704 # and we need the rest of the items.
2705 if start < 0:
2706 i = start
2707 n = None
2708 # stop is None and start is positive, so we just need items up to
2709 # the start index.
2710 elif stop is None:
2711 i = None
2712 n = start + 1
2713 # Both stop and start are positive, so they are comparable.
2714 else:
2715 i = None
2716 n = start - stop
2717 if n <= 0:
2718 return
2720 cache = list(islice(it, n))
2722 yield from cache[i::step]
2725def always_reversible(iterable):
2726 """An extension of :func:`reversed` that supports all iterables, not
2727 just those which implement the ``Reversible`` or ``Sequence`` protocols.
2729 >>> print(*always_reversible(x for x in range(3)))
2730 2 1 0
2732 If the iterable is already reversible, this function returns the
2733 result of :func:`reversed()`. If the iterable is not reversible,
2734 this function will cache the remaining items in the iterable and
2735 yield them in reverse order, which may require significant storage.
2736 """
2737 try:
2738 return reversed(iterable)
2739 except TypeError:
2740 return reversed(list(iterable))
2743def consecutive_groups(iterable, ordering=None):
2744 """Yield groups of consecutive items using :func:`itertools.groupby`.
2745 The *ordering* function determines whether two items are adjacent by
2746 returning their position.
2748 By default, the ordering function is the identity function. This is
2749 suitable for finding runs of numbers:
2751 >>> iterable = [1, 10, 11, 12, 20, 30, 31, 32, 33, 40]
2752 >>> for group in consecutive_groups(iterable):
2753 ... print(list(group))
2754 [1]
2755 [10, 11, 12]
2756 [20]
2757 [30, 31, 32, 33]
2758 [40]
2760 To find runs of adjacent letters, apply :func:`ord` function
2761 to convert letters to ordinals.
2763 >>> iterable = 'abcdfgilmnop'
2764 >>> ordering = ord
2765 >>> for group in consecutive_groups(iterable, ordering):
2766 ... print(list(group))
2767 ['a', 'b', 'c', 'd']
2768 ['f', 'g']
2769 ['i']
2770 ['l', 'm', 'n', 'o', 'p']
2772 Each group of consecutive items is an iterator that shares it source with
2773 *iterable*. When an an output group is advanced, the previous group is
2774 no longer available unless its elements are copied (e.g., into a ``list``).
2776 >>> iterable = [1, 2, 11, 12, 21, 22]
2777 >>> saved_groups = []
2778 >>> for group in consecutive_groups(iterable):
2779 ... saved_groups.append(list(group)) # Copy group elements
2780 >>> saved_groups
2781 [[1, 2], [11, 12], [21, 22]]
2783 """
2784 if ordering is None:
2785 key = lambda x: x[0] - x[1]
2786 else:
2787 key = lambda x: x[0] - ordering(x[1])
2789 for k, g in groupby(enumerate(iterable), key=key):
2790 yield map(itemgetter(1), g)
2793def difference(iterable, func=sub, *, initial=None):
2794 """This function is the inverse of :func:`itertools.accumulate`. By default
2795 it will compute the first difference of *iterable* using
2796 :func:`operator.sub`:
2798 >>> from itertools import accumulate
2799 >>> iterable = accumulate([0, 1, 2, 3, 4]) # produces 0, 1, 3, 6, 10
2800 >>> list(difference(iterable))
2801 [0, 1, 2, 3, 4]
2803 *func* defaults to :func:`operator.sub`, but other functions can be
2804 specified. They will be applied as follows::
2806 A, B, C, D, ... --> A, func(B, A), func(C, B), func(D, C), ...
2808 For example, to do progressive division:
2810 >>> iterable = [1, 2, 6, 24, 120]
2811 >>> func = lambda x, y: x // y
2812 >>> list(difference(iterable, func))
2813 [1, 2, 3, 4, 5]
2815 If the *initial* keyword is set, the first element will be skipped when
2816 computing successive differences.
2818 >>> it = [10, 11, 13, 16] # from accumulate([1, 2, 3], initial=10)
2819 >>> list(difference(it, initial=10))
2820 [1, 2, 3]
2822 """
2823 a, b = tee(iterable)
2824 try:
2825 first = [next(b)]
2826 except StopIteration:
2827 return iter([])
2829 if initial is not None:
2830 first = []
2832 return chain(first, map(func, b, a))
2835class SequenceView(Sequence):
2836 """Return a read-only view of the sequence object *target*.
2838 :class:`SequenceView` objects are analogous to Python's built-in
2839 "dictionary view" types. They provide a dynamic view of a sequence's items,
2840 meaning that when the sequence updates, so does the view.
2842 >>> seq = ['0', '1', '2']
2843 >>> view = SequenceView(seq)
2844 >>> view
2845 SequenceView(['0', '1', '2'])
2846 >>> seq.append('3')
2847 >>> view
2848 SequenceView(['0', '1', '2', '3'])
2850 Sequence views support indexing, slicing, and length queries. They act
2851 like the underlying sequence, except they don't allow assignment:
2853 >>> view[1]
2854 '1'
2855 >>> view[1:-1]
2856 ['1', '2']
2857 >>> len(view)
2858 4
2860 Sequence views are useful as an alternative to copying, as they don't
2861 require (much) extra storage.
2863 """
2865 def __init__(self, target):
2866 if not isinstance(target, Sequence):
2867 raise TypeError
2868 self._target = target
2870 def __getitem__(self, index):
2871 return self._target[index]
2873 def __len__(self):
2874 return len(self._target)
2876 def __repr__(self):
2877 return f'{self.__class__.__name__}({self._target!r})'
2880class seekable:
2881 """Wrap an iterator to allow for seeking backward and forward. This
2882 progressively caches the items in the source iterable so they can be
2883 re-visited.
2885 Call :meth:`seek` with an index to seek to that position in the source
2886 iterable.
2888 To "reset" an iterator, seek to ``0``:
2890 >>> from itertools import count
2891 >>> it = seekable((str(n) for n in count()))
2892 >>> next(it), next(it), next(it)
2893 ('0', '1', '2')
2894 >>> it.seek(0)
2895 >>> next(it), next(it), next(it)
2896 ('0', '1', '2')
2898 You can also seek forward:
2900 >>> it = seekable((str(n) for n in range(20)))
2901 >>> it.seek(10)
2902 >>> next(it)
2903 '10'
2904 >>> it.seek(20) # Seeking past the end of the source isn't a problem
2905 >>> list(it)
2906 []
2907 >>> it.seek(0) # Resetting works even after hitting the end
2908 >>> next(it)
2909 '0'
2911 Call :meth:`relative_seek` to seek relative to the source iterator's
2912 current position.
2914 >>> it = seekable((str(n) for n in range(20)))
2915 >>> next(it), next(it), next(it)
2916 ('0', '1', '2')
2917 >>> it.relative_seek(2)
2918 >>> next(it)
2919 '5'
2920 >>> it.relative_seek(-3) # Source is at '6', we move back to '3'
2921 >>> next(it)
2922 '3'
2923 >>> it.relative_seek(-3) # Source is at '4', we move back to '1'
2924 >>> next(it)
2925 '1'
2928 Call :meth:`peek` to look ahead one item without advancing the iterator:
2930 >>> it = seekable('1234')
2931 >>> it.peek()
2932 '1'
2933 >>> list(it)
2934 ['1', '2', '3', '4']
2935 >>> it.peek(default='empty')
2936 'empty'
2938 Before the iterator is at its end, calling :func:`bool` on it will return
2939 ``True``. After it will return ``False``:
2941 >>> it = seekable('5678')
2942 >>> bool(it)
2943 True
2944 >>> list(it)
2945 ['5', '6', '7', '8']
2946 >>> bool(it)
2947 False
2949 You may view the contents of the cache with the :meth:`elements` method.
2950 That returns a :class:`SequenceView`, a view that updates automatically:
2952 >>> it = seekable((str(n) for n in range(10)))
2953 >>> next(it), next(it), next(it)
2954 ('0', '1', '2')
2955 >>> elements = it.elements()
2956 >>> elements
2957 SequenceView(['0', '1', '2'])
2958 >>> next(it)
2959 '3'
2960 >>> elements
2961 SequenceView(['0', '1', '2', '3'])
2963 By default, the cache grows as the source iterable progresses, so beware of
2964 wrapping very large or infinite iterables. Supply *maxlen* to limit the
2965 size of the cache (this of course limits how far back you can seek).
2967 >>> from itertools import count
2968 >>> it = seekable((str(n) for n in count()), maxlen=2)
2969 >>> next(it), next(it), next(it), next(it)
2970 ('0', '1', '2', '3')
2971 >>> list(it.elements())
2972 ['2', '3']
2973 >>> it.seek(0)
2974 >>> next(it), next(it), next(it), next(it)
2975 ('2', '3', '4', '5')
2976 >>> next(it)
2977 '6'
2979 """
2981 def __init__(self, iterable, maxlen=None):
2982 self._source = iter(iterable)
2983 if maxlen is None:
2984 self._cache = []
2985 else:
2986 self._cache = deque([], maxlen)
2987 self._index = None
2989 def __iter__(self):
2990 return self
2992 def __next__(self):
2993 if self._index is not None:
2994 try:
2995 item = self._cache[self._index]
2996 except IndexError:
2997 self._index = None
2998 else:
2999 self._index += 1
3000 return item
3002 item = next(self._source)
3003 self._cache.append(item)
3004 return item
3006 def __bool__(self):
3007 try:
3008 self.peek()
3009 except StopIteration:
3010 return False
3011 return True
3013 def peek(self, default=_marker):
3014 try:
3015 peeked = next(self)
3016 except StopIteration:
3017 if default is _marker:
3018 raise
3019 return default
3020 if self._index is None:
3021 self._index = len(self._cache)
3022 self._index -= 1
3023 return peeked
3025 def elements(self):
3026 return SequenceView(self._cache)
3028 def seek(self, index):
3029 self._index = index
3030 remainder = index - len(self._cache)
3031 if remainder > 0:
3032 consume(self, remainder)
3034 def relative_seek(self, count):
3035 if self._index is None:
3036 self._index = len(self._cache)
3038 self.seek(max(self._index + count, 0))
3041class run_length:
3042 """
3043 :func:`run_length.encode` compresses an iterable with run-length encoding.
3044 It yields groups of repeated items with the count of how many times they
3045 were repeated:
3047 >>> uncompressed = 'abbcccdddd'
3048 >>> list(run_length.encode(uncompressed))
3049 [('a', 1), ('b', 2), ('c', 3), ('d', 4)]
3051 :func:`run_length.decode` decompresses an iterable that was previously
3052 compressed with run-length encoding. It yields the items of the
3053 decompressed iterable:
3055 >>> compressed = [('a', 1), ('b', 2), ('c', 3), ('d', 4)]
3056 >>> list(run_length.decode(compressed))
3057 ['a', 'b', 'b', 'c', 'c', 'c', 'd', 'd', 'd', 'd']
3059 """
3061 @staticmethod
3062 def encode(iterable):
3063 return ((k, ilen(g)) for k, g in groupby(iterable))
3065 @staticmethod
3066 def decode(iterable):
3067 return chain.from_iterable(starmap(repeat, iterable))
3070def exactly_n(iterable, n, predicate=bool):
3071 """Return ``True`` if exactly ``n`` items in the iterable are ``True``
3072 according to the *predicate* function.
3074 >>> exactly_n([True, True, False], 2)
3075 True
3076 >>> exactly_n([True, True, False], 1)
3077 False
3078 >>> exactly_n([0, 1, 2, 3, 4, 5], 3, lambda x: x < 3)
3079 True
3081 The iterable will be advanced until ``n + 1`` truthy items are encountered,
3082 so avoid calling it on infinite iterables.
3084 """
3085 iterator = filter(predicate, iterable)
3086 if n <= 0:
3087 if n < 0:
3088 return False
3089 for _ in iterator:
3090 return False
3091 return True
3093 iterator = islice(iterator, n - 1, None)
3094 for _ in iterator:
3095 for _ in iterator:
3096 return False
3097 return True
3098 return False
3101def circular_shifts(iterable, steps=1):
3102 """Yield the circular shifts of *iterable*.
3104 >>> list(circular_shifts(range(4)))
3105 [(0, 1, 2, 3), (1, 2, 3, 0), (2, 3, 0, 1), (3, 0, 1, 2)]
3107 Set *steps* to the number of places to rotate to the left
3108 (or to the right if negative). Defaults to 1.
3110 >>> list(circular_shifts(range(4), 2))
3111 [(0, 1, 2, 3), (2, 3, 0, 1)]
3113 >>> list(circular_shifts(range(4), -1))
3114 [(0, 1, 2, 3), (3, 0, 1, 2), (2, 3, 0, 1), (1, 2, 3, 0)]
3116 """
3117 buffer = deque(iterable)
3118 if steps == 0:
3119 raise ValueError('Steps should be a non-zero integer')
3121 buffer.rotate(steps)
3122 steps = -steps
3123 n = len(buffer)
3124 n //= math.gcd(n, steps)
3126 for _ in repeat(None, n):
3127 buffer.rotate(steps)
3128 yield tuple(buffer)
3131def make_decorator(wrapping_func, result_index=0):
3132 """Return a decorator version of *wrapping_func*, which is a function that
3133 modifies an iterable. *result_index* is the position in that function's
3134 signature where the iterable goes.
3136 This lets you use itertools on the "production end," i.e. at function
3137 definition. This can augment what the function returns without changing the
3138 function's code.
3140 For example, to produce a decorator version of :func:`chunked`:
3142 >>> from more_itertools import chunked
3143 >>> chunker = make_decorator(chunked, result_index=0)
3144 >>> @chunker(3)
3145 ... def iter_range(n):
3146 ... return iter(range(n))
3147 ...
3148 >>> list(iter_range(9))
3149 [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
3151 To only allow truthy items to be returned:
3153 >>> truth_serum = make_decorator(filter, result_index=1)
3154 >>> @truth_serum(bool)
3155 ... def boolean_test():
3156 ... return [0, 1, '', ' ', False, True]
3157 ...
3158 >>> list(boolean_test())
3159 [1, ' ', True]
3161 The :func:`peekable` and :func:`seekable` wrappers make for practical
3162 decorators:
3164 >>> from more_itertools import peekable
3165 >>> peekable_function = make_decorator(peekable)
3166 >>> @peekable_function()
3167 ... def str_range(*args):
3168 ... return (str(x) for x in range(*args))
3169 ...
3170 >>> it = str_range(1, 20, 2)
3171 >>> next(it), next(it), next(it)
3172 ('1', '3', '5')
3173 >>> it.peek()
3174 '7'
3175 >>> next(it)
3176 '7'
3178 """
3180 # See https://sites.google.com/site/bbayles/index/decorator_factory for
3181 # notes on how this works.
3182 def decorator(*wrapping_args, **wrapping_kwargs):
3183 def outer_wrapper(f):
3184 def inner_wrapper(*args, **kwargs):
3185 result = f(*args, **kwargs)
3186 wrapping_args_ = list(wrapping_args)
3187 wrapping_args_.insert(result_index, result)
3188 return wrapping_func(*wrapping_args_, **wrapping_kwargs)
3190 return inner_wrapper
3192 return outer_wrapper
3194 return decorator
3197def map_reduce(iterable, keyfunc, valuefunc=None, reducefunc=None):
3198 """Return a dictionary that maps the items in *iterable* to categories
3199 defined by *keyfunc*, transforms them with *valuefunc*, and
3200 then summarizes them by category with *reducefunc*.
3202 *valuefunc* defaults to the identity function if it is unspecified.
3203 If *reducefunc* is unspecified, no summarization takes place:
3205 >>> keyfunc = lambda x: x.upper()
3206 >>> result = map_reduce('abbccc', keyfunc)
3207 >>> sorted(result.items())
3208 [('A', ['a']), ('B', ['b', 'b']), ('C', ['c', 'c', 'c'])]
3210 Specifying *valuefunc* transforms the categorized items:
3212 >>> keyfunc = lambda x: x.upper()
3213 >>> valuefunc = lambda x: 1
3214 >>> result = map_reduce('abbccc', keyfunc, valuefunc)
3215 >>> sorted(result.items())
3216 [('A', [1]), ('B', [1, 1]), ('C', [1, 1, 1])]
3218 Specifying *reducefunc* summarizes the categorized items:
3220 >>> keyfunc = lambda x: x.upper()
3221 >>> valuefunc = lambda x: 1
3222 >>> reducefunc = sum
3223 >>> result = map_reduce('abbccc', keyfunc, valuefunc, reducefunc)
3224 >>> sorted(result.items())
3225 [('A', 1), ('B', 2), ('C', 3)]
3227 You may want to filter the input iterable before applying the map/reduce
3228 procedure:
3230 >>> all_items = range(30)
3231 >>> items = [x for x in all_items if 10 <= x <= 20] # Filter
3232 >>> keyfunc = lambda x: x % 2 # Evens map to 0; odds to 1
3233 >>> categories = map_reduce(items, keyfunc=keyfunc)
3234 >>> sorted(categories.items())
3235 [(0, [10, 12, 14, 16, 18, 20]), (1, [11, 13, 15, 17, 19])]
3236 >>> summaries = map_reduce(items, keyfunc=keyfunc, reducefunc=sum)
3237 >>> sorted(summaries.items())
3238 [(0, 90), (1, 75)]
3240 Note that all items in the iterable are gathered into a list before the
3241 summarization step, which may require significant storage.
3243 The returned object is a :obj:`collections.defaultdict` with the
3244 ``default_factory`` set to ``None``, such that it behaves like a normal
3245 dictionary.
3247 """
3249 ret = defaultdict(list)
3251 if valuefunc is None:
3252 for item in iterable:
3253 key = keyfunc(item)
3254 ret[key].append(item)
3256 else:
3257 for item in iterable:
3258 key = keyfunc(item)
3259 value = valuefunc(item)
3260 ret[key].append(value)
3262 if reducefunc is not None:
3263 for key, value_list in ret.items():
3264 ret[key] = reducefunc(value_list)
3266 ret.default_factory = None
3267 return ret
3270def rlocate(iterable, pred=bool, window_size=None):
3271 """Yield the index of each item in *iterable* for which *pred* returns
3272 ``True``, starting from the right and moving left.
3274 *pred* defaults to :func:`bool`, which will select truthy items:
3276 >>> list(rlocate([0, 1, 1, 0, 1, 0, 0])) # Truthy at 1, 2, and 4
3277 [4, 2, 1]
3279 Set *pred* to a custom function to, e.g., find the indexes for a particular
3280 item:
3282 >>> iterator = iter('abcb')
3283 >>> pred = lambda x: x == 'b'
3284 >>> list(rlocate(iterator, pred))
3285 [3, 1]
3287 If *window_size* is given, then the *pred* function will be called with
3288 that many items. This enables searching for sub-sequences:
3290 >>> iterable = [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]
3291 >>> pred = lambda *args: args == (1, 2, 3)
3292 >>> list(rlocate(iterable, pred=pred, window_size=3))
3293 [9, 5, 1]
3295 Beware, this function won't return anything for infinite iterables.
3296 If *iterable* is reversible, ``rlocate`` will reverse it and search from
3297 the right. Otherwise, it will search from the left and return the results
3298 in reverse order.
3300 See :func:`locate` to for other example applications.
3302 """
3303 if window_size is None:
3304 try:
3305 len_iter = len(iterable)
3306 return (len_iter - i - 1 for i in locate(reversed(iterable), pred))
3307 except TypeError:
3308 pass
3310 return reversed(list(locate(iterable, pred, window_size)))
3313def replace(iterable, pred, substitutes, count=None, window_size=1):
3314 """Yield the items from *iterable*, replacing the items for which *pred*
3315 returns ``True`` with the items from the iterable *substitutes*.
3317 >>> iterable = [1, 1, 0, 1, 1, 0, 1, 1]
3318 >>> pred = lambda x: x == 0
3319 >>> substitutes = (2, 3)
3320 >>> list(replace(iterable, pred, substitutes))
3321 [1, 1, 2, 3, 1, 1, 2, 3, 1, 1]
3323 If *count* is given, the number of replacements will be limited:
3325 >>> iterable = [1, 1, 0, 1, 1, 0, 1, 1, 0]
3326 >>> pred = lambda x: x == 0
3327 >>> substitutes = [None]
3328 >>> list(replace(iterable, pred, substitutes, count=2))
3329 [1, 1, None, 1, 1, None, 1, 1, 0]
3331 Use *window_size* to control the number of items passed as arguments to
3332 *pred*. This allows for locating and replacing subsequences.
3334 >>> iterable = [0, 1, 2, 5, 0, 1, 2, 5]
3335 >>> window_size = 3
3336 >>> pred = lambda *args: args == (0, 1, 2) # 3 items passed to pred
3337 >>> substitutes = [3, 4] # Splice in these items
3338 >>> list(replace(iterable, pred, substitutes, window_size=window_size))
3339 [3, 4, 5, 3, 4, 5]
3341 """
3342 if window_size < 1:
3343 raise ValueError('window_size must be at least 1')
3345 # Save the substitutes iterable, since it's used more than once
3346 substitutes = tuple(substitutes)
3348 # Add padding such that the number of windows matches the length of the
3349 # iterable
3350 it = chain(iterable, repeat(_marker, window_size - 1))
3351 windows = windowed(it, window_size)
3353 n = 0
3354 for w in windows:
3355 # If the current window matches our predicate (and we haven't hit
3356 # our maximum number of replacements), splice in the substitutes
3357 # and then consume the following windows that overlap with this one.
3358 # For example, if the iterable is (0, 1, 2, 3, 4...)
3359 # and the window size is 2, we have (0, 1), (1, 2), (2, 3)...
3360 # If the predicate matches on (0, 1), we need to zap (0, 1) and (1, 2)
3361 if pred(*w):
3362 if (count is None) or (n < count):
3363 n += 1
3364 yield from substitutes
3365 consume(windows, window_size - 1)
3366 continue
3368 # If there was no match (or we've reached the replacement limit),
3369 # yield the first item from the window.
3370 if w and (w[0] is not _marker):
3371 yield w[0]
3374def partitions(iterable):
3375 """Yield all possible order-preserving partitions of *iterable*.
3377 >>> iterable = 'abc'
3378 >>> for part in partitions(iterable):
3379 ... print([''.join(p) for p in part])
3380 ['abc']
3381 ['a', 'bc']
3382 ['ab', 'c']
3383 ['a', 'b', 'c']
3385 This is unrelated to :func:`partition`.
3387 """
3388 sequence = list(iterable)
3389 n = len(sequence)
3390 for i in powerset(range(1, n)):
3391 yield [sequence[i:j] for i, j in zip((0,) + i, i + (n,))]
3394def set_partitions(iterable, k=None, min_size=None, max_size=None):
3395 """
3396 Yield the set partitions of *iterable* into *k* parts. Set partitions are
3397 not order-preserving.
3399 >>> iterable = 'abc'
3400 >>> for part in set_partitions(iterable, 2):
3401 ... print([''.join(p) for p in part])
3402 ['a', 'bc']
3403 ['ab', 'c']
3404 ['b', 'ac']
3407 If *k* is not given, every set partition is generated.
3409 >>> iterable = 'abc'
3410 >>> for part in set_partitions(iterable):
3411 ... print([''.join(p) for p in part])
3412 ['abc']
3413 ['a', 'bc']
3414 ['ab', 'c']
3415 ['b', 'ac']
3416 ['a', 'b', 'c']
3418 if *min_size* and/or *max_size* are given, the minimum and/or maximum size
3419 per block in partition is set.
3421 >>> iterable = 'abc'
3422 >>> for part in set_partitions(iterable, min_size=2):
3423 ... print([''.join(p) for p in part])
3424 ['abc']
3425 >>> for part in set_partitions(iterable, max_size=2):
3426 ... print([''.join(p) for p in part])
3427 ['a', 'bc']
3428 ['ab', 'c']
3429 ['b', 'ac']
3430 ['a', 'b', 'c']
3432 """
3433 L = list(iterable)
3434 n = len(L)
3435 if k is not None:
3436 if k < 1:
3437 raise ValueError(
3438 "Can't partition in a negative or zero number of groups"
3439 )
3440 elif k > n:
3441 return
3443 min_size = min_size if min_size is not None else 0
3444 max_size = max_size if max_size is not None else n
3445 if min_size > max_size:
3446 return
3448 def set_partitions_helper(L, k):
3449 n = len(L)
3450 if k == 1:
3451 yield [L]
3452 elif n == k:
3453 yield [[s] for s in L]
3454 else:
3455 e, *M = L
3456 for p in set_partitions_helper(M, k - 1):
3457 yield [[e], *p]
3458 for p in set_partitions_helper(M, k):
3459 for i in range(len(p)):
3460 yield p[:i] + [[e] + p[i]] + p[i + 1 :]
3462 if k is None:
3463 for k in range(1, n + 1):
3464 yield from filter(
3465 lambda z: all(min_size <= len(bk) <= max_size for bk in z),
3466 set_partitions_helper(L, k),
3467 )
3468 else:
3469 yield from filter(
3470 lambda z: all(min_size <= len(bk) <= max_size for bk in z),
3471 set_partitions_helper(L, k),
3472 )
3475class time_limited:
3476 """
3477 Yield items from *iterable* until *limit_seconds* have passed.
3478 If the time limit expires before all items have been yielded, the
3479 ``timed_out`` parameter will be set to ``True``.
3481 >>> from time import sleep
3482 >>> def generator():
3483 ... yield 1
3484 ... yield 2
3485 ... sleep(0.2)
3486 ... yield 3
3487 >>> iterable = time_limited(0.1, generator())
3488 >>> list(iterable)
3489 [1, 2]
3490 >>> iterable.timed_out
3491 True
3493 Note that the time is checked before each item is yielded, and iteration
3494 stops if the time elapsed is greater than *limit_seconds*. If your time
3495 limit is 1 second, but it takes 2 seconds to generate the first item from
3496 the iterable, the function will run for 2 seconds and not yield anything.
3497 As a special case, when *limit_seconds* is zero, the iterator never
3498 returns anything.
3500 """
3502 def __init__(self, limit_seconds, iterable):
3503 if limit_seconds < 0:
3504 raise ValueError('limit_seconds must be positive')
3505 self.limit_seconds = limit_seconds
3506 self._iterator = iter(iterable)
3507 self._start_time = monotonic()
3508 self.timed_out = False
3510 def __iter__(self):
3511 return self
3513 def __next__(self):
3514 if self.limit_seconds == 0:
3515 self.timed_out = True
3516 raise StopIteration
3517 item = next(self._iterator)
3518 if monotonic() - self._start_time > self.limit_seconds:
3519 self.timed_out = True
3520 raise StopIteration
3522 return item
3525def only(iterable, default=None, too_long=None):
3526 """If *iterable* has only one item, return it.
3527 If it has zero items, return *default*.
3528 If it has more than one item, raise the exception given by *too_long*,
3529 which is ``ValueError`` by default.
3531 >>> only([], default='missing')
3532 'missing'
3533 >>> only([1])
3534 1
3535 >>> only([1, 2]) # doctest: +IGNORE_EXCEPTION_DETAIL
3536 Traceback (most recent call last):
3537 ...
3538 ValueError: Expected exactly one item in iterable, but got 1, 2,
3539 and perhaps more.'
3540 >>> only([1, 2], too_long=TypeError) # doctest: +IGNORE_EXCEPTION_DETAIL
3541 Traceback (most recent call last):
3542 ...
3543 TypeError
3545 Note that :func:`only` attempts to advance *iterable* twice to ensure there
3546 is only one item. See :func:`spy` or :func:`peekable` to check
3547 iterable contents less destructively.
3549 """
3550 iterator = iter(iterable)
3551 for first in iterator:
3552 for second in iterator:
3553 msg = (
3554 f'Expected exactly one item in iterable, but got {first!r}, '
3555 f'{second!r}, and perhaps more.'
3556 )
3557 raise too_long or ValueError(msg)
3558 return first
3559 return default
3562def _ichunk(iterator, n):
3563 cache = deque()
3564 chunk = islice(iterator, n)
3566 def generator():
3567 with suppress(StopIteration):
3568 while True:
3569 if cache:
3570 yield cache.popleft()
3571 else:
3572 yield next(chunk)
3574 def materialize_next(n=1):
3575 # if n not specified materialize everything
3576 if n is None:
3577 cache.extend(chunk)
3578 return len(cache)
3580 to_cache = n - len(cache)
3582 # materialize up to n
3583 if to_cache > 0:
3584 cache.extend(islice(chunk, to_cache))
3586 # return number materialized up to n
3587 return min(n, len(cache))
3589 return (generator(), materialize_next)
3592def ichunked(iterable, n):
3593 """Break *iterable* into sub-iterables with *n* elements each.
3594 :func:`ichunked` is like :func:`chunked`, but it yields iterables
3595 instead of lists.
3597 If the sub-iterables are read in order, the elements of *iterable*
3598 won't be stored in memory.
3599 If they are read out of order, :func:`itertools.tee` is used to cache
3600 elements as necessary.
3602 >>> from itertools import count
3603 >>> all_chunks = ichunked(count(), 4)
3604 >>> c_1, c_2, c_3 = next(all_chunks), next(all_chunks), next(all_chunks)
3605 >>> list(c_2) # c_1's elements have been cached; c_3's haven't been
3606 [4, 5, 6, 7]
3607 >>> list(c_1)
3608 [0, 1, 2, 3]
3609 >>> list(c_3)
3610 [8, 9, 10, 11]
3612 """
3613 iterator = iter(iterable)
3614 while True:
3615 # Create new chunk
3616 chunk, materialize_next = _ichunk(iterator, n)
3618 # Check to see whether we're at the end of the source iterable
3619 if not materialize_next():
3620 return
3622 yield chunk
3624 # Fill previous chunk's cache
3625 materialize_next(None)
3628def iequals(*iterables):
3629 """Return ``True`` if all given *iterables* are equal to each other,
3630 which means that they contain the same elements in the same order.
3632 The function is useful for comparing iterables of different data types
3633 or iterables that do not support equality checks.
3635 >>> iequals("abc", ['a', 'b', 'c'], ('a', 'b', 'c'), iter("abc"))
3636 True
3638 >>> iequals("abc", "acb")
3639 False
3641 Not to be confused with :func:`all_equal`, which checks whether all
3642 elements of iterable are equal to each other.
3644 """
3645 try:
3646 return all(map(all_equal, zip(*iterables, strict=True)))
3647 except ValueError:
3648 return False
3651def distinct_combinations(iterable, r):
3652 """Yield the distinct combinations of *r* items taken from *iterable*.
3654 >>> list(distinct_combinations([0, 0, 1], 2))
3655 [(0, 0), (0, 1)]
3657 Equivalent to ``set(combinations(iterable))``, except duplicates are not
3658 generated and thrown away. For larger input sequences this is much more
3659 efficient.
3661 """
3662 if r < 0:
3663 raise ValueError('r must be non-negative')
3664 elif r == 0:
3665 yield ()
3666 return
3667 pool = tuple(iterable)
3668 generators = [unique_everseen(enumerate(pool), key=itemgetter(1))]
3669 current_combo = [None] * r
3670 level = 0
3671 while generators:
3672 try:
3673 cur_idx, p = next(generators[-1])
3674 except StopIteration:
3675 generators.pop()
3676 level -= 1
3677 continue
3678 current_combo[level] = p
3679 if level + 1 == r:
3680 yield tuple(current_combo)
3681 else:
3682 generators.append(
3683 unique_everseen(
3684 enumerate(pool[cur_idx + 1 :], cur_idx + 1),
3685 key=itemgetter(1),
3686 )
3687 )
3688 level += 1
3691def filter_except(validator, iterable, *exceptions):
3692 """Yield the items from *iterable* for which the *validator* function does
3693 not raise one of the specified *exceptions*.
3695 *validator* is called for each item in *iterable*.
3696 It should be a function that accepts one argument and raises an exception
3697 if that item is not valid.
3699 >>> iterable = ['1', '2', 'three', '4', None]
3700 >>> list(filter_except(int, iterable, ValueError, TypeError))
3701 ['1', '2', '4']
3703 If an exception other than one given by *exceptions* is raised by
3704 *validator*, it is raised like normal.
3705 """
3706 for item in iterable:
3707 try:
3708 validator(item)
3709 except exceptions:
3710 pass
3711 else:
3712 yield item
3715def map_except(function, iterable, *exceptions):
3716 """Transform each item from *iterable* with *function* and yield the
3717 result, unless *function* raises one of the specified *exceptions*.
3719 *function* is called to transform each item in *iterable*.
3720 It should accept one argument.
3722 >>> iterable = ['1', '2', 'three', '4', None]
3723 >>> list(map_except(int, iterable, ValueError, TypeError))
3724 [1, 2, 4]
3726 If an exception other than one given by *exceptions* is raised by
3727 *function*, it is raised like normal.
3728 """
3729 for item in iterable:
3730 try:
3731 yield function(item)
3732 except exceptions:
3733 pass
3736def map_if(iterable, pred, func, func_else=None):
3737 """Evaluate each item from *iterable* using *pred*. If the result is
3738 equivalent to ``True``, transform the item with *func* and yield it.
3739 Otherwise, transform the item with *func_else* and yield it.
3741 *pred*, *func*, and *func_else* should each be functions that accept
3742 one argument. By default, *func_else* is the identity function.
3744 >>> from math import sqrt
3745 >>> iterable = list(range(-5, 5))
3746 >>> iterable
3747 [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4]
3748 >>> list(map_if(iterable, lambda x: x > 3, lambda x: 'toobig'))
3749 [-5, -4, -3, -2, -1, 0, 1, 2, 3, 'toobig']
3750 >>> list(map_if(iterable, lambda x: x >= 0,
3751 ... lambda x: f'{sqrt(x):.2f}', lambda x: None))
3752 [None, None, None, None, None, '0.00', '1.00', '1.41', '1.73', '2.00']
3753 """
3755 if func_else is None:
3756 for item in iterable:
3757 yield func(item) if pred(item) else item
3759 else:
3760 for item in iterable:
3761 yield func(item) if pred(item) else func_else(item)
3764def _sample_unweighted(iterator, k, strict):
3765 # Algorithm L in the 1994 paper by Kim-Hung Li:
3766 # "Reservoir-Sampling Algorithms of Time Complexity O(n(1+log(N/n)))".
3768 reservoir = list(islice(iterator, k))
3769 if strict and len(reservoir) < k:
3770 raise ValueError('Sample larger than population')
3771 W = 1.0
3773 with suppress(StopIteration):
3774 while True:
3775 W *= random() ** (1 / k)
3776 skip = floor(log(random()) / log1p(-W))
3777 element = next(islice(iterator, skip, None))
3778 reservoir[randrange(k)] = element
3780 shuffle(reservoir)
3781 return reservoir
3784def _sample_weighted(iterator, k, weights, strict):
3785 # Implementation of "A-ExpJ" from the 2006 paper by Efraimidis et al. :
3786 # "Weighted random sampling with a reservoir".
3788 # Log-transform for numerical stability for weights that are small/large
3789 weight_keys = (log(random()) / weight for weight in weights)
3791 # Fill up the reservoir (collection of samples) with the first `k`
3792 # weight-keys and elements, then heapify the list.
3793 reservoir = take(k, zip(weight_keys, iterator))
3794 if strict and len(reservoir) < k:
3795 raise ValueError('Sample larger than population')
3797 heapify(reservoir)
3799 # The number of jumps before changing the reservoir is a random variable
3800 # with an exponential distribution. Sample it using random() and logs.
3801 smallest_weight_key, _ = reservoir[0]
3802 weights_to_skip = log(random()) / smallest_weight_key
3804 for weight, element in zip(weights, iterator):
3805 if weight >= weights_to_skip:
3806 # The notation here is consistent with the paper, but we store
3807 # the weight-keys in log-space for better numerical stability.
3808 smallest_weight_key, _ = reservoir[0]
3809 t_w = exp(weight * smallest_weight_key)
3810 r_2 = uniform(t_w, 1) # generate U(t_w, 1)
3811 weight_key = log(r_2) / weight
3812 heapreplace(reservoir, (weight_key, element))
3813 smallest_weight_key, _ = reservoir[0]
3814 weights_to_skip = log(random()) / smallest_weight_key
3815 else:
3816 weights_to_skip -= weight
3818 ret = [element for weight_key, element in reservoir]
3819 shuffle(ret)
3820 return ret
3823def _sample_counted(population, k, counts, strict):
3824 element = None
3825 remaining = 0
3827 def feed(i):
3828 # Advance *i* steps ahead and consume an element
3829 nonlocal element, remaining
3831 while i + 1 > remaining:
3832 i = i - remaining
3833 element = next(population)
3834 remaining = next(counts)
3835 remaining -= i + 1
3836 return element
3838 with suppress(StopIteration):
3839 reservoir = []
3840 for _ in range(k):
3841 reservoir.append(feed(0))
3843 if strict and len(reservoir) < k:
3844 raise ValueError('Sample larger than population')
3846 with suppress(StopIteration):
3847 W = 1.0
3848 while True:
3849 W *= random() ** (1 / k)
3850 skip = floor(log(random()) / log1p(-W))
3851 element = feed(skip)
3852 reservoir[randrange(k)] = element
3854 shuffle(reservoir)
3855 return reservoir
3858def sample(iterable, k, weights=None, *, counts=None, strict=False):
3859 """Return a *k*-length list of elements chosen (without replacement)
3860 from the *iterable*.
3862 Similar to :func:`random.sample`, but works on inputs that aren't
3863 indexable (such as sets and dictionaries) and on inputs where the
3864 size isn't known in advance (such as generators).
3866 >>> iterable = range(100)
3867 >>> sample(iterable, 5) # doctest: +SKIP
3868 [81, 60, 96, 16, 4]
3870 For iterables with repeated elements, you may supply *counts* to
3871 indicate the repeats.
3873 >>> iterable = ['a', 'b']
3874 >>> counts = [3, 4] # Equivalent to 'a', 'a', 'a', 'b', 'b', 'b', 'b'
3875 >>> sample(iterable, k=3, counts=counts) # doctest: +SKIP
3876 ['a', 'a', 'b']
3878 An iterable with *weights* may be given:
3880 >>> iterable = range(100)
3881 >>> weights = (i * i + 1 for i in range(100))
3882 >>> sampled = sample(iterable, 5, weights=weights) # doctest: +SKIP
3883 [79, 67, 74, 66, 78]
3885 Weighted selections are made without replacement.
3886 After an element is selected, it is removed from the pool and the
3887 relative weights of the other elements increase (this
3888 does not match the behavior of :func:`random.sample`'s *counts*
3889 parameter). Note that *weights* may not be used with *counts*.
3891 If the length of *iterable* is less than *k*,
3892 ``ValueError`` is raised if *strict* is ``True`` and
3893 all elements are returned (in shuffled order) if *strict* is ``False``.
3895 By default, the `Algorithm L <https://w.wiki/ANrM>`__ reservoir sampling
3896 technique is used. When *weights* are provided,
3897 `Algorithm A-ExpJ <https://w.wiki/ANrS>`__ is used instead.
3899 Notes on reproducibility:
3901 * The algorithms rely on inexact floating-point functions provided
3902 by the underlying math library (e.g. ``log``, ``log1p``, and ``pow``).
3903 Those functions can `produce slightly different results
3904 <https://members.loria.fr/PZimmermann/papers/accuracy.pdf>`_ on
3905 different builds. Accordingly, selections can vary across builds
3906 even for the same seed.
3908 * The algorithms loop over the input and make selections based on
3909 ordinal position, so selections from unordered collections (such as
3910 sets) won't reproduce across sessions on the same platform using the
3911 same seed. For example, this won't reproduce::
3913 >> seed(8675309)
3914 >> sample(set('abcdefghijklmnopqrstuvwxyz'), 10)
3915 ['c', 'p', 'e', 'w', 's', 'a', 'j', 'd', 'n', 't']
3917 """
3918 iterator = iter(iterable)
3920 if k < 0:
3921 raise ValueError('k must be non-negative')
3923 if k == 0:
3924 return []
3926 if weights is not None and counts is not None:
3927 raise TypeError('weights and counts are mutually exclusive')
3929 elif weights is not None:
3930 weights = iter(weights)
3931 return _sample_weighted(iterator, k, weights, strict)
3933 elif counts is not None:
3934 counts = iter(counts)
3935 return _sample_counted(iterator, k, counts, strict)
3937 else:
3938 return _sample_unweighted(iterator, k, strict)
3941def is_sorted(iterable, key=None, reverse=False, strict=False):
3942 """Returns ``True`` if the items of iterable are in sorted order, and
3943 ``False`` otherwise. *key* and *reverse* have the same meaning that they do
3944 in the built-in :func:`sorted` function.
3946 >>> is_sorted(['1', '2', '3', '4', '5'], key=int)
3947 True
3948 >>> is_sorted([5, 4, 3, 1, 2], reverse=True)
3949 False
3951 If *strict*, tests for strict sorting, that is, returns ``False`` if equal
3952 elements are found:
3954 >>> is_sorted([1, 2, 2])
3955 True
3956 >>> is_sorted([1, 2, 2], strict=True)
3957 False
3959 The function returns ``False`` after encountering the first out-of-order
3960 item, which means it may produce results that differ from the built-in
3961 :func:`sorted` function for objects with unusual comparison dynamics
3962 (like ``math.nan``). If there are no out-of-order items, the iterable is
3963 exhausted.
3964 """
3965 it = iterable if (key is None) else map(key, iterable)
3966 a, b = tee(it)
3967 next(b, None)
3968 if reverse:
3969 b, a = a, b
3970 return all(map(lt, a, b)) if strict else not any(map(lt, b, a))
3973class AbortThread(BaseException):
3974 pass
3977class callback_iter:
3978 """Convert a function that uses callbacks to an iterator.
3980 Let *func* be a function that takes a `callback` keyword argument.
3981 For example:
3983 >>> def func(callback=None):
3984 ... for i, c in [(1, 'a'), (2, 'b'), (3, 'c')]:
3985 ... if callback:
3986 ... callback(i, c)
3987 ... return 4
3990 Use ``with callback_iter(func)`` to get an iterator over the parameters
3991 that are delivered to the callback.
3993 >>> with callback_iter(func) as it:
3994 ... for args, kwargs in it:
3995 ... print(args)
3996 (1, 'a')
3997 (2, 'b')
3998 (3, 'c')
4000 The function will be called in a background thread. The ``done`` property
4001 indicates whether it has completed execution.
4003 >>> it.done
4004 True
4006 If it completes successfully, its return value will be available
4007 in the ``result`` property.
4009 >>> it.result
4010 4
4012 Notes:
4014 * If the function uses some keyword argument besides ``callback``, supply
4015 *callback_kwd*.
4016 * If it finished executing, but raised an exception, accessing the
4017 ``result`` property will raise the same exception.
4018 * If it hasn't finished executing, accessing the ``result``
4019 property from within the ``with`` block will raise ``RuntimeError``.
4020 * If it hasn't finished executing, accessing the ``result`` property from
4021 outside the ``with`` block will raise a
4022 ``more_itertools.AbortThread`` exception.
4023 * Provide *wait_seconds* to adjust how frequently the it is polled for
4024 output.
4026 """
4028 def __init__(self, func, callback_kwd='callback', wait_seconds=0.1):
4029 self._func = func
4030 self._callback_kwd = callback_kwd
4031 self._aborted = False
4032 self._future = None
4033 self._wait_seconds = wait_seconds
4034 # Lazily import concurrent.future
4035 self._executor = __import__(
4036 'concurrent.futures'
4037 ).futures.ThreadPoolExecutor(max_workers=1)
4038 self._iterator = self._reader()
4040 def __enter__(self):
4041 return self
4043 def __exit__(self, exc_type, exc_value, traceback):
4044 self._aborted = True
4045 self._executor.shutdown()
4047 def __iter__(self):
4048 return self
4050 def __next__(self):
4051 return next(self._iterator)
4053 @property
4054 def done(self):
4055 if self._future is None:
4056 return False
4057 return self._future.done()
4059 @property
4060 def result(self):
4061 if not self.done:
4062 raise RuntimeError('Function has not yet completed')
4064 return self._future.result()
4066 def _reader(self):
4067 q = Queue()
4069 def callback(*args, **kwargs):
4070 if self._aborted:
4071 raise AbortThread('canceled by user')
4073 q.put((args, kwargs))
4075 self._future = self._executor.submit(
4076 self._func, **{self._callback_kwd: callback}
4077 )
4079 while True:
4080 try:
4081 item = q.get(timeout=self._wait_seconds)
4082 except Empty:
4083 pass
4084 else:
4085 q.task_done()
4086 yield item
4088 if self._future.done():
4089 break
4091 remaining = []
4092 while True:
4093 try:
4094 item = q.get_nowait()
4095 except Empty:
4096 break
4097 else:
4098 q.task_done()
4099 remaining.append(item)
4100 q.join()
4101 yield from remaining
4104def windowed_complete(iterable, n):
4105 """
4106 Yield ``(beginning, middle, end)`` tuples, where:
4108 * Each ``middle`` has *n* items from *iterable*
4109 * Each ``beginning`` has the items before the ones in ``middle``
4110 * Each ``end`` has the items after the ones in ``middle``
4112 >>> iterable = range(7)
4113 >>> n = 3
4114 >>> for beginning, middle, end in windowed_complete(iterable, n):
4115 ... print(beginning, middle, end)
4116 () (0, 1, 2) (3, 4, 5, 6)
4117 (0,) (1, 2, 3) (4, 5, 6)
4118 (0, 1) (2, 3, 4) (5, 6)
4119 (0, 1, 2) (3, 4, 5) (6,)
4120 (0, 1, 2, 3) (4, 5, 6) ()
4122 Note that *n* must be at least 0 and most equal to the length of
4123 *iterable*.
4125 This function will exhaust the iterable and may require significant
4126 storage.
4127 """
4128 if n < 0:
4129 raise ValueError('n must be >= 0')
4131 seq = tuple(iterable)
4132 size = len(seq)
4134 if n > size:
4135 raise ValueError('n must be <= len(seq)')
4137 for i in range(size - n + 1):
4138 beginning = seq[:i]
4139 middle = seq[i : i + n]
4140 end = seq[i + n :]
4141 yield beginning, middle, end
4144def all_unique(iterable, key=None):
4145 """
4146 Returns ``True`` if all the elements of *iterable* are unique (no two
4147 elements are equal).
4149 >>> all_unique('ABCB')
4150 False
4152 If a *key* function is specified, it will be used to make comparisons.
4154 >>> all_unique('ABCb')
4155 True
4156 >>> all_unique('ABCb', str.lower)
4157 False
4159 The function returns as soon as the first non-unique element is
4160 encountered. Iterables with a mix of hashable and unhashable items can
4161 be used, but the function will be slower for unhashable items.
4162 """
4163 seenset = set()
4164 seenset_add = seenset.add
4165 seenlist = []
4166 seenlist_add = seenlist.append
4167 for element in map(key, iterable) if key else iterable:
4168 try:
4169 if element in seenset:
4170 return False
4171 seenset_add(element)
4172 except TypeError:
4173 if element in seenlist:
4174 return False
4175 seenlist_add(element)
4176 return True
4179def nth_product(index, *args):
4180 """Equivalent to ``list(product(*args))[index]``.
4182 The products of *args* can be ordered lexicographically.
4183 :func:`nth_product` computes the product at sort position *index* without
4184 computing the previous products.
4186 >>> nth_product(8, range(2), range(2), range(2), range(2))
4187 (1, 0, 0, 0)
4189 ``IndexError`` will be raised if the given *index* is invalid.
4190 """
4191 pools = list(map(tuple, reversed(args)))
4192 ns = list(map(len, pools))
4194 c = reduce(mul, ns)
4196 if index < 0:
4197 index += c
4199 if not 0 <= index < c:
4200 raise IndexError
4202 result = []
4203 for pool, n in zip(pools, ns):
4204 result.append(pool[index % n])
4205 index //= n
4207 return tuple(reversed(result))
4210def nth_permutation(iterable, r, index):
4211 """Equivalent to ``list(permutations(iterable, r))[index]```
4213 The subsequences of *iterable* that are of length *r* where order is
4214 important can be ordered lexicographically. :func:`nth_permutation`
4215 computes the subsequence at sort position *index* directly, without
4216 computing the previous subsequences.
4218 >>> nth_permutation('ghijk', 2, 5)
4219 ('h', 'i')
4221 ``ValueError`` will be raised If *r* is negative or greater than the length
4222 of *iterable*.
4223 ``IndexError`` will be raised if the given *index* is invalid.
4224 """
4225 pool = list(iterable)
4226 n = len(pool)
4228 if r is None or r == n:
4229 r, c = n, factorial(n)
4230 elif not 0 <= r < n:
4231 raise ValueError
4232 else:
4233 c = perm(n, r)
4234 assert c > 0 # factorial(n)>0, and r<n so perm(n,r) is never zero
4236 if index < 0:
4237 index += c
4239 if not 0 <= index < c:
4240 raise IndexError
4242 result = [0] * r
4243 q = index * factorial(n) // c if r < n else index
4244 for d in range(1, n + 1):
4245 q, i = divmod(q, d)
4246 if 0 <= n - d < r:
4247 result[n - d] = i
4248 if q == 0:
4249 break
4251 return tuple(map(pool.pop, result))
4254def nth_combination_with_replacement(iterable, r, index):
4255 """Equivalent to
4256 ``list(combinations_with_replacement(iterable, r))[index]``.
4259 The subsequences with repetition of *iterable* that are of length *r* can
4260 be ordered lexicographically. :func:`nth_combination_with_replacement`
4261 computes the subsequence at sort position *index* directly, without
4262 computing the previous subsequences with replacement.
4264 >>> nth_combination_with_replacement(range(5), 3, 5)
4265 (0, 1, 1)
4267 ``ValueError`` will be raised If *r* is negative or greater than the length
4268 of *iterable*.
4269 ``IndexError`` will be raised if the given *index* is invalid.
4270 """
4271 pool = tuple(iterable)
4272 n = len(pool)
4273 if (r < 0) or (r > n):
4274 raise ValueError
4276 c = comb(n + r - 1, r)
4278 if index < 0:
4279 index += c
4281 if (index < 0) or (index >= c):
4282 raise IndexError
4284 result = []
4285 i = 0
4286 while r:
4287 r -= 1
4288 while n >= 0:
4289 num_combs = comb(n + r - 1, r)
4290 if index < num_combs:
4291 break
4292 n -= 1
4293 i += 1
4294 index -= num_combs
4295 result.append(pool[i])
4297 return tuple(result)
4300def value_chain(*args):
4301 """Yield all arguments passed to the function in the same order in which
4302 they were passed. If an argument itself is iterable then iterate over its
4303 values.
4305 >>> list(value_chain(1, 2, 3, [4, 5, 6]))
4306 [1, 2, 3, 4, 5, 6]
4308 Binary and text strings are not considered iterable and are emitted
4309 as-is:
4311 >>> list(value_chain('12', '34', ['56', '78']))
4312 ['12', '34', '56', '78']
4314 Pre- or postpend a single element to an iterable:
4316 >>> list(value_chain(1, [2, 3, 4, 5, 6]))
4317 [1, 2, 3, 4, 5, 6]
4318 >>> list(value_chain([1, 2, 3, 4, 5], 6))
4319 [1, 2, 3, 4, 5, 6]
4321 Multiple levels of nesting are not flattened.
4323 """
4324 scalar_types = (str, bytes)
4325 for value in args:
4326 if isinstance(value, scalar_types):
4327 yield value
4328 continue
4329 try:
4330 yield from value
4331 except TypeError:
4332 yield value
4335def product_index(element, *args):
4336 """Equivalent to ``list(product(*args)).index(element)``
4338 The products of *args* can be ordered lexicographically.
4339 :func:`product_index` computes the first index of *element* without
4340 computing the previous products.
4342 >>> product_index([8, 2], range(10), range(5))
4343 42
4345 ``ValueError`` will be raised if the given *element* isn't in the product
4346 of *args*.
4347 """
4348 elements = tuple(element)
4349 pools = tuple(map(tuple, args))
4350 if len(elements) != len(pools):
4351 raise ValueError('element is not a product of args')
4353 index = 0
4354 for elem, pool in zip(elements, pools):
4355 index = index * len(pool) + pool.index(elem)
4356 return index
4359def combination_index(element, iterable):
4360 """Equivalent to ``list(combinations(iterable, r)).index(element)``
4362 The subsequences of *iterable* that are of length *r* can be ordered
4363 lexicographically. :func:`combination_index` computes the index of the
4364 first *element*, without computing the previous combinations.
4366 >>> combination_index('adf', 'abcdefg')
4367 10
4369 ``ValueError`` will be raised if the given *element* isn't one of the
4370 combinations of *iterable*.
4371 """
4372 element = enumerate(element)
4373 k, y = next(element, (None, None))
4374 if k is None:
4375 return 0
4377 indexes = []
4378 pool = enumerate(iterable)
4379 for n, x in pool:
4380 if x == y:
4381 indexes.append(n)
4382 tmp, y = next(element, (None, None))
4383 if tmp is None:
4384 break
4385 else:
4386 k = tmp
4387 else:
4388 raise ValueError('element is not a combination of iterable')
4390 n, _ = last(pool, default=(n, None))
4392 index = 1
4393 for i, j in enumerate(reversed(indexes), start=1):
4394 j = n - j
4395 if i <= j:
4396 index += comb(j, i)
4398 return comb(n + 1, k + 1) - index
4401def combination_with_replacement_index(element, iterable):
4402 """Equivalent to
4403 ``list(combinations_with_replacement(iterable, r)).index(element)``
4405 The subsequences with repetition of *iterable* that are of length *r* can
4406 be ordered lexicographically. :func:`combination_with_replacement_index`
4407 computes the index of the first *element*, without computing the previous
4408 combinations with replacement.
4410 >>> combination_with_replacement_index('adf', 'abcdefg')
4411 20
4413 ``ValueError`` will be raised if the given *element* isn't one of the
4414 combinations with replacement of *iterable*.
4415 """
4416 element = tuple(element)
4417 l = len(element)
4418 element = enumerate(element)
4420 k, y = next(element, (None, None))
4421 if k is None:
4422 return 0
4424 indexes = []
4425 pool = tuple(iterable)
4426 for n, x in enumerate(pool):
4427 while x == y:
4428 indexes.append(n)
4429 tmp, y = next(element, (None, None))
4430 if tmp is None:
4431 break
4432 else:
4433 k = tmp
4434 if y is None:
4435 break
4436 else:
4437 raise ValueError(
4438 'element is not a combination with replacement of iterable'
4439 )
4441 n = len(pool)
4442 occupations = [0] * n
4443 for p in indexes:
4444 occupations[p] += 1
4446 index = 0
4447 cumulative_sum = 0
4448 for k in range(1, n):
4449 cumulative_sum += occupations[k - 1]
4450 j = l + n - 1 - k - cumulative_sum
4451 i = n - k
4452 if i <= j:
4453 index += comb(j, i)
4455 return index
4458def permutation_index(element, iterable):
4459 """Equivalent to ``list(permutations(iterable, r)).index(element)```
4461 The subsequences of *iterable* that are of length *r* where order is
4462 important can be ordered lexicographically. :func:`permutation_index`
4463 computes the index of the first *element* directly, without computing
4464 the previous permutations.
4466 >>> permutation_index([1, 3, 2], range(5))
4467 19
4469 ``ValueError`` will be raised if the given *element* isn't one of the
4470 permutations of *iterable*.
4471 """
4472 index = 0
4473 pool = list(iterable)
4474 for i, x in zip(range(len(pool), -1, -1), element):
4475 r = pool.index(x)
4476 index = index * i + r
4477 del pool[r]
4479 return index
4482class countable:
4483 """Wrap *iterable* and keep a count of how many items have been consumed.
4485 The ``items_seen`` attribute starts at ``0`` and increments as the iterable
4486 is consumed:
4488 >>> iterable = map(str, range(10))
4489 >>> it = countable(iterable)
4490 >>> it.items_seen
4491 0
4492 >>> next(it), next(it)
4493 ('0', '1')
4494 >>> list(it)
4495 ['2', '3', '4', '5', '6', '7', '8', '9']
4496 >>> it.items_seen
4497 10
4498 """
4500 def __init__(self, iterable):
4501 self._iterator = iter(iterable)
4502 self.items_seen = 0
4504 def __iter__(self):
4505 return self
4507 def __next__(self):
4508 item = next(self._iterator)
4509 self.items_seen += 1
4511 return item
4514def chunked_even(iterable, n):
4515 """Break *iterable* into lists of approximately length *n*.
4516 Items are distributed such the lengths of the lists differ by at most
4517 1 item.
4519 >>> iterable = [1, 2, 3, 4, 5, 6, 7]
4520 >>> n = 3
4521 >>> list(chunked_even(iterable, n)) # List lengths: 3, 2, 2
4522 [[1, 2, 3], [4, 5], [6, 7]]
4523 >>> list(chunked(iterable, n)) # List lengths: 3, 3, 1
4524 [[1, 2, 3], [4, 5, 6], [7]]
4526 """
4527 iterator = iter(iterable)
4529 # Initialize a buffer to process the chunks while keeping
4530 # some back to fill any underfilled chunks
4531 min_buffer = (n - 1) * (n - 2)
4532 buffer = list(islice(iterator, min_buffer))
4534 # Append items until we have a completed chunk
4535 for _ in islice(map(buffer.append, iterator), n, None, n):
4536 yield buffer[:n]
4537 del buffer[:n]
4539 # Check if any chunks need addition processing
4540 if not buffer:
4541 return
4542 length = len(buffer)
4544 # Chunks are either size `full_size <= n` or `partial_size = full_size - 1`
4545 q, r = divmod(length, n)
4546 num_lists = q + (1 if r > 0 else 0)
4547 q, r = divmod(length, num_lists)
4548 full_size = q + (1 if r > 0 else 0)
4549 partial_size = full_size - 1
4550 num_full = length - partial_size * num_lists
4552 # Yield chunks of full size
4553 partial_start_idx = num_full * full_size
4554 if full_size > 0:
4555 for i in range(0, partial_start_idx, full_size):
4556 yield buffer[i : i + full_size]
4558 # Yield chunks of partial size
4559 if partial_size > 0:
4560 for i in range(partial_start_idx, length, partial_size):
4561 yield buffer[i : i + partial_size]
4564def zip_broadcast(*objects, scalar_types=(str, bytes), strict=False):
4565 """A version of :func:`zip` that "broadcasts" any scalar
4566 (i.e., non-iterable) items into output tuples.
4568 >>> iterable_1 = [1, 2, 3]
4569 >>> iterable_2 = ['a', 'b', 'c']
4570 >>> scalar = '_'
4571 >>> list(zip_broadcast(iterable_1, iterable_2, scalar))
4572 [(1, 'a', '_'), (2, 'b', '_'), (3, 'c', '_')]
4574 The *scalar_types* keyword argument determines what types are considered
4575 scalar. It is set to ``(str, bytes)`` by default. Set it to ``None`` to
4576 treat strings and byte strings as iterable:
4578 >>> list(zip_broadcast('abc', 0, 'xyz', scalar_types=None))
4579 [('a', 0, 'x'), ('b', 0, 'y'), ('c', 0, 'z')]
4581 If the *strict* keyword argument is ``True``, then
4582 ``ValueError`` will be raised if any of the iterables have
4583 different lengths.
4584 """
4586 def is_scalar(obj):
4587 if scalar_types and isinstance(obj, scalar_types):
4588 return True
4589 try:
4590 iter(obj)
4591 except TypeError:
4592 return True
4593 else:
4594 return False
4596 size = len(objects)
4597 if not size:
4598 return
4600 new_item = [None] * size
4601 iterables, iterable_positions = [], []
4602 for i, obj in enumerate(objects):
4603 if is_scalar(obj):
4604 new_item[i] = obj
4605 else:
4606 iterables.append(iter(obj))
4607 iterable_positions.append(i)
4609 if not iterables:
4610 yield tuple(objects)
4611 return
4613 for item in zip(*iterables, strict=strict):
4614 for i, new_item[i] in zip(iterable_positions, item):
4615 pass
4616 yield tuple(new_item)
4619def unique_in_window(iterable, n, key=None):
4620 """Yield the items from *iterable* that haven't been seen recently.
4621 *n* is the size of the sliding window.
4623 >>> iterable = [0, 1, 0, 2, 3, 0]
4624 >>> n = 3
4625 >>> list(unique_in_window(iterable, n))
4626 [0, 1, 2, 3, 0]
4628 The *key* function, if provided, will be used to determine uniqueness:
4630 >>> list(unique_in_window('abAcda', 3, key=lambda x: x.lower()))
4631 ['a', 'b', 'c', 'd', 'a']
4633 Updates a sliding window no larger than n and yields a value
4634 if the item only occurs once in the updated window.
4636 When `n == 1`, *unique_in_window* is memoryless:
4638 >>> list(unique_in_window('aab', n=1))
4639 ['a', 'a', 'b']
4641 The items in *iterable* must be hashable.
4643 """
4644 if n <= 0:
4645 raise ValueError('n must be greater than 0')
4647 window = deque(maxlen=n)
4648 counts = Counter()
4649 use_key = key is not None
4651 for item in iterable:
4652 if len(window) == n:
4653 to_discard = window[0]
4654 if counts[to_discard] == 1:
4655 del counts[to_discard]
4656 else:
4657 counts[to_discard] -= 1
4659 k = key(item) if use_key else item
4660 if k not in counts:
4661 yield item
4662 counts[k] += 1
4663 window.append(k)
4666def duplicates_everseen(iterable, key=None):
4667 """Yield duplicate elements after their first appearance.
4669 >>> list(duplicates_everseen('mississippi'))
4670 ['s', 'i', 's', 's', 'i', 'p', 'i']
4671 >>> list(duplicates_everseen('AaaBbbCccAaa', str.lower))
4672 ['a', 'a', 'b', 'b', 'c', 'c', 'A', 'a', 'a']
4674 This function is analogous to :func:`unique_everseen` and is subject to
4675 the same performance considerations.
4677 """
4678 seen_set = set()
4679 seen_list = []
4680 use_key = key is not None
4682 for element in iterable:
4683 k = key(element) if use_key else element
4684 try:
4685 if k not in seen_set:
4686 seen_set.add(k)
4687 else:
4688 yield element
4689 except TypeError:
4690 if k not in seen_list:
4691 seen_list.append(k)
4692 else:
4693 yield element
4696def duplicates_justseen(iterable, key=None):
4697 """Yields serially-duplicate elements after their first appearance.
4699 >>> list(duplicates_justseen('mississippi'))
4700 ['s', 's', 'p']
4701 >>> list(duplicates_justseen('AaaBbbCccAaa', str.lower))
4702 ['a', 'a', 'b', 'b', 'c', 'c', 'a', 'a']
4704 This function is analogous to :func:`unique_justseen`.
4706 """
4707 return flatten(g for _, g in groupby(iterable, key) for _ in g)
4710def classify_unique(iterable, key=None):
4711 """Classify each element in terms of its uniqueness.
4713 For each element in the input iterable, return a 3-tuple consisting of:
4715 1. The element itself
4716 2. ``False`` if the element is equal to the one preceding it in the input,
4717 ``True`` otherwise (i.e. the equivalent of :func:`unique_justseen`)
4718 3. ``False`` if this element has been seen anywhere in the input before,
4719 ``True`` otherwise (i.e. the equivalent of :func:`unique_everseen`)
4721 >>> list(classify_unique('otto')) # doctest: +NORMALIZE_WHITESPACE
4722 [('o', True, True),
4723 ('t', True, True),
4724 ('t', False, False),
4725 ('o', True, False)]
4727 This function is analogous to :func:`unique_everseen` and is subject to
4728 the same performance considerations.
4730 """
4731 seen_set = set()
4732 seen_list = []
4733 use_key = key is not None
4734 previous = None
4736 for i, element in enumerate(iterable):
4737 k = key(element) if use_key else element
4738 is_unique_justseen = not i or previous != k
4739 previous = k
4740 is_unique_everseen = False
4741 try:
4742 if k not in seen_set:
4743 seen_set.add(k)
4744 is_unique_everseen = True
4745 except TypeError:
4746 if k not in seen_list:
4747 seen_list.append(k)
4748 is_unique_everseen = True
4749 yield element, is_unique_justseen, is_unique_everseen
4752def minmax(iterable_or_value, *others, key=None, default=_marker):
4753 """Returns both the smallest and largest items from an iterable
4754 or from two or more arguments.
4756 >>> minmax([3, 1, 5])
4757 (1, 5)
4759 >>> minmax(4, 2, 6)
4760 (2, 6)
4762 If a *key* function is provided, it will be used to transform the input
4763 items for comparison.
4765 >>> minmax([5, 30], key=str) # '30' sorts before '5'
4766 (30, 5)
4768 If a *default* value is provided, it will be returned if there are no
4769 input items.
4771 >>> minmax([], default=(0, 0))
4772 (0, 0)
4774 Otherwise ``ValueError`` is raised.
4776 This function makes a single pass over the input elements and takes care to
4777 minimize the number of comparisons made during processing.
4779 Note that unlike the builtin ``max`` function, which always returns the first
4780 item with the maximum value, this function may return another item when there are
4781 ties.
4783 This function is based on the
4784 `recipe <https://code.activestate.com/recipes/577916-fast-minmax-function>`__ by
4785 Raymond Hettinger.
4786 """
4787 iterable = (iterable_or_value, *others) if others else iterable_or_value
4789 it = iter(iterable)
4791 try:
4792 lo = hi = next(it)
4793 except StopIteration as exc:
4794 if default is _marker:
4795 raise ValueError(
4796 '`minmax()` argument is an empty iterable. '
4797 'Provide a `default` value to suppress this error.'
4798 ) from exc
4799 return default
4801 # Different branches depending on the presence of key. This saves a lot
4802 # of unimportant copies which would slow the "key=None" branch
4803 # significantly down.
4804 if key is None:
4805 for x, y in zip_longest(it, it, fillvalue=lo):
4806 if y < x:
4807 x, y = y, x
4808 if x < lo:
4809 lo = x
4810 if hi < y:
4811 hi = y
4813 else:
4814 lo_key = hi_key = key(lo)
4816 for x, y in zip_longest(it, it, fillvalue=lo):
4817 x_key, y_key = key(x), key(y)
4819 if y_key < x_key:
4820 x, y, x_key, y_key = y, x, y_key, x_key
4821 if x_key < lo_key:
4822 lo, lo_key = x, x_key
4823 if hi_key < y_key:
4824 hi, hi_key = y, y_key
4826 return lo, hi
4829def constrained_batches(
4830 iterable, max_size, max_count=None, get_len=len, strict=True
4831):
4832 """Yield batches of items from *iterable* with a combined size limited by
4833 *max_size*.
4835 >>> iterable = [b'12345', b'123', b'12345678', b'1', b'1', b'12', b'1']
4836 >>> list(constrained_batches(iterable, 10))
4837 [(b'12345', b'123'), (b'12345678', b'1', b'1'), (b'12', b'1')]
4839 If a *max_count* is supplied, the number of items per batch is also
4840 limited:
4842 >>> iterable = [b'12345', b'123', b'12345678', b'1', b'1', b'12', b'1']
4843 >>> list(constrained_batches(iterable, 10, max_count = 2))
4844 [(b'12345', b'123'), (b'12345678', b'1'), (b'1', b'12'), (b'1',)]
4846 If a *get_len* function is supplied, use that instead of :func:`len` to
4847 determine item size.
4849 If *strict* is ``True``, raise ``ValueError`` if any single item is bigger
4850 than *max_size*. Otherwise, allow single items to exceed *max_size*.
4851 """
4852 if max_size <= 0:
4853 raise ValueError('maximum size must be greater than zero')
4855 batch = []
4856 batch_size = 0
4857 batch_count = 0
4858 for item in iterable:
4859 item_len = get_len(item)
4860 if strict and item_len > max_size:
4861 raise ValueError('item size exceeds maximum size')
4863 reached_count = batch_count == max_count
4864 reached_size = item_len + batch_size > max_size
4865 if batch_count and (reached_size or reached_count):
4866 yield tuple(batch)
4867 batch.clear()
4868 batch_size = 0
4869 batch_count = 0
4871 batch.append(item)
4872 batch_size += item_len
4873 batch_count += 1
4875 if batch:
4876 yield tuple(batch)
4879def gray_product(*iterables):
4880 """Like :func:`itertools.product`, but return tuples in an order such
4881 that only one element in the generated tuple changes from one iteration
4882 to the next.
4884 >>> list(gray_product('AB','CD'))
4885 [('A', 'C'), ('B', 'C'), ('B', 'D'), ('A', 'D')]
4887 This function consumes all of the input iterables before producing output.
4888 If any of the input iterables have fewer than two items, ``ValueError``
4889 is raised.
4891 For information on the algorithm, see
4892 `this section <https://www-cs-faculty.stanford.edu/~knuth/fasc2a.ps.gz>`__
4893 of Donald Knuth's *The Art of Computer Programming*.
4894 """
4895 all_iterables = tuple(tuple(x) for x in iterables)
4896 iterable_count = len(all_iterables)
4897 for iterable in all_iterables:
4898 if len(iterable) < 2:
4899 raise ValueError("each iterable must have two or more items")
4901 # This is based on "Algorithm H" from section 7.2.1.1, page 20.
4902 # a holds the indexes of the source iterables for the n-tuple to be yielded
4903 # f is the array of "focus pointers"
4904 # o is the array of "directions"
4905 a = [0] * iterable_count
4906 f = list(range(iterable_count + 1))
4907 o = [1] * iterable_count
4908 while True:
4909 yield tuple(all_iterables[i][a[i]] for i in range(iterable_count))
4910 j = f[0]
4911 f[0] = 0
4912 if j == iterable_count:
4913 break
4914 a[j] = a[j] + o[j]
4915 if a[j] == 0 or a[j] == len(all_iterables[j]) - 1:
4916 o[j] = -o[j]
4917 f[j] = f[j + 1]
4918 f[j + 1] = j + 1
4921def partial_product(*iterables):
4922 """Yields tuples containing one item from each iterator, with subsequent
4923 tuples changing a single item at a time by advancing each iterator until it
4924 is exhausted. This sequence guarantees every value in each iterable is
4925 output at least once without generating all possible combinations.
4927 This may be useful, for example, when testing an expensive function.
4929 >>> list(partial_product('AB', 'C', 'DEF'))
4930 [('A', 'C', 'D'), ('B', 'C', 'D'), ('B', 'C', 'E'), ('B', 'C', 'F')]
4931 """
4933 iterators = list(map(iter, iterables))
4935 try:
4936 prod = [next(it) for it in iterators]
4937 except StopIteration:
4938 return
4939 yield tuple(prod)
4941 for i, it in enumerate(iterators):
4942 for prod[i] in it:
4943 yield tuple(prod)
4946def takewhile_inclusive(predicate, iterable):
4947 """A variant of :func:`takewhile` that yields one additional element.
4949 >>> list(takewhile_inclusive(lambda x: x < 5, [1, 4, 6, 4, 1]))
4950 [1, 4, 6]
4952 :func:`takewhile` would return ``[1, 4]``.
4953 """
4954 for x in iterable:
4955 yield x
4956 if not predicate(x):
4957 break
4960def outer_product(func, xs, ys, *args, **kwargs):
4961 """A generalized outer product that applies a binary function to all
4962 pairs of items. Returns a 2D matrix with ``len(xs)`` rows and ``len(ys)``
4963 columns.
4964 Also accepts ``*args`` and ``**kwargs`` that are passed to ``func``.
4966 Multiplication table:
4968 >>> list(outer_product(mul, range(1, 4), range(1, 6)))
4969 [(1, 2, 3, 4, 5), (2, 4, 6, 8, 10), (3, 6, 9, 12, 15)]
4971 Cross tabulation:
4973 >>> xs = ['A', 'B', 'A', 'A', 'B', 'B', 'A', 'A', 'B', 'B']
4974 >>> ys = ['X', 'X', 'X', 'Y', 'Z', 'Z', 'Y', 'Y', 'Z', 'Z']
4975 >>> pair_counts = Counter(zip(xs, ys))
4976 >>> count_rows = lambda x, y: pair_counts[x, y]
4977 >>> list(outer_product(count_rows, sorted(set(xs)), sorted(set(ys))))
4978 [(2, 3, 0), (1, 0, 4)]
4980 Usage with ``*args`` and ``**kwargs``:
4982 >>> animals = ['cat', 'wolf', 'mouse']
4983 >>> list(outer_product(min, animals, animals, key=len))
4984 [('cat', 'cat', 'cat'), ('cat', 'wolf', 'wolf'), ('cat', 'wolf', 'mouse')]
4985 """
4986 ys = tuple(ys)
4987 return batched(
4988 starmap(lambda x, y: func(x, y, *args, **kwargs), product(xs, ys)),
4989 n=len(ys),
4990 )
4993def iter_suppress(iterable, *exceptions):
4994 """Yield each of the items from *iterable*. If the iteration raises one of
4995 the specified *exceptions*, that exception will be suppressed and iteration
4996 will stop.
4998 >>> from itertools import chain
4999 >>> def breaks_at_five(x):
5000 ... while True:
5001 ... if x >= 5:
5002 ... raise RuntimeError
5003 ... yield x
5004 ... x += 1
5005 >>> it_1 = iter_suppress(breaks_at_five(1), RuntimeError)
5006 >>> it_2 = iter_suppress(breaks_at_five(2), RuntimeError)
5007 >>> list(chain(it_1, it_2))
5008 [1, 2, 3, 4, 2, 3, 4]
5009 """
5010 try:
5011 yield from iterable
5012 except exceptions:
5013 return
5016def filter_map(func, iterable):
5017 """Apply *func* to every element of *iterable*, yielding only those which
5018 are not ``None``.
5020 >>> elems = ['1', 'a', '2', 'b', '3']
5021 >>> list(filter_map(lambda s: int(s) if s.isnumeric() else None, elems))
5022 [1, 2, 3]
5023 """
5024 for x in iterable:
5025 y = func(x)
5026 if y is not None:
5027 yield y
5030def powerset_of_sets(iterable, *, baseset=set):
5031 """Yields all possible subsets of the iterable.
5033 >>> list(powerset_of_sets([1, 2, 3])) # doctest: +SKIP
5034 [set(), {1}, {2}, {3}, {1, 2}, {1, 3}, {2, 3}, {1, 2, 3}]
5035 >>> list(powerset_of_sets([1, 1, 0])) # doctest: +SKIP
5036 [set(), {1}, {0}, {0, 1}]
5038 :func:`powerset_of_sets` takes care to minimize the number
5039 of hash operations performed.
5041 The *baseset* parameter determines what kind of sets are
5042 constructed, either *set* or *frozenset*.
5043 """
5044 sets = tuple(dict.fromkeys(map(frozenset, zip(iterable))))
5045 union = baseset().union
5046 return chain.from_iterable(
5047 starmap(union, combinations(sets, r)) for r in range(len(sets) + 1)
5048 )
5051def join_mappings(**field_to_map):
5052 """
5053 Joins multiple mappings together using their common keys.
5055 >>> user_scores = {'elliot': 50, 'claris': 60}
5056 >>> user_times = {'elliot': 30, 'claris': 40}
5057 >>> join_mappings(score=user_scores, time=user_times)
5058 {'elliot': {'score': 50, 'time': 30}, 'claris': {'score': 60, 'time': 40}}
5059 """
5060 ret = defaultdict(dict)
5062 for field_name, mapping in field_to_map.items():
5063 for key, value in mapping.items():
5064 ret[key][field_name] = value
5066 return dict(ret)
5069def _complex_sumprod(v1, v2):
5070 """High precision sumprod() for complex numbers.
5071 Used by :func:`dft` and :func:`idft`.
5072 """
5074 real = attrgetter('real')
5075 imag = attrgetter('imag')
5076 r1 = chain(map(real, v1), map(neg, map(imag, v1)))
5077 r2 = chain(map(real, v2), map(imag, v2))
5078 i1 = chain(map(real, v1), map(imag, v1))
5079 i2 = chain(map(imag, v2), map(real, v2))
5080 return complex(_fsumprod(r1, r2), _fsumprod(i1, i2))
5083def dft(xarr):
5084 """Discrete Fourier Transform. *xarr* is a sequence of complex numbers.
5085 Yields the components of the corresponding transformed output vector.
5087 >>> import cmath
5088 >>> xarr = [1, 2-1j, -1j, -1+2j] # time domain
5089 >>> Xarr = [2, -2-2j, -2j, 4+4j] # frequency domain
5090 >>> magnitudes, phases = zip(*map(cmath.polar, Xarr))
5091 >>> all(map(cmath.isclose, dft(xarr), Xarr))
5092 True
5094 Inputs are restricted to numeric types that can add and multiply
5095 with a complex number. This includes int, float, complex, and
5096 Fraction, but excludes Decimal.
5098 See :func:`idft` for the inverse Discrete Fourier Transform.
5099 """
5100 N = len(xarr)
5101 roots_of_unity = [e ** (n / N * tau * -1j) for n in range(N)]
5102 for k in range(N):
5103 coeffs = [roots_of_unity[k * n % N] for n in range(N)]
5104 yield _complex_sumprod(xarr, coeffs)
5107def idft(Xarr):
5108 """Inverse Discrete Fourier Transform. *Xarr* is a sequence of
5109 complex numbers. Yields the components of the corresponding
5110 inverse-transformed output vector.
5112 >>> import cmath
5113 >>> xarr = [1, 2-1j, -1j, -1+2j] # time domain
5114 >>> Xarr = [2, -2-2j, -2j, 4+4j] # frequency domain
5115 >>> all(map(cmath.isclose, idft(Xarr), xarr))
5116 True
5118 Inputs are restricted to numeric types that can add and multiply
5119 with a complex number. This includes int, float, complex, and
5120 Fraction, but excludes Decimal.
5122 See :func:`dft` for the Discrete Fourier Transform.
5123 """
5124 N = len(Xarr)
5125 roots_of_unity = [e ** (n / N * tau * 1j) for n in range(N)]
5126 for k in range(N):
5127 coeffs = [roots_of_unity[k * n % N] for n in range(N)]
5128 yield _complex_sumprod(Xarr, coeffs) / N
5131def doublestarmap(func, iterable):
5132 """Apply *func* to every item of *iterable* by dictionary unpacking
5133 the item into *func*.
5135 The difference between :func:`itertools.starmap` and :func:`doublestarmap`
5136 parallels the distinction between ``func(*a)`` and ``func(**a)``.
5138 >>> iterable = [{'a': 1, 'b': 2}, {'a': 40, 'b': 60}]
5139 >>> list(doublestarmap(lambda a, b: a + b, iterable))
5140 [3, 100]
5142 ``TypeError`` will be raised if *func*'s signature doesn't match the
5143 mapping contained in *iterable* or if *iterable* does not contain mappings.
5144 """
5145 for item in iterable:
5146 yield func(**item)
5149def _nth_prime_bounds(n):
5150 """Bounds for the nth prime (counting from 1): lb < p_n < ub."""
5151 # At and above 688,383, the lb/ub spread is under 0.003 * p_n.
5153 if n < 1:
5154 raise ValueError
5156 if n < 6:
5157 return (n, 2.25 * n)
5159 # https://en.wikipedia.org/wiki/Prime-counting_function#Inequalities
5160 upper_bound = n * log(n * log(n))
5161 lower_bound = upper_bound - n
5162 if n >= 688_383:
5163 upper_bound -= n * (1.0 - (log(log(n)) - 2.0) / log(n))
5165 return lower_bound, upper_bound
5168def nth_prime(n, *, approximate=False):
5169 """Return the nth prime (counting from 0).
5171 >>> nth_prime(0)
5172 2
5173 >>> nth_prime(100)
5174 547
5176 If *approximate* is set to True, will return a prime close
5177 to the nth prime. The estimation is much faster than computing
5178 an exact result.
5180 >>> nth_prime(200_000_000, approximate=True) # Exact result is 4222234763
5181 4217820427
5183 """
5184 lb, ub = _nth_prime_bounds(n + 1)
5186 if not approximate or n <= 1_000_000:
5187 return nth(sieve(ceil(ub)), n)
5189 # Search from the midpoint and return the first odd prime
5190 odd = floor((lb + ub) / 2) | 1
5191 return first_true(count(odd, step=2), pred=is_prime)
5194def argmin(iterable, *, key=None):
5195 """
5196 Index of the first occurrence of a minimum value in an iterable.
5198 >>> argmin('efghabcdijkl')
5199 4
5200 >>> argmin([3, 2, 1, 0, 4, 2, 1, 0])
5201 3
5203 For example, look up a label corresponding to the position
5204 of a value that minimizes a cost function::
5206 >>> def cost(x):
5207 ... "Days for a wound to heal given a subject's age."
5208 ... return x**2 - 20*x + 150
5209 ...
5210 >>> labels = ['homer', 'marge', 'bart', 'lisa', 'maggie']
5211 >>> ages = [ 35, 30, 10, 9, 1 ]
5213 # Fastest healing family member
5214 >>> labels[argmin(ages, key=cost)]
5215 'bart'
5217 # Age with fastest healing
5218 >>> min(ages, key=cost)
5219 10
5221 """
5222 if key is not None:
5223 iterable = map(key, iterable)
5224 return min(enumerate(iterable), key=itemgetter(1))[0]
5227def argmax(iterable, *, key=None):
5228 """
5229 Index of the first occurrence of a maximum value in an iterable.
5231 >>> argmax('abcdefghabcd')
5232 7
5233 >>> argmax([0, 1, 2, 3, 3, 2, 1, 0])
5234 3
5236 For example, identify the best machine learning model::
5238 >>> models = ['svm', 'random forest', 'knn', 'naïve bayes']
5239 >>> accuracy = [ 68, 61, 84, 72 ]
5241 # Most accurate model
5242 >>> models[argmax(accuracy)]
5243 'knn'
5245 # Best accuracy
5246 >>> max(accuracy)
5247 84
5249 """
5250 if key is not None:
5251 iterable = map(key, iterable)
5252 return max(enumerate(iterable), key=itemgetter(1))[0]
5255def extract(iterable, indices):
5256 """Yield values at the specified indices.
5258 Example:
5260 >>> data = 'abcdefghijklmnopqrstuvwxyz'
5261 >>> list(extract(data, [7, 4, 11, 11, 14]))
5262 ['h', 'e', 'l', 'l', 'o']
5264 The *iterable* is consumed lazily and can be infinite.
5265 The *indices* are consumed immediately and must be finite.
5267 Raises ``IndexError`` if an index lies beyond the iterable.
5268 Raises ``ValueError`` for negative indices.
5269 """
5271 iterator = iter(iterable)
5272 index_and_position = sorted(zip(indices, count()))
5274 if index_and_position and index_and_position[0][0] < 0:
5275 raise ValueError('Indices must be non-negative')
5277 buffer = {}
5278 iterator_position = -1
5279 next_to_emit = 0
5281 for index, order in index_and_position:
5282 advance = index - iterator_position
5283 if advance:
5284 try:
5285 value = next(islice(iterator, advance - 1, None))
5286 except StopIteration:
5287 raise IndexError(index)
5288 iterator_position = index
5290 buffer[order] = value
5292 while next_to_emit in buffer:
5293 yield buffer.pop(next_to_emit)
5294 next_to_emit += 1
5297class serialize:
5298 """Wrap a non-concurrent iterator with a lock to enforce sequential access.
5300 Applies a non-reentrant lock around calls to ``__next__``, allowing
5301 iterator and generator instances to be shared by multiple consumer
5302 threads.
5303 """
5305 __slots__ = ('iterator', 'lock')
5307 def __init__(self, iterable):
5308 self.iterator = iter(iterable)
5309 self.lock = Lock()
5311 def __iter__(self):
5312 return self
5314 def __next__(self):
5315 with self.lock:
5316 return next(self.iterator)
5319def concurrent_tee(iterable, n=2):
5320 """Variant of itertools.tee() but with guaranteed threading semantics.
5322 Takes a non-threadsafe iterator as an input and creates concurrent
5323 tee objects for other threads to have reliable independent copies of
5324 the data stream.
5326 The new iterators are only thread-safe if consumed within a single thread.
5327 To share just one of the new iterators across multiple threads, wrap it
5328 with :func:`serialize`.
5329 """
5331 if n < 0:
5332 raise ValueError
5333 if n == 0:
5334 return ()
5335 iterator = _concurrent_tee(iterable)
5336 result = [iterator]
5337 for _ in range(n - 1):
5338 result.append(_concurrent_tee(iterator))
5339 return tuple(result)
5342class _concurrent_tee:
5343 def __init__(self, iterable):
5344 it = iter(iterable)
5345 if isinstance(it, _concurrent_tee):
5346 self.iterator = it.iterator
5347 self.link = it.link
5348 self.lock = it.lock
5349 else:
5350 self.iterator = it
5351 self.link = [None, None]
5352 self.lock = Lock()
5354 def __iter__(self):
5355 return self
5357 def __next__(self):
5358 link = self.link
5359 if link[1] is None:
5360 with self.lock:
5361 if link[1] is None:
5362 link[0] = next(self.iterator)
5363 link[1] = [None, None]
5364 value, self.link = link
5365 return value