Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/layers/pooling/global_average_pooling2d.py: 73%
11 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Global average pooling 2D layer."""
18from keras.src import backend
19from keras.src.layers.pooling.base_global_pooling2d import GlobalPooling2D
21# isort: off
22from tensorflow.python.util.tf_export import keras_export
25@keras_export(
26 "keras.layers.GlobalAveragePooling2D", "keras.layers.GlobalAvgPool2D"
27)
28class GlobalAveragePooling2D(GlobalPooling2D):
29 """Global average pooling operation for spatial data.
31 Examples:
33 >>> input_shape = (2, 4, 5, 3)
34 >>> x = tf.random.normal(input_shape)
35 >>> y = tf.keras.layers.GlobalAveragePooling2D()(x)
36 >>> print(y.shape)
37 (2, 3)
39 Args:
40 data_format: A string,
41 one of `channels_last` (default) or `channels_first`.
42 The ordering of the dimensions in the inputs.
43 `channels_last` corresponds to inputs with shape
44 `(batch, height, width, channels)` while `channels_first`
45 corresponds to inputs with shape
46 `(batch, channels, height, width)`.
47 It defaults to the `image_data_format` value found in your
48 Keras config file at `~/.keras/keras.json`.
49 If you never set it, then it will be "channels_last".
50 keepdims: A boolean, whether to keep the spatial dimensions or not.
51 If `keepdims` is `False` (default), the rank of the tensor is reduced
52 for spatial dimensions.
53 If `keepdims` is `True`, the spatial dimensions are retained with
54 length 1.
55 The behavior is the same as for `tf.reduce_mean` or `np.mean`.
57 Input shape:
58 - If `data_format='channels_last'`:
59 4D tensor with shape `(batch_size, rows, cols, channels)`.
60 - If `data_format='channels_first'`:
61 4D tensor with shape `(batch_size, channels, rows, cols)`.
63 Output shape:
64 - If `keepdims`=False:
65 2D tensor with shape `(batch_size, channels)`.
66 - If `keepdims`=True:
67 - If `data_format='channels_last'`:
68 4D tensor with shape `(batch_size, 1, 1, channels)`
69 - If `data_format='channels_first'`:
70 4D tensor with shape `(batch_size, channels, 1, 1)`
71 """
73 def call(self, inputs):
74 if self.data_format == "channels_last":
75 return backend.mean(inputs, axis=[1, 2], keepdims=self.keepdims)
76 else:
77 return backend.mean(inputs, axis=[2, 3], keepdims=self.keepdims)
80# Alias
82GlobalAvgPool2D = GlobalAveragePooling2D