1"""Imported from the recipes section of the itertools documentation.
2
3All functions taken from the recipes section of the itertools library docs
4[1]_.
5Some backward-compatible usability improvements have been made.
6
7.. [1] http://docs.python.org/library/itertools.html#recipes
8
9"""
10
11import random
12
13from collections import deque
14from contextlib import suppress
15from collections.abc import Sized
16from functools import lru_cache, partial
17from itertools import (
18 accumulate,
19 chain,
20 combinations,
21 compress,
22 count,
23 cycle,
24 groupby,
25 islice,
26 product,
27 repeat,
28 starmap,
29 takewhile,
30 tee,
31 zip_longest,
32)
33from math import prod, comb, isqrt, gcd
34from operator import mul, not_, itemgetter, getitem
35from random import randrange, sample, choice
36from sys import hexversion
37
38__all__ = [
39 'all_equal',
40 'batched',
41 'before_and_after',
42 'consume',
43 'convolve',
44 'dotproduct',
45 'first_true',
46 'factor',
47 'flatten',
48 'grouper',
49 'is_prime',
50 'iter_except',
51 'iter_index',
52 'loops',
53 'matmul',
54 'multinomial',
55 'ncycles',
56 'nth',
57 'nth_combination',
58 'padnone',
59 'pad_none',
60 'pairwise',
61 'partition',
62 'polynomial_eval',
63 'polynomial_from_roots',
64 'polynomial_derivative',
65 'powerset',
66 'prepend',
67 'quantify',
68 'reshape',
69 'random_combination_with_replacement',
70 'random_combination',
71 'random_permutation',
72 'random_product',
73 'repeatfunc',
74 'roundrobin',
75 'sieve',
76 'sliding_window',
77 'subslices',
78 'sum_of_squares',
79 'tabulate',
80 'tail',
81 'take',
82 'totient',
83 'transpose',
84 'triplewise',
85 'unique',
86 'unique_everseen',
87 'unique_justseen',
88]
89
90_marker = object()
91
92
93# zip with strict is available for Python 3.10+
94try:
95 zip(strict=True)
96except TypeError:
97 _zip_strict = zip
98else:
99 _zip_strict = partial(zip, strict=True)
100
101
102# math.sumprod is available for Python 3.12+
103try:
104 from math import sumprod as _sumprod
105except ImportError:
106 _sumprod = lambda x, y: dotproduct(x, y)
107
108
109def take(n, iterable):
110 """Return first *n* items of the *iterable* as a list.
111
112 >>> take(3, range(10))
113 [0, 1, 2]
114
115 If there are fewer than *n* items in the iterable, all of them are
116 returned.
117
118 >>> take(10, range(3))
119 [0, 1, 2]
120
121 """
122 return list(islice(iterable, n))
123
124
125def tabulate(function, start=0):
126 """Return an iterator over the results of ``func(start)``,
127 ``func(start + 1)``, ``func(start + 2)``...
128
129 *func* should be a function that accepts one integer argument.
130
131 If *start* is not specified it defaults to 0. It will be incremented each
132 time the iterator is advanced.
133
134 >>> square = lambda x: x ** 2
135 >>> iterator = tabulate(square, -3)
136 >>> take(4, iterator)
137 [9, 4, 1, 0]
138
139 """
140 return map(function, count(start))
141
142
143def tail(n, iterable):
144 """Return an iterator over the last *n* items of *iterable*.
145
146 >>> t = tail(3, 'ABCDEFG')
147 >>> list(t)
148 ['E', 'F', 'G']
149
150 """
151 # If the given iterable has a length, then we can use islice to get its
152 # final elements. Note that if the iterable is not actually Iterable,
153 # either islice or deque will throw a TypeError. This is why we don't
154 # check if it is Iterable.
155 if isinstance(iterable, Sized):
156 return islice(iterable, max(0, len(iterable) - n), None)
157 else:
158 return iter(deque(iterable, maxlen=n))
159
160
161def consume(iterator, n=None):
162 """Advance *iterable* by *n* steps. If *n* is ``None``, consume it
163 entirely.
164
165 Efficiently exhausts an iterator without returning values. Defaults to
166 consuming the whole iterator, but an optional second argument may be
167 provided to limit consumption.
168
169 >>> i = (x for x in range(10))
170 >>> next(i)
171 0
172 >>> consume(i, 3)
173 >>> next(i)
174 4
175 >>> consume(i)
176 >>> next(i)
177 Traceback (most recent call last):
178 File "<stdin>", line 1, in <module>
179 StopIteration
180
181 If the iterator has fewer items remaining than the provided limit, the
182 whole iterator will be consumed.
183
184 >>> i = (x for x in range(3))
185 >>> consume(i, 5)
186 >>> next(i)
187 Traceback (most recent call last):
188 File "<stdin>", line 1, in <module>
189 StopIteration
190
191 """
192 # Use functions that consume iterators at C speed.
193 if n is None:
194 # feed the entire iterator into a zero-length deque
195 deque(iterator, maxlen=0)
196 else:
197 # advance to the empty slice starting at position n
198 next(islice(iterator, n, n), None)
199
200
201def nth(iterable, n, default=None):
202 """Returns the nth item or a default value.
203
204 >>> l = range(10)
205 >>> nth(l, 3)
206 3
207 >>> nth(l, 20, "zebra")
208 'zebra'
209
210 """
211 return next(islice(iterable, n, None), default)
212
213
214def all_equal(iterable, key=None):
215 """
216 Returns ``True`` if all the elements are equal to each other.
217
218 >>> all_equal('aaaa')
219 True
220 >>> all_equal('aaab')
221 False
222
223 A function that accepts a single argument and returns a transformed version
224 of each input item can be specified with *key*:
225
226 >>> all_equal('AaaA', key=str.casefold)
227 True
228 >>> all_equal([1, 2, 3], key=lambda x: x < 10)
229 True
230
231 """
232 iterator = groupby(iterable, key)
233 for first in iterator:
234 for second in iterator:
235 return False
236 return True
237 return True
238
239
240def quantify(iterable, pred=bool):
241 """Return the how many times the predicate is true.
242
243 >>> quantify([True, False, True])
244 2
245
246 """
247 return sum(map(pred, iterable))
248
249
250def pad_none(iterable):
251 """Returns the sequence of elements and then returns ``None`` indefinitely.
252
253 >>> take(5, pad_none(range(3)))
254 [0, 1, 2, None, None]
255
256 Useful for emulating the behavior of the built-in :func:`map` function.
257
258 See also :func:`padded`.
259
260 """
261 return chain(iterable, repeat(None))
262
263
264padnone = pad_none
265
266
267def ncycles(iterable, n):
268 """Returns the sequence elements *n* times
269
270 >>> list(ncycles(["a", "b"], 3))
271 ['a', 'b', 'a', 'b', 'a', 'b']
272
273 """
274 return chain.from_iterable(repeat(tuple(iterable), n))
275
276
277def dotproduct(vec1, vec2):
278 """Returns the dot product of the two iterables.
279
280 >>> dotproduct([10, 15, 12], [0.65, 0.80, 1.25])
281 33.5
282 >>> 10 * 0.65 + 15 * 0.80 + 12 * 1.25
283 33.5
284
285 In Python 3.12 and later, use ``math.sumprod()`` instead.
286 """
287 return sum(map(mul, vec1, vec2))
288
289
290def flatten(listOfLists):
291 """Return an iterator flattening one level of nesting in a list of lists.
292
293 >>> list(flatten([[0, 1], [2, 3]]))
294 [0, 1, 2, 3]
295
296 See also :func:`collapse`, which can flatten multiple levels of nesting.
297
298 """
299 return chain.from_iterable(listOfLists)
300
301
302def repeatfunc(func, times=None, *args):
303 """Call *func* with *args* repeatedly, returning an iterable over the
304 results.
305
306 If *times* is specified, the iterable will terminate after that many
307 repetitions:
308
309 >>> from operator import add
310 >>> times = 4
311 >>> args = 3, 5
312 >>> list(repeatfunc(add, times, *args))
313 [8, 8, 8, 8]
314
315 If *times* is ``None`` the iterable will not terminate:
316
317 >>> from random import randrange
318 >>> times = None
319 >>> args = 1, 11
320 >>> take(6, repeatfunc(randrange, times, *args)) # doctest:+SKIP
321 [2, 4, 8, 1, 8, 4]
322
323 """
324 if times is None:
325 return starmap(func, repeat(args))
326 return starmap(func, repeat(args, times))
327
328
329def _pairwise(iterable):
330 """Returns an iterator of paired items, overlapping, from the original
331
332 >>> take(4, pairwise(count()))
333 [(0, 1), (1, 2), (2, 3), (3, 4)]
334
335 On Python 3.10 and above, this is an alias for :func:`itertools.pairwise`.
336
337 """
338 a, b = tee(iterable)
339 next(b, None)
340 return zip(a, b)
341
342
343try:
344 from itertools import pairwise as itertools_pairwise
345except ImportError:
346 pairwise = _pairwise
347else:
348
349 def pairwise(iterable):
350 return itertools_pairwise(iterable)
351
352 pairwise.__doc__ = _pairwise.__doc__
353
354
355class UnequalIterablesError(ValueError):
356 def __init__(self, details=None):
357 msg = 'Iterables have different lengths'
358 if details is not None:
359 msg += (': index 0 has length {}; index {} has length {}').format(
360 *details
361 )
362
363 super().__init__(msg)
364
365
366def _zip_equal_generator(iterables):
367 for combo in zip_longest(*iterables, fillvalue=_marker):
368 for val in combo:
369 if val is _marker:
370 raise UnequalIterablesError()
371 yield combo
372
373
374def _zip_equal(*iterables):
375 # Check whether the iterables are all the same size.
376 try:
377 first_size = len(iterables[0])
378 for i, it in enumerate(iterables[1:], 1):
379 size = len(it)
380 if size != first_size:
381 raise UnequalIterablesError(details=(first_size, i, size))
382 # All sizes are equal, we can use the built-in zip.
383 return zip(*iterables)
384 # If any one of the iterables didn't have a length, start reading
385 # them until one runs out.
386 except TypeError:
387 return _zip_equal_generator(iterables)
388
389
390def grouper(iterable, n, incomplete='fill', fillvalue=None):
391 """Group elements from *iterable* into fixed-length groups of length *n*.
392
393 >>> list(grouper('ABCDEF', 3))
394 [('A', 'B', 'C'), ('D', 'E', 'F')]
395
396 The keyword arguments *incomplete* and *fillvalue* control what happens for
397 iterables whose length is not a multiple of *n*.
398
399 When *incomplete* is `'fill'`, the last group will contain instances of
400 *fillvalue*.
401
402 >>> list(grouper('ABCDEFG', 3, incomplete='fill', fillvalue='x'))
403 [('A', 'B', 'C'), ('D', 'E', 'F'), ('G', 'x', 'x')]
404
405 When *incomplete* is `'ignore'`, the last group will not be emitted.
406
407 >>> list(grouper('ABCDEFG', 3, incomplete='ignore', fillvalue='x'))
408 [('A', 'B', 'C'), ('D', 'E', 'F')]
409
410 When *incomplete* is `'strict'`, a subclass of `ValueError` will be raised.
411
412 >>> iterator = grouper('ABCDEFG', 3, incomplete='strict')
413 >>> list(iterator) # doctest: +IGNORE_EXCEPTION_DETAIL
414 Traceback (most recent call last):
415 ...
416 UnequalIterablesError
417
418 """
419 iterators = [iter(iterable)] * n
420 if incomplete == 'fill':
421 return zip_longest(*iterators, fillvalue=fillvalue)
422 if incomplete == 'strict':
423 return _zip_equal(*iterators)
424 if incomplete == 'ignore':
425 return zip(*iterators)
426 else:
427 raise ValueError('Expected fill, strict, or ignore')
428
429
430def roundrobin(*iterables):
431 """Visit input iterables in a cycle until each is exhausted.
432
433 >>> list(roundrobin('ABC', 'D', 'EF'))
434 ['A', 'D', 'E', 'B', 'F', 'C']
435
436 This function produces the same output as :func:`interleave_longest`, but
437 may perform better for some inputs (in particular when the number of
438 iterables is small).
439
440 """
441 # Algorithm credited to George Sakkis
442 iterators = map(iter, iterables)
443 for num_active in range(len(iterables), 0, -1):
444 iterators = cycle(islice(iterators, num_active))
445 yield from map(next, iterators)
446
447
448def partition(pred, iterable):
449 """
450 Returns a 2-tuple of iterables derived from the input iterable.
451 The first yields the items that have ``pred(item) == False``.
452 The second yields the items that have ``pred(item) == True``.
453
454 >>> is_odd = lambda x: x % 2 != 0
455 >>> iterable = range(10)
456 >>> even_items, odd_items = partition(is_odd, iterable)
457 >>> list(even_items), list(odd_items)
458 ([0, 2, 4, 6, 8], [1, 3, 5, 7, 9])
459
460 If *pred* is None, :func:`bool` is used.
461
462 >>> iterable = [0, 1, False, True, '', ' ']
463 >>> false_items, true_items = partition(None, iterable)
464 >>> list(false_items), list(true_items)
465 ([0, False, ''], [1, True, ' '])
466
467 """
468 if pred is None:
469 pred = bool
470
471 t1, t2, p = tee(iterable, 3)
472 p1, p2 = tee(map(pred, p))
473 return (compress(t1, map(not_, p1)), compress(t2, p2))
474
475
476def powerset(iterable):
477 """Yields all possible subsets of the iterable.
478
479 >>> list(powerset([1, 2, 3]))
480 [(), (1,), (2,), (3,), (1, 2), (1, 3), (2, 3), (1, 2, 3)]
481
482 :func:`powerset` will operate on iterables that aren't :class:`set`
483 instances, so repeated elements in the input will produce repeated elements
484 in the output.
485
486 >>> seq = [1, 1, 0]
487 >>> list(powerset(seq))
488 [(), (1,), (1,), (0,), (1, 1), (1, 0), (1, 0), (1, 1, 0)]
489
490 For a variant that efficiently yields actual :class:`set` instances, see
491 :func:`powerset_of_sets`.
492 """
493 s = list(iterable)
494 return chain.from_iterable(combinations(s, r) for r in range(len(s) + 1))
495
496
497def unique_everseen(iterable, key=None):
498 """
499 Yield unique elements, preserving order.
500
501 >>> list(unique_everseen('AAAABBBCCDAABBB'))
502 ['A', 'B', 'C', 'D']
503 >>> list(unique_everseen('ABBCcAD', str.lower))
504 ['A', 'B', 'C', 'D']
505
506 Sequences with a mix of hashable and unhashable items can be used.
507 The function will be slower (i.e., `O(n^2)`) for unhashable items.
508
509 Remember that ``list`` objects are unhashable - you can use the *key*
510 parameter to transform the list to a tuple (which is hashable) to
511 avoid a slowdown.
512
513 >>> iterable = ([1, 2], [2, 3], [1, 2])
514 >>> list(unique_everseen(iterable)) # Slow
515 [[1, 2], [2, 3]]
516 >>> list(unique_everseen(iterable, key=tuple)) # Faster
517 [[1, 2], [2, 3]]
518
519 Similarly, you may want to convert unhashable ``set`` objects with
520 ``key=frozenset``. For ``dict`` objects,
521 ``key=lambda x: frozenset(x.items())`` can be used.
522
523 """
524 seenset = set()
525 seenset_add = seenset.add
526 seenlist = []
527 seenlist_add = seenlist.append
528 use_key = key is not None
529
530 for element in iterable:
531 k = key(element) if use_key else element
532 try:
533 if k not in seenset:
534 seenset_add(k)
535 yield element
536 except TypeError:
537 if k not in seenlist:
538 seenlist_add(k)
539 yield element
540
541
542def unique_justseen(iterable, key=None):
543 """Yields elements in order, ignoring serial duplicates
544
545 >>> list(unique_justseen('AAAABBBCCDAABBB'))
546 ['A', 'B', 'C', 'D', 'A', 'B']
547 >>> list(unique_justseen('ABBCcAD', str.lower))
548 ['A', 'B', 'C', 'A', 'D']
549
550 """
551 if key is None:
552 return map(itemgetter(0), groupby(iterable))
553
554 return map(next, map(itemgetter(1), groupby(iterable, key)))
555
556
557def unique(iterable, key=None, reverse=False):
558 """Yields unique elements in sorted order.
559
560 >>> list(unique([[1, 2], [3, 4], [1, 2]]))
561 [[1, 2], [3, 4]]
562
563 *key* and *reverse* are passed to :func:`sorted`.
564
565 >>> list(unique('ABBcCAD', str.casefold))
566 ['A', 'B', 'c', 'D']
567 >>> list(unique('ABBcCAD', str.casefold, reverse=True))
568 ['D', 'c', 'B', 'A']
569
570 The elements in *iterable* need not be hashable, but they must be
571 comparable for sorting to work.
572 """
573 sequenced = sorted(iterable, key=key, reverse=reverse)
574 return unique_justseen(sequenced, key=key)
575
576
577def iter_except(func, exception, first=None):
578 """Yields results from a function repeatedly until an exception is raised.
579
580 Converts a call-until-exception interface to an iterator interface.
581 Like ``iter(func, sentinel)``, but uses an exception instead of a sentinel
582 to end the loop.
583
584 >>> l = [0, 1, 2]
585 >>> list(iter_except(l.pop, IndexError))
586 [2, 1, 0]
587
588 Multiple exceptions can be specified as a stopping condition:
589
590 >>> l = [1, 2, 3, '...', 4, 5, 6]
591 >>> list(iter_except(lambda: 1 + l.pop(), (IndexError, TypeError)))
592 [7, 6, 5]
593 >>> list(iter_except(lambda: 1 + l.pop(), (IndexError, TypeError)))
594 [4, 3, 2]
595 >>> list(iter_except(lambda: 1 + l.pop(), (IndexError, TypeError)))
596 []
597
598 """
599 with suppress(exception):
600 if first is not None:
601 yield first()
602 while True:
603 yield func()
604
605
606def first_true(iterable, default=None, pred=None):
607 """
608 Returns the first true value in the iterable.
609
610 If no true value is found, returns *default*
611
612 If *pred* is not None, returns the first item for which
613 ``pred(item) == True`` .
614
615 >>> first_true(range(10))
616 1
617 >>> first_true(range(10), pred=lambda x: x > 5)
618 6
619 >>> first_true(range(10), default='missing', pred=lambda x: x > 9)
620 'missing'
621
622 """
623 return next(filter(pred, iterable), default)
624
625
626def random_product(*args, repeat=1):
627 """Draw an item at random from each of the input iterables.
628
629 >>> random_product('abc', range(4), 'XYZ') # doctest:+SKIP
630 ('c', 3, 'Z')
631
632 If *repeat* is provided as a keyword argument, that many items will be
633 drawn from each iterable.
634
635 >>> random_product('abcd', range(4), repeat=2) # doctest:+SKIP
636 ('a', 2, 'd', 3)
637
638 This equivalent to taking a random selection from
639 ``itertools.product(*args, repeat=repeat)``.
640
641 """
642 pools = [tuple(pool) for pool in args] * repeat
643 return tuple(choice(pool) for pool in pools)
644
645
646def random_permutation(iterable, r=None):
647 """Return a random *r* length permutation of the elements in *iterable*.
648
649 If *r* is not specified or is ``None``, then *r* defaults to the length of
650 *iterable*.
651
652 >>> random_permutation(range(5)) # doctest:+SKIP
653 (3, 4, 0, 1, 2)
654
655 This equivalent to taking a random selection from
656 ``itertools.permutations(iterable, r)``.
657
658 """
659 pool = tuple(iterable)
660 r = len(pool) if r is None else r
661 return tuple(sample(pool, r))
662
663
664def random_combination(iterable, r):
665 """Return a random *r* length subsequence of the elements in *iterable*.
666
667 >>> random_combination(range(5), 3) # doctest:+SKIP
668 (2, 3, 4)
669
670 This equivalent to taking a random selection from
671 ``itertools.combinations(iterable, r)``.
672
673 """
674 pool = tuple(iterable)
675 n = len(pool)
676 indices = sorted(sample(range(n), r))
677 return tuple(pool[i] for i in indices)
678
679
680def random_combination_with_replacement(iterable, r):
681 """Return a random *r* length subsequence of elements in *iterable*,
682 allowing individual elements to be repeated.
683
684 >>> random_combination_with_replacement(range(3), 5) # doctest:+SKIP
685 (0, 0, 1, 2, 2)
686
687 This equivalent to taking a random selection from
688 ``itertools.combinations_with_replacement(iterable, r)``.
689
690 """
691 pool = tuple(iterable)
692 n = len(pool)
693 indices = sorted(randrange(n) for i in range(r))
694 return tuple(pool[i] for i in indices)
695
696
697def nth_combination(iterable, r, index):
698 """Equivalent to ``list(combinations(iterable, r))[index]``.
699
700 The subsequences of *iterable* that are of length *r* can be ordered
701 lexicographically. :func:`nth_combination` computes the subsequence at
702 sort position *index* directly, without computing the previous
703 subsequences.
704
705 >>> nth_combination(range(5), 3, 5)
706 (0, 3, 4)
707
708 ``ValueError`` will be raised If *r* is negative or greater than the length
709 of *iterable*.
710 ``IndexError`` will be raised if the given *index* is invalid.
711 """
712 pool = tuple(iterable)
713 n = len(pool)
714 if (r < 0) or (r > n):
715 raise ValueError
716
717 c = 1
718 k = min(r, n - r)
719 for i in range(1, k + 1):
720 c = c * (n - k + i) // i
721
722 if index < 0:
723 index += c
724
725 if (index < 0) or (index >= c):
726 raise IndexError
727
728 result = []
729 while r:
730 c, n, r = c * r // n, n - 1, r - 1
731 while index >= c:
732 index -= c
733 c, n = c * (n - r) // n, n - 1
734 result.append(pool[-1 - n])
735
736 return tuple(result)
737
738
739def prepend(value, iterator):
740 """Yield *value*, followed by the elements in *iterator*.
741
742 >>> value = '0'
743 >>> iterator = ['1', '2', '3']
744 >>> list(prepend(value, iterator))
745 ['0', '1', '2', '3']
746
747 To prepend multiple values, see :func:`itertools.chain`
748 or :func:`value_chain`.
749
750 """
751 return chain([value], iterator)
752
753
754def convolve(signal, kernel):
755 """Discrete linear convolution of two iterables.
756 Equivalent to polynomial multiplication.
757
758 For example, multiplying ``(x² -x - 20)`` by ``(x - 3)``
759 gives ``(x³ -4x² -17x + 60)``.
760
761 >>> list(convolve([1, -1, -20], [1, -3]))
762 [1, -4, -17, 60]
763
764 Examples of popular kinds of kernels:
765
766 * The kernel ``[0.25, 0.25, 0.25, 0.25]`` computes a moving average.
767 For image data, this blurs the image and reduces noise.
768 * The kernel ``[1/2, 0, -1/2]`` estimates the first derivative of
769 a function evaluated at evenly spaced inputs.
770 * The kernel ``[1, -2, 1]`` estimates the second derivative of a
771 function evaluated at evenly spaced inputs.
772
773 Convolutions are mathematically commutative; however, the inputs are
774 evaluated differently. The signal is consumed lazily and can be
775 infinite. The kernel is fully consumed before the calculations begin.
776
777 Supports all numeric types: int, float, complex, Decimal, Fraction.
778
779 References:
780
781 * Article: https://betterexplained.com/articles/intuitive-convolution/
782 * Video by 3Blue1Brown: https://www.youtube.com/watch?v=KuXjwB4LzSA
783
784 """
785 # This implementation comes from an older version of the itertools
786 # documentation. While the newer implementation is a bit clearer,
787 # this one was kept because the inlined window logic is faster
788 # and it avoids an unnecessary deque-to-tuple conversion.
789 kernel = tuple(kernel)[::-1]
790 n = len(kernel)
791 window = deque([0], maxlen=n) * n
792 for x in chain(signal, repeat(0, n - 1)):
793 window.append(x)
794 yield _sumprod(kernel, window)
795
796
797def before_and_after(predicate, it):
798 """A variant of :func:`takewhile` that allows complete access to the
799 remainder of the iterator.
800
801 >>> it = iter('ABCdEfGhI')
802 >>> all_upper, remainder = before_and_after(str.isupper, it)
803 >>> ''.join(all_upper)
804 'ABC'
805 >>> ''.join(remainder) # takewhile() would lose the 'd'
806 'dEfGhI'
807
808 Note that the first iterator must be fully consumed before the second
809 iterator can generate valid results.
810 """
811 trues, after = tee(it)
812 trues = compress(takewhile(predicate, trues), zip(after))
813 return trues, after
814
815
816def triplewise(iterable):
817 """Return overlapping triplets from *iterable*.
818
819 >>> list(triplewise('ABCDE'))
820 [('A', 'B', 'C'), ('B', 'C', 'D'), ('C', 'D', 'E')]
821
822 """
823 # This deviates from the itertools documentation recipe - see
824 # https://github.com/more-itertools/more-itertools/issues/889
825 t1, t2, t3 = tee(iterable, 3)
826 next(t3, None)
827 next(t3, None)
828 next(t2, None)
829 return zip(t1, t2, t3)
830
831
832def _sliding_window_islice(iterable, n):
833 # Fast path for small, non-zero values of n.
834 iterators = tee(iterable, n)
835 for i, iterator in enumerate(iterators):
836 next(islice(iterator, i, i), None)
837 return zip(*iterators)
838
839
840def _sliding_window_deque(iterable, n):
841 # Normal path for other values of n.
842 iterator = iter(iterable)
843 window = deque(islice(iterator, n - 1), maxlen=n)
844 for x in iterator:
845 window.append(x)
846 yield tuple(window)
847
848
849def sliding_window(iterable, n):
850 """Return a sliding window of width *n* over *iterable*.
851
852 >>> list(sliding_window(range(6), 4))
853 [(0, 1, 2, 3), (1, 2, 3, 4), (2, 3, 4, 5)]
854
855 If *iterable* has fewer than *n* items, then nothing is yielded:
856
857 >>> list(sliding_window(range(3), 4))
858 []
859
860 For a variant with more features, see :func:`windowed`.
861 """
862 if n > 20:
863 return _sliding_window_deque(iterable, n)
864 elif n > 2:
865 return _sliding_window_islice(iterable, n)
866 elif n == 2:
867 return pairwise(iterable)
868 elif n == 1:
869 return zip(iterable)
870 else:
871 raise ValueError(f'n should be at least one, not {n}')
872
873
874def subslices(iterable):
875 """Return all contiguous non-empty subslices of *iterable*.
876
877 >>> list(subslices('ABC'))
878 [['A'], ['A', 'B'], ['A', 'B', 'C'], ['B'], ['B', 'C'], ['C']]
879
880 This is similar to :func:`substrings`, but emits items in a different
881 order.
882 """
883 seq = list(iterable)
884 slices = starmap(slice, combinations(range(len(seq) + 1), 2))
885 return map(getitem, repeat(seq), slices)
886
887
888def polynomial_from_roots(roots):
889 """Compute a polynomial's coefficients from its roots.
890
891 >>> roots = [5, -4, 3] # (x - 5) * (x + 4) * (x - 3)
892 >>> polynomial_from_roots(roots) # x³ - 4 x² - 17 x + 60
893 [1, -4, -17, 60]
894
895 Note that polynomial coefficients are specified in descending power order.
896
897 Supports all numeric types: int, float, complex, Decimal, Fraction.
898 """
899 # This recipe differs from the one in itertools docs in that it
900 # applies list() after each call to convolve(). This avoids
901 # hitting stack limits with nested generators.
902 poly = [1]
903 for root in roots:
904 poly = list(convolve(poly, (1, -root)))
905 return poly
906
907
908def iter_index(iterable, value, start=0, stop=None):
909 """Yield the index of each place in *iterable* that *value* occurs,
910 beginning with index *start* and ending before index *stop*.
911
912
913 >>> list(iter_index('AABCADEAF', 'A'))
914 [0, 1, 4, 7]
915 >>> list(iter_index('AABCADEAF', 'A', 1)) # start index is inclusive
916 [1, 4, 7]
917 >>> list(iter_index('AABCADEAF', 'A', 1, 7)) # stop index is not inclusive
918 [1, 4]
919
920 The behavior for non-scalar *values* matches the built-in Python types.
921
922 >>> list(iter_index('ABCDABCD', 'AB'))
923 [0, 4]
924 >>> list(iter_index([0, 1, 2, 3, 0, 1, 2, 3], [0, 1]))
925 []
926 >>> list(iter_index([[0, 1], [2, 3], [0, 1], [2, 3]], [0, 1]))
927 [0, 2]
928
929 See :func:`locate` for a more general means of finding the indexes
930 associated with particular values.
931
932 """
933 seq_index = getattr(iterable, 'index', None)
934 if seq_index is None:
935 # Slow path for general iterables
936 iterator = islice(iterable, start, stop)
937 for i, element in enumerate(iterator, start):
938 if element is value or element == value:
939 yield i
940 else:
941 # Fast path for sequences
942 stop = len(iterable) if stop is None else stop
943 i = start - 1
944 with suppress(ValueError):
945 while True:
946 yield (i := seq_index(value, i + 1, stop))
947
948
949def sieve(n):
950 """Yield the primes less than n.
951
952 >>> list(sieve(30))
953 [2, 3, 5, 7, 11, 13, 17, 19, 23, 29]
954
955 """
956 # This implementation comes from an older version of the itertools
957 # documentation. The newer implementation is easier to read but is
958 # less lazy.
959 if n > 2:
960 yield 2
961 start = 3
962 data = bytearray((0, 1)) * (n // 2)
963 for p in iter_index(data, 1, start, stop=isqrt(n) + 1):
964 yield from iter_index(data, 1, start, p * p)
965 data[p * p : n : p + p] = bytes(len(range(p * p, n, p + p)))
966 start = p * p
967 yield from iter_index(data, 1, start)
968
969
970def _batched(iterable, n, *, strict=False):
971 """Batch data into tuples of length *n*. If the number of items in
972 *iterable* is not divisible by *n*:
973 * The last batch will be shorter if *strict* is ``False``.
974 * :exc:`ValueError` will be raised if *strict* is ``True``.
975
976 >>> list(batched('ABCDEFG', 3))
977 [('A', 'B', 'C'), ('D', 'E', 'F'), ('G',)]
978
979 On Python 3.13 and above, this is an alias for :func:`itertools.batched`.
980 """
981 if n < 1:
982 raise ValueError('n must be at least one')
983 iterator = iter(iterable)
984 while batch := tuple(islice(iterator, n)):
985 if strict and len(batch) != n:
986 raise ValueError('batched(): incomplete batch')
987 yield batch
988
989
990if hexversion >= 0x30D00A2: # pragma: no cover
991 from itertools import batched as itertools_batched
992
993 def batched(iterable, n, *, strict=False):
994 return itertools_batched(iterable, n, strict=strict)
995
996 batched.__doc__ = _batched.__doc__
997else:
998 batched = _batched
999
1000
1001def transpose(it):
1002 """Swap the rows and columns of the input matrix.
1003
1004 >>> list(transpose([(1, 2, 3), (11, 22, 33)]))
1005 [(1, 11), (2, 22), (3, 33)]
1006
1007 The caller should ensure that the dimensions of the input are compatible.
1008 If the input is empty, no output will be produced.
1009 """
1010 return _zip_strict(*it)
1011
1012
1013def reshape(matrix, cols):
1014 """Reshape the 2-D input *matrix* to have a column count given by *cols*.
1015
1016 >>> matrix = [(0, 1), (2, 3), (4, 5)]
1017 >>> cols = 3
1018 >>> list(reshape(matrix, cols))
1019 [(0, 1, 2), (3, 4, 5)]
1020 """
1021 return batched(chain.from_iterable(matrix), cols)
1022
1023
1024def matmul(m1, m2):
1025 """Multiply two matrices.
1026
1027 >>> list(matmul([(7, 5), (3, 5)], [(2, 5), (7, 9)]))
1028 [(49, 80), (41, 60)]
1029
1030 The caller should ensure that the dimensions of the input matrices are
1031 compatible with each other.
1032
1033 Supports all numeric types: int, float, complex, Decimal, Fraction.
1034 """
1035 n = len(m2[0])
1036 return batched(starmap(_sumprod, product(m1, transpose(m2))), n)
1037
1038
1039def _factor_pollard(n):
1040 # Return a factor of n using Pollard's rho algorithm.
1041 # Efficient when n is odd and composite.
1042 for b in range(1, n):
1043 x = y = 2
1044 d = 1
1045 while d == 1:
1046 x = (x * x + b) % n
1047 y = (y * y + b) % n
1048 y = (y * y + b) % n
1049 d = gcd(x - y, n)
1050 if d != n:
1051 return d
1052 raise ValueError('prime or under 5')
1053
1054
1055_primes_below_211 = tuple(sieve(211))
1056
1057
1058def factor(n):
1059 """Yield the prime factors of n.
1060
1061 >>> list(factor(360))
1062 [2, 2, 2, 3, 3, 5]
1063
1064 Finds small factors with trial division. Larger factors are
1065 either verified as prime with ``is_prime`` or split into
1066 smaller factors with Pollard's rho algorithm.
1067 """
1068
1069 # Corner case reduction
1070 if n < 2:
1071 return
1072
1073 # Trial division reduction
1074 for prime in _primes_below_211:
1075 while not n % prime:
1076 yield prime
1077 n //= prime
1078
1079 # Pollard's rho reduction
1080 primes = []
1081 todo = [n] if n > 1 else []
1082 for n in todo:
1083 if n < 211**2 or is_prime(n):
1084 primes.append(n)
1085 else:
1086 fact = _factor_pollard(n)
1087 todo += (fact, n // fact)
1088 yield from sorted(primes)
1089
1090
1091def polynomial_eval(coefficients, x):
1092 """Evaluate a polynomial at a specific value.
1093
1094 Computes with better numeric stability than Horner's method.
1095
1096 Evaluate ``x^3 - 4 * x^2 - 17 * x + 60`` at ``x = 2.5``:
1097
1098 >>> coefficients = [1, -4, -17, 60]
1099 >>> x = 2.5
1100 >>> polynomial_eval(coefficients, x)
1101 8.125
1102
1103 Note that polynomial coefficients are specified in descending power order.
1104
1105 Supports all numeric types: int, float, complex, Decimal, Fraction.
1106 """
1107 n = len(coefficients)
1108 if n == 0:
1109 return type(x)(0)
1110 powers = map(pow, repeat(x), reversed(range(n)))
1111 return _sumprod(coefficients, powers)
1112
1113
1114def sum_of_squares(it):
1115 """Return the sum of the squares of the input values.
1116
1117 >>> sum_of_squares([10, 20, 30])
1118 1400
1119
1120 Supports all numeric types: int, float, complex, Decimal, Fraction.
1121 """
1122 return _sumprod(*tee(it))
1123
1124
1125def polynomial_derivative(coefficients):
1126 """Compute the first derivative of a polynomial.
1127
1128 Evaluate the derivative of ``x³ - 4 x² - 17 x + 60``:
1129
1130 >>> coefficients = [1, -4, -17, 60]
1131 >>> derivative_coefficients = polynomial_derivative(coefficients)
1132 >>> derivative_coefficients
1133 [3, -8, -17]
1134
1135 Note that polynomial coefficients are specified in descending power order.
1136
1137 Supports all numeric types: int, float, complex, Decimal, Fraction.
1138 """
1139 n = len(coefficients)
1140 powers = reversed(range(1, n))
1141 return list(map(mul, coefficients, powers))
1142
1143
1144def totient(n):
1145 """Return the count of natural numbers up to *n* that are coprime with *n*.
1146
1147 Euler's totient function φ(n) gives the number of totatives.
1148 Totative are integers k in the range 1 ≤ k ≤ n such that gcd(n, k) = 1.
1149
1150 >>> n = 9
1151 >>> totient(n)
1152 6
1153
1154 >>> totatives = [x for x in range(1, n) if gcd(n, x) == 1]
1155 >>> totatives
1156 [1, 2, 4, 5, 7, 8]
1157 >>> len(totatives)
1158 6
1159
1160 Reference: https://en.wikipedia.org/wiki/Euler%27s_totient_function
1161
1162 """
1163 for prime in set(factor(n)):
1164 n -= n // prime
1165 return n
1166
1167
1168# Miller–Rabin primality test: https://oeis.org/A014233
1169_perfect_tests = [
1170 (2047, (2,)),
1171 (9080191, (31, 73)),
1172 (4759123141, (2, 7, 61)),
1173 (1122004669633, (2, 13, 23, 1662803)),
1174 (2152302898747, (2, 3, 5, 7, 11)),
1175 (3474749660383, (2, 3, 5, 7, 11, 13)),
1176 (18446744073709551616, (2, 325, 9375, 28178, 450775, 9780504, 1795265022)),
1177 (
1178 3317044064679887385961981,
1179 (2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41),
1180 ),
1181]
1182
1183
1184@lru_cache
1185def _shift_to_odd(n):
1186 'Return s, d such that 2**s * d == n'
1187 s = ((n - 1) ^ n).bit_length() - 1
1188 d = n >> s
1189 assert (1 << s) * d == n and d & 1 and s >= 0
1190 return s, d
1191
1192
1193def _strong_probable_prime(n, base):
1194 assert (n > 2) and (n & 1) and (2 <= base < n)
1195
1196 s, d = _shift_to_odd(n - 1)
1197
1198 x = pow(base, d, n)
1199 if x == 1 or x == n - 1:
1200 return True
1201
1202 for _ in range(s - 1):
1203 x = x * x % n
1204 if x == n - 1:
1205 return True
1206
1207 return False
1208
1209
1210# Separate instance of Random() that doesn't share state
1211# with the default user instance of Random().
1212_private_randrange = random.Random().randrange
1213
1214
1215def is_prime(n):
1216 """Return ``True`` if *n* is prime and ``False`` otherwise.
1217
1218 Basic examples:
1219
1220 >>> is_prime(37)
1221 True
1222 >>> is_prime(3 * 13)
1223 False
1224 >>> is_prime(18_446_744_073_709_551_557)
1225 True
1226
1227 Find the next prime over one billion:
1228
1229 >>> next(filter(is_prime, count(10**9)))
1230 1000000007
1231
1232 Generate random primes up to 200 bits and up to 60 decimal digits:
1233
1234 >>> from random import seed, randrange, getrandbits
1235 >>> seed(18675309)
1236
1237 >>> next(filter(is_prime, map(getrandbits, repeat(200))))
1238 893303929355758292373272075469392561129886005037663238028407
1239
1240 >>> next(filter(is_prime, map(randrange, repeat(10**60))))
1241 269638077304026462407872868003560484232362454342414618963649
1242
1243 This function is exact for values of *n* below 10**24. For larger inputs,
1244 the probabilistic Miller-Rabin primality test has a less than 1 in 2**128
1245 chance of a false positive.
1246 """
1247
1248 if n < 17:
1249 return n in {2, 3, 5, 7, 11, 13}
1250
1251 if not (n & 1 and n % 3 and n % 5 and n % 7 and n % 11 and n % 13):
1252 return False
1253
1254 for limit, bases in _perfect_tests:
1255 if n < limit:
1256 break
1257 else:
1258 bases = (_private_randrange(2, n - 1) for i in range(64))
1259
1260 return all(_strong_probable_prime(n, base) for base in bases)
1261
1262
1263def loops(n):
1264 """Returns an iterable with *n* elements for efficient looping.
1265 Like ``range(n)`` but doesn't create integers.
1266
1267 >>> i = 0
1268 >>> for _ in loops(5):
1269 ... i += 1
1270 >>> i
1271 5
1272
1273 """
1274 return repeat(None, n)
1275
1276
1277def multinomial(*counts):
1278 """Number of distinct arrangements of a multiset.
1279
1280 The expression ``multinomial(3, 4, 2)`` has several equivalent
1281 interpretations:
1282
1283 * In the expansion of ``(a + b + c)⁹``, the coefficient of the
1284 ``a³b⁴c²`` term is 1260.
1285
1286 * There are 1260 distinct ways to arrange 9 balls consisting of 3 reds, 4
1287 greens, and 2 blues.
1288
1289 * There are 1260 unique ways to place 9 distinct objects into three bins
1290 with sizes 3, 4, and 2.
1291
1292 The :func:`multinomial` function computes the length of
1293 :func:`distinct_permutations`. For example, there are 83,160 distinct
1294 anagrams of the word "abracadabra":
1295
1296 >>> from more_itertools import distinct_permutations, ilen
1297 >>> ilen(distinct_permutations('abracadabra'))
1298 83160
1299
1300 This can be computed directly from the letter counts, 5a 2b 2r 1c 1d:
1301
1302 >>> from collections import Counter
1303 >>> list(Counter('abracadabra').values())
1304 [5, 2, 2, 1, 1]
1305 >>> multinomial(5, 2, 2, 1, 1)
1306 83160
1307
1308 A binomial coefficient is a special case of multinomial where there are
1309 only two categories. For example, the number of ways to arrange 12 balls
1310 with 5 reds and 7 blues is ``multinomial(5, 7)`` or ``math.comb(12, 5)``.
1311
1312 Likewise, factorial is a special case of multinomial where
1313 the multiplicities are all just 1 so that
1314 ``multinomial(1, 1, 1, 1, 1, 1, 1) == math.factorial(7)``.
1315
1316 Reference: https://en.wikipedia.org/wiki/Multinomial_theorem
1317
1318 """
1319 return prod(map(comb, accumulate(counts), counts))