Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/layers/rnn/base_conv_rnn.py: 12%
154 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Base class for convolutional-recurrent layers."""
18import numpy as np
19import tensorflow.compat.v2 as tf
21from keras.src import backend
22from keras.src.engine import base_layer
23from keras.src.engine.input_spec import InputSpec
24from keras.src.layers.rnn.base_rnn import RNN
25from keras.src.utils import conv_utils
26from keras.src.utils import generic_utils
27from keras.src.utils import tf_utils
30class ConvRNN(RNN):
31 """N-Dimensional Base class for convolutional-recurrent layers.
33 Args:
34 rank: Integer, rank of the convolution, e.g. "2" for 2D convolutions.
35 cell: A RNN cell instance. A RNN cell is a class that has: - a
36 `call(input_at_t, states_at_t)` method, returning `(output_at_t,
37 states_at_t_plus_1)`. The call method of the cell can also take the
38 optional argument `constants`, see section "Note on passing external
39 constants" below. - a `state_size` attribute. This can be a single
40 integer (single state) in which case it is the number of channels of the
41 recurrent state (which should be the same as the number of channels of
42 the cell output). This can also be a list/tuple of integers (one size
43 per state). In this case, the first entry (`state_size[0]`) should be
44 the same as the size of the cell output.
45 return_sequences: Boolean. Whether to return the last output. in the
46 output sequence, or the full sequence.
47 return_state: Boolean. Whether to return the last state in addition to the
48 output.
49 go_backwards: Boolean (default False). If True, process the input sequence
50 backwards and return the reversed sequence.
51 stateful: Boolean (default False). If True, the last state for each sample
52 at index i in a batch will be used as initial state for the sample of
53 index i in the following batch.
54 input_shape: Use this argument to specify the shape of the input when this
55 layer is the first one in a model.
56 Call arguments:
57 inputs: A (2 + `rank`)D tensor.
58 mask: Binary tensor of shape `(samples, timesteps)` indicating whether a
59 given timestep should be masked.
60 training: Python boolean indicating whether the layer should behave in
61 training mode or in inference mode. This argument is passed to the cell
62 when calling it. This is for use with cells that use dropout.
63 initial_state: List of initial state tensors to be passed to the first
64 call of the cell.
65 constants: List of constant tensors to be passed to the cell at each
66 timestep.
67 Input shape:
68 (3 + `rank`)D tensor with shape: `(samples, timesteps, channels,
69 img_dimensions...)`
70 if data_format='channels_first' or shape: `(samples, timesteps,
71 img_dimensions..., channels)` if data_format='channels_last'.
72 Output shape:
73 - If `return_state`: a list of tensors. The first tensor is the output.
74 The remaining tensors are the last states,
75 each (2 + `rank`)D tensor with shape: `(samples, filters,
76 new_img_dimensions...)` if data_format='channels_first'
77 or shape: `(samples, new_img_dimensions..., filters)` if
78 data_format='channels_last'. img_dimension values might have changed
79 due to padding.
80 - If `return_sequences`: (3 + `rank`)D tensor with shape: `(samples,
81 timesteps, filters, new_img_dimensions...)` if
82 data_format='channels_first'
83 or shape: `(samples, timesteps, new_img_dimensions..., filters)` if
84 data_format='channels_last'.
85 - Else, (2 + `rank`)D tensor with shape: `(samples, filters,
86 new_img_dimensions...)` if data_format='channels_first'
87 or shape: `(samples, new_img_dimensions..., filters)` if
88 data_format='channels_last'.
89 Masking: This layer supports masking for input data with a variable number
90 of timesteps.
91 Note on using statefulness in RNNs: You can set RNN layers to be 'stateful',
92 which means that the states computed for the samples in one batch will be
93 reused as initial states for the samples in the next batch. This assumes a
94 one-to-one mapping between samples in different successive batches.
95 To enable statefulness: - Specify `stateful=True` in the layer
96 constructor.
97 - Specify a fixed batch size for your model, by passing
98 - If sequential model: `batch_input_shape=(...)` to the first layer
99 in your model.
100 - If functional model with 1 or more Input layers:
101 `batch_shape=(...)` to all the first layers in your model. This is
102 the expected shape of your inputs *including the batch size*. It
103 should be a tuple of integers, e.g. `(32, 10, 100, 100, 32)`. for
104 rank 2 convolution Note that the image dimensions should be
105 specified too. - Specify `shuffle=False` when calling fit(). To
106 reset the states of your model, call `.reset_states()` on either a
107 specific layer, or on your entire model.
108 Note on specifying the initial state of RNNs: You can specify the initial
109 state of RNN layers symbolically by calling them with the keyword argument
110 `initial_state`. The value of `initial_state` should be a tensor or list
111 of tensors representing the initial state of the RNN layer. You can
112 specify the initial state of RNN layers numerically by calling
113 `reset_states` with the keyword argument `states`. The value of `states`
114 should be a numpy array or list of numpy arrays representing the initial
115 state of the RNN layer.
116 Note on passing external constants to RNNs: You can pass "external"
117 constants to the cell using the `constants` keyword argument of
118 `RNN.__call__` (as well as `RNN.call`) method. This requires that the
119 `cell.call` method accepts the same keyword argument `constants`. Such
120 constants can be used to condition the cell transformation on additional
121 static inputs (not changing over time), a.k.a. an attention mechanism.
122 """
124 def __init__(
125 self,
126 rank,
127 cell,
128 return_sequences=False,
129 return_state=False,
130 go_backwards=False,
131 stateful=False,
132 unroll=False,
133 **kwargs,
134 ):
135 if unroll:
136 raise TypeError(
137 "Unrolling is not possible with convolutional RNNs. "
138 f"Received: unroll={unroll}"
139 )
140 if isinstance(cell, (list, tuple)):
141 # The StackedConvRNN3DCells isn't implemented yet.
142 raise TypeError(
143 "It is not possible at the moment to"
144 "stack convolutional cells. Only pass a single cell "
145 "instance as the `cell` argument. Received: "
146 f"cell={cell}"
147 )
148 super().__init__(
149 cell,
150 return_sequences,
151 return_state,
152 go_backwards,
153 stateful,
154 unroll,
155 **kwargs,
156 )
157 self.rank = rank
158 self.input_spec = [InputSpec(ndim=rank + 3)]
159 self.states = None
160 self._num_constants = None
162 @tf_utils.shape_type_conversion
163 def compute_output_shape(self, input_shape):
164 if isinstance(input_shape, list):
165 input_shape = input_shape[0]
167 cell = self.cell
168 if cell.data_format == "channels_first":
169 img_dims = input_shape[3:]
170 elif cell.data_format == "channels_last":
171 img_dims = input_shape[2:-1]
173 norm_img_dims = tuple(
174 [
175 conv_utils.conv_output_length(
176 img_dims[idx],
177 cell.kernel_size[idx],
178 padding=cell.padding,
179 stride=cell.strides[idx],
180 dilation=cell.dilation_rate[idx],
181 )
182 for idx in range(len(img_dims))
183 ]
184 )
186 if cell.data_format == "channels_first":
187 output_shape = input_shape[:2] + (cell.filters,) + norm_img_dims
188 elif cell.data_format == "channels_last":
189 output_shape = input_shape[:2] + norm_img_dims + (cell.filters,)
191 if not self.return_sequences:
192 output_shape = output_shape[:1] + output_shape[2:]
194 if self.return_state:
195 output_shape = [output_shape]
196 if cell.data_format == "channels_first":
197 output_shape += [
198 (input_shape[0], cell.filters) + norm_img_dims
199 for _ in range(2)
200 ]
201 elif cell.data_format == "channels_last":
202 output_shape += [
203 (input_shape[0],) + norm_img_dims + (cell.filters,)
204 for _ in range(2)
205 ]
206 return output_shape
208 @tf_utils.shape_type_conversion
209 def build(self, input_shape):
210 # Note input_shape will be list of shapes of initial states and
211 # constants if these are passed in __call__.
212 if self._num_constants is not None:
213 constants_shape = input_shape[-self._num_constants :]
214 else:
215 constants_shape = None
217 if isinstance(input_shape, list):
218 input_shape = input_shape[0]
220 batch_size = input_shape[0] if self.stateful else None
221 self.input_spec[0] = InputSpec(
222 shape=(batch_size, None) + input_shape[2 : self.rank + 3]
223 )
225 # allow cell (if layer) to build before we set or validate state_spec
226 if isinstance(self.cell, base_layer.Layer):
227 step_input_shape = (input_shape[0],) + input_shape[2:]
228 if constants_shape is not None:
229 self.cell.build([step_input_shape] + constants_shape)
230 else:
231 self.cell.build(step_input_shape)
233 # set or validate state_spec
234 if hasattr(self.cell.state_size, "__len__"):
235 state_size = list(self.cell.state_size)
236 else:
237 state_size = [self.cell.state_size]
239 if self.state_spec is not None:
240 # initial_state was passed in call, check compatibility
241 if self.cell.data_format == "channels_first":
242 ch_dim = 1
243 elif self.cell.data_format == "channels_last":
244 ch_dim = self.rank + 1
245 if [spec.shape[ch_dim] for spec in self.state_spec] != state_size:
246 raise ValueError(
247 "An `initial_state` was passed that is not compatible with "
248 "`cell.state_size`. Received state shapes "
249 f"{[spec.shape for spec in self.state_spec]}. "
250 f"However `cell.state_size` is {self.cell.state_size}"
251 )
252 else:
253 img_dims = tuple((None for _ in range(self.rank)))
254 if self.cell.data_format == "channels_first":
255 self.state_spec = [
256 InputSpec(shape=(None, dim) + img_dims)
257 for dim in state_size
258 ]
259 elif self.cell.data_format == "channels_last":
260 self.state_spec = [
261 InputSpec(shape=(None,) + img_dims + (dim,))
262 for dim in state_size
263 ]
264 if self.stateful:
265 self.reset_states()
266 self.built = True
268 def get_initial_state(self, inputs):
269 # (samples, timesteps, img_dims..., filters)
270 initial_state = backend.zeros_like(inputs)
271 # (samples, img_dims..., filters)
272 initial_state = backend.sum(initial_state, axis=1)
273 shape = list(self.cell.kernel_shape)
274 shape[-1] = self.cell.filters
275 initial_state = self.cell.input_conv(
276 initial_state,
277 tf.zeros(tuple(shape), initial_state.dtype),
278 padding=self.cell.padding,
279 )
281 if hasattr(self.cell.state_size, "__len__"):
282 return [initial_state for _ in self.cell.state_size]
283 else:
284 return [initial_state]
286 def call(
287 self,
288 inputs,
289 mask=None,
290 training=None,
291 initial_state=None,
292 constants=None,
293 ):
294 # note that the .build() method of subclasses MUST define
295 # self.input_spec and self.state_spec with complete input shapes.
296 inputs, initial_state, constants = self._process_inputs(
297 inputs, initial_state, constants
298 )
300 if isinstance(mask, list):
301 mask = mask[0]
302 timesteps = backend.int_shape(inputs)[1]
304 kwargs = {}
305 if generic_utils.has_arg(self.cell.call, "training"):
306 kwargs["training"] = training
308 if constants:
309 if not generic_utils.has_arg(self.cell.call, "constants"):
310 raise ValueError(
311 f"RNN cell {self.cell} does not support constants. "
312 f"Received: constants={constants}"
313 )
315 def step(inputs, states):
316 constants = states[-self._num_constants :]
317 states = states[: -self._num_constants]
318 return self.cell.call(
319 inputs, states, constants=constants, **kwargs
320 )
322 else:
324 def step(inputs, states):
325 return self.cell.call(inputs, states, **kwargs)
327 last_output, outputs, states = backend.rnn(
328 step,
329 inputs,
330 initial_state,
331 constants=constants,
332 go_backwards=self.go_backwards,
333 mask=mask,
334 input_length=timesteps,
335 return_all_outputs=self.return_sequences,
336 )
337 if self.stateful:
338 updates = [
339 backend.update(self_state, state)
340 for self_state, state in zip(self.states, states)
341 ]
342 self.add_update(updates)
344 if self.return_sequences:
345 output = outputs
346 else:
347 output = last_output
349 if self.return_state:
350 if not isinstance(states, (list, tuple)):
351 states = [states]
352 else:
353 states = list(states)
354 return [output] + states
355 return output
357 def reset_states(self, states=None):
358 if not self.stateful:
359 raise AttributeError("Layer must be stateful.")
360 input_shape = self.input_spec[0].shape
361 state_shape = self.compute_output_shape(input_shape)
362 if self.return_state:
363 state_shape = state_shape[0]
364 if self.return_sequences:
365 state_shape = state_shape[:1].concatenate(state_shape[2:])
366 if None in state_shape:
367 raise ValueError(
368 "If a RNN is stateful, it needs to know "
369 "its batch size. Specify the batch size "
370 "of your input tensors: \n"
371 "- If using a Sequential model, "
372 "specify the batch size by passing "
373 "a `batch_input_shape` "
374 "argument to your first layer.\n"
375 "- If using the functional API, specify "
376 "the time dimension by passing a "
377 "`batch_shape` argument to your Input layer.\n"
378 "The same thing goes for the number of rows and "
379 "columns."
380 )
382 # helper function
383 def get_tuple_shape(nb_channels):
384 result = list(state_shape)
385 if self.cell.data_format == "channels_first":
386 result[1] = nb_channels
387 elif self.cell.data_format == "channels_last":
388 result[self.rank + 1] = nb_channels
389 else:
390 raise KeyError(
391 "Cell data format must be one of "
392 '{"channels_first", "channels_last"}. Received: '
393 f"cell.data_format={self.cell.data_format}"
394 )
395 return tuple(result)
397 # initialize state if None
398 if self.states[0] is None:
399 if hasattr(self.cell.state_size, "__len__"):
400 self.states = [
401 backend.zeros(get_tuple_shape(dim))
402 for dim in self.cell.state_size
403 ]
404 else:
405 self.states = [
406 backend.zeros(get_tuple_shape(self.cell.state_size))
407 ]
408 elif states is None:
409 if hasattr(self.cell.state_size, "__len__"):
410 for state, dim in zip(self.states, self.cell.state_size):
411 backend.set_value(state, np.zeros(get_tuple_shape(dim)))
412 else:
413 backend.set_value(
414 self.states[0],
415 np.zeros(get_tuple_shape(self.cell.state_size)),
416 )
417 else:
418 if not isinstance(states, (list, tuple)):
419 states = [states]
420 if len(states) != len(self.states):
421 raise ValueError(
422 f"Layer {self.name} expects {len(self.states)} states, "
423 f"but it received {len(states)} state values. "
424 f"States received: {states}"
425 )
426 for index, (value, state) in enumerate(zip(states, self.states)):
427 if hasattr(self.cell.state_size, "__len__"):
428 dim = self.cell.state_size[index]
429 else:
430 dim = self.cell.state_size
431 if value.shape != get_tuple_shape(dim):
432 raise ValueError(
433 "State {index} is incompatible with layer "
434 f"{self.name}: expected shape={get_tuple_shape(dim)}, "
435 f"found shape={value.shape}"
436 )
437 backend.set_value(state, value)