Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/opt_einsum/blas.py: 6%

77 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1""" 

2Determines if a contraction can use BLAS or not 

3""" 

4 

5import numpy as np 

6 

7from . import helpers 

8 

9__all__ = ["can_blas", "tensor_blas"] 

10 

11 

12def can_blas(inputs, result, idx_removed, shapes=None): 

13 """ 

14 Checks if we can use a BLAS call. 

15 

16 Parameters 

17 ---------- 

18 inputs : list of str 

19 Specifies the subscripts for summation. 

20 result : str 

21 Resulting summation. 

22 idx_removed : set 

23 Indices that are removed in the summation 

24 shapes : sequence of tuple[int], optional 

25 If given, check also that none of the indices are broadcast dimensions. 

26 

27 Returns 

28 ------- 

29 type : str or bool 

30 The type of BLAS call to be used or False if none. 

31 

32 Notes 

33 ----- 

34 We assume several operations are not efficient such as a transposed 

35 DDOT, therefore 'ijk,jki->' should prefer einsum. These return the blas 

36 type appended with "/EINSUM" to differentiate when they can still be done 

37 with tensordot if required, e.g. when a backend has no einsum. 

38 

39 Examples 

40 -------- 

41 >>> can_blas(['ij', 'jk'], 'ik', set('j')) 

42 'GEMM' 

43 

44 >>> can_blas(['ijj', 'jk'], 'ik', set('j')) 

45 False 

46 

47 >>> can_blas(['ab', 'cd'], 'abcd', set()) 

48 'OUTER/EINSUM' 

49 

50 >>> # looks like GEMM but actually 'j' is broadcast: 

51 >>> can_blas(['ij', 'jk'], 'ik', set('j'), shapes=[(4, 1), (5, 6)]) 

52 False 

53 """ 

54 # Can only do two 

55 if len(inputs) != 2: 

56 return False 

57 

58 input_left, input_right = inputs 

59 

60 for c in set(input_left + input_right): 

61 # can't deal with repeated indices on same input or more than 2 total 

62 nl, nr = input_left.count(c), input_right.count(c) 

63 if (nl > 1) or (nr > 1) or (nl + nr > 2): 

64 return False 

65 

66 # can't do implicit summation or dimension collapse e.g. 

67 # "ab,bc->c" (implicitly sum over 'a') 

68 # "ab,ca->ca" (take diagonal of 'a') 

69 if nl + nr - 1 == int(c in result): 

70 return False 

71 

72 # check for broadcast indices e.g: 

73 # "ij,jk->ik" (but one of the 'j' dimensions is broadcast up) 

74 if shapes is not None: 

75 for c in idx_removed: 

76 if shapes[0][input_left.find(c)] != shapes[1][input_right.find(c)]: 

77 return False 

78 

79 # Prefer einsum if not removing indices 

80 # (N.B. tensordot outer faster for large arrays?) 

81 if len(idx_removed) == 0: 

82 return 'OUTER/EINSUM' 

83 

84 # Build a few temporaries 

85 sets = [set(x) for x in inputs] 

86 keep_left = sets[0] - idx_removed 

87 keep_right = sets[1] - idx_removed 

88 rs = len(idx_removed) 

89 

90 # DDOT 

91 if inputs[0] == inputs[1]: 

92 return 'DOT' 

93 

94 # DDOT doesnt make sense if you have to tranpose - prefer einsum 

95 elif sets[0] == sets[1]: 

96 return 'DOT/EINSUM' 

97 

98 # GEMM no transpose 

99 if input_left[-rs:] == input_right[:rs]: 

100 return 'GEMM' 

101 

102 # GEMM transpose both 

103 elif input_left[:rs] == input_right[-rs:]: 

104 return 'GEMM' 

105 

106 # GEMM transpose right 

107 elif input_left[-rs:] == input_right[-rs:]: 

108 return 'GEMM' 

109 

110 # GEMM tranpose left 

111 elif input_left[:rs] == input_right[:rs]: 

112 return 'GEMM' 

113 

114 # Einsum is faster than vectordot if we have to copy 

115 elif (len(keep_left) == 0) or (len(keep_right) == 0): 

116 return 'GEMV/EINSUM' 

117 

118 # Conventional tensordot 

119 else: 

120 return 'TDOT' 

121 

122 

123def tensor_blas(view_left, input_left, view_right, input_right, index_result, idx_removed): 

124 """ 

125 Computes the dot product between two tensors, attempts to use np.dot and 

126 then tensordot if that fails. 

127 

128 Parameters 

129 ---------- 

130 view_left : array_like 

131 The left hand view 

132 input_left : str 

133 Indices of the left view 

134 view_right : array_like 

135 The right hand view 

136 input_right : str 

137 Indices of the right view 

138 index_result : str 

139 The resulting indices 

140 idx_removed : set 

141 Indices removed in the contraction 

142 

143 Returns 

144 ------- 

145 type : array 

146 The resulting BLAS operation. 

147 

148 Notes 

149 ----- 

150 Interior function for tensor BLAS. 

151 

152 This function will attempt to use `np.dot` by the iterating through the 

153 four possible transpose cases. If this fails all inner and matrix-vector 

154 operations will be handed off to einsum while all matrix-matrix operations will 

155 first copy the data, perform the DGEMM, and then copy the data to the required 

156 order. 

157 

158 Examples 

159 -------- 

160 

161 >>> a = np.random.rand(4, 4) 

162 >>> b = np.random.rand(4, 4) 

163 >>> tmp = tensor_blas(a, 'ij', b, 'jk', 'ik', set('j')) 

164 >>> np.allclose(tmp, np.dot(a, b)) 

165 

166 """ 

167 

168 idx_removed = set(idx_removed) 

169 keep_left = set(input_left) - idx_removed 

170 keep_right = set(input_right) - idx_removed 

171 

172 # We trust this must be called correctly 

173 dimension_dict = {} 

174 for i, s in zip(input_left, view_left.shape): 

175 dimension_dict[i] = s 

176 for i, s in zip(input_right, view_right.shape): 

177 dimension_dict[i] = s 

178 

179 # Do we want to be able to do this? 

180 

181 # Check for duplicate indices, cannot do einsum('iij,jkk->ik') operations here 

182 # if (len(set(input_left)) != len(input_left)): 

183 # new_inds = ''.join(keep_left) + ''.join(idx_removed) 

184 # view_left = np.einsum(input_left + '->' + new_inds, view_left, order='C') 

185 # input_left = new_inds 

186 

187 # if (len(set(input_right)) != len(input_right)): 

188 # new_inds = ''.join(idx_removed) + ''.join(keep_right) 

189 # view_right = np.einsum(input_right + '->' + new_inds, view_right, order='C') 

190 # input_right = new_inds 

191 

192 # Tensordot guarantees a copy for ndim > 2, should avoid skip if possible 

193 rs = len(idx_removed) 

194 dim_left = helpers.compute_size_by_dict(keep_left, dimension_dict) 

195 dim_right = helpers.compute_size_by_dict(keep_right, dimension_dict) 

196 dim_removed = helpers.compute_size_by_dict(idx_removed, dimension_dict) 

197 tensor_result = input_left + input_right 

198 for s in idx_removed: 

199 tensor_result = tensor_result.replace(s, "") 

200 

201 # This is ugly, but can vastly speed up certain operations 

202 # Vectordot 

203 if input_left == input_right: 

204 new_view = np.dot(view_left.ravel(), view_right.ravel()) 

205 

206 # Matrix multiply 

207 # No transpose needed 

208 elif input_left[-rs:] == input_right[:rs]: 

209 new_view = np.dot(view_left.reshape(dim_left, dim_removed), view_right.reshape(dim_removed, dim_right)) 

210 

211 # Transpose both 

212 elif input_left[:rs] == input_right[-rs:]: 

213 new_view = np.dot(view_left.reshape(dim_removed, dim_left).T, view_right.reshape(dim_right, dim_removed).T) 

214 

215 # Transpose right 

216 elif input_left[-rs:] == input_right[-rs:]: 

217 new_view = np.dot(view_left.reshape(dim_left, dim_removed), view_right.reshape(dim_right, dim_removed).T) 

218 

219 # Tranpose left 

220 elif input_left[:rs] == input_right[:rs]: 

221 new_view = np.dot(view_left.reshape(dim_removed, dim_left).T, view_right.reshape(dim_removed, dim_right)) 

222 

223 # Conventional tensordot 

224 else: 

225 # Find indices to contract over 

226 left_pos, right_pos = (), () 

227 for s in idx_removed: 

228 left_pos += (input_left.find(s), ) 

229 right_pos += (input_right.find(s), ) 

230 new_view = np.tensordot(view_left, view_right, axes=(left_pos, right_pos)) 

231 

232 # Make sure the resulting shape is correct 

233 tensor_shape = tuple(dimension_dict[x] for x in tensor_result) 

234 if new_view.shape != tensor_shape: 

235 if len(tensor_result) > 0: 

236 new_view.shape = tensor_shape 

237 else: 

238 new_view = np.squeeze(new_view) 

239 

240 if tensor_result != index_result: 

241 new_view = np.einsum(tensor_result + '->' + index_result, new_view) 

242 

243 return new_view