Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/saving/saving_lib.py: 19%
326 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Python-based idempotent model-saving functionality."""
17import datetime
18import io
19import json
20import os
21import re
22import tempfile
23import threading
24import warnings
25import zipfile
27import numpy as np
28import tensorflow.compat.v2 as tf
30import keras.src as keras
31from keras.src import losses
32from keras.src.engine import base_layer
33from keras.src.optimizers import optimizer
34from keras.src.saving.serialization_lib import ObjectSharingScope
35from keras.src.saving.serialization_lib import deserialize_keras_object
36from keras.src.saving.serialization_lib import serialize_keras_object
37from keras.src.utils import generic_utils
38from keras.src.utils import io_utils
40try:
41 import h5py
42except ImportError:
43 h5py = None
45# isort: off
47_CONFIG_FILENAME = "config.json"
48_METADATA_FILENAME = "metadata.json"
49_VARS_FNAME = "model.weights" # Will become e.g. "model.weights.h5"
50_ASSETS_DIRNAME = "assets"
52# A temporary flag to enable the new idempotent saving framework.
53_SAVING_V3_ENABLED = threading.local()
54_SAVING_V3_ENABLED.value = True
56ATTR_SKIPLIST = frozenset(
57 {
58 "_callable_losses",
59 "_captured_weight_regularizer",
60 "_checkpoint_dependencies",
61 "_deferred_dependencies",
62 "_eager_losses",
63 "_inbound_nodes",
64 "_inbound_nodes_value",
65 "_output_layers",
66 "_input_layers",
67 "_keras_api_names",
68 "_keras_api_names_v1",
69 "_name_based_restores",
70 "_non_trainable_weights",
71 "_outbound_nodes",
72 "_outbound_nodes_value",
73 "_saved_model_arg_spec",
74 "_self_name_based_restores",
75 "_self_saveable_object_factories",
76 "_self_tracked_trackables",
77 "_saved_model_inputs_spec",
78 "_self_unconditional_checkpoint_dependencies",
79 "_self_unconditional_deferred_dependencies",
80 "_self_unconditional_dependency_names",
81 "_tf_api_names",
82 "_tf_api_names_v1",
83 "_trainable_weights",
84 "_non_trainable_weights",
85 "_unconditional_checkpoint_dependencies",
86 "_unconditional_dependency_names",
87 "_updates",
88 "_layer_call_argspecs",
89 "inbound_nodes",
90 "outbound_nodes",
91 "input_shape",
92 "output_shape",
93 "submodules",
94 "weights",
95 "non_trainable_weights",
96 "trainable_weights",
97 "variables",
98 "non_trainable_variables",
99 "trainable_variables",
100 "updates", # Would raise a warning if visited.
101 "state_updates", # Would raise a warning if visited.
102 }
103)
106def save_model(model, filepath, weights_format="h5"):
107 """Save a zip-archive representing a Keras model to the given filepath.
109 The zip-based archive contains the following structure:
111 - JSON-based configuration file (config.json): Records of model, layer, and
112 other trackables' configuration.
113 - NPZ-based trackable state files, found in respective directories, such as
114 model/states.npz, model/dense_layer/states.npz, etc.
115 - Metadata file.
117 The states of Keras trackables (layers, optimizers, loss, and metrics) are
118 automatically saved as long as they can be discovered through the attributes
119 returned by `dir(Model)`. Typically, the state includes the variables
120 associated with the trackable, but some specially purposed layers may
121 contain more such as the vocabularies stored in the hashmaps. The trackables
122 define how their states are saved by exposing `save_state()` and
123 `load_state()` APIs.
125 For the case of layer states, the variables will be visited as long as
126 they are either 1) referenced via layer attributes, or 2) referenced via a
127 container (list, tuple, or dict), and the container is referenced via a
128 layer attribute.
129 """
130 filepath = str(filepath)
131 if not filepath.endswith(".keras"):
132 raise ValueError(
133 "Invalid `filepath` argument: expected a `.keras` extension. "
134 f"Received: filepath={filepath}"
135 )
136 if weights_format == "h5" and h5py is None:
137 raise ImportError("h5py must be installed in order to save a model.")
139 if not model.built:
140 warnings.warn(
141 "You are saving a model that has not yet been built. "
142 "It might not contain any weights yet. "
143 "Consider building the model first by calling it "
144 "on some data.",
145 stacklevel=2,
146 )
147 saving_v3_enabled_value = getattr(_SAVING_V3_ENABLED, "value", False)
148 _SAVING_V3_ENABLED.value = True
150 with ObjectSharingScope():
151 serialized_model_dict = serialize_keras_object(model)
152 config_json = json.dumps(serialized_model_dict)
153 metadata_json = json.dumps(
154 {
155 "keras_version": keras.__version__,
156 "date_saved": datetime.datetime.now().strftime("%Y-%m-%d@%H:%M:%S"),
157 }
158 )
159 # TODO(rameshsampath): Need a better logic for local vs remote path
160 if is_remote_path(filepath):
161 # Remote path. Zip to local drive and copy to remote
162 zip_filepath = os.path.join(get_temp_dir(), "tmp_model.keras")
163 else:
164 zip_filepath = filepath
165 try:
166 with zipfile.ZipFile(zip_filepath, "w") as zf:
168 with zf.open(_METADATA_FILENAME, "w") as f:
169 f.write(metadata_json.encode())
170 with zf.open(_CONFIG_FILENAME, "w") as f:
171 f.write(config_json.encode())
173 if weights_format == "h5":
174 weights_store = H5IOStore(
175 _VARS_FNAME + ".h5", archive=zf, mode="w"
176 )
177 elif weights_format == "npz":
178 weights_store = NpzIOStore(
179 _VARS_FNAME + ".npz", archive=zf, mode="w"
180 )
181 else:
182 raise ValueError(
183 "Unknown `weights_format` argument. "
184 "Expected 'h5' or 'npz'. "
185 f"Received: weights_format={weights_format}"
186 )
188 asset_store = DiskIOStore(_ASSETS_DIRNAME, archive=zf, mode="w")
190 _save_state(
191 model,
192 weights_store=weights_store,
193 assets_store=asset_store,
194 inner_path="",
195 visited_trackables=set(),
196 )
197 weights_store.close()
198 asset_store.close()
200 if is_remote_path(filepath):
201 # Using tf.io.gfile context manager doesn't close zip file when
202 # writing to GCS. Hence writing to local and copying to filepath.
203 tf.io.gfile.copy(zip_filepath, filepath, overwrite=True)
204 os.remove(zip_filepath)
205 except Exception as e:
206 raise e
207 finally:
208 _SAVING_V3_ENABLED.value = saving_v3_enabled_value
211def load_model(filepath, custom_objects=None, compile=True, safe_mode=True):
212 """Load a zip archive representing a Keras model."""
214 filepath = str(filepath)
215 if not filepath.endswith(".keras"):
216 raise ValueError(
217 "Invalid filename: expected a `.keras` extension. "
218 f"Received: filepath={filepath}"
219 )
221 saving_v3_enabled_value = getattr(_SAVING_V3_ENABLED, "value", False)
222 _SAVING_V3_ENABLED.value = True
224 try:
225 with tf.io.gfile.GFile(
226 filepath, mode="r+b"
227 ) as gfile_handle, zipfile.ZipFile(gfile_handle, "r") as zf:
229 with zf.open(_CONFIG_FILENAME, "r") as f:
230 config_json = f.read()
232 # Note: we should NOT use a custom JSON decoder. Anything that
233 # needs custom decoding must be handled in deserialize_keras_object.
234 config_dict = json.loads(config_json)
235 if not compile:
236 # Disable compilation
237 config_dict["compile_config"] = None
238 # Construct the model from the configuration file in the archive.
239 with ObjectSharingScope():
240 model = deserialize_keras_object(
241 config_dict, custom_objects, safe_mode=safe_mode
242 )
244 all_filenames = zf.namelist()
245 if _VARS_FNAME + ".h5" in all_filenames:
246 weights_store = H5IOStore(
247 _VARS_FNAME + ".h5", archive=zf, mode="r"
248 )
249 elif _VARS_FNAME + ".npz" in all_filenames:
250 weights_store = NpzIOStore(
251 _VARS_FNAME + ".npz", archive=zf, mode="r"
252 )
253 else:
254 raise ValueError(
255 f"Expected a {_VARS_FNAME}.h5 or {_VARS_FNAME}.npz file."
256 )
258 if len(all_filenames) > 3:
259 asset_store = DiskIOStore(_ASSETS_DIRNAME, archive=zf, mode="r")
260 else:
261 asset_store = None
263 _load_state(
264 model,
265 weights_store=weights_store,
266 assets_store=asset_store,
267 inner_path="",
268 visited_trackables=set(),
269 )
270 weights_store.close()
271 if asset_store:
272 asset_store.close()
274 except Exception as e:
275 raise e
276 else:
277 return model
278 finally:
279 _SAVING_V3_ENABLED.value = saving_v3_enabled_value
282def save_weights_only(model, filepath):
283 """Save only the weights of a model to a target filepath (.weights.h5).
285 Note: only supports h5 for now.
286 """
287 # TODO: if h5 filepath is remote, create the file in a temporary directory
288 # then upload it
289 filepath = str(filepath)
290 if not filepath.endswith(".weights.h5"):
291 raise ValueError(
292 "Invalid `filepath` argument: expected a `.weights.h5` extension. "
293 f"Received: filepath={filepath}"
294 )
295 weights_store = H5IOStore(filepath, mode="w")
296 _save_state(
297 model,
298 weights_store=weights_store,
299 assets_store=None,
300 inner_path="",
301 visited_trackables=set(),
302 )
303 weights_store.close()
306def load_weights_only(model, filepath, skip_mismatch=False):
307 """Load the weights of a model from a filepath (.keras or .weights.h5).
309 Note: only supports h5 for now.
310 """
311 temp_dir = None
312 archive = None
313 filepath = str(filepath)
314 if filepath.endswith(".weights.h5"):
315 # TODO: download file if h5 filepath is remote
316 weights_store = H5IOStore(filepath, mode="r")
317 elif filepath.endswith(".keras"):
318 archive = zipfile.ZipFile(filepath, "r")
319 weights_store = H5IOStore(
320 _VARS_FNAME + ".h5", archive=archive, mode="r"
321 )
323 _load_state(
324 model,
325 weights_store=weights_store,
326 assets_store=None,
327 inner_path="",
328 skip_mismatch=skip_mismatch,
329 visited_trackables=set(),
330 )
331 weights_store.close()
332 if temp_dir and tf.io.gfile.exists(temp_dir):
333 tf.io.gfile.rmtree(temp_dir)
334 if archive:
335 archive.close()
338def is_remote_path(filepath):
339 if re.match(r"^(/cns|/cfs|/gcs|.*://).*$", str(filepath)):
340 return True
341 return False
344def _write_to_zip_recursively(zipfile_to_save, system_path, zip_path):
345 if not tf.io.gfile.isdir(system_path):
346 zipfile_to_save.write(system_path, zip_path)
347 else:
348 for file_name in tf.io.gfile.listdir(system_path):
349 system_file_path = tf.io.gfile.join(system_path, file_name)
350 zip_file_path = tf.io.gfile.join(zip_path, file_name)
351 _write_to_zip_recursively(
352 zipfile_to_save, system_file_path, zip_file_path
353 )
356def _walk_trackable(trackable):
357 for child_attr in dir(trackable):
358 if child_attr.startswith("__") or child_attr in ATTR_SKIPLIST:
359 continue
360 try:
361 child_obj = getattr(trackable, child_attr)
362 except Exception:
363 # Avoid raising the exception when visiting the attributes.
364 continue
365 yield child_attr, child_obj
368def _save_state(
369 trackable, weights_store, assets_store, inner_path, visited_trackables
370):
371 # If the trackable has already been saved, skip it.
372 if id(trackable) in visited_trackables:
373 return
375 if hasattr(trackable, "save_own_variables") and weights_store:
376 trackable.save_own_variables(weights_store.make(inner_path))
377 if hasattr(trackable, "save_assets") and assets_store:
378 trackable.save_assets(assets_store.make(inner_path))
380 visited_trackables.add(id(trackable))
382 # Recursively save state of children trackables (layers, optimizers, etc.)
383 for child_attr, child_obj in _walk_trackable(trackable):
384 if _is_keras_trackable(child_obj):
385 _save_state(
386 child_obj,
387 weights_store,
388 assets_store,
389 inner_path=tf.io.gfile.join(inner_path, child_attr),
390 visited_trackables=visited_trackables,
391 )
392 elif isinstance(child_obj, (list, dict, tuple, set)):
393 _save_container_state(
394 child_obj,
395 weights_store,
396 assets_store,
397 inner_path=tf.io.gfile.join(inner_path, child_attr),
398 visited_trackables=visited_trackables,
399 )
402def _load_state(
403 trackable,
404 weights_store,
405 assets_store,
406 inner_path,
407 skip_mismatch=False,
408 visited_trackables=None,
409):
410 if visited_trackables and id(trackable) in visited_trackables:
411 return
413 if hasattr(trackable, "load_own_variables") and weights_store:
414 if skip_mismatch:
415 try:
416 trackable.load_own_variables(weights_store.get(inner_path))
417 except Exception as e:
418 warnings.warn(
419 f"Could not load weights in object {trackable}. "
420 "Skipping object. "
421 f"Exception encountered: {e}",
422 stacklevel=2,
423 )
424 else:
425 trackable.load_own_variables(weights_store.get(inner_path))
427 if hasattr(trackable, "load_assets") and assets_store:
428 if skip_mismatch:
429 try:
430 trackable.load_assets(assets_store.get(inner_path))
431 except Exception as e:
432 warnings.warn(
433 f"Could not load assets in object {trackable}. "
434 "Skipping object. "
435 f"Exception encountered: {e}",
436 stacklevel=2,
437 )
438 else:
439 trackable.load_assets(assets_store.get(inner_path))
441 if visited_trackables is not None:
442 visited_trackables.add(id(trackable))
444 # Recursively load states for Keras trackables such as layers/optimizers.
445 for child_attr, child_obj in _walk_trackable(trackable):
446 if _is_keras_trackable(child_obj):
447 _load_state(
448 child_obj,
449 weights_store,
450 assets_store,
451 inner_path=tf.io.gfile.join(inner_path, child_attr),
452 skip_mismatch=skip_mismatch,
453 visited_trackables=visited_trackables,
454 )
455 elif isinstance(child_obj, (list, dict, tuple, set)):
456 _load_container_state(
457 child_obj,
458 weights_store,
459 assets_store,
460 inner_path=tf.io.gfile.join(inner_path, child_attr),
461 skip_mismatch=skip_mismatch,
462 visited_trackables=visited_trackables,
463 )
466def _save_container_state(
467 container, weights_store, assets_store, inner_path, visited_trackables
468):
469 used_names = {}
470 if isinstance(container, dict):
471 container = list(container.values())
473 for trackable in container:
474 if _is_keras_trackable(trackable):
475 # Do NOT address the trackable via `trackable.name`, since
476 # names are usually autogenerated and thus not reproducible
477 # (i.e. they may vary across two instances of the same model).
478 name = generic_utils.to_snake_case(trackable.__class__.__name__)
479 if name in used_names:
480 used_names[name] += 1
481 name = f"{name}_{used_names[name]}"
482 else:
483 used_names[name] = 0
484 _save_state(
485 trackable,
486 weights_store,
487 assets_store,
488 inner_path=tf.io.gfile.join(inner_path, name),
489 visited_trackables=visited_trackables,
490 )
493def _load_container_state(
494 container,
495 weights_store,
496 assets_store,
497 inner_path,
498 skip_mismatch,
499 visited_trackables,
500):
501 used_names = {}
502 if isinstance(container, dict):
503 container = list(container.values())
505 for trackable in container:
506 if _is_keras_trackable(trackable):
507 name = generic_utils.to_snake_case(trackable.__class__.__name__)
508 if name in used_names:
509 used_names[name] += 1
510 name = f"{name}_{used_names[name]}"
511 else:
512 used_names[name] = 0
513 _load_state(
514 trackable,
515 weights_store,
516 assets_store,
517 inner_path=tf.io.gfile.join(inner_path, name),
518 skip_mismatch=skip_mismatch,
519 visited_trackables=visited_trackables,
520 )
523class DiskIOStore:
524 """Asset store backed by disk storage.
526 If `archive` is specified, then `root_path` refers to the filename
527 inside the archive.
529 If `archive` is not specified, then `root_path` refers to the full path of
530 the target directory.
531 """
533 def __init__(self, root_path, archive=None, mode=None):
534 self.mode = mode
535 self.root_path = root_path
536 self.archive = archive
537 self.tmp_dir = None
538 if self.archive:
539 self.tmp_dir = get_temp_dir()
540 if self.mode == "r":
541 self.archive.extractall(path=self.tmp_dir)
542 self.working_dir = tf.io.gfile.join(self.tmp_dir, self.root_path)
543 if self.mode == "w":
544 tf.io.gfile.makedirs(self.working_dir)
545 else:
546 if mode == "r":
547 self.working_dir = root_path
548 else:
549 self.tmp_dir = get_temp_dir()
550 self.working_dir = tf.io.gfile.join(
551 self.tmp_dir, self.root_path
552 )
553 tf.io.gfile.makedirs(self.working_dir)
555 def make(self, path):
556 if not path:
557 return self.working_dir
558 path = tf.io.gfile.join(self.working_dir, path)
559 if not tf.io.gfile.exists(path):
560 tf.io.gfile.makedirs(path)
561 return path
563 def get(self, path):
564 if not path:
565 return self.working_dir
566 path = tf.io.gfile.join(self.working_dir, path)
567 if tf.io.gfile.exists(path):
568 return path
569 return None
571 def close(self):
572 if self.mode == "w" and self.archive:
573 _write_to_zip_recursively(
574 self.archive, self.working_dir, self.root_path
575 )
576 if self.tmp_dir and tf.io.gfile.exists(self.tmp_dir):
577 tf.io.gfile.rmtree(self.tmp_dir)
580class H5IOStore:
581 def __init__(self, root_path, archive=None, mode="r"):
582 """Numerical variable store backed by HDF5.
584 If `archive` is specified, then `root_path` refers to the filename
585 inside the archive.
587 If `archive` is not specified, then `root_path` refers to the path of
588 the h5 file on disk.
589 """
590 self.root_path = root_path
591 self.mode = mode
592 self.archive = archive
593 self.io_file = None
595 if self.archive:
596 if self.mode == "w":
597 self.io_file = io.BytesIO()
598 else:
599 self.io_file = self.archive.open(self.root_path, "r")
600 self.h5_file = h5py.File(self.io_file, mode=self.mode)
601 else:
602 self.h5_file = h5py.File(root_path, mode=self.mode)
604 def make(self, path):
605 if not path:
606 return self.h5_file.create_group("vars")
607 return self.h5_file.create_group(path).create_group("vars")
609 def get(self, path):
610 if not path:
611 return self.h5_file["vars"]
612 if path in self.h5_file and "vars" in self.h5_file[path]:
613 return self.h5_file[path]["vars"]
614 return {}
616 def close(self):
617 self.h5_file.close()
618 if self.mode == "w" and self.archive:
619 self.archive.writestr(self.root_path, self.io_file.getvalue())
620 if self.io_file:
621 self.io_file.close()
624class NpzIOStore:
625 def __init__(self, root_path, archive=None, mode="r"):
626 """Numerical variable store backed by NumPy.savez/load.
628 If `archive` is specified, then `root_path` refers to the filename
629 inside the archive.
631 If `archive` is not specified, then `root_path` refers to the path of
632 the npz file on disk.
633 """
634 self.root_path = root_path
635 self.mode = mode
636 self.archive = archive
637 if mode == "w":
638 self.contents = {}
639 else:
640 if self.archive:
641 self.f = archive.open(root_path, mode="r")
642 else:
643 self.f = open(root_path, mode="rb")
644 self.contents = np.load(self.f, allow_pickle=True)
646 def make(self, path):
647 if not path:
648 self.contents["__root__"] = {}
649 return self.contents["__root__"]
650 self.contents[path] = {}
651 return self.contents[path]
653 def get(self, path):
654 if not path:
655 if "__root__" in self.contents:
656 return dict(self.contents["__root__"])
657 return {}
658 if path in self.contents:
659 return self.contents[path].tolist()
660 return {}
662 def close(self):
663 if self.mode == "w":
664 if self.archive:
665 self.f = self.archive.open(
666 self.root_path, mode="w", force_zip64=True
667 )
668 else:
669 self.f = open(self.root_path, mode="wb")
670 np.savez(self.f, **self.contents)
671 self.f.close()
674def get_temp_dir():
675 temp_dir = tempfile.mkdtemp()
676 testfile = tempfile.TemporaryFile(dir=temp_dir)
677 testfile.close()
678 return temp_dir
681def _is_keras_trackable(obj):
682 from keras.src.metrics import base_metric # To avoid circular import
684 return isinstance(
685 obj,
686 (
687 base_layer.Layer,
688 optimizer.Optimizer,
689 base_metric.Metric,
690 losses.Loss,
691 ),
692 )
695def saving_v3_enabled():
696 return getattr(_SAVING_V3_ENABLED, "value", True)
699# Some debugging utilities.
702def _print_h5_file(h5_file, prefix="", action=None):
703 if not prefix:
704 print(f"Keras weights file ({h5_file}) {action}:")
705 if not hasattr(h5_file, "keys"):
706 return
707 for key in h5_file.keys():
708 print(f"...{prefix}{key}")
709 _print_h5_file(h5_file[key], prefix=prefix + "...")
712def _print_zip_file(zipfile, action):
713 io_utils.print_msg(f"Keras model archive {action}:")
714 # Same as `ZipFile.printdir()` except for using Keras' printing utility.
715 io_utils.print_msg(
716 "%-46s %19s %12s" % ("File Name", "Modified ", "Size")
717 )
718 for zinfo in zipfile.filelist:
719 date = "%d-%02d-%02d %02d:%02d:%02d" % zinfo.date_time[:6]
720 io_utils.print_msg(
721 "%-46s %s %12d" % (zinfo.filename, date, zinfo.file_size)
722 )