Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/training/supervisor.py: 29%

336 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Training helper that checkpoints models and computes summaries.""" 

16import contextlib 

17import os 

18import time 

19 

20from tensorflow.core.framework.summary_pb2 import Summary 

21from tensorflow.core.util.event_pb2 import SessionLog 

22from tensorflow.python.eager import context 

23from tensorflow.python.framework import dtypes 

24from tensorflow.python.framework import meta_graph 

25from tensorflow.python.framework import ops 

26from tensorflow.python.ops import control_flow_ops 

27from tensorflow.python.ops import lookup_ops 

28from tensorflow.python.ops import variables 

29from tensorflow.python.platform import tf_logging as logging 

30from tensorflow.python.summary import summary as _summary 

31from tensorflow.python.training import coordinator 

32from tensorflow.python.training import saver as saver_mod 

33from tensorflow.python.training import session_manager as session_manager_mod 

34from tensorflow.python.training import training_util 

35from tensorflow.python.util import deprecation 

36from tensorflow.python.util.tf_export import tf_export 

37 

38 

39@tf_export(v1=["train.Supervisor"]) 

40class Supervisor: 

41 """A training helper that checkpoints models and computes summaries. 

42 

43 This class is deprecated. Please use 

44 `tf.compat.v1.train.MonitoredTrainingSession` instead. 

45 

46 The Supervisor is a small wrapper around a `Coordinator`, a `Saver`, 

47 and a `SessionManager` that takes care of common needs of TensorFlow 

48 training programs. 

49 

50 #### Use for a single program 

51 

52 ```python 

53 with tf.Graph().as_default(): 

54 ...add operations to the graph... 

55 # Create a Supervisor that will checkpoint the model in '/tmp/mydir'. 

56 sv = Supervisor(logdir='/tmp/mydir') 

57 # Get a TensorFlow session managed by the supervisor. 

58 with sv.managed_session(FLAGS.master) as sess: 

59 # Use the session to train the graph. 

60 while not sv.should_stop(): 

61 sess.run(<my_train_op>) 

62 ``` 

63 

64 Within the `with sv.managed_session()` block all variables in the graph have 

65 been initialized. In addition, a few services have been started to 

66 checkpoint the model and add summaries to the event log. 

67 

68 If the program crashes and is restarted, the managed session automatically 

69 reinitialize variables from the most recent checkpoint. 

70 

71 The supervisor is notified of any exception raised by one of the services. 

72 After an exception is raised, `should_stop()` returns `True`. In that case 

73 the training loop should also stop. This is why the training loop has to 

74 check for `sv.should_stop()`. 

75 

76 Exceptions that indicate that the training inputs have been exhausted, 

77 `tf.errors.OutOfRangeError`, also cause `sv.should_stop()` to return `True` 

78 but are not re-raised from the `with` block: they indicate a normal 

79 termination. 

80 

81 #### Use for multiple replicas 

82 

83 To train with replicas you deploy the same program in a `Cluster`. 

84 One of the tasks must be identified as the *chief*: the task that handles 

85 initialization, checkpoints, summaries, and recovery. The other tasks 

86 depend on the *chief* for these services. 

87 

88 The only change you have to do to the single program code is to indicate 

89 if the program is running as the *chief*. 

90 

91 ```python 

92 # Choose a task as the chief. This could be based on server_def.task_index, 

93 # or job_def.name, or job_def.tasks. It's entirely up to the end user. 

94 # But there can be only one *chief*. 

95 is_chief = (server_def.task_index == 0) 

96 server = tf.distribute.Server(server_def) 

97 

98 with tf.Graph().as_default(): 

99 ...add operations to the graph... 

100 # Create a Supervisor that uses log directory on a shared file system. 

101 # Indicate if you are the 'chief' 

102 sv = Supervisor(logdir='/shared_directory/...', is_chief=is_chief) 

103 # Get a Session in a TensorFlow server on the cluster. 

104 with sv.managed_session(server.target) as sess: 

105 # Use the session to train the graph. 

106 while not sv.should_stop(): 

107 sess.run(<my_train_op>) 

108 ``` 

109 

110 In the *chief* task, the `Supervisor` works exactly as in the first example 

111 above. In the other tasks `sv.managed_session()` waits for the Model to have 

112 been initialized before returning a session to the training code. The 

113 non-chief tasks depend on the chief task for initializing the model. 

114 

115 If one of the tasks crashes and restarts, `managed_session()` 

116 checks if the Model is initialized. If yes, it just creates a session and 

117 returns it to the training code that proceeds normally. If the model needs 

118 to be initialized, the chief task takes care of reinitializing it; the other 

119 tasks just wait for the model to have been initialized. 

120 

121 NOTE: This modified program still works fine as a single program. 

122 The single program marks itself as the chief. 

123 

124 #### What `master` string to use 

125 

126 Whether you are running on your machine or in the cluster you can use the 

127 following values for the --master flag: 

128 

129 * Specifying `''` requests an in-process session that does not use RPC. 

130 

131 * Specifying `'local'` requests a session that uses the RPC-based 

132 "Master interface" to run TensorFlow programs. See 

133 `tf.train.Server.create_local_server` for 

134 details. 

135 

136 * Specifying `'grpc://hostname:port'` requests a session that uses 

137 the RPC interface to a specific host, and also allows the in-process 

138 master to access remote tensorflow workers. Often, it is 

139 appropriate to pass `server.target` (for some `tf.distribute.Server` 

140 named `server). 

141 

142 #### Advanced use 

143 

144 ##### Launching additional services 

145 

146 `managed_session()` launches the Checkpoint and Summary services (threads). 

147 If you need more services to run you can simply launch them in the block 

148 controlled by `managed_session()`. 

149 

150 Example: Start a thread to print losses. We want this thread to run 

151 every 60 seconds, so we launch it with `sv.loop()`. 

152 

153 ```python 

154 ... 

155 sv = Supervisor(logdir='/tmp/mydir') 

156 with sv.managed_session(FLAGS.master) as sess: 

157 sv.loop(60, print_loss, (sess, )) 

158 while not sv.should_stop(): 

159 sess.run(my_train_op) 

160 ``` 

161 

162 ##### Launching fewer services 

163 

164 `managed_session()` launches the "summary" and "checkpoint" threads which use 

165 either the optionally `summary_op` and `saver` passed to the constructor, or 

166 default ones created automatically by the supervisor. If you want to run 

167 your own summary and checkpointing logic, disable these services by passing 

168 `None` to the `summary_op` and `saver` parameters. 

169 

170 Example: Create summaries manually every 100 steps in the chief. 

171 

172 ```python 

173 # Create a Supervisor with no automatic summaries. 

174 sv = Supervisor(logdir='/tmp/mydir', is_chief=is_chief, summary_op=None) 

175 # As summary_op was None, managed_session() does not start the 

176 # summary thread. 

177 with sv.managed_session(FLAGS.master) as sess: 

178 for step in range(1000000): 

179 if sv.should_stop(): 

180 break 

181 if is_chief and step % 100 == 0: 

182 # Create the summary every 100 chief steps. 

183 sv.summary_computed(sess, sess.run(my_summary_op)) 

184 else: 

185 # Train normally 

186 sess.run(my_train_op) 

187 ``` 

188 

189 ##### Custom model initialization 

190 

191 `managed_session()` only supports initializing the model by running an 

192 `init_op` or restoring from the latest checkpoint. If you have special 

193 initialization needs, see how to specify a `local_init_op` when creating the 

194 supervisor. You can also use the `SessionManager` directly to create a 

195 session and check if it could be initialized automatically. 

196 """ 

197 

198 # Value to pass for the 'ready_op', 'init_op', 'summary_op', 'saver', 

199 # and 'global_step' parameters of Supervisor.__init__() to indicate that 

200 # the default behavior should be used. 

201 USE_DEFAULT = 0 

202 

203 @deprecation.deprecated(None, 

204 "Please switch to tf.train.MonitoredTrainingSession") 

205 def __init__(self, 

206 graph=None, 

207 ready_op=USE_DEFAULT, 

208 ready_for_local_init_op=USE_DEFAULT, 

209 is_chief=True, 

210 init_op=USE_DEFAULT, 

211 init_feed_dict=None, 

212 local_init_op=USE_DEFAULT, 

213 logdir=None, 

214 summary_op=USE_DEFAULT, 

215 saver=USE_DEFAULT, 

216 global_step=USE_DEFAULT, 

217 save_summaries_secs=120, 

218 save_model_secs=600, 

219 recovery_wait_secs=30, 

220 stop_grace_secs=120, 

221 checkpoint_basename="model.ckpt", 

222 session_manager=None, 

223 summary_writer=USE_DEFAULT, 

224 init_fn=None, 

225 local_init_run_options=None): 

226 """Create a `Supervisor`. 

227 

228 Args: 

229 graph: A `Graph`. The graph that the model will use. Defaults to the 

230 default `Graph`. The supervisor may add operations to the graph before 

231 creating a session, but the graph should not be modified by the caller 

232 after passing it to the supervisor. 

233 ready_op: 1-D string `Tensor`. This tensor is evaluated by supervisors in 

234 `prepare_or_wait_for_session()` to check if the model is ready to use. 

235 The model is considered ready if it returns an empty array. Defaults to 

236 the tensor returned from `tf.compat.v1.report_uninitialized_variables()` 

237 If `None`, the model is not checked for readiness. 

238 ready_for_local_init_op: 1-D string `Tensor`. This tensor is evaluated by 

239 supervisors in `prepare_or_wait_for_session()` to check if the model is 

240 ready to run the local_init_op. The model is considered ready if it 

241 returns an empty array. Defaults to `None`. If `None`, the model is not 

242 checked for readiness before running local_init_op. 

243 is_chief: If True, create a chief supervisor in charge of initializing and 

244 restoring the model. If False, create a supervisor that relies on a 

245 chief supervisor for inits and restore. 

246 init_op: `Operation`. Used by chief supervisors to initialize the model 

247 when it can not be recovered. Defaults to an `Operation` that 

248 initializes all global variables. If `None`, no initialization is done 

249 automatically unless you pass a value for `init_fn`, see below. 

250 init_feed_dict: A dictionary that maps `Tensor` objects to feed values. 

251 This feed dictionary will be used when `init_op` is evaluated. 

252 local_init_op: `Operation`. Used by all supervisors to run initializations 

253 that should run for every new supervisor instance. By default these are 

254 table initializers and initializers for local variables. If `None`, no 

255 further per supervisor-instance initialization is done automatically. 

256 logdir: A string. Optional path to a directory where to checkpoint the 

257 model and log events for the visualizer. Used by chief supervisors. The 

258 directory will be created if it does not exist. 

259 summary_op: An `Operation` that returns a Summary for the event logs. Used 

260 by chief supervisors if a `logdir` was specified. Defaults to the 

261 operation returned from summary.merge_all(). If `None`, summaries are 

262 not computed automatically. 

263 saver: A Saver object. Used by chief supervisors if a `logdir` was 

264 specified. Defaults to the saved returned by Saver(). If `None`, the 

265 model is not saved automatically. 

266 global_step: An integer Tensor of size 1 that counts steps. The value 

267 from 'global_step' is used in summaries and checkpoint filenames. 

268 Default to the op named 'global_step' in the graph if it exists, is of 

269 rank 1, size 1, and of type tf.int32 or tf.int64. If `None` the global 

270 step is not recorded in summaries and checkpoint files. Used by chief 

271 supervisors if a `logdir` was specified. 

272 save_summaries_secs: Number of seconds between the computation of 

273 summaries for the event log. Defaults to 120 seconds. Pass 0 to 

274 disable summaries. 

275 save_model_secs: Number of seconds between the creation of model 

276 checkpoints. Defaults to 600 seconds. Pass 0 to disable checkpoints. 

277 recovery_wait_secs: Number of seconds between checks that the model is 

278 ready. Used by supervisors when waiting for a chief supervisor to 

279 initialize or restore the model. Defaults to 30 seconds. 

280 stop_grace_secs: Grace period, in seconds, given to running threads to 

281 stop when `stop()` is called. Defaults to 120 seconds. 

282 checkpoint_basename: The basename for checkpoint saving. 

283 session_manager: `SessionManager`, which manages Session creation and 

284 recovery. If it is `None`, a default `SessionManager` will be created 

285 with the set of arguments passed in for backwards compatibility. 

286 summary_writer: `SummaryWriter` to use or `USE_DEFAULT`. Can be `None` to 

287 indicate that no summaries should be written. 

288 init_fn: Optional callable used to initialize the model. Called after the 

289 optional `init_op` is called. The callable must accept one argument, 

290 the session being initialized. 

291 local_init_run_options: RunOptions to be passed as the SessionManager 

292 local_init_run_options parameter. 

293 

294 Returns: 

295 A `Supervisor`. 

296 

297 Raises: 

298 RuntimeError: If called with eager execution enabled. 

299 

300 @compatibility(eager) 

301 `Supervisor`s are not supported when eager execution is enabled. 

302 @end_compatibility 

303 """ 

304 if context.executing_eagerly(): 

305 raise RuntimeError("Supervisors are incompatible with eager execution.") 

306 # Set default values of arguments. 

307 if graph is None: 

308 graph = ops.get_default_graph() 

309 with graph.as_default(): 

310 self._init_ready_op( 

311 ready_op=ready_op, ready_for_local_init_op=ready_for_local_init_op) 

312 self._init_init_op(init_op=init_op, init_feed_dict=init_feed_dict) 

313 self._init_local_init_op(local_init_op=local_init_op) 

314 self._init_saver(saver=saver) 

315 self._init_summary_op(summary_op=summary_op) 

316 self._init_global_step(global_step=global_step) 

317 self._graph = graph 

318 self._meta_graph_def = meta_graph.create_meta_graph_def( 

319 graph_def=graph.as_graph_def(add_shapes=True), 

320 saver_def=self._saver.saver_def if self._saver else None) 

321 self._is_chief = is_chief 

322 self._coord = coordinator.Coordinator() 

323 self._recovery_wait_secs = recovery_wait_secs 

324 self._stop_grace_secs = stop_grace_secs 

325 self._init_fn = init_fn 

326 self._local_init_run_options = local_init_run_options 

327 

328 # Set all attributes related to checkpointing and writing events to None. 

329 # Afterwards, set them appropriately for chief supervisors, as these are 

330 # the only supervisors that can write checkpoints and events. 

331 self._logdir = None 

332 self._save_summaries_secs = None 

333 self._save_model_secs = None 

334 self._save_path = None 

335 self._summary_writer = None 

336 

337 if self._is_chief: 

338 self._logdir = logdir 

339 self._save_summaries_secs = save_summaries_secs 

340 self._save_model_secs = save_model_secs 

341 if self._logdir: 

342 self._save_path = os.path.join(self._logdir, checkpoint_basename) 

343 if summary_writer is Supervisor.USE_DEFAULT: 

344 if self._logdir: 

345 self._summary_writer = _summary.FileWriter(self._logdir) 

346 else: 

347 self._summary_writer = summary_writer 

348 self._graph_added_to_summary = False 

349 

350 self._init_session_manager(session_manager=session_manager) 

351 self._verify_setup() 

352 # The graph is not allowed to change anymore. 

353 graph.finalize() 

354 

355 def _init_session_manager(self, session_manager=None): 

356 if session_manager is None: 

357 self._session_manager = session_manager_mod.SessionManager( 

358 local_init_op=self._local_init_op, 

359 ready_op=self._ready_op, 

360 ready_for_local_init_op=self._ready_for_local_init_op, 

361 graph=self._graph, 

362 recovery_wait_secs=self._recovery_wait_secs, 

363 local_init_run_options=self._local_init_run_options) 

364 else: 

365 self._session_manager = session_manager 

366 

367 def _get_first_op_from_collection(self, key): 

368 """Returns the first `Operation` from a collection. 

369 

370 Args: 

371 key: A string collection key. 

372 

373 Returns: 

374 The first Op found in a collection, or `None` if the collection is empty. 

375 """ 

376 try: 

377 op_list = ops.get_collection(key) 

378 if len(op_list) > 1: 

379 logging.info("Found %d %s operations. Returning the first one.", 

380 len(op_list), key) 

381 if op_list: 

382 return op_list[0] 

383 except LookupError: 

384 pass 

385 

386 return None 

387 

388 def _init_ready_op(self, 

389 ready_op=USE_DEFAULT, 

390 ready_for_local_init_op=USE_DEFAULT): 

391 """Initializes ready_op. 

392 

393 Args: 

394 ready_op: `Tensor` to check if the model is initialized. If it's set to 

395 USE_DEFAULT, creates an op that checks all the variables are 

396 initialized. 

397 ready_for_local_init_op: `Tensor` to check if the model is ready to run 

398 local_init_op. If it's set to USE_DEFAULT, creates an op that checks all 

399 the global variables are initialized. 

400 """ 

401 if ready_op is Supervisor.USE_DEFAULT: 

402 ready_op = self._get_first_op_from_collection(ops.GraphKeys.READY_OP) 

403 if ready_op is None: 

404 ready_op = variables.report_uninitialized_variables() 

405 ops.add_to_collection(ops.GraphKeys.READY_OP, ready_op) 

406 self._ready_op = ready_op 

407 

408 # ready_for_local_init_op defaults to None for backward compatibility 

409 if ready_for_local_init_op is Supervisor.USE_DEFAULT: 

410 ready_for_local_init_op = self._get_first_op_from_collection( 

411 ops.GraphKeys.READY_FOR_LOCAL_INIT_OP) 

412 self._ready_for_local_init_op = ready_for_local_init_op 

413 

414 def _init_init_op(self, init_op=USE_DEFAULT, init_feed_dict=None): 

415 """Initializes init_op. 

416 

417 Args: 

418 init_op: `Operation` to initialize the variables. If set to USE_DEFAULT, 

419 create an op that initializes all variables and tables. 

420 init_feed_dict: A dictionary that maps `Tensor` objects to feed values. 

421 This feed dictionary will be used when `init_op` is evaluated. 

422 """ 

423 if init_op is Supervisor.USE_DEFAULT: 

424 init_op = self._get_first_op_from_collection(ops.GraphKeys.INIT_OP) 

425 if init_op is None: 

426 init_op = variables.global_variables_initializer() 

427 ops.add_to_collection(ops.GraphKeys.INIT_OP, init_op) 

428 self._init_op = init_op 

429 self._init_feed_dict = init_feed_dict 

430 

431 def _init_local_init_op(self, local_init_op=USE_DEFAULT): 

432 """Initializes local_init_op. 

433 

434 Args: 

435 local_init_op: `Operation` run for every new supervisor instance. If set 

436 to USE_DEFAULT, use the first op from the GraphKeys.LOCAL_INIT_OP 

437 collection. If the collection is empty, create an op that initializes 

438 all local variables and all tables. 

439 """ 

440 if local_init_op is Supervisor.USE_DEFAULT: 

441 local_init_op = self._get_first_op_from_collection( 

442 ops.GraphKeys.LOCAL_INIT_OP) 

443 if local_init_op is None: 

444 op_list = [ 

445 variables.local_variables_initializer(), 

446 lookup_ops.tables_initializer() 

447 ] 

448 if op_list: 

449 local_init_op = control_flow_ops.group(*op_list) 

450 ops.add_to_collection(ops.GraphKeys.LOCAL_INIT_OP, local_init_op) 

451 self._local_init_op = local_init_op 

452 

453 def _init_saver(self, saver=USE_DEFAULT): 

454 """Initializes saver. 

455 

456 Args: 

457 saver: A `Saver` object. If set to USE_DEFAULT, create one that saves all 

458 the variables. 

459 """ 

460 if saver is Supervisor.USE_DEFAULT: 

461 saver = self._get_first_op_from_collection(ops.GraphKeys.SAVERS) 

462 if saver is None and variables.global_variables(): 

463 saver = saver_mod.Saver() 

464 ops.add_to_collection(ops.GraphKeys.SAVERS, saver) 

465 self._saver = saver 

466 

467 def _init_summary_op(self, summary_op=USE_DEFAULT): 

468 """Initializes summary_op. 

469 

470 Args: 

471 summary_op: An Operation that returns a Summary for the event logs. If set 

472 to USE_DEFAULT, create an op that merges all the summaries. 

473 """ 

474 if summary_op is Supervisor.USE_DEFAULT: 

475 summary_op = self._get_first_op_from_collection(ops.GraphKeys.SUMMARY_OP) 

476 if summary_op is None: 

477 summary_op = _summary.merge_all() 

478 if summary_op is not None: 

479 ops.add_to_collection(ops.GraphKeys.SUMMARY_OP, summary_op) 

480 self._summary_op = summary_op 

481 

482 def _init_global_step(self, global_step=USE_DEFAULT): 

483 """Initializes global_step. 

484 

485 Args: 

486 global_step: An integer Tensor of size 1 that counts steps. If set to 

487 USE_DEFAULT, creates global_step tensor. 

488 """ 

489 if global_step is Supervisor.USE_DEFAULT: 

490 global_step = self._get_first_op_from_collection( 

491 ops.GraphKeys.GLOBAL_STEP) 

492 if global_step is None: 

493 global_step = self._default_global_step_tensor() 

494 if global_step is not None: 

495 ops.add_to_collection(ops.GraphKeys.GLOBAL_STEP, global_step) 

496 self._global_step = global_step 

497 

498 @property 

499 def is_chief(self): 

500 """Return True if this is a chief supervisor. 

501 

502 Returns: 

503 A bool. 

504 """ 

505 return self._is_chief 

506 

507 @property 

508 def session_manager(self): 

509 """Return the SessionManager used by the Supervisor. 

510 

511 Returns: 

512 A SessionManager object. 

513 """ 

514 return self._session_manager 

515 

516 @property 

517 def coord(self): 

518 """Return the Coordinator used by the Supervisor. 

519 

520 The Coordinator can be useful if you want to run multiple threads 

521 during your training. 

522 

523 Returns: 

524 A Coordinator object. 

525 """ 

526 return self._coord 

527 

528 @property 

529 def init_op(self): 

530 """Return the Init Op used by the supervisor. 

531 

532 Returns: 

533 An Op or `None`. 

534 """ 

535 return self._init_op 

536 

537 @property 

538 def init_feed_dict(self): 

539 """Return the feed dictionary used when evaluating the `init_op`. 

540 

541 Returns: 

542 A feed dictionary or `None`. 

543 """ 

544 return self._init_feed_dict 

545 

546 @property 

547 def ready_op(self): 

548 """Return the Ready Op used by the supervisor. 

549 

550 Returns: 

551 An Op or `None`. 

552 """ 

553 return self._ready_op 

554 

555 @property 

556 def ready_for_local_init_op(self): 

557 return self._ready_for_local_init_op 

558 

559 @property 

560 def summary_writer(self): 

561 """Return the SummaryWriter used by the chief supervisor. 

562 

563 Returns: 

564 A SummaryWriter. 

565 """ 

566 return self._summary_writer 

567 

568 @property 

569 def summary_op(self): 

570 """Return the Summary Tensor used by the chief supervisor. 

571 

572 Returns: 

573 A string Tensor for the summary or `None`. 

574 """ 

575 return self._summary_op 

576 

577 @property 

578 def save_summaries_secs(self): 

579 """Return the delay between summary computations. 

580 

581 Returns: 

582 A timestamp. 

583 """ 

584 return self._save_summaries_secs 

585 

586 @property 

587 def global_step(self): 

588 """Return the global_step Tensor used by the supervisor. 

589 

590 Returns: 

591 An integer Tensor for the global_step. 

592 """ 

593 return self._global_step 

594 

595 @property 

596 def saver(self): 

597 """Return the Saver used by the supervisor. 

598 

599 Returns: 

600 A Saver object. 

601 """ 

602 return self._saver 

603 

604 @property 

605 def save_model_secs(self): 

606 """Return the delay between checkpoints. 

607 

608 Returns: 

609 A timestamp. 

610 """ 

611 return self._save_model_secs 

612 

613 @property 

614 def save_path(self): 

615 """Return the save path used by the supervisor. 

616 

617 Returns: 

618 A string. 

619 """ 

620 return self._save_path 

621 

622 def _write_graph(self): 

623 """Writes graph_def to `logdir` and adds it to summary if applicable.""" 

624 assert self._is_chief 

625 if self._logdir: 

626 training_util.write_graph( 

627 self._graph.as_graph_def(add_shapes=True), self._logdir, 

628 "graph.pbtxt") 

629 if self._summary_writer and not self._graph_added_to_summary: 

630 self._summary_writer.add_graph(self._graph) 

631 self._summary_writer.add_meta_graph(self._meta_graph_def) 

632 self._graph_added_to_summary = True 

633 

634 def start_standard_services(self, sess): 

635 """Start the standard services for 'sess'. 

636 

637 This starts services in the background. The services started depend 

638 on the parameters to the constructor and may include: 

639 

640 - A Summary thread computing summaries every save_summaries_secs. 

641 - A Checkpoint thread saving the model every save_model_secs. 

642 - A StepCounter thread measure step time. 

643 

644 Args: 

645 sess: A Session. 

646 

647 Returns: 

648 A list of threads that are running the standard services. You can use 

649 the Supervisor's Coordinator to join these threads with: 

650 sv.coord.Join(<list of threads>) 

651 

652 Raises: 

653 RuntimeError: If called with a non-chief Supervisor. 

654 ValueError: If not `logdir` was passed to the constructor as the 

655 services need a log directory. 

656 """ 

657 if not self._is_chief: 

658 raise RuntimeError("Only chief supervisor can start standard services. " 

659 "Because only chief supervisors can write events.") 

660 

661 if not self._logdir: 

662 logging.warning("Standard services need a 'logdir' " 

663 "passed to the SessionManager") 

664 return 

665 

666 if self._global_step is not None and self._summary_writer: 

667 # Only add the session log if we keep track of global step. 

668 # TensorBoard cannot use START message for purging expired events 

669 # if there is no step value. 

670 current_step = training_util.global_step(sess, self._global_step) 

671 self._summary_writer.add_session_log( 

672 SessionLog(status=SessionLog.START), current_step) 

673 

674 threads = [] 

675 if self._save_summaries_secs and self._summary_writer: 

676 if self._summary_op is not None: 

677 threads.append(SVSummaryThread(self, sess)) 

678 if self._global_step is not None: 

679 threads.append(SVStepCounterThread(self, sess)) 

680 if self.saver and self._save_model_secs: 

681 threads.append(SVTimerCheckpointThread(self, sess)) 

682 for t in threads: 

683 t.start() 

684 return threads 

685 

686 def prepare_or_wait_for_session(self, 

687 master="", 

688 config=None, 

689 wait_for_checkpoint=False, 

690 max_wait_secs=7200, 

691 start_standard_services=True): 

692 """Make sure the model is ready to be used. 

693 

694 Create a session on 'master', recovering or initializing the model as 

695 needed, or wait for a session to be ready. If running as the chief 

696 and `start_standard_service` is set to True, also call the session 

697 manager to start the standard services. 

698 

699 Args: 

700 master: name of the TensorFlow master to use. See the 

701 `tf.compat.v1.Session` constructor for how this is interpreted. 

702 config: Optional ConfigProto proto used to configure the session, which is 

703 passed as-is to create the session. 

704 wait_for_checkpoint: Whether we should wait for the availability of a 

705 checkpoint before creating Session. Defaults to False. 

706 max_wait_secs: Maximum time to wait for the session to become available. 

707 start_standard_services: Whether to start the standard services and the 

708 queue runners. 

709 

710 Returns: 

711 A Session object that can be used to drive the model. 

712 """ 

713 # For users who recreate the session with prepare_or_wait_for_session(), we 

714 # need to clear the coordinator's stop_event so that threads managed by the 

715 # coordinator can run. 

716 self._coord.clear_stop() 

717 if self._summary_writer: 

718 self._summary_writer.reopen() 

719 

720 if self._is_chief: 

721 sess = self._session_manager.prepare_session( 

722 master, 

723 init_op=self.init_op, 

724 saver=self.saver, 

725 checkpoint_dir=self._logdir, 

726 wait_for_checkpoint=wait_for_checkpoint, 

727 max_wait_secs=max_wait_secs, 

728 config=config, 

729 init_feed_dict=self._init_feed_dict, 

730 init_fn=self._init_fn) 

731 self._write_graph() 

732 if start_standard_services: 

733 logging.info("Starting standard services.") 

734 self.start_standard_services(sess) 

735 else: 

736 sess = self._session_manager.wait_for_session( 

737 master, config=config, max_wait_secs=max_wait_secs) 

738 if start_standard_services: 

739 logging.info("Starting queue runners.") 

740 self.start_queue_runners(sess) 

741 return sess 

742 

743 def start_queue_runners(self, sess, queue_runners=None): 

744 """Start threads for `QueueRunners`. 

745 

746 Note that the queue runners collected in the graph key `QUEUE_RUNNERS` 

747 are already started automatically when you create a session with the 

748 supervisor, so unless you have non-collected queue runners to start 

749 you do not need to call this explicitly. 

750 

751 Args: 

752 sess: A `Session`. 

753 queue_runners: A list of `QueueRunners`. If not specified, we'll use the 

754 list of queue runners gathered in the graph under the key 

755 `GraphKeys.QUEUE_RUNNERS`. 

756 

757 Returns: 

758 The list of threads started for the `QueueRunners`. 

759 

760 Raises: 

761 RuntimeError: If called with eager execution enabled. 

762 

763 @compatibility(eager) 

764 Queues are not compatible with eager execution. To ingest data when eager 

765 execution is enabled, use the `tf.data` API. 

766 @end_compatibility 

767 """ 

768 if context.executing_eagerly(): 

769 raise RuntimeError("Queues are not compatible with eager execution.") 

770 if queue_runners is None: 

771 queue_runners = self._graph.get_collection(ops.GraphKeys.QUEUE_RUNNERS) 

772 threads = [] 

773 for qr in queue_runners: 

774 threads.extend( 

775 qr.create_threads(sess, coord=self._coord, daemon=True, start=True)) 

776 return threads 

777 

778 def loop(self, timer_interval_secs, target, args=None, kwargs=None): 

779 """Start a LooperThread that calls a function periodically. 

780 

781 If `timer_interval_secs` is None the thread calls `target(*args, **kwargs)` 

782 repeatedly. Otherwise it calls it every `timer_interval_secs` 

783 seconds. The thread terminates when a stop is requested. 

784 

785 The started thread is added to the list of threads managed by the supervisor 

786 so it does not need to be passed to the `stop()` method. 

787 

788 Args: 

789 timer_interval_secs: Number. Time boundaries at which to call `target`. 

790 target: A callable object. 

791 args: Optional arguments to pass to `target` when calling it. 

792 kwargs: Optional keyword arguments to pass to `target` when calling it. 

793 

794 Returns: 

795 The started thread. 

796 """ 

797 looper = coordinator.LooperThread( 

798 self._coord, 

799 timer_interval_secs, 

800 target=target, 

801 args=args, 

802 kwargs=kwargs) 

803 looper.start() 

804 return looper 

805 

806 def stop(self, 

807 threads=None, 

808 close_summary_writer=True, 

809 ignore_live_threads=False): 

810 """Stop the services and the coordinator. 

811 

812 This does not close the session. 

813 

814 Args: 

815 threads: Optional list of threads to join with the coordinator. If 

816 `None`, defaults to the threads running the standard services, the 

817 threads started for `QueueRunners`, and the threads started by the 

818 `loop()` method. To wait on additional threads, pass the list in this 

819 parameter. 

820 close_summary_writer: Whether to close the `summary_writer`. Defaults to 

821 `True` if the summary writer was created by the supervisor, `False` 

822 otherwise. 

823 ignore_live_threads: If `True` ignores threads that remain running after a 

824 grace period when joining threads via the coordinator, instead of 

825 raising a RuntimeError. 

826 """ 

827 self._coord.request_stop() 

828 try: 

829 # coord.join() re-raises the first reported exception; the "finally" 

830 # block ensures that we clean up whether or not an exception was 

831 # reported. 

832 self._coord.join( 

833 threads, 

834 stop_grace_period_secs=self._stop_grace_secs, 

835 ignore_live_threads=ignore_live_threads) 

836 finally: 

837 # Close the writer last, in case one of the running threads was using it. 

838 if close_summary_writer and self._summary_writer: 

839 # Stop messages are not logged with event.step, 

840 # since the session may have already terminated. 

841 self._summary_writer.add_session_log(SessionLog(status=SessionLog.STOP)) 

842 self._summary_writer.close() 

843 self._graph_added_to_summary = False 

844 

845 def request_stop(self, ex=None): 

846 """Request that the coordinator stop the threads. 

847 

848 See `Coordinator.request_stop()`. 

849 

850 Args: 

851 ex: Optional `Exception`, or Python `exc_info` tuple as returned by 

852 `sys.exc_info()`. If this is the first call to `request_stop()` the 

853 corresponding exception is recorded and re-raised from `join()`. 

854 """ 

855 self._coord.request_stop(ex=ex) 

856 

857 def should_stop(self): 

858 """Check if the coordinator was told to stop. 

859 

860 See `Coordinator.should_stop()`. 

861 

862 Returns: 

863 True if the coordinator was told to stop, False otherwise. 

864 """ 

865 return self._coord.should_stop() 

866 

867 def stop_on_exception(self): 

868 """Context handler to stop the supervisor when an exception is raised. 

869 

870 See `Coordinator.stop_on_exception()`. 

871 

872 Returns: 

873 A context handler. 

874 """ 

875 return self._coord.stop_on_exception() 

876 

877 def wait_for_stop(self): 

878 """Block waiting for the coordinator to stop.""" 

879 self._coord.wait_for_stop() 

880 

881 def summary_computed(self, sess, summary, global_step=None): 

882 """Indicate that a summary was computed. 

883 

884 Args: 

885 sess: A `Session` object. 

886 summary: A Summary proto, or a string holding a serialized summary proto. 

887 global_step: Int. global step this summary is associated with. If `None`, 

888 it will try to fetch the current step. 

889 

890 Raises: 

891 TypeError: if 'summary' is not a Summary proto or a string. 

892 RuntimeError: if the Supervisor was created without a `logdir`. 

893 """ 

894 if not self._summary_writer: 

895 raise RuntimeError("Writing a summary requires a summary writer.") 

896 if global_step is None and self.global_step is not None: 

897 global_step = training_util.global_step(sess, self.global_step) 

898 self._summary_writer.add_summary(summary, global_step) 

899 

900 def _default_global_step_tensor(self): 

901 """Returns the global_step from the default graph. 

902 

903 Returns: 

904 The global step `Tensor` or `None`. 

905 """ 

906 try: 

907 gs = ops.get_default_graph().get_tensor_by_name("global_step:0") 

908 if gs.dtype.base_dtype in [dtypes.int32, dtypes.int64]: 

909 return gs 

910 else: 

911 logging.warning("Found 'global_step' is not an int type: %s", gs.dtype) 

912 return None 

913 except KeyError: 

914 return None 

915 

916 def _verify_setup(self): 

917 """Check that all is good. 

918 

919 Raises: 

920 ValueError: If something is not good. 

921 """ 

922 # Not running as chief means that replicas are used. 

923 # In that case all Variables must have their device set. 

924 if not self._is_chief: 

925 for op in self._graph.get_operations(): 

926 if op.type in ["Variable", "VariableV2"] and not op.device: 

927 raise ValueError("When using replicas, all Variables must have " 

928 "their device set: %s" % op) 

929 

930 # pylint: disable=g-doc-return-or-yield,broad-except 

931 @contextlib.contextmanager 

932 def managed_session(self, 

933 master="", 

934 config=None, 

935 start_standard_services=True, 

936 close_summary_writer=True): 

937 """Returns a context manager for a managed session. 

938 

939 This context manager creates and automatically recovers a session. It 

940 optionally starts the standard services that handle checkpoints and 

941 summaries. It monitors exceptions raised from the `with` block or from the 

942 services and stops the supervisor as needed. 

943 

944 The context manager is typically used as follows: 

945 

946 ```python 

947 def train(): 

948 sv = tf.compat.v1.train.Supervisor(...) 

949 with sv.managed_session(<master>) as sess: 

950 for step in range(..): 

951 if sv.should_stop(): 

952 break 

953 sess.run(<my training op>) 

954 ...do other things needed at each training step... 

955 ``` 

956 

957 An exception raised from the `with` block or one of the service threads is 

958 raised again when the block exits. This is done after stopping all threads 

959 and closing the session. For example, an `AbortedError` exception, raised 

960 in case of preemption of one of the workers in a distributed model, is 

961 raised again when the block exits. 

962 

963 If you want to retry the training loop in case of preemption you can do it 

964 as follows: 

965 

966 ```python 

967 def main(...): 

968 while True 

969 try: 

970 train() 

971 except tf.errors.Aborted: 

972 pass 

973 ``` 

974 

975 As a special case, exceptions used for control flow, such as 

976 `OutOfRangeError` which reports that input queues are exhausted, are not 

977 raised again from the `with` block: they indicate a clean termination of 

978 the training loop and are considered normal termination. 

979 

980 Args: 

981 master: name of the TensorFlow master to use. See the 

982 `tf.compat.v1.Session` constructor for how this is interpreted. 

983 config: Optional `ConfigProto` proto used to configure the session. Passed 

984 as-is to create the session. 

985 start_standard_services: Whether to start the standard services, such as 

986 checkpoint, summary and step counter. 

987 close_summary_writer: Whether to close the summary writer when closing the 

988 session. Defaults to True. 

989 

990 Returns: 

991 A context manager that yields a `Session` restored from the latest 

992 checkpoint or initialized from scratch if not checkpoint exists. The 

993 session is closed when the `with` block exits. 

994 """ 

995 try: 

996 sess = self.prepare_or_wait_for_session( 

997 master=master, 

998 config=config, 

999 start_standard_services=start_standard_services) 

1000 yield sess 

1001 except Exception as e: 

1002 self.request_stop(e) 

1003 finally: 

1004 try: 

1005 # Request all the threads to stop and wait for them to do so. Any 

1006 # exception raised by the threads is raised again from stop(). 

1007 # Passing stop_grace_period_secs is for blocked enqueue/dequeue 

1008 # threads which are not checking for `should_stop()`. They 

1009 # will be stopped when we close the session further down. 

1010 self.stop(close_summary_writer=close_summary_writer) 

1011 finally: 

1012 # Close the session to finish up all pending calls. We do not care 

1013 # about exceptions raised when closing. This takes care of 

1014 # blocked enqueue/dequeue calls. 

1015 try: 

1016 sess.close() 

1017 except Exception: 

1018 # Silently ignore exceptions raised by close(). 

1019 pass 

1020 

1021 # pylint: enable=g-doc-return-or-yield,broad-except 

1022 

1023 

1024class SVSummaryThread(coordinator.LooperThread): 

1025 """A thread to save summaries on a timer.""" 

1026 

1027 def __init__(self, sv, sess): 

1028 """Create a SVSummaryThread. 

1029 

1030 Args: 

1031 sv: A `Supervisor`. 

1032 sess: A `Session`. 

1033 """ 

1034 super(SVSummaryThread, self).__init__(sv.coord, sv.save_summaries_secs) 

1035 self._sv = sv 

1036 self._sess = sess 

1037 

1038 def run_loop(self): 

1039 if self._sv.global_step is not None: 

1040 summary_strs, global_step = self._sess.run( 

1041 [self._sv.summary_op, self._sv.global_step]) 

1042 else: 

1043 summary_strs = self._sess.run(self._sv.summary_op) 

1044 global_step = None 

1045 if self._sv.summary_writer: 

1046 logging.info("Recording summary at step %s.", global_step) 

1047 self._sv.summary_writer.add_summary(summary_strs, global_step) 

1048 

1049 

1050class SVStepCounterThread(coordinator.LooperThread): 

1051 """Threads to count steps and measure their duration.""" 

1052 

1053 def __init__(self, sv, sess, step_counter=None): 

1054 """Create a `SVStepCounterThread`. 

1055 

1056 Args: 

1057 sv: A `Supervisor`. 

1058 sess: A `Session`. 

1059 step_counter: A `Tensor` holding the step counter. By defaults, it uses 

1060 sv.global_step. 

1061 """ 

1062 super(SVStepCounterThread, self).__init__(sv.coord, sv.save_summaries_secs) 

1063 self._sv = sv 

1064 self._sess = sess 

1065 self._last_time = 0.0 

1066 self._last_step = 0 

1067 step_counter = sv.global_step if step_counter is None else step_counter 

1068 self._step_counter = step_counter 

1069 self._summary_tag = "%s/sec" % self._step_counter.op.name 

1070 

1071 def start_loop(self): 

1072 self._last_time = time.time() 

1073 self._last_step = training_util.global_step(self._sess, self._step_counter) 

1074 

1075 def run_loop(self): 

1076 # Count the steps. 

1077 current_step = training_util.global_step(self._sess, self._step_counter) 

1078 added_steps = current_step - self._last_step 

1079 self._last_step = current_step 

1080 # Measure the elapsed time. 

1081 current_time = time.time() 

1082 elapsed_time = current_time - self._last_time 

1083 self._last_time = current_time 

1084 # Reports the number of steps done per second 

1085 if elapsed_time > 0.: 

1086 steps_per_sec = added_steps / elapsed_time 

1087 else: 

1088 steps_per_sec = float("inf") 

1089 summary = Summary(value=[ 

1090 Summary.Value(tag=self._summary_tag, simple_value=steps_per_sec) 

1091 ]) 

1092 if self._sv.summary_writer: 

1093 self._sv.summary_writer.add_summary(summary, current_step) 

1094 logging.log_first_n(logging.INFO, "%s: %g", 10, self._summary_tag, 

1095 steps_per_sec) 

1096 

1097 

1098class SVTimerCheckpointThread(coordinator.LooperThread): 

1099 """A thread to checkpoint on a timer.""" 

1100 

1101 def __init__(self, sv, sess): 

1102 """Create a `SVTimerCheckpointThread`. 

1103 

1104 Args: 

1105 sv: A `Supervisor`. 

1106 sess: A `Session`. 

1107 """ 

1108 super(SVTimerCheckpointThread, self).__init__(sv.coord, sv.save_model_secs) 

1109 self._sv = sv 

1110 self._sess = sess 

1111 

1112 def run_loop(self): 

1113 logging.info("Saving checkpoint to path %s", self._sv.save_path) 

1114 self._sv.saver.save( 

1115 self._sess, self._sv.save_path, global_step=self._sv.global_step) 

1116 if self._sv.summary_writer and self._sv.global_step is not None: 

1117 current_step = training_util.global_step(self._sess, self._sv.global_step) 

1118 self._sv.summary_writer.add_session_log( 

1119 SessionLog( 

1120 status=SessionLog.CHECKPOINT, checkpoint_path=self._sv.save_path), 

1121 current_step) 

1122 

1123 

1124# TODO(sherrym): All non-PEP8 compliant names will be deprecated shortly. 

1125setattr(Supervisor, "PrepareSession", Supervisor.prepare_or_wait_for_session) 

1126setattr(Supervisor, "StartQueueRunners", Supervisor.start_queue_runners) 

1127setattr(Supervisor, "StartStandardServices", Supervisor.start_standard_services) 

1128setattr(Supervisor, "Stop", Supervisor.stop) 

1129setattr(Supervisor, "RequestStop", Supervisor.request_stop) 

1130setattr(Supervisor, "Loop", Supervisor.loop) 

1131setattr(Supervisor, "ShouldStop", Supervisor.should_stop) 

1132setattr(Supervisor, "StopOnException", Supervisor.stop_on_exception) 

1133setattr(Supervisor, "WaitForStop", Supervisor.wait_for_stop) 

1134setattr(Supervisor, "SummaryComputed", Supervisor.summary_computed)