1import math
2
3from collections import Counter, defaultdict, deque
4from collections.abc import Sequence
5from contextlib import suppress
6from functools import cached_property, partial, reduce, wraps
7from heapq import heapify, heapreplace
8from itertools import (
9 chain,
10 combinations,
11 compress,
12 count,
13 cycle,
14 dropwhile,
15 groupby,
16 islice,
17 permutations,
18 repeat,
19 starmap,
20 takewhile,
21 tee,
22 zip_longest,
23 product,
24)
25from math import comb, e, exp, factorial, floor, fsum, log, log1p, perm, tau
26from math import ceil
27from queue import Empty, Queue
28from random import random, randrange, shuffle, uniform
29from operator import (
30 attrgetter,
31 is_not,
32 itemgetter,
33 lt,
34 mul,
35 neg,
36 sub,
37 gt,
38)
39from sys import maxsize
40from time import monotonic
41
42from .recipes import (
43 _marker,
44 consume,
45 first_true,
46 flatten,
47 is_prime,
48 nth,
49 powerset,
50 sieve,
51 take,
52 unique_everseen,
53 all_equal,
54 batched,
55)
56
57__all__ = [
58 'AbortThread',
59 'SequenceView',
60 'adjacent',
61 'all_unique',
62 'always_iterable',
63 'always_reversible',
64 'argmax',
65 'argmin',
66 'bucket',
67 'callback_iter',
68 'chunked',
69 'chunked_even',
70 'circular_shifts',
71 'collapse',
72 'combination_index',
73 'combination_with_replacement_index',
74 'consecutive_groups',
75 'constrained_batches',
76 'consumer',
77 'count_cycle',
78 'countable',
79 'derangements',
80 'dft',
81 'difference',
82 'distinct_combinations',
83 'distinct_permutations',
84 'distribute',
85 'divide',
86 'doublestarmap',
87 'duplicates_everseen',
88 'duplicates_justseen',
89 'classify_unique',
90 'exactly_n',
91 'extract',
92 'filter_except',
93 'filter_map',
94 'first',
95 'gray_product',
96 'groupby_transform',
97 'ichunked',
98 'iequals',
99 'idft',
100 'ilen',
101 'interleave',
102 'interleave_evenly',
103 'interleave_longest',
104 'interleave_randomly',
105 'intersperse',
106 'is_sorted',
107 'islice_extended',
108 'iterate',
109 'iter_suppress',
110 'join_mappings',
111 'last',
112 'locate',
113 'longest_common_prefix',
114 'lstrip',
115 'make_decorator',
116 'map_except',
117 'map_if',
118 'map_reduce',
119 'mark_ends',
120 'minmax',
121 'nth_or_last',
122 'nth_permutation',
123 'nth_prime',
124 'nth_product',
125 'nth_combination_with_replacement',
126 'numeric_range',
127 'one',
128 'only',
129 'outer_product',
130 'padded',
131 'partial_product',
132 'partitions',
133 'peekable',
134 'permutation_index',
135 'powerset_of_sets',
136 'product_index',
137 'raise_',
138 'repeat_each',
139 'repeat_last',
140 'replace',
141 'rlocate',
142 'rstrip',
143 'run_length',
144 'sample',
145 'seekable',
146 'set_partitions',
147 'side_effect',
148 'sliced',
149 'sort_together',
150 'split_after',
151 'split_at',
152 'split_before',
153 'split_into',
154 'split_when',
155 'spy',
156 'stagger',
157 'strip',
158 'strictly_n',
159 'substrings',
160 'substrings_indexes',
161 'takewhile_inclusive',
162 'time_limited',
163 'unique_in_window',
164 'unique_to_each',
165 'unzip',
166 'value_chain',
167 'windowed',
168 'windowed_complete',
169 'with_iter',
170 'zip_broadcast',
171 'zip_offset',
172]
173
174# math.sumprod is available for Python 3.12+
175try:
176 from math import sumprod as _fsumprod
177
178except ImportError: # pragma: no cover
179 # Extended precision algorithms from T. J. Dekker,
180 # "A Floating-Point Technique for Extending the Available Precision"
181 # https://csclub.uwaterloo.ca/~pbarfuss/dekker1971.pdf
182 # Formulas: (5.5) (5.6) and (5.8). Code: mul12()
183
184 def dl_split(x: float):
185 "Split a float into two half-precision components."
186 t = x * 134217729.0 # Veltkamp constant = 2.0 ** 27 + 1
187 hi = t - (t - x)
188 lo = x - hi
189 return hi, lo
190
191 def dl_mul(x, y):
192 "Lossless multiplication."
193 xx_hi, xx_lo = dl_split(x)
194 yy_hi, yy_lo = dl_split(y)
195 p = xx_hi * yy_hi
196 q = xx_hi * yy_lo + xx_lo * yy_hi
197 z = p + q
198 zz = p - z + q + xx_lo * yy_lo
199 return z, zz
200
201 def _fsumprod(p, q):
202 return fsum(chain.from_iterable(map(dl_mul, p, q)))
203
204
205def chunked(iterable, n, strict=False):
206 """Break *iterable* into lists of length *n*:
207
208 >>> list(chunked([1, 2, 3, 4, 5, 6], 3))
209 [[1, 2, 3], [4, 5, 6]]
210
211 By the default, the last yielded list will have fewer than *n* elements
212 if the length of *iterable* is not divisible by *n*:
213
214 >>> list(chunked([1, 2, 3, 4, 5, 6, 7, 8], 3))
215 [[1, 2, 3], [4, 5, 6], [7, 8]]
216
217 To use a fill-in value instead, see the :func:`grouper` recipe.
218
219 If the length of *iterable* is not divisible by *n* and *strict* is
220 ``True``, then ``ValueError`` will be raised before the last
221 list is yielded.
222
223 """
224 iterator = iter(partial(take, n, iter(iterable)), [])
225 if strict:
226 if n is None:
227 raise ValueError('n must not be None when using strict mode.')
228
229 def ret():
230 for chunk in iterator:
231 if len(chunk) != n:
232 raise ValueError('iterable is not divisible by n.')
233 yield chunk
234
235 return ret()
236 else:
237 return iterator
238
239
240def first(iterable, default=_marker):
241 """Return the first item of *iterable*, or *default* if *iterable* is
242 empty.
243
244 >>> first([0, 1, 2, 3])
245 0
246 >>> first([], 'some default')
247 'some default'
248
249 If *default* is not provided and there are no items in the iterable,
250 raise ``ValueError``.
251
252 :func:`first` is useful when you have a generator of expensive-to-retrieve
253 values and want any arbitrary one. It is marginally shorter than
254 ``next(iter(iterable), default)``.
255
256 """
257 for item in iterable:
258 return item
259 if default is _marker:
260 raise ValueError(
261 'first() was called on an empty iterable, '
262 'and no default value was provided.'
263 )
264 return default
265
266
267def last(iterable, default=_marker):
268 """Return the last item of *iterable*, or *default* if *iterable* is
269 empty.
270
271 >>> last([0, 1, 2, 3])
272 3
273 >>> last([], 'some default')
274 'some default'
275
276 If *default* is not provided and there are no items in the iterable,
277 raise ``ValueError``.
278 """
279 try:
280 if isinstance(iterable, Sequence):
281 return iterable[-1]
282 # Work around https://bugs.python.org/issue38525
283 if getattr(iterable, '__reversed__', None):
284 return next(reversed(iterable))
285 return deque(iterable, maxlen=1)[-1]
286 except (IndexError, TypeError, StopIteration):
287 if default is _marker:
288 raise ValueError(
289 'last() was called on an empty iterable, '
290 'and no default value was provided.'
291 )
292 return default
293
294
295def nth_or_last(iterable, n, default=_marker):
296 """Return the nth or the last item of *iterable*,
297 or *default* if *iterable* is empty.
298
299 >>> nth_or_last([0, 1, 2, 3], 2)
300 2
301 >>> nth_or_last([0, 1], 2)
302 1
303 >>> nth_or_last([], 0, 'some default')
304 'some default'
305
306 If *default* is not provided and there are no items in the iterable,
307 raise ``ValueError``.
308 """
309 return last(islice(iterable, n + 1), default=default)
310
311
312class peekable:
313 """Wrap an iterator to allow lookahead and prepending elements.
314
315 Call :meth:`peek` on the result to get the value that will be returned
316 by :func:`next`. This won't advance the iterator:
317
318 >>> p = peekable(['a', 'b'])
319 >>> p.peek()
320 'a'
321 >>> next(p)
322 'a'
323
324 Pass :meth:`peek` a default value to return that instead of raising
325 ``StopIteration`` when the iterator is exhausted.
326
327 >>> p = peekable([])
328 >>> p.peek('hi')
329 'hi'
330
331 peekables also offer a :meth:`prepend` method, which "inserts" items
332 at the head of the iterable:
333
334 >>> p = peekable([1, 2, 3])
335 >>> p.prepend(10, 11, 12)
336 >>> next(p)
337 10
338 >>> p.peek()
339 11
340 >>> list(p)
341 [11, 12, 1, 2, 3]
342
343 peekables can be indexed. Index 0 is the item that will be returned by
344 :func:`next`, index 1 is the item after that, and so on:
345 The values up to the given index will be cached.
346
347 >>> p = peekable(['a', 'b', 'c', 'd'])
348 >>> p[0]
349 'a'
350 >>> p[1]
351 'b'
352 >>> next(p)
353 'a'
354
355 Negative indexes are supported, but be aware that they will cache the
356 remaining items in the source iterator, which may require significant
357 storage.
358
359 To check whether a peekable is exhausted, check its truth value:
360
361 >>> p = peekable(['a', 'b'])
362 >>> if p: # peekable has items
363 ... list(p)
364 ['a', 'b']
365 >>> if not p: # peekable is exhausted
366 ... list(p)
367 []
368
369 """
370
371 def __init__(self, iterable):
372 self._it = iter(iterable)
373 self._cache = deque()
374
375 def __iter__(self):
376 return self
377
378 def __bool__(self):
379 try:
380 self.peek()
381 except StopIteration:
382 return False
383 return True
384
385 def peek(self, default=_marker):
386 """Return the item that will be next returned from ``next()``.
387
388 Return ``default`` if there are no items left. If ``default`` is not
389 provided, raise ``StopIteration``.
390
391 """
392 if not self._cache:
393 try:
394 self._cache.append(next(self._it))
395 except StopIteration:
396 if default is _marker:
397 raise
398 return default
399 return self._cache[0]
400
401 def prepend(self, *items):
402 """Stack up items to be the next ones returned from ``next()`` or
403 ``self.peek()``. The items will be returned in
404 first in, first out order::
405
406 >>> p = peekable([1, 2, 3])
407 >>> p.prepend(10, 11, 12)
408 >>> next(p)
409 10
410 >>> list(p)
411 [11, 12, 1, 2, 3]
412
413 It is possible, by prepending items, to "resurrect" a peekable that
414 previously raised ``StopIteration``.
415
416 >>> p = peekable([])
417 >>> next(p)
418 Traceback (most recent call last):
419 ...
420 StopIteration
421 >>> p.prepend(1)
422 >>> next(p)
423 1
424 >>> next(p)
425 Traceback (most recent call last):
426 ...
427 StopIteration
428
429 """
430 self._cache.extendleft(reversed(items))
431
432 def __next__(self):
433 if self._cache:
434 return self._cache.popleft()
435
436 return next(self._it)
437
438 def _get_slice(self, index):
439 # Normalize the slice's arguments
440 step = 1 if (index.step is None) else index.step
441 if step > 0:
442 start = 0 if (index.start is None) else index.start
443 stop = maxsize if (index.stop is None) else index.stop
444 elif step < 0:
445 start = -1 if (index.start is None) else index.start
446 stop = (-maxsize - 1) if (index.stop is None) else index.stop
447 else:
448 raise ValueError('slice step cannot be zero')
449
450 # If either the start or stop index is negative, we'll need to cache
451 # the rest of the iterable in order to slice from the right side.
452 if (start < 0) or (stop < 0):
453 self._cache.extend(self._it)
454 # Otherwise we'll need to find the rightmost index and cache to that
455 # point.
456 else:
457 n = min(max(start, stop) + 1, maxsize)
458 cache_len = len(self._cache)
459 if n >= cache_len:
460 self._cache.extend(islice(self._it, n - cache_len))
461
462 return list(self._cache)[index]
463
464 def __getitem__(self, index):
465 if isinstance(index, slice):
466 return self._get_slice(index)
467
468 cache_len = len(self._cache)
469 if index < 0:
470 self._cache.extend(self._it)
471 elif index >= cache_len:
472 self._cache.extend(islice(self._it, index + 1 - cache_len))
473
474 return self._cache[index]
475
476
477def consumer(func):
478 """Decorator that automatically advances a PEP-342-style "reverse iterator"
479 to its first yield point so you don't have to call ``next()`` on it
480 manually.
481
482 >>> @consumer
483 ... def tally():
484 ... i = 0
485 ... while True:
486 ... print('Thing number %s is %s.' % (i, (yield)))
487 ... i += 1
488 ...
489 >>> t = tally()
490 >>> t.send('red')
491 Thing number 0 is red.
492 >>> t.send('fish')
493 Thing number 1 is fish.
494
495 Without the decorator, you would have to call ``next(t)`` before
496 ``t.send()`` could be used.
497
498 """
499
500 @wraps(func)
501 def wrapper(*args, **kwargs):
502 gen = func(*args, **kwargs)
503 next(gen)
504 return gen
505
506 return wrapper
507
508
509def ilen(iterable):
510 """Return the number of items in *iterable*.
511
512 For example, there are 168 prime numbers below 1,000:
513
514 >>> ilen(sieve(1000))
515 168
516
517 Equivalent to, but faster than::
518
519 def ilen(iterable):
520 count = 0
521 for _ in iterable:
522 count += 1
523 return count
524
525 This fully consumes the iterable, so handle with care.
526
527 """
528 # This is the "most beautiful of the fast variants" of this function.
529 # If you think you can improve on it, please ensure that your version
530 # is both 10x faster and 10x more beautiful.
531 return sum(compress(repeat(1), zip(iterable)))
532
533
534def iterate(func, start):
535 """Return ``start``, ``func(start)``, ``func(func(start))``, ...
536
537 Produces an infinite iterator. To add a stopping condition,
538 use :func:`take`, ``takewhile``, or :func:`takewhile_inclusive`:.
539
540 >>> take(10, iterate(lambda x: 2*x, 1))
541 [1, 2, 4, 8, 16, 32, 64, 128, 256, 512]
542
543 >>> collatz = lambda x: 3*x + 1 if x%2==1 else x // 2
544 >>> list(takewhile_inclusive(lambda x: x!=1, iterate(collatz, 10)))
545 [10, 5, 16, 8, 4, 2, 1]
546
547 """
548 with suppress(StopIteration):
549 while True:
550 yield start
551 start = func(start)
552
553
554def with_iter(context_manager):
555 """Wrap an iterable in a ``with`` statement, so it closes once exhausted.
556
557 For example, this will close the file when the iterator is exhausted::
558
559 upper_lines = (line.upper() for line in with_iter(open('foo')))
560
561 Any context manager which returns an iterable is a candidate for
562 ``with_iter``.
563
564 """
565 with context_manager as iterable:
566 yield from iterable
567
568
569def one(iterable, too_short=None, too_long=None):
570 """Return the first item from *iterable*, which is expected to contain only
571 that item. Raise an exception if *iterable* is empty or has more than one
572 item.
573
574 :func:`one` is useful for ensuring that an iterable contains only one item.
575 For example, it can be used to retrieve the result of a database query
576 that is expected to return a single row.
577
578 If *iterable* is empty, ``ValueError`` will be raised. You may specify a
579 different exception with the *too_short* keyword:
580
581 >>> it = []
582 >>> one(it) # doctest: +IGNORE_EXCEPTION_DETAIL
583 Traceback (most recent call last):
584 ...
585 ValueError: too few items in iterable (expected 1)'
586 >>> too_short = IndexError('too few items')
587 >>> one(it, too_short=too_short) # doctest: +IGNORE_EXCEPTION_DETAIL
588 Traceback (most recent call last):
589 ...
590 IndexError: too few items
591
592 Similarly, if *iterable* contains more than one item, ``ValueError`` will
593 be raised. You may specify a different exception with the *too_long*
594 keyword:
595
596 >>> it = ['too', 'many']
597 >>> one(it) # doctest: +IGNORE_EXCEPTION_DETAIL
598 Traceback (most recent call last):
599 ...
600 ValueError: Expected exactly one item in iterable, but got 'too',
601 'many', and perhaps more.
602 >>> too_long = RuntimeError
603 >>> one(it, too_long=too_long) # doctest: +IGNORE_EXCEPTION_DETAIL
604 Traceback (most recent call last):
605 ...
606 RuntimeError
607
608 Note that :func:`one` attempts to advance *iterable* twice to ensure there
609 is only one item. See :func:`spy` or :func:`peekable` to check iterable
610 contents less destructively.
611
612 """
613 iterator = iter(iterable)
614 for first in iterator:
615 for second in iterator:
616 msg = (
617 f'Expected exactly one item in iterable, but got {first!r}, '
618 f'{second!r}, and perhaps more.'
619 )
620 raise too_long or ValueError(msg)
621 return first
622 raise too_short or ValueError('too few items in iterable (expected 1)')
623
624
625def raise_(exception, *args):
626 raise exception(*args)
627
628
629def strictly_n(iterable, n, too_short=None, too_long=None):
630 """Validate that *iterable* has exactly *n* items and return them if
631 it does. If it has fewer than *n* items, call function *too_short*
632 with the actual number of items. If it has more than *n* items, call function
633 *too_long* with the number ``n + 1``.
634
635 >>> iterable = ['a', 'b', 'c', 'd']
636 >>> n = 4
637 >>> list(strictly_n(iterable, n))
638 ['a', 'b', 'c', 'd']
639
640 Note that the returned iterable must be consumed in order for the check to
641 be made.
642
643 By default, *too_short* and *too_long* are functions that raise
644 ``ValueError``.
645
646 >>> list(strictly_n('ab', 3)) # doctest: +IGNORE_EXCEPTION_DETAIL
647 Traceback (most recent call last):
648 ...
649 ValueError: too few items in iterable (got 2)
650
651 >>> list(strictly_n('abc', 2)) # doctest: +IGNORE_EXCEPTION_DETAIL
652 Traceback (most recent call last):
653 ...
654 ValueError: too many items in iterable (got at least 3)
655
656 You can instead supply functions that do something else.
657 *too_short* will be called with the number of items in *iterable*.
658 *too_long* will be called with `n + 1`.
659
660 >>> def too_short(item_count):
661 ... raise RuntimeError
662 >>> it = strictly_n('abcd', 6, too_short=too_short)
663 >>> list(it) # doctest: +IGNORE_EXCEPTION_DETAIL
664 Traceback (most recent call last):
665 ...
666 RuntimeError
667
668 >>> def too_long(item_count):
669 ... print('The boss is going to hear about this')
670 >>> it = strictly_n('abcdef', 4, too_long=too_long)
671 >>> list(it)
672 The boss is going to hear about this
673 ['a', 'b', 'c', 'd']
674
675 """
676 if too_short is None:
677 too_short = lambda item_count: raise_(
678 ValueError,
679 f'Too few items in iterable (got {item_count})',
680 )
681
682 if too_long is None:
683 too_long = lambda item_count: raise_(
684 ValueError,
685 f'Too many items in iterable (got at least {item_count})',
686 )
687
688 it = iter(iterable)
689
690 sent = 0
691 for item in islice(it, n):
692 yield item
693 sent += 1
694
695 if sent < n:
696 too_short(sent)
697 return
698
699 for item in it:
700 too_long(n + 1)
701 return
702
703
704def distinct_permutations(iterable, r=None):
705 """Yield successive distinct permutations of the elements in *iterable*.
706
707 >>> sorted(distinct_permutations([1, 0, 1]))
708 [(0, 1, 1), (1, 0, 1), (1, 1, 0)]
709
710 Equivalent to yielding from ``set(permutations(iterable))``, except
711 duplicates are not generated and thrown away. For larger input sequences
712 this is much more efficient.
713
714 Duplicate permutations arise when there are duplicated elements in the
715 input iterable. The number of items returned is
716 `n! / (x_1! * x_2! * ... * x_n!)`, where `n` is the total number of
717 items input, and each `x_i` is the count of a distinct item in the input
718 sequence. The function :func:`multinomial` computes this directly.
719
720 If *r* is given, only the *r*-length permutations are yielded.
721
722 >>> sorted(distinct_permutations([1, 0, 1], r=2))
723 [(0, 1), (1, 0), (1, 1)]
724 >>> sorted(distinct_permutations(range(3), r=2))
725 [(0, 1), (0, 2), (1, 0), (1, 2), (2, 0), (2, 1)]
726
727 *iterable* need not be sortable, but note that using equal (``x == y``)
728 but non-identical (``id(x) != id(y)``) elements may produce surprising
729 behavior. For example, ``1`` and ``True`` are equal but non-identical:
730
731 >>> list(distinct_permutations([1, True, '3'])) # doctest: +SKIP
732 [
733 (1, True, '3'),
734 (1, '3', True),
735 ('3', 1, True)
736 ]
737 >>> list(distinct_permutations([1, 2, '3'])) # doctest: +SKIP
738 [
739 (1, 2, '3'),
740 (1, '3', 2),
741 (2, 1, '3'),
742 (2, '3', 1),
743 ('3', 1, 2),
744 ('3', 2, 1)
745 ]
746 """
747
748 # Algorithm: https://w.wiki/Qai
749 def _full(A):
750 while True:
751 # Yield the permutation we have
752 yield tuple(A)
753
754 # Find the largest index i such that A[i] < A[i + 1]
755 for i in range(size - 2, -1, -1):
756 if A[i] < A[i + 1]:
757 break
758 # If no such index exists, this permutation is the last one
759 else:
760 return
761
762 # Find the largest index j greater than j such that A[i] < A[j]
763 for j in range(size - 1, i, -1):
764 if A[i] < A[j]:
765 break
766
767 # Swap the value of A[i] with that of A[j], then reverse the
768 # sequence from A[i + 1] to form the new permutation
769 A[i], A[j] = A[j], A[i]
770 A[i + 1 :] = A[: i - size : -1] # A[i + 1:][::-1]
771
772 # Algorithm: modified from the above
773 def _partial(A, r):
774 # Split A into the first r items and the last r items
775 head, tail = A[:r], A[r:]
776 right_head_indexes = range(r - 1, -1, -1)
777 left_tail_indexes = range(len(tail))
778
779 while True:
780 # Yield the permutation we have
781 yield tuple(head)
782
783 # Starting from the right, find the first index of the head with
784 # value smaller than the maximum value of the tail - call it i.
785 pivot = tail[-1]
786 for i in right_head_indexes:
787 if head[i] < pivot:
788 break
789 pivot = head[i]
790 else:
791 return
792
793 # Starting from the left, find the first value of the tail
794 # with a value greater than head[i] and swap.
795 for j in left_tail_indexes:
796 if tail[j] > head[i]:
797 head[i], tail[j] = tail[j], head[i]
798 break
799 # If we didn't find one, start from the right and find the first
800 # index of the head with a value greater than head[i] and swap.
801 else:
802 for j in right_head_indexes:
803 if head[j] > head[i]:
804 head[i], head[j] = head[j], head[i]
805 break
806
807 # Reverse head[i + 1:] and swap it with tail[:r - (i + 1)]
808 tail += head[: i - r : -1] # head[i + 1:][::-1]
809 i += 1
810 head[i:], tail[:] = tail[: r - i], tail[r - i :]
811
812 items = list(iterable)
813
814 try:
815 items.sort()
816 sortable = True
817 except TypeError:
818 sortable = False
819
820 indices_dict = defaultdict(list)
821
822 for item in items:
823 indices_dict[items.index(item)].append(item)
824
825 indices = [items.index(item) for item in items]
826 indices.sort()
827
828 equivalent_items = {k: cycle(v) for k, v in indices_dict.items()}
829
830 def permuted_items(permuted_indices):
831 return tuple(
832 next(equivalent_items[index]) for index in permuted_indices
833 )
834
835 size = len(items)
836 if r is None:
837 r = size
838
839 # functools.partial(_partial, ... )
840 algorithm = _full if (r == size) else partial(_partial, r=r)
841
842 if 0 < r <= size:
843 if sortable:
844 return algorithm(items)
845 else:
846 return (
847 permuted_items(permuted_indices)
848 for permuted_indices in algorithm(indices)
849 )
850
851 return iter(() if r else ((),))
852
853
854def derangements(iterable, r=None):
855 """Yield successive derangements of the elements in *iterable*.
856
857 A derangement is a permutation in which no element appears at its original
858 index. In other words, a derangement is a permutation that has no fixed points.
859
860 Suppose Alice, Bob, Carol, and Dave are playing Secret Santa.
861 The code below outputs all of the different ways to assign gift recipients
862 such that nobody is assigned to himself or herself:
863
864 >>> for d in derangements(['Alice', 'Bob', 'Carol', 'Dave']):
865 ... print(', '.join(d))
866 Bob, Alice, Dave, Carol
867 Bob, Carol, Dave, Alice
868 Bob, Dave, Alice, Carol
869 Carol, Alice, Dave, Bob
870 Carol, Dave, Alice, Bob
871 Carol, Dave, Bob, Alice
872 Dave, Alice, Bob, Carol
873 Dave, Carol, Alice, Bob
874 Dave, Carol, Bob, Alice
875
876 If *r* is given, only the *r*-length derangements are yielded.
877
878 >>> sorted(derangements(range(3), 2))
879 [(1, 0), (1, 2), (2, 0)]
880 >>> sorted(derangements([0, 2, 3], 2))
881 [(2, 0), (2, 3), (3, 0)]
882
883 Elements are treated as unique based on their position, not on their value.
884
885 Consider the Secret Santa example with two *different* people who have
886 the *same* name. Then there are two valid gift assignments even though
887 it might appear that a person is assigned to themselves:
888
889 >>> names = ['Alice', 'Bob', 'Bob']
890 >>> list(derangements(names))
891 [('Bob', 'Bob', 'Alice'), ('Bob', 'Alice', 'Bob')]
892
893 To avoid confusion, make the inputs distinct:
894
895 >>> deduped = [f'{name}{index}' for index, name in enumerate(names)]
896 >>> list(derangements(deduped))
897 [('Bob1', 'Bob2', 'Alice0'), ('Bob2', 'Alice0', 'Bob1')]
898
899 The number of derangements of a set of size *n* is known as the
900 "subfactorial of n". For n > 0, the subfactorial is:
901 ``round(math.factorial(n) / math.e)``.
902
903 References:
904
905 * Article: https://www.numberanalytics.com/blog/ultimate-guide-to-derangements-in-combinatorics
906 * Sizes: https://oeis.org/A000166
907 """
908 xs = tuple(iterable)
909 ys = tuple(range(len(xs)))
910 return compress(
911 permutations(xs, r=r),
912 map(all, map(map, repeat(is_not), repeat(ys), permutations(ys, r=r))),
913 )
914
915
916def intersperse(e, iterable, n=1):
917 """Intersperse filler element *e* among the items in *iterable*, leaving
918 *n* items between each filler element.
919
920 >>> list(intersperse('!', [1, 2, 3, 4, 5]))
921 [1, '!', 2, '!', 3, '!', 4, '!', 5]
922
923 >>> list(intersperse(None, [1, 2, 3, 4, 5], n=2))
924 [1, 2, None, 3, 4, None, 5]
925
926 """
927 if n == 0:
928 raise ValueError('n must be > 0')
929 elif n == 1:
930 # interleave(repeat(e), iterable) -> e, x_0, e, x_1, e, x_2...
931 # islice(..., 1, None) -> x_0, e, x_1, e, x_2...
932 return islice(interleave(repeat(e), iterable), 1, None)
933 else:
934 # interleave(filler, chunks) -> [e], [x_0, x_1], [e], [x_2, x_3]...
935 # islice(..., 1, None) -> [x_0, x_1], [e], [x_2, x_3]...
936 # flatten(...) -> x_0, x_1, e, x_2, x_3...
937 filler = repeat([e])
938 chunks = chunked(iterable, n)
939 return flatten(islice(interleave(filler, chunks), 1, None))
940
941
942def unique_to_each(*iterables):
943 """Return the elements from each of the input iterables that aren't in the
944 other input iterables.
945
946 For example, suppose you have a set of packages, each with a set of
947 dependencies::
948
949 {'pkg_1': {'A', 'B'}, 'pkg_2': {'B', 'C'}, 'pkg_3': {'B', 'D'}}
950
951 If you remove one package, which dependencies can also be removed?
952
953 If ``pkg_1`` is removed, then ``A`` is no longer necessary - it is not
954 associated with ``pkg_2`` or ``pkg_3``. Similarly, ``C`` is only needed for
955 ``pkg_2``, and ``D`` is only needed for ``pkg_3``::
956
957 >>> unique_to_each({'A', 'B'}, {'B', 'C'}, {'B', 'D'})
958 [['A'], ['C'], ['D']]
959
960 If there are duplicates in one input iterable that aren't in the others
961 they will be duplicated in the output. Input order is preserved::
962
963 >>> unique_to_each("mississippi", "missouri")
964 [['p', 'p'], ['o', 'u', 'r']]
965
966 It is assumed that the elements of each iterable are hashable.
967
968 """
969 pool = [list(it) for it in iterables]
970 counts = Counter(chain.from_iterable(map(set, pool)))
971 uniques = {element for element in counts if counts[element] == 1}
972 return [list(filter(uniques.__contains__, it)) for it in pool]
973
974
975def windowed(seq, n, fillvalue=None, step=1):
976 """Return a sliding window of width *n* over the given iterable.
977
978 >>> all_windows = windowed([1, 2, 3, 4, 5], 3)
979 >>> list(all_windows)
980 [(1, 2, 3), (2, 3, 4), (3, 4, 5)]
981
982 When the window is larger than the iterable, *fillvalue* is used in place
983 of missing values:
984
985 >>> list(windowed([1, 2, 3], 4))
986 [(1, 2, 3, None)]
987
988 Each window will advance in increments of *step*:
989
990 >>> list(windowed([1, 2, 3, 4, 5, 6], 3, fillvalue='!', step=2))
991 [(1, 2, 3), (3, 4, 5), (5, 6, '!')]
992
993 To slide into the iterable's items, use :func:`chain` to add filler items
994 to the left:
995
996 >>> iterable = [1, 2, 3, 4]
997 >>> n = 3
998 >>> padding = [None] * (n - 1)
999 >>> list(windowed(chain(padding, iterable), 3))
1000 [(None, None, 1), (None, 1, 2), (1, 2, 3), (2, 3, 4)]
1001 """
1002 if n < 0:
1003 raise ValueError('n must be >= 0')
1004 if n == 0:
1005 yield ()
1006 return
1007 if step < 1:
1008 raise ValueError('step must be >= 1')
1009
1010 iterator = iter(seq)
1011
1012 # Generate first window
1013 window = deque(islice(iterator, n), maxlen=n)
1014
1015 # Deal with the first window not being full
1016 if not window:
1017 return
1018 if len(window) < n:
1019 yield tuple(window) + ((fillvalue,) * (n - len(window)))
1020 return
1021 yield tuple(window)
1022
1023 # Create the filler for the next windows. The padding ensures
1024 # we have just enough elements to fill the last window.
1025 padding = (fillvalue,) * (n - 1 if step >= n else step - 1)
1026 filler = map(window.append, chain(iterator, padding))
1027
1028 # Generate the rest of the windows
1029 for _ in islice(filler, step - 1, None, step):
1030 yield tuple(window)
1031
1032
1033def substrings(iterable):
1034 """Yield all of the substrings of *iterable*.
1035
1036 >>> [''.join(s) for s in substrings('more')]
1037 ['m', 'o', 'r', 'e', 'mo', 'or', 're', 'mor', 'ore', 'more']
1038
1039 Note that non-string iterables can also be subdivided.
1040
1041 >>> list(substrings([0, 1, 2]))
1042 [(0,), (1,), (2,), (0, 1), (1, 2), (0, 1, 2)]
1043
1044 """
1045 # The length-1 substrings
1046 seq = []
1047 for item in iterable:
1048 seq.append(item)
1049 yield (item,)
1050 seq = tuple(seq)
1051 item_count = len(seq)
1052
1053 # And the rest
1054 for n in range(2, item_count + 1):
1055 for i in range(item_count - n + 1):
1056 yield seq[i : i + n]
1057
1058
1059def substrings_indexes(seq, reverse=False):
1060 """Yield all substrings and their positions in *seq*
1061
1062 The items yielded will be a tuple of the form ``(substr, i, j)``, where
1063 ``substr == seq[i:j]``.
1064
1065 This function only works for iterables that support slicing, such as
1066 ``str`` objects.
1067
1068 >>> for item in substrings_indexes('more'):
1069 ... print(item)
1070 ('m', 0, 1)
1071 ('o', 1, 2)
1072 ('r', 2, 3)
1073 ('e', 3, 4)
1074 ('mo', 0, 2)
1075 ('or', 1, 3)
1076 ('re', 2, 4)
1077 ('mor', 0, 3)
1078 ('ore', 1, 4)
1079 ('more', 0, 4)
1080
1081 Set *reverse* to ``True`` to yield the same items in the opposite order.
1082
1083
1084 """
1085 r = range(1, len(seq) + 1)
1086 if reverse:
1087 r = reversed(r)
1088 return (
1089 (seq[i : i + L], i, i + L) for L in r for i in range(len(seq) - L + 1)
1090 )
1091
1092
1093class bucket:
1094 """Wrap *iterable* and return an object that buckets the iterable into
1095 child iterables based on a *key* function.
1096
1097 >>> iterable = ['a1', 'b1', 'c1', 'a2', 'b2', 'c2', 'b3']
1098 >>> s = bucket(iterable, key=lambda x: x[0]) # Bucket by 1st character
1099 >>> sorted(list(s)) # Get the keys
1100 ['a', 'b', 'c']
1101 >>> a_iterable = s['a']
1102 >>> next(a_iterable)
1103 'a1'
1104 >>> next(a_iterable)
1105 'a2'
1106 >>> list(s['b'])
1107 ['b1', 'b2', 'b3']
1108
1109 The original iterable will be advanced and its items will be cached until
1110 they are used by the child iterables. This may require significant storage.
1111
1112 By default, attempting to select a bucket to which no items belong will
1113 exhaust the iterable and cache all values.
1114 If you specify a *validator* function, selected buckets will instead be
1115 checked against it.
1116
1117 >>> from itertools import count
1118 >>> it = count(1, 2) # Infinite sequence of odd numbers
1119 >>> key = lambda x: x % 10 # Bucket by last digit
1120 >>> validator = lambda x: x in {1, 3, 5, 7, 9} # Odd digits only
1121 >>> s = bucket(it, key=key, validator=validator)
1122 >>> 2 in s
1123 False
1124 >>> list(s[2])
1125 []
1126
1127 """
1128
1129 def __init__(self, iterable, key, validator=None):
1130 self._it = iter(iterable)
1131 self._key = key
1132 self._cache = defaultdict(deque)
1133 self._validator = validator or (lambda x: True)
1134
1135 def __contains__(self, value):
1136 if not self._validator(value):
1137 return False
1138
1139 try:
1140 item = next(self[value])
1141 except StopIteration:
1142 return False
1143 else:
1144 self._cache[value].appendleft(item)
1145
1146 return True
1147
1148 def _get_values(self, value):
1149 """
1150 Helper to yield items from the parent iterator that match *value*.
1151 Items that don't match are stored in the local cache as they
1152 are encountered.
1153 """
1154 while True:
1155 # If we've cached some items that match the target value, emit
1156 # the first one and evict it from the cache.
1157 if self._cache[value]:
1158 yield self._cache[value].popleft()
1159 # Otherwise we need to advance the parent iterator to search for
1160 # a matching item, caching the rest.
1161 else:
1162 while True:
1163 try:
1164 item = next(self._it)
1165 except StopIteration:
1166 return
1167 item_value = self._key(item)
1168 if item_value == value:
1169 yield item
1170 break
1171 elif self._validator(item_value):
1172 self._cache[item_value].append(item)
1173
1174 def __iter__(self):
1175 for item in self._it:
1176 item_value = self._key(item)
1177 if self._validator(item_value):
1178 self._cache[item_value].append(item)
1179
1180 return iter(self._cache)
1181
1182 def __getitem__(self, value):
1183 if not self._validator(value):
1184 return iter(())
1185
1186 return self._get_values(value)
1187
1188
1189def spy(iterable, n=1):
1190 """Return a 2-tuple with a list containing the first *n* elements of
1191 *iterable*, and an iterator with the same items as *iterable*.
1192 This allows you to "look ahead" at the items in the iterable without
1193 advancing it.
1194
1195 There is one item in the list by default:
1196
1197 >>> iterable = 'abcdefg'
1198 >>> head, iterable = spy(iterable)
1199 >>> head
1200 ['a']
1201 >>> list(iterable)
1202 ['a', 'b', 'c', 'd', 'e', 'f', 'g']
1203
1204 You may use unpacking to retrieve items instead of lists:
1205
1206 >>> (head,), iterable = spy('abcdefg')
1207 >>> head
1208 'a'
1209 >>> (first, second), iterable = spy('abcdefg', 2)
1210 >>> first
1211 'a'
1212 >>> second
1213 'b'
1214
1215 The number of items requested can be larger than the number of items in
1216 the iterable:
1217
1218 >>> iterable = [1, 2, 3, 4, 5]
1219 >>> head, iterable = spy(iterable, 10)
1220 >>> head
1221 [1, 2, 3, 4, 5]
1222 >>> list(iterable)
1223 [1, 2, 3, 4, 5]
1224
1225 """
1226 p, q = tee(iterable)
1227 return take(n, q), p
1228
1229
1230def interleave(*iterables):
1231 """Return a new iterable yielding from each iterable in turn,
1232 until the shortest is exhausted.
1233
1234 >>> list(interleave([1, 2, 3], [4, 5], [6, 7, 8]))
1235 [1, 4, 6, 2, 5, 7]
1236
1237 For a version that doesn't terminate after the shortest iterable is
1238 exhausted, see :func:`interleave_longest`.
1239
1240 """
1241 return chain.from_iterable(zip(*iterables))
1242
1243
1244def interleave_longest(*iterables):
1245 """Return a new iterable yielding from each iterable in turn,
1246 skipping any that are exhausted.
1247
1248 >>> list(interleave_longest([1, 2, 3], [4, 5], [6, 7, 8]))
1249 [1, 4, 6, 2, 5, 7, 3, 8]
1250
1251 This function produces the same output as :func:`roundrobin`, but may
1252 perform better for some inputs (in particular when the number of iterables
1253 is large).
1254
1255 """
1256 for xs in zip_longest(*iterables, fillvalue=_marker):
1257 for x in xs:
1258 if x is not _marker:
1259 yield x
1260
1261
1262def interleave_evenly(iterables, lengths=None):
1263 """
1264 Interleave multiple iterables so that their elements are evenly distributed
1265 throughout the output sequence.
1266
1267 >>> iterables = [1, 2, 3, 4, 5], ['a', 'b']
1268 >>> list(interleave_evenly(iterables))
1269 [1, 2, 'a', 3, 4, 'b', 5]
1270
1271 >>> iterables = [[1, 2, 3], [4, 5], [6, 7, 8]]
1272 >>> list(interleave_evenly(iterables))
1273 [1, 6, 4, 2, 7, 3, 8, 5]
1274
1275 This function requires iterables of known length. Iterables without
1276 ``__len__()`` can be used by manually specifying lengths with *lengths*:
1277
1278 >>> from itertools import combinations, repeat
1279 >>> iterables = [combinations(range(4), 2), ['a', 'b', 'c']]
1280 >>> lengths = [4 * (4 - 1) // 2, 3]
1281 >>> list(interleave_evenly(iterables, lengths=lengths))
1282 [(0, 1), (0, 2), 'a', (0, 3), (1, 2), 'b', (1, 3), (2, 3), 'c']
1283
1284 Based on Bresenham's algorithm.
1285 """
1286 if lengths is None:
1287 try:
1288 lengths = [len(it) for it in iterables]
1289 except TypeError:
1290 raise ValueError(
1291 'Iterable lengths could not be determined automatically. '
1292 'Specify them with the lengths keyword.'
1293 )
1294 elif len(iterables) != len(lengths):
1295 raise ValueError('Mismatching number of iterables and lengths.')
1296
1297 dims = len(lengths)
1298
1299 # sort iterables by length, descending
1300 lengths_permute = sorted(
1301 range(dims), key=lambda i: lengths[i], reverse=True
1302 )
1303 lengths_desc = [lengths[i] for i in lengths_permute]
1304 iters_desc = [iter(iterables[i]) for i in lengths_permute]
1305
1306 # the longest iterable is the primary one (Bresenham: the longest
1307 # distance along an axis)
1308 delta_primary, deltas_secondary = lengths_desc[0], lengths_desc[1:]
1309 iter_primary, iters_secondary = iters_desc[0], iters_desc[1:]
1310 errors = [delta_primary // dims] * len(deltas_secondary)
1311
1312 to_yield = sum(lengths)
1313 while to_yield:
1314 yield next(iter_primary)
1315 to_yield -= 1
1316 # update errors for each secondary iterable
1317 errors = [e - delta for e, delta in zip(errors, deltas_secondary)]
1318
1319 # those iterables for which the error is negative are yielded
1320 # ("diagonal step" in Bresenham)
1321 for i, e_ in enumerate(errors):
1322 if e_ < 0:
1323 yield next(iters_secondary[i])
1324 to_yield -= 1
1325 errors[i] += delta_primary
1326
1327
1328def interleave_randomly(*iterables):
1329 """Repeatedly select one of the input *iterables* at random and yield the next
1330 item from it.
1331
1332 >>> iterables = [1, 2, 3], 'abc', (True, False, None)
1333 >>> list(interleave_randomly(*iterables)) # doctest: +SKIP
1334 ['a', 'b', 1, 'c', True, False, None, 2, 3]
1335
1336 The relative order of the items in each input iterable will preserved. Note the
1337 sequences of items with this property are not equally likely to be generated.
1338
1339 """
1340 iterators = [iter(e) for e in iterables]
1341 while iterators:
1342 idx = randrange(len(iterators))
1343 try:
1344 yield next(iterators[idx])
1345 except StopIteration:
1346 # equivalent to `list.pop` but slightly faster
1347 iterators[idx] = iterators[-1]
1348 del iterators[-1]
1349
1350
1351def collapse(iterable, base_type=None, levels=None):
1352 """Flatten an iterable with multiple levels of nesting (e.g., a list of
1353 lists of tuples) into non-iterable types.
1354
1355 >>> iterable = [(1, 2), ([3, 4], [[5], [6]])]
1356 >>> list(collapse(iterable))
1357 [1, 2, 3, 4, 5, 6]
1358
1359 Binary and text strings are not considered iterable and
1360 will not be collapsed.
1361
1362 To avoid collapsing other types, specify *base_type*:
1363
1364 >>> iterable = ['ab', ('cd', 'ef'), ['gh', 'ij']]
1365 >>> list(collapse(iterable, base_type=tuple))
1366 ['ab', ('cd', 'ef'), 'gh', 'ij']
1367
1368 Specify *levels* to stop flattening after a certain level:
1369
1370 >>> iterable = [('a', ['b']), ('c', ['d'])]
1371 >>> list(collapse(iterable)) # Fully flattened
1372 ['a', 'b', 'c', 'd']
1373 >>> list(collapse(iterable, levels=1)) # Only one level flattened
1374 ['a', ['b'], 'c', ['d']]
1375
1376 """
1377 stack = deque()
1378 # Add our first node group, treat the iterable as a single node
1379 stack.appendleft((0, repeat(iterable, 1)))
1380
1381 while stack:
1382 node_group = stack.popleft()
1383 level, nodes = node_group
1384
1385 # Check if beyond max level
1386 if levels is not None and level > levels:
1387 yield from nodes
1388 continue
1389
1390 for node in nodes:
1391 # Check if done iterating
1392 if isinstance(node, (str, bytes)) or (
1393 (base_type is not None) and isinstance(node, base_type)
1394 ):
1395 yield node
1396 # Otherwise try to create child nodes
1397 else:
1398 try:
1399 tree = iter(node)
1400 except TypeError:
1401 yield node
1402 else:
1403 # Save our current location
1404 stack.appendleft(node_group)
1405 # Append the new child node
1406 stack.appendleft((level + 1, tree))
1407 # Break to process child node
1408 break
1409
1410
1411def side_effect(func, iterable, chunk_size=None, before=None, after=None):
1412 """Invoke *func* on each item in *iterable* (or on each *chunk_size* group
1413 of items) before yielding the item.
1414
1415 `func` must be a function that takes a single argument. Its return value
1416 will be discarded.
1417
1418 *before* and *after* are optional functions that take no arguments. They
1419 will be executed before iteration starts and after it ends, respectively.
1420
1421 `side_effect` can be used for logging, updating progress bars, or anything
1422 that is not functionally "pure."
1423
1424 Emitting a status message:
1425
1426 >>> from more_itertools import consume
1427 >>> func = lambda item: print('Received {}'.format(item))
1428 >>> consume(side_effect(func, range(2)))
1429 Received 0
1430 Received 1
1431
1432 Operating on chunks of items:
1433
1434 >>> pair_sums = []
1435 >>> func = lambda chunk: pair_sums.append(sum(chunk))
1436 >>> list(side_effect(func, [0, 1, 2, 3, 4, 5], 2))
1437 [0, 1, 2, 3, 4, 5]
1438 >>> list(pair_sums)
1439 [1, 5, 9]
1440
1441 Writing to a file-like object:
1442
1443 >>> from io import StringIO
1444 >>> from more_itertools import consume
1445 >>> f = StringIO()
1446 >>> func = lambda x: print(x, file=f)
1447 >>> before = lambda: print(u'HEADER', file=f)
1448 >>> after = f.close
1449 >>> it = [u'a', u'b', u'c']
1450 >>> consume(side_effect(func, it, before=before, after=after))
1451 >>> f.closed
1452 True
1453
1454 """
1455 try:
1456 if before is not None:
1457 before()
1458
1459 if chunk_size is None:
1460 for item in iterable:
1461 func(item)
1462 yield item
1463 else:
1464 for chunk in chunked(iterable, chunk_size):
1465 func(chunk)
1466 yield from chunk
1467 finally:
1468 if after is not None:
1469 after()
1470
1471
1472def sliced(seq, n, strict=False):
1473 """Yield slices of length *n* from the sequence *seq*.
1474
1475 >>> list(sliced((1, 2, 3, 4, 5, 6), 3))
1476 [(1, 2, 3), (4, 5, 6)]
1477
1478 By the default, the last yielded slice will have fewer than *n* elements
1479 if the length of *seq* is not divisible by *n*:
1480
1481 >>> list(sliced((1, 2, 3, 4, 5, 6, 7, 8), 3))
1482 [(1, 2, 3), (4, 5, 6), (7, 8)]
1483
1484 If the length of *seq* is not divisible by *n* and *strict* is
1485 ``True``, then ``ValueError`` will be raised before the last
1486 slice is yielded.
1487
1488 This function will only work for iterables that support slicing.
1489 For non-sliceable iterables, see :func:`chunked`.
1490
1491 """
1492 iterator = takewhile(len, (seq[i : i + n] for i in count(0, n)))
1493 if strict:
1494
1495 def ret():
1496 for _slice in iterator:
1497 if len(_slice) != n:
1498 raise ValueError("seq is not divisible by n.")
1499 yield _slice
1500
1501 return ret()
1502 else:
1503 return iterator
1504
1505
1506def split_at(iterable, pred, maxsplit=-1, keep_separator=False):
1507 """Yield lists of items from *iterable*, where each list is delimited by
1508 an item where callable *pred* returns ``True``.
1509
1510 >>> list(split_at('abcdcba', lambda x: x == 'b'))
1511 [['a'], ['c', 'd', 'c'], ['a']]
1512
1513 >>> list(split_at(range(10), lambda n: n % 2 == 1))
1514 [[0], [2], [4], [6], [8], []]
1515
1516 At most *maxsplit* splits are done. If *maxsplit* is not specified or -1,
1517 then there is no limit on the number of splits:
1518
1519 >>> list(split_at(range(10), lambda n: n % 2 == 1, maxsplit=2))
1520 [[0], [2], [4, 5, 6, 7, 8, 9]]
1521
1522 By default, the delimiting items are not included in the output.
1523 To include them, set *keep_separator* to ``True``.
1524
1525 >>> list(split_at('abcdcba', lambda x: x == 'b', keep_separator=True))
1526 [['a'], ['b'], ['c', 'd', 'c'], ['b'], ['a']]
1527
1528 """
1529 if maxsplit == 0:
1530 yield list(iterable)
1531 return
1532
1533 buf = []
1534 it = iter(iterable)
1535 for item in it:
1536 if pred(item):
1537 yield buf
1538 if keep_separator:
1539 yield [item]
1540 if maxsplit == 1:
1541 yield list(it)
1542 return
1543 buf = []
1544 maxsplit -= 1
1545 else:
1546 buf.append(item)
1547 yield buf
1548
1549
1550def split_before(iterable, pred, maxsplit=-1):
1551 """Yield lists of items from *iterable*, where each list ends just before
1552 an item for which callable *pred* returns ``True``:
1553
1554 >>> list(split_before('OneTwo', lambda s: s.isupper()))
1555 [['O', 'n', 'e'], ['T', 'w', 'o']]
1556
1557 >>> list(split_before(range(10), lambda n: n % 3 == 0))
1558 [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
1559
1560 At most *maxsplit* splits are done. If *maxsplit* is not specified or -1,
1561 then there is no limit on the number of splits:
1562
1563 >>> list(split_before(range(10), lambda n: n % 3 == 0, maxsplit=2))
1564 [[0, 1, 2], [3, 4, 5], [6, 7, 8, 9]]
1565 """
1566 if maxsplit == 0:
1567 yield list(iterable)
1568 return
1569
1570 buf = []
1571 it = iter(iterable)
1572 for item in it:
1573 if pred(item) and buf:
1574 yield buf
1575 if maxsplit == 1:
1576 yield [item, *it]
1577 return
1578 buf = []
1579 maxsplit -= 1
1580 buf.append(item)
1581 if buf:
1582 yield buf
1583
1584
1585def split_after(iterable, pred, maxsplit=-1):
1586 """Yield lists of items from *iterable*, where each list ends with an
1587 item where callable *pred* returns ``True``:
1588
1589 >>> list(split_after('one1two2', lambda s: s.isdigit()))
1590 [['o', 'n', 'e', '1'], ['t', 'w', 'o', '2']]
1591
1592 >>> list(split_after(range(10), lambda n: n % 3 == 0))
1593 [[0], [1, 2, 3], [4, 5, 6], [7, 8, 9]]
1594
1595 At most *maxsplit* splits are done. If *maxsplit* is not specified or -1,
1596 then there is no limit on the number of splits:
1597
1598 >>> list(split_after(range(10), lambda n: n % 3 == 0, maxsplit=2))
1599 [[0], [1, 2, 3], [4, 5, 6, 7, 8, 9]]
1600
1601 """
1602 if maxsplit == 0:
1603 yield list(iterable)
1604 return
1605
1606 buf = []
1607 it = iter(iterable)
1608 for item in it:
1609 buf.append(item)
1610 if pred(item) and buf:
1611 yield buf
1612 if maxsplit == 1:
1613 buf = list(it)
1614 if buf:
1615 yield buf
1616 return
1617 buf = []
1618 maxsplit -= 1
1619 if buf:
1620 yield buf
1621
1622
1623def split_when(iterable, pred, maxsplit=-1):
1624 """Split *iterable* into pieces based on the output of *pred*.
1625 *pred* should be a function that takes successive pairs of items and
1626 returns ``True`` if the iterable should be split in between them.
1627
1628 For example, to find runs of increasing numbers, split the iterable when
1629 element ``i`` is larger than element ``i + 1``:
1630
1631 >>> list(split_when([1, 2, 3, 3, 2, 5, 2, 4, 2], lambda x, y: x > y))
1632 [[1, 2, 3, 3], [2, 5], [2, 4], [2]]
1633
1634 At most *maxsplit* splits are done. If *maxsplit* is not specified or -1,
1635 then there is no limit on the number of splits:
1636
1637 >>> list(split_when([1, 2, 3, 3, 2, 5, 2, 4, 2],
1638 ... lambda x, y: x > y, maxsplit=2))
1639 [[1, 2, 3, 3], [2, 5], [2, 4, 2]]
1640
1641 """
1642 if maxsplit == 0:
1643 yield list(iterable)
1644 return
1645
1646 it = iter(iterable)
1647 try:
1648 cur_item = next(it)
1649 except StopIteration:
1650 return
1651
1652 buf = [cur_item]
1653 for next_item in it:
1654 if pred(cur_item, next_item):
1655 yield buf
1656 if maxsplit == 1:
1657 yield [next_item, *it]
1658 return
1659 buf = []
1660 maxsplit -= 1
1661
1662 buf.append(next_item)
1663 cur_item = next_item
1664
1665 yield buf
1666
1667
1668def split_into(iterable, sizes):
1669 """Yield a list of sequential items from *iterable* of length 'n' for each
1670 integer 'n' in *sizes*.
1671
1672 >>> list(split_into([1,2,3,4,5,6], [1,2,3]))
1673 [[1], [2, 3], [4, 5, 6]]
1674
1675 If the sum of *sizes* is smaller than the length of *iterable*, then the
1676 remaining items of *iterable* will not be returned.
1677
1678 >>> list(split_into([1,2,3,4,5,6], [2,3]))
1679 [[1, 2], [3, 4, 5]]
1680
1681 If the sum of *sizes* is larger than the length of *iterable*, fewer items
1682 will be returned in the iteration that overruns the *iterable* and further
1683 lists will be empty:
1684
1685 >>> list(split_into([1,2,3,4], [1,2,3,4]))
1686 [[1], [2, 3], [4], []]
1687
1688 When a ``None`` object is encountered in *sizes*, the returned list will
1689 contain items up to the end of *iterable* the same way that
1690 :func:`itertools.slice` does:
1691
1692 >>> list(split_into([1,2,3,4,5,6,7,8,9,0], [2,3,None]))
1693 [[1, 2], [3, 4, 5], [6, 7, 8, 9, 0]]
1694
1695 :func:`split_into` can be useful for grouping a series of items where the
1696 sizes of the groups are not uniform. An example would be where in a row
1697 from a table, multiple columns represent elements of the same feature
1698 (e.g. a point represented by x,y,z) but, the format is not the same for
1699 all columns.
1700 """
1701 # convert the iterable argument into an iterator so its contents can
1702 # be consumed by islice in case it is a generator
1703 it = iter(iterable)
1704
1705 for size in sizes:
1706 if size is None:
1707 yield list(it)
1708 return
1709 else:
1710 yield list(islice(it, size))
1711
1712
1713def padded(iterable, fillvalue=None, n=None, next_multiple=False):
1714 """Yield the elements from *iterable*, followed by *fillvalue*, such that
1715 at least *n* items are emitted.
1716
1717 >>> list(padded([1, 2, 3], '?', 5))
1718 [1, 2, 3, '?', '?']
1719
1720 If *next_multiple* is ``True``, *fillvalue* will be emitted until the
1721 number of items emitted is a multiple of *n*:
1722
1723 >>> list(padded([1, 2, 3, 4], n=3, next_multiple=True))
1724 [1, 2, 3, 4, None, None]
1725
1726 If *n* is ``None``, *fillvalue* will be emitted indefinitely.
1727
1728 To create an *iterable* of exactly size *n*, you can truncate with
1729 :func:`islice`.
1730
1731 >>> list(islice(padded([1, 2, 3], '?'), 5))
1732 [1, 2, 3, '?', '?']
1733 >>> list(islice(padded([1, 2, 3, 4, 5, 6, 7, 8], '?'), 5))
1734 [1, 2, 3, 4, 5]
1735
1736 """
1737 iterator = iter(iterable)
1738 iterator_with_repeat = chain(iterator, repeat(fillvalue))
1739
1740 if n is None:
1741 return iterator_with_repeat
1742 elif n < 1:
1743 raise ValueError('n must be at least 1')
1744 elif next_multiple:
1745
1746 def slice_generator():
1747 for first in iterator:
1748 yield (first,)
1749 yield islice(iterator_with_repeat, n - 1)
1750
1751 # While elements exist produce slices of size n
1752 return chain.from_iterable(slice_generator())
1753 else:
1754 # Ensure the first batch is at least size n then iterate
1755 return chain(islice(iterator_with_repeat, n), iterator)
1756
1757
1758def repeat_each(iterable, n=2):
1759 """Repeat each element in *iterable* *n* times.
1760
1761 >>> list(repeat_each('ABC', 3))
1762 ['A', 'A', 'A', 'B', 'B', 'B', 'C', 'C', 'C']
1763 """
1764 return chain.from_iterable(map(repeat, iterable, repeat(n)))
1765
1766
1767def repeat_last(iterable, default=None):
1768 """After the *iterable* is exhausted, keep yielding its last element.
1769
1770 >>> list(islice(repeat_last(range(3)), 5))
1771 [0, 1, 2, 2, 2]
1772
1773 If the iterable is empty, yield *default* forever::
1774
1775 >>> list(islice(repeat_last(range(0), 42), 5))
1776 [42, 42, 42, 42, 42]
1777
1778 """
1779 item = _marker
1780 for item in iterable:
1781 yield item
1782 final = default if item is _marker else item
1783 yield from repeat(final)
1784
1785
1786def distribute(n, iterable):
1787 """Distribute the items from *iterable* among *n* smaller iterables.
1788
1789 >>> group_1, group_2 = distribute(2, [1, 2, 3, 4, 5, 6])
1790 >>> list(group_1)
1791 [1, 3, 5]
1792 >>> list(group_2)
1793 [2, 4, 6]
1794
1795 If the length of *iterable* is not evenly divisible by *n*, then the
1796 length of the returned iterables will not be identical:
1797
1798 >>> children = distribute(3, [1, 2, 3, 4, 5, 6, 7])
1799 >>> [list(c) for c in children]
1800 [[1, 4, 7], [2, 5], [3, 6]]
1801
1802 If the length of *iterable* is smaller than *n*, then the last returned
1803 iterables will be empty:
1804
1805 >>> children = distribute(5, [1, 2, 3])
1806 >>> [list(c) for c in children]
1807 [[1], [2], [3], [], []]
1808
1809 This function uses :func:`itertools.tee` and may require significant
1810 storage.
1811
1812 If you need the order items in the smaller iterables to match the
1813 original iterable, see :func:`divide`.
1814
1815 """
1816 if n < 1:
1817 raise ValueError('n must be at least 1')
1818
1819 children = tee(iterable, n)
1820 return [islice(it, index, None, n) for index, it in enumerate(children)]
1821
1822
1823def stagger(iterable, offsets=(-1, 0, 1), longest=False, fillvalue=None):
1824 """Yield tuples whose elements are offset from *iterable*.
1825 The amount by which the `i`-th item in each tuple is offset is given by
1826 the `i`-th item in *offsets*.
1827
1828 >>> list(stagger([0, 1, 2, 3]))
1829 [(None, 0, 1), (0, 1, 2), (1, 2, 3)]
1830 >>> list(stagger(range(8), offsets=(0, 2, 4)))
1831 [(0, 2, 4), (1, 3, 5), (2, 4, 6), (3, 5, 7)]
1832
1833 By default, the sequence will end when the final element of a tuple is the
1834 last item in the iterable. To continue until the first element of a tuple
1835 is the last item in the iterable, set *longest* to ``True``::
1836
1837 >>> list(stagger([0, 1, 2, 3], longest=True))
1838 [(None, 0, 1), (0, 1, 2), (1, 2, 3), (2, 3, None), (3, None, None)]
1839
1840 By default, ``None`` will be used to replace offsets beyond the end of the
1841 sequence. Specify *fillvalue* to use some other value.
1842
1843 """
1844 children = tee(iterable, len(offsets))
1845
1846 return zip_offset(
1847 *children, offsets=offsets, longest=longest, fillvalue=fillvalue
1848 )
1849
1850
1851def zip_offset(*iterables, offsets, longest=False, fillvalue=None):
1852 """``zip`` the input *iterables* together, but offset the `i`-th iterable
1853 by the `i`-th item in *offsets*.
1854
1855 >>> list(zip_offset('0123', 'abcdef', offsets=(0, 1)))
1856 [('0', 'b'), ('1', 'c'), ('2', 'd'), ('3', 'e')]
1857
1858 This can be used as a lightweight alternative to SciPy or pandas to analyze
1859 data sets in which some series have a lead or lag relationship.
1860
1861 By default, the sequence will end when the shortest iterable is exhausted.
1862 To continue until the longest iterable is exhausted, set *longest* to
1863 ``True``.
1864
1865 >>> list(zip_offset('0123', 'abcdef', offsets=(0, 1), longest=True))
1866 [('0', 'b'), ('1', 'c'), ('2', 'd'), ('3', 'e'), (None, 'f')]
1867
1868 By default, ``None`` will be used to replace offsets beyond the end of the
1869 sequence. Specify *fillvalue* to use some other value.
1870
1871 """
1872 if len(iterables) != len(offsets):
1873 raise ValueError("Number of iterables and offsets didn't match")
1874
1875 staggered = []
1876 for it, n in zip(iterables, offsets):
1877 if n < 0:
1878 staggered.append(chain(repeat(fillvalue, -n), it))
1879 elif n > 0:
1880 staggered.append(islice(it, n, None))
1881 else:
1882 staggered.append(it)
1883
1884 if longest:
1885 return zip_longest(*staggered, fillvalue=fillvalue)
1886
1887 return zip(*staggered)
1888
1889
1890def sort_together(
1891 iterables, key_list=(0,), key=None, reverse=False, strict=False
1892):
1893 """Return the input iterables sorted together, with *key_list* as the
1894 priority for sorting. All iterables are trimmed to the length of the
1895 shortest one.
1896
1897 This can be used like the sorting function in a spreadsheet. If each
1898 iterable represents a column of data, the key list determines which
1899 columns are used for sorting.
1900
1901 By default, all iterables are sorted using the ``0``-th iterable::
1902
1903 >>> iterables = [(4, 3, 2, 1), ('a', 'b', 'c', 'd')]
1904 >>> sort_together(iterables)
1905 [(1, 2, 3, 4), ('d', 'c', 'b', 'a')]
1906
1907 Set a different key list to sort according to another iterable.
1908 Specifying multiple keys dictates how ties are broken::
1909
1910 >>> iterables = [(3, 1, 2), (0, 1, 0), ('c', 'b', 'a')]
1911 >>> sort_together(iterables, key_list=(1, 2))
1912 [(2, 3, 1), (0, 0, 1), ('a', 'c', 'b')]
1913
1914 To sort by a function of the elements of the iterable, pass a *key*
1915 function. Its arguments are the elements of the iterables corresponding to
1916 the key list::
1917
1918 >>> names = ('a', 'b', 'c')
1919 >>> lengths = (1, 2, 3)
1920 >>> widths = (5, 2, 1)
1921 >>> def area(length, width):
1922 ... return length * width
1923 >>> sort_together([names, lengths, widths], key_list=(1, 2), key=area)
1924 [('c', 'b', 'a'), (3, 2, 1), (1, 2, 5)]
1925
1926 Set *reverse* to ``True`` to sort in descending order.
1927
1928 >>> sort_together([(1, 2, 3), ('c', 'b', 'a')], reverse=True)
1929 [(3, 2, 1), ('a', 'b', 'c')]
1930
1931 If the *strict* keyword argument is ``True``, then
1932 ``ValueError`` will be raised if any of the iterables have
1933 different lengths.
1934
1935 """
1936 if key is None:
1937 # if there is no key function, the key argument to sorted is an
1938 # itemgetter
1939 key_argument = itemgetter(*key_list)
1940 else:
1941 # if there is a key function, call it with the items at the offsets
1942 # specified by the key function as arguments
1943 key_list = list(key_list)
1944 if len(key_list) == 1:
1945 # if key_list contains a single item, pass the item at that offset
1946 # as the only argument to the key function
1947 key_offset = key_list[0]
1948 key_argument = lambda zipped_items: key(zipped_items[key_offset])
1949 else:
1950 # if key_list contains multiple items, use itemgetter to return a
1951 # tuple of items, which we pass as *args to the key function
1952 get_key_items = itemgetter(*key_list)
1953 key_argument = lambda zipped_items: key(
1954 *get_key_items(zipped_items)
1955 )
1956
1957 transposed = zip(*iterables, strict=strict)
1958 reordered = sorted(transposed, key=key_argument, reverse=reverse)
1959 untransposed = zip(*reordered, strict=strict)
1960 return list(untransposed)
1961
1962
1963def unzip(iterable):
1964 """The inverse of :func:`zip`, this function disaggregates the elements
1965 of the zipped *iterable*.
1966
1967 The ``i``-th iterable contains the ``i``-th element from each element
1968 of the zipped iterable. The first element is used to determine the
1969 length of the remaining elements.
1970
1971 >>> iterable = [('a', 1), ('b', 2), ('c', 3), ('d', 4)]
1972 >>> letters, numbers = unzip(iterable)
1973 >>> list(letters)
1974 ['a', 'b', 'c', 'd']
1975 >>> list(numbers)
1976 [1, 2, 3, 4]
1977
1978 This is similar to using ``zip(*iterable)``, but it avoids reading
1979 *iterable* into memory. Note, however, that this function uses
1980 :func:`itertools.tee` and thus may require significant storage.
1981
1982 """
1983 head, iterable = spy(iterable)
1984 if not head:
1985 # empty iterable, e.g. zip([], [], [])
1986 return ()
1987 # spy returns a one-length iterable as head
1988 head = head[0]
1989 iterables = tee(iterable, len(head))
1990
1991 # If we have an iterable like iter([(1, 2, 3), (4, 5), (6,)]),
1992 # the second unzipped iterable fails at the third tuple since
1993 # it tries to access (6,)[1].
1994 # Same with the third unzipped iterable and the second tuple.
1995 # To support these "improperly zipped" iterables, we suppress
1996 # the IndexError, which just stops the unzipped iterables at
1997 # first length mismatch.
1998 return tuple(
1999 iter_suppress(map(itemgetter(i), it), IndexError)
2000 for i, it in enumerate(iterables)
2001 )
2002
2003
2004def divide(n, iterable):
2005 """Divide the elements from *iterable* into *n* parts, maintaining
2006 order.
2007
2008 >>> group_1, group_2 = divide(2, [1, 2, 3, 4, 5, 6])
2009 >>> list(group_1)
2010 [1, 2, 3]
2011 >>> list(group_2)
2012 [4, 5, 6]
2013
2014 If the length of *iterable* is not evenly divisible by *n*, then the
2015 length of the returned iterables will not be identical:
2016
2017 >>> children = divide(3, [1, 2, 3, 4, 5, 6, 7])
2018 >>> [list(c) for c in children]
2019 [[1, 2, 3], [4, 5], [6, 7]]
2020
2021 If the length of the iterable is smaller than n, then the last returned
2022 iterables will be empty:
2023
2024 >>> children = divide(5, [1, 2, 3])
2025 >>> [list(c) for c in children]
2026 [[1], [2], [3], [], []]
2027
2028 This function will exhaust the iterable before returning.
2029 If order is not important, see :func:`distribute`, which does not first
2030 pull the iterable into memory.
2031
2032 """
2033 if n < 1:
2034 raise ValueError('n must be at least 1')
2035
2036 try:
2037 iterable[:0]
2038 except TypeError:
2039 seq = tuple(iterable)
2040 else:
2041 seq = iterable
2042
2043 q, r = divmod(len(seq), n)
2044
2045 ret = []
2046 stop = 0
2047 for i in range(1, n + 1):
2048 start = stop
2049 stop += q + 1 if i <= r else q
2050 ret.append(iter(seq[start:stop]))
2051
2052 return ret
2053
2054
2055def always_iterable(obj, base_type=(str, bytes)):
2056 """If *obj* is iterable, return an iterator over its items::
2057
2058 >>> obj = (1, 2, 3)
2059 >>> list(always_iterable(obj))
2060 [1, 2, 3]
2061
2062 If *obj* is not iterable, return a one-item iterable containing *obj*::
2063
2064 >>> obj = 1
2065 >>> list(always_iterable(obj))
2066 [1]
2067
2068 If *obj* is ``None``, return an empty iterable:
2069
2070 >>> obj = None
2071 >>> list(always_iterable(None))
2072 []
2073
2074 By default, binary and text strings are not considered iterable::
2075
2076 >>> obj = 'foo'
2077 >>> list(always_iterable(obj))
2078 ['foo']
2079
2080 If *base_type* is set, objects for which ``isinstance(obj, base_type)``
2081 returns ``True`` won't be considered iterable.
2082
2083 >>> obj = {'a': 1}
2084 >>> list(always_iterable(obj)) # Iterate over the dict's keys
2085 ['a']
2086 >>> list(always_iterable(obj, base_type=dict)) # Treat dicts as a unit
2087 [{'a': 1}]
2088
2089 Set *base_type* to ``None`` to avoid any special handling and treat objects
2090 Python considers iterable as iterable:
2091
2092 >>> obj = 'foo'
2093 >>> list(always_iterable(obj, base_type=None))
2094 ['f', 'o', 'o']
2095 """
2096 if obj is None:
2097 return iter(())
2098
2099 if (base_type is not None) and isinstance(obj, base_type):
2100 return iter((obj,))
2101
2102 try:
2103 return iter(obj)
2104 except TypeError:
2105 return iter((obj,))
2106
2107
2108def adjacent(predicate, iterable, distance=1):
2109 """Return an iterable over `(bool, item)` tuples where the `item` is
2110 drawn from *iterable* and the `bool` indicates whether
2111 that item satisfies the *predicate* or is adjacent to an item that does.
2112
2113 For example, to find whether items are adjacent to a ``3``::
2114
2115 >>> list(adjacent(lambda x: x == 3, range(6)))
2116 [(False, 0), (False, 1), (True, 2), (True, 3), (True, 4), (False, 5)]
2117
2118 Set *distance* to change what counts as adjacent. For example, to find
2119 whether items are two places away from a ``3``:
2120
2121 >>> list(adjacent(lambda x: x == 3, range(6), distance=2))
2122 [(False, 0), (True, 1), (True, 2), (True, 3), (True, 4), (True, 5)]
2123
2124 This is useful for contextualizing the results of a search function.
2125 For example, a code comparison tool might want to identify lines that
2126 have changed, but also surrounding lines to give the viewer of the diff
2127 context.
2128
2129 The predicate function will only be called once for each item in the
2130 iterable.
2131
2132 See also :func:`groupby_transform`, which can be used with this function
2133 to group ranges of items with the same `bool` value.
2134
2135 """
2136 # Allow distance=0 mainly for testing that it reproduces results with map()
2137 if distance < 0:
2138 raise ValueError('distance must be at least 0')
2139
2140 i1, i2 = tee(iterable)
2141 padding = [False] * distance
2142 selected = chain(padding, map(predicate, i1), padding)
2143 adjacent_to_selected = map(any, windowed(selected, 2 * distance + 1))
2144 return zip(adjacent_to_selected, i2)
2145
2146
2147def groupby_transform(iterable, keyfunc=None, valuefunc=None, reducefunc=None):
2148 """An extension of :func:`itertools.groupby` that can apply transformations
2149 to the grouped data.
2150
2151 * *keyfunc* is a function computing a key value for each item in *iterable*
2152 * *valuefunc* is a function that transforms the individual items from
2153 *iterable* after grouping
2154 * *reducefunc* is a function that transforms each group of items
2155
2156 >>> iterable = 'aAAbBBcCC'
2157 >>> keyfunc = lambda k: k.upper()
2158 >>> valuefunc = lambda v: v.lower()
2159 >>> reducefunc = lambda g: ''.join(g)
2160 >>> list(groupby_transform(iterable, keyfunc, valuefunc, reducefunc))
2161 [('A', 'aaa'), ('B', 'bbb'), ('C', 'ccc')]
2162
2163 Each optional argument defaults to an identity function if not specified.
2164
2165 :func:`groupby_transform` is useful when grouping elements of an iterable
2166 using a separate iterable as the key. To do this, :func:`zip` the iterables
2167 and pass a *keyfunc* that extracts the first element and a *valuefunc*
2168 that extracts the second element::
2169
2170 >>> from operator import itemgetter
2171 >>> keys = [0, 0, 1, 1, 1, 2, 2, 2, 3]
2172 >>> values = 'abcdefghi'
2173 >>> iterable = zip(keys, values)
2174 >>> grouper = groupby_transform(iterable, itemgetter(0), itemgetter(1))
2175 >>> [(k, ''.join(g)) for k, g in grouper]
2176 [(0, 'ab'), (1, 'cde'), (2, 'fgh'), (3, 'i')]
2177
2178 Note that the order of items in the iterable is significant.
2179 Only adjacent items are grouped together, so if you don't want any
2180 duplicate groups, you should sort the iterable by the key function.
2181
2182 """
2183 ret = groupby(iterable, keyfunc)
2184 if valuefunc:
2185 ret = ((k, map(valuefunc, g)) for k, g in ret)
2186 if reducefunc:
2187 ret = ((k, reducefunc(g)) for k, g in ret)
2188
2189 return ret
2190
2191
2192class numeric_range(Sequence):
2193 """An extension of the built-in ``range()`` function whose arguments can
2194 be any orderable numeric type.
2195
2196 With only *stop* specified, *start* defaults to ``0`` and *step*
2197 defaults to ``1``. The output items will match the type of *stop*:
2198
2199 >>> list(numeric_range(3.5))
2200 [0.0, 1.0, 2.0, 3.0]
2201
2202 With only *start* and *stop* specified, *step* defaults to ``1``. The
2203 output items will match the type of *start*:
2204
2205 >>> from decimal import Decimal
2206 >>> start = Decimal('2.1')
2207 >>> stop = Decimal('5.1')
2208 >>> list(numeric_range(start, stop))
2209 [Decimal('2.1'), Decimal('3.1'), Decimal('4.1')]
2210
2211 With *start*, *stop*, and *step* specified the output items will match
2212 the type of ``start + step``:
2213
2214 >>> from fractions import Fraction
2215 >>> start = Fraction(1, 2) # Start at 1/2
2216 >>> stop = Fraction(5, 2) # End at 5/2
2217 >>> step = Fraction(1, 2) # Count by 1/2
2218 >>> list(numeric_range(start, stop, step))
2219 [Fraction(1, 2), Fraction(1, 1), Fraction(3, 2), Fraction(2, 1)]
2220
2221 If *step* is zero, ``ValueError`` is raised. Negative steps are supported:
2222
2223 >>> list(numeric_range(3, -1, -1.0))
2224 [3.0, 2.0, 1.0, 0.0]
2225
2226 Be aware of the limitations of floating-point numbers; the representation
2227 of the yielded numbers may be surprising.
2228
2229 ``datetime.datetime`` objects can be used for *start* and *stop*, if *step*
2230 is a ``datetime.timedelta`` object:
2231
2232 >>> import datetime
2233 >>> start = datetime.datetime(2019, 1, 1)
2234 >>> stop = datetime.datetime(2019, 1, 3)
2235 >>> step = datetime.timedelta(days=1)
2236 >>> items = iter(numeric_range(start, stop, step))
2237 >>> next(items)
2238 datetime.datetime(2019, 1, 1, 0, 0)
2239 >>> next(items)
2240 datetime.datetime(2019, 1, 2, 0, 0)
2241
2242 """
2243
2244 _EMPTY_HASH = hash(range(0, 0))
2245
2246 def __init__(self, *args):
2247 argc = len(args)
2248 if argc == 1:
2249 (self._stop,) = args
2250 self._start = type(self._stop)(0)
2251 self._step = type(self._stop - self._start)(1)
2252 elif argc == 2:
2253 self._start, self._stop = args
2254 self._step = type(self._stop - self._start)(1)
2255 elif argc == 3:
2256 self._start, self._stop, self._step = args
2257 elif argc == 0:
2258 raise TypeError(
2259 f'numeric_range expected at least 1 argument, got {argc}'
2260 )
2261 else:
2262 raise TypeError(
2263 f'numeric_range expected at most 3 arguments, got {argc}'
2264 )
2265
2266 self._zero = type(self._step)(0)
2267 if self._step == self._zero:
2268 raise ValueError('numeric_range() arg 3 must not be zero')
2269 self._growing = self._step > self._zero
2270
2271 def __bool__(self):
2272 if self._growing:
2273 return self._start < self._stop
2274 else:
2275 return self._start > self._stop
2276
2277 def __contains__(self, elem):
2278 if self._growing:
2279 if self._start <= elem < self._stop:
2280 return (elem - self._start) % self._step == self._zero
2281 else:
2282 if self._start >= elem > self._stop:
2283 return (self._start - elem) % (-self._step) == self._zero
2284
2285 return False
2286
2287 def __eq__(self, other):
2288 if isinstance(other, numeric_range):
2289 empty_self = not bool(self)
2290 empty_other = not bool(other)
2291 if empty_self or empty_other:
2292 return empty_self and empty_other # True if both empty
2293 else:
2294 return (
2295 self._start == other._start
2296 and self._step == other._step
2297 and self._get_by_index(-1) == other._get_by_index(-1)
2298 )
2299 else:
2300 return False
2301
2302 def __getitem__(self, key):
2303 if isinstance(key, int):
2304 return self._get_by_index(key)
2305 elif isinstance(key, slice):
2306 step = self._step if key.step is None else key.step * self._step
2307
2308 if key.start is None or key.start <= -self._len:
2309 start = self._start
2310 elif key.start >= self._len:
2311 start = self._stop
2312 else: # -self._len < key.start < self._len
2313 start = self._get_by_index(key.start)
2314
2315 if key.stop is None or key.stop >= self._len:
2316 stop = self._stop
2317 elif key.stop <= -self._len:
2318 stop = self._start
2319 else: # -self._len < key.stop < self._len
2320 stop = self._get_by_index(key.stop)
2321
2322 return numeric_range(start, stop, step)
2323 else:
2324 raise TypeError(
2325 'numeric range indices must be '
2326 f'integers or slices, not {type(key).__name__}'
2327 )
2328
2329 def __hash__(self):
2330 if self:
2331 return hash((self._start, self._get_by_index(-1), self._step))
2332 else:
2333 return self._EMPTY_HASH
2334
2335 def __iter__(self):
2336 values = (self._start + (n * self._step) for n in count())
2337 if self._growing:
2338 return takewhile(partial(gt, self._stop), values)
2339 else:
2340 return takewhile(partial(lt, self._stop), values)
2341
2342 def __len__(self):
2343 return self._len
2344
2345 @cached_property
2346 def _len(self):
2347 if self._growing:
2348 start = self._start
2349 stop = self._stop
2350 step = self._step
2351 else:
2352 start = self._stop
2353 stop = self._start
2354 step = -self._step
2355 distance = stop - start
2356 if distance <= self._zero:
2357 return 0
2358 else: # distance > 0 and step > 0: regular euclidean division
2359 q, r = divmod(distance, step)
2360 return int(q) + int(r != self._zero)
2361
2362 def __reduce__(self):
2363 return numeric_range, (self._start, self._stop, self._step)
2364
2365 def __repr__(self):
2366 if self._step == 1:
2367 return f"numeric_range({self._start!r}, {self._stop!r})"
2368 return (
2369 f"numeric_range({self._start!r}, {self._stop!r}, {self._step!r})"
2370 )
2371
2372 def __reversed__(self):
2373 return iter(
2374 numeric_range(
2375 self._get_by_index(-1), self._start - self._step, -self._step
2376 )
2377 )
2378
2379 def count(self, value):
2380 return int(value in self)
2381
2382 def index(self, value):
2383 if self._growing:
2384 if self._start <= value < self._stop:
2385 q, r = divmod(value - self._start, self._step)
2386 if r == self._zero:
2387 return int(q)
2388 else:
2389 if self._start >= value > self._stop:
2390 q, r = divmod(self._start - value, -self._step)
2391 if r == self._zero:
2392 return int(q)
2393
2394 raise ValueError(f"{value} is not in numeric range")
2395
2396 def _get_by_index(self, i):
2397 if i < 0:
2398 i += self._len
2399 if i < 0 or i >= self._len:
2400 raise IndexError("numeric range object index out of range")
2401 return self._start + i * self._step
2402
2403
2404def count_cycle(iterable, n=None):
2405 """Cycle through the items from *iterable* up to *n* times, yielding
2406 the number of completed cycles along with each item. If *n* is omitted the
2407 process repeats indefinitely.
2408
2409 >>> list(count_cycle('AB', 3))
2410 [(0, 'A'), (0, 'B'), (1, 'A'), (1, 'B'), (2, 'A'), (2, 'B')]
2411
2412 """
2413 seq = tuple(iterable)
2414 if not seq:
2415 return iter(())
2416 counter = count() if n is None else range(n)
2417 return zip(repeat_each(counter, len(seq)), cycle(seq))
2418
2419
2420def mark_ends(iterable):
2421 """Yield 3-tuples of the form ``(is_first, is_last, item)``.
2422
2423 >>> list(mark_ends('ABC'))
2424 [(True, False, 'A'), (False, False, 'B'), (False, True, 'C')]
2425
2426 Use this when looping over an iterable to take special action on its first
2427 and/or last items:
2428
2429 >>> iterable = ['Header', 100, 200, 'Footer']
2430 >>> total = 0
2431 >>> for is_first, is_last, item in mark_ends(iterable):
2432 ... if is_first:
2433 ... continue # Skip the header
2434 ... if is_last:
2435 ... continue # Skip the footer
2436 ... total += item
2437 >>> print(total)
2438 300
2439 """
2440 it = iter(iterable)
2441 for a in it:
2442 first = True
2443 for b in it:
2444 yield first, False, a
2445 a = b
2446 first = False
2447 yield first, True, a
2448
2449
2450def locate(iterable, pred=bool, window_size=None):
2451 """Yield the index of each item in *iterable* for which *pred* returns
2452 ``True``.
2453
2454 *pred* defaults to :func:`bool`, which will select truthy items:
2455
2456 >>> list(locate([0, 1, 1, 0, 1, 0, 0]))
2457 [1, 2, 4]
2458
2459 Set *pred* to a custom function to, e.g., find the indexes for a particular
2460 item.
2461
2462 >>> list(locate(['a', 'b', 'c', 'b'], lambda x: x == 'b'))
2463 [1, 3]
2464
2465 If *window_size* is given, then the *pred* function will be called with
2466 that many items. This enables searching for sub-sequences:
2467
2468 >>> iterable = [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]
2469 >>> pred = lambda *args: args == (1, 2, 3)
2470 >>> list(locate(iterable, pred=pred, window_size=3))
2471 [1, 5, 9]
2472
2473 Use with :func:`seekable` to find indexes and then retrieve the associated
2474 items:
2475
2476 >>> from itertools import count
2477 >>> from more_itertools import seekable
2478 >>> source = (3 * n + 1 if (n % 2) else n // 2 for n in count())
2479 >>> it = seekable(source)
2480 >>> pred = lambda x: x > 100
2481 >>> indexes = locate(it, pred=pred)
2482 >>> i = next(indexes)
2483 >>> it.seek(i)
2484 >>> next(it)
2485 106
2486
2487 """
2488 if window_size is None:
2489 return compress(count(), map(pred, iterable))
2490
2491 if window_size < 1:
2492 raise ValueError('window size must be at least 1')
2493
2494 it = windowed(iterable, window_size, fillvalue=_marker)
2495 return compress(count(), starmap(pred, it))
2496
2497
2498def longest_common_prefix(iterables):
2499 """Yield elements of the longest common prefix among given *iterables*.
2500
2501 >>> ''.join(longest_common_prefix(['abcd', 'abc', 'abf']))
2502 'ab'
2503
2504 """
2505 return (c[0] for c in takewhile(all_equal, zip(*iterables)))
2506
2507
2508def lstrip(iterable, pred):
2509 """Yield the items from *iterable*, but strip any from the beginning
2510 for which *pred* returns ``True``.
2511
2512 For example, to remove a set of items from the start of an iterable:
2513
2514 >>> iterable = (None, False, None, 1, 2, None, 3, False, None)
2515 >>> pred = lambda x: x in {None, False, ''}
2516 >>> list(lstrip(iterable, pred))
2517 [1, 2, None, 3, False, None]
2518
2519 This function is analogous to to :func:`str.lstrip`, and is essentially
2520 an wrapper for :func:`itertools.dropwhile`.
2521
2522 """
2523 return dropwhile(pred, iterable)
2524
2525
2526def rstrip(iterable, pred):
2527 """Yield the items from *iterable*, but strip any from the end
2528 for which *pred* returns ``True``.
2529
2530 For example, to remove a set of items from the end of an iterable:
2531
2532 >>> iterable = (None, False, None, 1, 2, None, 3, False, None)
2533 >>> pred = lambda x: x in {None, False, ''}
2534 >>> list(rstrip(iterable, pred))
2535 [None, False, None, 1, 2, None, 3]
2536
2537 This function is analogous to :func:`str.rstrip`.
2538
2539 """
2540 cache = []
2541 cache_append = cache.append
2542 cache_clear = cache.clear
2543 for x in iterable:
2544 if pred(x):
2545 cache_append(x)
2546 else:
2547 yield from cache
2548 cache_clear()
2549 yield x
2550
2551
2552def strip(iterable, pred):
2553 """Yield the items from *iterable*, but strip any from the
2554 beginning and end for which *pred* returns ``True``.
2555
2556 For example, to remove a set of items from both ends of an iterable:
2557
2558 >>> iterable = (None, False, None, 1, 2, None, 3, False, None)
2559 >>> pred = lambda x: x in {None, False, ''}
2560 >>> list(strip(iterable, pred))
2561 [1, 2, None, 3]
2562
2563 This function is analogous to :func:`str.strip`.
2564
2565 """
2566 return rstrip(lstrip(iterable, pred), pred)
2567
2568
2569class islice_extended:
2570 """An extension of :func:`itertools.islice` that supports negative values
2571 for *stop*, *start*, and *step*.
2572
2573 >>> iterator = iter('abcdefgh')
2574 >>> list(islice_extended(iterator, -4, -1))
2575 ['e', 'f', 'g']
2576
2577 Slices with negative values require some caching of *iterable*, but this
2578 function takes care to minimize the amount of memory required.
2579
2580 For example, you can use a negative step with an infinite iterator:
2581
2582 >>> from itertools import count
2583 >>> list(islice_extended(count(), 110, 99, -2))
2584 [110, 108, 106, 104, 102, 100]
2585
2586 You can also use slice notation directly:
2587
2588 >>> iterator = map(str, count())
2589 >>> it = islice_extended(iterator)[10:20:2]
2590 >>> list(it)
2591 ['10', '12', '14', '16', '18']
2592
2593 """
2594
2595 def __init__(self, iterable, *args):
2596 it = iter(iterable)
2597 if args:
2598 self._iterator = _islice_helper(it, slice(*args))
2599 else:
2600 self._iterator = it
2601
2602 def __iter__(self):
2603 return self
2604
2605 def __next__(self):
2606 return next(self._iterator)
2607
2608 def __getitem__(self, key):
2609 if isinstance(key, slice):
2610 return islice_extended(_islice_helper(self._iterator, key))
2611
2612 raise TypeError('islice_extended.__getitem__ argument must be a slice')
2613
2614
2615def _islice_helper(it, s):
2616 start = s.start
2617 stop = s.stop
2618 if s.step == 0:
2619 raise ValueError('step argument must be a non-zero integer or None.')
2620 step = s.step or 1
2621
2622 if step > 0:
2623 start = 0 if (start is None) else start
2624
2625 if start < 0:
2626 # Consume all but the last -start items
2627 cache = deque(enumerate(it, 1), maxlen=-start)
2628 len_iter = cache[-1][0] if cache else 0
2629
2630 # Adjust start to be positive
2631 i = max(len_iter + start, 0)
2632
2633 # Adjust stop to be positive
2634 if stop is None:
2635 j = len_iter
2636 elif stop >= 0:
2637 j = min(stop, len_iter)
2638 else:
2639 j = max(len_iter + stop, 0)
2640
2641 # Slice the cache
2642 n = j - i
2643 if n <= 0:
2644 return
2645
2646 for index in range(n):
2647 if index % step == 0:
2648 # pop and yield the item.
2649 # We don't want to use an intermediate variable
2650 # it would extend the lifetime of the current item
2651 yield cache.popleft()[1]
2652 else:
2653 # just pop and discard the item
2654 cache.popleft()
2655 elif (stop is not None) and (stop < 0):
2656 # Advance to the start position
2657 next(islice(it, start, start), None)
2658
2659 # When stop is negative, we have to carry -stop items while
2660 # iterating
2661 cache = deque(islice(it, -stop), maxlen=-stop)
2662
2663 for index, item in enumerate(it):
2664 if index % step == 0:
2665 # pop and yield the item.
2666 # We don't want to use an intermediate variable
2667 # it would extend the lifetime of the current item
2668 yield cache.popleft()
2669 else:
2670 # just pop and discard the item
2671 cache.popleft()
2672 cache.append(item)
2673 else:
2674 # When both start and stop are positive we have the normal case
2675 yield from islice(it, start, stop, step)
2676 else:
2677 start = -1 if (start is None) else start
2678
2679 if (stop is not None) and (stop < 0):
2680 # Consume all but the last items
2681 n = -stop - 1
2682 cache = deque(enumerate(it, 1), maxlen=n)
2683 len_iter = cache[-1][0] if cache else 0
2684
2685 # If start and stop are both negative they are comparable and
2686 # we can just slice. Otherwise we can adjust start to be negative
2687 # and then slice.
2688 if start < 0:
2689 i, j = start, stop
2690 else:
2691 i, j = min(start - len_iter, -1), None
2692
2693 for index, item in list(cache)[i:j:step]:
2694 yield item
2695 else:
2696 # Advance to the stop position
2697 if stop is not None:
2698 m = stop + 1
2699 next(islice(it, m, m), None)
2700
2701 # stop is positive, so if start is negative they are not comparable
2702 # and we need the rest of the items.
2703 if start < 0:
2704 i = start
2705 n = None
2706 # stop is None and start is positive, so we just need items up to
2707 # the start index.
2708 elif stop is None:
2709 i = None
2710 n = start + 1
2711 # Both stop and start are positive, so they are comparable.
2712 else:
2713 i = None
2714 n = start - stop
2715 if n <= 0:
2716 return
2717
2718 cache = list(islice(it, n))
2719
2720 yield from cache[i::step]
2721
2722
2723def always_reversible(iterable):
2724 """An extension of :func:`reversed` that supports all iterables, not
2725 just those which implement the ``Reversible`` or ``Sequence`` protocols.
2726
2727 >>> print(*always_reversible(x for x in range(3)))
2728 2 1 0
2729
2730 If the iterable is already reversible, this function returns the
2731 result of :func:`reversed()`. If the iterable is not reversible,
2732 this function will cache the remaining items in the iterable and
2733 yield them in reverse order, which may require significant storage.
2734 """
2735 try:
2736 return reversed(iterable)
2737 except TypeError:
2738 return reversed(list(iterable))
2739
2740
2741def consecutive_groups(iterable, ordering=None):
2742 """Yield groups of consecutive items using :func:`itertools.groupby`.
2743 The *ordering* function determines whether two items are adjacent by
2744 returning their position.
2745
2746 By default, the ordering function is the identity function. This is
2747 suitable for finding runs of numbers:
2748
2749 >>> iterable = [1, 10, 11, 12, 20, 30, 31, 32, 33, 40]
2750 >>> for group in consecutive_groups(iterable):
2751 ... print(list(group))
2752 [1]
2753 [10, 11, 12]
2754 [20]
2755 [30, 31, 32, 33]
2756 [40]
2757
2758 To find runs of adjacent letters, apply :func:`ord` function
2759 to convert letters to ordinals.
2760
2761 >>> iterable = 'abcdfgilmnop'
2762 >>> ordering = ord
2763 >>> for group in consecutive_groups(iterable, ordering):
2764 ... print(list(group))
2765 ['a', 'b', 'c', 'd']
2766 ['f', 'g']
2767 ['i']
2768 ['l', 'm', 'n', 'o', 'p']
2769
2770 Each group of consecutive items is an iterator that shares it source with
2771 *iterable*. When an an output group is advanced, the previous group is
2772 no longer available unless its elements are copied (e.g., into a ``list``).
2773
2774 >>> iterable = [1, 2, 11, 12, 21, 22]
2775 >>> saved_groups = []
2776 >>> for group in consecutive_groups(iterable):
2777 ... saved_groups.append(list(group)) # Copy group elements
2778 >>> saved_groups
2779 [[1, 2], [11, 12], [21, 22]]
2780
2781 """
2782 if ordering is None:
2783 key = lambda x: x[0] - x[1]
2784 else:
2785 key = lambda x: x[0] - ordering(x[1])
2786
2787 for k, g in groupby(enumerate(iterable), key=key):
2788 yield map(itemgetter(1), g)
2789
2790
2791def difference(iterable, func=sub, *, initial=None):
2792 """This function is the inverse of :func:`itertools.accumulate`. By default
2793 it will compute the first difference of *iterable* using
2794 :func:`operator.sub`:
2795
2796 >>> from itertools import accumulate
2797 >>> iterable = accumulate([0, 1, 2, 3, 4]) # produces 0, 1, 3, 6, 10
2798 >>> list(difference(iterable))
2799 [0, 1, 2, 3, 4]
2800
2801 *func* defaults to :func:`operator.sub`, but other functions can be
2802 specified. They will be applied as follows::
2803
2804 A, B, C, D, ... --> A, func(B, A), func(C, B), func(D, C), ...
2805
2806 For example, to do progressive division:
2807
2808 >>> iterable = [1, 2, 6, 24, 120]
2809 >>> func = lambda x, y: x // y
2810 >>> list(difference(iterable, func))
2811 [1, 2, 3, 4, 5]
2812
2813 If the *initial* keyword is set, the first element will be skipped when
2814 computing successive differences.
2815
2816 >>> it = [10, 11, 13, 16] # from accumulate([1, 2, 3], initial=10)
2817 >>> list(difference(it, initial=10))
2818 [1, 2, 3]
2819
2820 """
2821 a, b = tee(iterable)
2822 try:
2823 first = [next(b)]
2824 except StopIteration:
2825 return iter([])
2826
2827 if initial is not None:
2828 first = []
2829
2830 return chain(first, map(func, b, a))
2831
2832
2833class SequenceView(Sequence):
2834 """Return a read-only view of the sequence object *target*.
2835
2836 :class:`SequenceView` objects are analogous to Python's built-in
2837 "dictionary view" types. They provide a dynamic view of a sequence's items,
2838 meaning that when the sequence updates, so does the view.
2839
2840 >>> seq = ['0', '1', '2']
2841 >>> view = SequenceView(seq)
2842 >>> view
2843 SequenceView(['0', '1', '2'])
2844 >>> seq.append('3')
2845 >>> view
2846 SequenceView(['0', '1', '2', '3'])
2847
2848 Sequence views support indexing, slicing, and length queries. They act
2849 like the underlying sequence, except they don't allow assignment:
2850
2851 >>> view[1]
2852 '1'
2853 >>> view[1:-1]
2854 ['1', '2']
2855 >>> len(view)
2856 4
2857
2858 Sequence views are useful as an alternative to copying, as they don't
2859 require (much) extra storage.
2860
2861 """
2862
2863 def __init__(self, target):
2864 if not isinstance(target, Sequence):
2865 raise TypeError
2866 self._target = target
2867
2868 def __getitem__(self, index):
2869 return self._target[index]
2870
2871 def __len__(self):
2872 return len(self._target)
2873
2874 def __repr__(self):
2875 return f'{self.__class__.__name__}({self._target!r})'
2876
2877
2878class seekable:
2879 """Wrap an iterator to allow for seeking backward and forward. This
2880 progressively caches the items in the source iterable so they can be
2881 re-visited.
2882
2883 Call :meth:`seek` with an index to seek to that position in the source
2884 iterable.
2885
2886 To "reset" an iterator, seek to ``0``:
2887
2888 >>> from itertools import count
2889 >>> it = seekable((str(n) for n in count()))
2890 >>> next(it), next(it), next(it)
2891 ('0', '1', '2')
2892 >>> it.seek(0)
2893 >>> next(it), next(it), next(it)
2894 ('0', '1', '2')
2895
2896 You can also seek forward:
2897
2898 >>> it = seekable((str(n) for n in range(20)))
2899 >>> it.seek(10)
2900 >>> next(it)
2901 '10'
2902 >>> it.seek(20) # Seeking past the end of the source isn't a problem
2903 >>> list(it)
2904 []
2905 >>> it.seek(0) # Resetting works even after hitting the end
2906 >>> next(it)
2907 '0'
2908
2909 Call :meth:`relative_seek` to seek relative to the source iterator's
2910 current position.
2911
2912 >>> it = seekable((str(n) for n in range(20)))
2913 >>> next(it), next(it), next(it)
2914 ('0', '1', '2')
2915 >>> it.relative_seek(2)
2916 >>> next(it)
2917 '5'
2918 >>> it.relative_seek(-3) # Source is at '6', we move back to '3'
2919 >>> next(it)
2920 '3'
2921 >>> it.relative_seek(-3) # Source is at '4', we move back to '1'
2922 >>> next(it)
2923 '1'
2924
2925
2926 Call :meth:`peek` to look ahead one item without advancing the iterator:
2927
2928 >>> it = seekable('1234')
2929 >>> it.peek()
2930 '1'
2931 >>> list(it)
2932 ['1', '2', '3', '4']
2933 >>> it.peek(default='empty')
2934 'empty'
2935
2936 Before the iterator is at its end, calling :func:`bool` on it will return
2937 ``True``. After it will return ``False``:
2938
2939 >>> it = seekable('5678')
2940 >>> bool(it)
2941 True
2942 >>> list(it)
2943 ['5', '6', '7', '8']
2944 >>> bool(it)
2945 False
2946
2947 You may view the contents of the cache with the :meth:`elements` method.
2948 That returns a :class:`SequenceView`, a view that updates automatically:
2949
2950 >>> it = seekable((str(n) for n in range(10)))
2951 >>> next(it), next(it), next(it)
2952 ('0', '1', '2')
2953 >>> elements = it.elements()
2954 >>> elements
2955 SequenceView(['0', '1', '2'])
2956 >>> next(it)
2957 '3'
2958 >>> elements
2959 SequenceView(['0', '1', '2', '3'])
2960
2961 By default, the cache grows as the source iterable progresses, so beware of
2962 wrapping very large or infinite iterables. Supply *maxlen* to limit the
2963 size of the cache (this of course limits how far back you can seek).
2964
2965 >>> from itertools import count
2966 >>> it = seekable((str(n) for n in count()), maxlen=2)
2967 >>> next(it), next(it), next(it), next(it)
2968 ('0', '1', '2', '3')
2969 >>> list(it.elements())
2970 ['2', '3']
2971 >>> it.seek(0)
2972 >>> next(it), next(it), next(it), next(it)
2973 ('2', '3', '4', '5')
2974 >>> next(it)
2975 '6'
2976
2977 """
2978
2979 def __init__(self, iterable, maxlen=None):
2980 self._source = iter(iterable)
2981 if maxlen is None:
2982 self._cache = []
2983 else:
2984 self._cache = deque([], maxlen)
2985 self._index = None
2986
2987 def __iter__(self):
2988 return self
2989
2990 def __next__(self):
2991 if self._index is not None:
2992 try:
2993 item = self._cache[self._index]
2994 except IndexError:
2995 self._index = None
2996 else:
2997 self._index += 1
2998 return item
2999
3000 item = next(self._source)
3001 self._cache.append(item)
3002 return item
3003
3004 def __bool__(self):
3005 try:
3006 self.peek()
3007 except StopIteration:
3008 return False
3009 return True
3010
3011 def peek(self, default=_marker):
3012 try:
3013 peeked = next(self)
3014 except StopIteration:
3015 if default is _marker:
3016 raise
3017 return default
3018 if self._index is None:
3019 self._index = len(self._cache)
3020 self._index -= 1
3021 return peeked
3022
3023 def elements(self):
3024 return SequenceView(self._cache)
3025
3026 def seek(self, index):
3027 self._index = index
3028 remainder = index - len(self._cache)
3029 if remainder > 0:
3030 consume(self, remainder)
3031
3032 def relative_seek(self, count):
3033 if self._index is None:
3034 self._index = len(self._cache)
3035
3036 self.seek(max(self._index + count, 0))
3037
3038
3039class run_length:
3040 """
3041 :func:`run_length.encode` compresses an iterable with run-length encoding.
3042 It yields groups of repeated items with the count of how many times they
3043 were repeated:
3044
3045 >>> uncompressed = 'abbcccdddd'
3046 >>> list(run_length.encode(uncompressed))
3047 [('a', 1), ('b', 2), ('c', 3), ('d', 4)]
3048
3049 :func:`run_length.decode` decompresses an iterable that was previously
3050 compressed with run-length encoding. It yields the items of the
3051 decompressed iterable:
3052
3053 >>> compressed = [('a', 1), ('b', 2), ('c', 3), ('d', 4)]
3054 >>> list(run_length.decode(compressed))
3055 ['a', 'b', 'b', 'c', 'c', 'c', 'd', 'd', 'd', 'd']
3056
3057 """
3058
3059 @staticmethod
3060 def encode(iterable):
3061 return ((k, ilen(g)) for k, g in groupby(iterable))
3062
3063 @staticmethod
3064 def decode(iterable):
3065 return chain.from_iterable(starmap(repeat, iterable))
3066
3067
3068def exactly_n(iterable, n, predicate=bool):
3069 """Return ``True`` if exactly ``n`` items in the iterable are ``True``
3070 according to the *predicate* function.
3071
3072 >>> exactly_n([True, True, False], 2)
3073 True
3074 >>> exactly_n([True, True, False], 1)
3075 False
3076 >>> exactly_n([0, 1, 2, 3, 4, 5], 3, lambda x: x < 3)
3077 True
3078
3079 The iterable will be advanced until ``n + 1`` truthy items are encountered,
3080 so avoid calling it on infinite iterables.
3081
3082 """
3083 return ilen(islice(filter(predicate, iterable), n + 1)) == n
3084
3085
3086def circular_shifts(iterable, steps=1):
3087 """Yield the circular shifts of *iterable*.
3088
3089 >>> list(circular_shifts(range(4)))
3090 [(0, 1, 2, 3), (1, 2, 3, 0), (2, 3, 0, 1), (3, 0, 1, 2)]
3091
3092 Set *steps* to the number of places to rotate to the left
3093 (or to the right if negative). Defaults to 1.
3094
3095 >>> list(circular_shifts(range(4), 2))
3096 [(0, 1, 2, 3), (2, 3, 0, 1)]
3097
3098 >>> list(circular_shifts(range(4), -1))
3099 [(0, 1, 2, 3), (3, 0, 1, 2), (2, 3, 0, 1), (1, 2, 3, 0)]
3100
3101 """
3102 buffer = deque(iterable)
3103 if steps == 0:
3104 raise ValueError('Steps should be a non-zero integer')
3105
3106 buffer.rotate(steps)
3107 steps = -steps
3108 n = len(buffer)
3109 n //= math.gcd(n, steps)
3110
3111 for _ in repeat(None, n):
3112 buffer.rotate(steps)
3113 yield tuple(buffer)
3114
3115
3116def make_decorator(wrapping_func, result_index=0):
3117 """Return a decorator version of *wrapping_func*, which is a function that
3118 modifies an iterable. *result_index* is the position in that function's
3119 signature where the iterable goes.
3120
3121 This lets you use itertools on the "production end," i.e. at function
3122 definition. This can augment what the function returns without changing the
3123 function's code.
3124
3125 For example, to produce a decorator version of :func:`chunked`:
3126
3127 >>> from more_itertools import chunked
3128 >>> chunker = make_decorator(chunked, result_index=0)
3129 >>> @chunker(3)
3130 ... def iter_range(n):
3131 ... return iter(range(n))
3132 ...
3133 >>> list(iter_range(9))
3134 [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
3135
3136 To only allow truthy items to be returned:
3137
3138 >>> truth_serum = make_decorator(filter, result_index=1)
3139 >>> @truth_serum(bool)
3140 ... def boolean_test():
3141 ... return [0, 1, '', ' ', False, True]
3142 ...
3143 >>> list(boolean_test())
3144 [1, ' ', True]
3145
3146 The :func:`peekable` and :func:`seekable` wrappers make for practical
3147 decorators:
3148
3149 >>> from more_itertools import peekable
3150 >>> peekable_function = make_decorator(peekable)
3151 >>> @peekable_function()
3152 ... def str_range(*args):
3153 ... return (str(x) for x in range(*args))
3154 ...
3155 >>> it = str_range(1, 20, 2)
3156 >>> next(it), next(it), next(it)
3157 ('1', '3', '5')
3158 >>> it.peek()
3159 '7'
3160 >>> next(it)
3161 '7'
3162
3163 """
3164
3165 # See https://sites.google.com/site/bbayles/index/decorator_factory for
3166 # notes on how this works.
3167 def decorator(*wrapping_args, **wrapping_kwargs):
3168 def outer_wrapper(f):
3169 def inner_wrapper(*args, **kwargs):
3170 result = f(*args, **kwargs)
3171 wrapping_args_ = list(wrapping_args)
3172 wrapping_args_.insert(result_index, result)
3173 return wrapping_func(*wrapping_args_, **wrapping_kwargs)
3174
3175 return inner_wrapper
3176
3177 return outer_wrapper
3178
3179 return decorator
3180
3181
3182def map_reduce(iterable, keyfunc, valuefunc=None, reducefunc=None):
3183 """Return a dictionary that maps the items in *iterable* to categories
3184 defined by *keyfunc*, transforms them with *valuefunc*, and
3185 then summarizes them by category with *reducefunc*.
3186
3187 *valuefunc* defaults to the identity function if it is unspecified.
3188 If *reducefunc* is unspecified, no summarization takes place:
3189
3190 >>> keyfunc = lambda x: x.upper()
3191 >>> result = map_reduce('abbccc', keyfunc)
3192 >>> sorted(result.items())
3193 [('A', ['a']), ('B', ['b', 'b']), ('C', ['c', 'c', 'c'])]
3194
3195 Specifying *valuefunc* transforms the categorized items:
3196
3197 >>> keyfunc = lambda x: x.upper()
3198 >>> valuefunc = lambda x: 1
3199 >>> result = map_reduce('abbccc', keyfunc, valuefunc)
3200 >>> sorted(result.items())
3201 [('A', [1]), ('B', [1, 1]), ('C', [1, 1, 1])]
3202
3203 Specifying *reducefunc* summarizes the categorized items:
3204
3205 >>> keyfunc = lambda x: x.upper()
3206 >>> valuefunc = lambda x: 1
3207 >>> reducefunc = sum
3208 >>> result = map_reduce('abbccc', keyfunc, valuefunc, reducefunc)
3209 >>> sorted(result.items())
3210 [('A', 1), ('B', 2), ('C', 3)]
3211
3212 You may want to filter the input iterable before applying the map/reduce
3213 procedure:
3214
3215 >>> all_items = range(30)
3216 >>> items = [x for x in all_items if 10 <= x <= 20] # Filter
3217 >>> keyfunc = lambda x: x % 2 # Evens map to 0; odds to 1
3218 >>> categories = map_reduce(items, keyfunc=keyfunc)
3219 >>> sorted(categories.items())
3220 [(0, [10, 12, 14, 16, 18, 20]), (1, [11, 13, 15, 17, 19])]
3221 >>> summaries = map_reduce(items, keyfunc=keyfunc, reducefunc=sum)
3222 >>> sorted(summaries.items())
3223 [(0, 90), (1, 75)]
3224
3225 Note that all items in the iterable are gathered into a list before the
3226 summarization step, which may require significant storage.
3227
3228 The returned object is a :obj:`collections.defaultdict` with the
3229 ``default_factory`` set to ``None``, such that it behaves like a normal
3230 dictionary.
3231
3232 """
3233
3234 ret = defaultdict(list)
3235
3236 if valuefunc is None:
3237 for item in iterable:
3238 key = keyfunc(item)
3239 ret[key].append(item)
3240
3241 else:
3242 for item in iterable:
3243 key = keyfunc(item)
3244 value = valuefunc(item)
3245 ret[key].append(value)
3246
3247 if reducefunc is not None:
3248 for key, value_list in ret.items():
3249 ret[key] = reducefunc(value_list)
3250
3251 ret.default_factory = None
3252 return ret
3253
3254
3255def rlocate(iterable, pred=bool, window_size=None):
3256 """Yield the index of each item in *iterable* for which *pred* returns
3257 ``True``, starting from the right and moving left.
3258
3259 *pred* defaults to :func:`bool`, which will select truthy items:
3260
3261 >>> list(rlocate([0, 1, 1, 0, 1, 0, 0])) # Truthy at 1, 2, and 4
3262 [4, 2, 1]
3263
3264 Set *pred* to a custom function to, e.g., find the indexes for a particular
3265 item:
3266
3267 >>> iterator = iter('abcb')
3268 >>> pred = lambda x: x == 'b'
3269 >>> list(rlocate(iterator, pred))
3270 [3, 1]
3271
3272 If *window_size* is given, then the *pred* function will be called with
3273 that many items. This enables searching for sub-sequences:
3274
3275 >>> iterable = [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]
3276 >>> pred = lambda *args: args == (1, 2, 3)
3277 >>> list(rlocate(iterable, pred=pred, window_size=3))
3278 [9, 5, 1]
3279
3280 Beware, this function won't return anything for infinite iterables.
3281 If *iterable* is reversible, ``rlocate`` will reverse it and search from
3282 the right. Otherwise, it will search from the left and return the results
3283 in reverse order.
3284
3285 See :func:`locate` to for other example applications.
3286
3287 """
3288 if window_size is None:
3289 try:
3290 len_iter = len(iterable)
3291 return (len_iter - i - 1 for i in locate(reversed(iterable), pred))
3292 except TypeError:
3293 pass
3294
3295 return reversed(list(locate(iterable, pred, window_size)))
3296
3297
3298def replace(iterable, pred, substitutes, count=None, window_size=1):
3299 """Yield the items from *iterable*, replacing the items for which *pred*
3300 returns ``True`` with the items from the iterable *substitutes*.
3301
3302 >>> iterable = [1, 1, 0, 1, 1, 0, 1, 1]
3303 >>> pred = lambda x: x == 0
3304 >>> substitutes = (2, 3)
3305 >>> list(replace(iterable, pred, substitutes))
3306 [1, 1, 2, 3, 1, 1, 2, 3, 1, 1]
3307
3308 If *count* is given, the number of replacements will be limited:
3309
3310 >>> iterable = [1, 1, 0, 1, 1, 0, 1, 1, 0]
3311 >>> pred = lambda x: x == 0
3312 >>> substitutes = [None]
3313 >>> list(replace(iterable, pred, substitutes, count=2))
3314 [1, 1, None, 1, 1, None, 1, 1, 0]
3315
3316 Use *window_size* to control the number of items passed as arguments to
3317 *pred*. This allows for locating and replacing subsequences.
3318
3319 >>> iterable = [0, 1, 2, 5, 0, 1, 2, 5]
3320 >>> window_size = 3
3321 >>> pred = lambda *args: args == (0, 1, 2) # 3 items passed to pred
3322 >>> substitutes = [3, 4] # Splice in these items
3323 >>> list(replace(iterable, pred, substitutes, window_size=window_size))
3324 [3, 4, 5, 3, 4, 5]
3325
3326 """
3327 if window_size < 1:
3328 raise ValueError('window_size must be at least 1')
3329
3330 # Save the substitutes iterable, since it's used more than once
3331 substitutes = tuple(substitutes)
3332
3333 # Add padding such that the number of windows matches the length of the
3334 # iterable
3335 it = chain(iterable, repeat(_marker, window_size - 1))
3336 windows = windowed(it, window_size)
3337
3338 n = 0
3339 for w in windows:
3340 # If the current window matches our predicate (and we haven't hit
3341 # our maximum number of replacements), splice in the substitutes
3342 # and then consume the following windows that overlap with this one.
3343 # For example, if the iterable is (0, 1, 2, 3, 4...)
3344 # and the window size is 2, we have (0, 1), (1, 2), (2, 3)...
3345 # If the predicate matches on (0, 1), we need to zap (0, 1) and (1, 2)
3346 if pred(*w):
3347 if (count is None) or (n < count):
3348 n += 1
3349 yield from substitutes
3350 consume(windows, window_size - 1)
3351 continue
3352
3353 # If there was no match (or we've reached the replacement limit),
3354 # yield the first item from the window.
3355 if w and (w[0] is not _marker):
3356 yield w[0]
3357
3358
3359def partitions(iterable):
3360 """Yield all possible order-preserving partitions of *iterable*.
3361
3362 >>> iterable = 'abc'
3363 >>> for part in partitions(iterable):
3364 ... print([''.join(p) for p in part])
3365 ['abc']
3366 ['a', 'bc']
3367 ['ab', 'c']
3368 ['a', 'b', 'c']
3369
3370 This is unrelated to :func:`partition`.
3371
3372 """
3373 sequence = list(iterable)
3374 n = len(sequence)
3375 for i in powerset(range(1, n)):
3376 yield [sequence[i:j] for i, j in zip((0,) + i, i + (n,))]
3377
3378
3379def set_partitions(iterable, k=None, min_size=None, max_size=None):
3380 """
3381 Yield the set partitions of *iterable* into *k* parts. Set partitions are
3382 not order-preserving.
3383
3384 >>> iterable = 'abc'
3385 >>> for part in set_partitions(iterable, 2):
3386 ... print([''.join(p) for p in part])
3387 ['a', 'bc']
3388 ['ab', 'c']
3389 ['b', 'ac']
3390
3391
3392 If *k* is not given, every set partition is generated.
3393
3394 >>> iterable = 'abc'
3395 >>> for part in set_partitions(iterable):
3396 ... print([''.join(p) for p in part])
3397 ['abc']
3398 ['a', 'bc']
3399 ['ab', 'c']
3400 ['b', 'ac']
3401 ['a', 'b', 'c']
3402
3403 if *min_size* and/or *max_size* are given, the minimum and/or maximum size
3404 per block in partition is set.
3405
3406 >>> iterable = 'abc'
3407 >>> for part in set_partitions(iterable, min_size=2):
3408 ... print([''.join(p) for p in part])
3409 ['abc']
3410 >>> for part in set_partitions(iterable, max_size=2):
3411 ... print([''.join(p) for p in part])
3412 ['a', 'bc']
3413 ['ab', 'c']
3414 ['b', 'ac']
3415 ['a', 'b', 'c']
3416
3417 """
3418 L = list(iterable)
3419 n = len(L)
3420 if k is not None:
3421 if k < 1:
3422 raise ValueError(
3423 "Can't partition in a negative or zero number of groups"
3424 )
3425 elif k > n:
3426 return
3427
3428 min_size = min_size if min_size is not None else 0
3429 max_size = max_size if max_size is not None else n
3430 if min_size > max_size:
3431 return
3432
3433 def set_partitions_helper(L, k):
3434 n = len(L)
3435 if k == 1:
3436 yield [L]
3437 elif n == k:
3438 yield [[s] for s in L]
3439 else:
3440 e, *M = L
3441 for p in set_partitions_helper(M, k - 1):
3442 yield [[e], *p]
3443 for p in set_partitions_helper(M, k):
3444 for i in range(len(p)):
3445 yield p[:i] + [[e] + p[i]] + p[i + 1 :]
3446
3447 if k is None:
3448 for k in range(1, n + 1):
3449 yield from filter(
3450 lambda z: all(min_size <= len(bk) <= max_size for bk in z),
3451 set_partitions_helper(L, k),
3452 )
3453 else:
3454 yield from filter(
3455 lambda z: all(min_size <= len(bk) <= max_size for bk in z),
3456 set_partitions_helper(L, k),
3457 )
3458
3459
3460class time_limited:
3461 """
3462 Yield items from *iterable* until *limit_seconds* have passed.
3463 If the time limit expires before all items have been yielded, the
3464 ``timed_out`` parameter will be set to ``True``.
3465
3466 >>> from time import sleep
3467 >>> def generator():
3468 ... yield 1
3469 ... yield 2
3470 ... sleep(0.2)
3471 ... yield 3
3472 >>> iterable = time_limited(0.1, generator())
3473 >>> list(iterable)
3474 [1, 2]
3475 >>> iterable.timed_out
3476 True
3477
3478 Note that the time is checked before each item is yielded, and iteration
3479 stops if the time elapsed is greater than *limit_seconds*. If your time
3480 limit is 1 second, but it takes 2 seconds to generate the first item from
3481 the iterable, the function will run for 2 seconds and not yield anything.
3482 As a special case, when *limit_seconds* is zero, the iterator never
3483 returns anything.
3484
3485 """
3486
3487 def __init__(self, limit_seconds, iterable):
3488 if limit_seconds < 0:
3489 raise ValueError('limit_seconds must be positive')
3490 self.limit_seconds = limit_seconds
3491 self._iterator = iter(iterable)
3492 self._start_time = monotonic()
3493 self.timed_out = False
3494
3495 def __iter__(self):
3496 return self
3497
3498 def __next__(self):
3499 if self.limit_seconds == 0:
3500 self.timed_out = True
3501 raise StopIteration
3502 item = next(self._iterator)
3503 if monotonic() - self._start_time > self.limit_seconds:
3504 self.timed_out = True
3505 raise StopIteration
3506
3507 return item
3508
3509
3510def only(iterable, default=None, too_long=None):
3511 """If *iterable* has only one item, return it.
3512 If it has zero items, return *default*.
3513 If it has more than one item, raise the exception given by *too_long*,
3514 which is ``ValueError`` by default.
3515
3516 >>> only([], default='missing')
3517 'missing'
3518 >>> only([1])
3519 1
3520 >>> only([1, 2]) # doctest: +IGNORE_EXCEPTION_DETAIL
3521 Traceback (most recent call last):
3522 ...
3523 ValueError: Expected exactly one item in iterable, but got 1, 2,
3524 and perhaps more.'
3525 >>> only([1, 2], too_long=TypeError) # doctest: +IGNORE_EXCEPTION_DETAIL
3526 Traceback (most recent call last):
3527 ...
3528 TypeError
3529
3530 Note that :func:`only` attempts to advance *iterable* twice to ensure there
3531 is only one item. See :func:`spy` or :func:`peekable` to check
3532 iterable contents less destructively.
3533
3534 """
3535 iterator = iter(iterable)
3536 for first in iterator:
3537 for second in iterator:
3538 msg = (
3539 f'Expected exactly one item in iterable, but got {first!r}, '
3540 f'{second!r}, and perhaps more.'
3541 )
3542 raise too_long or ValueError(msg)
3543 return first
3544 return default
3545
3546
3547def _ichunk(iterator, n):
3548 cache = deque()
3549 chunk = islice(iterator, n)
3550
3551 def generator():
3552 with suppress(StopIteration):
3553 while True:
3554 if cache:
3555 yield cache.popleft()
3556 else:
3557 yield next(chunk)
3558
3559 def materialize_next(n=1):
3560 # if n not specified materialize everything
3561 if n is None:
3562 cache.extend(chunk)
3563 return len(cache)
3564
3565 to_cache = n - len(cache)
3566
3567 # materialize up to n
3568 if to_cache > 0:
3569 cache.extend(islice(chunk, to_cache))
3570
3571 # return number materialized up to n
3572 return min(n, len(cache))
3573
3574 return (generator(), materialize_next)
3575
3576
3577def ichunked(iterable, n):
3578 """Break *iterable* into sub-iterables with *n* elements each.
3579 :func:`ichunked` is like :func:`chunked`, but it yields iterables
3580 instead of lists.
3581
3582 If the sub-iterables are read in order, the elements of *iterable*
3583 won't be stored in memory.
3584 If they are read out of order, :func:`itertools.tee` is used to cache
3585 elements as necessary.
3586
3587 >>> from itertools import count
3588 >>> all_chunks = ichunked(count(), 4)
3589 >>> c_1, c_2, c_3 = next(all_chunks), next(all_chunks), next(all_chunks)
3590 >>> list(c_2) # c_1's elements have been cached; c_3's haven't been
3591 [4, 5, 6, 7]
3592 >>> list(c_1)
3593 [0, 1, 2, 3]
3594 >>> list(c_3)
3595 [8, 9, 10, 11]
3596
3597 """
3598 iterator = iter(iterable)
3599 while True:
3600 # Create new chunk
3601 chunk, materialize_next = _ichunk(iterator, n)
3602
3603 # Check to see whether we're at the end of the source iterable
3604 if not materialize_next():
3605 return
3606
3607 yield chunk
3608
3609 # Fill previous chunk's cache
3610 materialize_next(None)
3611
3612
3613def iequals(*iterables):
3614 """Return ``True`` if all given *iterables* are equal to each other,
3615 which means that they contain the same elements in the same order.
3616
3617 The function is useful for comparing iterables of different data types
3618 or iterables that do not support equality checks.
3619
3620 >>> iequals("abc", ['a', 'b', 'c'], ('a', 'b', 'c'), iter("abc"))
3621 True
3622
3623 >>> iequals("abc", "acb")
3624 False
3625
3626 Not to be confused with :func:`all_equal`, which checks whether all
3627 elements of iterable are equal to each other.
3628
3629 """
3630 try:
3631 return all(map(all_equal, zip(*iterables, strict=True)))
3632 except ValueError:
3633 return False
3634
3635
3636def distinct_combinations(iterable, r):
3637 """Yield the distinct combinations of *r* items taken from *iterable*.
3638
3639 >>> list(distinct_combinations([0, 0, 1], 2))
3640 [(0, 0), (0, 1)]
3641
3642 Equivalent to ``set(combinations(iterable))``, except duplicates are not
3643 generated and thrown away. For larger input sequences this is much more
3644 efficient.
3645
3646 """
3647 if r < 0:
3648 raise ValueError('r must be non-negative')
3649 elif r == 0:
3650 yield ()
3651 return
3652 pool = tuple(iterable)
3653 generators = [unique_everseen(enumerate(pool), key=itemgetter(1))]
3654 current_combo = [None] * r
3655 level = 0
3656 while generators:
3657 try:
3658 cur_idx, p = next(generators[-1])
3659 except StopIteration:
3660 generators.pop()
3661 level -= 1
3662 continue
3663 current_combo[level] = p
3664 if level + 1 == r:
3665 yield tuple(current_combo)
3666 else:
3667 generators.append(
3668 unique_everseen(
3669 enumerate(pool[cur_idx + 1 :], cur_idx + 1),
3670 key=itemgetter(1),
3671 )
3672 )
3673 level += 1
3674
3675
3676def filter_except(validator, iterable, *exceptions):
3677 """Yield the items from *iterable* for which the *validator* function does
3678 not raise one of the specified *exceptions*.
3679
3680 *validator* is called for each item in *iterable*.
3681 It should be a function that accepts one argument and raises an exception
3682 if that item is not valid.
3683
3684 >>> iterable = ['1', '2', 'three', '4', None]
3685 >>> list(filter_except(int, iterable, ValueError, TypeError))
3686 ['1', '2', '4']
3687
3688 If an exception other than one given by *exceptions* is raised by
3689 *validator*, it is raised like normal.
3690 """
3691 for item in iterable:
3692 try:
3693 validator(item)
3694 except exceptions:
3695 pass
3696 else:
3697 yield item
3698
3699
3700def map_except(function, iterable, *exceptions):
3701 """Transform each item from *iterable* with *function* and yield the
3702 result, unless *function* raises one of the specified *exceptions*.
3703
3704 *function* is called to transform each item in *iterable*.
3705 It should accept one argument.
3706
3707 >>> iterable = ['1', '2', 'three', '4', None]
3708 >>> list(map_except(int, iterable, ValueError, TypeError))
3709 [1, 2, 4]
3710
3711 If an exception other than one given by *exceptions* is raised by
3712 *function*, it is raised like normal.
3713 """
3714 for item in iterable:
3715 try:
3716 yield function(item)
3717 except exceptions:
3718 pass
3719
3720
3721def map_if(iterable, pred, func, func_else=None):
3722 """Evaluate each item from *iterable* using *pred*. If the result is
3723 equivalent to ``True``, transform the item with *func* and yield it.
3724 Otherwise, transform the item with *func_else* and yield it.
3725
3726 *pred*, *func*, and *func_else* should each be functions that accept
3727 one argument. By default, *func_else* is the identity function.
3728
3729 >>> from math import sqrt
3730 >>> iterable = list(range(-5, 5))
3731 >>> iterable
3732 [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4]
3733 >>> list(map_if(iterable, lambda x: x > 3, lambda x: 'toobig'))
3734 [-5, -4, -3, -2, -1, 0, 1, 2, 3, 'toobig']
3735 >>> list(map_if(iterable, lambda x: x >= 0,
3736 ... lambda x: f'{sqrt(x):.2f}', lambda x: None))
3737 [None, None, None, None, None, '0.00', '1.00', '1.41', '1.73', '2.00']
3738 """
3739
3740 if func_else is None:
3741 for item in iterable:
3742 yield func(item) if pred(item) else item
3743
3744 else:
3745 for item in iterable:
3746 yield func(item) if pred(item) else func_else(item)
3747
3748
3749def _sample_unweighted(iterator, k, strict):
3750 # Algorithm L in the 1994 paper by Kim-Hung Li:
3751 # "Reservoir-Sampling Algorithms of Time Complexity O(n(1+log(N/n)))".
3752
3753 reservoir = list(islice(iterator, k))
3754 if strict and len(reservoir) < k:
3755 raise ValueError('Sample larger than population')
3756 W = 1.0
3757
3758 with suppress(StopIteration):
3759 while True:
3760 W *= random() ** (1 / k)
3761 skip = floor(log(random()) / log1p(-W))
3762 element = next(islice(iterator, skip, None))
3763 reservoir[randrange(k)] = element
3764
3765 shuffle(reservoir)
3766 return reservoir
3767
3768
3769def _sample_weighted(iterator, k, weights, strict):
3770 # Implementation of "A-ExpJ" from the 2006 paper by Efraimidis et al. :
3771 # "Weighted random sampling with a reservoir".
3772
3773 # Log-transform for numerical stability for weights that are small/large
3774 weight_keys = (log(random()) / weight for weight in weights)
3775
3776 # Fill up the reservoir (collection of samples) with the first `k`
3777 # weight-keys and elements, then heapify the list.
3778 reservoir = take(k, zip(weight_keys, iterator))
3779 if strict and len(reservoir) < k:
3780 raise ValueError('Sample larger than population')
3781
3782 heapify(reservoir)
3783
3784 # The number of jumps before changing the reservoir is a random variable
3785 # with an exponential distribution. Sample it using random() and logs.
3786 smallest_weight_key, _ = reservoir[0]
3787 weights_to_skip = log(random()) / smallest_weight_key
3788
3789 for weight, element in zip(weights, iterator):
3790 if weight >= weights_to_skip:
3791 # The notation here is consistent with the paper, but we store
3792 # the weight-keys in log-space for better numerical stability.
3793 smallest_weight_key, _ = reservoir[0]
3794 t_w = exp(weight * smallest_weight_key)
3795 r_2 = uniform(t_w, 1) # generate U(t_w, 1)
3796 weight_key = log(r_2) / weight
3797 heapreplace(reservoir, (weight_key, element))
3798 smallest_weight_key, _ = reservoir[0]
3799 weights_to_skip = log(random()) / smallest_weight_key
3800 else:
3801 weights_to_skip -= weight
3802
3803 ret = [element for weight_key, element in reservoir]
3804 shuffle(ret)
3805 return ret
3806
3807
3808def _sample_counted(population, k, counts, strict):
3809 element = None
3810 remaining = 0
3811
3812 def feed(i):
3813 # Advance *i* steps ahead and consume an element
3814 nonlocal element, remaining
3815
3816 while i + 1 > remaining:
3817 i = i - remaining
3818 element = next(population)
3819 remaining = next(counts)
3820 remaining -= i + 1
3821 return element
3822
3823 with suppress(StopIteration):
3824 reservoir = []
3825 for _ in range(k):
3826 reservoir.append(feed(0))
3827
3828 if strict and len(reservoir) < k:
3829 raise ValueError('Sample larger than population')
3830
3831 with suppress(StopIteration):
3832 W = 1.0
3833 while True:
3834 W *= random() ** (1 / k)
3835 skip = floor(log(random()) / log1p(-W))
3836 element = feed(skip)
3837 reservoir[randrange(k)] = element
3838
3839 shuffle(reservoir)
3840 return reservoir
3841
3842
3843def sample(iterable, k, weights=None, *, counts=None, strict=False):
3844 """Return a *k*-length list of elements chosen (without replacement)
3845 from the *iterable*.
3846
3847 Similar to :func:`random.sample`, but works on inputs that aren't
3848 indexable (such as sets and dictionaries) and on inputs where the
3849 size isn't known in advance (such as generators).
3850
3851 >>> iterable = range(100)
3852 >>> sample(iterable, 5) # doctest: +SKIP
3853 [81, 60, 96, 16, 4]
3854
3855 For iterables with repeated elements, you may supply *counts* to
3856 indicate the repeats.
3857
3858 >>> iterable = ['a', 'b']
3859 >>> counts = [3, 4] # Equivalent to 'a', 'a', 'a', 'b', 'b', 'b', 'b'
3860 >>> sample(iterable, k=3, counts=counts) # doctest: +SKIP
3861 ['a', 'a', 'b']
3862
3863 An iterable with *weights* may be given:
3864
3865 >>> iterable = range(100)
3866 >>> weights = (i * i + 1 for i in range(100))
3867 >>> sampled = sample(iterable, 5, weights=weights) # doctest: +SKIP
3868 [79, 67, 74, 66, 78]
3869
3870 Weighted selections are made without replacement.
3871 After an element is selected, it is removed from the pool and the
3872 relative weights of the other elements increase (this
3873 does not match the behavior of :func:`random.sample`'s *counts*
3874 parameter). Note that *weights* may not be used with *counts*.
3875
3876 If the length of *iterable* is less than *k*,
3877 ``ValueError`` is raised if *strict* is ``True`` and
3878 all elements are returned (in shuffled order) if *strict* is ``False``.
3879
3880 By default, the `Algorithm L <https://w.wiki/ANrM>`__ reservoir sampling
3881 technique is used. When *weights* are provided,
3882 `Algorithm A-ExpJ <https://w.wiki/ANrS>`__ is used instead.
3883
3884 Notes on reproducibility:
3885
3886 * The algorithms rely on inexact floating-point functions provided
3887 by the underlying math library (e.g. ``log``, ``log1p``, and ``pow``).
3888 Those functions can `produce slightly different results
3889 <https://members.loria.fr/PZimmermann/papers/accuracy.pdf>`_ on
3890 different builds. Accordingly, selections can vary across builds
3891 even for the same seed.
3892
3893 * The algorithms loop over the input and make selections based on
3894 ordinal position, so selections from unordered collections (such as
3895 sets) won't reproduce across sessions on the same platform using the
3896 same seed. For example, this won't reproduce::
3897
3898 >> seed(8675309)
3899 >> sample(set('abcdefghijklmnopqrstuvwxyz'), 10)
3900 ['c', 'p', 'e', 'w', 's', 'a', 'j', 'd', 'n', 't']
3901
3902 """
3903 iterator = iter(iterable)
3904
3905 if k < 0:
3906 raise ValueError('k must be non-negative')
3907
3908 if k == 0:
3909 return []
3910
3911 if weights is not None and counts is not None:
3912 raise TypeError('weights and counts are mutually exclusive')
3913
3914 elif weights is not None:
3915 weights = iter(weights)
3916 return _sample_weighted(iterator, k, weights, strict)
3917
3918 elif counts is not None:
3919 counts = iter(counts)
3920 return _sample_counted(iterator, k, counts, strict)
3921
3922 else:
3923 return _sample_unweighted(iterator, k, strict)
3924
3925
3926def is_sorted(iterable, key=None, reverse=False, strict=False):
3927 """Returns ``True`` if the items of iterable are in sorted order, and
3928 ``False`` otherwise. *key* and *reverse* have the same meaning that they do
3929 in the built-in :func:`sorted` function.
3930
3931 >>> is_sorted(['1', '2', '3', '4', '5'], key=int)
3932 True
3933 >>> is_sorted([5, 4, 3, 1, 2], reverse=True)
3934 False
3935
3936 If *strict*, tests for strict sorting, that is, returns ``False`` if equal
3937 elements are found:
3938
3939 >>> is_sorted([1, 2, 2])
3940 True
3941 >>> is_sorted([1, 2, 2], strict=True)
3942 False
3943
3944 The function returns ``False`` after encountering the first out-of-order
3945 item, which means it may produce results that differ from the built-in
3946 :func:`sorted` function for objects with unusual comparison dynamics
3947 (like ``math.nan``). If there are no out-of-order items, the iterable is
3948 exhausted.
3949 """
3950 it = iterable if (key is None) else map(key, iterable)
3951 a, b = tee(it)
3952 next(b, None)
3953 if reverse:
3954 b, a = a, b
3955 return all(map(lt, a, b)) if strict else not any(map(lt, b, a))
3956
3957
3958class AbortThread(BaseException):
3959 pass
3960
3961
3962class callback_iter:
3963 """Convert a function that uses callbacks to an iterator.
3964
3965 Let *func* be a function that takes a `callback` keyword argument.
3966 For example:
3967
3968 >>> def func(callback=None):
3969 ... for i, c in [(1, 'a'), (2, 'b'), (3, 'c')]:
3970 ... if callback:
3971 ... callback(i, c)
3972 ... return 4
3973
3974
3975 Use ``with callback_iter(func)`` to get an iterator over the parameters
3976 that are delivered to the callback.
3977
3978 >>> with callback_iter(func) as it:
3979 ... for args, kwargs in it:
3980 ... print(args)
3981 (1, 'a')
3982 (2, 'b')
3983 (3, 'c')
3984
3985 The function will be called in a background thread. The ``done`` property
3986 indicates whether it has completed execution.
3987
3988 >>> it.done
3989 True
3990
3991 If it completes successfully, its return value will be available
3992 in the ``result`` property.
3993
3994 >>> it.result
3995 4
3996
3997 Notes:
3998
3999 * If the function uses some keyword argument besides ``callback``, supply
4000 *callback_kwd*.
4001 * If it finished executing, but raised an exception, accessing the
4002 ``result`` property will raise the same exception.
4003 * If it hasn't finished executing, accessing the ``result``
4004 property from within the ``with`` block will raise ``RuntimeError``.
4005 * If it hasn't finished executing, accessing the ``result`` property from
4006 outside the ``with`` block will raise a
4007 ``more_itertools.AbortThread`` exception.
4008 * Provide *wait_seconds* to adjust how frequently the it is polled for
4009 output.
4010
4011 """
4012
4013 def __init__(self, func, callback_kwd='callback', wait_seconds=0.1):
4014 self._func = func
4015 self._callback_kwd = callback_kwd
4016 self._aborted = False
4017 self._future = None
4018 self._wait_seconds = wait_seconds
4019 # Lazily import concurrent.future
4020 self._executor = __import__(
4021 'concurrent.futures'
4022 ).futures.ThreadPoolExecutor(max_workers=1)
4023 self._iterator = self._reader()
4024
4025 def __enter__(self):
4026 return self
4027
4028 def __exit__(self, exc_type, exc_value, traceback):
4029 self._aborted = True
4030 self._executor.shutdown()
4031
4032 def __iter__(self):
4033 return self
4034
4035 def __next__(self):
4036 return next(self._iterator)
4037
4038 @property
4039 def done(self):
4040 if self._future is None:
4041 return False
4042 return self._future.done()
4043
4044 @property
4045 def result(self):
4046 if not self.done:
4047 raise RuntimeError('Function has not yet completed')
4048
4049 return self._future.result()
4050
4051 def _reader(self):
4052 q = Queue()
4053
4054 def callback(*args, **kwargs):
4055 if self._aborted:
4056 raise AbortThread('canceled by user')
4057
4058 q.put((args, kwargs))
4059
4060 self._future = self._executor.submit(
4061 self._func, **{self._callback_kwd: callback}
4062 )
4063
4064 while True:
4065 try:
4066 item = q.get(timeout=self._wait_seconds)
4067 except Empty:
4068 pass
4069 else:
4070 q.task_done()
4071 yield item
4072
4073 if self._future.done():
4074 break
4075
4076 remaining = []
4077 while True:
4078 try:
4079 item = q.get_nowait()
4080 except Empty:
4081 break
4082 else:
4083 q.task_done()
4084 remaining.append(item)
4085 q.join()
4086 yield from remaining
4087
4088
4089def windowed_complete(iterable, n):
4090 """
4091 Yield ``(beginning, middle, end)`` tuples, where:
4092
4093 * Each ``middle`` has *n* items from *iterable*
4094 * Each ``beginning`` has the items before the ones in ``middle``
4095 * Each ``end`` has the items after the ones in ``middle``
4096
4097 >>> iterable = range(7)
4098 >>> n = 3
4099 >>> for beginning, middle, end in windowed_complete(iterable, n):
4100 ... print(beginning, middle, end)
4101 () (0, 1, 2) (3, 4, 5, 6)
4102 (0,) (1, 2, 3) (4, 5, 6)
4103 (0, 1) (2, 3, 4) (5, 6)
4104 (0, 1, 2) (3, 4, 5) (6,)
4105 (0, 1, 2, 3) (4, 5, 6) ()
4106
4107 Note that *n* must be at least 0 and most equal to the length of
4108 *iterable*.
4109
4110 This function will exhaust the iterable and may require significant
4111 storage.
4112 """
4113 if n < 0:
4114 raise ValueError('n must be >= 0')
4115
4116 seq = tuple(iterable)
4117 size = len(seq)
4118
4119 if n > size:
4120 raise ValueError('n must be <= len(seq)')
4121
4122 for i in range(size - n + 1):
4123 beginning = seq[:i]
4124 middle = seq[i : i + n]
4125 end = seq[i + n :]
4126 yield beginning, middle, end
4127
4128
4129def all_unique(iterable, key=None):
4130 """
4131 Returns ``True`` if all the elements of *iterable* are unique (no two
4132 elements are equal).
4133
4134 >>> all_unique('ABCB')
4135 False
4136
4137 If a *key* function is specified, it will be used to make comparisons.
4138
4139 >>> all_unique('ABCb')
4140 True
4141 >>> all_unique('ABCb', str.lower)
4142 False
4143
4144 The function returns as soon as the first non-unique element is
4145 encountered. Iterables with a mix of hashable and unhashable items can
4146 be used, but the function will be slower for unhashable items.
4147 """
4148 seenset = set()
4149 seenset_add = seenset.add
4150 seenlist = []
4151 seenlist_add = seenlist.append
4152 for element in map(key, iterable) if key else iterable:
4153 try:
4154 if element in seenset:
4155 return False
4156 seenset_add(element)
4157 except TypeError:
4158 if element in seenlist:
4159 return False
4160 seenlist_add(element)
4161 return True
4162
4163
4164def nth_product(index, *args):
4165 """Equivalent to ``list(product(*args))[index]``.
4166
4167 The products of *args* can be ordered lexicographically.
4168 :func:`nth_product` computes the product at sort position *index* without
4169 computing the previous products.
4170
4171 >>> nth_product(8, range(2), range(2), range(2), range(2))
4172 (1, 0, 0, 0)
4173
4174 ``IndexError`` will be raised if the given *index* is invalid.
4175 """
4176 pools = list(map(tuple, reversed(args)))
4177 ns = list(map(len, pools))
4178
4179 c = reduce(mul, ns)
4180
4181 if index < 0:
4182 index += c
4183
4184 if not 0 <= index < c:
4185 raise IndexError
4186
4187 result = []
4188 for pool, n in zip(pools, ns):
4189 result.append(pool[index % n])
4190 index //= n
4191
4192 return tuple(reversed(result))
4193
4194
4195def nth_permutation(iterable, r, index):
4196 """Equivalent to ``list(permutations(iterable, r))[index]```
4197
4198 The subsequences of *iterable* that are of length *r* where order is
4199 important can be ordered lexicographically. :func:`nth_permutation`
4200 computes the subsequence at sort position *index* directly, without
4201 computing the previous subsequences.
4202
4203 >>> nth_permutation('ghijk', 2, 5)
4204 ('h', 'i')
4205
4206 ``ValueError`` will be raised If *r* is negative or greater than the length
4207 of *iterable*.
4208 ``IndexError`` will be raised if the given *index* is invalid.
4209 """
4210 pool = list(iterable)
4211 n = len(pool)
4212
4213 if r is None or r == n:
4214 r, c = n, factorial(n)
4215 elif not 0 <= r < n:
4216 raise ValueError
4217 else:
4218 c = perm(n, r)
4219 assert c > 0 # factorial(n)>0, and r<n so perm(n,r) is never zero
4220
4221 if index < 0:
4222 index += c
4223
4224 if not 0 <= index < c:
4225 raise IndexError
4226
4227 result = [0] * r
4228 q = index * factorial(n) // c if r < n else index
4229 for d in range(1, n + 1):
4230 q, i = divmod(q, d)
4231 if 0 <= n - d < r:
4232 result[n - d] = i
4233 if q == 0:
4234 break
4235
4236 return tuple(map(pool.pop, result))
4237
4238
4239def nth_combination_with_replacement(iterable, r, index):
4240 """Equivalent to
4241 ``list(combinations_with_replacement(iterable, r))[index]``.
4242
4243
4244 The subsequences with repetition of *iterable* that are of length *r* can
4245 be ordered lexicographically. :func:`nth_combination_with_replacement`
4246 computes the subsequence at sort position *index* directly, without
4247 computing the previous subsequences with replacement.
4248
4249 >>> nth_combination_with_replacement(range(5), 3, 5)
4250 (0, 1, 1)
4251
4252 ``ValueError`` will be raised If *r* is negative or greater than the length
4253 of *iterable*.
4254 ``IndexError`` will be raised if the given *index* is invalid.
4255 """
4256 pool = tuple(iterable)
4257 n = len(pool)
4258 if (r < 0) or (r > n):
4259 raise ValueError
4260
4261 c = comb(n + r - 1, r)
4262
4263 if index < 0:
4264 index += c
4265
4266 if (index < 0) or (index >= c):
4267 raise IndexError
4268
4269 result = []
4270 i = 0
4271 while r:
4272 r -= 1
4273 while n >= 0:
4274 num_combs = comb(n + r - 1, r)
4275 if index < num_combs:
4276 break
4277 n -= 1
4278 i += 1
4279 index -= num_combs
4280 result.append(pool[i])
4281
4282 return tuple(result)
4283
4284
4285def value_chain(*args):
4286 """Yield all arguments passed to the function in the same order in which
4287 they were passed. If an argument itself is iterable then iterate over its
4288 values.
4289
4290 >>> list(value_chain(1, 2, 3, [4, 5, 6]))
4291 [1, 2, 3, 4, 5, 6]
4292
4293 Binary and text strings are not considered iterable and are emitted
4294 as-is:
4295
4296 >>> list(value_chain('12', '34', ['56', '78']))
4297 ['12', '34', '56', '78']
4298
4299 Pre- or postpend a single element to an iterable:
4300
4301 >>> list(value_chain(1, [2, 3, 4, 5, 6]))
4302 [1, 2, 3, 4, 5, 6]
4303 >>> list(value_chain([1, 2, 3, 4, 5], 6))
4304 [1, 2, 3, 4, 5, 6]
4305
4306 Multiple levels of nesting are not flattened.
4307
4308 """
4309 for value in args:
4310 if isinstance(value, (str, bytes)):
4311 yield value
4312 continue
4313 try:
4314 yield from value
4315 except TypeError:
4316 yield value
4317
4318
4319def product_index(element, *args):
4320 """Equivalent to ``list(product(*args)).index(element)``
4321
4322 The products of *args* can be ordered lexicographically.
4323 :func:`product_index` computes the first index of *element* without
4324 computing the previous products.
4325
4326 >>> product_index([8, 2], range(10), range(5))
4327 42
4328
4329 ``ValueError`` will be raised if the given *element* isn't in the product
4330 of *args*.
4331 """
4332 elements = tuple(element)
4333 pools = tuple(map(tuple, args))
4334 if len(element) != len(args):
4335 raise ValueError('element is not a product of args')
4336
4337 index = 0
4338 for elem, pool in zip(elements, pools):
4339 index = index * len(pool) + pool.index(elem)
4340
4341 return index
4342
4343
4344def combination_index(element, iterable):
4345 """Equivalent to ``list(combinations(iterable, r)).index(element)``
4346
4347 The subsequences of *iterable* that are of length *r* can be ordered
4348 lexicographically. :func:`combination_index` computes the index of the
4349 first *element*, without computing the previous combinations.
4350
4351 >>> combination_index('adf', 'abcdefg')
4352 10
4353
4354 ``ValueError`` will be raised if the given *element* isn't one of the
4355 combinations of *iterable*.
4356 """
4357 element = enumerate(element)
4358 k, y = next(element, (None, None))
4359 if k is None:
4360 return 0
4361
4362 indexes = []
4363 pool = enumerate(iterable)
4364 for n, x in pool:
4365 if x == y:
4366 indexes.append(n)
4367 tmp, y = next(element, (None, None))
4368 if tmp is None:
4369 break
4370 else:
4371 k = tmp
4372 else:
4373 raise ValueError('element is not a combination of iterable')
4374
4375 n, _ = last(pool, default=(n, None))
4376
4377 index = 1
4378 for i, j in enumerate(reversed(indexes), start=1):
4379 j = n - j
4380 if i <= j:
4381 index += comb(j, i)
4382
4383 return comb(n + 1, k + 1) - index
4384
4385
4386def combination_with_replacement_index(element, iterable):
4387 """Equivalent to
4388 ``list(combinations_with_replacement(iterable, r)).index(element)``
4389
4390 The subsequences with repetition of *iterable* that are of length *r* can
4391 be ordered lexicographically. :func:`combination_with_replacement_index`
4392 computes the index of the first *element*, without computing the previous
4393 combinations with replacement.
4394
4395 >>> combination_with_replacement_index('adf', 'abcdefg')
4396 20
4397
4398 ``ValueError`` will be raised if the given *element* isn't one of the
4399 combinations with replacement of *iterable*.
4400 """
4401 element = tuple(element)
4402 l = len(element)
4403 element = enumerate(element)
4404
4405 k, y = next(element, (None, None))
4406 if k is None:
4407 return 0
4408
4409 indexes = []
4410 pool = tuple(iterable)
4411 for n, x in enumerate(pool):
4412 while x == y:
4413 indexes.append(n)
4414 tmp, y = next(element, (None, None))
4415 if tmp is None:
4416 break
4417 else:
4418 k = tmp
4419 if y is None:
4420 break
4421 else:
4422 raise ValueError(
4423 'element is not a combination with replacement of iterable'
4424 )
4425
4426 n = len(pool)
4427 occupations = [0] * n
4428 for p in indexes:
4429 occupations[p] += 1
4430
4431 index = 0
4432 cumulative_sum = 0
4433 for k in range(1, n):
4434 cumulative_sum += occupations[k - 1]
4435 j = l + n - 1 - k - cumulative_sum
4436 i = n - k
4437 if i <= j:
4438 index += comb(j, i)
4439
4440 return index
4441
4442
4443def permutation_index(element, iterable):
4444 """Equivalent to ``list(permutations(iterable, r)).index(element)```
4445
4446 The subsequences of *iterable* that are of length *r* where order is
4447 important can be ordered lexicographically. :func:`permutation_index`
4448 computes the index of the first *element* directly, without computing
4449 the previous permutations.
4450
4451 >>> permutation_index([1, 3, 2], range(5))
4452 19
4453
4454 ``ValueError`` will be raised if the given *element* isn't one of the
4455 permutations of *iterable*.
4456 """
4457 index = 0
4458 pool = list(iterable)
4459 for i, x in zip(range(len(pool), -1, -1), element):
4460 r = pool.index(x)
4461 index = index * i + r
4462 del pool[r]
4463
4464 return index
4465
4466
4467class countable:
4468 """Wrap *iterable* and keep a count of how many items have been consumed.
4469
4470 The ``items_seen`` attribute starts at ``0`` and increments as the iterable
4471 is consumed:
4472
4473 >>> iterable = map(str, range(10))
4474 >>> it = countable(iterable)
4475 >>> it.items_seen
4476 0
4477 >>> next(it), next(it)
4478 ('0', '1')
4479 >>> list(it)
4480 ['2', '3', '4', '5', '6', '7', '8', '9']
4481 >>> it.items_seen
4482 10
4483 """
4484
4485 def __init__(self, iterable):
4486 self._iterator = iter(iterable)
4487 self.items_seen = 0
4488
4489 def __iter__(self):
4490 return self
4491
4492 def __next__(self):
4493 item = next(self._iterator)
4494 self.items_seen += 1
4495
4496 return item
4497
4498
4499def chunked_even(iterable, n):
4500 """Break *iterable* into lists of approximately length *n*.
4501 Items are distributed such the lengths of the lists differ by at most
4502 1 item.
4503
4504 >>> iterable = [1, 2, 3, 4, 5, 6, 7]
4505 >>> n = 3
4506 >>> list(chunked_even(iterable, n)) # List lengths: 3, 2, 2
4507 [[1, 2, 3], [4, 5], [6, 7]]
4508 >>> list(chunked(iterable, n)) # List lengths: 3, 3, 1
4509 [[1, 2, 3], [4, 5, 6], [7]]
4510
4511 """
4512 iterator = iter(iterable)
4513
4514 # Initialize a buffer to process the chunks while keeping
4515 # some back to fill any underfilled chunks
4516 min_buffer = (n - 1) * (n - 2)
4517 buffer = list(islice(iterator, min_buffer))
4518
4519 # Append items until we have a completed chunk
4520 for _ in islice(map(buffer.append, iterator), n, None, n):
4521 yield buffer[:n]
4522 del buffer[:n]
4523
4524 # Check if any chunks need addition processing
4525 if not buffer:
4526 return
4527 length = len(buffer)
4528
4529 # Chunks are either size `full_size <= n` or `partial_size = full_size - 1`
4530 q, r = divmod(length, n)
4531 num_lists = q + (1 if r > 0 else 0)
4532 q, r = divmod(length, num_lists)
4533 full_size = q + (1 if r > 0 else 0)
4534 partial_size = full_size - 1
4535 num_full = length - partial_size * num_lists
4536
4537 # Yield chunks of full size
4538 partial_start_idx = num_full * full_size
4539 if full_size > 0:
4540 for i in range(0, partial_start_idx, full_size):
4541 yield buffer[i : i + full_size]
4542
4543 # Yield chunks of partial size
4544 if partial_size > 0:
4545 for i in range(partial_start_idx, length, partial_size):
4546 yield buffer[i : i + partial_size]
4547
4548
4549def zip_broadcast(*objects, scalar_types=(str, bytes), strict=False):
4550 """A version of :func:`zip` that "broadcasts" any scalar
4551 (i.e., non-iterable) items into output tuples.
4552
4553 >>> iterable_1 = [1, 2, 3]
4554 >>> iterable_2 = ['a', 'b', 'c']
4555 >>> scalar = '_'
4556 >>> list(zip_broadcast(iterable_1, iterable_2, scalar))
4557 [(1, 'a', '_'), (2, 'b', '_'), (3, 'c', '_')]
4558
4559 The *scalar_types* keyword argument determines what types are considered
4560 scalar. It is set to ``(str, bytes)`` by default. Set it to ``None`` to
4561 treat strings and byte strings as iterable:
4562
4563 >>> list(zip_broadcast('abc', 0, 'xyz', scalar_types=None))
4564 [('a', 0, 'x'), ('b', 0, 'y'), ('c', 0, 'z')]
4565
4566 If the *strict* keyword argument is ``True``, then
4567 ``ValueError`` will be raised if any of the iterables have
4568 different lengths.
4569 """
4570
4571 def is_scalar(obj):
4572 if scalar_types and isinstance(obj, scalar_types):
4573 return True
4574 try:
4575 iter(obj)
4576 except TypeError:
4577 return True
4578 else:
4579 return False
4580
4581 size = len(objects)
4582 if not size:
4583 return
4584
4585 new_item = [None] * size
4586 iterables, iterable_positions = [], []
4587 for i, obj in enumerate(objects):
4588 if is_scalar(obj):
4589 new_item[i] = obj
4590 else:
4591 iterables.append(iter(obj))
4592 iterable_positions.append(i)
4593
4594 if not iterables:
4595 yield tuple(objects)
4596 return
4597
4598 for item in zip(*iterables, strict=strict):
4599 for i, new_item[i] in zip(iterable_positions, item):
4600 pass
4601 yield tuple(new_item)
4602
4603
4604def unique_in_window(iterable, n, key=None):
4605 """Yield the items from *iterable* that haven't been seen recently.
4606 *n* is the size of the lookback window.
4607
4608 >>> iterable = [0, 1, 0, 2, 3, 0]
4609 >>> n = 3
4610 >>> list(unique_in_window(iterable, n))
4611 [0, 1, 2, 3, 0]
4612
4613 The *key* function, if provided, will be used to determine uniqueness:
4614
4615 >>> list(unique_in_window('abAcda', 3, key=lambda x: x.lower()))
4616 ['a', 'b', 'c', 'd', 'a']
4617
4618 The items in *iterable* must be hashable.
4619
4620 """
4621 if n <= 0:
4622 raise ValueError('n must be greater than 0')
4623
4624 window = deque(maxlen=n)
4625 counts = defaultdict(int)
4626 use_key = key is not None
4627
4628 for item in iterable:
4629 if len(window) == n:
4630 to_discard = window[0]
4631 if counts[to_discard] == 1:
4632 del counts[to_discard]
4633 else:
4634 counts[to_discard] -= 1
4635
4636 k = key(item) if use_key else item
4637 if k not in counts:
4638 yield item
4639 counts[k] += 1
4640 window.append(k)
4641
4642
4643def duplicates_everseen(iterable, key=None):
4644 """Yield duplicate elements after their first appearance.
4645
4646 >>> list(duplicates_everseen('mississippi'))
4647 ['s', 'i', 's', 's', 'i', 'p', 'i']
4648 >>> list(duplicates_everseen('AaaBbbCccAaa', str.lower))
4649 ['a', 'a', 'b', 'b', 'c', 'c', 'A', 'a', 'a']
4650
4651 This function is analogous to :func:`unique_everseen` and is subject to
4652 the same performance considerations.
4653
4654 """
4655 seen_set = set()
4656 seen_list = []
4657 use_key = key is not None
4658
4659 for element in iterable:
4660 k = key(element) if use_key else element
4661 try:
4662 if k not in seen_set:
4663 seen_set.add(k)
4664 else:
4665 yield element
4666 except TypeError:
4667 if k not in seen_list:
4668 seen_list.append(k)
4669 else:
4670 yield element
4671
4672
4673def duplicates_justseen(iterable, key=None):
4674 """Yields serially-duplicate elements after their first appearance.
4675
4676 >>> list(duplicates_justseen('mississippi'))
4677 ['s', 's', 'p']
4678 >>> list(duplicates_justseen('AaaBbbCccAaa', str.lower))
4679 ['a', 'a', 'b', 'b', 'c', 'c', 'a', 'a']
4680
4681 This function is analogous to :func:`unique_justseen`.
4682
4683 """
4684 return flatten(g for _, g in groupby(iterable, key) for _ in g)
4685
4686
4687def classify_unique(iterable, key=None):
4688 """Classify each element in terms of its uniqueness.
4689
4690 For each element in the input iterable, return a 3-tuple consisting of:
4691
4692 1. The element itself
4693 2. ``False`` if the element is equal to the one preceding it in the input,
4694 ``True`` otherwise (i.e. the equivalent of :func:`unique_justseen`)
4695 3. ``False`` if this element has been seen anywhere in the input before,
4696 ``True`` otherwise (i.e. the equivalent of :func:`unique_everseen`)
4697
4698 >>> list(classify_unique('otto')) # doctest: +NORMALIZE_WHITESPACE
4699 [('o', True, True),
4700 ('t', True, True),
4701 ('t', False, False),
4702 ('o', True, False)]
4703
4704 This function is analogous to :func:`unique_everseen` and is subject to
4705 the same performance considerations.
4706
4707 """
4708 seen_set = set()
4709 seen_list = []
4710 use_key = key is not None
4711 previous = None
4712
4713 for i, element in enumerate(iterable):
4714 k = key(element) if use_key else element
4715 is_unique_justseen = not i or previous != k
4716 previous = k
4717 is_unique_everseen = False
4718 try:
4719 if k not in seen_set:
4720 seen_set.add(k)
4721 is_unique_everseen = True
4722 except TypeError:
4723 if k not in seen_list:
4724 seen_list.append(k)
4725 is_unique_everseen = True
4726 yield element, is_unique_justseen, is_unique_everseen
4727
4728
4729def minmax(iterable_or_value, *others, key=None, default=_marker):
4730 """Returns both the smallest and largest items from an iterable
4731 or from two or more arguments.
4732
4733 >>> minmax([3, 1, 5])
4734 (1, 5)
4735
4736 >>> minmax(4, 2, 6)
4737 (2, 6)
4738
4739 If a *key* function is provided, it will be used to transform the input
4740 items for comparison.
4741
4742 >>> minmax([5, 30], key=str) # '30' sorts before '5'
4743 (30, 5)
4744
4745 If a *default* value is provided, it will be returned if there are no
4746 input items.
4747
4748 >>> minmax([], default=(0, 0))
4749 (0, 0)
4750
4751 Otherwise ``ValueError`` is raised.
4752
4753 This function makes a single pass over the input elements and takes care to
4754 minimize the number of comparisons made during processing.
4755
4756 Note that unlike the builtin ``max`` function, which always returns the first
4757 item with the maximum value, this function may return another item when there are
4758 ties.
4759
4760 This function is based on the
4761 `recipe <https://code.activestate.com/recipes/577916-fast-minmax-function>`__ by
4762 Raymond Hettinger.
4763 """
4764 iterable = (iterable_or_value, *others) if others else iterable_or_value
4765
4766 it = iter(iterable)
4767
4768 try:
4769 lo = hi = next(it)
4770 except StopIteration as exc:
4771 if default is _marker:
4772 raise ValueError(
4773 '`minmax()` argument is an empty iterable. '
4774 'Provide a `default` value to suppress this error.'
4775 ) from exc
4776 return default
4777
4778 # Different branches depending on the presence of key. This saves a lot
4779 # of unimportant copies which would slow the "key=None" branch
4780 # significantly down.
4781 if key is None:
4782 for x, y in zip_longest(it, it, fillvalue=lo):
4783 if y < x:
4784 x, y = y, x
4785 if x < lo:
4786 lo = x
4787 if hi < y:
4788 hi = y
4789
4790 else:
4791 lo_key = hi_key = key(lo)
4792
4793 for x, y in zip_longest(it, it, fillvalue=lo):
4794 x_key, y_key = key(x), key(y)
4795
4796 if y_key < x_key:
4797 x, y, x_key, y_key = y, x, y_key, x_key
4798 if x_key < lo_key:
4799 lo, lo_key = x, x_key
4800 if hi_key < y_key:
4801 hi, hi_key = y, y_key
4802
4803 return lo, hi
4804
4805
4806def constrained_batches(
4807 iterable, max_size, max_count=None, get_len=len, strict=True
4808):
4809 """Yield batches of items from *iterable* with a combined size limited by
4810 *max_size*.
4811
4812 >>> iterable = [b'12345', b'123', b'12345678', b'1', b'1', b'12', b'1']
4813 >>> list(constrained_batches(iterable, 10))
4814 [(b'12345', b'123'), (b'12345678', b'1', b'1'), (b'12', b'1')]
4815
4816 If a *max_count* is supplied, the number of items per batch is also
4817 limited:
4818
4819 >>> iterable = [b'12345', b'123', b'12345678', b'1', b'1', b'12', b'1']
4820 >>> list(constrained_batches(iterable, 10, max_count = 2))
4821 [(b'12345', b'123'), (b'12345678', b'1'), (b'1', b'12'), (b'1',)]
4822
4823 If a *get_len* function is supplied, use that instead of :func:`len` to
4824 determine item size.
4825
4826 If *strict* is ``True``, raise ``ValueError`` if any single item is bigger
4827 than *max_size*. Otherwise, allow single items to exceed *max_size*.
4828 """
4829 if max_size <= 0:
4830 raise ValueError('maximum size must be greater than zero')
4831
4832 batch = []
4833 batch_size = 0
4834 batch_count = 0
4835 for item in iterable:
4836 item_len = get_len(item)
4837 if strict and item_len > max_size:
4838 raise ValueError('item size exceeds maximum size')
4839
4840 reached_count = batch_count == max_count
4841 reached_size = item_len + batch_size > max_size
4842 if batch_count and (reached_size or reached_count):
4843 yield tuple(batch)
4844 batch.clear()
4845 batch_size = 0
4846 batch_count = 0
4847
4848 batch.append(item)
4849 batch_size += item_len
4850 batch_count += 1
4851
4852 if batch:
4853 yield tuple(batch)
4854
4855
4856def gray_product(*iterables):
4857 """Like :func:`itertools.product`, but return tuples in an order such
4858 that only one element in the generated tuple changes from one iteration
4859 to the next.
4860
4861 >>> list(gray_product('AB','CD'))
4862 [('A', 'C'), ('B', 'C'), ('B', 'D'), ('A', 'D')]
4863
4864 This function consumes all of the input iterables before producing output.
4865 If any of the input iterables have fewer than two items, ``ValueError``
4866 is raised.
4867
4868 For information on the algorithm, see
4869 `this section <https://www-cs-faculty.stanford.edu/~knuth/fasc2a.ps.gz>`__
4870 of Donald Knuth's *The Art of Computer Programming*.
4871 """
4872 all_iterables = tuple(tuple(x) for x in iterables)
4873 iterable_count = len(all_iterables)
4874 for iterable in all_iterables:
4875 if len(iterable) < 2:
4876 raise ValueError("each iterable must have two or more items")
4877
4878 # This is based on "Algorithm H" from section 7.2.1.1, page 20.
4879 # a holds the indexes of the source iterables for the n-tuple to be yielded
4880 # f is the array of "focus pointers"
4881 # o is the array of "directions"
4882 a = [0] * iterable_count
4883 f = list(range(iterable_count + 1))
4884 o = [1] * iterable_count
4885 while True:
4886 yield tuple(all_iterables[i][a[i]] for i in range(iterable_count))
4887 j = f[0]
4888 f[0] = 0
4889 if j == iterable_count:
4890 break
4891 a[j] = a[j] + o[j]
4892 if a[j] == 0 or a[j] == len(all_iterables[j]) - 1:
4893 o[j] = -o[j]
4894 f[j] = f[j + 1]
4895 f[j + 1] = j + 1
4896
4897
4898def partial_product(*iterables):
4899 """Yields tuples containing one item from each iterator, with subsequent
4900 tuples changing a single item at a time by advancing each iterator until it
4901 is exhausted. This sequence guarantees every value in each iterable is
4902 output at least once without generating all possible combinations.
4903
4904 This may be useful, for example, when testing an expensive function.
4905
4906 >>> list(partial_product('AB', 'C', 'DEF'))
4907 [('A', 'C', 'D'), ('B', 'C', 'D'), ('B', 'C', 'E'), ('B', 'C', 'F')]
4908 """
4909
4910 iterators = list(map(iter, iterables))
4911
4912 try:
4913 prod = [next(it) for it in iterators]
4914 except StopIteration:
4915 return
4916 yield tuple(prod)
4917
4918 for i, it in enumerate(iterators):
4919 for prod[i] in it:
4920 yield tuple(prod)
4921
4922
4923def takewhile_inclusive(predicate, iterable):
4924 """A variant of :func:`takewhile` that yields one additional element.
4925
4926 >>> list(takewhile_inclusive(lambda x: x < 5, [1, 4, 6, 4, 1]))
4927 [1, 4, 6]
4928
4929 :func:`takewhile` would return ``[1, 4]``.
4930 """
4931 for x in iterable:
4932 yield x
4933 if not predicate(x):
4934 break
4935
4936
4937def outer_product(func, xs, ys, *args, **kwargs):
4938 """A generalized outer product that applies a binary function to all
4939 pairs of items. Returns a 2D matrix with ``len(xs)`` rows and ``len(ys)``
4940 columns.
4941 Also accepts ``*args`` and ``**kwargs`` that are passed to ``func``.
4942
4943 Multiplication table:
4944
4945 >>> list(outer_product(mul, range(1, 4), range(1, 6)))
4946 [(1, 2, 3, 4, 5), (2, 4, 6, 8, 10), (3, 6, 9, 12, 15)]
4947
4948 Cross tabulation:
4949
4950 >>> xs = ['A', 'B', 'A', 'A', 'B', 'B', 'A', 'A', 'B', 'B']
4951 >>> ys = ['X', 'X', 'X', 'Y', 'Z', 'Z', 'Y', 'Y', 'Z', 'Z']
4952 >>> pair_counts = Counter(zip(xs, ys))
4953 >>> count_rows = lambda x, y: pair_counts[x, y]
4954 >>> list(outer_product(count_rows, sorted(set(xs)), sorted(set(ys))))
4955 [(2, 3, 0), (1, 0, 4)]
4956
4957 Usage with ``*args`` and ``**kwargs``:
4958
4959 >>> animals = ['cat', 'wolf', 'mouse']
4960 >>> list(outer_product(min, animals, animals, key=len))
4961 [('cat', 'cat', 'cat'), ('cat', 'wolf', 'wolf'), ('cat', 'wolf', 'mouse')]
4962 """
4963 ys = tuple(ys)
4964 return batched(
4965 starmap(lambda x, y: func(x, y, *args, **kwargs), product(xs, ys)),
4966 n=len(ys),
4967 )
4968
4969
4970def iter_suppress(iterable, *exceptions):
4971 """Yield each of the items from *iterable*. If the iteration raises one of
4972 the specified *exceptions*, that exception will be suppressed and iteration
4973 will stop.
4974
4975 >>> from itertools import chain
4976 >>> def breaks_at_five(x):
4977 ... while True:
4978 ... if x >= 5:
4979 ... raise RuntimeError
4980 ... yield x
4981 ... x += 1
4982 >>> it_1 = iter_suppress(breaks_at_five(1), RuntimeError)
4983 >>> it_2 = iter_suppress(breaks_at_five(2), RuntimeError)
4984 >>> list(chain(it_1, it_2))
4985 [1, 2, 3, 4, 2, 3, 4]
4986 """
4987 try:
4988 yield from iterable
4989 except exceptions:
4990 return
4991
4992
4993def filter_map(func, iterable):
4994 """Apply *func* to every element of *iterable*, yielding only those which
4995 are not ``None``.
4996
4997 >>> elems = ['1', 'a', '2', 'b', '3']
4998 >>> list(filter_map(lambda s: int(s) if s.isnumeric() else None, elems))
4999 [1, 2, 3]
5000 """
5001 for x in iterable:
5002 y = func(x)
5003 if y is not None:
5004 yield y
5005
5006
5007def powerset_of_sets(iterable):
5008 """Yields all possible subsets of the iterable.
5009
5010 >>> list(powerset_of_sets([1, 2, 3])) # doctest: +SKIP
5011 [set(), {1}, {2}, {3}, {1, 2}, {1, 3}, {2, 3}, {1, 2, 3}]
5012 >>> list(powerset_of_sets([1, 1, 0])) # doctest: +SKIP
5013 [set(), {1}, {0}, {0, 1}]
5014
5015 :func:`powerset_of_sets` takes care to minimize the number
5016 of hash operations performed.
5017 """
5018 sets = tuple(dict.fromkeys(map(frozenset, zip(iterable))))
5019 return chain.from_iterable(
5020 starmap(set().union, combinations(sets, r))
5021 for r in range(len(sets) + 1)
5022 )
5023
5024
5025def join_mappings(**field_to_map):
5026 """
5027 Joins multiple mappings together using their common keys.
5028
5029 >>> user_scores = {'elliot': 50, 'claris': 60}
5030 >>> user_times = {'elliot': 30, 'claris': 40}
5031 >>> join_mappings(score=user_scores, time=user_times)
5032 {'elliot': {'score': 50, 'time': 30}, 'claris': {'score': 60, 'time': 40}}
5033 """
5034 ret = defaultdict(dict)
5035
5036 for field_name, mapping in field_to_map.items():
5037 for key, value in mapping.items():
5038 ret[key][field_name] = value
5039
5040 return dict(ret)
5041
5042
5043def _complex_sumprod(v1, v2):
5044 """High precision sumprod() for complex numbers.
5045 Used by :func:`dft` and :func:`idft`.
5046 """
5047
5048 real = attrgetter('real')
5049 imag = attrgetter('imag')
5050 r1 = chain(map(real, v1), map(neg, map(imag, v1)))
5051 r2 = chain(map(real, v2), map(imag, v2))
5052 i1 = chain(map(real, v1), map(imag, v1))
5053 i2 = chain(map(imag, v2), map(real, v2))
5054 return complex(_fsumprod(r1, r2), _fsumprod(i1, i2))
5055
5056
5057def dft(xarr):
5058 """Discrete Fourier Transform. *xarr* is a sequence of complex numbers.
5059 Yields the components of the corresponding transformed output vector.
5060
5061 >>> import cmath
5062 >>> xarr = [1, 2-1j, -1j, -1+2j] # time domain
5063 >>> Xarr = [2, -2-2j, -2j, 4+4j] # frequency domain
5064 >>> magnitudes, phases = zip(*map(cmath.polar, Xarr))
5065 >>> all(map(cmath.isclose, dft(xarr), Xarr))
5066 True
5067
5068 Inputs are restricted to numeric types that can add and multiply
5069 with a complex number. This includes int, float, complex, and
5070 Fraction, but excludes Decimal.
5071
5072 See :func:`idft` for the inverse Discrete Fourier Transform.
5073 """
5074 N = len(xarr)
5075 roots_of_unity = [e ** (n / N * tau * -1j) for n in range(N)]
5076 for k in range(N):
5077 coeffs = [roots_of_unity[k * n % N] for n in range(N)]
5078 yield _complex_sumprod(xarr, coeffs)
5079
5080
5081def idft(Xarr):
5082 """Inverse Discrete Fourier Transform. *Xarr* is a sequence of
5083 complex numbers. Yields the components of the corresponding
5084 inverse-transformed output vector.
5085
5086 >>> import cmath
5087 >>> xarr = [1, 2-1j, -1j, -1+2j] # time domain
5088 >>> Xarr = [2, -2-2j, -2j, 4+4j] # frequency domain
5089 >>> all(map(cmath.isclose, idft(Xarr), xarr))
5090 True
5091
5092 Inputs are restricted to numeric types that can add and multiply
5093 with a complex number. This includes int, float, complex, and
5094 Fraction, but excludes Decimal.
5095
5096 See :func:`dft` for the Discrete Fourier Transform.
5097 """
5098 N = len(Xarr)
5099 roots_of_unity = [e ** (n / N * tau * 1j) for n in range(N)]
5100 for k in range(N):
5101 coeffs = [roots_of_unity[k * n % N] for n in range(N)]
5102 yield _complex_sumprod(Xarr, coeffs) / N
5103
5104
5105def doublestarmap(func, iterable):
5106 """Apply *func* to every item of *iterable* by dictionary unpacking
5107 the item into *func*.
5108
5109 The difference between :func:`itertools.starmap` and :func:`doublestarmap`
5110 parallels the distinction between ``func(*a)`` and ``func(**a)``.
5111
5112 >>> iterable = [{'a': 1, 'b': 2}, {'a': 40, 'b': 60}]
5113 >>> list(doublestarmap(lambda a, b: a + b, iterable))
5114 [3, 100]
5115
5116 ``TypeError`` will be raised if *func*'s signature doesn't match the
5117 mapping contained in *iterable* or if *iterable* does not contain mappings.
5118 """
5119 for item in iterable:
5120 yield func(**item)
5121
5122
5123def _nth_prime_bounds(n):
5124 """Bounds for the nth prime (counting from 1): lb < p_n < ub."""
5125 # At and above 688,383, the lb/ub spread is under 0.003 * p_n.
5126
5127 if n < 1:
5128 raise ValueError
5129
5130 if n < 6:
5131 return (n, 2.25 * n)
5132
5133 # https://en.wikipedia.org/wiki/Prime-counting_function#Inequalities
5134 upper_bound = n * log(n * log(n))
5135 lower_bound = upper_bound - n
5136 if n >= 688_383:
5137 upper_bound -= n * (1.0 - (log(log(n)) - 2.0) / log(n))
5138
5139 return lower_bound, upper_bound
5140
5141
5142def nth_prime(n, *, approximate=False):
5143 """Return the nth prime (counting from 0).
5144
5145 >>> nth_prime(0)
5146 2
5147 >>> nth_prime(100)
5148 547
5149
5150 If *approximate* is set to True, will return a prime close
5151 to the nth prime. The estimation is much faster than computing
5152 an exact result.
5153
5154 >>> nth_prime(200_000_000, approximate=True) # Exact result is 4222234763
5155 4217820427
5156
5157 """
5158 lb, ub = _nth_prime_bounds(n + 1)
5159
5160 if not approximate or n <= 1_000_000:
5161 return nth(sieve(ceil(ub)), n)
5162
5163 # Search from the midpoint and return the first odd prime
5164 odd = floor((lb + ub) / 2) | 1
5165 return first_true(count(odd, step=2), pred=is_prime)
5166
5167
5168def argmin(iterable, *, key=None):
5169 """
5170 Index of the first occurrence of a minimum value in an iterable.
5171
5172 >>> argmin('efghabcdijkl')
5173 4
5174 >>> argmin([3, 2, 1, 0, 4, 2, 1, 0])
5175 3
5176
5177 For example, look up a label corresponding to the position
5178 of a value that minimizes a cost function::
5179
5180 >>> def cost(x):
5181 ... "Days for a wound to heal given a subject's age."
5182 ... return x**2 - 20*x + 150
5183 ...
5184 >>> labels = ['homer', 'marge', 'bart', 'lisa', 'maggie']
5185 >>> ages = [ 35, 30, 10, 9, 1 ]
5186
5187 # Fastest healing family member
5188 >>> labels[argmin(ages, key=cost)]
5189 'bart'
5190
5191 # Age with fastest healing
5192 >>> min(ages, key=cost)
5193 10
5194
5195 """
5196 if key is not None:
5197 iterable = map(key, iterable)
5198 return min(enumerate(iterable), key=itemgetter(1))[0]
5199
5200
5201def argmax(iterable, *, key=None):
5202 """
5203 Index of the first occurrence of a maximum value in an iterable.
5204
5205 >>> argmax('abcdefghabcd')
5206 7
5207 >>> argmax([0, 1, 2, 3, 3, 2, 1, 0])
5208 3
5209
5210 For example, identify the best machine learning model::
5211
5212 >>> models = ['svm', 'random forest', 'knn', 'naïve bayes']
5213 >>> accuracy = [ 68, 61, 84, 72 ]
5214
5215 # Most accurate model
5216 >>> models[argmax(accuracy)]
5217 'knn'
5218
5219 # Best accuracy
5220 >>> max(accuracy)
5221 84
5222
5223 """
5224 if key is not None:
5225 iterable = map(key, iterable)
5226 return max(enumerate(iterable), key=itemgetter(1))[0]
5227
5228
5229def extract(iterable, indices):
5230 """Yield values at the specified indices.
5231
5232 Example:
5233
5234 >>> data = 'abcdefghijklmnopqrstuvwxyz'
5235 >>> list(extract(data, [7, 4, 11, 11, 14]))
5236 ['h', 'e', 'l', 'l', 'o']
5237
5238 The *iterable* is consumed lazily and can be infinite.
5239 The *indices* are consumed immediately and must be finite.
5240
5241 Raises ``IndexError`` if an index lies beyond the iterable.
5242 Raises ``ValueError`` for negative indices.
5243 """
5244
5245 iterator = iter(iterable)
5246 index_and_position = sorted(zip(indices, count()))
5247
5248 if index_and_position and index_and_position[0][0] < 0:
5249 raise ValueError('Indices must be non-negative')
5250
5251 buffer = {}
5252 iterator_position = -1
5253 next_to_emit = 0
5254
5255 for index, order in index_and_position:
5256 advance = index - iterator_position
5257 if advance:
5258 try:
5259 value = next(islice(iterator, advance - 1, None))
5260 except StopIteration:
5261 raise IndexError(index)
5262 iterator_position = index
5263
5264 buffer[order] = value
5265
5266 while next_to_emit in buffer:
5267 yield buffer.pop(next_to_emit)
5268 next_to_emit += 1