1"""Confusion group resolution for similar single-byte encodings.
2
3At runtime, loads pre-computed distinguishing byte maps from confusion.bin
4and uses them to resolve statistical scoring ties between similar encodings.
5
6Build-time computation (``compute_confusion_groups``, ``compute_distinguishing_maps``,
7``serialize_confusion_data``) lives in ``scripts/confusion_training.py``.
8
9Note: ``from __future__ import annotations`` is intentionally omitted because
10this module is compiled with mypyc, which does not support PEP 563 string
11annotations.
12"""
13
14import functools
15import importlib.resources
16import struct
17import warnings
18
19from chardet.models import (
20 BigramProfile,
21 get_enc_index,
22 get_idf_weights,
23 score_with_profile,
24)
25from chardet.pipeline import DetectionResult
26from chardet.registry import lookup_encoding
27
28# Type alias for the distinguishing map structure:
29# Maps (enc_a, enc_b) -> (distinguishing_byte_set, {byte_val: (cat_a, cat_b)})
30DistinguishingMaps = dict[
31 tuple[str, str],
32 tuple[frozenset[int], dict[int, tuple[str, str]]],
33]
34
35# uint8 -> Unicode general category, inverse of the mapping in
36# scripts/confusion_training.py used at serialization time.
37_INT_TO_CATEGORY: dict[int, str] = {
38 0: "Lu",
39 1: "Ll",
40 2: "Lt",
41 3: "Lm",
42 4: "Lo",
43 5: "Mn",
44 6: "Mc",
45 7: "Me",
46 8: "Nd",
47 9: "Nl",
48 10: "No",
49 11: "Pc",
50 12: "Pd",
51 13: "Ps",
52 14: "Pe",
53 15: "Pi",
54 16: "Pf",
55 17: "Po",
56 18: "Sm",
57 19: "Sc",
58 20: "Sk",
59 21: "So",
60 22: "Zs",
61 23: "Zl",
62 24: "Zp",
63 25: "Cc",
64 26: "Cf",
65 27: "Cs",
66 28: "Co",
67 29: "Cn",
68}
69
70# Inverse mapping for serialization — used by scripts/confusion_training.py.
71_CATEGORY_TO_INT: dict[str, int] = {v: k for k, v in _INT_TO_CATEGORY.items()}
72
73
74def deserialize_confusion_data_from_bytes(data: bytes) -> DistinguishingMaps:
75 """Load confusion group data from raw bytes.
76
77 :param data: The raw binary content of a confusion.bin file.
78 :returns: A :data:`DistinguishingMaps` dictionary keyed by encoding pairs.
79 """
80 result: DistinguishingMaps = {}
81 offset = 0
82 (num_pairs,) = struct.unpack_from("!H", data, offset)
83 offset += 2
84
85 for _ in range(num_pairs):
86 (name_a_len,) = struct.unpack_from("!B", data, offset)
87 offset += 1
88 name_a = data[offset : offset + name_a_len].decode("utf-8")
89 offset += name_a_len
90
91 (name_b_len,) = struct.unpack_from("!B", data, offset)
92 offset += 1
93 name_b = data[offset : offset + name_b_len].decode("utf-8")
94 offset += name_b_len
95
96 (num_diffs,) = struct.unpack_from("!B", data, offset)
97 offset += 1
98
99 diff_bytes_list: list[int] = []
100 categories: dict[int, tuple[str, str]] = {}
101 for _ in range(num_diffs):
102 bv, cat_a_int, cat_b_int = struct.unpack_from("!BBB", data, offset)
103 offset += 3
104 diff_bytes_list.append(bv)
105 categories[bv] = (
106 _INT_TO_CATEGORY.get(cat_a_int, "Cn"),
107 _INT_TO_CATEGORY.get(cat_b_int, "Cn"),
108 )
109 result[(name_a, name_b)] = (frozenset(diff_bytes_list), categories)
110
111 return result
112
113
114@functools.cache
115def load_confusion_data() -> DistinguishingMaps:
116 """Load confusion group data from the bundled confusion.bin file.
117
118 :returns: A :data:`DistinguishingMaps` dictionary keyed by encoding pairs.
119 """
120 ref = importlib.resources.files("chardet.models").joinpath("confusion.bin")
121 raw = ref.read_bytes()
122 if not raw:
123 warnings.warn(
124 "chardet confusion.bin is empty — confusion resolution disabled; "
125 "reinstall chardet to fix",
126 RuntimeWarning,
127 stacklevel=2,
128 )
129 return {}
130 try:
131 raw_maps = deserialize_confusion_data_from_bytes(raw)
132 except (struct.error, UnicodeDecodeError) as e:
133 msg = f"corrupt confusion.bin: {e}"
134 raise ValueError(msg) from e
135 # Normalize keys to canonical codec names so pipeline output matches.
136 normalized: DistinguishingMaps = {}
137 for (a, b), value in raw_maps.items():
138 norm_a = lookup_encoding(a) or a
139 norm_b = lookup_encoding(b) or b
140 normalized[(norm_a, norm_b)] = value
141 return normalized
142
143
144# Unicode general category preference scores for voting resolution.
145# Higher scores indicate more linguistically meaningful characters.
146_CATEGORY_PREFERENCE: dict[str, int] = {
147 "Lu": 10,
148 "Ll": 10,
149 "Lt": 10,
150 "Lm": 9,
151 "Lo": 9,
152 "Nd": 8,
153 "Nl": 7,
154 "No": 7,
155 "Pc": 6,
156 "Pd": 6,
157 "Ps": 6,
158 "Pe": 6,
159 "Pi": 6,
160 "Pf": 6,
161 "Po": 6,
162 "Sc": 5,
163 "Sm": 5,
164 "Sk": 4,
165 "So": 4,
166 "Zs": 3,
167 "Zl": 3,
168 "Zp": 3,
169 "Cf": 2,
170 "Cc": 1,
171 "Co": 1,
172 "Cs": 0,
173 "Cn": 0,
174 "Mn": 5,
175 "Mc": 5,
176 "Me": 5,
177}
178
179
180def resolve_by_category_voting(
181 data: bytes,
182 enc_a: str,
183 enc_b: str,
184 diff_bytes: frozenset[int],
185 categories: dict[int, tuple[str, str]],
186) -> str | None:
187 """Resolve between two encodings using Unicode category voting.
188
189 For each distinguishing byte present in the data, compare the Unicode
190 general category under each encoding. The encoding whose interpretation
191 has the higher category preference score gets a vote. The encoding with
192 more votes wins.
193
194 :param data: The raw byte data to examine.
195 :param enc_a: First encoding name.
196 :param enc_b: Second encoding name.
197 :param diff_bytes: Byte values where the two encodings differ.
198 :param categories: Mapping of byte value to ``(cat_a, cat_b)`` Unicode
199 general category pairs.
200 :returns: The winning encoding name, or ``None`` if tied.
201 """
202 votes_a = 0
203 votes_b = 0
204 relevant = frozenset(data) & diff_bytes
205 if not relevant:
206 return None
207 for bv in relevant:
208 cat_a, cat_b = categories[bv]
209 pref_a = _CATEGORY_PREFERENCE.get(cat_a, 0)
210 pref_b = _CATEGORY_PREFERENCE.get(cat_b, 0)
211 if pref_a > pref_b:
212 votes_a += pref_a - pref_b
213 elif pref_b > pref_a:
214 votes_b += pref_b - pref_a
215 if votes_a > votes_b:
216 return enc_a
217 if votes_b > votes_a:
218 return enc_b
219 return None
220
221
222def _best_variant_score(
223 profile: BigramProfile,
224 index: dict[str, list[tuple[str | None, memoryview, str]]],
225 enc: str,
226) -> float:
227 """Return the best bigram score across all language variants for *enc*."""
228 variants = index.get(enc)
229 if not variants:
230 return 0.0
231 return max(
232 score_with_profile(profile, model, model_key)
233 for _, model, model_key in variants
234 )
235
236
237def resolve_by_bigram_rescore(
238 data: bytes,
239 enc_a: str,
240 enc_b: str,
241 diff_bytes: frozenset[int],
242) -> str | None:
243 """Resolve between two encodings by re-scoring only distinguishing bigrams.
244
245 Builds a focused bigram profile containing only bigrams where at least one
246 byte is a distinguishing byte, then scores both encodings against their
247 best language model.
248
249 :param data: The raw byte data to examine.
250 :param enc_a: First encoding name.
251 :param enc_b: Second encoding name.
252 :param diff_bytes: Byte values where the two encodings differ.
253 :returns: The winning encoding name, or ``None`` if tied.
254 """
255 if len(data) < 2:
256 return None
257
258 idf = get_idf_weights()
259 freq: dict[int, int] = {}
260 for i in range(len(data) - 1):
261 b1 = data[i]
262 b2 = data[i + 1]
263 if b1 not in diff_bytes and b2 not in diff_bytes:
264 continue
265 idx = (b1 << 8) | b2
266 freq[idx] = freq.get(idx, 0) + idf[idx]
267
268 if not freq:
269 return None
270
271 profile = BigramProfile.from_weighted_freq(freq)
272
273 index = get_enc_index()
274 best_a = _best_variant_score(profile, index, enc_a)
275 best_b = _best_variant_score(profile, index, enc_b)
276
277 if best_a > best_b:
278 return enc_a
279 if best_b > best_a:
280 return enc_b
281 return None
282
283
284def _find_pair_key(
285 maps: DistinguishingMaps,
286 enc_a: str,
287 enc_b: str,
288) -> tuple[str, str] | None:
289 """Find the canonical key for a pair of encodings in the confusion maps."""
290 if (enc_a, enc_b) in maps:
291 return (enc_a, enc_b)
292 if (enc_b, enc_a) in maps:
293 return (enc_b, enc_a)
294 return None
295
296
297# Maximum confidence gap from the top result for candidates beyond
298# position 1 to participate in confusion resolution.
299_CONFUSION_BAND = 0.005
300
301
302def resolve_confusion_groups(
303 data: bytes,
304 results: list[DetectionResult],
305) -> list[DetectionResult]:
306 """Resolve confusion between similar encodings in the top results.
307
308 Checks the top result against each candidate within a confidence band.
309 Always checks position 1 (preserving original top-2 behavior); for
310 positions 2+ only checks within the band. Uses bigram re-scoring
311 with category voting as fallback.
312
313 :param data: The raw byte data to examine.
314 :param results: Detection results sorted by confidence descending.
315 :returns: A reordered list of :class:`DetectionResult` with the winner first.
316 """
317 if len(results) < 2:
318 return results
319
320 top = results[0]
321 if top.encoding is None:
322 return results
323
324 maps = load_confusion_data()
325 top_conf = top.confidence
326
327 for i in range(1, len(results)):
328 candidate = results[i]
329 if candidate.encoding is None:
330 continue
331 # Always check position 1 (original top-2 behavior).
332 # For positions 2+, only check within the confidence band.
333 if i > 1 and top_conf - candidate.confidence > _CONFUSION_BAND:
334 break
335
336 pair_key = _find_pair_key(maps, top.encoding, candidate.encoding)
337 if pair_key is None:
338 continue
339
340 diff_bytes, categories = maps[pair_key]
341 enc_a, enc_b = pair_key
342
343 cat_winner = resolve_by_category_voting(
344 data, enc_a, enc_b, diff_bytes, categories
345 )
346 bigram_winner = resolve_by_bigram_rescore(data, enc_a, enc_b, diff_bytes)
347 winner = bigram_winner if bigram_winner is not None else cat_winner
348
349 if winner is not None and winner == candidate.encoding:
350 # Give the promoted candidate the top result's confidence so
351 # the promotion survives any downstream confidence-based sort.
352 promoted = DetectionResult(
353 candidate.encoding,
354 top.confidence,
355 candidate.language,
356 candidate.mime_type,
357 )
358 rest = [r for j, r in enumerate(results) if j != i]
359 return [promoted, *rest]
360
361 return results