Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/saving/legacy/saving_utils.py: 18%
148 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utils related to keras model saving."""
17import copy
18import os
20import tensorflow.compat.v2 as tf
22import keras.src as keras
23from keras.src import backend
24from keras.src import losses
25from keras.src import optimizers
26from keras.src.engine import base_layer_utils
27from keras.src.optimizers import optimizer_v1
28from keras.src.saving.legacy import serialization
29from keras.src.utils import version_utils
30from keras.src.utils.io_utils import ask_to_proceed_with_overwrite
32# isort: off
33from tensorflow.python.platform import tf_logging as logging
36def extract_model_metrics(model):
37 """Convert metrics from a Keras model `compile` API to dictionary.
39 This is used for converting Keras models to Estimators and SavedModels.
41 Args:
42 model: A `tf.keras.Model` object.
44 Returns:
45 Dictionary mapping metric names to metric instances. May return `None` if
46 the model does not contain any metrics.
47 """
48 if getattr(model, "_compile_metrics", None):
49 # TODO(psv/kathywu): use this implementation in model to estimator flow.
50 # We are not using model.metrics here because we want to exclude the
51 # metrics added using `add_metric` API.
52 return {m.name: m for m in model._compile_metric_functions}
53 return None
56def model_call_inputs(model, keep_original_batch_size=False):
57 """Inspect model to get its input signature.
59 The model's input signature is a list with a single (possibly-nested)
60 object. This is due to the Keras-enforced restriction that tensor inputs
61 must be passed in as the first argument.
63 For example, a model with input {'feature1': <Tensor>, 'feature2': <Tensor>}
64 will have input signature:
65 [{'feature1': TensorSpec, 'feature2': TensorSpec}]
67 Args:
68 model: Keras Model object.
69 keep_original_batch_size: A boolean indicating whether we want to keep
70 using the original batch size or set it to None. Default is `False`,
71 which means that the batch dim of the returned input signature will
72 always be set to `None`.
74 Returns:
75 A tuple containing `(args, kwargs)` TensorSpecs of the model call function
76 inputs.
77 `kwargs` does not contain the `training` argument.
78 """
79 input_specs = model.save_spec(dynamic_batch=not keep_original_batch_size)
80 if input_specs is None:
81 return None, None
82 input_specs = _enforce_names_consistency(input_specs)
83 return input_specs
86def raise_model_input_error(model):
87 if isinstance(model, keras.models.Sequential):
88 raise ValueError(
89 f"Model {model} cannot be saved because the input shape is not "
90 "available. Please specify an input shape either by calling "
91 "`build(input_shape)` directly, or by calling the model on actual "
92 "data using `Model()`, `Model.fit()`, or `Model.predict()`."
93 )
95 # If the model is not a `Sequential`, it is intended to be a subclassed
96 # model.
97 raise ValueError(
98 f"Model {model} cannot be saved either because the input shape is not "
99 "available or because the forward pass of the model is not defined."
100 "To define a forward pass, please override `Model.call()`. To specify "
101 "an input shape, either call `build(input_shape)` directly, or call "
102 "the model on actual data using `Model()`, `Model.fit()`, or "
103 "`Model.predict()`. If you have a custom training step, please make "
104 "sure to invoke the forward pass in train step through "
105 "`Model.__call__`, i.e. `model(inputs)`, as opposed to `model.call()`."
106 )
109def trace_model_call(model, input_signature=None):
110 """Trace the model call to create a tf.function for exporting a Keras model.
112 Args:
113 model: A Keras model.
114 input_signature: optional, a list of tf.TensorSpec objects specifying the
115 inputs to the model.
117 Returns:
118 A tf.function wrapping the model's call function with input signatures
119 set.
121 Raises:
122 ValueError: if input signature cannot be inferred from the model.
123 """
124 if input_signature is None:
125 if isinstance(model.call, tf.__internal__.function.Function):
126 input_signature = model.call.input_signature
128 if input_signature:
129 model_args = input_signature
130 model_kwargs = {}
131 else:
132 model_args, model_kwargs = model_call_inputs(model)
134 if model_args is None:
135 raise_model_input_error(model)
137 @tf.function
138 def _wrapped_model(*args, **kwargs):
139 """A concrete tf.function that wraps the model's call function."""
140 (args, kwargs,) = model._call_spec.set_arg_value(
141 "training", False, args, kwargs, inputs_in_args=True
142 )
144 with base_layer_utils.call_context().enter(
145 model, inputs=None, build_graph=False, training=False, saving=True
146 ):
147 outputs = model(*args, **kwargs)
149 # Outputs always has to be a flat dict.
150 output_names = model.output_names # Functional Model.
151 if output_names is None: # Subclassed Model.
152 from keras.src.engine import compile_utils
154 output_names = compile_utils.create_pseudo_output_names(outputs)
155 outputs = tf.nest.flatten(outputs)
156 return {name: output for name, output in zip(output_names, outputs)}
158 return _wrapped_model.get_concrete_function(*model_args, **model_kwargs)
161def model_metadata(model, include_optimizer=True, require_config=True):
162 """Returns a dictionary containing the model metadata."""
163 from keras.src import __version__ as keras_version
164 from keras.src.optimizers.legacy import optimizer_v2
166 model_config = {"class_name": model.__class__.__name__}
167 try:
168 model_config["config"] = model.get_config()
169 except NotImplementedError as e:
170 if require_config:
171 raise e
173 metadata = dict(
174 keras_version=str(keras_version),
175 backend=backend.backend(),
176 model_config=model_config,
177 )
178 if model.optimizer and include_optimizer:
179 if isinstance(model.optimizer, optimizer_v1.TFOptimizer):
180 logging.warning(
181 "TensorFlow optimizers do not "
182 "make it possible to access "
183 "optimizer attributes or optimizer state "
184 "after instantiation. "
185 "As a result, we cannot save the optimizer "
186 "as part of the model save file. "
187 "You will have to compile your model again after loading it. "
188 "Prefer using a Keras optimizer instead "
189 "(see keras.io/optimizers)."
190 )
191 elif model._compile_was_called:
192 training_config = model._get_compile_args(user_metrics=False)
193 training_config.pop("optimizer", None) # Handled separately.
194 metadata["training_config"] = _serialize_nested_config(
195 training_config
196 )
197 if isinstance(model.optimizer, optimizer_v2.RestoredOptimizer):
198 raise NotImplementedError(
199 "Optimizers loaded from a SavedModel cannot be saved. "
200 "If you are calling `model.save` or "
201 "`tf.keras.models.save_model`, "
202 "please set the `include_optimizer` option to `False`. For "
203 "`tf.saved_model.save`, "
204 "delete the optimizer from the model."
205 )
206 else:
207 optimizer_config = {
208 "class_name": keras.utils.get_registered_name(
209 model.optimizer.__class__
210 ),
211 "config": model.optimizer.get_config(),
212 }
213 metadata["training_config"]["optimizer_config"] = optimizer_config
214 return metadata
217def should_overwrite(filepath, overwrite):
218 """Returns whether the filepath should be overwritten."""
219 # If file exists and should not be overwritten.
220 if not overwrite and os.path.isfile(filepath):
221 return ask_to_proceed_with_overwrite(filepath)
222 return True
225def compile_args_from_training_config(training_config, custom_objects=None):
226 """Return model.compile arguments from training config."""
227 if custom_objects is None:
228 custom_objects = {}
230 with keras.utils.CustomObjectScope(custom_objects):
231 optimizer_config = training_config["optimizer_config"]
232 optimizer = optimizers.deserialize(optimizer_config)
234 # Recover losses.
235 loss = None
236 loss_config = training_config.get("loss", None)
237 if loss_config is not None:
238 loss = _deserialize_nested_config(losses.deserialize, loss_config)
240 # Recover metrics.
241 metrics = None
242 metrics_config = training_config.get("metrics", None)
243 if metrics_config is not None:
244 metrics = _deserialize_nested_config(
245 _deserialize_metric, metrics_config
246 )
248 # Recover weighted metrics.
249 weighted_metrics = None
250 weighted_metrics_config = training_config.get("weighted_metrics", None)
251 if weighted_metrics_config is not None:
252 weighted_metrics = _deserialize_nested_config(
253 _deserialize_metric, weighted_metrics_config
254 )
256 sample_weight_mode = (
257 training_config["sample_weight_mode"]
258 if hasattr(training_config, "sample_weight_mode")
259 else None
260 )
261 loss_weights = training_config["loss_weights"]
263 return dict(
264 optimizer=optimizer,
265 loss=loss,
266 metrics=metrics,
267 weighted_metrics=weighted_metrics,
268 loss_weights=loss_weights,
269 sample_weight_mode=sample_weight_mode,
270 )
273def _deserialize_nested_config(deserialize_fn, config):
274 """Deserializes arbitrary Keras `config` using `deserialize_fn`."""
276 def _is_single_object(obj):
277 if isinstance(obj, dict) and "class_name" in obj:
278 return True # Serialized Keras object.
279 if isinstance(obj, str):
280 return True # Serialized function or string.
281 return False
283 if config is None:
284 return None
285 if _is_single_object(config):
286 return deserialize_fn(config)
287 elif isinstance(config, dict):
288 return {
289 k: _deserialize_nested_config(deserialize_fn, v)
290 for k, v in config.items()
291 }
292 elif isinstance(config, (tuple, list)):
293 return [
294 _deserialize_nested_config(deserialize_fn, obj) for obj in config
295 ]
297 raise ValueError(
298 "Saved configuration not understood. Configuration should be a "
299 f"dictionary, string, tuple or list. Received: config={config}."
300 )
303def _serialize_nested_config(config):
304 """Serialized a nested structure of Keras objects."""
306 def _serialize_fn(obj):
307 if callable(obj):
308 return serialization.serialize_keras_object(obj)
309 return obj
311 return tf.nest.map_structure(_serialize_fn, config)
314def _deserialize_metric(metric_config):
315 """Deserialize metrics, leaving special strings untouched."""
316 from keras.src import metrics as metrics_module
318 if metric_config in ["accuracy", "acc", "crossentropy", "ce"]:
319 # Do not deserialize accuracy and cross-entropy strings as we have
320 # special case handling for these in compile, based on model output
321 # shape.
322 return metric_config
323 return metrics_module.deserialize(metric_config)
326def _enforce_names_consistency(specs):
327 """Enforces that either all specs have names or none do."""
329 def _has_name(spec):
330 return spec is None or (hasattr(spec, "name") and spec.name is not None)
332 def _clear_name(spec):
333 spec = copy.deepcopy(spec)
334 if hasattr(spec, "name"):
335 spec._name = None
336 return spec
338 flat_specs = tf.nest.flatten(specs)
339 name_inconsistency = any(_has_name(s) for s in flat_specs) and not all(
340 _has_name(s) for s in flat_specs
341 )
343 if name_inconsistency:
344 specs = tf.nest.map_structure(_clear_name, specs)
345 return specs
348def try_build_compiled_arguments(model):
349 if (
350 not version_utils.is_v1_layer_or_model(model)
351 and model.outputs is not None
352 ):
353 try:
354 if not model.compiled_loss.built:
355 model.compiled_loss.build(model.outputs)
356 if not model.compiled_metrics.built:
357 model.compiled_metrics.build(model.outputs, model.outputs)
358 except: # noqa: E722
359 logging.warning(
360 "Compiled the loaded model, but the compiled metrics have "
361 "yet to be built. `model.compile_metrics` will be empty "
362 "until you train or evaluate the model."
363 )
366def is_hdf5_filepath(filepath):
367 return (
368 filepath.endswith(".h5")
369 or filepath.endswith(".keras")
370 or filepath.endswith(".hdf5")
371 )