1import itertools
2from collections.abc import Hashable
3from operator import itemgetter
4from typing import Any, Callable, Dict, Iterable, List, Tuple, TypeVar, Union
5
6from .._typing import T_num, T_obj
7
8
9def cluster_list(xs: List[T_num], tolerance: T_num = 0) -> List[List[T_num]]:
10 if tolerance == 0:
11 return [[x] for x in sorted(xs)]
12 if len(xs) < 2:
13 return [[x] for x in sorted(xs)]
14 groups = []
15 xs = list(sorted(xs))
16 current_group = [xs[0]]
17 last = xs[0]
18 for x in xs[1:]:
19 if x <= (last + tolerance):
20 current_group.append(x)
21 else:
22 groups.append(current_group)
23 current_group = [x]
24 last = x
25 groups.append(current_group)
26 return groups
27
28
29def make_cluster_dict(values: Iterable[T_num], tolerance: T_num) -> Dict[T_num, int]:
30 clusters = cluster_list(list(set(values)), tolerance)
31
32 nested_tuples = [
33 [(val, i) for val in value_cluster] for i, value_cluster in enumerate(clusters)
34 ]
35
36 return dict(itertools.chain(*nested_tuples))
37
38
39Clusterable = TypeVar("Clusterable", T_obj, Tuple[Any, ...])
40
41
42def cluster_objects(
43 xs: List[Clusterable],
44 key_fn: Union[Hashable, Callable[[Clusterable], T_num]],
45 tolerance: T_num,
46 preserve_order: bool = False,
47) -> List[List[Clusterable]]:
48
49 if not callable(key_fn):
50 key_fn = itemgetter(key_fn)
51
52 values = map(key_fn, xs)
53 cluster_dict = make_cluster_dict(values, tolerance)
54
55 get_0, get_1 = itemgetter(0), itemgetter(1)
56
57 if preserve_order:
58 cluster_tuples = [(x, cluster_dict.get(key_fn(x))) for x in xs]
59 else:
60 cluster_tuples = sorted(
61 ((x, cluster_dict.get(key_fn(x))) for x in xs), key=get_1
62 )
63
64 grouped = itertools.groupby(cluster_tuples, key=get_1)
65
66 return [list(map(get_0, v)) for k, v in grouped]