Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/distribute/one_device_strategy.py: 46%
168 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""A tf.distribute.Strategy for running on a single device."""
17from tensorflow.python.distribute import device_util
18from tensorflow.python.distribute import distribute_lib
19from tensorflow.python.distribute import distribute_utils
20from tensorflow.python.distribute import input_lib
21from tensorflow.python.distribute import input_util
22from tensorflow.python.distribute import numpy_dataset
23from tensorflow.python.distribute.v1 import input_lib as input_lib_v1
24from tensorflow.python.framework import constant_op
25from tensorflow.python.framework import ops
26from tensorflow.python.ops import array_ops
27from tensorflow.python.ops import control_flow_ops
28from tensorflow.python.ops import while_loop
29from tensorflow.python.util import nest
30from tensorflow.python.util.tf_export import tf_export
33# TODO(josh11b): Do we wrap values in types to generate errors if you are
34# doing something that won't work with other DistributionStrategy
35# implementations?
38@tf_export("distribute.OneDeviceStrategy", v1=[])
39class OneDeviceStrategy(distribute_lib.Strategy):
40 """A distribution strategy for running on a single device.
42 Using this strategy will place any variables created in its scope on the
43 specified device. Input distributed through this strategy will be
44 prefetched to the specified device. Moreover, any functions called via
45 `strategy.run` will also be placed on the specified device
46 as well.
48 Typical usage of this strategy could be testing your code with the
49 tf.distribute.Strategy API before switching to other strategies which
50 actually distribute to multiple devices/machines.
52 For example:
53 ```
54 strategy = tf.distribute.OneDeviceStrategy(device="/gpu:0")
56 with strategy.scope():
57 v = tf.Variable(1.0)
58 print(v.device) # /job:localhost/replica:0/task:0/device:GPU:0
60 def step_fn(x):
61 return x * 2
63 result = 0
64 for i in range(10):
65 result += strategy.run(step_fn, args=(i,))
66 print(result) # 90
67 ```
68 """
70 def __init__(self, device):
71 """Creates a `OneDeviceStrategy`.
73 Args:
74 device: Device string identifier for the device on which the variables
75 should be placed. See class docs for more details on how the device is
76 used. Examples: "/cpu:0", "/gpu:0", "/device:CPU:0", "/device:GPU:0"
77 """
78 super(OneDeviceStrategy, self).__init__(OneDeviceExtended(self, device))
79 distribute_lib.distribution_strategy_gauge.get_cell("V2").set(
80 "OneDeviceStrategy")
82 def experimental_distribute_dataset(self, dataset, options=None): # pylint: disable=useless-super-delegation
83 """Distributes a tf.data.Dataset instance provided via dataset.
85 In this case, there is only one device, so this is only a thin wrapper
86 around the input dataset. It will, however, prefetch the input data to the
87 specified device. The returned distributed dataset can be iterated over
88 similar to how regular datasets can.
90 NOTE: Currently, the user cannot add any more transformations to a
91 distributed dataset.
93 Example:
94 ```
95 strategy = tf.distribute.OneDeviceStrategy()
96 dataset = tf.data.Dataset.range(10).batch(2)
97 dist_dataset = strategy.experimental_distribute_dataset(dataset)
98 for x in dist_dataset:
99 print(x) # [0, 1], [2, 3],...
100 ```
101 Args:
102 dataset: `tf.data.Dataset` to be prefetched to device.
103 options: `tf.distribute.InputOptions` used to control options on how this
104 dataset is distributed.
105 Returns:
106 A "distributed `Dataset`" that the caller can iterate over.
107 """
108 return super(OneDeviceStrategy, self).experimental_distribute_dataset(
109 dataset, options)
111 def distribute_datasets_from_function(
112 self,
113 dataset_fn, # pylint: disable=useless-super-delegation
114 options=None):
115 """Distributes `tf.data.Dataset` instances created by calls to `dataset_fn`.
117 `dataset_fn` will be called once for each worker in the strategy. In this
118 case, we only have one worker and one device so `dataset_fn` is called
119 once.
121 The `dataset_fn` should take an `tf.distribute.InputContext` instance where
122 information about batching and input replication can be accessed:
124 ```
125 def dataset_fn(input_context):
126 batch_size = input_context.get_per_replica_batch_size(global_batch_size)
127 d = tf.data.Dataset.from_tensors([[1.]]).repeat().batch(batch_size)
128 return d.shard(
129 input_context.num_input_pipelines, input_context.input_pipeline_id)
131 inputs = strategy.distribute_datasets_from_function(dataset_fn)
133 for batch in inputs:
134 replica_results = strategy.run(replica_fn, args=(batch,))
135 ```
137 IMPORTANT: The `tf.data.Dataset` returned by `dataset_fn` should have a
138 per-replica batch size, unlike `experimental_distribute_dataset`, which uses
139 the global batch size. This may be computed using
140 `input_context.get_per_replica_batch_size`.
142 Args:
143 dataset_fn: A function taking a `tf.distribute.InputContext` instance and
144 returning a `tf.data.Dataset`.
145 options: `tf.distribute.InputOptions` used to control options on how this
146 dataset is distributed.
148 Returns:
149 A "distributed `Dataset`", which the caller can iterate over like regular
150 datasets.
151 """
152 return super(OneDeviceStrategy,
153 self).distribute_datasets_from_function(dataset_fn, options)
155 def experimental_local_results(self, value): # pylint: disable=useless-super-delegation
156 """Returns the list of all local per-replica values contained in `value`.
158 In `OneDeviceStrategy`, the `value` is always expected to be a single
159 value, so the result is just the value in a tuple.
161 Args:
162 value: A value returned by `experimental_run()`, `run()`,
163 `extended.call_for_each_replica()`, or a variable created in `scope`.
165 Returns:
166 A tuple of values contained in `value`. If `value` represents a single
167 value, this returns `(value,).`
168 """
169 return super(OneDeviceStrategy, self).experimental_local_results(value)
171 def run(self, fn, args=(), kwargs=None, options=None): # pylint: disable=useless-super-delegation
172 """Run `fn` on each replica, with the given arguments.
174 In `OneDeviceStrategy`, `fn` is simply called within a device scope for the
175 given device, with the provided arguments.
177 Args:
178 fn: The function to run. The output must be a `tf.nest` of `Tensor`s.
179 args: (Optional) Positional arguments to `fn`.
180 kwargs: (Optional) Keyword arguments to `fn`.
181 options: (Optional) An instance of `tf.distribute.RunOptions` specifying
182 the options to run `fn`.
184 Returns:
185 Return value from running `fn`.
186 """
187 return super(OneDeviceStrategy, self).run(fn, args, kwargs, options)
189 def reduce(self, reduce_op, value, axis): # pylint: disable=useless-super-delegation
190 """Reduce `value` across replicas.
192 In `OneDeviceStrategy`, there is only one replica, so if axis=None, value
193 is simply returned. If axis is specified as something other than None,
194 such as axis=0, value is reduced along that axis and returned.
196 Example:
197 ```
198 t = tf.range(10)
200 result = strategy.reduce(tf.distribute.ReduceOp.SUM, t, axis=None).numpy()
201 # result: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
203 result = strategy.reduce(tf.distribute.ReduceOp.SUM, t, axis=0).numpy()
204 # result: 45
205 ```
207 Args:
208 reduce_op: A `tf.distribute.ReduceOp` value specifying how values should
209 be combined.
210 value: A "per replica" value, e.g. returned by `run` to
211 be combined into a single tensor.
212 axis: Specifies the dimension to reduce along within each
213 replica's tensor. Should typically be set to the batch dimension, or
214 `None` to only reduce across replicas (e.g. if the tensor has no batch
215 dimension).
217 Returns:
218 A `Tensor`.
219 """
220 return super(OneDeviceStrategy, self).reduce(reduce_op, value, axis)
222 def scope(self): # pylint: disable=useless-super-delegation
223 """Returns a context manager selecting this Strategy as current.
225 Inside a `with strategy.scope():` code block, this thread
226 will use a variable creator set by `strategy`, and will
227 enter its "cross-replica context".
229 In `OneDeviceStrategy`, all variables created inside `strategy.scope()`
230 will be on `device` specified at strategy construction time.
231 See example in the docs for this class.
233 Returns:
234 A context manager to use for creating variables with this strategy.
235 """
236 return super(OneDeviceStrategy, self).scope()
239@tf_export(v1=["distribute.OneDeviceStrategy"]) # pylint: disable=empty-docstring
240class OneDeviceStrategyV1(distribute_lib.StrategyV1):
242 __doc__ = OneDeviceStrategy.__doc__.replace(
243 "For example:\n ```",
244 "For example:\n ```\n tf.enable_eager_execution()")
246 def __init__(self, device):
247 super(OneDeviceStrategyV1, self).__init__(OneDeviceExtended(self, device))
248 distribute_lib.distribution_strategy_gauge.get_cell("V1").set(
249 "OneDeviceStrategy")
250 __init__.__doc__ = OneDeviceStrategy.__init__.__doc__
253# TODO(josh11b): Switch to V2 after callers have been updated to only V2 APIs.
254class OneDeviceExtended(distribute_lib.StrategyExtendedV1):
255 """Implementation of OneDeviceStrategy."""
257 def __init__(self, container_strategy, device):
258 super(OneDeviceExtended, self).__init__(container_strategy)
259 self._device = device_util.resolve(device)
260 self._input_device = device_util.get_host_for_device(self._device)
262 def _input_workers_with_options(self, options=None):
263 if not options or options.experimental_fetch_to_device:
264 return input_lib.InputWorkers([(self._input_device, (self._device,))])
265 else:
266 return input_lib.InputWorkers([(self._input_device,
267 (self._input_device,))])
269 @property
270 def _input_workers(self):
271 return self._input_workers_with_options()
273 def _create_variable(self, next_creator, **kwargs):
274 colocate_with = kwargs.pop("colocate_with", None)
275 if colocate_with is None:
276 with ops.device(self._device):
277 return next_creator(**kwargs)
278 elif isinstance(colocate_with, numpy_dataset.SingleDevice):
279 with ops.device(colocate_with.device):
280 return next_creator(**kwargs)
281 else:
282 with ops.colocate_with(colocate_with):
283 return next_creator(**kwargs)
285 def _validate_colocate_with_variable(self, colocate_with_variable):
286 distribute_utils.validate_colocate(colocate_with_variable, self)
288 def _make_dataset_iterator(self, dataset):
289 """Make iterator from dataset without splitting the batch."""
290 # Note that split_batch_by argument is not passed because it is always 1 in
291 # this strategy, and adding it adds unnecessary overhead to the dataset.
292 return input_lib_v1.DatasetIterator(dataset, self._input_workers,
293 self._container_strategy())
295 def _make_input_fn_iterator(
296 self,
297 input_fn,
298 replication_mode=distribute_lib.InputReplicationMode.PER_WORKER):
299 return input_lib_v1.InputFunctionIterator(input_fn, self._input_workers,
300 [distribute_lib.InputContext()],
301 self._container_strategy())
303 def _experimental_make_numpy_dataset(self, numpy_input, session):
304 return numpy_dataset.one_host_numpy_dataset(
305 numpy_input, numpy_dataset.SingleDevice(self._input_device), session)
307 def _broadcast_to(self, tensor, destinations):
308 del destinations
309 return tensor
311 def _experimental_distribute_dataset(self, dataset, options):
312 # Note that split_batch_by argument is not passed because it is always 1 in
313 # this strategy, and adding it adds unnecessary overhead to the dataset.
314 if (options and options.experimental_replication_mode ==
315 distribute_lib.InputReplicationMode.PER_REPLICA):
316 raise NotImplementedError(
317 "InputReplicationMode.PER_REPLICA "
318 "is only supported in "
319 "`experimental_distribute_datasets_from_function`."
320 )
321 return input_util.get_distributed_dataset(
322 dataset,
323 self._input_workers_with_options(options),
324 self._container_strategy(),
325 options=options)
327 def _distribute_datasets_from_function(self, dataset_fn, options):
328 if (options and options.experimental_replication_mode ==
329 distribute_lib.InputReplicationMode.PER_REPLICA):
330 raise NotImplementedError(
331 "InputReplicationMode.PER_REPLICA "
332 "is only supported in "
333 "`experimental_distribute_datasets_from_function` "
334 "of tf.distribute.MirroredStrategy")
335 return input_util.get_distributed_datasets_from_function(
336 dataset_fn,
337 self._input_workers_with_options(options),
338 [distribute_lib.InputContext()],
339 self._container_strategy(),
340 options=options)
342 def _experimental_distribute_values_from_function(self, value_fn):
343 # TODO(b/137795644): This should return a PerReplica value but other
344 # methods like run in OneDeviceStrategy need to be modified
345 # to do the same.
346 return value_fn(distribute_lib.ValueContext())
348 # TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed.
349 def _experimental_run_steps_on_iterator(self, fn, iterator, iterations,
350 initial_loop_values=None):
351 if initial_loop_values is None:
352 initial_loop_values = {}
353 initial_loop_values = nest.flatten(initial_loop_values)
355 ctx = input_lib.MultiStepContext()
356 def body(i, *args):
357 """A wrapper around `fn` to create the while loop body."""
358 del args
359 fn_result = fn(ctx, iterator.get_next())
360 flat_last_step_outputs = nest.flatten(ctx.last_step_outputs)
361 with ops.control_dependencies([fn_result]):
362 return [i + 1] + flat_last_step_outputs
364 # We capture the control_flow_context at this point, before we run `fn`
365 # inside a while_loop. This is useful in cases where we might need to exit
366 # these contexts and get back to the outer context to do some things, for
367 # e.g. create an op which should be evaluated only once at the end of the
368 # loop on the host. One such usage is in creating metrics' value op.
369 self._outer_control_flow_context = (
370 ops.get_default_graph()._get_control_flow_context()) # pylint: disable=protected-access
372 # TODO(priyag): Use max_iterations instead of an explicit counter.
373 cond = lambda i, *args: i < iterations
374 i = constant_op.constant(0)
375 loop_result = while_loop.while_loop(
376 cond,
377 body, [i] + initial_loop_values,
378 name="",
379 parallel_iterations=1,
380 back_prop=False,
381 swap_memory=False,
382 return_same_structure=True)
383 del self._outer_control_flow_context
385 ctx.run_op = control_flow_ops.group(loop_result)
387 # Convert the last_step_outputs from a list to the original dict structure
388 # of last_step_outputs.
389 last_step_tensor_outputs = loop_result[1:]
390 last_step_tensor_outputs_dict = nest.pack_sequence_as(
391 ctx.last_step_outputs, last_step_tensor_outputs)
393 ctx._set_last_step_outputs(last_step_tensor_outputs_dict) # pylint: disable=protected-access
394 return ctx
396 def _call_for_each_replica(self, fn, args, kwargs):
397 strategy = self._container_strategy()
398 with ops.device(self._device), _OneDeviceReplicaContext(strategy):
399 return fn(*args, **kwargs)
401 def _reduce_to(self, reduce_op, value, destinations, options):
402 del reduce_op, destinations, options
403 return value
405 def _gather_to_implementation(self, value, destinations, axis, options):
406 del destinations, axis, options
407 return value
409 def _update(self, var, fn, args, kwargs, group):
410 # The implementations of _update() and _update_non_slot() are identical
411 # except _update() passes `var` as the first argument to `fn()`.
412 return self._update_non_slot(var, fn, (var,) + tuple(args), kwargs, group)
414 def _update_non_slot(self, colocate_with, fn, args, kwargs, group):
415 del colocate_with
416 with ops.device(self._device), distribute_lib.UpdateContext(self._device):
417 result = fn(*args, **kwargs)
418 if group:
419 return result
420 else:
421 return nest.map_structure(self._local_results, result)
423 def read_var(self, replica_local_var):
424 """Read the aggregate value of a replica-local variable."""
425 return array_ops.identity(replica_local_var)
427 def _local_results(self, value):
428 return (value,)
430 def value_container(self, value):
431 return value
433 def _in_multi_worker_mode(self):
434 """Whether this strategy indicates working in multi-worker settings."""
435 return False
437 @property
438 def _num_replicas_in_sync(self):
439 return 1
441 @property
442 def worker_devices(self):
443 return (self._device,)
445 @property
446 def parameter_devices(self):
447 return (self._device,)
449 def non_slot_devices(self, var_list):
450 del var_list
451 return (self._device,)
453 @property
454 def experimental_should_init(self):
455 return True
457 @property
458 def experimental_between_graph(self):
459 return False
461 @property
462 def should_checkpoint(self):
463 return True
465 @property
466 def should_save_summary(self):
467 return True
469 # TODO(priyag): Delete this once all strategies use global batch size.
470 @property
471 def _global_batch_size(self):
472 """Global and per-replica batching are equivalent for OneDeviceStrategy."""
473 return True
475 @property
476 def _support_per_replica_values(self):
477 return False
479 def _get_local_replica_id(self, replica_id_in_sync_group):
480 return replica_id_in_sync_group
483class _OneDeviceReplicaContext(distribute_lib.ReplicaContext):
484 """ReplicaContext for OneDeviceStrategy."""
486 def __init__(self, strategy):
487 distribute_lib.ReplicaContext.__init__(
488 self, strategy, replica_id_in_sync_group=0)
490 @property
491 def devices(self):
492 return self._strategy.extended.worker_devices