Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/opt_einsum/helpers.py: 22%

77 statements  

« prev     ^ index     » next       coverage.py v7.3.1, created at 2023-09-25 06:41 +0000

1""" 

2Contains helper functions for opt_einsum testing scripts 

3""" 

4 

5from typing import Any, Collection, Dict, FrozenSet, Iterable, List, Optional, Tuple, Union, overload 

6 

7import numpy as np 

8 

9from .parser import get_symbol 

10from .typing import ArrayIndexType, PathType 

11 

12__all__ = ["build_views", "compute_size_by_dict", "find_contraction", "flop_count"] 

13 

14_valid_chars = "abcdefghijklmopqABC" 

15_sizes = np.array([2, 3, 4, 5, 4, 3, 2, 6, 5, 4, 3, 2, 5, 7, 4, 3, 2, 3, 4]) 

16_default_dim_dict = {c: s for c, s in zip(_valid_chars, _sizes)} 

17 

18 

19def build_views(string: str, dimension_dict: Optional[Dict[str, int]] = None) -> List[np.ndarray]: 

20 """ 

21 Builds random numpy arrays for testing. 

22 

23 Parameters 

24 ---------- 

25 string : str 

26 List of tensor strings to build 

27 dimension_dict : dictionary 

28 Dictionary of index _sizes 

29 

30 Returns 

31 ------- 

32 ret : list of np.ndarry's 

33 The resulting views. 

34 

35 Examples 

36 -------- 

37 >>> view = build_views('abbc', {'a': 2, 'b':3, 'c':5}) 

38 >>> view[0].shape 

39 (2, 3, 3, 5) 

40 

41 """ 

42 

43 if dimension_dict is None: 

44 dimension_dict = _default_dim_dict 

45 

46 views = [] 

47 terms = string.split("->")[0].split(",") 

48 for term in terms: 

49 dims = [dimension_dict[x] for x in term] 

50 views.append(np.random.rand(*dims)) 

51 return views 

52 

53 

54@overload 

55def compute_size_by_dict(indices: Iterable[int], idx_dict: List[int]) -> int: 

56 ... 

57 

58 

59@overload 

60def compute_size_by_dict(indices: Collection[str], idx_dict: Dict[str, int]) -> int: 

61 ... 

62 

63 

64def compute_size_by_dict(indices: Any, idx_dict: Any) -> int: 

65 """ 

66 Computes the product of the elements in indices based on the dictionary 

67 idx_dict. 

68 

69 Parameters 

70 ---------- 

71 indices : iterable 

72 Indices to base the product on. 

73 idx_dict : dictionary 

74 Dictionary of index _sizes 

75 

76 Returns 

77 ------- 

78 ret : int 

79 The resulting product. 

80 

81 Examples 

82 -------- 

83 >>> compute_size_by_dict('abbc', {'a': 2, 'b':3, 'c':5}) 

84 90 

85 

86 """ 

87 ret = 1 

88 for i in indices: # lgtm [py/iteration-string-and-sequence] 

89 ret *= idx_dict[i] 

90 return ret 

91 

92 

93def find_contraction( 

94 positions: Collection[int], 

95 input_sets: List[ArrayIndexType], 

96 output_set: ArrayIndexType, 

97) -> Tuple[FrozenSet[str], List[ArrayIndexType], ArrayIndexType, ArrayIndexType]: 

98 """ 

99 Finds the contraction for a given set of input and output sets. 

100 

101 Parameters 

102 ---------- 

103 positions : iterable 

104 Integer positions of terms used in the contraction. 

105 input_sets : list 

106 List of sets that represent the lhs side of the einsum subscript 

107 output_set : set 

108 Set that represents the rhs side of the overall einsum subscript 

109 

110 Returns 

111 ------- 

112 new_result : set 

113 The indices of the resulting contraction 

114 remaining : list 

115 List of sets that have not been contracted, the new set is appended to 

116 the end of this list 

117 idx_removed : set 

118 Indices removed from the entire contraction 

119 idx_contraction : set 

120 The indices used in the current contraction 

121 

122 Examples 

123 -------- 

124 

125 # A simple dot product test case 

126 >>> pos = (0, 1) 

127 >>> isets = [set('ab'), set('bc')] 

128 >>> oset = set('ac') 

129 >>> find_contraction(pos, isets, oset) 

130 ({'a', 'c'}, [{'a', 'c'}], {'b'}, {'a', 'b', 'c'}) 

131 

132 # A more complex case with additional terms in the contraction 

133 >>> pos = (0, 2) 

134 >>> isets = [set('abd'), set('ac'), set('bdc')] 

135 >>> oset = set('ac') 

136 >>> find_contraction(pos, isets, oset) 

137 ({'a', 'c'}, [{'a', 'c'}, {'a', 'c'}], {'b', 'd'}, {'a', 'b', 'c', 'd'}) 

138 """ 

139 

140 remaining = list(input_sets) 

141 inputs = (remaining.pop(i) for i in sorted(positions, reverse=True)) 

142 idx_contract = frozenset.union(*inputs) 

143 idx_remain = output_set.union(*remaining) 

144 

145 new_result = idx_remain & idx_contract 

146 idx_removed = idx_contract - new_result 

147 remaining.append(new_result) 

148 

149 return new_result, remaining, idx_removed, idx_contract 

150 

151 

152def flop_count( 

153 idx_contraction: Collection[str], 

154 inner: bool, 

155 num_terms: int, 

156 size_dictionary: Dict[str, int], 

157) -> int: 

158 """ 

159 Computes the number of FLOPS in the contraction. 

160 

161 Parameters 

162 ---------- 

163 idx_contraction : iterable 

164 The indices involved in the contraction 

165 inner : bool 

166 Does this contraction require an inner product? 

167 num_terms : int 

168 The number of terms in a contraction 

169 size_dictionary : dict 

170 The size of each of the indices in idx_contraction 

171 

172 Returns 

173 ------- 

174 flop_count : int 

175 The total number of FLOPS required for the contraction. 

176 

177 Examples 

178 -------- 

179 

180 >>> flop_count('abc', False, 1, {'a': 2, 'b':3, 'c':5}) 

181 30 

182 

183 >>> flop_count('abc', True, 2, {'a': 2, 'b':3, 'c':5}) 

184 60 

185 

186 """ 

187 

188 overall_size = compute_size_by_dict(idx_contraction, size_dictionary) 

189 op_factor = max(1, num_terms - 1) 

190 if inner: 

191 op_factor += 1 

192 

193 return overall_size * op_factor 

194 

195 

196def rand_equation( 

197 n: int, 

198 reg: int, 

199 n_out: int = 0, 

200 d_min: int = 2, 

201 d_max: int = 9, 

202 seed: Optional[int] = None, 

203 global_dim: bool = False, 

204 return_size_dict: bool = False, 

205) -> Union[Tuple[str, PathType, Dict[str, int]], Tuple[str, PathType]]: 

206 """Generate a random contraction and shapes. 

207 

208 Parameters 

209 ---------- 

210 n : int 

211 Number of array arguments. 

212 reg : int 

213 'Regularity' of the contraction graph. This essentially determines how 

214 many indices each tensor shares with others on average. 

215 n_out : int, optional 

216 Number of output indices (i.e. the number of non-contracted indices). 

217 Defaults to 0, i.e., a contraction resulting in a scalar. 

218 d_min : int, optional 

219 Minimum dimension size. 

220 d_max : int, optional 

221 Maximum dimension size. 

222 seed: int, optional 

223 If not None, seed numpy's random generator with this. 

224 global_dim : bool, optional 

225 Add a global, 'broadcast', dimension to every operand. 

226 return_size_dict : bool, optional 

227 Return the mapping of indices to sizes. 

228 

229 Returns 

230 ------- 

231 eq : str 

232 The equation string. 

233 shapes : list[tuple[int]] 

234 The array shapes. 

235 size_dict : dict[str, int] 

236 The dict of index sizes, only returned if ``return_size_dict=True``. 

237 

238 Examples 

239 -------- 

240 >>> eq, shapes = rand_equation(n=10, reg=4, n_out=5, seed=42) 

241 >>> eq 

242 'oyeqn,tmaq,skpo,vg,hxui,n,fwxmr,hitplcj,kudlgfv,rywjsb->cebda' 

243 

244 >>> shapes 

245 [(9, 5, 4, 5, 4), 

246 (4, 4, 8, 5), 

247 (9, 4, 6, 9), 

248 (6, 6), 

249 (6, 9, 7, 8), 

250 (4,), 

251 (9, 3, 9, 4, 9), 

252 (6, 8, 4, 6, 8, 6, 3), 

253 (4, 7, 8, 8, 6, 9, 6), 

254 (9, 5, 3, 3, 9, 5)] 

255 """ 

256 

257 if seed is not None: 

258 np.random.seed(seed) 

259 

260 # total number of indices 

261 num_inds = n * reg // 2 + n_out 

262 inputs = ["" for _ in range(n)] 

263 output = [] 

264 

265 size_dict = {get_symbol(i): np.random.randint(d_min, d_max + 1) for i in range(num_inds)} 

266 

267 # generate a list of indices to place either once or twice 

268 def gen(): 

269 for i, ix in enumerate(size_dict): 

270 # generate an outer index 

271 if i < n_out: 

272 output.append(ix) 

273 yield ix 

274 # generate a bond 

275 else: 

276 yield ix 

277 yield ix 

278 

279 # add the indices randomly to the inputs 

280 for i, ix in enumerate(np.random.permutation(list(gen()))): 

281 # make sure all inputs have at least one index 

282 if i < n: 

283 inputs[i] += ix 

284 else: 

285 # don't add any traces on same op 

286 where = np.random.randint(0, n) 

287 while ix in inputs[where]: 

288 where = np.random.randint(0, n) 

289 

290 inputs[where] += ix 

291 

292 # possibly add the same global dim to every arg 

293 if global_dim: 

294 gdim = get_symbol(num_inds) 

295 size_dict[gdim] = np.random.randint(d_min, d_max + 1) 

296 for i in range(n): 

297 inputs[i] += gdim 

298 output += gdim 

299 

300 # randomly transpose the output indices and form equation 

301 output = "".join(np.random.permutation(output)) # type: ignore 

302 eq = "{}->{}".format(",".join(inputs), output) 

303 

304 # make the shapes 

305 shapes = [tuple(size_dict[ix] for ix in op) for op in inputs] 

306 

307 ret = (eq, shapes) 

308 

309 if return_size_dict: 

310 return ret + (size_dict,) 

311 else: 

312 return ret