Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/distribute/parameter_server_strategy.py: 29%
305 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Class implementing a multi-worker parameter server tf.distribute strategy."""
17import copy
20from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
21from tensorflow.python.distribute import device_util
22from tensorflow.python.distribute import distribute_lib
23from tensorflow.python.distribute import distribute_utils
24from tensorflow.python.distribute import input_lib
25from tensorflow.python.distribute import input_util
26from tensorflow.python.distribute import mirrored_run
27from tensorflow.python.distribute import multi_worker_util
28from tensorflow.python.distribute import numpy_dataset
29from tensorflow.python.distribute import ps_values
30from tensorflow.python.distribute import values
31from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
32from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolver
33from tensorflow.python.distribute.v1 import input_lib as input_lib_v1
34from tensorflow.python.eager import context
35from tensorflow.python.framework import device as tf_device
36from tensorflow.python.framework import ops
37from tensorflow.python.ops import array_ops
38from tensorflow.python.ops import resource_variable_ops
39from tensorflow.python.ops import variable_scope as vs
40from tensorflow.python.platform import tf_logging as logging
41from tensorflow.python.training import device_setter
42from tensorflow.python.util import nest
43from tensorflow.python.util.tf_export import tf_export
45_LOCAL_CPU = "/device:CPU:0"
48@tf_export(v1=["distribute.experimental.ParameterServerStrategy"]) # pylint: disable=missing-docstring
49class ParameterServerStrategyV1(distribute_lib.StrategyV1):
50 """An asynchronous multi-worker parameter server tf.distribute strategy.
52 This strategy requires two roles: workers and parameter servers. Variables and
53 updates to those variables will be assigned to parameter servers and other
54 operations are assigned to workers.
56 When each worker has more than one GPU, operations will be replicated on all
57 GPUs. Even though operations may be replicated, variables are not and each
58 worker shares a common view for which parameter server a variable is assigned
59 to.
61 By default it uses `TFConfigClusterResolver` to detect configurations for
62 multi-worker training. This requires a 'TF_CONFIG' environment variable and
63 the 'TF_CONFIG' must have a cluster spec.
65 This class assumes each worker is running the same code independently, but
66 parameter servers are running a standard server. This means that while each
67 worker will synchronously compute a single gradient update across all GPUs,
68 updates between workers proceed asynchronously. Operations that occur only on
69 the first replica (such as incrementing the global step), will occur on the
70 first replica *of every worker*.
72 It is expected to call `call_for_each_replica(fn, ...)` for any
73 operations which potentially can be replicated across replicas (i.e. multiple
74 GPUs) even if there is only CPU or one GPU. When defining the `fn`, extra
75 caution needs to be taken:
77 1) It is generally not recommended to open a device scope under the strategy's
78 scope. A device scope (i.e. calling `tf.device`) will be merged with or
79 override the device for operations but will not change the device for
80 variables.
82 2) It is also not recommended to open a colocation scope (i.e. calling
83 `tf.compat.v1.colocate_with`) under the strategy's scope. For colocating
84 variables, use `strategy.extended.colocate_vars_with` instead. Colocation of
85 ops will possibly create device assignment conflicts.
87 Note: This strategy only works with the Estimator API. Pass an instance of
88 this strategy to the `experimental_distribute` argument when you create the
89 `RunConfig`. This instance of `RunConfig` should then be passed to the
90 `Estimator` instance on which `train_and_evaluate` is called.
92 For Example:
93 ```
94 strategy = tf.distribute.experimental.ParameterServerStrategy()
95 run_config = tf.estimator.RunConfig(
96 experimental_distribute.train_distribute=strategy)
97 estimator = tf.estimator.Estimator(config=run_config)
98 tf.estimator.train_and_evaluate(estimator,...)
99 ```
100 """
102 def __init__(self, cluster_resolver=None):
103 """Initializes this strategy with an optional `cluster_resolver`.
105 Args:
106 cluster_resolver: Optional
107 `tf.distribute.cluster_resolver.ClusterResolver` object. Defaults to a
108 `tf.distribute.cluster_resolver.TFConfigClusterResolver`.
109 """
110 if cluster_resolver is None:
111 cluster_resolver = TFConfigClusterResolver()
112 super(ParameterServerStrategyV1, self).__init__(
113 ParameterServerStrategyExtended(
114 self, cluster_resolver=cluster_resolver))
115 distribute_lib.distribution_strategy_gauge.get_cell("V1").set(
116 "ParameterServerStrategy")
118 def experimental_distribute_dataset(self, dataset, options=None):
119 if (options and options.experimental_replication_mode ==
120 distribute_lib.InputReplicationMode.PER_REPLICA):
121 raise NotImplementedError(
122 "InputReplicationMode.PER_REPLICA "
123 "is only supported in "
124 "`experimental_distribute_datasets_from_function`."
125 )
126 self._raise_pss_error_if_eager()
127 super(ParameterServerStrategyV1,
128 self).experimental_distribute_dataset(dataset=dataset,
129 options=options)
131 def distribute_datasets_from_function(self, dataset_fn, options=None):
132 if (options and options.experimental_replication_mode ==
133 distribute_lib.InputReplicationMode.PER_REPLICA):
134 raise NotImplementedError(
135 "InputReplicationMode.PER_REPLICA "
136 "is only supported in "
137 "`experimental_distribute_datasets_from_function` "
138 "of tf.distribute.MirroredStrategy")
139 self._raise_pss_error_if_eager()
140 super(ParameterServerStrategyV1, self).distribute_datasets_from_function(
141 dataset_fn=dataset_fn, options=options)
143 def run(self, fn, args=(), kwargs=None, options=None):
144 self._raise_pss_error_if_eager()
145 super(ParameterServerStrategyV1, self).run(
146 fn, args=args, kwargs=kwargs, options=options)
148 def scope(self):
149 self._raise_pss_error_if_eager()
150 return super(ParameterServerStrategyV1, self).scope()
152 def _raise_pss_error_if_eager(self):
153 if context.executing_eagerly():
154 raise NotImplementedError(
155 "`tf.compat.v1.distribute.experimental.ParameterServerStrategy` "
156 "currently only works with the tf.Estimator API")
159# TODO(josh11b): Switch to V2 when we no longer need to support tf.compat.v1.
160class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1):
161 """Implementation of ParameterServerStrategy and CentralStorageStrategy."""
163 def __init__(self,
164 container_strategy,
165 cluster_resolver=None,
166 compute_devices=None,
167 parameter_device=None):
168 super(ParameterServerStrategyExtended, self).__init__(container_strategy)
169 self._initialize_strategy(
170 cluster_resolver=cluster_resolver,
171 compute_devices=compute_devices,
172 parameter_device=parameter_device)
174 # We typically don't need to do all-reduce in this strategy.
175 self._cross_device_ops = (
176 cross_device_ops_lib.ReductionToOneDevice(reduce_to_device=_LOCAL_CPU))
178 def _initialize_strategy(self,
179 cluster_resolver=None,
180 compute_devices=None,
181 parameter_device=None):
182 if cluster_resolver and cluster_resolver.cluster_spec():
183 self._initialize_multi_worker(cluster_resolver)
184 else:
185 self._initialize_local(
186 compute_devices, parameter_device, cluster_resolver=cluster_resolver)
188 def _initialize_multi_worker(self, cluster_resolver):
189 """Initialize devices for multiple workers.
191 It creates variable devices and compute devices. Variables and operations
192 will be assigned to them respectively. We have one compute device per
193 replica. The variable device is a device function or device string. The
194 default variable device assigns variables to parameter servers in a
195 round-robin fashion.
197 Args:
198 cluster_resolver: a descendant of `ClusterResolver` object.
200 Raises:
201 ValueError: if the cluster doesn't have ps jobs.
202 """
203 # TODO(b/126786766): TFConfigClusterResolver returns wrong number of GPUs in
204 # some cases.
205 if isinstance(cluster_resolver, TFConfigClusterResolver):
206 num_gpus = context.num_gpus()
207 else:
208 num_gpus = cluster_resolver.num_accelerators().get("GPU", 0)
210 # Save the num_gpus_per_worker for configure method.
211 self._num_gpus_per_worker = num_gpus
213 cluster_spec = cluster_resolver.cluster_spec()
214 task_type = cluster_resolver.task_type
215 task_id = cluster_resolver.task_id
216 if not task_type or task_id is None:
217 raise ValueError("When `cluster_spec` is given, you must also specify "
218 "`task_type` and `task_id`")
219 cluster_spec = multi_worker_util.normalize_cluster_spec(cluster_spec)
220 assert cluster_spec.as_dict()
222 self._worker_device = "/job:%s/task:%d" % (task_type, task_id)
223 self._input_host_device = numpy_dataset.SingleDevice(self._worker_device)
225 # Define compute devices which is a list of device strings and one for each
226 # replica. When there are GPUs, replicate operations on these GPUs.
227 # Otherwise, place operations on CPU.
228 if num_gpus > 0:
229 compute_devices = tuple(
230 "%s/device:GPU:%d" % (self._worker_device, i)
231 for i in range(num_gpus))
232 else:
233 compute_devices = (self._worker_device,)
235 self._compute_devices = [
236 device_util.canonicalize(d) for d in compute_devices]
238 # In distributed mode, place variables on ps jobs in a round-robin fashion.
239 # Note that devices returned from `replica_device_setter` are not
240 # canonical and therefore we don't canonicalize all variable devices to
241 # make them consistent.
242 # TODO(yuefengz): support passing a strategy object to control variable
243 # assignment.
244 # TODO(yuefengz): merge the logic of replica_device_setter into this
245 # class.
246 num_ps_replicas = len(cluster_spec.as_dict().get("ps", []))
247 if num_ps_replicas == 0:
248 raise ValueError("The cluster spec needs to have `ps` jobs.")
249 self._variable_device = device_setter.replica_device_setter(
250 ps_tasks=num_ps_replicas,
251 worker_device=self._worker_device,
252 merge_devices=True,
253 cluster=cluster_spec)
255 # The `_parameter_devices` is needed for the `parameter_devices` property
256 # and is a list of all variable devices. Here parameter devices are all
257 # tasks of the "ps" job.
258 self._parameter_devices = tuple(map("/job:ps/task:{}".format,
259 range(num_ps_replicas)))
261 # Add a default device so that ops without specified devices will not end up
262 # on other workers.
263 self._default_device = self._worker_device
265 self._is_chief = multi_worker_util.is_chief(cluster_spec, task_type,
266 task_id)
267 self._cluster_spec = cluster_spec
268 self._task_type = task_type
269 self._task_id = task_id
271 logging.info(
272 "Multi-worker ParameterServerStrategy with "
273 "cluster_spec = %r, task_type = %r, task_id = %r, "
274 "num_ps_replicas = %r, is_chief = %r, compute_devices = %r, "
275 "variable_device = %r", cluster_spec.as_dict(), task_type, task_id,
276 num_ps_replicas, self._is_chief, self._compute_devices,
277 self._variable_device)
279 # TODO(yuefengz): get rid of cluster_resolver argument when contrib's
280 # version no longer depends on this class.
281 def _initialize_local(self,
282 compute_devices,
283 parameter_device,
284 cluster_resolver=None):
285 """Initialize local devices for training."""
286 self._worker_device = device_util.canonicalize("/device:CPU:0")
287 self._input_host_device = numpy_dataset.SingleDevice(self._worker_device)
289 if compute_devices is None:
290 if not cluster_resolver:
291 num_gpus = context.num_gpus()
292 else:
293 num_gpus = cluster_resolver.num_accelerators().get("GPU", 0)
294 # Save the num_gpus_per_worker for configure method which is used by the
295 # contrib version.
296 self._num_gpus_per_worker = num_gpus
298 compute_devices = device_util.local_devices_from_num_gpus(num_gpus)
300 compute_devices = [device_util.canonicalize(d) for d in compute_devices]
302 if parameter_device is None:
303 # If there is only one GPU, put everything on that GPU. Otherwise, place
304 # variables on CPU.
305 if len(compute_devices) == 1:
306 parameter_device = compute_devices[0]
307 else:
308 parameter_device = _LOCAL_CPU
310 self._variable_device = parameter_device
311 self._compute_devices = compute_devices
312 self._parameter_devices = (parameter_device,)
313 self._is_chief = True
314 self._cluster_spec = None
315 self._task_type = None
316 self._task_id = None
318 logging.info(
319 "ParameterServerStrategy (CentralStorageStrategy if you are using a "
320 "single machine) with compute_devices = %r, variable_device = %r",
321 compute_devices, self._variable_device)
323 def _input_workers_with_options(self, options=None):
324 if not options or options.experimental_fetch_to_device:
325 return input_lib.InputWorkers(
326 [(self._worker_device, self._compute_devices)])
327 else:
328 return input_lib.InputWorkers(
329 [(self._worker_device,
330 (self._worker_device,) * len(self._compute_devices))])
332 @property
333 def _input_workers(self):
334 return self._input_workers_with_options()
336 def _validate_colocate_with_variable(self, colocate_with_variable):
337 distribute_utils.validate_colocate(colocate_with_variable, self)
339 def _experimental_distribute_dataset(self, dataset, options):
340 return input_util.get_distributed_dataset(
341 dataset,
342 self._input_workers_with_options(options),
343 self._container_strategy(),
344 num_replicas_in_sync=self._num_replicas_in_sync,
345 options=options)
347 def _make_dataset_iterator(self, dataset):
348 return input_lib_v1.DatasetIterator(
349 dataset,
350 self._input_workers,
351 self._container_strategy(),
352 num_replicas_in_sync=self._num_replicas_in_sync)
354 def _make_input_fn_iterator(
355 self,
356 input_fn,
357 replication_mode=distribute_lib.InputReplicationMode.PER_WORKER):
358 """Distributes the dataset to each local GPU."""
359 if self._cluster_spec:
360 input_pipeline_id = multi_worker_util.id_in_cluster(
361 self._cluster_spec, self._task_type, self._task_id)
362 num_input_pipelines = multi_worker_util.worker_count(
363 self._cluster_spec, self._task_type)
364 else:
365 input_pipeline_id = 0
366 num_input_pipelines = 1
367 input_context = distribute_lib.InputContext(
368 num_input_pipelines=num_input_pipelines,
369 input_pipeline_id=input_pipeline_id,
370 num_replicas_in_sync=self._num_replicas_in_sync)
371 return input_lib_v1.InputFunctionIterator(input_fn, self._input_workers,
372 [input_context],
373 self._container_strategy())
375 def _experimental_make_numpy_dataset(self, numpy_input, session):
376 return numpy_dataset.one_host_numpy_dataset(
377 numpy_input, self._input_host_device, session)
379 def _distribute_datasets_from_function(self, dataset_fn, options):
380 if self._cluster_spec:
381 input_pipeline_id = multi_worker_util.id_in_cluster(
382 self._cluster_spec, self._task_type, self._task_id)
383 num_input_pipelines = multi_worker_util.worker_count(
384 self._cluster_spec, self._task_type)
385 else:
386 input_pipeline_id = 0
387 num_input_pipelines = 1
389 input_context = distribute_lib.InputContext(
390 num_input_pipelines=num_input_pipelines,
391 input_pipeline_id=input_pipeline_id,
392 num_replicas_in_sync=self._num_replicas_in_sync)
394 return input_util.get_distributed_datasets_from_function(
395 dataset_fn,
396 self._input_workers_with_options(options), [input_context],
397 self._container_strategy(),
398 options=options)
400 def _experimental_distribute_values_from_function(self, value_fn):
401 per_replica_values = []
402 for replica_id in range(self._num_replicas_in_sync):
403 per_replica_values.append(
404 value_fn(distribute_lib.ValueContext(replica_id,
405 self._num_replicas_in_sync)))
406 return distribute_utils.regroup(per_replica_values, always_wrap=True)
408 def _broadcast_to(self, tensor, destinations):
409 # This is both a fast path for Python constants, and a way to delay
410 # converting Python values to a tensor until we know what type it
411 # should be converted to. Otherwise we have trouble with:
412 # global_step.assign_add(1)
413 # since the `1` gets broadcast as an int32 but global_step is int64.
414 if isinstance(tensor, (float, int)):
415 return tensor
416 if not cross_device_ops_lib.check_destinations(destinations):
417 # TODO(josh11b): Use current logical device instead of 0 here.
418 destinations = self._compute_devices
419 return self._cross_device_ops.broadcast(tensor, destinations)
421 def _allow_variable_partition(self):
422 return not context.executing_eagerly()
424 def _create_var_creator(self, next_creator, **kwargs):
425 if self._num_replicas_in_sync > 1:
426 aggregation = kwargs.pop("aggregation", vs.VariableAggregation.NONE)
427 if aggregation not in (
428 vs.VariableAggregation.NONE,
429 vs.VariableAggregation.SUM,
430 vs.VariableAggregation.MEAN,
431 vs.VariableAggregation.ONLY_FIRST_REPLICA
432 ):
433 raise ValueError("Invalid variable aggregation mode: " + aggregation +
434 " for variable: " + kwargs["name"])
436 def var_creator(**kwargs):
437 """Create an AggregatingVariable and fix up collections."""
438 # Record what collections this variable should be added to.
439 collections = kwargs.pop("collections", None)
440 if collections is None:
441 collections = [ops.GraphKeys.GLOBAL_VARIABLES]
442 kwargs["collections"] = []
444 # Create and wrap the variable.
445 v = next_creator(**kwargs)
446 wrapped = ps_values.AggregatingVariable(self._container_strategy(), v,
447 aggregation)
449 # Add the wrapped variable to the requested collections.
450 # The handling of eager mode and the global step matches
451 # ResourceVariable._init_from_args().
452 if not context.executing_eagerly():
453 g = ops.get_default_graph()
454 # If "trainable" is True, next_creator() will add the contained
455 # variable to the TRAINABLE_VARIABLES collection, so we manually
456 # remove it and replace with the wrapper. We can't set "trainable"
457 # to False for next_creator() since that causes functions like
458 # implicit_gradients to skip those variables.
459 if kwargs.get("trainable", True):
460 collections.append(ops.GraphKeys.TRAINABLE_VARIABLES)
461 l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES)
462 if v in l:
463 l.remove(v)
464 g.add_to_collections(collections, wrapped)
465 elif ops.GraphKeys.GLOBAL_STEP in collections:
466 ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, wrapped)
468 return wrapped
469 return var_creator
470 else:
471 return next_creator
473 # TODO(yuefengz): Not all ops in device_setter.STANDARD_PS_OPS will go through
474 # this creator, such as "MutableHashTable".
475 def _create_variable(self, next_creator, **kwargs):
476 var_creator = self._create_var_creator(next_creator, **kwargs)
478 if "colocate_with" in kwargs:
479 colocate_with = kwargs["colocate_with"]
480 if isinstance(colocate_with, numpy_dataset.SingleDevice):
481 with ops.device(colocate_with.device):
482 return var_creator(**kwargs)
483 with ops.device(None):
484 with ops.colocate_with(colocate_with):
485 return var_creator(**kwargs)
487 with ops.colocate_with(None, ignore_existing=True):
488 with ops.device(self._variable_device):
489 return var_creator(**kwargs)
491 def _call_for_each_replica(self, fn, args, kwargs):
492 return mirrored_run.call_for_each_replica(self._container_strategy(), fn,
493 args, kwargs)
495 def _verify_destinations_not_different_worker(self, destinations):
496 if not self._cluster_spec:
497 return
498 if destinations is None:
499 return
500 for d in cross_device_ops_lib.get_devices_from(destinations):
501 d_spec = tf_device.DeviceSpec.from_string(d)
502 if d_spec.job == self._task_type and d_spec.task != self._task_id:
503 raise ValueError(
504 "Cannot reduce to another worker: %r, current worker is %r" %
505 (d, self._worker_device))
507 def _gather_to_implementation(self, value, destinations, axis,
508 options):
509 self._verify_destinations_not_different_worker(destinations)
510 if not isinstance(value, values.DistributedValues):
511 return value
512 return self._cross_device_ops._gather( # pylint: disable=protected-access
513 value,
514 destinations=destinations,
515 axis=axis,
516 options=options)
518 def _reduce_to(self, reduce_op, value, destinations, options):
519 self._verify_destinations_not_different_worker(destinations)
520 if not isinstance(value, values.DistributedValues):
521 # pylint: disable=protected-access
522 return cross_device_ops_lib.reduce_non_distributed_value(
523 reduce_op, value, destinations, self._num_replicas_in_sync)
524 return self._cross_device_ops.reduce(
525 reduce_op, value, destinations=destinations, options=options)
527 def _batch_reduce_to(self, reduce_op, value_destination_pairs, options):
528 for _, destinations in value_destination_pairs:
529 self._verify_destinations_not_different_worker(destinations)
530 return self._cross_device_ops.batch_reduce(reduce_op,
531 value_destination_pairs, options)
533 def _select_single_value(self, structured):
534 """Select any single value in `structured`."""
536 def _select_fn(x): # pylint: disable=g-missing-docstring
537 if isinstance(x, values.Mirrored) or isinstance(x, values.PerReplica):
538 return x._primary # pylint: disable=protected-access
539 else:
540 return x
542 return nest.map_structure(_select_fn, structured)
544 def _update(self, var, fn, args, kwargs, group):
545 if isinstance(var, ps_values.AggregatingVariable):
546 var = var.get()
547 if not resource_variable_ops.is_resource_variable(var):
548 raise ValueError(
549 "You can not update `var` %r. It must be a Variable." % var)
550 with ops.colocate_with(var), distribute_lib.UpdateContext(var.device):
551 result = fn(var, *self._select_single_value(args),
552 **self._select_single_value(kwargs))
553 if group:
554 return result
555 else:
556 return nest.map_structure(self._local_results, result)
558 # TODO(yuefengz): does it need to call _select_single_value?
559 def _update_non_slot(self, colocate_with, fn, args, kwargs, group):
560 with ops.device(
561 colocate_with.device), distribute_lib.UpdateContext(colocate_with):
562 result = fn(*args, **kwargs)
563 if group:
564 return result
565 else:
566 return nest.map_structure(self._local_results, result)
568 def value_container(self, val):
569 if (hasattr(val, "_aggregating_container") and
570 not isinstance(val, ps_values.AggregatingVariable)):
571 wrapper = val._aggregating_container() # pylint: disable=protected-access
572 if wrapper is not None:
573 return wrapper
574 return val
576 def read_var(self, var):
577 # No need to distinguish between normal variables and replica-local
578 # variables.
579 return array_ops.identity(var)
581 def _configure(self,
582 session_config=None,
583 cluster_spec=None,
584 task_type=None,
585 task_id=None):
586 """Configures the strategy class with `cluster_spec`.
588 The strategy object will be re-initialized if `cluster_spec` is passed to
589 `configure` but was not passed when instantiating the strategy.
591 Args:
592 session_config: Session config object.
593 cluster_spec: a dict, ClusterDef or ClusterSpec object specifying the
594 cluster configurations.
595 task_type: the current task type.
596 task_id: the current task id.
598 Raises:
599 ValueError: if `cluster_spec` is given but `task_type` or `task_id` is
600 not.
601 """
602 if cluster_spec:
603 # Use the num_gpus_per_worker recorded in constructor since _configure
604 # doesn't take num_gpus.
605 cluster_resolver = SimpleClusterResolver(
606 cluster_spec=multi_worker_util.normalize_cluster_spec(cluster_spec),
607 task_type=task_type,
608 task_id=task_id,
609 num_accelerators={"GPU": self._num_gpus_per_worker})
610 self._initialize_multi_worker(cluster_resolver)
612 if session_config:
613 session_config.CopyFrom(self._update_config_proto(session_config))
615 def _update_config_proto(self, config_proto):
616 updated_config = copy.deepcopy(config_proto)
617 if not self._cluster_spec:
618 updated_config.isolate_session_state = True
619 return updated_config
621 updated_config.isolate_session_state = False
623 assert self._task_type
624 assert self._task_id is not None
626 # The device filters prevent communication between workers.
627 del updated_config.device_filters[:]
628 if self._task_type in ["chief", "worker"]:
629 updated_config.device_filters.extend(
630 ["/job:%s/task:%d" % (self._task_type, self._task_id), "/job:ps"])
631 elif self._task_type == "evaluator":
632 updated_config.device_filters.append(
633 "/job:%s/task:%d" % (self._task_type, self._task_id))
634 return updated_config
636 def _in_multi_worker_mode(self):
637 """Whether this strategy indicates working in multi-worker settings."""
638 return self._cluster_spec is not None
640 @property
641 def _num_replicas_in_sync(self):
642 return len(self._compute_devices)
644 @property
645 def worker_devices(self):
646 return self._compute_devices
648 @property
649 def worker_devices_by_replica(self):
650 return [[d] for d in self._compute_devices]
652 @property
653 def parameter_devices(self):
654 return self._parameter_devices
656 def non_slot_devices(self, var_list):
657 return min(var_list, key=lambda x: x.name)
659 @property
660 def experimental_between_graph(self):
661 # TODO(yuefengz): Should this return False in the local case?
662 return True
664 @property
665 def experimental_should_init(self):
666 return self._is_chief
668 @property
669 def should_checkpoint(self):
670 return self._is_chief
672 @property
673 def should_save_summary(self):
674 return self._is_chief
676 # TODO(priyag): Delete this once all strategies use global batch size.
677 @property
678 def _global_batch_size(self):
679 """`make_dataset_iterator` and `make_numpy_iterator` use global batch size.
681 `make_input_fn_iterator` assumes per-replica batching.
683 Returns:
684 Boolean.
685 """
686 return True
688 def _get_local_replica_id(self, replica_id_in_sync_group):
689 return replica_id_in_sync_group
691 def _get_replica_id_in_sync_group(self, replica_id):
692 return replica_id