Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/linalg/sparse/sparse_csr_matrix_grad.py: 20%

167 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""CSR Sparse Matrix Gradients.""" 

16 

17from tensorflow.python.framework import ops 

18from tensorflow.python.framework import sparse_tensor 

19from tensorflow.python.ops import array_ops 

20from tensorflow.python.ops import array_ops_stack 

21from tensorflow.python.ops import math_ops 

22from tensorflow.python.ops import sparse_ops 

23from tensorflow.python.ops.linalg.sparse import sparse_csr_matrix_ops 

24 

25 

26@ops.RegisterGradient("DenseToCSRSparseMatrix") 

27def _DenseToCSRSparseMatrixGrad(op, grad): 

28 """Gradient for dense_to_csr_sparse_matrix op.""" 

29 grad_values = ( 

30 sparse_csr_matrix_ops.csr_sparse_matrix_to_dense( 

31 grad, type=op.get_attr("T"))) 

32 # inputs to fw op were: params, indices. 

33 return (grad_values, None) 

34 

35 

36@ops.RegisterGradient("CSRSparseMatrixToDense") 

37def _CSRSparseMatrixToDenseGrad(op, grad): 

38 """Gradient for csr_sparse_matrix_to_dense op.""" 

39 coo_sparse_tensor = sparse_csr_matrix_ops.csr_sparse_matrix_to_sparse_tensor( 

40 op.inputs[0], type=grad.dtype) 

41 return sparse_csr_matrix_ops.sparse_tensor_to_csr_sparse_matrix( 

42 indices=coo_sparse_tensor.indices, 

43 values=array_ops.gather_nd(grad, coo_sparse_tensor.indices), 

44 dense_shape=grad.shape) 

45 

46 

47@ops.RegisterGradient("SparseTensorToCSRSparseMatrix") 

48def _SparseTensorToCSRSparseMatrixGrad(op, grad): 

49 """Gradient for sparse_tensor_to_csr_sparse_matrix op.""" 

50 grad_values = sparse_csr_matrix_ops.csr_sparse_matrix_to_sparse_tensor( 

51 grad, type=op.get_attr("T")).values 

52 return (None, grad_values, None) 

53 

54 

55@ops.RegisterGradient("CSRSparseMatrixToSparseTensor") 

56def _CSRSparseMatrixToSparseTensorGrad(op, *grads): 

57 """Gradient for csr_sparse_matrix_to_sparse_tensor op.""" 

58 return sparse_csr_matrix_ops.sparse_tensor_to_csr_sparse_matrix( 

59 indices=op.outputs[0], values=grads[1], dense_shape=op.outputs[2]) 

60 

61 

62ops.NotDifferentiable("SparseMatrixNNZ") 

63 

64ops.NotDifferentiable("SparseMatrixZeros") 

65 

66 

67def _PruneSparseTensor(unpruned, pruned_pattern): 

68 """Helper function to prune COO sparse tensor. 

69 

70 Given two sparse tensors 'unpruned' and 'pruned_pattern', generates another 

71 sparse tensor with indices and values fron 'unpruned' only if its indices also 

72 occur in pruned_pattern. 

73 

74 Args: 

75 unpruned: COO matrix with unpruned indices 

76 pruned_pattern: COO matrix with pruned pattern. 

77 

78 TODO(tabakg): This is far from optimal. Consider a C++ implementation. 

79 

80 Returns: 

81 Indices, values, and dense_shape of the pruned matrix. 

82 """ 

83 pruned_indices = sparse_ops.sparse_reshape( 

84 pruned_pattern, shape=(-1,)).indices[..., 0] 

85 unpruned_indices = sparse_ops.sparse_reshape( 

86 unpruned, shape=(-1,)).indices[..., 0] 

87 best_match = array_ops.searchsorted(unpruned_indices, pruned_indices) 

88 keep_indices = array_ops.gather( 

89 best_match, 

90 array_ops.where( 

91 math_ops.equal( 

92 array_ops.gather(unpruned_indices, best_match), pruned_indices))) 

93 return (array_ops.gather_nd(unpruned.indices, keep_indices), 

94 array_ops.gather_nd(unpruned.values, 

95 keep_indices), pruned_pattern.dense_shape) 

96 

97 

98def _PruneCSRMatrix(unpruned, pruned_pattern): 

99 """TODO(tabakg): Consider re-writing in C++.""" 

100 _, dtype = sparse_csr_matrix_ops.dense_shape_and_type(pruned_pattern) 

101 coo_unpruned = sparse_tensor.SparseTensor( 

102 *sparse_csr_matrix_ops.csr_sparse_matrix_to_sparse_tensor( 

103 unpruned, type=dtype)) 

104 coo_pruned_pattern = sparse_tensor.SparseTensor( 

105 *sparse_csr_matrix_ops.csr_sparse_matrix_to_sparse_tensor( 

106 pruned_pattern, type=dtype)) 

107 return sparse_csr_matrix_ops.sparse_tensor_to_csr_sparse_matrix( 

108 *_PruneSparseTensor(coo_unpruned, coo_pruned_pattern)) 

109 

110 

111@ops.RegisterGradient("SparseMatrixAdd") 

112def _SparseMatrixAddGrad(op, grad): 

113 """Gradient for sparse_matrix_add op.""" 

114 # input to sparse_matrix_add is (a, b, alpha, beta) 

115 # with a, b CSR and alpha beta scalars. 

116 # output is: alpha * a + beta * b 

117 

118 # d(a*A + b*B)/dA . grad = a * grad 

119 

120 # May have gotten the transposes wrong below. 

121 # d(a*A + b*B)/da . grad = tr(A' . grad) 

122 

123 # For now, only implement gradients w.r.t. A and B. 

124 # TODO(ebrevdo): Implement reduce_sum for SparseMatrix so that we 

125 # can implement gradients w.r.t. a and b. 

126 (a_csr, b_csr, alpha, beta) = op.inputs 

127 return (sparse_csr_matrix_ops.sparse_matrix_mul( 

128 _PruneCSRMatrix(grad, a_csr), alpha), 

129 sparse_csr_matrix_ops.sparse_matrix_mul( 

130 _PruneCSRMatrix(grad, b_csr), beta), None, None) 

131 

132 

133def _PrunedDenseMatrixMultiplication(a, 

134 b, 

135 indices, 

136 transpose_a=False, 

137 adjoint_a=False, 

138 transpose_b=False, 

139 adjoint_b=False): 

140 """Multiplies two dense matrices at selected indices. 

141 

142 The two inputs `a` and `b` must have matching rank (2 or 3). If using rank 3, 

143 the first rank is used for the batch number. The last two dimensions should 

144 also be compatible for matrix multiplication. 

145 

146 TODO(tabakg): Consider C++ implementation. There is also a more efficient way 

147 to handle transposes here. 

148 

149 Args: 

150 a: The left dense matrix (or batched matrices). 

151 b: The right dense matrix (or batched matrices). 

152 indices: The selected output indices where values should be produced. Other 

153 indices will be pruned (not computed in the first place). Indices are 

154 specified as a tensor of shape (length, rank), where length is the number 

155 of entries and rank is the rank of the dense inputs (2 or 3). 

156 transpose_a: Whether to transpose a. 

157 adjoint_a: Whether to take the conjugate transpose of a. 

158 transpose_b: Whether to transpose b. 

159 adjoint_b: Whether to take the conjugate transpose of b. 

160 

161 Returns: 

162 A CSR matrix. 

163 """ 

164 transpose_a = transpose_a or adjoint_a 

165 transpose_b = transpose_b or adjoint_b 

166 

167 a = math_ops.conj(a) if adjoint_a else a 

168 b = math_ops.conj(b) if adjoint_b else b 

169 

170 rank = len(a.shape) 

171 dense_shape = (a.shape[-1] if transpose_a else a.shape[-2], 

172 b.shape[-2] if transpose_b else b.shape[-1]) 

173 if rank == 2: 

174 rows = indices[:, 0] 

175 cols = indices[:, 1] 

176 transpose = array_ops.transpose 

177 gather_op = array_ops.gather 

178 elif rank == 3: 

179 dense_shape = (a.shape[0],) + dense_shape 

180 rows = indices[:, :2] 

181 cols = array_ops_stack.stack([indices[:, 0], indices[:, 2]], axis=1) 

182 transpose = lambda x: array_ops.transpose(x, perm=[0, 2, 1]) 

183 gather_op = array_ops.gather_nd 

184 

185 a_rows = gather_op(transpose(a) if transpose_a else a, indices=rows) 

186 b_cols = gather_op(b if transpose_b else transpose(b), indices=cols) 

187 values = math_ops.reduce_sum(a_rows * b_cols, axis=1) 

188 

189 return sparse_csr_matrix_ops.sparse_tensor_to_csr_sparse_matrix( 

190 indices=indices, values=values, dense_shape=dense_shape) 

191 

192 

193@ops.RegisterGradient("SparseMatrixTranspose") 

194def _SparseMatrixTransposeGrad(op, grad): 

195 """Gradient for sparse_matrix_transpose op.""" 

196 return sparse_csr_matrix_ops.sparse_matrix_transpose( 

197 grad, type=op.get_attr("type"), conjugate=op.get_attr("conjugate")) 

198 

199 

200@ops.RegisterGradient("SparseMatrixSoftmax") 

201def _SparseMatrixSoftmaxGrad(op, grad_softmax): 

202 """Gradient for sparse_matrix_softmax op.""" 

203 softmax = op.outputs[0] 

204 return sparse_csr_matrix_ops.sparse_matrix_softmax_grad( 

205 softmax, grad_softmax, type=op.get_attr("type")) 

206 

207 

208@ops.RegisterGradient("SparseMatrixMatMul") 

209def _SparseMatrixMatMulGrad(op, grad): 

210 """Gradient for sparse_matrix_mat_mul op.""" 

211 # input to sparse_matrix_mat_mul is (A, B) with CSR A and dense B. 

212 # Output is dense: 

213 # C = opA(A) . opB(B) if transpose_output = false 

214 # C = (opA(A) . opB(B))' = opB(B)' . opA(A)' if transpose_output = true. 

215 # where opA = transpose if transpose_a = True else identity 

216 # and opB = transpose if transpose_b = True else identity 

217 

218 t_a = op.get_attr("transpose_a") 

219 t_b = op.get_attr("transpose_b") 

220 adj_a = op.get_attr("adjoint_a") 

221 adj_b = op.get_attr("adjoint_b") 

222 transpose_output = op.get_attr("transpose_output") 

223 conjugate_output = op.get_attr("conjugate_output") 

224 a = op.inputs[0] # sparse matrix 

225 b = op.inputs[1] # dense matrix 

226 conj = math_ops.conj 

227 sparse_matmul = sparse_csr_matrix_ops.sparse_matrix_mat_mul 

228 

229 def matmul(x, y, **kwargs): # pylint: disable=invalid-name 

230 return _PrunedDenseMatrixMultiplication( 

231 x, 

232 y, 

233 indices=sparse_csr_matrix_ops.csr_sparse_matrix_to_sparse_tensor( 

234 a, type=x.dtype).indices, 

235 **kwargs) 

236 

237 if conjugate_output: 

238 grad = conj(grad) 

239 if not transpose_output: 

240 # C = opA(A) . opB(B) 

241 if not adj_a and not adj_b: 

242 a = conj(a) 

243 b = conj(b) 

244 if not t_a: 

245 grad_a = matmul(grad, b, transpose_b=not t_b) 

246 else: 

247 grad_a = matmul(b, grad, transpose_a=t_b, transpose_b=True) 

248 grad_b = sparse_matmul(a, grad, transpose_a=not t_a, transpose_output=t_b) 

249 elif not t_a and not t_b: 

250 if not adj_a: 

251 grad_a = matmul(grad, b, adjoint_b=not adj_b) 

252 else: 

253 grad_a = matmul(b, grad, adjoint_a=adj_b, adjoint_b=True) 

254 grad_b = sparse_matmul( 

255 a, 

256 grad, 

257 adjoint_a=not adj_a, 

258 transpose_output=adj_b, 

259 conjugate_output=adj_b) 

260 elif adj_a and t_b: 

261 grad_a = matmul(b, grad, transpose_a=True, adjoint_b=True) 

262 grad_b = sparse_matmul(a, grad, transpose_output=True) 

263 elif t_a and adj_b: 

264 grad_a = matmul(b, grad, transpose_a=True, transpose_b=True) 

265 grad_b = sparse_matmul( 

266 conj(a), grad, transpose_output=True, conjugate_output=True) 

267 else: 

268 # C = (opA(A) . opB(B))' = opB(B)' . opA(A)' 

269 if not adj_a and not adj_b: 

270 a = conj(a) 

271 b = conj(b) 

272 if not t_a: 

273 grad_a = matmul(grad, b, transpose_a=True, transpose_b=not t_b) 

274 else: 

275 grad_a = matmul(b, grad, transpose_a=t_b) 

276 grad_b = sparse_matmul( 

277 a, grad, transpose_a=not t_a, transpose_b=True, transpose_output=t_b) 

278 elif not t_a and not t_b: 

279 if not adj_a: 

280 grad_a = matmul(grad, b, transpose_a=True, adjoint_b=not adj_b) 

281 else: 

282 grad_a = matmul(b, conj(grad), adjoint_a=adj_b) 

283 grad_b = sparse_matmul( 

284 a, 

285 grad, 

286 adjoint_a=not adj_a, 

287 transpose_b=True, 

288 transpose_output=adj_b, 

289 conjugate_output=adj_b) 

290 elif adj_a and t_b: 

291 grad_a = matmul(b, conj(grad), transpose_a=True) 

292 grad_b = sparse_matmul(a, grad, transpose_b=True, transpose_output=True) 

293 elif t_a and adj_b: 

294 grad_a = matmul(b, grad, transpose_a=True) 

295 grad_b = sparse_matmul(a, grad, adjoint_b=True, transpose_output=True) 

296 

297 return (grad_a, grad_b) 

298 

299 

300@ops.RegisterGradient("SparseMatrixSparseMatMul") 

301def _SparseMatrixSparseMatMulGrad(op, grad): 

302 """Gradient for sparse_matrix_sparse_mat_mul op.""" 

303 t_a = op.get_attr("transpose_a") 

304 t_b = op.get_attr("transpose_b") 

305 adj_a = op.get_attr("adjoint_a") 

306 adj_b = op.get_attr("adjoint_b") 

307 dtype = op.get_attr("type") 

308 

309 # input to sparse_matrix_sparse_mat_mul is (A, B) with CSR A and B. 

310 # Output is CSR: 

311 # C = opA(A) . opB(B) 

312 # where opA = transpose if transpose_a = True else identity 

313 # and opB = transpose if transpose_b = True else identity 

314 a = op.inputs[0] 

315 b = op.inputs[1] 

316 conj = math_ops.conj 

317 matmul = sparse_csr_matrix_ops.sparse_matrix_sparse_mat_mul 

318 if not t_a and not t_b: 

319 if not adj_a: 

320 if not adj_b: 

321 grad_a = matmul(grad, b, adjoint_b=True, type=dtype) 

322 grad_b = matmul(a, grad, adjoint_a=True, type=dtype) 

323 else: 

324 grad_a = matmul(grad, b, type=dtype) 

325 grad_b = matmul(grad, a, adjoint_a=True, type=dtype) 

326 else: 

327 if not adj_b: 

328 grad_a = matmul(b, grad, adjoint_b=True, type=dtype) 

329 grad_b = matmul(a, grad, type=dtype) 

330 else: 

331 grad_a = matmul(b, grad, adjoint_a=True, adjoint_b=True, type=dtype) 

332 grad_b = matmul(grad, a, adjoint_a=True, adjoint_b=True, type=dtype) 

333 elif not adj_a and not adj_b: 

334 if not t_a and t_b: 

335 grad_a = matmul(grad, conj(b), type=dtype) 

336 grad_b = matmul(grad, conj(a), transpose_a=True, type=dtype) 

337 elif t_a and not t_b: 

338 grad_a = matmul(conj(b), grad, transpose_b=True, type=dtype) 

339 grad_b = matmul(conj(a), grad, type=dtype) 

340 else: 

341 grad_a = matmul(b, grad, adjoint_a=True, transpose_b=True, type=dtype) 

342 grad_b = matmul(grad, a, transpose_a=True, adjoint_b=True, type=dtype) 

343 elif adj_a and t_b: 

344 grad_a = matmul(b, grad, transpose_a=True, adjoint_b=True, type=dtype) 

345 grad_b = matmul(grad, a, transpose_a=True, transpose_b=True, type=dtype) 

346 elif t_a and adj_b: 

347 grad_a = matmul(b, grad, transpose_a=True, transpose_b=True, type=dtype) 

348 grad_b = matmul(grad, a, adjoint_a=True, transpose_b=True, type=dtype) 

349 

350 # TODO(tabakg): There should be a C++ function for sparse-sparse 

351 # multiplication with pre-determined indices, instead of pruning after the 

352 # multiplication. 

353 return (_PruneCSRMatrix(grad_a, a), _PruneCSRMatrix(grad_b, b)) 

354 

355 

356@ops.RegisterGradient("SparseMatrixMul") 

357def _SparseMatrixMulGrad(op, grad): 

358 """Gradient for sparse_matrix_mul op.""" 

359 # input to sparse_matrix_mul is (A, B) with CSR A and dense B. 

360 # Output is CSR: 

361 # C = A .* B 

362 del op 

363 del grad 

364 raise NotImplementedError