1from __future__ import annotations
2
3import importlib
4import logging
5import unicodedata
6from bisect import bisect_right
7from codecs import IncrementalDecoder
8from encodings.aliases import aliases
9from functools import lru_cache
10from re import findall
11from typing import Generator
12
13from _multibytecodec import ( # type: ignore[import-not-found,import]
14 MultibyteIncrementalDecoder,
15)
16
17from .constant import (
18 ENCODING_MARKS,
19 IANA_SUPPORTED_SIMILAR,
20 RE_POSSIBLE_ENCODING_INDICATION,
21 UNICODE_RANGES_COMBINED,
22 UNICODE_SECONDARY_RANGE_KEYWORD,
23 UTF8_MAXIMAL_ALLOCATION,
24 COMMON_CJK_CHARACTERS,
25 _LATIN,
26 _CJK,
27 _HANGUL,
28 _KATAKANA,
29 _HIRAGANA,
30 _THAI,
31 _ARABIC,
32 _ARABIC_ISOLATED_FORM,
33 _ACCENT_KEYWORDS,
34 _ACCENTUATED,
35)
36
37
38@lru_cache(maxsize=UTF8_MAXIMAL_ALLOCATION)
39def _character_flags(character: str) -> int:
40 """Compute all name-based classification flags with a single unicodedata.name() call."""
41 try:
42 desc: str = unicodedata.name(character)
43 except ValueError:
44 return 0
45
46 flags: int = 0
47
48 if "LATIN" in desc:
49 flags |= _LATIN
50 if "CJK" in desc:
51 flags |= _CJK
52 if "HANGUL" in desc:
53 flags |= _HANGUL
54 if "KATAKANA" in desc:
55 flags |= _KATAKANA
56 if "HIRAGANA" in desc:
57 flags |= _HIRAGANA
58 if "THAI" in desc:
59 flags |= _THAI
60 if "ARABIC" in desc:
61 flags |= _ARABIC
62 if "ISOLATED FORM" in desc:
63 flags |= _ARABIC_ISOLATED_FORM
64
65 for kw in _ACCENT_KEYWORDS:
66 if kw in desc:
67 flags |= _ACCENTUATED
68 break
69
70 return flags
71
72
73@lru_cache(maxsize=UTF8_MAXIMAL_ALLOCATION)
74def is_accentuated(character: str) -> bool:
75 return bool(_character_flags(character) & _ACCENTUATED)
76
77
78@lru_cache(maxsize=UTF8_MAXIMAL_ALLOCATION)
79def remove_accent(character: str) -> str:
80 decomposed: str = unicodedata.decomposition(character)
81 if not decomposed:
82 return character
83
84 codes: list[str] = decomposed.split(" ")
85
86 return chr(int(codes[0], 16))
87
88
89# Pre-built sorted lookup table for O(log n) binary search in unicode_range().
90# Each entry is (range_start, range_end_exclusive, range_name).
91_UNICODE_RANGES_SORTED: list[tuple[int, int, str]] = sorted(
92 (ord_range.start, ord_range.stop, name)
93 for name, ord_range in UNICODE_RANGES_COMBINED.items()
94)
95_UNICODE_RANGE_STARTS: list[int] = [e[0] for e in _UNICODE_RANGES_SORTED]
96
97
98@lru_cache(maxsize=UTF8_MAXIMAL_ALLOCATION)
99def unicode_range(character: str) -> str | None:
100 """
101 Retrieve the Unicode range official name from a single character.
102 """
103 character_ord: int = ord(character)
104
105 # Binary search: find the rightmost range whose start <= character_ord
106 idx = bisect_right(_UNICODE_RANGE_STARTS, character_ord) - 1
107 if idx >= 0:
108 start, stop, name = _UNICODE_RANGES_SORTED[idx]
109 if character_ord < stop:
110 return name
111
112 return None
113
114
115@lru_cache(maxsize=UTF8_MAXIMAL_ALLOCATION)
116def is_latin(character: str) -> bool:
117 return bool(_character_flags(character) & _LATIN)
118
119
120@lru_cache(maxsize=UTF8_MAXIMAL_ALLOCATION)
121def is_punctuation(character: str) -> bool:
122 character_category: str = unicodedata.category(character)
123
124 if "P" in character_category:
125 return True
126
127 character_range: str | None = unicode_range(character)
128
129 if character_range is None:
130 return False
131
132 return "Punctuation" in character_range
133
134
135@lru_cache(maxsize=UTF8_MAXIMAL_ALLOCATION)
136def is_symbol(character: str) -> bool:
137 character_category: str = unicodedata.category(character)
138
139 if "S" in character_category or "N" in character_category:
140 return True
141
142 character_range: str | None = unicode_range(character)
143
144 if character_range is None:
145 return False
146
147 return "Forms" in character_range and character_category != "Lo"
148
149
150@lru_cache(maxsize=UTF8_MAXIMAL_ALLOCATION)
151def is_emoticon(character: str) -> bool:
152 character_range: str | None = unicode_range(character)
153
154 if character_range is None:
155 return False
156
157 return "Emoticons" in character_range or "Pictographs" in character_range
158
159
160@lru_cache(maxsize=UTF8_MAXIMAL_ALLOCATION)
161def is_separator(character: str) -> bool:
162 if character.isspace() or character in {"|", "+", "<", ">"}:
163 return True
164
165 character_category: str = unicodedata.category(character)
166
167 return "Z" in character_category or character_category in {"Po", "Pd", "Pc"}
168
169
170@lru_cache(maxsize=UTF8_MAXIMAL_ALLOCATION)
171def is_case_variable(character: str) -> bool:
172 return character.islower() != character.isupper()
173
174
175@lru_cache(maxsize=UTF8_MAXIMAL_ALLOCATION)
176def is_cjk(character: str) -> bool:
177 return bool(_character_flags(character) & _CJK)
178
179
180@lru_cache(maxsize=UTF8_MAXIMAL_ALLOCATION)
181def is_hiragana(character: str) -> bool:
182 return bool(_character_flags(character) & _HIRAGANA)
183
184
185@lru_cache(maxsize=UTF8_MAXIMAL_ALLOCATION)
186def is_katakana(character: str) -> bool:
187 return bool(_character_flags(character) & _KATAKANA)
188
189
190@lru_cache(maxsize=UTF8_MAXIMAL_ALLOCATION)
191def is_hangul(character: str) -> bool:
192 return bool(_character_flags(character) & _HANGUL)
193
194
195@lru_cache(maxsize=UTF8_MAXIMAL_ALLOCATION)
196def is_thai(character: str) -> bool:
197 return bool(_character_flags(character) & _THAI)
198
199
200@lru_cache(maxsize=UTF8_MAXIMAL_ALLOCATION)
201def is_arabic(character: str) -> bool:
202 return bool(_character_flags(character) & _ARABIC)
203
204
205@lru_cache(maxsize=UTF8_MAXIMAL_ALLOCATION)
206def is_arabic_isolated_form(character: str) -> bool:
207 return bool(_character_flags(character) & _ARABIC_ISOLATED_FORM)
208
209
210@lru_cache(maxsize=UTF8_MAXIMAL_ALLOCATION)
211def is_cjk_uncommon(character: str) -> bool:
212 return character not in COMMON_CJK_CHARACTERS
213
214
215@lru_cache(maxsize=len(UNICODE_RANGES_COMBINED))
216def is_unicode_range_secondary(range_name: str) -> bool:
217 return any(keyword in range_name for keyword in UNICODE_SECONDARY_RANGE_KEYWORD)
218
219
220@lru_cache(maxsize=UTF8_MAXIMAL_ALLOCATION)
221def is_unprintable(character: str) -> bool:
222 return (
223 character.isspace() is False # includes \n \t \r \v
224 and character.isprintable() is False
225 and character != "\x1a" # Why? Its the ASCII substitute character.
226 and character != "\ufeff" # bug discovered in Python,
227 # Zero Width No-Break Space located in Arabic Presentation Forms-B, Unicode 1.1 not acknowledged as space.
228 )
229
230
231def any_specified_encoding(sequence: bytes, search_zone: int = 8192) -> str | None:
232 """
233 Extract using ASCII-only decoder any specified encoding in the first n-bytes.
234 """
235 if not isinstance(sequence, (bytes, bytearray)):
236 raise TypeError
237
238 seq_len: int = len(sequence)
239
240 results: list[str] = findall(
241 RE_POSSIBLE_ENCODING_INDICATION,
242 sequence[: min(seq_len, search_zone)].decode("ascii", errors="ignore"),
243 )
244
245 if len(results) == 0:
246 return None
247
248 for specified_encoding in results:
249 specified_encoding = specified_encoding.lower().replace("-", "_")
250
251 encoding_alias: str
252 encoding_iana: str
253
254 for encoding_alias, encoding_iana in aliases.items():
255 if encoding_alias == specified_encoding:
256 return encoding_iana
257 if encoding_iana == specified_encoding:
258 return encoding_iana
259
260 return None
261
262
263@lru_cache(maxsize=128)
264def is_multi_byte_encoding(name: str) -> bool:
265 """
266 Verify is a specific encoding is a multi byte one based on it IANA name
267 """
268 return name in {
269 "utf_8",
270 "utf_8_sig",
271 "utf_16",
272 "utf_16_be",
273 "utf_16_le",
274 "utf_32",
275 "utf_32_le",
276 "utf_32_be",
277 "utf_7",
278 } or issubclass(
279 importlib.import_module(f"encodings.{name}").IncrementalDecoder,
280 MultibyteIncrementalDecoder,
281 )
282
283
284def identify_sig_or_bom(sequence: bytes) -> tuple[str | None, bytes]:
285 """
286 Identify and extract SIG/BOM in given sequence.
287 """
288
289 for iana_encoding in ENCODING_MARKS:
290 marks: bytes | list[bytes] = ENCODING_MARKS[iana_encoding]
291
292 if isinstance(marks, bytes):
293 marks = [marks]
294
295 for mark in marks:
296 if sequence.startswith(mark):
297 return iana_encoding, mark
298
299 return None, b""
300
301
302def should_strip_sig_or_bom(iana_encoding: str) -> bool:
303 return iana_encoding not in {"utf_16", "utf_32"}
304
305
306def iana_name(cp_name: str, strict: bool = True) -> str:
307 """Returns the Python normalized encoding name (Not the IANA official name)."""
308 cp_name = cp_name.lower().replace("-", "_")
309
310 encoding_alias: str
311 encoding_iana: str
312
313 for encoding_alias, encoding_iana in aliases.items():
314 if cp_name in [encoding_alias, encoding_iana]:
315 return encoding_iana
316
317 if strict:
318 raise ValueError(f"Unable to retrieve IANA for '{cp_name}'")
319
320 return cp_name
321
322
323def cp_similarity(iana_name_a: str, iana_name_b: str) -> float:
324 if is_multi_byte_encoding(iana_name_a) or is_multi_byte_encoding(iana_name_b):
325 return 0.0
326
327 decoder_a = importlib.import_module(f"encodings.{iana_name_a}").IncrementalDecoder
328 decoder_b = importlib.import_module(f"encodings.{iana_name_b}").IncrementalDecoder
329
330 id_a: IncrementalDecoder = decoder_a(errors="ignore")
331 id_b: IncrementalDecoder = decoder_b(errors="ignore")
332
333 character_match_count: int = 0
334
335 for i in range(256):
336 to_be_decoded: bytes = bytes([i])
337 if id_a.decode(to_be_decoded) == id_b.decode(to_be_decoded):
338 character_match_count += 1
339
340 return character_match_count / 256
341
342
343def is_cp_similar(iana_name_a: str, iana_name_b: str) -> bool:
344 """
345 Determine if two code page are at least 80% similar. IANA_SUPPORTED_SIMILAR dict was generated using
346 the function cp_similarity.
347 """
348 return (
349 iana_name_a in IANA_SUPPORTED_SIMILAR
350 and iana_name_b in IANA_SUPPORTED_SIMILAR[iana_name_a]
351 )
352
353
354def set_logging_handler(
355 name: str = "charset_normalizer",
356 level: int = logging.INFO,
357 format_string: str = "%(asctime)s | %(levelname)s | %(message)s",
358) -> None:
359 logger = logging.getLogger(name)
360 logger.setLevel(level)
361
362 handler = logging.StreamHandler()
363 handler.setFormatter(logging.Formatter(format_string))
364 logger.addHandler(handler)
365
366
367def cut_sequence_chunks(
368 sequences: bytes,
369 encoding_iana: str,
370 offsets: range,
371 chunk_size: int,
372 bom_or_sig_available: bool,
373 strip_sig_or_bom: bool,
374 sig_payload: bytes,
375 is_multi_byte_decoder: bool,
376 decoded_payload: str | None = None,
377) -> Generator[str, None, None]:
378 if decoded_payload and is_multi_byte_decoder is False:
379 for i in offsets:
380 chunk = decoded_payload[i : i + chunk_size]
381 if not chunk:
382 break
383 yield chunk
384 else:
385 for i in offsets:
386 chunk_end = i + chunk_size
387 if chunk_end > len(sequences) + 8:
388 continue
389
390 cut_sequence = sequences[i : i + chunk_size]
391
392 if bom_or_sig_available and strip_sig_or_bom is False:
393 cut_sequence = sig_payload + cut_sequence
394
395 chunk = cut_sequence.decode(
396 encoding_iana,
397 errors="ignore" if is_multi_byte_decoder else "strict",
398 )
399
400 # multi-byte bad cutting detector and adjustment
401 # not the cleanest way to perform that fix but clever enough for now.
402 if is_multi_byte_decoder and i > 0:
403 chunk_partial_size_chk: int = min(chunk_size, 16)
404
405 if (
406 decoded_payload
407 and chunk[:chunk_partial_size_chk] not in decoded_payload
408 ):
409 for j in range(i, i - 4, -1):
410 cut_sequence = sequences[j:chunk_end]
411
412 if bom_or_sig_available and strip_sig_or_bom is False:
413 cut_sequence = sig_payload + cut_sequence
414
415 chunk = cut_sequence.decode(encoding_iana, errors="ignore")
416
417 if chunk[:chunk_partial_size_chk] in decoded_payload:
418 break
419
420 yield chunk