Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/layers/convolutional_recurrent.py: 24%
339 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=protected-access
16# pylint: disable=g-classes-have-attributes
17"""Convolutional-recurrent layers."""
19import numpy as np
21from tensorflow.python.keras import activations
22from tensorflow.python.keras import backend
23from tensorflow.python.keras import constraints
24from tensorflow.python.keras import initializers
25from tensorflow.python.keras import regularizers
26from tensorflow.python.keras.engine.base_layer import Layer
27from tensorflow.python.keras.engine.input_spec import InputSpec
28from tensorflow.python.keras.layers.recurrent import DropoutRNNCellMixin
29from tensorflow.python.keras.layers.recurrent import RNN
30from tensorflow.python.keras.utils import conv_utils
31from tensorflow.python.keras.utils import generic_utils
32from tensorflow.python.keras.utils import tf_utils
33from tensorflow.python.ops import array_ops
34from tensorflow.python.util.tf_export import keras_export
37class ConvRNN2D(RNN):
38 """Base class for convolutional-recurrent layers.
40 Args:
41 cell: A RNN cell instance. A RNN cell is a class that has:
42 - a `call(input_at_t, states_at_t)` method, returning
43 `(output_at_t, states_at_t_plus_1)`. The call method of the
44 cell can also take the optional argument `constants`, see
45 section "Note on passing external constants" below.
46 - a `state_size` attribute. This can be a single integer
47 (single state) in which case it is
48 the number of channels of the recurrent state
49 (which should be the same as the number of channels of the cell
50 output). This can also be a list/tuple of integers
51 (one size per state). In this case, the first entry
52 (`state_size[0]`) should be the same as
53 the size of the cell output.
54 return_sequences: Boolean. Whether to return the last output.
55 in the output sequence, or the full sequence.
56 return_state: Boolean. Whether to return the last state
57 in addition to the output.
58 go_backwards: Boolean (default False).
59 If True, process the input sequence backwards and return the
60 reversed sequence.
61 stateful: Boolean (default False). If True, the last state
62 for each sample at index i in a batch will be used as initial
63 state for the sample of index i in the following batch.
64 input_shape: Use this argument to specify the shape of the
65 input when this layer is the first one in a model.
67 Call arguments:
68 inputs: A 5D tensor.
69 mask: Binary tensor of shape `(samples, timesteps)` indicating whether
70 a given timestep should be masked.
71 training: Python boolean indicating whether the layer should behave in
72 training mode or in inference mode. This argument is passed to the cell
73 when calling it. This is for use with cells that use dropout.
74 initial_state: List of initial state tensors to be passed to the first
75 call of the cell.
76 constants: List of constant tensors to be passed to the cell at each
77 timestep.
79 Input shape:
80 5D tensor with shape:
81 `(samples, timesteps, channels, rows, cols)`
82 if data_format='channels_first' or 5D tensor with shape:
83 `(samples, timesteps, rows, cols, channels)`
84 if data_format='channels_last'.
86 Output shape:
87 - If `return_state`: a list of tensors. The first tensor is
88 the output. The remaining tensors are the last states,
89 each 4D tensor with shape:
90 `(samples, filters, new_rows, new_cols)`
91 if data_format='channels_first'
92 or 4D tensor with shape:
93 `(samples, new_rows, new_cols, filters)`
94 if data_format='channels_last'.
95 `rows` and `cols` values might have changed due to padding.
96 - If `return_sequences`: 5D tensor with shape:
97 `(samples, timesteps, filters, new_rows, new_cols)`
98 if data_format='channels_first'
99 or 5D tensor with shape:
100 `(samples, timesteps, new_rows, new_cols, filters)`
101 if data_format='channels_last'.
102 - Else, 4D tensor with shape:
103 `(samples, filters, new_rows, new_cols)`
104 if data_format='channels_first'
105 or 4D tensor with shape:
106 `(samples, new_rows, new_cols, filters)`
107 if data_format='channels_last'.
109 Masking:
110 This layer supports masking for input data with a variable number
111 of timesteps.
113 Note on using statefulness in RNNs:
114 You can set RNN layers to be 'stateful', which means that the states
115 computed for the samples in one batch will be reused as initial states
116 for the samples in the next batch. This assumes a one-to-one mapping
117 between samples in different successive batches.
118 To enable statefulness:
119 - Specify `stateful=True` in the layer constructor.
120 - Specify a fixed batch size for your model, by passing
121 - If sequential model:
122 `batch_input_shape=(...)` to the first layer in your model.
123 - If functional model with 1 or more Input layers:
124 `batch_shape=(...)` to all the first layers in your model.
125 This is the expected shape of your inputs
126 *including the batch size*.
127 It should be a tuple of integers,
128 e.g. `(32, 10, 100, 100, 32)`.
129 Note that the number of rows and columns should be specified
130 too.
131 - Specify `shuffle=False` when calling fit().
132 To reset the states of your model, call `.reset_states()` on either
133 a specific layer, or on your entire model.
135 Note on specifying the initial state of RNNs:
136 You can specify the initial state of RNN layers symbolically by
137 calling them with the keyword argument `initial_state`. The value of
138 `initial_state` should be a tensor or list of tensors representing
139 the initial state of the RNN layer.
140 You can specify the initial state of RNN layers numerically by
141 calling `reset_states` with the keyword argument `states`. The value of
142 `states` should be a numpy array or list of numpy arrays representing
143 the initial state of the RNN layer.
145 Note on passing external constants to RNNs:
146 You can pass "external" constants to the cell using the `constants`
147 keyword argument of `RNN.__call__` (as well as `RNN.call`) method. This
148 requires that the `cell.call` method accepts the same keyword argument
149 `constants`. Such constants can be used to condition the cell
150 transformation on additional static inputs (not changing over time),
151 a.k.a. an attention mechanism.
152 """
154 def __init__(self,
155 cell,
156 return_sequences=False,
157 return_state=False,
158 go_backwards=False,
159 stateful=False,
160 unroll=False,
161 **kwargs):
162 if unroll:
163 raise TypeError('Unrolling isn\'t possible with '
164 'convolutional RNNs.')
165 if isinstance(cell, (list, tuple)):
166 # The StackedConvRNN2DCells isn't implemented yet.
167 raise TypeError('It is not possible at the moment to'
168 'stack convolutional cells.')
169 super(ConvRNN2D, self).__init__(cell,
170 return_sequences,
171 return_state,
172 go_backwards,
173 stateful,
174 unroll,
175 **kwargs)
176 self.input_spec = [InputSpec(ndim=5)]
177 self.states = None
178 self._num_constants = None
180 @tf_utils.shape_type_conversion
181 def compute_output_shape(self, input_shape):
182 if isinstance(input_shape, list):
183 input_shape = input_shape[0]
185 cell = self.cell
186 if cell.data_format == 'channels_first':
187 rows = input_shape[3]
188 cols = input_shape[4]
189 elif cell.data_format == 'channels_last':
190 rows = input_shape[2]
191 cols = input_shape[3]
192 rows = conv_utils.conv_output_length(rows,
193 cell.kernel_size[0],
194 padding=cell.padding,
195 stride=cell.strides[0],
196 dilation=cell.dilation_rate[0])
197 cols = conv_utils.conv_output_length(cols,
198 cell.kernel_size[1],
199 padding=cell.padding,
200 stride=cell.strides[1],
201 dilation=cell.dilation_rate[1])
203 if cell.data_format == 'channels_first':
204 output_shape = input_shape[:2] + (cell.filters, rows, cols)
205 elif cell.data_format == 'channels_last':
206 output_shape = input_shape[:2] + (rows, cols, cell.filters)
208 if not self.return_sequences:
209 output_shape = output_shape[:1] + output_shape[2:]
211 if self.return_state:
212 output_shape = [output_shape]
213 if cell.data_format == 'channels_first':
214 output_shape += [(input_shape[0], cell.filters, rows, cols)
215 for _ in range(2)]
216 elif cell.data_format == 'channels_last':
217 output_shape += [(input_shape[0], rows, cols, cell.filters)
218 for _ in range(2)]
219 return output_shape
221 @tf_utils.shape_type_conversion
222 def build(self, input_shape):
223 # Note input_shape will be list of shapes of initial states and
224 # constants if these are passed in __call__.
225 if self._num_constants is not None:
226 constants_shape = input_shape[-self._num_constants:] # pylint: disable=E1130
227 else:
228 constants_shape = None
230 if isinstance(input_shape, list):
231 input_shape = input_shape[0]
233 batch_size = input_shape[0] if self.stateful else None
234 self.input_spec[0] = InputSpec(shape=(batch_size, None) + input_shape[2:5])
236 # allow cell (if layer) to build before we set or validate state_spec
237 if isinstance(self.cell, Layer):
238 step_input_shape = (input_shape[0],) + input_shape[2:]
239 if constants_shape is not None:
240 self.cell.build([step_input_shape] + constants_shape)
241 else:
242 self.cell.build(step_input_shape)
244 # set or validate state_spec
245 if hasattr(self.cell.state_size, '__len__'):
246 state_size = list(self.cell.state_size)
247 else:
248 state_size = [self.cell.state_size]
250 if self.state_spec is not None:
251 # initial_state was passed in call, check compatibility
252 if self.cell.data_format == 'channels_first':
253 ch_dim = 1
254 elif self.cell.data_format == 'channels_last':
255 ch_dim = 3
256 if [spec.shape[ch_dim] for spec in self.state_spec] != state_size:
257 raise ValueError(
258 'An initial_state was passed that is not compatible with '
259 '`cell.state_size`. Received `state_spec`={}; '
260 'However `cell.state_size` is '
261 '{}'.format([spec.shape for spec in self.state_spec],
262 self.cell.state_size))
263 else:
264 if self.cell.data_format == 'channels_first':
265 self.state_spec = [InputSpec(shape=(None, dim, None, None))
266 for dim in state_size]
267 elif self.cell.data_format == 'channels_last':
268 self.state_spec = [InputSpec(shape=(None, None, None, dim))
269 for dim in state_size]
270 if self.stateful:
271 self.reset_states()
272 self.built = True
274 def get_initial_state(self, inputs):
275 # (samples, timesteps, rows, cols, filters)
276 initial_state = backend.zeros_like(inputs)
277 # (samples, rows, cols, filters)
278 initial_state = backend.sum(initial_state, axis=1)
279 shape = list(self.cell.kernel_shape)
280 shape[-1] = self.cell.filters
281 initial_state = self.cell.input_conv(initial_state,
282 array_ops.zeros(tuple(shape),
283 initial_state.dtype),
284 padding=self.cell.padding)
286 if hasattr(self.cell.state_size, '__len__'):
287 return [initial_state for _ in self.cell.state_size]
288 else:
289 return [initial_state]
291 def call(self,
292 inputs,
293 mask=None,
294 training=None,
295 initial_state=None,
296 constants=None):
297 # note that the .build() method of subclasses MUST define
298 # self.input_spec and self.state_spec with complete input shapes.
299 inputs, initial_state, constants = self._process_inputs(
300 inputs, initial_state, constants)
302 if isinstance(mask, list):
303 mask = mask[0]
304 timesteps = backend.int_shape(inputs)[1]
306 kwargs = {}
307 if generic_utils.has_arg(self.cell.call, 'training'):
308 kwargs['training'] = training
310 if constants:
311 if not generic_utils.has_arg(self.cell.call, 'constants'):
312 raise ValueError('RNN cell does not support constants')
314 def step(inputs, states):
315 constants = states[-self._num_constants:] # pylint: disable=invalid-unary-operand-type
316 states = states[:-self._num_constants] # pylint: disable=invalid-unary-operand-type
317 return self.cell.call(inputs, states, constants=constants, **kwargs)
318 else:
319 def step(inputs, states):
320 return self.cell.call(inputs, states, **kwargs)
322 last_output, outputs, states = backend.rnn(step,
323 inputs,
324 initial_state,
325 constants=constants,
326 go_backwards=self.go_backwards,
327 mask=mask,
328 input_length=timesteps)
329 if self.stateful:
330 updates = [
331 backend.update(self_state, state)
332 for self_state, state in zip(self.states, states)
333 ]
334 self.add_update(updates)
336 if self.return_sequences:
337 output = outputs
338 else:
339 output = last_output
341 if self.return_state:
342 if not isinstance(states, (list, tuple)):
343 states = [states]
344 else:
345 states = list(states)
346 return [output] + states
347 else:
348 return output
350 def reset_states(self, states=None):
351 if not self.stateful:
352 raise AttributeError('Layer must be stateful.')
353 input_shape = self.input_spec[0].shape
354 state_shape = self.compute_output_shape(input_shape)
355 if self.return_state:
356 state_shape = state_shape[0]
357 if self.return_sequences:
358 state_shape = state_shape[:1].concatenate(state_shape[2:])
359 if None in state_shape:
360 raise ValueError('If a RNN is stateful, it needs to know '
361 'its batch size. Specify the batch size '
362 'of your input tensors: \n'
363 '- If using a Sequential model, '
364 'specify the batch size by passing '
365 'a `batch_input_shape` '
366 'argument to your first layer.\n'
367 '- If using the functional API, specify '
368 'the time dimension by passing a '
369 '`batch_shape` argument to your Input layer.\n'
370 'The same thing goes for the number of rows and '
371 'columns.')
373 # helper function
374 def get_tuple_shape(nb_channels):
375 result = list(state_shape)
376 if self.cell.data_format == 'channels_first':
377 result[1] = nb_channels
378 elif self.cell.data_format == 'channels_last':
379 result[3] = nb_channels
380 else:
381 raise KeyError
382 return tuple(result)
384 # initialize state if None
385 if self.states[0] is None:
386 if hasattr(self.cell.state_size, '__len__'):
387 self.states = [backend.zeros(get_tuple_shape(dim))
388 for dim in self.cell.state_size]
389 else:
390 self.states = [backend.zeros(get_tuple_shape(self.cell.state_size))]
391 elif states is None:
392 if hasattr(self.cell.state_size, '__len__'):
393 for state, dim in zip(self.states, self.cell.state_size):
394 backend.set_value(state, np.zeros(get_tuple_shape(dim)))
395 else:
396 backend.set_value(self.states[0],
397 np.zeros(get_tuple_shape(self.cell.state_size)))
398 else:
399 if not isinstance(states, (list, tuple)):
400 states = [states]
401 if len(states) != len(self.states):
402 raise ValueError('Layer ' + self.name + ' expects ' +
403 str(len(self.states)) + ' states, ' +
404 'but it received ' + str(len(states)) +
405 ' state values. Input received: ' + str(states))
406 for index, (value, state) in enumerate(zip(states, self.states)):
407 if hasattr(self.cell.state_size, '__len__'):
408 dim = self.cell.state_size[index]
409 else:
410 dim = self.cell.state_size
411 if value.shape != get_tuple_shape(dim):
412 raise ValueError('State ' + str(index) +
413 ' is incompatible with layer ' +
414 self.name + ': expected shape=' +
415 str(get_tuple_shape(dim)) +
416 ', found shape=' + str(value.shape))
417 # TODO(anjalisridhar): consider batch calls to `set_value`.
418 backend.set_value(state, value)
421class ConvLSTM2DCell(DropoutRNNCellMixin, Layer):
422 """Cell class for the ConvLSTM2D layer.
424 Args:
425 filters: Integer, the dimensionality of the output space
426 (i.e. the number of output filters in the convolution).
427 kernel_size: An integer or tuple/list of n integers, specifying the
428 dimensions of the convolution window.
429 strides: An integer or tuple/list of n integers,
430 specifying the strides of the convolution.
431 Specifying any stride value != 1 is incompatible with specifying
432 any `dilation_rate` value != 1.
433 padding: One of `"valid"` or `"same"` (case-insensitive).
434 `"valid"` means no padding. `"same"` results in padding evenly to
435 the left/right or up/down of the input such that output has the same
436 height/width dimension as the input.
437 data_format: A string,
438 one of `channels_last` (default) or `channels_first`.
439 It defaults to the `image_data_format` value found in your
440 Keras config file at `~/.keras/keras.json`.
441 If you never set it, then it will be "channels_last".
442 dilation_rate: An integer or tuple/list of n integers, specifying
443 the dilation rate to use for dilated convolution.
444 Currently, specifying any `dilation_rate` value != 1 is
445 incompatible with specifying any `strides` value != 1.
446 activation: Activation function to use.
447 If you don't specify anything, no activation is applied
448 (ie. "linear" activation: `a(x) = x`).
449 recurrent_activation: Activation function to use
450 for the recurrent step.
451 use_bias: Boolean, whether the layer uses a bias vector.
452 kernel_initializer: Initializer for the `kernel` weights matrix,
453 used for the linear transformation of the inputs.
454 recurrent_initializer: Initializer for the `recurrent_kernel`
455 weights matrix,
456 used for the linear transformation of the recurrent state.
457 bias_initializer: Initializer for the bias vector.
458 unit_forget_bias: Boolean.
459 If True, add 1 to the bias of the forget gate at initialization.
460 Use in combination with `bias_initializer="zeros"`.
461 This is recommended in [Jozefowicz et al., 2015](
462 http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf)
463 kernel_regularizer: Regularizer function applied to
464 the `kernel` weights matrix.
465 recurrent_regularizer: Regularizer function applied to
466 the `recurrent_kernel` weights matrix.
467 bias_regularizer: Regularizer function applied to the bias vector.
468 kernel_constraint: Constraint function applied to
469 the `kernel` weights matrix.
470 recurrent_constraint: Constraint function applied to
471 the `recurrent_kernel` weights matrix.
472 bias_constraint: Constraint function applied to the bias vector.
473 dropout: Float between 0 and 1.
474 Fraction of the units to drop for
475 the linear transformation of the inputs.
476 recurrent_dropout: Float between 0 and 1.
477 Fraction of the units to drop for
478 the linear transformation of the recurrent state.
480 Call arguments:
481 inputs: A 4D tensor.
482 states: List of state tensors corresponding to the previous timestep.
483 training: Python boolean indicating whether the layer should behave in
484 training mode or in inference mode. Only relevant when `dropout` or
485 `recurrent_dropout` is used.
486 """
488 def __init__(self,
489 filters,
490 kernel_size,
491 strides=(1, 1),
492 padding='valid',
493 data_format=None,
494 dilation_rate=(1, 1),
495 activation='tanh',
496 recurrent_activation='hard_sigmoid',
497 use_bias=True,
498 kernel_initializer='glorot_uniform',
499 recurrent_initializer='orthogonal',
500 bias_initializer='zeros',
501 unit_forget_bias=True,
502 kernel_regularizer=None,
503 recurrent_regularizer=None,
504 bias_regularizer=None,
505 kernel_constraint=None,
506 recurrent_constraint=None,
507 bias_constraint=None,
508 dropout=0.,
509 recurrent_dropout=0.,
510 **kwargs):
511 super(ConvLSTM2DCell, self).__init__(**kwargs)
512 self.filters = filters
513 self.kernel_size = conv_utils.normalize_tuple(kernel_size, 2, 'kernel_size')
514 self.strides = conv_utils.normalize_tuple(strides, 2, 'strides')
515 self.padding = conv_utils.normalize_padding(padding)
516 self.data_format = conv_utils.normalize_data_format(data_format)
517 self.dilation_rate = conv_utils.normalize_tuple(dilation_rate, 2,
518 'dilation_rate')
519 self.activation = activations.get(activation)
520 self.recurrent_activation = activations.get(recurrent_activation)
521 self.use_bias = use_bias
523 self.kernel_initializer = initializers.get(kernel_initializer)
524 self.recurrent_initializer = initializers.get(recurrent_initializer)
525 self.bias_initializer = initializers.get(bias_initializer)
526 self.unit_forget_bias = unit_forget_bias
528 self.kernel_regularizer = regularizers.get(kernel_regularizer)
529 self.recurrent_regularizer = regularizers.get(recurrent_regularizer)
530 self.bias_regularizer = regularizers.get(bias_regularizer)
532 self.kernel_constraint = constraints.get(kernel_constraint)
533 self.recurrent_constraint = constraints.get(recurrent_constraint)
534 self.bias_constraint = constraints.get(bias_constraint)
536 self.dropout = min(1., max(0., dropout))
537 self.recurrent_dropout = min(1., max(0., recurrent_dropout))
538 self.state_size = (self.filters, self.filters)
540 def build(self, input_shape):
542 if self.data_format == 'channels_first':
543 channel_axis = 1
544 else:
545 channel_axis = -1
546 if input_shape[channel_axis] is None:
547 raise ValueError('The channel dimension of the inputs '
548 'should be defined. Found `None`.')
549 input_dim = input_shape[channel_axis]
550 kernel_shape = self.kernel_size + (input_dim, self.filters * 4)
551 self.kernel_shape = kernel_shape
552 recurrent_kernel_shape = self.kernel_size + (self.filters, self.filters * 4)
554 self.kernel = self.add_weight(shape=kernel_shape,
555 initializer=self.kernel_initializer,
556 name='kernel',
557 regularizer=self.kernel_regularizer,
558 constraint=self.kernel_constraint)
559 self.recurrent_kernel = self.add_weight(
560 shape=recurrent_kernel_shape,
561 initializer=self.recurrent_initializer,
562 name='recurrent_kernel',
563 regularizer=self.recurrent_regularizer,
564 constraint=self.recurrent_constraint)
566 if self.use_bias:
567 if self.unit_forget_bias:
569 def bias_initializer(_, *args, **kwargs):
570 return backend.concatenate([
571 self.bias_initializer((self.filters,), *args, **kwargs),
572 initializers.get('ones')((self.filters,), *args, **kwargs),
573 self.bias_initializer((self.filters * 2,), *args, **kwargs),
574 ])
575 else:
576 bias_initializer = self.bias_initializer
577 self.bias = self.add_weight(
578 shape=(self.filters * 4,),
579 name='bias',
580 initializer=bias_initializer,
581 regularizer=self.bias_regularizer,
582 constraint=self.bias_constraint)
583 else:
584 self.bias = None
585 self.built = True
587 def call(self, inputs, states, training=None):
588 h_tm1 = states[0] # previous memory state
589 c_tm1 = states[1] # previous carry state
591 # dropout matrices for input units
592 dp_mask = self.get_dropout_mask_for_cell(inputs, training, count=4)
593 # dropout matrices for recurrent units
594 rec_dp_mask = self.get_recurrent_dropout_mask_for_cell(
595 h_tm1, training, count=4)
597 if 0 < self.dropout < 1.:
598 inputs_i = inputs * dp_mask[0]
599 inputs_f = inputs * dp_mask[1]
600 inputs_c = inputs * dp_mask[2]
601 inputs_o = inputs * dp_mask[3]
602 else:
603 inputs_i = inputs
604 inputs_f = inputs
605 inputs_c = inputs
606 inputs_o = inputs
608 if 0 < self.recurrent_dropout < 1.:
609 h_tm1_i = h_tm1 * rec_dp_mask[0]
610 h_tm1_f = h_tm1 * rec_dp_mask[1]
611 h_tm1_c = h_tm1 * rec_dp_mask[2]
612 h_tm1_o = h_tm1 * rec_dp_mask[3]
613 else:
614 h_tm1_i = h_tm1
615 h_tm1_f = h_tm1
616 h_tm1_c = h_tm1
617 h_tm1_o = h_tm1
619 (kernel_i, kernel_f,
620 kernel_c, kernel_o) = array_ops.split(self.kernel, 4, axis=3)
621 (recurrent_kernel_i,
622 recurrent_kernel_f,
623 recurrent_kernel_c,
624 recurrent_kernel_o) = array_ops.split(self.recurrent_kernel, 4, axis=3)
626 if self.use_bias:
627 bias_i, bias_f, bias_c, bias_o = array_ops.split(self.bias, 4)
628 else:
629 bias_i, bias_f, bias_c, bias_o = None, None, None, None
631 x_i = self.input_conv(inputs_i, kernel_i, bias_i, padding=self.padding)
632 x_f = self.input_conv(inputs_f, kernel_f, bias_f, padding=self.padding)
633 x_c = self.input_conv(inputs_c, kernel_c, bias_c, padding=self.padding)
634 x_o = self.input_conv(inputs_o, kernel_o, bias_o, padding=self.padding)
635 h_i = self.recurrent_conv(h_tm1_i, recurrent_kernel_i)
636 h_f = self.recurrent_conv(h_tm1_f, recurrent_kernel_f)
637 h_c = self.recurrent_conv(h_tm1_c, recurrent_kernel_c)
638 h_o = self.recurrent_conv(h_tm1_o, recurrent_kernel_o)
640 i = self.recurrent_activation(x_i + h_i)
641 f = self.recurrent_activation(x_f + h_f)
642 c = f * c_tm1 + i * self.activation(x_c + h_c)
643 o = self.recurrent_activation(x_o + h_o)
644 h = o * self.activation(c)
645 return h, [h, c]
647 def input_conv(self, x, w, b=None, padding='valid'):
648 conv_out = backend.conv2d(x, w, strides=self.strides,
649 padding=padding,
650 data_format=self.data_format,
651 dilation_rate=self.dilation_rate)
652 if b is not None:
653 conv_out = backend.bias_add(conv_out, b,
654 data_format=self.data_format)
655 return conv_out
657 def recurrent_conv(self, x, w):
658 conv_out = backend.conv2d(x, w, strides=(1, 1),
659 padding='same',
660 data_format=self.data_format)
661 return conv_out
663 def get_config(self):
664 config = {'filters': self.filters,
665 'kernel_size': self.kernel_size,
666 'strides': self.strides,
667 'padding': self.padding,
668 'data_format': self.data_format,
669 'dilation_rate': self.dilation_rate,
670 'activation': activations.serialize(self.activation),
671 'recurrent_activation': activations.serialize(
672 self.recurrent_activation),
673 'use_bias': self.use_bias,
674 'kernel_initializer': initializers.serialize(
675 self.kernel_initializer),
676 'recurrent_initializer': initializers.serialize(
677 self.recurrent_initializer),
678 'bias_initializer': initializers.serialize(self.bias_initializer),
679 'unit_forget_bias': self.unit_forget_bias,
680 'kernel_regularizer': regularizers.serialize(
681 self.kernel_regularizer),
682 'recurrent_regularizer': regularizers.serialize(
683 self.recurrent_regularizer),
684 'bias_regularizer': regularizers.serialize(self.bias_regularizer),
685 'kernel_constraint': constraints.serialize(
686 self.kernel_constraint),
687 'recurrent_constraint': constraints.serialize(
688 self.recurrent_constraint),
689 'bias_constraint': constraints.serialize(self.bias_constraint),
690 'dropout': self.dropout,
691 'recurrent_dropout': self.recurrent_dropout}
692 base_config = super(ConvLSTM2DCell, self).get_config()
693 return dict(list(base_config.items()) + list(config.items()))
696@keras_export('keras.layers.ConvLSTM2D')
697class ConvLSTM2D(ConvRNN2D):
698 """2D Convolutional LSTM layer.
700 A convolutional LSTM is similar to an LSTM, but the input transformations
701 and recurrent transformations are both convolutional. This layer is typically
702 used to process timeseries of images (i.e. video-like data).
704 It is known to perform well for weather data forecasting,
705 using inputs that are timeseries of 2D grids of sensor values.
706 It isn't usually applied to regular video data, due to its high computational
707 cost.
709 Args:
710 filters: Integer, the dimensionality of the output space
711 (i.e. the number of output filters in the convolution).
712 kernel_size: An integer or tuple/list of n integers, specifying the
713 dimensions of the convolution window.
714 strides: An integer or tuple/list of n integers,
715 specifying the strides of the convolution.
716 Specifying any stride value != 1 is incompatible with specifying
717 any `dilation_rate` value != 1.
718 padding: One of `"valid"` or `"same"` (case-insensitive).
719 `"valid"` means no padding. `"same"` results in padding evenly to
720 the left/right or up/down of the input such that output has the same
721 height/width dimension as the input.
722 data_format: A string,
723 one of `channels_last` (default) or `channels_first`.
724 The ordering of the dimensions in the inputs.
725 `channels_last` corresponds to inputs with shape
726 `(batch, time, ..., channels)`
727 while `channels_first` corresponds to
728 inputs with shape `(batch, time, channels, ...)`.
729 It defaults to the `image_data_format` value found in your
730 Keras config file at `~/.keras/keras.json`.
731 If you never set it, then it will be "channels_last".
732 dilation_rate: An integer or tuple/list of n integers, specifying
733 the dilation rate to use for dilated convolution.
734 Currently, specifying any `dilation_rate` value != 1 is
735 incompatible with specifying any `strides` value != 1.
736 activation: Activation function to use.
737 By default hyperbolic tangent activation function is applied
738 (`tanh(x)`).
739 recurrent_activation: Activation function to use
740 for the recurrent step.
741 use_bias: Boolean, whether the layer uses a bias vector.
742 kernel_initializer: Initializer for the `kernel` weights matrix,
743 used for the linear transformation of the inputs.
744 recurrent_initializer: Initializer for the `recurrent_kernel`
745 weights matrix,
746 used for the linear transformation of the recurrent state.
747 bias_initializer: Initializer for the bias vector.
748 unit_forget_bias: Boolean.
749 If True, add 1 to the bias of the forget gate at initialization.
750 Use in combination with `bias_initializer="zeros"`.
751 This is recommended in [Jozefowicz et al., 2015](
752 http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf)
753 kernel_regularizer: Regularizer function applied to
754 the `kernel` weights matrix.
755 recurrent_regularizer: Regularizer function applied to
756 the `recurrent_kernel` weights matrix.
757 bias_regularizer: Regularizer function applied to the bias vector.
758 activity_regularizer: Regularizer function applied to.
759 kernel_constraint: Constraint function applied to
760 the `kernel` weights matrix.
761 recurrent_constraint: Constraint function applied to
762 the `recurrent_kernel` weights matrix.
763 bias_constraint: Constraint function applied to the bias vector.
764 return_sequences: Boolean. Whether to return the last output
765 in the output sequence, or the full sequence. (default False)
766 return_state: Boolean Whether to return the last state
767 in addition to the output. (default False)
768 go_backwards: Boolean (default False).
769 If True, process the input sequence backwards.
770 stateful: Boolean (default False). If True, the last state
771 for each sample at index i in a batch will be used as initial
772 state for the sample of index i in the following batch.
773 dropout: Float between 0 and 1.
774 Fraction of the units to drop for
775 the linear transformation of the inputs.
776 recurrent_dropout: Float between 0 and 1.
777 Fraction of the units to drop for
778 the linear transformation of the recurrent state.
780 Call arguments:
781 inputs: A 5D float tensor (see input shape description below).
782 mask: Binary tensor of shape `(samples, timesteps)` indicating whether
783 a given timestep should be masked.
784 training: Python boolean indicating whether the layer should behave in
785 training mode or in inference mode. This argument is passed to the cell
786 when calling it. This is only relevant if `dropout` or `recurrent_dropout`
787 are set.
788 initial_state: List of initial state tensors to be passed to the first
789 call of the cell.
791 Input shape:
792 - If data_format='channels_first'
793 5D tensor with shape:
794 `(samples, time, channels, rows, cols)`
795 - If data_format='channels_last'
796 5D tensor with shape:
797 `(samples, time, rows, cols, channels)`
799 Output shape:
800 - If `return_state`: a list of tensors. The first tensor is
801 the output. The remaining tensors are the last states,
802 each 4D tensor with shape:
803 `(samples, filters, new_rows, new_cols)`
804 if data_format='channels_first'
805 or 4D tensor with shape:
806 `(samples, new_rows, new_cols, filters)`
807 if data_format='channels_last'.
808 `rows` and `cols` values might have changed due to padding.
809 - If `return_sequences`: 5D tensor with shape:
810 `(samples, timesteps, filters, new_rows, new_cols)`
811 if data_format='channels_first'
812 or 5D tensor with shape:
813 `(samples, timesteps, new_rows, new_cols, filters)`
814 if data_format='channels_last'.
815 - Else, 4D tensor with shape:
816 `(samples, filters, new_rows, new_cols)`
817 if data_format='channels_first'
818 or 4D tensor with shape:
819 `(samples, new_rows, new_cols, filters)`
820 if data_format='channels_last'.
822 Raises:
823 ValueError: in case of invalid constructor arguments.
825 References:
826 - [Shi et al., 2015](http://arxiv.org/abs/1506.04214v1)
827 (the current implementation does not include the feedback loop on the
828 cells output).
830 Example:
832 ```python
833 steps = 10
834 height = 32
835 width = 32
836 input_channels = 3
837 output_channels = 6
839 inputs = tf.keras.Input(shape=(steps, height, width, input_channels))
840 layer = tf.keras.layers.ConvLSTM2D(filters=output_channels, kernel_size=3)
841 outputs = layer(inputs)
842 ```
843 """
845 def __init__(self,
846 filters,
847 kernel_size,
848 strides=(1, 1),
849 padding='valid',
850 data_format=None,
851 dilation_rate=(1, 1),
852 activation='tanh',
853 recurrent_activation='hard_sigmoid',
854 use_bias=True,
855 kernel_initializer='glorot_uniform',
856 recurrent_initializer='orthogonal',
857 bias_initializer='zeros',
858 unit_forget_bias=True,
859 kernel_regularizer=None,
860 recurrent_regularizer=None,
861 bias_regularizer=None,
862 activity_regularizer=None,
863 kernel_constraint=None,
864 recurrent_constraint=None,
865 bias_constraint=None,
866 return_sequences=False,
867 return_state=False,
868 go_backwards=False,
869 stateful=False,
870 dropout=0.,
871 recurrent_dropout=0.,
872 **kwargs):
873 cell = ConvLSTM2DCell(filters=filters,
874 kernel_size=kernel_size,
875 strides=strides,
876 padding=padding,
877 data_format=data_format,
878 dilation_rate=dilation_rate,
879 activation=activation,
880 recurrent_activation=recurrent_activation,
881 use_bias=use_bias,
882 kernel_initializer=kernel_initializer,
883 recurrent_initializer=recurrent_initializer,
884 bias_initializer=bias_initializer,
885 unit_forget_bias=unit_forget_bias,
886 kernel_regularizer=kernel_regularizer,
887 recurrent_regularizer=recurrent_regularizer,
888 bias_regularizer=bias_regularizer,
889 kernel_constraint=kernel_constraint,
890 recurrent_constraint=recurrent_constraint,
891 bias_constraint=bias_constraint,
892 dropout=dropout,
893 recurrent_dropout=recurrent_dropout,
894 dtype=kwargs.get('dtype'))
895 super(ConvLSTM2D, self).__init__(cell,
896 return_sequences=return_sequences,
897 return_state=return_state,
898 go_backwards=go_backwards,
899 stateful=stateful,
900 **kwargs)
901 self.activity_regularizer = regularizers.get(activity_regularizer)
903 def call(self, inputs, mask=None, training=None, initial_state=None):
904 return super(ConvLSTM2D, self).call(inputs,
905 mask=mask,
906 training=training,
907 initial_state=initial_state)
909 @property
910 def filters(self):
911 return self.cell.filters
913 @property
914 def kernel_size(self):
915 return self.cell.kernel_size
917 @property
918 def strides(self):
919 return self.cell.strides
921 @property
922 def padding(self):
923 return self.cell.padding
925 @property
926 def data_format(self):
927 return self.cell.data_format
929 @property
930 def dilation_rate(self):
931 return self.cell.dilation_rate
933 @property
934 def activation(self):
935 return self.cell.activation
937 @property
938 def recurrent_activation(self):
939 return self.cell.recurrent_activation
941 @property
942 def use_bias(self):
943 return self.cell.use_bias
945 @property
946 def kernel_initializer(self):
947 return self.cell.kernel_initializer
949 @property
950 def recurrent_initializer(self):
951 return self.cell.recurrent_initializer
953 @property
954 def bias_initializer(self):
955 return self.cell.bias_initializer
957 @property
958 def unit_forget_bias(self):
959 return self.cell.unit_forget_bias
961 @property
962 def kernel_regularizer(self):
963 return self.cell.kernel_regularizer
965 @property
966 def recurrent_regularizer(self):
967 return self.cell.recurrent_regularizer
969 @property
970 def bias_regularizer(self):
971 return self.cell.bias_regularizer
973 @property
974 def kernel_constraint(self):
975 return self.cell.kernel_constraint
977 @property
978 def recurrent_constraint(self):
979 return self.cell.recurrent_constraint
981 @property
982 def bias_constraint(self):
983 return self.cell.bias_constraint
985 @property
986 def dropout(self):
987 return self.cell.dropout
989 @property
990 def recurrent_dropout(self):
991 return self.cell.recurrent_dropout
993 def get_config(self):
994 config = {'filters': self.filters,
995 'kernel_size': self.kernel_size,
996 'strides': self.strides,
997 'padding': self.padding,
998 'data_format': self.data_format,
999 'dilation_rate': self.dilation_rate,
1000 'activation': activations.serialize(self.activation),
1001 'recurrent_activation': activations.serialize(
1002 self.recurrent_activation),
1003 'use_bias': self.use_bias,
1004 'kernel_initializer': initializers.serialize(
1005 self.kernel_initializer),
1006 'recurrent_initializer': initializers.serialize(
1007 self.recurrent_initializer),
1008 'bias_initializer': initializers.serialize(self.bias_initializer),
1009 'unit_forget_bias': self.unit_forget_bias,
1010 'kernel_regularizer': regularizers.serialize(
1011 self.kernel_regularizer),
1012 'recurrent_regularizer': regularizers.serialize(
1013 self.recurrent_regularizer),
1014 'bias_regularizer': regularizers.serialize(self.bias_regularizer),
1015 'activity_regularizer': regularizers.serialize(
1016 self.activity_regularizer),
1017 'kernel_constraint': constraints.serialize(
1018 self.kernel_constraint),
1019 'recurrent_constraint': constraints.serialize(
1020 self.recurrent_constraint),
1021 'bias_constraint': constraints.serialize(self.bias_constraint),
1022 'dropout': self.dropout,
1023 'recurrent_dropout': self.recurrent_dropout}
1024 base_config = super(ConvLSTM2D, self).get_config()
1025 del base_config['cell']
1026 return dict(list(base_config.items()) + list(config.items()))
1028 @classmethod
1029 def from_config(cls, config):
1030 return cls(**config)