Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/lark/utils.py: 61%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import unicodedata
2import os
3from itertools import product
4from collections import deque
5from typing import Callable, Iterator, List, Optional, Tuple, Type, TypeVar, Union, Dict, Any, Sequence, Iterable, AbstractSet
7###{standalone
8import sys, re
9import logging
10from dataclasses import dataclass
11from typing import Generic, AnyStr
13logger: logging.Logger = logging.getLogger("lark")
14logger.addHandler(logging.StreamHandler())
15# Set to highest level, since we have some warnings amongst the code
16# By default, we should not output any log messages
17logger.setLevel(logging.CRITICAL)
20NO_VALUE = object()
22T = TypeVar("T")
25def classify(seq: Iterable, key: Optional[Callable] = None, value: Optional[Callable] = None) -> Dict:
26 d: Dict[Any, Any] = {}
27 for item in seq:
28 k = key(item) if (key is not None) else item
29 v = value(item) if (value is not None) else item
30 try:
31 d[k].append(v)
32 except KeyError:
33 d[k] = [v]
34 return d
37def _deserialize(data: Any, namespace: Dict[str, Any], memo: Dict) -> Any:
38 if isinstance(data, dict):
39 if '__type__' in data: # Object
40 class_ = namespace[data['__type__']]
41 return class_.deserialize(data, memo)
42 elif '@' in data:
43 return memo[data['@']]
44 return {key:_deserialize(value, namespace, memo) for key, value in data.items()}
45 elif isinstance(data, list):
46 return [_deserialize(value, namespace, memo) for value in data]
47 return data
50_T = TypeVar("_T", bound="Serialize")
52class Serialize:
53 """Safe-ish serialization interface that doesn't rely on Pickle
55 Attributes:
56 __serialize_fields__ (List[str]): Fields (aka attributes) to serialize.
57 __serialize_namespace__ (list): List of classes that deserialization is allowed to instantiate.
58 Should include all field types that aren't builtin types.
59 """
61 def memo_serialize(self, types_to_memoize: List) -> Any:
62 memo = SerializeMemoizer(types_to_memoize)
63 return self.serialize(memo), memo.serialize()
65 def serialize(self, memo = None) -> Dict[str, Any]:
66 if memo and memo.in_types(self):
67 return {'@': memo.memoized.get(self)}
69 fields = getattr(self, '__serialize_fields__')
70 res = {f: _serialize(getattr(self, f), memo) for f in fields}
71 res['__type__'] = type(self).__name__
72 if hasattr(self, '_serialize'):
73 self._serialize(res, memo)
74 return res
76 @classmethod
77 def deserialize(cls: Type[_T], data: Dict[str, Any], memo: Dict[int, Any]) -> _T:
78 namespace = getattr(cls, '__serialize_namespace__', [])
79 namespace = {c.__name__:c for c in namespace}
81 fields = getattr(cls, '__serialize_fields__')
83 if '@' in data:
84 return memo[data['@']]
86 inst = cls.__new__(cls)
87 for f in fields:
88 try:
89 setattr(inst, f, _deserialize(data[f], namespace, memo))
90 except KeyError as e:
91 raise KeyError("Cannot find key for class", cls, e)
93 if hasattr(inst, '_deserialize'):
94 inst._deserialize()
96 return inst
99class SerializeMemoizer(Serialize):
100 "A version of serialize that memoizes objects to reduce space"
102 __serialize_fields__ = 'memoized',
104 def __init__(self, types_to_memoize: List) -> None:
105 self.types_to_memoize = tuple(types_to_memoize)
106 self.memoized = Enumerator()
108 def in_types(self, value: Serialize) -> bool:
109 return isinstance(value, self.types_to_memoize)
111 def serialize(self) -> Dict[int, Any]: # type: ignore[override]
112 return _serialize(self.memoized.reversed(), None)
114 @classmethod
115 def deserialize(cls, data: Dict[int, Any], namespace: Dict[str, Any], memo: Dict[Any, Any]) -> Dict[int, Any]: # type: ignore[override]
116 return _deserialize(data, namespace, memo)
119try:
120 import regex
121 _has_regex = True
122except ImportError:
123 _has_regex = False
125if sys.version_info >= (3, 11):
126 import re._parser as sre_parse
127 import re._constants as sre_constants
128else:
129 import sre_parse
130 import sre_constants
132categ_pattern = re.compile(r'\\p{[A-Za-z_]+}')
134def get_regexp_width(expr: str) -> Union[Tuple[int, int], List[int]]:
135 if _has_regex:
136 # Since `sre_parse` cannot deal with Unicode categories of the form `\p{Mn}`, we replace these with
137 # a simple letter, which makes no difference as we are only trying to get the possible lengths of the regex
138 # match here below.
139 regexp_final = re.sub(categ_pattern, 'A', expr)
140 else:
141 if re.search(categ_pattern, expr):
142 raise ImportError('`regex` module must be installed in order to use Unicode categories.', expr)
143 regexp_final = expr
144 try:
145 # Fixed in next version (past 0.960) of typeshed
146 return [int(x) for x in sre_parse.parse(regexp_final).getwidth()]
147 except sre_constants.error:
148 if not _has_regex:
149 raise ValueError(expr)
150 else:
151 # sre_parse does not support the new features in regex. To not completely fail in that case,
152 # we manually test for the most important info (whether the empty string is matched)
153 c = regex.compile(regexp_final)
154 # Python 3.11.7 introducded sre_parse.MAXWIDTH that is used instead of MAXREPEAT
155 # See lark-parser/lark#1376 and python/cpython#109859
156 MAXWIDTH = getattr(sre_parse, "MAXWIDTH", sre_constants.MAXREPEAT)
157 if c.match('') is None:
158 # MAXREPEAT is a none pickable subclass of int, therefore needs to be converted to enable caching
159 return 1, int(MAXWIDTH)
160 else:
161 return 0, int(MAXWIDTH)
164@dataclass(frozen=True)
165class TextSlice(Generic[AnyStr]):
166 """A view of a string or bytes object, between the start and end indices.
168 Never creates a copy.
170 Lark accepts instances of TextSlice as input (instead of a string),
171 when the lexer is 'basic' or 'contextual'.
173 Args:
174 text (str or bytes): The text to slice.
175 start (int): The start index. Negative indices are supported.
176 end (int): The end index. Negative indices are supported.
178 Raises:
179 TypeError: If `text` is not a `str` or `bytes`.
180 AssertionError: If `start` or `end` are out of bounds.
182 Examples:
183 >>> TextSlice("Hello, World!", 7, -1)
184 TextSlice(text='Hello, World!', start=7, end=12)
186 >>> TextSlice("Hello, World!", 7, None).count("o")
187 1
189 """
190 text: AnyStr
191 start: int
192 end: int
194 def __post_init__(self):
195 if not isinstance(self.text, (str, bytes)):
196 raise TypeError("text must be str or bytes")
198 if self.start < 0:
199 object.__setattr__(self, 'start', self.start + len(self.text))
200 assert self.start >=0
202 if self.end is None:
203 object.__setattr__(self, 'end', len(self.text))
204 elif self.end < 0:
205 object.__setattr__(self, 'end', self.end + len(self.text))
206 assert self.end <= len(self.text)
208 @classmethod
209 def cast_from(cls, text: 'TextOrSlice') -> 'TextSlice[AnyStr]':
210 if isinstance(text, TextSlice):
211 return text
213 return cls(text, 0, len(text))
215 def is_complete_text(self):
216 return self.start == 0 and self.end == len(self.text)
218 def __len__(self):
219 return self.end - self.start
221 def count(self, substr: AnyStr):
222 return self.text.count(substr, self.start, self.end)
224 def rindex(self, substr: AnyStr):
225 return self.text.rindex(substr, self.start, self.end)
228TextOrSlice = Union[AnyStr, 'TextSlice[AnyStr]']
229LarkInput = Union[AnyStr, TextSlice[AnyStr], Any]
231###}
234_ID_START = 'Lu', 'Ll', 'Lt', 'Lm', 'Lo', 'Mn', 'Mc', 'Pc'
235_ID_CONTINUE = _ID_START + ('Nd', 'Nl',)
237def _test_unicode_category(s: str, categories: Sequence[str]) -> bool:
238 if len(s) != 1:
239 return all(_test_unicode_category(char, categories) for char in s)
240 return s == '_' or unicodedata.category(s) in categories
242def is_id_continue(s: str) -> bool:
243 """
244 Checks if all characters in `s` are alphanumeric characters (Unicode standard, so diacritics, indian vowels, non-latin
245 numbers, etc. all pass). Synonymous with a Python `ID_CONTINUE` identifier. See PEP 3131 for details.
246 """
247 return _test_unicode_category(s, _ID_CONTINUE)
249def is_id_start(s: str) -> bool:
250 """
251 Checks if all characters in `s` are alphabetic characters (Unicode standard, so diacritics, indian vowels, non-latin
252 numbers, etc. all pass). Synonymous with a Python `ID_START` identifier. See PEP 3131 for details.
253 """
254 return _test_unicode_category(s, _ID_START)
257def dedup_list(l: Iterable[T]) -> List[T]:
258 """Given a list (l) will removing duplicates from the list,
259 preserving the original order of the list. Assumes that
260 the list entries are hashable."""
261 return list(dict.fromkeys(l))
264class Enumerator(Serialize):
265 def __init__(self) -> None:
266 self.enums: Dict[Any, int] = {}
268 def get(self, item) -> int:
269 if item not in self.enums:
270 self.enums[item] = len(self.enums)
271 return self.enums[item]
273 def __len__(self):
274 return len(self.enums)
276 def reversed(self) -> Dict[int, Any]:
277 r = {v: k for k, v in self.enums.items()}
278 assert len(r) == len(self.enums)
279 return r
283def combine_alternatives(lists):
284 """
285 Accepts a list of alternatives, and enumerates all their possible concatenations.
287 Examples:
288 >>> combine_alternatives([range(2), [4,5]])
289 [[0, 4], [0, 5], [1, 4], [1, 5]]
291 >>> combine_alternatives(["abc", "xy", '$'])
292 [['a', 'x', '$'], ['a', 'y', '$'], ['b', 'x', '$'], ['b', 'y', '$'], ['c', 'x', '$'], ['c', 'y', '$']]
294 >>> combine_alternatives([])
295 [[]]
296 """
297 if not lists:
298 return [[]]
299 assert all(l for l in lists), lists
300 return list(product(*lists))
302try:
303 import atomicwrites
304 _has_atomicwrites = True
305except ImportError:
306 _has_atomicwrites = False
308class FS:
309 exists = staticmethod(os.path.exists)
311 @staticmethod
312 def open(name, mode="r", **kwargs):
313 if _has_atomicwrites and "w" in mode:
314 return atomicwrites.atomic_write(name, mode=mode, overwrite=True, **kwargs)
315 else:
316 return open(name, mode, **kwargs)
319class fzset(frozenset):
320 def __repr__(self):
321 return '{%s}' % ', '.join(map(repr, self))
324def classify_bool(seq: Iterable, pred: Callable) -> Any:
325 false_elems = []
326 true_elems = [elem for elem in seq if pred(elem) or false_elems.append(elem)] # type: ignore[func-returns-value]
327 return true_elems, false_elems
330def bfs(initial: Iterable, expand: Callable) -> Iterator:
331 open_q = deque(list(initial))
332 visited = set(open_q)
333 while open_q:
334 node = open_q.popleft()
335 yield node
336 for next_node in expand(node):
337 if next_node not in visited:
338 visited.add(next_node)
339 open_q.append(next_node)
341def bfs_all_unique(initial, expand):
342 "bfs, but doesn't keep track of visited (aka seen), because there can be no repetitions"
343 open_q = deque(list(initial))
344 while open_q:
345 node = open_q.popleft()
346 yield node
347 open_q += expand(node)
350def _serialize(value: Any, memo: Optional[SerializeMemoizer]) -> Any:
351 if isinstance(value, Serialize):
352 return value.serialize(memo)
353 elif isinstance(value, list):
354 return [_serialize(elem, memo) for elem in value]
355 elif isinstance(value, frozenset):
356 return list(value) # TODO reversible?
357 elif isinstance(value, dict):
358 return {key:_serialize(elem, memo) for key, elem in value.items()}
359 # assert value is None or isinstance(value, (int, float, str, tuple)), value
360 return value
365def small_factors(n: int, max_factor: int) -> List[Tuple[int, int]]:
366 """
367 Splits n up into smaller factors and summands <= max_factor.
368 Returns a list of [(a, b), ...]
369 so that the following code returns n:
371 n = 1
372 for a, b in values:
373 n = n * a + b
375 Currently, we also keep a + b <= max_factor, but that might change
376 """
377 assert n >= 0
378 assert max_factor > 2
379 if n <= max_factor:
380 return [(n, 0)]
382 for a in range(max_factor, 1, -1):
383 r, b = divmod(n, a)
384 if a + b <= max_factor:
385 return small_factors(r, max_factor) + [(a, b)]
386 assert False, "Failed to factorize %s" % n
389class OrderedSet(AbstractSet[T]):
390 """A minimal OrderedSet implementation, using a dictionary.
392 (relies on the dictionary being ordered)
393 """
394 def __init__(self, items: Iterable[T] =()):
395 self.d = dict.fromkeys(items)
397 def __contains__(self, item: Any) -> bool:
398 return item in self.d
400 def add(self, item: T):
401 self.d[item] = None
403 def __iter__(self) -> Iterator[T]:
404 return iter(self.d)
406 def remove(self, item: T):
407 del self.d[item]
409 def __bool__(self):
410 return bool(self.d)
412 def __len__(self) -> int:
413 return len(self.d)
415 def __repr__(self):
416 return f"{type(self).__name__}({', '.join(map(repr,self))})"