Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/distribute/distributed_training_utils_v1.py: 15%

436 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Utilities related to distributed training.""" 

16 

17import functools 

18 

19import numpy as np 

20import tensorflow.compat.v2 as tf 

21 

22from keras.src import backend 

23from keras.src import callbacks 

24from keras.src import metrics as metrics_module 

25from keras.src import optimizers 

26from keras.src.distribute import distribute_coordinator_utils as dc 

27from keras.src.distribute import distributed_training_utils as dist_utils 

28from keras.src.engine import training_utils_v1 

29from keras.src.optimizers.legacy import optimizer_v2 

30from keras.src.utils import tf_contextlib 

31from keras.src.utils.mode_keys import ModeKeys 

32 

33# isort: off 

34from tensorflow.python.platform import tf_logging as logging 

35 

36 

37def set_weights(distribution_strategy, dist_model, weights): 

38 """Sets the weights of the replicated models. 

39 

40 The weights of the replicated models are set to the weights of the original 

41 model. The weights of the replicated model are Mirrored variables and hence 

42 we need to use the `update` call within a DistributionStrategy scope. 

43 

44 Args: 

45 distribution_strategy: DistributionStrategy used to distribute training 

46 and validation. 

47 dist_model: The replicated models on the different devices. 

48 weights: The weights of the original model. 

49 """ 

50 assign_ops = [] 

51 for layer in dist_model.layers: 

52 num_param = len(layer.weights) 

53 layer_weights = weights[:num_param] 

54 for sw, w in zip(layer.weights, layer_weights): 

55 if tf.compat.v1.executing_eagerly_outside_functions(): 

56 sw.assign(w) 

57 else: 

58 assign_ops.append(distribution_strategy.unwrap(sw.assign(w))) 

59 weights = weights[num_param:] 

60 

61 if not tf.compat.v1.executing_eagerly_outside_functions(): 

62 backend.get_session(assign_ops).run(assign_ops) 

63 

64 

65def unwrap_values( 

66 distribution_strategy, 

67 grouped_inputs, 

68 grouped_outputs, 

69 grouped_updates=None, 

70 grouped_session_args=None, 

71 with_loss_tensor=False, 

72): 

73 """Unwrap the list of values contained in the PerReplica parameters. 

74 

75 This function calls `flatten_per_replica_values` to parse each of the input 

76 parameters into a list of values on the different devices. If we set 

77 `with_loss_tensor` to be True, we also call `reduce` on the list of losses 

78 on the different devices to give us one loss tensor. 

79 

80 Args: 

81 distribution_strategy: DistributionStrategy used to distribute training 

82 and validation. 

83 grouped_inputs: PerReplica inputs returned from the train or test function 

84 that we ran on each device. 

85 grouped_outputs: PerReplica outputs returned from the train or test 

86 function that we ran on each device. 

87 grouped_updates: PerReplica updates returned from the train or test 

88 function that we ran on each device. 

89 grouped_session_args: PerReplica session args returned from the train or 

90 test function that we ran on each device. 

91 with_loss_tensor: Boolean that indicates if we need to add the reduced 

92 loss tensor as one of the outputs. 

93 

94 Returns: 

95 Values of each of the PerReplica parameters. 

96 

97 """ 

98 # Unwrap per device values returned from each model's train function. 

99 # This will be used to construct the main train function. 

100 all_inputs = flatten_per_replica_values( 

101 distribution_strategy, grouped_inputs 

102 ) 

103 all_outputs = unwrap_outputs( 

104 distribution_strategy, grouped_outputs, with_loss_tensor 

105 ) 

106 

107 if grouped_updates: 

108 all_updates = flatten_per_replica_values( 

109 distribution_strategy, grouped_updates 

110 ) 

111 else: 

112 all_updates = None 

113 

114 all_session_args = {} 

115 if grouped_session_args: 

116 grouped_feed_dict = grouped_session_args.get("feed_dict") 

117 if grouped_feed_dict: 

118 all_session_args["feed_dict"] = flatten_per_replica_values( 

119 distribution_strategy, grouped_feed_dict 

120 ) 

121 

122 grouped_fetches = grouped_session_args.get("fetches") 

123 if grouped_fetches: 

124 all_session_args["fetches"] = flatten_per_replica_values( 

125 distribution_strategy, grouped_fetches 

126 ) 

127 

128 # TODO(priyag): Return only non empty/None values 

129 return all_inputs, all_outputs, all_updates, all_session_args 

130 

131 

132def unwrap_output_dict(strategy, grouped_outputs, mode): 

133 """Unwrap the list of outputs contained in the PerReplica parameters.""" 

134 if mode == ModeKeys.PREDICT: 

135 return flatten_per_replica_values(strategy, grouped_outputs) 

136 

137 # In the case of fit/eval, the grouped_outputs is a dict, whereas in 

138 # predict, the output is as same structure as model output. They need to be 

139 # treated differently 

140 total_loss = strategy.reduce( 

141 tf.distribute.ReduceOp.SUM, grouped_outputs["total_loss"][0], axis=None 

142 ) 

143 output_losses = flatten_per_replica_values( 

144 strategy, grouped_outputs["output_losses"] 

145 ) 

146 metrics = flatten_per_replica_values(strategy, grouped_outputs["metrics"]) 

147 batch_size = strategy.reduce( 

148 tf.distribute.ReduceOp.SUM, grouped_outputs["batch_size"], axis=None 

149 ) 

150 if ( 

151 backend.is_tpu_strategy(strategy) 

152 and tf.compat.v1.executing_eagerly_outside_functions() 

153 ): 

154 # Choose 1 value per replica in the TPU case since all replicas produce 

155 # the same output. 

156 # We only do this in eager mode for now since this function is used in 

157 # both graph and eager mode and in the graph case we currently don't use 

158 # experimental_run so would need to be removed when we converge the 

159 # graph code path as well. 

160 output_losses = output_losses[:: strategy.num_replicas_in_sync] 

161 metrics = metrics[:: strategy.num_replicas_in_sync] 

162 return { 

163 "total_loss": [total_loss], 

164 "output_losses": output_losses, 

165 "metrics": metrics, 

166 "batch_size": batch_size, 

167 } 

168 

169 

170def unwrap_outputs( 

171 distribution_strategy, grouped_outputs, with_loss_tensor=False 

172): 

173 """Unwrap the list of outputs contained in the PerReplica parameters. 

174 

175 This function calls `flatten_per_replica_values` to parse each of the input 

176 parameters into a list of outputs on the different devices. If we set 

177 `with_loss_tensor` to be True, we also call `reduce` on the list of losses 

178 on the different devices to give us one loss tensor. 

179 

180 Args: 

181 distribution_strategy: DistributionStrategy used to distribute training 

182 and validation. 

183 grouped_outputs: PerReplica outputs returned from the train or test 

184 function that we ran on each device. 

185 with_loss_tensor: Boolean that indicates if we need to add the reduced 

186 loss tensor as one of the outputs. 

187 

188 Returns: 

189 Values of each of the PerReplica outputs. 

190 

191 """ 

192 if not with_loss_tensor: 

193 return flatten_per_replica_values( 

194 distribution_strategy, grouped_outputs 

195 ) 

196 

197 if not isinstance(grouped_outputs, list): 

198 grouped_outputs = [grouped_outputs] 

199 # reduce loss tensor before adding it to the list of fetches 

200 loss = distribution_strategy.reduce( 

201 tf.distribute.ReduceOp.SUM, grouped_outputs[0], axis=None 

202 ) 

203 all_outputs = flatten_per_replica_values( 

204 distribution_strategy, grouped_outputs[1:] 

205 ) 

206 if ( 

207 backend.is_tpu_strategy(distribution_strategy) 

208 and tf.compat.v1.executing_eagerly_outside_functions() 

209 ): 

210 # Choose 1 value per replica in the TPU case since all replicas produce 

211 # the same output. 

212 # We only do this in eager mode for now since this function is used in 

213 # both graph and eager mode and in the graph case we currently don't use 

214 # experimental_run so would need to be removed when we converge the 

215 # graph code path as well. 

216 all_outputs = all_outputs[:: distribution_strategy.num_replicas_in_sync] 

217 return [loss] + all_outputs 

218 

219 

220def flatten_per_replica_values(distribution_strategy, per_replica_values): 

221 """Unwraps and flattens a nest of PerReplica parameters. 

222 

223 PerReplica values have one value associated with each device. Each entry in 

224 the PerReplica dict has a device `key` and the corresponding value on the 

225 device as the `value`. In this function we take a PerReplica value or a list 

226 of PerReplica values and return all the values in the PerReplica dict. 

227 

228 Args: 

229 distribution_strategy: DistributionStrategy used to distribute training 

230 and validation. 

231 per_replica_values: List of PerReplica object or a single PerReplica 

232 object. 

233 

234 Returns: 

235 List of values of all the PerReplica objects. 

236 

237 """ 

238 

239 # This function takes a PerReplica object or a list of PerReplica objects 

240 # and returns all the values associated with it. 

241 return [ 

242 e 

243 for flattened in tf.nest.flatten(per_replica_values) 

244 for e in distribution_strategy.unwrap(flattened) 

245 ] 

246 

247 

248def validate_callbacks(input_callbacks, optimizer): 

249 """Validate whether given callbacks are supported by DistributionStrategy. 

250 

251 Args: 

252 input_callbacks: List of callbacks passed by the user to fit. 

253 optimizer: Optimizer instance used to train the model. 

254 

255 Raises: 

256 ValueError: If `LearningRateScheduler` or `ReduceLROnPlateau` is one of 

257 the callbacks passed. 

258 ValueError: If `write_grads` is one of the parameters passed as part of 

259 the TensorBoard callback. 

260 """ 

261 if input_callbacks: 

262 for callback in input_callbacks: 

263 if isinstance( 

264 callback, 

265 (callbacks.LearningRateScheduler, callbacks.ReduceLROnPlateau), 

266 ): 

267 

268 if not isinstance(optimizer, optimizer_v2.OptimizerV2): 

269 raise ValueError( 

270 "You must specify a Keras Optimizer V2 when using " 

271 "%s callback with DistributionStrategy." % callback 

272 ) 

273 

274 # If users want to use the TensorBoard callback they cannot use 

275 # certain features of the callback that involve accessing model 

276 # attributes and running ops. 

277 if isinstance(callback, callbacks.TensorBoard): 

278 if getattr(callback, "write_grads", False): 

279 logging.warning( 

280 UserWarning( 

281 "`write_grads` in the TensorBoard callback is not " 

282 "supported when using DistributionStrategy. " 

283 "Setting `write_grads` to `False`." 

284 ) 

285 ) 

286 callback.write_grads = False 

287 

288 

289def validate_distributed_dataset_inputs( 

290 distribution_strategy, x, y, sample_weights=None 

291): 

292 """Validate all the components of a DistributedValue Dataset input. 

293 

294 Args: 

295 distribution_strategy: The current DistributionStrategy used to call 

296 `fit`/`evaluate`. 

297 x: Input Dataset DistributedValue object. For example, when we use 

298 `MirroredStrategy` this is a PerReplica object with a tensor for each 

299 device set in the dict. x can also be a tuple or dict. The keys of the 

300 dict should match the names of the input layers of the model. 

301 y: Target Dataset DistributedValue object. For example, when we use 

302 `MirroredStrategy` this is a PerReplica object with a tensor for each 

303 device set in the dict. y can also be a tuple or dict. The keys of the 

304 dict should match the names of the output layers of the model. 

305 sample_weights: Sample weights Dataset DistributedValue object. For 

306 example, when we use `MirroredStrategy` this is a PerReplica object 

307 with a tensor for each device set in the dict. 

308 

309 Returns: 

310 The unwrapped values list of the x and y DistributedValues inputs. 

311 

312 Raises: 

313 ValueError: If x and y do not have support for being evaluated as tensors. 

314 or if x and y contain elements that are not tensors or if x and y 

315 contain elements that have a shape or dtype mismatch. 

316 """ 

317 # If the input and target used to call the model are not dataset tensors, 

318 # we need to raise an error. When using a DistributionStrategy, the input 

319 # and targets to a model should be from a `tf.data.Dataset`. 

320 

321 # If each element of x and y are not tensors, we cannot standardize and 

322 # validate the input and targets. 

323 x_values_list = validate_per_replica_inputs(distribution_strategy, x) 

324 

325 if y is not None: 

326 y_values_list = validate_per_replica_inputs(distribution_strategy, y) 

327 else: 

328 y_values_list = None 

329 

330 if sample_weights is not None: 

331 sample_weights_list = validate_per_replica_inputs( 

332 distribution_strategy, sample_weights 

333 ) 

334 else: 

335 sample_weights_list = None 

336 

337 # Return the unwrapped values to avoid calling `unwrap` a second time. 

338 return x_values_list, y_values_list, sample_weights_list 

339 

340 

341def validate_per_replica_inputs(distribution_strategy, x): 

342 """Validates PerReplica dataset input list. 

343 

344 Args: 

345 distribution_strategy: The current DistributionStrategy used to call 

346 `fit`, `evaluate` and `predict`. 

347 x: A list of PerReplica objects that represent the input or 

348 target values. 

349 

350 Returns: 

351 List containing the first element of each of the PerReplica objects in 

352 the input list. 

353 

354 Raises: 

355 ValueError: If any of the objects in the `per_replica_list` is not a 

356 tensor. 

357 

358 """ 

359 # Convert the inputs and targets into a list of PerReplica objects. 

360 per_replica_list = tf.nest.flatten(x) 

361 x_values_list = [] 

362 for x in per_replica_list: 

363 # At this point x should contain only tensors. 

364 x_values = distribution_strategy.unwrap(x) 

365 for value in x_values: 

366 if not tf.is_tensor(value): 

367 raise ValueError( 

368 "Dataset input to the model should be tensors instead " 

369 "they are of type {}".format(type(value)) 

370 ) 

371 

372 if not tf.executing_eagerly(): 

373 # Validate that the shape and dtype of all the elements in x are the 

374 # same. 

375 validate_all_tensor_shapes(x, x_values) 

376 validate_all_tensor_types(x, x_values) 

377 

378 x_values_list.append(x_values[0]) 

379 return x_values_list 

380 

381 

382def validate_all_tensor_types(x, x_values): 

383 x_dtype = x_values[0].dtype 

384 for i in range(1, len(x_values)): 

385 if x_dtype != x_values[i].dtype: 

386 raise ValueError( 

387 "Input tensor dtypes do not match for distributed tensor" 

388 " inputs {}".format(x) 

389 ) 

390 

391 

392def validate_all_tensor_shapes(x, x_values): 

393 # Validate that the shape of all the elements in x have the same shape 

394 x_shape = x_values[0].shape.as_list() 

395 for i in range(1, len(x_values)): 

396 if x_shape != x_values[i].shape.as_list(): 

397 raise ValueError( 

398 "Input tensor shapes do not match for distributed tensor" 

399 " inputs {}".format(x) 

400 ) 

401 

402 

403def _wait_for_variable_initialization(session): 

404 """Utility to wait for variables to be initialized.""" 

405 all_variables = backend._get_variables(backend.get_graph()) 

406 candidate_vars = [] 

407 for v in all_variables: 

408 if not getattr(v, "_keras_initialized", False): 

409 candidate_vars.append(v) 

410 

411 if not candidate_vars: 

412 return 

413 

414 while True: 

415 is_initialized = session.run( 

416 [tf.compat.v1.is_variable_initialized(v) for v in candidate_vars] 

417 ) 

418 uninitialized_vars = [] 

419 for flag, v in zip(is_initialized, candidate_vars): 

420 if not flag: 

421 uninitialized_vars.append(v) 

422 v._keras_initialized = True 

423 if not uninitialized_vars: 

424 break 

425 

426 

427def init_restore_or_wait_for_variables(): 

428 """Initialize or restore variables or wait for variables to be 

429 initialized.""" 

430 backend._initialize_variables(backend._get_session()) 

431 

432 

433def validate_inputs(x, y): 

434 """Validate inputs when using DistributionStrategy. 

435 

436 Args: 

437 x: Model Inputs. 

438 y: Model Targets. 

439 

440 Raises: 

441 ValueError: if input is not a Dataset or a numpy array(when we use 

442 MirroredStrategy). 

443 """ 

444 if isinstance(x, tf.compat.v1.data.Iterator) or isinstance( 

445 y, tf.compat.v1.data.Iterator 

446 ): 

447 raise ValueError( 

448 "`DistributionStrategy` does not support inputs of type " 

449 "Iterator. You must pass a `tf.data.Dataset` object or a " 

450 "numpy array as input." 

451 ) 

452 

453 

454def is_dataset_shape_fully_defined(dataset): 

455 """Returns whether a dataset contains a final partial batch.""" 

456 shapes = tf.nest.flatten(tf.compat.v1.data.get_output_shapes(dataset)) 

457 unknown_shapes = [s for s in shapes if not s.is_fully_defined()] 

458 return not unknown_shapes 

459 

460 

461def process_batch_and_step_size( 

462 strategy, inputs, batch_size, steps_per_epoch, mode, validation_split=0.0 

463): 

464 """Process the batch size and step size based on input and dist strategy.""" 

465 first_x_value = tf.nest.flatten(inputs)[0] 

466 if isinstance(first_x_value, np.ndarray): 

467 num_samples = first_x_value.shape[0] 

468 if validation_split and 0.0 < validation_split < 1.0: 

469 num_samples = int(num_samples * (1 - validation_split)) 

470 # Until support for partial batch is implemented across all 

471 # functions and distribution strategy, we pass `mode` to selectively 

472 # relax the constraint to consume all the training samples. 

473 steps_per_epoch, batch_size = get_input_params( 

474 strategy, num_samples, steps_per_epoch, batch_size, mode=mode 

475 ) 

476 return batch_size, steps_per_epoch 

477 

478 

479def get_input_params( 

480 distribution_strategy, num_samples, steps, batch_size, mode=None 

481): 

482 """Calculate the number of batches and steps/steps_per_epoch. 

483 

484 Args: 

485 distribution_strategy: The DistributionStrategy used to compile the model. 

486 num_samples: The number of samples from which we determine the batch size 

487 and steps. 

488 steps: The specified number of steps. 

489 batch_size: The specified batch_size. 

490 mode: ModeKey representing whether input will be used for training, 

491 evaluation, or prediction. This is used to relax the constraints on 

492 consuming all the training samples to keep compatibility till we support 

493 partial batches. If none, then partial batches are not allowed. 

494 

495 Returns: 

496 steps: The steps or steps_per_epoch argument depending on if a user is 

497 calling `fit`, `evaluate` or `predict`. If the is_training flag is set 

498 we don't require the number of samples to be used completely. 

499 batch_size: The batch size to be used in model iterations. 

500 

501 Raises: 

502 ValueError: If the number of batches or steps evaluates to 0. 

503 

504 """ 

505 # TODO(b/118776054): Use global batch size for Keras/DS support. 

506 # Currently this is only supported in TPUStrategy and CoreMirroredStrategy. 

507 use_per_replica_batch = not dist_utils.global_batch_size_supported( 

508 distribution_strategy 

509 ) 

510 

511 # TODO(b/128995245): In eager mode, uneven batch sizes are allowed except 

512 # for `fit()` on TPUStrategy. 

513 # In graph mode, the zero batch case in batch norm is not handled due to 

514 # XLA-GPU regression. Uneven batch sizes are not allowed except 

515 # for `test()` and `predict()` on TPUStrategy. 

516 if tf.executing_eagerly(): 

517 allow_partial_batch = ( 

518 mode != ModeKeys.TRAIN 

519 or not backend.is_tpu_strategy(distribution_strategy) 

520 ) 

521 else: 

522 allow_partial_batch = mode == ModeKeys.TRAIN or ( 

523 (mode == ModeKeys.PREDICT or mode == ModeKeys.TEST) 

524 and backend.is_tpu_strategy(distribution_strategy) 

525 ) 

526 

527 if steps is None: 

528 if batch_size is None: 

529 # If neither the batch size or number of steps are set. We choose 

530 # the global batch size as the minimum of number of samples and 32. 

531 # 32 is chosen to provide backward compatibility. 

532 global_batch_size = min(num_samples, 32) 

533 else: 

534 # If the user provided the batch size we need to handle the case 

535 # between different strategies that use the global/per-replica batch 

536 # size 

537 global_batch_size = batch_size 

538 if use_per_replica_batch: 

539 global_batch_size *= distribution_strategy.num_replicas_in_sync 

540 if allow_partial_batch: 

541 steps = np.ceil(num_samples / global_batch_size).astype(int) 

542 else: 

543 if num_samples % global_batch_size: 

544 raise ValueError( 

545 "The number of samples %s is not divisible by " 

546 "batch size %s." % (num_samples, global_batch_size) 

547 ) 

548 steps = num_samples // global_batch_size 

549 else: 

550 if batch_size is None: 

551 # We calculate the batch size based on the number of steps specified 

552 if num_samples % steps: 

553 raise ValueError( 

554 "The number of samples %s is not divisible by " 

555 "steps %s. Please change the number of steps to a " 

556 "value that can consume all the samples" 

557 % (num_samples, steps) 

558 ) 

559 global_batch_size = num_samples // steps 

560 else: 

561 # If the user provided the batch size we need to handle the case 

562 # between different strategies that use the global/per-replica batch 

563 # size 

564 global_batch_size = batch_size 

565 if use_per_replica_batch: 

566 global_batch_size *= distribution_strategy.num_replicas_in_sync 

567 

568 min_num_samples = global_batch_size * steps 

569 if allow_partial_batch: 

570 min_num_samples = ( 

571 global_batch_size * (steps - 1) + 1 if steps > 1 else 0 

572 ) 

573 

574 if num_samples < min_num_samples: 

575 raise ValueError( 

576 "Number of samples %s is less than samples required " 

577 "for specified batch_size %s and steps %s" 

578 % (num_samples, global_batch_size, steps) 

579 ) 

580 

581 # We need to return the per replica or global batch size based on the 

582 # strategy 

583 if use_per_replica_batch: 

584 if global_batch_size % distribution_strategy.num_replicas_in_sync: 

585 raise ValueError( 

586 "The batch size (%s) could not be sharded evenly across the " 

587 "sync replicas (%s) in the distribution strategy." 

588 % ( 

589 global_batch_size, 

590 distribution_strategy.num_replicas_in_sync, 

591 ) 

592 ) 

593 batch_size = ( 

594 global_batch_size // distribution_strategy.num_replicas_in_sync 

595 ) 

596 else: 

597 batch_size = global_batch_size 

598 

599 return steps, batch_size 

600 

601 

602def get_batch_dimension(iterator): 

603 shapes = tf.nest.flatten(tf.compat.v1.data.get_output_shapes(iterator)) 

604 # Take the batch size from the first element, as it should be the same for 

605 # all. 

606 dims = shapes[0].dims 

607 return dims[0] if dims else None 

608 

609 

610def get_iterator(dataset, distribution_strategy): 

611 with distribution_strategy.scope(): 

612 iterator = distribution_strategy.make_dataset_iterator(dataset) 

613 initialize_iterator(iterator, distribution_strategy) 

614 return iterator 

615 

616 

617def initialize_iterator(iterator, distribution_strategy): 

618 with distribution_strategy.scope(): 

619 init_op = tf.group(iterator.initializer) 

620 if not tf.executing_eagerly(): 

621 backend.get_session((init_op,)).run(init_op) 

622 

623 

624def _get_input_from_iterator(iterator, model): 

625 """Get elements from the iterator and verify the input shape and type.""" 

626 next_element = iterator.get_next() 

627 

628 # `len(nest.flatten(x))` is going to not count empty elements such as {}. 

629 # len(nest.flatten([[0,1,2], {}])) is 3 and not 4. The `next_element` is 

630 # going to get flattened in `_prepare_feed_values` to work around that. 

631 # Empty elements are going to get filtered out as part of the flattening. 

632 if len(tf.nest.flatten(next_element)) == len(model.inputs): 

633 x = next_element 

634 y = None 

635 sample_weights = None 

636 elif len(tf.nest.flatten(next_element)) == ( 

637 len(model.inputs) + len(model.outputs) 

638 ): 

639 x, y = next_element 

640 sample_weights = None 

641 else: 

642 x, y, sample_weights = next_element 

643 

644 # Validate that all the elements in x and y are of the same type and shape. 

645 validate_distributed_dataset_inputs( 

646 model._distribution_strategy, x, y, sample_weights 

647 ) 

648 return x, y, sample_weights 

649 

650 

651def _prepare_feed_values(model, inputs, targets, sample_weights, mode): 

652 """Prepare feed values to the model execution function. 

653 

654 Args: 

655 model: Model to prepare feed values for. 

656 inputs: List or dict of model inputs. 

657 targets: Optional list of model targets. 

658 sample_weights: Optional list of sample weight arrays. 

659 mode: One of ModeKeys.TRAIN/ModeKeys.TEST/ModeKeys.PREDICT. 

660 

661 Returns: 

662 Feed values for the model in the given mode. 

663 """ 

664 strategy = model._distribution_strategy 

665 inputs, targets, sample_weights = _get_input_from_iterator(inputs, model) 

666 if backend.is_tpu_strategy(strategy): 

667 if sample_weights is not None: 

668 raise ValueError("TPUStrategy does not support sample weights.") 

669 

670 # When the inputs are dict, then we want to flatten it in the same order as 

671 # the input layers, such that the data are fed into the input layers in the 

672 # correct order. 

673 if isinstance(inputs, dict): 

674 inputs = [inputs[key] for key in model._feed_input_names] 

675 if is_distributing_by_cloning(model): 

676 inputs = flatten_per_replica_values(strategy, inputs) 

677 targets = flatten_per_replica_values(strategy, targets) 

678 # Expand 1-dimensional inputs. 

679 # TODO(b/124535720): Remove once this standarize data logic is shared 

680 # with main flow. 

681 inputs, targets = tf.nest.map_structure( 

682 training_utils_v1.standardize_single_array, (inputs, targets) 

683 ) 

684 else: 

685 inputs = training_utils_v1.ModelInputs(inputs).as_list() 

686 

687 if mode == ModeKeys.PREDICT: 

688 sample_weights = [] 

689 targets = [] 

690 elif sample_weights is not None and is_distributing_by_cloning(model): 

691 if tf.executing_eagerly() and not model._compile_distribution: 

692 raise NotImplementedError( 

693 "`sample_weight` is not supported when using " 

694 "tf.distribute.Strategy in eager mode and " 

695 "cloning=True." 

696 ) 

697 sample_weights = flatten_per_replica_values(strategy, sample_weights) 

698 

699 ins = [inputs, targets, sample_weights] 

700 return tuple(ins) 

701 

702 

703def is_distributing_by_cloning(model): 

704 """Decide whether this model is going to be distributed via cloning. 

705 

706 We are going to distribute the model by cloning in graph mode. 

707 

708 Args: 

709 model: Keras model to distribute. 

710 

711 Returns: 

712 True if the `model` is going to be distributed using cloning and False 

713 otherwise. 

714 """ 

715 if ( 

716 backend.is_tpu_strategy(model._distribution_strategy) 

717 and tf.executing_eagerly 

718 ): # b/137580852 

719 return False 

720 elif tf.compat.v1.executing_eagerly_outside_functions(): 

721 return bool(model._compile_distribution) 

722 return True 

723 

724 

725def _custom_compile_for_predict(model): 

726 """Custom compile for TPU predict mode.""" 

727 if not model.built: 

728 # Model is not compilable because it does not know its number of inputs 

729 # and outputs, nor their shapes and names. We will compile after the 

730 # first time the model gets called on training data. 

731 return 

732 model._is_compiled = True 

733 model.total_loss = None 

734 model.train_function = None 

735 model.test_function = None 

736 model.predict_function = None 

737 

738 

739def _build_network_on_replica(model, mode, inputs=None, targets=None): 

740 """Build an updated model on replicas. 

741 

742 We create a new Keras model while sharing the variables from the old graph. 

743 Building a new sub-graph is required since the original keras model creates 

744 placeholders for the input and the output that are not accessible till we 

745 call iterator.get_next() inside the step_fn for `fit`/`evaluate`/`predict`. 

746 

747 The sharing of weights and layers between the old and the new model 

748 guarantee that we're using Strategy variables and any updates on either 

749 model are reflected correctly in callbacks and loop iterations. 

750 

751 We need to make sure we share the optimizers between the old and the new 

752 model as well so that optimizer state is not lost if the user is running fit 

753 multiple times. 

754 

755 Args: 

756 model: Model to be replicated across Replicas 

757 mode: Which of fit/eval/predict is building the distributed network 

758 inputs: Input variables to be passed to the model 

759 targets: Target tensor to be passed to model.compile 

760 

761 Returns: 

762 A new model with shared layers with the old model. 

763 """ 

764 # Need to do imports here since we run into a circular dependency error. 

765 from keras.src import models 

766 from keras.src.engine import sequential 

767 

768 # We rely on the internal methods to avoid having share_weights weights in 

769 # the public API. 

770 if isinstance(model, sequential.Sequential): 

771 updated_model = models._clone_sequential_model( 

772 model, input_tensors=inputs, layer_fn=models.share_weights 

773 ) 

774 else: 

775 updated_model = models._clone_functional_model( 

776 model, input_tensors=inputs, layer_fn=models.share_weights 

777 ) 

778 # Callable losses added directly to a functional Model need to be added 

779 # here. 

780 updated_model._callable_losses = model._callable_losses 

781 

782 # Recast all low precision outputs back to float32 since we only casted the 

783 # inputs to bfloat16 and not targets. This is done so that we can preserve 

784 # precision when calculating the loss value. 

785 def _upcast_low_precision_outputs(output): 

786 if output.dtype == tf.bfloat16: 

787 return tf.cast(output, tf.float32) 

788 else: 

789 return output 

790 

791 updated_model.outputs = [ 

792 _upcast_low_precision_outputs(o) for o in updated_model.outputs 

793 ] 

794 

795 if isinstance(targets, tuple): 

796 targets = tf.nest.flatten(targets) 

797 

798 if mode == ModeKeys.PREDICT and inputs is not None: # TPU predict case 

799 _custom_compile_for_predict(updated_model) 

800 else: 

801 updated_model.compile( 

802 model.optimizer, 

803 model.loss, 

804 metrics=metrics_module.clone_metrics(model._compile_metrics), 

805 loss_weights=model.loss_weights, 

806 sample_weight_mode=model.sample_weight_mode, 

807 weighted_metrics=metrics_module.clone_metrics( 

808 model._compile_weighted_metrics 

809 ), 

810 target_tensors=targets, 

811 ) 

812 return updated_model 

813 

814 

815def _build_distributed_network( 

816 model, strategy, mode, inputs=None, targets=None 

817): 

818 """Create a cloned model on each replica.""" 

819 with backend.get_graph().as_default(), strategy.scope(): 

820 distributed_model = strategy.extended.call_for_each_replica( 

821 _build_network_on_replica, args=(model, mode, inputs, targets) 

822 ) 

823 set_distributed_model(model, mode, distributed_model) 

824 

825 

826def _clone_and_build_model(model, mode, inputs=None, targets=None): 

827 """Clone and build the given keras_model.""" 

828 # We need to set the import here since we run into a circular dependency 

829 # error. 

830 from keras.src import models 

831 

832 cloned_model = models.clone_model(model, input_tensors=inputs) 

833 

834 # Compile and build model. 

835 if isinstance(model.optimizer, optimizers.TFOptimizer): 

836 optimizer = model.optimizer 

837 else: 

838 optimizer_config = model.optimizer.get_config() 

839 optimizer = model.optimizer.__class__.from_config(optimizer_config) 

840 

841 # Recast all low precision outputs back to float32 since we only casted 

842 # the inputs to bfloat16 and not targets. This is done so that we can 

843 # preserve precision when calculating the loss value. 

844 def _upcast_low_precision_outputs(output): 

845 if output.dtype == tf.bfloat16: 

846 return tf.cast(output, tf.float32) 

847 else: 

848 return output 

849 

850 cloned_model.outputs = [ 

851 _upcast_low_precision_outputs(o) for o in cloned_model.outputs 

852 ] 

853 

854 if isinstance(targets, tuple): 

855 targets = tf.nest.flatten(targets) 

856 if mode == ModeKeys.PREDICT and inputs is not None: # TPU predict case 

857 _custom_compile_for_predict(cloned_model) 

858 else: 

859 cloned_model.compile( 

860 optimizer, 

861 model.loss, 

862 metrics=metrics_module.clone_metrics(model._compile_metrics), 

863 loss_weights=model.loss_weights, 

864 sample_weight_mode=model.sample_weight_mode, 

865 weighted_metrics=metrics_module.clone_metrics( 

866 model._compile_weighted_metrics 

867 ), 

868 target_tensors=targets, 

869 ) 

870 return cloned_model 

871 

872 

873def clone_model_on_replicas(model, strategy, mode, inputs=None, targets=None): 

874 """Create a cloned model on each replica.""" 

875 with backend.get_graph().as_default(), strategy.scope(): 

876 distributed_model = strategy.extended.call_for_each_replica( 

877 _clone_and_build_model, args=(model, mode, inputs, targets) 

878 ) 

879 set_distributed_model(model, mode, distributed_model) 

880 if mode == ModeKeys.TRAIN: 

881 model._make_callback_model(distributed_model) 

882 

883 

884def _make_execution_function(model, mode): 

885 """Makes or reuses function to run one step of distributed model 

886 execution.""" 

887 if is_distributing_by_cloning(model): 

888 return _make_execution_function_with_cloning(model, mode) 

889 

890 distributed_function = get_distributed_function(model, mode) 

891 if distributed_function: 

892 return distributed_function 

893 

894 distribution_function = _make_execution_function_without_cloning( 

895 model, mode 

896 ) 

897 set_distributed_function(model, mode, distribution_function) 

898 return distribution_function 

899 

900 

901def _make_execution_function_without_cloning(model, mode): 

902 """Creates a function to run one step of distributed model execution.""" 

903 strategy = model._distribution_strategy 

904 

905 with strategy.scope(): 

906 per_replica_function = _make_replica_execution_function(model, mode) 

907 

908 def distributed_function(input_fn): 

909 """A single step of the distributed execution across replicas.""" 

910 x, y, sample_weights = input_fn() 

911 # Call `Model.{train,test,predict}_on_batch` on every replica 

912 # passing PerReplicas as arguments. On every replica inside this 

913 # call, each PerReplica object will return the value for that 

914 # replica. The outputs are PerReplicas too. 

915 outputs = strategy.run( 

916 per_replica_function, args=(x, y, sample_weights) 

917 ) 

918 # Out of PerReplica outputs reduce or pick values to return. 

919 all_outputs = unwrap_outputs( 

920 strategy, outputs, with_loss_tensor=(mode != ModeKeys.PREDICT) 

921 ) 

922 return all_outputs 

923 

924 if not model.run_eagerly: 

925 distributed_function = tf.function(distributed_function) 

926 

927 def execution_function(input_fn): 

928 # `numpy` translates Tensors to values in Eager mode. 

929 return [out.numpy() for out in distributed_function(input_fn)] 

930 

931 else: 

932 execution_function = distributed_function 

933 

934 return execution_function 

935 

936 

937def _make_replica_execution_function(model, mode): 

938 """A single step of the distributed execution on a replica.""" 

939 if mode == ModeKeys.TRAIN: 

940 func = model.train_on_batch 

941 elif mode == ModeKeys.TEST: 

942 func = model.test_on_batch 

943 else: 

944 

945 def predict_on_batch(x, y=None, sample_weights=None): 

946 del y, sample_weights 

947 return model.predict_on_batch(x) 

948 

949 func = predict_on_batch 

950 

951 if mode != ModeKeys.PREDICT: 

952 # `reset_metrics` is set to False to maintain stateful metrics across 

953 # batch-level calls. 

954 func = functools.partial(func, reset_metrics=False) 

955 

956 return func 

957 

958 

959def _make_replicated_models_with_cloning(model, mode): 

960 """Build models on each replica.""" 

961 strategy = model._distribution_strategy 

962 

963 # If distributed_model is not built, create one for `mode`. 

964 if model._compile_distribution: 

965 clone_model_on_replicas(model, strategy, mode) 

966 else: 

967 _build_distributed_network(model, strategy, mode) 

968 

969 

970def _make_execution_function_with_cloning(model, mode): 

971 """Clones or re-uses models to run one step of distributed model 

972 execution.""" 

973 distributed_model = get_distributed_model(model, mode) 

974 # TODO(b/134069401): Create a cache for the distributed model and exec 

975 # function that incorporates additional attributes to be part of the cache 

976 # key than just the mode. 

977 # If distributed model for a particular `mode` is already built, use the 

978 # `_distribution_function` on that distributed model. 

979 # If you have updated the sample_weight_mode on the model, then you will 

980 # need to recompile metrics and recreate the execution function. This is 

981 # indicated by the `_recompile_exec_function` property. 

982 if ( 

983 distributed_model 

984 and hasattr(distributed_model, "_distribution_function") 

985 and not ( 

986 hasattr(distributed_model, "_recompile_exec_function") 

987 and distributed_model._recompile_exec_function 

988 ) 

989 ): 

990 return distributed_model._distributed_function 

991 

992 if not distributed_model: 

993 _make_replicated_models_with_cloning(model, mode) 

994 distributed_model = get_distributed_model(model, mode) 

995 assert distributed_model 

996 

997 # Also create an execution function on that distributed model. 

998 if tf.executing_eagerly(): 

999 distributed_function = _make_eager_execution_function(model, mode) 

1000 else: 

1001 distributed_function = _make_graph_execution_function(model, mode) 

1002 

1003 # We cache the distributed execution function on the model since creating 

1004 # distributed models and execution functions are expensive. 

1005 distributed_model._distributed_function = distributed_function 

1006 distributed_model._recompile_exec_function = False 

1007 return distributed_function 

1008 

1009 

1010def _make_graph_execution_function(model, mode): 

1011 """Makes function to run one step of distributed model in graph mode.""" 

1012 

1013 def _per_replica_function(model): 

1014 f = model._make_execution_function(mode) 

1015 return (f.inputs, f.outputs, f.updates_op, f.session_kwargs) 

1016 

1017 strategy = model._distribution_strategy 

1018 with strategy.scope(): 

1019 # Create train ops on each of the devices when we call 

1020 # `_per_replica_fit_function`. 

1021 ( 

1022 grouped_inputs, 

1023 grouped_outputs, 

1024 grouped_updates, 

1025 grouped_session_args, 

1026 ) = strategy.extended.call_for_each_replica( 

1027 _per_replica_function, args=(get_distributed_model(model, mode),) 

1028 ) 

1029 

1030 # Initialize the variables in the replicated model. This is necessary 

1031 # for multi-worker training because on some workers, initialization is 

1032 # not needed. This method does initialization or waiting for 

1033 # initialization according to the context object of distribute 

1034 # coordinator. 

1035 init_restore_or_wait_for_variables() 

1036 

1037 # Unwrap all the per device values returned from 

1038 # `call_for_each_replica`. Unwrapping per device values gives you a 

1039 # list of values that can be used to construct a new train function that 

1040 # is composed of update ops on all the devices over which the model is 

1041 # distributed. 

1042 ( 

1043 all_inputs, 

1044 all_outputs, 

1045 all_updates, 

1046 all_session_args, 

1047 ) = unwrap_values( 

1048 strategy, 

1049 grouped_inputs, 

1050 grouped_outputs, 

1051 grouped_updates, 

1052 grouped_session_args, 

1053 with_loss_tensor=(mode != ModeKeys.PREDICT), 

1054 ) 

1055 

1056 return backend.function( 

1057 all_inputs, 

1058 all_outputs, 

1059 updates=all_updates, 

1060 name=f"distributed_{mode}_function", 

1061 **all_session_args, 

1062 ) 

1063 

1064 

1065def _make_eager_execution_function(model, mode): 

1066 """Makes function to run one step of distributed model eager execution.""" 

1067 

1068 def _per_replica_function(model): 

1069 f = model._make_execution_function(mode) 

1070 return (f.inputs, f.outputs) 

1071 

1072 # NOTE(priyag): Try creating a new FuncGraph within DS scope instead of 

1073 # using the global one. 

1074 strategy = model._distribution_strategy 

1075 global_graph = backend.get_graph() 

1076 

1077 with global_graph.as_default(), strategy.scope(): 

1078 # First we gather the relevant portions of the model across all 

1079 # replicas. `backend._scratch_graph(global_graph)` signals to Keras 

1080 # that it should not lift to a separate graph when creating the 

1081 # per-replica functions. 

1082 with backend._scratch_graph(global_graph): 

1083 # Create train ops on each of the devices when we call 

1084 # `_per_replica_fit_function`. 

1085 grouped = strategy.extended.call_for_each_replica( 

1086 _per_replica_function, 

1087 args=(get_distributed_model(model, mode),), 

1088 ) 

1089 grouped_inputs, grouped_outputs = grouped 

1090 

1091 # Unwrap all the per device values returned from 

1092 # `call_for_each_replica`. Unwrapping per device values gives you a 

1093 # list of values that can be used to construct a new train function 

1094 # that is composed of inputs/outputs on all the devices over which 

1095 # the model is distributed. 

1096 (all_inputs, all_outputs, _, _) = unwrap_values( 

1097 strategy, 

1098 grouped_inputs, 

1099 grouped_outputs, 

1100 with_loss_tensor=(mode != ModeKeys.PREDICT), 

1101 ) 

1102 

1103 # Finally, a joint Keras function is created; this one will be created 

1104 # in a separate FuncGraph. 

1105 return backend.function( 

1106 all_inputs, 

1107 all_outputs, 

1108 name=f"eager_distributed_{mode}_function", 

1109 ) 

1110 

1111 

1112def _copy_weights_to_distributed_model(original_model, mode): 

1113 """Copies weights from original model to distributed models.""" 

1114 strategy = original_model._distribution_strategy 

1115 distributed_model = get_distributed_model(original_model, mode) 

1116 if strategy: 

1117 # Copy the weights from the original model to each of the replicated 

1118 # models. 

1119 orig_model_weights = original_model.get_weights() 

1120 first_model = strategy.unwrap(distributed_model)[0] 

1121 set_weights(strategy, first_model, orig_model_weights) 

1122 

1123 

1124def _copy_weights_to_original_model(model, mode): 

1125 """Copies weights from first distributed model back to original model.""" 

1126 if model._distribution_strategy and mode == ModeKeys.TRAIN: 

1127 distributed_model = get_distributed_model(model, mode) 

1128 updated_weights = model._distribution_strategy.unwrap( 

1129 distributed_model 

1130 )[0].get_weights() 

1131 model.set_weights(updated_weights) 

1132 

1133 

1134def _per_replica_aggregate_batch(strategy, batch_outs, model, mode): 

1135 """Aggregates the per-replica batch-level outputs from a distributed 

1136 step.""" 

1137 if strategy is not None and mode == ModeKeys.PREDICT: 

1138 total_batch_outs = [] 

1139 for i in range(len(model.outputs)): 

1140 num_replicas = strategy.num_replicas_in_sync 

1141 nested_outs = batch_outs[ 

1142 i * num_replicas : i * num_replicas + num_replicas 

1143 ] 

1144 total_batch_outs.append( 

1145 concat_along_batch_dimension(tf.nest.flatten(nested_outs)) 

1146 ) 

1147 return total_batch_outs 

1148 return batch_outs 

1149 

1150 

1151def _reset_metrics(model): 

1152 if model._distribution_strategy: 

1153 for mode in [ModeKeys.TRAIN, ModeKeys.TEST, ModeKeys.PREDICT]: 

1154 distributed_model = get_distributed_model(model, mode) 

1155 if distributed_model: 

1156 first_model = model._distribution_strategy.unwrap( 

1157 distributed_model 

1158 )[0] 

1159 first_model.reset_metrics() 

1160 

1161 

1162def get_distributed_model(model, mode): 

1163 key = _generate_cache_key(mode) 

1164 return model._distributed_model_cache.get(key, None) 

1165 

1166 

1167def set_distributed_model(model, mode, distributed_model): 

1168 key = _generate_cache_key(mode) 

1169 model._distributed_model_cache[key] = distributed_model 

1170 

1171 

1172def get_distributed_function(model, mode): 

1173 key = _generate_cache_key(mode) 

1174 return model._distributed_function_cache.get(key, None) 

1175 

1176 

1177def set_distributed_function(model, mode, distributed_function): 

1178 key = _generate_cache_key(mode) 

1179 model._distributed_function_cache[key] = distributed_function 

1180 

1181 

1182def _generate_cache_key(mode): 

1183 key = hash(mode) 

1184 return key 

1185 

1186 

1187@tf_contextlib.contextmanager 

1188def distributed_scope(strategy, learning_phase): 

1189 with strategy.scope(), backend.learning_phase_scope(learning_phase): 

1190 yield 

1191 

1192 

1193def is_current_worker_chief(): 

1194 return dc.get_current_worker_context().is_chief 

1195 

1196 

1197def filter_distributed_callbacks(callbacks_list, model): 

1198 """Filter Callbacks based on the worker context when running multi-worker. 

1199 

1200 Args: 

1201 callbacks_list: A list of `Callback` instances. 

1202 model: Keras model instance. 

1203 

1204 Returns: 

1205 The list of `Callback` instances that should be run on this worker. 

1206 """ 

1207 

1208 if not model._in_multi_worker_mode(): 

1209 raise ValueError( 

1210 "filter_distributed_callbacks() should only be called when Keras " 

1211 "is in multi worker mode." 

1212 ) 

1213 

1214 callbacks_list = callbacks_list or [] 

1215 if not [ 

1216 c for c in callbacks_list if isinstance(c, callbacks.ModelCheckpoint) 

1217 ]: 

1218 # TODO(rchao): Consider providing a ModelCheckpoint here if the user 

1219 # fails to (possibly with tempfile directory). 

1220 logging.warning( 

1221 "ModelCheckpoint callback is not provided. " 

1222 "Workers will need to restart training if any fails." 

1223 ) 

1224 

1225 if callbacks_list is None or is_current_worker_chief(): 

1226 return callbacks_list 

1227 

1228 # Some Callbacks should only run on the chief worker. 

1229 return [ 

1230 callback 

1231 for callback in callbacks_list 

1232 if not callback._chief_worker_only 

1233 ] 

1234 

1235 

1236def _update_sample_weight_modes(model, mode, sample_weights): 

1237 """Update sample_weight_mode of the distributed model.""" 

1238 if is_distributing_by_cloning(model): 

1239 distributed_model = get_distributed_model(model, mode) 

1240 if not distributed_model: 

1241 _make_replicated_models_with_cloning(model, mode) 

1242 distributed_model = get_distributed_model(model, mode) 

1243 distributed_model._recompile_exec_function = any( 

1244 [e.sample_weights_mismatch() for e in model._training_endpoints] 

1245 ) 

1246 

1247 if sample_weights: 

1248 distributed_models = flatten_per_replica_values( 

1249 model._distribution_strategy, distributed_model 

1250 ) 

1251 # sample_weights is a tuple of 1 list where the number of elements 

1252 # in the list is equal to the number of replicas in sync. 

1253 sample_weights = sample_weights[0] 

1254 if sample_weights and None not in sample_weights: 

1255 for m, sw in zip(distributed_models, sample_weights): 

1256 m._update_sample_weight_modes(sample_weights=[sw]) 

1257 

1258 

1259def concat_along_batch_dimension(outputs): 

1260 """Concats prediction outputs along the batch dimension.""" 

1261 if isinstance(outputs[0], tf.SparseTensor): 

1262 return tf.sparse.concat(axis=0, sp_inputs=outputs) 

1263 if isinstance(outputs[0], tf.RaggedTensor): 

1264 return tf.concat(outputs, axis=0) 

1265 return np.concatenate(outputs) 

1266