Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/layers/rnn/time_distributed.py: 15%
124 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Wrapper layer to apply every temporal slice of an input."""
18import tensorflow.compat.v2 as tf
20from keras.src import backend
21from keras.src.engine.base_layer import Layer
22from keras.src.engine.input_spec import InputSpec
23from keras.src.layers.rnn.base_wrapper import Wrapper
24from keras.src.utils import generic_utils
25from keras.src.utils import layer_utils
26from keras.src.utils import tf_utils
28# isort: off
29from tensorflow.python.util.tf_export import keras_export
32@keras_export("keras.layers.TimeDistributed")
33class TimeDistributed(Wrapper):
34 """This wrapper allows to apply a layer to every temporal slice of an input.
36 Every input should be at least 3D, and the dimension of index one of the
37 first input will be considered to be the temporal dimension.
39 Consider a batch of 32 video samples, where each sample is a 128x128 RGB
40 image with `channels_last` data format, across 10 timesteps.
41 The batch input shape is `(32, 10, 128, 128, 3)`.
43 You can then use `TimeDistributed` to apply the same `Conv2D` layer to each
44 of the 10 timesteps, independently:
46 >>> inputs = tf.keras.Input(shape=(10, 128, 128, 3))
47 >>> conv_2d_layer = tf.keras.layers.Conv2D(64, (3, 3))
48 >>> outputs = tf.keras.layers.TimeDistributed(conv_2d_layer)(inputs)
49 >>> outputs.shape
50 TensorShape([None, 10, 126, 126, 64])
52 Because `TimeDistributed` applies the same instance of `Conv2D` to each of
53 the timestamps, the same set of weights are used at each timestamp.
55 Args:
56 layer: a `tf.keras.layers.Layer` instance.
58 Call arguments:
59 inputs: Input tensor of shape (batch, time, ...) or nested tensors,
60 and each of which has shape (batch, time, ...).
61 training: Python boolean indicating whether the layer should behave in
62 training mode or in inference mode. This argument is passed to the
63 wrapped layer (only if the layer supports this argument).
64 mask: Binary tensor of shape `(samples, timesteps)` indicating whether
65 a given timestep should be masked. This argument is passed to the
66 wrapped layer (only if the layer supports this argument).
68 Raises:
69 ValueError: If not initialized with a `tf.keras.layers.Layer` instance.
70 """
72 def __init__(self, layer, **kwargs):
73 if not isinstance(layer, Layer):
74 raise ValueError(
75 "Please initialize `TimeDistributed` layer with a "
76 f"`tf.keras.layers.Layer` instance. Received: {layer}"
77 )
78 super().__init__(layer, **kwargs)
79 self.supports_masking = True
81 # It is safe to use the fast, reshape-based approach with all of our
82 # built-in Layers.
83 self._always_use_reshape = layer_utils.is_builtin_layer(
84 layer
85 ) and not getattr(layer, "stateful", False)
87 def _get_shape_tuple(self, init_tuple, tensor, start_idx):
88 """Finds non-specific dimensions in the static shapes.
90 The static shapes are replaced with the corresponding dynamic shapes of
91 the tensor.
92 Args:
93 init_tuple: a tuple, the first part of the output shape
94 tensor: the tensor from which to get the (static and dynamic) shapes
95 as the last part of the output shape
96 start_idx: int, which indicate the first dimension to take from
97 the static shape of the tensor
98 Returns:
99 The new shape with the first part from `init_tuple` and the last part
100 from or `tensor.shape`, where every `None` is replaced by the
101 corresponding dimension from `tf.shape(tensor)`.
102 """
103 # replace all None in int_shape by backend.shape
104 int_shape = backend.int_shape(tensor)[start_idx:]
105 if not any(s is None for s in int_shape):
106 return init_tuple + int_shape
107 shape = backend.shape(tensor)
108 int_shape = list(int_shape)
109 for i, s in enumerate(int_shape):
110 if s is None:
111 int_shape[i] = shape[start_idx + i]
112 return init_tuple + tuple(int_shape)
114 def _remove_timesteps(self, dims):
115 dims = dims.as_list()
116 return tf.TensorShape([dims[0]] + dims[2:])
118 def build(self, input_shape):
119 input_shape = tf_utils.convert_shapes(input_shape, to_tuples=False)
120 input_dims = tf.nest.flatten(
121 tf.nest.map_structure(lambda x: x.ndims, input_shape)
122 )
123 if any(dim < 3 for dim in input_dims):
124 raise ValueError(
125 "`TimeDistributed` Layer should be passed an `input_shape ` "
126 f"with at least 3 dimensions, received: {input_shape}"
127 )
128 # Don't enforce the batch or time dimension.
129 self.input_spec = tf.nest.map_structure(
130 lambda x: InputSpec(shape=[None, None] + x.as_list()[2:]),
131 input_shape,
132 )
133 child_input_shape = tf.nest.map_structure(
134 self._remove_timesteps, input_shape
135 )
136 child_input_shape = tf_utils.convert_shapes(child_input_shape)
137 super().build(tuple(child_input_shape))
138 self.built = True
140 def compute_output_shape(self, input_shape):
141 input_shape = tf_utils.convert_shapes(input_shape, to_tuples=False)
143 child_input_shape = tf.nest.map_structure(
144 self._remove_timesteps, input_shape
145 )
146 child_output_shape = self.layer.compute_output_shape(child_input_shape)
147 child_output_shape = tf_utils.convert_shapes(
148 child_output_shape, to_tuples=False
149 )
150 timesteps = tf_utils.convert_shapes(input_shape)
151 timesteps = tf.nest.flatten(timesteps)[1]
153 def insert_timesteps(dims):
154 dims = dims.as_list()
155 return tf.TensorShape([dims[0], timesteps] + dims[1:])
157 return tf.nest.map_structure(insert_timesteps, child_output_shape)
159 def call(self, inputs, training=None, mask=None):
160 kwargs = {}
161 if generic_utils.has_arg(self.layer.call, "training"):
162 kwargs["training"] = training
164 input_shape = tf.nest.map_structure(
165 lambda x: tf.TensorShape(backend.int_shape(x)), inputs
166 )
167 batch_size = tf_utils.convert_shapes(input_shape)
168 batch_size = tf.nest.flatten(batch_size)[0]
169 if batch_size and not self._always_use_reshape:
170 inputs, row_lengths = backend.convert_inputs_if_ragged(inputs)
171 is_ragged_input = row_lengths is not None
172 input_length = tf_utils.convert_shapes(input_shape)
173 input_length = tf.nest.flatten(input_length)[1]
175 # batch size matters, use rnn-based implementation
176 def step(x, _):
177 output = self.layer(x, **kwargs)
178 return output, []
180 _, outputs, _ = backend.rnn(
181 step,
182 inputs,
183 initial_states=[],
184 input_length=row_lengths[0]
185 if is_ragged_input
186 else input_length,
187 mask=mask,
188 unroll=False,
189 )
191 y = tf.nest.map_structure(
192 lambda output: backend.maybe_convert_to_ragged(
193 is_ragged_input, output, row_lengths
194 ),
195 outputs,
196 )
197 else:
198 # No batch size specified, therefore the layer will be able
199 # to process batches of any size.
200 # We can go with reshape-based implementation for performance.
201 is_ragged_input = tf.nest.map_structure(
202 lambda x: isinstance(x, tf.RaggedTensor), inputs
203 )
204 is_ragged_input = tf.nest.flatten(is_ragged_input)
205 if all(is_ragged_input):
206 input_values = tf.nest.map_structure(lambda x: x.values, inputs)
207 input_row_lenghts = tf.nest.map_structure(
208 lambda x: x.nested_row_lengths()[0], inputs
209 )
210 y = self.layer(input_values, **kwargs)
211 y = tf.nest.map_structure(
212 tf.RaggedTensor.from_row_lengths, y, input_row_lenghts
213 )
214 elif any(is_ragged_input):
215 raise ValueError(
216 "All inputs has to be either ragged or not, "
217 f"but not mixed. Received: {inputs}"
218 )
219 else:
220 input_length = tf_utils.convert_shapes(input_shape)
221 input_length = tf.nest.flatten(input_length)[1]
222 if not input_length:
223 input_length = tf.nest.map_structure(
224 lambda x: tf.shape(x)[1], inputs
225 )
226 input_length = generic_utils.to_list(
227 tf.nest.flatten(input_length)
228 )[0]
230 inner_input_shape = tf.nest.map_structure(
231 lambda x: self._get_shape_tuple((-1,), x, 2), inputs
232 )
233 # Shape: (num_samples * timesteps, ...). And track the
234 # transformation in self._input_map.
235 inputs = tf.__internal__.nest.map_structure_up_to(
236 inputs, tf.reshape, inputs, inner_input_shape
237 )
238 # (num_samples * timesteps, ...)
239 if (
240 generic_utils.has_arg(self.layer.call, "mask")
241 and mask is not None
242 ):
243 inner_mask_shape = self._get_shape_tuple((-1,), mask, 2)
244 kwargs["mask"] = backend.reshape(mask, inner_mask_shape)
246 y = self.layer(inputs, **kwargs)
248 # Reconstruct the output shape by re-splitting the 0th dimension
249 # back into (num_samples, timesteps, ...)
250 # We use batch_size when available so that the 0th dimension is
251 # set in the static shape of the reshaped output
252 reshape_batch_size = batch_size if batch_size else -1
253 output_shape = tf.nest.map_structure(
254 lambda tensor: self._get_shape_tuple(
255 (reshape_batch_size, input_length), tensor, 1
256 ),
257 y,
258 )
259 y = tf.__internal__.nest.map_structure_up_to(
260 y, tf.reshape, y, output_shape
261 )
263 return y
265 def compute_mask(self, inputs, mask=None):
266 """Computes an output mask tensor for Embedding layer.
268 This is based on the inputs, mask, and the inner layer.
269 If batch size is specified:
270 Simply return the input `mask`. (An rnn-based implementation with
271 more than one rnn inputs is required but not supported in tf.keras yet.)
272 Otherwise we call `compute_mask` of the inner layer at each time step.
273 If the output mask at each time step is not `None`:
274 (E.g., inner layer is Masking or RNN)
275 Concatenate all of them and return the concatenation.
276 If the output mask at each time step is `None` and the input mask is not
277 `None`:(E.g., inner layer is Dense)
278 Reduce the input_mask to 2 dimensions and return it.
279 Otherwise (both the output mask and the input mask are `None`):
280 (E.g., `mask` is not used at all)
281 Return `None`.
283 Args:
284 inputs: Tensor with shape [batch size, timesteps, ...] indicating the
285 input to TimeDistributed. If static shape information is available
286 for "batch size", `mask` is returned unmodified.
287 mask: Either None (indicating no masking) or a Tensor indicating the
288 input mask for TimeDistributed. The shape can be static or dynamic.
290 Returns:
291 Either None (no masking), or a [batch size, timesteps, ...] Tensor
292 with an output mask for the TimeDistributed layer with the shape
293 beyond the second dimension being the value of the input mask shape(if
294 the computed output mask is none), an output mask with the shape
295 beyond the first dimension being the value of the mask shape(if mask
296 is not None) or output mask with the shape beyond the first dimension
297 being the value of the computed output shape.
299 """
300 # cases need to call the layer.compute_mask when input_mask is None:
301 # Masking layer and Embedding layer with mask_zero
302 input_shape = tf.nest.map_structure(
303 lambda x: tf.TensorShape(backend.int_shape(x)), inputs
304 )
305 input_shape = tf_utils.convert_shapes(input_shape, to_tuples=False)
306 batch_size = tf_utils.convert_shapes(input_shape)
307 batch_size = tf.nest.flatten(batch_size)[0]
308 is_ragged_input = tf.nest.map_structure(
309 lambda x: isinstance(x, tf.RaggedTensor), inputs
310 )
311 is_ragged_input = generic_utils.to_list(
312 tf.nest.flatten(is_ragged_input)
313 )
314 if batch_size and not self._always_use_reshape or any(is_ragged_input):
315 # batch size matters, we currently do not handle mask explicitly, or
316 # if the layer always uses reshape approach, or the input is a
317 # ragged tensor.
318 return mask
319 inner_mask = mask
320 if inner_mask is not None:
321 inner_mask_shape = self._get_shape_tuple((-1,), mask, 2)
322 inner_mask = backend.reshape(inner_mask, inner_mask_shape)
323 inner_input_shape = tf.nest.map_structure(
324 lambda tensor: self._get_shape_tuple((-1,), tensor, 2), inputs
325 )
326 inner_inputs = tf.__internal__.nest.map_structure_up_to(
327 inputs, tf.reshape, inputs, inner_input_shape
328 )
329 output_mask = self.layer.compute_mask(inner_inputs, inner_mask)
330 if output_mask is None:
331 if mask is None:
332 return None
333 # input_mask is not None, and output_mask is None:
334 # we should return a not-None mask
335 output_mask = mask
336 for _ in range(2, len(backend.int_shape(mask))):
337 output_mask = backend.any(output_mask, axis=-1)
338 else:
339 # output_mask is not None. We need to reshape it
340 input_length = tf_utils.convert_shapes(input_shape)
341 input_length = tf.nest.flatten(input_length)[1]
342 if not input_length:
343 input_length = tf.nest.map_structure(
344 lambda x: backend.shape(x)[1], inputs
345 )
346 input_length = tf.nest.flatten(input_length)[0]
347 reshape_batch_size = batch_size if batch_size else -1
348 output_mask_shape = self._get_shape_tuple(
349 (reshape_batch_size, input_length), output_mask, 1
350 )
351 output_mask = backend.reshape(output_mask, output_mask_shape)
352 return output_mask