Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/serialized_attributes.py: 51%

81 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Helper classes that list&validate all attributes to serialize to SavedModel. 

16""" 

17 

18from tensorflow.python.eager import def_function 

19from tensorflow.python.keras.saving.saved_model import constants 

20from tensorflow.python.keras.saving.saved_model import save_impl 

21from tensorflow.python.keras.utils.generic_utils import LazyLoader 

22from tensorflow.python.trackable import base as trackable 

23from tensorflow.python.trackable.autotrackable import AutoTrackable 

24 

25# TODO(b/134426265): Switch back to single-quotes to match the rest of the file 

26# once the issue with copybara is fixed. 

27# pylint:disable=g-inconsistent-quotes 

28base_layer = LazyLoader( 

29 "base_layer", globals(), 

30 "tensorflow.python.keras.engine.base_layer") 

31training_lib = LazyLoader( 

32 "training_lib", globals(), 

33 "tensorflow.python.keras.engine.training") 

34metrics = LazyLoader("metrics", globals(), 

35 "tensorflow.python.keras.metrics") 

36recurrent = LazyLoader( 

37 "recurrent", globals(), 

38 "tensorflow.python.keras.layers.recurrent") 

39# pylint:enable=g-inconsistent-quotes 

40 

41 

42class SerializedAttributes(object): 

43 """Class that tracks and validates all serialization attributes. 

44 

45 Keras models contain many Python-defined components. For example, the 

46 trainable_variable property lists the model's trainable variables by 

47 recursively retrieving the trainable variables from each of the child layers. 

48 Another example is model.call, a python function that calls child layers and 

49 adds ops to the backend graph. 

50 

51 Only Tensorflow checkpointable objects and functions can be serialized to 

52 SavedModel. Serializing a Keras model as-is results in a checkpointable object 

53 that does not resemble a Keras model at all. Thus, extra checkpointable 

54 objects and functions must be created during serialization. 

55 

56 **Defining new serialized attributes** 

57 Child classes should be defined using: 

58 SerializedAttributes.with_attributes( 

59 'name', checkpointable_objects=[...], functions=[...], copy_from=[...]) 

60 This class is used to cache generated checkpointable objects and functions, 

61 ensuring that new objects and functions are generated a single time. 

62 

63 **Usage during serialization** 

64 Each Layer/Model object should have a corresponding instance of 

65 SerializedAttributes. Create a new instance by calling 

66 `SerializedAttributes.new(obj)`. Objects and functions may be saved using 

67 `.set_and_validate_checkpointable_objects`/`.set_and_and_validate_functions`. 

68 The properties `.checkpointable_objects` and `.functions` returns the cached 

69 values. 

70 

71 **Adding/changing attributes to save to SavedModel** 

72 1. Change the call to `SerializedAttributes.with_attributes` in the correct 

73 class: 

74 - CommonEndpoints: Base attributes to be added during serialization. If 

75 these attributes are present in a Trackable object, it can be 

76 deserialized to a Keras Model. 

77 - LayerAttributes: Attributes to serialize for Layer objects. 

78 - ModelAttributes: Attributes to serialize for Model objects. 

79 2. Update class docstring 

80 3. Update arguments to any calls to `set_and_validate_*`. For example, if 

81 `call_raw_tensors` is added to the ModelAttributes function list, then 

82 a `call_raw_tensors` function should be passed to 

83 `set_and_validate_functions`. 

84 

85 **Common endpoints vs other attributes** 

86 Only common endpoints are attached directly to the root object. Keras-specific 

87 attributes are saved to a separate trackable object with the name "keras_api". 

88 The number of objects attached to the root is limited because any naming 

89 conflicts will cause user code to break. 

90 

91 Another reason is that this will only affect users who call 

92 `tf.saved_model.load` instead of `tf.keras.models.load_model`. These are 

93 advanced users who are likely to have defined their own tf.functions and 

94 trackable objects. The added Keras-specific attributes are kept out of the way 

95 in the "keras_api" namespace. 

96 

97 Properties defined in this class may be used to filter out keras-specific 

98 attributes: 

99 - `functions_to_serialize`: Returns dict of functions to attach to the root 

100 object. 

101 - `checkpointable_objects_to_serialize`: Returns dict of objects to attach to 

102 the root object (including separate trackable object containing 

103 keras-specific attributes) 

104 

105 All changes to the serialized attributes must be backwards-compatible, so 

106 attributes should not be removed or modified without sufficient justification. 

107 """ 

108 

109 @staticmethod 

110 def with_attributes( 

111 name, checkpointable_objects=None, functions=None, copy_from=None): 

112 """Creates a subclass with all attributes as specified in the arguments. 

113 

114 Args: 

115 name: Name of subclass 

116 checkpointable_objects: List of checkpointable objects to be serialized 

117 in the SavedModel. 

118 functions: List of functions to be serialized in the SavedModel. 

119 copy_from: List of other SerializedAttributes subclasses. The returned 

120 class will copy checkpoint objects/functions from each subclass. 

121 

122 Returns: 

123 Child class with attributes as defined in the `checkpointable_objects` 

124 and `functions` lists. 

125 """ 

126 checkpointable_objects = checkpointable_objects or [] 

127 functions = functions or [] 

128 

129 if copy_from is not None: 

130 for cls in copy_from: 

131 checkpointable_objects.extend(cls.all_checkpointable_objects) 

132 functions.extend(cls.all_functions) 

133 

134 classdict = { 

135 'all_checkpointable_objects': set(checkpointable_objects), 

136 'all_functions': set(functions)} 

137 return type(name, (SerializedAttributes,), classdict) 

138 

139 @staticmethod 

140 def new(obj): 

141 """Returns a new SerializedAttribute object.""" 

142 if isinstance(obj, training_lib.Model): 

143 return ModelAttributes() 

144 elif isinstance(obj, metrics.Metric): 

145 return MetricAttributes() 

146 elif isinstance(obj, recurrent.RNN): 

147 return RNNAttributes() 

148 elif isinstance(obj, base_layer.Layer): 

149 return LayerAttributes() 

150 else: 

151 raise TypeError('Internal error during serialization: Expected Keras ' 

152 'Layer object, got {} of type {}'.format(obj, type(obj))) 

153 

154 def __init__(self): 

155 self._object_dict = {} 

156 self._function_dict = {} 

157 self._keras_trackable = AutoTrackable() 

158 

159 @property 

160 def functions(self): 

161 """Returns dictionary of all functions.""" 

162 return {key: value for key, value in self._function_dict.items() 

163 if value is not None} 

164 

165 @property 

166 def checkpointable_objects(self): 

167 """Returns dictionary of all checkpointable objects.""" 

168 return {key: value for key, value in self._object_dict.items() 

169 if value is not None} 

170 

171 @property 

172 def functions_to_serialize(self): 

173 """Returns functions to attach to the root object during serialization.""" 

174 functions = {} 

175 for key, v in self.functions.items(): 

176 if key in CommonEndpoints.all_functions: 

177 functions[key] = (v.wrapped_call if isinstance(v, save_impl.LayerCall) 

178 else v) 

179 return functions 

180 

181 @property 

182 def objects_to_serialize(self): 

183 """Returns objects to attach to the root object during serialization.""" 

184 objects = {key: value for key, value in self.checkpointable_objects.items() 

185 if key in CommonEndpoints.all_checkpointable_objects} 

186 objects[constants.KERAS_ATTR] = self._keras_trackable 

187 return objects 

188 

189 def set_and_validate_functions(self, function_dict): 

190 """Saves function dictionary, and validates dictionary values.""" 

191 for key in self.all_functions: 

192 if key in function_dict: 

193 if (function_dict[key] is not None and # Not all functions are required 

194 not isinstance(function_dict[key], 

195 (def_function.Function, save_impl.LayerCall))): 

196 raise ValueError( 

197 'Function dictionary contained a non-function object: {} (for key' 

198 ' {})'.format(function_dict[key], key)) 

199 fn = function_dict[key] 

200 self._function_dict[key] = fn 

201 

202 # Extract TensorFlow `Function` from LayerCall. 

203 tf_fn = fn.wrapped_call if isinstance(fn, save_impl.LayerCall) else fn 

204 setattr(self._keras_trackable, key, tf_fn) 

205 else: 

206 raise ValueError('Function {} missing from serialized function dict.' 

207 .format(key)) 

208 return self.functions 

209 

210 def set_and_validate_objects(self, object_dict): 

211 """Saves objects to a dictionary, and validates the values.""" 

212 for key in self.all_checkpointable_objects: 

213 if key in object_dict: 

214 if not isinstance(object_dict[key], trackable.Trackable): 

215 raise ValueError( 

216 'Object dictionary contained a non-trackable object: {} (for key' 

217 ' {})'.format(object_dict[key], key)) 

218 self._object_dict[key] = object_dict[key] 

219 setattr(self._keras_trackable, key, object_dict[key]) 

220 else: 

221 raise ValueError( 

222 'Object {} missing from serialized object dict.'.format(key)) 

223 return self.checkpointable_objects 

224 

225 

226class CommonEndpoints(SerializedAttributes.with_attributes( 

227 'CommonEndpoints', 

228 checkpointable_objects=['variables', 'trainable_variables', 

229 'regularization_losses'], 

230 functions=['__call__', 'call_and_return_all_conditional_losses', 

231 '_default_save_signature'])): 

232 """Common endpoints shared by all models loadable by Keras. 

233 

234 List of all attributes: 

235 variables: List of all variables in the model and its sublayers. 

236 trainable_variables: List of all trainable variables in the model and its 

237 sublayers. 

238 regularization_losses: List of all unconditional losses (losses not 

239 dependent on the inputs) in the model and its sublayers. 

240 __call__: Function that takes inputs and returns the outputs of the model 

241 call function. 

242 call_and_return_all_conditional_losses: Function that returns a tuple of 

243 (call function outputs, list of all losses that depend on the inputs). 

244 _default_save_signature: Traced model call function. This is only included 

245 if the top level exported object is a Keras model. 

246 """ 

247 

248 

249class LayerAttributes(SerializedAttributes.with_attributes( 

250 'LayerAttributes', 

251 checkpointable_objects=['non_trainable_variables', 'layers', 'metrics', 

252 'layer_regularization_losses', 'layer_metrics'], 

253 functions=['call_and_return_conditional_losses', 'activity_regularizer_fn'], 

254 copy_from=[CommonEndpoints] 

255 )): 

256 """Layer checkpointable objects + functions that are saved to the SavedModel. 

257 

258 List of all attributes: 

259 All attributes from CommonEndpoints 

260 non_trainable_variables: List of non-trainable variables in the layer and 

261 its sublayers. 

262 layers: List of all sublayers. 

263 metrics: List of all metrics in the layer and its sublayers. 

264 call_and_return_conditional_losses: Function that takes inputs and returns a 

265 tuple of (outputs of the call function, list of input-dependent losses). 

266 The list of losses excludes the activity regularizer function, which is 

267 separate to allow the deserialized Layer object to define a different 

268 activity regularizer. 

269 activity_regularizer_fn: Callable that returns the activity regularizer loss 

270 layer_regularization_losses: List of losses owned only by this layer. 

271 layer_metrics: List of metrics owned by this layer. 

272 """ 

273 

274 

275class ModelAttributes(SerializedAttributes.with_attributes( 

276 'ModelAttributes', 

277 copy_from=[LayerAttributes])): 

278 """Model checkpointable objects + functions that are saved to the SavedModel. 

279 

280 List of all attributes: 

281 All attributes from LayerAttributes (including CommonEndpoints) 

282 """ 

283 # TODO(kathywu): Add attributes `compile_losses` and `compile_metrics`, which 

284 # list all losses and metrics defined by `model.compile`. 

285 

286 

287class MetricAttributes( 

288 SerializedAttributes.with_attributes( 

289 'MetricAttributes', 

290 checkpointable_objects=['variables'], 

291 functions=[], 

292 )): 

293 """Attributes that are added to Metric objects when saved to SavedModel. 

294 

295 List of all attributes: 

296 variables: list of all variables 

297 """ 

298 pass 

299 

300 

301class RNNAttributes(SerializedAttributes.with_attributes( 

302 'RNNAttributes', 

303 checkpointable_objects=['states'], 

304 copy_from=[LayerAttributes])): 

305 """RNN checkpointable objects + functions that are saved to the SavedModel. 

306 

307 List of all attributes: 

308 All attributes from LayerAttributes (including CommonEndpoints) 

309 states: List of state variables 

310 """ 

311