Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/layers/reshaping/zero_padding3d.py: 25%
53 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Keras zero-padding layer for 3D input."""
18import tensorflow.compat.v2 as tf
20from keras.src import backend
21from keras.src.engine.base_layer import Layer
22from keras.src.engine.input_spec import InputSpec
23from keras.src.utils import conv_utils
25# isort: off
26from tensorflow.python.util.tf_export import keras_export
29@keras_export("keras.layers.ZeroPadding3D")
30class ZeroPadding3D(Layer):
31 """Zero-padding layer for 3D data (spatial or spatio-temporal).
33 Examples:
35 >>> input_shape = (1, 1, 2, 2, 3)
36 >>> x = np.arange(np.prod(input_shape)).reshape(input_shape)
37 >>> y = tf.keras.layers.ZeroPadding3D(padding=2)(x)
38 >>> print(y.shape)
39 (1, 5, 6, 6, 3)
41 Args:
42 padding: Int, or tuple of 3 ints, or tuple of 3 tuples of 2 ints.
43 - If int: the same symmetric padding
44 is applied to height and width.
45 - If tuple of 3 ints:
46 interpreted as two different
47 symmetric padding values for height and width:
48 `(symmetric_dim1_pad, symmetric_dim2_pad, symmetric_dim3_pad)`.
49 - If tuple of 3 tuples of 2 ints:
50 interpreted as
51 `((left_dim1_pad, right_dim1_pad), (left_dim2_pad,
52 right_dim2_pad), (left_dim3_pad, right_dim3_pad))`
53 data_format: A string,
54 one of `channels_last` (default) or `channels_first`.
55 The ordering of the dimensions in the inputs.
56 `channels_last` corresponds to inputs with shape
57 `(batch_size, spatial_dim1, spatial_dim2, spatial_dim3, channels)`
58 while `channels_first` corresponds to inputs with shape
59 `(batch_size, channels, spatial_dim1, spatial_dim2, spatial_dim3)`.
60 When unspecified, uses
61 `image_data_format` value found in your Keras config file at
62 `~/.keras/keras.json` (if exists) else 'channels_last'.
63 Defaults to 'channels_last'.
65 Input shape:
66 5D tensor with shape:
67 - If `data_format` is `"channels_last"`:
68 `(batch_size, first_axis_to_pad, second_axis_to_pad,
69 third_axis_to_pad, depth)`
70 - If `data_format` is `"channels_first"`:
71 `(batch_size, depth, first_axis_to_pad, second_axis_to_pad,
72 third_axis_to_pad)`
74 Output shape:
75 5D tensor with shape:
76 - If `data_format` is `"channels_last"`:
77 `(batch_size, first_padded_axis, second_padded_axis,
78 third_axis_to_pad, depth)`
79 - If `data_format` is `"channels_first"`:
80 `(batch_size, depth, first_padded_axis, second_padded_axis,
81 third_axis_to_pad)`
82 """
84 def __init__(self, padding=(1, 1, 1), data_format=None, **kwargs):
85 super().__init__(**kwargs)
86 self.data_format = conv_utils.normalize_data_format(data_format)
87 if isinstance(padding, int):
88 self.padding = (
89 (padding, padding),
90 (padding, padding),
91 (padding, padding),
92 )
93 elif hasattr(padding, "__len__"):
94 if len(padding) != 3:
95 raise ValueError(
96 f"`padding` should have 3 elements. Received: {padding}."
97 )
98 dim1_padding = conv_utils.normalize_tuple(
99 padding[0], 2, "1st entry of padding", allow_zero=True
100 )
101 dim2_padding = conv_utils.normalize_tuple(
102 padding[1], 2, "2nd entry of padding", allow_zero=True
103 )
104 dim3_padding = conv_utils.normalize_tuple(
105 padding[2], 2, "3rd entry of padding", allow_zero=True
106 )
107 self.padding = (dim1_padding, dim2_padding, dim3_padding)
108 else:
109 raise ValueError(
110 "`padding` should be either an int, "
111 "a tuple of 3 ints "
112 "(symmetric_dim1_pad, symmetric_dim2_pad, symmetric_dim3_pad), "
113 "or a tuple of 3 tuples of 2 ints "
114 "((left_dim1_pad, right_dim1_pad),"
115 " (left_dim2_pad, right_dim2_pad),"
116 " (left_dim3_pad, right_dim2_pad)). "
117 f"Received: {padding}."
118 )
119 self.input_spec = InputSpec(ndim=5)
121 def compute_output_shape(self, input_shape):
122 input_shape = tf.TensorShape(input_shape).as_list()
123 if self.data_format == "channels_first":
124 if input_shape[2] is not None:
125 dim1 = input_shape[2] + self.padding[0][0] + self.padding[0][1]
126 else:
127 dim1 = None
128 if input_shape[3] is not None:
129 dim2 = input_shape[3] + self.padding[1][0] + self.padding[1][1]
130 else:
131 dim2 = None
132 if input_shape[4] is not None:
133 dim3 = input_shape[4] + self.padding[2][0] + self.padding[2][1]
134 else:
135 dim3 = None
136 return tf.TensorShape(
137 [input_shape[0], input_shape[1], dim1, dim2, dim3]
138 )
139 elif self.data_format == "channels_last":
140 if input_shape[1] is not None:
141 dim1 = input_shape[1] + self.padding[0][0] + self.padding[0][1]
142 else:
143 dim1 = None
144 if input_shape[2] is not None:
145 dim2 = input_shape[2] + self.padding[1][0] + self.padding[1][1]
146 else:
147 dim2 = None
148 if input_shape[3] is not None:
149 dim3 = input_shape[3] + self.padding[2][0] + self.padding[2][1]
150 else:
151 dim3 = None
152 return tf.TensorShape(
153 [input_shape[0], dim1, dim2, dim3, input_shape[4]]
154 )
156 def call(self, inputs):
157 return backend.spatial_3d_padding(
158 inputs, padding=self.padding, data_format=self.data_format
159 )
161 def get_config(self):
162 config = {"padding": self.padding, "data_format": self.data_format}
163 base_config = super().get_config()
164 return dict(list(base_config.items()) + list(config.items()))