Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/saving/save.py: 42%
38 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Keras model saving code."""
17from tensorflow.python import tf2
18from tensorflow.python.keras.saving import hdf5_format
19from tensorflow.python.keras.saving import saving_utils
20from tensorflow.python.keras.saving.saved_model import load as saved_model_load
21from tensorflow.python.keras.saving.saved_model import load_context
22from tensorflow.python.keras.saving.saved_model import save as saved_model_save
23from tensorflow.python.keras.utils import generic_utils
24from tensorflow.python.keras.utils.io_utils import path_to_string
25from tensorflow.python.util.tf_export import keras_export
27# pylint: disable=g-import-not-at-top
28try:
29 import h5py
30except ImportError:
31 h5py = None
32# pylint: enable=g-import-not-at-top
35@keras_export('keras.models.save_model')
36def save_model(model,
37 filepath,
38 overwrite=True,
39 include_optimizer=True,
40 save_format=None,
41 signatures=None,
42 options=None,
43 save_traces=True):
44 # pylint: disable=line-too-long
45 """Saves a model as a TensorFlow SavedModel or HDF5 file.
47 See the [Serialization and Saving guide](https://keras.io/guides/serialization_and_saving/)
48 for details.
50 Usage:
52 >>> model = tf.keras.Sequential([
53 ... tf.keras.layers.Dense(5, input_shape=(3,)),
54 ... tf.keras.layers.Softmax()])
55 >>> model.save('/tmp/model')
56 >>> loaded_model = tf.keras.models.load_model('/tmp/model')
57 >>> x = tf.random.uniform((10, 3))
58 >>> assert np.allclose(model.predict(x), loaded_model.predict(x))
60 The SavedModel and HDF5 file contains:
62 - the model's configuration (topology)
63 - the model's weights
64 - the model's optimizer's state (if any)
66 Thus models can be reinstantiated in the exact same state, without any of the
67 code used for model definition or training.
69 Note that the model weights may have different scoped names after being
70 loaded. Scoped names include the model/layer names, such as
71 `"dense_1/kernel:0"`. It is recommended that you use the layer properties to
72 access specific variables, e.g. `model.get_layer("dense_1").kernel`.
74 __SavedModel serialization format__
76 Keras SavedModel uses `tf.saved_model.save` to save the model and all
77 trackable objects attached to the model (e.g. layers and variables). The model
78 config, weights, and optimizer are saved in the SavedModel. Additionally, for
79 every Keras layer attached to the model, the SavedModel stores:
81 * the config and metadata -- e.g. name, dtype, trainable status
82 * traced call and loss functions, which are stored as TensorFlow subgraphs.
84 The traced functions allow the SavedModel format to save and load custom
85 layers without the original class definition.
87 You can choose to not save the traced functions by disabling the `save_traces`
88 option. This will decrease the time it takes to save the model and the
89 amount of disk space occupied by the output SavedModel. If you enable this
90 option, then you _must_ provide all custom class definitions when loading
91 the model. See the `custom_objects` argument in `tf.keras.models.load_model`.
93 Args:
94 model: Keras model instance to be saved.
95 filepath: One of the following:
96 - String or `pathlib.Path` object, path where to save the model
97 - `h5py.File` object where to save the model
98 overwrite: Whether we should overwrite any existing model at the target
99 location, or instead ask the user with a manual prompt.
100 include_optimizer: If True, save optimizer's state together.
101 save_format: Either 'tf' or 'h5', indicating whether to save the model
102 to Tensorflow SavedModel or HDF5. Defaults to 'tf' in TF 2.X, and 'h5'
103 in TF 1.X.
104 signatures: Signatures to save with the SavedModel. Applicable to the 'tf'
105 format only. Please see the `signatures` argument in
106 `tf.saved_model.save` for details.
107 options: (only applies to SavedModel format) `tf.saved_model.SaveOptions`
108 object that specifies options for saving to SavedModel.
109 save_traces: (only applies to SavedModel format) When enabled, the
110 SavedModel will store the function traces for each layer. This
111 can be disabled, so that only the configs of each layer are stored.
112 Defaults to `True`. Disabling this will decrease serialization time and
113 reduce file size, but it requires that all custom layers/models
114 implement a `get_config()` method.
116 Raises:
117 ImportError: If save format is hdf5, and h5py is not available.
118 """
119 # pylint: enable=line-too-long
120 from tensorflow.python.keras.engine import sequential # pylint: disable=g-import-not-at-top
122 default_format = 'tf' if tf2.enabled() else 'h5'
123 save_format = save_format or default_format
125 filepath = path_to_string(filepath)
127 # If the user has not already called fit or built the underlying metrics, we
128 # should do that before saving to ensure the metric names have all
129 # appropriate name transformations applied.
130 saving_utils.try_build_compiled_arguments(model)
132 if (save_format == 'h5' or
133 (h5py is not None and isinstance(filepath, h5py.File)) or
134 saving_utils.is_hdf5_filepath(filepath)):
135 # TODO(b/130258301): add utility method for detecting model type.
136 if (not model._is_graph_network and # pylint:disable=protected-access
137 not isinstance(model, sequential.Sequential)):
138 raise NotImplementedError(
139 'Saving the model to HDF5 format requires the model to be a '
140 'Functional model or a Sequential model. It does not work for '
141 'subclassed models, because such models are defined via the body of '
142 'a Python method, which isn\'t safely serializable. Consider saving '
143 'to the Tensorflow SavedModel format (by setting save_format="tf") '
144 'or using `save_weights`.')
145 hdf5_format.save_model_to_hdf5(
146 model, filepath, overwrite, include_optimizer)
147 else:
148 with generic_utils.SharedObjectSavingScope():
149 saved_model_save.save(model, filepath, overwrite, include_optimizer,
150 signatures, options, save_traces)
153@keras_export('keras.models.load_model')
154def load_model(filepath, custom_objects=None, compile=True, options=None): # pylint: disable=redefined-builtin
155 """Loads a model saved via `model.save()`.
157 Usage:
159 >>> model = tf.keras.Sequential([
160 ... tf.keras.layers.Dense(5, input_shape=(3,)),
161 ... tf.keras.layers.Softmax()])
162 >>> model.save('/tmp/model')
163 >>> loaded_model = tf.keras.models.load_model('/tmp/model')
164 >>> x = tf.random.uniform((10, 3))
165 >>> assert np.allclose(model.predict(x), loaded_model.predict(x))
167 Note that the model weights may have different scoped names after being
168 loaded. Scoped names include the model/layer names, such as
169 `"dense_1/kernel:0"`. It is recommended that you use the layer properties to
170 access specific variables, e.g. `model.get_layer("dense_1").kernel`.
172 Args:
173 filepath: One of the following:
174 - String or `pathlib.Path` object, path to the saved model
175 - `h5py.File` object from which to load the model
176 custom_objects: Optional dictionary mapping names
177 (strings) to custom classes or functions to be
178 considered during deserialization.
179 compile: Boolean, whether to compile the model
180 after loading.
181 options: Optional `tf.saved_model.LoadOptions` object that specifies
182 options for loading from SavedModel.
184 Returns:
185 A Keras model instance. If the original model was compiled, and saved with
186 the optimizer, then the returned model will be compiled. Otherwise, the
187 model will be left uncompiled. In the case that an uncompiled model is
188 returned, a warning is displayed if the `compile` argument is set to
189 `True`.
191 Raises:
192 ImportError: if loading from an hdf5 file and h5py is not available.
193 IOError: In case of an invalid savefile.
194 """
195 with generic_utils.SharedObjectLoadingScope():
196 with generic_utils.CustomObjectScope(custom_objects or {}):
197 with load_context.load_context(options):
198 if (h5py is not None and
199 (isinstance(filepath, h5py.File) or h5py.is_hdf5(filepath))):
200 return hdf5_format.load_model_from_hdf5(filepath, custom_objects,
201 compile)
203 filepath = path_to_string(filepath)
204 if isinstance(filepath, str):
205 return saved_model_load.load(filepath, compile, options)
207 raise IOError(
208 'Unable to load model. Filepath is not an hdf5 file (or h5py is not '
209 'available) or SavedModel.')