Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/layers/rnn/base_cudnn_rnn.py: 26%
69 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Base class for recurrent layers backed by cuDNN."""
18import tensorflow.compat.v2 as tf
20from keras.src import backend
21from keras.src.engine.input_spec import InputSpec
22from keras.src.layers.rnn.base_rnn import RNN
25class _CuDNNRNN(RNN):
26 """Private base class for CuDNNGRU and CuDNNLSTM layers.
28 Args:
29 return_sequences: Boolean. Whether to return the last output
30 in the output sequence, or the full sequence.
31 return_state: Boolean. Whether to return the last state
32 in addition to the output.
33 go_backwards: Boolean (default False).
34 If True, process the input sequence backwards and return the
35 reversed sequence.
36 stateful: Boolean (default False). If True, the last state
37 for each sample at index i in a batch will be used as initial
38 state for the sample of index i in the following batch.
39 time_major: Boolean (default False). If true, the inputs and outputs will
40 be in shape `(timesteps, batch, ...)`, whereas in the False case, it
41 will be `(batch, timesteps, ...)`.
42 """
44 def __init__(
45 self,
46 return_sequences=False,
47 return_state=False,
48 go_backwards=False,
49 stateful=False,
50 time_major=False,
51 **kwargs
52 ):
53 # We invoke the base layer's initializer directly here because we do not
54 # want to create RNN cell instance.
55 super(RNN, self).__init__(**kwargs)
56 self.return_sequences = return_sequences
57 self.return_state = return_state
58 self.go_backwards = go_backwards
59 self.stateful = stateful
60 self.time_major = time_major
61 self.supports_masking = False
62 self.input_spec = [InputSpec(ndim=3)]
63 if hasattr(self.cell.state_size, "__len__"):
64 state_size = self.cell.state_size
65 else:
66 state_size = [self.cell.state_size]
67 self.state_spec = [InputSpec(shape=(None, dim)) for dim in state_size]
68 self.constants_spec = None
69 self._states = None
70 self._num_constants = 0
71 self._vector_shape = tf.constant([-1])
73 def call(self, inputs, mask=None, training=None, initial_state=None):
74 if isinstance(mask, list):
75 mask = mask[0]
76 if mask is not None:
77 raise ValueError("Masking is not supported for CuDNN RNNs.")
79 # input shape: `(samples, time (padded with zeros), input_dim)`
80 # note that the .build() method of subclasses MUST define
81 # self.input_spec and self.state_spec with complete input shapes.
82 if isinstance(inputs, list):
83 initial_state = inputs[1:]
84 inputs = inputs[0]
85 elif initial_state is not None:
86 pass
87 elif self.stateful:
88 initial_state = self.states
89 else:
90 initial_state = self.get_initial_state(inputs)
92 if len(initial_state) != len(self.states):
93 raise ValueError(
94 "Layer has "
95 + str(len(self.states))
96 + " states but was passed "
97 + str(len(initial_state))
98 + " initial states."
99 )
101 if self.go_backwards:
102 # Reverse time axis.
103 inputs = backend.reverse(inputs, 1)
104 output, states = self._process_batch(inputs, initial_state)
106 if self.stateful:
107 updates = [
108 tf.compat.v1.assign(self_state, state)
109 for self_state, state in zip(self.states, states)
110 ]
111 self.add_update(updates)
113 if self.return_state:
114 return [output] + states
115 else:
116 return output
118 def get_config(self):
119 config = {
120 "return_sequences": self.return_sequences,
121 "return_state": self.return_state,
122 "go_backwards": self.go_backwards,
123 "stateful": self.stateful,
124 "time_major": self.time_major,
125 }
126 base_config = super(RNN, self).get_config()
127 return dict(list(base_config.items()) + list(config.items()))
129 @classmethod
130 def from_config(cls, config):
131 return cls(**config)
133 @property
134 def trainable_weights(self):
135 if self.trainable and self.built:
136 return [self.kernel, self.recurrent_kernel, self.bias]
137 return []
139 @property
140 def non_trainable_weights(self):
141 if not self.trainable and self.built:
142 return [self.kernel, self.recurrent_kernel, self.bias]
143 return []
145 @property
146 def losses(self):
147 return super(RNN, self).losses
149 def get_losses_for(self, inputs=None):
150 return super(RNN, self).get_losses_for(inputs=inputs)