Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/lark/utils.py: 58%
211 statements
« prev ^ index » next coverage.py v7.4.1, created at 2024-02-14 06:19 +0000
« prev ^ index » next coverage.py v7.4.1, created at 2024-02-14 06:19 +0000
1import unicodedata
2import os
3from itertools import product
4from collections import deque
5from typing import Callable, Iterator, List, Optional, Tuple, Type, TypeVar, Union, Dict, Any, Sequence, Iterable, AbstractSet
7###{standalone
8import sys, re
9import logging
11logger: logging.Logger = logging.getLogger("lark")
12logger.addHandler(logging.StreamHandler())
13# Set to highest level, since we have some warnings amongst the code
14# By default, we should not output any log messages
15logger.setLevel(logging.CRITICAL)
18NO_VALUE = object()
20T = TypeVar("T")
23def classify(seq: Iterable, key: Optional[Callable] = None, value: Optional[Callable] = None) -> Dict:
24 d: Dict[Any, Any] = {}
25 for item in seq:
26 k = key(item) if (key is not None) else item
27 v = value(item) if (value is not None) else item
28 try:
29 d[k].append(v)
30 except KeyError:
31 d[k] = [v]
32 return d
35def _deserialize(data: Any, namespace: Dict[str, Any], memo: Dict) -> Any:
36 if isinstance(data, dict):
37 if '__type__' in data: # Object
38 class_ = namespace[data['__type__']]
39 return class_.deserialize(data, memo)
40 elif '@' in data:
41 return memo[data['@']]
42 return {key:_deserialize(value, namespace, memo) for key, value in data.items()}
43 elif isinstance(data, list):
44 return [_deserialize(value, namespace, memo) for value in data]
45 return data
48_T = TypeVar("_T", bound="Serialize")
50class Serialize:
51 """Safe-ish serialization interface that doesn't rely on Pickle
53 Attributes:
54 __serialize_fields__ (List[str]): Fields (aka attributes) to serialize.
55 __serialize_namespace__ (list): List of classes that deserialization is allowed to instantiate.
56 Should include all field types that aren't builtin types.
57 """
59 def memo_serialize(self, types_to_memoize: List) -> Any:
60 memo = SerializeMemoizer(types_to_memoize)
61 return self.serialize(memo), memo.serialize()
63 def serialize(self, memo = None) -> Dict[str, Any]:
64 if memo and memo.in_types(self):
65 return {'@': memo.memoized.get(self)}
67 fields = getattr(self, '__serialize_fields__')
68 res = {f: _serialize(getattr(self, f), memo) for f in fields}
69 res['__type__'] = type(self).__name__
70 if hasattr(self, '_serialize'):
71 self._serialize(res, memo) # type: ignore[attr-defined]
72 return res
74 @classmethod
75 def deserialize(cls: Type[_T], data: Dict[str, Any], memo: Dict[int, Any]) -> _T:
76 namespace = getattr(cls, '__serialize_namespace__', [])
77 namespace = {c.__name__:c for c in namespace}
79 fields = getattr(cls, '__serialize_fields__')
81 if '@' in data:
82 return memo[data['@']]
84 inst = cls.__new__(cls)
85 for f in fields:
86 try:
87 setattr(inst, f, _deserialize(data[f], namespace, memo))
88 except KeyError as e:
89 raise KeyError("Cannot find key for class", cls, e)
91 if hasattr(inst, '_deserialize'):
92 inst._deserialize() # type: ignore[attr-defined]
94 return inst
97class SerializeMemoizer(Serialize):
98 "A version of serialize that memoizes objects to reduce space"
100 __serialize_fields__ = 'memoized',
102 def __init__(self, types_to_memoize: List) -> None:
103 self.types_to_memoize = tuple(types_to_memoize)
104 self.memoized = Enumerator()
106 def in_types(self, value: Serialize) -> bool:
107 return isinstance(value, self.types_to_memoize)
109 def serialize(self) -> Dict[int, Any]: # type: ignore[override]
110 return _serialize(self.memoized.reversed(), None)
112 @classmethod
113 def deserialize(cls, data: Dict[int, Any], namespace: Dict[str, Any], memo: Dict[Any, Any]) -> Dict[int, Any]: # type: ignore[override]
114 return _deserialize(data, namespace, memo)
117try:
118 import regex
119 _has_regex = True
120except ImportError:
121 _has_regex = False
123if sys.version_info >= (3, 11):
124 import re._parser as sre_parse
125 import re._constants as sre_constants
126else:
127 import sre_parse
128 import sre_constants
130categ_pattern = re.compile(r'\\p{[A-Za-z_]+}')
132def get_regexp_width(expr: str) -> Union[Tuple[int, int], List[int]]:
133 if _has_regex:
134 # Since `sre_parse` cannot deal with Unicode categories of the form `\p{Mn}`, we replace these with
135 # a simple letter, which makes no difference as we are only trying to get the possible lengths of the regex
136 # match here below.
137 regexp_final = re.sub(categ_pattern, 'A', expr)
138 else:
139 if re.search(categ_pattern, expr):
140 raise ImportError('`regex` module must be installed in order to use Unicode categories.', expr)
141 regexp_final = expr
142 try:
143 # Fixed in next version (past 0.960) of typeshed
144 return [int(x) for x in sre_parse.parse(regexp_final).getwidth()] # type: ignore[attr-defined]
145 except sre_constants.error:
146 if not _has_regex:
147 raise ValueError(expr)
148 else:
149 # sre_parse does not support the new features in regex. To not completely fail in that case,
150 # we manually test for the most important info (whether the empty string is matched)
151 c = regex.compile(regexp_final)
152 # Python 3.11.7 introducded sre_parse.MAXWIDTH that is used instead of MAXREPEAT
153 # See lark-parser/lark#1376 and python/cpython#109859
154 MAXWIDTH = getattr(sre_parse, "MAXWIDTH", sre_constants.MAXREPEAT)
155 if c.match('') is None:
156 # MAXREPEAT is a none pickable subclass of int, therefore needs to be converted to enable caching
157 return 1, int(MAXWIDTH)
158 else:
159 return 0, int(MAXWIDTH)
161###}
164_ID_START = 'Lu', 'Ll', 'Lt', 'Lm', 'Lo', 'Mn', 'Mc', 'Pc'
165_ID_CONTINUE = _ID_START + ('Nd', 'Nl',)
167def _test_unicode_category(s: str, categories: Sequence[str]) -> bool:
168 if len(s) != 1:
169 return all(_test_unicode_category(char, categories) for char in s)
170 return s == '_' or unicodedata.category(s) in categories
172def is_id_continue(s: str) -> bool:
173 """
174 Checks if all characters in `s` are alphanumeric characters (Unicode standard, so diacritics, indian vowels, non-latin
175 numbers, etc. all pass). Synonymous with a Python `ID_CONTINUE` identifier. See PEP 3131 for details.
176 """
177 return _test_unicode_category(s, _ID_CONTINUE)
179def is_id_start(s: str) -> bool:
180 """
181 Checks if all characters in `s` are alphabetic characters (Unicode standard, so diacritics, indian vowels, non-latin
182 numbers, etc. all pass). Synonymous with a Python `ID_START` identifier. See PEP 3131 for details.
183 """
184 return _test_unicode_category(s, _ID_START)
187def dedup_list(l: Sequence[T]) -> List[T]:
188 """Given a list (l) will removing duplicates from the list,
189 preserving the original order of the list. Assumes that
190 the list entries are hashable."""
191 dedup = set()
192 # This returns None, but that's expected
193 return [x for x in l if not (x in dedup or dedup.add(x))] # type: ignore[func-returns-value]
194 # 2x faster (ordered in PyPy and CPython 3.6+, guaranteed to be ordered in Python 3.7+)
195 # return list(dict.fromkeys(l))
198class Enumerator(Serialize):
199 def __init__(self) -> None:
200 self.enums: Dict[Any, int] = {}
202 def get(self, item) -> int:
203 if item not in self.enums:
204 self.enums[item] = len(self.enums)
205 return self.enums[item]
207 def __len__(self):
208 return len(self.enums)
210 def reversed(self) -> Dict[int, Any]:
211 r = {v: k for k, v in self.enums.items()}
212 assert len(r) == len(self.enums)
213 return r
217def combine_alternatives(lists):
218 """
219 Accepts a list of alternatives, and enumerates all their possible concatenations.
221 Examples:
222 >>> combine_alternatives([range(2), [4,5]])
223 [[0, 4], [0, 5], [1, 4], [1, 5]]
225 >>> combine_alternatives(["abc", "xy", '$'])
226 [['a', 'x', '$'], ['a', 'y', '$'], ['b', 'x', '$'], ['b', 'y', '$'], ['c', 'x', '$'], ['c', 'y', '$']]
228 >>> combine_alternatives([])
229 [[]]
230 """
231 if not lists:
232 return [[]]
233 assert all(l for l in lists), lists
234 return list(product(*lists))
236try:
237 # atomicwrites doesn't have type bindings
238 import atomicwrites # type: ignore[import]
239 _has_atomicwrites = True
240except ImportError:
241 _has_atomicwrites = False
243class FS:
244 exists = staticmethod(os.path.exists)
246 @staticmethod
247 def open(name, mode="r", **kwargs):
248 if _has_atomicwrites and "w" in mode:
249 return atomicwrites.atomic_write(name, mode=mode, overwrite=True, **kwargs)
250 else:
251 return open(name, mode, **kwargs)
255def isascii(s: str) -> bool:
256 """ str.isascii only exists in python3.7+ """
257 if sys.version_info >= (3, 7):
258 return s.isascii()
259 else:
260 try:
261 s.encode('ascii')
262 return True
263 except (UnicodeDecodeError, UnicodeEncodeError):
264 return False
267class fzset(frozenset):
268 def __repr__(self):
269 return '{%s}' % ', '.join(map(repr, self))
272def classify_bool(seq: Iterable, pred: Callable) -> Any:
273 false_elems = []
274 true_elems = [elem for elem in seq if pred(elem) or false_elems.append(elem)] # type: ignore[func-returns-value]
275 return true_elems, false_elems
278def bfs(initial: Iterable, expand: Callable) -> Iterator:
279 open_q = deque(list(initial))
280 visited = set(open_q)
281 while open_q:
282 node = open_q.popleft()
283 yield node
284 for next_node in expand(node):
285 if next_node not in visited:
286 visited.add(next_node)
287 open_q.append(next_node)
289def bfs_all_unique(initial, expand):
290 "bfs, but doesn't keep track of visited (aka seen), because there can be no repetitions"
291 open_q = deque(list(initial))
292 while open_q:
293 node = open_q.popleft()
294 yield node
295 open_q += expand(node)
298def _serialize(value: Any, memo: Optional[SerializeMemoizer]) -> Any:
299 if isinstance(value, Serialize):
300 return value.serialize(memo)
301 elif isinstance(value, list):
302 return [_serialize(elem, memo) for elem in value]
303 elif isinstance(value, frozenset):
304 return list(value) # TODO reversible?
305 elif isinstance(value, dict):
306 return {key:_serialize(elem, memo) for key, elem in value.items()}
307 # assert value is None or isinstance(value, (int, float, str, tuple)), value
308 return value
313def small_factors(n: int, max_factor: int) -> List[Tuple[int, int]]:
314 """
315 Splits n up into smaller factors and summands <= max_factor.
316 Returns a list of [(a, b), ...]
317 so that the following code returns n:
319 n = 1
320 for a, b in values:
321 n = n * a + b
323 Currently, we also keep a + b <= max_factor, but that might change
324 """
325 assert n >= 0
326 assert max_factor > 2
327 if n <= max_factor:
328 return [(n, 0)]
330 for a in range(max_factor, 1, -1):
331 r, b = divmod(n, a)
332 if a + b <= max_factor:
333 return small_factors(r, max_factor) + [(a, b)]
334 assert False, "Failed to factorize %s" % n
337class OrderedSet(AbstractSet[T]):
338 """A minimal OrderedSet implementation, using a dictionary.
340 (relies on the dictionary being ordered)
341 """
342 def __init__(self, items: Iterable[T] =()):
343 self.d = dict.fromkeys(items)
345 def __contains__(self, item: Any) -> bool:
346 return item in self.d
348 def add(self, item: T):
349 self.d[item] = None
351 def __iter__(self) -> Iterator[T]:
352 return iter(self.d)
354 def remove(self, item: T):
355 del self.d[item]
357 def __bool__(self):
358 return bool(self.d)
360 def __len__(self) -> int:
361 return len(self.d)