Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/chardet/models/__init__.py: 21%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

153 statements  

1"""Model loading and bigram scoring utilities. 

2 

3Note: ``from __future__ import annotations`` is intentionally omitted because 

4this module is compiled with mypyc, which does not support PEP 563 string 

5annotations. 

6""" 

7 

8import functools 

9import importlib.resources 

10import math 

11import struct 

12import warnings 

13 

14from chardet.registry import REGISTRY, lookup_encoding 

15 

16_unpack_uint32 = struct.Struct(">I").unpack_from 

17_iter_3bytes = struct.Struct(">BBB").iter_unpack 

18 

19#: Weight applied to non-ASCII bigrams during profile construction. 

20#: Imported by pipeline/confusion.py for focused bigram re-scoring. 

21NON_ASCII_BIGRAM_WEIGHT: int = 8 

22# Encodings that map to exactly one language, derived from the registry. 

23# Keyed by canonical name only — callers always use canonical names. 

24_SINGLE_LANG_MAP: dict[str, str] = {} 

25for _enc in REGISTRY.values(): 

26 if len(_enc.languages) == 1: 

27 _SINGLE_LANG_MAP[_enc.name] = _enc.languages[0] 

28 

29 

30def _parse_models_bin( 

31 data: bytes, 

32) -> tuple[dict[str, bytearray], dict[str, float]]: 

33 """Parse the binary models.bin format into model tables and L2 norms. 

34 

35 :param data: Raw bytes of models.bin (must be non-empty). 

36 :returns: A ``(models, norms)`` tuple. 

37 :raises ValueError: If the data is corrupt or truncated. 

38 """ 

39 models: dict[str, bytearray] = {} 

40 norms: dict[str, float] = {} 

41 _sqrt = math.sqrt 

42 _unpack_u32 = _unpack_uint32 

43 _iter_bbb = _iter_3bytes 

44 try: 

45 offset = 0 

46 (num_encodings,) = _unpack_u32(data, offset) 

47 offset += 4 

48 

49 if num_encodings > 10_000: 

50 msg = f"corrupt models.bin: num_encodings={num_encodings} exceeds limit" 

51 raise ValueError(msg) 

52 

53 for _ in range(num_encodings): 

54 (name_len,) = _unpack_u32(data, offset) 

55 offset += 4 

56 if name_len > 256: 

57 msg = f"corrupt models.bin: name_len={name_len} exceeds 256" 

58 raise ValueError(msg) 

59 name = data[offset : offset + name_len].decode("utf-8") 

60 offset += name_len 

61 (num_entries,) = _unpack_u32(data, offset) 

62 offset += 4 

63 if num_entries > 65536: 

64 msg = f"corrupt models.bin: num_entries={num_entries} exceeds 65536" 

65 raise ValueError(msg) 

66 

67 table = bytearray(65536) 

68 sq_sum = 0 

69 expected_bytes = num_entries * 3 

70 chunk = data[offset : offset + expected_bytes] 

71 if len(chunk) != expected_bytes: 

72 msg = f"corrupt models.bin: truncated entry data for {name!r}" 

73 raise ValueError(msg) 

74 offset += expected_bytes 

75 for b1, b2, weight in _iter_bbb(chunk): 

76 table[(b1 << 8) | b2] = weight 

77 sq_sum += weight * weight 

78 models[name] = table 

79 norms[name] = _sqrt(sq_sum) 

80 except (struct.error, UnicodeDecodeError) as e: 

81 msg = f"corrupt models.bin: {e}" 

82 raise ValueError(msg) from e 

83 

84 return models, norms 

85 

86 

87@functools.cache 

88def _load_models_data() -> tuple[dict[str, bytearray], dict[str, float]]: 

89 """Load and parse models.bin, returning (models, norms). 

90 

91 Cached: only reads from disk on first call. 

92 """ 

93 ref = importlib.resources.files("chardet.models").joinpath("models.bin") 

94 data = ref.read_bytes() 

95 

96 if not data: 

97 warnings.warn( 

98 "chardet models.bin is empty — statistical detection disabled; " 

99 "reinstall chardet to fix", 

100 RuntimeWarning, 

101 stacklevel=2, 

102 ) 

103 return {}, {} 

104 

105 return _parse_models_bin(data) 

106 

107 

108def load_models() -> dict[str, bytearray]: 

109 """Load all bigram models from the bundled models.bin file. 

110 

111 Each model is a bytearray of length 65536 (256*256). 

112 Index: (b1 << 8) | b2 -> weight (0-255). 

113 

114 :returns: A dict mapping model key strings to 65536-byte lookup tables. 

115 """ 

116 return _load_models_data()[0] 

117 

118 

119def _build_enc_index( 

120 models: dict[str, bytearray], 

121) -> dict[str, list[tuple[str | None, bytearray, str]]]: 

122 """Build a grouped index from a models dict. 

123 

124 :param models: Mapping of ``"lang/encoding"`` keys to 65536-byte tables. 

125 :returns: Mapping of encoding name to ``[(lang, model, model_key), ...]``. 

126 """ 

127 index: dict[str, list[tuple[str | None, bytearray, str]]] = {} 

128 for key, model in models.items(): 

129 lang, enc = key.split("/", 1) 

130 index.setdefault(enc, []).append((lang, model, key)) 

131 

132 # Resolve aliases: if a model key uses a non-canonical name, 

133 # copy the entry under the canonical name. 

134 for enc_name in list(index): 

135 canonical = lookup_encoding(enc_name) 

136 if canonical is not None and canonical not in index: 

137 index[canonical] = index[enc_name] 

138 

139 return index 

140 

141 

142@functools.cache 

143def get_enc_index() -> dict[str, list[tuple[str | None, bytearray, str]]]: 

144 """Return a pre-grouped index mapping encoding name -> [(lang, model, model_key), ...].""" 

145 return _build_enc_index(load_models()) 

146 

147 

148def infer_language(encoding: str) -> str | None: 

149 """Return the language for a single-language encoding, or None. 

150 

151 :param encoding: The canonical encoding name. 

152 :returns: An ISO 639-1 language code, or ``None`` if the encoding is 

153 multi-language. 

154 """ 

155 return _SINGLE_LANG_MAP.get(encoding) 

156 

157 

158def has_model_variants(encoding: str) -> bool: 

159 """Return True if the encoding has language variants in the model index. 

160 

161 :param encoding: The canonical encoding name. 

162 :returns: ``True`` if bigram models exist for this encoding. 

163 """ 

164 return encoding in get_enc_index() 

165 

166 

167def _get_model_norms() -> dict[str, float]: 

168 """Return cached L2 norms for all models, keyed by model key string.""" 

169 return _load_models_data()[1] 

170 

171 

172class BigramProfile: 

173 """Pre-computed bigram frequency distribution for a data sample. 

174 

175 Computing this once and reusing it across all models reduces per-model 

176 scoring from O(n) to O(distinct_bigrams). 

177 

178 Stores a single ``weighted_freq`` dict mapping bigram index to 

179 *count * weight* (weight is 8 for non-ASCII bigrams, 1 otherwise). 

180 This pre-multiplies the weight during construction so the scoring 

181 inner loop only needs a single dict traversal with no branching. 

182 """ 

183 

184 __slots__ = ("input_norm", "weight_sum", "weighted_freq") 

185 

186 def __init__(self, data: bytes) -> None: 

187 """Compute the bigram frequency distribution for *data*. 

188 

189 :param data: The raw byte data to profile. 

190 """ 

191 total_bigrams = len(data) - 1 

192 if total_bigrams <= 0: 

193 self.weighted_freq: dict[int, int] = {} 

194 self.weight_sum: int = 0 

195 self.input_norm: float = 0.0 

196 return 

197 

198 freq: dict[int, int] = {} 

199 w_sum = 0 

200 hi_w = NON_ASCII_BIGRAM_WEIGHT 

201 _get = freq.get 

202 for i in range(total_bigrams): 

203 b1 = data[i] 

204 b2 = data[i + 1] 

205 idx = (b1 << 8) | b2 

206 if b1 > 0x7F or b2 > 0x7F: 

207 freq[idx] = _get(idx, 0) + hi_w 

208 w_sum += hi_w 

209 else: 

210 freq[idx] = _get(idx, 0) + 1 

211 w_sum += 1 

212 self.weighted_freq = freq 

213 self.weight_sum = w_sum 

214 self.input_norm = math.sqrt(sum(v * v for v in freq.values())) 

215 

216 @classmethod 

217 def from_weighted_freq(cls, weighted_freq: dict[int, int]) -> "BigramProfile": 

218 """Create a BigramProfile from pre-computed weighted frequencies. 

219 

220 Computes ``weight_sum`` and ``input_norm`` from *weighted_freq* to 

221 ensure consistency between the three fields. 

222 

223 :param weighted_freq: Mapping of bigram index to weighted count. 

224 :returns: A new :class:`BigramProfile` instance. 

225 """ 

226 profile = cls(b"") 

227 profile.weighted_freq = weighted_freq 

228 profile.weight_sum = sum(weighted_freq.values()) 

229 profile.input_norm = math.sqrt(sum(v * v for v in weighted_freq.values())) 

230 return profile 

231 

232 

233def score_with_profile( 

234 profile: BigramProfile, model: bytearray, model_key: str = "" 

235) -> float: 

236 """Score a pre-computed bigram profile against a single model using cosine similarity.""" 

237 if profile.input_norm == 0.0: 

238 return 0.0 

239 norms = _get_model_norms() 

240 model_norm = norms.get(model_key) if model_key else None 

241 if model_norm is None: 

242 sq_sum = 0 

243 for i in range(65536): 

244 v = model[i] 

245 if v: 

246 sq_sum += v * v 

247 model_norm = math.sqrt(sq_sum) 

248 if model_norm == 0.0: 

249 return 0.0 

250 dot = 0 

251 for idx, wcount in profile.weighted_freq.items(): 

252 dot += model[idx] * wcount 

253 return dot / (model_norm * profile.input_norm) 

254 

255 

256def score_best_language( 

257 data: bytes, 

258 encoding: str, 

259 profile: BigramProfile | None = None, 

260) -> tuple[float, str | None]: 

261 """Score data against all language variants of an encoding. 

262 

263 Returns (best_score, best_language). Uses a pre-grouped index for O(L) 

264 lookup where L is the number of language variants for the encoding. 

265 

266 If *profile* is provided, it is reused instead of recomputing the bigram 

267 frequency distribution from *data*. 

268 

269 :param data: The raw byte data to score. 

270 :param encoding: The canonical encoding name to match against. 

271 :param profile: Optional pre-computed :class:`BigramProfile` to reuse. 

272 :returns: A ``(score, language)`` tuple with the best cosine-similarity 

273 score and the corresponding language code (or ``None``). 

274 """ 

275 if not data and profile is None: 

276 return 0.0, None 

277 

278 index = get_enc_index() 

279 variants = index.get(encoding) 

280 if variants is None: 

281 return 0.0, None 

282 

283 if profile is None: 

284 profile = BigramProfile(data) 

285 

286 best_score = 0.0 

287 best_lang: str | None = None 

288 for lang, model, model_key in variants: 

289 s = score_with_profile(profile, model, model_key) 

290 if s > best_score: 

291 best_score = s 

292 best_lang = lang 

293 

294 return best_score, best_lang