Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/debug/wrappers/local_cli_wrapper.py: 16%
234 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Debugger Wrapper Session Consisting of a Local Curses-based CLI."""
16import argparse
17import os
18import sys
19import tempfile
21# Google-internal import(s).
22from tensorflow.python.debug.cli import analyzer_cli
23from tensorflow.python.debug.cli import cli_config
24from tensorflow.python.debug.cli import cli_shared
25from tensorflow.python.debug.cli import command_parser
26from tensorflow.python.debug.cli import debugger_cli_common
27from tensorflow.python.debug.cli import profile_analyzer_cli
28from tensorflow.python.debug.cli import ui_factory
29from tensorflow.python.debug.lib import common
30from tensorflow.python.debug.lib import debug_data
31from tensorflow.python.debug.wrappers import framework
32from tensorflow.python.lib.io import file_io
35_DUMP_ROOT_PREFIX = "tfdbg_"
38# TODO(donglin) Remove use_random_config_path after b/137652456 is fixed.
39class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession):
40 """Concrete subclass of BaseDebugWrapperSession implementing a local CLI.
42 This class has all the methods that a `session.Session` object has, in order
43 to support debugging with minimal code changes. Invoking its `run()` method
44 will launch the command-line interface (CLI) of tfdbg.
45 """
47 def __init__(self,
48 sess,
49 dump_root=None,
50 log_usage=True,
51 ui_type="curses",
52 thread_name_filter=None,
53 config_file_path=False):
54 """Constructor of LocalCLIDebugWrapperSession.
56 Args:
57 sess: The TensorFlow `Session` object being wrapped.
58 dump_root: (`str`) optional path to the dump root directory. Must be a
59 directory that does not exist or an empty directory. If the directory
60 does not exist, it will be created by the debugger core during debug
61 `run()` calls and removed afterwards. If `None`, the debug dumps will
62 be at tfdbg_<random_string> under the system temp directory.
63 log_usage: (`bool`) whether the usage of this class is to be logged.
64 ui_type: (`str`) requested UI type. Currently supported:
65 (curses | readline)
66 thread_name_filter: Regular-expression white list for thread name. See
67 the doc of `BaseDebugWrapperSession` for details.
68 config_file_path: Optional override to the default configuration file
69 path, which is at `${HOME}/.tfdbg_config`.
71 Raises:
72 ValueError: If dump_root is an existing and non-empty directory or if
73 dump_root is a file.
74 """
76 if log_usage:
77 pass # No logging for open-source.
79 framework.BaseDebugWrapperSession.__init__(
80 self, sess, thread_name_filter=thread_name_filter)
82 if not dump_root:
83 self._dump_root = tempfile.mkdtemp(prefix=_DUMP_ROOT_PREFIX)
84 else:
85 dump_root = os.path.expanduser(dump_root)
86 if os.path.isfile(dump_root):
87 raise ValueError("dump_root path points to a file: %s" % dump_root)
88 elif os.path.isdir(dump_root) and os.listdir(dump_root):
89 raise ValueError("dump_root path points to a non-empty directory: %s" %
90 dump_root)
92 self._dump_root = dump_root
94 self._initialize_argparsers()
96 # Registered tensor filters.
97 self._tensor_filters = {}
98 # Register frequently-used filter(s).
99 self.add_tensor_filter("has_inf_or_nan", debug_data.has_inf_or_nan)
101 # Below are the state variables of this wrapper object.
102 # _active_tensor_filter: what (if any) tensor filter is in effect. If such
103 # a filter is in effect, this object will call run() method of the
104 # underlying TensorFlow Session object until the filter passes. This is
105 # activated by the "-f" flag of the "run" command.
106 # _run_through_times: keeps track of how many times the wrapper needs to
107 # run through without stopping at the run-end CLI. It is activated by the
108 # "-t" option of the "run" command.
109 # _skip_debug: keeps track of whether the current run should be executed
110 # without debugging. It is activated by the "-n" option of the "run"
111 # command.
112 #
113 # _run_start_response: keeps track what OnRunStartResponse the wrapper
114 # should return at the next run-start callback. If this information is
115 # unavailable (i.e., is None), the run-start CLI will be launched to ask
116 # the user. This is the case, e.g., right before the first run starts.
117 self._active_tensor_filter = None
118 self._active_filter_exclude_node_names = None
119 self._active_tensor_filter_run_start_response = None
120 self._run_through_times = 1
121 self._skip_debug = False
122 self._run_start_response = None
123 self._is_run_start = True
124 self._ui_type = ui_type
125 self._config = None
126 if config_file_path:
127 self._config = cli_config.CLIConfig(config_file_path=config_file_path)
129 def _is_disk_usage_reset_each_run(self):
130 # The dumped tensors are all cleaned up after every Session.run
131 # in a command-line wrapper.
132 return True
134 def _initialize_argparsers(self):
135 self._argparsers = {}
136 ap = argparse.ArgumentParser(
137 description="Run through, with or without debug tensor watching.",
138 usage=argparse.SUPPRESS)
139 ap.add_argument(
140 "-t",
141 "--times",
142 dest="times",
143 type=int,
144 default=1,
145 help="How many Session.run() calls to proceed with.")
146 ap.add_argument(
147 "-n",
148 "--no_debug",
149 dest="no_debug",
150 action="store_true",
151 help="Run through without debug tensor watching.")
152 ap.add_argument(
153 "-f",
154 "--till_filter_pass",
155 dest="till_filter_pass",
156 type=str,
157 default="",
158 help="Run until a tensor in the graph passes the specified filter.")
159 ap.add_argument(
160 "-fenn",
161 "--filter_exclude_node_names",
162 dest="filter_exclude_node_names",
163 type=str,
164 default="",
165 help="When applying the tensor filter, exclude node with names "
166 "matching the regular expression. Applicable only if --tensor_filter "
167 "or -f is used.")
168 ap.add_argument(
169 "--node_name_filter",
170 dest="node_name_filter",
171 type=str,
172 default="",
173 help="Regular-expression filter for node names to be watched in the "
174 "run, e.g., loss, reshape.*")
175 ap.add_argument(
176 "--op_type_filter",
177 dest="op_type_filter",
178 type=str,
179 default="",
180 help="Regular-expression filter for op type to be watched in the run, "
181 "e.g., (MatMul|Add), Variable.*")
182 ap.add_argument(
183 "--tensor_dtype_filter",
184 dest="tensor_dtype_filter",
185 type=str,
186 default="",
187 help="Regular-expression filter for tensor dtype to be watched in the "
188 "run, e.g., (float32|float64), int.*")
189 ap.add_argument(
190 "-p",
191 "--profile",
192 dest="profile",
193 action="store_true",
194 help="Run and profile TensorFlow graph execution.")
195 self._argparsers["run"] = ap
197 ap = argparse.ArgumentParser(
198 description="Display information about this Session.run() call.",
199 usage=argparse.SUPPRESS)
200 self._argparsers["run_info"] = ap
202 self._argparsers["print_feed"] = command_parser.get_print_tensor_argparser(
203 "Print the value of a feed in feed_dict.")
205 def add_tensor_filter(self, filter_name, tensor_filter):
206 """Add a tensor filter.
208 Args:
209 filter_name: (`str`) name of the filter.
210 tensor_filter: (`callable`) the filter callable. See the doc string of
211 `DebugDumpDir.find()` for more details about its signature.
212 """
214 self._tensor_filters[filter_name] = tensor_filter
216 def on_session_init(self, request):
217 """Overrides on-session-init callback.
219 Args:
220 request: An instance of `OnSessionInitRequest`.
222 Returns:
223 An instance of `OnSessionInitResponse`.
224 """
226 return framework.OnSessionInitResponse(
227 framework.OnSessionInitAction.PROCEED)
229 def on_run_start(self, request):
230 """Overrides on-run-start callback.
232 Args:
233 request: An instance of `OnRunStartRequest`.
235 Returns:
236 An instance of `OnRunStartResponse`.
237 """
238 self._is_run_start = True
239 self._update_run_calls_state(
240 request.run_call_count, request.fetches, request.feed_dict,
241 is_callable_runner=request.is_callable_runner)
243 if self._active_tensor_filter:
244 # If we are running until a filter passes, we just need to keep running
245 # with the previous `OnRunStartResponse`.
246 return self._active_tensor_filter_run_start_response
248 self._exit_if_requested_by_user()
250 if self._run_call_count > 1 and not self._skip_debug:
251 if self._run_through_times > 0:
252 # Just run through without debugging.
253 return framework.OnRunStartResponse(
254 framework.OnRunStartAction.NON_DEBUG_RUN, [])
255 elif self._run_through_times == 0:
256 # It is the run at which the run-end CLI will be launched: activate
257 # debugging.
258 return (self._run_start_response or
259 framework.OnRunStartResponse(
260 framework.OnRunStartAction.DEBUG_RUN,
261 self._get_run_debug_urls()))
263 if self._run_start_response is None:
264 self._prep_cli_for_run_start()
266 self._run_start_response = self._launch_cli()
267 if self._active_tensor_filter:
268 self._active_tensor_filter_run_start_response = self._run_start_response
269 if self._run_through_times > 1:
270 self._run_through_times -= 1
272 self._exit_if_requested_by_user()
273 return self._run_start_response
275 def _exit_if_requested_by_user(self):
276 if self._run_start_response == debugger_cli_common.EXPLICIT_USER_EXIT:
277 # Explicit user "exit" command leads to sys.exit(1).
278 print(
279 "Note: user exited from debugger CLI: Calling sys.exit(1).",
280 file=sys.stderr)
281 sys.exit(1)
283 def _prep_cli_for_run_start(self):
284 """Prepare (but not launch) the CLI for run-start."""
285 self._run_cli = ui_factory.get_ui(self._ui_type, config=self._config)
287 help_intro = debugger_cli_common.RichTextLines([])
288 if self._run_call_count == 1:
289 # Show logo at the onset of the first run.
290 help_intro.extend(cli_shared.get_tfdbg_logo())
291 help_intro.extend(debugger_cli_common.get_tensorflow_version_lines())
292 help_intro.extend(debugger_cli_common.RichTextLines("Upcoming run:"))
293 help_intro.extend(self._run_info)
295 self._run_cli.set_help_intro(help_intro)
297 # Create initial screen output detailing the run.
298 self._title = "run-start: " + self._run_description
299 self._init_command = "run_info"
300 self._title_color = "blue_on_white"
302 def on_run_end(self, request):
303 """Overrides on-run-end callback.
305 Actions taken:
306 1) Load the debug dump.
307 2) Bring up the Analyzer CLI.
309 Args:
310 request: An instance of OnSessionInitRequest.
312 Returns:
313 An instance of OnSessionInitResponse.
314 """
316 self._is_run_start = False
317 if request.performed_action == framework.OnRunStartAction.DEBUG_RUN:
318 partition_graphs = None
319 if request.run_metadata and request.run_metadata.partition_graphs:
320 partition_graphs = request.run_metadata.partition_graphs
321 elif request.client_graph_def:
322 partition_graphs = [request.client_graph_def]
324 if request.tf_error and not os.path.isdir(self._dump_root):
325 # It is possible that the dump root may not exist due to errors that
326 # have occurred prior to graph execution (e.g., invalid device
327 # assignments), in which case we will just raise the exception as the
328 # unwrapped Session does.
329 raise request.tf_error
331 debug_dump = debug_data.DebugDumpDir(
332 self._dump_root, partition_graphs=partition_graphs)
333 debug_dump.set_python_graph(self._sess.graph)
335 passed_filter = None
336 passed_filter_exclude_node_names = None
337 if self._active_tensor_filter:
338 if not debug_dump.find(
339 self._tensor_filters[self._active_tensor_filter], first_n=1,
340 exclude_node_names=self._active_filter_exclude_node_names):
341 # No dumped tensor passes the filter in this run. Clean up the dump
342 # directory and move on.
343 self._remove_dump_root()
344 return framework.OnRunEndResponse()
345 else:
346 # Some dumped tensor(s) from this run passed the filter.
347 passed_filter = self._active_tensor_filter
348 passed_filter_exclude_node_names = (
349 self._active_filter_exclude_node_names)
350 self._active_tensor_filter = None
351 self._active_filter_exclude_node_names = None
353 self._prep_debug_cli_for_run_end(
354 debug_dump, request.tf_error, passed_filter,
355 passed_filter_exclude_node_names)
357 self._run_start_response = self._launch_cli()
359 # Clean up the dump generated by this run.
360 self._remove_dump_root()
361 elif request.performed_action == framework.OnRunStartAction.PROFILE_RUN:
362 self._prep_profile_cli_for_run_end(self._sess.graph, request.run_metadata)
363 self._run_start_response = self._launch_cli()
364 else:
365 # No debug information to show following a non-debug run() call.
366 self._run_start_response = None
368 # Return placeholder response that currently holds no additional
369 # information.
370 return framework.OnRunEndResponse()
372 def _remove_dump_root(self):
373 if os.path.isdir(self._dump_root):
374 file_io.delete_recursively(self._dump_root)
376 def _prep_debug_cli_for_run_end(self,
377 debug_dump,
378 tf_error,
379 passed_filter,
380 passed_filter_exclude_node_names):
381 """Prepare (but not launch) CLI for run-end, with debug dump from the run.
383 Args:
384 debug_dump: (debug_data.DebugDumpDir) The debug dump directory from this
385 run.
386 tf_error: (None or OpError) OpError that happened during the run() call
387 (if any).
388 passed_filter: (None or str) Name of the tensor filter that just passed
389 and caused the preparation of this run-end CLI (if any).
390 passed_filter_exclude_node_names: (None or str) Regular expression used
391 with the tensor filter to exclude ops with names matching the regular
392 expression.
393 """
395 if tf_error:
396 help_intro = cli_shared.get_error_intro(tf_error)
398 self._init_command = "help"
399 self._title_color = "red_on_white"
400 else:
401 help_intro = None
402 self._init_command = "lt"
404 self._title_color = "black_on_white"
405 if passed_filter is not None:
406 # Some dumped tensor(s) from this run passed the filter.
407 self._init_command = "lt -f %s" % passed_filter
408 if passed_filter_exclude_node_names:
409 self._init_command += (" --filter_exclude_node_names %s" %
410 passed_filter_exclude_node_names)
411 self._title_color = "red_on_white"
413 self._run_cli = analyzer_cli.create_analyzer_ui(
414 debug_dump,
415 self._tensor_filters,
416 ui_type=self._ui_type,
417 on_ui_exit=self._remove_dump_root,
418 config=self._config)
420 # Get names of all dumped tensors.
421 dumped_tensor_names = []
422 for datum in debug_dump.dumped_tensor_data:
423 dumped_tensor_names.append("%s:%d" %
424 (datum.node_name, datum.output_slot))
426 # Tab completions for command "print_tensors".
427 self._run_cli.register_tab_comp_context(["print_tensor", "pt"],
428 dumped_tensor_names)
430 # Tab completion for commands "node_info", "list_inputs" and
431 # "list_outputs". The list comprehension is used below because nodes()
432 # output can be unicodes and they need to be converted to strs.
433 self._run_cli.register_tab_comp_context(
434 ["node_info", "ni", "list_inputs", "li", "list_outputs", "lo"],
435 [str(node_name) for node_name in debug_dump.nodes()])
436 # TODO(cais): Reduce API surface area for aliases vis-a-vis tab
437 # completion contexts and registered command handlers.
439 self._title = "run-end: " + self._run_description
441 if help_intro:
442 self._run_cli.set_help_intro(help_intro)
444 def _prep_profile_cli_for_run_end(self, py_graph, run_metadata):
445 self._init_command = "lp"
446 self._run_cli = profile_analyzer_cli.create_profiler_ui(
447 py_graph, run_metadata, ui_type=self._ui_type,
448 config=self._run_cli.config)
449 self._title = "run-end (profiler mode): " + self._run_description
451 def _launch_cli(self):
452 """Launch the interactive command-line interface.
454 Returns:
455 The OnRunStartResponse specified by the user using the "run" command.
456 """
458 self._register_this_run_info(self._run_cli)
459 response = self._run_cli.run_ui(
460 init_command=self._init_command,
461 title=self._title,
462 title_color=self._title_color)
464 return response
466 def _run_info_handler(self, args, screen_info=None):
467 output = debugger_cli_common.RichTextLines([])
469 if self._run_call_count == 1:
470 output.extend(cli_shared.get_tfdbg_logo())
471 output.extend(debugger_cli_common.get_tensorflow_version_lines())
472 output.extend(self._run_info)
474 if (not self._is_run_start and
475 debugger_cli_common.MAIN_MENU_KEY in output.annotations):
476 menu = output.annotations[debugger_cli_common.MAIN_MENU_KEY]
477 if "list_tensors" not in menu.captions():
478 menu.insert(
479 0, debugger_cli_common.MenuItem("list_tensors", "list_tensors"))
481 return output
483 def _print_feed_handler(self, args, screen_info=None):
484 np_printoptions = cli_shared.numpy_printoptions_from_screen_info(
485 screen_info)
487 if not self._feed_dict:
488 return cli_shared.error(
489 "The feed_dict of the current run is None or empty.")
491 parsed = self._argparsers["print_feed"].parse_args(args)
492 tensor_name, tensor_slicing = (
493 command_parser.parse_tensor_name_with_slicing(parsed.tensor_name))
495 feed_key = None
496 feed_value = None
497 for key in self._feed_dict:
498 key_name = common.get_graph_element_name(key)
499 if key_name == tensor_name:
500 feed_key = key_name
501 feed_value = self._feed_dict[key]
502 break
504 if feed_key is None:
505 return cli_shared.error(
506 "The feed_dict of the current run does not contain the key %s" %
507 tensor_name)
508 else:
509 return cli_shared.format_tensor(
510 feed_value,
511 feed_key + " (feed)",
512 np_printoptions,
513 print_all=parsed.print_all,
514 tensor_slicing=tensor_slicing,
515 highlight_options=cli_shared.parse_ranges_highlight(parsed.ranges),
516 include_numeric_summary=parsed.numeric_summary)
518 def _run_handler(self, args, screen_info=None):
519 """Command handler for "run" command during on-run-start."""
521 del screen_info # Currently unused.
523 parsed = self._argparsers["run"].parse_args(args)
524 parsed.node_name_filter = parsed.node_name_filter or None
525 parsed.op_type_filter = parsed.op_type_filter or None
526 parsed.tensor_dtype_filter = parsed.tensor_dtype_filter or None
528 if parsed.filter_exclude_node_names and not parsed.till_filter_pass:
529 raise ValueError(
530 "The --filter_exclude_node_names (or -feon) flag is valid only if "
531 "the --till_filter_pass (or -f) flag is used.")
533 if parsed.profile:
534 raise debugger_cli_common.CommandLineExit(
535 exit_token=framework.OnRunStartResponse(
536 framework.OnRunStartAction.PROFILE_RUN, []))
538 self._skip_debug = parsed.no_debug
539 self._run_through_times = parsed.times
541 if parsed.times > 1 or parsed.no_debug:
542 # If requested -t times > 1, the very next run will be a non-debug run.
543 action = framework.OnRunStartAction.NON_DEBUG_RUN
544 debug_urls = []
545 else:
546 action = framework.OnRunStartAction.DEBUG_RUN
547 debug_urls = self._get_run_debug_urls()
548 run_start_response = framework.OnRunStartResponse(
549 action,
550 debug_urls,
551 node_name_regex_allowlist=parsed.node_name_filter,
552 op_type_regex_allowlist=parsed.op_type_filter,
553 tensor_dtype_regex_allowlist=parsed.tensor_dtype_filter)
555 if parsed.till_filter_pass:
556 # For the run-till-filter-pass (run -f) mode, use the DEBUG_RUN
557 # option to access the intermediate tensors, and set the corresponding
558 # state flag of the class itself to True.
559 if parsed.till_filter_pass in self._tensor_filters:
560 action = framework.OnRunStartAction.DEBUG_RUN
561 self._active_tensor_filter = parsed.till_filter_pass
562 self._active_filter_exclude_node_names = (
563 parsed.filter_exclude_node_names)
564 self._active_tensor_filter_run_start_response = run_start_response
565 else:
566 # Handle invalid filter name.
567 return debugger_cli_common.RichTextLines(
568 ["ERROR: tensor filter \"%s\" does not exist." %
569 parsed.till_filter_pass])
571 # Raise CommandLineExit exception to cause the CLI to exit.
572 raise debugger_cli_common.CommandLineExit(exit_token=run_start_response)
574 def _register_this_run_info(self, curses_cli):
575 curses_cli.register_command_handler(
576 "run",
577 self._run_handler,
578 self._argparsers["run"].format_help(),
579 prefix_aliases=["r"])
580 curses_cli.register_command_handler(
581 "run_info",
582 self._run_info_handler,
583 self._argparsers["run_info"].format_help(),
584 prefix_aliases=["ri"])
585 curses_cli.register_command_handler(
586 "print_feed",
587 self._print_feed_handler,
588 self._argparsers["print_feed"].format_help(),
589 prefix_aliases=["pf"])
591 if self._tensor_filters:
592 # Register tab completion for the filter names.
593 curses_cli.register_tab_comp_context(["run", "r"],
594 list(self._tensor_filters.keys()))
595 if self._feed_dict and hasattr(self._feed_dict, "keys"):
596 # Register tab completion for feed_dict keys.
597 feed_keys = [common.get_graph_element_name(key)
598 for key in self._feed_dict.keys()]
599 curses_cli.register_tab_comp_context(["print_feed", "pf"], feed_keys)
601 def _get_run_debug_urls(self):
602 """Get the debug_urls value for the current run() call.
604 Returns:
605 debug_urls: (list of str) Debug URLs for the current run() call.
606 Currently, the list consists of only one URL that is a file:// URL.
607 """
609 return ["file://" + self._dump_root]
611 def _update_run_calls_state(self,
612 run_call_count,
613 fetches,
614 feed_dict,
615 is_callable_runner=False):
616 """Update the internal state with regard to run() call history.
618 Args:
619 run_call_count: (int) Number of run() calls that have occurred.
620 fetches: a node/tensor or a list of node/tensor that are the fetches of
621 the run() call. This is the same as the fetches argument to the run()
622 call.
623 feed_dict: None of a dict. This is the feed_dict argument to the run()
624 call.
625 is_callable_runner: (bool) whether a runner returned by
626 Session.make_callable is being run.
627 """
629 self._run_call_count = run_call_count
630 self._feed_dict = feed_dict
631 self._run_description = cli_shared.get_run_short_description(
632 run_call_count,
633 fetches,
634 feed_dict,
635 is_callable_runner=is_callable_runner)
636 self._run_through_times -= 1
638 self._run_info = cli_shared.get_run_start_intro(
639 run_call_count,
640 fetches,
641 feed_dict,
642 self._tensor_filters,
643 is_callable_runner=is_callable_runner)