Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/layers/reshaping/repeat_vector.py: 52%

23 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Contains the RepeatVector layer.""" 

16 

17 

18import tensorflow.compat.v2 as tf 

19 

20from keras.src import backend 

21from keras.src.engine.base_layer import Layer 

22from keras.src.engine.input_spec import InputSpec 

23 

24# isort: off 

25from tensorflow.python.util.tf_export import keras_export 

26 

27 

28@keras_export("keras.layers.RepeatVector") 

29class RepeatVector(Layer): 

30 """Repeats the input n times. 

31 

32 Example: 

33 

34 ```python 

35 model = Sequential() 

36 model.add(Dense(32, input_dim=32)) 

37 # now: model.output_shape == (None, 32) 

38 # note: `None` is the batch dimension 

39 

40 model.add(RepeatVector(3)) 

41 # now: model.output_shape == (None, 3, 32) 

42 ``` 

43 

44 Args: 

45 n: Integer, repetition factor. 

46 Input shape: 2D tensor of shape `(num_samples, features)`. 

47 Output shape: 3D tensor of shape `(num_samples, n, features)`. 

48 """ 

49 

50 def __init__(self, n, **kwargs): 

51 super().__init__(**kwargs) 

52 self.n = n 

53 if not isinstance(n, int): 

54 raise TypeError( 

55 f"Expected an integer value for `n`, got {type(n)}." 

56 ) 

57 self.input_spec = InputSpec(ndim=2) 

58 

59 def compute_output_shape(self, input_shape): 

60 input_shape = tf.TensorShape(input_shape).as_list() 

61 return tf.TensorShape([input_shape[0], self.n, input_shape[1]]) 

62 

63 def call(self, inputs): 

64 return backend.repeat(inputs, self.n) 

65 

66 def get_config(self): 

67 config = {"n": self.n} 

68 base_config = super().get_config() 

69 return dict(list(base_config.items()) + list(config.items())) 

70