Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/training/saving/trace_saveable_util.py: 20%

56 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Utilities for tracing save and restore functions for SaveableObjects.""" 

16 

17from tensorflow.python.eager import def_function 

18from tensorflow.python.framework import dtypes 

19from tensorflow.python.framework import tensor_spec 

20from tensorflow.python.framework import type_spec 

21 

22from tensorflow.python.ops import resource_variable_ops 

23from tensorflow.python.training.saving import saveable_object 

24from tensorflow.python.training.saving import saveable_object_util 

25from tensorflow.python.util import nest 

26 

27 

28def trace_save_restore_function_map(obj, factory_data_list): 

29 """Traces all save and restore functions in the provided factory list. 

30 

31 Args: 

32 obj: `Trackable` object. 

33 factory_data_list: List of `_CheckpointFactoryData`. 

34 

35 Returns: 

36 Dict mapping atttribute names to tuples of concrete save/restore functions. 

37 """ 

38 saveable_fns = {} 

39 

40 for factory_data in factory_data_list: 

41 saveable_factory = factory_data.factory 

42 attribute_name = factory_data.name 

43 

44 # If object revives as a resource (or TPU/Mirrored) variable, 

45 # there is no need to trace the save and restore functions. 

46 if (resource_variable_ops.is_resource_variable(obj) or 

47 resource_variable_ops.is_resource_variable(saveable_factory) or 

48 not callable(saveable_factory)): 

49 continue 

50 

51 concrete_save, concrete_restore = ( 

52 _trace_save_restore_functions(saveable_factory, obj)) 

53 if not concrete_save: 

54 continue 

55 saveable_fns[attribute_name] = (concrete_save, concrete_restore) 

56 return saveable_fns 

57 

58 

59def _trace_save_restore_functions(saveable_factory, obj): 

60 """Traces save and restore functions.""" 

61 if saveable_object_util.is_factory_for_restored_saveable_object( 

62 saveable_factory): 

63 return ( 

64 saveable_factory.keywords["save_function"], 

65 saveable_factory.keywords["restore_function"], 

66 ) 

67 

68 saveables = [] # Store the saveables in a data structure accessible to both 

69 # the save and restore functions. 

70 

71 @def_function.function( 

72 input_signature=[tensor_spec.TensorSpec([], dtypes.string)] 

73 ) 

74 def save_fn(checkpoint_key): 

75 maybe_saveable = saveable_factory(name=checkpoint_key) 

76 if isinstance(maybe_saveable, saveable_object.SaveableObject): 

77 maybe_saveable = [maybe_saveable] 

78 saveables[:] = maybe_saveable 

79 

80 # Return list of all SaveSpecs created by the factory. 

81 ret = [] 

82 for saveable in saveables: 

83 for spec in saveable.specs: 

84 ret.append({"name": spec.name, "tensor": spec.tensor, 

85 "slice_spec": spec.slice_spec}) 

86 return ret 

87 

88 concrete_save = save_fn.get_concrete_function() 

89 

90 # The SaveableObjects are produced when `save_fn` is traced. 

91 saveables = saveable_object_util.validate_saveables_for_saved_model( 

92 saveables, obj) 

93 if not saveables: 

94 return None, None 

95 

96 # Use the SaveSpecs to define the input signature of the restore function. 

97 restored_type_specs = [] 

98 tensor_structure = [] 

99 for saveable in saveables: 

100 saveable_tensor_structure = [] 

101 tensor_structure.append(saveable_tensor_structure) 

102 for spec in saveable.specs: 

103 restored_type_specs.append(type_spec.type_spec_from_value(spec.tensor)) 

104 saveable_tensor_structure.append(spec.name) 

105 

106 @def_function.function(input_signature=restored_type_specs) 

107 def restore_fn(*restored_tensors): 

108 structured_restored_tensors = nest.pack_sequence_as( 

109 tensor_structure, restored_tensors) 

110 for saveable, restored_tensors in zip(saveables, 

111 structured_restored_tensors): 

112 saveable.restore(restored_tensors, restored_shapes=None) 

113 return 1 # Return dummy tensor 

114 

115 concrete_restore = restore_fn.get_concrete_function() 

116 return concrete_save, concrete_restore