Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/rnn_grad.py: 55%

11 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Gradients for (block) GRU/LSTM operators.""" 

16from tensorflow.python.framework import ops 

17from tensorflow.python.ops import gen_rnn_ops 

18 

19 

20def _block_lstm_grad(op, *grads): 

21 """Gradient for the BlockLSTM op.""" 

22 seq_len_max, x, cs_prev, h_prev, w, wci, wcf, wco, b = op.inputs 

23 i, cs, f, o, ci, co, h = op.outputs 

24 _, cs_grad, _, _, _, _, h_grad = grads 

25 (x_grad, cs_prev_grad, h_prev_grad, w_grad, wci_grad, wcf_grad, wco_grad, 

26 b_grad) = gen_rnn_ops.block_lstm_grad( 

27 seq_len_max=seq_len_max, 

28 x=x, 

29 cs_prev=cs_prev, 

30 h_prev=h_prev, 

31 w=w, 

32 wci=wci, 

33 wcf=wcf, 

34 wco=wco, 

35 b=b, 

36 i=i, 

37 cs=cs, 

38 f=f, 

39 o=o, 

40 ci=ci, 

41 co=co, 

42 h=h, 

43 cs_grad=cs_grad, 

44 h_grad=h_grad, 

45 use_peephole=op.get_attr("use_peephole")) 

46 return (None, x_grad, cs_prev_grad, h_prev_grad, w_grad, wci_grad, wcf_grad, 

47 wco_grad, b_grad) 

48 

49 

50ops.RegisterGradient("BlockLSTM")(_block_lstm_grad) 

51ops.RegisterGradient("BlockLSTMV2")(_block_lstm_grad)