Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/saving/legacy/save.py: 18%

138 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Keras model saving code.""" 

16 

17import os 

18 

19import tensorflow.compat.v2 as tf 

20 

21from keras.src import backend 

22from keras.src.saving import object_registration 

23from keras.src.saving.legacy import hdf5_format 

24from keras.src.saving.legacy import saving_utils 

25from keras.src.saving.legacy import serialization 

26from keras.src.saving.legacy.saved_model import load as saved_model_load 

27from keras.src.saving.legacy.saved_model import load_context 

28from keras.src.saving.legacy.saved_model import save as saved_model_save 

29from keras.src.saving.legacy.saved_model.utils import keras_option_scope 

30from keras.src.utils import io_utils 

31from keras.src.utils import traceback_utils 

32 

33try: 

34 import h5py 

35except ImportError: 

36 h5py = None 

37 

38 

39@traceback_utils.filter_traceback 

40def save_model( 

41 model, 

42 filepath, 

43 overwrite=True, 

44 include_optimizer=True, 

45 save_format=None, 

46 signatures=None, 

47 options=None, 

48 save_traces=True, 

49): 

50 """Saves a model as a TensorFlow SavedModel or HDF5 file. 

51 

52 See the [Serialization and Saving 

53 guide](https://keras.io/guides/serialization_and_saving/) for details. 

54 

55 Usage: 

56 

57 >>> model = tf.keras.Sequential([ 

58 ... tf.keras.layers.Dense(5, input_shape=(3,)), 

59 ... tf.keras.layers.Softmax()]) 

60 >>> model.save('/tmp/model') 

61 >>> loaded_model = tf.keras.models.load_model('/tmp/model') 

62 >>> x = tf.random.uniform((10, 3)) 

63 >>> assert np.allclose(model.predict(x), loaded_model.predict(x)) 

64 

65 Note that `model.save()` is an alias for `tf.keras.models.save_model()`. 

66 

67 The SavedModel and HDF5 file contains: 

68 

69 - the model's configuration (topology) 

70 - the model's weights 

71 - the model's optimizer's state (if any) 

72 

73 Thus models can be reinstantiated in the exact same state, without any of 

74 the code used for model definition or training. 

75 

76 Note that the model weights may have different scoped names after being 

77 loaded. Scoped names include the model/layer names, such as 

78 `"dense_1/kernel:0"`. It is recommended that you use the layer properties to 

79 access specific variables, e.g. `model.get_layer("dense_1").kernel`. 

80 

81 __SavedModel serialization format__ 

82 

83 Keras SavedModel uses `tf.saved_model.save` to save the model and all 

84 trackable objects attached to the model (e.g. layers and variables). The 

85 model config, weights, and optimizer are saved in the SavedModel. 

86 Additionally, for every Keras layer attached to the model, the SavedModel 

87 stores: 

88 

89 * the config and metadata -- e.g. name, dtype, trainable status 

90 * traced call and loss functions, which are stored as TensorFlow 

91 subgraphs. 

92 

93 The traced functions allow the SavedModel format to save and load custom 

94 layers without the original class definition. 

95 

96 You can choose to not save the traced functions by disabling the 

97 `save_traces` option. This will decrease the time it takes to save the model 

98 and the amount of disk space occupied by the output SavedModel. If you 

99 enable this option, then you _must_ provide all custom class definitions 

100 when loading the model. See the `custom_objects` argument in 

101 `tf.keras.models.load_model`. 

102 

103 Args: 

104 model: Keras model instance to be saved. 

105 filepath: One of the following: 

106 - String or `pathlib.Path` object, path where to save the model 

107 - `h5py.File` object where to save the model 

108 overwrite: Whether we should overwrite any existing model at the target 

109 location, or instead ask the user with a manual prompt. 

110 include_optimizer: If True, save optimizer's state together. 

111 save_format: Either 'tf' or 'h5', indicating whether to save the model 

112 to Tensorflow SavedModel or HDF5. Defaults to 'tf' in TF 2.X, and 'h5' 

113 in TF 1.X. 

114 signatures: Signatures to save with the SavedModel. Applicable to the 

115 'tf' format only. Please see the `signatures` argument in 

116 `tf.saved_model.save` for details. 

117 options: (only applies to SavedModel format) 

118 `tf.saved_model.SaveOptions` object that specifies options for saving 

119 to SavedModel. 

120 save_traces: (only applies to SavedModel format) When enabled, the 

121 SavedModel will store the function traces for each layer. This 

122 can be disabled, so that only the configs of each layer are stored. 

123 Defaults to `True`. Disabling this will decrease serialization time 

124 and reduce file size, but it requires that all custom layers/models 

125 implement a `get_config()` method. 

126 

127 Raises: 

128 ImportError: If save format is hdf5, and h5py is not available. 

129 """ 

130 

131 from keras.src.engine import sequential 

132 

133 default_format = "tf" if tf.__internal__.tf2.enabled() else "h5" 

134 save_format = save_format or default_format 

135 

136 filepath = io_utils.path_to_string(filepath) 

137 

138 # If the user has not already called fit or built the underlying metrics, we 

139 # should do that before saving to ensure the metric names have all 

140 # appropriate name transformations applied. 

141 saving_utils.try_build_compiled_arguments(model) 

142 

143 if ( 

144 save_format == "h5" 

145 or (h5py is not None and isinstance(filepath, h5py.File)) 

146 or saving_utils.is_hdf5_filepath(filepath) 

147 ): 

148 # TODO(b/130258301): add utility method for detecting model type. 

149 if not model._is_graph_network and not isinstance( 

150 model, sequential.Sequential 

151 ): 

152 raise NotImplementedError( 

153 "Saving the model to HDF5 format requires the model to be a " 

154 "Functional model or a Sequential model. It does not work for " 

155 "subclassed models, because such models are defined via the " 

156 "body of a Python method, which isn't safely serializable. " 

157 "Consider saving to the Tensorflow SavedModel format (by " 

158 'setting save_format="tf") or using `save_weights`.' 

159 ) 

160 hdf5_format.save_model_to_hdf5( 

161 model, filepath, overwrite, include_optimizer 

162 ) 

163 else: 

164 with serialization.SharedObjectSavingScope(): 

165 with keras_option_scope( 

166 save_traces=save_traces, in_tf_saved_model_scope=True 

167 ): 

168 saved_model_save.save( 

169 model, 

170 filepath, 

171 overwrite, 

172 include_optimizer, 

173 signatures, 

174 options, 

175 save_traces, 

176 ) 

177 

178 

179@traceback_utils.filter_traceback 

180def load_model(filepath, custom_objects=None, compile=True, options=None): 

181 """Loads a model saved via `model.save()`. 

182 

183 Usage: 

184 

185 >>> model = tf.keras.Sequential([ 

186 ... tf.keras.layers.Dense(5, input_shape=(3,)), 

187 ... tf.keras.layers.Softmax()]) 

188 >>> model.save('/tmp/model') 

189 >>> loaded_model = tf.keras.models.load_model('/tmp/model') 

190 >>> x = tf.random.uniform((10, 3)) 

191 >>> assert np.allclose(model.predict(x), loaded_model.predict(x)) 

192 

193 Note that the model weights may have different scoped names after being 

194 loaded. Scoped names include the model/layer names, such as 

195 `"dense_1/kernel:0"`. It is recommended that you use the layer properties to 

196 access specific variables, e.g. `model.get_layer("dense_1").kernel`. 

197 

198 Args: 

199 filepath: One of the following: 

200 - String or `pathlib.Path` object, path to the saved model 

201 - `h5py.File` object from which to load the model 

202 custom_objects: Optional dictionary mapping names 

203 (strings) to custom classes or functions to be 

204 considered during deserialization. 

205 compile: Boolean, whether to compile the model 

206 after loading. 

207 options: Optional `tf.saved_model.LoadOptions` object that specifies 

208 options for loading from SavedModel. 

209 

210 Returns: 

211 A Keras model instance. If the original model was compiled, and saved 

212 with the optimizer, then the returned model will be compiled. Otherwise, 

213 the model will be left uncompiled. In the case that an uncompiled model 

214 is returned, a warning is displayed if the `compile` argument is set to 

215 `True`. 

216 

217 Raises: 

218 ImportError: if loading from an hdf5 file and h5py is not available. 

219 IOError: In case of an invalid savefile. 

220 """ 

221 with serialization.SharedObjectLoadingScope(): 

222 custom_objects = custom_objects or {} 

223 tlco = object_registration._THREAD_LOCAL_CUSTOM_OBJECTS.__dict__ 

224 gco = object_registration._GLOBAL_CUSTOM_OBJECTS 

225 custom_objects = {**custom_objects, **tlco, **gco} 

226 with object_registration.CustomObjectScope(custom_objects): 

227 with keras_option_scope( 

228 save_traces=False, in_tf_saved_model_scope=True 

229 ): 

230 with load_context.load_context(options): 

231 filepath_str = io_utils.path_to_string(filepath) 

232 if isinstance(filepath_str, str): 

233 if not tf.io.gfile.exists(filepath_str): 

234 raise IOError( 

235 f"No file or directory found at {filepath_str}" 

236 ) 

237 

238 if tf.io.gfile.isdir(filepath_str): 

239 return saved_model_load.load( 

240 filepath_str, compile, options 

241 ) 

242 else: 

243 if h5py is None: 

244 raise ImportError( 

245 "Filepath looks like a hdf5 file but h5py" 

246 "is not available." 

247 f" filepath={filepath_str}" 

248 ) 

249 return hdf5_format.load_model_from_hdf5( 

250 tf.io.gfile.GFile(filepath_str, mode="rb"), 

251 custom_objects, 

252 compile, 

253 ) 

254 elif h5py is not None and isinstance(filepath, h5py.File): 

255 return hdf5_format.load_model_from_hdf5( 

256 filepath, custom_objects, compile 

257 ) 

258 

259 raise IOError( 

260 "Unable to load model. Filepath is not an hdf5 file (or h5py is not " 

261 f"available) or SavedModel. Received: filepath={filepath}" 

262 ) 

263 

264 

265def save_weights( 

266 model, filepath, overwrite=True, save_format=None, options=None 

267): 

268 """Saves all layer weights. 

269 

270 Either saves in HDF5 or in TensorFlow format based on the `save_format` 

271 argument. 

272 

273 When saving in HDF5 format, the weight file has: 

274 - `layer_names` (attribute), a list of strings 

275 (ordered names of model layers). 

276 - For every layer, a `group` named `layer.name` 

277 - For every such layer group, a group attribute `weight_names`, 

278 a list of strings 

279 (ordered names of weights tensor of the layer). 

280 - For every weight in the layer, a dataset 

281 storing the weight value, named after the weight tensor. 

282 

283 When saving in TensorFlow format, all objects referenced by the network 

284 are saved in the same format as `tf.train.Checkpoint`, including any 

285 `Layer` instances or `Optimizer` instances assigned to object 

286 attributes. For networks constructed from inputs and outputs using 

287 `tf.keras.Model(inputs, outputs)`, `Layer` instances used by the network 

288 are tracked/saved automatically. For user-defined classes which inherit 

289 from `tf.keras.Model`, `Layer` instances must be assigned to object 

290 attributes, typically in the constructor. See the documentation of 

291 `tf.train.Checkpoint` and `tf.keras.Model` for details. 

292 

293 While the formats are the same, do not mix `save_weights` and 

294 `tf.train.Checkpoint`. Checkpoints saved by `Model.save_weights` should 

295 be loaded using `Model.load_weights`. Checkpoints saved using 

296 `tf.train.Checkpoint.save` should be restored using the corresponding 

297 `tf.train.Checkpoint.restore`. Prefer `tf.train.Checkpoint` over 

298 `save_weights` for training checkpoints. 

299 

300 The TensorFlow format matches objects and variables by starting at a 

301 root object, `self` for `save_weights`, and greedily matching attribute 

302 names. For `Model.save` this is the `Model`, and for `Checkpoint.save` 

303 this is the `Checkpoint` even if the `Checkpoint` has a model attached. 

304 This means saving a `tf.keras.Model` using `save_weights` and loading 

305 into a `tf.train.Checkpoint` with a `Model` attached (or vice versa) 

306 will not match the `Model`'s variables. See the 

307 [guide to training checkpoints]( 

308 https://www.tensorflow.org/guide/checkpoint) for details on 

309 the TensorFlow format. 

310 

311 Args: 

312 filepath: String or PathLike, path to the file to save the weights 

313 to. When saving in TensorFlow format, this is the prefix used 

314 for checkpoint files (multiple files are generated). Note that 

315 the '.h5' suffix causes weights to be saved in HDF5 format. 

316 overwrite: Whether to silently overwrite any existing file at the 

317 target location, or provide the user with a manual prompt. 

318 save_format: Either 'tf' or 'h5'. A `filepath` ending in '.h5' or 

319 '.keras' will default to HDF5 if `save_format` is `None`. 

320 Otherwise `None` defaults to 'tf'. 

321 options: Optional `tf.train.CheckpointOptions` object that specifies 

322 options for saving weights. 

323 

324 Raises: 

325 ImportError: If `h5py` is not available when attempting to save in 

326 HDF5 format. 

327 """ 

328 model._assert_weights_created() 

329 filepath = io_utils.path_to_string(filepath) 

330 filepath_is_h5 = saving_utils.is_hdf5_filepath(filepath) 

331 if save_format is None: 

332 if filepath_is_h5: 

333 save_format = "h5" 

334 else: 

335 save_format = "tf" 

336 else: 

337 user_format = save_format.lower().strip() 

338 if user_format in ("tensorflow", "tf"): 

339 save_format = "tf" 

340 elif user_format in ("hdf5", "h5", "keras"): 

341 save_format = "h5" 

342 else: 

343 raise ValueError( 

344 f"Unknown format. Received: `save_format`={save_format}. " 

345 'Was expecting one of {"tf", "h5"}.' 

346 ) 

347 if save_format == "tf" and filepath_is_h5: 

348 raise ValueError( 

349 'save_weights got save_format="tf"/"tensorflow", but the ' 

350 f"filepath ({filepath}) looks like an HDF5 file. " 

351 'Omit the ".h5"/".keras" when saving in TensorFlow format.' 

352 ) 

353 

354 if save_format == "h5" and h5py is None: 

355 raise ImportError( 

356 "`save_weights` requires h5py when saving in hdf5, but h5py is " 

357 "not available. Try installing h5py package." 

358 ) 

359 if save_format == "tf": 

360 check_filepath = filepath + ".index" 

361 else: 

362 check_filepath = filepath 

363 # If file exists and should not be overwritten: 

364 if not overwrite and os.path.isfile(check_filepath): 

365 proceed = io_utils.ask_to_proceed_with_overwrite(check_filepath) 

366 if not proceed: 

367 return 

368 if save_format == "h5": 

369 with h5py.File(filepath, "w") as f: 

370 hdf5_format.save_weights_to_hdf5_group(f, model) 

371 else: 

372 if not tf.executing_eagerly(): 

373 # Call `get_session` to initialize any uninitialized variables. 

374 backend.get_session() 

375 model._checkpoint.write(filepath, options=options) 

376 

377 # Record this checkpoint so it's visible from 

378 # tf.train.latest_checkpoint. 

379 tf.__internal__.train.update_checkpoint_state( 

380 save_dir=os.path.dirname(filepath), 

381 model_checkpoint_path=filepath, 

382 save_relative_paths=True, 

383 all_model_checkpoint_paths=[filepath], 

384 ) 

385 

386 

387def load_weights( 

388 model, filepath, by_name=False, skip_mismatch=False, options=None 

389): 

390 """Loads all layer weights, either from a SavedModel or H5 weights file. 

391 

392 If `by_name` is False weights are loaded based on the network's 

393 topology. This means the architecture should be the same as when the 

394 weights were saved. Note that layers that don't have weights are not 

395 taken into account in the topological ordering, so adding or removing 

396 layers is fine as long as they don't have weights. 

397 

398 If `by_name` is True, weights are loaded into layers only if they share 

399 the same name. This is useful for fine-tuning or transfer-learning 

400 models where some of the layers have changed. 

401 

402 Only topological loading (`by_name=False`) is supported when loading 

403 weights from the TensorFlow format. Note that topological loading 

404 differs slightly between TensorFlow and HDF5 formats for user-defined 

405 classes inheriting from `tf.keras.Model`: HDF5 loads based on a 

406 flattened list of weights, while the TensorFlow format loads based on 

407 the object-local names of attributes to which layers are assigned in the 

408 `Model`'s constructor. 

409 

410 Args: 

411 filepath: String, path to the weights file to load. For weight files 

412 in TensorFlow format, this is the file prefix (the same as was 

413 passed to `save_weights`). This can also be a path to a 

414 SavedModel saved from `model.save`. 

415 by_name: Boolean, whether to load weights by name or by topological 

416 order. Only topological loading is supported for weight files in 

417 TensorFlow format. 

418 skip_mismatch: Boolean, whether to skip loading of layers where 

419 there is a mismatch in the number of weights, or a mismatch in 

420 the shape of the weight (only valid when `by_name=True`). 

421 options: Optional `tf.train.CheckpointOptions` object that specifies 

422 options for loading weights. 

423 

424 Returns: 

425 When loading a weight file in TensorFlow format, returns the same 

426 status object as `tf.train.Checkpoint.restore`. When graph building, 

427 restore ops are run automatically as soon as the network is built 

428 (on first call for user-defined classes inheriting from `Model`, 

429 immediately if it is already built). 

430 

431 When loading weights in HDF5 format, returns `None`. 

432 

433 Raises: 

434 ImportError: If `h5py` is not available and the weight file is in 

435 HDF5 format. 

436 ValueError: If `skip_mismatch` is set to `True` when `by_name` is 

437 `False`. 

438 """ 

439 if backend.is_tpu_strategy(model._distribution_strategy): 

440 if model._distribution_strategy.extended.steps_per_run > 1 and ( 

441 not saving_utils.is_hdf5_filepath(filepath) 

442 ): 

443 spr = model._distribution_strategy.extended.steps_per_run 

444 raise ValueError( 

445 "Load weights is not implemented with TPUStrategy " 

446 "with `steps_per_run` greater than 1. The " 

447 f"`steps_per_run` is {spr}" 

448 ) 

449 if skip_mismatch and not by_name: 

450 raise ValueError( 

451 "When calling model.load_weights, skip_mismatch can only be " 

452 "set to True when by_name is True." 

453 ) 

454 

455 filepath, save_format = _detect_save_format(filepath) 

456 if save_format == "tf": 

457 status = model._checkpoint.read(filepath, options) 

458 if by_name: 

459 raise NotImplementedError( 

460 "Weights may only be loaded based on topology into Models " 

461 "when loading TensorFlow-formatted weights " 

462 "(got by_name=True to load_weights)." 

463 ) 

464 if not tf.executing_eagerly(): 

465 session = backend.get_session() 

466 # Restore existing variables (if any) immediately, and set up a 

467 # streaming restore for any variables created in the future. 

468 tf.__internal__.tracking.streaming_restore( 

469 status=status, session=session 

470 ) 

471 status.assert_nontrivial_match() 

472 else: 

473 status = None 

474 if h5py is None: 

475 raise ImportError( 

476 "`load_weights` requires h5py package when loading weights " 

477 "from HDF5. Try installing h5py." 

478 ) 

479 if not model._is_graph_network and not model.built: 

480 raise ValueError( 

481 "Unable to load weights saved in HDF5 format into a " 

482 "subclassed Model which has not created its variables yet. " 

483 "Call the Model first, then load the weights." 

484 ) 

485 model._assert_weights_created() 

486 with h5py.File(filepath, "r") as f: 

487 if "layer_names" not in f.attrs and "model_weights" in f: 

488 f = f["model_weights"] 

489 if by_name: 

490 hdf5_format.load_weights_from_hdf5_group_by_name( 

491 f, model, skip_mismatch 

492 ) 

493 else: 

494 hdf5_format.load_weights_from_hdf5_group(f, model) 

495 

496 # Perform any layer defined finalization of the layer state. 

497 for layer in model.layers: 

498 layer.finalize_state() 

499 return status 

500 

501 

502def _detect_save_format(filepath): 

503 """Returns path to weights file and save format.""" 

504 

505 filepath = io_utils.path_to_string(filepath) 

506 if saving_utils.is_hdf5_filepath(filepath): 

507 return filepath, "h5" 

508 

509 # Filepath could be a TensorFlow checkpoint file prefix or SavedModel 

510 # directory. It's possible for filepath to be both a prefix and directory. 

511 # Prioritize checkpoint over SavedModel. 

512 if _is_readable_tf_checkpoint(filepath): 

513 save_format = "tf" 

514 elif tf.saved_model.contains_saved_model(filepath): 

515 ckpt_path = os.path.join( 

516 filepath, 

517 tf.saved_model.VARIABLES_DIRECTORY, 

518 tf.saved_model.VARIABLES_FILENAME, 

519 ) 

520 if _is_readable_tf_checkpoint(ckpt_path): 

521 filepath = ckpt_path 

522 save_format = "tf" 

523 else: 

524 raise ValueError( 

525 "Unable to load weights. filepath {} appears to be a " 

526 "SavedModel directory, but checkpoint either doesn't " 

527 "exist, or is incorrectly formatted.".format(filepath) 

528 ) 

529 else: 

530 # Not a TensorFlow checkpoint. This filepath is likely an H5 file that 

531 # doesn't have the hdf5/keras extensions. 

532 save_format = "h5" 

533 return filepath, save_format 

534 

535 

536def _is_readable_tf_checkpoint(filepath): 

537 try: 

538 tf.compat.v1.train.NewCheckpointReader(filepath) 

539 return True 

540 except tf.errors.DataLossError: 

541 # The checkpoint is not readable in TensorFlow format. 

542 return False 

543 

544 

545# Inject the load_model function to keras_deps to remove the dependency 

546# from TFLite to Keras. 

547tf.__internal__.register_load_model_function(load_model) 

548