Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/opt_einsum/parser.py: 55%

132 statements  

« prev     ^ index     » next       coverage.py v7.3.1, created at 2023-09-25 06:41 +0000

1""" 

2A functionally equivalent parser of the numpy.einsum input parser 

3""" 

4 

5import itertools 

6from typing import Any, Dict, Iterator, List, Tuple, Union 

7 

8import numpy as np 

9 

10from .typing import ArrayType, TensorShapeType 

11 

12__all__ = [ 

13 "is_valid_einsum_char", 

14 "has_valid_einsum_chars_only", 

15 "get_symbol", 

16 "gen_unused_symbols", 

17 "convert_to_valid_einsum_chars", 

18 "alpha_canonicalize", 

19 "find_output_str", 

20 "find_output_shape", 

21 "possibly_convert_to_numpy", 

22 "parse_einsum_input", 

23] 

24 

25_einsum_symbols_base = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" 

26 

27 

28def is_valid_einsum_char(x: str) -> bool: 

29 """Check if the character ``x`` is valid for numpy einsum. 

30 

31 **Examples:** 

32 

33 ```python 

34 is_valid_einsum_char("a") 

35 #> True 

36 

37 is_valid_einsum_char("Ǵ") 

38 #> False 

39 ``` 

40 """ 

41 return (x in _einsum_symbols_base) or (x in ",->.") 

42 

43 

44def has_valid_einsum_chars_only(einsum_str: str) -> bool: 

45 """Check if ``einsum_str`` contains only valid characters for numpy einsum. 

46 

47 **Examples:** 

48 

49 ```python 

50 has_valid_einsum_chars_only("abAZ") 

51 #> True 

52 

53 has_valid_einsum_chars_only("Över") 

54 #> False 

55 ``` 

56 """ 

57 return all(map(is_valid_einsum_char, einsum_str)) 

58 

59 

60def get_symbol(i: int) -> str: 

61 """Get the symbol corresponding to int ``i`` - runs through the usual 52 

62 letters before resorting to unicode characters, starting at ``chr(192)`` and skipping surrogates. 

63 

64 **Examples:** 

65 

66 ```python 

67 get_symbol(2) 

68 #> 'c' 

69 

70 get_symbol(200) 

71 #> 'Ŕ' 

72 

73 get_symbol(20000) 

74 #> '京' 

75 ``` 

76 """ 

77 if i < 52: 

78 return _einsum_symbols_base[i] 

79 elif i >= 55296: 

80 # Skip chr(57343) - chr(55296) as surrogates 

81 return chr(i + 2048) 

82 else: 

83 return chr(i + 140) 

84 

85 

86def gen_unused_symbols(used: str, n: int) -> Iterator[str]: 

87 """Generate ``n`` symbols that are not already in ``used``. 

88 

89 **Examples:** 

90 ```python 

91 list(oe.parser.gen_unused_symbols("abd", 2)) 

92 #> ['c', 'e'] 

93 ``` 

94 """ 

95 i = cnt = 0 

96 while cnt < n: 

97 s = get_symbol(i) 

98 i += 1 

99 if s in used: 

100 continue 

101 yield s 

102 cnt += 1 

103 

104 

105def convert_to_valid_einsum_chars(einsum_str: str) -> str: 

106 """Convert the str ``einsum_str`` to contain only the alphabetic characters 

107 valid for numpy einsum. If there are too many symbols, let the backend 

108 throw an error. 

109 

110 Examples 

111 -------- 

112 >>> oe.parser.convert_to_valid_einsum_chars("Ĥěļļö") 

113 'cbdda' 

114 """ 

115 symbols = sorted(set(einsum_str) - set(",->")) 

116 replacer = {x: get_symbol(i) for i, x in enumerate(symbols)} 

117 return "".join(replacer.get(x, x) for x in einsum_str) 

118 

119 

120def alpha_canonicalize(equation: str) -> str: 

121 """Alpha convert an equation in an order-independent canonical way. 

122 

123 Examples 

124 -------- 

125 >>> oe.parser.alpha_canonicalize("dcba") 

126 'abcd' 

127 

128 >>> oe.parser.alpha_canonicalize("Ĥěļļö") 

129 'abccd' 

130 """ 

131 rename: Dict[str, str] = {} 

132 for name in equation: 

133 if name in ".,->": 

134 continue 

135 if name not in rename: 

136 rename[name] = get_symbol(len(rename)) 

137 return "".join(rename.get(x, x) for x in equation) 

138 

139 

140def find_output_str(subscripts: str) -> str: 

141 """ 

142 Find the output string for the inputs ``subscripts`` under canonical einstein summation rules. 

143 That is, repeated indices are summed over by default. 

144 

145 Examples 

146 -------- 

147 >>> oe.parser.find_output_str("ab,bc") 

148 'ac' 

149 

150 >>> oe.parser.find_output_str("a,b") 

151 'ab' 

152 

153 >>> oe.parser.find_output_str("a,a,b,b") 

154 '' 

155 """ 

156 tmp_subscripts = subscripts.replace(",", "") 

157 return "".join(s for s in sorted(set(tmp_subscripts)) if tmp_subscripts.count(s) == 1) 

158 

159 

160def find_output_shape(inputs: List[str], shapes: List[TensorShapeType], output: str) -> TensorShapeType: 

161 """Find the output shape for given inputs, shapes and output string, taking 

162 into account broadcasting. 

163 

164 Examples 

165 -------- 

166 >>> oe.parser.find_output_shape(["ab", "bc"], [(2, 3), (3, 4)], "ac") 

167 (2, 4) 

168 

169 # Broadcasting is accounted for 

170 >>> oe.parser.find_output_shape(["a", "a"], [(4, ), (1, )], "a") 

171 (4,) 

172 """ 

173 return tuple(max(shape[loc] for shape, loc in zip(shapes, [x.find(c) for x in inputs]) if loc >= 0) for c in output) 

174 

175 

176def possibly_convert_to_numpy(x: Any) -> Any: 

177 """Convert things without a 'shape' to ndarrays, but leave everything else. 

178 

179 Examples 

180 -------- 

181 >>> oe.parser.possibly_convert_to_numpy(5) 

182 array(5) 

183 

184 >>> oe.parser.possibly_convert_to_numpy([5, 3]) 

185 array([5, 3]) 

186 

187 >>> oe.parser.possibly_convert_to_numpy(np.array([5, 3])) 

188 array([5, 3]) 

189 

190 # Any class with a shape is passed through 

191 >>> class Shape: 

192 ... def __init__(self, shape): 

193 ... self.shape = shape 

194 ... 

195 

196 >>> myshape = Shape((5, 5)) 

197 >>> oe.parser.possibly_convert_to_numpy(myshape) 

198 <__main__.Shape object at 0x10f850710> 

199 """ 

200 

201 if not hasattr(x, "shape"): 

202 return np.asanyarray(x) 

203 else: 

204 return x 

205 

206 

207def convert_subscripts(old_sub: List[Any], symbol_map: Dict[Any, Any]) -> str: 

208 """Convert user custom subscripts list to subscript string according to `symbol_map`. 

209 

210 Examples 

211 -------- 

212 >>> oe.parser.convert_subscripts(['abc', 'def'], {'abc':'a', 'def':'b'}) 

213 'ab' 

214 >>> oe.parser.convert_subscripts([Ellipsis, object], {object:'a'}) 

215 '...a' 

216 """ 

217 new_sub = "" 

218 for s in old_sub: 

219 if s is Ellipsis: 

220 new_sub += "..." 

221 else: 

222 # no need to try/except here because symbol_map has already been checked 

223 new_sub += symbol_map[s] 

224 return new_sub 

225 

226 

227def convert_interleaved_input(operands: Union[List[Any], Tuple[Any]]) -> Tuple[str, List[Any]]: 

228 """Convert 'interleaved' input to standard einsum input.""" 

229 tmp_operands = list(operands) 

230 operand_list = [] 

231 subscript_list = [] 

232 for p in range(len(operands) // 2): 

233 operand_list.append(tmp_operands.pop(0)) 

234 subscript_list.append(tmp_operands.pop(0)) 

235 

236 output_list = tmp_operands[-1] if len(tmp_operands) else None 

237 operands = [possibly_convert_to_numpy(x) for x in operand_list] 

238 

239 # build a map from user symbols to single-character symbols based on `get_symbol` 

240 # The map retains the intrinsic order of user symbols 

241 try: 

242 # collect all user symbols 

243 symbol_set = set(itertools.chain.from_iterable(subscript_list)) 

244 

245 # remove Ellipsis because it can not be compared with other objects 

246 symbol_set.discard(Ellipsis) 

247 

248 # build the map based on sorted user symbols, retaining the order we lost in the `set` 

249 symbol_map = {symbol: get_symbol(idx) for idx, symbol in enumerate(sorted(symbol_set))} 

250 

251 except TypeError: # unhashable or uncomparable object 

252 raise TypeError( 

253 "For this input type lists must contain either Ellipsis " 

254 "or hashable and comparable object (e.g. int, str)." 

255 ) 

256 

257 subscripts = ",".join(convert_subscripts(sub, symbol_map) for sub in subscript_list) 

258 if output_list is not None: 

259 subscripts += "->" 

260 subscripts += convert_subscripts(output_list, symbol_map) 

261 

262 return subscripts, operands 

263 

264 

265def parse_einsum_input(operands: Any, shapes: bool = False) -> Tuple[str, str, List[ArrayType]]: 

266 """ 

267 A reproduction of einsum c side einsum parsing in python. 

268 

269 **Parameters:** 

270 Intakes the same inputs as `contract_path`, but NOT the keyword args. The only 

271 supported keyword argument is: 

272 - **shapes** - *(bool, optional)* Whether ``parse_einsum_input`` should assume arrays (the default) or 

273 array shapes have been supplied. 

274 

275 Returns 

276 ------- 

277 input_strings : str 

278 Parsed input strings 

279 output_string : str 

280 Parsed output string 

281 operands : list of array_like 

282 The operands to use in the numpy contraction 

283 

284 Examples 

285 -------- 

286 The operand list is simplified to reduce printing: 

287 

288 >>> a = np.random.rand(4, 4) 

289 >>> b = np.random.rand(4, 4, 4) 

290 >>> parse_einsum_input(('...a,...a->...', a, b)) 

291 ('za,xza', 'xz', [a, b]) 

292 

293 >>> parse_einsum_input((a, [Ellipsis, 0], b, [Ellipsis, 0])) 

294 ('za,xza', 'xz', [a, b]) 

295 """ 

296 

297 if len(operands) == 0: 

298 raise ValueError("No input operands") 

299 

300 if isinstance(operands[0], str): 

301 subscripts = operands[0].replace(" ", "") 

302 if shapes: 

303 if any([hasattr(o, "shape") for o in operands[1:]]): 

304 raise ValueError( 

305 "shapes is set to True but given at least one operand looks like an array" 

306 " (at least one operand has a shape attribute). " 

307 ) 

308 operands = [possibly_convert_to_numpy(x) for x in operands[1:]] 

309 else: 

310 subscripts, operands = convert_interleaved_input(operands) 

311 

312 if shapes: 

313 operand_shapes = operands 

314 else: 

315 operand_shapes = [o.shape for o in operands] 

316 

317 # Check for proper "->" 

318 if ("-" in subscripts) or (">" in subscripts): 

319 invalid = (subscripts.count("-") > 1) or (subscripts.count(">") > 1) 

320 if invalid or (subscripts.count("->") != 1): 

321 raise ValueError("Subscripts can only contain one '->'.") 

322 

323 # Parse ellipses 

324 if "." in subscripts: 

325 used = subscripts.replace(".", "").replace(",", "").replace("->", "") 

326 ellipse_inds = "".join(gen_unused_symbols(used, max(len(x) for x in operand_shapes))) 

327 longest = 0 

328 

329 # Do we have an output to account for? 

330 if "->" in subscripts: 

331 input_tmp, output_sub = subscripts.split("->") 

332 split_subscripts = input_tmp.split(",") 

333 out_sub = True 

334 else: 

335 split_subscripts = subscripts.split(",") 

336 out_sub = False 

337 

338 for num, sub in enumerate(split_subscripts): 

339 if "." in sub: 

340 if (sub.count(".") != 3) or (sub.count("...") != 1): 

341 raise ValueError("Invalid Ellipses.") 

342 

343 # Take into account numerical values 

344 if operand_shapes[num] == (): 

345 ellipse_count = 0 

346 else: 

347 ellipse_count = max(len(operand_shapes[num]), 1) - (len(sub) - 3) 

348 

349 if ellipse_count > longest: 

350 longest = ellipse_count 

351 

352 if ellipse_count < 0: 

353 raise ValueError("Ellipses lengths do not match.") 

354 elif ellipse_count == 0: 

355 split_subscripts[num] = sub.replace("...", "") 

356 else: 

357 split_subscripts[num] = sub.replace("...", ellipse_inds[-ellipse_count:]) 

358 

359 subscripts = ",".join(split_subscripts) 

360 

361 # Figure out output ellipses 

362 if longest == 0: 

363 out_ellipse = "" 

364 else: 

365 out_ellipse = ellipse_inds[-longest:] 

366 

367 if out_sub: 

368 subscripts += "->" + output_sub.replace("...", out_ellipse) 

369 else: 

370 # Special care for outputless ellipses 

371 output_subscript = find_output_str(subscripts) 

372 normal_inds = "".join(sorted(set(output_subscript) - set(out_ellipse))) 

373 

374 subscripts += "->" + out_ellipse + normal_inds 

375 

376 # Build output string if does not exist 

377 if "->" in subscripts: 

378 input_subscripts, output_subscript = subscripts.split("->") 

379 else: 

380 input_subscripts, output_subscript = subscripts, find_output_str(subscripts) 

381 

382 # Make sure output subscripts are in the input 

383 for char in output_subscript: 

384 if char not in input_subscripts: 

385 raise ValueError("Output character '{}' did not appear in the input".format(char)) 

386 

387 # Make sure number operands is equivalent to the number of terms 

388 if len(input_subscripts.split(",")) != len(operands): 

389 raise ValueError( 

390 f"Number of einsum subscripts, {len(input_subscripts.split(','))}, must be equal to the " 

391 f"number of operands, {len(operands)}." 

392 ) 

393 

394 return input_subscripts, output_subscript, operands