Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/lark/utils.py: 58%

211 statements  

« prev     ^ index     » next       coverage.py v7.4.1, created at 2024-02-14 06:19 +0000

1import unicodedata 

2import os 

3from itertools import product 

4from collections import deque 

5from typing import Callable, Iterator, List, Optional, Tuple, Type, TypeVar, Union, Dict, Any, Sequence, Iterable, AbstractSet 

6 

7###{standalone 

8import sys, re 

9import logging 

10 

11logger: logging.Logger = logging.getLogger("lark") 

12logger.addHandler(logging.StreamHandler()) 

13# Set to highest level, since we have some warnings amongst the code 

14# By default, we should not output any log messages 

15logger.setLevel(logging.CRITICAL) 

16 

17 

18NO_VALUE = object() 

19 

20T = TypeVar("T") 

21 

22 

23def classify(seq: Iterable, key: Optional[Callable] = None, value: Optional[Callable] = None) -> Dict: 

24 d: Dict[Any, Any] = {} 

25 for item in seq: 

26 k = key(item) if (key is not None) else item 

27 v = value(item) if (value is not None) else item 

28 try: 

29 d[k].append(v) 

30 except KeyError: 

31 d[k] = [v] 

32 return d 

33 

34 

35def _deserialize(data: Any, namespace: Dict[str, Any], memo: Dict) -> Any: 

36 if isinstance(data, dict): 

37 if '__type__' in data: # Object 

38 class_ = namespace[data['__type__']] 

39 return class_.deserialize(data, memo) 

40 elif '@' in data: 

41 return memo[data['@']] 

42 return {key:_deserialize(value, namespace, memo) for key, value in data.items()} 

43 elif isinstance(data, list): 

44 return [_deserialize(value, namespace, memo) for value in data] 

45 return data 

46 

47 

48_T = TypeVar("_T", bound="Serialize") 

49 

50class Serialize: 

51 """Safe-ish serialization interface that doesn't rely on Pickle 

52 

53 Attributes: 

54 __serialize_fields__ (List[str]): Fields (aka attributes) to serialize. 

55 __serialize_namespace__ (list): List of classes that deserialization is allowed to instantiate. 

56 Should include all field types that aren't builtin types. 

57 """ 

58 

59 def memo_serialize(self, types_to_memoize: List) -> Any: 

60 memo = SerializeMemoizer(types_to_memoize) 

61 return self.serialize(memo), memo.serialize() 

62 

63 def serialize(self, memo = None) -> Dict[str, Any]: 

64 if memo and memo.in_types(self): 

65 return {'@': memo.memoized.get(self)} 

66 

67 fields = getattr(self, '__serialize_fields__') 

68 res = {f: _serialize(getattr(self, f), memo) for f in fields} 

69 res['__type__'] = type(self).__name__ 

70 if hasattr(self, '_serialize'): 

71 self._serialize(res, memo) # type: ignore[attr-defined] 

72 return res 

73 

74 @classmethod 

75 def deserialize(cls: Type[_T], data: Dict[str, Any], memo: Dict[int, Any]) -> _T: 

76 namespace = getattr(cls, '__serialize_namespace__', []) 

77 namespace = {c.__name__:c for c in namespace} 

78 

79 fields = getattr(cls, '__serialize_fields__') 

80 

81 if '@' in data: 

82 return memo[data['@']] 

83 

84 inst = cls.__new__(cls) 

85 for f in fields: 

86 try: 

87 setattr(inst, f, _deserialize(data[f], namespace, memo)) 

88 except KeyError as e: 

89 raise KeyError("Cannot find key for class", cls, e) 

90 

91 if hasattr(inst, '_deserialize'): 

92 inst._deserialize() # type: ignore[attr-defined] 

93 

94 return inst 

95 

96 

97class SerializeMemoizer(Serialize): 

98 "A version of serialize that memoizes objects to reduce space" 

99 

100 __serialize_fields__ = 'memoized', 

101 

102 def __init__(self, types_to_memoize: List) -> None: 

103 self.types_to_memoize = tuple(types_to_memoize) 

104 self.memoized = Enumerator() 

105 

106 def in_types(self, value: Serialize) -> bool: 

107 return isinstance(value, self.types_to_memoize) 

108 

109 def serialize(self) -> Dict[int, Any]: # type: ignore[override] 

110 return _serialize(self.memoized.reversed(), None) 

111 

112 @classmethod 

113 def deserialize(cls, data: Dict[int, Any], namespace: Dict[str, Any], memo: Dict[Any, Any]) -> Dict[int, Any]: # type: ignore[override] 

114 return _deserialize(data, namespace, memo) 

115 

116 

117try: 

118 import regex 

119 _has_regex = True 

120except ImportError: 

121 _has_regex = False 

122 

123if sys.version_info >= (3, 11): 

124 import re._parser as sre_parse 

125 import re._constants as sre_constants 

126else: 

127 import sre_parse 

128 import sre_constants 

129 

130categ_pattern = re.compile(r'\\p{[A-Za-z_]+}') 

131 

132def get_regexp_width(expr: str) -> Union[Tuple[int, int], List[int]]: 

133 if _has_regex: 

134 # Since `sre_parse` cannot deal with Unicode categories of the form `\p{Mn}`, we replace these with 

135 # a simple letter, which makes no difference as we are only trying to get the possible lengths of the regex 

136 # match here below. 

137 regexp_final = re.sub(categ_pattern, 'A', expr) 

138 else: 

139 if re.search(categ_pattern, expr): 

140 raise ImportError('`regex` module must be installed in order to use Unicode categories.', expr) 

141 regexp_final = expr 

142 try: 

143 # Fixed in next version (past 0.960) of typeshed 

144 return [int(x) for x in sre_parse.parse(regexp_final).getwidth()] # type: ignore[attr-defined] 

145 except sre_constants.error: 

146 if not _has_regex: 

147 raise ValueError(expr) 

148 else: 

149 # sre_parse does not support the new features in regex. To not completely fail in that case, 

150 # we manually test for the most important info (whether the empty string is matched) 

151 c = regex.compile(regexp_final) 

152 # Python 3.11.7 introducded sre_parse.MAXWIDTH that is used instead of MAXREPEAT 

153 # See lark-parser/lark#1376 and python/cpython#109859 

154 MAXWIDTH = getattr(sre_parse, "MAXWIDTH", sre_constants.MAXREPEAT) 

155 if c.match('') is None: 

156 # MAXREPEAT is a none pickable subclass of int, therefore needs to be converted to enable caching 

157 return 1, int(MAXWIDTH) 

158 else: 

159 return 0, int(MAXWIDTH) 

160 

161###} 

162 

163 

164_ID_START = 'Lu', 'Ll', 'Lt', 'Lm', 'Lo', 'Mn', 'Mc', 'Pc' 

165_ID_CONTINUE = _ID_START + ('Nd', 'Nl',) 

166 

167def _test_unicode_category(s: str, categories: Sequence[str]) -> bool: 

168 if len(s) != 1: 

169 return all(_test_unicode_category(char, categories) for char in s) 

170 return s == '_' or unicodedata.category(s) in categories 

171 

172def is_id_continue(s: str) -> bool: 

173 """ 

174 Checks if all characters in `s` are alphanumeric characters (Unicode standard, so diacritics, indian vowels, non-latin 

175 numbers, etc. all pass). Synonymous with a Python `ID_CONTINUE` identifier. See PEP 3131 for details. 

176 """ 

177 return _test_unicode_category(s, _ID_CONTINUE) 

178 

179def is_id_start(s: str) -> bool: 

180 """ 

181 Checks if all characters in `s` are alphabetic characters (Unicode standard, so diacritics, indian vowels, non-latin 

182 numbers, etc. all pass). Synonymous with a Python `ID_START` identifier. See PEP 3131 for details. 

183 """ 

184 return _test_unicode_category(s, _ID_START) 

185 

186 

187def dedup_list(l: Sequence[T]) -> List[T]: 

188 """Given a list (l) will removing duplicates from the list, 

189 preserving the original order of the list. Assumes that 

190 the list entries are hashable.""" 

191 dedup = set() 

192 # This returns None, but that's expected 

193 return [x for x in l if not (x in dedup or dedup.add(x))] # type: ignore[func-returns-value] 

194 # 2x faster (ordered in PyPy and CPython 3.6+, guaranteed to be ordered in Python 3.7+) 

195 # return list(dict.fromkeys(l)) 

196 

197 

198class Enumerator(Serialize): 

199 def __init__(self) -> None: 

200 self.enums: Dict[Any, int] = {} 

201 

202 def get(self, item) -> int: 

203 if item not in self.enums: 

204 self.enums[item] = len(self.enums) 

205 return self.enums[item] 

206 

207 def __len__(self): 

208 return len(self.enums) 

209 

210 def reversed(self) -> Dict[int, Any]: 

211 r = {v: k for k, v in self.enums.items()} 

212 assert len(r) == len(self.enums) 

213 return r 

214 

215 

216 

217def combine_alternatives(lists): 

218 """ 

219 Accepts a list of alternatives, and enumerates all their possible concatenations. 

220 

221 Examples: 

222 >>> combine_alternatives([range(2), [4,5]]) 

223 [[0, 4], [0, 5], [1, 4], [1, 5]] 

224 

225 >>> combine_alternatives(["abc", "xy", '$']) 

226 [['a', 'x', '$'], ['a', 'y', '$'], ['b', 'x', '$'], ['b', 'y', '$'], ['c', 'x', '$'], ['c', 'y', '$']] 

227 

228 >>> combine_alternatives([]) 

229 [[]] 

230 """ 

231 if not lists: 

232 return [[]] 

233 assert all(l for l in lists), lists 

234 return list(product(*lists)) 

235 

236try: 

237 # atomicwrites doesn't have type bindings 

238 import atomicwrites # type: ignore[import] 

239 _has_atomicwrites = True 

240except ImportError: 

241 _has_atomicwrites = False 

242 

243class FS: 

244 exists = staticmethod(os.path.exists) 

245 

246 @staticmethod 

247 def open(name, mode="r", **kwargs): 

248 if _has_atomicwrites and "w" in mode: 

249 return atomicwrites.atomic_write(name, mode=mode, overwrite=True, **kwargs) 

250 else: 

251 return open(name, mode, **kwargs) 

252 

253 

254 

255def isascii(s: str) -> bool: 

256 """ str.isascii only exists in python3.7+ """ 

257 if sys.version_info >= (3, 7): 

258 return s.isascii() 

259 else: 

260 try: 

261 s.encode('ascii') 

262 return True 

263 except (UnicodeDecodeError, UnicodeEncodeError): 

264 return False 

265 

266 

267class fzset(frozenset): 

268 def __repr__(self): 

269 return '{%s}' % ', '.join(map(repr, self)) 

270 

271 

272def classify_bool(seq: Iterable, pred: Callable) -> Any: 

273 false_elems = [] 

274 true_elems = [elem for elem in seq if pred(elem) or false_elems.append(elem)] # type: ignore[func-returns-value] 

275 return true_elems, false_elems 

276 

277 

278def bfs(initial: Iterable, expand: Callable) -> Iterator: 

279 open_q = deque(list(initial)) 

280 visited = set(open_q) 

281 while open_q: 

282 node = open_q.popleft() 

283 yield node 

284 for next_node in expand(node): 

285 if next_node not in visited: 

286 visited.add(next_node) 

287 open_q.append(next_node) 

288 

289def bfs_all_unique(initial, expand): 

290 "bfs, but doesn't keep track of visited (aka seen), because there can be no repetitions" 

291 open_q = deque(list(initial)) 

292 while open_q: 

293 node = open_q.popleft() 

294 yield node 

295 open_q += expand(node) 

296 

297 

298def _serialize(value: Any, memo: Optional[SerializeMemoizer]) -> Any: 

299 if isinstance(value, Serialize): 

300 return value.serialize(memo) 

301 elif isinstance(value, list): 

302 return [_serialize(elem, memo) for elem in value] 

303 elif isinstance(value, frozenset): 

304 return list(value) # TODO reversible? 

305 elif isinstance(value, dict): 

306 return {key:_serialize(elem, memo) for key, elem in value.items()} 

307 # assert value is None or isinstance(value, (int, float, str, tuple)), value 

308 return value 

309 

310 

311 

312 

313def small_factors(n: int, max_factor: int) -> List[Tuple[int, int]]: 

314 """ 

315 Splits n up into smaller factors and summands <= max_factor. 

316 Returns a list of [(a, b), ...] 

317 so that the following code returns n: 

318 

319 n = 1 

320 for a, b in values: 

321 n = n * a + b 

322 

323 Currently, we also keep a + b <= max_factor, but that might change 

324 """ 

325 assert n >= 0 

326 assert max_factor > 2 

327 if n <= max_factor: 

328 return [(n, 0)] 

329 

330 for a in range(max_factor, 1, -1): 

331 r, b = divmod(n, a) 

332 if a + b <= max_factor: 

333 return small_factors(r, max_factor) + [(a, b)] 

334 assert False, "Failed to factorize %s" % n 

335 

336 

337class OrderedSet(AbstractSet[T]): 

338 """A minimal OrderedSet implementation, using a dictionary. 

339 

340 (relies on the dictionary being ordered) 

341 """ 

342 def __init__(self, items: Iterable[T] =()): 

343 self.d = dict.fromkeys(items) 

344 

345 def __contains__(self, item: Any) -> bool: 

346 return item in self.d 

347 

348 def add(self, item: T): 

349 self.d[item] = None 

350 

351 def __iter__(self) -> Iterator[T]: 

352 return iter(self.d) 

353 

354 def remove(self, item: T): 

355 del self.d[item] 

356 

357 def __bool__(self): 

358 return bool(self.d) 

359 

360 def __len__(self) -> int: 

361 return len(self.d)