Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/opt_einsum/helpers.py: 17%

70 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1""" 

2Contains helper functions for opt_einsum testing scripts 

3""" 

4 

5from collections import OrderedDict 

6 

7import numpy as np 

8 

9from .parser import get_symbol 

10 

11__all__ = ["build_views", "compute_size_by_dict", "find_contraction", "flop_count"] 

12 

13_valid_chars = "abcdefghijklmopqABC" 

14_sizes = np.array([2, 3, 4, 5, 4, 3, 2, 6, 5, 4, 3, 2, 5, 7, 4, 3, 2, 3, 4]) 

15_default_dim_dict = {c: s for c, s in zip(_valid_chars, _sizes)} 

16 

17 

18def build_views(string, dimension_dict=None): 

19 """ 

20 Builds random numpy arrays for testing. 

21 

22 Parameters 

23 ---------- 

24 string : list of str 

25 List of tensor strings to build 

26 dimension_dict : dictionary 

27 Dictionary of index _sizes 

28 

29 Returns 

30 ------- 

31 ret : list of np.ndarry's 

32 The resulting views. 

33 

34 Examples 

35 -------- 

36 >>> view = build_views(['abbc'], {'a': 2, 'b':3, 'c':5}) 

37 >>> view[0].shape 

38 (2, 3, 3, 5) 

39 

40 """ 

41 

42 if dimension_dict is None: 

43 dimension_dict = _default_dim_dict 

44 

45 views = [] 

46 terms = string.split('->')[0].split(',') 

47 for term in terms: 

48 dims = [dimension_dict[x] for x in term] 

49 views.append(np.random.rand(*dims)) 

50 return views 

51 

52 

53def compute_size_by_dict(indices, idx_dict): 

54 """ 

55 Computes the product of the elements in indices based on the dictionary 

56 idx_dict. 

57 

58 Parameters 

59 ---------- 

60 indices : iterable 

61 Indices to base the product on. 

62 idx_dict : dictionary 

63 Dictionary of index _sizes 

64 

65 Returns 

66 ------- 

67 ret : int 

68 The resulting product. 

69 

70 Examples 

71 -------- 

72 >>> compute_size_by_dict('abbc', {'a': 2, 'b':3, 'c':5}) 

73 90 

74 

75 """ 

76 ret = 1 

77 for i in indices: # lgtm [py/iteration-string-and-sequence] 

78 ret *= idx_dict[i] 

79 return ret 

80 

81 

82def find_contraction(positions, input_sets, output_set): 

83 """ 

84 Finds the contraction for a given set of input and output sets. 

85 

86 Parameters 

87 ---------- 

88 positions : iterable 

89 Integer positions of terms used in the contraction. 

90 input_sets : list 

91 List of sets that represent the lhs side of the einsum subscript 

92 output_set : set 

93 Set that represents the rhs side of the overall einsum subscript 

94 

95 Returns 

96 ------- 

97 new_result : set 

98 The indices of the resulting contraction 

99 remaining : list 

100 List of sets that have not been contracted, the new set is appended to 

101 the end of this list 

102 idx_removed : set 

103 Indices removed from the entire contraction 

104 idx_contraction : set 

105 The indices used in the current contraction 

106 

107 Examples 

108 -------- 

109 

110 # A simple dot product test case 

111 >>> pos = (0, 1) 

112 >>> isets = [set('ab'), set('bc')] 

113 >>> oset = set('ac') 

114 >>> find_contraction(pos, isets, oset) 

115 ({'a', 'c'}, [{'a', 'c'}], {'b'}, {'a', 'b', 'c'}) 

116 

117 # A more complex case with additional terms in the contraction 

118 >>> pos = (0, 2) 

119 >>> isets = [set('abd'), set('ac'), set('bdc')] 

120 >>> oset = set('ac') 

121 >>> find_contraction(pos, isets, oset) 

122 ({'a', 'c'}, [{'a', 'c'}, {'a', 'c'}], {'b', 'd'}, {'a', 'b', 'c', 'd'}) 

123 """ 

124 

125 remaining = list(input_sets) 

126 inputs = (remaining.pop(i) for i in sorted(positions, reverse=True)) 

127 idx_contract = set.union(*inputs) 

128 idx_remain = output_set.union(*remaining) 

129 

130 new_result = idx_remain & idx_contract 

131 idx_removed = (idx_contract - new_result) 

132 remaining.append(new_result) 

133 

134 return new_result, remaining, idx_removed, idx_contract 

135 

136 

137def flop_count(idx_contraction, inner, num_terms, size_dictionary): 

138 """ 

139 Computes the number of FLOPS in the contraction. 

140 

141 Parameters 

142 ---------- 

143 idx_contraction : iterable 

144 The indices involved in the contraction 

145 inner : bool 

146 Does this contraction require an inner product? 

147 num_terms : int 

148 The number of terms in a contraction 

149 size_dictionary : dict 

150 The size of each of the indices in idx_contraction 

151 

152 Returns 

153 ------- 

154 flop_count : int 

155 The total number of FLOPS required for the contraction. 

156 

157 Examples 

158 -------- 

159 

160 >>> flop_count('abc', False, 1, {'a': 2, 'b':3, 'c':5}) 

161 90 

162 

163 >>> flop_count('abc', True, 2, {'a': 2, 'b':3, 'c':5}) 

164 270 

165 

166 """ 

167 

168 overall_size = compute_size_by_dict(idx_contraction, size_dictionary) 

169 op_factor = max(1, num_terms - 1) 

170 if inner: 

171 op_factor += 1 

172 

173 return overall_size * op_factor 

174 

175 

176def rand_equation(n, reg, n_out=0, d_min=2, d_max=9, seed=None, global_dim=False, return_size_dict=False): 

177 """Generate a random contraction and shapes. 

178 

179 Parameters 

180 ---------- 

181 n : int 

182 Number of array arguments. 

183 reg : int 

184 'Regularity' of the contraction graph. This essentially determines how 

185 many indices each tensor shares with others on average. 

186 n_out : int, optional 

187 Number of output indices (i.e. the number of non-contracted indices). 

188 Defaults to 0, i.e., a contraction resulting in a scalar. 

189 d_min : int, optional 

190 Minimum dimension size. 

191 d_max : int, optional 

192 Maximum dimension size. 

193 seed: int, optional 

194 If not None, seed numpy's random generator with this. 

195 global_dim : bool, optional 

196 Add a global, 'broadcast', dimension to every operand. 

197 return_size_dict : bool, optional 

198 Return the mapping of indices to sizes. 

199 

200 Returns 

201 ------- 

202 eq : str 

203 The equation string. 

204 shapes : list[tuple[int]] 

205 The array shapes. 

206 size_dict : dict[str, int] 

207 The dict of index sizes, only returned if ``return_size_dict=True``. 

208 

209 Examples 

210 -------- 

211 >>> eq, shapes = rand_equation(n=10, reg=4, n_out=5, seed=42) 

212 >>> eq 

213 'oyeqn,tmaq,skpo,vg,hxui,n,fwxmr,hitplcj,kudlgfv,rywjsb->cebda' 

214 

215 >>> shapes 

216 [(9, 5, 4, 5, 4), 

217 (4, 4, 8, 5), 

218 (9, 4, 6, 9), 

219 (6, 6), 

220 (6, 9, 7, 8), 

221 (4,), 

222 (9, 3, 9, 4, 9), 

223 (6, 8, 4, 6, 8, 6, 3), 

224 (4, 7, 8, 8, 6, 9, 6), 

225 (9, 5, 3, 3, 9, 5)] 

226 """ 

227 

228 if seed is not None: 

229 np.random.seed(seed) 

230 

231 # total number of indices 

232 num_inds = n * reg // 2 + n_out 

233 inputs = ["" for _ in range(n)] 

234 output = [] 

235 

236 size_dict = OrderedDict((get_symbol(i), np.random.randint(d_min, d_max + 1)) for i in range(num_inds)) 

237 

238 # generate a list of indices to place either once or twice 

239 def gen(): 

240 for i, ix in enumerate(size_dict): 

241 # generate an outer index 

242 if i < n_out: 

243 output.append(ix) 

244 yield ix 

245 # generate a bond 

246 else: 

247 yield ix 

248 yield ix 

249 

250 # add the indices randomly to the inputs 

251 for i, ix in enumerate(np.random.permutation(list(gen()))): 

252 # make sure all inputs have at least one index 

253 if i < n: 

254 inputs[i] += ix 

255 else: 

256 # don't add any traces on same op 

257 where = np.random.randint(0, n) 

258 while ix in inputs[where]: 

259 where = np.random.randint(0, n) 

260 

261 inputs[where] += ix 

262 

263 # possibly add the same global dim to every arg 

264 if global_dim: 

265 gdim = get_symbol(num_inds) 

266 size_dict[gdim] = np.random.randint(d_min, d_max + 1) 

267 for i in range(n): 

268 inputs[i] += gdim 

269 output += gdim 

270 

271 # randomly transpose the output indices and form equation 

272 output = "".join(np.random.permutation(output)) 

273 eq = "{}->{}".format(",".join(inputs), output) 

274 

275 # make the shapes 

276 shapes = [tuple(size_dict[ix] for ix in op) for op in inputs] 

277 

278 ret = (eq, shapes) 

279 

280 if return_size_dict: 

281 ret += (size_dict, ) 

282 

283 return ret