1"""Model loading and bigram scoring utilities.
2
3Note: ``from __future__ import annotations`` is intentionally omitted because
4this module is compiled with mypyc, which does not support PEP 563 string
5annotations.
6"""
7
8import functools
9import importlib.resources
10import math
11import struct
12import warnings
13import zlib
14
15from chardet.registry import REGISTRY, lookup_encoding
16
17_unpack_uint32 = struct.Struct(">I").unpack_from
18_unpack_float64 = struct.Struct(">d").unpack_from
19_V2_MAGIC = b"CMD2"
20
21# Encodings that map to exactly one language, derived from the registry.
22# Keyed by canonical name only — callers always use canonical names.
23_SINGLE_LANG_MAP: dict[str, str] = {}
24for _enc in REGISTRY.values():
25 if len(_enc.languages) == 1:
26 _SINGLE_LANG_MAP[_enc.name] = _enc.languages[0]
27
28
29def _parse_models_bin(
30 data: bytes,
31) -> tuple[dict[str, memoryview], dict[str, float]]:
32 """Parse the v2 dense zlib-compressed models.bin format.
33
34 :param data: Raw bytes of models.bin (must be non-empty).
35 :returns: A ``(models, norms)`` tuple.
36 :raises ValueError: If the data is corrupt or truncated.
37 """
38 try:
39 if data[:4] != _V2_MAGIC:
40 msg = "corrupt models.bin: missing CMD2 magic"
41 raise ValueError(msg)
42
43 offset = 4 # skip magic
44 (num_models,) = _unpack_uint32(data, offset)
45 offset += 4
46
47 if num_models > 10_000:
48 msg = f"corrupt models.bin: num_models={num_models} exceeds limit"
49 raise ValueError(msg)
50
51 names: list[str] = []
52 norms: dict[str, float] = {}
53 for _ in range(num_models):
54 (name_len,) = _unpack_uint32(data, offset)
55 offset += 4
56 if name_len > 256:
57 msg = f"corrupt models.bin: name_len={name_len} exceeds 256"
58 raise ValueError(msg)
59 name = data[offset : offset + name_len].decode("utf-8")
60 offset += name_len
61 (norm,) = _unpack_float64(data, offset)
62 offset += 8
63 names.append(name)
64 norms[name] = norm
65
66 # zlib.decompress is faster than decompressobj; trailing bytes are
67 # unlikely in bundled data and would not affect correctness since we
68 # validate decompressed size. train.py uses decompressobj for
69 # stricter checking during model generation.
70 blob = zlib.decompress(data[offset:])
71 expected_size = num_models * 65536
72 if len(blob) != expected_size:
73 msg = (
74 f"corrupt models.bin: decompressed size {len(blob)} "
75 f"!= expected {expected_size}"
76 )
77 raise ValueError(msg)
78
79 # memoryview slices avoid copies; the blob bytes object is kept
80 # alive by the functools.cache on _load_models_data().
81 mv = memoryview(blob)
82 models: dict[str, memoryview] = {}
83 for i, name in enumerate(names):
84 start = i * 65536
85 models[name] = mv[start : start + 65536]
86
87 except zlib.error as e:
88 msg = f"corrupt models.bin: {e}"
89 raise ValueError(msg) from e
90 except (struct.error, UnicodeDecodeError) as e:
91 msg = f"corrupt models.bin: {e}"
92 raise ValueError(msg) from e
93
94 return models, norms
95
96
97@functools.cache
98def _load_models_data() -> tuple[dict[str, memoryview], dict[str, float]]:
99 """Load and parse models.bin, returning (models, norms).
100
101 Cached: only reads from disk on first call.
102 """
103 ref = importlib.resources.files("chardet.models").joinpath("models.bin")
104 data = ref.read_bytes()
105
106 if not data:
107 warnings.warn(
108 "chardet models.bin is empty — statistical detection disabled; "
109 "reinstall chardet to fix",
110 RuntimeWarning,
111 stacklevel=2,
112 )
113 return {}, {}
114
115 return _parse_models_bin(data)
116
117
118def load_models() -> dict[str, memoryview]:
119 """Load all bigram models from the bundled models.bin file.
120
121 Each model is a memoryview of length 65536 (256*256).
122 Index: (b1 << 8) | b2 -> weight (0-255).
123
124 :returns: A dict mapping model key strings to 65536-byte lookup tables.
125 """
126 return _load_models_data()[0]
127
128
129def _build_enc_index(
130 models: dict[str, memoryview],
131) -> dict[str, list[tuple[str | None, memoryview, str]]]:
132 """Build a grouped index from a models dict.
133
134 :param models: Mapping of ``"lang/encoding"`` keys to 65536-byte tables.
135 :returns: Mapping of encoding name to ``[(lang, model, model_key), ...]``.
136 """
137 index: dict[str, list[tuple[str | None, memoryview, str]]] = {}
138 for key, model in models.items():
139 lang, enc = key.split("/", 1)
140 index.setdefault(enc, []).append((lang, model, key))
141
142 # Resolve aliases: if a model key uses a non-canonical name,
143 # copy the entry under the canonical name.
144 for enc_name in list(index):
145 canonical = lookup_encoding(enc_name)
146 if canonical is not None and canonical not in index:
147 index[canonical] = index[enc_name]
148
149 return index
150
151
152@functools.cache
153def get_enc_index() -> dict[str, list[tuple[str | None, memoryview, str]]]:
154 """Return a pre-grouped index mapping encoding name -> [(lang, model, model_key), ...]."""
155 return _build_enc_index(load_models())
156
157
158def infer_language(encoding: str) -> str | None:
159 """Return the language for a single-language encoding, or None.
160
161 :param encoding: The canonical encoding name.
162 :returns: An ISO 639-1 language code, or ``None`` if the encoding is
163 multi-language.
164 """
165 return _SINGLE_LANG_MAP.get(encoding)
166
167
168def has_model_variants(encoding: str) -> bool:
169 """Return True if the encoding has language variants in the model index.
170
171 :param encoding: The canonical encoding name.
172 :returns: ``True`` if bigram models exist for this encoding.
173 """
174 return encoding in get_enc_index()
175
176
177def _get_model_norms() -> dict[str, float]:
178 """Return cached L2 norms for all models, keyed by model key string."""
179 return _load_models_data()[1]
180
181
182@functools.cache
183def get_idf_weights() -> bytearray:
184 """Return a 65536-byte IDF weight table for bigram profile construction.
185
186 Loads a precomputed table from ``idf.bin`` (generated at training time).
187 For each bigram index, the weight reflects how discriminative that bigram
188 is across all models:
189
190 - Bigrams in every model (common ASCII) → weight 1 (minimal signal)
191 - Bigrams in one model → weight 255 (maximum signal)
192 - Bigrams not in any model → weight 1 (unknown, treat as neutral)
193 """
194 ref = importlib.resources.files("chardet.models").joinpath("idf.bin")
195 data = ref.read_bytes()
196 if len(data) != 65536:
197 warnings.warn(
198 f"chardet idf.bin has wrong size ({len(data)}), "
199 "falling back to uniform weights",
200 RuntimeWarning,
201 stacklevel=2,
202 )
203 return bytearray(b"\x01" * 65536)
204 return bytearray(data)
205
206
207class BigramProfile:
208 """Pre-computed bigram frequency distribution for a data sample.
209
210 Computing this once and reusing it across all models reduces per-model
211 scoring from O(n) to O(distinct_bigrams).
212
213 Stores a dense ``freq`` list of length 65536 indexed by bigram index, plus
214 a ``nonzero`` list of indices with non-zero frequency for fast iteration.
215 Each bigram is weighted by its IDF (inverse document frequency) across all
216 models — bigrams unique to few models get high weight, bigrams common to
217 all models get weight 1.
218 """
219
220 __slots__ = ("freq", "input_norm", "nonzero", "weight_sum")
221
222 def __init__(self, data: bytes) -> None:
223 """Compute the bigram frequency distribution for *data*.
224
225 Each bigram is weighted by its IDF (inverse document frequency) across
226 all loaded models. Bigrams unique to few models get high weight;
227 bigrams common to all models get weight 1.
228
229 :param data: The raw byte data to profile.
230 """
231 total_bigrams = len(data) - 1
232 if total_bigrams <= 0:
233 # Use empty lists (not [0]*65536) to avoid a 256KB allocation
234 # for no-op profiles. Safe because score_with_profile returns
235 # early when input_norm == 0.0, so freq is never indexed.
236 self.freq: list[int] = []
237 self.nonzero: list[int] = []
238 self.weight_sum: int = 0
239 self.input_norm: float = 0.0
240 return
241
242 idf = get_idf_weights()
243 freq: list[int] = [0] * 65536
244 nonzero: list[int] = []
245 w_sum = 0
246 for i in range(total_bigrams):
247 idx = (data[i] << 8) | data[i + 1]
248 w = idf[idx]
249 if freq[idx] == 0:
250 nonzero.append(idx)
251 freq[idx] += w
252 w_sum += w
253 self.freq = freq
254 self.nonzero = nonzero
255 self.weight_sum = w_sum
256 norm_sq = 0
257 for idx in nonzero:
258 v = freq[idx]
259 norm_sq += v * v
260 self.input_norm = math.sqrt(norm_sq)
261
262 @classmethod
263 def from_weighted_freq(cls, weighted_freq: dict[int, int]) -> "BigramProfile":
264 """Create a BigramProfile from pre-computed weighted frequencies.
265
266 Computes ``weight_sum`` and ``input_norm`` from *weighted_freq* to
267 ensure consistency between the stored fields.
268
269 :param weighted_freq: Mapping of bigram index to weighted count.
270 :returns: A new :class:`BigramProfile` instance.
271 """
272 profile = cls(b"")
273 freq: list[int] = [0] * 65536
274 nonzero: list[int] = []
275 for idx, count in weighted_freq.items():
276 freq[idx] = count
277 if count:
278 nonzero.append(idx)
279 profile.freq = freq
280 profile.nonzero = nonzero
281 profile.weight_sum = sum(weighted_freq.values())
282 profile.input_norm = math.sqrt(sum(v * v for v in weighted_freq.values()))
283 return profile
284
285
286def score_with_profile(
287 profile: BigramProfile, model: bytearray | memoryview, model_key: str = ""
288) -> float:
289 """Score a pre-computed bigram profile against a single model using cosine similarity."""
290 if profile.input_norm == 0.0:
291 return 0.0
292 norms = _get_model_norms()
293 model_norm = norms.get(model_key) if model_key else None
294 if model_norm is None:
295 sq_sum = 0
296 for i in range(65536):
297 v = model[i]
298 if v:
299 sq_sum += v * v
300 model_norm = math.sqrt(sq_sum)
301 if model_norm == 0.0:
302 return 0.0
303 dot = 0
304 freq = profile.freq
305 for idx in profile.nonzero:
306 dot += model[idx] * freq[idx]
307 return dot / (model_norm * profile.input_norm)
308
309
310def score_best_language(
311 data: bytes,
312 encoding: str,
313 profile: BigramProfile | None = None,
314) -> tuple[float, str | None]:
315 """Score data against all language variants of an encoding.
316
317 Returns (best_score, best_language). Uses a pre-grouped index for O(L)
318 lookup where L is the number of language variants for the encoding.
319
320 If *profile* is provided, it is reused instead of recomputing the bigram
321 frequency distribution from *data*.
322
323 :param data: The raw byte data to score.
324 :param encoding: The canonical encoding name to match against.
325 :param profile: Optional pre-computed :class:`BigramProfile` to reuse.
326 :returns: A ``(score, language)`` tuple with the best cosine-similarity
327 score and the corresponding language code (or ``None``).
328 """
329 if not data and profile is None:
330 return 0.0, None
331
332 index = get_enc_index()
333 variants = index.get(encoding)
334 if variants is None:
335 return 0.0, None
336
337 if profile is None:
338 profile = BigramProfile(data)
339
340 best_score = 0.0
341 best_lang: str | None = None
342 for lang, model, model_key in variants:
343 s = score_with_profile(profile, model, model_key)
344 if s > best_score:
345 best_score = s
346 best_lang = lang
347
348 return best_score, best_lang