Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/rnn_grad.py: 55%
11 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Gradients for (block) GRU/LSTM operators."""
16from tensorflow.python.framework import ops
17from tensorflow.python.ops import gen_rnn_ops
20def _block_lstm_grad(op, *grads):
21 """Gradient for the BlockLSTM op."""
22 seq_len_max, x, cs_prev, h_prev, w, wci, wcf, wco, b = op.inputs
23 i, cs, f, o, ci, co, h = op.outputs
24 _, cs_grad, _, _, _, _, h_grad = grads
25 (x_grad, cs_prev_grad, h_prev_grad, w_grad, wci_grad, wcf_grad, wco_grad,
26 b_grad) = gen_rnn_ops.block_lstm_grad(
27 seq_len_max=seq_len_max,
28 x=x,
29 cs_prev=cs_prev,
30 h_prev=h_prev,
31 w=w,
32 wci=wci,
33 wcf=wcf,
34 wco=wco,
35 b=b,
36 i=i,
37 cs=cs,
38 f=f,
39 o=o,
40 ci=ci,
41 co=co,
42 h=h,
43 cs_grad=cs_grad,
44 h_grad=h_grad,
45 use_peephole=op.get_attr("use_peephole"))
46 return (None, x_grad, cs_prev_grad, h_prev_grad, w_grad, wci_grad, wcf_grad,
47 wco_grad, b_grad)
50ops.RegisterGradient("BlockLSTM")(_block_lstm_grad)
51ops.RegisterGradient("BlockLSTMV2")(_block_lstm_grad)