Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/feature_column/serialization.py: 23%
115 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""FeatureColumn serialization, deserialization logic."""
17import six
19from tensorflow.python.feature_column import feature_column_v2 as fc_lib
20from tensorflow.python.feature_column import sequence_feature_column as sfc_lib
21from tensorflow.python.ops import init_ops
22from tensorflow.python.util import deprecation
23from tensorflow.python.util import tf_decorator
24from tensorflow.python.util import tf_inspect
25from tensorflow.python.util.tf_export import tf_export
26from tensorflow.tools.docs import doc_controls
28_FEATURE_COLUMN_DEPRECATION_WARNING = """\
29 Warning: tf.feature_column is not recommended for new code. Instead,
30 feature preprocessing can be done directly using either [Keras preprocessing
31 layers](https://www.tensorflow.org/guide/migrate/migrating_feature_columns)
32 or through the one-stop utility [`tf.keras.utils.FeatureSpace`](https://www.tensorflow.org/api_docs/python/tf/keras/utils/FeatureSpace)
33 built on top of them. See the [migration guide](https://tensorflow.org/guide/migrate)
34 for details.
35 """
37_FEATURE_COLUMN_DEPRECATION_RUNTIME_WARNING = (
38 'Use Keras preprocessing layers instead, either directly or via the '
39 '`tf.keras.utils.FeatureSpace` utility. Each of `tf.feature_column.*` has '
40 'a functional equivalent in `tf.keras.layers` for feature preprocessing '
41 'when training a Keras model.')
43_FEATURE_COLUMNS = [
44 fc_lib.BucketizedColumn, fc_lib.CrossedColumn, fc_lib.EmbeddingColumn,
45 fc_lib.HashedCategoricalColumn, fc_lib.IdentityCategoricalColumn,
46 fc_lib.IndicatorColumn, fc_lib.NumericColumn,
47 fc_lib.SequenceCategoricalColumn, fc_lib.SequenceDenseColumn,
48 fc_lib.SharedEmbeddingColumn, fc_lib.VocabularyFileCategoricalColumn,
49 fc_lib.VocabularyListCategoricalColumn, fc_lib.WeightedCategoricalColumn,
50 init_ops.TruncatedNormal, sfc_lib.SequenceNumericColumn
51]
54@doc_controls.header(_FEATURE_COLUMN_DEPRECATION_WARNING)
55@tf_export(
56 '__internal__.feature_column.serialize_feature_column',
57 v1=[])
58@deprecation.deprecated(None, _FEATURE_COLUMN_DEPRECATION_RUNTIME_WARNING)
59def serialize_feature_column(fc):
60 """Serializes a FeatureColumn or a raw string key.
62 This method should only be used to serialize parent FeatureColumns when
63 implementing FeatureColumn.get_config(), else serialize_feature_columns()
64 is preferable.
66 This serialization also keeps information of the FeatureColumn class, so
67 deserialization is possible without knowing the class type. For example:
69 a = numeric_column('x')
70 a.get_config() gives:
71 {
72 'key': 'price',
73 'shape': (1,),
74 'default_value': None,
75 'dtype': 'float32',
76 'normalizer_fn': None
77 }
78 While serialize_feature_column(a) gives:
79 {
80 'class_name': 'NumericColumn',
81 'config': {
82 'key': 'price',
83 'shape': (1,),
84 'default_value': None,
85 'dtype': 'float32',
86 'normalizer_fn': None
87 }
88 }
90 Args:
91 fc: A FeatureColumn or raw feature key string.
93 Returns:
94 Keras serialization for FeatureColumns, leaves string keys unaffected.
96 Raises:
97 ValueError if called with input that is not string or FeatureColumn.
98 """
99 if isinstance(fc, six.string_types):
100 return fc
101 elif isinstance(fc, fc_lib.FeatureColumn):
102 return {'class_name': fc.__class__.__name__, 'config': fc.get_config()}
103 else:
104 raise ValueError('Instance: {} is not a FeatureColumn'.format(fc))
108@doc_controls.header(_FEATURE_COLUMN_DEPRECATION_WARNING)
109@tf_export('__internal__.feature_column.deserialize_feature_column', v1=[])
110def deserialize_feature_column(config,
111 custom_objects=None,
112 columns_by_name=None):
113 """Deserializes a `config` generated with `serialize_feature_column`.
115 This method should only be used to deserialize parent FeatureColumns when
116 implementing FeatureColumn.from_config(), else deserialize_feature_columns()
117 is preferable. Returns a FeatureColumn for this config.
119 Args:
120 config: A Dict with the serialization of feature columns acquired by
121 `serialize_feature_column`, or a string representing a raw column.
122 custom_objects: A Dict from custom_object name to the associated keras
123 serializable objects (FeatureColumns, classes or functions).
124 columns_by_name: A Dict[String, FeatureColumn] of existing columns in order
125 to avoid duplication.
127 Raises:
128 ValueError if `config` has invalid format (e.g: expected keys missing,
129 or refers to unknown classes).
131 Returns:
132 A FeatureColumn corresponding to the input `config`.
133 """
134 # TODO(b/118939620): Simplify code if Keras utils support object deduping.
135 if isinstance(config, six.string_types):
136 return config
137 # A dict from class_name to class for all FeatureColumns in this module.
138 # FeatureColumns not part of the module can be passed as custom_objects.
139 module_feature_column_classes = {
140 cls.__name__: cls for cls in _FEATURE_COLUMNS
141 }
142 if columns_by_name is None:
143 columns_by_name = {}
145 (cls, cls_config) = _class_and_config_for_serialized_keras_object(
146 config,
147 module_objects=module_feature_column_classes,
148 custom_objects=custom_objects,
149 printable_module_name='feature_column_v2')
151 if not issubclass(cls, fc_lib.FeatureColumn):
152 raise ValueError(
153 'Expected FeatureColumn class, instead found: {}'.format(cls))
155 # Always deserialize the FeatureColumn, in order to get the name.
156 new_instance = cls.from_config( # pylint: disable=protected-access
157 cls_config,
158 custom_objects=custom_objects,
159 columns_by_name=columns_by_name)
161 # If the name already exists, re-use the column from columns_by_name,
162 # (new_instance remains unused).
163 return columns_by_name.setdefault(
164 _column_name_with_class_name(new_instance), new_instance)
168def serialize_feature_columns(feature_columns):
169 """Serializes a list of FeatureColumns.
171 Returns a list of Keras-style config dicts that represent the input
172 FeatureColumns and can be used with `deserialize_feature_columns` for
173 reconstructing the original columns.
175 Args:
176 feature_columns: A list of FeatureColumns.
178 Returns:
179 Keras serialization for the list of FeatureColumns.
181 Raises:
182 ValueError if called with input that is not a list of FeatureColumns.
183 """
184 return [serialize_feature_column(fc) for fc in feature_columns]
187def deserialize_feature_columns(configs, custom_objects=None):
188 """Deserializes a list of FeatureColumns configs.
190 Returns a list of FeatureColumns given a list of config dicts acquired by
191 `serialize_feature_columns`.
193 Args:
194 configs: A list of Dicts with the serialization of feature columns acquired
195 by `serialize_feature_columns`.
196 custom_objects: A Dict from custom_object name to the associated keras
197 serializable objects (FeatureColumns, classes or functions).
199 Returns:
200 FeatureColumn objects corresponding to the input configs.
202 Raises:
203 ValueError if called with input that is not a list of FeatureColumns.
204 """
205 columns_by_name = {}
206 return [
207 deserialize_feature_column(c, custom_objects, columns_by_name)
208 for c in configs
209 ]
212def _column_name_with_class_name(fc):
213 """Returns a unique name for the feature column used during deduping.
215 Without this two FeatureColumns that have the same name and where
216 one wraps the other, such as an IndicatorColumn wrapping a
217 SequenceCategoricalColumn, will fail to deserialize because they will have the
218 same name in columns_by_name, causing the wrong column to be returned.
220 Args:
221 fc: A FeatureColumn.
223 Returns:
224 A unique name as a string.
225 """
226 return fc.__class__.__name__ + ':' + fc.name
229def _serialize_keras_object(instance):
230 """Serialize a Keras object into a JSON-compatible representation."""
231 _, instance = tf_decorator.unwrap(instance)
232 if instance is None:
233 return None
235 if hasattr(instance, 'get_config'):
236 name = instance.__class__.__name__
237 config = instance.get_config()
238 serialization_config = {}
239 for key, item in config.items():
240 if isinstance(item, six.string_types):
241 serialization_config[key] = item
242 continue
244 # Any object of a different type needs to be converted to string or dict
245 # for serialization (e.g. custom functions, custom classes)
246 try:
247 serialized_item = _serialize_keras_object(item)
248 if isinstance(serialized_item, dict) and not isinstance(item, dict):
249 serialized_item['__passive_serialization__'] = True
250 serialization_config[key] = serialized_item
251 except ValueError:
252 serialization_config[key] = item
254 return {'class_name': name, 'config': serialization_config}
255 if hasattr(instance, '__name__'):
256 return instance.__name__
257 raise ValueError('Cannot serialize', instance)
260def _deserialize_keras_object(identifier,
261 module_objects=None,
262 custom_objects=None,
263 printable_module_name='object'):
264 """Turns the serialized form of a Keras object back into an actual object."""
265 if identifier is None:
266 return None
268 if isinstance(identifier, dict):
269 # In this case we are dealing with a Keras config dictionary.
270 config = identifier
271 (cls, cls_config) = _class_and_config_for_serialized_keras_object(
272 config, module_objects, custom_objects, printable_module_name)
274 if hasattr(cls, 'from_config'):
275 arg_spec = tf_inspect.getfullargspec(cls.from_config)
276 custom_objects = custom_objects or {}
278 if 'custom_objects' in arg_spec.args:
279 return cls.from_config(
280 cls_config, custom_objects=dict(list(custom_objects.items())))
281 return cls.from_config(cls_config)
282 else:
283 # Then `cls` may be a function returning a class.
284 # in this case by convention `config` holds
285 # the kwargs of the function.
286 custom_objects = custom_objects or {}
287 return cls(**cls_config)
288 elif isinstance(identifier, six.string_types):
289 object_name = identifier
290 if custom_objects and object_name in custom_objects:
291 obj = custom_objects.get(object_name)
292 else:
293 obj = module_objects.get(object_name)
294 if obj is None:
295 raise ValueError('Unknown ' + printable_module_name + ': ' +
296 object_name)
297 # Classes passed by name are instantiated with no args, functions are
298 # returned as-is.
299 if tf_inspect.isclass(obj):
300 return obj()
301 return obj
302 elif tf_inspect.isfunction(identifier):
303 # If a function has already been deserialized, return as is.
304 return identifier
305 else:
306 raise ValueError('Could not interpret serialized %s: %s' %
307 (printable_module_name, identifier))
310def _class_and_config_for_serialized_keras_object(
311 config,
312 module_objects=None,
313 custom_objects=None,
314 printable_module_name='object'):
315 """Returns the class name and config for a serialized keras object."""
316 if (not isinstance(config, dict) or 'class_name' not in config or
317 'config' not in config):
318 raise ValueError('Improper config format: ' + str(config))
320 class_name = config['class_name']
321 cls = _get_registered_object(
322 class_name, custom_objects=custom_objects, module_objects=module_objects)
323 if cls is None:
324 raise ValueError('Unknown ' + printable_module_name + ': ' + class_name)
326 cls_config = config['config']
328 deserialized_objects = {}
329 for key, item in cls_config.items():
330 if isinstance(item, dict) and '__passive_serialization__' in item:
331 deserialized_objects[key] = _deserialize_keras_object(
332 item,
333 module_objects=module_objects,
334 custom_objects=custom_objects,
335 printable_module_name='config_item')
336 elif (isinstance(item, six.string_types) and
337 tf_inspect.isfunction(_get_registered_object(item, custom_objects))):
338 # Handle custom functions here. When saving functions, we only save the
339 # function's name as a string. If we find a matching string in the custom
340 # objects during deserialization, we convert the string back to the
341 # original function.
342 # Note that a potential issue is that a string field could have a naming
343 # conflict with a custom function name, but this should be a rare case.
344 # This issue does not occur if a string field has a naming conflict with
345 # a custom object, since the config of an object will always be a dict.
346 deserialized_objects[key] = _get_registered_object(item, custom_objects)
347 for key, item in deserialized_objects.items():
348 cls_config[key] = deserialized_objects[key]
350 return (cls, cls_config)
353def _get_registered_object(name, custom_objects=None, module_objects=None):
354 if custom_objects and name in custom_objects:
355 return custom_objects[name]
356 elif module_objects and name in module_objects:
357 return module_objects[name]
358 return None