Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/saving/legacy/saving_utils.py: 18%

148 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Utils related to keras model saving.""" 

16 

17import copy 

18import os 

19 

20import tensorflow.compat.v2 as tf 

21 

22import keras.src as keras 

23from keras.src import backend 

24from keras.src import losses 

25from keras.src import optimizers 

26from keras.src.engine import base_layer_utils 

27from keras.src.optimizers import optimizer_v1 

28from keras.src.saving.legacy import serialization 

29from keras.src.utils import version_utils 

30from keras.src.utils.io_utils import ask_to_proceed_with_overwrite 

31 

32# isort: off 

33from tensorflow.python.platform import tf_logging as logging 

34 

35 

36def extract_model_metrics(model): 

37 """Convert metrics from a Keras model `compile` API to dictionary. 

38 

39 This is used for converting Keras models to Estimators and SavedModels. 

40 

41 Args: 

42 model: A `tf.keras.Model` object. 

43 

44 Returns: 

45 Dictionary mapping metric names to metric instances. May return `None` if 

46 the model does not contain any metrics. 

47 """ 

48 if getattr(model, "_compile_metrics", None): 

49 # TODO(psv/kathywu): use this implementation in model to estimator flow. 

50 # We are not using model.metrics here because we want to exclude the 

51 # metrics added using `add_metric` API. 

52 return {m.name: m for m in model._compile_metric_functions} 

53 return None 

54 

55 

56def model_call_inputs(model, keep_original_batch_size=False): 

57 """Inspect model to get its input signature. 

58 

59 The model's input signature is a list with a single (possibly-nested) 

60 object. This is due to the Keras-enforced restriction that tensor inputs 

61 must be passed in as the first argument. 

62 

63 For example, a model with input {'feature1': <Tensor>, 'feature2': <Tensor>} 

64 will have input signature: 

65 [{'feature1': TensorSpec, 'feature2': TensorSpec}] 

66 

67 Args: 

68 model: Keras Model object. 

69 keep_original_batch_size: A boolean indicating whether we want to keep 

70 using the original batch size or set it to None. Default is `False`, 

71 which means that the batch dim of the returned input signature will 

72 always be set to `None`. 

73 

74 Returns: 

75 A tuple containing `(args, kwargs)` TensorSpecs of the model call function 

76 inputs. 

77 `kwargs` does not contain the `training` argument. 

78 """ 

79 input_specs = model.save_spec(dynamic_batch=not keep_original_batch_size) 

80 if input_specs is None: 

81 return None, None 

82 input_specs = _enforce_names_consistency(input_specs) 

83 return input_specs 

84 

85 

86def raise_model_input_error(model): 

87 if isinstance(model, keras.models.Sequential): 

88 raise ValueError( 

89 f"Model {model} cannot be saved because the input shape is not " 

90 "available. Please specify an input shape either by calling " 

91 "`build(input_shape)` directly, or by calling the model on actual " 

92 "data using `Model()`, `Model.fit()`, or `Model.predict()`." 

93 ) 

94 

95 # If the model is not a `Sequential`, it is intended to be a subclassed 

96 # model. 

97 raise ValueError( 

98 f"Model {model} cannot be saved either because the input shape is not " 

99 "available or because the forward pass of the model is not defined." 

100 "To define a forward pass, please override `Model.call()`. To specify " 

101 "an input shape, either call `build(input_shape)` directly, or call " 

102 "the model on actual data using `Model()`, `Model.fit()`, or " 

103 "`Model.predict()`. If you have a custom training step, please make " 

104 "sure to invoke the forward pass in train step through " 

105 "`Model.__call__`, i.e. `model(inputs)`, as opposed to `model.call()`." 

106 ) 

107 

108 

109def trace_model_call(model, input_signature=None): 

110 """Trace the model call to create a tf.function for exporting a Keras model. 

111 

112 Args: 

113 model: A Keras model. 

114 input_signature: optional, a list of tf.TensorSpec objects specifying the 

115 inputs to the model. 

116 

117 Returns: 

118 A tf.function wrapping the model's call function with input signatures 

119 set. 

120 

121 Raises: 

122 ValueError: if input signature cannot be inferred from the model. 

123 """ 

124 if input_signature is None: 

125 if isinstance(model.call, tf.__internal__.function.Function): 

126 input_signature = model.call.input_signature 

127 

128 if input_signature: 

129 model_args = input_signature 

130 model_kwargs = {} 

131 else: 

132 model_args, model_kwargs = model_call_inputs(model) 

133 

134 if model_args is None: 

135 raise_model_input_error(model) 

136 

137 @tf.function 

138 def _wrapped_model(*args, **kwargs): 

139 """A concrete tf.function that wraps the model's call function.""" 

140 (args, kwargs,) = model._call_spec.set_arg_value( 

141 "training", False, args, kwargs, inputs_in_args=True 

142 ) 

143 

144 with base_layer_utils.call_context().enter( 

145 model, inputs=None, build_graph=False, training=False, saving=True 

146 ): 

147 outputs = model(*args, **kwargs) 

148 

149 # Outputs always has to be a flat dict. 

150 output_names = model.output_names # Functional Model. 

151 if output_names is None: # Subclassed Model. 

152 from keras.src.engine import compile_utils 

153 

154 output_names = compile_utils.create_pseudo_output_names(outputs) 

155 outputs = tf.nest.flatten(outputs) 

156 return {name: output for name, output in zip(output_names, outputs)} 

157 

158 return _wrapped_model.get_concrete_function(*model_args, **model_kwargs) 

159 

160 

161def model_metadata(model, include_optimizer=True, require_config=True): 

162 """Returns a dictionary containing the model metadata.""" 

163 from keras.src import __version__ as keras_version 

164 from keras.src.optimizers.legacy import optimizer_v2 

165 

166 model_config = {"class_name": model.__class__.__name__} 

167 try: 

168 model_config["config"] = model.get_config() 

169 except NotImplementedError as e: 

170 if require_config: 

171 raise e 

172 

173 metadata = dict( 

174 keras_version=str(keras_version), 

175 backend=backend.backend(), 

176 model_config=model_config, 

177 ) 

178 if model.optimizer and include_optimizer: 

179 if isinstance(model.optimizer, optimizer_v1.TFOptimizer): 

180 logging.warning( 

181 "TensorFlow optimizers do not " 

182 "make it possible to access " 

183 "optimizer attributes or optimizer state " 

184 "after instantiation. " 

185 "As a result, we cannot save the optimizer " 

186 "as part of the model save file. " 

187 "You will have to compile your model again after loading it. " 

188 "Prefer using a Keras optimizer instead " 

189 "(see keras.io/optimizers)." 

190 ) 

191 elif model._compile_was_called: 

192 training_config = model._get_compile_args(user_metrics=False) 

193 training_config.pop("optimizer", None) # Handled separately. 

194 metadata["training_config"] = _serialize_nested_config( 

195 training_config 

196 ) 

197 if isinstance(model.optimizer, optimizer_v2.RestoredOptimizer): 

198 raise NotImplementedError( 

199 "Optimizers loaded from a SavedModel cannot be saved. " 

200 "If you are calling `model.save` or " 

201 "`tf.keras.models.save_model`, " 

202 "please set the `include_optimizer` option to `False`. For " 

203 "`tf.saved_model.save`, " 

204 "delete the optimizer from the model." 

205 ) 

206 else: 

207 optimizer_config = { 

208 "class_name": keras.utils.get_registered_name( 

209 model.optimizer.__class__ 

210 ), 

211 "config": model.optimizer.get_config(), 

212 } 

213 metadata["training_config"]["optimizer_config"] = optimizer_config 

214 return metadata 

215 

216 

217def should_overwrite(filepath, overwrite): 

218 """Returns whether the filepath should be overwritten.""" 

219 # If file exists and should not be overwritten. 

220 if not overwrite and os.path.isfile(filepath): 

221 return ask_to_proceed_with_overwrite(filepath) 

222 return True 

223 

224 

225def compile_args_from_training_config(training_config, custom_objects=None): 

226 """Return model.compile arguments from training config.""" 

227 if custom_objects is None: 

228 custom_objects = {} 

229 

230 with keras.utils.CustomObjectScope(custom_objects): 

231 optimizer_config = training_config["optimizer_config"] 

232 optimizer = optimizers.deserialize(optimizer_config) 

233 

234 # Recover losses. 

235 loss = None 

236 loss_config = training_config.get("loss", None) 

237 if loss_config is not None: 

238 loss = _deserialize_nested_config(losses.deserialize, loss_config) 

239 

240 # Recover metrics. 

241 metrics = None 

242 metrics_config = training_config.get("metrics", None) 

243 if metrics_config is not None: 

244 metrics = _deserialize_nested_config( 

245 _deserialize_metric, metrics_config 

246 ) 

247 

248 # Recover weighted metrics. 

249 weighted_metrics = None 

250 weighted_metrics_config = training_config.get("weighted_metrics", None) 

251 if weighted_metrics_config is not None: 

252 weighted_metrics = _deserialize_nested_config( 

253 _deserialize_metric, weighted_metrics_config 

254 ) 

255 

256 sample_weight_mode = ( 

257 training_config["sample_weight_mode"] 

258 if hasattr(training_config, "sample_weight_mode") 

259 else None 

260 ) 

261 loss_weights = training_config["loss_weights"] 

262 

263 return dict( 

264 optimizer=optimizer, 

265 loss=loss, 

266 metrics=metrics, 

267 weighted_metrics=weighted_metrics, 

268 loss_weights=loss_weights, 

269 sample_weight_mode=sample_weight_mode, 

270 ) 

271 

272 

273def _deserialize_nested_config(deserialize_fn, config): 

274 """Deserializes arbitrary Keras `config` using `deserialize_fn`.""" 

275 

276 def _is_single_object(obj): 

277 if isinstance(obj, dict) and "class_name" in obj: 

278 return True # Serialized Keras object. 

279 if isinstance(obj, str): 

280 return True # Serialized function or string. 

281 return False 

282 

283 if config is None: 

284 return None 

285 if _is_single_object(config): 

286 return deserialize_fn(config) 

287 elif isinstance(config, dict): 

288 return { 

289 k: _deserialize_nested_config(deserialize_fn, v) 

290 for k, v in config.items() 

291 } 

292 elif isinstance(config, (tuple, list)): 

293 return [ 

294 _deserialize_nested_config(deserialize_fn, obj) for obj in config 

295 ] 

296 

297 raise ValueError( 

298 "Saved configuration not understood. Configuration should be a " 

299 f"dictionary, string, tuple or list. Received: config={config}." 

300 ) 

301 

302 

303def _serialize_nested_config(config): 

304 """Serialized a nested structure of Keras objects.""" 

305 

306 def _serialize_fn(obj): 

307 if callable(obj): 

308 return serialization.serialize_keras_object(obj) 

309 return obj 

310 

311 return tf.nest.map_structure(_serialize_fn, config) 

312 

313 

314def _deserialize_metric(metric_config): 

315 """Deserialize metrics, leaving special strings untouched.""" 

316 from keras.src import metrics as metrics_module 

317 

318 if metric_config in ["accuracy", "acc", "crossentropy", "ce"]: 

319 # Do not deserialize accuracy and cross-entropy strings as we have 

320 # special case handling for these in compile, based on model output 

321 # shape. 

322 return metric_config 

323 return metrics_module.deserialize(metric_config) 

324 

325 

326def _enforce_names_consistency(specs): 

327 """Enforces that either all specs have names or none do.""" 

328 

329 def _has_name(spec): 

330 return spec is None or (hasattr(spec, "name") and spec.name is not None) 

331 

332 def _clear_name(spec): 

333 spec = copy.deepcopy(spec) 

334 if hasattr(spec, "name"): 

335 spec._name = None 

336 return spec 

337 

338 flat_specs = tf.nest.flatten(specs) 

339 name_inconsistency = any(_has_name(s) for s in flat_specs) and not all( 

340 _has_name(s) for s in flat_specs 

341 ) 

342 

343 if name_inconsistency: 

344 specs = tf.nest.map_structure(_clear_name, specs) 

345 return specs 

346 

347 

348def try_build_compiled_arguments(model): 

349 if ( 

350 not version_utils.is_v1_layer_or_model(model) 

351 and model.outputs is not None 

352 ): 

353 try: 

354 if not model.compiled_loss.built: 

355 model.compiled_loss.build(model.outputs) 

356 if not model.compiled_metrics.built: 

357 model.compiled_metrics.build(model.outputs, model.outputs) 

358 except: # noqa: E722 

359 logging.warning( 

360 "Compiled the loaded model, but the compiled metrics have " 

361 "yet to be built. `model.compile_metrics` will be empty " 

362 "until you train or evaluate the model." 

363 ) 

364 

365 

366def is_hdf5_filepath(filepath): 

367 return ( 

368 filepath.endswith(".h5") 

369 or filepath.endswith(".keras") 

370 or filepath.endswith(".hdf5") 

371 ) 

372