1"""Imported from the recipes section of the itertools documentation.
2
3All functions taken from the recipes section of the itertools library docs
4[1]_.
5Some backward-compatible usability improvements have been made.
6
7.. [1] http://docs.python.org/library/itertools.html#recipes
8
9"""
10
11import random
12
13from bisect import bisect_left, insort
14from collections import deque
15from contextlib import suppress
16from functools import lru_cache, reduce
17from heapq import heappush, heappushpop
18from itertools import (
19 accumulate,
20 chain,
21 combinations,
22 compress,
23 count,
24 cycle,
25 groupby,
26 islice,
27 pairwise,
28 product,
29 repeat,
30 starmap,
31 takewhile,
32 tee,
33 zip_longest,
34)
35from math import prod, comb, isqrt, gcd
36from operator import mul, is_, not_, itemgetter, getitem, index
37from random import randrange, sample, choice, shuffle
38from sys import hexversion
39
40__all__ = [
41 'all_equal',
42 'batched',
43 'before_and_after',
44 'consume',
45 'convolve',
46 'dotproduct',
47 'first_true',
48 'factor',
49 'flatten',
50 'grouper',
51 'is_prime',
52 'iter_except',
53 'iter_index',
54 'loops',
55 'matmul',
56 'multinomial',
57 'ncycles',
58 'nth',
59 'nth_combination',
60 'padnone',
61 'pad_none',
62 'partition',
63 'polynomial_eval',
64 'polynomial_from_roots',
65 'polynomial_derivative',
66 'powerset',
67 'prepend',
68 'quantify',
69 'reshape',
70 'random_combination_with_replacement',
71 'random_combination',
72 'random_derangement',
73 'random_permutation',
74 'random_product',
75 'repeatfunc',
76 'roundrobin',
77 'running_median',
78 'sieve',
79 'sliding_window',
80 'subslices',
81 'sum_of_squares',
82 'tabulate',
83 'tail',
84 'take',
85 'totient',
86 'transpose',
87 'triplewise',
88 'unique',
89 'unique_everseen',
90 'unique_justseen',
91]
92
93_marker = object()
94
95
96# heapq max-heap functions are available for Python 3.14+
97try:
98 from heapq import heappush_max, heappushpop_max
99except ImportError: # pragma: no cover
100 _max_heap_available = False
101else: # pragma: no cover
102 _max_heap_available = True
103
104
105def take(n, iterable):
106 """Return first *n* items of the *iterable* as a list.
107
108 >>> take(3, range(10))
109 [0, 1, 2]
110
111 If there are fewer than *n* items in the iterable, all of them are
112 returned.
113
114 >>> take(10, range(3))
115 [0, 1, 2]
116
117 """
118 return list(islice(iterable, n))
119
120
121def tabulate(function, start=0):
122 """Return an iterator over the results of ``func(start)``,
123 ``func(start + 1)``, ``func(start + 2)``...
124
125 *func* should be a function that accepts one integer argument.
126
127 If *start* is not specified it defaults to 0. It will be incremented each
128 time the iterator is advanced.
129
130 >>> square = lambda x: x ** 2
131 >>> iterator = tabulate(square, -3)
132 >>> take(4, iterator)
133 [9, 4, 1, 0]
134
135 """
136 return map(function, count(start))
137
138
139def tail(n, iterable):
140 """Return an iterator over the last *n* items of *iterable*.
141
142 >>> t = tail(3, 'ABCDEFG')
143 >>> list(t)
144 ['E', 'F', 'G']
145
146 """
147 try:
148 size = len(iterable)
149 except TypeError:
150 return iter(deque(iterable, maxlen=n))
151 else:
152 return islice(iterable, max(0, size - n), None)
153
154
155def consume(iterator, n=None):
156 """Advance *iterable* by *n* steps. If *n* is ``None``, consume it
157 entirely.
158
159 Efficiently exhausts an iterator without returning values. Defaults to
160 consuming the whole iterator, but an optional second argument may be
161 provided to limit consumption.
162
163 >>> i = (x for x in range(10))
164 >>> next(i)
165 0
166 >>> consume(i, 3)
167 >>> next(i)
168 4
169 >>> consume(i)
170 >>> next(i)
171 Traceback (most recent call last):
172 File "<stdin>", line 1, in <module>
173 StopIteration
174
175 If the iterator has fewer items remaining than the provided limit, the
176 whole iterator will be consumed.
177
178 >>> i = (x for x in range(3))
179 >>> consume(i, 5)
180 >>> next(i)
181 Traceback (most recent call last):
182 File "<stdin>", line 1, in <module>
183 StopIteration
184
185 """
186 # Use functions that consume iterators at C speed.
187 if n is None:
188 # feed the entire iterator into a zero-length deque
189 deque(iterator, maxlen=0)
190 else:
191 # advance to the empty slice starting at position n
192 next(islice(iterator, n, n), None)
193
194
195def nth(iterable, n, default=None):
196 """Returns the nth item or a default value.
197
198 >>> l = range(10)
199 >>> nth(l, 3)
200 3
201 >>> nth(l, 20, "zebra")
202 'zebra'
203
204 """
205 return next(islice(iterable, n, None), default)
206
207
208def all_equal(iterable, key=None):
209 """
210 Returns ``True`` if all the elements are equal to each other.
211
212 >>> all_equal('aaaa')
213 True
214 >>> all_equal('aaab')
215 False
216
217 A function that accepts a single argument and returns a transformed version
218 of each input item can be specified with *key*:
219
220 >>> all_equal('AaaA', key=str.casefold)
221 True
222 >>> all_equal([1, 2, 3], key=lambda x: x < 10)
223 True
224
225 """
226 iterator = groupby(iterable, key)
227 for first in iterator:
228 for second in iterator:
229 return False
230 return True
231 return True
232
233
234def quantify(iterable, pred=bool):
235 """Return the how many times the predicate is true.
236
237 >>> quantify([True, False, True])
238 2
239
240 """
241 return sum(map(pred, iterable))
242
243
244def pad_none(iterable):
245 """Returns the sequence of elements and then returns ``None`` indefinitely.
246
247 >>> take(5, pad_none(range(3)))
248 [0, 1, 2, None, None]
249
250 Useful for emulating the behavior of the built-in :func:`map` function.
251
252 See also :func:`padded`.
253
254 """
255 return chain(iterable, repeat(None))
256
257
258padnone = pad_none
259
260
261def ncycles(iterable, n):
262 """Returns the sequence elements *n* times
263
264 >>> list(ncycles(["a", "b"], 3))
265 ['a', 'b', 'a', 'b', 'a', 'b']
266
267 """
268 return chain.from_iterable(repeat(tuple(iterable), n))
269
270
271def dotproduct(vec1, vec2):
272 """Returns the dot product of the two iterables.
273
274 >>> dotproduct([10, 15, 12], [0.65, 0.80, 1.25])
275 33.5
276 >>> 10 * 0.65 + 15 * 0.80 + 12 * 1.25
277 33.5
278
279 In Python 3.12 and later, use ``math.sumprod()`` instead.
280 """
281 return sum(map(mul, vec1, vec2))
282
283
284# math.sumprod is available for Python 3.12+
285try:
286 from math import sumprod as _sumprod
287except ImportError: # pragma: no cover
288 _sumprod = dotproduct
289
290
291def flatten(listOfLists):
292 """Return an iterator flattening one level of nesting in a list of lists.
293
294 >>> list(flatten([[0, 1], [2, 3]]))
295 [0, 1, 2, 3]
296
297 See also :func:`collapse`, which can flatten multiple levels of nesting.
298
299 """
300 return chain.from_iterable(listOfLists)
301
302
303def repeatfunc(func, times=None, *args):
304 """Call *func* with *args* repeatedly, returning an iterable over the
305 results.
306
307 If *times* is specified, the iterable will terminate after that many
308 repetitions:
309
310 >>> from operator import add
311 >>> times = 4
312 >>> args = 3, 5
313 >>> list(repeatfunc(add, times, *args))
314 [8, 8, 8, 8]
315
316 If *times* is ``None`` the iterable will not terminate:
317
318 >>> from random import randrange
319 >>> times = None
320 >>> args = 1, 11
321 >>> take(6, repeatfunc(randrange, times, *args)) # doctest:+SKIP
322 [2, 4, 8, 1, 8, 4]
323
324 """
325 if times is None:
326 return starmap(func, repeat(args))
327 return starmap(func, repeat(args, times))
328
329
330def grouper(iterable, n, incomplete='fill', fillvalue=None):
331 """Group elements from *iterable* into fixed-length groups of length *n*.
332
333 >>> list(grouper('ABCDEF', 3))
334 [('A', 'B', 'C'), ('D', 'E', 'F')]
335
336 The keyword arguments *incomplete* and *fillvalue* control what happens for
337 iterables whose length is not a multiple of *n*.
338
339 When *incomplete* is `'fill'`, the last group will contain instances of
340 *fillvalue*.
341
342 >>> list(grouper('ABCDEFG', 3, incomplete='fill', fillvalue='x'))
343 [('A', 'B', 'C'), ('D', 'E', 'F'), ('G', 'x', 'x')]
344
345 When *incomplete* is `'ignore'`, the last group will not be emitted.
346
347 >>> list(grouper('ABCDEFG', 3, incomplete='ignore', fillvalue='x'))
348 [('A', 'B', 'C'), ('D', 'E', 'F')]
349
350 When *incomplete* is `'strict'`, a `ValueError` will be raised.
351
352 >>> iterator = grouper('ABCDEFG', 3, incomplete='strict')
353 >>> list(iterator) # doctest: +IGNORE_EXCEPTION_DETAIL
354 Traceback (most recent call last):
355 ...
356 ValueError
357
358 """
359 iterators = [iter(iterable)] * n
360 match incomplete:
361 case 'fill':
362 return zip_longest(*iterators, fillvalue=fillvalue)
363 case 'strict':
364 return zip(*iterators, strict=True)
365 case 'ignore':
366 return zip(*iterators)
367 case _:
368 raise ValueError('Expected fill, strict, or ignore')
369
370
371def roundrobin(*iterables):
372 """Visit input iterables in a cycle until each is exhausted.
373
374 >>> list(roundrobin('ABC', 'D', 'EF'))
375 ['A', 'D', 'E', 'B', 'F', 'C']
376
377 This function produces the same output as :func:`interleave_longest`, but
378 may perform better for some inputs (in particular when the number of
379 iterables is small).
380
381 """
382 # Algorithm credited to George Sakkis
383 iterators = map(iter, iterables)
384 for num_active in range(len(iterables), 0, -1):
385 iterators = cycle(islice(iterators, num_active))
386 yield from map(next, iterators)
387
388
389def partition(pred, iterable):
390 """
391 Returns a 2-tuple of iterables derived from the input iterable.
392 The first yields the items that have ``pred(item) == False``.
393 The second yields the items that have ``pred(item) == True``.
394
395 >>> is_odd = lambda x: x % 2 != 0
396 >>> iterable = range(10)
397 >>> even_items, odd_items = partition(is_odd, iterable)
398 >>> list(even_items), list(odd_items)
399 ([0, 2, 4, 6, 8], [1, 3, 5, 7, 9])
400
401 If *pred* is None, :func:`bool` is used.
402
403 >>> iterable = [0, 1, False, True, '', ' ']
404 >>> false_items, true_items = partition(None, iterable)
405 >>> list(false_items), list(true_items)
406 ([0, False, ''], [1, True, ' '])
407
408 """
409 if pred is None:
410 pred = bool
411
412 t1, t2, p = tee(iterable, 3)
413 p1, p2 = tee(map(pred, p))
414 return (compress(t1, map(not_, p1)), compress(t2, p2))
415
416
417def powerset(iterable):
418 """Yields all possible subsets of the iterable.
419
420 >>> list(powerset([1, 2, 3]))
421 [(), (1,), (2,), (3,), (1, 2), (1, 3), (2, 3), (1, 2, 3)]
422
423 :func:`powerset` will operate on iterables that aren't :class:`set`
424 instances, so repeated elements in the input will produce repeated elements
425 in the output.
426
427 >>> seq = [1, 1, 0]
428 >>> list(powerset(seq))
429 [(), (1,), (1,), (0,), (1, 1), (1, 0), (1, 0), (1, 1, 0)]
430
431 For a variant that efficiently yields actual :class:`set` instances, see
432 :func:`powerset_of_sets`.
433 """
434 s = list(iterable)
435 return chain.from_iterable(combinations(s, r) for r in range(len(s) + 1))
436
437
438def unique_everseen(iterable, key=None):
439 """
440 Yield unique elements, preserving order.
441
442 >>> list(unique_everseen('AAAABBBCCDAABBB'))
443 ['A', 'B', 'C', 'D']
444 >>> list(unique_everseen('ABBCcAD', str.lower))
445 ['A', 'B', 'C', 'D']
446
447 Sequences with a mix of hashable and unhashable items can be used.
448 The function will be slower (i.e., `O(n^2)`) for unhashable items.
449
450 Remember that ``list`` objects are unhashable - you can use the *key*
451 parameter to transform the list to a tuple (which is hashable) to
452 avoid a slowdown.
453
454 >>> iterable = ([1, 2], [2, 3], [1, 2])
455 >>> list(unique_everseen(iterable)) # Slow
456 [[1, 2], [2, 3]]
457 >>> list(unique_everseen(iterable, key=tuple)) # Faster
458 [[1, 2], [2, 3]]
459
460 Similarly, you may want to convert unhashable ``set`` objects with
461 ``key=frozenset``. For ``dict`` objects,
462 ``key=lambda x: frozenset(x.items())`` can be used.
463
464 """
465 seenset = set()
466 seenset_add = seenset.add
467 seenlist = []
468 seenlist_add = seenlist.append
469 use_key = key is not None
470
471 for element in iterable:
472 k = key(element) if use_key else element
473 try:
474 if k not in seenset:
475 seenset_add(k)
476 yield element
477 except TypeError:
478 if k not in seenlist:
479 seenlist_add(k)
480 yield element
481
482
483def unique_justseen(iterable, key=None):
484 """Yields elements in order, ignoring serial duplicates
485
486 >>> list(unique_justseen('AAAABBBCCDAABBB'))
487 ['A', 'B', 'C', 'D', 'A', 'B']
488 >>> list(unique_justseen('ABBCcAD', str.lower))
489 ['A', 'B', 'C', 'A', 'D']
490
491 """
492 if key is None:
493 return map(itemgetter(0), groupby(iterable))
494
495 return map(next, map(itemgetter(1), groupby(iterable, key)))
496
497
498def unique(iterable, key=None, reverse=False):
499 """Yields unique elements in sorted order.
500
501 >>> list(unique([[1, 2], [3, 4], [1, 2]]))
502 [[1, 2], [3, 4]]
503
504 *key* and *reverse* are passed to :func:`sorted`.
505
506 >>> list(unique('ABBcCAD', str.casefold))
507 ['A', 'B', 'c', 'D']
508 >>> list(unique('ABBcCAD', str.casefold, reverse=True))
509 ['D', 'c', 'B', 'A']
510
511 The elements in *iterable* need not be hashable, but they must be
512 comparable for sorting to work.
513 """
514 sequenced = sorted(iterable, key=key, reverse=reverse)
515 return unique_justseen(sequenced, key=key)
516
517
518def iter_except(func, exception, first=None):
519 """Yields results from a function repeatedly until an exception is raised.
520
521 Converts a call-until-exception interface to an iterator interface.
522 Like ``iter(func, sentinel)``, but uses an exception instead of a sentinel
523 to end the loop.
524
525 >>> l = [0, 1, 2]
526 >>> list(iter_except(l.pop, IndexError))
527 [2, 1, 0]
528
529 Multiple exceptions can be specified as a stopping condition:
530
531 >>> l = [1, 2, 3, '...', 4, 5, 6]
532 >>> list(iter_except(lambda: 1 + l.pop(), (IndexError, TypeError)))
533 [7, 6, 5]
534 >>> list(iter_except(lambda: 1 + l.pop(), (IndexError, TypeError)))
535 [4, 3, 2]
536 >>> list(iter_except(lambda: 1 + l.pop(), (IndexError, TypeError)))
537 []
538
539 """
540 with suppress(exception):
541 if first is not None:
542 yield first()
543 while True:
544 yield func()
545
546
547def first_true(iterable, default=None, pred=None):
548 """
549 Returns the first true value in the iterable.
550
551 If no true value is found, returns *default*
552
553 If *pred* is not None, returns the first item for which
554 ``pred(item) == True`` .
555
556 >>> first_true(range(10))
557 1
558 >>> first_true(range(10), pred=lambda x: x > 5)
559 6
560 >>> first_true(range(10), default='missing', pred=lambda x: x > 9)
561 'missing'
562
563 """
564 return next(filter(pred, iterable), default)
565
566
567def random_product(*args, repeat=1):
568 """Draw an item at random from each of the input iterables.
569
570 >>> random_product('abc', range(4), 'XYZ') # doctest:+SKIP
571 ('c', 3, 'Z')
572
573 If *repeat* is provided as a keyword argument, that many items will be
574 drawn from each iterable.
575
576 >>> random_product('abcd', range(4), repeat=2) # doctest:+SKIP
577 ('a', 2, 'd', 3)
578
579 This equivalent to taking a random selection from
580 ``itertools.product(*args, repeat=repeat)``.
581
582 """
583 pools = [tuple(pool) for pool in args] * repeat
584 return tuple(choice(pool) for pool in pools)
585
586
587def random_permutation(iterable, r=None):
588 """Return a random *r* length permutation of the elements in *iterable*.
589
590 If *r* is not specified or is ``None``, then *r* defaults to the length of
591 *iterable*.
592
593 >>> random_permutation(range(5)) # doctest:+SKIP
594 (3, 4, 0, 1, 2)
595
596 This equivalent to taking a random selection from
597 ``itertools.permutations(iterable, r)``.
598
599 """
600 pool = tuple(iterable)
601 r = len(pool) if r is None else r
602 return tuple(sample(pool, r))
603
604
605def random_combination(iterable, r):
606 """Return a random *r* length subsequence of the elements in *iterable*.
607
608 >>> random_combination(range(5), 3) # doctest:+SKIP
609 (2, 3, 4)
610
611 This equivalent to taking a random selection from
612 ``itertools.combinations(iterable, r)``.
613
614 """
615 pool = tuple(iterable)
616 n = len(pool)
617 indices = sorted(sample(range(n), r))
618 return tuple(pool[i] for i in indices)
619
620
621def random_combination_with_replacement(iterable, r):
622 """Return a random *r* length subsequence of elements in *iterable*,
623 allowing individual elements to be repeated.
624
625 >>> random_combination_with_replacement(range(3), 5) # doctest:+SKIP
626 (0, 0, 1, 2, 2)
627
628 This equivalent to taking a random selection from
629 ``itertools.combinations_with_replacement(iterable, r)``.
630
631 """
632 pool = tuple(iterable)
633 n = len(pool)
634 indices = sorted(randrange(n) for i in range(r))
635 return tuple(pool[i] for i in indices)
636
637
638def nth_combination(iterable, r, index):
639 """Equivalent to ``list(combinations(iterable, r))[index]``.
640
641 The subsequences of *iterable* that are of length *r* can be ordered
642 lexicographically. :func:`nth_combination` computes the subsequence at
643 sort position *index* directly, without computing the previous
644 subsequences.
645
646 >>> nth_combination(range(5), 3, 5)
647 (0, 3, 4)
648
649 ``ValueError`` will be raised If *r* is negative or greater than the length
650 of *iterable*.
651 ``IndexError`` will be raised if the given *index* is invalid.
652 """
653 pool = tuple(iterable)
654 n = len(pool)
655 if (r < 0) or (r > n):
656 raise ValueError
657
658 c = 1
659 k = min(r, n - r)
660 for i in range(1, k + 1):
661 c = c * (n - k + i) // i
662
663 if index < 0:
664 index += c
665
666 if (index < 0) or (index >= c):
667 raise IndexError
668
669 result = []
670 while r:
671 c, n, r = c * r // n, n - 1, r - 1
672 while index >= c:
673 index -= c
674 c, n = c * (n - r) // n, n - 1
675 result.append(pool[-1 - n])
676
677 return tuple(result)
678
679
680def prepend(value, iterator):
681 """Yield *value*, followed by the elements in *iterator*.
682
683 >>> value = '0'
684 >>> iterator = ['1', '2', '3']
685 >>> list(prepend(value, iterator))
686 ['0', '1', '2', '3']
687
688 To prepend multiple values, see :func:`itertools.chain`
689 or :func:`value_chain`.
690
691 """
692 return chain([value], iterator)
693
694
695def convolve(signal, kernel):
696 """Discrete linear convolution of two iterables.
697 Equivalent to polynomial multiplication.
698
699 For example, multiplying ``(x² -x - 20)`` by ``(x - 3)``
700 gives ``(x³ -4x² -17x + 60)``.
701
702 >>> list(convolve([1, -1, -20], [1, -3]))
703 [1, -4, -17, 60]
704
705 Examples of popular kinds of kernels:
706
707 * The kernel ``[0.25, 0.25, 0.25, 0.25]`` computes a moving average.
708 For image data, this blurs the image and reduces noise.
709 * The kernel ``[1/2, 0, -1/2]`` estimates the first derivative of
710 a function evaluated at evenly spaced inputs.
711 * The kernel ``[1, -2, 1]`` estimates the second derivative of a
712 function evaluated at evenly spaced inputs.
713
714 Convolutions are mathematically commutative; however, the inputs are
715 evaluated differently. The signal is consumed lazily and can be
716 infinite. The kernel is fully consumed before the calculations begin.
717
718 Supports all numeric types: int, float, complex, Decimal, Fraction.
719
720 References:
721
722 * Article: https://betterexplained.com/articles/intuitive-convolution/
723 * Video by 3Blue1Brown: https://www.youtube.com/watch?v=KuXjwB4LzSA
724
725 """
726 # This implementation comes from an older version of the itertools
727 # documentation. While the newer implementation is a bit clearer,
728 # this one was kept because the inlined window logic is faster
729 # and it avoids an unnecessary deque-to-tuple conversion.
730 kernel = tuple(kernel)[::-1]
731 n = len(kernel)
732 window = deque([0], maxlen=n) * n
733 for x in chain(signal, repeat(0, n - 1)):
734 window.append(x)
735 yield _sumprod(kernel, window)
736
737
738def before_and_after(predicate, it):
739 """A variant of :func:`takewhile` that allows complete access to the
740 remainder of the iterator.
741
742 >>> it = iter('ABCdEfGhI')
743 >>> all_upper, remainder = before_and_after(str.isupper, it)
744 >>> ''.join(all_upper)
745 'ABC'
746 >>> ''.join(remainder) # takewhile() would lose the 'd'
747 'dEfGhI'
748
749 Note that the first iterator must be fully consumed before the second
750 iterator can generate valid results.
751 """
752 trues, after = tee(it)
753 trues = compress(takewhile(predicate, trues), zip(after))
754 return trues, after
755
756
757def triplewise(iterable):
758 """Return overlapping triplets from *iterable*.
759
760 >>> list(triplewise('ABCDE'))
761 [('A', 'B', 'C'), ('B', 'C', 'D'), ('C', 'D', 'E')]
762
763 """
764 # This deviates from the itertools documentation recipe - see
765 # https://github.com/more-itertools/more-itertools/issues/889
766 t1, t2, t3 = tee(iterable, 3)
767 next(t3, None)
768 next(t3, None)
769 next(t2, None)
770 return zip(t1, t2, t3)
771
772
773def _sliding_window_islice(iterable, n):
774 # Fast path for small, non-zero values of n.
775 iterators = tee(iterable, n)
776 for i, iterator in enumerate(iterators):
777 next(islice(iterator, i, i), None)
778 return zip(*iterators)
779
780
781def _sliding_window_deque(iterable, n):
782 # Normal path for other values of n.
783 iterator = iter(iterable)
784 window = deque(islice(iterator, n - 1), maxlen=n)
785 for x in iterator:
786 window.append(x)
787 yield tuple(window)
788
789
790def sliding_window(iterable, n):
791 """Return a sliding window of width *n* over *iterable*.
792
793 >>> list(sliding_window(range(6), 4))
794 [(0, 1, 2, 3), (1, 2, 3, 4), (2, 3, 4, 5)]
795
796 If *iterable* has fewer than *n* items, then nothing is yielded:
797
798 >>> list(sliding_window(range(3), 4))
799 []
800
801 For a variant with more features, see :func:`windowed`.
802 """
803 if n > 20:
804 return _sliding_window_deque(iterable, n)
805 elif n > 2:
806 return _sliding_window_islice(iterable, n)
807 elif n == 2:
808 return pairwise(iterable)
809 elif n == 1:
810 return zip(iterable)
811 else:
812 raise ValueError(f'n should be at least one, not {n}')
813
814
815def subslices(iterable):
816 """Return all contiguous non-empty subslices of *iterable*.
817
818 >>> list(subslices('ABC'))
819 [['A'], ['A', 'B'], ['A', 'B', 'C'], ['B'], ['B', 'C'], ['C']]
820
821 This is similar to :func:`substrings`, but emits items in a different
822 order.
823 """
824 seq = list(iterable)
825 slices = starmap(slice, combinations(range(len(seq) + 1), 2))
826 return map(getitem, repeat(seq), slices)
827
828
829def polynomial_from_roots(roots):
830 """Compute a polynomial's coefficients from its roots.
831
832 >>> roots = [5, -4, 3] # (x - 5) * (x + 4) * (x - 3)
833 >>> polynomial_from_roots(roots) # x³ - 4 x² - 17 x + 60
834 [1, -4, -17, 60]
835
836 Note that polynomial coefficients are specified in descending power order.
837
838 Supports all numeric types: int, float, complex, Decimal, Fraction.
839 """
840
841 # This recipe differs from the one in itertools docs in that it
842 # applies list() after each call to convolve(). This avoids
843 # hitting stack limits with nested generators.
844
845 poly = [1]
846 for root in roots:
847 poly = list(convolve(poly, (1, -root)))
848 return poly
849
850
851def iter_index(iterable, value, start=0, stop=None):
852 """Yield the index of each place in *iterable* that *value* occurs,
853 beginning with index *start* and ending before index *stop*.
854
855
856 >>> list(iter_index('AABCADEAF', 'A'))
857 [0, 1, 4, 7]
858 >>> list(iter_index('AABCADEAF', 'A', 1)) # start index is inclusive
859 [1, 4, 7]
860 >>> list(iter_index('AABCADEAF', 'A', 1, 7)) # stop index is not inclusive
861 [1, 4]
862
863 The behavior for non-scalar *values* matches the built-in Python types.
864
865 >>> list(iter_index('ABCDABCD', 'AB'))
866 [0, 4]
867 >>> list(iter_index([0, 1, 2, 3, 0, 1, 2, 3], [0, 1]))
868 []
869 >>> list(iter_index([[0, 1], [2, 3], [0, 1], [2, 3]], [0, 1]))
870 [0, 2]
871
872 See :func:`locate` for a more general means of finding the indexes
873 associated with particular values.
874
875 """
876 seq_index = getattr(iterable, 'index', None)
877 if seq_index is None:
878 # Slow path for general iterables
879 iterator = islice(iterable, start, stop)
880 for i, element in enumerate(iterator, start):
881 if element is value or element == value:
882 yield i
883 else:
884 # Fast path for sequences
885 stop = len(iterable) if stop is None else stop
886 i = start - 1
887 with suppress(ValueError):
888 while True:
889 yield (i := seq_index(value, i + 1, stop))
890
891
892def sieve(n):
893 """Yield the primes less than n.
894
895 >>> list(sieve(30))
896 [2, 3, 5, 7, 11, 13, 17, 19, 23, 29]
897
898 """
899 # This implementation comes from an older version of the itertools
900 # documentation. The newer implementation is easier to read but is
901 # less lazy.
902 if n > 2:
903 yield 2
904 start = 3
905 data = bytearray((0, 1)) * (n // 2)
906 for p in iter_index(data, 1, start, stop=isqrt(n) + 1):
907 yield from iter_index(data, 1, start, p * p)
908 data[p * p : n : p + p] = bytes(len(range(p * p, n, p + p)))
909 start = p * p
910 yield from iter_index(data, 1, start)
911
912
913def _batched(iterable, n, *, strict=False): # pragma: no cover
914 """Batch data into tuples of length *n*. If the number of items in
915 *iterable* is not divisible by *n*:
916 * The last batch will be shorter if *strict* is ``False``.
917 * :exc:`ValueError` will be raised if *strict* is ``True``.
918
919 >>> list(batched('ABCDEFG', 3))
920 [('A', 'B', 'C'), ('D', 'E', 'F'), ('G',)]
921
922 On Python 3.13 and above, this is an alias for :func:`itertools.batched`.
923 """
924 if n < 1:
925 raise ValueError('n must be at least one')
926 iterator = iter(iterable)
927 while batch := tuple(islice(iterator, n)):
928 if strict and len(batch) != n:
929 raise ValueError('batched(): incomplete batch')
930 yield batch
931
932
933if hexversion >= 0x30D00A2: # pragma: no cover
934 from itertools import batched as itertools_batched
935
936 def batched(iterable, n, *, strict=False):
937 return itertools_batched(iterable, n, strict=strict)
938
939 batched.__doc__ = _batched.__doc__
940else: # pragma: no cover
941 batched = _batched
942
943
944def transpose(it):
945 """Swap the rows and columns of the input matrix.
946
947 >>> list(transpose([(1, 2, 3), (11, 22, 33)]))
948 [(1, 11), (2, 22), (3, 33)]
949
950 The caller should ensure that the dimensions of the input are compatible.
951 If the input is empty, no output will be produced.
952 """
953 return zip(*it, strict=True)
954
955
956def _is_scalar(value, stringlike=(str, bytes)):
957 "Scalars are bytes, strings, and non-iterables."
958 try:
959 iter(value)
960 except TypeError:
961 return True
962 return isinstance(value, stringlike)
963
964
965def _flatten_tensor(tensor):
966 "Depth-first iterator over scalars in a tensor."
967 iterator = iter(tensor)
968 while True:
969 try:
970 value = next(iterator)
971 except StopIteration:
972 return iterator
973 iterator = chain((value,), iterator)
974 if _is_scalar(value):
975 return iterator
976 iterator = chain.from_iterable(iterator)
977
978
979def reshape(matrix, shape):
980 """Change the shape of a *matrix*.
981
982 If *shape* is an integer, the matrix must be two dimensional
983 and the shape is interpreted as the desired number of columns:
984
985 >>> matrix = [(0, 1), (2, 3), (4, 5)]
986 >>> cols = 3
987 >>> list(reshape(matrix, cols))
988 [(0, 1, 2), (3, 4, 5)]
989
990 If *shape* is a tuple (or other iterable), the input matrix can have
991 any number of dimensions. It will first be flattened and then rebuilt
992 to the desired shape which can also be multidimensional:
993
994 >>> matrix = [(0, 1), (2, 3), (4, 5)] # Start with a 3 x 2 matrix
995
996 >>> list(reshape(matrix, (2, 3))) # Make a 2 x 3 matrix
997 [(0, 1, 2), (3, 4, 5)]
998
999 >>> list(reshape(matrix, (6,))) # Make a vector of length six
1000 [0, 1, 2, 3, 4, 5]
1001
1002 >>> list(reshape(matrix, (2, 1, 3, 1))) # Make 2 x 1 x 3 x 1 tensor
1003 [(((0,), (1,), (2,)),), (((3,), (4,), (5,)),)]
1004
1005 Each dimension is assumed to be uniform, either all arrays or all scalars.
1006 Flattening stops when the first value in a dimension is a scalar.
1007 Scalars are bytes, strings, and non-iterables.
1008 The reshape iterator stops when the requested shape is complete
1009 or when the input is exhausted, whichever comes first.
1010
1011 """
1012 if isinstance(shape, int):
1013 return batched(chain.from_iterable(matrix), shape)
1014 first_dim, *dims = shape
1015 scalar_stream = _flatten_tensor(matrix)
1016 reshaped = reduce(batched, reversed(dims), scalar_stream)
1017 return islice(reshaped, first_dim)
1018
1019
1020def matmul(m1, m2):
1021 """Multiply two matrices.
1022
1023 >>> list(matmul([(7, 5), (3, 5)], [(2, 5), (7, 9)]))
1024 [(49, 80), (41, 60)]
1025
1026 The caller should ensure that the dimensions of the input matrices are
1027 compatible with each other.
1028
1029 Supports all numeric types: int, float, complex, Decimal, Fraction.
1030 """
1031 n = len(m2[0])
1032 return batched(starmap(_sumprod, product(m1, transpose(m2))), n)
1033
1034
1035def _factor_pollard(n):
1036 # Return a factor of n using Pollard's rho algorithm.
1037 # Efficient when n is odd and composite.
1038 for b in range(1, n):
1039 x = y = 2
1040 d = 1
1041 while d == 1:
1042 x = (x * x + b) % n
1043 y = (y * y + b) % n
1044 y = (y * y + b) % n
1045 d = gcd(x - y, n)
1046 if d != n:
1047 return d
1048 raise ValueError('prime or under 5') # pragma: no cover
1049
1050
1051_primes_below_211 = tuple(sieve(211))
1052
1053
1054def factor(n):
1055 """Yield the prime factors of n.
1056
1057 >>> list(factor(360))
1058 [2, 2, 2, 3, 3, 5]
1059
1060 Finds small factors with trial division. Larger factors are
1061 either verified as prime with ``is_prime`` or split into
1062 smaller factors with Pollard's rho algorithm.
1063 """
1064
1065 # Corner case reduction
1066 if n < 2:
1067 return
1068
1069 # Trial division reduction
1070 for prime in _primes_below_211:
1071 while not n % prime:
1072 yield prime
1073 n //= prime
1074
1075 # Pollard's rho reduction
1076 primes = []
1077 todo = [n] if n > 1 else []
1078 for n in todo:
1079 if n < 211**2 or is_prime(n):
1080 primes.append(n)
1081 else:
1082 fact = _factor_pollard(n)
1083 todo += (fact, n // fact)
1084 yield from sorted(primes)
1085
1086
1087def polynomial_eval(coefficients, x):
1088 """Evaluate a polynomial at a specific value.
1089
1090 Computes with better numeric stability than Horner's method.
1091
1092 Evaluate ``x^3 - 4 * x^2 - 17 * x + 60`` at ``x = 2.5``:
1093
1094 >>> coefficients = [1, -4, -17, 60]
1095 >>> x = 2.5
1096 >>> polynomial_eval(coefficients, x)
1097 8.125
1098
1099 Note that polynomial coefficients are specified in descending power order.
1100
1101 Supports all numeric types: int, float, complex, Decimal, Fraction.
1102 """
1103 n = len(coefficients)
1104 if n == 0:
1105 return type(x)(0)
1106 powers = map(pow, repeat(x), reversed(range(n)))
1107 return _sumprod(coefficients, powers)
1108
1109
1110def sum_of_squares(it):
1111 """Return the sum of the squares of the input values.
1112
1113 >>> sum_of_squares([10, 20, 30])
1114 1400
1115
1116 Supports all numeric types: int, float, complex, Decimal, Fraction.
1117 """
1118 return _sumprod(*tee(it))
1119
1120
1121def polynomial_derivative(coefficients):
1122 """Compute the first derivative of a polynomial.
1123
1124 Evaluate the derivative of ``x³ - 4 x² - 17 x + 60``:
1125
1126 >>> coefficients = [1, -4, -17, 60]
1127 >>> derivative_coefficients = polynomial_derivative(coefficients)
1128 >>> derivative_coefficients
1129 [3, -8, -17]
1130
1131 Note that polynomial coefficients are specified in descending power order.
1132
1133 Supports all numeric types: int, float, complex, Decimal, Fraction.
1134 """
1135 n = len(coefficients)
1136 powers = reversed(range(1, n))
1137 return list(map(mul, coefficients, powers))
1138
1139
1140def totient(n):
1141 """Return the count of natural numbers up to *n* that are coprime with *n*.
1142
1143 Euler's totient function φ(n) gives the number of totatives.
1144 Totative are integers k in the range 1 ≤ k ≤ n such that gcd(n, k) = 1.
1145
1146 >>> n = 9
1147 >>> totient(n)
1148 6
1149
1150 >>> totatives = [x for x in range(1, n) if gcd(n, x) == 1]
1151 >>> totatives
1152 [1, 2, 4, 5, 7, 8]
1153 >>> len(totatives)
1154 6
1155
1156 Reference: https://en.wikipedia.org/wiki/Euler%27s_totient_function
1157
1158 """
1159 for prime in set(factor(n)):
1160 n -= n // prime
1161 return n
1162
1163
1164# Miller–Rabin primality test: https://oeis.org/A014233
1165_perfect_tests = [
1166 (2047, (2,)),
1167 (9080191, (31, 73)),
1168 (4759123141, (2, 7, 61)),
1169 (1122004669633, (2, 13, 23, 1662803)),
1170 (2152302898747, (2, 3, 5, 7, 11)),
1171 (3474749660383, (2, 3, 5, 7, 11, 13)),
1172 (18446744073709551616, (2, 325, 9375, 28178, 450775, 9780504, 1795265022)),
1173 (
1174 3317044064679887385961981,
1175 (2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41),
1176 ),
1177]
1178
1179
1180@lru_cache
1181def _shift_to_odd(n):
1182 'Return s, d such that 2**s * d == n'
1183 s = ((n - 1) ^ n).bit_length() - 1
1184 d = n >> s
1185 assert (1 << s) * d == n and d & 1 and s >= 0
1186 return s, d
1187
1188
1189def _strong_probable_prime(n, base):
1190 assert (n > 2) and (n & 1) and (2 <= base < n)
1191
1192 s, d = _shift_to_odd(n - 1)
1193
1194 x = pow(base, d, n)
1195 if x == 1 or x == n - 1:
1196 return True
1197
1198 for _ in range(s - 1):
1199 x = x * x % n
1200 if x == n - 1:
1201 return True
1202
1203 return False
1204
1205
1206# Separate instance of Random() that doesn't share state
1207# with the default user instance of Random().
1208_private_randrange = random.Random().randrange
1209
1210
1211def is_prime(n):
1212 """Return ``True`` if *n* is prime and ``False`` otherwise.
1213
1214 Basic examples:
1215
1216 >>> is_prime(37)
1217 True
1218 >>> is_prime(3 * 13)
1219 False
1220 >>> is_prime(18_446_744_073_709_551_557)
1221 True
1222
1223 Find the next prime over one billion:
1224
1225 >>> next(filter(is_prime, count(10**9)))
1226 1000000007
1227
1228 Generate random primes up to 200 bits and up to 60 decimal digits:
1229
1230 >>> from random import seed, randrange, getrandbits
1231 >>> seed(18675309)
1232
1233 >>> next(filter(is_prime, map(getrandbits, repeat(200))))
1234 893303929355758292373272075469392561129886005037663238028407
1235
1236 >>> next(filter(is_prime, map(randrange, repeat(10**60))))
1237 269638077304026462407872868003560484232362454342414618963649
1238
1239 This function is exact for values of *n* below 10**24. For larger inputs,
1240 the probabilistic Miller-Rabin primality test has a less than 1 in 2**128
1241 chance of a false positive.
1242 """
1243
1244 if n < 17:
1245 return n in {2, 3, 5, 7, 11, 13}
1246
1247 if not (n & 1 and n % 3 and n % 5 and n % 7 and n % 11 and n % 13):
1248 return False
1249
1250 for limit, bases in _perfect_tests:
1251 if n < limit:
1252 break
1253 else:
1254 bases = (_private_randrange(2, n - 1) for i in range(64))
1255
1256 return all(_strong_probable_prime(n, base) for base in bases)
1257
1258
1259def loops(n):
1260 """Returns an iterable with *n* elements for efficient looping.
1261 Like ``range(n)`` but doesn't create integers.
1262
1263 >>> i = 0
1264 >>> for _ in loops(5):
1265 ... i += 1
1266 >>> i
1267 5
1268
1269 """
1270 return repeat(None, n)
1271
1272
1273def multinomial(*counts):
1274 """Number of distinct arrangements of a multiset.
1275
1276 The expression ``multinomial(3, 4, 2)`` has several equivalent
1277 interpretations:
1278
1279 * In the expansion of ``(a + b + c)⁹``, the coefficient of the
1280 ``a³b⁴c²`` term is 1260.
1281
1282 * There are 1260 distinct ways to arrange 9 balls consisting of 3 reds, 4
1283 greens, and 2 blues.
1284
1285 * There are 1260 unique ways to place 9 distinct objects into three bins
1286 with sizes 3, 4, and 2.
1287
1288 The :func:`multinomial` function computes the length of
1289 :func:`distinct_permutations`. For example, there are 83,160 distinct
1290 anagrams of the word "abracadabra":
1291
1292 >>> from more_itertools import distinct_permutations, ilen
1293 >>> ilen(distinct_permutations('abracadabra'))
1294 83160
1295
1296 This can be computed directly from the letter counts, 5a 2b 2r 1c 1d:
1297
1298 >>> from collections import Counter
1299 >>> list(Counter('abracadabra').values())
1300 [5, 2, 2, 1, 1]
1301 >>> multinomial(5, 2, 2, 1, 1)
1302 83160
1303
1304 A binomial coefficient is a special case of multinomial where there are
1305 only two categories. For example, the number of ways to arrange 12 balls
1306 with 5 reds and 7 blues is ``multinomial(5, 7)`` or ``math.comb(12, 5)``.
1307
1308 Likewise, factorial is a special case of multinomial where
1309 the multiplicities are all just 1 so that
1310 ``multinomial(1, 1, 1, 1, 1, 1, 1) == math.factorial(7)``.
1311
1312 Reference: https://en.wikipedia.org/wiki/Multinomial_theorem
1313
1314 """
1315 return prod(map(comb, accumulate(counts), counts))
1316
1317
1318def _running_median_minheap_and_maxheap(iterator): # pragma: no cover
1319 "Non-windowed running_median() for Python 3.14+"
1320
1321 read = iterator.__next__
1322 lo = [] # max-heap
1323 hi = [] # min-heap (same size as or one smaller than lo)
1324
1325 with suppress(StopIteration):
1326 while True:
1327 heappush_max(lo, heappushpop(hi, read()))
1328 yield lo[0]
1329
1330 heappush(hi, heappushpop_max(lo, read()))
1331 yield (lo[0] + hi[0]) / 2
1332
1333
1334def _running_median_minheap_only(iterator): # pragma: no cover
1335 "Backport of non-windowed running_median() for Python 3.13 and prior."
1336
1337 read = iterator.__next__
1338 lo = [] # max-heap (actually a minheap with negated values)
1339 hi = [] # min-heap (same size as or one smaller than lo)
1340
1341 with suppress(StopIteration):
1342 while True:
1343 heappush(lo, -heappushpop(hi, read()))
1344 yield -lo[0]
1345
1346 heappush(hi, -heappushpop(lo, -read()))
1347 yield (hi[0] - lo[0]) / 2
1348
1349
1350def _running_median_windowed(iterator, maxlen):
1351 "Yield median of values in a sliding window."
1352
1353 window = deque()
1354 ordered = []
1355
1356 for x in iterator:
1357 window.append(x)
1358 insort(ordered, x)
1359
1360 if len(ordered) > maxlen:
1361 i = bisect_left(ordered, window.popleft())
1362 del ordered[i]
1363
1364 n = len(ordered)
1365 m = n // 2
1366 yield ordered[m] if n & 1 else (ordered[m - 1] + ordered[m]) / 2
1367
1368
1369def running_median(iterable, *, maxlen=None):
1370 """Cumulative median of values seen so far or values in a sliding window.
1371
1372 Set *maxlen* to a positive integer to specify the maximum size
1373 of the sliding window. The default of *None* is equivalent to
1374 an unbounded window.
1375
1376 For example:
1377
1378 >>> list(running_median([5.0, 9.0, 4.0, 12.0, 8.0, 9.0]))
1379 [5.0, 7.0, 5.0, 7.0, 8.0, 8.5]
1380 >>> list(running_median([5.0, 9.0, 4.0, 12.0, 8.0, 9.0], maxlen=3))
1381 [5.0, 7.0, 5.0, 9.0, 8.0, 9.0]
1382
1383 Supports numeric types such as int, float, Decimal, and Fraction,
1384 but not complex numbers which are unorderable.
1385
1386 On version Python 3.13 and prior, max-heaps are simulated with
1387 negative values. The negation causes Decimal inputs to apply context
1388 rounding, making the results slightly different than that obtained
1389 by statistics.median().
1390 """
1391
1392 iterator = iter(iterable)
1393
1394 if maxlen is not None:
1395 maxlen = index(maxlen)
1396 if maxlen <= 0:
1397 raise ValueError('Window size should be positive')
1398 return _running_median_windowed(iterator, maxlen)
1399
1400 if not _max_heap_available:
1401 return _running_median_minheap_only(iterator) # pragma: no cover
1402
1403 return _running_median_minheap_and_maxheap(iterator) # pragma: no cover
1404
1405
1406def random_derangement(iterable):
1407 """Return a random derangement of elements in the iterable.
1408
1409 Equivalent to but much faster than ``choice(list(derangements(iterable)))``.
1410
1411 """
1412 seq = tuple(iterable)
1413 if len(seq) < 2:
1414 if len(seq) == 0:
1415 return ()
1416 raise IndexError('No derangments to choose from')
1417 perm = list(range(len(seq)))
1418 start = tuple(perm)
1419 while True:
1420 shuffle(perm)
1421 if not any(map(is_, start, perm)):
1422 return itemgetter(*perm)(seq)