Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/layers/pooling/base_pooling1d.py: 28%
39 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Private base class for pooling 1D layers."""
18import tensorflow.compat.v2 as tf
20from keras.src import backend
21from keras.src.engine.base_layer import Layer
22from keras.src.engine.input_spec import InputSpec
23from keras.src.utils import conv_utils
26class Pooling1D(Layer):
27 """Pooling layer for arbitrary pooling functions, for 1D inputs.
29 This class only exists for code reuse. It will never be an exposed API.
31 Args:
32 pool_function: The pooling function to apply, e.g. `tf.nn.max_pool2d`.
33 pool_size: An integer or tuple/list of a single integer,
34 representing the size of the pooling window.
35 strides: An integer or tuple/list of a single integer, specifying the
36 strides of the pooling operation.
37 padding: A string. The padding method, either 'valid' or 'same'.
38 Case-insensitive.
39 data_format: A string,
40 one of `channels_last` (default) or `channels_first`.
41 The ordering of the dimensions in the inputs.
42 `channels_last` corresponds to inputs with shape
43 `(batch, steps, features)` while `channels_first`
44 corresponds to inputs with shape
45 `(batch, features, steps)`.
46 name: A string, the name of the layer.
47 """
49 def __init__(
50 self,
51 pool_function,
52 pool_size,
53 strides,
54 padding="valid",
55 data_format="channels_last",
56 name=None,
57 **kwargs
58 ):
59 super().__init__(name=name, **kwargs)
60 if data_format is None:
61 data_format = backend.image_data_format()
62 if strides is None:
63 strides = pool_size
64 self.pool_function = pool_function
65 self.pool_size = conv_utils.normalize_tuple(pool_size, 1, "pool_size")
66 self.strides = conv_utils.normalize_tuple(
67 strides, 1, "strides", allow_zero=True
68 )
69 self.padding = conv_utils.normalize_padding(padding)
70 self.data_format = conv_utils.normalize_data_format(data_format)
71 self.input_spec = InputSpec(ndim=3)
73 def call(self, inputs):
74 pad_axis = 2 if self.data_format == "channels_last" else 3
75 inputs = tf.expand_dims(inputs, pad_axis)
76 outputs = self.pool_function(
77 inputs,
78 self.pool_size + (1,),
79 strides=self.strides + (1,),
80 padding=self.padding,
81 data_format=self.data_format,
82 )
83 return tf.squeeze(outputs, pad_axis)
85 def compute_output_shape(self, input_shape):
86 input_shape = tf.TensorShape(input_shape).as_list()
87 if self.data_format == "channels_first":
88 steps = input_shape[2]
89 features = input_shape[1]
90 else:
91 steps = input_shape[1]
92 features = input_shape[2]
93 length = conv_utils.conv_output_length(
94 steps, self.pool_size[0], self.padding, self.strides[0]
95 )
96 if self.data_format == "channels_first":
97 return tf.TensorShape([input_shape[0], features, length])
98 else:
99 return tf.TensorShape([input_shape[0], length, features])
101 def get_config(self):
102 config = {
103 "strides": self.strides,
104 "pool_size": self.pool_size,
105 "padding": self.padding,
106 "data_format": self.data_format,
107 }
108 base_config = super().get_config()
109 return dict(list(base_config.items()) + list(config.items()))