Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/saved_model/save_options.py: 40%

45 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Options for saving SavedModels.""" 

16 

17import enum 

18 

19from tensorflow.python.util import compat 

20from tensorflow.python.util.tf_export import tf_export 

21 

22 

23@tf_export("saved_model.experimental.VariablePolicy") 

24class VariablePolicy(enum.Enum): 

25 """Enum defining options for variable handling when saving. 

26 

27 NONE 

28 No policy applied: Distributed variables are saved as one variable, with no 

29 device attached. 

30 

31 SAVE_VARIABLE_DEVICES 

32 When saving variables, also save their device assignment. 

33 This is useful if one wants to hardcode devices in saved models, but it also 

34 makes them non-portable if soft device placement is disabled (more details 

35 in `tf.config.set_soft_device_placement`). This is currently not 

36 fully supported by `saved_model.load`, and is mainly intended to be used 

37 when one will be reading the saved model at a lower API level. In the 

38 example below, the graph saved by the call to `saved_model.save` will have 

39 the variable devices correctly specified: 

40 ```python 

41 exported = tf.train.Checkpoint() 

42 with tf.device('/GPU:0'): 

43 exported.x_gpu = tf.Variable(1.0) 

44 with tf.device('/CPU:0'): 

45 exported.x_cpu = tf.Variable(1.0) 

46 tf.saved_model.save(exported, export_dir, 

47 options = tf.saved_model.SaveOptions( 

48 experimental_variable_policy= 

49 tf.saved_model.experimental.VariablePolicy.SAVE_VARIABLE_DEVICES)) 

50 ``` 

51 Distributed variables are still saved as one variable under this policy. 

52 

53 EXPAND_DISTRIBUTED_VARIABLES 

54 Distributed variables will be saved with information about their components, 

55 allowing for their restoration on load. Also, the saved graph will contain 

56 references to those variables. This is useful when one wants to use the 

57 model for training in environments where the original distribution strategy 

58 is not available. 

59 """ 

60 

61 NONE = None 

62 

63 SAVE_VARIABLE_DEVICES = "save_variable_devices" 

64 

65 EXPAND_DISTRIBUTED_VARIABLES = "expand_distributed_variables" 

66 

67 def _save_variable_devices(self): 

68 """Checks whether variable devices should be saved.""" 

69 return self != VariablePolicy.NONE 

70 

71 def _expand_distributed_variables(self): 

72 """Checks whether distributed variables should be expanded.""" 

73 return self == VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES 

74 

75 @staticmethod 

76 def from_obj(obj): 

77 """Tries to convert `obj` to a VariablePolicy instance.""" 

78 if obj is None: 

79 return VariablePolicy.NONE 

80 if isinstance(obj, VariablePolicy): 

81 return obj 

82 key = str(obj).lower() 

83 for policy in VariablePolicy: 

84 if key == policy.value: 

85 return policy 

86 raise ValueError(f"Received invalid VariablePolicy value: {obj}.") 

87 

88 

89@tf_export("saved_model.SaveOptions") 

90class SaveOptions: 

91 """Options for saving to SavedModel. 

92 

93 This function may be used in the `options` argument in functions that 

94 save a SavedModel (`tf.saved_model.save`, `tf.keras.models.save_model`). 

95 """ 

96 

97 # Define object attributes in __slots__ for improved memory and performance. 

98 __slots__ = ("namespace_whitelist", "save_debug_info", "function_aliases", 

99 "experimental_io_device", "experimental_variable_policy", 

100 "experimental_custom_gradients") 

101 

102 def __init__(self, 

103 namespace_whitelist=None, 

104 save_debug_info=False, 

105 function_aliases=None, 

106 experimental_io_device=None, 

107 experimental_variable_policy=None, 

108 experimental_custom_gradients=True): 

109 """Creates an object that stores options for SavedModel saving. 

110 

111 Args: 

112 namespace_whitelist: List of strings containing op namespaces to whitelist 

113 when saving a model. Saving an object that uses namespaced ops must 

114 explicitly add all namespaces to the whitelist. The namespaced ops must 

115 be registered into the framework when loading the SavedModel. If no 

116 whitelist is provided, all namespaced ops will be allowed. 

117 save_debug_info: Boolean indicating whether debug information is saved. If 

118 True, then a debug/saved_model_debug_info.pb file will be written with 

119 the contents of a GraphDebugInfo binary protocol buffer containing stack 

120 trace information for all ops and functions that are saved. 

121 function_aliases: Python dict. Mapping from string to object returned by 

122 @tf.function. A single tf.function can generate many ConcreteFunctions. 

123 If a downstream tool wants to refer to all concrete functions generated 

124 by a single tf.function you can use the `function_aliases` argument to 

125 store a map from the alias name to all concrete function names. 

126 E.g. 

127 

128 >>> class Adder(tf.Module): 

129 ... @tf.function 

130 ... def double(self, x): 

131 ... return x + x 

132 

133 >>> model = Adder() 

134 >>> model.double.get_concrete_function( 

135 ... tf.TensorSpec(shape=[], dtype=tf.float32, name="float_input")) 

136 >>> model.double.get_concrete_function( 

137 ... tf.TensorSpec(shape=[], dtype=tf.string, name="string_input")) 

138 

139 >>> options = tf.saved_model.SaveOptions( 

140 ... function_aliases={'double': model.double}) 

141 >>> tf.saved_model.save(model, '/tmp/adder', options=options) 

142 

143 experimental_io_device: string. Applies in a distributed setting. 

144 Tensorflow device to use to access the filesystem. If `None` (default) 

145 then for each variable the filesystem is accessed from the CPU:0 device 

146 of the host where that variable is assigned. If specified, the 

147 filesystem is instead accessed from that device for all variables. 

148 

149 This is for example useful if you want to save to a local directory, 

150 such as "/tmp" when running in a distributed setting. In that case pass 

151 a device for the host where the "/tmp" directory is accessible. 

152 experimental_variable_policy: The policy to apply to variables when 

153 saving. This is either a `saved_model.experimental.VariablePolicy` enum 

154 instance or one of its value strings (case is not important). See that 

155 enum documentation for details. A value of `None` corresponds to the 

156 default policy. 

157 experimental_custom_gradients: Boolean. When True, will save traced 

158 gradient functions for the functions decorated by `tf.custom_gradient`. 

159 Defaults to `True`. 

160 """ 

161 self.namespace_whitelist = _validate_namespace_whitelist( 

162 namespace_whitelist) 

163 self.save_debug_info = save_debug_info 

164 self.function_aliases = function_aliases if function_aliases else dict() 

165 self.experimental_custom_gradients = experimental_custom_gradients 

166 self.experimental_io_device = experimental_io_device 

167 self.experimental_variable_policy = ( 

168 VariablePolicy.from_obj(experimental_variable_policy)) 

169 

170 

171def _validate_namespace_whitelist(namespace_whitelist): 

172 """Validates namespace whitelist argument.""" 

173 if namespace_whitelist is None: 

174 return None 

175 if not isinstance(namespace_whitelist, list): 

176 raise TypeError("`namespace_whitelist` must be a list of strings. Got: " 

177 f"{namespace_whitelist} with type " 

178 f"{type(namespace_whitelist)}.") 

179 

180 processed = [] 

181 for namespace in namespace_whitelist: 

182 if not isinstance(namespace, str): 

183 raise ValueError("Whitelisted namespace must be a string. Got: " 

184 f"{namespace} of type {type(namespace)}.") 

185 processed.append(compat.as_str(namespace)) 

186 return processed