Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/more_itertools/more.py: 19%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1636 statements  

1import math 

2import warnings 

3 

4from collections import Counter, defaultdict, deque, abc 

5from collections.abc import Sequence 

6from contextlib import suppress 

7from functools import cached_property, partial, reduce, wraps 

8from heapq import heapify, heapreplace 

9from itertools import ( 

10 chain, 

11 combinations, 

12 compress, 

13 count, 

14 cycle, 

15 dropwhile, 

16 groupby, 

17 islice, 

18 permutations, 

19 repeat, 

20 starmap, 

21 takewhile, 

22 tee, 

23 zip_longest, 

24 product, 

25) 

26from math import comb, e, exp, factorial, floor, fsum, log, log1p, perm, tau 

27from math import ceil 

28from queue import Empty, Queue 

29from random import random, randrange, shuffle, uniform 

30from operator import is_ as operator_is, attrgetter, itemgetter 

31from operator import neg, mul, sub, gt, lt 

32from sys import hexversion, maxsize 

33from time import monotonic 

34 

35from .recipes import ( 

36 _marker, 

37 _zip_equal, 

38 UnequalIterablesError, 

39 consume, 

40 first_true, 

41 flatten, 

42 is_prime, 

43 nth, 

44 powerset, 

45 sieve, 

46 take, 

47 unique_everseen, 

48 all_equal, 

49 batched, 

50) 

51 

52__all__ = [ 

53 'AbortThread', 

54 'SequenceView', 

55 'UnequalIterablesError', 

56 'adjacent', 

57 'all_unique', 

58 'always_iterable', 

59 'always_reversible', 

60 'argmax', 

61 'argmin', 

62 'bucket', 

63 'callback_iter', 

64 'chunked', 

65 'chunked_even', 

66 'circular_shifts', 

67 'collapse', 

68 'combination_index', 

69 'combination_with_replacement_index', 

70 'consecutive_groups', 

71 'constrained_batches', 

72 'consumer', 

73 'count_cycle', 

74 'countable', 

75 'derangements', 

76 'dft', 

77 'difference', 

78 'distinct_combinations', 

79 'distinct_permutations', 

80 'distribute', 

81 'divide', 

82 'doublestarmap', 

83 'duplicates_everseen', 

84 'duplicates_justseen', 

85 'classify_unique', 

86 'exactly_n', 

87 'filter_except', 

88 'filter_map', 

89 'first', 

90 'gray_product', 

91 'groupby_transform', 

92 'ichunked', 

93 'iequals', 

94 'idft', 

95 'ilen', 

96 'interleave', 

97 'interleave_evenly', 

98 'interleave_longest', 

99 'intersperse', 

100 'is_sorted', 

101 'islice_extended', 

102 'iterate', 

103 'iter_suppress', 

104 'join_mappings', 

105 'last', 

106 'locate', 

107 'longest_common_prefix', 

108 'lstrip', 

109 'make_decorator', 

110 'map_except', 

111 'map_if', 

112 'map_reduce', 

113 'mark_ends', 

114 'minmax', 

115 'nth_or_last', 

116 'nth_permutation', 

117 'nth_prime', 

118 'nth_product', 

119 'nth_combination_with_replacement', 

120 'numeric_range', 

121 'one', 

122 'only', 

123 'outer_product', 

124 'padded', 

125 'partial_product', 

126 'partitions', 

127 'peekable', 

128 'permutation_index', 

129 'powerset_of_sets', 

130 'product_index', 

131 'raise_', 

132 'repeat_each', 

133 'repeat_last', 

134 'replace', 

135 'rlocate', 

136 'rstrip', 

137 'run_length', 

138 'sample', 

139 'seekable', 

140 'set_partitions', 

141 'side_effect', 

142 'sliced', 

143 'sort_together', 

144 'split_after', 

145 'split_at', 

146 'split_before', 

147 'split_into', 

148 'split_when', 

149 'spy', 

150 'stagger', 

151 'strip', 

152 'strictly_n', 

153 'substrings', 

154 'substrings_indexes', 

155 'takewhile_inclusive', 

156 'time_limited', 

157 'unique_in_window', 

158 'unique_to_each', 

159 'unzip', 

160 'value_chain', 

161 'windowed', 

162 'windowed_complete', 

163 'with_iter', 

164 'zip_broadcast', 

165 'zip_equal', 

166 'zip_offset', 

167] 

168 

169# math.sumprod is available for Python 3.12+ 

170try: 

171 from math import sumprod as _fsumprod 

172 

173except ImportError: # pragma: no cover 

174 # Extended precision algorithms from T. J. Dekker, 

175 # "A Floating-Point Technique for Extending the Available Precision" 

176 # https://csclub.uwaterloo.ca/~pbarfuss/dekker1971.pdf 

177 # Formulas: (5.5) (5.6) and (5.8). Code: mul12() 

178 

179 def dl_split(x: float): 

180 "Split a float into two half-precision components." 

181 t = x * 134217729.0 # Veltkamp constant = 2.0 ** 27 + 1 

182 hi = t - (t - x) 

183 lo = x - hi 

184 return hi, lo 

185 

186 def dl_mul(x, y): 

187 "Lossless multiplication." 

188 xx_hi, xx_lo = dl_split(x) 

189 yy_hi, yy_lo = dl_split(y) 

190 p = xx_hi * yy_hi 

191 q = xx_hi * yy_lo + xx_lo * yy_hi 

192 z = p + q 

193 zz = p - z + q + xx_lo * yy_lo 

194 return z, zz 

195 

196 def _fsumprod(p, q): 

197 return fsum(chain.from_iterable(map(dl_mul, p, q))) 

198 

199 

200def chunked(iterable, n, strict=False): 

201 """Break *iterable* into lists of length *n*: 

202 

203 >>> list(chunked([1, 2, 3, 4, 5, 6], 3)) 

204 [[1, 2, 3], [4, 5, 6]] 

205 

206 By the default, the last yielded list will have fewer than *n* elements 

207 if the length of *iterable* is not divisible by *n*: 

208 

209 >>> list(chunked([1, 2, 3, 4, 5, 6, 7, 8], 3)) 

210 [[1, 2, 3], [4, 5, 6], [7, 8]] 

211 

212 To use a fill-in value instead, see the :func:`grouper` recipe. 

213 

214 If the length of *iterable* is not divisible by *n* and *strict* is 

215 ``True``, then ``ValueError`` will be raised before the last 

216 list is yielded. 

217 

218 """ 

219 iterator = iter(partial(take, n, iter(iterable)), []) 

220 if strict: 

221 if n is None: 

222 raise ValueError('n must not be None when using strict mode.') 

223 

224 def ret(): 

225 for chunk in iterator: 

226 if len(chunk) != n: 

227 raise ValueError('iterable is not divisible by n.') 

228 yield chunk 

229 

230 return ret() 

231 else: 

232 return iterator 

233 

234 

235def first(iterable, default=_marker): 

236 """Return the first item of *iterable*, or *default* if *iterable* is 

237 empty. 

238 

239 >>> first([0, 1, 2, 3]) 

240 0 

241 >>> first([], 'some default') 

242 'some default' 

243 

244 If *default* is not provided and there are no items in the iterable, 

245 raise ``ValueError``. 

246 

247 :func:`first` is useful when you have a generator of expensive-to-retrieve 

248 values and want any arbitrary one. It is marginally shorter than 

249 ``next(iter(iterable), default)``. 

250 

251 """ 

252 for item in iterable: 

253 return item 

254 if default is _marker: 

255 raise ValueError( 

256 'first() was called on an empty iterable, ' 

257 'and no default value was provided.' 

258 ) 

259 return default 

260 

261 

262def last(iterable, default=_marker): 

263 """Return the last item of *iterable*, or *default* if *iterable* is 

264 empty. 

265 

266 >>> last([0, 1, 2, 3]) 

267 3 

268 >>> last([], 'some default') 

269 'some default' 

270 

271 If *default* is not provided and there are no items in the iterable, 

272 raise ``ValueError``. 

273 """ 

274 try: 

275 if isinstance(iterable, Sequence): 

276 return iterable[-1] 

277 # Work around https://bugs.python.org/issue38525 

278 if hasattr(iterable, '__reversed__'): 

279 return next(reversed(iterable)) 

280 return deque(iterable, maxlen=1)[-1] 

281 except (IndexError, TypeError, StopIteration): 

282 if default is _marker: 

283 raise ValueError( 

284 'last() was called on an empty iterable, ' 

285 'and no default value was provided.' 

286 ) 

287 return default 

288 

289 

290def nth_or_last(iterable, n, default=_marker): 

291 """Return the nth or the last item of *iterable*, 

292 or *default* if *iterable* is empty. 

293 

294 >>> nth_or_last([0, 1, 2, 3], 2) 

295 2 

296 >>> nth_or_last([0, 1], 2) 

297 1 

298 >>> nth_or_last([], 0, 'some default') 

299 'some default' 

300 

301 If *default* is not provided and there are no items in the iterable, 

302 raise ``ValueError``. 

303 """ 

304 return last(islice(iterable, n + 1), default=default) 

305 

306 

307class peekable: 

308 """Wrap an iterator to allow lookahead and prepending elements. 

309 

310 Call :meth:`peek` on the result to get the value that will be returned 

311 by :func:`next`. This won't advance the iterator: 

312 

313 >>> p = peekable(['a', 'b']) 

314 >>> p.peek() 

315 'a' 

316 >>> next(p) 

317 'a' 

318 

319 Pass :meth:`peek` a default value to return that instead of raising 

320 ``StopIteration`` when the iterator is exhausted. 

321 

322 >>> p = peekable([]) 

323 >>> p.peek('hi') 

324 'hi' 

325 

326 peekables also offer a :meth:`prepend` method, which "inserts" items 

327 at the head of the iterable: 

328 

329 >>> p = peekable([1, 2, 3]) 

330 >>> p.prepend(10, 11, 12) 

331 >>> next(p) 

332 10 

333 >>> p.peek() 

334 11 

335 >>> list(p) 

336 [11, 12, 1, 2, 3] 

337 

338 peekables can be indexed. Index 0 is the item that will be returned by 

339 :func:`next`, index 1 is the item after that, and so on: 

340 The values up to the given index will be cached. 

341 

342 >>> p = peekable(['a', 'b', 'c', 'd']) 

343 >>> p[0] 

344 'a' 

345 >>> p[1] 

346 'b' 

347 >>> next(p) 

348 'a' 

349 

350 Negative indexes are supported, but be aware that they will cache the 

351 remaining items in the source iterator, which may require significant 

352 storage. 

353 

354 To check whether a peekable is exhausted, check its truth value: 

355 

356 >>> p = peekable(['a', 'b']) 

357 >>> if p: # peekable has items 

358 ... list(p) 

359 ['a', 'b'] 

360 >>> if not p: # peekable is exhausted 

361 ... list(p) 

362 [] 

363 

364 """ 

365 

366 def __init__(self, iterable): 

367 self._it = iter(iterable) 

368 self._cache = deque() 

369 

370 def __iter__(self): 

371 return self 

372 

373 def __bool__(self): 

374 try: 

375 self.peek() 

376 except StopIteration: 

377 return False 

378 return True 

379 

380 def peek(self, default=_marker): 

381 """Return the item that will be next returned from ``next()``. 

382 

383 Return ``default`` if there are no items left. If ``default`` is not 

384 provided, raise ``StopIteration``. 

385 

386 """ 

387 if not self._cache: 

388 try: 

389 self._cache.append(next(self._it)) 

390 except StopIteration: 

391 if default is _marker: 

392 raise 

393 return default 

394 return self._cache[0] 

395 

396 def prepend(self, *items): 

397 """Stack up items to be the next ones returned from ``next()`` or 

398 ``self.peek()``. The items will be returned in 

399 first in, first out order:: 

400 

401 >>> p = peekable([1, 2, 3]) 

402 >>> p.prepend(10, 11, 12) 

403 >>> next(p) 

404 10 

405 >>> list(p) 

406 [11, 12, 1, 2, 3] 

407 

408 It is possible, by prepending items, to "resurrect" a peekable that 

409 previously raised ``StopIteration``. 

410 

411 >>> p = peekable([]) 

412 >>> next(p) 

413 Traceback (most recent call last): 

414 ... 

415 StopIteration 

416 >>> p.prepend(1) 

417 >>> next(p) 

418 1 

419 >>> next(p) 

420 Traceback (most recent call last): 

421 ... 

422 StopIteration 

423 

424 """ 

425 self._cache.extendleft(reversed(items)) 

426 

427 def __next__(self): 

428 if self._cache: 

429 return self._cache.popleft() 

430 

431 return next(self._it) 

432 

433 def _get_slice(self, index): 

434 # Normalize the slice's arguments 

435 step = 1 if (index.step is None) else index.step 

436 if step > 0: 

437 start = 0 if (index.start is None) else index.start 

438 stop = maxsize if (index.stop is None) else index.stop 

439 elif step < 0: 

440 start = -1 if (index.start is None) else index.start 

441 stop = (-maxsize - 1) if (index.stop is None) else index.stop 

442 else: 

443 raise ValueError('slice step cannot be zero') 

444 

445 # If either the start or stop index is negative, we'll need to cache 

446 # the rest of the iterable in order to slice from the right side. 

447 if (start < 0) or (stop < 0): 

448 self._cache.extend(self._it) 

449 # Otherwise we'll need to find the rightmost index and cache to that 

450 # point. 

451 else: 

452 n = min(max(start, stop) + 1, maxsize) 

453 cache_len = len(self._cache) 

454 if n >= cache_len: 

455 self._cache.extend(islice(self._it, n - cache_len)) 

456 

457 return list(self._cache)[index] 

458 

459 def __getitem__(self, index): 

460 if isinstance(index, slice): 

461 return self._get_slice(index) 

462 

463 cache_len = len(self._cache) 

464 if index < 0: 

465 self._cache.extend(self._it) 

466 elif index >= cache_len: 

467 self._cache.extend(islice(self._it, index + 1 - cache_len)) 

468 

469 return self._cache[index] 

470 

471 

472def consumer(func): 

473 """Decorator that automatically advances a PEP-342-style "reverse iterator" 

474 to its first yield point so you don't have to call ``next()`` on it 

475 manually. 

476 

477 >>> @consumer 

478 ... def tally(): 

479 ... i = 0 

480 ... while True: 

481 ... print('Thing number %s is %s.' % (i, (yield))) 

482 ... i += 1 

483 ... 

484 >>> t = tally() 

485 >>> t.send('red') 

486 Thing number 0 is red. 

487 >>> t.send('fish') 

488 Thing number 1 is fish. 

489 

490 Without the decorator, you would have to call ``next(t)`` before 

491 ``t.send()`` could be used. 

492 

493 """ 

494 

495 @wraps(func) 

496 def wrapper(*args, **kwargs): 

497 gen = func(*args, **kwargs) 

498 next(gen) 

499 return gen 

500 

501 return wrapper 

502 

503 

504def ilen(iterable): 

505 """Return the number of items in *iterable*. 

506 

507 For example, there are 168 prime numbers below 1,000: 

508 

509 >>> ilen(sieve(1000)) 

510 168 

511 

512 Equivalent to, but faster than:: 

513 

514 def ilen(iterable): 

515 count = 0 

516 for _ in iterable: 

517 count += 1 

518 return count 

519 

520 This fully consumes the iterable, so handle with care. 

521 

522 """ 

523 # This is the "most beautiful of the fast variants" of this function. 

524 # If you think you can improve on it, please ensure that your version 

525 # is both 10x faster and 10x more beautiful. 

526 return sum(compress(repeat(1), zip(iterable))) 

527 

528 

529def iterate(func, start): 

530 """Return ``start``, ``func(start)``, ``func(func(start))``, ... 

531 

532 Produces an infinite iterator. To add a stopping condition, 

533 use :func:`take`, ``takewhile``, or :func:`takewhile_inclusive`:. 

534 

535 >>> take(10, iterate(lambda x: 2*x, 1)) 

536 [1, 2, 4, 8, 16, 32, 64, 128, 256, 512] 

537 

538 >>> collatz = lambda x: 3*x + 1 if x%2==1 else x // 2 

539 >>> list(takewhile_inclusive(lambda x: x!=1, iterate(collatz, 10))) 

540 [10, 5, 16, 8, 4, 2, 1] 

541 

542 """ 

543 with suppress(StopIteration): 

544 while True: 

545 yield start 

546 start = func(start) 

547 

548 

549def with_iter(context_manager): 

550 """Wrap an iterable in a ``with`` statement, so it closes once exhausted. 

551 

552 For example, this will close the file when the iterator is exhausted:: 

553 

554 upper_lines = (line.upper() for line in with_iter(open('foo'))) 

555 

556 Any context manager which returns an iterable is a candidate for 

557 ``with_iter``. 

558 

559 """ 

560 with context_manager as iterable: 

561 yield from iterable 

562 

563 

564def one(iterable, too_short=None, too_long=None): 

565 """Return the first item from *iterable*, which is expected to contain only 

566 that item. Raise an exception if *iterable* is empty or has more than one 

567 item. 

568 

569 :func:`one` is useful for ensuring that an iterable contains only one item. 

570 For example, it can be used to retrieve the result of a database query 

571 that is expected to return a single row. 

572 

573 If *iterable* is empty, ``ValueError`` will be raised. You may specify a 

574 different exception with the *too_short* keyword: 

575 

576 >>> it = [] 

577 >>> one(it) # doctest: +IGNORE_EXCEPTION_DETAIL 

578 Traceback (most recent call last): 

579 ... 

580 ValueError: too few items in iterable (expected 1)' 

581 >>> too_short = IndexError('too few items') 

582 >>> one(it, too_short=too_short) # doctest: +IGNORE_EXCEPTION_DETAIL 

583 Traceback (most recent call last): 

584 ... 

585 IndexError: too few items 

586 

587 Similarly, if *iterable* contains more than one item, ``ValueError`` will 

588 be raised. You may specify a different exception with the *too_long* 

589 keyword: 

590 

591 >>> it = ['too', 'many'] 

592 >>> one(it) # doctest: +IGNORE_EXCEPTION_DETAIL 

593 Traceback (most recent call last): 

594 ... 

595 ValueError: Expected exactly one item in iterable, but got 'too', 

596 'many', and perhaps more. 

597 >>> too_long = RuntimeError 

598 >>> one(it, too_long=too_long) # doctest: +IGNORE_EXCEPTION_DETAIL 

599 Traceback (most recent call last): 

600 ... 

601 RuntimeError 

602 

603 Note that :func:`one` attempts to advance *iterable* twice to ensure there 

604 is only one item. See :func:`spy` or :func:`peekable` to check iterable 

605 contents less destructively. 

606 

607 """ 

608 iterator = iter(iterable) 

609 for first in iterator: 

610 for second in iterator: 

611 msg = ( 

612 f'Expected exactly one item in iterable, but got {first!r}, ' 

613 f'{second!r}, and perhaps more.' 

614 ) 

615 raise too_long or ValueError(msg) 

616 return first 

617 raise too_short or ValueError('too few items in iterable (expected 1)') 

618 

619 

620def raise_(exception, *args): 

621 raise exception(*args) 

622 

623 

624def strictly_n(iterable, n, too_short=None, too_long=None): 

625 """Validate that *iterable* has exactly *n* items and return them if 

626 it does. If it has fewer than *n* items, call function *too_short* 

627 with those items. If it has more than *n* items, call function 

628 *too_long* with the first ``n + 1`` items. 

629 

630 >>> iterable = ['a', 'b', 'c', 'd'] 

631 >>> n = 4 

632 >>> list(strictly_n(iterable, n)) 

633 ['a', 'b', 'c', 'd'] 

634 

635 Note that the returned iterable must be consumed in order for the check to 

636 be made. 

637 

638 By default, *too_short* and *too_long* are functions that raise 

639 ``ValueError``. 

640 

641 >>> list(strictly_n('ab', 3)) # doctest: +IGNORE_EXCEPTION_DETAIL 

642 Traceback (most recent call last): 

643 ... 

644 ValueError: too few items in iterable (got 2) 

645 

646 >>> list(strictly_n('abc', 2)) # doctest: +IGNORE_EXCEPTION_DETAIL 

647 Traceback (most recent call last): 

648 ... 

649 ValueError: too many items in iterable (got at least 3) 

650 

651 You can instead supply functions that do something else. 

652 *too_short* will be called with the number of items in *iterable*. 

653 *too_long* will be called with `n + 1`. 

654 

655 >>> def too_short(item_count): 

656 ... raise RuntimeError 

657 >>> it = strictly_n('abcd', 6, too_short=too_short) 

658 >>> list(it) # doctest: +IGNORE_EXCEPTION_DETAIL 

659 Traceback (most recent call last): 

660 ... 

661 RuntimeError 

662 

663 >>> def too_long(item_count): 

664 ... print('The boss is going to hear about this') 

665 >>> it = strictly_n('abcdef', 4, too_long=too_long) 

666 >>> list(it) 

667 The boss is going to hear about this 

668 ['a', 'b', 'c', 'd'] 

669 

670 """ 

671 if too_short is None: 

672 too_short = lambda item_count: raise_( 

673 ValueError, 

674 f'Too few items in iterable (got {item_count})', 

675 ) 

676 

677 if too_long is None: 

678 too_long = lambda item_count: raise_( 

679 ValueError, 

680 f'Too many items in iterable (got at least {item_count})', 

681 ) 

682 

683 it = iter(iterable) 

684 

685 sent = 0 

686 for item in islice(it, n): 

687 yield item 

688 sent += 1 

689 

690 if sent < n: 

691 too_short(sent) 

692 return 

693 

694 for item in it: 

695 too_long(n + 1) 

696 return 

697 

698 

699def distinct_permutations(iterable, r=None): 

700 """Yield successive distinct permutations of the elements in *iterable*. 

701 

702 >>> sorted(distinct_permutations([1, 0, 1])) 

703 [(0, 1, 1), (1, 0, 1), (1, 1, 0)] 

704 

705 Equivalent to yielding from ``set(permutations(iterable))``, except 

706 duplicates are not generated and thrown away. For larger input sequences 

707 this is much more efficient. 

708 

709 Duplicate permutations arise when there are duplicated elements in the 

710 input iterable. The number of items returned is 

711 `n! / (x_1! * x_2! * ... * x_n!)`, where `n` is the total number of 

712 items input, and each `x_i` is the count of a distinct item in the input 

713 sequence. The function :func:`multinomial` computes this directly. 

714 

715 If *r* is given, only the *r*-length permutations are yielded. 

716 

717 >>> sorted(distinct_permutations([1, 0, 1], r=2)) 

718 [(0, 1), (1, 0), (1, 1)] 

719 >>> sorted(distinct_permutations(range(3), r=2)) 

720 [(0, 1), (0, 2), (1, 0), (1, 2), (2, 0), (2, 1)] 

721 

722 *iterable* need not be sortable, but note that using equal (``x == y``) 

723 but non-identical (``id(x) != id(y)``) elements may produce surprising 

724 behavior. For example, ``1`` and ``True`` are equal but non-identical: 

725 

726 >>> list(distinct_permutations([1, True, '3'])) # doctest: +SKIP 

727 [ 

728 (1, True, '3'), 

729 (1, '3', True), 

730 ('3', 1, True) 

731 ] 

732 >>> list(distinct_permutations([1, 2, '3'])) # doctest: +SKIP 

733 [ 

734 (1, 2, '3'), 

735 (1, '3', 2), 

736 (2, 1, '3'), 

737 (2, '3', 1), 

738 ('3', 1, 2), 

739 ('3', 2, 1) 

740 ] 

741 """ 

742 

743 # Algorithm: https://w.wiki/Qai 

744 def _full(A): 

745 while True: 

746 # Yield the permutation we have 

747 yield tuple(A) 

748 

749 # Find the largest index i such that A[i] < A[i + 1] 

750 for i in range(size - 2, -1, -1): 

751 if A[i] < A[i + 1]: 

752 break 

753 # If no such index exists, this permutation is the last one 

754 else: 

755 return 

756 

757 # Find the largest index j greater than j such that A[i] < A[j] 

758 for j in range(size - 1, i, -1): 

759 if A[i] < A[j]: 

760 break 

761 

762 # Swap the value of A[i] with that of A[j], then reverse the 

763 # sequence from A[i + 1] to form the new permutation 

764 A[i], A[j] = A[j], A[i] 

765 A[i + 1 :] = A[: i - size : -1] # A[i + 1:][::-1] 

766 

767 # Algorithm: modified from the above 

768 def _partial(A, r): 

769 # Split A into the first r items and the last r items 

770 head, tail = A[:r], A[r:] 

771 right_head_indexes = range(r - 1, -1, -1) 

772 left_tail_indexes = range(len(tail)) 

773 

774 while True: 

775 # Yield the permutation we have 

776 yield tuple(head) 

777 

778 # Starting from the right, find the first index of the head with 

779 # value smaller than the maximum value of the tail - call it i. 

780 pivot = tail[-1] 

781 for i in right_head_indexes: 

782 if head[i] < pivot: 

783 break 

784 pivot = head[i] 

785 else: 

786 return 

787 

788 # Starting from the left, find the first value of the tail 

789 # with a value greater than head[i] and swap. 

790 for j in left_tail_indexes: 

791 if tail[j] > head[i]: 

792 head[i], tail[j] = tail[j], head[i] 

793 break 

794 # If we didn't find one, start from the right and find the first 

795 # index of the head with a value greater than head[i] and swap. 

796 else: 

797 for j in right_head_indexes: 

798 if head[j] > head[i]: 

799 head[i], head[j] = head[j], head[i] 

800 break 

801 

802 # Reverse head[i + 1:] and swap it with tail[:r - (i + 1)] 

803 tail += head[: i - r : -1] # head[i + 1:][::-1] 

804 i += 1 

805 head[i:], tail[:] = tail[: r - i], tail[r - i :] 

806 

807 items = list(iterable) 

808 

809 try: 

810 items.sort() 

811 sortable = True 

812 except TypeError: 

813 sortable = False 

814 

815 indices_dict = defaultdict(list) 

816 

817 for item in items: 

818 indices_dict[items.index(item)].append(item) 

819 

820 indices = [items.index(item) for item in items] 

821 indices.sort() 

822 

823 equivalent_items = {k: cycle(v) for k, v in indices_dict.items()} 

824 

825 def permuted_items(permuted_indices): 

826 return tuple( 

827 next(equivalent_items[index]) for index in permuted_indices 

828 ) 

829 

830 size = len(items) 

831 if r is None: 

832 r = size 

833 

834 # functools.partial(_partial, ... ) 

835 algorithm = _full if (r == size) else partial(_partial, r=r) 

836 

837 if 0 < r <= size: 

838 if sortable: 

839 return algorithm(items) 

840 else: 

841 return ( 

842 permuted_items(permuted_indices) 

843 for permuted_indices in algorithm(indices) 

844 ) 

845 

846 return iter(() if r else ((),)) 

847 

848 

849def derangements(iterable, r=None): 

850 """Yield successive derangements of the elements in *iterable*. 

851 

852 A derangement is a permutation in which no element appears at its original 

853 index. In other words, a derangement is a permutation that has no fixed points. 

854 

855 Suppose Alice, Bob, Carol, and Dave are playing Secret Santa. 

856 The code below outputs all of the different ways to assign gift recipients 

857 such that nobody is assigned to himself or herself: 

858 

859 >>> for d in derangements(['Alice', 'Bob', 'Carol', 'Dave']): 

860 ... print(', '.join(d)) 

861 Bob, Alice, Dave, Carol 

862 Bob, Carol, Dave, Alice 

863 Bob, Dave, Alice, Carol 

864 Carol, Alice, Dave, Bob 

865 Carol, Dave, Alice, Bob 

866 Carol, Dave, Bob, Alice 

867 Dave, Alice, Bob, Carol 

868 Dave, Carol, Alice, Bob 

869 Dave, Carol, Bob, Alice 

870 

871 If *r* is given, only the *r*-length derangements are yielded. 

872 

873 >>> sorted(derangements(range(3), 2)) 

874 [(1, 0), (1, 2), (2, 0)] 

875 >>> sorted(derangements([0, 2, 3], 2)) 

876 [(2, 0), (2, 3), (3, 0)] 

877 

878 Elements are treated as unique based on their position, not on their value. 

879 If the input elements are unique, there will be no repeated values within a 

880 permutation. 

881 

882 The number of derangements of a set of size *n* is known as the 

883 "subfactorial of n". For n > 0, the subfactorial is: 

884 ``round(math.factorial(n) / math.e)``. 

885 """ 

886 xs = tuple(zip(iterable)) 

887 for ys in permutations(xs, r=r): 

888 if any(map(operator_is, xs, ys)): 

889 continue 

890 yield tuple(y[0] for y in ys) 

891 

892 

893def intersperse(e, iterable, n=1): 

894 """Intersperse filler element *e* among the items in *iterable*, leaving 

895 *n* items between each filler element. 

896 

897 >>> list(intersperse('!', [1, 2, 3, 4, 5])) 

898 [1, '!', 2, '!', 3, '!', 4, '!', 5] 

899 

900 >>> list(intersperse(None, [1, 2, 3, 4, 5], n=2)) 

901 [1, 2, None, 3, 4, None, 5] 

902 

903 """ 

904 if n == 0: 

905 raise ValueError('n must be > 0') 

906 elif n == 1: 

907 # interleave(repeat(e), iterable) -> e, x_0, e, x_1, e, x_2... 

908 # islice(..., 1, None) -> x_0, e, x_1, e, x_2... 

909 return islice(interleave(repeat(e), iterable), 1, None) 

910 else: 

911 # interleave(filler, chunks) -> [e], [x_0, x_1], [e], [x_2, x_3]... 

912 # islice(..., 1, None) -> [x_0, x_1], [e], [x_2, x_3]... 

913 # flatten(...) -> x_0, x_1, e, x_2, x_3... 

914 filler = repeat([e]) 

915 chunks = chunked(iterable, n) 

916 return flatten(islice(interleave(filler, chunks), 1, None)) 

917 

918 

919def unique_to_each(*iterables): 

920 """Return the elements from each of the input iterables that aren't in the 

921 other input iterables. 

922 

923 For example, suppose you have a set of packages, each with a set of 

924 dependencies:: 

925 

926 {'pkg_1': {'A', 'B'}, 'pkg_2': {'B', 'C'}, 'pkg_3': {'B', 'D'}} 

927 

928 If you remove one package, which dependencies can also be removed? 

929 

930 If ``pkg_1`` is removed, then ``A`` is no longer necessary - it is not 

931 associated with ``pkg_2`` or ``pkg_3``. Similarly, ``C`` is only needed for 

932 ``pkg_2``, and ``D`` is only needed for ``pkg_3``:: 

933 

934 >>> unique_to_each({'A', 'B'}, {'B', 'C'}, {'B', 'D'}) 

935 [['A'], ['C'], ['D']] 

936 

937 If there are duplicates in one input iterable that aren't in the others 

938 they will be duplicated in the output. Input order is preserved:: 

939 

940 >>> unique_to_each("mississippi", "missouri") 

941 [['p', 'p'], ['o', 'u', 'r']] 

942 

943 It is assumed that the elements of each iterable are hashable. 

944 

945 """ 

946 pool = [list(it) for it in iterables] 

947 counts = Counter(chain.from_iterable(map(set, pool))) 

948 uniques = {element for element in counts if counts[element] == 1} 

949 return [list(filter(uniques.__contains__, it)) for it in pool] 

950 

951 

952def windowed(seq, n, fillvalue=None, step=1): 

953 """Return a sliding window of width *n* over the given iterable. 

954 

955 >>> all_windows = windowed([1, 2, 3, 4, 5], 3) 

956 >>> list(all_windows) 

957 [(1, 2, 3), (2, 3, 4), (3, 4, 5)] 

958 

959 When the window is larger than the iterable, *fillvalue* is used in place 

960 of missing values: 

961 

962 >>> list(windowed([1, 2, 3], 4)) 

963 [(1, 2, 3, None)] 

964 

965 Each window will advance in increments of *step*: 

966 

967 >>> list(windowed([1, 2, 3, 4, 5, 6], 3, fillvalue='!', step=2)) 

968 [(1, 2, 3), (3, 4, 5), (5, 6, '!')] 

969 

970 To slide into the iterable's items, use :func:`chain` to add filler items 

971 to the left: 

972 

973 >>> iterable = [1, 2, 3, 4] 

974 >>> n = 3 

975 >>> padding = [None] * (n - 1) 

976 >>> list(windowed(chain(padding, iterable), 3)) 

977 [(None, None, 1), (None, 1, 2), (1, 2, 3), (2, 3, 4)] 

978 """ 

979 if n < 0: 

980 raise ValueError('n must be >= 0') 

981 if n == 0: 

982 yield () 

983 return 

984 if step < 1: 

985 raise ValueError('step must be >= 1') 

986 

987 iterator = iter(seq) 

988 

989 # Generate first window 

990 window = deque(islice(iterator, n), maxlen=n) 

991 

992 # Deal with the first window not being full 

993 if not window: 

994 return 

995 if len(window) < n: 

996 yield tuple(window) + ((fillvalue,) * (n - len(window))) 

997 return 

998 yield tuple(window) 

999 

1000 # Create the filler for the next windows. The padding ensures 

1001 # we have just enough elements to fill the last window. 

1002 padding = (fillvalue,) * (n - 1 if step >= n else step - 1) 

1003 filler = map(window.append, chain(iterator, padding)) 

1004 

1005 # Generate the rest of the windows 

1006 for _ in islice(filler, step - 1, None, step): 

1007 yield tuple(window) 

1008 

1009 

1010def substrings(iterable): 

1011 """Yield all of the substrings of *iterable*. 

1012 

1013 >>> [''.join(s) for s in substrings('more')] 

1014 ['m', 'o', 'r', 'e', 'mo', 'or', 're', 'mor', 'ore', 'more'] 

1015 

1016 Note that non-string iterables can also be subdivided. 

1017 

1018 >>> list(substrings([0, 1, 2])) 

1019 [(0,), (1,), (2,), (0, 1), (1, 2), (0, 1, 2)] 

1020 

1021 """ 

1022 # The length-1 substrings 

1023 seq = [] 

1024 for item in iterable: 

1025 seq.append(item) 

1026 yield (item,) 

1027 seq = tuple(seq) 

1028 item_count = len(seq) 

1029 

1030 # And the rest 

1031 for n in range(2, item_count + 1): 

1032 for i in range(item_count - n + 1): 

1033 yield seq[i : i + n] 

1034 

1035 

1036def substrings_indexes(seq, reverse=False): 

1037 """Yield all substrings and their positions in *seq* 

1038 

1039 The items yielded will be a tuple of the form ``(substr, i, j)``, where 

1040 ``substr == seq[i:j]``. 

1041 

1042 This function only works for iterables that support slicing, such as 

1043 ``str`` objects. 

1044 

1045 >>> for item in substrings_indexes('more'): 

1046 ... print(item) 

1047 ('m', 0, 1) 

1048 ('o', 1, 2) 

1049 ('r', 2, 3) 

1050 ('e', 3, 4) 

1051 ('mo', 0, 2) 

1052 ('or', 1, 3) 

1053 ('re', 2, 4) 

1054 ('mor', 0, 3) 

1055 ('ore', 1, 4) 

1056 ('more', 0, 4) 

1057 

1058 Set *reverse* to ``True`` to yield the same items in the opposite order. 

1059 

1060 

1061 """ 

1062 r = range(1, len(seq) + 1) 

1063 if reverse: 

1064 r = reversed(r) 

1065 return ( 

1066 (seq[i : i + L], i, i + L) for L in r for i in range(len(seq) - L + 1) 

1067 ) 

1068 

1069 

1070class bucket: 

1071 """Wrap *iterable* and return an object that buckets the iterable into 

1072 child iterables based on a *key* function. 

1073 

1074 >>> iterable = ['a1', 'b1', 'c1', 'a2', 'b2', 'c2', 'b3'] 

1075 >>> s = bucket(iterable, key=lambda x: x[0]) # Bucket by 1st character 

1076 >>> sorted(list(s)) # Get the keys 

1077 ['a', 'b', 'c'] 

1078 >>> a_iterable = s['a'] 

1079 >>> next(a_iterable) 

1080 'a1' 

1081 >>> next(a_iterable) 

1082 'a2' 

1083 >>> list(s['b']) 

1084 ['b1', 'b2', 'b3'] 

1085 

1086 The original iterable will be advanced and its items will be cached until 

1087 they are used by the child iterables. This may require significant storage. 

1088 

1089 By default, attempting to select a bucket to which no items belong will 

1090 exhaust the iterable and cache all values. 

1091 If you specify a *validator* function, selected buckets will instead be 

1092 checked against it. 

1093 

1094 >>> from itertools import count 

1095 >>> it = count(1, 2) # Infinite sequence of odd numbers 

1096 >>> key = lambda x: x % 10 # Bucket by last digit 

1097 >>> validator = lambda x: x in {1, 3, 5, 7, 9} # Odd digits only 

1098 >>> s = bucket(it, key=key, validator=validator) 

1099 >>> 2 in s 

1100 False 

1101 >>> list(s[2]) 

1102 [] 

1103 

1104 """ 

1105 

1106 def __init__(self, iterable, key, validator=None): 

1107 self._it = iter(iterable) 

1108 self._key = key 

1109 self._cache = defaultdict(deque) 

1110 self._validator = validator or (lambda x: True) 

1111 

1112 def __contains__(self, value): 

1113 if not self._validator(value): 

1114 return False 

1115 

1116 try: 

1117 item = next(self[value]) 

1118 except StopIteration: 

1119 return False 

1120 else: 

1121 self._cache[value].appendleft(item) 

1122 

1123 return True 

1124 

1125 def _get_values(self, value): 

1126 """ 

1127 Helper to yield items from the parent iterator that match *value*. 

1128 Items that don't match are stored in the local cache as they 

1129 are encountered. 

1130 """ 

1131 while True: 

1132 # If we've cached some items that match the target value, emit 

1133 # the first one and evict it from the cache. 

1134 if self._cache[value]: 

1135 yield self._cache[value].popleft() 

1136 # Otherwise we need to advance the parent iterator to search for 

1137 # a matching item, caching the rest. 

1138 else: 

1139 while True: 

1140 try: 

1141 item = next(self._it) 

1142 except StopIteration: 

1143 return 

1144 item_value = self._key(item) 

1145 if item_value == value: 

1146 yield item 

1147 break 

1148 elif self._validator(item_value): 

1149 self._cache[item_value].append(item) 

1150 

1151 def __iter__(self): 

1152 for item in self._it: 

1153 item_value = self._key(item) 

1154 if self._validator(item_value): 

1155 self._cache[item_value].append(item) 

1156 

1157 return iter(self._cache) 

1158 

1159 def __getitem__(self, value): 

1160 if not self._validator(value): 

1161 return iter(()) 

1162 

1163 return self._get_values(value) 

1164 

1165 

1166def spy(iterable, n=1): 

1167 """Return a 2-tuple with a list containing the first *n* elements of 

1168 *iterable*, and an iterator with the same items as *iterable*. 

1169 This allows you to "look ahead" at the items in the iterable without 

1170 advancing it. 

1171 

1172 There is one item in the list by default: 

1173 

1174 >>> iterable = 'abcdefg' 

1175 >>> head, iterable = spy(iterable) 

1176 >>> head 

1177 ['a'] 

1178 >>> list(iterable) 

1179 ['a', 'b', 'c', 'd', 'e', 'f', 'g'] 

1180 

1181 You may use unpacking to retrieve items instead of lists: 

1182 

1183 >>> (head,), iterable = spy('abcdefg') 

1184 >>> head 

1185 'a' 

1186 >>> (first, second), iterable = spy('abcdefg', 2) 

1187 >>> first 

1188 'a' 

1189 >>> second 

1190 'b' 

1191 

1192 The number of items requested can be larger than the number of items in 

1193 the iterable: 

1194 

1195 >>> iterable = [1, 2, 3, 4, 5] 

1196 >>> head, iterable = spy(iterable, 10) 

1197 >>> head 

1198 [1, 2, 3, 4, 5] 

1199 >>> list(iterable) 

1200 [1, 2, 3, 4, 5] 

1201 

1202 """ 

1203 p, q = tee(iterable) 

1204 return take(n, q), p 

1205 

1206 

1207def interleave(*iterables): 

1208 """Return a new iterable yielding from each iterable in turn, 

1209 until the shortest is exhausted. 

1210 

1211 >>> list(interleave([1, 2, 3], [4, 5], [6, 7, 8])) 

1212 [1, 4, 6, 2, 5, 7] 

1213 

1214 For a version that doesn't terminate after the shortest iterable is 

1215 exhausted, see :func:`interleave_longest`. 

1216 

1217 """ 

1218 return chain.from_iterable(zip(*iterables)) 

1219 

1220 

1221def interleave_longest(*iterables): 

1222 """Return a new iterable yielding from each iterable in turn, 

1223 skipping any that are exhausted. 

1224 

1225 >>> list(interleave_longest([1, 2, 3], [4, 5], [6, 7, 8])) 

1226 [1, 4, 6, 2, 5, 7, 3, 8] 

1227 

1228 This function produces the same output as :func:`roundrobin`, but may 

1229 perform better for some inputs (in particular when the number of iterables 

1230 is large). 

1231 

1232 """ 

1233 for xs in zip_longest(*iterables, fillvalue=_marker): 

1234 for x in xs: 

1235 if x is not _marker: 

1236 yield x 

1237 

1238 

1239def interleave_evenly(iterables, lengths=None): 

1240 """ 

1241 Interleave multiple iterables so that their elements are evenly distributed 

1242 throughout the output sequence. 

1243 

1244 >>> iterables = [1, 2, 3, 4, 5], ['a', 'b'] 

1245 >>> list(interleave_evenly(iterables)) 

1246 [1, 2, 'a', 3, 4, 'b', 5] 

1247 

1248 >>> iterables = [[1, 2, 3], [4, 5], [6, 7, 8]] 

1249 >>> list(interleave_evenly(iterables)) 

1250 [1, 6, 4, 2, 7, 3, 8, 5] 

1251 

1252 This function requires iterables of known length. Iterables without 

1253 ``__len__()`` can be used by manually specifying lengths with *lengths*: 

1254 

1255 >>> from itertools import combinations, repeat 

1256 >>> iterables = [combinations(range(4), 2), ['a', 'b', 'c']] 

1257 >>> lengths = [4 * (4 - 1) // 2, 3] 

1258 >>> list(interleave_evenly(iterables, lengths=lengths)) 

1259 [(0, 1), (0, 2), 'a', (0, 3), (1, 2), 'b', (1, 3), (2, 3), 'c'] 

1260 

1261 Based on Bresenham's algorithm. 

1262 """ 

1263 if lengths is None: 

1264 try: 

1265 lengths = [len(it) for it in iterables] 

1266 except TypeError: 

1267 raise ValueError( 

1268 'Iterable lengths could not be determined automatically. ' 

1269 'Specify them with the lengths keyword.' 

1270 ) 

1271 elif len(iterables) != len(lengths): 

1272 raise ValueError('Mismatching number of iterables and lengths.') 

1273 

1274 dims = len(lengths) 

1275 

1276 # sort iterables by length, descending 

1277 lengths_permute = sorted( 

1278 range(dims), key=lambda i: lengths[i], reverse=True 

1279 ) 

1280 lengths_desc = [lengths[i] for i in lengths_permute] 

1281 iters_desc = [iter(iterables[i]) for i in lengths_permute] 

1282 

1283 # the longest iterable is the primary one (Bresenham: the longest 

1284 # distance along an axis) 

1285 delta_primary, deltas_secondary = lengths_desc[0], lengths_desc[1:] 

1286 iter_primary, iters_secondary = iters_desc[0], iters_desc[1:] 

1287 errors = [delta_primary // dims] * len(deltas_secondary) 

1288 

1289 to_yield = sum(lengths) 

1290 while to_yield: 

1291 yield next(iter_primary) 

1292 to_yield -= 1 

1293 # update errors for each secondary iterable 

1294 errors = [e - delta for e, delta in zip(errors, deltas_secondary)] 

1295 

1296 # those iterables for which the error is negative are yielded 

1297 # ("diagonal step" in Bresenham) 

1298 for i, e_ in enumerate(errors): 

1299 if e_ < 0: 

1300 yield next(iters_secondary[i]) 

1301 to_yield -= 1 

1302 errors[i] += delta_primary 

1303 

1304 

1305def collapse(iterable, base_type=None, levels=None): 

1306 """Flatten an iterable with multiple levels of nesting (e.g., a list of 

1307 lists of tuples) into non-iterable types. 

1308 

1309 >>> iterable = [(1, 2), ([3, 4], [[5], [6]])] 

1310 >>> list(collapse(iterable)) 

1311 [1, 2, 3, 4, 5, 6] 

1312 

1313 Binary and text strings are not considered iterable and 

1314 will not be collapsed. 

1315 

1316 To avoid collapsing other types, specify *base_type*: 

1317 

1318 >>> iterable = ['ab', ('cd', 'ef'), ['gh', 'ij']] 

1319 >>> list(collapse(iterable, base_type=tuple)) 

1320 ['ab', ('cd', 'ef'), 'gh', 'ij'] 

1321 

1322 Specify *levels* to stop flattening after a certain level: 

1323 

1324 >>> iterable = [('a', ['b']), ('c', ['d'])] 

1325 >>> list(collapse(iterable)) # Fully flattened 

1326 ['a', 'b', 'c', 'd'] 

1327 >>> list(collapse(iterable, levels=1)) # Only one level flattened 

1328 ['a', ['b'], 'c', ['d']] 

1329 

1330 """ 

1331 stack = deque() 

1332 # Add our first node group, treat the iterable as a single node 

1333 stack.appendleft((0, repeat(iterable, 1))) 

1334 

1335 while stack: 

1336 node_group = stack.popleft() 

1337 level, nodes = node_group 

1338 

1339 # Check if beyond max level 

1340 if levels is not None and level > levels: 

1341 yield from nodes 

1342 continue 

1343 

1344 for node in nodes: 

1345 # Check if done iterating 

1346 if isinstance(node, (str, bytes)) or ( 

1347 (base_type is not None) and isinstance(node, base_type) 

1348 ): 

1349 yield node 

1350 # Otherwise try to create child nodes 

1351 else: 

1352 try: 

1353 tree = iter(node) 

1354 except TypeError: 

1355 yield node 

1356 else: 

1357 # Save our current location 

1358 stack.appendleft(node_group) 

1359 # Append the new child node 

1360 stack.appendleft((level + 1, tree)) 

1361 # Break to process child node 

1362 break 

1363 

1364 

1365def side_effect(func, iterable, chunk_size=None, before=None, after=None): 

1366 """Invoke *func* on each item in *iterable* (or on each *chunk_size* group 

1367 of items) before yielding the item. 

1368 

1369 `func` must be a function that takes a single argument. Its return value 

1370 will be discarded. 

1371 

1372 *before* and *after* are optional functions that take no arguments. They 

1373 will be executed before iteration starts and after it ends, respectively. 

1374 

1375 `side_effect` can be used for logging, updating progress bars, or anything 

1376 that is not functionally "pure." 

1377 

1378 Emitting a status message: 

1379 

1380 >>> from more_itertools import consume 

1381 >>> func = lambda item: print('Received {}'.format(item)) 

1382 >>> consume(side_effect(func, range(2))) 

1383 Received 0 

1384 Received 1 

1385 

1386 Operating on chunks of items: 

1387 

1388 >>> pair_sums = [] 

1389 >>> func = lambda chunk: pair_sums.append(sum(chunk)) 

1390 >>> list(side_effect(func, [0, 1, 2, 3, 4, 5], 2)) 

1391 [0, 1, 2, 3, 4, 5] 

1392 >>> list(pair_sums) 

1393 [1, 5, 9] 

1394 

1395 Writing to a file-like object: 

1396 

1397 >>> from io import StringIO 

1398 >>> from more_itertools import consume 

1399 >>> f = StringIO() 

1400 >>> func = lambda x: print(x, file=f) 

1401 >>> before = lambda: print(u'HEADER', file=f) 

1402 >>> after = f.close 

1403 >>> it = [u'a', u'b', u'c'] 

1404 >>> consume(side_effect(func, it, before=before, after=after)) 

1405 >>> f.closed 

1406 True 

1407 

1408 """ 

1409 try: 

1410 if before is not None: 

1411 before() 

1412 

1413 if chunk_size is None: 

1414 for item in iterable: 

1415 func(item) 

1416 yield item 

1417 else: 

1418 for chunk in chunked(iterable, chunk_size): 

1419 func(chunk) 

1420 yield from chunk 

1421 finally: 

1422 if after is not None: 

1423 after() 

1424 

1425 

1426def sliced(seq, n, strict=False): 

1427 """Yield slices of length *n* from the sequence *seq*. 

1428 

1429 >>> list(sliced((1, 2, 3, 4, 5, 6), 3)) 

1430 [(1, 2, 3), (4, 5, 6)] 

1431 

1432 By the default, the last yielded slice will have fewer than *n* elements 

1433 if the length of *seq* is not divisible by *n*: 

1434 

1435 >>> list(sliced((1, 2, 3, 4, 5, 6, 7, 8), 3)) 

1436 [(1, 2, 3), (4, 5, 6), (7, 8)] 

1437 

1438 If the length of *seq* is not divisible by *n* and *strict* is 

1439 ``True``, then ``ValueError`` will be raised before the last 

1440 slice is yielded. 

1441 

1442 This function will only work for iterables that support slicing. 

1443 For non-sliceable iterables, see :func:`chunked`. 

1444 

1445 """ 

1446 iterator = takewhile(len, (seq[i : i + n] for i in count(0, n))) 

1447 if strict: 

1448 

1449 def ret(): 

1450 for _slice in iterator: 

1451 if len(_slice) != n: 

1452 raise ValueError("seq is not divisible by n.") 

1453 yield _slice 

1454 

1455 return ret() 

1456 else: 

1457 return iterator 

1458 

1459 

1460def split_at(iterable, pred, maxsplit=-1, keep_separator=False): 

1461 """Yield lists of items from *iterable*, where each list is delimited by 

1462 an item where callable *pred* returns ``True``. 

1463 

1464 >>> list(split_at('abcdcba', lambda x: x == 'b')) 

1465 [['a'], ['c', 'd', 'c'], ['a']] 

1466 

1467 >>> list(split_at(range(10), lambda n: n % 2 == 1)) 

1468 [[0], [2], [4], [6], [8], []] 

1469 

1470 At most *maxsplit* splits are done. If *maxsplit* is not specified or -1, 

1471 then there is no limit on the number of splits: 

1472 

1473 >>> list(split_at(range(10), lambda n: n % 2 == 1, maxsplit=2)) 

1474 [[0], [2], [4, 5, 6, 7, 8, 9]] 

1475 

1476 By default, the delimiting items are not included in the output. 

1477 To include them, set *keep_separator* to ``True``. 

1478 

1479 >>> list(split_at('abcdcba', lambda x: x == 'b', keep_separator=True)) 

1480 [['a'], ['b'], ['c', 'd', 'c'], ['b'], ['a']] 

1481 

1482 """ 

1483 if maxsplit == 0: 

1484 yield list(iterable) 

1485 return 

1486 

1487 buf = [] 

1488 it = iter(iterable) 

1489 for item in it: 

1490 if pred(item): 

1491 yield buf 

1492 if keep_separator: 

1493 yield [item] 

1494 if maxsplit == 1: 

1495 yield list(it) 

1496 return 

1497 buf = [] 

1498 maxsplit -= 1 

1499 else: 

1500 buf.append(item) 

1501 yield buf 

1502 

1503 

1504def split_before(iterable, pred, maxsplit=-1): 

1505 """Yield lists of items from *iterable*, where each list ends just before 

1506 an item for which callable *pred* returns ``True``: 

1507 

1508 >>> list(split_before('OneTwo', lambda s: s.isupper())) 

1509 [['O', 'n', 'e'], ['T', 'w', 'o']] 

1510 

1511 >>> list(split_before(range(10), lambda n: n % 3 == 0)) 

1512 [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] 

1513 

1514 At most *maxsplit* splits are done. If *maxsplit* is not specified or -1, 

1515 then there is no limit on the number of splits: 

1516 

1517 >>> list(split_before(range(10), lambda n: n % 3 == 0, maxsplit=2)) 

1518 [[0, 1, 2], [3, 4, 5], [6, 7, 8, 9]] 

1519 """ 

1520 if maxsplit == 0: 

1521 yield list(iterable) 

1522 return 

1523 

1524 buf = [] 

1525 it = iter(iterable) 

1526 for item in it: 

1527 if pred(item) and buf: 

1528 yield buf 

1529 if maxsplit == 1: 

1530 yield [item, *it] 

1531 return 

1532 buf = [] 

1533 maxsplit -= 1 

1534 buf.append(item) 

1535 if buf: 

1536 yield buf 

1537 

1538 

1539def split_after(iterable, pred, maxsplit=-1): 

1540 """Yield lists of items from *iterable*, where each list ends with an 

1541 item where callable *pred* returns ``True``: 

1542 

1543 >>> list(split_after('one1two2', lambda s: s.isdigit())) 

1544 [['o', 'n', 'e', '1'], ['t', 'w', 'o', '2']] 

1545 

1546 >>> list(split_after(range(10), lambda n: n % 3 == 0)) 

1547 [[0], [1, 2, 3], [4, 5, 6], [7, 8, 9]] 

1548 

1549 At most *maxsplit* splits are done. If *maxsplit* is not specified or -1, 

1550 then there is no limit on the number of splits: 

1551 

1552 >>> list(split_after(range(10), lambda n: n % 3 == 0, maxsplit=2)) 

1553 [[0], [1, 2, 3], [4, 5, 6, 7, 8, 9]] 

1554 

1555 """ 

1556 if maxsplit == 0: 

1557 yield list(iterable) 

1558 return 

1559 

1560 buf = [] 

1561 it = iter(iterable) 

1562 for item in it: 

1563 buf.append(item) 

1564 if pred(item) and buf: 

1565 yield buf 

1566 if maxsplit == 1: 

1567 buf = list(it) 

1568 if buf: 

1569 yield buf 

1570 return 

1571 buf = [] 

1572 maxsplit -= 1 

1573 if buf: 

1574 yield buf 

1575 

1576 

1577def split_when(iterable, pred, maxsplit=-1): 

1578 """Split *iterable* into pieces based on the output of *pred*. 

1579 *pred* should be a function that takes successive pairs of items and 

1580 returns ``True`` if the iterable should be split in between them. 

1581 

1582 For example, to find runs of increasing numbers, split the iterable when 

1583 element ``i`` is larger than element ``i + 1``: 

1584 

1585 >>> list(split_when([1, 2, 3, 3, 2, 5, 2, 4, 2], lambda x, y: x > y)) 

1586 [[1, 2, 3, 3], [2, 5], [2, 4], [2]] 

1587 

1588 At most *maxsplit* splits are done. If *maxsplit* is not specified or -1, 

1589 then there is no limit on the number of splits: 

1590 

1591 >>> list(split_when([1, 2, 3, 3, 2, 5, 2, 4, 2], 

1592 ... lambda x, y: x > y, maxsplit=2)) 

1593 [[1, 2, 3, 3], [2, 5], [2, 4, 2]] 

1594 

1595 """ 

1596 if maxsplit == 0: 

1597 yield list(iterable) 

1598 return 

1599 

1600 it = iter(iterable) 

1601 try: 

1602 cur_item = next(it) 

1603 except StopIteration: 

1604 return 

1605 

1606 buf = [cur_item] 

1607 for next_item in it: 

1608 if pred(cur_item, next_item): 

1609 yield buf 

1610 if maxsplit == 1: 

1611 yield [next_item, *it] 

1612 return 

1613 buf = [] 

1614 maxsplit -= 1 

1615 

1616 buf.append(next_item) 

1617 cur_item = next_item 

1618 

1619 yield buf 

1620 

1621 

1622def split_into(iterable, sizes): 

1623 """Yield a list of sequential items from *iterable* of length 'n' for each 

1624 integer 'n' in *sizes*. 

1625 

1626 >>> list(split_into([1,2,3,4,5,6], [1,2,3])) 

1627 [[1], [2, 3], [4, 5, 6]] 

1628 

1629 If the sum of *sizes* is smaller than the length of *iterable*, then the 

1630 remaining items of *iterable* will not be returned. 

1631 

1632 >>> list(split_into([1,2,3,4,5,6], [2,3])) 

1633 [[1, 2], [3, 4, 5]] 

1634 

1635 If the sum of *sizes* is larger than the length of *iterable*, fewer items 

1636 will be returned in the iteration that overruns the *iterable* and further 

1637 lists will be empty: 

1638 

1639 >>> list(split_into([1,2,3,4], [1,2,3,4])) 

1640 [[1], [2, 3], [4], []] 

1641 

1642 When a ``None`` object is encountered in *sizes*, the returned list will 

1643 contain items up to the end of *iterable* the same way that 

1644 :func:`itertools.slice` does: 

1645 

1646 >>> list(split_into([1,2,3,4,5,6,7,8,9,0], [2,3,None])) 

1647 [[1, 2], [3, 4, 5], [6, 7, 8, 9, 0]] 

1648 

1649 :func:`split_into` can be useful for grouping a series of items where the 

1650 sizes of the groups are not uniform. An example would be where in a row 

1651 from a table, multiple columns represent elements of the same feature 

1652 (e.g. a point represented by x,y,z) but, the format is not the same for 

1653 all columns. 

1654 """ 

1655 # convert the iterable argument into an iterator so its contents can 

1656 # be consumed by islice in case it is a generator 

1657 it = iter(iterable) 

1658 

1659 for size in sizes: 

1660 if size is None: 

1661 yield list(it) 

1662 return 

1663 else: 

1664 yield list(islice(it, size)) 

1665 

1666 

1667def padded(iterable, fillvalue=None, n=None, next_multiple=False): 

1668 """Yield the elements from *iterable*, followed by *fillvalue*, such that 

1669 at least *n* items are emitted. 

1670 

1671 >>> list(padded([1, 2, 3], '?', 5)) 

1672 [1, 2, 3, '?', '?'] 

1673 

1674 If *next_multiple* is ``True``, *fillvalue* will be emitted until the 

1675 number of items emitted is a multiple of *n*: 

1676 

1677 >>> list(padded([1, 2, 3, 4], n=3, next_multiple=True)) 

1678 [1, 2, 3, 4, None, None] 

1679 

1680 If *n* is ``None``, *fillvalue* will be emitted indefinitely. 

1681 

1682 To create an *iterable* of exactly size *n*, you can truncate with 

1683 :func:`islice`. 

1684 

1685 >>> list(islice(padded([1, 2, 3], '?'), 5)) 

1686 [1, 2, 3, '?', '?'] 

1687 >>> list(islice(padded([1, 2, 3, 4, 5, 6, 7, 8], '?'), 5)) 

1688 [1, 2, 3, 4, 5] 

1689 

1690 """ 

1691 iterator = iter(iterable) 

1692 iterator_with_repeat = chain(iterator, repeat(fillvalue)) 

1693 

1694 if n is None: 

1695 return iterator_with_repeat 

1696 elif n < 1: 

1697 raise ValueError('n must be at least 1') 

1698 elif next_multiple: 

1699 

1700 def slice_generator(): 

1701 for first in iterator: 

1702 yield (first,) 

1703 yield islice(iterator_with_repeat, n - 1) 

1704 

1705 # While elements exist produce slices of size n 

1706 return chain.from_iterable(slice_generator()) 

1707 else: 

1708 # Ensure the first batch is at least size n then iterate 

1709 return chain(islice(iterator_with_repeat, n), iterator) 

1710 

1711 

1712def repeat_each(iterable, n=2): 

1713 """Repeat each element in *iterable* *n* times. 

1714 

1715 >>> list(repeat_each('ABC', 3)) 

1716 ['A', 'A', 'A', 'B', 'B', 'B', 'C', 'C', 'C'] 

1717 """ 

1718 return chain.from_iterable(map(repeat, iterable, repeat(n))) 

1719 

1720 

1721def repeat_last(iterable, default=None): 

1722 """After the *iterable* is exhausted, keep yielding its last element. 

1723 

1724 >>> list(islice(repeat_last(range(3)), 5)) 

1725 [0, 1, 2, 2, 2] 

1726 

1727 If the iterable is empty, yield *default* forever:: 

1728 

1729 >>> list(islice(repeat_last(range(0), 42), 5)) 

1730 [42, 42, 42, 42, 42] 

1731 

1732 """ 

1733 item = _marker 

1734 for item in iterable: 

1735 yield item 

1736 final = default if item is _marker else item 

1737 yield from repeat(final) 

1738 

1739 

1740def distribute(n, iterable): 

1741 """Distribute the items from *iterable* among *n* smaller iterables. 

1742 

1743 >>> group_1, group_2 = distribute(2, [1, 2, 3, 4, 5, 6]) 

1744 >>> list(group_1) 

1745 [1, 3, 5] 

1746 >>> list(group_2) 

1747 [2, 4, 6] 

1748 

1749 If the length of *iterable* is not evenly divisible by *n*, then the 

1750 length of the returned iterables will not be identical: 

1751 

1752 >>> children = distribute(3, [1, 2, 3, 4, 5, 6, 7]) 

1753 >>> [list(c) for c in children] 

1754 [[1, 4, 7], [2, 5], [3, 6]] 

1755 

1756 If the length of *iterable* is smaller than *n*, then the last returned 

1757 iterables will be empty: 

1758 

1759 >>> children = distribute(5, [1, 2, 3]) 

1760 >>> [list(c) for c in children] 

1761 [[1], [2], [3], [], []] 

1762 

1763 This function uses :func:`itertools.tee` and may require significant 

1764 storage. 

1765 

1766 If you need the order items in the smaller iterables to match the 

1767 original iterable, see :func:`divide`. 

1768 

1769 """ 

1770 if n < 1: 

1771 raise ValueError('n must be at least 1') 

1772 

1773 children = tee(iterable, n) 

1774 return [islice(it, index, None, n) for index, it in enumerate(children)] 

1775 

1776 

1777def stagger(iterable, offsets=(-1, 0, 1), longest=False, fillvalue=None): 

1778 """Yield tuples whose elements are offset from *iterable*. 

1779 The amount by which the `i`-th item in each tuple is offset is given by 

1780 the `i`-th item in *offsets*. 

1781 

1782 >>> list(stagger([0, 1, 2, 3])) 

1783 [(None, 0, 1), (0, 1, 2), (1, 2, 3)] 

1784 >>> list(stagger(range(8), offsets=(0, 2, 4))) 

1785 [(0, 2, 4), (1, 3, 5), (2, 4, 6), (3, 5, 7)] 

1786 

1787 By default, the sequence will end when the final element of a tuple is the 

1788 last item in the iterable. To continue until the first element of a tuple 

1789 is the last item in the iterable, set *longest* to ``True``:: 

1790 

1791 >>> list(stagger([0, 1, 2, 3], longest=True)) 

1792 [(None, 0, 1), (0, 1, 2), (1, 2, 3), (2, 3, None), (3, None, None)] 

1793 

1794 By default, ``None`` will be used to replace offsets beyond the end of the 

1795 sequence. Specify *fillvalue* to use some other value. 

1796 

1797 """ 

1798 children = tee(iterable, len(offsets)) 

1799 

1800 return zip_offset( 

1801 *children, offsets=offsets, longest=longest, fillvalue=fillvalue 

1802 ) 

1803 

1804 

1805def zip_equal(*iterables): 

1806 """``zip`` the input *iterables* together but raise 

1807 ``UnequalIterablesError`` if they aren't all the same length. 

1808 

1809 >>> it_1 = range(3) 

1810 >>> it_2 = iter('abc') 

1811 >>> list(zip_equal(it_1, it_2)) 

1812 [(0, 'a'), (1, 'b'), (2, 'c')] 

1813 

1814 >>> it_1 = range(3) 

1815 >>> it_2 = iter('abcd') 

1816 >>> list(zip_equal(it_1, it_2)) # doctest: +IGNORE_EXCEPTION_DETAIL 

1817 Traceback (most recent call last): 

1818 ... 

1819 more_itertools.more.UnequalIterablesError: Iterables have different 

1820 lengths 

1821 

1822 """ 

1823 if hexversion >= 0x30A00A6: 

1824 warnings.warn( 

1825 ( 

1826 'zip_equal will be removed in a future version of ' 

1827 'more-itertools. Use the builtin zip function with ' 

1828 'strict=True instead.' 

1829 ), 

1830 DeprecationWarning, 

1831 ) 

1832 

1833 return _zip_equal(*iterables) 

1834 

1835 

1836def zip_offset(*iterables, offsets, longest=False, fillvalue=None): 

1837 """``zip`` the input *iterables* together, but offset the `i`-th iterable 

1838 by the `i`-th item in *offsets*. 

1839 

1840 >>> list(zip_offset('0123', 'abcdef', offsets=(0, 1))) 

1841 [('0', 'b'), ('1', 'c'), ('2', 'd'), ('3', 'e')] 

1842 

1843 This can be used as a lightweight alternative to SciPy or pandas to analyze 

1844 data sets in which some series have a lead or lag relationship. 

1845 

1846 By default, the sequence will end when the shortest iterable is exhausted. 

1847 To continue until the longest iterable is exhausted, set *longest* to 

1848 ``True``. 

1849 

1850 >>> list(zip_offset('0123', 'abcdef', offsets=(0, 1), longest=True)) 

1851 [('0', 'b'), ('1', 'c'), ('2', 'd'), ('3', 'e'), (None, 'f')] 

1852 

1853 By default, ``None`` will be used to replace offsets beyond the end of the 

1854 sequence. Specify *fillvalue* to use some other value. 

1855 

1856 """ 

1857 if len(iterables) != len(offsets): 

1858 raise ValueError("Number of iterables and offsets didn't match") 

1859 

1860 staggered = [] 

1861 for it, n in zip(iterables, offsets): 

1862 if n < 0: 

1863 staggered.append(chain(repeat(fillvalue, -n), it)) 

1864 elif n > 0: 

1865 staggered.append(islice(it, n, None)) 

1866 else: 

1867 staggered.append(it) 

1868 

1869 if longest: 

1870 return zip_longest(*staggered, fillvalue=fillvalue) 

1871 

1872 return zip(*staggered) 

1873 

1874 

1875def sort_together( 

1876 iterables, key_list=(0,), key=None, reverse=False, strict=False 

1877): 

1878 """Return the input iterables sorted together, with *key_list* as the 

1879 priority for sorting. All iterables are trimmed to the length of the 

1880 shortest one. 

1881 

1882 This can be used like the sorting function in a spreadsheet. If each 

1883 iterable represents a column of data, the key list determines which 

1884 columns are used for sorting. 

1885 

1886 By default, all iterables are sorted using the ``0``-th iterable:: 

1887 

1888 >>> iterables = [(4, 3, 2, 1), ('a', 'b', 'c', 'd')] 

1889 >>> sort_together(iterables) 

1890 [(1, 2, 3, 4), ('d', 'c', 'b', 'a')] 

1891 

1892 Set a different key list to sort according to another iterable. 

1893 Specifying multiple keys dictates how ties are broken:: 

1894 

1895 >>> iterables = [(3, 1, 2), (0, 1, 0), ('c', 'b', 'a')] 

1896 >>> sort_together(iterables, key_list=(1, 2)) 

1897 [(2, 3, 1), (0, 0, 1), ('a', 'c', 'b')] 

1898 

1899 To sort by a function of the elements of the iterable, pass a *key* 

1900 function. Its arguments are the elements of the iterables corresponding to 

1901 the key list:: 

1902 

1903 >>> names = ('a', 'b', 'c') 

1904 >>> lengths = (1, 2, 3) 

1905 >>> widths = (5, 2, 1) 

1906 >>> def area(length, width): 

1907 ... return length * width 

1908 >>> sort_together([names, lengths, widths], key_list=(1, 2), key=area) 

1909 [('c', 'b', 'a'), (3, 2, 1), (1, 2, 5)] 

1910 

1911 Set *reverse* to ``True`` to sort in descending order. 

1912 

1913 >>> sort_together([(1, 2, 3), ('c', 'b', 'a')], reverse=True) 

1914 [(3, 2, 1), ('a', 'b', 'c')] 

1915 

1916 If the *strict* keyword argument is ``True``, then 

1917 ``UnequalIterablesError`` will be raised if any of the iterables have 

1918 different lengths. 

1919 

1920 """ 

1921 if key is None: 

1922 # if there is no key function, the key argument to sorted is an 

1923 # itemgetter 

1924 key_argument = itemgetter(*key_list) 

1925 else: 

1926 # if there is a key function, call it with the items at the offsets 

1927 # specified by the key function as arguments 

1928 key_list = list(key_list) 

1929 if len(key_list) == 1: 

1930 # if key_list contains a single item, pass the item at that offset 

1931 # as the only argument to the key function 

1932 key_offset = key_list[0] 

1933 key_argument = lambda zipped_items: key(zipped_items[key_offset]) 

1934 else: 

1935 # if key_list contains multiple items, use itemgetter to return a 

1936 # tuple of items, which we pass as *args to the key function 

1937 get_key_items = itemgetter(*key_list) 

1938 key_argument = lambda zipped_items: key( 

1939 *get_key_items(zipped_items) 

1940 ) 

1941 

1942 zipper = zip_equal if strict else zip 

1943 return list( 

1944 zipper(*sorted(zipper(*iterables), key=key_argument, reverse=reverse)) 

1945 ) 

1946 

1947 

1948def unzip(iterable): 

1949 """The inverse of :func:`zip`, this function disaggregates the elements 

1950 of the zipped *iterable*. 

1951 

1952 The ``i``-th iterable contains the ``i``-th element from each element 

1953 of the zipped iterable. The first element is used to determine the 

1954 length of the remaining elements. 

1955 

1956 >>> iterable = [('a', 1), ('b', 2), ('c', 3), ('d', 4)] 

1957 >>> letters, numbers = unzip(iterable) 

1958 >>> list(letters) 

1959 ['a', 'b', 'c', 'd'] 

1960 >>> list(numbers) 

1961 [1, 2, 3, 4] 

1962 

1963 This is similar to using ``zip(*iterable)``, but it avoids reading 

1964 *iterable* into memory. Note, however, that this function uses 

1965 :func:`itertools.tee` and thus may require significant storage. 

1966 

1967 """ 

1968 head, iterable = spy(iterable) 

1969 if not head: 

1970 # empty iterable, e.g. zip([], [], []) 

1971 return () 

1972 # spy returns a one-length iterable as head 

1973 head = head[0] 

1974 iterables = tee(iterable, len(head)) 

1975 

1976 def itemgetter(i): 

1977 def getter(obj): 

1978 try: 

1979 return obj[i] 

1980 except IndexError: 

1981 # basically if we have an iterable like 

1982 # iter([(1, 2, 3), (4, 5), (6,)]) 

1983 # the second unzipped iterable would fail at the third tuple 

1984 # since it would try to access tup[1] 

1985 # same with the third unzipped iterable and the second tuple 

1986 # to support these "improperly zipped" iterables, 

1987 # we create a custom itemgetter 

1988 # which just stops the unzipped iterables 

1989 # at first length mismatch 

1990 raise StopIteration 

1991 

1992 return getter 

1993 

1994 return tuple(map(itemgetter(i), it) for i, it in enumerate(iterables)) 

1995 

1996 

1997def divide(n, iterable): 

1998 """Divide the elements from *iterable* into *n* parts, maintaining 

1999 order. 

2000 

2001 >>> group_1, group_2 = divide(2, [1, 2, 3, 4, 5, 6]) 

2002 >>> list(group_1) 

2003 [1, 2, 3] 

2004 >>> list(group_2) 

2005 [4, 5, 6] 

2006 

2007 If the length of *iterable* is not evenly divisible by *n*, then the 

2008 length of the returned iterables will not be identical: 

2009 

2010 >>> children = divide(3, [1, 2, 3, 4, 5, 6, 7]) 

2011 >>> [list(c) for c in children] 

2012 [[1, 2, 3], [4, 5], [6, 7]] 

2013 

2014 If the length of the iterable is smaller than n, then the last returned 

2015 iterables will be empty: 

2016 

2017 >>> children = divide(5, [1, 2, 3]) 

2018 >>> [list(c) for c in children] 

2019 [[1], [2], [3], [], []] 

2020 

2021 This function will exhaust the iterable before returning. 

2022 If order is not important, see :func:`distribute`, which does not first 

2023 pull the iterable into memory. 

2024 

2025 """ 

2026 if n < 1: 

2027 raise ValueError('n must be at least 1') 

2028 

2029 try: 

2030 iterable[:0] 

2031 except TypeError: 

2032 seq = tuple(iterable) 

2033 else: 

2034 seq = iterable 

2035 

2036 q, r = divmod(len(seq), n) 

2037 

2038 ret = [] 

2039 stop = 0 

2040 for i in range(1, n + 1): 

2041 start = stop 

2042 stop += q + 1 if i <= r else q 

2043 ret.append(iter(seq[start:stop])) 

2044 

2045 return ret 

2046 

2047 

2048def always_iterable(obj, base_type=(str, bytes)): 

2049 """If *obj* is iterable, return an iterator over its items:: 

2050 

2051 >>> obj = (1, 2, 3) 

2052 >>> list(always_iterable(obj)) 

2053 [1, 2, 3] 

2054 

2055 If *obj* is not iterable, return a one-item iterable containing *obj*:: 

2056 

2057 >>> obj = 1 

2058 >>> list(always_iterable(obj)) 

2059 [1] 

2060 

2061 If *obj* is ``None``, return an empty iterable: 

2062 

2063 >>> obj = None 

2064 >>> list(always_iterable(None)) 

2065 [] 

2066 

2067 By default, binary and text strings are not considered iterable:: 

2068 

2069 >>> obj = 'foo' 

2070 >>> list(always_iterable(obj)) 

2071 ['foo'] 

2072 

2073 If *base_type* is set, objects for which ``isinstance(obj, base_type)`` 

2074 returns ``True`` won't be considered iterable. 

2075 

2076 >>> obj = {'a': 1} 

2077 >>> list(always_iterable(obj)) # Iterate over the dict's keys 

2078 ['a'] 

2079 >>> list(always_iterable(obj, base_type=dict)) # Treat dicts as a unit 

2080 [{'a': 1}] 

2081 

2082 Set *base_type* to ``None`` to avoid any special handling and treat objects 

2083 Python considers iterable as iterable: 

2084 

2085 >>> obj = 'foo' 

2086 >>> list(always_iterable(obj, base_type=None)) 

2087 ['f', 'o', 'o'] 

2088 """ 

2089 if obj is None: 

2090 return iter(()) 

2091 

2092 if (base_type is not None) and isinstance(obj, base_type): 

2093 return iter((obj,)) 

2094 

2095 try: 

2096 return iter(obj) 

2097 except TypeError: 

2098 return iter((obj,)) 

2099 

2100 

2101def adjacent(predicate, iterable, distance=1): 

2102 """Return an iterable over `(bool, item)` tuples where the `item` is 

2103 drawn from *iterable* and the `bool` indicates whether 

2104 that item satisfies the *predicate* or is adjacent to an item that does. 

2105 

2106 For example, to find whether items are adjacent to a ``3``:: 

2107 

2108 >>> list(adjacent(lambda x: x == 3, range(6))) 

2109 [(False, 0), (False, 1), (True, 2), (True, 3), (True, 4), (False, 5)] 

2110 

2111 Set *distance* to change what counts as adjacent. For example, to find 

2112 whether items are two places away from a ``3``: 

2113 

2114 >>> list(adjacent(lambda x: x == 3, range(6), distance=2)) 

2115 [(False, 0), (True, 1), (True, 2), (True, 3), (True, 4), (True, 5)] 

2116 

2117 This is useful for contextualizing the results of a search function. 

2118 For example, a code comparison tool might want to identify lines that 

2119 have changed, but also surrounding lines to give the viewer of the diff 

2120 context. 

2121 

2122 The predicate function will only be called once for each item in the 

2123 iterable. 

2124 

2125 See also :func:`groupby_transform`, which can be used with this function 

2126 to group ranges of items with the same `bool` value. 

2127 

2128 """ 

2129 # Allow distance=0 mainly for testing that it reproduces results with map() 

2130 if distance < 0: 

2131 raise ValueError('distance must be at least 0') 

2132 

2133 i1, i2 = tee(iterable) 

2134 padding = [False] * distance 

2135 selected = chain(padding, map(predicate, i1), padding) 

2136 adjacent_to_selected = map(any, windowed(selected, 2 * distance + 1)) 

2137 return zip(adjacent_to_selected, i2) 

2138 

2139 

2140def groupby_transform(iterable, keyfunc=None, valuefunc=None, reducefunc=None): 

2141 """An extension of :func:`itertools.groupby` that can apply transformations 

2142 to the grouped data. 

2143 

2144 * *keyfunc* is a function computing a key value for each item in *iterable* 

2145 * *valuefunc* is a function that transforms the individual items from 

2146 *iterable* after grouping 

2147 * *reducefunc* is a function that transforms each group of items 

2148 

2149 >>> iterable = 'aAAbBBcCC' 

2150 >>> keyfunc = lambda k: k.upper() 

2151 >>> valuefunc = lambda v: v.lower() 

2152 >>> reducefunc = lambda g: ''.join(g) 

2153 >>> list(groupby_transform(iterable, keyfunc, valuefunc, reducefunc)) 

2154 [('A', 'aaa'), ('B', 'bbb'), ('C', 'ccc')] 

2155 

2156 Each optional argument defaults to an identity function if not specified. 

2157 

2158 :func:`groupby_transform` is useful when grouping elements of an iterable 

2159 using a separate iterable as the key. To do this, :func:`zip` the iterables 

2160 and pass a *keyfunc* that extracts the first element and a *valuefunc* 

2161 that extracts the second element:: 

2162 

2163 >>> from operator import itemgetter 

2164 >>> keys = [0, 0, 1, 1, 1, 2, 2, 2, 3] 

2165 >>> values = 'abcdefghi' 

2166 >>> iterable = zip(keys, values) 

2167 >>> grouper = groupby_transform(iterable, itemgetter(0), itemgetter(1)) 

2168 >>> [(k, ''.join(g)) for k, g in grouper] 

2169 [(0, 'ab'), (1, 'cde'), (2, 'fgh'), (3, 'i')] 

2170 

2171 Note that the order of items in the iterable is significant. 

2172 Only adjacent items are grouped together, so if you don't want any 

2173 duplicate groups, you should sort the iterable by the key function. 

2174 

2175 """ 

2176 ret = groupby(iterable, keyfunc) 

2177 if valuefunc: 

2178 ret = ((k, map(valuefunc, g)) for k, g in ret) 

2179 if reducefunc: 

2180 ret = ((k, reducefunc(g)) for k, g in ret) 

2181 

2182 return ret 

2183 

2184 

2185class numeric_range(abc.Sequence, abc.Hashable): 

2186 """An extension of the built-in ``range()`` function whose arguments can 

2187 be any orderable numeric type. 

2188 

2189 With only *stop* specified, *start* defaults to ``0`` and *step* 

2190 defaults to ``1``. The output items will match the type of *stop*: 

2191 

2192 >>> list(numeric_range(3.5)) 

2193 [0.0, 1.0, 2.0, 3.0] 

2194 

2195 With only *start* and *stop* specified, *step* defaults to ``1``. The 

2196 output items will match the type of *start*: 

2197 

2198 >>> from decimal import Decimal 

2199 >>> start = Decimal('2.1') 

2200 >>> stop = Decimal('5.1') 

2201 >>> list(numeric_range(start, stop)) 

2202 [Decimal('2.1'), Decimal('3.1'), Decimal('4.1')] 

2203 

2204 With *start*, *stop*, and *step* specified the output items will match 

2205 the type of ``start + step``: 

2206 

2207 >>> from fractions import Fraction 

2208 >>> start = Fraction(1, 2) # Start at 1/2 

2209 >>> stop = Fraction(5, 2) # End at 5/2 

2210 >>> step = Fraction(1, 2) # Count by 1/2 

2211 >>> list(numeric_range(start, stop, step)) 

2212 [Fraction(1, 2), Fraction(1, 1), Fraction(3, 2), Fraction(2, 1)] 

2213 

2214 If *step* is zero, ``ValueError`` is raised. Negative steps are supported: 

2215 

2216 >>> list(numeric_range(3, -1, -1.0)) 

2217 [3.0, 2.0, 1.0, 0.0] 

2218 

2219 Be aware of the limitations of floating-point numbers; the representation 

2220 of the yielded numbers may be surprising. 

2221 

2222 ``datetime.datetime`` objects can be used for *start* and *stop*, if *step* 

2223 is a ``datetime.timedelta`` object: 

2224 

2225 >>> import datetime 

2226 >>> start = datetime.datetime(2019, 1, 1) 

2227 >>> stop = datetime.datetime(2019, 1, 3) 

2228 >>> step = datetime.timedelta(days=1) 

2229 >>> items = iter(numeric_range(start, stop, step)) 

2230 >>> next(items) 

2231 datetime.datetime(2019, 1, 1, 0, 0) 

2232 >>> next(items) 

2233 datetime.datetime(2019, 1, 2, 0, 0) 

2234 

2235 """ 

2236 

2237 _EMPTY_HASH = hash(range(0, 0)) 

2238 

2239 def __init__(self, *args): 

2240 argc = len(args) 

2241 if argc == 1: 

2242 (self._stop,) = args 

2243 self._start = type(self._stop)(0) 

2244 self._step = type(self._stop - self._start)(1) 

2245 elif argc == 2: 

2246 self._start, self._stop = args 

2247 self._step = type(self._stop - self._start)(1) 

2248 elif argc == 3: 

2249 self._start, self._stop, self._step = args 

2250 elif argc == 0: 

2251 raise TypeError( 

2252 f'numeric_range expected at least 1 argument, got {argc}' 

2253 ) 

2254 else: 

2255 raise TypeError( 

2256 f'numeric_range expected at most 3 arguments, got {argc}' 

2257 ) 

2258 

2259 self._zero = type(self._step)(0) 

2260 if self._step == self._zero: 

2261 raise ValueError('numeric_range() arg 3 must not be zero') 

2262 self._growing = self._step > self._zero 

2263 

2264 def __bool__(self): 

2265 if self._growing: 

2266 return self._start < self._stop 

2267 else: 

2268 return self._start > self._stop 

2269 

2270 def __contains__(self, elem): 

2271 if self._growing: 

2272 if self._start <= elem < self._stop: 

2273 return (elem - self._start) % self._step == self._zero 

2274 else: 

2275 if self._start >= elem > self._stop: 

2276 return (self._start - elem) % (-self._step) == self._zero 

2277 

2278 return False 

2279 

2280 def __eq__(self, other): 

2281 if isinstance(other, numeric_range): 

2282 empty_self = not bool(self) 

2283 empty_other = not bool(other) 

2284 if empty_self or empty_other: 

2285 return empty_self and empty_other # True if both empty 

2286 else: 

2287 return ( 

2288 self._start == other._start 

2289 and self._step == other._step 

2290 and self._get_by_index(-1) == other._get_by_index(-1) 

2291 ) 

2292 else: 

2293 return False 

2294 

2295 def __getitem__(self, key): 

2296 if isinstance(key, int): 

2297 return self._get_by_index(key) 

2298 elif isinstance(key, slice): 

2299 step = self._step if key.step is None else key.step * self._step 

2300 

2301 if key.start is None or key.start <= -self._len: 

2302 start = self._start 

2303 elif key.start >= self._len: 

2304 start = self._stop 

2305 else: # -self._len < key.start < self._len 

2306 start = self._get_by_index(key.start) 

2307 

2308 if key.stop is None or key.stop >= self._len: 

2309 stop = self._stop 

2310 elif key.stop <= -self._len: 

2311 stop = self._start 

2312 else: # -self._len < key.stop < self._len 

2313 stop = self._get_by_index(key.stop) 

2314 

2315 return numeric_range(start, stop, step) 

2316 else: 

2317 raise TypeError( 

2318 'numeric range indices must be ' 

2319 f'integers or slices, not {type(key).__name__}' 

2320 ) 

2321 

2322 def __hash__(self): 

2323 if self: 

2324 return hash((self._start, self._get_by_index(-1), self._step)) 

2325 else: 

2326 return self._EMPTY_HASH 

2327 

2328 def __iter__(self): 

2329 values = (self._start + (n * self._step) for n in count()) 

2330 if self._growing: 

2331 return takewhile(partial(gt, self._stop), values) 

2332 else: 

2333 return takewhile(partial(lt, self._stop), values) 

2334 

2335 def __len__(self): 

2336 return self._len 

2337 

2338 @cached_property 

2339 def _len(self): 

2340 if self._growing: 

2341 start = self._start 

2342 stop = self._stop 

2343 step = self._step 

2344 else: 

2345 start = self._stop 

2346 stop = self._start 

2347 step = -self._step 

2348 distance = stop - start 

2349 if distance <= self._zero: 

2350 return 0 

2351 else: # distance > 0 and step > 0: regular euclidean division 

2352 q, r = divmod(distance, step) 

2353 return int(q) + int(r != self._zero) 

2354 

2355 def __reduce__(self): 

2356 return numeric_range, (self._start, self._stop, self._step) 

2357 

2358 def __repr__(self): 

2359 if self._step == 1: 

2360 return f"numeric_range({self._start!r}, {self._stop!r})" 

2361 return ( 

2362 f"numeric_range({self._start!r}, {self._stop!r}, {self._step!r})" 

2363 ) 

2364 

2365 def __reversed__(self): 

2366 return iter( 

2367 numeric_range( 

2368 self._get_by_index(-1), self._start - self._step, -self._step 

2369 ) 

2370 ) 

2371 

2372 def count(self, value): 

2373 return int(value in self) 

2374 

2375 def index(self, value): 

2376 if self._growing: 

2377 if self._start <= value < self._stop: 

2378 q, r = divmod(value - self._start, self._step) 

2379 if r == self._zero: 

2380 return int(q) 

2381 else: 

2382 if self._start >= value > self._stop: 

2383 q, r = divmod(self._start - value, -self._step) 

2384 if r == self._zero: 

2385 return int(q) 

2386 

2387 raise ValueError(f"{value} is not in numeric range") 

2388 

2389 def _get_by_index(self, i): 

2390 if i < 0: 

2391 i += self._len 

2392 if i < 0 or i >= self._len: 

2393 raise IndexError("numeric range object index out of range") 

2394 return self._start + i * self._step 

2395 

2396 

2397def count_cycle(iterable, n=None): 

2398 """Cycle through the items from *iterable* up to *n* times, yielding 

2399 the number of completed cycles along with each item. If *n* is omitted the 

2400 process repeats indefinitely. 

2401 

2402 >>> list(count_cycle('AB', 3)) 

2403 [(0, 'A'), (0, 'B'), (1, 'A'), (1, 'B'), (2, 'A'), (2, 'B')] 

2404 

2405 """ 

2406 iterable = tuple(iterable) 

2407 if not iterable: 

2408 return iter(()) 

2409 counter = count() if n is None else range(n) 

2410 return ((i, item) for i in counter for item in iterable) 

2411 

2412 

2413def mark_ends(iterable): 

2414 """Yield 3-tuples of the form ``(is_first, is_last, item)``. 

2415 

2416 >>> list(mark_ends('ABC')) 

2417 [(True, False, 'A'), (False, False, 'B'), (False, True, 'C')] 

2418 

2419 Use this when looping over an iterable to take special action on its first 

2420 and/or last items: 

2421 

2422 >>> iterable = ['Header', 100, 200, 'Footer'] 

2423 >>> total = 0 

2424 >>> for is_first, is_last, item in mark_ends(iterable): 

2425 ... if is_first: 

2426 ... continue # Skip the header 

2427 ... if is_last: 

2428 ... continue # Skip the footer 

2429 ... total += item 

2430 >>> print(total) 

2431 300 

2432 """ 

2433 it = iter(iterable) 

2434 

2435 try: 

2436 b = next(it) 

2437 except StopIteration: 

2438 return 

2439 

2440 try: 

2441 for i in count(): 

2442 a = b 

2443 b = next(it) 

2444 yield i == 0, False, a 

2445 

2446 except StopIteration: 

2447 yield i == 0, True, a 

2448 

2449 

2450def locate(iterable, pred=bool, window_size=None): 

2451 """Yield the index of each item in *iterable* for which *pred* returns 

2452 ``True``. 

2453 

2454 *pred* defaults to :func:`bool`, which will select truthy items: 

2455 

2456 >>> list(locate([0, 1, 1, 0, 1, 0, 0])) 

2457 [1, 2, 4] 

2458 

2459 Set *pred* to a custom function to, e.g., find the indexes for a particular 

2460 item. 

2461 

2462 >>> list(locate(['a', 'b', 'c', 'b'], lambda x: x == 'b')) 

2463 [1, 3] 

2464 

2465 If *window_size* is given, then the *pred* function will be called with 

2466 that many items. This enables searching for sub-sequences: 

2467 

2468 >>> iterable = [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3] 

2469 >>> pred = lambda *args: args == (1, 2, 3) 

2470 >>> list(locate(iterable, pred=pred, window_size=3)) 

2471 [1, 5, 9] 

2472 

2473 Use with :func:`seekable` to find indexes and then retrieve the associated 

2474 items: 

2475 

2476 >>> from itertools import count 

2477 >>> from more_itertools import seekable 

2478 >>> source = (3 * n + 1 if (n % 2) else n // 2 for n in count()) 

2479 >>> it = seekable(source) 

2480 >>> pred = lambda x: x > 100 

2481 >>> indexes = locate(it, pred=pred) 

2482 >>> i = next(indexes) 

2483 >>> it.seek(i) 

2484 >>> next(it) 

2485 106 

2486 

2487 """ 

2488 if window_size is None: 

2489 return compress(count(), map(pred, iterable)) 

2490 

2491 if window_size < 1: 

2492 raise ValueError('window size must be at least 1') 

2493 

2494 it = windowed(iterable, window_size, fillvalue=_marker) 

2495 return compress(count(), starmap(pred, it)) 

2496 

2497 

2498def longest_common_prefix(iterables): 

2499 """Yield elements of the longest common prefix among given *iterables*. 

2500 

2501 >>> ''.join(longest_common_prefix(['abcd', 'abc', 'abf'])) 

2502 'ab' 

2503 

2504 """ 

2505 return (c[0] for c in takewhile(all_equal, zip(*iterables))) 

2506 

2507 

2508def lstrip(iterable, pred): 

2509 """Yield the items from *iterable*, but strip any from the beginning 

2510 for which *pred* returns ``True``. 

2511 

2512 For example, to remove a set of items from the start of an iterable: 

2513 

2514 >>> iterable = (None, False, None, 1, 2, None, 3, False, None) 

2515 >>> pred = lambda x: x in {None, False, ''} 

2516 >>> list(lstrip(iterable, pred)) 

2517 [1, 2, None, 3, False, None] 

2518 

2519 This function is analogous to to :func:`str.lstrip`, and is essentially 

2520 an wrapper for :func:`itertools.dropwhile`. 

2521 

2522 """ 

2523 return dropwhile(pred, iterable) 

2524 

2525 

2526def rstrip(iterable, pred): 

2527 """Yield the items from *iterable*, but strip any from the end 

2528 for which *pred* returns ``True``. 

2529 

2530 For example, to remove a set of items from the end of an iterable: 

2531 

2532 >>> iterable = (None, False, None, 1, 2, None, 3, False, None) 

2533 >>> pred = lambda x: x in {None, False, ''} 

2534 >>> list(rstrip(iterable, pred)) 

2535 [None, False, None, 1, 2, None, 3] 

2536 

2537 This function is analogous to :func:`str.rstrip`. 

2538 

2539 """ 

2540 cache = [] 

2541 cache_append = cache.append 

2542 cache_clear = cache.clear 

2543 for x in iterable: 

2544 if pred(x): 

2545 cache_append(x) 

2546 else: 

2547 yield from cache 

2548 cache_clear() 

2549 yield x 

2550 

2551 

2552def strip(iterable, pred): 

2553 """Yield the items from *iterable*, but strip any from the 

2554 beginning and end for which *pred* returns ``True``. 

2555 

2556 For example, to remove a set of items from both ends of an iterable: 

2557 

2558 >>> iterable = (None, False, None, 1, 2, None, 3, False, None) 

2559 >>> pred = lambda x: x in {None, False, ''} 

2560 >>> list(strip(iterable, pred)) 

2561 [1, 2, None, 3] 

2562 

2563 This function is analogous to :func:`str.strip`. 

2564 

2565 """ 

2566 return rstrip(lstrip(iterable, pred), pred) 

2567 

2568 

2569class islice_extended: 

2570 """An extension of :func:`itertools.islice` that supports negative values 

2571 for *stop*, *start*, and *step*. 

2572 

2573 >>> iterator = iter('abcdefgh') 

2574 >>> list(islice_extended(iterator, -4, -1)) 

2575 ['e', 'f', 'g'] 

2576 

2577 Slices with negative values require some caching of *iterable*, but this 

2578 function takes care to minimize the amount of memory required. 

2579 

2580 For example, you can use a negative step with an infinite iterator: 

2581 

2582 >>> from itertools import count 

2583 >>> list(islice_extended(count(), 110, 99, -2)) 

2584 [110, 108, 106, 104, 102, 100] 

2585 

2586 You can also use slice notation directly: 

2587 

2588 >>> iterator = map(str, count()) 

2589 >>> it = islice_extended(iterator)[10:20:2] 

2590 >>> list(it) 

2591 ['10', '12', '14', '16', '18'] 

2592 

2593 """ 

2594 

2595 def __init__(self, iterable, *args): 

2596 it = iter(iterable) 

2597 if args: 

2598 self._iterator = _islice_helper(it, slice(*args)) 

2599 else: 

2600 self._iterator = it 

2601 

2602 def __iter__(self): 

2603 return self 

2604 

2605 def __next__(self): 

2606 return next(self._iterator) 

2607 

2608 def __getitem__(self, key): 

2609 if isinstance(key, slice): 

2610 return islice_extended(_islice_helper(self._iterator, key)) 

2611 

2612 raise TypeError('islice_extended.__getitem__ argument must be a slice') 

2613 

2614 

2615def _islice_helper(it, s): 

2616 start = s.start 

2617 stop = s.stop 

2618 if s.step == 0: 

2619 raise ValueError('step argument must be a non-zero integer or None.') 

2620 step = s.step or 1 

2621 

2622 if step > 0: 

2623 start = 0 if (start is None) else start 

2624 

2625 if start < 0: 

2626 # Consume all but the last -start items 

2627 cache = deque(enumerate(it, 1), maxlen=-start) 

2628 len_iter = cache[-1][0] if cache else 0 

2629 

2630 # Adjust start to be positive 

2631 i = max(len_iter + start, 0) 

2632 

2633 # Adjust stop to be positive 

2634 if stop is None: 

2635 j = len_iter 

2636 elif stop >= 0: 

2637 j = min(stop, len_iter) 

2638 else: 

2639 j = max(len_iter + stop, 0) 

2640 

2641 # Slice the cache 

2642 n = j - i 

2643 if n <= 0: 

2644 return 

2645 

2646 for index in range(n): 

2647 if index % step == 0: 

2648 # pop and yield the item. 

2649 # We don't want to use an intermediate variable 

2650 # it would extend the lifetime of the current item 

2651 yield cache.popleft()[1] 

2652 else: 

2653 # just pop and discard the item 

2654 cache.popleft() 

2655 elif (stop is not None) and (stop < 0): 

2656 # Advance to the start position 

2657 next(islice(it, start, start), None) 

2658 

2659 # When stop is negative, we have to carry -stop items while 

2660 # iterating 

2661 cache = deque(islice(it, -stop), maxlen=-stop) 

2662 

2663 for index, item in enumerate(it): 

2664 if index % step == 0: 

2665 # pop and yield the item. 

2666 # We don't want to use an intermediate variable 

2667 # it would extend the lifetime of the current item 

2668 yield cache.popleft() 

2669 else: 

2670 # just pop and discard the item 

2671 cache.popleft() 

2672 cache.append(item) 

2673 else: 

2674 # When both start and stop are positive we have the normal case 

2675 yield from islice(it, start, stop, step) 

2676 else: 

2677 start = -1 if (start is None) else start 

2678 

2679 if (stop is not None) and (stop < 0): 

2680 # Consume all but the last items 

2681 n = -stop - 1 

2682 cache = deque(enumerate(it, 1), maxlen=n) 

2683 len_iter = cache[-1][0] if cache else 0 

2684 

2685 # If start and stop are both negative they are comparable and 

2686 # we can just slice. Otherwise we can adjust start to be negative 

2687 # and then slice. 

2688 if start < 0: 

2689 i, j = start, stop 

2690 else: 

2691 i, j = min(start - len_iter, -1), None 

2692 

2693 for index, item in list(cache)[i:j:step]: 

2694 yield item 

2695 else: 

2696 # Advance to the stop position 

2697 if stop is not None: 

2698 m = stop + 1 

2699 next(islice(it, m, m), None) 

2700 

2701 # stop is positive, so if start is negative they are not comparable 

2702 # and we need the rest of the items. 

2703 if start < 0: 

2704 i = start 

2705 n = None 

2706 # stop is None and start is positive, so we just need items up to 

2707 # the start index. 

2708 elif stop is None: 

2709 i = None 

2710 n = start + 1 

2711 # Both stop and start are positive, so they are comparable. 

2712 else: 

2713 i = None 

2714 n = start - stop 

2715 if n <= 0: 

2716 return 

2717 

2718 cache = list(islice(it, n)) 

2719 

2720 yield from cache[i::step] 

2721 

2722 

2723def always_reversible(iterable): 

2724 """An extension of :func:`reversed` that supports all iterables, not 

2725 just those which implement the ``Reversible`` or ``Sequence`` protocols. 

2726 

2727 >>> print(*always_reversible(x for x in range(3))) 

2728 2 1 0 

2729 

2730 If the iterable is already reversible, this function returns the 

2731 result of :func:`reversed()`. If the iterable is not reversible, 

2732 this function will cache the remaining items in the iterable and 

2733 yield them in reverse order, which may require significant storage. 

2734 """ 

2735 try: 

2736 return reversed(iterable) 

2737 except TypeError: 

2738 return reversed(list(iterable)) 

2739 

2740 

2741def consecutive_groups(iterable, ordering=None): 

2742 """Yield groups of consecutive items using :func:`itertools.groupby`. 

2743 The *ordering* function determines whether two items are adjacent by 

2744 returning their position. 

2745 

2746 By default, the ordering function is the identity function. This is 

2747 suitable for finding runs of numbers: 

2748 

2749 >>> iterable = [1, 10, 11, 12, 20, 30, 31, 32, 33, 40] 

2750 >>> for group in consecutive_groups(iterable): 

2751 ... print(list(group)) 

2752 [1] 

2753 [10, 11, 12] 

2754 [20] 

2755 [30, 31, 32, 33] 

2756 [40] 

2757 

2758 To find runs of adjacent letters, apply :func:`ord` function 

2759 to convert letters to ordinals. 

2760 

2761 >>> iterable = 'abcdfgilmnop' 

2762 >>> ordering = ord 

2763 >>> for group in consecutive_groups(iterable, ordering): 

2764 ... print(list(group)) 

2765 ['a', 'b', 'c', 'd'] 

2766 ['f', 'g'] 

2767 ['i'] 

2768 ['l', 'm', 'n', 'o', 'p'] 

2769 

2770 Each group of consecutive items is an iterator that shares it source with 

2771 *iterable*. When an an output group is advanced, the previous group is 

2772 no longer available unless its elements are copied (e.g., into a ``list``). 

2773 

2774 >>> iterable = [1, 2, 11, 12, 21, 22] 

2775 >>> saved_groups = [] 

2776 >>> for group in consecutive_groups(iterable): 

2777 ... saved_groups.append(list(group)) # Copy group elements 

2778 >>> saved_groups 

2779 [[1, 2], [11, 12], [21, 22]] 

2780 

2781 """ 

2782 if ordering is None: 

2783 key = lambda x: x[0] - x[1] 

2784 else: 

2785 key = lambda x: x[0] - ordering(x[1]) 

2786 

2787 for k, g in groupby(enumerate(iterable), key=key): 

2788 yield map(itemgetter(1), g) 

2789 

2790 

2791def difference(iterable, func=sub, *, initial=None): 

2792 """This function is the inverse of :func:`itertools.accumulate`. By default 

2793 it will compute the first difference of *iterable* using 

2794 :func:`operator.sub`: 

2795 

2796 >>> from itertools import accumulate 

2797 >>> iterable = accumulate([0, 1, 2, 3, 4]) # produces 0, 1, 3, 6, 10 

2798 >>> list(difference(iterable)) 

2799 [0, 1, 2, 3, 4] 

2800 

2801 *func* defaults to :func:`operator.sub`, but other functions can be 

2802 specified. They will be applied as follows:: 

2803 

2804 A, B, C, D, ... --> A, func(B, A), func(C, B), func(D, C), ... 

2805 

2806 For example, to do progressive division: 

2807 

2808 >>> iterable = [1, 2, 6, 24, 120] 

2809 >>> func = lambda x, y: x // y 

2810 >>> list(difference(iterable, func)) 

2811 [1, 2, 3, 4, 5] 

2812 

2813 If the *initial* keyword is set, the first element will be skipped when 

2814 computing successive differences. 

2815 

2816 >>> it = [10, 11, 13, 16] # from accumulate([1, 2, 3], initial=10) 

2817 >>> list(difference(it, initial=10)) 

2818 [1, 2, 3] 

2819 

2820 """ 

2821 a, b = tee(iterable) 

2822 try: 

2823 first = [next(b)] 

2824 except StopIteration: 

2825 return iter([]) 

2826 

2827 if initial is not None: 

2828 first = [] 

2829 

2830 return chain(first, map(func, b, a)) 

2831 

2832 

2833class SequenceView(Sequence): 

2834 """Return a read-only view of the sequence object *target*. 

2835 

2836 :class:`SequenceView` objects are analogous to Python's built-in 

2837 "dictionary view" types. They provide a dynamic view of a sequence's items, 

2838 meaning that when the sequence updates, so does the view. 

2839 

2840 >>> seq = ['0', '1', '2'] 

2841 >>> view = SequenceView(seq) 

2842 >>> view 

2843 SequenceView(['0', '1', '2']) 

2844 >>> seq.append('3') 

2845 >>> view 

2846 SequenceView(['0', '1', '2', '3']) 

2847 

2848 Sequence views support indexing, slicing, and length queries. They act 

2849 like the underlying sequence, except they don't allow assignment: 

2850 

2851 >>> view[1] 

2852 '1' 

2853 >>> view[1:-1] 

2854 ['1', '2'] 

2855 >>> len(view) 

2856 4 

2857 

2858 Sequence views are useful as an alternative to copying, as they don't 

2859 require (much) extra storage. 

2860 

2861 """ 

2862 

2863 def __init__(self, target): 

2864 if not isinstance(target, Sequence): 

2865 raise TypeError 

2866 self._target = target 

2867 

2868 def __getitem__(self, index): 

2869 return self._target[index] 

2870 

2871 def __len__(self): 

2872 return len(self._target) 

2873 

2874 def __repr__(self): 

2875 return f'{self.__class__.__name__}({self._target!r})' 

2876 

2877 

2878class seekable: 

2879 """Wrap an iterator to allow for seeking backward and forward. This 

2880 progressively caches the items in the source iterable so they can be 

2881 re-visited. 

2882 

2883 Call :meth:`seek` with an index to seek to that position in the source 

2884 iterable. 

2885 

2886 To "reset" an iterator, seek to ``0``: 

2887 

2888 >>> from itertools import count 

2889 >>> it = seekable((str(n) for n in count())) 

2890 >>> next(it), next(it), next(it) 

2891 ('0', '1', '2') 

2892 >>> it.seek(0) 

2893 >>> next(it), next(it), next(it) 

2894 ('0', '1', '2') 

2895 

2896 You can also seek forward: 

2897 

2898 >>> it = seekable((str(n) for n in range(20))) 

2899 >>> it.seek(10) 

2900 >>> next(it) 

2901 '10' 

2902 >>> it.seek(20) # Seeking past the end of the source isn't a problem 

2903 >>> list(it) 

2904 [] 

2905 >>> it.seek(0) # Resetting works even after hitting the end 

2906 >>> next(it) 

2907 '0' 

2908 

2909 Call :meth:`relative_seek` to seek relative to the source iterator's 

2910 current position. 

2911 

2912 >>> it = seekable((str(n) for n in range(20))) 

2913 >>> next(it), next(it), next(it) 

2914 ('0', '1', '2') 

2915 >>> it.relative_seek(2) 

2916 >>> next(it) 

2917 '5' 

2918 >>> it.relative_seek(-3) # Source is at '6', we move back to '3' 

2919 >>> next(it) 

2920 '3' 

2921 >>> it.relative_seek(-3) # Source is at '4', we move back to '1' 

2922 >>> next(it) 

2923 '1' 

2924 

2925 

2926 Call :meth:`peek` to look ahead one item without advancing the iterator: 

2927 

2928 >>> it = seekable('1234') 

2929 >>> it.peek() 

2930 '1' 

2931 >>> list(it) 

2932 ['1', '2', '3', '4'] 

2933 >>> it.peek(default='empty') 

2934 'empty' 

2935 

2936 Before the iterator is at its end, calling :func:`bool` on it will return 

2937 ``True``. After it will return ``False``: 

2938 

2939 >>> it = seekable('5678') 

2940 >>> bool(it) 

2941 True 

2942 >>> list(it) 

2943 ['5', '6', '7', '8'] 

2944 >>> bool(it) 

2945 False 

2946 

2947 You may view the contents of the cache with the :meth:`elements` method. 

2948 That returns a :class:`SequenceView`, a view that updates automatically: 

2949 

2950 >>> it = seekable((str(n) for n in range(10))) 

2951 >>> next(it), next(it), next(it) 

2952 ('0', '1', '2') 

2953 >>> elements = it.elements() 

2954 >>> elements 

2955 SequenceView(['0', '1', '2']) 

2956 >>> next(it) 

2957 '3' 

2958 >>> elements 

2959 SequenceView(['0', '1', '2', '3']) 

2960 

2961 By default, the cache grows as the source iterable progresses, so beware of 

2962 wrapping very large or infinite iterables. Supply *maxlen* to limit the 

2963 size of the cache (this of course limits how far back you can seek). 

2964 

2965 >>> from itertools import count 

2966 >>> it = seekable((str(n) for n in count()), maxlen=2) 

2967 >>> next(it), next(it), next(it), next(it) 

2968 ('0', '1', '2', '3') 

2969 >>> list(it.elements()) 

2970 ['2', '3'] 

2971 >>> it.seek(0) 

2972 >>> next(it), next(it), next(it), next(it) 

2973 ('2', '3', '4', '5') 

2974 >>> next(it) 

2975 '6' 

2976 

2977 """ 

2978 

2979 def __init__(self, iterable, maxlen=None): 

2980 self._source = iter(iterable) 

2981 if maxlen is None: 

2982 self._cache = [] 

2983 else: 

2984 self._cache = deque([], maxlen) 

2985 self._index = None 

2986 

2987 def __iter__(self): 

2988 return self 

2989 

2990 def __next__(self): 

2991 if self._index is not None: 

2992 try: 

2993 item = self._cache[self._index] 

2994 except IndexError: 

2995 self._index = None 

2996 else: 

2997 self._index += 1 

2998 return item 

2999 

3000 item = next(self._source) 

3001 self._cache.append(item) 

3002 return item 

3003 

3004 def __bool__(self): 

3005 try: 

3006 self.peek() 

3007 except StopIteration: 

3008 return False 

3009 return True 

3010 

3011 def peek(self, default=_marker): 

3012 try: 

3013 peeked = next(self) 

3014 except StopIteration: 

3015 if default is _marker: 

3016 raise 

3017 return default 

3018 if self._index is None: 

3019 self._index = len(self._cache) 

3020 self._index -= 1 

3021 return peeked 

3022 

3023 def elements(self): 

3024 return SequenceView(self._cache) 

3025 

3026 def seek(self, index): 

3027 self._index = index 

3028 remainder = index - len(self._cache) 

3029 if remainder > 0: 

3030 consume(self, remainder) 

3031 

3032 def relative_seek(self, count): 

3033 if self._index is None: 

3034 self._index = len(self._cache) 

3035 

3036 self.seek(max(self._index + count, 0)) 

3037 

3038 

3039class run_length: 

3040 """ 

3041 :func:`run_length.encode` compresses an iterable with run-length encoding. 

3042 It yields groups of repeated items with the count of how many times they 

3043 were repeated: 

3044 

3045 >>> uncompressed = 'abbcccdddd' 

3046 >>> list(run_length.encode(uncompressed)) 

3047 [('a', 1), ('b', 2), ('c', 3), ('d', 4)] 

3048 

3049 :func:`run_length.decode` decompresses an iterable that was previously 

3050 compressed with run-length encoding. It yields the items of the 

3051 decompressed iterable: 

3052 

3053 >>> compressed = [('a', 1), ('b', 2), ('c', 3), ('d', 4)] 

3054 >>> list(run_length.decode(compressed)) 

3055 ['a', 'b', 'b', 'c', 'c', 'c', 'd', 'd', 'd', 'd'] 

3056 

3057 """ 

3058 

3059 @staticmethod 

3060 def encode(iterable): 

3061 return ((k, ilen(g)) for k, g in groupby(iterable)) 

3062 

3063 @staticmethod 

3064 def decode(iterable): 

3065 return chain.from_iterable(starmap(repeat, iterable)) 

3066 

3067 

3068def exactly_n(iterable, n, predicate=bool): 

3069 """Return ``True`` if exactly ``n`` items in the iterable are ``True`` 

3070 according to the *predicate* function. 

3071 

3072 >>> exactly_n([True, True, False], 2) 

3073 True 

3074 >>> exactly_n([True, True, False], 1) 

3075 False 

3076 >>> exactly_n([0, 1, 2, 3, 4, 5], 3, lambda x: x < 3) 

3077 True 

3078 

3079 The iterable will be advanced until ``n + 1`` truthy items are encountered, 

3080 so avoid calling it on infinite iterables. 

3081 

3082 """ 

3083 return ilen(islice(filter(predicate, iterable), n + 1)) == n 

3084 

3085 

3086def circular_shifts(iterable, steps=1): 

3087 """Yield the circular shifts of *iterable*. 

3088 

3089 >>> list(circular_shifts(range(4))) 

3090 [(0, 1, 2, 3), (1, 2, 3, 0), (2, 3, 0, 1), (3, 0, 1, 2)] 

3091 

3092 Set *steps* to the number of places to rotate to the left 

3093 (or to the right if negative). Defaults to 1. 

3094 

3095 >>> list(circular_shifts(range(4), 2)) 

3096 [(0, 1, 2, 3), (2, 3, 0, 1)] 

3097 

3098 >>> list(circular_shifts(range(4), -1)) 

3099 [(0, 1, 2, 3), (3, 0, 1, 2), (2, 3, 0, 1), (1, 2, 3, 0)] 

3100 

3101 """ 

3102 buffer = deque(iterable) 

3103 if steps == 0: 

3104 raise ValueError('Steps should be a non-zero integer') 

3105 

3106 buffer.rotate(steps) 

3107 steps = -steps 

3108 n = len(buffer) 

3109 n //= math.gcd(n, steps) 

3110 

3111 for _ in repeat(None, n): 

3112 buffer.rotate(steps) 

3113 yield tuple(buffer) 

3114 

3115 

3116def make_decorator(wrapping_func, result_index=0): 

3117 """Return a decorator version of *wrapping_func*, which is a function that 

3118 modifies an iterable. *result_index* is the position in that function's 

3119 signature where the iterable goes. 

3120 

3121 This lets you use itertools on the "production end," i.e. at function 

3122 definition. This can augment what the function returns without changing the 

3123 function's code. 

3124 

3125 For example, to produce a decorator version of :func:`chunked`: 

3126 

3127 >>> from more_itertools import chunked 

3128 >>> chunker = make_decorator(chunked, result_index=0) 

3129 >>> @chunker(3) 

3130 ... def iter_range(n): 

3131 ... return iter(range(n)) 

3132 ... 

3133 >>> list(iter_range(9)) 

3134 [[0, 1, 2], [3, 4, 5], [6, 7, 8]] 

3135 

3136 To only allow truthy items to be returned: 

3137 

3138 >>> truth_serum = make_decorator(filter, result_index=1) 

3139 >>> @truth_serum(bool) 

3140 ... def boolean_test(): 

3141 ... return [0, 1, '', ' ', False, True] 

3142 ... 

3143 >>> list(boolean_test()) 

3144 [1, ' ', True] 

3145 

3146 The :func:`peekable` and :func:`seekable` wrappers make for practical 

3147 decorators: 

3148 

3149 >>> from more_itertools import peekable 

3150 >>> peekable_function = make_decorator(peekable) 

3151 >>> @peekable_function() 

3152 ... def str_range(*args): 

3153 ... return (str(x) for x in range(*args)) 

3154 ... 

3155 >>> it = str_range(1, 20, 2) 

3156 >>> next(it), next(it), next(it) 

3157 ('1', '3', '5') 

3158 >>> it.peek() 

3159 '7' 

3160 >>> next(it) 

3161 '7' 

3162 

3163 """ 

3164 

3165 # See https://sites.google.com/site/bbayles/index/decorator_factory for 

3166 # notes on how this works. 

3167 def decorator(*wrapping_args, **wrapping_kwargs): 

3168 def outer_wrapper(f): 

3169 def inner_wrapper(*args, **kwargs): 

3170 result = f(*args, **kwargs) 

3171 wrapping_args_ = list(wrapping_args) 

3172 wrapping_args_.insert(result_index, result) 

3173 return wrapping_func(*wrapping_args_, **wrapping_kwargs) 

3174 

3175 return inner_wrapper 

3176 

3177 return outer_wrapper 

3178 

3179 return decorator 

3180 

3181 

3182def map_reduce(iterable, keyfunc, valuefunc=None, reducefunc=None): 

3183 """Return a dictionary that maps the items in *iterable* to categories 

3184 defined by *keyfunc*, transforms them with *valuefunc*, and 

3185 then summarizes them by category with *reducefunc*. 

3186 

3187 *valuefunc* defaults to the identity function if it is unspecified. 

3188 If *reducefunc* is unspecified, no summarization takes place: 

3189 

3190 >>> keyfunc = lambda x: x.upper() 

3191 >>> result = map_reduce('abbccc', keyfunc) 

3192 >>> sorted(result.items()) 

3193 [('A', ['a']), ('B', ['b', 'b']), ('C', ['c', 'c', 'c'])] 

3194 

3195 Specifying *valuefunc* transforms the categorized items: 

3196 

3197 >>> keyfunc = lambda x: x.upper() 

3198 >>> valuefunc = lambda x: 1 

3199 >>> result = map_reduce('abbccc', keyfunc, valuefunc) 

3200 >>> sorted(result.items()) 

3201 [('A', [1]), ('B', [1, 1]), ('C', [1, 1, 1])] 

3202 

3203 Specifying *reducefunc* summarizes the categorized items: 

3204 

3205 >>> keyfunc = lambda x: x.upper() 

3206 >>> valuefunc = lambda x: 1 

3207 >>> reducefunc = sum 

3208 >>> result = map_reduce('abbccc', keyfunc, valuefunc, reducefunc) 

3209 >>> sorted(result.items()) 

3210 [('A', 1), ('B', 2), ('C', 3)] 

3211 

3212 You may want to filter the input iterable before applying the map/reduce 

3213 procedure: 

3214 

3215 >>> all_items = range(30) 

3216 >>> items = [x for x in all_items if 10 <= x <= 20] # Filter 

3217 >>> keyfunc = lambda x: x % 2 # Evens map to 0; odds to 1 

3218 >>> categories = map_reduce(items, keyfunc=keyfunc) 

3219 >>> sorted(categories.items()) 

3220 [(0, [10, 12, 14, 16, 18, 20]), (1, [11, 13, 15, 17, 19])] 

3221 >>> summaries = map_reduce(items, keyfunc=keyfunc, reducefunc=sum) 

3222 >>> sorted(summaries.items()) 

3223 [(0, 90), (1, 75)] 

3224 

3225 Note that all items in the iterable are gathered into a list before the 

3226 summarization step, which may require significant storage. 

3227 

3228 The returned object is a :obj:`collections.defaultdict` with the 

3229 ``default_factory`` set to ``None``, such that it behaves like a normal 

3230 dictionary. 

3231 

3232 """ 

3233 

3234 ret = defaultdict(list) 

3235 

3236 if valuefunc is None: 

3237 for item in iterable: 

3238 key = keyfunc(item) 

3239 ret[key].append(item) 

3240 

3241 else: 

3242 for item in iterable: 

3243 key = keyfunc(item) 

3244 value = valuefunc(item) 

3245 ret[key].append(value) 

3246 

3247 if reducefunc is not None: 

3248 for key, value_list in ret.items(): 

3249 ret[key] = reducefunc(value_list) 

3250 

3251 ret.default_factory = None 

3252 return ret 

3253 

3254 

3255def rlocate(iterable, pred=bool, window_size=None): 

3256 """Yield the index of each item in *iterable* for which *pred* returns 

3257 ``True``, starting from the right and moving left. 

3258 

3259 *pred* defaults to :func:`bool`, which will select truthy items: 

3260 

3261 >>> list(rlocate([0, 1, 1, 0, 1, 0, 0])) # Truthy at 1, 2, and 4 

3262 [4, 2, 1] 

3263 

3264 Set *pred* to a custom function to, e.g., find the indexes for a particular 

3265 item: 

3266 

3267 >>> iterator = iter('abcb') 

3268 >>> pred = lambda x: x == 'b' 

3269 >>> list(rlocate(iterator, pred)) 

3270 [3, 1] 

3271 

3272 If *window_size* is given, then the *pred* function will be called with 

3273 that many items. This enables searching for sub-sequences: 

3274 

3275 >>> iterable = [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3] 

3276 >>> pred = lambda *args: args == (1, 2, 3) 

3277 >>> list(rlocate(iterable, pred=pred, window_size=3)) 

3278 [9, 5, 1] 

3279 

3280 Beware, this function won't return anything for infinite iterables. 

3281 If *iterable* is reversible, ``rlocate`` will reverse it and search from 

3282 the right. Otherwise, it will search from the left and return the results 

3283 in reverse order. 

3284 

3285 See :func:`locate` to for other example applications. 

3286 

3287 """ 

3288 if window_size is None: 

3289 try: 

3290 len_iter = len(iterable) 

3291 return (len_iter - i - 1 for i in locate(reversed(iterable), pred)) 

3292 except TypeError: 

3293 pass 

3294 

3295 return reversed(list(locate(iterable, pred, window_size))) 

3296 

3297 

3298def replace(iterable, pred, substitutes, count=None, window_size=1): 

3299 """Yield the items from *iterable*, replacing the items for which *pred* 

3300 returns ``True`` with the items from the iterable *substitutes*. 

3301 

3302 >>> iterable = [1, 1, 0, 1, 1, 0, 1, 1] 

3303 >>> pred = lambda x: x == 0 

3304 >>> substitutes = (2, 3) 

3305 >>> list(replace(iterable, pred, substitutes)) 

3306 [1, 1, 2, 3, 1, 1, 2, 3, 1, 1] 

3307 

3308 If *count* is given, the number of replacements will be limited: 

3309 

3310 >>> iterable = [1, 1, 0, 1, 1, 0, 1, 1, 0] 

3311 >>> pred = lambda x: x == 0 

3312 >>> substitutes = [None] 

3313 >>> list(replace(iterable, pred, substitutes, count=2)) 

3314 [1, 1, None, 1, 1, None, 1, 1, 0] 

3315 

3316 Use *window_size* to control the number of items passed as arguments to 

3317 *pred*. This allows for locating and replacing subsequences. 

3318 

3319 >>> iterable = [0, 1, 2, 5, 0, 1, 2, 5] 

3320 >>> window_size = 3 

3321 >>> pred = lambda *args: args == (0, 1, 2) # 3 items passed to pred 

3322 >>> substitutes = [3, 4] # Splice in these items 

3323 >>> list(replace(iterable, pred, substitutes, window_size=window_size)) 

3324 [3, 4, 5, 3, 4, 5] 

3325 

3326 """ 

3327 if window_size < 1: 

3328 raise ValueError('window_size must be at least 1') 

3329 

3330 # Save the substitutes iterable, since it's used more than once 

3331 substitutes = tuple(substitutes) 

3332 

3333 # Add padding such that the number of windows matches the length of the 

3334 # iterable 

3335 it = chain(iterable, repeat(_marker, window_size - 1)) 

3336 windows = windowed(it, window_size) 

3337 

3338 n = 0 

3339 for w in windows: 

3340 # If the current window matches our predicate (and we haven't hit 

3341 # our maximum number of replacements), splice in the substitutes 

3342 # and then consume the following windows that overlap with this one. 

3343 # For example, if the iterable is (0, 1, 2, 3, 4...) 

3344 # and the window size is 2, we have (0, 1), (1, 2), (2, 3)... 

3345 # If the predicate matches on (0, 1), we need to zap (0, 1) and (1, 2) 

3346 if pred(*w): 

3347 if (count is None) or (n < count): 

3348 n += 1 

3349 yield from substitutes 

3350 consume(windows, window_size - 1) 

3351 continue 

3352 

3353 # If there was no match (or we've reached the replacement limit), 

3354 # yield the first item from the window. 

3355 if w and (w[0] is not _marker): 

3356 yield w[0] 

3357 

3358 

3359def partitions(iterable): 

3360 """Yield all possible order-preserving partitions of *iterable*. 

3361 

3362 >>> iterable = 'abc' 

3363 >>> for part in partitions(iterable): 

3364 ... print([''.join(p) for p in part]) 

3365 ['abc'] 

3366 ['a', 'bc'] 

3367 ['ab', 'c'] 

3368 ['a', 'b', 'c'] 

3369 

3370 This is unrelated to :func:`partition`. 

3371 

3372 """ 

3373 sequence = list(iterable) 

3374 n = len(sequence) 

3375 for i in powerset(range(1, n)): 

3376 yield [sequence[i:j] for i, j in zip((0,) + i, i + (n,))] 

3377 

3378 

3379def set_partitions(iterable, k=None, min_size=None, max_size=None): 

3380 """ 

3381 Yield the set partitions of *iterable* into *k* parts. Set partitions are 

3382 not order-preserving. 

3383 

3384 >>> iterable = 'abc' 

3385 >>> for part in set_partitions(iterable, 2): 

3386 ... print([''.join(p) for p in part]) 

3387 ['a', 'bc'] 

3388 ['ab', 'c'] 

3389 ['b', 'ac'] 

3390 

3391 

3392 If *k* is not given, every set partition is generated. 

3393 

3394 >>> iterable = 'abc' 

3395 >>> for part in set_partitions(iterable): 

3396 ... print([''.join(p) for p in part]) 

3397 ['abc'] 

3398 ['a', 'bc'] 

3399 ['ab', 'c'] 

3400 ['b', 'ac'] 

3401 ['a', 'b', 'c'] 

3402 

3403 if *min_size* and/or *max_size* are given, the minimum and/or maximum size 

3404 per block in partition is set. 

3405 

3406 >>> iterable = 'abc' 

3407 >>> for part in set_partitions(iterable, min_size=2): 

3408 ... print([''.join(p) for p in part]) 

3409 ['abc'] 

3410 >>> for part in set_partitions(iterable, max_size=2): 

3411 ... print([''.join(p) for p in part]) 

3412 ['a', 'bc'] 

3413 ['ab', 'c'] 

3414 ['b', 'ac'] 

3415 ['a', 'b', 'c'] 

3416 

3417 """ 

3418 L = list(iterable) 

3419 n = len(L) 

3420 if k is not None: 

3421 if k < 1: 

3422 raise ValueError( 

3423 "Can't partition in a negative or zero number of groups" 

3424 ) 

3425 elif k > n: 

3426 return 

3427 

3428 min_size = min_size if min_size is not None else 0 

3429 max_size = max_size if max_size is not None else n 

3430 if min_size > max_size: 

3431 return 

3432 

3433 def set_partitions_helper(L, k): 

3434 n = len(L) 

3435 if k == 1: 

3436 yield [L] 

3437 elif n == k: 

3438 yield [[s] for s in L] 

3439 else: 

3440 e, *M = L 

3441 for p in set_partitions_helper(M, k - 1): 

3442 yield [[e], *p] 

3443 for p in set_partitions_helper(M, k): 

3444 for i in range(len(p)): 

3445 yield p[:i] + [[e] + p[i]] + p[i + 1 :] 

3446 

3447 if k is None: 

3448 for k in range(1, n + 1): 

3449 yield from filter( 

3450 lambda z: all(min_size <= len(bk) <= max_size for bk in z), 

3451 set_partitions_helper(L, k), 

3452 ) 

3453 else: 

3454 yield from filter( 

3455 lambda z: all(min_size <= len(bk) <= max_size for bk in z), 

3456 set_partitions_helper(L, k), 

3457 ) 

3458 

3459 

3460class time_limited: 

3461 """ 

3462 Yield items from *iterable* until *limit_seconds* have passed. 

3463 If the time limit expires before all items have been yielded, the 

3464 ``timed_out`` parameter will be set to ``True``. 

3465 

3466 >>> from time import sleep 

3467 >>> def generator(): 

3468 ... yield 1 

3469 ... yield 2 

3470 ... sleep(0.2) 

3471 ... yield 3 

3472 >>> iterable = time_limited(0.1, generator()) 

3473 >>> list(iterable) 

3474 [1, 2] 

3475 >>> iterable.timed_out 

3476 True 

3477 

3478 Note that the time is checked before each item is yielded, and iteration 

3479 stops if the time elapsed is greater than *limit_seconds*. If your time 

3480 limit is 1 second, but it takes 2 seconds to generate the first item from 

3481 the iterable, the function will run for 2 seconds and not yield anything. 

3482 As a special case, when *limit_seconds* is zero, the iterator never 

3483 returns anything. 

3484 

3485 """ 

3486 

3487 def __init__(self, limit_seconds, iterable): 

3488 if limit_seconds < 0: 

3489 raise ValueError('limit_seconds must be positive') 

3490 self.limit_seconds = limit_seconds 

3491 self._iterator = iter(iterable) 

3492 self._start_time = monotonic() 

3493 self.timed_out = False 

3494 

3495 def __iter__(self): 

3496 return self 

3497 

3498 def __next__(self): 

3499 if self.limit_seconds == 0: 

3500 self.timed_out = True 

3501 raise StopIteration 

3502 item = next(self._iterator) 

3503 if monotonic() - self._start_time > self.limit_seconds: 

3504 self.timed_out = True 

3505 raise StopIteration 

3506 

3507 return item 

3508 

3509 

3510def only(iterable, default=None, too_long=None): 

3511 """If *iterable* has only one item, return it. 

3512 If it has zero items, return *default*. 

3513 If it has more than one item, raise the exception given by *too_long*, 

3514 which is ``ValueError`` by default. 

3515 

3516 >>> only([], default='missing') 

3517 'missing' 

3518 >>> only([1]) 

3519 1 

3520 >>> only([1, 2]) # doctest: +IGNORE_EXCEPTION_DETAIL 

3521 Traceback (most recent call last): 

3522 ... 

3523 ValueError: Expected exactly one item in iterable, but got 1, 2, 

3524 and perhaps more.' 

3525 >>> only([1, 2], too_long=TypeError) # doctest: +IGNORE_EXCEPTION_DETAIL 

3526 Traceback (most recent call last): 

3527 ... 

3528 TypeError 

3529 

3530 Note that :func:`only` attempts to advance *iterable* twice to ensure there 

3531 is only one item. See :func:`spy` or :func:`peekable` to check 

3532 iterable contents less destructively. 

3533 

3534 """ 

3535 iterator = iter(iterable) 

3536 for first in iterator: 

3537 for second in iterator: 

3538 msg = ( 

3539 f'Expected exactly one item in iterable, but got {first!r}, ' 

3540 f'{second!r}, and perhaps more.' 

3541 ) 

3542 raise too_long or ValueError(msg) 

3543 return first 

3544 return default 

3545 

3546 

3547def _ichunk(iterator, n): 

3548 cache = deque() 

3549 chunk = islice(iterator, n) 

3550 

3551 def generator(): 

3552 with suppress(StopIteration): 

3553 while True: 

3554 if cache: 

3555 yield cache.popleft() 

3556 else: 

3557 yield next(chunk) 

3558 

3559 def materialize_next(n=1): 

3560 # if n not specified materialize everything 

3561 if n is None: 

3562 cache.extend(chunk) 

3563 return len(cache) 

3564 

3565 to_cache = n - len(cache) 

3566 

3567 # materialize up to n 

3568 if to_cache > 0: 

3569 cache.extend(islice(chunk, to_cache)) 

3570 

3571 # return number materialized up to n 

3572 return min(n, len(cache)) 

3573 

3574 return (generator(), materialize_next) 

3575 

3576 

3577def ichunked(iterable, n): 

3578 """Break *iterable* into sub-iterables with *n* elements each. 

3579 :func:`ichunked` is like :func:`chunked`, but it yields iterables 

3580 instead of lists. 

3581 

3582 If the sub-iterables are read in order, the elements of *iterable* 

3583 won't be stored in memory. 

3584 If they are read out of order, :func:`itertools.tee` is used to cache 

3585 elements as necessary. 

3586 

3587 >>> from itertools import count 

3588 >>> all_chunks = ichunked(count(), 4) 

3589 >>> c_1, c_2, c_3 = next(all_chunks), next(all_chunks), next(all_chunks) 

3590 >>> list(c_2) # c_1's elements have been cached; c_3's haven't been 

3591 [4, 5, 6, 7] 

3592 >>> list(c_1) 

3593 [0, 1, 2, 3] 

3594 >>> list(c_3) 

3595 [8, 9, 10, 11] 

3596 

3597 """ 

3598 iterator = iter(iterable) 

3599 while True: 

3600 # Create new chunk 

3601 chunk, materialize_next = _ichunk(iterator, n) 

3602 

3603 # Check to see whether we're at the end of the source iterable 

3604 if not materialize_next(): 

3605 return 

3606 

3607 yield chunk 

3608 

3609 # Fill previous chunk's cache 

3610 materialize_next(None) 

3611 

3612 

3613def iequals(*iterables): 

3614 """Return ``True`` if all given *iterables* are equal to each other, 

3615 which means that they contain the same elements in the same order. 

3616 

3617 The function is useful for comparing iterables of different data types 

3618 or iterables that do not support equality checks. 

3619 

3620 >>> iequals("abc", ['a', 'b', 'c'], ('a', 'b', 'c'), iter("abc")) 

3621 True 

3622 

3623 >>> iequals("abc", "acb") 

3624 False 

3625 

3626 Not to be confused with :func:`all_equal`, which checks whether all 

3627 elements of iterable are equal to each other. 

3628 

3629 """ 

3630 return all(map(all_equal, zip_longest(*iterables, fillvalue=object()))) 

3631 

3632 

3633def distinct_combinations(iterable, r): 

3634 """Yield the distinct combinations of *r* items taken from *iterable*. 

3635 

3636 >>> list(distinct_combinations([0, 0, 1], 2)) 

3637 [(0, 0), (0, 1)] 

3638 

3639 Equivalent to ``set(combinations(iterable))``, except duplicates are not 

3640 generated and thrown away. For larger input sequences this is much more 

3641 efficient. 

3642 

3643 """ 

3644 if r < 0: 

3645 raise ValueError('r must be non-negative') 

3646 elif r == 0: 

3647 yield () 

3648 return 

3649 pool = tuple(iterable) 

3650 generators = [unique_everseen(enumerate(pool), key=itemgetter(1))] 

3651 current_combo = [None] * r 

3652 level = 0 

3653 while generators: 

3654 try: 

3655 cur_idx, p = next(generators[-1]) 

3656 except StopIteration: 

3657 generators.pop() 

3658 level -= 1 

3659 continue 

3660 current_combo[level] = p 

3661 if level + 1 == r: 

3662 yield tuple(current_combo) 

3663 else: 

3664 generators.append( 

3665 unique_everseen( 

3666 enumerate(pool[cur_idx + 1 :], cur_idx + 1), 

3667 key=itemgetter(1), 

3668 ) 

3669 ) 

3670 level += 1 

3671 

3672 

3673def filter_except(validator, iterable, *exceptions): 

3674 """Yield the items from *iterable* for which the *validator* function does 

3675 not raise one of the specified *exceptions*. 

3676 

3677 *validator* is called for each item in *iterable*. 

3678 It should be a function that accepts one argument and raises an exception 

3679 if that item is not valid. 

3680 

3681 >>> iterable = ['1', '2', 'three', '4', None] 

3682 >>> list(filter_except(int, iterable, ValueError, TypeError)) 

3683 ['1', '2', '4'] 

3684 

3685 If an exception other than one given by *exceptions* is raised by 

3686 *validator*, it is raised like normal. 

3687 """ 

3688 for item in iterable: 

3689 try: 

3690 validator(item) 

3691 except exceptions: 

3692 pass 

3693 else: 

3694 yield item 

3695 

3696 

3697def map_except(function, iterable, *exceptions): 

3698 """Transform each item from *iterable* with *function* and yield the 

3699 result, unless *function* raises one of the specified *exceptions*. 

3700 

3701 *function* is called to transform each item in *iterable*. 

3702 It should accept one argument. 

3703 

3704 >>> iterable = ['1', '2', 'three', '4', None] 

3705 >>> list(map_except(int, iterable, ValueError, TypeError)) 

3706 [1, 2, 4] 

3707 

3708 If an exception other than one given by *exceptions* is raised by 

3709 *function*, it is raised like normal. 

3710 """ 

3711 for item in iterable: 

3712 try: 

3713 yield function(item) 

3714 except exceptions: 

3715 pass 

3716 

3717 

3718def map_if(iterable, pred, func, func_else=None): 

3719 """Evaluate each item from *iterable* using *pred*. If the result is 

3720 equivalent to ``True``, transform the item with *func* and yield it. 

3721 Otherwise, transform the item with *func_else* and yield it. 

3722 

3723 *pred*, *func*, and *func_else* should each be functions that accept 

3724 one argument. By default, *func_else* is the identity function. 

3725 

3726 >>> from math import sqrt 

3727 >>> iterable = list(range(-5, 5)) 

3728 >>> iterable 

3729 [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4] 

3730 >>> list(map_if(iterable, lambda x: x > 3, lambda x: 'toobig')) 

3731 [-5, -4, -3, -2, -1, 0, 1, 2, 3, 'toobig'] 

3732 >>> list(map_if(iterable, lambda x: x >= 0, 

3733 ... lambda x: f'{sqrt(x):.2f}', lambda x: None)) 

3734 [None, None, None, None, None, '0.00', '1.00', '1.41', '1.73', '2.00'] 

3735 """ 

3736 

3737 if func_else is None: 

3738 for item in iterable: 

3739 yield func(item) if pred(item) else item 

3740 

3741 else: 

3742 for item in iterable: 

3743 yield func(item) if pred(item) else func_else(item) 

3744 

3745 

3746def _sample_unweighted(iterator, k, strict): 

3747 # Algorithm L in the 1994 paper by Kim-Hung Li: 

3748 # "Reservoir-Sampling Algorithms of Time Complexity O(n(1+log(N/n)))". 

3749 

3750 reservoir = list(islice(iterator, k)) 

3751 if strict and len(reservoir) < k: 

3752 raise ValueError('Sample larger than population') 

3753 W = 1.0 

3754 

3755 with suppress(StopIteration): 

3756 while True: 

3757 W *= random() ** (1 / k) 

3758 skip = floor(log(random()) / log1p(-W)) 

3759 element = next(islice(iterator, skip, None)) 

3760 reservoir[randrange(k)] = element 

3761 

3762 shuffle(reservoir) 

3763 return reservoir 

3764 

3765 

3766def _sample_weighted(iterator, k, weights, strict): 

3767 # Implementation of "A-ExpJ" from the 2006 paper by Efraimidis et al. : 

3768 # "Weighted random sampling with a reservoir". 

3769 

3770 # Log-transform for numerical stability for weights that are small/large 

3771 weight_keys = (log(random()) / weight for weight in weights) 

3772 

3773 # Fill up the reservoir (collection of samples) with the first `k` 

3774 # weight-keys and elements, then heapify the list. 

3775 reservoir = take(k, zip(weight_keys, iterator)) 

3776 if strict and len(reservoir) < k: 

3777 raise ValueError('Sample larger than population') 

3778 

3779 heapify(reservoir) 

3780 

3781 # The number of jumps before changing the reservoir is a random variable 

3782 # with an exponential distribution. Sample it using random() and logs. 

3783 smallest_weight_key, _ = reservoir[0] 

3784 weights_to_skip = log(random()) / smallest_weight_key 

3785 

3786 for weight, element in zip(weights, iterator): 

3787 if weight >= weights_to_skip: 

3788 # The notation here is consistent with the paper, but we store 

3789 # the weight-keys in log-space for better numerical stability. 

3790 smallest_weight_key, _ = reservoir[0] 

3791 t_w = exp(weight * smallest_weight_key) 

3792 r_2 = uniform(t_w, 1) # generate U(t_w, 1) 

3793 weight_key = log(r_2) / weight 

3794 heapreplace(reservoir, (weight_key, element)) 

3795 smallest_weight_key, _ = reservoir[0] 

3796 weights_to_skip = log(random()) / smallest_weight_key 

3797 else: 

3798 weights_to_skip -= weight 

3799 

3800 ret = [element for weight_key, element in reservoir] 

3801 shuffle(ret) 

3802 return ret 

3803 

3804 

3805def _sample_counted(population, k, counts, strict): 

3806 element = None 

3807 remaining = 0 

3808 

3809 def feed(i): 

3810 # Advance *i* steps ahead and consume an element 

3811 nonlocal element, remaining 

3812 

3813 while i + 1 > remaining: 

3814 i = i - remaining 

3815 element = next(population) 

3816 remaining = next(counts) 

3817 remaining -= i + 1 

3818 return element 

3819 

3820 with suppress(StopIteration): 

3821 reservoir = [] 

3822 for _ in range(k): 

3823 reservoir.append(feed(0)) 

3824 

3825 if strict and len(reservoir) < k: 

3826 raise ValueError('Sample larger than population') 

3827 

3828 with suppress(StopIteration): 

3829 W = 1.0 

3830 while True: 

3831 W *= random() ** (1 / k) 

3832 skip = floor(log(random()) / log1p(-W)) 

3833 element = feed(skip) 

3834 reservoir[randrange(k)] = element 

3835 

3836 shuffle(reservoir) 

3837 return reservoir 

3838 

3839 

3840def sample(iterable, k, weights=None, *, counts=None, strict=False): 

3841 """Return a *k*-length list of elements chosen (without replacement) 

3842 from the *iterable*. Similar to :func:`random.sample`, but works on 

3843 iterables of unknown length. 

3844 

3845 >>> iterable = range(100) 

3846 >>> sample(iterable, 5) # doctest: +SKIP 

3847 [81, 60, 96, 16, 4] 

3848 

3849 For iterables with repeated elements, you may supply *counts* to 

3850 indicate the repeats. 

3851 

3852 >>> iterable = ['a', 'b'] 

3853 >>> counts = [3, 4] # Equivalent to 'a', 'a', 'a', 'b', 'b', 'b', 'b' 

3854 >>> sample(iterable, k=3, counts=counts) # doctest: +SKIP 

3855 ['a', 'a', 'b'] 

3856 

3857 An iterable with *weights* may be given: 

3858 

3859 >>> iterable = range(100) 

3860 >>> weights = (i * i + 1 for i in range(100)) 

3861 >>> sampled = sample(iterable, 5, weights=weights) # doctest: +SKIP 

3862 [79, 67, 74, 66, 78] 

3863 

3864 Weighted selections are made without replacement. 

3865 After an element is selected, it is removed from the pool and the 

3866 relative weights of the other elements increase (this 

3867 does not match the behavior of :func:`random.sample`'s *counts* 

3868 parameter). Note that *weights* may not be used with *counts*. 

3869 

3870 If the length of *iterable* is less than *k*, 

3871 ``ValueError`` is raised if *strict* is ``True`` and 

3872 all elements are returned (in shuffled order) if *strict* is ``False``. 

3873 

3874 By default, the `Algorithm L <https://w.wiki/ANrM>`__ reservoir sampling 

3875 technique is used. When *weights* are provided, 

3876 `Algorithm A-ExpJ <https://w.wiki/ANrS>`__ is used. 

3877 """ 

3878 iterator = iter(iterable) 

3879 

3880 if k < 0: 

3881 raise ValueError('k must be non-negative') 

3882 

3883 if k == 0: 

3884 return [] 

3885 

3886 if weights is not None and counts is not None: 

3887 raise TypeError('weights and counts are mutually exclusive') 

3888 

3889 elif weights is not None: 

3890 weights = iter(weights) 

3891 return _sample_weighted(iterator, k, weights, strict) 

3892 

3893 elif counts is not None: 

3894 counts = iter(counts) 

3895 return _sample_counted(iterator, k, counts, strict) 

3896 

3897 else: 

3898 return _sample_unweighted(iterator, k, strict) 

3899 

3900 

3901def is_sorted(iterable, key=None, reverse=False, strict=False): 

3902 """Returns ``True`` if the items of iterable are in sorted order, and 

3903 ``False`` otherwise. *key* and *reverse* have the same meaning that they do 

3904 in the built-in :func:`sorted` function. 

3905 

3906 >>> is_sorted(['1', '2', '3', '4', '5'], key=int) 

3907 True 

3908 >>> is_sorted([5, 4, 3, 1, 2], reverse=True) 

3909 False 

3910 

3911 If *strict*, tests for strict sorting, that is, returns ``False`` if equal 

3912 elements are found: 

3913 

3914 >>> is_sorted([1, 2, 2]) 

3915 True 

3916 >>> is_sorted([1, 2, 2], strict=True) 

3917 False 

3918 

3919 The function returns ``False`` after encountering the first out-of-order 

3920 item, which means it may produce results that differ from the built-in 

3921 :func:`sorted` function for objects with unusual comparison dynamics 

3922 (like ``math.nan``). If there are no out-of-order items, the iterable is 

3923 exhausted. 

3924 """ 

3925 it = iterable if (key is None) else map(key, iterable) 

3926 a, b = tee(it) 

3927 next(b, None) 

3928 if reverse: 

3929 b, a = a, b 

3930 return all(map(lt, a, b)) if strict else not any(map(lt, b, a)) 

3931 

3932 

3933class AbortThread(BaseException): 

3934 pass 

3935 

3936 

3937class callback_iter: 

3938 """Convert a function that uses callbacks to an iterator. 

3939 

3940 Let *func* be a function that takes a `callback` keyword argument. 

3941 For example: 

3942 

3943 >>> def func(callback=None): 

3944 ... for i, c in [(1, 'a'), (2, 'b'), (3, 'c')]: 

3945 ... if callback: 

3946 ... callback(i, c) 

3947 ... return 4 

3948 

3949 

3950 Use ``with callback_iter(func)`` to get an iterator over the parameters 

3951 that are delivered to the callback. 

3952 

3953 >>> with callback_iter(func) as it: 

3954 ... for args, kwargs in it: 

3955 ... print(args) 

3956 (1, 'a') 

3957 (2, 'b') 

3958 (3, 'c') 

3959 

3960 The function will be called in a background thread. The ``done`` property 

3961 indicates whether it has completed execution. 

3962 

3963 >>> it.done 

3964 True 

3965 

3966 If it completes successfully, its return value will be available 

3967 in the ``result`` property. 

3968 

3969 >>> it.result 

3970 4 

3971 

3972 Notes: 

3973 

3974 * If the function uses some keyword argument besides ``callback``, supply 

3975 *callback_kwd*. 

3976 * If it finished executing, but raised an exception, accessing the 

3977 ``result`` property will raise the same exception. 

3978 * If it hasn't finished executing, accessing the ``result`` 

3979 property from within the ``with`` block will raise ``RuntimeError``. 

3980 * If it hasn't finished executing, accessing the ``result`` property from 

3981 outside the ``with`` block will raise a 

3982 ``more_itertools.AbortThread`` exception. 

3983 * Provide *wait_seconds* to adjust how frequently the it is polled for 

3984 output. 

3985 

3986 """ 

3987 

3988 def __init__(self, func, callback_kwd='callback', wait_seconds=0.1): 

3989 self._func = func 

3990 self._callback_kwd = callback_kwd 

3991 self._aborted = False 

3992 self._future = None 

3993 self._wait_seconds = wait_seconds 

3994 # Lazily import concurrent.future 

3995 self._executor = __import__( 

3996 'concurrent.futures' 

3997 ).futures.ThreadPoolExecutor(max_workers=1) 

3998 self._iterator = self._reader() 

3999 

4000 def __enter__(self): 

4001 return self 

4002 

4003 def __exit__(self, exc_type, exc_value, traceback): 

4004 self._aborted = True 

4005 self._executor.shutdown() 

4006 

4007 def __iter__(self): 

4008 return self 

4009 

4010 def __next__(self): 

4011 return next(self._iterator) 

4012 

4013 @property 

4014 def done(self): 

4015 if self._future is None: 

4016 return False 

4017 return self._future.done() 

4018 

4019 @property 

4020 def result(self): 

4021 if not self.done: 

4022 raise RuntimeError('Function has not yet completed') 

4023 

4024 return self._future.result() 

4025 

4026 def _reader(self): 

4027 q = Queue() 

4028 

4029 def callback(*args, **kwargs): 

4030 if self._aborted: 

4031 raise AbortThread('canceled by user') 

4032 

4033 q.put((args, kwargs)) 

4034 

4035 self._future = self._executor.submit( 

4036 self._func, **{self._callback_kwd: callback} 

4037 ) 

4038 

4039 while True: 

4040 try: 

4041 item = q.get(timeout=self._wait_seconds) 

4042 except Empty: 

4043 pass 

4044 else: 

4045 q.task_done() 

4046 yield item 

4047 

4048 if self._future.done(): 

4049 break 

4050 

4051 remaining = [] 

4052 while True: 

4053 try: 

4054 item = q.get_nowait() 

4055 except Empty: 

4056 break 

4057 else: 

4058 q.task_done() 

4059 remaining.append(item) 

4060 q.join() 

4061 yield from remaining 

4062 

4063 

4064def windowed_complete(iterable, n): 

4065 """ 

4066 Yield ``(beginning, middle, end)`` tuples, where: 

4067 

4068 * Each ``middle`` has *n* items from *iterable* 

4069 * Each ``beginning`` has the items before the ones in ``middle`` 

4070 * Each ``end`` has the items after the ones in ``middle`` 

4071 

4072 >>> iterable = range(7) 

4073 >>> n = 3 

4074 >>> for beginning, middle, end in windowed_complete(iterable, n): 

4075 ... print(beginning, middle, end) 

4076 () (0, 1, 2) (3, 4, 5, 6) 

4077 (0,) (1, 2, 3) (4, 5, 6) 

4078 (0, 1) (2, 3, 4) (5, 6) 

4079 (0, 1, 2) (3, 4, 5) (6,) 

4080 (0, 1, 2, 3) (4, 5, 6) () 

4081 

4082 Note that *n* must be at least 0 and most equal to the length of 

4083 *iterable*. 

4084 

4085 This function will exhaust the iterable and may require significant 

4086 storage. 

4087 """ 

4088 if n < 0: 

4089 raise ValueError('n must be >= 0') 

4090 

4091 seq = tuple(iterable) 

4092 size = len(seq) 

4093 

4094 if n > size: 

4095 raise ValueError('n must be <= len(seq)') 

4096 

4097 for i in range(size - n + 1): 

4098 beginning = seq[:i] 

4099 middle = seq[i : i + n] 

4100 end = seq[i + n :] 

4101 yield beginning, middle, end 

4102 

4103 

4104def all_unique(iterable, key=None): 

4105 """ 

4106 Returns ``True`` if all the elements of *iterable* are unique (no two 

4107 elements are equal). 

4108 

4109 >>> all_unique('ABCB') 

4110 False 

4111 

4112 If a *key* function is specified, it will be used to make comparisons. 

4113 

4114 >>> all_unique('ABCb') 

4115 True 

4116 >>> all_unique('ABCb', str.lower) 

4117 False 

4118 

4119 The function returns as soon as the first non-unique element is 

4120 encountered. Iterables with a mix of hashable and unhashable items can 

4121 be used, but the function will be slower for unhashable items. 

4122 """ 

4123 seenset = set() 

4124 seenset_add = seenset.add 

4125 seenlist = [] 

4126 seenlist_add = seenlist.append 

4127 for element in map(key, iterable) if key else iterable: 

4128 try: 

4129 if element in seenset: 

4130 return False 

4131 seenset_add(element) 

4132 except TypeError: 

4133 if element in seenlist: 

4134 return False 

4135 seenlist_add(element) 

4136 return True 

4137 

4138 

4139def nth_product(index, *args): 

4140 """Equivalent to ``list(product(*args))[index]``. 

4141 

4142 The products of *args* can be ordered lexicographically. 

4143 :func:`nth_product` computes the product at sort position *index* without 

4144 computing the previous products. 

4145 

4146 >>> nth_product(8, range(2), range(2), range(2), range(2)) 

4147 (1, 0, 0, 0) 

4148 

4149 ``IndexError`` will be raised if the given *index* is invalid. 

4150 """ 

4151 pools = list(map(tuple, reversed(args))) 

4152 ns = list(map(len, pools)) 

4153 

4154 c = reduce(mul, ns) 

4155 

4156 if index < 0: 

4157 index += c 

4158 

4159 if not 0 <= index < c: 

4160 raise IndexError 

4161 

4162 result = [] 

4163 for pool, n in zip(pools, ns): 

4164 result.append(pool[index % n]) 

4165 index //= n 

4166 

4167 return tuple(reversed(result)) 

4168 

4169 

4170def nth_permutation(iterable, r, index): 

4171 """Equivalent to ``list(permutations(iterable, r))[index]``` 

4172 

4173 The subsequences of *iterable* that are of length *r* where order is 

4174 important can be ordered lexicographically. :func:`nth_permutation` 

4175 computes the subsequence at sort position *index* directly, without 

4176 computing the previous subsequences. 

4177 

4178 >>> nth_permutation('ghijk', 2, 5) 

4179 ('h', 'i') 

4180 

4181 ``ValueError`` will be raised If *r* is negative or greater than the length 

4182 of *iterable*. 

4183 ``IndexError`` will be raised if the given *index* is invalid. 

4184 """ 

4185 pool = list(iterable) 

4186 n = len(pool) 

4187 

4188 if r is None or r == n: 

4189 r, c = n, factorial(n) 

4190 elif not 0 <= r < n: 

4191 raise ValueError 

4192 else: 

4193 c = perm(n, r) 

4194 assert c > 0 # factorial(n)>0, and r<n so perm(n,r) is never zero 

4195 

4196 if index < 0: 

4197 index += c 

4198 

4199 if not 0 <= index < c: 

4200 raise IndexError 

4201 

4202 result = [0] * r 

4203 q = index * factorial(n) // c if r < n else index 

4204 for d in range(1, n + 1): 

4205 q, i = divmod(q, d) 

4206 if 0 <= n - d < r: 

4207 result[n - d] = i 

4208 if q == 0: 

4209 break 

4210 

4211 return tuple(map(pool.pop, result)) 

4212 

4213 

4214def nth_combination_with_replacement(iterable, r, index): 

4215 """Equivalent to 

4216 ``list(combinations_with_replacement(iterable, r))[index]``. 

4217 

4218 

4219 The subsequences with repetition of *iterable* that are of length *r* can 

4220 be ordered lexicographically. :func:`nth_combination_with_replacement` 

4221 computes the subsequence at sort position *index* directly, without 

4222 computing the previous subsequences with replacement. 

4223 

4224 >>> nth_combination_with_replacement(range(5), 3, 5) 

4225 (0, 1, 1) 

4226 

4227 ``ValueError`` will be raised If *r* is negative or greater than the length 

4228 of *iterable*. 

4229 ``IndexError`` will be raised if the given *index* is invalid. 

4230 """ 

4231 pool = tuple(iterable) 

4232 n = len(pool) 

4233 if (r < 0) or (r > n): 

4234 raise ValueError 

4235 

4236 c = comb(n + r - 1, r) 

4237 

4238 if index < 0: 

4239 index += c 

4240 

4241 if (index < 0) or (index >= c): 

4242 raise IndexError 

4243 

4244 result = [] 

4245 i = 0 

4246 while r: 

4247 r -= 1 

4248 while n >= 0: 

4249 num_combs = comb(n + r - 1, r) 

4250 if index < num_combs: 

4251 break 

4252 n -= 1 

4253 i += 1 

4254 index -= num_combs 

4255 result.append(pool[i]) 

4256 

4257 return tuple(result) 

4258 

4259 

4260def value_chain(*args): 

4261 """Yield all arguments passed to the function in the same order in which 

4262 they were passed. If an argument itself is iterable then iterate over its 

4263 values. 

4264 

4265 >>> list(value_chain(1, 2, 3, [4, 5, 6])) 

4266 [1, 2, 3, 4, 5, 6] 

4267 

4268 Binary and text strings are not considered iterable and are emitted 

4269 as-is: 

4270 

4271 >>> list(value_chain('12', '34', ['56', '78'])) 

4272 ['12', '34', '56', '78'] 

4273 

4274 Pre- or postpend a single element to an iterable: 

4275 

4276 >>> list(value_chain(1, [2, 3, 4, 5, 6])) 

4277 [1, 2, 3, 4, 5, 6] 

4278 >>> list(value_chain([1, 2, 3, 4, 5], 6)) 

4279 [1, 2, 3, 4, 5, 6] 

4280 

4281 Multiple levels of nesting are not flattened. 

4282 

4283 """ 

4284 for value in args: 

4285 if isinstance(value, (str, bytes)): 

4286 yield value 

4287 continue 

4288 try: 

4289 yield from value 

4290 except TypeError: 

4291 yield value 

4292 

4293 

4294def product_index(element, *args): 

4295 """Equivalent to ``list(product(*args)).index(element)`` 

4296 

4297 The products of *args* can be ordered lexicographically. 

4298 :func:`product_index` computes the first index of *element* without 

4299 computing the previous products. 

4300 

4301 >>> product_index([8, 2], range(10), range(5)) 

4302 42 

4303 

4304 ``ValueError`` will be raised if the given *element* isn't in the product 

4305 of *args*. 

4306 """ 

4307 index = 0 

4308 

4309 for x, pool in zip_longest(element, args, fillvalue=_marker): 

4310 if x is _marker or pool is _marker: 

4311 raise ValueError('element is not a product of args') 

4312 

4313 pool = tuple(pool) 

4314 index = index * len(pool) + pool.index(x) 

4315 

4316 return index 

4317 

4318 

4319def combination_index(element, iterable): 

4320 """Equivalent to ``list(combinations(iterable, r)).index(element)`` 

4321 

4322 The subsequences of *iterable* that are of length *r* can be ordered 

4323 lexicographically. :func:`combination_index` computes the index of the 

4324 first *element*, without computing the previous combinations. 

4325 

4326 >>> combination_index('adf', 'abcdefg') 

4327 10 

4328 

4329 ``ValueError`` will be raised if the given *element* isn't one of the 

4330 combinations of *iterable*. 

4331 """ 

4332 element = enumerate(element) 

4333 k, y = next(element, (None, None)) 

4334 if k is None: 

4335 return 0 

4336 

4337 indexes = [] 

4338 pool = enumerate(iterable) 

4339 for n, x in pool: 

4340 if x == y: 

4341 indexes.append(n) 

4342 tmp, y = next(element, (None, None)) 

4343 if tmp is None: 

4344 break 

4345 else: 

4346 k = tmp 

4347 else: 

4348 raise ValueError('element is not a combination of iterable') 

4349 

4350 n, _ = last(pool, default=(n, None)) 

4351 

4352 # Python versions below 3.8 don't have math.comb 

4353 index = 1 

4354 for i, j in enumerate(reversed(indexes), start=1): 

4355 j = n - j 

4356 if i <= j: 

4357 index += comb(j, i) 

4358 

4359 return comb(n + 1, k + 1) - index 

4360 

4361 

4362def combination_with_replacement_index(element, iterable): 

4363 """Equivalent to 

4364 ``list(combinations_with_replacement(iterable, r)).index(element)`` 

4365 

4366 The subsequences with repetition of *iterable* that are of length *r* can 

4367 be ordered lexicographically. :func:`combination_with_replacement_index` 

4368 computes the index of the first *element*, without computing the previous 

4369 combinations with replacement. 

4370 

4371 >>> combination_with_replacement_index('adf', 'abcdefg') 

4372 20 

4373 

4374 ``ValueError`` will be raised if the given *element* isn't one of the 

4375 combinations with replacement of *iterable*. 

4376 """ 

4377 element = tuple(element) 

4378 l = len(element) 

4379 element = enumerate(element) 

4380 

4381 k, y = next(element, (None, None)) 

4382 if k is None: 

4383 return 0 

4384 

4385 indexes = [] 

4386 pool = tuple(iterable) 

4387 for n, x in enumerate(pool): 

4388 while x == y: 

4389 indexes.append(n) 

4390 tmp, y = next(element, (None, None)) 

4391 if tmp is None: 

4392 break 

4393 else: 

4394 k = tmp 

4395 if y is None: 

4396 break 

4397 else: 

4398 raise ValueError( 

4399 'element is not a combination with replacement of iterable' 

4400 ) 

4401 

4402 n = len(pool) 

4403 occupations = [0] * n 

4404 for p in indexes: 

4405 occupations[p] += 1 

4406 

4407 index = 0 

4408 cumulative_sum = 0 

4409 for k in range(1, n): 

4410 cumulative_sum += occupations[k - 1] 

4411 j = l + n - 1 - k - cumulative_sum 

4412 i = n - k 

4413 if i <= j: 

4414 index += comb(j, i) 

4415 

4416 return index 

4417 

4418 

4419def permutation_index(element, iterable): 

4420 """Equivalent to ``list(permutations(iterable, r)).index(element)``` 

4421 

4422 The subsequences of *iterable* that are of length *r* where order is 

4423 important can be ordered lexicographically. :func:`permutation_index` 

4424 computes the index of the first *element* directly, without computing 

4425 the previous permutations. 

4426 

4427 >>> permutation_index([1, 3, 2], range(5)) 

4428 19 

4429 

4430 ``ValueError`` will be raised if the given *element* isn't one of the 

4431 permutations of *iterable*. 

4432 """ 

4433 index = 0 

4434 pool = list(iterable) 

4435 for i, x in zip(range(len(pool), -1, -1), element): 

4436 r = pool.index(x) 

4437 index = index * i + r 

4438 del pool[r] 

4439 

4440 return index 

4441 

4442 

4443class countable: 

4444 """Wrap *iterable* and keep a count of how many items have been consumed. 

4445 

4446 The ``items_seen`` attribute starts at ``0`` and increments as the iterable 

4447 is consumed: 

4448 

4449 >>> iterable = map(str, range(10)) 

4450 >>> it = countable(iterable) 

4451 >>> it.items_seen 

4452 0 

4453 >>> next(it), next(it) 

4454 ('0', '1') 

4455 >>> list(it) 

4456 ['2', '3', '4', '5', '6', '7', '8', '9'] 

4457 >>> it.items_seen 

4458 10 

4459 """ 

4460 

4461 def __init__(self, iterable): 

4462 self._iterator = iter(iterable) 

4463 self.items_seen = 0 

4464 

4465 def __iter__(self): 

4466 return self 

4467 

4468 def __next__(self): 

4469 item = next(self._iterator) 

4470 self.items_seen += 1 

4471 

4472 return item 

4473 

4474 

4475def chunked_even(iterable, n): 

4476 """Break *iterable* into lists of approximately length *n*. 

4477 Items are distributed such the lengths of the lists differ by at most 

4478 1 item. 

4479 

4480 >>> iterable = [1, 2, 3, 4, 5, 6, 7] 

4481 >>> n = 3 

4482 >>> list(chunked_even(iterable, n)) # List lengths: 3, 2, 2 

4483 [[1, 2, 3], [4, 5], [6, 7]] 

4484 >>> list(chunked(iterable, n)) # List lengths: 3, 3, 1 

4485 [[1, 2, 3], [4, 5, 6], [7]] 

4486 

4487 """ 

4488 iterator = iter(iterable) 

4489 

4490 # Initialize a buffer to process the chunks while keeping 

4491 # some back to fill any underfilled chunks 

4492 min_buffer = (n - 1) * (n - 2) 

4493 buffer = list(islice(iterator, min_buffer)) 

4494 

4495 # Append items until we have a completed chunk 

4496 for _ in islice(map(buffer.append, iterator), n, None, n): 

4497 yield buffer[:n] 

4498 del buffer[:n] 

4499 

4500 # Check if any chunks need addition processing 

4501 if not buffer: 

4502 return 

4503 length = len(buffer) 

4504 

4505 # Chunks are either size `full_size <= n` or `partial_size = full_size - 1` 

4506 q, r = divmod(length, n) 

4507 num_lists = q + (1 if r > 0 else 0) 

4508 q, r = divmod(length, num_lists) 

4509 full_size = q + (1 if r > 0 else 0) 

4510 partial_size = full_size - 1 

4511 num_full = length - partial_size * num_lists 

4512 

4513 # Yield chunks of full size 

4514 partial_start_idx = num_full * full_size 

4515 if full_size > 0: 

4516 for i in range(0, partial_start_idx, full_size): 

4517 yield buffer[i : i + full_size] 

4518 

4519 # Yield chunks of partial size 

4520 if partial_size > 0: 

4521 for i in range(partial_start_idx, length, partial_size): 

4522 yield buffer[i : i + partial_size] 

4523 

4524 

4525def zip_broadcast(*objects, scalar_types=(str, bytes), strict=False): 

4526 """A version of :func:`zip` that "broadcasts" any scalar 

4527 (i.e., non-iterable) items into output tuples. 

4528 

4529 >>> iterable_1 = [1, 2, 3] 

4530 >>> iterable_2 = ['a', 'b', 'c'] 

4531 >>> scalar = '_' 

4532 >>> list(zip_broadcast(iterable_1, iterable_2, scalar)) 

4533 [(1, 'a', '_'), (2, 'b', '_'), (3, 'c', '_')] 

4534 

4535 The *scalar_types* keyword argument determines what types are considered 

4536 scalar. It is set to ``(str, bytes)`` by default. Set it to ``None`` to 

4537 treat strings and byte strings as iterable: 

4538 

4539 >>> list(zip_broadcast('abc', 0, 'xyz', scalar_types=None)) 

4540 [('a', 0, 'x'), ('b', 0, 'y'), ('c', 0, 'z')] 

4541 

4542 If the *strict* keyword argument is ``True``, then 

4543 ``UnequalIterablesError`` will be raised if any of the iterables have 

4544 different lengths. 

4545 """ 

4546 

4547 def is_scalar(obj): 

4548 if scalar_types and isinstance(obj, scalar_types): 

4549 return True 

4550 try: 

4551 iter(obj) 

4552 except TypeError: 

4553 return True 

4554 else: 

4555 return False 

4556 

4557 size = len(objects) 

4558 if not size: 

4559 return 

4560 

4561 new_item = [None] * size 

4562 iterables, iterable_positions = [], [] 

4563 for i, obj in enumerate(objects): 

4564 if is_scalar(obj): 

4565 new_item[i] = obj 

4566 else: 

4567 iterables.append(iter(obj)) 

4568 iterable_positions.append(i) 

4569 

4570 if not iterables: 

4571 yield tuple(objects) 

4572 return 

4573 

4574 zipper = _zip_equal if strict else zip 

4575 for item in zipper(*iterables): 

4576 for i, new_item[i] in zip(iterable_positions, item): 

4577 pass 

4578 yield tuple(new_item) 

4579 

4580 

4581def unique_in_window(iterable, n, key=None): 

4582 """Yield the items from *iterable* that haven't been seen recently. 

4583 *n* is the size of the lookback window. 

4584 

4585 >>> iterable = [0, 1, 0, 2, 3, 0] 

4586 >>> n = 3 

4587 >>> list(unique_in_window(iterable, n)) 

4588 [0, 1, 2, 3, 0] 

4589 

4590 The *key* function, if provided, will be used to determine uniqueness: 

4591 

4592 >>> list(unique_in_window('abAcda', 3, key=lambda x: x.lower())) 

4593 ['a', 'b', 'c', 'd', 'a'] 

4594 

4595 The items in *iterable* must be hashable. 

4596 

4597 """ 

4598 if n <= 0: 

4599 raise ValueError('n must be greater than 0') 

4600 

4601 window = deque(maxlen=n) 

4602 counts = defaultdict(int) 

4603 use_key = key is not None 

4604 

4605 for item in iterable: 

4606 if len(window) == n: 

4607 to_discard = window[0] 

4608 if counts[to_discard] == 1: 

4609 del counts[to_discard] 

4610 else: 

4611 counts[to_discard] -= 1 

4612 

4613 k = key(item) if use_key else item 

4614 if k not in counts: 

4615 yield item 

4616 counts[k] += 1 

4617 window.append(k) 

4618 

4619 

4620def duplicates_everseen(iterable, key=None): 

4621 """Yield duplicate elements after their first appearance. 

4622 

4623 >>> list(duplicates_everseen('mississippi')) 

4624 ['s', 'i', 's', 's', 'i', 'p', 'i'] 

4625 >>> list(duplicates_everseen('AaaBbbCccAaa', str.lower)) 

4626 ['a', 'a', 'b', 'b', 'c', 'c', 'A', 'a', 'a'] 

4627 

4628 This function is analogous to :func:`unique_everseen` and is subject to 

4629 the same performance considerations. 

4630 

4631 """ 

4632 seen_set = set() 

4633 seen_list = [] 

4634 use_key = key is not None 

4635 

4636 for element in iterable: 

4637 k = key(element) if use_key else element 

4638 try: 

4639 if k not in seen_set: 

4640 seen_set.add(k) 

4641 else: 

4642 yield element 

4643 except TypeError: 

4644 if k not in seen_list: 

4645 seen_list.append(k) 

4646 else: 

4647 yield element 

4648 

4649 

4650def duplicates_justseen(iterable, key=None): 

4651 """Yields serially-duplicate elements after their first appearance. 

4652 

4653 >>> list(duplicates_justseen('mississippi')) 

4654 ['s', 's', 'p'] 

4655 >>> list(duplicates_justseen('AaaBbbCccAaa', str.lower)) 

4656 ['a', 'a', 'b', 'b', 'c', 'c', 'a', 'a'] 

4657 

4658 This function is analogous to :func:`unique_justseen`. 

4659 

4660 """ 

4661 return flatten(g for _, g in groupby(iterable, key) for _ in g) 

4662 

4663 

4664def classify_unique(iterable, key=None): 

4665 """Classify each element in terms of its uniqueness. 

4666 

4667 For each element in the input iterable, return a 3-tuple consisting of: 

4668 

4669 1. The element itself 

4670 2. ``False`` if the element is equal to the one preceding it in the input, 

4671 ``True`` otherwise (i.e. the equivalent of :func:`unique_justseen`) 

4672 3. ``False`` if this element has been seen anywhere in the input before, 

4673 ``True`` otherwise (i.e. the equivalent of :func:`unique_everseen`) 

4674 

4675 >>> list(classify_unique('otto')) # doctest: +NORMALIZE_WHITESPACE 

4676 [('o', True, True), 

4677 ('t', True, True), 

4678 ('t', False, False), 

4679 ('o', True, False)] 

4680 

4681 This function is analogous to :func:`unique_everseen` and is subject to 

4682 the same performance considerations. 

4683 

4684 """ 

4685 seen_set = set() 

4686 seen_list = [] 

4687 use_key = key is not None 

4688 previous = None 

4689 

4690 for i, element in enumerate(iterable): 

4691 k = key(element) if use_key else element 

4692 is_unique_justseen = not i or previous != k 

4693 previous = k 

4694 is_unique_everseen = False 

4695 try: 

4696 if k not in seen_set: 

4697 seen_set.add(k) 

4698 is_unique_everseen = True 

4699 except TypeError: 

4700 if k not in seen_list: 

4701 seen_list.append(k) 

4702 is_unique_everseen = True 

4703 yield element, is_unique_justseen, is_unique_everseen 

4704 

4705 

4706def minmax(iterable_or_value, *others, key=None, default=_marker): 

4707 """Returns both the smallest and largest items from an iterable 

4708 or from two or more arguments. 

4709 

4710 >>> minmax([3, 1, 5]) 

4711 (1, 5) 

4712 

4713 >>> minmax(4, 2, 6) 

4714 (2, 6) 

4715 

4716 If a *key* function is provided, it will be used to transform the input 

4717 items for comparison. 

4718 

4719 >>> minmax([5, 30], key=str) # '30' sorts before '5' 

4720 (30, 5) 

4721 

4722 If a *default* value is provided, it will be returned if there are no 

4723 input items. 

4724 

4725 >>> minmax([], default=(0, 0)) 

4726 (0, 0) 

4727 

4728 Otherwise ``ValueError`` is raised. 

4729 

4730 This function makes a single pass over the input elements and takes care to 

4731 minimize the number of comparisons made during processing. 

4732 

4733 Note that unlike the builtin ``max`` function, which always returns the first 

4734 item with the maximum value, this function may return another item when there are 

4735 ties. 

4736 

4737 This function is based on the 

4738 `recipe <https://code.activestate.com/recipes/577916-fast-minmax-function>`__ by 

4739 Raymond Hettinger. 

4740 """ 

4741 iterable = (iterable_or_value, *others) if others else iterable_or_value 

4742 

4743 it = iter(iterable) 

4744 

4745 try: 

4746 lo = hi = next(it) 

4747 except StopIteration as exc: 

4748 if default is _marker: 

4749 raise ValueError( 

4750 '`minmax()` argument is an empty iterable. ' 

4751 'Provide a `default` value to suppress this error.' 

4752 ) from exc 

4753 return default 

4754 

4755 # Different branches depending on the presence of key. This saves a lot 

4756 # of unimportant copies which would slow the "key=None" branch 

4757 # significantly down. 

4758 if key is None: 

4759 for x, y in zip_longest(it, it, fillvalue=lo): 

4760 if y < x: 

4761 x, y = y, x 

4762 if x < lo: 

4763 lo = x 

4764 if hi < y: 

4765 hi = y 

4766 

4767 else: 

4768 lo_key = hi_key = key(lo) 

4769 

4770 for x, y in zip_longest(it, it, fillvalue=lo): 

4771 x_key, y_key = key(x), key(y) 

4772 

4773 if y_key < x_key: 

4774 x, y, x_key, y_key = y, x, y_key, x_key 

4775 if x_key < lo_key: 

4776 lo, lo_key = x, x_key 

4777 if hi_key < y_key: 

4778 hi, hi_key = y, y_key 

4779 

4780 return lo, hi 

4781 

4782 

4783def constrained_batches( 

4784 iterable, max_size, max_count=None, get_len=len, strict=True 

4785): 

4786 """Yield batches of items from *iterable* with a combined size limited by 

4787 *max_size*. 

4788 

4789 >>> iterable = [b'12345', b'123', b'12345678', b'1', b'1', b'12', b'1'] 

4790 >>> list(constrained_batches(iterable, 10)) 

4791 [(b'12345', b'123'), (b'12345678', b'1', b'1'), (b'12', b'1')] 

4792 

4793 If a *max_count* is supplied, the number of items per batch is also 

4794 limited: 

4795 

4796 >>> iterable = [b'12345', b'123', b'12345678', b'1', b'1', b'12', b'1'] 

4797 >>> list(constrained_batches(iterable, 10, max_count = 2)) 

4798 [(b'12345', b'123'), (b'12345678', b'1'), (b'1', b'12'), (b'1',)] 

4799 

4800 If a *get_len* function is supplied, use that instead of :func:`len` to 

4801 determine item size. 

4802 

4803 If *strict* is ``True``, raise ``ValueError`` if any single item is bigger 

4804 than *max_size*. Otherwise, allow single items to exceed *max_size*. 

4805 """ 

4806 if max_size <= 0: 

4807 raise ValueError('maximum size must be greater than zero') 

4808 

4809 batch = [] 

4810 batch_size = 0 

4811 batch_count = 0 

4812 for item in iterable: 

4813 item_len = get_len(item) 

4814 if strict and item_len > max_size: 

4815 raise ValueError('item size exceeds maximum size') 

4816 

4817 reached_count = batch_count == max_count 

4818 reached_size = item_len + batch_size > max_size 

4819 if batch_count and (reached_size or reached_count): 

4820 yield tuple(batch) 

4821 batch.clear() 

4822 batch_size = 0 

4823 batch_count = 0 

4824 

4825 batch.append(item) 

4826 batch_size += item_len 

4827 batch_count += 1 

4828 

4829 if batch: 

4830 yield tuple(batch) 

4831 

4832 

4833def gray_product(*iterables): 

4834 """Like :func:`itertools.product`, but return tuples in an order such 

4835 that only one element in the generated tuple changes from one iteration 

4836 to the next. 

4837 

4838 >>> list(gray_product('AB','CD')) 

4839 [('A', 'C'), ('B', 'C'), ('B', 'D'), ('A', 'D')] 

4840 

4841 This function consumes all of the input iterables before producing output. 

4842 If any of the input iterables have fewer than two items, ``ValueError`` 

4843 is raised. 

4844 

4845 For information on the algorithm, see 

4846 `this section <https://www-cs-faculty.stanford.edu/~knuth/fasc2a.ps.gz>`__ 

4847 of Donald Knuth's *The Art of Computer Programming*. 

4848 """ 

4849 all_iterables = tuple(tuple(x) for x in iterables) 

4850 iterable_count = len(all_iterables) 

4851 for iterable in all_iterables: 

4852 if len(iterable) < 2: 

4853 raise ValueError("each iterable must have two or more items") 

4854 

4855 # This is based on "Algorithm H" from section 7.2.1.1, page 20. 

4856 # a holds the indexes of the source iterables for the n-tuple to be yielded 

4857 # f is the array of "focus pointers" 

4858 # o is the array of "directions" 

4859 a = [0] * iterable_count 

4860 f = list(range(iterable_count + 1)) 

4861 o = [1] * iterable_count 

4862 while True: 

4863 yield tuple(all_iterables[i][a[i]] for i in range(iterable_count)) 

4864 j = f[0] 

4865 f[0] = 0 

4866 if j == iterable_count: 

4867 break 

4868 a[j] = a[j] + o[j] 

4869 if a[j] == 0 or a[j] == len(all_iterables[j]) - 1: 

4870 o[j] = -o[j] 

4871 f[j] = f[j + 1] 

4872 f[j + 1] = j + 1 

4873 

4874 

4875def partial_product(*iterables): 

4876 """Yields tuples containing one item from each iterator, with subsequent 

4877 tuples changing a single item at a time by advancing each iterator until it 

4878 is exhausted. This sequence guarantees every value in each iterable is 

4879 output at least once without generating all possible combinations. 

4880 

4881 This may be useful, for example, when testing an expensive function. 

4882 

4883 >>> list(partial_product('AB', 'C', 'DEF')) 

4884 [('A', 'C', 'D'), ('B', 'C', 'D'), ('B', 'C', 'E'), ('B', 'C', 'F')] 

4885 """ 

4886 

4887 iterators = list(map(iter, iterables)) 

4888 

4889 try: 

4890 prod = [next(it) for it in iterators] 

4891 except StopIteration: 

4892 return 

4893 yield tuple(prod) 

4894 

4895 for i, it in enumerate(iterators): 

4896 for prod[i] in it: 

4897 yield tuple(prod) 

4898 

4899 

4900def takewhile_inclusive(predicate, iterable): 

4901 """A variant of :func:`takewhile` that yields one additional element. 

4902 

4903 >>> list(takewhile_inclusive(lambda x: x < 5, [1, 4, 6, 4, 1])) 

4904 [1, 4, 6] 

4905 

4906 :func:`takewhile` would return ``[1, 4]``. 

4907 """ 

4908 for x in iterable: 

4909 yield x 

4910 if not predicate(x): 

4911 break 

4912 

4913 

4914def outer_product(func, xs, ys, *args, **kwargs): 

4915 """A generalized outer product that applies a binary function to all 

4916 pairs of items. Returns a 2D matrix with ``len(xs)`` rows and ``len(ys)`` 

4917 columns. 

4918 Also accepts ``*args`` and ``**kwargs`` that are passed to ``func``. 

4919 

4920 Multiplication table: 

4921 

4922 >>> list(outer_product(mul, range(1, 4), range(1, 6))) 

4923 [(1, 2, 3, 4, 5), (2, 4, 6, 8, 10), (3, 6, 9, 12, 15)] 

4924 

4925 Cross tabulation: 

4926 

4927 >>> xs = ['A', 'B', 'A', 'A', 'B', 'B', 'A', 'A', 'B', 'B'] 

4928 >>> ys = ['X', 'X', 'X', 'Y', 'Z', 'Z', 'Y', 'Y', 'Z', 'Z'] 

4929 >>> pair_counts = Counter(zip(xs, ys)) 

4930 >>> count_rows = lambda x, y: pair_counts[x, y] 

4931 >>> list(outer_product(count_rows, sorted(set(xs)), sorted(set(ys)))) 

4932 [(2, 3, 0), (1, 0, 4)] 

4933 

4934 Usage with ``*args`` and ``**kwargs``: 

4935 

4936 >>> animals = ['cat', 'wolf', 'mouse'] 

4937 >>> list(outer_product(min, animals, animals, key=len)) 

4938 [('cat', 'cat', 'cat'), ('cat', 'wolf', 'wolf'), ('cat', 'wolf', 'mouse')] 

4939 """ 

4940 ys = tuple(ys) 

4941 return batched( 

4942 starmap(lambda x, y: func(x, y, *args, **kwargs), product(xs, ys)), 

4943 n=len(ys), 

4944 ) 

4945 

4946 

4947def iter_suppress(iterable, *exceptions): 

4948 """Yield each of the items from *iterable*. If the iteration raises one of 

4949 the specified *exceptions*, that exception will be suppressed and iteration 

4950 will stop. 

4951 

4952 >>> from itertools import chain 

4953 >>> def breaks_at_five(x): 

4954 ... while True: 

4955 ... if x >= 5: 

4956 ... raise RuntimeError 

4957 ... yield x 

4958 ... x += 1 

4959 >>> it_1 = iter_suppress(breaks_at_five(1), RuntimeError) 

4960 >>> it_2 = iter_suppress(breaks_at_five(2), RuntimeError) 

4961 >>> list(chain(it_1, it_2)) 

4962 [1, 2, 3, 4, 2, 3, 4] 

4963 """ 

4964 try: 

4965 yield from iterable 

4966 except exceptions: 

4967 return 

4968 

4969 

4970def filter_map(func, iterable): 

4971 """Apply *func* to every element of *iterable*, yielding only those which 

4972 are not ``None``. 

4973 

4974 >>> elems = ['1', 'a', '2', 'b', '3'] 

4975 >>> list(filter_map(lambda s: int(s) if s.isnumeric() else None, elems)) 

4976 [1, 2, 3] 

4977 """ 

4978 for x in iterable: 

4979 y = func(x) 

4980 if y is not None: 

4981 yield y 

4982 

4983 

4984def powerset_of_sets(iterable): 

4985 """Yields all possible subsets of the iterable. 

4986 

4987 >>> list(powerset_of_sets([1, 2, 3])) # doctest: +SKIP 

4988 [set(), {1}, {2}, {3}, {1, 2}, {1, 3}, {2, 3}, {1, 2, 3}] 

4989 >>> list(powerset_of_sets([1, 1, 0])) # doctest: +SKIP 

4990 [set(), {1}, {0}, {0, 1}] 

4991 

4992 :func:`powerset_of_sets` takes care to minimize the number 

4993 of hash operations performed. 

4994 """ 

4995 sets = tuple(dict.fromkeys(map(frozenset, zip(iterable)))) 

4996 return chain.from_iterable( 

4997 starmap(set().union, combinations(sets, r)) 

4998 for r in range(len(sets) + 1) 

4999 ) 

5000 

5001 

5002def join_mappings(**field_to_map): 

5003 """ 

5004 Joins multiple mappings together using their common keys. 

5005 

5006 >>> user_scores = {'elliot': 50, 'claris': 60} 

5007 >>> user_times = {'elliot': 30, 'claris': 40} 

5008 >>> join_mappings(score=user_scores, time=user_times) 

5009 {'elliot': {'score': 50, 'time': 30}, 'claris': {'score': 60, 'time': 40}} 

5010 """ 

5011 ret = defaultdict(dict) 

5012 

5013 for field_name, mapping in field_to_map.items(): 

5014 for key, value in mapping.items(): 

5015 ret[key][field_name] = value 

5016 

5017 return dict(ret) 

5018 

5019 

5020def _complex_sumprod(v1, v2): 

5021 """High precision sumprod() for complex numbers. 

5022 Used by :func:`dft` and :func:`idft`. 

5023 """ 

5024 

5025 real = attrgetter('real') 

5026 imag = attrgetter('imag') 

5027 r1 = chain(map(real, v1), map(neg, map(imag, v1))) 

5028 r2 = chain(map(real, v2), map(imag, v2)) 

5029 i1 = chain(map(real, v1), map(imag, v1)) 

5030 i2 = chain(map(imag, v2), map(real, v2)) 

5031 return complex(_fsumprod(r1, r2), _fsumprod(i1, i2)) 

5032 

5033 

5034def dft(xarr): 

5035 """Discrete Fourier Transform. *xarr* is a sequence of complex numbers. 

5036 Yields the components of the corresponding transformed output vector. 

5037 

5038 >>> import cmath 

5039 >>> xarr = [1, 2-1j, -1j, -1+2j] # time domain 

5040 >>> Xarr = [2, -2-2j, -2j, 4+4j] # frequency domain 

5041 >>> magnitudes, phases = zip(*map(cmath.polar, Xarr)) 

5042 >>> all(map(cmath.isclose, dft(xarr), Xarr)) 

5043 True 

5044 

5045 Inputs are restricted to numeric types that can add and multiply 

5046 with a complex number. This includes int, float, complex, and 

5047 Fraction, but excludes Decimal. 

5048 

5049 See :func:`idft` for the inverse Discrete Fourier Transform. 

5050 """ 

5051 N = len(xarr) 

5052 roots_of_unity = [e ** (n / N * tau * -1j) for n in range(N)] 

5053 for k in range(N): 

5054 coeffs = [roots_of_unity[k * n % N] for n in range(N)] 

5055 yield _complex_sumprod(xarr, coeffs) 

5056 

5057 

5058def idft(Xarr): 

5059 """Inverse Discrete Fourier Transform. *Xarr* is a sequence of 

5060 complex numbers. Yields the components of the corresponding 

5061 inverse-transformed output vector. 

5062 

5063 >>> import cmath 

5064 >>> xarr = [1, 2-1j, -1j, -1+2j] # time domain 

5065 >>> Xarr = [2, -2-2j, -2j, 4+4j] # frequency domain 

5066 >>> all(map(cmath.isclose, idft(Xarr), xarr)) 

5067 True 

5068 

5069 Inputs are restricted to numeric types that can add and multiply 

5070 with a complex number. This includes int, float, complex, and 

5071 Fraction, but excludes Decimal. 

5072 

5073 See :func:`dft` for the Discrete Fourier Transform. 

5074 """ 

5075 N = len(Xarr) 

5076 roots_of_unity = [e ** (n / N * tau * 1j) for n in range(N)] 

5077 for k in range(N): 

5078 coeffs = [roots_of_unity[k * n % N] for n in range(N)] 

5079 yield _complex_sumprod(Xarr, coeffs) / N 

5080 

5081 

5082def doublestarmap(func, iterable): 

5083 """Apply *func* to every item of *iterable* by dictionary unpacking 

5084 the item into *func*. 

5085 

5086 The difference between :func:`itertools.starmap` and :func:`doublestarmap` 

5087 parallels the distinction between ``func(*a)`` and ``func(**a)``. 

5088 

5089 >>> iterable = [{'a': 1, 'b': 2}, {'a': 40, 'b': 60}] 

5090 >>> list(doublestarmap(lambda a, b: a + b, iterable)) 

5091 [3, 100] 

5092 

5093 ``TypeError`` will be raised if *func*'s signature doesn't match the 

5094 mapping contained in *iterable* or if *iterable* does not contain mappings. 

5095 """ 

5096 for item in iterable: 

5097 yield func(**item) 

5098 

5099 

5100def _nth_prime_bounds(n): 

5101 """Bounds for the nth prime (counting from 1): lb <= p_n <= ub.""" 

5102 # At and above 688,383, the lb/ub spread is under 0.003 * n. 

5103 

5104 if n < 1: 

5105 raise ValueError 

5106 

5107 if n < 6: 

5108 return (n, 2.25 * n) 

5109 

5110 # https://en.wikipedia.org/wiki/Prime-counting_function#Inequalities 

5111 upper_bound = n * log(n * log(n)) 

5112 lower_bound = upper_bound - n 

5113 if n >= 688_383: 

5114 upper_bound -= n * (1.0 - (log(log(n)) - 2.0) / log(n)) 

5115 

5116 return lower_bound, upper_bound 

5117 

5118 

5119def nth_prime(n, *, approximate=False): 

5120 """Return the nth prime (counting from 0). 

5121 

5122 >>> nth_prime(0) 

5123 2 

5124 >>> nth_prime(100) 

5125 547 

5126 

5127 If *approximate* is set to True, will return a prime in the close 

5128 to the nth prime. The estimation is much faster than computing 

5129 an exact result. 

5130 

5131 >>> nth_prime(200_000_000, approximate=True) # Exact result is 4222234763 

5132 4217820427 

5133 

5134 """ 

5135 lb, ub = _nth_prime_bounds(n + 1) 

5136 

5137 if not approximate or n <= 1_000_000: 

5138 return nth(sieve(ceil(ub)), n) 

5139 

5140 # Search from the midpoint and return the first odd prime 

5141 odd = floor((lb + ub) / 2) | 1 

5142 return first_true(count(odd, step=2), pred=is_prime) 

5143 

5144 

5145def argmin(iterable, *, key=None): 

5146 """ 

5147 Index of the first occurrence of a minimum value in an iterable. 

5148 

5149 >>> argmin('efghabcdijkl') 

5150 4 

5151 >>> argmin([3, 2, 1, 0, 4, 2, 1, 0]) 

5152 3 

5153 

5154 For example, look up a label corresponding to the position 

5155 of a value that minimizes a cost function:: 

5156 

5157 >>> def cost(x): 

5158 ... "Days for a wound to heal given a subject's age." 

5159 ... return x**2 - 20*x + 150 

5160 ... 

5161 >>> labels = ['homer', 'marge', 'bart', 'lisa', 'maggie'] 

5162 >>> ages = [ 35, 30, 10, 9, 1 ] 

5163 

5164 # Fastest healing family member 

5165 >>> labels[argmin(ages, key=cost)] 

5166 'bart' 

5167 

5168 # Age with fastest healing 

5169 >>> min(ages, key=cost) 

5170 10 

5171 

5172 """ 

5173 if key is not None: 

5174 iterable = map(key, iterable) 

5175 return min(enumerate(iterable), key=itemgetter(1))[0] 

5176 

5177 

5178def argmax(iterable, *, key=None): 

5179 """ 

5180 Index of the first occurrence of a maximum value in an iterable. 

5181 

5182 >>> argmax('abcdefghabcd') 

5183 7 

5184 >>> argmax([0, 1, 2, 3, 3, 2, 1, 0]) 

5185 3 

5186 

5187 For example, identify the best machine learning model:: 

5188 

5189 >>> models = ['svm', 'random forest', 'knn', 'naïve bayes'] 

5190 >>> accuracy = [ 68, 61, 84, 72 ] 

5191 

5192 # Most accurate model 

5193 >>> models[argmax(accuracy)] 

5194 'knn' 

5195 

5196 # Best accuracy 

5197 >>> max(accuracy) 

5198 84 

5199 

5200 """ 

5201 if key is not None: 

5202 iterable = map(key, iterable) 

5203 return max(enumerate(iterable), key=itemgetter(1))[0]