Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/layers/reshaping/repeat_vector.py: 52%
23 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Contains the RepeatVector layer."""
18import tensorflow.compat.v2 as tf
20from keras.src import backend
21from keras.src.engine.base_layer import Layer
22from keras.src.engine.input_spec import InputSpec
24# isort: off
25from tensorflow.python.util.tf_export import keras_export
28@keras_export("keras.layers.RepeatVector")
29class RepeatVector(Layer):
30 """Repeats the input n times.
32 Example:
34 ```python
35 model = Sequential()
36 model.add(Dense(32, input_dim=32))
37 # now: model.output_shape == (None, 32)
38 # note: `None` is the batch dimension
40 model.add(RepeatVector(3))
41 # now: model.output_shape == (None, 3, 32)
42 ```
44 Args:
45 n: Integer, repetition factor.
46 Input shape: 2D tensor of shape `(num_samples, features)`.
47 Output shape: 3D tensor of shape `(num_samples, n, features)`.
48 """
50 def __init__(self, n, **kwargs):
51 super().__init__(**kwargs)
52 self.n = n
53 if not isinstance(n, int):
54 raise TypeError(
55 f"Expected an integer value for `n`, got {type(n)}."
56 )
57 self.input_spec = InputSpec(ndim=2)
59 def compute_output_shape(self, input_shape):
60 input_shape = tf.TensorShape(input_shape).as_list()
61 return tf.TensorShape([input_shape[0], self.n, input_shape[1]])
63 def call(self, inputs):
64 return backend.repeat(inputs, self.n)
66 def get_config(self):
67 config = {"n": self.n}
68 base_config = super().get_config()
69 return dict(list(base_config.items()) + list(config.items()))