Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/rnn/abstract_rnn_cell.py: 39%
28 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Base class for RNN cells.
17Adapted from legacy github.com/keras-team/tf-keras.
18"""
20import tensorflow as tf
23def _generate_zero_filled_state_for_cell(cell, inputs, batch_size, dtype):
24 if inputs is not None:
25 batch_size = tf.shape(inputs)[0]
26 dtype = inputs.dtype
27 return _generate_zero_filled_state(batch_size, cell.state_size, dtype)
30def _generate_zero_filled_state(batch_size_tensor, state_size, dtype):
31 """Generate a zero filled tensor with shape [batch_size, state_size]."""
32 if batch_size_tensor is None or dtype is None:
33 raise ValueError(
34 "batch_size and dtype cannot be None while constructing initial state: "
35 "batch_size={}, dtype={}".format(batch_size_tensor, dtype)
36 )
38 def create_zeros(unnested_state_size):
39 flat_dims = tf.TensorShape(unnested_state_size).as_list()
40 init_state_size = [batch_size_tensor] + flat_dims
41 return tf.zeros(init_state_size, dtype=dtype)
43 if tf.nest.is_nested(state_size):
44 return tf.nest.map_structure(create_zeros, state_size)
45 else:
46 return create_zeros(state_size)
49class AbstractRNNCell(tf.keras.layers.Layer):
50 """Abstract object representing an RNN cell.
52 This is a base class for implementing RNN cells with custom behavior.
54 Every `RNNCell` must have the properties below and implement `call` with
55 the signature `(output, next_state) = call(input, state)`.
57 Examples:
59 ```python
60 class MinimalRNNCell(AbstractRNNCell):
62 def __init__(self, units, **kwargs):
63 self.units = units
64 super(MinimalRNNCell, self).__init__(**kwargs)
66 @property
67 def state_size(self):
68 return self.units
70 def build(self, input_shape):
71 self.kernel = self.add_weight(shape=(input_shape[-1], self.units),
72 initializer='uniform',
73 name='kernel')
74 self.recurrent_kernel = self.add_weight(
75 shape=(self.units, self.units),
76 initializer='uniform',
77 name='recurrent_kernel')
78 self.built = True
80 def call(self, inputs, states):
81 prev_output = states[0]
82 h = backend.dot(inputs, self.kernel)
83 output = h + backend.dot(prev_output, self.recurrent_kernel)
84 return output, output
85 ```
87 This definition of cell differs from the definition used in the literature.
88 In the literature, 'cell' refers to an object with a single scalar output.
89 This definition refers to a horizontal array of such units.
91 An RNN cell, in the most abstract setting, is anything that has
92 a state and performs some operation that takes a matrix of inputs.
93 This operation results in an output matrix with `self.output_size` columns.
94 If `self.state_size` is an integer, this operation also results in a new
95 state matrix with `self.state_size` columns. If `self.state_size` is a
96 (possibly nested tuple of) TensorShape object(s), then it should return a
97 matching structure of Tensors having shape `[batch_size].concatenate(s)`
98 for each `s` in `self.batch_size`.
99 """
101 def call(self, inputs, states):
102 """The function that contains the logic for one RNN step calculation.
104 Args:
105 inputs: the input tensor, which is a slide from the overall RNN input by
106 the time dimension (usually the second dimension).
107 states: the state tensor from previous step, which has the same shape
108 as `(batch, state_size)`. In the case of timestep 0, it will be the
109 initial state user specified, or zero filled tensor otherwise.
111 Returns:
112 A tuple of two tensors:
113 1. output tensor for the current timestep, with size `output_size`.
114 2. state tensor for next step, which has the shape of `state_size`.
115 """
116 raise NotImplementedError("Abstract method")
118 @property
119 def state_size(self):
120 """size(s) of state(s) used by this cell.
122 It can be represented by an Integer, a TensorShape or a tuple of Integers
123 or TensorShapes.
124 """
125 raise NotImplementedError("Abstract method")
127 @property
128 def output_size(self):
129 """Integer or TensorShape: size of outputs produced by this cell."""
130 raise NotImplementedError("Abstract method")
132 def get_initial_state(self, inputs=None, batch_size=None, dtype=None):
133 return _generate_zero_filled_state_for_cell(self, inputs, batch_size, dtype)