Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/data/experimental/service/server_lib.py: 43%
94 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""A Python interface for creating dataset servers."""
17import collections
19# pylint: disable=invalid-import-order,g-bad-import-order, unused-import
20from tensorflow.core.protobuf import service_config_pb2
21from tensorflow.python import pywrap_tensorflow
22from tensorflow.python.data.experimental.service import _pywrap_server_lib
23from tensorflow.python.data.experimental.service import _pywrap_utils
24from tensorflow.python.util.tf_export import tf_export
27def _get_time_or_placeholder(value):
28 """Modifies time-based config values to account for special behaviors."""
30 # Servers interpret time values of 0 to mean "choose a reasonable
31 # default". However, the Python API uses `None` for this, and allows 0 as a
32 # normal value. To account for this, if a user explicitly configures the
33 # interval/timeout to 0, we interpret it to mean "a very small number", and
34 # replace it with 1.
35 if value == 0:
36 return 1
37 # `None` indicates that the user wants to leave the behavior to the runtime.
38 if value is None:
39 return 0
40 return value
43@tf_export("data.experimental.service.DispatcherConfig")
44class DispatcherConfig(
45 collections.namedtuple(
46 "DispatcherConfig",
47 [
48 "port",
49 "protocol",
50 "work_dir",
51 "fault_tolerant_mode",
52 "worker_addresses",
53 "job_gc_check_interval_ms",
54 "job_gc_timeout_ms",
55 "worker_timeout_ms",
56 ],
57 )
58):
59 """Configuration class for tf.data service dispatchers.
61 Fields:
62 port: Specifies the port to bind to. A value of 0 indicates that the server
63 may bind to any available port.
64 protocol: The protocol to use for communicating with the tf.data service,
65 e.g. "grpc".
66 work_dir: A directory to store dispatcher state in. This
67 argument is required for the dispatcher to be able to recover from
68 restarts.
69 fault_tolerant_mode: Whether the dispatcher should write its state to a
70 journal so that it can recover from restarts. Dispatcher state, including
71 registered datasets and created jobs, is synchronously written to the
72 journal before responding to RPCs. If `True`, `work_dir` must also be
73 specified.
74 worker_addresses: If the job uses auto-sharding, it needs to specify a fixed
75 list of worker addresses that will register with the dispatcher. The
76 worker addresses should be in the format `"host"` or `"host:port"`, where
77 `"port"` is an integer, named port, or `%port%` to match any port.
78 job_gc_check_interval_ms: How often the dispatcher should scan through to
79 delete old and unused jobs, in milliseconds. If not set, the runtime will
80 select a reasonable default. A higher value will reduce load on the
81 dispatcher, while a lower value will reduce the time it takes for the
82 dispatcher to garbage collect expired jobs.
83 job_gc_timeout_ms: How long a job needs to be unused before it becomes a
84 candidate for garbage collection, in milliseconds. A value of -1 indicates
85 that jobs should never be garbage collected. If not set, the runtime will
86 select a reasonable default. A higher value will cause jobs to stay around
87 longer with no consumers. This is useful if there is a large gap in
88 time between when consumers read from the job. A lower value will reduce
89 the time it takes to reclaim the resources from expired jobs.
90 worker_timeout_ms: How long to wait for a worker to heartbeat before
91 considering it missing. If not set, the runtime will select a reasonable
92 default.
93 """
95 def __new__(
96 cls,
97 port=0,
98 protocol=None,
99 work_dir=None,
100 fault_tolerant_mode=False,
101 worker_addresses=None,
102 job_gc_check_interval_ms=None,
103 job_gc_timeout_ms=None,
104 worker_timeout_ms=None,
105 ):
106 if protocol is None:
107 protocol = _pywrap_utils.TF_DATA_DefaultProtocol()
108 job_gc_check_interval_ms = _get_time_or_placeholder(
109 job_gc_check_interval_ms)
110 job_gc_timeout_ms = _get_time_or_placeholder(job_gc_timeout_ms)
111 return super().__new__(
112 cls,
113 port,
114 protocol,
115 work_dir,
116 fault_tolerant_mode,
117 worker_addresses,
118 job_gc_check_interval_ms,
119 job_gc_timeout_ms,
120 worker_timeout_ms,
121 )
124@tf_export("data.experimental.service.DispatchServer", v1=[])
125class DispatchServer:
126 """An in-process tf.data service dispatch server.
128 A `tf.data.experimental.service.DispatchServer` coordinates a cluster of
129 `tf.data.experimental.service.WorkerServer`s. When the workers start, they
130 register themselves with the dispatcher.
132 >>> dispatcher = tf.data.experimental.service.DispatchServer()
133 >>> dispatcher_address = dispatcher.target.split("://")[1]
134 >>> worker = tf.data.experimental.service.WorkerServer(
135 ... tf.data.experimental.service.WorkerConfig(
136 ... dispatcher_address=dispatcher_address))
137 >>> dataset = tf.data.Dataset.range(10)
138 >>> dataset = dataset.apply(tf.data.experimental.service.distribute(
139 ... processing_mode="parallel_epochs", service=dispatcher.target))
140 >>> print(list(dataset.as_numpy_iterator()))
141 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
143 When starting a dedicated tf.data dispatch process, use join() to block
144 after starting up the server, until the server terminates.
146 ```
147 dispatcher = tf.data.experimental.service.DispatchServer(
148 tf.data.experimental.service.DispatcherConfig(port=5050))
149 dispatcher.join()
150 ```
152 Call stop() to gracefully terminate the dispatcher. The server automatically
153 stops when all reference to it have been deleted.
155 To start a `DispatchServer` in fault-tolerant mode, set `work_dir` and
156 `fault_tolerant_mode` like below:
158 ```
159 dispatcher = tf.data.experimental.service.DispatchServer(
160 tf.data.experimental.service.DispatcherConfig(
161 port=5050,
162 work_dir="gs://my-bucket/dispatcher/work_dir",
163 fault_tolerant_mode=True))
164 ```
165 """
167 def __init__(self, config=None, start=True):
168 """Creates a new dispatch server.
170 Args:
171 config: (Optional.) A `tf.data.experimental.service.DispatcherConfig`
172 configration. If `None`, the dispatcher will use default
173 configuration values.
174 start: (Optional.) Boolean, indicating whether to start the server after
175 creating it. Defaults to True.
176 """
177 config = config or DispatcherConfig()
178 if config.fault_tolerant_mode and not config.work_dir:
179 raise ValueError(
180 "Cannot enable fault tolerant mode without configuring a work dir. "
181 "Make sure to set `work_dir` in the `config` object passed to "
182 "`DispatcherServer`.")
183 self._config = config
184 if isinstance(config, service_config_pb2.DispatcherConfig):
185 config_proto = config
186 else:
187 config_proto = service_config_pb2.DispatcherConfig(
188 port=config.port,
189 protocol=config.protocol,
190 work_dir=config.work_dir,
191 fault_tolerant_mode=config.fault_tolerant_mode,
192 worker_addresses=config.worker_addresses,
193 job_gc_check_interval_ms=config.job_gc_check_interval_ms,
194 job_gc_timeout_ms=config.job_gc_timeout_ms,
195 worker_timeout_ms=config.worker_timeout_ms,
196 )
197 self._server = _pywrap_server_lib.TF_DATA_NewDispatchServer(
198 config_proto.SerializeToString())
199 if start:
200 self._server.start()
202 def start(self):
203 """Starts this server.
205 >>> dispatcher = tf.data.experimental.service.DispatchServer(start=False)
206 >>> dispatcher.start()
208 Raises:
209 tf.errors.OpError: Or one of its subclasses if an error occurs while
210 starting the server.
211 """
212 self._server.start()
214 def join(self):
215 """Blocks until the server has shut down.
217 This is useful when starting a dedicated dispatch process.
219 ```
220 dispatcher = tf.data.experimental.service.DispatchServer(
221 tf.data.experimental.service.DispatcherConfig(port=5050))
222 dispatcher.join()
223 ```
225 Raises:
226 tf.errors.OpError: Or one of its subclasses if an error occurs while
227 joining the server.
228 """
229 self._server.join()
231 def stop(self):
232 """Stops the server.
234 Raises:
235 tf.errors.OpError: Or one of its subclasses if an error occurs while
236 stopping the server.
237 """
238 self._stop()
240 @property
241 def target(self):
242 """Returns a target that can be used to connect to the server.
244 >>> dispatcher = tf.data.experimental.service.DispatchServer()
245 >>> dataset = tf.data.Dataset.range(10)
246 >>> dataset = dataset.apply(tf.data.experimental.service.distribute(
247 ... processing_mode="parallel_epochs", service=dispatcher.target))
249 The returned string will be in the form protocol://address, e.g.
250 "grpc://localhost:5050".
251 """
252 return "{0}://localhost:{1}".format(self._config.protocol,
253 self._server.bound_port())
255 def _stop(self):
256 """Stops the server.
258 Raises:
259 tf.errors.OpError: Or one of its subclasses if an error occurs while
260 stopping the server.
261 """
262 self._server.stop()
264 def __del__(self):
265 self._stop()
267 @property
268 def _address(self):
269 """Returns the address of the server.
271 The returned string will be in the form address:port, e.g. "localhost:1000".
272 """
273 return "localhost:{0}".format(self._server.bound_port())
275 def _num_workers(self):
276 """Returns the number of workers registered with the dispatcher."""
277 return self._server.num_workers()
279 def _snapshot_streams(self, path):
280 """Returns information about all the streams for a snapshot."""
281 return self._server.snapshot_streams(path)
284@tf_export("data.experimental.service.WorkerConfig")
285class WorkerConfig(
286 collections.namedtuple("WorkerConfig", [
287 "dispatcher_address", "worker_address", "port", "protocol",
288 "heartbeat_interval_ms", "dispatcher_timeout_ms",
289 "data_transfer_protocol", "data_transfer_address"
290 ])):
291 """Configuration class for tf.data service dispatchers.
293 Fields:
294 dispatcher_address: Specifies the address of the dispatcher.
295 worker_address: Specifies the address of the worker server. This address is
296 passed to the dispatcher so that the dispatcher can tell clients how to
297 connect to this worker.
298 port: Specifies the port to bind to. A value of 0 indicates that the worker
299 can bind to any available port.
300 protocol: A string indicating the protocol to be used by the worker to
301 connect to the dispatcher. E.g. "grpc".
302 heartbeat_interval_ms: How often the worker should heartbeat to the
303 dispatcher, in milliseconds. If not set, the runtime will select a
304 reasonable default. A higher value will reduce the load on the dispatcher,
305 while a lower value will reduce the time it takes to reclaim resources
306 from finished jobs.
307 dispatcher_timeout_ms: How long, in milliseconds, to retry requests to the
308 dispatcher before giving up and reporting an error. Defaults to 1 hour.
309 data_transfer_protocol: A string indicating the protocol to be used by the
310 worker to transfer data to the client. E.g. "grpc".
311 data_transfer_address: A string indicating the data transfer address of the
312 worker server.
313 """
315 def __new__(cls,
316 dispatcher_address,
317 worker_address=None,
318 port=0,
319 protocol=None,
320 heartbeat_interval_ms=None,
321 dispatcher_timeout_ms=None,
322 data_transfer_protocol=None,
323 data_transfer_address=None):
324 if worker_address is None:
325 worker_address = "localhost:%port%"
326 if protocol is None:
327 protocol = _pywrap_utils.TF_DATA_DefaultProtocol()
328 if data_transfer_address is None:
329 data_transfer_address = "localhost:%port%"
330 heartbeat_interval_ms = _get_time_or_placeholder(heartbeat_interval_ms)
331 dispatcher_timeout_ms = _get_time_or_placeholder(dispatcher_timeout_ms)
333 return super(WorkerConfig,
334 cls).__new__(cls, dispatcher_address, worker_address, port,
335 protocol, heartbeat_interval_ms,
336 dispatcher_timeout_ms, data_transfer_protocol,
337 data_transfer_address)
340@tf_export("data.experimental.service.WorkerServer", v1=[])
341class WorkerServer:
342 """An in-process tf.data service worker server.
344 A `tf.data.experimental.service.WorkerServer` performs `tf.data.Dataset`
345 processing for user-defined datasets, and provides the resulting elements over
346 RPC. A worker is associated with a single
347 `tf.data.experimental.service.DispatchServer`.
349 >>> dispatcher = tf.data.experimental.service.DispatchServer()
350 >>> dispatcher_address = dispatcher.target.split("://")[1]
351 >>> worker = tf.data.experimental.service.WorkerServer(
352 ... tf.data.experimental.service.WorkerConfig(
353 ... dispatcher_address=dispatcher_address))
354 >>> dataset = tf.data.Dataset.range(10)
355 >>> dataset = dataset.apply(tf.data.experimental.service.distribute(
356 ... processing_mode="parallel_epochs", service=dispatcher.target))
357 >>> print(list(dataset.as_numpy_iterator()))
358 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
360 When starting a dedicated tf.data worker process, use join() to block
361 after starting up the worker, until the worker terminates.
363 ```
364 worker = tf.data.experimental.service.WorkerServer(
365 port=5051, dispatcher_address="localhost:5050")
366 worker.join()
367 ```
369 Call stop() to gracefully terminate the worker. The worker automatically stops
370 when all reference to it have been deleted.
371 """
373 def __init__(self, config, start=True):
374 """Creates a new worker server.
376 Args:
377 config: A `tf.data.experimental.service.WorkerConfig` configration.
378 start: (Optional.) Boolean, indicating whether to start the server after
379 creating it. Defaults to True.
380 """
381 if config.dispatcher_address is None:
382 raise ValueError(
383 "Must specify a `dispatcher_address` in the `config` passed "
384 "to `WorkerServer`.")
385 if isinstance(config, service_config_pb2.WorkerConfig):
386 config_proto = config
387 else:
388 config_proto = service_config_pb2.WorkerConfig(
389 dispatcher_address=config.dispatcher_address,
390 worker_address=config.worker_address,
391 port=config.port,
392 protocol=config.protocol,
393 heartbeat_interval_ms=config.heartbeat_interval_ms,
394 dispatcher_timeout_ms=config.dispatcher_timeout_ms,
395 data_transfer_protocol=config.data_transfer_protocol,
396 data_transfer_address=config.data_transfer_address)
397 self._server = _pywrap_server_lib.TF_DATA_NewWorkerServer(
398 config_proto.SerializeToString())
399 if start:
400 self._server.start()
402 def start(self):
403 """Starts this server.
405 Raises:
406 tf.errors.OpError: Or one of its subclasses if an error occurs while
407 starting the server.
408 """
409 self._server.start()
411 def join(self):
412 """Blocks until the server has shut down.
414 This is useful when starting a dedicated worker process.
416 ```
417 worker_server = tf.data.experimental.service.WorkerServer(
418 port=5051, dispatcher_address="localhost:5050")
419 worker_server.join()
420 ```
422 This method currently blocks forever.
424 Raises:
425 tf.errors.OpError: Or one of its subclasses if an error occurs while
426 joining the server.
427 """
428 self._server.join()
430 def stop(self):
431 """Stops the server.
433 Raises:
434 tf.errors.OpError: Or one of its subclasses if an error occurs while
435 stopping the server.
436 """
437 self._stop()
439 def _stop(self):
440 """Stops the server.
442 Raises:
443 tf.errors.OpError: Or one of its subclasses if an error occurs while
444 stopping the server.
445 """
446 self._server.stop()
448 def __del__(self):
449 self._stop()
451 @property
452 def _address(self):
453 """Returns the address of the server.
455 The returned string will be in the form address:port, e.g. "localhost:1000".
456 """
457 return "localhost:{0}".format(self._server.bound_port())
459 def _num_tasks(self):
460 """Returns the number of tasks currently being executed on the worker."""
461 return self._server.num_tasks()
463 def _snapshot_task_progresses(self):
464 """Returns the progresses of the snapshot tasks currently being executed.
466 Returns:
467 An `Iterable[common_pb2.SnapshotTaskProgress]`.
468 """
469 return self._server.snapshot_task_progresses()