Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/distribute/coordinator/values.py: 34%
183 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Important value classes relevant to `ClusterCoordinator`.
17This is currently under development and the API is subject to change.
18"""
20import threading
22from tensorflow.python.data.ops import dataset_ops
23from tensorflow.python.data.ops.options import ExternalStatePolicy
24from tensorflow.python.distribute import input_lib
25from tensorflow.python.distribute.coordinator import remote_value
26from tensorflow.python.eager import context
27from tensorflow.python.eager import def_function
28from tensorflow.python.eager import function as tf_function
29from tensorflow.python.framework import composite_tensor
30from tensorflow.python.framework import errors
31from tensorflow.python.framework import ops
32from tensorflow.python.framework import type_spec as type_spec_lib
33from tensorflow.python.ops import array_ops
34from tensorflow.python.ops import gen_dataset_ops
35from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
36from tensorflow.python.ops import variable_scope
37from tensorflow.python.util import nest
38from tensorflow.python.util.tf_export import tf_export
41# TODO(yuefengz): create an implementation for resource RemoteValue which needs
42# to remember the closure object while a normal RemoteValue doesn't.
43class RemoteValueImpl(remote_value.RemoteValue):
44 """Implementation of `RemoteValue`."""
46 def __init__(self, closure, type_spec): # pylint: disable=super-init-not-called
47 """Initializes a `RemoteValueImpl`.
49 Args:
50 closure: The closure from which the `RemoteValue` is created.
51 type_spec: The type spec for this `RemoteValue` which is used to trace
52 functions that take this `RemoteValue` as input.
53 """
54 self._closure = closure
55 self._type_spec = type_spec
56 self._values = None
57 self._has_fetched_to_local = False
58 self._has_fetched_to_local_lock = threading.Lock()
59 self._fetched_tensors = None
60 self._error = None
61 self._status_available_event = threading.Event()
62 self._status = remote_value.RemoteValueStatus.NOT_READY
64 def _set_aborted(self, error):
65 self._status = remote_value.RemoteValueStatus.ABORTED
66 self._values = None
67 self._error = error
69 # Wake up any waiting thread and clear the event.
70 self._status_available_event.set()
72 def _rebuild_on(self, worker):
73 self._status_available_event.clear()
74 # TODO(yuefengz): we may need to rebuild its inputs as well.
75 self._closure.execute_on(worker)
77 def _set_values(self, tensors):
78 self._status = remote_value.RemoteValueStatus.READY
79 self._values = tensors
80 self._error = None
81 self._status_available_event.set()
83 def _set_error(self, error):
84 self._status = remote_value.RemoteValueStatus.READY
85 self._values = None
86 self._error = error
87 self._status_available_event.set()
89 def _get_values(self):
90 self._status_available_event.wait()
91 return self._values
93 def _get_error(self):
94 self._status_available_event.wait()
95 return self._error
97 def _wait_and_maybe_error(self):
98 self._status_available_event.wait()
99 if self._status is remote_value.RemoteValueStatus.ABORTED:
100 raise errors.CancelledError(
101 None, None,
102 "The corresponding function is aborted. Please reschedule the "
103 "function.")
104 if self._error is not None:
105 raise self._error
107 def fetch(self):
108 # TODO(rchao): Discuss the possibility of letting users perform `numpy`
109 # themselves at API graduation.
110 return nest.map_structure(
111 lambda x: x.numpy() if hasattr(x, "numpy") else x, self.get())
113 def get(self):
114 self._wait_and_maybe_error()
116 with self._has_fetched_to_local_lock:
117 if not self._has_fetched_to_local:
119 def copy_tensor(composite_tensor_obj):
120 """Copy a remote tensor to local (coordinator)."""
121 if isinstance(composite_tensor_obj, input_lib.DistributedIterator):
122 # A DistributedIterator cannot be copied to local; users should not
123 # access that anyway.
124 return composite_tensor_obj
126 with ops.device("/job:%s" % context.get_server_def().job_name):
127 # Copying to local (the coordinator) with `tf.device`.
128 return array_ops.identity(composite_tensor_obj)
130 if self._values is not None:
131 # When `self._values` is `None`, it indicates the associated function
132 # does not have a return value.
133 self._fetched_tensors = nest.map_structure(copy_tensor, self._values)
134 self._has_fetched_to_local = True
136 return self._fetched_tensors
139@tf_export("distribute.experimental.coordinator.PerWorkerValues",
140 "distribute.coordinator.PerWorkerValue", v1=[])
141class PerWorkerValues(composite_tensor.CompositeTensor):
142 """A container that holds a list of values, one value per worker.
144 `tf.distribute.experimental.coordinator.PerWorkerValues` contains a collection
145 of values, where each of the values is located on its corresponding worker,
146 and upon being used as one of the `args` or `kwargs` of
147 `tf.distribute.experimental.coordinator.ClusterCoordinator.schedule()`, the
148 value specific to a worker will be passed into the function being executed at
149 that corresponding worker.
151 Currently, the only supported path to create an object of
152 `tf.distribute.experimental.coordinator.PerWorkerValues` is through calling
153 `iter` on a `ClusterCoordinator.create_per_worker_dataset`-returned
154 distributed dataset instance. The mechanism to create a custom
155 `tf.distribute.experimental.coordinator.PerWorkerValues` is not yet supported.
156 """
158 def __init__(self, values):
159 for v in values:
160 if not isinstance(v, remote_value.RemoteValue):
161 raise AssertionError(
162 "`PerWorkerValues` should only take `RemoteValue`s.")
163 self._values = tuple(values)
165 @property
166 def _type_spec(self):
167 return PerWorkerValuesTypeSpec(
168 self._values[0]._type_spec, # pylint: disable=protected-access
169 type(self))
172class PerWorkerValuesTypeSpec(type_spec_lib.TypeSpec):
173 """TypeSpec for PerWorkerValues.
175 It only support tracing a function using a PerWorkerValues.
176 """
178 def __init__(self, value_spec, descendant_type):
179 assert value_spec
180 self._value_spec = value_spec
181 self._descendant_type = descendant_type
183 def _serialize(self):
184 return (self._value_spec,)
186 @property
187 def value_type(self):
188 return self._descendant_type
190 def most_specific_common_supertype(self, others):
191 raise NotImplementedError(
192 "most_specific_common_supertype is not implemented")
194 @property
195 def _component_specs(self):
196 return self._value_spec
198 def _to_components(self, value):
199 return self._value_spec
201 def _from_components(self, value):
202 return value
205class PerWorkerDatasetFromDatasetFunction(object):
206 """Represents worker-distributed datasets created from dataset function."""
208 def __init__(self, dataset_fn, coordinator):
209 """Makes an iterable from datasets created by the given function.
211 Args:
212 dataset_fn: A function that returns a `Dataset`.
213 coordinator: a `ClusterCoordinator` object, used to create dataset
214 resources.
215 """
217 def disallow_variable_creation(next_creator, **kwargs):
218 raise ValueError("Creating variables in `dataset_fn` is not allowed.")
220 if isinstance(dataset_fn, def_function.Function):
221 with variable_scope.variable_creator_scope(disallow_variable_creation):
222 dataset_fn = dataset_fn.get_concrete_function()
223 elif not isinstance(dataset_fn, tf_function.ConcreteFunction):
224 with variable_scope.variable_creator_scope(disallow_variable_creation):
225 dataset_fn = def_function.function(dataset_fn).get_concrete_function()
226 self._dataset_fn = dataset_fn
227 self._coordinator = coordinator
228 self._element_spec = None
230 def build(self):
231 """Trigger dataset creation on workers without creating an iterator.
233 Returns:
234 A PerWorkerValues object containing a tuple of RemoteValues, themselves
235 containing the built Dataset for each worker
236 """
237 def _create_per_worker_dataset():
238 dataset = self._dataset_fn()
239 return dataset
241 # pylint: disable=protected-access
242 per_worker_dataset = self._coordinator._create_per_worker_resources(
243 _create_per_worker_dataset)
244 # hack type_spec of RemoteValues
245 dataset_fn_output_type_spec = self._dataset_fn.structured_outputs._type_spec
246 for dataset_remote_value in per_worker_dataset._values:
247 dataset_remote_value._type_spec = dataset_fn_output_type_spec
248 return per_worker_dataset
250 def __iter__(self):
251 # We would like users to create iterators outside `tf.function`s so that we
252 # can track them.
253 if (not context.executing_eagerly() or
254 ops.get_default_graph().building_function):
255 raise RuntimeError(
256 "__iter__() is not supported inside of tf.function or in graph mode.")
258 def _create_per_worker_iterator():
259 dataset = self._dataset_fn()
260 return iter(dataset)
262 # If PerWorkerDatasetFromDatasetFunction.__iter__ is called multiple
263 # times, for the same object it should only create and register resource
264 # once. Using object id to distinguish different iterator resources.
265 per_worker_iterator = self._coordinator._create_per_worker_resources(
266 _create_per_worker_iterator)
268 # Setting type_spec of each RemoteValue so that functions taking these
269 # RemoteValues as inputs can be traced.
270 for iterator_remote_value in per_worker_iterator._values:
271 iterator_remote_value._type_spec = (
272 input_lib.get_iterator_spec_from_dataset(
273 self._coordinator.strategy, self._dataset_fn.structured_outputs))
275 return PerWorkerDistributedIterator(per_worker_iterator._values)
277 @property
278 def element_spec(self):
279 """The type specification of an element of this dataset.
281 This property is subject to change without notice.
282 """
283 if not isinstance(self._dataset_fn, tf_function.ConcreteFunction):
284 raise NotImplementedError(
285 "`element_spec` is not supported when the `dataset_fn` is not "
286 "a `ConcreteFunction`.")
287 return self._dataset_fn.structured_outputs.element_spec
290def serialize_dataset_to_graph(dataset):
291 dataset = dataset._apply_debug_options() # pylint: disable=protected-access
292 graph_def = gen_dataset_ops.dataset_to_graph_v2(
293 dataset._variant_tensor, # pylint: disable=protected-access
294 external_state_policy=ExternalStatePolicy.WARN.value,
295 strip_device_assignment=True)
296 return graph_def
299class _RemoteDataset(dataset_ops.DatasetSource):
300 """Creates a dataset given a graph def."""
302 def __init__(self, graph_def, element_spec):
303 self._elem_spec = element_spec
304 variant_tensor = ged_ops.dataset_from_graph(graph_def)
305 super(_RemoteDataset, self).__init__(variant_tensor)
307 @property
308 def element_spec(self):
309 return self._elem_spec
312def deserialize_dataset_from_graph(graph_def, element_spec):
313 return _RemoteDataset(graph_def, element_spec)
316class PerWorkerDatasetFromDataset(PerWorkerDatasetFromDatasetFunction):
317 """Represents worker-distributed datasets created from a dataset."""
319 def __init__(self, dataset, coordinator):
320 """Makes an iterable from datasets created by the given dataset.
322 It creates a dataset_fn which deserializes a dataset from a graph under the
323 hood.
325 Args:
326 dataset: A tf.data.Dataset, a DistributedDataset or a
327 DistributedDatasetsFromFunction
328 coordinator: a `ClusterCoordinator` object, used to create dataset
329 resources.
330 """
331 if isinstance(dataset, input_lib.DistributedDataset):
332 original_dataset = dataset._original_dataset
333 serialized = serialize_dataset_to_graph(original_dataset)
335 def dataset_fn():
336 deserialized = deserialize_dataset_from_graph(
337 serialized, original_dataset.element_spec)
338 dataset.build(dataset_to_replace=deserialized)
339 return dataset
340 elif isinstance(dataset, input_lib.DistributedDatasetsFromFunction):
341 def dataset_fn():
342 dataset.build()
343 return dataset
344 elif isinstance(dataset, dataset_ops.Dataset):
345 serialized = serialize_dataset_to_graph(dataset)
347 def dataset_fn():
348 return deserialize_dataset_from_graph(serialized, dataset.element_spec)
349 else:
350 raise ValueError("Unexpected dataset type!")
352 super(PerWorkerDatasetFromDataset, self).__init__(dataset_fn, coordinator)
355def get_per_worker_dataset(dataset_or_dataset_fn, coordinator):
356 """Returns a per-worker dataset from a dataset or a dataset function."""
357 if callable(dataset_or_dataset_fn):
358 return PerWorkerDatasetFromDatasetFunction(dataset_or_dataset_fn,
359 coordinator)
360 else:
361 return PerWorkerDatasetFromDataset(dataset_or_dataset_fn, coordinator)
364class PerWorkerDistributedIterator(PerWorkerValues):
365 """Distributed iterator for `ClusterCoordinator`."""
367 def __next__(self):
368 return self.get_next()
370 def get_next(self, name=None):
371 """Returns the next input from the iterator for all replicas."""
372 raise NotImplementedError("Iterating over an `AsyncDistributedIterator` "
373 "is not supported right now.")