1import math
2import warnings
3
4from collections import Counter, defaultdict, deque, abc
5from collections.abc import Sequence
6from contextlib import suppress
7from functools import cached_property, partial, reduce, wraps
8from heapq import heapify, heapreplace
9from itertools import (
10 chain,
11 combinations,
12 compress,
13 count,
14 cycle,
15 dropwhile,
16 groupby,
17 islice,
18 permutations,
19 repeat,
20 starmap,
21 takewhile,
22 tee,
23 zip_longest,
24 product,
25)
26from math import comb, e, exp, factorial, floor, fsum, log, log1p, perm, tau
27from math import ceil
28from queue import Empty, Queue
29from random import random, randrange, shuffle, uniform
30from operator import is_ as operator_is, attrgetter, itemgetter
31from operator import neg, mul, sub, gt, lt
32from sys import hexversion, maxsize
33from time import monotonic
34
35from .recipes import (
36 _marker,
37 _zip_equal,
38 UnequalIterablesError,
39 consume,
40 first_true,
41 flatten,
42 is_prime,
43 nth,
44 powerset,
45 sieve,
46 take,
47 unique_everseen,
48 all_equal,
49 batched,
50)
51
52__all__ = [
53 'AbortThread',
54 'SequenceView',
55 'UnequalIterablesError',
56 'adjacent',
57 'all_unique',
58 'always_iterable',
59 'always_reversible',
60 'argmax',
61 'argmin',
62 'bucket',
63 'callback_iter',
64 'chunked',
65 'chunked_even',
66 'circular_shifts',
67 'collapse',
68 'combination_index',
69 'combination_with_replacement_index',
70 'consecutive_groups',
71 'constrained_batches',
72 'consumer',
73 'count_cycle',
74 'countable',
75 'derangements',
76 'dft',
77 'difference',
78 'distinct_combinations',
79 'distinct_permutations',
80 'distribute',
81 'divide',
82 'doublestarmap',
83 'duplicates_everseen',
84 'duplicates_justseen',
85 'classify_unique',
86 'exactly_n',
87 'filter_except',
88 'filter_map',
89 'first',
90 'gray_product',
91 'groupby_transform',
92 'ichunked',
93 'iequals',
94 'idft',
95 'ilen',
96 'interleave',
97 'interleave_evenly',
98 'interleave_longest',
99 'intersperse',
100 'is_sorted',
101 'islice_extended',
102 'iterate',
103 'iter_suppress',
104 'join_mappings',
105 'last',
106 'locate',
107 'longest_common_prefix',
108 'lstrip',
109 'make_decorator',
110 'map_except',
111 'map_if',
112 'map_reduce',
113 'mark_ends',
114 'minmax',
115 'nth_or_last',
116 'nth_permutation',
117 'nth_prime',
118 'nth_product',
119 'nth_combination_with_replacement',
120 'numeric_range',
121 'one',
122 'only',
123 'outer_product',
124 'padded',
125 'partial_product',
126 'partitions',
127 'peekable',
128 'permutation_index',
129 'powerset_of_sets',
130 'product_index',
131 'raise_',
132 'repeat_each',
133 'repeat_last',
134 'replace',
135 'rlocate',
136 'rstrip',
137 'run_length',
138 'sample',
139 'seekable',
140 'set_partitions',
141 'side_effect',
142 'sliced',
143 'sort_together',
144 'split_after',
145 'split_at',
146 'split_before',
147 'split_into',
148 'split_when',
149 'spy',
150 'stagger',
151 'strip',
152 'strictly_n',
153 'substrings',
154 'substrings_indexes',
155 'takewhile_inclusive',
156 'time_limited',
157 'unique_in_window',
158 'unique_to_each',
159 'unzip',
160 'value_chain',
161 'windowed',
162 'windowed_complete',
163 'with_iter',
164 'zip_broadcast',
165 'zip_equal',
166 'zip_offset',
167]
168
169# math.sumprod is available for Python 3.12+
170try:
171 from math import sumprod as _fsumprod
172
173except ImportError: # pragma: no cover
174 # Extended precision algorithms from T. J. Dekker,
175 # "A Floating-Point Technique for Extending the Available Precision"
176 # https://csclub.uwaterloo.ca/~pbarfuss/dekker1971.pdf
177 # Formulas: (5.5) (5.6) and (5.8). Code: mul12()
178
179 def dl_split(x: float):
180 "Split a float into two half-precision components."
181 t = x * 134217729.0 # Veltkamp constant = 2.0 ** 27 + 1
182 hi = t - (t - x)
183 lo = x - hi
184 return hi, lo
185
186 def dl_mul(x, y):
187 "Lossless multiplication."
188 xx_hi, xx_lo = dl_split(x)
189 yy_hi, yy_lo = dl_split(y)
190 p = xx_hi * yy_hi
191 q = xx_hi * yy_lo + xx_lo * yy_hi
192 z = p + q
193 zz = p - z + q + xx_lo * yy_lo
194 return z, zz
195
196 def _fsumprod(p, q):
197 return fsum(chain.from_iterable(map(dl_mul, p, q)))
198
199
200def chunked(iterable, n, strict=False):
201 """Break *iterable* into lists of length *n*:
202
203 >>> list(chunked([1, 2, 3, 4, 5, 6], 3))
204 [[1, 2, 3], [4, 5, 6]]
205
206 By the default, the last yielded list will have fewer than *n* elements
207 if the length of *iterable* is not divisible by *n*:
208
209 >>> list(chunked([1, 2, 3, 4, 5, 6, 7, 8], 3))
210 [[1, 2, 3], [4, 5, 6], [7, 8]]
211
212 To use a fill-in value instead, see the :func:`grouper` recipe.
213
214 If the length of *iterable* is not divisible by *n* and *strict* is
215 ``True``, then ``ValueError`` will be raised before the last
216 list is yielded.
217
218 """
219 iterator = iter(partial(take, n, iter(iterable)), [])
220 if strict:
221 if n is None:
222 raise ValueError('n must not be None when using strict mode.')
223
224 def ret():
225 for chunk in iterator:
226 if len(chunk) != n:
227 raise ValueError('iterable is not divisible by n.')
228 yield chunk
229
230 return ret()
231 else:
232 return iterator
233
234
235def first(iterable, default=_marker):
236 """Return the first item of *iterable*, or *default* if *iterable* is
237 empty.
238
239 >>> first([0, 1, 2, 3])
240 0
241 >>> first([], 'some default')
242 'some default'
243
244 If *default* is not provided and there are no items in the iterable,
245 raise ``ValueError``.
246
247 :func:`first` is useful when you have a generator of expensive-to-retrieve
248 values and want any arbitrary one. It is marginally shorter than
249 ``next(iter(iterable), default)``.
250
251 """
252 for item in iterable:
253 return item
254 if default is _marker:
255 raise ValueError(
256 'first() was called on an empty iterable, '
257 'and no default value was provided.'
258 )
259 return default
260
261
262def last(iterable, default=_marker):
263 """Return the last item of *iterable*, or *default* if *iterable* is
264 empty.
265
266 >>> last([0, 1, 2, 3])
267 3
268 >>> last([], 'some default')
269 'some default'
270
271 If *default* is not provided and there are no items in the iterable,
272 raise ``ValueError``.
273 """
274 try:
275 if isinstance(iterable, Sequence):
276 return iterable[-1]
277 # Work around https://bugs.python.org/issue38525
278 if hasattr(iterable, '__reversed__'):
279 return next(reversed(iterable))
280 return deque(iterable, maxlen=1)[-1]
281 except (IndexError, TypeError, StopIteration):
282 if default is _marker:
283 raise ValueError(
284 'last() was called on an empty iterable, '
285 'and no default value was provided.'
286 )
287 return default
288
289
290def nth_or_last(iterable, n, default=_marker):
291 """Return the nth or the last item of *iterable*,
292 or *default* if *iterable* is empty.
293
294 >>> nth_or_last([0, 1, 2, 3], 2)
295 2
296 >>> nth_or_last([0, 1], 2)
297 1
298 >>> nth_or_last([], 0, 'some default')
299 'some default'
300
301 If *default* is not provided and there are no items in the iterable,
302 raise ``ValueError``.
303 """
304 return last(islice(iterable, n + 1), default=default)
305
306
307class peekable:
308 """Wrap an iterator to allow lookahead and prepending elements.
309
310 Call :meth:`peek` on the result to get the value that will be returned
311 by :func:`next`. This won't advance the iterator:
312
313 >>> p = peekable(['a', 'b'])
314 >>> p.peek()
315 'a'
316 >>> next(p)
317 'a'
318
319 Pass :meth:`peek` a default value to return that instead of raising
320 ``StopIteration`` when the iterator is exhausted.
321
322 >>> p = peekable([])
323 >>> p.peek('hi')
324 'hi'
325
326 peekables also offer a :meth:`prepend` method, which "inserts" items
327 at the head of the iterable:
328
329 >>> p = peekable([1, 2, 3])
330 >>> p.prepend(10, 11, 12)
331 >>> next(p)
332 10
333 >>> p.peek()
334 11
335 >>> list(p)
336 [11, 12, 1, 2, 3]
337
338 peekables can be indexed. Index 0 is the item that will be returned by
339 :func:`next`, index 1 is the item after that, and so on:
340 The values up to the given index will be cached.
341
342 >>> p = peekable(['a', 'b', 'c', 'd'])
343 >>> p[0]
344 'a'
345 >>> p[1]
346 'b'
347 >>> next(p)
348 'a'
349
350 Negative indexes are supported, but be aware that they will cache the
351 remaining items in the source iterator, which may require significant
352 storage.
353
354 To check whether a peekable is exhausted, check its truth value:
355
356 >>> p = peekable(['a', 'b'])
357 >>> if p: # peekable has items
358 ... list(p)
359 ['a', 'b']
360 >>> if not p: # peekable is exhausted
361 ... list(p)
362 []
363
364 """
365
366 def __init__(self, iterable):
367 self._it = iter(iterable)
368 self._cache = deque()
369
370 def __iter__(self):
371 return self
372
373 def __bool__(self):
374 try:
375 self.peek()
376 except StopIteration:
377 return False
378 return True
379
380 def peek(self, default=_marker):
381 """Return the item that will be next returned from ``next()``.
382
383 Return ``default`` if there are no items left. If ``default`` is not
384 provided, raise ``StopIteration``.
385
386 """
387 if not self._cache:
388 try:
389 self._cache.append(next(self._it))
390 except StopIteration:
391 if default is _marker:
392 raise
393 return default
394 return self._cache[0]
395
396 def prepend(self, *items):
397 """Stack up items to be the next ones returned from ``next()`` or
398 ``self.peek()``. The items will be returned in
399 first in, first out order::
400
401 >>> p = peekable([1, 2, 3])
402 >>> p.prepend(10, 11, 12)
403 >>> next(p)
404 10
405 >>> list(p)
406 [11, 12, 1, 2, 3]
407
408 It is possible, by prepending items, to "resurrect" a peekable that
409 previously raised ``StopIteration``.
410
411 >>> p = peekable([])
412 >>> next(p)
413 Traceback (most recent call last):
414 ...
415 StopIteration
416 >>> p.prepend(1)
417 >>> next(p)
418 1
419 >>> next(p)
420 Traceback (most recent call last):
421 ...
422 StopIteration
423
424 """
425 self._cache.extendleft(reversed(items))
426
427 def __next__(self):
428 if self._cache:
429 return self._cache.popleft()
430
431 return next(self._it)
432
433 def _get_slice(self, index):
434 # Normalize the slice's arguments
435 step = 1 if (index.step is None) else index.step
436 if step > 0:
437 start = 0 if (index.start is None) else index.start
438 stop = maxsize if (index.stop is None) else index.stop
439 elif step < 0:
440 start = -1 if (index.start is None) else index.start
441 stop = (-maxsize - 1) if (index.stop is None) else index.stop
442 else:
443 raise ValueError('slice step cannot be zero')
444
445 # If either the start or stop index is negative, we'll need to cache
446 # the rest of the iterable in order to slice from the right side.
447 if (start < 0) or (stop < 0):
448 self._cache.extend(self._it)
449 # Otherwise we'll need to find the rightmost index and cache to that
450 # point.
451 else:
452 n = min(max(start, stop) + 1, maxsize)
453 cache_len = len(self._cache)
454 if n >= cache_len:
455 self._cache.extend(islice(self._it, n - cache_len))
456
457 return list(self._cache)[index]
458
459 def __getitem__(self, index):
460 if isinstance(index, slice):
461 return self._get_slice(index)
462
463 cache_len = len(self._cache)
464 if index < 0:
465 self._cache.extend(self._it)
466 elif index >= cache_len:
467 self._cache.extend(islice(self._it, index + 1 - cache_len))
468
469 return self._cache[index]
470
471
472def consumer(func):
473 """Decorator that automatically advances a PEP-342-style "reverse iterator"
474 to its first yield point so you don't have to call ``next()`` on it
475 manually.
476
477 >>> @consumer
478 ... def tally():
479 ... i = 0
480 ... while True:
481 ... print('Thing number %s is %s.' % (i, (yield)))
482 ... i += 1
483 ...
484 >>> t = tally()
485 >>> t.send('red')
486 Thing number 0 is red.
487 >>> t.send('fish')
488 Thing number 1 is fish.
489
490 Without the decorator, you would have to call ``next(t)`` before
491 ``t.send()`` could be used.
492
493 """
494
495 @wraps(func)
496 def wrapper(*args, **kwargs):
497 gen = func(*args, **kwargs)
498 next(gen)
499 return gen
500
501 return wrapper
502
503
504def ilen(iterable):
505 """Return the number of items in *iterable*.
506
507 For example, there are 168 prime numbers below 1,000:
508
509 >>> ilen(sieve(1000))
510 168
511
512 Equivalent to, but faster than::
513
514 def ilen(iterable):
515 count = 0
516 for _ in iterable:
517 count += 1
518 return count
519
520 This fully consumes the iterable, so handle with care.
521
522 """
523 # This is the "most beautiful of the fast variants" of this function.
524 # If you think you can improve on it, please ensure that your version
525 # is both 10x faster and 10x more beautiful.
526 return sum(compress(repeat(1), zip(iterable)))
527
528
529def iterate(func, start):
530 """Return ``start``, ``func(start)``, ``func(func(start))``, ...
531
532 Produces an infinite iterator. To add a stopping condition,
533 use :func:`take`, ``takewhile``, or :func:`takewhile_inclusive`:.
534
535 >>> take(10, iterate(lambda x: 2*x, 1))
536 [1, 2, 4, 8, 16, 32, 64, 128, 256, 512]
537
538 >>> collatz = lambda x: 3*x + 1 if x%2==1 else x // 2
539 >>> list(takewhile_inclusive(lambda x: x!=1, iterate(collatz, 10)))
540 [10, 5, 16, 8, 4, 2, 1]
541
542 """
543 with suppress(StopIteration):
544 while True:
545 yield start
546 start = func(start)
547
548
549def with_iter(context_manager):
550 """Wrap an iterable in a ``with`` statement, so it closes once exhausted.
551
552 For example, this will close the file when the iterator is exhausted::
553
554 upper_lines = (line.upper() for line in with_iter(open('foo')))
555
556 Any context manager which returns an iterable is a candidate for
557 ``with_iter``.
558
559 """
560 with context_manager as iterable:
561 yield from iterable
562
563
564def one(iterable, too_short=None, too_long=None):
565 """Return the first item from *iterable*, which is expected to contain only
566 that item. Raise an exception if *iterable* is empty or has more than one
567 item.
568
569 :func:`one` is useful for ensuring that an iterable contains only one item.
570 For example, it can be used to retrieve the result of a database query
571 that is expected to return a single row.
572
573 If *iterable* is empty, ``ValueError`` will be raised. You may specify a
574 different exception with the *too_short* keyword:
575
576 >>> it = []
577 >>> one(it) # doctest: +IGNORE_EXCEPTION_DETAIL
578 Traceback (most recent call last):
579 ...
580 ValueError: too few items in iterable (expected 1)'
581 >>> too_short = IndexError('too few items')
582 >>> one(it, too_short=too_short) # doctest: +IGNORE_EXCEPTION_DETAIL
583 Traceback (most recent call last):
584 ...
585 IndexError: too few items
586
587 Similarly, if *iterable* contains more than one item, ``ValueError`` will
588 be raised. You may specify a different exception with the *too_long*
589 keyword:
590
591 >>> it = ['too', 'many']
592 >>> one(it) # doctest: +IGNORE_EXCEPTION_DETAIL
593 Traceback (most recent call last):
594 ...
595 ValueError: Expected exactly one item in iterable, but got 'too',
596 'many', and perhaps more.
597 >>> too_long = RuntimeError
598 >>> one(it, too_long=too_long) # doctest: +IGNORE_EXCEPTION_DETAIL
599 Traceback (most recent call last):
600 ...
601 RuntimeError
602
603 Note that :func:`one` attempts to advance *iterable* twice to ensure there
604 is only one item. See :func:`spy` or :func:`peekable` to check iterable
605 contents less destructively.
606
607 """
608 iterator = iter(iterable)
609 for first in iterator:
610 for second in iterator:
611 msg = (
612 f'Expected exactly one item in iterable, but got {first!r}, '
613 f'{second!r}, and perhaps more.'
614 )
615 raise too_long or ValueError(msg)
616 return first
617 raise too_short or ValueError('too few items in iterable (expected 1)')
618
619
620def raise_(exception, *args):
621 raise exception(*args)
622
623
624def strictly_n(iterable, n, too_short=None, too_long=None):
625 """Validate that *iterable* has exactly *n* items and return them if
626 it does. If it has fewer than *n* items, call function *too_short*
627 with those items. If it has more than *n* items, call function
628 *too_long* with the first ``n + 1`` items.
629
630 >>> iterable = ['a', 'b', 'c', 'd']
631 >>> n = 4
632 >>> list(strictly_n(iterable, n))
633 ['a', 'b', 'c', 'd']
634
635 Note that the returned iterable must be consumed in order for the check to
636 be made.
637
638 By default, *too_short* and *too_long* are functions that raise
639 ``ValueError``.
640
641 >>> list(strictly_n('ab', 3)) # doctest: +IGNORE_EXCEPTION_DETAIL
642 Traceback (most recent call last):
643 ...
644 ValueError: too few items in iterable (got 2)
645
646 >>> list(strictly_n('abc', 2)) # doctest: +IGNORE_EXCEPTION_DETAIL
647 Traceback (most recent call last):
648 ...
649 ValueError: too many items in iterable (got at least 3)
650
651 You can instead supply functions that do something else.
652 *too_short* will be called with the number of items in *iterable*.
653 *too_long* will be called with `n + 1`.
654
655 >>> def too_short(item_count):
656 ... raise RuntimeError
657 >>> it = strictly_n('abcd', 6, too_short=too_short)
658 >>> list(it) # doctest: +IGNORE_EXCEPTION_DETAIL
659 Traceback (most recent call last):
660 ...
661 RuntimeError
662
663 >>> def too_long(item_count):
664 ... print('The boss is going to hear about this')
665 >>> it = strictly_n('abcdef', 4, too_long=too_long)
666 >>> list(it)
667 The boss is going to hear about this
668 ['a', 'b', 'c', 'd']
669
670 """
671 if too_short is None:
672 too_short = lambda item_count: raise_(
673 ValueError,
674 f'Too few items in iterable (got {item_count})',
675 )
676
677 if too_long is None:
678 too_long = lambda item_count: raise_(
679 ValueError,
680 f'Too many items in iterable (got at least {item_count})',
681 )
682
683 it = iter(iterable)
684
685 sent = 0
686 for item in islice(it, n):
687 yield item
688 sent += 1
689
690 if sent < n:
691 too_short(sent)
692 return
693
694 for item in it:
695 too_long(n + 1)
696 return
697
698
699def distinct_permutations(iterable, r=None):
700 """Yield successive distinct permutations of the elements in *iterable*.
701
702 >>> sorted(distinct_permutations([1, 0, 1]))
703 [(0, 1, 1), (1, 0, 1), (1, 1, 0)]
704
705 Equivalent to yielding from ``set(permutations(iterable))``, except
706 duplicates are not generated and thrown away. For larger input sequences
707 this is much more efficient.
708
709 Duplicate permutations arise when there are duplicated elements in the
710 input iterable. The number of items returned is
711 `n! / (x_1! * x_2! * ... * x_n!)`, where `n` is the total number of
712 items input, and each `x_i` is the count of a distinct item in the input
713 sequence. The function :func:`multinomial` computes this directly.
714
715 If *r* is given, only the *r*-length permutations are yielded.
716
717 >>> sorted(distinct_permutations([1, 0, 1], r=2))
718 [(0, 1), (1, 0), (1, 1)]
719 >>> sorted(distinct_permutations(range(3), r=2))
720 [(0, 1), (0, 2), (1, 0), (1, 2), (2, 0), (2, 1)]
721
722 *iterable* need not be sortable, but note that using equal (``x == y``)
723 but non-identical (``id(x) != id(y)``) elements may produce surprising
724 behavior. For example, ``1`` and ``True`` are equal but non-identical:
725
726 >>> list(distinct_permutations([1, True, '3'])) # doctest: +SKIP
727 [
728 (1, True, '3'),
729 (1, '3', True),
730 ('3', 1, True)
731 ]
732 >>> list(distinct_permutations([1, 2, '3'])) # doctest: +SKIP
733 [
734 (1, 2, '3'),
735 (1, '3', 2),
736 (2, 1, '3'),
737 (2, '3', 1),
738 ('3', 1, 2),
739 ('3', 2, 1)
740 ]
741 """
742
743 # Algorithm: https://w.wiki/Qai
744 def _full(A):
745 while True:
746 # Yield the permutation we have
747 yield tuple(A)
748
749 # Find the largest index i such that A[i] < A[i + 1]
750 for i in range(size - 2, -1, -1):
751 if A[i] < A[i + 1]:
752 break
753 # If no such index exists, this permutation is the last one
754 else:
755 return
756
757 # Find the largest index j greater than j such that A[i] < A[j]
758 for j in range(size - 1, i, -1):
759 if A[i] < A[j]:
760 break
761
762 # Swap the value of A[i] with that of A[j], then reverse the
763 # sequence from A[i + 1] to form the new permutation
764 A[i], A[j] = A[j], A[i]
765 A[i + 1 :] = A[: i - size : -1] # A[i + 1:][::-1]
766
767 # Algorithm: modified from the above
768 def _partial(A, r):
769 # Split A into the first r items and the last r items
770 head, tail = A[:r], A[r:]
771 right_head_indexes = range(r - 1, -1, -1)
772 left_tail_indexes = range(len(tail))
773
774 while True:
775 # Yield the permutation we have
776 yield tuple(head)
777
778 # Starting from the right, find the first index of the head with
779 # value smaller than the maximum value of the tail - call it i.
780 pivot = tail[-1]
781 for i in right_head_indexes:
782 if head[i] < pivot:
783 break
784 pivot = head[i]
785 else:
786 return
787
788 # Starting from the left, find the first value of the tail
789 # with a value greater than head[i] and swap.
790 for j in left_tail_indexes:
791 if tail[j] > head[i]:
792 head[i], tail[j] = tail[j], head[i]
793 break
794 # If we didn't find one, start from the right and find the first
795 # index of the head with a value greater than head[i] and swap.
796 else:
797 for j in right_head_indexes:
798 if head[j] > head[i]:
799 head[i], head[j] = head[j], head[i]
800 break
801
802 # Reverse head[i + 1:] and swap it with tail[:r - (i + 1)]
803 tail += head[: i - r : -1] # head[i + 1:][::-1]
804 i += 1
805 head[i:], tail[:] = tail[: r - i], tail[r - i :]
806
807 items = list(iterable)
808
809 try:
810 items.sort()
811 sortable = True
812 except TypeError:
813 sortable = False
814
815 indices_dict = defaultdict(list)
816
817 for item in items:
818 indices_dict[items.index(item)].append(item)
819
820 indices = [items.index(item) for item in items]
821 indices.sort()
822
823 equivalent_items = {k: cycle(v) for k, v in indices_dict.items()}
824
825 def permuted_items(permuted_indices):
826 return tuple(
827 next(equivalent_items[index]) for index in permuted_indices
828 )
829
830 size = len(items)
831 if r is None:
832 r = size
833
834 # functools.partial(_partial, ... )
835 algorithm = _full if (r == size) else partial(_partial, r=r)
836
837 if 0 < r <= size:
838 if sortable:
839 return algorithm(items)
840 else:
841 return (
842 permuted_items(permuted_indices)
843 for permuted_indices in algorithm(indices)
844 )
845
846 return iter(() if r else ((),))
847
848
849def derangements(iterable, r=None):
850 """Yield successive derangements of the elements in *iterable*.
851
852 A derangement is a permutation in which no element appears at its original
853 index. In other words, a derangement is a permutation that has no fixed points.
854
855 Suppose Alice, Bob, Carol, and Dave are playing Secret Santa.
856 The code below outputs all of the different ways to assign gift recipients
857 such that nobody is assigned to himself or herself:
858
859 >>> for d in derangements(['Alice', 'Bob', 'Carol', 'Dave']):
860 ... print(', '.join(d))
861 Bob, Alice, Dave, Carol
862 Bob, Carol, Dave, Alice
863 Bob, Dave, Alice, Carol
864 Carol, Alice, Dave, Bob
865 Carol, Dave, Alice, Bob
866 Carol, Dave, Bob, Alice
867 Dave, Alice, Bob, Carol
868 Dave, Carol, Alice, Bob
869 Dave, Carol, Bob, Alice
870
871 If *r* is given, only the *r*-length derangements are yielded.
872
873 >>> sorted(derangements(range(3), 2))
874 [(1, 0), (1, 2), (2, 0)]
875 >>> sorted(derangements([0, 2, 3], 2))
876 [(2, 0), (2, 3), (3, 0)]
877
878 Elements are treated as unique based on their position, not on their value.
879 If the input elements are unique, there will be no repeated values within a
880 permutation.
881
882 The number of derangements of a set of size *n* is known as the
883 "subfactorial of n". For n > 0, the subfactorial is:
884 ``round(math.factorial(n) / math.e)``.
885 """
886 xs = tuple(zip(iterable))
887 for ys in permutations(xs, r=r):
888 if any(map(operator_is, xs, ys)):
889 continue
890 yield tuple(y[0] for y in ys)
891
892
893def intersperse(e, iterable, n=1):
894 """Intersperse filler element *e* among the items in *iterable*, leaving
895 *n* items between each filler element.
896
897 >>> list(intersperse('!', [1, 2, 3, 4, 5]))
898 [1, '!', 2, '!', 3, '!', 4, '!', 5]
899
900 >>> list(intersperse(None, [1, 2, 3, 4, 5], n=2))
901 [1, 2, None, 3, 4, None, 5]
902
903 """
904 if n == 0:
905 raise ValueError('n must be > 0')
906 elif n == 1:
907 # interleave(repeat(e), iterable) -> e, x_0, e, x_1, e, x_2...
908 # islice(..., 1, None) -> x_0, e, x_1, e, x_2...
909 return islice(interleave(repeat(e), iterable), 1, None)
910 else:
911 # interleave(filler, chunks) -> [e], [x_0, x_1], [e], [x_2, x_3]...
912 # islice(..., 1, None) -> [x_0, x_1], [e], [x_2, x_3]...
913 # flatten(...) -> x_0, x_1, e, x_2, x_3...
914 filler = repeat([e])
915 chunks = chunked(iterable, n)
916 return flatten(islice(interleave(filler, chunks), 1, None))
917
918
919def unique_to_each(*iterables):
920 """Return the elements from each of the input iterables that aren't in the
921 other input iterables.
922
923 For example, suppose you have a set of packages, each with a set of
924 dependencies::
925
926 {'pkg_1': {'A', 'B'}, 'pkg_2': {'B', 'C'}, 'pkg_3': {'B', 'D'}}
927
928 If you remove one package, which dependencies can also be removed?
929
930 If ``pkg_1`` is removed, then ``A`` is no longer necessary - it is not
931 associated with ``pkg_2`` or ``pkg_3``. Similarly, ``C`` is only needed for
932 ``pkg_2``, and ``D`` is only needed for ``pkg_3``::
933
934 >>> unique_to_each({'A', 'B'}, {'B', 'C'}, {'B', 'D'})
935 [['A'], ['C'], ['D']]
936
937 If there are duplicates in one input iterable that aren't in the others
938 they will be duplicated in the output. Input order is preserved::
939
940 >>> unique_to_each("mississippi", "missouri")
941 [['p', 'p'], ['o', 'u', 'r']]
942
943 It is assumed that the elements of each iterable are hashable.
944
945 """
946 pool = [list(it) for it in iterables]
947 counts = Counter(chain.from_iterable(map(set, pool)))
948 uniques = {element for element in counts if counts[element] == 1}
949 return [list(filter(uniques.__contains__, it)) for it in pool]
950
951
952def windowed(seq, n, fillvalue=None, step=1):
953 """Return a sliding window of width *n* over the given iterable.
954
955 >>> all_windows = windowed([1, 2, 3, 4, 5], 3)
956 >>> list(all_windows)
957 [(1, 2, 3), (2, 3, 4), (3, 4, 5)]
958
959 When the window is larger than the iterable, *fillvalue* is used in place
960 of missing values:
961
962 >>> list(windowed([1, 2, 3], 4))
963 [(1, 2, 3, None)]
964
965 Each window will advance in increments of *step*:
966
967 >>> list(windowed([1, 2, 3, 4, 5, 6], 3, fillvalue='!', step=2))
968 [(1, 2, 3), (3, 4, 5), (5, 6, '!')]
969
970 To slide into the iterable's items, use :func:`chain` to add filler items
971 to the left:
972
973 >>> iterable = [1, 2, 3, 4]
974 >>> n = 3
975 >>> padding = [None] * (n - 1)
976 >>> list(windowed(chain(padding, iterable), 3))
977 [(None, None, 1), (None, 1, 2), (1, 2, 3), (2, 3, 4)]
978 """
979 if n < 0:
980 raise ValueError('n must be >= 0')
981 if n == 0:
982 yield ()
983 return
984 if step < 1:
985 raise ValueError('step must be >= 1')
986
987 iterator = iter(seq)
988
989 # Generate first window
990 window = deque(islice(iterator, n), maxlen=n)
991
992 # Deal with the first window not being full
993 if not window:
994 return
995 if len(window) < n:
996 yield tuple(window) + ((fillvalue,) * (n - len(window)))
997 return
998 yield tuple(window)
999
1000 # Create the filler for the next windows. The padding ensures
1001 # we have just enough elements to fill the last window.
1002 padding = (fillvalue,) * (n - 1 if step >= n else step - 1)
1003 filler = map(window.append, chain(iterator, padding))
1004
1005 # Generate the rest of the windows
1006 for _ in islice(filler, step - 1, None, step):
1007 yield tuple(window)
1008
1009
1010def substrings(iterable):
1011 """Yield all of the substrings of *iterable*.
1012
1013 >>> [''.join(s) for s in substrings('more')]
1014 ['m', 'o', 'r', 'e', 'mo', 'or', 're', 'mor', 'ore', 'more']
1015
1016 Note that non-string iterables can also be subdivided.
1017
1018 >>> list(substrings([0, 1, 2]))
1019 [(0,), (1,), (2,), (0, 1), (1, 2), (0, 1, 2)]
1020
1021 """
1022 # The length-1 substrings
1023 seq = []
1024 for item in iterable:
1025 seq.append(item)
1026 yield (item,)
1027 seq = tuple(seq)
1028 item_count = len(seq)
1029
1030 # And the rest
1031 for n in range(2, item_count + 1):
1032 for i in range(item_count - n + 1):
1033 yield seq[i : i + n]
1034
1035
1036def substrings_indexes(seq, reverse=False):
1037 """Yield all substrings and their positions in *seq*
1038
1039 The items yielded will be a tuple of the form ``(substr, i, j)``, where
1040 ``substr == seq[i:j]``.
1041
1042 This function only works for iterables that support slicing, such as
1043 ``str`` objects.
1044
1045 >>> for item in substrings_indexes('more'):
1046 ... print(item)
1047 ('m', 0, 1)
1048 ('o', 1, 2)
1049 ('r', 2, 3)
1050 ('e', 3, 4)
1051 ('mo', 0, 2)
1052 ('or', 1, 3)
1053 ('re', 2, 4)
1054 ('mor', 0, 3)
1055 ('ore', 1, 4)
1056 ('more', 0, 4)
1057
1058 Set *reverse* to ``True`` to yield the same items in the opposite order.
1059
1060
1061 """
1062 r = range(1, len(seq) + 1)
1063 if reverse:
1064 r = reversed(r)
1065 return (
1066 (seq[i : i + L], i, i + L) for L in r for i in range(len(seq) - L + 1)
1067 )
1068
1069
1070class bucket:
1071 """Wrap *iterable* and return an object that buckets the iterable into
1072 child iterables based on a *key* function.
1073
1074 >>> iterable = ['a1', 'b1', 'c1', 'a2', 'b2', 'c2', 'b3']
1075 >>> s = bucket(iterable, key=lambda x: x[0]) # Bucket by 1st character
1076 >>> sorted(list(s)) # Get the keys
1077 ['a', 'b', 'c']
1078 >>> a_iterable = s['a']
1079 >>> next(a_iterable)
1080 'a1'
1081 >>> next(a_iterable)
1082 'a2'
1083 >>> list(s['b'])
1084 ['b1', 'b2', 'b3']
1085
1086 The original iterable will be advanced and its items will be cached until
1087 they are used by the child iterables. This may require significant storage.
1088
1089 By default, attempting to select a bucket to which no items belong will
1090 exhaust the iterable and cache all values.
1091 If you specify a *validator* function, selected buckets will instead be
1092 checked against it.
1093
1094 >>> from itertools import count
1095 >>> it = count(1, 2) # Infinite sequence of odd numbers
1096 >>> key = lambda x: x % 10 # Bucket by last digit
1097 >>> validator = lambda x: x in {1, 3, 5, 7, 9} # Odd digits only
1098 >>> s = bucket(it, key=key, validator=validator)
1099 >>> 2 in s
1100 False
1101 >>> list(s[2])
1102 []
1103
1104 """
1105
1106 def __init__(self, iterable, key, validator=None):
1107 self._it = iter(iterable)
1108 self._key = key
1109 self._cache = defaultdict(deque)
1110 self._validator = validator or (lambda x: True)
1111
1112 def __contains__(self, value):
1113 if not self._validator(value):
1114 return False
1115
1116 try:
1117 item = next(self[value])
1118 except StopIteration:
1119 return False
1120 else:
1121 self._cache[value].appendleft(item)
1122
1123 return True
1124
1125 def _get_values(self, value):
1126 """
1127 Helper to yield items from the parent iterator that match *value*.
1128 Items that don't match are stored in the local cache as they
1129 are encountered.
1130 """
1131 while True:
1132 # If we've cached some items that match the target value, emit
1133 # the first one and evict it from the cache.
1134 if self._cache[value]:
1135 yield self._cache[value].popleft()
1136 # Otherwise we need to advance the parent iterator to search for
1137 # a matching item, caching the rest.
1138 else:
1139 while True:
1140 try:
1141 item = next(self._it)
1142 except StopIteration:
1143 return
1144 item_value = self._key(item)
1145 if item_value == value:
1146 yield item
1147 break
1148 elif self._validator(item_value):
1149 self._cache[item_value].append(item)
1150
1151 def __iter__(self):
1152 for item in self._it:
1153 item_value = self._key(item)
1154 if self._validator(item_value):
1155 self._cache[item_value].append(item)
1156
1157 return iter(self._cache)
1158
1159 def __getitem__(self, value):
1160 if not self._validator(value):
1161 return iter(())
1162
1163 return self._get_values(value)
1164
1165
1166def spy(iterable, n=1):
1167 """Return a 2-tuple with a list containing the first *n* elements of
1168 *iterable*, and an iterator with the same items as *iterable*.
1169 This allows you to "look ahead" at the items in the iterable without
1170 advancing it.
1171
1172 There is one item in the list by default:
1173
1174 >>> iterable = 'abcdefg'
1175 >>> head, iterable = spy(iterable)
1176 >>> head
1177 ['a']
1178 >>> list(iterable)
1179 ['a', 'b', 'c', 'd', 'e', 'f', 'g']
1180
1181 You may use unpacking to retrieve items instead of lists:
1182
1183 >>> (head,), iterable = spy('abcdefg')
1184 >>> head
1185 'a'
1186 >>> (first, second), iterable = spy('abcdefg', 2)
1187 >>> first
1188 'a'
1189 >>> second
1190 'b'
1191
1192 The number of items requested can be larger than the number of items in
1193 the iterable:
1194
1195 >>> iterable = [1, 2, 3, 4, 5]
1196 >>> head, iterable = spy(iterable, 10)
1197 >>> head
1198 [1, 2, 3, 4, 5]
1199 >>> list(iterable)
1200 [1, 2, 3, 4, 5]
1201
1202 """
1203 p, q = tee(iterable)
1204 return take(n, q), p
1205
1206
1207def interleave(*iterables):
1208 """Return a new iterable yielding from each iterable in turn,
1209 until the shortest is exhausted.
1210
1211 >>> list(interleave([1, 2, 3], [4, 5], [6, 7, 8]))
1212 [1, 4, 6, 2, 5, 7]
1213
1214 For a version that doesn't terminate after the shortest iterable is
1215 exhausted, see :func:`interleave_longest`.
1216
1217 """
1218 return chain.from_iterable(zip(*iterables))
1219
1220
1221def interleave_longest(*iterables):
1222 """Return a new iterable yielding from each iterable in turn,
1223 skipping any that are exhausted.
1224
1225 >>> list(interleave_longest([1, 2, 3], [4, 5], [6, 7, 8]))
1226 [1, 4, 6, 2, 5, 7, 3, 8]
1227
1228 This function produces the same output as :func:`roundrobin`, but may
1229 perform better for some inputs (in particular when the number of iterables
1230 is large).
1231
1232 """
1233 for xs in zip_longest(*iterables, fillvalue=_marker):
1234 for x in xs:
1235 if x is not _marker:
1236 yield x
1237
1238
1239def interleave_evenly(iterables, lengths=None):
1240 """
1241 Interleave multiple iterables so that their elements are evenly distributed
1242 throughout the output sequence.
1243
1244 >>> iterables = [1, 2, 3, 4, 5], ['a', 'b']
1245 >>> list(interleave_evenly(iterables))
1246 [1, 2, 'a', 3, 4, 'b', 5]
1247
1248 >>> iterables = [[1, 2, 3], [4, 5], [6, 7, 8]]
1249 >>> list(interleave_evenly(iterables))
1250 [1, 6, 4, 2, 7, 3, 8, 5]
1251
1252 This function requires iterables of known length. Iterables without
1253 ``__len__()`` can be used by manually specifying lengths with *lengths*:
1254
1255 >>> from itertools import combinations, repeat
1256 >>> iterables = [combinations(range(4), 2), ['a', 'b', 'c']]
1257 >>> lengths = [4 * (4 - 1) // 2, 3]
1258 >>> list(interleave_evenly(iterables, lengths=lengths))
1259 [(0, 1), (0, 2), 'a', (0, 3), (1, 2), 'b', (1, 3), (2, 3), 'c']
1260
1261 Based on Bresenham's algorithm.
1262 """
1263 if lengths is None:
1264 try:
1265 lengths = [len(it) for it in iterables]
1266 except TypeError:
1267 raise ValueError(
1268 'Iterable lengths could not be determined automatically. '
1269 'Specify them with the lengths keyword.'
1270 )
1271 elif len(iterables) != len(lengths):
1272 raise ValueError('Mismatching number of iterables and lengths.')
1273
1274 dims = len(lengths)
1275
1276 # sort iterables by length, descending
1277 lengths_permute = sorted(
1278 range(dims), key=lambda i: lengths[i], reverse=True
1279 )
1280 lengths_desc = [lengths[i] for i in lengths_permute]
1281 iters_desc = [iter(iterables[i]) for i in lengths_permute]
1282
1283 # the longest iterable is the primary one (Bresenham: the longest
1284 # distance along an axis)
1285 delta_primary, deltas_secondary = lengths_desc[0], lengths_desc[1:]
1286 iter_primary, iters_secondary = iters_desc[0], iters_desc[1:]
1287 errors = [delta_primary // dims] * len(deltas_secondary)
1288
1289 to_yield = sum(lengths)
1290 while to_yield:
1291 yield next(iter_primary)
1292 to_yield -= 1
1293 # update errors for each secondary iterable
1294 errors = [e - delta for e, delta in zip(errors, deltas_secondary)]
1295
1296 # those iterables for which the error is negative are yielded
1297 # ("diagonal step" in Bresenham)
1298 for i, e_ in enumerate(errors):
1299 if e_ < 0:
1300 yield next(iters_secondary[i])
1301 to_yield -= 1
1302 errors[i] += delta_primary
1303
1304
1305def collapse(iterable, base_type=None, levels=None):
1306 """Flatten an iterable with multiple levels of nesting (e.g., a list of
1307 lists of tuples) into non-iterable types.
1308
1309 >>> iterable = [(1, 2), ([3, 4], [[5], [6]])]
1310 >>> list(collapse(iterable))
1311 [1, 2, 3, 4, 5, 6]
1312
1313 Binary and text strings are not considered iterable and
1314 will not be collapsed.
1315
1316 To avoid collapsing other types, specify *base_type*:
1317
1318 >>> iterable = ['ab', ('cd', 'ef'), ['gh', 'ij']]
1319 >>> list(collapse(iterable, base_type=tuple))
1320 ['ab', ('cd', 'ef'), 'gh', 'ij']
1321
1322 Specify *levels* to stop flattening after a certain level:
1323
1324 >>> iterable = [('a', ['b']), ('c', ['d'])]
1325 >>> list(collapse(iterable)) # Fully flattened
1326 ['a', 'b', 'c', 'd']
1327 >>> list(collapse(iterable, levels=1)) # Only one level flattened
1328 ['a', ['b'], 'c', ['d']]
1329
1330 """
1331 stack = deque()
1332 # Add our first node group, treat the iterable as a single node
1333 stack.appendleft((0, repeat(iterable, 1)))
1334
1335 while stack:
1336 node_group = stack.popleft()
1337 level, nodes = node_group
1338
1339 # Check if beyond max level
1340 if levels is not None and level > levels:
1341 yield from nodes
1342 continue
1343
1344 for node in nodes:
1345 # Check if done iterating
1346 if isinstance(node, (str, bytes)) or (
1347 (base_type is not None) and isinstance(node, base_type)
1348 ):
1349 yield node
1350 # Otherwise try to create child nodes
1351 else:
1352 try:
1353 tree = iter(node)
1354 except TypeError:
1355 yield node
1356 else:
1357 # Save our current location
1358 stack.appendleft(node_group)
1359 # Append the new child node
1360 stack.appendleft((level + 1, tree))
1361 # Break to process child node
1362 break
1363
1364
1365def side_effect(func, iterable, chunk_size=None, before=None, after=None):
1366 """Invoke *func* on each item in *iterable* (or on each *chunk_size* group
1367 of items) before yielding the item.
1368
1369 `func` must be a function that takes a single argument. Its return value
1370 will be discarded.
1371
1372 *before* and *after* are optional functions that take no arguments. They
1373 will be executed before iteration starts and after it ends, respectively.
1374
1375 `side_effect` can be used for logging, updating progress bars, or anything
1376 that is not functionally "pure."
1377
1378 Emitting a status message:
1379
1380 >>> from more_itertools import consume
1381 >>> func = lambda item: print('Received {}'.format(item))
1382 >>> consume(side_effect(func, range(2)))
1383 Received 0
1384 Received 1
1385
1386 Operating on chunks of items:
1387
1388 >>> pair_sums = []
1389 >>> func = lambda chunk: pair_sums.append(sum(chunk))
1390 >>> list(side_effect(func, [0, 1, 2, 3, 4, 5], 2))
1391 [0, 1, 2, 3, 4, 5]
1392 >>> list(pair_sums)
1393 [1, 5, 9]
1394
1395 Writing to a file-like object:
1396
1397 >>> from io import StringIO
1398 >>> from more_itertools import consume
1399 >>> f = StringIO()
1400 >>> func = lambda x: print(x, file=f)
1401 >>> before = lambda: print(u'HEADER', file=f)
1402 >>> after = f.close
1403 >>> it = [u'a', u'b', u'c']
1404 >>> consume(side_effect(func, it, before=before, after=after))
1405 >>> f.closed
1406 True
1407
1408 """
1409 try:
1410 if before is not None:
1411 before()
1412
1413 if chunk_size is None:
1414 for item in iterable:
1415 func(item)
1416 yield item
1417 else:
1418 for chunk in chunked(iterable, chunk_size):
1419 func(chunk)
1420 yield from chunk
1421 finally:
1422 if after is not None:
1423 after()
1424
1425
1426def sliced(seq, n, strict=False):
1427 """Yield slices of length *n* from the sequence *seq*.
1428
1429 >>> list(sliced((1, 2, 3, 4, 5, 6), 3))
1430 [(1, 2, 3), (4, 5, 6)]
1431
1432 By the default, the last yielded slice will have fewer than *n* elements
1433 if the length of *seq* is not divisible by *n*:
1434
1435 >>> list(sliced((1, 2, 3, 4, 5, 6, 7, 8), 3))
1436 [(1, 2, 3), (4, 5, 6), (7, 8)]
1437
1438 If the length of *seq* is not divisible by *n* and *strict* is
1439 ``True``, then ``ValueError`` will be raised before the last
1440 slice is yielded.
1441
1442 This function will only work for iterables that support slicing.
1443 For non-sliceable iterables, see :func:`chunked`.
1444
1445 """
1446 iterator = takewhile(len, (seq[i : i + n] for i in count(0, n)))
1447 if strict:
1448
1449 def ret():
1450 for _slice in iterator:
1451 if len(_slice) != n:
1452 raise ValueError("seq is not divisible by n.")
1453 yield _slice
1454
1455 return ret()
1456 else:
1457 return iterator
1458
1459
1460def split_at(iterable, pred, maxsplit=-1, keep_separator=False):
1461 """Yield lists of items from *iterable*, where each list is delimited by
1462 an item where callable *pred* returns ``True``.
1463
1464 >>> list(split_at('abcdcba', lambda x: x == 'b'))
1465 [['a'], ['c', 'd', 'c'], ['a']]
1466
1467 >>> list(split_at(range(10), lambda n: n % 2 == 1))
1468 [[0], [2], [4], [6], [8], []]
1469
1470 At most *maxsplit* splits are done. If *maxsplit* is not specified or -1,
1471 then there is no limit on the number of splits:
1472
1473 >>> list(split_at(range(10), lambda n: n % 2 == 1, maxsplit=2))
1474 [[0], [2], [4, 5, 6, 7, 8, 9]]
1475
1476 By default, the delimiting items are not included in the output.
1477 To include them, set *keep_separator* to ``True``.
1478
1479 >>> list(split_at('abcdcba', lambda x: x == 'b', keep_separator=True))
1480 [['a'], ['b'], ['c', 'd', 'c'], ['b'], ['a']]
1481
1482 """
1483 if maxsplit == 0:
1484 yield list(iterable)
1485 return
1486
1487 buf = []
1488 it = iter(iterable)
1489 for item in it:
1490 if pred(item):
1491 yield buf
1492 if keep_separator:
1493 yield [item]
1494 if maxsplit == 1:
1495 yield list(it)
1496 return
1497 buf = []
1498 maxsplit -= 1
1499 else:
1500 buf.append(item)
1501 yield buf
1502
1503
1504def split_before(iterable, pred, maxsplit=-1):
1505 """Yield lists of items from *iterable*, where each list ends just before
1506 an item for which callable *pred* returns ``True``:
1507
1508 >>> list(split_before('OneTwo', lambda s: s.isupper()))
1509 [['O', 'n', 'e'], ['T', 'w', 'o']]
1510
1511 >>> list(split_before(range(10), lambda n: n % 3 == 0))
1512 [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
1513
1514 At most *maxsplit* splits are done. If *maxsplit* is not specified or -1,
1515 then there is no limit on the number of splits:
1516
1517 >>> list(split_before(range(10), lambda n: n % 3 == 0, maxsplit=2))
1518 [[0, 1, 2], [3, 4, 5], [6, 7, 8, 9]]
1519 """
1520 if maxsplit == 0:
1521 yield list(iterable)
1522 return
1523
1524 buf = []
1525 it = iter(iterable)
1526 for item in it:
1527 if pred(item) and buf:
1528 yield buf
1529 if maxsplit == 1:
1530 yield [item, *it]
1531 return
1532 buf = []
1533 maxsplit -= 1
1534 buf.append(item)
1535 if buf:
1536 yield buf
1537
1538
1539def split_after(iterable, pred, maxsplit=-1):
1540 """Yield lists of items from *iterable*, where each list ends with an
1541 item where callable *pred* returns ``True``:
1542
1543 >>> list(split_after('one1two2', lambda s: s.isdigit()))
1544 [['o', 'n', 'e', '1'], ['t', 'w', 'o', '2']]
1545
1546 >>> list(split_after(range(10), lambda n: n % 3 == 0))
1547 [[0], [1, 2, 3], [4, 5, 6], [7, 8, 9]]
1548
1549 At most *maxsplit* splits are done. If *maxsplit* is not specified or -1,
1550 then there is no limit on the number of splits:
1551
1552 >>> list(split_after(range(10), lambda n: n % 3 == 0, maxsplit=2))
1553 [[0], [1, 2, 3], [4, 5, 6, 7, 8, 9]]
1554
1555 """
1556 if maxsplit == 0:
1557 yield list(iterable)
1558 return
1559
1560 buf = []
1561 it = iter(iterable)
1562 for item in it:
1563 buf.append(item)
1564 if pred(item) and buf:
1565 yield buf
1566 if maxsplit == 1:
1567 buf = list(it)
1568 if buf:
1569 yield buf
1570 return
1571 buf = []
1572 maxsplit -= 1
1573 if buf:
1574 yield buf
1575
1576
1577def split_when(iterable, pred, maxsplit=-1):
1578 """Split *iterable* into pieces based on the output of *pred*.
1579 *pred* should be a function that takes successive pairs of items and
1580 returns ``True`` if the iterable should be split in between them.
1581
1582 For example, to find runs of increasing numbers, split the iterable when
1583 element ``i`` is larger than element ``i + 1``:
1584
1585 >>> list(split_when([1, 2, 3, 3, 2, 5, 2, 4, 2], lambda x, y: x > y))
1586 [[1, 2, 3, 3], [2, 5], [2, 4], [2]]
1587
1588 At most *maxsplit* splits are done. If *maxsplit* is not specified or -1,
1589 then there is no limit on the number of splits:
1590
1591 >>> list(split_when([1, 2, 3, 3, 2, 5, 2, 4, 2],
1592 ... lambda x, y: x > y, maxsplit=2))
1593 [[1, 2, 3, 3], [2, 5], [2, 4, 2]]
1594
1595 """
1596 if maxsplit == 0:
1597 yield list(iterable)
1598 return
1599
1600 it = iter(iterable)
1601 try:
1602 cur_item = next(it)
1603 except StopIteration:
1604 return
1605
1606 buf = [cur_item]
1607 for next_item in it:
1608 if pred(cur_item, next_item):
1609 yield buf
1610 if maxsplit == 1:
1611 yield [next_item, *it]
1612 return
1613 buf = []
1614 maxsplit -= 1
1615
1616 buf.append(next_item)
1617 cur_item = next_item
1618
1619 yield buf
1620
1621
1622def split_into(iterable, sizes):
1623 """Yield a list of sequential items from *iterable* of length 'n' for each
1624 integer 'n' in *sizes*.
1625
1626 >>> list(split_into([1,2,3,4,5,6], [1,2,3]))
1627 [[1], [2, 3], [4, 5, 6]]
1628
1629 If the sum of *sizes* is smaller than the length of *iterable*, then the
1630 remaining items of *iterable* will not be returned.
1631
1632 >>> list(split_into([1,2,3,4,5,6], [2,3]))
1633 [[1, 2], [3, 4, 5]]
1634
1635 If the sum of *sizes* is larger than the length of *iterable*, fewer items
1636 will be returned in the iteration that overruns the *iterable* and further
1637 lists will be empty:
1638
1639 >>> list(split_into([1,2,3,4], [1,2,3,4]))
1640 [[1], [2, 3], [4], []]
1641
1642 When a ``None`` object is encountered in *sizes*, the returned list will
1643 contain items up to the end of *iterable* the same way that
1644 :func:`itertools.slice` does:
1645
1646 >>> list(split_into([1,2,3,4,5,6,7,8,9,0], [2,3,None]))
1647 [[1, 2], [3, 4, 5], [6, 7, 8, 9, 0]]
1648
1649 :func:`split_into` can be useful for grouping a series of items where the
1650 sizes of the groups are not uniform. An example would be where in a row
1651 from a table, multiple columns represent elements of the same feature
1652 (e.g. a point represented by x,y,z) but, the format is not the same for
1653 all columns.
1654 """
1655 # convert the iterable argument into an iterator so its contents can
1656 # be consumed by islice in case it is a generator
1657 it = iter(iterable)
1658
1659 for size in sizes:
1660 if size is None:
1661 yield list(it)
1662 return
1663 else:
1664 yield list(islice(it, size))
1665
1666
1667def padded(iterable, fillvalue=None, n=None, next_multiple=False):
1668 """Yield the elements from *iterable*, followed by *fillvalue*, such that
1669 at least *n* items are emitted.
1670
1671 >>> list(padded([1, 2, 3], '?', 5))
1672 [1, 2, 3, '?', '?']
1673
1674 If *next_multiple* is ``True``, *fillvalue* will be emitted until the
1675 number of items emitted is a multiple of *n*:
1676
1677 >>> list(padded([1, 2, 3, 4], n=3, next_multiple=True))
1678 [1, 2, 3, 4, None, None]
1679
1680 If *n* is ``None``, *fillvalue* will be emitted indefinitely.
1681
1682 To create an *iterable* of exactly size *n*, you can truncate with
1683 :func:`islice`.
1684
1685 >>> list(islice(padded([1, 2, 3], '?'), 5))
1686 [1, 2, 3, '?', '?']
1687 >>> list(islice(padded([1, 2, 3, 4, 5, 6, 7, 8], '?'), 5))
1688 [1, 2, 3, 4, 5]
1689
1690 """
1691 iterator = iter(iterable)
1692 iterator_with_repeat = chain(iterator, repeat(fillvalue))
1693
1694 if n is None:
1695 return iterator_with_repeat
1696 elif n < 1:
1697 raise ValueError('n must be at least 1')
1698 elif next_multiple:
1699
1700 def slice_generator():
1701 for first in iterator:
1702 yield (first,)
1703 yield islice(iterator_with_repeat, n - 1)
1704
1705 # While elements exist produce slices of size n
1706 return chain.from_iterable(slice_generator())
1707 else:
1708 # Ensure the first batch is at least size n then iterate
1709 return chain(islice(iterator_with_repeat, n), iterator)
1710
1711
1712def repeat_each(iterable, n=2):
1713 """Repeat each element in *iterable* *n* times.
1714
1715 >>> list(repeat_each('ABC', 3))
1716 ['A', 'A', 'A', 'B', 'B', 'B', 'C', 'C', 'C']
1717 """
1718 return chain.from_iterable(map(repeat, iterable, repeat(n)))
1719
1720
1721def repeat_last(iterable, default=None):
1722 """After the *iterable* is exhausted, keep yielding its last element.
1723
1724 >>> list(islice(repeat_last(range(3)), 5))
1725 [0, 1, 2, 2, 2]
1726
1727 If the iterable is empty, yield *default* forever::
1728
1729 >>> list(islice(repeat_last(range(0), 42), 5))
1730 [42, 42, 42, 42, 42]
1731
1732 """
1733 item = _marker
1734 for item in iterable:
1735 yield item
1736 final = default if item is _marker else item
1737 yield from repeat(final)
1738
1739
1740def distribute(n, iterable):
1741 """Distribute the items from *iterable* among *n* smaller iterables.
1742
1743 >>> group_1, group_2 = distribute(2, [1, 2, 3, 4, 5, 6])
1744 >>> list(group_1)
1745 [1, 3, 5]
1746 >>> list(group_2)
1747 [2, 4, 6]
1748
1749 If the length of *iterable* is not evenly divisible by *n*, then the
1750 length of the returned iterables will not be identical:
1751
1752 >>> children = distribute(3, [1, 2, 3, 4, 5, 6, 7])
1753 >>> [list(c) for c in children]
1754 [[1, 4, 7], [2, 5], [3, 6]]
1755
1756 If the length of *iterable* is smaller than *n*, then the last returned
1757 iterables will be empty:
1758
1759 >>> children = distribute(5, [1, 2, 3])
1760 >>> [list(c) for c in children]
1761 [[1], [2], [3], [], []]
1762
1763 This function uses :func:`itertools.tee` and may require significant
1764 storage.
1765
1766 If you need the order items in the smaller iterables to match the
1767 original iterable, see :func:`divide`.
1768
1769 """
1770 if n < 1:
1771 raise ValueError('n must be at least 1')
1772
1773 children = tee(iterable, n)
1774 return [islice(it, index, None, n) for index, it in enumerate(children)]
1775
1776
1777def stagger(iterable, offsets=(-1, 0, 1), longest=False, fillvalue=None):
1778 """Yield tuples whose elements are offset from *iterable*.
1779 The amount by which the `i`-th item in each tuple is offset is given by
1780 the `i`-th item in *offsets*.
1781
1782 >>> list(stagger([0, 1, 2, 3]))
1783 [(None, 0, 1), (0, 1, 2), (1, 2, 3)]
1784 >>> list(stagger(range(8), offsets=(0, 2, 4)))
1785 [(0, 2, 4), (1, 3, 5), (2, 4, 6), (3, 5, 7)]
1786
1787 By default, the sequence will end when the final element of a tuple is the
1788 last item in the iterable. To continue until the first element of a tuple
1789 is the last item in the iterable, set *longest* to ``True``::
1790
1791 >>> list(stagger([0, 1, 2, 3], longest=True))
1792 [(None, 0, 1), (0, 1, 2), (1, 2, 3), (2, 3, None), (3, None, None)]
1793
1794 By default, ``None`` will be used to replace offsets beyond the end of the
1795 sequence. Specify *fillvalue* to use some other value.
1796
1797 """
1798 children = tee(iterable, len(offsets))
1799
1800 return zip_offset(
1801 *children, offsets=offsets, longest=longest, fillvalue=fillvalue
1802 )
1803
1804
1805def zip_equal(*iterables):
1806 """``zip`` the input *iterables* together but raise
1807 ``UnequalIterablesError`` if they aren't all the same length.
1808
1809 >>> it_1 = range(3)
1810 >>> it_2 = iter('abc')
1811 >>> list(zip_equal(it_1, it_2))
1812 [(0, 'a'), (1, 'b'), (2, 'c')]
1813
1814 >>> it_1 = range(3)
1815 >>> it_2 = iter('abcd')
1816 >>> list(zip_equal(it_1, it_2)) # doctest: +IGNORE_EXCEPTION_DETAIL
1817 Traceback (most recent call last):
1818 ...
1819 more_itertools.more.UnequalIterablesError: Iterables have different
1820 lengths
1821
1822 """
1823 if hexversion >= 0x30A00A6:
1824 warnings.warn(
1825 (
1826 'zip_equal will be removed in a future version of '
1827 'more-itertools. Use the builtin zip function with '
1828 'strict=True instead.'
1829 ),
1830 DeprecationWarning,
1831 )
1832
1833 return _zip_equal(*iterables)
1834
1835
1836def zip_offset(*iterables, offsets, longest=False, fillvalue=None):
1837 """``zip`` the input *iterables* together, but offset the `i`-th iterable
1838 by the `i`-th item in *offsets*.
1839
1840 >>> list(zip_offset('0123', 'abcdef', offsets=(0, 1)))
1841 [('0', 'b'), ('1', 'c'), ('2', 'd'), ('3', 'e')]
1842
1843 This can be used as a lightweight alternative to SciPy or pandas to analyze
1844 data sets in which some series have a lead or lag relationship.
1845
1846 By default, the sequence will end when the shortest iterable is exhausted.
1847 To continue until the longest iterable is exhausted, set *longest* to
1848 ``True``.
1849
1850 >>> list(zip_offset('0123', 'abcdef', offsets=(0, 1), longest=True))
1851 [('0', 'b'), ('1', 'c'), ('2', 'd'), ('3', 'e'), (None, 'f')]
1852
1853 By default, ``None`` will be used to replace offsets beyond the end of the
1854 sequence. Specify *fillvalue* to use some other value.
1855
1856 """
1857 if len(iterables) != len(offsets):
1858 raise ValueError("Number of iterables and offsets didn't match")
1859
1860 staggered = []
1861 for it, n in zip(iterables, offsets):
1862 if n < 0:
1863 staggered.append(chain(repeat(fillvalue, -n), it))
1864 elif n > 0:
1865 staggered.append(islice(it, n, None))
1866 else:
1867 staggered.append(it)
1868
1869 if longest:
1870 return zip_longest(*staggered, fillvalue=fillvalue)
1871
1872 return zip(*staggered)
1873
1874
1875def sort_together(
1876 iterables, key_list=(0,), key=None, reverse=False, strict=False
1877):
1878 """Return the input iterables sorted together, with *key_list* as the
1879 priority for sorting. All iterables are trimmed to the length of the
1880 shortest one.
1881
1882 This can be used like the sorting function in a spreadsheet. If each
1883 iterable represents a column of data, the key list determines which
1884 columns are used for sorting.
1885
1886 By default, all iterables are sorted using the ``0``-th iterable::
1887
1888 >>> iterables = [(4, 3, 2, 1), ('a', 'b', 'c', 'd')]
1889 >>> sort_together(iterables)
1890 [(1, 2, 3, 4), ('d', 'c', 'b', 'a')]
1891
1892 Set a different key list to sort according to another iterable.
1893 Specifying multiple keys dictates how ties are broken::
1894
1895 >>> iterables = [(3, 1, 2), (0, 1, 0), ('c', 'b', 'a')]
1896 >>> sort_together(iterables, key_list=(1, 2))
1897 [(2, 3, 1), (0, 0, 1), ('a', 'c', 'b')]
1898
1899 To sort by a function of the elements of the iterable, pass a *key*
1900 function. Its arguments are the elements of the iterables corresponding to
1901 the key list::
1902
1903 >>> names = ('a', 'b', 'c')
1904 >>> lengths = (1, 2, 3)
1905 >>> widths = (5, 2, 1)
1906 >>> def area(length, width):
1907 ... return length * width
1908 >>> sort_together([names, lengths, widths], key_list=(1, 2), key=area)
1909 [('c', 'b', 'a'), (3, 2, 1), (1, 2, 5)]
1910
1911 Set *reverse* to ``True`` to sort in descending order.
1912
1913 >>> sort_together([(1, 2, 3), ('c', 'b', 'a')], reverse=True)
1914 [(3, 2, 1), ('a', 'b', 'c')]
1915
1916 If the *strict* keyword argument is ``True``, then
1917 ``UnequalIterablesError`` will be raised if any of the iterables have
1918 different lengths.
1919
1920 """
1921 if key is None:
1922 # if there is no key function, the key argument to sorted is an
1923 # itemgetter
1924 key_argument = itemgetter(*key_list)
1925 else:
1926 # if there is a key function, call it with the items at the offsets
1927 # specified by the key function as arguments
1928 key_list = list(key_list)
1929 if len(key_list) == 1:
1930 # if key_list contains a single item, pass the item at that offset
1931 # as the only argument to the key function
1932 key_offset = key_list[0]
1933 key_argument = lambda zipped_items: key(zipped_items[key_offset])
1934 else:
1935 # if key_list contains multiple items, use itemgetter to return a
1936 # tuple of items, which we pass as *args to the key function
1937 get_key_items = itemgetter(*key_list)
1938 key_argument = lambda zipped_items: key(
1939 *get_key_items(zipped_items)
1940 )
1941
1942 zipper = zip_equal if strict else zip
1943 return list(
1944 zipper(*sorted(zipper(*iterables), key=key_argument, reverse=reverse))
1945 )
1946
1947
1948def unzip(iterable):
1949 """The inverse of :func:`zip`, this function disaggregates the elements
1950 of the zipped *iterable*.
1951
1952 The ``i``-th iterable contains the ``i``-th element from each element
1953 of the zipped iterable. The first element is used to determine the
1954 length of the remaining elements.
1955
1956 >>> iterable = [('a', 1), ('b', 2), ('c', 3), ('d', 4)]
1957 >>> letters, numbers = unzip(iterable)
1958 >>> list(letters)
1959 ['a', 'b', 'c', 'd']
1960 >>> list(numbers)
1961 [1, 2, 3, 4]
1962
1963 This is similar to using ``zip(*iterable)``, but it avoids reading
1964 *iterable* into memory. Note, however, that this function uses
1965 :func:`itertools.tee` and thus may require significant storage.
1966
1967 """
1968 head, iterable = spy(iterable)
1969 if not head:
1970 # empty iterable, e.g. zip([], [], [])
1971 return ()
1972 # spy returns a one-length iterable as head
1973 head = head[0]
1974 iterables = tee(iterable, len(head))
1975
1976 def itemgetter(i):
1977 def getter(obj):
1978 try:
1979 return obj[i]
1980 except IndexError:
1981 # basically if we have an iterable like
1982 # iter([(1, 2, 3), (4, 5), (6,)])
1983 # the second unzipped iterable would fail at the third tuple
1984 # since it would try to access tup[1]
1985 # same with the third unzipped iterable and the second tuple
1986 # to support these "improperly zipped" iterables,
1987 # we create a custom itemgetter
1988 # which just stops the unzipped iterables
1989 # at first length mismatch
1990 raise StopIteration
1991
1992 return getter
1993
1994 return tuple(map(itemgetter(i), it) for i, it in enumerate(iterables))
1995
1996
1997def divide(n, iterable):
1998 """Divide the elements from *iterable* into *n* parts, maintaining
1999 order.
2000
2001 >>> group_1, group_2 = divide(2, [1, 2, 3, 4, 5, 6])
2002 >>> list(group_1)
2003 [1, 2, 3]
2004 >>> list(group_2)
2005 [4, 5, 6]
2006
2007 If the length of *iterable* is not evenly divisible by *n*, then the
2008 length of the returned iterables will not be identical:
2009
2010 >>> children = divide(3, [1, 2, 3, 4, 5, 6, 7])
2011 >>> [list(c) for c in children]
2012 [[1, 2, 3], [4, 5], [6, 7]]
2013
2014 If the length of the iterable is smaller than n, then the last returned
2015 iterables will be empty:
2016
2017 >>> children = divide(5, [1, 2, 3])
2018 >>> [list(c) for c in children]
2019 [[1], [2], [3], [], []]
2020
2021 This function will exhaust the iterable before returning.
2022 If order is not important, see :func:`distribute`, which does not first
2023 pull the iterable into memory.
2024
2025 """
2026 if n < 1:
2027 raise ValueError('n must be at least 1')
2028
2029 try:
2030 iterable[:0]
2031 except TypeError:
2032 seq = tuple(iterable)
2033 else:
2034 seq = iterable
2035
2036 q, r = divmod(len(seq), n)
2037
2038 ret = []
2039 stop = 0
2040 for i in range(1, n + 1):
2041 start = stop
2042 stop += q + 1 if i <= r else q
2043 ret.append(iter(seq[start:stop]))
2044
2045 return ret
2046
2047
2048def always_iterable(obj, base_type=(str, bytes)):
2049 """If *obj* is iterable, return an iterator over its items::
2050
2051 >>> obj = (1, 2, 3)
2052 >>> list(always_iterable(obj))
2053 [1, 2, 3]
2054
2055 If *obj* is not iterable, return a one-item iterable containing *obj*::
2056
2057 >>> obj = 1
2058 >>> list(always_iterable(obj))
2059 [1]
2060
2061 If *obj* is ``None``, return an empty iterable:
2062
2063 >>> obj = None
2064 >>> list(always_iterable(None))
2065 []
2066
2067 By default, binary and text strings are not considered iterable::
2068
2069 >>> obj = 'foo'
2070 >>> list(always_iterable(obj))
2071 ['foo']
2072
2073 If *base_type* is set, objects for which ``isinstance(obj, base_type)``
2074 returns ``True`` won't be considered iterable.
2075
2076 >>> obj = {'a': 1}
2077 >>> list(always_iterable(obj)) # Iterate over the dict's keys
2078 ['a']
2079 >>> list(always_iterable(obj, base_type=dict)) # Treat dicts as a unit
2080 [{'a': 1}]
2081
2082 Set *base_type* to ``None`` to avoid any special handling and treat objects
2083 Python considers iterable as iterable:
2084
2085 >>> obj = 'foo'
2086 >>> list(always_iterable(obj, base_type=None))
2087 ['f', 'o', 'o']
2088 """
2089 if obj is None:
2090 return iter(())
2091
2092 if (base_type is not None) and isinstance(obj, base_type):
2093 return iter((obj,))
2094
2095 try:
2096 return iter(obj)
2097 except TypeError:
2098 return iter((obj,))
2099
2100
2101def adjacent(predicate, iterable, distance=1):
2102 """Return an iterable over `(bool, item)` tuples where the `item` is
2103 drawn from *iterable* and the `bool` indicates whether
2104 that item satisfies the *predicate* or is adjacent to an item that does.
2105
2106 For example, to find whether items are adjacent to a ``3``::
2107
2108 >>> list(adjacent(lambda x: x == 3, range(6)))
2109 [(False, 0), (False, 1), (True, 2), (True, 3), (True, 4), (False, 5)]
2110
2111 Set *distance* to change what counts as adjacent. For example, to find
2112 whether items are two places away from a ``3``:
2113
2114 >>> list(adjacent(lambda x: x == 3, range(6), distance=2))
2115 [(False, 0), (True, 1), (True, 2), (True, 3), (True, 4), (True, 5)]
2116
2117 This is useful for contextualizing the results of a search function.
2118 For example, a code comparison tool might want to identify lines that
2119 have changed, but also surrounding lines to give the viewer of the diff
2120 context.
2121
2122 The predicate function will only be called once for each item in the
2123 iterable.
2124
2125 See also :func:`groupby_transform`, which can be used with this function
2126 to group ranges of items with the same `bool` value.
2127
2128 """
2129 # Allow distance=0 mainly for testing that it reproduces results with map()
2130 if distance < 0:
2131 raise ValueError('distance must be at least 0')
2132
2133 i1, i2 = tee(iterable)
2134 padding = [False] * distance
2135 selected = chain(padding, map(predicate, i1), padding)
2136 adjacent_to_selected = map(any, windowed(selected, 2 * distance + 1))
2137 return zip(adjacent_to_selected, i2)
2138
2139
2140def groupby_transform(iterable, keyfunc=None, valuefunc=None, reducefunc=None):
2141 """An extension of :func:`itertools.groupby` that can apply transformations
2142 to the grouped data.
2143
2144 * *keyfunc* is a function computing a key value for each item in *iterable*
2145 * *valuefunc* is a function that transforms the individual items from
2146 *iterable* after grouping
2147 * *reducefunc* is a function that transforms each group of items
2148
2149 >>> iterable = 'aAAbBBcCC'
2150 >>> keyfunc = lambda k: k.upper()
2151 >>> valuefunc = lambda v: v.lower()
2152 >>> reducefunc = lambda g: ''.join(g)
2153 >>> list(groupby_transform(iterable, keyfunc, valuefunc, reducefunc))
2154 [('A', 'aaa'), ('B', 'bbb'), ('C', 'ccc')]
2155
2156 Each optional argument defaults to an identity function if not specified.
2157
2158 :func:`groupby_transform` is useful when grouping elements of an iterable
2159 using a separate iterable as the key. To do this, :func:`zip` the iterables
2160 and pass a *keyfunc* that extracts the first element and a *valuefunc*
2161 that extracts the second element::
2162
2163 >>> from operator import itemgetter
2164 >>> keys = [0, 0, 1, 1, 1, 2, 2, 2, 3]
2165 >>> values = 'abcdefghi'
2166 >>> iterable = zip(keys, values)
2167 >>> grouper = groupby_transform(iterable, itemgetter(0), itemgetter(1))
2168 >>> [(k, ''.join(g)) for k, g in grouper]
2169 [(0, 'ab'), (1, 'cde'), (2, 'fgh'), (3, 'i')]
2170
2171 Note that the order of items in the iterable is significant.
2172 Only adjacent items are grouped together, so if you don't want any
2173 duplicate groups, you should sort the iterable by the key function.
2174
2175 """
2176 ret = groupby(iterable, keyfunc)
2177 if valuefunc:
2178 ret = ((k, map(valuefunc, g)) for k, g in ret)
2179 if reducefunc:
2180 ret = ((k, reducefunc(g)) for k, g in ret)
2181
2182 return ret
2183
2184
2185class numeric_range(abc.Sequence, abc.Hashable):
2186 """An extension of the built-in ``range()`` function whose arguments can
2187 be any orderable numeric type.
2188
2189 With only *stop* specified, *start* defaults to ``0`` and *step*
2190 defaults to ``1``. The output items will match the type of *stop*:
2191
2192 >>> list(numeric_range(3.5))
2193 [0.0, 1.0, 2.0, 3.0]
2194
2195 With only *start* and *stop* specified, *step* defaults to ``1``. The
2196 output items will match the type of *start*:
2197
2198 >>> from decimal import Decimal
2199 >>> start = Decimal('2.1')
2200 >>> stop = Decimal('5.1')
2201 >>> list(numeric_range(start, stop))
2202 [Decimal('2.1'), Decimal('3.1'), Decimal('4.1')]
2203
2204 With *start*, *stop*, and *step* specified the output items will match
2205 the type of ``start + step``:
2206
2207 >>> from fractions import Fraction
2208 >>> start = Fraction(1, 2) # Start at 1/2
2209 >>> stop = Fraction(5, 2) # End at 5/2
2210 >>> step = Fraction(1, 2) # Count by 1/2
2211 >>> list(numeric_range(start, stop, step))
2212 [Fraction(1, 2), Fraction(1, 1), Fraction(3, 2), Fraction(2, 1)]
2213
2214 If *step* is zero, ``ValueError`` is raised. Negative steps are supported:
2215
2216 >>> list(numeric_range(3, -1, -1.0))
2217 [3.0, 2.0, 1.0, 0.0]
2218
2219 Be aware of the limitations of floating-point numbers; the representation
2220 of the yielded numbers may be surprising.
2221
2222 ``datetime.datetime`` objects can be used for *start* and *stop*, if *step*
2223 is a ``datetime.timedelta`` object:
2224
2225 >>> import datetime
2226 >>> start = datetime.datetime(2019, 1, 1)
2227 >>> stop = datetime.datetime(2019, 1, 3)
2228 >>> step = datetime.timedelta(days=1)
2229 >>> items = iter(numeric_range(start, stop, step))
2230 >>> next(items)
2231 datetime.datetime(2019, 1, 1, 0, 0)
2232 >>> next(items)
2233 datetime.datetime(2019, 1, 2, 0, 0)
2234
2235 """
2236
2237 _EMPTY_HASH = hash(range(0, 0))
2238
2239 def __init__(self, *args):
2240 argc = len(args)
2241 if argc == 1:
2242 (self._stop,) = args
2243 self._start = type(self._stop)(0)
2244 self._step = type(self._stop - self._start)(1)
2245 elif argc == 2:
2246 self._start, self._stop = args
2247 self._step = type(self._stop - self._start)(1)
2248 elif argc == 3:
2249 self._start, self._stop, self._step = args
2250 elif argc == 0:
2251 raise TypeError(
2252 f'numeric_range expected at least 1 argument, got {argc}'
2253 )
2254 else:
2255 raise TypeError(
2256 f'numeric_range expected at most 3 arguments, got {argc}'
2257 )
2258
2259 self._zero = type(self._step)(0)
2260 if self._step == self._zero:
2261 raise ValueError('numeric_range() arg 3 must not be zero')
2262 self._growing = self._step > self._zero
2263
2264 def __bool__(self):
2265 if self._growing:
2266 return self._start < self._stop
2267 else:
2268 return self._start > self._stop
2269
2270 def __contains__(self, elem):
2271 if self._growing:
2272 if self._start <= elem < self._stop:
2273 return (elem - self._start) % self._step == self._zero
2274 else:
2275 if self._start >= elem > self._stop:
2276 return (self._start - elem) % (-self._step) == self._zero
2277
2278 return False
2279
2280 def __eq__(self, other):
2281 if isinstance(other, numeric_range):
2282 empty_self = not bool(self)
2283 empty_other = not bool(other)
2284 if empty_self or empty_other:
2285 return empty_self and empty_other # True if both empty
2286 else:
2287 return (
2288 self._start == other._start
2289 and self._step == other._step
2290 and self._get_by_index(-1) == other._get_by_index(-1)
2291 )
2292 else:
2293 return False
2294
2295 def __getitem__(self, key):
2296 if isinstance(key, int):
2297 return self._get_by_index(key)
2298 elif isinstance(key, slice):
2299 step = self._step if key.step is None else key.step * self._step
2300
2301 if key.start is None or key.start <= -self._len:
2302 start = self._start
2303 elif key.start >= self._len:
2304 start = self._stop
2305 else: # -self._len < key.start < self._len
2306 start = self._get_by_index(key.start)
2307
2308 if key.stop is None or key.stop >= self._len:
2309 stop = self._stop
2310 elif key.stop <= -self._len:
2311 stop = self._start
2312 else: # -self._len < key.stop < self._len
2313 stop = self._get_by_index(key.stop)
2314
2315 return numeric_range(start, stop, step)
2316 else:
2317 raise TypeError(
2318 'numeric range indices must be '
2319 f'integers or slices, not {type(key).__name__}'
2320 )
2321
2322 def __hash__(self):
2323 if self:
2324 return hash((self._start, self._get_by_index(-1), self._step))
2325 else:
2326 return self._EMPTY_HASH
2327
2328 def __iter__(self):
2329 values = (self._start + (n * self._step) for n in count())
2330 if self._growing:
2331 return takewhile(partial(gt, self._stop), values)
2332 else:
2333 return takewhile(partial(lt, self._stop), values)
2334
2335 def __len__(self):
2336 return self._len
2337
2338 @cached_property
2339 def _len(self):
2340 if self._growing:
2341 start = self._start
2342 stop = self._stop
2343 step = self._step
2344 else:
2345 start = self._stop
2346 stop = self._start
2347 step = -self._step
2348 distance = stop - start
2349 if distance <= self._zero:
2350 return 0
2351 else: # distance > 0 and step > 0: regular euclidean division
2352 q, r = divmod(distance, step)
2353 return int(q) + int(r != self._zero)
2354
2355 def __reduce__(self):
2356 return numeric_range, (self._start, self._stop, self._step)
2357
2358 def __repr__(self):
2359 if self._step == 1:
2360 return f"numeric_range({self._start!r}, {self._stop!r})"
2361 return (
2362 f"numeric_range({self._start!r}, {self._stop!r}, {self._step!r})"
2363 )
2364
2365 def __reversed__(self):
2366 return iter(
2367 numeric_range(
2368 self._get_by_index(-1), self._start - self._step, -self._step
2369 )
2370 )
2371
2372 def count(self, value):
2373 return int(value in self)
2374
2375 def index(self, value):
2376 if self._growing:
2377 if self._start <= value < self._stop:
2378 q, r = divmod(value - self._start, self._step)
2379 if r == self._zero:
2380 return int(q)
2381 else:
2382 if self._start >= value > self._stop:
2383 q, r = divmod(self._start - value, -self._step)
2384 if r == self._zero:
2385 return int(q)
2386
2387 raise ValueError(f"{value} is not in numeric range")
2388
2389 def _get_by_index(self, i):
2390 if i < 0:
2391 i += self._len
2392 if i < 0 or i >= self._len:
2393 raise IndexError("numeric range object index out of range")
2394 return self._start + i * self._step
2395
2396
2397def count_cycle(iterable, n=None):
2398 """Cycle through the items from *iterable* up to *n* times, yielding
2399 the number of completed cycles along with each item. If *n* is omitted the
2400 process repeats indefinitely.
2401
2402 >>> list(count_cycle('AB', 3))
2403 [(0, 'A'), (0, 'B'), (1, 'A'), (1, 'B'), (2, 'A'), (2, 'B')]
2404
2405 """
2406 iterable = tuple(iterable)
2407 if not iterable:
2408 return iter(())
2409 counter = count() if n is None else range(n)
2410 return ((i, item) for i in counter for item in iterable)
2411
2412
2413def mark_ends(iterable):
2414 """Yield 3-tuples of the form ``(is_first, is_last, item)``.
2415
2416 >>> list(mark_ends('ABC'))
2417 [(True, False, 'A'), (False, False, 'B'), (False, True, 'C')]
2418
2419 Use this when looping over an iterable to take special action on its first
2420 and/or last items:
2421
2422 >>> iterable = ['Header', 100, 200, 'Footer']
2423 >>> total = 0
2424 >>> for is_first, is_last, item in mark_ends(iterable):
2425 ... if is_first:
2426 ... continue # Skip the header
2427 ... if is_last:
2428 ... continue # Skip the footer
2429 ... total += item
2430 >>> print(total)
2431 300
2432 """
2433 it = iter(iterable)
2434
2435 try:
2436 b = next(it)
2437 except StopIteration:
2438 return
2439
2440 try:
2441 for i in count():
2442 a = b
2443 b = next(it)
2444 yield i == 0, False, a
2445
2446 except StopIteration:
2447 yield i == 0, True, a
2448
2449
2450def locate(iterable, pred=bool, window_size=None):
2451 """Yield the index of each item in *iterable* for which *pred* returns
2452 ``True``.
2453
2454 *pred* defaults to :func:`bool`, which will select truthy items:
2455
2456 >>> list(locate([0, 1, 1, 0, 1, 0, 0]))
2457 [1, 2, 4]
2458
2459 Set *pred* to a custom function to, e.g., find the indexes for a particular
2460 item.
2461
2462 >>> list(locate(['a', 'b', 'c', 'b'], lambda x: x == 'b'))
2463 [1, 3]
2464
2465 If *window_size* is given, then the *pred* function will be called with
2466 that many items. This enables searching for sub-sequences:
2467
2468 >>> iterable = [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]
2469 >>> pred = lambda *args: args == (1, 2, 3)
2470 >>> list(locate(iterable, pred=pred, window_size=3))
2471 [1, 5, 9]
2472
2473 Use with :func:`seekable` to find indexes and then retrieve the associated
2474 items:
2475
2476 >>> from itertools import count
2477 >>> from more_itertools import seekable
2478 >>> source = (3 * n + 1 if (n % 2) else n // 2 for n in count())
2479 >>> it = seekable(source)
2480 >>> pred = lambda x: x > 100
2481 >>> indexes = locate(it, pred=pred)
2482 >>> i = next(indexes)
2483 >>> it.seek(i)
2484 >>> next(it)
2485 106
2486
2487 """
2488 if window_size is None:
2489 return compress(count(), map(pred, iterable))
2490
2491 if window_size < 1:
2492 raise ValueError('window size must be at least 1')
2493
2494 it = windowed(iterable, window_size, fillvalue=_marker)
2495 return compress(count(), starmap(pred, it))
2496
2497
2498def longest_common_prefix(iterables):
2499 """Yield elements of the longest common prefix among given *iterables*.
2500
2501 >>> ''.join(longest_common_prefix(['abcd', 'abc', 'abf']))
2502 'ab'
2503
2504 """
2505 return (c[0] for c in takewhile(all_equal, zip(*iterables)))
2506
2507
2508def lstrip(iterable, pred):
2509 """Yield the items from *iterable*, but strip any from the beginning
2510 for which *pred* returns ``True``.
2511
2512 For example, to remove a set of items from the start of an iterable:
2513
2514 >>> iterable = (None, False, None, 1, 2, None, 3, False, None)
2515 >>> pred = lambda x: x in {None, False, ''}
2516 >>> list(lstrip(iterable, pred))
2517 [1, 2, None, 3, False, None]
2518
2519 This function is analogous to to :func:`str.lstrip`, and is essentially
2520 an wrapper for :func:`itertools.dropwhile`.
2521
2522 """
2523 return dropwhile(pred, iterable)
2524
2525
2526def rstrip(iterable, pred):
2527 """Yield the items from *iterable*, but strip any from the end
2528 for which *pred* returns ``True``.
2529
2530 For example, to remove a set of items from the end of an iterable:
2531
2532 >>> iterable = (None, False, None, 1, 2, None, 3, False, None)
2533 >>> pred = lambda x: x in {None, False, ''}
2534 >>> list(rstrip(iterable, pred))
2535 [None, False, None, 1, 2, None, 3]
2536
2537 This function is analogous to :func:`str.rstrip`.
2538
2539 """
2540 cache = []
2541 cache_append = cache.append
2542 cache_clear = cache.clear
2543 for x in iterable:
2544 if pred(x):
2545 cache_append(x)
2546 else:
2547 yield from cache
2548 cache_clear()
2549 yield x
2550
2551
2552def strip(iterable, pred):
2553 """Yield the items from *iterable*, but strip any from the
2554 beginning and end for which *pred* returns ``True``.
2555
2556 For example, to remove a set of items from both ends of an iterable:
2557
2558 >>> iterable = (None, False, None, 1, 2, None, 3, False, None)
2559 >>> pred = lambda x: x in {None, False, ''}
2560 >>> list(strip(iterable, pred))
2561 [1, 2, None, 3]
2562
2563 This function is analogous to :func:`str.strip`.
2564
2565 """
2566 return rstrip(lstrip(iterable, pred), pred)
2567
2568
2569class islice_extended:
2570 """An extension of :func:`itertools.islice` that supports negative values
2571 for *stop*, *start*, and *step*.
2572
2573 >>> iterator = iter('abcdefgh')
2574 >>> list(islice_extended(iterator, -4, -1))
2575 ['e', 'f', 'g']
2576
2577 Slices with negative values require some caching of *iterable*, but this
2578 function takes care to minimize the amount of memory required.
2579
2580 For example, you can use a negative step with an infinite iterator:
2581
2582 >>> from itertools import count
2583 >>> list(islice_extended(count(), 110, 99, -2))
2584 [110, 108, 106, 104, 102, 100]
2585
2586 You can also use slice notation directly:
2587
2588 >>> iterator = map(str, count())
2589 >>> it = islice_extended(iterator)[10:20:2]
2590 >>> list(it)
2591 ['10', '12', '14', '16', '18']
2592
2593 """
2594
2595 def __init__(self, iterable, *args):
2596 it = iter(iterable)
2597 if args:
2598 self._iterator = _islice_helper(it, slice(*args))
2599 else:
2600 self._iterator = it
2601
2602 def __iter__(self):
2603 return self
2604
2605 def __next__(self):
2606 return next(self._iterator)
2607
2608 def __getitem__(self, key):
2609 if isinstance(key, slice):
2610 return islice_extended(_islice_helper(self._iterator, key))
2611
2612 raise TypeError('islice_extended.__getitem__ argument must be a slice')
2613
2614
2615def _islice_helper(it, s):
2616 start = s.start
2617 stop = s.stop
2618 if s.step == 0:
2619 raise ValueError('step argument must be a non-zero integer or None.')
2620 step = s.step or 1
2621
2622 if step > 0:
2623 start = 0 if (start is None) else start
2624
2625 if start < 0:
2626 # Consume all but the last -start items
2627 cache = deque(enumerate(it, 1), maxlen=-start)
2628 len_iter = cache[-1][0] if cache else 0
2629
2630 # Adjust start to be positive
2631 i = max(len_iter + start, 0)
2632
2633 # Adjust stop to be positive
2634 if stop is None:
2635 j = len_iter
2636 elif stop >= 0:
2637 j = min(stop, len_iter)
2638 else:
2639 j = max(len_iter + stop, 0)
2640
2641 # Slice the cache
2642 n = j - i
2643 if n <= 0:
2644 return
2645
2646 for index in range(n):
2647 if index % step == 0:
2648 # pop and yield the item.
2649 # We don't want to use an intermediate variable
2650 # it would extend the lifetime of the current item
2651 yield cache.popleft()[1]
2652 else:
2653 # just pop and discard the item
2654 cache.popleft()
2655 elif (stop is not None) and (stop < 0):
2656 # Advance to the start position
2657 next(islice(it, start, start), None)
2658
2659 # When stop is negative, we have to carry -stop items while
2660 # iterating
2661 cache = deque(islice(it, -stop), maxlen=-stop)
2662
2663 for index, item in enumerate(it):
2664 if index % step == 0:
2665 # pop and yield the item.
2666 # We don't want to use an intermediate variable
2667 # it would extend the lifetime of the current item
2668 yield cache.popleft()
2669 else:
2670 # just pop and discard the item
2671 cache.popleft()
2672 cache.append(item)
2673 else:
2674 # When both start and stop are positive we have the normal case
2675 yield from islice(it, start, stop, step)
2676 else:
2677 start = -1 if (start is None) else start
2678
2679 if (stop is not None) and (stop < 0):
2680 # Consume all but the last items
2681 n = -stop - 1
2682 cache = deque(enumerate(it, 1), maxlen=n)
2683 len_iter = cache[-1][0] if cache else 0
2684
2685 # If start and stop are both negative they are comparable and
2686 # we can just slice. Otherwise we can adjust start to be negative
2687 # and then slice.
2688 if start < 0:
2689 i, j = start, stop
2690 else:
2691 i, j = min(start - len_iter, -1), None
2692
2693 for index, item in list(cache)[i:j:step]:
2694 yield item
2695 else:
2696 # Advance to the stop position
2697 if stop is not None:
2698 m = stop + 1
2699 next(islice(it, m, m), None)
2700
2701 # stop is positive, so if start is negative they are not comparable
2702 # and we need the rest of the items.
2703 if start < 0:
2704 i = start
2705 n = None
2706 # stop is None and start is positive, so we just need items up to
2707 # the start index.
2708 elif stop is None:
2709 i = None
2710 n = start + 1
2711 # Both stop and start are positive, so they are comparable.
2712 else:
2713 i = None
2714 n = start - stop
2715 if n <= 0:
2716 return
2717
2718 cache = list(islice(it, n))
2719
2720 yield from cache[i::step]
2721
2722
2723def always_reversible(iterable):
2724 """An extension of :func:`reversed` that supports all iterables, not
2725 just those which implement the ``Reversible`` or ``Sequence`` protocols.
2726
2727 >>> print(*always_reversible(x for x in range(3)))
2728 2 1 0
2729
2730 If the iterable is already reversible, this function returns the
2731 result of :func:`reversed()`. If the iterable is not reversible,
2732 this function will cache the remaining items in the iterable and
2733 yield them in reverse order, which may require significant storage.
2734 """
2735 try:
2736 return reversed(iterable)
2737 except TypeError:
2738 return reversed(list(iterable))
2739
2740
2741def consecutive_groups(iterable, ordering=None):
2742 """Yield groups of consecutive items using :func:`itertools.groupby`.
2743 The *ordering* function determines whether two items are adjacent by
2744 returning their position.
2745
2746 By default, the ordering function is the identity function. This is
2747 suitable for finding runs of numbers:
2748
2749 >>> iterable = [1, 10, 11, 12, 20, 30, 31, 32, 33, 40]
2750 >>> for group in consecutive_groups(iterable):
2751 ... print(list(group))
2752 [1]
2753 [10, 11, 12]
2754 [20]
2755 [30, 31, 32, 33]
2756 [40]
2757
2758 To find runs of adjacent letters, apply :func:`ord` function
2759 to convert letters to ordinals.
2760
2761 >>> iterable = 'abcdfgilmnop'
2762 >>> ordering = ord
2763 >>> for group in consecutive_groups(iterable, ordering):
2764 ... print(list(group))
2765 ['a', 'b', 'c', 'd']
2766 ['f', 'g']
2767 ['i']
2768 ['l', 'm', 'n', 'o', 'p']
2769
2770 Each group of consecutive items is an iterator that shares it source with
2771 *iterable*. When an an output group is advanced, the previous group is
2772 no longer available unless its elements are copied (e.g., into a ``list``).
2773
2774 >>> iterable = [1, 2, 11, 12, 21, 22]
2775 >>> saved_groups = []
2776 >>> for group in consecutive_groups(iterable):
2777 ... saved_groups.append(list(group)) # Copy group elements
2778 >>> saved_groups
2779 [[1, 2], [11, 12], [21, 22]]
2780
2781 """
2782 if ordering is None:
2783 key = lambda x: x[0] - x[1]
2784 else:
2785 key = lambda x: x[0] - ordering(x[1])
2786
2787 for k, g in groupby(enumerate(iterable), key=key):
2788 yield map(itemgetter(1), g)
2789
2790
2791def difference(iterable, func=sub, *, initial=None):
2792 """This function is the inverse of :func:`itertools.accumulate`. By default
2793 it will compute the first difference of *iterable* using
2794 :func:`operator.sub`:
2795
2796 >>> from itertools import accumulate
2797 >>> iterable = accumulate([0, 1, 2, 3, 4]) # produces 0, 1, 3, 6, 10
2798 >>> list(difference(iterable))
2799 [0, 1, 2, 3, 4]
2800
2801 *func* defaults to :func:`operator.sub`, but other functions can be
2802 specified. They will be applied as follows::
2803
2804 A, B, C, D, ... --> A, func(B, A), func(C, B), func(D, C), ...
2805
2806 For example, to do progressive division:
2807
2808 >>> iterable = [1, 2, 6, 24, 120]
2809 >>> func = lambda x, y: x // y
2810 >>> list(difference(iterable, func))
2811 [1, 2, 3, 4, 5]
2812
2813 If the *initial* keyword is set, the first element will be skipped when
2814 computing successive differences.
2815
2816 >>> it = [10, 11, 13, 16] # from accumulate([1, 2, 3], initial=10)
2817 >>> list(difference(it, initial=10))
2818 [1, 2, 3]
2819
2820 """
2821 a, b = tee(iterable)
2822 try:
2823 first = [next(b)]
2824 except StopIteration:
2825 return iter([])
2826
2827 if initial is not None:
2828 first = []
2829
2830 return chain(first, map(func, b, a))
2831
2832
2833class SequenceView(Sequence):
2834 """Return a read-only view of the sequence object *target*.
2835
2836 :class:`SequenceView` objects are analogous to Python's built-in
2837 "dictionary view" types. They provide a dynamic view of a sequence's items,
2838 meaning that when the sequence updates, so does the view.
2839
2840 >>> seq = ['0', '1', '2']
2841 >>> view = SequenceView(seq)
2842 >>> view
2843 SequenceView(['0', '1', '2'])
2844 >>> seq.append('3')
2845 >>> view
2846 SequenceView(['0', '1', '2', '3'])
2847
2848 Sequence views support indexing, slicing, and length queries. They act
2849 like the underlying sequence, except they don't allow assignment:
2850
2851 >>> view[1]
2852 '1'
2853 >>> view[1:-1]
2854 ['1', '2']
2855 >>> len(view)
2856 4
2857
2858 Sequence views are useful as an alternative to copying, as they don't
2859 require (much) extra storage.
2860
2861 """
2862
2863 def __init__(self, target):
2864 if not isinstance(target, Sequence):
2865 raise TypeError
2866 self._target = target
2867
2868 def __getitem__(self, index):
2869 return self._target[index]
2870
2871 def __len__(self):
2872 return len(self._target)
2873
2874 def __repr__(self):
2875 return f'{self.__class__.__name__}({self._target!r})'
2876
2877
2878class seekable:
2879 """Wrap an iterator to allow for seeking backward and forward. This
2880 progressively caches the items in the source iterable so they can be
2881 re-visited.
2882
2883 Call :meth:`seek` with an index to seek to that position in the source
2884 iterable.
2885
2886 To "reset" an iterator, seek to ``0``:
2887
2888 >>> from itertools import count
2889 >>> it = seekable((str(n) for n in count()))
2890 >>> next(it), next(it), next(it)
2891 ('0', '1', '2')
2892 >>> it.seek(0)
2893 >>> next(it), next(it), next(it)
2894 ('0', '1', '2')
2895
2896 You can also seek forward:
2897
2898 >>> it = seekable((str(n) for n in range(20)))
2899 >>> it.seek(10)
2900 >>> next(it)
2901 '10'
2902 >>> it.seek(20) # Seeking past the end of the source isn't a problem
2903 >>> list(it)
2904 []
2905 >>> it.seek(0) # Resetting works even after hitting the end
2906 >>> next(it)
2907 '0'
2908
2909 Call :meth:`relative_seek` to seek relative to the source iterator's
2910 current position.
2911
2912 >>> it = seekable((str(n) for n in range(20)))
2913 >>> next(it), next(it), next(it)
2914 ('0', '1', '2')
2915 >>> it.relative_seek(2)
2916 >>> next(it)
2917 '5'
2918 >>> it.relative_seek(-3) # Source is at '6', we move back to '3'
2919 >>> next(it)
2920 '3'
2921 >>> it.relative_seek(-3) # Source is at '4', we move back to '1'
2922 >>> next(it)
2923 '1'
2924
2925
2926 Call :meth:`peek` to look ahead one item without advancing the iterator:
2927
2928 >>> it = seekable('1234')
2929 >>> it.peek()
2930 '1'
2931 >>> list(it)
2932 ['1', '2', '3', '4']
2933 >>> it.peek(default='empty')
2934 'empty'
2935
2936 Before the iterator is at its end, calling :func:`bool` on it will return
2937 ``True``. After it will return ``False``:
2938
2939 >>> it = seekable('5678')
2940 >>> bool(it)
2941 True
2942 >>> list(it)
2943 ['5', '6', '7', '8']
2944 >>> bool(it)
2945 False
2946
2947 You may view the contents of the cache with the :meth:`elements` method.
2948 That returns a :class:`SequenceView`, a view that updates automatically:
2949
2950 >>> it = seekable((str(n) for n in range(10)))
2951 >>> next(it), next(it), next(it)
2952 ('0', '1', '2')
2953 >>> elements = it.elements()
2954 >>> elements
2955 SequenceView(['0', '1', '2'])
2956 >>> next(it)
2957 '3'
2958 >>> elements
2959 SequenceView(['0', '1', '2', '3'])
2960
2961 By default, the cache grows as the source iterable progresses, so beware of
2962 wrapping very large or infinite iterables. Supply *maxlen* to limit the
2963 size of the cache (this of course limits how far back you can seek).
2964
2965 >>> from itertools import count
2966 >>> it = seekable((str(n) for n in count()), maxlen=2)
2967 >>> next(it), next(it), next(it), next(it)
2968 ('0', '1', '2', '3')
2969 >>> list(it.elements())
2970 ['2', '3']
2971 >>> it.seek(0)
2972 >>> next(it), next(it), next(it), next(it)
2973 ('2', '3', '4', '5')
2974 >>> next(it)
2975 '6'
2976
2977 """
2978
2979 def __init__(self, iterable, maxlen=None):
2980 self._source = iter(iterable)
2981 if maxlen is None:
2982 self._cache = []
2983 else:
2984 self._cache = deque([], maxlen)
2985 self._index = None
2986
2987 def __iter__(self):
2988 return self
2989
2990 def __next__(self):
2991 if self._index is not None:
2992 try:
2993 item = self._cache[self._index]
2994 except IndexError:
2995 self._index = None
2996 else:
2997 self._index += 1
2998 return item
2999
3000 item = next(self._source)
3001 self._cache.append(item)
3002 return item
3003
3004 def __bool__(self):
3005 try:
3006 self.peek()
3007 except StopIteration:
3008 return False
3009 return True
3010
3011 def peek(self, default=_marker):
3012 try:
3013 peeked = next(self)
3014 except StopIteration:
3015 if default is _marker:
3016 raise
3017 return default
3018 if self._index is None:
3019 self._index = len(self._cache)
3020 self._index -= 1
3021 return peeked
3022
3023 def elements(self):
3024 return SequenceView(self._cache)
3025
3026 def seek(self, index):
3027 self._index = index
3028 remainder = index - len(self._cache)
3029 if remainder > 0:
3030 consume(self, remainder)
3031
3032 def relative_seek(self, count):
3033 if self._index is None:
3034 self._index = len(self._cache)
3035
3036 self.seek(max(self._index + count, 0))
3037
3038
3039class run_length:
3040 """
3041 :func:`run_length.encode` compresses an iterable with run-length encoding.
3042 It yields groups of repeated items with the count of how many times they
3043 were repeated:
3044
3045 >>> uncompressed = 'abbcccdddd'
3046 >>> list(run_length.encode(uncompressed))
3047 [('a', 1), ('b', 2), ('c', 3), ('d', 4)]
3048
3049 :func:`run_length.decode` decompresses an iterable that was previously
3050 compressed with run-length encoding. It yields the items of the
3051 decompressed iterable:
3052
3053 >>> compressed = [('a', 1), ('b', 2), ('c', 3), ('d', 4)]
3054 >>> list(run_length.decode(compressed))
3055 ['a', 'b', 'b', 'c', 'c', 'c', 'd', 'd', 'd', 'd']
3056
3057 """
3058
3059 @staticmethod
3060 def encode(iterable):
3061 return ((k, ilen(g)) for k, g in groupby(iterable))
3062
3063 @staticmethod
3064 def decode(iterable):
3065 return chain.from_iterable(starmap(repeat, iterable))
3066
3067
3068def exactly_n(iterable, n, predicate=bool):
3069 """Return ``True`` if exactly ``n`` items in the iterable are ``True``
3070 according to the *predicate* function.
3071
3072 >>> exactly_n([True, True, False], 2)
3073 True
3074 >>> exactly_n([True, True, False], 1)
3075 False
3076 >>> exactly_n([0, 1, 2, 3, 4, 5], 3, lambda x: x < 3)
3077 True
3078
3079 The iterable will be advanced until ``n + 1`` truthy items are encountered,
3080 so avoid calling it on infinite iterables.
3081
3082 """
3083 return ilen(islice(filter(predicate, iterable), n + 1)) == n
3084
3085
3086def circular_shifts(iterable, steps=1):
3087 """Yield the circular shifts of *iterable*.
3088
3089 >>> list(circular_shifts(range(4)))
3090 [(0, 1, 2, 3), (1, 2, 3, 0), (2, 3, 0, 1), (3, 0, 1, 2)]
3091
3092 Set *steps* to the number of places to rotate to the left
3093 (or to the right if negative). Defaults to 1.
3094
3095 >>> list(circular_shifts(range(4), 2))
3096 [(0, 1, 2, 3), (2, 3, 0, 1)]
3097
3098 >>> list(circular_shifts(range(4), -1))
3099 [(0, 1, 2, 3), (3, 0, 1, 2), (2, 3, 0, 1), (1, 2, 3, 0)]
3100
3101 """
3102 buffer = deque(iterable)
3103 if steps == 0:
3104 raise ValueError('Steps should be a non-zero integer')
3105
3106 buffer.rotate(steps)
3107 steps = -steps
3108 n = len(buffer)
3109 n //= math.gcd(n, steps)
3110
3111 for _ in repeat(None, n):
3112 buffer.rotate(steps)
3113 yield tuple(buffer)
3114
3115
3116def make_decorator(wrapping_func, result_index=0):
3117 """Return a decorator version of *wrapping_func*, which is a function that
3118 modifies an iterable. *result_index* is the position in that function's
3119 signature where the iterable goes.
3120
3121 This lets you use itertools on the "production end," i.e. at function
3122 definition. This can augment what the function returns without changing the
3123 function's code.
3124
3125 For example, to produce a decorator version of :func:`chunked`:
3126
3127 >>> from more_itertools import chunked
3128 >>> chunker = make_decorator(chunked, result_index=0)
3129 >>> @chunker(3)
3130 ... def iter_range(n):
3131 ... return iter(range(n))
3132 ...
3133 >>> list(iter_range(9))
3134 [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
3135
3136 To only allow truthy items to be returned:
3137
3138 >>> truth_serum = make_decorator(filter, result_index=1)
3139 >>> @truth_serum(bool)
3140 ... def boolean_test():
3141 ... return [0, 1, '', ' ', False, True]
3142 ...
3143 >>> list(boolean_test())
3144 [1, ' ', True]
3145
3146 The :func:`peekable` and :func:`seekable` wrappers make for practical
3147 decorators:
3148
3149 >>> from more_itertools import peekable
3150 >>> peekable_function = make_decorator(peekable)
3151 >>> @peekable_function()
3152 ... def str_range(*args):
3153 ... return (str(x) for x in range(*args))
3154 ...
3155 >>> it = str_range(1, 20, 2)
3156 >>> next(it), next(it), next(it)
3157 ('1', '3', '5')
3158 >>> it.peek()
3159 '7'
3160 >>> next(it)
3161 '7'
3162
3163 """
3164
3165 # See https://sites.google.com/site/bbayles/index/decorator_factory for
3166 # notes on how this works.
3167 def decorator(*wrapping_args, **wrapping_kwargs):
3168 def outer_wrapper(f):
3169 def inner_wrapper(*args, **kwargs):
3170 result = f(*args, **kwargs)
3171 wrapping_args_ = list(wrapping_args)
3172 wrapping_args_.insert(result_index, result)
3173 return wrapping_func(*wrapping_args_, **wrapping_kwargs)
3174
3175 return inner_wrapper
3176
3177 return outer_wrapper
3178
3179 return decorator
3180
3181
3182def map_reduce(iterable, keyfunc, valuefunc=None, reducefunc=None):
3183 """Return a dictionary that maps the items in *iterable* to categories
3184 defined by *keyfunc*, transforms them with *valuefunc*, and
3185 then summarizes them by category with *reducefunc*.
3186
3187 *valuefunc* defaults to the identity function if it is unspecified.
3188 If *reducefunc* is unspecified, no summarization takes place:
3189
3190 >>> keyfunc = lambda x: x.upper()
3191 >>> result = map_reduce('abbccc', keyfunc)
3192 >>> sorted(result.items())
3193 [('A', ['a']), ('B', ['b', 'b']), ('C', ['c', 'c', 'c'])]
3194
3195 Specifying *valuefunc* transforms the categorized items:
3196
3197 >>> keyfunc = lambda x: x.upper()
3198 >>> valuefunc = lambda x: 1
3199 >>> result = map_reduce('abbccc', keyfunc, valuefunc)
3200 >>> sorted(result.items())
3201 [('A', [1]), ('B', [1, 1]), ('C', [1, 1, 1])]
3202
3203 Specifying *reducefunc* summarizes the categorized items:
3204
3205 >>> keyfunc = lambda x: x.upper()
3206 >>> valuefunc = lambda x: 1
3207 >>> reducefunc = sum
3208 >>> result = map_reduce('abbccc', keyfunc, valuefunc, reducefunc)
3209 >>> sorted(result.items())
3210 [('A', 1), ('B', 2), ('C', 3)]
3211
3212 You may want to filter the input iterable before applying the map/reduce
3213 procedure:
3214
3215 >>> all_items = range(30)
3216 >>> items = [x for x in all_items if 10 <= x <= 20] # Filter
3217 >>> keyfunc = lambda x: x % 2 # Evens map to 0; odds to 1
3218 >>> categories = map_reduce(items, keyfunc=keyfunc)
3219 >>> sorted(categories.items())
3220 [(0, [10, 12, 14, 16, 18, 20]), (1, [11, 13, 15, 17, 19])]
3221 >>> summaries = map_reduce(items, keyfunc=keyfunc, reducefunc=sum)
3222 >>> sorted(summaries.items())
3223 [(0, 90), (1, 75)]
3224
3225 Note that all items in the iterable are gathered into a list before the
3226 summarization step, which may require significant storage.
3227
3228 The returned object is a :obj:`collections.defaultdict` with the
3229 ``default_factory`` set to ``None``, such that it behaves like a normal
3230 dictionary.
3231
3232 """
3233
3234 ret = defaultdict(list)
3235
3236 if valuefunc is None:
3237 for item in iterable:
3238 key = keyfunc(item)
3239 ret[key].append(item)
3240
3241 else:
3242 for item in iterable:
3243 key = keyfunc(item)
3244 value = valuefunc(item)
3245 ret[key].append(value)
3246
3247 if reducefunc is not None:
3248 for key, value_list in ret.items():
3249 ret[key] = reducefunc(value_list)
3250
3251 ret.default_factory = None
3252 return ret
3253
3254
3255def rlocate(iterable, pred=bool, window_size=None):
3256 """Yield the index of each item in *iterable* for which *pred* returns
3257 ``True``, starting from the right and moving left.
3258
3259 *pred* defaults to :func:`bool`, which will select truthy items:
3260
3261 >>> list(rlocate([0, 1, 1, 0, 1, 0, 0])) # Truthy at 1, 2, and 4
3262 [4, 2, 1]
3263
3264 Set *pred* to a custom function to, e.g., find the indexes for a particular
3265 item:
3266
3267 >>> iterator = iter('abcb')
3268 >>> pred = lambda x: x == 'b'
3269 >>> list(rlocate(iterator, pred))
3270 [3, 1]
3271
3272 If *window_size* is given, then the *pred* function will be called with
3273 that many items. This enables searching for sub-sequences:
3274
3275 >>> iterable = [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]
3276 >>> pred = lambda *args: args == (1, 2, 3)
3277 >>> list(rlocate(iterable, pred=pred, window_size=3))
3278 [9, 5, 1]
3279
3280 Beware, this function won't return anything for infinite iterables.
3281 If *iterable* is reversible, ``rlocate`` will reverse it and search from
3282 the right. Otherwise, it will search from the left and return the results
3283 in reverse order.
3284
3285 See :func:`locate` to for other example applications.
3286
3287 """
3288 if window_size is None:
3289 try:
3290 len_iter = len(iterable)
3291 return (len_iter - i - 1 for i in locate(reversed(iterable), pred))
3292 except TypeError:
3293 pass
3294
3295 return reversed(list(locate(iterable, pred, window_size)))
3296
3297
3298def replace(iterable, pred, substitutes, count=None, window_size=1):
3299 """Yield the items from *iterable*, replacing the items for which *pred*
3300 returns ``True`` with the items from the iterable *substitutes*.
3301
3302 >>> iterable = [1, 1, 0, 1, 1, 0, 1, 1]
3303 >>> pred = lambda x: x == 0
3304 >>> substitutes = (2, 3)
3305 >>> list(replace(iterable, pred, substitutes))
3306 [1, 1, 2, 3, 1, 1, 2, 3, 1, 1]
3307
3308 If *count* is given, the number of replacements will be limited:
3309
3310 >>> iterable = [1, 1, 0, 1, 1, 0, 1, 1, 0]
3311 >>> pred = lambda x: x == 0
3312 >>> substitutes = [None]
3313 >>> list(replace(iterable, pred, substitutes, count=2))
3314 [1, 1, None, 1, 1, None, 1, 1, 0]
3315
3316 Use *window_size* to control the number of items passed as arguments to
3317 *pred*. This allows for locating and replacing subsequences.
3318
3319 >>> iterable = [0, 1, 2, 5, 0, 1, 2, 5]
3320 >>> window_size = 3
3321 >>> pred = lambda *args: args == (0, 1, 2) # 3 items passed to pred
3322 >>> substitutes = [3, 4] # Splice in these items
3323 >>> list(replace(iterable, pred, substitutes, window_size=window_size))
3324 [3, 4, 5, 3, 4, 5]
3325
3326 """
3327 if window_size < 1:
3328 raise ValueError('window_size must be at least 1')
3329
3330 # Save the substitutes iterable, since it's used more than once
3331 substitutes = tuple(substitutes)
3332
3333 # Add padding such that the number of windows matches the length of the
3334 # iterable
3335 it = chain(iterable, repeat(_marker, window_size - 1))
3336 windows = windowed(it, window_size)
3337
3338 n = 0
3339 for w in windows:
3340 # If the current window matches our predicate (and we haven't hit
3341 # our maximum number of replacements), splice in the substitutes
3342 # and then consume the following windows that overlap with this one.
3343 # For example, if the iterable is (0, 1, 2, 3, 4...)
3344 # and the window size is 2, we have (0, 1), (1, 2), (2, 3)...
3345 # If the predicate matches on (0, 1), we need to zap (0, 1) and (1, 2)
3346 if pred(*w):
3347 if (count is None) or (n < count):
3348 n += 1
3349 yield from substitutes
3350 consume(windows, window_size - 1)
3351 continue
3352
3353 # If there was no match (or we've reached the replacement limit),
3354 # yield the first item from the window.
3355 if w and (w[0] is not _marker):
3356 yield w[0]
3357
3358
3359def partitions(iterable):
3360 """Yield all possible order-preserving partitions of *iterable*.
3361
3362 >>> iterable = 'abc'
3363 >>> for part in partitions(iterable):
3364 ... print([''.join(p) for p in part])
3365 ['abc']
3366 ['a', 'bc']
3367 ['ab', 'c']
3368 ['a', 'b', 'c']
3369
3370 This is unrelated to :func:`partition`.
3371
3372 """
3373 sequence = list(iterable)
3374 n = len(sequence)
3375 for i in powerset(range(1, n)):
3376 yield [sequence[i:j] for i, j in zip((0,) + i, i + (n,))]
3377
3378
3379def set_partitions(iterable, k=None, min_size=None, max_size=None):
3380 """
3381 Yield the set partitions of *iterable* into *k* parts. Set partitions are
3382 not order-preserving.
3383
3384 >>> iterable = 'abc'
3385 >>> for part in set_partitions(iterable, 2):
3386 ... print([''.join(p) for p in part])
3387 ['a', 'bc']
3388 ['ab', 'c']
3389 ['b', 'ac']
3390
3391
3392 If *k* is not given, every set partition is generated.
3393
3394 >>> iterable = 'abc'
3395 >>> for part in set_partitions(iterable):
3396 ... print([''.join(p) for p in part])
3397 ['abc']
3398 ['a', 'bc']
3399 ['ab', 'c']
3400 ['b', 'ac']
3401 ['a', 'b', 'c']
3402
3403 if *min_size* and/or *max_size* are given, the minimum and/or maximum size
3404 per block in partition is set.
3405
3406 >>> iterable = 'abc'
3407 >>> for part in set_partitions(iterable, min_size=2):
3408 ... print([''.join(p) for p in part])
3409 ['abc']
3410 >>> for part in set_partitions(iterable, max_size=2):
3411 ... print([''.join(p) for p in part])
3412 ['a', 'bc']
3413 ['ab', 'c']
3414 ['b', 'ac']
3415 ['a', 'b', 'c']
3416
3417 """
3418 L = list(iterable)
3419 n = len(L)
3420 if k is not None:
3421 if k < 1:
3422 raise ValueError(
3423 "Can't partition in a negative or zero number of groups"
3424 )
3425 elif k > n:
3426 return
3427
3428 min_size = min_size if min_size is not None else 0
3429 max_size = max_size if max_size is not None else n
3430 if min_size > max_size:
3431 return
3432
3433 def set_partitions_helper(L, k):
3434 n = len(L)
3435 if k == 1:
3436 yield [L]
3437 elif n == k:
3438 yield [[s] for s in L]
3439 else:
3440 e, *M = L
3441 for p in set_partitions_helper(M, k - 1):
3442 yield [[e], *p]
3443 for p in set_partitions_helper(M, k):
3444 for i in range(len(p)):
3445 yield p[:i] + [[e] + p[i]] + p[i + 1 :]
3446
3447 if k is None:
3448 for k in range(1, n + 1):
3449 yield from filter(
3450 lambda z: all(min_size <= len(bk) <= max_size for bk in z),
3451 set_partitions_helper(L, k),
3452 )
3453 else:
3454 yield from filter(
3455 lambda z: all(min_size <= len(bk) <= max_size for bk in z),
3456 set_partitions_helper(L, k),
3457 )
3458
3459
3460class time_limited:
3461 """
3462 Yield items from *iterable* until *limit_seconds* have passed.
3463 If the time limit expires before all items have been yielded, the
3464 ``timed_out`` parameter will be set to ``True``.
3465
3466 >>> from time import sleep
3467 >>> def generator():
3468 ... yield 1
3469 ... yield 2
3470 ... sleep(0.2)
3471 ... yield 3
3472 >>> iterable = time_limited(0.1, generator())
3473 >>> list(iterable)
3474 [1, 2]
3475 >>> iterable.timed_out
3476 True
3477
3478 Note that the time is checked before each item is yielded, and iteration
3479 stops if the time elapsed is greater than *limit_seconds*. If your time
3480 limit is 1 second, but it takes 2 seconds to generate the first item from
3481 the iterable, the function will run for 2 seconds and not yield anything.
3482 As a special case, when *limit_seconds* is zero, the iterator never
3483 returns anything.
3484
3485 """
3486
3487 def __init__(self, limit_seconds, iterable):
3488 if limit_seconds < 0:
3489 raise ValueError('limit_seconds must be positive')
3490 self.limit_seconds = limit_seconds
3491 self._iterator = iter(iterable)
3492 self._start_time = monotonic()
3493 self.timed_out = False
3494
3495 def __iter__(self):
3496 return self
3497
3498 def __next__(self):
3499 if self.limit_seconds == 0:
3500 self.timed_out = True
3501 raise StopIteration
3502 item = next(self._iterator)
3503 if monotonic() - self._start_time > self.limit_seconds:
3504 self.timed_out = True
3505 raise StopIteration
3506
3507 return item
3508
3509
3510def only(iterable, default=None, too_long=None):
3511 """If *iterable* has only one item, return it.
3512 If it has zero items, return *default*.
3513 If it has more than one item, raise the exception given by *too_long*,
3514 which is ``ValueError`` by default.
3515
3516 >>> only([], default='missing')
3517 'missing'
3518 >>> only([1])
3519 1
3520 >>> only([1, 2]) # doctest: +IGNORE_EXCEPTION_DETAIL
3521 Traceback (most recent call last):
3522 ...
3523 ValueError: Expected exactly one item in iterable, but got 1, 2,
3524 and perhaps more.'
3525 >>> only([1, 2], too_long=TypeError) # doctest: +IGNORE_EXCEPTION_DETAIL
3526 Traceback (most recent call last):
3527 ...
3528 TypeError
3529
3530 Note that :func:`only` attempts to advance *iterable* twice to ensure there
3531 is only one item. See :func:`spy` or :func:`peekable` to check
3532 iterable contents less destructively.
3533
3534 """
3535 iterator = iter(iterable)
3536 for first in iterator:
3537 for second in iterator:
3538 msg = (
3539 f'Expected exactly one item in iterable, but got {first!r}, '
3540 f'{second!r}, and perhaps more.'
3541 )
3542 raise too_long or ValueError(msg)
3543 return first
3544 return default
3545
3546
3547def _ichunk(iterator, n):
3548 cache = deque()
3549 chunk = islice(iterator, n)
3550
3551 def generator():
3552 with suppress(StopIteration):
3553 while True:
3554 if cache:
3555 yield cache.popleft()
3556 else:
3557 yield next(chunk)
3558
3559 def materialize_next(n=1):
3560 # if n not specified materialize everything
3561 if n is None:
3562 cache.extend(chunk)
3563 return len(cache)
3564
3565 to_cache = n - len(cache)
3566
3567 # materialize up to n
3568 if to_cache > 0:
3569 cache.extend(islice(chunk, to_cache))
3570
3571 # return number materialized up to n
3572 return min(n, len(cache))
3573
3574 return (generator(), materialize_next)
3575
3576
3577def ichunked(iterable, n):
3578 """Break *iterable* into sub-iterables with *n* elements each.
3579 :func:`ichunked` is like :func:`chunked`, but it yields iterables
3580 instead of lists.
3581
3582 If the sub-iterables are read in order, the elements of *iterable*
3583 won't be stored in memory.
3584 If they are read out of order, :func:`itertools.tee` is used to cache
3585 elements as necessary.
3586
3587 >>> from itertools import count
3588 >>> all_chunks = ichunked(count(), 4)
3589 >>> c_1, c_2, c_3 = next(all_chunks), next(all_chunks), next(all_chunks)
3590 >>> list(c_2) # c_1's elements have been cached; c_3's haven't been
3591 [4, 5, 6, 7]
3592 >>> list(c_1)
3593 [0, 1, 2, 3]
3594 >>> list(c_3)
3595 [8, 9, 10, 11]
3596
3597 """
3598 iterator = iter(iterable)
3599 while True:
3600 # Create new chunk
3601 chunk, materialize_next = _ichunk(iterator, n)
3602
3603 # Check to see whether we're at the end of the source iterable
3604 if not materialize_next():
3605 return
3606
3607 yield chunk
3608
3609 # Fill previous chunk's cache
3610 materialize_next(None)
3611
3612
3613def iequals(*iterables):
3614 """Return ``True`` if all given *iterables* are equal to each other,
3615 which means that they contain the same elements in the same order.
3616
3617 The function is useful for comparing iterables of different data types
3618 or iterables that do not support equality checks.
3619
3620 >>> iequals("abc", ['a', 'b', 'c'], ('a', 'b', 'c'), iter("abc"))
3621 True
3622
3623 >>> iequals("abc", "acb")
3624 False
3625
3626 Not to be confused with :func:`all_equal`, which checks whether all
3627 elements of iterable are equal to each other.
3628
3629 """
3630 return all(map(all_equal, zip_longest(*iterables, fillvalue=object())))
3631
3632
3633def distinct_combinations(iterable, r):
3634 """Yield the distinct combinations of *r* items taken from *iterable*.
3635
3636 >>> list(distinct_combinations([0, 0, 1], 2))
3637 [(0, 0), (0, 1)]
3638
3639 Equivalent to ``set(combinations(iterable))``, except duplicates are not
3640 generated and thrown away. For larger input sequences this is much more
3641 efficient.
3642
3643 """
3644 if r < 0:
3645 raise ValueError('r must be non-negative')
3646 elif r == 0:
3647 yield ()
3648 return
3649 pool = tuple(iterable)
3650 generators = [unique_everseen(enumerate(pool), key=itemgetter(1))]
3651 current_combo = [None] * r
3652 level = 0
3653 while generators:
3654 try:
3655 cur_idx, p = next(generators[-1])
3656 except StopIteration:
3657 generators.pop()
3658 level -= 1
3659 continue
3660 current_combo[level] = p
3661 if level + 1 == r:
3662 yield tuple(current_combo)
3663 else:
3664 generators.append(
3665 unique_everseen(
3666 enumerate(pool[cur_idx + 1 :], cur_idx + 1),
3667 key=itemgetter(1),
3668 )
3669 )
3670 level += 1
3671
3672
3673def filter_except(validator, iterable, *exceptions):
3674 """Yield the items from *iterable* for which the *validator* function does
3675 not raise one of the specified *exceptions*.
3676
3677 *validator* is called for each item in *iterable*.
3678 It should be a function that accepts one argument and raises an exception
3679 if that item is not valid.
3680
3681 >>> iterable = ['1', '2', 'three', '4', None]
3682 >>> list(filter_except(int, iterable, ValueError, TypeError))
3683 ['1', '2', '4']
3684
3685 If an exception other than one given by *exceptions* is raised by
3686 *validator*, it is raised like normal.
3687 """
3688 for item in iterable:
3689 try:
3690 validator(item)
3691 except exceptions:
3692 pass
3693 else:
3694 yield item
3695
3696
3697def map_except(function, iterable, *exceptions):
3698 """Transform each item from *iterable* with *function* and yield the
3699 result, unless *function* raises one of the specified *exceptions*.
3700
3701 *function* is called to transform each item in *iterable*.
3702 It should accept one argument.
3703
3704 >>> iterable = ['1', '2', 'three', '4', None]
3705 >>> list(map_except(int, iterable, ValueError, TypeError))
3706 [1, 2, 4]
3707
3708 If an exception other than one given by *exceptions* is raised by
3709 *function*, it is raised like normal.
3710 """
3711 for item in iterable:
3712 try:
3713 yield function(item)
3714 except exceptions:
3715 pass
3716
3717
3718def map_if(iterable, pred, func, func_else=None):
3719 """Evaluate each item from *iterable* using *pred*. If the result is
3720 equivalent to ``True``, transform the item with *func* and yield it.
3721 Otherwise, transform the item with *func_else* and yield it.
3722
3723 *pred*, *func*, and *func_else* should each be functions that accept
3724 one argument. By default, *func_else* is the identity function.
3725
3726 >>> from math import sqrt
3727 >>> iterable = list(range(-5, 5))
3728 >>> iterable
3729 [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4]
3730 >>> list(map_if(iterable, lambda x: x > 3, lambda x: 'toobig'))
3731 [-5, -4, -3, -2, -1, 0, 1, 2, 3, 'toobig']
3732 >>> list(map_if(iterable, lambda x: x >= 0,
3733 ... lambda x: f'{sqrt(x):.2f}', lambda x: None))
3734 [None, None, None, None, None, '0.00', '1.00', '1.41', '1.73', '2.00']
3735 """
3736
3737 if func_else is None:
3738 for item in iterable:
3739 yield func(item) if pred(item) else item
3740
3741 else:
3742 for item in iterable:
3743 yield func(item) if pred(item) else func_else(item)
3744
3745
3746def _sample_unweighted(iterator, k, strict):
3747 # Algorithm L in the 1994 paper by Kim-Hung Li:
3748 # "Reservoir-Sampling Algorithms of Time Complexity O(n(1+log(N/n)))".
3749
3750 reservoir = list(islice(iterator, k))
3751 if strict and len(reservoir) < k:
3752 raise ValueError('Sample larger than population')
3753 W = 1.0
3754
3755 with suppress(StopIteration):
3756 while True:
3757 W *= random() ** (1 / k)
3758 skip = floor(log(random()) / log1p(-W))
3759 element = next(islice(iterator, skip, None))
3760 reservoir[randrange(k)] = element
3761
3762 shuffle(reservoir)
3763 return reservoir
3764
3765
3766def _sample_weighted(iterator, k, weights, strict):
3767 # Implementation of "A-ExpJ" from the 2006 paper by Efraimidis et al. :
3768 # "Weighted random sampling with a reservoir".
3769
3770 # Log-transform for numerical stability for weights that are small/large
3771 weight_keys = (log(random()) / weight for weight in weights)
3772
3773 # Fill up the reservoir (collection of samples) with the first `k`
3774 # weight-keys and elements, then heapify the list.
3775 reservoir = take(k, zip(weight_keys, iterator))
3776 if strict and len(reservoir) < k:
3777 raise ValueError('Sample larger than population')
3778
3779 heapify(reservoir)
3780
3781 # The number of jumps before changing the reservoir is a random variable
3782 # with an exponential distribution. Sample it using random() and logs.
3783 smallest_weight_key, _ = reservoir[0]
3784 weights_to_skip = log(random()) / smallest_weight_key
3785
3786 for weight, element in zip(weights, iterator):
3787 if weight >= weights_to_skip:
3788 # The notation here is consistent with the paper, but we store
3789 # the weight-keys in log-space for better numerical stability.
3790 smallest_weight_key, _ = reservoir[0]
3791 t_w = exp(weight * smallest_weight_key)
3792 r_2 = uniform(t_w, 1) # generate U(t_w, 1)
3793 weight_key = log(r_2) / weight
3794 heapreplace(reservoir, (weight_key, element))
3795 smallest_weight_key, _ = reservoir[0]
3796 weights_to_skip = log(random()) / smallest_weight_key
3797 else:
3798 weights_to_skip -= weight
3799
3800 ret = [element for weight_key, element in reservoir]
3801 shuffle(ret)
3802 return ret
3803
3804
3805def _sample_counted(population, k, counts, strict):
3806 element = None
3807 remaining = 0
3808
3809 def feed(i):
3810 # Advance *i* steps ahead and consume an element
3811 nonlocal element, remaining
3812
3813 while i + 1 > remaining:
3814 i = i - remaining
3815 element = next(population)
3816 remaining = next(counts)
3817 remaining -= i + 1
3818 return element
3819
3820 with suppress(StopIteration):
3821 reservoir = []
3822 for _ in range(k):
3823 reservoir.append(feed(0))
3824
3825 if strict and len(reservoir) < k:
3826 raise ValueError('Sample larger than population')
3827
3828 with suppress(StopIteration):
3829 W = 1.0
3830 while True:
3831 W *= random() ** (1 / k)
3832 skip = floor(log(random()) / log1p(-W))
3833 element = feed(skip)
3834 reservoir[randrange(k)] = element
3835
3836 shuffle(reservoir)
3837 return reservoir
3838
3839
3840def sample(iterable, k, weights=None, *, counts=None, strict=False):
3841 """Return a *k*-length list of elements chosen (without replacement)
3842 from the *iterable*. Similar to :func:`random.sample`, but works on
3843 iterables of unknown length.
3844
3845 >>> iterable = range(100)
3846 >>> sample(iterable, 5) # doctest: +SKIP
3847 [81, 60, 96, 16, 4]
3848
3849 For iterables with repeated elements, you may supply *counts* to
3850 indicate the repeats.
3851
3852 >>> iterable = ['a', 'b']
3853 >>> counts = [3, 4] # Equivalent to 'a', 'a', 'a', 'b', 'b', 'b', 'b'
3854 >>> sample(iterable, k=3, counts=counts) # doctest: +SKIP
3855 ['a', 'a', 'b']
3856
3857 An iterable with *weights* may be given:
3858
3859 >>> iterable = range(100)
3860 >>> weights = (i * i + 1 for i in range(100))
3861 >>> sampled = sample(iterable, 5, weights=weights) # doctest: +SKIP
3862 [79, 67, 74, 66, 78]
3863
3864 Weighted selections are made without replacement.
3865 After an element is selected, it is removed from the pool and the
3866 relative weights of the other elements increase (this
3867 does not match the behavior of :func:`random.sample`'s *counts*
3868 parameter). Note that *weights* may not be used with *counts*.
3869
3870 If the length of *iterable* is less than *k*,
3871 ``ValueError`` is raised if *strict* is ``True`` and
3872 all elements are returned (in shuffled order) if *strict* is ``False``.
3873
3874 By default, the `Algorithm L <https://w.wiki/ANrM>`__ reservoir sampling
3875 technique is used. When *weights* are provided,
3876 `Algorithm A-ExpJ <https://w.wiki/ANrS>`__ is used.
3877 """
3878 iterator = iter(iterable)
3879
3880 if k < 0:
3881 raise ValueError('k must be non-negative')
3882
3883 if k == 0:
3884 return []
3885
3886 if weights is not None and counts is not None:
3887 raise TypeError('weights and counts are mutually exclusive')
3888
3889 elif weights is not None:
3890 weights = iter(weights)
3891 return _sample_weighted(iterator, k, weights, strict)
3892
3893 elif counts is not None:
3894 counts = iter(counts)
3895 return _sample_counted(iterator, k, counts, strict)
3896
3897 else:
3898 return _sample_unweighted(iterator, k, strict)
3899
3900
3901def is_sorted(iterable, key=None, reverse=False, strict=False):
3902 """Returns ``True`` if the items of iterable are in sorted order, and
3903 ``False`` otherwise. *key* and *reverse* have the same meaning that they do
3904 in the built-in :func:`sorted` function.
3905
3906 >>> is_sorted(['1', '2', '3', '4', '5'], key=int)
3907 True
3908 >>> is_sorted([5, 4, 3, 1, 2], reverse=True)
3909 False
3910
3911 If *strict*, tests for strict sorting, that is, returns ``False`` if equal
3912 elements are found:
3913
3914 >>> is_sorted([1, 2, 2])
3915 True
3916 >>> is_sorted([1, 2, 2], strict=True)
3917 False
3918
3919 The function returns ``False`` after encountering the first out-of-order
3920 item, which means it may produce results that differ from the built-in
3921 :func:`sorted` function for objects with unusual comparison dynamics
3922 (like ``math.nan``). If there are no out-of-order items, the iterable is
3923 exhausted.
3924 """
3925 it = iterable if (key is None) else map(key, iterable)
3926 a, b = tee(it)
3927 next(b, None)
3928 if reverse:
3929 b, a = a, b
3930 return all(map(lt, a, b)) if strict else not any(map(lt, b, a))
3931
3932
3933class AbortThread(BaseException):
3934 pass
3935
3936
3937class callback_iter:
3938 """Convert a function that uses callbacks to an iterator.
3939
3940 Let *func* be a function that takes a `callback` keyword argument.
3941 For example:
3942
3943 >>> def func(callback=None):
3944 ... for i, c in [(1, 'a'), (2, 'b'), (3, 'c')]:
3945 ... if callback:
3946 ... callback(i, c)
3947 ... return 4
3948
3949
3950 Use ``with callback_iter(func)`` to get an iterator over the parameters
3951 that are delivered to the callback.
3952
3953 >>> with callback_iter(func) as it:
3954 ... for args, kwargs in it:
3955 ... print(args)
3956 (1, 'a')
3957 (2, 'b')
3958 (3, 'c')
3959
3960 The function will be called in a background thread. The ``done`` property
3961 indicates whether it has completed execution.
3962
3963 >>> it.done
3964 True
3965
3966 If it completes successfully, its return value will be available
3967 in the ``result`` property.
3968
3969 >>> it.result
3970 4
3971
3972 Notes:
3973
3974 * If the function uses some keyword argument besides ``callback``, supply
3975 *callback_kwd*.
3976 * If it finished executing, but raised an exception, accessing the
3977 ``result`` property will raise the same exception.
3978 * If it hasn't finished executing, accessing the ``result``
3979 property from within the ``with`` block will raise ``RuntimeError``.
3980 * If it hasn't finished executing, accessing the ``result`` property from
3981 outside the ``with`` block will raise a
3982 ``more_itertools.AbortThread`` exception.
3983 * Provide *wait_seconds* to adjust how frequently the it is polled for
3984 output.
3985
3986 """
3987
3988 def __init__(self, func, callback_kwd='callback', wait_seconds=0.1):
3989 self._func = func
3990 self._callback_kwd = callback_kwd
3991 self._aborted = False
3992 self._future = None
3993 self._wait_seconds = wait_seconds
3994 # Lazily import concurrent.future
3995 self._executor = __import__(
3996 'concurrent.futures'
3997 ).futures.ThreadPoolExecutor(max_workers=1)
3998 self._iterator = self._reader()
3999
4000 def __enter__(self):
4001 return self
4002
4003 def __exit__(self, exc_type, exc_value, traceback):
4004 self._aborted = True
4005 self._executor.shutdown()
4006
4007 def __iter__(self):
4008 return self
4009
4010 def __next__(self):
4011 return next(self._iterator)
4012
4013 @property
4014 def done(self):
4015 if self._future is None:
4016 return False
4017 return self._future.done()
4018
4019 @property
4020 def result(self):
4021 if not self.done:
4022 raise RuntimeError('Function has not yet completed')
4023
4024 return self._future.result()
4025
4026 def _reader(self):
4027 q = Queue()
4028
4029 def callback(*args, **kwargs):
4030 if self._aborted:
4031 raise AbortThread('canceled by user')
4032
4033 q.put((args, kwargs))
4034
4035 self._future = self._executor.submit(
4036 self._func, **{self._callback_kwd: callback}
4037 )
4038
4039 while True:
4040 try:
4041 item = q.get(timeout=self._wait_seconds)
4042 except Empty:
4043 pass
4044 else:
4045 q.task_done()
4046 yield item
4047
4048 if self._future.done():
4049 break
4050
4051 remaining = []
4052 while True:
4053 try:
4054 item = q.get_nowait()
4055 except Empty:
4056 break
4057 else:
4058 q.task_done()
4059 remaining.append(item)
4060 q.join()
4061 yield from remaining
4062
4063
4064def windowed_complete(iterable, n):
4065 """
4066 Yield ``(beginning, middle, end)`` tuples, where:
4067
4068 * Each ``middle`` has *n* items from *iterable*
4069 * Each ``beginning`` has the items before the ones in ``middle``
4070 * Each ``end`` has the items after the ones in ``middle``
4071
4072 >>> iterable = range(7)
4073 >>> n = 3
4074 >>> for beginning, middle, end in windowed_complete(iterable, n):
4075 ... print(beginning, middle, end)
4076 () (0, 1, 2) (3, 4, 5, 6)
4077 (0,) (1, 2, 3) (4, 5, 6)
4078 (0, 1) (2, 3, 4) (5, 6)
4079 (0, 1, 2) (3, 4, 5) (6,)
4080 (0, 1, 2, 3) (4, 5, 6) ()
4081
4082 Note that *n* must be at least 0 and most equal to the length of
4083 *iterable*.
4084
4085 This function will exhaust the iterable and may require significant
4086 storage.
4087 """
4088 if n < 0:
4089 raise ValueError('n must be >= 0')
4090
4091 seq = tuple(iterable)
4092 size = len(seq)
4093
4094 if n > size:
4095 raise ValueError('n must be <= len(seq)')
4096
4097 for i in range(size - n + 1):
4098 beginning = seq[:i]
4099 middle = seq[i : i + n]
4100 end = seq[i + n :]
4101 yield beginning, middle, end
4102
4103
4104def all_unique(iterable, key=None):
4105 """
4106 Returns ``True`` if all the elements of *iterable* are unique (no two
4107 elements are equal).
4108
4109 >>> all_unique('ABCB')
4110 False
4111
4112 If a *key* function is specified, it will be used to make comparisons.
4113
4114 >>> all_unique('ABCb')
4115 True
4116 >>> all_unique('ABCb', str.lower)
4117 False
4118
4119 The function returns as soon as the first non-unique element is
4120 encountered. Iterables with a mix of hashable and unhashable items can
4121 be used, but the function will be slower for unhashable items.
4122 """
4123 seenset = set()
4124 seenset_add = seenset.add
4125 seenlist = []
4126 seenlist_add = seenlist.append
4127 for element in map(key, iterable) if key else iterable:
4128 try:
4129 if element in seenset:
4130 return False
4131 seenset_add(element)
4132 except TypeError:
4133 if element in seenlist:
4134 return False
4135 seenlist_add(element)
4136 return True
4137
4138
4139def nth_product(index, *args):
4140 """Equivalent to ``list(product(*args))[index]``.
4141
4142 The products of *args* can be ordered lexicographically.
4143 :func:`nth_product` computes the product at sort position *index* without
4144 computing the previous products.
4145
4146 >>> nth_product(8, range(2), range(2), range(2), range(2))
4147 (1, 0, 0, 0)
4148
4149 ``IndexError`` will be raised if the given *index* is invalid.
4150 """
4151 pools = list(map(tuple, reversed(args)))
4152 ns = list(map(len, pools))
4153
4154 c = reduce(mul, ns)
4155
4156 if index < 0:
4157 index += c
4158
4159 if not 0 <= index < c:
4160 raise IndexError
4161
4162 result = []
4163 for pool, n in zip(pools, ns):
4164 result.append(pool[index % n])
4165 index //= n
4166
4167 return tuple(reversed(result))
4168
4169
4170def nth_permutation(iterable, r, index):
4171 """Equivalent to ``list(permutations(iterable, r))[index]```
4172
4173 The subsequences of *iterable* that are of length *r* where order is
4174 important can be ordered lexicographically. :func:`nth_permutation`
4175 computes the subsequence at sort position *index* directly, without
4176 computing the previous subsequences.
4177
4178 >>> nth_permutation('ghijk', 2, 5)
4179 ('h', 'i')
4180
4181 ``ValueError`` will be raised If *r* is negative or greater than the length
4182 of *iterable*.
4183 ``IndexError`` will be raised if the given *index* is invalid.
4184 """
4185 pool = list(iterable)
4186 n = len(pool)
4187
4188 if r is None or r == n:
4189 r, c = n, factorial(n)
4190 elif not 0 <= r < n:
4191 raise ValueError
4192 else:
4193 c = perm(n, r)
4194 assert c > 0 # factorial(n)>0, and r<n so perm(n,r) is never zero
4195
4196 if index < 0:
4197 index += c
4198
4199 if not 0 <= index < c:
4200 raise IndexError
4201
4202 result = [0] * r
4203 q = index * factorial(n) // c if r < n else index
4204 for d in range(1, n + 1):
4205 q, i = divmod(q, d)
4206 if 0 <= n - d < r:
4207 result[n - d] = i
4208 if q == 0:
4209 break
4210
4211 return tuple(map(pool.pop, result))
4212
4213
4214def nth_combination_with_replacement(iterable, r, index):
4215 """Equivalent to
4216 ``list(combinations_with_replacement(iterable, r))[index]``.
4217
4218
4219 The subsequences with repetition of *iterable* that are of length *r* can
4220 be ordered lexicographically. :func:`nth_combination_with_replacement`
4221 computes the subsequence at sort position *index* directly, without
4222 computing the previous subsequences with replacement.
4223
4224 >>> nth_combination_with_replacement(range(5), 3, 5)
4225 (0, 1, 1)
4226
4227 ``ValueError`` will be raised If *r* is negative or greater than the length
4228 of *iterable*.
4229 ``IndexError`` will be raised if the given *index* is invalid.
4230 """
4231 pool = tuple(iterable)
4232 n = len(pool)
4233 if (r < 0) or (r > n):
4234 raise ValueError
4235
4236 c = comb(n + r - 1, r)
4237
4238 if index < 0:
4239 index += c
4240
4241 if (index < 0) or (index >= c):
4242 raise IndexError
4243
4244 result = []
4245 i = 0
4246 while r:
4247 r -= 1
4248 while n >= 0:
4249 num_combs = comb(n + r - 1, r)
4250 if index < num_combs:
4251 break
4252 n -= 1
4253 i += 1
4254 index -= num_combs
4255 result.append(pool[i])
4256
4257 return tuple(result)
4258
4259
4260def value_chain(*args):
4261 """Yield all arguments passed to the function in the same order in which
4262 they were passed. If an argument itself is iterable then iterate over its
4263 values.
4264
4265 >>> list(value_chain(1, 2, 3, [4, 5, 6]))
4266 [1, 2, 3, 4, 5, 6]
4267
4268 Binary and text strings are not considered iterable and are emitted
4269 as-is:
4270
4271 >>> list(value_chain('12', '34', ['56', '78']))
4272 ['12', '34', '56', '78']
4273
4274 Pre- or postpend a single element to an iterable:
4275
4276 >>> list(value_chain(1, [2, 3, 4, 5, 6]))
4277 [1, 2, 3, 4, 5, 6]
4278 >>> list(value_chain([1, 2, 3, 4, 5], 6))
4279 [1, 2, 3, 4, 5, 6]
4280
4281 Multiple levels of nesting are not flattened.
4282
4283 """
4284 for value in args:
4285 if isinstance(value, (str, bytes)):
4286 yield value
4287 continue
4288 try:
4289 yield from value
4290 except TypeError:
4291 yield value
4292
4293
4294def product_index(element, *args):
4295 """Equivalent to ``list(product(*args)).index(element)``
4296
4297 The products of *args* can be ordered lexicographically.
4298 :func:`product_index` computes the first index of *element* without
4299 computing the previous products.
4300
4301 >>> product_index([8, 2], range(10), range(5))
4302 42
4303
4304 ``ValueError`` will be raised if the given *element* isn't in the product
4305 of *args*.
4306 """
4307 index = 0
4308
4309 for x, pool in zip_longest(element, args, fillvalue=_marker):
4310 if x is _marker or pool is _marker:
4311 raise ValueError('element is not a product of args')
4312
4313 pool = tuple(pool)
4314 index = index * len(pool) + pool.index(x)
4315
4316 return index
4317
4318
4319def combination_index(element, iterable):
4320 """Equivalent to ``list(combinations(iterable, r)).index(element)``
4321
4322 The subsequences of *iterable* that are of length *r* can be ordered
4323 lexicographically. :func:`combination_index` computes the index of the
4324 first *element*, without computing the previous combinations.
4325
4326 >>> combination_index('adf', 'abcdefg')
4327 10
4328
4329 ``ValueError`` will be raised if the given *element* isn't one of the
4330 combinations of *iterable*.
4331 """
4332 element = enumerate(element)
4333 k, y = next(element, (None, None))
4334 if k is None:
4335 return 0
4336
4337 indexes = []
4338 pool = enumerate(iterable)
4339 for n, x in pool:
4340 if x == y:
4341 indexes.append(n)
4342 tmp, y = next(element, (None, None))
4343 if tmp is None:
4344 break
4345 else:
4346 k = tmp
4347 else:
4348 raise ValueError('element is not a combination of iterable')
4349
4350 n, _ = last(pool, default=(n, None))
4351
4352 # Python versions below 3.8 don't have math.comb
4353 index = 1
4354 for i, j in enumerate(reversed(indexes), start=1):
4355 j = n - j
4356 if i <= j:
4357 index += comb(j, i)
4358
4359 return comb(n + 1, k + 1) - index
4360
4361
4362def combination_with_replacement_index(element, iterable):
4363 """Equivalent to
4364 ``list(combinations_with_replacement(iterable, r)).index(element)``
4365
4366 The subsequences with repetition of *iterable* that are of length *r* can
4367 be ordered lexicographically. :func:`combination_with_replacement_index`
4368 computes the index of the first *element*, without computing the previous
4369 combinations with replacement.
4370
4371 >>> combination_with_replacement_index('adf', 'abcdefg')
4372 20
4373
4374 ``ValueError`` will be raised if the given *element* isn't one of the
4375 combinations with replacement of *iterable*.
4376 """
4377 element = tuple(element)
4378 l = len(element)
4379 element = enumerate(element)
4380
4381 k, y = next(element, (None, None))
4382 if k is None:
4383 return 0
4384
4385 indexes = []
4386 pool = tuple(iterable)
4387 for n, x in enumerate(pool):
4388 while x == y:
4389 indexes.append(n)
4390 tmp, y = next(element, (None, None))
4391 if tmp is None:
4392 break
4393 else:
4394 k = tmp
4395 if y is None:
4396 break
4397 else:
4398 raise ValueError(
4399 'element is not a combination with replacement of iterable'
4400 )
4401
4402 n = len(pool)
4403 occupations = [0] * n
4404 for p in indexes:
4405 occupations[p] += 1
4406
4407 index = 0
4408 cumulative_sum = 0
4409 for k in range(1, n):
4410 cumulative_sum += occupations[k - 1]
4411 j = l + n - 1 - k - cumulative_sum
4412 i = n - k
4413 if i <= j:
4414 index += comb(j, i)
4415
4416 return index
4417
4418
4419def permutation_index(element, iterable):
4420 """Equivalent to ``list(permutations(iterable, r)).index(element)```
4421
4422 The subsequences of *iterable* that are of length *r* where order is
4423 important can be ordered lexicographically. :func:`permutation_index`
4424 computes the index of the first *element* directly, without computing
4425 the previous permutations.
4426
4427 >>> permutation_index([1, 3, 2], range(5))
4428 19
4429
4430 ``ValueError`` will be raised if the given *element* isn't one of the
4431 permutations of *iterable*.
4432 """
4433 index = 0
4434 pool = list(iterable)
4435 for i, x in zip(range(len(pool), -1, -1), element):
4436 r = pool.index(x)
4437 index = index * i + r
4438 del pool[r]
4439
4440 return index
4441
4442
4443class countable:
4444 """Wrap *iterable* and keep a count of how many items have been consumed.
4445
4446 The ``items_seen`` attribute starts at ``0`` and increments as the iterable
4447 is consumed:
4448
4449 >>> iterable = map(str, range(10))
4450 >>> it = countable(iterable)
4451 >>> it.items_seen
4452 0
4453 >>> next(it), next(it)
4454 ('0', '1')
4455 >>> list(it)
4456 ['2', '3', '4', '5', '6', '7', '8', '9']
4457 >>> it.items_seen
4458 10
4459 """
4460
4461 def __init__(self, iterable):
4462 self._iterator = iter(iterable)
4463 self.items_seen = 0
4464
4465 def __iter__(self):
4466 return self
4467
4468 def __next__(self):
4469 item = next(self._iterator)
4470 self.items_seen += 1
4471
4472 return item
4473
4474
4475def chunked_even(iterable, n):
4476 """Break *iterable* into lists of approximately length *n*.
4477 Items are distributed such the lengths of the lists differ by at most
4478 1 item.
4479
4480 >>> iterable = [1, 2, 3, 4, 5, 6, 7]
4481 >>> n = 3
4482 >>> list(chunked_even(iterable, n)) # List lengths: 3, 2, 2
4483 [[1, 2, 3], [4, 5], [6, 7]]
4484 >>> list(chunked(iterable, n)) # List lengths: 3, 3, 1
4485 [[1, 2, 3], [4, 5, 6], [7]]
4486
4487 """
4488 iterator = iter(iterable)
4489
4490 # Initialize a buffer to process the chunks while keeping
4491 # some back to fill any underfilled chunks
4492 min_buffer = (n - 1) * (n - 2)
4493 buffer = list(islice(iterator, min_buffer))
4494
4495 # Append items until we have a completed chunk
4496 for _ in islice(map(buffer.append, iterator), n, None, n):
4497 yield buffer[:n]
4498 del buffer[:n]
4499
4500 # Check if any chunks need addition processing
4501 if not buffer:
4502 return
4503 length = len(buffer)
4504
4505 # Chunks are either size `full_size <= n` or `partial_size = full_size - 1`
4506 q, r = divmod(length, n)
4507 num_lists = q + (1 if r > 0 else 0)
4508 q, r = divmod(length, num_lists)
4509 full_size = q + (1 if r > 0 else 0)
4510 partial_size = full_size - 1
4511 num_full = length - partial_size * num_lists
4512
4513 # Yield chunks of full size
4514 partial_start_idx = num_full * full_size
4515 if full_size > 0:
4516 for i in range(0, partial_start_idx, full_size):
4517 yield buffer[i : i + full_size]
4518
4519 # Yield chunks of partial size
4520 if partial_size > 0:
4521 for i in range(partial_start_idx, length, partial_size):
4522 yield buffer[i : i + partial_size]
4523
4524
4525def zip_broadcast(*objects, scalar_types=(str, bytes), strict=False):
4526 """A version of :func:`zip` that "broadcasts" any scalar
4527 (i.e., non-iterable) items into output tuples.
4528
4529 >>> iterable_1 = [1, 2, 3]
4530 >>> iterable_2 = ['a', 'b', 'c']
4531 >>> scalar = '_'
4532 >>> list(zip_broadcast(iterable_1, iterable_2, scalar))
4533 [(1, 'a', '_'), (2, 'b', '_'), (3, 'c', '_')]
4534
4535 The *scalar_types* keyword argument determines what types are considered
4536 scalar. It is set to ``(str, bytes)`` by default. Set it to ``None`` to
4537 treat strings and byte strings as iterable:
4538
4539 >>> list(zip_broadcast('abc', 0, 'xyz', scalar_types=None))
4540 [('a', 0, 'x'), ('b', 0, 'y'), ('c', 0, 'z')]
4541
4542 If the *strict* keyword argument is ``True``, then
4543 ``UnequalIterablesError`` will be raised if any of the iterables have
4544 different lengths.
4545 """
4546
4547 def is_scalar(obj):
4548 if scalar_types and isinstance(obj, scalar_types):
4549 return True
4550 try:
4551 iter(obj)
4552 except TypeError:
4553 return True
4554 else:
4555 return False
4556
4557 size = len(objects)
4558 if not size:
4559 return
4560
4561 new_item = [None] * size
4562 iterables, iterable_positions = [], []
4563 for i, obj in enumerate(objects):
4564 if is_scalar(obj):
4565 new_item[i] = obj
4566 else:
4567 iterables.append(iter(obj))
4568 iterable_positions.append(i)
4569
4570 if not iterables:
4571 yield tuple(objects)
4572 return
4573
4574 zipper = _zip_equal if strict else zip
4575 for item in zipper(*iterables):
4576 for i, new_item[i] in zip(iterable_positions, item):
4577 pass
4578 yield tuple(new_item)
4579
4580
4581def unique_in_window(iterable, n, key=None):
4582 """Yield the items from *iterable* that haven't been seen recently.
4583 *n* is the size of the lookback window.
4584
4585 >>> iterable = [0, 1, 0, 2, 3, 0]
4586 >>> n = 3
4587 >>> list(unique_in_window(iterable, n))
4588 [0, 1, 2, 3, 0]
4589
4590 The *key* function, if provided, will be used to determine uniqueness:
4591
4592 >>> list(unique_in_window('abAcda', 3, key=lambda x: x.lower()))
4593 ['a', 'b', 'c', 'd', 'a']
4594
4595 The items in *iterable* must be hashable.
4596
4597 """
4598 if n <= 0:
4599 raise ValueError('n must be greater than 0')
4600
4601 window = deque(maxlen=n)
4602 counts = defaultdict(int)
4603 use_key = key is not None
4604
4605 for item in iterable:
4606 if len(window) == n:
4607 to_discard = window[0]
4608 if counts[to_discard] == 1:
4609 del counts[to_discard]
4610 else:
4611 counts[to_discard] -= 1
4612
4613 k = key(item) if use_key else item
4614 if k not in counts:
4615 yield item
4616 counts[k] += 1
4617 window.append(k)
4618
4619
4620def duplicates_everseen(iterable, key=None):
4621 """Yield duplicate elements after their first appearance.
4622
4623 >>> list(duplicates_everseen('mississippi'))
4624 ['s', 'i', 's', 's', 'i', 'p', 'i']
4625 >>> list(duplicates_everseen('AaaBbbCccAaa', str.lower))
4626 ['a', 'a', 'b', 'b', 'c', 'c', 'A', 'a', 'a']
4627
4628 This function is analogous to :func:`unique_everseen` and is subject to
4629 the same performance considerations.
4630
4631 """
4632 seen_set = set()
4633 seen_list = []
4634 use_key = key is not None
4635
4636 for element in iterable:
4637 k = key(element) if use_key else element
4638 try:
4639 if k not in seen_set:
4640 seen_set.add(k)
4641 else:
4642 yield element
4643 except TypeError:
4644 if k not in seen_list:
4645 seen_list.append(k)
4646 else:
4647 yield element
4648
4649
4650def duplicates_justseen(iterable, key=None):
4651 """Yields serially-duplicate elements after their first appearance.
4652
4653 >>> list(duplicates_justseen('mississippi'))
4654 ['s', 's', 'p']
4655 >>> list(duplicates_justseen('AaaBbbCccAaa', str.lower))
4656 ['a', 'a', 'b', 'b', 'c', 'c', 'a', 'a']
4657
4658 This function is analogous to :func:`unique_justseen`.
4659
4660 """
4661 return flatten(g for _, g in groupby(iterable, key) for _ in g)
4662
4663
4664def classify_unique(iterable, key=None):
4665 """Classify each element in terms of its uniqueness.
4666
4667 For each element in the input iterable, return a 3-tuple consisting of:
4668
4669 1. The element itself
4670 2. ``False`` if the element is equal to the one preceding it in the input,
4671 ``True`` otherwise (i.e. the equivalent of :func:`unique_justseen`)
4672 3. ``False`` if this element has been seen anywhere in the input before,
4673 ``True`` otherwise (i.e. the equivalent of :func:`unique_everseen`)
4674
4675 >>> list(classify_unique('otto')) # doctest: +NORMALIZE_WHITESPACE
4676 [('o', True, True),
4677 ('t', True, True),
4678 ('t', False, False),
4679 ('o', True, False)]
4680
4681 This function is analogous to :func:`unique_everseen` and is subject to
4682 the same performance considerations.
4683
4684 """
4685 seen_set = set()
4686 seen_list = []
4687 use_key = key is not None
4688 previous = None
4689
4690 for i, element in enumerate(iterable):
4691 k = key(element) if use_key else element
4692 is_unique_justseen = not i or previous != k
4693 previous = k
4694 is_unique_everseen = False
4695 try:
4696 if k not in seen_set:
4697 seen_set.add(k)
4698 is_unique_everseen = True
4699 except TypeError:
4700 if k not in seen_list:
4701 seen_list.append(k)
4702 is_unique_everseen = True
4703 yield element, is_unique_justseen, is_unique_everseen
4704
4705
4706def minmax(iterable_or_value, *others, key=None, default=_marker):
4707 """Returns both the smallest and largest items from an iterable
4708 or from two or more arguments.
4709
4710 >>> minmax([3, 1, 5])
4711 (1, 5)
4712
4713 >>> minmax(4, 2, 6)
4714 (2, 6)
4715
4716 If a *key* function is provided, it will be used to transform the input
4717 items for comparison.
4718
4719 >>> minmax([5, 30], key=str) # '30' sorts before '5'
4720 (30, 5)
4721
4722 If a *default* value is provided, it will be returned if there are no
4723 input items.
4724
4725 >>> minmax([], default=(0, 0))
4726 (0, 0)
4727
4728 Otherwise ``ValueError`` is raised.
4729
4730 This function makes a single pass over the input elements and takes care to
4731 minimize the number of comparisons made during processing.
4732
4733 Note that unlike the builtin ``max`` function, which always returns the first
4734 item with the maximum value, this function may return another item when there are
4735 ties.
4736
4737 This function is based on the
4738 `recipe <https://code.activestate.com/recipes/577916-fast-minmax-function>`__ by
4739 Raymond Hettinger.
4740 """
4741 iterable = (iterable_or_value, *others) if others else iterable_or_value
4742
4743 it = iter(iterable)
4744
4745 try:
4746 lo = hi = next(it)
4747 except StopIteration as exc:
4748 if default is _marker:
4749 raise ValueError(
4750 '`minmax()` argument is an empty iterable. '
4751 'Provide a `default` value to suppress this error.'
4752 ) from exc
4753 return default
4754
4755 # Different branches depending on the presence of key. This saves a lot
4756 # of unimportant copies which would slow the "key=None" branch
4757 # significantly down.
4758 if key is None:
4759 for x, y in zip_longest(it, it, fillvalue=lo):
4760 if y < x:
4761 x, y = y, x
4762 if x < lo:
4763 lo = x
4764 if hi < y:
4765 hi = y
4766
4767 else:
4768 lo_key = hi_key = key(lo)
4769
4770 for x, y in zip_longest(it, it, fillvalue=lo):
4771 x_key, y_key = key(x), key(y)
4772
4773 if y_key < x_key:
4774 x, y, x_key, y_key = y, x, y_key, x_key
4775 if x_key < lo_key:
4776 lo, lo_key = x, x_key
4777 if hi_key < y_key:
4778 hi, hi_key = y, y_key
4779
4780 return lo, hi
4781
4782
4783def constrained_batches(
4784 iterable, max_size, max_count=None, get_len=len, strict=True
4785):
4786 """Yield batches of items from *iterable* with a combined size limited by
4787 *max_size*.
4788
4789 >>> iterable = [b'12345', b'123', b'12345678', b'1', b'1', b'12', b'1']
4790 >>> list(constrained_batches(iterable, 10))
4791 [(b'12345', b'123'), (b'12345678', b'1', b'1'), (b'12', b'1')]
4792
4793 If a *max_count* is supplied, the number of items per batch is also
4794 limited:
4795
4796 >>> iterable = [b'12345', b'123', b'12345678', b'1', b'1', b'12', b'1']
4797 >>> list(constrained_batches(iterable, 10, max_count = 2))
4798 [(b'12345', b'123'), (b'12345678', b'1'), (b'1', b'12'), (b'1',)]
4799
4800 If a *get_len* function is supplied, use that instead of :func:`len` to
4801 determine item size.
4802
4803 If *strict* is ``True``, raise ``ValueError`` if any single item is bigger
4804 than *max_size*. Otherwise, allow single items to exceed *max_size*.
4805 """
4806 if max_size <= 0:
4807 raise ValueError('maximum size must be greater than zero')
4808
4809 batch = []
4810 batch_size = 0
4811 batch_count = 0
4812 for item in iterable:
4813 item_len = get_len(item)
4814 if strict and item_len > max_size:
4815 raise ValueError('item size exceeds maximum size')
4816
4817 reached_count = batch_count == max_count
4818 reached_size = item_len + batch_size > max_size
4819 if batch_count and (reached_size or reached_count):
4820 yield tuple(batch)
4821 batch.clear()
4822 batch_size = 0
4823 batch_count = 0
4824
4825 batch.append(item)
4826 batch_size += item_len
4827 batch_count += 1
4828
4829 if batch:
4830 yield tuple(batch)
4831
4832
4833def gray_product(*iterables):
4834 """Like :func:`itertools.product`, but return tuples in an order such
4835 that only one element in the generated tuple changes from one iteration
4836 to the next.
4837
4838 >>> list(gray_product('AB','CD'))
4839 [('A', 'C'), ('B', 'C'), ('B', 'D'), ('A', 'D')]
4840
4841 This function consumes all of the input iterables before producing output.
4842 If any of the input iterables have fewer than two items, ``ValueError``
4843 is raised.
4844
4845 For information on the algorithm, see
4846 `this section <https://www-cs-faculty.stanford.edu/~knuth/fasc2a.ps.gz>`__
4847 of Donald Knuth's *The Art of Computer Programming*.
4848 """
4849 all_iterables = tuple(tuple(x) for x in iterables)
4850 iterable_count = len(all_iterables)
4851 for iterable in all_iterables:
4852 if len(iterable) < 2:
4853 raise ValueError("each iterable must have two or more items")
4854
4855 # This is based on "Algorithm H" from section 7.2.1.1, page 20.
4856 # a holds the indexes of the source iterables for the n-tuple to be yielded
4857 # f is the array of "focus pointers"
4858 # o is the array of "directions"
4859 a = [0] * iterable_count
4860 f = list(range(iterable_count + 1))
4861 o = [1] * iterable_count
4862 while True:
4863 yield tuple(all_iterables[i][a[i]] for i in range(iterable_count))
4864 j = f[0]
4865 f[0] = 0
4866 if j == iterable_count:
4867 break
4868 a[j] = a[j] + o[j]
4869 if a[j] == 0 or a[j] == len(all_iterables[j]) - 1:
4870 o[j] = -o[j]
4871 f[j] = f[j + 1]
4872 f[j + 1] = j + 1
4873
4874
4875def partial_product(*iterables):
4876 """Yields tuples containing one item from each iterator, with subsequent
4877 tuples changing a single item at a time by advancing each iterator until it
4878 is exhausted. This sequence guarantees every value in each iterable is
4879 output at least once without generating all possible combinations.
4880
4881 This may be useful, for example, when testing an expensive function.
4882
4883 >>> list(partial_product('AB', 'C', 'DEF'))
4884 [('A', 'C', 'D'), ('B', 'C', 'D'), ('B', 'C', 'E'), ('B', 'C', 'F')]
4885 """
4886
4887 iterators = list(map(iter, iterables))
4888
4889 try:
4890 prod = [next(it) for it in iterators]
4891 except StopIteration:
4892 return
4893 yield tuple(prod)
4894
4895 for i, it in enumerate(iterators):
4896 for prod[i] in it:
4897 yield tuple(prod)
4898
4899
4900def takewhile_inclusive(predicate, iterable):
4901 """A variant of :func:`takewhile` that yields one additional element.
4902
4903 >>> list(takewhile_inclusive(lambda x: x < 5, [1, 4, 6, 4, 1]))
4904 [1, 4, 6]
4905
4906 :func:`takewhile` would return ``[1, 4]``.
4907 """
4908 for x in iterable:
4909 yield x
4910 if not predicate(x):
4911 break
4912
4913
4914def outer_product(func, xs, ys, *args, **kwargs):
4915 """A generalized outer product that applies a binary function to all
4916 pairs of items. Returns a 2D matrix with ``len(xs)`` rows and ``len(ys)``
4917 columns.
4918 Also accepts ``*args`` and ``**kwargs`` that are passed to ``func``.
4919
4920 Multiplication table:
4921
4922 >>> list(outer_product(mul, range(1, 4), range(1, 6)))
4923 [(1, 2, 3, 4, 5), (2, 4, 6, 8, 10), (3, 6, 9, 12, 15)]
4924
4925 Cross tabulation:
4926
4927 >>> xs = ['A', 'B', 'A', 'A', 'B', 'B', 'A', 'A', 'B', 'B']
4928 >>> ys = ['X', 'X', 'X', 'Y', 'Z', 'Z', 'Y', 'Y', 'Z', 'Z']
4929 >>> pair_counts = Counter(zip(xs, ys))
4930 >>> count_rows = lambda x, y: pair_counts[x, y]
4931 >>> list(outer_product(count_rows, sorted(set(xs)), sorted(set(ys))))
4932 [(2, 3, 0), (1, 0, 4)]
4933
4934 Usage with ``*args`` and ``**kwargs``:
4935
4936 >>> animals = ['cat', 'wolf', 'mouse']
4937 >>> list(outer_product(min, animals, animals, key=len))
4938 [('cat', 'cat', 'cat'), ('cat', 'wolf', 'wolf'), ('cat', 'wolf', 'mouse')]
4939 """
4940 ys = tuple(ys)
4941 return batched(
4942 starmap(lambda x, y: func(x, y, *args, **kwargs), product(xs, ys)),
4943 n=len(ys),
4944 )
4945
4946
4947def iter_suppress(iterable, *exceptions):
4948 """Yield each of the items from *iterable*. If the iteration raises one of
4949 the specified *exceptions*, that exception will be suppressed and iteration
4950 will stop.
4951
4952 >>> from itertools import chain
4953 >>> def breaks_at_five(x):
4954 ... while True:
4955 ... if x >= 5:
4956 ... raise RuntimeError
4957 ... yield x
4958 ... x += 1
4959 >>> it_1 = iter_suppress(breaks_at_five(1), RuntimeError)
4960 >>> it_2 = iter_suppress(breaks_at_five(2), RuntimeError)
4961 >>> list(chain(it_1, it_2))
4962 [1, 2, 3, 4, 2, 3, 4]
4963 """
4964 try:
4965 yield from iterable
4966 except exceptions:
4967 return
4968
4969
4970def filter_map(func, iterable):
4971 """Apply *func* to every element of *iterable*, yielding only those which
4972 are not ``None``.
4973
4974 >>> elems = ['1', 'a', '2', 'b', '3']
4975 >>> list(filter_map(lambda s: int(s) if s.isnumeric() else None, elems))
4976 [1, 2, 3]
4977 """
4978 for x in iterable:
4979 y = func(x)
4980 if y is not None:
4981 yield y
4982
4983
4984def powerset_of_sets(iterable):
4985 """Yields all possible subsets of the iterable.
4986
4987 >>> list(powerset_of_sets([1, 2, 3])) # doctest: +SKIP
4988 [set(), {1}, {2}, {3}, {1, 2}, {1, 3}, {2, 3}, {1, 2, 3}]
4989 >>> list(powerset_of_sets([1, 1, 0])) # doctest: +SKIP
4990 [set(), {1}, {0}, {0, 1}]
4991
4992 :func:`powerset_of_sets` takes care to minimize the number
4993 of hash operations performed.
4994 """
4995 sets = tuple(dict.fromkeys(map(frozenset, zip(iterable))))
4996 return chain.from_iterable(
4997 starmap(set().union, combinations(sets, r))
4998 for r in range(len(sets) + 1)
4999 )
5000
5001
5002def join_mappings(**field_to_map):
5003 """
5004 Joins multiple mappings together using their common keys.
5005
5006 >>> user_scores = {'elliot': 50, 'claris': 60}
5007 >>> user_times = {'elliot': 30, 'claris': 40}
5008 >>> join_mappings(score=user_scores, time=user_times)
5009 {'elliot': {'score': 50, 'time': 30}, 'claris': {'score': 60, 'time': 40}}
5010 """
5011 ret = defaultdict(dict)
5012
5013 for field_name, mapping in field_to_map.items():
5014 for key, value in mapping.items():
5015 ret[key][field_name] = value
5016
5017 return dict(ret)
5018
5019
5020def _complex_sumprod(v1, v2):
5021 """High precision sumprod() for complex numbers.
5022 Used by :func:`dft` and :func:`idft`.
5023 """
5024
5025 real = attrgetter('real')
5026 imag = attrgetter('imag')
5027 r1 = chain(map(real, v1), map(neg, map(imag, v1)))
5028 r2 = chain(map(real, v2), map(imag, v2))
5029 i1 = chain(map(real, v1), map(imag, v1))
5030 i2 = chain(map(imag, v2), map(real, v2))
5031 return complex(_fsumprod(r1, r2), _fsumprod(i1, i2))
5032
5033
5034def dft(xarr):
5035 """Discrete Fourier Transform. *xarr* is a sequence of complex numbers.
5036 Yields the components of the corresponding transformed output vector.
5037
5038 >>> import cmath
5039 >>> xarr = [1, 2-1j, -1j, -1+2j] # time domain
5040 >>> Xarr = [2, -2-2j, -2j, 4+4j] # frequency domain
5041 >>> magnitudes, phases = zip(*map(cmath.polar, Xarr))
5042 >>> all(map(cmath.isclose, dft(xarr), Xarr))
5043 True
5044
5045 Inputs are restricted to numeric types that can add and multiply
5046 with a complex number. This includes int, float, complex, and
5047 Fraction, but excludes Decimal.
5048
5049 See :func:`idft` for the inverse Discrete Fourier Transform.
5050 """
5051 N = len(xarr)
5052 roots_of_unity = [e ** (n / N * tau * -1j) for n in range(N)]
5053 for k in range(N):
5054 coeffs = [roots_of_unity[k * n % N] for n in range(N)]
5055 yield _complex_sumprod(xarr, coeffs)
5056
5057
5058def idft(Xarr):
5059 """Inverse Discrete Fourier Transform. *Xarr* is a sequence of
5060 complex numbers. Yields the components of the corresponding
5061 inverse-transformed output vector.
5062
5063 >>> import cmath
5064 >>> xarr = [1, 2-1j, -1j, -1+2j] # time domain
5065 >>> Xarr = [2, -2-2j, -2j, 4+4j] # frequency domain
5066 >>> all(map(cmath.isclose, idft(Xarr), xarr))
5067 True
5068
5069 Inputs are restricted to numeric types that can add and multiply
5070 with a complex number. This includes int, float, complex, and
5071 Fraction, but excludes Decimal.
5072
5073 See :func:`dft` for the Discrete Fourier Transform.
5074 """
5075 N = len(Xarr)
5076 roots_of_unity = [e ** (n / N * tau * 1j) for n in range(N)]
5077 for k in range(N):
5078 coeffs = [roots_of_unity[k * n % N] for n in range(N)]
5079 yield _complex_sumprod(Xarr, coeffs) / N
5080
5081
5082def doublestarmap(func, iterable):
5083 """Apply *func* to every item of *iterable* by dictionary unpacking
5084 the item into *func*.
5085
5086 The difference between :func:`itertools.starmap` and :func:`doublestarmap`
5087 parallels the distinction between ``func(*a)`` and ``func(**a)``.
5088
5089 >>> iterable = [{'a': 1, 'b': 2}, {'a': 40, 'b': 60}]
5090 >>> list(doublestarmap(lambda a, b: a + b, iterable))
5091 [3, 100]
5092
5093 ``TypeError`` will be raised if *func*'s signature doesn't match the
5094 mapping contained in *iterable* or if *iterable* does not contain mappings.
5095 """
5096 for item in iterable:
5097 yield func(**item)
5098
5099
5100def _nth_prime_bounds(n):
5101 """Bounds for the nth prime (counting from 1): lb <= p_n <= ub."""
5102 # At and above 688,383, the lb/ub spread is under 0.003 * n.
5103
5104 if n < 1:
5105 raise ValueError
5106
5107 if n < 6:
5108 return (n, 2.25 * n)
5109
5110 # https://en.wikipedia.org/wiki/Prime-counting_function#Inequalities
5111 upper_bound = n * log(n * log(n))
5112 lower_bound = upper_bound - n
5113 if n >= 688_383:
5114 upper_bound -= n * (1.0 - (log(log(n)) - 2.0) / log(n))
5115
5116 return lower_bound, upper_bound
5117
5118
5119def nth_prime(n, *, approximate=False):
5120 """Return the nth prime (counting from 0).
5121
5122 >>> nth_prime(0)
5123 2
5124 >>> nth_prime(100)
5125 547
5126
5127 If *approximate* is set to True, will return a prime in the close
5128 to the nth prime. The estimation is much faster than computing
5129 an exact result.
5130
5131 >>> nth_prime(200_000_000, approximate=True) # Exact result is 4222234763
5132 4217820427
5133
5134 """
5135 lb, ub = _nth_prime_bounds(n + 1)
5136
5137 if not approximate or n <= 1_000_000:
5138 return nth(sieve(ceil(ub)), n)
5139
5140 # Search from the midpoint and return the first odd prime
5141 odd = floor((lb + ub) / 2) | 1
5142 return first_true(count(odd, step=2), pred=is_prime)
5143
5144
5145def argmin(iterable, *, key=None):
5146 """
5147 Index of the first occurrence of a minimum value in an iterable.
5148
5149 >>> argmin('efghabcdijkl')
5150 4
5151 >>> argmin([3, 2, 1, 0, 4, 2, 1, 0])
5152 3
5153
5154 For example, look up a label corresponding to the position
5155 of a value that minimizes a cost function::
5156
5157 >>> def cost(x):
5158 ... "Days for a wound to heal given a subject's age."
5159 ... return x**2 - 20*x + 150
5160 ...
5161 >>> labels = ['homer', 'marge', 'bart', 'lisa', 'maggie']
5162 >>> ages = [ 35, 30, 10, 9, 1 ]
5163
5164 # Fastest healing family member
5165 >>> labels[argmin(ages, key=cost)]
5166 'bart'
5167
5168 # Age with fastest healing
5169 >>> min(ages, key=cost)
5170 10
5171
5172 """
5173 if key is not None:
5174 iterable = map(key, iterable)
5175 return min(enumerate(iterable), key=itemgetter(1))[0]
5176
5177
5178def argmax(iterable, *, key=None):
5179 """
5180 Index of the first occurrence of a maximum value in an iterable.
5181
5182 >>> argmax('abcdefghabcd')
5183 7
5184 >>> argmax([0, 1, 2, 3, 3, 2, 1, 0])
5185 3
5186
5187 For example, identify the best machine learning model::
5188
5189 >>> models = ['svm', 'random forest', 'knn', 'naïve bayes']
5190 >>> accuracy = [ 68, 61, 84, 72 ]
5191
5192 # Most accurate model
5193 >>> models[argmax(accuracy)]
5194 'knn'
5195
5196 # Best accuracy
5197 >>> max(accuracy)
5198 84
5199
5200 """
5201 if key is not None:
5202 iterable = map(key, iterable)
5203 return max(enumerate(iterable), key=itemgetter(1))[0]