Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/layers/rnn/base_conv_lstm.py: 34%
200 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Base class for N-D convolutional LSTM layers."""
18import tensorflow.compat.v2 as tf
20from keras.src import activations
21from keras.src import backend
22from keras.src import constraints
23from keras.src import initializers
24from keras.src import regularizers
25from keras.src.engine import base_layer
26from keras.src.layers.rnn.base_conv_rnn import ConvRNN
27from keras.src.layers.rnn.dropout_rnn_cell_mixin import DropoutRNNCellMixin
28from keras.src.utils import conv_utils
31class ConvLSTMCell(DropoutRNNCellMixin, base_layer.BaseRandomLayer):
32 """Cell class for the ConvLSTM layer.
34 Args:
35 rank: Integer, rank of the convolution, e.g. "2" for 2D convolutions.
36 filters: Integer, the dimensionality of the output space (i.e. the number
37 of output filters in the convolution).
38 kernel_size: An integer or tuple/list of n integers, specifying the
39 dimensions of the convolution window.
40 strides: An integer or tuple/list of n integers, specifying the strides of
41 the convolution. Specifying any stride value != 1 is incompatible with
42 specifying any `dilation_rate` value != 1.
43 padding: One of `"valid"` or `"same"` (case-insensitive). `"valid"` means
44 no padding. `"same"` results in padding evenly to the left/right or
45 up/down of the input such that output has the same height/width
46 dimension as the input.
47 data_format: A string, one of `channels_last` (default) or
48 `channels_first`. When unspecified, uses
49 `image_data_format` value found in your Keras config file at
50 `~/.keras/keras.json` (if exists) else 'channels_last'.
51 Defaults to 'channels_last'.
52 dilation_rate: An integer or tuple/list of n integers, specifying the
53 dilation rate to use for dilated convolution. Currently, specifying any
54 `dilation_rate` value != 1 is incompatible with specifying any `strides`
55 value != 1.
56 activation: Activation function to use. If you don't specify anything, no
57 activation is applied
58 (ie. "linear" activation: `a(x) = x`).
59 recurrent_activation: Activation function to use for the recurrent step.
60 use_bias: Boolean, whether the layer uses a bias vector.
61 kernel_initializer: Initializer for the `kernel` weights matrix, used for
62 the linear transformation of the inputs.
63 recurrent_initializer: Initializer for the `recurrent_kernel` weights
64 matrix, used for the linear transformation of the recurrent state.
65 bias_initializer: Initializer for the bias vector.
66 unit_forget_bias: Boolean. If True, add 1 to the bias of the forget gate
67 at initialization. Use in combination with `bias_initializer="zeros"`.
68 This is recommended in [Jozefowicz et al., 2015](
69 http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf)
70 kernel_regularizer: Regularizer function applied to the `kernel` weights
71 matrix.
72 recurrent_regularizer: Regularizer function applied to the
73 `recurrent_kernel` weights matrix.
74 bias_regularizer: Regularizer function applied to the bias vector.
75 kernel_constraint: Constraint function applied to the `kernel` weights
76 matrix.
77 recurrent_constraint: Constraint function applied to the
78 `recurrent_kernel` weights matrix.
79 bias_constraint: Constraint function applied to the bias vector.
80 dropout: Float between 0 and 1. Fraction of the units to drop for the
81 linear transformation of the inputs.
82 recurrent_dropout: Float between 0 and 1. Fraction of the units to drop
83 for the linear transformation of the recurrent state.
84 Call arguments:
85 inputs: A (2+ `rank`)D tensor.
86 states: List of state tensors corresponding to the previous timestep.
87 training: Python boolean indicating whether the layer should behave in
88 training mode or in inference mode. Only relevant when `dropout` or
89 `recurrent_dropout` is used.
90 """
92 def __init__(
93 self,
94 rank,
95 filters,
96 kernel_size,
97 strides=1,
98 padding="valid",
99 data_format=None,
100 dilation_rate=1,
101 activation="tanh",
102 recurrent_activation="hard_sigmoid",
103 use_bias=True,
104 kernel_initializer="glorot_uniform",
105 recurrent_initializer="orthogonal",
106 bias_initializer="zeros",
107 unit_forget_bias=True,
108 kernel_regularizer=None,
109 recurrent_regularizer=None,
110 bias_regularizer=None,
111 kernel_constraint=None,
112 recurrent_constraint=None,
113 bias_constraint=None,
114 dropout=0.0,
115 recurrent_dropout=0.0,
116 **kwargs,
117 ):
118 super().__init__(**kwargs)
119 self.rank = rank
120 if self.rank > 3:
121 raise ValueError(
122 f"Rank {rank} convolutions are not currently "
123 f"implemented. Received: rank={rank}"
124 )
125 self.filters = filters
126 self.kernel_size = conv_utils.normalize_tuple(
127 kernel_size, self.rank, "kernel_size"
128 )
129 self.strides = conv_utils.normalize_tuple(
130 strides, self.rank, "strides", allow_zero=True
131 )
132 self.padding = conv_utils.normalize_padding(padding)
133 self.data_format = conv_utils.normalize_data_format(data_format)
134 self.dilation_rate = conv_utils.normalize_tuple(
135 dilation_rate, self.rank, "dilation_rate"
136 )
137 self.activation = activations.get(activation)
138 self.recurrent_activation = activations.get(recurrent_activation)
139 self.use_bias = use_bias
141 self.kernel_initializer = initializers.get(kernel_initializer)
142 self.recurrent_initializer = initializers.get(recurrent_initializer)
143 self.bias_initializer = initializers.get(bias_initializer)
144 self.unit_forget_bias = unit_forget_bias
146 self.kernel_regularizer = regularizers.get(kernel_regularizer)
147 self.recurrent_regularizer = regularizers.get(recurrent_regularizer)
148 self.bias_regularizer = regularizers.get(bias_regularizer)
150 self.kernel_constraint = constraints.get(kernel_constraint)
151 self.recurrent_constraint = constraints.get(recurrent_constraint)
152 self.bias_constraint = constraints.get(bias_constraint)
154 self.dropout = min(1.0, max(0.0, dropout))
155 self.recurrent_dropout = min(1.0, max(0.0, recurrent_dropout))
156 self.state_size = (self.filters, self.filters)
158 def build(self, input_shape):
159 super().build(input_shape)
160 if self.data_format == "channels_first":
161 channel_axis = 1
162 else:
163 channel_axis = -1
164 if input_shape[channel_axis] is None:
165 raise ValueError(
166 "The channel dimension of the inputs (last axis) should be "
167 "defined. Found None. Full input shape received: "
168 f"input_shape={input_shape}"
169 )
170 input_dim = input_shape[channel_axis]
171 self.kernel_shape = self.kernel_size + (input_dim, self.filters * 4)
172 recurrent_kernel_shape = self.kernel_size + (
173 self.filters,
174 self.filters * 4,
175 )
177 self.kernel = self.add_weight(
178 shape=self.kernel_shape,
179 initializer=self.kernel_initializer,
180 name="kernel",
181 regularizer=self.kernel_regularizer,
182 constraint=self.kernel_constraint,
183 )
184 self.recurrent_kernel = self.add_weight(
185 shape=recurrent_kernel_shape,
186 initializer=self.recurrent_initializer,
187 name="recurrent_kernel",
188 regularizer=self.recurrent_regularizer,
189 constraint=self.recurrent_constraint,
190 )
192 if self.use_bias:
193 if self.unit_forget_bias:
195 def bias_initializer(_, *args, **kwargs):
196 return backend.concatenate(
197 [
198 self.bias_initializer(
199 (self.filters,), *args, **kwargs
200 ),
201 initializers.get("ones")(
202 (self.filters,), *args, **kwargs
203 ),
204 self.bias_initializer(
205 (self.filters * 2,), *args, **kwargs
206 ),
207 ]
208 )
210 else:
211 bias_initializer = self.bias_initializer
212 self.bias = self.add_weight(
213 shape=(self.filters * 4,),
214 name="bias",
215 initializer=bias_initializer,
216 regularizer=self.bias_regularizer,
217 constraint=self.bias_constraint,
218 )
219 else:
220 self.bias = None
221 self.built = True
223 def call(self, inputs, states, training=None):
224 h_tm1 = states[0] # previous memory state
225 c_tm1 = states[1] # previous carry state
227 # dropout matrices for input units
228 dp_mask = self.get_dropout_mask_for_cell(inputs, training, count=4)
229 # dropout matrices for recurrent units
230 rec_dp_mask = self.get_recurrent_dropout_mask_for_cell(
231 h_tm1, training, count=4
232 )
234 if 0 < self.dropout < 1.0:
235 inputs_i = inputs * dp_mask[0]
236 inputs_f = inputs * dp_mask[1]
237 inputs_c = inputs * dp_mask[2]
238 inputs_o = inputs * dp_mask[3]
239 else:
240 inputs_i = inputs
241 inputs_f = inputs
242 inputs_c = inputs
243 inputs_o = inputs
245 if 0 < self.recurrent_dropout < 1.0:
246 h_tm1_i = h_tm1 * rec_dp_mask[0]
247 h_tm1_f = h_tm1 * rec_dp_mask[1]
248 h_tm1_c = h_tm1 * rec_dp_mask[2]
249 h_tm1_o = h_tm1 * rec_dp_mask[3]
250 else:
251 h_tm1_i = h_tm1
252 h_tm1_f = h_tm1
253 h_tm1_c = h_tm1
254 h_tm1_o = h_tm1
256 (kernel_i, kernel_f, kernel_c, kernel_o) = tf.split(
257 self.kernel, 4, axis=self.rank + 1
258 )
259 (
260 recurrent_kernel_i,
261 recurrent_kernel_f,
262 recurrent_kernel_c,
263 recurrent_kernel_o,
264 ) = tf.split(self.recurrent_kernel, 4, axis=self.rank + 1)
266 if self.use_bias:
267 bias_i, bias_f, bias_c, bias_o = tf.split(self.bias, 4)
268 else:
269 bias_i, bias_f, bias_c, bias_o = None, None, None, None
271 x_i = self.input_conv(inputs_i, kernel_i, bias_i, padding=self.padding)
272 x_f = self.input_conv(inputs_f, kernel_f, bias_f, padding=self.padding)
273 x_c = self.input_conv(inputs_c, kernel_c, bias_c, padding=self.padding)
274 x_o = self.input_conv(inputs_o, kernel_o, bias_o, padding=self.padding)
275 h_i = self.recurrent_conv(h_tm1_i, recurrent_kernel_i)
276 h_f = self.recurrent_conv(h_tm1_f, recurrent_kernel_f)
277 h_c = self.recurrent_conv(h_tm1_c, recurrent_kernel_c)
278 h_o = self.recurrent_conv(h_tm1_o, recurrent_kernel_o)
280 i = self.recurrent_activation(x_i + h_i)
281 f = self.recurrent_activation(x_f + h_f)
282 c = f * c_tm1 + i * self.activation(x_c + h_c)
283 o = self.recurrent_activation(x_o + h_o)
284 h = o * self.activation(c)
285 return h, [h, c]
287 @property
288 def _conv_func(self):
289 if self.rank == 1:
290 return backend.conv1d
291 if self.rank == 2:
292 return backend.conv2d
293 if self.rank == 3:
294 return backend.conv3d
296 def input_conv(self, x, w, b=None, padding="valid"):
297 conv_out = self._conv_func(
298 x,
299 w,
300 strides=self.strides,
301 padding=padding,
302 data_format=self.data_format,
303 dilation_rate=self.dilation_rate,
304 )
305 if b is not None:
306 conv_out = backend.bias_add(
307 conv_out, b, data_format=self.data_format
308 )
309 return conv_out
311 def recurrent_conv(self, x, w):
312 strides = conv_utils.normalize_tuple(
313 1, self.rank, "strides", allow_zero=True
314 )
315 conv_out = self._conv_func(
316 x, w, strides=strides, padding="same", data_format=self.data_format
317 )
318 return conv_out
320 def get_config(self):
321 config = {
322 "filters": self.filters,
323 "kernel_size": self.kernel_size,
324 "strides": self.strides,
325 "padding": self.padding,
326 "data_format": self.data_format,
327 "dilation_rate": self.dilation_rate,
328 "activation": activations.serialize(self.activation),
329 "recurrent_activation": activations.serialize(
330 self.recurrent_activation
331 ),
332 "use_bias": self.use_bias,
333 "kernel_initializer": initializers.serialize(
334 self.kernel_initializer
335 ),
336 "recurrent_initializer": initializers.serialize(
337 self.recurrent_initializer
338 ),
339 "bias_initializer": initializers.serialize(self.bias_initializer),
340 "unit_forget_bias": self.unit_forget_bias,
341 "kernel_regularizer": regularizers.serialize(
342 self.kernel_regularizer
343 ),
344 "recurrent_regularizer": regularizers.serialize(
345 self.recurrent_regularizer
346 ),
347 "bias_regularizer": regularizers.serialize(self.bias_regularizer),
348 "kernel_constraint": constraints.serialize(self.kernel_constraint),
349 "recurrent_constraint": constraints.serialize(
350 self.recurrent_constraint
351 ),
352 "bias_constraint": constraints.serialize(self.bias_constraint),
353 "dropout": self.dropout,
354 "recurrent_dropout": self.recurrent_dropout,
355 }
356 base_config = super().get_config()
357 return dict(list(base_config.items()) + list(config.items()))
360class ConvLSTM(ConvRNN):
361 """Abstract N-D Convolutional LSTM layer (used as implementation base).
363 Similar to an LSTM layer, but the input transformations
364 and recurrent transformations are both convolutional.
366 Args:
367 rank: Integer, rank of the convolution, e.g. "2" for 2D convolutions.
368 filters: Integer, the dimensionality of the output space
369 (i.e. the number of output filters in the convolution).
370 kernel_size: An integer or tuple/list of n integers, specifying the
371 dimensions of the convolution window.
372 strides: An integer or tuple/list of n integers,
373 specifying the strides of the convolution.
374 Specifying any stride value != 1 is incompatible with specifying
375 any `dilation_rate` value != 1.
376 padding: One of `"valid"` or `"same"` (case-insensitive).
377 `"valid"` means no padding. `"same"` results in padding evenly to
378 the left/right or up/down of the input such that output has the same
379 height/width dimension as the input.
380 data_format: A string,
381 one of `channels_last` (default) or `channels_first`.
382 The ordering of the dimensions in the inputs.
383 `channels_last` corresponds to inputs with shape
384 `(batch, time, ..., channels)`
385 while `channels_first` corresponds to
386 inputs with shape `(batch, time, channels, ...)`.
387 When unspecified, uses
388 `image_data_format` value found in your Keras config file at
389 `~/.keras/keras.json` (if exists) else 'channels_last'.
390 Defaults to 'channels_last'.
391 dilation_rate: An integer or tuple/list of n integers, specifying
392 the dilation rate to use for dilated convolution.
393 Currently, specifying any `dilation_rate` value != 1 is
394 incompatible with specifying any `strides` value != 1.
395 activation: Activation function to use.
396 By default hyperbolic tangent activation function is applied
397 (`tanh(x)`).
398 recurrent_activation: Activation function to use
399 for the recurrent step.
400 use_bias: Boolean, whether the layer uses a bias vector.
401 kernel_initializer: Initializer for the `kernel` weights matrix,
402 used for the linear transformation of the inputs.
403 recurrent_initializer: Initializer for the `recurrent_kernel`
404 weights matrix,
405 used for the linear transformation of the recurrent state.
406 bias_initializer: Initializer for the bias vector.
407 unit_forget_bias: Boolean.
408 If True, add 1 to the bias of the forget gate at initialization.
409 Use in combination with `bias_initializer="zeros"`.
410 This is recommended in [Jozefowicz et al., 2015](
411 http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf)
412 kernel_regularizer: Regularizer function applied to
413 the `kernel` weights matrix.
414 recurrent_regularizer: Regularizer function applied to
415 the `recurrent_kernel` weights matrix.
416 bias_regularizer: Regularizer function applied to the bias vector.
417 activity_regularizer: Regularizer function applied to.
418 kernel_constraint: Constraint function applied to
419 the `kernel` weights matrix.
420 recurrent_constraint: Constraint function applied to
421 the `recurrent_kernel` weights matrix.
422 bias_constraint: Constraint function applied to the bias vector.
423 return_sequences: Boolean. Whether to return the last output
424 in the output sequence, or the full sequence. (default False)
425 return_state: Boolean Whether to return the last state
426 in addition to the output. (default False)
427 go_backwards: Boolean (default False).
428 If True, process the input sequence backwards.
429 stateful: Boolean (default False). If True, the last state
430 for each sample at index i in a batch will be used as initial
431 state for the sample of index i in the following batch.
432 dropout: Float between 0 and 1.
433 Fraction of the units to drop for
434 the linear transformation of the inputs.
435 recurrent_dropout: Float between 0 and 1.
436 Fraction of the units to drop for
437 the linear transformation of the recurrent state.
438 """
440 def __init__(
441 self,
442 rank,
443 filters,
444 kernel_size,
445 strides=1,
446 padding="valid",
447 data_format=None,
448 dilation_rate=1,
449 activation="tanh",
450 recurrent_activation="hard_sigmoid",
451 use_bias=True,
452 kernel_initializer="glorot_uniform",
453 recurrent_initializer="orthogonal",
454 bias_initializer="zeros",
455 unit_forget_bias=True,
456 kernel_regularizer=None,
457 recurrent_regularizer=None,
458 bias_regularizer=None,
459 activity_regularizer=None,
460 kernel_constraint=None,
461 recurrent_constraint=None,
462 bias_constraint=None,
463 return_sequences=False,
464 return_state=False,
465 go_backwards=False,
466 stateful=False,
467 dropout=0.0,
468 recurrent_dropout=0.0,
469 **kwargs,
470 ):
471 cell = ConvLSTMCell(
472 rank=rank,
473 filters=filters,
474 kernel_size=kernel_size,
475 strides=strides,
476 padding=padding,
477 data_format=data_format,
478 dilation_rate=dilation_rate,
479 activation=activation,
480 recurrent_activation=recurrent_activation,
481 use_bias=use_bias,
482 kernel_initializer=kernel_initializer,
483 recurrent_initializer=recurrent_initializer,
484 bias_initializer=bias_initializer,
485 unit_forget_bias=unit_forget_bias,
486 kernel_regularizer=kernel_regularizer,
487 recurrent_regularizer=recurrent_regularizer,
488 bias_regularizer=bias_regularizer,
489 kernel_constraint=kernel_constraint,
490 recurrent_constraint=recurrent_constraint,
491 bias_constraint=bias_constraint,
492 dropout=dropout,
493 recurrent_dropout=recurrent_dropout,
494 name="conv_lstm_cell",
495 dtype=kwargs.get("dtype"),
496 )
497 super().__init__(
498 rank,
499 cell,
500 return_sequences=return_sequences,
501 return_state=return_state,
502 go_backwards=go_backwards,
503 stateful=stateful,
504 **kwargs,
505 )
506 self.activity_regularizer = regularizers.get(activity_regularizer)
508 def call(self, inputs, mask=None, training=None, initial_state=None):
509 return super().call(
510 inputs, mask=mask, training=training, initial_state=initial_state
511 )
513 @property
514 def filters(self):
515 return self.cell.filters
517 @property
518 def kernel_size(self):
519 return self.cell.kernel_size
521 @property
522 def strides(self):
523 return self.cell.strides
525 @property
526 def padding(self):
527 return self.cell.padding
529 @property
530 def data_format(self):
531 return self.cell.data_format
533 @property
534 def dilation_rate(self):
535 return self.cell.dilation_rate
537 @property
538 def activation(self):
539 return self.cell.activation
541 @property
542 def recurrent_activation(self):
543 return self.cell.recurrent_activation
545 @property
546 def use_bias(self):
547 return self.cell.use_bias
549 @property
550 def kernel_initializer(self):
551 return self.cell.kernel_initializer
553 @property
554 def recurrent_initializer(self):
555 return self.cell.recurrent_initializer
557 @property
558 def bias_initializer(self):
559 return self.cell.bias_initializer
561 @property
562 def unit_forget_bias(self):
563 return self.cell.unit_forget_bias
565 @property
566 def kernel_regularizer(self):
567 return self.cell.kernel_regularizer
569 @property
570 def recurrent_regularizer(self):
571 return self.cell.recurrent_regularizer
573 @property
574 def bias_regularizer(self):
575 return self.cell.bias_regularizer
577 @property
578 def kernel_constraint(self):
579 return self.cell.kernel_constraint
581 @property
582 def recurrent_constraint(self):
583 return self.cell.recurrent_constraint
585 @property
586 def bias_constraint(self):
587 return self.cell.bias_constraint
589 @property
590 def dropout(self):
591 return self.cell.dropout
593 @property
594 def recurrent_dropout(self):
595 return self.cell.recurrent_dropout
597 def get_config(self):
598 config = {
599 "filters": self.filters,
600 "kernel_size": self.kernel_size,
601 "strides": self.strides,
602 "padding": self.padding,
603 "data_format": self.data_format,
604 "dilation_rate": self.dilation_rate,
605 "activation": activations.serialize(self.activation),
606 "recurrent_activation": activations.serialize(
607 self.recurrent_activation
608 ),
609 "use_bias": self.use_bias,
610 "kernel_initializer": initializers.serialize(
611 self.kernel_initializer
612 ),
613 "recurrent_initializer": initializers.serialize(
614 self.recurrent_initializer
615 ),
616 "bias_initializer": initializers.serialize(self.bias_initializer),
617 "unit_forget_bias": self.unit_forget_bias,
618 "kernel_regularizer": regularizers.serialize(
619 self.kernel_regularizer
620 ),
621 "recurrent_regularizer": regularizers.serialize(
622 self.recurrent_regularizer
623 ),
624 "bias_regularizer": regularizers.serialize(self.bias_regularizer),
625 "activity_regularizer": regularizers.serialize(
626 self.activity_regularizer
627 ),
628 "kernel_constraint": constraints.serialize(self.kernel_constraint),
629 "recurrent_constraint": constraints.serialize(
630 self.recurrent_constraint
631 ),
632 "bias_constraint": constraints.serialize(self.bias_constraint),
633 "dropout": self.dropout,
634 "recurrent_dropout": self.recurrent_dropout,
635 }
636 base_config = super().get_config()
637 del base_config["cell"]
638 return dict(list(base_config.items()) + list(config.items()))
640 @classmethod
641 def from_config(cls, config):
642 return cls(**config)