Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/distribute/mirrored_strategy.py: 25%

421 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Class MirroredStrategy implementing tf.distribute.Strategy.""" 

16 

17import copy 

18 

19from tensorflow.python import tf2 

20from tensorflow.python.distribute import collective_util 

21from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib 

22from tensorflow.python.distribute import cross_device_utils 

23from tensorflow.python.distribute import device_util 

24from tensorflow.python.distribute import distribute_lib 

25from tensorflow.python.distribute import distribute_utils 

26from tensorflow.python.distribute import input_lib 

27from tensorflow.python.distribute import input_util 

28from tensorflow.python.distribute import mirrored_run 

29from tensorflow.python.distribute import multi_worker_util 

30from tensorflow.python.distribute import numpy_dataset 

31from tensorflow.python.distribute import reduce_util 

32from tensorflow.python.distribute import values 

33from tensorflow.python.distribute import values_util 

34from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolver 

35from tensorflow.python.distribute.v1 import input_lib as input_lib_v1 

36from tensorflow.python.eager import context 

37from tensorflow.python.eager import record 

38from tensorflow.python.framework import config 

39from tensorflow.python.framework import constant_op 

40from tensorflow.python.framework import device as tf_device 

41from tensorflow.python.framework import dtypes 

42from tensorflow.python.framework import ops 

43from tensorflow.python.ops import array_ops 

44from tensorflow.python.ops import control_flow_ops 

45from tensorflow.python.ops import control_flow_util 

46from tensorflow.python.ops import while_loop 

47from tensorflow.python.platform import tf_logging as logging 

48from tensorflow.python.util import nest 

49from tensorflow.python.util.tf_export import tf_export 

50 

51# TODO(josh11b): Replace asserts in this file with if ...: raise ... 

52 

53 

54def _is_device_list_single_worker(devices): 

55 """Checks whether the devices list is for single or multi-worker. 

56 

57 Args: 

58 devices: a list of device strings or tf.config.LogicalDevice objects, for 

59 either local or for remote devices. 

60 

61 Returns: 

62 a boolean indicating whether these device strings are for local or for 

63 remote. 

64 

65 Raises: 

66 ValueError: if device strings are not consistent. 

67 """ 

68 specs = [] 

69 for d in devices: 

70 name = d.name if isinstance(d, context.LogicalDevice) else d 

71 specs.append(tf_device.DeviceSpec.from_string(name)) 

72 num_workers = len({(d.job, d.task, d.replica) for d in specs}) 

73 all_local = all(d.job in (None, "localhost") for d in specs) 

74 any_local = any(d.job in (None, "localhost") for d in specs) 

75 

76 if any_local and not all_local: 

77 raise ValueError("Local device should have only 'localhost' in the job " 

78 "field in device string. " 

79 "E.g. 'job:localhost' in " 

80 "/job:localhost/replica:0/task:0/device:CPU:0" 

81 "Devices cannot have mixed list of device strings " 

82 "containing both localhost and other job types such as " 

83 "worker, ps etc. ") 

84 

85 if num_workers == 1 and not all_local: 

86 if any(d.task is None for d in specs): 

87 raise ValueError("Remote device string must have task specified." 

88 "E.g. 'task:0' in " 

89 "/job:worker/replica:0/task:0/device:CPU:0") 

90 

91 return num_workers == 1 

92 

93 

94def _cluster_spec_to_device_list(cluster_spec, num_gpus_per_worker): 

95 """Returns a device list given a cluster spec.""" 

96 cluster_spec = multi_worker_util.normalize_cluster_spec(cluster_spec) 

97 devices = [] 

98 for task_type in ("chief", "worker"): 

99 for task_id in range(len(cluster_spec.as_dict().get(task_type, []))): 

100 if num_gpus_per_worker == 0: 

101 devices.append("/job:%s/task:%d/device:CPU:0" % (task_type, task_id)) 

102 else: 

103 devices.extend([ 

104 "/job:%s/task:%d/device:GPU:%i" % (task_type, task_id, gpu_id) 

105 for gpu_id in range(num_gpus_per_worker) 

106 ]) 

107 return devices 

108 

109 

110def _group_device_list(devices): 

111 """Groups the devices list by task_type and task_id. 

112 

113 Args: 

114 devices: a list of device strings for remote devices. 

115 

116 Returns: 

117 a dict of list of device strings mapping from task_type to a list of devices 

118 for the task_type in the ascending order of task_id. 

119 """ 

120 assert not _is_device_list_single_worker(devices) 

121 device_dict = {} 

122 

123 for d in devices: 

124 d_spec = tf_device.DeviceSpec.from_string(d) 

125 

126 # Create an entry for the task_type. 

127 if d_spec.job not in device_dict: 

128 device_dict[d_spec.job] = [] 

129 

130 # Fill the device list for task_type until it covers the task_id. 

131 while len(device_dict[d_spec.job]) <= d_spec.task: 

132 device_dict[d_spec.job].append([]) 

133 

134 device_dict[d_spec.job][d_spec.task].append(d) 

135 

136 return device_dict 

137 

138 

139def _is_gpu_device(device): 

140 return tf_device.DeviceSpec.from_string(device).device_type == "GPU" 

141 

142 

143def _infer_num_gpus_per_worker(devices): 

144 """Infers the number of GPUs on each worker. 

145 

146 Currently to make multi-worker cross device ops work, we need all workers to 

147 have the same number of GPUs. 

148 

149 Args: 

150 devices: a list of device strings, can be either local devices or remote 

151 devices. 

152 

153 Returns: 

154 number of GPUs per worker. 

155 

156 Raises: 

157 ValueError if workers have different number of GPUs or GPU indices are not 

158 consecutive and starting from 0. 

159 """ 

160 if _is_device_list_single_worker(devices): 

161 return sum(1 for d in devices if _is_gpu_device(d)) 

162 else: 

163 device_dict = _group_device_list(devices) 

164 num_gpus = None 

165 for _, devices_in_task in device_dict.items(): 

166 for device_in_task in devices_in_task: 

167 if num_gpus is None: 

168 num_gpus = sum(1 for d in device_in_task if _is_gpu_device(d)) 

169 

170 # Verify other workers have the same number of GPUs. 

171 elif num_gpus != sum(1 for d in device_in_task if _is_gpu_device(d)): 

172 raise ValueError("All workers should have the same number of GPUs.") 

173 

174 for d in device_in_task: 

175 d_spec = tf_device.DeviceSpec.from_string(d) 

176 if (d_spec.device_type == "GPU" and 

177 d_spec.device_index >= num_gpus): 

178 raise ValueError("GPU `device_index` on a worker should be " 

179 "consecutive and start from 0.") 

180 return num_gpus 

181 

182 

183def all_local_devices(num_gpus=None): 

184 devices = config.list_logical_devices("GPU") 

185 if num_gpus is not None: 

186 devices = devices[:num_gpus] 

187 return devices or config.list_logical_devices("CPU") 

188 

189 

190def all_devices(): 

191 devices = [] 

192 tfconfig = TFConfigClusterResolver() 

193 if tfconfig.cluster_spec().as_dict(): 

194 devices = _cluster_spec_to_device_list(tfconfig.cluster_spec(), 

195 context.num_gpus()) 

196 return devices if devices else all_local_devices() 

197 

198 

199@tf_export("distribute.MirroredStrategy", v1=[]) # pylint: disable=g-classes-have-attributes 

200class MirroredStrategy(distribute_lib.Strategy): 

201 """Synchronous training across multiple replicas on one machine. 

202 

203 This strategy is typically used for training on one 

204 machine with multiple GPUs. For TPUs, use 

205 `tf.distribute.TPUStrategy`. To use `MirroredStrategy` with multiple workers, 

206 please refer to `tf.distribute.experimental.MultiWorkerMirroredStrategy`. 

207 

208 For example, a variable created under a `MirroredStrategy` is a 

209 `MirroredVariable`. If no devices are specified in the constructor argument of 

210 the strategy then it will use all the available GPUs. If no GPUs are found, it 

211 will use the available CPUs. Note that TensorFlow treats all CPUs on a 

212 machine as a single device, and uses threads internally for parallelism. 

213 

214 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) 

215 >>> with strategy.scope(): 

216 ... x = tf.Variable(1.) 

217 >>> x 

218 MirroredVariable:{ 

219 0: <tf.Variable ... shape=() dtype=float32, numpy=1.0>, 

220 1: <tf.Variable ... shape=() dtype=float32, numpy=1.0> 

221 } 

222 

223 While using distribution strategies, all the variable creation should be done 

224 within the strategy's scope. This will replicate the variables across all the 

225 replicas and keep them in sync using an all-reduce algorithm. 

226 

227 Variables created inside a `MirroredStrategy` which is wrapped with a 

228 `tf.function` are still `MirroredVariables`. 

229 

230 >>> x = [] 

231 >>> @tf.function # Wrap the function with tf.function. 

232 ... def create_variable(): 

233 ... if not x: 

234 ... x.append(tf.Variable(1.)) 

235 ... return x[0] 

236 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) 

237 >>> with strategy.scope(): 

238 ... _ = create_variable() 

239 ... print(x[0]) 

240 MirroredVariable:{ 

241 0: <tf.Variable ... shape=() dtype=float32, numpy=1.0>, 

242 1: <tf.Variable ... shape=() dtype=float32, numpy=1.0> 

243 } 

244 

245 `experimental_distribute_dataset` can be used to distribute the dataset across 

246 the replicas when writing your own training loop. If you are using `.fit` and 

247 `.compile` methods available in `tf.keras`, then `tf.keras` will handle the 

248 distribution for you. 

249 

250 For example: 

251 

252 ```python 

253 my_strategy = tf.distribute.MirroredStrategy() 

254 with my_strategy.scope(): 

255 @tf.function 

256 def distribute_train_epoch(dataset): 

257 def replica_fn(input): 

258 # process input and return result 

259 return result 

260 

261 total_result = 0 

262 for x in dataset: 

263 per_replica_result = my_strategy.run(replica_fn, args=(x,)) 

264 total_result += my_strategy.reduce(tf.distribute.ReduceOp.SUM, 

265 per_replica_result, axis=None) 

266 return total_result 

267 

268 dist_dataset = my_strategy.experimental_distribute_dataset(dataset) 

269 for _ in range(EPOCHS): 

270 train_result = distribute_train_epoch(dist_dataset) 

271 ``` 

272 

273 Args: 

274 devices: a list of device strings such as `['/gpu:0', '/gpu:1']`. If 

275 `None`, all available GPUs are used. If no GPUs are found, CPU is used. 

276 cross_device_ops: optional, a descendant of `CrossDeviceOps`. If this is not 

277 set, `NcclAllReduce()` will be used by default. One would customize this 

278 if NCCL isn't available or if a special implementation that exploits 

279 the particular hardware is available. 

280 """ 

281 

282 # Only set this in tests. 

283 _collective_key_base = 0 

284 

285 def __init__(self, devices=None, cross_device_ops=None): 

286 extended = MirroredExtended( 

287 self, devices=devices, cross_device_ops=cross_device_ops) 

288 super(MirroredStrategy, self).__init__(extended) 

289 distribute_lib.distribution_strategy_gauge.get_cell("V2").set( 

290 "MirroredStrategy") 

291 

292 

293@tf_export(v1=["distribute.MirroredStrategy"]) 

294class MirroredStrategyV1(distribute_lib.StrategyV1): # pylint: disable=g-missing-docstring 

295 

296 __doc__ = MirroredStrategy.__doc__ 

297 

298 # Only set this in tests. 

299 _collective_key_base = 0 

300 

301 def __init__(self, devices=None, cross_device_ops=None): 

302 extended = MirroredExtended( 

303 self, devices=devices, cross_device_ops=cross_device_ops) 

304 super(MirroredStrategyV1, self).__init__(extended) 

305 distribute_lib.distribution_strategy_gauge.get_cell("V1").set( 

306 "MirroredStrategy") 

307 

308 

309# TODO(josh11b): Switch to V2 when we no longer need to support tf.compat.v1. 

310class MirroredExtended(distribute_lib.StrategyExtendedV1): 

311 """Implementation of MirroredStrategy.""" 

312 

313 def __init__(self, container_strategy, devices=None, cross_device_ops=None): 

314 super(MirroredExtended, self).__init__(container_strategy) 

315 if context.executing_eagerly(): 

316 if devices and not _is_device_list_single_worker(devices): 

317 raise RuntimeError("In-graph multi-worker training with " 

318 "`MirroredStrategy` is not supported in eager mode.") 

319 else: 

320 if TFConfigClusterResolver().cluster_spec().as_dict(): 

321 # if you are executing in eager mode, only the single machine code 

322 # path is supported. 

323 logging.info("Initializing local devices since in-graph multi-worker " 

324 "training with `MirroredStrategy` is not supported in " 

325 "eager mode. TF_CONFIG will be ignored when " 

326 "when initializing `MirroredStrategy`.") 

327 devices = devices or all_local_devices() 

328 else: 

329 devices = devices or all_devices() 

330 

331 assert devices, ("Got an empty `devices` list and unable to recognize " 

332 "any local devices.") 

333 

334 self._collective_key_base = container_strategy._collective_key_base 

335 self._communication_options = collective_util.Options( 

336 implementation=collective_util.CommunicationImplementation.NCCL) 

337 self._cross_device_ops = cross_device_ops 

338 self._initialize_strategy(devices) 

339 

340 # TODO(b/128995245): Enable last partial batch support in graph mode. 

341 if ops.executing_eagerly_outside_functions(): 

342 self.experimental_enable_get_next_as_optional = True 

343 

344 # Flag to turn on VariablePolicy. 

345 self._use_var_policy = False 

346 

347 def _use_merge_call(self): 

348 # We currently only disable merge_call when XLA is used to compile the `fn` 

349 # passed to `strategy.run` and all devices are GPU. 

350 return not control_flow_util.GraphOrParentsInXlaContext( 

351 ops.get_default_graph()) or not all( 

352 [_is_gpu_device(d) for d in self._devices]) 

353 

354 def _initialize_strategy(self, devices): 

355 # The _initialize_strategy method is intended to be used by distribute 

356 # coordinator as well. 

357 assert devices, "Must specify at least one device." 

358 devices = tuple(device_util.resolve(d) for d in devices) 

359 assert len(set(devices)) == len(devices), ( 

360 "No duplicates allowed in `devices` argument: %s" % (devices,)) 

361 

362 self._initialize_single_worker(devices) 

363 

364 self._collective_ops = self._make_collective_ops_with_fallbacks() 

365 # If cross_device_ops is not provided, set it to collective op by default. 

366 if not self._cross_device_ops: 

367 self._cross_device_ops = self._collective_ops 

368 

369 def _make_collective_ops_with_fallbacks(self): 

370 self._collective_keys = cross_device_utils.CollectiveKeys( 

371 group_key_start=1 + self._collective_key_base) 

372 

373 if not ops.executing_eagerly_outside_functions() and any( 

374 "gpu" not in d.lower() for d in self._devices): 

375 # In TF1/Session, fall back to ReductionToOneDevice() if there are 

376 # non-GPU devices or virtual GPUs are used. 

377 return cross_device_ops_lib.ReductionToOneDevice() 

378 

379 # Use ReductionToOneDevice() if mixed devices are used. 

380 if any("cpu" in d.lower() for d in self._devices) and any( 

381 "gpu" in d.lower() for d in self._devices): 

382 return cross_device_ops_lib.ReductionToOneDevice() 

383 

384 if all("cpu" in d.lower() for d in self._devices): 

385 # Use RING collective ops if all devices are CPU. 

386 self._communication_options = collective_util.Options( 

387 implementation=collective_util.CommunicationImplementation.RING) 

388 

389 else: 

390 physical_gpus = context.context().list_physical_devices(device_type="GPU") 

391 logical_gpus = context.context().list_logical_devices(device_type="GPU") 

392 # Use RING collective ops if virtual devices are used. 

393 if len(physical_gpus) < len(logical_gpus): 

394 self._communication_options = collective_util.Options( 

395 implementation=collective_util.CommunicationImplementation.RING) 

396 

397 # If all devices are physical GPU, use NCCL implementation. 

398 return cross_device_ops_lib.CollectiveAllReduce( 

399 devices=self._devices, 

400 group_size=len(self._devices), 

401 options=self._communication_options, 

402 collective_keys=self._collective_keys) 

403 

404 def _initialize_single_worker(self, devices): 

405 """Initializes the object for single-worker training.""" 

406 self._devices = tuple(device_util.canonicalize(d) for d in devices) 

407 self._input_workers_devices = ( 

408 (device_util.canonicalize("/device:CPU:0", devices[0]), devices),) 

409 

410 self._host_input_device = numpy_dataset.SingleDevice( 

411 self._input_workers_devices[0][0]) 

412 device_spec = tf_device.DeviceSpec.from_string( 

413 self._input_workers_devices[0][0]) 

414 # Ensures when we enter strategy.scope() we use the correct default device 

415 if device_spec.job is not None and device_spec.job != "localhost": 

416 self._default_device = "/job:%s/replica:%d/task:%d" % ( 

417 device_spec.job, device_spec.replica, device_spec.task) 

418 

419 logging.info("Using MirroredStrategy with devices %r", devices) 

420 

421 def _initialize_multi_worker(self, devices): 

422 """Initializes the object for multi-worker training.""" 

423 device_dict = _group_device_list(devices) 

424 workers = [] 

425 worker_devices = [] 

426 for job in ("chief", "worker"): 

427 for task in range(len(device_dict.get(job, []))): 

428 worker = "/job:%s/task:%d" % (job, task) 

429 workers.append(worker) 

430 worker_devices.append((worker, device_dict[job][task])) 

431 

432 # Setting `_default_device` will add a device scope in the 

433 # distribution.scope. We set the default device to the first worker. When 

434 # users specify device under distribution.scope by 

435 # with tf.device("/cpu:0"): 

436 # ... 

437 # their ops will end up on the cpu device of its first worker, e.g. 

438 # "/job:worker/task:0/device:CPU:0". Note this is not used in replica mode. 

439 self._default_device = workers[0] 

440 self._host_input_device = numpy_dataset.SingleDevice(workers[0]) 

441 

442 self._devices = tuple(devices) 

443 self._input_workers_devices = worker_devices 

444 self._is_multi_worker_training = True 

445 

446 if len(workers) > 1: 

447 # Grandfather usage in the legacy tests if they're configured properly. 

448 if (not isinstance(self._cross_device_ops, 

449 cross_device_ops_lib.ReductionToOneDevice) or 

450 self._cross_device_ops._num_between_graph_workers > 1): # pylint: disable=protected-access 

451 raise ValueError( 

452 "In-graph multi-worker training with `MirroredStrategy` is not " 

453 "supported.") 

454 self._inferred_cross_device_ops = self._cross_device_ops 

455 else: 

456 # TODO(yuefengz): make `select_cross_device_ops` work with device strings 

457 # containing job names. 

458 self._inferred_cross_device_ops = cross_device_ops_lib.NcclAllReduce() 

459 

460 logging.info("Using MirroredStrategy with remote devices %r", devices) 

461 

462 def _input_workers_with_options(self, options=None): 

463 if not options: 

464 return input_lib.InputWorkers(self._input_workers_devices) 

465 if (options.experimental_replication_mode == 

466 distribute_lib.InputReplicationMode.PER_REPLICA): 

467 if options.experimental_place_dataset_on_device: 

468 self._input_workers_devices = ( 

469 tuple( 

470 (device_util.canonicalize(d, d), (d,)) for d in self._devices)) 

471 else: 

472 self._input_workers_devices = ( 

473 tuple((device_util.canonicalize("/device:CPU:0", d), (d,)) 

474 for d in self._devices)) 

475 return input_lib.InputWorkers(self._input_workers_devices) 

476 else: 

477 if not options.experimental_fetch_to_device: 

478 return input_lib.InputWorkers([ 

479 (host_device, (host_device,) * len(compute_devices)) 

480 for host_device, compute_devices in self._input_workers_devices 

481 ]) 

482 else: 

483 return input_lib.InputWorkers(self._input_workers_devices) 

484 

485 @property 

486 def _input_workers(self): 

487 return self._input_workers_with_options() 

488 

489 def _get_variable_creator_initial_value(self, 

490 replica_id, 

491 device, 

492 primary_var, 

493 **kwargs): 

494 """Return the initial value for variables on a replica.""" 

495 if replica_id == 0: 

496 return kwargs["initial_value"] 

497 else: 

498 assert primary_var is not None 

499 assert device is not None 

500 assert kwargs is not None 

501 

502 def initial_value_fn(): 

503 if context.executing_eagerly() or ops.inside_function(): 

504 init_value = primary_var.value() 

505 return array_ops.identity(init_value) 

506 else: 

507 with ops.device(device): 

508 init_value = primary_var.initial_value 

509 return array_ops.identity(init_value) 

510 

511 return initial_value_fn 

512 

513 def _create_variable(self, next_creator, **kwargs): 

514 """Create a mirrored variable. See `DistributionStrategy.scope`.""" 

515 colocate_with = kwargs.pop("colocate_with", None) 

516 if colocate_with is None: 

517 devices = self._devices 

518 elif isinstance(colocate_with, numpy_dataset.SingleDevice): 

519 with ops.device(colocate_with.device): 

520 return next_creator(**kwargs) 

521 else: 

522 devices = colocate_with._devices # pylint: disable=protected-access 

523 

524 def _real_mirrored_creator(**kwargs): # pylint: disable=g-missing-docstring 

525 value_list = [] 

526 for i, d in enumerate(devices): 

527 with ops.device(d): 

528 kwargs["initial_value"] = self._get_variable_creator_initial_value( 

529 replica_id=i, 

530 device=d, 

531 primary_var=value_list[0] if value_list else None, 

532 **kwargs) 

533 if i > 0: 

534 # Give replicas meaningful distinct names: 

535 var0name = value_list[0].name.split(":")[0] 

536 # We append a / to variable names created on replicas with id > 0 to 

537 # ensure that we ignore the name scope and instead use the given 

538 # name as the absolute name of the variable. 

539 kwargs["name"] = "%s/replica_%d/" % (var0name, i) 

540 with context.device_policy(context.DEVICE_PLACEMENT_SILENT): 

541 # Don't record operations (e.g. other variable reads) during 

542 # variable creation. 

543 with record.stop_recording(): 

544 v = next_creator(**kwargs) 

545 assert not isinstance(v, values.DistributedVariable) 

546 value_list.append(v) 

547 return value_list 

548 

549 return distribute_utils.create_mirrored_variable( 

550 self._container_strategy(), _real_mirrored_creator, 

551 distribute_utils.VARIABLE_CLASS_MAPPING, 

552 distribute_utils.VARIABLE_POLICY_MAPPING, **kwargs) 

553 

554 def _validate_colocate_with_variable(self, colocate_with_variable): 

555 distribute_utils.validate_colocate_distributed_variable( 

556 colocate_with_variable, self) 

557 

558 def _make_dataset_iterator(self, dataset): 

559 return input_lib_v1.DatasetIterator( 

560 dataset, 

561 self._input_workers, 

562 self._container_strategy(), 

563 num_replicas_in_sync=self._num_replicas_in_sync) 

564 

565 def _make_input_fn_iterator( 

566 self, 

567 input_fn, 

568 replication_mode=distribute_lib.InputReplicationMode.PER_WORKER): 

569 input_contexts = [] 

570 num_workers = self._input_workers.num_workers 

571 for i in range(num_workers): 

572 input_contexts.append(distribute_lib.InputContext( 

573 num_input_pipelines=num_workers, 

574 input_pipeline_id=i, 

575 num_replicas_in_sync=self._num_replicas_in_sync)) 

576 return input_lib_v1.InputFunctionIterator(input_fn, self._input_workers, 

577 input_contexts, 

578 self._container_strategy()) 

579 

580 def _experimental_distribute_dataset(self, dataset, options): 

581 if (options and options.experimental_replication_mode == 

582 distribute_lib.InputReplicationMode.PER_REPLICA): 

583 raise NotImplementedError( 

584 "InputReplicationMode.PER_REPLICA " 

585 "is only supported in " 

586 "`distribute_datasets_from_function`." 

587 ) 

588 return input_util.get_distributed_dataset( 

589 dataset, 

590 self._input_workers_with_options(options), 

591 self._container_strategy(), 

592 num_replicas_in_sync=self._num_replicas_in_sync, 

593 options=options) 

594 

595 def _experimental_make_numpy_dataset(self, numpy_input, session): 

596 return numpy_dataset.one_host_numpy_dataset( 

597 numpy_input, self._host_input_device, session) 

598 

599 def _distribute_datasets_from_function(self, dataset_fn, options): 

600 input_workers = self._input_workers_with_options(options) 

601 input_contexts = [] 

602 num_workers = input_workers.num_workers 

603 for i in range(num_workers): 

604 input_contexts.append(distribute_lib.InputContext( 

605 num_input_pipelines=num_workers, 

606 input_pipeline_id=i, 

607 num_replicas_in_sync=self._num_replicas_in_sync)) 

608 

609 return input_util.get_distributed_datasets_from_function( 

610 dataset_fn, input_workers, input_contexts, self._container_strategy(), 

611 options) 

612 

613 def _experimental_distribute_values_from_function(self, value_fn): 

614 per_replica_values = [] 

615 for replica_id in range(self._num_replicas_in_sync): 

616 per_replica_values.append(value_fn( 

617 distribute_lib.ValueContext(replica_id, 

618 self._num_replicas_in_sync))) 

619 return distribute_utils.regroup(per_replica_values, always_wrap=True) 

620 

621 # TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed. 

622 def _experimental_run_steps_on_iterator(self, fn, iterator, iterations, 

623 initial_loop_values=None): 

624 if initial_loop_values is None: 

625 initial_loop_values = {} 

626 initial_loop_values = nest.flatten(initial_loop_values) 

627 

628 ctx = input_lib.MultiStepContext() 

629 def body(i, *args): 

630 """A wrapper around `fn` to create the while loop body.""" 

631 del args 

632 fn_result = fn(ctx, iterator.get_next()) 

633 for (name, output) in ctx.last_step_outputs.items(): 

634 # Convert all outputs to tensors, potentially from `DistributedValues`. 

635 ctx.last_step_outputs[name] = self._local_results(output) 

636 flat_last_step_outputs = nest.flatten(ctx.last_step_outputs) 

637 with ops.control_dependencies([fn_result]): 

638 return [i + 1] + flat_last_step_outputs 

639 

640 # We capture the control_flow_context at this point, before we run `fn` 

641 # inside a while_loop. This is useful in cases where we might need to exit 

642 # these contexts and get back to the outer context to do some things, for 

643 # e.g. create an op which should be evaluated only once at the end of the 

644 # loop on the host. One such usage is in creating metrics' value op. 

645 self._outer_control_flow_context = ( 

646 ops.get_default_graph()._get_control_flow_context()) # pylint: disable=protected-access 

647 

648 cond = lambda i, *args: i < iterations 

649 i = constant_op.constant(0) 

650 loop_result = while_loop.while_loop( 

651 cond, 

652 body, [i] + initial_loop_values, 

653 name="", 

654 parallel_iterations=1, 

655 back_prop=False, 

656 swap_memory=False, 

657 return_same_structure=True) 

658 del self._outer_control_flow_context 

659 

660 ctx.run_op = control_flow_ops.group(loop_result) 

661 

662 # Convert the last_step_outputs from a list to the original dict structure 

663 # of last_step_outputs. 

664 last_step_tensor_outputs = loop_result[1:] 

665 last_step_tensor_outputs_dict = nest.pack_sequence_as( 

666 ctx.last_step_outputs, last_step_tensor_outputs) 

667 

668 for name, reduce_op in ctx._last_step_outputs_reduce_ops.items(): # pylint: disable=protected-access 

669 output = last_step_tensor_outputs_dict[name] 

670 # For outputs that have already been reduced, wrap them in a Mirrored 

671 # container, else in a PerReplica container. 

672 if reduce_op is None: 

673 last_step_tensor_outputs_dict[name] = distribute_utils.regroup(output) 

674 else: 

675 assert len(output) == 1 

676 last_step_tensor_outputs_dict[name] = output[0] 

677 

678 ctx._set_last_step_outputs(last_step_tensor_outputs_dict) # pylint: disable=protected-access 

679 return ctx 

680 

681 def _broadcast_to(self, tensor, destinations): 

682 # This is both a fast path for Python constants, and a way to delay 

683 # converting Python values to a tensor until we know what type it 

684 # should be converted to. Otherwise we have trouble with: 

685 # global_step.assign_add(1) 

686 # since the `1` gets broadcast as an int32 but global_step is int64. 

687 if isinstance(tensor, (float, int)): 

688 return tensor 

689 # TODO(josh11b): In eager mode, use one thread per device, or async mode. 

690 if not destinations: 

691 # TODO(josh11b): Use current logical device instead of 0 here. 

692 destinations = self._devices 

693 return self._get_cross_device_ops(tensor).broadcast(tensor, destinations) 

694 

695 def _call_for_each_replica(self, fn, args, kwargs): 

696 return mirrored_run.call_for_each_replica( 

697 self._container_strategy(), fn, args, kwargs) 

698 

699 def _configure(self, 

700 session_config=None, 

701 cluster_spec=None, 

702 task_type=None, 

703 task_id=None): 

704 del task_type, task_id 

705 

706 if session_config: 

707 session_config.CopyFrom(self._update_config_proto(session_config)) 

708 

709 if cluster_spec: 

710 # TODO(yuefengz): remove the following code once cluster_resolver is 

711 # added. 

712 num_gpus_per_worker = _infer_num_gpus_per_worker(self._devices) 

713 multi_worker_devices = _cluster_spec_to_device_list( 

714 cluster_spec, num_gpus_per_worker) 

715 self._initialize_multi_worker(multi_worker_devices) 

716 

717 def _update_config_proto(self, config_proto): 

718 updated_config = copy.deepcopy(config_proto) 

719 updated_config.isolate_session_state = True 

720 return updated_config 

721 

722 def _get_cross_device_ops(self, value): 

723 # Always use CollectiveAllReduce when XLA is enabled, since other cross 

724 # device ops don't have as good support on XLA. 

725 if not self._use_merge_call(): 

726 if not isinstance(self._cross_device_ops, 

727 cross_device_ops_lib.CollectiveAllReduce): 

728 logging.warning( 

729 "Under XLA context, MirroredStrategy uses CollectiveAllReduce op. " 

730 "Although %r is provided to initialize MirroredStrategy, it is " 

731 "ignored in XLA. Please use CollectiveAllReduce(or default option) " 

732 "in the future, since other cross device ops are not well " 

733 "supported on XLA.", self._cross_device_ops 

734 ) 

735 return self._collective_ops 

736 

737 if isinstance(value, values.DistributedValues): 

738 value_int32 = True in { 

739 dtypes.as_dtype(v.dtype) == dtypes.int32 for v in value.values 

740 } 

741 else: 

742 value_int32 = dtypes.as_dtype(value.dtype) == dtypes.int32 

743 

744 if value_int32: 

745 return cross_device_ops_lib.ReductionToOneDevice() 

746 else: 

747 return self._cross_device_ops 

748 

749 def _gather_to_implementation(self, value, destinations, axis, options): 

750 if not isinstance(value, values.DistributedValues): 

751 # ReductionToOneDevice._gather accepts DistributedValues only. 

752 return value 

753 return self._get_cross_device_ops(value)._gather( # pylint: disable=protected-access 

754 value, 

755 destinations=destinations, 

756 axis=axis, 

757 options=self._communication_options.merge(options)) 

758 

759 def _reduce_to(self, reduce_op, value, destinations, options): 

760 if (distribute_utils.is_mirrored(value) and 

761 reduce_op == reduce_util.ReduceOp.MEAN): 

762 return value 

763 assert not distribute_utils.is_mirrored(value) 

764 def get_values(value): 

765 if not isinstance(value, values.DistributedValues): 

766 # This function handles reducing values that are not PerReplica or 

767 # Mirrored values. For example, the same value could be present on all 

768 # replicas in which case `value` would be a single value or value could 

769 # be 0. 

770 return cross_device_ops_lib.reduce_non_distributed_value( 

771 reduce_op, value, destinations, self._num_replicas_in_sync) 

772 

773 if self._use_merge_call() and ( 

774 not cross_device_ops_lib._devices_match(value, destinations) or # pylint: disable=protected-access 

775 any("cpu" in d.lower() 

776 for d in cross_device_ops_lib.get_devices_from(destinations))): 

777 return cross_device_ops_lib.ReductionToOneDevice().reduce( 

778 reduce_op, value, destinations) 

779 return self._get_cross_device_ops(value).reduce( 

780 reduce_op, 

781 value, 

782 destinations=destinations, 

783 options=self._communication_options.merge(options)) 

784 

785 return nest.map_structure(get_values, value) 

786 

787 def _batch_reduce_to(self, reduce_op, value_destination_pairs, options): 

788 cross_device_ops = None 

789 for value, _ in value_destination_pairs: 

790 if cross_device_ops is None: 

791 cross_device_ops = self._get_cross_device_ops(value) 

792 elif cross_device_ops is not self._get_cross_device_ops(value): 

793 raise ValueError("Inputs to batch_reduce_to must be either all on " 

794 "the host or all on the compute devices.") 

795 return cross_device_ops.batch_reduce( 

796 reduce_op, 

797 value_destination_pairs, 

798 options=self._communication_options.merge(options)) 

799 

800 def _update(self, var, fn, args, kwargs, group): 

801 # TODO(josh11b): In eager mode, use one thread per device. 

802 assert isinstance(var, values.DistributedVariable) 

803 updates = [] 

804 for i, v in enumerate(var.values): 

805 name = "update_%d" % i 

806 with ops.device(v.device), \ 

807 distribute_lib.UpdateContext(i), \ 

808 ops.name_scope(name): 

809 # If args and kwargs are not mirrored, the value is returned as is. 

810 updates.append( 

811 fn(v, *distribute_utils.select_replica(i, args), 

812 **distribute_utils.select_replica(i, kwargs))) 

813 return distribute_utils.update_regroup(self, updates, group) 

814 

815 def _replica_ctx_all_reduce(self, reduce_op, value, options=None): 

816 """Implements `StrategyExtendedV2._replica_ctx_all_reduce`.""" 

817 # This implementation avoids using `merge_call` and just launches collective 

818 # ops in one replica. 

819 if options is None: 

820 options = collective_util.Options() 

821 

822 if context.executing_eagerly() or ( 

823 not tf2.enabled()) or self._use_merge_call(): 

824 # In eager mode, falls back to the default implementation that uses 

825 # `merge_call`. Replica functions are running sequentially in eager mode, 

826 # and due to the blocking nature of collective ops, execution will hang if 

827 # collective ops are to be launched sequentially. 

828 return super()._replica_ctx_all_reduce(reduce_op, value, options) 

829 

830 replica_context = distribute_lib.get_replica_context() 

831 assert replica_context, ( 

832 "`StrategyExtended._replica_ctx_all_reduce` must be called in a " 

833 "replica context") 

834 return self._get_cross_device_ops(value)._all_reduce( # pylint: disable=protected-access 

835 reduce_op, 

836 value, 

837 replica_context._replica_id, # pylint: disable=protected-access 

838 options) 

839 

840 def _replica_ctx_update(self, var, fn, args, kwargs, group): 

841 if self._use_merge_call(): 

842 return super()._replica_ctx_update(var, fn, args, kwargs, group) 

843 

844 replica_context = distribute_lib.get_replica_context() 

845 assert replica_context 

846 replica_id = values_util.get_current_replica_id_as_int() 

847 name = "update_%d" % replica_id 

848 

849 if isinstance(var, values.DistributedVariable): 

850 var = var._get_replica(replica_id) # pylint: disable=protected-access 

851 

852 with ops.device(var.device), ops.name_scope(name): 

853 result = fn(var, *args, **kwargs) 

854 return result 

855 

856 def _update_non_slot(self, colocate_with, fn, args, kwargs, group): 

857 assert isinstance(colocate_with, tuple) 

858 # TODO(josh11b): In eager mode, use one thread per device. 

859 updates = [] 

860 for i, d in enumerate(colocate_with): 

861 name = "update_%d" % i 

862 with ops.device(d), distribute_lib.UpdateContext(i), ops.name_scope(name): 

863 updates.append( 

864 fn(*distribute_utils.select_replica(i, args), 

865 **distribute_utils.select_replica(i, kwargs))) 

866 return distribute_utils.update_regroup(self, updates, group) 

867 

868 def read_var(self, replica_local_var): 

869 """Read the aggregate value of a replica-local variable.""" 

870 # pylint: disable=protected-access 

871 if distribute_utils.is_sync_on_read(replica_local_var): 

872 return replica_local_var._get_cross_replica() 

873 assert distribute_utils.is_mirrored(replica_local_var) 

874 return array_ops.identity(replica_local_var._get()) 

875 # pylint: enable=protected-access 

876 

877 def value_container(self, val): 

878 return distribute_utils.value_container(val) 

879 

880 @property 

881 def _num_replicas_in_sync(self): 

882 return len(self._devices) 

883 

884 @property 

885 def worker_devices(self): 

886 return self._devices 

887 

888 @property 

889 def worker_devices_by_replica(self): 

890 return [[d] for d in self._devices] 

891 

892 @property 

893 def parameter_devices(self): 

894 return self.worker_devices 

895 

896 @property 

897 def experimental_between_graph(self): 

898 return False 

899 

900 @property 

901 def experimental_should_init(self): 

902 return True 

903 

904 @property 

905 def should_checkpoint(self): 

906 return True 

907 

908 @property 

909 def should_save_summary(self): 

910 return True 

911 

912 def non_slot_devices(self, var_list): 

913 del var_list 

914 # TODO(josh11b): Should this be the last logical device instead? 

915 return self._devices 

916 

917 # TODO(priyag): Delete this once all strategies use global batch size. 

918 @property 

919 def _global_batch_size(self): 

920 """`make_dataset_iterator` and `make_numpy_iterator` use global batch size. 

921 

922 `make_input_fn_iterator` assumes per-replica batching. 

923 

924 Returns: 

925 Boolean. 

926 """ 

927 return True 

928 

929 def _in_multi_worker_mode(self): 

930 """Whether this strategy indicates working in multi-worker settings.""" 

931 return False 

932 

933 def _get_local_replica_id(self, replica_id_in_sync_group): 

934 return replica_id_in_sync_group 

935 

936 def _get_replica_id_in_sync_group(self, replica_id): 

937 return replica_id