Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/distribute/parameter_server_strategy_v2.py: 23%

271 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Parameter server strategy V2 class. 

16 

17This is currently under development and the API is subject to change. 

18""" 

19 

20import functools 

21import os 

22import threading 

23 

24from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib 

25from tensorflow.python.distribute import device_util 

26from tensorflow.python.distribute import distribute_lib 

27from tensorflow.python.distribute import input_lib 

28from tensorflow.python.distribute import input_util 

29from tensorflow.python.distribute import mirrored_run 

30from tensorflow.python.distribute import multi_worker_util 

31from tensorflow.python.distribute import parameter_server_strategy 

32from tensorflow.python.distribute import ps_values 

33from tensorflow.python.distribute import sharded_variable 

34from tensorflow.python.distribute import values 

35from tensorflow.python.distribute.coordinator import cluster_coordinator 

36from tensorflow.python.eager import context 

37from tensorflow.python.eager import remote 

38from tensorflow.python.framework import config 

39from tensorflow.python.framework import device as tf_device 

40from tensorflow.python.framework import ops 

41from tensorflow.python.framework import tensor_shape 

42from tensorflow.python.ops import array_ops 

43from tensorflow.python.ops import resource_variable_ops 

44from tensorflow.python.ops import variable_scope as vs 

45from tensorflow.python.platform import tf_logging as logging 

46from tensorflow.python.trackable import base as trackable 

47from tensorflow.python.training import server_lib 

48from tensorflow.python.util import keras_deps 

49from tensorflow.python.util import nest 

50from tensorflow.python.util import tf_inspect 

51from tensorflow.python.util.tf_export import tf_export 

52from tensorflow.tsl.protobuf import coordination_config_pb2 

53 

54 

55ALLOWED_TASK_TYPES = ("chief", "worker", "ps") 

56# This sets the coordination service's internal heartbeat timeout. In testing, a 

57# value of 1 led to some spurious reports of unavailability, so a higher value 

58# is used. Refer to the discussion in b/249134783 for more. 

59_HEARTBEAT_TIMEOUT_SECS = 5 

60 

61 

62@tf_export( 

63 "distribute.experimental.ParameterServerStrategy", 

64 "distribute.ParameterServerStrategy", 

65 v1=[]) 

66class ParameterServerStrategyV2(distribute_lib.Strategy): 

67 """An multi-worker tf.distribute strategy with parameter servers. 

68 

69 Parameter server training is a common data-parallel method to scale up a 

70 machine learning model on multiple machines. A parameter server training 

71 cluster consists of workers and parameter servers. Variables are created on 

72 parameter servers and they are read and updated by workers in each step. 

73 By default, workers read and update these variables independently without 

74 synchronizing with each other. Under this configuration, it is known as 

75 asynchronous training. 

76 

77 In TensorFlow 2, we recommend an architecture based on central coordination 

78 for parameter server training. Each worker and parameter server runs a 

79 `tf.distribute.Server`, and on top of that, a coordinator task is responsible 

80 for creating resources on workers and parameter servers, dispatching 

81 functions, and coordinating the training. The coordinator uses a 

82 `tf.distribute.experimental.coordinator.ClusterCoordinator` to coordinate the 

83 cluster, and a `tf.distribute.experimental.ParameterServerStrategy` to define 

84 variables on parameter servers and computation on workers. 

85 

86 For the training to work, the coordinator dispatches `tf.function`s to be 

87 executed on remote workers. Upon receiving requests from the coordinator, a 

88 worker executes the `tf.function` by reading the variables from parameter 

89 servers, executing the ops, and updating the variables on the parameter 

90 servers. Each of the worker only processes the requests from the coordinator, 

91 and communicates with parameter servers, without direct interactions with 

92 other workers in the cluster. 

93 

94 As a result, failures of some workers do not prevent the cluster from 

95 continuing the work, and this allows the cluster to train with instances that 

96 can be occasionally unavailable (e.g. preemptible or spot instances). The 

97 coordinator and parameter servers though, must be available at all times for 

98 the cluster to make progress. 

99 

100 Note that the coordinator is not one of the training workers. Instead, it 

101 creates resources such as variables and datasets, dispatches `tf.function`s, 

102 saves checkpoints and so on. In addition to workers, parameter servers and 

103 the coordinator, an optional evaluator can be run on the side that 

104 periodically reads the checkpoints saved by the coordinator and runs 

105 evaluations against each checkpoint. 

106 

107 `ParameterServerStrategy` is supported with two training APIs: [Custom 

108 Training Loop (CTL)] 

109 (https://www.tensorflow.org/tutorials/distribute/custom_training) 

110 and [Keras Training API, also known as `Model.fit`] 

111 (https://www.tensorflow.org/tutorials/distribute/keras). CTL is recommended 

112 when users prefer to define the details of their training loop, and 

113 `Model.fit` is recommended when users prefer a high-level abstraction and 

114 handling of training. 

115 

116 When using a CTL, `ParameterServerStrategy` has to work in conjunction with a 

117 `tf.distribute.experimental.coordinator.ClusterCoordinator` object. 

118 

119 When using `Model.fit`, currently only the 

120 `tf.keras.utils.experimental.DatasetCreator` input type is supported. 

121 

122 __Example code for coordinator__ 

123 

124 This section provides code snippets that are intended to be run on (the only) 

125 one task that is designated as the coordinator. Note that `cluster_resolver`, 

126 `variable_partitioner`, and `dataset_fn` arguments are explained in the 

127 following "Cluster setup", "Variable partitioning", and "Dataset preparation" 

128 sections. 

129 

130 With a CTL, 

131 

132 ```python 

133 # Prepare a strategy to use with the cluster and variable partitioning info. 

134 strategy = tf.distribute.experimental.ParameterServerStrategy( 

135 cluster_resolver=..., 

136 variable_partitioner=...) 

137 coordinator = tf.distribute.experimental.coordinator.ClusterCoordinator( 

138 strategy=strategy) 

139 

140 # Prepare a distribute dataset that will place datasets on the workers. 

141 distributed_dataset = coordinator.create_per_worker_dataset(dataset_fn=...) 

142 

143 with strategy.scope(): 

144 model = ... 

145 optimizer, metrics = ... # Keras optimizer/metrics are great choices 

146 checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer) 

147 checkpoint_manager = tf.train.CheckpointManager( 

148 checkpoint, checkpoint_dir, max_to_keep=2) 

149 # `load_checkpoint` infers initial epoch from `optimizer.iterations`. 

150 initial_epoch = load_checkpoint(checkpoint_manager) or 0 

151 

152 @tf.function 

153 def worker_fn(iterator): 

154 

155 def replica_fn(inputs): 

156 batch_data, labels = inputs 

157 # calculate gradient, applying gradient, metrics update etc. 

158 

159 strategy.run(replica_fn, args=(next(iterator),)) 

160 

161 for epoch in range(initial_epoch, num_epoch): 

162 distributed_iterator = iter(distributed_dataset) # Reset iterator state. 

163 for step in range(steps_per_epoch): 

164 

165 # Asynchronously schedule the `worker_fn` to be executed on an arbitrary 

166 # worker. This call returns immediately. 

167 coordinator.schedule(worker_fn, args=(distributed_iterator,)) 

168 

169 # `join` blocks until all scheduled `worker_fn`s finish execution. Once it 

170 # returns, we can read the metrics and save checkpoints as needed. 

171 coordinator.join() 

172 logging.info('Metric result: %r', metrics.result()) 

173 train_accuracy.reset_states() 

174 checkpoint_manager.save() 

175 ``` 

176 

177 With `Model.fit`, 

178 

179 ```python 

180 # Prepare a strategy to use with the cluster and variable partitioning info. 

181 strategy = tf.distribute.experimental.ParameterServerStrategy( 

182 cluster_resolver=..., 

183 variable_partitioner=...) 

184 

185 # A dataset function takes a `input_context` and returns a `Dataset` 

186 def dataset_fn(input_context): 

187 dataset = tf.data.Dataset.from_tensors(...) 

188 return dataset.repeat().shard(...).batch(...).prefetch(...) 

189 

190 # With `Model.fit`, a `DatasetCreator` needs to be used. 

191 input = tf.keras.utils.experimental.DatasetCreator(dataset_fn=...) 

192 

193 with strategy.scope(): 

194 model = ... # Make sure the `Model` is created within scope. 

195 model.compile(optimizer="rmsprop", loss="mse", steps_per_execution=..., ...) 

196 

197 # Optional callbacks to checkpoint the model, back up the progress, etc. 

198 callbacks = [tf.keras.callbacks.ModelCheckpoint(...), ...] 

199 

200 # `steps_per_epoch` is required with `ParameterServerStrategy`. 

201 model.fit(input, epochs=..., steps_per_epoch=..., callbacks=callbacks) 

202 ``` 

203 

204 __Example code for worker and parameter servers__ 

205 

206 In addition to the coordinator, there should be tasks designated as 

207 "worker" or "ps". They should run the following code to start a TensorFlow 

208 server, waiting for coordinator's requests: 

209 

210 ```python 

211 # Provide a `tf.distribute.cluster_resolver.ClusterResolver` that serves 

212 # the cluster information. See below "Cluster setup" section. 

213 cluster_resolver = ... 

214 

215 server = tf.distribute.Server( 

216 cluster_resolver.cluster_spec(), 

217 job_name=cluster_resolver.task_type, 

218 task_index=cluster_resolver.task_id, 

219 protocol="grpc") 

220 

221 # Blocking the process that starts a server from exiting. 

222 server.join() 

223 ``` 

224 

225 __Cluster setup__ 

226 

227 In order for the tasks in the cluster to know other tasks' addresses, 

228 a `tf.distribute.cluster_resolver.ClusterResolver` is required to be used 

229 in coordinator, worker, and ps. The 

230 `tf.distribute.cluster_resolver.ClusterResolver` is responsible for providing 

231 the cluster information, as well as the task type and id of the current task. 

232 See `tf.distribute.cluster_resolver.ClusterResolver` for more information. 

233 

234 If `TF_CONFIG` environment variable is set, a 

235 `tf.distribute.cluster_resolver.TFConfigClusterResolver` should be used as 

236 well. 

237 

238 Since there are assumptions in 

239 `tf.distribute.experimental.ParameterServerStrategy` around the naming of the 

240 task types, "chief", "ps", and "worker" should be used in the 

241 `tf.distribute.cluster_resolver.ClusterResolver` to refer to the coordinator, 

242 parameter servers, and workers, respectively. 

243 

244 The following example demonstrates setting `TF_CONFIG` for the task designated 

245 as a parameter server (task type "ps") and index 1 (the second task), in a 

246 cluster with 1 chief, 2 parameter servers, and 3 workers. Note that it needs 

247 to be set before the use of 

248 `tf.distribute.cluster_resolver.TFConfigClusterResolver`. 

249 

250 Example code for cluster setup: 

251 ```python 

252 os.environ['TF_CONFIG'] = ''' 

253 { 

254 "cluster": { 

255 "chief": ["chief.example.com:2222"], 

256 "ps": ["ps0.example.com:2222", "ps1.example.com:2222"], 

257 "worker": ["worker0.example.com:2222", "worker1.example.com:2222", 

258 "worker2.example.com:2222"] 

259 }, 

260 "task": { 

261 "type": "ps", 

262 "index": 1 

263 } 

264 } 

265 ''' 

266 ``` 

267 

268 If you prefer to run the same binary for all tasks, you will need to let the 

269 binary branch into different roles at the beginning of the program: 

270 ```python 

271 # If coordinator, create a strategy and start the training program. 

272 if cluster_resolver.task_type == 'chief': 

273 strategy = tf.distribute.experimental.ParameterServerStrategy( 

274 cluster_resolver) 

275 ... 

276 

277 # If worker/ps, create a server 

278 elif cluster_resolver.task_type in ("worker", "ps"): 

279 server = tf.distribute.Server(...) 

280 ... 

281 ``` 

282 Alternatively, you can also start a bunch of TensorFlow servers in advance and 

283 connect to them later. The coordinator can be in the same cluster or on any 

284 machine that has connectivity to workers and parameter servers. This is 

285 covered in our guide and tutorial. 

286 

287 __Variable creation with `strategy.scope()`__ 

288 

289 `tf.distribute.experimental.ParameterServerStrategy` follows the 

290 `tf.distribute` API contract where variable creation is expected to be inside 

291 the context manager returned by `strategy.scope()`, in order to be correctly 

292 placed on parameter servers in a round-robin manner: 

293 

294 ```python 

295 # In this example, we're assuming having 3 ps. 

296 strategy = tf.distribute.experimental.ParameterServerStrategy( 

297 cluster_resolver=...) 

298 coordinator = tf.distribute.experimental.coordinator.ClusterCoordinator( 

299 strategy=strategy) 

300 

301 # Variables should be created inside scope to be placed on parameter servers. 

302 # If created outside scope such as `v1` here, it would be placed on the 

303 # coordinator. 

304 v1 = tf.Variable(initial_value=0.0) 

305 

306 with strategy.scope(): 

307 v2 = tf.Variable(initial_value=1.0) 

308 v3 = tf.Variable(initial_value=2.0) 

309 v4 = tf.Variable(initial_value=3.0) 

310 v5 = tf.Variable(initial_value=4.0) 

311 

312 # v2 through v5 are created in scope and are distributed on parameter servers. 

313 # Default placement is round-robin but the order should not be relied on. 

314 assert v2.device == "/job:ps/replica:0/task:0/device:CPU:0" 

315 assert v3.device == "/job:ps/replica:0/task:1/device:CPU:0" 

316 assert v4.device == "/job:ps/replica:0/task:2/device:CPU:0" 

317 assert v5.device == "/job:ps/replica:0/task:0/device:CPU:0" 

318 ``` 

319 

320 See `distribute.Strategy.scope` for more information. 

321 

322 __Variable partitioning__ 

323 

324 Having dedicated servers to store variables means being able to divide up, or 

325 "shard" the variables across the ps. Partitioning large variable among ps is a 

326 commonly used technique to boost training throughput and mitigate memory 

327 constraints. It enables parallel computations and updates on different shards 

328 of a variable, and often yields better load balancing across parameter 

329 servers. Without sharding, models with large variables (e.g, embeddings) that 

330 can't fit into one machine's memory would otherwise be unable to train. 

331 

332 With `tf.distribute.experimental.ParameterServerStrategy`, if a 

333 `variable_partitioner` is provided to `__init__` and certain conditions are 

334 satisfied, the resulting variables created in scope are sharded across the 

335 parameter servers, in a round-robin fashion. The variable reference returned 

336 from `tf.Variable` becomes a type that serves as the container of the sharded 

337 variables. One can access `variables` attribute of this container for the 

338 actual variable components. If building model with `tf.Module` or Keras, 

339 the variable components are collected in the `variables` alike attributes. 

340 

341 It is recommended to use size-based partitioners like 

342 `tf.distribute.experimental.partitioners.MinSizePartitioner` to avoid 

343 partitioning small variables, which could have negative impact on model 

344 training speed. 

345 

346 ```python 

347 # Partition the embedding layer into 2 shards. 

348 variable_partitioner = ( 

349 tf.distribute.experimental.partitioners.MinSizePartitioner( 

350 min_shard_bytes=(256 << 10), 

351 max_shards = 2)) 

352 strategy = tf.distribute.experimental.ParameterServerStrategy( 

353 cluster_resolver=..., 

354 variable_partitioner = variable_partitioner) 

355 with strategy.scope(): 

356 embedding = tf.keras.layers.Embedding(input_dim=1024, output_dim=1024) 

357 assert len(embedding.variables) == 2 

358 assert isinstance(embedding.variables[0], tf.Variable) 

359 assert isinstance(embedding.variables[1], tf.Variable) 

360 assert embedding.variables[0].shape == (512, 1024) 

361 assert embedding.variables[1].shape == (512, 1024) 

362 ``` 

363 

364 The sharded variable container can be converted to a `Tensor` via 

365 `tf.convert_to_tensor`. This means the container can be directly used in most 

366 Python Ops where such `Tensor` conversion automatically happens. For example, 

367 in the above code snippet, `x * self.w` would implicitly apply the said tensor 

368 conversion. Note that such conversion can be expensive, as the variable 

369 components need to be transferred from multiple parameter servers to where 

370 the value is used. 

371 

372 `tf.nn.embedding_lookup` on the other hand doesn't apply the tensor 

373 conversion, and performs parallel lookups on the variable components instead. 

374 This is crucial to scale up embedding lookups when the embedding table 

375 variable is large. 

376 

377 When a partitioned variable is saved to a `SavedModel`, it will be saved as if 

378 it is one single variable. This improves serving efficiency by eliminating 

379 a number of Ops that handle the partiton aspects. 

380 

381 Known limitations of variable partitioning: 

382 

383 * Number of partitions must not change across Checkpoint saving/loading. 

384 

385 * After saving partitioned variables to a SavedModel, the SavedModel can't be 

386 loaded via `tf.saved_model.load`. 

387 

388 * Partition variable doesn't directly work with `tf.GradientTape`, please use 

389 the `variables` attributes to get the actual variable components and use 

390 them in gradient APIs instead. 

391 

392 __Dataset preparation__ 

393 

394 With `tf.distribute.experimental.ParameterServerStrategy`, a dataset is 

395 created in each of the workers to be used for training. This is done by 

396 creating a `dataset_fn` that takes no argument and returns a 

397 `tf.data.Dataset`, and passing the `dataset_fn` into 

398 `tf.distribute.experimental.coordinator. 

399 ClusterCoordinator.create_per_worker_dataset`. We recommend the dataset to be 

400 shuffled and repeated to have the examples run through the training as evenly 

401 as possible. 

402 

403 ```python 

404 def dataset_fn(): 

405 filenames = ... 

406 dataset = tf.data.Dataset.from_tensor_slices(filenames) 

407 

408 # Dataset is recommended to be shuffled, and repeated. 

409 return dataset.shuffle(buffer_size=...).repeat().batch(batch_size=...) 

410 

411 coordinator = 

412 tf.distribute.experimental.coordinator.ClusterCoordinator(strategy=...) 

413 distributed_dataset = coordinator.create_per_worker_dataset(dataset_fn) 

414 ``` 

415 

416 __Limitations__ 

417 

418 * `tf.distribute.experimental.ParameterServerStrategy` in TF2 is experimental, 

419 and the API is subject to further changes. 

420 

421 * When using `Model.fit`, `tf.distribute.experimental.ParameterServerStrategy` 

422 must be used with a `tf.keras.utils.experimental.DatasetCreator`, and 

423 `steps_per_epoch` must be specified. 

424 """ 

425 

426 # pyformat: disable 

427 def __init__(self, cluster_resolver, variable_partitioner=None): 

428 """Initializes the TF2 parameter server strategy. 

429 

430 This initializes the `tf.distribute.experimental.ParameterServerStrategy` 

431 object to be ready for use with 

432 `tf.distribute.experimental.coordinator.ClusterCoordinator`. 

433 

434 Args: 

435 cluster_resolver: a `tf.distribute.cluster_resolver.ClusterResolver` 

436 object. 

437 variable_partitioner: 

438 a `distribute.experimental.partitioners.Partitioner` that specifies 

439 how to partition variables. If `None`, variables will not be 

440 partitioned. 

441 

442 * Predefined partitioners in `tf.distribute.experimental.partitioners` 

443 can be used for this argument. A commonly used partitioner is 

444 `MinSizePartitioner(min_shard_bytes = 256 << 10, max_shards = num_ps)`, 

445 which allocates at least 256K per shard, and each ps gets at most one 

446 shard. 

447 

448 * `variable_partitioner` will be called for each variable created under 

449 strategy `scope` to instruct how the variable should be partitioned. 

450 Variables that have only one partition along the partitioning axis 

451 (i.e., no need for partition) will be created as a normal `tf.Variable`. 

452 

453 * Only the first / outermost axis partitioning is supported. 

454 

455 * Div partition strategy is used to partition variables. Assuming we 

456 assign consecutive integer ids along the first axis of a variable, then 

457 ids are assigned to shards in a contiguous manner, while attempting to 

458 keep each shard size identical. If the ids do not evenly divide the 

459 number of shards, each of the first several shards will be assigned one 

460 more id. For instance, a variable whose first dimension is 13 has 13 

461 ids, and they are split across 5 shards as: 

462 `[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10], [11, 12]]`. 

463 

464 * Variables created under `strategy.extended.colocate_vars_with` will 

465 not be partitioned. 

466 """ 

467 # pyformat: enable 

468 self._cluster_resolver = cluster_resolver 

469 

470 self._verify_args_and_config(cluster_resolver) 

471 self._cluster_coordinator = None 

472 logging.info( 

473 "`tf.distribute.experimental.ParameterServerStrategy` is initialized " 

474 "with cluster_spec: %s", cluster_resolver.cluster_spec()) 

475 

476 if os.getenv("TF_PSS_ENABLE_COORDINATION_SERVICE"): 

477 self._configure_coordination_service(cluster_resolver.cluster_spec()) 

478 # TODO(b/167894802): Make coordinator, worker, and ps names customizable. 

479 self._connect_to_cluster(coordinator_name="chief") 

480 self._extended = ParameterServerStrategyV2Extended(self, cluster_resolver, 

481 variable_partitioner) 

482 super(ParameterServerStrategyV2, self).__init__(self._extended) 

483 distribute_lib.distribution_strategy_gauge.get_cell("V2").set( 

484 "ParameterServerStrategy") 

485 self._should_use_with_coordinator = True 

486 # Used while constructing distributed iterators. 

487 self._canonicalize_devices = False 

488 # Used to check if isinstance() without having to import this module 

489 self._is_parameter_server_strategy_v2 = True 

490 

491 def _configure_coordination_service(self, cluster_spec): 

492 if context.context().coordination_service is None: 

493 coordinated_jobs = ["worker", "ps"] 

494 coordinated_job_config = [] 

495 for job in coordinated_jobs: 

496 if job in cluster_spec.jobs: 

497 coordinated_job_config.append( 

498 coordination_config_pb2.CoordinatedJob( 

499 name=job, 

500 num_tasks=cluster_spec.num_tasks(job))) 

501 context.context().configure_coordination_service( 

502 service_type="standalone", 

503 service_leader=multi_worker_util.coordination_leader( 

504 cluster_spec), 

505 heartbeat_timeout_in_ms=_HEARTBEAT_TIMEOUT_SECS * 1000, 

506 allow_new_incarnation_to_reconnect=True) 

507 

508 def _connect_to_cluster(self, coordinator_name): 

509 if coordinator_name in ["worker", "ps"]: 

510 raise ValueError("coordinator name should not be 'worker' or 'ps'.") 

511 cluster_spec = self._cluster_resolver.cluster_spec() 

512 self._num_workers = len(cluster_spec.as_dict().get("worker", ())) 

513 self._num_ps = len(cluster_spec.as_dict().get("ps", ())) 

514 

515 device_filters = server_lib.ClusterDeviceFilters() 

516 # For any worker, only the devices on ps and coordinator nodes are visible 

517 for i in range(self._num_workers): 

518 device_filters.set_device_filters( 

519 "worker", i, ["/job:ps", "/job:%s" % coordinator_name]) 

520 # Similarly for any ps, only the devices on workers and coordinator are 

521 # visible 

522 for i in range(self._num_ps): 

523 device_filters.set_device_filters( 

524 "ps", i, ["/job:worker", "/job:%s" % coordinator_name]) 

525 

526 # Allow at most one outstanding RPC for each worker at a certain time. This 

527 # is to simplify worker failure handling in the runtime 

528 os.environ["TF_ENABLE_EAGER_CLIENT_STREAMING_ENQUEUE"] = "False" 

529 

530 # Disable async executors to make context.async_wait a no-op. This avoids 

531 # sending RPCs to remote workers since the executors used by PSStrategy 

532 # are known to be always synchronous. 

533 os.environ["TF_PS_DISABLE_ASYNC_EXECUTOR_GLOBALLY"] = "True" 

534 

535 logging.info("%s is now connecting to cluster with cluster_spec: %r", 

536 self.__class__.__name__, cluster_spec) 

537 remote.connect_to_cluster( 

538 cluster_spec, 

539 job_name=coordinator_name, 

540 protocol=self._cluster_resolver.rpc_layer, 

541 cluster_device_filters=device_filters) 

542 

543 distribute_lib.distribution_strategy_replica_gauge.get_cell( 

544 "ps_strategy_num_workers").set(self._num_workers) 

545 distribute_lib.distribution_strategy_replica_gauge.get_cell( 

546 "ps_strategy_num_ps").set(self._num_ps) 

547 

548 def _verify_args_and_config(self, cluster_resolver): 

549 if not cluster_resolver.cluster_spec(): 

550 raise ValueError("Cluster spec must be non-empty in " 

551 "`tf.distribute.cluster_resolver.ClusterResolver`.") 

552 cluster_spec = cluster_resolver.cluster_spec() 

553 

554 # The following checks if the task types are allowed (chief, ps, worker). 

555 multi_worker_util._validate_cluster_spec( # pylint: disable=protected-access 

556 cluster_spec, cluster_resolver.task_type, cluster_resolver.task_id) 

557 

558 if multi_worker_util.task_count(cluster_spec, "ps") < 1: 

559 raise ValueError("There must be at least one ps.") 

560 

561 if multi_worker_util.task_count(cluster_spec, "worker") < 1: 

562 raise ValueError("There must be at least one worker.") 

563 

564 

565class ParameterServerStrategyV2Extended( 

566 parameter_server_strategy.ParameterServerStrategyExtended): 

567 """Extended class for ParameterServerStrategyV2. 

568 

569 Please see `tf.distribute.StrategyExtended` doc for more information. 

570 """ 

571 

572 def __init__(self, container_strategy, cluster_resolver, 

573 variable_partitioner): 

574 """Initialization of ParameterServerStrategyV2Extended.""" 

575 super(ParameterServerStrategyV2Extended, self).__init__(container_strategy) 

576 self._num_ps = len(cluster_resolver.cluster_spec().as_dict().get("ps", [])) 

577 self._num_workers = len(cluster_resolver.cluster_spec().as_dict().get( 

578 "worker", [])) 

579 self._variable_count = 0 

580 

581 self._variable_partitioner = variable_partitioner 

582 # The following two attrs are to verify that `ParameterServerStrategy` 

583 # methods are properly used with a `ClusterCoordinator`. 

584 self._used_with_coordinator = False 

585 self._being_scheduled = False 

586 self._set_num_gpus() 

587 distribute_lib.distribution_strategy_replica_gauge.get_cell( 

588 "num_gpus_per_worker").set(self._num_gpus_per_worker) 

589 

590 # Don't canonicalize the devices here since this code is executed on Chief, 

591 # but we want the reduce evaluation to be done on each worker. Placer will 

592 # automatically choose the right device based on current context. 

593 # TODO(ishark): Use select_cross_device_ops instead. 

594 self._cross_device_ops = cross_device_ops_lib.ReductionToOneDevice( 

595 reduce_to_device="/device:CPU:0") 

596 self._cross_device_ops._canonicalize_devices = False # pylint: disable=protected-access 

597 self._allow_run_without_coordinator = False 

598 self._coordinator_creation_lock = threading.Lock() 

599 

600 def _set_num_gpus(self): 

601 devices = config.list_logical_devices("GPU") 

602 per_worker_gpus = {} 

603 for d in devices: 

604 d_spec = tf_device.DeviceSpec.from_string(d.name) 

605 if d_spec.device_type == "GPU" and d_spec.job == "worker": 

606 # TODO(b/167894802): update if worker name is customizable 

607 job_spec = d_spec.replace(device_type=None, device_index=None) 

608 per_worker_gpus[job_spec] = per_worker_gpus.get(job_spec, 0) + 1 

609 

610 num_gpus = 0 

611 for _, count in per_worker_gpus.items(): 

612 if num_gpus > 0 and count != num_gpus: 

613 raise ValueError("Mismatched number of GPUs per worker") 

614 num_gpus = count 

615 

616 self._num_gpus_per_worker = num_gpus 

617 logging.info(f"Number of GPUs on workers: {self._num_gpus_per_worker}") 

618 

619 @property 

620 def _num_replicas_in_sync(self): 

621 return self._num_gpus_per_worker or 1 

622 

623 def _create_var_creator(self, next_creator, **kwargs): 

624 aggregation = kwargs.pop("aggregation", vs.VariableAggregation.NONE) 

625 

626 def var_creator(**kwargs): 

627 """Create an AggregatingVariable.""" 

628 # Create and wrap the variable. 

629 v = next_creator(**kwargs) 

630 wrapped_v = ps_values.CachingVariable(v) 

631 wrapped = ps_values.AggregatingVariable(self._container_strategy(), 

632 wrapped_v, aggregation) 

633 return wrapped 

634 

635 if self._num_replicas_in_sync > 1: 

636 if aggregation not in (vs.VariableAggregation.NONE, 

637 vs.VariableAggregation.SUM, 

638 vs.VariableAggregation.MEAN, 

639 vs.VariableAggregation.ONLY_FIRST_REPLICA): 

640 raise ValueError("Invalid variable aggregation mode: " + aggregation + 

641 " for variable: " + kwargs["name"]) 

642 return var_creator 

643 else: 

644 

645 def variable_creator_single_replica(**kwargs): 

646 v = next_creator(**kwargs) 

647 return ps_values.CachingVariable(v) 

648 

649 return variable_creator_single_replica 

650 

651 def _create_variable(self, next_creator, **kwargs): 

652 """Implements StrategyExtendedV2._create_variable. 

653 

654 Creates a `Variable` or a `ShardedVariable`. A `ShardedVariable` will be 

655 created if satisfying all the following criteria: 

656 1. `self._variable_partitioner` results in more than one partition on the 

657 first axis. 

658 2. variable's rank is greater than 0. 

659 3. variable is not colocated with another variable. 

660 Otherwise a `Variable` will be created. 

661 

662 Args: 

663 next_creator: See `variable_scope.variable_creator_scope`; the next 

664 creator in the chain. 

665 **kwargs: Passed through to the next creator. 

666 

667 Returns: 

668 A `Variable` or `ShardedVariable`. 

669 """ 

670 

671 var_creator = self._create_var_creator(next_creator, **kwargs) 

672 if "colocate_with" in kwargs: # Never partition colocated_with variables. 

673 colocate_with = kwargs["colocate_with"] 

674 # Clear the variable scope to avoid possible conflicts between device 

675 # scope and colocation scope. 

676 with ops.device(None): 

677 with ops.colocate_with(colocate_with): 

678 var = var_creator(**kwargs) 

679 logging.debug( 

680 "Creating variable (name:%s, shape:%r) that colocates with %s", 

681 var.name, var.shape, kwargs["colocate_with"].name) 

682 return var 

683 

684 if self._variable_partitioner is None: 

685 return self._create_variable_round_robin(var_creator, **kwargs) 

686 

687 name = kwargs.get("name", None) 

688 dtype = kwargs.get("dtype", None) 

689 shape = kwargs.get("shape", None) 

690 initial_value = kwargs.get("initial_value", None) 

691 if initial_value is None: 

692 # If we are loading, next_creator will return an UninitializedVariable 

693 v = next_creator(**kwargs) 

694 if not isinstance(v, resource_variable_ops.UninitializedVariable): 

695 raise ValueError( 

696 "It looks like you are using `ParameterServerStrategy` with a " 

697 "`variable_partitioner`, and trying to create a variable without " 

698 "specifying `initial_value`. This is not allowed. Please specify the " 

699 "`initial_value`.") 

700 elif shape is None or dtype is None: 

701 raise ValueError( 

702 "It looks like you are trying to load a `SavedModel` using " 

703 "`tf.saved_model.load` within a `ParameterServerStrategy` scope, " 

704 "but the `SavedModel` is missing shape or dtype information.") 

705 else: 

706 def initializer(shape, dtype, **kwargs): 

707 if "partition_shape" in kwargs: 

708 shape = kwargs["partition_shape"] 

709 return array_ops.zeros(shape, dtype) 

710 initial_value = functools.partial(initializer, shape=shape, dtype=dtype) 

711 

712 # Two cases where initial_value can be a callable: 

713 # 1. initial_value is passed as a callable, e.g, an `initializer` class. 

714 # 2. restoring from checkpoint, initial_value is a 

715 # "CheckpointInitialValueCallable". 

716 init_from_fn = callable(initial_value) 

717 

718 if init_from_fn and (shape is None or dtype is None): 

719 init_from_fn = False 

720 initial_value = initial_value() 

721 if not init_from_fn: 

722 # The initial_value is created on coordinator, it will need to be sent to 

723 # ps for variable initialization, which can be inefficient and can 

724 # potentially hit the 2GB limit on protobuf serialization. 

725 initial_value = ops.convert_to_tensor(initial_value, dtype=dtype) 

726 dtype = initial_value.dtype 

727 shape = initial_value.shape 

728 else: 

729 shape = tensor_shape.as_shape(shape) 

730 

731 if shape.rank == 0: # Skip partitioning rank-0 variable. 

732 return self._create_variable_round_robin(var_creator, **kwargs) 

733 

734 num_partitions = self._variable_partitioner(shape=shape, dtype=dtype) 

735 if not num_partitions or num_partitions[0] == 0 or any( 

736 v != 1 for v in num_partitions[1:]): 

737 raise ValueError( 

738 "variable_partitioner must return a list/tuple whose elements are 1" 

739 " besides the first element (non-zero), got: %r" % num_partitions) 

740 

741 if num_partitions[0] == 1: # no partition 

742 return self._create_variable_round_robin(var_creator, **kwargs) 

743 

744 # Use "div" partition strategy to partition the variable. 

745 num_partitions = min(num_partitions[0], shape[0]) 

746 base = shape[0] // num_partitions 

747 extra = shape[0] % num_partitions 

748 # An example: num_partitions=4, shape[0]=10, partitions: [3, 3, 2, 2] 

749 # offsets: [0, 3, 6, 8, 10] 

750 offsets = [] 

751 for i in range(num_partitions): 

752 if i == 0: 

753 offsets.append(0) 

754 else: 

755 prev_shard_size = base + (1 if i - 1 < extra else 0) 

756 offsets.append(offsets[i - 1] + prev_shard_size) 

757 offsets.append(shape[0]) 

758 

759 def init_shard_fn(shard_index): 

760 if not init_from_fn: 

761 logging.log_if( 

762 logging.WARN, _INEFFICIENT_INIT_WARNING % name, shard_index == 0 and 

763 shape.num_elements() > _LARGE_VARIABLE_NUM_ELEMENTS) 

764 return initial_value[offsets[shard_index]:offsets[shard_index + 1]] 

765 partition_shape = (offsets[shard_index + 1] - 

766 offsets[shard_index],) + shape[1:] 

767 partition_offset = (offsets[shard_index],) + (0,) * len(shape[1:]) 

768 arg_spec = tf_inspect.getfullargspec(initial_value) 

769 if ("shard_info" not in arg_spec.args and 

770 "shard_info" not in arg_spec.kwonlyargs): 

771 try: 

772 value = initial_value( 

773 partition_shape=partition_shape, 

774 partition_offset=partition_offset) 

775 except (TypeError, ValueError): 

776 # TypeError: Initializer doesn't accept kwargs 

777 # ValueError: Initializer doesn't accept partition kwargs 

778 # In both cases we go ahead creating the full value and then slice. 

779 value = initial_value() 

780 

781 if value.shape == partition_shape: 

782 # Initializer supports partition: value is the partition value. 

783 return value 

784 else: 

785 # Initializer doesn't support partition: value is the full value 

786 # and needs to be sliced to get the partition value. 

787 logging.log_if( 

788 logging.WARN, _INEFFICIENT_INIT_WARNING % name, 

789 shard_index == 0 and 

790 shape.num_elements() > _LARGE_VARIABLE_NUM_ELEMENTS) 

791 return value[offsets[shard_index]:offsets[shard_index + 1]] 

792 else: 

793 # For compatibility with `CheckpointInitialValueCallable`. 

794 return initial_value( 

795 shard_info=trackable.ShardInfo( 

796 shape=tensor_shape.as_shape(partition_shape), 

797 offset=partition_offset)) 

798 

799 var_list = [] 

800 for i in range(num_partitions): 

801 kwargs["shape"] = (offsets[i + 1] - offsets[i],) + shape[1:] 

802 kwargs["initial_value"] = lambda: init_shard_fn(i) 

803 if name is not None: 

804 kwargs["name"] = "{}/part_{}".format(name, i) 

805 var_list.append(self._create_variable_round_robin(var_creator, **kwargs)) 

806 

807 result = sharded_variable.ShardedVariable(var_list) 

808 return result 

809 

810 def _create_variable_round_robin(self, next_creator, **kwargs): 

811 # Clear the colocation scope to avoid possible conflicts between device 

812 # scope and colocation scope. 

813 with ops.colocate_with(None, ignore_existing=True): 

814 # Explicitly set CPU:0 device for PS in case create variable is called 

815 # inside replica_fn and worker has with GPU:0 scope. 

816 with ops.device("/job:ps/task:%d/device:CPU:0" % 

817 (self._variable_count % self._num_ps)): 

818 var = next_creator(**kwargs) 

819 logging.debug( 

820 "Creating variable (name:%s, shape:%r) on " 

821 "/job:ps/task:%d/device:CPU:0", var.name, var.shape, 

822 (self._variable_count % self._num_ps)) 

823 self._variable_count += 1 

824 return var 

825 

826 def _resource_creator_scope(self): 

827 

828 with self._coordinator_creation_lock: 

829 if not self._container_strategy()._cluster_coordinator: # pylint: disable=protected-access 

830 cluster_coordinator.ClusterCoordinator( 

831 strategy=self._container_strategy()) 

832 

833 # TODO(wxinyi): We should warn the user of the inefficiency of creating 

834 # `StaticHashTable` inside a `@tf.function`-wrapped `dataset_fn` to be 

835 # distributed with `distribute_datasets_from_function` and 

836 # `create_per_worker_dataset`. This is because the `dataset_fn` does not 

837 # use the same `default_graph` as `scope` to which the 

838 # `resource_creator_stack` belongs. Thus, `StaticHashTable` creation inside 

839 # `dataset_fn` is not intercepted. And since its resource creation under a 

840 # `tf.function` is lifted out, all workers will share the same resource on 

841 # the coordinator which incurs worker-coordinator communication overhead. 

842 

843 def lookup_creator(next_creator, *args, **kwargs): 

844 if keras_deps.get_load_context_function()(): 

845 return (ps_values.RestoredDistributedTable( 

846 self._container_strategy(), lambda: next_creator(*args, **kwargs))) # pylint: disable=protected-access 

847 else: 

848 return ps_values.DistributedTable(self._container_strategy(), 

849 lambda: next_creator(*args, **kwargs)) # pylint: disable=protected-access 

850 

851 def restored_lookup_creator(next_creator, *args, **kwargs): 

852 return (ps_values.RestoredDistributedTable( 

853 self._container_strategy(), lambda: next_creator(*args, **kwargs))) # pylint: disable=protected-access 

854 

855 return [ 

856 ops.resource_creator_scope("StaticHashTable", lookup_creator), 

857 ops.resource_creator_scope("RestoredStaticHashTable", 

858 restored_lookup_creator) 

859 ] 

860 

861 def _assert_used_with_cluster_coordinator(self): 

862 if (not self._used_with_coordinator and 

863 not self._allow_run_without_coordinator): 

864 raise NotImplementedError( 

865 "`tf.distribute.experimental.ParameterServerStrategy` must be used " 

866 "with `tf.distribute.experimental.coordinator.ClusterCoordinator` in " 

867 "a custom training loop. If you are using `Model.fit`, please supply " 

868 "a dataset function directly to a " 

869 "`tf.keras.utils.experimental.DatasetCreator` instead.") 

870 

871 def _assert_being_scheduled_by_cluster_coordinator(self): 

872 if not self._being_scheduled and not self._allow_run_without_coordinator: 

873 logging.warning( 

874 "A `tf.distribute.experimental.ParameterServerStrategy` method is " 

875 "invoked without using `ClusterCoordinator.schedule`. If you are not " 

876 "tracing a tf.function, this method is possibly executed on the " 

877 "coordinator, which can be slow. To properly dispatch functions to " 

878 "run on workers, methods like `run` or `reduce` should be used " 

879 "within a function passed to `tf.distribute.experimental.coordinator." 

880 "ClusterCoordinator.schedule`.") 

881 

882 # options is not used right now. But we may want to support options while 

883 # creating InputWorkers in future, similar to MirroredStrategy. 

884 def _input_workers_with_options(self, options=None): 

885 input_workers_devices = (("/device:CPU:0", self.worker_devices),) 

886 return input_lib.InputWorkers( 

887 input_workers_devices, canonicalize_devices=False) 

888 

889 def _experimental_distribute_dataset(self, dataset, options): 

890 input_workers_devices = self._input_workers_with_options() 

891 

892 # If this DistributedDataset is created outside ClusterCoordinator, i,e, 

893 # outside a tf.function, we don't build its underlying datasets immediately 

894 # until it is passed to ClusterCoordinator.create_per_worker_dataset. 

895 return input_util.get_distributed_dataset( 

896 dataset, 

897 input_workers_devices, 

898 self._container_strategy(), 

899 num_replicas_in_sync=self._num_replicas_in_sync, 

900 options=options, 

901 build=ops.inside_function()) # will be built by ClusterCoordinator 

902 

903 def _distribute_datasets_from_function(self, dataset_fn, options): 

904 # There is no synchronization beyond a worker and thus, the number of 

905 # input pipelines in sync is only 1 per worker. 

906 input_pipeline_id_in_sync = 0 

907 num_input_pipelines_in_sync = 1 

908 

909 input_context = distribute_lib.InputContext( 

910 num_input_pipelines=num_input_pipelines_in_sync, 

911 input_pipeline_id=input_pipeline_id_in_sync, 

912 num_replicas_in_sync=self._num_replicas_in_sync) 

913 

914 # If this DistributedDatasetFromFunction is created outside 

915 # ClusterCoordinator, i,e, outside a tf.function, we don't build its 

916 # underlying datasets immediately until it is passed to 

917 # ClusterCoordinator.create_per_worker_dataset. 

918 return input_util.get_distributed_datasets_from_function( 

919 dataset_fn, 

920 self._input_workers_with_options(options), [input_context], 

921 self._container_strategy(), 

922 options=options, 

923 build=ops.inside_function()) # will be built by ClusterCoordinator 

924 

925 @property 

926 def worker_devices(self): 

927 num_gpus = self._num_gpus_per_worker 

928 if num_gpus > 0: 

929 compute_devices = tuple("/device:GPU:%d" % (i,) for i in range(num_gpus)) 

930 else: 

931 compute_devices = ("/device:CPU:0",) 

932 return compute_devices 

933 

934 def _call_for_each_replica(self, fn, args, kwargs): 

935 self._assert_being_scheduled_by_cluster_coordinator() 

936 

937 return mirrored_run.call_for_each_replica(self._container_strategy(), fn, 

938 args, kwargs) 

939 

940 def _reduce(self, reduce_op, value): 

941 self._assert_being_scheduled_by_cluster_coordinator() 

942 dst = device_util.current() or self._default_device or "/device:CPU:0" 

943 destinations = device_util.canonicalize_without_job_and_task(dst) 

944 result = self._local_results( 

945 self.reduce_to(reduce_op, value, destinations))[0] 

946 return result 

947 

948 def _reduce_to(self, reduce_op, value, destinations, options): 

949 self._assert_being_scheduled_by_cluster_coordinator() 

950 

951 def get_values(x): 

952 if isinstance(x, values.DistributedValues): 

953 return self._cross_device_ops.reduce( 

954 reduce_op, x, destinations=destinations) # pylint: disable=protected-access 

955 return x 

956 

957 return nest.map_structure(get_values, value) 

958 

959 

960# The warning that will be logged if the way we initialize sharded variables 

961# is memory-inefficient. 

962_INEFFICIENT_INIT_WARNING = ( 

963 "Large variable %s is partitioned but not initialized in a " 

964 "memory-efficient way. On each shard, the full value is first being " 

965 "created and then sliced into smaller values. To reduce the memory " 

966 "footprint, explicitly specify `dtype` and `shape` when creating " 

967 "variables, and use `tf.initializers` to initialize the variable. " 

968 "Note that some initializers (e.g., orthogonal) don't support " 

969 "memory-efficient initialization and there is not much you can do here.") 

970 

971_LARGE_VARIABLE_NUM_ELEMENTS = 1e9