1import math
2import warnings
3
4from collections import Counter, defaultdict, deque, abc
5from collections.abc import Sequence
6from contextlib import suppress
7from functools import cached_property, partial, reduce, wraps
8from heapq import heapify, heapreplace
9from itertools import (
10 chain,
11 combinations,
12 compress,
13 count,
14 cycle,
15 dropwhile,
16 groupby,
17 islice,
18 permutations,
19 repeat,
20 starmap,
21 takewhile,
22 tee,
23 zip_longest,
24 product,
25)
26from math import comb, e, exp, factorial, floor, fsum, log, log1p, perm, tau
27from math import ceil
28from queue import Empty, Queue
29from random import random, randrange, shuffle, uniform
30from operator import (
31 attrgetter,
32 is_not,
33 itemgetter,
34 lt,
35 mul,
36 neg,
37 sub,
38 gt,
39)
40from sys import hexversion, maxsize
41from time import monotonic
42
43from .recipes import (
44 _marker,
45 _zip_equal,
46 UnequalIterablesError,
47 consume,
48 first_true,
49 flatten,
50 is_prime,
51 nth,
52 powerset,
53 sieve,
54 take,
55 unique_everseen,
56 all_equal,
57 batched,
58)
59
60__all__ = [
61 'AbortThread',
62 'SequenceView',
63 'UnequalIterablesError',
64 'adjacent',
65 'all_unique',
66 'always_iterable',
67 'always_reversible',
68 'argmax',
69 'argmin',
70 'bucket',
71 'callback_iter',
72 'chunked',
73 'chunked_even',
74 'circular_shifts',
75 'collapse',
76 'combination_index',
77 'combination_with_replacement_index',
78 'consecutive_groups',
79 'constrained_batches',
80 'consumer',
81 'count_cycle',
82 'countable',
83 'derangements',
84 'dft',
85 'difference',
86 'distinct_combinations',
87 'distinct_permutations',
88 'distribute',
89 'divide',
90 'doublestarmap',
91 'duplicates_everseen',
92 'duplicates_justseen',
93 'classify_unique',
94 'exactly_n',
95 'extract',
96 'filter_except',
97 'filter_map',
98 'first',
99 'gray_product',
100 'groupby_transform',
101 'ichunked',
102 'iequals',
103 'idft',
104 'ilen',
105 'interleave',
106 'interleave_evenly',
107 'interleave_longest',
108 'interleave_randomly',
109 'intersperse',
110 'is_sorted',
111 'islice_extended',
112 'iterate',
113 'iter_suppress',
114 'join_mappings',
115 'last',
116 'locate',
117 'longest_common_prefix',
118 'lstrip',
119 'make_decorator',
120 'map_except',
121 'map_if',
122 'map_reduce',
123 'mark_ends',
124 'minmax',
125 'nth_or_last',
126 'nth_permutation',
127 'nth_prime',
128 'nth_product',
129 'nth_combination_with_replacement',
130 'numeric_range',
131 'one',
132 'only',
133 'outer_product',
134 'padded',
135 'partial_product',
136 'partitions',
137 'peekable',
138 'permutation_index',
139 'powerset_of_sets',
140 'product_index',
141 'raise_',
142 'repeat_each',
143 'repeat_last',
144 'replace',
145 'rlocate',
146 'rstrip',
147 'run_length',
148 'sample',
149 'seekable',
150 'set_partitions',
151 'side_effect',
152 'sliced',
153 'sort_together',
154 'split_after',
155 'split_at',
156 'split_before',
157 'split_into',
158 'split_when',
159 'spy',
160 'stagger',
161 'strip',
162 'strictly_n',
163 'substrings',
164 'substrings_indexes',
165 'takewhile_inclusive',
166 'time_limited',
167 'unique_in_window',
168 'unique_to_each',
169 'unzip',
170 'value_chain',
171 'windowed',
172 'windowed_complete',
173 'with_iter',
174 'zip_broadcast',
175 'zip_equal',
176 'zip_offset',
177]
178
179# math.sumprod is available for Python 3.12+
180try:
181 from math import sumprod as _fsumprod
182
183except ImportError: # pragma: no cover
184 # Extended precision algorithms from T. J. Dekker,
185 # "A Floating-Point Technique for Extending the Available Precision"
186 # https://csclub.uwaterloo.ca/~pbarfuss/dekker1971.pdf
187 # Formulas: (5.5) (5.6) and (5.8). Code: mul12()
188
189 def dl_split(x: float):
190 "Split a float into two half-precision components."
191 t = x * 134217729.0 # Veltkamp constant = 2.0 ** 27 + 1
192 hi = t - (t - x)
193 lo = x - hi
194 return hi, lo
195
196 def dl_mul(x, y):
197 "Lossless multiplication."
198 xx_hi, xx_lo = dl_split(x)
199 yy_hi, yy_lo = dl_split(y)
200 p = xx_hi * yy_hi
201 q = xx_hi * yy_lo + xx_lo * yy_hi
202 z = p + q
203 zz = p - z + q + xx_lo * yy_lo
204 return z, zz
205
206 def _fsumprod(p, q):
207 return fsum(chain.from_iterable(map(dl_mul, p, q)))
208
209
210def chunked(iterable, n, strict=False):
211 """Break *iterable* into lists of length *n*:
212
213 >>> list(chunked([1, 2, 3, 4, 5, 6], 3))
214 [[1, 2, 3], [4, 5, 6]]
215
216 By the default, the last yielded list will have fewer than *n* elements
217 if the length of *iterable* is not divisible by *n*:
218
219 >>> list(chunked([1, 2, 3, 4, 5, 6, 7, 8], 3))
220 [[1, 2, 3], [4, 5, 6], [7, 8]]
221
222 To use a fill-in value instead, see the :func:`grouper` recipe.
223
224 If the length of *iterable* is not divisible by *n* and *strict* is
225 ``True``, then ``ValueError`` will be raised before the last
226 list is yielded.
227
228 """
229 iterator = iter(partial(take, n, iter(iterable)), [])
230 if strict:
231 if n is None:
232 raise ValueError('n must not be None when using strict mode.')
233
234 def ret():
235 for chunk in iterator:
236 if len(chunk) != n:
237 raise ValueError('iterable is not divisible by n.')
238 yield chunk
239
240 return ret()
241 else:
242 return iterator
243
244
245def first(iterable, default=_marker):
246 """Return the first item of *iterable*, or *default* if *iterable* is
247 empty.
248
249 >>> first([0, 1, 2, 3])
250 0
251 >>> first([], 'some default')
252 'some default'
253
254 If *default* is not provided and there are no items in the iterable,
255 raise ``ValueError``.
256
257 :func:`first` is useful when you have a generator of expensive-to-retrieve
258 values and want any arbitrary one. It is marginally shorter than
259 ``next(iter(iterable), default)``.
260
261 """
262 for item in iterable:
263 return item
264 if default is _marker:
265 raise ValueError(
266 'first() was called on an empty iterable, '
267 'and no default value was provided.'
268 )
269 return default
270
271
272def last(iterable, default=_marker):
273 """Return the last item of *iterable*, or *default* if *iterable* is
274 empty.
275
276 >>> last([0, 1, 2, 3])
277 3
278 >>> last([], 'some default')
279 'some default'
280
281 If *default* is not provided and there are no items in the iterable,
282 raise ``ValueError``.
283 """
284 try:
285 if isinstance(iterable, Sequence):
286 return iterable[-1]
287 # Work around https://bugs.python.org/issue38525
288 if getattr(iterable, '__reversed__', None):
289 return next(reversed(iterable))
290 return deque(iterable, maxlen=1)[-1]
291 except (IndexError, TypeError, StopIteration):
292 if default is _marker:
293 raise ValueError(
294 'last() was called on an empty iterable, '
295 'and no default value was provided.'
296 )
297 return default
298
299
300def nth_or_last(iterable, n, default=_marker):
301 """Return the nth or the last item of *iterable*,
302 or *default* if *iterable* is empty.
303
304 >>> nth_or_last([0, 1, 2, 3], 2)
305 2
306 >>> nth_or_last([0, 1], 2)
307 1
308 >>> nth_or_last([], 0, 'some default')
309 'some default'
310
311 If *default* is not provided and there are no items in the iterable,
312 raise ``ValueError``.
313 """
314 return last(islice(iterable, n + 1), default=default)
315
316
317class peekable:
318 """Wrap an iterator to allow lookahead and prepending elements.
319
320 Call :meth:`peek` on the result to get the value that will be returned
321 by :func:`next`. This won't advance the iterator:
322
323 >>> p = peekable(['a', 'b'])
324 >>> p.peek()
325 'a'
326 >>> next(p)
327 'a'
328
329 Pass :meth:`peek` a default value to return that instead of raising
330 ``StopIteration`` when the iterator is exhausted.
331
332 >>> p = peekable([])
333 >>> p.peek('hi')
334 'hi'
335
336 peekables also offer a :meth:`prepend` method, which "inserts" items
337 at the head of the iterable:
338
339 >>> p = peekable([1, 2, 3])
340 >>> p.prepend(10, 11, 12)
341 >>> next(p)
342 10
343 >>> p.peek()
344 11
345 >>> list(p)
346 [11, 12, 1, 2, 3]
347
348 peekables can be indexed. Index 0 is the item that will be returned by
349 :func:`next`, index 1 is the item after that, and so on:
350 The values up to the given index will be cached.
351
352 >>> p = peekable(['a', 'b', 'c', 'd'])
353 >>> p[0]
354 'a'
355 >>> p[1]
356 'b'
357 >>> next(p)
358 'a'
359
360 Negative indexes are supported, but be aware that they will cache the
361 remaining items in the source iterator, which may require significant
362 storage.
363
364 To check whether a peekable is exhausted, check its truth value:
365
366 >>> p = peekable(['a', 'b'])
367 >>> if p: # peekable has items
368 ... list(p)
369 ['a', 'b']
370 >>> if not p: # peekable is exhausted
371 ... list(p)
372 []
373
374 """
375
376 def __init__(self, iterable):
377 self._it = iter(iterable)
378 self._cache = deque()
379
380 def __iter__(self):
381 return self
382
383 def __bool__(self):
384 try:
385 self.peek()
386 except StopIteration:
387 return False
388 return True
389
390 def peek(self, default=_marker):
391 """Return the item that will be next returned from ``next()``.
392
393 Return ``default`` if there are no items left. If ``default`` is not
394 provided, raise ``StopIteration``.
395
396 """
397 if not self._cache:
398 try:
399 self._cache.append(next(self._it))
400 except StopIteration:
401 if default is _marker:
402 raise
403 return default
404 return self._cache[0]
405
406 def prepend(self, *items):
407 """Stack up items to be the next ones returned from ``next()`` or
408 ``self.peek()``. The items will be returned in
409 first in, first out order::
410
411 >>> p = peekable([1, 2, 3])
412 >>> p.prepend(10, 11, 12)
413 >>> next(p)
414 10
415 >>> list(p)
416 [11, 12, 1, 2, 3]
417
418 It is possible, by prepending items, to "resurrect" a peekable that
419 previously raised ``StopIteration``.
420
421 >>> p = peekable([])
422 >>> next(p)
423 Traceback (most recent call last):
424 ...
425 StopIteration
426 >>> p.prepend(1)
427 >>> next(p)
428 1
429 >>> next(p)
430 Traceback (most recent call last):
431 ...
432 StopIteration
433
434 """
435 self._cache.extendleft(reversed(items))
436
437 def __next__(self):
438 if self._cache:
439 return self._cache.popleft()
440
441 return next(self._it)
442
443 def _get_slice(self, index):
444 # Normalize the slice's arguments
445 step = 1 if (index.step is None) else index.step
446 if step > 0:
447 start = 0 if (index.start is None) else index.start
448 stop = maxsize if (index.stop is None) else index.stop
449 elif step < 0:
450 start = -1 if (index.start is None) else index.start
451 stop = (-maxsize - 1) if (index.stop is None) else index.stop
452 else:
453 raise ValueError('slice step cannot be zero')
454
455 # If either the start or stop index is negative, we'll need to cache
456 # the rest of the iterable in order to slice from the right side.
457 if (start < 0) or (stop < 0):
458 self._cache.extend(self._it)
459 # Otherwise we'll need to find the rightmost index and cache to that
460 # point.
461 else:
462 n = min(max(start, stop) + 1, maxsize)
463 cache_len = len(self._cache)
464 if n >= cache_len:
465 self._cache.extend(islice(self._it, n - cache_len))
466
467 return list(self._cache)[index]
468
469 def __getitem__(self, index):
470 if isinstance(index, slice):
471 return self._get_slice(index)
472
473 cache_len = len(self._cache)
474 if index < 0:
475 self._cache.extend(self._it)
476 elif index >= cache_len:
477 self._cache.extend(islice(self._it, index + 1 - cache_len))
478
479 return self._cache[index]
480
481
482def consumer(func):
483 """Decorator that automatically advances a PEP-342-style "reverse iterator"
484 to its first yield point so you don't have to call ``next()`` on it
485 manually.
486
487 >>> @consumer
488 ... def tally():
489 ... i = 0
490 ... while True:
491 ... print('Thing number %s is %s.' % (i, (yield)))
492 ... i += 1
493 ...
494 >>> t = tally()
495 >>> t.send('red')
496 Thing number 0 is red.
497 >>> t.send('fish')
498 Thing number 1 is fish.
499
500 Without the decorator, you would have to call ``next(t)`` before
501 ``t.send()`` could be used.
502
503 """
504
505 @wraps(func)
506 def wrapper(*args, **kwargs):
507 gen = func(*args, **kwargs)
508 next(gen)
509 return gen
510
511 return wrapper
512
513
514def ilen(iterable):
515 """Return the number of items in *iterable*.
516
517 For example, there are 168 prime numbers below 1,000:
518
519 >>> ilen(sieve(1000))
520 168
521
522 Equivalent to, but faster than::
523
524 def ilen(iterable):
525 count = 0
526 for _ in iterable:
527 count += 1
528 return count
529
530 This fully consumes the iterable, so handle with care.
531
532 """
533 # This is the "most beautiful of the fast variants" of this function.
534 # If you think you can improve on it, please ensure that your version
535 # is both 10x faster and 10x more beautiful.
536 return sum(compress(repeat(1), zip(iterable)))
537
538
539def iterate(func, start):
540 """Return ``start``, ``func(start)``, ``func(func(start))``, ...
541
542 Produces an infinite iterator. To add a stopping condition,
543 use :func:`take`, ``takewhile``, or :func:`takewhile_inclusive`:.
544
545 >>> take(10, iterate(lambda x: 2*x, 1))
546 [1, 2, 4, 8, 16, 32, 64, 128, 256, 512]
547
548 >>> collatz = lambda x: 3*x + 1 if x%2==1 else x // 2
549 >>> list(takewhile_inclusive(lambda x: x!=1, iterate(collatz, 10)))
550 [10, 5, 16, 8, 4, 2, 1]
551
552 """
553 with suppress(StopIteration):
554 while True:
555 yield start
556 start = func(start)
557
558
559def with_iter(context_manager):
560 """Wrap an iterable in a ``with`` statement, so it closes once exhausted.
561
562 For example, this will close the file when the iterator is exhausted::
563
564 upper_lines = (line.upper() for line in with_iter(open('foo')))
565
566 Any context manager which returns an iterable is a candidate for
567 ``with_iter``.
568
569 """
570 with context_manager as iterable:
571 yield from iterable
572
573
574def one(iterable, too_short=None, too_long=None):
575 """Return the first item from *iterable*, which is expected to contain only
576 that item. Raise an exception if *iterable* is empty or has more than one
577 item.
578
579 :func:`one` is useful for ensuring that an iterable contains only one item.
580 For example, it can be used to retrieve the result of a database query
581 that is expected to return a single row.
582
583 If *iterable* is empty, ``ValueError`` will be raised. You may specify a
584 different exception with the *too_short* keyword:
585
586 >>> it = []
587 >>> one(it) # doctest: +IGNORE_EXCEPTION_DETAIL
588 Traceback (most recent call last):
589 ...
590 ValueError: too few items in iterable (expected 1)'
591 >>> too_short = IndexError('too few items')
592 >>> one(it, too_short=too_short) # doctest: +IGNORE_EXCEPTION_DETAIL
593 Traceback (most recent call last):
594 ...
595 IndexError: too few items
596
597 Similarly, if *iterable* contains more than one item, ``ValueError`` will
598 be raised. You may specify a different exception with the *too_long*
599 keyword:
600
601 >>> it = ['too', 'many']
602 >>> one(it) # doctest: +IGNORE_EXCEPTION_DETAIL
603 Traceback (most recent call last):
604 ...
605 ValueError: Expected exactly one item in iterable, but got 'too',
606 'many', and perhaps more.
607 >>> too_long = RuntimeError
608 >>> one(it, too_long=too_long) # doctest: +IGNORE_EXCEPTION_DETAIL
609 Traceback (most recent call last):
610 ...
611 RuntimeError
612
613 Note that :func:`one` attempts to advance *iterable* twice to ensure there
614 is only one item. See :func:`spy` or :func:`peekable` to check iterable
615 contents less destructively.
616
617 """
618 iterator = iter(iterable)
619 for first in iterator:
620 for second in iterator:
621 msg = (
622 f'Expected exactly one item in iterable, but got {first!r}, '
623 f'{second!r}, and perhaps more.'
624 )
625 raise too_long or ValueError(msg)
626 return first
627 raise too_short or ValueError('too few items in iterable (expected 1)')
628
629
630def raise_(exception, *args):
631 raise exception(*args)
632
633
634def strictly_n(iterable, n, too_short=None, too_long=None):
635 """Validate that *iterable* has exactly *n* items and return them if
636 it does. If it has fewer than *n* items, call function *too_short*
637 with those items. If it has more than *n* items, call function
638 *too_long* with the first ``n + 1`` items.
639
640 >>> iterable = ['a', 'b', 'c', 'd']
641 >>> n = 4
642 >>> list(strictly_n(iterable, n))
643 ['a', 'b', 'c', 'd']
644
645 Note that the returned iterable must be consumed in order for the check to
646 be made.
647
648 By default, *too_short* and *too_long* are functions that raise
649 ``ValueError``.
650
651 >>> list(strictly_n('ab', 3)) # doctest: +IGNORE_EXCEPTION_DETAIL
652 Traceback (most recent call last):
653 ...
654 ValueError: too few items in iterable (got 2)
655
656 >>> list(strictly_n('abc', 2)) # doctest: +IGNORE_EXCEPTION_DETAIL
657 Traceback (most recent call last):
658 ...
659 ValueError: too many items in iterable (got at least 3)
660
661 You can instead supply functions that do something else.
662 *too_short* will be called with the number of items in *iterable*.
663 *too_long* will be called with `n + 1`.
664
665 >>> def too_short(item_count):
666 ... raise RuntimeError
667 >>> it = strictly_n('abcd', 6, too_short=too_short)
668 >>> list(it) # doctest: +IGNORE_EXCEPTION_DETAIL
669 Traceback (most recent call last):
670 ...
671 RuntimeError
672
673 >>> def too_long(item_count):
674 ... print('The boss is going to hear about this')
675 >>> it = strictly_n('abcdef', 4, too_long=too_long)
676 >>> list(it)
677 The boss is going to hear about this
678 ['a', 'b', 'c', 'd']
679
680 """
681 if too_short is None:
682 too_short = lambda item_count: raise_(
683 ValueError,
684 f'Too few items in iterable (got {item_count})',
685 )
686
687 if too_long is None:
688 too_long = lambda item_count: raise_(
689 ValueError,
690 f'Too many items in iterable (got at least {item_count})',
691 )
692
693 it = iter(iterable)
694
695 sent = 0
696 for item in islice(it, n):
697 yield item
698 sent += 1
699
700 if sent < n:
701 too_short(sent)
702 return
703
704 for item in it:
705 too_long(n + 1)
706 return
707
708
709def distinct_permutations(iterable, r=None):
710 """Yield successive distinct permutations of the elements in *iterable*.
711
712 >>> sorted(distinct_permutations([1, 0, 1]))
713 [(0, 1, 1), (1, 0, 1), (1, 1, 0)]
714
715 Equivalent to yielding from ``set(permutations(iterable))``, except
716 duplicates are not generated and thrown away. For larger input sequences
717 this is much more efficient.
718
719 Duplicate permutations arise when there are duplicated elements in the
720 input iterable. The number of items returned is
721 `n! / (x_1! * x_2! * ... * x_n!)`, where `n` is the total number of
722 items input, and each `x_i` is the count of a distinct item in the input
723 sequence. The function :func:`multinomial` computes this directly.
724
725 If *r* is given, only the *r*-length permutations are yielded.
726
727 >>> sorted(distinct_permutations([1, 0, 1], r=2))
728 [(0, 1), (1, 0), (1, 1)]
729 >>> sorted(distinct_permutations(range(3), r=2))
730 [(0, 1), (0, 2), (1, 0), (1, 2), (2, 0), (2, 1)]
731
732 *iterable* need not be sortable, but note that using equal (``x == y``)
733 but non-identical (``id(x) != id(y)``) elements may produce surprising
734 behavior. For example, ``1`` and ``True`` are equal but non-identical:
735
736 >>> list(distinct_permutations([1, True, '3'])) # doctest: +SKIP
737 [
738 (1, True, '3'),
739 (1, '3', True),
740 ('3', 1, True)
741 ]
742 >>> list(distinct_permutations([1, 2, '3'])) # doctest: +SKIP
743 [
744 (1, 2, '3'),
745 (1, '3', 2),
746 (2, 1, '3'),
747 (2, '3', 1),
748 ('3', 1, 2),
749 ('3', 2, 1)
750 ]
751 """
752
753 # Algorithm: https://w.wiki/Qai
754 def _full(A):
755 while True:
756 # Yield the permutation we have
757 yield tuple(A)
758
759 # Find the largest index i such that A[i] < A[i + 1]
760 for i in range(size - 2, -1, -1):
761 if A[i] < A[i + 1]:
762 break
763 # If no such index exists, this permutation is the last one
764 else:
765 return
766
767 # Find the largest index j greater than j such that A[i] < A[j]
768 for j in range(size - 1, i, -1):
769 if A[i] < A[j]:
770 break
771
772 # Swap the value of A[i] with that of A[j], then reverse the
773 # sequence from A[i + 1] to form the new permutation
774 A[i], A[j] = A[j], A[i]
775 A[i + 1 :] = A[: i - size : -1] # A[i + 1:][::-1]
776
777 # Algorithm: modified from the above
778 def _partial(A, r):
779 # Split A into the first r items and the last r items
780 head, tail = A[:r], A[r:]
781 right_head_indexes = range(r - 1, -1, -1)
782 left_tail_indexes = range(len(tail))
783
784 while True:
785 # Yield the permutation we have
786 yield tuple(head)
787
788 # Starting from the right, find the first index of the head with
789 # value smaller than the maximum value of the tail - call it i.
790 pivot = tail[-1]
791 for i in right_head_indexes:
792 if head[i] < pivot:
793 break
794 pivot = head[i]
795 else:
796 return
797
798 # Starting from the left, find the first value of the tail
799 # with a value greater than head[i] and swap.
800 for j in left_tail_indexes:
801 if tail[j] > head[i]:
802 head[i], tail[j] = tail[j], head[i]
803 break
804 # If we didn't find one, start from the right and find the first
805 # index of the head with a value greater than head[i] and swap.
806 else:
807 for j in right_head_indexes:
808 if head[j] > head[i]:
809 head[i], head[j] = head[j], head[i]
810 break
811
812 # Reverse head[i + 1:] and swap it with tail[:r - (i + 1)]
813 tail += head[: i - r : -1] # head[i + 1:][::-1]
814 i += 1
815 head[i:], tail[:] = tail[: r - i], tail[r - i :]
816
817 items = list(iterable)
818
819 try:
820 items.sort()
821 sortable = True
822 except TypeError:
823 sortable = False
824
825 indices_dict = defaultdict(list)
826
827 for item in items:
828 indices_dict[items.index(item)].append(item)
829
830 indices = [items.index(item) for item in items]
831 indices.sort()
832
833 equivalent_items = {k: cycle(v) for k, v in indices_dict.items()}
834
835 def permuted_items(permuted_indices):
836 return tuple(
837 next(equivalent_items[index]) for index in permuted_indices
838 )
839
840 size = len(items)
841 if r is None:
842 r = size
843
844 # functools.partial(_partial, ... )
845 algorithm = _full if (r == size) else partial(_partial, r=r)
846
847 if 0 < r <= size:
848 if sortable:
849 return algorithm(items)
850 else:
851 return (
852 permuted_items(permuted_indices)
853 for permuted_indices in algorithm(indices)
854 )
855
856 return iter(() if r else ((),))
857
858
859def derangements(iterable, r=None):
860 """Yield successive derangements of the elements in *iterable*.
861
862 A derangement is a permutation in which no element appears at its original
863 index. In other words, a derangement is a permutation that has no fixed points.
864
865 Suppose Alice, Bob, Carol, and Dave are playing Secret Santa.
866 The code below outputs all of the different ways to assign gift recipients
867 such that nobody is assigned to himself or herself:
868
869 >>> for d in derangements(['Alice', 'Bob', 'Carol', 'Dave']):
870 ... print(', '.join(d))
871 Bob, Alice, Dave, Carol
872 Bob, Carol, Dave, Alice
873 Bob, Dave, Alice, Carol
874 Carol, Alice, Dave, Bob
875 Carol, Dave, Alice, Bob
876 Carol, Dave, Bob, Alice
877 Dave, Alice, Bob, Carol
878 Dave, Carol, Alice, Bob
879 Dave, Carol, Bob, Alice
880
881 If *r* is given, only the *r*-length derangements are yielded.
882
883 >>> sorted(derangements(range(3), 2))
884 [(1, 0), (1, 2), (2, 0)]
885 >>> sorted(derangements([0, 2, 3], 2))
886 [(2, 0), (2, 3), (3, 0)]
887
888 Elements are treated as unique based on their position, not on their value.
889 If the input elements are unique, there will be no repeated values within a
890 permutation.
891
892 The number of derangements of a set of size *n* is known as the
893 "subfactorial of n". For n > 0, the subfactorial is:
894 ``round(math.factorial(n) / math.e)``.
895
896 References:
897
898 * Article: https://www.numberanalytics.com/blog/ultimate-guide-to-derangements-in-combinatorics
899 * Sizes: https://oeis.org/A000166
900 """
901 xs = tuple(iterable)
902 ys = tuple(range(len(xs)))
903 return compress(
904 permutations(xs, r=r),
905 map(all, map(map, repeat(is_not), repeat(ys), permutations(ys, r=r))),
906 )
907
908
909def intersperse(e, iterable, n=1):
910 """Intersperse filler element *e* among the items in *iterable*, leaving
911 *n* items between each filler element.
912
913 >>> list(intersperse('!', [1, 2, 3, 4, 5]))
914 [1, '!', 2, '!', 3, '!', 4, '!', 5]
915
916 >>> list(intersperse(None, [1, 2, 3, 4, 5], n=2))
917 [1, 2, None, 3, 4, None, 5]
918
919 """
920 if n == 0:
921 raise ValueError('n must be > 0')
922 elif n == 1:
923 # interleave(repeat(e), iterable) -> e, x_0, e, x_1, e, x_2...
924 # islice(..., 1, None) -> x_0, e, x_1, e, x_2...
925 return islice(interleave(repeat(e), iterable), 1, None)
926 else:
927 # interleave(filler, chunks) -> [e], [x_0, x_1], [e], [x_2, x_3]...
928 # islice(..., 1, None) -> [x_0, x_1], [e], [x_2, x_3]...
929 # flatten(...) -> x_0, x_1, e, x_2, x_3...
930 filler = repeat([e])
931 chunks = chunked(iterable, n)
932 return flatten(islice(interleave(filler, chunks), 1, None))
933
934
935def unique_to_each(*iterables):
936 """Return the elements from each of the input iterables that aren't in the
937 other input iterables.
938
939 For example, suppose you have a set of packages, each with a set of
940 dependencies::
941
942 {'pkg_1': {'A', 'B'}, 'pkg_2': {'B', 'C'}, 'pkg_3': {'B', 'D'}}
943
944 If you remove one package, which dependencies can also be removed?
945
946 If ``pkg_1`` is removed, then ``A`` is no longer necessary - it is not
947 associated with ``pkg_2`` or ``pkg_3``. Similarly, ``C`` is only needed for
948 ``pkg_2``, and ``D`` is only needed for ``pkg_3``::
949
950 >>> unique_to_each({'A', 'B'}, {'B', 'C'}, {'B', 'D'})
951 [['A'], ['C'], ['D']]
952
953 If there are duplicates in one input iterable that aren't in the others
954 they will be duplicated in the output. Input order is preserved::
955
956 >>> unique_to_each("mississippi", "missouri")
957 [['p', 'p'], ['o', 'u', 'r']]
958
959 It is assumed that the elements of each iterable are hashable.
960
961 """
962 pool = [list(it) for it in iterables]
963 counts = Counter(chain.from_iterable(map(set, pool)))
964 uniques = {element for element in counts if counts[element] == 1}
965 return [list(filter(uniques.__contains__, it)) for it in pool]
966
967
968def windowed(seq, n, fillvalue=None, step=1):
969 """Return a sliding window of width *n* over the given iterable.
970
971 >>> all_windows = windowed([1, 2, 3, 4, 5], 3)
972 >>> list(all_windows)
973 [(1, 2, 3), (2, 3, 4), (3, 4, 5)]
974
975 When the window is larger than the iterable, *fillvalue* is used in place
976 of missing values:
977
978 >>> list(windowed([1, 2, 3], 4))
979 [(1, 2, 3, None)]
980
981 Each window will advance in increments of *step*:
982
983 >>> list(windowed([1, 2, 3, 4, 5, 6], 3, fillvalue='!', step=2))
984 [(1, 2, 3), (3, 4, 5), (5, 6, '!')]
985
986 To slide into the iterable's items, use :func:`chain` to add filler items
987 to the left:
988
989 >>> iterable = [1, 2, 3, 4]
990 >>> n = 3
991 >>> padding = [None] * (n - 1)
992 >>> list(windowed(chain(padding, iterable), 3))
993 [(None, None, 1), (None, 1, 2), (1, 2, 3), (2, 3, 4)]
994 """
995 if n < 0:
996 raise ValueError('n must be >= 0')
997 if n == 0:
998 yield ()
999 return
1000 if step < 1:
1001 raise ValueError('step must be >= 1')
1002
1003 iterator = iter(seq)
1004
1005 # Generate first window
1006 window = deque(islice(iterator, n), maxlen=n)
1007
1008 # Deal with the first window not being full
1009 if not window:
1010 return
1011 if len(window) < n:
1012 yield tuple(window) + ((fillvalue,) * (n - len(window)))
1013 return
1014 yield tuple(window)
1015
1016 # Create the filler for the next windows. The padding ensures
1017 # we have just enough elements to fill the last window.
1018 padding = (fillvalue,) * (n - 1 if step >= n else step - 1)
1019 filler = map(window.append, chain(iterator, padding))
1020
1021 # Generate the rest of the windows
1022 for _ in islice(filler, step - 1, None, step):
1023 yield tuple(window)
1024
1025
1026def substrings(iterable):
1027 """Yield all of the substrings of *iterable*.
1028
1029 >>> [''.join(s) for s in substrings('more')]
1030 ['m', 'o', 'r', 'e', 'mo', 'or', 're', 'mor', 'ore', 'more']
1031
1032 Note that non-string iterables can also be subdivided.
1033
1034 >>> list(substrings([0, 1, 2]))
1035 [(0,), (1,), (2,), (0, 1), (1, 2), (0, 1, 2)]
1036
1037 """
1038 # The length-1 substrings
1039 seq = []
1040 for item in iterable:
1041 seq.append(item)
1042 yield (item,)
1043 seq = tuple(seq)
1044 item_count = len(seq)
1045
1046 # And the rest
1047 for n in range(2, item_count + 1):
1048 for i in range(item_count - n + 1):
1049 yield seq[i : i + n]
1050
1051
1052def substrings_indexes(seq, reverse=False):
1053 """Yield all substrings and their positions in *seq*
1054
1055 The items yielded will be a tuple of the form ``(substr, i, j)``, where
1056 ``substr == seq[i:j]``.
1057
1058 This function only works for iterables that support slicing, such as
1059 ``str`` objects.
1060
1061 >>> for item in substrings_indexes('more'):
1062 ... print(item)
1063 ('m', 0, 1)
1064 ('o', 1, 2)
1065 ('r', 2, 3)
1066 ('e', 3, 4)
1067 ('mo', 0, 2)
1068 ('or', 1, 3)
1069 ('re', 2, 4)
1070 ('mor', 0, 3)
1071 ('ore', 1, 4)
1072 ('more', 0, 4)
1073
1074 Set *reverse* to ``True`` to yield the same items in the opposite order.
1075
1076
1077 """
1078 r = range(1, len(seq) + 1)
1079 if reverse:
1080 r = reversed(r)
1081 return (
1082 (seq[i : i + L], i, i + L) for L in r for i in range(len(seq) - L + 1)
1083 )
1084
1085
1086class bucket:
1087 """Wrap *iterable* and return an object that buckets the iterable into
1088 child iterables based on a *key* function.
1089
1090 >>> iterable = ['a1', 'b1', 'c1', 'a2', 'b2', 'c2', 'b3']
1091 >>> s = bucket(iterable, key=lambda x: x[0]) # Bucket by 1st character
1092 >>> sorted(list(s)) # Get the keys
1093 ['a', 'b', 'c']
1094 >>> a_iterable = s['a']
1095 >>> next(a_iterable)
1096 'a1'
1097 >>> next(a_iterable)
1098 'a2'
1099 >>> list(s['b'])
1100 ['b1', 'b2', 'b3']
1101
1102 The original iterable will be advanced and its items will be cached until
1103 they are used by the child iterables. This may require significant storage.
1104
1105 By default, attempting to select a bucket to which no items belong will
1106 exhaust the iterable and cache all values.
1107 If you specify a *validator* function, selected buckets will instead be
1108 checked against it.
1109
1110 >>> from itertools import count
1111 >>> it = count(1, 2) # Infinite sequence of odd numbers
1112 >>> key = lambda x: x % 10 # Bucket by last digit
1113 >>> validator = lambda x: x in {1, 3, 5, 7, 9} # Odd digits only
1114 >>> s = bucket(it, key=key, validator=validator)
1115 >>> 2 in s
1116 False
1117 >>> list(s[2])
1118 []
1119
1120 """
1121
1122 def __init__(self, iterable, key, validator=None):
1123 self._it = iter(iterable)
1124 self._key = key
1125 self._cache = defaultdict(deque)
1126 self._validator = validator or (lambda x: True)
1127
1128 def __contains__(self, value):
1129 if not self._validator(value):
1130 return False
1131
1132 try:
1133 item = next(self[value])
1134 except StopIteration:
1135 return False
1136 else:
1137 self._cache[value].appendleft(item)
1138
1139 return True
1140
1141 def _get_values(self, value):
1142 """
1143 Helper to yield items from the parent iterator that match *value*.
1144 Items that don't match are stored in the local cache as they
1145 are encountered.
1146 """
1147 while True:
1148 # If we've cached some items that match the target value, emit
1149 # the first one and evict it from the cache.
1150 if self._cache[value]:
1151 yield self._cache[value].popleft()
1152 # Otherwise we need to advance the parent iterator to search for
1153 # a matching item, caching the rest.
1154 else:
1155 while True:
1156 try:
1157 item = next(self._it)
1158 except StopIteration:
1159 return
1160 item_value = self._key(item)
1161 if item_value == value:
1162 yield item
1163 break
1164 elif self._validator(item_value):
1165 self._cache[item_value].append(item)
1166
1167 def __iter__(self):
1168 for item in self._it:
1169 item_value = self._key(item)
1170 if self._validator(item_value):
1171 self._cache[item_value].append(item)
1172
1173 return iter(self._cache)
1174
1175 def __getitem__(self, value):
1176 if not self._validator(value):
1177 return iter(())
1178
1179 return self._get_values(value)
1180
1181
1182def spy(iterable, n=1):
1183 """Return a 2-tuple with a list containing the first *n* elements of
1184 *iterable*, and an iterator with the same items as *iterable*.
1185 This allows you to "look ahead" at the items in the iterable without
1186 advancing it.
1187
1188 There is one item in the list by default:
1189
1190 >>> iterable = 'abcdefg'
1191 >>> head, iterable = spy(iterable)
1192 >>> head
1193 ['a']
1194 >>> list(iterable)
1195 ['a', 'b', 'c', 'd', 'e', 'f', 'g']
1196
1197 You may use unpacking to retrieve items instead of lists:
1198
1199 >>> (head,), iterable = spy('abcdefg')
1200 >>> head
1201 'a'
1202 >>> (first, second), iterable = spy('abcdefg', 2)
1203 >>> first
1204 'a'
1205 >>> second
1206 'b'
1207
1208 The number of items requested can be larger than the number of items in
1209 the iterable:
1210
1211 >>> iterable = [1, 2, 3, 4, 5]
1212 >>> head, iterable = spy(iterable, 10)
1213 >>> head
1214 [1, 2, 3, 4, 5]
1215 >>> list(iterable)
1216 [1, 2, 3, 4, 5]
1217
1218 """
1219 p, q = tee(iterable)
1220 return take(n, q), p
1221
1222
1223def interleave(*iterables):
1224 """Return a new iterable yielding from each iterable in turn,
1225 until the shortest is exhausted.
1226
1227 >>> list(interleave([1, 2, 3], [4, 5], [6, 7, 8]))
1228 [1, 4, 6, 2, 5, 7]
1229
1230 For a version that doesn't terminate after the shortest iterable is
1231 exhausted, see :func:`interleave_longest`.
1232
1233 """
1234 return chain.from_iterable(zip(*iterables))
1235
1236
1237def interleave_longest(*iterables):
1238 """Return a new iterable yielding from each iterable in turn,
1239 skipping any that are exhausted.
1240
1241 >>> list(interleave_longest([1, 2, 3], [4, 5], [6, 7, 8]))
1242 [1, 4, 6, 2, 5, 7, 3, 8]
1243
1244 This function produces the same output as :func:`roundrobin`, but may
1245 perform better for some inputs (in particular when the number of iterables
1246 is large).
1247
1248 """
1249 for xs in zip_longest(*iterables, fillvalue=_marker):
1250 for x in xs:
1251 if x is not _marker:
1252 yield x
1253
1254
1255def interleave_evenly(iterables, lengths=None):
1256 """
1257 Interleave multiple iterables so that their elements are evenly distributed
1258 throughout the output sequence.
1259
1260 >>> iterables = [1, 2, 3, 4, 5], ['a', 'b']
1261 >>> list(interleave_evenly(iterables))
1262 [1, 2, 'a', 3, 4, 'b', 5]
1263
1264 >>> iterables = [[1, 2, 3], [4, 5], [6, 7, 8]]
1265 >>> list(interleave_evenly(iterables))
1266 [1, 6, 4, 2, 7, 3, 8, 5]
1267
1268 This function requires iterables of known length. Iterables without
1269 ``__len__()`` can be used by manually specifying lengths with *lengths*:
1270
1271 >>> from itertools import combinations, repeat
1272 >>> iterables = [combinations(range(4), 2), ['a', 'b', 'c']]
1273 >>> lengths = [4 * (4 - 1) // 2, 3]
1274 >>> list(interleave_evenly(iterables, lengths=lengths))
1275 [(0, 1), (0, 2), 'a', (0, 3), (1, 2), 'b', (1, 3), (2, 3), 'c']
1276
1277 Based on Bresenham's algorithm.
1278 """
1279 if lengths is None:
1280 try:
1281 lengths = [len(it) for it in iterables]
1282 except TypeError:
1283 raise ValueError(
1284 'Iterable lengths could not be determined automatically. '
1285 'Specify them with the lengths keyword.'
1286 )
1287 elif len(iterables) != len(lengths):
1288 raise ValueError('Mismatching number of iterables and lengths.')
1289
1290 dims = len(lengths)
1291
1292 # sort iterables by length, descending
1293 lengths_permute = sorted(
1294 range(dims), key=lambda i: lengths[i], reverse=True
1295 )
1296 lengths_desc = [lengths[i] for i in lengths_permute]
1297 iters_desc = [iter(iterables[i]) for i in lengths_permute]
1298
1299 # the longest iterable is the primary one (Bresenham: the longest
1300 # distance along an axis)
1301 delta_primary, deltas_secondary = lengths_desc[0], lengths_desc[1:]
1302 iter_primary, iters_secondary = iters_desc[0], iters_desc[1:]
1303 errors = [delta_primary // dims] * len(deltas_secondary)
1304
1305 to_yield = sum(lengths)
1306 while to_yield:
1307 yield next(iter_primary)
1308 to_yield -= 1
1309 # update errors for each secondary iterable
1310 errors = [e - delta for e, delta in zip(errors, deltas_secondary)]
1311
1312 # those iterables for which the error is negative are yielded
1313 # ("diagonal step" in Bresenham)
1314 for i, e_ in enumerate(errors):
1315 if e_ < 0:
1316 yield next(iters_secondary[i])
1317 to_yield -= 1
1318 errors[i] += delta_primary
1319
1320
1321def interleave_randomly(*iterables):
1322 """Return a new iterable randomly selecting from each iterable,
1323 until all iterables are exhausted.
1324
1325 The relative order of the elements in each iterable is preserved,
1326 but the order with respect to other iterables is randomized.
1327
1328 >>> iterables = [1, 2, 3], 'abc', (True, False, None)
1329 >>> list(interleave_randomly(*iterables)) # doctest: +SKIP
1330 ['a', 'b', 1, 'c', True, False, None, 2, 3]
1331 """
1332 iterators = [iter(e) for e in iterables]
1333 while iterators:
1334 idx = randrange(len(iterators))
1335 try:
1336 yield next(iterators[idx])
1337 except StopIteration:
1338 # equivalent to `list.pop` but slightly faster
1339 iterators[idx] = iterators[-1]
1340 del iterators[-1]
1341
1342
1343def collapse(iterable, base_type=None, levels=None):
1344 """Flatten an iterable with multiple levels of nesting (e.g., a list of
1345 lists of tuples) into non-iterable types.
1346
1347 >>> iterable = [(1, 2), ([3, 4], [[5], [6]])]
1348 >>> list(collapse(iterable))
1349 [1, 2, 3, 4, 5, 6]
1350
1351 Binary and text strings are not considered iterable and
1352 will not be collapsed.
1353
1354 To avoid collapsing other types, specify *base_type*:
1355
1356 >>> iterable = ['ab', ('cd', 'ef'), ['gh', 'ij']]
1357 >>> list(collapse(iterable, base_type=tuple))
1358 ['ab', ('cd', 'ef'), 'gh', 'ij']
1359
1360 Specify *levels* to stop flattening after a certain level:
1361
1362 >>> iterable = [('a', ['b']), ('c', ['d'])]
1363 >>> list(collapse(iterable)) # Fully flattened
1364 ['a', 'b', 'c', 'd']
1365 >>> list(collapse(iterable, levels=1)) # Only one level flattened
1366 ['a', ['b'], 'c', ['d']]
1367
1368 """
1369 stack = deque()
1370 # Add our first node group, treat the iterable as a single node
1371 stack.appendleft((0, repeat(iterable, 1)))
1372
1373 while stack:
1374 node_group = stack.popleft()
1375 level, nodes = node_group
1376
1377 # Check if beyond max level
1378 if levels is not None and level > levels:
1379 yield from nodes
1380 continue
1381
1382 for node in nodes:
1383 # Check if done iterating
1384 if isinstance(node, (str, bytes)) or (
1385 (base_type is not None) and isinstance(node, base_type)
1386 ):
1387 yield node
1388 # Otherwise try to create child nodes
1389 else:
1390 try:
1391 tree = iter(node)
1392 except TypeError:
1393 yield node
1394 else:
1395 # Save our current location
1396 stack.appendleft(node_group)
1397 # Append the new child node
1398 stack.appendleft((level + 1, tree))
1399 # Break to process child node
1400 break
1401
1402
1403def side_effect(func, iterable, chunk_size=None, before=None, after=None):
1404 """Invoke *func* on each item in *iterable* (or on each *chunk_size* group
1405 of items) before yielding the item.
1406
1407 `func` must be a function that takes a single argument. Its return value
1408 will be discarded.
1409
1410 *before* and *after* are optional functions that take no arguments. They
1411 will be executed before iteration starts and after it ends, respectively.
1412
1413 `side_effect` can be used for logging, updating progress bars, or anything
1414 that is not functionally "pure."
1415
1416 Emitting a status message:
1417
1418 >>> from more_itertools import consume
1419 >>> func = lambda item: print('Received {}'.format(item))
1420 >>> consume(side_effect(func, range(2)))
1421 Received 0
1422 Received 1
1423
1424 Operating on chunks of items:
1425
1426 >>> pair_sums = []
1427 >>> func = lambda chunk: pair_sums.append(sum(chunk))
1428 >>> list(side_effect(func, [0, 1, 2, 3, 4, 5], 2))
1429 [0, 1, 2, 3, 4, 5]
1430 >>> list(pair_sums)
1431 [1, 5, 9]
1432
1433 Writing to a file-like object:
1434
1435 >>> from io import StringIO
1436 >>> from more_itertools import consume
1437 >>> f = StringIO()
1438 >>> func = lambda x: print(x, file=f)
1439 >>> before = lambda: print(u'HEADER', file=f)
1440 >>> after = f.close
1441 >>> it = [u'a', u'b', u'c']
1442 >>> consume(side_effect(func, it, before=before, after=after))
1443 >>> f.closed
1444 True
1445
1446 """
1447 try:
1448 if before is not None:
1449 before()
1450
1451 if chunk_size is None:
1452 for item in iterable:
1453 func(item)
1454 yield item
1455 else:
1456 for chunk in chunked(iterable, chunk_size):
1457 func(chunk)
1458 yield from chunk
1459 finally:
1460 if after is not None:
1461 after()
1462
1463
1464def sliced(seq, n, strict=False):
1465 """Yield slices of length *n* from the sequence *seq*.
1466
1467 >>> list(sliced((1, 2, 3, 4, 5, 6), 3))
1468 [(1, 2, 3), (4, 5, 6)]
1469
1470 By the default, the last yielded slice will have fewer than *n* elements
1471 if the length of *seq* is not divisible by *n*:
1472
1473 >>> list(sliced((1, 2, 3, 4, 5, 6, 7, 8), 3))
1474 [(1, 2, 3), (4, 5, 6), (7, 8)]
1475
1476 If the length of *seq* is not divisible by *n* and *strict* is
1477 ``True``, then ``ValueError`` will be raised before the last
1478 slice is yielded.
1479
1480 This function will only work for iterables that support slicing.
1481 For non-sliceable iterables, see :func:`chunked`.
1482
1483 """
1484 iterator = takewhile(len, (seq[i : i + n] for i in count(0, n)))
1485 if strict:
1486
1487 def ret():
1488 for _slice in iterator:
1489 if len(_slice) != n:
1490 raise ValueError("seq is not divisible by n.")
1491 yield _slice
1492
1493 return ret()
1494 else:
1495 return iterator
1496
1497
1498def split_at(iterable, pred, maxsplit=-1, keep_separator=False):
1499 """Yield lists of items from *iterable*, where each list is delimited by
1500 an item where callable *pred* returns ``True``.
1501
1502 >>> list(split_at('abcdcba', lambda x: x == 'b'))
1503 [['a'], ['c', 'd', 'c'], ['a']]
1504
1505 >>> list(split_at(range(10), lambda n: n % 2 == 1))
1506 [[0], [2], [4], [6], [8], []]
1507
1508 At most *maxsplit* splits are done. If *maxsplit* is not specified or -1,
1509 then there is no limit on the number of splits:
1510
1511 >>> list(split_at(range(10), lambda n: n % 2 == 1, maxsplit=2))
1512 [[0], [2], [4, 5, 6, 7, 8, 9]]
1513
1514 By default, the delimiting items are not included in the output.
1515 To include them, set *keep_separator* to ``True``.
1516
1517 >>> list(split_at('abcdcba', lambda x: x == 'b', keep_separator=True))
1518 [['a'], ['b'], ['c', 'd', 'c'], ['b'], ['a']]
1519
1520 """
1521 if maxsplit == 0:
1522 yield list(iterable)
1523 return
1524
1525 buf = []
1526 it = iter(iterable)
1527 for item in it:
1528 if pred(item):
1529 yield buf
1530 if keep_separator:
1531 yield [item]
1532 if maxsplit == 1:
1533 yield list(it)
1534 return
1535 buf = []
1536 maxsplit -= 1
1537 else:
1538 buf.append(item)
1539 yield buf
1540
1541
1542def split_before(iterable, pred, maxsplit=-1):
1543 """Yield lists of items from *iterable*, where each list ends just before
1544 an item for which callable *pred* returns ``True``:
1545
1546 >>> list(split_before('OneTwo', lambda s: s.isupper()))
1547 [['O', 'n', 'e'], ['T', 'w', 'o']]
1548
1549 >>> list(split_before(range(10), lambda n: n % 3 == 0))
1550 [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
1551
1552 At most *maxsplit* splits are done. If *maxsplit* is not specified or -1,
1553 then there is no limit on the number of splits:
1554
1555 >>> list(split_before(range(10), lambda n: n % 3 == 0, maxsplit=2))
1556 [[0, 1, 2], [3, 4, 5], [6, 7, 8, 9]]
1557 """
1558 if maxsplit == 0:
1559 yield list(iterable)
1560 return
1561
1562 buf = []
1563 it = iter(iterable)
1564 for item in it:
1565 if pred(item) and buf:
1566 yield buf
1567 if maxsplit == 1:
1568 yield [item, *it]
1569 return
1570 buf = []
1571 maxsplit -= 1
1572 buf.append(item)
1573 if buf:
1574 yield buf
1575
1576
1577def split_after(iterable, pred, maxsplit=-1):
1578 """Yield lists of items from *iterable*, where each list ends with an
1579 item where callable *pred* returns ``True``:
1580
1581 >>> list(split_after('one1two2', lambda s: s.isdigit()))
1582 [['o', 'n', 'e', '1'], ['t', 'w', 'o', '2']]
1583
1584 >>> list(split_after(range(10), lambda n: n % 3 == 0))
1585 [[0], [1, 2, 3], [4, 5, 6], [7, 8, 9]]
1586
1587 At most *maxsplit* splits are done. If *maxsplit* is not specified or -1,
1588 then there is no limit on the number of splits:
1589
1590 >>> list(split_after(range(10), lambda n: n % 3 == 0, maxsplit=2))
1591 [[0], [1, 2, 3], [4, 5, 6, 7, 8, 9]]
1592
1593 """
1594 if maxsplit == 0:
1595 yield list(iterable)
1596 return
1597
1598 buf = []
1599 it = iter(iterable)
1600 for item in it:
1601 buf.append(item)
1602 if pred(item) and buf:
1603 yield buf
1604 if maxsplit == 1:
1605 buf = list(it)
1606 if buf:
1607 yield buf
1608 return
1609 buf = []
1610 maxsplit -= 1
1611 if buf:
1612 yield buf
1613
1614
1615def split_when(iterable, pred, maxsplit=-1):
1616 """Split *iterable* into pieces based on the output of *pred*.
1617 *pred* should be a function that takes successive pairs of items and
1618 returns ``True`` if the iterable should be split in between them.
1619
1620 For example, to find runs of increasing numbers, split the iterable when
1621 element ``i`` is larger than element ``i + 1``:
1622
1623 >>> list(split_when([1, 2, 3, 3, 2, 5, 2, 4, 2], lambda x, y: x > y))
1624 [[1, 2, 3, 3], [2, 5], [2, 4], [2]]
1625
1626 At most *maxsplit* splits are done. If *maxsplit* is not specified or -1,
1627 then there is no limit on the number of splits:
1628
1629 >>> list(split_when([1, 2, 3, 3, 2, 5, 2, 4, 2],
1630 ... lambda x, y: x > y, maxsplit=2))
1631 [[1, 2, 3, 3], [2, 5], [2, 4, 2]]
1632
1633 """
1634 if maxsplit == 0:
1635 yield list(iterable)
1636 return
1637
1638 it = iter(iterable)
1639 try:
1640 cur_item = next(it)
1641 except StopIteration:
1642 return
1643
1644 buf = [cur_item]
1645 for next_item in it:
1646 if pred(cur_item, next_item):
1647 yield buf
1648 if maxsplit == 1:
1649 yield [next_item, *it]
1650 return
1651 buf = []
1652 maxsplit -= 1
1653
1654 buf.append(next_item)
1655 cur_item = next_item
1656
1657 yield buf
1658
1659
1660def split_into(iterable, sizes):
1661 """Yield a list of sequential items from *iterable* of length 'n' for each
1662 integer 'n' in *sizes*.
1663
1664 >>> list(split_into([1,2,3,4,5,6], [1,2,3]))
1665 [[1], [2, 3], [4, 5, 6]]
1666
1667 If the sum of *sizes* is smaller than the length of *iterable*, then the
1668 remaining items of *iterable* will not be returned.
1669
1670 >>> list(split_into([1,2,3,4,5,6], [2,3]))
1671 [[1, 2], [3, 4, 5]]
1672
1673 If the sum of *sizes* is larger than the length of *iterable*, fewer items
1674 will be returned in the iteration that overruns the *iterable* and further
1675 lists will be empty:
1676
1677 >>> list(split_into([1,2,3,4], [1,2,3,4]))
1678 [[1], [2, 3], [4], []]
1679
1680 When a ``None`` object is encountered in *sizes*, the returned list will
1681 contain items up to the end of *iterable* the same way that
1682 :func:`itertools.slice` does:
1683
1684 >>> list(split_into([1,2,3,4,5,6,7,8,9,0], [2,3,None]))
1685 [[1, 2], [3, 4, 5], [6, 7, 8, 9, 0]]
1686
1687 :func:`split_into` can be useful for grouping a series of items where the
1688 sizes of the groups are not uniform. An example would be where in a row
1689 from a table, multiple columns represent elements of the same feature
1690 (e.g. a point represented by x,y,z) but, the format is not the same for
1691 all columns.
1692 """
1693 # convert the iterable argument into an iterator so its contents can
1694 # be consumed by islice in case it is a generator
1695 it = iter(iterable)
1696
1697 for size in sizes:
1698 if size is None:
1699 yield list(it)
1700 return
1701 else:
1702 yield list(islice(it, size))
1703
1704
1705def padded(iterable, fillvalue=None, n=None, next_multiple=False):
1706 """Yield the elements from *iterable*, followed by *fillvalue*, such that
1707 at least *n* items are emitted.
1708
1709 >>> list(padded([1, 2, 3], '?', 5))
1710 [1, 2, 3, '?', '?']
1711
1712 If *next_multiple* is ``True``, *fillvalue* will be emitted until the
1713 number of items emitted is a multiple of *n*:
1714
1715 >>> list(padded([1, 2, 3, 4], n=3, next_multiple=True))
1716 [1, 2, 3, 4, None, None]
1717
1718 If *n* is ``None``, *fillvalue* will be emitted indefinitely.
1719
1720 To create an *iterable* of exactly size *n*, you can truncate with
1721 :func:`islice`.
1722
1723 >>> list(islice(padded([1, 2, 3], '?'), 5))
1724 [1, 2, 3, '?', '?']
1725 >>> list(islice(padded([1, 2, 3, 4, 5, 6, 7, 8], '?'), 5))
1726 [1, 2, 3, 4, 5]
1727
1728 """
1729 iterator = iter(iterable)
1730 iterator_with_repeat = chain(iterator, repeat(fillvalue))
1731
1732 if n is None:
1733 return iterator_with_repeat
1734 elif n < 1:
1735 raise ValueError('n must be at least 1')
1736 elif next_multiple:
1737
1738 def slice_generator():
1739 for first in iterator:
1740 yield (first,)
1741 yield islice(iterator_with_repeat, n - 1)
1742
1743 # While elements exist produce slices of size n
1744 return chain.from_iterable(slice_generator())
1745 else:
1746 # Ensure the first batch is at least size n then iterate
1747 return chain(islice(iterator_with_repeat, n), iterator)
1748
1749
1750def repeat_each(iterable, n=2):
1751 """Repeat each element in *iterable* *n* times.
1752
1753 >>> list(repeat_each('ABC', 3))
1754 ['A', 'A', 'A', 'B', 'B', 'B', 'C', 'C', 'C']
1755 """
1756 return chain.from_iterable(map(repeat, iterable, repeat(n)))
1757
1758
1759def repeat_last(iterable, default=None):
1760 """After the *iterable* is exhausted, keep yielding its last element.
1761
1762 >>> list(islice(repeat_last(range(3)), 5))
1763 [0, 1, 2, 2, 2]
1764
1765 If the iterable is empty, yield *default* forever::
1766
1767 >>> list(islice(repeat_last(range(0), 42), 5))
1768 [42, 42, 42, 42, 42]
1769
1770 """
1771 item = _marker
1772 for item in iterable:
1773 yield item
1774 final = default if item is _marker else item
1775 yield from repeat(final)
1776
1777
1778def distribute(n, iterable):
1779 """Distribute the items from *iterable* among *n* smaller iterables.
1780
1781 >>> group_1, group_2 = distribute(2, [1, 2, 3, 4, 5, 6])
1782 >>> list(group_1)
1783 [1, 3, 5]
1784 >>> list(group_2)
1785 [2, 4, 6]
1786
1787 If the length of *iterable* is not evenly divisible by *n*, then the
1788 length of the returned iterables will not be identical:
1789
1790 >>> children = distribute(3, [1, 2, 3, 4, 5, 6, 7])
1791 >>> [list(c) for c in children]
1792 [[1, 4, 7], [2, 5], [3, 6]]
1793
1794 If the length of *iterable* is smaller than *n*, then the last returned
1795 iterables will be empty:
1796
1797 >>> children = distribute(5, [1, 2, 3])
1798 >>> [list(c) for c in children]
1799 [[1], [2], [3], [], []]
1800
1801 This function uses :func:`itertools.tee` and may require significant
1802 storage.
1803
1804 If you need the order items in the smaller iterables to match the
1805 original iterable, see :func:`divide`.
1806
1807 """
1808 if n < 1:
1809 raise ValueError('n must be at least 1')
1810
1811 children = tee(iterable, n)
1812 return [islice(it, index, None, n) for index, it in enumerate(children)]
1813
1814
1815def stagger(iterable, offsets=(-1, 0, 1), longest=False, fillvalue=None):
1816 """Yield tuples whose elements are offset from *iterable*.
1817 The amount by which the `i`-th item in each tuple is offset is given by
1818 the `i`-th item in *offsets*.
1819
1820 >>> list(stagger([0, 1, 2, 3]))
1821 [(None, 0, 1), (0, 1, 2), (1, 2, 3)]
1822 >>> list(stagger(range(8), offsets=(0, 2, 4)))
1823 [(0, 2, 4), (1, 3, 5), (2, 4, 6), (3, 5, 7)]
1824
1825 By default, the sequence will end when the final element of a tuple is the
1826 last item in the iterable. To continue until the first element of a tuple
1827 is the last item in the iterable, set *longest* to ``True``::
1828
1829 >>> list(stagger([0, 1, 2, 3], longest=True))
1830 [(None, 0, 1), (0, 1, 2), (1, 2, 3), (2, 3, None), (3, None, None)]
1831
1832 By default, ``None`` will be used to replace offsets beyond the end of the
1833 sequence. Specify *fillvalue* to use some other value.
1834
1835 """
1836 children = tee(iterable, len(offsets))
1837
1838 return zip_offset(
1839 *children, offsets=offsets, longest=longest, fillvalue=fillvalue
1840 )
1841
1842
1843def zip_equal(*iterables):
1844 """``zip`` the input *iterables* together but raise
1845 ``UnequalIterablesError`` if they aren't all the same length.
1846
1847 >>> it_1 = range(3)
1848 >>> it_2 = iter('abc')
1849 >>> list(zip_equal(it_1, it_2))
1850 [(0, 'a'), (1, 'b'), (2, 'c')]
1851
1852 >>> it_1 = range(3)
1853 >>> it_2 = iter('abcd')
1854 >>> list(zip_equal(it_1, it_2)) # doctest: +IGNORE_EXCEPTION_DETAIL
1855 Traceback (most recent call last):
1856 ...
1857 more_itertools.more.UnequalIterablesError: Iterables have different
1858 lengths
1859
1860 """
1861 if hexversion >= 0x30A00A6:
1862 warnings.warn(
1863 (
1864 'zip_equal will be removed in a future version of '
1865 'more-itertools. Use the builtin zip function with '
1866 'strict=True instead.'
1867 ),
1868 DeprecationWarning,
1869 )
1870
1871 return _zip_equal(*iterables)
1872
1873
1874def zip_offset(*iterables, offsets, longest=False, fillvalue=None):
1875 """``zip`` the input *iterables* together, but offset the `i`-th iterable
1876 by the `i`-th item in *offsets*.
1877
1878 >>> list(zip_offset('0123', 'abcdef', offsets=(0, 1)))
1879 [('0', 'b'), ('1', 'c'), ('2', 'd'), ('3', 'e')]
1880
1881 This can be used as a lightweight alternative to SciPy or pandas to analyze
1882 data sets in which some series have a lead or lag relationship.
1883
1884 By default, the sequence will end when the shortest iterable is exhausted.
1885 To continue until the longest iterable is exhausted, set *longest* to
1886 ``True``.
1887
1888 >>> list(zip_offset('0123', 'abcdef', offsets=(0, 1), longest=True))
1889 [('0', 'b'), ('1', 'c'), ('2', 'd'), ('3', 'e'), (None, 'f')]
1890
1891 By default, ``None`` will be used to replace offsets beyond the end of the
1892 sequence. Specify *fillvalue* to use some other value.
1893
1894 """
1895 if len(iterables) != len(offsets):
1896 raise ValueError("Number of iterables and offsets didn't match")
1897
1898 staggered = []
1899 for it, n in zip(iterables, offsets):
1900 if n < 0:
1901 staggered.append(chain(repeat(fillvalue, -n), it))
1902 elif n > 0:
1903 staggered.append(islice(it, n, None))
1904 else:
1905 staggered.append(it)
1906
1907 if longest:
1908 return zip_longest(*staggered, fillvalue=fillvalue)
1909
1910 return zip(*staggered)
1911
1912
1913def sort_together(
1914 iterables, key_list=(0,), key=None, reverse=False, strict=False
1915):
1916 """Return the input iterables sorted together, with *key_list* as the
1917 priority for sorting. All iterables are trimmed to the length of the
1918 shortest one.
1919
1920 This can be used like the sorting function in a spreadsheet. If each
1921 iterable represents a column of data, the key list determines which
1922 columns are used for sorting.
1923
1924 By default, all iterables are sorted using the ``0``-th iterable::
1925
1926 >>> iterables = [(4, 3, 2, 1), ('a', 'b', 'c', 'd')]
1927 >>> sort_together(iterables)
1928 [(1, 2, 3, 4), ('d', 'c', 'b', 'a')]
1929
1930 Set a different key list to sort according to another iterable.
1931 Specifying multiple keys dictates how ties are broken::
1932
1933 >>> iterables = [(3, 1, 2), (0, 1, 0), ('c', 'b', 'a')]
1934 >>> sort_together(iterables, key_list=(1, 2))
1935 [(2, 3, 1), (0, 0, 1), ('a', 'c', 'b')]
1936
1937 To sort by a function of the elements of the iterable, pass a *key*
1938 function. Its arguments are the elements of the iterables corresponding to
1939 the key list::
1940
1941 >>> names = ('a', 'b', 'c')
1942 >>> lengths = (1, 2, 3)
1943 >>> widths = (5, 2, 1)
1944 >>> def area(length, width):
1945 ... return length * width
1946 >>> sort_together([names, lengths, widths], key_list=(1, 2), key=area)
1947 [('c', 'b', 'a'), (3, 2, 1), (1, 2, 5)]
1948
1949 Set *reverse* to ``True`` to sort in descending order.
1950
1951 >>> sort_together([(1, 2, 3), ('c', 'b', 'a')], reverse=True)
1952 [(3, 2, 1), ('a', 'b', 'c')]
1953
1954 If the *strict* keyword argument is ``True``, then
1955 ``UnequalIterablesError`` will be raised if any of the iterables have
1956 different lengths.
1957
1958 """
1959 if key is None:
1960 # if there is no key function, the key argument to sorted is an
1961 # itemgetter
1962 key_argument = itemgetter(*key_list)
1963 else:
1964 # if there is a key function, call it with the items at the offsets
1965 # specified by the key function as arguments
1966 key_list = list(key_list)
1967 if len(key_list) == 1:
1968 # if key_list contains a single item, pass the item at that offset
1969 # as the only argument to the key function
1970 key_offset = key_list[0]
1971 key_argument = lambda zipped_items: key(zipped_items[key_offset])
1972 else:
1973 # if key_list contains multiple items, use itemgetter to return a
1974 # tuple of items, which we pass as *args to the key function
1975 get_key_items = itemgetter(*key_list)
1976 key_argument = lambda zipped_items: key(
1977 *get_key_items(zipped_items)
1978 )
1979
1980 zipper = zip_equal if strict else zip
1981 return list(
1982 zipper(*sorted(zipper(*iterables), key=key_argument, reverse=reverse))
1983 )
1984
1985
1986def unzip(iterable):
1987 """The inverse of :func:`zip`, this function disaggregates the elements
1988 of the zipped *iterable*.
1989
1990 The ``i``-th iterable contains the ``i``-th element from each element
1991 of the zipped iterable. The first element is used to determine the
1992 length of the remaining elements.
1993
1994 >>> iterable = [('a', 1), ('b', 2), ('c', 3), ('d', 4)]
1995 >>> letters, numbers = unzip(iterable)
1996 >>> list(letters)
1997 ['a', 'b', 'c', 'd']
1998 >>> list(numbers)
1999 [1, 2, 3, 4]
2000
2001 This is similar to using ``zip(*iterable)``, but it avoids reading
2002 *iterable* into memory. Note, however, that this function uses
2003 :func:`itertools.tee` and thus may require significant storage.
2004
2005 """
2006 head, iterable = spy(iterable)
2007 if not head:
2008 # empty iterable, e.g. zip([], [], [])
2009 return ()
2010 # spy returns a one-length iterable as head
2011 head = head[0]
2012 iterables = tee(iterable, len(head))
2013
2014 # If we have an iterable like iter([(1, 2, 3), (4, 5), (6,)]),
2015 # the second unzipped iterable fails at the third tuple since
2016 # it tries to access (6,)[1].
2017 # Same with the third unzipped iterable and the second tuple.
2018 # To support these "improperly zipped" iterables, we suppress
2019 # the IndexError, which just stops the unzipped iterables at
2020 # first length mismatch.
2021 return tuple(
2022 iter_suppress(map(itemgetter(i), it), IndexError)
2023 for i, it in enumerate(iterables)
2024 )
2025
2026
2027def divide(n, iterable):
2028 """Divide the elements from *iterable* into *n* parts, maintaining
2029 order.
2030
2031 >>> group_1, group_2 = divide(2, [1, 2, 3, 4, 5, 6])
2032 >>> list(group_1)
2033 [1, 2, 3]
2034 >>> list(group_2)
2035 [4, 5, 6]
2036
2037 If the length of *iterable* is not evenly divisible by *n*, then the
2038 length of the returned iterables will not be identical:
2039
2040 >>> children = divide(3, [1, 2, 3, 4, 5, 6, 7])
2041 >>> [list(c) for c in children]
2042 [[1, 2, 3], [4, 5], [6, 7]]
2043
2044 If the length of the iterable is smaller than n, then the last returned
2045 iterables will be empty:
2046
2047 >>> children = divide(5, [1, 2, 3])
2048 >>> [list(c) for c in children]
2049 [[1], [2], [3], [], []]
2050
2051 This function will exhaust the iterable before returning.
2052 If order is not important, see :func:`distribute`, which does not first
2053 pull the iterable into memory.
2054
2055 """
2056 if n < 1:
2057 raise ValueError('n must be at least 1')
2058
2059 try:
2060 iterable[:0]
2061 except TypeError:
2062 seq = tuple(iterable)
2063 else:
2064 seq = iterable
2065
2066 q, r = divmod(len(seq), n)
2067
2068 ret = []
2069 stop = 0
2070 for i in range(1, n + 1):
2071 start = stop
2072 stop += q + 1 if i <= r else q
2073 ret.append(iter(seq[start:stop]))
2074
2075 return ret
2076
2077
2078def always_iterable(obj, base_type=(str, bytes)):
2079 """If *obj* is iterable, return an iterator over its items::
2080
2081 >>> obj = (1, 2, 3)
2082 >>> list(always_iterable(obj))
2083 [1, 2, 3]
2084
2085 If *obj* is not iterable, return a one-item iterable containing *obj*::
2086
2087 >>> obj = 1
2088 >>> list(always_iterable(obj))
2089 [1]
2090
2091 If *obj* is ``None``, return an empty iterable:
2092
2093 >>> obj = None
2094 >>> list(always_iterable(None))
2095 []
2096
2097 By default, binary and text strings are not considered iterable::
2098
2099 >>> obj = 'foo'
2100 >>> list(always_iterable(obj))
2101 ['foo']
2102
2103 If *base_type* is set, objects for which ``isinstance(obj, base_type)``
2104 returns ``True`` won't be considered iterable.
2105
2106 >>> obj = {'a': 1}
2107 >>> list(always_iterable(obj)) # Iterate over the dict's keys
2108 ['a']
2109 >>> list(always_iterable(obj, base_type=dict)) # Treat dicts as a unit
2110 [{'a': 1}]
2111
2112 Set *base_type* to ``None`` to avoid any special handling and treat objects
2113 Python considers iterable as iterable:
2114
2115 >>> obj = 'foo'
2116 >>> list(always_iterable(obj, base_type=None))
2117 ['f', 'o', 'o']
2118 """
2119 if obj is None:
2120 return iter(())
2121
2122 if (base_type is not None) and isinstance(obj, base_type):
2123 return iter((obj,))
2124
2125 try:
2126 return iter(obj)
2127 except TypeError:
2128 return iter((obj,))
2129
2130
2131def adjacent(predicate, iterable, distance=1):
2132 """Return an iterable over `(bool, item)` tuples where the `item` is
2133 drawn from *iterable* and the `bool` indicates whether
2134 that item satisfies the *predicate* or is adjacent to an item that does.
2135
2136 For example, to find whether items are adjacent to a ``3``::
2137
2138 >>> list(adjacent(lambda x: x == 3, range(6)))
2139 [(False, 0), (False, 1), (True, 2), (True, 3), (True, 4), (False, 5)]
2140
2141 Set *distance* to change what counts as adjacent. For example, to find
2142 whether items are two places away from a ``3``:
2143
2144 >>> list(adjacent(lambda x: x == 3, range(6), distance=2))
2145 [(False, 0), (True, 1), (True, 2), (True, 3), (True, 4), (True, 5)]
2146
2147 This is useful for contextualizing the results of a search function.
2148 For example, a code comparison tool might want to identify lines that
2149 have changed, but also surrounding lines to give the viewer of the diff
2150 context.
2151
2152 The predicate function will only be called once for each item in the
2153 iterable.
2154
2155 See also :func:`groupby_transform`, which can be used with this function
2156 to group ranges of items with the same `bool` value.
2157
2158 """
2159 # Allow distance=0 mainly for testing that it reproduces results with map()
2160 if distance < 0:
2161 raise ValueError('distance must be at least 0')
2162
2163 i1, i2 = tee(iterable)
2164 padding = [False] * distance
2165 selected = chain(padding, map(predicate, i1), padding)
2166 adjacent_to_selected = map(any, windowed(selected, 2 * distance + 1))
2167 return zip(adjacent_to_selected, i2)
2168
2169
2170def groupby_transform(iterable, keyfunc=None, valuefunc=None, reducefunc=None):
2171 """An extension of :func:`itertools.groupby` that can apply transformations
2172 to the grouped data.
2173
2174 * *keyfunc* is a function computing a key value for each item in *iterable*
2175 * *valuefunc* is a function that transforms the individual items from
2176 *iterable* after grouping
2177 * *reducefunc* is a function that transforms each group of items
2178
2179 >>> iterable = 'aAAbBBcCC'
2180 >>> keyfunc = lambda k: k.upper()
2181 >>> valuefunc = lambda v: v.lower()
2182 >>> reducefunc = lambda g: ''.join(g)
2183 >>> list(groupby_transform(iterable, keyfunc, valuefunc, reducefunc))
2184 [('A', 'aaa'), ('B', 'bbb'), ('C', 'ccc')]
2185
2186 Each optional argument defaults to an identity function if not specified.
2187
2188 :func:`groupby_transform` is useful when grouping elements of an iterable
2189 using a separate iterable as the key. To do this, :func:`zip` the iterables
2190 and pass a *keyfunc* that extracts the first element and a *valuefunc*
2191 that extracts the second element::
2192
2193 >>> from operator import itemgetter
2194 >>> keys = [0, 0, 1, 1, 1, 2, 2, 2, 3]
2195 >>> values = 'abcdefghi'
2196 >>> iterable = zip(keys, values)
2197 >>> grouper = groupby_transform(iterable, itemgetter(0), itemgetter(1))
2198 >>> [(k, ''.join(g)) for k, g in grouper]
2199 [(0, 'ab'), (1, 'cde'), (2, 'fgh'), (3, 'i')]
2200
2201 Note that the order of items in the iterable is significant.
2202 Only adjacent items are grouped together, so if you don't want any
2203 duplicate groups, you should sort the iterable by the key function.
2204
2205 """
2206 ret = groupby(iterable, keyfunc)
2207 if valuefunc:
2208 ret = ((k, map(valuefunc, g)) for k, g in ret)
2209 if reducefunc:
2210 ret = ((k, reducefunc(g)) for k, g in ret)
2211
2212 return ret
2213
2214
2215class numeric_range(abc.Sequence, abc.Hashable):
2216 """An extension of the built-in ``range()`` function whose arguments can
2217 be any orderable numeric type.
2218
2219 With only *stop* specified, *start* defaults to ``0`` and *step*
2220 defaults to ``1``. The output items will match the type of *stop*:
2221
2222 >>> list(numeric_range(3.5))
2223 [0.0, 1.0, 2.0, 3.0]
2224
2225 With only *start* and *stop* specified, *step* defaults to ``1``. The
2226 output items will match the type of *start*:
2227
2228 >>> from decimal import Decimal
2229 >>> start = Decimal('2.1')
2230 >>> stop = Decimal('5.1')
2231 >>> list(numeric_range(start, stop))
2232 [Decimal('2.1'), Decimal('3.1'), Decimal('4.1')]
2233
2234 With *start*, *stop*, and *step* specified the output items will match
2235 the type of ``start + step``:
2236
2237 >>> from fractions import Fraction
2238 >>> start = Fraction(1, 2) # Start at 1/2
2239 >>> stop = Fraction(5, 2) # End at 5/2
2240 >>> step = Fraction(1, 2) # Count by 1/2
2241 >>> list(numeric_range(start, stop, step))
2242 [Fraction(1, 2), Fraction(1, 1), Fraction(3, 2), Fraction(2, 1)]
2243
2244 If *step* is zero, ``ValueError`` is raised. Negative steps are supported:
2245
2246 >>> list(numeric_range(3, -1, -1.0))
2247 [3.0, 2.0, 1.0, 0.0]
2248
2249 Be aware of the limitations of floating-point numbers; the representation
2250 of the yielded numbers may be surprising.
2251
2252 ``datetime.datetime`` objects can be used for *start* and *stop*, if *step*
2253 is a ``datetime.timedelta`` object:
2254
2255 >>> import datetime
2256 >>> start = datetime.datetime(2019, 1, 1)
2257 >>> stop = datetime.datetime(2019, 1, 3)
2258 >>> step = datetime.timedelta(days=1)
2259 >>> items = iter(numeric_range(start, stop, step))
2260 >>> next(items)
2261 datetime.datetime(2019, 1, 1, 0, 0)
2262 >>> next(items)
2263 datetime.datetime(2019, 1, 2, 0, 0)
2264
2265 """
2266
2267 _EMPTY_HASH = hash(range(0, 0))
2268
2269 def __init__(self, *args):
2270 argc = len(args)
2271 if argc == 1:
2272 (self._stop,) = args
2273 self._start = type(self._stop)(0)
2274 self._step = type(self._stop - self._start)(1)
2275 elif argc == 2:
2276 self._start, self._stop = args
2277 self._step = type(self._stop - self._start)(1)
2278 elif argc == 3:
2279 self._start, self._stop, self._step = args
2280 elif argc == 0:
2281 raise TypeError(
2282 f'numeric_range expected at least 1 argument, got {argc}'
2283 )
2284 else:
2285 raise TypeError(
2286 f'numeric_range expected at most 3 arguments, got {argc}'
2287 )
2288
2289 self._zero = type(self._step)(0)
2290 if self._step == self._zero:
2291 raise ValueError('numeric_range() arg 3 must not be zero')
2292 self._growing = self._step > self._zero
2293
2294 def __bool__(self):
2295 if self._growing:
2296 return self._start < self._stop
2297 else:
2298 return self._start > self._stop
2299
2300 def __contains__(self, elem):
2301 if self._growing:
2302 if self._start <= elem < self._stop:
2303 return (elem - self._start) % self._step == self._zero
2304 else:
2305 if self._start >= elem > self._stop:
2306 return (self._start - elem) % (-self._step) == self._zero
2307
2308 return False
2309
2310 def __eq__(self, other):
2311 if isinstance(other, numeric_range):
2312 empty_self = not bool(self)
2313 empty_other = not bool(other)
2314 if empty_self or empty_other:
2315 return empty_self and empty_other # True if both empty
2316 else:
2317 return (
2318 self._start == other._start
2319 and self._step == other._step
2320 and self._get_by_index(-1) == other._get_by_index(-1)
2321 )
2322 else:
2323 return False
2324
2325 def __getitem__(self, key):
2326 if isinstance(key, int):
2327 return self._get_by_index(key)
2328 elif isinstance(key, slice):
2329 step = self._step if key.step is None else key.step * self._step
2330
2331 if key.start is None or key.start <= -self._len:
2332 start = self._start
2333 elif key.start >= self._len:
2334 start = self._stop
2335 else: # -self._len < key.start < self._len
2336 start = self._get_by_index(key.start)
2337
2338 if key.stop is None or key.stop >= self._len:
2339 stop = self._stop
2340 elif key.stop <= -self._len:
2341 stop = self._start
2342 else: # -self._len < key.stop < self._len
2343 stop = self._get_by_index(key.stop)
2344
2345 return numeric_range(start, stop, step)
2346 else:
2347 raise TypeError(
2348 'numeric range indices must be '
2349 f'integers or slices, not {type(key).__name__}'
2350 )
2351
2352 def __hash__(self):
2353 if self:
2354 return hash((self._start, self._get_by_index(-1), self._step))
2355 else:
2356 return self._EMPTY_HASH
2357
2358 def __iter__(self):
2359 values = (self._start + (n * self._step) for n in count())
2360 if self._growing:
2361 return takewhile(partial(gt, self._stop), values)
2362 else:
2363 return takewhile(partial(lt, self._stop), values)
2364
2365 def __len__(self):
2366 return self._len
2367
2368 @cached_property
2369 def _len(self):
2370 if self._growing:
2371 start = self._start
2372 stop = self._stop
2373 step = self._step
2374 else:
2375 start = self._stop
2376 stop = self._start
2377 step = -self._step
2378 distance = stop - start
2379 if distance <= self._zero:
2380 return 0
2381 else: # distance > 0 and step > 0: regular euclidean division
2382 q, r = divmod(distance, step)
2383 return int(q) + int(r != self._zero)
2384
2385 def __reduce__(self):
2386 return numeric_range, (self._start, self._stop, self._step)
2387
2388 def __repr__(self):
2389 if self._step == 1:
2390 return f"numeric_range({self._start!r}, {self._stop!r})"
2391 return (
2392 f"numeric_range({self._start!r}, {self._stop!r}, {self._step!r})"
2393 )
2394
2395 def __reversed__(self):
2396 return iter(
2397 numeric_range(
2398 self._get_by_index(-1), self._start - self._step, -self._step
2399 )
2400 )
2401
2402 def count(self, value):
2403 return int(value in self)
2404
2405 def index(self, value):
2406 if self._growing:
2407 if self._start <= value < self._stop:
2408 q, r = divmod(value - self._start, self._step)
2409 if r == self._zero:
2410 return int(q)
2411 else:
2412 if self._start >= value > self._stop:
2413 q, r = divmod(self._start - value, -self._step)
2414 if r == self._zero:
2415 return int(q)
2416
2417 raise ValueError(f"{value} is not in numeric range")
2418
2419 def _get_by_index(self, i):
2420 if i < 0:
2421 i += self._len
2422 if i < 0 or i >= self._len:
2423 raise IndexError("numeric range object index out of range")
2424 return self._start + i * self._step
2425
2426
2427def count_cycle(iterable, n=None):
2428 """Cycle through the items from *iterable* up to *n* times, yielding
2429 the number of completed cycles along with each item. If *n* is omitted the
2430 process repeats indefinitely.
2431
2432 >>> list(count_cycle('AB', 3))
2433 [(0, 'A'), (0, 'B'), (1, 'A'), (1, 'B'), (2, 'A'), (2, 'B')]
2434
2435 """
2436 iterable = tuple(iterable)
2437 if not iterable:
2438 return iter(())
2439 counter = count() if n is None else range(n)
2440 return ((i, item) for i in counter for item in iterable)
2441
2442
2443def mark_ends(iterable):
2444 """Yield 3-tuples of the form ``(is_first, is_last, item)``.
2445
2446 >>> list(mark_ends('ABC'))
2447 [(True, False, 'A'), (False, False, 'B'), (False, True, 'C')]
2448
2449 Use this when looping over an iterable to take special action on its first
2450 and/or last items:
2451
2452 >>> iterable = ['Header', 100, 200, 'Footer']
2453 >>> total = 0
2454 >>> for is_first, is_last, item in mark_ends(iterable):
2455 ... if is_first:
2456 ... continue # Skip the header
2457 ... if is_last:
2458 ... continue # Skip the footer
2459 ... total += item
2460 >>> print(total)
2461 300
2462 """
2463 it = iter(iterable)
2464 for a in it:
2465 first = True
2466 for b in it:
2467 yield first, False, a
2468 a = b
2469 first = False
2470 yield first, True, a
2471
2472
2473def locate(iterable, pred=bool, window_size=None):
2474 """Yield the index of each item in *iterable* for which *pred* returns
2475 ``True``.
2476
2477 *pred* defaults to :func:`bool`, which will select truthy items:
2478
2479 >>> list(locate([0, 1, 1, 0, 1, 0, 0]))
2480 [1, 2, 4]
2481
2482 Set *pred* to a custom function to, e.g., find the indexes for a particular
2483 item.
2484
2485 >>> list(locate(['a', 'b', 'c', 'b'], lambda x: x == 'b'))
2486 [1, 3]
2487
2488 If *window_size* is given, then the *pred* function will be called with
2489 that many items. This enables searching for sub-sequences:
2490
2491 >>> iterable = [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]
2492 >>> pred = lambda *args: args == (1, 2, 3)
2493 >>> list(locate(iterable, pred=pred, window_size=3))
2494 [1, 5, 9]
2495
2496 Use with :func:`seekable` to find indexes and then retrieve the associated
2497 items:
2498
2499 >>> from itertools import count
2500 >>> from more_itertools import seekable
2501 >>> source = (3 * n + 1 if (n % 2) else n // 2 for n in count())
2502 >>> it = seekable(source)
2503 >>> pred = lambda x: x > 100
2504 >>> indexes = locate(it, pred=pred)
2505 >>> i = next(indexes)
2506 >>> it.seek(i)
2507 >>> next(it)
2508 106
2509
2510 """
2511 if window_size is None:
2512 return compress(count(), map(pred, iterable))
2513
2514 if window_size < 1:
2515 raise ValueError('window size must be at least 1')
2516
2517 it = windowed(iterable, window_size, fillvalue=_marker)
2518 return compress(count(), starmap(pred, it))
2519
2520
2521def longest_common_prefix(iterables):
2522 """Yield elements of the longest common prefix among given *iterables*.
2523
2524 >>> ''.join(longest_common_prefix(['abcd', 'abc', 'abf']))
2525 'ab'
2526
2527 """
2528 return (c[0] for c in takewhile(all_equal, zip(*iterables)))
2529
2530
2531def lstrip(iterable, pred):
2532 """Yield the items from *iterable*, but strip any from the beginning
2533 for which *pred* returns ``True``.
2534
2535 For example, to remove a set of items from the start of an iterable:
2536
2537 >>> iterable = (None, False, None, 1, 2, None, 3, False, None)
2538 >>> pred = lambda x: x in {None, False, ''}
2539 >>> list(lstrip(iterable, pred))
2540 [1, 2, None, 3, False, None]
2541
2542 This function is analogous to to :func:`str.lstrip`, and is essentially
2543 an wrapper for :func:`itertools.dropwhile`.
2544
2545 """
2546 return dropwhile(pred, iterable)
2547
2548
2549def rstrip(iterable, pred):
2550 """Yield the items from *iterable*, but strip any from the end
2551 for which *pred* returns ``True``.
2552
2553 For example, to remove a set of items from the end of an iterable:
2554
2555 >>> iterable = (None, False, None, 1, 2, None, 3, False, None)
2556 >>> pred = lambda x: x in {None, False, ''}
2557 >>> list(rstrip(iterable, pred))
2558 [None, False, None, 1, 2, None, 3]
2559
2560 This function is analogous to :func:`str.rstrip`.
2561
2562 """
2563 cache = []
2564 cache_append = cache.append
2565 cache_clear = cache.clear
2566 for x in iterable:
2567 if pred(x):
2568 cache_append(x)
2569 else:
2570 yield from cache
2571 cache_clear()
2572 yield x
2573
2574
2575def strip(iterable, pred):
2576 """Yield the items from *iterable*, but strip any from the
2577 beginning and end for which *pred* returns ``True``.
2578
2579 For example, to remove a set of items from both ends of an iterable:
2580
2581 >>> iterable = (None, False, None, 1, 2, None, 3, False, None)
2582 >>> pred = lambda x: x in {None, False, ''}
2583 >>> list(strip(iterable, pred))
2584 [1, 2, None, 3]
2585
2586 This function is analogous to :func:`str.strip`.
2587
2588 """
2589 return rstrip(lstrip(iterable, pred), pred)
2590
2591
2592class islice_extended:
2593 """An extension of :func:`itertools.islice` that supports negative values
2594 for *stop*, *start*, and *step*.
2595
2596 >>> iterator = iter('abcdefgh')
2597 >>> list(islice_extended(iterator, -4, -1))
2598 ['e', 'f', 'g']
2599
2600 Slices with negative values require some caching of *iterable*, but this
2601 function takes care to minimize the amount of memory required.
2602
2603 For example, you can use a negative step with an infinite iterator:
2604
2605 >>> from itertools import count
2606 >>> list(islice_extended(count(), 110, 99, -2))
2607 [110, 108, 106, 104, 102, 100]
2608
2609 You can also use slice notation directly:
2610
2611 >>> iterator = map(str, count())
2612 >>> it = islice_extended(iterator)[10:20:2]
2613 >>> list(it)
2614 ['10', '12', '14', '16', '18']
2615
2616 """
2617
2618 def __init__(self, iterable, *args):
2619 it = iter(iterable)
2620 if args:
2621 self._iterator = _islice_helper(it, slice(*args))
2622 else:
2623 self._iterator = it
2624
2625 def __iter__(self):
2626 return self
2627
2628 def __next__(self):
2629 return next(self._iterator)
2630
2631 def __getitem__(self, key):
2632 if isinstance(key, slice):
2633 return islice_extended(_islice_helper(self._iterator, key))
2634
2635 raise TypeError('islice_extended.__getitem__ argument must be a slice')
2636
2637
2638def _islice_helper(it, s):
2639 start = s.start
2640 stop = s.stop
2641 if s.step == 0:
2642 raise ValueError('step argument must be a non-zero integer or None.')
2643 step = s.step or 1
2644
2645 if step > 0:
2646 start = 0 if (start is None) else start
2647
2648 if start < 0:
2649 # Consume all but the last -start items
2650 cache = deque(enumerate(it, 1), maxlen=-start)
2651 len_iter = cache[-1][0] if cache else 0
2652
2653 # Adjust start to be positive
2654 i = max(len_iter + start, 0)
2655
2656 # Adjust stop to be positive
2657 if stop is None:
2658 j = len_iter
2659 elif stop >= 0:
2660 j = min(stop, len_iter)
2661 else:
2662 j = max(len_iter + stop, 0)
2663
2664 # Slice the cache
2665 n = j - i
2666 if n <= 0:
2667 return
2668
2669 for index in range(n):
2670 if index % step == 0:
2671 # pop and yield the item.
2672 # We don't want to use an intermediate variable
2673 # it would extend the lifetime of the current item
2674 yield cache.popleft()[1]
2675 else:
2676 # just pop and discard the item
2677 cache.popleft()
2678 elif (stop is not None) and (stop < 0):
2679 # Advance to the start position
2680 next(islice(it, start, start), None)
2681
2682 # When stop is negative, we have to carry -stop items while
2683 # iterating
2684 cache = deque(islice(it, -stop), maxlen=-stop)
2685
2686 for index, item in enumerate(it):
2687 if index % step == 0:
2688 # pop and yield the item.
2689 # We don't want to use an intermediate variable
2690 # it would extend the lifetime of the current item
2691 yield cache.popleft()
2692 else:
2693 # just pop and discard the item
2694 cache.popleft()
2695 cache.append(item)
2696 else:
2697 # When both start and stop are positive we have the normal case
2698 yield from islice(it, start, stop, step)
2699 else:
2700 start = -1 if (start is None) else start
2701
2702 if (stop is not None) and (stop < 0):
2703 # Consume all but the last items
2704 n = -stop - 1
2705 cache = deque(enumerate(it, 1), maxlen=n)
2706 len_iter = cache[-1][0] if cache else 0
2707
2708 # If start and stop are both negative they are comparable and
2709 # we can just slice. Otherwise we can adjust start to be negative
2710 # and then slice.
2711 if start < 0:
2712 i, j = start, stop
2713 else:
2714 i, j = min(start - len_iter, -1), None
2715
2716 for index, item in list(cache)[i:j:step]:
2717 yield item
2718 else:
2719 # Advance to the stop position
2720 if stop is not None:
2721 m = stop + 1
2722 next(islice(it, m, m), None)
2723
2724 # stop is positive, so if start is negative they are not comparable
2725 # and we need the rest of the items.
2726 if start < 0:
2727 i = start
2728 n = None
2729 # stop is None and start is positive, so we just need items up to
2730 # the start index.
2731 elif stop is None:
2732 i = None
2733 n = start + 1
2734 # Both stop and start are positive, so they are comparable.
2735 else:
2736 i = None
2737 n = start - stop
2738 if n <= 0:
2739 return
2740
2741 cache = list(islice(it, n))
2742
2743 yield from cache[i::step]
2744
2745
2746def always_reversible(iterable):
2747 """An extension of :func:`reversed` that supports all iterables, not
2748 just those which implement the ``Reversible`` or ``Sequence`` protocols.
2749
2750 >>> print(*always_reversible(x for x in range(3)))
2751 2 1 0
2752
2753 If the iterable is already reversible, this function returns the
2754 result of :func:`reversed()`. If the iterable is not reversible,
2755 this function will cache the remaining items in the iterable and
2756 yield them in reverse order, which may require significant storage.
2757 """
2758 try:
2759 return reversed(iterable)
2760 except TypeError:
2761 return reversed(list(iterable))
2762
2763
2764def consecutive_groups(iterable, ordering=None):
2765 """Yield groups of consecutive items using :func:`itertools.groupby`.
2766 The *ordering* function determines whether two items are adjacent by
2767 returning their position.
2768
2769 By default, the ordering function is the identity function. This is
2770 suitable for finding runs of numbers:
2771
2772 >>> iterable = [1, 10, 11, 12, 20, 30, 31, 32, 33, 40]
2773 >>> for group in consecutive_groups(iterable):
2774 ... print(list(group))
2775 [1]
2776 [10, 11, 12]
2777 [20]
2778 [30, 31, 32, 33]
2779 [40]
2780
2781 To find runs of adjacent letters, apply :func:`ord` function
2782 to convert letters to ordinals.
2783
2784 >>> iterable = 'abcdfgilmnop'
2785 >>> ordering = ord
2786 >>> for group in consecutive_groups(iterable, ordering):
2787 ... print(list(group))
2788 ['a', 'b', 'c', 'd']
2789 ['f', 'g']
2790 ['i']
2791 ['l', 'm', 'n', 'o', 'p']
2792
2793 Each group of consecutive items is an iterator that shares it source with
2794 *iterable*. When an an output group is advanced, the previous group is
2795 no longer available unless its elements are copied (e.g., into a ``list``).
2796
2797 >>> iterable = [1, 2, 11, 12, 21, 22]
2798 >>> saved_groups = []
2799 >>> for group in consecutive_groups(iterable):
2800 ... saved_groups.append(list(group)) # Copy group elements
2801 >>> saved_groups
2802 [[1, 2], [11, 12], [21, 22]]
2803
2804 """
2805 if ordering is None:
2806 key = lambda x: x[0] - x[1]
2807 else:
2808 key = lambda x: x[0] - ordering(x[1])
2809
2810 for k, g in groupby(enumerate(iterable), key=key):
2811 yield map(itemgetter(1), g)
2812
2813
2814def difference(iterable, func=sub, *, initial=None):
2815 """This function is the inverse of :func:`itertools.accumulate`. By default
2816 it will compute the first difference of *iterable* using
2817 :func:`operator.sub`:
2818
2819 >>> from itertools import accumulate
2820 >>> iterable = accumulate([0, 1, 2, 3, 4]) # produces 0, 1, 3, 6, 10
2821 >>> list(difference(iterable))
2822 [0, 1, 2, 3, 4]
2823
2824 *func* defaults to :func:`operator.sub`, but other functions can be
2825 specified. They will be applied as follows::
2826
2827 A, B, C, D, ... --> A, func(B, A), func(C, B), func(D, C), ...
2828
2829 For example, to do progressive division:
2830
2831 >>> iterable = [1, 2, 6, 24, 120]
2832 >>> func = lambda x, y: x // y
2833 >>> list(difference(iterable, func))
2834 [1, 2, 3, 4, 5]
2835
2836 If the *initial* keyword is set, the first element will be skipped when
2837 computing successive differences.
2838
2839 >>> it = [10, 11, 13, 16] # from accumulate([1, 2, 3], initial=10)
2840 >>> list(difference(it, initial=10))
2841 [1, 2, 3]
2842
2843 """
2844 a, b = tee(iterable)
2845 try:
2846 first = [next(b)]
2847 except StopIteration:
2848 return iter([])
2849
2850 if initial is not None:
2851 first = []
2852
2853 return chain(first, map(func, b, a))
2854
2855
2856class SequenceView(Sequence):
2857 """Return a read-only view of the sequence object *target*.
2858
2859 :class:`SequenceView` objects are analogous to Python's built-in
2860 "dictionary view" types. They provide a dynamic view of a sequence's items,
2861 meaning that when the sequence updates, so does the view.
2862
2863 >>> seq = ['0', '1', '2']
2864 >>> view = SequenceView(seq)
2865 >>> view
2866 SequenceView(['0', '1', '2'])
2867 >>> seq.append('3')
2868 >>> view
2869 SequenceView(['0', '1', '2', '3'])
2870
2871 Sequence views support indexing, slicing, and length queries. They act
2872 like the underlying sequence, except they don't allow assignment:
2873
2874 >>> view[1]
2875 '1'
2876 >>> view[1:-1]
2877 ['1', '2']
2878 >>> len(view)
2879 4
2880
2881 Sequence views are useful as an alternative to copying, as they don't
2882 require (much) extra storage.
2883
2884 """
2885
2886 def __init__(self, target):
2887 if not isinstance(target, Sequence):
2888 raise TypeError
2889 self._target = target
2890
2891 def __getitem__(self, index):
2892 return self._target[index]
2893
2894 def __len__(self):
2895 return len(self._target)
2896
2897 def __repr__(self):
2898 return f'{self.__class__.__name__}({self._target!r})'
2899
2900
2901class seekable:
2902 """Wrap an iterator to allow for seeking backward and forward. This
2903 progressively caches the items in the source iterable so they can be
2904 re-visited.
2905
2906 Call :meth:`seek` with an index to seek to that position in the source
2907 iterable.
2908
2909 To "reset" an iterator, seek to ``0``:
2910
2911 >>> from itertools import count
2912 >>> it = seekable((str(n) for n in count()))
2913 >>> next(it), next(it), next(it)
2914 ('0', '1', '2')
2915 >>> it.seek(0)
2916 >>> next(it), next(it), next(it)
2917 ('0', '1', '2')
2918
2919 You can also seek forward:
2920
2921 >>> it = seekable((str(n) for n in range(20)))
2922 >>> it.seek(10)
2923 >>> next(it)
2924 '10'
2925 >>> it.seek(20) # Seeking past the end of the source isn't a problem
2926 >>> list(it)
2927 []
2928 >>> it.seek(0) # Resetting works even after hitting the end
2929 >>> next(it)
2930 '0'
2931
2932 Call :meth:`relative_seek` to seek relative to the source iterator's
2933 current position.
2934
2935 >>> it = seekable((str(n) for n in range(20)))
2936 >>> next(it), next(it), next(it)
2937 ('0', '1', '2')
2938 >>> it.relative_seek(2)
2939 >>> next(it)
2940 '5'
2941 >>> it.relative_seek(-3) # Source is at '6', we move back to '3'
2942 >>> next(it)
2943 '3'
2944 >>> it.relative_seek(-3) # Source is at '4', we move back to '1'
2945 >>> next(it)
2946 '1'
2947
2948
2949 Call :meth:`peek` to look ahead one item without advancing the iterator:
2950
2951 >>> it = seekable('1234')
2952 >>> it.peek()
2953 '1'
2954 >>> list(it)
2955 ['1', '2', '3', '4']
2956 >>> it.peek(default='empty')
2957 'empty'
2958
2959 Before the iterator is at its end, calling :func:`bool` on it will return
2960 ``True``. After it will return ``False``:
2961
2962 >>> it = seekable('5678')
2963 >>> bool(it)
2964 True
2965 >>> list(it)
2966 ['5', '6', '7', '8']
2967 >>> bool(it)
2968 False
2969
2970 You may view the contents of the cache with the :meth:`elements` method.
2971 That returns a :class:`SequenceView`, a view that updates automatically:
2972
2973 >>> it = seekable((str(n) for n in range(10)))
2974 >>> next(it), next(it), next(it)
2975 ('0', '1', '2')
2976 >>> elements = it.elements()
2977 >>> elements
2978 SequenceView(['0', '1', '2'])
2979 >>> next(it)
2980 '3'
2981 >>> elements
2982 SequenceView(['0', '1', '2', '3'])
2983
2984 By default, the cache grows as the source iterable progresses, so beware of
2985 wrapping very large or infinite iterables. Supply *maxlen* to limit the
2986 size of the cache (this of course limits how far back you can seek).
2987
2988 >>> from itertools import count
2989 >>> it = seekable((str(n) for n in count()), maxlen=2)
2990 >>> next(it), next(it), next(it), next(it)
2991 ('0', '1', '2', '3')
2992 >>> list(it.elements())
2993 ['2', '3']
2994 >>> it.seek(0)
2995 >>> next(it), next(it), next(it), next(it)
2996 ('2', '3', '4', '5')
2997 >>> next(it)
2998 '6'
2999
3000 """
3001
3002 def __init__(self, iterable, maxlen=None):
3003 self._source = iter(iterable)
3004 if maxlen is None:
3005 self._cache = []
3006 else:
3007 self._cache = deque([], maxlen)
3008 self._index = None
3009
3010 def __iter__(self):
3011 return self
3012
3013 def __next__(self):
3014 if self._index is not None:
3015 try:
3016 item = self._cache[self._index]
3017 except IndexError:
3018 self._index = None
3019 else:
3020 self._index += 1
3021 return item
3022
3023 item = next(self._source)
3024 self._cache.append(item)
3025 return item
3026
3027 def __bool__(self):
3028 try:
3029 self.peek()
3030 except StopIteration:
3031 return False
3032 return True
3033
3034 def peek(self, default=_marker):
3035 try:
3036 peeked = next(self)
3037 except StopIteration:
3038 if default is _marker:
3039 raise
3040 return default
3041 if self._index is None:
3042 self._index = len(self._cache)
3043 self._index -= 1
3044 return peeked
3045
3046 def elements(self):
3047 return SequenceView(self._cache)
3048
3049 def seek(self, index):
3050 self._index = index
3051 remainder = index - len(self._cache)
3052 if remainder > 0:
3053 consume(self, remainder)
3054
3055 def relative_seek(self, count):
3056 if self._index is None:
3057 self._index = len(self._cache)
3058
3059 self.seek(max(self._index + count, 0))
3060
3061
3062class run_length:
3063 """
3064 :func:`run_length.encode` compresses an iterable with run-length encoding.
3065 It yields groups of repeated items with the count of how many times they
3066 were repeated:
3067
3068 >>> uncompressed = 'abbcccdddd'
3069 >>> list(run_length.encode(uncompressed))
3070 [('a', 1), ('b', 2), ('c', 3), ('d', 4)]
3071
3072 :func:`run_length.decode` decompresses an iterable that was previously
3073 compressed with run-length encoding. It yields the items of the
3074 decompressed iterable:
3075
3076 >>> compressed = [('a', 1), ('b', 2), ('c', 3), ('d', 4)]
3077 >>> list(run_length.decode(compressed))
3078 ['a', 'b', 'b', 'c', 'c', 'c', 'd', 'd', 'd', 'd']
3079
3080 """
3081
3082 @staticmethod
3083 def encode(iterable):
3084 return ((k, ilen(g)) for k, g in groupby(iterable))
3085
3086 @staticmethod
3087 def decode(iterable):
3088 return chain.from_iterable(starmap(repeat, iterable))
3089
3090
3091def exactly_n(iterable, n, predicate=bool):
3092 """Return ``True`` if exactly ``n`` items in the iterable are ``True``
3093 according to the *predicate* function.
3094
3095 >>> exactly_n([True, True, False], 2)
3096 True
3097 >>> exactly_n([True, True, False], 1)
3098 False
3099 >>> exactly_n([0, 1, 2, 3, 4, 5], 3, lambda x: x < 3)
3100 True
3101
3102 The iterable will be advanced until ``n + 1`` truthy items are encountered,
3103 so avoid calling it on infinite iterables.
3104
3105 """
3106 return ilen(islice(filter(predicate, iterable), n + 1)) == n
3107
3108
3109def circular_shifts(iterable, steps=1):
3110 """Yield the circular shifts of *iterable*.
3111
3112 >>> list(circular_shifts(range(4)))
3113 [(0, 1, 2, 3), (1, 2, 3, 0), (2, 3, 0, 1), (3, 0, 1, 2)]
3114
3115 Set *steps* to the number of places to rotate to the left
3116 (or to the right if negative). Defaults to 1.
3117
3118 >>> list(circular_shifts(range(4), 2))
3119 [(0, 1, 2, 3), (2, 3, 0, 1)]
3120
3121 >>> list(circular_shifts(range(4), -1))
3122 [(0, 1, 2, 3), (3, 0, 1, 2), (2, 3, 0, 1), (1, 2, 3, 0)]
3123
3124 """
3125 buffer = deque(iterable)
3126 if steps == 0:
3127 raise ValueError('Steps should be a non-zero integer')
3128
3129 buffer.rotate(steps)
3130 steps = -steps
3131 n = len(buffer)
3132 n //= math.gcd(n, steps)
3133
3134 for _ in repeat(None, n):
3135 buffer.rotate(steps)
3136 yield tuple(buffer)
3137
3138
3139def make_decorator(wrapping_func, result_index=0):
3140 """Return a decorator version of *wrapping_func*, which is a function that
3141 modifies an iterable. *result_index* is the position in that function's
3142 signature where the iterable goes.
3143
3144 This lets you use itertools on the "production end," i.e. at function
3145 definition. This can augment what the function returns without changing the
3146 function's code.
3147
3148 For example, to produce a decorator version of :func:`chunked`:
3149
3150 >>> from more_itertools import chunked
3151 >>> chunker = make_decorator(chunked, result_index=0)
3152 >>> @chunker(3)
3153 ... def iter_range(n):
3154 ... return iter(range(n))
3155 ...
3156 >>> list(iter_range(9))
3157 [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
3158
3159 To only allow truthy items to be returned:
3160
3161 >>> truth_serum = make_decorator(filter, result_index=1)
3162 >>> @truth_serum(bool)
3163 ... def boolean_test():
3164 ... return [0, 1, '', ' ', False, True]
3165 ...
3166 >>> list(boolean_test())
3167 [1, ' ', True]
3168
3169 The :func:`peekable` and :func:`seekable` wrappers make for practical
3170 decorators:
3171
3172 >>> from more_itertools import peekable
3173 >>> peekable_function = make_decorator(peekable)
3174 >>> @peekable_function()
3175 ... def str_range(*args):
3176 ... return (str(x) for x in range(*args))
3177 ...
3178 >>> it = str_range(1, 20, 2)
3179 >>> next(it), next(it), next(it)
3180 ('1', '3', '5')
3181 >>> it.peek()
3182 '7'
3183 >>> next(it)
3184 '7'
3185
3186 """
3187
3188 # See https://sites.google.com/site/bbayles/index/decorator_factory for
3189 # notes on how this works.
3190 def decorator(*wrapping_args, **wrapping_kwargs):
3191 def outer_wrapper(f):
3192 def inner_wrapper(*args, **kwargs):
3193 result = f(*args, **kwargs)
3194 wrapping_args_ = list(wrapping_args)
3195 wrapping_args_.insert(result_index, result)
3196 return wrapping_func(*wrapping_args_, **wrapping_kwargs)
3197
3198 return inner_wrapper
3199
3200 return outer_wrapper
3201
3202 return decorator
3203
3204
3205def map_reduce(iterable, keyfunc, valuefunc=None, reducefunc=None):
3206 """Return a dictionary that maps the items in *iterable* to categories
3207 defined by *keyfunc*, transforms them with *valuefunc*, and
3208 then summarizes them by category with *reducefunc*.
3209
3210 *valuefunc* defaults to the identity function if it is unspecified.
3211 If *reducefunc* is unspecified, no summarization takes place:
3212
3213 >>> keyfunc = lambda x: x.upper()
3214 >>> result = map_reduce('abbccc', keyfunc)
3215 >>> sorted(result.items())
3216 [('A', ['a']), ('B', ['b', 'b']), ('C', ['c', 'c', 'c'])]
3217
3218 Specifying *valuefunc* transforms the categorized items:
3219
3220 >>> keyfunc = lambda x: x.upper()
3221 >>> valuefunc = lambda x: 1
3222 >>> result = map_reduce('abbccc', keyfunc, valuefunc)
3223 >>> sorted(result.items())
3224 [('A', [1]), ('B', [1, 1]), ('C', [1, 1, 1])]
3225
3226 Specifying *reducefunc* summarizes the categorized items:
3227
3228 >>> keyfunc = lambda x: x.upper()
3229 >>> valuefunc = lambda x: 1
3230 >>> reducefunc = sum
3231 >>> result = map_reduce('abbccc', keyfunc, valuefunc, reducefunc)
3232 >>> sorted(result.items())
3233 [('A', 1), ('B', 2), ('C', 3)]
3234
3235 You may want to filter the input iterable before applying the map/reduce
3236 procedure:
3237
3238 >>> all_items = range(30)
3239 >>> items = [x for x in all_items if 10 <= x <= 20] # Filter
3240 >>> keyfunc = lambda x: x % 2 # Evens map to 0; odds to 1
3241 >>> categories = map_reduce(items, keyfunc=keyfunc)
3242 >>> sorted(categories.items())
3243 [(0, [10, 12, 14, 16, 18, 20]), (1, [11, 13, 15, 17, 19])]
3244 >>> summaries = map_reduce(items, keyfunc=keyfunc, reducefunc=sum)
3245 >>> sorted(summaries.items())
3246 [(0, 90), (1, 75)]
3247
3248 Note that all items in the iterable are gathered into a list before the
3249 summarization step, which may require significant storage.
3250
3251 The returned object is a :obj:`collections.defaultdict` with the
3252 ``default_factory`` set to ``None``, such that it behaves like a normal
3253 dictionary.
3254
3255 """
3256
3257 ret = defaultdict(list)
3258
3259 if valuefunc is None:
3260 for item in iterable:
3261 key = keyfunc(item)
3262 ret[key].append(item)
3263
3264 else:
3265 for item in iterable:
3266 key = keyfunc(item)
3267 value = valuefunc(item)
3268 ret[key].append(value)
3269
3270 if reducefunc is not None:
3271 for key, value_list in ret.items():
3272 ret[key] = reducefunc(value_list)
3273
3274 ret.default_factory = None
3275 return ret
3276
3277
3278def rlocate(iterable, pred=bool, window_size=None):
3279 """Yield the index of each item in *iterable* for which *pred* returns
3280 ``True``, starting from the right and moving left.
3281
3282 *pred* defaults to :func:`bool`, which will select truthy items:
3283
3284 >>> list(rlocate([0, 1, 1, 0, 1, 0, 0])) # Truthy at 1, 2, and 4
3285 [4, 2, 1]
3286
3287 Set *pred* to a custom function to, e.g., find the indexes for a particular
3288 item:
3289
3290 >>> iterator = iter('abcb')
3291 >>> pred = lambda x: x == 'b'
3292 >>> list(rlocate(iterator, pred))
3293 [3, 1]
3294
3295 If *window_size* is given, then the *pred* function will be called with
3296 that many items. This enables searching for sub-sequences:
3297
3298 >>> iterable = [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]
3299 >>> pred = lambda *args: args == (1, 2, 3)
3300 >>> list(rlocate(iterable, pred=pred, window_size=3))
3301 [9, 5, 1]
3302
3303 Beware, this function won't return anything for infinite iterables.
3304 If *iterable* is reversible, ``rlocate`` will reverse it and search from
3305 the right. Otherwise, it will search from the left and return the results
3306 in reverse order.
3307
3308 See :func:`locate` to for other example applications.
3309
3310 """
3311 if window_size is None:
3312 try:
3313 len_iter = len(iterable)
3314 return (len_iter - i - 1 for i in locate(reversed(iterable), pred))
3315 except TypeError:
3316 pass
3317
3318 return reversed(list(locate(iterable, pred, window_size)))
3319
3320
3321def replace(iterable, pred, substitutes, count=None, window_size=1):
3322 """Yield the items from *iterable*, replacing the items for which *pred*
3323 returns ``True`` with the items from the iterable *substitutes*.
3324
3325 >>> iterable = [1, 1, 0, 1, 1, 0, 1, 1]
3326 >>> pred = lambda x: x == 0
3327 >>> substitutes = (2, 3)
3328 >>> list(replace(iterable, pred, substitutes))
3329 [1, 1, 2, 3, 1, 1, 2, 3, 1, 1]
3330
3331 If *count* is given, the number of replacements will be limited:
3332
3333 >>> iterable = [1, 1, 0, 1, 1, 0, 1, 1, 0]
3334 >>> pred = lambda x: x == 0
3335 >>> substitutes = [None]
3336 >>> list(replace(iterable, pred, substitutes, count=2))
3337 [1, 1, None, 1, 1, None, 1, 1, 0]
3338
3339 Use *window_size* to control the number of items passed as arguments to
3340 *pred*. This allows for locating and replacing subsequences.
3341
3342 >>> iterable = [0, 1, 2, 5, 0, 1, 2, 5]
3343 >>> window_size = 3
3344 >>> pred = lambda *args: args == (0, 1, 2) # 3 items passed to pred
3345 >>> substitutes = [3, 4] # Splice in these items
3346 >>> list(replace(iterable, pred, substitutes, window_size=window_size))
3347 [3, 4, 5, 3, 4, 5]
3348
3349 """
3350 if window_size < 1:
3351 raise ValueError('window_size must be at least 1')
3352
3353 # Save the substitutes iterable, since it's used more than once
3354 substitutes = tuple(substitutes)
3355
3356 # Add padding such that the number of windows matches the length of the
3357 # iterable
3358 it = chain(iterable, repeat(_marker, window_size - 1))
3359 windows = windowed(it, window_size)
3360
3361 n = 0
3362 for w in windows:
3363 # If the current window matches our predicate (and we haven't hit
3364 # our maximum number of replacements), splice in the substitutes
3365 # and then consume the following windows that overlap with this one.
3366 # For example, if the iterable is (0, 1, 2, 3, 4...)
3367 # and the window size is 2, we have (0, 1), (1, 2), (2, 3)...
3368 # If the predicate matches on (0, 1), we need to zap (0, 1) and (1, 2)
3369 if pred(*w):
3370 if (count is None) or (n < count):
3371 n += 1
3372 yield from substitutes
3373 consume(windows, window_size - 1)
3374 continue
3375
3376 # If there was no match (or we've reached the replacement limit),
3377 # yield the first item from the window.
3378 if w and (w[0] is not _marker):
3379 yield w[0]
3380
3381
3382def partitions(iterable):
3383 """Yield all possible order-preserving partitions of *iterable*.
3384
3385 >>> iterable = 'abc'
3386 >>> for part in partitions(iterable):
3387 ... print([''.join(p) for p in part])
3388 ['abc']
3389 ['a', 'bc']
3390 ['ab', 'c']
3391 ['a', 'b', 'c']
3392
3393 This is unrelated to :func:`partition`.
3394
3395 """
3396 sequence = list(iterable)
3397 n = len(sequence)
3398 for i in powerset(range(1, n)):
3399 yield [sequence[i:j] for i, j in zip((0,) + i, i + (n,))]
3400
3401
3402def set_partitions(iterable, k=None, min_size=None, max_size=None):
3403 """
3404 Yield the set partitions of *iterable* into *k* parts. Set partitions are
3405 not order-preserving.
3406
3407 >>> iterable = 'abc'
3408 >>> for part in set_partitions(iterable, 2):
3409 ... print([''.join(p) for p in part])
3410 ['a', 'bc']
3411 ['ab', 'c']
3412 ['b', 'ac']
3413
3414
3415 If *k* is not given, every set partition is generated.
3416
3417 >>> iterable = 'abc'
3418 >>> for part in set_partitions(iterable):
3419 ... print([''.join(p) for p in part])
3420 ['abc']
3421 ['a', 'bc']
3422 ['ab', 'c']
3423 ['b', 'ac']
3424 ['a', 'b', 'c']
3425
3426 if *min_size* and/or *max_size* are given, the minimum and/or maximum size
3427 per block in partition is set.
3428
3429 >>> iterable = 'abc'
3430 >>> for part in set_partitions(iterable, min_size=2):
3431 ... print([''.join(p) for p in part])
3432 ['abc']
3433 >>> for part in set_partitions(iterable, max_size=2):
3434 ... print([''.join(p) for p in part])
3435 ['a', 'bc']
3436 ['ab', 'c']
3437 ['b', 'ac']
3438 ['a', 'b', 'c']
3439
3440 """
3441 L = list(iterable)
3442 n = len(L)
3443 if k is not None:
3444 if k < 1:
3445 raise ValueError(
3446 "Can't partition in a negative or zero number of groups"
3447 )
3448 elif k > n:
3449 return
3450
3451 min_size = min_size if min_size is not None else 0
3452 max_size = max_size if max_size is not None else n
3453 if min_size > max_size:
3454 return
3455
3456 def set_partitions_helper(L, k):
3457 n = len(L)
3458 if k == 1:
3459 yield [L]
3460 elif n == k:
3461 yield [[s] for s in L]
3462 else:
3463 e, *M = L
3464 for p in set_partitions_helper(M, k - 1):
3465 yield [[e], *p]
3466 for p in set_partitions_helper(M, k):
3467 for i in range(len(p)):
3468 yield p[:i] + [[e] + p[i]] + p[i + 1 :]
3469
3470 if k is None:
3471 for k in range(1, n + 1):
3472 yield from filter(
3473 lambda z: all(min_size <= len(bk) <= max_size for bk in z),
3474 set_partitions_helper(L, k),
3475 )
3476 else:
3477 yield from filter(
3478 lambda z: all(min_size <= len(bk) <= max_size for bk in z),
3479 set_partitions_helper(L, k),
3480 )
3481
3482
3483class time_limited:
3484 """
3485 Yield items from *iterable* until *limit_seconds* have passed.
3486 If the time limit expires before all items have been yielded, the
3487 ``timed_out`` parameter will be set to ``True``.
3488
3489 >>> from time import sleep
3490 >>> def generator():
3491 ... yield 1
3492 ... yield 2
3493 ... sleep(0.2)
3494 ... yield 3
3495 >>> iterable = time_limited(0.1, generator())
3496 >>> list(iterable)
3497 [1, 2]
3498 >>> iterable.timed_out
3499 True
3500
3501 Note that the time is checked before each item is yielded, and iteration
3502 stops if the time elapsed is greater than *limit_seconds*. If your time
3503 limit is 1 second, but it takes 2 seconds to generate the first item from
3504 the iterable, the function will run for 2 seconds and not yield anything.
3505 As a special case, when *limit_seconds* is zero, the iterator never
3506 returns anything.
3507
3508 """
3509
3510 def __init__(self, limit_seconds, iterable):
3511 if limit_seconds < 0:
3512 raise ValueError('limit_seconds must be positive')
3513 self.limit_seconds = limit_seconds
3514 self._iterator = iter(iterable)
3515 self._start_time = monotonic()
3516 self.timed_out = False
3517
3518 def __iter__(self):
3519 return self
3520
3521 def __next__(self):
3522 if self.limit_seconds == 0:
3523 self.timed_out = True
3524 raise StopIteration
3525 item = next(self._iterator)
3526 if monotonic() - self._start_time > self.limit_seconds:
3527 self.timed_out = True
3528 raise StopIteration
3529
3530 return item
3531
3532
3533def only(iterable, default=None, too_long=None):
3534 """If *iterable* has only one item, return it.
3535 If it has zero items, return *default*.
3536 If it has more than one item, raise the exception given by *too_long*,
3537 which is ``ValueError`` by default.
3538
3539 >>> only([], default='missing')
3540 'missing'
3541 >>> only([1])
3542 1
3543 >>> only([1, 2]) # doctest: +IGNORE_EXCEPTION_DETAIL
3544 Traceback (most recent call last):
3545 ...
3546 ValueError: Expected exactly one item in iterable, but got 1, 2,
3547 and perhaps more.'
3548 >>> only([1, 2], too_long=TypeError) # doctest: +IGNORE_EXCEPTION_DETAIL
3549 Traceback (most recent call last):
3550 ...
3551 TypeError
3552
3553 Note that :func:`only` attempts to advance *iterable* twice to ensure there
3554 is only one item. See :func:`spy` or :func:`peekable` to check
3555 iterable contents less destructively.
3556
3557 """
3558 iterator = iter(iterable)
3559 for first in iterator:
3560 for second in iterator:
3561 msg = (
3562 f'Expected exactly one item in iterable, but got {first!r}, '
3563 f'{second!r}, and perhaps more.'
3564 )
3565 raise too_long or ValueError(msg)
3566 return first
3567 return default
3568
3569
3570def _ichunk(iterator, n):
3571 cache = deque()
3572 chunk = islice(iterator, n)
3573
3574 def generator():
3575 with suppress(StopIteration):
3576 while True:
3577 if cache:
3578 yield cache.popleft()
3579 else:
3580 yield next(chunk)
3581
3582 def materialize_next(n=1):
3583 # if n not specified materialize everything
3584 if n is None:
3585 cache.extend(chunk)
3586 return len(cache)
3587
3588 to_cache = n - len(cache)
3589
3590 # materialize up to n
3591 if to_cache > 0:
3592 cache.extend(islice(chunk, to_cache))
3593
3594 # return number materialized up to n
3595 return min(n, len(cache))
3596
3597 return (generator(), materialize_next)
3598
3599
3600def ichunked(iterable, n):
3601 """Break *iterable* into sub-iterables with *n* elements each.
3602 :func:`ichunked` is like :func:`chunked`, but it yields iterables
3603 instead of lists.
3604
3605 If the sub-iterables are read in order, the elements of *iterable*
3606 won't be stored in memory.
3607 If they are read out of order, :func:`itertools.tee` is used to cache
3608 elements as necessary.
3609
3610 >>> from itertools import count
3611 >>> all_chunks = ichunked(count(), 4)
3612 >>> c_1, c_2, c_3 = next(all_chunks), next(all_chunks), next(all_chunks)
3613 >>> list(c_2) # c_1's elements have been cached; c_3's haven't been
3614 [4, 5, 6, 7]
3615 >>> list(c_1)
3616 [0, 1, 2, 3]
3617 >>> list(c_3)
3618 [8, 9, 10, 11]
3619
3620 """
3621 iterator = iter(iterable)
3622 while True:
3623 # Create new chunk
3624 chunk, materialize_next = _ichunk(iterator, n)
3625
3626 # Check to see whether we're at the end of the source iterable
3627 if not materialize_next():
3628 return
3629
3630 yield chunk
3631
3632 # Fill previous chunk's cache
3633 materialize_next(None)
3634
3635
3636def iequals(*iterables):
3637 """Return ``True`` if all given *iterables* are equal to each other,
3638 which means that they contain the same elements in the same order.
3639
3640 The function is useful for comparing iterables of different data types
3641 or iterables that do not support equality checks.
3642
3643 >>> iequals("abc", ['a', 'b', 'c'], ('a', 'b', 'c'), iter("abc"))
3644 True
3645
3646 >>> iequals("abc", "acb")
3647 False
3648
3649 Not to be confused with :func:`all_equal`, which checks whether all
3650 elements of iterable are equal to each other.
3651
3652 """
3653 return all(map(all_equal, zip_longest(*iterables, fillvalue=object())))
3654
3655
3656def distinct_combinations(iterable, r):
3657 """Yield the distinct combinations of *r* items taken from *iterable*.
3658
3659 >>> list(distinct_combinations([0, 0, 1], 2))
3660 [(0, 0), (0, 1)]
3661
3662 Equivalent to ``set(combinations(iterable))``, except duplicates are not
3663 generated and thrown away. For larger input sequences this is much more
3664 efficient.
3665
3666 """
3667 if r < 0:
3668 raise ValueError('r must be non-negative')
3669 elif r == 0:
3670 yield ()
3671 return
3672 pool = tuple(iterable)
3673 generators = [unique_everseen(enumerate(pool), key=itemgetter(1))]
3674 current_combo = [None] * r
3675 level = 0
3676 while generators:
3677 try:
3678 cur_idx, p = next(generators[-1])
3679 except StopIteration:
3680 generators.pop()
3681 level -= 1
3682 continue
3683 current_combo[level] = p
3684 if level + 1 == r:
3685 yield tuple(current_combo)
3686 else:
3687 generators.append(
3688 unique_everseen(
3689 enumerate(pool[cur_idx + 1 :], cur_idx + 1),
3690 key=itemgetter(1),
3691 )
3692 )
3693 level += 1
3694
3695
3696def filter_except(validator, iterable, *exceptions):
3697 """Yield the items from *iterable* for which the *validator* function does
3698 not raise one of the specified *exceptions*.
3699
3700 *validator* is called for each item in *iterable*.
3701 It should be a function that accepts one argument and raises an exception
3702 if that item is not valid.
3703
3704 >>> iterable = ['1', '2', 'three', '4', None]
3705 >>> list(filter_except(int, iterable, ValueError, TypeError))
3706 ['1', '2', '4']
3707
3708 If an exception other than one given by *exceptions* is raised by
3709 *validator*, it is raised like normal.
3710 """
3711 for item in iterable:
3712 try:
3713 validator(item)
3714 except exceptions:
3715 pass
3716 else:
3717 yield item
3718
3719
3720def map_except(function, iterable, *exceptions):
3721 """Transform each item from *iterable* with *function* and yield the
3722 result, unless *function* raises one of the specified *exceptions*.
3723
3724 *function* is called to transform each item in *iterable*.
3725 It should accept one argument.
3726
3727 >>> iterable = ['1', '2', 'three', '4', None]
3728 >>> list(map_except(int, iterable, ValueError, TypeError))
3729 [1, 2, 4]
3730
3731 If an exception other than one given by *exceptions* is raised by
3732 *function*, it is raised like normal.
3733 """
3734 for item in iterable:
3735 try:
3736 yield function(item)
3737 except exceptions:
3738 pass
3739
3740
3741def map_if(iterable, pred, func, func_else=None):
3742 """Evaluate each item from *iterable* using *pred*. If the result is
3743 equivalent to ``True``, transform the item with *func* and yield it.
3744 Otherwise, transform the item with *func_else* and yield it.
3745
3746 *pred*, *func*, and *func_else* should each be functions that accept
3747 one argument. By default, *func_else* is the identity function.
3748
3749 >>> from math import sqrt
3750 >>> iterable = list(range(-5, 5))
3751 >>> iterable
3752 [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4]
3753 >>> list(map_if(iterable, lambda x: x > 3, lambda x: 'toobig'))
3754 [-5, -4, -3, -2, -1, 0, 1, 2, 3, 'toobig']
3755 >>> list(map_if(iterable, lambda x: x >= 0,
3756 ... lambda x: f'{sqrt(x):.2f}', lambda x: None))
3757 [None, None, None, None, None, '0.00', '1.00', '1.41', '1.73', '2.00']
3758 """
3759
3760 if func_else is None:
3761 for item in iterable:
3762 yield func(item) if pred(item) else item
3763
3764 else:
3765 for item in iterable:
3766 yield func(item) if pred(item) else func_else(item)
3767
3768
3769def _sample_unweighted(iterator, k, strict):
3770 # Algorithm L in the 1994 paper by Kim-Hung Li:
3771 # "Reservoir-Sampling Algorithms of Time Complexity O(n(1+log(N/n)))".
3772
3773 reservoir = list(islice(iterator, k))
3774 if strict and len(reservoir) < k:
3775 raise ValueError('Sample larger than population')
3776 W = 1.0
3777
3778 with suppress(StopIteration):
3779 while True:
3780 W *= random() ** (1 / k)
3781 skip = floor(log(random()) / log1p(-W))
3782 element = next(islice(iterator, skip, None))
3783 reservoir[randrange(k)] = element
3784
3785 shuffle(reservoir)
3786 return reservoir
3787
3788
3789def _sample_weighted(iterator, k, weights, strict):
3790 # Implementation of "A-ExpJ" from the 2006 paper by Efraimidis et al. :
3791 # "Weighted random sampling with a reservoir".
3792
3793 # Log-transform for numerical stability for weights that are small/large
3794 weight_keys = (log(random()) / weight for weight in weights)
3795
3796 # Fill up the reservoir (collection of samples) with the first `k`
3797 # weight-keys and elements, then heapify the list.
3798 reservoir = take(k, zip(weight_keys, iterator))
3799 if strict and len(reservoir) < k:
3800 raise ValueError('Sample larger than population')
3801
3802 heapify(reservoir)
3803
3804 # The number of jumps before changing the reservoir is a random variable
3805 # with an exponential distribution. Sample it using random() and logs.
3806 smallest_weight_key, _ = reservoir[0]
3807 weights_to_skip = log(random()) / smallest_weight_key
3808
3809 for weight, element in zip(weights, iterator):
3810 if weight >= weights_to_skip:
3811 # The notation here is consistent with the paper, but we store
3812 # the weight-keys in log-space for better numerical stability.
3813 smallest_weight_key, _ = reservoir[0]
3814 t_w = exp(weight * smallest_weight_key)
3815 r_2 = uniform(t_w, 1) # generate U(t_w, 1)
3816 weight_key = log(r_2) / weight
3817 heapreplace(reservoir, (weight_key, element))
3818 smallest_weight_key, _ = reservoir[0]
3819 weights_to_skip = log(random()) / smallest_weight_key
3820 else:
3821 weights_to_skip -= weight
3822
3823 ret = [element for weight_key, element in reservoir]
3824 shuffle(ret)
3825 return ret
3826
3827
3828def _sample_counted(population, k, counts, strict):
3829 element = None
3830 remaining = 0
3831
3832 def feed(i):
3833 # Advance *i* steps ahead and consume an element
3834 nonlocal element, remaining
3835
3836 while i + 1 > remaining:
3837 i = i - remaining
3838 element = next(population)
3839 remaining = next(counts)
3840 remaining -= i + 1
3841 return element
3842
3843 with suppress(StopIteration):
3844 reservoir = []
3845 for _ in range(k):
3846 reservoir.append(feed(0))
3847
3848 if strict and len(reservoir) < k:
3849 raise ValueError('Sample larger than population')
3850
3851 with suppress(StopIteration):
3852 W = 1.0
3853 while True:
3854 W *= random() ** (1 / k)
3855 skip = floor(log(random()) / log1p(-W))
3856 element = feed(skip)
3857 reservoir[randrange(k)] = element
3858
3859 shuffle(reservoir)
3860 return reservoir
3861
3862
3863def sample(iterable, k, weights=None, *, counts=None, strict=False):
3864 """Return a *k*-length list of elements chosen (without replacement)
3865 from the *iterable*.
3866
3867 Similar to :func:`random.sample`, but works on inputs that aren't
3868 indexable (such as sets and dictionaries) and on inputs where the
3869 size isn't known in advance (such as generators).
3870
3871 >>> iterable = range(100)
3872 >>> sample(iterable, 5) # doctest: +SKIP
3873 [81, 60, 96, 16, 4]
3874
3875 For iterables with repeated elements, you may supply *counts* to
3876 indicate the repeats.
3877
3878 >>> iterable = ['a', 'b']
3879 >>> counts = [3, 4] # Equivalent to 'a', 'a', 'a', 'b', 'b', 'b', 'b'
3880 >>> sample(iterable, k=3, counts=counts) # doctest: +SKIP
3881 ['a', 'a', 'b']
3882
3883 An iterable with *weights* may be given:
3884
3885 >>> iterable = range(100)
3886 >>> weights = (i * i + 1 for i in range(100))
3887 >>> sampled = sample(iterable, 5, weights=weights) # doctest: +SKIP
3888 [79, 67, 74, 66, 78]
3889
3890 Weighted selections are made without replacement.
3891 After an element is selected, it is removed from the pool and the
3892 relative weights of the other elements increase (this
3893 does not match the behavior of :func:`random.sample`'s *counts*
3894 parameter). Note that *weights* may not be used with *counts*.
3895
3896 If the length of *iterable* is less than *k*,
3897 ``ValueError`` is raised if *strict* is ``True`` and
3898 all elements are returned (in shuffled order) if *strict* is ``False``.
3899
3900 By default, the `Algorithm L <https://w.wiki/ANrM>`__ reservoir sampling
3901 technique is used. When *weights* are provided,
3902 `Algorithm A-ExpJ <https://w.wiki/ANrS>`__ is used instead.
3903
3904 Notes on reproducibility:
3905
3906 * The algorithms rely on inexact floating-point functions provided
3907 by the underlying math library (e.g. ``log``, ``log1p``, and ``pow``).
3908 Those functions can `produce slightly different results
3909 <https://members.loria.fr/PZimmermann/papers/accuracy.pdf>`_ on
3910 different builds. Accordingly, selections can vary across builds
3911 even for the same seed.
3912
3913 * The algorithms loop over the input and make selections based on
3914 ordinal position, so selections from unordered collections (such as
3915 sets) won't reproduce across sessions on the same platform using the
3916 same seed. For example, this won't reproduce::
3917
3918 >> seed(8675309)
3919 >> sample(set('abcdefghijklmnopqrstuvwxyz'), 10)
3920 ['c', 'p', 'e', 'w', 's', 'a', 'j', 'd', 'n', 't']
3921
3922 """
3923 iterator = iter(iterable)
3924
3925 if k < 0:
3926 raise ValueError('k must be non-negative')
3927
3928 if k == 0:
3929 return []
3930
3931 if weights is not None and counts is not None:
3932 raise TypeError('weights and counts are mutually exclusive')
3933
3934 elif weights is not None:
3935 weights = iter(weights)
3936 return _sample_weighted(iterator, k, weights, strict)
3937
3938 elif counts is not None:
3939 counts = iter(counts)
3940 return _sample_counted(iterator, k, counts, strict)
3941
3942 else:
3943 return _sample_unweighted(iterator, k, strict)
3944
3945
3946def is_sorted(iterable, key=None, reverse=False, strict=False):
3947 """Returns ``True`` if the items of iterable are in sorted order, and
3948 ``False`` otherwise. *key* and *reverse* have the same meaning that they do
3949 in the built-in :func:`sorted` function.
3950
3951 >>> is_sorted(['1', '2', '3', '4', '5'], key=int)
3952 True
3953 >>> is_sorted([5, 4, 3, 1, 2], reverse=True)
3954 False
3955
3956 If *strict*, tests for strict sorting, that is, returns ``False`` if equal
3957 elements are found:
3958
3959 >>> is_sorted([1, 2, 2])
3960 True
3961 >>> is_sorted([1, 2, 2], strict=True)
3962 False
3963
3964 The function returns ``False`` after encountering the first out-of-order
3965 item, which means it may produce results that differ from the built-in
3966 :func:`sorted` function for objects with unusual comparison dynamics
3967 (like ``math.nan``). If there are no out-of-order items, the iterable is
3968 exhausted.
3969 """
3970 it = iterable if (key is None) else map(key, iterable)
3971 a, b = tee(it)
3972 next(b, None)
3973 if reverse:
3974 b, a = a, b
3975 return all(map(lt, a, b)) if strict else not any(map(lt, b, a))
3976
3977
3978class AbortThread(BaseException):
3979 pass
3980
3981
3982class callback_iter:
3983 """Convert a function that uses callbacks to an iterator.
3984
3985 Let *func* be a function that takes a `callback` keyword argument.
3986 For example:
3987
3988 >>> def func(callback=None):
3989 ... for i, c in [(1, 'a'), (2, 'b'), (3, 'c')]:
3990 ... if callback:
3991 ... callback(i, c)
3992 ... return 4
3993
3994
3995 Use ``with callback_iter(func)`` to get an iterator over the parameters
3996 that are delivered to the callback.
3997
3998 >>> with callback_iter(func) as it:
3999 ... for args, kwargs in it:
4000 ... print(args)
4001 (1, 'a')
4002 (2, 'b')
4003 (3, 'c')
4004
4005 The function will be called in a background thread. The ``done`` property
4006 indicates whether it has completed execution.
4007
4008 >>> it.done
4009 True
4010
4011 If it completes successfully, its return value will be available
4012 in the ``result`` property.
4013
4014 >>> it.result
4015 4
4016
4017 Notes:
4018
4019 * If the function uses some keyword argument besides ``callback``, supply
4020 *callback_kwd*.
4021 * If it finished executing, but raised an exception, accessing the
4022 ``result`` property will raise the same exception.
4023 * If it hasn't finished executing, accessing the ``result``
4024 property from within the ``with`` block will raise ``RuntimeError``.
4025 * If it hasn't finished executing, accessing the ``result`` property from
4026 outside the ``with`` block will raise a
4027 ``more_itertools.AbortThread`` exception.
4028 * Provide *wait_seconds* to adjust how frequently the it is polled for
4029 output.
4030
4031 """
4032
4033 def __init__(self, func, callback_kwd='callback', wait_seconds=0.1):
4034 self._func = func
4035 self._callback_kwd = callback_kwd
4036 self._aborted = False
4037 self._future = None
4038 self._wait_seconds = wait_seconds
4039 # Lazily import concurrent.future
4040 self._executor = __import__(
4041 'concurrent.futures'
4042 ).futures.ThreadPoolExecutor(max_workers=1)
4043 self._iterator = self._reader()
4044
4045 def __enter__(self):
4046 return self
4047
4048 def __exit__(self, exc_type, exc_value, traceback):
4049 self._aborted = True
4050 self._executor.shutdown()
4051
4052 def __iter__(self):
4053 return self
4054
4055 def __next__(self):
4056 return next(self._iterator)
4057
4058 @property
4059 def done(self):
4060 if self._future is None:
4061 return False
4062 return self._future.done()
4063
4064 @property
4065 def result(self):
4066 if not self.done:
4067 raise RuntimeError('Function has not yet completed')
4068
4069 return self._future.result()
4070
4071 def _reader(self):
4072 q = Queue()
4073
4074 def callback(*args, **kwargs):
4075 if self._aborted:
4076 raise AbortThread('canceled by user')
4077
4078 q.put((args, kwargs))
4079
4080 self._future = self._executor.submit(
4081 self._func, **{self._callback_kwd: callback}
4082 )
4083
4084 while True:
4085 try:
4086 item = q.get(timeout=self._wait_seconds)
4087 except Empty:
4088 pass
4089 else:
4090 q.task_done()
4091 yield item
4092
4093 if self._future.done():
4094 break
4095
4096 remaining = []
4097 while True:
4098 try:
4099 item = q.get_nowait()
4100 except Empty:
4101 break
4102 else:
4103 q.task_done()
4104 remaining.append(item)
4105 q.join()
4106 yield from remaining
4107
4108
4109def windowed_complete(iterable, n):
4110 """
4111 Yield ``(beginning, middle, end)`` tuples, where:
4112
4113 * Each ``middle`` has *n* items from *iterable*
4114 * Each ``beginning`` has the items before the ones in ``middle``
4115 * Each ``end`` has the items after the ones in ``middle``
4116
4117 >>> iterable = range(7)
4118 >>> n = 3
4119 >>> for beginning, middle, end in windowed_complete(iterable, n):
4120 ... print(beginning, middle, end)
4121 () (0, 1, 2) (3, 4, 5, 6)
4122 (0,) (1, 2, 3) (4, 5, 6)
4123 (0, 1) (2, 3, 4) (5, 6)
4124 (0, 1, 2) (3, 4, 5) (6,)
4125 (0, 1, 2, 3) (4, 5, 6) ()
4126
4127 Note that *n* must be at least 0 and most equal to the length of
4128 *iterable*.
4129
4130 This function will exhaust the iterable and may require significant
4131 storage.
4132 """
4133 if n < 0:
4134 raise ValueError('n must be >= 0')
4135
4136 seq = tuple(iterable)
4137 size = len(seq)
4138
4139 if n > size:
4140 raise ValueError('n must be <= len(seq)')
4141
4142 for i in range(size - n + 1):
4143 beginning = seq[:i]
4144 middle = seq[i : i + n]
4145 end = seq[i + n :]
4146 yield beginning, middle, end
4147
4148
4149def all_unique(iterable, key=None):
4150 """
4151 Returns ``True`` if all the elements of *iterable* are unique (no two
4152 elements are equal).
4153
4154 >>> all_unique('ABCB')
4155 False
4156
4157 If a *key* function is specified, it will be used to make comparisons.
4158
4159 >>> all_unique('ABCb')
4160 True
4161 >>> all_unique('ABCb', str.lower)
4162 False
4163
4164 The function returns as soon as the first non-unique element is
4165 encountered. Iterables with a mix of hashable and unhashable items can
4166 be used, but the function will be slower for unhashable items.
4167 """
4168 seenset = set()
4169 seenset_add = seenset.add
4170 seenlist = []
4171 seenlist_add = seenlist.append
4172 for element in map(key, iterable) if key else iterable:
4173 try:
4174 if element in seenset:
4175 return False
4176 seenset_add(element)
4177 except TypeError:
4178 if element in seenlist:
4179 return False
4180 seenlist_add(element)
4181 return True
4182
4183
4184def nth_product(index, *args):
4185 """Equivalent to ``list(product(*args))[index]``.
4186
4187 The products of *args* can be ordered lexicographically.
4188 :func:`nth_product` computes the product at sort position *index* without
4189 computing the previous products.
4190
4191 >>> nth_product(8, range(2), range(2), range(2), range(2))
4192 (1, 0, 0, 0)
4193
4194 ``IndexError`` will be raised if the given *index* is invalid.
4195 """
4196 pools = list(map(tuple, reversed(args)))
4197 ns = list(map(len, pools))
4198
4199 c = reduce(mul, ns)
4200
4201 if index < 0:
4202 index += c
4203
4204 if not 0 <= index < c:
4205 raise IndexError
4206
4207 result = []
4208 for pool, n in zip(pools, ns):
4209 result.append(pool[index % n])
4210 index //= n
4211
4212 return tuple(reversed(result))
4213
4214
4215def nth_permutation(iterable, r, index):
4216 """Equivalent to ``list(permutations(iterable, r))[index]```
4217
4218 The subsequences of *iterable* that are of length *r* where order is
4219 important can be ordered lexicographically. :func:`nth_permutation`
4220 computes the subsequence at sort position *index* directly, without
4221 computing the previous subsequences.
4222
4223 >>> nth_permutation('ghijk', 2, 5)
4224 ('h', 'i')
4225
4226 ``ValueError`` will be raised If *r* is negative or greater than the length
4227 of *iterable*.
4228 ``IndexError`` will be raised if the given *index* is invalid.
4229 """
4230 pool = list(iterable)
4231 n = len(pool)
4232
4233 if r is None or r == n:
4234 r, c = n, factorial(n)
4235 elif not 0 <= r < n:
4236 raise ValueError
4237 else:
4238 c = perm(n, r)
4239 assert c > 0 # factorial(n)>0, and r<n so perm(n,r) is never zero
4240
4241 if index < 0:
4242 index += c
4243
4244 if not 0 <= index < c:
4245 raise IndexError
4246
4247 result = [0] * r
4248 q = index * factorial(n) // c if r < n else index
4249 for d in range(1, n + 1):
4250 q, i = divmod(q, d)
4251 if 0 <= n - d < r:
4252 result[n - d] = i
4253 if q == 0:
4254 break
4255
4256 return tuple(map(pool.pop, result))
4257
4258
4259def nth_combination_with_replacement(iterable, r, index):
4260 """Equivalent to
4261 ``list(combinations_with_replacement(iterable, r))[index]``.
4262
4263
4264 The subsequences with repetition of *iterable* that are of length *r* can
4265 be ordered lexicographically. :func:`nth_combination_with_replacement`
4266 computes the subsequence at sort position *index* directly, without
4267 computing the previous subsequences with replacement.
4268
4269 >>> nth_combination_with_replacement(range(5), 3, 5)
4270 (0, 1, 1)
4271
4272 ``ValueError`` will be raised If *r* is negative or greater than the length
4273 of *iterable*.
4274 ``IndexError`` will be raised if the given *index* is invalid.
4275 """
4276 pool = tuple(iterable)
4277 n = len(pool)
4278 if (r < 0) or (r > n):
4279 raise ValueError
4280
4281 c = comb(n + r - 1, r)
4282
4283 if index < 0:
4284 index += c
4285
4286 if (index < 0) or (index >= c):
4287 raise IndexError
4288
4289 result = []
4290 i = 0
4291 while r:
4292 r -= 1
4293 while n >= 0:
4294 num_combs = comb(n + r - 1, r)
4295 if index < num_combs:
4296 break
4297 n -= 1
4298 i += 1
4299 index -= num_combs
4300 result.append(pool[i])
4301
4302 return tuple(result)
4303
4304
4305def value_chain(*args):
4306 """Yield all arguments passed to the function in the same order in which
4307 they were passed. If an argument itself is iterable then iterate over its
4308 values.
4309
4310 >>> list(value_chain(1, 2, 3, [4, 5, 6]))
4311 [1, 2, 3, 4, 5, 6]
4312
4313 Binary and text strings are not considered iterable and are emitted
4314 as-is:
4315
4316 >>> list(value_chain('12', '34', ['56', '78']))
4317 ['12', '34', '56', '78']
4318
4319 Pre- or postpend a single element to an iterable:
4320
4321 >>> list(value_chain(1, [2, 3, 4, 5, 6]))
4322 [1, 2, 3, 4, 5, 6]
4323 >>> list(value_chain([1, 2, 3, 4, 5], 6))
4324 [1, 2, 3, 4, 5, 6]
4325
4326 Multiple levels of nesting are not flattened.
4327
4328 """
4329 for value in args:
4330 if isinstance(value, (str, bytes)):
4331 yield value
4332 continue
4333 try:
4334 yield from value
4335 except TypeError:
4336 yield value
4337
4338
4339def product_index(element, *args):
4340 """Equivalent to ``list(product(*args)).index(element)``
4341
4342 The products of *args* can be ordered lexicographically.
4343 :func:`product_index` computes the first index of *element* without
4344 computing the previous products.
4345
4346 >>> product_index([8, 2], range(10), range(5))
4347 42
4348
4349 ``ValueError`` will be raised if the given *element* isn't in the product
4350 of *args*.
4351 """
4352 index = 0
4353
4354 for x, pool in zip_longest(element, args, fillvalue=_marker):
4355 if x is _marker or pool is _marker:
4356 raise ValueError('element is not a product of args')
4357
4358 pool = tuple(pool)
4359 index = index * len(pool) + pool.index(x)
4360
4361 return index
4362
4363
4364def combination_index(element, iterable):
4365 """Equivalent to ``list(combinations(iterable, r)).index(element)``
4366
4367 The subsequences of *iterable* that are of length *r* can be ordered
4368 lexicographically. :func:`combination_index` computes the index of the
4369 first *element*, without computing the previous combinations.
4370
4371 >>> combination_index('adf', 'abcdefg')
4372 10
4373
4374 ``ValueError`` will be raised if the given *element* isn't one of the
4375 combinations of *iterable*.
4376 """
4377 element = enumerate(element)
4378 k, y = next(element, (None, None))
4379 if k is None:
4380 return 0
4381
4382 indexes = []
4383 pool = enumerate(iterable)
4384 for n, x in pool:
4385 if x == y:
4386 indexes.append(n)
4387 tmp, y = next(element, (None, None))
4388 if tmp is None:
4389 break
4390 else:
4391 k = tmp
4392 else:
4393 raise ValueError('element is not a combination of iterable')
4394
4395 n, _ = last(pool, default=(n, None))
4396
4397 # Python versions below 3.8 don't have math.comb
4398 index = 1
4399 for i, j in enumerate(reversed(indexes), start=1):
4400 j = n - j
4401 if i <= j:
4402 index += comb(j, i)
4403
4404 return comb(n + 1, k + 1) - index
4405
4406
4407def combination_with_replacement_index(element, iterable):
4408 """Equivalent to
4409 ``list(combinations_with_replacement(iterable, r)).index(element)``
4410
4411 The subsequences with repetition of *iterable* that are of length *r* can
4412 be ordered lexicographically. :func:`combination_with_replacement_index`
4413 computes the index of the first *element*, without computing the previous
4414 combinations with replacement.
4415
4416 >>> combination_with_replacement_index('adf', 'abcdefg')
4417 20
4418
4419 ``ValueError`` will be raised if the given *element* isn't one of the
4420 combinations with replacement of *iterable*.
4421 """
4422 element = tuple(element)
4423 l = len(element)
4424 element = enumerate(element)
4425
4426 k, y = next(element, (None, None))
4427 if k is None:
4428 return 0
4429
4430 indexes = []
4431 pool = tuple(iterable)
4432 for n, x in enumerate(pool):
4433 while x == y:
4434 indexes.append(n)
4435 tmp, y = next(element, (None, None))
4436 if tmp is None:
4437 break
4438 else:
4439 k = tmp
4440 if y is None:
4441 break
4442 else:
4443 raise ValueError(
4444 'element is not a combination with replacement of iterable'
4445 )
4446
4447 n = len(pool)
4448 occupations = [0] * n
4449 for p in indexes:
4450 occupations[p] += 1
4451
4452 index = 0
4453 cumulative_sum = 0
4454 for k in range(1, n):
4455 cumulative_sum += occupations[k - 1]
4456 j = l + n - 1 - k - cumulative_sum
4457 i = n - k
4458 if i <= j:
4459 index += comb(j, i)
4460
4461 return index
4462
4463
4464def permutation_index(element, iterable):
4465 """Equivalent to ``list(permutations(iterable, r)).index(element)```
4466
4467 The subsequences of *iterable* that are of length *r* where order is
4468 important can be ordered lexicographically. :func:`permutation_index`
4469 computes the index of the first *element* directly, without computing
4470 the previous permutations.
4471
4472 >>> permutation_index([1, 3, 2], range(5))
4473 19
4474
4475 ``ValueError`` will be raised if the given *element* isn't one of the
4476 permutations of *iterable*.
4477 """
4478 index = 0
4479 pool = list(iterable)
4480 for i, x in zip(range(len(pool), -1, -1), element):
4481 r = pool.index(x)
4482 index = index * i + r
4483 del pool[r]
4484
4485 return index
4486
4487
4488class countable:
4489 """Wrap *iterable* and keep a count of how many items have been consumed.
4490
4491 The ``items_seen`` attribute starts at ``0`` and increments as the iterable
4492 is consumed:
4493
4494 >>> iterable = map(str, range(10))
4495 >>> it = countable(iterable)
4496 >>> it.items_seen
4497 0
4498 >>> next(it), next(it)
4499 ('0', '1')
4500 >>> list(it)
4501 ['2', '3', '4', '5', '6', '7', '8', '9']
4502 >>> it.items_seen
4503 10
4504 """
4505
4506 def __init__(self, iterable):
4507 self._iterator = iter(iterable)
4508 self.items_seen = 0
4509
4510 def __iter__(self):
4511 return self
4512
4513 def __next__(self):
4514 item = next(self._iterator)
4515 self.items_seen += 1
4516
4517 return item
4518
4519
4520def chunked_even(iterable, n):
4521 """Break *iterable* into lists of approximately length *n*.
4522 Items are distributed such the lengths of the lists differ by at most
4523 1 item.
4524
4525 >>> iterable = [1, 2, 3, 4, 5, 6, 7]
4526 >>> n = 3
4527 >>> list(chunked_even(iterable, n)) # List lengths: 3, 2, 2
4528 [[1, 2, 3], [4, 5], [6, 7]]
4529 >>> list(chunked(iterable, n)) # List lengths: 3, 3, 1
4530 [[1, 2, 3], [4, 5, 6], [7]]
4531
4532 """
4533 iterator = iter(iterable)
4534
4535 # Initialize a buffer to process the chunks while keeping
4536 # some back to fill any underfilled chunks
4537 min_buffer = (n - 1) * (n - 2)
4538 buffer = list(islice(iterator, min_buffer))
4539
4540 # Append items until we have a completed chunk
4541 for _ in islice(map(buffer.append, iterator), n, None, n):
4542 yield buffer[:n]
4543 del buffer[:n]
4544
4545 # Check if any chunks need addition processing
4546 if not buffer:
4547 return
4548 length = len(buffer)
4549
4550 # Chunks are either size `full_size <= n` or `partial_size = full_size - 1`
4551 q, r = divmod(length, n)
4552 num_lists = q + (1 if r > 0 else 0)
4553 q, r = divmod(length, num_lists)
4554 full_size = q + (1 if r > 0 else 0)
4555 partial_size = full_size - 1
4556 num_full = length - partial_size * num_lists
4557
4558 # Yield chunks of full size
4559 partial_start_idx = num_full * full_size
4560 if full_size > 0:
4561 for i in range(0, partial_start_idx, full_size):
4562 yield buffer[i : i + full_size]
4563
4564 # Yield chunks of partial size
4565 if partial_size > 0:
4566 for i in range(partial_start_idx, length, partial_size):
4567 yield buffer[i : i + partial_size]
4568
4569
4570def zip_broadcast(*objects, scalar_types=(str, bytes), strict=False):
4571 """A version of :func:`zip` that "broadcasts" any scalar
4572 (i.e., non-iterable) items into output tuples.
4573
4574 >>> iterable_1 = [1, 2, 3]
4575 >>> iterable_2 = ['a', 'b', 'c']
4576 >>> scalar = '_'
4577 >>> list(zip_broadcast(iterable_1, iterable_2, scalar))
4578 [(1, 'a', '_'), (2, 'b', '_'), (3, 'c', '_')]
4579
4580 The *scalar_types* keyword argument determines what types are considered
4581 scalar. It is set to ``(str, bytes)`` by default. Set it to ``None`` to
4582 treat strings and byte strings as iterable:
4583
4584 >>> list(zip_broadcast('abc', 0, 'xyz', scalar_types=None))
4585 [('a', 0, 'x'), ('b', 0, 'y'), ('c', 0, 'z')]
4586
4587 If the *strict* keyword argument is ``True``, then
4588 ``UnequalIterablesError`` will be raised if any of the iterables have
4589 different lengths.
4590 """
4591
4592 def is_scalar(obj):
4593 if scalar_types and isinstance(obj, scalar_types):
4594 return True
4595 try:
4596 iter(obj)
4597 except TypeError:
4598 return True
4599 else:
4600 return False
4601
4602 size = len(objects)
4603 if not size:
4604 return
4605
4606 new_item = [None] * size
4607 iterables, iterable_positions = [], []
4608 for i, obj in enumerate(objects):
4609 if is_scalar(obj):
4610 new_item[i] = obj
4611 else:
4612 iterables.append(iter(obj))
4613 iterable_positions.append(i)
4614
4615 if not iterables:
4616 yield tuple(objects)
4617 return
4618
4619 zipper = _zip_equal if strict else zip
4620 for item in zipper(*iterables):
4621 for i, new_item[i] in zip(iterable_positions, item):
4622 pass
4623 yield tuple(new_item)
4624
4625
4626def unique_in_window(iterable, n, key=None):
4627 """Yield the items from *iterable* that haven't been seen recently.
4628 *n* is the size of the lookback window.
4629
4630 >>> iterable = [0, 1, 0, 2, 3, 0]
4631 >>> n = 3
4632 >>> list(unique_in_window(iterable, n))
4633 [0, 1, 2, 3, 0]
4634
4635 The *key* function, if provided, will be used to determine uniqueness:
4636
4637 >>> list(unique_in_window('abAcda', 3, key=lambda x: x.lower()))
4638 ['a', 'b', 'c', 'd', 'a']
4639
4640 The items in *iterable* must be hashable.
4641
4642 """
4643 if n <= 0:
4644 raise ValueError('n must be greater than 0')
4645
4646 window = deque(maxlen=n)
4647 counts = defaultdict(int)
4648 use_key = key is not None
4649
4650 for item in iterable:
4651 if len(window) == n:
4652 to_discard = window[0]
4653 if counts[to_discard] == 1:
4654 del counts[to_discard]
4655 else:
4656 counts[to_discard] -= 1
4657
4658 k = key(item) if use_key else item
4659 if k not in counts:
4660 yield item
4661 counts[k] += 1
4662 window.append(k)
4663
4664
4665def duplicates_everseen(iterable, key=None):
4666 """Yield duplicate elements after their first appearance.
4667
4668 >>> list(duplicates_everseen('mississippi'))
4669 ['s', 'i', 's', 's', 'i', 'p', 'i']
4670 >>> list(duplicates_everseen('AaaBbbCccAaa', str.lower))
4671 ['a', 'a', 'b', 'b', 'c', 'c', 'A', 'a', 'a']
4672
4673 This function is analogous to :func:`unique_everseen` and is subject to
4674 the same performance considerations.
4675
4676 """
4677 seen_set = set()
4678 seen_list = []
4679 use_key = key is not None
4680
4681 for element in iterable:
4682 k = key(element) if use_key else element
4683 try:
4684 if k not in seen_set:
4685 seen_set.add(k)
4686 else:
4687 yield element
4688 except TypeError:
4689 if k not in seen_list:
4690 seen_list.append(k)
4691 else:
4692 yield element
4693
4694
4695def duplicates_justseen(iterable, key=None):
4696 """Yields serially-duplicate elements after their first appearance.
4697
4698 >>> list(duplicates_justseen('mississippi'))
4699 ['s', 's', 'p']
4700 >>> list(duplicates_justseen('AaaBbbCccAaa', str.lower))
4701 ['a', 'a', 'b', 'b', 'c', 'c', 'a', 'a']
4702
4703 This function is analogous to :func:`unique_justseen`.
4704
4705 """
4706 return flatten(g for _, g in groupby(iterable, key) for _ in g)
4707
4708
4709def classify_unique(iterable, key=None):
4710 """Classify each element in terms of its uniqueness.
4711
4712 For each element in the input iterable, return a 3-tuple consisting of:
4713
4714 1. The element itself
4715 2. ``False`` if the element is equal to the one preceding it in the input,
4716 ``True`` otherwise (i.e. the equivalent of :func:`unique_justseen`)
4717 3. ``False`` if this element has been seen anywhere in the input before,
4718 ``True`` otherwise (i.e. the equivalent of :func:`unique_everseen`)
4719
4720 >>> list(classify_unique('otto')) # doctest: +NORMALIZE_WHITESPACE
4721 [('o', True, True),
4722 ('t', True, True),
4723 ('t', False, False),
4724 ('o', True, False)]
4725
4726 This function is analogous to :func:`unique_everseen` and is subject to
4727 the same performance considerations.
4728
4729 """
4730 seen_set = set()
4731 seen_list = []
4732 use_key = key is not None
4733 previous = None
4734
4735 for i, element in enumerate(iterable):
4736 k = key(element) if use_key else element
4737 is_unique_justseen = not i or previous != k
4738 previous = k
4739 is_unique_everseen = False
4740 try:
4741 if k not in seen_set:
4742 seen_set.add(k)
4743 is_unique_everseen = True
4744 except TypeError:
4745 if k not in seen_list:
4746 seen_list.append(k)
4747 is_unique_everseen = True
4748 yield element, is_unique_justseen, is_unique_everseen
4749
4750
4751def minmax(iterable_or_value, *others, key=None, default=_marker):
4752 """Returns both the smallest and largest items from an iterable
4753 or from two or more arguments.
4754
4755 >>> minmax([3, 1, 5])
4756 (1, 5)
4757
4758 >>> minmax(4, 2, 6)
4759 (2, 6)
4760
4761 If a *key* function is provided, it will be used to transform the input
4762 items for comparison.
4763
4764 >>> minmax([5, 30], key=str) # '30' sorts before '5'
4765 (30, 5)
4766
4767 If a *default* value is provided, it will be returned if there are no
4768 input items.
4769
4770 >>> minmax([], default=(0, 0))
4771 (0, 0)
4772
4773 Otherwise ``ValueError`` is raised.
4774
4775 This function makes a single pass over the input elements and takes care to
4776 minimize the number of comparisons made during processing.
4777
4778 Note that unlike the builtin ``max`` function, which always returns the first
4779 item with the maximum value, this function may return another item when there are
4780 ties.
4781
4782 This function is based on the
4783 `recipe <https://code.activestate.com/recipes/577916-fast-minmax-function>`__ by
4784 Raymond Hettinger.
4785 """
4786 iterable = (iterable_or_value, *others) if others else iterable_or_value
4787
4788 it = iter(iterable)
4789
4790 try:
4791 lo = hi = next(it)
4792 except StopIteration as exc:
4793 if default is _marker:
4794 raise ValueError(
4795 '`minmax()` argument is an empty iterable. '
4796 'Provide a `default` value to suppress this error.'
4797 ) from exc
4798 return default
4799
4800 # Different branches depending on the presence of key. This saves a lot
4801 # of unimportant copies which would slow the "key=None" branch
4802 # significantly down.
4803 if key is None:
4804 for x, y in zip_longest(it, it, fillvalue=lo):
4805 if y < x:
4806 x, y = y, x
4807 if x < lo:
4808 lo = x
4809 if hi < y:
4810 hi = y
4811
4812 else:
4813 lo_key = hi_key = key(lo)
4814
4815 for x, y in zip_longest(it, it, fillvalue=lo):
4816 x_key, y_key = key(x), key(y)
4817
4818 if y_key < x_key:
4819 x, y, x_key, y_key = y, x, y_key, x_key
4820 if x_key < lo_key:
4821 lo, lo_key = x, x_key
4822 if hi_key < y_key:
4823 hi, hi_key = y, y_key
4824
4825 return lo, hi
4826
4827
4828def constrained_batches(
4829 iterable, max_size, max_count=None, get_len=len, strict=True
4830):
4831 """Yield batches of items from *iterable* with a combined size limited by
4832 *max_size*.
4833
4834 >>> iterable = [b'12345', b'123', b'12345678', b'1', b'1', b'12', b'1']
4835 >>> list(constrained_batches(iterable, 10))
4836 [(b'12345', b'123'), (b'12345678', b'1', b'1'), (b'12', b'1')]
4837
4838 If a *max_count* is supplied, the number of items per batch is also
4839 limited:
4840
4841 >>> iterable = [b'12345', b'123', b'12345678', b'1', b'1', b'12', b'1']
4842 >>> list(constrained_batches(iterable, 10, max_count = 2))
4843 [(b'12345', b'123'), (b'12345678', b'1'), (b'1', b'12'), (b'1',)]
4844
4845 If a *get_len* function is supplied, use that instead of :func:`len` to
4846 determine item size.
4847
4848 If *strict* is ``True``, raise ``ValueError`` if any single item is bigger
4849 than *max_size*. Otherwise, allow single items to exceed *max_size*.
4850 """
4851 if max_size <= 0:
4852 raise ValueError('maximum size must be greater than zero')
4853
4854 batch = []
4855 batch_size = 0
4856 batch_count = 0
4857 for item in iterable:
4858 item_len = get_len(item)
4859 if strict and item_len > max_size:
4860 raise ValueError('item size exceeds maximum size')
4861
4862 reached_count = batch_count == max_count
4863 reached_size = item_len + batch_size > max_size
4864 if batch_count and (reached_size or reached_count):
4865 yield tuple(batch)
4866 batch.clear()
4867 batch_size = 0
4868 batch_count = 0
4869
4870 batch.append(item)
4871 batch_size += item_len
4872 batch_count += 1
4873
4874 if batch:
4875 yield tuple(batch)
4876
4877
4878def gray_product(*iterables):
4879 """Like :func:`itertools.product`, but return tuples in an order such
4880 that only one element in the generated tuple changes from one iteration
4881 to the next.
4882
4883 >>> list(gray_product('AB','CD'))
4884 [('A', 'C'), ('B', 'C'), ('B', 'D'), ('A', 'D')]
4885
4886 This function consumes all of the input iterables before producing output.
4887 If any of the input iterables have fewer than two items, ``ValueError``
4888 is raised.
4889
4890 For information on the algorithm, see
4891 `this section <https://www-cs-faculty.stanford.edu/~knuth/fasc2a.ps.gz>`__
4892 of Donald Knuth's *The Art of Computer Programming*.
4893 """
4894 all_iterables = tuple(tuple(x) for x in iterables)
4895 iterable_count = len(all_iterables)
4896 for iterable in all_iterables:
4897 if len(iterable) < 2:
4898 raise ValueError("each iterable must have two or more items")
4899
4900 # This is based on "Algorithm H" from section 7.2.1.1, page 20.
4901 # a holds the indexes of the source iterables for the n-tuple to be yielded
4902 # f is the array of "focus pointers"
4903 # o is the array of "directions"
4904 a = [0] * iterable_count
4905 f = list(range(iterable_count + 1))
4906 o = [1] * iterable_count
4907 while True:
4908 yield tuple(all_iterables[i][a[i]] for i in range(iterable_count))
4909 j = f[0]
4910 f[0] = 0
4911 if j == iterable_count:
4912 break
4913 a[j] = a[j] + o[j]
4914 if a[j] == 0 or a[j] == len(all_iterables[j]) - 1:
4915 o[j] = -o[j]
4916 f[j] = f[j + 1]
4917 f[j + 1] = j + 1
4918
4919
4920def partial_product(*iterables):
4921 """Yields tuples containing one item from each iterator, with subsequent
4922 tuples changing a single item at a time by advancing each iterator until it
4923 is exhausted. This sequence guarantees every value in each iterable is
4924 output at least once without generating all possible combinations.
4925
4926 This may be useful, for example, when testing an expensive function.
4927
4928 >>> list(partial_product('AB', 'C', 'DEF'))
4929 [('A', 'C', 'D'), ('B', 'C', 'D'), ('B', 'C', 'E'), ('B', 'C', 'F')]
4930 """
4931
4932 iterators = list(map(iter, iterables))
4933
4934 try:
4935 prod = [next(it) for it in iterators]
4936 except StopIteration:
4937 return
4938 yield tuple(prod)
4939
4940 for i, it in enumerate(iterators):
4941 for prod[i] in it:
4942 yield tuple(prod)
4943
4944
4945def takewhile_inclusive(predicate, iterable):
4946 """A variant of :func:`takewhile` that yields one additional element.
4947
4948 >>> list(takewhile_inclusive(lambda x: x < 5, [1, 4, 6, 4, 1]))
4949 [1, 4, 6]
4950
4951 :func:`takewhile` would return ``[1, 4]``.
4952 """
4953 for x in iterable:
4954 yield x
4955 if not predicate(x):
4956 break
4957
4958
4959def outer_product(func, xs, ys, *args, **kwargs):
4960 """A generalized outer product that applies a binary function to all
4961 pairs of items. Returns a 2D matrix with ``len(xs)`` rows and ``len(ys)``
4962 columns.
4963 Also accepts ``*args`` and ``**kwargs`` that are passed to ``func``.
4964
4965 Multiplication table:
4966
4967 >>> list(outer_product(mul, range(1, 4), range(1, 6)))
4968 [(1, 2, 3, 4, 5), (2, 4, 6, 8, 10), (3, 6, 9, 12, 15)]
4969
4970 Cross tabulation:
4971
4972 >>> xs = ['A', 'B', 'A', 'A', 'B', 'B', 'A', 'A', 'B', 'B']
4973 >>> ys = ['X', 'X', 'X', 'Y', 'Z', 'Z', 'Y', 'Y', 'Z', 'Z']
4974 >>> pair_counts = Counter(zip(xs, ys))
4975 >>> count_rows = lambda x, y: pair_counts[x, y]
4976 >>> list(outer_product(count_rows, sorted(set(xs)), sorted(set(ys))))
4977 [(2, 3, 0), (1, 0, 4)]
4978
4979 Usage with ``*args`` and ``**kwargs``:
4980
4981 >>> animals = ['cat', 'wolf', 'mouse']
4982 >>> list(outer_product(min, animals, animals, key=len))
4983 [('cat', 'cat', 'cat'), ('cat', 'wolf', 'wolf'), ('cat', 'wolf', 'mouse')]
4984 """
4985 ys = tuple(ys)
4986 return batched(
4987 starmap(lambda x, y: func(x, y, *args, **kwargs), product(xs, ys)),
4988 n=len(ys),
4989 )
4990
4991
4992def iter_suppress(iterable, *exceptions):
4993 """Yield each of the items from *iterable*. If the iteration raises one of
4994 the specified *exceptions*, that exception will be suppressed and iteration
4995 will stop.
4996
4997 >>> from itertools import chain
4998 >>> def breaks_at_five(x):
4999 ... while True:
5000 ... if x >= 5:
5001 ... raise RuntimeError
5002 ... yield x
5003 ... x += 1
5004 >>> it_1 = iter_suppress(breaks_at_five(1), RuntimeError)
5005 >>> it_2 = iter_suppress(breaks_at_five(2), RuntimeError)
5006 >>> list(chain(it_1, it_2))
5007 [1, 2, 3, 4, 2, 3, 4]
5008 """
5009 try:
5010 yield from iterable
5011 except exceptions:
5012 return
5013
5014
5015def filter_map(func, iterable):
5016 """Apply *func* to every element of *iterable*, yielding only those which
5017 are not ``None``.
5018
5019 >>> elems = ['1', 'a', '2', 'b', '3']
5020 >>> list(filter_map(lambda s: int(s) if s.isnumeric() else None, elems))
5021 [1, 2, 3]
5022 """
5023 for x in iterable:
5024 y = func(x)
5025 if y is not None:
5026 yield y
5027
5028
5029def powerset_of_sets(iterable):
5030 """Yields all possible subsets of the iterable.
5031
5032 >>> list(powerset_of_sets([1, 2, 3])) # doctest: +SKIP
5033 [set(), {1}, {2}, {3}, {1, 2}, {1, 3}, {2, 3}, {1, 2, 3}]
5034 >>> list(powerset_of_sets([1, 1, 0])) # doctest: +SKIP
5035 [set(), {1}, {0}, {0, 1}]
5036
5037 :func:`powerset_of_sets` takes care to minimize the number
5038 of hash operations performed.
5039 """
5040 sets = tuple(dict.fromkeys(map(frozenset, zip(iterable))))
5041 return chain.from_iterable(
5042 starmap(set().union, combinations(sets, r))
5043 for r in range(len(sets) + 1)
5044 )
5045
5046
5047def join_mappings(**field_to_map):
5048 """
5049 Joins multiple mappings together using their common keys.
5050
5051 >>> user_scores = {'elliot': 50, 'claris': 60}
5052 >>> user_times = {'elliot': 30, 'claris': 40}
5053 >>> join_mappings(score=user_scores, time=user_times)
5054 {'elliot': {'score': 50, 'time': 30}, 'claris': {'score': 60, 'time': 40}}
5055 """
5056 ret = defaultdict(dict)
5057
5058 for field_name, mapping in field_to_map.items():
5059 for key, value in mapping.items():
5060 ret[key][field_name] = value
5061
5062 return dict(ret)
5063
5064
5065def _complex_sumprod(v1, v2):
5066 """High precision sumprod() for complex numbers.
5067 Used by :func:`dft` and :func:`idft`.
5068 """
5069
5070 real = attrgetter('real')
5071 imag = attrgetter('imag')
5072 r1 = chain(map(real, v1), map(neg, map(imag, v1)))
5073 r2 = chain(map(real, v2), map(imag, v2))
5074 i1 = chain(map(real, v1), map(imag, v1))
5075 i2 = chain(map(imag, v2), map(real, v2))
5076 return complex(_fsumprod(r1, r2), _fsumprod(i1, i2))
5077
5078
5079def dft(xarr):
5080 """Discrete Fourier Transform. *xarr* is a sequence of complex numbers.
5081 Yields the components of the corresponding transformed output vector.
5082
5083 >>> import cmath
5084 >>> xarr = [1, 2-1j, -1j, -1+2j] # time domain
5085 >>> Xarr = [2, -2-2j, -2j, 4+4j] # frequency domain
5086 >>> magnitudes, phases = zip(*map(cmath.polar, Xarr))
5087 >>> all(map(cmath.isclose, dft(xarr), Xarr))
5088 True
5089
5090 Inputs are restricted to numeric types that can add and multiply
5091 with a complex number. This includes int, float, complex, and
5092 Fraction, but excludes Decimal.
5093
5094 See :func:`idft` for the inverse Discrete Fourier Transform.
5095 """
5096 N = len(xarr)
5097 roots_of_unity = [e ** (n / N * tau * -1j) for n in range(N)]
5098 for k in range(N):
5099 coeffs = [roots_of_unity[k * n % N] for n in range(N)]
5100 yield _complex_sumprod(xarr, coeffs)
5101
5102
5103def idft(Xarr):
5104 """Inverse Discrete Fourier Transform. *Xarr* is a sequence of
5105 complex numbers. Yields the components of the corresponding
5106 inverse-transformed output vector.
5107
5108 >>> import cmath
5109 >>> xarr = [1, 2-1j, -1j, -1+2j] # time domain
5110 >>> Xarr = [2, -2-2j, -2j, 4+4j] # frequency domain
5111 >>> all(map(cmath.isclose, idft(Xarr), xarr))
5112 True
5113
5114 Inputs are restricted to numeric types that can add and multiply
5115 with a complex number. This includes int, float, complex, and
5116 Fraction, but excludes Decimal.
5117
5118 See :func:`dft` for the Discrete Fourier Transform.
5119 """
5120 N = len(Xarr)
5121 roots_of_unity = [e ** (n / N * tau * 1j) for n in range(N)]
5122 for k in range(N):
5123 coeffs = [roots_of_unity[k * n % N] for n in range(N)]
5124 yield _complex_sumprod(Xarr, coeffs) / N
5125
5126
5127def doublestarmap(func, iterable):
5128 """Apply *func* to every item of *iterable* by dictionary unpacking
5129 the item into *func*.
5130
5131 The difference between :func:`itertools.starmap` and :func:`doublestarmap`
5132 parallels the distinction between ``func(*a)`` and ``func(**a)``.
5133
5134 >>> iterable = [{'a': 1, 'b': 2}, {'a': 40, 'b': 60}]
5135 >>> list(doublestarmap(lambda a, b: a + b, iterable))
5136 [3, 100]
5137
5138 ``TypeError`` will be raised if *func*'s signature doesn't match the
5139 mapping contained in *iterable* or if *iterable* does not contain mappings.
5140 """
5141 for item in iterable:
5142 yield func(**item)
5143
5144
5145def _nth_prime_bounds(n):
5146 """Bounds for the nth prime (counting from 1): lb < p_n < ub."""
5147 # At and above 688,383, the lb/ub spread is under 0.003 * p_n.
5148
5149 if n < 1:
5150 raise ValueError
5151
5152 if n < 6:
5153 return (n, 2.25 * n)
5154
5155 # https://en.wikipedia.org/wiki/Prime-counting_function#Inequalities
5156 upper_bound = n * log(n * log(n))
5157 lower_bound = upper_bound - n
5158 if n >= 688_383:
5159 upper_bound -= n * (1.0 - (log(log(n)) - 2.0) / log(n))
5160
5161 return lower_bound, upper_bound
5162
5163
5164def nth_prime(n, *, approximate=False):
5165 """Return the nth prime (counting from 0).
5166
5167 >>> nth_prime(0)
5168 2
5169 >>> nth_prime(100)
5170 547
5171
5172 If *approximate* is set to True, will return a prime in the close
5173 to the nth prime. The estimation is much faster than computing
5174 an exact result.
5175
5176 >>> nth_prime(200_000_000, approximate=True) # Exact result is 4222234763
5177 4217820427
5178
5179 """
5180 lb, ub = _nth_prime_bounds(n + 1)
5181
5182 if not approximate or n <= 1_000_000:
5183 return nth(sieve(ceil(ub)), n)
5184
5185 # Search from the midpoint and return the first odd prime
5186 odd = floor((lb + ub) / 2) | 1
5187 return first_true(count(odd, step=2), pred=is_prime)
5188
5189
5190def argmin(iterable, *, key=None):
5191 """
5192 Index of the first occurrence of a minimum value in an iterable.
5193
5194 >>> argmin('efghabcdijkl')
5195 4
5196 >>> argmin([3, 2, 1, 0, 4, 2, 1, 0])
5197 3
5198
5199 For example, look up a label corresponding to the position
5200 of a value that minimizes a cost function::
5201
5202 >>> def cost(x):
5203 ... "Days for a wound to heal given a subject's age."
5204 ... return x**2 - 20*x + 150
5205 ...
5206 >>> labels = ['homer', 'marge', 'bart', 'lisa', 'maggie']
5207 >>> ages = [ 35, 30, 10, 9, 1 ]
5208
5209 # Fastest healing family member
5210 >>> labels[argmin(ages, key=cost)]
5211 'bart'
5212
5213 # Age with fastest healing
5214 >>> min(ages, key=cost)
5215 10
5216
5217 """
5218 if key is not None:
5219 iterable = map(key, iterable)
5220 return min(enumerate(iterable), key=itemgetter(1))[0]
5221
5222
5223def argmax(iterable, *, key=None):
5224 """
5225 Index of the first occurrence of a maximum value in an iterable.
5226
5227 >>> argmax('abcdefghabcd')
5228 7
5229 >>> argmax([0, 1, 2, 3, 3, 2, 1, 0])
5230 3
5231
5232 For example, identify the best machine learning model::
5233
5234 >>> models = ['svm', 'random forest', 'knn', 'naïve bayes']
5235 >>> accuracy = [ 68, 61, 84, 72 ]
5236
5237 # Most accurate model
5238 >>> models[argmax(accuracy)]
5239 'knn'
5240
5241 # Best accuracy
5242 >>> max(accuracy)
5243 84
5244
5245 """
5246 if key is not None:
5247 iterable = map(key, iterable)
5248 return max(enumerate(iterable), key=itemgetter(1))[0]
5249
5250
5251def extract(iterable, indices):
5252 """Yield values at the specified indices.
5253
5254 Example:
5255
5256 >>> data = 'abcdefghijklmnopqrstuvwxyz'
5257 >>> list(extract(data, [7, 4, 11, 11, 14]))
5258 ['h', 'e', 'l', 'l', 'o']
5259
5260 The *iterable* is consumed lazily and can be infinite.
5261 The *indices* are consumed immediately and must be finite.
5262
5263 Raises ``IndexError`` if an index lies beyond the iterable.
5264 Raises ``ValueError`` for negative indices.
5265 """
5266
5267 iterator = iter(iterable)
5268 index_and_position = sorted(zip(indices, count()))
5269
5270 if index_and_position and index_and_position[0][0] < 0:
5271 raise ValueError('Indices must be non-negative')
5272
5273 buffer = {}
5274 iterator_position = -1
5275 next_to_emit = 0
5276
5277 for index, order in index_and_position:
5278 advance = index - iterator_position
5279 if advance:
5280 try:
5281 value = next(islice(iterator, advance - 1, None))
5282 except StopIteration:
5283 raise IndexError(index)
5284 iterator_position = index
5285
5286 buffer[order] = value
5287
5288 while next_to_emit in buffer:
5289 yield buffer.pop(next_to_emit)
5290 next_to_emit += 1