Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/debug/cli/profile_analyzer_cli.py: 16%
299 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Formats and displays profiling information."""
17import argparse
18import os
19import re
21import numpy as np
23from tensorflow.python.debug.cli import cli_shared
24from tensorflow.python.debug.cli import command_parser
25from tensorflow.python.debug.cli import debugger_cli_common
26from tensorflow.python.debug.cli import ui_factory
27from tensorflow.python.debug.lib import profiling
28from tensorflow.python.debug.lib import source_utils
30RL = debugger_cli_common.RichLine
32SORT_OPS_BY_OP_NAME = "node"
33SORT_OPS_BY_OP_TYPE = "op_type"
34SORT_OPS_BY_OP_TIME = "op_time"
35SORT_OPS_BY_EXEC_TIME = "exec_time"
36SORT_OPS_BY_START_TIME = "start_time"
37SORT_OPS_BY_LINE = "line"
39_DEVICE_NAME_FILTER_FLAG = "device_name_filter"
40_NODE_NAME_FILTER_FLAG = "node_name_filter"
41_OP_TYPE_FILTER_FLAG = "op_type_filter"
44class ProfileDataTableView(object):
45 """Table View of profiling data."""
47 def __init__(self, profile_datum_list, time_unit=cli_shared.TIME_UNIT_US):
48 """Constructor.
50 Args:
51 profile_datum_list: List of `ProfileDatum` objects.
52 time_unit: must be in cli_shared.TIME_UNITS.
53 """
54 self._profile_datum_list = profile_datum_list
55 self.formatted_start_time = [
56 datum.start_time for datum in profile_datum_list]
57 self.formatted_op_time = [
58 cli_shared.time_to_readable_str(datum.op_time,
59 force_time_unit=time_unit)
60 for datum in profile_datum_list]
61 self.formatted_exec_time = [
62 cli_shared.time_to_readable_str(
63 datum.node_exec_stats.all_end_rel_micros,
64 force_time_unit=time_unit)
65 for datum in profile_datum_list]
67 self._column_names = ["Node",
68 "Op Type",
69 "Start Time (us)",
70 "Op Time (%s)" % time_unit,
71 "Exec Time (%s)" % time_unit,
72 "Filename:Lineno(function)"]
73 self._column_sort_ids = [SORT_OPS_BY_OP_NAME, SORT_OPS_BY_OP_TYPE,
74 SORT_OPS_BY_START_TIME, SORT_OPS_BY_OP_TIME,
75 SORT_OPS_BY_EXEC_TIME, SORT_OPS_BY_LINE]
77 def value(self,
78 row,
79 col,
80 device_name_filter=None,
81 node_name_filter=None,
82 op_type_filter=None):
83 """Get the content of a cell of the table.
85 Args:
86 row: (int) row index.
87 col: (int) column index.
88 device_name_filter: Regular expression to filter by device name.
89 node_name_filter: Regular expression to filter by node name.
90 op_type_filter: Regular expression to filter by op type.
92 Returns:
93 A debuggre_cli_common.RichLine object representing the content of the
94 cell, potentially with a clickable MenuItem.
96 Raises:
97 IndexError: if row index is out of range.
98 """
99 menu_item = None
100 if col == 0:
101 text = self._profile_datum_list[row].node_exec_stats.node_name
102 elif col == 1:
103 text = self._profile_datum_list[row].op_type
104 elif col == 2:
105 text = str(self.formatted_start_time[row])
106 elif col == 3:
107 text = str(self.formatted_op_time[row])
108 elif col == 4:
109 text = str(self.formatted_exec_time[row])
110 elif col == 5:
111 command = "ps"
112 if device_name_filter:
113 command += " --%s %s" % (_DEVICE_NAME_FILTER_FLAG,
114 device_name_filter)
115 if node_name_filter:
116 command += " --%s %s" % (_NODE_NAME_FILTER_FLAG, node_name_filter)
117 if op_type_filter:
118 command += " --%s %s" % (_OP_TYPE_FILTER_FLAG, op_type_filter)
119 command += " %s --init_line %d" % (
120 self._profile_datum_list[row].file_path,
121 self._profile_datum_list[row].line_number)
122 menu_item = debugger_cli_common.MenuItem(None, command)
123 text = self._profile_datum_list[row].file_line_func
124 else:
125 raise IndexError("Invalid column index %d." % col)
127 return RL(text, font_attr=menu_item)
129 def row_count(self):
130 return len(self._profile_datum_list)
132 def column_count(self):
133 return len(self._column_names)
135 def column_names(self):
136 return self._column_names
138 def column_sort_id(self, col):
139 return self._column_sort_ids[col]
142def _list_profile_filter(
143 profile_datum,
144 node_name_regex,
145 file_path_regex,
146 op_type_regex,
147 op_time_interval,
148 exec_time_interval,
149 min_lineno=-1,
150 max_lineno=-1):
151 """Filter function for list_profile command.
153 Args:
154 profile_datum: A `ProfileDatum` object.
155 node_name_regex: Regular expression pattern object to filter by name.
156 file_path_regex: Regular expression pattern object to filter by file path.
157 op_type_regex: Regular expression pattern object to filter by op type.
158 op_time_interval: `Interval` for filtering op time.
159 exec_time_interval: `Interval` for filtering exec time.
160 min_lineno: Lower bound for 1-based line number, inclusive.
161 If <= 0, has no effect.
162 max_lineno: Upper bound for 1-based line number, exclusive.
163 If <= 0, has no effect.
164 # TODO(cais): Maybe filter by function name.
166 Returns:
167 True iff profile_datum should be included.
168 """
169 if node_name_regex and not node_name_regex.match(
170 profile_datum.node_exec_stats.node_name):
171 return False
172 if file_path_regex:
173 if (not profile_datum.file_path or
174 not file_path_regex.match(profile_datum.file_path)):
175 return False
176 if (min_lineno > 0 and profile_datum.line_number and
177 profile_datum.line_number < min_lineno):
178 return False
179 if (max_lineno > 0 and profile_datum.line_number and
180 profile_datum.line_number >= max_lineno):
181 return False
182 if (profile_datum.op_type is not None and op_type_regex and
183 not op_type_regex.match(profile_datum.op_type)):
184 return False
185 if op_time_interval is not None and not op_time_interval.contains(
186 profile_datum.op_time):
187 return False
188 if exec_time_interval and not exec_time_interval.contains(
189 profile_datum.node_exec_stats.all_end_rel_micros):
190 return False
191 return True
194def _list_profile_sort_key(profile_datum, sort_by):
195 """Get a profile_datum property to sort by in list_profile command.
197 Args:
198 profile_datum: A `ProfileDatum` object.
199 sort_by: (string) indicates a value to sort by.
200 Must be one of SORT_BY* constants.
202 Returns:
203 profile_datum property to sort by.
204 """
205 if sort_by == SORT_OPS_BY_OP_NAME:
206 return profile_datum.node_exec_stats.node_name
207 elif sort_by == SORT_OPS_BY_OP_TYPE:
208 return profile_datum.op_type
209 elif sort_by == SORT_OPS_BY_LINE:
210 return profile_datum.file_line_func
211 elif sort_by == SORT_OPS_BY_OP_TIME:
212 return profile_datum.op_time
213 elif sort_by == SORT_OPS_BY_EXEC_TIME:
214 return profile_datum.node_exec_stats.all_end_rel_micros
215 else: # sort by start time
216 return profile_datum.node_exec_stats.all_start_micros
219class ProfileAnalyzer(object):
220 """Analyzer for profiling data."""
222 def __init__(self, graph, run_metadata):
223 """ProfileAnalyzer constructor.
225 Args:
226 graph: (tf.Graph) Python graph object.
227 run_metadata: A `RunMetadata` protobuf object.
229 Raises:
230 ValueError: If run_metadata is None.
231 """
232 self._graph = graph
233 if not run_metadata:
234 raise ValueError("No RunMetadata passed for profile analysis.")
235 self._run_metadata = run_metadata
236 self._arg_parsers = {}
237 ap = argparse.ArgumentParser(
238 description="List nodes profile information.",
239 usage=argparse.SUPPRESS)
240 ap.add_argument(
241 "-d",
242 "--%s" % _DEVICE_NAME_FILTER_FLAG,
243 dest=_DEVICE_NAME_FILTER_FLAG,
244 type=str,
245 default="",
246 help="filter device name by regex.")
247 ap.add_argument(
248 "-n",
249 "--%s" % _NODE_NAME_FILTER_FLAG,
250 dest=_NODE_NAME_FILTER_FLAG,
251 type=str,
252 default="",
253 help="filter node name by regex.")
254 ap.add_argument(
255 "-t",
256 "--%s" % _OP_TYPE_FILTER_FLAG,
257 dest=_OP_TYPE_FILTER_FLAG,
258 type=str,
259 default="",
260 help="filter op type by regex.")
261 # TODO(annarev): allow file filtering at non-stack top position.
262 ap.add_argument(
263 "-f",
264 "--file_path_filter",
265 dest="file_path_filter",
266 type=str,
267 default="",
268 help="filter by file name at the top position of node's creation "
269 "stack that does not belong to TensorFlow library.")
270 ap.add_argument(
271 "--min_lineno",
272 dest="min_lineno",
273 type=int,
274 default=-1,
275 help="(Inclusive) lower bound for 1-based line number in source file. "
276 "If <= 0, has no effect.")
277 ap.add_argument(
278 "--max_lineno",
279 dest="max_lineno",
280 type=int,
281 default=-1,
282 help="(Exclusive) upper bound for 1-based line number in source file. "
283 "If <= 0, has no effect.")
284 ap.add_argument(
285 "-e",
286 "--execution_time",
287 dest="execution_time",
288 type=str,
289 default="",
290 help="Filter by execution time interval "
291 "(includes compute plus pre- and post -processing time). "
292 "Supported units are s, ms and us (default). "
293 "E.g. -e >100s, -e <100, -e [100us,1000ms]")
294 ap.add_argument(
295 "-o",
296 "--op_time",
297 dest="op_time",
298 type=str,
299 default="",
300 help="Filter by op time interval (only includes compute time). "
301 "Supported units are s, ms and us (default). "
302 "E.g. -e >100s, -e <100, -e [100us,1000ms]")
303 ap.add_argument(
304 "-s",
305 "--sort_by",
306 dest="sort_by",
307 type=str,
308 default=SORT_OPS_BY_START_TIME,
309 help=("the field to sort the data by: (%s)" %
310 " | ".join([SORT_OPS_BY_OP_NAME, SORT_OPS_BY_OP_TYPE,
311 SORT_OPS_BY_START_TIME, SORT_OPS_BY_OP_TIME,
312 SORT_OPS_BY_EXEC_TIME, SORT_OPS_BY_LINE])))
313 ap.add_argument(
314 "-r",
315 "--reverse",
316 dest="reverse",
317 action="store_true",
318 help="sort the data in reverse (descending) order")
319 ap.add_argument(
320 "--time_unit",
321 dest="time_unit",
322 type=str,
323 default=cli_shared.TIME_UNIT_US,
324 help="Time unit (" + " | ".join(cli_shared.TIME_UNITS) + ")")
326 self._arg_parsers["list_profile"] = ap
328 ap = argparse.ArgumentParser(
329 description="Print a Python source file with line-level profile "
330 "information",
331 usage=argparse.SUPPRESS)
332 ap.add_argument(
333 "source_file_path",
334 type=str,
335 help="Path to the source_file_path")
336 ap.add_argument(
337 "--cost_type",
338 type=str,
339 choices=["exec_time", "op_time"],
340 default="exec_time",
341 help="Type of cost to display")
342 ap.add_argument(
343 "--time_unit",
344 dest="time_unit",
345 type=str,
346 default=cli_shared.TIME_UNIT_US,
347 help="Time unit (" + " | ".join(cli_shared.TIME_UNITS) + ")")
348 ap.add_argument(
349 "-d",
350 "--%s" % _DEVICE_NAME_FILTER_FLAG,
351 dest=_DEVICE_NAME_FILTER_FLAG,
352 type=str,
353 default="",
354 help="Filter device name by regex.")
355 ap.add_argument(
356 "-n",
357 "--%s" % _NODE_NAME_FILTER_FLAG,
358 dest=_NODE_NAME_FILTER_FLAG,
359 type=str,
360 default="",
361 help="Filter node name by regex.")
362 ap.add_argument(
363 "-t",
364 "--%s" % _OP_TYPE_FILTER_FLAG,
365 dest=_OP_TYPE_FILTER_FLAG,
366 type=str,
367 default="",
368 help="Filter op type by regex.")
369 ap.add_argument(
370 "--init_line",
371 dest="init_line",
372 type=int,
373 default=0,
374 help="The 1-based line number to scroll to initially.")
376 self._arg_parsers["print_source"] = ap
378 def list_profile(self, args, screen_info=None):
379 """Command handler for list_profile.
381 List per-operation profile information.
383 Args:
384 args: Command-line arguments, excluding the command prefix, as a list of
385 str.
386 screen_info: Optional dict input containing screen information such as
387 cols.
389 Returns:
390 Output text lines as a RichTextLines object.
391 """
392 screen_cols = 80
393 if screen_info and "cols" in screen_info:
394 screen_cols = screen_info["cols"]
396 parsed = self._arg_parsers["list_profile"].parse_args(args)
397 op_time_interval = (command_parser.parse_time_interval(parsed.op_time)
398 if parsed.op_time else None)
399 exec_time_interval = (
400 command_parser.parse_time_interval(parsed.execution_time)
401 if parsed.execution_time else None)
402 node_name_regex = (re.compile(parsed.node_name_filter)
403 if parsed.node_name_filter else None)
404 file_path_regex = (re.compile(parsed.file_path_filter)
405 if parsed.file_path_filter else None)
406 op_type_regex = (re.compile(parsed.op_type_filter)
407 if parsed.op_type_filter else None)
409 output = debugger_cli_common.RichTextLines([""])
410 device_name_regex = (re.compile(parsed.device_name_filter)
411 if parsed.device_name_filter else None)
412 data_generator = self._get_profile_data_generator()
413 device_count = len(self._run_metadata.step_stats.dev_stats)
414 for index in range(device_count):
415 device_stats = self._run_metadata.step_stats.dev_stats[index]
416 if not device_name_regex or device_name_regex.match(device_stats.device):
417 profile_data = [
418 datum for datum in data_generator(device_stats)
419 if _list_profile_filter(
420 datum, node_name_regex, file_path_regex, op_type_regex,
421 op_time_interval, exec_time_interval,
422 min_lineno=parsed.min_lineno, max_lineno=parsed.max_lineno)]
423 profile_data = sorted(
424 profile_data,
425 key=lambda datum: _list_profile_sort_key(datum, parsed.sort_by),
426 reverse=parsed.reverse)
427 output.extend(
428 self._get_list_profile_lines(
429 device_stats.device, index, device_count,
430 profile_data, parsed.sort_by, parsed.reverse, parsed.time_unit,
431 device_name_filter=parsed.device_name_filter,
432 node_name_filter=parsed.node_name_filter,
433 op_type_filter=parsed.op_type_filter,
434 screen_cols=screen_cols))
435 return output
437 def _get_profile_data_generator(self):
438 """Get function that generates `ProfileDatum` objects.
440 Returns:
441 A function that generates `ProfileDatum` objects.
442 """
443 node_to_file_path = {}
444 node_to_line_number = {}
445 node_to_func_name = {}
446 node_to_op_type = {}
447 for op in self._graph.get_operations():
448 for trace_entry in reversed(op.traceback):
449 file_path = trace_entry[0]
450 line_num = trace_entry[1]
451 func_name = trace_entry[2]
452 if not source_utils.guess_is_tensorflow_py_library(file_path):
453 break
454 node_to_file_path[op.name] = file_path
455 node_to_line_number[op.name] = line_num
456 node_to_func_name[op.name] = func_name
457 node_to_op_type[op.name] = op.type
459 def profile_data_generator(device_step_stats):
460 for node_stats in device_step_stats.node_stats:
461 if node_stats.node_name == "_SOURCE" or node_stats.node_name == "_SINK":
462 continue
463 yield profiling.ProfileDatum(
464 device_step_stats.device,
465 node_stats,
466 node_to_file_path.get(node_stats.node_name, ""),
467 node_to_line_number.get(node_stats.node_name, 0),
468 node_to_func_name.get(node_stats.node_name, ""),
469 node_to_op_type.get(node_stats.node_name, ""))
470 return profile_data_generator
472 def _get_list_profile_lines(
473 self, device_name, device_index, device_count,
474 profile_datum_list, sort_by, sort_reverse, time_unit,
475 device_name_filter=None, node_name_filter=None, op_type_filter=None,
476 screen_cols=80):
477 """Get `RichTextLines` object for list_profile command for a given device.
479 Args:
480 device_name: (string) Device name.
481 device_index: (int) Device index.
482 device_count: (int) Number of devices.
483 profile_datum_list: List of `ProfileDatum` objects.
484 sort_by: (string) Identifier of column to sort. Sort identifier
485 must match value of SORT_OPS_BY_OP_NAME, SORT_OPS_BY_OP_TYPE,
486 SORT_OPS_BY_EXEC_TIME, SORT_OPS_BY_MEMORY or SORT_OPS_BY_LINE.
487 sort_reverse: (bool) Whether to sort in descending instead of default
488 (ascending) order.
489 time_unit: time unit, must be in cli_shared.TIME_UNITS.
490 device_name_filter: Regular expression to filter by device name.
491 node_name_filter: Regular expression to filter by node name.
492 op_type_filter: Regular expression to filter by op type.
493 screen_cols: (int) Number of columns available on the screen (i.e.,
494 available screen width).
496 Returns:
497 `RichTextLines` object containing a table that displays profiling
498 information for each op.
499 """
500 profile_data = ProfileDataTableView(profile_datum_list, time_unit=time_unit)
502 # Calculate total time early to calculate column widths.
503 total_op_time = sum(datum.op_time for datum in profile_datum_list)
504 total_exec_time = sum(datum.node_exec_stats.all_end_rel_micros
505 for datum in profile_datum_list)
506 device_total_row = [
507 "Device Total", "",
508 cli_shared.time_to_readable_str(total_op_time,
509 force_time_unit=time_unit),
510 cli_shared.time_to_readable_str(total_exec_time,
511 force_time_unit=time_unit)]
513 # Calculate column widths.
514 column_widths = [
515 len(column_name) for column_name in profile_data.column_names()]
516 for col in range(len(device_total_row)):
517 column_widths[col] = max(column_widths[col], len(device_total_row[col]))
518 for col in range(len(column_widths)):
519 for row in range(profile_data.row_count()):
520 column_widths[col] = max(
521 column_widths[col], len(profile_data.value(
522 row,
523 col,
524 device_name_filter=device_name_filter,
525 node_name_filter=node_name_filter,
526 op_type_filter=op_type_filter)))
527 column_widths[col] += 2 # add margin between columns
529 # Add device name.
530 output = [RL("-" * screen_cols)]
531 device_row = "Device %d of %d: %s" % (
532 device_index + 1, device_count, device_name)
533 output.append(RL(device_row))
534 output.append(RL())
536 # Add headers.
537 base_command = "list_profile"
538 row = RL()
539 for col in range(profile_data.column_count()):
540 column_name = profile_data.column_names()[col]
541 sort_id = profile_data.column_sort_id(col)
542 command = "%s -s %s" % (base_command, sort_id)
543 if sort_by == sort_id and not sort_reverse:
544 command += " -r"
545 head_menu_item = debugger_cli_common.MenuItem(None, command)
546 row += RL(column_name, font_attr=[head_menu_item, "bold"])
547 row += RL(" " * (column_widths[col] - len(column_name)))
549 output.append(row)
551 # Add data rows.
552 for row in range(profile_data.row_count()):
553 new_row = RL()
554 for col in range(profile_data.column_count()):
555 new_cell = profile_data.value(
556 row,
557 col,
558 device_name_filter=device_name_filter,
559 node_name_filter=node_name_filter,
560 op_type_filter=op_type_filter)
561 new_row += new_cell
562 new_row += RL(" " * (column_widths[col] - len(new_cell)))
563 output.append(new_row)
565 # Add stat totals.
566 row_str = ""
567 for width, row in zip(column_widths, device_total_row):
568 row_str += ("{:<%d}" % width).format(row)
569 output.append(RL())
570 output.append(RL(row_str))
571 return debugger_cli_common.rich_text_lines_from_rich_line_list(output)
573 def _measure_list_profile_column_widths(self, profile_data):
574 """Determine the maximum column widths for each data list.
576 Args:
577 profile_data: list of ProfileDatum objects.
579 Returns:
580 List of column widths in the same order as columns in data.
581 """
582 num_columns = len(profile_data.column_names())
583 widths = [len(column_name) for column_name in profile_data.column_names()]
584 for row in range(profile_data.row_count()):
585 for col in range(num_columns):
586 widths[col] = max(
587 widths[col], len(str(profile_data.row_values(row)[col])) + 2)
588 return widths
590 _LINE_COST_ATTR = cli_shared.COLOR_CYAN
591 _LINE_NUM_ATTR = cli_shared.COLOR_YELLOW
592 _NUM_NODES_HEAD = "#nodes"
593 _NUM_EXECS_SUB_HEAD = "(#execs)"
594 _LINENO_HEAD = "lineno"
595 _SOURCE_HEAD = "source"
597 def print_source(self, args, screen_info=None):
598 """Print a Python source file with line-level profile information.
600 Args:
601 args: Command-line arguments, excluding the command prefix, as a list of
602 str.
603 screen_info: Optional dict input containing screen information such as
604 cols.
606 Returns:
607 Output text lines as a RichTextLines object.
608 """
609 del screen_info
611 parsed = self._arg_parsers["print_source"].parse_args(args)
613 device_name_regex = (re.compile(parsed.device_name_filter)
614 if parsed.device_name_filter else None)
616 profile_data = []
617 data_generator = self._get_profile_data_generator()
618 device_count = len(self._run_metadata.step_stats.dev_stats)
619 for index in range(device_count):
620 device_stats = self._run_metadata.step_stats.dev_stats[index]
621 if device_name_regex and not device_name_regex.match(device_stats.device):
622 continue
623 profile_data.extend(data_generator(device_stats))
625 source_annotation = source_utils.annotate_source_against_profile(
626 profile_data,
627 os.path.expanduser(parsed.source_file_path),
628 node_name_filter=parsed.node_name_filter,
629 op_type_filter=parsed.op_type_filter)
630 if not source_annotation:
631 return debugger_cli_common.RichTextLines(
632 ["The source file %s does not contain any profile information for "
633 "the previous Session run under the following "
634 "filters:" % parsed.source_file_path,
635 " --%s: %s" % (_DEVICE_NAME_FILTER_FLAG, parsed.device_name_filter),
636 " --%s: %s" % (_NODE_NAME_FILTER_FLAG, parsed.node_name_filter),
637 " --%s: %s" % (_OP_TYPE_FILTER_FLAG, parsed.op_type_filter)])
639 max_total_cost = 0
640 for line_index in source_annotation:
641 total_cost = self._get_total_cost(source_annotation[line_index],
642 parsed.cost_type)
643 max_total_cost = max(max_total_cost, total_cost)
645 source_lines, line_num_width = source_utils.load_source(
646 parsed.source_file_path)
648 cost_bar_max_length = 10
649 total_cost_head = parsed.cost_type
650 column_widths = {
651 "cost_bar": cost_bar_max_length + 3,
652 "total_cost": len(total_cost_head) + 3,
653 "num_nodes_execs": len(self._NUM_EXECS_SUB_HEAD) + 1,
654 "line_number": line_num_width,
655 }
657 head = RL(
658 " " * column_widths["cost_bar"] +
659 total_cost_head +
660 " " * (column_widths["total_cost"] - len(total_cost_head)) +
661 self._NUM_NODES_HEAD +
662 " " * (column_widths["num_nodes_execs"] - len(self._NUM_NODES_HEAD)),
663 font_attr=self._LINE_COST_ATTR)
664 head += RL(self._LINENO_HEAD, font_attr=self._LINE_NUM_ATTR)
665 sub_head = RL(
666 " " * (column_widths["cost_bar"] +
667 column_widths["total_cost"]) +
668 self._NUM_EXECS_SUB_HEAD +
669 " " * (column_widths["num_nodes_execs"] -
670 len(self._NUM_EXECS_SUB_HEAD)) +
671 " " * column_widths["line_number"],
672 font_attr=self._LINE_COST_ATTR)
673 sub_head += RL(self._SOURCE_HEAD, font_attr="bold")
674 lines = [head, sub_head]
676 output_annotations = {}
677 for i, line in enumerate(source_lines):
678 lineno = i + 1
679 if lineno in source_annotation:
680 annotation = source_annotation[lineno]
681 cost_bar = self._render_normalized_cost_bar(
682 self._get_total_cost(annotation, parsed.cost_type), max_total_cost,
683 cost_bar_max_length)
684 annotated_line = cost_bar
685 annotated_line += " " * (column_widths["cost_bar"] - len(cost_bar))
687 total_cost = RL(cli_shared.time_to_readable_str(
688 self._get_total_cost(annotation, parsed.cost_type),
689 force_time_unit=parsed.time_unit),
690 font_attr=self._LINE_COST_ATTR)
691 total_cost += " " * (column_widths["total_cost"] - len(total_cost))
692 annotated_line += total_cost
694 file_path_filter = re.escape(parsed.source_file_path) + "$"
695 command = "lp --file_path_filter %s --min_lineno %d --max_lineno %d" % (
696 file_path_filter, lineno, lineno + 1)
697 if parsed.device_name_filter:
698 command += " --%s %s" % (_DEVICE_NAME_FILTER_FLAG,
699 parsed.device_name_filter)
700 if parsed.node_name_filter:
701 command += " --%s %s" % (_NODE_NAME_FILTER_FLAG,
702 parsed.node_name_filter)
703 if parsed.op_type_filter:
704 command += " --%s %s" % (_OP_TYPE_FILTER_FLAG,
705 parsed.op_type_filter)
706 menu_item = debugger_cli_common.MenuItem(None, command)
707 num_nodes_execs = RL("%d(%d)" % (annotation.node_count,
708 annotation.node_exec_count),
709 font_attr=[self._LINE_COST_ATTR, menu_item])
710 num_nodes_execs += " " * (
711 column_widths["num_nodes_execs"] - len(num_nodes_execs))
712 annotated_line += num_nodes_execs
713 else:
714 annotated_line = RL(
715 " " * sum(column_widths[col_name] for col_name in column_widths
716 if col_name != "line_number"))
718 line_num_column = RL(" L%d" % (lineno), self._LINE_NUM_ATTR)
719 line_num_column += " " * (
720 column_widths["line_number"] - len(line_num_column))
721 annotated_line += line_num_column
722 annotated_line += line
723 lines.append(annotated_line)
725 if parsed.init_line == lineno:
726 output_annotations[
727 debugger_cli_common.INIT_SCROLL_POS_KEY] = len(lines) - 1
729 return debugger_cli_common.rich_text_lines_from_rich_line_list(
730 lines, annotations=output_annotations)
732 def _get_total_cost(self, aggregated_profile, cost_type):
733 if cost_type == "exec_time":
734 return aggregated_profile.total_exec_time
735 elif cost_type == "op_time":
736 return aggregated_profile.total_op_time
737 else:
738 raise ValueError("Unsupported cost type: %s" % cost_type)
740 def _render_normalized_cost_bar(self, cost, max_cost, length):
741 """Render a text bar representing a normalized cost.
743 Args:
744 cost: the absolute value of the cost.
745 max_cost: the maximum cost value to normalize the absolute cost with.
746 length: (int) length of the cost bar, in number of characters, excluding
747 the brackets on the two ends.
749 Returns:
750 An instance of debugger_cli_common.RichTextLine.
751 """
752 num_ticks = int(np.ceil(float(cost) / max_cost * length))
753 num_ticks = num_ticks or 1 # Minimum is 1 tick.
754 output = RL("[", font_attr=self._LINE_COST_ATTR)
755 output += RL("|" * num_ticks + " " * (length - num_ticks),
756 font_attr=["bold", self._LINE_COST_ATTR])
757 output += RL("]", font_attr=self._LINE_COST_ATTR)
758 return output
760 def get_help(self, handler_name):
761 return self._arg_parsers[handler_name].format_help()
764def create_profiler_ui(graph,
765 run_metadata,
766 ui_type="curses",
767 on_ui_exit=None,
768 config=None):
769 """Create an instance of CursesUI based on a `tf.Graph` and `RunMetadata`.
771 Args:
772 graph: Python `Graph` object.
773 run_metadata: A `RunMetadata` protobuf object.
774 ui_type: (str) requested UI type, e.g., "curses", "readline".
775 on_ui_exit: (`Callable`) the callback to be called when the UI exits.
776 config: An instance of `cli_config.CLIConfig`.
778 Returns:
779 (base_ui.BaseUI) A BaseUI subtype object with a set of standard analyzer
780 commands and tab-completions registered.
781 """
782 del config # Currently unused.
784 analyzer = ProfileAnalyzer(graph, run_metadata)
786 cli = ui_factory.get_ui(ui_type, on_ui_exit=on_ui_exit)
787 cli.register_command_handler(
788 "list_profile",
789 analyzer.list_profile,
790 analyzer.get_help("list_profile"),
791 prefix_aliases=["lp"])
792 cli.register_command_handler(
793 "print_source",
794 analyzer.print_source,
795 analyzer.get_help("print_source"),
796 prefix_aliases=["ps"])
798 return cli