Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/debug/wrappers/framework.py: 30%

274 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Framework of debug wrapper sessions. 

16 

17A debug wrapper session is a wrapper around a TensorFlow Python Session. 

18The wrapper preserves the Session interface, most importantly the run() method, 

19while providing abilities to: 

20a) Intercept a run() call to a wrapped session and insert debug tensor watches 

21 according to externally-specified debug URLs. 

22 

23b) Release control to an external (i.e., non-Session) object before and after 

24 the run() call, so that the external object can perform actions such as 

25 launching a UI to let users inspect the intermediate tensors and partition 

26 graphs from the run() call. 

27 

28c) (To be implemented in a future CL) Enter an instruction loop to let an 

29 external object (e.g., remote client) launch run() and cont() calls 

30 remotely. 

31 

32*** The lifetime of a debug wrapper session: *** 

33 

341) The wrapper session is created by calling the constructor with a 

35 wrapped (normal) session as the argument: 

36 wrapper = FooDebugWrapperSession(sess) 

37 wherein FooDebugWrapperSession is a concrete subclass implementing the 

38 abstract BaseDebugWrapperSession class below. 

39 

402) Near the end of the constructor call, the on_session_init() callback is 

41 invoked, with a OnSessionInitRequest object as the argument. The object 

42 carries the wrapped (normal) session object. 

43 

443) The callback handles the request and returns a OnSessionInitResponse 

45 object with an action field, directing the wrapper session what to do next. 

46 

47If the action field in the OnSessionInitResponse is PROCEED, the constructor 

48returns. Control is released back to the caller of the constructor, which can 

49invoke run() method of wrapper session with the same syntax as a non-wrapped 

50session, e.g.,: 

51 wrapper.run(fetches, feed_dict=feeds, options=run_options) 

52 

53Below, A1 - A2 is the lifetime of a wrapper run() call if the action is 

54PROCEED: 

55 

56A1) Right at the start of each run() call, the on_run_start() callback is 

57 invoked, with an OnRunStartRequest object carrying information such as 

58 the fetches, the feed dict, the run options and run metadata used in 

59 this run call, along with a count of how many run calls has occurred 

60 on this wrapper session. The callback then returns an OnRunStartResponse 

61 object, of which the action field directs what the wrapper session 

62 actually will do of the run() call. 

63 

64 If the action is DEBUG_RUN, a debugged (tensor-watched) run will ensue, 

65 with the debug URLs supplied in the debug_urls field of the response. 

66 These can be file:// or grpc:// URLs, for example. 

67 

68 If the action is NON_DEBUG_RUN, a non-debug (normal) run will ensue. 

69 

70A2) Right before the run() returns, the on_run_end() callback is invoked, 

71 with an OnRunEndRequest object as the argument, which carries information 

72 including the actual action performed in the wrapper run() call and the 

73 run_metadata from the run() call. 

74 

75However, if the action field in OnSessionInitResponse is 

76REMOTE_INSTR_LOOP, the constructor will automatically invoke an instruction loop 

77that gives the control to a remote caller. 

78 

79In the remote instruction loop, the following steps will happen: 

80 

81B1) Callback on_instr_start() is invoked. The callback will return an 

82 OnInstrStartResponse object with an action field which can order one of 

83 the following actions: 

84 i) a run() call with fetches, feeds and debug_urls specified. 

85 ii) exit the instruction loop. 

86 

87B2) The wrapper session carries out the action specified above. 

88 

89B3) If still in the instruction loop, the wrapper session invokes the 

90 on_instr_end() callback. After the on_instr_end() callback returns, jump 

91 back to B1. 

92 

93TODO(cais): Implemented the instruction loop in B1 - B3. 

94 

95""" 

96 

97import abc 

98import re 

99import threading 

100 

101from tensorflow.core.protobuf import config_pb2 

102from tensorflow.python.client import session 

103from tensorflow.python.debug.lib import debug_utils 

104from tensorflow.python.framework import errors 

105from tensorflow.python.framework import stack 

106from tensorflow.python.platform import tf_logging 

107from tensorflow.python.training import monitored_session 

108from tensorflow.python.util import nest 

109from tensorflow.python.util.compat import collections_abc 

110 

111 

112# Helper function. 

113def _check_type(obj, expected_types): 

114 """Check if an object is of the expected type. 

115 

116 Args: 

117 obj: The object being checked. 

118 expected_types: (`type` or an iterable of `type`s) The expected `type`(s) 

119 of obj. 

120 

121 Raises: 

122 TypeError: If obj is not an instance of expected_type. 

123 """ 

124 if not isinstance(obj, expected_types): 

125 raise TypeError("Expected type %s; got type %s" % 

126 (expected_types, type(obj))) 

127 

128 

129class OnSessionInitRequest: 

130 """Request to an on-session-init callback. 

131 

132 This callback is invoked during the __init__ call to a debug-wrapper session. 

133 """ 

134 

135 def __init__(self, sess): 

136 """Constructor. 

137 

138 Args: 

139 sess: A tensorflow Session object. 

140 """ 

141 

142 _check_type(sess, (session.BaseSession, monitored_session.MonitoredSession)) 

143 self.session = sess 

144 

145 

146class OnSessionInitAction: 

147 """Enum-like values for possible action to take on session init.""" 

148 

149 # Proceed, without special actions, in the wrapper session initialization. 

150 # What action the wrapper session performs next is determined by the caller 

151 # of the wrapper session. E.g., it can call run(). 

152 PROCEED = "proceed" 

153 

154 # Instead of letting the caller of the wrapper session determine what actions 

155 # the wrapper session will perform next, enter a loop to receive instructions 

156 # from a remote client. 

157 # For example, TensorBoard visual debugger can use this action so that it can 

158 # launch session.run() calls remotely. 

159 REMOTE_INSTR_LOOP = "remote_instr_loop" 

160 

161 

162class OnSessionInitResponse: 

163 """Response from an on-session-init callback.""" 

164 

165 def __init__(self, action): 

166 """Constructor. 

167 

168 Args: 

169 action: (`OnSessionInitAction`) Debugger action to take on session init. 

170 """ 

171 _check_type(action, str) 

172 self.action = action 

173 

174 

175class OnRunStartRequest: 

176 """Request to an on-run-start callback. 

177 

178 This callback is invoked during a run() call of the debug-wrapper 

179 session, immediately after the run() call counter is incremented. 

180 """ 

181 

182 def __init__(self, fetches, feed_dict, run_options, run_metadata, 

183 run_call_count, is_callable_runner=False): 

184 """Constructor of `OnRunStartRequest`. 

185 

186 Args: 

187 fetches: Fetch targets of the run() call. 

188 feed_dict: The feed dictionary to the run() call. 

189 run_options: RunOptions input to the run() call. 

190 run_metadata: RunMetadata input to the run() call. 

191 The above four arguments are identical to the input arguments to the 

192 run() method of a non-wrapped TensorFlow session. 

193 run_call_count: 1-based count of how many run calls (including this one) 

194 has been invoked. 

195 is_callable_runner: (bool) whether a runner returned by 

196 Session.make_callable is being run. 

197 """ 

198 self.fetches = fetches 

199 self.feed_dict = feed_dict 

200 self.run_options = run_options 

201 self.run_metadata = run_metadata 

202 self.run_call_count = run_call_count 

203 self.is_callable_runner = is_callable_runner 

204 

205 

206class OnRunStartAction: 

207 """Enum-like values for possible action to take on start of a run() call.""" 

208 

209 # Run once with debug tensor-watching. 

210 DEBUG_RUN = "debug_run" 

211 

212 # Run once with profiler. 

213 PROFILE_RUN = "profile_run" 

214 

215 # Run without debug tensor-watching. 

216 NON_DEBUG_RUN = "non_debug_run" 

217 

218 

219 

220class OnRunStartResponse: 

221 """Request from an on-run-start callback. 

222 

223 The caller of the callback can use this response object to specify what 

224 action the debug-wrapper session actually takes on the run() call. 

225 """ 

226 

227 def __init__(self, 

228 action, 

229 debug_urls, 

230 debug_ops="DebugIdentity", 

231 node_name_regex_allowlist=None, 

232 op_type_regex_allowlist=None, 

233 tensor_dtype_regex_allowlist=None, 

234 tolerate_debug_op_creation_failures=False): 

235 """Constructor of `OnRunStartResponse`. 

236 

237 Args: 

238 action: (`OnRunStartAction`) the action actually taken by the wrapped 

239 session for the run() call. 

240 debug_urls: (`list` of `str`) debug_urls used in watching the tensors 

241 during the run() call. 

242 debug_ops: (`str` or `list` of `str`) Debug op(s) to be used by the 

243 debugger. 

244 node_name_regex_allowlist: Regular-expression allowlist for node 

245 name. 

246 op_type_regex_allowlist: Regular-expression allowlist for op type. 

247 tensor_dtype_regex_allowlist: Regular-expression allowlist for tensor 

248 dtype. 

249 tolerate_debug_op_creation_failures: Whether debug op creation failures 

250 are to be tolerated. 

251 """ 

252 

253 _check_type(action, str) 

254 self.action = action 

255 

256 _check_type(debug_urls, list) 

257 self.debug_urls = debug_urls 

258 

259 self.debug_ops = debug_ops 

260 

261 self.node_name_regex_allowlist = node_name_regex_allowlist 

262 self.op_type_regex_allowlist = op_type_regex_allowlist 

263 self.tensor_dtype_regex_allowlist = tensor_dtype_regex_allowlist 

264 self.tolerate_debug_op_creation_failures = ( 

265 tolerate_debug_op_creation_failures) 

266 

267 

268class OnRunEndRequest: 

269 """Request to an on-run-end callback. 

270 

271 The callback is invoked immediately before the wrapped run() call ends. 

272 """ 

273 

274 def __init__(self, 

275 performed_action, 

276 run_metadata=None, 

277 client_graph_def=None, 

278 tf_error=None): 

279 """Constructor for `OnRunEndRequest`. 

280 

281 Args: 

282 performed_action: (`OnRunStartAction`) Actually-performed action by the 

283 debug-wrapper session. 

284 run_metadata: run_metadata output from the run() call (if any). 

285 client_graph_def: (GraphDef) GraphDef from the client side, i.e., from 

286 the python front end of TensorFlow. Can be obtained with 

287 session.graph.as_graph_def(). 

288 tf_error: (errors.OpError subtypes) TensorFlow OpError that occurred 

289 during the run (if any). 

290 """ 

291 

292 _check_type(performed_action, str) 

293 self.performed_action = performed_action 

294 

295 if run_metadata is not None: 

296 _check_type(run_metadata, config_pb2.RunMetadata) 

297 self.run_metadata = run_metadata 

298 self.client_graph_def = client_graph_def 

299 self.tf_error = tf_error 

300 

301 

302class OnRunEndResponse: 

303 """Response from an on-run-end callback.""" 

304 

305 def __init__(self): 

306 

307 # Currently only a placeholder. 

308 pass 

309 

310 

311class BaseDebugWrapperSession(session.SessionInterface, metaclass=abc.ABCMeta): 

312 """Base class of debug-wrapper session classes. 

313 

314 Concrete classes that inherit from this class need to implement the abstract 

315 methods such as on_session_init, on_run_start and on_run_end. 

316 """ 

317 

318 def __init__(self, sess, thread_name_filter=None, 

319 pass_through_operrors=False): 

320 """Constructor of `BaseDebugWrapperSession`. 

321 

322 Args: 

323 sess: An (unwrapped) TensorFlow session instance. It should be a subtype 

324 of `BaseSession` or `tf.MonitoredSession`. 

325 thread_name_filter: Regular-expression filter (allowlist) for name(s) of 

326 thread(s) on which the wrapper session will be active. This regular 

327 expression is used in a start-anchored fashion on the thread name, i.e., 

328 by applying the `match` method of the compiled pattern. The default 

329 `None` means that the wrapper session will be active on all threads. 

330 E.g., r"MainThread$", r"QueueRunnerThread.*". 

331 pass_through_operrors: If True, all captured OpErrors will be 

332 propagated. By default this captures all OpErrors. 

333 

334 Raises: 

335 ValueError: On invalid `OnSessionInitAction` value. 

336 NotImplementedError: If a non-DirectSession sess object is received. 

337 """ 

338 

339 _check_type(sess, (session.BaseSession, monitored_session.MonitoredSession)) 

340 

341 # The session being wrapped. 

342 self._sess = sess 

343 self._thread_name_filter_pattern = (re.compile(thread_name_filter) 

344 if thread_name_filter else None) 

345 # TODO(cais/kstevens): Unittest this pass through feature. 

346 self._pass_through_operrors = pass_through_operrors 

347 

348 # Keeps track of number of run calls that have been performed on this 

349 # debug-wrapper session. The count can be used for purposes such as 

350 # displaying the state of the Session in a UI and determining a run 

351 # number-dependent debug URL. 

352 self._run_call_count = 0 

353 

354 # Invoke on-session-init callback. 

355 response = self.on_session_init(OnSessionInitRequest(self._sess)) 

356 _check_type(response, OnSessionInitResponse) 

357 

358 if response.action == OnSessionInitAction.PROCEED: 

359 pass 

360 elif response.action == OnSessionInitAction.REMOTE_INSTR_LOOP: 

361 # TODO(cais): Implement REMOTE_INSTR_LOOP 

362 raise NotImplementedError( 

363 "OnSessionInitAction REMOTE_INSTR_LOOP has not been " 

364 "implemented.") 

365 else: 

366 raise ValueError( 

367 "Invalid OnSessionInitAction value: %s" % response.action) 

368 

369 self._default_session_context_manager = None 

370 

371 # A cache for callables created from CallableOptions. 

372 self._cached_callables_from_options = {} 

373 

374 @property 

375 def graph(self): 

376 return self._sess.graph 

377 

378 @property 

379 def graph_def(self): 

380 return self._sess.graph_def 

381 

382 @property 

383 def sess_str(self): 

384 return self._sess.sess_str 

385 

386 @property 

387 def session(self): 

388 return self._sess 

389 

390 def run(self, 

391 fetches, 

392 feed_dict=None, 

393 options=None, 

394 run_metadata=None, 

395 callable_runner=None, 

396 callable_runner_args=None, 

397 callable_options=None): 

398 """Wrapper around Session.run() that inserts tensor watch options. 

399 

400 Args: 

401 fetches: Same as the `fetches` arg to regular `Session.run()`. 

402 feed_dict: Same as the `feed_dict` arg to regular `Session.run()`. 

403 options: Same as the `options` arg to regular `Session.run()`. 

404 run_metadata: Same as the `run_metadata` arg to regular `Session.run()`. 

405 callable_runner: A `callable` returned by `Session.make_callable()`. 

406 If not `None`, `fetches` and `feed_dict` must both be `None`. 

407 Mutually exclusive with `callable_options`. 

408 callable_runner_args: An optional list of arguments to `callable_runner` 

409 or for `callable_options`. 

410 callable_options: An instance of `config_pb2.CallableOptions`, to be 

411 used with `Session._make_callable_from_options()`. Mutually exclusive 

412 with `callable_runner`. 

413 

414 Returns: 

415 Simply forwards the output of the wrapped `Session.run()` call. 

416 

417 Raises: 

418 ValueError: On invalid `OnRunStartAction` value. Or if `callable_runner` 

419 is not `None` and either or both of `fetches` and `feed_dict` is `None`. 

420 """ 

421 if callable_runner and callable_options: 

422 raise ValueError( 

423 "callable_runner and callable_options are mutually exclusive, but " 

424 "are both specified in this call to BaseDebugWrapperSession.run().") 

425 

426 if callable_runner and (fetches or feed_dict): 

427 raise ValueError( 

428 "callable_runner and fetches/feed_dict are mutually exclusive, " 

429 "but are used simultaneously.") 

430 elif callable_options and (fetches or feed_dict): 

431 raise ValueError( 

432 "callable_options and fetches/feed_dict are mutually exclusive, " 

433 "but are used simultaneously.") 

434 

435 self.increment_run_call_count() 

436 

437 def is_empty(x): 

438 """Check whether a possibly nested structure is empty.""" 

439 if not nest.is_nested(x): 

440 return False 

441 if isinstance(x, collections_abc.Mapping): 

442 return is_empty(list(x.values())) 

443 for item in x: 

444 if not is_empty(item): 

445 return False 

446 return True 

447 

448 empty_fetches = is_empty(fetches) 

449 if empty_fetches: 

450 tf_logging.info( 

451 "Due to empty fetches, tfdbg Session wrapper is letting a " 

452 "Session.run pass through without any debugging actions.") 

453 if self._is_disabled_thread() or empty_fetches: 

454 if callable_runner: 

455 return callable_runner(*callable_runner_args) 

456 elif callable_options: 

457 # pylint:disable=protected-access 

458 return self._sess._make_callable_from_options( 

459 callable_options)(*callable_runner_args) 

460 # pylint:enable=protected-access 

461 else: 

462 return self._sess.run(fetches, 

463 feed_dict=feed_dict, 

464 options=options, 

465 run_metadata=run_metadata) 

466 

467 # Invoke on-run-start callback and obtain response. 

468 run_start_resp = self.on_run_start( 

469 OnRunStartRequest(fetches, feed_dict, options, run_metadata, 

470 self._run_call_count, 

471 is_callable_runner=bool(callable_runner))) 

472 _check_type(run_start_resp, OnRunStartResponse) 

473 

474 if run_start_resp.action == OnRunStartAction.DEBUG_RUN: 

475 retvals, run_end_req = self._run_with_debugging( 

476 run_start_resp, fetches, feed_dict, options, run_metadata, 

477 callable_runner, callable_runner_args, callable_options) 

478 elif run_start_resp.action == OnRunStartAction.PROFILE_RUN: 

479 retvals, run_end_req = self._run_with_profiling( 

480 run_start_resp, fetches, feed_dict, options, run_metadata, 

481 callable_runner, callable_runner_args, callable_options) 

482 elif run_start_resp.action == OnRunStartAction.NON_DEBUG_RUN: 

483 # Invoke run() method of the wrapped session. 

484 if callable_runner: 

485 retvals = callable_runner(*callable_runner_args) 

486 elif callable_options: 

487 # pylint:disable=protected-access 

488 callable_object = self._sess._make_callable_from_options( 

489 callable_options) 

490 # pylint:enable=protected-access 

491 retvals = callable_object(*callable_runner_args) 

492 else: 

493 retvals = self._sess.run( 

494 fetches, 

495 feed_dict=feed_dict, 

496 options=options, 

497 run_metadata=run_metadata) 

498 

499 # Prepare arg for the on-run-end callback. 

500 run_end_req = OnRunEndRequest(run_start_resp.action) 

501 else: 

502 raise ValueError( 

503 "Invalid OnRunStartAction value: %s" % run_start_resp.action) 

504 

505 # Invoke on-run-end callback and obtain response. 

506 run_end_resp = self.on_run_end(run_end_req) 

507 _check_type(run_end_resp, OnRunEndResponse) 

508 # Currently run_end_resp is only a placeholder. No action is taken on it. 

509 

510 return retvals 

511 

512 def _run_with_debugging(self, 

513 run_start_resp, 

514 fetches, 

515 feed_dict, 

516 options, 

517 run_metadata, 

518 callable_runner, 

519 callable_runner_args, 

520 callable_options): 

521 """Perform a session.run() or callable with debugging.""" 

522 # Decorate RunOption to fill in debugger tensor watch specifications. 

523 decorated_run_options = None 

524 if callable_options: 

525 callable_options_id = id(callable_options) 

526 if callable_options_id not in self._cached_callables_from_options: 

527 # Make a copy of callable_options to avoid mutating it. 

528 new_callable_options = config_pb2.CallableOptions() 

529 new_callable_options.CopyFrom(callable_options) 

530 decorated_run_options = new_callable_options.run_options 

531 else: 

532 decorated_run_options = options or config_pb2.RunOptions() 

533 

534 run_metadata = run_metadata or config_pb2.RunMetadata() 

535 

536 if decorated_run_options: 

537 self._decorate_run_options_for_debug( 

538 decorated_run_options, 

539 run_start_resp.debug_urls, 

540 debug_ops=run_start_resp.debug_ops, 

541 node_name_regex_allowlist=(run_start_resp.node_name_regex_allowlist), 

542 op_type_regex_allowlist=run_start_resp.op_type_regex_allowlist, 

543 tensor_dtype_regex_allowlist=( 

544 run_start_resp.tensor_dtype_regex_allowlist), 

545 tolerate_debug_op_creation_failures=( 

546 run_start_resp.tolerate_debug_op_creation_failures)) 

547 

548 # Invoke the run() method of the wrapped Session. Catch any TensorFlow 

549 # runtime errors. 

550 tf_error = None 

551 try: 

552 if callable_runner: 

553 retvals = callable_runner(*callable_runner_args, 

554 options=decorated_run_options, 

555 run_metadata=run_metadata) 

556 elif callable_options: 

557 # pylint:disable=protected-access 

558 if callable_options_id in self._cached_callables_from_options: 

559 callable_object = self._cached_callables_from_options[ 

560 callable_options_id] 

561 else: 

562 callable_object = self._sess._make_callable_from_options( 

563 new_callable_options) 

564 self._cached_callables_from_options[ 

565 callable_options_id] = callable_object 

566 # pylint:enable=protected-access 

567 retvals = callable_object( 

568 *callable_runner_args, run_metadata=run_metadata) 

569 else: 

570 retvals = self._sess.run(fetches, 

571 feed_dict=feed_dict, 

572 options=decorated_run_options, 

573 run_metadata=run_metadata) 

574 except errors.OpError as op_error: 

575 if self._pass_through_operrors: 

576 raise op_error 

577 tf_error = op_error 

578 retvals = op_error 

579 

580 return retvals, OnRunEndRequest( 

581 run_start_resp.action, 

582 run_metadata=run_metadata, 

583 client_graph_def=self._sess.graph.as_graph_def(), 

584 tf_error=tf_error) 

585 

586 def _run_with_profiling(self, 

587 run_start_resp, 

588 fetches, 

589 feed_dict, 

590 options, 

591 run_metadata, 

592 callable_runner, 

593 callable_runner_args, 

594 callable_options): 

595 """Perform a session.run() or callable with profiling.""" 

596 # Decorate RunOption to fill in debugger tensor watch specifications. 

597 decorated_run_options = None 

598 if callable_options: 

599 callable_options_id = id(callable_options) 

600 if callable_options_id not in self._cached_callables_from_options: 

601 # Make a copy of callable_options to avoid mutating it. 

602 new_callable_options = config_pb2.CallableOptions() 

603 new_callable_options.CopyFrom(callable_options) 

604 decorated_run_options = new_callable_options.run_options 

605 else: 

606 decorated_run_options = options or config_pb2.RunOptions() 

607 self._decorate_run_options_for_profile(decorated_run_options) 

608 

609 run_metadata = run_metadata or config_pb2.RunMetadata() 

610 if callable_runner: 

611 retvals = callable_runner(*callable_runner_args, 

612 options=decorated_run_options, 

613 run_metadata=run_metadata) 

614 elif callable_options: 

615 # pylint:disable=protected-access 

616 callable_object = self._sess._make_callable_from_options( 

617 new_callable_options) 

618 # pylint:enable=protected-access 

619 retvals = callable_object( 

620 *callable_runner_args, run_metadata=run_metadata) 

621 else: 

622 retvals = self._sess.run(fetches, 

623 feed_dict=feed_dict, 

624 options=decorated_run_options, 

625 run_metadata=run_metadata) 

626 return retvals, OnRunEndRequest( 

627 run_start_resp.action, 

628 run_metadata=run_metadata, 

629 client_graph_def=self._sess.graph.as_graph_def()) 

630 

631 def _is_disabled_thread(self): 

632 thread_name = threading.current_thread().name or "" 

633 return (self._thread_name_filter_pattern and 

634 not self._thread_name_filter_pattern.match(thread_name)) 

635 

636 def run_step_fn(self, step_fn): 

637 return step_fn( 

638 monitored_session.MonitoredSession.StepContext(self._sess, self.run)) 

639 

640 def partial_run_setup(self, fetches, feeds=None): 

641 """Sets up the feeds and fetches for partial runs in the session.""" 

642 raise NotImplementedError( 

643 "partial_run_setup is not implemented for debug-wrapper sessions.") 

644 

645 def partial_run(self, handle, fetches, feed_dict=None): 

646 raise NotImplementedError( 

647 "partial_run is not implemented for debug-wrapper sessions.") 

648 

649 def list_devices(self, *args, **kwargs): 

650 return self._sess.list_devices(*args, **kwargs) 

651 

652 def reset(self, *args, **kwargs): 

653 return self._sess.reset(*args, **kwargs) 

654 

655 def make_callable(self, 

656 fetches, 

657 feed_list=None, 

658 accept_options=False): 

659 runner = self._sess.make_callable( 

660 fetches, feed_list=feed_list, accept_options=True) 

661 def wrapped_runner(*runner_args, **kwargs): 

662 return self.run(None, 

663 feed_dict=None, 

664 options=kwargs.get("options", None), 

665 run_metadata=kwargs.get("run_metadata", None), 

666 callable_runner=runner, 

667 callable_runner_args=runner_args) 

668 return wrapped_runner 

669 

670 def _make_callable_from_options(self, callable_options): 

671 def wrapped_runner(*feed_values, **kwargs): 

672 return self.run(None, 

673 run_metadata=kwargs.get("run_metadata", None), 

674 callable_options=callable_options, 

675 callable_runner_args=feed_values) 

676 return wrapped_runner 

677 

678 @property 

679 def run_call_count(self): 

680 return self._run_call_count 

681 

682 def increment_run_call_count(self): 

683 self._run_call_count += 1 

684 

685 def _is_disk_usage_reset_each_run(self): 

686 """Indicates whether disk usage is reset after each Session.run. 

687 

688 Subclasses that clean up the disk usage after every run should 

689 override this protected method. 

690 

691 Returns: 

692 (`bool`) Whether the disk usage amount is reset to zero after 

693 each Session.run. 

694 """ 

695 return False 

696 

697 def _decorate_run_options_for_debug( 

698 self, 

699 run_options, 

700 debug_urls, 

701 debug_ops="DebugIdentity", 

702 node_name_regex_allowlist=None, 

703 op_type_regex_allowlist=None, 

704 tensor_dtype_regex_allowlist=None, 

705 tolerate_debug_op_creation_failures=False): 

706 """Modify a RunOptions object for debug tensor watching. 

707 

708 Specifies request for outputting partition graphs. Adds 

709 debug_tensor_watch_opts with proper debug URLs. 

710 

711 Args: 

712 run_options: (RunOptions) the modified RunOptions object. 

713 debug_urls: (list of str) debug URLs to be entered in run_options. 

714 debug_tensor_watch_opts. 

715 debug_ops: (str or list of str) debug op(s) to be used by the debugger. 

716 node_name_regex_allowlist: Regular-expression allowlist for node 

717 name. 

718 op_type_regex_allowlist: Regular-expression allowlist for op type. 

719 tensor_dtype_regex_allowlist: Regular-expression allowlist for tensor 

720 dtype. 

721 tolerate_debug_op_creation_failures: Whether debug op creation failures 

722 are to be tolerated. 

723 """ 

724 

725 run_options.output_partition_graphs = True 

726 debug_utils.watch_graph( 

727 run_options, 

728 self._sess.graph, 

729 debug_urls=debug_urls, 

730 debug_ops=debug_ops, 

731 node_name_regex_allowlist=node_name_regex_allowlist, 

732 op_type_regex_allowlist=op_type_regex_allowlist, 

733 tensor_dtype_regex_allowlist=tensor_dtype_regex_allowlist, 

734 tolerate_debug_op_creation_failures=tolerate_debug_op_creation_failures, 

735 reset_disk_byte_usage=(self._run_call_count == 1 or 

736 self._is_disk_usage_reset_each_run())) 

737 

738 def _decorate_run_options_for_profile(self, run_options): 

739 """Modify a RunOptions object for profiling TensorFlow graph execution. 

740 

741 Args: 

742 run_options: (RunOptions) the modified RunOptions object. 

743 """ 

744 

745 run_options.trace_level = config_pb2.RunOptions.FULL_TRACE 

746 

747 @abc.abstractmethod 

748 def on_session_init(self, request): 

749 """Callback invoked during construction of the debug-wrapper session. 

750 

751 This is a blocking callback. 

752 The invocation happens right before the constructor ends. 

753 

754 Args: 

755 request: (`OnSessionInitRequest`) callback request carrying information 

756 such as the session being wrapped. 

757 

758 Returns: 

759 An instance of `OnSessionInitResponse`. 

760 """ 

761 

762 @abc.abstractmethod 

763 def on_run_start(self, request): 

764 """Callback invoked on run() calls to the debug-wrapper session. 

765 

766 This is a blocking callback. 

767 The invocation happens after the wrapper's run() call is entered, 

768 after an increment of run call counter. 

769 

770 Args: 

771 request: (`OnRunStartRequest`) callback request object carrying 

772 information about the run call such as the fetches, feed dict, run 

773 options, run metadata, and how many `run()` calls to this wrapper 

774 session have occurred. 

775 

776 Returns: 

777 An instance of `OnRunStartResponse`, carrying information to 

778 debug URLs used to watch the tensors. 

779 """ 

780 

781 @abc.abstractmethod 

782 def on_run_end(self, request): 

783 """Callback invoked on run() calls to the debug-wrapper session. 

784 

785 This is a blocking callback. 

786 The invocation happens right before the wrapper exits its run() call. 

787 

788 Args: 

789 request: (`OnRunEndRequest`) callback request object carrying information 

790 such as the actual action performed by the session wrapper for the 

791 run() call. 

792 

793 Returns: 

794 An instance of `OnRunStartResponse`. 

795 """ 

796 

797 def as_default(self): 

798 return stack.default_session(self) 

799 

800 def __enter__(self): 

801 if self._default_session_context_manager is None: 

802 self._default_session_context_manager = self.as_default() 

803 return self._default_session_context_manager.__enter__() 

804 

805 def __exit__(self, exec_type, exec_value, exec_tb): 

806 self._default_session_context_manager.__exit__( 

807 exec_type, exec_value, exec_tb) 

808 

809 def __del__(self): 

810 if hasattr(self._sess, "__del__"): 

811 self._sess.__del__() 

812 

813 def close(self): 

814 self._sess.close() 

815 

816 # TODO(cais): Add _node_name_regex_allowlist and 

817 # _node_op_type_regex_allowlist. 

818 

819 def should_stop(self): 

820 if hasattr(self._sess, "should_stop"): 

821 return self._sess.should_stop() 

822 else: 

823 raise ValueError( 

824 "The wrapped session %r does not have a method called 'should_stop'. " 

825 "Do you intend to wrap a tf.MonitoredSession instead?" % self._sess) 

826 

827 

828class WatchOptions: 

829 """Type for return values of watch_fn.""" 

830 

831 def __init__(self, 

832 debug_ops=None, 

833 node_name_regex_allowlist=None, 

834 op_type_regex_allowlist=None, 

835 tensor_dtype_regex_allowlist=None, 

836 tolerate_debug_op_creation_failures=False): 

837 """Constructor of WatchOptions: Debug watch options. 

838 

839 Used as return values of `watch_fn`s. 

840 

841 Args: 

842 debug_ops: (`str` or `list of str`) Debug ops to be used. 

843 node_name_regex_allowlist: Regular-expression allowlist for node_name, 

844 e.g., `"(weight_[0-9]+|bias_.*)"` 

845 op_type_regex_allowlist: Regular-expression allowlist for the op type of 

846 nodes, e.g., `"(Variable|Add)"`. 

847 If both `node_name_regex_allowlist` and `op_type_regex_allowlist` 

848 are set, the two filtering operations will occur in a logical `AND` 

849 relation. In other words, a node will be included if and only if it 

850 hits both allowlists. 

851 tensor_dtype_regex_allowlist: Regular-expression allowlist for Tensor 

852 data type, e.g., `"^int.*"`. 

853 This allowlist operates in logical `AND` relations to the two allowlists 

854 above. 

855 tolerate_debug_op_creation_failures: (`bool`) whether debug op creation 

856 failures (e.g., due to dtype incompatibility) are to be tolerated by not 

857 throwing exceptions. 

858 """ 

859 if debug_ops: 

860 self.debug_ops = debug_ops 

861 else: 

862 self.debug_ops = ["DebugIdentity"] 

863 self.node_name_regex_allowlist = node_name_regex_allowlist 

864 self.op_type_regex_allowlist = op_type_regex_allowlist 

865 self.tensor_dtype_regex_allowlist = tensor_dtype_regex_allowlist 

866 self.tolerate_debug_op_creation_failures = ( 

867 tolerate_debug_op_creation_failures) 

868 

869 def __repr__(self): 

870 return ("WatchOptions(debug_ops=%r, node_name_regex_allowlist=%r, " 

871 "op_type_regex_allowlist=%r, tensor_dtype_regex_allowlist=%r, " 

872 "tolerate_debug_op_creation_failures=%r)" % 

873 (self.debug_ops, self.node_name_regex_allowlist, 

874 self.op_type_regex_allowlist, self.tensor_dtype_regex_allowlist, 

875 self.tolerate_debug_op_creation_failures)) 

876 

877 

878class NonInteractiveDebugWrapperSession(BaseDebugWrapperSession): 

879 """Base class for non-interactive (i.e., non-CLI) debug wrapper sessions.""" 

880 

881 def __init__(self, sess, watch_fn=None, thread_name_filter=None, 

882 pass_through_operrors=False): 

883 """Constructor of NonInteractiveDebugWrapperSession. 

884 

885 Args: 

886 sess: The TensorFlow `Session` object being wrapped. 

887 watch_fn: (`Callable`) A Callable that maps the fetches and feeds of a 

888 debugged `Session.run()` call to `WatchOptions.` 

889 * Args: 

890 * `fetches`: the fetches to the `Session.run()` call. 

891 * `feeds`: the feeds to the `Session.run()` call. 

892 

893 * Returns: 

894 (`tf_debug.WatchOptions`) An object containing debug options including 

895 the debug ops to use, the node names, op types and/or tensor data 

896 types to watch, etc. See the documentation of `tf_debug.WatchOptions` 

897 for more details. 

898 thread_name_filter: Regular-expression white list for threads on which the 

899 wrapper session will be active. See doc of `BaseDebugWrapperSession` for 

900 more details. 

901 pass_through_operrors: If true, all captured OpErrors will be 

902 propagated. By default this captures all OpErrors. 

903 Raises: 

904 TypeError: If a non-None `watch_fn` is specified and it is not callable. 

905 """ 

906 

907 BaseDebugWrapperSession.__init__( 

908 self, sess, thread_name_filter=thread_name_filter, 

909 pass_through_operrors=pass_through_operrors) 

910 

911 self._watch_fn = None 

912 if watch_fn is not None: 

913 if not callable(watch_fn): 

914 raise TypeError("watch_fn is not callable") 

915 self._watch_fn = watch_fn 

916 

917 def on_session_init(self, request): 

918 """See doc of BaseDebugWrapperSession.on_run_start.""" 

919 

920 return OnSessionInitResponse(OnSessionInitAction.PROCEED) 

921 

922 @abc.abstractmethod 

923 def prepare_run_debug_urls(self, fetches, feed_dict): 

924 """Abstract method to be implemented by concrete subclasses. 

925 

926 This method prepares the run-specific debug URL(s). 

927 

928 Args: 

929 fetches: Same as the `fetches` argument to `Session.run()` 

930 feed_dict: Same as the `feed_dict` argument to `Session.run()` 

931 

932 Returns: 

933 debug_urls: (`str` or `list` of `str`) Debug URLs to be used in 

934 this `Session.run()` call. 

935 """ 

936 

937 def on_run_start(self, request): 

938 """See doc of BaseDebugWrapperSession.on_run_start.""" 

939 

940 debug_urls, watch_opts = self._prepare_run_watch_config( 

941 request.fetches, request.feed_dict) 

942 

943 return OnRunStartResponse( 

944 OnRunStartAction.DEBUG_RUN, 

945 debug_urls, 

946 debug_ops=watch_opts.debug_ops, 

947 node_name_regex_allowlist=watch_opts.node_name_regex_allowlist, 

948 op_type_regex_allowlist=watch_opts.op_type_regex_allowlist, 

949 tensor_dtype_regex_allowlist=watch_opts.tensor_dtype_regex_allowlist, 

950 tolerate_debug_op_creation_failures=( 

951 watch_opts.tolerate_debug_op_creation_failures)) 

952 

953 def _prepare_run_watch_config(self, fetches, feed_dict): 

954 """Get the debug_urls, and node/op allowlists for the current run() call. 

955 

956 Args: 

957 fetches: Same as the `fetches` argument to `Session.run()`. 

958 feed_dict: Same as the `feed_dict argument` to `Session.run()`. 

959 

960 Returns: 

961 debug_urls: (str or list of str) Debug URLs for the current run() call. 

962 Currently, the list consists of only one URL that is a file:// URL. 

963 watch_options: (WatchOptions) The return value of a watch_fn, containing 

964 options including debug_ops, and allowlists. 

965 """ 

966 

967 debug_urls = self.prepare_run_debug_urls(fetches, feed_dict) 

968 if self._watch_fn is None: 

969 watch_options = WatchOptions() 

970 else: 

971 watch_options = self._watch_fn(fetches, feed_dict) 

972 if isinstance(watch_options, tuple): 

973 # For legacy return type (tuples). 

974 watch_options = WatchOptions(*watch_options) 

975 

976 return debug_urls, watch_options 

977 

978 def on_run_end(self, request): 

979 """See doc of BaseDebugWrapperSession.on_run_end.""" 

980 

981 return OnRunEndResponse()