Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/saving/legacy/save.py: 18%
138 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Keras model saving code."""
17import os
19import tensorflow.compat.v2 as tf
21from keras.src import backend
22from keras.src.saving import object_registration
23from keras.src.saving.legacy import hdf5_format
24from keras.src.saving.legacy import saving_utils
25from keras.src.saving.legacy import serialization
26from keras.src.saving.legacy.saved_model import load as saved_model_load
27from keras.src.saving.legacy.saved_model import load_context
28from keras.src.saving.legacy.saved_model import save as saved_model_save
29from keras.src.saving.legacy.saved_model.utils import keras_option_scope
30from keras.src.utils import io_utils
31from keras.src.utils import traceback_utils
33try:
34 import h5py
35except ImportError:
36 h5py = None
39@traceback_utils.filter_traceback
40def save_model(
41 model,
42 filepath,
43 overwrite=True,
44 include_optimizer=True,
45 save_format=None,
46 signatures=None,
47 options=None,
48 save_traces=True,
49):
50 """Saves a model as a TensorFlow SavedModel or HDF5 file.
52 See the [Serialization and Saving
53 guide](https://keras.io/guides/serialization_and_saving/) for details.
55 Usage:
57 >>> model = tf.keras.Sequential([
58 ... tf.keras.layers.Dense(5, input_shape=(3,)),
59 ... tf.keras.layers.Softmax()])
60 >>> model.save('/tmp/model')
61 >>> loaded_model = tf.keras.models.load_model('/tmp/model')
62 >>> x = tf.random.uniform((10, 3))
63 >>> assert np.allclose(model.predict(x), loaded_model.predict(x))
65 Note that `model.save()` is an alias for `tf.keras.models.save_model()`.
67 The SavedModel and HDF5 file contains:
69 - the model's configuration (topology)
70 - the model's weights
71 - the model's optimizer's state (if any)
73 Thus models can be reinstantiated in the exact same state, without any of
74 the code used for model definition or training.
76 Note that the model weights may have different scoped names after being
77 loaded. Scoped names include the model/layer names, such as
78 `"dense_1/kernel:0"`. It is recommended that you use the layer properties to
79 access specific variables, e.g. `model.get_layer("dense_1").kernel`.
81 __SavedModel serialization format__
83 Keras SavedModel uses `tf.saved_model.save` to save the model and all
84 trackable objects attached to the model (e.g. layers and variables). The
85 model config, weights, and optimizer are saved in the SavedModel.
86 Additionally, for every Keras layer attached to the model, the SavedModel
87 stores:
89 * the config and metadata -- e.g. name, dtype, trainable status
90 * traced call and loss functions, which are stored as TensorFlow
91 subgraphs.
93 The traced functions allow the SavedModel format to save and load custom
94 layers without the original class definition.
96 You can choose to not save the traced functions by disabling the
97 `save_traces` option. This will decrease the time it takes to save the model
98 and the amount of disk space occupied by the output SavedModel. If you
99 enable this option, then you _must_ provide all custom class definitions
100 when loading the model. See the `custom_objects` argument in
101 `tf.keras.models.load_model`.
103 Args:
104 model: Keras model instance to be saved.
105 filepath: One of the following:
106 - String or `pathlib.Path` object, path where to save the model
107 - `h5py.File` object where to save the model
108 overwrite: Whether we should overwrite any existing model at the target
109 location, or instead ask the user with a manual prompt.
110 include_optimizer: If True, save optimizer's state together.
111 save_format: Either 'tf' or 'h5', indicating whether to save the model
112 to Tensorflow SavedModel or HDF5. Defaults to 'tf' in TF 2.X, and 'h5'
113 in TF 1.X.
114 signatures: Signatures to save with the SavedModel. Applicable to the
115 'tf' format only. Please see the `signatures` argument in
116 `tf.saved_model.save` for details.
117 options: (only applies to SavedModel format)
118 `tf.saved_model.SaveOptions` object that specifies options for saving
119 to SavedModel.
120 save_traces: (only applies to SavedModel format) When enabled, the
121 SavedModel will store the function traces for each layer. This
122 can be disabled, so that only the configs of each layer are stored.
123 Defaults to `True`. Disabling this will decrease serialization time
124 and reduce file size, but it requires that all custom layers/models
125 implement a `get_config()` method.
127 Raises:
128 ImportError: If save format is hdf5, and h5py is not available.
129 """
131 from keras.src.engine import sequential
133 default_format = "tf" if tf.__internal__.tf2.enabled() else "h5"
134 save_format = save_format or default_format
136 filepath = io_utils.path_to_string(filepath)
138 # If the user has not already called fit or built the underlying metrics, we
139 # should do that before saving to ensure the metric names have all
140 # appropriate name transformations applied.
141 saving_utils.try_build_compiled_arguments(model)
143 if (
144 save_format == "h5"
145 or (h5py is not None and isinstance(filepath, h5py.File))
146 or saving_utils.is_hdf5_filepath(filepath)
147 ):
148 # TODO(b/130258301): add utility method for detecting model type.
149 if not model._is_graph_network and not isinstance(
150 model, sequential.Sequential
151 ):
152 raise NotImplementedError(
153 "Saving the model to HDF5 format requires the model to be a "
154 "Functional model or a Sequential model. It does not work for "
155 "subclassed models, because such models are defined via the "
156 "body of a Python method, which isn't safely serializable. "
157 "Consider saving to the Tensorflow SavedModel format (by "
158 'setting save_format="tf") or using `save_weights`.'
159 )
160 hdf5_format.save_model_to_hdf5(
161 model, filepath, overwrite, include_optimizer
162 )
163 else:
164 with serialization.SharedObjectSavingScope():
165 with keras_option_scope(
166 save_traces=save_traces, in_tf_saved_model_scope=True
167 ):
168 saved_model_save.save(
169 model,
170 filepath,
171 overwrite,
172 include_optimizer,
173 signatures,
174 options,
175 save_traces,
176 )
179@traceback_utils.filter_traceback
180def load_model(filepath, custom_objects=None, compile=True, options=None):
181 """Loads a model saved via `model.save()`.
183 Usage:
185 >>> model = tf.keras.Sequential([
186 ... tf.keras.layers.Dense(5, input_shape=(3,)),
187 ... tf.keras.layers.Softmax()])
188 >>> model.save('/tmp/model')
189 >>> loaded_model = tf.keras.models.load_model('/tmp/model')
190 >>> x = tf.random.uniform((10, 3))
191 >>> assert np.allclose(model.predict(x), loaded_model.predict(x))
193 Note that the model weights may have different scoped names after being
194 loaded. Scoped names include the model/layer names, such as
195 `"dense_1/kernel:0"`. It is recommended that you use the layer properties to
196 access specific variables, e.g. `model.get_layer("dense_1").kernel`.
198 Args:
199 filepath: One of the following:
200 - String or `pathlib.Path` object, path to the saved model
201 - `h5py.File` object from which to load the model
202 custom_objects: Optional dictionary mapping names
203 (strings) to custom classes or functions to be
204 considered during deserialization.
205 compile: Boolean, whether to compile the model
206 after loading.
207 options: Optional `tf.saved_model.LoadOptions` object that specifies
208 options for loading from SavedModel.
210 Returns:
211 A Keras model instance. If the original model was compiled, and saved
212 with the optimizer, then the returned model will be compiled. Otherwise,
213 the model will be left uncompiled. In the case that an uncompiled model
214 is returned, a warning is displayed if the `compile` argument is set to
215 `True`.
217 Raises:
218 ImportError: if loading from an hdf5 file and h5py is not available.
219 IOError: In case of an invalid savefile.
220 """
221 with serialization.SharedObjectLoadingScope():
222 custom_objects = custom_objects or {}
223 tlco = object_registration._THREAD_LOCAL_CUSTOM_OBJECTS.__dict__
224 gco = object_registration._GLOBAL_CUSTOM_OBJECTS
225 custom_objects = {**custom_objects, **tlco, **gco}
226 with object_registration.CustomObjectScope(custom_objects):
227 with keras_option_scope(
228 save_traces=False, in_tf_saved_model_scope=True
229 ):
230 with load_context.load_context(options):
231 filepath_str = io_utils.path_to_string(filepath)
232 if isinstance(filepath_str, str):
233 if not tf.io.gfile.exists(filepath_str):
234 raise IOError(
235 f"No file or directory found at {filepath_str}"
236 )
238 if tf.io.gfile.isdir(filepath_str):
239 return saved_model_load.load(
240 filepath_str, compile, options
241 )
242 else:
243 if h5py is None:
244 raise ImportError(
245 "Filepath looks like a hdf5 file but h5py"
246 "is not available."
247 f" filepath={filepath_str}"
248 )
249 return hdf5_format.load_model_from_hdf5(
250 tf.io.gfile.GFile(filepath_str, mode="rb"),
251 custom_objects,
252 compile,
253 )
254 elif h5py is not None and isinstance(filepath, h5py.File):
255 return hdf5_format.load_model_from_hdf5(
256 filepath, custom_objects, compile
257 )
259 raise IOError(
260 "Unable to load model. Filepath is not an hdf5 file (or h5py is not "
261 f"available) or SavedModel. Received: filepath={filepath}"
262 )
265def save_weights(
266 model, filepath, overwrite=True, save_format=None, options=None
267):
268 """Saves all layer weights.
270 Either saves in HDF5 or in TensorFlow format based on the `save_format`
271 argument.
273 When saving in HDF5 format, the weight file has:
274 - `layer_names` (attribute), a list of strings
275 (ordered names of model layers).
276 - For every layer, a `group` named `layer.name`
277 - For every such layer group, a group attribute `weight_names`,
278 a list of strings
279 (ordered names of weights tensor of the layer).
280 - For every weight in the layer, a dataset
281 storing the weight value, named after the weight tensor.
283 When saving in TensorFlow format, all objects referenced by the network
284 are saved in the same format as `tf.train.Checkpoint`, including any
285 `Layer` instances or `Optimizer` instances assigned to object
286 attributes. For networks constructed from inputs and outputs using
287 `tf.keras.Model(inputs, outputs)`, `Layer` instances used by the network
288 are tracked/saved automatically. For user-defined classes which inherit
289 from `tf.keras.Model`, `Layer` instances must be assigned to object
290 attributes, typically in the constructor. See the documentation of
291 `tf.train.Checkpoint` and `tf.keras.Model` for details.
293 While the formats are the same, do not mix `save_weights` and
294 `tf.train.Checkpoint`. Checkpoints saved by `Model.save_weights` should
295 be loaded using `Model.load_weights`. Checkpoints saved using
296 `tf.train.Checkpoint.save` should be restored using the corresponding
297 `tf.train.Checkpoint.restore`. Prefer `tf.train.Checkpoint` over
298 `save_weights` for training checkpoints.
300 The TensorFlow format matches objects and variables by starting at a
301 root object, `self` for `save_weights`, and greedily matching attribute
302 names. For `Model.save` this is the `Model`, and for `Checkpoint.save`
303 this is the `Checkpoint` even if the `Checkpoint` has a model attached.
304 This means saving a `tf.keras.Model` using `save_weights` and loading
305 into a `tf.train.Checkpoint` with a `Model` attached (or vice versa)
306 will not match the `Model`'s variables. See the
307 [guide to training checkpoints](
308 https://www.tensorflow.org/guide/checkpoint) for details on
309 the TensorFlow format.
311 Args:
312 filepath: String or PathLike, path to the file to save the weights
313 to. When saving in TensorFlow format, this is the prefix used
314 for checkpoint files (multiple files are generated). Note that
315 the '.h5' suffix causes weights to be saved in HDF5 format.
316 overwrite: Whether to silently overwrite any existing file at the
317 target location, or provide the user with a manual prompt.
318 save_format: Either 'tf' or 'h5'. A `filepath` ending in '.h5' or
319 '.keras' will default to HDF5 if `save_format` is `None`.
320 Otherwise `None` defaults to 'tf'.
321 options: Optional `tf.train.CheckpointOptions` object that specifies
322 options for saving weights.
324 Raises:
325 ImportError: If `h5py` is not available when attempting to save in
326 HDF5 format.
327 """
328 model._assert_weights_created()
329 filepath = io_utils.path_to_string(filepath)
330 filepath_is_h5 = saving_utils.is_hdf5_filepath(filepath)
331 if save_format is None:
332 if filepath_is_h5:
333 save_format = "h5"
334 else:
335 save_format = "tf"
336 else:
337 user_format = save_format.lower().strip()
338 if user_format in ("tensorflow", "tf"):
339 save_format = "tf"
340 elif user_format in ("hdf5", "h5", "keras"):
341 save_format = "h5"
342 else:
343 raise ValueError(
344 f"Unknown format. Received: `save_format`={save_format}. "
345 'Was expecting one of {"tf", "h5"}.'
346 )
347 if save_format == "tf" and filepath_is_h5:
348 raise ValueError(
349 'save_weights got save_format="tf"/"tensorflow", but the '
350 f"filepath ({filepath}) looks like an HDF5 file. "
351 'Omit the ".h5"/".keras" when saving in TensorFlow format.'
352 )
354 if save_format == "h5" and h5py is None:
355 raise ImportError(
356 "`save_weights` requires h5py when saving in hdf5, but h5py is "
357 "not available. Try installing h5py package."
358 )
359 if save_format == "tf":
360 check_filepath = filepath + ".index"
361 else:
362 check_filepath = filepath
363 # If file exists and should not be overwritten:
364 if not overwrite and os.path.isfile(check_filepath):
365 proceed = io_utils.ask_to_proceed_with_overwrite(check_filepath)
366 if not proceed:
367 return
368 if save_format == "h5":
369 with h5py.File(filepath, "w") as f:
370 hdf5_format.save_weights_to_hdf5_group(f, model)
371 else:
372 if not tf.executing_eagerly():
373 # Call `get_session` to initialize any uninitialized variables.
374 backend.get_session()
375 model._checkpoint.write(filepath, options=options)
377 # Record this checkpoint so it's visible from
378 # tf.train.latest_checkpoint.
379 tf.__internal__.train.update_checkpoint_state(
380 save_dir=os.path.dirname(filepath),
381 model_checkpoint_path=filepath,
382 save_relative_paths=True,
383 all_model_checkpoint_paths=[filepath],
384 )
387def load_weights(
388 model, filepath, by_name=False, skip_mismatch=False, options=None
389):
390 """Loads all layer weights, either from a SavedModel or H5 weights file.
392 If `by_name` is False weights are loaded based on the network's
393 topology. This means the architecture should be the same as when the
394 weights were saved. Note that layers that don't have weights are not
395 taken into account in the topological ordering, so adding or removing
396 layers is fine as long as they don't have weights.
398 If `by_name` is True, weights are loaded into layers only if they share
399 the same name. This is useful for fine-tuning or transfer-learning
400 models where some of the layers have changed.
402 Only topological loading (`by_name=False`) is supported when loading
403 weights from the TensorFlow format. Note that topological loading
404 differs slightly between TensorFlow and HDF5 formats for user-defined
405 classes inheriting from `tf.keras.Model`: HDF5 loads based on a
406 flattened list of weights, while the TensorFlow format loads based on
407 the object-local names of attributes to which layers are assigned in the
408 `Model`'s constructor.
410 Args:
411 filepath: String, path to the weights file to load. For weight files
412 in TensorFlow format, this is the file prefix (the same as was
413 passed to `save_weights`). This can also be a path to a
414 SavedModel saved from `model.save`.
415 by_name: Boolean, whether to load weights by name or by topological
416 order. Only topological loading is supported for weight files in
417 TensorFlow format.
418 skip_mismatch: Boolean, whether to skip loading of layers where
419 there is a mismatch in the number of weights, or a mismatch in
420 the shape of the weight (only valid when `by_name=True`).
421 options: Optional `tf.train.CheckpointOptions` object that specifies
422 options for loading weights.
424 Returns:
425 When loading a weight file in TensorFlow format, returns the same
426 status object as `tf.train.Checkpoint.restore`. When graph building,
427 restore ops are run automatically as soon as the network is built
428 (on first call for user-defined classes inheriting from `Model`,
429 immediately if it is already built).
431 When loading weights in HDF5 format, returns `None`.
433 Raises:
434 ImportError: If `h5py` is not available and the weight file is in
435 HDF5 format.
436 ValueError: If `skip_mismatch` is set to `True` when `by_name` is
437 `False`.
438 """
439 if backend.is_tpu_strategy(model._distribution_strategy):
440 if model._distribution_strategy.extended.steps_per_run > 1 and (
441 not saving_utils.is_hdf5_filepath(filepath)
442 ):
443 spr = model._distribution_strategy.extended.steps_per_run
444 raise ValueError(
445 "Load weights is not implemented with TPUStrategy "
446 "with `steps_per_run` greater than 1. The "
447 f"`steps_per_run` is {spr}"
448 )
449 if skip_mismatch and not by_name:
450 raise ValueError(
451 "When calling model.load_weights, skip_mismatch can only be "
452 "set to True when by_name is True."
453 )
455 filepath, save_format = _detect_save_format(filepath)
456 if save_format == "tf":
457 status = model._checkpoint.read(filepath, options)
458 if by_name:
459 raise NotImplementedError(
460 "Weights may only be loaded based on topology into Models "
461 "when loading TensorFlow-formatted weights "
462 "(got by_name=True to load_weights)."
463 )
464 if not tf.executing_eagerly():
465 session = backend.get_session()
466 # Restore existing variables (if any) immediately, and set up a
467 # streaming restore for any variables created in the future.
468 tf.__internal__.tracking.streaming_restore(
469 status=status, session=session
470 )
471 status.assert_nontrivial_match()
472 else:
473 status = None
474 if h5py is None:
475 raise ImportError(
476 "`load_weights` requires h5py package when loading weights "
477 "from HDF5. Try installing h5py."
478 )
479 if not model._is_graph_network and not model.built:
480 raise ValueError(
481 "Unable to load weights saved in HDF5 format into a "
482 "subclassed Model which has not created its variables yet. "
483 "Call the Model first, then load the weights."
484 )
485 model._assert_weights_created()
486 with h5py.File(filepath, "r") as f:
487 if "layer_names" not in f.attrs and "model_weights" in f:
488 f = f["model_weights"]
489 if by_name:
490 hdf5_format.load_weights_from_hdf5_group_by_name(
491 f, model, skip_mismatch
492 )
493 else:
494 hdf5_format.load_weights_from_hdf5_group(f, model)
496 # Perform any layer defined finalization of the layer state.
497 for layer in model.layers:
498 layer.finalize_state()
499 return status
502def _detect_save_format(filepath):
503 """Returns path to weights file and save format."""
505 filepath = io_utils.path_to_string(filepath)
506 if saving_utils.is_hdf5_filepath(filepath):
507 return filepath, "h5"
509 # Filepath could be a TensorFlow checkpoint file prefix or SavedModel
510 # directory. It's possible for filepath to be both a prefix and directory.
511 # Prioritize checkpoint over SavedModel.
512 if _is_readable_tf_checkpoint(filepath):
513 save_format = "tf"
514 elif tf.saved_model.contains_saved_model(filepath):
515 ckpt_path = os.path.join(
516 filepath,
517 tf.saved_model.VARIABLES_DIRECTORY,
518 tf.saved_model.VARIABLES_FILENAME,
519 )
520 if _is_readable_tf_checkpoint(ckpt_path):
521 filepath = ckpt_path
522 save_format = "tf"
523 else:
524 raise ValueError(
525 "Unable to load weights. filepath {} appears to be a "
526 "SavedModel directory, but checkpoint either doesn't "
527 "exist, or is incorrectly formatted.".format(filepath)
528 )
529 else:
530 # Not a TensorFlow checkpoint. This filepath is likely an H5 file that
531 # doesn't have the hdf5/keras extensions.
532 save_format = "h5"
533 return filepath, save_format
536def _is_readable_tf_checkpoint(filepath):
537 try:
538 tf.compat.v1.train.NewCheckpointReader(filepath)
539 return True
540 except tf.errors.DataLossError:
541 # The checkpoint is not readable in TensorFlow format.
542 return False
545# Inject the load_model function to keras_deps to remove the dependency
546# from TFLite to Keras.
547tf.__internal__.register_load_model_function(load_model)