1"""Imported from the recipes section of the itertools documentation.
2
3All functions taken from the recipes section of the itertools library docs
4[1]_.
5Some backward-compatible usability improvements have been made.
6
7.. [1] http://docs.python.org/library/itertools.html#recipes
8
9"""
10
11import random
12
13from bisect import bisect_left, insort
14from collections import deque
15from contextlib import suppress
16from collections.abc import Sized
17from functools import lru_cache, partial, reduce
18from heapq import heappush, heappushpop
19from itertools import (
20 accumulate,
21 chain,
22 combinations,
23 compress,
24 count,
25 cycle,
26 groupby,
27 islice,
28 product,
29 repeat,
30 starmap,
31 takewhile,
32 tee,
33 zip_longest,
34)
35from math import prod, comb, isqrt, gcd
36from operator import mul, not_, itemgetter, getitem, index
37from random import randrange, sample, choice
38from sys import hexversion
39
40__all__ = [
41 'all_equal',
42 'batched',
43 'before_and_after',
44 'consume',
45 'convolve',
46 'dotproduct',
47 'first_true',
48 'factor',
49 'flatten',
50 'grouper',
51 'is_prime',
52 'iter_except',
53 'iter_index',
54 'loops',
55 'matmul',
56 'multinomial',
57 'ncycles',
58 'nth',
59 'nth_combination',
60 'padnone',
61 'pad_none',
62 'pairwise',
63 'partition',
64 'polynomial_eval',
65 'polynomial_from_roots',
66 'polynomial_derivative',
67 'powerset',
68 'prepend',
69 'quantify',
70 'reshape',
71 'random_combination_with_replacement',
72 'random_combination',
73 'random_permutation',
74 'random_product',
75 'repeatfunc',
76 'roundrobin',
77 'running_median',
78 'sieve',
79 'sliding_window',
80 'subslices',
81 'sum_of_squares',
82 'tabulate',
83 'tail',
84 'take',
85 'totient',
86 'transpose',
87 'triplewise',
88 'unique',
89 'unique_everseen',
90 'unique_justseen',
91]
92
93_marker = object()
94
95
96# zip with strict is available for Python 3.10+
97try:
98 zip(strict=True)
99except TypeError:
100 _zip_strict = zip
101else:
102 _zip_strict = partial(zip, strict=True)
103
104
105# math.sumprod is available for Python 3.12+
106try:
107 from math import sumprod as _sumprod
108except ImportError:
109 _sumprod = lambda x, y: dotproduct(x, y)
110
111
112# heapq max-heap functions are available for Python 3.14+
113try:
114 from heapq import heappush_max, heappushpop_max
115
116 _max_heap_available = True
117except ImportError:
118 _max_heap_available = False
119
120
121def take(n, iterable):
122 """Return first *n* items of the *iterable* as a list.
123
124 >>> take(3, range(10))
125 [0, 1, 2]
126
127 If there are fewer than *n* items in the iterable, all of them are
128 returned.
129
130 >>> take(10, range(3))
131 [0, 1, 2]
132
133 """
134 return list(islice(iterable, n))
135
136
137def tabulate(function, start=0):
138 """Return an iterator over the results of ``func(start)``,
139 ``func(start + 1)``, ``func(start + 2)``...
140
141 *func* should be a function that accepts one integer argument.
142
143 If *start* is not specified it defaults to 0. It will be incremented each
144 time the iterator is advanced.
145
146 >>> square = lambda x: x ** 2
147 >>> iterator = tabulate(square, -3)
148 >>> take(4, iterator)
149 [9, 4, 1, 0]
150
151 """
152 return map(function, count(start))
153
154
155def tail(n, iterable):
156 """Return an iterator over the last *n* items of *iterable*.
157
158 >>> t = tail(3, 'ABCDEFG')
159 >>> list(t)
160 ['E', 'F', 'G']
161
162 """
163 # If the given iterable has a length, then we can use islice to get its
164 # final elements. Note that if the iterable is not actually Iterable,
165 # either islice or deque will throw a TypeError. This is why we don't
166 # check if it is Iterable.
167 if isinstance(iterable, Sized):
168 return islice(iterable, max(0, len(iterable) - n), None)
169 else:
170 return iter(deque(iterable, maxlen=n))
171
172
173def consume(iterator, n=None):
174 """Advance *iterable* by *n* steps. If *n* is ``None``, consume it
175 entirely.
176
177 Efficiently exhausts an iterator without returning values. Defaults to
178 consuming the whole iterator, but an optional second argument may be
179 provided to limit consumption.
180
181 >>> i = (x for x in range(10))
182 >>> next(i)
183 0
184 >>> consume(i, 3)
185 >>> next(i)
186 4
187 >>> consume(i)
188 >>> next(i)
189 Traceback (most recent call last):
190 File "<stdin>", line 1, in <module>
191 StopIteration
192
193 If the iterator has fewer items remaining than the provided limit, the
194 whole iterator will be consumed.
195
196 >>> i = (x for x in range(3))
197 >>> consume(i, 5)
198 >>> next(i)
199 Traceback (most recent call last):
200 File "<stdin>", line 1, in <module>
201 StopIteration
202
203 """
204 # Use functions that consume iterators at C speed.
205 if n is None:
206 # feed the entire iterator into a zero-length deque
207 deque(iterator, maxlen=0)
208 else:
209 # advance to the empty slice starting at position n
210 next(islice(iterator, n, n), None)
211
212
213def nth(iterable, n, default=None):
214 """Returns the nth item or a default value.
215
216 >>> l = range(10)
217 >>> nth(l, 3)
218 3
219 >>> nth(l, 20, "zebra")
220 'zebra'
221
222 """
223 return next(islice(iterable, n, None), default)
224
225
226def all_equal(iterable, key=None):
227 """
228 Returns ``True`` if all the elements are equal to each other.
229
230 >>> all_equal('aaaa')
231 True
232 >>> all_equal('aaab')
233 False
234
235 A function that accepts a single argument and returns a transformed version
236 of each input item can be specified with *key*:
237
238 >>> all_equal('AaaA', key=str.casefold)
239 True
240 >>> all_equal([1, 2, 3], key=lambda x: x < 10)
241 True
242
243 """
244 iterator = groupby(iterable, key)
245 for first in iterator:
246 for second in iterator:
247 return False
248 return True
249 return True
250
251
252def quantify(iterable, pred=bool):
253 """Return the how many times the predicate is true.
254
255 >>> quantify([True, False, True])
256 2
257
258 """
259 return sum(map(pred, iterable))
260
261
262def pad_none(iterable):
263 """Returns the sequence of elements and then returns ``None`` indefinitely.
264
265 >>> take(5, pad_none(range(3)))
266 [0, 1, 2, None, None]
267
268 Useful for emulating the behavior of the built-in :func:`map` function.
269
270 See also :func:`padded`.
271
272 """
273 return chain(iterable, repeat(None))
274
275
276padnone = pad_none
277
278
279def ncycles(iterable, n):
280 """Returns the sequence elements *n* times
281
282 >>> list(ncycles(["a", "b"], 3))
283 ['a', 'b', 'a', 'b', 'a', 'b']
284
285 """
286 return chain.from_iterable(repeat(tuple(iterable), n))
287
288
289def dotproduct(vec1, vec2):
290 """Returns the dot product of the two iterables.
291
292 >>> dotproduct([10, 15, 12], [0.65, 0.80, 1.25])
293 33.5
294 >>> 10 * 0.65 + 15 * 0.80 + 12 * 1.25
295 33.5
296
297 In Python 3.12 and later, use ``math.sumprod()`` instead.
298 """
299 return sum(map(mul, vec1, vec2))
300
301
302def flatten(listOfLists):
303 """Return an iterator flattening one level of nesting in a list of lists.
304
305 >>> list(flatten([[0, 1], [2, 3]]))
306 [0, 1, 2, 3]
307
308 See also :func:`collapse`, which can flatten multiple levels of nesting.
309
310 """
311 return chain.from_iterable(listOfLists)
312
313
314def repeatfunc(func, times=None, *args):
315 """Call *func* with *args* repeatedly, returning an iterable over the
316 results.
317
318 If *times* is specified, the iterable will terminate after that many
319 repetitions:
320
321 >>> from operator import add
322 >>> times = 4
323 >>> args = 3, 5
324 >>> list(repeatfunc(add, times, *args))
325 [8, 8, 8, 8]
326
327 If *times* is ``None`` the iterable will not terminate:
328
329 >>> from random import randrange
330 >>> times = None
331 >>> args = 1, 11
332 >>> take(6, repeatfunc(randrange, times, *args)) # doctest:+SKIP
333 [2, 4, 8, 1, 8, 4]
334
335 """
336 if times is None:
337 return starmap(func, repeat(args))
338 return starmap(func, repeat(args, times))
339
340
341def _pairwise(iterable):
342 """Returns an iterator of paired items, overlapping, from the original
343
344 >>> take(4, pairwise(count()))
345 [(0, 1), (1, 2), (2, 3), (3, 4)]
346
347 On Python 3.10 and above, this is an alias for :func:`itertools.pairwise`.
348
349 """
350 a, b = tee(iterable)
351 next(b, None)
352 return zip(a, b)
353
354
355try:
356 from itertools import pairwise as itertools_pairwise
357except ImportError:
358 pairwise = _pairwise
359else:
360
361 def pairwise(iterable):
362 return itertools_pairwise(iterable)
363
364 pairwise.__doc__ = _pairwise.__doc__
365
366
367class UnequalIterablesError(ValueError):
368 def __init__(self, details=None):
369 msg = 'Iterables have different lengths'
370 if details is not None:
371 msg += (': index 0 has length {}; index {} has length {}').format(
372 *details
373 )
374
375 super().__init__(msg)
376
377
378def _zip_equal_generator(iterables):
379 for combo in zip_longest(*iterables, fillvalue=_marker):
380 for val in combo:
381 if val is _marker:
382 raise UnequalIterablesError()
383 yield combo
384
385
386def _zip_equal(*iterables):
387 # Check whether the iterables are all the same size.
388 try:
389 first_size = len(iterables[0])
390 for i, it in enumerate(iterables[1:], 1):
391 size = len(it)
392 if size != first_size:
393 raise UnequalIterablesError(details=(first_size, i, size))
394 # All sizes are equal, we can use the built-in zip.
395 return zip(*iterables)
396 # If any one of the iterables didn't have a length, start reading
397 # them until one runs out.
398 except TypeError:
399 return _zip_equal_generator(iterables)
400
401
402def grouper(iterable, n, incomplete='fill', fillvalue=None):
403 """Group elements from *iterable* into fixed-length groups of length *n*.
404
405 >>> list(grouper('ABCDEF', 3))
406 [('A', 'B', 'C'), ('D', 'E', 'F')]
407
408 The keyword arguments *incomplete* and *fillvalue* control what happens for
409 iterables whose length is not a multiple of *n*.
410
411 When *incomplete* is `'fill'`, the last group will contain instances of
412 *fillvalue*.
413
414 >>> list(grouper('ABCDEFG', 3, incomplete='fill', fillvalue='x'))
415 [('A', 'B', 'C'), ('D', 'E', 'F'), ('G', 'x', 'x')]
416
417 When *incomplete* is `'ignore'`, the last group will not be emitted.
418
419 >>> list(grouper('ABCDEFG', 3, incomplete='ignore', fillvalue='x'))
420 [('A', 'B', 'C'), ('D', 'E', 'F')]
421
422 When *incomplete* is `'strict'`, a subclass of `ValueError` will be raised.
423
424 >>> iterator = grouper('ABCDEFG', 3, incomplete='strict')
425 >>> list(iterator) # doctest: +IGNORE_EXCEPTION_DETAIL
426 Traceback (most recent call last):
427 ...
428 UnequalIterablesError
429
430 """
431 iterators = [iter(iterable)] * n
432 if incomplete == 'fill':
433 return zip_longest(*iterators, fillvalue=fillvalue)
434 if incomplete == 'strict':
435 return _zip_equal(*iterators)
436 if incomplete == 'ignore':
437 return zip(*iterators)
438 else:
439 raise ValueError('Expected fill, strict, or ignore')
440
441
442def roundrobin(*iterables):
443 """Visit input iterables in a cycle until each is exhausted.
444
445 >>> list(roundrobin('ABC', 'D', 'EF'))
446 ['A', 'D', 'E', 'B', 'F', 'C']
447
448 This function produces the same output as :func:`interleave_longest`, but
449 may perform better for some inputs (in particular when the number of
450 iterables is small).
451
452 """
453 # Algorithm credited to George Sakkis
454 iterators = map(iter, iterables)
455 for num_active in range(len(iterables), 0, -1):
456 iterators = cycle(islice(iterators, num_active))
457 yield from map(next, iterators)
458
459
460def partition(pred, iterable):
461 """
462 Returns a 2-tuple of iterables derived from the input iterable.
463 The first yields the items that have ``pred(item) == False``.
464 The second yields the items that have ``pred(item) == True``.
465
466 >>> is_odd = lambda x: x % 2 != 0
467 >>> iterable = range(10)
468 >>> even_items, odd_items = partition(is_odd, iterable)
469 >>> list(even_items), list(odd_items)
470 ([0, 2, 4, 6, 8], [1, 3, 5, 7, 9])
471
472 If *pred* is None, :func:`bool` is used.
473
474 >>> iterable = [0, 1, False, True, '', ' ']
475 >>> false_items, true_items = partition(None, iterable)
476 >>> list(false_items), list(true_items)
477 ([0, False, ''], [1, True, ' '])
478
479 """
480 if pred is None:
481 pred = bool
482
483 t1, t2, p = tee(iterable, 3)
484 p1, p2 = tee(map(pred, p))
485 return (compress(t1, map(not_, p1)), compress(t2, p2))
486
487
488def powerset(iterable):
489 """Yields all possible subsets of the iterable.
490
491 >>> list(powerset([1, 2, 3]))
492 [(), (1,), (2,), (3,), (1, 2), (1, 3), (2, 3), (1, 2, 3)]
493
494 :func:`powerset` will operate on iterables that aren't :class:`set`
495 instances, so repeated elements in the input will produce repeated elements
496 in the output.
497
498 >>> seq = [1, 1, 0]
499 >>> list(powerset(seq))
500 [(), (1,), (1,), (0,), (1, 1), (1, 0), (1, 0), (1, 1, 0)]
501
502 For a variant that efficiently yields actual :class:`set` instances, see
503 :func:`powerset_of_sets`.
504 """
505 s = list(iterable)
506 return chain.from_iterable(combinations(s, r) for r in range(len(s) + 1))
507
508
509def unique_everseen(iterable, key=None):
510 """
511 Yield unique elements, preserving order.
512
513 >>> list(unique_everseen('AAAABBBCCDAABBB'))
514 ['A', 'B', 'C', 'D']
515 >>> list(unique_everseen('ABBCcAD', str.lower))
516 ['A', 'B', 'C', 'D']
517
518 Sequences with a mix of hashable and unhashable items can be used.
519 The function will be slower (i.e., `O(n^2)`) for unhashable items.
520
521 Remember that ``list`` objects are unhashable - you can use the *key*
522 parameter to transform the list to a tuple (which is hashable) to
523 avoid a slowdown.
524
525 >>> iterable = ([1, 2], [2, 3], [1, 2])
526 >>> list(unique_everseen(iterable)) # Slow
527 [[1, 2], [2, 3]]
528 >>> list(unique_everseen(iterable, key=tuple)) # Faster
529 [[1, 2], [2, 3]]
530
531 Similarly, you may want to convert unhashable ``set`` objects with
532 ``key=frozenset``. For ``dict`` objects,
533 ``key=lambda x: frozenset(x.items())`` can be used.
534
535 """
536 seenset = set()
537 seenset_add = seenset.add
538 seenlist = []
539 seenlist_add = seenlist.append
540 use_key = key is not None
541
542 for element in iterable:
543 k = key(element) if use_key else element
544 try:
545 if k not in seenset:
546 seenset_add(k)
547 yield element
548 except TypeError:
549 if k not in seenlist:
550 seenlist_add(k)
551 yield element
552
553
554def unique_justseen(iterable, key=None):
555 """Yields elements in order, ignoring serial duplicates
556
557 >>> list(unique_justseen('AAAABBBCCDAABBB'))
558 ['A', 'B', 'C', 'D', 'A', 'B']
559 >>> list(unique_justseen('ABBCcAD', str.lower))
560 ['A', 'B', 'C', 'A', 'D']
561
562 """
563 if key is None:
564 return map(itemgetter(0), groupby(iterable))
565
566 return map(next, map(itemgetter(1), groupby(iterable, key)))
567
568
569def unique(iterable, key=None, reverse=False):
570 """Yields unique elements in sorted order.
571
572 >>> list(unique([[1, 2], [3, 4], [1, 2]]))
573 [[1, 2], [3, 4]]
574
575 *key* and *reverse* are passed to :func:`sorted`.
576
577 >>> list(unique('ABBcCAD', str.casefold))
578 ['A', 'B', 'c', 'D']
579 >>> list(unique('ABBcCAD', str.casefold, reverse=True))
580 ['D', 'c', 'B', 'A']
581
582 The elements in *iterable* need not be hashable, but they must be
583 comparable for sorting to work.
584 """
585 sequenced = sorted(iterable, key=key, reverse=reverse)
586 return unique_justseen(sequenced, key=key)
587
588
589def iter_except(func, exception, first=None):
590 """Yields results from a function repeatedly until an exception is raised.
591
592 Converts a call-until-exception interface to an iterator interface.
593 Like ``iter(func, sentinel)``, but uses an exception instead of a sentinel
594 to end the loop.
595
596 >>> l = [0, 1, 2]
597 >>> list(iter_except(l.pop, IndexError))
598 [2, 1, 0]
599
600 Multiple exceptions can be specified as a stopping condition:
601
602 >>> l = [1, 2, 3, '...', 4, 5, 6]
603 >>> list(iter_except(lambda: 1 + l.pop(), (IndexError, TypeError)))
604 [7, 6, 5]
605 >>> list(iter_except(lambda: 1 + l.pop(), (IndexError, TypeError)))
606 [4, 3, 2]
607 >>> list(iter_except(lambda: 1 + l.pop(), (IndexError, TypeError)))
608 []
609
610 """
611 with suppress(exception):
612 if first is not None:
613 yield first()
614 while True:
615 yield func()
616
617
618def first_true(iterable, default=None, pred=None):
619 """
620 Returns the first true value in the iterable.
621
622 If no true value is found, returns *default*
623
624 If *pred* is not None, returns the first item for which
625 ``pred(item) == True`` .
626
627 >>> first_true(range(10))
628 1
629 >>> first_true(range(10), pred=lambda x: x > 5)
630 6
631 >>> first_true(range(10), default='missing', pred=lambda x: x > 9)
632 'missing'
633
634 """
635 return next(filter(pred, iterable), default)
636
637
638def random_product(*args, repeat=1):
639 """Draw an item at random from each of the input iterables.
640
641 >>> random_product('abc', range(4), 'XYZ') # doctest:+SKIP
642 ('c', 3, 'Z')
643
644 If *repeat* is provided as a keyword argument, that many items will be
645 drawn from each iterable.
646
647 >>> random_product('abcd', range(4), repeat=2) # doctest:+SKIP
648 ('a', 2, 'd', 3)
649
650 This equivalent to taking a random selection from
651 ``itertools.product(*args, repeat=repeat)``.
652
653 """
654 pools = [tuple(pool) for pool in args] * repeat
655 return tuple(choice(pool) for pool in pools)
656
657
658def random_permutation(iterable, r=None):
659 """Return a random *r* length permutation of the elements in *iterable*.
660
661 If *r* is not specified or is ``None``, then *r* defaults to the length of
662 *iterable*.
663
664 >>> random_permutation(range(5)) # doctest:+SKIP
665 (3, 4, 0, 1, 2)
666
667 This equivalent to taking a random selection from
668 ``itertools.permutations(iterable, r)``.
669
670 """
671 pool = tuple(iterable)
672 r = len(pool) if r is None else r
673 return tuple(sample(pool, r))
674
675
676def random_combination(iterable, r):
677 """Return a random *r* length subsequence of the elements in *iterable*.
678
679 >>> random_combination(range(5), 3) # doctest:+SKIP
680 (2, 3, 4)
681
682 This equivalent to taking a random selection from
683 ``itertools.combinations(iterable, r)``.
684
685 """
686 pool = tuple(iterable)
687 n = len(pool)
688 indices = sorted(sample(range(n), r))
689 return tuple(pool[i] for i in indices)
690
691
692def random_combination_with_replacement(iterable, r):
693 """Return a random *r* length subsequence of elements in *iterable*,
694 allowing individual elements to be repeated.
695
696 >>> random_combination_with_replacement(range(3), 5) # doctest:+SKIP
697 (0, 0, 1, 2, 2)
698
699 This equivalent to taking a random selection from
700 ``itertools.combinations_with_replacement(iterable, r)``.
701
702 """
703 pool = tuple(iterable)
704 n = len(pool)
705 indices = sorted(randrange(n) for i in range(r))
706 return tuple(pool[i] for i in indices)
707
708
709def nth_combination(iterable, r, index):
710 """Equivalent to ``list(combinations(iterable, r))[index]``.
711
712 The subsequences of *iterable* that are of length *r* can be ordered
713 lexicographically. :func:`nth_combination` computes the subsequence at
714 sort position *index* directly, without computing the previous
715 subsequences.
716
717 >>> nth_combination(range(5), 3, 5)
718 (0, 3, 4)
719
720 ``ValueError`` will be raised If *r* is negative or greater than the length
721 of *iterable*.
722 ``IndexError`` will be raised if the given *index* is invalid.
723 """
724 pool = tuple(iterable)
725 n = len(pool)
726 if (r < 0) or (r > n):
727 raise ValueError
728
729 c = 1
730 k = min(r, n - r)
731 for i in range(1, k + 1):
732 c = c * (n - k + i) // i
733
734 if index < 0:
735 index += c
736
737 if (index < 0) or (index >= c):
738 raise IndexError
739
740 result = []
741 while r:
742 c, n, r = c * r // n, n - 1, r - 1
743 while index >= c:
744 index -= c
745 c, n = c * (n - r) // n, n - 1
746 result.append(pool[-1 - n])
747
748 return tuple(result)
749
750
751def prepend(value, iterator):
752 """Yield *value*, followed by the elements in *iterator*.
753
754 >>> value = '0'
755 >>> iterator = ['1', '2', '3']
756 >>> list(prepend(value, iterator))
757 ['0', '1', '2', '3']
758
759 To prepend multiple values, see :func:`itertools.chain`
760 or :func:`value_chain`.
761
762 """
763 return chain([value], iterator)
764
765
766def convolve(signal, kernel):
767 """Discrete linear convolution of two iterables.
768 Equivalent to polynomial multiplication.
769
770 For example, multiplying ``(x² -x - 20)`` by ``(x - 3)``
771 gives ``(x³ -4x² -17x + 60)``.
772
773 >>> list(convolve([1, -1, -20], [1, -3]))
774 [1, -4, -17, 60]
775
776 Examples of popular kinds of kernels:
777
778 * The kernel ``[0.25, 0.25, 0.25, 0.25]`` computes a moving average.
779 For image data, this blurs the image and reduces noise.
780 * The kernel ``[1/2, 0, -1/2]`` estimates the first derivative of
781 a function evaluated at evenly spaced inputs.
782 * The kernel ``[1, -2, 1]`` estimates the second derivative of a
783 function evaluated at evenly spaced inputs.
784
785 Convolutions are mathematically commutative; however, the inputs are
786 evaluated differently. The signal is consumed lazily and can be
787 infinite. The kernel is fully consumed before the calculations begin.
788
789 Supports all numeric types: int, float, complex, Decimal, Fraction.
790
791 References:
792
793 * Article: https://betterexplained.com/articles/intuitive-convolution/
794 * Video by 3Blue1Brown: https://www.youtube.com/watch?v=KuXjwB4LzSA
795
796 """
797 # This implementation comes from an older version of the itertools
798 # documentation. While the newer implementation is a bit clearer,
799 # this one was kept because the inlined window logic is faster
800 # and it avoids an unnecessary deque-to-tuple conversion.
801 kernel = tuple(kernel)[::-1]
802 n = len(kernel)
803 window = deque([0], maxlen=n) * n
804 for x in chain(signal, repeat(0, n - 1)):
805 window.append(x)
806 yield _sumprod(kernel, window)
807
808
809def before_and_after(predicate, it):
810 """A variant of :func:`takewhile` that allows complete access to the
811 remainder of the iterator.
812
813 >>> it = iter('ABCdEfGhI')
814 >>> all_upper, remainder = before_and_after(str.isupper, it)
815 >>> ''.join(all_upper)
816 'ABC'
817 >>> ''.join(remainder) # takewhile() would lose the 'd'
818 'dEfGhI'
819
820 Note that the first iterator must be fully consumed before the second
821 iterator can generate valid results.
822 """
823 trues, after = tee(it)
824 trues = compress(takewhile(predicate, trues), zip(after))
825 return trues, after
826
827
828def triplewise(iterable):
829 """Return overlapping triplets from *iterable*.
830
831 >>> list(triplewise('ABCDE'))
832 [('A', 'B', 'C'), ('B', 'C', 'D'), ('C', 'D', 'E')]
833
834 """
835 # This deviates from the itertools documentation recipe - see
836 # https://github.com/more-itertools/more-itertools/issues/889
837 t1, t2, t3 = tee(iterable, 3)
838 next(t3, None)
839 next(t3, None)
840 next(t2, None)
841 return zip(t1, t2, t3)
842
843
844def _sliding_window_islice(iterable, n):
845 # Fast path for small, non-zero values of n.
846 iterators = tee(iterable, n)
847 for i, iterator in enumerate(iterators):
848 next(islice(iterator, i, i), None)
849 return zip(*iterators)
850
851
852def _sliding_window_deque(iterable, n):
853 # Normal path for other values of n.
854 iterator = iter(iterable)
855 window = deque(islice(iterator, n - 1), maxlen=n)
856 for x in iterator:
857 window.append(x)
858 yield tuple(window)
859
860
861def sliding_window(iterable, n):
862 """Return a sliding window of width *n* over *iterable*.
863
864 >>> list(sliding_window(range(6), 4))
865 [(0, 1, 2, 3), (1, 2, 3, 4), (2, 3, 4, 5)]
866
867 If *iterable* has fewer than *n* items, then nothing is yielded:
868
869 >>> list(sliding_window(range(3), 4))
870 []
871
872 For a variant with more features, see :func:`windowed`.
873 """
874 if n > 20:
875 return _sliding_window_deque(iterable, n)
876 elif n > 2:
877 return _sliding_window_islice(iterable, n)
878 elif n == 2:
879 return pairwise(iterable)
880 elif n == 1:
881 return zip(iterable)
882 else:
883 raise ValueError(f'n should be at least one, not {n}')
884
885
886def subslices(iterable):
887 """Return all contiguous non-empty subslices of *iterable*.
888
889 >>> list(subslices('ABC'))
890 [['A'], ['A', 'B'], ['A', 'B', 'C'], ['B'], ['B', 'C'], ['C']]
891
892 This is similar to :func:`substrings`, but emits items in a different
893 order.
894 """
895 seq = list(iterable)
896 slices = starmap(slice, combinations(range(len(seq) + 1), 2))
897 return map(getitem, repeat(seq), slices)
898
899
900def polynomial_from_roots(roots):
901 """Compute a polynomial's coefficients from its roots.
902
903 >>> roots = [5, -4, 3] # (x - 5) * (x + 4) * (x - 3)
904 >>> polynomial_from_roots(roots) # x³ - 4 x² - 17 x + 60
905 [1, -4, -17, 60]
906
907 Note that polynomial coefficients are specified in descending power order.
908
909 Supports all numeric types: int, float, complex, Decimal, Fraction.
910 """
911
912 # This recipe differs from the one in itertools docs in that it
913 # applies list() after each call to convolve(). This avoids
914 # hitting stack limits with nested generators.
915
916 poly = [1]
917 for root in roots:
918 poly = list(convolve(poly, (1, -root)))
919 return poly
920
921
922def iter_index(iterable, value, start=0, stop=None):
923 """Yield the index of each place in *iterable* that *value* occurs,
924 beginning with index *start* and ending before index *stop*.
925
926
927 >>> list(iter_index('AABCADEAF', 'A'))
928 [0, 1, 4, 7]
929 >>> list(iter_index('AABCADEAF', 'A', 1)) # start index is inclusive
930 [1, 4, 7]
931 >>> list(iter_index('AABCADEAF', 'A', 1, 7)) # stop index is not inclusive
932 [1, 4]
933
934 The behavior for non-scalar *values* matches the built-in Python types.
935
936 >>> list(iter_index('ABCDABCD', 'AB'))
937 [0, 4]
938 >>> list(iter_index([0, 1, 2, 3, 0, 1, 2, 3], [0, 1]))
939 []
940 >>> list(iter_index([[0, 1], [2, 3], [0, 1], [2, 3]], [0, 1]))
941 [0, 2]
942
943 See :func:`locate` for a more general means of finding the indexes
944 associated with particular values.
945
946 """
947 seq_index = getattr(iterable, 'index', None)
948 if seq_index is None:
949 # Slow path for general iterables
950 iterator = islice(iterable, start, stop)
951 for i, element in enumerate(iterator, start):
952 if element is value or element == value:
953 yield i
954 else:
955 # Fast path for sequences
956 stop = len(iterable) if stop is None else stop
957 i = start - 1
958 with suppress(ValueError):
959 while True:
960 yield (i := seq_index(value, i + 1, stop))
961
962
963def sieve(n):
964 """Yield the primes less than n.
965
966 >>> list(sieve(30))
967 [2, 3, 5, 7, 11, 13, 17, 19, 23, 29]
968
969 """
970 # This implementation comes from an older version of the itertools
971 # documentation. The newer implementation is easier to read but is
972 # less lazy.
973 if n > 2:
974 yield 2
975 start = 3
976 data = bytearray((0, 1)) * (n // 2)
977 for p in iter_index(data, 1, start, stop=isqrt(n) + 1):
978 yield from iter_index(data, 1, start, p * p)
979 data[p * p : n : p + p] = bytes(len(range(p * p, n, p + p)))
980 start = p * p
981 yield from iter_index(data, 1, start)
982
983
984def _batched(iterable, n, *, strict=False): # pragma: no cover
985 """Batch data into tuples of length *n*. If the number of items in
986 *iterable* is not divisible by *n*:
987 * The last batch will be shorter if *strict* is ``False``.
988 * :exc:`ValueError` will be raised if *strict* is ``True``.
989
990 >>> list(batched('ABCDEFG', 3))
991 [('A', 'B', 'C'), ('D', 'E', 'F'), ('G',)]
992
993 On Python 3.13 and above, this is an alias for :func:`itertools.batched`.
994 """
995 if n < 1:
996 raise ValueError('n must be at least one')
997 iterator = iter(iterable)
998 while batch := tuple(islice(iterator, n)):
999 if strict and len(batch) != n:
1000 raise ValueError('batched(): incomplete batch')
1001 yield batch
1002
1003
1004if hexversion >= 0x30D00A2: # pragma: no cover
1005 from itertools import batched as itertools_batched
1006
1007 def batched(iterable, n, *, strict=False):
1008 return itertools_batched(iterable, n, strict=strict)
1009
1010 batched.__doc__ = _batched.__doc__
1011else:
1012 batched = _batched
1013
1014
1015def transpose(it):
1016 """Swap the rows and columns of the input matrix.
1017
1018 >>> list(transpose([(1, 2, 3), (11, 22, 33)]))
1019 [(1, 11), (2, 22), (3, 33)]
1020
1021 The caller should ensure that the dimensions of the input are compatible.
1022 If the input is empty, no output will be produced.
1023 """
1024 return _zip_strict(*it)
1025
1026
1027def _is_scalar(value, stringlike=(str, bytes)):
1028 "Scalars are bytes, strings, and non-iterables."
1029 try:
1030 iter(value)
1031 except TypeError:
1032 return True
1033 return isinstance(value, stringlike)
1034
1035
1036def _flatten_tensor(tensor):
1037 "Depth-first iterator over scalars in a tensor."
1038 iterator = iter(tensor)
1039 while True:
1040 try:
1041 value = next(iterator)
1042 except StopIteration:
1043 return iterator
1044 iterator = chain((value,), iterator)
1045 if _is_scalar(value):
1046 return iterator
1047 iterator = chain.from_iterable(iterator)
1048
1049
1050def reshape(matrix, shape):
1051 """Change the shape of a *matrix*.
1052
1053 If *shape* is an integer, the matrix must be two dimensional
1054 and the shape is interpreted as the desired number of columns:
1055
1056 >>> matrix = [(0, 1), (2, 3), (4, 5)]
1057 >>> cols = 3
1058 >>> list(reshape(matrix, cols))
1059 [(0, 1, 2), (3, 4, 5)]
1060
1061 If *shape* is a tuple (or other iterable), the input matrix can have
1062 any number of dimensions. It will first be flattened and then rebuilt
1063 to the desired shape which can also be multidimensional:
1064
1065 >>> matrix = [(0, 1), (2, 3), (4, 5)] # Start with a 3 x 2 matrix
1066
1067 >>> list(reshape(matrix, (2, 3))) # Make a 2 x 3 matrix
1068 [(0, 1, 2), (3, 4, 5)]
1069
1070 >>> list(reshape(matrix, (6,))) # Make a vector of length six
1071 [0, 1, 2, 3, 4, 5]
1072
1073 >>> list(reshape(matrix, (2, 1, 3, 1))) # Make 2 x 1 x 3 x 1 tensor
1074 [(((0,), (1,), (2,)),), (((3,), (4,), (5,)),)]
1075
1076 Each dimension is assumed to be uniform, either all arrays or all scalars.
1077 Flattening stops when the first value in a dimension is a scalar.
1078 Scalars are bytes, strings, and non-iterables.
1079 The reshape iterator stops when the requested shape is complete
1080 or when the input is exhausted, whichever comes first.
1081
1082 """
1083 if isinstance(shape, int):
1084 return batched(chain.from_iterable(matrix), shape)
1085 first_dim, *dims = shape
1086 scalar_stream = _flatten_tensor(matrix)
1087 return islice(reduce(batched, reversed(dims), scalar_stream), first_dim)
1088
1089
1090def matmul(m1, m2):
1091 """Multiply two matrices.
1092
1093 >>> list(matmul([(7, 5), (3, 5)], [(2, 5), (7, 9)]))
1094 [(49, 80), (41, 60)]
1095
1096 The caller should ensure that the dimensions of the input matrices are
1097 compatible with each other.
1098
1099 Supports all numeric types: int, float, complex, Decimal, Fraction.
1100 """
1101 n = len(m2[0])
1102 return batched(starmap(_sumprod, product(m1, transpose(m2))), n)
1103
1104
1105def _factor_pollard(n):
1106 # Return a factor of n using Pollard's rho algorithm.
1107 # Efficient when n is odd and composite.
1108 for b in range(1, n):
1109 x = y = 2
1110 d = 1
1111 while d == 1:
1112 x = (x * x + b) % n
1113 y = (y * y + b) % n
1114 y = (y * y + b) % n
1115 d = gcd(x - y, n)
1116 if d != n:
1117 return d
1118 raise ValueError('prime or under 5') # pragma: no cover
1119
1120
1121_primes_below_211 = tuple(sieve(211))
1122
1123
1124def factor(n):
1125 """Yield the prime factors of n.
1126
1127 >>> list(factor(360))
1128 [2, 2, 2, 3, 3, 5]
1129
1130 Finds small factors with trial division. Larger factors are
1131 either verified as prime with ``is_prime`` or split into
1132 smaller factors with Pollard's rho algorithm.
1133 """
1134
1135 # Corner case reduction
1136 if n < 2:
1137 return
1138
1139 # Trial division reduction
1140 for prime in _primes_below_211:
1141 while not n % prime:
1142 yield prime
1143 n //= prime
1144
1145 # Pollard's rho reduction
1146 primes = []
1147 todo = [n] if n > 1 else []
1148 for n in todo:
1149 if n < 211**2 or is_prime(n):
1150 primes.append(n)
1151 else:
1152 fact = _factor_pollard(n)
1153 todo += (fact, n // fact)
1154 yield from sorted(primes)
1155
1156
1157def polynomial_eval(coefficients, x):
1158 """Evaluate a polynomial at a specific value.
1159
1160 Computes with better numeric stability than Horner's method.
1161
1162 Evaluate ``x^3 - 4 * x^2 - 17 * x + 60`` at ``x = 2.5``:
1163
1164 >>> coefficients = [1, -4, -17, 60]
1165 >>> x = 2.5
1166 >>> polynomial_eval(coefficients, x)
1167 8.125
1168
1169 Note that polynomial coefficients are specified in descending power order.
1170
1171 Supports all numeric types: int, float, complex, Decimal, Fraction.
1172 """
1173 n = len(coefficients)
1174 if n == 0:
1175 return type(x)(0)
1176 powers = map(pow, repeat(x), reversed(range(n)))
1177 return _sumprod(coefficients, powers)
1178
1179
1180def sum_of_squares(it):
1181 """Return the sum of the squares of the input values.
1182
1183 >>> sum_of_squares([10, 20, 30])
1184 1400
1185
1186 Supports all numeric types: int, float, complex, Decimal, Fraction.
1187 """
1188 return _sumprod(*tee(it))
1189
1190
1191def polynomial_derivative(coefficients):
1192 """Compute the first derivative of a polynomial.
1193
1194 Evaluate the derivative of ``x³ - 4 x² - 17 x + 60``:
1195
1196 >>> coefficients = [1, -4, -17, 60]
1197 >>> derivative_coefficients = polynomial_derivative(coefficients)
1198 >>> derivative_coefficients
1199 [3, -8, -17]
1200
1201 Note that polynomial coefficients are specified in descending power order.
1202
1203 Supports all numeric types: int, float, complex, Decimal, Fraction.
1204 """
1205 n = len(coefficients)
1206 powers = reversed(range(1, n))
1207 return list(map(mul, coefficients, powers))
1208
1209
1210def totient(n):
1211 """Return the count of natural numbers up to *n* that are coprime with *n*.
1212
1213 Euler's totient function φ(n) gives the number of totatives.
1214 Totative are integers k in the range 1 ≤ k ≤ n such that gcd(n, k) = 1.
1215
1216 >>> n = 9
1217 >>> totient(n)
1218 6
1219
1220 >>> totatives = [x for x in range(1, n) if gcd(n, x) == 1]
1221 >>> totatives
1222 [1, 2, 4, 5, 7, 8]
1223 >>> len(totatives)
1224 6
1225
1226 Reference: https://en.wikipedia.org/wiki/Euler%27s_totient_function
1227
1228 """
1229 for prime in set(factor(n)):
1230 n -= n // prime
1231 return n
1232
1233
1234# Miller–Rabin primality test: https://oeis.org/A014233
1235_perfect_tests = [
1236 (2047, (2,)),
1237 (9080191, (31, 73)),
1238 (4759123141, (2, 7, 61)),
1239 (1122004669633, (2, 13, 23, 1662803)),
1240 (2152302898747, (2, 3, 5, 7, 11)),
1241 (3474749660383, (2, 3, 5, 7, 11, 13)),
1242 (18446744073709551616, (2, 325, 9375, 28178, 450775, 9780504, 1795265022)),
1243 (
1244 3317044064679887385961981,
1245 (2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41),
1246 ),
1247]
1248
1249
1250@lru_cache
1251def _shift_to_odd(n):
1252 'Return s, d such that 2**s * d == n'
1253 s = ((n - 1) ^ n).bit_length() - 1
1254 d = n >> s
1255 assert (1 << s) * d == n and d & 1 and s >= 0
1256 return s, d
1257
1258
1259def _strong_probable_prime(n, base):
1260 assert (n > 2) and (n & 1) and (2 <= base < n)
1261
1262 s, d = _shift_to_odd(n - 1)
1263
1264 x = pow(base, d, n)
1265 if x == 1 or x == n - 1:
1266 return True
1267
1268 for _ in range(s - 1):
1269 x = x * x % n
1270 if x == n - 1:
1271 return True
1272
1273 return False
1274
1275
1276# Separate instance of Random() that doesn't share state
1277# with the default user instance of Random().
1278_private_randrange = random.Random().randrange
1279
1280
1281def is_prime(n):
1282 """Return ``True`` if *n* is prime and ``False`` otherwise.
1283
1284 Basic examples:
1285
1286 >>> is_prime(37)
1287 True
1288 >>> is_prime(3 * 13)
1289 False
1290 >>> is_prime(18_446_744_073_709_551_557)
1291 True
1292
1293 Find the next prime over one billion:
1294
1295 >>> next(filter(is_prime, count(10**9)))
1296 1000000007
1297
1298 Generate random primes up to 200 bits and up to 60 decimal digits:
1299
1300 >>> from random import seed, randrange, getrandbits
1301 >>> seed(18675309)
1302
1303 >>> next(filter(is_prime, map(getrandbits, repeat(200))))
1304 893303929355758292373272075469392561129886005037663238028407
1305
1306 >>> next(filter(is_prime, map(randrange, repeat(10**60))))
1307 269638077304026462407872868003560484232362454342414618963649
1308
1309 This function is exact for values of *n* below 10**24. For larger inputs,
1310 the probabilistic Miller-Rabin primality test has a less than 1 in 2**128
1311 chance of a false positive.
1312 """
1313
1314 if n < 17:
1315 return n in {2, 3, 5, 7, 11, 13}
1316
1317 if not (n & 1 and n % 3 and n % 5 and n % 7 and n % 11 and n % 13):
1318 return False
1319
1320 for limit, bases in _perfect_tests:
1321 if n < limit:
1322 break
1323 else:
1324 bases = (_private_randrange(2, n - 1) for i in range(64))
1325
1326 return all(_strong_probable_prime(n, base) for base in bases)
1327
1328
1329def loops(n):
1330 """Returns an iterable with *n* elements for efficient looping.
1331 Like ``range(n)`` but doesn't create integers.
1332
1333 >>> i = 0
1334 >>> for _ in loops(5):
1335 ... i += 1
1336 >>> i
1337 5
1338
1339 """
1340 return repeat(None, n)
1341
1342
1343def multinomial(*counts):
1344 """Number of distinct arrangements of a multiset.
1345
1346 The expression ``multinomial(3, 4, 2)`` has several equivalent
1347 interpretations:
1348
1349 * In the expansion of ``(a + b + c)⁹``, the coefficient of the
1350 ``a³b⁴c²`` term is 1260.
1351
1352 * There are 1260 distinct ways to arrange 9 balls consisting of 3 reds, 4
1353 greens, and 2 blues.
1354
1355 * There are 1260 unique ways to place 9 distinct objects into three bins
1356 with sizes 3, 4, and 2.
1357
1358 The :func:`multinomial` function computes the length of
1359 :func:`distinct_permutations`. For example, there are 83,160 distinct
1360 anagrams of the word "abracadabra":
1361
1362 >>> from more_itertools import distinct_permutations, ilen
1363 >>> ilen(distinct_permutations('abracadabra'))
1364 83160
1365
1366 This can be computed directly from the letter counts, 5a 2b 2r 1c 1d:
1367
1368 >>> from collections import Counter
1369 >>> list(Counter('abracadabra').values())
1370 [5, 2, 2, 1, 1]
1371 >>> multinomial(5, 2, 2, 1, 1)
1372 83160
1373
1374 A binomial coefficient is a special case of multinomial where there are
1375 only two categories. For example, the number of ways to arrange 12 balls
1376 with 5 reds and 7 blues is ``multinomial(5, 7)`` or ``math.comb(12, 5)``.
1377
1378 Likewise, factorial is a special case of multinomial where
1379 the multiplicities are all just 1 so that
1380 ``multinomial(1, 1, 1, 1, 1, 1, 1) == math.factorial(7)``.
1381
1382 Reference: https://en.wikipedia.org/wiki/Multinomial_theorem
1383
1384 """
1385 return prod(map(comb, accumulate(counts), counts))
1386
1387
1388def _running_median_minheap_and_maxheap(iterator): # pragma: no cover
1389 "Non-windowed running_median() for Python 3.14+"
1390
1391 read = iterator.__next__
1392 lo = [] # max-heap
1393 hi = [] # min-heap (same size as or one smaller than lo)
1394
1395 with suppress(StopIteration):
1396 while True:
1397 heappush_max(lo, heappushpop(hi, read()))
1398 yield lo[0]
1399
1400 heappush(hi, heappushpop_max(lo, read()))
1401 yield (lo[0] + hi[0]) / 2
1402
1403
1404def _running_median_minheap_only(iterator): # pragma: no cover
1405 "Backport of non-windowed running_median() for Python 3.13 and prior."
1406
1407 read = iterator.__next__
1408 lo = [] # max-heap (actually a minheap with negated values)
1409 hi = [] # min-heap (same size as or one smaller than lo)
1410
1411 with suppress(StopIteration):
1412 while True:
1413 heappush(lo, -heappushpop(hi, read()))
1414 yield -lo[0]
1415
1416 heappush(hi, -heappushpop(lo, -read()))
1417 yield (hi[0] - lo[0]) / 2
1418
1419
1420def _running_median_windowed(iterator, maxlen):
1421 "Yield median of values in a sliding window."
1422
1423 window = deque()
1424 ordered = []
1425
1426 for x in iterator:
1427 window.append(x)
1428 insort(ordered, x)
1429
1430 if len(ordered) > maxlen:
1431 i = bisect_left(ordered, window.popleft())
1432 del ordered[i]
1433
1434 n = len(ordered)
1435 m = n // 2
1436 yield ordered[m] if n & 1 else (ordered[m - 1] + ordered[m]) / 2
1437
1438
1439def running_median(iterable, *, maxlen=None):
1440 """Cumulative median of values seen so far or values in a sliding window.
1441
1442 Set *maxlen* to a positive integer to specify the maximum size
1443 of the sliding window. The default of *None* is equivalent to
1444 an unbounded window.
1445
1446 For example:
1447
1448 >>> list(running_median([5.0, 9.0, 4.0, 12.0, 8.0, 9.0]))
1449 [5.0, 7.0, 5.0, 7.0, 8.0, 8.5]
1450 >>> list(running_median([5.0, 9.0, 4.0, 12.0, 8.0, 9.0], maxlen=3))
1451 [5.0, 7.0, 5.0, 9.0, 8.0, 9.0]
1452
1453 Supports numeric types such as int, float, Decimal, and Fraction,
1454 but not complex numbers which are unorderable.
1455
1456 On version Python 3.13 and prior, max-heaps are simulated with
1457 negative values. The negation causes Decimal inputs to apply context
1458 rounding, making the results slightly different than that obtained
1459 by statistics.median().
1460 """
1461
1462 iterator = iter(iterable)
1463
1464 if maxlen is not None:
1465 maxlen = index(maxlen)
1466 if maxlen <= 0:
1467 raise ValueError('Window size should be positive')
1468 return _running_median_windowed(iterator, maxlen)
1469
1470 if not _max_heap_available:
1471 return _running_median_minheap_only(iterator)
1472
1473 return _running_median_minheap_and_maxheap(iterator)