Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/layers/spatial_pyramid_pooling.py: 24%
55 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Spatial Pyramid Pooling layers"""
17import tensorflow as tf
18from tensorflow_addons.layers.adaptive_pooling import AdaptiveAveragePooling2D
19import tensorflow_addons.utils.keras_utils as conv_utils
21from typeguard import typechecked
22from typing import Union, Iterable
25@tf.keras.utils.register_keras_serializable(package="Addons")
26class SpatialPyramidPooling2D(tf.keras.layers.Layer):
27 """Performs Spatial Pyramid Pooling.
29 See [Spatial Pyramid Pooling in Deep Convolutional Networks for Visual Recognition](https://arxiv.org/pdf/1406.4729.pdf).
31 Spatial Pyramid Pooling generates a fixed-length representation
32 regardless of input size/scale. It is typically used before a layer
33 that requires a constant input shape, for example before a Dense Layer.
35 Args:
36 bins: Either a collection of integers or a collection of collections of 2 integers.
37 Each element in the inner collection must contain 2 integers, (pooled_rows, pooled_cols)
38 For example, providing [1, 3, 5] or [[1, 1], [3, 3], [5, 5]] preforms pooling
39 using three different pooling layers, having outputs with dimensions 1x1, 3x3 and 5x5 respectively.
40 These are flattened along height and width to give an output of shape
41 [batch_size, (1 + 9 + 25), channels] = [batch_size, 35, channels].
42 data_format: A string,
43 one of `channels_last` (default) or `channels_first`.
44 The ordering of the dimensions in the inputs.
45 `channels_last` corresponds to inputs with shape
46 `(batch, height, width, channels)` while `channels_first`
47 corresponds to inputs with shape `(batch, channels, height, width)`.
49 Input shape:
50 - If `data_format='channels_last'`:
51 4D tensor with shape `(batch_size, height, width, channels)`.
52 - If `data_format='channels_first'`:
53 4D tensor with shape `(batch_size, channels, height, width)`.
55 Output shape:
56 The output is the pooled image, flattened across its height and width
57 - If `data_format='channels_last'`:
58 3D tensor with shape `(batch_size, num_bins, channels)`.
59 - If `data_format='channels_first'`:
60 3D tensor with shape `(batch_size, channels, num_bins)`.
61 """
63 @typechecked
64 def __init__(
65 self,
66 bins: Union[Iterable[int], Iterable[Iterable[int]]],
67 data_format=None,
68 *args,
69 **kwargs,
70 ):
71 self.bins = [conv_utils.normalize_tuple(bin, 2, "bin") for bin in bins]
72 self.data_format = conv_utils.normalize_data_format(data_format)
73 self.pool_layers = []
74 for bin in self.bins:
75 self.pool_layers.append(AdaptiveAveragePooling2D(bin, self.data_format))
76 super().__init__(*args, **kwargs)
78 def call(self, inputs, **kwargs):
79 dynamic_input_shape = tf.shape(inputs)
80 outputs = []
81 index = 0
82 if self.data_format == "channels_last":
83 for bin in self.bins:
84 height_overflow = dynamic_input_shape[1] % bin[0]
85 width_overflow = dynamic_input_shape[2] % bin[1]
86 new_input_height = dynamic_input_shape[1] - height_overflow
87 new_input_width = dynamic_input_shape[2] - width_overflow
89 new_inp = inputs[:, :new_input_height, :new_input_width, :]
90 output = self.pool_layers[index](new_inp)
91 output = tf.reshape(
92 output, [dynamic_input_shape[0], bin[0] * bin[1], inputs.shape[-1]]
93 )
94 outputs.append(output)
95 index += 1
96 outputs = tf.concat(outputs, axis=1)
97 else:
98 for bin in self.bins:
99 height_overflow = dynamic_input_shape[2] % bin[0]
100 width_overflow = dynamic_input_shape[3] % bin[1]
101 new_input_height = dynamic_input_shape[2] - height_overflow
102 new_input_width = dynamic_input_shape[3] - width_overflow
104 new_inp = inputs[:, :, :new_input_height, :new_input_width]
105 output = self.pool_layers[index](new_inp)
106 output = tf.reshape(
107 output, [dynamic_input_shape[0], inputs.shape[1], bin[0] * bin[1]]
108 )
109 outputs.append(output)
110 index += 1
112 outputs = tf.concat(outputs, axis=2)
113 return outputs
115 def compute_output_shape(self, input_shape):
116 pooled_shape = 0
117 for bin in self.bins:
118 pooled_shape += tf.reduce_prod(bin)
119 if self.data_format == "channels_last":
120 return tf.TensorShape([input_shape[0], pooled_shape, input_shape[-1]])
121 else:
122 return tf.TensorShape([input_shape[0], input_shape[1], pooled_shape])
124 def get_config(self):
125 config = {"bins": self.bins, "data_format": self.data_format}
126 base_config = super().get_config()
127 return {**base_config, **config}