Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/lark/utils.py: 50%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

204 statements  

1import unicodedata 

2import os 

3from itertools import product 

4from collections import deque 

5from typing import Callable, Iterator, List, Optional, Tuple, Type, TypeVar, Union, Dict, Any, Sequence, Iterable, AbstractSet 

6 

7###{standalone 

8import sys, re 

9import logging 

10 

11logger: logging.Logger = logging.getLogger("lark") 

12logger.addHandler(logging.StreamHandler()) 

13# Set to highest level, since we have some warnings amongst the code 

14# By default, we should not output any log messages 

15logger.setLevel(logging.CRITICAL) 

16 

17 

18NO_VALUE = object() 

19 

20T = TypeVar("T") 

21 

22 

23def classify(seq: Iterable, key: Optional[Callable] = None, value: Optional[Callable] = None) -> Dict: 

24 d: Dict[Any, Any] = {} 

25 for item in seq: 

26 k = key(item) if (key is not None) else item 

27 v = value(item) if (value is not None) else item 

28 try: 

29 d[k].append(v) 

30 except KeyError: 

31 d[k] = [v] 

32 return d 

33 

34 

35def _deserialize(data: Any, namespace: Dict[str, Any], memo: Dict) -> Any: 

36 if isinstance(data, dict): 

37 if '__type__' in data: # Object 

38 class_ = namespace[data['__type__']] 

39 return class_.deserialize(data, memo) 

40 elif '@' in data: 

41 return memo[data['@']] 

42 return {key:_deserialize(value, namespace, memo) for key, value in data.items()} 

43 elif isinstance(data, list): 

44 return [_deserialize(value, namespace, memo) for value in data] 

45 return data 

46 

47 

48_T = TypeVar("_T", bound="Serialize") 

49 

50class Serialize: 

51 """Safe-ish serialization interface that doesn't rely on Pickle 

52 

53 Attributes: 

54 __serialize_fields__ (List[str]): Fields (aka attributes) to serialize. 

55 __serialize_namespace__ (list): List of classes that deserialization is allowed to instantiate. 

56 Should include all field types that aren't builtin types. 

57 """ 

58 

59 def memo_serialize(self, types_to_memoize: List) -> Any: 

60 memo = SerializeMemoizer(types_to_memoize) 

61 return self.serialize(memo), memo.serialize() 

62 

63 def serialize(self, memo = None) -> Dict[str, Any]: 

64 if memo and memo.in_types(self): 

65 return {'@': memo.memoized.get(self)} 

66 

67 fields = getattr(self, '__serialize_fields__') 

68 res = {f: _serialize(getattr(self, f), memo) for f in fields} 

69 res['__type__'] = type(self).__name__ 

70 if hasattr(self, '_serialize'): 

71 self._serialize(res, memo) 

72 return res 

73 

74 @classmethod 

75 def deserialize(cls: Type[_T], data: Dict[str, Any], memo: Dict[int, Any]) -> _T: 

76 namespace = getattr(cls, '__serialize_namespace__', []) 

77 namespace = {c.__name__:c for c in namespace} 

78 

79 fields = getattr(cls, '__serialize_fields__') 

80 

81 if '@' in data: 

82 return memo[data['@']] 

83 

84 inst = cls.__new__(cls) 

85 for f in fields: 

86 try: 

87 setattr(inst, f, _deserialize(data[f], namespace, memo)) 

88 except KeyError as e: 

89 raise KeyError("Cannot find key for class", cls, e) 

90 

91 if hasattr(inst, '_deserialize'): 

92 inst._deserialize() 

93 

94 return inst 

95 

96 

97class SerializeMemoizer(Serialize): 

98 "A version of serialize that memoizes objects to reduce space" 

99 

100 __serialize_fields__ = 'memoized', 

101 

102 def __init__(self, types_to_memoize: List) -> None: 

103 self.types_to_memoize = tuple(types_to_memoize) 

104 self.memoized = Enumerator() 

105 

106 def in_types(self, value: Serialize) -> bool: 

107 return isinstance(value, self.types_to_memoize) 

108 

109 def serialize(self) -> Dict[int, Any]: # type: ignore[override] 

110 return _serialize(self.memoized.reversed(), None) 

111 

112 @classmethod 

113 def deserialize(cls, data: Dict[int, Any], namespace: Dict[str, Any], memo: Dict[Any, Any]) -> Dict[int, Any]: # type: ignore[override] 

114 return _deserialize(data, namespace, memo) 

115 

116 

117try: 

118 import regex 

119 _has_regex = True 

120except ImportError: 

121 _has_regex = False 

122 

123if sys.version_info >= (3, 11): 

124 import re._parser as sre_parse 

125 import re._constants as sre_constants 

126else: 

127 import sre_parse 

128 import sre_constants 

129 

130categ_pattern = re.compile(r'\\p{[A-Za-z_]+}') 

131 

132def get_regexp_width(expr: str) -> Union[Tuple[int, int], List[int]]: 

133 if _has_regex: 

134 # Since `sre_parse` cannot deal with Unicode categories of the form `\p{Mn}`, we replace these with 

135 # a simple letter, which makes no difference as we are only trying to get the possible lengths of the regex 

136 # match here below. 

137 regexp_final = re.sub(categ_pattern, 'A', expr) 

138 else: 

139 if re.search(categ_pattern, expr): 

140 raise ImportError('`regex` module must be installed in order to use Unicode categories.', expr) 

141 regexp_final = expr 

142 try: 

143 # Fixed in next version (past 0.960) of typeshed 

144 return [int(x) for x in sre_parse.parse(regexp_final).getwidth()] 

145 except sre_constants.error: 

146 if not _has_regex: 

147 raise ValueError(expr) 

148 else: 

149 # sre_parse does not support the new features in regex. To not completely fail in that case, 

150 # we manually test for the most important info (whether the empty string is matched) 

151 c = regex.compile(regexp_final) 

152 # Python 3.11.7 introducded sre_parse.MAXWIDTH that is used instead of MAXREPEAT 

153 # See lark-parser/lark#1376 and python/cpython#109859 

154 MAXWIDTH = getattr(sre_parse, "MAXWIDTH", sre_constants.MAXREPEAT) 

155 if c.match('') is None: 

156 # MAXREPEAT is a none pickable subclass of int, therefore needs to be converted to enable caching 

157 return 1, int(MAXWIDTH) 

158 else: 

159 return 0, int(MAXWIDTH) 

160 

161###} 

162 

163 

164_ID_START = 'Lu', 'Ll', 'Lt', 'Lm', 'Lo', 'Mn', 'Mc', 'Pc' 

165_ID_CONTINUE = _ID_START + ('Nd', 'Nl',) 

166 

167def _test_unicode_category(s: str, categories: Sequence[str]) -> bool: 

168 if len(s) != 1: 

169 return all(_test_unicode_category(char, categories) for char in s) 

170 return s == '_' or unicodedata.category(s) in categories 

171 

172def is_id_continue(s: str) -> bool: 

173 """ 

174 Checks if all characters in `s` are alphanumeric characters (Unicode standard, so diacritics, indian vowels, non-latin 

175 numbers, etc. all pass). Synonymous with a Python `ID_CONTINUE` identifier. See PEP 3131 for details. 

176 """ 

177 return _test_unicode_category(s, _ID_CONTINUE) 

178 

179def is_id_start(s: str) -> bool: 

180 """ 

181 Checks if all characters in `s` are alphabetic characters (Unicode standard, so diacritics, indian vowels, non-latin 

182 numbers, etc. all pass). Synonymous with a Python `ID_START` identifier. See PEP 3131 for details. 

183 """ 

184 return _test_unicode_category(s, _ID_START) 

185 

186 

187def dedup_list(l: Sequence[T]) -> List[T]: 

188 """Given a list (l) will removing duplicates from the list, 

189 preserving the original order of the list. Assumes that 

190 the list entries are hashable.""" 

191 return list(dict.fromkeys(l)) 

192 

193 

194class Enumerator(Serialize): 

195 def __init__(self) -> None: 

196 self.enums: Dict[Any, int] = {} 

197 

198 def get(self, item) -> int: 

199 if item not in self.enums: 

200 self.enums[item] = len(self.enums) 

201 return self.enums[item] 

202 

203 def __len__(self): 

204 return len(self.enums) 

205 

206 def reversed(self) -> Dict[int, Any]: 

207 r = {v: k for k, v in self.enums.items()} 

208 assert len(r) == len(self.enums) 

209 return r 

210 

211 

212 

213def combine_alternatives(lists): 

214 """ 

215 Accepts a list of alternatives, and enumerates all their possible concatenations. 

216 

217 Examples: 

218 >>> combine_alternatives([range(2), [4,5]]) 

219 [[0, 4], [0, 5], [1, 4], [1, 5]] 

220 

221 >>> combine_alternatives(["abc", "xy", '$']) 

222 [['a', 'x', '$'], ['a', 'y', '$'], ['b', 'x', '$'], ['b', 'y', '$'], ['c', 'x', '$'], ['c', 'y', '$']] 

223 

224 >>> combine_alternatives([]) 

225 [[]] 

226 """ 

227 if not lists: 

228 return [[]] 

229 assert all(l for l in lists), lists 

230 return list(product(*lists)) 

231 

232try: 

233 import atomicwrites 

234 _has_atomicwrites = True 

235except ImportError: 

236 _has_atomicwrites = False 

237 

238class FS: 

239 exists = staticmethod(os.path.exists) 

240 

241 @staticmethod 

242 def open(name, mode="r", **kwargs): 

243 if _has_atomicwrites and "w" in mode: 

244 return atomicwrites.atomic_write(name, mode=mode, overwrite=True, **kwargs) 

245 else: 

246 return open(name, mode, **kwargs) 

247 

248 

249class fzset(frozenset): 

250 def __repr__(self): 

251 return '{%s}' % ', '.join(map(repr, self)) 

252 

253 

254def classify_bool(seq: Iterable, pred: Callable) -> Any: 

255 false_elems = [] 

256 true_elems = [elem for elem in seq if pred(elem) or false_elems.append(elem)] # type: ignore[func-returns-value] 

257 return true_elems, false_elems 

258 

259 

260def bfs(initial: Iterable, expand: Callable) -> Iterator: 

261 open_q = deque(list(initial)) 

262 visited = set(open_q) 

263 while open_q: 

264 node = open_q.popleft() 

265 yield node 

266 for next_node in expand(node): 

267 if next_node not in visited: 

268 visited.add(next_node) 

269 open_q.append(next_node) 

270 

271def bfs_all_unique(initial, expand): 

272 "bfs, but doesn't keep track of visited (aka seen), because there can be no repetitions" 

273 open_q = deque(list(initial)) 

274 while open_q: 

275 node = open_q.popleft() 

276 yield node 

277 open_q += expand(node) 

278 

279 

280def _serialize(value: Any, memo: Optional[SerializeMemoizer]) -> Any: 

281 if isinstance(value, Serialize): 

282 return value.serialize(memo) 

283 elif isinstance(value, list): 

284 return [_serialize(elem, memo) for elem in value] 

285 elif isinstance(value, frozenset): 

286 return list(value) # TODO reversible? 

287 elif isinstance(value, dict): 

288 return {key:_serialize(elem, memo) for key, elem in value.items()} 

289 # assert value is None or isinstance(value, (int, float, str, tuple)), value 

290 return value 

291 

292 

293 

294 

295def small_factors(n: int, max_factor: int) -> List[Tuple[int, int]]: 

296 """ 

297 Splits n up into smaller factors and summands <= max_factor. 

298 Returns a list of [(a, b), ...] 

299 so that the following code returns n: 

300 

301 n = 1 

302 for a, b in values: 

303 n = n * a + b 

304 

305 Currently, we also keep a + b <= max_factor, but that might change 

306 """ 

307 assert n >= 0 

308 assert max_factor > 2 

309 if n <= max_factor: 

310 return [(n, 0)] 

311 

312 for a in range(max_factor, 1, -1): 

313 r, b = divmod(n, a) 

314 if a + b <= max_factor: 

315 return small_factors(r, max_factor) + [(a, b)] 

316 assert False, "Failed to factorize %s" % n 

317 

318 

319class OrderedSet(AbstractSet[T]): 

320 """A minimal OrderedSet implementation, using a dictionary. 

321 

322 (relies on the dictionary being ordered) 

323 """ 

324 def __init__(self, items: Iterable[T] =()): 

325 self.d = dict.fromkeys(items) 

326 

327 def __contains__(self, item: Any) -> bool: 

328 return item in self.d 

329 

330 def add(self, item: T): 

331 self.d[item] = None 

332 

333 def __iter__(self) -> Iterator[T]: 

334 return iter(self.d) 

335 

336 def remove(self, item: T): 

337 del self.d[item] 

338 

339 def __bool__(self): 

340 return bool(self.d) 

341 

342 def __len__(self) -> int: 

343 return len(self.d) 

344 

345 def __repr__(self): 

346 return f"{type(self).__name__}({', '.join(map(repr,self))})"