Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/saving/saving_utils.py: 19%
147 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utils related to keras model saving."""
17import collections
18import copy
19import os
21from tensorflow.python.eager import def_function
22from tensorflow.python.keras import backend as K
23from tensorflow.python.keras import losses
24from tensorflow.python.keras import optimizer_v1
25from tensorflow.python.keras import optimizers
26from tensorflow.python.keras.engine import base_layer_utils
27from tensorflow.python.keras.utils import generic_utils
28from tensorflow.python.keras.utils import version_utils
29from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite
30from tensorflow.python.platform import tf_logging as logging
31from tensorflow.python.util import nest
34def extract_model_metrics(model):
35 """Convert metrics from a Keras model `compile` API to dictionary.
37 This is used for converting Keras models to Estimators and SavedModels.
39 Args:
40 model: A `tf.keras.Model` object.
42 Returns:
43 Dictionary mapping metric names to metric instances. May return `None` if
44 the model does not contain any metrics.
45 """
46 if getattr(model, '_compile_metrics', None):
47 # TODO(psv/kathywu): use this implementation in model to estimator flow.
48 # We are not using model.metrics here because we want to exclude the metrics
49 # added using `add_metric` API.
50 return {m.name: m for m in model._compile_metric_functions} # pylint: disable=protected-access
51 return None
54def model_input_signature(model, keep_original_batch_size=False):
55 """Inspect model to get its input signature.
57 The model's input signature is a list with a single (possibly-nested) object.
58 This is due to the Keras-enforced restriction that tensor inputs must be
59 passed in as the first argument.
61 For example, a model with input {'feature1': <Tensor>, 'feature2': <Tensor>}
62 will have input signature: [{'feature1': TensorSpec, 'feature2': TensorSpec}]
64 Args:
65 model: Keras Model object.
66 keep_original_batch_size: A boolean indicating whether we want to keep using
67 the original batch size or set it to None. Default is `False`, which means
68 that the batch dim of the returned input signature will always be set to
69 `None`.
71 Returns:
72 A list containing either a single TensorSpec or an object with nested
73 TensorSpecs. This list does not contain the `training` argument.
74 """
75 input_specs = model._get_save_spec(dynamic_batch=not keep_original_batch_size) # pylint: disable=protected-access
76 if input_specs is None:
77 return None
78 input_specs = _enforce_names_consistency(input_specs)
79 # Return a list with a single element as the model's input signature.
80 if isinstance(input_specs,
81 collections.abc.Sequence) and len(input_specs) == 1:
82 # Note that the isinstance check filters out single-element dictionaries,
83 # which should also be wrapped as a single-element list.
84 return input_specs
85 else:
86 return [input_specs]
89def raise_model_input_error(model):
90 raise ValueError(
91 'Model {} cannot be saved because the input shapes have not been '
92 'set. Usually, input shapes are automatically determined from calling'
93 ' `.fit()` or `.predict()`. To manually set the shapes, call '
94 '`model.build(input_shape)`.'.format(model))
97def trace_model_call(model, input_signature=None):
98 """Trace the model call to create a tf.function for exporting a Keras model.
100 Args:
101 model: A Keras model.
102 input_signature: optional, a list of tf.TensorSpec objects specifying the
103 inputs to the model.
105 Returns:
106 A tf.function wrapping the model's call function with input signatures set.
108 Raises:
109 ValueError: if input signature cannot be inferred from the model.
110 """
111 if input_signature is None:
112 if isinstance(model.call, def_function.Function):
113 input_signature = model.call.input_signature
115 if input_signature is None:
116 input_signature = model_input_signature(model)
118 if input_signature is None:
119 raise_model_input_error(model)
121 @def_function.function(input_signature=input_signature)
122 def _wrapped_model(*args):
123 """A concrete tf.function that wraps the model's call function."""
124 # When given a single input, Keras models will call the model on the tensor
125 # rather than a list consisting of the single tensor.
126 inputs = args[0] if len(input_signature) == 1 else list(args)
128 with base_layer_utils.call_context().enter(
129 model, inputs=inputs, build_graph=False, training=False, saving=True):
130 outputs = model(inputs, training=False)
132 # Outputs always has to be a flat dict.
133 output_names = model.output_names # Functional Model.
134 if output_names is None: # Subclassed Model.
135 from tensorflow.python.keras.engine import compile_utils # pylint: disable=g-import-not-at-top
136 output_names = compile_utils.create_pseudo_output_names(outputs)
137 outputs = nest.flatten(outputs)
138 return {name: output for name, output in zip(output_names, outputs)}
140 return _wrapped_model
143def model_metadata(model, include_optimizer=True, require_config=True):
144 """Returns a dictionary containing the model metadata."""
145 from tensorflow.python.keras import __version__ as keras_version # pylint: disable=g-import-not-at-top
146 from tensorflow.python.keras.optimizer_v2 import optimizer_v2 # pylint: disable=g-import-not-at-top
148 model_config = {'class_name': model.__class__.__name__}
149 try:
150 model_config['config'] = model.get_config()
151 except NotImplementedError as e:
152 if require_config:
153 raise e
155 metadata = dict(
156 keras_version=str(keras_version),
157 backend=K.backend(),
158 model_config=model_config)
159 if model.optimizer and include_optimizer:
160 if isinstance(model.optimizer, optimizer_v1.TFOptimizer):
161 logging.warning(
162 'TensorFlow optimizers do not '
163 'make it possible to access '
164 'optimizer attributes or optimizer state '
165 'after instantiation. '
166 'As a result, we cannot save the optimizer '
167 'as part of the model save file. '
168 'You will have to compile your model again after loading it. '
169 'Prefer using a Keras optimizer instead '
170 '(see keras.io/optimizers).')
171 elif model._compile_was_called: # pylint: disable=protected-access
172 training_config = model._get_compile_args(user_metrics=False) # pylint: disable=protected-access
173 training_config.pop('optimizer', None) # Handled separately.
174 metadata['training_config'] = _serialize_nested_config(training_config)
175 if isinstance(model.optimizer, optimizer_v2.RestoredOptimizer):
176 raise NotImplementedError(
177 'As of now, Optimizers loaded from SavedModel cannot be saved. '
178 'If you\'re calling `model.save` or `tf.keras.models.save_model`,'
179 ' please set the `include_optimizer` option to `False`. For '
180 '`tf.saved_model.save`, delete the optimizer from the model.')
181 else:
182 optimizer_config = {
183 'class_name':
184 generic_utils.get_registered_name(model.optimizer.__class__),
185 'config':
186 model.optimizer.get_config()
187 }
188 metadata['training_config']['optimizer_config'] = optimizer_config
189 return metadata
192def should_overwrite(filepath, overwrite):
193 """Returns whether the filepath should be overwritten."""
194 # If file exists and should not be overwritten.
195 if not overwrite and os.path.isfile(filepath):
196 return ask_to_proceed_with_overwrite(filepath)
197 return True
200def compile_args_from_training_config(training_config, custom_objects=None):
201 """Return model.compile arguments from training config."""
202 if custom_objects is None:
203 custom_objects = {}
205 with generic_utils.CustomObjectScope(custom_objects):
206 optimizer_config = training_config['optimizer_config']
207 optimizer = optimizers.deserialize(optimizer_config)
209 # Recover losses.
210 loss = None
211 loss_config = training_config.get('loss', None)
212 if loss_config is not None:
213 loss = _deserialize_nested_config(losses.deserialize, loss_config)
215 # Recover metrics.
216 metrics = None
217 metrics_config = training_config.get('metrics', None)
218 if metrics_config is not None:
219 metrics = _deserialize_nested_config(_deserialize_metric, metrics_config)
221 # Recover weighted metrics.
222 weighted_metrics = None
223 weighted_metrics_config = training_config.get('weighted_metrics', None)
224 if weighted_metrics_config is not None:
225 weighted_metrics = _deserialize_nested_config(_deserialize_metric,
226 weighted_metrics_config)
228 sample_weight_mode = training_config['sample_weight_mode'] if hasattr(
229 training_config, 'sample_weight_mode') else None
230 loss_weights = training_config['loss_weights']
232 return dict(
233 optimizer=optimizer,
234 loss=loss,
235 metrics=metrics,
236 weighted_metrics=weighted_metrics,
237 loss_weights=loss_weights,
238 sample_weight_mode=sample_weight_mode)
241def _deserialize_nested_config(deserialize_fn, config):
242 """Deserializes arbitrary Keras `config` using `deserialize_fn`."""
244 def _is_single_object(obj):
245 if isinstance(obj, dict) and 'class_name' in obj:
246 return True # Serialized Keras object.
247 if isinstance(obj, str):
248 return True # Serialized function or string.
249 return False
251 if config is None:
252 return None
253 if _is_single_object(config):
254 return deserialize_fn(config)
255 elif isinstance(config, dict):
256 return {
257 k: _deserialize_nested_config(deserialize_fn, v)
258 for k, v in config.items()
259 }
260 elif isinstance(config, (tuple, list)):
261 return [_deserialize_nested_config(deserialize_fn, obj) for obj in config]
263 raise ValueError('Saved configuration not understood.')
266def _serialize_nested_config(config):
267 """Serialized a nested structure of Keras objects."""
269 def _serialize_fn(obj):
270 if callable(obj):
271 return generic_utils.serialize_keras_object(obj)
272 return obj
274 return nest.map_structure(_serialize_fn, config)
277def _deserialize_metric(metric_config):
278 """Deserialize metrics, leaving special strings untouched."""
279 from tensorflow.python.keras import metrics as metrics_module # pylint:disable=g-import-not-at-top
280 if metric_config in ['accuracy', 'acc', 'crossentropy', 'ce']:
281 # Do not deserialize accuracy and cross-entropy strings as we have special
282 # case handling for these in compile, based on model output shape.
283 return metric_config
284 return metrics_module.deserialize(metric_config)
287def _enforce_names_consistency(specs):
288 """Enforces that either all specs have names or none do."""
290 def _has_name(spec):
291 return hasattr(spec, 'name') and spec.name is not None
293 def _clear_name(spec):
294 spec = copy.deepcopy(spec)
295 if hasattr(spec, 'name'):
296 spec._name = None # pylint:disable=protected-access
297 return spec
299 flat_specs = nest.flatten(specs)
300 name_inconsistency = (
301 any(_has_name(s) for s in flat_specs) and
302 not all(_has_name(s) for s in flat_specs))
304 if name_inconsistency:
305 specs = nest.map_structure(_clear_name, specs)
306 return specs
309def try_build_compiled_arguments(model):
310 if (not version_utils.is_v1_layer_or_model(model) and
311 model.outputs is not None):
312 try:
313 if not model.compiled_loss.built:
314 model.compiled_loss.build(model.outputs)
315 if not model.compiled_metrics.built:
316 model.compiled_metrics.build(model.outputs, model.outputs)
317 except: # pylint: disable=bare-except
318 logging.warning(
319 'Compiled the loaded model, but the compiled metrics have yet to '
320 'be built. `model.compile_metrics` will be empty until you train '
321 'or evaluate the model.')
324def is_hdf5_filepath(filepath):
325 return (filepath.endswith('.h5') or filepath.endswith('.keras') or
326 filepath.endswith('.hdf5'))