Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/feature_column/serialization.py: 23%

115 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""FeatureColumn serialization, deserialization logic.""" 

16 

17import six 

18 

19from tensorflow.python.feature_column import feature_column_v2 as fc_lib 

20from tensorflow.python.feature_column import sequence_feature_column as sfc_lib 

21from tensorflow.python.ops import init_ops 

22from tensorflow.python.util import deprecation 

23from tensorflow.python.util import tf_decorator 

24from tensorflow.python.util import tf_inspect 

25from tensorflow.python.util.tf_export import tf_export 

26from tensorflow.tools.docs import doc_controls 

27 

28_FEATURE_COLUMN_DEPRECATION_WARNING = """\ 

29 Warning: tf.feature_column is not recommended for new code. Instead, 

30 feature preprocessing can be done directly using either [Keras preprocessing 

31 layers](https://www.tensorflow.org/guide/migrate/migrating_feature_columns) 

32 or through the one-stop utility [`tf.keras.utils.FeatureSpace`](https://www.tensorflow.org/api_docs/python/tf/keras/utils/FeatureSpace) 

33 built on top of them. See the [migration guide](https://tensorflow.org/guide/migrate) 

34 for details. 

35 """ 

36 

37_FEATURE_COLUMN_DEPRECATION_RUNTIME_WARNING = ( 

38 'Use Keras preprocessing layers instead, either directly or via the ' 

39 '`tf.keras.utils.FeatureSpace` utility. Each of `tf.feature_column.*` has ' 

40 'a functional equivalent in `tf.keras.layers` for feature preprocessing ' 

41 'when training a Keras model.') 

42 

43_FEATURE_COLUMNS = [ 

44 fc_lib.BucketizedColumn, fc_lib.CrossedColumn, fc_lib.EmbeddingColumn, 

45 fc_lib.HashedCategoricalColumn, fc_lib.IdentityCategoricalColumn, 

46 fc_lib.IndicatorColumn, fc_lib.NumericColumn, 

47 fc_lib.SequenceCategoricalColumn, fc_lib.SequenceDenseColumn, 

48 fc_lib.SharedEmbeddingColumn, fc_lib.VocabularyFileCategoricalColumn, 

49 fc_lib.VocabularyListCategoricalColumn, fc_lib.WeightedCategoricalColumn, 

50 init_ops.TruncatedNormal, sfc_lib.SequenceNumericColumn 

51] 

52 

53 

54@doc_controls.header(_FEATURE_COLUMN_DEPRECATION_WARNING) 

55@tf_export( 

56 '__internal__.feature_column.serialize_feature_column', 

57 v1=[]) 

58@deprecation.deprecated(None, _FEATURE_COLUMN_DEPRECATION_RUNTIME_WARNING) 

59def serialize_feature_column(fc): 

60 """Serializes a FeatureColumn or a raw string key. 

61 

62 This method should only be used to serialize parent FeatureColumns when 

63 implementing FeatureColumn.get_config(), else serialize_feature_columns() 

64 is preferable. 

65 

66 This serialization also keeps information of the FeatureColumn class, so 

67 deserialization is possible without knowing the class type. For example: 

68 

69 a = numeric_column('x') 

70 a.get_config() gives: 

71 { 

72 'key': 'price', 

73 'shape': (1,), 

74 'default_value': None, 

75 'dtype': 'float32', 

76 'normalizer_fn': None 

77 } 

78 While serialize_feature_column(a) gives: 

79 { 

80 'class_name': 'NumericColumn', 

81 'config': { 

82 'key': 'price', 

83 'shape': (1,), 

84 'default_value': None, 

85 'dtype': 'float32', 

86 'normalizer_fn': None 

87 } 

88 } 

89 

90 Args: 

91 fc: A FeatureColumn or raw feature key string. 

92 

93 Returns: 

94 Keras serialization for FeatureColumns, leaves string keys unaffected. 

95 

96 Raises: 

97 ValueError if called with input that is not string or FeatureColumn. 

98 """ 

99 if isinstance(fc, six.string_types): 

100 return fc 

101 elif isinstance(fc, fc_lib.FeatureColumn): 

102 return {'class_name': fc.__class__.__name__, 'config': fc.get_config()} 

103 else: 

104 raise ValueError('Instance: {} is not a FeatureColumn'.format(fc)) 

105 

106 

107 

108@doc_controls.header(_FEATURE_COLUMN_DEPRECATION_WARNING) 

109@tf_export('__internal__.feature_column.deserialize_feature_column', v1=[]) 

110def deserialize_feature_column(config, 

111 custom_objects=None, 

112 columns_by_name=None): 

113 """Deserializes a `config` generated with `serialize_feature_column`. 

114 

115 This method should only be used to deserialize parent FeatureColumns when 

116 implementing FeatureColumn.from_config(), else deserialize_feature_columns() 

117 is preferable. Returns a FeatureColumn for this config. 

118 

119 Args: 

120 config: A Dict with the serialization of feature columns acquired by 

121 `serialize_feature_column`, or a string representing a raw column. 

122 custom_objects: A Dict from custom_object name to the associated keras 

123 serializable objects (FeatureColumns, classes or functions). 

124 columns_by_name: A Dict[String, FeatureColumn] of existing columns in order 

125 to avoid duplication. 

126 

127 Raises: 

128 ValueError if `config` has invalid format (e.g: expected keys missing, 

129 or refers to unknown classes). 

130 

131 Returns: 

132 A FeatureColumn corresponding to the input `config`. 

133 """ 

134 # TODO(b/118939620): Simplify code if Keras utils support object deduping. 

135 if isinstance(config, six.string_types): 

136 return config 

137 # A dict from class_name to class for all FeatureColumns in this module. 

138 # FeatureColumns not part of the module can be passed as custom_objects. 

139 module_feature_column_classes = { 

140 cls.__name__: cls for cls in _FEATURE_COLUMNS 

141 } 

142 if columns_by_name is None: 

143 columns_by_name = {} 

144 

145 (cls, cls_config) = _class_and_config_for_serialized_keras_object( 

146 config, 

147 module_objects=module_feature_column_classes, 

148 custom_objects=custom_objects, 

149 printable_module_name='feature_column_v2') 

150 

151 if not issubclass(cls, fc_lib.FeatureColumn): 

152 raise ValueError( 

153 'Expected FeatureColumn class, instead found: {}'.format(cls)) 

154 

155 # Always deserialize the FeatureColumn, in order to get the name. 

156 new_instance = cls.from_config( # pylint: disable=protected-access 

157 cls_config, 

158 custom_objects=custom_objects, 

159 columns_by_name=columns_by_name) 

160 

161 # If the name already exists, re-use the column from columns_by_name, 

162 # (new_instance remains unused). 

163 return columns_by_name.setdefault( 

164 _column_name_with_class_name(new_instance), new_instance) 

165 

166 

167 

168def serialize_feature_columns(feature_columns): 

169 """Serializes a list of FeatureColumns. 

170 

171 Returns a list of Keras-style config dicts that represent the input 

172 FeatureColumns and can be used with `deserialize_feature_columns` for 

173 reconstructing the original columns. 

174 

175 Args: 

176 feature_columns: A list of FeatureColumns. 

177 

178 Returns: 

179 Keras serialization for the list of FeatureColumns. 

180 

181 Raises: 

182 ValueError if called with input that is not a list of FeatureColumns. 

183 """ 

184 return [serialize_feature_column(fc) for fc in feature_columns] 

185 

186 

187def deserialize_feature_columns(configs, custom_objects=None): 

188 """Deserializes a list of FeatureColumns configs. 

189 

190 Returns a list of FeatureColumns given a list of config dicts acquired by 

191 `serialize_feature_columns`. 

192 

193 Args: 

194 configs: A list of Dicts with the serialization of feature columns acquired 

195 by `serialize_feature_columns`. 

196 custom_objects: A Dict from custom_object name to the associated keras 

197 serializable objects (FeatureColumns, classes or functions). 

198 

199 Returns: 

200 FeatureColumn objects corresponding to the input configs. 

201 

202 Raises: 

203 ValueError if called with input that is not a list of FeatureColumns. 

204 """ 

205 columns_by_name = {} 

206 return [ 

207 deserialize_feature_column(c, custom_objects, columns_by_name) 

208 for c in configs 

209 ] 

210 

211 

212def _column_name_with_class_name(fc): 

213 """Returns a unique name for the feature column used during deduping. 

214 

215 Without this two FeatureColumns that have the same name and where 

216 one wraps the other, such as an IndicatorColumn wrapping a 

217 SequenceCategoricalColumn, will fail to deserialize because they will have the 

218 same name in columns_by_name, causing the wrong column to be returned. 

219 

220 Args: 

221 fc: A FeatureColumn. 

222 

223 Returns: 

224 A unique name as a string. 

225 """ 

226 return fc.__class__.__name__ + ':' + fc.name 

227 

228 

229def _serialize_keras_object(instance): 

230 """Serialize a Keras object into a JSON-compatible representation.""" 

231 _, instance = tf_decorator.unwrap(instance) 

232 if instance is None: 

233 return None 

234 

235 if hasattr(instance, 'get_config'): 

236 name = instance.__class__.__name__ 

237 config = instance.get_config() 

238 serialization_config = {} 

239 for key, item in config.items(): 

240 if isinstance(item, six.string_types): 

241 serialization_config[key] = item 

242 continue 

243 

244 # Any object of a different type needs to be converted to string or dict 

245 # for serialization (e.g. custom functions, custom classes) 

246 try: 

247 serialized_item = _serialize_keras_object(item) 

248 if isinstance(serialized_item, dict) and not isinstance(item, dict): 

249 serialized_item['__passive_serialization__'] = True 

250 serialization_config[key] = serialized_item 

251 except ValueError: 

252 serialization_config[key] = item 

253 

254 return {'class_name': name, 'config': serialization_config} 

255 if hasattr(instance, '__name__'): 

256 return instance.__name__ 

257 raise ValueError('Cannot serialize', instance) 

258 

259 

260def _deserialize_keras_object(identifier, 

261 module_objects=None, 

262 custom_objects=None, 

263 printable_module_name='object'): 

264 """Turns the serialized form of a Keras object back into an actual object.""" 

265 if identifier is None: 

266 return None 

267 

268 if isinstance(identifier, dict): 

269 # In this case we are dealing with a Keras config dictionary. 

270 config = identifier 

271 (cls, cls_config) = _class_and_config_for_serialized_keras_object( 

272 config, module_objects, custom_objects, printable_module_name) 

273 

274 if hasattr(cls, 'from_config'): 

275 arg_spec = tf_inspect.getfullargspec(cls.from_config) 

276 custom_objects = custom_objects or {} 

277 

278 if 'custom_objects' in arg_spec.args: 

279 return cls.from_config( 

280 cls_config, custom_objects=dict(list(custom_objects.items()))) 

281 return cls.from_config(cls_config) 

282 else: 

283 # Then `cls` may be a function returning a class. 

284 # in this case by convention `config` holds 

285 # the kwargs of the function. 

286 custom_objects = custom_objects or {} 

287 return cls(**cls_config) 

288 elif isinstance(identifier, six.string_types): 

289 object_name = identifier 

290 if custom_objects and object_name in custom_objects: 

291 obj = custom_objects.get(object_name) 

292 else: 

293 obj = module_objects.get(object_name) 

294 if obj is None: 

295 raise ValueError('Unknown ' + printable_module_name + ': ' + 

296 object_name) 

297 # Classes passed by name are instantiated with no args, functions are 

298 # returned as-is. 

299 if tf_inspect.isclass(obj): 

300 return obj() 

301 return obj 

302 elif tf_inspect.isfunction(identifier): 

303 # If a function has already been deserialized, return as is. 

304 return identifier 

305 else: 

306 raise ValueError('Could not interpret serialized %s: %s' % 

307 (printable_module_name, identifier)) 

308 

309 

310def _class_and_config_for_serialized_keras_object( 

311 config, 

312 module_objects=None, 

313 custom_objects=None, 

314 printable_module_name='object'): 

315 """Returns the class name and config for a serialized keras object.""" 

316 if (not isinstance(config, dict) or 'class_name' not in config or 

317 'config' not in config): 

318 raise ValueError('Improper config format: ' + str(config)) 

319 

320 class_name = config['class_name'] 

321 cls = _get_registered_object( 

322 class_name, custom_objects=custom_objects, module_objects=module_objects) 

323 if cls is None: 

324 raise ValueError('Unknown ' + printable_module_name + ': ' + class_name) 

325 

326 cls_config = config['config'] 

327 

328 deserialized_objects = {} 

329 for key, item in cls_config.items(): 

330 if isinstance(item, dict) and '__passive_serialization__' in item: 

331 deserialized_objects[key] = _deserialize_keras_object( 

332 item, 

333 module_objects=module_objects, 

334 custom_objects=custom_objects, 

335 printable_module_name='config_item') 

336 elif (isinstance(item, six.string_types) and 

337 tf_inspect.isfunction(_get_registered_object(item, custom_objects))): 

338 # Handle custom functions here. When saving functions, we only save the 

339 # function's name as a string. If we find a matching string in the custom 

340 # objects during deserialization, we convert the string back to the 

341 # original function. 

342 # Note that a potential issue is that a string field could have a naming 

343 # conflict with a custom function name, but this should be a rare case. 

344 # This issue does not occur if a string field has a naming conflict with 

345 # a custom object, since the config of an object will always be a dict. 

346 deserialized_objects[key] = _get_registered_object(item, custom_objects) 

347 for key, item in deserialized_objects.items(): 

348 cls_config[key] = deserialized_objects[key] 

349 

350 return (cls, cls_config) 

351 

352 

353def _get_registered_object(name, custom_objects=None, module_objects=None): 

354 if custom_objects and name in custom_objects: 

355 return custom_objects[name] 

356 elif module_objects and name in module_objects: 

357 return module_objects[name] 

358 return None