1"""Imported from the recipes section of the itertools documentation.
2
3All functions taken from the recipes section of the itertools library docs
4[1]_.
5Some backward-compatible usability improvements have been made.
6
7.. [1] http://docs.python.org/library/itertools.html#recipes
8
9"""
10
11import random
12
13from bisect import bisect_left, insort
14from collections import deque
15from contextlib import suppress
16from functools import lru_cache, reduce
17from heapq import heappush, heappushpop
18from itertools import (
19 accumulate,
20 chain,
21 combinations,
22 compress,
23 count,
24 cycle,
25 groupby,
26 islice,
27 pairwise,
28 product,
29 repeat,
30 starmap,
31 takewhile,
32 tee,
33 zip_longest,
34)
35from math import prod, comb, isqrt, gcd
36from operator import mul, getitem, index, is_, itemgetter, not_, truediv
37from random import randrange, sample, choice, shuffle
38from sys import hexversion
39
40__all__ = [
41 'all_equal',
42 'batched',
43 'before_and_after',
44 'consume',
45 'convolve',
46 'dotproduct',
47 'first_true',
48 'factor',
49 'flatten',
50 'grouper',
51 'is_prime',
52 'iter_except',
53 'iter_index',
54 'loops',
55 'matmul',
56 'multinomial',
57 'ncycles',
58 'nth',
59 'nth_combination',
60 'padnone',
61 'pad_none',
62 'partition',
63 'polynomial_eval',
64 'polynomial_from_roots',
65 'polynomial_derivative',
66 'powerset',
67 'prepend',
68 'quantify',
69 'reshape',
70 'random_combination_with_replacement',
71 'random_combination',
72 'random_derangement',
73 'random_permutation',
74 'random_product',
75 'repeatfunc',
76 'roundrobin',
77 'running_mean',
78 'running_median',
79 'sieve',
80 'sliding_window',
81 'subslices',
82 'sum_of_squares',
83 'tabulate',
84 'tail',
85 'take',
86 'totient',
87 'transpose',
88 'triplewise',
89 'unique',
90 'unique_everseen',
91 'unique_justseen',
92]
93
94_marker = object()
95
96
97# heapq max-heap functions are available for Python 3.14+
98try:
99 from heapq import heappush_max, heappushpop_max
100except ImportError: # pragma: no cover
101 _max_heap_available = False
102else: # pragma: no cover
103 _max_heap_available = True
104
105
106def take(n, iterable):
107 """Return first *n* items of the *iterable* as a list.
108
109 >>> take(3, range(10))
110 [0, 1, 2]
111
112 If there are fewer than *n* items in the iterable, all of them are
113 returned.
114
115 >>> take(10, range(3))
116 [0, 1, 2]
117
118 """
119 return list(islice(iterable, n))
120
121
122def tabulate(function, start=0):
123 """Return an iterator over the results of ``func(start)``,
124 ``func(start + 1)``, ``func(start + 2)``...
125
126 *func* should be a function that accepts one integer argument.
127
128 If *start* is not specified it defaults to 0. It will be incremented each
129 time the iterator is advanced.
130
131 >>> square = lambda x: x ** 2
132 >>> iterator = tabulate(square, -3)
133 >>> take(4, iterator)
134 [9, 4, 1, 0]
135
136 """
137 return map(function, count(start))
138
139
140def tail(n, iterable):
141 """Return an iterator over the last *n* items of *iterable*.
142
143 >>> t = tail(3, 'ABCDEFG')
144 >>> list(t)
145 ['E', 'F', 'G']
146
147 """
148 try:
149 size = len(iterable)
150 except TypeError:
151 return iter(deque(iterable, maxlen=n))
152 else:
153 return islice(iterable, max(0, size - n), None)
154
155
156def consume(iterator, n=None):
157 """Advance *iterable* by *n* steps. If *n* is ``None``, consume it
158 entirely.
159
160 Efficiently exhausts an iterator without returning values. Defaults to
161 consuming the whole iterator, but an optional second argument may be
162 provided to limit consumption.
163
164 >>> i = (x for x in range(10))
165 >>> next(i)
166 0
167 >>> consume(i, 3)
168 >>> next(i)
169 4
170 >>> consume(i)
171 >>> next(i)
172 Traceback (most recent call last):
173 File "<stdin>", line 1, in <module>
174 StopIteration
175
176 If the iterator has fewer items remaining than the provided limit, the
177 whole iterator will be consumed.
178
179 >>> i = (x for x in range(3))
180 >>> consume(i, 5)
181 >>> next(i)
182 Traceback (most recent call last):
183 File "<stdin>", line 1, in <module>
184 StopIteration
185
186 """
187 # Use functions that consume iterators at C speed.
188 if n is None:
189 # feed the entire iterator into a zero-length deque
190 deque(iterator, maxlen=0)
191 else:
192 # advance to the empty slice starting at position n
193 next(islice(iterator, n, n), None)
194
195
196def nth(iterable, n, default=None):
197 """Returns the nth item or a default value.
198
199 >>> l = range(10)
200 >>> nth(l, 3)
201 3
202 >>> nth(l, 20, "zebra")
203 'zebra'
204
205 """
206 return next(islice(iterable, n, None), default)
207
208
209def all_equal(iterable, key=None):
210 """
211 Returns ``True`` if all the elements are equal to each other.
212
213 >>> all_equal('aaaa')
214 True
215 >>> all_equal('aaab')
216 False
217
218 A function that accepts a single argument and returns a transformed version
219 of each input item can be specified with *key*:
220
221 >>> all_equal('AaaA', key=str.casefold)
222 True
223 >>> all_equal([1, 2, 3], key=lambda x: x < 10)
224 True
225
226 """
227 iterator = groupby(iterable, key)
228 for first in iterator:
229 for second in iterator:
230 return False
231 return True
232 return True
233
234
235def quantify(iterable, pred=bool):
236 """Return the how many times the predicate is true.
237
238 >>> quantify([True, False, True])
239 2
240
241 """
242 return sum(map(pred, iterable))
243
244
245def pad_none(iterable):
246 """Returns the sequence of elements and then returns ``None`` indefinitely.
247
248 >>> take(5, pad_none(range(3)))
249 [0, 1, 2, None, None]
250
251 Useful for emulating the behavior of the built-in :func:`map` function.
252
253 See also :func:`padded`.
254
255 """
256 return chain(iterable, repeat(None))
257
258
259padnone = pad_none
260
261
262def ncycles(iterable, n):
263 """Returns the sequence elements *n* times
264
265 >>> list(ncycles(["a", "b"], 3))
266 ['a', 'b', 'a', 'b', 'a', 'b']
267
268 """
269 return chain.from_iterable(repeat(tuple(iterable), n))
270
271
272def dotproduct(vec1, vec2):
273 """Returns the dot product of the two iterables.
274
275 >>> dotproduct([10, 15, 12], [0.65, 0.80, 1.25])
276 33.5
277 >>> 10 * 0.65 + 15 * 0.80 + 12 * 1.25
278 33.5
279
280 In Python 3.12 and later, use ``math.sumprod()`` instead.
281 """
282 return sum(map(mul, vec1, vec2))
283
284
285# math.sumprod is available for Python 3.12+
286try:
287 from math import sumprod as _sumprod
288except ImportError: # pragma: no cover
289 _sumprod = dotproduct
290
291
292def flatten(listOfLists):
293 """Return an iterator flattening one level of nesting in a list of lists.
294
295 >>> list(flatten([[0, 1], [2, 3]]))
296 [0, 1, 2, 3]
297
298 See also :func:`collapse`, which can flatten multiple levels of nesting.
299
300 """
301 return chain.from_iterable(listOfLists)
302
303
304def repeatfunc(func, times=None, *args):
305 """Call *func* with *args* repeatedly, returning an iterable over the
306 results.
307
308 If *times* is specified, the iterable will terminate after that many
309 repetitions:
310
311 >>> from operator import add
312 >>> times = 4
313 >>> args = 3, 5
314 >>> list(repeatfunc(add, times, *args))
315 [8, 8, 8, 8]
316
317 If *times* is ``None`` the iterable will not terminate:
318
319 >>> from random import randrange
320 >>> times = None
321 >>> args = 1, 11
322 >>> take(6, repeatfunc(randrange, times, *args)) # doctest:+SKIP
323 [2, 4, 8, 1, 8, 4]
324
325 """
326 if times is None:
327 return starmap(func, repeat(args))
328 return starmap(func, repeat(args, times))
329
330
331def grouper(iterable, n, incomplete='fill', fillvalue=None):
332 """Group elements from *iterable* into fixed-length groups of length *n*.
333
334 >>> list(grouper('ABCDEF', 3))
335 [('A', 'B', 'C'), ('D', 'E', 'F')]
336
337 The keyword arguments *incomplete* and *fillvalue* control what happens for
338 iterables whose length is not a multiple of *n*.
339
340 When *incomplete* is `'fill'`, the last group will contain instances of
341 *fillvalue*.
342
343 >>> list(grouper('ABCDEFG', 3, incomplete='fill', fillvalue='x'))
344 [('A', 'B', 'C'), ('D', 'E', 'F'), ('G', 'x', 'x')]
345
346 When *incomplete* is `'ignore'`, the last group will not be emitted.
347
348 >>> list(grouper('ABCDEFG', 3, incomplete='ignore', fillvalue='x'))
349 [('A', 'B', 'C'), ('D', 'E', 'F')]
350
351 When *incomplete* is `'strict'`, a `ValueError` will be raised.
352
353 >>> iterator = grouper('ABCDEFG', 3, incomplete='strict')
354 >>> list(iterator) # doctest: +IGNORE_EXCEPTION_DETAIL
355 Traceback (most recent call last):
356 ...
357 ValueError
358
359 """
360 iterators = [iter(iterable)] * n
361 match incomplete:
362 case 'fill':
363 return zip_longest(*iterators, fillvalue=fillvalue)
364 case 'strict':
365 return zip(*iterators, strict=True)
366 case 'ignore':
367 return zip(*iterators)
368 case _:
369 raise ValueError('Expected fill, strict, or ignore')
370
371
372def roundrobin(*iterables):
373 """Visit input iterables in a cycle until each is exhausted.
374
375 >>> list(roundrobin('ABC', 'D', 'EF'))
376 ['A', 'D', 'E', 'B', 'F', 'C']
377
378 This function produces the same output as :func:`interleave_longest`, but
379 may perform better for some inputs (in particular when the number of
380 iterables is small).
381
382 """
383 # Algorithm credited to George Sakkis
384 iterators = map(iter, iterables)
385 for num_active in range(len(iterables), 0, -1):
386 iterators = cycle(islice(iterators, num_active))
387 yield from map(next, iterators)
388
389
390def partition(pred, iterable):
391 """
392 Returns a 2-tuple of iterables derived from the input iterable.
393 The first yields the items that have ``pred(item) == False``.
394 The second yields the items that have ``pred(item) == True``.
395
396 >>> is_odd = lambda x: x % 2 != 0
397 >>> iterable = range(10)
398 >>> even_items, odd_items = partition(is_odd, iterable)
399 >>> list(even_items), list(odd_items)
400 ([0, 2, 4, 6, 8], [1, 3, 5, 7, 9])
401
402 If *pred* is None, :func:`bool` is used.
403
404 >>> iterable = [0, 1, False, True, '', ' ']
405 >>> false_items, true_items = partition(None, iterable)
406 >>> list(false_items), list(true_items)
407 ([0, False, ''], [1, True, ' '])
408
409 """
410 if pred is None:
411 pred = bool
412
413 t1, t2, p = tee(iterable, 3)
414 p1, p2 = tee(map(pred, p))
415 return (compress(t1, map(not_, p1)), compress(t2, p2))
416
417
418def powerset(iterable):
419 """Yields all possible subsets of the iterable.
420
421 >>> list(powerset([1, 2, 3]))
422 [(), (1,), (2,), (3,), (1, 2), (1, 3), (2, 3), (1, 2, 3)]
423
424 :func:`powerset` will operate on iterables that aren't :class:`set`
425 instances, so repeated elements in the input will produce repeated elements
426 in the output.
427
428 >>> seq = [1, 1, 0]
429 >>> list(powerset(seq))
430 [(), (1,), (1,), (0,), (1, 1), (1, 0), (1, 0), (1, 1, 0)]
431
432 For a variant that efficiently yields actual :class:`set` instances, see
433 :func:`powerset_of_sets`.
434 """
435 s = list(iterable)
436 return chain.from_iterable(combinations(s, r) for r in range(len(s) + 1))
437
438
439def unique_everseen(iterable, key=None):
440 """
441 Yield unique elements, preserving order.
442
443 >>> list(unique_everseen('AAAABBBCCDAABBB'))
444 ['A', 'B', 'C', 'D']
445 >>> list(unique_everseen('ABBCcAD', str.lower))
446 ['A', 'B', 'C', 'D']
447
448 Sequences with a mix of hashable and unhashable items can be used.
449 The function will be slower (i.e., `O(n^2)`) for unhashable items.
450
451 Remember that ``list`` objects are unhashable - you can use the *key*
452 parameter to transform the list to a tuple (which is hashable) to
453 avoid a slowdown.
454
455 >>> iterable = ([1, 2], [2, 3], [1, 2])
456 >>> list(unique_everseen(iterable)) # Slow
457 [[1, 2], [2, 3]]
458 >>> list(unique_everseen(iterable, key=tuple)) # Faster
459 [[1, 2], [2, 3]]
460
461 Similarly, you may want to convert unhashable ``set`` objects with
462 ``key=frozenset``. For ``dict`` objects,
463 ``key=lambda x: frozenset(x.items())`` can be used.
464
465 """
466 seenset = set()
467 seenset_add = seenset.add
468 seenlist = []
469 seenlist_add = seenlist.append
470 use_key = key is not None
471
472 for element in iterable:
473 k = key(element) if use_key else element
474 try:
475 if k not in seenset:
476 seenset_add(k)
477 yield element
478 except TypeError:
479 if k not in seenlist:
480 seenlist_add(k)
481 yield element
482
483
484def unique_justseen(iterable, key=None):
485 """Yields elements in order, ignoring serial duplicates
486
487 >>> list(unique_justseen('AAAABBBCCDAABBB'))
488 ['A', 'B', 'C', 'D', 'A', 'B']
489 >>> list(unique_justseen('ABBCcAD', str.lower))
490 ['A', 'B', 'C', 'A', 'D']
491
492 """
493 if key is None:
494 return map(itemgetter(0), groupby(iterable))
495
496 return map(next, map(itemgetter(1), groupby(iterable, key)))
497
498
499def unique(iterable, key=None, reverse=False):
500 """Yields unique elements in sorted order.
501
502 >>> list(unique([[1, 2], [3, 4], [1, 2]]))
503 [[1, 2], [3, 4]]
504
505 *key* and *reverse* are passed to :func:`sorted`.
506
507 >>> list(unique('ABBcCAD', str.casefold))
508 ['A', 'B', 'c', 'D']
509 >>> list(unique('ABBcCAD', str.casefold, reverse=True))
510 ['D', 'c', 'B', 'A']
511
512 The elements in *iterable* need not be hashable, but they must be
513 comparable for sorting to work.
514 """
515 sequenced = sorted(iterable, key=key, reverse=reverse)
516 return unique_justseen(sequenced, key=key)
517
518
519def iter_except(func, exception, first=None):
520 """Yields results from a function repeatedly until an exception is raised.
521
522 Converts a call-until-exception interface to an iterator interface.
523 Like ``iter(func, sentinel)``, but uses an exception instead of a sentinel
524 to end the loop.
525
526 >>> l = [0, 1, 2]
527 >>> list(iter_except(l.pop, IndexError))
528 [2, 1, 0]
529
530 Multiple exceptions can be specified as a stopping condition:
531
532 >>> l = [1, 2, 3, '...', 4, 5, 6]
533 >>> list(iter_except(lambda: 1 + l.pop(), (IndexError, TypeError)))
534 [7, 6, 5]
535 >>> list(iter_except(lambda: 1 + l.pop(), (IndexError, TypeError)))
536 [4, 3, 2]
537 >>> list(iter_except(lambda: 1 + l.pop(), (IndexError, TypeError)))
538 []
539
540 """
541 with suppress(exception):
542 if first is not None:
543 yield first()
544 while True:
545 yield func()
546
547
548def first_true(iterable, default=None, pred=None):
549 """
550 Returns the first true value in the iterable.
551
552 If no true value is found, returns *default*
553
554 If *pred* is not None, returns the first item for which
555 ``pred(item) == True`` .
556
557 >>> first_true(range(10))
558 1
559 >>> first_true(range(10), pred=lambda x: x > 5)
560 6
561 >>> first_true(range(10), default='missing', pred=lambda x: x > 9)
562 'missing'
563
564 """
565 return next(filter(pred, iterable), default)
566
567
568def random_product(*iterables, repeat=1):
569 """Draw an item at random from each of the input iterables.
570
571 >>> random_product('abc', range(4), 'XYZ') # doctest:+SKIP
572 ('c', 3, 'Z')
573
574 If *repeat* is provided as a keyword argument, that many items will be
575 drawn from each iterable.
576
577 >>> random_product('abcd', range(4), repeat=2) # doctest:+SKIP
578 ('a', 2, 'd', 3)
579
580 This equivalent to taking a random selection from
581 ``itertools.product(*args, repeat=repeat)``.
582
583 """
584 pools = tuple(map(tuple, iterables)) * repeat
585 return tuple(map(choice, pools))
586
587
588def random_permutation(iterable, r=None):
589 """Return a random *r* length permutation of the elements in *iterable*.
590
591 If *r* is not specified or is ``None``, then *r* defaults to the length of
592 *iterable*.
593
594 >>> random_permutation(range(5)) # doctest:+SKIP
595 (3, 4, 0, 1, 2)
596
597 This equivalent to taking a random selection from
598 ``itertools.permutations(iterable, r)``.
599
600 """
601 pool = tuple(iterable)
602 r = len(pool) if r is None else r
603 return tuple(sample(pool, r))
604
605
606def random_combination(iterable, r):
607 """Return a random *r* length subsequence of the elements in *iterable*.
608
609 >>> random_combination(range(5), 3) # doctest:+SKIP
610 (2, 3, 4)
611
612 This equivalent to taking a random selection from
613 ``itertools.combinations(iterable, r)``.
614
615 """
616 pool = tuple(iterable)
617 n = len(pool)
618 indices = sorted(sample(range(n), r))
619 return tuple([pool[i] for i in indices])
620
621
622def random_combination_with_replacement(iterable, r):
623 """Return a random *r* length subsequence of elements in *iterable*,
624 allowing individual elements to be repeated.
625
626 >>> random_combination_with_replacement(range(3), 5) # doctest:+SKIP
627 (0, 0, 1, 2, 2)
628
629 This equivalent to taking a random selection from
630 ``itertools.combinations_with_replacement(iterable, r)``.
631
632 """
633 pool = tuple(iterable)
634 n = len(pool)
635 indices = sorted(randrange(n) for i in range(r))
636 return tuple([pool[i] for i in indices])
637
638
639def nth_combination(iterable, r, index):
640 """Equivalent to ``list(combinations(iterable, r))[index]``.
641
642 The subsequences of *iterable* that are of length *r* can be ordered
643 lexicographically. :func:`nth_combination` computes the subsequence at
644 sort position *index* directly, without computing the previous
645 subsequences.
646
647 >>> nth_combination(range(5), 3, 5)
648 (0, 3, 4)
649
650 ``ValueError`` will be raised If *r* is negative.
651 ``IndexError`` will be raised if the given *index* is invalid.
652 """
653 pool = tuple(iterable)
654 n = len(pool)
655 c = comb(n, r)
656
657 if index < 0:
658 index += c
659 if not 0 <= index < c:
660 raise IndexError
661
662 result = []
663 while r:
664 c, n, r = c * r // n, n - 1, r - 1
665 while index >= c:
666 index -= c
667 c, n = c * (n - r) // n, n - 1
668 result.append(pool[-1 - n])
669
670 return tuple(result)
671
672
673def prepend(value, iterator):
674 """Yield *value*, followed by the elements in *iterator*.
675
676 >>> value = '0'
677 >>> iterator = ['1', '2', '3']
678 >>> list(prepend(value, iterator))
679 ['0', '1', '2', '3']
680
681 To prepend multiple values, see :func:`itertools.chain`
682 or :func:`value_chain`.
683
684 """
685 return chain([value], iterator)
686
687
688def convolve(signal, kernel):
689 """Discrete linear convolution of two iterables.
690 Equivalent to polynomial multiplication.
691
692 For example, multiplying ``(x² -x - 20)`` by ``(x - 3)``
693 gives ``(x³ -4x² -17x + 60)``.
694
695 >>> list(convolve([1, -1, -20], [1, -3]))
696 [1, -4, -17, 60]
697
698 Examples of popular kinds of kernels:
699
700 * The kernel ``[0.25, 0.25, 0.25, 0.25]`` computes a moving average.
701 For image data, this blurs the image and reduces noise.
702 * The kernel ``[1/2, 0, -1/2]`` estimates the first derivative of
703 a function evaluated at evenly spaced inputs.
704 * The kernel ``[1, -2, 1]`` estimates the second derivative of a
705 function evaluated at evenly spaced inputs.
706
707 Convolutions are mathematically commutative; however, the inputs are
708 evaluated differently. The signal is consumed lazily and can be
709 infinite. The kernel is fully consumed before the calculations begin.
710
711 Supports all numeric types: int, float, complex, Decimal, Fraction.
712
713 References:
714
715 * Article: https://betterexplained.com/articles/intuitive-convolution/
716 * Video by 3Blue1Brown: https://www.youtube.com/watch?v=KuXjwB4LzSA
717
718 """
719 # This implementation comes from an older version of the itertools
720 # documentation. While the newer implementation is a bit clearer,
721 # this one was kept because the inlined window logic is faster
722 # and it avoids an unnecessary deque-to-tuple conversion.
723 kernel = tuple(kernel)[::-1]
724 n = len(kernel)
725 window = deque([0], maxlen=n) * n
726 for x in chain(signal, repeat(0, n - 1)):
727 window.append(x)
728 yield _sumprod(kernel, window)
729
730
731def before_and_after(predicate, it):
732 """A variant of :func:`takewhile` that allows complete access to the
733 remainder of the iterator.
734
735 >>> it = iter('ABCdEfGhI')
736 >>> all_upper, remainder = before_and_after(str.isupper, it)
737 >>> ''.join(all_upper)
738 'ABC'
739 >>> ''.join(remainder) # takewhile() would lose the 'd'
740 'dEfGhI'
741
742 Note that the first iterator must be fully consumed before the second
743 iterator can generate valid results.
744 """
745 trues, after = tee(it)
746 trues = compress(takewhile(predicate, trues), zip(after))
747 return trues, after
748
749
750def triplewise(iterable):
751 """Return overlapping triplets from *iterable*.
752
753 >>> list(triplewise('ABCDE'))
754 [('A', 'B', 'C'), ('B', 'C', 'D'), ('C', 'D', 'E')]
755
756 """
757 # This deviates from the itertools documentation recipe - see
758 # https://github.com/more-itertools/more-itertools/issues/889
759 t1, t2, t3 = tee(iterable, 3)
760 next(t3, None)
761 next(t3, None)
762 next(t2, None)
763 return zip(t1, t2, t3)
764
765
766def _sliding_window_islice(iterable, n):
767 # Fast path for small, non-zero values of n.
768 iterators = tee(iterable, n)
769 for i, iterator in enumerate(iterators):
770 next(islice(iterator, i, i), None)
771 return zip(*iterators)
772
773
774def _sliding_window_deque(iterable, n):
775 # Normal path for other values of n.
776 iterator = iter(iterable)
777 window = deque(islice(iterator, n - 1), maxlen=n)
778 for x in iterator:
779 window.append(x)
780 yield tuple(window)
781
782
783def sliding_window(iterable, n):
784 """Return a sliding window of width *n* over *iterable*.
785
786 >>> list(sliding_window(range(6), 4))
787 [(0, 1, 2, 3), (1, 2, 3, 4), (2, 3, 4, 5)]
788
789 If *iterable* has fewer than *n* items, then nothing is yielded:
790
791 >>> list(sliding_window(range(3), 4))
792 []
793
794 For a variant with more features, see :func:`windowed`.
795 """
796 if n > 20:
797 return _sliding_window_deque(iterable, n)
798 elif n > 2:
799 return _sliding_window_islice(iterable, n)
800 elif n == 2:
801 return pairwise(iterable)
802 elif n == 1:
803 return zip(iterable)
804 else:
805 raise ValueError(f'n should be at least one, not {n}')
806
807
808def subslices(iterable):
809 """Return all contiguous non-empty subslices of *iterable*.
810
811 >>> list(subslices('ABC'))
812 [['A'], ['A', 'B'], ['A', 'B', 'C'], ['B'], ['B', 'C'], ['C']]
813
814 This is similar to :func:`substrings`, but emits items in a different
815 order.
816 """
817 seq = list(iterable)
818 slices = starmap(slice, combinations(range(len(seq) + 1), 2))
819 return map(getitem, repeat(seq), slices)
820
821
822def polynomial_from_roots(roots):
823 """Compute a polynomial's coefficients from its roots.
824
825 >>> roots = [5, -4, 3] # (x - 5) * (x + 4) * (x - 3)
826 >>> polynomial_from_roots(roots) # x³ - 4 x² - 17 x + 60
827 [1, -4, -17, 60]
828
829 Note that polynomial coefficients are specified in descending power order.
830
831 Supports all numeric types: int, float, complex, Decimal, Fraction.
832 """
833
834 # This recipe differs from the one in itertools docs in that it
835 # applies list() after each call to convolve(). This avoids
836 # hitting stack limits with nested generators.
837
838 poly = [1]
839 for root in roots:
840 poly = list(convolve(poly, (1, -root)))
841 return poly
842
843
844def iter_index(iterable, value, start=0, stop=None):
845 """Yield the index of each place in *iterable* that *value* occurs,
846 beginning with index *start* and ending before index *stop*.
847
848
849 >>> list(iter_index('AABCADEAF', 'A'))
850 [0, 1, 4, 7]
851 >>> list(iter_index('AABCADEAF', 'A', 1)) # start index is inclusive
852 [1, 4, 7]
853 >>> list(iter_index('AABCADEAF', 'A', 1, 7)) # stop index is not inclusive
854 [1, 4]
855
856 The behavior for non-scalar *values* matches the built-in Python types.
857
858 >>> list(iter_index('ABCDABCD', 'AB'))
859 [0, 4]
860 >>> list(iter_index([0, 1, 2, 3, 0, 1, 2, 3], [0, 1]))
861 []
862 >>> list(iter_index([[0, 1], [2, 3], [0, 1], [2, 3]], [0, 1]))
863 [0, 2]
864
865 See :func:`locate` for a more general means of finding the indexes
866 associated with particular values.
867
868 """
869 seq_index = getattr(iterable, 'index', None)
870 if seq_index is None:
871 # Slow path for general iterables
872 iterator = islice(iterable, start, stop)
873 for i, element in enumerate(iterator, start):
874 if element is value or element == value:
875 yield i
876 else:
877 # Fast path for sequences
878 stop = len(iterable) if stop is None else stop
879 i = start - 1
880 with suppress(ValueError):
881 while True:
882 yield (i := seq_index(value, i + 1, stop))
883
884
885def sieve(n):
886 """Yield the primes less than n.
887
888 >>> list(sieve(30))
889 [2, 3, 5, 7, 11, 13, 17, 19, 23, 29]
890
891 """
892 # This implementation comes from an older version of the itertools
893 # documentation. The newer implementation is easier to read but is
894 # less lazy.
895 if n > 2:
896 yield 2
897 start = 3
898 data = bytearray((0, 1)) * (n // 2)
899 for p in iter_index(data, 1, start, stop=isqrt(n) + 1):
900 yield from iter_index(data, 1, start, p * p)
901 data[p * p : n : p + p] = bytes(len(range(p * p, n, p + p)))
902 start = p * p
903 yield from iter_index(data, 1, start)
904
905
906def _batched(iterable, n, *, strict=False): # pragma: no cover
907 """Batch data into tuples of length *n*. If the number of items in
908 *iterable* is not divisible by *n*:
909 * The last batch will be shorter if *strict* is ``False``.
910 * :exc:`ValueError` will be raised if *strict* is ``True``.
911
912 >>> list(batched('ABCDEFG', 3))
913 [('A', 'B', 'C'), ('D', 'E', 'F'), ('G',)]
914
915 On Python 3.13 and above, this is an alias for :func:`itertools.batched`.
916 """
917 if n < 1:
918 raise ValueError('n must be at least one')
919 iterator = iter(iterable)
920 while batch := tuple(islice(iterator, n)):
921 if strict and len(batch) != n:
922 raise ValueError('batched(): incomplete batch')
923 yield batch
924
925
926if hexversion >= 0x30D00A2: # pragma: no cover
927 from itertools import batched as itertools_batched
928
929 def batched(iterable, n, *, strict=False):
930 return itertools_batched(iterable, n, strict=strict)
931
932 batched.__doc__ = _batched.__doc__
933else: # pragma: no cover
934 batched = _batched
935
936
937def transpose(it):
938 """Swap the rows and columns of the input matrix.
939
940 >>> list(transpose([(1, 2, 3), (11, 22, 33)]))
941 [(1, 11), (2, 22), (3, 33)]
942
943 The caller should ensure that the dimensions of the input are compatible.
944 If the input is empty, no output will be produced.
945 """
946 return zip(*it, strict=True)
947
948
949def _is_scalar(value, stringlike=(str, bytes)):
950 "Scalars are bytes, strings, and non-iterables."
951 try:
952 iter(value)
953 except TypeError:
954 return True
955 return isinstance(value, stringlike)
956
957
958def _flatten_tensor(tensor):
959 "Depth-first iterator over scalars in a tensor."
960 iterator = iter(tensor)
961 while True:
962 try:
963 value = next(iterator)
964 except StopIteration:
965 return iterator
966 iterator = chain((value,), iterator)
967 if _is_scalar(value):
968 return iterator
969 iterator = chain.from_iterable(iterator)
970
971
972def reshape(matrix, shape):
973 """Change the shape of a *matrix*.
974
975 If *shape* is an integer, the matrix must be two dimensional
976 and the shape is interpreted as the desired number of columns:
977
978 >>> matrix = [(0, 1), (2, 3), (4, 5)]
979 >>> cols = 3
980 >>> list(reshape(matrix, cols))
981 [(0, 1, 2), (3, 4, 5)]
982
983 If *shape* is a tuple (or other iterable), the input matrix can have
984 any number of dimensions. It will first be flattened and then rebuilt
985 to the desired shape which can also be multidimensional:
986
987 >>> matrix = [(0, 1), (2, 3), (4, 5)] # Start with a 3 x 2 matrix
988
989 >>> list(reshape(matrix, (2, 3))) # Make a 2 x 3 matrix
990 [(0, 1, 2), (3, 4, 5)]
991
992 >>> list(reshape(matrix, (6,))) # Make a vector of length six
993 [0, 1, 2, 3, 4, 5]
994
995 >>> list(reshape(matrix, (2, 1, 3, 1))) # Make 2 x 1 x 3 x 1 tensor
996 [(((0,), (1,), (2,)),), (((3,), (4,), (5,)),)]
997
998 Each dimension is assumed to be uniform, either all arrays or all scalars.
999 Flattening stops when the first value in a dimension is a scalar.
1000 Scalars are bytes, strings, and non-iterables.
1001 The reshape iterator stops when the requested shape is complete
1002 or when the input is exhausted, whichever comes first.
1003
1004 """
1005 if isinstance(shape, int):
1006 return batched(chain.from_iterable(matrix), shape)
1007 first_dim, *dims = shape
1008 scalar_stream = _flatten_tensor(matrix)
1009 reshaped = reduce(batched, reversed(dims), scalar_stream)
1010 return islice(reshaped, first_dim)
1011
1012
1013def matmul(m1, m2):
1014 """Multiply two matrices.
1015
1016 >>> list(matmul([(7, 5), (3, 5)], [(2, 5), (7, 9)]))
1017 [(49, 80), (41, 60)]
1018
1019 The caller should ensure that the dimensions of the input matrices are
1020 compatible with each other.
1021
1022 Supports all numeric types: int, float, complex, Decimal, Fraction.
1023 """
1024 n = len(m2[0])
1025 return batched(starmap(_sumprod, product(m1, transpose(m2))), n)
1026
1027
1028def _factor_pollard(n):
1029 # Return a factor of n using Pollard's rho algorithm.
1030 # Efficient when n is odd and composite.
1031 for b in range(1, n):
1032 x = y = 2
1033 d = 1
1034 while d == 1:
1035 x = (x * x + b) % n
1036 y = (y * y + b) % n
1037 y = (y * y + b) % n
1038 d = gcd(x - y, n)
1039 if d != n:
1040 return d
1041 raise ValueError('prime or under 5') # pragma: no cover
1042
1043
1044_primes_below_211 = tuple(sieve(211))
1045
1046
1047def factor(n):
1048 """Yield the prime factors of n.
1049
1050 >>> list(factor(360))
1051 [2, 2, 2, 3, 3, 5]
1052
1053 Finds small factors with trial division. Larger factors are
1054 either verified as prime with ``is_prime`` or split into
1055 smaller factors with Pollard's rho algorithm.
1056 """
1057
1058 # Corner case reduction
1059 if n < 2:
1060 return
1061
1062 # Trial division reduction
1063 for prime in _primes_below_211:
1064 while not n % prime:
1065 yield prime
1066 n //= prime
1067
1068 # Pollard's rho reduction
1069 primes = []
1070 todo = [n] if n > 1 else []
1071 for n in todo:
1072 if n < 211**2 or is_prime(n):
1073 primes.append(n)
1074 else:
1075 fact = _factor_pollard(n)
1076 todo += (fact, n // fact)
1077 yield from sorted(primes)
1078
1079
1080def polynomial_eval(coefficients, x):
1081 """Evaluate a polynomial at a specific value.
1082
1083 Computes with better numeric stability than Horner's method.
1084
1085 Evaluate ``x^3 - 4 * x^2 - 17 * x + 60`` at ``x = 2.5``:
1086
1087 >>> coefficients = [1, -4, -17, 60]
1088 >>> x = 2.5
1089 >>> polynomial_eval(coefficients, x)
1090 8.125
1091
1092 Note that polynomial coefficients are specified in descending power order.
1093
1094 Supports all numeric types: int, float, complex, Decimal, Fraction.
1095 """
1096 n = len(coefficients)
1097 if n == 0:
1098 return type(x)(0)
1099 powers = map(pow, repeat(x), reversed(range(n)))
1100 return _sumprod(coefficients, powers)
1101
1102
1103def sum_of_squares(it):
1104 """Return the sum of the squares of the input values.
1105
1106 >>> sum_of_squares([10, 20, 30])
1107 1400
1108
1109 Supports all numeric types: int, float, complex, Decimal, Fraction.
1110 """
1111 return _sumprod(*tee(it))
1112
1113
1114def polynomial_derivative(coefficients):
1115 """Compute the first derivative of a polynomial.
1116
1117 Evaluate the derivative of ``x³ - 4 x² - 17 x + 60``:
1118
1119 >>> coefficients = [1, -4, -17, 60]
1120 >>> derivative_coefficients = polynomial_derivative(coefficients)
1121 >>> derivative_coefficients
1122 [3, -8, -17]
1123
1124 Note that polynomial coefficients are specified in descending power order.
1125
1126 Supports all numeric types: int, float, complex, Decimal, Fraction.
1127 """
1128 n = len(coefficients)
1129 powers = reversed(range(1, n))
1130 return list(map(mul, coefficients, powers))
1131
1132
1133def totient(n):
1134 """Return the count of natural numbers up to *n* that are coprime with *n*.
1135
1136 Euler's totient function φ(n) gives the number of totatives.
1137 Totative are integers k in the range 1 ≤ k ≤ n such that gcd(n, k) = 1.
1138
1139 >>> n = 9
1140 >>> totient(n)
1141 6
1142
1143 >>> totatives = [x for x in range(1, n) if gcd(n, x) == 1]
1144 >>> totatives
1145 [1, 2, 4, 5, 7, 8]
1146 >>> len(totatives)
1147 6
1148
1149 Reference: https://en.wikipedia.org/wiki/Euler%27s_totient_function
1150
1151 """
1152 for prime in set(factor(n)):
1153 n -= n // prime
1154 return n
1155
1156
1157# Miller–Rabin primality test: https://oeis.org/A014233
1158_perfect_tests = [
1159 (2047, (2,)),
1160 (9080191, (31, 73)),
1161 (4759123141, (2, 7, 61)),
1162 (1122004669633, (2, 13, 23, 1662803)),
1163 (2152302898747, (2, 3, 5, 7, 11)),
1164 (3474749660383, (2, 3, 5, 7, 11, 13)),
1165 (18446744073709551616, (2, 325, 9375, 28178, 450775, 9780504, 1795265022)),
1166 (
1167 3317044064679887385961981,
1168 (2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41),
1169 ),
1170]
1171
1172
1173@lru_cache
1174def _shift_to_odd(n):
1175 'Return s, d such that 2**s * d == n'
1176 s = ((n - 1) ^ n).bit_length() - 1
1177 d = n >> s
1178 assert (1 << s) * d == n and d & 1 and s >= 0
1179 return s, d
1180
1181
1182def _strong_probable_prime(n, base):
1183 assert (n > 2) and (n & 1) and (2 <= base < n)
1184
1185 s, d = _shift_to_odd(n - 1)
1186
1187 x = pow(base, d, n)
1188 if x == 1 or x == n - 1:
1189 return True
1190
1191 for _ in range(s - 1):
1192 x = x * x % n
1193 if x == n - 1:
1194 return True
1195
1196 return False
1197
1198
1199# Separate instance of Random() that doesn't share state
1200# with the default user instance of Random().
1201_private_randrange = random.Random().randrange
1202
1203
1204def is_prime(n):
1205 """Return ``True`` if *n* is prime and ``False`` otherwise.
1206
1207 Basic examples:
1208
1209 >>> is_prime(37)
1210 True
1211 >>> is_prime(3 * 13)
1212 False
1213 >>> is_prime(18_446_744_073_709_551_557)
1214 True
1215
1216 Find the next prime over one billion:
1217
1218 >>> next(filter(is_prime, count(10**9)))
1219 1000000007
1220
1221 Generate random primes up to 200 bits and up to 60 decimal digits:
1222
1223 >>> from random import seed, randrange, getrandbits
1224 >>> seed(18675309)
1225
1226 >>> next(filter(is_prime, map(getrandbits, repeat(200))))
1227 893303929355758292373272075469392561129886005037663238028407
1228
1229 >>> next(filter(is_prime, map(randrange, repeat(10**60))))
1230 269638077304026462407872868003560484232362454342414618963649
1231
1232 This function is exact for values of *n* below 10**24. For larger inputs,
1233 the probabilistic Miller-Rabin primality test has a less than 1 in 2**128
1234 chance of a false positive.
1235 """
1236
1237 if n < 17:
1238 return n in {2, 3, 5, 7, 11, 13}
1239
1240 if not (n & 1 and n % 3 and n % 5 and n % 7 and n % 11 and n % 13):
1241 return False
1242
1243 for limit, bases in _perfect_tests:
1244 if n < limit:
1245 break
1246 else:
1247 bases = (_private_randrange(2, n - 1) for i in range(64))
1248
1249 return all(_strong_probable_prime(n, base) for base in bases)
1250
1251
1252def loops(n):
1253 """Returns an iterable with *n* elements for efficient looping.
1254 Like ``range(n)`` but doesn't create integers.
1255
1256 >>> i = 0
1257 >>> for _ in loops(5):
1258 ... i += 1
1259 >>> i
1260 5
1261
1262 """
1263 return repeat(None, n)
1264
1265
1266def multinomial(*counts):
1267 """Number of distinct arrangements of a multiset.
1268
1269 The expression ``multinomial(3, 4, 2)`` has several equivalent
1270 interpretations:
1271
1272 * In the expansion of ``(a + b + c)⁹``, the coefficient of the
1273 ``a³b⁴c²`` term is 1260.
1274
1275 * There are 1260 distinct ways to arrange 9 balls consisting of 3 reds, 4
1276 greens, and 2 blues.
1277
1278 * There are 1260 unique ways to place 9 distinct objects into three bins
1279 with sizes 3, 4, and 2.
1280
1281 The :func:`multinomial` function computes the length of
1282 :func:`distinct_permutations`. For example, there are 83,160 distinct
1283 anagrams of the word "abracadabra":
1284
1285 >>> from more_itertools import distinct_permutations, ilen
1286 >>> ilen(distinct_permutations('abracadabra'))
1287 83160
1288
1289 This can be computed directly from the letter counts, 5a 2b 2r 1c 1d:
1290
1291 >>> from collections import Counter
1292 >>> list(Counter('abracadabra').values())
1293 [5, 2, 2, 1, 1]
1294 >>> multinomial(5, 2, 2, 1, 1)
1295 83160
1296
1297 A binomial coefficient is a special case of multinomial where there are
1298 only two categories. For example, the number of ways to arrange 12 balls
1299 with 5 reds and 7 blues is ``multinomial(5, 7)`` or ``math.comb(12, 5)``.
1300
1301 Likewise, factorial is a special case of multinomial where
1302 the multiplicities are all just 1 so that
1303 ``multinomial(1, 1, 1, 1, 1, 1, 1) == math.factorial(7)``.
1304
1305 Reference: https://en.wikipedia.org/wiki/Multinomial_theorem
1306
1307 """
1308 return prod(map(comb, accumulate(counts), counts))
1309
1310
1311def _running_median_minheap_and_maxheap(iterator): # pragma: no cover
1312 "Non-windowed running_median() for Python 3.14+"
1313
1314 read = iterator.__next__
1315 lo = [] # max-heap
1316 hi = [] # min-heap (same size as or one smaller than lo)
1317
1318 with suppress(StopIteration):
1319 while True:
1320 heappush_max(lo, heappushpop(hi, read()))
1321 yield lo[0]
1322
1323 heappush(hi, heappushpop_max(lo, read()))
1324 yield (lo[0] + hi[0]) / 2
1325
1326
1327def _running_median_minheap_only(iterator): # pragma: no cover
1328 "Backport of non-windowed running_median() for Python 3.13 and prior."
1329
1330 read = iterator.__next__
1331 lo = [] # max-heap (actually a minheap with negated values)
1332 hi = [] # min-heap (same size as or one smaller than lo)
1333
1334 with suppress(StopIteration):
1335 while True:
1336 heappush(lo, -heappushpop(hi, read()))
1337 yield -lo[0]
1338
1339 heappush(hi, -heappushpop(lo, -read()))
1340 yield (hi[0] - lo[0]) / 2
1341
1342
1343def _running_median_windowed(iterator, maxlen):
1344 "Yield median of values in a sliding window."
1345
1346 window = deque()
1347 ordered = []
1348
1349 for x in iterator:
1350 window.append(x)
1351 insort(ordered, x)
1352
1353 if len(ordered) > maxlen:
1354 i = bisect_left(ordered, window.popleft())
1355 del ordered[i]
1356
1357 n = len(ordered)
1358 m = n // 2
1359 yield ordered[m] if n & 1 else (ordered[m - 1] + ordered[m]) / 2
1360
1361
1362def running_median(iterable, *, maxlen=None):
1363 """Cumulative median of values seen so far or values in a sliding window.
1364
1365 Set *maxlen* to a positive integer to specify the maximum size
1366 of the sliding window. The default of *None* is equivalent to
1367 an unbounded window.
1368
1369 For example:
1370
1371 >>> list(running_median([5.0, 9.0, 4.0, 12.0, 8.0, 9.0]))
1372 [5.0, 7.0, 5.0, 7.0, 8.0, 8.5]
1373 >>> list(running_median([5.0, 9.0, 4.0, 12.0, 8.0, 9.0], maxlen=3))
1374 [5.0, 7.0, 5.0, 9.0, 8.0, 9.0]
1375
1376 Supports numeric types such as int, float, Decimal, and Fraction,
1377 but not complex numbers which are unorderable.
1378
1379 On version Python 3.13 and prior, max-heaps are simulated with
1380 negative values. The negation causes Decimal inputs to apply context
1381 rounding, making the results slightly different than that obtained
1382 by statistics.median().
1383 """
1384
1385 iterator = iter(iterable)
1386
1387 if maxlen is not None:
1388 maxlen = index(maxlen)
1389 if maxlen <= 0:
1390 raise ValueError('Window size should be positive')
1391 return _running_median_windowed(iterator, maxlen)
1392
1393 if not _max_heap_available:
1394 return _running_median_minheap_only(iterator) # pragma: no cover
1395
1396 return _running_median_minheap_and_maxheap(iterator) # pragma: no cover
1397
1398
1399def running_mean(iterable):
1400 """Yield the average of all values seen so far.
1401
1402 >>> iterable = [100, 200, 600]
1403 >>> all_means = running_mean(iterable)
1404 >>> next(all_means)
1405 100.0
1406 >>> next(all_means)
1407 150.0
1408 >>> next(all_means)
1409 300.0
1410 """
1411 return map(truediv, accumulate(iterable), count(1))
1412
1413
1414def random_derangement(iterable):
1415 """Return a random derangement of elements in the iterable.
1416
1417 Equivalent to but much faster than ``choice(list(derangements(iterable)))``.
1418
1419 """
1420 seq = tuple(iterable)
1421 if len(seq) < 2:
1422 if len(seq) == 0:
1423 return ()
1424 raise IndexError('No derangments to choose from')
1425 perm = list(range(len(seq)))
1426 start = tuple(perm)
1427 while True:
1428 shuffle(perm)
1429 if not any(map(is_, start, perm)):
1430 return itemgetter(*perm)(seq)