Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/saving/saving_api.py: 21%
87 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Public API surface for saving APIs."""
17import os
18import warnings
19import zipfile
21import tensorflow.compat.v2 as tf
22from tensorflow.python.util.tf_export import keras_export
24from keras.src.saving import saving_lib
25from keras.src.saving.legacy import save as legacy_sm_saving_lib
26from keras.src.utils import io_utils
28try:
29 import h5py
30except ImportError:
31 h5py = None
34@keras_export("keras.saving.save_model", "keras.models.save_model")
35def save_model(model, filepath, overwrite=True, save_format=None, **kwargs):
36 """Saves a model as a TensorFlow SavedModel or HDF5 file.
38 See the [Serialization and Saving guide](
39 https://keras.io/guides/serialization_and_saving/) for details.
41 Args:
42 model: Keras model instance to be saved.
43 filepath: `str` or `pathlib.Path` object. Path where to save the model.
44 overwrite: Whether we should overwrite any existing model at the target
45 location, or instead ask the user via an interactive prompt.
46 save_format: Either `"keras"`, `"tf"`, `"h5"`,
47 indicating whether to save the model
48 in the native Keras format (`.keras`),
49 in the TensorFlow SavedModel format (referred to as "SavedModel"
50 below), or in the legacy HDF5 format (`.h5`).
51 Defaults to `"tf"` in TF 2.X, and `"h5"` in TF 1.X.
53 SavedModel format arguments:
54 include_optimizer: Only applied to SavedModel and legacy HDF5 formats.
55 If False, do not save the optimizer state. Defaults to True.
56 signatures: Only applies to SavedModel format. Signatures to save
57 with the SavedModel. See the `signatures` argument in
58 `tf.saved_model.save` for details.
59 options: Only applies to SavedModel format.
60 `tf.saved_model.SaveOptions` object that specifies SavedModel
61 saving options.
62 save_traces: Only applies to SavedModel format. When enabled, the
63 SavedModel will store the function traces for each layer. This
64 can be disabled, so that only the configs of each layer are stored.
65 Defaults to `True`. Disabling this will decrease serialization time
66 and reduce file size, but it requires that all custom layers/models
67 implement a `get_config()` method.
69 Example:
71 ```python
72 model = tf.keras.Sequential([
73 tf.keras.layers.Dense(5, input_shape=(3,)),
74 tf.keras.layers.Softmax()])
75 model.save("model.keras")
76 loaded_model = tf.keras.saving.load_model("model.keras")
77 x = tf.random.uniform((10, 3))
78 assert np.allclose(model.predict(x), loaded_model.predict(x))
79 ```
81 Note that `model.save()` is an alias for `tf.keras.saving.save_model()`.
83 The SavedModel or HDF5 file contains:
85 - The model's configuration (architecture)
86 - The model's weights
87 - The model's optimizer's state (if any)
89 Thus models can be reinstantiated in the exact same state, without any of
90 the code used for model definition or training.
92 Note that the model weights may have different scoped names after being
93 loaded. Scoped names include the model/layer names, such as
94 `"dense_1/kernel:0"`. It is recommended that you use the layer properties to
95 access specific variables, e.g. `model.get_layer("dense_1").kernel`.
97 __SavedModel serialization format__
99 With `save_format="tf"`, the model and all trackable objects attached
100 to the it (e.g. layers and variables) are saved as a TensorFlow SavedModel.
101 The model config, weights, and optimizer are included in the SavedModel.
102 Additionally, for every Keras layer attached to the model, the SavedModel
103 stores:
105 * The config and metadata -- e.g. name, dtype, trainable status
106 * Traced call and loss functions, which are stored as TensorFlow
107 subgraphs.
109 The traced functions allow the SavedModel format to save and load custom
110 layers without the original class definition.
112 You can choose to not save the traced functions by disabling the
113 `save_traces` option. This will decrease the time it takes to save the model
114 and the amount of disk space occupied by the output SavedModel. If you
115 enable this option, then you _must_ provide all custom class definitions
116 when loading the model. See the `custom_objects` argument in
117 `tf.keras.saving.load_model`.
118 """
119 save_format = get_save_format(filepath, save_format)
121 # Deprecation warnings
122 if save_format == "h5":
123 warnings.warn(
124 "You are saving your model as an HDF5 file via `model.save()`. "
125 "This file format is considered legacy. "
126 "We recommend using instead the native Keras format, "
127 "e.g. `model.save('my_model.keras')`.",
128 stacklevel=2,
129 )
131 if save_format == "keras":
132 # If file exists and should not be overwritten.
133 try:
134 exists = os.path.exists(filepath)
135 except TypeError:
136 exists = False
137 if exists and not overwrite:
138 proceed = io_utils.ask_to_proceed_with_overwrite(filepath)
139 if not proceed:
140 return
141 if kwargs:
142 raise ValueError(
143 "The following argument(s) are not supported "
144 f"with the native Keras format: {list(kwargs.keys())}"
145 )
146 saving_lib.save_model(model, filepath)
147 else:
148 # Legacy case
149 return legacy_sm_saving_lib.save_model(
150 model,
151 filepath,
152 overwrite=overwrite,
153 save_format=save_format,
154 **kwargs,
155 )
158@keras_export("keras.saving.load_model", "keras.models.load_model")
159def load_model(
160 filepath, custom_objects=None, compile=True, safe_mode=True, **kwargs
161):
162 """Loads a model saved via `model.save()`.
164 Args:
165 filepath: `str` or `pathlib.Path` object, path to the saved model file.
166 custom_objects: Optional dictionary mapping names
167 (strings) to custom classes or functions to be
168 considered during deserialization.
169 compile: Boolean, whether to compile the model after loading.
170 safe_mode: Boolean, whether to disallow unsafe `lambda` deserialization.
171 When `safe_mode=False`, loading an object has the potential to
172 trigger arbitrary code execution. This argument is only
173 applicable to the Keras v3 model format. Defaults to True.
175 SavedModel format arguments:
176 options: Only applies to SavedModel format.
177 Optional `tf.saved_model.LoadOptions` object that specifies
178 SavedModel loading options.
180 Returns:
181 A Keras model instance. If the original model was compiled,
182 and the argument `compile=True` is set, then the returned model
183 will be compiled. Otherwise, the model will be left uncompiled.
185 Example:
187 ```python
188 model = tf.keras.Sequential([
189 tf.keras.layers.Dense(5, input_shape=(3,)),
190 tf.keras.layers.Softmax()])
191 model.save("model.keras")
192 loaded_model = tf.keras.saving.load_model("model.keras")
193 x = tf.random.uniform((10, 3))
194 assert np.allclose(model.predict(x), loaded_model.predict(x))
195 ```
197 Note that the model variables may have different name values
198 (`var.name` property, e.g. `"dense_1/kernel:0"`) after being reloaded.
199 It is recommended that you use layer attributes to
200 access specific variables, e.g. `model.get_layer("dense_1").kernel`.
201 """
202 is_keras_zip = str(filepath).endswith(".keras") and zipfile.is_zipfile(
203 filepath
204 )
206 # Support for remote zip files
207 if (
208 saving_lib.is_remote_path(filepath)
209 and not tf.io.gfile.isdir(filepath)
210 and not is_keras_zip
211 ):
212 local_path = os.path.join(
213 saving_lib.get_temp_dir(), os.path.basename(filepath)
214 )
216 # Copy from remote to temporary local directory
217 tf.io.gfile.copy(filepath, local_path, overwrite=True)
219 # Switch filepath to local zipfile for loading model
220 if zipfile.is_zipfile(local_path):
221 filepath = local_path
222 is_keras_zip = True
224 if is_keras_zip:
225 if kwargs:
226 raise ValueError(
227 "The following argument(s) are not supported "
228 f"with the native Keras format: {list(kwargs.keys())}"
229 )
230 return saving_lib.load_model(
231 filepath,
232 custom_objects=custom_objects,
233 compile=compile,
234 safe_mode=safe_mode,
235 )
237 # Legacy case.
238 return legacy_sm_saving_lib.load_model(
239 filepath, custom_objects=custom_objects, compile=compile, **kwargs
240 )
243def save_weights(model, filepath, overwrite=True, **kwargs):
244 if str(filepath).endswith(".weights.h5"):
245 # If file exists and should not be overwritten.
246 try:
247 exists = os.path.exists(filepath)
248 except TypeError:
249 exists = False
250 if exists and not overwrite:
251 proceed = io_utils.ask_to_proceed_with_overwrite(filepath)
252 if not proceed:
253 return
254 saving_lib.save_weights_only(model, filepath)
255 else:
256 legacy_sm_saving_lib.save_weights(
257 model, filepath, overwrite=overwrite, **kwargs
258 )
261def load_weights(model, filepath, skip_mismatch=False, **kwargs):
262 if str(filepath).endswith(".keras") and zipfile.is_zipfile(filepath):
263 saving_lib.load_weights_only(
264 model, filepath, skip_mismatch=skip_mismatch
265 )
266 elif str(filepath).endswith(".weights.h5"):
267 saving_lib.load_weights_only(
268 model, filepath, skip_mismatch=skip_mismatch
269 )
270 else:
271 return legacy_sm_saving_lib.load_weights(
272 model, filepath, skip_mismatch=skip_mismatch, **kwargs
273 )
276def get_save_format(filepath, save_format):
277 if save_format:
278 if save_format == "keras_v3":
279 return "keras"
280 if save_format == "keras":
281 if saving_lib.saving_v3_enabled():
282 return "keras"
283 else:
284 return "h5"
285 if save_format in ("h5", "hdf5"):
286 return "h5"
287 if save_format in ("tf", "tensorflow"):
288 return "tf"
290 raise ValueError(
291 "Unknown `save_format` argument. Expected one of "
292 "'keras', 'tf', or 'h5'. "
293 f"Received: save_format{save_format}"
294 )
296 # No save format specified: infer from filepath.
298 if str(filepath).endswith(".keras"):
299 if saving_lib.saving_v3_enabled():
300 return "keras"
301 else:
302 return "h5"
304 if str(filepath).endswith((".h5", ".hdf5")):
305 return "h5"
307 if h5py is not None and isinstance(filepath, h5py.File):
308 return "h5"
310 # No recognizable file format: default to TF in TF2 and h5 in TF1.
312 if tf.__internal__.tf2.enabled():
313 return "tf"
314 else:
315 return "h5"