Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/layers/reshaping/zero_padding3d.py: 25%

53 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Keras zero-padding layer for 3D input.""" 

16 

17 

18import tensorflow.compat.v2 as tf 

19 

20from keras.src import backend 

21from keras.src.engine.base_layer import Layer 

22from keras.src.engine.input_spec import InputSpec 

23from keras.src.utils import conv_utils 

24 

25# isort: off 

26from tensorflow.python.util.tf_export import keras_export 

27 

28 

29@keras_export("keras.layers.ZeroPadding3D") 

30class ZeroPadding3D(Layer): 

31 """Zero-padding layer for 3D data (spatial or spatio-temporal). 

32 

33 Examples: 

34 

35 >>> input_shape = (1, 1, 2, 2, 3) 

36 >>> x = np.arange(np.prod(input_shape)).reshape(input_shape) 

37 >>> y = tf.keras.layers.ZeroPadding3D(padding=2)(x) 

38 >>> print(y.shape) 

39 (1, 5, 6, 6, 3) 

40 

41 Args: 

42 padding: Int, or tuple of 3 ints, or tuple of 3 tuples of 2 ints. 

43 - If int: the same symmetric padding 

44 is applied to height and width. 

45 - If tuple of 3 ints: 

46 interpreted as two different 

47 symmetric padding values for height and width: 

48 `(symmetric_dim1_pad, symmetric_dim2_pad, symmetric_dim3_pad)`. 

49 - If tuple of 3 tuples of 2 ints: 

50 interpreted as 

51 `((left_dim1_pad, right_dim1_pad), (left_dim2_pad, 

52 right_dim2_pad), (left_dim3_pad, right_dim3_pad))` 

53 data_format: A string, 

54 one of `channels_last` (default) or `channels_first`. 

55 The ordering of the dimensions in the inputs. 

56 `channels_last` corresponds to inputs with shape 

57 `(batch_size, spatial_dim1, spatial_dim2, spatial_dim3, channels)` 

58 while `channels_first` corresponds to inputs with shape 

59 `(batch_size, channels, spatial_dim1, spatial_dim2, spatial_dim3)`. 

60 When unspecified, uses 

61 `image_data_format` value found in your Keras config file at 

62 `~/.keras/keras.json` (if exists) else 'channels_last'. 

63 Defaults to 'channels_last'. 

64 

65 Input shape: 

66 5D tensor with shape: 

67 - If `data_format` is `"channels_last"`: 

68 `(batch_size, first_axis_to_pad, second_axis_to_pad, 

69 third_axis_to_pad, depth)` 

70 - If `data_format` is `"channels_first"`: 

71 `(batch_size, depth, first_axis_to_pad, second_axis_to_pad, 

72 third_axis_to_pad)` 

73 

74 Output shape: 

75 5D tensor with shape: 

76 - If `data_format` is `"channels_last"`: 

77 `(batch_size, first_padded_axis, second_padded_axis, 

78 third_axis_to_pad, depth)` 

79 - If `data_format` is `"channels_first"`: 

80 `(batch_size, depth, first_padded_axis, second_padded_axis, 

81 third_axis_to_pad)` 

82 """ 

83 

84 def __init__(self, padding=(1, 1, 1), data_format=None, **kwargs): 

85 super().__init__(**kwargs) 

86 self.data_format = conv_utils.normalize_data_format(data_format) 

87 if isinstance(padding, int): 

88 self.padding = ( 

89 (padding, padding), 

90 (padding, padding), 

91 (padding, padding), 

92 ) 

93 elif hasattr(padding, "__len__"): 

94 if len(padding) != 3: 

95 raise ValueError( 

96 f"`padding` should have 3 elements. Received: {padding}." 

97 ) 

98 dim1_padding = conv_utils.normalize_tuple( 

99 padding[0], 2, "1st entry of padding", allow_zero=True 

100 ) 

101 dim2_padding = conv_utils.normalize_tuple( 

102 padding[1], 2, "2nd entry of padding", allow_zero=True 

103 ) 

104 dim3_padding = conv_utils.normalize_tuple( 

105 padding[2], 2, "3rd entry of padding", allow_zero=True 

106 ) 

107 self.padding = (dim1_padding, dim2_padding, dim3_padding) 

108 else: 

109 raise ValueError( 

110 "`padding` should be either an int, " 

111 "a tuple of 3 ints " 

112 "(symmetric_dim1_pad, symmetric_dim2_pad, symmetric_dim3_pad), " 

113 "or a tuple of 3 tuples of 2 ints " 

114 "((left_dim1_pad, right_dim1_pad)," 

115 " (left_dim2_pad, right_dim2_pad)," 

116 " (left_dim3_pad, right_dim2_pad)). " 

117 f"Received: {padding}." 

118 ) 

119 self.input_spec = InputSpec(ndim=5) 

120 

121 def compute_output_shape(self, input_shape): 

122 input_shape = tf.TensorShape(input_shape).as_list() 

123 if self.data_format == "channels_first": 

124 if input_shape[2] is not None: 

125 dim1 = input_shape[2] + self.padding[0][0] + self.padding[0][1] 

126 else: 

127 dim1 = None 

128 if input_shape[3] is not None: 

129 dim2 = input_shape[3] + self.padding[1][0] + self.padding[1][1] 

130 else: 

131 dim2 = None 

132 if input_shape[4] is not None: 

133 dim3 = input_shape[4] + self.padding[2][0] + self.padding[2][1] 

134 else: 

135 dim3 = None 

136 return tf.TensorShape( 

137 [input_shape[0], input_shape[1], dim1, dim2, dim3] 

138 ) 

139 elif self.data_format == "channels_last": 

140 if input_shape[1] is not None: 

141 dim1 = input_shape[1] + self.padding[0][0] + self.padding[0][1] 

142 else: 

143 dim1 = None 

144 if input_shape[2] is not None: 

145 dim2 = input_shape[2] + self.padding[1][0] + self.padding[1][1] 

146 else: 

147 dim2 = None 

148 if input_shape[3] is not None: 

149 dim3 = input_shape[3] + self.padding[2][0] + self.padding[2][1] 

150 else: 

151 dim3 = None 

152 return tf.TensorShape( 

153 [input_shape[0], dim1, dim2, dim3, input_shape[4]] 

154 ) 

155 

156 def call(self, inputs): 

157 return backend.spatial_3d_padding( 

158 inputs, padding=self.padding, data_format=self.data_format 

159 ) 

160 

161 def get_config(self): 

162 config = {"padding": self.padding, "data_format": self.data_format} 

163 base_config = super().get_config() 

164 return dict(list(base_config.items()) + list(config.items())) 

165