Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.9/dist-packages/pandas/core/sorting.py: 47%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

245 statements  

1""" miscellaneous sorting / groupby utilities """ 

2from __future__ import annotations 

3 

4from collections import defaultdict 

5from typing import ( 

6 TYPE_CHECKING, 

7 Callable, 

8 DefaultDict, 

9 cast, 

10) 

11 

12import numpy as np 

13 

14from pandas._libs import ( 

15 algos, 

16 hashtable, 

17 lib, 

18) 

19from pandas._libs.hashtable import unique_label_indices 

20 

21from pandas.core.dtypes.common import ( 

22 ensure_int64, 

23 ensure_platform_int, 

24) 

25from pandas.core.dtypes.generic import ( 

26 ABCMultiIndex, 

27 ABCRangeIndex, 

28) 

29from pandas.core.dtypes.missing import isna 

30 

31from pandas.core.construction import extract_array 

32 

33if TYPE_CHECKING: 

34 from collections.abc import ( 

35 Hashable, 

36 Iterable, 

37 Sequence, 

38 ) 

39 

40 from pandas._typing import ( 

41 ArrayLike, 

42 AxisInt, 

43 IndexKeyFunc, 

44 Level, 

45 NaPosition, 

46 Shape, 

47 SortKind, 

48 npt, 

49 ) 

50 

51 from pandas import ( 

52 MultiIndex, 

53 Series, 

54 ) 

55 from pandas.core.arrays import ExtensionArray 

56 from pandas.core.indexes.base import Index 

57 

58 

59def get_indexer_indexer( 

60 target: Index, 

61 level: Level | list[Level] | None, 

62 ascending: list[bool] | bool, 

63 kind: SortKind, 

64 na_position: NaPosition, 

65 sort_remaining: bool, 

66 key: IndexKeyFunc, 

67) -> npt.NDArray[np.intp] | None: 

68 """ 

69 Helper method that return the indexer according to input parameters for 

70 the sort_index method of DataFrame and Series. 

71 

72 Parameters 

73 ---------- 

74 target : Index 

75 level : int or level name or list of ints or list of level names 

76 ascending : bool or list of bools, default True 

77 kind : {'quicksort', 'mergesort', 'heapsort', 'stable'} 

78 na_position : {'first', 'last'} 

79 sort_remaining : bool 

80 key : callable, optional 

81 

82 Returns 

83 ------- 

84 Optional[ndarray[intp]] 

85 The indexer for the new index. 

86 """ 

87 

88 # error: Incompatible types in assignment (expression has type 

89 # "Union[ExtensionArray, ndarray[Any, Any], Index, Series]", variable has 

90 # type "Index") 

91 target = ensure_key_mapped(target, key, levels=level) # type: ignore[assignment] 

92 target = target._sort_levels_monotonic() 

93 

94 if level is not None: 

95 _, indexer = target.sortlevel( 

96 level, 

97 ascending=ascending, 

98 sort_remaining=sort_remaining, 

99 na_position=na_position, 

100 ) 

101 elif (np.all(ascending) and target.is_monotonic_increasing) or ( 

102 not np.any(ascending) and target.is_monotonic_decreasing 

103 ): 

104 # Check monotonic-ness before sort an index (GH 11080) 

105 return None 

106 elif isinstance(target, ABCMultiIndex): 

107 codes = [lev.codes for lev in target._get_codes_for_sorting()] 

108 indexer = lexsort_indexer( 

109 codes, orders=ascending, na_position=na_position, codes_given=True 

110 ) 

111 else: 

112 # ascending can only be a Sequence for MultiIndex 

113 indexer = nargsort( 

114 target, 

115 kind=kind, 

116 ascending=cast(bool, ascending), 

117 na_position=na_position, 

118 ) 

119 return indexer 

120 

121 

122def get_group_index( 

123 labels, shape: Shape, sort: bool, xnull: bool 

124) -> npt.NDArray[np.int64]: 

125 """ 

126 For the particular label_list, gets the offsets into the hypothetical list 

127 representing the totally ordered cartesian product of all possible label 

128 combinations, *as long as* this space fits within int64 bounds; 

129 otherwise, though group indices identify unique combinations of 

130 labels, they cannot be deconstructed. 

131 - If `sort`, rank of returned ids preserve lexical ranks of labels. 

132 i.e. returned id's can be used to do lexical sort on labels; 

133 - If `xnull` nulls (-1 labels) are passed through. 

134 

135 Parameters 

136 ---------- 

137 labels : sequence of arrays 

138 Integers identifying levels at each location 

139 shape : tuple[int, ...] 

140 Number of unique levels at each location 

141 sort : bool 

142 If the ranks of returned ids should match lexical ranks of labels 

143 xnull : bool 

144 If true nulls are excluded. i.e. -1 values in the labels are 

145 passed through. 

146 

147 Returns 

148 ------- 

149 An array of type int64 where two elements are equal if their corresponding 

150 labels are equal at all location. 

151 

152 Notes 

153 ----- 

154 The length of `labels` and `shape` must be identical. 

155 """ 

156 

157 def _int64_cut_off(shape) -> int: 

158 acc = 1 

159 for i, mul in enumerate(shape): 

160 acc *= int(mul) 

161 if not acc < lib.i8max: 

162 return i 

163 return len(shape) 

164 

165 def maybe_lift(lab, size: int) -> tuple[np.ndarray, int]: 

166 # promote nan values (assigned -1 label in lab array) 

167 # so that all output values are non-negative 

168 return (lab + 1, size + 1) if (lab == -1).any() else (lab, size) 

169 

170 labels = [ensure_int64(x) for x in labels] 

171 lshape = list(shape) 

172 if not xnull: 

173 for i, (lab, size) in enumerate(zip(labels, shape)): 

174 labels[i], lshape[i] = maybe_lift(lab, size) 

175 

176 labels = list(labels) 

177 

178 # Iteratively process all the labels in chunks sized so less 

179 # than lib.i8max unique int ids will be required for each chunk 

180 while True: 

181 # how many levels can be done without overflow: 

182 nlev = _int64_cut_off(lshape) 

183 

184 # compute flat ids for the first `nlev` levels 

185 stride = np.prod(lshape[1:nlev], dtype="i8") 

186 out = stride * labels[0].astype("i8", subok=False, copy=False) 

187 

188 for i in range(1, nlev): 

189 if lshape[i] == 0: 

190 stride = np.int64(0) 

191 else: 

192 stride //= lshape[i] 

193 out += labels[i] * stride 

194 

195 if xnull: # exclude nulls 

196 mask = labels[0] == -1 

197 for lab in labels[1:nlev]: 

198 mask |= lab == -1 

199 out[mask] = -1 

200 

201 if nlev == len(lshape): # all levels done! 

202 break 

203 

204 # compress what has been done so far in order to avoid overflow 

205 # to retain lexical ranks, obs_ids should be sorted 

206 comp_ids, obs_ids = compress_group_index(out, sort=sort) 

207 

208 labels = [comp_ids] + labels[nlev:] 

209 lshape = [len(obs_ids)] + lshape[nlev:] 

210 

211 return out 

212 

213 

214def get_compressed_ids( 

215 labels, sizes: Shape 

216) -> tuple[npt.NDArray[np.intp], npt.NDArray[np.int64]]: 

217 """ 

218 Group_index is offsets into cartesian product of all possible labels. This 

219 space can be huge, so this function compresses it, by computing offsets 

220 (comp_ids) into the list of unique labels (obs_group_ids). 

221 

222 Parameters 

223 ---------- 

224 labels : list of label arrays 

225 sizes : tuple[int] of size of the levels 

226 

227 Returns 

228 ------- 

229 np.ndarray[np.intp] 

230 comp_ids 

231 np.ndarray[np.int64] 

232 obs_group_ids 

233 """ 

234 ids = get_group_index(labels, sizes, sort=True, xnull=False) 

235 return compress_group_index(ids, sort=True) 

236 

237 

238def is_int64_overflow_possible(shape: Shape) -> bool: 

239 the_prod = 1 

240 for x in shape: 

241 the_prod *= int(x) 

242 

243 return the_prod >= lib.i8max 

244 

245 

246def _decons_group_index( 

247 comp_labels: npt.NDArray[np.intp], shape: Shape 

248) -> list[npt.NDArray[np.intp]]: 

249 # reconstruct labels 

250 if is_int64_overflow_possible(shape): 

251 # at some point group indices are factorized, 

252 # and may not be deconstructed here! wrong path! 

253 raise ValueError("cannot deconstruct factorized group indices!") 

254 

255 label_list = [] 

256 factor = 1 

257 y = np.array(0) 

258 x = comp_labels 

259 for i in reversed(range(len(shape))): 

260 labels = (x - y) % (factor * shape[i]) // factor 

261 np.putmask(labels, comp_labels < 0, -1) 

262 label_list.append(labels) 

263 y = labels * factor 

264 factor *= shape[i] 

265 return label_list[::-1] 

266 

267 

268def decons_obs_group_ids( 

269 comp_ids: npt.NDArray[np.intp], 

270 obs_ids: npt.NDArray[np.intp], 

271 shape: Shape, 

272 labels: Sequence[npt.NDArray[np.signedinteger]], 

273 xnull: bool, 

274) -> list[npt.NDArray[np.intp]]: 

275 """ 

276 Reconstruct labels from observed group ids. 

277 

278 Parameters 

279 ---------- 

280 comp_ids : np.ndarray[np.intp] 

281 obs_ids: np.ndarray[np.intp] 

282 shape : tuple[int] 

283 labels : Sequence[np.ndarray[np.signedinteger]] 

284 xnull : bool 

285 If nulls are excluded; i.e. -1 labels are passed through. 

286 """ 

287 if not xnull: 

288 lift = np.fromiter(((a == -1).any() for a in labels), dtype=np.intp) 

289 arr_shape = np.asarray(shape, dtype=np.intp) + lift 

290 shape = tuple(arr_shape) 

291 

292 if not is_int64_overflow_possible(shape): 

293 # obs ids are deconstructable! take the fast route! 

294 out = _decons_group_index(obs_ids, shape) 

295 return out if xnull or not lift.any() else [x - y for x, y in zip(out, lift)] 

296 

297 indexer = unique_label_indices(comp_ids) 

298 return [lab[indexer].astype(np.intp, subok=False, copy=True) for lab in labels] 

299 

300 

301def lexsort_indexer( 

302 keys: Sequence[ArrayLike | Index | Series], 

303 orders=None, 

304 na_position: str = "last", 

305 key: Callable | None = None, 

306 codes_given: bool = False, 

307) -> npt.NDArray[np.intp]: 

308 """ 

309 Performs lexical sorting on a set of keys 

310 

311 Parameters 

312 ---------- 

313 keys : Sequence[ArrayLike | Index | Series] 

314 Sequence of arrays to be sorted by the indexer 

315 Sequence[Series] is only if key is not None. 

316 orders : bool or list of booleans, optional 

317 Determines the sorting order for each element in keys. If a list, 

318 it must be the same length as keys. This determines whether the 

319 corresponding element in keys should be sorted in ascending 

320 (True) or descending (False) order. if bool, applied to all 

321 elements as above. if None, defaults to True. 

322 na_position : {'first', 'last'}, default 'last' 

323 Determines placement of NA elements in the sorted list ("last" or "first") 

324 key : Callable, optional 

325 Callable key function applied to every element in keys before sorting 

326 codes_given: bool, False 

327 Avoid categorical materialization if codes are already provided. 

328 

329 Returns 

330 ------- 

331 np.ndarray[np.intp] 

332 """ 

333 from pandas.core.arrays import Categorical 

334 

335 if na_position not in ["last", "first"]: 

336 raise ValueError(f"invalid na_position: {na_position}") 

337 

338 if isinstance(orders, bool): 

339 orders = [orders] * len(keys) 

340 elif orders is None: 

341 orders = [True] * len(keys) 

342 

343 labels = [] 

344 

345 for k, order in zip(keys, orders): 

346 k = ensure_key_mapped(k, key) 

347 if codes_given: 

348 codes = cast(np.ndarray, k) 

349 n = codes.max() + 1 if len(codes) else 0 

350 else: 

351 cat = Categorical(k, ordered=True) 

352 codes = cat.codes 

353 n = len(cat.categories) 

354 

355 mask = codes == -1 

356 

357 if na_position == "last" and mask.any(): 

358 codes = np.where(mask, n, codes) 

359 

360 # not order means descending 

361 if not order: 

362 codes = np.where(mask, codes, n - codes - 1) 

363 

364 labels.append(codes) 

365 

366 return np.lexsort(labels[::-1]) 

367 

368 

369def nargsort( 

370 items: ArrayLike | Index | Series, 

371 kind: SortKind = "quicksort", 

372 ascending: bool = True, 

373 na_position: str = "last", 

374 key: Callable | None = None, 

375 mask: npt.NDArray[np.bool_] | None = None, 

376) -> npt.NDArray[np.intp]: 

377 """ 

378 Intended to be a drop-in replacement for np.argsort which handles NaNs. 

379 

380 Adds ascending, na_position, and key parameters. 

381 

382 (GH #6399, #5231, #27237) 

383 

384 Parameters 

385 ---------- 

386 items : np.ndarray, ExtensionArray, Index, or Series 

387 kind : {'quicksort', 'mergesort', 'heapsort', 'stable'}, default 'quicksort' 

388 ascending : bool, default True 

389 na_position : {'first', 'last'}, default 'last' 

390 key : Optional[Callable], default None 

391 mask : Optional[np.ndarray[bool]], default None 

392 Passed when called by ExtensionArray.argsort. 

393 

394 Returns 

395 ------- 

396 np.ndarray[np.intp] 

397 """ 

398 

399 if key is not None: 

400 # see TestDataFrameSortKey, TestRangeIndex::test_sort_values_key 

401 items = ensure_key_mapped(items, key) 

402 return nargsort( 

403 items, 

404 kind=kind, 

405 ascending=ascending, 

406 na_position=na_position, 

407 key=None, 

408 mask=mask, 

409 ) 

410 

411 if isinstance(items, ABCRangeIndex): 

412 return items.argsort(ascending=ascending) 

413 elif not isinstance(items, ABCMultiIndex): 

414 items = extract_array(items) 

415 else: 

416 raise TypeError( 

417 "nargsort does not support MultiIndex. Use index.sort_values instead." 

418 ) 

419 

420 if mask is None: 

421 mask = np.asarray(isna(items)) 

422 

423 if not isinstance(items, np.ndarray): 

424 # i.e. ExtensionArray 

425 return items.argsort( 

426 ascending=ascending, 

427 kind=kind, 

428 na_position=na_position, 

429 ) 

430 

431 idx = np.arange(len(items)) 

432 non_nans = items[~mask] 

433 non_nan_idx = idx[~mask] 

434 

435 nan_idx = np.nonzero(mask)[0] 

436 if not ascending: 

437 non_nans = non_nans[::-1] 

438 non_nan_idx = non_nan_idx[::-1] 

439 indexer = non_nan_idx[non_nans.argsort(kind=kind)] 

440 if not ascending: 

441 indexer = indexer[::-1] 

442 # Finally, place the NaNs at the end or the beginning according to 

443 # na_position 

444 if na_position == "last": 

445 indexer = np.concatenate([indexer, nan_idx]) 

446 elif na_position == "first": 

447 indexer = np.concatenate([nan_idx, indexer]) 

448 else: 

449 raise ValueError(f"invalid na_position: {na_position}") 

450 return ensure_platform_int(indexer) 

451 

452 

453def nargminmax(values: ExtensionArray, method: str, axis: AxisInt = 0): 

454 """ 

455 Implementation of np.argmin/argmax but for ExtensionArray and which 

456 handles missing values. 

457 

458 Parameters 

459 ---------- 

460 values : ExtensionArray 

461 method : {"argmax", "argmin"} 

462 axis : int, default 0 

463 

464 Returns 

465 ------- 

466 int 

467 """ 

468 assert method in {"argmax", "argmin"} 

469 func = np.argmax if method == "argmax" else np.argmin 

470 

471 mask = np.asarray(isna(values)) 

472 arr_values = values._values_for_argsort() 

473 

474 if arr_values.ndim > 1: 

475 if mask.any(): 

476 if axis == 1: 

477 zipped = zip(arr_values, mask) 

478 else: 

479 zipped = zip(arr_values.T, mask.T) 

480 return np.array([_nanargminmax(v, m, func) for v, m in zipped]) 

481 return func(arr_values, axis=axis) 

482 

483 return _nanargminmax(arr_values, mask, func) 

484 

485 

486def _nanargminmax(values: np.ndarray, mask: npt.NDArray[np.bool_], func) -> int: 

487 """ 

488 See nanargminmax.__doc__. 

489 """ 

490 idx = np.arange(values.shape[0]) 

491 non_nans = values[~mask] 

492 non_nan_idx = idx[~mask] 

493 

494 return non_nan_idx[func(non_nans)] 

495 

496 

497def _ensure_key_mapped_multiindex( 

498 index: MultiIndex, key: Callable, level=None 

499) -> MultiIndex: 

500 """ 

501 Returns a new MultiIndex in which key has been applied 

502 to all levels specified in level (or all levels if level 

503 is None). Used for key sorting for MultiIndex. 

504 

505 Parameters 

506 ---------- 

507 index : MultiIndex 

508 Index to which to apply the key function on the 

509 specified levels. 

510 key : Callable 

511 Function that takes an Index and returns an Index of 

512 the same shape. This key is applied to each level 

513 separately. The name of the level can be used to 

514 distinguish different levels for application. 

515 level : list-like, int or str, default None 

516 Level or list of levels to apply the key function to. 

517 If None, key function is applied to all levels. Other 

518 levels are left unchanged. 

519 

520 Returns 

521 ------- 

522 labels : MultiIndex 

523 Resulting MultiIndex with modified levels. 

524 """ 

525 

526 if level is not None: 

527 if isinstance(level, (str, int)): 

528 sort_levels = [level] 

529 else: 

530 sort_levels = level 

531 

532 sort_levels = [index._get_level_number(lev) for lev in sort_levels] 

533 else: 

534 sort_levels = list(range(index.nlevels)) # satisfies mypy 

535 

536 mapped = [ 

537 ensure_key_mapped(index._get_level_values(level), key) 

538 if level in sort_levels 

539 else index._get_level_values(level) 

540 for level in range(index.nlevels) 

541 ] 

542 

543 return type(index).from_arrays(mapped) 

544 

545 

546def ensure_key_mapped( 

547 values: ArrayLike | Index | Series, key: Callable | None, levels=None 

548) -> ArrayLike | Index | Series: 

549 """ 

550 Applies a callable key function to the values function and checks 

551 that the resulting value has the same shape. Can be called on Index 

552 subclasses, Series, DataFrames, or ndarrays. 

553 

554 Parameters 

555 ---------- 

556 values : Series, DataFrame, Index subclass, or ndarray 

557 key : Optional[Callable], key to be called on the values array 

558 levels : Optional[List], if values is a MultiIndex, list of levels to 

559 apply the key to. 

560 """ 

561 from pandas.core.indexes.api import Index 

562 

563 if not key: 

564 return values 

565 

566 if isinstance(values, ABCMultiIndex): 

567 return _ensure_key_mapped_multiindex(values, key, level=levels) 

568 

569 result = key(values.copy()) 

570 if len(result) != len(values): 

571 raise ValueError( 

572 "User-provided `key` function must not change the shape of the array." 

573 ) 

574 

575 try: 

576 if isinstance( 

577 values, Index 

578 ): # convert to a new Index subclass, not necessarily the same 

579 result = Index(result) 

580 else: 

581 # try to revert to original type otherwise 

582 type_of_values = type(values) 

583 # error: Too many arguments for "ExtensionArray" 

584 result = type_of_values(result) # type: ignore[call-arg] 

585 except TypeError: 

586 raise TypeError( 

587 f"User-provided `key` function returned an invalid type {type(result)} \ 

588 which could not be converted to {type(values)}." 

589 ) 

590 

591 return result 

592 

593 

594def get_flattened_list( 

595 comp_ids: npt.NDArray[np.intp], 

596 ngroups: int, 

597 levels: Iterable[Index], 

598 labels: Iterable[np.ndarray], 

599) -> list[tuple]: 

600 """Map compressed group id -> key tuple.""" 

601 comp_ids = comp_ids.astype(np.int64, copy=False) 

602 arrays: DefaultDict[int, list[int]] = defaultdict(list) 

603 for labs, level in zip(labels, levels): 

604 table = hashtable.Int64HashTable(ngroups) 

605 table.map_keys_to_values(comp_ids, labs.astype(np.int64, copy=False)) 

606 for i in range(ngroups): 

607 arrays[i].append(level[table.get_item(i)]) 

608 return [tuple(array) for array in arrays.values()] 

609 

610 

611def get_indexer_dict( 

612 label_list: list[np.ndarray], keys: list[Index] 

613) -> dict[Hashable, npt.NDArray[np.intp]]: 

614 """ 

615 Returns 

616 ------- 

617 dict: 

618 Labels mapped to indexers. 

619 """ 

620 shape = tuple(len(x) for x in keys) 

621 

622 group_index = get_group_index(label_list, shape, sort=True, xnull=True) 

623 if np.all(group_index == -1): 

624 # Short-circuit, lib.indices_fast will return the same 

625 return {} 

626 ngroups = ( 

627 ((group_index.size and group_index.max()) + 1) 

628 if is_int64_overflow_possible(shape) 

629 else np.prod(shape, dtype="i8") 

630 ) 

631 

632 sorter = get_group_index_sorter(group_index, ngroups) 

633 

634 sorted_labels = [lab.take(sorter) for lab in label_list] 

635 group_index = group_index.take(sorter) 

636 

637 return lib.indices_fast(sorter, group_index, keys, sorted_labels) 

638 

639 

640# ---------------------------------------------------------------------- 

641# sorting levels...cleverly? 

642 

643 

644def get_group_index_sorter( 

645 group_index: npt.NDArray[np.intp], ngroups: int | None = None 

646) -> npt.NDArray[np.intp]: 

647 """ 

648 algos.groupsort_indexer implements `counting sort` and it is at least 

649 O(ngroups), where 

650 ngroups = prod(shape) 

651 shape = map(len, keys) 

652 that is, linear in the number of combinations (cartesian product) of unique 

653 values of groupby keys. This can be huge when doing multi-key groupby. 

654 np.argsort(kind='mergesort') is O(count x log(count)) where count is the 

655 length of the data-frame; 

656 Both algorithms are `stable` sort and that is necessary for correctness of 

657 groupby operations. e.g. consider: 

658 df.groupby(key)[col].transform('first') 

659 

660 Parameters 

661 ---------- 

662 group_index : np.ndarray[np.intp] 

663 signed integer dtype 

664 ngroups : int or None, default None 

665 

666 Returns 

667 ------- 

668 np.ndarray[np.intp] 

669 """ 

670 if ngroups is None: 

671 ngroups = 1 + group_index.max() 

672 count = len(group_index) 

673 alpha = 0.0 # taking complexities literally; there may be 

674 beta = 1.0 # some room for fine-tuning these parameters 

675 do_groupsort = count > 0 and ((alpha + beta * ngroups) < (count * np.log(count))) 

676 if do_groupsort: 

677 sorter, _ = algos.groupsort_indexer( 

678 ensure_platform_int(group_index), 

679 ngroups, 

680 ) 

681 # sorter _should_ already be intp, but mypy is not yet able to verify 

682 else: 

683 sorter = group_index.argsort(kind="mergesort") 

684 return ensure_platform_int(sorter) 

685 

686 

687def compress_group_index( 

688 group_index: npt.NDArray[np.int64], sort: bool = True 

689) -> tuple[npt.NDArray[np.int64], npt.NDArray[np.int64]]: 

690 """ 

691 Group_index is offsets into cartesian product of all possible labels. This 

692 space can be huge, so this function compresses it, by computing offsets 

693 (comp_ids) into the list of unique labels (obs_group_ids). 

694 """ 

695 if len(group_index) and np.all(group_index[1:] >= group_index[:-1]): 

696 # GH 53806: fast path for sorted group_index 

697 unique_mask = np.concatenate( 

698 [group_index[:1] > -1, group_index[1:] != group_index[:-1]] 

699 ) 

700 comp_ids = unique_mask.cumsum() 

701 comp_ids -= 1 

702 obs_group_ids = group_index[unique_mask] 

703 else: 

704 size_hint = len(group_index) 

705 table = hashtable.Int64HashTable(size_hint) 

706 

707 group_index = ensure_int64(group_index) 

708 

709 # note, group labels come out ascending (ie, 1,2,3 etc) 

710 comp_ids, obs_group_ids = table.get_labels_groupby(group_index) 

711 

712 if sort and len(obs_group_ids) > 0: 

713 obs_group_ids, comp_ids = _reorder_by_uniques(obs_group_ids, comp_ids) 

714 

715 return ensure_int64(comp_ids), ensure_int64(obs_group_ids) 

716 

717 

718def _reorder_by_uniques( 

719 uniques: npt.NDArray[np.int64], labels: npt.NDArray[np.intp] 

720) -> tuple[npt.NDArray[np.int64], npt.NDArray[np.intp]]: 

721 """ 

722 Parameters 

723 ---------- 

724 uniques : np.ndarray[np.int64] 

725 labels : np.ndarray[np.intp] 

726 

727 Returns 

728 ------- 

729 np.ndarray[np.int64] 

730 np.ndarray[np.intp] 

731 """ 

732 # sorter is index where elements ought to go 

733 sorter = uniques.argsort() 

734 

735 # reverse_indexer is where elements came from 

736 reverse_indexer = np.empty(len(sorter), dtype=np.intp) 

737 reverse_indexer.put(sorter, np.arange(len(sorter))) 

738 

739 mask = labels < 0 

740 

741 # move labels to right locations (ie, unsort ascending labels) 

742 labels = reverse_indexer.take(labels) 

743 np.putmask(labels, mask, -1) 

744 

745 # sort observed ids 

746 uniques = uniques.take(sorter) 

747 

748 return uniques, labels