1import itertools
2from operator import itemgetter
3from typing import Dict, Iterable, Optional
4
5from .._typing import T_bbox, T_num, T_obj, T_obj_list
6from .clustering import cluster_objects
7
8
9def objects_to_rect(objects: Iterable[T_obj]) -> Dict[str, T_num]:
10 """
11 Given an iterable of objects, return the smallest rectangle (i.e. a
12 dict with "x0", "top", "x1", and "bottom" keys) that contains them
13 all.
14 """
15 return bbox_to_rect(objects_to_bbox(objects))
16
17
18def objects_to_bbox(objects: Iterable[T_obj]) -> T_bbox:
19 """
20 Given an iterable of objects, return the smallest bounding box that
21 contains them all.
22 """
23 return merge_bboxes(map(bbox_getter, objects))
24
25
26bbox_getter = itemgetter("x0", "top", "x1", "bottom")
27
28
29def obj_to_bbox(obj: T_obj) -> T_bbox:
30 """
31 Return the bounding box for an object.
32 """
33 bbox: T_bbox = bbox_getter(obj)
34 return bbox
35
36
37def bbox_to_rect(bbox: T_bbox) -> Dict[str, T_num]:
38 """
39 Return the rectangle (i.e a dict with keys "x0", "top", "x1",
40 "bottom") for an object.
41 """
42 return {"x0": bbox[0], "top": bbox[1], "x1": bbox[2], "bottom": bbox[3]}
43
44
45def merge_bboxes(bboxes: Iterable[T_bbox]) -> T_bbox:
46 """
47 Given an iterable of bounding boxes, return the smallest bounding box
48 that contains them all.
49 """
50 x0, top, x1, bottom = zip(*bboxes)
51 return (min(x0), min(top), max(x1), max(bottom))
52
53
54def get_bbox_overlap(a: T_bbox, b: T_bbox) -> Optional[T_bbox]:
55 a_left, a_top, a_right, a_bottom = a
56 b_left, b_top, b_right, b_bottom = b
57 o_left = max(a_left, b_left)
58 o_right = min(a_right, b_right)
59 o_bottom = min(a_bottom, b_bottom)
60 o_top = max(a_top, b_top)
61 o_width = o_right - o_left
62 o_height = o_bottom - o_top
63 if o_height >= 0 and o_width >= 0 and o_height + o_width > 0:
64 return (o_left, o_top, o_right, o_bottom)
65 else:
66 return None
67
68
69def calculate_area(bbox: T_bbox) -> T_num:
70 left, top, right, bottom = bbox
71 if left > right or top > bottom:
72 raise ValueError(f"{bbox} has a negative width or height.")
73 return (right - left) * (bottom - top)
74
75
76def clip_obj(obj: T_obj, bbox: T_bbox) -> Optional[T_obj]:
77 overlap = get_bbox_overlap(obj_to_bbox(obj), bbox)
78 if overlap is None:
79 return None
80
81 dims = bbox_to_rect(overlap)
82 copy = dict(obj)
83
84 for attr in ["x0", "top", "x1", "bottom"]:
85 copy[attr] = dims[attr]
86
87 diff = dims["top"] - obj["top"]
88 if "doctop" in copy:
89 copy["doctop"] = obj["doctop"] + diff
90 copy["width"] = copy["x1"] - copy["x0"]
91 copy["height"] = copy["bottom"] - copy["top"]
92
93 return copy
94
95
96def intersects_bbox(objs: Iterable[T_obj], bbox: T_bbox) -> T_obj_list:
97 """
98 Filters objs to only those intersecting the bbox
99 """
100 return [obj for obj in objs if get_bbox_overlap(obj_to_bbox(obj), bbox) is not None]
101
102
103def within_bbox(objs: Iterable[T_obj], bbox: T_bbox) -> T_obj_list:
104 """
105 Filters objs to only those fully within the bbox
106 """
107 return [
108 obj
109 for obj in objs
110 if get_bbox_overlap(obj_to_bbox(obj), bbox) == obj_to_bbox(obj)
111 ]
112
113
114def outside_bbox(objs: Iterable[T_obj], bbox: T_bbox) -> T_obj_list:
115 """
116 Filters objs to only those fully outside the bbox
117 """
118 return [obj for obj in objs if get_bbox_overlap(obj_to_bbox(obj), bbox) is None]
119
120
121def crop_to_bbox(objs: Iterable[T_obj], bbox: T_bbox) -> T_obj_list:
122 """
123 Filters objs to only those intersecting the bbox,
124 and crops the extent of the objects to the bbox.
125 """
126 return list(filter(None, (clip_obj(obj, bbox) for obj in objs)))
127
128
129def move_object(obj: T_obj, axis: str, value: T_num) -> T_obj:
130 assert axis in ("h", "v")
131 if axis == "h":
132 new_items = [
133 ("x0", obj["x0"] + value),
134 ("x1", obj["x1"] + value),
135 ]
136 if axis == "v":
137 new_items = [
138 ("top", obj["top"] + value),
139 ("bottom", obj["bottom"] + value),
140 ]
141 if "doctop" in obj:
142 new_items += [("doctop", obj["doctop"] + value)]
143 if "y0" in obj:
144 new_items += [
145 ("y0", obj["y0"] - value),
146 ("y1", obj["y1"] - value),
147 ]
148 return obj.__class__(tuple(obj.items()) + tuple(new_items))
149
150
151def snap_objects(objs: Iterable[T_obj], attr: str, tolerance: T_num) -> T_obj_list:
152 axis = {"x0": "h", "x1": "h", "top": "v", "bottom": "v"}[attr]
153 list_objs = list(objs)
154 clusters = cluster_objects(list_objs, itemgetter(attr), tolerance)
155 avgs = [sum(map(itemgetter(attr), cluster)) / len(cluster) for cluster in clusters]
156 snapped_clusters = [
157 [move_object(obj, axis, avg - obj[attr]) for obj in cluster]
158 for cluster, avg in zip(clusters, avgs)
159 ]
160 return list(itertools.chain(*snapped_clusters))
161
162
163def resize_object(obj: T_obj, key: str, value: T_num) -> T_obj:
164 assert key in ("x0", "x1", "top", "bottom")
165 old_value = obj[key]
166 diff = value - old_value
167 new_items = [
168 (key, value),
169 ]
170 if key == "x0":
171 assert value <= obj["x1"]
172 new_items.append(("width", obj["x1"] - value))
173 elif key == "x1":
174 assert value >= obj["x0"]
175 new_items.append(("width", value - obj["x0"]))
176 elif key == "top":
177 assert value <= obj["bottom"]
178 new_items.append(("doctop", obj["doctop"] + diff))
179 new_items.append(("height", obj["height"] - diff))
180 if "y1" in obj:
181 new_items.append(("y1", obj["y1"] - diff))
182 elif key == "bottom":
183 assert value >= obj["top"]
184 new_items.append(("height", obj["height"] + diff))
185 if "y0" in obj:
186 new_items.append(("y0", obj["y0"] - diff))
187 return obj.__class__(tuple(obj.items()) + tuple(new_items))
188
189
190def curve_to_edges(curve: T_obj) -> T_obj_list:
191 point_pairs = zip(curve["pts"], curve["pts"][1:])
192 return [
193 {
194 "object_type": "curve_edge",
195 "x0": min(p0[0], p1[0]),
196 "x1": max(p0[0], p1[0]),
197 "top": min(p0[1], p1[1]),
198 "doctop": min(p0[1], p1[1]) + (curve["doctop"] - curve["top"]),
199 "bottom": max(p0[1], p1[1]),
200 "width": abs(p0[0] - p1[0]),
201 "height": abs(p0[1] - p1[1]),
202 "orientation": "v" if p0[0] == p1[0] else ("h" if p0[1] == p1[1] else None),
203 }
204 for p0, p1 in point_pairs
205 ]
206
207
208def rect_to_edges(rect: T_obj) -> T_obj_list:
209 top, bottom, left, right = [dict(rect) for x in range(4)]
210 top.update(
211 {
212 "object_type": "rect_edge",
213 "height": 0,
214 "y0": rect["y1"],
215 "bottom": rect["top"],
216 "orientation": "h",
217 }
218 )
219 bottom.update(
220 {
221 "object_type": "rect_edge",
222 "height": 0,
223 "y1": rect["y0"],
224 "top": rect["top"] + rect["height"],
225 "doctop": rect["doctop"] + rect["height"],
226 "orientation": "h",
227 }
228 )
229 left.update(
230 {
231 "object_type": "rect_edge",
232 "width": 0,
233 "x1": rect["x0"],
234 "orientation": "v",
235 }
236 )
237 right.update(
238 {
239 "object_type": "rect_edge",
240 "width": 0,
241 "x0": rect["x1"],
242 "orientation": "v",
243 }
244 )
245 return [top, bottom, left, right]
246
247
248def line_to_edge(line: T_obj) -> T_obj:
249 edge = dict(line)
250 edge["orientation"] = "h" if (line["top"] == line["bottom"]) else "v"
251 return edge
252
253
254def obj_to_edges(obj: T_obj) -> T_obj_list:
255 t = obj["object_type"]
256 if "_edge" in t:
257 return [obj]
258 elif t == "line":
259 return [line_to_edge(obj)]
260 else:
261 return {"rect": rect_to_edges, "curve": curve_to_edges}[t](obj)
262
263
264def filter_edges(
265 edges: Iterable[T_obj],
266 orientation: Optional[str] = None,
267 edge_type: Optional[str] = None,
268 min_length: T_num = 1,
269) -> T_obj_list:
270 if orientation not in ("v", "h", None):
271 raise ValueError("Orientation must be 'v' or 'h'")
272
273 def test(e: T_obj) -> bool:
274 dim = "height" if e["orientation"] == "v" else "width"
275 et_correct = e["object_type"] == edge_type if edge_type is not None else True
276 orient_correct = orientation is None or e["orientation"] == orientation
277 return bool(et_correct and orient_correct and (e[dim] >= min_length))
278
279 return list(filter(test, edges))