Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/engine/training_eager_v1.py: 15%

120 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Keras training and evaluation routines for eager execution.""" 

16# pylint: disable=protected-access 

17 

18import numpy as np 

19 

20from tensorflow.python.eager.backprop import GradientTape 

21from tensorflow.python.framework import tensor_conversion 

22from tensorflow.python.keras import backend 

23from tensorflow.python.keras.engine import training_utils 

24from tensorflow.python.keras.engine import training_utils_v1 

25from tensorflow.python.keras.mixed_precision import loss_scale_optimizer 

26from tensorflow.python.keras.utils import losses_utils 

27from tensorflow.python.ops import math_ops 

28from tensorflow.python.platform import tf_logging as logging 

29from tensorflow.python.util import nest 

30 

31 

32def _eager_loss_fn(outputs, targets, loss_fn, output_name): 

33 with backend.name_scope(output_name + '_loss'): 

34 loss = loss_fn(targets, outputs) 

35 return loss 

36 

37 

38def _eager_metrics_fn(model, outputs, targets, sample_weights=None, masks=None): 

39 """Calculates the metrics for each output of the given model. 

40 

41 Args: 

42 model: The model on which metrics are being calculated. 

43 outputs: The outputs of the given model. 

44 targets: The predictions or targets of the given model. 

45 sample_weights: Optional list of sample weights for each output. 

46 masks: Optional list of masks for each output. 

47 

48 Returns: 

49 Returns the metric results for each output of the model. 

50 """ 

51 outputs = nest.flatten(outputs) 

52 targets = nest.flatten(targets) 

53 # Invoke all(weighted and unweighted) metrics. 

54 metric_results = [] 

55 if targets: 

56 # Insert None values corresponding to the targets that need to be skipped 

57 # on the model. 

58 if len(model._targets) != len(targets): 

59 new_targets = [ 

60 None if t is None else targets.pop(0) for t in model._targets 

61 ] 

62 targets = new_targets 

63 

64 metric_results = model._handle_metrics( 

65 outputs, 

66 targets=targets, 

67 sample_weights=sample_weights, 

68 masks=masks, 

69 return_weighted_and_unweighted_metrics=True, 

70 skip_target_masks=model._prepare_skip_target_masks()) 

71 

72 # Add metric results from the `add_metric` metrics. 

73 metric_results.extend([ 

74 m.result() 

75 for m in model.metrics 

76 if m not in model._compile_metric_functions 

77 ]) 

78 return metric_results 

79 

80 

81def _model_loss(model, 

82 inputs, 

83 targets, 

84 output_loss_metrics=None, 

85 sample_weights=None, 

86 training=False): 

87 """Calculates the loss for a given model. 

88 

89 Args: 

90 model: The model on which metrics are being calculated. 

91 inputs: Either a dictionary of inputs to the model or a list of input 

92 arrays. 

93 targets: List of target arrays. 

94 output_loss_metrics: List of metrics that are used to aggregated output 

95 loss values. 

96 sample_weights: Optional list of sample weight arrays. 

97 training: Whether the model should be run in inference or training mode. 

98 

99 Returns: 

100 Returns the model output, total loss, loss value calculated using the 

101 specified loss function and masks for each output. The total loss includes 

102 regularization losses and applies masking and sample weighting 

103 to the loss value. 

104 """ 

105 # TODO(psv): Dedup code here with graph mode prepare_total_loss() fn. 

106 # Used to keep track of the total loss value (stateless). 

107 # eg., total_loss = loss_weight_1 * output_1_loss_fn(...) + 

108 # loss_weight_2 * output_2_loss_fn(...) + 

109 # layer losses. 

110 total_loss = 0 

111 kwargs = {} 

112 if model._expects_training_arg: 

113 kwargs['training'] = training 

114 if len(inputs) == 1 and not isinstance(inputs, dict): 

115 inputs = inputs[0] 

116 

117 # Allow mixed `NumPy` and `EagerTensor` input here. 

118 if any( 

119 isinstance(input_t, (np.ndarray, float, int)) 

120 for input_t in nest.flatten(inputs)): 

121 inputs = nest.map_structure( 

122 tensor_conversion.convert_to_tensor_v2_with_dispatch, inputs 

123 ) 

124 

125 outs = model(inputs, **kwargs) 

126 outs = nest.flatten(outs) 

127 

128 if targets: 

129 targets = training_utils_v1.cast_if_floating_dtype_and_mismatch( 

130 targets, outs) 

131 # TODO(sallymatson/psv): check if we should do same mismatch fix for weights 

132 if sample_weights: 

133 new_sample_weights = [] 

134 for val in sample_weights: 

135 if val is not None: 

136 new_sample_weights.append(training_utils_v1.cast_if_floating_dtype( 

137 tensor_conversion.convert_to_tensor_v2_with_dispatch(val))) 

138 else: 

139 new_sample_weights.append(None) 

140 sample_weights = new_sample_weights 

141 

142 masks = [getattr(t, '_keras_mask', None) for t in outs] 

143 targets = nest.flatten(targets) 

144 

145 # Used to keep track of individual output losses. 

146 output_losses = [] 

147 

148 with backend.name_scope('loss'): 

149 loss_fns = [ 

150 loss_fn for loss_fn in model.loss_functions if loss_fn is not None 

151 ] 

152 custom_losses = model.losses # Regularization losses 

153 

154 if not loss_fns and not custom_losses: 

155 if training: 

156 raise ValueError('The model cannot be trained ' 

157 'because it has no loss to optimize.') 

158 else: 

159 raise ValueError('The model cannot be evaluated ' 

160 'because it has no loss to compute.') 

161 

162 for i, loss_fn in enumerate(loss_fns): 

163 weights = sample_weights[i] if sample_weights else None 

164 mask = masks[i] 

165 with backend.name_scope(model.output_names[i] + '_loss'): 

166 if mask is not None: 

167 mask = math_ops.cast(mask, outs[i].dtype) 

168 # Update weights with mask. 

169 if weights is None: 

170 weights = mask 

171 else: 

172 # Update dimensions of weights to match with mask if possible. 

173 weights = math_ops.cast(weights, outs[i].dtype) 

174 mask, _, weights = ( 

175 losses_utils.squeeze_or_expand_dimensions( 

176 mask, sample_weight=weights)) 

177 weights *= mask 

178 

179 if hasattr(loss_fn, 'reduction'): 

180 per_sample_losses = loss_fn.call(targets[i], outs[i]) 

181 weighted_losses = losses_utils.compute_weighted_loss( 

182 per_sample_losses, 

183 sample_weight=weights, 

184 reduction=losses_utils.ReductionV2.NONE) 

185 loss_reduction = loss_fn.reduction 

186 

187 # `AUTO` loss reduction defaults to `SUM_OVER_BATCH_SIZE` for all 

188 # compile use cases. 

189 if loss_reduction == losses_utils.ReductionV2.AUTO: 

190 loss_reduction = losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE 

191 

192 # Compute the stateless loss value. 

193 output_loss = losses_utils.reduce_weighted_loss( 

194 weighted_losses, reduction=loss_reduction) 

195 else: 

196 # Compute the stateless loss value for a custom loss class. 

197 # Here we assume that the class takes care of loss reduction 

198 # because if this class returns a vector value we cannot 

199 # differentiate between use case where a custom optimizer 

200 # expects a vector loss value vs unreduced per-sample loss value. 

201 output_loss = loss_fn(targets[i], outs[i], sample_weight=weights) 

202 loss_reduction = losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE 

203 

204 # If the number of outputs is 1 then we don't append the loss metric 

205 # associated with each model output. When there are multiple outputs 

206 # associated with a model, each output's loss is calculated and returned 

207 # as part of the loss_metrics. 

208 if len(model.outputs) > 1: 

209 # Keep track of the stateful output loss result. 

210 output_losses.append(output_loss_metrics[i](output_loss)) 

211 

212 # Scale output loss for distribution. For custom losses we assume 

213 # reduction was mean. 

214 if loss_reduction == losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE: 

215 output_loss = losses_utils.scale_loss_for_distribution(output_loss) 

216 total_loss += model._loss_weights_list[i] * output_loss 

217 

218 # Add regularization losses 

219 if custom_losses: 

220 total_loss += losses_utils.scale_loss_for_distribution( 

221 math_ops.add_n(custom_losses)) 

222 return outs, total_loss, output_losses, masks 

223 

224 

225def _process_single_batch(model, 

226 inputs, 

227 targets, 

228 output_loss_metrics=None, 

229 sample_weights=None, 

230 training=False): 

231 """Calculate the loss and gradient for one input batch. 

232 

233 The model weights are updated if training is set to True. 

234 

235 Args: 

236 model: Model whose loss has to be calculated. 

237 inputs: List of input arrays. 

238 targets: List of target arrays. 

239 output_loss_metrics: List of metrics that are used to aggregated output 

240 loss values. 

241 sample_weights: Optional list of sample weight arrays. 

242 training: The boolean represents if the weights of the model are updated. 

243 'fit' methods will set this to True while 'evaluate' methods will 

244 set this to False. 

245 

246 Returns: 

247 output of the model, total loss, the loss and the mask 

248 associated with each output. 

249 

250 Raises: 

251 ValueError: If the model has no loss to optimize. 

252 """ 

253 with backend.eager_learning_phase_scope(1 if training else 0), \ 

254 training_utils.RespectCompiledTrainableState(model): 

255 with GradientTape() as tape: 

256 outs, total_loss, output_losses, masks = ( 

257 _model_loss( 

258 model, 

259 inputs, 

260 targets, 

261 output_loss_metrics=output_loss_metrics, 

262 sample_weights=sample_weights, 

263 training=training)) 

264 if isinstance(model.optimizer, loss_scale_optimizer.LossScaleOptimizer): 

265 scaled_total_loss = model.optimizer.get_scaled_loss(total_loss) 

266 else: 

267 scaled_total_loss = total_loss 

268 if training: 

269 trainable_weights = model.trainable_weights 

270 if trainable_weights: 

271 # TODO(tanzheny) b/132690565: Provide mechanism for user to override 

272 # model.train_on_batch. 

273 if hasattr(model, '_backwards'): 

274 model._backwards(tape, scaled_total_loss) 

275 else: 

276 grads = tape.gradient(scaled_total_loss, trainable_weights) 

277 if isinstance(model.optimizer, 

278 loss_scale_optimizer.LossScaleOptimizer): 

279 grads = model.optimizer.get_unscaled_gradients(grads) 

280 model.optimizer.apply_gradients(zip(grads, trainable_weights)) 

281 else: 

282 logging.warning('The list of trainable weights is empty. Make sure that' 

283 ' you are not setting model.trainable to False before ' 

284 'compiling the model.') 

285 return outs, total_loss, output_losses, masks 

286 

287 

288def train_on_batch(model, 

289 inputs, 

290 targets, 

291 sample_weights=None, 

292 output_loss_metrics=None): 

293 """Calculates the loss and gradient updates for one input batch. 

294 

295 Args: 

296 model: Model whose loss has to be calculated. 

297 inputs: Input batch data. 

298 targets: Target batch data. 

299 sample_weights: Sample weight batch data. 

300 output_loss_metrics: List of metrics that are used to aggregated output 

301 loss values. 

302 

303 Returns: 

304 Dict with three items: 

305 'total_loss': list with a single tensor for overall loss, 

306 'output_losses': list of tensors for loss corresponding to each of the 

307 model output. Could be a empty list when model has only one output. 

308 'metrics': list of tensors for metric specified. 

309 """ 

310 inputs = training_utils_v1.cast_to_model_input_dtypes(inputs, model) 

311 outs, total_loss, output_losses, masks = ( 

312 _process_single_batch( 

313 model, 

314 inputs, 

315 targets, 

316 sample_weights=sample_weights, 

317 training=True, 

318 output_loss_metrics=output_loss_metrics)) 

319 if not isinstance(outs, list): 

320 outs = [outs] 

321 metrics_results = _eager_metrics_fn( 

322 model, outs, targets, sample_weights=sample_weights, masks=masks) 

323 total_loss = nest.flatten(total_loss) 

324 return {'total_loss': total_loss, 

325 'output_losses': output_losses, 

326 'metrics': metrics_results} 

327 

328 

329def test_on_batch(model, 

330 inputs, 

331 targets, 

332 sample_weights=None, 

333 output_loss_metrics=None): 

334 """Calculates the loss for one input batch. 

335 

336 Args: 

337 model: Model whose loss has to be calculated. 

338 inputs: Input batch data. 

339 targets: Target batch data. 

340 sample_weights: Sample weight batch data. 

341 output_loss_metrics: List of metrics that are used to aggregated output 

342 loss values. 

343 

344 Returns: 

345 Dict with three items: 

346 'total_loss': single tensor for overall loss, 

347 'output_losses': list of tensors for loss corresponding to each of the 

348 model output. Could be a empty list when model has only one output. 

349 'metrics': list of tensors for metric specified. 

350 """ 

351 inputs = training_utils_v1.cast_to_model_input_dtypes(inputs, model) 

352 

353 with backend.eager_learning_phase_scope(0): 

354 outs, total_loss, output_losses, masks = ( 

355 _model_loss( 

356 model, 

357 inputs, 

358 targets, 

359 sample_weights=sample_weights, 

360 training=False, 

361 output_loss_metrics=output_loss_metrics)) 

362 if not isinstance(outs, list): 

363 outs = [outs] 

364 metrics_results = _eager_metrics_fn( 

365 model, outs, targets, sample_weights=sample_weights, masks=masks) 

366 total_loss = nest.flatten(total_loss) 

367 

368 return {'total_loss': total_loss, 

369 'output_losses': output_losses, 

370 'metrics': metrics_results}