1"""Imported from the recipes section of the itertools documentation.
2
3All functions taken from the recipes section of the itertools library docs
4[1]_.
5Some backward-compatible usability improvements have been made.
6
7.. [1] http://docs.python.org/library/itertools.html#recipes
8
9"""
10
11import random
12
13from bisect import bisect_left, insort
14from collections import deque
15from contextlib import suppress
16from functools import lru_cache, reduce
17from heapq import heappush, heappushpop
18from itertools import (
19 accumulate,
20 chain,
21 combinations,
22 compress,
23 count,
24 cycle,
25 groupby,
26 islice,
27 pairwise,
28 product,
29 repeat,
30 starmap,
31 takewhile,
32 tee,
33 zip_longest,
34)
35from math import prod, comb, isqrt, gcd
36from operator import mul, is_, not_, itemgetter, getitem, index
37from random import randrange, sample, choice, shuffle
38from sys import hexversion
39
40__all__ = [
41 'all_equal',
42 'batched',
43 'before_and_after',
44 'consume',
45 'convolve',
46 'dotproduct',
47 'first_true',
48 'factor',
49 'flatten',
50 'grouper',
51 'is_prime',
52 'iter_except',
53 'iter_index',
54 'loops',
55 'matmul',
56 'multinomial',
57 'ncycles',
58 'nth',
59 'nth_combination',
60 'padnone',
61 'pad_none',
62 'partition',
63 'polynomial_eval',
64 'polynomial_from_roots',
65 'polynomial_derivative',
66 'powerset',
67 'prepend',
68 'quantify',
69 'reshape',
70 'random_combination_with_replacement',
71 'random_combination',
72 'random_derangement',
73 'random_permutation',
74 'random_product',
75 'repeatfunc',
76 'roundrobin',
77 'running_median',
78 'sieve',
79 'sliding_window',
80 'subslices',
81 'sum_of_squares',
82 'tabulate',
83 'tail',
84 'take',
85 'totient',
86 'transpose',
87 'triplewise',
88 'unique',
89 'unique_everseen',
90 'unique_justseen',
91]
92
93_marker = object()
94
95
96# heapq max-heap functions are available for Python 3.14+
97try:
98 from heapq import heappush_max, heappushpop_max
99except ImportError: # pragma: no cover
100 _max_heap_available = False
101else: # pragma: no cover
102 _max_heap_available = True
103
104
105def take(n, iterable):
106 """Return first *n* items of the *iterable* as a list.
107
108 >>> take(3, range(10))
109 [0, 1, 2]
110
111 If there are fewer than *n* items in the iterable, all of them are
112 returned.
113
114 >>> take(10, range(3))
115 [0, 1, 2]
116
117 """
118 return list(islice(iterable, n))
119
120
121def tabulate(function, start=0):
122 """Return an iterator over the results of ``func(start)``,
123 ``func(start + 1)``, ``func(start + 2)``...
124
125 *func* should be a function that accepts one integer argument.
126
127 If *start* is not specified it defaults to 0. It will be incremented each
128 time the iterator is advanced.
129
130 >>> square = lambda x: x ** 2
131 >>> iterator = tabulate(square, -3)
132 >>> take(4, iterator)
133 [9, 4, 1, 0]
134
135 """
136 return map(function, count(start))
137
138
139def tail(n, iterable):
140 """Return an iterator over the last *n* items of *iterable*.
141
142 >>> t = tail(3, 'ABCDEFG')
143 >>> list(t)
144 ['E', 'F', 'G']
145
146 """
147 try:
148 size = len(iterable)
149 except TypeError:
150 return iter(deque(iterable, maxlen=n))
151 else:
152 return islice(iterable, max(0, size - n), None)
153
154
155def consume(iterator, n=None):
156 """Advance *iterable* by *n* steps. If *n* is ``None``, consume it
157 entirely.
158
159 Efficiently exhausts an iterator without returning values. Defaults to
160 consuming the whole iterator, but an optional second argument may be
161 provided to limit consumption.
162
163 >>> i = (x for x in range(10))
164 >>> next(i)
165 0
166 >>> consume(i, 3)
167 >>> next(i)
168 4
169 >>> consume(i)
170 >>> next(i)
171 Traceback (most recent call last):
172 File "<stdin>", line 1, in <module>
173 StopIteration
174
175 If the iterator has fewer items remaining than the provided limit, the
176 whole iterator will be consumed.
177
178 >>> i = (x for x in range(3))
179 >>> consume(i, 5)
180 >>> next(i)
181 Traceback (most recent call last):
182 File "<stdin>", line 1, in <module>
183 StopIteration
184
185 """
186 # Use functions that consume iterators at C speed.
187 if n is None:
188 # feed the entire iterator into a zero-length deque
189 deque(iterator, maxlen=0)
190 else:
191 # advance to the empty slice starting at position n
192 next(islice(iterator, n, n), None)
193
194
195def nth(iterable, n, default=None):
196 """Returns the nth item or a default value.
197
198 >>> l = range(10)
199 >>> nth(l, 3)
200 3
201 >>> nth(l, 20, "zebra")
202 'zebra'
203
204 """
205 return next(islice(iterable, n, None), default)
206
207
208def all_equal(iterable, key=None):
209 """
210 Returns ``True`` if all the elements are equal to each other.
211
212 >>> all_equal('aaaa')
213 True
214 >>> all_equal('aaab')
215 False
216
217 A function that accepts a single argument and returns a transformed version
218 of each input item can be specified with *key*:
219
220 >>> all_equal('AaaA', key=str.casefold)
221 True
222 >>> all_equal([1, 2, 3], key=lambda x: x < 10)
223 True
224
225 """
226 iterator = groupby(iterable, key)
227 for first in iterator:
228 for second in iterator:
229 return False
230 return True
231 return True
232
233
234def quantify(iterable, pred=bool):
235 """Return the how many times the predicate is true.
236
237 >>> quantify([True, False, True])
238 2
239
240 """
241 return sum(map(pred, iterable))
242
243
244def pad_none(iterable):
245 """Returns the sequence of elements and then returns ``None`` indefinitely.
246
247 >>> take(5, pad_none(range(3)))
248 [0, 1, 2, None, None]
249
250 Useful for emulating the behavior of the built-in :func:`map` function.
251
252 See also :func:`padded`.
253
254 """
255 return chain(iterable, repeat(None))
256
257
258padnone = pad_none
259
260
261def ncycles(iterable, n):
262 """Returns the sequence elements *n* times
263
264 >>> list(ncycles(["a", "b"], 3))
265 ['a', 'b', 'a', 'b', 'a', 'b']
266
267 """
268 return chain.from_iterable(repeat(tuple(iterable), n))
269
270
271def dotproduct(vec1, vec2):
272 """Returns the dot product of the two iterables.
273
274 >>> dotproduct([10, 15, 12], [0.65, 0.80, 1.25])
275 33.5
276 >>> 10 * 0.65 + 15 * 0.80 + 12 * 1.25
277 33.5
278
279 In Python 3.12 and later, use ``math.sumprod()`` instead.
280 """
281 return sum(map(mul, vec1, vec2))
282
283
284# math.sumprod is available for Python 3.12+
285try:
286 from math import sumprod as _sumprod
287except ImportError: # pragma: no cover
288 _sumprod = dotproduct
289
290
291def flatten(listOfLists):
292 """Return an iterator flattening one level of nesting in a list of lists.
293
294 >>> list(flatten([[0, 1], [2, 3]]))
295 [0, 1, 2, 3]
296
297 See also :func:`collapse`, which can flatten multiple levels of nesting.
298
299 """
300 return chain.from_iterable(listOfLists)
301
302
303def repeatfunc(func, times=None, *args):
304 """Call *func* with *args* repeatedly, returning an iterable over the
305 results.
306
307 If *times* is specified, the iterable will terminate after that many
308 repetitions:
309
310 >>> from operator import add
311 >>> times = 4
312 >>> args = 3, 5
313 >>> list(repeatfunc(add, times, *args))
314 [8, 8, 8, 8]
315
316 If *times* is ``None`` the iterable will not terminate:
317
318 >>> from random import randrange
319 >>> times = None
320 >>> args = 1, 11
321 >>> take(6, repeatfunc(randrange, times, *args)) # doctest:+SKIP
322 [2, 4, 8, 1, 8, 4]
323
324 """
325 if times is None:
326 return starmap(func, repeat(args))
327 return starmap(func, repeat(args, times))
328
329
330def grouper(iterable, n, incomplete='fill', fillvalue=None):
331 """Group elements from *iterable* into fixed-length groups of length *n*.
332
333 >>> list(grouper('ABCDEF', 3))
334 [('A', 'B', 'C'), ('D', 'E', 'F')]
335
336 The keyword arguments *incomplete* and *fillvalue* control what happens for
337 iterables whose length is not a multiple of *n*.
338
339 When *incomplete* is `'fill'`, the last group will contain instances of
340 *fillvalue*.
341
342 >>> list(grouper('ABCDEFG', 3, incomplete='fill', fillvalue='x'))
343 [('A', 'B', 'C'), ('D', 'E', 'F'), ('G', 'x', 'x')]
344
345 When *incomplete* is `'ignore'`, the last group will not be emitted.
346
347 >>> list(grouper('ABCDEFG', 3, incomplete='ignore', fillvalue='x'))
348 [('A', 'B', 'C'), ('D', 'E', 'F')]
349
350 When *incomplete* is `'strict'`, a `ValueError` will be raised.
351
352 >>> iterator = grouper('ABCDEFG', 3, incomplete='strict')
353 >>> list(iterator) # doctest: +IGNORE_EXCEPTION_DETAIL
354 Traceback (most recent call last):
355 ...
356 ValueError
357
358 """
359 iterators = [iter(iterable)] * n
360 match incomplete:
361 case 'fill':
362 return zip_longest(*iterators, fillvalue=fillvalue)
363 case 'strict':
364 return zip(*iterators, strict=True)
365 case 'ignore':
366 return zip(*iterators)
367 case _:
368 raise ValueError('Expected fill, strict, or ignore')
369
370
371def roundrobin(*iterables):
372 """Visit input iterables in a cycle until each is exhausted.
373
374 >>> list(roundrobin('ABC', 'D', 'EF'))
375 ['A', 'D', 'E', 'B', 'F', 'C']
376
377 This function produces the same output as :func:`interleave_longest`, but
378 may perform better for some inputs (in particular when the number of
379 iterables is small).
380
381 """
382 # Algorithm credited to George Sakkis
383 iterators = map(iter, iterables)
384 for num_active in range(len(iterables), 0, -1):
385 iterators = cycle(islice(iterators, num_active))
386 yield from map(next, iterators)
387
388
389def partition(pred, iterable):
390 """
391 Returns a 2-tuple of iterables derived from the input iterable.
392 The first yields the items that have ``pred(item) == False``.
393 The second yields the items that have ``pred(item) == True``.
394
395 >>> is_odd = lambda x: x % 2 != 0
396 >>> iterable = range(10)
397 >>> even_items, odd_items = partition(is_odd, iterable)
398 >>> list(even_items), list(odd_items)
399 ([0, 2, 4, 6, 8], [1, 3, 5, 7, 9])
400
401 If *pred* is None, :func:`bool` is used.
402
403 >>> iterable = [0, 1, False, True, '', ' ']
404 >>> false_items, true_items = partition(None, iterable)
405 >>> list(false_items), list(true_items)
406 ([0, False, ''], [1, True, ' '])
407
408 """
409 if pred is None:
410 pred = bool
411
412 t1, t2, p = tee(iterable, 3)
413 p1, p2 = tee(map(pred, p))
414 return (compress(t1, map(not_, p1)), compress(t2, p2))
415
416
417def powerset(iterable):
418 """Yields all possible subsets of the iterable.
419
420 >>> list(powerset([1, 2, 3]))
421 [(), (1,), (2,), (3,), (1, 2), (1, 3), (2, 3), (1, 2, 3)]
422
423 :func:`powerset` will operate on iterables that aren't :class:`set`
424 instances, so repeated elements in the input will produce repeated elements
425 in the output.
426
427 >>> seq = [1, 1, 0]
428 >>> list(powerset(seq))
429 [(), (1,), (1,), (0,), (1, 1), (1, 0), (1, 0), (1, 1, 0)]
430
431 For a variant that efficiently yields actual :class:`set` instances, see
432 :func:`powerset_of_sets`.
433 """
434 s = list(iterable)
435 return chain.from_iterable(combinations(s, r) for r in range(len(s) + 1))
436
437
438def unique_everseen(iterable, key=None):
439 """
440 Yield unique elements, preserving order.
441
442 >>> list(unique_everseen('AAAABBBCCDAABBB'))
443 ['A', 'B', 'C', 'D']
444 >>> list(unique_everseen('ABBCcAD', str.lower))
445 ['A', 'B', 'C', 'D']
446
447 Sequences with a mix of hashable and unhashable items can be used.
448 The function will be slower (i.e., `O(n^2)`) for unhashable items.
449
450 Remember that ``list`` objects are unhashable - you can use the *key*
451 parameter to transform the list to a tuple (which is hashable) to
452 avoid a slowdown.
453
454 >>> iterable = ([1, 2], [2, 3], [1, 2])
455 >>> list(unique_everseen(iterable)) # Slow
456 [[1, 2], [2, 3]]
457 >>> list(unique_everseen(iterable, key=tuple)) # Faster
458 [[1, 2], [2, 3]]
459
460 Similarly, you may want to convert unhashable ``set`` objects with
461 ``key=frozenset``. For ``dict`` objects,
462 ``key=lambda x: frozenset(x.items())`` can be used.
463
464 """
465 seenset = set()
466 seenset_add = seenset.add
467 seenlist = []
468 seenlist_add = seenlist.append
469 use_key = key is not None
470
471 for element in iterable:
472 k = key(element) if use_key else element
473 try:
474 if k not in seenset:
475 seenset_add(k)
476 yield element
477 except TypeError:
478 if k not in seenlist:
479 seenlist_add(k)
480 yield element
481
482
483def unique_justseen(iterable, key=None):
484 """Yields elements in order, ignoring serial duplicates
485
486 >>> list(unique_justseen('AAAABBBCCDAABBB'))
487 ['A', 'B', 'C', 'D', 'A', 'B']
488 >>> list(unique_justseen('ABBCcAD', str.lower))
489 ['A', 'B', 'C', 'A', 'D']
490
491 """
492 if key is None:
493 return map(itemgetter(0), groupby(iterable))
494
495 return map(next, map(itemgetter(1), groupby(iterable, key)))
496
497
498def unique(iterable, key=None, reverse=False):
499 """Yields unique elements in sorted order.
500
501 >>> list(unique([[1, 2], [3, 4], [1, 2]]))
502 [[1, 2], [3, 4]]
503
504 *key* and *reverse* are passed to :func:`sorted`.
505
506 >>> list(unique('ABBcCAD', str.casefold))
507 ['A', 'B', 'c', 'D']
508 >>> list(unique('ABBcCAD', str.casefold, reverse=True))
509 ['D', 'c', 'B', 'A']
510
511 The elements in *iterable* need not be hashable, but they must be
512 comparable for sorting to work.
513 """
514 sequenced = sorted(iterable, key=key, reverse=reverse)
515 return unique_justseen(sequenced, key=key)
516
517
518def iter_except(func, exception, first=None):
519 """Yields results from a function repeatedly until an exception is raised.
520
521 Converts a call-until-exception interface to an iterator interface.
522 Like ``iter(func, sentinel)``, but uses an exception instead of a sentinel
523 to end the loop.
524
525 >>> l = [0, 1, 2]
526 >>> list(iter_except(l.pop, IndexError))
527 [2, 1, 0]
528
529 Multiple exceptions can be specified as a stopping condition:
530
531 >>> l = [1, 2, 3, '...', 4, 5, 6]
532 >>> list(iter_except(lambda: 1 + l.pop(), (IndexError, TypeError)))
533 [7, 6, 5]
534 >>> list(iter_except(lambda: 1 + l.pop(), (IndexError, TypeError)))
535 [4, 3, 2]
536 >>> list(iter_except(lambda: 1 + l.pop(), (IndexError, TypeError)))
537 []
538
539 """
540 with suppress(exception):
541 if first is not None:
542 yield first()
543 while True:
544 yield func()
545
546
547def first_true(iterable, default=None, pred=None):
548 """
549 Returns the first true value in the iterable.
550
551 If no true value is found, returns *default*
552
553 If *pred* is not None, returns the first item for which
554 ``pred(item) == True`` .
555
556 >>> first_true(range(10))
557 1
558 >>> first_true(range(10), pred=lambda x: x > 5)
559 6
560 >>> first_true(range(10), default='missing', pred=lambda x: x > 9)
561 'missing'
562
563 """
564 return next(filter(pred, iterable), default)
565
566
567def random_product(*iterables, repeat=1):
568 """Draw an item at random from each of the input iterables.
569
570 >>> random_product('abc', range(4), 'XYZ') # doctest:+SKIP
571 ('c', 3, 'Z')
572
573 If *repeat* is provided as a keyword argument, that many items will be
574 drawn from each iterable.
575
576 >>> random_product('abcd', range(4), repeat=2) # doctest:+SKIP
577 ('a', 2, 'd', 3)
578
579 This equivalent to taking a random selection from
580 ``itertools.product(*args, repeat=repeat)``.
581
582 """
583 pools = tuple(map(tuple, iterables)) * repeat
584 return tuple(map(choice, pools))
585
586
587def random_permutation(iterable, r=None):
588 """Return a random *r* length permutation of the elements in *iterable*.
589
590 If *r* is not specified or is ``None``, then *r* defaults to the length of
591 *iterable*.
592
593 >>> random_permutation(range(5)) # doctest:+SKIP
594 (3, 4, 0, 1, 2)
595
596 This equivalent to taking a random selection from
597 ``itertools.permutations(iterable, r)``.
598
599 """
600 pool = tuple(iterable)
601 r = len(pool) if r is None else r
602 return tuple(sample(pool, r))
603
604
605def random_combination(iterable, r):
606 """Return a random *r* length subsequence of the elements in *iterable*.
607
608 >>> random_combination(range(5), 3) # doctest:+SKIP
609 (2, 3, 4)
610
611 This equivalent to taking a random selection from
612 ``itertools.combinations(iterable, r)``.
613
614 """
615 pool = tuple(iterable)
616 n = len(pool)
617 indices = sorted(sample(range(n), r))
618 return tuple([pool[i] for i in indices])
619
620
621def random_combination_with_replacement(iterable, r):
622 """Return a random *r* length subsequence of elements in *iterable*,
623 allowing individual elements to be repeated.
624
625 >>> random_combination_with_replacement(range(3), 5) # doctest:+SKIP
626 (0, 0, 1, 2, 2)
627
628 This equivalent to taking a random selection from
629 ``itertools.combinations_with_replacement(iterable, r)``.
630
631 """
632 pool = tuple(iterable)
633 n = len(pool)
634 indices = sorted(randrange(n) for i in range(r))
635 return tuple([pool[i] for i in indices])
636
637
638def nth_combination(iterable, r, index):
639 """Equivalent to ``list(combinations(iterable, r))[index]``.
640
641 The subsequences of *iterable* that are of length *r* can be ordered
642 lexicographically. :func:`nth_combination` computes the subsequence at
643 sort position *index* directly, without computing the previous
644 subsequences.
645
646 >>> nth_combination(range(5), 3, 5)
647 (0, 3, 4)
648
649 ``ValueError`` will be raised If *r* is negative.
650 ``IndexError`` will be raised if the given *index* is invalid.
651 """
652 pool = tuple(iterable)
653 n = len(pool)
654 c = comb(n, r)
655
656 if index < 0:
657 index += c
658 if not 0 <= index < c:
659 raise IndexError
660
661 result = []
662 while r:
663 c, n, r = c * r // n, n - 1, r - 1
664 while index >= c:
665 index -= c
666 c, n = c * (n - r) // n, n - 1
667 result.append(pool[-1 - n])
668
669 return tuple(result)
670
671
672def prepend(value, iterator):
673 """Yield *value*, followed by the elements in *iterator*.
674
675 >>> value = '0'
676 >>> iterator = ['1', '2', '3']
677 >>> list(prepend(value, iterator))
678 ['0', '1', '2', '3']
679
680 To prepend multiple values, see :func:`itertools.chain`
681 or :func:`value_chain`.
682
683 """
684 return chain([value], iterator)
685
686
687def convolve(signal, kernel):
688 """Discrete linear convolution of two iterables.
689 Equivalent to polynomial multiplication.
690
691 For example, multiplying ``(x² -x - 20)`` by ``(x - 3)``
692 gives ``(x³ -4x² -17x + 60)``.
693
694 >>> list(convolve([1, -1, -20], [1, -3]))
695 [1, -4, -17, 60]
696
697 Examples of popular kinds of kernels:
698
699 * The kernel ``[0.25, 0.25, 0.25, 0.25]`` computes a moving average.
700 For image data, this blurs the image and reduces noise.
701 * The kernel ``[1/2, 0, -1/2]`` estimates the first derivative of
702 a function evaluated at evenly spaced inputs.
703 * The kernel ``[1, -2, 1]`` estimates the second derivative of a
704 function evaluated at evenly spaced inputs.
705
706 Convolutions are mathematically commutative; however, the inputs are
707 evaluated differently. The signal is consumed lazily and can be
708 infinite. The kernel is fully consumed before the calculations begin.
709
710 Supports all numeric types: int, float, complex, Decimal, Fraction.
711
712 References:
713
714 * Article: https://betterexplained.com/articles/intuitive-convolution/
715 * Video by 3Blue1Brown: https://www.youtube.com/watch?v=KuXjwB4LzSA
716
717 """
718 # This implementation comes from an older version of the itertools
719 # documentation. While the newer implementation is a bit clearer,
720 # this one was kept because the inlined window logic is faster
721 # and it avoids an unnecessary deque-to-tuple conversion.
722 kernel = tuple(kernel)[::-1]
723 n = len(kernel)
724 window = deque([0], maxlen=n) * n
725 for x in chain(signal, repeat(0, n - 1)):
726 window.append(x)
727 yield _sumprod(kernel, window)
728
729
730def before_and_after(predicate, it):
731 """A variant of :func:`takewhile` that allows complete access to the
732 remainder of the iterator.
733
734 >>> it = iter('ABCdEfGhI')
735 >>> all_upper, remainder = before_and_after(str.isupper, it)
736 >>> ''.join(all_upper)
737 'ABC'
738 >>> ''.join(remainder) # takewhile() would lose the 'd'
739 'dEfGhI'
740
741 Note that the first iterator must be fully consumed before the second
742 iterator can generate valid results.
743 """
744 trues, after = tee(it)
745 trues = compress(takewhile(predicate, trues), zip(after))
746 return trues, after
747
748
749def triplewise(iterable):
750 """Return overlapping triplets from *iterable*.
751
752 >>> list(triplewise('ABCDE'))
753 [('A', 'B', 'C'), ('B', 'C', 'D'), ('C', 'D', 'E')]
754
755 """
756 # This deviates from the itertools documentation recipe - see
757 # https://github.com/more-itertools/more-itertools/issues/889
758 t1, t2, t3 = tee(iterable, 3)
759 next(t3, None)
760 next(t3, None)
761 next(t2, None)
762 return zip(t1, t2, t3)
763
764
765def _sliding_window_islice(iterable, n):
766 # Fast path for small, non-zero values of n.
767 iterators = tee(iterable, n)
768 for i, iterator in enumerate(iterators):
769 next(islice(iterator, i, i), None)
770 return zip(*iterators)
771
772
773def _sliding_window_deque(iterable, n):
774 # Normal path for other values of n.
775 iterator = iter(iterable)
776 window = deque(islice(iterator, n - 1), maxlen=n)
777 for x in iterator:
778 window.append(x)
779 yield tuple(window)
780
781
782def sliding_window(iterable, n):
783 """Return a sliding window of width *n* over *iterable*.
784
785 >>> list(sliding_window(range(6), 4))
786 [(0, 1, 2, 3), (1, 2, 3, 4), (2, 3, 4, 5)]
787
788 If *iterable* has fewer than *n* items, then nothing is yielded:
789
790 >>> list(sliding_window(range(3), 4))
791 []
792
793 For a variant with more features, see :func:`windowed`.
794 """
795 if n > 20:
796 return _sliding_window_deque(iterable, n)
797 elif n > 2:
798 return _sliding_window_islice(iterable, n)
799 elif n == 2:
800 return pairwise(iterable)
801 elif n == 1:
802 return zip(iterable)
803 else:
804 raise ValueError(f'n should be at least one, not {n}')
805
806
807def subslices(iterable):
808 """Return all contiguous non-empty subslices of *iterable*.
809
810 >>> list(subslices('ABC'))
811 [['A'], ['A', 'B'], ['A', 'B', 'C'], ['B'], ['B', 'C'], ['C']]
812
813 This is similar to :func:`substrings`, but emits items in a different
814 order.
815 """
816 seq = list(iterable)
817 slices = starmap(slice, combinations(range(len(seq) + 1), 2))
818 return map(getitem, repeat(seq), slices)
819
820
821def polynomial_from_roots(roots):
822 """Compute a polynomial's coefficients from its roots.
823
824 >>> roots = [5, -4, 3] # (x - 5) * (x + 4) * (x - 3)
825 >>> polynomial_from_roots(roots) # x³ - 4 x² - 17 x + 60
826 [1, -4, -17, 60]
827
828 Note that polynomial coefficients are specified in descending power order.
829
830 Supports all numeric types: int, float, complex, Decimal, Fraction.
831 """
832
833 # This recipe differs from the one in itertools docs in that it
834 # applies list() after each call to convolve(). This avoids
835 # hitting stack limits with nested generators.
836
837 poly = [1]
838 for root in roots:
839 poly = list(convolve(poly, (1, -root)))
840 return poly
841
842
843def iter_index(iterable, value, start=0, stop=None):
844 """Yield the index of each place in *iterable* that *value* occurs,
845 beginning with index *start* and ending before index *stop*.
846
847
848 >>> list(iter_index('AABCADEAF', 'A'))
849 [0, 1, 4, 7]
850 >>> list(iter_index('AABCADEAF', 'A', 1)) # start index is inclusive
851 [1, 4, 7]
852 >>> list(iter_index('AABCADEAF', 'A', 1, 7)) # stop index is not inclusive
853 [1, 4]
854
855 The behavior for non-scalar *values* matches the built-in Python types.
856
857 >>> list(iter_index('ABCDABCD', 'AB'))
858 [0, 4]
859 >>> list(iter_index([0, 1, 2, 3, 0, 1, 2, 3], [0, 1]))
860 []
861 >>> list(iter_index([[0, 1], [2, 3], [0, 1], [2, 3]], [0, 1]))
862 [0, 2]
863
864 See :func:`locate` for a more general means of finding the indexes
865 associated with particular values.
866
867 """
868 seq_index = getattr(iterable, 'index', None)
869 if seq_index is None:
870 # Slow path for general iterables
871 iterator = islice(iterable, start, stop)
872 for i, element in enumerate(iterator, start):
873 if element is value or element == value:
874 yield i
875 else:
876 # Fast path for sequences
877 stop = len(iterable) if stop is None else stop
878 i = start - 1
879 with suppress(ValueError):
880 while True:
881 yield (i := seq_index(value, i + 1, stop))
882
883
884def sieve(n):
885 """Yield the primes less than n.
886
887 >>> list(sieve(30))
888 [2, 3, 5, 7, 11, 13, 17, 19, 23, 29]
889
890 """
891 # This implementation comes from an older version of the itertools
892 # documentation. The newer implementation is easier to read but is
893 # less lazy.
894 if n > 2:
895 yield 2
896 start = 3
897 data = bytearray((0, 1)) * (n // 2)
898 for p in iter_index(data, 1, start, stop=isqrt(n) + 1):
899 yield from iter_index(data, 1, start, p * p)
900 data[p * p : n : p + p] = bytes(len(range(p * p, n, p + p)))
901 start = p * p
902 yield from iter_index(data, 1, start)
903
904
905def _batched(iterable, n, *, strict=False): # pragma: no cover
906 """Batch data into tuples of length *n*. If the number of items in
907 *iterable* is not divisible by *n*:
908 * The last batch will be shorter if *strict* is ``False``.
909 * :exc:`ValueError` will be raised if *strict* is ``True``.
910
911 >>> list(batched('ABCDEFG', 3))
912 [('A', 'B', 'C'), ('D', 'E', 'F'), ('G',)]
913
914 On Python 3.13 and above, this is an alias for :func:`itertools.batched`.
915 """
916 if n < 1:
917 raise ValueError('n must be at least one')
918 iterator = iter(iterable)
919 while batch := tuple(islice(iterator, n)):
920 if strict and len(batch) != n:
921 raise ValueError('batched(): incomplete batch')
922 yield batch
923
924
925if hexversion >= 0x30D00A2: # pragma: no cover
926 from itertools import batched as itertools_batched
927
928 def batched(iterable, n, *, strict=False):
929 return itertools_batched(iterable, n, strict=strict)
930
931 batched.__doc__ = _batched.__doc__
932else: # pragma: no cover
933 batched = _batched
934
935
936def transpose(it):
937 """Swap the rows and columns of the input matrix.
938
939 >>> list(transpose([(1, 2, 3), (11, 22, 33)]))
940 [(1, 11), (2, 22), (3, 33)]
941
942 The caller should ensure that the dimensions of the input are compatible.
943 If the input is empty, no output will be produced.
944 """
945 return zip(*it, strict=True)
946
947
948def _is_scalar(value, stringlike=(str, bytes)):
949 "Scalars are bytes, strings, and non-iterables."
950 try:
951 iter(value)
952 except TypeError:
953 return True
954 return isinstance(value, stringlike)
955
956
957def _flatten_tensor(tensor):
958 "Depth-first iterator over scalars in a tensor."
959 iterator = iter(tensor)
960 while True:
961 try:
962 value = next(iterator)
963 except StopIteration:
964 return iterator
965 iterator = chain((value,), iterator)
966 if _is_scalar(value):
967 return iterator
968 iterator = chain.from_iterable(iterator)
969
970
971def reshape(matrix, shape):
972 """Change the shape of a *matrix*.
973
974 If *shape* is an integer, the matrix must be two dimensional
975 and the shape is interpreted as the desired number of columns:
976
977 >>> matrix = [(0, 1), (2, 3), (4, 5)]
978 >>> cols = 3
979 >>> list(reshape(matrix, cols))
980 [(0, 1, 2), (3, 4, 5)]
981
982 If *shape* is a tuple (or other iterable), the input matrix can have
983 any number of dimensions. It will first be flattened and then rebuilt
984 to the desired shape which can also be multidimensional:
985
986 >>> matrix = [(0, 1), (2, 3), (4, 5)] # Start with a 3 x 2 matrix
987
988 >>> list(reshape(matrix, (2, 3))) # Make a 2 x 3 matrix
989 [(0, 1, 2), (3, 4, 5)]
990
991 >>> list(reshape(matrix, (6,))) # Make a vector of length six
992 [0, 1, 2, 3, 4, 5]
993
994 >>> list(reshape(matrix, (2, 1, 3, 1))) # Make 2 x 1 x 3 x 1 tensor
995 [(((0,), (1,), (2,)),), (((3,), (4,), (5,)),)]
996
997 Each dimension is assumed to be uniform, either all arrays or all scalars.
998 Flattening stops when the first value in a dimension is a scalar.
999 Scalars are bytes, strings, and non-iterables.
1000 The reshape iterator stops when the requested shape is complete
1001 or when the input is exhausted, whichever comes first.
1002
1003 """
1004 if isinstance(shape, int):
1005 return batched(chain.from_iterable(matrix), shape)
1006 first_dim, *dims = shape
1007 scalar_stream = _flatten_tensor(matrix)
1008 reshaped = reduce(batched, reversed(dims), scalar_stream)
1009 return islice(reshaped, first_dim)
1010
1011
1012def matmul(m1, m2):
1013 """Multiply two matrices.
1014
1015 >>> list(matmul([(7, 5), (3, 5)], [(2, 5), (7, 9)]))
1016 [(49, 80), (41, 60)]
1017
1018 The caller should ensure that the dimensions of the input matrices are
1019 compatible with each other.
1020
1021 Supports all numeric types: int, float, complex, Decimal, Fraction.
1022 """
1023 n = len(m2[0])
1024 return batched(starmap(_sumprod, product(m1, transpose(m2))), n)
1025
1026
1027def _factor_pollard(n):
1028 # Return a factor of n using Pollard's rho algorithm.
1029 # Efficient when n is odd and composite.
1030 for b in range(1, n):
1031 x = y = 2
1032 d = 1
1033 while d == 1:
1034 x = (x * x + b) % n
1035 y = (y * y + b) % n
1036 y = (y * y + b) % n
1037 d = gcd(x - y, n)
1038 if d != n:
1039 return d
1040 raise ValueError('prime or under 5') # pragma: no cover
1041
1042
1043_primes_below_211 = tuple(sieve(211))
1044
1045
1046def factor(n):
1047 """Yield the prime factors of n.
1048
1049 >>> list(factor(360))
1050 [2, 2, 2, 3, 3, 5]
1051
1052 Finds small factors with trial division. Larger factors are
1053 either verified as prime with ``is_prime`` or split into
1054 smaller factors with Pollard's rho algorithm.
1055 """
1056
1057 # Corner case reduction
1058 if n < 2:
1059 return
1060
1061 # Trial division reduction
1062 for prime in _primes_below_211:
1063 while not n % prime:
1064 yield prime
1065 n //= prime
1066
1067 # Pollard's rho reduction
1068 primes = []
1069 todo = [n] if n > 1 else []
1070 for n in todo:
1071 if n < 211**2 or is_prime(n):
1072 primes.append(n)
1073 else:
1074 fact = _factor_pollard(n)
1075 todo += (fact, n // fact)
1076 yield from sorted(primes)
1077
1078
1079def polynomial_eval(coefficients, x):
1080 """Evaluate a polynomial at a specific value.
1081
1082 Computes with better numeric stability than Horner's method.
1083
1084 Evaluate ``x^3 - 4 * x^2 - 17 * x + 60`` at ``x = 2.5``:
1085
1086 >>> coefficients = [1, -4, -17, 60]
1087 >>> x = 2.5
1088 >>> polynomial_eval(coefficients, x)
1089 8.125
1090
1091 Note that polynomial coefficients are specified in descending power order.
1092
1093 Supports all numeric types: int, float, complex, Decimal, Fraction.
1094 """
1095 n = len(coefficients)
1096 if n == 0:
1097 return type(x)(0)
1098 powers = map(pow, repeat(x), reversed(range(n)))
1099 return _sumprod(coefficients, powers)
1100
1101
1102def sum_of_squares(it):
1103 """Return the sum of the squares of the input values.
1104
1105 >>> sum_of_squares([10, 20, 30])
1106 1400
1107
1108 Supports all numeric types: int, float, complex, Decimal, Fraction.
1109 """
1110 return _sumprod(*tee(it))
1111
1112
1113def polynomial_derivative(coefficients):
1114 """Compute the first derivative of a polynomial.
1115
1116 Evaluate the derivative of ``x³ - 4 x² - 17 x + 60``:
1117
1118 >>> coefficients = [1, -4, -17, 60]
1119 >>> derivative_coefficients = polynomial_derivative(coefficients)
1120 >>> derivative_coefficients
1121 [3, -8, -17]
1122
1123 Note that polynomial coefficients are specified in descending power order.
1124
1125 Supports all numeric types: int, float, complex, Decimal, Fraction.
1126 """
1127 n = len(coefficients)
1128 powers = reversed(range(1, n))
1129 return list(map(mul, coefficients, powers))
1130
1131
1132def totient(n):
1133 """Return the count of natural numbers up to *n* that are coprime with *n*.
1134
1135 Euler's totient function φ(n) gives the number of totatives.
1136 Totative are integers k in the range 1 ≤ k ≤ n such that gcd(n, k) = 1.
1137
1138 >>> n = 9
1139 >>> totient(n)
1140 6
1141
1142 >>> totatives = [x for x in range(1, n) if gcd(n, x) == 1]
1143 >>> totatives
1144 [1, 2, 4, 5, 7, 8]
1145 >>> len(totatives)
1146 6
1147
1148 Reference: https://en.wikipedia.org/wiki/Euler%27s_totient_function
1149
1150 """
1151 for prime in set(factor(n)):
1152 n -= n // prime
1153 return n
1154
1155
1156# Miller–Rabin primality test: https://oeis.org/A014233
1157_perfect_tests = [
1158 (2047, (2,)),
1159 (9080191, (31, 73)),
1160 (4759123141, (2, 7, 61)),
1161 (1122004669633, (2, 13, 23, 1662803)),
1162 (2152302898747, (2, 3, 5, 7, 11)),
1163 (3474749660383, (2, 3, 5, 7, 11, 13)),
1164 (18446744073709551616, (2, 325, 9375, 28178, 450775, 9780504, 1795265022)),
1165 (
1166 3317044064679887385961981,
1167 (2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41),
1168 ),
1169]
1170
1171
1172@lru_cache
1173def _shift_to_odd(n):
1174 'Return s, d such that 2**s * d == n'
1175 s = ((n - 1) ^ n).bit_length() - 1
1176 d = n >> s
1177 assert (1 << s) * d == n and d & 1 and s >= 0
1178 return s, d
1179
1180
1181def _strong_probable_prime(n, base):
1182 assert (n > 2) and (n & 1) and (2 <= base < n)
1183
1184 s, d = _shift_to_odd(n - 1)
1185
1186 x = pow(base, d, n)
1187 if x == 1 or x == n - 1:
1188 return True
1189
1190 for _ in range(s - 1):
1191 x = x * x % n
1192 if x == n - 1:
1193 return True
1194
1195 return False
1196
1197
1198# Separate instance of Random() that doesn't share state
1199# with the default user instance of Random().
1200_private_randrange = random.Random().randrange
1201
1202
1203def is_prime(n):
1204 """Return ``True`` if *n* is prime and ``False`` otherwise.
1205
1206 Basic examples:
1207
1208 >>> is_prime(37)
1209 True
1210 >>> is_prime(3 * 13)
1211 False
1212 >>> is_prime(18_446_744_073_709_551_557)
1213 True
1214
1215 Find the next prime over one billion:
1216
1217 >>> next(filter(is_prime, count(10**9)))
1218 1000000007
1219
1220 Generate random primes up to 200 bits and up to 60 decimal digits:
1221
1222 >>> from random import seed, randrange, getrandbits
1223 >>> seed(18675309)
1224
1225 >>> next(filter(is_prime, map(getrandbits, repeat(200))))
1226 893303929355758292373272075469392561129886005037663238028407
1227
1228 >>> next(filter(is_prime, map(randrange, repeat(10**60))))
1229 269638077304026462407872868003560484232362454342414618963649
1230
1231 This function is exact for values of *n* below 10**24. For larger inputs,
1232 the probabilistic Miller-Rabin primality test has a less than 1 in 2**128
1233 chance of a false positive.
1234 """
1235
1236 if n < 17:
1237 return n in {2, 3, 5, 7, 11, 13}
1238
1239 if not (n & 1 and n % 3 and n % 5 and n % 7 and n % 11 and n % 13):
1240 return False
1241
1242 for limit, bases in _perfect_tests:
1243 if n < limit:
1244 break
1245 else:
1246 bases = (_private_randrange(2, n - 1) for i in range(64))
1247
1248 return all(_strong_probable_prime(n, base) for base in bases)
1249
1250
1251def loops(n):
1252 """Returns an iterable with *n* elements for efficient looping.
1253 Like ``range(n)`` but doesn't create integers.
1254
1255 >>> i = 0
1256 >>> for _ in loops(5):
1257 ... i += 1
1258 >>> i
1259 5
1260
1261 """
1262 return repeat(None, n)
1263
1264
1265def multinomial(*counts):
1266 """Number of distinct arrangements of a multiset.
1267
1268 The expression ``multinomial(3, 4, 2)`` has several equivalent
1269 interpretations:
1270
1271 * In the expansion of ``(a + b + c)⁹``, the coefficient of the
1272 ``a³b⁴c²`` term is 1260.
1273
1274 * There are 1260 distinct ways to arrange 9 balls consisting of 3 reds, 4
1275 greens, and 2 blues.
1276
1277 * There are 1260 unique ways to place 9 distinct objects into three bins
1278 with sizes 3, 4, and 2.
1279
1280 The :func:`multinomial` function computes the length of
1281 :func:`distinct_permutations`. For example, there are 83,160 distinct
1282 anagrams of the word "abracadabra":
1283
1284 >>> from more_itertools import distinct_permutations, ilen
1285 >>> ilen(distinct_permutations('abracadabra'))
1286 83160
1287
1288 This can be computed directly from the letter counts, 5a 2b 2r 1c 1d:
1289
1290 >>> from collections import Counter
1291 >>> list(Counter('abracadabra').values())
1292 [5, 2, 2, 1, 1]
1293 >>> multinomial(5, 2, 2, 1, 1)
1294 83160
1295
1296 A binomial coefficient is a special case of multinomial where there are
1297 only two categories. For example, the number of ways to arrange 12 balls
1298 with 5 reds and 7 blues is ``multinomial(5, 7)`` or ``math.comb(12, 5)``.
1299
1300 Likewise, factorial is a special case of multinomial where
1301 the multiplicities are all just 1 so that
1302 ``multinomial(1, 1, 1, 1, 1, 1, 1) == math.factorial(7)``.
1303
1304 Reference: https://en.wikipedia.org/wiki/Multinomial_theorem
1305
1306 """
1307 return prod(map(comb, accumulate(counts), counts))
1308
1309
1310def _running_median_minheap_and_maxheap(iterator): # pragma: no cover
1311 "Non-windowed running_median() for Python 3.14+"
1312
1313 read = iterator.__next__
1314 lo = [] # max-heap
1315 hi = [] # min-heap (same size as or one smaller than lo)
1316
1317 with suppress(StopIteration):
1318 while True:
1319 heappush_max(lo, heappushpop(hi, read()))
1320 yield lo[0]
1321
1322 heappush(hi, heappushpop_max(lo, read()))
1323 yield (lo[0] + hi[0]) / 2
1324
1325
1326def _running_median_minheap_only(iterator): # pragma: no cover
1327 "Backport of non-windowed running_median() for Python 3.13 and prior."
1328
1329 read = iterator.__next__
1330 lo = [] # max-heap (actually a minheap with negated values)
1331 hi = [] # min-heap (same size as or one smaller than lo)
1332
1333 with suppress(StopIteration):
1334 while True:
1335 heappush(lo, -heappushpop(hi, read()))
1336 yield -lo[0]
1337
1338 heappush(hi, -heappushpop(lo, -read()))
1339 yield (hi[0] - lo[0]) / 2
1340
1341
1342def _running_median_windowed(iterator, maxlen):
1343 "Yield median of values in a sliding window."
1344
1345 window = deque()
1346 ordered = []
1347
1348 for x in iterator:
1349 window.append(x)
1350 insort(ordered, x)
1351
1352 if len(ordered) > maxlen:
1353 i = bisect_left(ordered, window.popleft())
1354 del ordered[i]
1355
1356 n = len(ordered)
1357 m = n // 2
1358 yield ordered[m] if n & 1 else (ordered[m - 1] + ordered[m]) / 2
1359
1360
1361def running_median(iterable, *, maxlen=None):
1362 """Cumulative median of values seen so far or values in a sliding window.
1363
1364 Set *maxlen* to a positive integer to specify the maximum size
1365 of the sliding window. The default of *None* is equivalent to
1366 an unbounded window.
1367
1368 For example:
1369
1370 >>> list(running_median([5.0, 9.0, 4.0, 12.0, 8.0, 9.0]))
1371 [5.0, 7.0, 5.0, 7.0, 8.0, 8.5]
1372 >>> list(running_median([5.0, 9.0, 4.0, 12.0, 8.0, 9.0], maxlen=3))
1373 [5.0, 7.0, 5.0, 9.0, 8.0, 9.0]
1374
1375 Supports numeric types such as int, float, Decimal, and Fraction,
1376 but not complex numbers which are unorderable.
1377
1378 On version Python 3.13 and prior, max-heaps are simulated with
1379 negative values. The negation causes Decimal inputs to apply context
1380 rounding, making the results slightly different than that obtained
1381 by statistics.median().
1382 """
1383
1384 iterator = iter(iterable)
1385
1386 if maxlen is not None:
1387 maxlen = index(maxlen)
1388 if maxlen <= 0:
1389 raise ValueError('Window size should be positive')
1390 return _running_median_windowed(iterator, maxlen)
1391
1392 if not _max_heap_available:
1393 return _running_median_minheap_only(iterator) # pragma: no cover
1394
1395 return _running_median_minheap_and_maxheap(iterator) # pragma: no cover
1396
1397
1398def random_derangement(iterable):
1399 """Return a random derangement of elements in the iterable.
1400
1401 Equivalent to but much faster than ``choice(list(derangements(iterable)))``.
1402
1403 """
1404 seq = tuple(iterable)
1405 if len(seq) < 2:
1406 if len(seq) == 0:
1407 return ()
1408 raise IndexError('No derangments to choose from')
1409 perm = list(range(len(seq)))
1410 start = tuple(perm)
1411 while True:
1412 shuffle(perm)
1413 if not any(map(is_, start, perm)):
1414 return itemgetter(*perm)(seq)