Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/lark/utils.py: 61%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

237 statements  

1import unicodedata 

2import os 

3from itertools import product 

4from collections import deque 

5from typing import Callable, Iterator, List, Optional, Tuple, Type, TypeVar, Union, Dict, Any, Sequence, Iterable, AbstractSet 

6 

7###{standalone 

8import sys, re 

9import logging 

10from dataclasses import dataclass 

11from typing import Generic, AnyStr 

12 

13logger: logging.Logger = logging.getLogger("lark") 

14logger.addHandler(logging.StreamHandler()) 

15# Set to highest level, since we have some warnings amongst the code 

16# By default, we should not output any log messages 

17logger.setLevel(logging.CRITICAL) 

18 

19 

20NO_VALUE = object() 

21 

22T = TypeVar("T") 

23 

24 

25def classify(seq: Iterable, key: Optional[Callable] = None, value: Optional[Callable] = None) -> Dict: 

26 d: Dict[Any, Any] = {} 

27 for item in seq: 

28 k = key(item) if (key is not None) else item 

29 v = value(item) if (value is not None) else item 

30 try: 

31 d[k].append(v) 

32 except KeyError: 

33 d[k] = [v] 

34 return d 

35 

36 

37def _deserialize(data: Any, namespace: Dict[str, Any], memo: Dict) -> Any: 

38 if isinstance(data, dict): 

39 if '__type__' in data: # Object 

40 class_ = namespace[data['__type__']] 

41 return class_.deserialize(data, memo) 

42 elif '@' in data: 

43 return memo[data['@']] 

44 return {key:_deserialize(value, namespace, memo) for key, value in data.items()} 

45 elif isinstance(data, list): 

46 return [_deserialize(value, namespace, memo) for value in data] 

47 return data 

48 

49 

50_T = TypeVar("_T", bound="Serialize") 

51 

52class Serialize: 

53 """Safe-ish serialization interface that doesn't rely on Pickle 

54 

55 Attributes: 

56 __serialize_fields__ (List[str]): Fields (aka attributes) to serialize. 

57 __serialize_namespace__ (list): List of classes that deserialization is allowed to instantiate. 

58 Should include all field types that aren't builtin types. 

59 """ 

60 

61 def memo_serialize(self, types_to_memoize: List) -> Any: 

62 memo = SerializeMemoizer(types_to_memoize) 

63 return self.serialize(memo), memo.serialize() 

64 

65 def serialize(self, memo = None) -> Dict[str, Any]: 

66 if memo and memo.in_types(self): 

67 return {'@': memo.memoized.get(self)} 

68 

69 fields = getattr(self, '__serialize_fields__') 

70 res = {f: _serialize(getattr(self, f), memo) for f in fields} 

71 res['__type__'] = type(self).__name__ 

72 if hasattr(self, '_serialize'): 

73 self._serialize(res, memo) 

74 return res 

75 

76 @classmethod 

77 def deserialize(cls: Type[_T], data: Dict[str, Any], memo: Dict[int, Any]) -> _T: 

78 namespace = getattr(cls, '__serialize_namespace__', []) 

79 namespace = {c.__name__:c for c in namespace} 

80 

81 fields = getattr(cls, '__serialize_fields__') 

82 

83 if '@' in data: 

84 return memo[data['@']] 

85 

86 inst = cls.__new__(cls) 

87 for f in fields: 

88 try: 

89 setattr(inst, f, _deserialize(data[f], namespace, memo)) 

90 except KeyError as e: 

91 raise KeyError("Cannot find key for class", cls, e) 

92 

93 if hasattr(inst, '_deserialize'): 

94 inst._deserialize() 

95 

96 return inst 

97 

98 

99class SerializeMemoizer(Serialize): 

100 "A version of serialize that memoizes objects to reduce space" 

101 

102 __serialize_fields__ = 'memoized', 

103 

104 def __init__(self, types_to_memoize: List) -> None: 

105 self.types_to_memoize = tuple(types_to_memoize) 

106 self.memoized = Enumerator() 

107 

108 def in_types(self, value: Serialize) -> bool: 

109 return isinstance(value, self.types_to_memoize) 

110 

111 def serialize(self) -> Dict[int, Any]: # type: ignore[override] 

112 return _serialize(self.memoized.reversed(), None) 

113 

114 @classmethod 

115 def deserialize(cls, data: Dict[int, Any], namespace: Dict[str, Any], memo: Dict[Any, Any]) -> Dict[int, Any]: # type: ignore[override] 

116 return _deserialize(data, namespace, memo) 

117 

118 

119try: 

120 import regex 

121 _has_regex = True 

122except ImportError: 

123 _has_regex = False 

124 

125if sys.version_info >= (3, 11): 

126 import re._parser as sre_parse 

127 import re._constants as sre_constants 

128else: 

129 import sre_parse 

130 import sre_constants 

131 

132categ_pattern = re.compile(r'\\p{[A-Za-z_]+}') 

133 

134def get_regexp_width(expr: str) -> Union[Tuple[int, int], List[int]]: 

135 if _has_regex: 

136 # Since `sre_parse` cannot deal with Unicode categories of the form `\p{Mn}`, we replace these with 

137 # a simple letter, which makes no difference as we are only trying to get the possible lengths of the regex 

138 # match here below. 

139 regexp_final = re.sub(categ_pattern, 'A', expr) 

140 else: 

141 if re.search(categ_pattern, expr): 

142 raise ImportError('`regex` module must be installed in order to use Unicode categories.', expr) 

143 regexp_final = expr 

144 try: 

145 # Fixed in next version (past 0.960) of typeshed 

146 return [int(x) for x in sre_parse.parse(regexp_final).getwidth()] 

147 except sre_constants.error: 

148 if not _has_regex: 

149 raise ValueError(expr) 

150 else: 

151 # sre_parse does not support the new features in regex. To not completely fail in that case, 

152 # we manually test for the most important info (whether the empty string is matched) 

153 c = regex.compile(regexp_final) 

154 # Python 3.11.7 introducded sre_parse.MAXWIDTH that is used instead of MAXREPEAT 

155 # See lark-parser/lark#1376 and python/cpython#109859 

156 MAXWIDTH = getattr(sre_parse, "MAXWIDTH", sre_constants.MAXREPEAT) 

157 if c.match('') is None: 

158 # MAXREPEAT is a none pickable subclass of int, therefore needs to be converted to enable caching 

159 return 1, int(MAXWIDTH) 

160 else: 

161 return 0, int(MAXWIDTH) 

162 

163 

164@dataclass(frozen=True) 

165class TextSlice(Generic[AnyStr]): 

166 """A view of a string or bytes object, between the start and end indices. 

167 

168 Never creates a copy. 

169 

170 Lark accepts instances of TextSlice as input (instead of a string), 

171 when the lexer is 'basic' or 'contextual'. 

172 

173 Args: 

174 text (str or bytes): The text to slice. 

175 start (int): The start index. Negative indices are supported. 

176 end (int): The end index. Negative indices are supported. 

177 

178 Raises: 

179 TypeError: If `text` is not a `str` or `bytes`. 

180 AssertionError: If `start` or `end` are out of bounds. 

181 

182 Examples: 

183 >>> TextSlice("Hello, World!", 7, -1) 

184 TextSlice(text='Hello, World!', start=7, end=12) 

185 

186 >>> TextSlice("Hello, World!", 7, None).count("o") 

187 1 

188 

189 """ 

190 text: AnyStr 

191 start: int 

192 end: int 

193 

194 def __post_init__(self): 

195 if not isinstance(self.text, (str, bytes)): 

196 raise TypeError("text must be str or bytes") 

197 

198 if self.start < 0: 

199 object.__setattr__(self, 'start', self.start + len(self.text)) 

200 assert self.start >=0 

201 

202 if self.end is None: 

203 object.__setattr__(self, 'end', len(self.text)) 

204 elif self.end < 0: 

205 object.__setattr__(self, 'end', self.end + len(self.text)) 

206 assert self.end <= len(self.text) 

207 

208 @classmethod 

209 def cast_from(cls, text: 'TextOrSlice') -> 'TextSlice[AnyStr]': 

210 if isinstance(text, TextSlice): 

211 return text 

212 

213 return cls(text, 0, len(text)) 

214 

215 def is_complete_text(self): 

216 return self.start == 0 and self.end == len(self.text) 

217 

218 def __len__(self): 

219 return self.end - self.start 

220 

221 def count(self, substr: AnyStr): 

222 return self.text.count(substr, self.start, self.end) 

223 

224 def rindex(self, substr: AnyStr): 

225 return self.text.rindex(substr, self.start, self.end) 

226 

227 

228TextOrSlice = Union[AnyStr, 'TextSlice[AnyStr]'] 

229LarkInput = Union[AnyStr, TextSlice[AnyStr], Any] 

230 

231###} 

232 

233 

234_ID_START = 'Lu', 'Ll', 'Lt', 'Lm', 'Lo', 'Mn', 'Mc', 'Pc' 

235_ID_CONTINUE = _ID_START + ('Nd', 'Nl',) 

236 

237def _test_unicode_category(s: str, categories: Sequence[str]) -> bool: 

238 if len(s) != 1: 

239 return all(_test_unicode_category(char, categories) for char in s) 

240 return s == '_' or unicodedata.category(s) in categories 

241 

242def is_id_continue(s: str) -> bool: 

243 """ 

244 Checks if all characters in `s` are alphanumeric characters (Unicode standard, so diacritics, indian vowels, non-latin 

245 numbers, etc. all pass). Synonymous with a Python `ID_CONTINUE` identifier. See PEP 3131 for details. 

246 """ 

247 return _test_unicode_category(s, _ID_CONTINUE) 

248 

249def is_id_start(s: str) -> bool: 

250 """ 

251 Checks if all characters in `s` are alphabetic characters (Unicode standard, so diacritics, indian vowels, non-latin 

252 numbers, etc. all pass). Synonymous with a Python `ID_START` identifier. See PEP 3131 for details. 

253 """ 

254 return _test_unicode_category(s, _ID_START) 

255 

256 

257def dedup_list(l: Iterable[T]) -> List[T]: 

258 """Given a list (l) will removing duplicates from the list, 

259 preserving the original order of the list. Assumes that 

260 the list entries are hashable.""" 

261 return list(dict.fromkeys(l)) 

262 

263 

264class Enumerator(Serialize): 

265 def __init__(self) -> None: 

266 self.enums: Dict[Any, int] = {} 

267 

268 def get(self, item) -> int: 

269 if item not in self.enums: 

270 self.enums[item] = len(self.enums) 

271 return self.enums[item] 

272 

273 def __len__(self): 

274 return len(self.enums) 

275 

276 def reversed(self) -> Dict[int, Any]: 

277 r = {v: k for k, v in self.enums.items()} 

278 assert len(r) == len(self.enums) 

279 return r 

280 

281 

282 

283def combine_alternatives(lists): 

284 """ 

285 Accepts a list of alternatives, and enumerates all their possible concatenations. 

286 

287 Examples: 

288 >>> combine_alternatives([range(2), [4,5]]) 

289 [[0, 4], [0, 5], [1, 4], [1, 5]] 

290 

291 >>> combine_alternatives(["abc", "xy", '$']) 

292 [['a', 'x', '$'], ['a', 'y', '$'], ['b', 'x', '$'], ['b', 'y', '$'], ['c', 'x', '$'], ['c', 'y', '$']] 

293 

294 >>> combine_alternatives([]) 

295 [[]] 

296 """ 

297 if not lists: 

298 return [[]] 

299 assert all(l for l in lists), lists 

300 return list(product(*lists)) 

301 

302try: 

303 import atomicwrites 

304 _has_atomicwrites = True 

305except ImportError: 

306 _has_atomicwrites = False 

307 

308class FS: 

309 exists = staticmethod(os.path.exists) 

310 

311 @staticmethod 

312 def open(name, mode="r", **kwargs): 

313 if _has_atomicwrites and "w" in mode: 

314 return atomicwrites.atomic_write(name, mode=mode, overwrite=True, **kwargs) 

315 else: 

316 return open(name, mode, **kwargs) 

317 

318 

319class fzset(frozenset): 

320 def __repr__(self): 

321 return '{%s}' % ', '.join(map(repr, self)) 

322 

323 

324def classify_bool(seq: Iterable, pred: Callable) -> Any: 

325 false_elems = [] 

326 true_elems = [elem for elem in seq if pred(elem) or false_elems.append(elem)] # type: ignore[func-returns-value] 

327 return true_elems, false_elems 

328 

329 

330def bfs(initial: Iterable, expand: Callable) -> Iterator: 

331 open_q = deque(list(initial)) 

332 visited = set(open_q) 

333 while open_q: 

334 node = open_q.popleft() 

335 yield node 

336 for next_node in expand(node): 

337 if next_node not in visited: 

338 visited.add(next_node) 

339 open_q.append(next_node) 

340 

341def bfs_all_unique(initial, expand): 

342 "bfs, but doesn't keep track of visited (aka seen), because there can be no repetitions" 

343 open_q = deque(list(initial)) 

344 while open_q: 

345 node = open_q.popleft() 

346 yield node 

347 open_q += expand(node) 

348 

349 

350def _serialize(value: Any, memo: Optional[SerializeMemoizer]) -> Any: 

351 if isinstance(value, Serialize): 

352 return value.serialize(memo) 

353 elif isinstance(value, list): 

354 return [_serialize(elem, memo) for elem in value] 

355 elif isinstance(value, frozenset): 

356 return list(value) # TODO reversible? 

357 elif isinstance(value, dict): 

358 return {key:_serialize(elem, memo) for key, elem in value.items()} 

359 # assert value is None or isinstance(value, (int, float, str, tuple)), value 

360 return value 

361 

362 

363 

364 

365def small_factors(n: int, max_factor: int) -> List[Tuple[int, int]]: 

366 """ 

367 Splits n up into smaller factors and summands <= max_factor. 

368 Returns a list of [(a, b), ...] 

369 so that the following code returns n: 

370 

371 n = 1 

372 for a, b in values: 

373 n = n * a + b 

374 

375 Currently, we also keep a + b <= max_factor, but that might change 

376 """ 

377 assert n >= 0 

378 assert max_factor > 2 

379 if n <= max_factor: 

380 return [(n, 0)] 

381 

382 for a in range(max_factor, 1, -1): 

383 r, b = divmod(n, a) 

384 if a + b <= max_factor: 

385 return small_factors(r, max_factor) + [(a, b)] 

386 assert False, "Failed to factorize %s" % n 

387 

388 

389class OrderedSet(AbstractSet[T]): 

390 """A minimal OrderedSet implementation, using a dictionary. 

391 

392 (relies on the dictionary being ordered) 

393 """ 

394 def __init__(self, items: Iterable[T] =()): 

395 self.d = dict.fromkeys(items) 

396 

397 def __contains__(self, item: Any) -> bool: 

398 return item in self.d 

399 

400 def add(self, item: T): 

401 self.d[item] = None 

402 

403 def __iter__(self) -> Iterator[T]: 

404 return iter(self.d) 

405 

406 def remove(self, item: T): 

407 del self.d[item] 

408 

409 def __bool__(self): 

410 return bool(self.d) 

411 

412 def __len__(self) -> int: 

413 return len(self.d) 

414 

415 def __repr__(self): 

416 return f"{type(self).__name__}({', '.join(map(repr,self))})"