1"""Imported from the recipes section of the itertools documentation.
2
3All functions taken from the recipes section of the itertools library docs
4[1]_.
5Some backward-compatible usability improvements have been made.
6
7.. [1] http://docs.python.org/library/itertools.html#recipes
8
9"""
10
11import random
12
13from bisect import bisect_left, insort
14from collections import deque
15from contextlib import suppress
16from functools import lru_cache, reduce
17from heapq import heappush, heappushpop
18from itertools import (
19 accumulate,
20 chain,
21 combinations,
22 compress,
23 count,
24 cycle,
25 filterfalse,
26 groupby,
27 islice,
28 pairwise as itertools_pairwise,
29 product,
30 repeat,
31 starmap,
32 takewhile,
33 tee,
34 zip_longest,
35)
36from math import prod, comb, isqrt, gcd
37from operator import mul, getitem, index, is_, itemgetter, truediv
38from random import randrange, sample, choice, shuffle
39from sys import hexversion
40
41__all__ = [
42 'all_equal',
43 'batched',
44 'before_and_after',
45 'consume',
46 'convolve',
47 'dotproduct',
48 'first_true',
49 'factor',
50 'flatten',
51 'grouper',
52 'is_prime',
53 'iter_except',
54 'iter_index',
55 'loops',
56 'matmul',
57 'multinomial',
58 'ncycles',
59 'nth',
60 'nth_combination',
61 'padnone',
62 'pad_none',
63 'pairwise',
64 'partition',
65 'polynomial_eval',
66 'polynomial_from_roots',
67 'polynomial_derivative',
68 'powerset',
69 'prepend',
70 'quantify',
71 'reshape',
72 'random_combination_with_replacement',
73 'random_combination',
74 'random_derangement',
75 'random_permutation',
76 'random_product',
77 'repeatfunc',
78 'roundrobin',
79 'running_mean',
80 'running_median',
81 'sieve',
82 'sliding_window',
83 'subslices',
84 'sum_of_squares',
85 'tabulate',
86 'tail',
87 'take',
88 'totient',
89 'transpose',
90 'triplewise',
91 'unique',
92 'unique_everseen',
93 'unique_justseen',
94]
95
96_marker = object()
97
98
99# heapq max-heap functions are available for Python 3.14+
100try:
101 from heapq import heappush_max, heappushpop_max
102except ImportError: # pragma: no cover
103 _max_heap_available = False
104else: # pragma: no cover
105 _max_heap_available = True
106
107
108def take(n, iterable):
109 """Return first *n* items of the *iterable* as a list.
110
111 >>> take(3, range(10))
112 [0, 1, 2]
113
114 If there are fewer than *n* items in the iterable, all of them are
115 returned.
116
117 >>> take(10, range(3))
118 [0, 1, 2]
119
120 """
121 return list(islice(iterable, n))
122
123
124def tabulate(function, start=0):
125 """Return an iterator over the results of ``func(start)``,
126 ``func(start + 1)``, ``func(start + 2)``...
127
128 *func* should be a function that accepts one integer argument.
129
130 If *start* is not specified it defaults to 0. It will be incremented each
131 time the iterator is advanced.
132
133 >>> square = lambda x: x ** 2
134 >>> iterator = tabulate(square, -3)
135 >>> take(4, iterator)
136 [9, 4, 1, 0]
137
138 """
139 return map(function, count(start))
140
141
142def tail(n, iterable):
143 """Return an iterator over the last *n* items of *iterable*.
144
145 >>> t = tail(3, 'ABCDEFG')
146 >>> list(t)
147 ['E', 'F', 'G']
148
149 """
150 try:
151 size = len(iterable)
152 except TypeError:
153 return iter(deque(iterable, maxlen=n))
154 else:
155 return islice(iterable, max(0, size - n), None)
156
157
158def consume(iterator, n=None):
159 """Advance *iterable* by *n* steps. If *n* is ``None``, consume it
160 entirely.
161
162 Efficiently exhausts an iterator without returning values. Defaults to
163 consuming the whole iterator, but an optional second argument may be
164 provided to limit consumption.
165
166 >>> i = (x for x in range(10))
167 >>> next(i)
168 0
169 >>> consume(i, 3)
170 >>> next(i)
171 4
172 >>> consume(i)
173 >>> next(i)
174 Traceback (most recent call last):
175 File "<stdin>", line 1, in <module>
176 StopIteration
177
178 If the iterator has fewer items remaining than the provided limit, the
179 whole iterator will be consumed.
180
181 >>> i = (x for x in range(3))
182 >>> consume(i, 5)
183 >>> next(i)
184 Traceback (most recent call last):
185 File "<stdin>", line 1, in <module>
186 StopIteration
187
188 """
189 # Use functions that consume iterators at C speed.
190 if n is None:
191 # feed the entire iterator into a zero-length deque
192 deque(iterator, maxlen=0)
193 else:
194 # advance to the empty slice starting at position n
195 next(islice(iterator, n, n), None)
196
197
198def nth(iterable, n, default=None):
199 """Returns the nth item or a default value.
200
201 >>> l = range(10)
202 >>> nth(l, 3)
203 3
204 >>> nth(l, 20, "zebra")
205 'zebra'
206
207 """
208 return next(islice(iterable, n, None), default)
209
210
211def all_equal(iterable, key=None):
212 """
213 Returns ``True`` if all the elements are equal to each other.
214
215 >>> all_equal('aaaa')
216 True
217 >>> all_equal('aaab')
218 False
219
220 A function that accepts a single argument and returns a transformed version
221 of each input item can be specified with *key*:
222
223 >>> all_equal('AaaA', key=str.casefold)
224 True
225 >>> all_equal([1, 2, 3], key=lambda x: x < 10)
226 True
227
228 """
229 iterator = groupby(iterable, key)
230 for first in iterator:
231 for second in iterator:
232 return False
233 return True
234 return True
235
236
237def quantify(iterable, pred=bool):
238 """Return the how many times the predicate is true.
239
240 >>> quantify([True, False, True])
241 2
242
243 """
244 return sum(map(pred, iterable))
245
246
247def pad_none(iterable):
248 """Returns the sequence of elements and then returns ``None`` indefinitely.
249
250 >>> take(5, pad_none(range(3)))
251 [0, 1, 2, None, None]
252
253 Useful for emulating the behavior of the built-in :func:`map` function.
254
255 See also :func:`padded`.
256
257 """
258 return chain(iterable, repeat(None))
259
260
261padnone = pad_none
262
263
264def ncycles(iterable, n):
265 """Returns the sequence elements *n* times
266
267 >>> list(ncycles(["a", "b"], 3))
268 ['a', 'b', 'a', 'b', 'a', 'b']
269
270 """
271 return chain.from_iterable(repeat(tuple(iterable), n))
272
273
274def dotproduct(vec1, vec2):
275 """Returns the dot product of the two iterables.
276
277 >>> dotproduct([10, 15, 12], [0.65, 0.80, 1.25])
278 33.5
279 >>> 10 * 0.65 + 15 * 0.80 + 12 * 1.25
280 33.5
281
282 In Python 3.12 and later, use ``math.sumprod()`` instead.
283 """
284 return sum(map(mul, vec1, vec2))
285
286
287# math.sumprod is available for Python 3.12+
288try:
289 from math import sumprod as _sumprod
290except ImportError: # pragma: no cover
291 _sumprod = dotproduct
292
293
294def flatten(list_of_lists):
295 """Return an iterator flattening one level of nesting in a list of lists.
296
297 >>> list(flatten([[0, 1], [2, 3]]))
298 [0, 1, 2, 3]
299
300 See also :func:`collapse`, which can flatten multiple levels of nesting.
301
302 """
303 return chain.from_iterable(list_of_lists)
304
305
306def repeatfunc(function, times=None, *args):
307 """Call *function* with *args* repeatedly, returning an iterable over the
308 results.
309
310 If *times* is specified, the iterable will terminate after that many
311 repetitions:
312
313 >>> from operator import add
314 >>> times = 4
315 >>> args = 3, 5
316 >>> list(repeatfunc(add, times, *args))
317 [8, 8, 8, 8]
318
319 If *times* is ``None`` the iterable will not terminate:
320
321 >>> from random import randrange
322 >>> times = None
323 >>> args = 1, 11
324 >>> take(6, repeatfunc(randrange, times, *args)) # doctest:+SKIP
325 [2, 4, 8, 1, 8, 4]
326
327 """
328 if times is None:
329 return starmap(function, repeat(args))
330 return starmap(function, repeat(args, times))
331
332
333def pairwise(iterable):
334 """
335 Wrapper for :func:`itertools.pairwise`.
336
337 .. warning::
338
339 This function is deprecated as of version 11.0.0. It will be removed in a future
340 major release.
341 """
342 return itertools_pairwise(iterable)
343
344
345def grouper(iterable, n, incomplete='fill', fillvalue=None):
346 """Group elements from *iterable* into fixed-length groups of length *n*.
347
348 >>> list(grouper('ABCDEF', 3))
349 [('A', 'B', 'C'), ('D', 'E', 'F')]
350
351 The keyword arguments *incomplete* and *fillvalue* control what happens for
352 iterables whose length is not a multiple of *n*.
353
354 When *incomplete* is `'fill'`, the last group will contain instances of
355 *fillvalue*.
356
357 >>> list(grouper('ABCDEFG', 3, incomplete='fill', fillvalue='x'))
358 [('A', 'B', 'C'), ('D', 'E', 'F'), ('G', 'x', 'x')]
359
360 When *incomplete* is `'ignore'`, the last group will not be emitted.
361
362 >>> list(grouper('ABCDEFG', 3, incomplete='ignore', fillvalue='x'))
363 [('A', 'B', 'C'), ('D', 'E', 'F')]
364
365 When *incomplete* is `'strict'`, a `ValueError` will be raised.
366
367 >>> iterator = grouper('ABCDEFG', 3, incomplete='strict')
368 >>> list(iterator) # doctest: +IGNORE_EXCEPTION_DETAIL
369 Traceback (most recent call last):
370 ...
371 ValueError
372
373 """
374 iterators = [iter(iterable)] * n
375 match incomplete:
376 case 'fill':
377 return zip_longest(*iterators, fillvalue=fillvalue)
378 case 'strict':
379 return zip(*iterators, strict=True)
380 case 'ignore':
381 return zip(*iterators)
382 case _:
383 raise ValueError('Expected fill, strict, or ignore')
384
385
386def roundrobin(*iterables):
387 """Visit input iterables in a cycle until each is exhausted.
388
389 >>> list(roundrobin('ABC', 'D', 'EF'))
390 ['A', 'D', 'E', 'B', 'F', 'C']
391
392 This function produces the same output as :func:`interleave_longest`, but
393 may perform better for some inputs (in particular when the number of
394 iterables is small).
395
396 """
397 # Algorithm credited to George Sakkis
398 iterators = map(iter, iterables)
399 for num_active in range(len(iterables), 0, -1):
400 iterators = cycle(islice(iterators, num_active))
401 yield from map(next, iterators)
402
403
404def partition(pred, iterable):
405 """
406 Returns a 2-tuple of iterables derived from the input iterable.
407 The first yields the items that have ``pred(item) == False``.
408 The second yields the items that have ``pred(item) == True``.
409
410 >>> is_odd = lambda x: x % 2 != 0
411 >>> iterable = range(10)
412 >>> even_items, odd_items = partition(is_odd, iterable)
413 >>> list(even_items), list(odd_items)
414 ([0, 2, 4, 6, 8], [1, 3, 5, 7, 9])
415
416 If *pred* is None, :func:`bool` is used.
417
418 >>> iterable = [0, 1, False, True, '', ' ']
419 >>> false_items, true_items = partition(None, iterable)
420 >>> list(false_items), list(true_items)
421 ([0, False, ''], [1, True, ' '])
422
423 """
424 if pred is None:
425 pred = bool
426 iterator = iter(iterable)
427
428 false_queue = deque()
429 true_queue = deque()
430
431 def gen(queue):
432 while True:
433 while queue:
434 yield queue.popleft()
435 for value in iterator:
436 (true_queue if pred(value) else false_queue).append(value)
437 break
438 else:
439 return
440
441 return gen(false_queue), gen(true_queue)
442
443
444def powerset(iterable):
445 """Yields all possible subsets of the iterable.
446
447 >>> list(powerset([1, 2, 3]))
448 [(), (1,), (2,), (3,), (1, 2), (1, 3), (2, 3), (1, 2, 3)]
449
450 :func:`powerset` will operate on iterables that aren't :class:`set`
451 instances, so repeated elements in the input will produce repeated elements
452 in the output.
453
454 >>> seq = [1, 1, 0]
455 >>> list(powerset(seq))
456 [(), (1,), (1,), (0,), (1, 1), (1, 0), (1, 0), (1, 1, 0)]
457
458 For a variant that efficiently yields actual :class:`set` instances, see
459 :func:`powerset_of_sets`.
460 """
461 s = list(iterable)
462 return chain.from_iterable(combinations(s, r) for r in range(len(s) + 1))
463
464
465def unique_everseen(iterable, key=None):
466 """Yield unique elements, preserving order. Remember all elements ever seen.
467
468 >>> list(unique_everseen('AAAABBBCCDAABBB'))
469 ['A', 'B', 'C', 'D']
470 >>> list(unique_everseen('ABBCcAD', str.casefold))
471 ['A', 'B', 'C', 'D']
472
473 Raises ``TypeError`` for unhashable items.
474
475 Some unhashable objects can be converted to hashable objects
476 using the *key* parameter:
477
478 * For ``list`` objects, try ``key=tuple``.
479 * For ``set`` objects, try ``key=frozenset``.
480 * For ``dict`` objects, try ``key=lambda x: frozenset(x.items())``
481 or in Python 3.15 and later, set ``key=frozendict``.
482
483 Alternatively, consider the ``unique()`` itertool recipe. It sorts
484 the data and then uses equality to eliminate duplicates. Hashability
485 is not required.
486
487 """
488 seen = set()
489 if key is None:
490 for element in filterfalse(seen.__contains__, iterable):
491 seen.add(element)
492 yield element
493 else:
494 for element in iterable:
495 k = key(element)
496 if k not in seen:
497 seen.add(k)
498 yield element
499
500
501def unique_justseen(iterable, key=None):
502 """Yields elements in order, ignoring serial duplicates
503
504 >>> list(unique_justseen('AAAABBBCCDAABBB'))
505 ['A', 'B', 'C', 'D', 'A', 'B']
506 >>> list(unique_justseen('ABBCcAD', str.lower))
507 ['A', 'B', 'C', 'A', 'D']
508
509 """
510 if key is None:
511 return map(itemgetter(0), groupby(iterable))
512
513 return map(next, map(itemgetter(1), groupby(iterable, key)))
514
515
516def unique(iterable, key=None, reverse=False):
517 """Yields unique elements in sorted order.
518
519 >>> list(unique([[1, 2], [3, 4], [1, 2]]))
520 [[1, 2], [3, 4]]
521
522 *key* and *reverse* are passed to :func:`sorted`.
523
524 >>> list(unique('ABBcCAD', str.casefold))
525 ['A', 'B', 'c', 'D']
526 >>> list(unique('ABBcCAD', str.casefold, reverse=True))
527 ['D', 'c', 'B', 'A']
528
529 The elements in *iterable* need not be hashable, but they must be
530 comparable for sorting to work.
531 """
532 sequenced = sorted(iterable, key=key, reverse=reverse)
533 return unique_justseen(sequenced, key=key)
534
535
536def iter_except(function, exception, first=None):
537 """Yields results from a function repeatedly until an exception is raised.
538
539 Converts a call-until-exception interface to an iterator interface.
540 Like ``iter(function, sentinel)``, but uses an exception instead of a sentinel
541 to end the loop.
542
543 >>> l = [0, 1, 2]
544 >>> list(iter_except(l.pop, IndexError))
545 [2, 1, 0]
546
547 Multiple exceptions can be specified as a stopping condition:
548
549 >>> l = [1, 2, 3, '...', 4, 5, 6]
550 >>> list(iter_except(lambda: 1 + l.pop(), (IndexError, TypeError)))
551 [7, 6, 5]
552 >>> list(iter_except(lambda: 1 + l.pop(), (IndexError, TypeError)))
553 [4, 3, 2]
554 >>> list(iter_except(lambda: 1 + l.pop(), (IndexError, TypeError)))
555 []
556
557 """
558 with suppress(exception):
559 if first is not None:
560 yield first()
561 while True:
562 yield function()
563
564
565def first_true(iterable, default=None, pred=None):
566 """
567 Returns the first true value in the iterable.
568
569 If no true value is found, returns *default*
570
571 If *pred* is not None, returns the first item for which
572 ``pred(item) == True`` .
573
574 >>> first_true(range(10))
575 1
576 >>> first_true(range(10), pred=lambda x: x > 5)
577 6
578 >>> first_true(range(10), default='missing', pred=lambda x: x > 9)
579 'missing'
580
581 """
582 return next(filter(pred, iterable), default)
583
584
585def random_product(*iterables, repeat=1):
586 """Draw an item at random from each of the input iterables.
587
588 >>> random_product('abc', range(4), 'XYZ') # doctest:+SKIP
589 ('c', 3, 'Z')
590
591 If *repeat* is provided as a keyword argument, that many items will be
592 drawn from each iterable.
593
594 >>> random_product('abcd', range(4), repeat=2) # doctest:+SKIP
595 ('a', 2, 'd', 3)
596
597 This equivalent to taking a random selection from
598 ``itertools.product(*args, repeat=repeat)``.
599
600 """
601 pools = tuple(map(tuple, iterables)) * repeat
602 return tuple(map(choice, pools))
603
604
605def random_permutation(iterable, r=None):
606 """Return a random *r* length permutation of the elements in *iterable*.
607
608 If *r* is not specified or is ``None``, then *r* defaults to the length of
609 *iterable*.
610
611 >>> random_permutation(range(5)) # doctest:+SKIP
612 (3, 4, 0, 1, 2)
613
614 This equivalent to taking a random selection from
615 ``itertools.permutations(iterable, r)``.
616
617 """
618 pool = tuple(iterable)
619 r = len(pool) if r is None else r
620 return tuple(sample(pool, r))
621
622
623def random_combination(iterable, r):
624 """Return a random *r* length subsequence of the elements in *iterable*.
625
626 >>> random_combination(range(5), 3) # doctest:+SKIP
627 (2, 3, 4)
628
629 This equivalent to taking a random selection from
630 ``itertools.combinations(iterable, r)``.
631
632 """
633 pool = tuple(iterable)
634 n = len(pool)
635 indices = sorted(sample(range(n), r))
636 return tuple([pool[i] for i in indices])
637
638
639def random_combination_with_replacement(iterable, r):
640 """Return a random *r* length subsequence of elements in *iterable*,
641 allowing individual elements to be repeated.
642
643 >>> random_combination_with_replacement(range(3), 5) # doctest:+SKIP
644 (0, 0, 1, 2, 2)
645
646 This equivalent to taking a random selection from
647 ``itertools.combinations_with_replacement(iterable, r)``.
648
649 """
650 pool = tuple(iterable)
651 n = len(pool)
652 indices = sorted(randrange(n) for i in range(r))
653 return tuple([pool[i] for i in indices])
654
655
656def nth_combination(iterable, r, index):
657 """Equivalent to ``list(combinations(iterable, r))[index]``.
658
659 The subsequences of *iterable* that are of length *r* can be ordered
660 lexicographically. :func:`nth_combination` computes the subsequence at
661 sort position *index* directly, without computing the previous
662 subsequences.
663
664 >>> nth_combination(range(5), 3, 5)
665 (0, 3, 4)
666
667 ``ValueError`` will be raised If *r* is negative.
668 ``IndexError`` will be raised if the given *index* is invalid.
669 """
670 pool = tuple(iterable)
671 n = len(pool)
672 c = comb(n, r)
673
674 if index < 0:
675 index += c
676 if not 0 <= index < c:
677 raise IndexError
678
679 result = []
680 while r:
681 c, n, r = c * r // n, n - 1, r - 1
682 while index >= c:
683 index -= c
684 c, n = c * (n - r) // n, n - 1
685 result.append(pool[-1 - n])
686
687 return tuple(result)
688
689
690def prepend(value, iterable):
691 """Yield *value*, followed by the elements in *iterable*.
692
693 >>> value = '0'
694 >>> iterable = ['1', '2', '3']
695 >>> list(prepend(value, iterable))
696 ['0', '1', '2', '3']
697
698 To prepend multiple values, see :func:`itertools.chain`
699 or :func:`value_chain`.
700
701 """
702 return chain([value], iterable)
703
704
705def convolve(signal, kernel):
706 """Discrete linear convolution of two iterables.
707 Equivalent to polynomial multiplication.
708
709 For example, multiplying ``(x² -x - 20)`` by ``(x - 3)``
710 gives ``(x³ -4x² -17x + 60)``.
711
712 >>> list(convolve([1, -1, -20], [1, -3]))
713 [1, -4, -17, 60]
714
715 Examples of popular kinds of kernels:
716
717 * The kernel ``[0.25, 0.25, 0.25, 0.25]`` computes a moving average.
718 For image data, this blurs the image and reduces noise.
719 * The kernel ``[1/2, 0, -1/2]`` estimates the first derivative of
720 a function evaluated at evenly spaced inputs.
721 * The kernel ``[1, -2, 1]`` estimates the second derivative of a
722 function evaluated at evenly spaced inputs.
723
724 Convolutions are mathematically commutative; however, the inputs are
725 evaluated differently. The signal is consumed lazily and can be
726 infinite. The kernel is fully consumed before the calculations begin.
727
728 Supports all numeric types: int, float, complex, Decimal, Fraction.
729
730 References:
731
732 * Article: https://betterexplained.com/articles/intuitive-convolution/
733 * Video by 3Blue1Brown: https://www.youtube.com/watch?v=KuXjwB4LzSA
734
735 """
736 # This implementation comes from an older version of the itertools
737 # documentation. While the newer implementation is a bit clearer,
738 # this one was kept because the inlined window logic is faster
739 # and it avoids an unnecessary deque-to-tuple conversion.
740 kernel = tuple(kernel)[::-1]
741 n = len(kernel)
742 window = deque([0], maxlen=n) * n
743 for x in chain(signal, repeat(0, n - 1)):
744 window.append(x)
745 yield _sumprod(kernel, window)
746
747
748def before_and_after(predicate, it):
749 """A variant of :func:`takewhile` that allows complete access to the
750 remainder of the iterator.
751
752 >>> it = iter('ABCdEfGhI')
753 >>> all_upper, remainder = before_and_after(str.isupper, it)
754 >>> ''.join(all_upper)
755 'ABC'
756 >>> ''.join(remainder) # takewhile() would lose the 'd'
757 'dEfGhI'
758
759 Note that the first iterator must be fully consumed before the second
760 iterator can generate valid results.
761 """
762 trues, after = tee(it)
763 trues = compress(takewhile(predicate, trues), zip(after))
764 return trues, after
765
766
767def triplewise(iterable):
768 """Return overlapping triplets from *iterable*.
769
770 >>> list(triplewise('ABCDE'))
771 [('A', 'B', 'C'), ('B', 'C', 'D'), ('C', 'D', 'E')]
772
773 """
774 # This deviates from the itertools documentation recipe - see
775 # https://github.com/more-itertools/more-itertools/issues/889
776 t1, t2, t3 = tee(iterable, 3)
777 next(t3, None)
778 next(t3, None)
779 next(t2, None)
780 return zip(t1, t2, t3)
781
782
783def _sliding_window_islice(iterable, n):
784 # Fast path for small, non-zero values of n.
785 iterators = tee(iterable, n)
786 for i, iterator in enumerate(iterators):
787 next(islice(iterator, i, i), None)
788 return zip(*iterators)
789
790
791def _sliding_window_deque(iterable, n):
792 # Normal path for other values of n.
793 iterator = iter(iterable)
794 window = deque(islice(iterator, n - 1), maxlen=n)
795 for x in iterator:
796 window.append(x)
797 yield tuple(window)
798
799
800def sliding_window(iterable, n):
801 """Return a sliding window of width *n* over *iterable*.
802
803 >>> list(sliding_window(range(6), 4))
804 [(0, 1, 2, 3), (1, 2, 3, 4), (2, 3, 4, 5)]
805
806 If *iterable* has fewer than *n* items, then nothing is yielded:
807
808 >>> list(sliding_window(range(3), 4))
809 []
810
811 For a variant with more features, see :func:`windowed`.
812 """
813 if n > 20:
814 return _sliding_window_deque(iterable, n)
815 elif n > 2:
816 return _sliding_window_islice(iterable, n)
817 elif n == 2:
818 return pairwise(iterable)
819 elif n == 1:
820 return zip(iterable)
821 else:
822 raise ValueError(f'n should be at least one, not {n}')
823
824
825def subslices(iterable):
826 """Return all contiguous non-empty subslices of *iterable*.
827
828 >>> list(subslices('ABC'))
829 [['A'], ['A', 'B'], ['A', 'B', 'C'], ['B'], ['B', 'C'], ['C']]
830
831 This is similar to :func:`substrings`, but emits items in a different
832 order.
833 """
834 seq = list(iterable)
835 slices = starmap(slice, combinations(range(len(seq) + 1), 2))
836 return map(getitem, repeat(seq), slices)
837
838
839def polynomial_from_roots(roots):
840 """Compute a polynomial's coefficients from its roots.
841
842 >>> roots = [5, -4, 3] # (x - 5) * (x + 4) * (x - 3)
843 >>> polynomial_from_roots(roots) # x³ - 4 x² - 17 x + 60
844 [1, -4, -17, 60]
845
846 Note that polynomial coefficients are specified in descending power order.
847
848 Supports all numeric types: int, float, complex, Decimal, Fraction.
849 """
850
851 # This recipe differs from the one in itertools docs in that it
852 # applies list() after each call to convolve(). This avoids
853 # hitting stack limits with nested generators.
854
855 poly = [1]
856 for root in roots:
857 poly = list(convolve(poly, (1, -root)))
858 return poly
859
860
861def iter_index(iterable, value, start=0, stop=None):
862 """Yield the index of each place in *iterable* that *value* occurs,
863 beginning with index *start* and ending before index *stop*.
864
865
866 >>> list(iter_index('AABCADEAF', 'A'))
867 [0, 1, 4, 7]
868 >>> list(iter_index('AABCADEAF', 'A', 1)) # start index is inclusive
869 [1, 4, 7]
870 >>> list(iter_index('AABCADEAF', 'A', 1, 7)) # stop index is not inclusive
871 [1, 4]
872
873 The behavior for non-scalar *values* matches the built-in Python types.
874
875 >>> list(iter_index('ABCDABCD', 'AB'))
876 [0, 4]
877 >>> list(iter_index([0, 1, 2, 3, 0, 1, 2, 3], [0, 1]))
878 []
879 >>> list(iter_index([[0, 1], [2, 3], [0, 1], [2, 3]], [0, 1]))
880 [0, 2]
881
882 See :func:`locate` for a more general means of finding the indexes
883 associated with particular values.
884
885 """
886 seq_index = getattr(iterable, 'index', None)
887 if seq_index is None:
888 # Slow path for general iterables
889 iterator = islice(iterable, start, stop)
890 for i, element in enumerate(iterator, start):
891 if element is value or element == value:
892 yield i
893 else:
894 # Fast path for sequences
895 stop = len(iterable) if stop is None else stop
896 i = start - 1
897 with suppress(ValueError):
898 while True:
899 yield (i := seq_index(value, i + 1, stop))
900
901
902def sieve(n):
903 """Yield the primes less than n.
904
905 >>> list(sieve(30))
906 [2, 3, 5, 7, 11, 13, 17, 19, 23, 29]
907
908 """
909 # This implementation comes from an older version of the itertools
910 # documentation. The newer implementation is easier to read but is
911 # less lazy.
912 if n > 2:
913 yield 2
914 start = 3
915 data = bytearray((0, 1)) * (n // 2)
916 for p in iter_index(data, 1, start, stop=isqrt(n) + 1):
917 yield from iter_index(data, 1, start, p * p)
918 data[p * p : n : p + p] = bytes(len(range(p * p, n, p + p)))
919 start = p * p
920 yield from iter_index(data, 1, start)
921
922
923def _batched(iterable, n, *, strict=False): # pragma: no cover
924 """Batch data into tuples of length *n*. If the number of items in
925 *iterable* is not divisible by *n*:
926 * The last batch will be shorter if *strict* is ``False``.
927 * :exc:`ValueError` will be raised if *strict* is ``True``.
928
929 >>> list(batched('ABCDEFG', 3))
930 [('A', 'B', 'C'), ('D', 'E', 'F'), ('G',)]
931
932 On Python 3.13 and above, this is an alias for :func:`itertools.batched`.
933 """
934 if n < 1:
935 raise ValueError('n must be at least one')
936 iterator = iter(iterable)
937 while batch := tuple(islice(iterator, n)):
938 if strict and len(batch) != n:
939 raise ValueError('batched(): incomplete batch')
940 yield batch
941
942
943if hexversion >= 0x30D00A2: # pragma: no cover
944 from itertools import batched as itertools_batched
945
946 def batched(iterable, n, *, strict=False):
947 return itertools_batched(iterable, n, strict=strict)
948
949 batched.__doc__ = _batched.__doc__
950else: # pragma: no cover
951 batched = _batched
952
953
954def transpose(matrix):
955 """Swap the rows and columns of the input matrix.
956
957 >>> list(transpose([(1, 2, 3), (11, 22, 33)]))
958 [(1, 11), (2, 22), (3, 33)]
959
960 The caller should ensure that the dimensions of the input are compatible.
961 If the input is empty, no output will be produced.
962 """
963 return zip(*matrix, strict=True)
964
965
966def _is_scalar(value, stringlike=(str, bytes)):
967 "Scalars are bytes, strings, and non-iterables."
968 try:
969 iter(value)
970 except TypeError:
971 return True
972 return isinstance(value, stringlike)
973
974
975def _flatten_tensor(tensor):
976 "Depth-first iterator over scalars in a tensor."
977 iterator = iter(tensor)
978 while True:
979 try:
980 value = next(iterator)
981 except StopIteration:
982 return iterator
983 iterator = chain((value,), iterator)
984 if _is_scalar(value):
985 return iterator
986 iterator = chain.from_iterable(iterator)
987
988
989def reshape(matrix, shape):
990 """Change the shape of a *matrix*.
991
992 If *shape* is an integer, the matrix must be two dimensional
993 and the shape is interpreted as the desired number of columns:
994
995 >>> matrix = [(0, 1), (2, 3), (4, 5)]
996 >>> cols = 3
997 >>> list(reshape(matrix, cols))
998 [(0, 1, 2), (3, 4, 5)]
999
1000 If *shape* is a tuple (or other iterable), the input matrix can have
1001 any number of dimensions. It will first be flattened and then rebuilt
1002 to the desired shape which can also be multidimensional:
1003
1004 >>> matrix = [(0, 1), (2, 3), (4, 5)] # Start with a 3 x 2 matrix
1005
1006 >>> list(reshape(matrix, (2, 3))) # Make a 2 x 3 matrix
1007 [(0, 1, 2), (3, 4, 5)]
1008
1009 >>> list(reshape(matrix, (6,))) # Make a vector of length six
1010 [0, 1, 2, 3, 4, 5]
1011
1012 >>> list(reshape(matrix, (2, 1, 3, 1))) # Make 2 x 1 x 3 x 1 tensor
1013 [(((0,), (1,), (2,)),), (((3,), (4,), (5,)),)]
1014
1015 Each dimension is assumed to be uniform, either all arrays or all scalars.
1016 Flattening stops when the first value in a dimension is a scalar.
1017 Scalars are bytes, strings, and non-iterables.
1018 The reshape iterator stops when the requested shape is complete
1019 or when the input is exhausted, whichever comes first.
1020
1021 """
1022 if isinstance(shape, int):
1023 return batched(chain.from_iterable(matrix), shape)
1024 first_dim, *dims = shape
1025 scalar_stream = _flatten_tensor(matrix)
1026 reshaped = reduce(batched, reversed(dims), scalar_stream)
1027 return islice(reshaped, first_dim)
1028
1029
1030def matmul(m1, m2):
1031 """Multiply two matrices.
1032
1033 >>> list(matmul([(7, 5), (3, 5)], [(2, 5), (7, 9)]))
1034 [(49, 80), (41, 60)]
1035
1036 The caller should ensure that the dimensions of the input matrices are
1037 compatible with each other.
1038
1039 Supports all numeric types: int, float, complex, Decimal, Fraction.
1040 """
1041 n = len(m2[0])
1042 return batched(starmap(_sumprod, product(m1, transpose(m2))), n)
1043
1044
1045def _factor_pollard(n):
1046 # Return a factor of n using Pollard's rho algorithm.
1047 # Efficient when n is odd and composite.
1048 for b in range(1, n):
1049 x = y = 2
1050 d = 1
1051 while d == 1:
1052 x = (x * x + b) % n
1053 y = (y * y + b) % n
1054 y = (y * y + b) % n
1055 d = gcd(x - y, n)
1056 if d != n:
1057 return d
1058 raise ValueError('prime or under 5') # pragma: no cover
1059
1060
1061_primes_below_211 = tuple(sieve(211))
1062
1063
1064def factor(n):
1065 """Yield the prime factors of n.
1066
1067 >>> list(factor(360))
1068 [2, 2, 2, 3, 3, 5]
1069
1070 Finds small factors with trial division. Larger factors are
1071 either verified as prime with ``is_prime`` or split into
1072 smaller factors with Pollard's rho algorithm.
1073 """
1074
1075 # Corner case reduction
1076 if n < 2:
1077 return
1078
1079 # Trial division reduction
1080 for prime in _primes_below_211:
1081 while not n % prime:
1082 yield prime
1083 n //= prime
1084
1085 # Pollard's rho reduction
1086 primes = []
1087 todo = [n] if n > 1 else []
1088 for n in todo:
1089 if n < 211**2 or is_prime(n):
1090 primes.append(n)
1091 else:
1092 fact = _factor_pollard(n)
1093 todo += (fact, n // fact)
1094 yield from sorted(primes)
1095
1096
1097def polynomial_eval(coefficients, x):
1098 """Evaluate a polynomial at a specific value.
1099
1100 Computes with better numeric stability than Horner's method.
1101
1102 Evaluate ``x^3 - 4 * x^2 - 17 * x + 60`` at ``x = 2.5``:
1103
1104 >>> coefficients = [1, -4, -17, 60]
1105 >>> x = 2.5
1106 >>> polynomial_eval(coefficients, x)
1107 8.125
1108
1109 Note that polynomial coefficients are specified in descending power order.
1110
1111 Supports all numeric types: int, float, complex, Decimal, Fraction.
1112 """
1113 n = len(coefficients)
1114 if n == 0:
1115 return type(x)(0)
1116 powers = map(pow, repeat(x), reversed(range(n)))
1117 return _sumprod(coefficients, powers)
1118
1119
1120def sum_of_squares(iterable):
1121 """Return the sum of the squares of the input values.
1122
1123 >>> sum_of_squares([10, 20, 30])
1124 1400
1125
1126 Supports all numeric types: int, float, complex, Decimal, Fraction.
1127 """
1128 return _sumprod(*tee(iterable))
1129
1130
1131def polynomial_derivative(coefficients):
1132 """Compute the first derivative of a polynomial.
1133
1134 Evaluate the derivative of ``x³ - 4 x² - 17 x + 60``:
1135
1136 >>> coefficients = [1, -4, -17, 60]
1137 >>> derivative_coefficients = polynomial_derivative(coefficients)
1138 >>> derivative_coefficients
1139 [3, -8, -17]
1140
1141 Note that polynomial coefficients are specified in descending power order.
1142
1143 Supports all numeric types: int, float, complex, Decimal, Fraction.
1144 """
1145 n = len(coefficients)
1146 powers = reversed(range(1, n))
1147 return list(map(mul, coefficients, powers))
1148
1149
1150def totient(n):
1151 """Return the count of natural numbers up to *n* that are coprime with *n*.
1152
1153 Euler's totient function φ(n) gives the number of totatives.
1154 Totative are integers k in the range 1 ≤ k ≤ n such that gcd(n, k) = 1.
1155
1156 >>> n = 9
1157 >>> totient(n)
1158 6
1159
1160 >>> totatives = [x for x in range(1, n) if gcd(n, x) == 1]
1161 >>> totatives
1162 [1, 2, 4, 5, 7, 8]
1163 >>> len(totatives)
1164 6
1165
1166 Reference: https://en.wikipedia.org/wiki/Euler%27s_totient_function
1167
1168 """
1169 for prime in set(factor(n)):
1170 n -= n // prime
1171 return n
1172
1173
1174# Miller–Rabin primality test: https://oeis.org/A014233
1175_perfect_tests = [
1176 (2047, (2,)),
1177 (9080191, (31, 73)),
1178 (4759123141, (2, 7, 61)),
1179 (1122004669633, (2, 13, 23, 1662803)),
1180 (2152302898747, (2, 3, 5, 7, 11)),
1181 (3474749660383, (2, 3, 5, 7, 11, 13)),
1182 (18446744073709551616, (2, 325, 9375, 28178, 450775, 9780504, 1795265022)),
1183 (
1184 3317044064679887385961981,
1185 (2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41),
1186 ),
1187]
1188
1189
1190@lru_cache
1191def _shift_to_odd(n):
1192 'Return s, d such that 2**s * d == n'
1193 s = ((n - 1) ^ n).bit_length() - 1
1194 d = n >> s
1195 assert (1 << s) * d == n and d & 1 and s >= 0
1196 return s, d
1197
1198
1199def _strong_probable_prime(n, base):
1200 assert (n > 2) and (n & 1) and (2 <= base < n)
1201
1202 s, d = _shift_to_odd(n - 1)
1203
1204 x = pow(base, d, n)
1205 if x == 1 or x == n - 1:
1206 return True
1207
1208 for _ in range(s - 1):
1209 x = x * x % n
1210 if x == n - 1:
1211 return True
1212
1213 return False
1214
1215
1216# Separate instance of Random() that doesn't share state
1217# with the default user instance of Random().
1218_private_randrange = random.Random().randrange
1219
1220
1221def is_prime(n):
1222 """Return ``True`` if *n* is prime and ``False`` otherwise.
1223
1224 Basic examples:
1225
1226 >>> is_prime(37)
1227 True
1228 >>> is_prime(3 * 13)
1229 False
1230 >>> is_prime(18_446_744_073_709_551_557)
1231 True
1232
1233 Find the next prime over one billion:
1234
1235 >>> next(filter(is_prime, count(10**9)))
1236 1000000007
1237
1238 Generate random primes up to 200 bits and up to 60 decimal digits:
1239
1240 >>> from random import seed, randrange, getrandbits
1241 >>> seed(18675309)
1242
1243 >>> next(filter(is_prime, map(getrandbits, repeat(200))))
1244 893303929355758292373272075469392561129886005037663238028407
1245
1246 >>> next(filter(is_prime, map(randrange, repeat(10**60))))
1247 269638077304026462407872868003560484232362454342414618963649
1248
1249 This function is exact for values of *n* below 10**24. For larger inputs,
1250 the probabilistic Miller-Rabin primality test has a less than 1 in 2**128
1251 chance of a false positive.
1252 """
1253
1254 if n < 17:
1255 return n in {2, 3, 5, 7, 11, 13}
1256
1257 if not (n & 1 and n % 3 and n % 5 and n % 7 and n % 11 and n % 13):
1258 return False
1259
1260 for limit, bases in _perfect_tests:
1261 if n < limit:
1262 break
1263 else:
1264 bases = (_private_randrange(2, n - 1) for i in range(64))
1265
1266 return all(_strong_probable_prime(n, base) for base in bases)
1267
1268
1269def loops(n):
1270 """Returns an iterable with *n* elements for efficient looping.
1271 Like ``range(n)`` but doesn't create integers.
1272
1273 >>> i = 0
1274 >>> for _ in loops(5):
1275 ... i += 1
1276 >>> i
1277 5
1278
1279 """
1280 return repeat(None, n)
1281
1282
1283def multinomial(*counts):
1284 """Number of distinct arrangements of a multiset.
1285
1286 The expression ``multinomial(3, 4, 2)`` has several equivalent
1287 interpretations:
1288
1289 * In the expansion of ``(a + b + c)⁹``, the coefficient of the
1290 ``a³b⁴c²`` term is 1260.
1291
1292 * There are 1260 distinct ways to arrange 9 balls consisting of 3 reds, 4
1293 greens, and 2 blues.
1294
1295 * There are 1260 unique ways to place 9 distinct objects into three bins
1296 with sizes 3, 4, and 2.
1297
1298 The :func:`multinomial` function computes the length of
1299 :func:`distinct_permutations`. For example, there are 83,160 distinct
1300 anagrams of the word "abracadabra":
1301
1302 >>> from more_itertools import distinct_permutations, ilen
1303 >>> ilen(distinct_permutations('abracadabra'))
1304 83160
1305
1306 This can be computed directly from the letter counts, 5a 2b 2r 1c 1d:
1307
1308 >>> from collections import Counter
1309 >>> list(Counter('abracadabra').values())
1310 [5, 2, 2, 1, 1]
1311 >>> multinomial(5, 2, 2, 1, 1)
1312 83160
1313
1314 A binomial coefficient is a special case of multinomial where there are
1315 only two categories. For example, the number of ways to arrange 12 balls
1316 with 5 reds and 7 blues is ``multinomial(5, 7)`` or ``math.comb(12, 5)``.
1317
1318 Likewise, factorial is a special case of multinomial where
1319 the multiplicities are all just 1 so that
1320 ``multinomial(1, 1, 1, 1, 1, 1, 1) == math.factorial(7)``.
1321
1322 Reference: https://en.wikipedia.org/wiki/Multinomial_theorem
1323
1324 """
1325 return prod(map(comb, accumulate(counts), counts))
1326
1327
1328def _running_median_minheap_and_maxheap(iterator): # pragma: no cover
1329 "Non-windowed running_median() for Python 3.14+"
1330
1331 read = iterator.__next__
1332 lo = [] # max-heap
1333 hi = [] # min-heap (same size as or one smaller than lo)
1334
1335 with suppress(StopIteration):
1336 while True:
1337 heappush_max(lo, heappushpop(hi, read()))
1338 yield lo[0]
1339
1340 heappush(hi, heappushpop_max(lo, read()))
1341 yield (lo[0] + hi[0]) / 2
1342
1343
1344def _running_median_minheap_only(iterator): # pragma: no cover
1345 "Backport of non-windowed running_median() for Python 3.13 and prior."
1346
1347 read = iterator.__next__
1348 lo = [] # max-heap (actually a minheap with negated values)
1349 hi = [] # min-heap (same size as or one smaller than lo)
1350
1351 with suppress(StopIteration):
1352 while True:
1353 heappush(lo, -heappushpop(hi, read()))
1354 yield -lo[0]
1355
1356 heappush(hi, -heappushpop(lo, -read()))
1357 yield (hi[0] - lo[0]) / 2
1358
1359
1360def _running_median_windowed(iterator, maxlen):
1361 "Yield median of values in a sliding window."
1362
1363 window = deque()
1364 ordered = []
1365
1366 for x in iterator:
1367 window.append(x)
1368 insort(ordered, x)
1369
1370 if len(ordered) > maxlen:
1371 i = bisect_left(ordered, window.popleft())
1372 del ordered[i]
1373
1374 n = len(ordered)
1375 m = n // 2
1376 yield ordered[m] if n & 1 else (ordered[m - 1] + ordered[m]) / 2
1377
1378
1379def running_median(iterable, *, maxlen=None):
1380 """Cumulative median of values seen so far or values in a sliding window.
1381
1382 Set *maxlen* to a positive integer to specify the maximum size
1383 of the sliding window. The default of *None* is equivalent to
1384 an unbounded window.
1385
1386 For example:
1387
1388 >>> list(running_median([5.0, 9.0, 4.0, 12.0, 8.0, 9.0]))
1389 [5.0, 7.0, 5.0, 7.0, 8.0, 8.5]
1390 >>> list(running_median([5.0, 9.0, 4.0, 12.0, 8.0, 9.0], maxlen=3))
1391 [5.0, 7.0, 5.0, 9.0, 8.0, 9.0]
1392
1393 Supports numeric types such as int, float, Decimal, and Fraction,
1394 but not complex numbers which are unorderable.
1395
1396 On version Python 3.13 and prior, max-heaps are simulated with
1397 negative values. The negation causes Decimal inputs to apply context
1398 rounding, making the results slightly different than that obtained
1399 by statistics.median().
1400 """
1401
1402 iterator = iter(iterable)
1403
1404 if maxlen is not None:
1405 maxlen = index(maxlen)
1406 if maxlen <= 0:
1407 raise ValueError('Window size should be positive')
1408 return _running_median_windowed(iterator, maxlen)
1409
1410 if not _max_heap_available:
1411 return _running_median_minheap_only(iterator) # pragma: no cover
1412
1413 return _running_median_minheap_and_maxheap(iterator) # pragma: no cover
1414
1415
1416def _windowed_running_mean(iterator, n):
1417 window = deque()
1418 running_sum = 0
1419 for value in iterator:
1420 window.append(value)
1421 running_sum += value
1422 if len(window) > n:
1423 running_sum -= window.popleft()
1424 yield running_sum / len(window)
1425
1426
1427def running_mean(iterable, *, maxlen=None):
1428 """Cumulative mean of values seen so far or values in a sliding window.
1429
1430 Set *maxlen* to a positive integer to specify the maximum size
1431 of the sliding window. The default of *None* is equivalent to
1432 an unbounded window.
1433
1434 For example:
1435
1436 >>> list(running_mean([40, 30, 50, 46, 39, 44]))
1437 [40.0, 35.0, 40.0, 41.5, 41.0, 41.5]
1438
1439 >>> list(running_mean([40, 30, 50, 46, 39, 44], maxlen=3))
1440 [40.0, 35.0, 40.0, 42.0, 45.0, 43.0]
1441
1442 Supports numeric types such as int, float, complex, Decimal, and Fraction.
1443
1444 No extra effort is made to reduce round-off errors for float inputs.
1445 So the results may be slightly different from `statistics.mean`.
1446
1447 """
1448
1449 iterator = iter(iterable)
1450
1451 if maxlen is None:
1452 return map(truediv, accumulate(iterator), count(1))
1453
1454 if maxlen <= 0:
1455 raise ValueError('Window size should be positive')
1456
1457 return _windowed_running_mean(iterator, maxlen)
1458
1459
1460def random_derangement(iterable):
1461 """Return a random derangement of elements in the iterable.
1462
1463 Equivalent to but much faster than ``choice(list(derangements(iterable)))``.
1464
1465 """
1466 seq = tuple(iterable)
1467 if len(seq) < 2:
1468 if len(seq) == 0:
1469 return ()
1470 raise IndexError('No derangments to choose from')
1471 perm = list(range(len(seq)))
1472 start = tuple(perm)
1473 while True:
1474 shuffle(perm)
1475 if not any(map(is_, start, perm)):
1476 return itemgetter(*perm)(seq)