Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/layers/rnn/conv_lstm1d.py: 86%
7 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""1D Convolutional LSTM layer."""
18from keras.src.layers.rnn.base_conv_lstm import ConvLSTM
20# isort: off
21from tensorflow.python.util.tf_export import keras_export
24@keras_export("keras.layers.ConvLSTM1D")
25class ConvLSTM1D(ConvLSTM):
26 """1D Convolutional LSTM.
28 Similar to an LSTM layer, but the input transformations
29 and recurrent transformations are both convolutional.
31 Args:
32 filters: Integer, the dimensionality of the output space (i.e. the number
33 of output filters in the convolution).
34 kernel_size: An integer or tuple/list of n integers, specifying the
35 dimensions of the convolution window.
36 strides: An integer or tuple/list of n integers, specifying the strides of
37 the convolution. Specifying any stride value != 1 is incompatible with
38 specifying any `dilation_rate` value != 1.
39 padding: One of `"valid"` or `"same"` (case-insensitive). `"valid"` means
40 no padding. `"same"` results in padding evenly to the left/right or
41 up/down of the input such that output has the same height/width
42 dimension as the input.
43 data_format: A string, one of `channels_last` (default) or
44 `channels_first`. The ordering of the dimensions in the inputs.
45 `channels_last` corresponds to inputs with shape `(batch, time, ...,
46 channels)` while `channels_first` corresponds to inputs with shape
47 `(batch, time, channels, ...)`. When unspecified, uses
48 `image_data_format` value found in your Keras config file at
49 `~/.keras/keras.json` (if exists) else 'channels_last'.
50 Defaults to 'channels_last'.
51 dilation_rate: An integer or tuple/list of n integers, specifying the
52 dilation rate to use for dilated convolution. Currently, specifying any
53 `dilation_rate` value != 1 is incompatible with specifying any `strides`
54 value != 1.
55 activation: Activation function to use. By default hyperbolic tangent
56 activation function is applied (`tanh(x)`).
57 recurrent_activation: Activation function to use for the recurrent step.
58 use_bias: Boolean, whether the layer uses a bias vector.
59 kernel_initializer: Initializer for the `kernel` weights matrix, used for
60 the linear transformation of the inputs.
61 recurrent_initializer: Initializer for the `recurrent_kernel` weights
62 matrix, used for the linear transformation of the recurrent state.
63 bias_initializer: Initializer for the bias vector.
64 unit_forget_bias: Boolean. If True, add 1 to the bias of the forget gate
65 at initialization. Use in combination with `bias_initializer="zeros"`.
66 This is recommended in [Jozefowicz et al., 2015](
67 http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf)
68 kernel_regularizer: Regularizer function applied to the `kernel` weights
69 matrix.
70 recurrent_regularizer: Regularizer function applied to the
71 `recurrent_kernel` weights matrix.
72 bias_regularizer: Regularizer function applied to the bias vector.
73 activity_regularizer: Regularizer function applied to.
74 kernel_constraint: Constraint function applied to the `kernel` weights
75 matrix.
76 recurrent_constraint: Constraint function applied to the
77 `recurrent_kernel` weights matrix.
78 bias_constraint: Constraint function applied to the bias vector.
79 return_sequences: Boolean. Whether to return the last output in the output
80 sequence, or the full sequence. (default False)
81 return_state: Boolean Whether to return the last state in addition to the
82 output. (default False)
83 go_backwards: Boolean (default False). If True, process the input sequence
84 backwards.
85 stateful: Boolean (default False). If True, the last state for each sample
86 at index i in a batch will be used as initial state for the sample of
87 index i in the following batch.
88 dropout: Float between 0 and 1. Fraction of the units to drop for the
89 linear transformation of the inputs.
90 recurrent_dropout: Float between 0 and 1. Fraction of the units to drop
91 for the linear transformation of the recurrent state.
92 Call arguments:
93 inputs: A 4D tensor.
94 mask: Binary tensor of shape `(samples, timesteps)` indicating whether a
95 given timestep should be masked.
96 training: Python boolean indicating whether the layer should behave in
97 training mode or in inference mode. This argument is passed to the cell
98 when calling it. This is only relevant if `dropout` or
99 `recurrent_dropout` are set.
100 initial_state: List of initial state tensors to be passed to the first
101 call of the cell.
102 Input shape: - If data_format='channels_first'
103 4D tensor with shape: `(samples, time, channels, rows)` - If
104 data_format='channels_last'
105 4D tensor with shape: `(samples, time, rows, channels)`
106 Output shape:
107 - If `return_state`: a list of tensors. The first tensor is the output.
108 The remaining tensors are the last states,
109 each 3D tensor with shape: `(samples, filters, new_rows)` if
110 data_format='channels_first'
111 or shape: `(samples, new_rows, filters)` if data_format='channels_last'.
112 `rows` values might have changed due to padding.
113 - If `return_sequences`: 4D tensor with shape: `(samples, timesteps,
114 filters, new_rows)` if data_format='channels_first'
115 or shape: `(samples, timesteps, new_rows, filters)` if
116 data_format='channels_last'.
117 - Else, 3D tensor with shape: `(samples, filters, new_rows)` if
118 data_format='channels_first'
119 or shape: `(samples, new_rows, filters)` if data_format='channels_last'.
121 Raises:
122 ValueError: in case of invalid constructor arguments.
124 References:
125 - [Shi et al., 2015](http://arxiv.org/abs/1506.04214v1)
126 (the current implementation does not include the feedback loop on the
127 cells output).
128 """
130 def __init__(
131 self,
132 filters,
133 kernel_size,
134 strides=1,
135 padding="valid",
136 data_format=None,
137 dilation_rate=1,
138 activation="tanh",
139 recurrent_activation="hard_sigmoid",
140 use_bias=True,
141 kernel_initializer="glorot_uniform",
142 recurrent_initializer="orthogonal",
143 bias_initializer="zeros",
144 unit_forget_bias=True,
145 kernel_regularizer=None,
146 recurrent_regularizer=None,
147 bias_regularizer=None,
148 activity_regularizer=None,
149 kernel_constraint=None,
150 recurrent_constraint=None,
151 bias_constraint=None,
152 return_sequences=False,
153 return_state=False,
154 go_backwards=False,
155 stateful=False,
156 dropout=0.0,
157 recurrent_dropout=0.0,
158 **kwargs
159 ):
160 super().__init__(
161 rank=1,
162 filters=filters,
163 kernel_size=kernel_size,
164 strides=strides,
165 padding=padding,
166 data_format=data_format,
167 dilation_rate=dilation_rate,
168 activation=activation,
169 recurrent_activation=recurrent_activation,
170 use_bias=use_bias,
171 kernel_initializer=kernel_initializer,
172 recurrent_initializer=recurrent_initializer,
173 bias_initializer=bias_initializer,
174 unit_forget_bias=unit_forget_bias,
175 kernel_regularizer=kernel_regularizer,
176 recurrent_regularizer=recurrent_regularizer,
177 bias_regularizer=bias_regularizer,
178 activity_regularizer=activity_regularizer,
179 kernel_constraint=kernel_constraint,
180 recurrent_constraint=recurrent_constraint,
181 bias_constraint=bias_constraint,
182 return_sequences=return_sequences,
183 return_state=return_state,
184 go_backwards=go_backwards,
185 stateful=stateful,
186 dropout=dropout,
187 recurrent_dropout=recurrent_dropout,
188 **kwargs
189 )