Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/layers/maxout.py: 27%
37 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Implementing Maxout layer."""
17import tensorflow as tf
18from typeguard import typechecked
21@tf.keras.utils.register_keras_serializable(package="Addons")
22class Maxout(tf.keras.layers.Layer):
23 """Applies Maxout to the input.
25 "Maxout Networks" Ian J. Goodfellow, David Warde-Farley, Mehdi Mirza, Aaron
26 Courville, Yoshua Bengio. https://arxiv.org/abs/1302.4389
28 Usually the operation is performed in the filter/channel dimension. This
29 can also be used after Dense layers to reduce number of features.
31 Args:
32 num_units: Specifies how many features will remain after maxout
33 in the `axis` dimension (usually channel).
34 This must be a factor of number of features.
35 axis: The dimension where max pooling will be performed. Default is the
36 last dimension.
38 Input shape:
39 nD tensor with shape: `(batch_size, ..., axis_dim, ...)`.
41 Output shape:
42 nD tensor with shape: `(batch_size, ..., num_units, ...)`.
43 """
45 @typechecked
46 def __init__(self, num_units: int, axis: int = -1, **kwargs):
47 super().__init__(**kwargs)
48 self.num_units = num_units
49 self.axis = axis
51 def call(self, inputs):
52 inputs = tf.convert_to_tensor(inputs)
53 shape = inputs.get_shape().as_list()
54 # Dealing with batches with arbitrary sizes
55 for i in range(len(shape)):
56 if shape[i] is None:
57 shape[i] = tf.shape(inputs)[i]
59 num_channels = shape[self.axis]
60 if not isinstance(num_channels, tf.Tensor) and num_channels % self.num_units:
61 raise ValueError(
62 "number of features({}) is not "
63 "a multiple of num_units({})".format(num_channels, self.num_units)
64 )
66 if self.axis < 0:
67 axis = self.axis + len(shape)
68 else:
69 axis = self.axis
70 assert axis >= 0, "Find invalid axis: {}".format(self.axis)
72 expand_shape = shape[:]
73 expand_shape[axis] = self.num_units
74 k = num_channels // self.num_units
75 expand_shape.insert(axis, k)
77 outputs = tf.math.reduce_max(
78 tf.reshape(inputs, expand_shape), axis, keepdims=False
79 )
80 return outputs
82 def compute_output_shape(self, input_shape):
83 input_shape = tf.TensorShape(input_shape).as_list()
84 input_shape[self.axis] = self.num_units
85 return tf.TensorShape(input_shape)
87 def get_config(self):
88 config = {"num_units": self.num_units, "axis": self.axis}
89 base_config = super().get_config()
90 return {**base_config, **config}