1import math
2import warnings
3
4from collections import Counter, defaultdict, deque, abc
5from collections.abc import Sequence
6from contextlib import suppress
7from functools import cached_property, partial, reduce, wraps
8from heapq import heapify, heapreplace
9from itertools import (
10 chain,
11 combinations,
12 compress,
13 count,
14 cycle,
15 dropwhile,
16 groupby,
17 islice,
18 permutations,
19 repeat,
20 starmap,
21 takewhile,
22 tee,
23 zip_longest,
24 product,
25)
26from math import comb, e, exp, factorial, floor, fsum, log, log1p, perm, tau
27from math import ceil
28from queue import Empty, Queue
29from random import random, randrange, shuffle, uniform
30from operator import (
31 attrgetter,
32 is_not,
33 itemgetter,
34 lt,
35 mul,
36 neg,
37 sub,
38 gt,
39)
40from sys import hexversion, maxsize
41from time import monotonic
42
43from .recipes import (
44 _marker,
45 _zip_equal,
46 UnequalIterablesError,
47 consume,
48 first_true,
49 flatten,
50 is_prime,
51 nth,
52 powerset,
53 sieve,
54 take,
55 unique_everseen,
56 all_equal,
57 batched,
58)
59
60__all__ = [
61 'AbortThread',
62 'SequenceView',
63 'UnequalIterablesError',
64 'adjacent',
65 'all_unique',
66 'always_iterable',
67 'always_reversible',
68 'argmax',
69 'argmin',
70 'bucket',
71 'callback_iter',
72 'chunked',
73 'chunked_even',
74 'circular_shifts',
75 'collapse',
76 'combination_index',
77 'combination_with_replacement_index',
78 'consecutive_groups',
79 'constrained_batches',
80 'consumer',
81 'count_cycle',
82 'countable',
83 'derangements',
84 'dft',
85 'difference',
86 'distinct_combinations',
87 'distinct_permutations',
88 'distribute',
89 'divide',
90 'doublestarmap',
91 'duplicates_everseen',
92 'duplicates_justseen',
93 'classify_unique',
94 'exactly_n',
95 'filter_except',
96 'filter_map',
97 'first',
98 'gray_product',
99 'groupby_transform',
100 'ichunked',
101 'iequals',
102 'idft',
103 'ilen',
104 'interleave',
105 'interleave_evenly',
106 'interleave_longest',
107 'intersperse',
108 'is_sorted',
109 'islice_extended',
110 'iterate',
111 'iter_suppress',
112 'join_mappings',
113 'last',
114 'locate',
115 'longest_common_prefix',
116 'lstrip',
117 'make_decorator',
118 'map_except',
119 'map_if',
120 'map_reduce',
121 'mark_ends',
122 'minmax',
123 'nth_or_last',
124 'nth_permutation',
125 'nth_prime',
126 'nth_product',
127 'nth_combination_with_replacement',
128 'numeric_range',
129 'one',
130 'only',
131 'outer_product',
132 'padded',
133 'partial_product',
134 'partitions',
135 'peekable',
136 'permutation_index',
137 'powerset_of_sets',
138 'product_index',
139 'raise_',
140 'repeat_each',
141 'repeat_last',
142 'replace',
143 'rlocate',
144 'rstrip',
145 'run_length',
146 'sample',
147 'seekable',
148 'set_partitions',
149 'side_effect',
150 'sliced',
151 'sort_together',
152 'split_after',
153 'split_at',
154 'split_before',
155 'split_into',
156 'split_when',
157 'spy',
158 'stagger',
159 'strip',
160 'strictly_n',
161 'substrings',
162 'substrings_indexes',
163 'takewhile_inclusive',
164 'time_limited',
165 'unique_in_window',
166 'unique_to_each',
167 'unzip',
168 'value_chain',
169 'windowed',
170 'windowed_complete',
171 'with_iter',
172 'zip_broadcast',
173 'zip_equal',
174 'zip_offset',
175]
176
177# math.sumprod is available for Python 3.12+
178try:
179 from math import sumprod as _fsumprod
180
181except ImportError: # pragma: no cover
182 # Extended precision algorithms from T. J. Dekker,
183 # "A Floating-Point Technique for Extending the Available Precision"
184 # https://csclub.uwaterloo.ca/~pbarfuss/dekker1971.pdf
185 # Formulas: (5.5) (5.6) and (5.8). Code: mul12()
186
187 def dl_split(x: float):
188 "Split a float into two half-precision components."
189 t = x * 134217729.0 # Veltkamp constant = 2.0 ** 27 + 1
190 hi = t - (t - x)
191 lo = x - hi
192 return hi, lo
193
194 def dl_mul(x, y):
195 "Lossless multiplication."
196 xx_hi, xx_lo = dl_split(x)
197 yy_hi, yy_lo = dl_split(y)
198 p = xx_hi * yy_hi
199 q = xx_hi * yy_lo + xx_lo * yy_hi
200 z = p + q
201 zz = p - z + q + xx_lo * yy_lo
202 return z, zz
203
204 def _fsumprod(p, q):
205 return fsum(chain.from_iterable(map(dl_mul, p, q)))
206
207
208def chunked(iterable, n, strict=False):
209 """Break *iterable* into lists of length *n*:
210
211 >>> list(chunked([1, 2, 3, 4, 5, 6], 3))
212 [[1, 2, 3], [4, 5, 6]]
213
214 By the default, the last yielded list will have fewer than *n* elements
215 if the length of *iterable* is not divisible by *n*:
216
217 >>> list(chunked([1, 2, 3, 4, 5, 6, 7, 8], 3))
218 [[1, 2, 3], [4, 5, 6], [7, 8]]
219
220 To use a fill-in value instead, see the :func:`grouper` recipe.
221
222 If the length of *iterable* is not divisible by *n* and *strict* is
223 ``True``, then ``ValueError`` will be raised before the last
224 list is yielded.
225
226 """
227 iterator = iter(partial(take, n, iter(iterable)), [])
228 if strict:
229 if n is None:
230 raise ValueError('n must not be None when using strict mode.')
231
232 def ret():
233 for chunk in iterator:
234 if len(chunk) != n:
235 raise ValueError('iterable is not divisible by n.')
236 yield chunk
237
238 return ret()
239 else:
240 return iterator
241
242
243def first(iterable, default=_marker):
244 """Return the first item of *iterable*, or *default* if *iterable* is
245 empty.
246
247 >>> first([0, 1, 2, 3])
248 0
249 >>> first([], 'some default')
250 'some default'
251
252 If *default* is not provided and there are no items in the iterable,
253 raise ``ValueError``.
254
255 :func:`first` is useful when you have a generator of expensive-to-retrieve
256 values and want any arbitrary one. It is marginally shorter than
257 ``next(iter(iterable), default)``.
258
259 """
260 for item in iterable:
261 return item
262 if default is _marker:
263 raise ValueError(
264 'first() was called on an empty iterable, '
265 'and no default value was provided.'
266 )
267 return default
268
269
270def last(iterable, default=_marker):
271 """Return the last item of *iterable*, or *default* if *iterable* is
272 empty.
273
274 >>> last([0, 1, 2, 3])
275 3
276 >>> last([], 'some default')
277 'some default'
278
279 If *default* is not provided and there are no items in the iterable,
280 raise ``ValueError``.
281 """
282 try:
283 if isinstance(iterable, Sequence):
284 return iterable[-1]
285 # Work around https://bugs.python.org/issue38525
286 if hasattr(iterable, '__reversed__'):
287 return next(reversed(iterable))
288 return deque(iterable, maxlen=1)[-1]
289 except (IndexError, TypeError, StopIteration):
290 if default is _marker:
291 raise ValueError(
292 'last() was called on an empty iterable, '
293 'and no default value was provided.'
294 )
295 return default
296
297
298def nth_or_last(iterable, n, default=_marker):
299 """Return the nth or the last item of *iterable*,
300 or *default* if *iterable* is empty.
301
302 >>> nth_or_last([0, 1, 2, 3], 2)
303 2
304 >>> nth_or_last([0, 1], 2)
305 1
306 >>> nth_or_last([], 0, 'some default')
307 'some default'
308
309 If *default* is not provided and there are no items in the iterable,
310 raise ``ValueError``.
311 """
312 return last(islice(iterable, n + 1), default=default)
313
314
315class peekable:
316 """Wrap an iterator to allow lookahead and prepending elements.
317
318 Call :meth:`peek` on the result to get the value that will be returned
319 by :func:`next`. This won't advance the iterator:
320
321 >>> p = peekable(['a', 'b'])
322 >>> p.peek()
323 'a'
324 >>> next(p)
325 'a'
326
327 Pass :meth:`peek` a default value to return that instead of raising
328 ``StopIteration`` when the iterator is exhausted.
329
330 >>> p = peekable([])
331 >>> p.peek('hi')
332 'hi'
333
334 peekables also offer a :meth:`prepend` method, which "inserts" items
335 at the head of the iterable:
336
337 >>> p = peekable([1, 2, 3])
338 >>> p.prepend(10, 11, 12)
339 >>> next(p)
340 10
341 >>> p.peek()
342 11
343 >>> list(p)
344 [11, 12, 1, 2, 3]
345
346 peekables can be indexed. Index 0 is the item that will be returned by
347 :func:`next`, index 1 is the item after that, and so on:
348 The values up to the given index will be cached.
349
350 >>> p = peekable(['a', 'b', 'c', 'd'])
351 >>> p[0]
352 'a'
353 >>> p[1]
354 'b'
355 >>> next(p)
356 'a'
357
358 Negative indexes are supported, but be aware that they will cache the
359 remaining items in the source iterator, which may require significant
360 storage.
361
362 To check whether a peekable is exhausted, check its truth value:
363
364 >>> p = peekable(['a', 'b'])
365 >>> if p: # peekable has items
366 ... list(p)
367 ['a', 'b']
368 >>> if not p: # peekable is exhausted
369 ... list(p)
370 []
371
372 """
373
374 def __init__(self, iterable):
375 self._it = iter(iterable)
376 self._cache = deque()
377
378 def __iter__(self):
379 return self
380
381 def __bool__(self):
382 try:
383 self.peek()
384 except StopIteration:
385 return False
386 return True
387
388 def peek(self, default=_marker):
389 """Return the item that will be next returned from ``next()``.
390
391 Return ``default`` if there are no items left. If ``default`` is not
392 provided, raise ``StopIteration``.
393
394 """
395 if not self._cache:
396 try:
397 self._cache.append(next(self._it))
398 except StopIteration:
399 if default is _marker:
400 raise
401 return default
402 return self._cache[0]
403
404 def prepend(self, *items):
405 """Stack up items to be the next ones returned from ``next()`` or
406 ``self.peek()``. The items will be returned in
407 first in, first out order::
408
409 >>> p = peekable([1, 2, 3])
410 >>> p.prepend(10, 11, 12)
411 >>> next(p)
412 10
413 >>> list(p)
414 [11, 12, 1, 2, 3]
415
416 It is possible, by prepending items, to "resurrect" a peekable that
417 previously raised ``StopIteration``.
418
419 >>> p = peekable([])
420 >>> next(p)
421 Traceback (most recent call last):
422 ...
423 StopIteration
424 >>> p.prepend(1)
425 >>> next(p)
426 1
427 >>> next(p)
428 Traceback (most recent call last):
429 ...
430 StopIteration
431
432 """
433 self._cache.extendleft(reversed(items))
434
435 def __next__(self):
436 if self._cache:
437 return self._cache.popleft()
438
439 return next(self._it)
440
441 def _get_slice(self, index):
442 # Normalize the slice's arguments
443 step = 1 if (index.step is None) else index.step
444 if step > 0:
445 start = 0 if (index.start is None) else index.start
446 stop = maxsize if (index.stop is None) else index.stop
447 elif step < 0:
448 start = -1 if (index.start is None) else index.start
449 stop = (-maxsize - 1) if (index.stop is None) else index.stop
450 else:
451 raise ValueError('slice step cannot be zero')
452
453 # If either the start or stop index is negative, we'll need to cache
454 # the rest of the iterable in order to slice from the right side.
455 if (start < 0) or (stop < 0):
456 self._cache.extend(self._it)
457 # Otherwise we'll need to find the rightmost index and cache to that
458 # point.
459 else:
460 n = min(max(start, stop) + 1, maxsize)
461 cache_len = len(self._cache)
462 if n >= cache_len:
463 self._cache.extend(islice(self._it, n - cache_len))
464
465 return list(self._cache)[index]
466
467 def __getitem__(self, index):
468 if isinstance(index, slice):
469 return self._get_slice(index)
470
471 cache_len = len(self._cache)
472 if index < 0:
473 self._cache.extend(self._it)
474 elif index >= cache_len:
475 self._cache.extend(islice(self._it, index + 1 - cache_len))
476
477 return self._cache[index]
478
479
480def consumer(func):
481 """Decorator that automatically advances a PEP-342-style "reverse iterator"
482 to its first yield point so you don't have to call ``next()`` on it
483 manually.
484
485 >>> @consumer
486 ... def tally():
487 ... i = 0
488 ... while True:
489 ... print('Thing number %s is %s.' % (i, (yield)))
490 ... i += 1
491 ...
492 >>> t = tally()
493 >>> t.send('red')
494 Thing number 0 is red.
495 >>> t.send('fish')
496 Thing number 1 is fish.
497
498 Without the decorator, you would have to call ``next(t)`` before
499 ``t.send()`` could be used.
500
501 """
502
503 @wraps(func)
504 def wrapper(*args, **kwargs):
505 gen = func(*args, **kwargs)
506 next(gen)
507 return gen
508
509 return wrapper
510
511
512def ilen(iterable):
513 """Return the number of items in *iterable*.
514
515 For example, there are 168 prime numbers below 1,000:
516
517 >>> ilen(sieve(1000))
518 168
519
520 Equivalent to, but faster than::
521
522 def ilen(iterable):
523 count = 0
524 for _ in iterable:
525 count += 1
526 return count
527
528 This fully consumes the iterable, so handle with care.
529
530 """
531 # This is the "most beautiful of the fast variants" of this function.
532 # If you think you can improve on it, please ensure that your version
533 # is both 10x faster and 10x more beautiful.
534 return sum(compress(repeat(1), zip(iterable)))
535
536
537def iterate(func, start):
538 """Return ``start``, ``func(start)``, ``func(func(start))``, ...
539
540 Produces an infinite iterator. To add a stopping condition,
541 use :func:`take`, ``takewhile``, or :func:`takewhile_inclusive`:.
542
543 >>> take(10, iterate(lambda x: 2*x, 1))
544 [1, 2, 4, 8, 16, 32, 64, 128, 256, 512]
545
546 >>> collatz = lambda x: 3*x + 1 if x%2==1 else x // 2
547 >>> list(takewhile_inclusive(lambda x: x!=1, iterate(collatz, 10)))
548 [10, 5, 16, 8, 4, 2, 1]
549
550 """
551 with suppress(StopIteration):
552 while True:
553 yield start
554 start = func(start)
555
556
557def with_iter(context_manager):
558 """Wrap an iterable in a ``with`` statement, so it closes once exhausted.
559
560 For example, this will close the file when the iterator is exhausted::
561
562 upper_lines = (line.upper() for line in with_iter(open('foo')))
563
564 Any context manager which returns an iterable is a candidate for
565 ``with_iter``.
566
567 """
568 with context_manager as iterable:
569 yield from iterable
570
571
572def one(iterable, too_short=None, too_long=None):
573 """Return the first item from *iterable*, which is expected to contain only
574 that item. Raise an exception if *iterable* is empty or has more than one
575 item.
576
577 :func:`one` is useful for ensuring that an iterable contains only one item.
578 For example, it can be used to retrieve the result of a database query
579 that is expected to return a single row.
580
581 If *iterable* is empty, ``ValueError`` will be raised. You may specify a
582 different exception with the *too_short* keyword:
583
584 >>> it = []
585 >>> one(it) # doctest: +IGNORE_EXCEPTION_DETAIL
586 Traceback (most recent call last):
587 ...
588 ValueError: too few items in iterable (expected 1)'
589 >>> too_short = IndexError('too few items')
590 >>> one(it, too_short=too_short) # doctest: +IGNORE_EXCEPTION_DETAIL
591 Traceback (most recent call last):
592 ...
593 IndexError: too few items
594
595 Similarly, if *iterable* contains more than one item, ``ValueError`` will
596 be raised. You may specify a different exception with the *too_long*
597 keyword:
598
599 >>> it = ['too', 'many']
600 >>> one(it) # doctest: +IGNORE_EXCEPTION_DETAIL
601 Traceback (most recent call last):
602 ...
603 ValueError: Expected exactly one item in iterable, but got 'too',
604 'many', and perhaps more.
605 >>> too_long = RuntimeError
606 >>> one(it, too_long=too_long) # doctest: +IGNORE_EXCEPTION_DETAIL
607 Traceback (most recent call last):
608 ...
609 RuntimeError
610
611 Note that :func:`one` attempts to advance *iterable* twice to ensure there
612 is only one item. See :func:`spy` or :func:`peekable` to check iterable
613 contents less destructively.
614
615 """
616 iterator = iter(iterable)
617 for first in iterator:
618 for second in iterator:
619 msg = (
620 f'Expected exactly one item in iterable, but got {first!r}, '
621 f'{second!r}, and perhaps more.'
622 )
623 raise too_long or ValueError(msg)
624 return first
625 raise too_short or ValueError('too few items in iterable (expected 1)')
626
627
628def raise_(exception, *args):
629 raise exception(*args)
630
631
632def strictly_n(iterable, n, too_short=None, too_long=None):
633 """Validate that *iterable* has exactly *n* items and return them if
634 it does. If it has fewer than *n* items, call function *too_short*
635 with those items. If it has more than *n* items, call function
636 *too_long* with the first ``n + 1`` items.
637
638 >>> iterable = ['a', 'b', 'c', 'd']
639 >>> n = 4
640 >>> list(strictly_n(iterable, n))
641 ['a', 'b', 'c', 'd']
642
643 Note that the returned iterable must be consumed in order for the check to
644 be made.
645
646 By default, *too_short* and *too_long* are functions that raise
647 ``ValueError``.
648
649 >>> list(strictly_n('ab', 3)) # doctest: +IGNORE_EXCEPTION_DETAIL
650 Traceback (most recent call last):
651 ...
652 ValueError: too few items in iterable (got 2)
653
654 >>> list(strictly_n('abc', 2)) # doctest: +IGNORE_EXCEPTION_DETAIL
655 Traceback (most recent call last):
656 ...
657 ValueError: too many items in iterable (got at least 3)
658
659 You can instead supply functions that do something else.
660 *too_short* will be called with the number of items in *iterable*.
661 *too_long* will be called with `n + 1`.
662
663 >>> def too_short(item_count):
664 ... raise RuntimeError
665 >>> it = strictly_n('abcd', 6, too_short=too_short)
666 >>> list(it) # doctest: +IGNORE_EXCEPTION_DETAIL
667 Traceback (most recent call last):
668 ...
669 RuntimeError
670
671 >>> def too_long(item_count):
672 ... print('The boss is going to hear about this')
673 >>> it = strictly_n('abcdef', 4, too_long=too_long)
674 >>> list(it)
675 The boss is going to hear about this
676 ['a', 'b', 'c', 'd']
677
678 """
679 if too_short is None:
680 too_short = lambda item_count: raise_(
681 ValueError,
682 f'Too few items in iterable (got {item_count})',
683 )
684
685 if too_long is None:
686 too_long = lambda item_count: raise_(
687 ValueError,
688 f'Too many items in iterable (got at least {item_count})',
689 )
690
691 it = iter(iterable)
692
693 sent = 0
694 for item in islice(it, n):
695 yield item
696 sent += 1
697
698 if sent < n:
699 too_short(sent)
700 return
701
702 for item in it:
703 too_long(n + 1)
704 return
705
706
707def distinct_permutations(iterable, r=None):
708 """Yield successive distinct permutations of the elements in *iterable*.
709
710 >>> sorted(distinct_permutations([1, 0, 1]))
711 [(0, 1, 1), (1, 0, 1), (1, 1, 0)]
712
713 Equivalent to yielding from ``set(permutations(iterable))``, except
714 duplicates are not generated and thrown away. For larger input sequences
715 this is much more efficient.
716
717 Duplicate permutations arise when there are duplicated elements in the
718 input iterable. The number of items returned is
719 `n! / (x_1! * x_2! * ... * x_n!)`, where `n` is the total number of
720 items input, and each `x_i` is the count of a distinct item in the input
721 sequence. The function :func:`multinomial` computes this directly.
722
723 If *r* is given, only the *r*-length permutations are yielded.
724
725 >>> sorted(distinct_permutations([1, 0, 1], r=2))
726 [(0, 1), (1, 0), (1, 1)]
727 >>> sorted(distinct_permutations(range(3), r=2))
728 [(0, 1), (0, 2), (1, 0), (1, 2), (2, 0), (2, 1)]
729
730 *iterable* need not be sortable, but note that using equal (``x == y``)
731 but non-identical (``id(x) != id(y)``) elements may produce surprising
732 behavior. For example, ``1`` and ``True`` are equal but non-identical:
733
734 >>> list(distinct_permutations([1, True, '3'])) # doctest: +SKIP
735 [
736 (1, True, '3'),
737 (1, '3', True),
738 ('3', 1, True)
739 ]
740 >>> list(distinct_permutations([1, 2, '3'])) # doctest: +SKIP
741 [
742 (1, 2, '3'),
743 (1, '3', 2),
744 (2, 1, '3'),
745 (2, '3', 1),
746 ('3', 1, 2),
747 ('3', 2, 1)
748 ]
749 """
750
751 # Algorithm: https://w.wiki/Qai
752 def _full(A):
753 while True:
754 # Yield the permutation we have
755 yield tuple(A)
756
757 # Find the largest index i such that A[i] < A[i + 1]
758 for i in range(size - 2, -1, -1):
759 if A[i] < A[i + 1]:
760 break
761 # If no such index exists, this permutation is the last one
762 else:
763 return
764
765 # Find the largest index j greater than j such that A[i] < A[j]
766 for j in range(size - 1, i, -1):
767 if A[i] < A[j]:
768 break
769
770 # Swap the value of A[i] with that of A[j], then reverse the
771 # sequence from A[i + 1] to form the new permutation
772 A[i], A[j] = A[j], A[i]
773 A[i + 1 :] = A[: i - size : -1] # A[i + 1:][::-1]
774
775 # Algorithm: modified from the above
776 def _partial(A, r):
777 # Split A into the first r items and the last r items
778 head, tail = A[:r], A[r:]
779 right_head_indexes = range(r - 1, -1, -1)
780 left_tail_indexes = range(len(tail))
781
782 while True:
783 # Yield the permutation we have
784 yield tuple(head)
785
786 # Starting from the right, find the first index of the head with
787 # value smaller than the maximum value of the tail - call it i.
788 pivot = tail[-1]
789 for i in right_head_indexes:
790 if head[i] < pivot:
791 break
792 pivot = head[i]
793 else:
794 return
795
796 # Starting from the left, find the first value of the tail
797 # with a value greater than head[i] and swap.
798 for j in left_tail_indexes:
799 if tail[j] > head[i]:
800 head[i], tail[j] = tail[j], head[i]
801 break
802 # If we didn't find one, start from the right and find the first
803 # index of the head with a value greater than head[i] and swap.
804 else:
805 for j in right_head_indexes:
806 if head[j] > head[i]:
807 head[i], head[j] = head[j], head[i]
808 break
809
810 # Reverse head[i + 1:] and swap it with tail[:r - (i + 1)]
811 tail += head[: i - r : -1] # head[i + 1:][::-1]
812 i += 1
813 head[i:], tail[:] = tail[: r - i], tail[r - i :]
814
815 items = list(iterable)
816
817 try:
818 items.sort()
819 sortable = True
820 except TypeError:
821 sortable = False
822
823 indices_dict = defaultdict(list)
824
825 for item in items:
826 indices_dict[items.index(item)].append(item)
827
828 indices = [items.index(item) for item in items]
829 indices.sort()
830
831 equivalent_items = {k: cycle(v) for k, v in indices_dict.items()}
832
833 def permuted_items(permuted_indices):
834 return tuple(
835 next(equivalent_items[index]) for index in permuted_indices
836 )
837
838 size = len(items)
839 if r is None:
840 r = size
841
842 # functools.partial(_partial, ... )
843 algorithm = _full if (r == size) else partial(_partial, r=r)
844
845 if 0 < r <= size:
846 if sortable:
847 return algorithm(items)
848 else:
849 return (
850 permuted_items(permuted_indices)
851 for permuted_indices in algorithm(indices)
852 )
853
854 return iter(() if r else ((),))
855
856
857def derangements(iterable, r=None):
858 """Yield successive derangements of the elements in *iterable*.
859
860 A derangement is a permutation in which no element appears at its original
861 index. In other words, a derangement is a permutation that has no fixed points.
862
863 Suppose Alice, Bob, Carol, and Dave are playing Secret Santa.
864 The code below outputs all of the different ways to assign gift recipients
865 such that nobody is assigned to himself or herself:
866
867 >>> for d in derangements(['Alice', 'Bob', 'Carol', 'Dave']):
868 ... print(', '.join(d))
869 Bob, Alice, Dave, Carol
870 Bob, Carol, Dave, Alice
871 Bob, Dave, Alice, Carol
872 Carol, Alice, Dave, Bob
873 Carol, Dave, Alice, Bob
874 Carol, Dave, Bob, Alice
875 Dave, Alice, Bob, Carol
876 Dave, Carol, Alice, Bob
877 Dave, Carol, Bob, Alice
878
879 If *r* is given, only the *r*-length derangements are yielded.
880
881 >>> sorted(derangements(range(3), 2))
882 [(1, 0), (1, 2), (2, 0)]
883 >>> sorted(derangements([0, 2, 3], 2))
884 [(2, 0), (2, 3), (3, 0)]
885
886 Elements are treated as unique based on their position, not on their value.
887 If the input elements are unique, there will be no repeated values within a
888 permutation.
889
890 The number of derangements of a set of size *n* is known as the
891 "subfactorial of n". For n > 0, the subfactorial is:
892 ``round(math.factorial(n) / math.e)``.
893 References:
894
895 * Article: https://www.numberanalytics.com/blog/ultimate-guide-to-derangements-in-combinatorics
896 * Sizes: https://oeis.org/A000166
897 """
898 xs = tuple(iterable)
899 ys = tuple(range(len(xs)))
900 return compress(
901 permutations(xs, r=r),
902 map(all, map(map, repeat(is_not), repeat(ys), permutations(ys, r=r))),
903 )
904
905
906def intersperse(e, iterable, n=1):
907 """Intersperse filler element *e* among the items in *iterable*, leaving
908 *n* items between each filler element.
909
910 >>> list(intersperse('!', [1, 2, 3, 4, 5]))
911 [1, '!', 2, '!', 3, '!', 4, '!', 5]
912
913 >>> list(intersperse(None, [1, 2, 3, 4, 5], n=2))
914 [1, 2, None, 3, 4, None, 5]
915
916 """
917 if n == 0:
918 raise ValueError('n must be > 0')
919 elif n == 1:
920 # interleave(repeat(e), iterable) -> e, x_0, e, x_1, e, x_2...
921 # islice(..., 1, None) -> x_0, e, x_1, e, x_2...
922 return islice(interleave(repeat(e), iterable), 1, None)
923 else:
924 # interleave(filler, chunks) -> [e], [x_0, x_1], [e], [x_2, x_3]...
925 # islice(..., 1, None) -> [x_0, x_1], [e], [x_2, x_3]...
926 # flatten(...) -> x_0, x_1, e, x_2, x_3...
927 filler = repeat([e])
928 chunks = chunked(iterable, n)
929 return flatten(islice(interleave(filler, chunks), 1, None))
930
931
932def unique_to_each(*iterables):
933 """Return the elements from each of the input iterables that aren't in the
934 other input iterables.
935
936 For example, suppose you have a set of packages, each with a set of
937 dependencies::
938
939 {'pkg_1': {'A', 'B'}, 'pkg_2': {'B', 'C'}, 'pkg_3': {'B', 'D'}}
940
941 If you remove one package, which dependencies can also be removed?
942
943 If ``pkg_1`` is removed, then ``A`` is no longer necessary - it is not
944 associated with ``pkg_2`` or ``pkg_3``. Similarly, ``C`` is only needed for
945 ``pkg_2``, and ``D`` is only needed for ``pkg_3``::
946
947 >>> unique_to_each({'A', 'B'}, {'B', 'C'}, {'B', 'D'})
948 [['A'], ['C'], ['D']]
949
950 If there are duplicates in one input iterable that aren't in the others
951 they will be duplicated in the output. Input order is preserved::
952
953 >>> unique_to_each("mississippi", "missouri")
954 [['p', 'p'], ['o', 'u', 'r']]
955
956 It is assumed that the elements of each iterable are hashable.
957
958 """
959 pool = [list(it) for it in iterables]
960 counts = Counter(chain.from_iterable(map(set, pool)))
961 uniques = {element for element in counts if counts[element] == 1}
962 return [list(filter(uniques.__contains__, it)) for it in pool]
963
964
965def windowed(seq, n, fillvalue=None, step=1):
966 """Return a sliding window of width *n* over the given iterable.
967
968 >>> all_windows = windowed([1, 2, 3, 4, 5], 3)
969 >>> list(all_windows)
970 [(1, 2, 3), (2, 3, 4), (3, 4, 5)]
971
972 When the window is larger than the iterable, *fillvalue* is used in place
973 of missing values:
974
975 >>> list(windowed([1, 2, 3], 4))
976 [(1, 2, 3, None)]
977
978 Each window will advance in increments of *step*:
979
980 >>> list(windowed([1, 2, 3, 4, 5, 6], 3, fillvalue='!', step=2))
981 [(1, 2, 3), (3, 4, 5), (5, 6, '!')]
982
983 To slide into the iterable's items, use :func:`chain` to add filler items
984 to the left:
985
986 >>> iterable = [1, 2, 3, 4]
987 >>> n = 3
988 >>> padding = [None] * (n - 1)
989 >>> list(windowed(chain(padding, iterable), 3))
990 [(None, None, 1), (None, 1, 2), (1, 2, 3), (2, 3, 4)]
991 """
992 if n < 0:
993 raise ValueError('n must be >= 0')
994 if n == 0:
995 yield ()
996 return
997 if step < 1:
998 raise ValueError('step must be >= 1')
999
1000 iterator = iter(seq)
1001
1002 # Generate first window
1003 window = deque(islice(iterator, n), maxlen=n)
1004
1005 # Deal with the first window not being full
1006 if not window:
1007 return
1008 if len(window) < n:
1009 yield tuple(window) + ((fillvalue,) * (n - len(window)))
1010 return
1011 yield tuple(window)
1012
1013 # Create the filler for the next windows. The padding ensures
1014 # we have just enough elements to fill the last window.
1015 padding = (fillvalue,) * (n - 1 if step >= n else step - 1)
1016 filler = map(window.append, chain(iterator, padding))
1017
1018 # Generate the rest of the windows
1019 for _ in islice(filler, step - 1, None, step):
1020 yield tuple(window)
1021
1022
1023def substrings(iterable):
1024 """Yield all of the substrings of *iterable*.
1025
1026 >>> [''.join(s) for s in substrings('more')]
1027 ['m', 'o', 'r', 'e', 'mo', 'or', 're', 'mor', 'ore', 'more']
1028
1029 Note that non-string iterables can also be subdivided.
1030
1031 >>> list(substrings([0, 1, 2]))
1032 [(0,), (1,), (2,), (0, 1), (1, 2), (0, 1, 2)]
1033
1034 """
1035 # The length-1 substrings
1036 seq = []
1037 for item in iterable:
1038 seq.append(item)
1039 yield (item,)
1040 seq = tuple(seq)
1041 item_count = len(seq)
1042
1043 # And the rest
1044 for n in range(2, item_count + 1):
1045 for i in range(item_count - n + 1):
1046 yield seq[i : i + n]
1047
1048
1049def substrings_indexes(seq, reverse=False):
1050 """Yield all substrings and their positions in *seq*
1051
1052 The items yielded will be a tuple of the form ``(substr, i, j)``, where
1053 ``substr == seq[i:j]``.
1054
1055 This function only works for iterables that support slicing, such as
1056 ``str`` objects.
1057
1058 >>> for item in substrings_indexes('more'):
1059 ... print(item)
1060 ('m', 0, 1)
1061 ('o', 1, 2)
1062 ('r', 2, 3)
1063 ('e', 3, 4)
1064 ('mo', 0, 2)
1065 ('or', 1, 3)
1066 ('re', 2, 4)
1067 ('mor', 0, 3)
1068 ('ore', 1, 4)
1069 ('more', 0, 4)
1070
1071 Set *reverse* to ``True`` to yield the same items in the opposite order.
1072
1073
1074 """
1075 r = range(1, len(seq) + 1)
1076 if reverse:
1077 r = reversed(r)
1078 return (
1079 (seq[i : i + L], i, i + L) for L in r for i in range(len(seq) - L + 1)
1080 )
1081
1082
1083class bucket:
1084 """Wrap *iterable* and return an object that buckets the iterable into
1085 child iterables based on a *key* function.
1086
1087 >>> iterable = ['a1', 'b1', 'c1', 'a2', 'b2', 'c2', 'b3']
1088 >>> s = bucket(iterable, key=lambda x: x[0]) # Bucket by 1st character
1089 >>> sorted(list(s)) # Get the keys
1090 ['a', 'b', 'c']
1091 >>> a_iterable = s['a']
1092 >>> next(a_iterable)
1093 'a1'
1094 >>> next(a_iterable)
1095 'a2'
1096 >>> list(s['b'])
1097 ['b1', 'b2', 'b3']
1098
1099 The original iterable will be advanced and its items will be cached until
1100 they are used by the child iterables. This may require significant storage.
1101
1102 By default, attempting to select a bucket to which no items belong will
1103 exhaust the iterable and cache all values.
1104 If you specify a *validator* function, selected buckets will instead be
1105 checked against it.
1106
1107 >>> from itertools import count
1108 >>> it = count(1, 2) # Infinite sequence of odd numbers
1109 >>> key = lambda x: x % 10 # Bucket by last digit
1110 >>> validator = lambda x: x in {1, 3, 5, 7, 9} # Odd digits only
1111 >>> s = bucket(it, key=key, validator=validator)
1112 >>> 2 in s
1113 False
1114 >>> list(s[2])
1115 []
1116
1117 """
1118
1119 def __init__(self, iterable, key, validator=None):
1120 self._it = iter(iterable)
1121 self._key = key
1122 self._cache = defaultdict(deque)
1123 self._validator = validator or (lambda x: True)
1124
1125 def __contains__(self, value):
1126 if not self._validator(value):
1127 return False
1128
1129 try:
1130 item = next(self[value])
1131 except StopIteration:
1132 return False
1133 else:
1134 self._cache[value].appendleft(item)
1135
1136 return True
1137
1138 def _get_values(self, value):
1139 """
1140 Helper to yield items from the parent iterator that match *value*.
1141 Items that don't match are stored in the local cache as they
1142 are encountered.
1143 """
1144 while True:
1145 # If we've cached some items that match the target value, emit
1146 # the first one and evict it from the cache.
1147 if self._cache[value]:
1148 yield self._cache[value].popleft()
1149 # Otherwise we need to advance the parent iterator to search for
1150 # a matching item, caching the rest.
1151 else:
1152 while True:
1153 try:
1154 item = next(self._it)
1155 except StopIteration:
1156 return
1157 item_value = self._key(item)
1158 if item_value == value:
1159 yield item
1160 break
1161 elif self._validator(item_value):
1162 self._cache[item_value].append(item)
1163
1164 def __iter__(self):
1165 for item in self._it:
1166 item_value = self._key(item)
1167 if self._validator(item_value):
1168 self._cache[item_value].append(item)
1169
1170 return iter(self._cache)
1171
1172 def __getitem__(self, value):
1173 if not self._validator(value):
1174 return iter(())
1175
1176 return self._get_values(value)
1177
1178
1179def spy(iterable, n=1):
1180 """Return a 2-tuple with a list containing the first *n* elements of
1181 *iterable*, and an iterator with the same items as *iterable*.
1182 This allows you to "look ahead" at the items in the iterable without
1183 advancing it.
1184
1185 There is one item in the list by default:
1186
1187 >>> iterable = 'abcdefg'
1188 >>> head, iterable = spy(iterable)
1189 >>> head
1190 ['a']
1191 >>> list(iterable)
1192 ['a', 'b', 'c', 'd', 'e', 'f', 'g']
1193
1194 You may use unpacking to retrieve items instead of lists:
1195
1196 >>> (head,), iterable = spy('abcdefg')
1197 >>> head
1198 'a'
1199 >>> (first, second), iterable = spy('abcdefg', 2)
1200 >>> first
1201 'a'
1202 >>> second
1203 'b'
1204
1205 The number of items requested can be larger than the number of items in
1206 the iterable:
1207
1208 >>> iterable = [1, 2, 3, 4, 5]
1209 >>> head, iterable = spy(iterable, 10)
1210 >>> head
1211 [1, 2, 3, 4, 5]
1212 >>> list(iterable)
1213 [1, 2, 3, 4, 5]
1214
1215 """
1216 p, q = tee(iterable)
1217 return take(n, q), p
1218
1219
1220def interleave(*iterables):
1221 """Return a new iterable yielding from each iterable in turn,
1222 until the shortest is exhausted.
1223
1224 >>> list(interleave([1, 2, 3], [4, 5], [6, 7, 8]))
1225 [1, 4, 6, 2, 5, 7]
1226
1227 For a version that doesn't terminate after the shortest iterable is
1228 exhausted, see :func:`interleave_longest`.
1229
1230 """
1231 return chain.from_iterable(zip(*iterables))
1232
1233
1234def interleave_longest(*iterables):
1235 """Return a new iterable yielding from each iterable in turn,
1236 skipping any that are exhausted.
1237
1238 >>> list(interleave_longest([1, 2, 3], [4, 5], [6, 7, 8]))
1239 [1, 4, 6, 2, 5, 7, 3, 8]
1240
1241 This function produces the same output as :func:`roundrobin`, but may
1242 perform better for some inputs (in particular when the number of iterables
1243 is large).
1244
1245 """
1246 for xs in zip_longest(*iterables, fillvalue=_marker):
1247 for x in xs:
1248 if x is not _marker:
1249 yield x
1250
1251
1252def interleave_evenly(iterables, lengths=None):
1253 """
1254 Interleave multiple iterables so that their elements are evenly distributed
1255 throughout the output sequence.
1256
1257 >>> iterables = [1, 2, 3, 4, 5], ['a', 'b']
1258 >>> list(interleave_evenly(iterables))
1259 [1, 2, 'a', 3, 4, 'b', 5]
1260
1261 >>> iterables = [[1, 2, 3], [4, 5], [6, 7, 8]]
1262 >>> list(interleave_evenly(iterables))
1263 [1, 6, 4, 2, 7, 3, 8, 5]
1264
1265 This function requires iterables of known length. Iterables without
1266 ``__len__()`` can be used by manually specifying lengths with *lengths*:
1267
1268 >>> from itertools import combinations, repeat
1269 >>> iterables = [combinations(range(4), 2), ['a', 'b', 'c']]
1270 >>> lengths = [4 * (4 - 1) // 2, 3]
1271 >>> list(interleave_evenly(iterables, lengths=lengths))
1272 [(0, 1), (0, 2), 'a', (0, 3), (1, 2), 'b', (1, 3), (2, 3), 'c']
1273
1274 Based on Bresenham's algorithm.
1275 """
1276 if lengths is None:
1277 try:
1278 lengths = [len(it) for it in iterables]
1279 except TypeError:
1280 raise ValueError(
1281 'Iterable lengths could not be determined automatically. '
1282 'Specify them with the lengths keyword.'
1283 )
1284 elif len(iterables) != len(lengths):
1285 raise ValueError('Mismatching number of iterables and lengths.')
1286
1287 dims = len(lengths)
1288
1289 # sort iterables by length, descending
1290 lengths_permute = sorted(
1291 range(dims), key=lambda i: lengths[i], reverse=True
1292 )
1293 lengths_desc = [lengths[i] for i in lengths_permute]
1294 iters_desc = [iter(iterables[i]) for i in lengths_permute]
1295
1296 # the longest iterable is the primary one (Bresenham: the longest
1297 # distance along an axis)
1298 delta_primary, deltas_secondary = lengths_desc[0], lengths_desc[1:]
1299 iter_primary, iters_secondary = iters_desc[0], iters_desc[1:]
1300 errors = [delta_primary // dims] * len(deltas_secondary)
1301
1302 to_yield = sum(lengths)
1303 while to_yield:
1304 yield next(iter_primary)
1305 to_yield -= 1
1306 # update errors for each secondary iterable
1307 errors = [e - delta for e, delta in zip(errors, deltas_secondary)]
1308
1309 # those iterables for which the error is negative are yielded
1310 # ("diagonal step" in Bresenham)
1311 for i, e_ in enumerate(errors):
1312 if e_ < 0:
1313 yield next(iters_secondary[i])
1314 to_yield -= 1
1315 errors[i] += delta_primary
1316
1317
1318def collapse(iterable, base_type=None, levels=None):
1319 """Flatten an iterable with multiple levels of nesting (e.g., a list of
1320 lists of tuples) into non-iterable types.
1321
1322 >>> iterable = [(1, 2), ([3, 4], [[5], [6]])]
1323 >>> list(collapse(iterable))
1324 [1, 2, 3, 4, 5, 6]
1325
1326 Binary and text strings are not considered iterable and
1327 will not be collapsed.
1328
1329 To avoid collapsing other types, specify *base_type*:
1330
1331 >>> iterable = ['ab', ('cd', 'ef'), ['gh', 'ij']]
1332 >>> list(collapse(iterable, base_type=tuple))
1333 ['ab', ('cd', 'ef'), 'gh', 'ij']
1334
1335 Specify *levels* to stop flattening after a certain level:
1336
1337 >>> iterable = [('a', ['b']), ('c', ['d'])]
1338 >>> list(collapse(iterable)) # Fully flattened
1339 ['a', 'b', 'c', 'd']
1340 >>> list(collapse(iterable, levels=1)) # Only one level flattened
1341 ['a', ['b'], 'c', ['d']]
1342
1343 """
1344 stack = deque()
1345 # Add our first node group, treat the iterable as a single node
1346 stack.appendleft((0, repeat(iterable, 1)))
1347
1348 while stack:
1349 node_group = stack.popleft()
1350 level, nodes = node_group
1351
1352 # Check if beyond max level
1353 if levels is not None and level > levels:
1354 yield from nodes
1355 continue
1356
1357 for node in nodes:
1358 # Check if done iterating
1359 if isinstance(node, (str, bytes)) or (
1360 (base_type is not None) and isinstance(node, base_type)
1361 ):
1362 yield node
1363 # Otherwise try to create child nodes
1364 else:
1365 try:
1366 tree = iter(node)
1367 except TypeError:
1368 yield node
1369 else:
1370 # Save our current location
1371 stack.appendleft(node_group)
1372 # Append the new child node
1373 stack.appendleft((level + 1, tree))
1374 # Break to process child node
1375 break
1376
1377
1378def side_effect(func, iterable, chunk_size=None, before=None, after=None):
1379 """Invoke *func* on each item in *iterable* (or on each *chunk_size* group
1380 of items) before yielding the item.
1381
1382 `func` must be a function that takes a single argument. Its return value
1383 will be discarded.
1384
1385 *before* and *after* are optional functions that take no arguments. They
1386 will be executed before iteration starts and after it ends, respectively.
1387
1388 `side_effect` can be used for logging, updating progress bars, or anything
1389 that is not functionally "pure."
1390
1391 Emitting a status message:
1392
1393 >>> from more_itertools import consume
1394 >>> func = lambda item: print('Received {}'.format(item))
1395 >>> consume(side_effect(func, range(2)))
1396 Received 0
1397 Received 1
1398
1399 Operating on chunks of items:
1400
1401 >>> pair_sums = []
1402 >>> func = lambda chunk: pair_sums.append(sum(chunk))
1403 >>> list(side_effect(func, [0, 1, 2, 3, 4, 5], 2))
1404 [0, 1, 2, 3, 4, 5]
1405 >>> list(pair_sums)
1406 [1, 5, 9]
1407
1408 Writing to a file-like object:
1409
1410 >>> from io import StringIO
1411 >>> from more_itertools import consume
1412 >>> f = StringIO()
1413 >>> func = lambda x: print(x, file=f)
1414 >>> before = lambda: print(u'HEADER', file=f)
1415 >>> after = f.close
1416 >>> it = [u'a', u'b', u'c']
1417 >>> consume(side_effect(func, it, before=before, after=after))
1418 >>> f.closed
1419 True
1420
1421 """
1422 try:
1423 if before is not None:
1424 before()
1425
1426 if chunk_size is None:
1427 for item in iterable:
1428 func(item)
1429 yield item
1430 else:
1431 for chunk in chunked(iterable, chunk_size):
1432 func(chunk)
1433 yield from chunk
1434 finally:
1435 if after is not None:
1436 after()
1437
1438
1439def sliced(seq, n, strict=False):
1440 """Yield slices of length *n* from the sequence *seq*.
1441
1442 >>> list(sliced((1, 2, 3, 4, 5, 6), 3))
1443 [(1, 2, 3), (4, 5, 6)]
1444
1445 By the default, the last yielded slice will have fewer than *n* elements
1446 if the length of *seq* is not divisible by *n*:
1447
1448 >>> list(sliced((1, 2, 3, 4, 5, 6, 7, 8), 3))
1449 [(1, 2, 3), (4, 5, 6), (7, 8)]
1450
1451 If the length of *seq* is not divisible by *n* and *strict* is
1452 ``True``, then ``ValueError`` will be raised before the last
1453 slice is yielded.
1454
1455 This function will only work for iterables that support slicing.
1456 For non-sliceable iterables, see :func:`chunked`.
1457
1458 """
1459 iterator = takewhile(len, (seq[i : i + n] for i in count(0, n)))
1460 if strict:
1461
1462 def ret():
1463 for _slice in iterator:
1464 if len(_slice) != n:
1465 raise ValueError("seq is not divisible by n.")
1466 yield _slice
1467
1468 return ret()
1469 else:
1470 return iterator
1471
1472
1473def split_at(iterable, pred, maxsplit=-1, keep_separator=False):
1474 """Yield lists of items from *iterable*, where each list is delimited by
1475 an item where callable *pred* returns ``True``.
1476
1477 >>> list(split_at('abcdcba', lambda x: x == 'b'))
1478 [['a'], ['c', 'd', 'c'], ['a']]
1479
1480 >>> list(split_at(range(10), lambda n: n % 2 == 1))
1481 [[0], [2], [4], [6], [8], []]
1482
1483 At most *maxsplit* splits are done. If *maxsplit* is not specified or -1,
1484 then there is no limit on the number of splits:
1485
1486 >>> list(split_at(range(10), lambda n: n % 2 == 1, maxsplit=2))
1487 [[0], [2], [4, 5, 6, 7, 8, 9]]
1488
1489 By default, the delimiting items are not included in the output.
1490 To include them, set *keep_separator* to ``True``.
1491
1492 >>> list(split_at('abcdcba', lambda x: x == 'b', keep_separator=True))
1493 [['a'], ['b'], ['c', 'd', 'c'], ['b'], ['a']]
1494
1495 """
1496 if maxsplit == 0:
1497 yield list(iterable)
1498 return
1499
1500 buf = []
1501 it = iter(iterable)
1502 for item in it:
1503 if pred(item):
1504 yield buf
1505 if keep_separator:
1506 yield [item]
1507 if maxsplit == 1:
1508 yield list(it)
1509 return
1510 buf = []
1511 maxsplit -= 1
1512 else:
1513 buf.append(item)
1514 yield buf
1515
1516
1517def split_before(iterable, pred, maxsplit=-1):
1518 """Yield lists of items from *iterable*, where each list ends just before
1519 an item for which callable *pred* returns ``True``:
1520
1521 >>> list(split_before('OneTwo', lambda s: s.isupper()))
1522 [['O', 'n', 'e'], ['T', 'w', 'o']]
1523
1524 >>> list(split_before(range(10), lambda n: n % 3 == 0))
1525 [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
1526
1527 At most *maxsplit* splits are done. If *maxsplit* is not specified or -1,
1528 then there is no limit on the number of splits:
1529
1530 >>> list(split_before(range(10), lambda n: n % 3 == 0, maxsplit=2))
1531 [[0, 1, 2], [3, 4, 5], [6, 7, 8, 9]]
1532 """
1533 if maxsplit == 0:
1534 yield list(iterable)
1535 return
1536
1537 buf = []
1538 it = iter(iterable)
1539 for item in it:
1540 if pred(item) and buf:
1541 yield buf
1542 if maxsplit == 1:
1543 yield [item, *it]
1544 return
1545 buf = []
1546 maxsplit -= 1
1547 buf.append(item)
1548 if buf:
1549 yield buf
1550
1551
1552def split_after(iterable, pred, maxsplit=-1):
1553 """Yield lists of items from *iterable*, where each list ends with an
1554 item where callable *pred* returns ``True``:
1555
1556 >>> list(split_after('one1two2', lambda s: s.isdigit()))
1557 [['o', 'n', 'e', '1'], ['t', 'w', 'o', '2']]
1558
1559 >>> list(split_after(range(10), lambda n: n % 3 == 0))
1560 [[0], [1, 2, 3], [4, 5, 6], [7, 8, 9]]
1561
1562 At most *maxsplit* splits are done. If *maxsplit* is not specified or -1,
1563 then there is no limit on the number of splits:
1564
1565 >>> list(split_after(range(10), lambda n: n % 3 == 0, maxsplit=2))
1566 [[0], [1, 2, 3], [4, 5, 6, 7, 8, 9]]
1567
1568 """
1569 if maxsplit == 0:
1570 yield list(iterable)
1571 return
1572
1573 buf = []
1574 it = iter(iterable)
1575 for item in it:
1576 buf.append(item)
1577 if pred(item) and buf:
1578 yield buf
1579 if maxsplit == 1:
1580 buf = list(it)
1581 if buf:
1582 yield buf
1583 return
1584 buf = []
1585 maxsplit -= 1
1586 if buf:
1587 yield buf
1588
1589
1590def split_when(iterable, pred, maxsplit=-1):
1591 """Split *iterable* into pieces based on the output of *pred*.
1592 *pred* should be a function that takes successive pairs of items and
1593 returns ``True`` if the iterable should be split in between them.
1594
1595 For example, to find runs of increasing numbers, split the iterable when
1596 element ``i`` is larger than element ``i + 1``:
1597
1598 >>> list(split_when([1, 2, 3, 3, 2, 5, 2, 4, 2], lambda x, y: x > y))
1599 [[1, 2, 3, 3], [2, 5], [2, 4], [2]]
1600
1601 At most *maxsplit* splits are done. If *maxsplit* is not specified or -1,
1602 then there is no limit on the number of splits:
1603
1604 >>> list(split_when([1, 2, 3, 3, 2, 5, 2, 4, 2],
1605 ... lambda x, y: x > y, maxsplit=2))
1606 [[1, 2, 3, 3], [2, 5], [2, 4, 2]]
1607
1608 """
1609 if maxsplit == 0:
1610 yield list(iterable)
1611 return
1612
1613 it = iter(iterable)
1614 try:
1615 cur_item = next(it)
1616 except StopIteration:
1617 return
1618
1619 buf = [cur_item]
1620 for next_item in it:
1621 if pred(cur_item, next_item):
1622 yield buf
1623 if maxsplit == 1:
1624 yield [next_item, *it]
1625 return
1626 buf = []
1627 maxsplit -= 1
1628
1629 buf.append(next_item)
1630 cur_item = next_item
1631
1632 yield buf
1633
1634
1635def split_into(iterable, sizes):
1636 """Yield a list of sequential items from *iterable* of length 'n' for each
1637 integer 'n' in *sizes*.
1638
1639 >>> list(split_into([1,2,3,4,5,6], [1,2,3]))
1640 [[1], [2, 3], [4, 5, 6]]
1641
1642 If the sum of *sizes* is smaller than the length of *iterable*, then the
1643 remaining items of *iterable* will not be returned.
1644
1645 >>> list(split_into([1,2,3,4,5,6], [2,3]))
1646 [[1, 2], [3, 4, 5]]
1647
1648 If the sum of *sizes* is larger than the length of *iterable*, fewer items
1649 will be returned in the iteration that overruns the *iterable* and further
1650 lists will be empty:
1651
1652 >>> list(split_into([1,2,3,4], [1,2,3,4]))
1653 [[1], [2, 3], [4], []]
1654
1655 When a ``None`` object is encountered in *sizes*, the returned list will
1656 contain items up to the end of *iterable* the same way that
1657 :func:`itertools.slice` does:
1658
1659 >>> list(split_into([1,2,3,4,5,6,7,8,9,0], [2,3,None]))
1660 [[1, 2], [3, 4, 5], [6, 7, 8, 9, 0]]
1661
1662 :func:`split_into` can be useful for grouping a series of items where the
1663 sizes of the groups are not uniform. An example would be where in a row
1664 from a table, multiple columns represent elements of the same feature
1665 (e.g. a point represented by x,y,z) but, the format is not the same for
1666 all columns.
1667 """
1668 # convert the iterable argument into an iterator so its contents can
1669 # be consumed by islice in case it is a generator
1670 it = iter(iterable)
1671
1672 for size in sizes:
1673 if size is None:
1674 yield list(it)
1675 return
1676 else:
1677 yield list(islice(it, size))
1678
1679
1680def padded(iterable, fillvalue=None, n=None, next_multiple=False):
1681 """Yield the elements from *iterable*, followed by *fillvalue*, such that
1682 at least *n* items are emitted.
1683
1684 >>> list(padded([1, 2, 3], '?', 5))
1685 [1, 2, 3, '?', '?']
1686
1687 If *next_multiple* is ``True``, *fillvalue* will be emitted until the
1688 number of items emitted is a multiple of *n*:
1689
1690 >>> list(padded([1, 2, 3, 4], n=3, next_multiple=True))
1691 [1, 2, 3, 4, None, None]
1692
1693 If *n* is ``None``, *fillvalue* will be emitted indefinitely.
1694
1695 To create an *iterable* of exactly size *n*, you can truncate with
1696 :func:`islice`.
1697
1698 >>> list(islice(padded([1, 2, 3], '?'), 5))
1699 [1, 2, 3, '?', '?']
1700 >>> list(islice(padded([1, 2, 3, 4, 5, 6, 7, 8], '?'), 5))
1701 [1, 2, 3, 4, 5]
1702
1703 """
1704 iterator = iter(iterable)
1705 iterator_with_repeat = chain(iterator, repeat(fillvalue))
1706
1707 if n is None:
1708 return iterator_with_repeat
1709 elif n < 1:
1710 raise ValueError('n must be at least 1')
1711 elif next_multiple:
1712
1713 def slice_generator():
1714 for first in iterator:
1715 yield (first,)
1716 yield islice(iterator_with_repeat, n - 1)
1717
1718 # While elements exist produce slices of size n
1719 return chain.from_iterable(slice_generator())
1720 else:
1721 # Ensure the first batch is at least size n then iterate
1722 return chain(islice(iterator_with_repeat, n), iterator)
1723
1724
1725def repeat_each(iterable, n=2):
1726 """Repeat each element in *iterable* *n* times.
1727
1728 >>> list(repeat_each('ABC', 3))
1729 ['A', 'A', 'A', 'B', 'B', 'B', 'C', 'C', 'C']
1730 """
1731 return chain.from_iterable(map(repeat, iterable, repeat(n)))
1732
1733
1734def repeat_last(iterable, default=None):
1735 """After the *iterable* is exhausted, keep yielding its last element.
1736
1737 >>> list(islice(repeat_last(range(3)), 5))
1738 [0, 1, 2, 2, 2]
1739
1740 If the iterable is empty, yield *default* forever::
1741
1742 >>> list(islice(repeat_last(range(0), 42), 5))
1743 [42, 42, 42, 42, 42]
1744
1745 """
1746 item = _marker
1747 for item in iterable:
1748 yield item
1749 final = default if item is _marker else item
1750 yield from repeat(final)
1751
1752
1753def distribute(n, iterable):
1754 """Distribute the items from *iterable* among *n* smaller iterables.
1755
1756 >>> group_1, group_2 = distribute(2, [1, 2, 3, 4, 5, 6])
1757 >>> list(group_1)
1758 [1, 3, 5]
1759 >>> list(group_2)
1760 [2, 4, 6]
1761
1762 If the length of *iterable* is not evenly divisible by *n*, then the
1763 length of the returned iterables will not be identical:
1764
1765 >>> children = distribute(3, [1, 2, 3, 4, 5, 6, 7])
1766 >>> [list(c) for c in children]
1767 [[1, 4, 7], [2, 5], [3, 6]]
1768
1769 If the length of *iterable* is smaller than *n*, then the last returned
1770 iterables will be empty:
1771
1772 >>> children = distribute(5, [1, 2, 3])
1773 >>> [list(c) for c in children]
1774 [[1], [2], [3], [], []]
1775
1776 This function uses :func:`itertools.tee` and may require significant
1777 storage.
1778
1779 If you need the order items in the smaller iterables to match the
1780 original iterable, see :func:`divide`.
1781
1782 """
1783 if n < 1:
1784 raise ValueError('n must be at least 1')
1785
1786 children = tee(iterable, n)
1787 return [islice(it, index, None, n) for index, it in enumerate(children)]
1788
1789
1790def stagger(iterable, offsets=(-1, 0, 1), longest=False, fillvalue=None):
1791 """Yield tuples whose elements are offset from *iterable*.
1792 The amount by which the `i`-th item in each tuple is offset is given by
1793 the `i`-th item in *offsets*.
1794
1795 >>> list(stagger([0, 1, 2, 3]))
1796 [(None, 0, 1), (0, 1, 2), (1, 2, 3)]
1797 >>> list(stagger(range(8), offsets=(0, 2, 4)))
1798 [(0, 2, 4), (1, 3, 5), (2, 4, 6), (3, 5, 7)]
1799
1800 By default, the sequence will end when the final element of a tuple is the
1801 last item in the iterable. To continue until the first element of a tuple
1802 is the last item in the iterable, set *longest* to ``True``::
1803
1804 >>> list(stagger([0, 1, 2, 3], longest=True))
1805 [(None, 0, 1), (0, 1, 2), (1, 2, 3), (2, 3, None), (3, None, None)]
1806
1807 By default, ``None`` will be used to replace offsets beyond the end of the
1808 sequence. Specify *fillvalue* to use some other value.
1809
1810 """
1811 children = tee(iterable, len(offsets))
1812
1813 return zip_offset(
1814 *children, offsets=offsets, longest=longest, fillvalue=fillvalue
1815 )
1816
1817
1818def zip_equal(*iterables):
1819 """``zip`` the input *iterables* together but raise
1820 ``UnequalIterablesError`` if they aren't all the same length.
1821
1822 >>> it_1 = range(3)
1823 >>> it_2 = iter('abc')
1824 >>> list(zip_equal(it_1, it_2))
1825 [(0, 'a'), (1, 'b'), (2, 'c')]
1826
1827 >>> it_1 = range(3)
1828 >>> it_2 = iter('abcd')
1829 >>> list(zip_equal(it_1, it_2)) # doctest: +IGNORE_EXCEPTION_DETAIL
1830 Traceback (most recent call last):
1831 ...
1832 more_itertools.more.UnequalIterablesError: Iterables have different
1833 lengths
1834
1835 """
1836 if hexversion >= 0x30A00A6:
1837 warnings.warn(
1838 (
1839 'zip_equal will be removed in a future version of '
1840 'more-itertools. Use the builtin zip function with '
1841 'strict=True instead.'
1842 ),
1843 DeprecationWarning,
1844 )
1845
1846 return _zip_equal(*iterables)
1847
1848
1849def zip_offset(*iterables, offsets, longest=False, fillvalue=None):
1850 """``zip`` the input *iterables* together, but offset the `i`-th iterable
1851 by the `i`-th item in *offsets*.
1852
1853 >>> list(zip_offset('0123', 'abcdef', offsets=(0, 1)))
1854 [('0', 'b'), ('1', 'c'), ('2', 'd'), ('3', 'e')]
1855
1856 This can be used as a lightweight alternative to SciPy or pandas to analyze
1857 data sets in which some series have a lead or lag relationship.
1858
1859 By default, the sequence will end when the shortest iterable is exhausted.
1860 To continue until the longest iterable is exhausted, set *longest* to
1861 ``True``.
1862
1863 >>> list(zip_offset('0123', 'abcdef', offsets=(0, 1), longest=True))
1864 [('0', 'b'), ('1', 'c'), ('2', 'd'), ('3', 'e'), (None, 'f')]
1865
1866 By default, ``None`` will be used to replace offsets beyond the end of the
1867 sequence. Specify *fillvalue* to use some other value.
1868
1869 """
1870 if len(iterables) != len(offsets):
1871 raise ValueError("Number of iterables and offsets didn't match")
1872
1873 staggered = []
1874 for it, n in zip(iterables, offsets):
1875 if n < 0:
1876 staggered.append(chain(repeat(fillvalue, -n), it))
1877 elif n > 0:
1878 staggered.append(islice(it, n, None))
1879 else:
1880 staggered.append(it)
1881
1882 if longest:
1883 return zip_longest(*staggered, fillvalue=fillvalue)
1884
1885 return zip(*staggered)
1886
1887
1888def sort_together(
1889 iterables, key_list=(0,), key=None, reverse=False, strict=False
1890):
1891 """Return the input iterables sorted together, with *key_list* as the
1892 priority for sorting. All iterables are trimmed to the length of the
1893 shortest one.
1894
1895 This can be used like the sorting function in a spreadsheet. If each
1896 iterable represents a column of data, the key list determines which
1897 columns are used for sorting.
1898
1899 By default, all iterables are sorted using the ``0``-th iterable::
1900
1901 >>> iterables = [(4, 3, 2, 1), ('a', 'b', 'c', 'd')]
1902 >>> sort_together(iterables)
1903 [(1, 2, 3, 4), ('d', 'c', 'b', 'a')]
1904
1905 Set a different key list to sort according to another iterable.
1906 Specifying multiple keys dictates how ties are broken::
1907
1908 >>> iterables = [(3, 1, 2), (0, 1, 0), ('c', 'b', 'a')]
1909 >>> sort_together(iterables, key_list=(1, 2))
1910 [(2, 3, 1), (0, 0, 1), ('a', 'c', 'b')]
1911
1912 To sort by a function of the elements of the iterable, pass a *key*
1913 function. Its arguments are the elements of the iterables corresponding to
1914 the key list::
1915
1916 >>> names = ('a', 'b', 'c')
1917 >>> lengths = (1, 2, 3)
1918 >>> widths = (5, 2, 1)
1919 >>> def area(length, width):
1920 ... return length * width
1921 >>> sort_together([names, lengths, widths], key_list=(1, 2), key=area)
1922 [('c', 'b', 'a'), (3, 2, 1), (1, 2, 5)]
1923
1924 Set *reverse* to ``True`` to sort in descending order.
1925
1926 >>> sort_together([(1, 2, 3), ('c', 'b', 'a')], reverse=True)
1927 [(3, 2, 1), ('a', 'b', 'c')]
1928
1929 If the *strict* keyword argument is ``True``, then
1930 ``UnequalIterablesError`` will be raised if any of the iterables have
1931 different lengths.
1932
1933 """
1934 if key is None:
1935 # if there is no key function, the key argument to sorted is an
1936 # itemgetter
1937 key_argument = itemgetter(*key_list)
1938 else:
1939 # if there is a key function, call it with the items at the offsets
1940 # specified by the key function as arguments
1941 key_list = list(key_list)
1942 if len(key_list) == 1:
1943 # if key_list contains a single item, pass the item at that offset
1944 # as the only argument to the key function
1945 key_offset = key_list[0]
1946 key_argument = lambda zipped_items: key(zipped_items[key_offset])
1947 else:
1948 # if key_list contains multiple items, use itemgetter to return a
1949 # tuple of items, which we pass as *args to the key function
1950 get_key_items = itemgetter(*key_list)
1951 key_argument = lambda zipped_items: key(
1952 *get_key_items(zipped_items)
1953 )
1954
1955 zipper = zip_equal if strict else zip
1956 return list(
1957 zipper(*sorted(zipper(*iterables), key=key_argument, reverse=reverse))
1958 )
1959
1960
1961def unzip(iterable):
1962 """The inverse of :func:`zip`, this function disaggregates the elements
1963 of the zipped *iterable*.
1964
1965 The ``i``-th iterable contains the ``i``-th element from each element
1966 of the zipped iterable. The first element is used to determine the
1967 length of the remaining elements.
1968
1969 >>> iterable = [('a', 1), ('b', 2), ('c', 3), ('d', 4)]
1970 >>> letters, numbers = unzip(iterable)
1971 >>> list(letters)
1972 ['a', 'b', 'c', 'd']
1973 >>> list(numbers)
1974 [1, 2, 3, 4]
1975
1976 This is similar to using ``zip(*iterable)``, but it avoids reading
1977 *iterable* into memory. Note, however, that this function uses
1978 :func:`itertools.tee` and thus may require significant storage.
1979
1980 """
1981 head, iterable = spy(iterable)
1982 if not head:
1983 # empty iterable, e.g. zip([], [], [])
1984 return ()
1985 # spy returns a one-length iterable as head
1986 head = head[0]
1987 iterables = tee(iterable, len(head))
1988
1989 def itemgetter(i):
1990 def getter(obj):
1991 try:
1992 return obj[i]
1993 except IndexError:
1994 # basically if we have an iterable like
1995 # iter([(1, 2, 3), (4, 5), (6,)])
1996 # the second unzipped iterable would fail at the third tuple
1997 # since it would try to access tup[1]
1998 # same with the third unzipped iterable and the second tuple
1999 # to support these "improperly zipped" iterables,
2000 # we create a custom itemgetter
2001 # which just stops the unzipped iterables
2002 # at first length mismatch
2003 raise StopIteration
2004
2005 return getter
2006
2007 return tuple(map(itemgetter(i), it) for i, it in enumerate(iterables))
2008
2009
2010def divide(n, iterable):
2011 """Divide the elements from *iterable* into *n* parts, maintaining
2012 order.
2013
2014 >>> group_1, group_2 = divide(2, [1, 2, 3, 4, 5, 6])
2015 >>> list(group_1)
2016 [1, 2, 3]
2017 >>> list(group_2)
2018 [4, 5, 6]
2019
2020 If the length of *iterable* is not evenly divisible by *n*, then the
2021 length of the returned iterables will not be identical:
2022
2023 >>> children = divide(3, [1, 2, 3, 4, 5, 6, 7])
2024 >>> [list(c) for c in children]
2025 [[1, 2, 3], [4, 5], [6, 7]]
2026
2027 If the length of the iterable is smaller than n, then the last returned
2028 iterables will be empty:
2029
2030 >>> children = divide(5, [1, 2, 3])
2031 >>> [list(c) for c in children]
2032 [[1], [2], [3], [], []]
2033
2034 This function will exhaust the iterable before returning.
2035 If order is not important, see :func:`distribute`, which does not first
2036 pull the iterable into memory.
2037
2038 """
2039 if n < 1:
2040 raise ValueError('n must be at least 1')
2041
2042 try:
2043 iterable[:0]
2044 except TypeError:
2045 seq = tuple(iterable)
2046 else:
2047 seq = iterable
2048
2049 q, r = divmod(len(seq), n)
2050
2051 ret = []
2052 stop = 0
2053 for i in range(1, n + 1):
2054 start = stop
2055 stop += q + 1 if i <= r else q
2056 ret.append(iter(seq[start:stop]))
2057
2058 return ret
2059
2060
2061def always_iterable(obj, base_type=(str, bytes)):
2062 """If *obj* is iterable, return an iterator over its items::
2063
2064 >>> obj = (1, 2, 3)
2065 >>> list(always_iterable(obj))
2066 [1, 2, 3]
2067
2068 If *obj* is not iterable, return a one-item iterable containing *obj*::
2069
2070 >>> obj = 1
2071 >>> list(always_iterable(obj))
2072 [1]
2073
2074 If *obj* is ``None``, return an empty iterable:
2075
2076 >>> obj = None
2077 >>> list(always_iterable(None))
2078 []
2079
2080 By default, binary and text strings are not considered iterable::
2081
2082 >>> obj = 'foo'
2083 >>> list(always_iterable(obj))
2084 ['foo']
2085
2086 If *base_type* is set, objects for which ``isinstance(obj, base_type)``
2087 returns ``True`` won't be considered iterable.
2088
2089 >>> obj = {'a': 1}
2090 >>> list(always_iterable(obj)) # Iterate over the dict's keys
2091 ['a']
2092 >>> list(always_iterable(obj, base_type=dict)) # Treat dicts as a unit
2093 [{'a': 1}]
2094
2095 Set *base_type* to ``None`` to avoid any special handling and treat objects
2096 Python considers iterable as iterable:
2097
2098 >>> obj = 'foo'
2099 >>> list(always_iterable(obj, base_type=None))
2100 ['f', 'o', 'o']
2101 """
2102 if obj is None:
2103 return iter(())
2104
2105 if (base_type is not None) and isinstance(obj, base_type):
2106 return iter((obj,))
2107
2108 try:
2109 return iter(obj)
2110 except TypeError:
2111 return iter((obj,))
2112
2113
2114def adjacent(predicate, iterable, distance=1):
2115 """Return an iterable over `(bool, item)` tuples where the `item` is
2116 drawn from *iterable* and the `bool` indicates whether
2117 that item satisfies the *predicate* or is adjacent to an item that does.
2118
2119 For example, to find whether items are adjacent to a ``3``::
2120
2121 >>> list(adjacent(lambda x: x == 3, range(6)))
2122 [(False, 0), (False, 1), (True, 2), (True, 3), (True, 4), (False, 5)]
2123
2124 Set *distance* to change what counts as adjacent. For example, to find
2125 whether items are two places away from a ``3``:
2126
2127 >>> list(adjacent(lambda x: x == 3, range(6), distance=2))
2128 [(False, 0), (True, 1), (True, 2), (True, 3), (True, 4), (True, 5)]
2129
2130 This is useful for contextualizing the results of a search function.
2131 For example, a code comparison tool might want to identify lines that
2132 have changed, but also surrounding lines to give the viewer of the diff
2133 context.
2134
2135 The predicate function will only be called once for each item in the
2136 iterable.
2137
2138 See also :func:`groupby_transform`, which can be used with this function
2139 to group ranges of items with the same `bool` value.
2140
2141 """
2142 # Allow distance=0 mainly for testing that it reproduces results with map()
2143 if distance < 0:
2144 raise ValueError('distance must be at least 0')
2145
2146 i1, i2 = tee(iterable)
2147 padding = [False] * distance
2148 selected = chain(padding, map(predicate, i1), padding)
2149 adjacent_to_selected = map(any, windowed(selected, 2 * distance + 1))
2150 return zip(adjacent_to_selected, i2)
2151
2152
2153def groupby_transform(iterable, keyfunc=None, valuefunc=None, reducefunc=None):
2154 """An extension of :func:`itertools.groupby` that can apply transformations
2155 to the grouped data.
2156
2157 * *keyfunc* is a function computing a key value for each item in *iterable*
2158 * *valuefunc* is a function that transforms the individual items from
2159 *iterable* after grouping
2160 * *reducefunc* is a function that transforms each group of items
2161
2162 >>> iterable = 'aAAbBBcCC'
2163 >>> keyfunc = lambda k: k.upper()
2164 >>> valuefunc = lambda v: v.lower()
2165 >>> reducefunc = lambda g: ''.join(g)
2166 >>> list(groupby_transform(iterable, keyfunc, valuefunc, reducefunc))
2167 [('A', 'aaa'), ('B', 'bbb'), ('C', 'ccc')]
2168
2169 Each optional argument defaults to an identity function if not specified.
2170
2171 :func:`groupby_transform` is useful when grouping elements of an iterable
2172 using a separate iterable as the key. To do this, :func:`zip` the iterables
2173 and pass a *keyfunc* that extracts the first element and a *valuefunc*
2174 that extracts the second element::
2175
2176 >>> from operator import itemgetter
2177 >>> keys = [0, 0, 1, 1, 1, 2, 2, 2, 3]
2178 >>> values = 'abcdefghi'
2179 >>> iterable = zip(keys, values)
2180 >>> grouper = groupby_transform(iterable, itemgetter(0), itemgetter(1))
2181 >>> [(k, ''.join(g)) for k, g in grouper]
2182 [(0, 'ab'), (1, 'cde'), (2, 'fgh'), (3, 'i')]
2183
2184 Note that the order of items in the iterable is significant.
2185 Only adjacent items are grouped together, so if you don't want any
2186 duplicate groups, you should sort the iterable by the key function.
2187
2188 """
2189 ret = groupby(iterable, keyfunc)
2190 if valuefunc:
2191 ret = ((k, map(valuefunc, g)) for k, g in ret)
2192 if reducefunc:
2193 ret = ((k, reducefunc(g)) for k, g in ret)
2194
2195 return ret
2196
2197
2198class numeric_range(abc.Sequence, abc.Hashable):
2199 """An extension of the built-in ``range()`` function whose arguments can
2200 be any orderable numeric type.
2201
2202 With only *stop* specified, *start* defaults to ``0`` and *step*
2203 defaults to ``1``. The output items will match the type of *stop*:
2204
2205 >>> list(numeric_range(3.5))
2206 [0.0, 1.0, 2.0, 3.0]
2207
2208 With only *start* and *stop* specified, *step* defaults to ``1``. The
2209 output items will match the type of *start*:
2210
2211 >>> from decimal import Decimal
2212 >>> start = Decimal('2.1')
2213 >>> stop = Decimal('5.1')
2214 >>> list(numeric_range(start, stop))
2215 [Decimal('2.1'), Decimal('3.1'), Decimal('4.1')]
2216
2217 With *start*, *stop*, and *step* specified the output items will match
2218 the type of ``start + step``:
2219
2220 >>> from fractions import Fraction
2221 >>> start = Fraction(1, 2) # Start at 1/2
2222 >>> stop = Fraction(5, 2) # End at 5/2
2223 >>> step = Fraction(1, 2) # Count by 1/2
2224 >>> list(numeric_range(start, stop, step))
2225 [Fraction(1, 2), Fraction(1, 1), Fraction(3, 2), Fraction(2, 1)]
2226
2227 If *step* is zero, ``ValueError`` is raised. Negative steps are supported:
2228
2229 >>> list(numeric_range(3, -1, -1.0))
2230 [3.0, 2.0, 1.0, 0.0]
2231
2232 Be aware of the limitations of floating-point numbers; the representation
2233 of the yielded numbers may be surprising.
2234
2235 ``datetime.datetime`` objects can be used for *start* and *stop*, if *step*
2236 is a ``datetime.timedelta`` object:
2237
2238 >>> import datetime
2239 >>> start = datetime.datetime(2019, 1, 1)
2240 >>> stop = datetime.datetime(2019, 1, 3)
2241 >>> step = datetime.timedelta(days=1)
2242 >>> items = iter(numeric_range(start, stop, step))
2243 >>> next(items)
2244 datetime.datetime(2019, 1, 1, 0, 0)
2245 >>> next(items)
2246 datetime.datetime(2019, 1, 2, 0, 0)
2247
2248 """
2249
2250 _EMPTY_HASH = hash(range(0, 0))
2251
2252 def __init__(self, *args):
2253 argc = len(args)
2254 if argc == 1:
2255 (self._stop,) = args
2256 self._start = type(self._stop)(0)
2257 self._step = type(self._stop - self._start)(1)
2258 elif argc == 2:
2259 self._start, self._stop = args
2260 self._step = type(self._stop - self._start)(1)
2261 elif argc == 3:
2262 self._start, self._stop, self._step = args
2263 elif argc == 0:
2264 raise TypeError(
2265 f'numeric_range expected at least 1 argument, got {argc}'
2266 )
2267 else:
2268 raise TypeError(
2269 f'numeric_range expected at most 3 arguments, got {argc}'
2270 )
2271
2272 self._zero = type(self._step)(0)
2273 if self._step == self._zero:
2274 raise ValueError('numeric_range() arg 3 must not be zero')
2275 self._growing = self._step > self._zero
2276
2277 def __bool__(self):
2278 if self._growing:
2279 return self._start < self._stop
2280 else:
2281 return self._start > self._stop
2282
2283 def __contains__(self, elem):
2284 if self._growing:
2285 if self._start <= elem < self._stop:
2286 return (elem - self._start) % self._step == self._zero
2287 else:
2288 if self._start >= elem > self._stop:
2289 return (self._start - elem) % (-self._step) == self._zero
2290
2291 return False
2292
2293 def __eq__(self, other):
2294 if isinstance(other, numeric_range):
2295 empty_self = not bool(self)
2296 empty_other = not bool(other)
2297 if empty_self or empty_other:
2298 return empty_self and empty_other # True if both empty
2299 else:
2300 return (
2301 self._start == other._start
2302 and self._step == other._step
2303 and self._get_by_index(-1) == other._get_by_index(-1)
2304 )
2305 else:
2306 return False
2307
2308 def __getitem__(self, key):
2309 if isinstance(key, int):
2310 return self._get_by_index(key)
2311 elif isinstance(key, slice):
2312 step = self._step if key.step is None else key.step * self._step
2313
2314 if key.start is None or key.start <= -self._len:
2315 start = self._start
2316 elif key.start >= self._len:
2317 start = self._stop
2318 else: # -self._len < key.start < self._len
2319 start = self._get_by_index(key.start)
2320
2321 if key.stop is None or key.stop >= self._len:
2322 stop = self._stop
2323 elif key.stop <= -self._len:
2324 stop = self._start
2325 else: # -self._len < key.stop < self._len
2326 stop = self._get_by_index(key.stop)
2327
2328 return numeric_range(start, stop, step)
2329 else:
2330 raise TypeError(
2331 'numeric range indices must be '
2332 f'integers or slices, not {type(key).__name__}'
2333 )
2334
2335 def __hash__(self):
2336 if self:
2337 return hash((self._start, self._get_by_index(-1), self._step))
2338 else:
2339 return self._EMPTY_HASH
2340
2341 def __iter__(self):
2342 values = (self._start + (n * self._step) for n in count())
2343 if self._growing:
2344 return takewhile(partial(gt, self._stop), values)
2345 else:
2346 return takewhile(partial(lt, self._stop), values)
2347
2348 def __len__(self):
2349 return self._len
2350
2351 @cached_property
2352 def _len(self):
2353 if self._growing:
2354 start = self._start
2355 stop = self._stop
2356 step = self._step
2357 else:
2358 start = self._stop
2359 stop = self._start
2360 step = -self._step
2361 distance = stop - start
2362 if distance <= self._zero:
2363 return 0
2364 else: # distance > 0 and step > 0: regular euclidean division
2365 q, r = divmod(distance, step)
2366 return int(q) + int(r != self._zero)
2367
2368 def __reduce__(self):
2369 return numeric_range, (self._start, self._stop, self._step)
2370
2371 def __repr__(self):
2372 if self._step == 1:
2373 return f"numeric_range({self._start!r}, {self._stop!r})"
2374 return (
2375 f"numeric_range({self._start!r}, {self._stop!r}, {self._step!r})"
2376 )
2377
2378 def __reversed__(self):
2379 return iter(
2380 numeric_range(
2381 self._get_by_index(-1), self._start - self._step, -self._step
2382 )
2383 )
2384
2385 def count(self, value):
2386 return int(value in self)
2387
2388 def index(self, value):
2389 if self._growing:
2390 if self._start <= value < self._stop:
2391 q, r = divmod(value - self._start, self._step)
2392 if r == self._zero:
2393 return int(q)
2394 else:
2395 if self._start >= value > self._stop:
2396 q, r = divmod(self._start - value, -self._step)
2397 if r == self._zero:
2398 return int(q)
2399
2400 raise ValueError(f"{value} is not in numeric range")
2401
2402 def _get_by_index(self, i):
2403 if i < 0:
2404 i += self._len
2405 if i < 0 or i >= self._len:
2406 raise IndexError("numeric range object index out of range")
2407 return self._start + i * self._step
2408
2409
2410def count_cycle(iterable, n=None):
2411 """Cycle through the items from *iterable* up to *n* times, yielding
2412 the number of completed cycles along with each item. If *n* is omitted the
2413 process repeats indefinitely.
2414
2415 >>> list(count_cycle('AB', 3))
2416 [(0, 'A'), (0, 'B'), (1, 'A'), (1, 'B'), (2, 'A'), (2, 'B')]
2417
2418 """
2419 iterable = tuple(iterable)
2420 if not iterable:
2421 return iter(())
2422 counter = count() if n is None else range(n)
2423 return ((i, item) for i in counter for item in iterable)
2424
2425
2426def mark_ends(iterable):
2427 """Yield 3-tuples of the form ``(is_first, is_last, item)``.
2428
2429 >>> list(mark_ends('ABC'))
2430 [(True, False, 'A'), (False, False, 'B'), (False, True, 'C')]
2431
2432 Use this when looping over an iterable to take special action on its first
2433 and/or last items:
2434
2435 >>> iterable = ['Header', 100, 200, 'Footer']
2436 >>> total = 0
2437 >>> for is_first, is_last, item in mark_ends(iterable):
2438 ... if is_first:
2439 ... continue # Skip the header
2440 ... if is_last:
2441 ... continue # Skip the footer
2442 ... total += item
2443 >>> print(total)
2444 300
2445 """
2446 it = iter(iterable)
2447
2448 try:
2449 b = next(it)
2450 except StopIteration:
2451 return
2452
2453 try:
2454 for i in count():
2455 a = b
2456 b = next(it)
2457 yield i == 0, False, a
2458
2459 except StopIteration:
2460 yield i == 0, True, a
2461
2462
2463def locate(iterable, pred=bool, window_size=None):
2464 """Yield the index of each item in *iterable* for which *pred* returns
2465 ``True``.
2466
2467 *pred* defaults to :func:`bool`, which will select truthy items:
2468
2469 >>> list(locate([0, 1, 1, 0, 1, 0, 0]))
2470 [1, 2, 4]
2471
2472 Set *pred* to a custom function to, e.g., find the indexes for a particular
2473 item.
2474
2475 >>> list(locate(['a', 'b', 'c', 'b'], lambda x: x == 'b'))
2476 [1, 3]
2477
2478 If *window_size* is given, then the *pred* function will be called with
2479 that many items. This enables searching for sub-sequences:
2480
2481 >>> iterable = [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]
2482 >>> pred = lambda *args: args == (1, 2, 3)
2483 >>> list(locate(iterable, pred=pred, window_size=3))
2484 [1, 5, 9]
2485
2486 Use with :func:`seekable` to find indexes and then retrieve the associated
2487 items:
2488
2489 >>> from itertools import count
2490 >>> from more_itertools import seekable
2491 >>> source = (3 * n + 1 if (n % 2) else n // 2 for n in count())
2492 >>> it = seekable(source)
2493 >>> pred = lambda x: x > 100
2494 >>> indexes = locate(it, pred=pred)
2495 >>> i = next(indexes)
2496 >>> it.seek(i)
2497 >>> next(it)
2498 106
2499
2500 """
2501 if window_size is None:
2502 return compress(count(), map(pred, iterable))
2503
2504 if window_size < 1:
2505 raise ValueError('window size must be at least 1')
2506
2507 it = windowed(iterable, window_size, fillvalue=_marker)
2508 return compress(count(), starmap(pred, it))
2509
2510
2511def longest_common_prefix(iterables):
2512 """Yield elements of the longest common prefix among given *iterables*.
2513
2514 >>> ''.join(longest_common_prefix(['abcd', 'abc', 'abf']))
2515 'ab'
2516
2517 """
2518 return (c[0] for c in takewhile(all_equal, zip(*iterables)))
2519
2520
2521def lstrip(iterable, pred):
2522 """Yield the items from *iterable*, but strip any from the beginning
2523 for which *pred* returns ``True``.
2524
2525 For example, to remove a set of items from the start of an iterable:
2526
2527 >>> iterable = (None, False, None, 1, 2, None, 3, False, None)
2528 >>> pred = lambda x: x in {None, False, ''}
2529 >>> list(lstrip(iterable, pred))
2530 [1, 2, None, 3, False, None]
2531
2532 This function is analogous to to :func:`str.lstrip`, and is essentially
2533 an wrapper for :func:`itertools.dropwhile`.
2534
2535 """
2536 return dropwhile(pred, iterable)
2537
2538
2539def rstrip(iterable, pred):
2540 """Yield the items from *iterable*, but strip any from the end
2541 for which *pred* returns ``True``.
2542
2543 For example, to remove a set of items from the end of an iterable:
2544
2545 >>> iterable = (None, False, None, 1, 2, None, 3, False, None)
2546 >>> pred = lambda x: x in {None, False, ''}
2547 >>> list(rstrip(iterable, pred))
2548 [None, False, None, 1, 2, None, 3]
2549
2550 This function is analogous to :func:`str.rstrip`.
2551
2552 """
2553 cache = []
2554 cache_append = cache.append
2555 cache_clear = cache.clear
2556 for x in iterable:
2557 if pred(x):
2558 cache_append(x)
2559 else:
2560 yield from cache
2561 cache_clear()
2562 yield x
2563
2564
2565def strip(iterable, pred):
2566 """Yield the items from *iterable*, but strip any from the
2567 beginning and end for which *pred* returns ``True``.
2568
2569 For example, to remove a set of items from both ends of an iterable:
2570
2571 >>> iterable = (None, False, None, 1, 2, None, 3, False, None)
2572 >>> pred = lambda x: x in {None, False, ''}
2573 >>> list(strip(iterable, pred))
2574 [1, 2, None, 3]
2575
2576 This function is analogous to :func:`str.strip`.
2577
2578 """
2579 return rstrip(lstrip(iterable, pred), pred)
2580
2581
2582class islice_extended:
2583 """An extension of :func:`itertools.islice` that supports negative values
2584 for *stop*, *start*, and *step*.
2585
2586 >>> iterator = iter('abcdefgh')
2587 >>> list(islice_extended(iterator, -4, -1))
2588 ['e', 'f', 'g']
2589
2590 Slices with negative values require some caching of *iterable*, but this
2591 function takes care to minimize the amount of memory required.
2592
2593 For example, you can use a negative step with an infinite iterator:
2594
2595 >>> from itertools import count
2596 >>> list(islice_extended(count(), 110, 99, -2))
2597 [110, 108, 106, 104, 102, 100]
2598
2599 You can also use slice notation directly:
2600
2601 >>> iterator = map(str, count())
2602 >>> it = islice_extended(iterator)[10:20:2]
2603 >>> list(it)
2604 ['10', '12', '14', '16', '18']
2605
2606 """
2607
2608 def __init__(self, iterable, *args):
2609 it = iter(iterable)
2610 if args:
2611 self._iterator = _islice_helper(it, slice(*args))
2612 else:
2613 self._iterator = it
2614
2615 def __iter__(self):
2616 return self
2617
2618 def __next__(self):
2619 return next(self._iterator)
2620
2621 def __getitem__(self, key):
2622 if isinstance(key, slice):
2623 return islice_extended(_islice_helper(self._iterator, key))
2624
2625 raise TypeError('islice_extended.__getitem__ argument must be a slice')
2626
2627
2628def _islice_helper(it, s):
2629 start = s.start
2630 stop = s.stop
2631 if s.step == 0:
2632 raise ValueError('step argument must be a non-zero integer or None.')
2633 step = s.step or 1
2634
2635 if step > 0:
2636 start = 0 if (start is None) else start
2637
2638 if start < 0:
2639 # Consume all but the last -start items
2640 cache = deque(enumerate(it, 1), maxlen=-start)
2641 len_iter = cache[-1][0] if cache else 0
2642
2643 # Adjust start to be positive
2644 i = max(len_iter + start, 0)
2645
2646 # Adjust stop to be positive
2647 if stop is None:
2648 j = len_iter
2649 elif stop >= 0:
2650 j = min(stop, len_iter)
2651 else:
2652 j = max(len_iter + stop, 0)
2653
2654 # Slice the cache
2655 n = j - i
2656 if n <= 0:
2657 return
2658
2659 for index in range(n):
2660 if index % step == 0:
2661 # pop and yield the item.
2662 # We don't want to use an intermediate variable
2663 # it would extend the lifetime of the current item
2664 yield cache.popleft()[1]
2665 else:
2666 # just pop and discard the item
2667 cache.popleft()
2668 elif (stop is not None) and (stop < 0):
2669 # Advance to the start position
2670 next(islice(it, start, start), None)
2671
2672 # When stop is negative, we have to carry -stop items while
2673 # iterating
2674 cache = deque(islice(it, -stop), maxlen=-stop)
2675
2676 for index, item in enumerate(it):
2677 if index % step == 0:
2678 # pop and yield the item.
2679 # We don't want to use an intermediate variable
2680 # it would extend the lifetime of the current item
2681 yield cache.popleft()
2682 else:
2683 # just pop and discard the item
2684 cache.popleft()
2685 cache.append(item)
2686 else:
2687 # When both start and stop are positive we have the normal case
2688 yield from islice(it, start, stop, step)
2689 else:
2690 start = -1 if (start is None) else start
2691
2692 if (stop is not None) and (stop < 0):
2693 # Consume all but the last items
2694 n = -stop - 1
2695 cache = deque(enumerate(it, 1), maxlen=n)
2696 len_iter = cache[-1][0] if cache else 0
2697
2698 # If start and stop are both negative they are comparable and
2699 # we can just slice. Otherwise we can adjust start to be negative
2700 # and then slice.
2701 if start < 0:
2702 i, j = start, stop
2703 else:
2704 i, j = min(start - len_iter, -1), None
2705
2706 for index, item in list(cache)[i:j:step]:
2707 yield item
2708 else:
2709 # Advance to the stop position
2710 if stop is not None:
2711 m = stop + 1
2712 next(islice(it, m, m), None)
2713
2714 # stop is positive, so if start is negative they are not comparable
2715 # and we need the rest of the items.
2716 if start < 0:
2717 i = start
2718 n = None
2719 # stop is None and start is positive, so we just need items up to
2720 # the start index.
2721 elif stop is None:
2722 i = None
2723 n = start + 1
2724 # Both stop and start are positive, so they are comparable.
2725 else:
2726 i = None
2727 n = start - stop
2728 if n <= 0:
2729 return
2730
2731 cache = list(islice(it, n))
2732
2733 yield from cache[i::step]
2734
2735
2736def always_reversible(iterable):
2737 """An extension of :func:`reversed` that supports all iterables, not
2738 just those which implement the ``Reversible`` or ``Sequence`` protocols.
2739
2740 >>> print(*always_reversible(x for x in range(3)))
2741 2 1 0
2742
2743 If the iterable is already reversible, this function returns the
2744 result of :func:`reversed()`. If the iterable is not reversible,
2745 this function will cache the remaining items in the iterable and
2746 yield them in reverse order, which may require significant storage.
2747 """
2748 try:
2749 return reversed(iterable)
2750 except TypeError:
2751 return reversed(list(iterable))
2752
2753
2754def consecutive_groups(iterable, ordering=None):
2755 """Yield groups of consecutive items using :func:`itertools.groupby`.
2756 The *ordering* function determines whether two items are adjacent by
2757 returning their position.
2758
2759 By default, the ordering function is the identity function. This is
2760 suitable for finding runs of numbers:
2761
2762 >>> iterable = [1, 10, 11, 12, 20, 30, 31, 32, 33, 40]
2763 >>> for group in consecutive_groups(iterable):
2764 ... print(list(group))
2765 [1]
2766 [10, 11, 12]
2767 [20]
2768 [30, 31, 32, 33]
2769 [40]
2770
2771 To find runs of adjacent letters, apply :func:`ord` function
2772 to convert letters to ordinals.
2773
2774 >>> iterable = 'abcdfgilmnop'
2775 >>> ordering = ord
2776 >>> for group in consecutive_groups(iterable, ordering):
2777 ... print(list(group))
2778 ['a', 'b', 'c', 'd']
2779 ['f', 'g']
2780 ['i']
2781 ['l', 'm', 'n', 'o', 'p']
2782
2783 Each group of consecutive items is an iterator that shares it source with
2784 *iterable*. When an an output group is advanced, the previous group is
2785 no longer available unless its elements are copied (e.g., into a ``list``).
2786
2787 >>> iterable = [1, 2, 11, 12, 21, 22]
2788 >>> saved_groups = []
2789 >>> for group in consecutive_groups(iterable):
2790 ... saved_groups.append(list(group)) # Copy group elements
2791 >>> saved_groups
2792 [[1, 2], [11, 12], [21, 22]]
2793
2794 """
2795 if ordering is None:
2796 key = lambda x: x[0] - x[1]
2797 else:
2798 key = lambda x: x[0] - ordering(x[1])
2799
2800 for k, g in groupby(enumerate(iterable), key=key):
2801 yield map(itemgetter(1), g)
2802
2803
2804def difference(iterable, func=sub, *, initial=None):
2805 """This function is the inverse of :func:`itertools.accumulate`. By default
2806 it will compute the first difference of *iterable* using
2807 :func:`operator.sub`:
2808
2809 >>> from itertools import accumulate
2810 >>> iterable = accumulate([0, 1, 2, 3, 4]) # produces 0, 1, 3, 6, 10
2811 >>> list(difference(iterable))
2812 [0, 1, 2, 3, 4]
2813
2814 *func* defaults to :func:`operator.sub`, but other functions can be
2815 specified. They will be applied as follows::
2816
2817 A, B, C, D, ... --> A, func(B, A), func(C, B), func(D, C), ...
2818
2819 For example, to do progressive division:
2820
2821 >>> iterable = [1, 2, 6, 24, 120]
2822 >>> func = lambda x, y: x // y
2823 >>> list(difference(iterable, func))
2824 [1, 2, 3, 4, 5]
2825
2826 If the *initial* keyword is set, the first element will be skipped when
2827 computing successive differences.
2828
2829 >>> it = [10, 11, 13, 16] # from accumulate([1, 2, 3], initial=10)
2830 >>> list(difference(it, initial=10))
2831 [1, 2, 3]
2832
2833 """
2834 a, b = tee(iterable)
2835 try:
2836 first = [next(b)]
2837 except StopIteration:
2838 return iter([])
2839
2840 if initial is not None:
2841 first = []
2842
2843 return chain(first, map(func, b, a))
2844
2845
2846class SequenceView(Sequence):
2847 """Return a read-only view of the sequence object *target*.
2848
2849 :class:`SequenceView` objects are analogous to Python's built-in
2850 "dictionary view" types. They provide a dynamic view of a sequence's items,
2851 meaning that when the sequence updates, so does the view.
2852
2853 >>> seq = ['0', '1', '2']
2854 >>> view = SequenceView(seq)
2855 >>> view
2856 SequenceView(['0', '1', '2'])
2857 >>> seq.append('3')
2858 >>> view
2859 SequenceView(['0', '1', '2', '3'])
2860
2861 Sequence views support indexing, slicing, and length queries. They act
2862 like the underlying sequence, except they don't allow assignment:
2863
2864 >>> view[1]
2865 '1'
2866 >>> view[1:-1]
2867 ['1', '2']
2868 >>> len(view)
2869 4
2870
2871 Sequence views are useful as an alternative to copying, as they don't
2872 require (much) extra storage.
2873
2874 """
2875
2876 def __init__(self, target):
2877 if not isinstance(target, Sequence):
2878 raise TypeError
2879 self._target = target
2880
2881 def __getitem__(self, index):
2882 return self._target[index]
2883
2884 def __len__(self):
2885 return len(self._target)
2886
2887 def __repr__(self):
2888 return f'{self.__class__.__name__}({self._target!r})'
2889
2890
2891class seekable:
2892 """Wrap an iterator to allow for seeking backward and forward. This
2893 progressively caches the items in the source iterable so they can be
2894 re-visited.
2895
2896 Call :meth:`seek` with an index to seek to that position in the source
2897 iterable.
2898
2899 To "reset" an iterator, seek to ``0``:
2900
2901 >>> from itertools import count
2902 >>> it = seekable((str(n) for n in count()))
2903 >>> next(it), next(it), next(it)
2904 ('0', '1', '2')
2905 >>> it.seek(0)
2906 >>> next(it), next(it), next(it)
2907 ('0', '1', '2')
2908
2909 You can also seek forward:
2910
2911 >>> it = seekable((str(n) for n in range(20)))
2912 >>> it.seek(10)
2913 >>> next(it)
2914 '10'
2915 >>> it.seek(20) # Seeking past the end of the source isn't a problem
2916 >>> list(it)
2917 []
2918 >>> it.seek(0) # Resetting works even after hitting the end
2919 >>> next(it)
2920 '0'
2921
2922 Call :meth:`relative_seek` to seek relative to the source iterator's
2923 current position.
2924
2925 >>> it = seekable((str(n) for n in range(20)))
2926 >>> next(it), next(it), next(it)
2927 ('0', '1', '2')
2928 >>> it.relative_seek(2)
2929 >>> next(it)
2930 '5'
2931 >>> it.relative_seek(-3) # Source is at '6', we move back to '3'
2932 >>> next(it)
2933 '3'
2934 >>> it.relative_seek(-3) # Source is at '4', we move back to '1'
2935 >>> next(it)
2936 '1'
2937
2938
2939 Call :meth:`peek` to look ahead one item without advancing the iterator:
2940
2941 >>> it = seekable('1234')
2942 >>> it.peek()
2943 '1'
2944 >>> list(it)
2945 ['1', '2', '3', '4']
2946 >>> it.peek(default='empty')
2947 'empty'
2948
2949 Before the iterator is at its end, calling :func:`bool` on it will return
2950 ``True``. After it will return ``False``:
2951
2952 >>> it = seekable('5678')
2953 >>> bool(it)
2954 True
2955 >>> list(it)
2956 ['5', '6', '7', '8']
2957 >>> bool(it)
2958 False
2959
2960 You may view the contents of the cache with the :meth:`elements` method.
2961 That returns a :class:`SequenceView`, a view that updates automatically:
2962
2963 >>> it = seekable((str(n) for n in range(10)))
2964 >>> next(it), next(it), next(it)
2965 ('0', '1', '2')
2966 >>> elements = it.elements()
2967 >>> elements
2968 SequenceView(['0', '1', '2'])
2969 >>> next(it)
2970 '3'
2971 >>> elements
2972 SequenceView(['0', '1', '2', '3'])
2973
2974 By default, the cache grows as the source iterable progresses, so beware of
2975 wrapping very large or infinite iterables. Supply *maxlen* to limit the
2976 size of the cache (this of course limits how far back you can seek).
2977
2978 >>> from itertools import count
2979 >>> it = seekable((str(n) for n in count()), maxlen=2)
2980 >>> next(it), next(it), next(it), next(it)
2981 ('0', '1', '2', '3')
2982 >>> list(it.elements())
2983 ['2', '3']
2984 >>> it.seek(0)
2985 >>> next(it), next(it), next(it), next(it)
2986 ('2', '3', '4', '5')
2987 >>> next(it)
2988 '6'
2989
2990 """
2991
2992 def __init__(self, iterable, maxlen=None):
2993 self._source = iter(iterable)
2994 if maxlen is None:
2995 self._cache = []
2996 else:
2997 self._cache = deque([], maxlen)
2998 self._index = None
2999
3000 def __iter__(self):
3001 return self
3002
3003 def __next__(self):
3004 if self._index is not None:
3005 try:
3006 item = self._cache[self._index]
3007 except IndexError:
3008 self._index = None
3009 else:
3010 self._index += 1
3011 return item
3012
3013 item = next(self._source)
3014 self._cache.append(item)
3015 return item
3016
3017 def __bool__(self):
3018 try:
3019 self.peek()
3020 except StopIteration:
3021 return False
3022 return True
3023
3024 def peek(self, default=_marker):
3025 try:
3026 peeked = next(self)
3027 except StopIteration:
3028 if default is _marker:
3029 raise
3030 return default
3031 if self._index is None:
3032 self._index = len(self._cache)
3033 self._index -= 1
3034 return peeked
3035
3036 def elements(self):
3037 return SequenceView(self._cache)
3038
3039 def seek(self, index):
3040 self._index = index
3041 remainder = index - len(self._cache)
3042 if remainder > 0:
3043 consume(self, remainder)
3044
3045 def relative_seek(self, count):
3046 if self._index is None:
3047 self._index = len(self._cache)
3048
3049 self.seek(max(self._index + count, 0))
3050
3051
3052class run_length:
3053 """
3054 :func:`run_length.encode` compresses an iterable with run-length encoding.
3055 It yields groups of repeated items with the count of how many times they
3056 were repeated:
3057
3058 >>> uncompressed = 'abbcccdddd'
3059 >>> list(run_length.encode(uncompressed))
3060 [('a', 1), ('b', 2), ('c', 3), ('d', 4)]
3061
3062 :func:`run_length.decode` decompresses an iterable that was previously
3063 compressed with run-length encoding. It yields the items of the
3064 decompressed iterable:
3065
3066 >>> compressed = [('a', 1), ('b', 2), ('c', 3), ('d', 4)]
3067 >>> list(run_length.decode(compressed))
3068 ['a', 'b', 'b', 'c', 'c', 'c', 'd', 'd', 'd', 'd']
3069
3070 """
3071
3072 @staticmethod
3073 def encode(iterable):
3074 return ((k, ilen(g)) for k, g in groupby(iterable))
3075
3076 @staticmethod
3077 def decode(iterable):
3078 return chain.from_iterable(starmap(repeat, iterable))
3079
3080
3081def exactly_n(iterable, n, predicate=bool):
3082 """Return ``True`` if exactly ``n`` items in the iterable are ``True``
3083 according to the *predicate* function.
3084
3085 >>> exactly_n([True, True, False], 2)
3086 True
3087 >>> exactly_n([True, True, False], 1)
3088 False
3089 >>> exactly_n([0, 1, 2, 3, 4, 5], 3, lambda x: x < 3)
3090 True
3091
3092 The iterable will be advanced until ``n + 1`` truthy items are encountered,
3093 so avoid calling it on infinite iterables.
3094
3095 """
3096 return ilen(islice(filter(predicate, iterable), n + 1)) == n
3097
3098
3099def circular_shifts(iterable, steps=1):
3100 """Yield the circular shifts of *iterable*.
3101
3102 >>> list(circular_shifts(range(4)))
3103 [(0, 1, 2, 3), (1, 2, 3, 0), (2, 3, 0, 1), (3, 0, 1, 2)]
3104
3105 Set *steps* to the number of places to rotate to the left
3106 (or to the right if negative). Defaults to 1.
3107
3108 >>> list(circular_shifts(range(4), 2))
3109 [(0, 1, 2, 3), (2, 3, 0, 1)]
3110
3111 >>> list(circular_shifts(range(4), -1))
3112 [(0, 1, 2, 3), (3, 0, 1, 2), (2, 3, 0, 1), (1, 2, 3, 0)]
3113
3114 """
3115 buffer = deque(iterable)
3116 if steps == 0:
3117 raise ValueError('Steps should be a non-zero integer')
3118
3119 buffer.rotate(steps)
3120 steps = -steps
3121 n = len(buffer)
3122 n //= math.gcd(n, steps)
3123
3124 for _ in repeat(None, n):
3125 buffer.rotate(steps)
3126 yield tuple(buffer)
3127
3128
3129def make_decorator(wrapping_func, result_index=0):
3130 """Return a decorator version of *wrapping_func*, which is a function that
3131 modifies an iterable. *result_index* is the position in that function's
3132 signature where the iterable goes.
3133
3134 This lets you use itertools on the "production end," i.e. at function
3135 definition. This can augment what the function returns without changing the
3136 function's code.
3137
3138 For example, to produce a decorator version of :func:`chunked`:
3139
3140 >>> from more_itertools import chunked
3141 >>> chunker = make_decorator(chunked, result_index=0)
3142 >>> @chunker(3)
3143 ... def iter_range(n):
3144 ... return iter(range(n))
3145 ...
3146 >>> list(iter_range(9))
3147 [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
3148
3149 To only allow truthy items to be returned:
3150
3151 >>> truth_serum = make_decorator(filter, result_index=1)
3152 >>> @truth_serum(bool)
3153 ... def boolean_test():
3154 ... return [0, 1, '', ' ', False, True]
3155 ...
3156 >>> list(boolean_test())
3157 [1, ' ', True]
3158
3159 The :func:`peekable` and :func:`seekable` wrappers make for practical
3160 decorators:
3161
3162 >>> from more_itertools import peekable
3163 >>> peekable_function = make_decorator(peekable)
3164 >>> @peekable_function()
3165 ... def str_range(*args):
3166 ... return (str(x) for x in range(*args))
3167 ...
3168 >>> it = str_range(1, 20, 2)
3169 >>> next(it), next(it), next(it)
3170 ('1', '3', '5')
3171 >>> it.peek()
3172 '7'
3173 >>> next(it)
3174 '7'
3175
3176 """
3177
3178 # See https://sites.google.com/site/bbayles/index/decorator_factory for
3179 # notes on how this works.
3180 def decorator(*wrapping_args, **wrapping_kwargs):
3181 def outer_wrapper(f):
3182 def inner_wrapper(*args, **kwargs):
3183 result = f(*args, **kwargs)
3184 wrapping_args_ = list(wrapping_args)
3185 wrapping_args_.insert(result_index, result)
3186 return wrapping_func(*wrapping_args_, **wrapping_kwargs)
3187
3188 return inner_wrapper
3189
3190 return outer_wrapper
3191
3192 return decorator
3193
3194
3195def map_reduce(iterable, keyfunc, valuefunc=None, reducefunc=None):
3196 """Return a dictionary that maps the items in *iterable* to categories
3197 defined by *keyfunc*, transforms them with *valuefunc*, and
3198 then summarizes them by category with *reducefunc*.
3199
3200 *valuefunc* defaults to the identity function if it is unspecified.
3201 If *reducefunc* is unspecified, no summarization takes place:
3202
3203 >>> keyfunc = lambda x: x.upper()
3204 >>> result = map_reduce('abbccc', keyfunc)
3205 >>> sorted(result.items())
3206 [('A', ['a']), ('B', ['b', 'b']), ('C', ['c', 'c', 'c'])]
3207
3208 Specifying *valuefunc* transforms the categorized items:
3209
3210 >>> keyfunc = lambda x: x.upper()
3211 >>> valuefunc = lambda x: 1
3212 >>> result = map_reduce('abbccc', keyfunc, valuefunc)
3213 >>> sorted(result.items())
3214 [('A', [1]), ('B', [1, 1]), ('C', [1, 1, 1])]
3215
3216 Specifying *reducefunc* summarizes the categorized items:
3217
3218 >>> keyfunc = lambda x: x.upper()
3219 >>> valuefunc = lambda x: 1
3220 >>> reducefunc = sum
3221 >>> result = map_reduce('abbccc', keyfunc, valuefunc, reducefunc)
3222 >>> sorted(result.items())
3223 [('A', 1), ('B', 2), ('C', 3)]
3224
3225 You may want to filter the input iterable before applying the map/reduce
3226 procedure:
3227
3228 >>> all_items = range(30)
3229 >>> items = [x for x in all_items if 10 <= x <= 20] # Filter
3230 >>> keyfunc = lambda x: x % 2 # Evens map to 0; odds to 1
3231 >>> categories = map_reduce(items, keyfunc=keyfunc)
3232 >>> sorted(categories.items())
3233 [(0, [10, 12, 14, 16, 18, 20]), (1, [11, 13, 15, 17, 19])]
3234 >>> summaries = map_reduce(items, keyfunc=keyfunc, reducefunc=sum)
3235 >>> sorted(summaries.items())
3236 [(0, 90), (1, 75)]
3237
3238 Note that all items in the iterable are gathered into a list before the
3239 summarization step, which may require significant storage.
3240
3241 The returned object is a :obj:`collections.defaultdict` with the
3242 ``default_factory`` set to ``None``, such that it behaves like a normal
3243 dictionary.
3244
3245 """
3246
3247 ret = defaultdict(list)
3248
3249 if valuefunc is None:
3250 for item in iterable:
3251 key = keyfunc(item)
3252 ret[key].append(item)
3253
3254 else:
3255 for item in iterable:
3256 key = keyfunc(item)
3257 value = valuefunc(item)
3258 ret[key].append(value)
3259
3260 if reducefunc is not None:
3261 for key, value_list in ret.items():
3262 ret[key] = reducefunc(value_list)
3263
3264 ret.default_factory = None
3265 return ret
3266
3267
3268def rlocate(iterable, pred=bool, window_size=None):
3269 """Yield the index of each item in *iterable* for which *pred* returns
3270 ``True``, starting from the right and moving left.
3271
3272 *pred* defaults to :func:`bool`, which will select truthy items:
3273
3274 >>> list(rlocate([0, 1, 1, 0, 1, 0, 0])) # Truthy at 1, 2, and 4
3275 [4, 2, 1]
3276
3277 Set *pred* to a custom function to, e.g., find the indexes for a particular
3278 item:
3279
3280 >>> iterator = iter('abcb')
3281 >>> pred = lambda x: x == 'b'
3282 >>> list(rlocate(iterator, pred))
3283 [3, 1]
3284
3285 If *window_size* is given, then the *pred* function will be called with
3286 that many items. This enables searching for sub-sequences:
3287
3288 >>> iterable = [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]
3289 >>> pred = lambda *args: args == (1, 2, 3)
3290 >>> list(rlocate(iterable, pred=pred, window_size=3))
3291 [9, 5, 1]
3292
3293 Beware, this function won't return anything for infinite iterables.
3294 If *iterable* is reversible, ``rlocate`` will reverse it and search from
3295 the right. Otherwise, it will search from the left and return the results
3296 in reverse order.
3297
3298 See :func:`locate` to for other example applications.
3299
3300 """
3301 if window_size is None:
3302 try:
3303 len_iter = len(iterable)
3304 return (len_iter - i - 1 for i in locate(reversed(iterable), pred))
3305 except TypeError:
3306 pass
3307
3308 return reversed(list(locate(iterable, pred, window_size)))
3309
3310
3311def replace(iterable, pred, substitutes, count=None, window_size=1):
3312 """Yield the items from *iterable*, replacing the items for which *pred*
3313 returns ``True`` with the items from the iterable *substitutes*.
3314
3315 >>> iterable = [1, 1, 0, 1, 1, 0, 1, 1]
3316 >>> pred = lambda x: x == 0
3317 >>> substitutes = (2, 3)
3318 >>> list(replace(iterable, pred, substitutes))
3319 [1, 1, 2, 3, 1, 1, 2, 3, 1, 1]
3320
3321 If *count* is given, the number of replacements will be limited:
3322
3323 >>> iterable = [1, 1, 0, 1, 1, 0, 1, 1, 0]
3324 >>> pred = lambda x: x == 0
3325 >>> substitutes = [None]
3326 >>> list(replace(iterable, pred, substitutes, count=2))
3327 [1, 1, None, 1, 1, None, 1, 1, 0]
3328
3329 Use *window_size* to control the number of items passed as arguments to
3330 *pred*. This allows for locating and replacing subsequences.
3331
3332 >>> iterable = [0, 1, 2, 5, 0, 1, 2, 5]
3333 >>> window_size = 3
3334 >>> pred = lambda *args: args == (0, 1, 2) # 3 items passed to pred
3335 >>> substitutes = [3, 4] # Splice in these items
3336 >>> list(replace(iterable, pred, substitutes, window_size=window_size))
3337 [3, 4, 5, 3, 4, 5]
3338
3339 """
3340 if window_size < 1:
3341 raise ValueError('window_size must be at least 1')
3342
3343 # Save the substitutes iterable, since it's used more than once
3344 substitutes = tuple(substitutes)
3345
3346 # Add padding such that the number of windows matches the length of the
3347 # iterable
3348 it = chain(iterable, repeat(_marker, window_size - 1))
3349 windows = windowed(it, window_size)
3350
3351 n = 0
3352 for w in windows:
3353 # If the current window matches our predicate (and we haven't hit
3354 # our maximum number of replacements), splice in the substitutes
3355 # and then consume the following windows that overlap with this one.
3356 # For example, if the iterable is (0, 1, 2, 3, 4...)
3357 # and the window size is 2, we have (0, 1), (1, 2), (2, 3)...
3358 # If the predicate matches on (0, 1), we need to zap (0, 1) and (1, 2)
3359 if pred(*w):
3360 if (count is None) or (n < count):
3361 n += 1
3362 yield from substitutes
3363 consume(windows, window_size - 1)
3364 continue
3365
3366 # If there was no match (or we've reached the replacement limit),
3367 # yield the first item from the window.
3368 if w and (w[0] is not _marker):
3369 yield w[0]
3370
3371
3372def partitions(iterable):
3373 """Yield all possible order-preserving partitions of *iterable*.
3374
3375 >>> iterable = 'abc'
3376 >>> for part in partitions(iterable):
3377 ... print([''.join(p) for p in part])
3378 ['abc']
3379 ['a', 'bc']
3380 ['ab', 'c']
3381 ['a', 'b', 'c']
3382
3383 This is unrelated to :func:`partition`.
3384
3385 """
3386 sequence = list(iterable)
3387 n = len(sequence)
3388 for i in powerset(range(1, n)):
3389 yield [sequence[i:j] for i, j in zip((0,) + i, i + (n,))]
3390
3391
3392def set_partitions(iterable, k=None, min_size=None, max_size=None):
3393 """
3394 Yield the set partitions of *iterable* into *k* parts. Set partitions are
3395 not order-preserving.
3396
3397 >>> iterable = 'abc'
3398 >>> for part in set_partitions(iterable, 2):
3399 ... print([''.join(p) for p in part])
3400 ['a', 'bc']
3401 ['ab', 'c']
3402 ['b', 'ac']
3403
3404
3405 If *k* is not given, every set partition is generated.
3406
3407 >>> iterable = 'abc'
3408 >>> for part in set_partitions(iterable):
3409 ... print([''.join(p) for p in part])
3410 ['abc']
3411 ['a', 'bc']
3412 ['ab', 'c']
3413 ['b', 'ac']
3414 ['a', 'b', 'c']
3415
3416 if *min_size* and/or *max_size* are given, the minimum and/or maximum size
3417 per block in partition is set.
3418
3419 >>> iterable = 'abc'
3420 >>> for part in set_partitions(iterable, min_size=2):
3421 ... print([''.join(p) for p in part])
3422 ['abc']
3423 >>> for part in set_partitions(iterable, max_size=2):
3424 ... print([''.join(p) for p in part])
3425 ['a', 'bc']
3426 ['ab', 'c']
3427 ['b', 'ac']
3428 ['a', 'b', 'c']
3429
3430 """
3431 L = list(iterable)
3432 n = len(L)
3433 if k is not None:
3434 if k < 1:
3435 raise ValueError(
3436 "Can't partition in a negative or zero number of groups"
3437 )
3438 elif k > n:
3439 return
3440
3441 min_size = min_size if min_size is not None else 0
3442 max_size = max_size if max_size is not None else n
3443 if min_size > max_size:
3444 return
3445
3446 def set_partitions_helper(L, k):
3447 n = len(L)
3448 if k == 1:
3449 yield [L]
3450 elif n == k:
3451 yield [[s] for s in L]
3452 else:
3453 e, *M = L
3454 for p in set_partitions_helper(M, k - 1):
3455 yield [[e], *p]
3456 for p in set_partitions_helper(M, k):
3457 for i in range(len(p)):
3458 yield p[:i] + [[e] + p[i]] + p[i + 1 :]
3459
3460 if k is None:
3461 for k in range(1, n + 1):
3462 yield from filter(
3463 lambda z: all(min_size <= len(bk) <= max_size for bk in z),
3464 set_partitions_helper(L, k),
3465 )
3466 else:
3467 yield from filter(
3468 lambda z: all(min_size <= len(bk) <= max_size for bk in z),
3469 set_partitions_helper(L, k),
3470 )
3471
3472
3473class time_limited:
3474 """
3475 Yield items from *iterable* until *limit_seconds* have passed.
3476 If the time limit expires before all items have been yielded, the
3477 ``timed_out`` parameter will be set to ``True``.
3478
3479 >>> from time import sleep
3480 >>> def generator():
3481 ... yield 1
3482 ... yield 2
3483 ... sleep(0.2)
3484 ... yield 3
3485 >>> iterable = time_limited(0.1, generator())
3486 >>> list(iterable)
3487 [1, 2]
3488 >>> iterable.timed_out
3489 True
3490
3491 Note that the time is checked before each item is yielded, and iteration
3492 stops if the time elapsed is greater than *limit_seconds*. If your time
3493 limit is 1 second, but it takes 2 seconds to generate the first item from
3494 the iterable, the function will run for 2 seconds and not yield anything.
3495 As a special case, when *limit_seconds* is zero, the iterator never
3496 returns anything.
3497
3498 """
3499
3500 def __init__(self, limit_seconds, iterable):
3501 if limit_seconds < 0:
3502 raise ValueError('limit_seconds must be positive')
3503 self.limit_seconds = limit_seconds
3504 self._iterator = iter(iterable)
3505 self._start_time = monotonic()
3506 self.timed_out = False
3507
3508 def __iter__(self):
3509 return self
3510
3511 def __next__(self):
3512 if self.limit_seconds == 0:
3513 self.timed_out = True
3514 raise StopIteration
3515 item = next(self._iterator)
3516 if monotonic() - self._start_time > self.limit_seconds:
3517 self.timed_out = True
3518 raise StopIteration
3519
3520 return item
3521
3522
3523def only(iterable, default=None, too_long=None):
3524 """If *iterable* has only one item, return it.
3525 If it has zero items, return *default*.
3526 If it has more than one item, raise the exception given by *too_long*,
3527 which is ``ValueError`` by default.
3528
3529 >>> only([], default='missing')
3530 'missing'
3531 >>> only([1])
3532 1
3533 >>> only([1, 2]) # doctest: +IGNORE_EXCEPTION_DETAIL
3534 Traceback (most recent call last):
3535 ...
3536 ValueError: Expected exactly one item in iterable, but got 1, 2,
3537 and perhaps more.'
3538 >>> only([1, 2], too_long=TypeError) # doctest: +IGNORE_EXCEPTION_DETAIL
3539 Traceback (most recent call last):
3540 ...
3541 TypeError
3542
3543 Note that :func:`only` attempts to advance *iterable* twice to ensure there
3544 is only one item. See :func:`spy` or :func:`peekable` to check
3545 iterable contents less destructively.
3546
3547 """
3548 iterator = iter(iterable)
3549 for first in iterator:
3550 for second in iterator:
3551 msg = (
3552 f'Expected exactly one item in iterable, but got {first!r}, '
3553 f'{second!r}, and perhaps more.'
3554 )
3555 raise too_long or ValueError(msg)
3556 return first
3557 return default
3558
3559
3560def _ichunk(iterator, n):
3561 cache = deque()
3562 chunk = islice(iterator, n)
3563
3564 def generator():
3565 with suppress(StopIteration):
3566 while True:
3567 if cache:
3568 yield cache.popleft()
3569 else:
3570 yield next(chunk)
3571
3572 def materialize_next(n=1):
3573 # if n not specified materialize everything
3574 if n is None:
3575 cache.extend(chunk)
3576 return len(cache)
3577
3578 to_cache = n - len(cache)
3579
3580 # materialize up to n
3581 if to_cache > 0:
3582 cache.extend(islice(chunk, to_cache))
3583
3584 # return number materialized up to n
3585 return min(n, len(cache))
3586
3587 return (generator(), materialize_next)
3588
3589
3590def ichunked(iterable, n):
3591 """Break *iterable* into sub-iterables with *n* elements each.
3592 :func:`ichunked` is like :func:`chunked`, but it yields iterables
3593 instead of lists.
3594
3595 If the sub-iterables are read in order, the elements of *iterable*
3596 won't be stored in memory.
3597 If they are read out of order, :func:`itertools.tee` is used to cache
3598 elements as necessary.
3599
3600 >>> from itertools import count
3601 >>> all_chunks = ichunked(count(), 4)
3602 >>> c_1, c_2, c_3 = next(all_chunks), next(all_chunks), next(all_chunks)
3603 >>> list(c_2) # c_1's elements have been cached; c_3's haven't been
3604 [4, 5, 6, 7]
3605 >>> list(c_1)
3606 [0, 1, 2, 3]
3607 >>> list(c_3)
3608 [8, 9, 10, 11]
3609
3610 """
3611 iterator = iter(iterable)
3612 while True:
3613 # Create new chunk
3614 chunk, materialize_next = _ichunk(iterator, n)
3615
3616 # Check to see whether we're at the end of the source iterable
3617 if not materialize_next():
3618 return
3619
3620 yield chunk
3621
3622 # Fill previous chunk's cache
3623 materialize_next(None)
3624
3625
3626def iequals(*iterables):
3627 """Return ``True`` if all given *iterables* are equal to each other,
3628 which means that they contain the same elements in the same order.
3629
3630 The function is useful for comparing iterables of different data types
3631 or iterables that do not support equality checks.
3632
3633 >>> iequals("abc", ['a', 'b', 'c'], ('a', 'b', 'c'), iter("abc"))
3634 True
3635
3636 >>> iequals("abc", "acb")
3637 False
3638
3639 Not to be confused with :func:`all_equal`, which checks whether all
3640 elements of iterable are equal to each other.
3641
3642 """
3643 return all(map(all_equal, zip_longest(*iterables, fillvalue=object())))
3644
3645
3646def distinct_combinations(iterable, r):
3647 """Yield the distinct combinations of *r* items taken from *iterable*.
3648
3649 >>> list(distinct_combinations([0, 0, 1], 2))
3650 [(0, 0), (0, 1)]
3651
3652 Equivalent to ``set(combinations(iterable))``, except duplicates are not
3653 generated and thrown away. For larger input sequences this is much more
3654 efficient.
3655
3656 """
3657 if r < 0:
3658 raise ValueError('r must be non-negative')
3659 elif r == 0:
3660 yield ()
3661 return
3662 pool = tuple(iterable)
3663 generators = [unique_everseen(enumerate(pool), key=itemgetter(1))]
3664 current_combo = [None] * r
3665 level = 0
3666 while generators:
3667 try:
3668 cur_idx, p = next(generators[-1])
3669 except StopIteration:
3670 generators.pop()
3671 level -= 1
3672 continue
3673 current_combo[level] = p
3674 if level + 1 == r:
3675 yield tuple(current_combo)
3676 else:
3677 generators.append(
3678 unique_everseen(
3679 enumerate(pool[cur_idx + 1 :], cur_idx + 1),
3680 key=itemgetter(1),
3681 )
3682 )
3683 level += 1
3684
3685
3686def filter_except(validator, iterable, *exceptions):
3687 """Yield the items from *iterable* for which the *validator* function does
3688 not raise one of the specified *exceptions*.
3689
3690 *validator* is called for each item in *iterable*.
3691 It should be a function that accepts one argument and raises an exception
3692 if that item is not valid.
3693
3694 >>> iterable = ['1', '2', 'three', '4', None]
3695 >>> list(filter_except(int, iterable, ValueError, TypeError))
3696 ['1', '2', '4']
3697
3698 If an exception other than one given by *exceptions* is raised by
3699 *validator*, it is raised like normal.
3700 """
3701 for item in iterable:
3702 try:
3703 validator(item)
3704 except exceptions:
3705 pass
3706 else:
3707 yield item
3708
3709
3710def map_except(function, iterable, *exceptions):
3711 """Transform each item from *iterable* with *function* and yield the
3712 result, unless *function* raises one of the specified *exceptions*.
3713
3714 *function* is called to transform each item in *iterable*.
3715 It should accept one argument.
3716
3717 >>> iterable = ['1', '2', 'three', '4', None]
3718 >>> list(map_except(int, iterable, ValueError, TypeError))
3719 [1, 2, 4]
3720
3721 If an exception other than one given by *exceptions* is raised by
3722 *function*, it is raised like normal.
3723 """
3724 for item in iterable:
3725 try:
3726 yield function(item)
3727 except exceptions:
3728 pass
3729
3730
3731def map_if(iterable, pred, func, func_else=None):
3732 """Evaluate each item from *iterable* using *pred*. If the result is
3733 equivalent to ``True``, transform the item with *func* and yield it.
3734 Otherwise, transform the item with *func_else* and yield it.
3735
3736 *pred*, *func*, and *func_else* should each be functions that accept
3737 one argument. By default, *func_else* is the identity function.
3738
3739 >>> from math import sqrt
3740 >>> iterable = list(range(-5, 5))
3741 >>> iterable
3742 [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4]
3743 >>> list(map_if(iterable, lambda x: x > 3, lambda x: 'toobig'))
3744 [-5, -4, -3, -2, -1, 0, 1, 2, 3, 'toobig']
3745 >>> list(map_if(iterable, lambda x: x >= 0,
3746 ... lambda x: f'{sqrt(x):.2f}', lambda x: None))
3747 [None, None, None, None, None, '0.00', '1.00', '1.41', '1.73', '2.00']
3748 """
3749
3750 if func_else is None:
3751 for item in iterable:
3752 yield func(item) if pred(item) else item
3753
3754 else:
3755 for item in iterable:
3756 yield func(item) if pred(item) else func_else(item)
3757
3758
3759def _sample_unweighted(iterator, k, strict):
3760 # Algorithm L in the 1994 paper by Kim-Hung Li:
3761 # "Reservoir-Sampling Algorithms of Time Complexity O(n(1+log(N/n)))".
3762
3763 reservoir = list(islice(iterator, k))
3764 if strict and len(reservoir) < k:
3765 raise ValueError('Sample larger than population')
3766 W = 1.0
3767
3768 with suppress(StopIteration):
3769 while True:
3770 W *= random() ** (1 / k)
3771 skip = floor(log(random()) / log1p(-W))
3772 element = next(islice(iterator, skip, None))
3773 reservoir[randrange(k)] = element
3774
3775 shuffle(reservoir)
3776 return reservoir
3777
3778
3779def _sample_weighted(iterator, k, weights, strict):
3780 # Implementation of "A-ExpJ" from the 2006 paper by Efraimidis et al. :
3781 # "Weighted random sampling with a reservoir".
3782
3783 # Log-transform for numerical stability for weights that are small/large
3784 weight_keys = (log(random()) / weight for weight in weights)
3785
3786 # Fill up the reservoir (collection of samples) with the first `k`
3787 # weight-keys and elements, then heapify the list.
3788 reservoir = take(k, zip(weight_keys, iterator))
3789 if strict and len(reservoir) < k:
3790 raise ValueError('Sample larger than population')
3791
3792 heapify(reservoir)
3793
3794 # The number of jumps before changing the reservoir is a random variable
3795 # with an exponential distribution. Sample it using random() and logs.
3796 smallest_weight_key, _ = reservoir[0]
3797 weights_to_skip = log(random()) / smallest_weight_key
3798
3799 for weight, element in zip(weights, iterator):
3800 if weight >= weights_to_skip:
3801 # The notation here is consistent with the paper, but we store
3802 # the weight-keys in log-space for better numerical stability.
3803 smallest_weight_key, _ = reservoir[0]
3804 t_w = exp(weight * smallest_weight_key)
3805 r_2 = uniform(t_w, 1) # generate U(t_w, 1)
3806 weight_key = log(r_2) / weight
3807 heapreplace(reservoir, (weight_key, element))
3808 smallest_weight_key, _ = reservoir[0]
3809 weights_to_skip = log(random()) / smallest_weight_key
3810 else:
3811 weights_to_skip -= weight
3812
3813 ret = [element for weight_key, element in reservoir]
3814 shuffle(ret)
3815 return ret
3816
3817
3818def _sample_counted(population, k, counts, strict):
3819 element = None
3820 remaining = 0
3821
3822 def feed(i):
3823 # Advance *i* steps ahead and consume an element
3824 nonlocal element, remaining
3825
3826 while i + 1 > remaining:
3827 i = i - remaining
3828 element = next(population)
3829 remaining = next(counts)
3830 remaining -= i + 1
3831 return element
3832
3833 with suppress(StopIteration):
3834 reservoir = []
3835 for _ in range(k):
3836 reservoir.append(feed(0))
3837
3838 if strict and len(reservoir) < k:
3839 raise ValueError('Sample larger than population')
3840
3841 with suppress(StopIteration):
3842 W = 1.0
3843 while True:
3844 W *= random() ** (1 / k)
3845 skip = floor(log(random()) / log1p(-W))
3846 element = feed(skip)
3847 reservoir[randrange(k)] = element
3848
3849 shuffle(reservoir)
3850 return reservoir
3851
3852
3853def sample(iterable, k, weights=None, *, counts=None, strict=False):
3854 """Return a *k*-length list of elements chosen (without replacement)
3855 from the *iterable*. Similar to :func:`random.sample`, but works on
3856 iterables of unknown length.
3857
3858 >>> iterable = range(100)
3859 >>> sample(iterable, 5) # doctest: +SKIP
3860 [81, 60, 96, 16, 4]
3861
3862 For iterables with repeated elements, you may supply *counts* to
3863 indicate the repeats.
3864
3865 >>> iterable = ['a', 'b']
3866 >>> counts = [3, 4] # Equivalent to 'a', 'a', 'a', 'b', 'b', 'b', 'b'
3867 >>> sample(iterable, k=3, counts=counts) # doctest: +SKIP
3868 ['a', 'a', 'b']
3869
3870 An iterable with *weights* may be given:
3871
3872 >>> iterable = range(100)
3873 >>> weights = (i * i + 1 for i in range(100))
3874 >>> sampled = sample(iterable, 5, weights=weights) # doctest: +SKIP
3875 [79, 67, 74, 66, 78]
3876
3877 Weighted selections are made without replacement.
3878 After an element is selected, it is removed from the pool and the
3879 relative weights of the other elements increase (this
3880 does not match the behavior of :func:`random.sample`'s *counts*
3881 parameter). Note that *weights* may not be used with *counts*.
3882
3883 If the length of *iterable* is less than *k*,
3884 ``ValueError`` is raised if *strict* is ``True`` and
3885 all elements are returned (in shuffled order) if *strict* is ``False``.
3886
3887 By default, the `Algorithm L <https://w.wiki/ANrM>`__ reservoir sampling
3888 technique is used. When *weights* are provided,
3889 `Algorithm A-ExpJ <https://w.wiki/ANrS>`__ is used.
3890 """
3891 iterator = iter(iterable)
3892
3893 if k < 0:
3894 raise ValueError('k must be non-negative')
3895
3896 if k == 0:
3897 return []
3898
3899 if weights is not None and counts is not None:
3900 raise TypeError('weights and counts are mutually exclusive')
3901
3902 elif weights is not None:
3903 weights = iter(weights)
3904 return _sample_weighted(iterator, k, weights, strict)
3905
3906 elif counts is not None:
3907 counts = iter(counts)
3908 return _sample_counted(iterator, k, counts, strict)
3909
3910 else:
3911 return _sample_unweighted(iterator, k, strict)
3912
3913
3914def is_sorted(iterable, key=None, reverse=False, strict=False):
3915 """Returns ``True`` if the items of iterable are in sorted order, and
3916 ``False`` otherwise. *key* and *reverse* have the same meaning that they do
3917 in the built-in :func:`sorted` function.
3918
3919 >>> is_sorted(['1', '2', '3', '4', '5'], key=int)
3920 True
3921 >>> is_sorted([5, 4, 3, 1, 2], reverse=True)
3922 False
3923
3924 If *strict*, tests for strict sorting, that is, returns ``False`` if equal
3925 elements are found:
3926
3927 >>> is_sorted([1, 2, 2])
3928 True
3929 >>> is_sorted([1, 2, 2], strict=True)
3930 False
3931
3932 The function returns ``False`` after encountering the first out-of-order
3933 item, which means it may produce results that differ from the built-in
3934 :func:`sorted` function for objects with unusual comparison dynamics
3935 (like ``math.nan``). If there are no out-of-order items, the iterable is
3936 exhausted.
3937 """
3938 it = iterable if (key is None) else map(key, iterable)
3939 a, b = tee(it)
3940 next(b, None)
3941 if reverse:
3942 b, a = a, b
3943 return all(map(lt, a, b)) if strict else not any(map(lt, b, a))
3944
3945
3946class AbortThread(BaseException):
3947 pass
3948
3949
3950class callback_iter:
3951 """Convert a function that uses callbacks to an iterator.
3952
3953 Let *func* be a function that takes a `callback` keyword argument.
3954 For example:
3955
3956 >>> def func(callback=None):
3957 ... for i, c in [(1, 'a'), (2, 'b'), (3, 'c')]:
3958 ... if callback:
3959 ... callback(i, c)
3960 ... return 4
3961
3962
3963 Use ``with callback_iter(func)`` to get an iterator over the parameters
3964 that are delivered to the callback.
3965
3966 >>> with callback_iter(func) as it:
3967 ... for args, kwargs in it:
3968 ... print(args)
3969 (1, 'a')
3970 (2, 'b')
3971 (3, 'c')
3972
3973 The function will be called in a background thread. The ``done`` property
3974 indicates whether it has completed execution.
3975
3976 >>> it.done
3977 True
3978
3979 If it completes successfully, its return value will be available
3980 in the ``result`` property.
3981
3982 >>> it.result
3983 4
3984
3985 Notes:
3986
3987 * If the function uses some keyword argument besides ``callback``, supply
3988 *callback_kwd*.
3989 * If it finished executing, but raised an exception, accessing the
3990 ``result`` property will raise the same exception.
3991 * If it hasn't finished executing, accessing the ``result``
3992 property from within the ``with`` block will raise ``RuntimeError``.
3993 * If it hasn't finished executing, accessing the ``result`` property from
3994 outside the ``with`` block will raise a
3995 ``more_itertools.AbortThread`` exception.
3996 * Provide *wait_seconds* to adjust how frequently the it is polled for
3997 output.
3998
3999 """
4000
4001 def __init__(self, func, callback_kwd='callback', wait_seconds=0.1):
4002 self._func = func
4003 self._callback_kwd = callback_kwd
4004 self._aborted = False
4005 self._future = None
4006 self._wait_seconds = wait_seconds
4007 # Lazily import concurrent.future
4008 self._executor = __import__(
4009 'concurrent.futures'
4010 ).futures.ThreadPoolExecutor(max_workers=1)
4011 self._iterator = self._reader()
4012
4013 def __enter__(self):
4014 return self
4015
4016 def __exit__(self, exc_type, exc_value, traceback):
4017 self._aborted = True
4018 self._executor.shutdown()
4019
4020 def __iter__(self):
4021 return self
4022
4023 def __next__(self):
4024 return next(self._iterator)
4025
4026 @property
4027 def done(self):
4028 if self._future is None:
4029 return False
4030 return self._future.done()
4031
4032 @property
4033 def result(self):
4034 if not self.done:
4035 raise RuntimeError('Function has not yet completed')
4036
4037 return self._future.result()
4038
4039 def _reader(self):
4040 q = Queue()
4041
4042 def callback(*args, **kwargs):
4043 if self._aborted:
4044 raise AbortThread('canceled by user')
4045
4046 q.put((args, kwargs))
4047
4048 self._future = self._executor.submit(
4049 self._func, **{self._callback_kwd: callback}
4050 )
4051
4052 while True:
4053 try:
4054 item = q.get(timeout=self._wait_seconds)
4055 except Empty:
4056 pass
4057 else:
4058 q.task_done()
4059 yield item
4060
4061 if self._future.done():
4062 break
4063
4064 remaining = []
4065 while True:
4066 try:
4067 item = q.get_nowait()
4068 except Empty:
4069 break
4070 else:
4071 q.task_done()
4072 remaining.append(item)
4073 q.join()
4074 yield from remaining
4075
4076
4077def windowed_complete(iterable, n):
4078 """
4079 Yield ``(beginning, middle, end)`` tuples, where:
4080
4081 * Each ``middle`` has *n* items from *iterable*
4082 * Each ``beginning`` has the items before the ones in ``middle``
4083 * Each ``end`` has the items after the ones in ``middle``
4084
4085 >>> iterable = range(7)
4086 >>> n = 3
4087 >>> for beginning, middle, end in windowed_complete(iterable, n):
4088 ... print(beginning, middle, end)
4089 () (0, 1, 2) (3, 4, 5, 6)
4090 (0,) (1, 2, 3) (4, 5, 6)
4091 (0, 1) (2, 3, 4) (5, 6)
4092 (0, 1, 2) (3, 4, 5) (6,)
4093 (0, 1, 2, 3) (4, 5, 6) ()
4094
4095 Note that *n* must be at least 0 and most equal to the length of
4096 *iterable*.
4097
4098 This function will exhaust the iterable and may require significant
4099 storage.
4100 """
4101 if n < 0:
4102 raise ValueError('n must be >= 0')
4103
4104 seq = tuple(iterable)
4105 size = len(seq)
4106
4107 if n > size:
4108 raise ValueError('n must be <= len(seq)')
4109
4110 for i in range(size - n + 1):
4111 beginning = seq[:i]
4112 middle = seq[i : i + n]
4113 end = seq[i + n :]
4114 yield beginning, middle, end
4115
4116
4117def all_unique(iterable, key=None):
4118 """
4119 Returns ``True`` if all the elements of *iterable* are unique (no two
4120 elements are equal).
4121
4122 >>> all_unique('ABCB')
4123 False
4124
4125 If a *key* function is specified, it will be used to make comparisons.
4126
4127 >>> all_unique('ABCb')
4128 True
4129 >>> all_unique('ABCb', str.lower)
4130 False
4131
4132 The function returns as soon as the first non-unique element is
4133 encountered. Iterables with a mix of hashable and unhashable items can
4134 be used, but the function will be slower for unhashable items.
4135 """
4136 seenset = set()
4137 seenset_add = seenset.add
4138 seenlist = []
4139 seenlist_add = seenlist.append
4140 for element in map(key, iterable) if key else iterable:
4141 try:
4142 if element in seenset:
4143 return False
4144 seenset_add(element)
4145 except TypeError:
4146 if element in seenlist:
4147 return False
4148 seenlist_add(element)
4149 return True
4150
4151
4152def nth_product(index, *args):
4153 """Equivalent to ``list(product(*args))[index]``.
4154
4155 The products of *args* can be ordered lexicographically.
4156 :func:`nth_product` computes the product at sort position *index* without
4157 computing the previous products.
4158
4159 >>> nth_product(8, range(2), range(2), range(2), range(2))
4160 (1, 0, 0, 0)
4161
4162 ``IndexError`` will be raised if the given *index* is invalid.
4163 """
4164 pools = list(map(tuple, reversed(args)))
4165 ns = list(map(len, pools))
4166
4167 c = reduce(mul, ns)
4168
4169 if index < 0:
4170 index += c
4171
4172 if not 0 <= index < c:
4173 raise IndexError
4174
4175 result = []
4176 for pool, n in zip(pools, ns):
4177 result.append(pool[index % n])
4178 index //= n
4179
4180 return tuple(reversed(result))
4181
4182
4183def nth_permutation(iterable, r, index):
4184 """Equivalent to ``list(permutations(iterable, r))[index]```
4185
4186 The subsequences of *iterable* that are of length *r* where order is
4187 important can be ordered lexicographically. :func:`nth_permutation`
4188 computes the subsequence at sort position *index* directly, without
4189 computing the previous subsequences.
4190
4191 >>> nth_permutation('ghijk', 2, 5)
4192 ('h', 'i')
4193
4194 ``ValueError`` will be raised If *r* is negative or greater than the length
4195 of *iterable*.
4196 ``IndexError`` will be raised if the given *index* is invalid.
4197 """
4198 pool = list(iterable)
4199 n = len(pool)
4200
4201 if r is None or r == n:
4202 r, c = n, factorial(n)
4203 elif not 0 <= r < n:
4204 raise ValueError
4205 else:
4206 c = perm(n, r)
4207 assert c > 0 # factorial(n)>0, and r<n so perm(n,r) is never zero
4208
4209 if index < 0:
4210 index += c
4211
4212 if not 0 <= index < c:
4213 raise IndexError
4214
4215 result = [0] * r
4216 q = index * factorial(n) // c if r < n else index
4217 for d in range(1, n + 1):
4218 q, i = divmod(q, d)
4219 if 0 <= n - d < r:
4220 result[n - d] = i
4221 if q == 0:
4222 break
4223
4224 return tuple(map(pool.pop, result))
4225
4226
4227def nth_combination_with_replacement(iterable, r, index):
4228 """Equivalent to
4229 ``list(combinations_with_replacement(iterable, r))[index]``.
4230
4231
4232 The subsequences with repetition of *iterable* that are of length *r* can
4233 be ordered lexicographically. :func:`nth_combination_with_replacement`
4234 computes the subsequence at sort position *index* directly, without
4235 computing the previous subsequences with replacement.
4236
4237 >>> nth_combination_with_replacement(range(5), 3, 5)
4238 (0, 1, 1)
4239
4240 ``ValueError`` will be raised If *r* is negative or greater than the length
4241 of *iterable*.
4242 ``IndexError`` will be raised if the given *index* is invalid.
4243 """
4244 pool = tuple(iterable)
4245 n = len(pool)
4246 if (r < 0) or (r > n):
4247 raise ValueError
4248
4249 c = comb(n + r - 1, r)
4250
4251 if index < 0:
4252 index += c
4253
4254 if (index < 0) or (index >= c):
4255 raise IndexError
4256
4257 result = []
4258 i = 0
4259 while r:
4260 r -= 1
4261 while n >= 0:
4262 num_combs = comb(n + r - 1, r)
4263 if index < num_combs:
4264 break
4265 n -= 1
4266 i += 1
4267 index -= num_combs
4268 result.append(pool[i])
4269
4270 return tuple(result)
4271
4272
4273def value_chain(*args):
4274 """Yield all arguments passed to the function in the same order in which
4275 they were passed. If an argument itself is iterable then iterate over its
4276 values.
4277
4278 >>> list(value_chain(1, 2, 3, [4, 5, 6]))
4279 [1, 2, 3, 4, 5, 6]
4280
4281 Binary and text strings are not considered iterable and are emitted
4282 as-is:
4283
4284 >>> list(value_chain('12', '34', ['56', '78']))
4285 ['12', '34', '56', '78']
4286
4287 Pre- or postpend a single element to an iterable:
4288
4289 >>> list(value_chain(1, [2, 3, 4, 5, 6]))
4290 [1, 2, 3, 4, 5, 6]
4291 >>> list(value_chain([1, 2, 3, 4, 5], 6))
4292 [1, 2, 3, 4, 5, 6]
4293
4294 Multiple levels of nesting are not flattened.
4295
4296 """
4297 for value in args:
4298 if isinstance(value, (str, bytes)):
4299 yield value
4300 continue
4301 try:
4302 yield from value
4303 except TypeError:
4304 yield value
4305
4306
4307def product_index(element, *args):
4308 """Equivalent to ``list(product(*args)).index(element)``
4309
4310 The products of *args* can be ordered lexicographically.
4311 :func:`product_index` computes the first index of *element* without
4312 computing the previous products.
4313
4314 >>> product_index([8, 2], range(10), range(5))
4315 42
4316
4317 ``ValueError`` will be raised if the given *element* isn't in the product
4318 of *args*.
4319 """
4320 index = 0
4321
4322 for x, pool in zip_longest(element, args, fillvalue=_marker):
4323 if x is _marker or pool is _marker:
4324 raise ValueError('element is not a product of args')
4325
4326 pool = tuple(pool)
4327 index = index * len(pool) + pool.index(x)
4328
4329 return index
4330
4331
4332def combination_index(element, iterable):
4333 """Equivalent to ``list(combinations(iterable, r)).index(element)``
4334
4335 The subsequences of *iterable* that are of length *r* can be ordered
4336 lexicographically. :func:`combination_index` computes the index of the
4337 first *element*, without computing the previous combinations.
4338
4339 >>> combination_index('adf', 'abcdefg')
4340 10
4341
4342 ``ValueError`` will be raised if the given *element* isn't one of the
4343 combinations of *iterable*.
4344 """
4345 element = enumerate(element)
4346 k, y = next(element, (None, None))
4347 if k is None:
4348 return 0
4349
4350 indexes = []
4351 pool = enumerate(iterable)
4352 for n, x in pool:
4353 if x == y:
4354 indexes.append(n)
4355 tmp, y = next(element, (None, None))
4356 if tmp is None:
4357 break
4358 else:
4359 k = tmp
4360 else:
4361 raise ValueError('element is not a combination of iterable')
4362
4363 n, _ = last(pool, default=(n, None))
4364
4365 # Python versions below 3.8 don't have math.comb
4366 index = 1
4367 for i, j in enumerate(reversed(indexes), start=1):
4368 j = n - j
4369 if i <= j:
4370 index += comb(j, i)
4371
4372 return comb(n + 1, k + 1) - index
4373
4374
4375def combination_with_replacement_index(element, iterable):
4376 """Equivalent to
4377 ``list(combinations_with_replacement(iterable, r)).index(element)``
4378
4379 The subsequences with repetition of *iterable* that are of length *r* can
4380 be ordered lexicographically. :func:`combination_with_replacement_index`
4381 computes the index of the first *element*, without computing the previous
4382 combinations with replacement.
4383
4384 >>> combination_with_replacement_index('adf', 'abcdefg')
4385 20
4386
4387 ``ValueError`` will be raised if the given *element* isn't one of the
4388 combinations with replacement of *iterable*.
4389 """
4390 element = tuple(element)
4391 l = len(element)
4392 element = enumerate(element)
4393
4394 k, y = next(element, (None, None))
4395 if k is None:
4396 return 0
4397
4398 indexes = []
4399 pool = tuple(iterable)
4400 for n, x in enumerate(pool):
4401 while x == y:
4402 indexes.append(n)
4403 tmp, y = next(element, (None, None))
4404 if tmp is None:
4405 break
4406 else:
4407 k = tmp
4408 if y is None:
4409 break
4410 else:
4411 raise ValueError(
4412 'element is not a combination with replacement of iterable'
4413 )
4414
4415 n = len(pool)
4416 occupations = [0] * n
4417 for p in indexes:
4418 occupations[p] += 1
4419
4420 index = 0
4421 cumulative_sum = 0
4422 for k in range(1, n):
4423 cumulative_sum += occupations[k - 1]
4424 j = l + n - 1 - k - cumulative_sum
4425 i = n - k
4426 if i <= j:
4427 index += comb(j, i)
4428
4429 return index
4430
4431
4432def permutation_index(element, iterable):
4433 """Equivalent to ``list(permutations(iterable, r)).index(element)```
4434
4435 The subsequences of *iterable* that are of length *r* where order is
4436 important can be ordered lexicographically. :func:`permutation_index`
4437 computes the index of the first *element* directly, without computing
4438 the previous permutations.
4439
4440 >>> permutation_index([1, 3, 2], range(5))
4441 19
4442
4443 ``ValueError`` will be raised if the given *element* isn't one of the
4444 permutations of *iterable*.
4445 """
4446 index = 0
4447 pool = list(iterable)
4448 for i, x in zip(range(len(pool), -1, -1), element):
4449 r = pool.index(x)
4450 index = index * i + r
4451 del pool[r]
4452
4453 return index
4454
4455
4456class countable:
4457 """Wrap *iterable* and keep a count of how many items have been consumed.
4458
4459 The ``items_seen`` attribute starts at ``0`` and increments as the iterable
4460 is consumed:
4461
4462 >>> iterable = map(str, range(10))
4463 >>> it = countable(iterable)
4464 >>> it.items_seen
4465 0
4466 >>> next(it), next(it)
4467 ('0', '1')
4468 >>> list(it)
4469 ['2', '3', '4', '5', '6', '7', '8', '9']
4470 >>> it.items_seen
4471 10
4472 """
4473
4474 def __init__(self, iterable):
4475 self._iterator = iter(iterable)
4476 self.items_seen = 0
4477
4478 def __iter__(self):
4479 return self
4480
4481 def __next__(self):
4482 item = next(self._iterator)
4483 self.items_seen += 1
4484
4485 return item
4486
4487
4488def chunked_even(iterable, n):
4489 """Break *iterable* into lists of approximately length *n*.
4490 Items are distributed such the lengths of the lists differ by at most
4491 1 item.
4492
4493 >>> iterable = [1, 2, 3, 4, 5, 6, 7]
4494 >>> n = 3
4495 >>> list(chunked_even(iterable, n)) # List lengths: 3, 2, 2
4496 [[1, 2, 3], [4, 5], [6, 7]]
4497 >>> list(chunked(iterable, n)) # List lengths: 3, 3, 1
4498 [[1, 2, 3], [4, 5, 6], [7]]
4499
4500 """
4501 iterator = iter(iterable)
4502
4503 # Initialize a buffer to process the chunks while keeping
4504 # some back to fill any underfilled chunks
4505 min_buffer = (n - 1) * (n - 2)
4506 buffer = list(islice(iterator, min_buffer))
4507
4508 # Append items until we have a completed chunk
4509 for _ in islice(map(buffer.append, iterator), n, None, n):
4510 yield buffer[:n]
4511 del buffer[:n]
4512
4513 # Check if any chunks need addition processing
4514 if not buffer:
4515 return
4516 length = len(buffer)
4517
4518 # Chunks are either size `full_size <= n` or `partial_size = full_size - 1`
4519 q, r = divmod(length, n)
4520 num_lists = q + (1 if r > 0 else 0)
4521 q, r = divmod(length, num_lists)
4522 full_size = q + (1 if r > 0 else 0)
4523 partial_size = full_size - 1
4524 num_full = length - partial_size * num_lists
4525
4526 # Yield chunks of full size
4527 partial_start_idx = num_full * full_size
4528 if full_size > 0:
4529 for i in range(0, partial_start_idx, full_size):
4530 yield buffer[i : i + full_size]
4531
4532 # Yield chunks of partial size
4533 if partial_size > 0:
4534 for i in range(partial_start_idx, length, partial_size):
4535 yield buffer[i : i + partial_size]
4536
4537
4538def zip_broadcast(*objects, scalar_types=(str, bytes), strict=False):
4539 """A version of :func:`zip` that "broadcasts" any scalar
4540 (i.e., non-iterable) items into output tuples.
4541
4542 >>> iterable_1 = [1, 2, 3]
4543 >>> iterable_2 = ['a', 'b', 'c']
4544 >>> scalar = '_'
4545 >>> list(zip_broadcast(iterable_1, iterable_2, scalar))
4546 [(1, 'a', '_'), (2, 'b', '_'), (3, 'c', '_')]
4547
4548 The *scalar_types* keyword argument determines what types are considered
4549 scalar. It is set to ``(str, bytes)`` by default. Set it to ``None`` to
4550 treat strings and byte strings as iterable:
4551
4552 >>> list(zip_broadcast('abc', 0, 'xyz', scalar_types=None))
4553 [('a', 0, 'x'), ('b', 0, 'y'), ('c', 0, 'z')]
4554
4555 If the *strict* keyword argument is ``True``, then
4556 ``UnequalIterablesError`` will be raised if any of the iterables have
4557 different lengths.
4558 """
4559
4560 def is_scalar(obj):
4561 if scalar_types and isinstance(obj, scalar_types):
4562 return True
4563 try:
4564 iter(obj)
4565 except TypeError:
4566 return True
4567 else:
4568 return False
4569
4570 size = len(objects)
4571 if not size:
4572 return
4573
4574 new_item = [None] * size
4575 iterables, iterable_positions = [], []
4576 for i, obj in enumerate(objects):
4577 if is_scalar(obj):
4578 new_item[i] = obj
4579 else:
4580 iterables.append(iter(obj))
4581 iterable_positions.append(i)
4582
4583 if not iterables:
4584 yield tuple(objects)
4585 return
4586
4587 zipper = _zip_equal if strict else zip
4588 for item in zipper(*iterables):
4589 for i, new_item[i] in zip(iterable_positions, item):
4590 pass
4591 yield tuple(new_item)
4592
4593
4594def unique_in_window(iterable, n, key=None):
4595 """Yield the items from *iterable* that haven't been seen recently.
4596 *n* is the size of the lookback window.
4597
4598 >>> iterable = [0, 1, 0, 2, 3, 0]
4599 >>> n = 3
4600 >>> list(unique_in_window(iterable, n))
4601 [0, 1, 2, 3, 0]
4602
4603 The *key* function, if provided, will be used to determine uniqueness:
4604
4605 >>> list(unique_in_window('abAcda', 3, key=lambda x: x.lower()))
4606 ['a', 'b', 'c', 'd', 'a']
4607
4608 The items in *iterable* must be hashable.
4609
4610 """
4611 if n <= 0:
4612 raise ValueError('n must be greater than 0')
4613
4614 window = deque(maxlen=n)
4615 counts = defaultdict(int)
4616 use_key = key is not None
4617
4618 for item in iterable:
4619 if len(window) == n:
4620 to_discard = window[0]
4621 if counts[to_discard] == 1:
4622 del counts[to_discard]
4623 else:
4624 counts[to_discard] -= 1
4625
4626 k = key(item) if use_key else item
4627 if k not in counts:
4628 yield item
4629 counts[k] += 1
4630 window.append(k)
4631
4632
4633def duplicates_everseen(iterable, key=None):
4634 """Yield duplicate elements after their first appearance.
4635
4636 >>> list(duplicates_everseen('mississippi'))
4637 ['s', 'i', 's', 's', 'i', 'p', 'i']
4638 >>> list(duplicates_everseen('AaaBbbCccAaa', str.lower))
4639 ['a', 'a', 'b', 'b', 'c', 'c', 'A', 'a', 'a']
4640
4641 This function is analogous to :func:`unique_everseen` and is subject to
4642 the same performance considerations.
4643
4644 """
4645 seen_set = set()
4646 seen_list = []
4647 use_key = key is not None
4648
4649 for element in iterable:
4650 k = key(element) if use_key else element
4651 try:
4652 if k not in seen_set:
4653 seen_set.add(k)
4654 else:
4655 yield element
4656 except TypeError:
4657 if k not in seen_list:
4658 seen_list.append(k)
4659 else:
4660 yield element
4661
4662
4663def duplicates_justseen(iterable, key=None):
4664 """Yields serially-duplicate elements after their first appearance.
4665
4666 >>> list(duplicates_justseen('mississippi'))
4667 ['s', 's', 'p']
4668 >>> list(duplicates_justseen('AaaBbbCccAaa', str.lower))
4669 ['a', 'a', 'b', 'b', 'c', 'c', 'a', 'a']
4670
4671 This function is analogous to :func:`unique_justseen`.
4672
4673 """
4674 return flatten(g for _, g in groupby(iterable, key) for _ in g)
4675
4676
4677def classify_unique(iterable, key=None):
4678 """Classify each element in terms of its uniqueness.
4679
4680 For each element in the input iterable, return a 3-tuple consisting of:
4681
4682 1. The element itself
4683 2. ``False`` if the element is equal to the one preceding it in the input,
4684 ``True`` otherwise (i.e. the equivalent of :func:`unique_justseen`)
4685 3. ``False`` if this element has been seen anywhere in the input before,
4686 ``True`` otherwise (i.e. the equivalent of :func:`unique_everseen`)
4687
4688 >>> list(classify_unique('otto')) # doctest: +NORMALIZE_WHITESPACE
4689 [('o', True, True),
4690 ('t', True, True),
4691 ('t', False, False),
4692 ('o', True, False)]
4693
4694 This function is analogous to :func:`unique_everseen` and is subject to
4695 the same performance considerations.
4696
4697 """
4698 seen_set = set()
4699 seen_list = []
4700 use_key = key is not None
4701 previous = None
4702
4703 for i, element in enumerate(iterable):
4704 k = key(element) if use_key else element
4705 is_unique_justseen = not i or previous != k
4706 previous = k
4707 is_unique_everseen = False
4708 try:
4709 if k not in seen_set:
4710 seen_set.add(k)
4711 is_unique_everseen = True
4712 except TypeError:
4713 if k not in seen_list:
4714 seen_list.append(k)
4715 is_unique_everseen = True
4716 yield element, is_unique_justseen, is_unique_everseen
4717
4718
4719def minmax(iterable_or_value, *others, key=None, default=_marker):
4720 """Returns both the smallest and largest items from an iterable
4721 or from two or more arguments.
4722
4723 >>> minmax([3, 1, 5])
4724 (1, 5)
4725
4726 >>> minmax(4, 2, 6)
4727 (2, 6)
4728
4729 If a *key* function is provided, it will be used to transform the input
4730 items for comparison.
4731
4732 >>> minmax([5, 30], key=str) # '30' sorts before '5'
4733 (30, 5)
4734
4735 If a *default* value is provided, it will be returned if there are no
4736 input items.
4737
4738 >>> minmax([], default=(0, 0))
4739 (0, 0)
4740
4741 Otherwise ``ValueError`` is raised.
4742
4743 This function makes a single pass over the input elements and takes care to
4744 minimize the number of comparisons made during processing.
4745
4746 Note that unlike the builtin ``max`` function, which always returns the first
4747 item with the maximum value, this function may return another item when there are
4748 ties.
4749
4750 This function is based on the
4751 `recipe <https://code.activestate.com/recipes/577916-fast-minmax-function>`__ by
4752 Raymond Hettinger.
4753 """
4754 iterable = (iterable_or_value, *others) if others else iterable_or_value
4755
4756 it = iter(iterable)
4757
4758 try:
4759 lo = hi = next(it)
4760 except StopIteration as exc:
4761 if default is _marker:
4762 raise ValueError(
4763 '`minmax()` argument is an empty iterable. '
4764 'Provide a `default` value to suppress this error.'
4765 ) from exc
4766 return default
4767
4768 # Different branches depending on the presence of key. This saves a lot
4769 # of unimportant copies which would slow the "key=None" branch
4770 # significantly down.
4771 if key is None:
4772 for x, y in zip_longest(it, it, fillvalue=lo):
4773 if y < x:
4774 x, y = y, x
4775 if x < lo:
4776 lo = x
4777 if hi < y:
4778 hi = y
4779
4780 else:
4781 lo_key = hi_key = key(lo)
4782
4783 for x, y in zip_longest(it, it, fillvalue=lo):
4784 x_key, y_key = key(x), key(y)
4785
4786 if y_key < x_key:
4787 x, y, x_key, y_key = y, x, y_key, x_key
4788 if x_key < lo_key:
4789 lo, lo_key = x, x_key
4790 if hi_key < y_key:
4791 hi, hi_key = y, y_key
4792
4793 return lo, hi
4794
4795
4796def constrained_batches(
4797 iterable, max_size, max_count=None, get_len=len, strict=True
4798):
4799 """Yield batches of items from *iterable* with a combined size limited by
4800 *max_size*.
4801
4802 >>> iterable = [b'12345', b'123', b'12345678', b'1', b'1', b'12', b'1']
4803 >>> list(constrained_batches(iterable, 10))
4804 [(b'12345', b'123'), (b'12345678', b'1', b'1'), (b'12', b'1')]
4805
4806 If a *max_count* is supplied, the number of items per batch is also
4807 limited:
4808
4809 >>> iterable = [b'12345', b'123', b'12345678', b'1', b'1', b'12', b'1']
4810 >>> list(constrained_batches(iterable, 10, max_count = 2))
4811 [(b'12345', b'123'), (b'12345678', b'1'), (b'1', b'12'), (b'1',)]
4812
4813 If a *get_len* function is supplied, use that instead of :func:`len` to
4814 determine item size.
4815
4816 If *strict* is ``True``, raise ``ValueError`` if any single item is bigger
4817 than *max_size*. Otherwise, allow single items to exceed *max_size*.
4818 """
4819 if max_size <= 0:
4820 raise ValueError('maximum size must be greater than zero')
4821
4822 batch = []
4823 batch_size = 0
4824 batch_count = 0
4825 for item in iterable:
4826 item_len = get_len(item)
4827 if strict and item_len > max_size:
4828 raise ValueError('item size exceeds maximum size')
4829
4830 reached_count = batch_count == max_count
4831 reached_size = item_len + batch_size > max_size
4832 if batch_count and (reached_size or reached_count):
4833 yield tuple(batch)
4834 batch.clear()
4835 batch_size = 0
4836 batch_count = 0
4837
4838 batch.append(item)
4839 batch_size += item_len
4840 batch_count += 1
4841
4842 if batch:
4843 yield tuple(batch)
4844
4845
4846def gray_product(*iterables):
4847 """Like :func:`itertools.product`, but return tuples in an order such
4848 that only one element in the generated tuple changes from one iteration
4849 to the next.
4850
4851 >>> list(gray_product('AB','CD'))
4852 [('A', 'C'), ('B', 'C'), ('B', 'D'), ('A', 'D')]
4853
4854 This function consumes all of the input iterables before producing output.
4855 If any of the input iterables have fewer than two items, ``ValueError``
4856 is raised.
4857
4858 For information on the algorithm, see
4859 `this section <https://www-cs-faculty.stanford.edu/~knuth/fasc2a.ps.gz>`__
4860 of Donald Knuth's *The Art of Computer Programming*.
4861 """
4862 all_iterables = tuple(tuple(x) for x in iterables)
4863 iterable_count = len(all_iterables)
4864 for iterable in all_iterables:
4865 if len(iterable) < 2:
4866 raise ValueError("each iterable must have two or more items")
4867
4868 # This is based on "Algorithm H" from section 7.2.1.1, page 20.
4869 # a holds the indexes of the source iterables for the n-tuple to be yielded
4870 # f is the array of "focus pointers"
4871 # o is the array of "directions"
4872 a = [0] * iterable_count
4873 f = list(range(iterable_count + 1))
4874 o = [1] * iterable_count
4875 while True:
4876 yield tuple(all_iterables[i][a[i]] for i in range(iterable_count))
4877 j = f[0]
4878 f[0] = 0
4879 if j == iterable_count:
4880 break
4881 a[j] = a[j] + o[j]
4882 if a[j] == 0 or a[j] == len(all_iterables[j]) - 1:
4883 o[j] = -o[j]
4884 f[j] = f[j + 1]
4885 f[j + 1] = j + 1
4886
4887
4888def partial_product(*iterables):
4889 """Yields tuples containing one item from each iterator, with subsequent
4890 tuples changing a single item at a time by advancing each iterator until it
4891 is exhausted. This sequence guarantees every value in each iterable is
4892 output at least once without generating all possible combinations.
4893
4894 This may be useful, for example, when testing an expensive function.
4895
4896 >>> list(partial_product('AB', 'C', 'DEF'))
4897 [('A', 'C', 'D'), ('B', 'C', 'D'), ('B', 'C', 'E'), ('B', 'C', 'F')]
4898 """
4899
4900 iterators = list(map(iter, iterables))
4901
4902 try:
4903 prod = [next(it) for it in iterators]
4904 except StopIteration:
4905 return
4906 yield tuple(prod)
4907
4908 for i, it in enumerate(iterators):
4909 for prod[i] in it:
4910 yield tuple(prod)
4911
4912
4913def takewhile_inclusive(predicate, iterable):
4914 """A variant of :func:`takewhile` that yields one additional element.
4915
4916 >>> list(takewhile_inclusive(lambda x: x < 5, [1, 4, 6, 4, 1]))
4917 [1, 4, 6]
4918
4919 :func:`takewhile` would return ``[1, 4]``.
4920 """
4921 for x in iterable:
4922 yield x
4923 if not predicate(x):
4924 break
4925
4926
4927def outer_product(func, xs, ys, *args, **kwargs):
4928 """A generalized outer product that applies a binary function to all
4929 pairs of items. Returns a 2D matrix with ``len(xs)`` rows and ``len(ys)``
4930 columns.
4931 Also accepts ``*args`` and ``**kwargs`` that are passed to ``func``.
4932
4933 Multiplication table:
4934
4935 >>> list(outer_product(mul, range(1, 4), range(1, 6)))
4936 [(1, 2, 3, 4, 5), (2, 4, 6, 8, 10), (3, 6, 9, 12, 15)]
4937
4938 Cross tabulation:
4939
4940 >>> xs = ['A', 'B', 'A', 'A', 'B', 'B', 'A', 'A', 'B', 'B']
4941 >>> ys = ['X', 'X', 'X', 'Y', 'Z', 'Z', 'Y', 'Y', 'Z', 'Z']
4942 >>> pair_counts = Counter(zip(xs, ys))
4943 >>> count_rows = lambda x, y: pair_counts[x, y]
4944 >>> list(outer_product(count_rows, sorted(set(xs)), sorted(set(ys))))
4945 [(2, 3, 0), (1, 0, 4)]
4946
4947 Usage with ``*args`` and ``**kwargs``:
4948
4949 >>> animals = ['cat', 'wolf', 'mouse']
4950 >>> list(outer_product(min, animals, animals, key=len))
4951 [('cat', 'cat', 'cat'), ('cat', 'wolf', 'wolf'), ('cat', 'wolf', 'mouse')]
4952 """
4953 ys = tuple(ys)
4954 return batched(
4955 starmap(lambda x, y: func(x, y, *args, **kwargs), product(xs, ys)),
4956 n=len(ys),
4957 )
4958
4959
4960def iter_suppress(iterable, *exceptions):
4961 """Yield each of the items from *iterable*. If the iteration raises one of
4962 the specified *exceptions*, that exception will be suppressed and iteration
4963 will stop.
4964
4965 >>> from itertools import chain
4966 >>> def breaks_at_five(x):
4967 ... while True:
4968 ... if x >= 5:
4969 ... raise RuntimeError
4970 ... yield x
4971 ... x += 1
4972 >>> it_1 = iter_suppress(breaks_at_five(1), RuntimeError)
4973 >>> it_2 = iter_suppress(breaks_at_five(2), RuntimeError)
4974 >>> list(chain(it_1, it_2))
4975 [1, 2, 3, 4, 2, 3, 4]
4976 """
4977 try:
4978 yield from iterable
4979 except exceptions:
4980 return
4981
4982
4983def filter_map(func, iterable):
4984 """Apply *func* to every element of *iterable*, yielding only those which
4985 are not ``None``.
4986
4987 >>> elems = ['1', 'a', '2', 'b', '3']
4988 >>> list(filter_map(lambda s: int(s) if s.isnumeric() else None, elems))
4989 [1, 2, 3]
4990 """
4991 for x in iterable:
4992 y = func(x)
4993 if y is not None:
4994 yield y
4995
4996
4997def powerset_of_sets(iterable):
4998 """Yields all possible subsets of the iterable.
4999
5000 >>> list(powerset_of_sets([1, 2, 3])) # doctest: +SKIP
5001 [set(), {1}, {2}, {3}, {1, 2}, {1, 3}, {2, 3}, {1, 2, 3}]
5002 >>> list(powerset_of_sets([1, 1, 0])) # doctest: +SKIP
5003 [set(), {1}, {0}, {0, 1}]
5004
5005 :func:`powerset_of_sets` takes care to minimize the number
5006 of hash operations performed.
5007 """
5008 sets = tuple(dict.fromkeys(map(frozenset, zip(iterable))))
5009 return chain.from_iterable(
5010 starmap(set().union, combinations(sets, r))
5011 for r in range(len(sets) + 1)
5012 )
5013
5014
5015def join_mappings(**field_to_map):
5016 """
5017 Joins multiple mappings together using their common keys.
5018
5019 >>> user_scores = {'elliot': 50, 'claris': 60}
5020 >>> user_times = {'elliot': 30, 'claris': 40}
5021 >>> join_mappings(score=user_scores, time=user_times)
5022 {'elliot': {'score': 50, 'time': 30}, 'claris': {'score': 60, 'time': 40}}
5023 """
5024 ret = defaultdict(dict)
5025
5026 for field_name, mapping in field_to_map.items():
5027 for key, value in mapping.items():
5028 ret[key][field_name] = value
5029
5030 return dict(ret)
5031
5032
5033def _complex_sumprod(v1, v2):
5034 """High precision sumprod() for complex numbers.
5035 Used by :func:`dft` and :func:`idft`.
5036 """
5037
5038 real = attrgetter('real')
5039 imag = attrgetter('imag')
5040 r1 = chain(map(real, v1), map(neg, map(imag, v1)))
5041 r2 = chain(map(real, v2), map(imag, v2))
5042 i1 = chain(map(real, v1), map(imag, v1))
5043 i2 = chain(map(imag, v2), map(real, v2))
5044 return complex(_fsumprod(r1, r2), _fsumprod(i1, i2))
5045
5046
5047def dft(xarr):
5048 """Discrete Fourier Transform. *xarr* is a sequence of complex numbers.
5049 Yields the components of the corresponding transformed output vector.
5050
5051 >>> import cmath
5052 >>> xarr = [1, 2-1j, -1j, -1+2j] # time domain
5053 >>> Xarr = [2, -2-2j, -2j, 4+4j] # frequency domain
5054 >>> magnitudes, phases = zip(*map(cmath.polar, Xarr))
5055 >>> all(map(cmath.isclose, dft(xarr), Xarr))
5056 True
5057
5058 Inputs are restricted to numeric types that can add and multiply
5059 with a complex number. This includes int, float, complex, and
5060 Fraction, but excludes Decimal.
5061
5062 See :func:`idft` for the inverse Discrete Fourier Transform.
5063 """
5064 N = len(xarr)
5065 roots_of_unity = [e ** (n / N * tau * -1j) for n in range(N)]
5066 for k in range(N):
5067 coeffs = [roots_of_unity[k * n % N] for n in range(N)]
5068 yield _complex_sumprod(xarr, coeffs)
5069
5070
5071def idft(Xarr):
5072 """Inverse Discrete Fourier Transform. *Xarr* is a sequence of
5073 complex numbers. Yields the components of the corresponding
5074 inverse-transformed output vector.
5075
5076 >>> import cmath
5077 >>> xarr = [1, 2-1j, -1j, -1+2j] # time domain
5078 >>> Xarr = [2, -2-2j, -2j, 4+4j] # frequency domain
5079 >>> all(map(cmath.isclose, idft(Xarr), xarr))
5080 True
5081
5082 Inputs are restricted to numeric types that can add and multiply
5083 with a complex number. This includes int, float, complex, and
5084 Fraction, but excludes Decimal.
5085
5086 See :func:`dft` for the Discrete Fourier Transform.
5087 """
5088 N = len(Xarr)
5089 roots_of_unity = [e ** (n / N * tau * 1j) for n in range(N)]
5090 for k in range(N):
5091 coeffs = [roots_of_unity[k * n % N] for n in range(N)]
5092 yield _complex_sumprod(Xarr, coeffs) / N
5093
5094
5095def doublestarmap(func, iterable):
5096 """Apply *func* to every item of *iterable* by dictionary unpacking
5097 the item into *func*.
5098
5099 The difference between :func:`itertools.starmap` and :func:`doublestarmap`
5100 parallels the distinction between ``func(*a)`` and ``func(**a)``.
5101
5102 >>> iterable = [{'a': 1, 'b': 2}, {'a': 40, 'b': 60}]
5103 >>> list(doublestarmap(lambda a, b: a + b, iterable))
5104 [3, 100]
5105
5106 ``TypeError`` will be raised if *func*'s signature doesn't match the
5107 mapping contained in *iterable* or if *iterable* does not contain mappings.
5108 """
5109 for item in iterable:
5110 yield func(**item)
5111
5112
5113def _nth_prime_bounds(n):
5114 """Bounds for the nth prime (counting from 1): lb <= p_n <= ub."""
5115 # At and above 688,383, the lb/ub spread is under 0.003 * n.
5116
5117 if n < 1:
5118 raise ValueError
5119
5120 if n < 6:
5121 return (n, 2.25 * n)
5122
5123 # https://en.wikipedia.org/wiki/Prime-counting_function#Inequalities
5124 upper_bound = n * log(n * log(n))
5125 lower_bound = upper_bound - n
5126 if n >= 688_383:
5127 upper_bound -= n * (1.0 - (log(log(n)) - 2.0) / log(n))
5128
5129 return lower_bound, upper_bound
5130
5131
5132def nth_prime(n, *, approximate=False):
5133 """Return the nth prime (counting from 0).
5134
5135 >>> nth_prime(0)
5136 2
5137 >>> nth_prime(100)
5138 547
5139
5140 If *approximate* is set to True, will return a prime in the close
5141 to the nth prime. The estimation is much faster than computing
5142 an exact result.
5143
5144 >>> nth_prime(200_000_000, approximate=True) # Exact result is 4222234763
5145 4217820427
5146
5147 """
5148 lb, ub = _nth_prime_bounds(n + 1)
5149
5150 if not approximate or n <= 1_000_000:
5151 return nth(sieve(ceil(ub)), n)
5152
5153 # Search from the midpoint and return the first odd prime
5154 odd = floor((lb + ub) / 2) | 1
5155 return first_true(count(odd, step=2), pred=is_prime)
5156
5157
5158def argmin(iterable, *, key=None):
5159 """
5160 Index of the first occurrence of a minimum value in an iterable.
5161
5162 >>> argmin('efghabcdijkl')
5163 4
5164 >>> argmin([3, 2, 1, 0, 4, 2, 1, 0])
5165 3
5166
5167 For example, look up a label corresponding to the position
5168 of a value that minimizes a cost function::
5169
5170 >>> def cost(x):
5171 ... "Days for a wound to heal given a subject's age."
5172 ... return x**2 - 20*x + 150
5173 ...
5174 >>> labels = ['homer', 'marge', 'bart', 'lisa', 'maggie']
5175 >>> ages = [ 35, 30, 10, 9, 1 ]
5176
5177 # Fastest healing family member
5178 >>> labels[argmin(ages, key=cost)]
5179 'bart'
5180
5181 # Age with fastest healing
5182 >>> min(ages, key=cost)
5183 10
5184
5185 """
5186 if key is not None:
5187 iterable = map(key, iterable)
5188 return min(enumerate(iterable), key=itemgetter(1))[0]
5189
5190
5191def argmax(iterable, *, key=None):
5192 """
5193 Index of the first occurrence of a maximum value in an iterable.
5194
5195 >>> argmax('abcdefghabcd')
5196 7
5197 >>> argmax([0, 1, 2, 3, 3, 2, 1, 0])
5198 3
5199
5200 For example, identify the best machine learning model::
5201
5202 >>> models = ['svm', 'random forest', 'knn', 'naïve bayes']
5203 >>> accuracy = [ 68, 61, 84, 72 ]
5204
5205 # Most accurate model
5206 >>> models[argmax(accuracy)]
5207 'knn'
5208
5209 # Best accuracy
5210 >>> max(accuracy)
5211 84
5212
5213 """
5214 if key is not None:
5215 iterable = map(key, iterable)
5216 return max(enumerate(iterable), key=itemgetter(1))[0]