1"""Imported from the recipes section of the itertools documentation.
2
3All functions taken from the recipes section of the itertools library docs
4[1]_.
5Some backward-compatible usability improvements have been made.
6
7.. [1] http://docs.python.org/library/itertools.html#recipes
8
9"""
10
11import random
12
13from bisect import bisect_left, insort
14from collections import deque
15from contextlib import suppress
16from functools import lru_cache, reduce
17from heapq import heappush, heappushpop
18from itertools import (
19 accumulate,
20 chain,
21 combinations,
22 compress,
23 count,
24 cycle,
25 groupby,
26 islice,
27 pairwise,
28 product,
29 repeat,
30 starmap,
31 takewhile,
32 tee,
33 zip_longest,
34)
35from math import prod, comb, isqrt, gcd
36from operator import mul, is_, not_, itemgetter, getitem, index
37from random import randrange, sample, choice, shuffle
38from sys import hexversion
39
40__all__ = [
41 'all_equal',
42 'batched',
43 'before_and_after',
44 'consume',
45 'convolve',
46 'dotproduct',
47 'first_true',
48 'factor',
49 'flatten',
50 'grouper',
51 'is_prime',
52 'iter_except',
53 'iter_index',
54 'loops',
55 'matmul',
56 'multinomial',
57 'ncycles',
58 'nth',
59 'nth_combination',
60 'padnone',
61 'pad_none',
62 'partition',
63 'polynomial_eval',
64 'polynomial_from_roots',
65 'polynomial_derivative',
66 'powerset',
67 'prepend',
68 'quantify',
69 'reshape',
70 'random_combination_with_replacement',
71 'random_combination',
72 'random_derangement',
73 'random_permutation',
74 'random_product',
75 'repeatfunc',
76 'roundrobin',
77 'running_median',
78 'sieve',
79 'sliding_window',
80 'subslices',
81 'sum_of_squares',
82 'tabulate',
83 'tail',
84 'take',
85 'totient',
86 'transpose',
87 'triplewise',
88 'unique',
89 'unique_everseen',
90 'unique_justseen',
91]
92
93_marker = object()
94
95
96# heapq max-heap functions are available for Python 3.14+
97try:
98 from heapq import heappush_max, heappushpop_max
99except ImportError: # pragma: no cover
100 _max_heap_available = False
101else: # pragma: no cover
102 _max_heap_available = True
103
104
105def take(n, iterable):
106 """Return first *n* items of the *iterable* as a list.
107
108 >>> take(3, range(10))
109 [0, 1, 2]
110
111 If there are fewer than *n* items in the iterable, all of them are
112 returned.
113
114 >>> take(10, range(3))
115 [0, 1, 2]
116
117 """
118 return list(islice(iterable, n))
119
120
121def tabulate(function, start=0):
122 """Return an iterator over the results of ``func(start)``,
123 ``func(start + 1)``, ``func(start + 2)``...
124
125 *func* should be a function that accepts one integer argument.
126
127 If *start* is not specified it defaults to 0. It will be incremented each
128 time the iterator is advanced.
129
130 >>> square = lambda x: x ** 2
131 >>> iterator = tabulate(square, -3)
132 >>> take(4, iterator)
133 [9, 4, 1, 0]
134
135 """
136 return map(function, count(start))
137
138
139def tail(n, iterable):
140 """Return an iterator over the last *n* items of *iterable*.
141
142 >>> t = tail(3, 'ABCDEFG')
143 >>> list(t)
144 ['E', 'F', 'G']
145
146 """
147 try:
148 size = len(iterable)
149 except TypeError:
150 return iter(deque(iterable, maxlen=n))
151 else:
152 return islice(iterable, max(0, size - n), None)
153
154
155def consume(iterator, n=None):
156 """Advance *iterable* by *n* steps. If *n* is ``None``, consume it
157 entirely.
158
159 Efficiently exhausts an iterator without returning values. Defaults to
160 consuming the whole iterator, but an optional second argument may be
161 provided to limit consumption.
162
163 >>> i = (x for x in range(10))
164 >>> next(i)
165 0
166 >>> consume(i, 3)
167 >>> next(i)
168 4
169 >>> consume(i)
170 >>> next(i)
171 Traceback (most recent call last):
172 File "<stdin>", line 1, in <module>
173 StopIteration
174
175 If the iterator has fewer items remaining than the provided limit, the
176 whole iterator will be consumed.
177
178 >>> i = (x for x in range(3))
179 >>> consume(i, 5)
180 >>> next(i)
181 Traceback (most recent call last):
182 File "<stdin>", line 1, in <module>
183 StopIteration
184
185 """
186 # Use functions that consume iterators at C speed.
187 if n is None:
188 # feed the entire iterator into a zero-length deque
189 deque(iterator, maxlen=0)
190 else:
191 # advance to the empty slice starting at position n
192 next(islice(iterator, n, n), None)
193
194
195def nth(iterable, n, default=None):
196 """Returns the nth item or a default value.
197
198 >>> l = range(10)
199 >>> nth(l, 3)
200 3
201 >>> nth(l, 20, "zebra")
202 'zebra'
203
204 """
205 return next(islice(iterable, n, None), default)
206
207
208def all_equal(iterable, key=None):
209 """
210 Returns ``True`` if all the elements are equal to each other.
211
212 >>> all_equal('aaaa')
213 True
214 >>> all_equal('aaab')
215 False
216
217 A function that accepts a single argument and returns a transformed version
218 of each input item can be specified with *key*:
219
220 >>> all_equal('AaaA', key=str.casefold)
221 True
222 >>> all_equal([1, 2, 3], key=lambda x: x < 10)
223 True
224
225 """
226 iterator = groupby(iterable, key)
227 for first in iterator:
228 for second in iterator:
229 return False
230 return True
231 return True
232
233
234def quantify(iterable, pred=bool):
235 """Return the how many times the predicate is true.
236
237 >>> quantify([True, False, True])
238 2
239
240 """
241 return sum(map(pred, iterable))
242
243
244def pad_none(iterable):
245 """Returns the sequence of elements and then returns ``None`` indefinitely.
246
247 >>> take(5, pad_none(range(3)))
248 [0, 1, 2, None, None]
249
250 Useful for emulating the behavior of the built-in :func:`map` function.
251
252 See also :func:`padded`.
253
254 """
255 return chain(iterable, repeat(None))
256
257
258padnone = pad_none
259
260
261def ncycles(iterable, n):
262 """Returns the sequence elements *n* times
263
264 >>> list(ncycles(["a", "b"], 3))
265 ['a', 'b', 'a', 'b', 'a', 'b']
266
267 """
268 return chain.from_iterable(repeat(tuple(iterable), n))
269
270
271def dotproduct(vec1, vec2):
272 """Returns the dot product of the two iterables.
273
274 >>> dotproduct([10, 15, 12], [0.65, 0.80, 1.25])
275 33.5
276 >>> 10 * 0.65 + 15 * 0.80 + 12 * 1.25
277 33.5
278
279 In Python 3.12 and later, use ``math.sumprod()`` instead.
280 """
281 return sum(map(mul, vec1, vec2))
282
283
284# math.sumprod is available for Python 3.12+
285try:
286 from math import sumprod as _sumprod
287except ImportError: # pragma: no cover
288 _sumprod = dotproduct
289
290
291def flatten(listOfLists):
292 """Return an iterator flattening one level of nesting in a list of lists.
293
294 >>> list(flatten([[0, 1], [2, 3]]))
295 [0, 1, 2, 3]
296
297 See also :func:`collapse`, which can flatten multiple levels of nesting.
298
299 """
300 return chain.from_iterable(listOfLists)
301
302
303def repeatfunc(func, times=None, *args):
304 """Call *func* with *args* repeatedly, returning an iterable over the
305 results.
306
307 If *times* is specified, the iterable will terminate after that many
308 repetitions:
309
310 >>> from operator import add
311 >>> times = 4
312 >>> args = 3, 5
313 >>> list(repeatfunc(add, times, *args))
314 [8, 8, 8, 8]
315
316 If *times* is ``None`` the iterable will not terminate:
317
318 >>> from random import randrange
319 >>> times = None
320 >>> args = 1, 11
321 >>> take(6, repeatfunc(randrange, times, *args)) # doctest:+SKIP
322 [2, 4, 8, 1, 8, 4]
323
324 """
325 if times is None:
326 return starmap(func, repeat(args))
327 return starmap(func, repeat(args, times))
328
329
330def grouper(iterable, n, incomplete='fill', fillvalue=None):
331 """Group elements from *iterable* into fixed-length groups of length *n*.
332
333 >>> list(grouper('ABCDEF', 3))
334 [('A', 'B', 'C'), ('D', 'E', 'F')]
335
336 The keyword arguments *incomplete* and *fillvalue* control what happens for
337 iterables whose length is not a multiple of *n*.
338
339 When *incomplete* is `'fill'`, the last group will contain instances of
340 *fillvalue*.
341
342 >>> list(grouper('ABCDEFG', 3, incomplete='fill', fillvalue='x'))
343 [('A', 'B', 'C'), ('D', 'E', 'F'), ('G', 'x', 'x')]
344
345 When *incomplete* is `'ignore'`, the last group will not be emitted.
346
347 >>> list(grouper('ABCDEFG', 3, incomplete='ignore', fillvalue='x'))
348 [('A', 'B', 'C'), ('D', 'E', 'F')]
349
350 When *incomplete* is `'strict'`, a `ValueError` will be raised.
351
352 >>> iterator = grouper('ABCDEFG', 3, incomplete='strict')
353 >>> list(iterator) # doctest: +IGNORE_EXCEPTION_DETAIL
354 Traceback (most recent call last):
355 ...
356 ValueError
357
358 """
359 iterators = [iter(iterable)] * n
360 if incomplete == 'fill':
361 return zip_longest(*iterators, fillvalue=fillvalue)
362 if incomplete == 'strict':
363 return zip(*iterators, strict=True)
364 if incomplete == 'ignore':
365 return zip(*iterators)
366 else:
367 raise ValueError('Expected fill, strict, or ignore')
368
369
370def roundrobin(*iterables):
371 """Visit input iterables in a cycle until each is exhausted.
372
373 >>> list(roundrobin('ABC', 'D', 'EF'))
374 ['A', 'D', 'E', 'B', 'F', 'C']
375
376 This function produces the same output as :func:`interleave_longest`, but
377 may perform better for some inputs (in particular when the number of
378 iterables is small).
379
380 """
381 # Algorithm credited to George Sakkis
382 iterators = map(iter, iterables)
383 for num_active in range(len(iterables), 0, -1):
384 iterators = cycle(islice(iterators, num_active))
385 yield from map(next, iterators)
386
387
388def partition(pred, iterable):
389 """
390 Returns a 2-tuple of iterables derived from the input iterable.
391 The first yields the items that have ``pred(item) == False``.
392 The second yields the items that have ``pred(item) == True``.
393
394 >>> is_odd = lambda x: x % 2 != 0
395 >>> iterable = range(10)
396 >>> even_items, odd_items = partition(is_odd, iterable)
397 >>> list(even_items), list(odd_items)
398 ([0, 2, 4, 6, 8], [1, 3, 5, 7, 9])
399
400 If *pred* is None, :func:`bool` is used.
401
402 >>> iterable = [0, 1, False, True, '', ' ']
403 >>> false_items, true_items = partition(None, iterable)
404 >>> list(false_items), list(true_items)
405 ([0, False, ''], [1, True, ' '])
406
407 """
408 if pred is None:
409 pred = bool
410
411 t1, t2, p = tee(iterable, 3)
412 p1, p2 = tee(map(pred, p))
413 return (compress(t1, map(not_, p1)), compress(t2, p2))
414
415
416def powerset(iterable):
417 """Yields all possible subsets of the iterable.
418
419 >>> list(powerset([1, 2, 3]))
420 [(), (1,), (2,), (3,), (1, 2), (1, 3), (2, 3), (1, 2, 3)]
421
422 :func:`powerset` will operate on iterables that aren't :class:`set`
423 instances, so repeated elements in the input will produce repeated elements
424 in the output.
425
426 >>> seq = [1, 1, 0]
427 >>> list(powerset(seq))
428 [(), (1,), (1,), (0,), (1, 1), (1, 0), (1, 0), (1, 1, 0)]
429
430 For a variant that efficiently yields actual :class:`set` instances, see
431 :func:`powerset_of_sets`.
432 """
433 s = list(iterable)
434 return chain.from_iterable(combinations(s, r) for r in range(len(s) + 1))
435
436
437def unique_everseen(iterable, key=None):
438 """
439 Yield unique elements, preserving order.
440
441 >>> list(unique_everseen('AAAABBBCCDAABBB'))
442 ['A', 'B', 'C', 'D']
443 >>> list(unique_everseen('ABBCcAD', str.lower))
444 ['A', 'B', 'C', 'D']
445
446 Sequences with a mix of hashable and unhashable items can be used.
447 The function will be slower (i.e., `O(n^2)`) for unhashable items.
448
449 Remember that ``list`` objects are unhashable - you can use the *key*
450 parameter to transform the list to a tuple (which is hashable) to
451 avoid a slowdown.
452
453 >>> iterable = ([1, 2], [2, 3], [1, 2])
454 >>> list(unique_everseen(iterable)) # Slow
455 [[1, 2], [2, 3]]
456 >>> list(unique_everseen(iterable, key=tuple)) # Faster
457 [[1, 2], [2, 3]]
458
459 Similarly, you may want to convert unhashable ``set`` objects with
460 ``key=frozenset``. For ``dict`` objects,
461 ``key=lambda x: frozenset(x.items())`` can be used.
462
463 """
464 seenset = set()
465 seenset_add = seenset.add
466 seenlist = []
467 seenlist_add = seenlist.append
468 use_key = key is not None
469
470 for element in iterable:
471 k = key(element) if use_key else element
472 try:
473 if k not in seenset:
474 seenset_add(k)
475 yield element
476 except TypeError:
477 if k not in seenlist:
478 seenlist_add(k)
479 yield element
480
481
482def unique_justseen(iterable, key=None):
483 """Yields elements in order, ignoring serial duplicates
484
485 >>> list(unique_justseen('AAAABBBCCDAABBB'))
486 ['A', 'B', 'C', 'D', 'A', 'B']
487 >>> list(unique_justseen('ABBCcAD', str.lower))
488 ['A', 'B', 'C', 'A', 'D']
489
490 """
491 if key is None:
492 return map(itemgetter(0), groupby(iterable))
493
494 return map(next, map(itemgetter(1), groupby(iterable, key)))
495
496
497def unique(iterable, key=None, reverse=False):
498 """Yields unique elements in sorted order.
499
500 >>> list(unique([[1, 2], [3, 4], [1, 2]]))
501 [[1, 2], [3, 4]]
502
503 *key* and *reverse* are passed to :func:`sorted`.
504
505 >>> list(unique('ABBcCAD', str.casefold))
506 ['A', 'B', 'c', 'D']
507 >>> list(unique('ABBcCAD', str.casefold, reverse=True))
508 ['D', 'c', 'B', 'A']
509
510 The elements in *iterable* need not be hashable, but they must be
511 comparable for sorting to work.
512 """
513 sequenced = sorted(iterable, key=key, reverse=reverse)
514 return unique_justseen(sequenced, key=key)
515
516
517def iter_except(func, exception, first=None):
518 """Yields results from a function repeatedly until an exception is raised.
519
520 Converts a call-until-exception interface to an iterator interface.
521 Like ``iter(func, sentinel)``, but uses an exception instead of a sentinel
522 to end the loop.
523
524 >>> l = [0, 1, 2]
525 >>> list(iter_except(l.pop, IndexError))
526 [2, 1, 0]
527
528 Multiple exceptions can be specified as a stopping condition:
529
530 >>> l = [1, 2, 3, '...', 4, 5, 6]
531 >>> list(iter_except(lambda: 1 + l.pop(), (IndexError, TypeError)))
532 [7, 6, 5]
533 >>> list(iter_except(lambda: 1 + l.pop(), (IndexError, TypeError)))
534 [4, 3, 2]
535 >>> list(iter_except(lambda: 1 + l.pop(), (IndexError, TypeError)))
536 []
537
538 """
539 with suppress(exception):
540 if first is not None:
541 yield first()
542 while True:
543 yield func()
544
545
546def first_true(iterable, default=None, pred=None):
547 """
548 Returns the first true value in the iterable.
549
550 If no true value is found, returns *default*
551
552 If *pred* is not None, returns the first item for which
553 ``pred(item) == True`` .
554
555 >>> first_true(range(10))
556 1
557 >>> first_true(range(10), pred=lambda x: x > 5)
558 6
559 >>> first_true(range(10), default='missing', pred=lambda x: x > 9)
560 'missing'
561
562 """
563 return next(filter(pred, iterable), default)
564
565
566def random_product(*args, repeat=1):
567 """Draw an item at random from each of the input iterables.
568
569 >>> random_product('abc', range(4), 'XYZ') # doctest:+SKIP
570 ('c', 3, 'Z')
571
572 If *repeat* is provided as a keyword argument, that many items will be
573 drawn from each iterable.
574
575 >>> random_product('abcd', range(4), repeat=2) # doctest:+SKIP
576 ('a', 2, 'd', 3)
577
578 This equivalent to taking a random selection from
579 ``itertools.product(*args, repeat=repeat)``.
580
581 """
582 pools = [tuple(pool) for pool in args] * repeat
583 return tuple(choice(pool) for pool in pools)
584
585
586def random_permutation(iterable, r=None):
587 """Return a random *r* length permutation of the elements in *iterable*.
588
589 If *r* is not specified or is ``None``, then *r* defaults to the length of
590 *iterable*.
591
592 >>> random_permutation(range(5)) # doctest:+SKIP
593 (3, 4, 0, 1, 2)
594
595 This equivalent to taking a random selection from
596 ``itertools.permutations(iterable, r)``.
597
598 """
599 pool = tuple(iterable)
600 r = len(pool) if r is None else r
601 return tuple(sample(pool, r))
602
603
604def random_combination(iterable, r):
605 """Return a random *r* length subsequence of the elements in *iterable*.
606
607 >>> random_combination(range(5), 3) # doctest:+SKIP
608 (2, 3, 4)
609
610 This equivalent to taking a random selection from
611 ``itertools.combinations(iterable, r)``.
612
613 """
614 pool = tuple(iterable)
615 n = len(pool)
616 indices = sorted(sample(range(n), r))
617 return tuple(pool[i] for i in indices)
618
619
620def random_combination_with_replacement(iterable, r):
621 """Return a random *r* length subsequence of elements in *iterable*,
622 allowing individual elements to be repeated.
623
624 >>> random_combination_with_replacement(range(3), 5) # doctest:+SKIP
625 (0, 0, 1, 2, 2)
626
627 This equivalent to taking a random selection from
628 ``itertools.combinations_with_replacement(iterable, r)``.
629
630 """
631 pool = tuple(iterable)
632 n = len(pool)
633 indices = sorted(randrange(n) for i in range(r))
634 return tuple(pool[i] for i in indices)
635
636
637def nth_combination(iterable, r, index):
638 """Equivalent to ``list(combinations(iterable, r))[index]``.
639
640 The subsequences of *iterable* that are of length *r* can be ordered
641 lexicographically. :func:`nth_combination` computes the subsequence at
642 sort position *index* directly, without computing the previous
643 subsequences.
644
645 >>> nth_combination(range(5), 3, 5)
646 (0, 3, 4)
647
648 ``ValueError`` will be raised If *r* is negative or greater than the length
649 of *iterable*.
650 ``IndexError`` will be raised if the given *index* is invalid.
651 """
652 pool = tuple(iterable)
653 n = len(pool)
654 if (r < 0) or (r > n):
655 raise ValueError
656
657 c = 1
658 k = min(r, n - r)
659 for i in range(1, k + 1):
660 c = c * (n - k + i) // i
661
662 if index < 0:
663 index += c
664
665 if (index < 0) or (index >= c):
666 raise IndexError
667
668 result = []
669 while r:
670 c, n, r = c * r // n, n - 1, r - 1
671 while index >= c:
672 index -= c
673 c, n = c * (n - r) // n, n - 1
674 result.append(pool[-1 - n])
675
676 return tuple(result)
677
678
679def prepend(value, iterator):
680 """Yield *value*, followed by the elements in *iterator*.
681
682 >>> value = '0'
683 >>> iterator = ['1', '2', '3']
684 >>> list(prepend(value, iterator))
685 ['0', '1', '2', '3']
686
687 To prepend multiple values, see :func:`itertools.chain`
688 or :func:`value_chain`.
689
690 """
691 return chain([value], iterator)
692
693
694def convolve(signal, kernel):
695 """Discrete linear convolution of two iterables.
696 Equivalent to polynomial multiplication.
697
698 For example, multiplying ``(x² -x - 20)`` by ``(x - 3)``
699 gives ``(x³ -4x² -17x + 60)``.
700
701 >>> list(convolve([1, -1, -20], [1, -3]))
702 [1, -4, -17, 60]
703
704 Examples of popular kinds of kernels:
705
706 * The kernel ``[0.25, 0.25, 0.25, 0.25]`` computes a moving average.
707 For image data, this blurs the image and reduces noise.
708 * The kernel ``[1/2, 0, -1/2]`` estimates the first derivative of
709 a function evaluated at evenly spaced inputs.
710 * The kernel ``[1, -2, 1]`` estimates the second derivative of a
711 function evaluated at evenly spaced inputs.
712
713 Convolutions are mathematically commutative; however, the inputs are
714 evaluated differently. The signal is consumed lazily and can be
715 infinite. The kernel is fully consumed before the calculations begin.
716
717 Supports all numeric types: int, float, complex, Decimal, Fraction.
718
719 References:
720
721 * Article: https://betterexplained.com/articles/intuitive-convolution/
722 * Video by 3Blue1Brown: https://www.youtube.com/watch?v=KuXjwB4LzSA
723
724 """
725 # This implementation comes from an older version of the itertools
726 # documentation. While the newer implementation is a bit clearer,
727 # this one was kept because the inlined window logic is faster
728 # and it avoids an unnecessary deque-to-tuple conversion.
729 kernel = tuple(kernel)[::-1]
730 n = len(kernel)
731 window = deque([0], maxlen=n) * n
732 for x in chain(signal, repeat(0, n - 1)):
733 window.append(x)
734 yield _sumprod(kernel, window)
735
736
737def before_and_after(predicate, it):
738 """A variant of :func:`takewhile` that allows complete access to the
739 remainder of the iterator.
740
741 >>> it = iter('ABCdEfGhI')
742 >>> all_upper, remainder = before_and_after(str.isupper, it)
743 >>> ''.join(all_upper)
744 'ABC'
745 >>> ''.join(remainder) # takewhile() would lose the 'd'
746 'dEfGhI'
747
748 Note that the first iterator must be fully consumed before the second
749 iterator can generate valid results.
750 """
751 trues, after = tee(it)
752 trues = compress(takewhile(predicate, trues), zip(after))
753 return trues, after
754
755
756def triplewise(iterable):
757 """Return overlapping triplets from *iterable*.
758
759 >>> list(triplewise('ABCDE'))
760 [('A', 'B', 'C'), ('B', 'C', 'D'), ('C', 'D', 'E')]
761
762 """
763 # This deviates from the itertools documentation recipe - see
764 # https://github.com/more-itertools/more-itertools/issues/889
765 t1, t2, t3 = tee(iterable, 3)
766 next(t3, None)
767 next(t3, None)
768 next(t2, None)
769 return zip(t1, t2, t3)
770
771
772def _sliding_window_islice(iterable, n):
773 # Fast path for small, non-zero values of n.
774 iterators = tee(iterable, n)
775 for i, iterator in enumerate(iterators):
776 next(islice(iterator, i, i), None)
777 return zip(*iterators)
778
779
780def _sliding_window_deque(iterable, n):
781 # Normal path for other values of n.
782 iterator = iter(iterable)
783 window = deque(islice(iterator, n - 1), maxlen=n)
784 for x in iterator:
785 window.append(x)
786 yield tuple(window)
787
788
789def sliding_window(iterable, n):
790 """Return a sliding window of width *n* over *iterable*.
791
792 >>> list(sliding_window(range(6), 4))
793 [(0, 1, 2, 3), (1, 2, 3, 4), (2, 3, 4, 5)]
794
795 If *iterable* has fewer than *n* items, then nothing is yielded:
796
797 >>> list(sliding_window(range(3), 4))
798 []
799
800 For a variant with more features, see :func:`windowed`.
801 """
802 if n > 20:
803 return _sliding_window_deque(iterable, n)
804 elif n > 2:
805 return _sliding_window_islice(iterable, n)
806 elif n == 2:
807 return pairwise(iterable)
808 elif n == 1:
809 return zip(iterable)
810 else:
811 raise ValueError(f'n should be at least one, not {n}')
812
813
814def subslices(iterable):
815 """Return all contiguous non-empty subslices of *iterable*.
816
817 >>> list(subslices('ABC'))
818 [['A'], ['A', 'B'], ['A', 'B', 'C'], ['B'], ['B', 'C'], ['C']]
819
820 This is similar to :func:`substrings`, but emits items in a different
821 order.
822 """
823 seq = list(iterable)
824 slices = starmap(slice, combinations(range(len(seq) + 1), 2))
825 return map(getitem, repeat(seq), slices)
826
827
828def polynomial_from_roots(roots):
829 """Compute a polynomial's coefficients from its roots.
830
831 >>> roots = [5, -4, 3] # (x - 5) * (x + 4) * (x - 3)
832 >>> polynomial_from_roots(roots) # x³ - 4 x² - 17 x + 60
833 [1, -4, -17, 60]
834
835 Note that polynomial coefficients are specified in descending power order.
836
837 Supports all numeric types: int, float, complex, Decimal, Fraction.
838 """
839
840 # This recipe differs from the one in itertools docs in that it
841 # applies list() after each call to convolve(). This avoids
842 # hitting stack limits with nested generators.
843
844 poly = [1]
845 for root in roots:
846 poly = list(convolve(poly, (1, -root)))
847 return poly
848
849
850def iter_index(iterable, value, start=0, stop=None):
851 """Yield the index of each place in *iterable* that *value* occurs,
852 beginning with index *start* and ending before index *stop*.
853
854
855 >>> list(iter_index('AABCADEAF', 'A'))
856 [0, 1, 4, 7]
857 >>> list(iter_index('AABCADEAF', 'A', 1)) # start index is inclusive
858 [1, 4, 7]
859 >>> list(iter_index('AABCADEAF', 'A', 1, 7)) # stop index is not inclusive
860 [1, 4]
861
862 The behavior for non-scalar *values* matches the built-in Python types.
863
864 >>> list(iter_index('ABCDABCD', 'AB'))
865 [0, 4]
866 >>> list(iter_index([0, 1, 2, 3, 0, 1, 2, 3], [0, 1]))
867 []
868 >>> list(iter_index([[0, 1], [2, 3], [0, 1], [2, 3]], [0, 1]))
869 [0, 2]
870
871 See :func:`locate` for a more general means of finding the indexes
872 associated with particular values.
873
874 """
875 seq_index = getattr(iterable, 'index', None)
876 if seq_index is None:
877 # Slow path for general iterables
878 iterator = islice(iterable, start, stop)
879 for i, element in enumerate(iterator, start):
880 if element is value or element == value:
881 yield i
882 else:
883 # Fast path for sequences
884 stop = len(iterable) if stop is None else stop
885 i = start - 1
886 with suppress(ValueError):
887 while True:
888 yield (i := seq_index(value, i + 1, stop))
889
890
891def sieve(n):
892 """Yield the primes less than n.
893
894 >>> list(sieve(30))
895 [2, 3, 5, 7, 11, 13, 17, 19, 23, 29]
896
897 """
898 # This implementation comes from an older version of the itertools
899 # documentation. The newer implementation is easier to read but is
900 # less lazy.
901 if n > 2:
902 yield 2
903 start = 3
904 data = bytearray((0, 1)) * (n // 2)
905 for p in iter_index(data, 1, start, stop=isqrt(n) + 1):
906 yield from iter_index(data, 1, start, p * p)
907 data[p * p : n : p + p] = bytes(len(range(p * p, n, p + p)))
908 start = p * p
909 yield from iter_index(data, 1, start)
910
911
912def _batched(iterable, n, *, strict=False): # pragma: no cover
913 """Batch data into tuples of length *n*. If the number of items in
914 *iterable* is not divisible by *n*:
915 * The last batch will be shorter if *strict* is ``False``.
916 * :exc:`ValueError` will be raised if *strict* is ``True``.
917
918 >>> list(batched('ABCDEFG', 3))
919 [('A', 'B', 'C'), ('D', 'E', 'F'), ('G',)]
920
921 On Python 3.13 and above, this is an alias for :func:`itertools.batched`.
922 """
923 if n < 1:
924 raise ValueError('n must be at least one')
925 iterator = iter(iterable)
926 while batch := tuple(islice(iterator, n)):
927 if strict and len(batch) != n:
928 raise ValueError('batched(): incomplete batch')
929 yield batch
930
931
932if hexversion >= 0x30D00A2: # pragma: no cover
933 from itertools import batched as itertools_batched
934
935 def batched(iterable, n, *, strict=False):
936 return itertools_batched(iterable, n, strict=strict)
937
938 batched.__doc__ = _batched.__doc__
939else: # pragma: no cover
940 batched = _batched
941
942
943def transpose(it):
944 """Swap the rows and columns of the input matrix.
945
946 >>> list(transpose([(1, 2, 3), (11, 22, 33)]))
947 [(1, 11), (2, 22), (3, 33)]
948
949 The caller should ensure that the dimensions of the input are compatible.
950 If the input is empty, no output will be produced.
951 """
952 return zip(*it, strict=True)
953
954
955def _is_scalar(value, stringlike=(str, bytes)):
956 "Scalars are bytes, strings, and non-iterables."
957 try:
958 iter(value)
959 except TypeError:
960 return True
961 return isinstance(value, stringlike)
962
963
964def _flatten_tensor(tensor):
965 "Depth-first iterator over scalars in a tensor."
966 iterator = iter(tensor)
967 while True:
968 try:
969 value = next(iterator)
970 except StopIteration:
971 return iterator
972 iterator = chain((value,), iterator)
973 if _is_scalar(value):
974 return iterator
975 iterator = chain.from_iterable(iterator)
976
977
978def reshape(matrix, shape):
979 """Change the shape of a *matrix*.
980
981 If *shape* is an integer, the matrix must be two dimensional
982 and the shape is interpreted as the desired number of columns:
983
984 >>> matrix = [(0, 1), (2, 3), (4, 5)]
985 >>> cols = 3
986 >>> list(reshape(matrix, cols))
987 [(0, 1, 2), (3, 4, 5)]
988
989 If *shape* is a tuple (or other iterable), the input matrix can have
990 any number of dimensions. It will first be flattened and then rebuilt
991 to the desired shape which can also be multidimensional:
992
993 >>> matrix = [(0, 1), (2, 3), (4, 5)] # Start with a 3 x 2 matrix
994
995 >>> list(reshape(matrix, (2, 3))) # Make a 2 x 3 matrix
996 [(0, 1, 2), (3, 4, 5)]
997
998 >>> list(reshape(matrix, (6,))) # Make a vector of length six
999 [0, 1, 2, 3, 4, 5]
1000
1001 >>> list(reshape(matrix, (2, 1, 3, 1))) # Make 2 x 1 x 3 x 1 tensor
1002 [(((0,), (1,), (2,)),), (((3,), (4,), (5,)),)]
1003
1004 Each dimension is assumed to be uniform, either all arrays or all scalars.
1005 Flattening stops when the first value in a dimension is a scalar.
1006 Scalars are bytes, strings, and non-iterables.
1007 The reshape iterator stops when the requested shape is complete
1008 or when the input is exhausted, whichever comes first.
1009
1010 """
1011 if isinstance(shape, int):
1012 return batched(chain.from_iterable(matrix), shape)
1013 first_dim, *dims = shape
1014 scalar_stream = _flatten_tensor(matrix)
1015 reshaped = reduce(batched, reversed(dims), scalar_stream)
1016 return islice(reshaped, first_dim)
1017
1018
1019def matmul(m1, m2):
1020 """Multiply two matrices.
1021
1022 >>> list(matmul([(7, 5), (3, 5)], [(2, 5), (7, 9)]))
1023 [(49, 80), (41, 60)]
1024
1025 The caller should ensure that the dimensions of the input matrices are
1026 compatible with each other.
1027
1028 Supports all numeric types: int, float, complex, Decimal, Fraction.
1029 """
1030 n = len(m2[0])
1031 return batched(starmap(_sumprod, product(m1, transpose(m2))), n)
1032
1033
1034def _factor_pollard(n):
1035 # Return a factor of n using Pollard's rho algorithm.
1036 # Efficient when n is odd and composite.
1037 for b in range(1, n):
1038 x = y = 2
1039 d = 1
1040 while d == 1:
1041 x = (x * x + b) % n
1042 y = (y * y + b) % n
1043 y = (y * y + b) % n
1044 d = gcd(x - y, n)
1045 if d != n:
1046 return d
1047 raise ValueError('prime or under 5') # pragma: no cover
1048
1049
1050_primes_below_211 = tuple(sieve(211))
1051
1052
1053def factor(n):
1054 """Yield the prime factors of n.
1055
1056 >>> list(factor(360))
1057 [2, 2, 2, 3, 3, 5]
1058
1059 Finds small factors with trial division. Larger factors are
1060 either verified as prime with ``is_prime`` or split into
1061 smaller factors with Pollard's rho algorithm.
1062 """
1063
1064 # Corner case reduction
1065 if n < 2:
1066 return
1067
1068 # Trial division reduction
1069 for prime in _primes_below_211:
1070 while not n % prime:
1071 yield prime
1072 n //= prime
1073
1074 # Pollard's rho reduction
1075 primes = []
1076 todo = [n] if n > 1 else []
1077 for n in todo:
1078 if n < 211**2 or is_prime(n):
1079 primes.append(n)
1080 else:
1081 fact = _factor_pollard(n)
1082 todo += (fact, n // fact)
1083 yield from sorted(primes)
1084
1085
1086def polynomial_eval(coefficients, x):
1087 """Evaluate a polynomial at a specific value.
1088
1089 Computes with better numeric stability than Horner's method.
1090
1091 Evaluate ``x^3 - 4 * x^2 - 17 * x + 60`` at ``x = 2.5``:
1092
1093 >>> coefficients = [1, -4, -17, 60]
1094 >>> x = 2.5
1095 >>> polynomial_eval(coefficients, x)
1096 8.125
1097
1098 Note that polynomial coefficients are specified in descending power order.
1099
1100 Supports all numeric types: int, float, complex, Decimal, Fraction.
1101 """
1102 n = len(coefficients)
1103 if n == 0:
1104 return type(x)(0)
1105 powers = map(pow, repeat(x), reversed(range(n)))
1106 return _sumprod(coefficients, powers)
1107
1108
1109def sum_of_squares(it):
1110 """Return the sum of the squares of the input values.
1111
1112 >>> sum_of_squares([10, 20, 30])
1113 1400
1114
1115 Supports all numeric types: int, float, complex, Decimal, Fraction.
1116 """
1117 return _sumprod(*tee(it))
1118
1119
1120def polynomial_derivative(coefficients):
1121 """Compute the first derivative of a polynomial.
1122
1123 Evaluate the derivative of ``x³ - 4 x² - 17 x + 60``:
1124
1125 >>> coefficients = [1, -4, -17, 60]
1126 >>> derivative_coefficients = polynomial_derivative(coefficients)
1127 >>> derivative_coefficients
1128 [3, -8, -17]
1129
1130 Note that polynomial coefficients are specified in descending power order.
1131
1132 Supports all numeric types: int, float, complex, Decimal, Fraction.
1133 """
1134 n = len(coefficients)
1135 powers = reversed(range(1, n))
1136 return list(map(mul, coefficients, powers))
1137
1138
1139def totient(n):
1140 """Return the count of natural numbers up to *n* that are coprime with *n*.
1141
1142 Euler's totient function φ(n) gives the number of totatives.
1143 Totative are integers k in the range 1 ≤ k ≤ n such that gcd(n, k) = 1.
1144
1145 >>> n = 9
1146 >>> totient(n)
1147 6
1148
1149 >>> totatives = [x for x in range(1, n) if gcd(n, x) == 1]
1150 >>> totatives
1151 [1, 2, 4, 5, 7, 8]
1152 >>> len(totatives)
1153 6
1154
1155 Reference: https://en.wikipedia.org/wiki/Euler%27s_totient_function
1156
1157 """
1158 for prime in set(factor(n)):
1159 n -= n // prime
1160 return n
1161
1162
1163# Miller–Rabin primality test: https://oeis.org/A014233
1164_perfect_tests = [
1165 (2047, (2,)),
1166 (9080191, (31, 73)),
1167 (4759123141, (2, 7, 61)),
1168 (1122004669633, (2, 13, 23, 1662803)),
1169 (2152302898747, (2, 3, 5, 7, 11)),
1170 (3474749660383, (2, 3, 5, 7, 11, 13)),
1171 (18446744073709551616, (2, 325, 9375, 28178, 450775, 9780504, 1795265022)),
1172 (
1173 3317044064679887385961981,
1174 (2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41),
1175 ),
1176]
1177
1178
1179@lru_cache
1180def _shift_to_odd(n):
1181 'Return s, d such that 2**s * d == n'
1182 s = ((n - 1) ^ n).bit_length() - 1
1183 d = n >> s
1184 assert (1 << s) * d == n and d & 1 and s >= 0
1185 return s, d
1186
1187
1188def _strong_probable_prime(n, base):
1189 assert (n > 2) and (n & 1) and (2 <= base < n)
1190
1191 s, d = _shift_to_odd(n - 1)
1192
1193 x = pow(base, d, n)
1194 if x == 1 or x == n - 1:
1195 return True
1196
1197 for _ in range(s - 1):
1198 x = x * x % n
1199 if x == n - 1:
1200 return True
1201
1202 return False
1203
1204
1205# Separate instance of Random() that doesn't share state
1206# with the default user instance of Random().
1207_private_randrange = random.Random().randrange
1208
1209
1210def is_prime(n):
1211 """Return ``True`` if *n* is prime and ``False`` otherwise.
1212
1213 Basic examples:
1214
1215 >>> is_prime(37)
1216 True
1217 >>> is_prime(3 * 13)
1218 False
1219 >>> is_prime(18_446_744_073_709_551_557)
1220 True
1221
1222 Find the next prime over one billion:
1223
1224 >>> next(filter(is_prime, count(10**9)))
1225 1000000007
1226
1227 Generate random primes up to 200 bits and up to 60 decimal digits:
1228
1229 >>> from random import seed, randrange, getrandbits
1230 >>> seed(18675309)
1231
1232 >>> next(filter(is_prime, map(getrandbits, repeat(200))))
1233 893303929355758292373272075469392561129886005037663238028407
1234
1235 >>> next(filter(is_prime, map(randrange, repeat(10**60))))
1236 269638077304026462407872868003560484232362454342414618963649
1237
1238 This function is exact for values of *n* below 10**24. For larger inputs,
1239 the probabilistic Miller-Rabin primality test has a less than 1 in 2**128
1240 chance of a false positive.
1241 """
1242
1243 if n < 17:
1244 return n in {2, 3, 5, 7, 11, 13}
1245
1246 if not (n & 1 and n % 3 and n % 5 and n % 7 and n % 11 and n % 13):
1247 return False
1248
1249 for limit, bases in _perfect_tests:
1250 if n < limit:
1251 break
1252 else:
1253 bases = (_private_randrange(2, n - 1) for i in range(64))
1254
1255 return all(_strong_probable_prime(n, base) for base in bases)
1256
1257
1258def loops(n):
1259 """Returns an iterable with *n* elements for efficient looping.
1260 Like ``range(n)`` but doesn't create integers.
1261
1262 >>> i = 0
1263 >>> for _ in loops(5):
1264 ... i += 1
1265 >>> i
1266 5
1267
1268 """
1269 return repeat(None, n)
1270
1271
1272def multinomial(*counts):
1273 """Number of distinct arrangements of a multiset.
1274
1275 The expression ``multinomial(3, 4, 2)`` has several equivalent
1276 interpretations:
1277
1278 * In the expansion of ``(a + b + c)⁹``, the coefficient of the
1279 ``a³b⁴c²`` term is 1260.
1280
1281 * There are 1260 distinct ways to arrange 9 balls consisting of 3 reds, 4
1282 greens, and 2 blues.
1283
1284 * There are 1260 unique ways to place 9 distinct objects into three bins
1285 with sizes 3, 4, and 2.
1286
1287 The :func:`multinomial` function computes the length of
1288 :func:`distinct_permutations`. For example, there are 83,160 distinct
1289 anagrams of the word "abracadabra":
1290
1291 >>> from more_itertools import distinct_permutations, ilen
1292 >>> ilen(distinct_permutations('abracadabra'))
1293 83160
1294
1295 This can be computed directly from the letter counts, 5a 2b 2r 1c 1d:
1296
1297 >>> from collections import Counter
1298 >>> list(Counter('abracadabra').values())
1299 [5, 2, 2, 1, 1]
1300 >>> multinomial(5, 2, 2, 1, 1)
1301 83160
1302
1303 A binomial coefficient is a special case of multinomial where there are
1304 only two categories. For example, the number of ways to arrange 12 balls
1305 with 5 reds and 7 blues is ``multinomial(5, 7)`` or ``math.comb(12, 5)``.
1306
1307 Likewise, factorial is a special case of multinomial where
1308 the multiplicities are all just 1 so that
1309 ``multinomial(1, 1, 1, 1, 1, 1, 1) == math.factorial(7)``.
1310
1311 Reference: https://en.wikipedia.org/wiki/Multinomial_theorem
1312
1313 """
1314 return prod(map(comb, accumulate(counts), counts))
1315
1316
1317def _running_median_minheap_and_maxheap(iterator): # pragma: no cover
1318 "Non-windowed running_median() for Python 3.14+"
1319
1320 read = iterator.__next__
1321 lo = [] # max-heap
1322 hi = [] # min-heap (same size as or one smaller than lo)
1323
1324 with suppress(StopIteration):
1325 while True:
1326 heappush_max(lo, heappushpop(hi, read()))
1327 yield lo[0]
1328
1329 heappush(hi, heappushpop_max(lo, read()))
1330 yield (lo[0] + hi[0]) / 2
1331
1332
1333def _running_median_minheap_only(iterator): # pragma: no cover
1334 "Backport of non-windowed running_median() for Python 3.13 and prior."
1335
1336 read = iterator.__next__
1337 lo = [] # max-heap (actually a minheap with negated values)
1338 hi = [] # min-heap (same size as or one smaller than lo)
1339
1340 with suppress(StopIteration):
1341 while True:
1342 heappush(lo, -heappushpop(hi, read()))
1343 yield -lo[0]
1344
1345 heappush(hi, -heappushpop(lo, -read()))
1346 yield (hi[0] - lo[0]) / 2
1347
1348
1349def _running_median_windowed(iterator, maxlen):
1350 "Yield median of values in a sliding window."
1351
1352 window = deque()
1353 ordered = []
1354
1355 for x in iterator:
1356 window.append(x)
1357 insort(ordered, x)
1358
1359 if len(ordered) > maxlen:
1360 i = bisect_left(ordered, window.popleft())
1361 del ordered[i]
1362
1363 n = len(ordered)
1364 m = n // 2
1365 yield ordered[m] if n & 1 else (ordered[m - 1] + ordered[m]) / 2
1366
1367
1368def running_median(iterable, *, maxlen=None):
1369 """Cumulative median of values seen so far or values in a sliding window.
1370
1371 Set *maxlen* to a positive integer to specify the maximum size
1372 of the sliding window. The default of *None* is equivalent to
1373 an unbounded window.
1374
1375 For example:
1376
1377 >>> list(running_median([5.0, 9.0, 4.0, 12.0, 8.0, 9.0]))
1378 [5.0, 7.0, 5.0, 7.0, 8.0, 8.5]
1379 >>> list(running_median([5.0, 9.0, 4.0, 12.0, 8.0, 9.0], maxlen=3))
1380 [5.0, 7.0, 5.0, 9.0, 8.0, 9.0]
1381
1382 Supports numeric types such as int, float, Decimal, and Fraction,
1383 but not complex numbers which are unorderable.
1384
1385 On version Python 3.13 and prior, max-heaps are simulated with
1386 negative values. The negation causes Decimal inputs to apply context
1387 rounding, making the results slightly different than that obtained
1388 by statistics.median().
1389 """
1390
1391 iterator = iter(iterable)
1392
1393 if maxlen is not None:
1394 maxlen = index(maxlen)
1395 if maxlen <= 0:
1396 raise ValueError('Window size should be positive')
1397 return _running_median_windowed(iterator, maxlen)
1398
1399 if not _max_heap_available:
1400 return _running_median_minheap_only(iterator) # pragma: no cover
1401
1402 return _running_median_minheap_and_maxheap(iterator) # pragma: no cover
1403
1404
1405def random_derangement(iterable):
1406 """Return a random derangement of elements in the iterable.
1407
1408 Equivalent to but much faster than ``choice(list(derangements(iterable)))``.
1409
1410 """
1411 seq = tuple(iterable)
1412 if len(seq) < 2:
1413 if len(seq) == 0:
1414 return ()
1415 raise IndexError('No derangments to choose from')
1416 perm = list(range(len(seq)))
1417 start = tuple(perm)
1418 while True:
1419 shuffle(perm)
1420 if not any(map(is_, start, perm)):
1421 return itemgetter(*perm)(seq)