Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/engine/compile_utils.py: 16%

376 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15 

16"""Utilities for `Model.compile`.""" 

17 

18 

19import copy 

20 

21import tensorflow.compat.v2 as tf 

22 

23from keras.src import losses as losses_mod 

24from keras.src import metrics as metrics_mod 

25from keras.src.saving import saving_lib 

26from keras.src.utils import generic_utils 

27from keras.src.utils import losses_utils 

28from keras.src.utils import tf_utils 

29 

30 

31class Container: 

32 """Base Container class.""" 

33 

34 def __init__(self, output_names=None): 

35 self._output_names = output_names 

36 

37 def build(self, y_pred): 

38 if self._output_names is None: 

39 # In Subclass API, output names like 'output_1' are used for 

40 # `Metric` names. 

41 self._output_names = create_pseudo_output_names(y_pred) 

42 

43 def _conform_to_outputs(self, outputs, struct): 

44 """Convenience method to conform `struct` to `outputs` structure. 

45 

46 Mappings performed: 

47 

48 (1) Map a dict to a list of outputs, using the output names. 

49 (2) Fill missing keys in a dict w/ `None`s. 

50 (3) Map a single item to all outputs. 

51 

52 Args: 

53 outputs: Model predictions. 

54 struct: Arbitrary nested structure (e.g. of labels, sample_weights, 

55 losses, or metrics). 

56 

57 Returns: 

58 Mapping of `struct` to `outputs` structure. 

59 """ 

60 struct = map_to_output_names(outputs, self._output_names, struct) 

61 struct = map_missing_dict_keys(outputs, struct) 

62 # Allow passing one object that applies to all outputs. 

63 if not tf.nest.is_nested(struct) and tf.nest.is_nested(outputs): 

64 struct = tf.nest.map_structure(lambda _: struct, outputs) 

65 return struct 

66 

67 def _maybe_broadcast_to_outputs(self, outputs, objects): 

68 """Determines if losses / metrics should be applied to all outputs. 

69 

70 NOTE: This method should only be called for Metrics / Losses, not for 

71 y_true / sample_weight. 

72 

73 Args: 

74 outputs: Model predictions. 

75 objects: Arbitrary nested structure (e.g. of losses or metrics) 

76 

77 Returns: 

78 Arbitrary nested structure of objects, maybe copied to each output. 

79 

80 Applies a Loss / Metric to all outputs. 

81 """ 

82 if not self._should_broadcast(objects): 

83 return objects 

84 

85 # When there is more than one Model output, this is needed to keep 

86 # each Metric / Loss separate. When there is only one Model output, 

87 # the user-supplied object should be used. 

88 should_copy_objects = len(tf.nest.flatten(outputs)) > 1 

89 

90 def _broadcast_fn(): 

91 if should_copy_objects: 

92 return tf.nest.map_structure(self._copy_object, objects) 

93 return objects 

94 

95 return tf.nest.map_structure(lambda _: _broadcast_fn(), outputs) 

96 

97 def _should_broadcast(self, objects): 

98 raise NotImplementedError 

99 

100 def _copy_object(self, obj): 

101 raise NotImplementedError 

102 

103 

104class LossesContainer(Container): 

105 """A container class for losses passed to `Model.compile()`. 

106 

107 Args: 

108 losses: Struct of loss function(s). See `Model.compile()` doc for more 

109 information. 

110 loss_weights: Weights of the losses contributions of different model 

111 outputs. See `Model.compile()` doc for more information. 

112 output_names: List of string. Per-output metric names. 

113 total_loss_mean: A `keras.metrics.Mean` instance that is used to track the 

114 mean of all losses (including compiled and regularization losses). 

115 """ 

116 

117 def __init__( 

118 self, losses, loss_weights=None, output_names=None, total_loss_mean=None 

119 ): 

120 super(LossesContainer, self).__init__(output_names=output_names) 

121 

122 # Keep user-supplied values untouched for recompiling and serialization. 

123 self._user_losses = losses 

124 self._user_loss_weights = loss_weights 

125 

126 self._losses = losses 

127 self._loss_weights = loss_weights 

128 self._per_output_metrics = None # Per-output losses become metrics. 

129 

130 # Mean of the total loss. 

131 self._total_loss_mean = total_loss_mean or metrics_mod.Mean(name="loss") 

132 self._built = False 

133 

134 def get_config(self): 

135 # In case `self._losses` is a single string where we convert it to a 

136 # list. 

137 self._losses = tf.nest.flatten(self._losses) 

138 return { 

139 "losses": [ 

140 saving_lib.serialize_keras_object(obj) 

141 for obj in self._losses 

142 if obj is not None 

143 ], 

144 "total_loss_mean": saving_lib.serialize_keras_object( 

145 self._total_loss_mean 

146 ), 

147 } 

148 

149 @classmethod 

150 def from_config(cls, config): 

151 """Returns the `LossesContainer` instance given the `config`.""" 

152 deserialized_config = {} 

153 for key, value in config.items(): 

154 if isinstance(value, list): 

155 deserialized_config[key] = [ 

156 saving_lib.deserialize_keras_object(item) for item in value 

157 ] 

158 else: 

159 deserialized_config[key] = saving_lib.deserialize_keras_object( 

160 value 

161 ) 

162 return cls(**deserialized_config) 

163 

164 @property 

165 def metrics(self): 

166 """Per-output loss metrics.""" 

167 if not self._built: 

168 return [] 

169 per_output_metrics = [ 

170 metric_obj 

171 for metric_obj in tf.nest.flatten(self._per_output_metrics) 

172 if metric_obj is not None 

173 ] 

174 return [self._total_loss_mean] + per_output_metrics 

175 

176 def build(self, y_pred): 

177 """One-time setup of loss objects.""" 

178 super(LossesContainer, self).build(y_pred) 

179 

180 self._losses = self._maybe_broadcast_to_outputs(y_pred, self._losses) 

181 self._losses = self._conform_to_outputs(y_pred, self._losses) 

182 self._losses = tf.nest.map_structure( 

183 self._get_loss_object, self._losses 

184 ) 

185 self._losses = tf.nest.flatten(self._losses) 

186 

187 self._loss_weights = self._maybe_broadcast_to_outputs( 

188 y_pred, self._loss_weights 

189 ) 

190 self._loss_weights = self._conform_to_outputs( 

191 y_pred, self._loss_weights 

192 ) 

193 self._loss_weights = tf.nest.flatten(self._loss_weights) 

194 

195 self._create_metrics() 

196 self._built = True 

197 

198 @property 

199 def built(self): 

200 return self._built 

201 

202 def _create_metrics(self): 

203 """Creates per-output loss metrics, but only for multi-output Models.""" 

204 if len(self._output_names) == 1: 

205 self._per_output_metrics = [None] 

206 else: 

207 self._per_output_metrics = [] 

208 for loss_obj, output_name in zip(self._losses, self._output_names): 

209 if loss_obj is None: 

210 self._per_output_metrics.append(None) 

211 else: 

212 self._per_output_metrics.append( 

213 metrics_mod.Mean(output_name + "_loss") 

214 ) 

215 

216 def __call__( 

217 self, y_true, y_pred, sample_weight=None, regularization_losses=None 

218 ): 

219 """Computes the overall loss. 

220 

221 Args: 

222 y_true: An arbitrary structure of Tensors representing the ground 

223 truth. 

224 y_pred: An arbitrary structure of Tensors representing a Model's 

225 outputs. 

226 sample_weight: An arbitrary structure of Tensors representing the 

227 per-sample loss weights. If one Tensor is passed, it is used for all 

228 losses. If multiple Tensors are passed, the structure should match 

229 `y_pred`. 

230 regularization_losses: Additional losses to be added to the total 

231 loss. 

232 

233 Returns: 

234 The total loss as a `tf.Tensor`, or `None` if no loss results. 

235 """ 

236 y_true = self._conform_to_outputs(y_pred, y_true) 

237 sample_weight = self._conform_to_outputs(y_pred, sample_weight) 

238 

239 if not self._built: 

240 self.build(y_pred) 

241 

242 y_pred = tf.nest.flatten(y_pred) 

243 y_true = tf.nest.flatten(y_true) 

244 sample_weight = tf.nest.flatten(sample_weight) 

245 

246 loss_values = [] # Used for gradient calculation. 

247 total_loss_mean_values = [] # Used for loss metric calculation. 

248 batch_dim = None 

249 zip_args = ( 

250 y_true, 

251 y_pred, 

252 sample_weight, 

253 self._losses, 

254 self._loss_weights, 

255 self._per_output_metrics, 

256 ) 

257 for y_t, y_p, sw, loss_obj, loss_weight, metric_obj in zip(*zip_args): 

258 if ( 

259 y_t is None or loss_obj is None 

260 ): # Ok to have no loss for an output. 

261 continue 

262 

263 y_t, y_p, sw = match_dtype_and_rank(y_t, y_p, sw) 

264 sw = losses_utils.apply_mask(y_p, sw, losses_utils.get_mask(y_p)) 

265 loss_value = loss_obj(y_t, y_p, sample_weight=sw) 

266 

267 total_loss_mean_value = loss_value 

268 # Correct for the `Mean` loss metrics counting each replica as a 

269 # batch. 

270 if loss_obj.reduction == losses_utils.ReductionV2.SUM: 

271 total_loss_mean_value *= ( 

272 tf.distribute.get_strategy().num_replicas_in_sync 

273 ) 

274 

275 if batch_dim is None: 

276 if tf_utils.is_ragged(y_t): 

277 batch_dim = y_t.nrows() 

278 else: 

279 batch_dim = tf.shape(y_t)[0] 

280 

281 if metric_obj is not None: 

282 metric_obj.update_state( 

283 total_loss_mean_value, sample_weight=batch_dim 

284 ) 

285 

286 if loss_weight is not None: 

287 loss_value *= loss_weight 

288 total_loss_mean_value *= loss_weight 

289 

290 if ( 

291 loss_obj.reduction 

292 == losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE 

293 or loss_obj.reduction == losses_utils.ReductionV2.AUTO 

294 ): 

295 loss_value = losses_utils.scale_loss_for_distribution( 

296 loss_value 

297 ) 

298 

299 loss_values.append(loss_value) 

300 total_loss_mean_values.append(total_loss_mean_value) 

301 

302 if regularization_losses: 

303 regularization_losses = losses_utils.cast_losses_to_common_dtype( 

304 regularization_losses 

305 ) 

306 reg_loss = tf.add_n(regularization_losses) 

307 total_loss_mean_values.append(reg_loss) 

308 loss_values.append( 

309 losses_utils.scale_loss_for_distribution(reg_loss) 

310 ) 

311 

312 if loss_values: 

313 total_loss_mean_values = losses_utils.cast_losses_to_common_dtype( 

314 total_loss_mean_values 

315 ) 

316 total_total_loss_mean_value = tf.add_n(total_loss_mean_values) 

317 self._total_loss_mean.update_state( 

318 total_total_loss_mean_value, sample_weight=batch_dim 

319 ) 

320 

321 loss_values = losses_utils.cast_losses_to_common_dtype(loss_values) 

322 total_loss = tf.add_n(loss_values) 

323 return total_loss 

324 else: 

325 return None 

326 

327 def reset_state(self): 

328 """Resets the state of loss metrics.""" 

329 if not self._built: 

330 return 

331 metrics = [self._total_loss_mean] + tf.nest.flatten( 

332 self._per_output_metrics 

333 ) 

334 for metric_obj in metrics: 

335 if metric_obj is not None: 

336 metric_obj.reset_state() 

337 

338 def _get_loss_object(self, loss): 

339 """Returns a `Loss` object. 

340 

341 Converts the user-supplied loss to a `Loss` object. Also allows 

342 `SUM_OVER_BATCH_SIZE` reduction to be used for this loss. 

343 

344 Args: 

345 loss: A string, function, or `Loss` object. 

346 

347 Returns: 

348 A `Loss` object. 

349 """ 

350 if loss is None: 

351 return None # Ok to have no loss for an output. 

352 

353 loss = losses_mod.get(loss) 

354 if not isinstance(loss, losses_mod.Loss): 

355 loss_name = get_custom_object_name(loss) 

356 if loss_name is None: 

357 raise ValueError(f"Loss should be a callable, received: {loss}") 

358 loss = losses_mod.LossFunctionWrapper(loss, name=loss_name) 

359 loss._allow_sum_over_batch_size = True 

360 return loss 

361 

362 def _should_broadcast(self, obj): 

363 return not tf.nest.is_nested(obj) 

364 

365 def _copy_object(self, obj): 

366 return obj # Losses don't need to be copied. 

367 

368 

369class MetricsContainer(Container): 

370 """A container class for metrics passed to `Model.compile`.""" 

371 

372 def __init__( 

373 self, 

374 metrics=None, 

375 weighted_metrics=None, 

376 output_names=None, 

377 from_serialized=False, 

378 ): 

379 """Initializes a container for metrics. 

380 

381 Arguments: 

382 metrics: see the `metrics` argument from `tf.keras.Model.compile`. 

383 weighted_metrics: see the `weighted_metrics` argument from 

384 `tf.keras.Model.compile`. 

385 output_names: A list of strings of names of outputs for the model. 

386 from_serialized: Whether the model being compiled is from a serialized 

387 model. Used to avoid redundantly applying pre-processing renaming 

388 steps. 

389 """ 

390 super(MetricsContainer, self).__init__(output_names=output_names) 

391 

392 self._check_duplicated_metrics(metrics, weighted_metrics) 

393 # Keep user-supplied values untouched for recompiling and serialization. 

394 self._user_metrics = metrics 

395 self._user_weighted_metrics = weighted_metrics 

396 

397 self._metrics = metrics 

398 self._weighted_metrics = weighted_metrics 

399 self._built = False 

400 

401 self._from_serialized = from_serialized 

402 

403 def _check_duplicated_metrics(self, metrics, weighted_metrics): 

404 """Raise error when user provided metrics have any duplications. 

405 

406 Note that metrics are stateful container, a shared metric instance 

407 between model.metric and model.weighted_metric will make the same 

408 intance to be udpated twice, and report wrong value. 

409 

410 Args: 

411 metrics: User provided metrics list. 

412 weighted_metrics: User provided weighted metrics list. 

413 

414 Raises: 

415 ValueError, when duplicated metrics instance discovered in user 

416 provided metrics and weighted metrics. 

417 """ 

418 seen = set() 

419 duplicated = [] 

420 for x in tf.nest.flatten(metrics) + tf.nest.flatten(weighted_metrics): 

421 # We only check metrics object. The string and function objects 

422 # will be converted to unique Metric instance. 

423 if not isinstance(x, metrics_mod.Metric): 

424 continue 

425 if x in seen: 

426 duplicated.append(x) 

427 seen.add(x) 

428 

429 if duplicated: 

430 raise ValueError( 

431 "Found duplicated metrics object in the user provided " 

432 "metrics and weighted metrics. This will cause the same " 

433 "metric object to be updated multiple times, and report " 

434 "wrong results. \n" 

435 f"Duplicated items: {duplicated}" 

436 ) 

437 

438 @property 

439 def metrics(self): 

440 """All metrics in this container.""" 

441 if not self._built: 

442 return [] 

443 return self._metrics_in_order 

444 

445 @property 

446 def unweighted_metrics(self): 

447 """Metrics in the container that should not be passed sample_weight.""" 

448 if not self._built: 

449 return None 

450 return tf.nest.flatten(self._metrics) 

451 

452 @property 

453 def weighted_metrics(self): 

454 """Metrics in this container that should be passed `sample_weight`.""" 

455 if not self._built: 

456 return None 

457 return tf.nest.flatten(self._weighted_metrics) 

458 

459 def build(self, y_pred, y_true): 

460 """One-time setup of metric objects.""" 

461 super(MetricsContainer, self).build(y_pred) 

462 

463 self._metrics = self._maybe_broadcast_to_outputs(y_pred, self._metrics) 

464 self._metrics = self._conform_to_outputs(y_pred, self._metrics) 

465 

466 self._weighted_metrics = self._maybe_broadcast_to_outputs( 

467 y_pred, self._weighted_metrics 

468 ) 

469 self._weighted_metrics = self._conform_to_outputs( 

470 y_pred, self._weighted_metrics 

471 ) 

472 

473 # Standardize on tuple since `tf.data` turns lists into `Tensor`s. 

474 y_pred = tf.__internal__.nest.list_to_tuple(y_pred) 

475 y_true = tf.__internal__.nest.list_to_tuple(y_true) 

476 self._metrics = tf.__internal__.nest.list_to_tuple(self._metrics) 

477 self._weighted_metrics = tf.__internal__.nest.list_to_tuple( 

478 self._weighted_metrics 

479 ) 

480 

481 # Convert to `Metric` objects, potentially disambiguating based on 

482 # output properties. 

483 self._metrics = tf.__internal__.nest.map_structure_up_to( 

484 y_pred, self._get_metric_objects, self._metrics, y_true, y_pred 

485 ) 

486 self._weighted_metrics = tf.__internal__.nest.map_structure_up_to( 

487 y_pred, 

488 self._get_metric_objects, 

489 self._weighted_metrics, 

490 y_true, 

491 y_pred, 

492 ) 

493 

494 self._metrics = tf.__internal__.nest.flatten_up_to( 

495 y_pred, self._metrics, check_types=False 

496 ) 

497 self._weighted_metrics = tf.__internal__.nest.flatten_up_to( 

498 y_pred, self._weighted_metrics, check_types=False 

499 ) 

500 

501 # Assumes metrics, weighted_metrics have been flattened up to outputs. 

502 # 

503 # If we are loading a model that has been already serialized, we do not 

504 # want to re-apply any pre-processing metric renaming steps. 

505 if not self._from_serialized: 

506 self._set_metric_names() 

507 self._create_ordered_metrics() 

508 self._built = True 

509 

510 @property 

511 def built(self): 

512 return self._built 

513 

514 def _set_metric_names(self): 

515 """Sets unique metric names.""" 

516 # For multi-output models, prepend the output name to the metric name. 

517 # For weighted metrics, prepend "weighted_" if the name would be 

518 # non-unique. 

519 

520 metric_names = set() 

521 is_multi_output = len(self._output_names) > 1 

522 zip_args = (self._output_names, self._metrics, self._weighted_metrics) 

523 for output_name, output_metrics, weighted_output_metrics in zip( 

524 *zip_args 

525 ): 

526 for m in output_metrics: 

527 if m is None: 

528 continue 

529 if is_multi_output: 

530 m._name = output_name + "_" + m._name 

531 if m._name in metric_names: 

532 raise ValueError( 

533 f"Found two metrics with the same name: {m._name}. " 

534 "All the metrics added to the model need to have " 

535 "unique names." 

536 ) 

537 metric_names.add(m._name) 

538 

539 for wm in weighted_output_metrics: 

540 if wm is None: 

541 continue 

542 if is_multi_output: 

543 if output_name + "_" + wm._name in metric_names: 

544 wm._name = output_name + "_weighted_" + wm._name 

545 else: 

546 wm._name = output_name + "_" + wm._name 

547 elif wm._name in metric_names: 

548 wm._name = "weighted_" + wm._name 

549 

550 if wm._name in metric_names: 

551 raise ValueError( 

552 "Found two weighted metrics with the same name: " 

553 f"{wm._name}.All the metrics added to the model need " 

554 "to have unique names." 

555 ) 

556 metric_names.add(wm._name) 

557 

558 def _create_ordered_metrics(self): 

559 """Cache the flat order needed when return metrics, for backcompat.""" 

560 self._metrics_in_order = [] 

561 for output_metrics, output_weighted_metrics in zip( 

562 self._metrics, self._weighted_metrics 

563 ): 

564 for m in tf.nest.flatten(output_metrics): 

565 if m is not None: 

566 self._metrics_in_order.append(m) 

567 for wm in tf.nest.flatten(output_weighted_metrics): 

568 if wm is not None: 

569 self._metrics_in_order.append(wm) 

570 

571 def update_state(self, y_true, y_pred, sample_weight=None): 

572 """Updates the state of per-output metrics.""" 

573 y_true = self._conform_to_outputs(y_pred, y_true) 

574 sample_weight = self._conform_to_outputs(y_pred, sample_weight) 

575 

576 if not self._built: 

577 self.build(y_pred, y_true) 

578 

579 y_pred = tf.nest.flatten(y_pred) 

580 y_true = tf.nest.flatten(y_true) if y_true is not None else [] 

581 sample_weight = tf.nest.flatten(sample_weight) 

582 

583 zip_args = ( 

584 y_true, 

585 y_pred, 

586 sample_weight, 

587 self._metrics, 

588 self._weighted_metrics, 

589 ) 

590 for y_t, y_p, sw, metric_objs, weighted_metric_objs in zip(*zip_args): 

591 # Ok to have no metrics for an output. 

592 if y_t is None or ( 

593 all(m is None for m in metric_objs) 

594 and all(wm is None for wm in weighted_metric_objs) 

595 ): 

596 continue 

597 

598 y_t, y_p, sw = match_dtype_and_rank(y_t, y_p, sw) 

599 mask = losses_utils.get_mask(y_p) 

600 sw = losses_utils.apply_mask(y_p, sw, mask) 

601 

602 for metric_obj in metric_objs: 

603 if metric_obj is None: 

604 continue 

605 metric_obj.update_state(y_t, y_p, sample_weight=mask) 

606 

607 for weighted_metric_obj in weighted_metric_objs: 

608 if weighted_metric_obj is None: 

609 continue 

610 weighted_metric_obj.update_state(y_t, y_p, sample_weight=sw) 

611 

612 def reset_state(self): 

613 """Resets the state of all `Metric`s in this container.""" 

614 if self._built: 

615 metrics = self._metrics_in_order 

616 else: 

617 # If the user supplied `Metric` objects directly, we should 

618 # reset those. This could also contain `str`s or `function`s 

619 # though. 

620 metrics = tf.nest.flatten(self._user_metrics) + tf.nest.flatten( 

621 self._user_weighted_metrics 

622 ) 

623 

624 for metric_obj in metrics: 

625 if isinstance(metric_obj, metrics_mod.Metric): 

626 metric_obj.reset_state() 

627 

628 def _get_metric_objects(self, metrics, y_t, y_p): 

629 """Convert user-supplied metrics to `Metric` objects.""" 

630 metrics = tf.nest.flatten(metrics) 

631 return [self._get_metric_object(m, y_t, y_p) for m in metrics] 

632 

633 def _get_metric_object(self, metric, y_t, y_p): 

634 """Converts user-supplied metric to a `Metric` object. 

635 

636 Args: 

637 metric: A string, function, or `Metric` object. 

638 y_t: Sample of label. 

639 y_p: Sample of output. 

640 

641 Returns: 

642 A `Metric` object. 

643 """ 

644 if metric is None: 

645 return None # Ok to have no metric for an output. 

646 

647 # Convenience feature for selecting b/t binary, categorical, 

648 # and sparse categorical. 

649 if str(metric).lower() not in ["accuracy", "acc", "crossentropy", "ce"]: 

650 metric_obj = metrics_mod.get(metric) 

651 else: 

652 y_t_rank = len(y_t.shape.as_list()) 

653 y_p_rank = len(y_p.shape.as_list()) 

654 y_t_last_dim = y_t.shape.as_list()[-1] 

655 y_p_last_dim = y_p.shape.as_list()[-1] 

656 

657 is_binary = y_p_last_dim == 1 

658 is_sparse_categorical = ( 

659 y_t_rank < y_p_rank or y_t_last_dim == 1 and y_p_last_dim > 1 

660 ) 

661 

662 if str(metric).lower() in ["accuracy", "acc"]: 

663 if is_binary: 

664 metric_obj = metrics_mod.binary_accuracy 

665 elif is_sparse_categorical: 

666 metric_obj = metrics_mod.sparse_categorical_accuracy 

667 else: 

668 metric_obj = metrics_mod.categorical_accuracy 

669 else: 

670 if is_binary: 

671 metric_obj = metrics_mod.binary_crossentropy 

672 elif is_sparse_categorical: 

673 metric_obj = metrics_mod.sparse_categorical_crossentropy 

674 else: 

675 metric_obj = metrics_mod.categorical_crossentropy 

676 

677 if isinstance(metric_obj, losses_mod.Loss): 

678 metric_obj._allow_sum_over_batch_size = True 

679 

680 if not isinstance(metric_obj, metrics_mod.Metric): 

681 if isinstance(metric, str): 

682 metric_name = metric 

683 else: 

684 metric_name = get_custom_object_name(metric) 

685 if metric_name is None: 

686 raise ValueError( 

687 f"Metric should be a callable, received: {metric}" 

688 ) 

689 

690 metric_obj = metrics_mod.MeanMetricWrapper( 

691 metric_obj, name=metric_name 

692 ) 

693 

694 return metric_obj 

695 

696 def _should_broadcast(self, obj): 

697 # e.g. 'mse'. 

698 if not tf.nest.is_nested(obj): 

699 return True 

700 # e.g. ['mse'] or ['mse', 'mae']. 

701 return isinstance(obj, (list, tuple)) and not any( 

702 tf.nest.is_nested(o) for o in obj 

703 ) 

704 

705 def _copy_object(self, obj): 

706 if isinstance(obj, metrics_mod.Metric): 

707 return obj.__class__.from_config(obj.get_config()) 

708 return obj # Can be a function or `None`. 

709 

710 

711def create_pseudo_output_names(outputs): 

712 """Create pseudo output names for a subclassed Model.""" 

713 return _create_pseudo_names(outputs, prefix="output_") 

714 

715 

716def create_pseudo_input_names(inputs): 

717 """Create pseudo input names for a subclassed Model.""" 

718 return _create_pseudo_names(inputs, prefix="input_") 

719 

720 

721def _create_pseudo_names(tensors, prefix): 

722 """Creates pseudo {input | output} names for subclassed Models. 

723 

724 Warning: this function should only be used to define default 

725 names for `Metics` and `SavedModel`. No other use cases should 

726 rely on a `Model`'s input or output names. 

727 

728 Example with dict: 

729 

730 `{'a': [x1, x2], 'b': x3}` becomes: 

731 `['a_1', 'a_2', 'b']` 

732 

733 Example with list: 

734 

735 `[x, y]` becomes: 

736 `['output_1', 'output_2']` 

737 

738 Args: 

739 tensors: `Model`'s outputs or inputs. 

740 prefix: 'output_' for outputs, 'input_' for inputs. 

741 

742 Returns: 

743 Flattened list of pseudo names. 

744 """ 

745 

746 def one_index(ele): 

747 # Start with "output_1" instead of "output_0". 

748 if isinstance(ele, int): 

749 return ele + 1 

750 return ele 

751 

752 flat_paths = list(tf.__internal__.nest.yield_flat_paths(tensors)) 

753 flat_paths = tf.nest.map_structure(one_index, flat_paths) 

754 names = [] 

755 for path in flat_paths: 

756 if not path: 

757 name = prefix + "1" # Single output. 

758 else: 

759 name = "_".join(str(p) for p in path) 

760 if isinstance(path[0], int): 

761 name = prefix + name 

762 names.append(name) 

763 return names 

764 

765 

766def map_to_output_names(y_pred, output_names, struct): 

767 """Maps a dict to a list using `output_names` as keys. 

768 

769 This is a convenience feature only. When a `Model`'s outputs 

770 are a list, you can specify per-output losses and metrics as 

771 a dict, where the keys are the output names. If you specify 

772 per-output losses and metrics via the same structure as the 

773 `Model`'s outputs (recommended), no mapping is performed. 

774 

775 For the Functional API, the output names are the names of the 

776 last layer of each output. For the Subclass API, the output names 

777 are determined by `create_pseudo_output_names` (For example: 

778 `['output_1', 'output_2']` for a list of outputs). 

779 

780 This mapping preserves backwards compatibility for `compile` and 

781 `fit`. 

782 

783 Args: 

784 y_pred: Sample outputs of the Model, to determine if this convenience 

785 feature should be applied (`struct` is returned unmodified if `y_pred` 

786 isn't a flat list). 

787 output_names: List. The names of the outputs of the Model. 

788 struct: The structure to map. 

789 

790 Returns: 

791 `struct` mapped to a list in same order as `output_names`. 

792 """ 

793 single_output = not tf.nest.is_nested(y_pred) 

794 outputs_are_flat_list = ( 

795 not single_output 

796 and isinstance(y_pred, (list, tuple)) 

797 and not any(tf.nest.is_nested(y_p) for y_p in y_pred) 

798 ) 

799 

800 if (single_output or outputs_are_flat_list) and isinstance(struct, dict): 

801 output_names = output_names or create_pseudo_output_names(y_pred) 

802 struct = copy.copy(struct) 

803 new_struct = [struct.pop(name, None) for name in output_names] 

804 if struct: 

805 raise ValueError( 

806 "Found unexpected losses or metrics that do not correspond " 

807 f"to any Model output: {struct.keys()}. " 

808 f"Valid mode output names: {output_names}. " 

809 f"Received struct is: {struct}." 

810 ) 

811 if len(new_struct) == 1: 

812 return new_struct[0] 

813 return new_struct 

814 else: 

815 return struct 

816 

817 

818def map_missing_dict_keys(y_pred, struct): 

819 """Replaces missing dict keys in `struct` with `None` placeholders.""" 

820 if not isinstance(y_pred, dict) or not isinstance(struct, dict): 

821 return struct 

822 struct = copy.copy(struct) 

823 for k in y_pred.keys(): 

824 if k not in struct: 

825 struct[k] = None 

826 return struct 

827 

828 

829def match_dtype_and_rank(y_t, y_p, sw): 

830 """Match dtype and rank of predictions.""" 

831 if y_t.shape.rank == 1 and y_p.shape.rank == 2: 

832 y_t = tf.expand_dims(y_t, axis=-1) 

833 if sw is not None: 

834 if sw.shape.rank == 1 and y_p.shape.rank == 2: 

835 sw = tf.expand_dims(sw, axis=-1) 

836 

837 # Dtype. 

838 # This is required mainly for custom loss functions which do not take care 

839 # casting dtypes. 

840 if (y_t.dtype.is_floating and y_p.dtype.is_floating) or ( 

841 y_t.dtype.is_integer and y_p.dtype.is_integer 

842 ): 

843 y_t = tf.cast(y_t, y_p.dtype) 

844 

845 if sw is not None: 

846 sw = tf.cast(sw, y_p.dtype) 

847 return y_t, y_p, sw 

848 

849 

850def get_custom_object_name(obj): 

851 """Returns the name to use for a custom loss or metric callable. 

852 

853 Args: 

854 obj: Custom loss of metric callable 

855 

856 Returns: 

857 Name to use, or `None` if the object was not recognized. 

858 """ 

859 if hasattr(obj, "name"): # Accept `Loss` instance as `Metric`. 

860 return obj.name 

861 elif hasattr(obj, "__name__"): # Function. 

862 return obj.__name__ 

863 elif hasattr(obj, "__class__"): # Class instance. 

864 return generic_utils.to_snake_case(obj.__class__.__name__) 

865 else: # Unrecognized object. 

866 return None 

867