Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/training/basic_session_run_hooks.py: 26%

487 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Some common SessionRunHook classes. 

16 

17Note that the symbols that are exported to v1 tf.train namespace are also 

18exported to v2 in tf.estimator namespace. See 

19https://github.com/tensorflow/estimator/blob/master/tensorflow_estimator/python/estimator/hooks/basic_session_run_hooks.py 

20""" 

21 

22import os 

23import time 

24 

25import numpy as np 

26 

27from tensorflow.core.framework.summary_pb2 import Summary 

28from tensorflow.core.protobuf import config_pb2 

29from tensorflow.core.util.event_pb2 import SessionLog 

30from tensorflow.python.client import timeline 

31from tensorflow.python.framework import dtypes 

32from tensorflow.python.framework import errors 

33from tensorflow.python.framework import meta_graph 

34from tensorflow.python.framework import ops 

35from tensorflow.python.ops import init_ops 

36from tensorflow.python.ops import variable_scope 

37from tensorflow.python.platform import gfile 

38from tensorflow.python.platform import tf_logging as logging 

39from tensorflow.python.training import session_run_hook 

40from tensorflow.python.training import training_util 

41from tensorflow.python.training.session_run_hook import SessionRunArgs 

42from tensorflow.python.training.summary_io import SummaryWriterCache 

43from tensorflow.python.util.tf_export import tf_export 

44 

45_HOOKS = "hooks" 

46_STEPS_PER_RUN_VAR = "steps_per_run" 

47 

48 

49class _HookTimer: 

50 """Base timer for determining when Hooks should trigger. 

51 

52 Should not be instantiated directly. 

53 """ 

54 

55 def __init__(self): 

56 pass 

57 

58 def reset(self): 

59 """Resets the timer.""" 

60 pass 

61 

62 def should_trigger_for_step(self, step): 

63 """Return true if the timer should trigger for the specified step.""" 

64 raise NotImplementedError 

65 

66 def update_last_triggered_step(self, step): 

67 """Update the last triggered time and step number. 

68 

69 Args: 

70 step: The current step. 

71 

72 Returns: 

73 A pair `(elapsed_time, elapsed_steps)`, where `elapsed_time` is the number 

74 of seconds between the current trigger and the last one (a float), and 

75 `elapsed_steps` is the number of steps between the current trigger and 

76 the last one. Both values will be set to `None` on the first trigger. 

77 """ 

78 raise NotImplementedError 

79 

80 def last_triggered_step(self): 

81 """Returns the last triggered time step or None if never triggered.""" 

82 raise NotImplementedError 

83 

84 

85@tf_export(v1=["train.SecondOrStepTimer"]) 

86class SecondOrStepTimer(_HookTimer): 

87 """Timer that triggers at most once every N seconds or once every N steps. 

88 

89 This symbol is also exported to v2 in tf.estimator namespace. See 

90 https://github.com/tensorflow/estimator/blob/master/tensorflow_estimator/python/estimator/hooks/basic_session_run_hooks.py 

91 """ 

92 

93 def __init__(self, every_secs=None, every_steps=None): 

94 self.reset() 

95 self._every_secs = every_secs 

96 self._every_steps = every_steps 

97 

98 if self._every_secs is None and self._every_steps is None: 

99 raise ValueError("Either every_secs or every_steps should be provided.") 

100 if (self._every_secs is not None) and (self._every_steps is not None): 

101 raise ValueError("Can not provide both every_secs and every_steps.") 

102 

103 super(SecondOrStepTimer, self).__init__() 

104 

105 def reset(self): 

106 self._last_triggered_step = None 

107 self._last_triggered_time = None 

108 

109 def should_trigger_for_step(self, step): 

110 """Return true if the timer should trigger for the specified step. 

111 

112 Args: 

113 step: Training step to trigger on. 

114 

115 Returns: 

116 True if the difference between the current time and the time of the last 

117 trigger exceeds `every_secs`, or if the difference between the current 

118 step and the last triggered step exceeds `every_steps`. False otherwise. 

119 """ 

120 if self._last_triggered_step is None: 

121 return True 

122 

123 if self._last_triggered_step == step: 

124 return False 

125 

126 if self._every_secs is not None: 

127 if time.time() >= self._last_triggered_time + self._every_secs: 

128 return True 

129 

130 if self._every_steps is not None: 

131 if step >= self._last_triggered_step + self._every_steps: 

132 return True 

133 

134 return False 

135 

136 def update_last_triggered_step(self, step): 

137 current_time = time.time() 

138 if self._last_triggered_time is None: 

139 elapsed_secs = None 

140 elapsed_steps = None 

141 else: 

142 elapsed_secs = current_time - self._last_triggered_time 

143 elapsed_steps = step - self._last_triggered_step 

144 

145 self._last_triggered_time = current_time 

146 self._last_triggered_step = step 

147 return (elapsed_secs, elapsed_steps) 

148 

149 def last_triggered_step(self): 

150 return self._last_triggered_step 

151 

152 

153class NeverTriggerTimer(_HookTimer): 

154 """Timer that never triggers.""" 

155 

156 def should_trigger_for_step(self, step): 

157 _ = step 

158 return False 

159 

160 def update_last_triggered_step(self, step): 

161 _ = step 

162 return (None, None) 

163 

164 def last_triggered_step(self): 

165 return None 

166 

167 

168@tf_export(v1=["train.LoggingTensorHook"]) 

169class LoggingTensorHook(session_run_hook.SessionRunHook): 

170 """Prints the given tensors every N local steps, every N seconds, or at end. 

171 

172 The tensors will be printed to the log, with `INFO` severity. If you are not 

173 seeing the logs, you might want to add the following line after your imports: 

174 

175 ```python 

176 tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO) 

177 ``` 

178 

179 Note that if `at_end` is True, `tensors` should not include any tensor 

180 whose evaluation produces a side effect such as consuming additional inputs. 

181 

182 @compatibility(TF2) 

183 Please check this [notebook][notebook] on how to migrate the API to TF2. 

184 

185 [notebook]:https://github.com/tensorflow/docs/blob/master/site/en/guide/migrate/logging_stop_hook.ipynb 

186 

187 @end_compatibility 

188 

189 """ 

190 

191 def __init__(self, 

192 tensors, 

193 every_n_iter=None, 

194 every_n_secs=None, 

195 at_end=False, 

196 formatter=None): 

197 """Initializes a `LoggingTensorHook`. 

198 

199 Args: 

200 tensors: `dict` that maps string-valued tags to tensors/tensor names, or 

201 `iterable` of tensors/tensor names. 

202 every_n_iter: `int`, print the values of `tensors` once every N local 

203 steps taken on the current worker. 

204 every_n_secs: `int` or `float`, print the values of `tensors` once every N 

205 seconds. Exactly one of `every_n_iter` and `every_n_secs` should be 

206 provided. 

207 at_end: `bool` specifying whether to print the values of `tensors` at the 

208 end of the run. 

209 formatter: function, takes dict of `tag`->`Tensor` and returns a string. 

210 If `None` uses default printing all tensors. 

211 

212 Raises: 

213 ValueError: if `every_n_iter` is non-positive. 

214 """ 

215 only_log_at_end = ( 

216 at_end and (every_n_iter is None) and (every_n_secs is None)) 

217 if (not only_log_at_end and 

218 (every_n_iter is None) == (every_n_secs is None)): 

219 raise ValueError( 

220 "either at_end and/or exactly one of every_n_iter and every_n_secs " 

221 "must be provided.") 

222 if every_n_iter is not None and every_n_iter <= 0: 

223 raise ValueError("invalid every_n_iter=%s." % every_n_iter) 

224 if not isinstance(tensors, dict): 

225 self._tag_order = tensors 

226 tensors = {item: item for item in tensors} 

227 else: 

228 self._tag_order = sorted(tensors.keys()) 

229 self._tensors = tensors 

230 self._formatter = formatter 

231 self._timer = ( 

232 NeverTriggerTimer() if only_log_at_end else SecondOrStepTimer( 

233 every_secs=every_n_secs, every_steps=every_n_iter)) 

234 self._log_at_end = at_end 

235 

236 def begin(self): 

237 self._timer.reset() 

238 self._iter_count = 0 

239 # Convert names to tensors if given 

240 self._current_tensors = { 

241 tag: _as_graph_element(tensor) 

242 for (tag, tensor) in self._tensors.items() 

243 } 

244 

245 def before_run(self, run_context): # pylint: disable=unused-argument 

246 self._should_trigger = self._timer.should_trigger_for_step(self._iter_count) 

247 if self._should_trigger: 

248 return SessionRunArgs(self._current_tensors) 

249 else: 

250 return None 

251 

252 def _log_tensors(self, tensor_values): 

253 original = np.get_printoptions() 

254 np.set_printoptions(suppress=True) 

255 elapsed_secs, _ = self._timer.update_last_triggered_step(self._iter_count) 

256 if self._formatter: 

257 logging.info(self._formatter(tensor_values)) 

258 else: 

259 stats = [] 

260 for tag in self._tag_order: 

261 stats.append("%s = %s" % (tag, tensor_values[tag])) 

262 if elapsed_secs is not None: 

263 logging.info("%s (%.3f sec)", ", ".join(stats), elapsed_secs) 

264 else: 

265 logging.info("%s", ", ".join(stats)) 

266 np.set_printoptions(**original) 

267 

268 def after_run(self, run_context, run_values): 

269 _ = run_context 

270 if self._should_trigger: 

271 self._log_tensors(run_values.results) 

272 

273 self._iter_count += 1 

274 

275 def end(self, session): 

276 if self._log_at_end: 

277 values = session.run(self._current_tensors) 

278 self._log_tensors(values) 

279 

280 

281def get_or_create_steps_per_run_variable(): 

282 """Gets or creates the steps_per_run variable. 

283 

284 In Estimator, the user provided computation, the model_fn, is wrapped 

285 inside a tf.while_loop for peak performance. The iterations of the loop are 

286 specified by this variable, which adjusts its value on the CPU after each 

287 device program execution and before the next execution. 

288 

289 The purpose of using a variable, rather than a constant, is to allow 

290 Estimator adapt the device training iterations according to the final steps 

291 specified by users. For example, if the user sets the steps_per_run as 

292 4 and steps as 10 in Estimator.train(), the steps_per_run 

293 variable will have the following value before each training run. 

294 

295 - 1-st execution: steps_per_run = 4 

296 - 2-nd execution: steps_per_run = 4 

297 - 3-rd execution: steps_per_run = 2 

298 

299 As model_fn increases the global step once per train_op invocation, the global 

300 step is 10 after all executions, matching the steps=10 inputs passed in by 

301 users. 

302 

303 Returns: 

304 A TF non-trainable resource variable. 

305 

306 Raises: 

307 RuntimeError: If multi steps_per_run variables were found. 

308 """ 

309 graph = ops.get_default_graph() 

310 collection_name = "{}_{}".format(_HOOKS, _STEPS_PER_RUN_VAR) 

311 steps_per_run_vars = graph.get_collection(collection_name) 

312 if len(steps_per_run_vars) == 1: 

313 return steps_per_run_vars[0] 

314 elif len(steps_per_run_vars) > 1: 

315 raise RuntimeError("Multiple steps_per_run_var in collection.") 

316 

317 with variable_scope.variable_scope(_HOOKS, reuse=variable_scope.AUTO_REUSE): 

318 return variable_scope.get_variable( 

319 _STEPS_PER_RUN_VAR, 

320 initializer=init_ops.ones_initializer(), 

321 shape=[], 

322 dtype=dtypes.int32, 

323 trainable=False, 

324 collections=[collection_name, ops.GraphKeys.LOCAL_VARIABLES], 

325 use_resource=True) 

326 

327 

328class _MultiStepStopAtStepHook(session_run_hook.SessionRunHook): 

329 """Hook that requests stop at a specified step.""" 

330 

331 def __init__(self, num_steps=None, last_step=None, steps_per_run=1): 

332 """Initializes a `MultiStepStopAtStepHook`. 

333 

334 This hook requests stop after either a number of steps have been 

335 executed or a last step has been reached. Only one of the two options can be 

336 specified. 

337 

338 if `num_steps` is specified, it indicates the number of steps to execute 

339 after `begin()` is called. If instead `last_step` is specified, it 

340 indicates the last step we want to execute, as passed to the `after_run()` 

341 call. 

342 

343 In Estimator, the user provided computation, the model_fn, is wrapped 

344 inside a tf.while_loop for peak performance. The steps_per_run variable 

345 determines the number of iterations of the loop before returning to the CPU. 

346 

347 Args: 

348 num_steps: Number of steps to execute. 

349 last_step: Step after which to stop. 

350 steps_per_run: Number of steps executed per run call. 

351 

352 Raises: 

353 ValueError: If one of the arguments is invalid. 

354 """ 

355 if num_steps is None and last_step is None: 

356 raise ValueError("One of num_steps or last_step must be specified.") 

357 if num_steps is not None and last_step is not None: 

358 raise ValueError("Only one of num_steps or last_step can be specified.") 

359 if steps_per_run is None or steps_per_run < 1: 

360 raise ValueError("steps_per_run should be greater than 0") 

361 self._num_steps = num_steps 

362 self._last_step = last_step 

363 self._steps_per_run_initial_value = steps_per_run 

364 

365 def begin(self): 

366 self._global_step_tensor = training_util.get_global_step() 

367 if self._global_step_tensor is None: 

368 raise RuntimeError("Global step should be created to use StopAtStepHook.") 

369 self._steps_per_run_variable = get_or_create_steps_per_run_variable() 

370 

371 def _update_steps_per_run_variable(self, global_step, session): 

372 steps = min(self._last_step - global_step, 

373 self._steps_per_run_initial_value) 

374 self._steps_per_run_variable.load(steps, session=session) 

375 

376 def after_create_session(self, session, coord): 

377 global_step = session.run(self._global_step_tensor) 

378 if self._last_step is None: 

379 self._last_step = global_step + self._num_steps 

380 self._update_steps_per_run_variable(global_step, session) 

381 

382 def after_run(self, run_context, run_values): 

383 # Global step cannot be retrieved via SessionRunArgs and before_run due to 

384 # race condition in hook execution. 

385 global_step = run_context.session.run(self._global_step_tensor) 

386 if global_step >= self._last_step: 

387 run_context.request_stop() 

388 else: 

389 self._update_steps_per_run_variable(global_step, run_context.session) 

390 

391 

392@tf_export(v1=["train.StopAtStepHook"]) 

393class StopAtStepHook(session_run_hook.SessionRunHook): 

394 """Hook that requests stop at a specified step. 

395 

396 @compatibility(TF2) 

397 Please check this [notebook][notebook] on how to migrate the API to TF2. 

398 

399 [notebook]:https://github.com/tensorflow/docs/blob/master/site/en/guide/migrate/logging_stop_hook.ipynb 

400 

401 @end_compatibility 

402 """ 

403 

404 def __init__(self, num_steps=None, last_step=None): 

405 """Initializes a `StopAtStepHook`. 

406 

407 This hook requests stop after either a number of steps have been 

408 executed or a last step has been reached. Only one of the two options can be 

409 specified. 

410 

411 if `num_steps` is specified, it indicates the number of steps to execute 

412 after `begin()` is called. If instead `last_step` is specified, it 

413 indicates the last step we want to execute, as passed to the `after_run()` 

414 call. 

415 

416 Args: 

417 num_steps: Number of steps to execute. 

418 last_step: Step after which to stop. 

419 

420 Raises: 

421 ValueError: If one of the arguments is invalid. 

422 """ 

423 if num_steps is None and last_step is None: 

424 raise ValueError("One of num_steps or last_step must be specified.") 

425 if num_steps is not None and last_step is not None: 

426 raise ValueError("Only one of num_steps or last_step can be specified.") 

427 self._num_steps = num_steps 

428 self._last_step = last_step 

429 

430 def begin(self): 

431 self._global_step_tensor = training_util._get_or_create_global_step_read() # pylint: disable=protected-access 

432 if self._global_step_tensor is None: 

433 raise RuntimeError("Global step should be created to use StopAtStepHook.") 

434 

435 def after_create_session(self, session, coord): 

436 if self._last_step is None: 

437 global_step = session.run(self._global_step_tensor) 

438 self._last_step = global_step + self._num_steps 

439 

440 def before_run(self, run_context): # pylint: disable=unused-argument 

441 return SessionRunArgs(self._global_step_tensor) 

442 

443 def after_run(self, run_context, run_values): 

444 global_step = run_values.results + 1 

445 if global_step >= self._last_step: 

446 # Check latest global step to ensure that the targeted last step is 

447 # reached. global_step read tensor is the value of global step 

448 # before running the operation. We're not sure whether current session.run 

449 # incremented the global_step or not. Here we're checking it. 

450 

451 step = run_context.session.run(self._global_step_tensor) 

452 if step >= self._last_step: 

453 run_context.request_stop() 

454 

455 

456@tf_export(v1=["train.CheckpointSaverListener"]) 

457class CheckpointSaverListener: 

458 """Interface for listeners that take action before or after checkpoint save. 

459 

460 `CheckpointSaverListener` triggers only in steps when `CheckpointSaverHook` is 

461 triggered, and provides callbacks at the following points: 

462 - before using the session 

463 - before each call to `Saver.save()` 

464 - after each call to `Saver.save()` 

465 - at the end of session 

466 

467 To use a listener, implement a class and pass the listener to a 

468 `CheckpointSaverHook`, as in this example: 

469 

470 ```python 

471 class ExampleCheckpointSaverListener(CheckpointSaverListener): 

472 def begin(self): 

473 # You can add ops to the graph here. 

474 print('Starting the session.') 

475 self.your_tensor = ... 

476 

477 def before_save(self, session, global_step_value): 

478 print('About to write a checkpoint') 

479 

480 def after_save(self, session, global_step_value): 

481 print('Done writing checkpoint.') 

482 if decided_to_stop_training(): 

483 return True 

484 

485 def end(self, session, global_step_value): 

486 print('Done with the session.') 

487 

488 ... 

489 listener = ExampleCheckpointSaverListener() 

490 saver_hook = tf.estimator.CheckpointSaverHook( 

491 checkpoint_dir, listeners=[listener]) 

492 with 

493 tf.compat.v1.train.MonitoredTrainingSession(chief_only_hooks=[saver_hook]): 

494 ... 

495 ``` 

496 

497 A `CheckpointSaverListener` may simply take some action after every 

498 checkpoint save. It is also possible for the listener to use its own schedule 

499 to act less frequently, e.g. based on global_step_value. In this case, 

500 implementors should implement the `end()` method to handle actions related to 

501 the last checkpoint save. But the listener should not act twice if 

502 `after_save()` already handled this last checkpoint save. 

503 

504 A `CheckpointSaverListener` can request training to be stopped, by returning 

505 True in `after_save`. Please note that, in replicated distributed training 

506 setting, only `chief` should use this behavior. Otherwise each worker will do 

507 their own evaluation, which may be wasteful of resources. 

508 """ 

509 

510 def begin(self): 

511 pass 

512 

513 def before_save(self, session, global_step_value): 

514 pass 

515 

516 def after_save(self, session, global_step_value): 

517 pass 

518 

519 def end(self, session, global_step_value): 

520 pass 

521 

522 

523@tf_export(v1=["train.CheckpointSaverHook"]) 

524class CheckpointSaverHook(session_run_hook.SessionRunHook): 

525 """Saves checkpoints every N steps or seconds.""" 

526 

527 def __init__(self, 

528 checkpoint_dir, 

529 save_secs=None, 

530 save_steps=None, 

531 saver=None, 

532 checkpoint_basename="model.ckpt", 

533 scaffold=None, 

534 listeners=None, 

535 save_graph_def=True): 

536 """Initializes a `CheckpointSaverHook`. 

537 

538 Args: 

539 checkpoint_dir: `str`, base directory for the checkpoint files. 

540 save_secs: `int`, save every N secs. 

541 save_steps: `int`, save every N steps. 

542 saver: `Saver` object, used for saving. 

543 checkpoint_basename: `str`, base name for the checkpoint files. 

544 scaffold: `Scaffold`, use to get saver object. 

545 listeners: List of `CheckpointSaverListener` subclass instances. Used for 

546 callbacks that run immediately before or after this hook saves the 

547 checkpoint. 

548 save_graph_def: Whether to save the GraphDef and MetaGraphDef to 

549 `checkpoint_dir`. The GraphDef is saved after the session is created as 

550 `graph.pbtxt`. MetaGraphDefs are saved out for every checkpoint as 

551 `model.ckpt-*.meta`. 

552 

553 Raises: 

554 ValueError: One of `save_steps` or `save_secs` should be set. 

555 ValueError: At most one of `saver` or `scaffold` should be set. 

556 """ 

557 logging.info("Create CheckpointSaverHook.") 

558 if saver is not None and scaffold is not None: 

559 raise ValueError("You cannot provide both saver and scaffold.") 

560 self._saver = saver 

561 self._checkpoint_dir = checkpoint_dir 

562 self._save_path = os.path.join(checkpoint_dir, checkpoint_basename) 

563 self._scaffold = scaffold 

564 self._timer = SecondOrStepTimer( 

565 every_secs=save_secs, every_steps=save_steps) 

566 self._listeners = listeners or [] 

567 # Set sufficiently high default that it never skips checking the actual 

568 # global step counter -- unless the user overrides it with the right value 

569 # for the steps_per_run. 

570 self._steps_per_run = 1000000 

571 self._save_graph_def = save_graph_def 

572 

573 def _set_steps_per_run(self, steps_per_run): 

574 self._steps_per_run = steps_per_run 

575 

576 def begin(self): 

577 self._summary_writer = SummaryWriterCache.get(self._checkpoint_dir) 

578 self._global_step_tensor = training_util._get_or_create_global_step_read() # pylint: disable=protected-access 

579 if self._global_step_tensor is None: 

580 raise RuntimeError( 

581 "Global step should be created to use CheckpointSaverHook.") 

582 for l in self._listeners: 

583 l.begin() 

584 

585 def after_create_session(self, session, coord): 

586 global_step = session.run(self._global_step_tensor) 

587 if self._save_graph_def: 

588 # We do write graph and saver_def at the first call of before_run. 

589 # We cannot do this in begin, since we let other hooks to change graph and 

590 # add variables in begin. Graph is finalized after all begin calls. 

591 training_util.write_graph( 

592 ops.get_default_graph().as_graph_def(add_shapes=True), 

593 self._checkpoint_dir, "graph.pbtxt") 

594 saver_def = self._get_saver().saver_def if self._get_saver() else None 

595 graph = ops.get_default_graph() 

596 meta_graph_def = meta_graph.create_meta_graph_def( 

597 graph_def=graph.as_graph_def(add_shapes=True), saver_def=saver_def) 

598 self._summary_writer.add_graph(graph) 

599 self._summary_writer.add_meta_graph(meta_graph_def) 

600 # The checkpoint saved here is the state at step "global_step". 

601 self._save(session, global_step) 

602 self._timer.update_last_triggered_step(global_step) 

603 

604 def before_run(self, run_context): # pylint: disable=unused-argument 

605 return SessionRunArgs(self._global_step_tensor) 

606 

607 def after_run(self, run_context, run_values): 

608 stale_global_step = run_values.results 

609 if self._timer.should_trigger_for_step(stale_global_step + 

610 self._steps_per_run): 

611 # get the real value after train op. 

612 global_step = run_context.session.run(self._global_step_tensor) 

613 if self._timer.should_trigger_for_step(global_step): 

614 self._timer.update_last_triggered_step(global_step) 

615 if self._save(run_context.session, global_step): 

616 run_context.request_stop() 

617 

618 def end(self, session): 

619 last_step = session.run(self._global_step_tensor) 

620 if last_step != self._timer.last_triggered_step(): 

621 self._save(session, last_step) 

622 for l in self._listeners: 

623 l.end(session, last_step) 

624 

625 def _save(self, session, step): 

626 """Saves the latest checkpoint, returns should_stop.""" 

627 logging.info("Calling checkpoint listeners before saving checkpoint %d...", 

628 step) 

629 for l in self._listeners: 

630 l.before_save(session, step) 

631 

632 logging.info("Saving checkpoints for %d into %s.", step, self._save_path) 

633 self._get_saver().save(session, self._save_path, global_step=step, 

634 write_meta_graph=self._save_graph_def) 

635 self._summary_writer.add_session_log( 

636 SessionLog( 

637 status=SessionLog.CHECKPOINT, checkpoint_path=self._save_path), 

638 step) 

639 logging.info("Calling checkpoint listeners after saving checkpoint %d...", 

640 step) 

641 should_stop = False 

642 for l in self._listeners: 

643 if l.after_save(session, step): 

644 logging.info( 

645 "A CheckpointSaverListener requested that training be stopped. " 

646 "listener: {}".format(l)) 

647 should_stop = True 

648 return should_stop 

649 

650 def _get_saver(self): 

651 if self._saver is not None: 

652 return self._saver 

653 elif self._scaffold is not None: 

654 return self._scaffold.saver 

655 

656 # Get saver from the SAVERS collection if present. 

657 collection_key = ops.GraphKeys.SAVERS 

658 savers = ops.get_collection(collection_key) 

659 if not savers: 

660 raise RuntimeError( 

661 "No items in collection {}. Please add a saver to the collection " 

662 "or provide a saver or scaffold.".format(collection_key)) 

663 elif len(savers) > 1: 

664 raise RuntimeError( 

665 "More than one item in collection {}. " 

666 "Please indicate which one to use by passing it to the constructor." 

667 .format(collection_key)) 

668 

669 self._saver = savers[0] 

670 return savers[0] 

671 

672 

673@tf_export(v1=["train.StepCounterHook"]) 

674class StepCounterHook(session_run_hook.SessionRunHook): 

675 """Hook that counts steps per second.""" 

676 

677 def __init__(self, 

678 every_n_steps=100, 

679 every_n_secs=None, 

680 output_dir=None, 

681 summary_writer=None): 

682 

683 if (every_n_steps is None) == (every_n_secs is None): 

684 raise ValueError( 

685 "exactly one of every_n_steps and every_n_secs should be provided.") 

686 self._timer = SecondOrStepTimer( 

687 every_steps=every_n_steps, every_secs=every_n_secs) 

688 

689 self._summary_writer = summary_writer 

690 self._output_dir = output_dir 

691 self._last_global_step = None 

692 self._steps_per_run = 1 

693 

694 def _set_steps_per_run(self, steps_per_run): 

695 self._steps_per_run = steps_per_run 

696 

697 def begin(self): 

698 if self._summary_writer is None and self._output_dir: 

699 self._summary_writer = SummaryWriterCache.get(self._output_dir) 

700 self._global_step_tensor = training_util._get_or_create_global_step_read() # pylint: disable=protected-access 

701 if self._global_step_tensor is None: 

702 raise RuntimeError( 

703 "Global step should be created to use StepCounterHook.") 

704 self._summary_tag = training_util.get_global_step().op.name + "/sec" 

705 

706 def before_run(self, run_context): # pylint: disable=unused-argument 

707 return SessionRunArgs(self._global_step_tensor) 

708 

709 def _log_and_record(self, elapsed_steps, elapsed_time, global_step): 

710 steps_per_sec = elapsed_steps / elapsed_time 

711 if self._summary_writer is not None: 

712 summary = Summary(value=[ 

713 Summary.Value(tag=self._summary_tag, simple_value=steps_per_sec) 

714 ]) 

715 self._summary_writer.add_summary(summary, global_step) 

716 logging.info("%s: %g", self._summary_tag, steps_per_sec) 

717 

718 def after_run(self, run_context, run_values): 

719 _ = run_context 

720 

721 stale_global_step = run_values.results 

722 if self._timer.should_trigger_for_step(stale_global_step + 

723 self._steps_per_run): 

724 # get the real value after train op. 

725 global_step = run_context.session.run(self._global_step_tensor) 

726 if self._timer.should_trigger_for_step(global_step): 

727 elapsed_time, elapsed_steps = self._timer.update_last_triggered_step( 

728 global_step) 

729 if elapsed_time is not None: 

730 self._log_and_record(elapsed_steps, elapsed_time, global_step) 

731 

732 # Check whether the global step has been increased. Here, we do not use the 

733 # timer.last_triggered_step as the timer might record a different global 

734 # step value such that the comparison could be unreliable. For simplicity, 

735 # we just compare the stale_global_step with previously recorded version. 

736 if stale_global_step == self._last_global_step: 

737 # Here, we give a warning in the first 5 times if we have observed that 

738 # the global step has not been increased. For some Optimizers, the global 

739 # step is not increased each time by design. For example, 

740 # SyncReplicaOptimizer doesn't increase the global step in worker's main 

741 # train step. 

742 logging.log_first_n( 

743 logging.WARN, 

744 "It seems that global step (tf.train.get_global_step) has not " 

745 "been increased. Current value (could be stable): %s vs previous " 

746 "value: %s. You could increase the global step by passing " 

747 "tf.train.get_global_step() to Optimizer.apply_gradients or " 

748 "Optimizer.minimize.", 5, stale_global_step, self._last_global_step) 

749 

750 self._last_global_step = stale_global_step 

751 

752 

753@tf_export(v1=["train.NanLossDuringTrainingError"]) 

754class NanLossDuringTrainingError(RuntimeError): 

755 

756 def __str__(self): 

757 return "NaN loss during training." 

758 

759 

760@tf_export(v1=["train.NanTensorHook"]) 

761class NanTensorHook(session_run_hook.SessionRunHook): 

762 """Monitors the loss tensor and stops training if loss is NaN. 

763 

764 Can either fail with exception or just stop training. 

765 """ 

766 

767 def __init__(self, loss_tensor, fail_on_nan_loss=True): 

768 """Initializes a `NanTensorHook`. 

769 

770 Args: 

771 loss_tensor: `Tensor`, the loss tensor. 

772 fail_on_nan_loss: `bool`, whether to raise exception when loss is NaN. 

773 """ 

774 self._loss_tensor = loss_tensor 

775 self._fail_on_nan_loss = fail_on_nan_loss 

776 

777 def before_run(self, run_context): # pylint: disable=unused-argument 

778 return SessionRunArgs(self._loss_tensor) 

779 

780 def after_run(self, run_context, run_values): 

781 if np.isnan(run_values.results): 

782 failure_message = "Model diverged with loss = NaN." 

783 if self._fail_on_nan_loss: 

784 logging.error(failure_message) 

785 raise NanLossDuringTrainingError 

786 else: 

787 logging.warning(failure_message) 

788 # We don't raise an error but we request stop without an exception. 

789 run_context.request_stop() 

790 

791 

792@tf_export(v1=["train.SummarySaverHook"]) 

793class SummarySaverHook(session_run_hook.SessionRunHook): 

794 """Saves summaries every N steps.""" 

795 

796 def __init__(self, 

797 save_steps=None, 

798 save_secs=None, 

799 output_dir=None, 

800 summary_writer=None, 

801 scaffold=None, 

802 summary_op=None): 

803 """Initializes a `SummarySaverHook`. 

804 

805 Args: 

806 save_steps: `int`, save summaries every N steps. Exactly one of 

807 `save_secs` and `save_steps` should be set. 

808 save_secs: `int`, save summaries every N seconds. 

809 output_dir: `string`, the directory to save the summaries to. Only used if 

810 no `summary_writer` is supplied. 

811 summary_writer: `SummaryWriter`. If `None` and an `output_dir` was passed, 

812 one will be created accordingly. 

813 scaffold: `Scaffold` to get summary_op if it's not provided. 

814 summary_op: `Tensor` of type `string` containing the serialized `Summary` 

815 protocol buffer or a list of `Tensor`. They are most likely an output by 

816 TF summary methods like `tf.compat.v1.summary.scalar` or 

817 `tf.compat.v1.summary.merge_all`. It can be passed in as one tensor; if 

818 more than one, they must be passed in as a list. 

819 

820 Raises: 

821 ValueError: Exactly one of scaffold or summary_op should be set. 

822 """ 

823 if ((scaffold is None and summary_op is None) or 

824 (scaffold is not None and summary_op is not None)): 

825 raise ValueError( 

826 "Exactly one of scaffold or summary_op must be provided.") 

827 self._summary_op = summary_op 

828 self._summary_writer = summary_writer 

829 self._output_dir = output_dir 

830 self._scaffold = scaffold 

831 self._timer = SecondOrStepTimer( 

832 every_secs=save_secs, every_steps=save_steps) 

833 # TODO(mdan): Throw an error if output_dir and summary_writer are None. 

834 

835 def begin(self): 

836 if self._summary_writer is None and self._output_dir: 

837 self._summary_writer = SummaryWriterCache.get(self._output_dir) 

838 self._next_step = None 

839 self._global_step_tensor = training_util._get_or_create_global_step_read() # pylint: disable=protected-access 

840 if self._global_step_tensor is None: 

841 raise RuntimeError( 

842 "Global step should be created to use SummarySaverHook.") 

843 

844 def before_run(self, run_context): # pylint: disable=unused-argument 

845 self._request_summary = ( 

846 self._next_step is None or 

847 self._timer.should_trigger_for_step(self._next_step)) 

848 requests = {"global_step": self._global_step_tensor} 

849 if self._request_summary: 

850 if self._get_summary_op() is not None: 

851 requests["summary"] = self._get_summary_op() 

852 

853 return SessionRunArgs(requests) 

854 

855 def after_run(self, run_context, run_values): 

856 _ = run_context 

857 if not self._summary_writer: 

858 return 

859 

860 stale_global_step = run_values.results["global_step"] 

861 global_step = stale_global_step + 1 

862 if self._next_step is None or self._request_summary: 

863 global_step = run_context.session.run(self._global_step_tensor) 

864 

865 if self._next_step is None: 

866 self._summary_writer.add_session_log( 

867 SessionLog(status=SessionLog.START), global_step) 

868 

869 if self._request_summary: 

870 self._timer.update_last_triggered_step(global_step) 

871 if "summary" in run_values.results: 

872 for summary in run_values.results["summary"]: 

873 self._summary_writer.add_summary(summary, global_step) 

874 

875 self._next_step = global_step + 1 

876 

877 def end(self, session=None): 

878 if self._summary_writer: 

879 self._summary_writer.flush() 

880 

881 def _get_summary_op(self): 

882 """Fetches the summary op either from self._summary_op or self._scaffold. 

883 

884 Returns: 

885 Returns a list of summary `Tensor`. 

886 """ 

887 summary_op = None 

888 if self._summary_op is not None: 

889 summary_op = self._summary_op 

890 elif self._scaffold.summary_op is not None: 

891 summary_op = self._scaffold.summary_op 

892 

893 if summary_op is None: 

894 return None 

895 

896 if not isinstance(summary_op, list): 

897 return [summary_op] 

898 return summary_op 

899 

900 

901@tf_export(v1=["train.GlobalStepWaiterHook"]) 

902class GlobalStepWaiterHook(session_run_hook.SessionRunHook): 

903 """Delays execution until global step reaches `wait_until_step`. 

904 

905 This hook delays execution until global step reaches to `wait_until_step`. It 

906 is used to gradually start workers in distributed settings. One example usage 

907 would be setting `wait_until_step=int(K*log(task_id+1))` assuming that 

908 task_id=0 is the chief. 

909 """ 

910 

911 def __init__(self, wait_until_step): 

912 """Initializes a `GlobalStepWaiterHook`. 

913 

914 Args: 

915 wait_until_step: an `int` shows until which global step should we wait. 

916 """ 

917 self._wait_until_step = wait_until_step 

918 

919 def begin(self): 

920 self._worker_is_started = False 

921 self._global_step_tensor = training_util._get_or_create_global_step_read() # pylint: disable=protected-access 

922 if self._global_step_tensor is None: 

923 raise RuntimeError( 

924 "Global step should be created to use _GlobalStepWaiterHook.") 

925 

926 def before_run(self, run_context): 

927 if self._worker_is_started: 

928 return None 

929 

930 if self._wait_until_step <= 0: 

931 self._worker_is_started = True 

932 return None 

933 

934 logging.info("Waiting for global step %d before starting training.", 

935 self._wait_until_step) 

936 last_logged_step = 0 

937 while True: 

938 current_step = run_context.session.run(self._global_step_tensor) 

939 if current_step >= self._wait_until_step: 

940 self._worker_is_started = True 

941 return None 

942 if current_step - last_logged_step > 1000: 

943 logging.info( 

944 "Waiting for global step %d before starting training. " 

945 "Current step is %d.", self._wait_until_step, current_step) 

946 last_logged_step = current_step 

947 time.sleep(0.5) 

948 

949 

950@tf_export(v1=["train.FinalOpsHook"]) 

951class FinalOpsHook(session_run_hook.SessionRunHook): 

952 """A hook which evaluates `Tensors` at the end of a session.""" 

953 

954 def __init__(self, final_ops, final_ops_feed_dict=None): 

955 """Initializes `FinalOpHook` with ops to run at the end of the session. 

956 

957 Args: 

958 final_ops: A single `Tensor`, a list of `Tensors` or a dictionary of names 

959 to `Tensors`. 

960 final_ops_feed_dict: A feed dictionary to use when running 

961 `final_ops_dict`. 

962 """ 

963 self._final_ops = final_ops 

964 self._final_ops_feed_dict = final_ops_feed_dict 

965 self._final_ops_values = None 

966 

967 @property 

968 def final_ops_values(self): 

969 return self._final_ops_values 

970 

971 def end(self, session): 

972 if self._final_ops is not None: 

973 try: 

974 self._final_ops_values = session.run( 

975 self._final_ops, feed_dict=self._final_ops_feed_dict) 

976 except (errors.OutOfRangeError, StopIteration) as e: 

977 logging.warning( 

978 "An OutOfRangeError or StopIteration exception is raised by the " 

979 "code in FinalOpsHook. This typically means the Ops running by the " 

980 "FinalOpsHook have a dependency back to some input source, which " 

981 "should not happen. For example, for metrics in " 

982 "tf.estimator.Estimator, all metrics functions return two Ops: " 

983 "`value_op` and `update_op`. Estimator.evaluate calls the " 

984 "`update_op` for each batch of the data in input source and, once " 

985 "it is exhausted, it call the `value_op` to get the metric values. " 

986 "The `value_op` here should have dependency back to variables " 

987 "reading only, rather than reading another batch from input. " 

988 "Otherwise, the `value_op`, executed by `FinalOpsHook`, triggers " 

989 "another data reading, which ends OutOfRangeError/StopIteration. " 

990 "Please fix that.") 

991 raise e 

992 

993 

994@tf_export(v1=["train.FeedFnHook"]) 

995class FeedFnHook(session_run_hook.SessionRunHook): 

996 """Runs `feed_fn` and sets the `feed_dict` accordingly.""" 

997 

998 def __init__(self, feed_fn): 

999 """Initializes a `FeedFnHook`. 

1000 

1001 Args: 

1002 feed_fn: function that takes no arguments and returns `dict` of `Tensor` 

1003 to feed. 

1004 """ 

1005 self.feed_fn = feed_fn 

1006 

1007 def before_run(self, run_context): # pylint: disable=unused-argument 

1008 return session_run_hook.SessionRunArgs( 

1009 fetches=None, feed_dict=self.feed_fn()) 

1010 

1011 

1012@tf_export(v1=["train.ProfilerHook"]) 

1013class ProfilerHook(session_run_hook.SessionRunHook): 

1014 """Captures CPU/GPU profiling information every N steps or seconds. 

1015 

1016 This produces files called "timeline-<step>.json", which are in Chrome 

1017 Trace format. 

1018 

1019 For more information see: 

1020 https://github.com/catapult-project/catapult/blob/master/tracing/README.md 

1021 """ 

1022 

1023 def __init__(self, 

1024 save_steps=None, 

1025 save_secs=None, 

1026 output_dir="", 

1027 show_dataflow=True, 

1028 show_memory=False): 

1029 """Initializes a hook that takes periodic profiling snapshots. 

1030 

1031 `options.run_metadata` argument of `tf.Session.Run` is used to collect 

1032 metadata about execution. This hook sets the metadata and dumps it in Chrome 

1033 Trace format. 

1034 

1035 

1036 Args: 

1037 save_steps: `int`, save profile traces every N steps. Exactly one of 

1038 `save_secs` and `save_steps` should be set. 

1039 save_secs: `int` or `float`, save profile traces every N seconds. 

1040 output_dir: `string`, the directory to save the profile traces to. 

1041 Defaults to the current directory. 

1042 show_dataflow: `bool`, if True, add flow events to the trace connecting 

1043 producers and consumers of tensors. 

1044 show_memory: `bool`, if True, add object snapshot events to the trace 

1045 showing the sizes and lifetimes of tensors. 

1046 """ 

1047 self._output_file = os.path.join(output_dir, "timeline-{}.json") 

1048 self._file_writer = SummaryWriterCache.get(output_dir) 

1049 self._show_dataflow = show_dataflow 

1050 self._show_memory = show_memory 

1051 self._timer = SecondOrStepTimer( 

1052 every_secs=save_secs, every_steps=save_steps) 

1053 

1054 def begin(self): 

1055 self._next_step = None 

1056 self._global_step_tensor = training_util._get_or_create_global_step_read() # pylint: disable=protected-access 

1057 if self._global_step_tensor is None: 

1058 raise RuntimeError("Global step should be created to use ProfilerHook.") 

1059 

1060 def before_run(self, run_context): 

1061 self._request_summary = ( 

1062 self._next_step is not None and 

1063 self._timer.should_trigger_for_step(self._next_step)) 

1064 requests = {"global_step": self._global_step_tensor} 

1065 opts = ( 

1066 config_pb2.RunOptions(trace_level=config_pb2.RunOptions.FULL_TRACE) 

1067 if self._request_summary else None) 

1068 

1069 return SessionRunArgs(requests, options=opts) 

1070 

1071 def after_run(self, run_context, run_values): 

1072 stale_global_step = run_values.results["global_step"] 

1073 if self._next_step is None: 

1074 # Update the timer so that it does not activate until N steps or seconds 

1075 # have passed. 

1076 self._timer.update_last_triggered_step(stale_global_step) 

1077 global_step = stale_global_step + 1 

1078 if self._request_summary: 

1079 global_step = run_context.session.run(self._global_step_tensor) 

1080 self._timer.update_last_triggered_step(global_step) 

1081 self._save(global_step, self._output_file.format(global_step), 

1082 run_values.run_metadata.step_stats) 

1083 self._file_writer.add_run_metadata(run_values.run_metadata, 

1084 "step_%d" % global_step) 

1085 

1086 self._next_step = global_step + 1 

1087 

1088 def _save(self, step, save_path, step_stats): 

1089 logging.info("Saving timeline for %d into '%s'.", step, save_path) 

1090 with gfile.Open(save_path, "w") as f: 

1091 trace = timeline.Timeline(step_stats) 

1092 f.write( 

1093 trace.generate_chrome_trace_format( 

1094 show_dataflow=self._show_dataflow, show_memory=self._show_memory)) 

1095 

1096 

1097def _as_graph_element(obj): 

1098 """Retrieves Graph element.""" 

1099 graph = ops.get_default_graph() 

1100 if not isinstance(obj, str): 

1101 if not hasattr(obj, "graph") or obj.graph != graph: 

1102 raise ValueError("Passed %s should have graph attribute that is equal " 

1103 "to current graph %s." % (obj, graph)) 

1104 return obj 

1105 if ":" in obj: 

1106 element = graph.as_graph_element(obj) 

1107 else: 

1108 element = graph.as_graph_element(obj + ":0") 

1109 # Check that there is no :1 (e.g. it's single output). 

1110 try: 

1111 graph.as_graph_element(obj + ":1") 

1112 except (KeyError, ValueError): 

1113 pass 

1114 else: 

1115 raise ValueError("Name %s is ambiguous, " 

1116 "as this `Operation` has multiple outputs " 

1117 "(at least 2)." % obj) 

1118 return element