Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/saving/legacy/saved_model/serialized_attributes.py: 50%

80 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Helper classes that list&validate all attributes to serialize to SavedModel. 

16""" 

17 

18import tensorflow.compat.v2 as tf 

19 

20from keras.src.saving.legacy.saved_model import constants 

21from keras.src.saving.legacy.saved_model import order_preserving_set as ops 

22from keras.src.saving.legacy.saved_model import save_impl 

23from keras.src.utils.generic_utils import LazyLoader 

24 

25# TODO(b/134426265): Switch back to single-quotes to match the rest of the file 

26# once the issue with copybara is fixed. 

27 

28base_layer = LazyLoader("base_layer", globals(), "keras.src.engine.base_layer") 

29training_lib = LazyLoader("training_lib", globals(), "keras.src.engine.training") 

30metrics = LazyLoader("metrics", globals(), "keras.src.metrics") 

31base_rnn = LazyLoader("base_rnn", globals(), "keras.src.layers.rnn.base_rnn") 

32 

33 

34class SerializedAttributes: 

35 """Class that tracks and validates all serialization attributes. 

36 

37 Keras models contain many Python-defined components. For example, the 

38 trainable_variable property lists the model's trainable variables by 

39 recursively retrieving the trainable variables from each of the child 

40 layers. Another example is model.call, a python function that calls child 

41 layers and adds ops to the backend graph. 

42 

43 Only Tensorflow checkpointable objects and functions can be serialized to 

44 SavedModel. Serializing a Keras model as-is results in a checkpointable 

45 object that does not resemble a Keras model at all. Thus, extra 

46 checkpointable objects and functions must be created during serialization. 

47 

48 **Defining new serialized attributes** 

49 Child classes should be defined using: 

50 SerializedAttributes.with_attributes( 

51 'name', checkpointable_objects=[...], 

52 functions=[...], copy_from=[...]) 

53 This class is used to cache generated checkpointable objects and functions, 

54 ensuring that new objects and functions are generated a single time. 

55 

56 **Usage during serialization** 

57 Each Layer/Model object should have a corresponding instance of 

58 SerializedAttributes. Create a new instance by calling 

59 `SerializedAttributes.new(obj)`. Objects and functions may be saved using 

60 `.set_and_validate_checkpointable_objects`/`.set_and_and_validate_functions`. 

61 The properties `.checkpointable_objects` and `.functions` returns the cached 

62 values. 

63 

64 **Adding/changing attributes to save to SavedModel** 

65 1. Change the call to `SerializedAttributes.with_attributes` in the correct 

66 class: 

67 - CommonEndpoints: Base attributes to be added during serialization. If 

68 these attributes are present in a Trackable object, it can be 

69 deserialized to a Keras Model. 

70 - LayerAttributes: Attributes to serialize for Layer objects. 

71 - ModelAttributes: Attributes to serialize for Model objects. 

72 2. Update class docstring 

73 3. Update arguments to any calls to `set_and_validate_*`. For example, if 

74 `call_raw_tensors` is added to the ModelAttributes function list, then 

75 a `call_raw_tensors` function should be passed to 

76 `set_and_validate_functions`. 

77 

78 **Common endpoints vs other attributes** 

79 Only common endpoints are attached directly to the root object. 

80 Keras-specific attributes are saved to a separate trackable object with the 

81 name "keras_api". The number of objects attached to the root is limited 

82 because any naming conflicts will cause user code to break. 

83 

84 Another reason is that this will only affect users who call 

85 `tf.saved_model.load` instead of `tf.keras.models.load_model`. These are 

86 advanced users who are likely to have defined their own tf.functions and 

87 trackable objects. The added Keras-specific attributes are kept out of the 

88 way in the "keras_api" namespace. 

89 

90 Properties defined in this class may be used to filter out keras-specific 

91 attributes: 

92 - `functions_to_serialize`: Returns dict of functions to attach to the root 

93 object. 

94 - `checkpointable_objects_to_serialize`: Returns dict of objects to attach 

95 to the root object (including separate trackable object containing 

96 keras-specific attributes) 

97 

98 All changes to the serialized attributes must be backwards-compatible, so 

99 attributes should not be removed or modified without sufficient 

100 justification. 

101 """ 

102 

103 @staticmethod 

104 def with_attributes( 

105 name, checkpointable_objects=None, functions=None, copy_from=None 

106 ): 

107 """Creates a subclass with all attributes as specified in the arguments. 

108 

109 Args: 

110 name: Name of subclass 

111 checkpointable_objects: List of checkpointable objects to be 

112 serialized in the SavedModel. 

113 functions: List of functions to be serialized in the SavedModel. 

114 copy_from: List of other SerializedAttributes subclasses. The returned 

115 class will copy checkpoint objects/functions from each subclass. 

116 

117 Returns: 

118 Child class with attributes as defined in the `checkpointable_objects` 

119 and `functions` lists. 

120 """ 

121 checkpointable_objects = checkpointable_objects or [] 

122 functions = functions or [] 

123 

124 if copy_from is not None: 

125 for cls in copy_from: 

126 checkpointable_objects.extend(cls.all_checkpointable_objects) 

127 functions.extend(cls.all_functions) 

128 

129 # OrderPreservingSets are used here to guarantee serialization 

130 # determinism of Keras objects. 

131 classdict = { 

132 "all_checkpointable_objects": ops.OrderPreservingSet( 

133 checkpointable_objects 

134 ), 

135 "all_functions": ops.OrderPreservingSet(functions), 

136 } 

137 return type(name, (SerializedAttributes,), classdict) 

138 

139 @staticmethod 

140 def new(obj): 

141 """Returns a new SerializedAttribute object.""" 

142 if isinstance(obj, training_lib.Model): 

143 return ModelAttributes() 

144 elif isinstance(obj, metrics.Metric): 

145 return MetricAttributes() 

146 elif isinstance(obj, base_rnn.RNN): 

147 return RNNAttributes() 

148 elif isinstance(obj, base_layer.Layer): 

149 return LayerAttributes() 

150 else: 

151 raise TypeError( 

152 "Internal error during serialization. Expected Keras " 

153 f"Layer object. Received: {obj} " 

154 f"(of type {type(obj)})" 

155 ) 

156 

157 def __init__(self): 

158 self._object_dict = {} 

159 self._function_dict = {} 

160 self._keras_trackable = tf.__internal__.tracking.AutoTrackable() 

161 

162 @property 

163 def functions(self): 

164 """Returns dictionary of all functions.""" 

165 return { 

166 key: value 

167 for key, value in self._function_dict.items() 

168 if value is not None 

169 } 

170 

171 @property 

172 def checkpointable_objects(self): 

173 """Returns dictionary of all checkpointable objects.""" 

174 return { 

175 key: value 

176 for key, value in self._object_dict.items() 

177 if value is not None 

178 } 

179 

180 @property 

181 def functions_to_serialize(self): 

182 """Returns functions to attach to the root object during 

183 serialization.""" 

184 functions = {} 

185 for key, v in self.functions.items(): 

186 if key in CommonEndpoints.all_functions: 

187 functions[key] = ( 

188 v.wrapped_call if isinstance(v, save_impl.LayerCall) else v 

189 ) 

190 return functions 

191 

192 @property 

193 def objects_to_serialize(self): 

194 """Returns objects to attach to the root object during serialization.""" 

195 objects = { 

196 key: value 

197 for key, value in self.checkpointable_objects.items() 

198 if key in CommonEndpoints.all_checkpointable_objects 

199 } 

200 objects[constants.KERAS_ATTR] = self._keras_trackable 

201 return objects 

202 

203 def set_and_validate_functions(self, function_dict): 

204 """Saves function dictionary, and validates dictionary values.""" 

205 for key in self.all_functions: 

206 if key in function_dict: 

207 if function_dict[ 

208 key 

209 # Not all functions are required 

210 ] is not None and not isinstance( 

211 function_dict[key], 

212 ( 

213 tf.__internal__.function.Function, 

214 tf.types.experimental.ConcreteFunction, 

215 save_impl.LayerCall, 

216 ), 

217 ): 

218 raise ValueError( 

219 "The tf.function dictionary contained a non-function " 

220 f"object: {function_dict[key]} (for key {key}). Only " 

221 "tf.function instances or ConcreteFunction instances " 

222 "should be passed." 

223 ) 

224 fn = function_dict[key] 

225 self._function_dict[key] = fn 

226 

227 # Extract TensorFlow `Function` from LayerCall. 

228 tf_fn = ( 

229 fn.wrapped_call 

230 if isinstance(fn, save_impl.LayerCall) 

231 else fn 

232 ) 

233 setattr(self._keras_trackable, key, tf_fn) 

234 else: 

235 raise ValueError( 

236 f"Function {key} missing from serialized " 

237 "tf.function dictionary." 

238 ) 

239 return self.functions 

240 

241 def set_and_validate_objects(self, object_dict): 

242 """Saves objects to a dictionary, and validates the values.""" 

243 for key in self.all_checkpointable_objects: 

244 if key in object_dict: 

245 if not isinstance( 

246 object_dict[key], tf.__internal__.tracking.Trackable 

247 ): 

248 raise ValueError( 

249 "The object dictionary contained a non-trackable " 

250 f"object: {object_dict[key]} (for key {key}). " 

251 "Only trackable objects are " 

252 "allowed, such as Keras layers/models or " 

253 "tf.Module instances." 

254 ) 

255 self._object_dict[key] = object_dict[key] 

256 setattr(self._keras_trackable, key, object_dict[key]) 

257 else: 

258 raise ValueError( 

259 f"Object {key} missing from serialized object dictionary." 

260 ) 

261 return self.checkpointable_objects 

262 

263 

264class CommonEndpoints( 

265 SerializedAttributes.with_attributes( 

266 "CommonEndpoints", 

267 checkpointable_objects=[ 

268 "variables", 

269 "trainable_variables", 

270 "regularization_losses", 

271 ], 

272 functions=[ 

273 "__call__", 

274 "call_and_return_all_conditional_losses", 

275 "_default_save_signature", 

276 ], 

277 ) 

278): 

279 """Common endpoints shared by all models loadable by Keras. 

280 

281 List of all attributes: 

282 variables: List of all variables in the model and its sublayers. 

283 trainable_variables: List of all trainable variables in the model and its 

284 sublayers. 

285 regularization_losses: List of all unconditional losses (losses not 

286 dependent on the inputs) in the model and its sublayers. 

287 __call__: Function that takes inputs and returns the outputs of the model 

288 call function. 

289 call_and_return_all_conditional_losses: Function that returns a tuple of 

290 (call function outputs, list of all losses that depend on the inputs). 

291 _default_save_signature: Traced model call function. This is only included 

292 if the top level exported object is a Keras model. 

293 """ 

294 

295 

296class LayerAttributes( 

297 SerializedAttributes.with_attributes( 

298 "LayerAttributes", 

299 checkpointable_objects=[ 

300 "non_trainable_variables", 

301 "layers", 

302 "metrics", 

303 "layer_regularization_losses", 

304 "layer_metrics", 

305 ], 

306 functions=[ 

307 "call_and_return_conditional_losses", 

308 "activity_regularizer_fn", 

309 ], 

310 copy_from=[CommonEndpoints], 

311 ) 

312): 

313 """Layer checkpointable objects + functions saved to the SavedModel. 

314 

315 List of all attributes: 

316 All attributes from CommonEndpoints 

317 non_trainable_variables: List of non-trainable variables in the layer and 

318 its sublayers. 

319 layers: List of all sublayers. 

320 metrics: List of all metrics in the layer and its sublayers. 

321 call_and_return_conditional_losses: Function that takes inputs and returns 

322 a tuple of (outputs of the call function, list of input-dependent 

323 losses). The list of losses excludes the activity regularizer function, 

324 which is separate to allow the deserialized Layer object to define a 

325 different activity regularizer. 

326 activity_regularizer_fn: Callable that returns the activity regularizer 

327 loss 

328 layer_regularization_losses: List of losses owned only by this layer. 

329 layer_metrics: List of metrics owned by this layer. 

330 """ 

331 

332 

333class ModelAttributes( 

334 SerializedAttributes.with_attributes( 

335 "ModelAttributes", copy_from=[LayerAttributes] 

336 ) 

337): 

338 """Model checkpointable objects + functions saved to the SavedModel. 

339 

340 List of all attributes: 

341 All attributes from LayerAttributes (including CommonEndpoints) 

342 """ 

343 

344 # TODO(kathywu): Add attributes `compile_losses` and `compile_metrics`, 

345 # which list all losses and metrics defined by `model.compile`. 

346 

347 

348class MetricAttributes( 

349 SerializedAttributes.with_attributes( 

350 "MetricAttributes", 

351 checkpointable_objects=["variables"], 

352 functions=[], 

353 ) 

354): 

355 """Attributes that are added to Metric objects when saved to SavedModel. 

356 

357 List of all attributes: 

358 variables: list of all variables 

359 """ 

360 

361 pass 

362 

363 

364class RNNAttributes( 

365 SerializedAttributes.with_attributes( 

366 "RNNAttributes", 

367 checkpointable_objects=["states"], 

368 copy_from=[LayerAttributes], 

369 ) 

370): 

371 """RNN checkpointable objects + functions that are saved to the SavedModel. 

372 

373 List of all attributes: 

374 All attributes from LayerAttributes (including CommonEndpoints) 

375 states: List of state variables 

376 """ 

377