Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/training/session_manager.py: 18%

152 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Training helper that checkpoints models and creates session.""" 

16import time 

17 

18import numpy as np 

19from tensorflow.python.checkpoint import checkpoint_management 

20from tensorflow.python.client import session 

21from tensorflow.python.distribute import distribute_lib 

22from tensorflow.python.framework import errors 

23from tensorflow.python.framework import ops 

24from tensorflow.python.platform import tf_logging as logging 

25from tensorflow.python.util.tf_export import tf_export 

26 

27 

28def _maybe_name(obj): 

29 """Returns object name if it has one, or a message otherwise. 

30 

31 This is useful for names that apper in error messages. 

32 Args: 

33 obj: Object to get the name of. 

34 Returns: 

35 name, "None", or a "no name" message. 

36 """ 

37 if obj is None: 

38 return "None" 

39 elif hasattr(obj, "name"): 

40 return obj.name 

41 else: 

42 return "<no name for %s>" % type(obj) 

43 

44 

45def _restore_checkpoint_and_maybe_run_saved_model_initializers( 

46 sess, saver, path): 

47 """Restores checkpoint values and SavedModel initializers if found.""" 

48 # NOTE: All references to SavedModel refer to SavedModels loaded from the 

49 # load_v2 API (which does not require the `sess` argument). 

50 

51 # If the graph contains resources loaded from a SavedModel, they are not 

52 # restored when calling `saver.restore`. Thus, the SavedModel initializer must 

53 # be called with `saver.restore` to properly initialize the model. 

54 

55 # The SavedModel init is stored in the "saved_model_initializers" collection. 

56 # This collection is part of the MetaGraph's default_init_op, so it is already 

57 # called by MonitoredSession as long as the saver doesn't restore any 

58 # checkpoints from the working dir. 

59 saved_model_init_ops = ops.get_collection("saved_model_initializers") 

60 if saved_model_init_ops: 

61 sess.run(saved_model_init_ops) 

62 

63 # The saver must be called *after* the SavedModel init, because the SavedModel 

64 # init will restore the variables from the SavedModel variables directory. 

65 # Initializing/restoring twice is not ideal but there's no other way to do it. 

66 saver.restore(sess, path) 

67 

68 

69@tf_export(v1=["train.SessionManager"]) 

70class SessionManager: 

71 """Training helper that restores from checkpoint and creates session. 

72 

73 This class is a small wrapper that takes care of session creation and 

74 checkpoint recovery. It also provides functions that to facilitate 

75 coordination among multiple training threads or processes. 

76 

77 * Checkpointing trained variables as the training progresses. 

78 * Initializing variables on startup, restoring them from the most recent 

79 checkpoint after a crash, or wait for checkpoints to become available. 

80 

81 ### Usage: 

82 

83 ```python 

84 with tf.Graph().as_default(): 

85 ...add operations to the graph... 

86 # Create a SessionManager that will checkpoint the model in '/tmp/mydir'. 

87 sm = SessionManager() 

88 sess = sm.prepare_session(master, init_op, saver, checkpoint_dir) 

89 # Use the session to train the graph. 

90 while True: 

91 sess.run(<my_train_op>) 

92 ``` 

93 

94 `prepare_session()` initializes or restores a model. It requires `init_op` 

95 and `saver` as an argument. 

96 

97 A second process could wait for the model to be ready by doing the following: 

98 

99 ```python 

100 with tf.Graph().as_default(): 

101 ...add operations to the graph... 

102 # Create a SessionManager that will wait for the model to become ready. 

103 sm = SessionManager() 

104 sess = sm.wait_for_session(master) 

105 # Use the session to train the graph. 

106 while True: 

107 sess.run(<my_train_op>) 

108 ``` 

109 

110 `wait_for_session()` waits for a model to be initialized by other processes. 

111 

112 """ 

113 

114 def __init__(self, 

115 local_init_op=None, 

116 ready_op=None, 

117 ready_for_local_init_op=None, 

118 graph=None, 

119 recovery_wait_secs=30, 

120 local_init_run_options=None, 

121 local_init_feed_dict=None): 

122 """Creates a SessionManager. 

123 

124 The `local_init_op` is an `Operation` that is run always after a new session 

125 was created. If `None`, this step is skipped. 

126 

127 The `ready_op` is an `Operation` used to check if the model is ready. The 

128 model is considered ready if that operation returns an empty 1D string 

129 tensor. If the operation returns a non empty 1D string tensor, the elements 

130 are concatenated and used to indicate to the user why the model is not 

131 ready. 

132 

133 The `ready_for_local_init_op` is an `Operation` used to check if the model 

134 is ready to run local_init_op. The model is considered ready if that 

135 operation returns an empty 1D string tensor. If the operation returns a non 

136 empty 1D string tensor, the elements are concatenated and used to indicate 

137 to the user why the model is not ready. 

138 

139 If `ready_op` is `None`, the model is not checked for readiness. 

140 

141 `recovery_wait_secs` is the number of seconds between checks that 

142 the model is ready. It is used by processes to wait for a model to 

143 be initialized or restored. Defaults to 30 seconds. 

144 

145 Args: 

146 local_init_op: An `Operation` run immediately after session creation. 

147 Usually used to initialize tables and local variables. 

148 ready_op: An `Operation` to check if the model is initialized. 

149 ready_for_local_init_op: An `Operation` to check if the model is ready 

150 to run local_init_op. 

151 graph: The `Graph` that the model will use. 

152 recovery_wait_secs: Seconds between checks for the model to be ready. 

153 local_init_run_options: RunOptions to be passed to session.run when 

154 executing the local_init_op. 

155 local_init_feed_dict: Optional session feed dictionary to use when running 

156 the local_init_op. 

157 

158 Raises: 

159 ValueError: If ready_for_local_init_op is not None but local_init_op is 

160 None 

161 """ 

162 # Sets default values of arguments. 

163 if graph is None: 

164 graph = ops.get_default_graph() 

165 self._local_init_op = local_init_op 

166 self._ready_op = ready_op 

167 self._ready_for_local_init_op = ready_for_local_init_op 

168 self._graph = graph 

169 self._recovery_wait_secs = recovery_wait_secs 

170 self._target = None 

171 self._local_init_run_options = local_init_run_options 

172 self._local_init_feed_dict = local_init_feed_dict 

173 if ready_for_local_init_op is not None and local_init_op is None: 

174 raise ValueError("If you pass a ready_for_local_init_op " 

175 "you must also pass a local_init_op " 

176 ", ready_for_local_init_op [%s]" % 

177 ready_for_local_init_op) 

178 

179 def _restore_checkpoint(self, 

180 master, 

181 saver=None, 

182 checkpoint_dir=None, 

183 checkpoint_filename_with_path=None, 

184 wait_for_checkpoint=False, 

185 max_wait_secs=7200, 

186 config=None): 

187 """Creates a `Session`, and tries to restore a checkpoint. 

188 

189 

190 Args: 

191 master: `String` representation of the TensorFlow master to use. 

192 saver: A `Saver` object used to restore a model. 

193 checkpoint_dir: Path to the checkpoint files. The latest checkpoint in the 

194 dir will be used to restore. 

195 checkpoint_filename_with_path: Full file name path to the checkpoint file. 

196 wait_for_checkpoint: Whether to wait for checkpoint to become available. 

197 max_wait_secs: Maximum time to wait for checkpoints to become available. 

198 config: Optional `ConfigProto` proto used to configure the session. 

199 

200 Returns: 

201 A pair (sess, is_restored) where 'is_restored' is `True` if 

202 the session could be restored, `False` otherwise. 

203 

204 Raises: 

205 ValueError: If both checkpoint_dir and checkpoint_filename_with_path are 

206 set. 

207 """ 

208 self._target = master 

209 

210 # This is required to so that we initialize the TPU device before 

211 # restoring from checkpoint since we'll be placing variables on the device 

212 # and TPUInitialize wipes out the memory of the device. 

213 strategy = distribute_lib.get_strategy() 

214 if strategy and hasattr(strategy.extended, 

215 "_experimental_initialize_system"): 

216 strategy.extended._experimental_initialize_system() # pylint: disable=protected-access 

217 

218 sess = session.Session(self._target, graph=self._graph, config=config) 

219 if checkpoint_dir and checkpoint_filename_with_path: 

220 raise ValueError("Can not provide both checkpoint_dir and " 

221 "checkpoint_filename_with_path.") 

222 # If either saver or checkpoint_* is not specified, cannot restore. Just 

223 # return. 

224 if not saver or not (checkpoint_dir or checkpoint_filename_with_path): 

225 return sess, False 

226 

227 if checkpoint_filename_with_path: 

228 _restore_checkpoint_and_maybe_run_saved_model_initializers( 

229 sess, saver, checkpoint_filename_with_path) 

230 return sess, True 

231 

232 # Waits up until max_wait_secs for checkpoint to become available. 

233 wait_time = 0 

234 ckpt = checkpoint_management.get_checkpoint_state(checkpoint_dir) 

235 while not ckpt or not ckpt.model_checkpoint_path: 

236 if wait_for_checkpoint and wait_time < max_wait_secs: 

237 logging.info("Waiting for checkpoint to be available.") 

238 time.sleep(self._recovery_wait_secs) 

239 wait_time += self._recovery_wait_secs 

240 ckpt = checkpoint_management.get_checkpoint_state(checkpoint_dir) 

241 else: 

242 return sess, False 

243 

244 # Loads the checkpoint. 

245 _restore_checkpoint_and_maybe_run_saved_model_initializers( 

246 sess, saver, ckpt.model_checkpoint_path) 

247 saver.recover_last_checkpoints(ckpt.all_model_checkpoint_paths) 

248 return sess, True 

249 

250 def prepare_session(self, 

251 master, 

252 init_op=None, 

253 saver=None, 

254 checkpoint_dir=None, 

255 checkpoint_filename_with_path=None, 

256 wait_for_checkpoint=False, 

257 max_wait_secs=7200, 

258 config=None, 

259 init_feed_dict=None, 

260 init_fn=None): 

261 """Creates a `Session`. Makes sure the model is ready to be used. 

262 

263 Creates a `Session` on 'master'. If a `saver` object is passed in, and 

264 `checkpoint_dir` points to a directory containing valid checkpoint 

265 files, then it will try to recover the model from checkpoint. If 

266 no checkpoint files are available, and `wait_for_checkpoint` is 

267 `True`, then the process would check every `recovery_wait_secs`, 

268 up to `max_wait_secs`, for recovery to succeed. 

269 

270 If the model cannot be recovered successfully then it is initialized by 

271 running the `init_op` and calling `init_fn` if they are provided. 

272 The `local_init_op` is also run after init_op and init_fn, regardless of 

273 whether the model was recovered successfully, but only if 

274 `ready_for_local_init_op` passes. 

275 

276 If the model is recovered from a checkpoint it is assumed that all 

277 global variables have been initialized, in particular neither `init_op` 

278 nor `init_fn` will be executed. 

279 

280 It is an error if the model cannot be recovered and no `init_op` 

281 or `init_fn` or `local_init_op` are passed. 

282 

283 Args: 

284 master: `String` representation of the TensorFlow master to use. 

285 init_op: Optional `Operation` used to initialize the model. 

286 saver: A `Saver` object used to restore a model. 

287 checkpoint_dir: Path to the checkpoint files. The latest checkpoint in the 

288 dir will be used to restore. 

289 checkpoint_filename_with_path: Full file name path to the checkpoint file. 

290 wait_for_checkpoint: Whether to wait for checkpoint to become available. 

291 max_wait_secs: Maximum time to wait for checkpoints to become available. 

292 config: Optional `ConfigProto` proto used to configure the session. 

293 init_feed_dict: Optional dictionary that maps `Tensor` objects to feed 

294 values. This feed dictionary is passed to the session `run()` call when 

295 running the init op. 

296 init_fn: Optional callable used to initialize the model. Called after the 

297 optional `init_op` is called. The callable must accept one argument, 

298 the session being initialized. 

299 

300 Returns: 

301 A `Session` object that can be used to drive the model. 

302 

303 Raises: 

304 RuntimeError: If the model cannot be initialized or recovered. 

305 ValueError: If both checkpoint_dir and checkpoint_filename_with_path are 

306 set. 

307 """ 

308 

309 sess, is_loaded_from_checkpoint = self._restore_checkpoint( 

310 master, 

311 saver, 

312 checkpoint_dir=checkpoint_dir, 

313 checkpoint_filename_with_path=checkpoint_filename_with_path, 

314 wait_for_checkpoint=wait_for_checkpoint, 

315 max_wait_secs=max_wait_secs, 

316 config=config) 

317 if not is_loaded_from_checkpoint: 

318 if init_op is None and not init_fn and self._local_init_op is None: 

319 raise RuntimeError("Model is not initialized and no init_op or " 

320 "init_fn or local_init_op was given") 

321 if init_op is not None: 

322 sess.run(init_op, feed_dict=init_feed_dict) 

323 if init_fn: 

324 init_fn(sess) 

325 

326 local_init_success, msg = self._try_run_local_init_op(sess) 

327 if not local_init_success: 

328 raise RuntimeError( 

329 "Init operations did not make model ready for local_init. " 

330 "Init op: %s, init fn: %s, error: %s" % (_maybe_name(init_op), 

331 init_fn, 

332 msg)) 

333 

334 is_ready, msg = self._model_ready(sess) 

335 if not is_ready: 

336 raise RuntimeError( 

337 "Init operations did not make model ready. " 

338 "Init op: %s, init fn: %s, local_init_op: %s, error: %s" % 

339 (_maybe_name(init_op), init_fn, self._local_init_op, msg)) 

340 return sess 

341 

342 def recover_session(self, 

343 master, 

344 saver=None, 

345 checkpoint_dir=None, 

346 checkpoint_filename_with_path=None, 

347 wait_for_checkpoint=False, 

348 max_wait_secs=7200, 

349 config=None): 

350 """Creates a `Session`, recovering if possible. 

351 

352 Creates a new session on 'master'. If the session is not initialized 

353 and can be recovered from a checkpoint, recover it. 

354 

355 Args: 

356 master: `String` representation of the TensorFlow master to use. 

357 saver: A `Saver` object used to restore a model. 

358 checkpoint_dir: Path to the checkpoint files. The latest checkpoint in the 

359 dir will be used to restore. 

360 checkpoint_filename_with_path: Full file name path to the checkpoint file. 

361 wait_for_checkpoint: Whether to wait for checkpoint to become available. 

362 max_wait_secs: Maximum time to wait for checkpoints to become available. 

363 config: Optional `ConfigProto` proto used to configure the session. 

364 

365 Returns: 

366 A pair (sess, initialized) where 'initialized' is `True` if 

367 the session could be recovered and initialized, `False` otherwise. 

368 

369 Raises: 

370 ValueError: If both checkpoint_dir and checkpoint_filename_with_path are 

371 set. 

372 """ 

373 

374 sess, is_loaded_from_checkpoint = self._restore_checkpoint( 

375 master, 

376 saver, 

377 checkpoint_dir=checkpoint_dir, 

378 checkpoint_filename_with_path=checkpoint_filename_with_path, 

379 wait_for_checkpoint=wait_for_checkpoint, 

380 max_wait_secs=max_wait_secs, 

381 config=config) 

382 

383 # Always try to run local_init_op 

384 local_init_success, msg = self._try_run_local_init_op(sess) 

385 

386 if not is_loaded_from_checkpoint: 

387 # Do not need to run checks for readiness 

388 return sess, False 

389 

390 restoring_file = checkpoint_dir or checkpoint_filename_with_path 

391 if not local_init_success: 

392 logging.info( 

393 "Restoring model from %s did not make model ready for local init:" 

394 " %s", restoring_file, msg) 

395 return sess, False 

396 

397 is_ready, msg = self._model_ready(sess) 

398 if not is_ready: 

399 logging.info("Restoring model from %s did not make model ready: %s", 

400 restoring_file, msg) 

401 return sess, False 

402 

403 logging.info("Restored model from %s", restoring_file) 

404 return sess, is_loaded_from_checkpoint 

405 

406 def wait_for_session(self, master, config=None, max_wait_secs=float("Inf")): 

407 """Creates a new `Session` and waits for model to be ready. 

408 

409 Creates a new `Session` on 'master'. Waits for the model to be 

410 initialized or recovered from a checkpoint. It's expected that 

411 another thread or process will make the model ready, and that this 

412 is intended to be used by threads/processes that participate in a 

413 distributed training configuration where a different thread/process 

414 is responsible for initializing or recovering the model being trained. 

415 

416 NB: The amount of time this method waits for the session is bounded 

417 by max_wait_secs. By default, this function will wait indefinitely. 

418 

419 Args: 

420 master: `String` representation of the TensorFlow master to use. 

421 config: Optional ConfigProto proto used to configure the session. 

422 max_wait_secs: Maximum time to wait for the session to become available. 

423 

424 Returns: 

425 A `Session`. May be None if the operation exceeds the timeout 

426 specified by config.operation_timeout_in_ms. 

427 

428 Raises: 

429 tf.DeadlineExceededError: if the session is not available after 

430 max_wait_secs. 

431 """ 

432 self._target = master 

433 

434 if max_wait_secs is None: 

435 max_wait_secs = float("Inf") 

436 timer = _CountDownTimer(max_wait_secs) 

437 

438 while True: 

439 sess = session.Session(self._target, graph=self._graph, config=config) 

440 not_ready_msg = None 

441 not_ready_local_msg = None 

442 local_init_success, not_ready_local_msg = self._try_run_local_init_op( 

443 sess) 

444 if local_init_success: 

445 # Successful if local_init_op is None, or ready_for_local_init_op passes 

446 is_ready, not_ready_msg = self._model_ready(sess) 

447 if is_ready: 

448 return sess 

449 

450 self._safe_close(sess) 

451 

452 # Do we have enough time left to try again? 

453 remaining_ms_after_wait = ( 

454 timer.secs_remaining() - self._recovery_wait_secs) 

455 if remaining_ms_after_wait < 0: 

456 raise errors.DeadlineExceededError( 

457 None, None, 

458 "Session was not ready after waiting %d secs." % (max_wait_secs,)) 

459 

460 logging.info("Waiting for model to be ready. " 

461 "Ready_for_local_init_op: %s, ready: %s", 

462 not_ready_local_msg, not_ready_msg) 

463 time.sleep(self._recovery_wait_secs) 

464 

465 def _safe_close(self, sess): 

466 """Closes a session without raising an exception. 

467 

468 Just like sess.close() but ignores exceptions. 

469 

470 Args: 

471 sess: A `Session`. 

472 """ 

473 # pylint: disable=broad-except 

474 try: 

475 sess.close() 

476 except Exception: 

477 # Intentionally not logging to avoid user complaints that 

478 # they get cryptic errors. We really do not care that Close 

479 # fails. 

480 pass 

481 # pylint: enable=broad-except 

482 

483 def _model_ready(self, sess): 

484 """Checks if the model is ready or not. 

485 

486 Args: 

487 sess: A `Session`. 

488 

489 Returns: 

490 A tuple (is_ready, msg), where is_ready is True if ready and False 

491 otherwise, and msg is `None` if the model is ready, a `String` with the 

492 reason why it is not ready otherwise. 

493 """ 

494 return _ready(self._ready_op, sess, "Model not ready") 

495 

496 def _model_ready_for_local_init(self, sess): 

497 """Checks if the model is ready to run local_init_op. 

498 

499 Args: 

500 sess: A `Session`. 

501 

502 Returns: 

503 A tuple (is_ready, msg), where is_ready is True if ready to run 

504 local_init_op and False otherwise, and msg is `None` if the model is 

505 ready to run local_init_op, a `String` with the reason why it is not ready 

506 otherwise. 

507 """ 

508 return _ready(self._ready_for_local_init_op, sess, 

509 "Model not ready for local init") 

510 

511 def _try_run_local_init_op(self, sess): 

512 """Tries to run _local_init_op, if not None, and is ready for local init. 

513 

514 Args: 

515 sess: A `Session`. 

516 

517 Returns: 

518 A tuple (is_successful, msg), where is_successful is True if 

519 _local_init_op is None, or we ran _local_init_op, and False otherwise; 

520 and msg is a `String` with the reason why the model was not ready to run 

521 local init. 

522 """ 

523 if self._local_init_op is not None: 

524 is_ready_for_local_init, msg = self._model_ready_for_local_init(sess) 

525 if is_ready_for_local_init: 

526 logging.info("Running local_init_op.") 

527 sess.run(self._local_init_op, feed_dict=self._local_init_feed_dict, 

528 options=self._local_init_run_options) 

529 logging.info("Done running local_init_op.") 

530 return True, None 

531 else: 

532 return False, msg 

533 return True, None 

534 

535 

536def _ready(op, sess, msg): 

537 """Checks if the model is ready or not, as determined by op. 

538 

539 Args: 

540 op: An op, either _ready_op or _ready_for_local_init_op, which defines the 

541 readiness of the model. 

542 sess: A `Session`. 

543 msg: A message to log to warning if not ready 

544 

545 Returns: 

546 A tuple (is_ready, msg), where is_ready is True if ready and False 

547 otherwise, and msg is `None` if the model is ready, a `String` with the 

548 reason why it is not ready otherwise. 

549 """ 

550 if op is None: 

551 return True, None 

552 else: 

553 try: 

554 ready_value = sess.run(op) 

555 # The model is considered ready if ready_op returns an empty 1-D tensor. 

556 # Also compare to `None` and dtype being int32 for backward 

557 # compatibility. 

558 if (ready_value is None or ready_value.dtype == np.int32 or 

559 ready_value.size == 0): 

560 return True, None 

561 else: 

562 # TODO(sherrym): If a custom ready_op returns other types of tensor, 

563 # or strings other than variable names, this message could be 

564 # confusing. 

565 non_initialized_varnames = ", ".join( 

566 [i.decode("utf-8") for i in ready_value]) 

567 return False, "Variables not initialized: " + non_initialized_varnames 

568 except errors.FailedPreconditionError as e: 

569 if "uninitialized" not in str(e): 

570 logging.warning("%s : error [%s]", msg, str(e)) 

571 raise e 

572 return False, str(e) 

573 

574 

575class _CountDownTimer: 

576 """A timer that tracks a duration since creation.""" 

577 

578 __slots__ = ["_start_time_secs", "_duration_secs"] 

579 

580 def __init__(self, duration_secs): 

581 self._start_time_secs = time.time() 

582 self._duration_secs = duration_secs 

583 

584 def secs_remaining(self): 

585 diff = self._duration_secs - (time.time() - self._start_time_secs) 

586 return max(0, diff)