Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/layers/rnn/abstract_rnn_cell.py: 75%
16 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Base class for RNN cells."""
18from keras.src.engine import base_layer
19from keras.src.layers.rnn import rnn_utils
21# isort: off
22from tensorflow.python.util.tf_export import keras_export
25@keras_export("keras.layers.AbstractRNNCell")
26class AbstractRNNCell(base_layer.Layer):
27 """Abstract object representing an RNN cell.
29 See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn)
30 for details about the usage of RNN API.
32 This is the base class for implementing RNN cells with custom behavior.
34 Every `RNNCell` must have the properties below and implement `call` with
35 the signature `(output, next_state) = call(input, state)`.
37 Examples:
39 ```python
40 class MinimalRNNCell(AbstractRNNCell):
42 def __init__(self, units, **kwargs):
43 self.units = units
44 super(MinimalRNNCell, self).__init__(**kwargs)
46 @property
47 def state_size(self):
48 return self.units
50 def build(self, input_shape):
51 self.kernel = self.add_weight(shape=(input_shape[-1], self.units),
52 initializer='uniform',
53 name='kernel')
54 self.recurrent_kernel = self.add_weight(
55 shape=(self.units, self.units),
56 initializer='uniform',
57 name='recurrent_kernel')
58 self.built = True
60 def call(self, inputs, states):
61 prev_output = states[0]
62 h = backend.dot(inputs, self.kernel)
63 output = h + backend.dot(prev_output, self.recurrent_kernel)
64 return output, output
65 ```
67 This definition of cell differs from the definition used in the literature.
68 In the literature, 'cell' refers to an object with a single scalar output.
69 This definition refers to a horizontal array of such units.
71 An RNN cell, in the most abstract setting, is anything that has
72 a state and performs some operation that takes a matrix of inputs.
73 This operation results in an output matrix with `self.output_size` columns.
74 If `self.state_size` is an integer, this operation also results in a new
75 state matrix with `self.state_size` columns. If `self.state_size` is a
76 (possibly nested tuple of) TensorShape object(s), then it should return a
77 matching structure of Tensors having shape `[batch_size].concatenate(s)`
78 for each `s` in `self.batch_size`.
79 """
81 def call(self, inputs, states):
82 """The function that contains the logic for one RNN step calculation.
84 Args:
85 inputs: the input tensor, which is a slide from the overall RNN input
86 by the time dimension (usually the second dimension).
87 states: the state tensor from previous step, which has the same shape
88 as `(batch, state_size)`. In the case of timestep 0, it will be the
89 initial state user specified, or zero filled tensor otherwise.
91 Returns:
92 A tuple of two tensors:
93 1. output tensor for the current timestep, with size `output_size`.
94 2. state tensor for next step, which has the shape of `state_size`.
95 """
96 raise NotImplementedError
98 @property
99 def state_size(self):
100 """size(s) of state(s) used by this cell.
102 It can be represented by an Integer, a TensorShape or a tuple of
103 Integers or TensorShapes.
104 """
105 raise NotImplementedError
107 @property
108 def output_size(self):
109 """Integer or TensorShape: size of outputs produced by this cell."""
110 raise NotImplementedError
112 def get_initial_state(self, inputs=None, batch_size=None, dtype=None):
113 return rnn_utils.generate_zero_filled_state_for_cell(
114 self, inputs, batch_size, dtype
115 )