Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/linalg/sparse/sparse_csr_matrix_grad.py: 20%
167 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""CSR Sparse Matrix Gradients."""
17from tensorflow.python.framework import ops
18from tensorflow.python.framework import sparse_tensor
19from tensorflow.python.ops import array_ops
20from tensorflow.python.ops import array_ops_stack
21from tensorflow.python.ops import math_ops
22from tensorflow.python.ops import sparse_ops
23from tensorflow.python.ops.linalg.sparse import sparse_csr_matrix_ops
26@ops.RegisterGradient("DenseToCSRSparseMatrix")
27def _DenseToCSRSparseMatrixGrad(op, grad):
28 """Gradient for dense_to_csr_sparse_matrix op."""
29 grad_values = (
30 sparse_csr_matrix_ops.csr_sparse_matrix_to_dense(
31 grad, type=op.get_attr("T")))
32 # inputs to fw op were: params, indices.
33 return (grad_values, None)
36@ops.RegisterGradient("CSRSparseMatrixToDense")
37def _CSRSparseMatrixToDenseGrad(op, grad):
38 """Gradient for csr_sparse_matrix_to_dense op."""
39 coo_sparse_tensor = sparse_csr_matrix_ops.csr_sparse_matrix_to_sparse_tensor(
40 op.inputs[0], type=grad.dtype)
41 return sparse_csr_matrix_ops.sparse_tensor_to_csr_sparse_matrix(
42 indices=coo_sparse_tensor.indices,
43 values=array_ops.gather_nd(grad, coo_sparse_tensor.indices),
44 dense_shape=grad.shape)
47@ops.RegisterGradient("SparseTensorToCSRSparseMatrix")
48def _SparseTensorToCSRSparseMatrixGrad(op, grad):
49 """Gradient for sparse_tensor_to_csr_sparse_matrix op."""
50 grad_values = sparse_csr_matrix_ops.csr_sparse_matrix_to_sparse_tensor(
51 grad, type=op.get_attr("T")).values
52 return (None, grad_values, None)
55@ops.RegisterGradient("CSRSparseMatrixToSparseTensor")
56def _CSRSparseMatrixToSparseTensorGrad(op, *grads):
57 """Gradient for csr_sparse_matrix_to_sparse_tensor op."""
58 return sparse_csr_matrix_ops.sparse_tensor_to_csr_sparse_matrix(
59 indices=op.outputs[0], values=grads[1], dense_shape=op.outputs[2])
62ops.NotDifferentiable("SparseMatrixNNZ")
64ops.NotDifferentiable("SparseMatrixZeros")
67def _PruneSparseTensor(unpruned, pruned_pattern):
68 """Helper function to prune COO sparse tensor.
70 Given two sparse tensors 'unpruned' and 'pruned_pattern', generates another
71 sparse tensor with indices and values fron 'unpruned' only if its indices also
72 occur in pruned_pattern.
74 Args:
75 unpruned: COO matrix with unpruned indices
76 pruned_pattern: COO matrix with pruned pattern.
78 TODO(tabakg): This is far from optimal. Consider a C++ implementation.
80 Returns:
81 Indices, values, and dense_shape of the pruned matrix.
82 """
83 pruned_indices = sparse_ops.sparse_reshape(
84 pruned_pattern, shape=(-1,)).indices[..., 0]
85 unpruned_indices = sparse_ops.sparse_reshape(
86 unpruned, shape=(-1,)).indices[..., 0]
87 best_match = array_ops.searchsorted(unpruned_indices, pruned_indices)
88 keep_indices = array_ops.gather(
89 best_match,
90 array_ops.where(
91 math_ops.equal(
92 array_ops.gather(unpruned_indices, best_match), pruned_indices)))
93 return (array_ops.gather_nd(unpruned.indices, keep_indices),
94 array_ops.gather_nd(unpruned.values,
95 keep_indices), pruned_pattern.dense_shape)
98def _PruneCSRMatrix(unpruned, pruned_pattern):
99 """TODO(tabakg): Consider re-writing in C++."""
100 _, dtype = sparse_csr_matrix_ops.dense_shape_and_type(pruned_pattern)
101 coo_unpruned = sparse_tensor.SparseTensor(
102 *sparse_csr_matrix_ops.csr_sparse_matrix_to_sparse_tensor(
103 unpruned, type=dtype))
104 coo_pruned_pattern = sparse_tensor.SparseTensor(
105 *sparse_csr_matrix_ops.csr_sparse_matrix_to_sparse_tensor(
106 pruned_pattern, type=dtype))
107 return sparse_csr_matrix_ops.sparse_tensor_to_csr_sparse_matrix(
108 *_PruneSparseTensor(coo_unpruned, coo_pruned_pattern))
111@ops.RegisterGradient("SparseMatrixAdd")
112def _SparseMatrixAddGrad(op, grad):
113 """Gradient for sparse_matrix_add op."""
114 # input to sparse_matrix_add is (a, b, alpha, beta)
115 # with a, b CSR and alpha beta scalars.
116 # output is: alpha * a + beta * b
118 # d(a*A + b*B)/dA . grad = a * grad
120 # May have gotten the transposes wrong below.
121 # d(a*A + b*B)/da . grad = tr(A' . grad)
123 # For now, only implement gradients w.r.t. A and B.
124 # TODO(ebrevdo): Implement reduce_sum for SparseMatrix so that we
125 # can implement gradients w.r.t. a and b.
126 (a_csr, b_csr, alpha, beta) = op.inputs
127 return (sparse_csr_matrix_ops.sparse_matrix_mul(
128 _PruneCSRMatrix(grad, a_csr), alpha),
129 sparse_csr_matrix_ops.sparse_matrix_mul(
130 _PruneCSRMatrix(grad, b_csr), beta), None, None)
133def _PrunedDenseMatrixMultiplication(a,
134 b,
135 indices,
136 transpose_a=False,
137 adjoint_a=False,
138 transpose_b=False,
139 adjoint_b=False):
140 """Multiplies two dense matrices at selected indices.
142 The two inputs `a` and `b` must have matching rank (2 or 3). If using rank 3,
143 the first rank is used for the batch number. The last two dimensions should
144 also be compatible for matrix multiplication.
146 TODO(tabakg): Consider C++ implementation. There is also a more efficient way
147 to handle transposes here.
149 Args:
150 a: The left dense matrix (or batched matrices).
151 b: The right dense matrix (or batched matrices).
152 indices: The selected output indices where values should be produced. Other
153 indices will be pruned (not computed in the first place). Indices are
154 specified as a tensor of shape (length, rank), where length is the number
155 of entries and rank is the rank of the dense inputs (2 or 3).
156 transpose_a: Whether to transpose a.
157 adjoint_a: Whether to take the conjugate transpose of a.
158 transpose_b: Whether to transpose b.
159 adjoint_b: Whether to take the conjugate transpose of b.
161 Returns:
162 A CSR matrix.
163 """
164 transpose_a = transpose_a or adjoint_a
165 transpose_b = transpose_b or adjoint_b
167 a = math_ops.conj(a) if adjoint_a else a
168 b = math_ops.conj(b) if adjoint_b else b
170 rank = len(a.shape)
171 dense_shape = (a.shape[-1] if transpose_a else a.shape[-2],
172 b.shape[-2] if transpose_b else b.shape[-1])
173 if rank == 2:
174 rows = indices[:, 0]
175 cols = indices[:, 1]
176 transpose = array_ops.transpose
177 gather_op = array_ops.gather
178 elif rank == 3:
179 dense_shape = (a.shape[0],) + dense_shape
180 rows = indices[:, :2]
181 cols = array_ops_stack.stack([indices[:, 0], indices[:, 2]], axis=1)
182 transpose = lambda x: array_ops.transpose(x, perm=[0, 2, 1])
183 gather_op = array_ops.gather_nd
185 a_rows = gather_op(transpose(a) if transpose_a else a, indices=rows)
186 b_cols = gather_op(b if transpose_b else transpose(b), indices=cols)
187 values = math_ops.reduce_sum(a_rows * b_cols, axis=1)
189 return sparse_csr_matrix_ops.sparse_tensor_to_csr_sparse_matrix(
190 indices=indices, values=values, dense_shape=dense_shape)
193@ops.RegisterGradient("SparseMatrixTranspose")
194def _SparseMatrixTransposeGrad(op, grad):
195 """Gradient for sparse_matrix_transpose op."""
196 return sparse_csr_matrix_ops.sparse_matrix_transpose(
197 grad, type=op.get_attr("type"), conjugate=op.get_attr("conjugate"))
200@ops.RegisterGradient("SparseMatrixSoftmax")
201def _SparseMatrixSoftmaxGrad(op, grad_softmax):
202 """Gradient for sparse_matrix_softmax op."""
203 softmax = op.outputs[0]
204 return sparse_csr_matrix_ops.sparse_matrix_softmax_grad(
205 softmax, grad_softmax, type=op.get_attr("type"))
208@ops.RegisterGradient("SparseMatrixMatMul")
209def _SparseMatrixMatMulGrad(op, grad):
210 """Gradient for sparse_matrix_mat_mul op."""
211 # input to sparse_matrix_mat_mul is (A, B) with CSR A and dense B.
212 # Output is dense:
213 # C = opA(A) . opB(B) if transpose_output = false
214 # C = (opA(A) . opB(B))' = opB(B)' . opA(A)' if transpose_output = true.
215 # where opA = transpose if transpose_a = True else identity
216 # and opB = transpose if transpose_b = True else identity
218 t_a = op.get_attr("transpose_a")
219 t_b = op.get_attr("transpose_b")
220 adj_a = op.get_attr("adjoint_a")
221 adj_b = op.get_attr("adjoint_b")
222 transpose_output = op.get_attr("transpose_output")
223 conjugate_output = op.get_attr("conjugate_output")
224 a = op.inputs[0] # sparse matrix
225 b = op.inputs[1] # dense matrix
226 conj = math_ops.conj
227 sparse_matmul = sparse_csr_matrix_ops.sparse_matrix_mat_mul
229 def matmul(x, y, **kwargs): # pylint: disable=invalid-name
230 return _PrunedDenseMatrixMultiplication(
231 x,
232 y,
233 indices=sparse_csr_matrix_ops.csr_sparse_matrix_to_sparse_tensor(
234 a, type=x.dtype).indices,
235 **kwargs)
237 if conjugate_output:
238 grad = conj(grad)
239 if not transpose_output:
240 # C = opA(A) . opB(B)
241 if not adj_a and not adj_b:
242 a = conj(a)
243 b = conj(b)
244 if not t_a:
245 grad_a = matmul(grad, b, transpose_b=not t_b)
246 else:
247 grad_a = matmul(b, grad, transpose_a=t_b, transpose_b=True)
248 grad_b = sparse_matmul(a, grad, transpose_a=not t_a, transpose_output=t_b)
249 elif not t_a and not t_b:
250 if not adj_a:
251 grad_a = matmul(grad, b, adjoint_b=not adj_b)
252 else:
253 grad_a = matmul(b, grad, adjoint_a=adj_b, adjoint_b=True)
254 grad_b = sparse_matmul(
255 a,
256 grad,
257 adjoint_a=not adj_a,
258 transpose_output=adj_b,
259 conjugate_output=adj_b)
260 elif adj_a and t_b:
261 grad_a = matmul(b, grad, transpose_a=True, adjoint_b=True)
262 grad_b = sparse_matmul(a, grad, transpose_output=True)
263 elif t_a and adj_b:
264 grad_a = matmul(b, grad, transpose_a=True, transpose_b=True)
265 grad_b = sparse_matmul(
266 conj(a), grad, transpose_output=True, conjugate_output=True)
267 else:
268 # C = (opA(A) . opB(B))' = opB(B)' . opA(A)'
269 if not adj_a and not adj_b:
270 a = conj(a)
271 b = conj(b)
272 if not t_a:
273 grad_a = matmul(grad, b, transpose_a=True, transpose_b=not t_b)
274 else:
275 grad_a = matmul(b, grad, transpose_a=t_b)
276 grad_b = sparse_matmul(
277 a, grad, transpose_a=not t_a, transpose_b=True, transpose_output=t_b)
278 elif not t_a and not t_b:
279 if not adj_a:
280 grad_a = matmul(grad, b, transpose_a=True, adjoint_b=not adj_b)
281 else:
282 grad_a = matmul(b, conj(grad), adjoint_a=adj_b)
283 grad_b = sparse_matmul(
284 a,
285 grad,
286 adjoint_a=not adj_a,
287 transpose_b=True,
288 transpose_output=adj_b,
289 conjugate_output=adj_b)
290 elif adj_a and t_b:
291 grad_a = matmul(b, conj(grad), transpose_a=True)
292 grad_b = sparse_matmul(a, grad, transpose_b=True, transpose_output=True)
293 elif t_a and adj_b:
294 grad_a = matmul(b, grad, transpose_a=True)
295 grad_b = sparse_matmul(a, grad, adjoint_b=True, transpose_output=True)
297 return (grad_a, grad_b)
300@ops.RegisterGradient("SparseMatrixSparseMatMul")
301def _SparseMatrixSparseMatMulGrad(op, grad):
302 """Gradient for sparse_matrix_sparse_mat_mul op."""
303 t_a = op.get_attr("transpose_a")
304 t_b = op.get_attr("transpose_b")
305 adj_a = op.get_attr("adjoint_a")
306 adj_b = op.get_attr("adjoint_b")
307 dtype = op.get_attr("type")
309 # input to sparse_matrix_sparse_mat_mul is (A, B) with CSR A and B.
310 # Output is CSR:
311 # C = opA(A) . opB(B)
312 # where opA = transpose if transpose_a = True else identity
313 # and opB = transpose if transpose_b = True else identity
314 a = op.inputs[0]
315 b = op.inputs[1]
316 conj = math_ops.conj
317 matmul = sparse_csr_matrix_ops.sparse_matrix_sparse_mat_mul
318 if not t_a and not t_b:
319 if not adj_a:
320 if not adj_b:
321 grad_a = matmul(grad, b, adjoint_b=True, type=dtype)
322 grad_b = matmul(a, grad, adjoint_a=True, type=dtype)
323 else:
324 grad_a = matmul(grad, b, type=dtype)
325 grad_b = matmul(grad, a, adjoint_a=True, type=dtype)
326 else:
327 if not adj_b:
328 grad_a = matmul(b, grad, adjoint_b=True, type=dtype)
329 grad_b = matmul(a, grad, type=dtype)
330 else:
331 grad_a = matmul(b, grad, adjoint_a=True, adjoint_b=True, type=dtype)
332 grad_b = matmul(grad, a, adjoint_a=True, adjoint_b=True, type=dtype)
333 elif not adj_a and not adj_b:
334 if not t_a and t_b:
335 grad_a = matmul(grad, conj(b), type=dtype)
336 grad_b = matmul(grad, conj(a), transpose_a=True, type=dtype)
337 elif t_a and not t_b:
338 grad_a = matmul(conj(b), grad, transpose_b=True, type=dtype)
339 grad_b = matmul(conj(a), grad, type=dtype)
340 else:
341 grad_a = matmul(b, grad, adjoint_a=True, transpose_b=True, type=dtype)
342 grad_b = matmul(grad, a, transpose_a=True, adjoint_b=True, type=dtype)
343 elif adj_a and t_b:
344 grad_a = matmul(b, grad, transpose_a=True, adjoint_b=True, type=dtype)
345 grad_b = matmul(grad, a, transpose_a=True, transpose_b=True, type=dtype)
346 elif t_a and adj_b:
347 grad_a = matmul(b, grad, transpose_a=True, transpose_b=True, type=dtype)
348 grad_b = matmul(grad, a, adjoint_a=True, transpose_b=True, type=dtype)
350 # TODO(tabakg): There should be a C++ function for sparse-sparse
351 # multiplication with pre-determined indices, instead of pruning after the
352 # multiplication.
353 return (_PruneCSRMatrix(grad_a, a), _PruneCSRMatrix(grad_b, b))
356@ops.RegisterGradient("SparseMatrixMul")
357def _SparseMatrixMulGrad(op, grad):
358 """Gradient for sparse_matrix_mul op."""
359 # input to sparse_matrix_mul is (A, B) with CSR A and dense B.
360 # Output is CSR:
361 # C = A .* B
362 del op
363 del grad
364 raise NotImplementedError