Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/lark/utils.py: 61%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import unicodedata
2import os
3from itertools import product
4from collections import deque
5from typing import Callable, Iterator, List, Optional, Tuple, Type, TypeVar, Union, Dict, Any, Sequence, Iterable, AbstractSet
7###{standalone
8import sys, re
9import logging
10from dataclasses import dataclass
11from typing import Generic, AnyStr
13logger: logging.Logger = logging.getLogger("lark")
14logger.addHandler(logging.StreamHandler())
15# Set to highest level, since we have some warnings amongst the code
16# By default, we should not output any log messages
17logger.setLevel(logging.CRITICAL)
20NO_VALUE = object()
22T = TypeVar("T")
25def classify(seq: Iterable, key: Optional[Callable] = None, value: Optional[Callable] = None) -> Dict:
26 d: Dict[Any, Any] = {}
27 for item in seq:
28 k = key(item) if (key is not None) else item
29 v = value(item) if (value is not None) else item
30 try:
31 d[k].append(v)
32 except KeyError:
33 d[k] = [v]
34 return d
37def _deserialize(data: Any, namespace: Dict[str, Any], memo: Dict) -> Any:
38 if isinstance(data, dict):
39 if '__type__' in data: # Object
40 class_ = namespace[data['__type__']]
41 return class_.deserialize(data, memo)
42 elif '@' in data:
43 return memo[data['@']]
44 return {key:_deserialize(value, namespace, memo) for key, value in data.items()}
45 elif isinstance(data, list):
46 return [_deserialize(value, namespace, memo) for value in data]
47 return data
50_T = TypeVar("_T", bound="Serialize")
52class Serialize:
53 """Safe-ish serialization interface that doesn't rely on Pickle
55 Attributes:
56 __serialize_fields__ (List[str]): Fields (aka attributes) to serialize.
57 __serialize_namespace__ (list): List of classes that deserialization is allowed to instantiate.
58 Should include all field types that aren't builtin types.
59 """
61 def memo_serialize(self, types_to_memoize: List) -> Any:
62 memo = SerializeMemoizer(types_to_memoize)
63 return self.serialize(memo), memo.serialize()
65 def serialize(self, memo = None) -> Dict[str, Any]:
66 if memo and memo.in_types(self):
67 return {'@': memo.memoized.get(self)}
69 fields = getattr(self, '__serialize_fields__')
70 res = {f: _serialize(getattr(self, f), memo) for f in fields}
71 res['__type__'] = type(self).__name__
72 if hasattr(self, '_serialize'):
73 self._serialize(res, memo)
74 return res
76 @classmethod
77 def deserialize(cls: Type[_T], data: Dict[str, Any], memo: Dict[int, Any]) -> _T:
78 namespace = getattr(cls, '__serialize_namespace__', [])
79 namespace = {c.__name__:c for c in namespace}
81 fields = getattr(cls, '__serialize_fields__')
83 if '@' in data:
84 return memo[data['@']]
86 inst = cls.__new__(cls)
87 for f in fields:
88 try:
89 setattr(inst, f, _deserialize(data[f], namespace, memo))
90 except KeyError as e:
91 raise KeyError("Cannot find key for class", cls, e)
93 if hasattr(inst, '_deserialize'):
94 inst._deserialize()
96 return inst
99class SerializeMemoizer(Serialize):
100 "A version of serialize that memoizes objects to reduce space"
102 __serialize_fields__ = 'memoized',
104 def __init__(self, types_to_memoize: List) -> None:
105 self.types_to_memoize = tuple(types_to_memoize)
106 self.memoized = Enumerator()
108 def in_types(self, value: Serialize) -> bool:
109 return isinstance(value, self.types_to_memoize)
111 def serialize(self) -> Dict[int, Any]: # type: ignore[override]
112 return _serialize(self.memoized.reversed(), None)
114 @classmethod
115 def deserialize(cls, data: Dict[int, Any], namespace: Dict[str, Any], memo: Dict[Any, Any]) -> Dict[int, Any]: # type: ignore[override]
116 return _deserialize(data, namespace, memo)
119try:
120 import regex
121 _has_regex = True
122except ImportError:
123 _has_regex = False
125if sys.version_info >= (3, 11):
126 import re._parser as sre_parse
127 import re._constants as sre_constants
128else:
129 import sre_parse
130 import sre_constants
132categ_pattern = re.compile(r'\\p{[A-Za-z_]+}')
134def get_regexp_width(expr: str) -> Union[Tuple[int, int], List[int]]:
135 if _has_regex:
136 # Since `sre_parse` cannot deal with Unicode categories of the form `\p{Mn}`, we replace these with
137 # a simple letter, which makes no difference as we are only trying to get the possible lengths of the regex
138 # match here below.
139 regexp_final = re.sub(categ_pattern, 'A', expr)
140 else:
141 if re.search(categ_pattern, expr):
142 raise ImportError('`regex` module must be installed in order to use Unicode categories.', expr)
143 regexp_final = expr
144 try:
145 # Fixed in next version (past 0.960) of typeshed
146 return [int(x) for x in sre_parse.parse(regexp_final).getwidth()]
147 except sre_constants.error:
148 if not _has_regex:
149 raise ValueError(expr)
150 else:
151 # sre_parse does not support the new features in regex. To not completely fail in that case,
152 # we manually test for the most important info (whether the empty string is matched)
153 c = regex.compile(regexp_final)
154 # Python 3.11.7 introducded sre_parse.MAXWIDTH that is used instead of MAXREPEAT
155 # See lark-parser/lark#1376 and python/cpython#109859
156 MAXWIDTH = getattr(sre_parse, "MAXWIDTH", sre_constants.MAXREPEAT)
157 if c.match('') is None:
158 # MAXREPEAT is a none pickable subclass of int, therefore needs to be converted to enable caching
159 return 1, int(MAXWIDTH)
160 else:
161 return 0, int(MAXWIDTH)
164@dataclass(frozen=True)
165class TextSlice(Generic[AnyStr]):
166 """A view of a string or bytes object, between the start and end indices.
168 Never creates a copy.
170 Lark accepts instances of TextSlice as input (instead of a string),
171 when the lexer is 'basic' or 'contextual'.
173 Args:
174 text (str or bytes): The text to slice.
175 start (int): The start index. Negative indices are supported.
176 end (int): The end index. Negative indices are supported.
178 Raises:
179 TypeError: If `text` is not a `str` or `bytes`.
180 AssertionError: If `start` or `end` are out of bounds.
182 Examples:
183 >>> TextSlice("Hello, World!", 7, -1)
184 TextSlice(text='Hello, World!', start=7, end=12)
186 >>> TextSlice("Hello, World!", 7, None).count("o")
187 1
189 """
190 text: AnyStr
191 start: int
192 end: int
194 def __post_init__(self):
195 if not isinstance(self.text, (str, bytes)):
196 raise TypeError("text must be str or bytes")
198 if self.start < 0:
199 object.__setattr__(self, 'start', self.start + len(self.text))
200 assert self.start >=0
202 if self.end is None:
203 object.__setattr__(self, 'end', len(self.text))
204 elif self.end < 0:
205 object.__setattr__(self, 'end', self.end + len(self.text))
206 assert self.end <= len(self.text)
208 @classmethod
209 def cast_from(cls, text: 'TextOrSlice') -> 'TextSlice[AnyStr]':
210 if isinstance(text, TextSlice):
211 return text
213 return cls(text, 0, len(text))
215 def is_complete_text(self):
216 return self.start == 0 and self.end == len(self.text)
218 def __len__(self):
219 return self.end - self.start
221 def count(self, substr: AnyStr):
222 return self.text.count(substr, self.start, self.end)
224 def rindex(self, substr: AnyStr):
225 return self.text.rindex(substr, self.start, self.end)
228TextOrSlice = Union[AnyStr, 'TextSlice[AnyStr]']
230###}
233_ID_START = 'Lu', 'Ll', 'Lt', 'Lm', 'Lo', 'Mn', 'Mc', 'Pc'
234_ID_CONTINUE = _ID_START + ('Nd', 'Nl',)
236def _test_unicode_category(s: str, categories: Sequence[str]) -> bool:
237 if len(s) != 1:
238 return all(_test_unicode_category(char, categories) for char in s)
239 return s == '_' or unicodedata.category(s) in categories
241def is_id_continue(s: str) -> bool:
242 """
243 Checks if all characters in `s` are alphanumeric characters (Unicode standard, so diacritics, indian vowels, non-latin
244 numbers, etc. all pass). Synonymous with a Python `ID_CONTINUE` identifier. See PEP 3131 for details.
245 """
246 return _test_unicode_category(s, _ID_CONTINUE)
248def is_id_start(s: str) -> bool:
249 """
250 Checks if all characters in `s` are alphabetic characters (Unicode standard, so diacritics, indian vowels, non-latin
251 numbers, etc. all pass). Synonymous with a Python `ID_START` identifier. See PEP 3131 for details.
252 """
253 return _test_unicode_category(s, _ID_START)
256def dedup_list(l: Iterable[T]) -> List[T]:
257 """Given a list (l) will removing duplicates from the list,
258 preserving the original order of the list. Assumes that
259 the list entries are hashable."""
260 return list(dict.fromkeys(l))
263class Enumerator(Serialize):
264 def __init__(self) -> None:
265 self.enums: Dict[Any, int] = {}
267 def get(self, item) -> int:
268 if item not in self.enums:
269 self.enums[item] = len(self.enums)
270 return self.enums[item]
272 def __len__(self):
273 return len(self.enums)
275 def reversed(self) -> Dict[int, Any]:
276 r = {v: k for k, v in self.enums.items()}
277 assert len(r) == len(self.enums)
278 return r
282def combine_alternatives(lists):
283 """
284 Accepts a list of alternatives, and enumerates all their possible concatenations.
286 Examples:
287 >>> combine_alternatives([range(2), [4,5]])
288 [[0, 4], [0, 5], [1, 4], [1, 5]]
290 >>> combine_alternatives(["abc", "xy", '$'])
291 [['a', 'x', '$'], ['a', 'y', '$'], ['b', 'x', '$'], ['b', 'y', '$'], ['c', 'x', '$'], ['c', 'y', '$']]
293 >>> combine_alternatives([])
294 [[]]
295 """
296 if not lists:
297 return [[]]
298 assert all(l for l in lists), lists
299 return list(product(*lists))
301try:
302 import atomicwrites
303 _has_atomicwrites = True
304except ImportError:
305 _has_atomicwrites = False
307class FS:
308 exists = staticmethod(os.path.exists)
310 @staticmethod
311 def open(name, mode="r", **kwargs):
312 if _has_atomicwrites and "w" in mode:
313 return atomicwrites.atomic_write(name, mode=mode, overwrite=True, **kwargs)
314 else:
315 return open(name, mode, **kwargs)
318class fzset(frozenset):
319 def __repr__(self):
320 return '{%s}' % ', '.join(map(repr, self))
323def classify_bool(seq: Iterable, pred: Callable) -> Any:
324 false_elems = []
325 true_elems = [elem for elem in seq if pred(elem) or false_elems.append(elem)] # type: ignore[func-returns-value]
326 return true_elems, false_elems
329def bfs(initial: Iterable, expand: Callable) -> Iterator:
330 open_q = deque(list(initial))
331 visited = set(open_q)
332 while open_q:
333 node = open_q.popleft()
334 yield node
335 for next_node in expand(node):
336 if next_node not in visited:
337 visited.add(next_node)
338 open_q.append(next_node)
340def bfs_all_unique(initial, expand):
341 "bfs, but doesn't keep track of visited (aka seen), because there can be no repetitions"
342 open_q = deque(list(initial))
343 while open_q:
344 node = open_q.popleft()
345 yield node
346 open_q += expand(node)
349def _serialize(value: Any, memo: Optional[SerializeMemoizer]) -> Any:
350 if isinstance(value, Serialize):
351 return value.serialize(memo)
352 elif isinstance(value, list):
353 return [_serialize(elem, memo) for elem in value]
354 elif isinstance(value, frozenset):
355 return list(value) # TODO reversible?
356 elif isinstance(value, dict):
357 return {key:_serialize(elem, memo) for key, elem in value.items()}
358 # assert value is None or isinstance(value, (int, float, str, tuple)), value
359 return value
364def small_factors(n: int, max_factor: int) -> List[Tuple[int, int]]:
365 """
366 Splits n up into smaller factors and summands <= max_factor.
367 Returns a list of [(a, b), ...]
368 so that the following code returns n:
370 n = 1
371 for a, b in values:
372 n = n * a + b
374 Currently, we also keep a + b <= max_factor, but that might change
375 """
376 assert n >= 0
377 assert max_factor > 2
378 if n <= max_factor:
379 return [(n, 0)]
381 for a in range(max_factor, 1, -1):
382 r, b = divmod(n, a)
383 if a + b <= max_factor:
384 return small_factors(r, max_factor) + [(a, b)]
385 assert False, "Failed to factorize %s" % n
388class OrderedSet(AbstractSet[T]):
389 """A minimal OrderedSet implementation, using a dictionary.
391 (relies on the dictionary being ordered)
392 """
393 def __init__(self, items: Iterable[T] =()):
394 self.d = dict.fromkeys(items)
396 def __contains__(self, item: Any) -> bool:
397 return item in self.d
399 def add(self, item: T):
400 self.d[item] = None
402 def __iter__(self) -> Iterator[T]:
403 return iter(self.d)
405 def remove(self, item: T):
406 del self.d[item]
408 def __bool__(self):
409 return bool(self.d)
411 def __len__(self) -> int:
412 return len(self.d)
414 def __repr__(self):
415 return f"{type(self).__name__}({', '.join(map(repr,self))})"