Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/debug/cli/cli_shared.py: 21%

173 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Shared functions and classes for tfdbg command-line interface.""" 

16import math 

17 

18import numpy as np 

19 

20from tensorflow.python.debug.cli import command_parser 

21from tensorflow.python.debug.cli import debugger_cli_common 

22from tensorflow.python.debug.cli import tensor_format 

23from tensorflow.python.debug.lib import common 

24from tensorflow.python.framework import ops 

25from tensorflow.python.ops import variables 

26from tensorflow.python.platform import gfile 

27 

28RL = debugger_cli_common.RichLine 

29 

30# Default threshold number of elements above which ellipses will be used 

31# when printing the value of the tensor. 

32DEFAULT_NDARRAY_DISPLAY_THRESHOLD = 2000 

33 

34COLOR_BLACK = "black" 

35COLOR_BLUE = "blue" 

36COLOR_CYAN = "cyan" 

37COLOR_GRAY = "gray" 

38COLOR_GREEN = "green" 

39COLOR_MAGENTA = "magenta" 

40COLOR_RED = "red" 

41COLOR_WHITE = "white" 

42COLOR_YELLOW = "yellow" 

43 

44TIME_UNIT_US = "us" 

45TIME_UNIT_MS = "ms" 

46TIME_UNIT_S = "s" 

47TIME_UNITS = [TIME_UNIT_US, TIME_UNIT_MS, TIME_UNIT_S] 

48 

49 

50def bytes_to_readable_str(num_bytes, include_b=False): 

51 """Generate a human-readable string representing number of bytes. 

52 

53 The units B, kB, MB and GB are used. 

54 

55 Args: 

56 num_bytes: (`int` or None) Number of bytes. 

57 include_b: (`bool`) Include the letter B at the end of the unit. 

58 

59 Returns: 

60 (`str`) A string representing the number of bytes in a human-readable way, 

61 including a unit at the end. 

62 """ 

63 

64 if num_bytes is None: 

65 return str(num_bytes) 

66 if num_bytes < 1024: 

67 result = "%d" % num_bytes 

68 elif num_bytes < 1048576: 

69 result = "%.2fk" % (num_bytes / 1024.0) 

70 elif num_bytes < 1073741824: 

71 result = "%.2fM" % (num_bytes / 1048576.0) 

72 else: 

73 result = "%.2fG" % (num_bytes / 1073741824.0) 

74 

75 if include_b: 

76 result += "B" 

77 return result 

78 

79 

80def time_to_readable_str(value_us, force_time_unit=None): 

81 """Convert time value to human-readable string. 

82 

83 Args: 

84 value_us: time value in microseconds. 

85 force_time_unit: force the output to use the specified time unit. Must be 

86 in TIME_UNITS. 

87 

88 Returns: 

89 Human-readable string representation of the time value. 

90 

91 Raises: 

92 ValueError: if force_time_unit value is not in TIME_UNITS. 

93 """ 

94 if not value_us: 

95 return "0" 

96 if force_time_unit: 

97 if force_time_unit not in TIME_UNITS: 

98 raise ValueError("Invalid time unit: %s" % force_time_unit) 

99 order = TIME_UNITS.index(force_time_unit) 

100 time_unit = force_time_unit 

101 return "{:.10g}{}".format(value_us / math.pow(10.0, 3*order), time_unit) 

102 else: 

103 order = min(len(TIME_UNITS) - 1, int(math.log(value_us, 10) / 3)) 

104 time_unit = TIME_UNITS[order] 

105 return "{:.3g}{}".format(value_us / math.pow(10.0, 3*order), time_unit) 

106 

107 

108def parse_ranges_highlight(ranges_string): 

109 """Process ranges highlight string. 

110 

111 Args: 

112 ranges_string: (str) A string representing a numerical range of a list of 

113 numerical ranges. See the help info of the -r flag of the print_tensor 

114 command for more details. 

115 

116 Returns: 

117 An instance of tensor_format.HighlightOptions, if range_string is a valid 

118 representation of a range or a list of ranges. 

119 """ 

120 

121 ranges = None 

122 

123 def ranges_filter(x): 

124 r = np.zeros(x.shape, dtype=bool) 

125 for range_start, range_end in ranges: 

126 r = np.logical_or(r, np.logical_and(x >= range_start, x <= range_end)) 

127 

128 return r 

129 

130 if ranges_string: 

131 ranges = command_parser.parse_ranges(ranges_string) 

132 return tensor_format.HighlightOptions( 

133 ranges_filter, description=ranges_string) 

134 else: 

135 return None 

136 

137 

138def numpy_printoptions_from_screen_info(screen_info): 

139 if screen_info and "cols" in screen_info: 

140 return {"linewidth": screen_info["cols"]} 

141 else: 

142 return {} 

143 

144 

145def format_tensor(tensor, 

146 tensor_name, 

147 np_printoptions, 

148 print_all=False, 

149 tensor_slicing=None, 

150 highlight_options=None, 

151 include_numeric_summary=False, 

152 write_path=None): 

153 """Generate formatted str to represent a tensor or its slices. 

154 

155 Args: 

156 tensor: (numpy ndarray) The tensor value. 

157 tensor_name: (str) Name of the tensor, e.g., the tensor's debug watch key. 

158 np_printoptions: (dict) Numpy tensor formatting options. 

159 print_all: (bool) Whether the tensor is to be displayed in its entirety, 

160 instead of printing ellipses, even if its number of elements exceeds 

161 the default numpy display threshold. 

162 (Note: Even if this is set to true, the screen output can still be cut 

163 off by the UI frontend if it consist of more lines than the frontend 

164 can handle.) 

165 tensor_slicing: (str or None) Slicing of the tensor, e.g., "[:, 1]". If 

166 None, no slicing will be performed on the tensor. 

167 highlight_options: (tensor_format.HighlightOptions) options to highlight 

168 elements of the tensor. See the doc of tensor_format.format_tensor() 

169 for more details. 

170 include_numeric_summary: Whether a text summary of the numeric values (if 

171 applicable) will be included. 

172 write_path: A path to save the tensor value (after any slicing) to 

173 (optional). `numpy.save()` is used to save the value. 

174 

175 Returns: 

176 An instance of `debugger_cli_common.RichTextLines` representing the 

177 (potentially sliced) tensor. 

178 """ 

179 

180 if tensor_slicing: 

181 # Validate the indexing. 

182 value = command_parser.evaluate_tensor_slice(tensor, tensor_slicing) 

183 sliced_name = tensor_name + tensor_slicing 

184 else: 

185 value = tensor 

186 sliced_name = tensor_name 

187 

188 auxiliary_message = None 

189 if write_path: 

190 with gfile.Open(write_path, "wb") as output_file: 

191 np.save(output_file, value) 

192 line = debugger_cli_common.RichLine("Saved value to: ") 

193 line += debugger_cli_common.RichLine(write_path, font_attr="bold") 

194 line += " (%sB)" % bytes_to_readable_str(gfile.Stat(write_path).length) 

195 auxiliary_message = debugger_cli_common.rich_text_lines_from_rich_line_list( 

196 [line, debugger_cli_common.RichLine("")]) 

197 

198 if print_all: 

199 np_printoptions["threshold"] = value.size 

200 else: 

201 np_printoptions["threshold"] = DEFAULT_NDARRAY_DISPLAY_THRESHOLD 

202 

203 return tensor_format.format_tensor( 

204 value, 

205 sliced_name, 

206 include_metadata=True, 

207 include_numeric_summary=include_numeric_summary, 

208 auxiliary_message=auxiliary_message, 

209 np_printoptions=np_printoptions, 

210 highlight_options=highlight_options) 

211 

212 

213def error(msg): 

214 """Generate a RichTextLines output for error. 

215 

216 Args: 

217 msg: (str) The error message. 

218 

219 Returns: 

220 (debugger_cli_common.RichTextLines) A representation of the error message 

221 for screen output. 

222 """ 

223 

224 return debugger_cli_common.rich_text_lines_from_rich_line_list([ 

225 RL("ERROR: " + msg, COLOR_RED)]) 

226 

227 

228def _recommend_command(command, description, indent=2, create_link=False): 

229 """Generate a RichTextLines object that describes a recommended command. 

230 

231 Args: 

232 command: (str) The command to recommend. 

233 description: (str) A description of what the command does. 

234 indent: (int) How many spaces to indent in the beginning. 

235 create_link: (bool) Whether a command link is to be applied to the command 

236 string. 

237 

238 Returns: 

239 (RichTextLines) Formatted text (with font attributes) for recommending the 

240 command. 

241 """ 

242 

243 indent_str = " " * indent 

244 

245 if create_link: 

246 font_attr = [debugger_cli_common.MenuItem("", command), "bold"] 

247 else: 

248 font_attr = "bold" 

249 

250 lines = [RL(indent_str) + RL(command, font_attr) + ":", 

251 indent_str + " " + description] 

252 

253 return debugger_cli_common.rich_text_lines_from_rich_line_list(lines) 

254 

255 

256def get_tfdbg_logo(): 

257 """Make an ASCII representation of the tfdbg logo.""" 

258 

259 lines = [ 

260 "", 

261 "TTTTTT FFFF DDD BBBB GGG ", 

262 " TT F D D B B G ", 

263 " TT FFF D D BBBB G GG", 

264 " TT F D D B B G G", 

265 " TT F DDD BBBB GGG ", 

266 "", 

267 ] 

268 return debugger_cli_common.RichTextLines(lines) 

269 

270 

271_HORIZONTAL_BAR = "======================================" 

272 

273 

274def get_run_start_intro(run_call_count, 

275 fetches, 

276 feed_dict, 

277 tensor_filters, 

278 is_callable_runner=False): 

279 """Generate formatted intro for run-start UI. 

280 

281 Args: 

282 run_call_count: (int) Run call counter. 

283 fetches: Fetches of the `Session.run()` call. See doc of `Session.run()` 

284 for more details. 

285 feed_dict: Feeds to the `Session.run()` call. See doc of `Session.run()` 

286 for more details. 

287 tensor_filters: (dict) A dict from tensor-filter name to tensor-filter 

288 callable. 

289 is_callable_runner: (bool) whether a runner returned by 

290 Session.make_callable is being run. 

291 

292 Returns: 

293 (RichTextLines) Formatted intro message about the `Session.run()` call. 

294 """ 

295 

296 fetch_lines = common.get_flattened_names(fetches) 

297 

298 if not feed_dict: 

299 feed_dict_lines = [debugger_cli_common.RichLine(" (Empty)")] 

300 else: 

301 feed_dict_lines = [] 

302 for feed_key in feed_dict: 

303 feed_key_name = common.get_graph_element_name(feed_key) 

304 feed_dict_line = debugger_cli_common.RichLine(" ") 

305 feed_dict_line += debugger_cli_common.RichLine( 

306 feed_key_name, 

307 debugger_cli_common.MenuItem(None, "pf '%s'" % feed_key_name)) 

308 # Surround the name string with quotes, because feed_key_name may contain 

309 # spaces in some cases, e.g., SparseTensors. 

310 feed_dict_lines.append(feed_dict_line) 

311 feed_dict_lines = debugger_cli_common.rich_text_lines_from_rich_line_list( 

312 feed_dict_lines) 

313 

314 out = debugger_cli_common.RichTextLines(_HORIZONTAL_BAR) 

315 if is_callable_runner: 

316 out.append("Running a runner returned by Session.make_callable()") 

317 else: 

318 out.append("Session.run() call #%d:" % run_call_count) 

319 out.append("") 

320 out.append("Fetch(es):") 

321 out.extend(debugger_cli_common.RichTextLines( 

322 [" " + line for line in fetch_lines])) 

323 out.append("") 

324 out.append("Feed dict:") 

325 out.extend(feed_dict_lines) 

326 out.append(_HORIZONTAL_BAR) 

327 out.append("") 

328 out.append("Select one of the following commands to proceed ---->") 

329 

330 out.extend( 

331 _recommend_command( 

332 "run", 

333 "Execute the run() call with debug tensor-watching", 

334 create_link=True)) 

335 out.extend( 

336 _recommend_command( 

337 "run -n", 

338 "Execute the run() call without debug tensor-watching", 

339 create_link=True)) 

340 out.extend( 

341 _recommend_command( 

342 "run -t <T>", 

343 "Execute run() calls (T - 1) times without debugging, then " 

344 "execute run() once more with debugging and drop back to the CLI")) 

345 out.extend( 

346 _recommend_command( 

347 "run -f <filter_name>", 

348 "Keep executing run() calls until a dumped tensor passes a given, " 

349 "registered filter (conditional breakpoint mode)")) 

350 

351 more_lines = [" Registered filter(s):"] 

352 if tensor_filters: 

353 filter_names = [] 

354 for filter_name in tensor_filters: 

355 filter_names.append(filter_name) 

356 command_menu_node = debugger_cli_common.MenuItem( 

357 "", "run -f %s" % filter_name) 

358 more_lines.append(RL(" * ") + RL(filter_name, command_menu_node)) 

359 else: 

360 more_lines.append(" (None)") 

361 

362 out.extend( 

363 debugger_cli_common.rich_text_lines_from_rich_line_list(more_lines)) 

364 

365 out.append("") 

366 

367 out.append_rich_line(RL("For more details, see ") + 

368 RL("help.", debugger_cli_common.MenuItem("", "help")) + 

369 ".") 

370 out.append("") 

371 

372 # Make main menu for the run-start intro. 

373 menu = debugger_cli_common.Menu() 

374 menu.append(debugger_cli_common.MenuItem("run", "run")) 

375 menu.append(debugger_cli_common.MenuItem("exit", "exit")) 

376 out.annotations[debugger_cli_common.MAIN_MENU_KEY] = menu 

377 

378 return out 

379 

380 

381def get_run_short_description(run_call_count, 

382 fetches, 

383 feed_dict, 

384 is_callable_runner=False): 

385 """Get a short description of the run() call. 

386 

387 Args: 

388 run_call_count: (int) Run call counter. 

389 fetches: Fetches of the `Session.run()` call. See doc of `Session.run()` 

390 for more details. 

391 feed_dict: Feeds to the `Session.run()` call. See doc of `Session.run()` 

392 for more details. 

393 is_callable_runner: (bool) whether a runner returned by 

394 Session.make_callable is being run. 

395 

396 Returns: 

397 (str) A short description of the run() call, including information about 

398 the fetche(s) and feed(s). 

399 """ 

400 if is_callable_runner: 

401 return "runner from make_callable()" 

402 

403 description = "run #%d: " % run_call_count 

404 

405 if isinstance(fetches, (ops.Tensor, ops.Operation, variables.Variable)): 

406 description += "1 fetch (%s); " % common.get_graph_element_name(fetches) 

407 else: 

408 # Could be (nested) list, tuple, dict or namedtuple. 

409 num_fetches = len(common.get_flattened_names(fetches)) 

410 if num_fetches > 1: 

411 description += "%d fetches; " % num_fetches 

412 else: 

413 description += "%d fetch; " % num_fetches 

414 

415 if not feed_dict: 

416 description += "0 feeds" 

417 else: 

418 if len(feed_dict) == 1: 

419 for key in feed_dict: 

420 description += "1 feed (%s)" % ( 

421 key 

422 if isinstance(key, str) or not hasattr(key, "name") else key.name) 

423 else: 

424 description += "%d feeds" % len(feed_dict) 

425 

426 return description 

427 

428 

429def get_error_intro(tf_error): 

430 """Generate formatted intro for TensorFlow run-time error. 

431 

432 Args: 

433 tf_error: (errors.OpError) TensorFlow run-time error object. 

434 

435 Returns: 

436 (RichTextLines) Formatted intro message about the run-time OpError, with 

437 sample commands for debugging. 

438 """ 

439 

440 if hasattr(tf_error, "op") and hasattr(tf_error.op, "name"): 

441 op_name = tf_error.op.name 

442 else: 

443 op_name = None 

444 

445 intro_lines = [ 

446 "--------------------------------------", 

447 RL("!!! An error occurred during the run !!!", "blink"), 

448 "", 

449 ] 

450 

451 out = debugger_cli_common.rich_text_lines_from_rich_line_list(intro_lines) 

452 

453 if op_name is not None: 

454 out.extend(debugger_cli_common.RichTextLines( 

455 ["You may use the following commands to debug:"])) 

456 out.extend( 

457 _recommend_command("ni -a -d -t %s" % op_name, 

458 "Inspect information about the failing op.", 

459 create_link=True)) 

460 out.extend( 

461 _recommend_command("li -r %s" % op_name, 

462 "List inputs to the failing op, recursively.", 

463 create_link=True)) 

464 

465 out.extend( 

466 _recommend_command( 

467 "lt", 

468 "List all tensors dumped during the failing run() call.", 

469 create_link=True)) 

470 else: 

471 out.extend(debugger_cli_common.RichTextLines([ 

472 "WARNING: Cannot determine the name of the op that caused the error."])) 

473 

474 more_lines = [ 

475 "", 

476 "Op name: %s" % op_name, 

477 "Error type: " + str(type(tf_error)), 

478 "", 

479 "Details:", 

480 str(tf_error), 

481 "", 

482 "--------------------------------------", 

483 "", 

484 ] 

485 

486 out.extend(debugger_cli_common.RichTextLines(more_lines)) 

487 

488 return out