Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/rnn/nas_cell.py: 23%
77 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Implements NAS Cell."""
17import tensorflow as tf
18from typeguard import typechecked
20from tensorflow_addons.rnn.abstract_rnn_cell import AbstractRNNCell
21from tensorflow_addons.utils.types import (
22 FloatTensorLike,
23 TensorLike,
24 Initializer,
25)
26from typing import Optional
29@tf.keras.utils.register_keras_serializable(package="Addons")
30class NASCell(AbstractRNNCell):
31 """Neural Architecture Search (NAS) recurrent network cell.
33 This implements the recurrent cell from the paper:
35 https://arxiv.org/abs/1611.01578
37 Barret Zoph and Quoc V. Le.
38 "Neural Architecture Search with Reinforcement Learning" Proc. ICLR 2017.
40 The class uses an optional projection layer.
42 Example:
44 >>> inputs = np.random.random([30,23,9]).astype(np.float32)
45 >>> NASCell = tfa.rnn.NASCell(4)
46 >>> rnn = tf.keras.layers.RNN(NASCell, return_sequences=True, return_state=True)
47 >>> outputs, memory_state, carry_state = rnn(inputs)
48 >>> outputs.shape
49 TensorShape([30, 23, 4])
50 >>> memory_state.shape
51 TensorShape([30, 4])
52 >>> carry_state.shape
53 TensorShape([30, 4])
54 """
56 # NAS cell's architecture base.
57 _NAS_BASE = 8
59 @typechecked
60 def __init__(
61 self,
62 units: TensorLike,
63 projection: Optional[FloatTensorLike] = None,
64 use_bias: bool = False,
65 kernel_initializer: Initializer = "glorot_uniform",
66 recurrent_initializer: Initializer = "glorot_uniform",
67 projection_initializer: Initializer = "glorot_uniform",
68 bias_initializer: Initializer = "zeros",
69 **kwargs,
70 ):
71 """Initialize the parameters for a NAS cell.
73 Args:
74 units: int, The number of units in the NAS cell.
75 projection: (optional) int, The output dimensionality for the
76 projection matrices. If None, no projection is performed.
77 use_bias: (optional) bool, If True then use biases within the cell.
78 This is False by default.
79 kernel_initializer: Initializer for kernel weight.
80 recurrent_initializer: Initializer for recurrent kernel weight.
81 projection_initializer: Initializer for projection weight, used when
82 projection is not None.
83 bias_initializer: Initializer for bias, used when use_bias is True.
84 **kwargs: Additional keyword arguments.
85 """
86 super().__init__(**kwargs)
87 self.units = units
88 self.projection = projection
89 self.use_bias = use_bias
90 self.kernel_initializer = kernel_initializer
91 self.recurrent_initializer = recurrent_initializer
92 self.projection_initializer = projection_initializer
93 self.bias_initializer = bias_initializer
95 if projection is not None:
96 self._state_size = [units, projection]
97 self._output_size = projection
98 else:
99 self._state_size = [units, units]
100 self._output_size = units
102 @property
103 def state_size(self):
104 return self._state_size
106 @property
107 def output_size(self):
108 return self._output_size
110 def build(self, inputs_shape):
111 input_size = tf.compat.dimension_value(
112 tf.TensorShape(inputs_shape).with_rank(2)[1]
113 )
114 if input_size is None:
115 raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
117 # Variables for the NAS cell. `recurrent_kernel` is all matrices
118 # multiplying the hidden state and `kernel` is all matrices multiplying
119 # the inputs.
120 self.recurrent_kernel = self.add_weight(
121 name="recurrent_kernel",
122 shape=[self.output_size, self._NAS_BASE * self.units],
123 initializer=self.recurrent_initializer,
124 )
125 self.kernel = self.add_weight(
126 name="kernel",
127 shape=[input_size, self._NAS_BASE * self.units],
128 initializer=self.kernel_initializer,
129 )
131 if self.use_bias:
132 self.bias = self.add_weight(
133 name="bias",
134 shape=[self._NAS_BASE * self.units],
135 initializer=self.bias_initializer,
136 )
137 # Projection layer if specified
138 if self.projection is not None:
139 self.projection_weights = self.add_weight(
140 name="projection_weights",
141 shape=[self.units, self.projection],
142 initializer=self.projection_initializer,
143 )
145 self.built = True
147 def call(self, inputs, state):
148 """Run one step of NAS Cell.
150 Args:
151 inputs: input Tensor, 2D, batch x num_units.
152 state: This must be a list of state Tensors, both `2-D`, with column
153 sizes `c_state` and `m_state`.
155 Returns:
156 A tuple containing:
157 - A `2-D, [batch x output_dim]`, Tensor representing the output of
158 the NAS Cell after reading `inputs` when previous state was
159 `state`.
160 Here output_dim is:
161 projection if projection was set, units otherwise.
162 - Tensor(s) representing the new state of NAS Cell after reading
163 `inputs` when the previous state was `state`. Same type and
164 shape(s) as `state`.
166 Raises:
167 ValueError: If input size cannot be inferred from inputs via
168 static shape inference.
169 """
170 sigmoid = tf.math.sigmoid
171 tanh = tf.math.tanh
172 relu = tf.nn.relu
174 c_prev, m_prev = state
176 m_matrix = tf.matmul(m_prev, self.recurrent_kernel)
177 inputs_matrix = tf.matmul(inputs, self.kernel)
179 if self.use_bias:
180 m_matrix = tf.nn.bias_add(m_matrix, self.bias)
182 # The NAS cell branches into 8 different splits for both the hidden
183 # state and the input
184 m_matrix_splits = tf.split(
185 axis=1, num_or_size_splits=self._NAS_BASE, value=m_matrix
186 )
187 inputs_matrix_splits = tf.split(
188 axis=1, num_or_size_splits=self._NAS_BASE, value=inputs_matrix
189 )
191 # First layer
192 layer1_0 = sigmoid(inputs_matrix_splits[0] + m_matrix_splits[0])
193 layer1_1 = relu(inputs_matrix_splits[1] + m_matrix_splits[1])
194 layer1_2 = sigmoid(inputs_matrix_splits[2] + m_matrix_splits[2])
195 layer1_3 = relu(inputs_matrix_splits[3] * m_matrix_splits[3])
196 layer1_4 = tanh(inputs_matrix_splits[4] + m_matrix_splits[4])
197 layer1_5 = sigmoid(inputs_matrix_splits[5] + m_matrix_splits[5])
198 layer1_6 = tanh(inputs_matrix_splits[6] + m_matrix_splits[6])
199 layer1_7 = sigmoid(inputs_matrix_splits[7] + m_matrix_splits[7])
201 # Second layer
202 l2_0 = tanh(layer1_0 * layer1_1)
203 l2_1 = tanh(layer1_2 + layer1_3)
204 l2_2 = tanh(layer1_4 * layer1_5)
205 l2_3 = sigmoid(layer1_6 + layer1_7)
207 # Inject the cell
208 l2_0 = tanh(l2_0 + c_prev)
210 # Third layer
211 l3_0_pre = l2_0 * l2_1
212 new_c = l3_0_pre # create new cell
213 l3_0 = l3_0_pre
214 l3_1 = tanh(l2_2 + l2_3)
216 # Final layer
217 new_m = tanh(l3_0 * l3_1)
219 # Projection layer if specified
220 if self.projection is not None:
221 new_m = tf.matmul(new_m, self.projection_weights)
223 return new_m, [new_c, new_m]
225 def get_config(self):
226 config = {
227 "units": self.units,
228 "projection": self.projection,
229 "use_bias": self.use_bias,
230 "kernel_initializer": self.kernel_initializer,
231 "recurrent_initializer": self.recurrent_initializer,
232 "bias_initializer": self.bias_initializer,
233 "projection_initializer": self.projection_initializer,
234 }
235 base_config = super().get_config()
236 return {**base_config, **config}