Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/saving/saving_utils.py: 19%

147 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Utils related to keras model saving.""" 

16 

17import collections 

18import copy 

19import os 

20 

21from tensorflow.python.eager import def_function 

22from tensorflow.python.keras import backend as K 

23from tensorflow.python.keras import losses 

24from tensorflow.python.keras import optimizer_v1 

25from tensorflow.python.keras import optimizers 

26from tensorflow.python.keras.engine import base_layer_utils 

27from tensorflow.python.keras.utils import generic_utils 

28from tensorflow.python.keras.utils import version_utils 

29from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite 

30from tensorflow.python.platform import tf_logging as logging 

31from tensorflow.python.util import nest 

32 

33 

34def extract_model_metrics(model): 

35 """Convert metrics from a Keras model `compile` API to dictionary. 

36 

37 This is used for converting Keras models to Estimators and SavedModels. 

38 

39 Args: 

40 model: A `tf.keras.Model` object. 

41 

42 Returns: 

43 Dictionary mapping metric names to metric instances. May return `None` if 

44 the model does not contain any metrics. 

45 """ 

46 if getattr(model, '_compile_metrics', None): 

47 # TODO(psv/kathywu): use this implementation in model to estimator flow. 

48 # We are not using model.metrics here because we want to exclude the metrics 

49 # added using `add_metric` API. 

50 return {m.name: m for m in model._compile_metric_functions} # pylint: disable=protected-access 

51 return None 

52 

53 

54def model_input_signature(model, keep_original_batch_size=False): 

55 """Inspect model to get its input signature. 

56 

57 The model's input signature is a list with a single (possibly-nested) object. 

58 This is due to the Keras-enforced restriction that tensor inputs must be 

59 passed in as the first argument. 

60 

61 For example, a model with input {'feature1': <Tensor>, 'feature2': <Tensor>} 

62 will have input signature: [{'feature1': TensorSpec, 'feature2': TensorSpec}] 

63 

64 Args: 

65 model: Keras Model object. 

66 keep_original_batch_size: A boolean indicating whether we want to keep using 

67 the original batch size or set it to None. Default is `False`, which means 

68 that the batch dim of the returned input signature will always be set to 

69 `None`. 

70 

71 Returns: 

72 A list containing either a single TensorSpec or an object with nested 

73 TensorSpecs. This list does not contain the `training` argument. 

74 """ 

75 input_specs = model._get_save_spec(dynamic_batch=not keep_original_batch_size) # pylint: disable=protected-access 

76 if input_specs is None: 

77 return None 

78 input_specs = _enforce_names_consistency(input_specs) 

79 # Return a list with a single element as the model's input signature. 

80 if isinstance(input_specs, 

81 collections.abc.Sequence) and len(input_specs) == 1: 

82 # Note that the isinstance check filters out single-element dictionaries, 

83 # which should also be wrapped as a single-element list. 

84 return input_specs 

85 else: 

86 return [input_specs] 

87 

88 

89def raise_model_input_error(model): 

90 raise ValueError( 

91 'Model {} cannot be saved because the input shapes have not been ' 

92 'set. Usually, input shapes are automatically determined from calling' 

93 ' `.fit()` or `.predict()`. To manually set the shapes, call ' 

94 '`model.build(input_shape)`.'.format(model)) 

95 

96 

97def trace_model_call(model, input_signature=None): 

98 """Trace the model call to create a tf.function for exporting a Keras model. 

99 

100 Args: 

101 model: A Keras model. 

102 input_signature: optional, a list of tf.TensorSpec objects specifying the 

103 inputs to the model. 

104 

105 Returns: 

106 A tf.function wrapping the model's call function with input signatures set. 

107 

108 Raises: 

109 ValueError: if input signature cannot be inferred from the model. 

110 """ 

111 if input_signature is None: 

112 if isinstance(model.call, def_function.Function): 

113 input_signature = model.call.input_signature 

114 

115 if input_signature is None: 

116 input_signature = model_input_signature(model) 

117 

118 if input_signature is None: 

119 raise_model_input_error(model) 

120 

121 @def_function.function(input_signature=input_signature) 

122 def _wrapped_model(*args): 

123 """A concrete tf.function that wraps the model's call function.""" 

124 # When given a single input, Keras models will call the model on the tensor 

125 # rather than a list consisting of the single tensor. 

126 inputs = args[0] if len(input_signature) == 1 else list(args) 

127 

128 with base_layer_utils.call_context().enter( 

129 model, inputs=inputs, build_graph=False, training=False, saving=True): 

130 outputs = model(inputs, training=False) 

131 

132 # Outputs always has to be a flat dict. 

133 output_names = model.output_names # Functional Model. 

134 if output_names is None: # Subclassed Model. 

135 from tensorflow.python.keras.engine import compile_utils # pylint: disable=g-import-not-at-top 

136 output_names = compile_utils.create_pseudo_output_names(outputs) 

137 outputs = nest.flatten(outputs) 

138 return {name: output for name, output in zip(output_names, outputs)} 

139 

140 return _wrapped_model 

141 

142 

143def model_metadata(model, include_optimizer=True, require_config=True): 

144 """Returns a dictionary containing the model metadata.""" 

145 from tensorflow.python.keras import __version__ as keras_version # pylint: disable=g-import-not-at-top 

146 from tensorflow.python.keras.optimizer_v2 import optimizer_v2 # pylint: disable=g-import-not-at-top 

147 

148 model_config = {'class_name': model.__class__.__name__} 

149 try: 

150 model_config['config'] = model.get_config() 

151 except NotImplementedError as e: 

152 if require_config: 

153 raise e 

154 

155 metadata = dict( 

156 keras_version=str(keras_version), 

157 backend=K.backend(), 

158 model_config=model_config) 

159 if model.optimizer and include_optimizer: 

160 if isinstance(model.optimizer, optimizer_v1.TFOptimizer): 

161 logging.warning( 

162 'TensorFlow optimizers do not ' 

163 'make it possible to access ' 

164 'optimizer attributes or optimizer state ' 

165 'after instantiation. ' 

166 'As a result, we cannot save the optimizer ' 

167 'as part of the model save file. ' 

168 'You will have to compile your model again after loading it. ' 

169 'Prefer using a Keras optimizer instead ' 

170 '(see keras.io/optimizers).') 

171 elif model._compile_was_called: # pylint: disable=protected-access 

172 training_config = model._get_compile_args(user_metrics=False) # pylint: disable=protected-access 

173 training_config.pop('optimizer', None) # Handled separately. 

174 metadata['training_config'] = _serialize_nested_config(training_config) 

175 if isinstance(model.optimizer, optimizer_v2.RestoredOptimizer): 

176 raise NotImplementedError( 

177 'As of now, Optimizers loaded from SavedModel cannot be saved. ' 

178 'If you\'re calling `model.save` or `tf.keras.models.save_model`,' 

179 ' please set the `include_optimizer` option to `False`. For ' 

180 '`tf.saved_model.save`, delete the optimizer from the model.') 

181 else: 

182 optimizer_config = { 

183 'class_name': 

184 generic_utils.get_registered_name(model.optimizer.__class__), 

185 'config': 

186 model.optimizer.get_config() 

187 } 

188 metadata['training_config']['optimizer_config'] = optimizer_config 

189 return metadata 

190 

191 

192def should_overwrite(filepath, overwrite): 

193 """Returns whether the filepath should be overwritten.""" 

194 # If file exists and should not be overwritten. 

195 if not overwrite and os.path.isfile(filepath): 

196 return ask_to_proceed_with_overwrite(filepath) 

197 return True 

198 

199 

200def compile_args_from_training_config(training_config, custom_objects=None): 

201 """Return model.compile arguments from training config.""" 

202 if custom_objects is None: 

203 custom_objects = {} 

204 

205 with generic_utils.CustomObjectScope(custom_objects): 

206 optimizer_config = training_config['optimizer_config'] 

207 optimizer = optimizers.deserialize(optimizer_config) 

208 

209 # Recover losses. 

210 loss = None 

211 loss_config = training_config.get('loss', None) 

212 if loss_config is not None: 

213 loss = _deserialize_nested_config(losses.deserialize, loss_config) 

214 

215 # Recover metrics. 

216 metrics = None 

217 metrics_config = training_config.get('metrics', None) 

218 if metrics_config is not None: 

219 metrics = _deserialize_nested_config(_deserialize_metric, metrics_config) 

220 

221 # Recover weighted metrics. 

222 weighted_metrics = None 

223 weighted_metrics_config = training_config.get('weighted_metrics', None) 

224 if weighted_metrics_config is not None: 

225 weighted_metrics = _deserialize_nested_config(_deserialize_metric, 

226 weighted_metrics_config) 

227 

228 sample_weight_mode = training_config['sample_weight_mode'] if hasattr( 

229 training_config, 'sample_weight_mode') else None 

230 loss_weights = training_config['loss_weights'] 

231 

232 return dict( 

233 optimizer=optimizer, 

234 loss=loss, 

235 metrics=metrics, 

236 weighted_metrics=weighted_metrics, 

237 loss_weights=loss_weights, 

238 sample_weight_mode=sample_weight_mode) 

239 

240 

241def _deserialize_nested_config(deserialize_fn, config): 

242 """Deserializes arbitrary Keras `config` using `deserialize_fn`.""" 

243 

244 def _is_single_object(obj): 

245 if isinstance(obj, dict) and 'class_name' in obj: 

246 return True # Serialized Keras object. 

247 if isinstance(obj, str): 

248 return True # Serialized function or string. 

249 return False 

250 

251 if config is None: 

252 return None 

253 if _is_single_object(config): 

254 return deserialize_fn(config) 

255 elif isinstance(config, dict): 

256 return { 

257 k: _deserialize_nested_config(deserialize_fn, v) 

258 for k, v in config.items() 

259 } 

260 elif isinstance(config, (tuple, list)): 

261 return [_deserialize_nested_config(deserialize_fn, obj) for obj in config] 

262 

263 raise ValueError('Saved configuration not understood.') 

264 

265 

266def _serialize_nested_config(config): 

267 """Serialized a nested structure of Keras objects.""" 

268 

269 def _serialize_fn(obj): 

270 if callable(obj): 

271 return generic_utils.serialize_keras_object(obj) 

272 return obj 

273 

274 return nest.map_structure(_serialize_fn, config) 

275 

276 

277def _deserialize_metric(metric_config): 

278 """Deserialize metrics, leaving special strings untouched.""" 

279 from tensorflow.python.keras import metrics as metrics_module # pylint:disable=g-import-not-at-top 

280 if metric_config in ['accuracy', 'acc', 'crossentropy', 'ce']: 

281 # Do not deserialize accuracy and cross-entropy strings as we have special 

282 # case handling for these in compile, based on model output shape. 

283 return metric_config 

284 return metrics_module.deserialize(metric_config) 

285 

286 

287def _enforce_names_consistency(specs): 

288 """Enforces that either all specs have names or none do.""" 

289 

290 def _has_name(spec): 

291 return hasattr(spec, 'name') and spec.name is not None 

292 

293 def _clear_name(spec): 

294 spec = copy.deepcopy(spec) 

295 if hasattr(spec, 'name'): 

296 spec._name = None # pylint:disable=protected-access 

297 return spec 

298 

299 flat_specs = nest.flatten(specs) 

300 name_inconsistency = ( 

301 any(_has_name(s) for s in flat_specs) and 

302 not all(_has_name(s) for s in flat_specs)) 

303 

304 if name_inconsistency: 

305 specs = nest.map_structure(_clear_name, specs) 

306 return specs 

307 

308 

309def try_build_compiled_arguments(model): 

310 if (not version_utils.is_v1_layer_or_model(model) and 

311 model.outputs is not None): 

312 try: 

313 if not model.compiled_loss.built: 

314 model.compiled_loss.build(model.outputs) 

315 if not model.compiled_metrics.built: 

316 model.compiled_metrics.build(model.outputs, model.outputs) 

317 except: # pylint: disable=bare-except 

318 logging.warning( 

319 'Compiled the loaded model, but the compiled metrics have yet to ' 

320 'be built. `model.compile_metrics` will be empty until you train ' 

321 'or evaluate the model.') 

322 

323 

324def is_hdf5_filepath(filepath): 

325 return (filepath.endswith('.h5') or filepath.endswith('.keras') or 

326 filepath.endswith('.hdf5'))