Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/more_itertools/more.py: 19%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import math
3from collections import Counter, defaultdict, deque
4from collections.abc import Sequence
5from contextlib import suppress
6from functools import cached_property, partial, wraps
7from heapq import heapify, heapreplace
8from itertools import (
9 chain,
10 combinations,
11 compress,
12 count,
13 cycle,
14 dropwhile,
15 groupby,
16 islice,
17 permutations,
18 repeat,
19 starmap,
20 takewhile,
21 tee,
22 zip_longest,
23 product,
24)
25from math import comb, e, exp, factorial, floor, fsum, log, log1p, perm, tau
26from math import ceil, prod
27from queue import Empty, Queue
28from random import random, randrange, shuffle, uniform
29from operator import (
30 attrgetter,
31 getitem,
32 is_not,
33 itemgetter,
34 lt,
35 neg,
36 sub,
37 gt,
38)
39from sys import maxsize
40from time import monotonic
41from threading import Lock
43from .recipes import (
44 _marker,
45 consume,
46 first_true,
47 flatten,
48 is_prime,
49 nth,
50 powerset,
51 sieve,
52 take,
53 unique_everseen,
54 all_equal,
55 batched,
56)
58__all__ = [
59 'AbortThread',
60 'SequenceView',
61 'adjacent',
62 'all_unique',
63 'always_iterable',
64 'always_reversible',
65 'argmax',
66 'argmin',
67 'bucket',
68 'callback_iter',
69 'chunked',
70 'chunked_even',
71 'circular_shifts',
72 'collapse',
73 'combination_index',
74 'combination_with_replacement_index',
75 'concurrent_tee',
76 'consecutive_groups',
77 'constrained_batches',
78 'consumer',
79 'count_cycle',
80 'countable',
81 'derangements',
82 'dft',
83 'difference',
84 'distinct_combinations',
85 'distinct_permutations',
86 'distribute',
87 'divide',
88 'doublestarmap',
89 'duplicates_everseen',
90 'duplicates_justseen',
91 'classify_unique',
92 'exactly_n',
93 'extract',
94 'filter_except',
95 'filter_map',
96 'first',
97 'gray_product',
98 'groupby_transform',
99 'ichunked',
100 'iequals',
101 'idft',
102 'ilen',
103 'interleave',
104 'interleave_evenly',
105 'interleave_longest',
106 'interleave_randomly',
107 'intersperse',
108 'is_sorted',
109 'islice_extended',
110 'iterate',
111 'iter_suppress',
112 'join_mappings',
113 'last',
114 'locate',
115 'longest_common_prefix',
116 'lstrip',
117 'make_decorator',
118 'map_except',
119 'map_if',
120 'map_reduce',
121 'mark_ends',
122 'minmax',
123 'nth_or_last',
124 'nth_permutation',
125 'nth_prime',
126 'nth_product',
127 'nth_combination_with_replacement',
128 'numeric_range',
129 'one',
130 'only',
131 'outer_product',
132 'padded',
133 'partial_product',
134 'partitions',
135 'peekable',
136 'permutation_index',
137 'powerset_of_sets',
138 'product_index',
139 'raise_',
140 'repeat_each',
141 'repeat_last',
142 'replace',
143 'rlocate',
144 'rstrip',
145 'run_length',
146 'sample',
147 'seekable',
148 'serialize',
149 'set_partitions',
150 'side_effect',
151 'sized_iterator',
152 'sliced',
153 'sort_together',
154 'split_after',
155 'split_at',
156 'split_before',
157 'split_into',
158 'split_when',
159 'spy',
160 'stagger',
161 'strip',
162 'strictly_n',
163 'substrings',
164 'substrings_indexes',
165 'synchronized',
166 'takewhile_inclusive',
167 'time_limited',
168 'unique_in_window',
169 'unique_to_each',
170 'unzip',
171 'value_chain',
172 'windowed',
173 'windowed_complete',
174 'with_iter',
175 'zip_broadcast',
176 'zip_offset',
177]
179# math.sumprod is available for Python 3.12+
180try:
181 from math import sumprod as _fsumprod
183except ImportError: # pragma: no cover
184 # Extended precision algorithms from T. J. Dekker,
185 # "A Floating-Point Technique for Extending the Available Precision"
186 # https://csclub.uwaterloo.ca/~pbarfuss/dekker1971.pdf
187 # Formulas: (5.5) (5.6) and (5.8). Code: mul12()
189 def dl_split(x: float):
190 "Split a float into two half-precision components."
191 t = x * 134217729.0 # Veltkamp constant = 2.0 ** 27 + 1
192 hi = t - (t - x)
193 lo = x - hi
194 return hi, lo
196 def dl_mul(x, y):
197 "Lossless multiplication."
198 xx_hi, xx_lo = dl_split(x)
199 yy_hi, yy_lo = dl_split(y)
200 p = xx_hi * yy_hi
201 q = xx_hi * yy_lo + xx_lo * yy_hi
202 z = p + q
203 zz = p - z + q + xx_lo * yy_lo
204 return z, zz
206 def _fsumprod(p, q):
207 return fsum(chain.from_iterable(map(dl_mul, p, q)))
210def chunked(iterable, n, strict=False):
211 """Break *iterable* into lists of length *n*:
213 >>> list(chunked([1, 2, 3, 4, 5, 6], 3))
214 [[1, 2, 3], [4, 5, 6]]
216 By the default, the last yielded list will have fewer than *n* elements
217 if the length of *iterable* is not divisible by *n*:
219 >>> list(chunked([1, 2, 3, 4, 5, 6, 7, 8], 3))
220 [[1, 2, 3], [4, 5, 6], [7, 8]]
222 To use a fill-in value instead, see the :func:`grouper` recipe.
224 If the length of *iterable* is not divisible by *n* and *strict* is
225 ``True``, then ``ValueError`` will be raised before the last
226 list is yielded.
228 """
229 iterator = iter(partial(take, n, iter(iterable)), [])
230 if strict:
231 if n is None:
232 raise ValueError('n must not be None when using strict mode.')
234 def ret():
235 for chunk in iterator:
236 if len(chunk) != n:
237 raise ValueError('iterable is not divisible by n.')
238 yield chunk
240 return ret()
241 else:
242 return iterator
245def first(iterable, default=_marker):
246 """Return the first item of *iterable*, or *default* if *iterable* is
247 empty.
249 >>> first([0, 1, 2, 3])
250 0
251 >>> first([], 'some default')
252 'some default'
254 If *default* is not provided and there are no items in the iterable,
255 raise ``ValueError``.
257 :func:`first` is useful when you have a generator of expensive-to-retrieve
258 values and want any arbitrary one. It is marginally shorter than
259 ``next(iter(iterable), default)``.
261 """
262 for item in iterable:
263 return item
264 if default is _marker:
265 raise ValueError(
266 'first() was called on an empty iterable, '
267 'and no default value was provided.'
268 )
269 return default
272def last(iterable, default=_marker):
273 """Return the last item of *iterable*, or *default* if *iterable* is
274 empty.
276 >>> last([0, 1, 2, 3])
277 3
278 >>> last([], 'some default')
279 'some default'
281 If *default* is not provided and there are no items in the iterable,
282 raise ``ValueError``.
283 """
284 try:
285 if isinstance(iterable, Sequence):
286 return iterable[-1]
287 # Work around https://bugs.python.org/issue38525
288 if getattr(iterable, '__reversed__', None):
289 return next(reversed(iterable))
290 return deque(iterable, maxlen=1)[-1]
291 except (IndexError, TypeError, StopIteration):
292 if default is _marker:
293 raise ValueError(
294 'last() was called on an empty iterable, '
295 'and no default value was provided.'
296 )
297 return default
300def nth_or_last(iterable, n, default=_marker):
301 """Return the nth or the last item of *iterable*,
302 or *default* if *iterable* is empty.
304 >>> nth_or_last([0, 1, 2, 3], 2)
305 2
306 >>> nth_or_last([0, 1], 2)
307 1
308 >>> nth_or_last([], 0, 'some default')
309 'some default'
311 If *default* is not provided and there are no items in the iterable,
312 raise ``ValueError``.
313 """
314 return last(islice(iterable, n + 1), default=default)
317class peekable:
318 """Wrap an iterator to allow lookahead and prepending elements.
320 Call :meth:`peek` on the result to get the value that will be returned
321 by :func:`next`. This won't advance the iterator:
323 >>> p = peekable(['a', 'b'])
324 >>> p.peek()
325 'a'
326 >>> next(p)
327 'a'
329 Pass :meth:`peek` a default value to return that instead of raising
330 ``StopIteration`` when the iterator is exhausted.
332 >>> p = peekable([])
333 >>> p.peek('hi')
334 'hi'
336 peekables also offer a :meth:`prepend` method, which "inserts" items
337 at the head of the iterable:
339 >>> p = peekable([1, 2, 3])
340 >>> p.prepend(10, 11, 12)
341 >>> next(p)
342 10
343 >>> p.peek()
344 11
345 >>> list(p)
346 [11, 12, 1, 2, 3]
348 peekables can be indexed. Index 0 is the item that will be returned by
349 :func:`next`, index 1 is the item after that, and so on:
350 The values up to the given index will be cached.
352 >>> p = peekable(['a', 'b', 'c', 'd'])
353 >>> p[0]
354 'a'
355 >>> p[1]
356 'b'
357 >>> next(p)
358 'a'
360 Negative indexes are supported, but be aware that they will cache the
361 remaining items in the source iterator, which may require significant
362 storage.
364 To check whether a peekable is exhausted, check its truth value:
366 >>> p = peekable(['a', 'b'])
367 >>> if p: # peekable has items
368 ... list(p)
369 ['a', 'b']
370 >>> if not p: # peekable is exhausted
371 ... list(p)
372 []
374 """
376 def __init__(self, iterable):
377 self._it = iter(iterable)
378 self._cache = deque()
380 def __iter__(self):
381 return self
383 def __bool__(self):
384 try:
385 self.peek()
386 except StopIteration:
387 return False
388 return True
390 def peek(self, default=_marker):
391 """Return the item that will be next returned from ``next()``.
393 Return ``default`` if there are no items left. If ``default`` is not
394 provided, raise ``StopIteration``.
396 """
397 if not self._cache:
398 try:
399 self._cache.append(next(self._it))
400 except StopIteration:
401 if default is _marker:
402 raise
403 return default
404 return self._cache[0]
406 def prepend(self, *items):
407 """Stack up items to be the next ones returned from ``next()`` or
408 ``self.peek()``. The items will be returned in
409 first in, first out order::
411 >>> p = peekable([1, 2, 3])
412 >>> p.prepend(10, 11, 12)
413 >>> next(p)
414 10
415 >>> list(p)
416 [11, 12, 1, 2, 3]
418 It is possible, by prepending items, to "resurrect" a peekable that
419 previously raised ``StopIteration``.
421 >>> p = peekable([])
422 >>> next(p)
423 Traceback (most recent call last):
424 ...
425 StopIteration
426 >>> p.prepend(1)
427 >>> next(p)
428 1
429 >>> next(p)
430 Traceback (most recent call last):
431 ...
432 StopIteration
434 """
435 self._cache.extendleft(reversed(items))
437 def __next__(self):
438 if self._cache:
439 return self._cache.popleft()
441 return next(self._it)
443 def _get_slice(self, index):
444 # Normalize the slice's arguments
445 step = 1 if (index.step is None) else index.step
446 if step > 0:
447 start = 0 if (index.start is None) else index.start
448 stop = maxsize if (index.stop is None) else index.stop
449 elif step < 0:
450 start = -1 if (index.start is None) else index.start
451 stop = (-maxsize - 1) if (index.stop is None) else index.stop
452 else:
453 raise ValueError('slice step cannot be zero')
455 # If either the start or stop index is negative, we'll need to cache
456 # the rest of the iterable in order to slice from the right side.
457 if (start < 0) or (stop < 0):
458 self._cache.extend(self._it)
459 # Otherwise we'll need to find the rightmost index and cache to that
460 # point.
461 else:
462 n = min(max(start, stop) + 1, maxsize)
463 cache_len = len(self._cache)
464 if n >= cache_len:
465 self._cache.extend(islice(self._it, n - cache_len))
467 return list(self._cache)[index]
469 def __getitem__(self, index):
470 if isinstance(index, slice):
471 return self._get_slice(index)
473 cache_len = len(self._cache)
474 if index < 0:
475 self._cache.extend(self._it)
476 elif index >= cache_len:
477 self._cache.extend(islice(self._it, index + 1 - cache_len))
479 return self._cache[index]
482def consumer(func):
483 """Decorator that automatically advances a PEP-342-style "reverse iterator"
484 to its first yield point so you don't have to call ``next()`` on it
485 manually.
487 >>> @consumer
488 ... def tally():
489 ... i = 0
490 ... while True:
491 ... print('Thing number %s is %s.' % (i, (yield)))
492 ... i += 1
493 ...
494 >>> t = tally()
495 >>> t.send('red')
496 Thing number 0 is red.
497 >>> t.send('fish')
498 Thing number 1 is fish.
500 Without the decorator, you would have to call ``next(t)`` before
501 ``t.send()`` could be used.
503 """
505 @wraps(func)
506 def wrapper(*args, **kwargs):
507 gen = func(*args, **kwargs)
508 next(gen)
509 return gen
511 return wrapper
514def ilen(iterable):
515 """Return the number of items in *iterable*.
517 For example, there are 168 prime numbers below 1,000:
519 >>> ilen(sieve(1000))
520 168
522 Equivalent to, but faster than::
524 def ilen(iterable):
525 count = 0
526 for _ in iterable:
527 count += 1
528 return count
530 This fully consumes the iterable, so handle with care.
532 """
533 # This is the "most beautiful of the fast variants" of this function.
534 # If you think you can improve on it, please ensure that your version
535 # is both 10x faster and 10x more beautiful.
536 return sum(compress(repeat(1), zip(iterable)))
539def iterate(func, start):
540 """Return ``start``, ``func(start)``, ``func(func(start))``, ...
542 Produces an infinite iterator. To add a stopping condition,
543 use :func:`take`, ``takewhile``, or :func:`takewhile_inclusive`:.
545 >>> take(10, iterate(lambda x: 2*x, 1))
546 [1, 2, 4, 8, 16, 32, 64, 128, 256, 512]
548 >>> collatz = lambda x: 3*x + 1 if x%2==1 else x // 2
549 >>> list(takewhile_inclusive(lambda x: x!=1, iterate(collatz, 10)))
550 [10, 5, 16, 8, 4, 2, 1]
552 """
553 with suppress(StopIteration):
554 while True:
555 yield start
556 start = func(start)
559def with_iter(context_manager):
560 """Wrap an iterable in a ``with`` statement, so it closes once exhausted.
562 For example, this will close the file when the iterator is exhausted::
564 upper_lines = (line.upper() for line in with_iter(open('foo')))
566 Note that you have to actually exhaust the iterator for opened files to be closed.
568 Any context manager which returns an iterable is a candidate for
569 ``with_iter``.
571 """
572 with context_manager as iterable:
573 yield from iterable
576class sized_iterator:
577 """Wrapper for *iterable* that implements ``__len__``.
579 >>> it = map(str, range(5))
580 >>> sized_it = sized_iterator(it, 5)
581 >>> len(sized_it)
582 5
583 >>> list(sized_it)
584 ['0', '1', '2', '3', '4']
586 This is useful for tools that use :func:`len`, like
587 `tqdm <https://pypi.org/project/tqdm/>`__ .
589 The wrapper doesn't validate the provided *length*, so be sure to choose
590 a value that reflects reality.
591 """
593 def __init__(self, iterable, length):
594 self._iterator = iter(iterable)
595 self._length = length
597 def __next__(self):
598 return next(self._iterator)
600 def __iter__(self):
601 return self
603 def __len__(self):
604 return self._length
607def one(iterable, too_short=None, too_long=None):
608 """Return the first item from *iterable*, which is expected to contain only
609 that item. Raise an exception if *iterable* is empty or has more than one
610 item.
612 :func:`one` is useful for ensuring that an iterable contains only one item.
613 For example, it can be used to retrieve the result of a database query
614 that is expected to return a single row.
616 If *iterable* is empty, ``ValueError`` will be raised. You may specify a
617 different exception with the *too_short* keyword:
619 >>> it = []
620 >>> one(it) # doctest: +IGNORE_EXCEPTION_DETAIL
621 Traceback (most recent call last):
622 ...
623 ValueError: too few items in iterable (expected 1)'
624 >>> too_short = IndexError('too few items')
625 >>> one(it, too_short=too_short) # doctest: +IGNORE_EXCEPTION_DETAIL
626 Traceback (most recent call last):
627 ...
628 IndexError: too few items
630 Similarly, if *iterable* contains more than one item, ``ValueError`` will
631 be raised. You may specify a different exception with the *too_long*
632 keyword:
634 >>> it = ['too', 'many']
635 >>> one(it) # doctest: +IGNORE_EXCEPTION_DETAIL
636 Traceback (most recent call last):
637 ...
638 ValueError: Expected exactly one item in iterable, but got 'too',
639 'many', and perhaps more.
640 >>> too_long = RuntimeError
641 >>> one(it, too_long=too_long) # doctest: +IGNORE_EXCEPTION_DETAIL
642 Traceback (most recent call last):
643 ...
644 RuntimeError
646 Note that :func:`one` attempts to advance *iterable* twice to ensure there
647 is only one item. See :func:`spy` or :func:`peekable` to check iterable
648 contents less destructively.
650 """
651 iterator = iter(iterable)
652 for first in iterator:
653 for second in iterator:
654 msg = (
655 f'Expected exactly one item in iterable, but got {first!r}, '
656 f'{second!r}, and perhaps more.'
657 )
658 raise too_long or ValueError(msg)
659 return first
660 raise too_short or ValueError('too few items in iterable (expected 1)')
663def raise_(exception, *args):
664 raise exception(*args)
667def strictly_n(iterable, n, too_short=None, too_long=None):
668 """Validate that *iterable* has exactly *n* items and return them if
669 it does. If it has fewer than *n* items, call function *too_short*
670 with the actual number of items. If it has more than *n* items, call function
671 *too_long* with the number ``n + 1``.
673 >>> iterable = ['a', 'b', 'c', 'd']
674 >>> n = 4
675 >>> list(strictly_n(iterable, n))
676 ['a', 'b', 'c', 'd']
678 Note that the returned iterable must be consumed in order for the check to
679 be made.
681 By default, *too_short* and *too_long* are functions that raise
682 ``ValueError``.
684 >>> list(strictly_n('ab', 3)) # doctest: +IGNORE_EXCEPTION_DETAIL
685 Traceback (most recent call last):
686 ...
687 ValueError: too few items in iterable (got 2)
689 >>> list(strictly_n('abc', 2)) # doctest: +IGNORE_EXCEPTION_DETAIL
690 Traceback (most recent call last):
691 ...
692 ValueError: too many items in iterable (got at least 3)
694 You can instead supply functions that do something else.
695 *too_short* will be called with the number of items in *iterable*.
696 *too_long* will be called with `n + 1`.
698 >>> def too_short(item_count):
699 ... raise RuntimeError
700 >>> it = strictly_n('abcd', 6, too_short=too_short)
701 >>> list(it) # doctest: +IGNORE_EXCEPTION_DETAIL
702 Traceback (most recent call last):
703 ...
704 RuntimeError
706 >>> def too_long(item_count):
707 ... print('The boss is going to hear about this')
708 >>> it = strictly_n('abcdef', 4, too_long=too_long)
709 >>> list(it)
710 The boss is going to hear about this
711 ['a', 'b', 'c', 'd']
713 """
714 if too_short is None:
715 too_short = lambda item_count: raise_(
716 ValueError,
717 f'Too few items in iterable (got {item_count})',
718 )
720 if too_long is None:
721 too_long = lambda item_count: raise_(
722 ValueError,
723 f'Too many items in iterable (got at least {item_count})',
724 )
726 it = iter(iterable)
728 sent = 0
729 for item in islice(it, n):
730 yield item
731 sent += 1
733 if sent < n:
734 too_short(sent)
735 return
737 for item in it:
738 too_long(n + 1)
739 return
742def distinct_permutations(iterable, r=None):
743 """Yield successive distinct permutations of the elements in *iterable*.
745 >>> sorted(distinct_permutations([1, 0, 1]))
746 [(0, 1, 1), (1, 0, 1), (1, 1, 0)]
748 Equivalent to yielding from ``set(permutations(iterable))``, except
749 duplicates are not generated and thrown away. For larger input sequences
750 this is much more efficient.
752 Duplicate permutations arise when there are duplicated elements in the
753 input iterable. The number of items returned is
754 `n! / (x_1! * x_2! * ... * x_n!)`, where `n` is the total number of
755 items input, and each `x_i` is the count of a distinct item in the input
756 sequence. The function :func:`multinomial` computes this directly.
758 If *r* is given, only the *r*-length permutations are yielded.
760 >>> sorted(distinct_permutations([1, 0, 1], r=2))
761 [(0, 1), (1, 0), (1, 1)]
762 >>> sorted(distinct_permutations(range(3), r=2))
763 [(0, 1), (0, 2), (1, 0), (1, 2), (2, 0), (2, 1)]
765 *iterable* need not be sortable, but note that using equal (``x == y``)
766 but non-identical (``id(x) != id(y)``) elements may produce surprising
767 behavior. For example, ``1`` and ``True`` are equal but non-identical:
769 >>> list(distinct_permutations([1, True, '3'])) # doctest: +SKIP
770 [
771 (1, True, '3'),
772 (1, '3', True),
773 ('3', 1, True)
774 ]
775 >>> list(distinct_permutations([1, 2, '3'])) # doctest: +SKIP
776 [
777 (1, 2, '3'),
778 (1, '3', 2),
779 (2, 1, '3'),
780 (2, '3', 1),
781 ('3', 1, 2),
782 ('3', 2, 1)
783 ]
784 """
786 # Algorithm: https://w.wiki/Qai
787 def _full(A):
788 while True:
789 # Yield the permutation we have
790 yield tuple(A)
792 # Find the largest index i such that A[i] < A[i + 1]
793 for i in range(size - 2, -1, -1):
794 if A[i] < A[i + 1]:
795 break
796 # If no such index exists, this permutation is the last one
797 else:
798 return
800 # Find the largest index j greater than j such that A[i] < A[j]
801 for j in range(size - 1, i, -1):
802 if A[i] < A[j]:
803 break
805 # Swap the value of A[i] with that of A[j], then reverse the
806 # sequence from A[i + 1] to form the new permutation
807 A[i], A[j] = A[j], A[i]
808 A[i + 1 :] = A[: i - size : -1] # A[i + 1:][::-1]
810 # Algorithm: modified from the above
811 def _partial(A, r):
812 # Split A into the first r items and the last r items
813 head, tail = A[:r], A[r:]
814 right_head_indexes = range(r - 1, -1, -1)
815 left_tail_indexes = range(len(tail))
817 while True:
818 # Yield the permutation we have
819 yield tuple(head)
821 # Starting from the right, find the first index of the head with
822 # value smaller than the maximum value of the tail - call it i.
823 pivot = tail[-1]
824 for i in right_head_indexes:
825 if head[i] < pivot:
826 break
827 pivot = head[i]
828 else:
829 return
831 # Starting from the left, find the first value of the tail
832 # with a value greater than head[i] and swap.
833 for j in left_tail_indexes:
834 if tail[j] > head[i]:
835 head[i], tail[j] = tail[j], head[i]
836 break
837 # If we didn't find one, start from the right and find the first
838 # index of the head with a value greater than head[i] and swap.
839 else:
840 for j in right_head_indexes:
841 if head[j] > head[i]:
842 head[i], head[j] = head[j], head[i]
843 break
845 # Reverse head[i + 1:] and swap it with tail[:r - (i + 1)]
846 tail += head[: i - r : -1] # head[i + 1:][::-1]
847 i += 1
848 head[i:], tail[:] = tail[: r - i], tail[r - i :]
850 items = list(iterable)
852 try:
853 items.sort()
854 sortable = True
855 except TypeError:
856 sortable = False
858 indices_dict = defaultdict(list)
860 for item in items:
861 indices_dict[items.index(item)].append(item)
863 indices = [items.index(item) for item in items]
864 indices.sort()
866 equivalent_items = {k: cycle(v) for k, v in indices_dict.items()}
868 def permuted_items(permuted_indices):
869 return tuple(
870 next(equivalent_items[index]) for index in permuted_indices
871 )
873 size = len(items)
874 if r is None:
875 r = size
877 # functools.partial(_partial, ... )
878 algorithm = _full if (r == size) else partial(_partial, r=r)
880 if 0 < r <= size:
881 if sortable:
882 return algorithm(items)
883 else:
884 return (
885 permuted_items(permuted_indices)
886 for permuted_indices in algorithm(indices)
887 )
889 return iter(() if r else ((),))
892def derangements(iterable, r=None):
893 """Yield successive derangements of the elements in *iterable*.
895 A derangement is a permutation in which no element appears at its original
896 index. In other words, a derangement is a permutation that has no fixed points.
898 Suppose Alice, Bob, Carol, and Dave are playing Secret Santa.
899 The code below outputs all of the different ways to assign gift recipients
900 such that nobody is assigned to himself or herself:
902 >>> for d in derangements(['Alice', 'Bob', 'Carol', 'Dave']):
903 ... print(', '.join(d))
904 Bob, Alice, Dave, Carol
905 Bob, Carol, Dave, Alice
906 Bob, Dave, Alice, Carol
907 Carol, Alice, Dave, Bob
908 Carol, Dave, Alice, Bob
909 Carol, Dave, Bob, Alice
910 Dave, Alice, Bob, Carol
911 Dave, Carol, Alice, Bob
912 Dave, Carol, Bob, Alice
914 If *r* is given, only the *r*-length derangements are yielded.
916 >>> sorted(derangements(range(3), 2))
917 [(1, 0), (1, 2), (2, 0)]
918 >>> sorted(derangements([0, 2, 3], 2))
919 [(2, 0), (2, 3), (3, 0)]
921 Elements are treated as unique based on their position, not on their value.
923 Consider the Secret Santa example with two *different* people who have
924 the *same* name. Then there are two valid gift assignments even though
925 it might appear that a person is assigned to themselves:
927 >>> names = ['Alice', 'Bob', 'Bob']
928 >>> list(derangements(names))
929 [('Bob', 'Bob', 'Alice'), ('Bob', 'Alice', 'Bob')]
931 To avoid confusion, make the inputs distinct:
933 >>> deduped = [f'{name}{index}' for index, name in enumerate(names)]
934 >>> list(derangements(deduped))
935 [('Bob1', 'Bob2', 'Alice0'), ('Bob2', 'Alice0', 'Bob1')]
937 The number of derangements of a set of size *n* is known as the
938 "subfactorial of n". For n > 0, the subfactorial is:
939 ``round(math.factorial(n) / math.e)``.
941 References:
943 * Article: https://www.numberanalytics.com/blog/ultimate-guide-to-derangements-in-combinatorics
944 * Sizes: https://oeis.org/A000166
945 """
946 xs = tuple(iterable)
947 ys = tuple(range(len(xs)))
948 return compress(
949 permutations(xs, r=r),
950 map(all, map(map, repeat(is_not), repeat(ys), permutations(ys, r=r))),
951 )
954def intersperse(e, iterable, n=1):
955 """Intersperse filler element *e* among the items in *iterable*, leaving
956 *n* items between each filler element.
958 >>> list(intersperse('!', [1, 2, 3, 4, 5]))
959 [1, '!', 2, '!', 3, '!', 4, '!', 5]
961 >>> list(intersperse(None, [1, 2, 3, 4, 5], n=2))
962 [1, 2, None, 3, 4, None, 5]
964 """
965 if n == 0:
966 raise ValueError('n must be > 0')
967 elif n == 1:
968 # interleave(repeat(e), iterable) -> e, x_0, e, x_1, e, x_2...
969 # islice(..., 1, None) -> x_0, e, x_1, e, x_2...
970 return islice(interleave(repeat(e), iterable), 1, None)
971 else:
972 # interleave(filler, chunks) -> [e], [x_0, x_1], [e], [x_2, x_3]...
973 # islice(..., 1, None) -> [x_0, x_1], [e], [x_2, x_3]...
974 # flatten(...) -> x_0, x_1, e, x_2, x_3...
975 filler = repeat([e])
976 chunks = chunked(iterable, n)
977 return flatten(islice(interleave(filler, chunks), 1, None))
980def unique_to_each(*iterables):
981 """Return the elements from each of the input iterables that aren't in the
982 other input iterables.
984 For example, suppose you have a set of packages, each with a set of
985 dependencies::
987 {'pkg_1': {'A', 'B'}, 'pkg_2': {'B', 'C'}, 'pkg_3': {'B', 'D'}}
989 If you remove one package, which dependencies can also be removed?
991 If ``pkg_1`` is removed, then ``A`` is no longer necessary - it is not
992 associated with ``pkg_2`` or ``pkg_3``. Similarly, ``C`` is only needed for
993 ``pkg_2``, and ``D`` is only needed for ``pkg_3``::
995 >>> unique_to_each({'A', 'B'}, {'B', 'C'}, {'B', 'D'})
996 [['A'], ['C'], ['D']]
998 If there are duplicates in one input iterable that aren't in the others
999 they will be duplicated in the output. Input order is preserved::
1001 >>> unique_to_each("mississippi", "missouri")
1002 [['p', 'p'], ['o', 'u', 'r']]
1004 It is assumed that the elements of each iterable are hashable.
1006 """
1007 pool = [list(it) for it in iterables]
1008 counts = Counter(chain.from_iterable(map(set, pool)))
1009 uniques = {element for element in counts if counts[element] == 1}
1010 return [list(filter(uniques.__contains__, it)) for it in pool]
1013def windowed(seq, n, fillvalue=None, step=1):
1014 """Return a sliding window of width *n* over the given iterable.
1016 >>> all_windows = windowed([1, 2, 3, 4, 5], 3)
1017 >>> list(all_windows)
1018 [(1, 2, 3), (2, 3, 4), (3, 4, 5)]
1020 When the window is larger than the iterable, *fillvalue* is used in place
1021 of missing values:
1023 >>> list(windowed([1, 2, 3], 4))
1024 [(1, 2, 3, None)]
1026 Each window will advance in increments of *step*:
1028 >>> list(windowed([1, 2, 3, 4, 5, 6], 3, fillvalue='!', step=2))
1029 [(1, 2, 3), (3, 4, 5), (5, 6, '!')]
1031 To slide into the iterable's items, use :func:`chain` to add filler items
1032 to the left:
1034 >>> iterable = [1, 2, 3, 4]
1035 >>> n = 3
1036 >>> padding = [None] * (n - 1)
1037 >>> list(windowed(chain(padding, iterable), 3))
1038 [(None, None, 1), (None, 1, 2), (1, 2, 3), (2, 3, 4)]
1039 """
1040 if n < 0:
1041 raise ValueError('n must be >= 0')
1042 if n == 0:
1043 yield ()
1044 return
1045 if step < 1:
1046 raise ValueError('step must be >= 1')
1048 iterator = iter(seq)
1050 # Generate first window
1051 window = deque(islice(iterator, n), maxlen=n)
1053 # Deal with the first window not being full
1054 if not window:
1055 return
1056 if len(window) < n:
1057 yield tuple(window) + ((fillvalue,) * (n - len(window)))
1058 return
1059 yield tuple(window)
1061 # Create the filler for the next windows. The padding ensures
1062 # we have just enough elements to fill the last window.
1063 padding = (fillvalue,) * (n - 1 if step >= n else step - 1)
1064 filler = map(window.append, chain(iterator, padding))
1066 # Generate the rest of the windows
1067 for _ in islice(filler, step - 1, None, step):
1068 yield tuple(window)
1071def substrings(iterable):
1072 """Yield all of the substrings of *iterable*.
1074 >>> [''.join(s) for s in substrings('more')]
1075 ['m', 'o', 'r', 'e', 'mo', 'or', 're', 'mor', 'ore', 'more']
1077 Note that non-string iterables can also be subdivided.
1079 >>> list(substrings([0, 1, 2]))
1080 [(0,), (1,), (2,), (0, 1), (1, 2), (0, 1, 2)]
1082 Like subslices() but returns tuples instead of lists
1083 and returns the shortest substrings first.
1085 """
1086 seq = tuple(iterable)
1087 item_count = len(seq)
1088 for n in range(1, item_count + 1):
1089 slices = map(slice, range(item_count), range(n, item_count + 1))
1090 yield from map(getitem, repeat(seq), slices)
1093def substrings_indexes(seq, reverse=False):
1094 """Yield all substrings and their positions in *seq*
1096 The items yielded will be a tuple of the form ``(substr, i, j)``, where
1097 ``substr == seq[i:j]``.
1099 This function only works for iterables that support slicing, such as
1100 ``str`` objects.
1102 >>> for item in substrings_indexes('more'):
1103 ... print(item)
1104 ('m', 0, 1)
1105 ('o', 1, 2)
1106 ('r', 2, 3)
1107 ('e', 3, 4)
1108 ('mo', 0, 2)
1109 ('or', 1, 3)
1110 ('re', 2, 4)
1111 ('mor', 0, 3)
1112 ('ore', 1, 4)
1113 ('more', 0, 4)
1115 Set *reverse* to ``True`` to yield the same items in the opposite order.
1118 """
1119 r = range(1, len(seq) + 1)
1120 if reverse:
1121 r = reversed(r)
1122 return (
1123 (seq[i : i + L], i, i + L) for L in r for i in range(len(seq) - L + 1)
1124 )
1127class bucket:
1128 """Wrap *iterable* and return an object that buckets the iterable into
1129 child iterables based on a *key* function.
1131 >>> iterable = ['a1', 'b1', 'c1', 'a2', 'b2', 'c2', 'b3']
1132 >>> s = bucket(iterable, key=lambda x: x[0]) # Bucket by 1st character
1133 >>> sorted(list(s)) # Get the keys
1134 ['a', 'b', 'c']
1135 >>> a_iterable = s['a']
1136 >>> next(a_iterable)
1137 'a1'
1138 >>> next(a_iterable)
1139 'a2'
1140 >>> list(s['b'])
1141 ['b1', 'b2', 'b3']
1143 The original iterable will be advanced and its items will be cached until
1144 they are used by the child iterables. This may require significant storage.
1146 By default, attempting to select a bucket to which no items belong will
1147 exhaust the iterable and cache all values.
1148 If you specify a *validator* function, selected buckets will instead be
1149 checked against it.
1151 >>> from itertools import count
1152 >>> it = count(1, 2) # Infinite sequence of odd numbers
1153 >>> key = lambda x: x % 10 # Bucket by last digit
1154 >>> validator = lambda x: x in {1, 3, 5, 7, 9} # Odd digits only
1155 >>> s = bucket(it, key=key, validator=validator)
1156 >>> 2 in s
1157 False
1158 >>> list(s[2])
1159 []
1161 """
1163 def __init__(self, iterable, key, validator=None):
1164 self._it = iter(iterable)
1165 self._key = key
1166 self._cache = defaultdict(deque)
1167 self._validator = validator or (lambda x: True)
1169 def __contains__(self, value):
1170 if not self._validator(value):
1171 return False
1173 try:
1174 item = next(self[value])
1175 except StopIteration:
1176 return False
1177 else:
1178 self._cache[value].appendleft(item)
1180 return True
1182 def _get_values(self, value):
1183 """
1184 Helper to yield items from the parent iterator that match *value*.
1185 Items that don't match are stored in the local cache as they
1186 are encountered.
1187 """
1188 while True:
1189 # If we've cached some items that match the target value, emit
1190 # the first one and evict it from the cache.
1191 if self._cache[value]:
1192 yield self._cache[value].popleft()
1193 # Otherwise we need to advance the parent iterator to search for
1194 # a matching item, caching the rest.
1195 else:
1196 while True:
1197 try:
1198 item = next(self._it)
1199 except StopIteration:
1200 return
1201 item_value = self._key(item)
1202 if item_value == value:
1203 yield item
1204 break
1205 elif self._validator(item_value):
1206 self._cache[item_value].append(item)
1208 def __iter__(self):
1209 for item in self._it:
1210 item_value = self._key(item)
1211 if self._validator(item_value):
1212 self._cache[item_value].append(item)
1214 return iter(self._cache)
1216 def __getitem__(self, value):
1217 if not self._validator(value):
1218 return iter(())
1220 return self._get_values(value)
1223def spy(iterable, n=1):
1224 """Return a 2-tuple with a list containing the first *n* elements of
1225 *iterable*, and an iterator with the same items as *iterable*.
1226 This allows you to "look ahead" at the items in the iterable without
1227 advancing it.
1229 There is one item in the list by default:
1231 >>> iterable = 'abcdefg'
1232 >>> head, iterable = spy(iterable)
1233 >>> head
1234 ['a']
1235 >>> list(iterable)
1236 ['a', 'b', 'c', 'd', 'e', 'f', 'g']
1238 You may use unpacking to retrieve items instead of lists:
1240 >>> (head,), iterable = spy('abcdefg')
1241 >>> head
1242 'a'
1243 >>> (first, second), iterable = spy('abcdefg', 2)
1244 >>> first
1245 'a'
1246 >>> second
1247 'b'
1249 The number of items requested can be larger than the number of items in
1250 the iterable:
1252 >>> iterable = [1, 2, 3, 4, 5]
1253 >>> head, iterable = spy(iterable, 10)
1254 >>> head
1255 [1, 2, 3, 4, 5]
1256 >>> list(iterable)
1257 [1, 2, 3, 4, 5]
1259 """
1260 p, q = tee(iterable)
1261 return take(n, q), p
1264def interleave(*iterables):
1265 """Return a new iterable yielding from each iterable in turn,
1266 until the shortest is exhausted.
1268 >>> list(interleave([1, 2, 3], [4, 5], [6, 7, 8]))
1269 [1, 4, 6, 2, 5, 7]
1271 For a version that doesn't terminate after the shortest iterable is
1272 exhausted, see :func:`interleave_longest`.
1274 """
1275 return chain.from_iterable(zip(*iterables))
1278def interleave_longest(*iterables):
1279 """Return a new iterable yielding from each iterable in turn,
1280 skipping any that are exhausted.
1282 >>> list(interleave_longest([1, 2, 3], [4, 5], [6, 7, 8]))
1283 [1, 4, 6, 2, 5, 7, 3, 8]
1285 This function produces the same output as :func:`roundrobin`, but may
1286 perform better for some inputs (in particular when the number of iterables
1287 is large).
1289 """
1290 for xs in zip_longest(*iterables, fillvalue=_marker):
1291 for x in xs:
1292 if x is not _marker:
1293 yield x
1296def interleave_evenly(iterables, lengths=None):
1297 """
1298 Interleave multiple iterables so that their elements are evenly distributed
1299 throughout the output sequence.
1301 >>> iterables = [1, 2, 3, 4, 5], ['a', 'b']
1302 >>> list(interleave_evenly(iterables))
1303 [1, 2, 'a', 3, 4, 'b', 5]
1305 >>> iterables = [[1, 2, 3], [4, 5], [6, 7, 8]]
1306 >>> list(interleave_evenly(iterables))
1307 [1, 6, 4, 2, 7, 3, 8, 5]
1309 This function requires iterables of known length. Iterables without
1310 ``__len__()`` can be used by manually specifying lengths with *lengths*:
1312 >>> from itertools import combinations, repeat
1313 >>> iterables = [combinations(range(4), 2), ['a', 'b', 'c']]
1314 >>> lengths = [4 * (4 - 1) // 2, 3]
1315 >>> list(interleave_evenly(iterables, lengths=lengths))
1316 [(0, 1), (0, 2), 'a', (0, 3), (1, 2), 'b', (1, 3), (2, 3), 'c']
1318 Based on Bresenham's algorithm.
1319 """
1320 if lengths is None:
1321 try:
1322 lengths = [len(it) for it in iterables]
1323 except TypeError:
1324 raise ValueError(
1325 'Iterable lengths could not be determined automatically. '
1326 'Specify them with the lengths keyword.'
1327 )
1328 elif len(iterables) != len(lengths):
1329 raise ValueError('Mismatching number of iterables and lengths.')
1331 dims = len(lengths)
1333 # sort iterables by length, descending
1334 lengths_permute = sorted(
1335 range(dims), key=lambda i: lengths[i], reverse=True
1336 )
1337 lengths_desc = [lengths[i] for i in lengths_permute]
1338 iters_desc = [iter(iterables[i]) for i in lengths_permute]
1340 # the longest iterable is the primary one (Bresenham: the longest
1341 # distance along an axis)
1342 delta_primary, deltas_secondary = lengths_desc[0], lengths_desc[1:]
1343 iter_primary, iters_secondary = iters_desc[0], iters_desc[1:]
1344 errors = [delta_primary // dims] * len(deltas_secondary)
1346 to_yield = sum(lengths)
1347 while to_yield:
1348 yield next(iter_primary)
1349 to_yield -= 1
1350 # update errors for each secondary iterable
1351 errors = [e - delta for e, delta in zip(errors, deltas_secondary)]
1353 # those iterables for which the error is negative are yielded
1354 # ("diagonal step" in Bresenham)
1355 for i, e_ in enumerate(errors):
1356 if e_ < 0:
1357 yield next(iters_secondary[i])
1358 to_yield -= 1
1359 errors[i] += delta_primary
1362def interleave_randomly(*iterables):
1363 """Repeatedly select one of the input *iterables* at random and yield the next
1364 item from it.
1366 >>> iterables = [1, 2, 3], 'abc', (True, False, None)
1367 >>> list(interleave_randomly(*iterables)) # doctest: +SKIP
1368 ['a', 'b', 1, 'c', True, False, None, 2, 3]
1370 The relative order of the items in each input iterable will preserved. Note the
1371 sequences of items with this property are not equally likely to be generated.
1373 """
1374 iterators = [iter(e) for e in iterables]
1375 while iterators:
1376 idx = randrange(len(iterators))
1377 try:
1378 yield next(iterators[idx])
1379 except StopIteration:
1380 # equivalent to `list.pop` but slightly faster
1381 iterators[idx] = iterators[-1]
1382 del iterators[-1]
1385def collapse(iterable, base_type=None, levels=None):
1386 """Flatten an iterable with multiple levels of nesting (e.g., a list of
1387 lists of tuples) into non-iterable types.
1389 >>> iterable = [(1, 2), ([3, 4], [[5], [6]])]
1390 >>> list(collapse(iterable))
1391 [1, 2, 3, 4, 5, 6]
1393 Binary and text strings are not considered iterable and
1394 will not be collapsed.
1396 To avoid collapsing other types, specify *base_type*:
1398 >>> iterable = ['ab', ('cd', 'ef'), ['gh', 'ij']]
1399 >>> list(collapse(iterable, base_type=tuple))
1400 ['ab', ('cd', 'ef'), 'gh', 'ij']
1402 Specify *levels* to stop flattening after a certain level:
1404 >>> iterable = [('a', ['b']), ('c', ['d'])]
1405 >>> list(collapse(iterable)) # Fully flattened
1406 ['a', 'b', 'c', 'd']
1407 >>> list(collapse(iterable, levels=1)) # Only one level flattened
1408 ['a', ['b'], 'c', ['d']]
1410 """
1411 stack = deque()
1412 # Add our first node group, treat the iterable as a single node
1413 stack.appendleft((0, repeat(iterable, 1)))
1415 while stack:
1416 node_group = stack.popleft()
1417 level, nodes = node_group
1419 # Check if beyond max level
1420 if levels is not None and level > levels:
1421 yield from nodes
1422 continue
1424 for node in nodes:
1425 # Check if done iterating
1426 if isinstance(node, (str, bytes)) or (
1427 (base_type is not None) and isinstance(node, base_type)
1428 ):
1429 yield node
1430 # Otherwise try to create child nodes
1431 else:
1432 try:
1433 tree = iter(node)
1434 except TypeError:
1435 yield node
1436 else:
1437 # Save our current location
1438 stack.appendleft(node_group)
1439 # Append the new child node
1440 stack.appendleft((level + 1, tree))
1441 # Break to process child node
1442 break
1445def side_effect(func, iterable, chunk_size=None, before=None, after=None):
1446 """Invoke *func* on each item in *iterable* (or on each *chunk_size* group
1447 of items) before yielding the item.
1449 `func` must be a function that takes a single argument. Its return value
1450 will be discarded.
1452 *before* and *after* are optional functions that take no arguments. They
1453 will be executed before iteration starts and after it ends, respectively.
1455 `side_effect` can be used for logging, updating progress bars, or anything
1456 that is not functionally "pure."
1458 Emitting a status message:
1460 >>> from more_itertools import consume
1461 >>> func = lambda item: print('Received {}'.format(item))
1462 >>> consume(side_effect(func, range(2)))
1463 Received 0
1464 Received 1
1466 Operating on chunks of items:
1468 >>> pair_sums = []
1469 >>> func = lambda chunk: pair_sums.append(sum(chunk))
1470 >>> list(side_effect(func, [0, 1, 2, 3, 4, 5], 2))
1471 [0, 1, 2, 3, 4, 5]
1472 >>> list(pair_sums)
1473 [1, 5, 9]
1475 Writing to a file-like object:
1477 >>> from io import StringIO
1478 >>> from more_itertools import consume
1479 >>> f = StringIO()
1480 >>> func = lambda x: print(x, file=f)
1481 >>> before = lambda: print(u'HEADER', file=f)
1482 >>> after = f.close
1483 >>> it = [u'a', u'b', u'c']
1484 >>> consume(side_effect(func, it, before=before, after=after))
1485 >>> f.closed
1486 True
1488 """
1489 try:
1490 if before is not None:
1491 before()
1493 if chunk_size is None:
1494 for item in iterable:
1495 func(item)
1496 yield item
1497 else:
1498 for chunk in chunked(iterable, chunk_size):
1499 func(chunk)
1500 yield from chunk
1501 finally:
1502 if after is not None:
1503 after()
1506def sliced(seq, n, strict=False):
1507 """Yield slices of length *n* from the sequence *seq*.
1509 >>> list(sliced((1, 2, 3, 4, 5, 6), 3))
1510 [(1, 2, 3), (4, 5, 6)]
1512 By the default, the last yielded slice will have fewer than *n* elements
1513 if the length of *seq* is not divisible by *n*:
1515 >>> list(sliced((1, 2, 3, 4, 5, 6, 7, 8), 3))
1516 [(1, 2, 3), (4, 5, 6), (7, 8)]
1518 If the length of *seq* is not divisible by *n* and *strict* is
1519 ``True``, then ``ValueError`` will be raised before the last
1520 slice is yielded.
1522 This function will only work for iterables that support slicing.
1523 For non-sliceable iterables, see :func:`chunked`.
1525 """
1526 iterator = takewhile(len, (seq[i : i + n] for i in count(0, n)))
1527 if strict:
1529 def ret():
1530 for _slice in iterator:
1531 if len(_slice) != n:
1532 raise ValueError("seq is not divisible by n.")
1533 yield _slice
1535 return ret()
1536 else:
1537 return iterator
1540def split_at(iterable, pred, maxsplit=-1, keep_separator=False):
1541 """Yield lists of items from *iterable*, where each list is delimited by
1542 an item where callable *pred* returns ``True``.
1544 >>> list(split_at('abcdcba', lambda x: x == 'b'))
1545 [['a'], ['c', 'd', 'c'], ['a']]
1547 >>> list(split_at(range(10), lambda n: n % 2 == 1))
1548 [[0], [2], [4], [6], [8], []]
1550 At most *maxsplit* splits are done. If *maxsplit* is not specified or -1,
1551 then there is no limit on the number of splits:
1553 >>> list(split_at(range(10), lambda n: n % 2 == 1, maxsplit=2))
1554 [[0], [2], [4, 5, 6, 7, 8, 9]]
1556 By default, the delimiting items are not included in the output.
1557 To include them, set *keep_separator* to ``True``.
1559 >>> list(split_at('abcdcba', lambda x: x == 'b', keep_separator=True))
1560 [['a'], ['b'], ['c', 'd', 'c'], ['b'], ['a']]
1562 """
1563 if maxsplit == 0:
1564 yield list(iterable)
1565 return
1567 buf = []
1568 it = iter(iterable)
1569 for item in it:
1570 if pred(item):
1571 yield buf
1572 if keep_separator:
1573 yield [item]
1574 if maxsplit == 1:
1575 yield list(it)
1576 return
1577 buf = []
1578 maxsplit -= 1
1579 else:
1580 buf.append(item)
1581 yield buf
1584def split_before(iterable, pred, maxsplit=-1):
1585 """Yield lists of items from *iterable*, where each list ends just before
1586 an item for which callable *pred* returns ``True``:
1588 >>> list(split_before('OneTwo', lambda s: s.isupper()))
1589 [['O', 'n', 'e'], ['T', 'w', 'o']]
1591 >>> list(split_before(range(10), lambda n: n % 3 == 0))
1592 [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
1594 At most *maxsplit* splits are done. If *maxsplit* is not specified or -1,
1595 then there is no limit on the number of splits:
1597 >>> list(split_before(range(10), lambda n: n % 3 == 0, maxsplit=2))
1598 [[0, 1, 2], [3, 4, 5], [6, 7, 8, 9]]
1599 """
1600 if maxsplit == 0:
1601 yield list(iterable)
1602 return
1604 buf = []
1605 it = iter(iterable)
1606 for item in it:
1607 if pred(item) and buf:
1608 yield buf
1609 if maxsplit == 1:
1610 yield [item, *it]
1611 return
1612 buf = []
1613 maxsplit -= 1
1614 buf.append(item)
1615 if buf:
1616 yield buf
1619def split_after(iterable, pred, maxsplit=-1):
1620 """Yield lists of items from *iterable*, where each list ends with an
1621 item where callable *pred* returns ``True``:
1623 >>> list(split_after('one1two2', lambda s: s.isdigit()))
1624 [['o', 'n', 'e', '1'], ['t', 'w', 'o', '2']]
1626 >>> list(split_after(range(10), lambda n: n % 3 == 0))
1627 [[0], [1, 2, 3], [4, 5, 6], [7, 8, 9]]
1629 At most *maxsplit* splits are done. If *maxsplit* is not specified or -1,
1630 then there is no limit on the number of splits:
1632 >>> list(split_after(range(10), lambda n: n % 3 == 0, maxsplit=2))
1633 [[0], [1, 2, 3], [4, 5, 6, 7, 8, 9]]
1635 """
1636 if maxsplit == 0:
1637 yield list(iterable)
1638 return
1640 buf = []
1641 it = iter(iterable)
1642 for item in it:
1643 buf.append(item)
1644 if pred(item) and buf:
1645 yield buf
1646 if maxsplit == 1:
1647 buf = list(it)
1648 if buf:
1649 yield buf
1650 return
1651 buf = []
1652 maxsplit -= 1
1653 if buf:
1654 yield buf
1657def split_when(iterable, pred, maxsplit=-1):
1658 """Split *iterable* into pieces based on the output of *pred*.
1659 *pred* should be a function that takes successive pairs of items and
1660 returns ``True`` if the iterable should be split in between them.
1662 For example, to find runs of increasing numbers, split the iterable when
1663 element ``i`` is larger than element ``i + 1``:
1665 >>> list(split_when([1, 2, 3, 3, 2, 5, 2, 4, 2], lambda x, y: x > y))
1666 [[1, 2, 3, 3], [2, 5], [2, 4], [2]]
1668 At most *maxsplit* splits are done. If *maxsplit* is not specified or -1,
1669 then there is no limit on the number of splits:
1671 >>> list(split_when([1, 2, 3, 3, 2, 5, 2, 4, 2],
1672 ... lambda x, y: x > y, maxsplit=2))
1673 [[1, 2, 3, 3], [2, 5], [2, 4, 2]]
1675 """
1676 if maxsplit == 0:
1677 yield list(iterable)
1678 return
1680 it = iter(iterable)
1681 try:
1682 cur_item = next(it)
1683 except StopIteration:
1684 return
1686 buf = [cur_item]
1687 for next_item in it:
1688 if pred(cur_item, next_item):
1689 yield buf
1690 if maxsplit == 1:
1691 yield [next_item, *it]
1692 return
1693 buf = []
1694 maxsplit -= 1
1696 buf.append(next_item)
1697 cur_item = next_item
1699 yield buf
1702def split_into(iterable, sizes):
1703 """Yield a list of sequential items from *iterable* of length 'n' for each
1704 integer 'n' in *sizes*.
1706 >>> list(split_into([1,2,3,4,5,6], [1,2,3]))
1707 [[1], [2, 3], [4, 5, 6]]
1709 If the sum of *sizes* is smaller than the length of *iterable*, then the
1710 remaining items of *iterable* will not be returned.
1712 >>> list(split_into([1,2,3,4,5,6], [2,3]))
1713 [[1, 2], [3, 4, 5]]
1715 If the sum of *sizes* is larger than the length of *iterable*, fewer items
1716 will be returned in the iteration that overruns the *iterable* and further
1717 lists will be empty:
1719 >>> list(split_into([1,2,3,4], [1,2,3,4]))
1720 [[1], [2, 3], [4], []]
1722 When a ``None`` object is encountered in *sizes*, the returned list will
1723 contain items up to the end of *iterable* the same way that
1724 :func:`itertools.slice` does:
1726 >>> list(split_into([1,2,3,4,5,6,7,8,9,0], [2,3,None]))
1727 [[1, 2], [3, 4, 5], [6, 7, 8, 9, 0]]
1729 :func:`split_into` can be useful for grouping a series of items where the
1730 sizes of the groups are not uniform. An example would be where in a row
1731 from a table, multiple columns represent elements of the same feature
1732 (e.g. a point represented by x,y,z) but, the format is not the same for
1733 all columns.
1734 """
1735 # convert the iterable argument into an iterator so its contents can
1736 # be consumed by islice in case it is a generator
1737 it = iter(iterable)
1739 for size in sizes:
1740 if size is None:
1741 yield list(it)
1742 return
1743 else:
1744 yield list(islice(it, size))
1747def padded(iterable, fillvalue=None, n=None, next_multiple=False):
1748 """Yield the elements from *iterable*, followed by *fillvalue*, such that
1749 at least *n* items are emitted.
1751 >>> list(padded([1, 2, 3], '?', 5))
1752 [1, 2, 3, '?', '?']
1754 If *next_multiple* is ``True``, *fillvalue* will be emitted until the
1755 number of items emitted is a multiple of *n*:
1757 >>> list(padded([1, 2, 3, 4], n=3, next_multiple=True))
1758 [1, 2, 3, 4, None, None]
1760 If *n* is ``None``, *fillvalue* will be emitted indefinitely.
1762 To create an *iterable* of exactly size *n*, you can truncate with
1763 :func:`islice`.
1765 >>> list(islice(padded([1, 2, 3], '?'), 5))
1766 [1, 2, 3, '?', '?']
1767 >>> list(islice(padded([1, 2, 3, 4, 5, 6, 7, 8], '?'), 5))
1768 [1, 2, 3, 4, 5]
1770 """
1771 iterator = iter(iterable)
1772 iterator_with_repeat = chain(iterator, repeat(fillvalue))
1774 if n is None:
1775 return iterator_with_repeat
1776 elif n < 1:
1777 raise ValueError('n must be at least 1')
1778 elif next_multiple:
1780 def slice_generator():
1781 for first in iterator:
1782 yield (first,)
1783 yield islice(iterator_with_repeat, n - 1)
1785 # While elements exist produce slices of size n
1786 return chain.from_iterable(slice_generator())
1787 else:
1788 # Ensure the first batch is at least size n then iterate
1789 return chain(islice(iterator_with_repeat, n), iterator)
1792def repeat_each(iterable, n=2):
1793 """Repeat each element in *iterable* *n* times.
1795 >>> list(repeat_each('ABC', 3))
1796 ['A', 'A', 'A', 'B', 'B', 'B', 'C', 'C', 'C']
1797 """
1798 return chain.from_iterable(map(repeat, iterable, repeat(n)))
1801def repeat_last(iterable, default=None):
1802 """After the *iterable* is exhausted, keep yielding its last element.
1804 >>> list(islice(repeat_last(range(3)), 5))
1805 [0, 1, 2, 2, 2]
1807 If the iterable is empty, yield *default* forever::
1809 >>> list(islice(repeat_last(range(0), 42), 5))
1810 [42, 42, 42, 42, 42]
1812 """
1813 item = _marker
1814 for item in iterable:
1815 yield item
1816 final = default if item is _marker else item
1817 yield from repeat(final)
1820def distribute(n, iterable):
1821 """Distribute the items from *iterable* among *n* smaller iterables.
1823 >>> group_1, group_2 = distribute(2, [1, 2, 3, 4, 5, 6])
1824 >>> list(group_1)
1825 [1, 3, 5]
1826 >>> list(group_2)
1827 [2, 4, 6]
1829 If the length of *iterable* is not evenly divisible by *n*, then the
1830 length of the returned iterables will not be identical:
1832 >>> children = distribute(3, [1, 2, 3, 4, 5, 6, 7])
1833 >>> [list(c) for c in children]
1834 [[1, 4, 7], [2, 5], [3, 6]]
1836 If the length of *iterable* is smaller than *n*, then the last returned
1837 iterables will be empty:
1839 >>> children = distribute(5, [1, 2, 3])
1840 >>> [list(c) for c in children]
1841 [[1], [2], [3], [], []]
1843 This function uses :func:`itertools.tee` and may require significant
1844 storage.
1846 If you need the order items in the smaller iterables to match the
1847 original iterable, see :func:`divide`.
1849 """
1850 if n < 1:
1851 raise ValueError('n must be at least 1')
1853 children = tee(iterable, n)
1854 return [islice(it, index, None, n) for index, it in enumerate(children)]
1857def stagger(iterable, offsets=(-1, 0, 1), longest=False, fillvalue=None):
1858 """Yield tuples whose elements are offset from *iterable*.
1859 The amount by which the `i`-th item in each tuple is offset is given by
1860 the `i`-th item in *offsets*.
1862 >>> list(stagger([0, 1, 2, 3]))
1863 [(None, 0, 1), (0, 1, 2), (1, 2, 3)]
1864 >>> list(stagger(range(8), offsets=(0, 2, 4)))
1865 [(0, 2, 4), (1, 3, 5), (2, 4, 6), (3, 5, 7)]
1867 By default, the sequence will end when the final element of a tuple is the
1868 last item in the iterable. To continue until the first element of a tuple
1869 is the last item in the iterable, set *longest* to ``True``::
1871 >>> list(stagger([0, 1, 2, 3], longest=True))
1872 [(None, 0, 1), (0, 1, 2), (1, 2, 3), (2, 3, None), (3, None, None)]
1874 By default, ``None`` will be used to replace offsets beyond the end of the
1875 sequence. Specify *fillvalue* to use some other value.
1877 """
1878 children = tee(iterable, len(offsets))
1880 return zip_offset(
1881 *children, offsets=offsets, longest=longest, fillvalue=fillvalue
1882 )
1885def zip_offset(*iterables, offsets, longest=False, fillvalue=None):
1886 """``zip`` the input *iterables* together, but offset the `i`-th iterable
1887 by the `i`-th item in *offsets*.
1889 >>> list(zip_offset('0123', 'abcdef', offsets=(0, 1)))
1890 [('0', 'b'), ('1', 'c'), ('2', 'd'), ('3', 'e')]
1892 This can be used as a lightweight alternative to SciPy or pandas to analyze
1893 data sets in which some series have a lead or lag relationship.
1895 By default, the sequence will end when the shortest iterable is exhausted.
1896 To continue until the longest iterable is exhausted, set *longest* to
1897 ``True``.
1899 >>> list(zip_offset('0123', 'abcdef', offsets=(0, 1), longest=True))
1900 [('0', 'b'), ('1', 'c'), ('2', 'd'), ('3', 'e'), (None, 'f')]
1902 By default, ``None`` will be used to replace offsets beyond the end of the
1903 sequence. Specify *fillvalue* to use some other value.
1905 """
1906 if len(iterables) != len(offsets):
1907 raise ValueError("Number of iterables and offsets didn't match")
1909 staggered = []
1910 for it, n in zip(iterables, offsets):
1911 if n < 0:
1912 staggered.append(chain(repeat(fillvalue, -n), it))
1913 elif n > 0:
1914 staggered.append(islice(it, n, None))
1915 else:
1916 staggered.append(it)
1918 if longest:
1919 return zip_longest(*staggered, fillvalue=fillvalue)
1921 return zip(*staggered)
1924def sort_together(
1925 iterables, key_list=(0,), key=None, reverse=False, strict=False
1926):
1927 """Return the input iterables sorted together, with *key_list* as the
1928 priority for sorting. All iterables are trimmed to the length of the
1929 shortest one.
1931 This can be used like the sorting function in a spreadsheet. If each
1932 iterable represents a column of data, the key list determines which
1933 columns are used for sorting.
1935 By default, all iterables are sorted using the ``0``-th iterable::
1937 >>> iterables = [(4, 3, 2, 1), ('a', 'b', 'c', 'd')]
1938 >>> sort_together(iterables)
1939 [(1, 2, 3, 4), ('d', 'c', 'b', 'a')]
1941 Set a different key list to sort according to another iterable.
1942 Specifying multiple keys dictates how ties are broken::
1944 >>> iterables = [(3, 1, 2), (0, 1, 0), ('c', 'b', 'a')]
1945 >>> sort_together(iterables, key_list=(1, 2))
1946 [(2, 3, 1), (0, 0, 1), ('a', 'c', 'b')]
1948 To sort by a function of the elements of the iterable, pass a *key*
1949 function. Its arguments are the elements of the iterables corresponding to
1950 the key list::
1952 >>> names = ('a', 'b', 'c')
1953 >>> lengths = (1, 2, 3)
1954 >>> widths = (5, 2, 1)
1955 >>> def area(length, width):
1956 ... return length * width
1957 >>> sort_together([names, lengths, widths], key_list=(1, 2), key=area)
1958 [('c', 'b', 'a'), (3, 2, 1), (1, 2, 5)]
1960 Set *reverse* to ``True`` to sort in descending order.
1962 >>> sort_together([(1, 2, 3), ('c', 'b', 'a')], reverse=True)
1963 [(3, 2, 1), ('a', 'b', 'c')]
1965 If the *strict* keyword argument is ``True``, then
1966 ``ValueError`` will be raised if any of the iterables have
1967 different lengths.
1969 """
1970 if key is None:
1971 # if there is no key function, the key argument to sorted is an
1972 # itemgetter
1973 key_argument = itemgetter(*key_list)
1974 else:
1975 # if there is a key function, call it with the items at the offsets
1976 # specified by the key function as arguments
1977 key_list = list(key_list)
1978 if len(key_list) == 1:
1979 # if key_list contains a single item, pass the item at that offset
1980 # as the only argument to the key function
1981 key_offset = key_list[0]
1982 key_argument = lambda zipped_items: key(zipped_items[key_offset])
1983 else:
1984 # if key_list contains multiple items, use itemgetter to return a
1985 # tuple of items, which we pass as *args to the key function
1986 get_key_items = itemgetter(*key_list)
1987 key_argument = lambda zipped_items: key(
1988 *get_key_items(zipped_items)
1989 )
1991 transposed = zip(*iterables, strict=strict)
1992 reordered = sorted(transposed, key=key_argument, reverse=reverse)
1993 untransposed = zip(*reordered, strict=strict)
1994 return list(untransposed)
1997def unzip(iterable):
1998 """The inverse of :func:`zip`, this function disaggregates the elements
1999 of the zipped *iterable*.
2001 The ``i``-th iterable contains the ``i``-th element from each element
2002 of the zipped iterable. The first element is used to determine the
2003 length of the remaining elements.
2005 >>> iterable = [('a', 1), ('b', 2), ('c', 3), ('d', 4)]
2006 >>> letters, numbers = unzip(iterable)
2007 >>> list(letters)
2008 ['a', 'b', 'c', 'd']
2009 >>> list(numbers)
2010 [1, 2, 3, 4]
2012 This is similar to using ``zip(*iterable)``, but it avoids reading
2013 *iterable* into memory. Note, however, that this function uses
2014 :func:`itertools.tee` and thus may require significant storage.
2016 """
2017 head, iterable = spy(iterable)
2018 if not head:
2019 # empty iterable, e.g. zip([], [], [])
2020 return ()
2021 # spy returns a one-length iterable as head
2022 head = head[0]
2023 iterables = tee(iterable, len(head))
2025 # If we have an iterable like iter([(1, 2, 3), (4, 5), (6,)]),
2026 # the second unzipped iterable fails at the third tuple since
2027 # it tries to access (6,)[1].
2028 # Same with the third unzipped iterable and the second tuple.
2029 # To support these "improperly zipped" iterables, we suppress
2030 # the IndexError, which just stops the unzipped iterables at
2031 # first length mismatch.
2032 return tuple(
2033 iter_suppress(map(itemgetter(i), it), IndexError)
2034 for i, it in enumerate(iterables)
2035 )
2038def divide(n, iterable):
2039 """Divide the elements from *iterable* into *n* parts, maintaining
2040 order.
2042 >>> group_1, group_2 = divide(2, [1, 2, 3, 4, 5, 6])
2043 >>> list(group_1)
2044 [1, 2, 3]
2045 >>> list(group_2)
2046 [4, 5, 6]
2048 If the length of *iterable* is not evenly divisible by *n*, then the
2049 length of the returned iterables will not be identical:
2051 >>> children = divide(3, [1, 2, 3, 4, 5, 6, 7])
2052 >>> [list(c) for c in children]
2053 [[1, 2, 3], [4, 5], [6, 7]]
2055 If the length of the iterable is smaller than n, then the last returned
2056 iterables will be empty:
2058 >>> children = divide(5, [1, 2, 3])
2059 >>> [list(c) for c in children]
2060 [[1], [2], [3], [], []]
2062 This function will exhaust the iterable before returning.
2063 If order is not important, see :func:`distribute`, which does not first
2064 pull the iterable into memory.
2066 """
2067 if n < 1:
2068 raise ValueError('n must be at least 1')
2070 try:
2071 iterable[:0]
2072 except TypeError:
2073 seq = tuple(iterable)
2074 else:
2075 seq = iterable
2077 q, r = divmod(len(seq), n)
2079 ret = []
2080 stop = 0
2081 for i in range(1, n + 1):
2082 start = stop
2083 stop += q + 1 if i <= r else q
2084 ret.append(iter(seq[start:stop]))
2086 return ret
2089def always_iterable(obj, base_type=(str, bytes)):
2090 """If *obj* is iterable, return an iterator over its items::
2092 >>> obj = (1, 2, 3)
2093 >>> list(always_iterable(obj))
2094 [1, 2, 3]
2096 If *obj* is not iterable, return a one-item iterable containing *obj*::
2098 >>> obj = 1
2099 >>> list(always_iterable(obj))
2100 [1]
2102 If *obj* is ``None``, return an empty iterable:
2104 >>> obj = None
2105 >>> list(always_iterable(None))
2106 []
2108 By default, binary and text strings are not considered iterable::
2110 >>> obj = 'foo'
2111 >>> list(always_iterable(obj))
2112 ['foo']
2114 If *base_type* is set, objects for which ``isinstance(obj, base_type)``
2115 returns ``True`` won't be considered iterable.
2117 >>> obj = {'a': 1}
2118 >>> list(always_iterable(obj)) # Iterate over the dict's keys
2119 ['a']
2120 >>> list(always_iterable(obj, base_type=dict)) # Treat dicts as a unit
2121 [{'a': 1}]
2123 Set *base_type* to ``None`` to avoid any special handling and treat objects
2124 Python considers iterable as iterable:
2126 >>> obj = 'foo'
2127 >>> list(always_iterable(obj, base_type=None))
2128 ['f', 'o', 'o']
2129 """
2130 if obj is None:
2131 return iter(())
2133 if (base_type is not None) and isinstance(obj, base_type):
2134 return iter((obj,))
2136 try:
2137 return iter(obj)
2138 except TypeError:
2139 return iter((obj,))
2142def adjacent(predicate, iterable, distance=1):
2143 """Return an iterable over `(bool, item)` tuples where the `item` is
2144 drawn from *iterable* and the `bool` indicates whether
2145 that item satisfies the *predicate* or is adjacent to an item that does.
2147 For example, to find whether items are adjacent to a ``3``::
2149 >>> list(adjacent(lambda x: x == 3, range(6)))
2150 [(False, 0), (False, 1), (True, 2), (True, 3), (True, 4), (False, 5)]
2152 Set *distance* to change what counts as adjacent. For example, to find
2153 whether items are two places away from a ``3``:
2155 >>> list(adjacent(lambda x: x == 3, range(6), distance=2))
2156 [(False, 0), (True, 1), (True, 2), (True, 3), (True, 4), (True, 5)]
2158 This is useful for contextualizing the results of a search function.
2159 For example, a code comparison tool might want to identify lines that
2160 have changed, but also surrounding lines to give the viewer of the diff
2161 context.
2163 The predicate function will only be called once for each item in the
2164 iterable.
2166 See also :func:`groupby_transform`, which can be used with this function
2167 to group ranges of items with the same `bool` value.
2169 """
2170 # Allow distance=0 mainly for testing that it reproduces results with map()
2171 if distance < 0:
2172 raise ValueError('distance must be at least 0')
2174 i1, i2 = tee(iterable)
2175 padding = [False] * distance
2176 selected = chain(padding, map(predicate, i1), padding)
2177 adjacent_to_selected = map(any, windowed(selected, 2 * distance + 1))
2178 return zip(adjacent_to_selected, i2)
2181def groupby_transform(iterable, keyfunc=None, valuefunc=None, reducefunc=None):
2182 """An extension of :func:`itertools.groupby` that can apply transformations
2183 to the grouped data.
2185 * *keyfunc* is a function computing a key value for each item in *iterable*
2186 * *valuefunc* is a function that transforms the individual items from
2187 *iterable* after grouping
2188 * *reducefunc* is a function that transforms each group of items
2190 >>> iterable = 'aAAbBBcCC'
2191 >>> keyfunc = lambda k: k.upper()
2192 >>> valuefunc = lambda v: v.lower()
2193 >>> reducefunc = lambda g: ''.join(g)
2194 >>> list(groupby_transform(iterable, keyfunc, valuefunc, reducefunc))
2195 [('A', 'aaa'), ('B', 'bbb'), ('C', 'ccc')]
2197 Each optional argument defaults to an identity function if not specified.
2199 :func:`groupby_transform` is useful when grouping elements of an iterable
2200 using a separate iterable as the key. To do this, :func:`zip` the iterables
2201 and pass a *keyfunc* that extracts the first element and a *valuefunc*
2202 that extracts the second element::
2204 >>> from operator import itemgetter
2205 >>> keys = [0, 0, 1, 1, 1, 2, 2, 2, 3]
2206 >>> values = 'abcdefghi'
2207 >>> iterable = zip(keys, values)
2208 >>> grouper = groupby_transform(iterable, itemgetter(0), itemgetter(1))
2209 >>> [(k, ''.join(g)) for k, g in grouper]
2210 [(0, 'ab'), (1, 'cde'), (2, 'fgh'), (3, 'i')]
2212 Note that the order of items in the iterable is significant.
2213 Only adjacent items are grouped together, so if you don't want any
2214 duplicate groups, you should sort the iterable by the key function.
2216 """
2217 ret = groupby(iterable, keyfunc)
2218 if valuefunc:
2219 ret = ((k, map(valuefunc, g)) for k, g in ret)
2220 if reducefunc:
2221 ret = ((k, reducefunc(g)) for k, g in ret)
2223 return ret
2226class numeric_range(Sequence):
2227 """An extension of the built-in ``range()`` function whose arguments can
2228 be any orderable numeric type.
2230 With only *stop* specified, *start* defaults to ``0`` and *step*
2231 defaults to ``1``. The output items will match the type of *stop*:
2233 >>> list(numeric_range(3.5))
2234 [0.0, 1.0, 2.0, 3.0]
2236 With only *start* and *stop* specified, *step* defaults to ``1``. The
2237 output items will match the type of *start*:
2239 >>> from decimal import Decimal
2240 >>> start = Decimal('2.1')
2241 >>> stop = Decimal('5.1')
2242 >>> list(numeric_range(start, stop))
2243 [Decimal('2.1'), Decimal('3.1'), Decimal('4.1')]
2245 With *start*, *stop*, and *step* specified the output items will match
2246 the type of ``start + step``:
2248 >>> from fractions import Fraction
2249 >>> start = Fraction(1, 2) # Start at 1/2
2250 >>> stop = Fraction(5, 2) # End at 5/2
2251 >>> step = Fraction(1, 2) # Count by 1/2
2252 >>> list(numeric_range(start, stop, step))
2253 [Fraction(1, 2), Fraction(1, 1), Fraction(3, 2), Fraction(2, 1)]
2255 If *step* is zero, ``ValueError`` is raised. Negative steps are supported:
2257 >>> list(numeric_range(3, -1, -1.0))
2258 [3.0, 2.0, 1.0, 0.0]
2260 Be aware of the limitations of floating-point numbers; the representation
2261 of the yielded numbers may be surprising.
2263 ``datetime.datetime`` objects can be used for *start* and *stop*, if *step*
2264 is a ``datetime.timedelta`` object:
2266 >>> import datetime
2267 >>> start = datetime.datetime(2019, 1, 1)
2268 >>> stop = datetime.datetime(2019, 1, 3)
2269 >>> step = datetime.timedelta(days=1)
2270 >>> items = iter(numeric_range(start, stop, step))
2271 >>> next(items)
2272 datetime.datetime(2019, 1, 1, 0, 0)
2273 >>> next(items)
2274 datetime.datetime(2019, 1, 2, 0, 0)
2276 """
2278 _EMPTY_HASH = hash(range(0, 0))
2280 def __init__(self, *args):
2281 argc = len(args)
2282 if argc == 1:
2283 (self._stop,) = args
2284 self._start = type(self._stop)(0)
2285 self._step = type(self._stop - self._start)(1)
2286 elif argc == 2:
2287 self._start, self._stop = args
2288 self._step = type(self._stop - self._start)(1)
2289 elif argc == 3:
2290 self._start, self._stop, self._step = args
2291 elif argc == 0:
2292 raise TypeError(
2293 f'numeric_range expected at least 1 argument, got {argc}'
2294 )
2295 else:
2296 raise TypeError(
2297 f'numeric_range expected at most 3 arguments, got {argc}'
2298 )
2300 self._zero = type(self._step)(0)
2301 if self._step == self._zero:
2302 raise ValueError('numeric_range() arg 3 must not be zero')
2303 self._growing = self._step > self._zero
2305 def __bool__(self):
2306 if self._growing:
2307 return self._start < self._stop
2308 else:
2309 return self._start > self._stop
2311 def __contains__(self, elem):
2312 if self._growing:
2313 if self._start <= elem < self._stop:
2314 return (elem - self._start) % self._step == self._zero
2315 else:
2316 if self._start >= elem > self._stop:
2317 return (self._start - elem) % (-self._step) == self._zero
2319 return False
2321 def __eq__(self, other):
2322 if isinstance(other, numeric_range):
2323 empty_self = not bool(self)
2324 empty_other = not bool(other)
2325 if empty_self or empty_other:
2326 return empty_self and empty_other # True if both empty
2327 else:
2328 return (
2329 self._start == other._start
2330 and self._step == other._step
2331 and self._get_by_index(-1) == other._get_by_index(-1)
2332 )
2333 else:
2334 return False
2336 def __getitem__(self, key):
2337 if isinstance(key, int):
2338 return self._get_by_index(key)
2339 elif isinstance(key, slice):
2340 step = self._step if key.step is None else key.step * self._step
2342 if key.start is None or key.start <= -self._len:
2343 start = self._start
2344 elif key.start >= self._len:
2345 start = self._stop
2346 else: # -self._len < key.start < self._len
2347 start = self._get_by_index(key.start)
2349 if key.stop is None or key.stop >= self._len:
2350 stop = self._stop
2351 elif key.stop <= -self._len:
2352 stop = self._start
2353 else: # -self._len < key.stop < self._len
2354 stop = self._get_by_index(key.stop)
2356 return numeric_range(start, stop, step)
2357 else:
2358 raise TypeError(
2359 'numeric range indices must be '
2360 f'integers or slices, not {type(key).__name__}'
2361 )
2363 def __hash__(self):
2364 if self:
2365 return hash((self._start, self._get_by_index(-1), self._step))
2366 else:
2367 return self._EMPTY_HASH
2369 def __iter__(self):
2370 values = (self._start + (n * self._step) for n in count())
2371 if self._growing:
2372 return takewhile(partial(gt, self._stop), values)
2373 else:
2374 return takewhile(partial(lt, self._stop), values)
2376 def __len__(self):
2377 return self._len
2379 @cached_property
2380 def _len(self):
2381 if self._growing:
2382 start = self._start
2383 stop = self._stop
2384 step = self._step
2385 else:
2386 start = self._stop
2387 stop = self._start
2388 step = -self._step
2389 distance = stop - start
2390 if distance <= self._zero:
2391 return 0
2392 else: # distance > 0 and step > 0: regular euclidean division
2393 q, r = divmod(distance, step)
2394 return int(q) + int(r != self._zero)
2396 def __reduce__(self):
2397 return numeric_range, (self._start, self._stop, self._step)
2399 def __repr__(self):
2400 if self._step == 1:
2401 return f"numeric_range({self._start!r}, {self._stop!r})"
2402 return (
2403 f"numeric_range({self._start!r}, {self._stop!r}, {self._step!r})"
2404 )
2406 def __reversed__(self):
2407 return iter(
2408 numeric_range(
2409 self._get_by_index(-1), self._start - self._step, -self._step
2410 )
2411 )
2413 def count(self, value):
2414 return int(value in self)
2416 def index(self, value):
2417 if self._growing:
2418 if self._start <= value < self._stop:
2419 q, r = divmod(value - self._start, self._step)
2420 if r == self._zero:
2421 return int(q)
2422 else:
2423 if self._start >= value > self._stop:
2424 q, r = divmod(self._start - value, -self._step)
2425 if r == self._zero:
2426 return int(q)
2428 raise ValueError(f"{value} is not in numeric range")
2430 def _get_by_index(self, i):
2431 if i < 0:
2432 i += self._len
2433 if i < 0 or i >= self._len:
2434 raise IndexError("numeric range object index out of range")
2435 return self._start + i * self._step
2438def count_cycle(iterable, n=None):
2439 """Cycle through the items from *iterable* up to *n* times, yielding
2440 the number of completed cycles along with each item. If *n* is omitted the
2441 process repeats indefinitely.
2443 >>> list(count_cycle('AB', 3))
2444 [(0, 'A'), (0, 'B'), (1, 'A'), (1, 'B'), (2, 'A'), (2, 'B')]
2446 """
2447 if n is not None:
2448 return product(range(n), iterable)
2449 seq = tuple(iterable)
2450 if not seq:
2451 return iter(())
2452 counter = count() if n is None else range(n)
2453 return zip(repeat_each(counter, len(seq)), cycle(seq))
2456def mark_ends(iterable):
2457 """Yield 3-tuples of the form ``(is_first, is_last, item)``.
2459 >>> list(mark_ends('ABC'))
2460 [(True, False, 'A'), (False, False, 'B'), (False, True, 'C')]
2462 Use this when looping over an iterable to take special action on its first
2463 and/or last items:
2465 >>> iterable = ['Header', 100, 200, 'Footer']
2466 >>> total = 0
2467 >>> for is_first, is_last, item in mark_ends(iterable):
2468 ... if is_first:
2469 ... continue # Skip the header
2470 ... if is_last:
2471 ... continue # Skip the footer
2472 ... total += item
2473 >>> print(total)
2474 300
2475 """
2476 it = iter(iterable)
2477 for a in it:
2478 first = True
2479 for b in it:
2480 yield first, False, a
2481 a = b
2482 first = False
2483 yield first, True, a
2486def locate(iterable, pred=bool, window_size=None):
2487 """Yield the index of each item in *iterable* for which *pred* returns
2488 ``True``.
2490 *pred* defaults to :func:`bool`, which will select truthy items:
2492 >>> list(locate([0, 1, 1, 0, 1, 0, 0]))
2493 [1, 2, 4]
2495 Set *pred* to a custom function to, e.g., find the indexes for a particular
2496 item.
2498 >>> list(locate(['a', 'b', 'c', 'b'], lambda x: x == 'b'))
2499 [1, 3]
2501 If *window_size* is given, then the *pred* function will be called with
2502 that many items. This enables searching for sub-sequences:
2504 >>> iterable = [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]
2505 >>> pred = lambda *args: args == (1, 2, 3)
2506 >>> list(locate(iterable, pred=pred, window_size=3))
2507 [1, 5, 9]
2509 Use with :func:`seekable` to find indexes and then retrieve the associated
2510 items:
2512 >>> from itertools import count
2513 >>> from more_itertools import seekable
2514 >>> source = (3 * n + 1 if (n % 2) else n // 2 for n in count())
2515 >>> it = seekable(source)
2516 >>> pred = lambda x: x > 100
2517 >>> indexes = locate(it, pred=pred)
2518 >>> i = next(indexes)
2519 >>> it.seek(i)
2520 >>> next(it)
2521 106
2523 """
2524 if window_size is None:
2525 return compress(count(), map(pred, iterable))
2527 if window_size < 1:
2528 raise ValueError('window size must be at least 1')
2530 it = windowed(iterable, window_size, fillvalue=_marker)
2531 return compress(count(), starmap(pred, it))
2534def longest_common_prefix(iterables):
2535 """Yield elements of the longest common prefix among given *iterables*.
2537 >>> ''.join(longest_common_prefix(['abcd', 'abc', 'abf']))
2538 'ab'
2540 """
2541 return (c[0] for c in takewhile(all_equal, zip(*iterables)))
2544def lstrip(iterable, pred):
2545 """Yield the items from *iterable*, but strip any from the beginning
2546 for which *pred* returns ``True``.
2548 For example, to remove a set of items from the start of an iterable:
2550 >>> iterable = (None, False, None, 1, 2, None, 3, False, None)
2551 >>> pred = lambda x: x in {None, False, ''}
2552 >>> list(lstrip(iterable, pred))
2553 [1, 2, None, 3, False, None]
2555 This function is analogous to to :func:`str.lstrip`, and is essentially
2556 an wrapper for :func:`itertools.dropwhile`.
2558 """
2559 return dropwhile(pred, iterable)
2562def rstrip(iterable, pred):
2563 """Yield the items from *iterable*, but strip any from the end
2564 for which *pred* returns ``True``.
2566 For example, to remove a set of items from the end of an iterable:
2568 >>> iterable = (None, False, None, 1, 2, None, 3, False, None)
2569 >>> pred = lambda x: x in {None, False, ''}
2570 >>> list(rstrip(iterable, pred))
2571 [None, False, None, 1, 2, None, 3]
2573 This function is analogous to :func:`str.rstrip`.
2575 """
2576 cache = []
2577 cache_append = cache.append
2578 cache_clear = cache.clear
2579 for x in iterable:
2580 if pred(x):
2581 cache_append(x)
2582 else:
2583 yield from cache
2584 cache_clear()
2585 yield x
2588def strip(iterable, pred):
2589 """Yield the items from *iterable*, but strip any from the
2590 beginning and end for which *pred* returns ``True``.
2592 For example, to remove a set of items from both ends of an iterable:
2594 >>> iterable = (None, False, None, 1, 2, None, 3, False, None)
2595 >>> pred = lambda x: x in {None, False, ''}
2596 >>> list(strip(iterable, pred))
2597 [1, 2, None, 3]
2599 This function is analogous to :func:`str.strip`.
2601 """
2602 return rstrip(lstrip(iterable, pred), pred)
2605class islice_extended:
2606 """An extension of :func:`itertools.islice` that supports negative values
2607 for *stop*, *start*, and *step*.
2609 >>> iterator = iter('abcdefgh')
2610 >>> list(islice_extended(iterator, -4, -1))
2611 ['e', 'f', 'g']
2613 Slices with negative values require some caching of *iterable*, but this
2614 function takes care to minimize the amount of memory required.
2616 For example, you can use a negative step with an infinite iterator:
2618 >>> from itertools import count
2619 >>> list(islice_extended(count(), 110, 99, -2))
2620 [110, 108, 106, 104, 102, 100]
2622 You can also use slice notation directly:
2624 >>> iterator = map(str, count())
2625 >>> it = islice_extended(iterator)[10:20:2]
2626 >>> list(it)
2627 ['10', '12', '14', '16', '18']
2629 """
2631 def __init__(self, iterable, *args):
2632 it = iter(iterable)
2633 if args:
2634 self._iterator = _islice_helper(it, slice(*args))
2635 else:
2636 self._iterator = it
2638 def __iter__(self):
2639 return self
2641 def __next__(self):
2642 return next(self._iterator)
2644 def __getitem__(self, key):
2645 if isinstance(key, slice):
2646 return islice_extended(_islice_helper(self._iterator, key))
2648 raise TypeError('islice_extended.__getitem__ argument must be a slice')
2651def _islice_helper(it, s):
2652 start = s.start
2653 stop = s.stop
2654 if s.step == 0:
2655 raise ValueError('step argument must be a non-zero integer or None.')
2656 step = s.step or 1
2658 if step > 0:
2659 start = 0 if (start is None) else start
2661 if start < 0:
2662 # Consume all but the last -start items
2663 cache = deque(enumerate(it, 1), maxlen=-start)
2664 len_iter = cache[-1][0] if cache else 0
2666 # Adjust start to be positive
2667 i = max(len_iter + start, 0)
2669 # Adjust stop to be positive
2670 if stop is None:
2671 j = len_iter
2672 elif stop >= 0:
2673 j = min(stop, len_iter)
2674 else:
2675 j = max(len_iter + stop, 0)
2677 # Slice the cache
2678 n = j - i
2679 if n <= 0:
2680 return
2682 for index in range(n):
2683 if index % step == 0:
2684 # pop and yield the item.
2685 # We don't want to use an intermediate variable
2686 # it would extend the lifetime of the current item
2687 yield cache.popleft()[1]
2688 else:
2689 # just pop and discard the item
2690 cache.popleft()
2691 elif (stop is not None) and (stop < 0):
2692 # Advance to the start position
2693 next(islice(it, start, start), None)
2695 # When stop is negative, we have to carry -stop items while
2696 # iterating
2697 cache = deque(islice(it, -stop), maxlen=-stop)
2699 for index, item in enumerate(it):
2700 if index % step == 0:
2701 # pop and yield the item.
2702 # We don't want to use an intermediate variable
2703 # it would extend the lifetime of the current item
2704 yield cache.popleft()
2705 else:
2706 # just pop and discard the item
2707 cache.popleft()
2708 cache.append(item)
2709 else:
2710 # When both start and stop are positive we have the normal case
2711 yield from islice(it, start, stop, step)
2712 else:
2713 start = -1 if (start is None) else start
2715 if (stop is not None) and (stop < 0):
2716 # Consume all but the last items
2717 n = -stop - 1
2718 cache = deque(enumerate(it, 1), maxlen=n)
2719 len_iter = cache[-1][0] if cache else 0
2721 # If start and stop are both negative they are comparable and
2722 # we can just slice. Otherwise we can adjust start to be negative
2723 # and then slice.
2724 if start < 0:
2725 i, j = start, stop
2726 else:
2727 i, j = min(start - len_iter, -1), None
2729 for index, item in list(cache)[i:j:step]:
2730 yield item
2731 else:
2732 # Advance to the stop position
2733 if stop is not None:
2734 m = stop + 1
2735 next(islice(it, m, m), None)
2737 # stop is positive, so if start is negative they are not comparable
2738 # and we need the rest of the items.
2739 if start < 0:
2740 i = start
2741 n = None
2742 # stop is None and start is positive, so we just need items up to
2743 # the start index.
2744 elif stop is None:
2745 i = None
2746 n = start + 1
2747 # Both stop and start are positive, so they are comparable.
2748 else:
2749 i = None
2750 n = start - stop
2751 if n <= 0:
2752 return
2754 cache = list(islice(it, n))
2756 yield from cache[i::step]
2759def always_reversible(iterable):
2760 """An extension of :func:`reversed` that supports all iterables, not
2761 just those which implement the ``Reversible`` or ``Sequence`` protocols.
2763 >>> print(*always_reversible(x for x in range(3)))
2764 2 1 0
2766 If the iterable is already reversible, this function returns the
2767 result of :func:`reversed()`. If the iterable is not reversible,
2768 this function will cache the remaining items in the iterable and
2769 yield them in reverse order, which may require significant storage.
2770 """
2771 try:
2772 return reversed(iterable)
2773 except TypeError:
2774 return reversed(list(iterable))
2777def consecutive_groups(iterable, ordering=None):
2778 """Yield groups of consecutive items using :func:`itertools.groupby`.
2779 The *ordering* function determines whether two items are adjacent by
2780 returning their position.
2782 By default, the ordering function is the identity function. This is
2783 suitable for finding runs of numbers:
2785 >>> iterable = [1, 10, 11, 12, 20, 30, 31, 32, 33, 40]
2786 >>> for group in consecutive_groups(iterable):
2787 ... print(list(group))
2788 [1]
2789 [10, 11, 12]
2790 [20]
2791 [30, 31, 32, 33]
2792 [40]
2794 To find runs of adjacent letters, apply :func:`ord` function
2795 to convert letters to ordinals.
2797 >>> iterable = 'abcdfgilmnop'
2798 >>> ordering = ord
2799 >>> for group in consecutive_groups(iterable, ordering):
2800 ... print(list(group))
2801 ['a', 'b', 'c', 'd']
2802 ['f', 'g']
2803 ['i']
2804 ['l', 'm', 'n', 'o', 'p']
2806 Each group of consecutive items is an iterator that shares it source with
2807 *iterable*. When an an output group is advanced, the previous group is
2808 no longer available unless its elements are copied (e.g., into a ``list``).
2810 >>> iterable = [1, 2, 11, 12, 21, 22]
2811 >>> saved_groups = []
2812 >>> for group in consecutive_groups(iterable):
2813 ... saved_groups.append(list(group)) # Copy group elements
2814 >>> saved_groups
2815 [[1, 2], [11, 12], [21, 22]]
2817 """
2818 if ordering is None:
2819 key = lambda x: x[0] - x[1]
2820 else:
2821 key = lambda x: x[0] - ordering(x[1])
2823 for k, g in groupby(enumerate(iterable), key=key):
2824 yield map(itemgetter(1), g)
2827def difference(iterable, func=sub, *, initial=None):
2828 """This function is the inverse of :func:`itertools.accumulate`. By default
2829 it will compute the first difference of *iterable* using
2830 :func:`operator.sub`:
2832 >>> from itertools import accumulate
2833 >>> iterable = accumulate([0, 1, 2, 3, 4]) # produces 0, 1, 3, 6, 10
2834 >>> list(difference(iterable))
2835 [0, 1, 2, 3, 4]
2837 *func* defaults to :func:`operator.sub`, but other functions can be
2838 specified. They will be applied as follows::
2840 A, B, C, D, ... --> A, func(B, A), func(C, B), func(D, C), ...
2842 For example, to do progressive division:
2844 >>> iterable = [1, 2, 6, 24, 120]
2845 >>> func = lambda x, y: x // y
2846 >>> list(difference(iterable, func))
2847 [1, 2, 3, 4, 5]
2849 If the *initial* keyword is set, the first element will be skipped when
2850 computing successive differences.
2852 >>> it = [10, 11, 13, 16] # from accumulate([1, 2, 3], initial=10)
2853 >>> list(difference(it, initial=10))
2854 [1, 2, 3]
2856 """
2857 a, b = tee(iterable)
2858 try:
2859 first = [next(b)]
2860 except StopIteration:
2861 return iter([])
2863 if initial is not None:
2864 return map(func, b, a)
2866 return chain(first, map(func, b, a))
2869class SequenceView(Sequence):
2870 """Return a read-only view of the sequence object *target*.
2872 :class:`SequenceView` objects are analogous to Python's built-in
2873 "dictionary view" types. They provide a dynamic view of a sequence's items,
2874 meaning that when the sequence updates, so does the view.
2876 >>> seq = ['0', '1', '2']
2877 >>> view = SequenceView(seq)
2878 >>> view
2879 SequenceView(['0', '1', '2'])
2880 >>> seq.append('3')
2881 >>> view
2882 SequenceView(['0', '1', '2', '3'])
2884 Sequence views support indexing, slicing, and length queries. They act
2885 like the underlying sequence, except they don't allow assignment:
2887 >>> view[1]
2888 '1'
2889 >>> view[1:-1]
2890 ['1', '2']
2891 >>> len(view)
2892 4
2894 Sequence views are useful as an alternative to copying, as they don't
2895 require (much) extra storage.
2897 """
2899 def __init__(self, target):
2900 if not isinstance(target, Sequence):
2901 raise TypeError
2902 self._target = target
2904 def __getitem__(self, index):
2905 return self._target[index]
2907 def __len__(self):
2908 return len(self._target)
2910 def __repr__(self):
2911 return f'{self.__class__.__name__}({self._target!r})'
2914class seekable:
2915 """Wrap an iterator to allow for seeking backward and forward. This
2916 progressively caches the items in the source iterable so they can be
2917 re-visited.
2919 Call :meth:`seek` with an index to seek to that position in the source
2920 iterable.
2922 To "reset" an iterator, seek to ``0``:
2924 >>> from itertools import count
2925 >>> it = seekable((str(n) for n in count()))
2926 >>> next(it), next(it), next(it)
2927 ('0', '1', '2')
2928 >>> it.seek(0)
2929 >>> next(it), next(it), next(it)
2930 ('0', '1', '2')
2932 You can also seek forward:
2934 >>> it = seekable((str(n) for n in range(20)))
2935 >>> it.seek(10)
2936 >>> next(it)
2937 '10'
2938 >>> it.seek(20) # Seeking past the end of the source isn't a problem
2939 >>> list(it)
2940 []
2941 >>> it.seek(0) # Resetting works even after hitting the end
2942 >>> next(it)
2943 '0'
2945 Call :meth:`relative_seek` to seek relative to the source iterator's
2946 current position.
2948 >>> it = seekable((str(n) for n in range(20)))
2949 >>> next(it), next(it), next(it)
2950 ('0', '1', '2')
2951 >>> it.relative_seek(2)
2952 >>> next(it)
2953 '5'
2954 >>> it.relative_seek(-3) # Source is at '6', we move back to '3'
2955 >>> next(it)
2956 '3'
2957 >>> it.relative_seek(-3) # Source is at '4', we move back to '1'
2958 >>> next(it)
2959 '1'
2962 Call :meth:`peek` to look ahead one item without advancing the iterator:
2964 >>> it = seekable('1234')
2965 >>> it.peek()
2966 '1'
2967 >>> list(it)
2968 ['1', '2', '3', '4']
2969 >>> it.peek(default='empty')
2970 'empty'
2972 Before the iterator is at its end, calling :func:`bool` on it will return
2973 ``True``. After it will return ``False``:
2975 >>> it = seekable('5678')
2976 >>> bool(it)
2977 True
2978 >>> list(it)
2979 ['5', '6', '7', '8']
2980 >>> bool(it)
2981 False
2983 You may view the contents of the cache with the :meth:`elements` method.
2984 That returns a :class:`SequenceView`, a view that updates automatically:
2986 >>> it = seekable((str(n) for n in range(10)))
2987 >>> next(it), next(it), next(it)
2988 ('0', '1', '2')
2989 >>> elements = it.elements()
2990 >>> elements
2991 SequenceView(['0', '1', '2'])
2992 >>> next(it)
2993 '3'
2994 >>> elements
2995 SequenceView(['0', '1', '2', '3'])
2997 By default, the cache grows as the source iterable progresses, so beware of
2998 wrapping very large or infinite iterables. Supply *maxlen* to limit the
2999 size of the cache (this of course limits how far back you can seek).
3001 >>> from itertools import count
3002 >>> it = seekable((str(n) for n in count()), maxlen=2)
3003 >>> next(it), next(it), next(it), next(it)
3004 ('0', '1', '2', '3')
3005 >>> list(it.elements())
3006 ['2', '3']
3007 >>> it.seek(0)
3008 >>> next(it), next(it), next(it), next(it)
3009 ('2', '3', '4', '5')
3010 >>> next(it)
3011 '6'
3013 """
3015 def __init__(self, iterable, maxlen=None):
3016 self._source = iter(iterable)
3017 if maxlen is None:
3018 self._cache = []
3019 else:
3020 self._cache = deque([], maxlen)
3021 self._index = None
3023 def __iter__(self):
3024 return self
3026 def __next__(self):
3027 if self._index is not None:
3028 try:
3029 item = self._cache[self._index]
3030 except IndexError:
3031 self._index = None
3032 else:
3033 self._index += 1
3034 return item
3036 item = next(self._source)
3037 self._cache.append(item)
3038 return item
3040 def __bool__(self):
3041 try:
3042 self.peek()
3043 except StopIteration:
3044 return False
3045 return True
3047 def peek(self, default=_marker):
3048 try:
3049 peeked = next(self)
3050 except StopIteration:
3051 if default is _marker:
3052 raise
3053 return default
3054 if self._index is None:
3055 self._index = len(self._cache)
3056 self._index -= 1
3057 return peeked
3059 def elements(self):
3060 return SequenceView(self._cache)
3062 def seek(self, index):
3063 self._index = index
3064 remainder = index - len(self._cache)
3065 if remainder > 0:
3066 consume(self, remainder)
3068 def relative_seek(self, count):
3069 if self._index is None:
3070 self._index = len(self._cache)
3072 self.seek(max(self._index + count, 0))
3075class run_length:
3076 """
3077 :func:`run_length.encode` compresses an iterable with run-length encoding.
3078 It yields groups of repeated items with the count of how many times they
3079 were repeated:
3081 >>> uncompressed = 'abbcccdddd'
3082 >>> list(run_length.encode(uncompressed))
3083 [('a', 1), ('b', 2), ('c', 3), ('d', 4)]
3085 :func:`run_length.decode` decompresses an iterable that was previously
3086 compressed with run-length encoding. It yields the items of the
3087 decompressed iterable:
3089 >>> compressed = [('a', 1), ('b', 2), ('c', 3), ('d', 4)]
3090 >>> list(run_length.decode(compressed))
3091 ['a', 'b', 'b', 'c', 'c', 'c', 'd', 'd', 'd', 'd']
3093 """
3095 @staticmethod
3096 def encode(iterable):
3097 return ((k, ilen(g)) for k, g in groupby(iterable))
3099 @staticmethod
3100 def decode(iterable):
3101 return chain.from_iterable(starmap(repeat, iterable))
3104def exactly_n(iterable, n, predicate=bool):
3105 """Return ``True`` if exactly ``n`` items in the iterable are ``True``
3106 according to the *predicate* function.
3108 >>> exactly_n([True, True, False], 2)
3109 True
3110 >>> exactly_n([True, True, False], 1)
3111 False
3112 >>> exactly_n([0, 1, 2, 3, 4, 5], 3, lambda x: x < 3)
3113 True
3115 The iterable will be advanced until ``n + 1`` truthy items are encountered,
3116 so avoid calling it on infinite iterables.
3118 """
3119 iterator = filter(predicate, iterable)
3120 if n <= 0:
3121 if n < 0:
3122 return False
3123 for _ in iterator:
3124 return False
3125 return True
3127 iterator = islice(iterator, n - 1, None)
3128 for _ in iterator:
3129 for _ in iterator:
3130 return False
3131 return True
3132 return False
3135def circular_shifts(iterable, steps=1):
3136 """Yield the circular shifts of *iterable*.
3138 >>> list(circular_shifts(range(4)))
3139 [(0, 1, 2, 3), (1, 2, 3, 0), (2, 3, 0, 1), (3, 0, 1, 2)]
3141 Set *steps* to the number of places to rotate to the left
3142 (or to the right if negative). Defaults to 1.
3144 >>> list(circular_shifts(range(4), 2))
3145 [(0, 1, 2, 3), (2, 3, 0, 1)]
3147 >>> list(circular_shifts(range(4), -1))
3148 [(0, 1, 2, 3), (3, 0, 1, 2), (2, 3, 0, 1), (1, 2, 3, 0)]
3150 """
3151 buffer = deque(iterable)
3152 if steps == 0:
3153 raise ValueError('Steps should be a non-zero integer')
3155 buffer.rotate(steps)
3156 steps = -steps
3157 n = len(buffer)
3158 n //= math.gcd(n, steps)
3160 for _ in repeat(None, n):
3161 buffer.rotate(steps)
3162 yield tuple(buffer)
3165def make_decorator(wrapping_func, result_index=0):
3166 """Return a decorator version of *wrapping_func*, which is a function that
3167 modifies an iterable. *result_index* is the position in that function's
3168 signature where the iterable goes.
3170 This lets you use itertools on the "production end," i.e. at function
3171 definition. This can augment what the function returns without changing the
3172 function's code.
3174 For example, to produce a decorator version of :func:`chunked`:
3176 >>> from more_itertools import chunked
3177 >>> chunker = make_decorator(chunked, result_index=0)
3178 >>> @chunker(3)
3179 ... def iter_range(n):
3180 ... return iter(range(n))
3181 ...
3182 >>> list(iter_range(9))
3183 [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
3185 To only allow truthy items to be returned:
3187 >>> truth_serum = make_decorator(filter, result_index=1)
3188 >>> @truth_serum(bool)
3189 ... def boolean_test():
3190 ... return [0, 1, '', ' ', False, True]
3191 ...
3192 >>> list(boolean_test())
3193 [1, ' ', True]
3195 The :func:`peekable` and :func:`seekable` wrappers make for practical
3196 decorators:
3198 >>> from more_itertools import peekable
3199 >>> peekable_function = make_decorator(peekable)
3200 >>> @peekable_function()
3201 ... def str_range(*args):
3202 ... return (str(x) for x in range(*args))
3203 ...
3204 >>> it = str_range(1, 20, 2)
3205 >>> next(it), next(it), next(it)
3206 ('1', '3', '5')
3207 >>> it.peek()
3208 '7'
3209 >>> next(it)
3210 '7'
3212 """
3214 # See https://sites.google.com/site/bbayles/index/decorator_factory for
3215 # notes on how this works.
3216 def decorator(*wrapping_args, **wrapping_kwargs):
3217 def outer_wrapper(f):
3218 def inner_wrapper(*args, **kwargs):
3219 result = f(*args, **kwargs)
3220 wrapping_args_ = list(wrapping_args)
3221 wrapping_args_.insert(result_index, result)
3222 return wrapping_func(*wrapping_args_, **wrapping_kwargs)
3224 return inner_wrapper
3226 return outer_wrapper
3228 return decorator
3231def map_reduce(iterable, keyfunc, valuefunc=None, reducefunc=None):
3232 """Return a dictionary that maps the items in *iterable* to categories
3233 defined by *keyfunc*, transforms them with *valuefunc*, and
3234 then summarizes them by category with *reducefunc*.
3236 *valuefunc* defaults to the identity function if it is unspecified.
3237 If *reducefunc* is unspecified, no summarization takes place:
3239 >>> keyfunc = lambda x: x.upper()
3240 >>> result = map_reduce('abbccc', keyfunc)
3241 >>> sorted(result.items())
3242 [('A', ['a']), ('B', ['b', 'b']), ('C', ['c', 'c', 'c'])]
3244 Specifying *valuefunc* transforms the categorized items:
3246 >>> keyfunc = lambda x: x.upper()
3247 >>> valuefunc = lambda x: 1
3248 >>> result = map_reduce('abbccc', keyfunc, valuefunc)
3249 >>> sorted(result.items())
3250 [('A', [1]), ('B', [1, 1]), ('C', [1, 1, 1])]
3252 Specifying *reducefunc* summarizes the categorized items:
3254 >>> keyfunc = lambda x: x.upper()
3255 >>> valuefunc = lambda x: 1
3256 >>> reducefunc = sum
3257 >>> result = map_reduce('abbccc', keyfunc, valuefunc, reducefunc)
3258 >>> sorted(result.items())
3259 [('A', 1), ('B', 2), ('C', 3)]
3261 You may want to filter the input iterable before applying the map/reduce
3262 procedure:
3264 >>> all_items = range(30)
3265 >>> items = [x for x in all_items if 10 <= x <= 20] # Filter
3266 >>> keyfunc = lambda x: x % 2 # Evens map to 0; odds to 1
3267 >>> categories = map_reduce(items, keyfunc=keyfunc)
3268 >>> sorted(categories.items())
3269 [(0, [10, 12, 14, 16, 18, 20]), (1, [11, 13, 15, 17, 19])]
3270 >>> summaries = map_reduce(items, keyfunc=keyfunc, reducefunc=sum)
3271 >>> sorted(summaries.items())
3272 [(0, 90), (1, 75)]
3274 Note that all items in the iterable are gathered into a list before the
3275 summarization step, which may require significant storage.
3277 The returned object is a :obj:`collections.defaultdict` with the
3278 ``default_factory`` set to ``None``, such that it behaves like a normal
3279 dictionary.
3281 """
3283 ret = defaultdict(list)
3285 if valuefunc is None:
3286 for item in iterable:
3287 key = keyfunc(item)
3288 ret[key].append(item)
3290 else:
3291 for item in iterable:
3292 key = keyfunc(item)
3293 value = valuefunc(item)
3294 ret[key].append(value)
3296 if reducefunc is not None:
3297 for key, value_list in ret.items():
3298 ret[key] = reducefunc(value_list)
3300 ret.default_factory = None
3301 return ret
3304def rlocate(iterable, pred=bool, window_size=None):
3305 """Yield the index of each item in *iterable* for which *pred* returns
3306 ``True``, starting from the right and moving left.
3308 *pred* defaults to :func:`bool`, which will select truthy items:
3310 >>> list(rlocate([0, 1, 1, 0, 1, 0, 0])) # Truthy at 1, 2, and 4
3311 [4, 2, 1]
3313 Set *pred* to a custom function to, e.g., find the indexes for a particular
3314 item:
3316 >>> iterator = iter('abcb')
3317 >>> pred = lambda x: x == 'b'
3318 >>> list(rlocate(iterator, pred))
3319 [3, 1]
3321 If *window_size* is given, then the *pred* function will be called with
3322 that many items. This enables searching for sub-sequences:
3324 >>> iterable = [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]
3325 >>> pred = lambda *args: args == (1, 2, 3)
3326 >>> list(rlocate(iterable, pred=pred, window_size=3))
3327 [9, 5, 1]
3329 Beware, this function won't return anything for infinite iterables.
3330 If *iterable* is reversible, ``rlocate`` will reverse it and search from
3331 the right. Otherwise, it will search from the left and return the results
3332 in reverse order.
3334 See :func:`locate` to for other example applications.
3336 """
3337 if window_size is None:
3338 try:
3339 len_iter = len(iterable)
3340 return (len_iter - i - 1 for i in locate(reversed(iterable), pred))
3341 except TypeError:
3342 pass
3344 return reversed(list(locate(iterable, pred, window_size)))
3347def replace(iterable, pred, substitutes, count=None, window_size=1):
3348 """Yield the items from *iterable*, replacing the items for which *pred*
3349 returns ``True`` with the items from the iterable *substitutes*.
3351 >>> iterable = [1, 1, 0, 1, 1, 0, 1, 1]
3352 >>> pred = lambda x: x == 0
3353 >>> substitutes = (2, 3)
3354 >>> list(replace(iterable, pred, substitutes))
3355 [1, 1, 2, 3, 1, 1, 2, 3, 1, 1]
3357 If *count* is given, the number of replacements will be limited:
3359 >>> iterable = [1, 1, 0, 1, 1, 0, 1, 1, 0]
3360 >>> pred = lambda x: x == 0
3361 >>> substitutes = [None]
3362 >>> list(replace(iterable, pred, substitutes, count=2))
3363 [1, 1, None, 1, 1, None, 1, 1, 0]
3365 Use *window_size* to control the number of items passed as arguments to
3366 *pred*. This allows for locating and replacing subsequences.
3368 >>> iterable = [0, 1, 2, 5, 0, 1, 2, 5]
3369 >>> window_size = 3
3370 >>> pred = lambda *args: args == (0, 1, 2) # 3 items passed to pred
3371 >>> substitutes = [3, 4] # Splice in these items
3372 >>> list(replace(iterable, pred, substitutes, window_size=window_size))
3373 [3, 4, 5, 3, 4, 5]
3375 """
3376 if window_size < 1:
3377 raise ValueError('window_size must be at least 1')
3379 # Save the substitutes iterable, since it's used more than once
3380 substitutes = tuple(substitutes)
3382 # Add padding such that the number of windows matches the length of the
3383 # iterable
3384 it = chain(iterable, repeat(_marker, window_size - 1))
3385 windows = windowed(it, window_size)
3387 n = 0
3388 for w in windows:
3389 # If the current window matches our predicate (and we haven't hit
3390 # our maximum number of replacements), splice in the substitutes
3391 # and then consume the following windows that overlap with this one.
3392 # For example, if the iterable is (0, 1, 2, 3, 4...)
3393 # and the window size is 2, we have (0, 1), (1, 2), (2, 3)...
3394 # If the predicate matches on (0, 1), we need to zap (0, 1) and (1, 2)
3395 if pred(*w):
3396 if (count is None) or (n < count):
3397 n += 1
3398 yield from substitutes
3399 consume(windows, window_size - 1)
3400 continue
3402 # If there was no match (or we've reached the replacement limit),
3403 # yield the first item from the window.
3404 if w and (w[0] is not _marker):
3405 yield w[0]
3408def partitions(iterable):
3409 """Yield all possible order-preserving partitions of *iterable*.
3411 >>> iterable = 'abc'
3412 >>> for part in partitions(iterable):
3413 ... print([''.join(p) for p in part])
3414 ['abc']
3415 ['a', 'bc']
3416 ['ab', 'c']
3417 ['a', 'b', 'c']
3419 This is unrelated to :func:`partition`.
3421 """
3422 sequence = list(iterable)
3423 n = len(sequence)
3424 for i in powerset(range(1, n)):
3425 yield [sequence[i:j] for i, j in zip((0,) + i, i + (n,))]
3428def set_partitions(iterable, k=None, min_size=None, max_size=None):
3429 """
3430 Yield the set partitions of *iterable* into *k* parts. Set partitions are
3431 not order-preserving.
3433 >>> iterable = 'abc'
3434 >>> for part in set_partitions(iterable, 2):
3435 ... print([''.join(p) for p in part])
3436 ['a', 'bc']
3437 ['ab', 'c']
3438 ['b', 'ac']
3441 If *k* is not given, every set partition is generated.
3443 >>> iterable = 'abc'
3444 >>> for part in set_partitions(iterable):
3445 ... print([''.join(p) for p in part])
3446 ['abc']
3447 ['a', 'bc']
3448 ['ab', 'c']
3449 ['b', 'ac']
3450 ['a', 'b', 'c']
3452 if *min_size* and/or *max_size* are given, the minimum and/or maximum size
3453 per block in partition is set.
3455 >>> iterable = 'abc'
3456 >>> for part in set_partitions(iterable, min_size=2):
3457 ... print([''.join(p) for p in part])
3458 ['abc']
3459 >>> for part in set_partitions(iterable, max_size=2):
3460 ... print([''.join(p) for p in part])
3461 ['a', 'bc']
3462 ['ab', 'c']
3463 ['b', 'ac']
3464 ['a', 'b', 'c']
3466 """
3467 L = list(iterable)
3468 n = len(L)
3469 if k is not None:
3470 if k < 1:
3471 raise ValueError(
3472 "Can't partition in a negative or zero number of groups"
3473 )
3474 elif k > n:
3475 return
3477 min_size = min_size if min_size is not None else 0
3478 max_size = max_size if max_size is not None else n
3479 if min_size > max_size:
3480 return
3482 def set_partitions_helper(L, k):
3483 n = len(L)
3484 if k == 1:
3485 yield [L]
3486 elif n == k:
3487 yield [[s] for s in L]
3488 else:
3489 e, *M = L
3490 for p in set_partitions_helper(M, k - 1):
3491 yield [[e], *p]
3492 for p in set_partitions_helper(M, k):
3493 for i in range(len(p)):
3494 yield p[:i] + [[e] + p[i]] + p[i + 1 :]
3496 if k is None:
3497 for k in range(1, n + 1):
3498 yield from filter(
3499 lambda z: all(min_size <= len(bk) <= max_size for bk in z),
3500 set_partitions_helper(L, k),
3501 )
3502 else:
3503 yield from filter(
3504 lambda z: all(min_size <= len(bk) <= max_size for bk in z),
3505 set_partitions_helper(L, k),
3506 )
3509class time_limited:
3510 """
3511 Yield items from *iterable* until *limit_seconds* have passed.
3512 If the time limit expires before all items have been yielded, the
3513 ``timed_out`` parameter will be set to ``True``.
3515 >>> from time import sleep
3516 >>> def generator():
3517 ... yield 1
3518 ... yield 2
3519 ... sleep(0.2)
3520 ... yield 3
3521 >>> iterable = time_limited(0.1, generator())
3522 >>> list(iterable)
3523 [1, 2]
3524 >>> iterable.timed_out
3525 True
3527 Note that the time is checked before each item is yielded, and iteration
3528 stops if the time elapsed is greater than *limit_seconds*. If your time
3529 limit is 1 second, but it takes 2 seconds to generate the first item from
3530 the iterable, the function will run for 2 seconds and not yield anything.
3531 As a special case, when *limit_seconds* is zero, the iterator never
3532 returns anything.
3534 """
3536 def __init__(self, limit_seconds, iterable):
3537 if limit_seconds < 0:
3538 raise ValueError('limit_seconds must be positive')
3539 self.limit_seconds = limit_seconds
3540 self._iterator = iter(iterable)
3541 self._start_time = monotonic()
3542 self.timed_out = False
3544 def __iter__(self):
3545 return self
3547 def __next__(self):
3548 if self.limit_seconds == 0:
3549 self.timed_out = True
3550 raise StopIteration
3551 item = next(self._iterator)
3552 if monotonic() - self._start_time > self.limit_seconds:
3553 self.timed_out = True
3554 raise StopIteration
3556 return item
3559def only(iterable, default=None, too_long=None):
3560 """If *iterable* has only one item, return it.
3561 If it has zero items, return *default*.
3562 If it has more than one item, raise the exception given by *too_long*,
3563 which is ``ValueError`` by default.
3565 >>> only([], default='missing')
3566 'missing'
3567 >>> only([1])
3568 1
3569 >>> only([1, 2]) # doctest: +IGNORE_EXCEPTION_DETAIL
3570 Traceback (most recent call last):
3571 ...
3572 ValueError: Expected exactly one item in iterable, but got 1, 2,
3573 and perhaps more.'
3574 >>> only([1, 2], too_long=TypeError) # doctest: +IGNORE_EXCEPTION_DETAIL
3575 Traceback (most recent call last):
3576 ...
3577 TypeError
3579 Note that :func:`only` attempts to advance *iterable* twice to ensure there
3580 is only one item. See :func:`spy` or :func:`peekable` to check
3581 iterable contents less destructively.
3583 """
3584 iterator = iter(iterable)
3585 for first in iterator:
3586 for second in iterator:
3587 msg = (
3588 f'Expected exactly one item in iterable, but got {first!r}, '
3589 f'{second!r}, and perhaps more.'
3590 )
3591 raise too_long or ValueError(msg)
3592 return first
3593 return default
3596def ichunked(iterable, n):
3597 """Break *iterable* into sub-iterables with *n* elements each.
3598 :func:`ichunked` is like :func:`chunked`, but it yields iterables
3599 instead of lists.
3601 If the sub-iterables are read in order, the elements of *iterable*
3602 won't be stored in memory.
3603 If they are read out of order, :func:`itertools.tee` is used to cache
3604 elements as necessary.
3606 >>> from itertools import count
3607 >>> all_chunks = ichunked(count(), 4)
3608 >>> c_1, c_2, c_3 = next(all_chunks), next(all_chunks), next(all_chunks)
3609 >>> list(c_2) # c_1's elements have been cached; c_3's haven't been
3610 [4, 5, 6, 7]
3611 >>> list(c_1)
3612 [0, 1, 2, 3]
3613 >>> list(c_3)
3614 [8, 9, 10, 11]
3616 """
3617 iterator = iter(iterable)
3618 for first in iterator:
3619 rest = islice(iterator, n - 1)
3620 cache, cacher = tee(rest)
3621 yield chain([first], rest, cache)
3622 consume(cacher)
3625def iequals(*iterables):
3626 """Return ``True`` if all given *iterables* are equal to each other,
3627 which means that they contain the same elements in the same order.
3629 The function is useful for comparing iterables of different data types
3630 or iterables that do not support equality checks.
3632 >>> iequals("abc", ['a', 'b', 'c'], ('a', 'b', 'c'), iter("abc"))
3633 True
3635 >>> iequals("abc", "acb")
3636 False
3638 Not to be confused with :func:`all_equal`, which checks whether all
3639 elements of iterable are equal to each other.
3641 """
3642 try:
3643 return all(map(all_equal, zip(*iterables, strict=True)))
3644 except ValueError:
3645 return False
3648def distinct_combinations(iterable, r):
3649 """Yield the distinct combinations of *r* items taken from *iterable*.
3651 >>> list(distinct_combinations([0, 0, 1], 2))
3652 [(0, 0), (0, 1)]
3654 Equivalent to ``set(combinations(iterable))``, except duplicates are not
3655 generated and thrown away. For larger input sequences this is much more
3656 efficient.
3658 """
3659 if r < 0:
3660 raise ValueError('r must be non-negative')
3661 elif r == 0:
3662 yield ()
3663 return
3664 pool = tuple(iterable)
3665 generators = [unique_everseen(enumerate(pool), key=itemgetter(1))]
3666 current_combo = [None] * r
3667 level = 0
3668 while generators:
3669 try:
3670 cur_idx, p = next(generators[-1])
3671 except StopIteration:
3672 generators.pop()
3673 level -= 1
3674 continue
3675 current_combo[level] = p
3676 if level + 1 == r:
3677 yield tuple(current_combo)
3678 else:
3679 generators.append(
3680 unique_everseen(
3681 enumerate(pool[cur_idx + 1 :], cur_idx + 1),
3682 key=itemgetter(1),
3683 )
3684 )
3685 level += 1
3688def filter_except(validator, iterable, *exceptions):
3689 """Yield the items from *iterable* for which the *validator* function does
3690 not raise one of the specified *exceptions*.
3692 *validator* is called for each item in *iterable*.
3693 It should be a function that accepts one argument and raises an exception
3694 if that item is not valid.
3696 >>> iterable = ['1', '2', 'three', '4', None]
3697 >>> list(filter_except(int, iterable, ValueError, TypeError))
3698 ['1', '2', '4']
3700 If an exception other than one given by *exceptions* is raised by
3701 *validator*, it is raised like normal.
3702 """
3703 for item in iterable:
3704 try:
3705 validator(item)
3706 except exceptions:
3707 pass
3708 else:
3709 yield item
3712def map_except(function, iterable, *exceptions):
3713 """Transform each item from *iterable* with *function* and yield the
3714 result, unless *function* raises one of the specified *exceptions*.
3716 *function* is called to transform each item in *iterable*.
3717 It should accept one argument.
3719 >>> iterable = ['1', '2', 'three', '4', None]
3720 >>> list(map_except(int, iterable, ValueError, TypeError))
3721 [1, 2, 4]
3723 If an exception other than one given by *exceptions* is raised by
3724 *function*, it is raised like normal.
3725 """
3726 for item in iterable:
3727 try:
3728 yield function(item)
3729 except exceptions:
3730 pass
3733def map_if(iterable, pred, func, func_else=None):
3734 """Evaluate each item from *iterable* using *pred*. If the result is
3735 equivalent to ``True``, transform the item with *func* and yield it.
3736 Otherwise, transform the item with *func_else* and yield it.
3738 *pred*, *func*, and *func_else* should each be functions that accept
3739 one argument. By default, *func_else* is the identity function.
3741 >>> from math import sqrt
3742 >>> iterable = list(range(-5, 5))
3743 >>> iterable
3744 [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4]
3745 >>> list(map_if(iterable, lambda x: x > 3, lambda x: 'toobig'))
3746 [-5, -4, -3, -2, -1, 0, 1, 2, 3, 'toobig']
3747 >>> list(map_if(iterable, lambda x: x >= 0,
3748 ... lambda x: f'{sqrt(x):.2f}', lambda x: None))
3749 [None, None, None, None, None, '0.00', '1.00', '1.41', '1.73', '2.00']
3750 """
3752 if func_else is None:
3753 for item in iterable:
3754 yield func(item) if pred(item) else item
3756 else:
3757 for item in iterable:
3758 yield func(item) if pred(item) else func_else(item)
3761def _sample_unweighted(iterator, k, strict):
3762 # Algorithm L in the 1994 paper by Kim-Hung Li:
3763 # "Reservoir-Sampling Algorithms of Time Complexity O(n(1+log(N/n)))".
3765 reservoir = list(islice(iterator, k))
3766 if strict and len(reservoir) < k:
3767 raise ValueError('Sample larger than population')
3768 W = 1.0
3770 with suppress(StopIteration):
3771 while True:
3772 W *= random() ** (1 / k)
3773 skip = floor(log(random()) / log1p(-W))
3774 element = next(islice(iterator, skip, None))
3775 reservoir[randrange(k)] = element
3777 shuffle(reservoir)
3778 return reservoir
3781def _sample_weighted(iterator, k, weights, strict):
3782 # Implementation of "A-ExpJ" from the 2006 paper by Efraimidis et al. :
3783 # "Weighted random sampling with a reservoir".
3785 # Log-transform for numerical stability for weights that are small/large
3786 weight_keys = (log(random()) / weight for weight in weights)
3788 # Fill up the reservoir (collection of samples) with the first `k`
3789 # weight-keys and elements, then heapify the list.
3790 reservoir = take(k, zip(weight_keys, iterator))
3791 if strict and len(reservoir) < k:
3792 raise ValueError('Sample larger than population')
3794 heapify(reservoir)
3796 # The number of jumps before changing the reservoir is a random variable
3797 # with an exponential distribution. Sample it using random() and logs.
3798 smallest_weight_key, _ = reservoir[0]
3799 weights_to_skip = log(random()) / smallest_weight_key
3801 for weight, element in zip(weights, iterator):
3802 if weight >= weights_to_skip:
3803 # The notation here is consistent with the paper, but we store
3804 # the weight-keys in log-space for better numerical stability.
3805 smallest_weight_key, _ = reservoir[0]
3806 t_w = exp(weight * smallest_weight_key)
3807 r_2 = uniform(t_w, 1) # generate U(t_w, 1)
3808 weight_key = log(r_2) / weight
3809 heapreplace(reservoir, (weight_key, element))
3810 smallest_weight_key, _ = reservoir[0]
3811 weights_to_skip = log(random()) / smallest_weight_key
3812 else:
3813 weights_to_skip -= weight
3815 ret = [element for weight_key, element in reservoir]
3816 shuffle(ret)
3817 return ret
3820def _sample_counted(population, k, counts, strict):
3821 element = None
3822 remaining = 0
3824 def feed(i):
3825 # Advance *i* steps ahead and consume an element
3826 nonlocal element, remaining
3828 while i + 1 > remaining:
3829 i = i - remaining
3830 element = next(population)
3831 remaining = next(counts)
3832 remaining -= i + 1
3833 return element
3835 with suppress(StopIteration):
3836 reservoir = []
3837 for _ in range(k):
3838 reservoir.append(feed(0))
3840 if strict and len(reservoir) < k:
3841 raise ValueError('Sample larger than population')
3843 with suppress(StopIteration):
3844 W = 1.0
3845 while True:
3846 W *= random() ** (1 / k)
3847 skip = floor(log(random()) / log1p(-W))
3848 element = feed(skip)
3849 reservoir[randrange(k)] = element
3851 shuffle(reservoir)
3852 return reservoir
3855def sample(iterable, k, weights=None, *, counts=None, strict=False):
3856 """Return a *k*-length list of elements chosen (without replacement)
3857 from the *iterable*.
3859 Similar to :func:`random.sample`, but works on inputs that aren't
3860 indexable (such as sets and dictionaries) and on inputs where the
3861 size isn't known in advance (such as generators).
3863 >>> iterable = range(100)
3864 >>> sample(iterable, 5) # doctest: +SKIP
3865 [81, 60, 96, 16, 4]
3867 For iterables with repeated elements, you may supply *counts* to
3868 indicate the repeats.
3870 >>> iterable = ['a', 'b']
3871 >>> counts = [3, 4] # Equivalent to 'a', 'a', 'a', 'b', 'b', 'b', 'b'
3872 >>> sample(iterable, k=3, counts=counts) # doctest: +SKIP
3873 ['a', 'a', 'b']
3875 An iterable with *weights* may be given:
3877 >>> iterable = range(100)
3878 >>> weights = (i * i + 1 for i in range(100))
3879 >>> sampled = sample(iterable, 5, weights=weights) # doctest: +SKIP
3880 [79, 67, 74, 66, 78]
3882 Weighted selections are made without replacement.
3883 After an element is selected, it is removed from the pool and the
3884 relative weights of the other elements increase (this
3885 does not match the behavior of :func:`random.sample`'s *counts*
3886 parameter). Note that *weights* may not be used with *counts*.
3888 If the length of *iterable* is less than *k*,
3889 ``ValueError`` is raised if *strict* is ``True`` and
3890 all elements are returned (in shuffled order) if *strict* is ``False``.
3892 By default, the `Algorithm L <https://w.wiki/ANrM>`__ reservoir sampling
3893 technique is used. When *weights* are provided,
3894 `Algorithm A-ExpJ <https://w.wiki/ANrS>`__ is used instead.
3896 Notes on reproducibility:
3898 * The algorithms rely on inexact floating-point functions provided
3899 by the underlying math library (e.g. ``log``, ``log1p``, and ``pow``).
3900 Those functions can `produce slightly different results
3901 <https://members.loria.fr/PZimmermann/papers/accuracy.pdf>`_ on
3902 different builds. Accordingly, selections can vary across builds
3903 even for the same seed.
3905 * The algorithms loop over the input and make selections based on
3906 ordinal position, so selections from unordered collections (such as
3907 sets) won't reproduce across sessions on the same platform using the
3908 same seed. For example, this won't reproduce::
3910 >> seed(8675309)
3911 >> sample(set('abcdefghijklmnopqrstuvwxyz'), 10)
3912 ['c', 'p', 'e', 'w', 's', 'a', 'j', 'd', 'n', 't']
3914 """
3915 iterator = iter(iterable)
3917 if k < 0:
3918 raise ValueError('k must be non-negative')
3920 if k == 0:
3921 return []
3923 if weights is not None and counts is not None:
3924 raise TypeError('weights and counts are mutually exclusive')
3926 elif weights is not None:
3927 weights = iter(weights)
3928 return _sample_weighted(iterator, k, weights, strict)
3930 elif counts is not None:
3931 counts = iter(counts)
3932 return _sample_counted(iterator, k, counts, strict)
3934 else:
3935 return _sample_unweighted(iterator, k, strict)
3938def is_sorted(iterable, key=None, reverse=False, strict=False):
3939 """Returns ``True`` if the items of iterable are in sorted order, and
3940 ``False`` otherwise. *key* and *reverse* have the same meaning that they do
3941 in the built-in :func:`sorted` function.
3943 >>> is_sorted(['1', '2', '3', '4', '5'], key=int)
3944 True
3945 >>> is_sorted([5, 4, 3, 1, 2], reverse=True)
3946 False
3948 If *strict*, tests for strict sorting, that is, returns ``False`` if equal
3949 elements are found:
3951 >>> is_sorted([1, 2, 2])
3952 True
3953 >>> is_sorted([1, 2, 2], strict=True)
3954 False
3956 The function returns ``False`` after encountering the first out-of-order
3957 item, which means it may produce results that differ from the built-in
3958 :func:`sorted` function for objects with unusual comparison dynamics
3959 (like ``math.nan``). If there are no out-of-order items, the iterable is
3960 exhausted.
3961 """
3962 it = iterable if (key is None) else map(key, iterable)
3963 a, b = tee(it)
3964 next(b, None)
3965 if reverse:
3966 b, a = a, b
3967 return all(map(lt, a, b)) if strict else not any(map(lt, b, a))
3970class AbortThread(BaseException):
3971 pass
3974class callback_iter:
3975 """Convert a function that uses callbacks to an iterator.
3977 Let *func* be a function that takes a `callback` keyword argument.
3978 For example:
3980 >>> def func(callback=None):
3981 ... for i, c in [(1, 'a'), (2, 'b'), (3, 'c')]:
3982 ... if callback:
3983 ... callback(i, c)
3984 ... return 4
3987 Use ``with callback_iter(func)`` to get an iterator over the parameters
3988 that are delivered to the callback.
3990 >>> with callback_iter(func) as it:
3991 ... for args, kwargs in it:
3992 ... print(args)
3993 (1, 'a')
3994 (2, 'b')
3995 (3, 'c')
3997 The function will be called in a background thread. The ``done`` property
3998 indicates whether it has completed execution.
4000 >>> it.done
4001 True
4003 If it completes successfully, its return value will be available
4004 in the ``result`` property.
4006 >>> it.result
4007 4
4009 Notes:
4011 * If the function uses some keyword argument besides ``callback``, supply
4012 *callback_kwd*.
4013 * If it finished executing, but raised an exception, accessing the
4014 ``result`` property will raise the same exception.
4015 * If it hasn't finished executing, accessing the ``result``
4016 property from within the ``with`` block will raise ``RuntimeError``.
4017 * If it hasn't finished executing, accessing the ``result`` property from
4018 outside the ``with`` block will raise a
4019 ``more_itertools.AbortThread`` exception.
4020 * Provide *wait_seconds* to adjust how frequently the it is polled for
4021 output.
4023 """
4025 def __init__(self, func, callback_kwd='callback', wait_seconds=0.1):
4026 self._func = func
4027 self._callback_kwd = callback_kwd
4028 self._aborted = False
4029 self._future = None
4030 self._wait_seconds = wait_seconds
4032 # Lazily import concurrent.future
4033 self._module = __import__('concurrent.futures').futures
4034 self._executor = self._module.ThreadPoolExecutor(max_workers=1)
4035 self._iterator = self._reader()
4037 def __enter__(self):
4038 return self
4040 def __exit__(self, exc_type, exc_value, traceback):
4041 self._aborted = True
4042 self._executor.shutdown()
4044 def __iter__(self):
4045 return self
4047 def __next__(self):
4048 return next(self._iterator)
4050 @property
4051 def done(self):
4052 if self._future is None:
4053 return False
4054 return self._future.done()
4056 @property
4057 def result(self):
4058 if self._future:
4059 try:
4060 return self._future.result(timeout=0)
4061 except self._module.TimeoutError:
4062 pass
4064 raise RuntimeError('Function has not yet completed')
4066 def _reader(self):
4067 q = Queue()
4069 def callback(*args, **kwargs):
4070 if self._aborted:
4071 raise AbortThread('canceled by user')
4073 q.put((args, kwargs))
4075 self._future = self._executor.submit(
4076 self._func, **{self._callback_kwd: callback}
4077 )
4079 while True:
4080 try:
4081 item = q.get(timeout=self._wait_seconds)
4082 except Empty:
4083 pass
4084 else:
4085 q.task_done()
4086 yield item
4088 if self._future.done():
4089 break
4091 remaining = []
4092 while True:
4093 try:
4094 item = q.get_nowait()
4095 except Empty:
4096 break
4097 else:
4098 q.task_done()
4099 remaining.append(item)
4100 q.join()
4101 yield from remaining
4104def windowed_complete(iterable, n):
4105 """
4106 Yield ``(beginning, middle, end)`` tuples, where:
4108 * Each ``middle`` has *n* items from *iterable*
4109 * Each ``beginning`` has the items before the ones in ``middle``
4110 * Each ``end`` has the items after the ones in ``middle``
4112 >>> iterable = range(7)
4113 >>> n = 3
4114 >>> for beginning, middle, end in windowed_complete(iterable, n):
4115 ... print(beginning, middle, end)
4116 () (0, 1, 2) (3, 4, 5, 6)
4117 (0,) (1, 2, 3) (4, 5, 6)
4118 (0, 1) (2, 3, 4) (5, 6)
4119 (0, 1, 2) (3, 4, 5) (6,)
4120 (0, 1, 2, 3) (4, 5, 6) ()
4122 Note that *n* must be at least 0 and most equal to the length of
4123 *iterable*.
4125 This function will exhaust the iterable and may require significant
4126 storage.
4127 """
4128 if n < 0:
4129 raise ValueError('n must be >= 0')
4131 seq = tuple(iterable)
4132 size = len(seq)
4134 if n > size:
4135 raise ValueError('n must be <= len(seq)')
4137 for i in range(size - n + 1):
4138 beginning = seq[:i]
4139 middle = seq[i : i + n]
4140 end = seq[i + n :]
4141 yield beginning, middle, end
4144def all_unique(iterable, key=None):
4145 """
4146 Returns ``True`` if all the elements of *iterable* are unique (no two
4147 elements are equal).
4149 >>> all_unique('ABCB')
4150 False
4152 If a *key* function is specified, it will be used to make comparisons.
4154 >>> all_unique('ABCb')
4155 True
4156 >>> all_unique('ABCb', str.lower)
4157 False
4159 The function returns as soon as the first non-unique element is
4160 encountered. Iterables with a mix of hashable and unhashable items can
4161 be used, but the function will be slower for unhashable items.
4162 """
4163 seenset = set()
4164 seenset_add = seenset.add
4165 seenlist = []
4166 seenlist_add = seenlist.append
4167 for element in map(key, iterable) if key else iterable:
4168 try:
4169 if element in seenset:
4170 return False
4171 seenset_add(element)
4172 except TypeError:
4173 if element in seenlist:
4174 return False
4175 seenlist_add(element)
4176 return True
4179def nth_product(index, *iterables, repeat=1):
4180 """Equivalent to ``list(product(*iterables, repeat=repeat))[index]``.
4182 The products of *iterables* can be ordered lexicographically.
4183 :func:`nth_product` computes the product at sort position *index* without
4184 computing the previous products.
4186 >>> nth_product(8, range(2), range(2), range(2), range(2))
4187 (1, 0, 0, 0)
4189 The *repeat* keyword argument specifies the number of repetitions
4190 of the iterables. The above example is equivalent to::
4192 >>> nth_product(8, range(2), repeat=4)
4193 (1, 0, 0, 0)
4195 ``IndexError`` will be raised if the given *index* is invalid.
4196 """
4197 pools = tuple(map(tuple, reversed(iterables))) * repeat
4198 ns = tuple(map(len, pools))
4200 c = prod(ns)
4202 if index < 0:
4203 index += c
4204 if not 0 <= index < c:
4205 raise IndexError
4207 result = []
4208 for pool, n in zip(pools, ns):
4209 result.append(pool[index % n])
4210 index //= n
4212 return tuple(reversed(result))
4215def nth_permutation(iterable, r, index):
4216 """Equivalent to ``list(permutations(iterable, r))[index]```
4218 The subsequences of *iterable* that are of length *r* where order is
4219 important can be ordered lexicographically. :func:`nth_permutation`
4220 computes the subsequence at sort position *index* directly, without
4221 computing the previous subsequences.
4223 >>> nth_permutation('ghijk', 2, 5)
4224 ('h', 'i')
4226 ``ValueError`` will be raised If *r* is negative.
4227 ``IndexError`` will be raised if the given *index* is invalid.
4228 """
4229 pool = list(iterable)
4230 n = len(pool)
4231 if r is None:
4232 r = n
4233 c = perm(n, r)
4235 if index < 0:
4236 index += c
4237 if not 0 <= index < c:
4238 raise IndexError
4240 result = [0] * r
4241 q = index * factorial(n) // c if r < n else index
4242 for d in range(1, n + 1):
4243 q, i = divmod(q, d)
4244 if 0 <= n - d < r:
4245 result[n - d] = i
4246 if q == 0:
4247 break
4249 return tuple(map(pool.pop, result))
4252def nth_combination_with_replacement(iterable, r, index):
4253 """Equivalent to
4254 ``list(combinations_with_replacement(iterable, r))[index]``.
4257 The subsequences with repetition of *iterable* that are of length *r* can
4258 be ordered lexicographically. :func:`nth_combination_with_replacement`
4259 computes the subsequence at sort position *index* directly, without
4260 computing the previous subsequences with replacement.
4262 >>> nth_combination_with_replacement(range(5), 3, 5)
4263 (0, 1, 1)
4265 ``ValueError`` will be raised If *r* is negative.
4266 ``IndexError`` will be raised if the given *index* is invalid.
4267 """
4268 pool = tuple(iterable)
4269 n = len(pool)
4270 if r < 0:
4271 raise ValueError
4272 c = comb(n + r - 1, r) if n else 0 if r else 1
4274 if index < 0:
4275 index += c
4276 if not 0 <= index < c:
4277 raise IndexError
4279 result = []
4280 i = 0
4281 while r:
4282 r -= 1
4283 while n >= 0:
4284 num_combs = comb(n + r - 1, r)
4285 if index < num_combs:
4286 break
4287 n -= 1
4288 i += 1
4289 index -= num_combs
4290 result.append(pool[i])
4292 return tuple(result)
4295def value_chain(*args):
4296 """Yield all arguments passed to the function in the same order in which
4297 they were passed. If an argument itself is iterable then iterate over its
4298 values.
4300 >>> list(value_chain(1, 2, 3, [4, 5, 6]))
4301 [1, 2, 3, 4, 5, 6]
4303 Binary and text strings are not considered iterable and are emitted
4304 as-is:
4306 >>> list(value_chain('12', '34', ['56', '78']))
4307 ['12', '34', '56', '78']
4309 Pre- or postpend a single element to an iterable:
4311 >>> list(value_chain(1, [2, 3, 4, 5, 6]))
4312 [1, 2, 3, 4, 5, 6]
4313 >>> list(value_chain([1, 2, 3, 4, 5], 6))
4314 [1, 2, 3, 4, 5, 6]
4316 Multiple levels of nesting are not flattened.
4318 """
4319 scalar_types = (str, bytes)
4320 for value in args:
4321 if isinstance(value, scalar_types):
4322 yield value
4323 continue
4324 try:
4325 yield from value
4326 except TypeError:
4327 yield value
4330def product_index(element, *iterables, repeat=1):
4331 """Equivalent to ``list(product(*iterables, repeat=repeat)).index(tuple(element))``
4333 The products of *iterables* can be ordered lexicographically.
4334 :func:`product_index` computes the first index of *element* without
4335 computing the previous products.
4337 >>> product_index([8, 2], range(10), range(5))
4338 42
4340 The *repeat* keyword argument specifies the number of repetitions
4341 of the iterables::
4343 >>> product_index([8, 0, 7], range(10), repeat=3)
4344 807
4346 ``ValueError`` will be raised if the given *element* isn't in the product
4347 of *args*.
4348 """
4349 elements = tuple(element)
4350 pools = tuple(map(tuple, iterables)) * repeat
4351 if len(elements) != len(pools):
4352 raise ValueError('element is not a product of args')
4354 index = 0
4355 for elem, pool in zip(elements, pools):
4356 index = index * len(pool) + pool.index(elem)
4357 return index
4360def combination_index(element, iterable):
4361 """Equivalent to ``list(combinations(iterable, r)).index(element)``
4363 The subsequences of *iterable* that are of length *r* can be ordered
4364 lexicographically. :func:`combination_index` computes the index of the
4365 first *element*, without computing the previous combinations.
4367 >>> combination_index('adf', 'abcdefg')
4368 10
4370 ``ValueError`` will be raised if the given *element* isn't one of the
4371 combinations of *iterable*.
4372 """
4373 element = enumerate(element)
4374 k, y = next(element, (None, None))
4375 if k is None:
4376 return 0
4378 indexes = []
4379 pool = enumerate(iterable)
4380 for n, x in pool:
4381 if x == y:
4382 indexes.append(n)
4383 tmp, y = next(element, (None, None))
4384 if tmp is None:
4385 break
4386 else:
4387 k = tmp
4388 else:
4389 raise ValueError('element is not a combination of iterable')
4391 n, _ = last(pool, default=(n, None))
4393 index = 1
4394 for i, j in enumerate(reversed(indexes), start=1):
4395 j = n - j
4396 if i <= j:
4397 index += comb(j, i)
4399 return comb(n + 1, k + 1) - index
4402def combination_with_replacement_index(element, iterable):
4403 """Equivalent to
4404 ``list(combinations_with_replacement(iterable, r)).index(element)``
4406 The subsequences with repetition of *iterable* that are of length *r* can
4407 be ordered lexicographically. :func:`combination_with_replacement_index`
4408 computes the index of the first *element*, without computing the previous
4409 combinations with replacement.
4411 >>> combination_with_replacement_index('adf', 'abcdefg')
4412 20
4414 ``ValueError`` will be raised if the given *element* isn't one of the
4415 combinations with replacement of *iterable*.
4416 """
4417 element = tuple(element)
4418 l = len(element)
4419 element = enumerate(element)
4421 k, y = next(element, (None, None))
4422 if k is None:
4423 return 0
4425 indexes = []
4426 pool = tuple(iterable)
4427 for n, x in enumerate(pool):
4428 while x == y:
4429 indexes.append(n)
4430 tmp, y = next(element, (None, None))
4431 if tmp is None:
4432 break
4433 else:
4434 k = tmp
4435 if y is None:
4436 break
4437 else:
4438 raise ValueError(
4439 'element is not a combination with replacement of iterable'
4440 )
4442 n = len(pool)
4443 occupations = [0] * n
4444 for p in indexes:
4445 occupations[p] += 1
4447 index = 0
4448 cumulative_sum = 0
4449 for k in range(1, n):
4450 cumulative_sum += occupations[k - 1]
4451 j = l + n - 1 - k - cumulative_sum
4452 i = n - k
4453 if i <= j:
4454 index += comb(j, i)
4456 return index
4459def permutation_index(element, iterable):
4460 """Equivalent to ``list(permutations(iterable, r)).index(element)```
4462 The subsequences of *iterable* that are of length *r* where order is
4463 important can be ordered lexicographically. :func:`permutation_index`
4464 computes the index of the first *element* directly, without computing
4465 the previous permutations.
4467 >>> permutation_index([1, 3, 2], range(5))
4468 19
4470 ``ValueError`` will be raised if the given *element* isn't one of the
4471 permutations of *iterable*.
4472 """
4473 index = 0
4474 pool = list(iterable)
4475 for i, x in zip(range(len(pool), -1, -1), element):
4476 r = pool.index(x)
4477 index = index * i + r
4478 del pool[r]
4480 return index
4483class countable:
4484 """Wrap *iterable* and keep a count of how many items have been consumed.
4486 The ``items_seen`` attribute starts at ``0`` and increments as the iterable
4487 is consumed:
4489 >>> iterable = map(str, range(10))
4490 >>> it = countable(iterable)
4491 >>> it.items_seen
4492 0
4493 >>> next(it), next(it)
4494 ('0', '1')
4495 >>> list(it)
4496 ['2', '3', '4', '5', '6', '7', '8', '9']
4497 >>> it.items_seen
4498 10
4499 """
4501 def __init__(self, iterable):
4502 self._iterator = iter(iterable)
4503 self.items_seen = 0
4505 def __iter__(self):
4506 return self
4508 def __next__(self):
4509 item = next(self._iterator)
4510 self.items_seen += 1
4512 return item
4515def chunked_even(iterable, n):
4516 """Break *iterable* into lists of approximately length *n*.
4517 Items are distributed such the lengths of the lists differ by at most
4518 1 item.
4520 >>> iterable = [1, 2, 3, 4, 5, 6, 7]
4521 >>> n = 3
4522 >>> list(chunked_even(iterable, n)) # List lengths: 3, 2, 2
4523 [[1, 2, 3], [4, 5], [6, 7]]
4524 >>> list(chunked(iterable, n)) # List lengths: 3, 3, 1
4525 [[1, 2, 3], [4, 5, 6], [7]]
4527 """
4528 iterator = iter(iterable)
4530 # Initialize a buffer to process the chunks while keeping
4531 # some back to fill any underfilled chunks
4532 min_buffer = (n - 1) * (n - 2)
4533 buffer = list(islice(iterator, min_buffer))
4535 # Append items until we have a completed chunk
4536 for _ in islice(map(buffer.append, iterator), n, None, n):
4537 yield buffer[:n]
4538 del buffer[:n]
4540 # Check if any chunks need addition processing
4541 if not buffer:
4542 return
4543 length = len(buffer)
4545 # Chunks are either size `full_size <= n` or `partial_size = full_size - 1`
4546 q, r = divmod(length, n)
4547 num_lists = q + (1 if r > 0 else 0)
4548 q, r = divmod(length, num_lists)
4549 full_size = q + (1 if r > 0 else 0)
4550 partial_size = full_size - 1
4551 num_full = length - partial_size * num_lists
4553 # Yield chunks of full size
4554 partial_start_idx = num_full * full_size
4555 if full_size > 0:
4556 for i in range(0, partial_start_idx, full_size):
4557 yield buffer[i : i + full_size]
4559 # Yield chunks of partial size
4560 if partial_size > 0:
4561 for i in range(partial_start_idx, length, partial_size):
4562 yield buffer[i : i + partial_size]
4565def zip_broadcast(*objects, scalar_types=(str, bytes), strict=False):
4566 """A version of :func:`zip` that "broadcasts" any scalar
4567 (i.e., non-iterable) items into output tuples.
4569 >>> iterable_1 = [1, 2, 3]
4570 >>> iterable_2 = ['a', 'b', 'c']
4571 >>> scalar = '_'
4572 >>> list(zip_broadcast(iterable_1, iterable_2, scalar))
4573 [(1, 'a', '_'), (2, 'b', '_'), (3, 'c', '_')]
4575 The *scalar_types* keyword argument determines what types are considered
4576 scalar. It is set to ``(str, bytes)`` by default. Set it to ``None`` to
4577 treat strings and byte strings as iterable:
4579 >>> list(zip_broadcast('abc', 0, 'xyz', scalar_types=None))
4580 [('a', 0, 'x'), ('b', 0, 'y'), ('c', 0, 'z')]
4582 If the *strict* keyword argument is ``True``, then
4583 ``ValueError`` will be raised if any of the iterables have
4584 different lengths.
4585 """
4587 def is_scalar(obj):
4588 if scalar_types and isinstance(obj, scalar_types):
4589 return True
4590 try:
4591 iter(obj)
4592 except TypeError:
4593 return True
4594 else:
4595 return False
4597 size = len(objects)
4598 if not size:
4599 return
4601 new_item = [None] * size
4602 iterables, iterable_positions = [], []
4603 for i, obj in enumerate(objects):
4604 if is_scalar(obj):
4605 new_item[i] = obj
4606 else:
4607 iterables.append(iter(obj))
4608 iterable_positions.append(i)
4610 if not iterables:
4611 yield tuple(objects)
4612 return
4614 for item in zip(*iterables, strict=strict):
4615 for i, new_item[i] in zip(iterable_positions, item):
4616 pass
4617 yield tuple(new_item)
4620def unique_in_window(iterable, n, key=None):
4621 """Yield the items from *iterable* that haven't been seen recently.
4622 *n* is the size of the sliding window.
4624 >>> iterable = [0, 1, 0, 2, 3, 0]
4625 >>> n = 3
4626 >>> list(unique_in_window(iterable, n))
4627 [0, 1, 2, 3, 0]
4629 The *key* function, if provided, will be used to determine uniqueness:
4631 >>> list(unique_in_window('abAcda', 3, key=lambda x: x.lower()))
4632 ['a', 'b', 'c', 'd', 'a']
4634 Updates a sliding window no larger than n and yields a value
4635 if the item only occurs once in the updated window.
4637 When `n == 1`, *unique_in_window* is memoryless:
4639 >>> list(unique_in_window('aab', n=1))
4640 ['a', 'a', 'b']
4642 The items in *iterable* must be hashable.
4644 """
4645 if n <= 0:
4646 raise ValueError('n must be greater than 0')
4648 window = deque(maxlen=n)
4649 counts = Counter()
4650 use_key = key is not None
4652 for item in iterable:
4653 if len(window) == n:
4654 to_discard = window[0]
4655 if counts[to_discard] == 1:
4656 del counts[to_discard]
4657 else:
4658 counts[to_discard] -= 1
4660 k = key(item) if use_key else item
4661 if k not in counts:
4662 yield item
4663 counts[k] += 1
4664 window.append(k)
4667def duplicates_everseen(iterable, key=None):
4668 """Yield duplicate elements after their first appearance.
4670 >>> list(duplicates_everseen('mississippi'))
4671 ['s', 'i', 's', 's', 'i', 'p', 'i']
4672 >>> list(duplicates_everseen('AaaBbbCccAaa', str.lower))
4673 ['a', 'a', 'b', 'b', 'c', 'c', 'A', 'a', 'a']
4675 This function is analogous to :func:`unique_everseen` and is subject to
4676 the same performance considerations.
4678 """
4679 seen_set = set()
4680 seen_list = []
4681 use_key = key is not None
4683 for element in iterable:
4684 k = key(element) if use_key else element
4685 try:
4686 if k not in seen_set:
4687 seen_set.add(k)
4688 else:
4689 yield element
4690 except TypeError:
4691 if k not in seen_list:
4692 seen_list.append(k)
4693 else:
4694 yield element
4697def duplicates_justseen(iterable, key=None):
4698 """Yields serially-duplicate elements after their first appearance.
4700 >>> list(duplicates_justseen('mississippi'))
4701 ['s', 's', 'p']
4702 >>> list(duplicates_justseen('AaaBbbCccAaa', str.lower))
4703 ['a', 'a', 'b', 'b', 'c', 'c', 'a', 'a']
4705 This function is analogous to :func:`unique_justseen`.
4707 """
4708 return flatten(g for _, g in groupby(iterable, key) for _ in g)
4711def classify_unique(iterable, key=None):
4712 """Classify each element in terms of its uniqueness.
4714 For each element in the input iterable, return a 3-tuple consisting of:
4716 1. The element itself
4717 2. ``False`` if the element is equal to the one preceding it in the input,
4718 ``True`` otherwise (i.e. the equivalent of :func:`unique_justseen`)
4719 3. ``False`` if this element has been seen anywhere in the input before,
4720 ``True`` otherwise (i.e. the equivalent of :func:`unique_everseen`)
4722 >>> list(classify_unique('otto')) # doctest: +NORMALIZE_WHITESPACE
4723 [('o', True, True),
4724 ('t', True, True),
4725 ('t', False, False),
4726 ('o', True, False)]
4728 This function is analogous to :func:`unique_everseen` and is subject to
4729 the same performance considerations.
4731 """
4732 seen_set = set()
4733 seen_list = []
4734 use_key = key is not None
4735 previous = None
4737 for i, element in enumerate(iterable):
4738 k = key(element) if use_key else element
4739 is_unique_justseen = not i or previous != k
4740 previous = k
4741 is_unique_everseen = False
4742 try:
4743 if k not in seen_set:
4744 seen_set.add(k)
4745 is_unique_everseen = True
4746 except TypeError:
4747 if k not in seen_list:
4748 seen_list.append(k)
4749 is_unique_everseen = True
4750 yield element, is_unique_justseen, is_unique_everseen
4753def minmax(iterable_or_value, *others, key=None, default=_marker):
4754 """Returns both the smallest and largest items from an iterable
4755 or from two or more arguments.
4757 >>> minmax([3, 1, 5])
4758 (1, 5)
4760 >>> minmax(4, 2, 6)
4761 (2, 6)
4763 If a *key* function is provided, it will be used to transform the input
4764 items for comparison.
4766 >>> minmax([5, 30], key=str) # '30' sorts before '5'
4767 (30, 5)
4769 If a *default* value is provided, it will be returned if there are no
4770 input items.
4772 >>> minmax([], default=(0, 0))
4773 (0, 0)
4775 Otherwise ``ValueError`` is raised.
4777 This function makes a single pass over the input elements and takes care to
4778 minimize the number of comparisons made during processing.
4780 Note that unlike the builtin ``max`` function, which always returns the first
4781 item with the maximum value, this function may return another item when there are
4782 ties.
4784 This function is based on the
4785 `recipe <https://code.activestate.com/recipes/577916-fast-minmax-function>`__ by
4786 Raymond Hettinger.
4787 """
4788 iterable = (iterable_or_value, *others) if others else iterable_or_value
4790 it = iter(iterable)
4792 try:
4793 lo = hi = next(it)
4794 except StopIteration as exc:
4795 if default is _marker:
4796 raise ValueError(
4797 '`minmax()` argument is an empty iterable. '
4798 'Provide a `default` value to suppress this error.'
4799 ) from exc
4800 return default
4802 # Different branches depending on the presence of key. This saves a lot
4803 # of unimportant copies which would slow the "key=None" branch
4804 # significantly down.
4805 if key is None:
4806 for x, y in zip_longest(it, it, fillvalue=lo):
4807 if y < x:
4808 x, y = y, x
4809 if x < lo:
4810 lo = x
4811 if hi < y:
4812 hi = y
4814 else:
4815 lo_key = hi_key = key(lo)
4817 for x, y in zip_longest(it, it, fillvalue=lo):
4818 x_key, y_key = key(x), key(y)
4820 if y_key < x_key:
4821 x, y, x_key, y_key = y, x, y_key, x_key
4822 if x_key < lo_key:
4823 lo, lo_key = x, x_key
4824 if hi_key < y_key:
4825 hi, hi_key = y, y_key
4827 return lo, hi
4830def constrained_batches(
4831 iterable, max_size, max_count=None, get_len=len, strict=True
4832):
4833 """Yield batches of items from *iterable* with a combined size limited by
4834 *max_size*.
4836 >>> iterable = [b'12345', b'123', b'12345678', b'1', b'1', b'12', b'1']
4837 >>> list(constrained_batches(iterable, 10))
4838 [(b'12345', b'123'), (b'12345678', b'1', b'1'), (b'12', b'1')]
4840 If a *max_count* is supplied, the number of items per batch is also
4841 limited:
4843 >>> iterable = [b'12345', b'123', b'12345678', b'1', b'1', b'12', b'1']
4844 >>> list(constrained_batches(iterable, 10, max_count = 2))
4845 [(b'12345', b'123'), (b'12345678', b'1'), (b'1', b'12'), (b'1',)]
4847 If a *get_len* function is supplied, use that instead of :func:`len` to
4848 determine item size.
4850 If *strict* is ``True``, raise ``ValueError`` if any single item is bigger
4851 than *max_size*. Otherwise, allow single items to exceed *max_size*.
4852 """
4853 if max_size <= 0:
4854 raise ValueError('maximum size must be greater than zero')
4856 batch = []
4857 batch_size = 0
4858 batch_count = 0
4859 for item in iterable:
4860 item_len = get_len(item)
4861 if strict and item_len > max_size:
4862 raise ValueError('item size exceeds maximum size')
4864 reached_count = batch_count == max_count
4865 reached_size = item_len + batch_size > max_size
4866 if batch_count and (reached_size or reached_count):
4867 yield tuple(batch)
4868 batch.clear()
4869 batch_size = 0
4870 batch_count = 0
4872 batch.append(item)
4873 batch_size += item_len
4874 batch_count += 1
4876 if batch:
4877 yield tuple(batch)
4880def gray_product(*iterables, repeat=1):
4881 """Like :func:`itertools.product`, but return tuples in an order such
4882 that only one element in the generated tuple changes from one iteration
4883 to the next.
4885 >>> list(gray_product('AB','CD'))
4886 [('A', 'C'), ('B', 'C'), ('B', 'D'), ('A', 'D')]
4888 The *repeat* keyword argument specifies the number of repetitions
4889 of the iterables. For example, ``gray_product('AB', repeat=3)`` is
4890 equivalent to ``gray_product('AB', 'AB', 'AB')``.
4892 This function consumes all of the input iterables before producing output.
4893 If any of the input iterables have fewer than two items, ``ValueError``
4894 is raised.
4896 For information on the algorithm, see
4897 `this section <https://www-cs-faculty.stanford.edu/~knuth/fasc2a.ps.gz>`__
4898 of Donald Knuth's *The Art of Computer Programming*.
4899 """
4900 all_iterables = tuple(map(tuple, iterables)) * repeat
4901 iterable_count = len(all_iterables)
4902 for iterable in all_iterables:
4903 if len(iterable) < 2:
4904 raise ValueError("each iterable must have two or more items")
4906 # This is based on "Algorithm H" from section 7.2.1.1, page 20.
4907 # a holds the indexes of the source iterables for the n-tuple to be yielded
4908 # f is the array of "focus pointers"
4909 # o is the array of "directions"
4910 a = [0] * iterable_count
4911 f = list(range(iterable_count + 1))
4912 o = [1] * iterable_count
4913 while True:
4914 yield tuple(all_iterables[i][a[i]] for i in range(iterable_count))
4915 j = f[0]
4916 f[0] = 0
4917 if j == iterable_count:
4918 break
4919 a[j] = a[j] + o[j]
4920 if a[j] == 0 or a[j] == len(all_iterables[j]) - 1:
4921 o[j] = -o[j]
4922 f[j] = f[j + 1]
4923 f[j + 1] = j + 1
4926def partial_product(*iterables, repeat=1):
4927 """Yields tuples containing one item from each iterator, with subsequent
4928 tuples changing a single item at a time by advancing each iterator until it
4929 is exhausted. This sequence guarantees every value in each iterable is
4930 output at least once without generating all possible combinations.
4932 This may be useful, for example, when testing an expensive function.
4934 >>> list(partial_product('AB', 'C', 'DEF'))
4935 [('A', 'C', 'D'), ('B', 'C', 'D'), ('B', 'C', 'E'), ('B', 'C', 'F')]
4937 The *repeat* keyword argument specifies the number of repetitions
4938 of the iterables. For example, ``partial_product('AB', repeat=3)`` is
4939 equivalent to ``partial_product('AB', 'AB', 'AB')``.
4940 """
4942 all_iterables = tuple(map(tuple, iterables)) * repeat
4943 iterators = tuple(map(iter, all_iterables))
4945 try:
4946 prod = [next(it) for it in iterators]
4947 except StopIteration:
4948 return
4949 yield tuple(prod)
4951 for i, it in enumerate(iterators):
4952 for prod[i] in it:
4953 yield tuple(prod)
4956def takewhile_inclusive(predicate, iterable):
4957 """A variant of :func:`takewhile` that yields one additional element.
4959 >>> list(takewhile_inclusive(lambda x: x < 5, [1, 4, 6, 4, 1]))
4960 [1, 4, 6]
4962 :func:`takewhile` would return ``[1, 4]``.
4963 """
4964 for x in iterable:
4965 yield x
4966 if not predicate(x):
4967 break
4970def outer_product(func, xs, ys, *args, **kwargs):
4971 """A generalized outer product that applies a binary function to all
4972 pairs of items. Returns a 2D matrix with ``len(xs)`` rows and ``len(ys)``
4973 columns.
4974 Also accepts ``*args`` and ``**kwargs`` that are passed to ``func``.
4976 Multiplication table:
4978 >>> from operator import mul
4979 >>> list(outer_product(mul, range(1, 4), range(1, 6)))
4980 [(1, 2, 3, 4, 5), (2, 4, 6, 8, 10), (3, 6, 9, 12, 15)]
4982 Cross tabulation:
4984 >>> xs = ['A', 'B', 'A', 'A', 'B', 'B', 'A', 'A', 'B', 'B']
4985 >>> ys = ['X', 'X', 'X', 'Y', 'Z', 'Z', 'Y', 'Y', 'Z', 'Z']
4986 >>> pair_counts = Counter(zip(xs, ys))
4987 >>> count_rows = lambda x, y: pair_counts[x, y]
4988 >>> list(outer_product(count_rows, sorted(set(xs)), sorted(set(ys))))
4989 [(2, 3, 0), (1, 0, 4)]
4991 Usage with ``*args`` and ``**kwargs``:
4993 >>> animals = ['cat', 'wolf', 'mouse']
4994 >>> list(outer_product(min, animals, animals, key=len))
4995 [('cat', 'cat', 'cat'), ('cat', 'wolf', 'wolf'), ('cat', 'wolf', 'mouse')]
4996 """
4997 ys = tuple(ys)
4998 return batched(
4999 starmap(lambda x, y: func(x, y, *args, **kwargs), product(xs, ys)),
5000 n=len(ys),
5001 )
5004def iter_suppress(iterable, *exceptions):
5005 """Yield each of the items from *iterable*. If the iteration raises one of
5006 the specified *exceptions*, that exception will be suppressed and iteration
5007 will stop.
5009 >>> from itertools import chain
5010 >>> def breaks_at_five(x):
5011 ... while True:
5012 ... if x >= 5:
5013 ... raise RuntimeError
5014 ... yield x
5015 ... x += 1
5016 >>> it_1 = iter_suppress(breaks_at_five(1), RuntimeError)
5017 >>> it_2 = iter_suppress(breaks_at_five(2), RuntimeError)
5018 >>> list(chain(it_1, it_2))
5019 [1, 2, 3, 4, 2, 3, 4]
5020 """
5021 try:
5022 yield from iterable
5023 except exceptions:
5024 return
5027def filter_map(func, iterable):
5028 """Apply *func* to every element of *iterable*, yielding only those which
5029 are not ``None``.
5031 >>> elems = ['1', 'a', '2', 'b', '3']
5032 >>> list(filter_map(lambda s: int(s) if s.isnumeric() else None, elems))
5033 [1, 2, 3]
5034 """
5035 for x in iterable:
5036 y = func(x)
5037 if y is not None:
5038 yield y
5041def powerset_of_sets(iterable, *, baseset=set):
5042 """Yields all possible subsets of the iterable.
5044 >>> list(powerset_of_sets([1, 2, 3])) # doctest: +SKIP
5045 [set(), {1}, {2}, {3}, {1, 2}, {1, 3}, {2, 3}, {1, 2, 3}]
5046 >>> list(powerset_of_sets([1, 1, 0])) # doctest: +SKIP
5047 [set(), {1}, {0}, {0, 1}]
5049 :func:`powerset_of_sets` takes care to minimize the number
5050 of hash operations performed.
5052 The *baseset* parameter determines what kind of sets are
5053 constructed, either *set* or *frozenset*.
5054 """
5055 sets = tuple(dict.fromkeys(map(frozenset, zip(iterable))))
5056 union = baseset().union
5057 return chain.from_iterable(
5058 starmap(union, combinations(sets, r)) for r in range(len(sets) + 1)
5059 )
5062def join_mappings(**field_to_map):
5063 """
5064 Joins multiple mappings together using their common keys.
5066 >>> user_scores = {'elliot': 50, 'claris': 60}
5067 >>> user_times = {'elliot': 30, 'claris': 40}
5068 >>> join_mappings(score=user_scores, time=user_times)
5069 {'elliot': {'score': 50, 'time': 30}, 'claris': {'score': 60, 'time': 40}}
5070 """
5071 ret = defaultdict(dict)
5073 for field_name, mapping in field_to_map.items():
5074 for key, value in mapping.items():
5075 ret[key][field_name] = value
5077 return dict(ret)
5080def _complex_sumprod(v1, v2):
5081 """High precision sumprod() for complex numbers.
5082 Used by :func:`dft` and :func:`idft`.
5083 """
5085 real = attrgetter('real')
5086 imag = attrgetter('imag')
5087 r1 = chain(map(real, v1), map(neg, map(imag, v1)))
5088 r2 = chain(map(real, v2), map(imag, v2))
5089 i1 = chain(map(real, v1), map(imag, v1))
5090 i2 = chain(map(imag, v2), map(real, v2))
5091 return complex(_fsumprod(r1, r2), _fsumprod(i1, i2))
5094def dft(xarr):
5095 """Discrete Fourier Transform. *xarr* is a sequence of complex numbers.
5096 Yields the components of the corresponding transformed output vector.
5098 >>> import cmath
5099 >>> xarr = [1, 2-1j, -1j, -1+2j] # time domain
5100 >>> Xarr = [2, -2-2j, -2j, 4+4j] # frequency domain
5101 >>> magnitudes, phases = zip(*map(cmath.polar, Xarr))
5102 >>> all(map(cmath.isclose, dft(xarr), Xarr))
5103 True
5105 Inputs are restricted to numeric types that can add and multiply
5106 with a complex number. This includes int, float, complex, and
5107 Fraction, but excludes Decimal.
5109 See :func:`idft` for the inverse Discrete Fourier Transform.
5110 """
5111 N = len(xarr)
5112 roots_of_unity = [e ** (n / N * tau * -1j) for n in range(N)]
5113 for k in range(N):
5114 coeffs = [roots_of_unity[k * n % N] for n in range(N)]
5115 yield _complex_sumprod(xarr, coeffs)
5118def idft(Xarr):
5119 """Inverse Discrete Fourier Transform. *Xarr* is a sequence of
5120 complex numbers. Yields the components of the corresponding
5121 inverse-transformed output vector.
5123 >>> import cmath
5124 >>> xarr = [1, 2-1j, -1j, -1+2j] # time domain
5125 >>> Xarr = [2, -2-2j, -2j, 4+4j] # frequency domain
5126 >>> all(map(cmath.isclose, idft(Xarr), xarr))
5127 True
5129 Inputs are restricted to numeric types that can add and multiply
5130 with a complex number. This includes int, float, complex, and
5131 Fraction, but excludes Decimal.
5133 See :func:`dft` for the Discrete Fourier Transform.
5134 """
5135 N = len(Xarr)
5136 roots_of_unity = [e ** (n / N * tau * 1j) for n in range(N)]
5137 for k in range(N):
5138 coeffs = [roots_of_unity[k * n % N] for n in range(N)]
5139 yield _complex_sumprod(Xarr, coeffs) / N
5142def doublestarmap(func, iterable):
5143 """Apply *func* to every item of *iterable* by dictionary unpacking
5144 the item into *func*.
5146 The difference between :func:`itertools.starmap` and :func:`doublestarmap`
5147 parallels the distinction between ``func(*a)`` and ``func(**a)``.
5149 >>> iterable = [{'a': 1, 'b': 2}, {'a': 40, 'b': 60}]
5150 >>> list(doublestarmap(lambda a, b: a + b, iterable))
5151 [3, 100]
5153 ``TypeError`` will be raised if *func*'s signature doesn't match the
5154 mapping contained in *iterable* or if *iterable* does not contain mappings.
5155 """
5156 for item in iterable:
5157 yield func(**item)
5160def _nth_prime_bounds(n):
5161 """Bounds for the nth prime (counting from 1): lb < p_n < ub."""
5162 # At and above 688,383, the lb/ub spread is under 0.003 * p_n.
5164 if n < 1:
5165 raise ValueError
5167 if n < 6:
5168 return (n, 2.25 * n)
5170 # https://en.wikipedia.org/wiki/Prime-counting_function#Inequalities
5171 upper_bound = n * log(n * log(n))
5172 lower_bound = upper_bound - n
5173 if n >= 688_383:
5174 upper_bound -= n * (1.0 - (log(log(n)) - 2.0) / log(n))
5176 return lower_bound, upper_bound
5179def nth_prime(n, *, approximate=False):
5180 """Return the nth prime (counting from 0).
5182 >>> nth_prime(0)
5183 2
5184 >>> nth_prime(100)
5185 547
5187 If *approximate* is set to True, will return a prime close
5188 to the nth prime. The estimation is much faster than computing
5189 an exact result.
5191 >>> nth_prime(200_000_000, approximate=True) # Exact result is 4222234763
5192 4217820427
5194 """
5195 lb, ub = _nth_prime_bounds(n + 1)
5197 if not approximate or n <= 1_000_000:
5198 return nth(sieve(ceil(ub)), n)
5200 # Search from the midpoint and return the first odd prime
5201 odd = floor((lb + ub) / 2) | 1
5202 return first_true(count(odd, step=2), pred=is_prime)
5205def argmin(iterable, *, key=None):
5206 """
5207 Index of the first occurrence of a minimum value in an iterable.
5209 >>> argmin('efghabcdijkl')
5210 4
5211 >>> argmin([3, 2, 1, 0, 4, 2, 1, 0])
5212 3
5214 For example, look up a label corresponding to the position
5215 of a value that minimizes a cost function::
5217 >>> def cost(x):
5218 ... "Days for a wound to heal given a subject's age."
5219 ... return x**2 - 20*x + 150
5220 ...
5221 >>> labels = ['homer', 'marge', 'bart', 'lisa', 'maggie']
5222 >>> ages = [ 35, 30, 10, 9, 1 ]
5224 # Fastest healing family member
5225 >>> labels[argmin(ages, key=cost)]
5226 'bart'
5228 # Age with fastest healing
5229 >>> min(ages, key=cost)
5230 10
5232 """
5233 if key is not None:
5234 iterable = map(key, iterable)
5235 return min(enumerate(iterable), key=itemgetter(1))[0]
5238def argmax(iterable, *, key=None):
5239 """
5240 Index of the first occurrence of a maximum value in an iterable.
5242 >>> argmax('abcdefghabcd')
5243 7
5244 >>> argmax([0, 1, 2, 3, 3, 2, 1, 0])
5245 3
5247 For example, identify the best machine learning model::
5249 >>> models = ['svm', 'random forest', 'knn', 'naïve bayes']
5250 >>> accuracy = [ 68, 61, 84, 72 ]
5252 # Most accurate model
5253 >>> models[argmax(accuracy)]
5254 'knn'
5256 # Best accuracy
5257 >>> max(accuracy)
5258 84
5260 """
5261 if key is not None:
5262 iterable = map(key, iterable)
5263 return max(enumerate(iterable), key=itemgetter(1))[0]
5266def _extract_monotonic(iterator, indices):
5267 'Non-decreasing indices, lazily consumed'
5268 num_read = 0
5269 for index in indices:
5270 advance = index - num_read
5271 try:
5272 value = next(islice(iterator, advance, None))
5273 except ValueError:
5274 if advance != -1 or index < 0:
5275 raise ValueError(f'Invalid index: {index}') from None
5276 except StopIteration:
5277 raise IndexError(index) from None
5278 else:
5279 num_read += advance + 1
5280 yield value
5283def _extract_buffered(iterator, index_and_position):
5284 'Arbitrary index order, greedily consumed'
5285 buffer = {}
5286 iterator_position = -1
5287 next_to_emit = 0
5289 for index, order in index_and_position:
5290 advance = index - iterator_position
5291 if advance:
5292 try:
5293 value = next(islice(iterator, advance - 1, None))
5294 except StopIteration:
5295 raise IndexError(index) from None
5296 iterator_position = index
5298 buffer[order] = value
5300 while next_to_emit in buffer:
5301 yield buffer.pop(next_to_emit)
5302 next_to_emit += 1
5305def extract(iterable, indices, *, monotonic=False):
5306 """Yield values at the specified indices.
5308 Example:
5310 >>> data = 'abcdefghijklmnopqrstuvwxyz'
5311 >>> list(extract(data, [7, 4, 11, 11, 14]))
5312 ['h', 'e', 'l', 'l', 'o']
5314 The *iterable* is consumed lazily and can be infinite.
5316 When *monotonic* is false, the *indices* are consumed immediately
5317 and must be finite. When *monotonic* is true, *indices* are consumed
5318 lazily and can be infinite but must be non-decreasing.
5320 Raises ``IndexError`` if an index lies beyond the iterable.
5321 Raises ``ValueError`` for a negative index or for a decreasing
5322 index when *monotonic* is true.
5323 """
5325 iterator = iter(iterable)
5326 indices = iter(indices)
5328 if monotonic:
5329 return _extract_monotonic(iterator, indices)
5331 index_and_position = sorted(zip(indices, count()))
5332 if index_and_position and index_and_position[0][0] < 0:
5333 raise ValueError('Indices must be non-negative')
5334 return _extract_buffered(iterator, index_and_position)
5337class serialize:
5338 """Wrap a non-concurrent iterator with a lock to enforce sequential access.
5340 Applies a non-reentrant lock around calls to ``__next__``, allowing
5341 iterator and generator instances to be shared by multiple consumer
5342 threads.
5343 """
5345 __slots__ = ('iterator', 'lock')
5347 def __init__(self, iterable):
5348 self.iterator = iter(iterable)
5349 self.lock = Lock()
5351 def __iter__(self):
5352 return self
5354 def __next__(self):
5355 with self.lock:
5356 return next(self.iterator)
5359def synchronized(func):
5360 """Wrap an iterator-returning callable to make its iterators thread-safe.
5362 Existing itertools and more-itertools can be wrapped so that their
5363 iterator instances are serialized.
5365 For example, ``itertools.count`` does not make thread-safe instances,
5366 but that is easily fixed with::
5368 atomic_counter = synchronized(itertools.count)
5370 Can also be used as a decorator for generator functions definitions
5371 so that the generator instances are serialized::
5373 @synchronized
5374 def enumerate_and_timestamp(iterable):
5375 for count, value in enumerate(iterable):
5376 yield count, time_ns(), value
5378 """
5380 @wraps(func)
5381 def inner(*args, **kwargs):
5382 iterator = func(*args, **kwargs)
5383 return serialize(iterator)
5385 return inner
5388def concurrent_tee(iterable, n=2):
5389 """Variant of itertools.tee() but with guaranteed threading semantics.
5391 Takes a non-threadsafe iterator as an input and creates concurrent
5392 tee objects for other threads to have reliable independent copies of
5393 the data stream.
5395 The new iterators are only thread-safe if consumed within a single thread.
5396 To share just one of the new iterators across multiple threads, wrap it
5397 with :func:`serialize`.
5398 """
5400 if n < 0:
5401 raise ValueError
5402 if n == 0:
5403 return ()
5404 iterator = _concurrent_tee(iterable)
5405 result = [iterator]
5406 for _ in range(n - 1):
5407 result.append(_concurrent_tee(iterator))
5408 return tuple(result)
5411class _concurrent_tee:
5412 __slots__ = ('iterator', 'link', 'lock')
5414 def __init__(self, iterable):
5415 if isinstance(iterable, _concurrent_tee):
5416 self.iterator = iterable.iterator
5417 self.link = iterable.link
5418 self.lock = iterable.lock
5419 else:
5420 self.iterator = iter(iterable)
5421 self.link = [None, None]
5422 self.lock = Lock()
5424 def __iter__(self):
5425 return self
5427 def __next__(self):
5428 link = self.link
5429 if link[1] is None:
5430 with self.lock:
5431 if link[1] is None:
5432 link[0] = next(self.iterator)
5433 link[1] = [None, None]
5434 value, self.link = link
5435 return value