Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/data/experimental/ops/data_service_ops.py: 30%

215 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Python API for executing a tf.data.Dataset using a tf.data service.""" 

16 

17import enum 

18import functools 

19 

20from tensorflow.core.protobuf import data_service_pb2 

21from tensorflow.python import tf2 

22from tensorflow.python.data.experimental.ops import compression_ops 

23from tensorflow.python.data.experimental.service import _pywrap_server_lib 

24from tensorflow.python.data.experimental.service import _pywrap_utils 

25from tensorflow.python.data.ops import dataset_ops 

26from tensorflow.python.data.ops import options as options_lib 

27from tensorflow.python.data.ops import structured_function 

28from tensorflow.python.data.ops.options import AutoShardPolicy 

29from tensorflow.python.data.ops.options import ExternalStatePolicy 

30from tensorflow.python.eager import context 

31from tensorflow.python.framework import dtypes 

32from tensorflow.python.framework import ops 

33from tensorflow.python.framework import tensor_spec 

34from tensorflow.python.framework import tensor_util 

35from tensorflow.python.ops import gen_experimental_dataset_ops 

36from tensorflow.python.ops import string_ops 

37from tensorflow.python.saved_model import nested_structure_coder 

38from tensorflow.python.util.tf_export import tf_export 

39 

40COMPRESSION_AUTO = "AUTO" 

41COMPRESSION_NONE = None 

42_PARALLEL_EPOCHS = "parallel_epochs" 

43_DISTRIBUTED_EPOCH = "distributed_epoch" 

44 

45 

46@tf_export("data.experimental.service.ShardingPolicy") 

47class ShardingPolicy(enum.IntEnum): 

48 """Specifies how to shard data among tf.data service workers. 

49 

50 OFF: No sharding will be performed. Each worker produces the entire dataset 

51 without any sharding. With this mode, the best practice is to shuffle the 

52 dataset nondeterministically so that workers process the dataset in different 

53 orders. If workers are restarted or join the cluster mid-job, they will begin 

54 processing the dataset from the beginning. 

55 

56 DYNAMIC: The input dataset is dynamically split among workers at runtime. Each 

57 worker gets the next split when it reads data from the dispatcher. Data is 

58 produced non-deterministically in this mode. Dynamic sharding works well with 

59 varying-sized tf.data service clusters, e.g., when you need to auto-scale your 

60 workers. Dynamic sharding provides at-most once visitation guarantees. No 

61 examples will be repeated, but some may be missed if a tf.data service worker 

62 gets restarted while processing a file. 

63 

64 The following are static sharding policies. The semantics are similar to 

65 `tf.data.experimental.AutoShardPolicy`. These policies require: 

66 * The tf.data service cluster is configured with a fixed list of workers 

67 in DispatcherConfig. 

68 * Each client only reads from the local tf.data service worker. 

69 

70 If a worker is restarted while performing static sharding, the worker will 

71 begin processing its shard again from the beginning. 

72 

73 FILE: Shards by input files (i.e. each worker will get a fixed set of files to 

74 process). When this option is selected, make sure that there is at least as 

75 many files as workers. If there are fewer input files than workers, a runtime 

76 error will be raised. 

77 

78 DATA: Shards by elements produced by the dataset. Each worker will process the 

79 whole dataset and discard the portion that is not for itself. Note that for 

80 this mode to correctly partition the dataset elements, the dataset needs to 

81 produce elements in a deterministic order. 

82 

83 FILE_OR_DATA: Attempts FILE-based sharding, falling back to DATA-based 

84 sharding on failure. 

85 

86 HINT: Looks for the presence of `shard(SHARD_HINT, ...)` which is treated as a 

87 placeholder to replace with `shard(num_workers, worker_index)`. 

88 """ 

89 

90 # LINT.IfChange(tf_data_service_sharding_policy) 

91 OFF = 0 

92 DYNAMIC = 1 

93 FILE = 2 

94 DATA = 3 

95 FILE_OR_DATA = 4 

96 HINT = 5 

97 # LINT.ThenChange() 

98 

99 def _to_proto(self): 

100 """Converts the policy to ProcessingModeDef proto enum.""" 

101 

102 if self == ShardingPolicy.OFF: 

103 return data_service_pb2.ProcessingModeDef.OFF 

104 if self == ShardingPolicy.DYNAMIC: 

105 return data_service_pb2.ProcessingModeDef.DYNAMIC 

106 if self == ShardingPolicy.FILE: 

107 return data_service_pb2.ProcessingModeDef.FILE 

108 if self == ShardingPolicy.DATA: 

109 return data_service_pb2.ProcessingModeDef.DATA 

110 if self == ShardingPolicy.FILE_OR_DATA: 

111 return data_service_pb2.ProcessingModeDef.FILE_OR_DATA 

112 if self == ShardingPolicy.HINT: 

113 return data_service_pb2.ProcessingModeDef.HINT 

114 raise ValueError(f"Unable to convert sharding policy {self!r} to proto.") 

115 

116 

117@tf_export("data.experimental.service.CrossTrainerCache") 

118class CrossTrainerCache: 

119 """Options related to the tf.data service cross trainer cache. 

120 

121 This is used to enable cross-trainer cache when distributing a dataset. For 

122 example: 

123 

124 ``` 

125 dataset = dataset.apply(tf.data.experimental.service.distribute( 

126 processing_mode=tf.data.experimental.service.ShardingPolicy.OFF, 

127 service=FLAGS.tf_data_service_address, 

128 job_name="job", 

129 cross_trainer_cache=data_service_ops.CrossTrainerCache( 

130 trainer_id=trainer_id()))) 

131 ``` 

132 

133 For more details, refer to 

134 https://www.tensorflow.org/api_docs/python/tf/data/experimental/service#sharing_tfdata_service_with_concurrent_trainers. 

135 """ 

136 

137 def __init__(self, trainer_id): 

138 """Constructs a CrossTrainerCache. 

139 

140 Args: 

141 trainer_id: Each training job has a unique ID. Once a job has consumed 

142 data, the data remains in the cache and is re-used by jobs with different 

143 `trainer_id`s. Requests with the same `trainer_id` do not re-use data. 

144 

145 Raises: 

146 ValueError if `trainer_id` is empty. 

147 """ 

148 if not trainer_id: 

149 raise ValueError( 

150 "tf.data service cross-trainer cache requires a non-empty trainer ID." 

151 ) 

152 self.trainer_id = trainer_id 

153 

154 def _to_proto(self): 

155 return data_service_pb2.CrossTrainerCacheOptions(trainer_id=self.trainer_id) 

156 

157 

158def _get_validated_sharding_policy(processing_mode): 

159 """Validates `processing_mode` and converts it to ShardingPolicy.""" 

160 

161 if isinstance(processing_mode, ShardingPolicy): 

162 return processing_mode 

163 if processing_mode == _PARALLEL_EPOCHS: 

164 return ShardingPolicy.OFF 

165 if processing_mode == _DISTRIBUTED_EPOCH: 

166 return ShardingPolicy.DYNAMIC 

167 

168 raise ValueError("tf.data service processing mode should be a " 

169 "`tf.data.experimental.service.ShardingPolicy`, " 

170 "`\"parallel_epochs\"`, or `\"distributed_epoch\"`. Got " 

171 f"{processing_mode!r}.") 

172 

173 

174def _validate_job_name(job_name): 

175 if job_name is None: 

176 return 

177 if not isinstance(job_name, str): 

178 raise ValueError("`job_name` must be a string, but `job_name` was of type " 

179 f"{type(job_name)}. job_name={job_name}") 

180 if not job_name: 

181 raise ValueError("`job_name` must not be empty") 

182 

183 

184def _validate_compression(compression): 

185 valid_compressions = [COMPRESSION_AUTO, COMPRESSION_NONE] 

186 if compression not in valid_compressions: 

187 raise ValueError(f"Invalid `compression` argument: {compression}. " 

188 f"Must be one of {valid_compressions}.") 

189 

190 

191def _get_compression_proto(compression): 

192 if compression == COMPRESSION_AUTO: 

193 return data_service_pb2.DataServiceMetadata.COMPRESSION_SNAPPY 

194 if compression == COMPRESSION_NONE: 

195 return data_service_pb2.DataServiceMetadata.COMPRESSION_OFF 

196 raise ValueError(f"Invalid `compression` argument: {compression}. " 

197 f"Must be one of {[COMPRESSION_AUTO, COMPRESSION_NONE]}.") 

198 

199 

200def _to_tensor(dataset_id): 

201 """Converts `dataset_id` to Tensor.""" 

202 

203 if isinstance(dataset_id, ops.Tensor): 

204 return dataset_id 

205 if isinstance(dataset_id, str) or isinstance(dataset_id, bytes): 

206 return ops.convert_to_tensor( 

207 dataset_id, dtype=dtypes.string, name="dataset_id") 

208 return ops.convert_to_tensor( 

209 dataset_id, dtype=dtypes.int64, name="dataset_id") 

210 

211 

212def _to_string(dataset_id): 

213 """Converts `dataset_id` to string.""" 

214 

215 if isinstance(dataset_id, ops.Tensor): 

216 return (dataset_id if dataset_id.dtype == dtypes.string else 

217 string_ops.as_string(dataset_id)) 

218 return (dataset_id.decode() 

219 if isinstance(dataset_id, bytes) else str(dataset_id)) 

220 

221 

222class _DataServiceDatasetV2(dataset_ops.DatasetSource): 

223 """A `Dataset` that reads elements from the tf.data service.""" 

224 

225 def __init__(self, 

226 dataset_id, 

227 processing_mode, 

228 address, 

229 element_spec, 

230 protocol, 

231 data_transfer_protocol, 

232 job_name=None, 

233 consumer_index=None, 

234 num_consumers=None, 

235 max_outstanding_requests=None, 

236 task_refresh_interval_hint_ms=None, 

237 cross_trainer_cache=None, 

238 target_workers="AUTO"): 

239 """Constructs a _DataServiceDatasetV2. 

240 

241 Args: 

242 dataset_id: The dataset id for the dataset to read from. 

243 processing_mode: A `tf.data.experimental.service.ShardingPolicy` 

244 specifying how to shard the dataset among tf.data workers. See 

245 `tf.data.experimental.service.ShardingPolicy` for details. For backwards 

246 compatibility, `processing_mode` may also be set to the strings 

247 `"parallel_epochs"` or `"distributed_epoch"`, which are respectively 

248 equivalent to `ShardingPolicy.OFF` and `ShardingPolicy.DYNAMIC`. 

249 address: The tf.data service address, e.g. "localhost:5000". 

250 element_spec: The dataset element spec for the dataset to read from. 

251 protocol: The protocol to use for communicating with the tf.data service, 

252 e.g. "grpc". 

253 data_transfer_protocol: (Optional.) The protocol to use for transferring 

254 data with the tf.data service. By default, data is transferred using 

255 gRPC. 

256 job_name: (Optional.) The name of the job. If provided, it must be a 

257 non-empty string or Tensor. This argument makes it possible for multiple 

258 datasets to share the same job. The default behavior is that the dataset 

259 creates anonymous, exclusively owned jobs. 

260 consumer_index: (Optional.) The index of the consumer in the range from 

261 `0` to `num_consumers`. Must be specified alongside `num_consumers`. 

262 When specified, consumers will read from the job in a strict round-robin 

263 order, instead of the default first-come-first-served order. 

264 num_consumers: (Optional.) The number of consumers which will consume from 

265 the job. Must be specified alongside `consumer_index`. When specified, 

266 consumers will read from the job in a strict round-robin order, instead 

267 of the default first-come-first-served order. When `num_consumers` is 

268 specified, the dataset must have infinite cardinality to prevent a 

269 producer from running out of data early and causing consumers to go out 

270 of sync. 

271 max_outstanding_requests: (Optional.) A limit on how many elements may be 

272 requested at the same time. You can use this option to control the 

273 amount of memory used, since `distribute` won't use more than 

274 `element_size` * `max_outstanding_requests` of memory. 

275 task_refresh_interval_hint_ms: (Optional.) A hint for how often to query 

276 the dispatcher for task changes. 

277 cross_trainer_cache: (Optional.) If a `CrossTrainerCache` object is 

278 provided, dataset iteration will be shared across concurrently running 

279 trainers. See 

280 https://www.tensorflow.org/api_docs/python/tf/data/experimental/service#sharing_tfdata_service_with_concurrent_trainers 

281 for details. 

282 target_workers: (Optional.) Which workers to read from. If `"AUTO"`, 

283 tf.data runtime decides which workers to read from. If `"ANY"`, reads 

284 from any tf.data service workers. If `"LOCAL"`, only reads from local 

285 in-processs tf.data service workers. `"AUTO"` works well for most cases, 

286 while users can specify other targets. For example, `"LOCAL"` helps 

287 avoid RPCs and data copy if every TF worker colocates with a tf.data 

288 service worker. Consumers of a shared job must use the same 

289 `target_workers`. Defaults to `"AUTO"`. 

290 """ 

291 if consumer_index is None != num_consumers is None: 

292 raise ValueError( 

293 "Must either set both `consumer_index` and `num_consumers`, " 

294 "or neither. ", 

295 f"consumer_index={consumer_index}, num_consumers={num_consumers}") 

296 if num_consumers is not None and job_name is None: 

297 raise ValueError("`job_name` must be set when setting `num_consumers`. " 

298 f"num_consumers was set to {num_consumers}.") 

299 

300 processing_mode_def = data_service_pb2.ProcessingModeDef( 

301 sharding_policy=_get_validated_sharding_policy( 

302 processing_mode)._to_proto()) 

303 if job_name is None: 

304 job_name = "" 

305 if max_outstanding_requests is None: 

306 max_outstanding_requests = dataset_ops.AUTOTUNE 

307 if task_refresh_interval_hint_ms is None: 

308 task_refresh_interval_hint_ms = dataset_ops.AUTOTUNE 

309 

310 self._dataset_id = _to_tensor(dataset_id) 

311 self._processing_mode = ops.convert_to_tensor( 

312 processing_mode_def.SerializeToString(), 

313 dtype=dtypes.string, 

314 name="processing_mode") 

315 self._address = ops.convert_to_tensor( 

316 address, dtype=dtypes.string, name="address") 

317 self._protocol = ops.convert_to_tensor( 

318 protocol, dtype=dtypes.string, name="protocol") 

319 self._job_name = ops.convert_to_tensor( 

320 job_name, dtype=dtypes.string, name="job_name") 

321 self._consumer_index = ops.convert_to_tensor( 

322 -1 if consumer_index is None else consumer_index, 

323 dtype=dtypes.int64, 

324 name="consumer_index") 

325 self._num_consumers = ops.convert_to_tensor( 

326 -1 if num_consumers is None else num_consumers, 

327 dtype=dtypes.int64, 

328 name="num_consumers") 

329 self._max_outstanding_requests = ops.convert_to_tensor( 

330 max_outstanding_requests, 

331 dtype=dtypes.int64, 

332 name="max_outstanding_requests") 

333 self._element_spec = element_spec 

334 uncompress_func = structured_function.StructuredFunctionWrapper( 

335 lambda x: compression_ops.uncompress(x, output_spec=element_spec), 

336 transformation_name="DataServiceDataset.uncompress()", 

337 input_structure=tensor_spec.TensorSpec(shape=(), dtype=dtypes.variant)) 

338 cross_trainer_cache_options = ( 

339 cross_trainer_cache._to_proto().SerializeToString() 

340 if cross_trainer_cache else None) 

341 

342 compat_kwargs = {} 

343 if data_transfer_protocol is not None: 

344 compat_kwargs["data_transfer_protocol"] = data_transfer_protocol 

345 

346 # If `uncompress` is `True`, the dataset will query the servers to find 

347 # out the actual compression used. It is always set to `True` the first 

348 # time the graph is built, and set to false when serializing, so we will 

349 # uncompress at most once. 

350 uncompress = True 

351 variant_tensor = gen_experimental_dataset_ops.data_service_dataset_v4( 

352 dataset_id=self._dataset_id, 

353 processing_mode=self._processing_mode, 

354 address=self._address, 

355 protocol=self._protocol, 

356 job_name=self._job_name, 

357 consumer_index=self._consumer_index, 

358 num_consumers=self._num_consumers, 

359 max_outstanding_requests=self._max_outstanding_requests, 

360 task_refresh_interval_hint_ms=task_refresh_interval_hint_ms, 

361 iteration_counter=( 

362 gen_experimental_dataset_ops.dummy_iteration_counter()), 

363 target_workers=target_workers, 

364 uncompress=uncompress, 

365 uncompress_fn=uncompress_func.function, 

366 cross_trainer_cache_options=cross_trainer_cache_options, 

367 **compat_kwargs, 

368 **self._flat_structure) 

369 super(_DataServiceDatasetV2, self).__init__(variant_tensor) 

370 

371 @property 

372 def element_spec(self): 

373 return self._element_spec 

374 

375 

376class _DataServiceDatasetV1(dataset_ops.DatasetV1Adapter): 

377 """A `Dataset` that executes its input through the tf.data service.""" 

378 

379 @functools.wraps(_DataServiceDatasetV2.__init__) 

380 def __init__(self, dataset_id, processing_mode, address, element_spec, 

381 protocol, data_transfer_protocol, job_name, consumer_index, 

382 num_consumers, max_outstanding_requests, 

383 task_refresh_interval_hint_ms, cross_trainer_cache, 

384 target_workers): 

385 

386 self._wrapped = _DataServiceDatasetV2( 

387 dataset_id=dataset_id, 

388 processing_mode=processing_mode, 

389 address=address, 

390 element_spec=element_spec, 

391 protocol=protocol, 

392 data_transfer_protocol=data_transfer_protocol, 

393 job_name=job_name, 

394 consumer_index=consumer_index, 

395 num_consumers=num_consumers, 

396 max_outstanding_requests=max_outstanding_requests, 

397 task_refresh_interval_hint_ms=task_refresh_interval_hint_ms, 

398 cross_trainer_cache=cross_trainer_cache, 

399 target_workers=target_workers) 

400 super(_DataServiceDatasetV1, self).__init__(self._wrapped) 

401 

402 

403if tf2.enabled(): 

404 _DataServiceDataset = _DataServiceDatasetV2 

405else: 

406 _DataServiceDataset = _DataServiceDatasetV1 

407 

408 

409def _parse_service(service): 

410 """Converts a tf.data service string into a (protocol, address) tuple. 

411 

412 Args: 

413 service: A string in the format "protocol://address" or just "address". If 

414 the string is only an address, the default protocol will be used. 

415 

416 Returns: 

417 The (protocol, address) tuple 

418 """ 

419 if not isinstance(service, str): 

420 raise ValueError("`service` must be a string, but `service` was of type " 

421 f"{type(service)}. service={service}") 

422 if not service: 

423 raise ValueError("`service` must not be empty") 

424 parts = service.split("://") 

425 if len(parts) == 2: 

426 protocol, address = parts 

427 elif len(parts) == 1: 

428 address = parts[0] 

429 protocol = _pywrap_utils.TF_DATA_DefaultProtocol() 

430 else: 

431 raise ValueError("Malformed `service` string has multiple '://': " 

432 f"{service}.") 

433 # TODO(aaudibert): Considering validating reachability of address here. 

434 return (protocol, address) 

435 

436 

437def _distribute(processing_mode, 

438 service, 

439 job_name=None, 

440 consumer_index=None, 

441 num_consumers=None, 

442 max_outstanding_requests=None, 

443 task_refresh_interval_hint_ms=None, 

444 data_transfer_protocol=None, 

445 compression="AUTO", 

446 cross_trainer_cache=None, 

447 target_workers="AUTO"): 

448 """A transformation that moves dataset processing to the tf.data service. 

449 

450 This transformation is similar to `distribute`, but supports additional 

451 parameters which we do not yet want to add to the public Python API. 

452 

453 Args: 

454 processing_mode: A `tf.data.experimental.service.ShardingPolicy` specifying 

455 how to shard the dataset among tf.data workers. See 

456 `tf.data.experimental.service.ShardingPolicy` for details. For backwards 

457 compatibility, `processing_mode` may also be set to the strings 

458 `"parallel_epochs"` or `"distributed_epoch"`, which are respectively 

459 equivalent to `ShardingPolicy.OFF` and `ShardingPolicy.DYNAMIC`. 

460 service: A string or a tuple indicating how to connect to the tf.data 

461 service. If it's a string, it should be in the format 

462 `[<protocol>://]<address>`, where `<address>` identifies the dispatcher 

463 address and `<protocol>` can optionally be used to override the default 

464 protocol to use. If it's a tuple, it should be (protocol, address). 

465 job_name: (Optional.) The name of the job. If provided, it must be a 

466 non-empty string. This argument makes it possible for multiple datasets to 

467 share the same job. The default behavior is that the dataset creates 

468 anonymous, exclusively owned jobs. 

469 consumer_index: (Optional.) The index of the consumer in the range from `0` 

470 to `num_consumers`. Must be specified alongside `num_consumers`. When 

471 specified, consumers will read from the job in a strict round-robin order, 

472 instead of the default first-come-first-served order. 

473 num_consumers: (Optional.) The number of consumers which will consume from 

474 the job. Must be specified alongside `consumer_index`. When specified, 

475 consumers will read from the job in a strict round-robin order, instead of 

476 the default first-come-first-served order. When `num_consumers` is 

477 specified, the dataset must have infinite cardinality to prevent a 

478 producer from running out of data early and causing consumers to go out of 

479 sync. 

480 max_outstanding_requests: (Optional.) A limit on how many elements may be 

481 requested at the same time. You can use this option to control the amount 

482 of memory used, since `distribute` won't use more than `element_size` * 

483 `max_outstanding_requests` of memory. 

484 task_refresh_interval_hint_ms: (Optional.) A hint for how often to query the 

485 dispatcher for task changes. 

486 data_transfer_protocol: (Optional.) The protocol to use for transferring 

487 data with the tf.data service. By default, data is transferred using gRPC. 

488 compression: How to compress the dataset's elements before transferring them 

489 over the network. "AUTO" leaves the decision of how to compress up to the 

490 tf.data service runtime. `None` indicates not to compress. 

491 cross_trainer_cache: (Optional.) If a `CrossTrainerCache` object is 

492 provided, dataset iteration will be shared across concurrently running 

493 trainers. See 

494 https://www.tensorflow.org/api_docs/python/tf/data/experimental/service#sharing_tfdata_service_with_concurrent_trainers 

495 for details. 

496 target_workers: (Optional.) Which workers to read from. If `"AUTO"`, tf.data 

497 runtime decides which workers to read from. If `"ANY"`, reads from any 

498 tf.data service workers. If `"LOCAL"`, only reads from local in-processs 

499 tf.data service workers. `"AUTO"` works well for most cases, while users 

500 can specify other targets. For example, `"LOCAL"` helps avoid RPCs and 

501 data copy if every TF worker colocates with a tf.data service worker. 

502 Consumers of a shared job must use the same `target_workers`. Defaults to 

503 `"AUTO"`. 

504 

505 Returns: 

506 Dataset: A `Dataset` of the elements produced by the data service. 

507 """ 

508 processing_mode = _get_validated_sharding_policy(processing_mode) 

509 _validate_compression(compression) 

510 

511 def _apply_fn(dataset): # pylint: disable=missing-docstring 

512 dataset_id = _register_dataset(service, dataset, compression=compression) 

513 return _from_dataset_id( 

514 processing_mode, 

515 service, 

516 dataset_id, 

517 dataset.element_spec, 

518 job_name=job_name, 

519 consumer_index=consumer_index, 

520 num_consumers=num_consumers, 

521 max_outstanding_requests=max_outstanding_requests, 

522 task_refresh_interval_hint_ms=task_refresh_interval_hint_ms, 

523 data_transfer_protocol=data_transfer_protocol, 

524 compression=compression, 

525 cross_trainer_cache=cross_trainer_cache, 

526 target_workers=target_workers) 

527 

528 return _apply_fn 

529 

530 

531@tf_export("data.experimental.service.distribute") 

532def distribute(processing_mode, 

533 service, 

534 job_name=None, 

535 consumer_index=None, 

536 num_consumers=None, 

537 max_outstanding_requests=None, 

538 data_transfer_protocol=None, 

539 compression="AUTO", 

540 cross_trainer_cache=None, 

541 target_workers="AUTO"): 

542 """A transformation that moves dataset processing to the tf.data service. 

543 

544 When you iterate over a dataset containing the `distribute` transformation, 

545 the tf.data service creates a "job" which produces data for the dataset 

546 iteration. 

547 

548 The tf.data service uses a cluster of workers to prepare data for training 

549 your model. 

550 The `processing_mode` argument to `tf.data.experimental.service.distribute` 

551 describes how to leverage multiple workers to process the input dataset. 

552 Currently, there are two processing modes to choose from: "distributed_epoch" 

553 and "parallel_epochs". 

554 

555 "distributed_epoch" means that the dataset will be split across all tf.data 

556 service workers. 

557 The dispatcher produces "splits" for the dataset and sends them to workers for 

558 further processing. For example, if a dataset begins with a list of filenames, 

559 the dispatcher will iterate through the filenames and send the filenames to 

560 tf.data workers, which will perform the rest of the dataset transformations on 

561 those files. "distributed_epoch" is useful when your model needs to see each 

562 element of the dataset exactly once, or if it needs to see the data in a 

563 generally-sequential order. "distributed_epoch" only works for datasets with 

564 splittable sources, such as `Dataset.from_tensor_slices`, 

565 `Dataset.list_files`, or `Dataset.range`. 

566 

567 "parallel_epochs" means that the entire input dataset will be processed 

568 independently by each of the tf.data service workers. 

569 For this reason, it is important to shuffle data (e.g. filenames) 

570 non-deterministically, so that each worker will process the elements of the 

571 dataset in a different order. "parallel_epochs" can be used to distribute 

572 datasets that aren't splittable. 

573 

574 With two workers, "parallel_epochs" will produce every element of the dataset 

575 twice: 

576 

577 >>> dispatcher = tf.data.experimental.service.DispatchServer() 

578 >>> dispatcher_address = dispatcher.target.split("://")[1] 

579 >>> # Start two workers 

580 >>> workers = [ 

581 ... tf.data.experimental.service.WorkerServer( 

582 ... tf.data.experimental.service.WorkerConfig( 

583 ... dispatcher_address=dispatcher_address)) for _ in range(2) 

584 ... ] 

585 >>> dataset = tf.data.Dataset.range(10) 

586 >>> dataset = dataset.apply(tf.data.experimental.service.distribute( 

587 ... processing_mode="parallel_epochs", service=dispatcher.target)) 

588 >>> print(sorted(list(dataset.as_numpy_iterator()))) 

589 [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9] 

590 

591 "distributed_epoch", on the other hand, will still produce each element once: 

592 

593 >>> dispatcher = tf.data.experimental.service.DispatchServer() 

594 >>> dispatcher_address = dispatcher.target.split("://")[1] 

595 >>> workers = [ 

596 ... tf.data.experimental.service.WorkerServer( 

597 ... tf.data.experimental.service.WorkerConfig( 

598 ... dispatcher_address=dispatcher_address)) for _ in range(2) 

599 ... ] 

600 >>> dataset = tf.data.Dataset.range(10) 

601 >>> dataset = dataset.apply(tf.data.experimental.service.distribute( 

602 ... processing_mode="distributed_epoch", service=dispatcher.target)) 

603 >>> print(sorted(list(dataset.as_numpy_iterator()))) 

604 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] 

605 

606 When using `apply(tf.data.experimental.service.distribute(...))`, the dataset 

607 before the `apply` transformation executes within the tf.data service, while 

608 the operations after `apply` happen within the local process. 

609 

610 >>> dispatcher = tf.data.experimental.service.DispatchServer() 

611 >>> dispatcher_address = dispatcher.target.split("://")[1] 

612 >>> workers = [ 

613 ... tf.data.experimental.service.WorkerServer( 

614 ... tf.data.experimental.service.WorkerConfig( 

615 ... dispatcher_address=dispatcher_address)) for _ in range(2) 

616 ... ] 

617 >>> dataset = tf.data.Dataset.range(5) 

618 >>> dataset = dataset.map(lambda x: x*x) 

619 >>> dataset = dataset.apply( 

620 ... tf.data.experimental.service.distribute("parallel_epochs", 

621 ... dispatcher.target)) 

622 >>> dataset = dataset.map(lambda x: x+1) 

623 >>> print(sorted(list(dataset.as_numpy_iterator()))) 

624 [1, 1, 2, 2, 5, 5, 10, 10, 17, 17] 

625 

626 In the above example, the dataset operations (before applying the `distribute` 

627 function on the elements) will be executed on the tf.data workers, 

628 and the elements are provided over RPC. The remaining transformations 

629 (after the call to `distribute`) will be executed locally. The dispatcher 

630 and the workers will bind to usused free ports (which are chosen at random), 

631 in order to communicate with each other. However, to bind them to specific 

632 ports, the `port` parameter can be passed. 

633 

634 The `job_name` argument allows jobs to be shared across multiple 

635 datasets. Instead of each dataset creating its own job, all 

636 datasets with the same `job_name` will consume from the same job. A new job 

637 will be created for each iteration of the dataset (with each repetition of 

638 `Dataset.repeat` counting as a new iteration). Suppose the `DispatchServer` 

639 is serving on `localhost:5000` and two training workers (in either a single 

640 client or multi-client setup) iterate over the below dataset, and there is a 

641 single tf.data worker: 

642 

643 ``` 

644 range5_dataset = tf.data.Dataset.range(5) 

645 dataset = range5_dataset.apply(tf.data.experimental.service.distribute( 

646 "parallel_epochs", "localhost:5000", job_name="my_job_name")) 

647 for iteration in range(3): 

648 print(list(dataset)) 

649 ``` 

650 

651 The elements of each job will be split between the two processes, with 

652 elements being consumed by the processes on a first-come first-served basis. 

653 One possible result is that process 1 prints 

654 

655 ``` 

656 [0, 2, 4] 

657 [0, 1, 3] 

658 [1] 

659 ``` 

660 

661 and process 2 prints 

662 

663 ``` 

664 [1, 3] 

665 [2, 4] 

666 [0, 2, 3, 4] 

667 ``` 

668 

669 Job names must not be re-used across different training jobs within the 

670 lifetime of the tf.data service. In general, the tf.data service is expected 

671 to live for the duration of a single training job. 

672 To use the tf.data service with multiple training jobs, make sure to use 

673 different job names to avoid conflicts. For example, suppose a training job 

674 calls `distribute` with `job_name="job"` and reads until end of input. If 

675 another independent job connects to the same tf.data service and tries to read 

676 from `job_name="job"`, it will immediately receive end of input, without 

677 getting any data. 

678 

679 **Coordinated data read** 

680 

681 By default, when multiple consumers read from the same job, they receive data 

682 on a first-come first-served basis. In some use cases, it is advantageous to 

683 coordinate the consumers. At each step, consumers read data from the same 

684 worker. 

685 

686 For example, the tf.data service can be used to coordinate example sizes 

687 across a cluster during synchronous training, so that during each step all 

688 replicas train on similar-sized elements. To achieve this, define a dataset 

689 which generates rounds of `num_consumers` consecutive similar-sized batches, 

690 then enable coordinated reads by setting `consumer_index` and `num_consumers`. 

691 

692 NOTE: To keep consumers in sync, round robin data consumption requires that 

693 the dataset have infinite cardinality. You can get this by adding `.repeat()` 

694 at the end of the dataset definition. 

695 

696 **Keras and Distribution Strategies** 

697 

698 The dataset produced by the `distribute` transformation can be passed to 

699 Keras' `Model.fit` or Distribution Strategy's 

700 `tf.distribute.Strategy.experimental_distribute_dataset` like any other 

701 `tf.data.Dataset`. We recommend setting a `job_name` on the call to 

702 `distribute` so that if there are multiple workers, they read data from the 

703 same job. Note that the autosharding normally performed by 

704 `experimental_distribute_dataset` will be disabled when setting a `job_name`, 

705 since sharing the job already results in splitting data across the workers. 

706 When using a shared job, data will be dynamically balanced across workers, so 

707 that they reach end of input about the same time. This results in better 

708 worker utilization than with autosharding, where each worker processes an 

709 independent set of files, and some workers may run out of data earlier than 

710 others. 

711 

712 Args: 

713 processing_mode: A `tf.data.experimental.service.ShardingPolicy` specifying 

714 how to shard the dataset among tf.data workers. See 

715 `tf.data.experimental.service.ShardingPolicy` for details. For backwards 

716 compatibility, `processing_mode` may also be set to the strings 

717 `"parallel_epochs"` or `"distributed_epoch"`, which are respectively 

718 equivalent to `ShardingPolicy.OFF` and `ShardingPolicy.DYNAMIC`. 

719 service: A string or a tuple indicating how to connect to the tf.data 

720 service. If it's a string, it should be in the format 

721 `[<protocol>://]<address>`, where `<address>` identifies the dispatcher 

722 address and `<protocol>` can optionally be used to override the default 

723 protocol to use. If it's a tuple, it should be (protocol, address). 

724 job_name: (Optional.) The name of the job. If provided, it must be a 

725 non-empty string. This argument makes it possible for multiple datasets to 

726 share the same job. The default behavior is that the dataset creates 

727 anonymous, exclusively owned jobs. 

728 consumer_index: (Optional.) The index of the consumer in the range from `0` 

729 to `num_consumers`. Must be specified alongside `num_consumers`. When 

730 specified, consumers will read from the job in a strict round-robin order, 

731 instead of the default first-come-first-served order. 

732 num_consumers: (Optional.) The number of consumers which will consume from 

733 the job. Must be specified alongside `consumer_index`. When specified, 

734 consumers will read from the job in a strict round-robin order, instead of 

735 the default first-come-first-served order. When `num_consumers` is 

736 specified, the dataset must have infinite cardinality to prevent a 

737 producer from running out of data early and causing consumers to go out of 

738 sync. 

739 max_outstanding_requests: (Optional.) A limit on how many elements may be 

740 requested at the same time. You can use this option to control the amount 

741 of memory used, since `distribute` won't use more than `element_size` * 

742 `max_outstanding_requests` of memory. 

743 data_transfer_protocol: (Optional.) The protocol to use for transferring 

744 data with the tf.data service. By default, data is transferred using gRPC. 

745 compression: How to compress the dataset's elements before transferring them 

746 over the network. "AUTO" leaves the decision of how to compress up to the 

747 tf.data service runtime. `None` indicates not to compress. 

748 cross_trainer_cache: (Optional.) If a `CrossTrainerCache` object is 

749 provided, dataset iteration will be shared across concurrently running 

750 trainers. See 

751 https://www.tensorflow.org/api_docs/python/tf/data/experimental/service#sharing_tfdata_service_with_concurrent_trainers 

752 for details. 

753 target_workers: (Optional.) Which workers to read from. If `"AUTO"`, tf.data 

754 runtime decides which workers to read from. If `"ANY"`, reads from any 

755 tf.data service workers. If `"LOCAL"`, only reads from local in-processs 

756 tf.data service workers. `"AUTO"` works well for most cases, while users 

757 can specify other targets. For example, `"LOCAL"` helps avoid RPCs and 

758 data copy if every TF worker colocates with a tf.data service worker. 

759 Consumers of a shared job must use the same `target_workers`. Defaults to 

760 `"AUTO"`. 

761 

762 Returns: 

763 Dataset: A `Dataset` of the elements produced by the data service. 

764 """ 

765 _validate_job_name(job_name) 

766 return _distribute( 

767 processing_mode=processing_mode, 

768 service=service, 

769 job_name=job_name, 

770 consumer_index=consumer_index, 

771 num_consumers=num_consumers, 

772 max_outstanding_requests=max_outstanding_requests, 

773 data_transfer_protocol=data_transfer_protocol, 

774 compression=compression, 

775 cross_trainer_cache=cross_trainer_cache, 

776 target_workers=target_workers) 

777 

778 

779def _register_dataset(service, dataset, compression, dataset_id=None): 

780 """Registers a dataset with the tf.data service. 

781 

782 This transformation is similar to `register_dataset`, but supports additional 

783 parameters which we do not yet want to add to the public Python API. 

784 

785 Args: 

786 service: A string or a tuple indicating how to connect to the tf.data 

787 service. If it's a string, it should be in the format 

788 `[<protocol>://]<address>`, where `<address>` identifies the dispatcher 

789 address and `<protocol>` can optionally be used to override the default 

790 protocol to use. If it's a tuple, it should be (protocol, address). 

791 dataset: A `tf.data.Dataset` to register with the tf.data service. 

792 compression: How to compress the dataset's elements before transferring them 

793 over the network. "AUTO" leaves the decision of how to compress up to the 

794 tf.data service runtime. `None` indicates not to compress. 

795 dataset_id: (Optional.) By default, tf.data service generates a unique 

796 (string) ID for each registered dataset. If a `dataset_id` is provided, it 

797 will use the specified ID. If a dataset with a matching ID already exists, 

798 no new dataset is registered. This is useful if multiple training jobs 

799 want to (re)use the same dataset for training. In this case, they can 

800 register the dataset with the same dataset ID. 

801 

802 Returns: 

803 A scalar string tensor representing the dataset ID. 

804 """ 

805 _validate_compression(compression) 

806 if isinstance(service, tuple): 

807 protocol, address = service 

808 else: 

809 protocol, address = _parse_service(service) 

810 external_state_policy = dataset.options().experimental_external_state_policy 

811 if external_state_policy is None: 

812 external_state_policy = ExternalStatePolicy.WARN 

813 

814 encoded_spec = None 

815 if context.executing_eagerly(): 

816 encoded_spec = nested_structure_coder.encode_structure( 

817 dataset.element_spec).SerializeToString() 

818 

819 if compression == COMPRESSION_AUTO: 

820 dataset = dataset.map( 

821 lambda *x: compression_ops.compress(x), 

822 num_parallel_calls=dataset_ops.AUTOTUNE) 

823 dataset = dataset._apply_debug_options() # pylint: disable=protected-access 

824 

825 metadata = data_service_pb2.DataServiceMetadata( 

826 element_spec=encoded_spec, 

827 compression=_get_compression_proto(compression)) 

828 

829 return gen_experimental_dataset_ops.register_dataset_v2( 

830 dataset._variant_tensor, # pylint: disable=protected-access 

831 address=address, 

832 protocol=protocol, 

833 external_state_policy=external_state_policy.value, 

834 requested_dataset_id=dataset_id, 

835 metadata=metadata.SerializeToString()) 

836 

837 

838@tf_export("data.experimental.service.register_dataset") 

839def register_dataset(service, dataset, compression="AUTO", dataset_id=None): 

840 """Registers a dataset with the tf.data service. 

841 

842 `register_dataset` registers a dataset with the tf.data service so that 

843 datasets can be created later with 

844 `tf.data.experimental.service.from_dataset_id`. This is useful when the 

845 dataset 

846 is registered by one process, then used in another process. When the same 

847 process is both registering and reading from the dataset, it is simpler to use 

848 `tf.data.experimental.service.distribute` instead. 

849 

850 If the dataset is already registered with the tf.data service, 

851 `register_dataset` returns the already-registered dataset's id. 

852 

853 >>> dispatcher = tf.data.experimental.service.DispatchServer() 

854 >>> dispatcher_address = dispatcher.target.split("://")[1] 

855 >>> worker = tf.data.experimental.service.WorkerServer( 

856 ... tf.data.experimental.service.WorkerConfig( 

857 ... dispatcher_address=dispatcher_address)) 

858 >>> dataset = tf.data.Dataset.range(10) 

859 >>> dataset_id = tf.data.experimental.service.register_dataset( 

860 ... dispatcher.target, dataset) 

861 >>> dataset = tf.data.experimental.service.from_dataset_id( 

862 ... processing_mode="parallel_epochs", 

863 ... service=dispatcher.target, 

864 ... dataset_id=dataset_id, 

865 ... element_spec=dataset.element_spec) 

866 >>> print(list(dataset.as_numpy_iterator())) 

867 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] 

868 

869 Args: 

870 service: A string or a tuple indicating how to connect to the tf.data 

871 service. If it's a string, it should be in the format 

872 `[<protocol>://]<address>`, where `<address>` identifies the dispatcher 

873 address and `<protocol>` can optionally be used to override the default 

874 protocol to use. If it's a tuple, it should be (protocol, address). 

875 dataset: A `tf.data.Dataset` to register with the tf.data service. 

876 compression: (Optional.) How to compress the dataset's elements before 

877 transferring them over the network. "AUTO" leaves the decision of how to 

878 compress up to the tf.data service runtime. `None` indicates not to 

879 compress. 

880 dataset_id: (Optional.) By default, tf.data service generates a unique 

881 (string) ID for each registered dataset. If a `dataset_id` is provided, it 

882 will use the specified ID. If a dataset with a matching ID already exists, 

883 no new dataset is registered. This is useful if multiple training jobs 

884 want to (re)use the same dataset for training. In this case, they can 

885 register the dataset with the same dataset ID. 

886 

887 Returns: 

888 A scalar string tensor representing the dataset ID. 

889 """ 

890 return _register_dataset(service, dataset, compression, dataset_id) 

891 

892 

893def _from_dataset_id(processing_mode, 

894 service, 

895 dataset_id, 

896 element_spec, 

897 job_name=None, 

898 consumer_index=None, 

899 num_consumers=None, 

900 max_outstanding_requests=None, 

901 task_refresh_interval_hint_ms=None, 

902 data_transfer_protocol=None, 

903 compression="AUTO", 

904 cross_trainer_cache=None, 

905 target_workers="AUTO"): 

906 """Creates a dataset which reads data from the tf.data service. 

907 

908 This transformation is similar to `from_dataset_id`, but supports additional 

909 parameters which we do not yet want to add to the public Python API. 

910 

911 Args: 

912 processing_mode: A `tf.data.experimental.service.ShardingPolicy` specifying 

913 how to shard the dataset among tf.data workers. See 

914 `tf.data.experimental.service.ShardingPolicy` for details. For backwards 

915 compatibility, `processing_mode` may also be set to the strings 

916 `"parallel_epochs"` or `"distributed_epoch"`, which are respectively 

917 equivalent to `ShardingPolicy.OFF` and `ShardingPolicy.DYNAMIC`. 

918 service: A string or a tuple indicating how to connect to the tf.data 

919 service. If it's a string, it should be in the format 

920 `[<protocol>://]<address>`, where `<address>` identifies the dispatcher 

921 address and `<protocol>` can optionally be used to override the default 

922 protocol to use. If it's a tuple, it should be (protocol, address). 

923 dataset_id: The id of the dataset to read from. This id is returned by 

924 `register_dataset` when the dataset is registered with the tf.data 

925 service. 

926 element_spec: A nested structure of `tf.TypeSpec`s representing the type of 

927 elements produced by the dataset. This argument is only required inside a 

928 tf.function. Use `tf.data.Dataset.element_spec` to get the element spec 

929 for a given dataset. 

930 job_name: (Optional.) The name of the job. If provided, it must be a 

931 non-empty string or tensor. This argument makes it possible for multiple 

932 datasets to share the same job. The default behavior is that the dataset 

933 creates anonymous, exclusively owned jobs. 

934 consumer_index: (Optional.) The index of the consumer in the range from `0` 

935 to `num_consumers`. Must be specified alongside `num_consumers`. When 

936 specified, consumers will read from the job in a strict round-robin order, 

937 instead of the default first-come-first-served order. 

938 num_consumers: (Optional.) The number of consumers which will consume from 

939 the job. Must be specified alongside `consumer_index`. When specified, 

940 consumers will read from the job in a strict round-robin order, instead of 

941 the default first-come-first-served order. When `num_consumers` is 

942 specified, the dataset must have infinite cardinality to prevent a 

943 producer from running out of data early and causing consumers to go out of 

944 sync. 

945 max_outstanding_requests: (Optional.) A limit on how many elements may be 

946 requested at the same time. You can use this option to control the amount 

947 of memory used, since `distribute` won't use more than `element_size` * 

948 `max_outstanding_requests` of memory. 

949 task_refresh_interval_hint_ms: (Optional.) A hint for how often to query the 

950 dispatcher for task changes. 

951 data_transfer_protocol: (Optional.) The protocol to use for transferring 

952 data with the tf.data service. By default, data is transferred using gRPC. 

953 compression: An indication of how the dataset's elements were compressed, so 

954 that `from_dataset_id` can uncompress them if necessary. 

955 cross_trainer_cache: (Optional.) If a `CrossTrainerCache` object is 

956 provided, dataset iteration will be shared across concurrently running 

957 trainers. See 

958 https://www.tensorflow.org/api_docs/python/tf/data/experimental/service#sharing_tfdata_service_with_concurrent_trainers 

959 for details. 

960 target_workers: (Optional.) Which workers to read from. If `"AUTO"`, tf.data 

961 runtime decides which workers to read from. If `"ANY"`, reads from any 

962 tf.data service workers. If `"LOCAL"`, only reads from local in-processs 

963 tf.data service workers. `"AUTO"` works well for most cases, while users 

964 can specify other targets. For example, `"LOCAL"` helps avoid RPCs and 

965 data copy if every TF worker colocates with a tf.data service worker. 

966 Consumers of a shared job must use the same `target_workers`. Defaults to 

967 `"AUTO"`. 

968 

969 Returns: 

970 A `tf.data.Dataset` which reads from the tf.data service. 

971 """ 

972 def _get_element_spec(): 

973 """Fetches the element spec from the server.""" 

974 data_service_metadata = None 

975 dataset_id_val = tensor_util.constant_value(dataset_id) 

976 try: 

977 data_service_metadata = ( 

978 _pywrap_server_lib.TF_DATA_GetDataServiceMetadataByID( 

979 dataset_id_val, address, protocol 

980 ) 

981 ) 

982 except NotImplementedError as err: 

983 raise ValueError( 

984 "The tf.data service is running an earlier version of TensorFlow " 

985 "that requires specifying `element_spec` as an argument to " 

986 "`from_dataset_id`. Please either supply an element spec or update " 

987 "the tf.data service to the latest version.") from err 

988 except RuntimeError: 

989 # This error results from dataset ID not found. A more appropriate error 

990 # will be raised when the dataset is created. 

991 pass 

992 

993 if not data_service_metadata or not data_service_metadata.element_spec: 

994 dataset_id_val = tensor_util.constant_value(dataset_id) 

995 raise ValueError( 

996 f"Failed to fetch element spec for dataset id {dataset_id_val} from " 

997 "tf.data service. If the dataset was registered in graph mode or " 

998 "inside a tf.function, the `element_spec` must be specified as an " 

999 "argument to `from_dataset_id`.") 

1000 

1001 struct_pb = nested_structure_coder.struct_pb2.StructuredValue() 

1002 struct_pb.ParseFromString(data_service_metadata.element_spec) 

1003 return nested_structure_coder.decode_proto(struct_pb) 

1004 

1005 processing_mode = _get_validated_sharding_policy(processing_mode) 

1006 if isinstance(service, tuple): 

1007 protocol, address = service 

1008 else: 

1009 protocol, address = _parse_service(service) 

1010 _validate_compression(compression) 

1011 if job_name is not None: 

1012 if not isinstance(job_name, str) and not isinstance(job_name, ops.Tensor): 

1013 raise ValueError( 

1014 "`job_name` must be a string or Tensor, but `job_name` was of type " 

1015 f"{type(job_name)}. job_name={job_name}.") 

1016 

1017 if not element_spec: 

1018 if not context.executing_eagerly(): 

1019 raise ValueError( 

1020 "In graph mode `element_spec` must be provided manually.") 

1021 element_spec = _get_element_spec() 

1022 

1023 dataset = _DataServiceDataset( 

1024 dataset_id=dataset_id, 

1025 processing_mode=processing_mode, 

1026 address=address, 

1027 element_spec=element_spec, 

1028 protocol=protocol, 

1029 data_transfer_protocol=data_transfer_protocol, 

1030 job_name=job_name, 

1031 consumer_index=consumer_index, 

1032 num_consumers=num_consumers, 

1033 max_outstanding_requests=max_outstanding_requests, 

1034 task_refresh_interval_hint_ms=task_refresh_interval_hint_ms, 

1035 cross_trainer_cache=cross_trainer_cache, 

1036 target_workers=target_workers) 

1037 

1038 # Disable autosharding for shared jobs. 

1039 if job_name is not None: 

1040 options = options_lib.Options() 

1041 options.experimental_distribute.auto_shard_policy = AutoShardPolicy.OFF 

1042 dataset = dataset.with_options(options) 

1043 return dataset 

1044 

1045 

1046@tf_export("data.experimental.service.from_dataset_id") 

1047def from_dataset_id(processing_mode, 

1048 service, 

1049 dataset_id, 

1050 element_spec=None, 

1051 job_name=None, 

1052 consumer_index=None, 

1053 num_consumers=None, 

1054 max_outstanding_requests=None, 

1055 data_transfer_protocol=None, 

1056 cross_trainer_cache=None, 

1057 target_workers="AUTO"): 

1058 """Creates a dataset which reads data from the tf.data service. 

1059 

1060 This is useful when the dataset is registered by one process, then used in 

1061 another process. When the same process is both registering and reading from 

1062 the dataset, it is simpler to use `tf.data.experimental.service.distribute` 

1063 instead. 

1064 

1065 Before using `from_dataset_id`, the dataset must have been registered with the 

1066 tf.data service using `tf.data.experimental.service.register_dataset`. 

1067 `register_dataset` returns a dataset id for the registered dataset. That is 

1068 the `dataset_id` which should be passed to `from_dataset_id`. 

1069 

1070 The `element_spec` argument indicates the `tf.TypeSpec`s for the elements 

1071 produced by the dataset. Currently `element_spec` must be explicitly 

1072 specified, and match the dataset registered under `dataset_id`. `element_spec` 

1073 defaults to `None` so that in the future we can support automatically 

1074 discovering the `element_spec` by querying the tf.data service. 

1075 

1076 `tf.data.experimental.service.distribute` is a convenience method which 

1077 combines `register_dataset` and `from_dataset_id` into a dataset 

1078 transformation. 

1079 See the documentation for `tf.data.experimental.service.distribute` for more 

1080 detail about how `from_dataset_id` works. 

1081 

1082 >>> dispatcher = tf.data.experimental.service.DispatchServer() 

1083 >>> dispatcher_address = dispatcher.target.split("://")[1] 

1084 >>> worker = tf.data.experimental.service.WorkerServer( 

1085 ... tf.data.experimental.service.WorkerConfig( 

1086 ... dispatcher_address=dispatcher_address)) 

1087 >>> dataset = tf.data.Dataset.range(10) 

1088 >>> dataset_id = tf.data.experimental.service.register_dataset( 

1089 ... dispatcher.target, dataset) 

1090 >>> dataset = tf.data.experimental.service.from_dataset_id( 

1091 ... processing_mode="parallel_epochs", 

1092 ... service=dispatcher.target, 

1093 ... dataset_id=dataset_id, 

1094 ... element_spec=dataset.element_spec) 

1095 >>> print(list(dataset.as_numpy_iterator())) 

1096 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] 

1097 

1098 Args: 

1099 processing_mode: A `tf.data.experimental.service.ShardingPolicy` specifying 

1100 how to shard the dataset among tf.data workers. See 

1101 `tf.data.experimental.service.ShardingPolicy` for details. For backwards 

1102 compatibility, `processing_mode` may also be set to the strings 

1103 `"parallel_epochs"` or `"distributed_epoch"`, which are respectively 

1104 equivalent to `ShardingPolicy.OFF` and `ShardingPolicy.DYNAMIC`. 

1105 service: A string or a tuple indicating how to connect to the tf.data 

1106 service. If it's a string, it should be in the format 

1107 `[<protocol>://]<address>`, where `<address>` identifies the dispatcher 

1108 address and `<protocol>` can optionally be used to override the default 

1109 protocol to use. If it's a tuple, it should be (protocol, address). 

1110 dataset_id: The id of the dataset to read from. This id is returned by 

1111 `register_dataset` when the dataset is registered with the tf.data 

1112 service. 

1113 element_spec: A nested structure of `tf.TypeSpec`s representing the type of 

1114 elements produced by the dataset. This argument is only required inside a 

1115 tf.function. Use `tf.data.Dataset.element_spec` to get the element spec 

1116 for a given dataset. 

1117 job_name: (Optional.) The name of the job. If provided, it must be a 

1118 non-empty string. This argument makes it possible for multiple datasets to 

1119 share the same job. The default behavior is that the dataset creates 

1120 anonymous, exclusively owned jobs. 

1121 consumer_index: (Optional.) The index of the consumer in the range from `0` 

1122 to `num_consumers`. Must be specified alongside `num_consumers`. When 

1123 specified, consumers will read from the job in a strict round-robin order, 

1124 instead of the default first-come-first-served order. 

1125 num_consumers: (Optional.) The number of consumers which will consume from 

1126 the job. Must be specified alongside `consumer_index`. When specified, 

1127 consumers will read from the job in a strict round-robin order, instead of 

1128 the default first-come-first-served order. When `num_consumers` is 

1129 specified, the dataset must have infinite cardinality to prevent a 

1130 producer from running out of data early and causing consumers to go out of 

1131 sync. 

1132 max_outstanding_requests: (Optional.) A limit on how many elements may be 

1133 requested at the same time. You can use this option to control the amount 

1134 of memory used, since `distribute` won't use more than `element_size` * 

1135 `max_outstanding_requests` of memory. 

1136 data_transfer_protocol: (Optional.) The protocol to use for transferring 

1137 data with the tf.data service. By default, data is transferred using gRPC. 

1138 cross_trainer_cache: (Optional.) If a `CrossTrainerCache` object is 

1139 provided, dataset iteration will be shared across concurrently running 

1140 trainers. See 

1141 https://www.tensorflow.org/api_docs/python/tf/data/experimental/service#sharing_tfdata_service_with_concurrent_trainers 

1142 for details. 

1143 target_workers: (Optional.) Which workers to read from. If `"AUTO"`, tf.data 

1144 runtime decides which workers to read from. If `"ANY"`, reads from any 

1145 tf.data service workers. If `"LOCAL"`, only reads from local in-processs 

1146 tf.data service workers. `"AUTO"` works well for most cases, while users 

1147 can specify other targets. For example, `"LOCAL"` helps avoid RPCs and 

1148 data copy if every TF worker colocates with a tf.data service worker. 

1149 Consumers of a shared job must use the same `target_workers`. Defaults to 

1150 `"AUTO"`. 

1151 

1152 Returns: 

1153 A `tf.data.Dataset` which reads from the tf.data service. 

1154 """ 

1155 _validate_job_name(job_name) 

1156 if job_name is not None: 

1157 job_name = string_ops.string_join( 

1158 ["dataset_id=", _to_string(dataset_id), job_name], "/") 

1159 

1160 return _from_dataset_id( 

1161 processing_mode=processing_mode, 

1162 service=service, 

1163 dataset_id=dataset_id, 

1164 element_spec=element_spec, 

1165 job_name=job_name, 

1166 consumer_index=consumer_index, 

1167 num_consumers=num_consumers, 

1168 max_outstanding_requests=max_outstanding_requests, 

1169 data_transfer_protocol=data_transfer_protocol, 

1170 cross_trainer_cache=cross_trainer_cache, 

1171 target_workers=target_workers)