Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/data/experimental/ops/data_service_ops.py: 30%
215 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Python API for executing a tf.data.Dataset using a tf.data service."""
17import enum
18import functools
20from tensorflow.core.protobuf import data_service_pb2
21from tensorflow.python import tf2
22from tensorflow.python.data.experimental.ops import compression_ops
23from tensorflow.python.data.experimental.service import _pywrap_server_lib
24from tensorflow.python.data.experimental.service import _pywrap_utils
25from tensorflow.python.data.ops import dataset_ops
26from tensorflow.python.data.ops import options as options_lib
27from tensorflow.python.data.ops import structured_function
28from tensorflow.python.data.ops.options import AutoShardPolicy
29from tensorflow.python.data.ops.options import ExternalStatePolicy
30from tensorflow.python.eager import context
31from tensorflow.python.framework import dtypes
32from tensorflow.python.framework import ops
33from tensorflow.python.framework import tensor_spec
34from tensorflow.python.framework import tensor_util
35from tensorflow.python.ops import gen_experimental_dataset_ops
36from tensorflow.python.ops import string_ops
37from tensorflow.python.saved_model import nested_structure_coder
38from tensorflow.python.util.tf_export import tf_export
40COMPRESSION_AUTO = "AUTO"
41COMPRESSION_NONE = None
42_PARALLEL_EPOCHS = "parallel_epochs"
43_DISTRIBUTED_EPOCH = "distributed_epoch"
46@tf_export("data.experimental.service.ShardingPolicy")
47class ShardingPolicy(enum.IntEnum):
48 """Specifies how to shard data among tf.data service workers.
50 OFF: No sharding will be performed. Each worker produces the entire dataset
51 without any sharding. With this mode, the best practice is to shuffle the
52 dataset nondeterministically so that workers process the dataset in different
53 orders. If workers are restarted or join the cluster mid-job, they will begin
54 processing the dataset from the beginning.
56 DYNAMIC: The input dataset is dynamically split among workers at runtime. Each
57 worker gets the next split when it reads data from the dispatcher. Data is
58 produced non-deterministically in this mode. Dynamic sharding works well with
59 varying-sized tf.data service clusters, e.g., when you need to auto-scale your
60 workers. Dynamic sharding provides at-most once visitation guarantees. No
61 examples will be repeated, but some may be missed if a tf.data service worker
62 gets restarted while processing a file.
64 The following are static sharding policies. The semantics are similar to
65 `tf.data.experimental.AutoShardPolicy`. These policies require:
66 * The tf.data service cluster is configured with a fixed list of workers
67 in DispatcherConfig.
68 * Each client only reads from the local tf.data service worker.
70 If a worker is restarted while performing static sharding, the worker will
71 begin processing its shard again from the beginning.
73 FILE: Shards by input files (i.e. each worker will get a fixed set of files to
74 process). When this option is selected, make sure that there is at least as
75 many files as workers. If there are fewer input files than workers, a runtime
76 error will be raised.
78 DATA: Shards by elements produced by the dataset. Each worker will process the
79 whole dataset and discard the portion that is not for itself. Note that for
80 this mode to correctly partition the dataset elements, the dataset needs to
81 produce elements in a deterministic order.
83 FILE_OR_DATA: Attempts FILE-based sharding, falling back to DATA-based
84 sharding on failure.
86 HINT: Looks for the presence of `shard(SHARD_HINT, ...)` which is treated as a
87 placeholder to replace with `shard(num_workers, worker_index)`.
88 """
90 # LINT.IfChange(tf_data_service_sharding_policy)
91 OFF = 0
92 DYNAMIC = 1
93 FILE = 2
94 DATA = 3
95 FILE_OR_DATA = 4
96 HINT = 5
97 # LINT.ThenChange()
99 def _to_proto(self):
100 """Converts the policy to ProcessingModeDef proto enum."""
102 if self == ShardingPolicy.OFF:
103 return data_service_pb2.ProcessingModeDef.OFF
104 if self == ShardingPolicy.DYNAMIC:
105 return data_service_pb2.ProcessingModeDef.DYNAMIC
106 if self == ShardingPolicy.FILE:
107 return data_service_pb2.ProcessingModeDef.FILE
108 if self == ShardingPolicy.DATA:
109 return data_service_pb2.ProcessingModeDef.DATA
110 if self == ShardingPolicy.FILE_OR_DATA:
111 return data_service_pb2.ProcessingModeDef.FILE_OR_DATA
112 if self == ShardingPolicy.HINT:
113 return data_service_pb2.ProcessingModeDef.HINT
114 raise ValueError(f"Unable to convert sharding policy {self!r} to proto.")
117@tf_export("data.experimental.service.CrossTrainerCache")
118class CrossTrainerCache:
119 """Options related to the tf.data service cross trainer cache.
121 This is used to enable cross-trainer cache when distributing a dataset. For
122 example:
124 ```
125 dataset = dataset.apply(tf.data.experimental.service.distribute(
126 processing_mode=tf.data.experimental.service.ShardingPolicy.OFF,
127 service=FLAGS.tf_data_service_address,
128 job_name="job",
129 cross_trainer_cache=data_service_ops.CrossTrainerCache(
130 trainer_id=trainer_id())))
131 ```
133 For more details, refer to
134 https://www.tensorflow.org/api_docs/python/tf/data/experimental/service#sharing_tfdata_service_with_concurrent_trainers.
135 """
137 def __init__(self, trainer_id):
138 """Constructs a CrossTrainerCache.
140 Args:
141 trainer_id: Each training job has a unique ID. Once a job has consumed
142 data, the data remains in the cache and is re-used by jobs with different
143 `trainer_id`s. Requests with the same `trainer_id` do not re-use data.
145 Raises:
146 ValueError if `trainer_id` is empty.
147 """
148 if not trainer_id:
149 raise ValueError(
150 "tf.data service cross-trainer cache requires a non-empty trainer ID."
151 )
152 self.trainer_id = trainer_id
154 def _to_proto(self):
155 return data_service_pb2.CrossTrainerCacheOptions(trainer_id=self.trainer_id)
158def _get_validated_sharding_policy(processing_mode):
159 """Validates `processing_mode` and converts it to ShardingPolicy."""
161 if isinstance(processing_mode, ShardingPolicy):
162 return processing_mode
163 if processing_mode == _PARALLEL_EPOCHS:
164 return ShardingPolicy.OFF
165 if processing_mode == _DISTRIBUTED_EPOCH:
166 return ShardingPolicy.DYNAMIC
168 raise ValueError("tf.data service processing mode should be a "
169 "`tf.data.experimental.service.ShardingPolicy`, "
170 "`\"parallel_epochs\"`, or `\"distributed_epoch\"`. Got "
171 f"{processing_mode!r}.")
174def _validate_job_name(job_name):
175 if job_name is None:
176 return
177 if not isinstance(job_name, str):
178 raise ValueError("`job_name` must be a string, but `job_name` was of type "
179 f"{type(job_name)}. job_name={job_name}")
180 if not job_name:
181 raise ValueError("`job_name` must not be empty")
184def _validate_compression(compression):
185 valid_compressions = [COMPRESSION_AUTO, COMPRESSION_NONE]
186 if compression not in valid_compressions:
187 raise ValueError(f"Invalid `compression` argument: {compression}. "
188 f"Must be one of {valid_compressions}.")
191def _get_compression_proto(compression):
192 if compression == COMPRESSION_AUTO:
193 return data_service_pb2.DataServiceMetadata.COMPRESSION_SNAPPY
194 if compression == COMPRESSION_NONE:
195 return data_service_pb2.DataServiceMetadata.COMPRESSION_OFF
196 raise ValueError(f"Invalid `compression` argument: {compression}. "
197 f"Must be one of {[COMPRESSION_AUTO, COMPRESSION_NONE]}.")
200def _to_tensor(dataset_id):
201 """Converts `dataset_id` to Tensor."""
203 if isinstance(dataset_id, ops.Tensor):
204 return dataset_id
205 if isinstance(dataset_id, str) or isinstance(dataset_id, bytes):
206 return ops.convert_to_tensor(
207 dataset_id, dtype=dtypes.string, name="dataset_id")
208 return ops.convert_to_tensor(
209 dataset_id, dtype=dtypes.int64, name="dataset_id")
212def _to_string(dataset_id):
213 """Converts `dataset_id` to string."""
215 if isinstance(dataset_id, ops.Tensor):
216 return (dataset_id if dataset_id.dtype == dtypes.string else
217 string_ops.as_string(dataset_id))
218 return (dataset_id.decode()
219 if isinstance(dataset_id, bytes) else str(dataset_id))
222class _DataServiceDatasetV2(dataset_ops.DatasetSource):
223 """A `Dataset` that reads elements from the tf.data service."""
225 def __init__(self,
226 dataset_id,
227 processing_mode,
228 address,
229 element_spec,
230 protocol,
231 data_transfer_protocol,
232 job_name=None,
233 consumer_index=None,
234 num_consumers=None,
235 max_outstanding_requests=None,
236 task_refresh_interval_hint_ms=None,
237 cross_trainer_cache=None,
238 target_workers="AUTO"):
239 """Constructs a _DataServiceDatasetV2.
241 Args:
242 dataset_id: The dataset id for the dataset to read from.
243 processing_mode: A `tf.data.experimental.service.ShardingPolicy`
244 specifying how to shard the dataset among tf.data workers. See
245 `tf.data.experimental.service.ShardingPolicy` for details. For backwards
246 compatibility, `processing_mode` may also be set to the strings
247 `"parallel_epochs"` or `"distributed_epoch"`, which are respectively
248 equivalent to `ShardingPolicy.OFF` and `ShardingPolicy.DYNAMIC`.
249 address: The tf.data service address, e.g. "localhost:5000".
250 element_spec: The dataset element spec for the dataset to read from.
251 protocol: The protocol to use for communicating with the tf.data service,
252 e.g. "grpc".
253 data_transfer_protocol: (Optional.) The protocol to use for transferring
254 data with the tf.data service. By default, data is transferred using
255 gRPC.
256 job_name: (Optional.) The name of the job. If provided, it must be a
257 non-empty string or Tensor. This argument makes it possible for multiple
258 datasets to share the same job. The default behavior is that the dataset
259 creates anonymous, exclusively owned jobs.
260 consumer_index: (Optional.) The index of the consumer in the range from
261 `0` to `num_consumers`. Must be specified alongside `num_consumers`.
262 When specified, consumers will read from the job in a strict round-robin
263 order, instead of the default first-come-first-served order.
264 num_consumers: (Optional.) The number of consumers which will consume from
265 the job. Must be specified alongside `consumer_index`. When specified,
266 consumers will read from the job in a strict round-robin order, instead
267 of the default first-come-first-served order. When `num_consumers` is
268 specified, the dataset must have infinite cardinality to prevent a
269 producer from running out of data early and causing consumers to go out
270 of sync.
271 max_outstanding_requests: (Optional.) A limit on how many elements may be
272 requested at the same time. You can use this option to control the
273 amount of memory used, since `distribute` won't use more than
274 `element_size` * `max_outstanding_requests` of memory.
275 task_refresh_interval_hint_ms: (Optional.) A hint for how often to query
276 the dispatcher for task changes.
277 cross_trainer_cache: (Optional.) If a `CrossTrainerCache` object is
278 provided, dataset iteration will be shared across concurrently running
279 trainers. See
280 https://www.tensorflow.org/api_docs/python/tf/data/experimental/service#sharing_tfdata_service_with_concurrent_trainers
281 for details.
282 target_workers: (Optional.) Which workers to read from. If `"AUTO"`,
283 tf.data runtime decides which workers to read from. If `"ANY"`, reads
284 from any tf.data service workers. If `"LOCAL"`, only reads from local
285 in-processs tf.data service workers. `"AUTO"` works well for most cases,
286 while users can specify other targets. For example, `"LOCAL"` helps
287 avoid RPCs and data copy if every TF worker colocates with a tf.data
288 service worker. Consumers of a shared job must use the same
289 `target_workers`. Defaults to `"AUTO"`.
290 """
291 if consumer_index is None != num_consumers is None:
292 raise ValueError(
293 "Must either set both `consumer_index` and `num_consumers`, "
294 "or neither. ",
295 f"consumer_index={consumer_index}, num_consumers={num_consumers}")
296 if num_consumers is not None and job_name is None:
297 raise ValueError("`job_name` must be set when setting `num_consumers`. "
298 f"num_consumers was set to {num_consumers}.")
300 processing_mode_def = data_service_pb2.ProcessingModeDef(
301 sharding_policy=_get_validated_sharding_policy(
302 processing_mode)._to_proto())
303 if job_name is None:
304 job_name = ""
305 if max_outstanding_requests is None:
306 max_outstanding_requests = dataset_ops.AUTOTUNE
307 if task_refresh_interval_hint_ms is None:
308 task_refresh_interval_hint_ms = dataset_ops.AUTOTUNE
310 self._dataset_id = _to_tensor(dataset_id)
311 self._processing_mode = ops.convert_to_tensor(
312 processing_mode_def.SerializeToString(),
313 dtype=dtypes.string,
314 name="processing_mode")
315 self._address = ops.convert_to_tensor(
316 address, dtype=dtypes.string, name="address")
317 self._protocol = ops.convert_to_tensor(
318 protocol, dtype=dtypes.string, name="protocol")
319 self._job_name = ops.convert_to_tensor(
320 job_name, dtype=dtypes.string, name="job_name")
321 self._consumer_index = ops.convert_to_tensor(
322 -1 if consumer_index is None else consumer_index,
323 dtype=dtypes.int64,
324 name="consumer_index")
325 self._num_consumers = ops.convert_to_tensor(
326 -1 if num_consumers is None else num_consumers,
327 dtype=dtypes.int64,
328 name="num_consumers")
329 self._max_outstanding_requests = ops.convert_to_tensor(
330 max_outstanding_requests,
331 dtype=dtypes.int64,
332 name="max_outstanding_requests")
333 self._element_spec = element_spec
334 uncompress_func = structured_function.StructuredFunctionWrapper(
335 lambda x: compression_ops.uncompress(x, output_spec=element_spec),
336 transformation_name="DataServiceDataset.uncompress()",
337 input_structure=tensor_spec.TensorSpec(shape=(), dtype=dtypes.variant))
338 cross_trainer_cache_options = (
339 cross_trainer_cache._to_proto().SerializeToString()
340 if cross_trainer_cache else None)
342 compat_kwargs = {}
343 if data_transfer_protocol is not None:
344 compat_kwargs["data_transfer_protocol"] = data_transfer_protocol
346 # If `uncompress` is `True`, the dataset will query the servers to find
347 # out the actual compression used. It is always set to `True` the first
348 # time the graph is built, and set to false when serializing, so we will
349 # uncompress at most once.
350 uncompress = True
351 variant_tensor = gen_experimental_dataset_ops.data_service_dataset_v4(
352 dataset_id=self._dataset_id,
353 processing_mode=self._processing_mode,
354 address=self._address,
355 protocol=self._protocol,
356 job_name=self._job_name,
357 consumer_index=self._consumer_index,
358 num_consumers=self._num_consumers,
359 max_outstanding_requests=self._max_outstanding_requests,
360 task_refresh_interval_hint_ms=task_refresh_interval_hint_ms,
361 iteration_counter=(
362 gen_experimental_dataset_ops.dummy_iteration_counter()),
363 target_workers=target_workers,
364 uncompress=uncompress,
365 uncompress_fn=uncompress_func.function,
366 cross_trainer_cache_options=cross_trainer_cache_options,
367 **compat_kwargs,
368 **self._flat_structure)
369 super(_DataServiceDatasetV2, self).__init__(variant_tensor)
371 @property
372 def element_spec(self):
373 return self._element_spec
376class _DataServiceDatasetV1(dataset_ops.DatasetV1Adapter):
377 """A `Dataset` that executes its input through the tf.data service."""
379 @functools.wraps(_DataServiceDatasetV2.__init__)
380 def __init__(self, dataset_id, processing_mode, address, element_spec,
381 protocol, data_transfer_protocol, job_name, consumer_index,
382 num_consumers, max_outstanding_requests,
383 task_refresh_interval_hint_ms, cross_trainer_cache,
384 target_workers):
386 self._wrapped = _DataServiceDatasetV2(
387 dataset_id=dataset_id,
388 processing_mode=processing_mode,
389 address=address,
390 element_spec=element_spec,
391 protocol=protocol,
392 data_transfer_protocol=data_transfer_protocol,
393 job_name=job_name,
394 consumer_index=consumer_index,
395 num_consumers=num_consumers,
396 max_outstanding_requests=max_outstanding_requests,
397 task_refresh_interval_hint_ms=task_refresh_interval_hint_ms,
398 cross_trainer_cache=cross_trainer_cache,
399 target_workers=target_workers)
400 super(_DataServiceDatasetV1, self).__init__(self._wrapped)
403if tf2.enabled():
404 _DataServiceDataset = _DataServiceDatasetV2
405else:
406 _DataServiceDataset = _DataServiceDatasetV1
409def _parse_service(service):
410 """Converts a tf.data service string into a (protocol, address) tuple.
412 Args:
413 service: A string in the format "protocol://address" or just "address". If
414 the string is only an address, the default protocol will be used.
416 Returns:
417 The (protocol, address) tuple
418 """
419 if not isinstance(service, str):
420 raise ValueError("`service` must be a string, but `service` was of type "
421 f"{type(service)}. service={service}")
422 if not service:
423 raise ValueError("`service` must not be empty")
424 parts = service.split("://")
425 if len(parts) == 2:
426 protocol, address = parts
427 elif len(parts) == 1:
428 address = parts[0]
429 protocol = _pywrap_utils.TF_DATA_DefaultProtocol()
430 else:
431 raise ValueError("Malformed `service` string has multiple '://': "
432 f"{service}.")
433 # TODO(aaudibert): Considering validating reachability of address here.
434 return (protocol, address)
437def _distribute(processing_mode,
438 service,
439 job_name=None,
440 consumer_index=None,
441 num_consumers=None,
442 max_outstanding_requests=None,
443 task_refresh_interval_hint_ms=None,
444 data_transfer_protocol=None,
445 compression="AUTO",
446 cross_trainer_cache=None,
447 target_workers="AUTO"):
448 """A transformation that moves dataset processing to the tf.data service.
450 This transformation is similar to `distribute`, but supports additional
451 parameters which we do not yet want to add to the public Python API.
453 Args:
454 processing_mode: A `tf.data.experimental.service.ShardingPolicy` specifying
455 how to shard the dataset among tf.data workers. See
456 `tf.data.experimental.service.ShardingPolicy` for details. For backwards
457 compatibility, `processing_mode` may also be set to the strings
458 `"parallel_epochs"` or `"distributed_epoch"`, which are respectively
459 equivalent to `ShardingPolicy.OFF` and `ShardingPolicy.DYNAMIC`.
460 service: A string or a tuple indicating how to connect to the tf.data
461 service. If it's a string, it should be in the format
462 `[<protocol>://]<address>`, where `<address>` identifies the dispatcher
463 address and `<protocol>` can optionally be used to override the default
464 protocol to use. If it's a tuple, it should be (protocol, address).
465 job_name: (Optional.) The name of the job. If provided, it must be a
466 non-empty string. This argument makes it possible for multiple datasets to
467 share the same job. The default behavior is that the dataset creates
468 anonymous, exclusively owned jobs.
469 consumer_index: (Optional.) The index of the consumer in the range from `0`
470 to `num_consumers`. Must be specified alongside `num_consumers`. When
471 specified, consumers will read from the job in a strict round-robin order,
472 instead of the default first-come-first-served order.
473 num_consumers: (Optional.) The number of consumers which will consume from
474 the job. Must be specified alongside `consumer_index`. When specified,
475 consumers will read from the job in a strict round-robin order, instead of
476 the default first-come-first-served order. When `num_consumers` is
477 specified, the dataset must have infinite cardinality to prevent a
478 producer from running out of data early and causing consumers to go out of
479 sync.
480 max_outstanding_requests: (Optional.) A limit on how many elements may be
481 requested at the same time. You can use this option to control the amount
482 of memory used, since `distribute` won't use more than `element_size` *
483 `max_outstanding_requests` of memory.
484 task_refresh_interval_hint_ms: (Optional.) A hint for how often to query the
485 dispatcher for task changes.
486 data_transfer_protocol: (Optional.) The protocol to use for transferring
487 data with the tf.data service. By default, data is transferred using gRPC.
488 compression: How to compress the dataset's elements before transferring them
489 over the network. "AUTO" leaves the decision of how to compress up to the
490 tf.data service runtime. `None` indicates not to compress.
491 cross_trainer_cache: (Optional.) If a `CrossTrainerCache` object is
492 provided, dataset iteration will be shared across concurrently running
493 trainers. See
494 https://www.tensorflow.org/api_docs/python/tf/data/experimental/service#sharing_tfdata_service_with_concurrent_trainers
495 for details.
496 target_workers: (Optional.) Which workers to read from. If `"AUTO"`, tf.data
497 runtime decides which workers to read from. If `"ANY"`, reads from any
498 tf.data service workers. If `"LOCAL"`, only reads from local in-processs
499 tf.data service workers. `"AUTO"` works well for most cases, while users
500 can specify other targets. For example, `"LOCAL"` helps avoid RPCs and
501 data copy if every TF worker colocates with a tf.data service worker.
502 Consumers of a shared job must use the same `target_workers`. Defaults to
503 `"AUTO"`.
505 Returns:
506 Dataset: A `Dataset` of the elements produced by the data service.
507 """
508 processing_mode = _get_validated_sharding_policy(processing_mode)
509 _validate_compression(compression)
511 def _apply_fn(dataset): # pylint: disable=missing-docstring
512 dataset_id = _register_dataset(service, dataset, compression=compression)
513 return _from_dataset_id(
514 processing_mode,
515 service,
516 dataset_id,
517 dataset.element_spec,
518 job_name=job_name,
519 consumer_index=consumer_index,
520 num_consumers=num_consumers,
521 max_outstanding_requests=max_outstanding_requests,
522 task_refresh_interval_hint_ms=task_refresh_interval_hint_ms,
523 data_transfer_protocol=data_transfer_protocol,
524 compression=compression,
525 cross_trainer_cache=cross_trainer_cache,
526 target_workers=target_workers)
528 return _apply_fn
531@tf_export("data.experimental.service.distribute")
532def distribute(processing_mode,
533 service,
534 job_name=None,
535 consumer_index=None,
536 num_consumers=None,
537 max_outstanding_requests=None,
538 data_transfer_protocol=None,
539 compression="AUTO",
540 cross_trainer_cache=None,
541 target_workers="AUTO"):
542 """A transformation that moves dataset processing to the tf.data service.
544 When you iterate over a dataset containing the `distribute` transformation,
545 the tf.data service creates a "job" which produces data for the dataset
546 iteration.
548 The tf.data service uses a cluster of workers to prepare data for training
549 your model.
550 The `processing_mode` argument to `tf.data.experimental.service.distribute`
551 describes how to leverage multiple workers to process the input dataset.
552 Currently, there are two processing modes to choose from: "distributed_epoch"
553 and "parallel_epochs".
555 "distributed_epoch" means that the dataset will be split across all tf.data
556 service workers.
557 The dispatcher produces "splits" for the dataset and sends them to workers for
558 further processing. For example, if a dataset begins with a list of filenames,
559 the dispatcher will iterate through the filenames and send the filenames to
560 tf.data workers, which will perform the rest of the dataset transformations on
561 those files. "distributed_epoch" is useful when your model needs to see each
562 element of the dataset exactly once, or if it needs to see the data in a
563 generally-sequential order. "distributed_epoch" only works for datasets with
564 splittable sources, such as `Dataset.from_tensor_slices`,
565 `Dataset.list_files`, or `Dataset.range`.
567 "parallel_epochs" means that the entire input dataset will be processed
568 independently by each of the tf.data service workers.
569 For this reason, it is important to shuffle data (e.g. filenames)
570 non-deterministically, so that each worker will process the elements of the
571 dataset in a different order. "parallel_epochs" can be used to distribute
572 datasets that aren't splittable.
574 With two workers, "parallel_epochs" will produce every element of the dataset
575 twice:
577 >>> dispatcher = tf.data.experimental.service.DispatchServer()
578 >>> dispatcher_address = dispatcher.target.split("://")[1]
579 >>> # Start two workers
580 >>> workers = [
581 ... tf.data.experimental.service.WorkerServer(
582 ... tf.data.experimental.service.WorkerConfig(
583 ... dispatcher_address=dispatcher_address)) for _ in range(2)
584 ... ]
585 >>> dataset = tf.data.Dataset.range(10)
586 >>> dataset = dataset.apply(tf.data.experimental.service.distribute(
587 ... processing_mode="parallel_epochs", service=dispatcher.target))
588 >>> print(sorted(list(dataset.as_numpy_iterator())))
589 [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9]
591 "distributed_epoch", on the other hand, will still produce each element once:
593 >>> dispatcher = tf.data.experimental.service.DispatchServer()
594 >>> dispatcher_address = dispatcher.target.split("://")[1]
595 >>> workers = [
596 ... tf.data.experimental.service.WorkerServer(
597 ... tf.data.experimental.service.WorkerConfig(
598 ... dispatcher_address=dispatcher_address)) for _ in range(2)
599 ... ]
600 >>> dataset = tf.data.Dataset.range(10)
601 >>> dataset = dataset.apply(tf.data.experimental.service.distribute(
602 ... processing_mode="distributed_epoch", service=dispatcher.target))
603 >>> print(sorted(list(dataset.as_numpy_iterator())))
604 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
606 When using `apply(tf.data.experimental.service.distribute(...))`, the dataset
607 before the `apply` transformation executes within the tf.data service, while
608 the operations after `apply` happen within the local process.
610 >>> dispatcher = tf.data.experimental.service.DispatchServer()
611 >>> dispatcher_address = dispatcher.target.split("://")[1]
612 >>> workers = [
613 ... tf.data.experimental.service.WorkerServer(
614 ... tf.data.experimental.service.WorkerConfig(
615 ... dispatcher_address=dispatcher_address)) for _ in range(2)
616 ... ]
617 >>> dataset = tf.data.Dataset.range(5)
618 >>> dataset = dataset.map(lambda x: x*x)
619 >>> dataset = dataset.apply(
620 ... tf.data.experimental.service.distribute("parallel_epochs",
621 ... dispatcher.target))
622 >>> dataset = dataset.map(lambda x: x+1)
623 >>> print(sorted(list(dataset.as_numpy_iterator())))
624 [1, 1, 2, 2, 5, 5, 10, 10, 17, 17]
626 In the above example, the dataset operations (before applying the `distribute`
627 function on the elements) will be executed on the tf.data workers,
628 and the elements are provided over RPC. The remaining transformations
629 (after the call to `distribute`) will be executed locally. The dispatcher
630 and the workers will bind to usused free ports (which are chosen at random),
631 in order to communicate with each other. However, to bind them to specific
632 ports, the `port` parameter can be passed.
634 The `job_name` argument allows jobs to be shared across multiple
635 datasets. Instead of each dataset creating its own job, all
636 datasets with the same `job_name` will consume from the same job. A new job
637 will be created for each iteration of the dataset (with each repetition of
638 `Dataset.repeat` counting as a new iteration). Suppose the `DispatchServer`
639 is serving on `localhost:5000` and two training workers (in either a single
640 client or multi-client setup) iterate over the below dataset, and there is a
641 single tf.data worker:
643 ```
644 range5_dataset = tf.data.Dataset.range(5)
645 dataset = range5_dataset.apply(tf.data.experimental.service.distribute(
646 "parallel_epochs", "localhost:5000", job_name="my_job_name"))
647 for iteration in range(3):
648 print(list(dataset))
649 ```
651 The elements of each job will be split between the two processes, with
652 elements being consumed by the processes on a first-come first-served basis.
653 One possible result is that process 1 prints
655 ```
656 [0, 2, 4]
657 [0, 1, 3]
658 [1]
659 ```
661 and process 2 prints
663 ```
664 [1, 3]
665 [2, 4]
666 [0, 2, 3, 4]
667 ```
669 Job names must not be re-used across different training jobs within the
670 lifetime of the tf.data service. In general, the tf.data service is expected
671 to live for the duration of a single training job.
672 To use the tf.data service with multiple training jobs, make sure to use
673 different job names to avoid conflicts. For example, suppose a training job
674 calls `distribute` with `job_name="job"` and reads until end of input. If
675 another independent job connects to the same tf.data service and tries to read
676 from `job_name="job"`, it will immediately receive end of input, without
677 getting any data.
679 **Coordinated data read**
681 By default, when multiple consumers read from the same job, they receive data
682 on a first-come first-served basis. In some use cases, it is advantageous to
683 coordinate the consumers. At each step, consumers read data from the same
684 worker.
686 For example, the tf.data service can be used to coordinate example sizes
687 across a cluster during synchronous training, so that during each step all
688 replicas train on similar-sized elements. To achieve this, define a dataset
689 which generates rounds of `num_consumers` consecutive similar-sized batches,
690 then enable coordinated reads by setting `consumer_index` and `num_consumers`.
692 NOTE: To keep consumers in sync, round robin data consumption requires that
693 the dataset have infinite cardinality. You can get this by adding `.repeat()`
694 at the end of the dataset definition.
696 **Keras and Distribution Strategies**
698 The dataset produced by the `distribute` transformation can be passed to
699 Keras' `Model.fit` or Distribution Strategy's
700 `tf.distribute.Strategy.experimental_distribute_dataset` like any other
701 `tf.data.Dataset`. We recommend setting a `job_name` on the call to
702 `distribute` so that if there are multiple workers, they read data from the
703 same job. Note that the autosharding normally performed by
704 `experimental_distribute_dataset` will be disabled when setting a `job_name`,
705 since sharing the job already results in splitting data across the workers.
706 When using a shared job, data will be dynamically balanced across workers, so
707 that they reach end of input about the same time. This results in better
708 worker utilization than with autosharding, where each worker processes an
709 independent set of files, and some workers may run out of data earlier than
710 others.
712 Args:
713 processing_mode: A `tf.data.experimental.service.ShardingPolicy` specifying
714 how to shard the dataset among tf.data workers. See
715 `tf.data.experimental.service.ShardingPolicy` for details. For backwards
716 compatibility, `processing_mode` may also be set to the strings
717 `"parallel_epochs"` or `"distributed_epoch"`, which are respectively
718 equivalent to `ShardingPolicy.OFF` and `ShardingPolicy.DYNAMIC`.
719 service: A string or a tuple indicating how to connect to the tf.data
720 service. If it's a string, it should be in the format
721 `[<protocol>://]<address>`, where `<address>` identifies the dispatcher
722 address and `<protocol>` can optionally be used to override the default
723 protocol to use. If it's a tuple, it should be (protocol, address).
724 job_name: (Optional.) The name of the job. If provided, it must be a
725 non-empty string. This argument makes it possible for multiple datasets to
726 share the same job. The default behavior is that the dataset creates
727 anonymous, exclusively owned jobs.
728 consumer_index: (Optional.) The index of the consumer in the range from `0`
729 to `num_consumers`. Must be specified alongside `num_consumers`. When
730 specified, consumers will read from the job in a strict round-robin order,
731 instead of the default first-come-first-served order.
732 num_consumers: (Optional.) The number of consumers which will consume from
733 the job. Must be specified alongside `consumer_index`. When specified,
734 consumers will read from the job in a strict round-robin order, instead of
735 the default first-come-first-served order. When `num_consumers` is
736 specified, the dataset must have infinite cardinality to prevent a
737 producer from running out of data early and causing consumers to go out of
738 sync.
739 max_outstanding_requests: (Optional.) A limit on how many elements may be
740 requested at the same time. You can use this option to control the amount
741 of memory used, since `distribute` won't use more than `element_size` *
742 `max_outstanding_requests` of memory.
743 data_transfer_protocol: (Optional.) The protocol to use for transferring
744 data with the tf.data service. By default, data is transferred using gRPC.
745 compression: How to compress the dataset's elements before transferring them
746 over the network. "AUTO" leaves the decision of how to compress up to the
747 tf.data service runtime. `None` indicates not to compress.
748 cross_trainer_cache: (Optional.) If a `CrossTrainerCache` object is
749 provided, dataset iteration will be shared across concurrently running
750 trainers. See
751 https://www.tensorflow.org/api_docs/python/tf/data/experimental/service#sharing_tfdata_service_with_concurrent_trainers
752 for details.
753 target_workers: (Optional.) Which workers to read from. If `"AUTO"`, tf.data
754 runtime decides which workers to read from. If `"ANY"`, reads from any
755 tf.data service workers. If `"LOCAL"`, only reads from local in-processs
756 tf.data service workers. `"AUTO"` works well for most cases, while users
757 can specify other targets. For example, `"LOCAL"` helps avoid RPCs and
758 data copy if every TF worker colocates with a tf.data service worker.
759 Consumers of a shared job must use the same `target_workers`. Defaults to
760 `"AUTO"`.
762 Returns:
763 Dataset: A `Dataset` of the elements produced by the data service.
764 """
765 _validate_job_name(job_name)
766 return _distribute(
767 processing_mode=processing_mode,
768 service=service,
769 job_name=job_name,
770 consumer_index=consumer_index,
771 num_consumers=num_consumers,
772 max_outstanding_requests=max_outstanding_requests,
773 data_transfer_protocol=data_transfer_protocol,
774 compression=compression,
775 cross_trainer_cache=cross_trainer_cache,
776 target_workers=target_workers)
779def _register_dataset(service, dataset, compression, dataset_id=None):
780 """Registers a dataset with the tf.data service.
782 This transformation is similar to `register_dataset`, but supports additional
783 parameters which we do not yet want to add to the public Python API.
785 Args:
786 service: A string or a tuple indicating how to connect to the tf.data
787 service. If it's a string, it should be in the format
788 `[<protocol>://]<address>`, where `<address>` identifies the dispatcher
789 address and `<protocol>` can optionally be used to override the default
790 protocol to use. If it's a tuple, it should be (protocol, address).
791 dataset: A `tf.data.Dataset` to register with the tf.data service.
792 compression: How to compress the dataset's elements before transferring them
793 over the network. "AUTO" leaves the decision of how to compress up to the
794 tf.data service runtime. `None` indicates not to compress.
795 dataset_id: (Optional.) By default, tf.data service generates a unique
796 (string) ID for each registered dataset. If a `dataset_id` is provided, it
797 will use the specified ID. If a dataset with a matching ID already exists,
798 no new dataset is registered. This is useful if multiple training jobs
799 want to (re)use the same dataset for training. In this case, they can
800 register the dataset with the same dataset ID.
802 Returns:
803 A scalar string tensor representing the dataset ID.
804 """
805 _validate_compression(compression)
806 if isinstance(service, tuple):
807 protocol, address = service
808 else:
809 protocol, address = _parse_service(service)
810 external_state_policy = dataset.options().experimental_external_state_policy
811 if external_state_policy is None:
812 external_state_policy = ExternalStatePolicy.WARN
814 encoded_spec = None
815 if context.executing_eagerly():
816 encoded_spec = nested_structure_coder.encode_structure(
817 dataset.element_spec).SerializeToString()
819 if compression == COMPRESSION_AUTO:
820 dataset = dataset.map(
821 lambda *x: compression_ops.compress(x),
822 num_parallel_calls=dataset_ops.AUTOTUNE)
823 dataset = dataset._apply_debug_options() # pylint: disable=protected-access
825 metadata = data_service_pb2.DataServiceMetadata(
826 element_spec=encoded_spec,
827 compression=_get_compression_proto(compression))
829 return gen_experimental_dataset_ops.register_dataset_v2(
830 dataset._variant_tensor, # pylint: disable=protected-access
831 address=address,
832 protocol=protocol,
833 external_state_policy=external_state_policy.value,
834 requested_dataset_id=dataset_id,
835 metadata=metadata.SerializeToString())
838@tf_export("data.experimental.service.register_dataset")
839def register_dataset(service, dataset, compression="AUTO", dataset_id=None):
840 """Registers a dataset with the tf.data service.
842 `register_dataset` registers a dataset with the tf.data service so that
843 datasets can be created later with
844 `tf.data.experimental.service.from_dataset_id`. This is useful when the
845 dataset
846 is registered by one process, then used in another process. When the same
847 process is both registering and reading from the dataset, it is simpler to use
848 `tf.data.experimental.service.distribute` instead.
850 If the dataset is already registered with the tf.data service,
851 `register_dataset` returns the already-registered dataset's id.
853 >>> dispatcher = tf.data.experimental.service.DispatchServer()
854 >>> dispatcher_address = dispatcher.target.split("://")[1]
855 >>> worker = tf.data.experimental.service.WorkerServer(
856 ... tf.data.experimental.service.WorkerConfig(
857 ... dispatcher_address=dispatcher_address))
858 >>> dataset = tf.data.Dataset.range(10)
859 >>> dataset_id = tf.data.experimental.service.register_dataset(
860 ... dispatcher.target, dataset)
861 >>> dataset = tf.data.experimental.service.from_dataset_id(
862 ... processing_mode="parallel_epochs",
863 ... service=dispatcher.target,
864 ... dataset_id=dataset_id,
865 ... element_spec=dataset.element_spec)
866 >>> print(list(dataset.as_numpy_iterator()))
867 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
869 Args:
870 service: A string or a tuple indicating how to connect to the tf.data
871 service. If it's a string, it should be in the format
872 `[<protocol>://]<address>`, where `<address>` identifies the dispatcher
873 address and `<protocol>` can optionally be used to override the default
874 protocol to use. If it's a tuple, it should be (protocol, address).
875 dataset: A `tf.data.Dataset` to register with the tf.data service.
876 compression: (Optional.) How to compress the dataset's elements before
877 transferring them over the network. "AUTO" leaves the decision of how to
878 compress up to the tf.data service runtime. `None` indicates not to
879 compress.
880 dataset_id: (Optional.) By default, tf.data service generates a unique
881 (string) ID for each registered dataset. If a `dataset_id` is provided, it
882 will use the specified ID. If a dataset with a matching ID already exists,
883 no new dataset is registered. This is useful if multiple training jobs
884 want to (re)use the same dataset for training. In this case, they can
885 register the dataset with the same dataset ID.
887 Returns:
888 A scalar string tensor representing the dataset ID.
889 """
890 return _register_dataset(service, dataset, compression, dataset_id)
893def _from_dataset_id(processing_mode,
894 service,
895 dataset_id,
896 element_spec,
897 job_name=None,
898 consumer_index=None,
899 num_consumers=None,
900 max_outstanding_requests=None,
901 task_refresh_interval_hint_ms=None,
902 data_transfer_protocol=None,
903 compression="AUTO",
904 cross_trainer_cache=None,
905 target_workers="AUTO"):
906 """Creates a dataset which reads data from the tf.data service.
908 This transformation is similar to `from_dataset_id`, but supports additional
909 parameters which we do not yet want to add to the public Python API.
911 Args:
912 processing_mode: A `tf.data.experimental.service.ShardingPolicy` specifying
913 how to shard the dataset among tf.data workers. See
914 `tf.data.experimental.service.ShardingPolicy` for details. For backwards
915 compatibility, `processing_mode` may also be set to the strings
916 `"parallel_epochs"` or `"distributed_epoch"`, which are respectively
917 equivalent to `ShardingPolicy.OFF` and `ShardingPolicy.DYNAMIC`.
918 service: A string or a tuple indicating how to connect to the tf.data
919 service. If it's a string, it should be in the format
920 `[<protocol>://]<address>`, where `<address>` identifies the dispatcher
921 address and `<protocol>` can optionally be used to override the default
922 protocol to use. If it's a tuple, it should be (protocol, address).
923 dataset_id: The id of the dataset to read from. This id is returned by
924 `register_dataset` when the dataset is registered with the tf.data
925 service.
926 element_spec: A nested structure of `tf.TypeSpec`s representing the type of
927 elements produced by the dataset. This argument is only required inside a
928 tf.function. Use `tf.data.Dataset.element_spec` to get the element spec
929 for a given dataset.
930 job_name: (Optional.) The name of the job. If provided, it must be a
931 non-empty string or tensor. This argument makes it possible for multiple
932 datasets to share the same job. The default behavior is that the dataset
933 creates anonymous, exclusively owned jobs.
934 consumer_index: (Optional.) The index of the consumer in the range from `0`
935 to `num_consumers`. Must be specified alongside `num_consumers`. When
936 specified, consumers will read from the job in a strict round-robin order,
937 instead of the default first-come-first-served order.
938 num_consumers: (Optional.) The number of consumers which will consume from
939 the job. Must be specified alongside `consumer_index`. When specified,
940 consumers will read from the job in a strict round-robin order, instead of
941 the default first-come-first-served order. When `num_consumers` is
942 specified, the dataset must have infinite cardinality to prevent a
943 producer from running out of data early and causing consumers to go out of
944 sync.
945 max_outstanding_requests: (Optional.) A limit on how many elements may be
946 requested at the same time. You can use this option to control the amount
947 of memory used, since `distribute` won't use more than `element_size` *
948 `max_outstanding_requests` of memory.
949 task_refresh_interval_hint_ms: (Optional.) A hint for how often to query the
950 dispatcher for task changes.
951 data_transfer_protocol: (Optional.) The protocol to use for transferring
952 data with the tf.data service. By default, data is transferred using gRPC.
953 compression: An indication of how the dataset's elements were compressed, so
954 that `from_dataset_id` can uncompress them if necessary.
955 cross_trainer_cache: (Optional.) If a `CrossTrainerCache` object is
956 provided, dataset iteration will be shared across concurrently running
957 trainers. See
958 https://www.tensorflow.org/api_docs/python/tf/data/experimental/service#sharing_tfdata_service_with_concurrent_trainers
959 for details.
960 target_workers: (Optional.) Which workers to read from. If `"AUTO"`, tf.data
961 runtime decides which workers to read from. If `"ANY"`, reads from any
962 tf.data service workers. If `"LOCAL"`, only reads from local in-processs
963 tf.data service workers. `"AUTO"` works well for most cases, while users
964 can specify other targets. For example, `"LOCAL"` helps avoid RPCs and
965 data copy if every TF worker colocates with a tf.data service worker.
966 Consumers of a shared job must use the same `target_workers`. Defaults to
967 `"AUTO"`.
969 Returns:
970 A `tf.data.Dataset` which reads from the tf.data service.
971 """
972 def _get_element_spec():
973 """Fetches the element spec from the server."""
974 data_service_metadata = None
975 dataset_id_val = tensor_util.constant_value(dataset_id)
976 try:
977 data_service_metadata = (
978 _pywrap_server_lib.TF_DATA_GetDataServiceMetadataByID(
979 dataset_id_val, address, protocol
980 )
981 )
982 except NotImplementedError as err:
983 raise ValueError(
984 "The tf.data service is running an earlier version of TensorFlow "
985 "that requires specifying `element_spec` as an argument to "
986 "`from_dataset_id`. Please either supply an element spec or update "
987 "the tf.data service to the latest version.") from err
988 except RuntimeError:
989 # This error results from dataset ID not found. A more appropriate error
990 # will be raised when the dataset is created.
991 pass
993 if not data_service_metadata or not data_service_metadata.element_spec:
994 dataset_id_val = tensor_util.constant_value(dataset_id)
995 raise ValueError(
996 f"Failed to fetch element spec for dataset id {dataset_id_val} from "
997 "tf.data service. If the dataset was registered in graph mode or "
998 "inside a tf.function, the `element_spec` must be specified as an "
999 "argument to `from_dataset_id`.")
1001 struct_pb = nested_structure_coder.struct_pb2.StructuredValue()
1002 struct_pb.ParseFromString(data_service_metadata.element_spec)
1003 return nested_structure_coder.decode_proto(struct_pb)
1005 processing_mode = _get_validated_sharding_policy(processing_mode)
1006 if isinstance(service, tuple):
1007 protocol, address = service
1008 else:
1009 protocol, address = _parse_service(service)
1010 _validate_compression(compression)
1011 if job_name is not None:
1012 if not isinstance(job_name, str) and not isinstance(job_name, ops.Tensor):
1013 raise ValueError(
1014 "`job_name` must be a string or Tensor, but `job_name` was of type "
1015 f"{type(job_name)}. job_name={job_name}.")
1017 if not element_spec:
1018 if not context.executing_eagerly():
1019 raise ValueError(
1020 "In graph mode `element_spec` must be provided manually.")
1021 element_spec = _get_element_spec()
1023 dataset = _DataServiceDataset(
1024 dataset_id=dataset_id,
1025 processing_mode=processing_mode,
1026 address=address,
1027 element_spec=element_spec,
1028 protocol=protocol,
1029 data_transfer_protocol=data_transfer_protocol,
1030 job_name=job_name,
1031 consumer_index=consumer_index,
1032 num_consumers=num_consumers,
1033 max_outstanding_requests=max_outstanding_requests,
1034 task_refresh_interval_hint_ms=task_refresh_interval_hint_ms,
1035 cross_trainer_cache=cross_trainer_cache,
1036 target_workers=target_workers)
1038 # Disable autosharding for shared jobs.
1039 if job_name is not None:
1040 options = options_lib.Options()
1041 options.experimental_distribute.auto_shard_policy = AutoShardPolicy.OFF
1042 dataset = dataset.with_options(options)
1043 return dataset
1046@tf_export("data.experimental.service.from_dataset_id")
1047def from_dataset_id(processing_mode,
1048 service,
1049 dataset_id,
1050 element_spec=None,
1051 job_name=None,
1052 consumer_index=None,
1053 num_consumers=None,
1054 max_outstanding_requests=None,
1055 data_transfer_protocol=None,
1056 cross_trainer_cache=None,
1057 target_workers="AUTO"):
1058 """Creates a dataset which reads data from the tf.data service.
1060 This is useful when the dataset is registered by one process, then used in
1061 another process. When the same process is both registering and reading from
1062 the dataset, it is simpler to use `tf.data.experimental.service.distribute`
1063 instead.
1065 Before using `from_dataset_id`, the dataset must have been registered with the
1066 tf.data service using `tf.data.experimental.service.register_dataset`.
1067 `register_dataset` returns a dataset id for the registered dataset. That is
1068 the `dataset_id` which should be passed to `from_dataset_id`.
1070 The `element_spec` argument indicates the `tf.TypeSpec`s for the elements
1071 produced by the dataset. Currently `element_spec` must be explicitly
1072 specified, and match the dataset registered under `dataset_id`. `element_spec`
1073 defaults to `None` so that in the future we can support automatically
1074 discovering the `element_spec` by querying the tf.data service.
1076 `tf.data.experimental.service.distribute` is a convenience method which
1077 combines `register_dataset` and `from_dataset_id` into a dataset
1078 transformation.
1079 See the documentation for `tf.data.experimental.service.distribute` for more
1080 detail about how `from_dataset_id` works.
1082 >>> dispatcher = tf.data.experimental.service.DispatchServer()
1083 >>> dispatcher_address = dispatcher.target.split("://")[1]
1084 >>> worker = tf.data.experimental.service.WorkerServer(
1085 ... tf.data.experimental.service.WorkerConfig(
1086 ... dispatcher_address=dispatcher_address))
1087 >>> dataset = tf.data.Dataset.range(10)
1088 >>> dataset_id = tf.data.experimental.service.register_dataset(
1089 ... dispatcher.target, dataset)
1090 >>> dataset = tf.data.experimental.service.from_dataset_id(
1091 ... processing_mode="parallel_epochs",
1092 ... service=dispatcher.target,
1093 ... dataset_id=dataset_id,
1094 ... element_spec=dataset.element_spec)
1095 >>> print(list(dataset.as_numpy_iterator()))
1096 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
1098 Args:
1099 processing_mode: A `tf.data.experimental.service.ShardingPolicy` specifying
1100 how to shard the dataset among tf.data workers. See
1101 `tf.data.experimental.service.ShardingPolicy` for details. For backwards
1102 compatibility, `processing_mode` may also be set to the strings
1103 `"parallel_epochs"` or `"distributed_epoch"`, which are respectively
1104 equivalent to `ShardingPolicy.OFF` and `ShardingPolicy.DYNAMIC`.
1105 service: A string or a tuple indicating how to connect to the tf.data
1106 service. If it's a string, it should be in the format
1107 `[<protocol>://]<address>`, where `<address>` identifies the dispatcher
1108 address and `<protocol>` can optionally be used to override the default
1109 protocol to use. If it's a tuple, it should be (protocol, address).
1110 dataset_id: The id of the dataset to read from. This id is returned by
1111 `register_dataset` when the dataset is registered with the tf.data
1112 service.
1113 element_spec: A nested structure of `tf.TypeSpec`s representing the type of
1114 elements produced by the dataset. This argument is only required inside a
1115 tf.function. Use `tf.data.Dataset.element_spec` to get the element spec
1116 for a given dataset.
1117 job_name: (Optional.) The name of the job. If provided, it must be a
1118 non-empty string. This argument makes it possible for multiple datasets to
1119 share the same job. The default behavior is that the dataset creates
1120 anonymous, exclusively owned jobs.
1121 consumer_index: (Optional.) The index of the consumer in the range from `0`
1122 to `num_consumers`. Must be specified alongside `num_consumers`. When
1123 specified, consumers will read from the job in a strict round-robin order,
1124 instead of the default first-come-first-served order.
1125 num_consumers: (Optional.) The number of consumers which will consume from
1126 the job. Must be specified alongside `consumer_index`. When specified,
1127 consumers will read from the job in a strict round-robin order, instead of
1128 the default first-come-first-served order. When `num_consumers` is
1129 specified, the dataset must have infinite cardinality to prevent a
1130 producer from running out of data early and causing consumers to go out of
1131 sync.
1132 max_outstanding_requests: (Optional.) A limit on how many elements may be
1133 requested at the same time. You can use this option to control the amount
1134 of memory used, since `distribute` won't use more than `element_size` *
1135 `max_outstanding_requests` of memory.
1136 data_transfer_protocol: (Optional.) The protocol to use for transferring
1137 data with the tf.data service. By default, data is transferred using gRPC.
1138 cross_trainer_cache: (Optional.) If a `CrossTrainerCache` object is
1139 provided, dataset iteration will be shared across concurrently running
1140 trainers. See
1141 https://www.tensorflow.org/api_docs/python/tf/data/experimental/service#sharing_tfdata_service_with_concurrent_trainers
1142 for details.
1143 target_workers: (Optional.) Which workers to read from. If `"AUTO"`, tf.data
1144 runtime decides which workers to read from. If `"ANY"`, reads from any
1145 tf.data service workers. If `"LOCAL"`, only reads from local in-processs
1146 tf.data service workers. `"AUTO"` works well for most cases, while users
1147 can specify other targets. For example, `"LOCAL"` helps avoid RPCs and
1148 data copy if every TF worker colocates with a tf.data service worker.
1149 Consumers of a shared job must use the same `target_workers`. Defaults to
1150 `"AUTO"`.
1152 Returns:
1153 A `tf.data.Dataset` which reads from the tf.data service.
1154 """
1155 _validate_job_name(job_name)
1156 if job_name is not None:
1157 job_name = string_ops.string_join(
1158 ["dataset_id=", _to_string(dataset_id), job_name], "/")
1160 return _from_dataset_id(
1161 processing_mode=processing_mode,
1162 service=service,
1163 dataset_id=dataset_id,
1164 element_spec=element_spec,
1165 job_name=job_name,
1166 consumer_index=consumer_index,
1167 num_consumers=num_consumers,
1168 max_outstanding_requests=max_outstanding_requests,
1169 data_transfer_protocol=data_transfer_protocol,
1170 cross_trainer_cache=cross_trainer_cache,
1171 target_workers=target_workers)