1"""Imported from the recipes section of the itertools documentation.
2
3All functions taken from the recipes section of the itertools library docs
4[1]_.
5Some backward-compatible usability improvements have been made.
6
7.. [1] http://docs.python.org/library/itertools.html#recipes
8
9"""
10
11import random
12
13from collections import deque
14from contextlib import suppress
15from collections.abc import Sized
16from functools import lru_cache, partial
17from itertools import (
18 accumulate,
19 chain,
20 combinations,
21 compress,
22 count,
23 cycle,
24 groupby,
25 islice,
26 product,
27 repeat,
28 starmap,
29 takewhile,
30 tee,
31 zip_longest,
32)
33from math import prod, comb, isqrt, gcd
34from operator import mul, not_, itemgetter, getitem
35from random import randrange, sample, choice
36from sys import hexversion
37
38__all__ = [
39 'all_equal',
40 'batched',
41 'before_and_after',
42 'consume',
43 'convolve',
44 'dotproduct',
45 'first_true',
46 'factor',
47 'flatten',
48 'grouper',
49 'is_prime',
50 'iter_except',
51 'iter_index',
52 'loops',
53 'matmul',
54 'multinomial',
55 'ncycles',
56 'nth',
57 'nth_combination',
58 'padnone',
59 'pad_none',
60 'pairwise',
61 'partition',
62 'polynomial_eval',
63 'polynomial_from_roots',
64 'polynomial_derivative',
65 'powerset',
66 'prepend',
67 'quantify',
68 'reshape',
69 'random_combination_with_replacement',
70 'random_combination',
71 'random_permutation',
72 'random_product',
73 'repeatfunc',
74 'roundrobin',
75 'sieve',
76 'sliding_window',
77 'subslices',
78 'sum_of_squares',
79 'tabulate',
80 'tail',
81 'take',
82 'totient',
83 'transpose',
84 'triplewise',
85 'unique',
86 'unique_everseen',
87 'unique_justseen',
88]
89
90_marker = object()
91
92
93# zip with strict is available for Python 3.10+
94try:
95 zip(strict=True)
96except TypeError:
97 _zip_strict = zip
98else:
99 _zip_strict = partial(zip, strict=True)
100
101
102# math.sumprod is available for Python 3.12+
103try:
104 from math import sumprod as _sumprod
105except ImportError:
106 _sumprod = lambda x, y: dotproduct(x, y)
107
108
109def take(n, iterable):
110 """Return first *n* items of the *iterable* as a list.
111
112 >>> take(3, range(10))
113 [0, 1, 2]
114
115 If there are fewer than *n* items in the iterable, all of them are
116 returned.
117
118 >>> take(10, range(3))
119 [0, 1, 2]
120
121 """
122 return list(islice(iterable, n))
123
124
125def tabulate(function, start=0):
126 """Return an iterator over the results of ``func(start)``,
127 ``func(start + 1)``, ``func(start + 2)``...
128
129 *func* should be a function that accepts one integer argument.
130
131 If *start* is not specified it defaults to 0. It will be incremented each
132 time the iterator is advanced.
133
134 >>> square = lambda x: x ** 2
135 >>> iterator = tabulate(square, -3)
136 >>> take(4, iterator)
137 [9, 4, 1, 0]
138
139 """
140 return map(function, count(start))
141
142
143def tail(n, iterable):
144 """Return an iterator over the last *n* items of *iterable*.
145
146 >>> t = tail(3, 'ABCDEFG')
147 >>> list(t)
148 ['E', 'F', 'G']
149
150 """
151 # If the given iterable has a length, then we can use islice to get its
152 # final elements. Note that if the iterable is not actually Iterable,
153 # either islice or deque will throw a TypeError. This is why we don't
154 # check if it is Iterable.
155 if isinstance(iterable, Sized):
156 return islice(iterable, max(0, len(iterable) - n), None)
157 else:
158 return iter(deque(iterable, maxlen=n))
159
160
161def consume(iterator, n=None):
162 """Advance *iterable* by *n* steps. If *n* is ``None``, consume it
163 entirely.
164
165 Efficiently exhausts an iterator without returning values. Defaults to
166 consuming the whole iterator, but an optional second argument may be
167 provided to limit consumption.
168
169 >>> i = (x for x in range(10))
170 >>> next(i)
171 0
172 >>> consume(i, 3)
173 >>> next(i)
174 4
175 >>> consume(i)
176 >>> next(i)
177 Traceback (most recent call last):
178 File "<stdin>", line 1, in <module>
179 StopIteration
180
181 If the iterator has fewer items remaining than the provided limit, the
182 whole iterator will be consumed.
183
184 >>> i = (x for x in range(3))
185 >>> consume(i, 5)
186 >>> next(i)
187 Traceback (most recent call last):
188 File "<stdin>", line 1, in <module>
189 StopIteration
190
191 """
192 # Use functions that consume iterators at C speed.
193 if n is None:
194 # feed the entire iterator into a zero-length deque
195 deque(iterator, maxlen=0)
196 else:
197 # advance to the empty slice starting at position n
198 next(islice(iterator, n, n), None)
199
200
201def nth(iterable, n, default=None):
202 """Returns the nth item or a default value.
203
204 >>> l = range(10)
205 >>> nth(l, 3)
206 3
207 >>> nth(l, 20, "zebra")
208 'zebra'
209
210 """
211 return next(islice(iterable, n, None), default)
212
213
214def all_equal(iterable, key=None):
215 """
216 Returns ``True`` if all the elements are equal to each other.
217
218 >>> all_equal('aaaa')
219 True
220 >>> all_equal('aaab')
221 False
222
223 A function that accepts a single argument and returns a transformed version
224 of each input item can be specified with *key*:
225
226 >>> all_equal('AaaA', key=str.casefold)
227 True
228 >>> all_equal([1, 2, 3], key=lambda x: x < 10)
229 True
230
231 """
232 iterator = groupby(iterable, key)
233 for first in iterator:
234 for second in iterator:
235 return False
236 return True
237 return True
238
239
240def quantify(iterable, pred=bool):
241 """Return the how many times the predicate is true.
242
243 >>> quantify([True, False, True])
244 2
245
246 """
247 return sum(map(pred, iterable))
248
249
250def pad_none(iterable):
251 """Returns the sequence of elements and then returns ``None`` indefinitely.
252
253 >>> take(5, pad_none(range(3)))
254 [0, 1, 2, None, None]
255
256 Useful for emulating the behavior of the built-in :func:`map` function.
257
258 See also :func:`padded`.
259
260 """
261 return chain(iterable, repeat(None))
262
263
264padnone = pad_none
265
266
267def ncycles(iterable, n):
268 """Returns the sequence elements *n* times
269
270 >>> list(ncycles(["a", "b"], 3))
271 ['a', 'b', 'a', 'b', 'a', 'b']
272
273 """
274 return chain.from_iterable(repeat(tuple(iterable), n))
275
276
277def dotproduct(vec1, vec2):
278 """Returns the dot product of the two iterables.
279
280 >>> dotproduct([10, 15, 12], [0.65, 0.80, 1.25])
281 33.5
282 >>> 10 * 0.65 + 15 * 0.80 + 12 * 1.25
283 33.5
284
285 In Python 3.12 and later, use ``math.sumprod()`` instead.
286 """
287 return sum(map(mul, vec1, vec2))
288
289
290def flatten(listOfLists):
291 """Return an iterator flattening one level of nesting in a list of lists.
292
293 >>> list(flatten([[0, 1], [2, 3]]))
294 [0, 1, 2, 3]
295
296 See also :func:`collapse`, which can flatten multiple levels of nesting.
297
298 """
299 return chain.from_iterable(listOfLists)
300
301
302def repeatfunc(func, times=None, *args):
303 """Call *func* with *args* repeatedly, returning an iterable over the
304 results.
305
306 If *times* is specified, the iterable will terminate after that many
307 repetitions:
308
309 >>> from operator import add
310 >>> times = 4
311 >>> args = 3, 5
312 >>> list(repeatfunc(add, times, *args))
313 [8, 8, 8, 8]
314
315 If *times* is ``None`` the iterable will not terminate:
316
317 >>> from random import randrange
318 >>> times = None
319 >>> args = 1, 11
320 >>> take(6, repeatfunc(randrange, times, *args)) # doctest:+SKIP
321 [2, 4, 8, 1, 8, 4]
322
323 """
324 if times is None:
325 return starmap(func, repeat(args))
326 return starmap(func, repeat(args, times))
327
328
329def _pairwise(iterable):
330 """Returns an iterator of paired items, overlapping, from the original
331
332 >>> take(4, pairwise(count()))
333 [(0, 1), (1, 2), (2, 3), (3, 4)]
334
335 On Python 3.10 and above, this is an alias for :func:`itertools.pairwise`.
336
337 """
338 a, b = tee(iterable)
339 next(b, None)
340 return zip(a, b)
341
342
343try:
344 from itertools import pairwise as itertools_pairwise
345except ImportError:
346 pairwise = _pairwise
347else:
348
349 def pairwise(iterable):
350 return itertools_pairwise(iterable)
351
352 pairwise.__doc__ = _pairwise.__doc__
353
354
355class UnequalIterablesError(ValueError):
356 def __init__(self, details=None):
357 msg = 'Iterables have different lengths'
358 if details is not None:
359 msg += (': index 0 has length {}; index {} has length {}').format(
360 *details
361 )
362
363 super().__init__(msg)
364
365
366def _zip_equal_generator(iterables):
367 for combo in zip_longest(*iterables, fillvalue=_marker):
368 for val in combo:
369 if val is _marker:
370 raise UnequalIterablesError()
371 yield combo
372
373
374def _zip_equal(*iterables):
375 # Check whether the iterables are all the same size.
376 try:
377 first_size = len(iterables[0])
378 for i, it in enumerate(iterables[1:], 1):
379 size = len(it)
380 if size != first_size:
381 raise UnequalIterablesError(details=(first_size, i, size))
382 # All sizes are equal, we can use the built-in zip.
383 return zip(*iterables)
384 # If any one of the iterables didn't have a length, start reading
385 # them until one runs out.
386 except TypeError:
387 return _zip_equal_generator(iterables)
388
389
390def grouper(iterable, n, incomplete='fill', fillvalue=None):
391 """Group elements from *iterable* into fixed-length groups of length *n*.
392
393 >>> list(grouper('ABCDEF', 3))
394 [('A', 'B', 'C'), ('D', 'E', 'F')]
395
396 The keyword arguments *incomplete* and *fillvalue* control what happens for
397 iterables whose length is not a multiple of *n*.
398
399 When *incomplete* is `'fill'`, the last group will contain instances of
400 *fillvalue*.
401
402 >>> list(grouper('ABCDEFG', 3, incomplete='fill', fillvalue='x'))
403 [('A', 'B', 'C'), ('D', 'E', 'F'), ('G', 'x', 'x')]
404
405 When *incomplete* is `'ignore'`, the last group will not be emitted.
406
407 >>> list(grouper('ABCDEFG', 3, incomplete='ignore', fillvalue='x'))
408 [('A', 'B', 'C'), ('D', 'E', 'F')]
409
410 When *incomplete* is `'strict'`, a subclass of `ValueError` will be raised.
411
412 >>> iterator = grouper('ABCDEFG', 3, incomplete='strict')
413 >>> list(iterator) # doctest: +IGNORE_EXCEPTION_DETAIL
414 Traceback (most recent call last):
415 ...
416 UnequalIterablesError
417
418 """
419 iterators = [iter(iterable)] * n
420 if incomplete == 'fill':
421 return zip_longest(*iterators, fillvalue=fillvalue)
422 if incomplete == 'strict':
423 return _zip_equal(*iterators)
424 if incomplete == 'ignore':
425 return zip(*iterators)
426 else:
427 raise ValueError('Expected fill, strict, or ignore')
428
429
430def roundrobin(*iterables):
431 """Visit input iterables in a cycle until each is exhausted.
432
433 >>> list(roundrobin('ABC', 'D', 'EF'))
434 ['A', 'D', 'E', 'B', 'F', 'C']
435
436 This function produces the same output as :func:`interleave_longest`, but
437 may perform better for some inputs (in particular when the number of
438 iterables is small).
439
440 """
441 # Algorithm credited to George Sakkis
442 iterators = map(iter, iterables)
443 for num_active in range(len(iterables), 0, -1):
444 iterators = cycle(islice(iterators, num_active))
445 yield from map(next, iterators)
446
447
448def partition(pred, iterable):
449 """
450 Returns a 2-tuple of iterables derived from the input iterable.
451 The first yields the items that have ``pred(item) == False``.
452 The second yields the items that have ``pred(item) == True``.
453
454 >>> is_odd = lambda x: x % 2 != 0
455 >>> iterable = range(10)
456 >>> even_items, odd_items = partition(is_odd, iterable)
457 >>> list(even_items), list(odd_items)
458 ([0, 2, 4, 6, 8], [1, 3, 5, 7, 9])
459
460 If *pred* is None, :func:`bool` is used.
461
462 >>> iterable = [0, 1, False, True, '', ' ']
463 >>> false_items, true_items = partition(None, iterable)
464 >>> list(false_items), list(true_items)
465 ([0, False, ''], [1, True, ' '])
466
467 """
468 if pred is None:
469 pred = bool
470
471 t1, t2, p = tee(iterable, 3)
472 p1, p2 = tee(map(pred, p))
473 return (compress(t1, map(not_, p1)), compress(t2, p2))
474
475
476def powerset(iterable):
477 """Yields all possible subsets of the iterable.
478
479 >>> list(powerset([1, 2, 3]))
480 [(), (1,), (2,), (3,), (1, 2), (1, 3), (2, 3), (1, 2, 3)]
481
482 :func:`powerset` will operate on iterables that aren't :class:`set`
483 instances, so repeated elements in the input will produce repeated elements
484 in the output.
485
486 >>> seq = [1, 1, 0]
487 >>> list(powerset(seq))
488 [(), (1,), (1,), (0,), (1, 1), (1, 0), (1, 0), (1, 1, 0)]
489
490 For a variant that efficiently yields actual :class:`set` instances, see
491 :func:`powerset_of_sets`.
492 """
493 s = list(iterable)
494 return chain.from_iterable(combinations(s, r) for r in range(len(s) + 1))
495
496
497def unique_everseen(iterable, key=None):
498 """
499 Yield unique elements, preserving order.
500
501 >>> list(unique_everseen('AAAABBBCCDAABBB'))
502 ['A', 'B', 'C', 'D']
503 >>> list(unique_everseen('ABBCcAD', str.lower))
504 ['A', 'B', 'C', 'D']
505
506 Sequences with a mix of hashable and unhashable items can be used.
507 The function will be slower (i.e., `O(n^2)`) for unhashable items.
508
509 Remember that ``list`` objects are unhashable - you can use the *key*
510 parameter to transform the list to a tuple (which is hashable) to
511 avoid a slowdown.
512
513 >>> iterable = ([1, 2], [2, 3], [1, 2])
514 >>> list(unique_everseen(iterable)) # Slow
515 [[1, 2], [2, 3]]
516 >>> list(unique_everseen(iterable, key=tuple)) # Faster
517 [[1, 2], [2, 3]]
518
519 Similarly, you may want to convert unhashable ``set`` objects with
520 ``key=frozenset``. For ``dict`` objects,
521 ``key=lambda x: frozenset(x.items())`` can be used.
522
523 """
524 seenset = set()
525 seenset_add = seenset.add
526 seenlist = []
527 seenlist_add = seenlist.append
528 use_key = key is not None
529
530 for element in iterable:
531 k = key(element) if use_key else element
532 try:
533 if k not in seenset:
534 seenset_add(k)
535 yield element
536 except TypeError:
537 if k not in seenlist:
538 seenlist_add(k)
539 yield element
540
541
542def unique_justseen(iterable, key=None):
543 """Yields elements in order, ignoring serial duplicates
544
545 >>> list(unique_justseen('AAAABBBCCDAABBB'))
546 ['A', 'B', 'C', 'D', 'A', 'B']
547 >>> list(unique_justseen('ABBCcAD', str.lower))
548 ['A', 'B', 'C', 'A', 'D']
549
550 """
551 if key is None:
552 return map(itemgetter(0), groupby(iterable))
553
554 return map(next, map(itemgetter(1), groupby(iterable, key)))
555
556
557def unique(iterable, key=None, reverse=False):
558 """Yields unique elements in sorted order.
559
560 >>> list(unique([[1, 2], [3, 4], [1, 2]]))
561 [[1, 2], [3, 4]]
562
563 *key* and *reverse* are passed to :func:`sorted`.
564
565 >>> list(unique('ABBcCAD', str.casefold))
566 ['A', 'B', 'c', 'D']
567 >>> list(unique('ABBcCAD', str.casefold, reverse=True))
568 ['D', 'c', 'B', 'A']
569
570 The elements in *iterable* need not be hashable, but they must be
571 comparable for sorting to work.
572 """
573 sequenced = sorted(iterable, key=key, reverse=reverse)
574 return unique_justseen(sequenced, key=key)
575
576
577def iter_except(func, exception, first=None):
578 """Yields results from a function repeatedly until an exception is raised.
579
580 Converts a call-until-exception interface to an iterator interface.
581 Like ``iter(func, sentinel)``, but uses an exception instead of a sentinel
582 to end the loop.
583
584 >>> l = [0, 1, 2]
585 >>> list(iter_except(l.pop, IndexError))
586 [2, 1, 0]
587
588 Multiple exceptions can be specified as a stopping condition:
589
590 >>> l = [1, 2, 3, '...', 4, 5, 6]
591 >>> list(iter_except(lambda: 1 + l.pop(), (IndexError, TypeError)))
592 [7, 6, 5]
593 >>> list(iter_except(lambda: 1 + l.pop(), (IndexError, TypeError)))
594 [4, 3, 2]
595 >>> list(iter_except(lambda: 1 + l.pop(), (IndexError, TypeError)))
596 []
597
598 """
599 with suppress(exception):
600 if first is not None:
601 yield first()
602 while True:
603 yield func()
604
605
606def first_true(iterable, default=None, pred=None):
607 """
608 Returns the first true value in the iterable.
609
610 If no true value is found, returns *default*
611
612 If *pred* is not None, returns the first item for which
613 ``pred(item) == True`` .
614
615 >>> first_true(range(10))
616 1
617 >>> first_true(range(10), pred=lambda x: x > 5)
618 6
619 >>> first_true(range(10), default='missing', pred=lambda x: x > 9)
620 'missing'
621
622 """
623 return next(filter(pred, iterable), default)
624
625
626def random_product(*args, repeat=1):
627 """Draw an item at random from each of the input iterables.
628
629 >>> random_product('abc', range(4), 'XYZ') # doctest:+SKIP
630 ('c', 3, 'Z')
631
632 If *repeat* is provided as a keyword argument, that many items will be
633 drawn from each iterable.
634
635 >>> random_product('abcd', range(4), repeat=2) # doctest:+SKIP
636 ('a', 2, 'd', 3)
637
638 This equivalent to taking a random selection from
639 ``itertools.product(*args, repeat=repeat)``.
640
641 """
642 pools = [tuple(pool) for pool in args] * repeat
643 return tuple(choice(pool) for pool in pools)
644
645
646def random_permutation(iterable, r=None):
647 """Return a random *r* length permutation of the elements in *iterable*.
648
649 If *r* is not specified or is ``None``, then *r* defaults to the length of
650 *iterable*.
651
652 >>> random_permutation(range(5)) # doctest:+SKIP
653 (3, 4, 0, 1, 2)
654
655 This equivalent to taking a random selection from
656 ``itertools.permutations(iterable, r)``.
657
658 """
659 pool = tuple(iterable)
660 r = len(pool) if r is None else r
661 return tuple(sample(pool, r))
662
663
664def random_combination(iterable, r):
665 """Return a random *r* length subsequence of the elements in *iterable*.
666
667 >>> random_combination(range(5), 3) # doctest:+SKIP
668 (2, 3, 4)
669
670 This equivalent to taking a random selection from
671 ``itertools.combinations(iterable, r)``.
672
673 """
674 pool = tuple(iterable)
675 n = len(pool)
676 indices = sorted(sample(range(n), r))
677 return tuple(pool[i] for i in indices)
678
679
680def random_combination_with_replacement(iterable, r):
681 """Return a random *r* length subsequence of elements in *iterable*,
682 allowing individual elements to be repeated.
683
684 >>> random_combination_with_replacement(range(3), 5) # doctest:+SKIP
685 (0, 0, 1, 2, 2)
686
687 This equivalent to taking a random selection from
688 ``itertools.combinations_with_replacement(iterable, r)``.
689
690 """
691 pool = tuple(iterable)
692 n = len(pool)
693 indices = sorted(randrange(n) for i in range(r))
694 return tuple(pool[i] for i in indices)
695
696
697def nth_combination(iterable, r, index):
698 """Equivalent to ``list(combinations(iterable, r))[index]``.
699
700 The subsequences of *iterable* that are of length *r* can be ordered
701 lexicographically. :func:`nth_combination` computes the subsequence at
702 sort position *index* directly, without computing the previous
703 subsequences.
704
705 >>> nth_combination(range(5), 3, 5)
706 (0, 3, 4)
707
708 ``ValueError`` will be raised If *r* is negative or greater than the length
709 of *iterable*.
710 ``IndexError`` will be raised if the given *index* is invalid.
711 """
712 pool = tuple(iterable)
713 n = len(pool)
714 if (r < 0) or (r > n):
715 raise ValueError
716
717 c = 1
718 k = min(r, n - r)
719 for i in range(1, k + 1):
720 c = c * (n - k + i) // i
721
722 if index < 0:
723 index += c
724
725 if (index < 0) or (index >= c):
726 raise IndexError
727
728 result = []
729 while r:
730 c, n, r = c * r // n, n - 1, r - 1
731 while index >= c:
732 index -= c
733 c, n = c * (n - r) // n, n - 1
734 result.append(pool[-1 - n])
735
736 return tuple(result)
737
738
739def prepend(value, iterator):
740 """Yield *value*, followed by the elements in *iterator*.
741
742 >>> value = '0'
743 >>> iterator = ['1', '2', '3']
744 >>> list(prepend(value, iterator))
745 ['0', '1', '2', '3']
746
747 To prepend multiple values, see :func:`itertools.chain`
748 or :func:`value_chain`.
749
750 """
751 return chain([value], iterator)
752
753
754def convolve(signal, kernel):
755 """Discrete linear convolution of two iterables.
756 Equivalent to polynomial multiplication.
757
758 For example, multiplying ``(x² -x - 20)`` by ``(x - 3)``
759 gives ``(x³ -4x² -17x + 60)``.
760
761 >>> list(convolve([1, -1, -20], [1, -3]))
762 [1, -4, -17, 60]
763
764 Examples of popular kinds of kernels:
765
766 * The kernel ``[0.25, 0.25, 0.25, 0.25]`` computes a moving average.
767 For image data, this blurs the image and reduces noise.
768 * The kernel ``[1/2, 0, -1/2]`` estimates the first derivative of
769 a function evaluated at evenly spaced inputs.
770 * The kernel ``[1, -2, 1]`` estimates the second derivative of a
771 function evaluated at evenly spaced inputs.
772
773 Convolutions are mathematically commutative; however, the inputs are
774 evaluated differently. The signal is consumed lazily and can be
775 infinite. The kernel is fully consumed before the calculations begin.
776
777 Supports all numeric types: int, float, complex, Decimal, Fraction.
778
779 References:
780
781 * Article: https://betterexplained.com/articles/intuitive-convolution/
782 * Video by 3Blue1Brown: https://www.youtube.com/watch?v=KuXjwB4LzSA
783
784 """
785 # This implementation comes from an older version of the itertools
786 # documentation. While the newer implementation is a bit clearer,
787 # this one was kept because the inlined window logic is faster
788 # and it avoids an unnecessary deque-to-tuple conversion.
789 kernel = tuple(kernel)[::-1]
790 n = len(kernel)
791 window = deque([0], maxlen=n) * n
792 for x in chain(signal, repeat(0, n - 1)):
793 window.append(x)
794 yield _sumprod(kernel, window)
795
796
797def before_and_after(predicate, it):
798 """A variant of :func:`takewhile` that allows complete access to the
799 remainder of the iterator.
800
801 >>> it = iter('ABCdEfGhI')
802 >>> all_upper, remainder = before_and_after(str.isupper, it)
803 >>> ''.join(all_upper)
804 'ABC'
805 >>> ''.join(remainder) # takewhile() would lose the 'd'
806 'dEfGhI'
807
808 Note that the first iterator must be fully consumed before the second
809 iterator can generate valid results.
810 """
811 trues, after = tee(it)
812 trues = compress(takewhile(predicate, trues), zip(after))
813 return trues, after
814
815
816def triplewise(iterable):
817 """Return overlapping triplets from *iterable*.
818
819 >>> list(triplewise('ABCDE'))
820 [('A', 'B', 'C'), ('B', 'C', 'D'), ('C', 'D', 'E')]
821
822 """
823 # This deviates from the itertools documentation recipe - see
824 # https://github.com/more-itertools/more-itertools/issues/889
825 t1, t2, t3 = tee(iterable, 3)
826 next(t3, None)
827 next(t3, None)
828 next(t2, None)
829 return zip(t1, t2, t3)
830
831
832def _sliding_window_islice(iterable, n):
833 # Fast path for small, non-zero values of n.
834 iterators = tee(iterable, n)
835 for i, iterator in enumerate(iterators):
836 next(islice(iterator, i, i), None)
837 return zip(*iterators)
838
839
840def _sliding_window_deque(iterable, n):
841 # Normal path for other values of n.
842 iterator = iter(iterable)
843 window = deque(islice(iterator, n - 1), maxlen=n)
844 for x in iterator:
845 window.append(x)
846 yield tuple(window)
847
848
849def sliding_window(iterable, n):
850 """Return a sliding window of width *n* over *iterable*.
851
852 >>> list(sliding_window(range(6), 4))
853 [(0, 1, 2, 3), (1, 2, 3, 4), (2, 3, 4, 5)]
854
855 If *iterable* has fewer than *n* items, then nothing is yielded:
856
857 >>> list(sliding_window(range(3), 4))
858 []
859
860 For a variant with more features, see :func:`windowed`.
861 """
862 if n > 20:
863 return _sliding_window_deque(iterable, n)
864 elif n > 2:
865 return _sliding_window_islice(iterable, n)
866 elif n == 2:
867 return pairwise(iterable)
868 elif n == 1:
869 return zip(iterable)
870 else:
871 raise ValueError(f'n should be at least one, not {n}')
872
873
874def subslices(iterable):
875 """Return all contiguous non-empty subslices of *iterable*.
876
877 >>> list(subslices('ABC'))
878 [['A'], ['A', 'B'], ['A', 'B', 'C'], ['B'], ['B', 'C'], ['C']]
879
880 This is similar to :func:`substrings`, but emits items in a different
881 order.
882 """
883 seq = list(iterable)
884 slices = starmap(slice, combinations(range(len(seq) + 1), 2))
885 return map(getitem, repeat(seq), slices)
886
887
888def polynomial_from_roots(roots):
889 """Compute a polynomial's coefficients from its roots.
890
891 >>> roots = [5, -4, 3] # (x - 5) * (x + 4) * (x - 3)
892 >>> polynomial_from_roots(roots) # x³ - 4 x² - 17 x + 60
893 [1, -4, -17, 60]
894
895 Note that polynomial coefficients are specified in descending power order.
896
897 Supports all numeric types: int, float, complex, Decimal, Fraction.
898 """
899 # This recipe differs from the one in itertools docs in that it
900 # applies list() after each call to convolve(). This avoids
901 # hitting stack limits with nested generators.
902 poly = [1]
903 for root in roots:
904 poly = list(convolve(poly, (1, -root)))
905 return poly
906
907
908def iter_index(iterable, value, start=0, stop=None):
909 """Yield the index of each place in *iterable* that *value* occurs,
910 beginning with index *start* and ending before index *stop*.
911
912
913 >>> list(iter_index('AABCADEAF', 'A'))
914 [0, 1, 4, 7]
915 >>> list(iter_index('AABCADEAF', 'A', 1)) # start index is inclusive
916 [1, 4, 7]
917 >>> list(iter_index('AABCADEAF', 'A', 1, 7)) # stop index is not inclusive
918 [1, 4]
919
920 The behavior for non-scalar *values* matches the built-in Python types.
921
922 >>> list(iter_index('ABCDABCD', 'AB'))
923 [0, 4]
924 >>> list(iter_index([0, 1, 2, 3, 0, 1, 2, 3], [0, 1]))
925 []
926 >>> list(iter_index([[0, 1], [2, 3], [0, 1], [2, 3]], [0, 1]))
927 [0, 2]
928
929 See :func:`locate` for a more general means of finding the indexes
930 associated with particular values.
931
932 """
933 seq_index = getattr(iterable, 'index', None)
934 if seq_index is None:
935 # Slow path for general iterables
936 iterator = islice(iterable, start, stop)
937 for i, element in enumerate(iterator, start):
938 if element is value or element == value:
939 yield i
940 else:
941 # Fast path for sequences
942 stop = len(iterable) if stop is None else stop
943 i = start - 1
944 with suppress(ValueError):
945 while True:
946 yield (i := seq_index(value, i + 1, stop))
947
948
949def sieve(n):
950 """Yield the primes less than n.
951
952 >>> list(sieve(30))
953 [2, 3, 5, 7, 11, 13, 17, 19, 23, 29]
954
955 """
956 # This implementation comes from an older version of the itertools
957 # documentation. The newer implementation is easier to read but is
958 # less lazy.
959 if n > 2:
960 yield 2
961 start = 3
962 data = bytearray((0, 1)) * (n // 2)
963 for p in iter_index(data, 1, start, stop=isqrt(n) + 1):
964 yield from iter_index(data, 1, start, p * p)
965 data[p * p : n : p + p] = bytes(len(range(p * p, n, p + p)))
966 start = p * p
967 yield from iter_index(data, 1, start)
968
969
970def _batched(iterable, n, *, strict=False):
971 """Batch data into tuples of length *n*. If the number of items in
972 *iterable* is not divisible by *n*:
973 * The last batch will be shorter if *strict* is ``False``.
974 * :exc:`ValueError` will be raised if *strict* is ``True``.
975
976 >>> list(batched('ABCDEFG', 3))
977 [('A', 'B', 'C'), ('D', 'E', 'F'), ('G',)]
978
979 On Python 3.13 and above, this is an alias for :func:`itertools.batched`.
980 """
981 if n < 1:
982 raise ValueError('n must be at least one')
983 iterator = iter(iterable)
984 while batch := tuple(islice(iterator, n)):
985 if strict and len(batch) != n:
986 raise ValueError('batched(): incomplete batch')
987 yield batch
988
989
990if hexversion >= 0x30D00A2: # pragma: no cover
991 from itertools import batched as itertools_batched
992
993 def batched(iterable, n, *, strict=False):
994 return itertools_batched(iterable, n, strict=strict)
995
996else:
997 batched = _batched
998
999 batched.__doc__ = _batched.__doc__
1000
1001
1002def transpose(it):
1003 """Swap the rows and columns of the input matrix.
1004
1005 >>> list(transpose([(1, 2, 3), (11, 22, 33)]))
1006 [(1, 11), (2, 22), (3, 33)]
1007
1008 The caller should ensure that the dimensions of the input are compatible.
1009 If the input is empty, no output will be produced.
1010 """
1011 return _zip_strict(*it)
1012
1013
1014def reshape(matrix, cols):
1015 """Reshape the 2-D input *matrix* to have a column count given by *cols*.
1016
1017 >>> matrix = [(0, 1), (2, 3), (4, 5)]
1018 >>> cols = 3
1019 >>> list(reshape(matrix, cols))
1020 [(0, 1, 2), (3, 4, 5)]
1021 """
1022 return batched(chain.from_iterable(matrix), cols)
1023
1024
1025def matmul(m1, m2):
1026 """Multiply two matrices.
1027
1028 >>> list(matmul([(7, 5), (3, 5)], [(2, 5), (7, 9)]))
1029 [(49, 80), (41, 60)]
1030
1031 The caller should ensure that the dimensions of the input matrices are
1032 compatible with each other.
1033
1034 Supports all numeric types: int, float, complex, Decimal, Fraction.
1035 """
1036 n = len(m2[0])
1037 return batched(starmap(_sumprod, product(m1, transpose(m2))), n)
1038
1039
1040def _factor_pollard(n):
1041 # Return a factor of n using Pollard's rho algorithm.
1042 # Efficient when n is odd and composite.
1043 for b in range(1, n):
1044 x = y = 2
1045 d = 1
1046 while d == 1:
1047 x = (x * x + b) % n
1048 y = (y * y + b) % n
1049 y = (y * y + b) % n
1050 d = gcd(x - y, n)
1051 if d != n:
1052 return d
1053 raise ValueError('prime or under 5')
1054
1055
1056_primes_below_211 = tuple(sieve(211))
1057
1058
1059def factor(n):
1060 """Yield the prime factors of n.
1061
1062 >>> list(factor(360))
1063 [2, 2, 2, 3, 3, 5]
1064
1065 Finds small factors with trial division. Larger factors are
1066 either verified as prime with ``is_prime`` or split into
1067 smaller factors with Pollard's rho algorithm.
1068 """
1069
1070 # Corner case reduction
1071 if n < 2:
1072 return
1073
1074 # Trial division reduction
1075 for prime in _primes_below_211:
1076 while not n % prime:
1077 yield prime
1078 n //= prime
1079
1080 # Pollard's rho reduction
1081 primes = []
1082 todo = [n] if n > 1 else []
1083 for n in todo:
1084 if n < 211**2 or is_prime(n):
1085 primes.append(n)
1086 else:
1087 fact = _factor_pollard(n)
1088 todo += (fact, n // fact)
1089 yield from sorted(primes)
1090
1091
1092def polynomial_eval(coefficients, x):
1093 """Evaluate a polynomial at a specific value.
1094
1095 Computes with better numeric stability than Horner's method.
1096
1097 Evaluate ``x^3 - 4 * x^2 - 17 * x + 60`` at ``x = 2.5``:
1098
1099 >>> coefficients = [1, -4, -17, 60]
1100 >>> x = 2.5
1101 >>> polynomial_eval(coefficients, x)
1102 8.125
1103
1104 Note that polynomial coefficients are specified in descending power order.
1105
1106 Supports all numeric types: int, float, complex, Decimal, Fraction.
1107 """
1108 n = len(coefficients)
1109 if n == 0:
1110 return type(x)(0)
1111 powers = map(pow, repeat(x), reversed(range(n)))
1112 return _sumprod(coefficients, powers)
1113
1114
1115def sum_of_squares(it):
1116 """Return the sum of the squares of the input values.
1117
1118 >>> sum_of_squares([10, 20, 30])
1119 1400
1120
1121 Supports all numeric types: int, float, complex, Decimal, Fraction.
1122 """
1123 return _sumprod(*tee(it))
1124
1125
1126def polynomial_derivative(coefficients):
1127 """Compute the first derivative of a polynomial.
1128
1129 Evaluate the derivative of ``x³ - 4 x² - 17 x + 60``:
1130
1131 >>> coefficients = [1, -4, -17, 60]
1132 >>> derivative_coefficients = polynomial_derivative(coefficients)
1133 >>> derivative_coefficients
1134 [3, -8, -17]
1135
1136 Note that polynomial coefficients are specified in descending power order.
1137
1138 Supports all numeric types: int, float, complex, Decimal, Fraction.
1139 """
1140 n = len(coefficients)
1141 powers = reversed(range(1, n))
1142 return list(map(mul, coefficients, powers))
1143
1144
1145def totient(n):
1146 """Return the count of natural numbers up to *n* that are coprime with *n*.
1147
1148 Euler's totient function φ(n) gives the number of totatives.
1149 Totative are integers k in the range 1 ≤ k ≤ n such that gcd(n, k) = 1.
1150
1151 >>> n = 9
1152 >>> totient(n)
1153 6
1154
1155 >>> totatives = [x for x in range(1, n) if gcd(n, x) == 1]
1156 >>> totatives
1157 [1, 2, 4, 5, 7, 8]
1158 >>> len(totatives)
1159 6
1160
1161 Reference: https://en.wikipedia.org/wiki/Euler%27s_totient_function
1162
1163 """
1164 for prime in set(factor(n)):
1165 n -= n // prime
1166 return n
1167
1168
1169# Miller–Rabin primality test: https://oeis.org/A014233
1170_perfect_tests = [
1171 (2047, (2,)),
1172 (9080191, (31, 73)),
1173 (4759123141, (2, 7, 61)),
1174 (1122004669633, (2, 13, 23, 1662803)),
1175 (2152302898747, (2, 3, 5, 7, 11)),
1176 (3474749660383, (2, 3, 5, 7, 11, 13)),
1177 (18446744073709551616, (2, 325, 9375, 28178, 450775, 9780504, 1795265022)),
1178 (
1179 3317044064679887385961981,
1180 (2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41),
1181 ),
1182]
1183
1184
1185@lru_cache
1186def _shift_to_odd(n):
1187 'Return s, d such that 2**s * d == n'
1188 s = ((n - 1) ^ n).bit_length() - 1
1189 d = n >> s
1190 assert (1 << s) * d == n and d & 1 and s >= 0
1191 return s, d
1192
1193
1194def _strong_probable_prime(n, base):
1195 assert (n > 2) and (n & 1) and (2 <= base < n)
1196
1197 s, d = _shift_to_odd(n - 1)
1198
1199 x = pow(base, d, n)
1200 if x == 1 or x == n - 1:
1201 return True
1202
1203 for _ in range(s - 1):
1204 x = x * x % n
1205 if x == n - 1:
1206 return True
1207
1208 return False
1209
1210
1211# Separate instance of Random() that doesn't share state
1212# with the default user instance of Random().
1213_private_randrange = random.Random().randrange
1214
1215
1216def is_prime(n):
1217 """Return ``True`` if *n* is prime and ``False`` otherwise.
1218
1219 Basic examples:
1220
1221 >>> is_prime(37)
1222 True
1223 >>> is_prime(3 * 13)
1224 False
1225 >>> is_prime(18_446_744_073_709_551_557)
1226 True
1227
1228 Find the next prime over one billion:
1229
1230 >>> next(filter(is_prime, count(10**9)))
1231 1000000007
1232
1233 Generate random primes up to 200 bits and up to 60 decimal digits:
1234
1235 >>> from random import seed, randrange, getrandbits
1236 >>> seed(18675309)
1237
1238 >>> next(filter(is_prime, map(getrandbits, repeat(200))))
1239 893303929355758292373272075469392561129886005037663238028407
1240
1241 >>> next(filter(is_prime, map(randrange, repeat(10**60))))
1242 269638077304026462407872868003560484232362454342414618963649
1243
1244 This function is exact for values of *n* below 10**24. For larger inputs,
1245 the probabilistic Miller-Rabin primality test has a less than 1 in 2**128
1246 chance of a false positive.
1247 """
1248
1249 if n < 17:
1250 return n in {2, 3, 5, 7, 11, 13}
1251
1252 if not (n & 1 and n % 3 and n % 5 and n % 7 and n % 11 and n % 13):
1253 return False
1254
1255 for limit, bases in _perfect_tests:
1256 if n < limit:
1257 break
1258 else:
1259 bases = (_private_randrange(2, n - 1) for i in range(64))
1260
1261 return all(_strong_probable_prime(n, base) for base in bases)
1262
1263
1264def loops(n):
1265 """Returns an iterable with *n* elements for efficient looping.
1266 Like ``range(n)`` but doesn't create integers.
1267
1268 >>> i = 0
1269 >>> for _ in loops(5):
1270 ... i += 1
1271 >>> i
1272 5
1273
1274 """
1275 return repeat(None, n)
1276
1277
1278def multinomial(*counts):
1279 """Number of distinct arrangements of a multiset.
1280
1281 The expression ``multinomial(3, 4, 2)`` has several equivalent
1282 interpretations:
1283
1284 * In the expansion of ``(a + b + c)⁹``, the coefficient of the
1285 ``a³b⁴c²`` term is 1260.
1286
1287 * There are 1260 distinct ways to arrange 9 balls consisting of 3 reds, 4
1288 greens, and 2 blues.
1289
1290 * There are 1260 unique ways to place 9 distinct objects into three bins
1291 with sizes 3, 4, and 2.
1292
1293 The :func:`multinomial` function computes the length of
1294 :func:`distinct_permutations`. For example, there are 83,160 distinct
1295 anagrams of the word "abracadabra":
1296
1297 >>> from more_itertools import distinct_permutations, ilen
1298 >>> ilen(distinct_permutations('abracadabra'))
1299 83160
1300
1301 This can be computed directly from the letter counts, 5a 2b 2r 1c 1d:
1302
1303 >>> from collections import Counter
1304 >>> list(Counter('abracadabra').values())
1305 [5, 2, 2, 1, 1]
1306 >>> multinomial(5, 2, 2, 1, 1)
1307 83160
1308
1309 A binomial coefficient is a special case of multinomial where there are
1310 only two categories. For example, the number of ways to arrange 12 balls
1311 with 5 reds and 7 blues is ``multinomial(5, 7)`` or ``math.comb(12, 5)``.
1312
1313 Likewise, factorial is a special case of multinomial where
1314 the multiplicities are all just 1 so that
1315 ``multinomial(1, 1, 1, 1, 1, 1, 1) == math.factorial(7)``.
1316
1317 Reference: https://en.wikipedia.org/wiki/Multinomial_theorem
1318
1319 """
1320 return prod(map(comb, accumulate(counts), counts))