Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/training/saving/trace_saveable_util.py: 20%
56 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities for tracing save and restore functions for SaveableObjects."""
17from tensorflow.python.eager import def_function
18from tensorflow.python.framework import dtypes
19from tensorflow.python.framework import tensor_spec
20from tensorflow.python.framework import type_spec
22from tensorflow.python.ops import resource_variable_ops
23from tensorflow.python.training.saving import saveable_object
24from tensorflow.python.training.saving import saveable_object_util
25from tensorflow.python.util import nest
28def trace_save_restore_function_map(obj, factory_data_list):
29 """Traces all save and restore functions in the provided factory list.
31 Args:
32 obj: `Trackable` object.
33 factory_data_list: List of `_CheckpointFactoryData`.
35 Returns:
36 Dict mapping atttribute names to tuples of concrete save/restore functions.
37 """
38 saveable_fns = {}
40 for factory_data in factory_data_list:
41 saveable_factory = factory_data.factory
42 attribute_name = factory_data.name
44 # If object revives as a resource (or TPU/Mirrored) variable,
45 # there is no need to trace the save and restore functions.
46 if (resource_variable_ops.is_resource_variable(obj) or
47 resource_variable_ops.is_resource_variable(saveable_factory) or
48 not callable(saveable_factory)):
49 continue
51 concrete_save, concrete_restore = (
52 _trace_save_restore_functions(saveable_factory, obj))
53 if not concrete_save:
54 continue
55 saveable_fns[attribute_name] = (concrete_save, concrete_restore)
56 return saveable_fns
59def _trace_save_restore_functions(saveable_factory, obj):
60 """Traces save and restore functions."""
61 if saveable_object_util.is_factory_for_restored_saveable_object(
62 saveable_factory):
63 return (
64 saveable_factory.keywords["save_function"],
65 saveable_factory.keywords["restore_function"],
66 )
68 saveables = [] # Store the saveables in a data structure accessible to both
69 # the save and restore functions.
71 @def_function.function(
72 input_signature=[tensor_spec.TensorSpec([], dtypes.string)]
73 )
74 def save_fn(checkpoint_key):
75 maybe_saveable = saveable_factory(name=checkpoint_key)
76 if isinstance(maybe_saveable, saveable_object.SaveableObject):
77 maybe_saveable = [maybe_saveable]
78 saveables[:] = maybe_saveable
80 # Return list of all SaveSpecs created by the factory.
81 ret = []
82 for saveable in saveables:
83 for spec in saveable.specs:
84 ret.append({"name": spec.name, "tensor": spec.tensor,
85 "slice_spec": spec.slice_spec})
86 return ret
88 concrete_save = save_fn.get_concrete_function()
90 # The SaveableObjects are produced when `save_fn` is traced.
91 saveables = saveable_object_util.validate_saveables_for_saved_model(
92 saveables, obj)
93 if not saveables:
94 return None, None
96 # Use the SaveSpecs to define the input signature of the restore function.
97 restored_type_specs = []
98 tensor_structure = []
99 for saveable in saveables:
100 saveable_tensor_structure = []
101 tensor_structure.append(saveable_tensor_structure)
102 for spec in saveable.specs:
103 restored_type_specs.append(type_spec.type_spec_from_value(spec.tensor))
104 saveable_tensor_structure.append(spec.name)
106 @def_function.function(input_signature=restored_type_specs)
107 def restore_fn(*restored_tensors):
108 structured_restored_tensors = nest.pack_sequence_as(
109 tensor_structure, restored_tensors)
110 for saveable, restored_tensors in zip(saveables,
111 structured_restored_tensors):
112 saveable.restore(restored_tensors, restored_shapes=None)
113 return 1 # Return dummy tensor
115 concrete_restore = restore_fn.get_concrete_function()
116 return concrete_save, concrete_restore