Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/opt_einsum/parser.py: 15%

124 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1#!/usr/bin/env python 

2# coding: utf-8 

3""" 

4A functionally equivalent parser of the numpy.einsum input parser 

5""" 

6 

7import itertools 

8from collections import OrderedDict 

9 

10import numpy as np 

11 

12__all__ = [ 

13 "is_valid_einsum_char", "has_valid_einsum_chars_only", "get_symbol", "gen_unused_symbols", 

14 "convert_to_valid_einsum_chars", "alpha_canonicalize", "find_output_str", "find_output_shape", 

15 "possibly_convert_to_numpy", "parse_einsum_input" 

16] 

17 

18_einsum_symbols_base = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ' 

19 

20 

21def is_valid_einsum_char(x): 

22 """Check if the character ``x`` is valid for numpy einsum. 

23 

24 Examples 

25 -------- 

26 >>> is_valid_einsum_char("a") 

27 True 

28 

29 >>> is_valid_einsum_char("Ǵ") 

30 False 

31 """ 

32 return (x in _einsum_symbols_base) or (x in ',->.') 

33 

34 

35def has_valid_einsum_chars_only(einsum_str): 

36 """Check if ``einsum_str`` contains only valid characters for numpy einsum. 

37 

38 Examples 

39 -------- 

40 >>> has_valid_einsum_chars_only("abAZ") 

41 True 

42 

43 >>> has_valid_einsum_chars_only("Över") 

44 False 

45 """ 

46 return all(map(is_valid_einsum_char, einsum_str)) 

47 

48 

49def get_symbol(i): 

50 """Get the symbol corresponding to int ``i`` - runs through the usual 52 

51 letters before resorting to unicode characters, starting at ``chr(192)``. 

52 

53 Examples 

54 -------- 

55 >>> get_symbol(2) 

56 'c' 

57 

58 >>> get_symbol(200) 

59 'Ŕ' 

60 

61 >>> get_symbol(20000) 

62 '京' 

63 """ 

64 if i < 52: 

65 return _einsum_symbols_base[i] 

66 return chr(i + 140) 

67 

68 

69def gen_unused_symbols(used, n): 

70 """Generate ``n`` symbols that are not already in ``used``. 

71 

72 Examples 

73 -------- 

74 >>> list(oe.parser.gen_unused_symbols("abd", 2)) 

75 ['c', 'e'] 

76 """ 

77 i = cnt = 0 

78 while cnt < n: 

79 s = get_symbol(i) 

80 i += 1 

81 if s in used: 

82 continue 

83 yield s 

84 cnt += 1 

85 

86 

87def convert_to_valid_einsum_chars(einsum_str): 

88 """Convert the str ``einsum_str`` to contain only the alphabetic characters 

89 valid for numpy einsum. If there are too many symbols, let the backend 

90 throw an error. 

91 

92 Examples 

93 -------- 

94 >>> oe.parser.convert_to_valid_einsum_chars("Ĥěļļö") 

95 'cbdda' 

96 """ 

97 symbols = sorted(set(einsum_str) - set(',->')) 

98 replacer = {x: get_symbol(i) for i, x in enumerate(symbols)} 

99 return "".join(replacer.get(x, x) for x in einsum_str) 

100 

101 

102def alpha_canonicalize(equation): 

103 """Alpha convert an equation in an order-independent canonical way. 

104 

105 Examples 

106 -------- 

107 >>> oe.parser.alpha_canonicalize("dcba") 

108 'abcd' 

109 

110 >>> oe.parser.alpha_canonicalize("Ĥěļļö") 

111 'abccd' 

112 """ 

113 rename = OrderedDict() 

114 for name in equation: 

115 if name in '.,->': 

116 continue 

117 if name not in rename: 

118 rename[name] = get_symbol(len(rename)) 

119 return ''.join(rename.get(x, x) for x in equation) 

120 

121 

122def find_output_str(subscripts): 

123 """ 

124 Find the output string for the inputs ``subscripts`` under canonical einstein summation rules. That is, repeated indices are summed over by default. 

125 

126 Examples 

127 -------- 

128 >>> oe.parser.find_output_str("ab,bc") 

129 'ac' 

130 

131 >>> oe.parser.find_output_str("a,b") 

132 'ab' 

133 

134 >>> oe.parser.find_output_str("a,a,b,b") 

135 '' 

136 """ 

137 tmp_subscripts = subscripts.replace(",", "") 

138 return "".join(s for s in sorted(set(tmp_subscripts)) if tmp_subscripts.count(s) == 1) 

139 

140 

141def find_output_shape(inputs, shapes, output): 

142 """Find the output shape for given inputs, shapes and output string, taking 

143 into account broadcasting. 

144 

145 Examples 

146 -------- 

147 >>> oe.parser.find_output_shape(["ab", "bc"], [(2, 3), (3, 4)], "ac") 

148 (2, 4) 

149 

150 # Broadcasting is accounted for 

151 >>> oe.parser.find_output_shape(["a", "a"], [(4, ), (1, )], "a") 

152 (4,) 

153 """ 

154 return tuple( 

155 max(shape[loc] for shape, loc in zip(shapes, [x.find(c) for x in inputs]) if loc >= 0) for c in output) 

156 

157 

158def possibly_convert_to_numpy(x): 

159 """Convert things without a 'shape' to ndarrays, but leave everything else. 

160 

161 Examples 

162 -------- 

163 >>> oe.parser.possibly_convert_to_numpy(5) 

164 array(5) 

165 

166 >>> oe.parser.possibly_convert_to_numpy([5, 3]) 

167 array([5, 3]) 

168 

169 >>> oe.parser.possibly_convert_to_numpy(np.array([5, 3])) 

170 array([5, 3]) 

171 

172 # Any class with a shape is passed through 

173 >>> class Shape: 

174 ... def __init__(self, shape): 

175 ... self.shape = shape 

176 ... 

177 

178 >>> myshape = Shape((5, 5)) 

179 >>> oe.parser.possibly_convert_to_numpy(myshape) 

180 <__main__.Shape object at 0x10f850710> 

181 """ 

182 

183 if not hasattr(x, 'shape'): 

184 return np.asanyarray(x) 

185 else: 

186 return x 

187 

188 

189def convert_subscripts(old_sub, symbol_map): 

190 """Convert user custom subscripts list to subscript string according to `symbol_map`. 

191 

192 Examples 

193 -------- 

194 >>> oe.parser.convert_subscripts(['abc', 'def'], {'abc':'a', 'def':'b'}) 

195 'ab' 

196 >>> oe.parser.convert_subscripts([Ellipsis, object], {object:'a'}) 

197 '...a' 

198 """ 

199 new_sub = "" 

200 for s in old_sub: 

201 if s is Ellipsis: 

202 new_sub += "..." 

203 else: 

204 # no need to try/except here because symbol_map has already been checked 

205 new_sub += symbol_map[s] 

206 return new_sub 

207 

208 

209def convert_interleaved_input(operands): 

210 """Convert 'interleaved' input to standard einsum input. 

211 """ 

212 tmp_operands = list(operands) 

213 operand_list = [] 

214 subscript_list = [] 

215 for p in range(len(operands) // 2): 

216 operand_list.append(tmp_operands.pop(0)) 

217 subscript_list.append(tmp_operands.pop(0)) 

218 

219 output_list = tmp_operands[-1] if len(tmp_operands) else None 

220 operands = [possibly_convert_to_numpy(x) for x in operand_list] 

221 

222 # build a map from user symbols to single-character symbols based on `get_symbol` 

223 # The map retains the intrinsic order of user symbols 

224 try: 

225 # collect all user symbols 

226 symbol_set = set(itertools.chain.from_iterable(subscript_list)) 

227 

228 # remove Ellipsis because it can not be compared with other objects 

229 symbol_set.discard(Ellipsis) 

230 

231 # build the map based on sorted user symbols, retaining the order we lost in the `set` 

232 symbol_map = {symbol: get_symbol(idx) for idx, symbol in enumerate(sorted(symbol_set))} 

233 

234 except TypeError: # unhashable or uncomparable object 

235 raise TypeError("For this input type lists must contain either Ellipsis " 

236 "or hashable and comparable object (e.g. int, str).") 

237 

238 subscripts = ','.join(convert_subscripts(sub, symbol_map) for sub in subscript_list) 

239 if output_list is not None: 

240 subscripts += "->" 

241 subscripts += convert_subscripts(output_list, symbol_map) 

242 

243 return subscripts, operands 

244 

245 

246def parse_einsum_input(operands): 

247 """ 

248 A reproduction of einsum c side einsum parsing in python. 

249 

250 Returns 

251 ------- 

252 input_strings : str 

253 Parsed input strings 

254 output_string : str 

255 Parsed output string 

256 operands : list of array_like 

257 The operands to use in the numpy contraction 

258 

259 Examples 

260 -------- 

261 The operand list is simplified to reduce printing: 

262 

263 >>> a = np.random.rand(4, 4) 

264 >>> b = np.random.rand(4, 4, 4) 

265 >>> parse_einsum_input(('...a,...a->...', a, b)) 

266 ('za,xza', 'xz', [a, b]) 

267 

268 >>> parse_einsum_input((a, [Ellipsis, 0], b, [Ellipsis, 0])) 

269 ('za,xza', 'xz', [a, b]) 

270 """ 

271 

272 if len(operands) == 0: 

273 raise ValueError("No input operands") 

274 

275 if isinstance(operands[0], str): 

276 subscripts = operands[0].replace(" ", "") 

277 operands = [possibly_convert_to_numpy(x) for x in operands[1:]] 

278 

279 else: 

280 subscripts, operands = convert_interleaved_input(operands) 

281 

282 # Check for proper "->" 

283 if ("-" in subscripts) or (">" in subscripts): 

284 invalid = (subscripts.count("-") > 1) or (subscripts.count(">") > 1) 

285 if invalid or (subscripts.count("->") != 1): 

286 raise ValueError("Subscripts can only contain one '->'.") 

287 

288 # Parse ellipses 

289 if "." in subscripts: 

290 used = subscripts.replace(".", "").replace(",", "").replace("->", "") 

291 ellipse_inds = "".join(gen_unused_symbols(used, max(len(x.shape) for x in operands))) 

292 longest = 0 

293 

294 # Do we have an output to account for? 

295 if "->" in subscripts: 

296 input_tmp, output_sub = subscripts.split("->") 

297 split_subscripts = input_tmp.split(",") 

298 out_sub = True 

299 else: 

300 split_subscripts = subscripts.split(',') 

301 out_sub = False 

302 

303 for num, sub in enumerate(split_subscripts): 

304 if "." in sub: 

305 if (sub.count(".") != 3) or (sub.count("...") != 1): 

306 raise ValueError("Invalid Ellipses.") 

307 

308 # Take into account numerical values 

309 if operands[num].shape == (): 

310 ellipse_count = 0 

311 else: 

312 ellipse_count = max(len(operands[num].shape), 1) - (len(sub) - 3) 

313 

314 if ellipse_count > longest: 

315 longest = ellipse_count 

316 

317 if ellipse_count < 0: 

318 raise ValueError("Ellipses lengths do not match.") 

319 elif ellipse_count == 0: 

320 split_subscripts[num] = sub.replace('...', '') 

321 else: 

322 split_subscripts[num] = sub.replace('...', ellipse_inds[-ellipse_count:]) 

323 

324 subscripts = ",".join(split_subscripts) 

325 

326 # Figure out output ellipses 

327 if longest == 0: 

328 out_ellipse = "" 

329 else: 

330 out_ellipse = ellipse_inds[-longest:] 

331 

332 if out_sub: 

333 subscripts += "->" + output_sub.replace("...", out_ellipse) 

334 else: 

335 # Special care for outputless ellipses 

336 output_subscript = find_output_str(subscripts) 

337 normal_inds = ''.join(sorted(set(output_subscript) - set(out_ellipse))) 

338 

339 subscripts += "->" + out_ellipse + normal_inds 

340 

341 # Build output string if does not exist 

342 if "->" in subscripts: 

343 input_subscripts, output_subscript = subscripts.split("->") 

344 else: 

345 input_subscripts, output_subscript = subscripts, find_output_str(subscripts) 

346 

347 # Make sure output subscripts are in the input 

348 for char in output_subscript: 

349 if char not in input_subscripts: 

350 raise ValueError("Output character '{}' did not appear in the input".format(char)) 

351 

352 # Make sure number operands is equivalent to the number of terms 

353 if len(input_subscripts.split(',')) != len(operands): 

354 raise ValueError("Number of einsum subscripts must be equal to the " "number of operands.") 

355 

356 return input_subscripts, output_subscript, operands