Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/lark/utils.py: 50%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import unicodedata
2import os
3from itertools import product
4from collections import deque
5from typing import Callable, Iterator, List, Optional, Tuple, Type, TypeVar, Union, Dict, Any, Sequence, Iterable, AbstractSet
7###{standalone
8import sys, re
9import logging
11logger: logging.Logger = logging.getLogger("lark")
12logger.addHandler(logging.StreamHandler())
13# Set to highest level, since we have some warnings amongst the code
14# By default, we should not output any log messages
15logger.setLevel(logging.CRITICAL)
18NO_VALUE = object()
20T = TypeVar("T")
23def classify(seq: Iterable, key: Optional[Callable] = None, value: Optional[Callable] = None) -> Dict:
24 d: Dict[Any, Any] = {}
25 for item in seq:
26 k = key(item) if (key is not None) else item
27 v = value(item) if (value is not None) else item
28 try:
29 d[k].append(v)
30 except KeyError:
31 d[k] = [v]
32 return d
35def _deserialize(data: Any, namespace: Dict[str, Any], memo: Dict) -> Any:
36 if isinstance(data, dict):
37 if '__type__' in data: # Object
38 class_ = namespace[data['__type__']]
39 return class_.deserialize(data, memo)
40 elif '@' in data:
41 return memo[data['@']]
42 return {key:_deserialize(value, namespace, memo) for key, value in data.items()}
43 elif isinstance(data, list):
44 return [_deserialize(value, namespace, memo) for value in data]
45 return data
48_T = TypeVar("_T", bound="Serialize")
50class Serialize:
51 """Safe-ish serialization interface that doesn't rely on Pickle
53 Attributes:
54 __serialize_fields__ (List[str]): Fields (aka attributes) to serialize.
55 __serialize_namespace__ (list): List of classes that deserialization is allowed to instantiate.
56 Should include all field types that aren't builtin types.
57 """
59 def memo_serialize(self, types_to_memoize: List) -> Any:
60 memo = SerializeMemoizer(types_to_memoize)
61 return self.serialize(memo), memo.serialize()
63 def serialize(self, memo = None) -> Dict[str, Any]:
64 if memo and memo.in_types(self):
65 return {'@': memo.memoized.get(self)}
67 fields = getattr(self, '__serialize_fields__')
68 res = {f: _serialize(getattr(self, f), memo) for f in fields}
69 res['__type__'] = type(self).__name__
70 if hasattr(self, '_serialize'):
71 self._serialize(res, memo)
72 return res
74 @classmethod
75 def deserialize(cls: Type[_T], data: Dict[str, Any], memo: Dict[int, Any]) -> _T:
76 namespace = getattr(cls, '__serialize_namespace__', [])
77 namespace = {c.__name__:c for c in namespace}
79 fields = getattr(cls, '__serialize_fields__')
81 if '@' in data:
82 return memo[data['@']]
84 inst = cls.__new__(cls)
85 for f in fields:
86 try:
87 setattr(inst, f, _deserialize(data[f], namespace, memo))
88 except KeyError as e:
89 raise KeyError("Cannot find key for class", cls, e)
91 if hasattr(inst, '_deserialize'):
92 inst._deserialize()
94 return inst
97class SerializeMemoizer(Serialize):
98 "A version of serialize that memoizes objects to reduce space"
100 __serialize_fields__ = 'memoized',
102 def __init__(self, types_to_memoize: List) -> None:
103 self.types_to_memoize = tuple(types_to_memoize)
104 self.memoized = Enumerator()
106 def in_types(self, value: Serialize) -> bool:
107 return isinstance(value, self.types_to_memoize)
109 def serialize(self) -> Dict[int, Any]: # type: ignore[override]
110 return _serialize(self.memoized.reversed(), None)
112 @classmethod
113 def deserialize(cls, data: Dict[int, Any], namespace: Dict[str, Any], memo: Dict[Any, Any]) -> Dict[int, Any]: # type: ignore[override]
114 return _deserialize(data, namespace, memo)
117try:
118 import regex
119 _has_regex = True
120except ImportError:
121 _has_regex = False
123if sys.version_info >= (3, 11):
124 import re._parser as sre_parse
125 import re._constants as sre_constants
126else:
127 import sre_parse
128 import sre_constants
130categ_pattern = re.compile(r'\\p{[A-Za-z_]+}')
132def get_regexp_width(expr: str) -> Union[Tuple[int, int], List[int]]:
133 if _has_regex:
134 # Since `sre_parse` cannot deal with Unicode categories of the form `\p{Mn}`, we replace these with
135 # a simple letter, which makes no difference as we are only trying to get the possible lengths of the regex
136 # match here below.
137 regexp_final = re.sub(categ_pattern, 'A', expr)
138 else:
139 if re.search(categ_pattern, expr):
140 raise ImportError('`regex` module must be installed in order to use Unicode categories.', expr)
141 regexp_final = expr
142 try:
143 # Fixed in next version (past 0.960) of typeshed
144 return [int(x) for x in sre_parse.parse(regexp_final).getwidth()]
145 except sre_constants.error:
146 if not _has_regex:
147 raise ValueError(expr)
148 else:
149 # sre_parse does not support the new features in regex. To not completely fail in that case,
150 # we manually test for the most important info (whether the empty string is matched)
151 c = regex.compile(regexp_final)
152 # Python 3.11.7 introducded sre_parse.MAXWIDTH that is used instead of MAXREPEAT
153 # See lark-parser/lark#1376 and python/cpython#109859
154 MAXWIDTH = getattr(sre_parse, "MAXWIDTH", sre_constants.MAXREPEAT)
155 if c.match('') is None:
156 # MAXREPEAT is a none pickable subclass of int, therefore needs to be converted to enable caching
157 return 1, int(MAXWIDTH)
158 else:
159 return 0, int(MAXWIDTH)
161###}
164_ID_START = 'Lu', 'Ll', 'Lt', 'Lm', 'Lo', 'Mn', 'Mc', 'Pc'
165_ID_CONTINUE = _ID_START + ('Nd', 'Nl',)
167def _test_unicode_category(s: str, categories: Sequence[str]) -> bool:
168 if len(s) != 1:
169 return all(_test_unicode_category(char, categories) for char in s)
170 return s == '_' or unicodedata.category(s) in categories
172def is_id_continue(s: str) -> bool:
173 """
174 Checks if all characters in `s` are alphanumeric characters (Unicode standard, so diacritics, indian vowels, non-latin
175 numbers, etc. all pass). Synonymous with a Python `ID_CONTINUE` identifier. See PEP 3131 for details.
176 """
177 return _test_unicode_category(s, _ID_CONTINUE)
179def is_id_start(s: str) -> bool:
180 """
181 Checks if all characters in `s` are alphabetic characters (Unicode standard, so diacritics, indian vowels, non-latin
182 numbers, etc. all pass). Synonymous with a Python `ID_START` identifier. See PEP 3131 for details.
183 """
184 return _test_unicode_category(s, _ID_START)
187def dedup_list(l: Sequence[T]) -> List[T]:
188 """Given a list (l) will removing duplicates from the list,
189 preserving the original order of the list. Assumes that
190 the list entries are hashable."""
191 return list(dict.fromkeys(l))
194class Enumerator(Serialize):
195 def __init__(self) -> None:
196 self.enums: Dict[Any, int] = {}
198 def get(self, item) -> int:
199 if item not in self.enums:
200 self.enums[item] = len(self.enums)
201 return self.enums[item]
203 def __len__(self):
204 return len(self.enums)
206 def reversed(self) -> Dict[int, Any]:
207 r = {v: k for k, v in self.enums.items()}
208 assert len(r) == len(self.enums)
209 return r
213def combine_alternatives(lists):
214 """
215 Accepts a list of alternatives, and enumerates all their possible concatenations.
217 Examples:
218 >>> combine_alternatives([range(2), [4,5]])
219 [[0, 4], [0, 5], [1, 4], [1, 5]]
221 >>> combine_alternatives(["abc", "xy", '$'])
222 [['a', 'x', '$'], ['a', 'y', '$'], ['b', 'x', '$'], ['b', 'y', '$'], ['c', 'x', '$'], ['c', 'y', '$']]
224 >>> combine_alternatives([])
225 [[]]
226 """
227 if not lists:
228 return [[]]
229 assert all(l for l in lists), lists
230 return list(product(*lists))
232try:
233 import atomicwrites
234 _has_atomicwrites = True
235except ImportError:
236 _has_atomicwrites = False
238class FS:
239 exists = staticmethod(os.path.exists)
241 @staticmethod
242 def open(name, mode="r", **kwargs):
243 if _has_atomicwrites and "w" in mode:
244 return atomicwrites.atomic_write(name, mode=mode, overwrite=True, **kwargs)
245 else:
246 return open(name, mode, **kwargs)
249class fzset(frozenset):
250 def __repr__(self):
251 return '{%s}' % ', '.join(map(repr, self))
254def classify_bool(seq: Iterable, pred: Callable) -> Any:
255 false_elems = []
256 true_elems = [elem for elem in seq if pred(elem) or false_elems.append(elem)] # type: ignore[func-returns-value]
257 return true_elems, false_elems
260def bfs(initial: Iterable, expand: Callable) -> Iterator:
261 open_q = deque(list(initial))
262 visited = set(open_q)
263 while open_q:
264 node = open_q.popleft()
265 yield node
266 for next_node in expand(node):
267 if next_node not in visited:
268 visited.add(next_node)
269 open_q.append(next_node)
271def bfs_all_unique(initial, expand):
272 "bfs, but doesn't keep track of visited (aka seen), because there can be no repetitions"
273 open_q = deque(list(initial))
274 while open_q:
275 node = open_q.popleft()
276 yield node
277 open_q += expand(node)
280def _serialize(value: Any, memo: Optional[SerializeMemoizer]) -> Any:
281 if isinstance(value, Serialize):
282 return value.serialize(memo)
283 elif isinstance(value, list):
284 return [_serialize(elem, memo) for elem in value]
285 elif isinstance(value, frozenset):
286 return list(value) # TODO reversible?
287 elif isinstance(value, dict):
288 return {key:_serialize(elem, memo) for key, elem in value.items()}
289 # assert value is None or isinstance(value, (int, float, str, tuple)), value
290 return value
295def small_factors(n: int, max_factor: int) -> List[Tuple[int, int]]:
296 """
297 Splits n up into smaller factors and summands <= max_factor.
298 Returns a list of [(a, b), ...]
299 so that the following code returns n:
301 n = 1
302 for a, b in values:
303 n = n * a + b
305 Currently, we also keep a + b <= max_factor, but that might change
306 """
307 assert n >= 0
308 assert max_factor > 2
309 if n <= max_factor:
310 return [(n, 0)]
312 for a in range(max_factor, 1, -1):
313 r, b = divmod(n, a)
314 if a + b <= max_factor:
315 return small_factors(r, max_factor) + [(a, b)]
316 assert False, "Failed to factorize %s" % n
319class OrderedSet(AbstractSet[T]):
320 """A minimal OrderedSet implementation, using a dictionary.
322 (relies on the dictionary being ordered)
323 """
324 def __init__(self, items: Iterable[T] =()):
325 self.d = dict.fromkeys(items)
327 def __contains__(self, item: Any) -> bool:
328 return item in self.d
330 def add(self, item: T):
331 self.d[item] = None
333 def __iter__(self) -> Iterator[T]:
334 return iter(self.d)
336 def remove(self, item: T):
337 del self.d[item]
339 def __bool__(self):
340 return bool(self.d)
342 def __len__(self) -> int:
343 return len(self.d)
345 def __repr__(self):
346 return f"{type(self).__name__}({', '.join(map(repr,self))})"