Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/layers/max_unpooling_2d.py: 24%
58 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""MaxUnpooling2D operation."""
17import tensorflow as tf
19from typeguard import typechecked
20from typing import Union, Iterable
22from tensorflow_addons.utils.keras_utils import normalize_tuple
25def _calculate_output_shape(input_shape, pool_size, strides, padding):
26 """Calculates the shape of the unpooled output."""
27 if padding == "VALID":
28 output_shape = (
29 input_shape[0],
30 (input_shape[1] - 1) * strides[0] + pool_size[0],
31 (input_shape[2] - 1) * strides[1] + pool_size[1],
32 input_shape[3],
33 )
34 elif padding == "SAME":
35 output_shape = (
36 input_shape[0],
37 input_shape[1] * strides[0],
38 input_shape[2] * strides[1],
39 input_shape[3],
40 )
41 else:
42 raise ValueError('Padding must be a string from: "SAME", "VALID"')
43 return output_shape
46def _max_unpooling_2d(updates, mask, pool_size=(2, 2), strides=(2, 2), padding="SAME"):
47 """Unpool the outputs of a maximum pooling operation."""
48 pool_size_attr = " ".join(["i: %d" % v for v in pool_size])
49 strides_attr = " ".join(["i: %d" % v for v in strides])
50 experimental_implements = [
51 'name: "addons:MaxUnpooling2D"',
52 'attr { key: "pool_size" value { list {%s} } }' % pool_size_attr,
53 'attr { key: "strides" value { list {%s} } }' % strides_attr,
54 'attr { key: "padding" value { s: "%s" } }' % padding,
55 ]
56 experimental_implements = " ".join(experimental_implements)
58 @tf.function(experimental_implements=experimental_implements)
59 def func(updates, mask):
60 mask = tf.cast(mask, "int32")
61 input_shape = tf.shape(updates, out_type="int32")
62 input_shape = [updates.shape[i] or input_shape[i] for i in range(4)]
63 output_shape = _calculate_output_shape(input_shape, pool_size, strides, padding)
65 # Calculates indices for batch, height, width and feature maps.
66 one_like_mask = tf.ones_like(mask, dtype="int32")
67 batch_shape = tf.concat([[input_shape[0]], [1], [1], [1]], axis=0)
68 batch_range = tf.reshape(
69 tf.range(output_shape[0], dtype="int32"), shape=batch_shape
70 )
71 b = one_like_mask * batch_range
72 y = mask // (output_shape[2] * output_shape[3])
73 x = (mask // output_shape[3]) % output_shape[2]
74 feature_range = tf.range(output_shape[3], dtype="int32")
75 f = one_like_mask * feature_range
77 # Transposes indices & reshape update values to one dimension.
78 updates_size = tf.size(updates)
79 indices = tf.transpose(tf.reshape(tf.stack([b, y, x, f]), [4, updates_size]))
80 values = tf.reshape(updates, [updates_size])
81 ret = tf.scatter_nd(indices, values, output_shape)
82 return ret
84 return func(updates, mask)
87@tf.keras.utils.register_keras_serializable(package="Addons")
88class MaxUnpooling2D(tf.keras.layers.Layer):
89 """Unpool the outputs of a maximum pooling operation.
91 This function currently does not support outputs of MaxPoolingWithArgMax in
92 following cases:
93 - include_batch_in_index equals true.
94 - input_shape is not divisible by strides if padding is "SAME".
95 - (input_shape - pool_size) is not divisible by strides if padding is "VALID".
96 - The max pooling operation results in duplicate values in updates and mask.
98 Args:
99 updates: The pooling result from max pooling.
100 mask: the argmax result corresponds to above max values.
101 pool_size: The filter that max pooling was performed with. Default: (2, 2).
102 strides: The strides that max pooling was performed with. Default: (2, 2).
103 padding: The padding that max pooling was performed with. Default: "SAME".
104 Input shape:
105 4D tensor with shape: `(batch_size, height, width, channel)`.
106 Output shape:
107 4D tensor with the same shape as the input of max pooling operation.
108 """
110 @typechecked
111 def __init__(
112 self,
113 pool_size: Union[int, Iterable[int]] = (2, 2),
114 strides: Union[int, Iterable[int]] = (2, 2),
115 padding: str = "SAME",
116 **kwargs,
117 ):
118 super(MaxUnpooling2D, self).__init__(**kwargs)
120 if padding != "SAME" and padding != "VALID":
121 raise ValueError('Padding must be a string from: "SAME", "VALID"')
123 self.pool_size = normalize_tuple(pool_size, 2, "pool_size")
124 self.strides = normalize_tuple(strides, 2, "strides")
125 self.padding = padding
127 def call(self, updates, mask):
128 return _max_unpooling_2d(
129 updates,
130 mask,
131 pool_size=self.pool_size,
132 strides=self.strides,
133 padding=self.padding,
134 )
136 def compute_output_shape(self, input_shapes):
137 input_shape = input_shapes[1]
138 return _calculate_output_shape(
139 input_shape, self.pool_size, self.strides, self.padding
140 )
142 def get_config(self):
143 config = super(MaxUnpooling2D, self).get_config()
144 config["pool_size"] = self.pool_size
145 config["strides"] = self.strides
146 config["padding"] = self.padding
147 return config