Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/opt_einsum/blas.py: 6%
77 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1"""
2Determines if a contraction can use BLAS or not
3"""
5import numpy as np
7from . import helpers
9__all__ = ["can_blas", "tensor_blas"]
12def can_blas(inputs, result, idx_removed, shapes=None):
13 """
14 Checks if we can use a BLAS call.
16 Parameters
17 ----------
18 inputs : list of str
19 Specifies the subscripts for summation.
20 result : str
21 Resulting summation.
22 idx_removed : set
23 Indices that are removed in the summation
24 shapes : sequence of tuple[int], optional
25 If given, check also that none of the indices are broadcast dimensions.
27 Returns
28 -------
29 type : str or bool
30 The type of BLAS call to be used or False if none.
32 Notes
33 -----
34 We assume several operations are not efficient such as a transposed
35 DDOT, therefore 'ijk,jki->' should prefer einsum. These return the blas
36 type appended with "/EINSUM" to differentiate when they can still be done
37 with tensordot if required, e.g. when a backend has no einsum.
39 Examples
40 --------
41 >>> can_blas(['ij', 'jk'], 'ik', set('j'))
42 'GEMM'
44 >>> can_blas(['ijj', 'jk'], 'ik', set('j'))
45 False
47 >>> can_blas(['ab', 'cd'], 'abcd', set())
48 'OUTER/EINSUM'
50 >>> # looks like GEMM but actually 'j' is broadcast:
51 >>> can_blas(['ij', 'jk'], 'ik', set('j'), shapes=[(4, 1), (5, 6)])
52 False
53 """
54 # Can only do two
55 if len(inputs) != 2:
56 return False
58 input_left, input_right = inputs
60 for c in set(input_left + input_right):
61 # can't deal with repeated indices on same input or more than 2 total
62 nl, nr = input_left.count(c), input_right.count(c)
63 if (nl > 1) or (nr > 1) or (nl + nr > 2):
64 return False
66 # can't do implicit summation or dimension collapse e.g.
67 # "ab,bc->c" (implicitly sum over 'a')
68 # "ab,ca->ca" (take diagonal of 'a')
69 if nl + nr - 1 == int(c in result):
70 return False
72 # check for broadcast indices e.g:
73 # "ij,jk->ik" (but one of the 'j' dimensions is broadcast up)
74 if shapes is not None:
75 for c in idx_removed:
76 if shapes[0][input_left.find(c)] != shapes[1][input_right.find(c)]:
77 return False
79 # Prefer einsum if not removing indices
80 # (N.B. tensordot outer faster for large arrays?)
81 if len(idx_removed) == 0:
82 return 'OUTER/EINSUM'
84 # Build a few temporaries
85 sets = [set(x) for x in inputs]
86 keep_left = sets[0] - idx_removed
87 keep_right = sets[1] - idx_removed
88 rs = len(idx_removed)
90 # DDOT
91 if inputs[0] == inputs[1]:
92 return 'DOT'
94 # DDOT doesnt make sense if you have to tranpose - prefer einsum
95 elif sets[0] == sets[1]:
96 return 'DOT/EINSUM'
98 # GEMM no transpose
99 if input_left[-rs:] == input_right[:rs]:
100 return 'GEMM'
102 # GEMM transpose both
103 elif input_left[:rs] == input_right[-rs:]:
104 return 'GEMM'
106 # GEMM transpose right
107 elif input_left[-rs:] == input_right[-rs:]:
108 return 'GEMM'
110 # GEMM tranpose left
111 elif input_left[:rs] == input_right[:rs]:
112 return 'GEMM'
114 # Einsum is faster than vectordot if we have to copy
115 elif (len(keep_left) == 0) or (len(keep_right) == 0):
116 return 'GEMV/EINSUM'
118 # Conventional tensordot
119 else:
120 return 'TDOT'
123def tensor_blas(view_left, input_left, view_right, input_right, index_result, idx_removed):
124 """
125 Computes the dot product between two tensors, attempts to use np.dot and
126 then tensordot if that fails.
128 Parameters
129 ----------
130 view_left : array_like
131 The left hand view
132 input_left : str
133 Indices of the left view
134 view_right : array_like
135 The right hand view
136 input_right : str
137 Indices of the right view
138 index_result : str
139 The resulting indices
140 idx_removed : set
141 Indices removed in the contraction
143 Returns
144 -------
145 type : array
146 The resulting BLAS operation.
148 Notes
149 -----
150 Interior function for tensor BLAS.
152 This function will attempt to use `np.dot` by the iterating through the
153 four possible transpose cases. If this fails all inner and matrix-vector
154 operations will be handed off to einsum while all matrix-matrix operations will
155 first copy the data, perform the DGEMM, and then copy the data to the required
156 order.
158 Examples
159 --------
161 >>> a = np.random.rand(4, 4)
162 >>> b = np.random.rand(4, 4)
163 >>> tmp = tensor_blas(a, 'ij', b, 'jk', 'ik', set('j'))
164 >>> np.allclose(tmp, np.dot(a, b))
166 """
168 idx_removed = set(idx_removed)
169 keep_left = set(input_left) - idx_removed
170 keep_right = set(input_right) - idx_removed
172 # We trust this must be called correctly
173 dimension_dict = {}
174 for i, s in zip(input_left, view_left.shape):
175 dimension_dict[i] = s
176 for i, s in zip(input_right, view_right.shape):
177 dimension_dict[i] = s
179 # Do we want to be able to do this?
181 # Check for duplicate indices, cannot do einsum('iij,jkk->ik') operations here
182 # if (len(set(input_left)) != len(input_left)):
183 # new_inds = ''.join(keep_left) + ''.join(idx_removed)
184 # view_left = np.einsum(input_left + '->' + new_inds, view_left, order='C')
185 # input_left = new_inds
187 # if (len(set(input_right)) != len(input_right)):
188 # new_inds = ''.join(idx_removed) + ''.join(keep_right)
189 # view_right = np.einsum(input_right + '->' + new_inds, view_right, order='C')
190 # input_right = new_inds
192 # Tensordot guarantees a copy for ndim > 2, should avoid skip if possible
193 rs = len(idx_removed)
194 dim_left = helpers.compute_size_by_dict(keep_left, dimension_dict)
195 dim_right = helpers.compute_size_by_dict(keep_right, dimension_dict)
196 dim_removed = helpers.compute_size_by_dict(idx_removed, dimension_dict)
197 tensor_result = input_left + input_right
198 for s in idx_removed:
199 tensor_result = tensor_result.replace(s, "")
201 # This is ugly, but can vastly speed up certain operations
202 # Vectordot
203 if input_left == input_right:
204 new_view = np.dot(view_left.ravel(), view_right.ravel())
206 # Matrix multiply
207 # No transpose needed
208 elif input_left[-rs:] == input_right[:rs]:
209 new_view = np.dot(view_left.reshape(dim_left, dim_removed), view_right.reshape(dim_removed, dim_right))
211 # Transpose both
212 elif input_left[:rs] == input_right[-rs:]:
213 new_view = np.dot(view_left.reshape(dim_removed, dim_left).T, view_right.reshape(dim_right, dim_removed).T)
215 # Transpose right
216 elif input_left[-rs:] == input_right[-rs:]:
217 new_view = np.dot(view_left.reshape(dim_left, dim_removed), view_right.reshape(dim_right, dim_removed).T)
219 # Tranpose left
220 elif input_left[:rs] == input_right[:rs]:
221 new_view = np.dot(view_left.reshape(dim_removed, dim_left).T, view_right.reshape(dim_removed, dim_right))
223 # Conventional tensordot
224 else:
225 # Find indices to contract over
226 left_pos, right_pos = (), ()
227 for s in idx_removed:
228 left_pos += (input_left.find(s), )
229 right_pos += (input_right.find(s), )
230 new_view = np.tensordot(view_left, view_right, axes=(left_pos, right_pos))
232 # Make sure the resulting shape is correct
233 tensor_shape = tuple(dimension_dict[x] for x in tensor_result)
234 if new_view.shape != tensor_shape:
235 if len(tensor_result) > 0:
236 new_view.shape = tensor_shape
237 else:
238 new_view = np.squeeze(new_view)
240 if tensor_result != index_result:
241 new_view = np.einsum(tensor_result + '->' + index_result, new_view)
243 return new_view