Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/opt_einsum/blas.py: 9%

80 statements  

« prev     ^ index     » next       coverage.py v7.3.1, created at 2023-09-25 06:41 +0000

1""" 

2Determines if a contraction can use BLAS or not 

3""" 

4 

5from typing import List, Sequence, Tuple, Union 

6 

7import numpy as np 

8 

9from . import helpers 

10from .typing import ArrayIndexType 

11 

12__all__ = ["can_blas", "tensor_blas"] 

13 

14 

15def can_blas( 

16 inputs: List[str], 

17 result: str, 

18 idx_removed: ArrayIndexType, 

19 shapes: Sequence[Tuple[int]] = None, 

20) -> Union[str, bool]: 

21 """ 

22 Checks if we can use a BLAS call. 

23 

24 Parameters 

25 ---------- 

26 inputs : list of str 

27 Specifies the subscripts for summation. 

28 result : str 

29 Resulting summation. 

30 idx_removed : set 

31 Indices that are removed in the summation 

32 shapes : sequence of tuple[int], optional 

33 If given, check also that none of the indices are broadcast dimensions. 

34 

35 Returns 

36 ------- 

37 type : str or bool 

38 The type of BLAS call to be used or False if none. 

39 

40 Notes 

41 ----- 

42 We assume several operations are not efficient such as a transposed 

43 DDOT, therefore 'ijk,jki->' should prefer einsum. These return the blas 

44 type appended with "/EINSUM" to differentiate when they can still be done 

45 with tensordot if required, e.g. when a backend has no einsum. 

46 

47 Examples 

48 -------- 

49 >>> can_blas(['ij', 'jk'], 'ik', set('j')) 

50 'GEMM' 

51 

52 >>> can_blas(['ijj', 'jk'], 'ik', set('j')) 

53 False 

54 

55 >>> can_blas(['ab', 'cd'], 'abcd', set()) 

56 'OUTER/EINSUM' 

57 

58 >>> # looks like GEMM but actually 'j' is broadcast: 

59 >>> can_blas(['ij', 'jk'], 'ik', set('j'), shapes=[(4, 1), (5, 6)]) 

60 False 

61 """ 

62 # Can only do two 

63 if len(inputs) != 2: 

64 return False 

65 

66 input_left, input_right = inputs 

67 

68 for c in set(input_left + input_right): 

69 # can't deal with repeated indices on same input or more than 2 total 

70 nl, nr = input_left.count(c), input_right.count(c) 

71 if (nl > 1) or (nr > 1) or (nl + nr > 2): 

72 return False 

73 

74 # can't do implicit summation or dimension collapse e.g. 

75 # "ab,bc->c" (implicitly sum over 'a') 

76 # "ab,ca->ca" (take diagonal of 'a') 

77 if nl + nr - 1 == int(c in result): 

78 return False 

79 

80 # check for broadcast indices e.g: 

81 # "ij,jk->ik" (but one of the 'j' dimensions is broadcast up) 

82 if shapes is not None: 

83 for c in idx_removed: 

84 if shapes[0][input_left.find(c)] != shapes[1][input_right.find(c)]: 

85 return False 

86 

87 # Prefer einsum if not removing indices 

88 # (N.B. tensordot outer faster for large arrays?) 

89 if len(idx_removed) == 0: 

90 return "OUTER/EINSUM" 

91 

92 # Build a few temporaries 

93 sets = [set(x) for x in inputs] 

94 keep_left = sets[0] - idx_removed 

95 keep_right = sets[1] - idx_removed 

96 rs = len(idx_removed) 

97 

98 # DDOT 

99 if inputs[0] == inputs[1]: 

100 return "DOT" 

101 

102 # DDOT does not make sense if you have to transpose - prefer einsum 

103 elif sets[0] == sets[1]: 

104 return "DOT/EINSUM" 

105 

106 # GEMM no transpose 

107 if input_left[-rs:] == input_right[:rs]: 

108 return "GEMM" 

109 

110 # GEMM transpose both 

111 elif input_left[:rs] == input_right[-rs:]: 

112 return "GEMM" 

113 

114 # GEMM transpose right 

115 elif input_left[-rs:] == input_right[-rs:]: 

116 return "GEMM" 

117 

118 # GEMM transpose left 

119 elif input_left[:rs] == input_right[:rs]: 

120 return "GEMM" 

121 

122 # Einsum is faster than vectordot if we have to copy 

123 elif (len(keep_left) == 0) or (len(keep_right) == 0): 

124 return "GEMV/EINSUM" 

125 

126 # Conventional tensordot 

127 else: 

128 return "TDOT" 

129 

130 

131def tensor_blas( 

132 view_left: np.ndarray, 

133 input_left: str, 

134 view_right: np.ndarray, 

135 input_right: str, 

136 index_result: str, 

137 idx_removed: ArrayIndexType, 

138) -> np.ndarray: 

139 """ 

140 Computes the dot product between two tensors, attempts to use np.dot and 

141 then tensordot if that fails. 

142 

143 Parameters 

144 ---------- 

145 view_left : array_like 

146 The left hand view 

147 input_left : str 

148 Indices of the left view 

149 view_right : array_like 

150 The right hand view 

151 input_right : str 

152 Indices of the right view 

153 index_result : str 

154 The resulting indices 

155 idx_removed : set 

156 Indices removed in the contraction 

157 

158 Returns 

159 ------- 

160 type : array 

161 The resulting BLAS operation. 

162 

163 Notes 

164 ----- 

165 Interior function for tensor BLAS. 

166 

167 This function will attempt to use `np.dot` by the iterating through the 

168 four possible transpose cases. If this fails all inner and matrix-vector 

169 operations will be handed off to einsum while all matrix-matrix operations will 

170 first copy the data, perform the DGEMM, and then copy the data to the required 

171 order. 

172 

173 Examples 

174 -------- 

175 

176 >>> a = np.random.rand(4, 4) 

177 >>> b = np.random.rand(4, 4) 

178 >>> tmp = tensor_blas(a, 'ij', b, 'jk', 'ik', set('j')) 

179 >>> np.allclose(tmp, np.dot(a, b)) 

180 

181 """ 

182 

183 idx_removed = frozenset(idx_removed) 

184 keep_left = frozenset(input_left) - idx_removed 

185 keep_right = frozenset(input_right) - idx_removed 

186 

187 # We trust this must be called correctly 

188 dimension_dict = {} 

189 for i, s in zip(input_left, view_left.shape): 

190 dimension_dict[i] = s 

191 for i, s in zip(input_right, view_right.shape): 

192 dimension_dict[i] = s 

193 

194 # Do we want to be able to do this? 

195 

196 # Check for duplicate indices, cannot do einsum('iij,jkk->ik') operations here 

197 # if (len(set(input_left)) != len(input_left)): 

198 # new_inds = ''.join(keep_left) + ''.join(idx_removed) 

199 # view_left = np.einsum(input_left + '->' + new_inds, view_left, order='C') 

200 # input_left = new_inds 

201 

202 # if (len(set(input_right)) != len(input_right)): 

203 # new_inds = ''.join(idx_removed) + ''.join(keep_right) 

204 # view_right = np.einsum(input_right + '->' + new_inds, view_right, order='C') 

205 # input_right = new_inds 

206 

207 # Tensordot guarantees a copy for ndim > 2, should avoid skip if possible 

208 rs = len(idx_removed) 

209 dim_left = helpers.compute_size_by_dict(keep_left, dimension_dict) 

210 dim_right = helpers.compute_size_by_dict(keep_right, dimension_dict) 

211 dim_removed = helpers.compute_size_by_dict(idx_removed, dimension_dict) 

212 tensor_result = input_left + input_right 

213 for sidx in idx_removed: 

214 tensor_result = tensor_result.replace(sidx, "") 

215 

216 # This is ugly, but can vastly speed up certain operations 

217 # Vectordot 

218 if input_left == input_right: 

219 new_view = np.dot(view_left.ravel(), view_right.ravel()) 

220 

221 # Matrix multiply 

222 # No transpose needed 

223 elif input_left[-rs:] == input_right[:rs]: 

224 new_view = np.dot( 

225 view_left.reshape(dim_left, dim_removed), 

226 view_right.reshape(dim_removed, dim_right), 

227 ) 

228 

229 # Transpose both 

230 elif input_left[:rs] == input_right[-rs:]: 

231 new_view = np.dot( 

232 view_left.reshape(dim_removed, dim_left).T, 

233 view_right.reshape(dim_right, dim_removed).T, 

234 ) 

235 

236 # Transpose right 

237 elif input_left[-rs:] == input_right[-rs:]: 

238 new_view = np.dot( 

239 view_left.reshape(dim_left, dim_removed), 

240 view_right.reshape(dim_right, dim_removed).T, 

241 ) 

242 

243 # Transpose left 

244 elif input_left[:rs] == input_right[:rs]: 

245 new_view = np.dot( 

246 view_left.reshape(dim_removed, dim_left).T, 

247 view_right.reshape(dim_removed, dim_right), 

248 ) 

249 

250 # Conventional tensordot 

251 else: 

252 # Find indices to contract over 

253 left_pos: Tuple[int, ...] = () 

254 right_pos: Tuple[int, ...] = () 

255 for fidx in idx_removed: 

256 left_pos += (input_left.find(fidx),) 

257 right_pos += (input_right.find(fidx),) 

258 new_view = np.tensordot(view_left, view_right, axes=(left_pos, right_pos)) 

259 

260 # Make sure the resulting shape is correct 

261 tensor_shape = tuple(dimension_dict[x] for x in tensor_result) 

262 if new_view.shape != tensor_shape: 

263 if len(tensor_result) > 0: 

264 new_view.shape = tensor_shape 

265 else: 

266 new_view = np.squeeze(new_view) 

267 

268 if tensor_result != index_result: 

269 new_view = np.einsum(tensor_result + "->" + index_result, new_view) 

270 

271 return new_view