Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/distribute/parameter_server_strategy_v2.py: 23%
271 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Parameter server strategy V2 class.
17This is currently under development and the API is subject to change.
18"""
20import functools
21import os
22import threading
24from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
25from tensorflow.python.distribute import device_util
26from tensorflow.python.distribute import distribute_lib
27from tensorflow.python.distribute import input_lib
28from tensorflow.python.distribute import input_util
29from tensorflow.python.distribute import mirrored_run
30from tensorflow.python.distribute import multi_worker_util
31from tensorflow.python.distribute import parameter_server_strategy
32from tensorflow.python.distribute import ps_values
33from tensorflow.python.distribute import sharded_variable
34from tensorflow.python.distribute import values
35from tensorflow.python.distribute.coordinator import cluster_coordinator
36from tensorflow.python.eager import context
37from tensorflow.python.eager import remote
38from tensorflow.python.framework import config
39from tensorflow.python.framework import device as tf_device
40from tensorflow.python.framework import ops
41from tensorflow.python.framework import tensor_shape
42from tensorflow.python.ops import array_ops
43from tensorflow.python.ops import resource_variable_ops
44from tensorflow.python.ops import variable_scope as vs
45from tensorflow.python.platform import tf_logging as logging
46from tensorflow.python.trackable import base as trackable
47from tensorflow.python.training import server_lib
48from tensorflow.python.util import keras_deps
49from tensorflow.python.util import nest
50from tensorflow.python.util import tf_inspect
51from tensorflow.python.util.tf_export import tf_export
52from tensorflow.tsl.protobuf import coordination_config_pb2
55ALLOWED_TASK_TYPES = ("chief", "worker", "ps")
56# This sets the coordination service's internal heartbeat timeout. In testing, a
57# value of 1 led to some spurious reports of unavailability, so a higher value
58# is used. Refer to the discussion in b/249134783 for more.
59_HEARTBEAT_TIMEOUT_SECS = 5
62@tf_export(
63 "distribute.experimental.ParameterServerStrategy",
64 "distribute.ParameterServerStrategy",
65 v1=[])
66class ParameterServerStrategyV2(distribute_lib.Strategy):
67 """An multi-worker tf.distribute strategy with parameter servers.
69 Parameter server training is a common data-parallel method to scale up a
70 machine learning model on multiple machines. A parameter server training
71 cluster consists of workers and parameter servers. Variables are created on
72 parameter servers and they are read and updated by workers in each step.
73 By default, workers read and update these variables independently without
74 synchronizing with each other. Under this configuration, it is known as
75 asynchronous training.
77 In TensorFlow 2, we recommend an architecture based on central coordination
78 for parameter server training. Each worker and parameter server runs a
79 `tf.distribute.Server`, and on top of that, a coordinator task is responsible
80 for creating resources on workers and parameter servers, dispatching
81 functions, and coordinating the training. The coordinator uses a
82 `tf.distribute.experimental.coordinator.ClusterCoordinator` to coordinate the
83 cluster, and a `tf.distribute.experimental.ParameterServerStrategy` to define
84 variables on parameter servers and computation on workers.
86 For the training to work, the coordinator dispatches `tf.function`s to be
87 executed on remote workers. Upon receiving requests from the coordinator, a
88 worker executes the `tf.function` by reading the variables from parameter
89 servers, executing the ops, and updating the variables on the parameter
90 servers. Each of the worker only processes the requests from the coordinator,
91 and communicates with parameter servers, without direct interactions with
92 other workers in the cluster.
94 As a result, failures of some workers do not prevent the cluster from
95 continuing the work, and this allows the cluster to train with instances that
96 can be occasionally unavailable (e.g. preemptible or spot instances). The
97 coordinator and parameter servers though, must be available at all times for
98 the cluster to make progress.
100 Note that the coordinator is not one of the training workers. Instead, it
101 creates resources such as variables and datasets, dispatches `tf.function`s,
102 saves checkpoints and so on. In addition to workers, parameter servers and
103 the coordinator, an optional evaluator can be run on the side that
104 periodically reads the checkpoints saved by the coordinator and runs
105 evaluations against each checkpoint.
107 `ParameterServerStrategy` is supported with two training APIs: [Custom
108 Training Loop (CTL)]
109 (https://www.tensorflow.org/tutorials/distribute/custom_training)
110 and [Keras Training API, also known as `Model.fit`]
111 (https://www.tensorflow.org/tutorials/distribute/keras). CTL is recommended
112 when users prefer to define the details of their training loop, and
113 `Model.fit` is recommended when users prefer a high-level abstraction and
114 handling of training.
116 When using a CTL, `ParameterServerStrategy` has to work in conjunction with a
117 `tf.distribute.experimental.coordinator.ClusterCoordinator` object.
119 When using `Model.fit`, currently only the
120 `tf.keras.utils.experimental.DatasetCreator` input type is supported.
122 __Example code for coordinator__
124 This section provides code snippets that are intended to be run on (the only)
125 one task that is designated as the coordinator. Note that `cluster_resolver`,
126 `variable_partitioner`, and `dataset_fn` arguments are explained in the
127 following "Cluster setup", "Variable partitioning", and "Dataset preparation"
128 sections.
130 With a CTL,
132 ```python
133 # Prepare a strategy to use with the cluster and variable partitioning info.
134 strategy = tf.distribute.experimental.ParameterServerStrategy(
135 cluster_resolver=...,
136 variable_partitioner=...)
137 coordinator = tf.distribute.experimental.coordinator.ClusterCoordinator(
138 strategy=strategy)
140 # Prepare a distribute dataset that will place datasets on the workers.
141 distributed_dataset = coordinator.create_per_worker_dataset(dataset_fn=...)
143 with strategy.scope():
144 model = ...
145 optimizer, metrics = ... # Keras optimizer/metrics are great choices
146 checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
147 checkpoint_manager = tf.train.CheckpointManager(
148 checkpoint, checkpoint_dir, max_to_keep=2)
149 # `load_checkpoint` infers initial epoch from `optimizer.iterations`.
150 initial_epoch = load_checkpoint(checkpoint_manager) or 0
152 @tf.function
153 def worker_fn(iterator):
155 def replica_fn(inputs):
156 batch_data, labels = inputs
157 # calculate gradient, applying gradient, metrics update etc.
159 strategy.run(replica_fn, args=(next(iterator),))
161 for epoch in range(initial_epoch, num_epoch):
162 distributed_iterator = iter(distributed_dataset) # Reset iterator state.
163 for step in range(steps_per_epoch):
165 # Asynchronously schedule the `worker_fn` to be executed on an arbitrary
166 # worker. This call returns immediately.
167 coordinator.schedule(worker_fn, args=(distributed_iterator,))
169 # `join` blocks until all scheduled `worker_fn`s finish execution. Once it
170 # returns, we can read the metrics and save checkpoints as needed.
171 coordinator.join()
172 logging.info('Metric result: %r', metrics.result())
173 train_accuracy.reset_states()
174 checkpoint_manager.save()
175 ```
177 With `Model.fit`,
179 ```python
180 # Prepare a strategy to use with the cluster and variable partitioning info.
181 strategy = tf.distribute.experimental.ParameterServerStrategy(
182 cluster_resolver=...,
183 variable_partitioner=...)
185 # A dataset function takes a `input_context` and returns a `Dataset`
186 def dataset_fn(input_context):
187 dataset = tf.data.Dataset.from_tensors(...)
188 return dataset.repeat().shard(...).batch(...).prefetch(...)
190 # With `Model.fit`, a `DatasetCreator` needs to be used.
191 input = tf.keras.utils.experimental.DatasetCreator(dataset_fn=...)
193 with strategy.scope():
194 model = ... # Make sure the `Model` is created within scope.
195 model.compile(optimizer="rmsprop", loss="mse", steps_per_execution=..., ...)
197 # Optional callbacks to checkpoint the model, back up the progress, etc.
198 callbacks = [tf.keras.callbacks.ModelCheckpoint(...), ...]
200 # `steps_per_epoch` is required with `ParameterServerStrategy`.
201 model.fit(input, epochs=..., steps_per_epoch=..., callbacks=callbacks)
202 ```
204 __Example code for worker and parameter servers__
206 In addition to the coordinator, there should be tasks designated as
207 "worker" or "ps". They should run the following code to start a TensorFlow
208 server, waiting for coordinator's requests:
210 ```python
211 # Provide a `tf.distribute.cluster_resolver.ClusterResolver` that serves
212 # the cluster information. See below "Cluster setup" section.
213 cluster_resolver = ...
215 server = tf.distribute.Server(
216 cluster_resolver.cluster_spec(),
217 job_name=cluster_resolver.task_type,
218 task_index=cluster_resolver.task_id,
219 protocol="grpc")
221 # Blocking the process that starts a server from exiting.
222 server.join()
223 ```
225 __Cluster setup__
227 In order for the tasks in the cluster to know other tasks' addresses,
228 a `tf.distribute.cluster_resolver.ClusterResolver` is required to be used
229 in coordinator, worker, and ps. The
230 `tf.distribute.cluster_resolver.ClusterResolver` is responsible for providing
231 the cluster information, as well as the task type and id of the current task.
232 See `tf.distribute.cluster_resolver.ClusterResolver` for more information.
234 If `TF_CONFIG` environment variable is set, a
235 `tf.distribute.cluster_resolver.TFConfigClusterResolver` should be used as
236 well.
238 Since there are assumptions in
239 `tf.distribute.experimental.ParameterServerStrategy` around the naming of the
240 task types, "chief", "ps", and "worker" should be used in the
241 `tf.distribute.cluster_resolver.ClusterResolver` to refer to the coordinator,
242 parameter servers, and workers, respectively.
244 The following example demonstrates setting `TF_CONFIG` for the task designated
245 as a parameter server (task type "ps") and index 1 (the second task), in a
246 cluster with 1 chief, 2 parameter servers, and 3 workers. Note that it needs
247 to be set before the use of
248 `tf.distribute.cluster_resolver.TFConfigClusterResolver`.
250 Example code for cluster setup:
251 ```python
252 os.environ['TF_CONFIG'] = '''
253 {
254 "cluster": {
255 "chief": ["chief.example.com:2222"],
256 "ps": ["ps0.example.com:2222", "ps1.example.com:2222"],
257 "worker": ["worker0.example.com:2222", "worker1.example.com:2222",
258 "worker2.example.com:2222"]
259 },
260 "task": {
261 "type": "ps",
262 "index": 1
263 }
264 }
265 '''
266 ```
268 If you prefer to run the same binary for all tasks, you will need to let the
269 binary branch into different roles at the beginning of the program:
270 ```python
271 # If coordinator, create a strategy and start the training program.
272 if cluster_resolver.task_type == 'chief':
273 strategy = tf.distribute.experimental.ParameterServerStrategy(
274 cluster_resolver)
275 ...
277 # If worker/ps, create a server
278 elif cluster_resolver.task_type in ("worker", "ps"):
279 server = tf.distribute.Server(...)
280 ...
281 ```
282 Alternatively, you can also start a bunch of TensorFlow servers in advance and
283 connect to them later. The coordinator can be in the same cluster or on any
284 machine that has connectivity to workers and parameter servers. This is
285 covered in our guide and tutorial.
287 __Variable creation with `strategy.scope()`__
289 `tf.distribute.experimental.ParameterServerStrategy` follows the
290 `tf.distribute` API contract where variable creation is expected to be inside
291 the context manager returned by `strategy.scope()`, in order to be correctly
292 placed on parameter servers in a round-robin manner:
294 ```python
295 # In this example, we're assuming having 3 ps.
296 strategy = tf.distribute.experimental.ParameterServerStrategy(
297 cluster_resolver=...)
298 coordinator = tf.distribute.experimental.coordinator.ClusterCoordinator(
299 strategy=strategy)
301 # Variables should be created inside scope to be placed on parameter servers.
302 # If created outside scope such as `v1` here, it would be placed on the
303 # coordinator.
304 v1 = tf.Variable(initial_value=0.0)
306 with strategy.scope():
307 v2 = tf.Variable(initial_value=1.0)
308 v3 = tf.Variable(initial_value=2.0)
309 v4 = tf.Variable(initial_value=3.0)
310 v5 = tf.Variable(initial_value=4.0)
312 # v2 through v5 are created in scope and are distributed on parameter servers.
313 # Default placement is round-robin but the order should not be relied on.
314 assert v2.device == "/job:ps/replica:0/task:0/device:CPU:0"
315 assert v3.device == "/job:ps/replica:0/task:1/device:CPU:0"
316 assert v4.device == "/job:ps/replica:0/task:2/device:CPU:0"
317 assert v5.device == "/job:ps/replica:0/task:0/device:CPU:0"
318 ```
320 See `distribute.Strategy.scope` for more information.
322 __Variable partitioning__
324 Having dedicated servers to store variables means being able to divide up, or
325 "shard" the variables across the ps. Partitioning large variable among ps is a
326 commonly used technique to boost training throughput and mitigate memory
327 constraints. It enables parallel computations and updates on different shards
328 of a variable, and often yields better load balancing across parameter
329 servers. Without sharding, models with large variables (e.g, embeddings) that
330 can't fit into one machine's memory would otherwise be unable to train.
332 With `tf.distribute.experimental.ParameterServerStrategy`, if a
333 `variable_partitioner` is provided to `__init__` and certain conditions are
334 satisfied, the resulting variables created in scope are sharded across the
335 parameter servers, in a round-robin fashion. The variable reference returned
336 from `tf.Variable` becomes a type that serves as the container of the sharded
337 variables. One can access `variables` attribute of this container for the
338 actual variable components. If building model with `tf.Module` or Keras,
339 the variable components are collected in the `variables` alike attributes.
341 It is recommended to use size-based partitioners like
342 `tf.distribute.experimental.partitioners.MinSizePartitioner` to avoid
343 partitioning small variables, which could have negative impact on model
344 training speed.
346 ```python
347 # Partition the embedding layer into 2 shards.
348 variable_partitioner = (
349 tf.distribute.experimental.partitioners.MinSizePartitioner(
350 min_shard_bytes=(256 << 10),
351 max_shards = 2))
352 strategy = tf.distribute.experimental.ParameterServerStrategy(
353 cluster_resolver=...,
354 variable_partitioner = variable_partitioner)
355 with strategy.scope():
356 embedding = tf.keras.layers.Embedding(input_dim=1024, output_dim=1024)
357 assert len(embedding.variables) == 2
358 assert isinstance(embedding.variables[0], tf.Variable)
359 assert isinstance(embedding.variables[1], tf.Variable)
360 assert embedding.variables[0].shape == (512, 1024)
361 assert embedding.variables[1].shape == (512, 1024)
362 ```
364 The sharded variable container can be converted to a `Tensor` via
365 `tf.convert_to_tensor`. This means the container can be directly used in most
366 Python Ops where such `Tensor` conversion automatically happens. For example,
367 in the above code snippet, `x * self.w` would implicitly apply the said tensor
368 conversion. Note that such conversion can be expensive, as the variable
369 components need to be transferred from multiple parameter servers to where
370 the value is used.
372 `tf.nn.embedding_lookup` on the other hand doesn't apply the tensor
373 conversion, and performs parallel lookups on the variable components instead.
374 This is crucial to scale up embedding lookups when the embedding table
375 variable is large.
377 When a partitioned variable is saved to a `SavedModel`, it will be saved as if
378 it is one single variable. This improves serving efficiency by eliminating
379 a number of Ops that handle the partiton aspects.
381 Known limitations of variable partitioning:
383 * Number of partitions must not change across Checkpoint saving/loading.
385 * After saving partitioned variables to a SavedModel, the SavedModel can't be
386 loaded via `tf.saved_model.load`.
388 * Partition variable doesn't directly work with `tf.GradientTape`, please use
389 the `variables` attributes to get the actual variable components and use
390 them in gradient APIs instead.
392 __Dataset preparation__
394 With `tf.distribute.experimental.ParameterServerStrategy`, a dataset is
395 created in each of the workers to be used for training. This is done by
396 creating a `dataset_fn` that takes no argument and returns a
397 `tf.data.Dataset`, and passing the `dataset_fn` into
398 `tf.distribute.experimental.coordinator.
399 ClusterCoordinator.create_per_worker_dataset`. We recommend the dataset to be
400 shuffled and repeated to have the examples run through the training as evenly
401 as possible.
403 ```python
404 def dataset_fn():
405 filenames = ...
406 dataset = tf.data.Dataset.from_tensor_slices(filenames)
408 # Dataset is recommended to be shuffled, and repeated.
409 return dataset.shuffle(buffer_size=...).repeat().batch(batch_size=...)
411 coordinator =
412 tf.distribute.experimental.coordinator.ClusterCoordinator(strategy=...)
413 distributed_dataset = coordinator.create_per_worker_dataset(dataset_fn)
414 ```
416 __Limitations__
418 * `tf.distribute.experimental.ParameterServerStrategy` in TF2 is experimental,
419 and the API is subject to further changes.
421 * When using `Model.fit`, `tf.distribute.experimental.ParameterServerStrategy`
422 must be used with a `tf.keras.utils.experimental.DatasetCreator`, and
423 `steps_per_epoch` must be specified.
424 """
426 # pyformat: disable
427 def __init__(self, cluster_resolver, variable_partitioner=None):
428 """Initializes the TF2 parameter server strategy.
430 This initializes the `tf.distribute.experimental.ParameterServerStrategy`
431 object to be ready for use with
432 `tf.distribute.experimental.coordinator.ClusterCoordinator`.
434 Args:
435 cluster_resolver: a `tf.distribute.cluster_resolver.ClusterResolver`
436 object.
437 variable_partitioner:
438 a `distribute.experimental.partitioners.Partitioner` that specifies
439 how to partition variables. If `None`, variables will not be
440 partitioned.
442 * Predefined partitioners in `tf.distribute.experimental.partitioners`
443 can be used for this argument. A commonly used partitioner is
444 `MinSizePartitioner(min_shard_bytes = 256 << 10, max_shards = num_ps)`,
445 which allocates at least 256K per shard, and each ps gets at most one
446 shard.
448 * `variable_partitioner` will be called for each variable created under
449 strategy `scope` to instruct how the variable should be partitioned.
450 Variables that have only one partition along the partitioning axis
451 (i.e., no need for partition) will be created as a normal `tf.Variable`.
453 * Only the first / outermost axis partitioning is supported.
455 * Div partition strategy is used to partition variables. Assuming we
456 assign consecutive integer ids along the first axis of a variable, then
457 ids are assigned to shards in a contiguous manner, while attempting to
458 keep each shard size identical. If the ids do not evenly divide the
459 number of shards, each of the first several shards will be assigned one
460 more id. For instance, a variable whose first dimension is 13 has 13
461 ids, and they are split across 5 shards as:
462 `[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10], [11, 12]]`.
464 * Variables created under `strategy.extended.colocate_vars_with` will
465 not be partitioned.
466 """
467 # pyformat: enable
468 self._cluster_resolver = cluster_resolver
470 self._verify_args_and_config(cluster_resolver)
471 self._cluster_coordinator = None
472 logging.info(
473 "`tf.distribute.experimental.ParameterServerStrategy` is initialized "
474 "with cluster_spec: %s", cluster_resolver.cluster_spec())
476 if os.getenv("TF_PSS_ENABLE_COORDINATION_SERVICE"):
477 self._configure_coordination_service(cluster_resolver.cluster_spec())
478 # TODO(b/167894802): Make coordinator, worker, and ps names customizable.
479 self._connect_to_cluster(coordinator_name="chief")
480 self._extended = ParameterServerStrategyV2Extended(self, cluster_resolver,
481 variable_partitioner)
482 super(ParameterServerStrategyV2, self).__init__(self._extended)
483 distribute_lib.distribution_strategy_gauge.get_cell("V2").set(
484 "ParameterServerStrategy")
485 self._should_use_with_coordinator = True
486 # Used while constructing distributed iterators.
487 self._canonicalize_devices = False
488 # Used to check if isinstance() without having to import this module
489 self._is_parameter_server_strategy_v2 = True
491 def _configure_coordination_service(self, cluster_spec):
492 if context.context().coordination_service is None:
493 coordinated_jobs = ["worker", "ps"]
494 coordinated_job_config = []
495 for job in coordinated_jobs:
496 if job in cluster_spec.jobs:
497 coordinated_job_config.append(
498 coordination_config_pb2.CoordinatedJob(
499 name=job,
500 num_tasks=cluster_spec.num_tasks(job)))
501 context.context().configure_coordination_service(
502 service_type="standalone",
503 service_leader=multi_worker_util.coordination_leader(
504 cluster_spec),
505 heartbeat_timeout_in_ms=_HEARTBEAT_TIMEOUT_SECS * 1000,
506 allow_new_incarnation_to_reconnect=True)
508 def _connect_to_cluster(self, coordinator_name):
509 if coordinator_name in ["worker", "ps"]:
510 raise ValueError("coordinator name should not be 'worker' or 'ps'.")
511 cluster_spec = self._cluster_resolver.cluster_spec()
512 self._num_workers = len(cluster_spec.as_dict().get("worker", ()))
513 self._num_ps = len(cluster_spec.as_dict().get("ps", ()))
515 device_filters = server_lib.ClusterDeviceFilters()
516 # For any worker, only the devices on ps and coordinator nodes are visible
517 for i in range(self._num_workers):
518 device_filters.set_device_filters(
519 "worker", i, ["/job:ps", "/job:%s" % coordinator_name])
520 # Similarly for any ps, only the devices on workers and coordinator are
521 # visible
522 for i in range(self._num_ps):
523 device_filters.set_device_filters(
524 "ps", i, ["/job:worker", "/job:%s" % coordinator_name])
526 # Allow at most one outstanding RPC for each worker at a certain time. This
527 # is to simplify worker failure handling in the runtime
528 os.environ["TF_ENABLE_EAGER_CLIENT_STREAMING_ENQUEUE"] = "False"
530 # Disable async executors to make context.async_wait a no-op. This avoids
531 # sending RPCs to remote workers since the executors used by PSStrategy
532 # are known to be always synchronous.
533 os.environ["TF_PS_DISABLE_ASYNC_EXECUTOR_GLOBALLY"] = "True"
535 logging.info("%s is now connecting to cluster with cluster_spec: %r",
536 self.__class__.__name__, cluster_spec)
537 remote.connect_to_cluster(
538 cluster_spec,
539 job_name=coordinator_name,
540 protocol=self._cluster_resolver.rpc_layer,
541 cluster_device_filters=device_filters)
543 distribute_lib.distribution_strategy_replica_gauge.get_cell(
544 "ps_strategy_num_workers").set(self._num_workers)
545 distribute_lib.distribution_strategy_replica_gauge.get_cell(
546 "ps_strategy_num_ps").set(self._num_ps)
548 def _verify_args_and_config(self, cluster_resolver):
549 if not cluster_resolver.cluster_spec():
550 raise ValueError("Cluster spec must be non-empty in "
551 "`tf.distribute.cluster_resolver.ClusterResolver`.")
552 cluster_spec = cluster_resolver.cluster_spec()
554 # The following checks if the task types are allowed (chief, ps, worker).
555 multi_worker_util._validate_cluster_spec( # pylint: disable=protected-access
556 cluster_spec, cluster_resolver.task_type, cluster_resolver.task_id)
558 if multi_worker_util.task_count(cluster_spec, "ps") < 1:
559 raise ValueError("There must be at least one ps.")
561 if multi_worker_util.task_count(cluster_spec, "worker") < 1:
562 raise ValueError("There must be at least one worker.")
565class ParameterServerStrategyV2Extended(
566 parameter_server_strategy.ParameterServerStrategyExtended):
567 """Extended class for ParameterServerStrategyV2.
569 Please see `tf.distribute.StrategyExtended` doc for more information.
570 """
572 def __init__(self, container_strategy, cluster_resolver,
573 variable_partitioner):
574 """Initialization of ParameterServerStrategyV2Extended."""
575 super(ParameterServerStrategyV2Extended, self).__init__(container_strategy)
576 self._num_ps = len(cluster_resolver.cluster_spec().as_dict().get("ps", []))
577 self._num_workers = len(cluster_resolver.cluster_spec().as_dict().get(
578 "worker", []))
579 self._variable_count = 0
581 self._variable_partitioner = variable_partitioner
582 # The following two attrs are to verify that `ParameterServerStrategy`
583 # methods are properly used with a `ClusterCoordinator`.
584 self._used_with_coordinator = False
585 self._being_scheduled = False
586 self._set_num_gpus()
587 distribute_lib.distribution_strategy_replica_gauge.get_cell(
588 "num_gpus_per_worker").set(self._num_gpus_per_worker)
590 # Don't canonicalize the devices here since this code is executed on Chief,
591 # but we want the reduce evaluation to be done on each worker. Placer will
592 # automatically choose the right device based on current context.
593 # TODO(ishark): Use select_cross_device_ops instead.
594 self._cross_device_ops = cross_device_ops_lib.ReductionToOneDevice(
595 reduce_to_device="/device:CPU:0")
596 self._cross_device_ops._canonicalize_devices = False # pylint: disable=protected-access
597 self._allow_run_without_coordinator = False
598 self._coordinator_creation_lock = threading.Lock()
600 def _set_num_gpus(self):
601 devices = config.list_logical_devices("GPU")
602 per_worker_gpus = {}
603 for d in devices:
604 d_spec = tf_device.DeviceSpec.from_string(d.name)
605 if d_spec.device_type == "GPU" and d_spec.job == "worker":
606 # TODO(b/167894802): update if worker name is customizable
607 job_spec = d_spec.replace(device_type=None, device_index=None)
608 per_worker_gpus[job_spec] = per_worker_gpus.get(job_spec, 0) + 1
610 num_gpus = 0
611 for _, count in per_worker_gpus.items():
612 if num_gpus > 0 and count != num_gpus:
613 raise ValueError("Mismatched number of GPUs per worker")
614 num_gpus = count
616 self._num_gpus_per_worker = num_gpus
617 logging.info(f"Number of GPUs on workers: {self._num_gpus_per_worker}")
619 @property
620 def _num_replicas_in_sync(self):
621 return self._num_gpus_per_worker or 1
623 def _create_var_creator(self, next_creator, **kwargs):
624 aggregation = kwargs.pop("aggregation", vs.VariableAggregation.NONE)
626 def var_creator(**kwargs):
627 """Create an AggregatingVariable."""
628 # Create and wrap the variable.
629 v = next_creator(**kwargs)
630 wrapped_v = ps_values.CachingVariable(v)
631 wrapped = ps_values.AggregatingVariable(self._container_strategy(),
632 wrapped_v, aggregation)
633 return wrapped
635 if self._num_replicas_in_sync > 1:
636 if aggregation not in (vs.VariableAggregation.NONE,
637 vs.VariableAggregation.SUM,
638 vs.VariableAggregation.MEAN,
639 vs.VariableAggregation.ONLY_FIRST_REPLICA):
640 raise ValueError("Invalid variable aggregation mode: " + aggregation +
641 " for variable: " + kwargs["name"])
642 return var_creator
643 else:
645 def variable_creator_single_replica(**kwargs):
646 v = next_creator(**kwargs)
647 return ps_values.CachingVariable(v)
649 return variable_creator_single_replica
651 def _create_variable(self, next_creator, **kwargs):
652 """Implements StrategyExtendedV2._create_variable.
654 Creates a `Variable` or a `ShardedVariable`. A `ShardedVariable` will be
655 created if satisfying all the following criteria:
656 1. `self._variable_partitioner` results in more than one partition on the
657 first axis.
658 2. variable's rank is greater than 0.
659 3. variable is not colocated with another variable.
660 Otherwise a `Variable` will be created.
662 Args:
663 next_creator: See `variable_scope.variable_creator_scope`; the next
664 creator in the chain.
665 **kwargs: Passed through to the next creator.
667 Returns:
668 A `Variable` or `ShardedVariable`.
669 """
671 var_creator = self._create_var_creator(next_creator, **kwargs)
672 if "colocate_with" in kwargs: # Never partition colocated_with variables.
673 colocate_with = kwargs["colocate_with"]
674 # Clear the variable scope to avoid possible conflicts between device
675 # scope and colocation scope.
676 with ops.device(None):
677 with ops.colocate_with(colocate_with):
678 var = var_creator(**kwargs)
679 logging.debug(
680 "Creating variable (name:%s, shape:%r) that colocates with %s",
681 var.name, var.shape, kwargs["colocate_with"].name)
682 return var
684 if self._variable_partitioner is None:
685 return self._create_variable_round_robin(var_creator, **kwargs)
687 name = kwargs.get("name", None)
688 dtype = kwargs.get("dtype", None)
689 shape = kwargs.get("shape", None)
690 initial_value = kwargs.get("initial_value", None)
691 if initial_value is None:
692 # If we are loading, next_creator will return an UninitializedVariable
693 v = next_creator(**kwargs)
694 if not isinstance(v, resource_variable_ops.UninitializedVariable):
695 raise ValueError(
696 "It looks like you are using `ParameterServerStrategy` with a "
697 "`variable_partitioner`, and trying to create a variable without "
698 "specifying `initial_value`. This is not allowed. Please specify the "
699 "`initial_value`.")
700 elif shape is None or dtype is None:
701 raise ValueError(
702 "It looks like you are trying to load a `SavedModel` using "
703 "`tf.saved_model.load` within a `ParameterServerStrategy` scope, "
704 "but the `SavedModel` is missing shape or dtype information.")
705 else:
706 def initializer(shape, dtype, **kwargs):
707 if "partition_shape" in kwargs:
708 shape = kwargs["partition_shape"]
709 return array_ops.zeros(shape, dtype)
710 initial_value = functools.partial(initializer, shape=shape, dtype=dtype)
712 # Two cases where initial_value can be a callable:
713 # 1. initial_value is passed as a callable, e.g, an `initializer` class.
714 # 2. restoring from checkpoint, initial_value is a
715 # "CheckpointInitialValueCallable".
716 init_from_fn = callable(initial_value)
718 if init_from_fn and (shape is None or dtype is None):
719 init_from_fn = False
720 initial_value = initial_value()
721 if not init_from_fn:
722 # The initial_value is created on coordinator, it will need to be sent to
723 # ps for variable initialization, which can be inefficient and can
724 # potentially hit the 2GB limit on protobuf serialization.
725 initial_value = ops.convert_to_tensor(initial_value, dtype=dtype)
726 dtype = initial_value.dtype
727 shape = initial_value.shape
728 else:
729 shape = tensor_shape.as_shape(shape)
731 if shape.rank == 0: # Skip partitioning rank-0 variable.
732 return self._create_variable_round_robin(var_creator, **kwargs)
734 num_partitions = self._variable_partitioner(shape=shape, dtype=dtype)
735 if not num_partitions or num_partitions[0] == 0 or any(
736 v != 1 for v in num_partitions[1:]):
737 raise ValueError(
738 "variable_partitioner must return a list/tuple whose elements are 1"
739 " besides the first element (non-zero), got: %r" % num_partitions)
741 if num_partitions[0] == 1: # no partition
742 return self._create_variable_round_robin(var_creator, **kwargs)
744 # Use "div" partition strategy to partition the variable.
745 num_partitions = min(num_partitions[0], shape[0])
746 base = shape[0] // num_partitions
747 extra = shape[0] % num_partitions
748 # An example: num_partitions=4, shape[0]=10, partitions: [3, 3, 2, 2]
749 # offsets: [0, 3, 6, 8, 10]
750 offsets = []
751 for i in range(num_partitions):
752 if i == 0:
753 offsets.append(0)
754 else:
755 prev_shard_size = base + (1 if i - 1 < extra else 0)
756 offsets.append(offsets[i - 1] + prev_shard_size)
757 offsets.append(shape[0])
759 def init_shard_fn(shard_index):
760 if not init_from_fn:
761 logging.log_if(
762 logging.WARN, _INEFFICIENT_INIT_WARNING % name, shard_index == 0 and
763 shape.num_elements() > _LARGE_VARIABLE_NUM_ELEMENTS)
764 return initial_value[offsets[shard_index]:offsets[shard_index + 1]]
765 partition_shape = (offsets[shard_index + 1] -
766 offsets[shard_index],) + shape[1:]
767 partition_offset = (offsets[shard_index],) + (0,) * len(shape[1:])
768 arg_spec = tf_inspect.getfullargspec(initial_value)
769 if ("shard_info" not in arg_spec.args and
770 "shard_info" not in arg_spec.kwonlyargs):
771 try:
772 value = initial_value(
773 partition_shape=partition_shape,
774 partition_offset=partition_offset)
775 except (TypeError, ValueError):
776 # TypeError: Initializer doesn't accept kwargs
777 # ValueError: Initializer doesn't accept partition kwargs
778 # In both cases we go ahead creating the full value and then slice.
779 value = initial_value()
781 if value.shape == partition_shape:
782 # Initializer supports partition: value is the partition value.
783 return value
784 else:
785 # Initializer doesn't support partition: value is the full value
786 # and needs to be sliced to get the partition value.
787 logging.log_if(
788 logging.WARN, _INEFFICIENT_INIT_WARNING % name,
789 shard_index == 0 and
790 shape.num_elements() > _LARGE_VARIABLE_NUM_ELEMENTS)
791 return value[offsets[shard_index]:offsets[shard_index + 1]]
792 else:
793 # For compatibility with `CheckpointInitialValueCallable`.
794 return initial_value(
795 shard_info=trackable.ShardInfo(
796 shape=tensor_shape.as_shape(partition_shape),
797 offset=partition_offset))
799 var_list = []
800 for i in range(num_partitions):
801 kwargs["shape"] = (offsets[i + 1] - offsets[i],) + shape[1:]
802 kwargs["initial_value"] = lambda: init_shard_fn(i)
803 if name is not None:
804 kwargs["name"] = "{}/part_{}".format(name, i)
805 var_list.append(self._create_variable_round_robin(var_creator, **kwargs))
807 result = sharded_variable.ShardedVariable(var_list)
808 return result
810 def _create_variable_round_robin(self, next_creator, **kwargs):
811 # Clear the colocation scope to avoid possible conflicts between device
812 # scope and colocation scope.
813 with ops.colocate_with(None, ignore_existing=True):
814 # Explicitly set CPU:0 device for PS in case create variable is called
815 # inside replica_fn and worker has with GPU:0 scope.
816 with ops.device("/job:ps/task:%d/device:CPU:0" %
817 (self._variable_count % self._num_ps)):
818 var = next_creator(**kwargs)
819 logging.debug(
820 "Creating variable (name:%s, shape:%r) on "
821 "/job:ps/task:%d/device:CPU:0", var.name, var.shape,
822 (self._variable_count % self._num_ps))
823 self._variable_count += 1
824 return var
826 def _resource_creator_scope(self):
828 with self._coordinator_creation_lock:
829 if not self._container_strategy()._cluster_coordinator: # pylint: disable=protected-access
830 cluster_coordinator.ClusterCoordinator(
831 strategy=self._container_strategy())
833 # TODO(wxinyi): We should warn the user of the inefficiency of creating
834 # `StaticHashTable` inside a `@tf.function`-wrapped `dataset_fn` to be
835 # distributed with `distribute_datasets_from_function` and
836 # `create_per_worker_dataset`. This is because the `dataset_fn` does not
837 # use the same `default_graph` as `scope` to which the
838 # `resource_creator_stack` belongs. Thus, `StaticHashTable` creation inside
839 # `dataset_fn` is not intercepted. And since its resource creation under a
840 # `tf.function` is lifted out, all workers will share the same resource on
841 # the coordinator which incurs worker-coordinator communication overhead.
843 def lookup_creator(next_creator, *args, **kwargs):
844 if keras_deps.get_load_context_function()():
845 return (ps_values.RestoredDistributedTable(
846 self._container_strategy(), lambda: next_creator(*args, **kwargs))) # pylint: disable=protected-access
847 else:
848 return ps_values.DistributedTable(self._container_strategy(),
849 lambda: next_creator(*args, **kwargs)) # pylint: disable=protected-access
851 def restored_lookup_creator(next_creator, *args, **kwargs):
852 return (ps_values.RestoredDistributedTable(
853 self._container_strategy(), lambda: next_creator(*args, **kwargs))) # pylint: disable=protected-access
855 return [
856 ops.resource_creator_scope("StaticHashTable", lookup_creator),
857 ops.resource_creator_scope("RestoredStaticHashTable",
858 restored_lookup_creator)
859 ]
861 def _assert_used_with_cluster_coordinator(self):
862 if (not self._used_with_coordinator and
863 not self._allow_run_without_coordinator):
864 raise NotImplementedError(
865 "`tf.distribute.experimental.ParameterServerStrategy` must be used "
866 "with `tf.distribute.experimental.coordinator.ClusterCoordinator` in "
867 "a custom training loop. If you are using `Model.fit`, please supply "
868 "a dataset function directly to a "
869 "`tf.keras.utils.experimental.DatasetCreator` instead.")
871 def _assert_being_scheduled_by_cluster_coordinator(self):
872 if not self._being_scheduled and not self._allow_run_without_coordinator:
873 logging.warning(
874 "A `tf.distribute.experimental.ParameterServerStrategy` method is "
875 "invoked without using `ClusterCoordinator.schedule`. If you are not "
876 "tracing a tf.function, this method is possibly executed on the "
877 "coordinator, which can be slow. To properly dispatch functions to "
878 "run on workers, methods like `run` or `reduce` should be used "
879 "within a function passed to `tf.distribute.experimental.coordinator."
880 "ClusterCoordinator.schedule`.")
882 # options is not used right now. But we may want to support options while
883 # creating InputWorkers in future, similar to MirroredStrategy.
884 def _input_workers_with_options(self, options=None):
885 input_workers_devices = (("/device:CPU:0", self.worker_devices),)
886 return input_lib.InputWorkers(
887 input_workers_devices, canonicalize_devices=False)
889 def _experimental_distribute_dataset(self, dataset, options):
890 input_workers_devices = self._input_workers_with_options()
892 # If this DistributedDataset is created outside ClusterCoordinator, i,e,
893 # outside a tf.function, we don't build its underlying datasets immediately
894 # until it is passed to ClusterCoordinator.create_per_worker_dataset.
895 return input_util.get_distributed_dataset(
896 dataset,
897 input_workers_devices,
898 self._container_strategy(),
899 num_replicas_in_sync=self._num_replicas_in_sync,
900 options=options,
901 build=ops.inside_function()) # will be built by ClusterCoordinator
903 def _distribute_datasets_from_function(self, dataset_fn, options):
904 # There is no synchronization beyond a worker and thus, the number of
905 # input pipelines in sync is only 1 per worker.
906 input_pipeline_id_in_sync = 0
907 num_input_pipelines_in_sync = 1
909 input_context = distribute_lib.InputContext(
910 num_input_pipelines=num_input_pipelines_in_sync,
911 input_pipeline_id=input_pipeline_id_in_sync,
912 num_replicas_in_sync=self._num_replicas_in_sync)
914 # If this DistributedDatasetFromFunction is created outside
915 # ClusterCoordinator, i,e, outside a tf.function, we don't build its
916 # underlying datasets immediately until it is passed to
917 # ClusterCoordinator.create_per_worker_dataset.
918 return input_util.get_distributed_datasets_from_function(
919 dataset_fn,
920 self._input_workers_with_options(options), [input_context],
921 self._container_strategy(),
922 options=options,
923 build=ops.inside_function()) # will be built by ClusterCoordinator
925 @property
926 def worker_devices(self):
927 num_gpus = self._num_gpus_per_worker
928 if num_gpus > 0:
929 compute_devices = tuple("/device:GPU:%d" % (i,) for i in range(num_gpus))
930 else:
931 compute_devices = ("/device:CPU:0",)
932 return compute_devices
934 def _call_for_each_replica(self, fn, args, kwargs):
935 self._assert_being_scheduled_by_cluster_coordinator()
937 return mirrored_run.call_for_each_replica(self._container_strategy(), fn,
938 args, kwargs)
940 def _reduce(self, reduce_op, value):
941 self._assert_being_scheduled_by_cluster_coordinator()
942 dst = device_util.current() or self._default_device or "/device:CPU:0"
943 destinations = device_util.canonicalize_without_job_and_task(dst)
944 result = self._local_results(
945 self.reduce_to(reduce_op, value, destinations))[0]
946 return result
948 def _reduce_to(self, reduce_op, value, destinations, options):
949 self._assert_being_scheduled_by_cluster_coordinator()
951 def get_values(x):
952 if isinstance(x, values.DistributedValues):
953 return self._cross_device_ops.reduce(
954 reduce_op, x, destinations=destinations) # pylint: disable=protected-access
955 return x
957 return nest.map_structure(get_values, value)
960# The warning that will be logged if the way we initialize sharded variables
961# is memory-inefficient.
962_INEFFICIENT_INIT_WARNING = (
963 "Large variable %s is partitioned but not initialized in a "
964 "memory-efficient way. On each shard, the full value is first being "
965 "created and then sliced into smaller values. To reduce the memory "
966 "footprint, explicitly specify `dtype` and `shape` when creating "
967 "variables, and use `tf.initializers` to initialize the variable. "
968 "Note that some initializers (e.g., orthogonal) don't support "
969 "memory-efficient initialization and there is not much you can do here.")
971_LARGE_VARIABLE_NUM_ELEMENTS = 1e9