Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/layers/reshaping/cropping1d.py: 43%

28 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Keras cropping layer for 1D input.""" 

16 

17 

18import tensorflow.compat.v2 as tf 

19 

20from keras.src.engine.base_layer import Layer 

21from keras.src.engine.input_spec import InputSpec 

22from keras.src.utils import conv_utils 

23 

24# isort: off 

25from tensorflow.python.util.tf_export import keras_export 

26 

27 

28@keras_export("keras.layers.Cropping1D") 

29class Cropping1D(Layer): 

30 """Cropping layer for 1D input (e.g. temporal sequence). 

31 

32 It crops along the time dimension (axis 1). 

33 

34 Examples: 

35 

36 >>> input_shape = (2, 3, 2) 

37 >>> x = np.arange(np.prod(input_shape)).reshape(input_shape) 

38 >>> print(x) 

39 [[[ 0 1] 

40 [ 2 3] 

41 [ 4 5]] 

42 [[ 6 7] 

43 [ 8 9] 

44 [10 11]]] 

45 >>> y = tf.keras.layers.Cropping1D(cropping=1)(x) 

46 >>> print(y) 

47 tf.Tensor( 

48 [[[2 3]] 

49 [[8 9]]], shape=(2, 1, 2), dtype=int64) 

50 

51 Args: 

52 cropping: Int or tuple of int (length 2) 

53 How many units should be trimmed off at the beginning and end of 

54 the cropping dimension (axis 1). 

55 If a single int is provided, the same value will be used for both. 

56 

57 Input shape: 

58 3D tensor with shape `(batch_size, axis_to_crop, features)` 

59 

60 Output shape: 

61 3D tensor with shape `(batch_size, cropped_axis, features)` 

62 """ 

63 

64 def __init__(self, cropping=(1, 1), **kwargs): 

65 super().__init__(**kwargs) 

66 self.cropping = conv_utils.normalize_tuple( 

67 cropping, 2, "cropping", allow_zero=True 

68 ) 

69 self.input_spec = InputSpec(ndim=3) 

70 

71 def compute_output_shape(self, input_shape): 

72 input_shape = tf.TensorShape(input_shape).as_list() 

73 if input_shape[1] is not None: 

74 length = input_shape[1] - self.cropping[0] - self.cropping[1] 

75 else: 

76 length = None 

77 return tf.TensorShape([input_shape[0], length, input_shape[2]]) 

78 

79 def call(self, inputs): 

80 if ( 

81 inputs.shape[1] is not None 

82 and sum(self.cropping) >= inputs.shape[1] 

83 ): 

84 raise ValueError( 

85 "cropping parameter of Cropping layer must be " 

86 "greater than the input shape. Received: inputs.shape=" 

87 f"{inputs.shape}, and cropping={self.cropping}" 

88 ) 

89 if self.cropping[1] == 0: 

90 return inputs[:, self.cropping[0] :, :] 

91 else: 

92 return inputs[:, self.cropping[0] : -self.cropping[1], :] 

93 

94 def get_config(self): 

95 config = {"cropping": self.cropping} 

96 base_config = super().get_config() 

97 return dict(list(base_config.items()) + list(config.items())) 

98