Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/debug/wrappers/framework.py: 30%
274 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Framework of debug wrapper sessions.
17A debug wrapper session is a wrapper around a TensorFlow Python Session.
18The wrapper preserves the Session interface, most importantly the run() method,
19while providing abilities to:
20a) Intercept a run() call to a wrapped session and insert debug tensor watches
21 according to externally-specified debug URLs.
23b) Release control to an external (i.e., non-Session) object before and after
24 the run() call, so that the external object can perform actions such as
25 launching a UI to let users inspect the intermediate tensors and partition
26 graphs from the run() call.
28c) (To be implemented in a future CL) Enter an instruction loop to let an
29 external object (e.g., remote client) launch run() and cont() calls
30 remotely.
32*** The lifetime of a debug wrapper session: ***
341) The wrapper session is created by calling the constructor with a
35 wrapped (normal) session as the argument:
36 wrapper = FooDebugWrapperSession(sess)
37 wherein FooDebugWrapperSession is a concrete subclass implementing the
38 abstract BaseDebugWrapperSession class below.
402) Near the end of the constructor call, the on_session_init() callback is
41 invoked, with a OnSessionInitRequest object as the argument. The object
42 carries the wrapped (normal) session object.
443) The callback handles the request and returns a OnSessionInitResponse
45 object with an action field, directing the wrapper session what to do next.
47If the action field in the OnSessionInitResponse is PROCEED, the constructor
48returns. Control is released back to the caller of the constructor, which can
49invoke run() method of wrapper session with the same syntax as a non-wrapped
50session, e.g.,:
51 wrapper.run(fetches, feed_dict=feeds, options=run_options)
53Below, A1 - A2 is the lifetime of a wrapper run() call if the action is
54PROCEED:
56A1) Right at the start of each run() call, the on_run_start() callback is
57 invoked, with an OnRunStartRequest object carrying information such as
58 the fetches, the feed dict, the run options and run metadata used in
59 this run call, along with a count of how many run calls has occurred
60 on this wrapper session. The callback then returns an OnRunStartResponse
61 object, of which the action field directs what the wrapper session
62 actually will do of the run() call.
64 If the action is DEBUG_RUN, a debugged (tensor-watched) run will ensue,
65 with the debug URLs supplied in the debug_urls field of the response.
66 These can be file:// or grpc:// URLs, for example.
68 If the action is NON_DEBUG_RUN, a non-debug (normal) run will ensue.
70A2) Right before the run() returns, the on_run_end() callback is invoked,
71 with an OnRunEndRequest object as the argument, which carries information
72 including the actual action performed in the wrapper run() call and the
73 run_metadata from the run() call.
75However, if the action field in OnSessionInitResponse is
76REMOTE_INSTR_LOOP, the constructor will automatically invoke an instruction loop
77that gives the control to a remote caller.
79In the remote instruction loop, the following steps will happen:
81B1) Callback on_instr_start() is invoked. The callback will return an
82 OnInstrStartResponse object with an action field which can order one of
83 the following actions:
84 i) a run() call with fetches, feeds and debug_urls specified.
85 ii) exit the instruction loop.
87B2) The wrapper session carries out the action specified above.
89B3) If still in the instruction loop, the wrapper session invokes the
90 on_instr_end() callback. After the on_instr_end() callback returns, jump
91 back to B1.
93TODO(cais): Implemented the instruction loop in B1 - B3.
95"""
97import abc
98import re
99import threading
101from tensorflow.core.protobuf import config_pb2
102from tensorflow.python.client import session
103from tensorflow.python.debug.lib import debug_utils
104from tensorflow.python.framework import errors
105from tensorflow.python.framework import stack
106from tensorflow.python.platform import tf_logging
107from tensorflow.python.training import monitored_session
108from tensorflow.python.util import nest
109from tensorflow.python.util.compat import collections_abc
112# Helper function.
113def _check_type(obj, expected_types):
114 """Check if an object is of the expected type.
116 Args:
117 obj: The object being checked.
118 expected_types: (`type` or an iterable of `type`s) The expected `type`(s)
119 of obj.
121 Raises:
122 TypeError: If obj is not an instance of expected_type.
123 """
124 if not isinstance(obj, expected_types):
125 raise TypeError("Expected type %s; got type %s" %
126 (expected_types, type(obj)))
129class OnSessionInitRequest:
130 """Request to an on-session-init callback.
132 This callback is invoked during the __init__ call to a debug-wrapper session.
133 """
135 def __init__(self, sess):
136 """Constructor.
138 Args:
139 sess: A tensorflow Session object.
140 """
142 _check_type(sess, (session.BaseSession, monitored_session.MonitoredSession))
143 self.session = sess
146class OnSessionInitAction:
147 """Enum-like values for possible action to take on session init."""
149 # Proceed, without special actions, in the wrapper session initialization.
150 # What action the wrapper session performs next is determined by the caller
151 # of the wrapper session. E.g., it can call run().
152 PROCEED = "proceed"
154 # Instead of letting the caller of the wrapper session determine what actions
155 # the wrapper session will perform next, enter a loop to receive instructions
156 # from a remote client.
157 # For example, TensorBoard visual debugger can use this action so that it can
158 # launch session.run() calls remotely.
159 REMOTE_INSTR_LOOP = "remote_instr_loop"
162class OnSessionInitResponse:
163 """Response from an on-session-init callback."""
165 def __init__(self, action):
166 """Constructor.
168 Args:
169 action: (`OnSessionInitAction`) Debugger action to take on session init.
170 """
171 _check_type(action, str)
172 self.action = action
175class OnRunStartRequest:
176 """Request to an on-run-start callback.
178 This callback is invoked during a run() call of the debug-wrapper
179 session, immediately after the run() call counter is incremented.
180 """
182 def __init__(self, fetches, feed_dict, run_options, run_metadata,
183 run_call_count, is_callable_runner=False):
184 """Constructor of `OnRunStartRequest`.
186 Args:
187 fetches: Fetch targets of the run() call.
188 feed_dict: The feed dictionary to the run() call.
189 run_options: RunOptions input to the run() call.
190 run_metadata: RunMetadata input to the run() call.
191 The above four arguments are identical to the input arguments to the
192 run() method of a non-wrapped TensorFlow session.
193 run_call_count: 1-based count of how many run calls (including this one)
194 has been invoked.
195 is_callable_runner: (bool) whether a runner returned by
196 Session.make_callable is being run.
197 """
198 self.fetches = fetches
199 self.feed_dict = feed_dict
200 self.run_options = run_options
201 self.run_metadata = run_metadata
202 self.run_call_count = run_call_count
203 self.is_callable_runner = is_callable_runner
206class OnRunStartAction:
207 """Enum-like values for possible action to take on start of a run() call."""
209 # Run once with debug tensor-watching.
210 DEBUG_RUN = "debug_run"
212 # Run once with profiler.
213 PROFILE_RUN = "profile_run"
215 # Run without debug tensor-watching.
216 NON_DEBUG_RUN = "non_debug_run"
220class OnRunStartResponse:
221 """Request from an on-run-start callback.
223 The caller of the callback can use this response object to specify what
224 action the debug-wrapper session actually takes on the run() call.
225 """
227 def __init__(self,
228 action,
229 debug_urls,
230 debug_ops="DebugIdentity",
231 node_name_regex_allowlist=None,
232 op_type_regex_allowlist=None,
233 tensor_dtype_regex_allowlist=None,
234 tolerate_debug_op_creation_failures=False):
235 """Constructor of `OnRunStartResponse`.
237 Args:
238 action: (`OnRunStartAction`) the action actually taken by the wrapped
239 session for the run() call.
240 debug_urls: (`list` of `str`) debug_urls used in watching the tensors
241 during the run() call.
242 debug_ops: (`str` or `list` of `str`) Debug op(s) to be used by the
243 debugger.
244 node_name_regex_allowlist: Regular-expression allowlist for node
245 name.
246 op_type_regex_allowlist: Regular-expression allowlist for op type.
247 tensor_dtype_regex_allowlist: Regular-expression allowlist for tensor
248 dtype.
249 tolerate_debug_op_creation_failures: Whether debug op creation failures
250 are to be tolerated.
251 """
253 _check_type(action, str)
254 self.action = action
256 _check_type(debug_urls, list)
257 self.debug_urls = debug_urls
259 self.debug_ops = debug_ops
261 self.node_name_regex_allowlist = node_name_regex_allowlist
262 self.op_type_regex_allowlist = op_type_regex_allowlist
263 self.tensor_dtype_regex_allowlist = tensor_dtype_regex_allowlist
264 self.tolerate_debug_op_creation_failures = (
265 tolerate_debug_op_creation_failures)
268class OnRunEndRequest:
269 """Request to an on-run-end callback.
271 The callback is invoked immediately before the wrapped run() call ends.
272 """
274 def __init__(self,
275 performed_action,
276 run_metadata=None,
277 client_graph_def=None,
278 tf_error=None):
279 """Constructor for `OnRunEndRequest`.
281 Args:
282 performed_action: (`OnRunStartAction`) Actually-performed action by the
283 debug-wrapper session.
284 run_metadata: run_metadata output from the run() call (if any).
285 client_graph_def: (GraphDef) GraphDef from the client side, i.e., from
286 the python front end of TensorFlow. Can be obtained with
287 session.graph.as_graph_def().
288 tf_error: (errors.OpError subtypes) TensorFlow OpError that occurred
289 during the run (if any).
290 """
292 _check_type(performed_action, str)
293 self.performed_action = performed_action
295 if run_metadata is not None:
296 _check_type(run_metadata, config_pb2.RunMetadata)
297 self.run_metadata = run_metadata
298 self.client_graph_def = client_graph_def
299 self.tf_error = tf_error
302class OnRunEndResponse:
303 """Response from an on-run-end callback."""
305 def __init__(self):
307 # Currently only a placeholder.
308 pass
311class BaseDebugWrapperSession(session.SessionInterface, metaclass=abc.ABCMeta):
312 """Base class of debug-wrapper session classes.
314 Concrete classes that inherit from this class need to implement the abstract
315 methods such as on_session_init, on_run_start and on_run_end.
316 """
318 def __init__(self, sess, thread_name_filter=None,
319 pass_through_operrors=False):
320 """Constructor of `BaseDebugWrapperSession`.
322 Args:
323 sess: An (unwrapped) TensorFlow session instance. It should be a subtype
324 of `BaseSession` or `tf.MonitoredSession`.
325 thread_name_filter: Regular-expression filter (allowlist) for name(s) of
326 thread(s) on which the wrapper session will be active. This regular
327 expression is used in a start-anchored fashion on the thread name, i.e.,
328 by applying the `match` method of the compiled pattern. The default
329 `None` means that the wrapper session will be active on all threads.
330 E.g., r"MainThread$", r"QueueRunnerThread.*".
331 pass_through_operrors: If True, all captured OpErrors will be
332 propagated. By default this captures all OpErrors.
334 Raises:
335 ValueError: On invalid `OnSessionInitAction` value.
336 NotImplementedError: If a non-DirectSession sess object is received.
337 """
339 _check_type(sess, (session.BaseSession, monitored_session.MonitoredSession))
341 # The session being wrapped.
342 self._sess = sess
343 self._thread_name_filter_pattern = (re.compile(thread_name_filter)
344 if thread_name_filter else None)
345 # TODO(cais/kstevens): Unittest this pass through feature.
346 self._pass_through_operrors = pass_through_operrors
348 # Keeps track of number of run calls that have been performed on this
349 # debug-wrapper session. The count can be used for purposes such as
350 # displaying the state of the Session in a UI and determining a run
351 # number-dependent debug URL.
352 self._run_call_count = 0
354 # Invoke on-session-init callback.
355 response = self.on_session_init(OnSessionInitRequest(self._sess))
356 _check_type(response, OnSessionInitResponse)
358 if response.action == OnSessionInitAction.PROCEED:
359 pass
360 elif response.action == OnSessionInitAction.REMOTE_INSTR_LOOP:
361 # TODO(cais): Implement REMOTE_INSTR_LOOP
362 raise NotImplementedError(
363 "OnSessionInitAction REMOTE_INSTR_LOOP has not been "
364 "implemented.")
365 else:
366 raise ValueError(
367 "Invalid OnSessionInitAction value: %s" % response.action)
369 self._default_session_context_manager = None
371 # A cache for callables created from CallableOptions.
372 self._cached_callables_from_options = {}
374 @property
375 def graph(self):
376 return self._sess.graph
378 @property
379 def graph_def(self):
380 return self._sess.graph_def
382 @property
383 def sess_str(self):
384 return self._sess.sess_str
386 @property
387 def session(self):
388 return self._sess
390 def run(self,
391 fetches,
392 feed_dict=None,
393 options=None,
394 run_metadata=None,
395 callable_runner=None,
396 callable_runner_args=None,
397 callable_options=None):
398 """Wrapper around Session.run() that inserts tensor watch options.
400 Args:
401 fetches: Same as the `fetches` arg to regular `Session.run()`.
402 feed_dict: Same as the `feed_dict` arg to regular `Session.run()`.
403 options: Same as the `options` arg to regular `Session.run()`.
404 run_metadata: Same as the `run_metadata` arg to regular `Session.run()`.
405 callable_runner: A `callable` returned by `Session.make_callable()`.
406 If not `None`, `fetches` and `feed_dict` must both be `None`.
407 Mutually exclusive with `callable_options`.
408 callable_runner_args: An optional list of arguments to `callable_runner`
409 or for `callable_options`.
410 callable_options: An instance of `config_pb2.CallableOptions`, to be
411 used with `Session._make_callable_from_options()`. Mutually exclusive
412 with `callable_runner`.
414 Returns:
415 Simply forwards the output of the wrapped `Session.run()` call.
417 Raises:
418 ValueError: On invalid `OnRunStartAction` value. Or if `callable_runner`
419 is not `None` and either or both of `fetches` and `feed_dict` is `None`.
420 """
421 if callable_runner and callable_options:
422 raise ValueError(
423 "callable_runner and callable_options are mutually exclusive, but "
424 "are both specified in this call to BaseDebugWrapperSession.run().")
426 if callable_runner and (fetches or feed_dict):
427 raise ValueError(
428 "callable_runner and fetches/feed_dict are mutually exclusive, "
429 "but are used simultaneously.")
430 elif callable_options and (fetches or feed_dict):
431 raise ValueError(
432 "callable_options and fetches/feed_dict are mutually exclusive, "
433 "but are used simultaneously.")
435 self.increment_run_call_count()
437 def is_empty(x):
438 """Check whether a possibly nested structure is empty."""
439 if not nest.is_nested(x):
440 return False
441 if isinstance(x, collections_abc.Mapping):
442 return is_empty(list(x.values()))
443 for item in x:
444 if not is_empty(item):
445 return False
446 return True
448 empty_fetches = is_empty(fetches)
449 if empty_fetches:
450 tf_logging.info(
451 "Due to empty fetches, tfdbg Session wrapper is letting a "
452 "Session.run pass through without any debugging actions.")
453 if self._is_disabled_thread() or empty_fetches:
454 if callable_runner:
455 return callable_runner(*callable_runner_args)
456 elif callable_options:
457 # pylint:disable=protected-access
458 return self._sess._make_callable_from_options(
459 callable_options)(*callable_runner_args)
460 # pylint:enable=protected-access
461 else:
462 return self._sess.run(fetches,
463 feed_dict=feed_dict,
464 options=options,
465 run_metadata=run_metadata)
467 # Invoke on-run-start callback and obtain response.
468 run_start_resp = self.on_run_start(
469 OnRunStartRequest(fetches, feed_dict, options, run_metadata,
470 self._run_call_count,
471 is_callable_runner=bool(callable_runner)))
472 _check_type(run_start_resp, OnRunStartResponse)
474 if run_start_resp.action == OnRunStartAction.DEBUG_RUN:
475 retvals, run_end_req = self._run_with_debugging(
476 run_start_resp, fetches, feed_dict, options, run_metadata,
477 callable_runner, callable_runner_args, callable_options)
478 elif run_start_resp.action == OnRunStartAction.PROFILE_RUN:
479 retvals, run_end_req = self._run_with_profiling(
480 run_start_resp, fetches, feed_dict, options, run_metadata,
481 callable_runner, callable_runner_args, callable_options)
482 elif run_start_resp.action == OnRunStartAction.NON_DEBUG_RUN:
483 # Invoke run() method of the wrapped session.
484 if callable_runner:
485 retvals = callable_runner(*callable_runner_args)
486 elif callable_options:
487 # pylint:disable=protected-access
488 callable_object = self._sess._make_callable_from_options(
489 callable_options)
490 # pylint:enable=protected-access
491 retvals = callable_object(*callable_runner_args)
492 else:
493 retvals = self._sess.run(
494 fetches,
495 feed_dict=feed_dict,
496 options=options,
497 run_metadata=run_metadata)
499 # Prepare arg for the on-run-end callback.
500 run_end_req = OnRunEndRequest(run_start_resp.action)
501 else:
502 raise ValueError(
503 "Invalid OnRunStartAction value: %s" % run_start_resp.action)
505 # Invoke on-run-end callback and obtain response.
506 run_end_resp = self.on_run_end(run_end_req)
507 _check_type(run_end_resp, OnRunEndResponse)
508 # Currently run_end_resp is only a placeholder. No action is taken on it.
510 return retvals
512 def _run_with_debugging(self,
513 run_start_resp,
514 fetches,
515 feed_dict,
516 options,
517 run_metadata,
518 callable_runner,
519 callable_runner_args,
520 callable_options):
521 """Perform a session.run() or callable with debugging."""
522 # Decorate RunOption to fill in debugger tensor watch specifications.
523 decorated_run_options = None
524 if callable_options:
525 callable_options_id = id(callable_options)
526 if callable_options_id not in self._cached_callables_from_options:
527 # Make a copy of callable_options to avoid mutating it.
528 new_callable_options = config_pb2.CallableOptions()
529 new_callable_options.CopyFrom(callable_options)
530 decorated_run_options = new_callable_options.run_options
531 else:
532 decorated_run_options = options or config_pb2.RunOptions()
534 run_metadata = run_metadata or config_pb2.RunMetadata()
536 if decorated_run_options:
537 self._decorate_run_options_for_debug(
538 decorated_run_options,
539 run_start_resp.debug_urls,
540 debug_ops=run_start_resp.debug_ops,
541 node_name_regex_allowlist=(run_start_resp.node_name_regex_allowlist),
542 op_type_regex_allowlist=run_start_resp.op_type_regex_allowlist,
543 tensor_dtype_regex_allowlist=(
544 run_start_resp.tensor_dtype_regex_allowlist),
545 tolerate_debug_op_creation_failures=(
546 run_start_resp.tolerate_debug_op_creation_failures))
548 # Invoke the run() method of the wrapped Session. Catch any TensorFlow
549 # runtime errors.
550 tf_error = None
551 try:
552 if callable_runner:
553 retvals = callable_runner(*callable_runner_args,
554 options=decorated_run_options,
555 run_metadata=run_metadata)
556 elif callable_options:
557 # pylint:disable=protected-access
558 if callable_options_id in self._cached_callables_from_options:
559 callable_object = self._cached_callables_from_options[
560 callable_options_id]
561 else:
562 callable_object = self._sess._make_callable_from_options(
563 new_callable_options)
564 self._cached_callables_from_options[
565 callable_options_id] = callable_object
566 # pylint:enable=protected-access
567 retvals = callable_object(
568 *callable_runner_args, run_metadata=run_metadata)
569 else:
570 retvals = self._sess.run(fetches,
571 feed_dict=feed_dict,
572 options=decorated_run_options,
573 run_metadata=run_metadata)
574 except errors.OpError as op_error:
575 if self._pass_through_operrors:
576 raise op_error
577 tf_error = op_error
578 retvals = op_error
580 return retvals, OnRunEndRequest(
581 run_start_resp.action,
582 run_metadata=run_metadata,
583 client_graph_def=self._sess.graph.as_graph_def(),
584 tf_error=tf_error)
586 def _run_with_profiling(self,
587 run_start_resp,
588 fetches,
589 feed_dict,
590 options,
591 run_metadata,
592 callable_runner,
593 callable_runner_args,
594 callable_options):
595 """Perform a session.run() or callable with profiling."""
596 # Decorate RunOption to fill in debugger tensor watch specifications.
597 decorated_run_options = None
598 if callable_options:
599 callable_options_id = id(callable_options)
600 if callable_options_id not in self._cached_callables_from_options:
601 # Make a copy of callable_options to avoid mutating it.
602 new_callable_options = config_pb2.CallableOptions()
603 new_callable_options.CopyFrom(callable_options)
604 decorated_run_options = new_callable_options.run_options
605 else:
606 decorated_run_options = options or config_pb2.RunOptions()
607 self._decorate_run_options_for_profile(decorated_run_options)
609 run_metadata = run_metadata or config_pb2.RunMetadata()
610 if callable_runner:
611 retvals = callable_runner(*callable_runner_args,
612 options=decorated_run_options,
613 run_metadata=run_metadata)
614 elif callable_options:
615 # pylint:disable=protected-access
616 callable_object = self._sess._make_callable_from_options(
617 new_callable_options)
618 # pylint:enable=protected-access
619 retvals = callable_object(
620 *callable_runner_args, run_metadata=run_metadata)
621 else:
622 retvals = self._sess.run(fetches,
623 feed_dict=feed_dict,
624 options=decorated_run_options,
625 run_metadata=run_metadata)
626 return retvals, OnRunEndRequest(
627 run_start_resp.action,
628 run_metadata=run_metadata,
629 client_graph_def=self._sess.graph.as_graph_def())
631 def _is_disabled_thread(self):
632 thread_name = threading.current_thread().name or ""
633 return (self._thread_name_filter_pattern and
634 not self._thread_name_filter_pattern.match(thread_name))
636 def run_step_fn(self, step_fn):
637 return step_fn(
638 monitored_session.MonitoredSession.StepContext(self._sess, self.run))
640 def partial_run_setup(self, fetches, feeds=None):
641 """Sets up the feeds and fetches for partial runs in the session."""
642 raise NotImplementedError(
643 "partial_run_setup is not implemented for debug-wrapper sessions.")
645 def partial_run(self, handle, fetches, feed_dict=None):
646 raise NotImplementedError(
647 "partial_run is not implemented for debug-wrapper sessions.")
649 def list_devices(self, *args, **kwargs):
650 return self._sess.list_devices(*args, **kwargs)
652 def reset(self, *args, **kwargs):
653 return self._sess.reset(*args, **kwargs)
655 def make_callable(self,
656 fetches,
657 feed_list=None,
658 accept_options=False):
659 runner = self._sess.make_callable(
660 fetches, feed_list=feed_list, accept_options=True)
661 def wrapped_runner(*runner_args, **kwargs):
662 return self.run(None,
663 feed_dict=None,
664 options=kwargs.get("options", None),
665 run_metadata=kwargs.get("run_metadata", None),
666 callable_runner=runner,
667 callable_runner_args=runner_args)
668 return wrapped_runner
670 def _make_callable_from_options(self, callable_options):
671 def wrapped_runner(*feed_values, **kwargs):
672 return self.run(None,
673 run_metadata=kwargs.get("run_metadata", None),
674 callable_options=callable_options,
675 callable_runner_args=feed_values)
676 return wrapped_runner
678 @property
679 def run_call_count(self):
680 return self._run_call_count
682 def increment_run_call_count(self):
683 self._run_call_count += 1
685 def _is_disk_usage_reset_each_run(self):
686 """Indicates whether disk usage is reset after each Session.run.
688 Subclasses that clean up the disk usage after every run should
689 override this protected method.
691 Returns:
692 (`bool`) Whether the disk usage amount is reset to zero after
693 each Session.run.
694 """
695 return False
697 def _decorate_run_options_for_debug(
698 self,
699 run_options,
700 debug_urls,
701 debug_ops="DebugIdentity",
702 node_name_regex_allowlist=None,
703 op_type_regex_allowlist=None,
704 tensor_dtype_regex_allowlist=None,
705 tolerate_debug_op_creation_failures=False):
706 """Modify a RunOptions object for debug tensor watching.
708 Specifies request for outputting partition graphs. Adds
709 debug_tensor_watch_opts with proper debug URLs.
711 Args:
712 run_options: (RunOptions) the modified RunOptions object.
713 debug_urls: (list of str) debug URLs to be entered in run_options.
714 debug_tensor_watch_opts.
715 debug_ops: (str or list of str) debug op(s) to be used by the debugger.
716 node_name_regex_allowlist: Regular-expression allowlist for node
717 name.
718 op_type_regex_allowlist: Regular-expression allowlist for op type.
719 tensor_dtype_regex_allowlist: Regular-expression allowlist for tensor
720 dtype.
721 tolerate_debug_op_creation_failures: Whether debug op creation failures
722 are to be tolerated.
723 """
725 run_options.output_partition_graphs = True
726 debug_utils.watch_graph(
727 run_options,
728 self._sess.graph,
729 debug_urls=debug_urls,
730 debug_ops=debug_ops,
731 node_name_regex_allowlist=node_name_regex_allowlist,
732 op_type_regex_allowlist=op_type_regex_allowlist,
733 tensor_dtype_regex_allowlist=tensor_dtype_regex_allowlist,
734 tolerate_debug_op_creation_failures=tolerate_debug_op_creation_failures,
735 reset_disk_byte_usage=(self._run_call_count == 1 or
736 self._is_disk_usage_reset_each_run()))
738 def _decorate_run_options_for_profile(self, run_options):
739 """Modify a RunOptions object for profiling TensorFlow graph execution.
741 Args:
742 run_options: (RunOptions) the modified RunOptions object.
743 """
745 run_options.trace_level = config_pb2.RunOptions.FULL_TRACE
747 @abc.abstractmethod
748 def on_session_init(self, request):
749 """Callback invoked during construction of the debug-wrapper session.
751 This is a blocking callback.
752 The invocation happens right before the constructor ends.
754 Args:
755 request: (`OnSessionInitRequest`) callback request carrying information
756 such as the session being wrapped.
758 Returns:
759 An instance of `OnSessionInitResponse`.
760 """
762 @abc.abstractmethod
763 def on_run_start(self, request):
764 """Callback invoked on run() calls to the debug-wrapper session.
766 This is a blocking callback.
767 The invocation happens after the wrapper's run() call is entered,
768 after an increment of run call counter.
770 Args:
771 request: (`OnRunStartRequest`) callback request object carrying
772 information about the run call such as the fetches, feed dict, run
773 options, run metadata, and how many `run()` calls to this wrapper
774 session have occurred.
776 Returns:
777 An instance of `OnRunStartResponse`, carrying information to
778 debug URLs used to watch the tensors.
779 """
781 @abc.abstractmethod
782 def on_run_end(self, request):
783 """Callback invoked on run() calls to the debug-wrapper session.
785 This is a blocking callback.
786 The invocation happens right before the wrapper exits its run() call.
788 Args:
789 request: (`OnRunEndRequest`) callback request object carrying information
790 such as the actual action performed by the session wrapper for the
791 run() call.
793 Returns:
794 An instance of `OnRunStartResponse`.
795 """
797 def as_default(self):
798 return stack.default_session(self)
800 def __enter__(self):
801 if self._default_session_context_manager is None:
802 self._default_session_context_manager = self.as_default()
803 return self._default_session_context_manager.__enter__()
805 def __exit__(self, exec_type, exec_value, exec_tb):
806 self._default_session_context_manager.__exit__(
807 exec_type, exec_value, exec_tb)
809 def __del__(self):
810 if hasattr(self._sess, "__del__"):
811 self._sess.__del__()
813 def close(self):
814 self._sess.close()
816 # TODO(cais): Add _node_name_regex_allowlist and
817 # _node_op_type_regex_allowlist.
819 def should_stop(self):
820 if hasattr(self._sess, "should_stop"):
821 return self._sess.should_stop()
822 else:
823 raise ValueError(
824 "The wrapped session %r does not have a method called 'should_stop'. "
825 "Do you intend to wrap a tf.MonitoredSession instead?" % self._sess)
828class WatchOptions:
829 """Type for return values of watch_fn."""
831 def __init__(self,
832 debug_ops=None,
833 node_name_regex_allowlist=None,
834 op_type_regex_allowlist=None,
835 tensor_dtype_regex_allowlist=None,
836 tolerate_debug_op_creation_failures=False):
837 """Constructor of WatchOptions: Debug watch options.
839 Used as return values of `watch_fn`s.
841 Args:
842 debug_ops: (`str` or `list of str`) Debug ops to be used.
843 node_name_regex_allowlist: Regular-expression allowlist for node_name,
844 e.g., `"(weight_[0-9]+|bias_.*)"`
845 op_type_regex_allowlist: Regular-expression allowlist for the op type of
846 nodes, e.g., `"(Variable|Add)"`.
847 If both `node_name_regex_allowlist` and `op_type_regex_allowlist`
848 are set, the two filtering operations will occur in a logical `AND`
849 relation. In other words, a node will be included if and only if it
850 hits both allowlists.
851 tensor_dtype_regex_allowlist: Regular-expression allowlist for Tensor
852 data type, e.g., `"^int.*"`.
853 This allowlist operates in logical `AND` relations to the two allowlists
854 above.
855 tolerate_debug_op_creation_failures: (`bool`) whether debug op creation
856 failures (e.g., due to dtype incompatibility) are to be tolerated by not
857 throwing exceptions.
858 """
859 if debug_ops:
860 self.debug_ops = debug_ops
861 else:
862 self.debug_ops = ["DebugIdentity"]
863 self.node_name_regex_allowlist = node_name_regex_allowlist
864 self.op_type_regex_allowlist = op_type_regex_allowlist
865 self.tensor_dtype_regex_allowlist = tensor_dtype_regex_allowlist
866 self.tolerate_debug_op_creation_failures = (
867 tolerate_debug_op_creation_failures)
869 def __repr__(self):
870 return ("WatchOptions(debug_ops=%r, node_name_regex_allowlist=%r, "
871 "op_type_regex_allowlist=%r, tensor_dtype_regex_allowlist=%r, "
872 "tolerate_debug_op_creation_failures=%r)" %
873 (self.debug_ops, self.node_name_regex_allowlist,
874 self.op_type_regex_allowlist, self.tensor_dtype_regex_allowlist,
875 self.tolerate_debug_op_creation_failures))
878class NonInteractiveDebugWrapperSession(BaseDebugWrapperSession):
879 """Base class for non-interactive (i.e., non-CLI) debug wrapper sessions."""
881 def __init__(self, sess, watch_fn=None, thread_name_filter=None,
882 pass_through_operrors=False):
883 """Constructor of NonInteractiveDebugWrapperSession.
885 Args:
886 sess: The TensorFlow `Session` object being wrapped.
887 watch_fn: (`Callable`) A Callable that maps the fetches and feeds of a
888 debugged `Session.run()` call to `WatchOptions.`
889 * Args:
890 * `fetches`: the fetches to the `Session.run()` call.
891 * `feeds`: the feeds to the `Session.run()` call.
893 * Returns:
894 (`tf_debug.WatchOptions`) An object containing debug options including
895 the debug ops to use, the node names, op types and/or tensor data
896 types to watch, etc. See the documentation of `tf_debug.WatchOptions`
897 for more details.
898 thread_name_filter: Regular-expression white list for threads on which the
899 wrapper session will be active. See doc of `BaseDebugWrapperSession` for
900 more details.
901 pass_through_operrors: If true, all captured OpErrors will be
902 propagated. By default this captures all OpErrors.
903 Raises:
904 TypeError: If a non-None `watch_fn` is specified and it is not callable.
905 """
907 BaseDebugWrapperSession.__init__(
908 self, sess, thread_name_filter=thread_name_filter,
909 pass_through_operrors=pass_through_operrors)
911 self._watch_fn = None
912 if watch_fn is not None:
913 if not callable(watch_fn):
914 raise TypeError("watch_fn is not callable")
915 self._watch_fn = watch_fn
917 def on_session_init(self, request):
918 """See doc of BaseDebugWrapperSession.on_run_start."""
920 return OnSessionInitResponse(OnSessionInitAction.PROCEED)
922 @abc.abstractmethod
923 def prepare_run_debug_urls(self, fetches, feed_dict):
924 """Abstract method to be implemented by concrete subclasses.
926 This method prepares the run-specific debug URL(s).
928 Args:
929 fetches: Same as the `fetches` argument to `Session.run()`
930 feed_dict: Same as the `feed_dict` argument to `Session.run()`
932 Returns:
933 debug_urls: (`str` or `list` of `str`) Debug URLs to be used in
934 this `Session.run()` call.
935 """
937 def on_run_start(self, request):
938 """See doc of BaseDebugWrapperSession.on_run_start."""
940 debug_urls, watch_opts = self._prepare_run_watch_config(
941 request.fetches, request.feed_dict)
943 return OnRunStartResponse(
944 OnRunStartAction.DEBUG_RUN,
945 debug_urls,
946 debug_ops=watch_opts.debug_ops,
947 node_name_regex_allowlist=watch_opts.node_name_regex_allowlist,
948 op_type_regex_allowlist=watch_opts.op_type_regex_allowlist,
949 tensor_dtype_regex_allowlist=watch_opts.tensor_dtype_regex_allowlist,
950 tolerate_debug_op_creation_failures=(
951 watch_opts.tolerate_debug_op_creation_failures))
953 def _prepare_run_watch_config(self, fetches, feed_dict):
954 """Get the debug_urls, and node/op allowlists for the current run() call.
956 Args:
957 fetches: Same as the `fetches` argument to `Session.run()`.
958 feed_dict: Same as the `feed_dict argument` to `Session.run()`.
960 Returns:
961 debug_urls: (str or list of str) Debug URLs for the current run() call.
962 Currently, the list consists of only one URL that is a file:// URL.
963 watch_options: (WatchOptions) The return value of a watch_fn, containing
964 options including debug_ops, and allowlists.
965 """
967 debug_urls = self.prepare_run_debug_urls(fetches, feed_dict)
968 if self._watch_fn is None:
969 watch_options = WatchOptions()
970 else:
971 watch_options = self._watch_fn(fetches, feed_dict)
972 if isinstance(watch_options, tuple):
973 # For legacy return type (tuples).
974 watch_options = WatchOptions(*watch_options)
976 return debug_urls, watch_options
978 def on_run_end(self, request):
979 """See doc of BaseDebugWrapperSession.on_run_end."""
981 return OnRunEndResponse()