1"""Imported from the recipes section of the itertools documentation.
2
3All functions taken from the recipes section of the itertools library docs
4[1]_.
5Some backward-compatible usability improvements have been made.
6
7.. [1] http://docs.python.org/library/itertools.html#recipes
8
9"""
10
11import random
12
13from bisect import bisect_left, insort
14from collections import deque
15from contextlib import suppress
16from functools import lru_cache, reduce
17from heapq import heappush, heappushpop
18from itertools import (
19 accumulate,
20 chain,
21 combinations,
22 compress,
23 count,
24 cycle,
25 groupby,
26 islice,
27 pairwise,
28 product,
29 repeat,
30 starmap,
31 takewhile,
32 tee,
33 zip_longest,
34)
35from math import prod, comb, isqrt, gcd
36from operator import mul, getitem, index, is_, itemgetter, truediv
37from random import randrange, sample, choice, shuffle
38from sys import hexversion
39
40__all__ = [
41 'all_equal',
42 'batched',
43 'before_and_after',
44 'consume',
45 'convolve',
46 'dotproduct',
47 'first_true',
48 'factor',
49 'flatten',
50 'grouper',
51 'is_prime',
52 'iter_except',
53 'iter_index',
54 'loops',
55 'matmul',
56 'multinomial',
57 'ncycles',
58 'nth',
59 'nth_combination',
60 'padnone',
61 'pad_none',
62 'partition',
63 'polynomial_eval',
64 'polynomial_from_roots',
65 'polynomial_derivative',
66 'powerset',
67 'prepend',
68 'quantify',
69 'reshape',
70 'random_combination_with_replacement',
71 'random_combination',
72 'random_derangement',
73 'random_permutation',
74 'random_product',
75 'repeatfunc',
76 'roundrobin',
77 'running_mean',
78 'running_median',
79 'sieve',
80 'sliding_window',
81 'subslices',
82 'sum_of_squares',
83 'tabulate',
84 'tail',
85 'take',
86 'totient',
87 'transpose',
88 'triplewise',
89 'unique',
90 'unique_everseen',
91 'unique_justseen',
92]
93
94_marker = object()
95
96
97# heapq max-heap functions are available for Python 3.14+
98try:
99 from heapq import heappush_max, heappushpop_max
100except ImportError: # pragma: no cover
101 _max_heap_available = False
102else: # pragma: no cover
103 _max_heap_available = True
104
105
106def take(n, iterable):
107 """Return first *n* items of the *iterable* as a list.
108
109 >>> take(3, range(10))
110 [0, 1, 2]
111
112 If there are fewer than *n* items in the iterable, all of them are
113 returned.
114
115 >>> take(10, range(3))
116 [0, 1, 2]
117
118 """
119 return list(islice(iterable, n))
120
121
122def tabulate(function, start=0):
123 """Return an iterator over the results of ``func(start)``,
124 ``func(start + 1)``, ``func(start + 2)``...
125
126 *func* should be a function that accepts one integer argument.
127
128 If *start* is not specified it defaults to 0. It will be incremented each
129 time the iterator is advanced.
130
131 >>> square = lambda x: x ** 2
132 >>> iterator = tabulate(square, -3)
133 >>> take(4, iterator)
134 [9, 4, 1, 0]
135
136 """
137 return map(function, count(start))
138
139
140def tail(n, iterable):
141 """Return an iterator over the last *n* items of *iterable*.
142
143 >>> t = tail(3, 'ABCDEFG')
144 >>> list(t)
145 ['E', 'F', 'G']
146
147 """
148 try:
149 size = len(iterable)
150 except TypeError:
151 return iter(deque(iterable, maxlen=n))
152 else:
153 return islice(iterable, max(0, size - n), None)
154
155
156def consume(iterator, n=None):
157 """Advance *iterable* by *n* steps. If *n* is ``None``, consume it
158 entirely.
159
160 Efficiently exhausts an iterator without returning values. Defaults to
161 consuming the whole iterator, but an optional second argument may be
162 provided to limit consumption.
163
164 >>> i = (x for x in range(10))
165 >>> next(i)
166 0
167 >>> consume(i, 3)
168 >>> next(i)
169 4
170 >>> consume(i)
171 >>> next(i)
172 Traceback (most recent call last):
173 File "<stdin>", line 1, in <module>
174 StopIteration
175
176 If the iterator has fewer items remaining than the provided limit, the
177 whole iterator will be consumed.
178
179 >>> i = (x for x in range(3))
180 >>> consume(i, 5)
181 >>> next(i)
182 Traceback (most recent call last):
183 File "<stdin>", line 1, in <module>
184 StopIteration
185
186 """
187 # Use functions that consume iterators at C speed.
188 if n is None:
189 # feed the entire iterator into a zero-length deque
190 deque(iterator, maxlen=0)
191 else:
192 # advance to the empty slice starting at position n
193 next(islice(iterator, n, n), None)
194
195
196def nth(iterable, n, default=None):
197 """Returns the nth item or a default value.
198
199 >>> l = range(10)
200 >>> nth(l, 3)
201 3
202 >>> nth(l, 20, "zebra")
203 'zebra'
204
205 """
206 return next(islice(iterable, n, None), default)
207
208
209def all_equal(iterable, key=None):
210 """
211 Returns ``True`` if all the elements are equal to each other.
212
213 >>> all_equal('aaaa')
214 True
215 >>> all_equal('aaab')
216 False
217
218 A function that accepts a single argument and returns a transformed version
219 of each input item can be specified with *key*:
220
221 >>> all_equal('AaaA', key=str.casefold)
222 True
223 >>> all_equal([1, 2, 3], key=lambda x: x < 10)
224 True
225
226 """
227 iterator = groupby(iterable, key)
228 for first in iterator:
229 for second in iterator:
230 return False
231 return True
232 return True
233
234
235def quantify(iterable, pred=bool):
236 """Return the how many times the predicate is true.
237
238 >>> quantify([True, False, True])
239 2
240
241 """
242 return sum(map(pred, iterable))
243
244
245def pad_none(iterable):
246 """Returns the sequence of elements and then returns ``None`` indefinitely.
247
248 >>> take(5, pad_none(range(3)))
249 [0, 1, 2, None, None]
250
251 Useful for emulating the behavior of the built-in :func:`map` function.
252
253 See also :func:`padded`.
254
255 """
256 return chain(iterable, repeat(None))
257
258
259padnone = pad_none
260
261
262def ncycles(iterable, n):
263 """Returns the sequence elements *n* times
264
265 >>> list(ncycles(["a", "b"], 3))
266 ['a', 'b', 'a', 'b', 'a', 'b']
267
268 """
269 return chain.from_iterable(repeat(tuple(iterable), n))
270
271
272def dotproduct(vec1, vec2):
273 """Returns the dot product of the two iterables.
274
275 >>> dotproduct([10, 15, 12], [0.65, 0.80, 1.25])
276 33.5
277 >>> 10 * 0.65 + 15 * 0.80 + 12 * 1.25
278 33.5
279
280 In Python 3.12 and later, use ``math.sumprod()`` instead.
281 """
282 return sum(map(mul, vec1, vec2))
283
284
285# math.sumprod is available for Python 3.12+
286try:
287 from math import sumprod as _sumprod
288except ImportError: # pragma: no cover
289 _sumprod = dotproduct
290
291
292def flatten(listOfLists):
293 """Return an iterator flattening one level of nesting in a list of lists.
294
295 >>> list(flatten([[0, 1], [2, 3]]))
296 [0, 1, 2, 3]
297
298 See also :func:`collapse`, which can flatten multiple levels of nesting.
299
300 """
301 return chain.from_iterable(listOfLists)
302
303
304def repeatfunc(func, times=None, *args):
305 """Call *func* with *args* repeatedly, returning an iterable over the
306 results.
307
308 If *times* is specified, the iterable will terminate after that many
309 repetitions:
310
311 >>> from operator import add
312 >>> times = 4
313 >>> args = 3, 5
314 >>> list(repeatfunc(add, times, *args))
315 [8, 8, 8, 8]
316
317 If *times* is ``None`` the iterable will not terminate:
318
319 >>> from random import randrange
320 >>> times = None
321 >>> args = 1, 11
322 >>> take(6, repeatfunc(randrange, times, *args)) # doctest:+SKIP
323 [2, 4, 8, 1, 8, 4]
324
325 """
326 if times is None:
327 return starmap(func, repeat(args))
328 return starmap(func, repeat(args, times))
329
330
331def grouper(iterable, n, incomplete='fill', fillvalue=None):
332 """Group elements from *iterable* into fixed-length groups of length *n*.
333
334 >>> list(grouper('ABCDEF', 3))
335 [('A', 'B', 'C'), ('D', 'E', 'F')]
336
337 The keyword arguments *incomplete* and *fillvalue* control what happens for
338 iterables whose length is not a multiple of *n*.
339
340 When *incomplete* is `'fill'`, the last group will contain instances of
341 *fillvalue*.
342
343 >>> list(grouper('ABCDEFG', 3, incomplete='fill', fillvalue='x'))
344 [('A', 'B', 'C'), ('D', 'E', 'F'), ('G', 'x', 'x')]
345
346 When *incomplete* is `'ignore'`, the last group will not be emitted.
347
348 >>> list(grouper('ABCDEFG', 3, incomplete='ignore', fillvalue='x'))
349 [('A', 'B', 'C'), ('D', 'E', 'F')]
350
351 When *incomplete* is `'strict'`, a `ValueError` will be raised.
352
353 >>> iterator = grouper('ABCDEFG', 3, incomplete='strict')
354 >>> list(iterator) # doctest: +IGNORE_EXCEPTION_DETAIL
355 Traceback (most recent call last):
356 ...
357 ValueError
358
359 """
360 iterators = [iter(iterable)] * n
361 match incomplete:
362 case 'fill':
363 return zip_longest(*iterators, fillvalue=fillvalue)
364 case 'strict':
365 return zip(*iterators, strict=True)
366 case 'ignore':
367 return zip(*iterators)
368 case _:
369 raise ValueError('Expected fill, strict, or ignore')
370
371
372def roundrobin(*iterables):
373 """Visit input iterables in a cycle until each is exhausted.
374
375 >>> list(roundrobin('ABC', 'D', 'EF'))
376 ['A', 'D', 'E', 'B', 'F', 'C']
377
378 This function produces the same output as :func:`interleave_longest`, but
379 may perform better for some inputs (in particular when the number of
380 iterables is small).
381
382 """
383 # Algorithm credited to George Sakkis
384 iterators = map(iter, iterables)
385 for num_active in range(len(iterables), 0, -1):
386 iterators = cycle(islice(iterators, num_active))
387 yield from map(next, iterators)
388
389
390def partition(pred, iterable):
391 """
392 Returns a 2-tuple of iterables derived from the input iterable.
393 The first yields the items that have ``pred(item) == False``.
394 The second yields the items that have ``pred(item) == True``.
395
396 >>> is_odd = lambda x: x % 2 != 0
397 >>> iterable = range(10)
398 >>> even_items, odd_items = partition(is_odd, iterable)
399 >>> list(even_items), list(odd_items)
400 ([0, 2, 4, 6, 8], [1, 3, 5, 7, 9])
401
402 If *pred* is None, :func:`bool` is used.
403
404 >>> iterable = [0, 1, False, True, '', ' ']
405 >>> false_items, true_items = partition(None, iterable)
406 >>> list(false_items), list(true_items)
407 ([0, False, ''], [1, True, ' '])
408
409 """
410 if pred is None:
411 pred = bool
412 iterator = iter(iterable)
413
414 false_queue = deque()
415 true_queue = deque()
416
417 def gen(queue):
418 while True:
419 while queue:
420 yield queue.popleft()
421 for value in iterator:
422 (true_queue if pred(value) else false_queue).append(value)
423 break
424 else:
425 return
426
427 return gen(false_queue), gen(true_queue)
428
429
430def powerset(iterable):
431 """Yields all possible subsets of the iterable.
432
433 >>> list(powerset([1, 2, 3]))
434 [(), (1,), (2,), (3,), (1, 2), (1, 3), (2, 3), (1, 2, 3)]
435
436 :func:`powerset` will operate on iterables that aren't :class:`set`
437 instances, so repeated elements in the input will produce repeated elements
438 in the output.
439
440 >>> seq = [1, 1, 0]
441 >>> list(powerset(seq))
442 [(), (1,), (1,), (0,), (1, 1), (1, 0), (1, 0), (1, 1, 0)]
443
444 For a variant that efficiently yields actual :class:`set` instances, see
445 :func:`powerset_of_sets`.
446 """
447 s = list(iterable)
448 return chain.from_iterable(combinations(s, r) for r in range(len(s) + 1))
449
450
451def unique_everseen(iterable, key=None):
452 """
453 Yield unique elements, preserving order.
454
455 >>> list(unique_everseen('AAAABBBCCDAABBB'))
456 ['A', 'B', 'C', 'D']
457 >>> list(unique_everseen('ABBCcAD', str.lower))
458 ['A', 'B', 'C', 'D']
459
460 Sequences with a mix of hashable and unhashable items can be used.
461 The function will be slower (i.e., `O(n^2)`) for unhashable items.
462
463 Remember that ``list`` objects are unhashable - you can use the *key*
464 parameter to transform the list to a tuple (which is hashable) to
465 avoid a slowdown.
466
467 >>> iterable = ([1, 2], [2, 3], [1, 2])
468 >>> list(unique_everseen(iterable)) # Slow
469 [[1, 2], [2, 3]]
470 >>> list(unique_everseen(iterable, key=tuple)) # Faster
471 [[1, 2], [2, 3]]
472
473 Similarly, you may want to convert unhashable ``set`` objects with
474 ``key=frozenset``. For ``dict`` objects,
475 ``key=lambda x: frozenset(x.items())`` can be used.
476
477 """
478 seenset = set()
479 seenset_add = seenset.add
480 seenlist = []
481 seenlist_add = seenlist.append
482 use_key = key is not None
483
484 for element in iterable:
485 k = key(element) if use_key else element
486 try:
487 if k not in seenset:
488 seenset_add(k)
489 yield element
490 except TypeError:
491 if k not in seenlist:
492 seenlist_add(k)
493 yield element
494
495
496def unique_justseen(iterable, key=None):
497 """Yields elements in order, ignoring serial duplicates
498
499 >>> list(unique_justseen('AAAABBBCCDAABBB'))
500 ['A', 'B', 'C', 'D', 'A', 'B']
501 >>> list(unique_justseen('ABBCcAD', str.lower))
502 ['A', 'B', 'C', 'A', 'D']
503
504 """
505 if key is None:
506 return map(itemgetter(0), groupby(iterable))
507
508 return map(next, map(itemgetter(1), groupby(iterable, key)))
509
510
511def unique(iterable, key=None, reverse=False):
512 """Yields unique elements in sorted order.
513
514 >>> list(unique([[1, 2], [3, 4], [1, 2]]))
515 [[1, 2], [3, 4]]
516
517 *key* and *reverse* are passed to :func:`sorted`.
518
519 >>> list(unique('ABBcCAD', str.casefold))
520 ['A', 'B', 'c', 'D']
521 >>> list(unique('ABBcCAD', str.casefold, reverse=True))
522 ['D', 'c', 'B', 'A']
523
524 The elements in *iterable* need not be hashable, but they must be
525 comparable for sorting to work.
526 """
527 sequenced = sorted(iterable, key=key, reverse=reverse)
528 return unique_justseen(sequenced, key=key)
529
530
531def iter_except(func, exception, first=None):
532 """Yields results from a function repeatedly until an exception is raised.
533
534 Converts a call-until-exception interface to an iterator interface.
535 Like ``iter(func, sentinel)``, but uses an exception instead of a sentinel
536 to end the loop.
537
538 >>> l = [0, 1, 2]
539 >>> list(iter_except(l.pop, IndexError))
540 [2, 1, 0]
541
542 Multiple exceptions can be specified as a stopping condition:
543
544 >>> l = [1, 2, 3, '...', 4, 5, 6]
545 >>> list(iter_except(lambda: 1 + l.pop(), (IndexError, TypeError)))
546 [7, 6, 5]
547 >>> list(iter_except(lambda: 1 + l.pop(), (IndexError, TypeError)))
548 [4, 3, 2]
549 >>> list(iter_except(lambda: 1 + l.pop(), (IndexError, TypeError)))
550 []
551
552 """
553 with suppress(exception):
554 if first is not None:
555 yield first()
556 while True:
557 yield func()
558
559
560def first_true(iterable, default=None, pred=None):
561 """
562 Returns the first true value in the iterable.
563
564 If no true value is found, returns *default*
565
566 If *pred* is not None, returns the first item for which
567 ``pred(item) == True`` .
568
569 >>> first_true(range(10))
570 1
571 >>> first_true(range(10), pred=lambda x: x > 5)
572 6
573 >>> first_true(range(10), default='missing', pred=lambda x: x > 9)
574 'missing'
575
576 """
577 return next(filter(pred, iterable), default)
578
579
580def random_product(*iterables, repeat=1):
581 """Draw an item at random from each of the input iterables.
582
583 >>> random_product('abc', range(4), 'XYZ') # doctest:+SKIP
584 ('c', 3, 'Z')
585
586 If *repeat* is provided as a keyword argument, that many items will be
587 drawn from each iterable.
588
589 >>> random_product('abcd', range(4), repeat=2) # doctest:+SKIP
590 ('a', 2, 'd', 3)
591
592 This equivalent to taking a random selection from
593 ``itertools.product(*args, repeat=repeat)``.
594
595 """
596 pools = tuple(map(tuple, iterables)) * repeat
597 return tuple(map(choice, pools))
598
599
600def random_permutation(iterable, r=None):
601 """Return a random *r* length permutation of the elements in *iterable*.
602
603 If *r* is not specified or is ``None``, then *r* defaults to the length of
604 *iterable*.
605
606 >>> random_permutation(range(5)) # doctest:+SKIP
607 (3, 4, 0, 1, 2)
608
609 This equivalent to taking a random selection from
610 ``itertools.permutations(iterable, r)``.
611
612 """
613 pool = tuple(iterable)
614 r = len(pool) if r is None else r
615 return tuple(sample(pool, r))
616
617
618def random_combination(iterable, r):
619 """Return a random *r* length subsequence of the elements in *iterable*.
620
621 >>> random_combination(range(5), 3) # doctest:+SKIP
622 (2, 3, 4)
623
624 This equivalent to taking a random selection from
625 ``itertools.combinations(iterable, r)``.
626
627 """
628 pool = tuple(iterable)
629 n = len(pool)
630 indices = sorted(sample(range(n), r))
631 return tuple([pool[i] for i in indices])
632
633
634def random_combination_with_replacement(iterable, r):
635 """Return a random *r* length subsequence of elements in *iterable*,
636 allowing individual elements to be repeated.
637
638 >>> random_combination_with_replacement(range(3), 5) # doctest:+SKIP
639 (0, 0, 1, 2, 2)
640
641 This equivalent to taking a random selection from
642 ``itertools.combinations_with_replacement(iterable, r)``.
643
644 """
645 pool = tuple(iterable)
646 n = len(pool)
647 indices = sorted(randrange(n) for i in range(r))
648 return tuple([pool[i] for i in indices])
649
650
651def nth_combination(iterable, r, index):
652 """Equivalent to ``list(combinations(iterable, r))[index]``.
653
654 The subsequences of *iterable* that are of length *r* can be ordered
655 lexicographically. :func:`nth_combination` computes the subsequence at
656 sort position *index* directly, without computing the previous
657 subsequences.
658
659 >>> nth_combination(range(5), 3, 5)
660 (0, 3, 4)
661
662 ``ValueError`` will be raised If *r* is negative.
663 ``IndexError`` will be raised if the given *index* is invalid.
664 """
665 pool = tuple(iterable)
666 n = len(pool)
667 c = comb(n, r)
668
669 if index < 0:
670 index += c
671 if not 0 <= index < c:
672 raise IndexError
673
674 result = []
675 while r:
676 c, n, r = c * r // n, n - 1, r - 1
677 while index >= c:
678 index -= c
679 c, n = c * (n - r) // n, n - 1
680 result.append(pool[-1 - n])
681
682 return tuple(result)
683
684
685def prepend(value, iterator):
686 """Yield *value*, followed by the elements in *iterator*.
687
688 >>> value = '0'
689 >>> iterator = ['1', '2', '3']
690 >>> list(prepend(value, iterator))
691 ['0', '1', '2', '3']
692
693 To prepend multiple values, see :func:`itertools.chain`
694 or :func:`value_chain`.
695
696 """
697 return chain([value], iterator)
698
699
700def convolve(signal, kernel):
701 """Discrete linear convolution of two iterables.
702 Equivalent to polynomial multiplication.
703
704 For example, multiplying ``(x² -x - 20)`` by ``(x - 3)``
705 gives ``(x³ -4x² -17x + 60)``.
706
707 >>> list(convolve([1, -1, -20], [1, -3]))
708 [1, -4, -17, 60]
709
710 Examples of popular kinds of kernels:
711
712 * The kernel ``[0.25, 0.25, 0.25, 0.25]`` computes a moving average.
713 For image data, this blurs the image and reduces noise.
714 * The kernel ``[1/2, 0, -1/2]`` estimates the first derivative of
715 a function evaluated at evenly spaced inputs.
716 * The kernel ``[1, -2, 1]`` estimates the second derivative of a
717 function evaluated at evenly spaced inputs.
718
719 Convolutions are mathematically commutative; however, the inputs are
720 evaluated differently. The signal is consumed lazily and can be
721 infinite. The kernel is fully consumed before the calculations begin.
722
723 Supports all numeric types: int, float, complex, Decimal, Fraction.
724
725 References:
726
727 * Article: https://betterexplained.com/articles/intuitive-convolution/
728 * Video by 3Blue1Brown: https://www.youtube.com/watch?v=KuXjwB4LzSA
729
730 """
731 # This implementation comes from an older version of the itertools
732 # documentation. While the newer implementation is a bit clearer,
733 # this one was kept because the inlined window logic is faster
734 # and it avoids an unnecessary deque-to-tuple conversion.
735 kernel = tuple(kernel)[::-1]
736 n = len(kernel)
737 window = deque([0], maxlen=n) * n
738 for x in chain(signal, repeat(0, n - 1)):
739 window.append(x)
740 yield _sumprod(kernel, window)
741
742
743def before_and_after(predicate, it):
744 """A variant of :func:`takewhile` that allows complete access to the
745 remainder of the iterator.
746
747 >>> it = iter('ABCdEfGhI')
748 >>> all_upper, remainder = before_and_after(str.isupper, it)
749 >>> ''.join(all_upper)
750 'ABC'
751 >>> ''.join(remainder) # takewhile() would lose the 'd'
752 'dEfGhI'
753
754 Note that the first iterator must be fully consumed before the second
755 iterator can generate valid results.
756 """
757 trues, after = tee(it)
758 trues = compress(takewhile(predicate, trues), zip(after))
759 return trues, after
760
761
762def triplewise(iterable):
763 """Return overlapping triplets from *iterable*.
764
765 >>> list(triplewise('ABCDE'))
766 [('A', 'B', 'C'), ('B', 'C', 'D'), ('C', 'D', 'E')]
767
768 """
769 # This deviates from the itertools documentation recipe - see
770 # https://github.com/more-itertools/more-itertools/issues/889
771 t1, t2, t3 = tee(iterable, 3)
772 next(t3, None)
773 next(t3, None)
774 next(t2, None)
775 return zip(t1, t2, t3)
776
777
778def _sliding_window_islice(iterable, n):
779 # Fast path for small, non-zero values of n.
780 iterators = tee(iterable, n)
781 for i, iterator in enumerate(iterators):
782 next(islice(iterator, i, i), None)
783 return zip(*iterators)
784
785
786def _sliding_window_deque(iterable, n):
787 # Normal path for other values of n.
788 iterator = iter(iterable)
789 window = deque(islice(iterator, n - 1), maxlen=n)
790 for x in iterator:
791 window.append(x)
792 yield tuple(window)
793
794
795def sliding_window(iterable, n):
796 """Return a sliding window of width *n* over *iterable*.
797
798 >>> list(sliding_window(range(6), 4))
799 [(0, 1, 2, 3), (1, 2, 3, 4), (2, 3, 4, 5)]
800
801 If *iterable* has fewer than *n* items, then nothing is yielded:
802
803 >>> list(sliding_window(range(3), 4))
804 []
805
806 For a variant with more features, see :func:`windowed`.
807 """
808 if n > 20:
809 return _sliding_window_deque(iterable, n)
810 elif n > 2:
811 return _sliding_window_islice(iterable, n)
812 elif n == 2:
813 return pairwise(iterable)
814 elif n == 1:
815 return zip(iterable)
816 else:
817 raise ValueError(f'n should be at least one, not {n}')
818
819
820def subslices(iterable):
821 """Return all contiguous non-empty subslices of *iterable*.
822
823 >>> list(subslices('ABC'))
824 [['A'], ['A', 'B'], ['A', 'B', 'C'], ['B'], ['B', 'C'], ['C']]
825
826 This is similar to :func:`substrings`, but emits items in a different
827 order.
828 """
829 seq = list(iterable)
830 slices = starmap(slice, combinations(range(len(seq) + 1), 2))
831 return map(getitem, repeat(seq), slices)
832
833
834def polynomial_from_roots(roots):
835 """Compute a polynomial's coefficients from its roots.
836
837 >>> roots = [5, -4, 3] # (x - 5) * (x + 4) * (x - 3)
838 >>> polynomial_from_roots(roots) # x³ - 4 x² - 17 x + 60
839 [1, -4, -17, 60]
840
841 Note that polynomial coefficients are specified in descending power order.
842
843 Supports all numeric types: int, float, complex, Decimal, Fraction.
844 """
845
846 # This recipe differs from the one in itertools docs in that it
847 # applies list() after each call to convolve(). This avoids
848 # hitting stack limits with nested generators.
849
850 poly = [1]
851 for root in roots:
852 poly = list(convolve(poly, (1, -root)))
853 return poly
854
855
856def iter_index(iterable, value, start=0, stop=None):
857 """Yield the index of each place in *iterable* that *value* occurs,
858 beginning with index *start* and ending before index *stop*.
859
860
861 >>> list(iter_index('AABCADEAF', 'A'))
862 [0, 1, 4, 7]
863 >>> list(iter_index('AABCADEAF', 'A', 1)) # start index is inclusive
864 [1, 4, 7]
865 >>> list(iter_index('AABCADEAF', 'A', 1, 7)) # stop index is not inclusive
866 [1, 4]
867
868 The behavior for non-scalar *values* matches the built-in Python types.
869
870 >>> list(iter_index('ABCDABCD', 'AB'))
871 [0, 4]
872 >>> list(iter_index([0, 1, 2, 3, 0, 1, 2, 3], [0, 1]))
873 []
874 >>> list(iter_index([[0, 1], [2, 3], [0, 1], [2, 3]], [0, 1]))
875 [0, 2]
876
877 See :func:`locate` for a more general means of finding the indexes
878 associated with particular values.
879
880 """
881 seq_index = getattr(iterable, 'index', None)
882 if seq_index is None:
883 # Slow path for general iterables
884 iterator = islice(iterable, start, stop)
885 for i, element in enumerate(iterator, start):
886 if element is value or element == value:
887 yield i
888 else:
889 # Fast path for sequences
890 stop = len(iterable) if stop is None else stop
891 i = start - 1
892 with suppress(ValueError):
893 while True:
894 yield (i := seq_index(value, i + 1, stop))
895
896
897def sieve(n):
898 """Yield the primes less than n.
899
900 >>> list(sieve(30))
901 [2, 3, 5, 7, 11, 13, 17, 19, 23, 29]
902
903 """
904 # This implementation comes from an older version of the itertools
905 # documentation. The newer implementation is easier to read but is
906 # less lazy.
907 if n > 2:
908 yield 2
909 start = 3
910 data = bytearray((0, 1)) * (n // 2)
911 for p in iter_index(data, 1, start, stop=isqrt(n) + 1):
912 yield from iter_index(data, 1, start, p * p)
913 data[p * p : n : p + p] = bytes(len(range(p * p, n, p + p)))
914 start = p * p
915 yield from iter_index(data, 1, start)
916
917
918def _batched(iterable, n, *, strict=False): # pragma: no cover
919 """Batch data into tuples of length *n*. If the number of items in
920 *iterable* is not divisible by *n*:
921 * The last batch will be shorter if *strict* is ``False``.
922 * :exc:`ValueError` will be raised if *strict* is ``True``.
923
924 >>> list(batched('ABCDEFG', 3))
925 [('A', 'B', 'C'), ('D', 'E', 'F'), ('G',)]
926
927 On Python 3.13 and above, this is an alias for :func:`itertools.batched`.
928 """
929 if n < 1:
930 raise ValueError('n must be at least one')
931 iterator = iter(iterable)
932 while batch := tuple(islice(iterator, n)):
933 if strict and len(batch) != n:
934 raise ValueError('batched(): incomplete batch')
935 yield batch
936
937
938if hexversion >= 0x30D00A2: # pragma: no cover
939 from itertools import batched as itertools_batched
940
941 def batched(iterable, n, *, strict=False):
942 return itertools_batched(iterable, n, strict=strict)
943
944 batched.__doc__ = _batched.__doc__
945else: # pragma: no cover
946 batched = _batched
947
948
949def transpose(it):
950 """Swap the rows and columns of the input matrix.
951
952 >>> list(transpose([(1, 2, 3), (11, 22, 33)]))
953 [(1, 11), (2, 22), (3, 33)]
954
955 The caller should ensure that the dimensions of the input are compatible.
956 If the input is empty, no output will be produced.
957 """
958 return zip(*it, strict=True)
959
960
961def _is_scalar(value, stringlike=(str, bytes)):
962 "Scalars are bytes, strings, and non-iterables."
963 try:
964 iter(value)
965 except TypeError:
966 return True
967 return isinstance(value, stringlike)
968
969
970def _flatten_tensor(tensor):
971 "Depth-first iterator over scalars in a tensor."
972 iterator = iter(tensor)
973 while True:
974 try:
975 value = next(iterator)
976 except StopIteration:
977 return iterator
978 iterator = chain((value,), iterator)
979 if _is_scalar(value):
980 return iterator
981 iterator = chain.from_iterable(iterator)
982
983
984def reshape(matrix, shape):
985 """Change the shape of a *matrix*.
986
987 If *shape* is an integer, the matrix must be two dimensional
988 and the shape is interpreted as the desired number of columns:
989
990 >>> matrix = [(0, 1), (2, 3), (4, 5)]
991 >>> cols = 3
992 >>> list(reshape(matrix, cols))
993 [(0, 1, 2), (3, 4, 5)]
994
995 If *shape* is a tuple (or other iterable), the input matrix can have
996 any number of dimensions. It will first be flattened and then rebuilt
997 to the desired shape which can also be multidimensional:
998
999 >>> matrix = [(0, 1), (2, 3), (4, 5)] # Start with a 3 x 2 matrix
1000
1001 >>> list(reshape(matrix, (2, 3))) # Make a 2 x 3 matrix
1002 [(0, 1, 2), (3, 4, 5)]
1003
1004 >>> list(reshape(matrix, (6,))) # Make a vector of length six
1005 [0, 1, 2, 3, 4, 5]
1006
1007 >>> list(reshape(matrix, (2, 1, 3, 1))) # Make 2 x 1 x 3 x 1 tensor
1008 [(((0,), (1,), (2,)),), (((3,), (4,), (5,)),)]
1009
1010 Each dimension is assumed to be uniform, either all arrays or all scalars.
1011 Flattening stops when the first value in a dimension is a scalar.
1012 Scalars are bytes, strings, and non-iterables.
1013 The reshape iterator stops when the requested shape is complete
1014 or when the input is exhausted, whichever comes first.
1015
1016 """
1017 if isinstance(shape, int):
1018 return batched(chain.from_iterable(matrix), shape)
1019 first_dim, *dims = shape
1020 scalar_stream = _flatten_tensor(matrix)
1021 reshaped = reduce(batched, reversed(dims), scalar_stream)
1022 return islice(reshaped, first_dim)
1023
1024
1025def matmul(m1, m2):
1026 """Multiply two matrices.
1027
1028 >>> list(matmul([(7, 5), (3, 5)], [(2, 5), (7, 9)]))
1029 [(49, 80), (41, 60)]
1030
1031 The caller should ensure that the dimensions of the input matrices are
1032 compatible with each other.
1033
1034 Supports all numeric types: int, float, complex, Decimal, Fraction.
1035 """
1036 n = len(m2[0])
1037 return batched(starmap(_sumprod, product(m1, transpose(m2))), n)
1038
1039
1040def _factor_pollard(n):
1041 # Return a factor of n using Pollard's rho algorithm.
1042 # Efficient when n is odd and composite.
1043 for b in range(1, n):
1044 x = y = 2
1045 d = 1
1046 while d == 1:
1047 x = (x * x + b) % n
1048 y = (y * y + b) % n
1049 y = (y * y + b) % n
1050 d = gcd(x - y, n)
1051 if d != n:
1052 return d
1053 raise ValueError('prime or under 5') # pragma: no cover
1054
1055
1056_primes_below_211 = tuple(sieve(211))
1057
1058
1059def factor(n):
1060 """Yield the prime factors of n.
1061
1062 >>> list(factor(360))
1063 [2, 2, 2, 3, 3, 5]
1064
1065 Finds small factors with trial division. Larger factors are
1066 either verified as prime with ``is_prime`` or split into
1067 smaller factors with Pollard's rho algorithm.
1068 """
1069
1070 # Corner case reduction
1071 if n < 2:
1072 return
1073
1074 # Trial division reduction
1075 for prime in _primes_below_211:
1076 while not n % prime:
1077 yield prime
1078 n //= prime
1079
1080 # Pollard's rho reduction
1081 primes = []
1082 todo = [n] if n > 1 else []
1083 for n in todo:
1084 if n < 211**2 or is_prime(n):
1085 primes.append(n)
1086 else:
1087 fact = _factor_pollard(n)
1088 todo += (fact, n // fact)
1089 yield from sorted(primes)
1090
1091
1092def polynomial_eval(coefficients, x):
1093 """Evaluate a polynomial at a specific value.
1094
1095 Computes with better numeric stability than Horner's method.
1096
1097 Evaluate ``x^3 - 4 * x^2 - 17 * x + 60`` at ``x = 2.5``:
1098
1099 >>> coefficients = [1, -4, -17, 60]
1100 >>> x = 2.5
1101 >>> polynomial_eval(coefficients, x)
1102 8.125
1103
1104 Note that polynomial coefficients are specified in descending power order.
1105
1106 Supports all numeric types: int, float, complex, Decimal, Fraction.
1107 """
1108 n = len(coefficients)
1109 if n == 0:
1110 return type(x)(0)
1111 powers = map(pow, repeat(x), reversed(range(n)))
1112 return _sumprod(coefficients, powers)
1113
1114
1115def sum_of_squares(it):
1116 """Return the sum of the squares of the input values.
1117
1118 >>> sum_of_squares([10, 20, 30])
1119 1400
1120
1121 Supports all numeric types: int, float, complex, Decimal, Fraction.
1122 """
1123 return _sumprod(*tee(it))
1124
1125
1126def polynomial_derivative(coefficients):
1127 """Compute the first derivative of a polynomial.
1128
1129 Evaluate the derivative of ``x³ - 4 x² - 17 x + 60``:
1130
1131 >>> coefficients = [1, -4, -17, 60]
1132 >>> derivative_coefficients = polynomial_derivative(coefficients)
1133 >>> derivative_coefficients
1134 [3, -8, -17]
1135
1136 Note that polynomial coefficients are specified in descending power order.
1137
1138 Supports all numeric types: int, float, complex, Decimal, Fraction.
1139 """
1140 n = len(coefficients)
1141 powers = reversed(range(1, n))
1142 return list(map(mul, coefficients, powers))
1143
1144
1145def totient(n):
1146 """Return the count of natural numbers up to *n* that are coprime with *n*.
1147
1148 Euler's totient function φ(n) gives the number of totatives.
1149 Totative are integers k in the range 1 ≤ k ≤ n such that gcd(n, k) = 1.
1150
1151 >>> n = 9
1152 >>> totient(n)
1153 6
1154
1155 >>> totatives = [x for x in range(1, n) if gcd(n, x) == 1]
1156 >>> totatives
1157 [1, 2, 4, 5, 7, 8]
1158 >>> len(totatives)
1159 6
1160
1161 Reference: https://en.wikipedia.org/wiki/Euler%27s_totient_function
1162
1163 """
1164 for prime in set(factor(n)):
1165 n -= n // prime
1166 return n
1167
1168
1169# Miller–Rabin primality test: https://oeis.org/A014233
1170_perfect_tests = [
1171 (2047, (2,)),
1172 (9080191, (31, 73)),
1173 (4759123141, (2, 7, 61)),
1174 (1122004669633, (2, 13, 23, 1662803)),
1175 (2152302898747, (2, 3, 5, 7, 11)),
1176 (3474749660383, (2, 3, 5, 7, 11, 13)),
1177 (18446744073709551616, (2, 325, 9375, 28178, 450775, 9780504, 1795265022)),
1178 (
1179 3317044064679887385961981,
1180 (2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41),
1181 ),
1182]
1183
1184
1185@lru_cache
1186def _shift_to_odd(n):
1187 'Return s, d such that 2**s * d == n'
1188 s = ((n - 1) ^ n).bit_length() - 1
1189 d = n >> s
1190 assert (1 << s) * d == n and d & 1 and s >= 0
1191 return s, d
1192
1193
1194def _strong_probable_prime(n, base):
1195 assert (n > 2) and (n & 1) and (2 <= base < n)
1196
1197 s, d = _shift_to_odd(n - 1)
1198
1199 x = pow(base, d, n)
1200 if x == 1 or x == n - 1:
1201 return True
1202
1203 for _ in range(s - 1):
1204 x = x * x % n
1205 if x == n - 1:
1206 return True
1207
1208 return False
1209
1210
1211# Separate instance of Random() that doesn't share state
1212# with the default user instance of Random().
1213_private_randrange = random.Random().randrange
1214
1215
1216def is_prime(n):
1217 """Return ``True`` if *n* is prime and ``False`` otherwise.
1218
1219 Basic examples:
1220
1221 >>> is_prime(37)
1222 True
1223 >>> is_prime(3 * 13)
1224 False
1225 >>> is_prime(18_446_744_073_709_551_557)
1226 True
1227
1228 Find the next prime over one billion:
1229
1230 >>> next(filter(is_prime, count(10**9)))
1231 1000000007
1232
1233 Generate random primes up to 200 bits and up to 60 decimal digits:
1234
1235 >>> from random import seed, randrange, getrandbits
1236 >>> seed(18675309)
1237
1238 >>> next(filter(is_prime, map(getrandbits, repeat(200))))
1239 893303929355758292373272075469392561129886005037663238028407
1240
1241 >>> next(filter(is_prime, map(randrange, repeat(10**60))))
1242 269638077304026462407872868003560484232362454342414618963649
1243
1244 This function is exact for values of *n* below 10**24. For larger inputs,
1245 the probabilistic Miller-Rabin primality test has a less than 1 in 2**128
1246 chance of a false positive.
1247 """
1248
1249 if n < 17:
1250 return n in {2, 3, 5, 7, 11, 13}
1251
1252 if not (n & 1 and n % 3 and n % 5 and n % 7 and n % 11 and n % 13):
1253 return False
1254
1255 for limit, bases in _perfect_tests:
1256 if n < limit:
1257 break
1258 else:
1259 bases = (_private_randrange(2, n - 1) for i in range(64))
1260
1261 return all(_strong_probable_prime(n, base) for base in bases)
1262
1263
1264def loops(n):
1265 """Returns an iterable with *n* elements for efficient looping.
1266 Like ``range(n)`` but doesn't create integers.
1267
1268 >>> i = 0
1269 >>> for _ in loops(5):
1270 ... i += 1
1271 >>> i
1272 5
1273
1274 """
1275 return repeat(None, n)
1276
1277
1278def multinomial(*counts):
1279 """Number of distinct arrangements of a multiset.
1280
1281 The expression ``multinomial(3, 4, 2)`` has several equivalent
1282 interpretations:
1283
1284 * In the expansion of ``(a + b + c)⁹``, the coefficient of the
1285 ``a³b⁴c²`` term is 1260.
1286
1287 * There are 1260 distinct ways to arrange 9 balls consisting of 3 reds, 4
1288 greens, and 2 blues.
1289
1290 * There are 1260 unique ways to place 9 distinct objects into three bins
1291 with sizes 3, 4, and 2.
1292
1293 The :func:`multinomial` function computes the length of
1294 :func:`distinct_permutations`. For example, there are 83,160 distinct
1295 anagrams of the word "abracadabra":
1296
1297 >>> from more_itertools import distinct_permutations, ilen
1298 >>> ilen(distinct_permutations('abracadabra'))
1299 83160
1300
1301 This can be computed directly from the letter counts, 5a 2b 2r 1c 1d:
1302
1303 >>> from collections import Counter
1304 >>> list(Counter('abracadabra').values())
1305 [5, 2, 2, 1, 1]
1306 >>> multinomial(5, 2, 2, 1, 1)
1307 83160
1308
1309 A binomial coefficient is a special case of multinomial where there are
1310 only two categories. For example, the number of ways to arrange 12 balls
1311 with 5 reds and 7 blues is ``multinomial(5, 7)`` or ``math.comb(12, 5)``.
1312
1313 Likewise, factorial is a special case of multinomial where
1314 the multiplicities are all just 1 so that
1315 ``multinomial(1, 1, 1, 1, 1, 1, 1) == math.factorial(7)``.
1316
1317 Reference: https://en.wikipedia.org/wiki/Multinomial_theorem
1318
1319 """
1320 return prod(map(comb, accumulate(counts), counts))
1321
1322
1323def _running_median_minheap_and_maxheap(iterator): # pragma: no cover
1324 "Non-windowed running_median() for Python 3.14+"
1325
1326 read = iterator.__next__
1327 lo = [] # max-heap
1328 hi = [] # min-heap (same size as or one smaller than lo)
1329
1330 with suppress(StopIteration):
1331 while True:
1332 heappush_max(lo, heappushpop(hi, read()))
1333 yield lo[0]
1334
1335 heappush(hi, heappushpop_max(lo, read()))
1336 yield (lo[0] + hi[0]) / 2
1337
1338
1339def _running_median_minheap_only(iterator): # pragma: no cover
1340 "Backport of non-windowed running_median() for Python 3.13 and prior."
1341
1342 read = iterator.__next__
1343 lo = [] # max-heap (actually a minheap with negated values)
1344 hi = [] # min-heap (same size as or one smaller than lo)
1345
1346 with suppress(StopIteration):
1347 while True:
1348 heappush(lo, -heappushpop(hi, read()))
1349 yield -lo[0]
1350
1351 heappush(hi, -heappushpop(lo, -read()))
1352 yield (hi[0] - lo[0]) / 2
1353
1354
1355def _running_median_windowed(iterator, maxlen):
1356 "Yield median of values in a sliding window."
1357
1358 window = deque()
1359 ordered = []
1360
1361 for x in iterator:
1362 window.append(x)
1363 insort(ordered, x)
1364
1365 if len(ordered) > maxlen:
1366 i = bisect_left(ordered, window.popleft())
1367 del ordered[i]
1368
1369 n = len(ordered)
1370 m = n // 2
1371 yield ordered[m] if n & 1 else (ordered[m - 1] + ordered[m]) / 2
1372
1373
1374def running_median(iterable, *, maxlen=None):
1375 """Cumulative median of values seen so far or values in a sliding window.
1376
1377 Set *maxlen* to a positive integer to specify the maximum size
1378 of the sliding window. The default of *None* is equivalent to
1379 an unbounded window.
1380
1381 For example:
1382
1383 >>> list(running_median([5.0, 9.0, 4.0, 12.0, 8.0, 9.0]))
1384 [5.0, 7.0, 5.0, 7.0, 8.0, 8.5]
1385 >>> list(running_median([5.0, 9.0, 4.0, 12.0, 8.0, 9.0], maxlen=3))
1386 [5.0, 7.0, 5.0, 9.0, 8.0, 9.0]
1387
1388 Supports numeric types such as int, float, Decimal, and Fraction,
1389 but not complex numbers which are unorderable.
1390
1391 On version Python 3.13 and prior, max-heaps are simulated with
1392 negative values. The negation causes Decimal inputs to apply context
1393 rounding, making the results slightly different than that obtained
1394 by statistics.median().
1395 """
1396
1397 iterator = iter(iterable)
1398
1399 if maxlen is not None:
1400 maxlen = index(maxlen)
1401 if maxlen <= 0:
1402 raise ValueError('Window size should be positive')
1403 return _running_median_windowed(iterator, maxlen)
1404
1405 if not _max_heap_available:
1406 return _running_median_minheap_only(iterator) # pragma: no cover
1407
1408 return _running_median_minheap_and_maxheap(iterator) # pragma: no cover
1409
1410
1411def _windowed_running_mean(iterator, n):
1412 window = deque()
1413 running_sum = 0
1414 for value in iterator:
1415 window.append(value)
1416 running_sum += value
1417 if len(window) > n:
1418 running_sum -= window.popleft()
1419 yield running_sum / len(window)
1420
1421
1422def running_mean(iterable, *, maxlen=None):
1423 """Cumulative mean of values seen so far or values in a sliding window.
1424
1425 Set *maxlen* to a positive integer to specify the maximum size
1426 of the sliding window. The default of *None* is equivalent to
1427 an unbounded window.
1428
1429 For example:
1430
1431 >>> list(running_mean([40, 30, 50, 46, 39, 44]))
1432 [40.0, 35.0, 40.0, 41.5, 41.0, 41.5]
1433
1434 >>> list(running_mean([40, 30, 50, 46, 39, 44], maxlen=3))
1435 [40.0, 35.0, 40.0, 42.0, 45.0, 43.0]
1436
1437 Supports numeric types such as int, float, complex, Decimal, and Fraction.
1438
1439 No extra effort is made to reduce round-off errors for float inputs.
1440 So the results may be slightly different from `statistics.mean`.
1441
1442 """
1443
1444 iterator = iter(iterable)
1445
1446 if maxlen is None:
1447 return map(truediv, accumulate(iterator), count(1))
1448
1449 if maxlen <= 0:
1450 raise ValueError('Window size should be positive')
1451
1452 return _windowed_running_mean(iterator, maxlen)
1453
1454
1455def random_derangement(iterable):
1456 """Return a random derangement of elements in the iterable.
1457
1458 Equivalent to but much faster than ``choice(list(derangements(iterable)))``.
1459
1460 """
1461 seq = tuple(iterable)
1462 if len(seq) < 2:
1463 if len(seq) == 0:
1464 return ()
1465 raise IndexError('No derangments to choose from')
1466 perm = list(range(len(seq)))
1467 start = tuple(perm)
1468 while True:
1469 shuffle(perm)
1470 if not any(map(is_, start, perm)):
1471 return itemgetter(*perm)(seq)