Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/layers/reshaping/up_sampling3d.py: 42%
31 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Keras upsampling layer for 3D inputs."""
18import tensorflow.compat.v2 as tf
20from keras.src import backend
21from keras.src.engine.base_layer import Layer
22from keras.src.engine.input_spec import InputSpec
23from keras.src.utils import conv_utils
25# isort: off
26from tensorflow.python.util.tf_export import keras_export
29@keras_export("keras.layers.UpSampling3D")
30class UpSampling3D(Layer):
31 """Upsampling layer for 3D inputs.
33 Repeats the 1st, 2nd and 3rd dimensions
34 of the data by `size[0]`, `size[1]` and `size[2]` respectively.
36 Examples:
38 >>> input_shape = (2, 1, 2, 1, 3)
39 >>> x = tf.constant(1, shape=input_shape)
40 >>> y = tf.keras.layers.UpSampling3D(size=2)(x)
41 >>> print(y.shape)
42 (2, 2, 4, 2, 3)
44 Args:
45 size: Int, or tuple of 3 integers.
46 The upsampling factors for dim1, dim2 and dim3.
47 data_format: A string,
48 one of `channels_last` (default) or `channels_first`.
49 The ordering of the dimensions in the inputs.
50 `channels_last` corresponds to inputs with shape
51 `(batch_size, spatial_dim1, spatial_dim2, spatial_dim3, channels)`
52 while `channels_first` corresponds to inputs with shape
53 `(batch_size, channels, spatial_dim1, spatial_dim2, spatial_dim3)`.
54 When unspecified, uses
55 `image_data_format` value found in your Keras config file at
56 `~/.keras/keras.json` (if exists) else 'channels_last'.
57 Defaults to 'channels_last'.
59 Input shape:
60 5D tensor with shape:
61 - If `data_format` is `"channels_last"`:
62 `(batch_size, dim1, dim2, dim3, channels)`
63 - If `data_format` is `"channels_first"`:
64 `(batch_size, channels, dim1, dim2, dim3)`
66 Output shape:
67 5D tensor with shape:
68 - If `data_format` is `"channels_last"`:
69 `(batch_size, upsampled_dim1, upsampled_dim2, upsampled_dim3,
70 channels)`
71 - If `data_format` is `"channels_first"`:
72 `(batch_size, channels, upsampled_dim1, upsampled_dim2,
73 upsampled_dim3)`
74 """
76 def __init__(self, size=(2, 2, 2), data_format=None, **kwargs):
77 self.data_format = conv_utils.normalize_data_format(data_format)
78 self.size = conv_utils.normalize_tuple(size, 3, "size")
79 self.input_spec = InputSpec(ndim=5)
80 super().__init__(**kwargs)
82 def compute_output_shape(self, input_shape):
83 input_shape = tf.TensorShape(input_shape).as_list()
84 if self.data_format == "channels_first":
85 dim1 = (
86 self.size[0] * input_shape[2]
87 if input_shape[2] is not None
88 else None
89 )
90 dim2 = (
91 self.size[1] * input_shape[3]
92 if input_shape[3] is not None
93 else None
94 )
95 dim3 = (
96 self.size[2] * input_shape[4]
97 if input_shape[4] is not None
98 else None
99 )
100 return tf.TensorShape(
101 [input_shape[0], input_shape[1], dim1, dim2, dim3]
102 )
103 else:
104 dim1 = (
105 self.size[0] * input_shape[1]
106 if input_shape[1] is not None
107 else None
108 )
109 dim2 = (
110 self.size[1] * input_shape[2]
111 if input_shape[2] is not None
112 else None
113 )
114 dim3 = (
115 self.size[2] * input_shape[3]
116 if input_shape[3] is not None
117 else None
118 )
119 return tf.TensorShape(
120 [input_shape[0], dim1, dim2, dim3, input_shape[4]]
121 )
123 def call(self, inputs):
124 return backend.resize_volumes(
125 inputs, self.size[0], self.size[1], self.size[2], self.data_format
126 )
128 def get_config(self):
129 config = {"size": self.size, "data_format": self.data_format}
130 base_config = super().get_config()
131 return dict(list(base_config.items()) + list(config.items()))