Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/lark/utils.py: 56%
195 statements
« prev ^ index » next coverage.py v7.3.1, created at 2023-09-25 06:30 +0000
« prev ^ index » next coverage.py v7.3.1, created at 2023-09-25 06:30 +0000
1import unicodedata
2import os
3from itertools import product
4from collections import deque
5from typing import Callable, Iterator, List, Optional, Tuple, Type, TypeVar, Union, Dict, Any, Sequence, Iterable
7###{standalone
8import sys, re
9import logging
11logger: logging.Logger = logging.getLogger("lark")
12logger.addHandler(logging.StreamHandler())
13# Set to highest level, since we have some warnings amongst the code
14# By default, we should not output any log messages
15logger.setLevel(logging.CRITICAL)
18NO_VALUE = object()
20T = TypeVar("T")
23def classify(seq: Iterable, key: Optional[Callable] = None, value: Optional[Callable] = None) -> Dict:
24 d: Dict[Any, Any] = {}
25 for item in seq:
26 k = key(item) if (key is not None) else item
27 v = value(item) if (value is not None) else item
28 try:
29 d[k].append(v)
30 except KeyError:
31 d[k] = [v]
32 return d
35def _deserialize(data: Any, namespace: Dict[str, Any], memo: Dict) -> Any:
36 if isinstance(data, dict):
37 if '__type__' in data: # Object
38 class_ = namespace[data['__type__']]
39 return class_.deserialize(data, memo)
40 elif '@' in data:
41 return memo[data['@']]
42 return {key:_deserialize(value, namespace, memo) for key, value in data.items()}
43 elif isinstance(data, list):
44 return [_deserialize(value, namespace, memo) for value in data]
45 return data
48_T = TypeVar("_T", bound="Serialize")
50class Serialize:
51 """Safe-ish serialization interface that doesn't rely on Pickle
53 Attributes:
54 __serialize_fields__ (List[str]): Fields (aka attributes) to serialize.
55 __serialize_namespace__ (list): List of classes that deserialization is allowed to instantiate.
56 Should include all field types that aren't builtin types.
57 """
59 def memo_serialize(self, types_to_memoize: List) -> Any:
60 memo = SerializeMemoizer(types_to_memoize)
61 return self.serialize(memo), memo.serialize()
63 def serialize(self, memo = None) -> Dict[str, Any]:
64 if memo and memo.in_types(self):
65 return {'@': memo.memoized.get(self)}
67 fields = getattr(self, '__serialize_fields__')
68 res = {f: _serialize(getattr(self, f), memo) for f in fields}
69 res['__type__'] = type(self).__name__
70 if hasattr(self, '_serialize'):
71 self._serialize(res, memo) # type: ignore[attr-defined]
72 return res
74 @classmethod
75 def deserialize(cls: Type[_T], data: Dict[str, Any], memo: Dict[int, Any]) -> _T:
76 namespace = getattr(cls, '__serialize_namespace__', [])
77 namespace = {c.__name__:c for c in namespace}
79 fields = getattr(cls, '__serialize_fields__')
81 if '@' in data:
82 return memo[data['@']]
84 inst = cls.__new__(cls)
85 for f in fields:
86 try:
87 setattr(inst, f, _deserialize(data[f], namespace, memo))
88 except KeyError as e:
89 raise KeyError("Cannot find key for class", cls, e)
91 if hasattr(inst, '_deserialize'):
92 inst._deserialize() # type: ignore[attr-defined]
94 return inst
97class SerializeMemoizer(Serialize):
98 "A version of serialize that memoizes objects to reduce space"
100 __serialize_fields__ = 'memoized',
102 def __init__(self, types_to_memoize: List) -> None:
103 self.types_to_memoize = tuple(types_to_memoize)
104 self.memoized = Enumerator()
106 def in_types(self, value: Serialize) -> bool:
107 return isinstance(value, self.types_to_memoize)
109 def serialize(self) -> Dict[int, Any]: # type: ignore[override]
110 return _serialize(self.memoized.reversed(), None)
112 @classmethod
113 def deserialize(cls, data: Dict[int, Any], namespace: Dict[str, Any], memo: Dict[Any, Any]) -> Dict[int, Any]: # type: ignore[override]
114 return _deserialize(data, namespace, memo)
117try:
118 import regex
119 _has_regex = True
120except ImportError:
121 _has_regex = False
123if sys.version_info >= (3, 11):
124 import re._parser as sre_parse
125 import re._constants as sre_constants
126else:
127 import sre_parse
128 import sre_constants
130categ_pattern = re.compile(r'\\p{[A-Za-z_]+}')
132def get_regexp_width(expr: str) -> Union[Tuple[int, int], List[int]]:
133 if _has_regex:
134 # Since `sre_parse` cannot deal with Unicode categories of the form `\p{Mn}`, we replace these with
135 # a simple letter, which makes no difference as we are only trying to get the possible lengths of the regex
136 # match here below.
137 regexp_final = re.sub(categ_pattern, 'A', expr)
138 else:
139 if re.search(categ_pattern, expr):
140 raise ImportError('`regex` module must be installed in order to use Unicode categories.', expr)
141 regexp_final = expr
142 try:
143 # Fixed in next version (past 0.960) of typeshed
144 return [int(x) for x in sre_parse.parse(regexp_final).getwidth()] # type: ignore[attr-defined]
145 except sre_constants.error:
146 if not _has_regex:
147 raise ValueError(expr)
148 else:
149 # sre_parse does not support the new features in regex. To not completely fail in that case,
150 # we manually test for the most important info (whether the empty string is matched)
151 c = regex.compile(regexp_final)
152 if c.match('') is None:
153 # MAXREPEAT is a none pickable subclass of int, therefore needs to be converted to enable caching
154 return 1, int(sre_constants.MAXREPEAT)
155 else:
156 return 0, int(sre_constants.MAXREPEAT)
158###}
161_ID_START = 'Lu', 'Ll', 'Lt', 'Lm', 'Lo', 'Mn', 'Mc', 'Pc'
162_ID_CONTINUE = _ID_START + ('Nd', 'Nl',)
164def _test_unicode_category(s: str, categories: Sequence[str]) -> bool:
165 if len(s) != 1:
166 return all(_test_unicode_category(char, categories) for char in s)
167 return s == '_' or unicodedata.category(s) in categories
169def is_id_continue(s: str) -> bool:
170 """
171 Checks if all characters in `s` are alphanumeric characters (Unicode standard, so diacritics, indian vowels, non-latin
172 numbers, etc. all pass). Synonymous with a Python `ID_CONTINUE` identifier. See PEP 3131 for details.
173 """
174 return _test_unicode_category(s, _ID_CONTINUE)
176def is_id_start(s: str) -> bool:
177 """
178 Checks if all characters in `s` are alphabetic characters (Unicode standard, so diacritics, indian vowels, non-latin
179 numbers, etc. all pass). Synonymous with a Python `ID_START` identifier. See PEP 3131 for details.
180 """
181 return _test_unicode_category(s, _ID_START)
184def dedup_list(l: List[T]) -> List[T]:
185 """Given a list (l) will removing duplicates from the list,
186 preserving the original order of the list. Assumes that
187 the list entries are hashable."""
188 dedup = set()
189 # This returns None, but that's expected
190 return [x for x in l if not (x in dedup or dedup.add(x))] # type: ignore[func-returns-value]
191 # 2x faster (ordered in PyPy and CPython 3.6+, guaranteed to be ordered in Python 3.7+)
192 # return list(dict.fromkeys(l))
195class Enumerator(Serialize):
196 def __init__(self) -> None:
197 self.enums: Dict[Any, int] = {}
199 def get(self, item) -> int:
200 if item not in self.enums:
201 self.enums[item] = len(self.enums)
202 return self.enums[item]
204 def __len__(self):
205 return len(self.enums)
207 def reversed(self) -> Dict[int, Any]:
208 r = {v: k for k, v in self.enums.items()}
209 assert len(r) == len(self.enums)
210 return r
214def combine_alternatives(lists):
215 """
216 Accepts a list of alternatives, and enumerates all their possible concatenations.
218 Examples:
219 >>> combine_alternatives([range(2), [4,5]])
220 [[0, 4], [0, 5], [1, 4], [1, 5]]
222 >>> combine_alternatives(["abc", "xy", '$'])
223 [['a', 'x', '$'], ['a', 'y', '$'], ['b', 'x', '$'], ['b', 'y', '$'], ['c', 'x', '$'], ['c', 'y', '$']]
225 >>> combine_alternatives([])
226 [[]]
227 """
228 if not lists:
229 return [[]]
230 assert all(l for l in lists), lists
231 return list(product(*lists))
233try:
234 import atomicwrites
235 _has_atomicwrites = True
236except ImportError:
237 _has_atomicwrites = False
239class FS:
240 exists = staticmethod(os.path.exists)
242 @staticmethod
243 def open(name, mode="r", **kwargs):
244 if _has_atomicwrites and "w" in mode:
245 return atomicwrites.atomic_write(name, mode=mode, overwrite=True, **kwargs)
246 else:
247 return open(name, mode, **kwargs)
251def isascii(s: str) -> bool:
252 """ str.isascii only exists in python3.7+ """
253 if sys.version_info >= (3, 7):
254 return s.isascii()
255 else:
256 try:
257 s.encode('ascii')
258 return True
259 except (UnicodeDecodeError, UnicodeEncodeError):
260 return False
263class fzset(frozenset):
264 def __repr__(self):
265 return '{%s}' % ', '.join(map(repr, self))
268def classify_bool(seq: Sequence, pred: Callable) -> Any:
269 false_elems = []
270 true_elems = [elem for elem in seq if pred(elem) or false_elems.append(elem)] # type: ignore[func-returns-value]
271 return true_elems, false_elems
274def bfs(initial: Sequence, expand: Callable) -> Iterator:
275 open_q = deque(list(initial))
276 visited = set(open_q)
277 while open_q:
278 node = open_q.popleft()
279 yield node
280 for next_node in expand(node):
281 if next_node not in visited:
282 visited.add(next_node)
283 open_q.append(next_node)
285def bfs_all_unique(initial, expand):
286 "bfs, but doesn't keep track of visited (aka seen), because there can be no repetitions"
287 open_q = deque(list(initial))
288 while open_q:
289 node = open_q.popleft()
290 yield node
291 open_q += expand(node)
294def _serialize(value: Any, memo: Optional[SerializeMemoizer]) -> Any:
295 if isinstance(value, Serialize):
296 return value.serialize(memo)
297 elif isinstance(value, list):
298 return [_serialize(elem, memo) for elem in value]
299 elif isinstance(value, frozenset):
300 return list(value) # TODO reversible?
301 elif isinstance(value, dict):
302 return {key:_serialize(elem, memo) for key, elem in value.items()}
303 # assert value is None or isinstance(value, (int, float, str, tuple)), value
304 return value
309def small_factors(n: int, max_factor: int) -> List[Tuple[int, int]]:
310 """
311 Splits n up into smaller factors and summands <= max_factor.
312 Returns a list of [(a, b), ...]
313 so that the following code returns n:
315 n = 1
316 for a, b in values:
317 n = n * a + b
319 Currently, we also keep a + b <= max_factor, but that might change
320 """
321 assert n >= 0
322 assert max_factor > 2
323 if n <= max_factor:
324 return [(n, 0)]
326 for a in range(max_factor, 1, -1):
327 r, b = divmod(n, a)
328 if a + b <= max_factor:
329 return small_factors(r, max_factor) + [(a, b)]
330 assert False, "Failed to factorize %s" % n