Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/cudnn_rnn_grad.py: 50%
18 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Gradients for CuudnnRNN operators."""
16from tensorflow.python.framework import ops
17from tensorflow.python.ops import gen_cudnn_rnn_ops
20@ops.RegisterGradient("CudnnRNN")
21def _cudnn_rnn_backward(op, *grads):
22 """Gradients for the CudnnRNN op."""
23 if not op.get_attr("is_training"):
24 raise ValueError(
25 "To use CudnnRNN in gradients, is_training must be set to True.")
26 return gen_cudnn_rnn_ops.cudnn_rnn_backprop(
27 input=op.inputs[0],
28 input_h=op.inputs[1],
29 input_c=op.inputs[2],
30 params=op.inputs[3],
31 output=op.outputs[0],
32 output_h=op.outputs[1],
33 output_c=op.outputs[2],
34 output_backprop=grads[0],
35 output_h_backprop=grads[1],
36 output_c_backprop=grads[2],
37 reserve_space=op.outputs[3],
38 dropout=op.get_attr("dropout"),
39 seed=op.get_attr("seed"),
40 seed2=op.get_attr("seed2"),
41 rnn_mode=op.get_attr("rnn_mode"),
42 input_mode=op.get_attr("input_mode"),
43 direction=op.get_attr("direction"))
46@ops.RegisterGradient("CudnnRNNV2")
47def _cudnn_rnn_backward_v2(op, *grad):
48 if not op.get_attr("is_training"):
49 raise ValueError(
50 "To use CudnnRNNV2 in gradients, is_training must be set to True.")
51 return gen_cudnn_rnn_ops.cudnn_rnn_backprop_v2(
52 input=op.inputs[0],
53 input_h=op.inputs[1],
54 input_c=op.inputs[2],
55 params=op.inputs[3],
56 output=op.outputs[0],
57 output_h=op.outputs[1],
58 output_c=op.outputs[2],
59 output_backprop=grad[0],
60 output_h_backprop=grad[1],
61 output_c_backprop=grad[2],
62 reserve_space=op.outputs[3],
63 host_reserved=op.outputs[4],
64 dropout=op.get_attr("dropout"),
65 seed=op.get_attr("seed"),
66 seed2=op.get_attr("seed2"),
67 rnn_mode=op.get_attr("rnn_mode"),
68 input_mode=op.get_attr("input_mode"),
69 direction=op.get_attr("direction"))
72@ops.RegisterGradient("CudnnRNNV3")
73def _cudnn_rnn_backwardv3(op, *grads):
74 """Gradients for the CudnnRNNV3 op."""
75 if not op.get_attr("is_training"):
76 raise ValueError(
77 "To use CudnnRNNV3 in gradients, is_training must be set to True.")
78 return gen_cudnn_rnn_ops.cudnn_rnn_backprop_v3(
79 input=op.inputs[0],
80 input_h=op.inputs[1],
81 input_c=op.inputs[2],
82 params=op.inputs[3],
83 sequence_lengths=op.inputs[4],
84 output=op.outputs[0],
85 output_h=op.outputs[1],
86 output_c=op.outputs[2],
87 output_backprop=grads[0],
88 output_h_backprop=grads[1],
89 output_c_backprop=grads[2],
90 reserve_space=op.outputs[3],
91 host_reserved=op.outputs[4],
92 dropout=op.get_attr("dropout"),
93 seed=op.get_attr("seed"),
94 seed2=op.get_attr("seed2"),
95 time_major=op.get_attr("time_major"),
96 num_proj=op.get_attr("num_proj"),
97 rnn_mode=op.get_attr("rnn_mode"),
98 input_mode=op.get_attr("input_mode"),
99 direction=op.get_attr("direction")) + (None,)