Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/layers/reshaping/up_sampling2d.py: 45%
31 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Keras upsampling layer for 2D inputs."""
18import tensorflow.compat.v2 as tf
20from keras.src import backend
21from keras.src.engine.base_layer import Layer
22from keras.src.engine.input_spec import InputSpec
23from keras.src.utils import conv_utils
24from keras.src.utils import image_utils
26# isort: off
27from tensorflow.python.util.tf_export import keras_export
30@keras_export("keras.layers.UpSampling2D")
31class UpSampling2D(Layer):
32 """Upsampling layer for 2D inputs.
34 Repeats the rows and columns of the data
35 by `size[0]` and `size[1]` respectively.
37 Examples:
39 >>> input_shape = (2, 2, 1, 3)
40 >>> x = np.arange(np.prod(input_shape)).reshape(input_shape)
41 >>> print(x)
42 [[[[ 0 1 2]]
43 [[ 3 4 5]]]
44 [[[ 6 7 8]]
45 [[ 9 10 11]]]]
46 >>> y = tf.keras.layers.UpSampling2D(size=(1, 2))(x)
47 >>> print(y)
48 tf.Tensor(
49 [[[[ 0 1 2]
50 [ 0 1 2]]
51 [[ 3 4 5]
52 [ 3 4 5]]]
53 [[[ 6 7 8]
54 [ 6 7 8]]
55 [[ 9 10 11]
56 [ 9 10 11]]]], shape=(2, 2, 2, 3), dtype=int64)
58 Args:
59 size: Int, or tuple of 2 integers.
60 The upsampling factors for rows and columns.
61 data_format: A string,
62 one of `channels_last` (default) or `channels_first`.
63 The ordering of the dimensions in the inputs.
64 `channels_last` corresponds to inputs with shape
65 `(batch_size, height, width, channels)` while `channels_first`
66 corresponds to inputs with shape
67 `(batch_size, channels, height, width)`.
68 When unspecified, uses
69 `image_data_format` value found in your Keras config file at
70 `~/.keras/keras.json` (if exists) else 'channels_last'.
71 Defaults to 'channels_last'.
72 interpolation: A string, one of `"area"`, `"bicubic"`, `"bilinear"`,
73 `"gaussian"`, `"lanczos3"`, `"lanczos5"`, `"mitchellcubic"`,
74 `"nearest"`.
76 Input shape:
77 4D tensor with shape:
78 - If `data_format` is `"channels_last"`:
79 `(batch_size, rows, cols, channels)`
80 - If `data_format` is `"channels_first"`:
81 `(batch_size, channels, rows, cols)`
83 Output shape:
84 4D tensor with shape:
85 - If `data_format` is `"channels_last"`:
86 `(batch_size, upsampled_rows, upsampled_cols, channels)`
87 - If `data_format` is `"channels_first"`:
88 `(batch_size, channels, upsampled_rows, upsampled_cols)`
89 """
91 def __init__(
92 self, size=(2, 2), data_format=None, interpolation="nearest", **kwargs
93 ):
94 super().__init__(**kwargs)
95 self.data_format = conv_utils.normalize_data_format(data_format)
96 self.size = conv_utils.normalize_tuple(size, 2, "size")
97 self.interpolation = image_utils.get_interpolation(interpolation)
98 self.input_spec = InputSpec(ndim=4)
100 def compute_output_shape(self, input_shape):
101 input_shape = tf.TensorShape(input_shape).as_list()
102 if self.data_format == "channels_first":
103 height = (
104 self.size[0] * input_shape[2]
105 if input_shape[2] is not None
106 else None
107 )
108 width = (
109 self.size[1] * input_shape[3]
110 if input_shape[3] is not None
111 else None
112 )
113 return tf.TensorShape(
114 [input_shape[0], input_shape[1], height, width]
115 )
116 else:
117 height = (
118 self.size[0] * input_shape[1]
119 if input_shape[1] is not None
120 else None
121 )
122 width = (
123 self.size[1] * input_shape[2]
124 if input_shape[2] is not None
125 else None
126 )
127 return tf.TensorShape(
128 [input_shape[0], height, width, input_shape[3]]
129 )
131 def call(self, inputs):
132 return backend.resize_images(
133 inputs,
134 self.size[0],
135 self.size[1],
136 self.data_format,
137 interpolation=self.interpolation,
138 )
140 def get_config(self):
141 config = {
142 "size": self.size,
143 "data_format": self.data_format,
144 "interpolation": self.interpolation,
145 }
146 base_config = super().get_config()
147 return dict(list(base_config.items()) + list(config.items()))