Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/saving/saving_api.py: 21%

87 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2022 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Public API surface for saving APIs.""" 

16 

17import os 

18import warnings 

19import zipfile 

20 

21import tensorflow.compat.v2 as tf 

22from tensorflow.python.util.tf_export import keras_export 

23 

24from keras.src.saving import saving_lib 

25from keras.src.saving.legacy import save as legacy_sm_saving_lib 

26from keras.src.utils import io_utils 

27 

28try: 

29 import h5py 

30except ImportError: 

31 h5py = None 

32 

33 

34@keras_export("keras.saving.save_model", "keras.models.save_model") 

35def save_model(model, filepath, overwrite=True, save_format=None, **kwargs): 

36 """Saves a model as a TensorFlow SavedModel or HDF5 file. 

37 

38 See the [Serialization and Saving guide]( 

39 https://keras.io/guides/serialization_and_saving/) for details. 

40 

41 Args: 

42 model: Keras model instance to be saved. 

43 filepath: `str` or `pathlib.Path` object. Path where to save the model. 

44 overwrite: Whether we should overwrite any existing model at the target 

45 location, or instead ask the user via an interactive prompt. 

46 save_format: Either `"keras"`, `"tf"`, `"h5"`, 

47 indicating whether to save the model 

48 in the native Keras format (`.keras`), 

49 in the TensorFlow SavedModel format (referred to as "SavedModel" 

50 below), or in the legacy HDF5 format (`.h5`). 

51 Defaults to `"tf"` in TF 2.X, and `"h5"` in TF 1.X. 

52 

53 SavedModel format arguments: 

54 include_optimizer: Only applied to SavedModel and legacy HDF5 formats. 

55 If False, do not save the optimizer state. Defaults to True. 

56 signatures: Only applies to SavedModel format. Signatures to save 

57 with the SavedModel. See the `signatures` argument in 

58 `tf.saved_model.save` for details. 

59 options: Only applies to SavedModel format. 

60 `tf.saved_model.SaveOptions` object that specifies SavedModel 

61 saving options. 

62 save_traces: Only applies to SavedModel format. When enabled, the 

63 SavedModel will store the function traces for each layer. This 

64 can be disabled, so that only the configs of each layer are stored. 

65 Defaults to `True`. Disabling this will decrease serialization time 

66 and reduce file size, but it requires that all custom layers/models 

67 implement a `get_config()` method. 

68 

69 Example: 

70 

71 ```python 

72 model = tf.keras.Sequential([ 

73 tf.keras.layers.Dense(5, input_shape=(3,)), 

74 tf.keras.layers.Softmax()]) 

75 model.save("model.keras") 

76 loaded_model = tf.keras.saving.load_model("model.keras") 

77 x = tf.random.uniform((10, 3)) 

78 assert np.allclose(model.predict(x), loaded_model.predict(x)) 

79 ``` 

80 

81 Note that `model.save()` is an alias for `tf.keras.saving.save_model()`. 

82 

83 The SavedModel or HDF5 file contains: 

84 

85 - The model's configuration (architecture) 

86 - The model's weights 

87 - The model's optimizer's state (if any) 

88 

89 Thus models can be reinstantiated in the exact same state, without any of 

90 the code used for model definition or training. 

91 

92 Note that the model weights may have different scoped names after being 

93 loaded. Scoped names include the model/layer names, such as 

94 `"dense_1/kernel:0"`. It is recommended that you use the layer properties to 

95 access specific variables, e.g. `model.get_layer("dense_1").kernel`. 

96 

97 __SavedModel serialization format__ 

98 

99 With `save_format="tf"`, the model and all trackable objects attached 

100 to the it (e.g. layers and variables) are saved as a TensorFlow SavedModel. 

101 The model config, weights, and optimizer are included in the SavedModel. 

102 Additionally, for every Keras layer attached to the model, the SavedModel 

103 stores: 

104 

105 * The config and metadata -- e.g. name, dtype, trainable status 

106 * Traced call and loss functions, which are stored as TensorFlow 

107 subgraphs. 

108 

109 The traced functions allow the SavedModel format to save and load custom 

110 layers without the original class definition. 

111 

112 You can choose to not save the traced functions by disabling the 

113 `save_traces` option. This will decrease the time it takes to save the model 

114 and the amount of disk space occupied by the output SavedModel. If you 

115 enable this option, then you _must_ provide all custom class definitions 

116 when loading the model. See the `custom_objects` argument in 

117 `tf.keras.saving.load_model`. 

118 """ 

119 save_format = get_save_format(filepath, save_format) 

120 

121 # Deprecation warnings 

122 if save_format == "h5": 

123 warnings.warn( 

124 "You are saving your model as an HDF5 file via `model.save()`. " 

125 "This file format is considered legacy. " 

126 "We recommend using instead the native Keras format, " 

127 "e.g. `model.save('my_model.keras')`.", 

128 stacklevel=2, 

129 ) 

130 

131 if save_format == "keras": 

132 # If file exists and should not be overwritten. 

133 try: 

134 exists = os.path.exists(filepath) 

135 except TypeError: 

136 exists = False 

137 if exists and not overwrite: 

138 proceed = io_utils.ask_to_proceed_with_overwrite(filepath) 

139 if not proceed: 

140 return 

141 if kwargs: 

142 raise ValueError( 

143 "The following argument(s) are not supported " 

144 f"with the native Keras format: {list(kwargs.keys())}" 

145 ) 

146 saving_lib.save_model(model, filepath) 

147 else: 

148 # Legacy case 

149 return legacy_sm_saving_lib.save_model( 

150 model, 

151 filepath, 

152 overwrite=overwrite, 

153 save_format=save_format, 

154 **kwargs, 

155 ) 

156 

157 

158@keras_export("keras.saving.load_model", "keras.models.load_model") 

159def load_model( 

160 filepath, custom_objects=None, compile=True, safe_mode=True, **kwargs 

161): 

162 """Loads a model saved via `model.save()`. 

163 

164 Args: 

165 filepath: `str` or `pathlib.Path` object, path to the saved model file. 

166 custom_objects: Optional dictionary mapping names 

167 (strings) to custom classes or functions to be 

168 considered during deserialization. 

169 compile: Boolean, whether to compile the model after loading. 

170 safe_mode: Boolean, whether to disallow unsafe `lambda` deserialization. 

171 When `safe_mode=False`, loading an object has the potential to 

172 trigger arbitrary code execution. This argument is only 

173 applicable to the Keras v3 model format. Defaults to True. 

174 

175 SavedModel format arguments: 

176 options: Only applies to SavedModel format. 

177 Optional `tf.saved_model.LoadOptions` object that specifies 

178 SavedModel loading options. 

179 

180 Returns: 

181 A Keras model instance. If the original model was compiled, 

182 and the argument `compile=True` is set, then the returned model 

183 will be compiled. Otherwise, the model will be left uncompiled. 

184 

185 Example: 

186 

187 ```python 

188 model = tf.keras.Sequential([ 

189 tf.keras.layers.Dense(5, input_shape=(3,)), 

190 tf.keras.layers.Softmax()]) 

191 model.save("model.keras") 

192 loaded_model = tf.keras.saving.load_model("model.keras") 

193 x = tf.random.uniform((10, 3)) 

194 assert np.allclose(model.predict(x), loaded_model.predict(x)) 

195 ``` 

196 

197 Note that the model variables may have different name values 

198 (`var.name` property, e.g. `"dense_1/kernel:0"`) after being reloaded. 

199 It is recommended that you use layer attributes to 

200 access specific variables, e.g. `model.get_layer("dense_1").kernel`. 

201 """ 

202 is_keras_zip = str(filepath).endswith(".keras") and zipfile.is_zipfile( 

203 filepath 

204 ) 

205 

206 # Support for remote zip files 

207 if ( 

208 saving_lib.is_remote_path(filepath) 

209 and not tf.io.gfile.isdir(filepath) 

210 and not is_keras_zip 

211 ): 

212 local_path = os.path.join( 

213 saving_lib.get_temp_dir(), os.path.basename(filepath) 

214 ) 

215 

216 # Copy from remote to temporary local directory 

217 tf.io.gfile.copy(filepath, local_path, overwrite=True) 

218 

219 # Switch filepath to local zipfile for loading model 

220 if zipfile.is_zipfile(local_path): 

221 filepath = local_path 

222 is_keras_zip = True 

223 

224 if is_keras_zip: 

225 if kwargs: 

226 raise ValueError( 

227 "The following argument(s) are not supported " 

228 f"with the native Keras format: {list(kwargs.keys())}" 

229 ) 

230 return saving_lib.load_model( 

231 filepath, 

232 custom_objects=custom_objects, 

233 compile=compile, 

234 safe_mode=safe_mode, 

235 ) 

236 

237 # Legacy case. 

238 return legacy_sm_saving_lib.load_model( 

239 filepath, custom_objects=custom_objects, compile=compile, **kwargs 

240 ) 

241 

242 

243def save_weights(model, filepath, overwrite=True, **kwargs): 

244 if str(filepath).endswith(".weights.h5"): 

245 # If file exists and should not be overwritten. 

246 try: 

247 exists = os.path.exists(filepath) 

248 except TypeError: 

249 exists = False 

250 if exists and not overwrite: 

251 proceed = io_utils.ask_to_proceed_with_overwrite(filepath) 

252 if not proceed: 

253 return 

254 saving_lib.save_weights_only(model, filepath) 

255 else: 

256 legacy_sm_saving_lib.save_weights( 

257 model, filepath, overwrite=overwrite, **kwargs 

258 ) 

259 

260 

261def load_weights(model, filepath, skip_mismatch=False, **kwargs): 

262 if str(filepath).endswith(".keras") and zipfile.is_zipfile(filepath): 

263 saving_lib.load_weights_only( 

264 model, filepath, skip_mismatch=skip_mismatch 

265 ) 

266 elif str(filepath).endswith(".weights.h5"): 

267 saving_lib.load_weights_only( 

268 model, filepath, skip_mismatch=skip_mismatch 

269 ) 

270 else: 

271 return legacy_sm_saving_lib.load_weights( 

272 model, filepath, skip_mismatch=skip_mismatch, **kwargs 

273 ) 

274 

275 

276def get_save_format(filepath, save_format): 

277 if save_format: 

278 if save_format == "keras_v3": 

279 return "keras" 

280 if save_format == "keras": 

281 if saving_lib.saving_v3_enabled(): 

282 return "keras" 

283 else: 

284 return "h5" 

285 if save_format in ("h5", "hdf5"): 

286 return "h5" 

287 if save_format in ("tf", "tensorflow"): 

288 return "tf" 

289 

290 raise ValueError( 

291 "Unknown `save_format` argument. Expected one of " 

292 "'keras', 'tf', or 'h5'. " 

293 f"Received: save_format{save_format}" 

294 ) 

295 

296 # No save format specified: infer from filepath. 

297 

298 if str(filepath).endswith(".keras"): 

299 if saving_lib.saving_v3_enabled(): 

300 return "keras" 

301 else: 

302 return "h5" 

303 

304 if str(filepath).endswith((".h5", ".hdf5")): 

305 return "h5" 

306 

307 if h5py is not None and isinstance(filepath, h5py.File): 

308 return "h5" 

309 

310 # No recognizable file format: default to TF in TF2 and h5 in TF1. 

311 

312 if tf.__internal__.tf2.enabled(): 

313 return "tf" 

314 else: 

315 return "h5" 

316