Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/feature_column/base_feature_layer.py: 27%
82 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""This API defines FeatureColumn abstraction."""
17# This file was originally under tf/python/feature_column, and was moved to
18# Keras package in order to remove the reverse dependency from TF to Keras.
20from __future__ import absolute_import
21from __future__ import division
22from __future__ import print_function
24import collections
25import re
27import tensorflow.compat.v2 as tf
29from keras.src.engine.base_layer import Layer
30from keras.src.saving import serialization_lib
33class _BaseFeaturesLayer(Layer):
34 """Base class for DenseFeatures and SequenceFeatures.
36 Defines common methods and helpers.
38 Args:
39 feature_columns: An iterable containing the FeatureColumns to use as
40 inputs to your model.
41 expected_column_type: Expected class for provided feature columns.
42 trainable: Boolean, whether the layer's variables will be updated via
43 gradient descent during training.
44 name: Name to give to the DenseFeatures.
45 **kwargs: Keyword arguments to construct a layer.
47 Raises:
48 ValueError: if an item in `feature_columns` doesn't match
49 `expected_column_type`.
50 """
52 def __init__(
53 self,
54 feature_columns,
55 expected_column_type,
56 trainable,
57 name,
58 partitioner=None,
59 **kwargs
60 ):
61 super().__init__(name=name, trainable=trainable, **kwargs)
62 self._feature_columns = _normalize_feature_columns(feature_columns)
63 self._state_manager = tf.__internal__.feature_column.StateManager(
64 self, self.trainable
65 )
66 self._partitioner = partitioner
67 for column in self._feature_columns:
68 if not isinstance(column, expected_column_type):
69 raise ValueError(
70 "Items of feature_columns must be a {}. "
71 "You can wrap a categorical column with an "
72 "embedding_column or indicator_column. Given: {}".format(
73 expected_column_type, column
74 )
75 )
77 def build(self, _):
78 for column in self._feature_columns:
79 with tf.compat.v1.variable_scope(
80 self.name, partitioner=self._partitioner
81 ):
82 with tf.compat.v1.variable_scope(
83 _sanitize_column_name_for_variable_scope(column.name)
84 ):
85 column.create_state(self._state_manager)
86 super().build(None)
88 def _output_shape(self, input_shape, num_elements):
89 """Computes expected output shape of the dense tensor of the layer.
91 Args:
92 input_shape: Tensor or array with batch shape.
93 num_elements: Size of the last dimension of the output.
95 Returns:
96 Tuple with output shape.
97 """
98 raise NotImplementedError("Calling an abstract method.")
100 def compute_output_shape(self, input_shape):
101 total_elements = 0
102 for column in self._feature_columns:
103 total_elements += column.variable_shape.num_elements()
104 return self._target_shape(input_shape, total_elements)
106 def _process_dense_tensor(self, column, tensor):
107 """Reshapes the dense tensor output of a column based on expected shape.
109 Args:
110 column: A DenseColumn or SequenceDenseColumn object.
111 tensor: A dense tensor obtained from the same column.
113 Returns:
114 Reshaped dense tensor.
115 """
116 num_elements = column.variable_shape.num_elements()
117 target_shape = self._target_shape(tf.shape(tensor), num_elements)
118 return tf.reshape(tensor, shape=target_shape)
120 def _verify_and_concat_tensors(self, output_tensors):
121 """Verifies and concatenates the dense output of several columns."""
122 _verify_static_batch_size_equality(
123 output_tensors, self._feature_columns
124 )
125 return tf.concat(output_tensors, -1)
127 def get_config(self):
128 column_configs = [
129 tf.__internal__.feature_column.serialize_feature_column(fc)
130 for fc in self._feature_columns
131 ]
132 config = {"feature_columns": column_configs}
133 config["partitioner"] = serialization_lib.serialize_keras_object(
134 self._partitioner
135 )
137 base_config = super().get_config()
138 return dict(list(base_config.items()) + list(config.items()))
140 @classmethod
141 def from_config(cls, config, custom_objects=None):
142 config_cp = config.copy()
143 columns_by_name = {}
144 config_cp["feature_columns"] = [
145 tf.__internal__.feature_column.deserialize_feature_column(
146 c, custom_objects, columns_by_name
147 )
148 for c in config["feature_columns"]
149 ]
150 config_cp["partitioner"] = serialization_lib.deserialize_keras_object(
151 config["partitioner"], custom_objects
152 )
154 return cls(**config_cp)
157def _sanitize_column_name_for_variable_scope(name):
158 """Sanitizes user-provided feature names for use as variable scopes."""
159 invalid_char = re.compile("[^A-Za-z0-9_.\\-]")
160 return invalid_char.sub("_", name)
163def _verify_static_batch_size_equality(tensors, columns):
164 """Verify equality between static batch sizes.
166 Args:
167 tensors: iterable of input tensors.
168 columns: Corresponding feature columns.
170 Raises:
171 ValueError: in case of mismatched batch sizes.
172 """
173 expected_batch_size = None
174 for i in range(0, len(tensors)):
175 # bath_size is a Dimension object.
176 batch_size = tf.compat.v1.Dimension(
177 tf.compat.dimension_value(tensors[i].shape[0])
178 )
179 if batch_size.value is not None:
180 if expected_batch_size is None:
181 bath_size_column_index = i
182 expected_batch_size = batch_size
183 elif not expected_batch_size.is_compatible_with(batch_size):
184 raise ValueError(
185 "Batch size (first dimension) of each feature must be "
186 "same. Batch size of columns ({}, {}): ({}, {})".format(
187 columns[bath_size_column_index].name,
188 columns[i].name,
189 expected_batch_size,
190 batch_size,
191 )
192 )
195def _normalize_feature_columns(feature_columns):
196 """Normalizes the `feature_columns` input.
198 This method converts the `feature_columns` to list type as best as it can.
199 In addition, verifies the type and other parts of feature_columns, required
200 by downstream library.
202 Args:
203 feature_columns: The raw feature columns, usually passed by users.
205 Returns:
206 The normalized feature column list.
208 Raises:
209 ValueError: for any invalid inputs, such as empty, duplicated names, etc.
210 """
211 if isinstance(
212 feature_columns, tf.__internal__.feature_column.FeatureColumn
213 ):
214 feature_columns = [feature_columns]
216 if isinstance(feature_columns, collections.abc.Iterator):
217 feature_columns = list(feature_columns)
219 if isinstance(feature_columns, dict):
220 raise ValueError("Expected feature_columns to be iterable, found dict.")
222 for column in feature_columns:
223 if not isinstance(column, tf.__internal__.feature_column.FeatureColumn):
224 raise ValueError(
225 "Items of feature_columns must be a FeatureColumn. "
226 "Given (type {}): {}.".format(type(column), column)
227 )
228 if not feature_columns:
229 raise ValueError("feature_columns must not be empty.")
230 name_to_column = {}
231 for column in feature_columns:
232 if column.name in name_to_column:
233 raise ValueError(
234 "Duplicate feature column name found for columns: {} "
235 "and {}. This usually means that these columns refer to "
236 "same base feature. Either one must be discarded or a "
237 "duplicated but renamed item must be inserted in "
238 "features dict.".format(column, name_to_column[column.name])
239 )
240 name_to_column[column.name] = column
242 return sorted(feature_columns, key=lambda x: x.name)