Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/data/experimental/service/server_lib.py: 43%

94 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""A Python interface for creating dataset servers.""" 

16 

17import collections 

18 

19# pylint: disable=invalid-import-order,g-bad-import-order, unused-import 

20from tensorflow.core.protobuf import service_config_pb2 

21from tensorflow.python import pywrap_tensorflow 

22from tensorflow.python.data.experimental.service import _pywrap_server_lib 

23from tensorflow.python.data.experimental.service import _pywrap_utils 

24from tensorflow.python.util.tf_export import tf_export 

25 

26 

27def _get_time_or_placeholder(value): 

28 """Modifies time-based config values to account for special behaviors.""" 

29 

30 # Servers interpret time values of 0 to mean "choose a reasonable 

31 # default". However, the Python API uses `None` for this, and allows 0 as a 

32 # normal value. To account for this, if a user explicitly configures the 

33 # interval/timeout to 0, we interpret it to mean "a very small number", and 

34 # replace it with 1. 

35 if value == 0: 

36 return 1 

37 # `None` indicates that the user wants to leave the behavior to the runtime. 

38 if value is None: 

39 return 0 

40 return value 

41 

42 

43@tf_export("data.experimental.service.DispatcherConfig") 

44class DispatcherConfig( 

45 collections.namedtuple( 

46 "DispatcherConfig", 

47 [ 

48 "port", 

49 "protocol", 

50 "work_dir", 

51 "fault_tolerant_mode", 

52 "worker_addresses", 

53 "job_gc_check_interval_ms", 

54 "job_gc_timeout_ms", 

55 "worker_timeout_ms", 

56 ], 

57 ) 

58): 

59 """Configuration class for tf.data service dispatchers. 

60 

61 Fields: 

62 port: Specifies the port to bind to. A value of 0 indicates that the server 

63 may bind to any available port. 

64 protocol: The protocol to use for communicating with the tf.data service, 

65 e.g. "grpc". 

66 work_dir: A directory to store dispatcher state in. This 

67 argument is required for the dispatcher to be able to recover from 

68 restarts. 

69 fault_tolerant_mode: Whether the dispatcher should write its state to a 

70 journal so that it can recover from restarts. Dispatcher state, including 

71 registered datasets and created jobs, is synchronously written to the 

72 journal before responding to RPCs. If `True`, `work_dir` must also be 

73 specified. 

74 worker_addresses: If the job uses auto-sharding, it needs to specify a fixed 

75 list of worker addresses that will register with the dispatcher. The 

76 worker addresses should be in the format `"host"` or `"host:port"`, where 

77 `"port"` is an integer, named port, or `%port%` to match any port. 

78 job_gc_check_interval_ms: How often the dispatcher should scan through to 

79 delete old and unused jobs, in milliseconds. If not set, the runtime will 

80 select a reasonable default. A higher value will reduce load on the 

81 dispatcher, while a lower value will reduce the time it takes for the 

82 dispatcher to garbage collect expired jobs. 

83 job_gc_timeout_ms: How long a job needs to be unused before it becomes a 

84 candidate for garbage collection, in milliseconds. A value of -1 indicates 

85 that jobs should never be garbage collected. If not set, the runtime will 

86 select a reasonable default. A higher value will cause jobs to stay around 

87 longer with no consumers. This is useful if there is a large gap in 

88 time between when consumers read from the job. A lower value will reduce 

89 the time it takes to reclaim the resources from expired jobs. 

90 worker_timeout_ms: How long to wait for a worker to heartbeat before 

91 considering it missing. If not set, the runtime will select a reasonable 

92 default. 

93 """ 

94 

95 def __new__( 

96 cls, 

97 port=0, 

98 protocol=None, 

99 work_dir=None, 

100 fault_tolerant_mode=False, 

101 worker_addresses=None, 

102 job_gc_check_interval_ms=None, 

103 job_gc_timeout_ms=None, 

104 worker_timeout_ms=None, 

105 ): 

106 if protocol is None: 

107 protocol = _pywrap_utils.TF_DATA_DefaultProtocol() 

108 job_gc_check_interval_ms = _get_time_or_placeholder( 

109 job_gc_check_interval_ms) 

110 job_gc_timeout_ms = _get_time_or_placeholder(job_gc_timeout_ms) 

111 return super().__new__( 

112 cls, 

113 port, 

114 protocol, 

115 work_dir, 

116 fault_tolerant_mode, 

117 worker_addresses, 

118 job_gc_check_interval_ms, 

119 job_gc_timeout_ms, 

120 worker_timeout_ms, 

121 ) 

122 

123 

124@tf_export("data.experimental.service.DispatchServer", v1=[]) 

125class DispatchServer: 

126 """An in-process tf.data service dispatch server. 

127 

128 A `tf.data.experimental.service.DispatchServer` coordinates a cluster of 

129 `tf.data.experimental.service.WorkerServer`s. When the workers start, they 

130 register themselves with the dispatcher. 

131 

132 >>> dispatcher = tf.data.experimental.service.DispatchServer() 

133 >>> dispatcher_address = dispatcher.target.split("://")[1] 

134 >>> worker = tf.data.experimental.service.WorkerServer( 

135 ... tf.data.experimental.service.WorkerConfig( 

136 ... dispatcher_address=dispatcher_address)) 

137 >>> dataset = tf.data.Dataset.range(10) 

138 >>> dataset = dataset.apply(tf.data.experimental.service.distribute( 

139 ... processing_mode="parallel_epochs", service=dispatcher.target)) 

140 >>> print(list(dataset.as_numpy_iterator())) 

141 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] 

142 

143 When starting a dedicated tf.data dispatch process, use join() to block 

144 after starting up the server, until the server terminates. 

145 

146 ``` 

147 dispatcher = tf.data.experimental.service.DispatchServer( 

148 tf.data.experimental.service.DispatcherConfig(port=5050)) 

149 dispatcher.join() 

150 ``` 

151 

152 Call stop() to gracefully terminate the dispatcher. The server automatically 

153 stops when all reference to it have been deleted. 

154 

155 To start a `DispatchServer` in fault-tolerant mode, set `work_dir` and 

156 `fault_tolerant_mode` like below: 

157 

158 ``` 

159 dispatcher = tf.data.experimental.service.DispatchServer( 

160 tf.data.experimental.service.DispatcherConfig( 

161 port=5050, 

162 work_dir="gs://my-bucket/dispatcher/work_dir", 

163 fault_tolerant_mode=True)) 

164 ``` 

165 """ 

166 

167 def __init__(self, config=None, start=True): 

168 """Creates a new dispatch server. 

169 

170 Args: 

171 config: (Optional.) A `tf.data.experimental.service.DispatcherConfig` 

172 configration. If `None`, the dispatcher will use default 

173 configuration values. 

174 start: (Optional.) Boolean, indicating whether to start the server after 

175 creating it. Defaults to True. 

176 """ 

177 config = config or DispatcherConfig() 

178 if config.fault_tolerant_mode and not config.work_dir: 

179 raise ValueError( 

180 "Cannot enable fault tolerant mode without configuring a work dir. " 

181 "Make sure to set `work_dir` in the `config` object passed to " 

182 "`DispatcherServer`.") 

183 self._config = config 

184 if isinstance(config, service_config_pb2.DispatcherConfig): 

185 config_proto = config 

186 else: 

187 config_proto = service_config_pb2.DispatcherConfig( 

188 port=config.port, 

189 protocol=config.protocol, 

190 work_dir=config.work_dir, 

191 fault_tolerant_mode=config.fault_tolerant_mode, 

192 worker_addresses=config.worker_addresses, 

193 job_gc_check_interval_ms=config.job_gc_check_interval_ms, 

194 job_gc_timeout_ms=config.job_gc_timeout_ms, 

195 worker_timeout_ms=config.worker_timeout_ms, 

196 ) 

197 self._server = _pywrap_server_lib.TF_DATA_NewDispatchServer( 

198 config_proto.SerializeToString()) 

199 if start: 

200 self._server.start() 

201 

202 def start(self): 

203 """Starts this server. 

204 

205 >>> dispatcher = tf.data.experimental.service.DispatchServer(start=False) 

206 >>> dispatcher.start() 

207 

208 Raises: 

209 tf.errors.OpError: Or one of its subclasses if an error occurs while 

210 starting the server. 

211 """ 

212 self._server.start() 

213 

214 def join(self): 

215 """Blocks until the server has shut down. 

216 

217 This is useful when starting a dedicated dispatch process. 

218 

219 ``` 

220 dispatcher = tf.data.experimental.service.DispatchServer( 

221 tf.data.experimental.service.DispatcherConfig(port=5050)) 

222 dispatcher.join() 

223 ``` 

224 

225 Raises: 

226 tf.errors.OpError: Or one of its subclasses if an error occurs while 

227 joining the server. 

228 """ 

229 self._server.join() 

230 

231 def stop(self): 

232 """Stops the server. 

233 

234 Raises: 

235 tf.errors.OpError: Or one of its subclasses if an error occurs while 

236 stopping the server. 

237 """ 

238 self._stop() 

239 

240 @property 

241 def target(self): 

242 """Returns a target that can be used to connect to the server. 

243 

244 >>> dispatcher = tf.data.experimental.service.DispatchServer() 

245 >>> dataset = tf.data.Dataset.range(10) 

246 >>> dataset = dataset.apply(tf.data.experimental.service.distribute( 

247 ... processing_mode="parallel_epochs", service=dispatcher.target)) 

248 

249 The returned string will be in the form protocol://address, e.g. 

250 "grpc://localhost:5050". 

251 """ 

252 return "{0}://localhost:{1}".format(self._config.protocol, 

253 self._server.bound_port()) 

254 

255 def _stop(self): 

256 """Stops the server. 

257 

258 Raises: 

259 tf.errors.OpError: Or one of its subclasses if an error occurs while 

260 stopping the server. 

261 """ 

262 self._server.stop() 

263 

264 def __del__(self): 

265 self._stop() 

266 

267 @property 

268 def _address(self): 

269 """Returns the address of the server. 

270 

271 The returned string will be in the form address:port, e.g. "localhost:1000". 

272 """ 

273 return "localhost:{0}".format(self._server.bound_port()) 

274 

275 def _num_workers(self): 

276 """Returns the number of workers registered with the dispatcher.""" 

277 return self._server.num_workers() 

278 

279 def _snapshot_streams(self, path): 

280 """Returns information about all the streams for a snapshot.""" 

281 return self._server.snapshot_streams(path) 

282 

283 

284@tf_export("data.experimental.service.WorkerConfig") 

285class WorkerConfig( 

286 collections.namedtuple("WorkerConfig", [ 

287 "dispatcher_address", "worker_address", "port", "protocol", 

288 "heartbeat_interval_ms", "dispatcher_timeout_ms", 

289 "data_transfer_protocol", "data_transfer_address" 

290 ])): 

291 """Configuration class for tf.data service dispatchers. 

292 

293 Fields: 

294 dispatcher_address: Specifies the address of the dispatcher. 

295 worker_address: Specifies the address of the worker server. This address is 

296 passed to the dispatcher so that the dispatcher can tell clients how to 

297 connect to this worker. 

298 port: Specifies the port to bind to. A value of 0 indicates that the worker 

299 can bind to any available port. 

300 protocol: A string indicating the protocol to be used by the worker to 

301 connect to the dispatcher. E.g. "grpc". 

302 heartbeat_interval_ms: How often the worker should heartbeat to the 

303 dispatcher, in milliseconds. If not set, the runtime will select a 

304 reasonable default. A higher value will reduce the load on the dispatcher, 

305 while a lower value will reduce the time it takes to reclaim resources 

306 from finished jobs. 

307 dispatcher_timeout_ms: How long, in milliseconds, to retry requests to the 

308 dispatcher before giving up and reporting an error. Defaults to 1 hour. 

309 data_transfer_protocol: A string indicating the protocol to be used by the 

310 worker to transfer data to the client. E.g. "grpc". 

311 data_transfer_address: A string indicating the data transfer address of the 

312 worker server. 

313 """ 

314 

315 def __new__(cls, 

316 dispatcher_address, 

317 worker_address=None, 

318 port=0, 

319 protocol=None, 

320 heartbeat_interval_ms=None, 

321 dispatcher_timeout_ms=None, 

322 data_transfer_protocol=None, 

323 data_transfer_address=None): 

324 if worker_address is None: 

325 worker_address = "localhost:%port%" 

326 if protocol is None: 

327 protocol = _pywrap_utils.TF_DATA_DefaultProtocol() 

328 if data_transfer_address is None: 

329 data_transfer_address = "localhost:%port%" 

330 heartbeat_interval_ms = _get_time_or_placeholder(heartbeat_interval_ms) 

331 dispatcher_timeout_ms = _get_time_or_placeholder(dispatcher_timeout_ms) 

332 

333 return super(WorkerConfig, 

334 cls).__new__(cls, dispatcher_address, worker_address, port, 

335 protocol, heartbeat_interval_ms, 

336 dispatcher_timeout_ms, data_transfer_protocol, 

337 data_transfer_address) 

338 

339 

340@tf_export("data.experimental.service.WorkerServer", v1=[]) 

341class WorkerServer: 

342 """An in-process tf.data service worker server. 

343 

344 A `tf.data.experimental.service.WorkerServer` performs `tf.data.Dataset` 

345 processing for user-defined datasets, and provides the resulting elements over 

346 RPC. A worker is associated with a single 

347 `tf.data.experimental.service.DispatchServer`. 

348 

349 >>> dispatcher = tf.data.experimental.service.DispatchServer() 

350 >>> dispatcher_address = dispatcher.target.split("://")[1] 

351 >>> worker = tf.data.experimental.service.WorkerServer( 

352 ... tf.data.experimental.service.WorkerConfig( 

353 ... dispatcher_address=dispatcher_address)) 

354 >>> dataset = tf.data.Dataset.range(10) 

355 >>> dataset = dataset.apply(tf.data.experimental.service.distribute( 

356 ... processing_mode="parallel_epochs", service=dispatcher.target)) 

357 >>> print(list(dataset.as_numpy_iterator())) 

358 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] 

359 

360 When starting a dedicated tf.data worker process, use join() to block 

361 after starting up the worker, until the worker terminates. 

362 

363 ``` 

364 worker = tf.data.experimental.service.WorkerServer( 

365 port=5051, dispatcher_address="localhost:5050") 

366 worker.join() 

367 ``` 

368 

369 Call stop() to gracefully terminate the worker. The worker automatically stops 

370 when all reference to it have been deleted. 

371 """ 

372 

373 def __init__(self, config, start=True): 

374 """Creates a new worker server. 

375 

376 Args: 

377 config: A `tf.data.experimental.service.WorkerConfig` configration. 

378 start: (Optional.) Boolean, indicating whether to start the server after 

379 creating it. Defaults to True. 

380 """ 

381 if config.dispatcher_address is None: 

382 raise ValueError( 

383 "Must specify a `dispatcher_address` in the `config` passed " 

384 "to `WorkerServer`.") 

385 if isinstance(config, service_config_pb2.WorkerConfig): 

386 config_proto = config 

387 else: 

388 config_proto = service_config_pb2.WorkerConfig( 

389 dispatcher_address=config.dispatcher_address, 

390 worker_address=config.worker_address, 

391 port=config.port, 

392 protocol=config.protocol, 

393 heartbeat_interval_ms=config.heartbeat_interval_ms, 

394 dispatcher_timeout_ms=config.dispatcher_timeout_ms, 

395 data_transfer_protocol=config.data_transfer_protocol, 

396 data_transfer_address=config.data_transfer_address) 

397 self._server = _pywrap_server_lib.TF_DATA_NewWorkerServer( 

398 config_proto.SerializeToString()) 

399 if start: 

400 self._server.start() 

401 

402 def start(self): 

403 """Starts this server. 

404 

405 Raises: 

406 tf.errors.OpError: Or one of its subclasses if an error occurs while 

407 starting the server. 

408 """ 

409 self._server.start() 

410 

411 def join(self): 

412 """Blocks until the server has shut down. 

413 

414 This is useful when starting a dedicated worker process. 

415 

416 ``` 

417 worker_server = tf.data.experimental.service.WorkerServer( 

418 port=5051, dispatcher_address="localhost:5050") 

419 worker_server.join() 

420 ``` 

421 

422 This method currently blocks forever. 

423 

424 Raises: 

425 tf.errors.OpError: Or one of its subclasses if an error occurs while 

426 joining the server. 

427 """ 

428 self._server.join() 

429 

430 def stop(self): 

431 """Stops the server. 

432 

433 Raises: 

434 tf.errors.OpError: Or one of its subclasses if an error occurs while 

435 stopping the server. 

436 """ 

437 self._stop() 

438 

439 def _stop(self): 

440 """Stops the server. 

441 

442 Raises: 

443 tf.errors.OpError: Or one of its subclasses if an error occurs while 

444 stopping the server. 

445 """ 

446 self._server.stop() 

447 

448 def __del__(self): 

449 self._stop() 

450 

451 @property 

452 def _address(self): 

453 """Returns the address of the server. 

454 

455 The returned string will be in the form address:port, e.g. "localhost:1000". 

456 """ 

457 return "localhost:{0}".format(self._server.bound_port()) 

458 

459 def _num_tasks(self): 

460 """Returns the number of tasks currently being executed on the worker.""" 

461 return self._server.num_tasks() 

462 

463 def _snapshot_task_progresses(self): 

464 """Returns the progresses of the snapshot tasks currently being executed. 

465 

466 Returns: 

467 An `Iterable[common_pb2.SnapshotTaskProgress]`. 

468 """ 

469 return self._server.snapshot_task_progresses()