Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/feature_column/base_feature_layer.py: 27%

82 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""This API defines FeatureColumn abstraction.""" 

16 

17# This file was originally under tf/python/feature_column, and was moved to 

18# Keras package in order to remove the reverse dependency from TF to Keras. 

19 

20from __future__ import absolute_import 

21from __future__ import division 

22from __future__ import print_function 

23 

24import collections 

25import re 

26 

27import tensorflow.compat.v2 as tf 

28 

29from keras.src.engine.base_layer import Layer 

30from keras.src.saving import serialization_lib 

31 

32 

33class _BaseFeaturesLayer(Layer): 

34 """Base class for DenseFeatures and SequenceFeatures. 

35 

36 Defines common methods and helpers. 

37 

38 Args: 

39 feature_columns: An iterable containing the FeatureColumns to use as 

40 inputs to your model. 

41 expected_column_type: Expected class for provided feature columns. 

42 trainable: Boolean, whether the layer's variables will be updated via 

43 gradient descent during training. 

44 name: Name to give to the DenseFeatures. 

45 **kwargs: Keyword arguments to construct a layer. 

46 

47 Raises: 

48 ValueError: if an item in `feature_columns` doesn't match 

49 `expected_column_type`. 

50 """ 

51 

52 def __init__( 

53 self, 

54 feature_columns, 

55 expected_column_type, 

56 trainable, 

57 name, 

58 partitioner=None, 

59 **kwargs 

60 ): 

61 super().__init__(name=name, trainable=trainable, **kwargs) 

62 self._feature_columns = _normalize_feature_columns(feature_columns) 

63 self._state_manager = tf.__internal__.feature_column.StateManager( 

64 self, self.trainable 

65 ) 

66 self._partitioner = partitioner 

67 for column in self._feature_columns: 

68 if not isinstance(column, expected_column_type): 

69 raise ValueError( 

70 "Items of feature_columns must be a {}. " 

71 "You can wrap a categorical column with an " 

72 "embedding_column or indicator_column. Given: {}".format( 

73 expected_column_type, column 

74 ) 

75 ) 

76 

77 def build(self, _): 

78 for column in self._feature_columns: 

79 with tf.compat.v1.variable_scope( 

80 self.name, partitioner=self._partitioner 

81 ): 

82 with tf.compat.v1.variable_scope( 

83 _sanitize_column_name_for_variable_scope(column.name) 

84 ): 

85 column.create_state(self._state_manager) 

86 super().build(None) 

87 

88 def _output_shape(self, input_shape, num_elements): 

89 """Computes expected output shape of the dense tensor of the layer. 

90 

91 Args: 

92 input_shape: Tensor or array with batch shape. 

93 num_elements: Size of the last dimension of the output. 

94 

95 Returns: 

96 Tuple with output shape. 

97 """ 

98 raise NotImplementedError("Calling an abstract method.") 

99 

100 def compute_output_shape(self, input_shape): 

101 total_elements = 0 

102 for column in self._feature_columns: 

103 total_elements += column.variable_shape.num_elements() 

104 return self._target_shape(input_shape, total_elements) 

105 

106 def _process_dense_tensor(self, column, tensor): 

107 """Reshapes the dense tensor output of a column based on expected shape. 

108 

109 Args: 

110 column: A DenseColumn or SequenceDenseColumn object. 

111 tensor: A dense tensor obtained from the same column. 

112 

113 Returns: 

114 Reshaped dense tensor. 

115 """ 

116 num_elements = column.variable_shape.num_elements() 

117 target_shape = self._target_shape(tf.shape(tensor), num_elements) 

118 return tf.reshape(tensor, shape=target_shape) 

119 

120 def _verify_and_concat_tensors(self, output_tensors): 

121 """Verifies and concatenates the dense output of several columns.""" 

122 _verify_static_batch_size_equality( 

123 output_tensors, self._feature_columns 

124 ) 

125 return tf.concat(output_tensors, -1) 

126 

127 def get_config(self): 

128 column_configs = [ 

129 tf.__internal__.feature_column.serialize_feature_column(fc) 

130 for fc in self._feature_columns 

131 ] 

132 config = {"feature_columns": column_configs} 

133 config["partitioner"] = serialization_lib.serialize_keras_object( 

134 self._partitioner 

135 ) 

136 

137 base_config = super().get_config() 

138 return dict(list(base_config.items()) + list(config.items())) 

139 

140 @classmethod 

141 def from_config(cls, config, custom_objects=None): 

142 config_cp = config.copy() 

143 columns_by_name = {} 

144 config_cp["feature_columns"] = [ 

145 tf.__internal__.feature_column.deserialize_feature_column( 

146 c, custom_objects, columns_by_name 

147 ) 

148 for c in config["feature_columns"] 

149 ] 

150 config_cp["partitioner"] = serialization_lib.deserialize_keras_object( 

151 config["partitioner"], custom_objects 

152 ) 

153 

154 return cls(**config_cp) 

155 

156 

157def _sanitize_column_name_for_variable_scope(name): 

158 """Sanitizes user-provided feature names for use as variable scopes.""" 

159 invalid_char = re.compile("[^A-Za-z0-9_.\\-]") 

160 return invalid_char.sub("_", name) 

161 

162 

163def _verify_static_batch_size_equality(tensors, columns): 

164 """Verify equality between static batch sizes. 

165 

166 Args: 

167 tensors: iterable of input tensors. 

168 columns: Corresponding feature columns. 

169 

170 Raises: 

171 ValueError: in case of mismatched batch sizes. 

172 """ 

173 expected_batch_size = None 

174 for i in range(0, len(tensors)): 

175 # bath_size is a Dimension object. 

176 batch_size = tf.compat.v1.Dimension( 

177 tf.compat.dimension_value(tensors[i].shape[0]) 

178 ) 

179 if batch_size.value is not None: 

180 if expected_batch_size is None: 

181 bath_size_column_index = i 

182 expected_batch_size = batch_size 

183 elif not expected_batch_size.is_compatible_with(batch_size): 

184 raise ValueError( 

185 "Batch size (first dimension) of each feature must be " 

186 "same. Batch size of columns ({}, {}): ({}, {})".format( 

187 columns[bath_size_column_index].name, 

188 columns[i].name, 

189 expected_batch_size, 

190 batch_size, 

191 ) 

192 ) 

193 

194 

195def _normalize_feature_columns(feature_columns): 

196 """Normalizes the `feature_columns` input. 

197 

198 This method converts the `feature_columns` to list type as best as it can. 

199 In addition, verifies the type and other parts of feature_columns, required 

200 by downstream library. 

201 

202 Args: 

203 feature_columns: The raw feature columns, usually passed by users. 

204 

205 Returns: 

206 The normalized feature column list. 

207 

208 Raises: 

209 ValueError: for any invalid inputs, such as empty, duplicated names, etc. 

210 """ 

211 if isinstance( 

212 feature_columns, tf.__internal__.feature_column.FeatureColumn 

213 ): 

214 feature_columns = [feature_columns] 

215 

216 if isinstance(feature_columns, collections.abc.Iterator): 

217 feature_columns = list(feature_columns) 

218 

219 if isinstance(feature_columns, dict): 

220 raise ValueError("Expected feature_columns to be iterable, found dict.") 

221 

222 for column in feature_columns: 

223 if not isinstance(column, tf.__internal__.feature_column.FeatureColumn): 

224 raise ValueError( 

225 "Items of feature_columns must be a FeatureColumn. " 

226 "Given (type {}): {}.".format(type(column), column) 

227 ) 

228 if not feature_columns: 

229 raise ValueError("feature_columns must not be empty.") 

230 name_to_column = {} 

231 for column in feature_columns: 

232 if column.name in name_to_column: 

233 raise ValueError( 

234 "Duplicate feature column name found for columns: {} " 

235 "and {}. This usually means that these columns refer to " 

236 "same base feature. Either one must be discarded or a " 

237 "duplicated but renamed item must be inserted in " 

238 "features dict.".format(column, name_to_column[column.name]) 

239 ) 

240 name_to_column[column.name] = column 

241 

242 return sorted(feature_columns, key=lambda x: x.name) 

243