Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/distribute/distributed_training_utils_v1.py: 18%

451 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Utilities related to distributed training.""" 

16# pylint:disable=protected-access 

17 

18import functools 

19 

20import numpy as np 

21 

22from tensorflow.python.data.ops import dataset_ops 

23from tensorflow.python.data.ops import iterator_ops 

24from tensorflow.python.distribute import reduce_util 

25from tensorflow.python.eager import context 

26from tensorflow.python.eager import def_function 

27from tensorflow.python.framework import dtypes 

28from tensorflow.python.framework import ops 

29from tensorflow.python.framework import sparse_tensor 

30from tensorflow.python.framework import tensor_util 

31from tensorflow.python.keras import backend 

32from tensorflow.python.keras import callbacks 

33from tensorflow.python.keras import metrics as metrics_module 

34from tensorflow.python.keras import optimizers 

35from tensorflow.python.keras.distribute import distribute_coordinator_utils as dc 

36from tensorflow.python.keras.distribute import distributed_training_utils as dist_utils 

37from tensorflow.python.keras.engine import training_utils_v1 

38from tensorflow.python.keras.optimizer_v2 import optimizer_v2 

39from tensorflow.python.keras.utils import tf_contextlib 

40from tensorflow.python.keras.utils.mode_keys import ModeKeys 

41from tensorflow.python.ops import array_ops 

42from tensorflow.python.ops import control_flow_ops 

43from tensorflow.python.ops import math_ops 

44from tensorflow.python.ops import sparse_ops 

45from tensorflow.python.ops import variable_v1 

46from tensorflow.python.ops.ragged import ragged_tensor 

47from tensorflow.python.platform import tf_logging as logging 

48from tensorflow.python.util import nest 

49 

50 

51def set_weights(distribution_strategy, dist_model, weights): 

52 """Sets the weights of the replicated models. 

53 

54 The weights of the replicated models are set to the weights of the original 

55 model. The weights of the replicated model are Mirrored variables and hence 

56 we need to use the `update` call within a DistributionStrategy scope. 

57 

58 Args: 

59 distribution_strategy: DistributionStrategy used to distribute training 

60 and validation. 

61 dist_model: The replicated models on the different devices. 

62 weights: The weights of the original model. 

63 """ 

64 assign_ops = [] 

65 for layer in dist_model.layers: 

66 num_param = len(layer.weights) 

67 layer_weights = weights[:num_param] 

68 for sw, w in zip(layer.weights, layer_weights): 

69 if ops.executing_eagerly_outside_functions(): 

70 sw.assign(w) 

71 else: 

72 assign_ops.append(distribution_strategy.unwrap(sw.assign(w))) 

73 weights = weights[num_param:] 

74 

75 if not ops.executing_eagerly_outside_functions(): 

76 backend.get_session(assign_ops).run(assign_ops) 

77 

78 

79def unwrap_values(distribution_strategy, grouped_inputs, grouped_outputs, 

80 grouped_updates=None, grouped_session_args=None, 

81 with_loss_tensor=False): 

82 """Unwrap the list of values contained in the PerReplica parameters. 

83 

84 This function calls `flatten_per_replica_values` to parse each of the input 

85 parameters into a list of values on the different devices. If we set 

86 `with_loss_tensor` to be True, we also call `reduce` on the list of losses on 

87 the different devices to give us one loss tensor. 

88 

89 Args: 

90 distribution_strategy: DistributionStrategy used to distribute training and 

91 validation. 

92 grouped_inputs: PerReplica inputs returned from the train or test function 

93 that we ran on each device. 

94 grouped_outputs: PerReplica outputs returned from the train or test function 

95 that we ran on each device. 

96 grouped_updates: PerReplica updates returned from the train or test function 

97 that we ran on each device. 

98 grouped_session_args: PerReplica session args returned from the train or 

99 test function that we ran on each device. 

100 with_loss_tensor: Boolean that indicates if we need to add the reduced loss 

101 tensor as one of the outputs. 

102 

103 Returns: 

104 Values of each of the PerReplica parameters. 

105 

106 """ 

107 # Unwrap per device values returned from each model's train function. 

108 # This will be used to construct the main train function. 

109 all_inputs = flatten_per_replica_values(distribution_strategy, 

110 grouped_inputs) 

111 all_outputs = unwrap_outputs(distribution_strategy, grouped_outputs, 

112 with_loss_tensor) 

113 

114 if grouped_updates: 

115 all_updates = flatten_per_replica_values(distribution_strategy, 

116 grouped_updates) 

117 else: 

118 all_updates = None 

119 

120 all_session_args = {} 

121 if grouped_session_args: 

122 grouped_feed_dict = grouped_session_args.get('feed_dict') 

123 if grouped_feed_dict: 

124 all_session_args['feed_dict'] = flatten_per_replica_values( 

125 distribution_strategy, grouped_feed_dict) 

126 

127 grouped_fetches = grouped_session_args.get('fetches') 

128 if grouped_fetches: 

129 all_session_args['fetches'] = flatten_per_replica_values( 

130 distribution_strategy, grouped_fetches) 

131 

132 # TODO(priyag): Return only non empty/None values 

133 return all_inputs, all_outputs, all_updates, all_session_args 

134 

135 

136def unwrap_output_dict(strategy, grouped_outputs, mode): 

137 """Unwrap the list of outputs contained in the PerReplica parameters.""" 

138 if mode == ModeKeys.PREDICT: 

139 return flatten_per_replica_values(strategy, grouped_outputs) 

140 

141 # In the case of fit/eval, the grouped_outputs is a dict, whereas in predict, 

142 # the output is as same structure as model output. They need to be treated 

143 # differently 

144 total_loss = strategy.reduce(reduce_util.ReduceOp.SUM, 

145 grouped_outputs['total_loss'][0], axis=None) 

146 output_losses = flatten_per_replica_values(strategy, 

147 grouped_outputs['output_losses']) 

148 metrics = flatten_per_replica_values(strategy, 

149 grouped_outputs['metrics']) 

150 batch_size = strategy.reduce(reduce_util.ReduceOp.SUM, 

151 grouped_outputs['batch_size'], axis=None) 

152 if (backend.is_tpu_strategy(strategy) and 

153 ops.executing_eagerly_outside_functions()): 

154 # Choose 1 value per replica in the TPU case since all replicas produce the 

155 # same output. 

156 # We only do this in eager mode for now since this function is used in 

157 # both graph and eager mode and in the graph case we currently don't use 

158 # experimental_run so would need to be removed when we converge the graph 

159 # code path as well. 

160 output_losses = output_losses[::strategy.num_replicas_in_sync] 

161 metrics = metrics[::strategy.num_replicas_in_sync] 

162 return {'total_loss': [total_loss], 

163 'output_losses': output_losses, 

164 'metrics': metrics, 

165 'batch_size': batch_size} 

166 

167 

168def unwrap_outputs(distribution_strategy, grouped_outputs, 

169 with_loss_tensor=False): 

170 """Unwrap the list of outputs contained in the PerReplica parameters. 

171 

172 This function calls `flatten_per_replica_values` to parse each of the input 

173 parameters into a list of outputs on the different devices. If we set 

174 `with_loss_tensor` to be True, we also call `reduce` on the list of losses on 

175 the different devices to give us one loss tensor. 

176 

177 Args: 

178 distribution_strategy: DistributionStrategy used to distribute training and 

179 validation. 

180 grouped_outputs: PerReplica outputs returned from the train or test function 

181 that we ran on each device. 

182 with_loss_tensor: Boolean that indicates if we need to add the reduced loss 

183 tensor as one of the outputs. 

184 

185 Returns: 

186 Values of each of the PerReplica outputs. 

187 

188 """ 

189 if not with_loss_tensor: 

190 return flatten_per_replica_values(distribution_strategy, 

191 grouped_outputs) 

192 

193 if not isinstance(grouped_outputs, list): 

194 grouped_outputs = [grouped_outputs] 

195 # reduce loss tensor before adding it to the list of fetches 

196 loss = distribution_strategy.reduce(reduce_util.ReduceOp.SUM, 

197 grouped_outputs[0], axis=None) 

198 all_outputs = flatten_per_replica_values(distribution_strategy, 

199 grouped_outputs[1:]) 

200 if (backend.is_tpu_strategy(distribution_strategy) and 

201 ops.executing_eagerly_outside_functions()): 

202 # Choose 1 value per replica in the TPU case since all replicas produce the 

203 # same output. 

204 # We only do this in eager mode for now since this function is used in 

205 # both graph and eager mode and in the graph case we currently don't use 

206 # experimental_run so would need to be removed when we converge the graph 

207 # code path as well. 

208 all_outputs = all_outputs[::distribution_strategy.num_replicas_in_sync] 

209 return [loss] + all_outputs 

210 

211 

212def flatten_per_replica_values(distribution_strategy, per_replica_values): 

213 """Unwraps and flattens a nest of PerReplica parameters. 

214 

215 PerReplica values have one value associated with each device. Each entry in 

216 the PerReplica dict has a device `key` and the corresponding value on the 

217 device as the `value`. In this function we take a PerReplica value or a list 

218 of PerReplica values and return all the values in the PerReplica dict. 

219 

220 Args: 

221 distribution_strategy: DistributionStrategy used to distribute training and 

222 validation. 

223 per_replica_values: List of PerReplica object or a single PerReplica object. 

224 

225 Returns: 

226 List of values of all the PerReplica objects. 

227 

228 """ 

229 # pylint: disable=g-complex-comprehension 

230 # This function takes a PerReplica object or a list of PerReplica objects and 

231 # returns all the values associated with it. 

232 return [e for flattened in nest.flatten(per_replica_values) 

233 for e in distribution_strategy.unwrap(flattened)] 

234 

235 

236def validate_callbacks(input_callbacks, optimizer): 

237 """Validate whether given callbacks are supported by DistributionStrategy. 

238 

239 Args: 

240 input_callbacks: List of callbacks passed by the user to fit. 

241 optimizer: Optimizer instance used to train the model. 

242 

243 Raises: 

244 ValueError: If `LearningRateScheduler` or `ReduceLROnPlateau` is one of the 

245 callbacks passed. 

246 ValueError: If `write_grads` is one of the parameters passed as part of the 

247 TensorBoard callback. 

248 """ 

249 if input_callbacks: 

250 for callback in input_callbacks: 

251 if isinstance(callback, (callbacks.LearningRateScheduler, 

252 callbacks.ReduceLROnPlateau)): 

253 

254 if not isinstance(optimizer, optimizer_v2.OptimizerV2): 

255 raise ValueError('You must specify a Keras Optimizer V2 when using ' 

256 '%s callback with DistributionStrategy.' % callback) 

257 

258 # If users want to use the TensorBoard callback they cannot use certain 

259 # features of the callback that involve accessing model attributes and 

260 # running ops. 

261 if isinstance(callback, callbacks.TensorBoard): 

262 if getattr(callback, 'write_grads', False): 

263 logging.warning( 

264 UserWarning( 

265 '`write_grads` in the TensorBoard callback is not supported ' 

266 'when using DistributionStrategy. Setting `write_grads` ' 

267 'to `False`.')) 

268 callback.write_grads = False 

269 

270 

271def validate_distributed_dataset_inputs(distribution_strategy, x, y, 

272 sample_weights=None): 

273 """Validate all the components of a DistributedValue Dataset input. 

274 

275 Args: 

276 distribution_strategy: The current DistributionStrategy used to call 

277 `fit`/`evaluate`. 

278 x: Input Dataset DistributedValue object. For example, when we use 

279 `MirroredStrategy` this is a PerReplica object with a tensor for each 

280 device set in the dict. x can also be a tuple or dict. The keys of the 

281 dict should match the names of the input layers of the model. 

282 y: Target Dataset DistributedValue object. For example, when we use 

283 `MirroredStrategy` this is a PerReplica object with a tensor for each 

284 device set in the dict. y can also be a tuple or dict. The keys of the 

285 dict should match the names of the output layers of the model. 

286 sample_weights: Sample weights Dataset DistributedValue object. For example, 

287 when we use `MirroredStrategy` this is a PerReplica object with a tensor 

288 for each device set in the dict. 

289 

290 Returns: 

291 The unwrapped values list of the x and y DistributedValues inputs. 

292 

293 Raises: 

294 ValueError: If x and y do not have support for being evaluated as tensors. 

295 or if x and y contain elements that are not tensors or if x and y 

296 contain elements that have a shape or dtype mismatch. 

297 """ 

298 # If the input and target used to call the model are not dataset tensors, 

299 # we need to raise an error. When using a DistributionStrategy, the input 

300 # and targets to a model should be from a `tf.data.Dataset`. 

301 

302 # If each element of x and y are not tensors, we cannot standardize and 

303 # validate the input and targets. 

304 x_values_list = validate_per_replica_inputs(distribution_strategy, x) 

305 

306 if y is not None: 

307 y_values_list = validate_per_replica_inputs(distribution_strategy, y) 

308 else: 

309 y_values_list = None 

310 

311 if sample_weights is not None: 

312 sample_weights_list = validate_per_replica_inputs(distribution_strategy, 

313 sample_weights) 

314 else: 

315 sample_weights_list = None 

316 

317 # Return the unwrapped values to avoid calling `unwrap` a second time. 

318 return x_values_list, y_values_list, sample_weights_list 

319 

320 

321def validate_per_replica_inputs(distribution_strategy, x): 

322 """Validates PerReplica dataset input list. 

323 

324 Args: 

325 distribution_strategy: The current DistributionStrategy used to call 

326 `fit`, `evaluate` and `predict`. 

327 x: A list of PerReplica objects that represent the input or 

328 target values. 

329 

330 Returns: 

331 List containing the first element of each of the PerReplica objects in 

332 the input list. 

333 

334 Raises: 

335 ValueError: If any of the objects in the `per_replica_list` is not a tensor. 

336 

337 """ 

338 # Convert the inputs and targets into a list of PerReplica objects. 

339 per_replica_list = nest.flatten(x, expand_composites=True) 

340 x_values_list = [] 

341 for x in per_replica_list: 

342 # At this point x should contain only tensors. 

343 x_values = distribution_strategy.unwrap(x) 

344 for value in x_values: 

345 if not tensor_util.is_tf_type(value): 

346 raise ValueError('Dataset input to the model should be tensors instead ' 

347 'they are of type {}'.format(type(value))) 

348 

349 if not context.executing_eagerly(): 

350 # Validate that the shape and dtype of all the elements in x are the same. 

351 validate_all_tensor_shapes(x, x_values) 

352 validate_all_tensor_types(x, x_values) 

353 

354 x_values_list.append(x_values[0]) 

355 return x_values_list 

356 

357 

358def validate_all_tensor_types(x, x_values): 

359 x_dtype = x_values[0].dtype 

360 for i in range(1, len(x_values)): 

361 if x_dtype != x_values[i].dtype: 

362 raise ValueError('Input tensor dtypes do not match for distributed tensor' 

363 ' inputs {}'.format(x)) 

364 

365 

366def validate_all_tensor_shapes(x, x_values): 

367 # Validate that the shape of all the elements in x have the same shape 

368 x_shape = x_values[0].shape.as_list() 

369 for i in range(1, len(x_values)): 

370 if x_shape != x_values[i].shape.as_list(): 

371 raise ValueError('Input tensor shapes do not match for distributed tensor' 

372 ' inputs {}'.format(x)) 

373 

374 

375def _wait_for_variable_initialization(session): 

376 """Utility to wait for variables to be initialized.""" 

377 all_variables = backend._get_variables(backend.get_graph()) # pylint: disable=protected-access 

378 candidate_vars = [] 

379 for v in all_variables: 

380 if not getattr(v, '_keras_initialized', False): 

381 candidate_vars.append(v) 

382 

383 if not candidate_vars: 

384 return 

385 

386 while True: 

387 is_initialized = session.run( 

388 [variable_v1.is_variable_initialized(v) for v in candidate_vars]) 

389 uninitialized_vars = [] 

390 for flag, v in zip(is_initialized, candidate_vars): 

391 if not flag: 

392 uninitialized_vars.append(v) 

393 v._keras_initialized = True # pylint: disable=protected-access 

394 if not uninitialized_vars: 

395 break 

396 

397 

398def init_restore_or_wait_for_variables(): 

399 """Initialize or restore variables or wait for variables to be initialized.""" 

400 backend._initialize_variables(backend._get_session()) # pylint: disable=protected-access 

401 

402 

403def validate_inputs(x, y): 

404 """Validate inputs when using DistributionStrategy. 

405 

406 Args: 

407 x: Model Inputs. 

408 y: Model Targets. 

409 

410 Raises: 

411 ValueError: if input is not a Dataset or a numpy array(when we use 

412 MirroredStrategy). 

413 """ 

414 if (isinstance(x, iterator_ops.Iterator) or 

415 isinstance(y, iterator_ops.Iterator)): 

416 raise ValueError('`DistributionStrategy` does not support inputs of type ' 

417 'Iterator. You must pass a `tf.data.Dataset` object or a ' 

418 'numpy array as input.') 

419 

420 

421def is_dataset_shape_fully_defined(dataset): 

422 """Returns whether a dataset contains a final partial batch.""" 

423 shapes = nest.flatten(dataset_ops.get_legacy_output_shapes(dataset)) 

424 unknown_shapes = [s for s in shapes if not s.is_fully_defined()] 

425 return not unknown_shapes 

426 

427 

428def process_batch_and_step_size(strategy, 

429 inputs, 

430 batch_size, 

431 steps_per_epoch, 

432 mode, 

433 validation_split=0.): 

434 """Process the batch size and step size based on input and dist strategy.""" 

435 first_x_value = nest.flatten(inputs)[0] 

436 if isinstance(first_x_value, np.ndarray): 

437 num_samples = first_x_value.shape[0] 

438 if validation_split and 0. < validation_split < 1.: 

439 num_samples = int(num_samples * (1 - validation_split)) 

440 # Until support for partial batch is implemented across all 

441 # functions and distribution strategy, we pass `mode` to selectively 

442 # relax the constraint to consume all the training samples. 

443 steps_per_epoch, batch_size = get_input_params( 

444 strategy, num_samples, steps_per_epoch, batch_size, mode=mode) 

445 return batch_size, steps_per_epoch 

446 

447 

448def get_input_params(distribution_strategy, 

449 num_samples, 

450 steps, 

451 batch_size, 

452 mode=None): 

453 """Calculate the number of batches and steps/steps_per_epoch. 

454 

455 Args: 

456 distribution_strategy: The DistributionStrategy used to compile the model. 

457 num_samples: The number of samples from which we determine the batch size 

458 and steps. 

459 steps: The specified number of steps. 

460 batch_size: The specified batch_size. 

461 mode: ModeKey representing whether input will be used for training, 

462 evaluation, or prediction. This is used to relax the constraints on 

463 consuming all the training samples to keep compatibility till we support 

464 partial batches. If none, then partial batches are not allowed. 

465 

466 Returns: 

467 steps: The steps or steps_per_epoch argument depending on if a user is 

468 calling `fit`, `evaluate` or `predict`. If the is_training flag is set 

469 we don't require the number of samples to be used completely. 

470 batch_size: The batch size to be used in model iterations. 

471 

472 Raises: 

473 ValueError: If the number of batches or steps evaluates to 0. 

474 

475 """ 

476 # TODO(b/118776054): Use global batch size for Keras/DS support. 

477 # Currently this is only supported in TPUStrategy and CoreMirroredStrategy. 

478 use_per_replica_batch = not dist_utils.global_batch_size_supported( 

479 distribution_strategy) 

480 

481 # TODO(b/128995245): In eager mode, uneven batch sizes are allowed except for 

482 # `fit()` on TPUStrategy. 

483 # In graph mode, the zero batch case in batch norm is not handled due to 

484 # XLA-GPU regression. Uneven batch sizes are not allowed except 

485 # for `test()` and `predict()` on TPUStrategy. 

486 if context.executing_eagerly(): 

487 allow_partial_batch = ( 

488 mode != ModeKeys.TRAIN or 

489 not backend.is_tpu_strategy(distribution_strategy)) 

490 else: 

491 allow_partial_batch = ( 

492 mode == ModeKeys.TRAIN or 

493 ((mode == ModeKeys.PREDICT or mode == ModeKeys.TEST) and 

494 backend.is_tpu_strategy(distribution_strategy))) 

495 

496 if steps is None: 

497 if batch_size is None: 

498 # If neither the batch size or number of steps are set. We choose the 

499 # global batch size as the minimum of number of samples and 32. 32 is 

500 # chosen to provide backward compatibility. 

501 global_batch_size = min(num_samples, 32) 

502 else: 

503 # If the user provided the batch size we need to handle the case 

504 # between different strategies that use the global/per-replica batch size 

505 global_batch_size = batch_size 

506 if use_per_replica_batch: 

507 global_batch_size *= distribution_strategy.num_replicas_in_sync 

508 if allow_partial_batch: 

509 steps = np.ceil(num_samples / global_batch_size).astype(int) 

510 else: 

511 if num_samples % global_batch_size: 

512 raise ValueError('The number of samples %s is not divisible by ' 

513 'batch size %s.' % (num_samples, global_batch_size)) 

514 steps = num_samples // global_batch_size 

515 else: 

516 if batch_size is None: 

517 # We calculate the batch size based on the number of steps specified 

518 if num_samples % steps: 

519 raise ValueError('The number of samples %s is not divisible by ' 

520 'steps %s. Please change the number of steps to a ' 

521 'value that can consume all the samples' % ( 

522 num_samples, steps)) 

523 global_batch_size = num_samples // steps 

524 else: 

525 # If the user provided the batch size we need to handle the case 

526 # between different strategies that use the global/per-replica batch size 

527 global_batch_size = batch_size 

528 if use_per_replica_batch: 

529 global_batch_size *= distribution_strategy.num_replicas_in_sync 

530 

531 min_num_samples = global_batch_size * steps 

532 if allow_partial_batch: 

533 min_num_samples = global_batch_size * (steps-1) + 1 if steps > 1 else 0 

534 

535 if num_samples < min_num_samples: 

536 raise ValueError('Number of samples %s is less than samples required ' 

537 'for specified batch_size %s and steps %s' % ( 

538 num_samples, global_batch_size, steps)) 

539 

540 # We need to return the per replica or global batch size based on the strategy 

541 if use_per_replica_batch: 

542 if global_batch_size % distribution_strategy.num_replicas_in_sync: 

543 raise ValueError( 

544 'The batch size (%s) could not be sharded evenly across the sync ' 

545 'replicas (%s) in the distribution strategy.' % ( 

546 global_batch_size, distribution_strategy.num_replicas_in_sync)) 

547 batch_size = global_batch_size // distribution_strategy.num_replicas_in_sync 

548 else: 

549 batch_size = global_batch_size 

550 

551 return steps, batch_size 

552 

553 

554def get_batch_dimension(iterator): 

555 shapes = nest.flatten(dataset_ops.get_legacy_output_shapes(iterator)) 

556 # Take the batch size from the first element, as it should be the same for 

557 # all. 

558 dims = shapes[0].dims 

559 return dims[0] if dims else None 

560 

561 

562def get_iterator(dataset, distribution_strategy): 

563 with distribution_strategy.scope(): 

564 iterator = distribution_strategy.make_dataset_iterator(dataset) 

565 initialize_iterator(iterator, distribution_strategy) 

566 return iterator 

567 

568 

569def initialize_iterator(iterator, distribution_strategy): 

570 with distribution_strategy.scope(): 

571 init_op = control_flow_ops.group(iterator.initializer) 

572 if not context.executing_eagerly(): 

573 backend.get_session((init_op,)).run(init_op) 

574 

575 

576def _get_input_from_iterator(iterator, model): 

577 """Get elements from the iterator and verify the input shape and type.""" 

578 next_element = iterator.get_next() 

579 

580 # `len(nest.flatten(x))` is going to not count empty elements such as {}. 

581 # len(nest.flatten([[0,1,2], {}])) is 3 and not 4. The `next_element` is 

582 # going to get flattened in `_prepare_feed_values` to work around that. Empty 

583 # elements are going to get filtered out as part of the flattening. 

584 if len(nest.flatten(next_element)) == len(model.inputs): 

585 x = next_element 

586 y = None 

587 sample_weights = None 

588 elif len(nest.flatten(next_element)) == (len(model.inputs) + 

589 len(model.outputs)): 

590 x, y = next_element 

591 sample_weights = None 

592 else: 

593 x, y, sample_weights = next_element 

594 

595 # Validate that all the elements in x and y are of the same type and shape. 

596 validate_distributed_dataset_inputs( 

597 model._distribution_strategy, x, y, sample_weights) 

598 return x, y, sample_weights 

599 

600 

601def _prepare_feed_values(model, inputs, targets, sample_weights, mode): 

602 """Prepare feed values to the model execution function. 

603 

604 Args: 

605 model: Model to prepare feed values for. 

606 inputs: List or dict of model inputs. 

607 targets: Optional list of model targets. 

608 sample_weights: Optional list of sample weight arrays. 

609 mode: One of ModeKeys.TRAIN/ModeKeys.TEST/ModeKeys.PREDICT. 

610 

611 Returns: 

612 Feed values for the model in the given mode. 

613 """ 

614 strategy = model._distribution_strategy 

615 inputs, targets, sample_weights = _get_input_from_iterator(inputs, model) 

616 if backend.is_tpu_strategy(strategy): 

617 if sample_weights is not None: 

618 raise ValueError('TPUStrategy does not support sample weights.') 

619 

620 # When the inputs are dict, then we want to flatten it in the same order as 

621 # the input layers, such that the data are fed into the input layers in the 

622 # correct order. 

623 if isinstance(inputs, dict): 

624 inputs = [inputs[key] for key in model._feed_input_names] 

625 if is_distributing_by_cloning(model): 

626 inputs = flatten_per_replica_values(strategy, inputs) 

627 targets = flatten_per_replica_values(strategy, targets) 

628 # Expand 1-dimensional inputs. 

629 # TODO(b/124535720): Remove once this standarize data logic is shared with 

630 # main flow. 

631 inputs, targets = nest.map_structure( 

632 training_utils_v1.standardize_single_array, (inputs, targets)) 

633 else: 

634 inputs = training_utils_v1.ModelInputs(inputs).as_list() 

635 

636 if mode == ModeKeys.PREDICT: 

637 sample_weights = [] 

638 targets = [] 

639 elif sample_weights is not None and is_distributing_by_cloning(model): 

640 if context.executing_eagerly() and not model._compile_distribution: 

641 raise NotImplementedError('`sample_weight` is not supported when using ' 

642 'tf.distribute.Strategy in eager mode and ' 

643 'cloning=True.') 

644 sample_weights = flatten_per_replica_values(strategy, sample_weights) 

645 

646 ins = [inputs, targets, sample_weights] 

647 return tuple(ins) 

648 

649 

650def is_distributing_by_cloning(model): 

651 """Decide whether this model is going to be distributed via cloning. 

652 

653 We are going to distribute the model by cloning in graph mode. 

654 

655 Args: 

656 model: Keras model to distribute. 

657 

658 Returns: 

659 True if the `model` is going to be distributed using cloning and False 

660 otherwise. 

661 """ 

662 if (backend.is_tpu_strategy(model._distribution_strategy) and 

663 context.executing_eagerly): # b/137580852 

664 return False 

665 elif ops.executing_eagerly_outside_functions(): 

666 return bool(model._compile_distribution) 

667 return True 

668 

669 

670def _custom_compile_for_predict(model): 

671 """Custom compile for TPU predict mode.""" 

672 if not model.built: 

673 # Model is not compilable because it does not know its number of inputs 

674 # and outputs, nor their shapes and names. We will compile after the first 

675 # time the model gets called on training data. 

676 return 

677 model._is_compiled = True 

678 model.total_loss = None 

679 model.train_function = None 

680 model.test_function = None 

681 model.predict_function = None 

682 

683 

684def _build_network_on_replica(model, mode, inputs=None, targets=None): 

685 """Build an updated model on replicas. 

686 

687 We create a new Keras model while sharing the variables from the old graph. 

688 Building a new sub-graph is required since the original keras model creates 

689 placeholders for the input and the output that are not accessible till we 

690 call iterator.get_next() inside the step_fn for `fit`/`evaluate`/`predict`. 

691 

692 The sharing of weights and layers between the old and the new model guarantee 

693 that we're using Strategy variables and any updates on either model are 

694 reflected correctly in callbacks and loop iterations. 

695 

696 We need to make sure we share the optimizers between the old and the new model 

697 as well so that optimizer state is not lost if the user is running fit 

698 multiple times. 

699 

700 Args: 

701 model: Model to be replicated across Replicas 

702 mode: Which of fit/eval/predict is building the distributed network 

703 inputs: Input variables to be passed to the model 

704 targets: Target tensor to be passed to model.compile 

705 

706 Returns: 

707 A new model with shared layers with the old model. 

708 """ 

709 # Need to do imports here since we run into a circular dependency error. 

710 from tensorflow.python.keras import models # pylint: disable=g-import-not-at-top 

711 from tensorflow.python.keras.engine import sequential # pylint: disable=g-import-not-at-top 

712 

713 # We rely on the internal methods to avoid having share_weights weights in the 

714 # public API. 

715 if isinstance(model, sequential.Sequential): 

716 updated_model = models._clone_sequential_model( 

717 model, input_tensors=inputs, layer_fn=models.share_weights) 

718 else: 

719 updated_model = models._clone_functional_model( 

720 model, input_tensors=inputs, layer_fn=models.share_weights) 

721 # Callable losses added directly to a functional Model need to be added 

722 # here. 

723 updated_model._callable_losses = model._callable_losses 

724 

725 # Recast all low precision outputs back to float32 since we only casted 

726 # the inputs to bfloat16 and not targets. This is done so that we can preserve 

727 # precision when calculating the loss value. 

728 def _upcast_low_precision_outputs(output): 

729 if output.dtype == dtypes.bfloat16: 

730 return math_ops.cast(output, dtypes.float32) 

731 else: 

732 return output 

733 updated_model.outputs = [_upcast_low_precision_outputs(o) 

734 for o in updated_model.outputs] 

735 

736 if isinstance(targets, tuple): 

737 targets = nest.flatten(targets) 

738 

739 if mode == ModeKeys.PREDICT and inputs is not None: # TPU predict case 

740 _custom_compile_for_predict(updated_model) 

741 else: 

742 updated_model.compile( 

743 model.optimizer, 

744 model.loss, 

745 metrics=metrics_module.clone_metrics(model._compile_metrics), 

746 loss_weights=model.loss_weights, 

747 sample_weight_mode=model.sample_weight_mode, 

748 weighted_metrics=metrics_module.clone_metrics( 

749 model._compile_weighted_metrics), 

750 target_tensors=targets) 

751 return updated_model 

752 

753 

754def _build_distributed_network(model, strategy, mode, inputs=None, 

755 targets=None): 

756 """Create a cloned model on each replica.""" 

757 with backend.get_graph().as_default(), strategy.scope(): 

758 distributed_model = strategy.extended.call_for_each_replica( 

759 _build_network_on_replica, 

760 args=(model, mode, inputs, targets)) 

761 set_distributed_model(model, mode, distributed_model) 

762 

763 

764def _clone_and_build_model(model, mode, inputs=None, targets=None): 

765 """Clone and build the given keras_model.""" 

766 # We need to set the import here since we run into a circular dependency 

767 # error. 

768 from tensorflow.python.keras import models # pylint: disable=g-import-not-at-top 

769 cloned_model = models.clone_model(model, input_tensors=inputs) 

770 

771 # Compile and build model. 

772 if isinstance(model.optimizer, optimizers.TFOptimizer): 

773 optimizer = model.optimizer 

774 else: 

775 optimizer_config = model.optimizer.get_config() 

776 optimizer = model.optimizer.__class__.from_config(optimizer_config) 

777 

778 # Recast all low precision outputs back to float32 since we only casted 

779 # the inputs to bfloat16 and not targets. This is done so that we can preserve 

780 # precision when calculating the loss value. 

781 def _upcast_low_precision_outputs(output): 

782 if output.dtype == dtypes.bfloat16: 

783 return math_ops.cast(output, dtypes.float32) 

784 else: 

785 return output 

786 cloned_model.outputs = [_upcast_low_precision_outputs(o) 

787 for o in cloned_model.outputs] 

788 

789 if isinstance(targets, tuple): 

790 targets = nest.flatten(targets) 

791 if mode == ModeKeys.PREDICT and inputs is not None: # TPU predict case 

792 _custom_compile_for_predict(cloned_model) 

793 else: 

794 cloned_model.compile( 

795 optimizer, 

796 model.loss, 

797 metrics=metrics_module.clone_metrics(model._compile_metrics), 

798 loss_weights=model.loss_weights, 

799 sample_weight_mode=model.sample_weight_mode, 

800 weighted_metrics=metrics_module.clone_metrics( 

801 model._compile_weighted_metrics), 

802 target_tensors=targets) 

803 return cloned_model 

804 

805 

806def clone_model_on_replicas(model, strategy, mode, inputs=None, targets=None): 

807 """Create a cloned model on each replica.""" 

808 with backend.get_graph().as_default(), strategy.scope(): 

809 distributed_model = strategy.extended.call_for_each_replica( 

810 _clone_and_build_model, args=(model, mode, inputs, targets)) 

811 set_distributed_model(model, mode, distributed_model) 

812 if mode == ModeKeys.TRAIN: 

813 model._make_callback_model(distributed_model) 

814 

815 

816def _make_execution_function(model, mode): 

817 """Makes or reuses function to run one step of distributed model execution.""" 

818 if is_distributing_by_cloning(model): 

819 return _make_execution_function_with_cloning(model, mode) 

820 

821 distributed_function = get_distributed_function(model, mode) 

822 if distributed_function: 

823 return distributed_function 

824 

825 distribution_function = _make_execution_function_without_cloning(model, mode) 

826 set_distributed_function(model, mode, distribution_function) 

827 return distribution_function 

828 

829 

830def _make_execution_function_without_cloning(model, mode): 

831 """Creates a function to run one step of distributed model execution.""" 

832 strategy = model._distribution_strategy 

833 

834 with strategy.scope(): 

835 per_replica_function = _make_replica_execution_function(model, mode) 

836 

837 def distributed_function(input_fn): 

838 """A single step of the distributed execution across replicas.""" 

839 x, y, sample_weights = input_fn() 

840 # Call `Model.{train,test,predict}_on_batch` on every replica passing 

841 # PerReplicas as arguments. On every replica inside this call, each 

842 # PerReplica object will return the value for that replica. The outputs 

843 # are PerReplicas too. 

844 outputs = strategy.run(per_replica_function, args=(x, y, sample_weights)) 

845 # Out of PerReplica outputs reduce or pick values to return. 

846 all_outputs = unwrap_outputs( 

847 strategy, outputs, with_loss_tensor=(mode != ModeKeys.PREDICT)) 

848 return all_outputs 

849 

850 if not model.run_eagerly: 

851 distributed_function = def_function.function(distributed_function) 

852 def execution_function(input_fn): 

853 # `numpy` translates Tensors to values in Eager mode. 

854 return [out.numpy() for out in distributed_function(input_fn)] 

855 else: 

856 execution_function = distributed_function 

857 

858 return execution_function 

859 

860 

861def _make_replica_execution_function(model, mode): 

862 """A single step of the distributed execution on a replica.""" 

863 if mode == ModeKeys.TRAIN: 

864 func = model.train_on_batch 

865 elif mode == ModeKeys.TEST: 

866 func = model.test_on_batch 

867 else: 

868 

869 def predict_on_batch(x, y=None, sample_weights=None): 

870 del y, sample_weights 

871 return model.predict_on_batch(x) 

872 

873 func = predict_on_batch 

874 

875 if mode != ModeKeys.PREDICT: 

876 # `reset_metrics` is set to False to maintain stateful metrics across 

877 # batch-level calls. 

878 func = functools.partial(func, reset_metrics=False) 

879 

880 return func 

881 

882 

883def _make_replicated_models_with_cloning(model, mode): 

884 """Build models on each replica.""" 

885 strategy = model._distribution_strategy 

886 

887 # If distributed_model is not built, create one for `mode`. 

888 if model._compile_distribution: 

889 clone_model_on_replicas(model, strategy, mode) 

890 else: 

891 _build_distributed_network(model, strategy, mode) 

892 

893 

894def _make_execution_function_with_cloning(model, mode): 

895 """Clones or re-uses models to run one step of distributed model execution.""" 

896 distributed_model = get_distributed_model(model, mode) 

897 # TODO(b/134069401): Create a cache for the distributed model and exec 

898 # function that incorporates additional attributes to be part of the cache key 

899 # than just the mode. 

900 # If distributed model for a particular `mode` is already built, use the 

901 # `_distribution_function` on that distributed model. 

902 # If you have updated the sample_weight_mode on the model, then you will need 

903 # to recompile metrics and recreate the execution function. This is indicated 

904 # by the `_recompile_exec_function` property. 

905 if (distributed_model and hasattr(distributed_model, '_distribution_function') 

906 and not (hasattr(distributed_model, '_recompile_exec_function') and 

907 distributed_model._recompile_exec_function)): 

908 return distributed_model._distributed_function 

909 

910 if not distributed_model: 

911 _make_replicated_models_with_cloning(model, mode) 

912 distributed_model = get_distributed_model(model, mode) 

913 assert distributed_model 

914 

915 # Also create an execution function on that distributed model. 

916 if context.executing_eagerly(): 

917 distributed_function = _make_eager_execution_function(model, mode) 

918 else: 

919 distributed_function = _make_graph_execution_function(model, mode) 

920 

921 # We cache the distributed execution function on the model since creating 

922 # distributed models and execution functions are expensive. 

923 distributed_model._distributed_function = distributed_function 

924 distributed_model._recompile_exec_function = False 

925 return distributed_function 

926 

927 

928def _make_graph_execution_function(model, mode): 

929 """Makes function to run one step of distributed model in graph mode.""" 

930 

931 def _per_replica_function(model): 

932 f = model._make_execution_function(mode) 

933 return (f.inputs, f.outputs, f.updates_op, f.session_kwargs) 

934 

935 strategy = model._distribution_strategy 

936 with strategy.scope(): 

937 # Create train ops on each of the devices when we call 

938 # `_per_replica_fit_function`. 

939 (grouped_inputs, grouped_outputs, grouped_updates, 

940 grouped_session_args) = strategy.extended.call_for_each_replica( 

941 _per_replica_function, args=(get_distributed_model(model, mode),)) 

942 

943 # Initialize the variables in the replicated model. This is necessary for 

944 # multi-worker training because on some workers, initialization is not 

945 # needed. This method does initialization or waiting for initialization 

946 # according to the context object of distribute coordinator. 

947 init_restore_or_wait_for_variables() 

948 

949 # Unwrap all the per device values returned from `call_for_each_replica`. 

950 # Unwrapping per device values gives you a list of values that can be 

951 # used to construct a new train function that is composed of update ops on 

952 # all the devices over which the model is distributed. 

953 (all_inputs, all_outputs, all_updates, all_session_args) = unwrap_values( 

954 strategy, 

955 grouped_inputs, 

956 grouped_outputs, 

957 grouped_updates, 

958 grouped_session_args, 

959 with_loss_tensor=(mode != ModeKeys.PREDICT)) 

960 

961 return backend.function( 

962 all_inputs, 

963 all_outputs, 

964 updates=all_updates, 

965 name='distributed_{}_function'.format(mode), 

966 **all_session_args) 

967 

968 

969def _make_eager_execution_function(model, mode): 

970 """Makes function to run one step of distributed model eager execution.""" 

971 def _per_replica_function(model): 

972 f = model._make_execution_function(mode) 

973 return (f.inputs, f.outputs) 

974 

975 # NOTE(priyag): Try creating a new FuncGraph within DS scope instead of using 

976 # the global one. 

977 strategy = model._distribution_strategy 

978 global_graph = backend.get_graph() 

979 

980 with global_graph.as_default(), strategy.scope(): 

981 # First we gather the relevant portions of the model across all replicas. 

982 # `backend._scratch_graph(global_graph)` signals to Keras that it should not 

983 # lift to a separate graph when creating the per-replica functions. 

984 with backend._scratch_graph(global_graph): 

985 # Create train ops on each of the devices when we call 

986 # `_per_replica_fit_function`. 

987 grouped = strategy.extended.call_for_each_replica( 

988 _per_replica_function, args=(get_distributed_model(model, mode),)) 

989 grouped_inputs, grouped_outputs = grouped 

990 

991 # Unwrap all the per device values returned from `call_for_each_replica`. 

992 # Unwrapping per device values gives you a list of values that can be 

993 # used to construct a new train function that is composed of 

994 # inputs/outputs on all the devices over which the model is distributed. 

995 (all_inputs, all_outputs, _, _) = unwrap_values( 

996 strategy, 

997 grouped_inputs, 

998 grouped_outputs, 

999 with_loss_tensor=(mode != ModeKeys.PREDICT)) 

1000 

1001 # Finally, a joint Keras function is created; this one will be created in 

1002 # a separate FuncGraph. 

1003 return backend.function( 

1004 all_inputs, 

1005 all_outputs, 

1006 name='eager_distributed_{}_function'.format(mode)) 

1007 

1008 

1009def _copy_weights_to_distributed_model(original_model, mode): 

1010 """Copies weights from original model to distributed models.""" 

1011 strategy = original_model._distribution_strategy 

1012 distributed_model = get_distributed_model(original_model, mode) 

1013 if strategy: 

1014 # Copy the weights from the original model to each of the replicated 

1015 # models. 

1016 orig_model_weights = original_model.get_weights() 

1017 first_model = strategy.unwrap(distributed_model)[0] 

1018 set_weights(strategy, first_model, orig_model_weights) 

1019 

1020 

1021def _copy_weights_to_original_model(model, mode): 

1022 """Copies weights from first distributed model back to original model.""" 

1023 if model._distribution_strategy and mode == ModeKeys.TRAIN: 

1024 distributed_model = get_distributed_model(model, mode) 

1025 updated_weights = model._distribution_strategy.unwrap( 

1026 distributed_model)[0].get_weights() 

1027 model.set_weights(updated_weights) 

1028 

1029 

1030def _per_replica_aggregate_batch(strategy, batch_outs, model, mode): 

1031 """Aggregates the per-replica batch-level outputs from a distributed step.""" 

1032 if strategy is not None and mode == ModeKeys.PREDICT: 

1033 total_batch_outs = [] 

1034 for i in range(len(model.outputs)): 

1035 num_replicas = strategy.num_replicas_in_sync 

1036 nested_outs = batch_outs[i * num_replicas:i * num_replicas + num_replicas] 

1037 total_batch_outs.append( 

1038 concat_along_batch_dimension(nest.flatten(nested_outs))) 

1039 return total_batch_outs 

1040 return batch_outs 

1041 

1042 

1043def _reset_metrics(model): 

1044 if model._distribution_strategy: 

1045 for mode in [ModeKeys.TRAIN, ModeKeys.TEST, ModeKeys.PREDICT]: 

1046 distributed_model = get_distributed_model(model, mode) 

1047 if distributed_model: 

1048 first_model = model._distribution_strategy.unwrap(distributed_model)[0] 

1049 first_model.reset_metrics() 

1050 

1051 

1052def get_distributed_model(model, mode): 

1053 key = _generate_cache_key(mode) 

1054 return model._distributed_model_cache.get(key, None) 

1055 

1056 

1057def set_distributed_model(model, mode, distributed_model): 

1058 key = _generate_cache_key(mode) 

1059 model._distributed_model_cache[key] = distributed_model 

1060 

1061 

1062def get_distributed_function(model, mode): 

1063 key = _generate_cache_key(mode) 

1064 return model._distributed_function_cache.get(key, None) 

1065 

1066 

1067def set_distributed_function(model, mode, distributed_function): 

1068 key = _generate_cache_key(mode) 

1069 model._distributed_function_cache[key] = distributed_function 

1070 

1071 

1072def _generate_cache_key(mode): 

1073 key = hash(mode) 

1074 return key 

1075 

1076 

1077@tf_contextlib.contextmanager 

1078def distributed_scope(strategy, learning_phase): 

1079 with strategy.scope(), backend.learning_phase_scope(learning_phase): 

1080 yield 

1081 

1082 

1083def is_current_worker_chief(): 

1084 return dc.get_current_worker_context().is_chief 

1085 

1086 

1087def filter_distributed_callbacks(callbacks_list, model): 

1088 """Filter Callbacks based on the worker context when running multi-worker. 

1089 

1090 Args: 

1091 callbacks_list: A list of `Callback` instances. 

1092 model: Keras model instance. 

1093 

1094 Returns: 

1095 The list of `Callback` instances that should be run on this worker. 

1096 """ 

1097 

1098 if not model._in_multi_worker_mode(): 

1099 raise ValueError( 

1100 'filter_distributed_callbacks() should only be called when Keras ' 

1101 'is in multi worker mode.') 

1102 

1103 callbacks_list = callbacks_list or [] 

1104 if not [ 

1105 c for c in callbacks_list if isinstance(c, callbacks.ModelCheckpoint) 

1106 ]: 

1107 # TODO(rchao): Consider providing a ModelCheckpoint here if the user 

1108 # fails to (possibly with tempfile directory). 

1109 logging.warning('ModelCheckpoint callback is not provided. ' 

1110 'Workers will need to restart training if any fails.') 

1111 

1112 if callbacks_list is None or is_current_worker_chief(): 

1113 return callbacks_list 

1114 

1115 # Some Callbacks should only run on the chief worker. 

1116 return [ 

1117 callback for callback in callbacks_list if not callback._chief_worker_only 

1118 ] # pylint: disable=protected-access 

1119 

1120 

1121def _update_sample_weight_modes(model, mode, sample_weights): 

1122 """Update sample_weight_mode of the distributed model.""" 

1123 if is_distributing_by_cloning(model): 

1124 distributed_model = get_distributed_model(model, mode) 

1125 if not distributed_model: 

1126 _make_replicated_models_with_cloning(model, mode) 

1127 distributed_model = get_distributed_model(model, mode) 

1128 distributed_model._recompile_exec_function = any( 

1129 [e.sample_weights_mismatch() for e in model._training_endpoints]) 

1130 

1131 if sample_weights: 

1132 distributed_models = flatten_per_replica_values( 

1133 model._distribution_strategy, distributed_model) 

1134 # sample_weights is a tuple of 1 list where the number of elements in the 

1135 # list is equal to the number of replicas in sync. 

1136 sample_weights = sample_weights[0] 

1137 if sample_weights and None not in sample_weights: 

1138 for m, sw in zip(distributed_models, sample_weights): 

1139 m._update_sample_weight_modes(sample_weights=[sw]) 

1140 

1141 

1142def concat_along_batch_dimension(outputs): 

1143 """Concats prediction outputs along the batch dimension.""" 

1144 if isinstance(outputs[0], sparse_tensor.SparseTensor): 

1145 return sparse_ops.sparse_concat_v2(axis=0, sp_inputs=outputs) 

1146 if isinstance(outputs[0], ragged_tensor.RaggedTensor): 

1147 return array_ops.concat(outputs, axis=0) 

1148 return np.concatenate(outputs)