Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/distribute/coordinator/values.py: 34%

183 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Important value classes relevant to `ClusterCoordinator`. 

16 

17This is currently under development and the API is subject to change. 

18""" 

19 

20import threading 

21 

22from tensorflow.python.data.ops import dataset_ops 

23from tensorflow.python.data.ops.options import ExternalStatePolicy 

24from tensorflow.python.distribute import input_lib 

25from tensorflow.python.distribute.coordinator import remote_value 

26from tensorflow.python.eager import context 

27from tensorflow.python.eager import def_function 

28from tensorflow.python.eager import function as tf_function 

29from tensorflow.python.framework import composite_tensor 

30from tensorflow.python.framework import errors 

31from tensorflow.python.framework import ops 

32from tensorflow.python.framework import type_spec as type_spec_lib 

33from tensorflow.python.ops import array_ops 

34from tensorflow.python.ops import gen_dataset_ops 

35from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops 

36from tensorflow.python.ops import variable_scope 

37from tensorflow.python.util import nest 

38from tensorflow.python.util.tf_export import tf_export 

39 

40 

41# TODO(yuefengz): create an implementation for resource RemoteValue which needs 

42# to remember the closure object while a normal RemoteValue doesn't. 

43class RemoteValueImpl(remote_value.RemoteValue): 

44 """Implementation of `RemoteValue`.""" 

45 

46 def __init__(self, closure, type_spec): # pylint: disable=super-init-not-called 

47 """Initializes a `RemoteValueImpl`. 

48 

49 Args: 

50 closure: The closure from which the `RemoteValue` is created. 

51 type_spec: The type spec for this `RemoteValue` which is used to trace 

52 functions that take this `RemoteValue` as input. 

53 """ 

54 self._closure = closure 

55 self._type_spec = type_spec 

56 self._values = None 

57 self._has_fetched_to_local = False 

58 self._has_fetched_to_local_lock = threading.Lock() 

59 self._fetched_tensors = None 

60 self._error = None 

61 self._status_available_event = threading.Event() 

62 self._status = remote_value.RemoteValueStatus.NOT_READY 

63 

64 def _set_aborted(self, error): 

65 self._status = remote_value.RemoteValueStatus.ABORTED 

66 self._values = None 

67 self._error = error 

68 

69 # Wake up any waiting thread and clear the event. 

70 self._status_available_event.set() 

71 

72 def _rebuild_on(self, worker): 

73 self._status_available_event.clear() 

74 # TODO(yuefengz): we may need to rebuild its inputs as well. 

75 self._closure.execute_on(worker) 

76 

77 def _set_values(self, tensors): 

78 self._status = remote_value.RemoteValueStatus.READY 

79 self._values = tensors 

80 self._error = None 

81 self._status_available_event.set() 

82 

83 def _set_error(self, error): 

84 self._status = remote_value.RemoteValueStatus.READY 

85 self._values = None 

86 self._error = error 

87 self._status_available_event.set() 

88 

89 def _get_values(self): 

90 self._status_available_event.wait() 

91 return self._values 

92 

93 def _get_error(self): 

94 self._status_available_event.wait() 

95 return self._error 

96 

97 def _wait_and_maybe_error(self): 

98 self._status_available_event.wait() 

99 if self._status is remote_value.RemoteValueStatus.ABORTED: 

100 raise errors.CancelledError( 

101 None, None, 

102 "The corresponding function is aborted. Please reschedule the " 

103 "function.") 

104 if self._error is not None: 

105 raise self._error 

106 

107 def fetch(self): 

108 # TODO(rchao): Discuss the possibility of letting users perform `numpy` 

109 # themselves at API graduation. 

110 return nest.map_structure( 

111 lambda x: x.numpy() if hasattr(x, "numpy") else x, self.get()) 

112 

113 def get(self): 

114 self._wait_and_maybe_error() 

115 

116 with self._has_fetched_to_local_lock: 

117 if not self._has_fetched_to_local: 

118 

119 def copy_tensor(composite_tensor_obj): 

120 """Copy a remote tensor to local (coordinator).""" 

121 if isinstance(composite_tensor_obj, input_lib.DistributedIterator): 

122 # A DistributedIterator cannot be copied to local; users should not 

123 # access that anyway. 

124 return composite_tensor_obj 

125 

126 with ops.device("/job:%s" % context.get_server_def().job_name): 

127 # Copying to local (the coordinator) with `tf.device`. 

128 return array_ops.identity(composite_tensor_obj) 

129 

130 if self._values is not None: 

131 # When `self._values` is `None`, it indicates the associated function 

132 # does not have a return value. 

133 self._fetched_tensors = nest.map_structure(copy_tensor, self._values) 

134 self._has_fetched_to_local = True 

135 

136 return self._fetched_tensors 

137 

138 

139@tf_export("distribute.experimental.coordinator.PerWorkerValues", 

140 "distribute.coordinator.PerWorkerValue", v1=[]) 

141class PerWorkerValues(composite_tensor.CompositeTensor): 

142 """A container that holds a list of values, one value per worker. 

143 

144 `tf.distribute.experimental.coordinator.PerWorkerValues` contains a collection 

145 of values, where each of the values is located on its corresponding worker, 

146 and upon being used as one of the `args` or `kwargs` of 

147 `tf.distribute.experimental.coordinator.ClusterCoordinator.schedule()`, the 

148 value specific to a worker will be passed into the function being executed at 

149 that corresponding worker. 

150 

151 Currently, the only supported path to create an object of 

152 `tf.distribute.experimental.coordinator.PerWorkerValues` is through calling 

153 `iter` on a `ClusterCoordinator.create_per_worker_dataset`-returned 

154 distributed dataset instance. The mechanism to create a custom 

155 `tf.distribute.experimental.coordinator.PerWorkerValues` is not yet supported. 

156 """ 

157 

158 def __init__(self, values): 

159 for v in values: 

160 if not isinstance(v, remote_value.RemoteValue): 

161 raise AssertionError( 

162 "`PerWorkerValues` should only take `RemoteValue`s.") 

163 self._values = tuple(values) 

164 

165 @property 

166 def _type_spec(self): 

167 return PerWorkerValuesTypeSpec( 

168 self._values[0]._type_spec, # pylint: disable=protected-access 

169 type(self)) 

170 

171 

172class PerWorkerValuesTypeSpec(type_spec_lib.TypeSpec): 

173 """TypeSpec for PerWorkerValues. 

174 

175 It only support tracing a function using a PerWorkerValues. 

176 """ 

177 

178 def __init__(self, value_spec, descendant_type): 

179 assert value_spec 

180 self._value_spec = value_spec 

181 self._descendant_type = descendant_type 

182 

183 def _serialize(self): 

184 return (self._value_spec,) 

185 

186 @property 

187 def value_type(self): 

188 return self._descendant_type 

189 

190 def most_specific_common_supertype(self, others): 

191 raise NotImplementedError( 

192 "most_specific_common_supertype is not implemented") 

193 

194 @property 

195 def _component_specs(self): 

196 return self._value_spec 

197 

198 def _to_components(self, value): 

199 return self._value_spec 

200 

201 def _from_components(self, value): 

202 return value 

203 

204 

205class PerWorkerDatasetFromDatasetFunction(object): 

206 """Represents worker-distributed datasets created from dataset function.""" 

207 

208 def __init__(self, dataset_fn, coordinator): 

209 """Makes an iterable from datasets created by the given function. 

210 

211 Args: 

212 dataset_fn: A function that returns a `Dataset`. 

213 coordinator: a `ClusterCoordinator` object, used to create dataset 

214 resources. 

215 """ 

216 

217 def disallow_variable_creation(next_creator, **kwargs): 

218 raise ValueError("Creating variables in `dataset_fn` is not allowed.") 

219 

220 if isinstance(dataset_fn, def_function.Function): 

221 with variable_scope.variable_creator_scope(disallow_variable_creation): 

222 dataset_fn = dataset_fn.get_concrete_function() 

223 elif not isinstance(dataset_fn, tf_function.ConcreteFunction): 

224 with variable_scope.variable_creator_scope(disallow_variable_creation): 

225 dataset_fn = def_function.function(dataset_fn).get_concrete_function() 

226 self._dataset_fn = dataset_fn 

227 self._coordinator = coordinator 

228 self._element_spec = None 

229 

230 def build(self): 

231 """Trigger dataset creation on workers without creating an iterator. 

232 

233 Returns: 

234 A PerWorkerValues object containing a tuple of RemoteValues, themselves 

235 containing the built Dataset for each worker 

236 """ 

237 def _create_per_worker_dataset(): 

238 dataset = self._dataset_fn() 

239 return dataset 

240 

241 # pylint: disable=protected-access 

242 per_worker_dataset = self._coordinator._create_per_worker_resources( 

243 _create_per_worker_dataset) 

244 # hack type_spec of RemoteValues 

245 dataset_fn_output_type_spec = self._dataset_fn.structured_outputs._type_spec 

246 for dataset_remote_value in per_worker_dataset._values: 

247 dataset_remote_value._type_spec = dataset_fn_output_type_spec 

248 return per_worker_dataset 

249 

250 def __iter__(self): 

251 # We would like users to create iterators outside `tf.function`s so that we 

252 # can track them. 

253 if (not context.executing_eagerly() or 

254 ops.get_default_graph().building_function): 

255 raise RuntimeError( 

256 "__iter__() is not supported inside of tf.function or in graph mode.") 

257 

258 def _create_per_worker_iterator(): 

259 dataset = self._dataset_fn() 

260 return iter(dataset) 

261 

262 # If PerWorkerDatasetFromDatasetFunction.__iter__ is called multiple 

263 # times, for the same object it should only create and register resource 

264 # once. Using object id to distinguish different iterator resources. 

265 per_worker_iterator = self._coordinator._create_per_worker_resources( 

266 _create_per_worker_iterator) 

267 

268 # Setting type_spec of each RemoteValue so that functions taking these 

269 # RemoteValues as inputs can be traced. 

270 for iterator_remote_value in per_worker_iterator._values: 

271 iterator_remote_value._type_spec = ( 

272 input_lib.get_iterator_spec_from_dataset( 

273 self._coordinator.strategy, self._dataset_fn.structured_outputs)) 

274 

275 return PerWorkerDistributedIterator(per_worker_iterator._values) 

276 

277 @property 

278 def element_spec(self): 

279 """The type specification of an element of this dataset. 

280 

281 This property is subject to change without notice. 

282 """ 

283 if not isinstance(self._dataset_fn, tf_function.ConcreteFunction): 

284 raise NotImplementedError( 

285 "`element_spec` is not supported when the `dataset_fn` is not " 

286 "a `ConcreteFunction`.") 

287 return self._dataset_fn.structured_outputs.element_spec 

288 

289 

290def serialize_dataset_to_graph(dataset): 

291 dataset = dataset._apply_debug_options() # pylint: disable=protected-access 

292 graph_def = gen_dataset_ops.dataset_to_graph_v2( 

293 dataset._variant_tensor, # pylint: disable=protected-access 

294 external_state_policy=ExternalStatePolicy.WARN.value, 

295 strip_device_assignment=True) 

296 return graph_def 

297 

298 

299class _RemoteDataset(dataset_ops.DatasetSource): 

300 """Creates a dataset given a graph def.""" 

301 

302 def __init__(self, graph_def, element_spec): 

303 self._elem_spec = element_spec 

304 variant_tensor = ged_ops.dataset_from_graph(graph_def) 

305 super(_RemoteDataset, self).__init__(variant_tensor) 

306 

307 @property 

308 def element_spec(self): 

309 return self._elem_spec 

310 

311 

312def deserialize_dataset_from_graph(graph_def, element_spec): 

313 return _RemoteDataset(graph_def, element_spec) 

314 

315 

316class PerWorkerDatasetFromDataset(PerWorkerDatasetFromDatasetFunction): 

317 """Represents worker-distributed datasets created from a dataset.""" 

318 

319 def __init__(self, dataset, coordinator): 

320 """Makes an iterable from datasets created by the given dataset. 

321 

322 It creates a dataset_fn which deserializes a dataset from a graph under the 

323 hood. 

324 

325 Args: 

326 dataset: A tf.data.Dataset, a DistributedDataset or a 

327 DistributedDatasetsFromFunction 

328 coordinator: a `ClusterCoordinator` object, used to create dataset 

329 resources. 

330 """ 

331 if isinstance(dataset, input_lib.DistributedDataset): 

332 original_dataset = dataset._original_dataset 

333 serialized = serialize_dataset_to_graph(original_dataset) 

334 

335 def dataset_fn(): 

336 deserialized = deserialize_dataset_from_graph( 

337 serialized, original_dataset.element_spec) 

338 dataset.build(dataset_to_replace=deserialized) 

339 return dataset 

340 elif isinstance(dataset, input_lib.DistributedDatasetsFromFunction): 

341 def dataset_fn(): 

342 dataset.build() 

343 return dataset 

344 elif isinstance(dataset, dataset_ops.Dataset): 

345 serialized = serialize_dataset_to_graph(dataset) 

346 

347 def dataset_fn(): 

348 return deserialize_dataset_from_graph(serialized, dataset.element_spec) 

349 else: 

350 raise ValueError("Unexpected dataset type!") 

351 

352 super(PerWorkerDatasetFromDataset, self).__init__(dataset_fn, coordinator) 

353 

354 

355def get_per_worker_dataset(dataset_or_dataset_fn, coordinator): 

356 """Returns a per-worker dataset from a dataset or a dataset function.""" 

357 if callable(dataset_or_dataset_fn): 

358 return PerWorkerDatasetFromDatasetFunction(dataset_or_dataset_fn, 

359 coordinator) 

360 else: 

361 return PerWorkerDatasetFromDataset(dataset_or_dataset_fn, coordinator) 

362 

363 

364class PerWorkerDistributedIterator(PerWorkerValues): 

365 """Distributed iterator for `ClusterCoordinator`.""" 

366 

367 def __next__(self): 

368 return self.get_next() 

369 

370 def get_next(self, name=None): 

371 """Returns the next input from the iterator for all replicas.""" 

372 raise NotImplementedError("Iterating over an `AsyncDistributedIterator` " 

373 "is not supported right now.")