Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/opt_einsum/blas.py: 9%
80 statements
« prev ^ index » next coverage.py v7.3.1, created at 2023-09-25 06:41 +0000
« prev ^ index » next coverage.py v7.3.1, created at 2023-09-25 06:41 +0000
1"""
2Determines if a contraction can use BLAS or not
3"""
5from typing import List, Sequence, Tuple, Union
7import numpy as np
9from . import helpers
10from .typing import ArrayIndexType
12__all__ = ["can_blas", "tensor_blas"]
15def can_blas(
16 inputs: List[str],
17 result: str,
18 idx_removed: ArrayIndexType,
19 shapes: Sequence[Tuple[int]] = None,
20) -> Union[str, bool]:
21 """
22 Checks if we can use a BLAS call.
24 Parameters
25 ----------
26 inputs : list of str
27 Specifies the subscripts for summation.
28 result : str
29 Resulting summation.
30 idx_removed : set
31 Indices that are removed in the summation
32 shapes : sequence of tuple[int], optional
33 If given, check also that none of the indices are broadcast dimensions.
35 Returns
36 -------
37 type : str or bool
38 The type of BLAS call to be used or False if none.
40 Notes
41 -----
42 We assume several operations are not efficient such as a transposed
43 DDOT, therefore 'ijk,jki->' should prefer einsum. These return the blas
44 type appended with "/EINSUM" to differentiate when they can still be done
45 with tensordot if required, e.g. when a backend has no einsum.
47 Examples
48 --------
49 >>> can_blas(['ij', 'jk'], 'ik', set('j'))
50 'GEMM'
52 >>> can_blas(['ijj', 'jk'], 'ik', set('j'))
53 False
55 >>> can_blas(['ab', 'cd'], 'abcd', set())
56 'OUTER/EINSUM'
58 >>> # looks like GEMM but actually 'j' is broadcast:
59 >>> can_blas(['ij', 'jk'], 'ik', set('j'), shapes=[(4, 1), (5, 6)])
60 False
61 """
62 # Can only do two
63 if len(inputs) != 2:
64 return False
66 input_left, input_right = inputs
68 for c in set(input_left + input_right):
69 # can't deal with repeated indices on same input or more than 2 total
70 nl, nr = input_left.count(c), input_right.count(c)
71 if (nl > 1) or (nr > 1) or (nl + nr > 2):
72 return False
74 # can't do implicit summation or dimension collapse e.g.
75 # "ab,bc->c" (implicitly sum over 'a')
76 # "ab,ca->ca" (take diagonal of 'a')
77 if nl + nr - 1 == int(c in result):
78 return False
80 # check for broadcast indices e.g:
81 # "ij,jk->ik" (but one of the 'j' dimensions is broadcast up)
82 if shapes is not None:
83 for c in idx_removed:
84 if shapes[0][input_left.find(c)] != shapes[1][input_right.find(c)]:
85 return False
87 # Prefer einsum if not removing indices
88 # (N.B. tensordot outer faster for large arrays?)
89 if len(idx_removed) == 0:
90 return "OUTER/EINSUM"
92 # Build a few temporaries
93 sets = [set(x) for x in inputs]
94 keep_left = sets[0] - idx_removed
95 keep_right = sets[1] - idx_removed
96 rs = len(idx_removed)
98 # DDOT
99 if inputs[0] == inputs[1]:
100 return "DOT"
102 # DDOT does not make sense if you have to transpose - prefer einsum
103 elif sets[0] == sets[1]:
104 return "DOT/EINSUM"
106 # GEMM no transpose
107 if input_left[-rs:] == input_right[:rs]:
108 return "GEMM"
110 # GEMM transpose both
111 elif input_left[:rs] == input_right[-rs:]:
112 return "GEMM"
114 # GEMM transpose right
115 elif input_left[-rs:] == input_right[-rs:]:
116 return "GEMM"
118 # GEMM transpose left
119 elif input_left[:rs] == input_right[:rs]:
120 return "GEMM"
122 # Einsum is faster than vectordot if we have to copy
123 elif (len(keep_left) == 0) or (len(keep_right) == 0):
124 return "GEMV/EINSUM"
126 # Conventional tensordot
127 else:
128 return "TDOT"
131def tensor_blas(
132 view_left: np.ndarray,
133 input_left: str,
134 view_right: np.ndarray,
135 input_right: str,
136 index_result: str,
137 idx_removed: ArrayIndexType,
138) -> np.ndarray:
139 """
140 Computes the dot product between two tensors, attempts to use np.dot and
141 then tensordot if that fails.
143 Parameters
144 ----------
145 view_left : array_like
146 The left hand view
147 input_left : str
148 Indices of the left view
149 view_right : array_like
150 The right hand view
151 input_right : str
152 Indices of the right view
153 index_result : str
154 The resulting indices
155 idx_removed : set
156 Indices removed in the contraction
158 Returns
159 -------
160 type : array
161 The resulting BLAS operation.
163 Notes
164 -----
165 Interior function for tensor BLAS.
167 This function will attempt to use `np.dot` by the iterating through the
168 four possible transpose cases. If this fails all inner and matrix-vector
169 operations will be handed off to einsum while all matrix-matrix operations will
170 first copy the data, perform the DGEMM, and then copy the data to the required
171 order.
173 Examples
174 --------
176 >>> a = np.random.rand(4, 4)
177 >>> b = np.random.rand(4, 4)
178 >>> tmp = tensor_blas(a, 'ij', b, 'jk', 'ik', set('j'))
179 >>> np.allclose(tmp, np.dot(a, b))
181 """
183 idx_removed = frozenset(idx_removed)
184 keep_left = frozenset(input_left) - idx_removed
185 keep_right = frozenset(input_right) - idx_removed
187 # We trust this must be called correctly
188 dimension_dict = {}
189 for i, s in zip(input_left, view_left.shape):
190 dimension_dict[i] = s
191 for i, s in zip(input_right, view_right.shape):
192 dimension_dict[i] = s
194 # Do we want to be able to do this?
196 # Check for duplicate indices, cannot do einsum('iij,jkk->ik') operations here
197 # if (len(set(input_left)) != len(input_left)):
198 # new_inds = ''.join(keep_left) + ''.join(idx_removed)
199 # view_left = np.einsum(input_left + '->' + new_inds, view_left, order='C')
200 # input_left = new_inds
202 # if (len(set(input_right)) != len(input_right)):
203 # new_inds = ''.join(idx_removed) + ''.join(keep_right)
204 # view_right = np.einsum(input_right + '->' + new_inds, view_right, order='C')
205 # input_right = new_inds
207 # Tensordot guarantees a copy for ndim > 2, should avoid skip if possible
208 rs = len(idx_removed)
209 dim_left = helpers.compute_size_by_dict(keep_left, dimension_dict)
210 dim_right = helpers.compute_size_by_dict(keep_right, dimension_dict)
211 dim_removed = helpers.compute_size_by_dict(idx_removed, dimension_dict)
212 tensor_result = input_left + input_right
213 for sidx in idx_removed:
214 tensor_result = tensor_result.replace(sidx, "")
216 # This is ugly, but can vastly speed up certain operations
217 # Vectordot
218 if input_left == input_right:
219 new_view = np.dot(view_left.ravel(), view_right.ravel())
221 # Matrix multiply
222 # No transpose needed
223 elif input_left[-rs:] == input_right[:rs]:
224 new_view = np.dot(
225 view_left.reshape(dim_left, dim_removed),
226 view_right.reshape(dim_removed, dim_right),
227 )
229 # Transpose both
230 elif input_left[:rs] == input_right[-rs:]:
231 new_view = np.dot(
232 view_left.reshape(dim_removed, dim_left).T,
233 view_right.reshape(dim_right, dim_removed).T,
234 )
236 # Transpose right
237 elif input_left[-rs:] == input_right[-rs:]:
238 new_view = np.dot(
239 view_left.reshape(dim_left, dim_removed),
240 view_right.reshape(dim_right, dim_removed).T,
241 )
243 # Transpose left
244 elif input_left[:rs] == input_right[:rs]:
245 new_view = np.dot(
246 view_left.reshape(dim_removed, dim_left).T,
247 view_right.reshape(dim_removed, dim_right),
248 )
250 # Conventional tensordot
251 else:
252 # Find indices to contract over
253 left_pos: Tuple[int, ...] = ()
254 right_pos: Tuple[int, ...] = ()
255 for fidx in idx_removed:
256 left_pos += (input_left.find(fidx),)
257 right_pos += (input_right.find(fidx),)
258 new_view = np.tensordot(view_left, view_right, axes=(left_pos, right_pos))
260 # Make sure the resulting shape is correct
261 tensor_shape = tuple(dimension_dict[x] for x in tensor_result)
262 if new_view.shape != tensor_shape:
263 if len(tensor_result) > 0:
264 new_view.shape = tensor_shape
265 else:
266 new_view = np.squeeze(new_view)
268 if tensor_result != index_result:
269 new_view = np.einsum(tensor_result + "->" + index_result, new_view)
271 return new_view