Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/lark/utils.py: 56%

195 statements  

« prev     ^ index     » next       coverage.py v7.3.1, created at 2023-09-25 06:30 +0000

1import unicodedata 

2import os 

3from itertools import product 

4from collections import deque 

5from typing import Callable, Iterator, List, Optional, Tuple, Type, TypeVar, Union, Dict, Any, Sequence, Iterable 

6 

7###{standalone 

8import sys, re 

9import logging 

10 

11logger: logging.Logger = logging.getLogger("lark") 

12logger.addHandler(logging.StreamHandler()) 

13# Set to highest level, since we have some warnings amongst the code 

14# By default, we should not output any log messages 

15logger.setLevel(logging.CRITICAL) 

16 

17 

18NO_VALUE = object() 

19 

20T = TypeVar("T") 

21 

22 

23def classify(seq: Iterable, key: Optional[Callable] = None, value: Optional[Callable] = None) -> Dict: 

24 d: Dict[Any, Any] = {} 

25 for item in seq: 

26 k = key(item) if (key is not None) else item 

27 v = value(item) if (value is not None) else item 

28 try: 

29 d[k].append(v) 

30 except KeyError: 

31 d[k] = [v] 

32 return d 

33 

34 

35def _deserialize(data: Any, namespace: Dict[str, Any], memo: Dict) -> Any: 

36 if isinstance(data, dict): 

37 if '__type__' in data: # Object 

38 class_ = namespace[data['__type__']] 

39 return class_.deserialize(data, memo) 

40 elif '@' in data: 

41 return memo[data['@']] 

42 return {key:_deserialize(value, namespace, memo) for key, value in data.items()} 

43 elif isinstance(data, list): 

44 return [_deserialize(value, namespace, memo) for value in data] 

45 return data 

46 

47 

48_T = TypeVar("_T", bound="Serialize") 

49 

50class Serialize: 

51 """Safe-ish serialization interface that doesn't rely on Pickle 

52 

53 Attributes: 

54 __serialize_fields__ (List[str]): Fields (aka attributes) to serialize. 

55 __serialize_namespace__ (list): List of classes that deserialization is allowed to instantiate. 

56 Should include all field types that aren't builtin types. 

57 """ 

58 

59 def memo_serialize(self, types_to_memoize: List) -> Any: 

60 memo = SerializeMemoizer(types_to_memoize) 

61 return self.serialize(memo), memo.serialize() 

62 

63 def serialize(self, memo = None) -> Dict[str, Any]: 

64 if memo and memo.in_types(self): 

65 return {'@': memo.memoized.get(self)} 

66 

67 fields = getattr(self, '__serialize_fields__') 

68 res = {f: _serialize(getattr(self, f), memo) for f in fields} 

69 res['__type__'] = type(self).__name__ 

70 if hasattr(self, '_serialize'): 

71 self._serialize(res, memo) # type: ignore[attr-defined] 

72 return res 

73 

74 @classmethod 

75 def deserialize(cls: Type[_T], data: Dict[str, Any], memo: Dict[int, Any]) -> _T: 

76 namespace = getattr(cls, '__serialize_namespace__', []) 

77 namespace = {c.__name__:c for c in namespace} 

78 

79 fields = getattr(cls, '__serialize_fields__') 

80 

81 if '@' in data: 

82 return memo[data['@']] 

83 

84 inst = cls.__new__(cls) 

85 for f in fields: 

86 try: 

87 setattr(inst, f, _deserialize(data[f], namespace, memo)) 

88 except KeyError as e: 

89 raise KeyError("Cannot find key for class", cls, e) 

90 

91 if hasattr(inst, '_deserialize'): 

92 inst._deserialize() # type: ignore[attr-defined] 

93 

94 return inst 

95 

96 

97class SerializeMemoizer(Serialize): 

98 "A version of serialize that memoizes objects to reduce space" 

99 

100 __serialize_fields__ = 'memoized', 

101 

102 def __init__(self, types_to_memoize: List) -> None: 

103 self.types_to_memoize = tuple(types_to_memoize) 

104 self.memoized = Enumerator() 

105 

106 def in_types(self, value: Serialize) -> bool: 

107 return isinstance(value, self.types_to_memoize) 

108 

109 def serialize(self) -> Dict[int, Any]: # type: ignore[override] 

110 return _serialize(self.memoized.reversed(), None) 

111 

112 @classmethod 

113 def deserialize(cls, data: Dict[int, Any], namespace: Dict[str, Any], memo: Dict[Any, Any]) -> Dict[int, Any]: # type: ignore[override] 

114 return _deserialize(data, namespace, memo) 

115 

116 

117try: 

118 import regex 

119 _has_regex = True 

120except ImportError: 

121 _has_regex = False 

122 

123if sys.version_info >= (3, 11): 

124 import re._parser as sre_parse 

125 import re._constants as sre_constants 

126else: 

127 import sre_parse 

128 import sre_constants 

129 

130categ_pattern = re.compile(r'\\p{[A-Za-z_]+}') 

131 

132def get_regexp_width(expr: str) -> Union[Tuple[int, int], List[int]]: 

133 if _has_regex: 

134 # Since `sre_parse` cannot deal with Unicode categories of the form `\p{Mn}`, we replace these with 

135 # a simple letter, which makes no difference as we are only trying to get the possible lengths of the regex 

136 # match here below. 

137 regexp_final = re.sub(categ_pattern, 'A', expr) 

138 else: 

139 if re.search(categ_pattern, expr): 

140 raise ImportError('`regex` module must be installed in order to use Unicode categories.', expr) 

141 regexp_final = expr 

142 try: 

143 # Fixed in next version (past 0.960) of typeshed 

144 return [int(x) for x in sre_parse.parse(regexp_final).getwidth()] # type: ignore[attr-defined] 

145 except sre_constants.error: 

146 if not _has_regex: 

147 raise ValueError(expr) 

148 else: 

149 # sre_parse does not support the new features in regex. To not completely fail in that case, 

150 # we manually test for the most important info (whether the empty string is matched) 

151 c = regex.compile(regexp_final) 

152 if c.match('') is None: 

153 # MAXREPEAT is a none pickable subclass of int, therefore needs to be converted to enable caching 

154 return 1, int(sre_constants.MAXREPEAT) 

155 else: 

156 return 0, int(sre_constants.MAXREPEAT) 

157 

158###} 

159 

160 

161_ID_START = 'Lu', 'Ll', 'Lt', 'Lm', 'Lo', 'Mn', 'Mc', 'Pc' 

162_ID_CONTINUE = _ID_START + ('Nd', 'Nl',) 

163 

164def _test_unicode_category(s: str, categories: Sequence[str]) -> bool: 

165 if len(s) != 1: 

166 return all(_test_unicode_category(char, categories) for char in s) 

167 return s == '_' or unicodedata.category(s) in categories 

168 

169def is_id_continue(s: str) -> bool: 

170 """ 

171 Checks if all characters in `s` are alphanumeric characters (Unicode standard, so diacritics, indian vowels, non-latin 

172 numbers, etc. all pass). Synonymous with a Python `ID_CONTINUE` identifier. See PEP 3131 for details. 

173 """ 

174 return _test_unicode_category(s, _ID_CONTINUE) 

175 

176def is_id_start(s: str) -> bool: 

177 """ 

178 Checks if all characters in `s` are alphabetic characters (Unicode standard, so diacritics, indian vowels, non-latin 

179 numbers, etc. all pass). Synonymous with a Python `ID_START` identifier. See PEP 3131 for details. 

180 """ 

181 return _test_unicode_category(s, _ID_START) 

182 

183 

184def dedup_list(l: List[T]) -> List[T]: 

185 """Given a list (l) will removing duplicates from the list, 

186 preserving the original order of the list. Assumes that 

187 the list entries are hashable.""" 

188 dedup = set() 

189 # This returns None, but that's expected 

190 return [x for x in l if not (x in dedup or dedup.add(x))] # type: ignore[func-returns-value] 

191 # 2x faster (ordered in PyPy and CPython 3.6+, guaranteed to be ordered in Python 3.7+) 

192 # return list(dict.fromkeys(l)) 

193 

194 

195class Enumerator(Serialize): 

196 def __init__(self) -> None: 

197 self.enums: Dict[Any, int] = {} 

198 

199 def get(self, item) -> int: 

200 if item not in self.enums: 

201 self.enums[item] = len(self.enums) 

202 return self.enums[item] 

203 

204 def __len__(self): 

205 return len(self.enums) 

206 

207 def reversed(self) -> Dict[int, Any]: 

208 r = {v: k for k, v in self.enums.items()} 

209 assert len(r) == len(self.enums) 

210 return r 

211 

212 

213 

214def combine_alternatives(lists): 

215 """ 

216 Accepts a list of alternatives, and enumerates all their possible concatenations. 

217 

218 Examples: 

219 >>> combine_alternatives([range(2), [4,5]]) 

220 [[0, 4], [0, 5], [1, 4], [1, 5]] 

221 

222 >>> combine_alternatives(["abc", "xy", '$']) 

223 [['a', 'x', '$'], ['a', 'y', '$'], ['b', 'x', '$'], ['b', 'y', '$'], ['c', 'x', '$'], ['c', 'y', '$']] 

224 

225 >>> combine_alternatives([]) 

226 [[]] 

227 """ 

228 if not lists: 

229 return [[]] 

230 assert all(l for l in lists), lists 

231 return list(product(*lists)) 

232 

233try: 

234 import atomicwrites 

235 _has_atomicwrites = True 

236except ImportError: 

237 _has_atomicwrites = False 

238 

239class FS: 

240 exists = staticmethod(os.path.exists) 

241 

242 @staticmethod 

243 def open(name, mode="r", **kwargs): 

244 if _has_atomicwrites and "w" in mode: 

245 return atomicwrites.atomic_write(name, mode=mode, overwrite=True, **kwargs) 

246 else: 

247 return open(name, mode, **kwargs) 

248 

249 

250 

251def isascii(s: str) -> bool: 

252 """ str.isascii only exists in python3.7+ """ 

253 if sys.version_info >= (3, 7): 

254 return s.isascii() 

255 else: 

256 try: 

257 s.encode('ascii') 

258 return True 

259 except (UnicodeDecodeError, UnicodeEncodeError): 

260 return False 

261 

262 

263class fzset(frozenset): 

264 def __repr__(self): 

265 return '{%s}' % ', '.join(map(repr, self)) 

266 

267 

268def classify_bool(seq: Sequence, pred: Callable) -> Any: 

269 false_elems = [] 

270 true_elems = [elem for elem in seq if pred(elem) or false_elems.append(elem)] # type: ignore[func-returns-value] 

271 return true_elems, false_elems 

272 

273 

274def bfs(initial: Sequence, expand: Callable) -> Iterator: 

275 open_q = deque(list(initial)) 

276 visited = set(open_q) 

277 while open_q: 

278 node = open_q.popleft() 

279 yield node 

280 for next_node in expand(node): 

281 if next_node not in visited: 

282 visited.add(next_node) 

283 open_q.append(next_node) 

284 

285def bfs_all_unique(initial, expand): 

286 "bfs, but doesn't keep track of visited (aka seen), because there can be no repetitions" 

287 open_q = deque(list(initial)) 

288 while open_q: 

289 node = open_q.popleft() 

290 yield node 

291 open_q += expand(node) 

292 

293 

294def _serialize(value: Any, memo: Optional[SerializeMemoizer]) -> Any: 

295 if isinstance(value, Serialize): 

296 return value.serialize(memo) 

297 elif isinstance(value, list): 

298 return [_serialize(elem, memo) for elem in value] 

299 elif isinstance(value, frozenset): 

300 return list(value) # TODO reversible? 

301 elif isinstance(value, dict): 

302 return {key:_serialize(elem, memo) for key, elem in value.items()} 

303 # assert value is None or isinstance(value, (int, float, str, tuple)), value 

304 return value 

305 

306 

307 

308 

309def small_factors(n: int, max_factor: int) -> List[Tuple[int, int]]: 

310 """ 

311 Splits n up into smaller factors and summands <= max_factor. 

312 Returns a list of [(a, b), ...] 

313 so that the following code returns n: 

314 

315 n = 1 

316 for a, b in values: 

317 n = n * a + b 

318 

319 Currently, we also keep a + b <= max_factor, but that might change 

320 """ 

321 assert n >= 0 

322 assert max_factor > 2 

323 if n <= max_factor: 

324 return [(n, 0)] 

325 

326 for a in range(max_factor, 1, -1): 

327 r, b = divmod(n, a) 

328 if a + b <= max_factor: 

329 return small_factors(r, max_factor) + [(a, b)] 

330 assert False, "Failed to factorize %s" % n