Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/training/monitored_session.py: 27%

474 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# pylint: disable=g-bad-file-header 

2# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 

3# 

4# Licensed under the Apache License, Version 2.0 (the "License"); 

5# you may not use this file except in compliance with the License. 

6# You may obtain a copy of the License at 

7# 

8# http://www.apache.org/licenses/LICENSE-2.0 

9# 

10# Unless required by applicable law or agreed to in writing, software 

11# distributed under the License is distributed on an "AS IS" BASIS, 

12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

13# See the License for the specific language governing permissions and 

14# limitations under the License. 

15# ============================================================================== 

16"""A wrapper of Session API which runs hooks.""" 

17 

18import abc 

19import os 

20 

21from tensorflow.core.protobuf import config_pb2 

22from tensorflow.python.checkpoint import checkpoint as trackable_util 

23from tensorflow.python.checkpoint import graph_view 

24from tensorflow.python.distribute import distribute_coordinator_context 

25from tensorflow.python.framework import errors 

26from tensorflow.python.framework import ops 

27from tensorflow.python.ops import array_ops 

28from tensorflow.python.ops import control_flow_ops 

29from tensorflow.python.ops import lookup_ops 

30from tensorflow.python.ops import resources 

31from tensorflow.python.ops import variables 

32from tensorflow.python.platform import tf_logging as logging 

33from tensorflow.python.summary import summary 

34from tensorflow.python.training import basic_session_run_hooks 

35from tensorflow.python.training import coordinator 

36from tensorflow.python.training import queue_runner 

37from tensorflow.python.training import saver as training_saver 

38from tensorflow.python.training import session_manager as sm 

39from tensorflow.python.training import session_run_hook 

40from tensorflow.python.util import function_utils 

41from tensorflow.python.util.tf_export import tf_export 

42 

43# The list of exceptions that we should recover from. Exceptions not in this 

44# list may terminate the job. 

45_PREEMPTION_ERRORS = (errors.AbortedError, errors.UnavailableError) 

46 

47# Value that indicates no value was provided. 

48USE_DEFAULT = object() 

49 

50 

51@tf_export(v1=['train.Scaffold']) 

52class Scaffold: 

53 """Structure to create or gather pieces commonly needed to train a model. 

54 

55 When you build a model for training you usually need ops to initialize 

56 variables, a `Saver` to checkpoint them, an op to collect summaries for 

57 the visualizer, and so on. 

58 

59 Various libraries built on top of the core TensorFlow library take care of 

60 creating some or all of these pieces and storing them in well known 

61 collections in the graph. The `Scaffold` class helps pick these pieces from 

62 the graph collections, creating and adding them to the collections if needed. 

63 

64 If you call the scaffold constructor without any arguments, it will pick 

65 pieces from the collections, creating default ones if needed when 

66 `scaffold.finalize()` is called. You can pass arguments to the constructor to 

67 provide your own pieces. Pieces that you pass to the constructor are not 

68 added to the graph collections. 

69 

70 The following pieces are directly accessible as attributes of the `Scaffold` 

71 object: 

72 

73 * `saver`: A `tf.compat.v1.train.Saver` object taking care of saving the 

74 variables. 

75 Picked from and stored into the `SAVERS` collection in the graph by default. 

76 * `init_op`: An op to run to initialize the variables. Picked from and 

77 stored into the `INIT_OP` collection in the graph by default. 

78 * `ready_op`: An op to verify that the variables are initialized. Picked 

79 from and stored into the `READY_OP` collection in the graph by default. 

80 * `ready_for_local_init_op`: An op to verify that global state has been 

81 initialized and it is alright to run `local_init_op`. Picked from and 

82 stored into the `READY_FOR_LOCAL_INIT_OP` collection in the graph by 

83 default. This is needed when the initialization of local variables depends 

84 on the values of global variables. 

85 * `local_init_op`: An op to initialize the local variables. Picked 

86 from and stored into the `LOCAL_INIT_OP` collection in the graph by default. 

87 * `summary_op`: An op to run and merge the summaries in the graph. Picked 

88 from and stored into the `SUMMARY_OP` collection in the graph by default. 

89 

90 You can also pass the following additional pieces to the constructor: 

91 

92 * `init_feed_dict`: A session feed dictionary that should be used when 

93 running the init op. 

94 * `init_fn`: A callable to run after the init op to perform additional 

95 initializations. The callable will be called as 

96 `init_fn(scaffold, session)`. 

97 

98 """ 

99 

100 def __init__(self, 

101 init_op=None, 

102 init_feed_dict=None, 

103 init_fn=None, 

104 ready_op=None, 

105 ready_for_local_init_op=None, 

106 local_init_op=None, 

107 summary_op=None, 

108 saver=None, 

109 copy_from_scaffold=None, 

110 local_init_feed_dict=None): 

111 """Create a scaffold. 

112 

113 Args: 

114 init_op: Optional op for initializing variables. 

115 init_feed_dict: Optional session feed dictionary to use when running the 

116 init_op. 

117 init_fn: Optional function to use to initialize the model after running 

118 the init_op. Will be called as `init_fn(scaffold, session)`. 

119 ready_op: Optional op to verify that the variables are initialized. Must 

120 return an empty 1D string tensor when the variables are initialized, or 

121 a non-empty 1D string tensor listing the names of the non-initialized 

122 variables. 

123 ready_for_local_init_op: Optional op to verify that the global variables 

124 are initialized and `local_init_op` can be run. Must return an empty 1D 

125 string tensor when the global variables are initialized, or a non-empty 

126 1D string tensor listing the names of the non-initialized global 

127 variables. 

128 local_init_op: Optional op to initialize local variables. 

129 summary_op: Optional op to gather all summaries. Must return a scalar 

130 string tensor containing a serialized `Summary` proto. 

131 saver: Optional `tf.compat.v1.train.Saver` object to use to save and 

132 restore variables. May also be a `tf.train.Checkpoint` object, in which 

133 case object-based checkpoints are saved. This will also load some 

134 object-based checkpoints saved from elsewhere, but that loading may be 

135 fragile since it uses fixed keys rather than performing a full 

136 graph-based match. For example if a variable has two paths from the 

137 `Checkpoint` object because two `Model` objects share the `Layer` object 

138 that owns it, removing one `Model` may change the keys and break 

139 checkpoint loading through this API, whereas a graph-based match would 

140 match the variable through the other `Model`. 

141 copy_from_scaffold: Optional scaffold object to copy fields from. Its 

142 fields will be overwritten by the provided fields in this function. 

143 local_init_feed_dict: Optional session feed dictionary to use when running 

144 the local_init_op. 

145 """ 

146 if copy_from_scaffold is not None: 

147 if not isinstance(copy_from_scaffold, Scaffold): 

148 raise TypeError('copy_from_scaffold is not a Scaffold instance.') 

149 # We need _coalesce since Tensor is not converted to bool automatically, 

150 # so the common idiom of (a or b) does not work. 

151 coalesce = lambda a, b: a if a is not None else b 

152 init_op = coalesce(init_op, copy_from_scaffold.init_op) 

153 init_feed_dict = coalesce(init_feed_dict, 

154 copy_from_scaffold.init_feed_dict) 

155 # Use the original init_fn provided by the user to init the new Scaffold. 

156 init_fn = coalesce(init_fn, copy_from_scaffold._user_init_fn) # pylint: disable=protected-access 

157 ready_op = coalesce(ready_op, copy_from_scaffold.ready_op) 

158 ready_for_local_init_op = coalesce( 

159 ready_for_local_init_op, copy_from_scaffold.ready_for_local_init_op) 

160 local_init_op = coalesce(local_init_op, copy_from_scaffold.local_init_op) 

161 local_init_feed_dict = coalesce(local_init_feed_dict, 

162 copy_from_scaffold.local_init_feed_dict) 

163 summary_op = coalesce(summary_op, copy_from_scaffold.summary_op) 

164 saver = coalesce(saver, copy_from_scaffold.saver) 

165 

166 # NOTE(touts): modifying the init function to be passed the scaffold is a 

167 # hack to make it easy to find the saver. Is there a better way? 

168 self._user_init_fn = init_fn 

169 if init_fn: 

170 self._init_fn = lambda sess: init_fn(self, sess) 

171 else: 

172 self._init_fn = None 

173 

174 self._init_op = init_op 

175 self._init_feed_dict = init_feed_dict 

176 self._ready_op = ready_op 

177 self._ready_for_local_init_op = ready_for_local_init_op 

178 self._local_init_op = local_init_op 

179 self._local_init_feed_dict = local_init_feed_dict 

180 self._summary_op = summary_op 

181 self._saver = saver 

182 

183 def finalize(self): 

184 """Creates operations if needed and finalizes the graph.""" 

185 if self._init_op is None: 

186 

187 def default_init_op(): 

188 return control_flow_ops.group( 

189 variables.global_variables_initializer(), 

190 resources.initialize_resources(resources.shared_resources()), 

191 ops.get_collection('saved_model_initializers')) 

192 

193 self._init_op = Scaffold.get_or_default('init_op', ops.GraphKeys.INIT_OP, 

194 default_init_op) 

195 if self._ready_op is None: 

196 

197 def default_ready_op(): 

198 return array_ops.concat([ 

199 variables.report_uninitialized_variables(), 

200 resources.report_uninitialized_resources() 

201 ], 0) 

202 

203 self._ready_op = Scaffold.get_or_default('ready_op', 

204 ops.GraphKeys.READY_OP, 

205 default_ready_op) 

206 if self._ready_for_local_init_op is None: 

207 

208 def default_ready_for_local_init_op(): 

209 return array_ops.concat([ 

210 variables.report_uninitialized_variables( 

211 variables.global_variables()), 

212 resources.report_uninitialized_resources( 

213 resources.shared_resources()) 

214 ], 0) 

215 

216 self._ready_for_local_init_op = Scaffold.get_or_default( 

217 'ready_for_local_init_op', ops.GraphKeys.READY_FOR_LOCAL_INIT_OP, 

218 default_ready_for_local_init_op) 

219 if self._local_init_op is None: 

220 self._local_init_op = Scaffold.get_or_default( 

221 'local_init_op', ops.GraphKeys.LOCAL_INIT_OP, 

222 Scaffold.default_local_init_op) 

223 if self._summary_op is None: 

224 self._summary_op = Scaffold.get_or_default('summary_op', 

225 ops.GraphKeys.SUMMARY_OP, 

226 summary.merge_all) 

227 # pylint: disable=g-long-lambda 

228 if self._saver is None: 

229 self._saver = training_saver._get_saver_or_default() # pylint: disable=protected-access 

230 # pylint: enable=g-long-lambda 

231 if isinstance(self._saver, trackable_util.Checkpoint): 

232 self._saver = training_saver.Saver( 

233 var_list=graph_view.ObjectGraphView( 

234 self._saver).frozen_saveable_objects(), 

235 sharded=True) 

236 else: 

237 self._saver.build() 

238 

239 ops.get_default_graph().finalize() 

240 logging.info('Graph was finalized.') 

241 return self 

242 

243 @property 

244 def init_fn(self): 

245 return self._init_fn 

246 

247 @property 

248 def init_op(self): 

249 return self._init_op 

250 

251 @property 

252 def ready_op(self): 

253 return self._ready_op 

254 

255 @property 

256 def ready_for_local_init_op(self): 

257 return self._ready_for_local_init_op 

258 

259 @property 

260 def local_init_op(self): 

261 return self._local_init_op 

262 

263 @property 

264 def local_init_feed_dict(self): 

265 return self._local_init_feed_dict 

266 

267 @property 

268 def summary_op(self): 

269 return self._summary_op 

270 

271 @property 

272 def saver(self): 

273 return self._saver 

274 

275 @property 

276 def init_feed_dict(self): 

277 return self._init_feed_dict 

278 

279 @staticmethod 

280 def get_or_default(arg_name, collection_key, default_constructor): 

281 """Get from cache or create a default operation.""" 

282 elements = ops.get_collection(collection_key) 

283 if elements: 

284 if len(elements) > 1: 

285 raise RuntimeError( 

286 'More than one item in the collection "%s". ' 

287 'Please indicate which one to use by passing it to ' 

288 'the tf.Scaffold constructor as: ' 

289 'tf.Scaffold(%s=item to use)', collection_key, arg_name) 

290 return elements[0] 

291 op = default_constructor() 

292 if op is not None: 

293 ops.add_to_collection(collection_key, op) 

294 return op 

295 

296 @staticmethod 

297 def default_local_init_op(): 

298 """Returns an op that groups the default local init ops. 

299 

300 This op is used during session initialization when a Scaffold is 

301 initialized without specifying the local_init_op arg. It includes 

302 `tf.compat.v1.local_variables_initializer`, 

303 `tf.compat.v1.tables_initializer`, and also 

304 initializes local session resources. 

305 

306 Returns: 

307 The default Scaffold local init op. 

308 """ 

309 return control_flow_ops.group( 

310 variables.local_variables_initializer(), 

311 lookup_ops.tables_initializer(), 

312 resources.initialize_resources(resources.local_resources())) 

313 

314 

315def _create_monitored_session_with_worker_context( 

316 worker_context, # pylint: disable=missing-docstring 

317 scaffold, 

318 checkpoint_dir=None, 

319 hooks=None, 

320 chief_only_hooks=None, 

321 save_checkpoint_secs=None, 

322 save_summaries_steps=None, 

323 save_summaries_secs=None, 

324 config=None, 

325 stop_grace_period_secs=120, 

326 log_step_count_steps=100, 

327 max_wait_secs=7200, 

328 save_checkpoint_steps=None, 

329 summary_dir=None, 

330 save_graph_def=True): 

331 all_hooks = [] 

332 if hooks: 

333 all_hooks.extend(hooks) 

334 if chief_only_hooks and worker_context.is_chief: 

335 all_hooks.extend(chief_only_hooks) 

336 

337 # We need to call save or summary ops on all workers since these ops may 

338 # contain collective ops, only running save ops on some workers would make 

339 # collective ops hang. Therefore on those workers that don't need to actually 

340 # write checkpoints or summaries, we let them write to a temp directory. 

341 # pylint: disable=protected-access 

342 if type( 

343 worker_context._strategy).__name__ in ('CollectiveAllReduceStrategy', 

344 'CollectiveAllReduceStrategyV1', 

345 'MultiWorkerMirroredStrategy'): 

346 if worker_context.task_type: 

347 tmpdir = 'tmp_%s_%d' % (worker_context.task_type, worker_context.task_id) 

348 else: 

349 tmpdir = 'tmp' 

350 

351 if save_checkpoint_secs: 

352 logging.warning('Collective ops may deadlock with ' 

353 '`save_checkpoints_secs` please use ' 

354 '`save_checkpoint_steps` instead. Clearing ' 

355 '`save_checkpoint_secs` and setting ' 

356 '`save_checkpoint_steps` to 1000 now.') 

357 save_checkpoint_secs = None 

358 save_checkpoint_steps = 1000 

359 if save_summaries_secs: 

360 logging.warning('Collective ops may run out of sync with' 

361 '`save_summaries_secs`, please use ' 

362 '`save_summaries_steps` instead.') 

363 else: 

364 tmpdir = None 

365 

366 summary_dir = summary_dir or checkpoint_dir 

367 if summary_dir and log_step_count_steps and log_step_count_steps > 0: 

368 if worker_context.should_save_summary: 

369 all_hooks.append( 

370 basic_session_run_hooks.StepCounterHook( 

371 output_dir=summary_dir, every_n_steps=log_step_count_steps)) 

372 elif tmpdir: 

373 all_hooks.append( 

374 basic_session_run_hooks.StepCounterHook( 

375 output_dir=os.path.join(summary_dir, tmpdir), 

376 every_n_steps=log_step_count_steps)) 

377 

378 if (((save_summaries_steps and save_summaries_steps > 0) or 

379 (save_summaries_secs and save_summaries_secs > 0)) and summary_dir): 

380 if worker_context.should_save_summary: 

381 all_hooks.append( 

382 basic_session_run_hooks.SummarySaverHook( 

383 scaffold=scaffold, 

384 save_steps=save_summaries_steps, 

385 save_secs=save_summaries_secs, 

386 output_dir=summary_dir)) 

387 elif tmpdir: 

388 all_hooks.append( 

389 basic_session_run_hooks.SummarySaverHook( 

390 scaffold=scaffold, 

391 save_steps=save_summaries_steps, 

392 save_secs=save_summaries_secs, 

393 output_dir=os.path.join(summary_dir, tmpdir))) 

394 

395 if (((save_checkpoint_secs and save_checkpoint_secs > 0) or 

396 (save_checkpoint_steps and save_checkpoint_steps > 0)) and 

397 checkpoint_dir): 

398 if worker_context.should_checkpoint: 

399 all_hooks.append( 

400 basic_session_run_hooks.CheckpointSaverHook( 

401 checkpoint_dir, 

402 save_steps=save_checkpoint_steps, 

403 save_secs=save_checkpoint_secs, 

404 scaffold=scaffold, 

405 save_graph_def=save_graph_def)) 

406 elif tmpdir: 

407 all_hooks.append( 

408 basic_session_run_hooks.CheckpointSaverHook( 

409 os.path.join(checkpoint_dir, tmpdir), 

410 save_steps=save_checkpoint_steps, 

411 save_secs=save_checkpoint_secs, 

412 scaffold=scaffold, 

413 save_graph_def=save_graph_def)) 

414 

415 logging.info('all_hooks %r', all_hooks) 

416 session_creator = worker_context.session_creator( 

417 scaffold, 

418 config=config, 

419 checkpoint_dir=checkpoint_dir, 

420 max_wait_secs=max_wait_secs) 

421 return MonitoredSession( 

422 session_creator=session_creator, 

423 hooks=all_hooks, 

424 stop_grace_period_secs=stop_grace_period_secs) 

425 

426 

427@tf_export(v1=['train.MonitoredTrainingSession']) 

428def MonitoredTrainingSession( 

429 master='', # pylint: disable=invalid-name 

430 is_chief=True, 

431 checkpoint_dir=None, 

432 scaffold=None, 

433 hooks=None, 

434 chief_only_hooks=None, 

435 save_checkpoint_secs=USE_DEFAULT, 

436 save_summaries_steps=USE_DEFAULT, 

437 save_summaries_secs=USE_DEFAULT, 

438 config=None, 

439 stop_grace_period_secs=120, 

440 log_step_count_steps=100, 

441 max_wait_secs=7200, 

442 save_checkpoint_steps=USE_DEFAULT, 

443 summary_dir=None, 

444 save_graph_def=True): 

445 """Creates a `MonitoredSession` for training. 

446 

447 For a chief, this utility sets proper session initializer/restorer. It also 

448 creates hooks related to checkpoint and summary saving. For workers, this 

449 utility sets proper session creator which waits for the chief to 

450 initialize/restore. Please check `tf.compat.v1.train.MonitoredSession` for 

451 more 

452 information. 

453 

454 @compatibility(TF2) 

455 This API is not compatible with eager execution and `tf.function`. To migrate 

456 to TF2, rewrite the code to be compatible with eager execution. Check the 

457 [migration 

458 guide](https://www.tensorflow.org/guide/migrate#1_replace_v1sessionrun_calls) 

459 on replacing `Session.run` calls. In Keras, session hooks can be replaced by 

460 Callbacks e.g. [logging hook notebook]( 

461 https://github.com/tensorflow/docs/blob/master/site/en/guide/migrate/logging_stop_hook.ipynb) 

462 For more details please read [Better 

463 performance with tf.function](https://www.tensorflow.org/guide/function). 

464 @end_compatibility 

465 

466 Args: 

467 master: `String` the TensorFlow master to use. 

468 is_chief: If `True`, it will take care of initialization and recovery the 

469 underlying TensorFlow session. If `False`, it will wait on a chief to 

470 initialize or recover the TensorFlow session. 

471 checkpoint_dir: A string. Optional path to a directory where to restore 

472 variables. 

473 scaffold: A `Scaffold` used for gathering or building supportive ops. If not 

474 specified, a default one is created. It's used to finalize the graph. 

475 hooks: Optional list of `SessionRunHook` objects. 

476 chief_only_hooks: list of `SessionRunHook` objects. Activate these hooks if 

477 `is_chief==True`, ignore otherwise. 

478 save_checkpoint_secs: The frequency, in seconds, that a checkpoint is saved 

479 using a default checkpoint saver. If both `save_checkpoint_steps` and 

480 `save_checkpoint_secs` are set to `None`, then the default checkpoint 

481 saver isn't used. If both are provided, then only `save_checkpoint_secs` 

482 is used. Default 600. 

483 save_summaries_steps: The frequency, in number of global steps, that the 

484 summaries are written to disk using a default summary saver. If both 

485 `save_summaries_steps` and `save_summaries_secs` are set to `None`, then 

486 the default summary saver isn't used. Default 100. 

487 save_summaries_secs: The frequency, in secs, that the summaries are written 

488 to disk using a default summary saver. If both `save_summaries_steps` and 

489 `save_summaries_secs` are set to `None`, then the default summary saver 

490 isn't used. Default not enabled. 

491 config: an instance of `tf.compat.v1.ConfigProto` proto used to configure 

492 the session. It's the `config` argument of constructor of 

493 `tf.compat.v1.Session`. 

494 stop_grace_period_secs: Number of seconds given to threads to stop after 

495 `close()` has been called. 

496 log_step_count_steps: The frequency, in number of global steps, that the 

497 global step/sec is logged. 

498 max_wait_secs: Maximum time workers should wait for the session to become 

499 available. This should be kept relatively short to help detect incorrect 

500 code, but sometimes may need to be increased if the chief takes a while to 

501 start up. 

502 save_checkpoint_steps: The frequency, in number of global steps, that a 

503 checkpoint is saved using a default checkpoint saver. If both 

504 `save_checkpoint_steps` and `save_checkpoint_secs` are set to `None`, then 

505 the default checkpoint saver isn't used. If both are provided, then only 

506 `save_checkpoint_secs` is used. Default not enabled. 

507 summary_dir: A string. Optional path to a directory where to save 

508 summaries. If None, checkpoint_dir is used instead. 

509 save_graph_def: Whether to save the GraphDef and MetaGraphDef to 

510 `checkpoint_dir`. The GraphDef is saved after the session is created as 

511 `graph.pbtxt`. MetaGraphDefs are saved out for every checkpoint as 

512 `model.ckpt-*.meta`. 

513 

514 Returns: 

515 A `MonitoredSession` object. 

516 """ 

517 if save_summaries_steps == USE_DEFAULT and save_summaries_secs == USE_DEFAULT: 

518 save_summaries_steps = 100 

519 save_summaries_secs = None 

520 elif save_summaries_secs == USE_DEFAULT: 

521 save_summaries_secs = None 

522 elif save_summaries_steps == USE_DEFAULT: 

523 save_summaries_steps = None 

524 

525 if (save_checkpoint_steps == USE_DEFAULT and 

526 save_checkpoint_secs == USE_DEFAULT): 

527 save_checkpoint_steps = None 

528 save_checkpoint_secs = 600 

529 elif save_checkpoint_secs == USE_DEFAULT: 

530 save_checkpoint_secs = None 

531 elif save_checkpoint_steps == USE_DEFAULT: 

532 save_checkpoint_steps = None 

533 

534 scaffold = scaffold or Scaffold() 

535 worker_context = distribute_coordinator_context.get_current_worker_context() 

536 

537 if worker_context: 

538 return _create_monitored_session_with_worker_context( 

539 worker_context, 

540 scaffold, 

541 checkpoint_dir=checkpoint_dir, 

542 hooks=hooks, 

543 chief_only_hooks=chief_only_hooks, 

544 save_checkpoint_secs=save_checkpoint_secs, 

545 save_summaries_steps=save_summaries_steps, 

546 save_summaries_secs=save_summaries_secs, 

547 config=config, 

548 stop_grace_period_secs=stop_grace_period_secs, 

549 log_step_count_steps=log_step_count_steps, 

550 max_wait_secs=max_wait_secs, 

551 save_checkpoint_steps=save_checkpoint_steps, 

552 summary_dir=summary_dir, 

553 save_graph_def=save_graph_def) 

554 

555 if not is_chief: 

556 session_creator = WorkerSessionCreator( 

557 scaffold=scaffold, 

558 master=master, 

559 config=config, 

560 max_wait_secs=max_wait_secs) 

561 return MonitoredSession( 

562 session_creator=session_creator, 

563 hooks=hooks or [], 

564 stop_grace_period_secs=stop_grace_period_secs) 

565 

566 all_hooks = [] 

567 if chief_only_hooks: 

568 all_hooks.extend(chief_only_hooks) 

569 session_creator = ChiefSessionCreator( 

570 scaffold=scaffold, 

571 checkpoint_dir=checkpoint_dir, 

572 master=master, 

573 config=config) 

574 

575 summary_dir = summary_dir or checkpoint_dir 

576 if summary_dir: 

577 if log_step_count_steps and log_step_count_steps > 0: 

578 all_hooks.append( 

579 basic_session_run_hooks.StepCounterHook( 

580 output_dir=summary_dir, every_n_steps=log_step_count_steps)) 

581 

582 if (save_summaries_steps and 

583 save_summaries_steps > 0) or (save_summaries_secs and 

584 save_summaries_secs > 0): 

585 all_hooks.append( 

586 basic_session_run_hooks.SummarySaverHook( 

587 scaffold=scaffold, 

588 save_steps=save_summaries_steps, 

589 save_secs=save_summaries_secs, 

590 output_dir=summary_dir)) 

591 

592 if checkpoint_dir: 

593 if (save_checkpoint_secs and 

594 save_checkpoint_secs > 0) or (save_checkpoint_steps and 

595 save_checkpoint_steps > 0): 

596 all_hooks.append( 

597 basic_session_run_hooks.CheckpointSaverHook( 

598 checkpoint_dir, 

599 save_steps=save_checkpoint_steps, 

600 save_secs=save_checkpoint_secs, 

601 scaffold=scaffold, 

602 save_graph_def=save_graph_def)) 

603 

604 if hooks: 

605 all_hooks.extend(hooks) 

606 return MonitoredSession( 

607 session_creator=session_creator, 

608 hooks=all_hooks, 

609 stop_grace_period_secs=stop_grace_period_secs) 

610 

611 

612@tf_export(v1=['train.SessionCreator']) 

613class SessionCreator(metaclass=abc.ABCMeta): 

614 """A factory for tf.Session.""" 

615 

616 @abc.abstractmethod 

617 def create_session(self): 

618 raise NotImplementedError( 

619 'create_session is not implemented for {}.'.format(self)) 

620 

621 

622@tf_export(v1=['train.ChiefSessionCreator']) 

623class ChiefSessionCreator(SessionCreator): 

624 """Creates a tf.compat.v1.Session for a chief.""" 

625 

626 def __init__(self, 

627 scaffold=None, 

628 master='', 

629 config=None, 

630 checkpoint_dir=None, 

631 checkpoint_filename_with_path=None): 

632 """Initializes a chief session creator. 

633 

634 Args: 

635 scaffold: A `Scaffold` used for gathering or building supportive ops. If 

636 not specified a default one is created. It's used to finalize the graph. 

637 master: `String` representation of the TensorFlow master to use. 

638 config: `ConfigProto` proto used to configure the session. 

639 checkpoint_dir: A string. Optional path to a directory where to restore 

640 variables. 

641 checkpoint_filename_with_path: Full file name path to the checkpoint file. 

642 """ 

643 self._checkpoint_dir = checkpoint_dir 

644 self._checkpoint_filename_with_path = checkpoint_filename_with_path 

645 self._scaffold = scaffold or Scaffold() 

646 self._session_manager = None 

647 self._master = master 

648 self._config = config 

649 

650 def _get_session_manager(self): 

651 """Gets or creates a SessionManager.""" 

652 if self._session_manager: 

653 return self._session_manager 

654 

655 self._session_manager = sm.SessionManager( 

656 local_init_op=self._scaffold.local_init_op, 

657 local_init_feed_dict=self._scaffold.local_init_feed_dict, 

658 ready_op=self._scaffold.ready_op, 

659 ready_for_local_init_op=self._scaffold.ready_for_local_init_op, 

660 graph=ops.get_default_graph()) 

661 return self._session_manager 

662 

663 def create_session(self): 

664 self._scaffold.finalize() 

665 return self._get_session_manager().prepare_session( 

666 self._master, 

667 saver=self._scaffold.saver, 

668 checkpoint_dir=self._checkpoint_dir, 

669 checkpoint_filename_with_path=self._checkpoint_filename_with_path, 

670 config=self._config, 

671 init_op=self._scaffold.init_op, 

672 init_feed_dict=self._scaffold.init_feed_dict, 

673 init_fn=self._scaffold.init_fn) 

674 

675 

676@tf_export(v1=['train.WorkerSessionCreator']) 

677class WorkerSessionCreator(SessionCreator): 

678 """Creates a tf.compat.v1.Session for a worker.""" 

679 

680 def __init__(self, 

681 scaffold=None, 

682 master='', 

683 config=None, 

684 max_wait_secs=30 * 60): 

685 """Initializes a worker session creator. 

686 

687 Args: 

688 scaffold: A `Scaffold` used for gathering or building supportive ops. If 

689 not specified a default one is created. It's used to finalize the graph. 

690 master: `String` representation of the TensorFlow master to use. 

691 config: `ConfigProto` proto used to configure the session. 

692 max_wait_secs: Maximum time to wait for the session to become available. 

693 """ 

694 self._scaffold = scaffold or Scaffold() 

695 self._session_manager = None 

696 self._master = master 

697 self._config = config 

698 self._max_wait_secs = max_wait_secs 

699 

700 def _get_session_manager(self): 

701 """Gets or creates a SessionManager.""" 

702 if self._session_manager: 

703 return self._session_manager 

704 

705 self._session_manager = sm.SessionManager( 

706 local_init_op=self._scaffold.local_init_op, 

707 local_init_feed_dict=self._scaffold.local_init_feed_dict, 

708 ready_op=self._scaffold.ready_op, 

709 ready_for_local_init_op=self._scaffold.ready_for_local_init_op, 

710 graph=ops.get_default_graph()) 

711 return self._session_manager 

712 

713 def create_session(self): 

714 self._scaffold.finalize() 

715 return self._get_session_manager().wait_for_session( 

716 self._master, config=self._config, max_wait_secs=self._max_wait_secs) 

717 

718 

719class _MonitoredSession: 

720 """See `MonitoredSession` or `SingularMonitoredSession`.""" 

721 

722 def __init__(self, 

723 session_creator, 

724 hooks, 

725 should_recover, 

726 stop_grace_period_secs=120): 

727 """Sets up a Monitored or Hooked Session. 

728 

729 Args: 

730 session_creator: A factory object to create session. Typically a 

731 `ChiefSessionCreator` or a `WorkerSessionCreator`. 

732 hooks: An iterable of `SessionRunHook' objects. 

733 should_recover: A bool. Indicates whether to recover from `AbortedError` 

734 and `UnavailableError` or not. 

735 stop_grace_period_secs: Number of seconds given to threads to stop after 

736 `close()` has been called. 

737 """ 

738 self._graph_was_finalized = ops.get_default_graph().finalized 

739 self._hooks = hooks or [] 

740 for h in self._hooks: 

741 h.begin() 

742 

743 worker_context = distribute_coordinator_context.get_current_worker_context() 

744 if not session_creator and worker_context: 

745 session_creator = worker_context.session_creator() 

746 

747 # Create the session. 

748 self._coordinated_creator = self._CoordinatedSessionCreator( 

749 session_creator=session_creator or ChiefSessionCreator(), 

750 hooks=self._hooks, 

751 stop_grace_period_secs=stop_grace_period_secs) 

752 if should_recover: 

753 self._sess = _RecoverableSession(self._coordinated_creator) 

754 else: 

755 self._sess = self._coordinated_creator.create_session() 

756 

757 @property 

758 def graph(self): 

759 """The graph that was launched in this session.""" 

760 if self._tf_sess() is None: 

761 return None 

762 return self._tf_sess().graph 

763 

764 def run(self, fetches, feed_dict=None, options=None, run_metadata=None): 

765 """Run ops in the monitored session. 

766 

767 This method is completely compatible with the `tf.Session.run()` method. 

768 

769 Args: 

770 fetches: Same as `tf.Session.run()`. 

771 feed_dict: Same as `tf.Session.run()`. 

772 options: Same as `tf.Session.run()`. 

773 run_metadata: Same as `tf.Session.run()`. 

774 

775 Returns: 

776 Same as `tf.Session.run()`. 

777 """ 

778 return self._sess.run( 

779 fetches, 

780 feed_dict=feed_dict, 

781 options=options, 

782 run_metadata=run_metadata) 

783 

784 def run_step_fn(self, step_fn): 

785 """Run ops using a step function. 

786 

787 Args: 

788 step_fn: A function or a method with a single argument of type 

789 `StepContext`. The function may use methods of the argument to perform 

790 computations with access to a raw session. The returned value of the 

791 `step_fn` will be returned from `run_step_fn`, unless a stop is 

792 requested. In that case, the next `should_stop` call will return True. 

793 Example usage: 

794 ```python 

795 with tf.Graph().as_default(): 

796 c = tf.compat.v1.placeholder(dtypes.float32) 

797 v = tf.add(c, 4.0) 

798 w = tf.add(c, 0.5) 

799 def step_fn(step_context): 

800 a = step_context.session.run(fetches=v, feed_dict={c: 0.5}) 

801 if a <= 4.5: 

802 step_context.request_stop() 

803 return step_context.run_with_hooks(fetches=w, 

804 feed_dict={c: 0.1}) 

805 

806 with tf.MonitoredSession() as session: 

807 while not session.should_stop(): 

808 a = session.run_step_fn(step_fn) 

809 ``` 

810 Hooks interact with the `run_with_hooks()` call inside the 

811 `step_fn` as they do with a `MonitoredSession.run` call. 

812 

813 Returns: 

814 Returns the returned value of `step_fn`. 

815 

816 Raises: 

817 StopIteration: if `step_fn` has called `request_stop()`. It may be 

818 caught by `with tf.MonitoredSession()` to close the session. 

819 ValueError: if `step_fn` doesn't have a single argument called 

820 `step_context`. It may also optionally have `self` for cases when it 

821 belongs to an object. 

822 """ 

823 step_fn_arguments = function_utils.fn_args(step_fn) 

824 if step_fn_arguments != ('step_context',) and step_fn_arguments != ( 

825 'self', 

826 'step_context', 

827 ): 

828 raise ValueError( 

829 '`step_fn` may either have one `step_context` argument, or' 

830 ' `self` and `step_context` arguments if it\'s an instance' 

831 ' method. Got {} instead.'.format(step_fn_arguments)) 

832 

833 # `self._sess` is either `_RecoverableSession` or a `_CoordinatedSession`. 

834 # Setting `run_with_hooks` to `None` will cause `run_with_hooks` to be 

835 # `_CoordinatedSession.run` downstream in either case. This allows 

836 # `_PREEMPTION_ERRORS` to propage from within `step_fn` to 

837 # `_RecoverableSession.run_step_fn`. 

838 return self._sess.run_step_fn(step_fn, self._tf_sess(), run_with_hooks=None) 

839 

840 class StepContext: 

841 """Control flow instrument for the `step_fn` from `run_step_fn()`. 

842 

843 Users of `step_fn` may perform `run()` calls without running hooks 

844 by accessing the `session`. A `run()` call with hooks may be performed 

845 using `run_with_hooks()`. Computation flow can be interrupted using 

846 `request_stop()`. 

847 """ 

848 

849 def __init__(self, session, run_with_hooks_fn): 

850 """Initializes the `step_context` argument for a `step_fn` invocation. 

851 

852 Args: 

853 session: An instance of `tf.compat.v1.Session`. 

854 run_with_hooks_fn: A function for running fetches and hooks. 

855 """ 

856 self._session = session 

857 self._run_with_hooks_fn = run_with_hooks_fn 

858 

859 @property 

860 def session(self): 

861 return self._session 

862 

863 def run_with_hooks(self, *args, **kwargs): 

864 """Same as `MonitoredSession.run`. Accepts the same arguments.""" 

865 return self._run_with_hooks_fn(*args, **kwargs) 

866 

867 def request_stop(self): 

868 """Exit the training loop by causing `should_stop()` to return `True`. 

869 

870 Causes `step_fn` to exit by raising an exception. 

871 

872 Raises: 

873 StopIteration 

874 """ 

875 raise StopIteration('step_fn has requested the iterations to stop.') 

876 

877 def should_stop(self): 

878 return self._sess is None or self._sess.should_stop() 

879 

880 def close(self): 

881 self._close_internal() 

882 

883 def __enter__(self): 

884 return self 

885 

886 def __exit__(self, exception_type, exception_value, traceback): 

887 if exception_type in [errors.OutOfRangeError, StopIteration]: 

888 exception_type = None 

889 self._close_internal(exception_type) 

890 # __exit__ should return True to suppress an exception. 

891 return exception_type is None 

892 

893 class _CoordinatedSessionCreator(SessionCreator): 

894 """Factory for _CoordinatedSession.""" 

895 

896 def __init__(self, session_creator, hooks, stop_grace_period_secs): 

897 self._session_creator = session_creator 

898 self._hooks = hooks 

899 self.coord = None 

900 self.tf_sess = None 

901 self._stop_grace_period_secs = stop_grace_period_secs 

902 

903 def create_session(self): 

904 """Creates a coordinated session.""" 

905 # Keep the tf_sess for unit testing. 

906 self.tf_sess = self._session_creator.create_session() 

907 # We don't want coordinator to suppress any exception. 

908 self.coord = coordinator.Coordinator(clean_stop_exception_types=[]) 

909 if ops.get_collection(ops.GraphKeys.QUEUE_RUNNERS): 

910 queue_runner.start_queue_runners(sess=self.tf_sess, coord=self.coord) 

911 # Inform the hooks that a new session has been created. 

912 for hook in self._hooks: 

913 hook.after_create_session(self.tf_sess, self.coord) 

914 return _CoordinatedSession( 

915 _HookedSession(self.tf_sess, self._hooks), self.coord, 

916 self._stop_grace_period_secs) 

917 

918 def _close_internal(self, exception_type=None): 

919 try: 

920 if not exception_type: 

921 for h in self._hooks: 

922 h.end(self._coordinated_creator.tf_sess) 

923 finally: 

924 try: 

925 if self._sess is None: 

926 raise RuntimeError('Session is already closed.') 

927 self._sess.close() 

928 finally: 

929 self._sess = None 

930 self._coordinated_creator.tf_sess = None 

931 self._coordinated_creator.coord = None 

932 if not self._graph_was_finalized: 

933 ops.get_default_graph()._unsafe_unfinalize() # pylint: disable=protected-access 

934 

935 def _is_closed(self): 

936 """Return True if the monitored session is closed. 

937 

938 For tests only. 

939 

940 Returns: 

941 A boolean. 

942 """ 

943 return self._coordinated_creator.tf_sess is None 

944 

945 def _tf_sess(self): 

946 """Return underlying tf.compat.v1.Session object. 

947 

948 Warning: accessing the returned object in user code is likely to cause races 

949 or "flaky tests". 

950 

951 Returns: 

952 A tf.compat.v1.Session object. 

953 """ 

954 return self._coordinated_creator.tf_sess 

955 

956 

957@tf_export(v1=['train.MonitoredSession']) 

958class MonitoredSession(_MonitoredSession): 

959 """Session-like object that handles initialization, recovery and hooks. 

960 

961 Example usage: 

962 

963 ```python 

964 saver_hook = CheckpointSaverHook(...) 

965 summary_hook = SummarySaverHook(...) 

966 with MonitoredSession(session_creator=ChiefSessionCreator(...), 

967 hooks=[saver_hook, summary_hook]) as sess: 

968 while not sess.should_stop(): 

969 sess.run(train_op) 

970 ``` 

971 

972 Initialization: At creation time the monitored session does following things 

973 in given order: 

974 

975 * calls `hook.begin()` for each given hook 

976 * finalizes the graph via `scaffold.finalize()` 

977 * create session 

978 * initializes the model via initialization ops provided by `Scaffold` 

979 * restores variables if a checkpoint exists 

980 * launches queue runners 

981 * calls `hook.after_create_session()` 

982 

983 Run: When `run()` is called, the monitored session does following things: 

984 

985 * calls `hook.before_run()` 

986 * calls TensorFlow `session.run()` with merged fetches and feed_dict 

987 * calls `hook.after_run()` 

988 * returns result of `session.run()` asked by user 

989 * if `AbortedError` or `UnavailableError` occurs, it recovers or 

990 reinitializes the session before executing the run() call again 

991 

992 

993 Exit: At the `close()`, the monitored session does following things in order: 

994 

995 * calls `hook.end()` 

996 * closes the queue runners and the session 

997 * suppresses `OutOfRange` error which indicates that all inputs have been 

998 processed if the monitored_session is used as a context 

999 

1000 How to set `tf.compat.v1.Session` arguments: 

1001 

1002 * In most cases you can set session arguments as follows: 

1003 

1004 ```python 

1005 MonitoredSession( 

1006 session_creator=ChiefSessionCreator(master=..., config=...)) 

1007 ``` 

1008 

1009 * In distributed setting for a non-chief worker, you can use following: 

1010 

1011 ```python 

1012 MonitoredSession( 

1013 session_creator=WorkerSessionCreator(master=..., config=...)) 

1014 ``` 

1015 

1016 See `MonitoredTrainingSession` for an example usage based on chief or worker. 

1017 

1018 Note: This is not a `tf.compat.v1.Session`. For example, it cannot do 

1019 following: 

1020 

1021 * it cannot be set as default session. 

1022 * it cannot be sent to saver.save. 

1023 * it cannot be sent to tf.train.start_queue_runners. 

1024 

1025 @compatibility(TF2) 

1026 This API is not compatible with eager execution and `tf.function`. To migrate 

1027 to TF2, rewrite the code to be compatible with eager execution. Check the 

1028 [migration 

1029 guide](https://www.tensorflow.org/guide/migrate#1_replace_v1sessionrun_calls) 

1030 on replacing `Session.run` calls. In Keras, session hooks can be replaced by 

1031 Callbacks e.g. [logging hook notebook]( 

1032 https://github.com/tensorflow/docs/blob/master/site/en/guide/migrate/logging_stop_hook.ipynb) 

1033 For more details please read [Better 

1034 performance with tf.function](https://www.tensorflow.org/guide/function). 

1035 @end_compatibility 

1036 

1037 Args: 

1038 session_creator: A factory object to create session. Typically a 

1039 `ChiefSessionCreator` which is the default one. 

1040 hooks: An iterable of `SessionRunHook' objects. 

1041 

1042 Returns: 

1043 A MonitoredSession object. 

1044 """ 

1045 

1046 def __init__(self, 

1047 session_creator=None, 

1048 hooks=None, 

1049 stop_grace_period_secs=120): 

1050 super(MonitoredSession, self).__init__( 

1051 session_creator, 

1052 hooks, 

1053 should_recover=True, 

1054 stop_grace_period_secs=stop_grace_period_secs) 

1055 

1056 

1057@tf_export(v1=['train.SingularMonitoredSession']) 

1058class SingularMonitoredSession(_MonitoredSession): 

1059 """Session-like object that handles initialization, restoring, and hooks. 

1060 

1061 Please note that this utility is not recommended for distributed settings. 

1062 For distributed settings, please use `tf.compat.v1.train.MonitoredSession`. 

1063 The 

1064 differences between `MonitoredSession` and `SingularMonitoredSession` are: 

1065 

1066 * `MonitoredSession` handles `AbortedError` and `UnavailableError` for 

1067 distributed settings, but `SingularMonitoredSession` does not. 

1068 * `MonitoredSession` can be created in `chief` or `worker` modes. 

1069 `SingularMonitoredSession` is always created as `chief`. 

1070 * You can access the raw `tf.compat.v1.Session` object used by 

1071 `SingularMonitoredSession`, whereas in MonitoredSession the raw session is 

1072 private. This can be used: 

1073 - To `run` without hooks. 

1074 - To save and restore. 

1075 * All other functionality is identical. 

1076 

1077 Example usage: 

1078 ```python 

1079 saver_hook = CheckpointSaverHook(...) 

1080 summary_hook = SummarySaverHook(...) 

1081 with SingularMonitoredSession(hooks=[saver_hook, summary_hook]) as sess: 

1082 while not sess.should_stop(): 

1083 sess.run(train_op) 

1084 ``` 

1085 

1086 Initialization: At creation time the hooked session does following things 

1087 in given order: 

1088 

1089 * calls `hook.begin()` for each given hook 

1090 * finalizes the graph via `scaffold.finalize()` 

1091 * create session 

1092 * initializes the model via initialization ops provided by `Scaffold` 

1093 * restores variables if a checkpoint exists 

1094 * launches queue runners 

1095 

1096 Run: When `run()` is called, the hooked session does following things: 

1097 

1098 * calls `hook.before_run()` 

1099 * calls TensorFlow `session.run()` with merged fetches and feed_dict 

1100 * calls `hook.after_run()` 

1101 * returns result of `session.run()` asked by user 

1102 

1103 Exit: At the `close()`, the hooked session does following things in order: 

1104 

1105 * calls `hook.end()` 

1106 * closes the queue runners and the session 

1107 * suppresses `OutOfRange` error which indicates that all inputs have been 

1108 processed if the `SingularMonitoredSession` is used as a context. 

1109 

1110 @compatibility(TF2) 

1111 This API is not compatible with eager execution and `tf.function`. To migrate 

1112 to TF2, rewrite the code to be compatible with eager execution. Check the 

1113 [migration 

1114 guide](https://www.tensorflow.org/guide/migrate#1_replace_v1sessionrun_calls) 

1115 on replacing `Session.run` calls. In Keras, session hooks can be replaced by 

1116 Callbacks e.g. [logging hook notebook]( 

1117 https://github.com/tensorflow/docs/blob/master/site/en/guide/migrate/logging_stop_hook.ipynb) 

1118 For more details please read [Better 

1119 performance with tf.function](https://www.tensorflow.org/guide/function). 

1120 @end_compatibility 

1121 """ 

1122 

1123 def __init__(self, 

1124 hooks=None, 

1125 scaffold=None, 

1126 master='', 

1127 config=None, 

1128 checkpoint_dir=None, 

1129 stop_grace_period_secs=120, 

1130 checkpoint_filename_with_path=None): 

1131 """Creates a SingularMonitoredSession. 

1132 

1133 Args: 

1134 hooks: An iterable of `SessionRunHook' objects. 

1135 scaffold: A `Scaffold` used for gathering or building supportive ops. If 

1136 not specified a default one is created. It's used to finalize the graph. 

1137 master: `String` representation of the TensorFlow master to use. 

1138 config: `ConfigProto` proto used to configure the session. 

1139 checkpoint_dir: A string. Optional path to a directory where to restore 

1140 variables. 

1141 stop_grace_period_secs: Number of seconds given to threads to stop after 

1142 `close()` has been called. 

1143 checkpoint_filename_with_path: A string. Optional path to a checkpoint 

1144 file from which to restore variables. 

1145 """ 

1146 session_creator = ChiefSessionCreator( 

1147 scaffold=scaffold, 

1148 master=master, 

1149 config=config, 

1150 checkpoint_dir=checkpoint_dir, 

1151 checkpoint_filename_with_path=checkpoint_filename_with_path) 

1152 super(SingularMonitoredSession, self).__init__( 

1153 session_creator, 

1154 hooks, 

1155 should_recover=False, 

1156 stop_grace_period_secs=stop_grace_period_secs) 

1157 

1158 def raw_session(self): 

1159 """Returns underlying `TensorFlow.Session` object.""" 

1160 return self._tf_sess() 

1161 

1162 

1163class _WrappedSession: 

1164 """Wrapper around a `tf.compat.v1.Session`. 

1165 

1166 This wrapper is used as a base class for various session wrappers 

1167 that provide additional functionality such as monitoring, coordination, 

1168 and recovery. 

1169 

1170 In addition to the methods exported by `SessionInterface` the wrapper 

1171 provides a method to check for stop and never raises exceptions from 

1172 calls to `close()`. 

1173 """ 

1174 

1175 def __init__(self, sess): 

1176 """Creates a `_WrappedSession`. 

1177 

1178 Args: 

1179 sess: A `tf.compat.v1.Session` or `_WrappedSession` object. The wrapped 

1180 session. 

1181 """ 

1182 self._sess = sess 

1183 self._wrapped_is_stoppable = isinstance(self._sess, _WrappedSession) 

1184 

1185 @property 

1186 def graph(self): 

1187 return self._sess.graph 

1188 

1189 @property 

1190 def sess_str(self): 

1191 return self._sess.sess_str 

1192 

1193 def should_stop(self): 

1194 """Return true if this session should not be used anymore. 

1195 

1196 Always return True if the session was closed. 

1197 

1198 Returns: 

1199 True if the session should stop, False otherwise. 

1200 """ 

1201 if self._check_stop(): 

1202 return True 

1203 if self._sess: 

1204 return self._wrapped_is_stoppable and self._sess.should_stop() 

1205 return True 

1206 

1207 def _check_stop(self): 

1208 """Hook for subclasses to provide their own stop condition. 

1209 

1210 Returns: 

1211 True if the session should stop, False otherwise. 

1212 """ 

1213 return False 

1214 

1215 def close(self): 

1216 if self._sess: 

1217 try: 

1218 self._sess.close() 

1219 except _PREEMPTION_ERRORS as e: 

1220 logging.error( 

1221 'An error occurred when attempting to close the ' 

1222 'session. This may be due to a preemption in a ' 

1223 'connected worker or parameter server. Error: %s', e) 

1224 finally: 

1225 self._sess = None 

1226 

1227 def run(self, *args, **kwargs): 

1228 return self._sess.run(*args, **kwargs) 

1229 

1230 def run_step_fn(self, step_fn, raw_session, run_with_hooks): 

1231 # `_RecoverableSession` sets `run_with_hooks` to `_CoordinatedSession.run`. 

1232 # It is `None` when called from `_CoordinatedSession`. In that case 

1233 # `self.run` is `_CoordinatedSession.run`. 

1234 run_with_hooks = run_with_hooks or self.run 

1235 return step_fn(_MonitoredSession.StepContext(raw_session, run_with_hooks)) 

1236 

1237 

1238class _RecoverableSession(_WrappedSession): 

1239 """A wrapped session that recreates a session upon certain kinds of errors. 

1240 

1241 The constructor is passed a SessionCreator object, not a session. 

1242 

1243 Calls to `run()` are delegated to the wrapped session. If a call raises the 

1244 exception `tf.errors.AbortedError` or `tf.errors.UnavailableError`, the 

1245 wrapped session is closed, and a new one is created by calling the factory 

1246 again. 

1247 """ 

1248 

1249 def __init__(self, sess_creator): 

1250 """Create a new `_RecoverableSession`. 

1251 

1252 The value returned by calling `sess_creator.create_session()` will be the 

1253 session wrapped by this recoverable session. 

1254 

1255 Args: 

1256 sess_creator: A 'SessionCreator' to be wrapped by recoverable. 

1257 """ 

1258 self._sess_creator = sess_creator 

1259 _WrappedSession.__init__(self, self._create_session()) 

1260 

1261 def _create_session(self): 

1262 while True: 

1263 try: 

1264 return self._sess_creator.create_session() 

1265 except _PREEMPTION_ERRORS as e: 

1266 logging.info( 

1267 'An error was raised while a session was being created. ' 

1268 'This may be due to a preemption of a connected worker ' 

1269 'or parameter server. A new session will be created. ' 

1270 'This error may also occur due to a gRPC failure caused ' 

1271 'by high memory or network bandwidth usage in the ' 

1272 'parameter servers. If this error occurs repeatedly, try ' 

1273 'increasing the number of parameter servers assigned to ' 

1274 'the job. Error: %s', e) 

1275 

1276 def _check_stop(self): 

1277 try: 

1278 if self._sess: 

1279 return self._sess._check_stop() # pylint: disable=protected-access 

1280 else: 

1281 return True 

1282 except _PREEMPTION_ERRORS as e: 

1283 logging.info( 

1284 'An error was raised while considering whether the ' 

1285 'session is complete. This may be due to a preemption in ' 

1286 'a connected worker or parameter server. The current ' 

1287 'session will be closed and a new session will be ' 

1288 'created. This error may also occur due to a gRPC failure ' 

1289 'caused by high memory or network bandwidth usage in the ' 

1290 'parameter servers. If this error occurs repeatedly, try ' 

1291 'increasing the number of parameter servers assigned to ' 

1292 'the job. Error: %s', e) 

1293 self.close() 

1294 self._sess = self._create_session() 

1295 # Since we have just recreated the session, the overall computation should 

1296 # not stop: 

1297 return False 

1298 except Exception: # pylint: disable=broad-except 

1299 # `should_stop` should return True instead of raising an exception. 

1300 return True 

1301 

1302 def run(self, fetches, feed_dict=None, options=None, run_metadata=None): 

1303 while True: 

1304 try: 

1305 if not self._sess: 

1306 self._sess = self._create_session() 

1307 return self._sess.run( 

1308 fetches, 

1309 feed_dict=feed_dict, 

1310 options=options, 

1311 run_metadata=run_metadata) 

1312 except _PREEMPTION_ERRORS as e: 

1313 logging.info( 

1314 'An error was raised. This may be due to a preemption in ' 

1315 'a connected worker or parameter server. The current ' 

1316 'session will be closed and a new session will be ' 

1317 'created. This error may also occur due to a gRPC failure ' 

1318 'caused by high memory or network bandwidth usage in the ' 

1319 'parameter servers. If this error occurs repeatedly, try ' 

1320 'increasing the number of parameter servers assigned to ' 

1321 'the job. Error: %s', e) 

1322 self.close() 

1323 self._sess = None 

1324 

1325 def run_step_fn(self, step_fn, raw_session, run_with_hooks): 

1326 while True: 

1327 try: 

1328 if not self._sess: 

1329 self._sess = self._create_session() 

1330 

1331 run_with_hooks = self._sess.run 

1332 return self._sess.run_step_fn(step_fn, raw_session, run_with_hooks) 

1333 except _PREEMPTION_ERRORS as e: 

1334 logging.info( 

1335 'An error was raised. This may be due to a preemption in ' 

1336 'a connected worker or parameter server. The current ' 

1337 'session will be closed and a new session will be ' 

1338 'created. This error may also occur due to a gRPC failure ' 

1339 'caused by high memory or network bandwidth usage in the ' 

1340 'parameter servers. If this error occurs repeatedly, try ' 

1341 'increasing the number of parameter servers assigned to ' 

1342 'the job. Error: %s', e) 

1343 self.close() 

1344 self._sess = None 

1345 

1346 

1347class _CoordinatedSession(_WrappedSession): 

1348 """A wrapped session that works with a `tf.Coordinator`. 

1349 

1350 Calls to `run()` are delegated to the wrapped session. If a call 

1351 raises an exception, the exception is reported to the coordinator. 

1352 

1353 In addition, after each call to `run()` this session ask the coordinator if 

1354 the session should stop. In that case it will join all the threads 

1355 registered with the coordinator before returning. 

1356 

1357 If the coordinator was requested to stop with an exception, that exception 

1358 will be re-raised from the call to `run()`. 

1359 """ 

1360 

1361 def __init__(self, sess, coord, stop_grace_period_secs=120): 

1362 """Create a new `_CoordinatedSession`. 

1363 

1364 Args: 

1365 sess: A `tf.compat.v1.Session` object. The wrapped session. 

1366 coord: A `tf.train.Coordinator` object. 

1367 stop_grace_period_secs: Number of seconds given to threads to stop after 

1368 `close()` has been called. 

1369 """ 

1370 _WrappedSession.__init__(self, sess) 

1371 self._coord = coord 

1372 self._stop_grace_period_secs = stop_grace_period_secs 

1373 

1374 def _check_stop(self): 

1375 # If the coordinator was asked to stop due to an exception, then it needs 

1376 # to be propagated to this stack. 

1377 self._coord.raise_requested_exception() 

1378 # At this point, no exceptions are recorded in the coordinator. 

1379 return self._coord.should_stop() 

1380 

1381 def close(self): 

1382 self._coord.request_stop() 

1383 try: 

1384 self._coord.join( 

1385 stop_grace_period_secs=self._stop_grace_period_secs, 

1386 ignore_live_threads=True) 

1387 finally: 

1388 try: 

1389 _WrappedSession.close(self) 

1390 except Exception: # pylint: disable=broad-except 

1391 # We intentionally suppress exceptions from the close() here since 

1392 # useful exceptions are already reported by join(). 

1393 pass 

1394 

1395 def run(self, *args, **kwargs): 

1396 try: 

1397 return self._sess.run(*args, **kwargs) 

1398 except _PREEMPTION_ERRORS: 

1399 raise 

1400 except Exception as original_exception: # pylint: disable=broad-except 

1401 # A non-preemption error could have been caused by a preemption error 

1402 # in the coordinator. If this is the case, raise that exception instead, 

1403 # since it's the root cause. Otherwise, stick to the `original_exception`. 

1404 try: 

1405 self._coord.raise_requested_exception() 

1406 except _PREEMPTION_ERRORS: 

1407 raise 

1408 except Exception: # pylint: disable=broad-except 

1409 raise original_exception from None 

1410 else: 

1411 raise 

1412 

1413 

1414class _HookedSession(_WrappedSession): 

1415 """A _WrappedSession that calls hooks during calls to run(). 

1416 

1417 The list of hooks to call is passed in the constructor. Before each call 

1418 to `run()` the session calls the `before_run()` method of the hooks, which 

1419 can return additional ops or tensors to run. These are added to the arguments 

1420 of the call to `run()`. 

1421 

1422 When the `run()` call finishes, the session calls the `after_run()` methods of 

1423 the hooks, passing the values returned by the `run()` call corresponding to 

1424 the ops and tensors that each hook requested. 

1425 

1426 If any call to the hooks, requests stop via run_context the session will be 

1427 marked as needing to stop and its `should_stop()` method will now return 

1428 `True`. 

1429 """ 

1430 

1431 def __init__(self, sess, hooks): 

1432 """Initializes a _HookedSession object. 

1433 

1434 Args: 

1435 sess: A `tf.compat.v1.Session` or a `_WrappedSession` object. 

1436 hooks: An iterable of `SessionRunHook' objects. 

1437 """ 

1438 

1439 _WrappedSession.__init__(self, sess) 

1440 self._hooks = hooks 

1441 self._should_stop = False 

1442 

1443 def _check_stop(self): 

1444 """See base class.""" 

1445 return self._should_stop 

1446 

1447 def run(self, fetches, feed_dict=None, options=None, run_metadata=None): 

1448 """See base class.""" 

1449 if self.should_stop(): 

1450 raise RuntimeError('Run called even after should_stop requested.') 

1451 

1452 actual_fetches = {'caller': fetches} 

1453 

1454 run_context = session_run_hook.SessionRunContext( 

1455 original_args=session_run_hook.SessionRunArgs(fetches, feed_dict), 

1456 session=self._sess) 

1457 

1458 options = options or config_pb2.RunOptions() 

1459 feed_dict = self._call_hook_before_run(run_context, actual_fetches, 

1460 feed_dict, options) 

1461 

1462 # Do session run. 

1463 run_metadata = run_metadata or config_pb2.RunMetadata() 

1464 outputs = _WrappedSession.run( 

1465 self, 

1466 fetches=actual_fetches, 

1467 feed_dict=feed_dict, 

1468 options=options, 

1469 run_metadata=run_metadata) 

1470 

1471 for hook in self._hooks: 

1472 hook.after_run( 

1473 run_context, 

1474 session_run_hook.SessionRunValues( 

1475 results=outputs[hook] if hook in outputs else None, 

1476 options=options, 

1477 run_metadata=run_metadata)) 

1478 self._should_stop = self._should_stop or run_context.stop_requested 

1479 

1480 return outputs['caller'] 

1481 

1482 def _call_hook_before_run(self, run_context, fetch_dict, user_feed_dict, 

1483 options): 

1484 """Calls hooks.before_run and handles requests from hooks.""" 

1485 hook_feeds = {} 

1486 for hook in self._hooks: 

1487 request = hook.before_run(run_context) 

1488 if request is not None: 

1489 if request.fetches is not None: 

1490 fetch_dict[hook] = request.fetches 

1491 if request.feed_dict: 

1492 self._raise_if_feeds_intersects(hook_feeds, request.feed_dict, 

1493 'Same tensor is fed by two hooks.') 

1494 hook_feeds.update(request.feed_dict) 

1495 if request.options: 

1496 self._merge_run_options(options, request.options) 

1497 

1498 if not hook_feeds: 

1499 return user_feed_dict 

1500 

1501 if not user_feed_dict: 

1502 return hook_feeds 

1503 

1504 self._raise_if_feeds_intersects( 

1505 user_feed_dict, hook_feeds, 

1506 'Same tensor is fed by a SessionRunHook and user.') 

1507 hook_feeds.update(user_feed_dict) 

1508 return hook_feeds 

1509 

1510 def _raise_if_feeds_intersects(self, feeds1, feeds2, message): 

1511 intersection = set(feeds1.keys()) & set(feeds2.keys()) 

1512 if intersection: 

1513 raise RuntimeError(message + ' Conflict(s): ' + str(list(intersection))) 

1514 

1515 def _merge_run_options(self, options, incoming_options): 

1516 """Merge two instances of RunOptions into the first one. 

1517 

1518 During the merger, the numerical fields including trace_level, 

1519 timeout_in_ms, inter_op_thread_pool are set to the larger one of the two. 

1520 The boolean value is set to the logical OR of the two. 

1521 debug_tensor_watch_opts of the original options is extended with that from 

1522 the incoming one. 

1523 

1524 Args: 

1525 options: The options to merge into. 

1526 incoming_options: The options to be merged into the first argument. 

1527 """ 

1528 options.trace_level = max(options.trace_level, incoming_options.trace_level) 

1529 options.timeout_in_ms = max(options.timeout_in_ms, 

1530 incoming_options.timeout_in_ms) 

1531 options.inter_op_thread_pool = max(options.inter_op_thread_pool, 

1532 incoming_options.inter_op_thread_pool) 

1533 options.output_partition_graphs = max( 

1534 options.output_partition_graphs, 

1535 incoming_options.output_partition_graphs) 

1536 options.debug_options.debug_tensor_watch_opts.extend( 

1537 incoming_options.debug_options.debug_tensor_watch_opts) 

1538 options.debug_options.reset_disk_byte_usage = ( 

1539 options.debug_options.reset_disk_byte_usage or 

1540 incoming_options.debug_options.reset_disk_byte_usage) 

1541 options.report_tensor_allocations_upon_oom = ( 

1542 options.report_tensor_allocations_upon_oom or 

1543 incoming_options.report_tensor_allocations_upon_oom)