Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/lark/utils.py: 61%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

236 statements  

1import unicodedata 

2import os 

3from itertools import product 

4from collections import deque 

5from typing import Callable, Iterator, List, Optional, Tuple, Type, TypeVar, Union, Dict, Any, Sequence, Iterable, AbstractSet 

6 

7###{standalone 

8import sys, re 

9import logging 

10from dataclasses import dataclass 

11from typing import Generic, AnyStr 

12 

13logger: logging.Logger = logging.getLogger("lark") 

14logger.addHandler(logging.StreamHandler()) 

15# Set to highest level, since we have some warnings amongst the code 

16# By default, we should not output any log messages 

17logger.setLevel(logging.CRITICAL) 

18 

19 

20NO_VALUE = object() 

21 

22T = TypeVar("T") 

23 

24 

25def classify(seq: Iterable, key: Optional[Callable] = None, value: Optional[Callable] = None) -> Dict: 

26 d: Dict[Any, Any] = {} 

27 for item in seq: 

28 k = key(item) if (key is not None) else item 

29 v = value(item) if (value is not None) else item 

30 try: 

31 d[k].append(v) 

32 except KeyError: 

33 d[k] = [v] 

34 return d 

35 

36 

37def _deserialize(data: Any, namespace: Dict[str, Any], memo: Dict) -> Any: 

38 if isinstance(data, dict): 

39 if '__type__' in data: # Object 

40 class_ = namespace[data['__type__']] 

41 return class_.deserialize(data, memo) 

42 elif '@' in data: 

43 return memo[data['@']] 

44 return {key:_deserialize(value, namespace, memo) for key, value in data.items()} 

45 elif isinstance(data, list): 

46 return [_deserialize(value, namespace, memo) for value in data] 

47 return data 

48 

49 

50_T = TypeVar("_T", bound="Serialize") 

51 

52class Serialize: 

53 """Safe-ish serialization interface that doesn't rely on Pickle 

54 

55 Attributes: 

56 __serialize_fields__ (List[str]): Fields (aka attributes) to serialize. 

57 __serialize_namespace__ (list): List of classes that deserialization is allowed to instantiate. 

58 Should include all field types that aren't builtin types. 

59 """ 

60 

61 def memo_serialize(self, types_to_memoize: List) -> Any: 

62 memo = SerializeMemoizer(types_to_memoize) 

63 return self.serialize(memo), memo.serialize() 

64 

65 def serialize(self, memo = None) -> Dict[str, Any]: 

66 if memo and memo.in_types(self): 

67 return {'@': memo.memoized.get(self)} 

68 

69 fields = getattr(self, '__serialize_fields__') 

70 res = {f: _serialize(getattr(self, f), memo) for f in fields} 

71 res['__type__'] = type(self).__name__ 

72 if hasattr(self, '_serialize'): 

73 self._serialize(res, memo) 

74 return res 

75 

76 @classmethod 

77 def deserialize(cls: Type[_T], data: Dict[str, Any], memo: Dict[int, Any]) -> _T: 

78 namespace = getattr(cls, '__serialize_namespace__', []) 

79 namespace = {c.__name__:c for c in namespace} 

80 

81 fields = getattr(cls, '__serialize_fields__') 

82 

83 if '@' in data: 

84 return memo[data['@']] 

85 

86 inst = cls.__new__(cls) 

87 for f in fields: 

88 try: 

89 setattr(inst, f, _deserialize(data[f], namespace, memo)) 

90 except KeyError as e: 

91 raise KeyError("Cannot find key for class", cls, e) 

92 

93 if hasattr(inst, '_deserialize'): 

94 inst._deserialize() 

95 

96 return inst 

97 

98 

99class SerializeMemoizer(Serialize): 

100 "A version of serialize that memoizes objects to reduce space" 

101 

102 __serialize_fields__ = 'memoized', 

103 

104 def __init__(self, types_to_memoize: List) -> None: 

105 self.types_to_memoize = tuple(types_to_memoize) 

106 self.memoized = Enumerator() 

107 

108 def in_types(self, value: Serialize) -> bool: 

109 return isinstance(value, self.types_to_memoize) 

110 

111 def serialize(self) -> Dict[int, Any]: # type: ignore[override] 

112 return _serialize(self.memoized.reversed(), None) 

113 

114 @classmethod 

115 def deserialize(cls, data: Dict[int, Any], namespace: Dict[str, Any], memo: Dict[Any, Any]) -> Dict[int, Any]: # type: ignore[override] 

116 return _deserialize(data, namespace, memo) 

117 

118 

119try: 

120 import regex 

121 _has_regex = True 

122except ImportError: 

123 _has_regex = False 

124 

125if sys.version_info >= (3, 11): 

126 import re._parser as sre_parse 

127 import re._constants as sre_constants 

128else: 

129 import sre_parse 

130 import sre_constants 

131 

132categ_pattern = re.compile(r'\\p{[A-Za-z_]+}') 

133 

134def get_regexp_width(expr: str) -> Union[Tuple[int, int], List[int]]: 

135 if _has_regex: 

136 # Since `sre_parse` cannot deal with Unicode categories of the form `\p{Mn}`, we replace these with 

137 # a simple letter, which makes no difference as we are only trying to get the possible lengths of the regex 

138 # match here below. 

139 regexp_final = re.sub(categ_pattern, 'A', expr) 

140 else: 

141 if re.search(categ_pattern, expr): 

142 raise ImportError('`regex` module must be installed in order to use Unicode categories.', expr) 

143 regexp_final = expr 

144 try: 

145 # Fixed in next version (past 0.960) of typeshed 

146 return [int(x) for x in sre_parse.parse(regexp_final).getwidth()] 

147 except sre_constants.error: 

148 if not _has_regex: 

149 raise ValueError(expr) 

150 else: 

151 # sre_parse does not support the new features in regex. To not completely fail in that case, 

152 # we manually test for the most important info (whether the empty string is matched) 

153 c = regex.compile(regexp_final) 

154 # Python 3.11.7 introducded sre_parse.MAXWIDTH that is used instead of MAXREPEAT 

155 # See lark-parser/lark#1376 and python/cpython#109859 

156 MAXWIDTH = getattr(sre_parse, "MAXWIDTH", sre_constants.MAXREPEAT) 

157 if c.match('') is None: 

158 # MAXREPEAT is a none pickable subclass of int, therefore needs to be converted to enable caching 

159 return 1, int(MAXWIDTH) 

160 else: 

161 return 0, int(MAXWIDTH) 

162 

163 

164@dataclass(frozen=True) 

165class TextSlice(Generic[AnyStr]): 

166 """A view of a string or bytes object, between the start and end indices. 

167 

168 Never creates a copy. 

169 

170 Lark accepts instances of TextSlice as input (instead of a string), 

171 when the lexer is 'basic' or 'contextual'. 

172 

173 Args: 

174 text (str or bytes): The text to slice. 

175 start (int): The start index. Negative indices are supported. 

176 end (int): The end index. Negative indices are supported. 

177 

178 Raises: 

179 TypeError: If `text` is not a `str` or `bytes`. 

180 AssertionError: If `start` or `end` are out of bounds. 

181 

182 Examples: 

183 >>> TextSlice("Hello, World!", 7, -1) 

184 TextSlice(text='Hello, World!', start=7, end=12) 

185 

186 >>> TextSlice("Hello, World!", 7, None).count("o") 

187 1 

188 

189 """ 

190 text: AnyStr 

191 start: int 

192 end: int 

193 

194 def __post_init__(self): 

195 if not isinstance(self.text, (str, bytes)): 

196 raise TypeError("text must be str or bytes") 

197 

198 if self.start < 0: 

199 object.__setattr__(self, 'start', self.start + len(self.text)) 

200 assert self.start >=0 

201 

202 if self.end is None: 

203 object.__setattr__(self, 'end', len(self.text)) 

204 elif self.end < 0: 

205 object.__setattr__(self, 'end', self.end + len(self.text)) 

206 assert self.end <= len(self.text) 

207 

208 @classmethod 

209 def cast_from(cls, text: 'TextOrSlice') -> 'TextSlice[AnyStr]': 

210 if isinstance(text, TextSlice): 

211 return text 

212 

213 return cls(text, 0, len(text)) 

214 

215 def is_complete_text(self): 

216 return self.start == 0 and self.end == len(self.text) 

217 

218 def __len__(self): 

219 return self.end - self.start 

220 

221 def count(self, substr: AnyStr): 

222 return self.text.count(substr, self.start, self.end) 

223 

224 def rindex(self, substr: AnyStr): 

225 return self.text.rindex(substr, self.start, self.end) 

226 

227 

228TextOrSlice = Union[AnyStr, 'TextSlice[AnyStr]'] 

229 

230###} 

231 

232 

233_ID_START = 'Lu', 'Ll', 'Lt', 'Lm', 'Lo', 'Mn', 'Mc', 'Pc' 

234_ID_CONTINUE = _ID_START + ('Nd', 'Nl',) 

235 

236def _test_unicode_category(s: str, categories: Sequence[str]) -> bool: 

237 if len(s) != 1: 

238 return all(_test_unicode_category(char, categories) for char in s) 

239 return s == '_' or unicodedata.category(s) in categories 

240 

241def is_id_continue(s: str) -> bool: 

242 """ 

243 Checks if all characters in `s` are alphanumeric characters (Unicode standard, so diacritics, indian vowels, non-latin 

244 numbers, etc. all pass). Synonymous with a Python `ID_CONTINUE` identifier. See PEP 3131 for details. 

245 """ 

246 return _test_unicode_category(s, _ID_CONTINUE) 

247 

248def is_id_start(s: str) -> bool: 

249 """ 

250 Checks if all characters in `s` are alphabetic characters (Unicode standard, so diacritics, indian vowels, non-latin 

251 numbers, etc. all pass). Synonymous with a Python `ID_START` identifier. See PEP 3131 for details. 

252 """ 

253 return _test_unicode_category(s, _ID_START) 

254 

255 

256def dedup_list(l: Iterable[T]) -> List[T]: 

257 """Given a list (l) will removing duplicates from the list, 

258 preserving the original order of the list. Assumes that 

259 the list entries are hashable.""" 

260 return list(dict.fromkeys(l)) 

261 

262 

263class Enumerator(Serialize): 

264 def __init__(self) -> None: 

265 self.enums: Dict[Any, int] = {} 

266 

267 def get(self, item) -> int: 

268 if item not in self.enums: 

269 self.enums[item] = len(self.enums) 

270 return self.enums[item] 

271 

272 def __len__(self): 

273 return len(self.enums) 

274 

275 def reversed(self) -> Dict[int, Any]: 

276 r = {v: k for k, v in self.enums.items()} 

277 assert len(r) == len(self.enums) 

278 return r 

279 

280 

281 

282def combine_alternatives(lists): 

283 """ 

284 Accepts a list of alternatives, and enumerates all their possible concatenations. 

285 

286 Examples: 

287 >>> combine_alternatives([range(2), [4,5]]) 

288 [[0, 4], [0, 5], [1, 4], [1, 5]] 

289 

290 >>> combine_alternatives(["abc", "xy", '$']) 

291 [['a', 'x', '$'], ['a', 'y', '$'], ['b', 'x', '$'], ['b', 'y', '$'], ['c', 'x', '$'], ['c', 'y', '$']] 

292 

293 >>> combine_alternatives([]) 

294 [[]] 

295 """ 

296 if not lists: 

297 return [[]] 

298 assert all(l for l in lists), lists 

299 return list(product(*lists)) 

300 

301try: 

302 import atomicwrites 

303 _has_atomicwrites = True 

304except ImportError: 

305 _has_atomicwrites = False 

306 

307class FS: 

308 exists = staticmethod(os.path.exists) 

309 

310 @staticmethod 

311 def open(name, mode="r", **kwargs): 

312 if _has_atomicwrites and "w" in mode: 

313 return atomicwrites.atomic_write(name, mode=mode, overwrite=True, **kwargs) 

314 else: 

315 return open(name, mode, **kwargs) 

316 

317 

318class fzset(frozenset): 

319 def __repr__(self): 

320 return '{%s}' % ', '.join(map(repr, self)) 

321 

322 

323def classify_bool(seq: Iterable, pred: Callable) -> Any: 

324 false_elems = [] 

325 true_elems = [elem for elem in seq if pred(elem) or false_elems.append(elem)] # type: ignore[func-returns-value] 

326 return true_elems, false_elems 

327 

328 

329def bfs(initial: Iterable, expand: Callable) -> Iterator: 

330 open_q = deque(list(initial)) 

331 visited = set(open_q) 

332 while open_q: 

333 node = open_q.popleft() 

334 yield node 

335 for next_node in expand(node): 

336 if next_node not in visited: 

337 visited.add(next_node) 

338 open_q.append(next_node) 

339 

340def bfs_all_unique(initial, expand): 

341 "bfs, but doesn't keep track of visited (aka seen), because there can be no repetitions" 

342 open_q = deque(list(initial)) 

343 while open_q: 

344 node = open_q.popleft() 

345 yield node 

346 open_q += expand(node) 

347 

348 

349def _serialize(value: Any, memo: Optional[SerializeMemoizer]) -> Any: 

350 if isinstance(value, Serialize): 

351 return value.serialize(memo) 

352 elif isinstance(value, list): 

353 return [_serialize(elem, memo) for elem in value] 

354 elif isinstance(value, frozenset): 

355 return list(value) # TODO reversible? 

356 elif isinstance(value, dict): 

357 return {key:_serialize(elem, memo) for key, elem in value.items()} 

358 # assert value is None or isinstance(value, (int, float, str, tuple)), value 

359 return value 

360 

361 

362 

363 

364def small_factors(n: int, max_factor: int) -> List[Tuple[int, int]]: 

365 """ 

366 Splits n up into smaller factors and summands <= max_factor. 

367 Returns a list of [(a, b), ...] 

368 so that the following code returns n: 

369 

370 n = 1 

371 for a, b in values: 

372 n = n * a + b 

373 

374 Currently, we also keep a + b <= max_factor, but that might change 

375 """ 

376 assert n >= 0 

377 assert max_factor > 2 

378 if n <= max_factor: 

379 return [(n, 0)] 

380 

381 for a in range(max_factor, 1, -1): 

382 r, b = divmod(n, a) 

383 if a + b <= max_factor: 

384 return small_factors(r, max_factor) + [(a, b)] 

385 assert False, "Failed to factorize %s" % n 

386 

387 

388class OrderedSet(AbstractSet[T]): 

389 """A minimal OrderedSet implementation, using a dictionary. 

390 

391 (relies on the dictionary being ordered) 

392 """ 

393 def __init__(self, items: Iterable[T] =()): 

394 self.d = dict.fromkeys(items) 

395 

396 def __contains__(self, item: Any) -> bool: 

397 return item in self.d 

398 

399 def add(self, item: T): 

400 self.d[item] = None 

401 

402 def __iter__(self) -> Iterator[T]: 

403 return iter(self.d) 

404 

405 def remove(self, item: T): 

406 del self.d[item] 

407 

408 def __bool__(self): 

409 return bool(self.d) 

410 

411 def __len__(self) -> int: 

412 return len(self.d) 

413 

414 def __repr__(self): 

415 return f"{type(self).__name__}({', '.join(map(repr,self))})"