Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/training/server_lib.py: 23%

200 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""A Python interface for creating TensorFlow servers.""" 

16 

17from tensorflow.core.protobuf import cluster_pb2 

18from tensorflow.core.protobuf import device_filters_pb2 

19from tensorflow.core.protobuf import tensorflow_server_pb2 

20from tensorflow.python.client import pywrap_tf_session as c_api 

21from tensorflow.python.framework import errors 

22from tensorflow.python.util import compat 

23from tensorflow.python.util import deprecation 

24from tensorflow.python.util.tf_export import tf_export 

25 

26 

27def _make_server_def(server_or_cluster_def, job_name, task_index, protocol, 

28 config): 

29 """Creates a `tf.train.ServerDef` protocol buffer. 

30 

31 Args: 

32 server_or_cluster_def: A `tf.train.ServerDef` or `tf.train.ClusterDef` 

33 protocol buffer, or a `tf.train.ClusterSpec` object, describing the server 

34 to be defined and/or the cluster of which it is a member. 

35 job_name: (Optional.) Specifies the name of the job of which the server is a 

36 member. Defaults to the value in `server_or_cluster_def`, if specified. 

37 task_index: (Optional.) Specifies the task index of the server in its job. 

38 Defaults to the value in `server_or_cluster_def`, if specified. Otherwise 

39 defaults to 0 if the server's job has only one task. 

40 protocol: (Optional.) Specifies the protocol to be used by the server. 

41 Acceptable values include `"grpc", "grpc+verbs"`. Defaults to the value in 

42 `server_or_cluster_def`, if specified. Otherwise defaults to `"grpc"`. 

43 config: (Options.) A `tf.compat.v1.ConfigProto` that specifies default 

44 configuration options for all sessions that run on this server. 

45 

46 Returns: 

47 A `tf.train.ServerDef`. 

48 

49 Raises: 

50 TypeError: If the arguments do not have the appropriate type. 

51 ValueError: If an argument is not specified and cannot be inferred. 

52 """ 

53 server_def = tensorflow_server_pb2.ServerDef() 

54 if isinstance(server_or_cluster_def, tensorflow_server_pb2.ServerDef): 

55 server_def.MergeFrom(server_or_cluster_def) 

56 if job_name is not None: 

57 server_def.job_name = job_name 

58 if task_index is not None: 

59 server_def.task_index = task_index 

60 if protocol is not None: 

61 server_def.protocol = protocol 

62 if config is not None: 

63 server_def.default_session_config.MergeFrom(config) 

64 else: 

65 try: 

66 cluster_spec = ClusterSpec(server_or_cluster_def) 

67 except TypeError: 

68 raise TypeError("Could not convert `server_or_cluster_def` to a " 

69 "`tf.train.ServerDef` or `tf.train.ClusterSpec`.") 

70 if job_name is None: 

71 if len(cluster_spec.jobs) == 1: 

72 job_name = cluster_spec.jobs[0] 

73 else: 

74 raise ValueError("Must specify an explicit `job_name`.") 

75 if task_index is None: 

76 task_indices = cluster_spec.task_indices(job_name) 

77 if len(task_indices) == 1: 

78 task_index = task_indices[0] 

79 else: 

80 raise ValueError("Must specify an explicit `task_index`.") 

81 if protocol is None: 

82 protocol = "grpc" 

83 

84 server_def = tensorflow_server_pb2.ServerDef( 

85 cluster=cluster_spec.as_cluster_def(), 

86 job_name=job_name, 

87 task_index=task_index, 

88 protocol=protocol) 

89 if config is not None: 

90 server_def.default_session_config.MergeFrom(config) 

91 return server_def 

92 

93 

94@tf_export("distribute.Server", v1=["distribute.Server", "train.Server"]) 

95@deprecation.deprecated_endpoints("train.Server") 

96class Server: 

97 """An in-process TensorFlow server, for use in distributed training. 

98 

99 A `tf.distribute.Server` instance encapsulates a set of devices and a 

100 `tf.compat.v1.Session` target that 

101 can participate in distributed training. A server belongs to a 

102 cluster (specified by a `tf.train.ClusterSpec`), and 

103 corresponds to a particular task in a named job. The server can 

104 communicate with any other server in the same cluster. 

105 """ 

106 

107 def __init__(self, 

108 server_or_cluster_def, 

109 job_name=None, 

110 task_index=None, 

111 protocol=None, 

112 config=None, 

113 start=True): 

114 """Creates a new server with the given definition. 

115 

116 The `job_name`, `task_index`, and `protocol` arguments are optional, and 

117 override any information provided in `server_or_cluster_def`. 

118 

119 Args: 

120 server_or_cluster_def: A `tf.train.ServerDef` or `tf.train.ClusterDef` 

121 protocol buffer, or a `tf.train.ClusterSpec` object, describing the 

122 server to be created and/or the cluster of which it is a member. 

123 job_name: (Optional.) Specifies the name of the job of which the server is 

124 a member. Defaults to the value in `server_or_cluster_def`, if 

125 specified. 

126 task_index: (Optional.) Specifies the task index of the server in its job. 

127 Defaults to the value in `server_or_cluster_def`, if specified. 

128 Otherwise defaults to 0 if the server's job has only one task. 

129 protocol: (Optional.) Specifies the protocol to be used by the server. 

130 Acceptable values include `"grpc", "grpc+verbs"`. Defaults to the value 

131 in `server_or_cluster_def`, if specified. Otherwise defaults to 

132 `"grpc"`. 

133 config: (Options.) A `tf.compat.v1.ConfigProto` that specifies default 

134 configuration options for all sessions that run on this server. 

135 start: (Optional.) Boolean, indicating whether to start the server after 

136 creating it. Defaults to `True`. 

137 

138 Raises: 

139 tf.errors.OpError: Or one of its subclasses if an error occurs while 

140 creating the TensorFlow server. 

141 """ 

142 self._server_def = _make_server_def(server_or_cluster_def, job_name, 

143 task_index, protocol, config) 

144 self._server = c_api.TF_NewServer(self._server_def.SerializeToString()) 

145 if start: 

146 self.start() 

147 

148 def __del__(self): 

149 # At shutdown, `errors` may have been garbage collected. 

150 if errors is not None: 

151 exception = errors.UnimplementedError 

152 else: 

153 exception = Exception 

154 try: 

155 c_api.TF_ServerStop(self._server) 

156 # Clean shutdown of servers is not yet implemented, so 

157 # we leak instead of calling c_api.TF_DeleteServer here. 

158 # See: 

159 # https://github.com/tensorflow/tensorflow/blob/0495317a6e9dd4cac577b9d5cf9525e62b571018/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h#L73 

160 except AttributeError: 

161 # At shutdown, `c_api` may have been garbage collected. 

162 pass 

163 except exception: 

164 pass 

165 self._server = None 

166 

167 def start(self): 

168 """Starts this server. 

169 

170 Raises: 

171 tf.errors.OpError: Or one of its subclasses if an error occurs while 

172 starting the TensorFlow server. 

173 """ 

174 c_api.TF_ServerStart(self._server) 

175 

176 def join(self): 

177 """Blocks until the server has shut down. 

178 

179 This method currently blocks forever. 

180 

181 Raises: 

182 tf.errors.OpError: Or one of its subclasses if an error occurs while 

183 joining the TensorFlow server. 

184 """ 

185 c_api.TF_ServerJoin(self._server) 

186 

187 @property 

188 def server_def(self): 

189 """Returns the `tf.train.ServerDef` for this server. 

190 

191 Returns: 

192 A `tf.train.ServerDef` protocol buffer that describes the configuration 

193 of this server. 

194 """ 

195 return self._server_def 

196 

197 @property 

198 def target(self): 

199 """Returns the target for a `tf.compat.v1.Session` to connect to this server. 

200 

201 To create a 

202 `tf.compat.v1.Session` that 

203 connects to this server, use the following snippet: 

204 

205 ```python 

206 server = tf.distribute.Server(...) 

207 with tf.compat.v1.Session(server.target): 

208 # ... 

209 ``` 

210 

211 Returns: 

212 A string containing a session target for this server. 

213 """ 

214 return c_api.TF_ServerTarget(self._server) 

215 

216 @staticmethod 

217 def create_local_server(config=None, start=True): 

218 """Creates a new single-process cluster running on the local host. 

219 

220 This method is a convenience wrapper for creating a 

221 `tf.distribute.Server` with a `tf.train.ServerDef` that specifies a 

222 single-process cluster containing a single task in a job called 

223 `"local"`. 

224 

225 Args: 

226 config: (Options.) A `tf.compat.v1.ConfigProto` that specifies default 

227 configuration options for all sessions that run on this server. 

228 start: (Optional.) Boolean, indicating whether to start the server after 

229 creating it. Defaults to `True`. 

230 

231 Returns: 

232 A local `tf.distribute.Server`. 

233 """ 

234 # Specifying port 0 means that the OS will choose a free port for the 

235 # server. 

236 return Server({"localhost": ["localhost:0"]}, 

237 protocol="grpc", 

238 config=config, 

239 start=start) 

240 

241 

242@tf_export("train.ClusterSpec") 

243class ClusterSpec: 

244 """Represents a cluster as a set of "tasks", organized into "jobs". 

245 

246 A `tf.train.ClusterSpec` represents the set of processes that 

247 participate in a distributed TensorFlow computation. Every 

248 `tf.distribute.Server` is constructed in a particular cluster. 

249 

250 To create a cluster with two jobs and five tasks, you specify the 

251 mapping from job names to lists of network addresses (typically 

252 hostname-port pairs). 

253 

254 ```python 

255 cluster = tf.train.ClusterSpec({"worker": ["worker0.example.com:2222", 

256 "worker1.example.com:2222", 

257 "worker2.example.com:2222"], 

258 "ps": ["ps0.example.com:2222", 

259 "ps1.example.com:2222"]}) 

260 ``` 

261 

262 Each job may also be specified as a sparse mapping from task indices 

263 to network addresses. This enables a server to be configured without 

264 needing to know the identity of (for example) all other worker 

265 tasks: 

266 

267 ```python 

268 cluster = tf.train.ClusterSpec({"worker": {1: "worker1.example.com:2222"}, 

269 "ps": ["ps0.example.com:2222", 

270 "ps1.example.com:2222"]}) 

271 ``` 

272 """ 

273 

274 def __init__(self, cluster): 

275 """Creates a `ClusterSpec`. 

276 

277 Args: 

278 cluster: A dictionary mapping one or more job names to (i) a list of 

279 network addresses, or (ii) a dictionary mapping integer task indices to 

280 network addresses; or a `tf.train.ClusterDef` protocol buffer. 

281 

282 Raises: 

283 TypeError: If `cluster` is not a dictionary mapping strings to lists 

284 of strings, and not a `tf.train.ClusterDef` protobuf. 

285 """ 

286 if isinstance(cluster, dict): 

287 self._cluster_spec = {} 

288 for job_name, tasks in cluster.items(): 

289 if isinstance(tasks, (list, tuple)): 

290 job_tasks = {i: task for i, task in enumerate(tasks)} 

291 elif isinstance(tasks, dict): 

292 job_tasks = {int(i): task for i, task in tasks.items()} 

293 else: 

294 raise TypeError("The tasks for job %r must be a list or a dictionary " 

295 "from integers to strings." % job_name) 

296 self._cluster_spec[job_name] = job_tasks 

297 self._make_cluster_def() 

298 elif isinstance(cluster, cluster_pb2.ClusterDef): 

299 self._cluster_def = cluster 

300 self._cluster_spec = {} 

301 for job_def in self._cluster_def.job: 

302 self._cluster_spec[job_def.name] = { 

303 i: t for i, t in job_def.tasks.items() 

304 } 

305 elif isinstance(cluster, ClusterSpec): 

306 self._cluster_def = cluster_pb2.ClusterDef() 

307 self._cluster_def.MergeFrom(cluster.as_cluster_def()) 

308 self._cluster_spec = {} 

309 for job_def in self._cluster_def.job: 

310 self._cluster_spec[job_def.name] = { 

311 i: t for i, t in job_def.tasks.items() 

312 } 

313 else: 

314 raise TypeError("`cluster` must be a dictionary mapping one or more " 

315 "job names to lists of network addresses, or a " 

316 "`ClusterDef` protocol buffer") 

317 

318 def __bool__(self): 

319 return bool(self._cluster_spec) 

320 

321 # Python 2.x 

322 __nonzero__ = __bool__ 

323 

324 def __eq__(self, other): 

325 return self._cluster_spec == other 

326 

327 def __ne__(self, other): 

328 return self._cluster_spec != other 

329 

330 def __repr__(self): 

331 key_values = self.as_dict() 

332 string_items = [ 

333 repr(k) + ": " + repr(key_values[k]) for k in sorted(key_values) 

334 ] 

335 return "ClusterSpec({" + ", ".join(string_items) + "})" 

336 

337 def as_dict(self): 

338 """Returns a dictionary from job names to their tasks. 

339 

340 For each job, if the task index space is dense, the corresponding 

341 value will be a list of network addresses; otherwise it will be a 

342 dictionary mapping (sparse) task indices to the corresponding 

343 addresses. 

344 

345 Returns: 

346 A dictionary mapping job names to lists or dictionaries 

347 describing the tasks in those jobs. 

348 """ 

349 ret = {} 

350 for job in self.jobs: 

351 task_indices = self.task_indices(job) 

352 if len(task_indices) == 0: 

353 ret[job] = {} 

354 continue 

355 if max(task_indices) + 1 == len(task_indices): 

356 # Return a list because the task indices are dense. This 

357 # matches the behavior of `as_dict()` before support for 

358 # sparse jobs was added. 

359 ret[job] = self.job_tasks(job) 

360 else: 

361 ret[job] = {i: self.task_address(job, i) for i in task_indices} 

362 return ret 

363 

364 def as_cluster_def(self): 

365 """Returns a `tf.train.ClusterDef` protocol buffer based on this cluster.""" 

366 return self._cluster_def 

367 

368 @property 

369 def jobs(self): 

370 """Returns a list of job names in this cluster. 

371 

372 Returns: 

373 A list of strings, corresponding to the names of jobs in this cluster. 

374 """ 

375 return list(self._cluster_spec.keys()) 

376 

377 def num_tasks(self, job_name): 

378 """Returns the number of tasks defined in the given job. 

379 

380 Args: 

381 job_name: The string name of a job in this cluster. 

382 

383 Returns: 

384 The number of tasks defined in the given job. 

385 

386 Raises: 

387 ValueError: If `job_name` does not name a job in this cluster. 

388 """ 

389 try: 

390 job = self._cluster_spec[job_name] 

391 except KeyError: 

392 raise ValueError("No such job in cluster: %r" % job_name) 

393 return len(job) 

394 

395 def task_indices(self, job_name): 

396 """Returns a list of valid task indices in the given job. 

397 

398 Args: 

399 job_name: The string name of a job in this cluster. 

400 

401 Returns: 

402 A list of valid task indices in the given job. 

403 

404 Raises: 

405 ValueError: If `job_name` does not name a job in this cluster, 

406 or no task with index `task_index` is defined in that job. 

407 """ 

408 try: 

409 job = self._cluster_spec[job_name] 

410 except KeyError: 

411 raise ValueError("No such job in cluster: %r" % job_name) 

412 return list(sorted(job.keys())) 

413 

414 def task_address(self, job_name, task_index): 

415 """Returns the address of the given task in the given job. 

416 

417 Args: 

418 job_name: The string name of a job in this cluster. 

419 task_index: A non-negative integer. 

420 

421 Returns: 

422 The address of the given task in the given job. 

423 

424 Raises: 

425 ValueError: If `job_name` does not name a job in this cluster, 

426 or no task with index `task_index` is defined in that job. 

427 """ 

428 try: 

429 job = self._cluster_spec[job_name] 

430 except KeyError: 

431 raise ValueError("No such job in cluster: %r" % job_name) 

432 try: 

433 return job[task_index] 

434 except KeyError: 

435 raise ValueError("No task with index %r in job %r" % 

436 (task_index, job_name)) 

437 

438 def job_tasks(self, job_name): 

439 """Returns a mapping from task ID to address in the given job. 

440 

441 NOTE: For backwards compatibility, this method returns a list. If 

442 the given job was defined with a sparse set of task indices, the 

443 length of this list may not reflect the number of tasks defined in 

444 this job. Use the `tf.train.ClusterSpec.num_tasks` method 

445 to find the number of tasks defined in a particular job. 

446 

447 Args: 

448 job_name: The string name of a job in this cluster. 

449 

450 Returns: 

451 A list of task addresses, where the index in the list 

452 corresponds to the task index of each task. The list may contain 

453 `None` if the job was defined with a sparse set of task indices. 

454 

455 Raises: 

456 ValueError: If `job_name` does not name a job in this cluster. 

457 """ 

458 try: 

459 job = self._cluster_spec[job_name] 

460 except KeyError: 

461 raise ValueError("No such job in cluster: %r" % job_name) 

462 ret = [None for _ in range(max(job.keys()) + 1)] 

463 for i, task in job.items(): 

464 ret[i] = task 

465 return ret 

466 

467 def _make_cluster_def(self): 

468 """Creates a `tf.train.ClusterDef` based on the given `cluster_spec`. 

469 

470 Raises: 

471 TypeError: If `cluster_spec` is not a dictionary mapping strings to lists 

472 of strings. 

473 """ 

474 self._cluster_def = cluster_pb2.ClusterDef() 

475 

476 # NOTE(mrry): Sort by job_name to produce deterministic protobufs. 

477 for job_name, tasks in sorted(self._cluster_spec.items()): 

478 try: 

479 job_name = compat.as_bytes(job_name) 

480 except TypeError: 

481 raise TypeError("Job name %r must be bytes or unicode" % job_name) 

482 

483 job_def = self._cluster_def.job.add() 

484 job_def.name = job_name 

485 

486 for i, task_address in sorted(tasks.items()): 

487 try: 

488 task_address = compat.as_bytes(task_address) 

489 except TypeError: 

490 raise TypeError("Task address %r must be bytes or unicode" % 

491 task_address) 

492 job_def.tasks[i] = task_address 

493 

494 

495@tf_export("config.experimental.ClusterDeviceFilters") 

496class ClusterDeviceFilters: 

497 """Represent a collection of device filters for the remote workers in cluster. 

498 

499 NOTE: this is an experimental API and subject to changes. 

500 

501 Set device filters for selective jobs and tasks. For each remote worker, the 

502 device filters are a list of strings. When any filters are present, the remote 

503 worker will ignore all devices which do not match any of its filters. Each 

504 filter can be partially specified, e.g. "/job:ps", "/job:worker/replica:3", 

505 etc. Note that a device is always visible to the worker it is located on. 

506 

507 For example, to set the device filters for a parameter server cluster: 

508 

509 ```python 

510 cdf = tf.config.experimental.ClusterDeviceFilters() 

511 for i in range(num_workers): 

512 cdf.set_device_filters('worker', i, ['/job:ps']) 

513 for i in range(num_ps): 

514 cdf.set_device_filters('ps', i, ['/job:worker']) 

515 

516 tf.config.experimental_connect_to_cluster(cluster_def, 

517 cluster_device_filters=cdf) 

518 ``` 

519 

520 The device filters can be partically specified. For remote tasks that do not 

521 have device filters specified, all devices will be visible to them. 

522 """ 

523 

524 def __init__(self): 

525 # `_device_filters` is a dict mapping job names to job device filters. 

526 # Job device filters further maps task IDs to task device filters. 

527 # Task device filters are a list of strings, each one is a device filter. 

528 self._device_filters = {} 

529 

530 # Serialized protobuf for cluster device filters. 

531 self._cluster_device_filters = None 

532 

533 def set_device_filters(self, job_name, task_index, device_filters): 

534 """Set the device filters for given job name and task id.""" 

535 assert all(isinstance(df, str) for df in device_filters) 

536 self._device_filters.setdefault(job_name, {}) 

537 self._device_filters[job_name][task_index] = [df for df in device_filters] 

538 # Due to updates in data, invalidate the serialized proto cache. 

539 self._cluster_device_filters = None 

540 

541 def _as_cluster_device_filters(self): 

542 """Returns a serialized protobuf of cluster device filters.""" 

543 if self._cluster_device_filters: 

544 return self._cluster_device_filters 

545 

546 self._make_cluster_device_filters() 

547 return self._cluster_device_filters 

548 

549 def _make_cluster_device_filters(self): 

550 """Creates `ClusterDeviceFilters` proto based on the `_device_filters`. 

551 

552 Raises: 

553 TypeError: If `_device_filters` is not a dictionary mapping strings to 

554 a map of task indices and device filters. 

555 """ 

556 self._cluster_device_filters = device_filters_pb2.ClusterDeviceFilters() 

557 

558 # Sort by job_name to produce deterministic protobufs. 

559 for job_name, tasks in sorted(self._device_filters.items()): 

560 try: 

561 job_name = compat.as_bytes(job_name) 

562 except TypeError: 

563 raise TypeError("Job name %r must be bytes or unicode" % job_name) 

564 

565 jdf = self._cluster_device_filters.jobs.add() 

566 jdf.name = job_name 

567 

568 for i, task_device_filters in sorted(tasks.items()): 

569 for tdf in task_device_filters: 

570 try: 

571 tdf = compat.as_bytes(tdf) 

572 except TypeError: 

573 raise TypeError("Device filter %r must be bytes or unicode" % tdf) 

574 jdf.tasks[i].device_filters.append(tdf)