Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/debug/cli/profile_analyzer_cli.py: 16%

299 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Formats and displays profiling information.""" 

16 

17import argparse 

18import os 

19import re 

20 

21import numpy as np 

22 

23from tensorflow.python.debug.cli import cli_shared 

24from tensorflow.python.debug.cli import command_parser 

25from tensorflow.python.debug.cli import debugger_cli_common 

26from tensorflow.python.debug.cli import ui_factory 

27from tensorflow.python.debug.lib import profiling 

28from tensorflow.python.debug.lib import source_utils 

29 

30RL = debugger_cli_common.RichLine 

31 

32SORT_OPS_BY_OP_NAME = "node" 

33SORT_OPS_BY_OP_TYPE = "op_type" 

34SORT_OPS_BY_OP_TIME = "op_time" 

35SORT_OPS_BY_EXEC_TIME = "exec_time" 

36SORT_OPS_BY_START_TIME = "start_time" 

37SORT_OPS_BY_LINE = "line" 

38 

39_DEVICE_NAME_FILTER_FLAG = "device_name_filter" 

40_NODE_NAME_FILTER_FLAG = "node_name_filter" 

41_OP_TYPE_FILTER_FLAG = "op_type_filter" 

42 

43 

44class ProfileDataTableView(object): 

45 """Table View of profiling data.""" 

46 

47 def __init__(self, profile_datum_list, time_unit=cli_shared.TIME_UNIT_US): 

48 """Constructor. 

49 

50 Args: 

51 profile_datum_list: List of `ProfileDatum` objects. 

52 time_unit: must be in cli_shared.TIME_UNITS. 

53 """ 

54 self._profile_datum_list = profile_datum_list 

55 self.formatted_start_time = [ 

56 datum.start_time for datum in profile_datum_list] 

57 self.formatted_op_time = [ 

58 cli_shared.time_to_readable_str(datum.op_time, 

59 force_time_unit=time_unit) 

60 for datum in profile_datum_list] 

61 self.formatted_exec_time = [ 

62 cli_shared.time_to_readable_str( 

63 datum.node_exec_stats.all_end_rel_micros, 

64 force_time_unit=time_unit) 

65 for datum in profile_datum_list] 

66 

67 self._column_names = ["Node", 

68 "Op Type", 

69 "Start Time (us)", 

70 "Op Time (%s)" % time_unit, 

71 "Exec Time (%s)" % time_unit, 

72 "Filename:Lineno(function)"] 

73 self._column_sort_ids = [SORT_OPS_BY_OP_NAME, SORT_OPS_BY_OP_TYPE, 

74 SORT_OPS_BY_START_TIME, SORT_OPS_BY_OP_TIME, 

75 SORT_OPS_BY_EXEC_TIME, SORT_OPS_BY_LINE] 

76 

77 def value(self, 

78 row, 

79 col, 

80 device_name_filter=None, 

81 node_name_filter=None, 

82 op_type_filter=None): 

83 """Get the content of a cell of the table. 

84 

85 Args: 

86 row: (int) row index. 

87 col: (int) column index. 

88 device_name_filter: Regular expression to filter by device name. 

89 node_name_filter: Regular expression to filter by node name. 

90 op_type_filter: Regular expression to filter by op type. 

91 

92 Returns: 

93 A debuggre_cli_common.RichLine object representing the content of the 

94 cell, potentially with a clickable MenuItem. 

95 

96 Raises: 

97 IndexError: if row index is out of range. 

98 """ 

99 menu_item = None 

100 if col == 0: 

101 text = self._profile_datum_list[row].node_exec_stats.node_name 

102 elif col == 1: 

103 text = self._profile_datum_list[row].op_type 

104 elif col == 2: 

105 text = str(self.formatted_start_time[row]) 

106 elif col == 3: 

107 text = str(self.formatted_op_time[row]) 

108 elif col == 4: 

109 text = str(self.formatted_exec_time[row]) 

110 elif col == 5: 

111 command = "ps" 

112 if device_name_filter: 

113 command += " --%s %s" % (_DEVICE_NAME_FILTER_FLAG, 

114 device_name_filter) 

115 if node_name_filter: 

116 command += " --%s %s" % (_NODE_NAME_FILTER_FLAG, node_name_filter) 

117 if op_type_filter: 

118 command += " --%s %s" % (_OP_TYPE_FILTER_FLAG, op_type_filter) 

119 command += " %s --init_line %d" % ( 

120 self._profile_datum_list[row].file_path, 

121 self._profile_datum_list[row].line_number) 

122 menu_item = debugger_cli_common.MenuItem(None, command) 

123 text = self._profile_datum_list[row].file_line_func 

124 else: 

125 raise IndexError("Invalid column index %d." % col) 

126 

127 return RL(text, font_attr=menu_item) 

128 

129 def row_count(self): 

130 return len(self._profile_datum_list) 

131 

132 def column_count(self): 

133 return len(self._column_names) 

134 

135 def column_names(self): 

136 return self._column_names 

137 

138 def column_sort_id(self, col): 

139 return self._column_sort_ids[col] 

140 

141 

142def _list_profile_filter( 

143 profile_datum, 

144 node_name_regex, 

145 file_path_regex, 

146 op_type_regex, 

147 op_time_interval, 

148 exec_time_interval, 

149 min_lineno=-1, 

150 max_lineno=-1): 

151 """Filter function for list_profile command. 

152 

153 Args: 

154 profile_datum: A `ProfileDatum` object. 

155 node_name_regex: Regular expression pattern object to filter by name. 

156 file_path_regex: Regular expression pattern object to filter by file path. 

157 op_type_regex: Regular expression pattern object to filter by op type. 

158 op_time_interval: `Interval` for filtering op time. 

159 exec_time_interval: `Interval` for filtering exec time. 

160 min_lineno: Lower bound for 1-based line number, inclusive. 

161 If <= 0, has no effect. 

162 max_lineno: Upper bound for 1-based line number, exclusive. 

163 If <= 0, has no effect. 

164 # TODO(cais): Maybe filter by function name. 

165 

166 Returns: 

167 True iff profile_datum should be included. 

168 """ 

169 if node_name_regex and not node_name_regex.match( 

170 profile_datum.node_exec_stats.node_name): 

171 return False 

172 if file_path_regex: 

173 if (not profile_datum.file_path or 

174 not file_path_regex.match(profile_datum.file_path)): 

175 return False 

176 if (min_lineno > 0 and profile_datum.line_number and 

177 profile_datum.line_number < min_lineno): 

178 return False 

179 if (max_lineno > 0 and profile_datum.line_number and 

180 profile_datum.line_number >= max_lineno): 

181 return False 

182 if (profile_datum.op_type is not None and op_type_regex and 

183 not op_type_regex.match(profile_datum.op_type)): 

184 return False 

185 if op_time_interval is not None and not op_time_interval.contains( 

186 profile_datum.op_time): 

187 return False 

188 if exec_time_interval and not exec_time_interval.contains( 

189 profile_datum.node_exec_stats.all_end_rel_micros): 

190 return False 

191 return True 

192 

193 

194def _list_profile_sort_key(profile_datum, sort_by): 

195 """Get a profile_datum property to sort by in list_profile command. 

196 

197 Args: 

198 profile_datum: A `ProfileDatum` object. 

199 sort_by: (string) indicates a value to sort by. 

200 Must be one of SORT_BY* constants. 

201 

202 Returns: 

203 profile_datum property to sort by. 

204 """ 

205 if sort_by == SORT_OPS_BY_OP_NAME: 

206 return profile_datum.node_exec_stats.node_name 

207 elif sort_by == SORT_OPS_BY_OP_TYPE: 

208 return profile_datum.op_type 

209 elif sort_by == SORT_OPS_BY_LINE: 

210 return profile_datum.file_line_func 

211 elif sort_by == SORT_OPS_BY_OP_TIME: 

212 return profile_datum.op_time 

213 elif sort_by == SORT_OPS_BY_EXEC_TIME: 

214 return profile_datum.node_exec_stats.all_end_rel_micros 

215 else: # sort by start time 

216 return profile_datum.node_exec_stats.all_start_micros 

217 

218 

219class ProfileAnalyzer(object): 

220 """Analyzer for profiling data.""" 

221 

222 def __init__(self, graph, run_metadata): 

223 """ProfileAnalyzer constructor. 

224 

225 Args: 

226 graph: (tf.Graph) Python graph object. 

227 run_metadata: A `RunMetadata` protobuf object. 

228 

229 Raises: 

230 ValueError: If run_metadata is None. 

231 """ 

232 self._graph = graph 

233 if not run_metadata: 

234 raise ValueError("No RunMetadata passed for profile analysis.") 

235 self._run_metadata = run_metadata 

236 self._arg_parsers = {} 

237 ap = argparse.ArgumentParser( 

238 description="List nodes profile information.", 

239 usage=argparse.SUPPRESS) 

240 ap.add_argument( 

241 "-d", 

242 "--%s" % _DEVICE_NAME_FILTER_FLAG, 

243 dest=_DEVICE_NAME_FILTER_FLAG, 

244 type=str, 

245 default="", 

246 help="filter device name by regex.") 

247 ap.add_argument( 

248 "-n", 

249 "--%s" % _NODE_NAME_FILTER_FLAG, 

250 dest=_NODE_NAME_FILTER_FLAG, 

251 type=str, 

252 default="", 

253 help="filter node name by regex.") 

254 ap.add_argument( 

255 "-t", 

256 "--%s" % _OP_TYPE_FILTER_FLAG, 

257 dest=_OP_TYPE_FILTER_FLAG, 

258 type=str, 

259 default="", 

260 help="filter op type by regex.") 

261 # TODO(annarev): allow file filtering at non-stack top position. 

262 ap.add_argument( 

263 "-f", 

264 "--file_path_filter", 

265 dest="file_path_filter", 

266 type=str, 

267 default="", 

268 help="filter by file name at the top position of node's creation " 

269 "stack that does not belong to TensorFlow library.") 

270 ap.add_argument( 

271 "--min_lineno", 

272 dest="min_lineno", 

273 type=int, 

274 default=-1, 

275 help="(Inclusive) lower bound for 1-based line number in source file. " 

276 "If <= 0, has no effect.") 

277 ap.add_argument( 

278 "--max_lineno", 

279 dest="max_lineno", 

280 type=int, 

281 default=-1, 

282 help="(Exclusive) upper bound for 1-based line number in source file. " 

283 "If <= 0, has no effect.") 

284 ap.add_argument( 

285 "-e", 

286 "--execution_time", 

287 dest="execution_time", 

288 type=str, 

289 default="", 

290 help="Filter by execution time interval " 

291 "(includes compute plus pre- and post -processing time). " 

292 "Supported units are s, ms and us (default). " 

293 "E.g. -e >100s, -e <100, -e [100us,1000ms]") 

294 ap.add_argument( 

295 "-o", 

296 "--op_time", 

297 dest="op_time", 

298 type=str, 

299 default="", 

300 help="Filter by op time interval (only includes compute time). " 

301 "Supported units are s, ms and us (default). " 

302 "E.g. -e >100s, -e <100, -e [100us,1000ms]") 

303 ap.add_argument( 

304 "-s", 

305 "--sort_by", 

306 dest="sort_by", 

307 type=str, 

308 default=SORT_OPS_BY_START_TIME, 

309 help=("the field to sort the data by: (%s)" % 

310 " | ".join([SORT_OPS_BY_OP_NAME, SORT_OPS_BY_OP_TYPE, 

311 SORT_OPS_BY_START_TIME, SORT_OPS_BY_OP_TIME, 

312 SORT_OPS_BY_EXEC_TIME, SORT_OPS_BY_LINE]))) 

313 ap.add_argument( 

314 "-r", 

315 "--reverse", 

316 dest="reverse", 

317 action="store_true", 

318 help="sort the data in reverse (descending) order") 

319 ap.add_argument( 

320 "--time_unit", 

321 dest="time_unit", 

322 type=str, 

323 default=cli_shared.TIME_UNIT_US, 

324 help="Time unit (" + " | ".join(cli_shared.TIME_UNITS) + ")") 

325 

326 self._arg_parsers["list_profile"] = ap 

327 

328 ap = argparse.ArgumentParser( 

329 description="Print a Python source file with line-level profile " 

330 "information", 

331 usage=argparse.SUPPRESS) 

332 ap.add_argument( 

333 "source_file_path", 

334 type=str, 

335 help="Path to the source_file_path") 

336 ap.add_argument( 

337 "--cost_type", 

338 type=str, 

339 choices=["exec_time", "op_time"], 

340 default="exec_time", 

341 help="Type of cost to display") 

342 ap.add_argument( 

343 "--time_unit", 

344 dest="time_unit", 

345 type=str, 

346 default=cli_shared.TIME_UNIT_US, 

347 help="Time unit (" + " | ".join(cli_shared.TIME_UNITS) + ")") 

348 ap.add_argument( 

349 "-d", 

350 "--%s" % _DEVICE_NAME_FILTER_FLAG, 

351 dest=_DEVICE_NAME_FILTER_FLAG, 

352 type=str, 

353 default="", 

354 help="Filter device name by regex.") 

355 ap.add_argument( 

356 "-n", 

357 "--%s" % _NODE_NAME_FILTER_FLAG, 

358 dest=_NODE_NAME_FILTER_FLAG, 

359 type=str, 

360 default="", 

361 help="Filter node name by regex.") 

362 ap.add_argument( 

363 "-t", 

364 "--%s" % _OP_TYPE_FILTER_FLAG, 

365 dest=_OP_TYPE_FILTER_FLAG, 

366 type=str, 

367 default="", 

368 help="Filter op type by regex.") 

369 ap.add_argument( 

370 "--init_line", 

371 dest="init_line", 

372 type=int, 

373 default=0, 

374 help="The 1-based line number to scroll to initially.") 

375 

376 self._arg_parsers["print_source"] = ap 

377 

378 def list_profile(self, args, screen_info=None): 

379 """Command handler for list_profile. 

380 

381 List per-operation profile information. 

382 

383 Args: 

384 args: Command-line arguments, excluding the command prefix, as a list of 

385 str. 

386 screen_info: Optional dict input containing screen information such as 

387 cols. 

388 

389 Returns: 

390 Output text lines as a RichTextLines object. 

391 """ 

392 screen_cols = 80 

393 if screen_info and "cols" in screen_info: 

394 screen_cols = screen_info["cols"] 

395 

396 parsed = self._arg_parsers["list_profile"].parse_args(args) 

397 op_time_interval = (command_parser.parse_time_interval(parsed.op_time) 

398 if parsed.op_time else None) 

399 exec_time_interval = ( 

400 command_parser.parse_time_interval(parsed.execution_time) 

401 if parsed.execution_time else None) 

402 node_name_regex = (re.compile(parsed.node_name_filter) 

403 if parsed.node_name_filter else None) 

404 file_path_regex = (re.compile(parsed.file_path_filter) 

405 if parsed.file_path_filter else None) 

406 op_type_regex = (re.compile(parsed.op_type_filter) 

407 if parsed.op_type_filter else None) 

408 

409 output = debugger_cli_common.RichTextLines([""]) 

410 device_name_regex = (re.compile(parsed.device_name_filter) 

411 if parsed.device_name_filter else None) 

412 data_generator = self._get_profile_data_generator() 

413 device_count = len(self._run_metadata.step_stats.dev_stats) 

414 for index in range(device_count): 

415 device_stats = self._run_metadata.step_stats.dev_stats[index] 

416 if not device_name_regex or device_name_regex.match(device_stats.device): 

417 profile_data = [ 

418 datum for datum in data_generator(device_stats) 

419 if _list_profile_filter( 

420 datum, node_name_regex, file_path_regex, op_type_regex, 

421 op_time_interval, exec_time_interval, 

422 min_lineno=parsed.min_lineno, max_lineno=parsed.max_lineno)] 

423 profile_data = sorted( 

424 profile_data, 

425 key=lambda datum: _list_profile_sort_key(datum, parsed.sort_by), 

426 reverse=parsed.reverse) 

427 output.extend( 

428 self._get_list_profile_lines( 

429 device_stats.device, index, device_count, 

430 profile_data, parsed.sort_by, parsed.reverse, parsed.time_unit, 

431 device_name_filter=parsed.device_name_filter, 

432 node_name_filter=parsed.node_name_filter, 

433 op_type_filter=parsed.op_type_filter, 

434 screen_cols=screen_cols)) 

435 return output 

436 

437 def _get_profile_data_generator(self): 

438 """Get function that generates `ProfileDatum` objects. 

439 

440 Returns: 

441 A function that generates `ProfileDatum` objects. 

442 """ 

443 node_to_file_path = {} 

444 node_to_line_number = {} 

445 node_to_func_name = {} 

446 node_to_op_type = {} 

447 for op in self._graph.get_operations(): 

448 for trace_entry in reversed(op.traceback): 

449 file_path = trace_entry[0] 

450 line_num = trace_entry[1] 

451 func_name = trace_entry[2] 

452 if not source_utils.guess_is_tensorflow_py_library(file_path): 

453 break 

454 node_to_file_path[op.name] = file_path 

455 node_to_line_number[op.name] = line_num 

456 node_to_func_name[op.name] = func_name 

457 node_to_op_type[op.name] = op.type 

458 

459 def profile_data_generator(device_step_stats): 

460 for node_stats in device_step_stats.node_stats: 

461 if node_stats.node_name == "_SOURCE" or node_stats.node_name == "_SINK": 

462 continue 

463 yield profiling.ProfileDatum( 

464 device_step_stats.device, 

465 node_stats, 

466 node_to_file_path.get(node_stats.node_name, ""), 

467 node_to_line_number.get(node_stats.node_name, 0), 

468 node_to_func_name.get(node_stats.node_name, ""), 

469 node_to_op_type.get(node_stats.node_name, "")) 

470 return profile_data_generator 

471 

472 def _get_list_profile_lines( 

473 self, device_name, device_index, device_count, 

474 profile_datum_list, sort_by, sort_reverse, time_unit, 

475 device_name_filter=None, node_name_filter=None, op_type_filter=None, 

476 screen_cols=80): 

477 """Get `RichTextLines` object for list_profile command for a given device. 

478 

479 Args: 

480 device_name: (string) Device name. 

481 device_index: (int) Device index. 

482 device_count: (int) Number of devices. 

483 profile_datum_list: List of `ProfileDatum` objects. 

484 sort_by: (string) Identifier of column to sort. Sort identifier 

485 must match value of SORT_OPS_BY_OP_NAME, SORT_OPS_BY_OP_TYPE, 

486 SORT_OPS_BY_EXEC_TIME, SORT_OPS_BY_MEMORY or SORT_OPS_BY_LINE. 

487 sort_reverse: (bool) Whether to sort in descending instead of default 

488 (ascending) order. 

489 time_unit: time unit, must be in cli_shared.TIME_UNITS. 

490 device_name_filter: Regular expression to filter by device name. 

491 node_name_filter: Regular expression to filter by node name. 

492 op_type_filter: Regular expression to filter by op type. 

493 screen_cols: (int) Number of columns available on the screen (i.e., 

494 available screen width). 

495 

496 Returns: 

497 `RichTextLines` object containing a table that displays profiling 

498 information for each op. 

499 """ 

500 profile_data = ProfileDataTableView(profile_datum_list, time_unit=time_unit) 

501 

502 # Calculate total time early to calculate column widths. 

503 total_op_time = sum(datum.op_time for datum in profile_datum_list) 

504 total_exec_time = sum(datum.node_exec_stats.all_end_rel_micros 

505 for datum in profile_datum_list) 

506 device_total_row = [ 

507 "Device Total", "", 

508 cli_shared.time_to_readable_str(total_op_time, 

509 force_time_unit=time_unit), 

510 cli_shared.time_to_readable_str(total_exec_time, 

511 force_time_unit=time_unit)] 

512 

513 # Calculate column widths. 

514 column_widths = [ 

515 len(column_name) for column_name in profile_data.column_names()] 

516 for col in range(len(device_total_row)): 

517 column_widths[col] = max(column_widths[col], len(device_total_row[col])) 

518 for col in range(len(column_widths)): 

519 for row in range(profile_data.row_count()): 

520 column_widths[col] = max( 

521 column_widths[col], len(profile_data.value( 

522 row, 

523 col, 

524 device_name_filter=device_name_filter, 

525 node_name_filter=node_name_filter, 

526 op_type_filter=op_type_filter))) 

527 column_widths[col] += 2 # add margin between columns 

528 

529 # Add device name. 

530 output = [RL("-" * screen_cols)] 

531 device_row = "Device %d of %d: %s" % ( 

532 device_index + 1, device_count, device_name) 

533 output.append(RL(device_row)) 

534 output.append(RL()) 

535 

536 # Add headers. 

537 base_command = "list_profile" 

538 row = RL() 

539 for col in range(profile_data.column_count()): 

540 column_name = profile_data.column_names()[col] 

541 sort_id = profile_data.column_sort_id(col) 

542 command = "%s -s %s" % (base_command, sort_id) 

543 if sort_by == sort_id and not sort_reverse: 

544 command += " -r" 

545 head_menu_item = debugger_cli_common.MenuItem(None, command) 

546 row += RL(column_name, font_attr=[head_menu_item, "bold"]) 

547 row += RL(" " * (column_widths[col] - len(column_name))) 

548 

549 output.append(row) 

550 

551 # Add data rows. 

552 for row in range(profile_data.row_count()): 

553 new_row = RL() 

554 for col in range(profile_data.column_count()): 

555 new_cell = profile_data.value( 

556 row, 

557 col, 

558 device_name_filter=device_name_filter, 

559 node_name_filter=node_name_filter, 

560 op_type_filter=op_type_filter) 

561 new_row += new_cell 

562 new_row += RL(" " * (column_widths[col] - len(new_cell))) 

563 output.append(new_row) 

564 

565 # Add stat totals. 

566 row_str = "" 

567 for width, row in zip(column_widths, device_total_row): 

568 row_str += ("{:<%d}" % width).format(row) 

569 output.append(RL()) 

570 output.append(RL(row_str)) 

571 return debugger_cli_common.rich_text_lines_from_rich_line_list(output) 

572 

573 def _measure_list_profile_column_widths(self, profile_data): 

574 """Determine the maximum column widths for each data list. 

575 

576 Args: 

577 profile_data: list of ProfileDatum objects. 

578 

579 Returns: 

580 List of column widths in the same order as columns in data. 

581 """ 

582 num_columns = len(profile_data.column_names()) 

583 widths = [len(column_name) for column_name in profile_data.column_names()] 

584 for row in range(profile_data.row_count()): 

585 for col in range(num_columns): 

586 widths[col] = max( 

587 widths[col], len(str(profile_data.row_values(row)[col])) + 2) 

588 return widths 

589 

590 _LINE_COST_ATTR = cli_shared.COLOR_CYAN 

591 _LINE_NUM_ATTR = cli_shared.COLOR_YELLOW 

592 _NUM_NODES_HEAD = "#nodes" 

593 _NUM_EXECS_SUB_HEAD = "(#execs)" 

594 _LINENO_HEAD = "lineno" 

595 _SOURCE_HEAD = "source" 

596 

597 def print_source(self, args, screen_info=None): 

598 """Print a Python source file with line-level profile information. 

599 

600 Args: 

601 args: Command-line arguments, excluding the command prefix, as a list of 

602 str. 

603 screen_info: Optional dict input containing screen information such as 

604 cols. 

605 

606 Returns: 

607 Output text lines as a RichTextLines object. 

608 """ 

609 del screen_info 

610 

611 parsed = self._arg_parsers["print_source"].parse_args(args) 

612 

613 device_name_regex = (re.compile(parsed.device_name_filter) 

614 if parsed.device_name_filter else None) 

615 

616 profile_data = [] 

617 data_generator = self._get_profile_data_generator() 

618 device_count = len(self._run_metadata.step_stats.dev_stats) 

619 for index in range(device_count): 

620 device_stats = self._run_metadata.step_stats.dev_stats[index] 

621 if device_name_regex and not device_name_regex.match(device_stats.device): 

622 continue 

623 profile_data.extend(data_generator(device_stats)) 

624 

625 source_annotation = source_utils.annotate_source_against_profile( 

626 profile_data, 

627 os.path.expanduser(parsed.source_file_path), 

628 node_name_filter=parsed.node_name_filter, 

629 op_type_filter=parsed.op_type_filter) 

630 if not source_annotation: 

631 return debugger_cli_common.RichTextLines( 

632 ["The source file %s does not contain any profile information for " 

633 "the previous Session run under the following " 

634 "filters:" % parsed.source_file_path, 

635 " --%s: %s" % (_DEVICE_NAME_FILTER_FLAG, parsed.device_name_filter), 

636 " --%s: %s" % (_NODE_NAME_FILTER_FLAG, parsed.node_name_filter), 

637 " --%s: %s" % (_OP_TYPE_FILTER_FLAG, parsed.op_type_filter)]) 

638 

639 max_total_cost = 0 

640 for line_index in source_annotation: 

641 total_cost = self._get_total_cost(source_annotation[line_index], 

642 parsed.cost_type) 

643 max_total_cost = max(max_total_cost, total_cost) 

644 

645 source_lines, line_num_width = source_utils.load_source( 

646 parsed.source_file_path) 

647 

648 cost_bar_max_length = 10 

649 total_cost_head = parsed.cost_type 

650 column_widths = { 

651 "cost_bar": cost_bar_max_length + 3, 

652 "total_cost": len(total_cost_head) + 3, 

653 "num_nodes_execs": len(self._NUM_EXECS_SUB_HEAD) + 1, 

654 "line_number": line_num_width, 

655 } 

656 

657 head = RL( 

658 " " * column_widths["cost_bar"] + 

659 total_cost_head + 

660 " " * (column_widths["total_cost"] - len(total_cost_head)) + 

661 self._NUM_NODES_HEAD + 

662 " " * (column_widths["num_nodes_execs"] - len(self._NUM_NODES_HEAD)), 

663 font_attr=self._LINE_COST_ATTR) 

664 head += RL(self._LINENO_HEAD, font_attr=self._LINE_NUM_ATTR) 

665 sub_head = RL( 

666 " " * (column_widths["cost_bar"] + 

667 column_widths["total_cost"]) + 

668 self._NUM_EXECS_SUB_HEAD + 

669 " " * (column_widths["num_nodes_execs"] - 

670 len(self._NUM_EXECS_SUB_HEAD)) + 

671 " " * column_widths["line_number"], 

672 font_attr=self._LINE_COST_ATTR) 

673 sub_head += RL(self._SOURCE_HEAD, font_attr="bold") 

674 lines = [head, sub_head] 

675 

676 output_annotations = {} 

677 for i, line in enumerate(source_lines): 

678 lineno = i + 1 

679 if lineno in source_annotation: 

680 annotation = source_annotation[lineno] 

681 cost_bar = self._render_normalized_cost_bar( 

682 self._get_total_cost(annotation, parsed.cost_type), max_total_cost, 

683 cost_bar_max_length) 

684 annotated_line = cost_bar 

685 annotated_line += " " * (column_widths["cost_bar"] - len(cost_bar)) 

686 

687 total_cost = RL(cli_shared.time_to_readable_str( 

688 self._get_total_cost(annotation, parsed.cost_type), 

689 force_time_unit=parsed.time_unit), 

690 font_attr=self._LINE_COST_ATTR) 

691 total_cost += " " * (column_widths["total_cost"] - len(total_cost)) 

692 annotated_line += total_cost 

693 

694 file_path_filter = re.escape(parsed.source_file_path) + "$" 

695 command = "lp --file_path_filter %s --min_lineno %d --max_lineno %d" % ( 

696 file_path_filter, lineno, lineno + 1) 

697 if parsed.device_name_filter: 

698 command += " --%s %s" % (_DEVICE_NAME_FILTER_FLAG, 

699 parsed.device_name_filter) 

700 if parsed.node_name_filter: 

701 command += " --%s %s" % (_NODE_NAME_FILTER_FLAG, 

702 parsed.node_name_filter) 

703 if parsed.op_type_filter: 

704 command += " --%s %s" % (_OP_TYPE_FILTER_FLAG, 

705 parsed.op_type_filter) 

706 menu_item = debugger_cli_common.MenuItem(None, command) 

707 num_nodes_execs = RL("%d(%d)" % (annotation.node_count, 

708 annotation.node_exec_count), 

709 font_attr=[self._LINE_COST_ATTR, menu_item]) 

710 num_nodes_execs += " " * ( 

711 column_widths["num_nodes_execs"] - len(num_nodes_execs)) 

712 annotated_line += num_nodes_execs 

713 else: 

714 annotated_line = RL( 

715 " " * sum(column_widths[col_name] for col_name in column_widths 

716 if col_name != "line_number")) 

717 

718 line_num_column = RL(" L%d" % (lineno), self._LINE_NUM_ATTR) 

719 line_num_column += " " * ( 

720 column_widths["line_number"] - len(line_num_column)) 

721 annotated_line += line_num_column 

722 annotated_line += line 

723 lines.append(annotated_line) 

724 

725 if parsed.init_line == lineno: 

726 output_annotations[ 

727 debugger_cli_common.INIT_SCROLL_POS_KEY] = len(lines) - 1 

728 

729 return debugger_cli_common.rich_text_lines_from_rich_line_list( 

730 lines, annotations=output_annotations) 

731 

732 def _get_total_cost(self, aggregated_profile, cost_type): 

733 if cost_type == "exec_time": 

734 return aggregated_profile.total_exec_time 

735 elif cost_type == "op_time": 

736 return aggregated_profile.total_op_time 

737 else: 

738 raise ValueError("Unsupported cost type: %s" % cost_type) 

739 

740 def _render_normalized_cost_bar(self, cost, max_cost, length): 

741 """Render a text bar representing a normalized cost. 

742 

743 Args: 

744 cost: the absolute value of the cost. 

745 max_cost: the maximum cost value to normalize the absolute cost with. 

746 length: (int) length of the cost bar, in number of characters, excluding 

747 the brackets on the two ends. 

748 

749 Returns: 

750 An instance of debugger_cli_common.RichTextLine. 

751 """ 

752 num_ticks = int(np.ceil(float(cost) / max_cost * length)) 

753 num_ticks = num_ticks or 1 # Minimum is 1 tick. 

754 output = RL("[", font_attr=self._LINE_COST_ATTR) 

755 output += RL("|" * num_ticks + " " * (length - num_ticks), 

756 font_attr=["bold", self._LINE_COST_ATTR]) 

757 output += RL("]", font_attr=self._LINE_COST_ATTR) 

758 return output 

759 

760 def get_help(self, handler_name): 

761 return self._arg_parsers[handler_name].format_help() 

762 

763 

764def create_profiler_ui(graph, 

765 run_metadata, 

766 ui_type="curses", 

767 on_ui_exit=None, 

768 config=None): 

769 """Create an instance of CursesUI based on a `tf.Graph` and `RunMetadata`. 

770 

771 Args: 

772 graph: Python `Graph` object. 

773 run_metadata: A `RunMetadata` protobuf object. 

774 ui_type: (str) requested UI type, e.g., "curses", "readline". 

775 on_ui_exit: (`Callable`) the callback to be called when the UI exits. 

776 config: An instance of `cli_config.CLIConfig`. 

777 

778 Returns: 

779 (base_ui.BaseUI) A BaseUI subtype object with a set of standard analyzer 

780 commands and tab-completions registered. 

781 """ 

782 del config # Currently unused. 

783 

784 analyzer = ProfileAnalyzer(graph, run_metadata) 

785 

786 cli = ui_factory.get_ui(ui_type, on_ui_exit=on_ui_exit) 

787 cli.register_command_handler( 

788 "list_profile", 

789 analyzer.list_profile, 

790 analyzer.get_help("list_profile"), 

791 prefix_aliases=["lp"]) 

792 cli.register_command_handler( 

793 "print_source", 

794 analyzer.print_source, 

795 analyzer.get_help("print_source"), 

796 prefix_aliases=["ps"]) 

797 

798 return cli