1"""Imported from the recipes section of the itertools documentation.
2
3All functions taken from the recipes section of the itertools library docs
4[1]_.
5Some backward-compatible usability improvements have been made.
6
7.. [1] http://docs.python.org/library/itertools.html#recipes
8
9"""
10
11import random
12
13from bisect import bisect_left, insort
14from collections import deque
15from contextlib import suppress
16from functools import lru_cache, partial, reduce
17from heapq import heappush, heappushpop
18from itertools import (
19 accumulate,
20 chain,
21 combinations,
22 compress,
23 count,
24 cycle,
25 groupby,
26 islice,
27 product,
28 repeat,
29 starmap,
30 takewhile,
31 tee,
32 zip_longest,
33)
34from math import prod, comb, isqrt, gcd
35from operator import mul, not_, itemgetter, getitem, index
36from random import randrange, sample, choice
37from sys import hexversion
38
39__all__ = [
40 'all_equal',
41 'batched',
42 'before_and_after',
43 'consume',
44 'convolve',
45 'dotproduct',
46 'first_true',
47 'factor',
48 'flatten',
49 'grouper',
50 'is_prime',
51 'iter_except',
52 'iter_index',
53 'loops',
54 'matmul',
55 'multinomial',
56 'ncycles',
57 'nth',
58 'nth_combination',
59 'padnone',
60 'pad_none',
61 'pairwise',
62 'partition',
63 'polynomial_eval',
64 'polynomial_from_roots',
65 'polynomial_derivative',
66 'powerset',
67 'prepend',
68 'quantify',
69 'reshape',
70 'random_combination_with_replacement',
71 'random_combination',
72 'random_permutation',
73 'random_product',
74 'repeatfunc',
75 'roundrobin',
76 'running_median',
77 'sieve',
78 'sliding_window',
79 'subslices',
80 'sum_of_squares',
81 'tabulate',
82 'tail',
83 'take',
84 'totient',
85 'transpose',
86 'triplewise',
87 'unique',
88 'unique_everseen',
89 'unique_justseen',
90]
91
92_marker = object()
93
94
95# zip with strict is available for Python 3.10+
96try:
97 zip(strict=True)
98except TypeError: # pragma: no cover
99 _zip_strict = zip
100else: # pragma: no cover
101 _zip_strict = partial(zip, strict=True)
102
103
104# math.sumprod is available for Python 3.12+
105try:
106 from math import sumprod as _sumprod
107except ImportError: # pragma: no cover
108 _sumprod = lambda x, y: dotproduct(x, y)
109
110
111# heapq max-heap functions are available for Python 3.14+
112try:
113 from heapq import heappush_max, heappushpop_max
114except ImportError: # pragma: no cover
115 _max_heap_available = False
116else: # pragma: no cover
117 _max_heap_available = True
118
119
120def take(n, iterable):
121 """Return first *n* items of the *iterable* as a list.
122
123 >>> take(3, range(10))
124 [0, 1, 2]
125
126 If there are fewer than *n* items in the iterable, all of them are
127 returned.
128
129 >>> take(10, range(3))
130 [0, 1, 2]
131
132 """
133 return list(islice(iterable, n))
134
135
136def tabulate(function, start=0):
137 """Return an iterator over the results of ``func(start)``,
138 ``func(start + 1)``, ``func(start + 2)``...
139
140 *func* should be a function that accepts one integer argument.
141
142 If *start* is not specified it defaults to 0. It will be incremented each
143 time the iterator is advanced.
144
145 >>> square = lambda x: x ** 2
146 >>> iterator = tabulate(square, -3)
147 >>> take(4, iterator)
148 [9, 4, 1, 0]
149
150 """
151 return map(function, count(start))
152
153
154def tail(n, iterable):
155 """Return an iterator over the last *n* items of *iterable*.
156
157 >>> t = tail(3, 'ABCDEFG')
158 >>> list(t)
159 ['E', 'F', 'G']
160
161 """
162 try:
163 size = len(iterable)
164 except TypeError:
165 return iter(deque(iterable, maxlen=n))
166 else:
167 return islice(iterable, max(0, size - n), None)
168
169
170def consume(iterator, n=None):
171 """Advance *iterable* by *n* steps. If *n* is ``None``, consume it
172 entirely.
173
174 Efficiently exhausts an iterator without returning values. Defaults to
175 consuming the whole iterator, but an optional second argument may be
176 provided to limit consumption.
177
178 >>> i = (x for x in range(10))
179 >>> next(i)
180 0
181 >>> consume(i, 3)
182 >>> next(i)
183 4
184 >>> consume(i)
185 >>> next(i)
186 Traceback (most recent call last):
187 File "<stdin>", line 1, in <module>
188 StopIteration
189
190 If the iterator has fewer items remaining than the provided limit, the
191 whole iterator will be consumed.
192
193 >>> i = (x for x in range(3))
194 >>> consume(i, 5)
195 >>> next(i)
196 Traceback (most recent call last):
197 File "<stdin>", line 1, in <module>
198 StopIteration
199
200 """
201 # Use functions that consume iterators at C speed.
202 if n is None:
203 # feed the entire iterator into a zero-length deque
204 deque(iterator, maxlen=0)
205 else:
206 # advance to the empty slice starting at position n
207 next(islice(iterator, n, n), None)
208
209
210def nth(iterable, n, default=None):
211 """Returns the nth item or a default value.
212
213 >>> l = range(10)
214 >>> nth(l, 3)
215 3
216 >>> nth(l, 20, "zebra")
217 'zebra'
218
219 """
220 return next(islice(iterable, n, None), default)
221
222
223def all_equal(iterable, key=None):
224 """
225 Returns ``True`` if all the elements are equal to each other.
226
227 >>> all_equal('aaaa')
228 True
229 >>> all_equal('aaab')
230 False
231
232 A function that accepts a single argument and returns a transformed version
233 of each input item can be specified with *key*:
234
235 >>> all_equal('AaaA', key=str.casefold)
236 True
237 >>> all_equal([1, 2, 3], key=lambda x: x < 10)
238 True
239
240 """
241 iterator = groupby(iterable, key)
242 for first in iterator:
243 for second in iterator:
244 return False
245 return True
246 return True
247
248
249def quantify(iterable, pred=bool):
250 """Return the how many times the predicate is true.
251
252 >>> quantify([True, False, True])
253 2
254
255 """
256 return sum(map(pred, iterable))
257
258
259def pad_none(iterable):
260 """Returns the sequence of elements and then returns ``None`` indefinitely.
261
262 >>> take(5, pad_none(range(3)))
263 [0, 1, 2, None, None]
264
265 Useful for emulating the behavior of the built-in :func:`map` function.
266
267 See also :func:`padded`.
268
269 """
270 return chain(iterable, repeat(None))
271
272
273padnone = pad_none
274
275
276def ncycles(iterable, n):
277 """Returns the sequence elements *n* times
278
279 >>> list(ncycles(["a", "b"], 3))
280 ['a', 'b', 'a', 'b', 'a', 'b']
281
282 """
283 return chain.from_iterable(repeat(tuple(iterable), n))
284
285
286def dotproduct(vec1, vec2):
287 """Returns the dot product of the two iterables.
288
289 >>> dotproduct([10, 15, 12], [0.65, 0.80, 1.25])
290 33.5
291 >>> 10 * 0.65 + 15 * 0.80 + 12 * 1.25
292 33.5
293
294 In Python 3.12 and later, use ``math.sumprod()`` instead.
295 """
296 return sum(map(mul, vec1, vec2))
297
298
299def flatten(listOfLists):
300 """Return an iterator flattening one level of nesting in a list of lists.
301
302 >>> list(flatten([[0, 1], [2, 3]]))
303 [0, 1, 2, 3]
304
305 See also :func:`collapse`, which can flatten multiple levels of nesting.
306
307 """
308 return chain.from_iterable(listOfLists)
309
310
311def repeatfunc(func, times=None, *args):
312 """Call *func* with *args* repeatedly, returning an iterable over the
313 results.
314
315 If *times* is specified, the iterable will terminate after that many
316 repetitions:
317
318 >>> from operator import add
319 >>> times = 4
320 >>> args = 3, 5
321 >>> list(repeatfunc(add, times, *args))
322 [8, 8, 8, 8]
323
324 If *times* is ``None`` the iterable will not terminate:
325
326 >>> from random import randrange
327 >>> times = None
328 >>> args = 1, 11
329 >>> take(6, repeatfunc(randrange, times, *args)) # doctest:+SKIP
330 [2, 4, 8, 1, 8, 4]
331
332 """
333 if times is None:
334 return starmap(func, repeat(args))
335 return starmap(func, repeat(args, times))
336
337
338def _pairwise(iterable):
339 """Returns an iterator of paired items, overlapping, from the original
340
341 >>> take(4, pairwise(count()))
342 [(0, 1), (1, 2), (2, 3), (3, 4)]
343
344 On Python 3.10 and above, this is an alias for :func:`itertools.pairwise`.
345
346 """
347 a, b = tee(iterable)
348 next(b, None)
349 return zip(a, b)
350
351
352try:
353 from itertools import pairwise as itertools_pairwise
354except ImportError: # pragma: no cover
355 pairwise = _pairwise
356else: # pragma: no cover
357
358 def pairwise(iterable):
359 return itertools_pairwise(iterable)
360
361 pairwise.__doc__ = _pairwise.__doc__
362
363
364class UnequalIterablesError(ValueError):
365 def __init__(self, details=None):
366 msg = 'Iterables have different lengths'
367 if details is not None:
368 msg += (': index 0 has length {}; index {} has length {}').format(
369 *details
370 )
371
372 super().__init__(msg)
373
374
375def _zip_equal_generator(iterables):
376 for combo in zip_longest(*iterables, fillvalue=_marker):
377 for val in combo:
378 if val is _marker:
379 raise UnequalIterablesError()
380 yield combo
381
382
383def _zip_equal(*iterables):
384 # Check whether the iterables are all the same size.
385 try:
386 first_size = len(iterables[0])
387 for i, it in enumerate(iterables[1:], 1):
388 size = len(it)
389 if size != first_size:
390 raise UnequalIterablesError(details=(first_size, i, size))
391 # All sizes are equal, we can use the built-in zip.
392 return zip(*iterables)
393 # If any one of the iterables didn't have a length, start reading
394 # them until one runs out.
395 except TypeError:
396 return _zip_equal_generator(iterables)
397
398
399def grouper(iterable, n, incomplete='fill', fillvalue=None):
400 """Group elements from *iterable* into fixed-length groups of length *n*.
401
402 >>> list(grouper('ABCDEF', 3))
403 [('A', 'B', 'C'), ('D', 'E', 'F')]
404
405 The keyword arguments *incomplete* and *fillvalue* control what happens for
406 iterables whose length is not a multiple of *n*.
407
408 When *incomplete* is `'fill'`, the last group will contain instances of
409 *fillvalue*.
410
411 >>> list(grouper('ABCDEFG', 3, incomplete='fill', fillvalue='x'))
412 [('A', 'B', 'C'), ('D', 'E', 'F'), ('G', 'x', 'x')]
413
414 When *incomplete* is `'ignore'`, the last group will not be emitted.
415
416 >>> list(grouper('ABCDEFG', 3, incomplete='ignore', fillvalue='x'))
417 [('A', 'B', 'C'), ('D', 'E', 'F')]
418
419 When *incomplete* is `'strict'`, a subclass of `ValueError` will be raised.
420
421 >>> iterator = grouper('ABCDEFG', 3, incomplete='strict')
422 >>> list(iterator) # doctest: +IGNORE_EXCEPTION_DETAIL
423 Traceback (most recent call last):
424 ...
425 UnequalIterablesError
426
427 """
428 iterators = [iter(iterable)] * n
429 if incomplete == 'fill':
430 return zip_longest(*iterators, fillvalue=fillvalue)
431 if incomplete == 'strict':
432 return _zip_equal(*iterators)
433 if incomplete == 'ignore':
434 return zip(*iterators)
435 else:
436 raise ValueError('Expected fill, strict, or ignore')
437
438
439def roundrobin(*iterables):
440 """Visit input iterables in a cycle until each is exhausted.
441
442 >>> list(roundrobin('ABC', 'D', 'EF'))
443 ['A', 'D', 'E', 'B', 'F', 'C']
444
445 This function produces the same output as :func:`interleave_longest`, but
446 may perform better for some inputs (in particular when the number of
447 iterables is small).
448
449 """
450 # Algorithm credited to George Sakkis
451 iterators = map(iter, iterables)
452 for num_active in range(len(iterables), 0, -1):
453 iterators = cycle(islice(iterators, num_active))
454 yield from map(next, iterators)
455
456
457def partition(pred, iterable):
458 """
459 Returns a 2-tuple of iterables derived from the input iterable.
460 The first yields the items that have ``pred(item) == False``.
461 The second yields the items that have ``pred(item) == True``.
462
463 >>> is_odd = lambda x: x % 2 != 0
464 >>> iterable = range(10)
465 >>> even_items, odd_items = partition(is_odd, iterable)
466 >>> list(even_items), list(odd_items)
467 ([0, 2, 4, 6, 8], [1, 3, 5, 7, 9])
468
469 If *pred* is None, :func:`bool` is used.
470
471 >>> iterable = [0, 1, False, True, '', ' ']
472 >>> false_items, true_items = partition(None, iterable)
473 >>> list(false_items), list(true_items)
474 ([0, False, ''], [1, True, ' '])
475
476 """
477 if pred is None:
478 pred = bool
479
480 t1, t2, p = tee(iterable, 3)
481 p1, p2 = tee(map(pred, p))
482 return (compress(t1, map(not_, p1)), compress(t2, p2))
483
484
485def powerset(iterable):
486 """Yields all possible subsets of the iterable.
487
488 >>> list(powerset([1, 2, 3]))
489 [(), (1,), (2,), (3,), (1, 2), (1, 3), (2, 3), (1, 2, 3)]
490
491 :func:`powerset` will operate on iterables that aren't :class:`set`
492 instances, so repeated elements in the input will produce repeated elements
493 in the output.
494
495 >>> seq = [1, 1, 0]
496 >>> list(powerset(seq))
497 [(), (1,), (1,), (0,), (1, 1), (1, 0), (1, 0), (1, 1, 0)]
498
499 For a variant that efficiently yields actual :class:`set` instances, see
500 :func:`powerset_of_sets`.
501 """
502 s = list(iterable)
503 return chain.from_iterable(combinations(s, r) for r in range(len(s) + 1))
504
505
506def unique_everseen(iterable, key=None):
507 """
508 Yield unique elements, preserving order.
509
510 >>> list(unique_everseen('AAAABBBCCDAABBB'))
511 ['A', 'B', 'C', 'D']
512 >>> list(unique_everseen('ABBCcAD', str.lower))
513 ['A', 'B', 'C', 'D']
514
515 Sequences with a mix of hashable and unhashable items can be used.
516 The function will be slower (i.e., `O(n^2)`) for unhashable items.
517
518 Remember that ``list`` objects are unhashable - you can use the *key*
519 parameter to transform the list to a tuple (which is hashable) to
520 avoid a slowdown.
521
522 >>> iterable = ([1, 2], [2, 3], [1, 2])
523 >>> list(unique_everseen(iterable)) # Slow
524 [[1, 2], [2, 3]]
525 >>> list(unique_everseen(iterable, key=tuple)) # Faster
526 [[1, 2], [2, 3]]
527
528 Similarly, you may want to convert unhashable ``set`` objects with
529 ``key=frozenset``. For ``dict`` objects,
530 ``key=lambda x: frozenset(x.items())`` can be used.
531
532 """
533 seenset = set()
534 seenset_add = seenset.add
535 seenlist = []
536 seenlist_add = seenlist.append
537 use_key = key is not None
538
539 for element in iterable:
540 k = key(element) if use_key else element
541 try:
542 if k not in seenset:
543 seenset_add(k)
544 yield element
545 except TypeError:
546 if k not in seenlist:
547 seenlist_add(k)
548 yield element
549
550
551def unique_justseen(iterable, key=None):
552 """Yields elements in order, ignoring serial duplicates
553
554 >>> list(unique_justseen('AAAABBBCCDAABBB'))
555 ['A', 'B', 'C', 'D', 'A', 'B']
556 >>> list(unique_justseen('ABBCcAD', str.lower))
557 ['A', 'B', 'C', 'A', 'D']
558
559 """
560 if key is None:
561 return map(itemgetter(0), groupby(iterable))
562
563 return map(next, map(itemgetter(1), groupby(iterable, key)))
564
565
566def unique(iterable, key=None, reverse=False):
567 """Yields unique elements in sorted order.
568
569 >>> list(unique([[1, 2], [3, 4], [1, 2]]))
570 [[1, 2], [3, 4]]
571
572 *key* and *reverse* are passed to :func:`sorted`.
573
574 >>> list(unique('ABBcCAD', str.casefold))
575 ['A', 'B', 'c', 'D']
576 >>> list(unique('ABBcCAD', str.casefold, reverse=True))
577 ['D', 'c', 'B', 'A']
578
579 The elements in *iterable* need not be hashable, but they must be
580 comparable for sorting to work.
581 """
582 sequenced = sorted(iterable, key=key, reverse=reverse)
583 return unique_justseen(sequenced, key=key)
584
585
586def iter_except(func, exception, first=None):
587 """Yields results from a function repeatedly until an exception is raised.
588
589 Converts a call-until-exception interface to an iterator interface.
590 Like ``iter(func, sentinel)``, but uses an exception instead of a sentinel
591 to end the loop.
592
593 >>> l = [0, 1, 2]
594 >>> list(iter_except(l.pop, IndexError))
595 [2, 1, 0]
596
597 Multiple exceptions can be specified as a stopping condition:
598
599 >>> l = [1, 2, 3, '...', 4, 5, 6]
600 >>> list(iter_except(lambda: 1 + l.pop(), (IndexError, TypeError)))
601 [7, 6, 5]
602 >>> list(iter_except(lambda: 1 + l.pop(), (IndexError, TypeError)))
603 [4, 3, 2]
604 >>> list(iter_except(lambda: 1 + l.pop(), (IndexError, TypeError)))
605 []
606
607 """
608 with suppress(exception):
609 if first is not None:
610 yield first()
611 while True:
612 yield func()
613
614
615def first_true(iterable, default=None, pred=None):
616 """
617 Returns the first true value in the iterable.
618
619 If no true value is found, returns *default*
620
621 If *pred* is not None, returns the first item for which
622 ``pred(item) == True`` .
623
624 >>> first_true(range(10))
625 1
626 >>> first_true(range(10), pred=lambda x: x > 5)
627 6
628 >>> first_true(range(10), default='missing', pred=lambda x: x > 9)
629 'missing'
630
631 """
632 return next(filter(pred, iterable), default)
633
634
635def random_product(*args, repeat=1):
636 """Draw an item at random from each of the input iterables.
637
638 >>> random_product('abc', range(4), 'XYZ') # doctest:+SKIP
639 ('c', 3, 'Z')
640
641 If *repeat* is provided as a keyword argument, that many items will be
642 drawn from each iterable.
643
644 >>> random_product('abcd', range(4), repeat=2) # doctest:+SKIP
645 ('a', 2, 'd', 3)
646
647 This equivalent to taking a random selection from
648 ``itertools.product(*args, repeat=repeat)``.
649
650 """
651 pools = [tuple(pool) for pool in args] * repeat
652 return tuple(choice(pool) for pool in pools)
653
654
655def random_permutation(iterable, r=None):
656 """Return a random *r* length permutation of the elements in *iterable*.
657
658 If *r* is not specified or is ``None``, then *r* defaults to the length of
659 *iterable*.
660
661 >>> random_permutation(range(5)) # doctest:+SKIP
662 (3, 4, 0, 1, 2)
663
664 This equivalent to taking a random selection from
665 ``itertools.permutations(iterable, r)``.
666
667 """
668 pool = tuple(iterable)
669 r = len(pool) if r is None else r
670 return tuple(sample(pool, r))
671
672
673def random_combination(iterable, r):
674 """Return a random *r* length subsequence of the elements in *iterable*.
675
676 >>> random_combination(range(5), 3) # doctest:+SKIP
677 (2, 3, 4)
678
679 This equivalent to taking a random selection from
680 ``itertools.combinations(iterable, r)``.
681
682 """
683 pool = tuple(iterable)
684 n = len(pool)
685 indices = sorted(sample(range(n), r))
686 return tuple(pool[i] for i in indices)
687
688
689def random_combination_with_replacement(iterable, r):
690 """Return a random *r* length subsequence of elements in *iterable*,
691 allowing individual elements to be repeated.
692
693 >>> random_combination_with_replacement(range(3), 5) # doctest:+SKIP
694 (0, 0, 1, 2, 2)
695
696 This equivalent to taking a random selection from
697 ``itertools.combinations_with_replacement(iterable, r)``.
698
699 """
700 pool = tuple(iterable)
701 n = len(pool)
702 indices = sorted(randrange(n) for i in range(r))
703 return tuple(pool[i] for i in indices)
704
705
706def nth_combination(iterable, r, index):
707 """Equivalent to ``list(combinations(iterable, r))[index]``.
708
709 The subsequences of *iterable* that are of length *r* can be ordered
710 lexicographically. :func:`nth_combination` computes the subsequence at
711 sort position *index* directly, without computing the previous
712 subsequences.
713
714 >>> nth_combination(range(5), 3, 5)
715 (0, 3, 4)
716
717 ``ValueError`` will be raised If *r* is negative or greater than the length
718 of *iterable*.
719 ``IndexError`` will be raised if the given *index* is invalid.
720 """
721 pool = tuple(iterable)
722 n = len(pool)
723 if (r < 0) or (r > n):
724 raise ValueError
725
726 c = 1
727 k = min(r, n - r)
728 for i in range(1, k + 1):
729 c = c * (n - k + i) // i
730
731 if index < 0:
732 index += c
733
734 if (index < 0) or (index >= c):
735 raise IndexError
736
737 result = []
738 while r:
739 c, n, r = c * r // n, n - 1, r - 1
740 while index >= c:
741 index -= c
742 c, n = c * (n - r) // n, n - 1
743 result.append(pool[-1 - n])
744
745 return tuple(result)
746
747
748def prepend(value, iterator):
749 """Yield *value*, followed by the elements in *iterator*.
750
751 >>> value = '0'
752 >>> iterator = ['1', '2', '3']
753 >>> list(prepend(value, iterator))
754 ['0', '1', '2', '3']
755
756 To prepend multiple values, see :func:`itertools.chain`
757 or :func:`value_chain`.
758
759 """
760 return chain([value], iterator)
761
762
763def convolve(signal, kernel):
764 """Discrete linear convolution of two iterables.
765 Equivalent to polynomial multiplication.
766
767 For example, multiplying ``(x² -x - 20)`` by ``(x - 3)``
768 gives ``(x³ -4x² -17x + 60)``.
769
770 >>> list(convolve([1, -1, -20], [1, -3]))
771 [1, -4, -17, 60]
772
773 Examples of popular kinds of kernels:
774
775 * The kernel ``[0.25, 0.25, 0.25, 0.25]`` computes a moving average.
776 For image data, this blurs the image and reduces noise.
777 * The kernel ``[1/2, 0, -1/2]`` estimates the first derivative of
778 a function evaluated at evenly spaced inputs.
779 * The kernel ``[1, -2, 1]`` estimates the second derivative of a
780 function evaluated at evenly spaced inputs.
781
782 Convolutions are mathematically commutative; however, the inputs are
783 evaluated differently. The signal is consumed lazily and can be
784 infinite. The kernel is fully consumed before the calculations begin.
785
786 Supports all numeric types: int, float, complex, Decimal, Fraction.
787
788 References:
789
790 * Article: https://betterexplained.com/articles/intuitive-convolution/
791 * Video by 3Blue1Brown: https://www.youtube.com/watch?v=KuXjwB4LzSA
792
793 """
794 # This implementation comes from an older version of the itertools
795 # documentation. While the newer implementation is a bit clearer,
796 # this one was kept because the inlined window logic is faster
797 # and it avoids an unnecessary deque-to-tuple conversion.
798 kernel = tuple(kernel)[::-1]
799 n = len(kernel)
800 window = deque([0], maxlen=n) * n
801 for x in chain(signal, repeat(0, n - 1)):
802 window.append(x)
803 yield _sumprod(kernel, window)
804
805
806def before_and_after(predicate, it):
807 """A variant of :func:`takewhile` that allows complete access to the
808 remainder of the iterator.
809
810 >>> it = iter('ABCdEfGhI')
811 >>> all_upper, remainder = before_and_after(str.isupper, it)
812 >>> ''.join(all_upper)
813 'ABC'
814 >>> ''.join(remainder) # takewhile() would lose the 'd'
815 'dEfGhI'
816
817 Note that the first iterator must be fully consumed before the second
818 iterator can generate valid results.
819 """
820 trues, after = tee(it)
821 trues = compress(takewhile(predicate, trues), zip(after))
822 return trues, after
823
824
825def triplewise(iterable):
826 """Return overlapping triplets from *iterable*.
827
828 >>> list(triplewise('ABCDE'))
829 [('A', 'B', 'C'), ('B', 'C', 'D'), ('C', 'D', 'E')]
830
831 """
832 # This deviates from the itertools documentation recipe - see
833 # https://github.com/more-itertools/more-itertools/issues/889
834 t1, t2, t3 = tee(iterable, 3)
835 next(t3, None)
836 next(t3, None)
837 next(t2, None)
838 return zip(t1, t2, t3)
839
840
841def _sliding_window_islice(iterable, n):
842 # Fast path for small, non-zero values of n.
843 iterators = tee(iterable, n)
844 for i, iterator in enumerate(iterators):
845 next(islice(iterator, i, i), None)
846 return zip(*iterators)
847
848
849def _sliding_window_deque(iterable, n):
850 # Normal path for other values of n.
851 iterator = iter(iterable)
852 window = deque(islice(iterator, n - 1), maxlen=n)
853 for x in iterator:
854 window.append(x)
855 yield tuple(window)
856
857
858def sliding_window(iterable, n):
859 """Return a sliding window of width *n* over *iterable*.
860
861 >>> list(sliding_window(range(6), 4))
862 [(0, 1, 2, 3), (1, 2, 3, 4), (2, 3, 4, 5)]
863
864 If *iterable* has fewer than *n* items, then nothing is yielded:
865
866 >>> list(sliding_window(range(3), 4))
867 []
868
869 For a variant with more features, see :func:`windowed`.
870 """
871 if n > 20:
872 return _sliding_window_deque(iterable, n)
873 elif n > 2:
874 return _sliding_window_islice(iterable, n)
875 elif n == 2:
876 return pairwise(iterable)
877 elif n == 1:
878 return zip(iterable)
879 else:
880 raise ValueError(f'n should be at least one, not {n}')
881
882
883def subslices(iterable):
884 """Return all contiguous non-empty subslices of *iterable*.
885
886 >>> list(subslices('ABC'))
887 [['A'], ['A', 'B'], ['A', 'B', 'C'], ['B'], ['B', 'C'], ['C']]
888
889 This is similar to :func:`substrings`, but emits items in a different
890 order.
891 """
892 seq = list(iterable)
893 slices = starmap(slice, combinations(range(len(seq) + 1), 2))
894 return map(getitem, repeat(seq), slices)
895
896
897def polynomial_from_roots(roots):
898 """Compute a polynomial's coefficients from its roots.
899
900 >>> roots = [5, -4, 3] # (x - 5) * (x + 4) * (x - 3)
901 >>> polynomial_from_roots(roots) # x³ - 4 x² - 17 x + 60
902 [1, -4, -17, 60]
903
904 Note that polynomial coefficients are specified in descending power order.
905
906 Supports all numeric types: int, float, complex, Decimal, Fraction.
907 """
908
909 # This recipe differs from the one in itertools docs in that it
910 # applies list() after each call to convolve(). This avoids
911 # hitting stack limits with nested generators.
912
913 poly = [1]
914 for root in roots:
915 poly = list(convolve(poly, (1, -root)))
916 return poly
917
918
919def iter_index(iterable, value, start=0, stop=None):
920 """Yield the index of each place in *iterable* that *value* occurs,
921 beginning with index *start* and ending before index *stop*.
922
923
924 >>> list(iter_index('AABCADEAF', 'A'))
925 [0, 1, 4, 7]
926 >>> list(iter_index('AABCADEAF', 'A', 1)) # start index is inclusive
927 [1, 4, 7]
928 >>> list(iter_index('AABCADEAF', 'A', 1, 7)) # stop index is not inclusive
929 [1, 4]
930
931 The behavior for non-scalar *values* matches the built-in Python types.
932
933 >>> list(iter_index('ABCDABCD', 'AB'))
934 [0, 4]
935 >>> list(iter_index([0, 1, 2, 3, 0, 1, 2, 3], [0, 1]))
936 []
937 >>> list(iter_index([[0, 1], [2, 3], [0, 1], [2, 3]], [0, 1]))
938 [0, 2]
939
940 See :func:`locate` for a more general means of finding the indexes
941 associated with particular values.
942
943 """
944 seq_index = getattr(iterable, 'index', None)
945 if seq_index is None:
946 # Slow path for general iterables
947 iterator = islice(iterable, start, stop)
948 for i, element in enumerate(iterator, start):
949 if element is value or element == value:
950 yield i
951 else:
952 # Fast path for sequences
953 stop = len(iterable) if stop is None else stop
954 i = start - 1
955 with suppress(ValueError):
956 while True:
957 yield (i := seq_index(value, i + 1, stop))
958
959
960def sieve(n):
961 """Yield the primes less than n.
962
963 >>> list(sieve(30))
964 [2, 3, 5, 7, 11, 13, 17, 19, 23, 29]
965
966 """
967 # This implementation comes from an older version of the itertools
968 # documentation. The newer implementation is easier to read but is
969 # less lazy.
970 if n > 2:
971 yield 2
972 start = 3
973 data = bytearray((0, 1)) * (n // 2)
974 for p in iter_index(data, 1, start, stop=isqrt(n) + 1):
975 yield from iter_index(data, 1, start, p * p)
976 data[p * p : n : p + p] = bytes(len(range(p * p, n, p + p)))
977 start = p * p
978 yield from iter_index(data, 1, start)
979
980
981def _batched(iterable, n, *, strict=False): # pragma: no cover
982 """Batch data into tuples of length *n*. If the number of items in
983 *iterable* is not divisible by *n*:
984 * The last batch will be shorter if *strict* is ``False``.
985 * :exc:`ValueError` will be raised if *strict* is ``True``.
986
987 >>> list(batched('ABCDEFG', 3))
988 [('A', 'B', 'C'), ('D', 'E', 'F'), ('G',)]
989
990 On Python 3.13 and above, this is an alias for :func:`itertools.batched`.
991 """
992 if n < 1:
993 raise ValueError('n must be at least one')
994 iterator = iter(iterable)
995 while batch := tuple(islice(iterator, n)):
996 if strict and len(batch) != n:
997 raise ValueError('batched(): incomplete batch')
998 yield batch
999
1000
1001if hexversion >= 0x30D00A2: # pragma: no cover
1002 from itertools import batched as itertools_batched
1003
1004 def batched(iterable, n, *, strict=False):
1005 return itertools_batched(iterable, n, strict=strict)
1006
1007 batched.__doc__ = _batched.__doc__
1008else: # pragma: no cover
1009 batched = _batched
1010
1011
1012def transpose(it):
1013 """Swap the rows and columns of the input matrix.
1014
1015 >>> list(transpose([(1, 2, 3), (11, 22, 33)]))
1016 [(1, 11), (2, 22), (3, 33)]
1017
1018 The caller should ensure that the dimensions of the input are compatible.
1019 If the input is empty, no output will be produced.
1020 """
1021 return _zip_strict(*it)
1022
1023
1024def _is_scalar(value, stringlike=(str, bytes)):
1025 "Scalars are bytes, strings, and non-iterables."
1026 try:
1027 iter(value)
1028 except TypeError:
1029 return True
1030 return isinstance(value, stringlike)
1031
1032
1033def _flatten_tensor(tensor):
1034 "Depth-first iterator over scalars in a tensor."
1035 iterator = iter(tensor)
1036 while True:
1037 try:
1038 value = next(iterator)
1039 except StopIteration:
1040 return iterator
1041 iterator = chain((value,), iterator)
1042 if _is_scalar(value):
1043 return iterator
1044 iterator = chain.from_iterable(iterator)
1045
1046
1047def reshape(matrix, shape):
1048 """Change the shape of a *matrix*.
1049
1050 If *shape* is an integer, the matrix must be two dimensional
1051 and the shape is interpreted as the desired number of columns:
1052
1053 >>> matrix = [(0, 1), (2, 3), (4, 5)]
1054 >>> cols = 3
1055 >>> list(reshape(matrix, cols))
1056 [(0, 1, 2), (3, 4, 5)]
1057
1058 If *shape* is a tuple (or other iterable), the input matrix can have
1059 any number of dimensions. It will first be flattened and then rebuilt
1060 to the desired shape which can also be multidimensional:
1061
1062 >>> matrix = [(0, 1), (2, 3), (4, 5)] # Start with a 3 x 2 matrix
1063
1064 >>> list(reshape(matrix, (2, 3))) # Make a 2 x 3 matrix
1065 [(0, 1, 2), (3, 4, 5)]
1066
1067 >>> list(reshape(matrix, (6,))) # Make a vector of length six
1068 [0, 1, 2, 3, 4, 5]
1069
1070 >>> list(reshape(matrix, (2, 1, 3, 1))) # Make 2 x 1 x 3 x 1 tensor
1071 [(((0,), (1,), (2,)),), (((3,), (4,), (5,)),)]
1072
1073 Each dimension is assumed to be uniform, either all arrays or all scalars.
1074 Flattening stops when the first value in a dimension is a scalar.
1075 Scalars are bytes, strings, and non-iterables.
1076 The reshape iterator stops when the requested shape is complete
1077 or when the input is exhausted, whichever comes first.
1078
1079 """
1080 if isinstance(shape, int):
1081 return batched(chain.from_iterable(matrix), shape)
1082 first_dim, *dims = shape
1083 scalar_stream = _flatten_tensor(matrix)
1084 reshaped = reduce(batched, reversed(dims), scalar_stream)
1085 return islice(reshaped, first_dim)
1086
1087
1088def matmul(m1, m2):
1089 """Multiply two matrices.
1090
1091 >>> list(matmul([(7, 5), (3, 5)], [(2, 5), (7, 9)]))
1092 [(49, 80), (41, 60)]
1093
1094 The caller should ensure that the dimensions of the input matrices are
1095 compatible with each other.
1096
1097 Supports all numeric types: int, float, complex, Decimal, Fraction.
1098 """
1099 n = len(m2[0])
1100 return batched(starmap(_sumprod, product(m1, transpose(m2))), n)
1101
1102
1103def _factor_pollard(n):
1104 # Return a factor of n using Pollard's rho algorithm.
1105 # Efficient when n is odd and composite.
1106 for b in range(1, n):
1107 x = y = 2
1108 d = 1
1109 while d == 1:
1110 x = (x * x + b) % n
1111 y = (y * y + b) % n
1112 y = (y * y + b) % n
1113 d = gcd(x - y, n)
1114 if d != n:
1115 return d
1116 raise ValueError('prime or under 5') # pragma: no cover
1117
1118
1119_primes_below_211 = tuple(sieve(211))
1120
1121
1122def factor(n):
1123 """Yield the prime factors of n.
1124
1125 >>> list(factor(360))
1126 [2, 2, 2, 3, 3, 5]
1127
1128 Finds small factors with trial division. Larger factors are
1129 either verified as prime with ``is_prime`` or split into
1130 smaller factors with Pollard's rho algorithm.
1131 """
1132
1133 # Corner case reduction
1134 if n < 2:
1135 return
1136
1137 # Trial division reduction
1138 for prime in _primes_below_211:
1139 while not n % prime:
1140 yield prime
1141 n //= prime
1142
1143 # Pollard's rho reduction
1144 primes = []
1145 todo = [n] if n > 1 else []
1146 for n in todo:
1147 if n < 211**2 or is_prime(n):
1148 primes.append(n)
1149 else:
1150 fact = _factor_pollard(n)
1151 todo += (fact, n // fact)
1152 yield from sorted(primes)
1153
1154
1155def polynomial_eval(coefficients, x):
1156 """Evaluate a polynomial at a specific value.
1157
1158 Computes with better numeric stability than Horner's method.
1159
1160 Evaluate ``x^3 - 4 * x^2 - 17 * x + 60`` at ``x = 2.5``:
1161
1162 >>> coefficients = [1, -4, -17, 60]
1163 >>> x = 2.5
1164 >>> polynomial_eval(coefficients, x)
1165 8.125
1166
1167 Note that polynomial coefficients are specified in descending power order.
1168
1169 Supports all numeric types: int, float, complex, Decimal, Fraction.
1170 """
1171 n = len(coefficients)
1172 if n == 0:
1173 return type(x)(0)
1174 powers = map(pow, repeat(x), reversed(range(n)))
1175 return _sumprod(coefficients, powers)
1176
1177
1178def sum_of_squares(it):
1179 """Return the sum of the squares of the input values.
1180
1181 >>> sum_of_squares([10, 20, 30])
1182 1400
1183
1184 Supports all numeric types: int, float, complex, Decimal, Fraction.
1185 """
1186 return _sumprod(*tee(it))
1187
1188
1189def polynomial_derivative(coefficients):
1190 """Compute the first derivative of a polynomial.
1191
1192 Evaluate the derivative of ``x³ - 4 x² - 17 x + 60``:
1193
1194 >>> coefficients = [1, -4, -17, 60]
1195 >>> derivative_coefficients = polynomial_derivative(coefficients)
1196 >>> derivative_coefficients
1197 [3, -8, -17]
1198
1199 Note that polynomial coefficients are specified in descending power order.
1200
1201 Supports all numeric types: int, float, complex, Decimal, Fraction.
1202 """
1203 n = len(coefficients)
1204 powers = reversed(range(1, n))
1205 return list(map(mul, coefficients, powers))
1206
1207
1208def totient(n):
1209 """Return the count of natural numbers up to *n* that are coprime with *n*.
1210
1211 Euler's totient function φ(n) gives the number of totatives.
1212 Totative are integers k in the range 1 ≤ k ≤ n such that gcd(n, k) = 1.
1213
1214 >>> n = 9
1215 >>> totient(n)
1216 6
1217
1218 >>> totatives = [x for x in range(1, n) if gcd(n, x) == 1]
1219 >>> totatives
1220 [1, 2, 4, 5, 7, 8]
1221 >>> len(totatives)
1222 6
1223
1224 Reference: https://en.wikipedia.org/wiki/Euler%27s_totient_function
1225
1226 """
1227 for prime in set(factor(n)):
1228 n -= n // prime
1229 return n
1230
1231
1232# Miller–Rabin primality test: https://oeis.org/A014233
1233_perfect_tests = [
1234 (2047, (2,)),
1235 (9080191, (31, 73)),
1236 (4759123141, (2, 7, 61)),
1237 (1122004669633, (2, 13, 23, 1662803)),
1238 (2152302898747, (2, 3, 5, 7, 11)),
1239 (3474749660383, (2, 3, 5, 7, 11, 13)),
1240 (18446744073709551616, (2, 325, 9375, 28178, 450775, 9780504, 1795265022)),
1241 (
1242 3317044064679887385961981,
1243 (2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41),
1244 ),
1245]
1246
1247
1248@lru_cache
1249def _shift_to_odd(n):
1250 'Return s, d such that 2**s * d == n'
1251 s = ((n - 1) ^ n).bit_length() - 1
1252 d = n >> s
1253 assert (1 << s) * d == n and d & 1 and s >= 0
1254 return s, d
1255
1256
1257def _strong_probable_prime(n, base):
1258 assert (n > 2) and (n & 1) and (2 <= base < n)
1259
1260 s, d = _shift_to_odd(n - 1)
1261
1262 x = pow(base, d, n)
1263 if x == 1 or x == n - 1:
1264 return True
1265
1266 for _ in range(s - 1):
1267 x = x * x % n
1268 if x == n - 1:
1269 return True
1270
1271 return False
1272
1273
1274# Separate instance of Random() that doesn't share state
1275# with the default user instance of Random().
1276_private_randrange = random.Random().randrange
1277
1278
1279def is_prime(n):
1280 """Return ``True`` if *n* is prime and ``False`` otherwise.
1281
1282 Basic examples:
1283
1284 >>> is_prime(37)
1285 True
1286 >>> is_prime(3 * 13)
1287 False
1288 >>> is_prime(18_446_744_073_709_551_557)
1289 True
1290
1291 Find the next prime over one billion:
1292
1293 >>> next(filter(is_prime, count(10**9)))
1294 1000000007
1295
1296 Generate random primes up to 200 bits and up to 60 decimal digits:
1297
1298 >>> from random import seed, randrange, getrandbits
1299 >>> seed(18675309)
1300
1301 >>> next(filter(is_prime, map(getrandbits, repeat(200))))
1302 893303929355758292373272075469392561129886005037663238028407
1303
1304 >>> next(filter(is_prime, map(randrange, repeat(10**60))))
1305 269638077304026462407872868003560484232362454342414618963649
1306
1307 This function is exact for values of *n* below 10**24. For larger inputs,
1308 the probabilistic Miller-Rabin primality test has a less than 1 in 2**128
1309 chance of a false positive.
1310 """
1311
1312 if n < 17:
1313 return n in {2, 3, 5, 7, 11, 13}
1314
1315 if not (n & 1 and n % 3 and n % 5 and n % 7 and n % 11 and n % 13):
1316 return False
1317
1318 for limit, bases in _perfect_tests:
1319 if n < limit:
1320 break
1321 else:
1322 bases = (_private_randrange(2, n - 1) for i in range(64))
1323
1324 return all(_strong_probable_prime(n, base) for base in bases)
1325
1326
1327def loops(n):
1328 """Returns an iterable with *n* elements for efficient looping.
1329 Like ``range(n)`` but doesn't create integers.
1330
1331 >>> i = 0
1332 >>> for _ in loops(5):
1333 ... i += 1
1334 >>> i
1335 5
1336
1337 """
1338 return repeat(None, n)
1339
1340
1341def multinomial(*counts):
1342 """Number of distinct arrangements of a multiset.
1343
1344 The expression ``multinomial(3, 4, 2)`` has several equivalent
1345 interpretations:
1346
1347 * In the expansion of ``(a + b + c)⁹``, the coefficient of the
1348 ``a³b⁴c²`` term is 1260.
1349
1350 * There are 1260 distinct ways to arrange 9 balls consisting of 3 reds, 4
1351 greens, and 2 blues.
1352
1353 * There are 1260 unique ways to place 9 distinct objects into three bins
1354 with sizes 3, 4, and 2.
1355
1356 The :func:`multinomial` function computes the length of
1357 :func:`distinct_permutations`. For example, there are 83,160 distinct
1358 anagrams of the word "abracadabra":
1359
1360 >>> from more_itertools import distinct_permutations, ilen
1361 >>> ilen(distinct_permutations('abracadabra'))
1362 83160
1363
1364 This can be computed directly from the letter counts, 5a 2b 2r 1c 1d:
1365
1366 >>> from collections import Counter
1367 >>> list(Counter('abracadabra').values())
1368 [5, 2, 2, 1, 1]
1369 >>> multinomial(5, 2, 2, 1, 1)
1370 83160
1371
1372 A binomial coefficient is a special case of multinomial where there are
1373 only two categories. For example, the number of ways to arrange 12 balls
1374 with 5 reds and 7 blues is ``multinomial(5, 7)`` or ``math.comb(12, 5)``.
1375
1376 Likewise, factorial is a special case of multinomial where
1377 the multiplicities are all just 1 so that
1378 ``multinomial(1, 1, 1, 1, 1, 1, 1) == math.factorial(7)``.
1379
1380 Reference: https://en.wikipedia.org/wiki/Multinomial_theorem
1381
1382 """
1383 return prod(map(comb, accumulate(counts), counts))
1384
1385
1386def _running_median_minheap_and_maxheap(iterator): # pragma: no cover
1387 "Non-windowed running_median() for Python 3.14+"
1388
1389 read = iterator.__next__
1390 lo = [] # max-heap
1391 hi = [] # min-heap (same size as or one smaller than lo)
1392
1393 with suppress(StopIteration):
1394 while True:
1395 heappush_max(lo, heappushpop(hi, read()))
1396 yield lo[0]
1397
1398 heappush(hi, heappushpop_max(lo, read()))
1399 yield (lo[0] + hi[0]) / 2
1400
1401
1402def _running_median_minheap_only(iterator): # pragma: no cover
1403 "Backport of non-windowed running_median() for Python 3.13 and prior."
1404
1405 read = iterator.__next__
1406 lo = [] # max-heap (actually a minheap with negated values)
1407 hi = [] # min-heap (same size as or one smaller than lo)
1408
1409 with suppress(StopIteration):
1410 while True:
1411 heappush(lo, -heappushpop(hi, read()))
1412 yield -lo[0]
1413
1414 heappush(hi, -heappushpop(lo, -read()))
1415 yield (hi[0] - lo[0]) / 2
1416
1417
1418def _running_median_windowed(iterator, maxlen):
1419 "Yield median of values in a sliding window."
1420
1421 window = deque()
1422 ordered = []
1423
1424 for x in iterator:
1425 window.append(x)
1426 insort(ordered, x)
1427
1428 if len(ordered) > maxlen:
1429 i = bisect_left(ordered, window.popleft())
1430 del ordered[i]
1431
1432 n = len(ordered)
1433 m = n // 2
1434 yield ordered[m] if n & 1 else (ordered[m - 1] + ordered[m]) / 2
1435
1436
1437def running_median(iterable, *, maxlen=None):
1438 """Cumulative median of values seen so far or values in a sliding window.
1439
1440 Set *maxlen* to a positive integer to specify the maximum size
1441 of the sliding window. The default of *None* is equivalent to
1442 an unbounded window.
1443
1444 For example:
1445
1446 >>> list(running_median([5.0, 9.0, 4.0, 12.0, 8.0, 9.0]))
1447 [5.0, 7.0, 5.0, 7.0, 8.0, 8.5]
1448 >>> list(running_median([5.0, 9.0, 4.0, 12.0, 8.0, 9.0], maxlen=3))
1449 [5.0, 7.0, 5.0, 9.0, 8.0, 9.0]
1450
1451 Supports numeric types such as int, float, Decimal, and Fraction,
1452 but not complex numbers which are unorderable.
1453
1454 On version Python 3.13 and prior, max-heaps are simulated with
1455 negative values. The negation causes Decimal inputs to apply context
1456 rounding, making the results slightly different than that obtained
1457 by statistics.median().
1458 """
1459
1460 iterator = iter(iterable)
1461
1462 if maxlen is not None:
1463 maxlen = index(maxlen)
1464 if maxlen <= 0:
1465 raise ValueError('Window size should be positive')
1466 return _running_median_windowed(iterator, maxlen)
1467
1468 if not _max_heap_available:
1469 return _running_median_minheap_only(iterator) # pragma: no cover
1470
1471 return _running_median_minheap_and_maxheap(iterator) # pragma: no cover