Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/distribute/collective_all_reduce_strategy.py: 27%

406 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Class CollectiveAllReduceStrategy implementing DistributionStrategy.""" 

16 

17import copy 

18import threading 

19import time 

20import weakref 

21 

22from tensorflow.core.protobuf import rewriter_config_pb2 

23from tensorflow.core.protobuf import tensorflow_server_pb2 

24from tensorflow.python.distribute import collective_util 

25from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib 

26from tensorflow.python.distribute import cross_device_utils 

27from tensorflow.python.distribute import device_util 

28from tensorflow.python.distribute import distribute_lib 

29from tensorflow.python.distribute import distribute_utils 

30from tensorflow.python.distribute import input_lib 

31from tensorflow.python.distribute import input_util 

32from tensorflow.python.distribute import mirrored_strategy 

33from tensorflow.python.distribute import multi_worker_util 

34from tensorflow.python.distribute import numpy_dataset 

35from tensorflow.python.distribute import reduce_util 

36from tensorflow.python.distribute import values 

37from tensorflow.python.distribute.cluster_resolver import ClusterResolver 

38from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver 

39from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolver 

40from tensorflow.python.distribute.v1 import input_lib as input_lib_v1 

41from tensorflow.python.eager import context 

42from tensorflow.python.framework import device as tf_device 

43from tensorflow.python.framework import errors 

44from tensorflow.python.framework import ops 

45from tensorflow.python.ops import array_ops 

46from tensorflow.python.ops import collective_ops 

47from tensorflow.python.ops import control_flow_util 

48from tensorflow.python.platform import tf_logging as logging 

49from tensorflow.python.tpu import tpu_strategy_util 

50from tensorflow.python.trackable import base 

51from tensorflow.python.util import deprecation 

52from tensorflow.python.util.tf_export import tf_export 

53from tensorflow.tsl.protobuf import coordination_config_pb2 

54 

55 

56# pylint: disable=line-too-long 

57@tf_export("distribute.MultiWorkerMirroredStrategy", v1=[]) 

58class CollectiveAllReduceStrategy(distribute_lib.Strategy): 

59 """A distribution strategy for synchronous training on multiple workers. 

60 

61 This strategy implements synchronous distributed training across multiple 

62 workers, each with potentially multiple GPUs. Similar to 

63 `tf.distribute.MirroredStrategy`, it replicates all variables and computations 

64 to each local device. The difference is that it uses a distributed collective 

65 implementation (e.g. all-reduce), so that multiple workers can work together. 

66 

67 You need to launch your program on each worker and configure 

68 `cluster_resolver` correctly. For example, if you are using 

69 `tf.distribute.cluster_resolver.TFConfigClusterResolver`, each worker needs to 

70 have its corresponding `task_type` and `task_id` set in the `TF_CONFIG` 

71 environment variable. An example TF_CONFIG on worker-0 of a two worker cluster 

72 is: 

73 

74 ``` 

75 TF_CONFIG = '{"cluster": {"worker": ["localhost:12345", "localhost:23456"]}, "task": {"type": "worker", "index": 0} }' 

76 ``` 

77 

78 Your program runs on each worker as-is. Note that collectives require each 

79 worker to participate. All `tf.distribute` and non `tf.distribute` API may use 

80 collectives internally, e.g. checkpointing and saving since reading a 

81 `tf.Variable` with `tf.VariableSynchronization.ON_READ` all-reduces the value. 

82 Therefore it's recommended to run exactly the same program on each worker. 

83 Dispatching based on `task_type` or `task_id` of the worker is error-prone. 

84 

85 `cluster_resolver.num_accelerators()` determines the number of GPUs the 

86 strategy uses. If it's zero, the strategy uses the CPU. All workers need to 

87 use the same number of devices, otherwise the behavior is undefined. 

88 

89 This strategy is not intended for TPU. Use `tf.distribute.TPUStrategy` 

90 instead. 

91 

92 After setting up TF_CONFIG, using this strategy is similar to using 

93 `tf.distribute.MirroredStrategy` and `tf.distribute.TPUStrategy`. 

94 

95 ``` 

96 strategy = tf.distribute.MultiWorkerMirroredStrategy() 

97 

98 with strategy.scope(): 

99 model = tf.keras.Sequential([ 

100 tf.keras.layers.Dense(2, input_shape=(5,)), 

101 ]) 

102 optimizer = tf.keras.optimizers.SGD(learning_rate=0.1) 

103 

104 def dataset_fn(ctx): 

105 x = np.random.random((2, 5)).astype(np.float32) 

106 y = np.random.randint(2, size=(2, 1)) 

107 dataset = tf.data.Dataset.from_tensor_slices((x, y)) 

108 return dataset.repeat().batch(1, drop_remainder=True) 

109 dist_dataset = strategy.distribute_datasets_from_function(dataset_fn) 

110 

111 model.compile() 

112 model.fit(dist_dataset) 

113 ``` 

114 

115 You can also write your own training loop: 

116 

117 ``` 

118 @tf.function 

119 def train_step(iterator): 

120 

121 def step_fn(inputs): 

122 features, labels = inputs 

123 with tf.GradientTape() as tape: 

124 logits = model(features, training=True) 

125 loss = tf.keras.losses.sparse_categorical_crossentropy( 

126 labels, logits) 

127 

128 grads = tape.gradient(loss, model.trainable_variables) 

129 optimizer.apply_gradients(zip(grads, model.trainable_variables)) 

130 

131 strategy.run(step_fn, args=(next(iterator),)) 

132 

133 for _ in range(NUM_STEP): 

134 train_step(iterator) 

135 ``` 

136 

137 See 

138 [Multi-worker training with Keras](https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras) 

139 for a detailed tutorial. 

140 

141 __Saving__ 

142 

143 You need to save and checkpoint on all workers instead of just one. This is 

144 because variables whose synchronization=ON_READ triggers aggregation during 

145 saving. It's recommended to save to a different path on each worker to avoid 

146 race conditions. Each worker saves the same thing. See 

147 [Multi-worker training with Keras](https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras#model_saving_and_loading) 

148 tutorial for examples. 

149 

150 __Known Issues__ 

151 

152 * `tf.distribute.cluster_resolver.TFConfigClusterResolver` does not return the 

153 correct number of accelerators. The strategy uses all available GPUs if 

154 `cluster_resolver` is `tf.distribute.cluster_resolver.TFConfigClusterResolver` 

155 or `None`. 

156 * In eager mode, the strategy needs to be created before calling any other 

157 Tensorflow API. 

158 

159 """ 

160 # pylint: enable=line-too-long 

161 

162 # TODO(anjalisridhar): Update our guides with examples showing how we can use 

163 # the cluster_resolver argument. 

164 

165 # The starting number for collective keys. This should only be set in tests. 

166 _collective_key_base = 0 

167 

168 def __init__(self, 

169 cluster_resolver=None, 

170 communication_options=None): 

171 """Creates the strategy. 

172 

173 Args: 

174 cluster_resolver: optional 

175 `tf.distribute.cluster_resolver.ClusterResolver`. If `None`, 

176 `tf.distribute.cluster_resolver.TFConfigClusterResolver` is used. 

177 communication_options: optional 

178 `tf.distribute.experimental.CommunicationOptions`. This configures the 

179 default options for cross device communications. It can be overridden by 

180 options provided to the communication APIs like 

181 `tf.distribute.ReplicaContext.all_reduce`. See 

182 `tf.distribute.experimental.CommunicationOptions` for details. 

183 """ 

184 if communication_options is None: 

185 communication_options = collective_util.Options() 

186 super(CollectiveAllReduceStrategy, self).__init__( 

187 CollectiveAllReduceExtended( 

188 self, 

189 cluster_resolver=cluster_resolver, 

190 communication_options=communication_options)) 

191 

192 distribute_lib.distribution_strategy_gauge.get_cell("V2").set( 

193 "MultiWorkerMirroredStrategy") 

194 # pylint: disable=protected-access 

195 distribute_lib.distribution_strategy_replica_gauge.get_cell( 

196 "num_workers").set(self.extended._num_workers) 

197 distribute_lib.distribution_strategy_replica_gauge.get_cell( 

198 "num_replicas_per_worker").set(self.extended._num_devices_per_worker) 

199 

200 @classmethod 

201 def _from_local_devices(cls, devices, communication_options=None): 

202 """A convenience method to create an object with a list of devices.""" 

203 obj = cls(communication_options=communication_options) 

204 obj.extended._initialize_local(TFConfigClusterResolver(), devices=devices) # pylint: disable=protected-access 

205 return obj 

206 

207 @property 

208 def cluster_resolver(self): 

209 """Returns the cluster resolver associated with this strategy. 

210 

211 As a multi-worker strategy, `tf.distribute.MultiWorkerMirroredStrategy` 

212 provides the associated `tf.distribute.cluster_resolver.ClusterResolver`. If 

213 the user provides one in `__init__`, that instance is returned; if the user 

214 does not, a default `TFConfigClusterResolver` is provided. 

215 """ 

216 return self.extended._cluster_resolver # pylint: disable=protected-access 

217 

218 

219class _CollectiveAllReduceStrategyExperimentalMeta(type): 

220 

221 @classmethod 

222 def __instancecheck__(cls, instance): 

223 # This is to make isinstance(tf.distribute.MultiWorkerMirroredStrategy(), 

224 # tf.distribute.experimental.MultiWorkerMirroredStrategy). Some libraries is 

225 # performing such check. 

226 return isinstance(instance, CollectiveAllReduceStrategy) 

227 

228 

229@tf_export("distribute.experimental.MultiWorkerMirroredStrategy", v1=[]) 

230class _CollectiveAllReduceStrategyExperimental( 

231 CollectiveAllReduceStrategy, 

232 metaclass=_CollectiveAllReduceStrategyExperimentalMeta): 

233 

234 __doc__ = CollectiveAllReduceStrategy.__doc__ 

235 

236 @deprecation.deprecated( 

237 None, "use distribute.MultiWorkerMirroredStrategy instead") 

238 def __init__(self, 

239 communication=collective_util.CommunicationImplementation.AUTO, 

240 cluster_resolver=None): 

241 """Creates the strategy. 

242 

243 Args: 

244 communication: optional 

245 `tf.distribute.experimental.CommunicationImplementation`. This is a hint 

246 on the preferred collective communication implementation. Possible 

247 values include `AUTO`, `RING`, and `NCCL`. 

248 cluster_resolver: optional 

249 `tf.distribute.cluster_resolver.ClusterResolver`. If `None`, 

250 `tf.distribute.cluster_resolver.TFConfigClusterResolver` is used. 

251 """ 

252 communication_options = collective_util.Options( 

253 implementation=communication) 

254 super(_CollectiveAllReduceStrategyExperimental, 

255 self).__init__(cluster_resolver, communication_options) 

256 

257 @classmethod 

258 def _from_local_devices( 

259 cls, 

260 devices, 

261 communication=collective_util.CommunicationImplementation.AUTO): 

262 """A convenience method to create an object with a list of devices.""" 

263 obj = cls(communication) 

264 obj.extended._initialize_local(TFConfigClusterResolver(), devices=devices) # pylint: disable=protected-access 

265 return obj 

266 

267 

268_CollectiveAllReduceStrategyExperimental.__name__ = CollectiveAllReduceStrategy.__name__ 

269 

270 

271@tf_export(v1=["distribute.experimental.MultiWorkerMirroredStrategy"]) # pylint: disable=missing-docstring 

272class CollectiveAllReduceStrategyV1(distribute_lib.StrategyV1): 

273 

274 __doc__ = CollectiveAllReduceStrategy.__doc__ 

275 

276 # The starting number for collective keys. This should only be set in tests. 

277 _collective_key_base = 0 

278 

279 def __init__(self, 

280 communication=collective_util.CommunicationImplementation.AUTO, 

281 cluster_resolver=None): 

282 """Initializes the object.""" 

283 communication_options = collective_util.Options( 

284 implementation=communication) 

285 super(CollectiveAllReduceStrategyV1, self).__init__( 

286 CollectiveAllReduceExtended( 

287 self, 

288 cluster_resolver=cluster_resolver, 

289 communication_options=communication_options)) 

290 distribute_lib.distribution_strategy_gauge.get_cell("V1").set( 

291 "MultiWorkerMirroredStrategy") 

292 # pylint: disable=protected-access 

293 distribute_lib.distribution_strategy_replica_gauge.get_cell( 

294 "num_workers").set(self.extended._num_workers) 

295 distribute_lib.distribution_strategy_replica_gauge.get_cell( 

296 "num_gpu_per_worker").set( 

297 self.extended._num_devices_per_worker 

298 if self.extended._local_device_type == "GPU" 

299 else 0) 

300 

301 

302def _is_gpu_device(device): 

303 return tf_device.DeviceSpec.from_string(device).device_type == "GPU" 

304 

305 

306class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): 

307 """Implementation of CollectiveAllReduceStrategy.""" 

308 

309 # Whether to perdically check the health of the cluster. If any worker is not 

310 # reachable, collectives are aborted and the user program should get a 

311 # tf.errors.UnavailableError. It's required to restart in order to recover. 

312 _enable_check_health = True 

313 # Check health interval in seconds. 

314 _check_health_interval = 30 

315 # Timeout in seconds for the first check health. The first check health needs 

316 # to wait for cluster, which may make a longer time. 

317 _check_health_initial_timeout = 0 

318 # Times to retry before considering the peer is down. 

319 _check_health_retry_limit = 3 

320 # Timeout in seconds the each check health. 

321 _check_health_timeout = 10 

322 

323 def __init__(self, container_strategy, cluster_resolver, 

324 communication_options, devices=None): 

325 if not isinstance(communication_options, collective_util.Options): 

326 raise ValueError("communication_options must be an instance of " 

327 "tf.distribute.experimental.CommunicationOptions") 

328 if cluster_resolver and devices: 

329 raise ValueError( 

330 "cluster_resolver and devices cannot be set at the same time") 

331 

332 self._cluster_resolver = cluster_resolver or TFConfigClusterResolver() 

333 if not isinstance(self._cluster_resolver, ClusterResolver): 

334 raise ValueError("cluster_resolver must be an instance of " 

335 "tf.distribute.cluster_resolver.ClusterResolver") 

336 distribute_lib.StrategyExtendedV1.__init__(self, container_strategy) 

337 self._communication_options = communication_options 

338 self._collective_key_base = container_strategy._collective_key_base # pylint: disable=protected-access 

339 self._initialize_strategy(self._cluster_resolver, devices=devices) 

340 self._cfer_fn_cache = weakref.WeakKeyDictionary() 

341 self.experimental_enable_get_next_as_optional = True 

342 assert isinstance(self._cross_device_ops, 

343 cross_device_ops_lib.CollectiveAllReduce) 

344 

345 def _use_merge_call(self): 

346 # We currently only disable merge_call when XLA is used to compile the `fn` 

347 # passed to `strategy.run` and all devices are GPU. 

348 return not control_flow_util.GraphOrParentsInXlaContext( 

349 ops.get_default_graph()) or not all( 

350 [_is_gpu_device(d) for d in self._devices]) 

351 

352 def _initialize_strategy(self, cluster_resolver, devices): 

353 # If devices are provided or cluster_spec is not specified, initialize 

354 # single worker. Otherwise initialize multi workers. 

355 if devices or not cluster_resolver.cluster_spec().as_dict(): 

356 self._initialize_local(cluster_resolver, devices=devices) 

357 else: 

358 self._initialize_multi_worker(cluster_resolver) 

359 

360 def _initialize_local_devices(self, cluster_resolver, worker_device): 

361 # TODO(b/126786766): TFConfigClusterResolver returns wrong number of GPUs in 

362 # some cases. 

363 if isinstance(cluster_resolver, TFConfigClusterResolver): 

364 num_gpus = context.num_gpus() 

365 num_tpus = 0 

366 else: 

367 num_gpus = cluster_resolver.num_accelerators().get("GPU", 0) 

368 num_tpus = cluster_resolver.num_accelerators().get("TPU", 0) 

369 

370 if num_gpus: 

371 local_device_type = "GPU" 

372 num_local_devices = num_gpus 

373 elif num_tpus: 

374 local_device_type = "TPU" 

375 num_local_devices = num_tpus 

376 else: 

377 local_device_type = "CPU" 

378 num_local_devices = 1 

379 local_devices = tuple( 

380 f"{worker_device}/device:{local_device_type}:{i}" 

381 for i in range(num_local_devices)) 

382 return local_devices, local_device_type 

383 

384 def _initialize_local(self, cluster_resolver, devices=None): 

385 """Initializes the object for local training.""" 

386 self._is_chief = True 

387 self._num_workers = 1 

388 

389 if ops.executing_eagerly_outside_functions(): 

390 try: 

391 context.context().configure_collective_ops( 

392 scoped_allocator_enabled_ops=("CollectiveReduce",)) 

393 except RuntimeError: 

394 logging.warning("Collective ops is not configured at program startup. " 

395 "Some performance features may not be enabled.") 

396 self._collective_ops_configured = True 

397 

398 if devices: 

399 local_devices = devices 

400 if "GPU" in devices[0]: 

401 local_device_type = "GPU" 

402 elif "TPU" in devices[0]: 

403 local_device_type = "TPU" 

404 else: 

405 local_device_type = "CPU" 

406 else: 

407 local_devices, local_device_type = self._initialize_local_devices( 

408 cluster_resolver, worker_device="") 

409 

410 self._worker_device = device_util.canonicalize("/device:CPU:0") 

411 self._host_input_device = numpy_dataset.SingleDevice(self._worker_device) 

412 

413 self._collective_keys = cross_device_utils.CollectiveKeys( 

414 group_key_start=1 + self._collective_key_base) 

415 self._cross_device_ops = cross_device_ops_lib.CollectiveAllReduce( 

416 devices=local_devices, 

417 group_size=len(local_devices), 

418 options=self._communication_options, 

419 collective_keys=self._collective_keys) 

420 # CrossDeviceOps for per host tensors. 

421 self._host_cross_device_ops = cross_device_ops_lib.CollectiveAllReduce( 

422 devices=[self._worker_device], 

423 group_size=self._num_workers, 

424 options=self._communication_options, 

425 collective_keys=self._collective_keys) 

426 super(CollectiveAllReduceExtended, self)._initialize_single_worker( 

427 local_devices) 

428 

429 self._cluster_spec = None 

430 self._task_type = None 

431 self._task_id = None 

432 self._id_in_cluster = 0 

433 

434 # This is a mark to tell whether we are running with standalone client or 

435 # independent worker. Right now with standalone client, strategy object is 

436 # created as local strategy and then turn into multi-worker strategy via 

437 # configure call. 

438 self._local_or_standalone_client_mode = True 

439 

440 # Save the num_devices_per_worker and rpc_layer for configure method. 

441 self._num_devices_per_worker = len(local_devices) 

442 self._local_device_type = local_device_type 

443 self._rpc_layer = cluster_resolver.rpc_layer 

444 self._warn_nccl_no_gpu() 

445 

446 logging.info( 

447 "Single-worker MultiWorkerMirroredStrategy with local_devices " 

448 "= %r, communication = %s", local_devices, 

449 self._communication_options.implementation) 

450 

451 def _initialize_multi_worker(self, cluster_resolver): 

452 """Initializes the object for multi-worker training.""" 

453 cluster_spec = multi_worker_util.normalize_cluster_spec( 

454 cluster_resolver.cluster_spec()) 

455 task_type = cluster_resolver.task_type 

456 task_id = cluster_resolver.task_id 

457 if task_type is None or task_id is None: 

458 raise ValueError("When `cluster_spec` is given, you must also specify " 

459 "`task_type` and `task_id`.") 

460 self._cluster_spec = cluster_spec 

461 self._task_type = task_type 

462 self._task_id = task_id 

463 self._id_in_cluster = multi_worker_util.id_in_cluster( 

464 self._cluster_spec, self._task_type, self._task_id) 

465 

466 self._num_workers = multi_worker_util.worker_count(cluster_spec, task_type) 

467 if not self._num_workers: 

468 raise ValueError("No `worker`, `chief` or `evaluator` tasks can be found " 

469 "in `cluster_spec`.") 

470 

471 self._is_chief = multi_worker_util.is_chief(cluster_spec, task_type, 

472 task_id) 

473 

474 self._worker_device = "/job:%s/task:%d" % (task_type, task_id) 

475 self._host_input_device = numpy_dataset.SingleDevice(self._worker_device) 

476 

477 if (ops.executing_eagerly_outside_functions() and 

478 not getattr(self, "_local_or_standalone_client_mode", False)): 

479 context.context().configure_collective_ops( 

480 collective_leader=multi_worker_util.collective_leader( 

481 cluster_spec, task_type, task_id), 

482 scoped_allocator_enabled_ops=("CollectiveReduce",), 

483 device_filters=("/job:%s/task:%d" % (task_type, task_id),)) 

484 self._collective_ops_configured = True 

485 if context.context().coordination_service is None: 

486 coordinated_jobs = ["chief", "worker"] 

487 if task_type in coordinated_jobs: 

488 coordinated_job_config = [] 

489 for job in coordinated_jobs: 

490 if job in cluster_spec.jobs: 

491 coordinated_job_config.append( 

492 coordination_config_pb2.CoordinatedJob( 

493 name=job, 

494 num_tasks=cluster_spec.num_tasks(job))) 

495 context.context().configure_coordination_service( 

496 service_type="standalone", 

497 service_leader=multi_worker_util.coordination_leader( 

498 cluster_spec), 

499 coordinated_jobs=coordinated_job_config) 

500 

501 # Starting a std server in eager mode and in independent worker mode. 

502 if (context.executing_eagerly() and 

503 not getattr(self, "_std_server_started", False) and 

504 not getattr(self, "_local_or_standalone_client_mode", False)): 

505 # Checking _local_or_standalone_client_mode as well because we should not 

506 # create the std server in standalone client mode. 

507 config_proto = copy.deepcopy(context.context().config) 

508 config_proto = self._update_config_proto(config_proto) 

509 

510 # If coordination service is enabled, use its internal heartbeat to detect 

511 # peer failures instead of the Python-level health check. 

512 if config_proto.experimental.coordination_config.service_type: 

513 self._enable_check_health = False 

514 

515 if hasattr(cluster_resolver, "port"): 

516 port = cluster_resolver.port 

517 else: 

518 port = 0 

519 server_def = tensorflow_server_pb2.ServerDef( 

520 cluster=cluster_spec.as_cluster_def(), 

521 default_session_config=config_proto, 

522 job_name=task_type, 

523 task_index=task_id, 

524 protocol=cluster_resolver.rpc_layer or "grpc", 

525 port=port) 

526 context.context().enable_collective_ops(server_def) 

527 self._std_server_started = True 

528 # The `ensure_initialized` is needed before calling 

529 # `context.context().devices()`. 

530 context.context().ensure_initialized() 

531 logging.info( 

532 "Enabled multi-worker collective ops with available devices: %r", 

533 context.context().devices()) 

534 

535 # TODO(yuefengz): The `num_gpus` is only for this particular task. It 

536 # assumes all workers have the same number of GPUs. We should remove this 

537 # assumption by querying all tasks for their numbers of GPUs. 

538 # TODO(b/126786766): TFConfigClusterResolver returns wrong number of GPUs in 

539 # some cases. 

540 local_devices, local_device_type = self._initialize_local_devices( 

541 cluster_resolver, self._worker_device) 

542 if local_device_type == "TPU": 

543 tpu_strategy_util.initialize_tpu_system() 

544 

545 self._collective_keys = cross_device_utils.CollectiveKeys( 

546 group_key_start=1 + self._collective_key_base) 

547 self._cross_device_ops = cross_device_ops_lib.CollectiveAllReduce( 

548 devices=local_devices, 

549 group_size=len(local_devices) * self._num_workers, 

550 options=self._communication_options, 

551 collective_keys=self._collective_keys) 

552 # CrossDeviceOps for per host tensors. 

553 self._host_cross_device_ops = cross_device_ops_lib.CollectiveAllReduce( 

554 devices=[self._worker_device], 

555 group_size=self._num_workers, 

556 options=self._communication_options, 

557 collective_keys=self._collective_keys) 

558 super(CollectiveAllReduceExtended, self)._initialize_single_worker( 

559 local_devices) 

560 

561 # Add a default device so that ops without specified devices will not end up 

562 # on other workers. 

563 self._default_device = "/job:%s/task:%d" % (task_type, task_id) 

564 

565 # Save the num_devices_per_worker and rpc_layer for configure method. 

566 self._num_devices_per_worker = len(local_devices) 

567 self._local_device_type = local_device_type 

568 self._rpc_layer = cluster_resolver.rpc_layer 

569 self._warn_nccl_no_gpu() 

570 

571 if self._enable_check_health and context.executing_eagerly(): 

572 self._start_check_health_thread() 

573 else: 

574 logging.info("Check health not enabled.") 

575 

576 logging.info( 

577 "MultiWorkerMirroredStrategy with cluster_spec = %r, task_type = %r, " 

578 "task_id = %r, num_workers = %r, local_devices = %r, " 

579 "communication = %s", cluster_spec.as_dict(), task_type, task_id, 

580 self._num_workers, local_devices, 

581 self._communication_options.implementation) 

582 

583 def __del__(self): 

584 self._stop_check_health_thread() 

585 

586 def _input_workers_with_options(self, options=None): 

587 host_device = device_util.get_host_for_device(self._worker_device) 

588 if not options or options.experimental_fetch_to_device: 

589 return input_lib.InputWorkers([(host_device, self.worker_devices)]) 

590 else: 

591 return input_lib.InputWorkers([( 

592 host_device, 

593 [device_util.get_host_for_device(worker) for worker in 

594 self.worker_devices])]) 

595 

596 @property 

597 def _input_workers(self): 

598 return self._input_workers_with_options() 

599 

600 def _get_variable_creator_initial_value(self, 

601 replica_id, 

602 device, 

603 primary_var, 

604 **kwargs): 

605 if replica_id == 0: # First replica on each worker. 

606 assert device is not None 

607 assert primary_var is None 

608 

609 def initial_value_fn(): # pylint: disable=g-missing-docstring 

610 # Only the first device participates in the broadcast of initial values. 

611 group_key = self._collective_keys.get_group_key([device]) 

612 group_size = self._num_workers 

613 collective_instance_key = ( 

614 self._collective_keys.get_instance_key(group_key, device)) 

615 

616 with ops.device(device): 

617 initial_value = kwargs["initial_value"] 

618 if callable(initial_value): 

619 initial_value = initial_value() 

620 if isinstance(initial_value, base.CheckpointInitialValue): 

621 initial_value = initial_value.wrapped_value 

622 assert not callable(initial_value) 

623 initial_value = ops.convert_to_tensor( 

624 initial_value, dtype=kwargs.get("dtype", None)) 

625 

626 if self._num_workers > 1: 

627 if self._is_chief: 

628 bcast_send = collective_ops.broadcast_send( 

629 initial_value, initial_value.shape, initial_value.dtype, 

630 group_size, group_key, collective_instance_key) 

631 with ops.control_dependencies([bcast_send]): 

632 return array_ops.identity(initial_value) 

633 else: 

634 return collective_ops.broadcast_recv(initial_value.shape, 

635 initial_value.dtype, 

636 group_size, group_key, 

637 collective_instance_key) 

638 return initial_value 

639 

640 return initial_value_fn 

641 else: 

642 return super(CollectiveAllReduceExtended, 

643 self)._get_variable_creator_initial_value( 

644 replica_id=replica_id, 

645 device=device, 

646 primary_var=primary_var, 

647 **kwargs) 

648 

649 def _make_input_context(self): 

650 input_context = distribute_lib.InputContext( 

651 num_input_pipelines=self._num_workers, 

652 input_pipeline_id=self._id_in_cluster, 

653 num_replicas_in_sync=self._num_replicas_in_sync) 

654 return input_context 

655 

656 def _experimental_distribute_dataset(self, dataset, options): 

657 if (options and options.experimental_replication_mode == 

658 distribute_lib.InputReplicationMode.PER_REPLICA): 

659 raise NotImplementedError( 

660 "InputReplicationMode.PER_REPLICA " 

661 "is only supported in " 

662 "`distribute_datasets_from_function` " 

663 "of tf.distribute.MirroredStrategy" 

664 ) 

665 input_context = self._make_input_context() 

666 return input_util.get_distributed_dataset( 

667 dataset, 

668 self._input_workers_with_options(options), 

669 self._container_strategy(), 

670 num_replicas_in_sync=self._num_replicas_in_sync, 

671 input_context=input_context, 

672 options=options) 

673 

674 def _distribute_datasets_from_function(self, dataset_fn, options): 

675 if (options and options.experimental_replication_mode == 

676 distribute_lib.InputReplicationMode.PER_REPLICA): 

677 raise NotImplementedError( 

678 "InputReplicationMode.PER_REPLICA " 

679 "is only supported in " 

680 "`distribute_datasets_from_function` " 

681 "of tf.distribute.MirroredStrategy") 

682 input_context = self._make_input_context() 

683 return input_util.get_distributed_datasets_from_function( 

684 dataset_fn=dataset_fn, 

685 input_workers=self._input_workers_with_options(options), 

686 input_contexts=[input_context], 

687 strategy=self._container_strategy(), 

688 options=options) 

689 

690 def _experimental_distribute_values_from_function(self, value_fn): 

691 per_replica_values = [] 

692 num_local_replicas = len(self.worker_devices) 

693 for local_replica_id in range(num_local_replicas): 

694 replica_id = (self._id_in_cluster * num_local_replicas + 

695 local_replica_id) 

696 value_context = distribute_lib.ValueContext( 

697 replica_id, self._num_replicas_in_sync) 

698 per_replica_values.append(value_fn(value_context)) 

699 return distribute_utils.regroup(per_replica_values, always_wrap=True) 

700 

701 def _make_dataset_iterator(self, dataset): 

702 """Distributes the dataset to each local GPU.""" 

703 input_context = self._make_input_context() 

704 return input_lib_v1.DatasetIterator( 

705 dataset, 

706 self._input_workers, 

707 self._container_strategy(), 

708 num_replicas_in_sync=self._num_replicas_in_sync, 

709 input_context=input_context) 

710 

711 def _make_input_fn_iterator( 

712 self, 

713 input_fn, 

714 replication_mode=distribute_lib.InputReplicationMode.PER_WORKER): 

715 """Distributes the input function to each local GPU.""" 

716 input_context = self._make_input_context() 

717 return input_lib_v1.InputFunctionIterator(input_fn, self._input_workers, 

718 [input_context], 

719 self._container_strategy()) 

720 

721 def _configure(self, 

722 session_config=None, 

723 cluster_spec=None, 

724 task_type=None, 

725 task_id=None): 

726 """Configures the object. 

727 

728 Args: 

729 session_config: a `tf.compat.v1.ConfigProto` 

730 cluster_spec: a dict, ClusterDef or ClusterSpec object specifying the 

731 cluster configurations. 

732 task_type: the current task type, such as "worker". 

733 task_id: the current task id. 

734 

735 Raises: 

736 ValueError: if `task_type` is not in the `cluster_spec`. 

737 """ 

738 if cluster_spec: 

739 cluster_resolver = SimpleClusterResolver( 

740 cluster_spec=multi_worker_util.normalize_cluster_spec(cluster_spec), 

741 task_type=task_type, 

742 task_id=task_id, 

743 num_accelerators={ 

744 self._local_device_type: self._num_devices_per_worker}, 

745 rpc_layer=self._rpc_layer) 

746 self._initialize_multi_worker(cluster_resolver) 

747 assert isinstance(self._cross_device_ops, 

748 cross_device_ops_lib.CollectiveAllReduce) 

749 

750 if session_config: 

751 session_config.CopyFrom(self._update_config_proto(session_config)) 

752 

753 def _update_config_proto(self, config_proto): 

754 updated_config = copy.deepcopy(config_proto) 

755 # Enable the scoped allocator optimization for CollectiveOps. This 

756 # optimization converts many small all-reduces into fewer larger 

757 # all-reduces. 

758 rewrite_options = updated_config.graph_options.rewrite_options 

759 rewrite_options.scoped_allocator_optimization = ( 

760 rewriter_config_pb2.RewriterConfig.ON) 

761 # We turn on ScopedAllocator only for CollectiveReduce op, i.e. enable_op = 

762 # ["CollectiveReduce"]. Since we can't assign to a repeated proto field, we 

763 # clear and then append. 

764 del rewrite_options.scoped_allocator_opts.enable_op[:] 

765 rewrite_options.scoped_allocator_opts.enable_op.append("CollectiveReduce") 

766 

767 if (not ops.executing_eagerly_outside_functions() and 

768 self._communication_options.implementation == 

769 collective_util.CommunicationImplementation.NCCL): 

770 updated_config.experimental.collective_nccl = True 

771 

772 if not self._cluster_spec: 

773 return updated_config 

774 

775 assert self._task_type 

776 assert self._task_id is not None 

777 

778 # Collective group leader is needed for collective ops to coordinate 

779 # workers. 

780 updated_config.experimental.collective_group_leader = ( 

781 multi_worker_util.collective_leader(self._cluster_spec, self._task_type, 

782 self._task_id)) 

783 

784 # The device filters prevent communication between workers. 

785 del updated_config.device_filters[:] 

786 updated_config.device_filters.append( 

787 "/job:%s/task:%d" % (self._task_type, self._task_id)) 

788 

789 return updated_config 

790 

791 def _get_cross_device_ops(self, value): 

792 # CollectiveAllReduce works on a predefined set of devices. In most cases 

793 # they should be the compute devices, but certain use cases may reduce host 

794 # tensors as well (e.g. early stopping). We infer the cross_device_ops to 

795 # use based on the number of devices, since inputs don't always have device 

796 # annotations. The compute devices one is preferred since we can potentially 

797 # leverage NCCL. 

798 if isinstance(value, values.DistributedValues): 

799 num_devices = len(value._values) # pylint: disable=protected-access 

800 else: 

801 num_devices = 1 

802 if num_devices == len(self.worker_devices): 

803 return self._cross_device_ops 

804 else: 

805 return self._host_cross_device_ops 

806 

807 def _gather_to_implementation(self, value, destinations, axis, options): 

808 return self._get_cross_device_ops(value)._gather( # pylint: disable=protected-access 

809 value, 

810 destinations=destinations, 

811 axis=axis, 

812 options=options) 

813 

814 def _reduce_to(self, reduce_op, value, destinations, options): 

815 if (isinstance(value, values.Mirrored) and 

816 reduce_op == reduce_util.ReduceOp.MEAN): 

817 return value 

818 assert not isinstance(value, values.Mirrored) 

819 

820 if (isinstance(value, values.DistributedValues) and 

821 len(self.worker_devices) == 1): 

822 value = value.values[0] 

823 

824 # When there are multiple workers, we need to reduce across workers using 

825 # collective ops. 

826 if (not isinstance(value, values.DistributedValues) and 

827 self._num_workers == 1): 

828 # This function handles reducing values that are not PerReplica or 

829 # Mirrored values. For example, the same value could be present on all 

830 # replicas in which case `value` would be a single value or value could 

831 # be 0. 

832 return cross_device_ops_lib.reduce_non_distributed_value( 

833 reduce_op, value, destinations, len(self.worker_devices)) 

834 return self._get_cross_device_ops(value).reduce( 

835 reduce_op, 

836 value, 

837 destinations=destinations, 

838 options=self._communication_options.merge(options)) 

839 

840 def _replica_ctx_all_reduce(self, reduce_op, value, options=None): 

841 """Implements `StrategyExtendedV2._replica_ctx_all_reduce`.""" 

842 # This implementation avoids using `merge_call` and just launches collective 

843 # ops in one replica. 

844 if options is None: 

845 options = collective_util.Options() 

846 

847 if context.executing_eagerly(): 

848 # In eager mode, falls back to the default implemenation that uses 

849 # `merge_call`. Replica functions are running sequentially in eager mode, 

850 # and due to the blocking nature of collective ops, execution will hang if 

851 # collective ops are to be launched sequentially. 

852 return super()._replica_ctx_all_reduce(reduce_op, value, options) 

853 

854 replica_context = distribute_lib.get_replica_context() 

855 assert replica_context, ( 

856 "`StrategyExtended._replica_ctx_all_reduce` must be called in a " 

857 "replica context") 

858 return self._cross_device_ops._all_reduce( # pylint: disable=protected-access 

859 reduce_op, 

860 value, 

861 replica_context._replica_id, # pylint: disable=protected-access 

862 options) 

863 

864 def _check_health(self): 

865 while True: 

866 if self._check_health_thread_should_stop.is_set(): 

867 return 

868 for job in self._cluster_spec.jobs: 

869 for task_id in range(self._cluster_spec.num_tasks(job)): 

870 peer = "/job:{}/replica:0/task:{}".format(job, task_id) 

871 attempts = 0 

872 while True: 

873 attempts += 1 

874 try: 

875 context.context().check_collective_ops_peer_health( 

876 peer, timeout_in_ms=self._check_health_timeout * 1000) 

877 # If check_collective_ops_peer_health doesn't raise an Exception, 

878 # the peer is healthy. 

879 break 

880 except (errors.UnavailableError, errors.FailedPreconditionError, 

881 errors.DeadlineExceededError) as e: 

882 # TODO(b/151232436): Always raise UnavailableError when a peer 

883 # fails. Now there could be many kinds of errors: 

884 # - Unavailable: when the peer is not reachable, e.g. it's down. 

885 # - FailedPrecondition: when the peer has restarted. 

886 if attempts < self._check_health_retry_limit: 

887 logging.warning("%s seems down, retrying %d/%d", peer, attempts, 

888 self._check_health_retry_limit) 

889 continue 

890 logging.error( 

891 "Cluster check alive failed, %s is down, " 

892 "aborting collectives: %s", peer, e) 

893 context.context().abort_collective_ops( 

894 errors.UNAVAILABLE, 

895 "cluster check alive failed, {} is down".format(peer)) 

896 return 

897 except Exception as e: # pylint: disable=broad-except 

898 logging.error("Unexpected exception in check alive: %s", e) 

899 context.context().abort_collective_ops( 

900 errors.INTERNAL, 

901 "unexecpted exception in check alive: %s" % e) 

902 return 

903 time.sleep(self._check_health_interval) 

904 

905 def _start_check_health_thread(self): 

906 # Use a dummy all-reduce as a barrier to wait for all workers to be up, 

907 # otherwise the check health may fail immediately. 

908 

909 # Use array_ops.identity to create the dummy tensor so that we have a new 

910 # Tensor. If we use constant it may be a cached from on a /job:localhost 

911 # device, which will cause some code that relies on tensor.device to error. 

912 # 

913 # TODO(b/151232436): change to an explicit barrier if we have it. 

914 dummy_value = array_ops.identity([]) 

915 logging.info("Waiting for the cluster, timeout = %s", 

916 self._check_health_initial_timeout or "inf") 

917 try: 

918 self._host_cross_device_ops.reduce( 

919 reduce_util.ReduceOp.SUM, 

920 dummy_value, 

921 dummy_value, 

922 options=collective_util.Options( 

923 timeout_seconds=self._check_health_initial_timeout, 

924 implementation=collective_util.CommunicationImplementation.RING)) 

925 if context.is_async(): 

926 context.async_wait() 

927 except errors.DeadlineExceededError: 

928 raise RuntimeError( 

929 "Timeout waiting for the cluster, timeout is %d seconds" % 

930 self._check_health_initial_timeout) 

931 logging.info("Cluster is ready.") 

932 self._check_health_thread_should_stop = threading.Event() 

933 # Start the thread as daemon to avoid it blocking the program from exiting. 

934 # We try best to shutdown the thread but __del__ is not guaranteed to be 

935 # called when program exists. 

936 self._check_health_thread = threading.Thread( 

937 target=self._check_health, 

938 daemon=True) 

939 self._check_health_thread.start() 

940 

941 def _stop_check_health_thread(self): 

942 if getattr(self, "_check_health_thread", None): 

943 logging.info("stopping check health thread") 

944 self._check_health_thread_should_stop.set() 

945 self._check_health_thread.join() 

946 self._check_health_thread = None 

947 logging.info("check health thread stopped") 

948 

949 def _warn_nccl_no_gpu(self): 

950 if ((self._communication_options.implementation == 

951 collective_util.CommunicationImplementation.NCCL) and 

952 self._local_device_type != "GPU"): 

953 logging.warning("Enabled NCCL communication but no GPUs detected/" 

954 "specified.") 

955 

956 def _in_multi_worker_mode(self): 

957 """Whether this strategy indicates working in multi-worker settings.""" 

958 return self._num_workers > 1 

959 

960 @property 

961 def experimental_between_graph(self): 

962 return True 

963 

964 @property 

965 def experimental_should_init(self): 

966 return True 

967 

968 @property 

969 def should_checkpoint(self): 

970 return self._is_chief 

971 

972 @property 

973 def should_save_summary(self): 

974 return self._is_chief 

975 

976 @property 

977 def _num_replicas_in_sync(self): 

978 return len(self.worker_devices) * self._num_workers 

979 

980 # TODO(priyag): Delete this once all strategies use global batch size. 

981 @property 

982 def _global_batch_size(self): 

983 """`make_dataset_iterator` and `make_numpy_iterator` use global batch size. 

984 

985 `make_input_fn_iterator` assumes per-replica batching. 

986 

987 Returns: 

988 Boolean. 

989 """ 

990 return True 

991 

992 def _get_replica_id_in_sync_group(self, replica_id): 

993 return self._id_in_cluster * len(self.worker_devices) + replica_id 

994 

995 def _get_local_replica_id(self, replica_id_in_sync_group): 

996 return (replica_id_in_sync_group - 

997 self._id_in_cluster * len(self.worker_devices)) 

998 

999 def __deepcopy__(self, memo): 

1000 # We check the check health thread instead of whether we are in eager mode 

1001 # to limit the backward incompatibility. 

1002 if hasattr(self, "_check_health_thread"): 

1003 raise ValueError( 

1004 "MultiWorkerMirroredStrategy cannot be deep copied in eager mode. " 

1005 "If you're using Estimator and see this error message, call " 

1006 "tf.compat.v1.disable_eager_execution() at the beginning of your " 

1007 "program") 

1008 # Otherwise, do a regular deepcopy. 

1009 cls = self.__class__ 

1010 result = cls.__new__(cls) 

1011 memo[id(self)] = result 

1012 for k, v in self.__dict__.items(): 

1013 setattr(result, k, copy.deepcopy(v, memo)) 

1014 return result