1"""Model loading and bigram scoring utilities.
2
3Note: ``from __future__ import annotations`` is intentionally omitted because
4this module is compiled with mypyc, which does not support PEP 563 string
5annotations.
6"""
7
8import functools
9import importlib.resources
10import math
11import struct
12import warnings
13
14from chardet.registry import REGISTRY, lookup_encoding
15
16_unpack_uint32 = struct.Struct(">I").unpack_from
17_iter_3bytes = struct.Struct(">BBB").iter_unpack
18
19#: Weight applied to non-ASCII bigrams during profile construction.
20#: Imported by pipeline/confusion.py for focused bigram re-scoring.
21NON_ASCII_BIGRAM_WEIGHT: int = 8
22# Encodings that map to exactly one language, derived from the registry.
23# Keyed by canonical name only — callers always use canonical names.
24_SINGLE_LANG_MAP: dict[str, str] = {}
25for _enc in REGISTRY.values():
26 if len(_enc.languages) == 1:
27 _SINGLE_LANG_MAP[_enc.name] = _enc.languages[0]
28
29
30def _parse_models_bin(
31 data: bytes,
32) -> tuple[dict[str, bytearray], dict[str, float]]:
33 """Parse the binary models.bin format into model tables and L2 norms.
34
35 :param data: Raw bytes of models.bin (must be non-empty).
36 :returns: A ``(models, norms)`` tuple.
37 :raises ValueError: If the data is corrupt or truncated.
38 """
39 models: dict[str, bytearray] = {}
40 norms: dict[str, float] = {}
41 _sqrt = math.sqrt
42 _unpack_u32 = _unpack_uint32
43 _iter_bbb = _iter_3bytes
44 try:
45 offset = 0
46 (num_encodings,) = _unpack_u32(data, offset)
47 offset += 4
48
49 if num_encodings > 10_000:
50 msg = f"corrupt models.bin: num_encodings={num_encodings} exceeds limit"
51 raise ValueError(msg)
52
53 for _ in range(num_encodings):
54 (name_len,) = _unpack_u32(data, offset)
55 offset += 4
56 if name_len > 256:
57 msg = f"corrupt models.bin: name_len={name_len} exceeds 256"
58 raise ValueError(msg)
59 name = data[offset : offset + name_len].decode("utf-8")
60 offset += name_len
61 (num_entries,) = _unpack_u32(data, offset)
62 offset += 4
63 if num_entries > 65536:
64 msg = f"corrupt models.bin: num_entries={num_entries} exceeds 65536"
65 raise ValueError(msg)
66
67 table = bytearray(65536)
68 sq_sum = 0
69 expected_bytes = num_entries * 3
70 chunk = data[offset : offset + expected_bytes]
71 if len(chunk) != expected_bytes:
72 msg = f"corrupt models.bin: truncated entry data for {name!r}"
73 raise ValueError(msg)
74 offset += expected_bytes
75 for b1, b2, weight in _iter_bbb(chunk):
76 table[(b1 << 8) | b2] = weight
77 sq_sum += weight * weight
78 models[name] = table
79 norms[name] = _sqrt(sq_sum)
80 except (struct.error, UnicodeDecodeError) as e:
81 msg = f"corrupt models.bin: {e}"
82 raise ValueError(msg) from e
83
84 return models, norms
85
86
87@functools.cache
88def _load_models_data() -> tuple[dict[str, bytearray], dict[str, float]]:
89 """Load and parse models.bin, returning (models, norms).
90
91 Cached: only reads from disk on first call.
92 """
93 ref = importlib.resources.files("chardet.models").joinpath("models.bin")
94 data = ref.read_bytes()
95
96 if not data:
97 warnings.warn(
98 "chardet models.bin is empty — statistical detection disabled; "
99 "reinstall chardet to fix",
100 RuntimeWarning,
101 stacklevel=2,
102 )
103 return {}, {}
104
105 return _parse_models_bin(data)
106
107
108def load_models() -> dict[str, bytearray]:
109 """Load all bigram models from the bundled models.bin file.
110
111 Each model is a bytearray of length 65536 (256*256).
112 Index: (b1 << 8) | b2 -> weight (0-255).
113
114 :returns: A dict mapping model key strings to 65536-byte lookup tables.
115 """
116 return _load_models_data()[0]
117
118
119def _build_enc_index(
120 models: dict[str, bytearray],
121) -> dict[str, list[tuple[str | None, bytearray, str]]]:
122 """Build a grouped index from a models dict.
123
124 :param models: Mapping of ``"lang/encoding"`` keys to 65536-byte tables.
125 :returns: Mapping of encoding name to ``[(lang, model, model_key), ...]``.
126 """
127 index: dict[str, list[tuple[str | None, bytearray, str]]] = {}
128 for key, model in models.items():
129 lang, enc = key.split("/", 1)
130 index.setdefault(enc, []).append((lang, model, key))
131
132 # Resolve aliases: if a model key uses a non-canonical name,
133 # copy the entry under the canonical name.
134 for enc_name in list(index):
135 canonical = lookup_encoding(enc_name)
136 if canonical is not None and canonical not in index:
137 index[canonical] = index[enc_name]
138
139 return index
140
141
142@functools.cache
143def get_enc_index() -> dict[str, list[tuple[str | None, bytearray, str]]]:
144 """Return a pre-grouped index mapping encoding name -> [(lang, model, model_key), ...]."""
145 return _build_enc_index(load_models())
146
147
148def infer_language(encoding: str) -> str | None:
149 """Return the language for a single-language encoding, or None.
150
151 :param encoding: The canonical encoding name.
152 :returns: An ISO 639-1 language code, or ``None`` if the encoding is
153 multi-language.
154 """
155 return _SINGLE_LANG_MAP.get(encoding)
156
157
158def has_model_variants(encoding: str) -> bool:
159 """Return True if the encoding has language variants in the model index.
160
161 :param encoding: The canonical encoding name.
162 :returns: ``True`` if bigram models exist for this encoding.
163 """
164 return encoding in get_enc_index()
165
166
167def _get_model_norms() -> dict[str, float]:
168 """Return cached L2 norms for all models, keyed by model key string."""
169 return _load_models_data()[1]
170
171
172class BigramProfile:
173 """Pre-computed bigram frequency distribution for a data sample.
174
175 Computing this once and reusing it across all models reduces per-model
176 scoring from O(n) to O(distinct_bigrams).
177
178 Stores a single ``weighted_freq`` dict mapping bigram index to
179 *count * weight* (weight is 8 for non-ASCII bigrams, 1 otherwise).
180 This pre-multiplies the weight during construction so the scoring
181 inner loop only needs a single dict traversal with no branching.
182 """
183
184 __slots__ = ("input_norm", "weight_sum", "weighted_freq")
185
186 def __init__(self, data: bytes) -> None:
187 """Compute the bigram frequency distribution for *data*.
188
189 :param data: The raw byte data to profile.
190 """
191 total_bigrams = len(data) - 1
192 if total_bigrams <= 0:
193 self.weighted_freq: dict[int, int] = {}
194 self.weight_sum: int = 0
195 self.input_norm: float = 0.0
196 return
197
198 freq: dict[int, int] = {}
199 w_sum = 0
200 hi_w = NON_ASCII_BIGRAM_WEIGHT
201 _get = freq.get
202 for i in range(total_bigrams):
203 b1 = data[i]
204 b2 = data[i + 1]
205 idx = (b1 << 8) | b2
206 if b1 > 0x7F or b2 > 0x7F:
207 freq[idx] = _get(idx, 0) + hi_w
208 w_sum += hi_w
209 else:
210 freq[idx] = _get(idx, 0) + 1
211 w_sum += 1
212 self.weighted_freq = freq
213 self.weight_sum = w_sum
214 self.input_norm = math.sqrt(sum(v * v for v in freq.values()))
215
216 @classmethod
217 def from_weighted_freq(cls, weighted_freq: dict[int, int]) -> "BigramProfile":
218 """Create a BigramProfile from pre-computed weighted frequencies.
219
220 Computes ``weight_sum`` and ``input_norm`` from *weighted_freq* to
221 ensure consistency between the three fields.
222
223 :param weighted_freq: Mapping of bigram index to weighted count.
224 :returns: A new :class:`BigramProfile` instance.
225 """
226 profile = cls(b"")
227 profile.weighted_freq = weighted_freq
228 profile.weight_sum = sum(weighted_freq.values())
229 profile.input_norm = math.sqrt(sum(v * v for v in weighted_freq.values()))
230 return profile
231
232
233def score_with_profile(
234 profile: BigramProfile, model: bytearray, model_key: str = ""
235) -> float:
236 """Score a pre-computed bigram profile against a single model using cosine similarity."""
237 if profile.input_norm == 0.0:
238 return 0.0
239 norms = _get_model_norms()
240 model_norm = norms.get(model_key) if model_key else None
241 if model_norm is None:
242 sq_sum = 0
243 for i in range(65536):
244 v = model[i]
245 if v:
246 sq_sum += v * v
247 model_norm = math.sqrt(sq_sum)
248 if model_norm == 0.0:
249 return 0.0
250 dot = 0
251 for idx, wcount in profile.weighted_freq.items():
252 dot += model[idx] * wcount
253 return dot / (model_norm * profile.input_norm)
254
255
256def score_best_language(
257 data: bytes,
258 encoding: str,
259 profile: BigramProfile | None = None,
260) -> tuple[float, str | None]:
261 """Score data against all language variants of an encoding.
262
263 Returns (best_score, best_language). Uses a pre-grouped index for O(L)
264 lookup where L is the number of language variants for the encoding.
265
266 If *profile* is provided, it is reused instead of recomputing the bigram
267 frequency distribution from *data*.
268
269 :param data: The raw byte data to score.
270 :param encoding: The canonical encoding name to match against.
271 :param profile: Optional pre-computed :class:`BigramProfile` to reuse.
272 :returns: A ``(score, language)`` tuple with the best cosine-similarity
273 score and the corresponding language code (or ``None``).
274 """
275 if not data and profile is None:
276 return 0.0, None
277
278 index = get_enc_index()
279 variants = index.get(encoding)
280 if variants is None:
281 return 0.0, None
282
283 if profile is None:
284 profile = BigramProfile(data)
285
286 best_score = 0.0
287 best_lang: str | None = None
288 for lang, model, model_key in variants:
289 s = score_with_profile(profile, model, model_key)
290 if s > best_score:
291 best_score = s
292 best_lang = lang
293
294 return best_score, best_lang