Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/layers/rnn/conv_lstm3d.py: 86%
7 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""3D Convolutional LSTM layer."""
18from keras.src.layers.rnn.base_conv_lstm import ConvLSTM
20# isort: off
21from tensorflow.python.util.tf_export import keras_export
24@keras_export("keras.layers.ConvLSTM3D")
25class ConvLSTM3D(ConvLSTM):
26 """3D Convolutional LSTM.
28 Similar to an LSTM layer, but the input transformations
29 and recurrent transformations are both convolutional.
31 Args:
32 filters: Integer, the dimensionality of the output space (i.e. the number
33 of output filters in the convolution).
34 kernel_size: An integer or tuple/list of n integers, specifying the
35 dimensions of the convolution window.
36 strides: An integer or tuple/list of n integers, specifying the strides of
37 the convolution. Specifying any stride value != 1 is incompatible with
38 specifying any `dilation_rate` value != 1.
39 padding: One of `"valid"` or `"same"` (case-insensitive). `"valid"` means
40 no padding. `"same"` results in padding evenly to the left/right or
41 up/down of the input such that output has the same height/width
42 dimension as the input.
43 data_format: A string, one of `channels_last` (default) or
44 `channels_first`. The ordering of the dimensions in the inputs.
45 `channels_last` corresponds to inputs with shape `(batch, time, ...,
46 channels)` while `channels_first` corresponds to inputs with shape
47 `(batch, time, channels, ...)`. When unspecified, uses
48 `image_data_format` value found in your Keras config file at
49 `~/.keras/keras.json` (if exists) else 'channels_last'.
50 Defaults to 'channels_last'.
51 dilation_rate: An integer or tuple/list of n integers, specifying the
52 dilation rate to use for dilated convolution. Currently, specifying any
53 `dilation_rate` value != 1 is incompatible with specifying any `strides`
54 value != 1.
55 activation: Activation function to use. By default hyperbolic tangent
56 activation function is applied (`tanh(x)`).
57 recurrent_activation: Activation function to use for the recurrent step.
58 use_bias: Boolean, whether the layer uses a bias vector.
59 kernel_initializer: Initializer for the `kernel` weights matrix, used for
60 the linear transformation of the inputs.
61 recurrent_initializer: Initializer for the `recurrent_kernel` weights
62 matrix, used for the linear transformation of the recurrent state.
63 bias_initializer: Initializer for the bias vector.
64 unit_forget_bias: Boolean. If True, add 1 to the bias of the forget gate
65 at initialization. Use in combination with `bias_initializer="zeros"`.
66 This is recommended in [Jozefowicz et al., 2015](
67 http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf)
68 kernel_regularizer: Regularizer function applied to the `kernel` weights
69 matrix.
70 recurrent_regularizer: Regularizer function applied to the
71 `recurrent_kernel` weights matrix.
72 bias_regularizer: Regularizer function applied to the bias vector.
73 activity_regularizer: Regularizer function applied to.
74 kernel_constraint: Constraint function applied to the `kernel` weights
75 matrix.
76 recurrent_constraint: Constraint function applied to the
77 `recurrent_kernel` weights matrix.
78 bias_constraint: Constraint function applied to the bias vector.
79 return_sequences: Boolean. Whether to return the last output in the output
80 sequence, or the full sequence. (default False)
81 return_state: Boolean Whether to return the last state in addition to the
82 output. (default False)
83 go_backwards: Boolean (default False). If True, process the input sequence
84 backwards.
85 stateful: Boolean (default False). If True, the last state for each sample
86 at index i in a batch will be used as initial state for the sample of
87 index i in the following batch.
88 dropout: Float between 0 and 1. Fraction of the units to drop for the
89 linear transformation of the inputs.
90 recurrent_dropout: Float between 0 and 1. Fraction of the units to drop
91 for the linear transformation of the recurrent state.
92 Call arguments:
93 inputs: A 6D tensor.
94 mask: Binary tensor of shape `(samples, timesteps)` indicating whether a
95 given timestep should be masked.
96 training: Python boolean indicating whether the layer should behave in
97 training mode or in inference mode. This argument is passed to the cell
98 when calling it. This is only relevant if `dropout` or
99 `recurrent_dropout` are set.
100 initial_state: List of initial state tensors to be passed to the first
101 call of the cell.
102 Input shape: - If data_format='channels_first'
103 6D tensor with shape: `(samples, time, channels, rows, cols, depth)` -
104 If data_format='channels_last'
105 5D tensor with shape: `(samples, time, rows, cols, depth, channels)`
106 Output shape:
107 - If `return_state`: a list of tensors. The first tensor is the output.
108 The remaining tensors are the last states,
109 each 5D tensor with shape: `(samples, filters, new_rows, new_cols,
110 new_depth)` if data_format='channels_first'
111 or shape: `(samples, new_rows, new_cols, new_depth, filters)` if
112 data_format='channels_last'. `rows`, `cols`, and `depth` values might
113 have changed due to padding.
114 - If `return_sequences`: 6D tensor with shape: `(samples, timesteps,
115 filters, new_rows, new_cols, new_depth)` if data_format='channels_first'
116 or shape: `(samples, timesteps, new_rows, new_cols, new_depth, filters)`
117 if data_format='channels_last'.
118 - Else, 5D tensor with shape: `(samples, filters, new_rows, new_cols,
119 new_depth)` if data_format='channels_first'
120 or shape: `(samples, new_rows, new_cols, new_depth, filters)` if
121 data_format='channels_last'.
123 Raises:
124 ValueError: in case of invalid constructor arguments.
126 References:
127 - [Shi et al., 2015](http://arxiv.org/abs/1506.04214v1)
128 (the current implementation does not include the feedback loop on the
129 cells output).
130 """
132 def __init__(
133 self,
134 filters,
135 kernel_size,
136 strides=(1, 1, 1),
137 padding="valid",
138 data_format=None,
139 dilation_rate=(1, 1, 1),
140 activation="tanh",
141 recurrent_activation="hard_sigmoid",
142 use_bias=True,
143 kernel_initializer="glorot_uniform",
144 recurrent_initializer="orthogonal",
145 bias_initializer="zeros",
146 unit_forget_bias=True,
147 kernel_regularizer=None,
148 recurrent_regularizer=None,
149 bias_regularizer=None,
150 activity_regularizer=None,
151 kernel_constraint=None,
152 recurrent_constraint=None,
153 bias_constraint=None,
154 return_sequences=False,
155 return_state=False,
156 go_backwards=False,
157 stateful=False,
158 dropout=0.0,
159 recurrent_dropout=0.0,
160 **kwargs
161 ):
162 super().__init__(
163 rank=3,
164 filters=filters,
165 kernel_size=kernel_size,
166 strides=strides,
167 padding=padding,
168 data_format=data_format,
169 dilation_rate=dilation_rate,
170 activation=activation,
171 recurrent_activation=recurrent_activation,
172 use_bias=use_bias,
173 kernel_initializer=kernel_initializer,
174 recurrent_initializer=recurrent_initializer,
175 bias_initializer=bias_initializer,
176 unit_forget_bias=unit_forget_bias,
177 kernel_regularizer=kernel_regularizer,
178 recurrent_regularizer=recurrent_regularizer,
179 bias_regularizer=bias_regularizer,
180 activity_regularizer=activity_regularizer,
181 kernel_constraint=kernel_constraint,
182 recurrent_constraint=recurrent_constraint,
183 bias_constraint=bias_constraint,
184 return_sequences=return_sequences,
185 return_state=return_state,
186 go_backwards=go_backwards,
187 stateful=stateful,
188 dropout=dropout,
189 recurrent_dropout=recurrent_dropout,
190 **kwargs
191 )