Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/layers/maxout.py: 27%

37 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Implementing Maxout layer.""" 

16 

17import tensorflow as tf 

18from typeguard import typechecked 

19 

20 

21@tf.keras.utils.register_keras_serializable(package="Addons") 

22class Maxout(tf.keras.layers.Layer): 

23 """Applies Maxout to the input. 

24 

25 "Maxout Networks" Ian J. Goodfellow, David Warde-Farley, Mehdi Mirza, Aaron 

26 Courville, Yoshua Bengio. https://arxiv.org/abs/1302.4389 

27 

28 Usually the operation is performed in the filter/channel dimension. This 

29 can also be used after Dense layers to reduce number of features. 

30 

31 Args: 

32 num_units: Specifies how many features will remain after maxout 

33 in the `axis` dimension (usually channel). 

34 This must be a factor of number of features. 

35 axis: The dimension where max pooling will be performed. Default is the 

36 last dimension. 

37 

38 Input shape: 

39 nD tensor with shape: `(batch_size, ..., axis_dim, ...)`. 

40 

41 Output shape: 

42 nD tensor with shape: `(batch_size, ..., num_units, ...)`. 

43 """ 

44 

45 @typechecked 

46 def __init__(self, num_units: int, axis: int = -1, **kwargs): 

47 super().__init__(**kwargs) 

48 self.num_units = num_units 

49 self.axis = axis 

50 

51 def call(self, inputs): 

52 inputs = tf.convert_to_tensor(inputs) 

53 shape = inputs.get_shape().as_list() 

54 # Dealing with batches with arbitrary sizes 

55 for i in range(len(shape)): 

56 if shape[i] is None: 

57 shape[i] = tf.shape(inputs)[i] 

58 

59 num_channels = shape[self.axis] 

60 if not isinstance(num_channels, tf.Tensor) and num_channels % self.num_units: 

61 raise ValueError( 

62 "number of features({}) is not " 

63 "a multiple of num_units({})".format(num_channels, self.num_units) 

64 ) 

65 

66 if self.axis < 0: 

67 axis = self.axis + len(shape) 

68 else: 

69 axis = self.axis 

70 assert axis >= 0, "Find invalid axis: {}".format(self.axis) 

71 

72 expand_shape = shape[:] 

73 expand_shape[axis] = self.num_units 

74 k = num_channels // self.num_units 

75 expand_shape.insert(axis, k) 

76 

77 outputs = tf.math.reduce_max( 

78 tf.reshape(inputs, expand_shape), axis, keepdims=False 

79 ) 

80 return outputs 

81 

82 def compute_output_shape(self, input_shape): 

83 input_shape = tf.TensorShape(input_shape).as_list() 

84 input_shape[self.axis] = self.num_units 

85 return tf.TensorShape(input_shape) 

86 

87 def get_config(self): 

88 config = {"num_units": self.num_units, "axis": self.axis} 

89 base_config = super().get_config() 

90 return {**base_config, **config}