Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/debug/cli/command_parser.py: 13%

204 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Command parsing module for TensorFlow Debugger (tfdbg).""" 

16import argparse 

17import ast 

18import re 

19import sys 

20 

21 

22_BRACKETS_PATTERN = re.compile(r"\[[^\]]*\]") 

23_QUOTES_PATTERN = re.compile(r"(\"[^\"]*\"|\'[^\']*\')") 

24_WHITESPACE_PATTERN = re.compile(r"\s+") 

25 

26_NUMBER_PATTERN = re.compile(r"[-+]?(\d+(\.\d*)?|\.\d+)([eE][-+]?\d+)?") 

27 

28 

29class Interval(object): 

30 """Represents an interval between a start and end value.""" 

31 

32 def __init__(self, start, start_included, end, end_included): 

33 self.start = start 

34 self.start_included = start_included 

35 self.end = end 

36 self.end_included = end_included 

37 

38 def contains(self, value): 

39 if value < self.start or value == self.start and not self.start_included: 

40 return False 

41 if value > self.end or value == self.end and not self.end_included: 

42 return False 

43 return True 

44 

45 def __eq__(self, other): 

46 return (self.start == other.start and 

47 self.start_included == other.start_included and 

48 self.end == other.end and 

49 self.end_included == other.end_included) 

50 

51 

52def parse_command(command): 

53 """Parse command string into a list of arguments. 

54 

55 - Disregards whitespace inside double quotes and brackets. 

56 - Strips paired leading and trailing double quotes in arguments. 

57 - Splits the command at whitespace. 

58 

59 Nested double quotes and brackets are not handled. 

60 

61 Args: 

62 command: (str) Input command. 

63 

64 Returns: 

65 (list of str) List of arguments. 

66 """ 

67 

68 command = command.strip() 

69 if not command: 

70 return [] 

71 

72 brackets_intervals = [f.span() for f in _BRACKETS_PATTERN.finditer(command)] 

73 quotes_intervals = [f.span() for f in _QUOTES_PATTERN.finditer(command)] 

74 whitespaces_intervals = [ 

75 f.span() for f in _WHITESPACE_PATTERN.finditer(command) 

76 ] 

77 

78 if not whitespaces_intervals: 

79 return [command] 

80 

81 arguments = [] 

82 idx0 = 0 

83 for start, end in whitespaces_intervals + [(len(command), None)]: 

84 # Skip whitespace stretches enclosed in brackets or double quotes. 

85 

86 if not any(interval[0] < start < interval[1] 

87 for interval in brackets_intervals + quotes_intervals): 

88 argument = command[idx0:start] 

89 

90 # Strip leading and trailing double quote if they are paired. 

91 if (argument.startswith("\"") and argument.endswith("\"") or 

92 argument.startswith("'") and argument.endswith("'")): 

93 argument = argument[1:-1] 

94 arguments.append(argument) 

95 idx0 = end 

96 

97 return arguments 

98 

99 

100def extract_output_file_path(args): 

101 """Extract output file path from command arguments. 

102 

103 Args: 

104 args: (list of str) command arguments. 

105 

106 Returns: 

107 (list of str) Command arguments with the output file path part stripped. 

108 (str or None) Output file path (if any). 

109 

110 Raises: 

111 SyntaxError: If there is no file path after the last ">" character. 

112 """ 

113 

114 if args and args[-1].endswith(">"): 

115 raise SyntaxError("Redirect file path is empty") 

116 elif args and args[-1].startswith(">"): 

117 try: 

118 _parse_interval(args[-1]) 

119 if len(args) > 1 and args[-2].startswith("-"): 

120 output_file_path = None 

121 else: 

122 output_file_path = args[-1][1:] 

123 args = args[:-1] 

124 except ValueError: 

125 output_file_path = args[-1][1:] 

126 args = args[:-1] 

127 elif len(args) > 1 and args[-2] == ">": 

128 output_file_path = args[-1] 

129 args = args[:-2] 

130 elif args and args[-1].count(">") == 1: 

131 gt_index = args[-1].index(">") 

132 if gt_index > 0 and args[-1][gt_index - 1] == "=": 

133 output_file_path = None 

134 else: 

135 output_file_path = args[-1][gt_index + 1:] 

136 args[-1] = args[-1][:gt_index] 

137 elif len(args) > 1 and args[-2].endswith(">"): 

138 output_file_path = args[-1] 

139 args = args[:-1] 

140 args[-1] = args[-1][:-1] 

141 else: 

142 output_file_path = None 

143 

144 return args, output_file_path 

145 

146 

147def parse_tensor_name_with_slicing(in_str): 

148 """Parse tensor name, potentially suffixed by slicing string. 

149 

150 Args: 

151 in_str: (str) Input name of the tensor, potentially followed by a slicing 

152 string. E.g.: Without slicing string: "hidden/weights/Variable:0", with 

153 slicing string: "hidden/weights/Variable:0[1, :]" 

154 

155 Returns: 

156 (str) name of the tensor 

157 (str) slicing string, if any. If no slicing string is present, return "". 

158 """ 

159 

160 if in_str.count("[") == 1 and in_str.endswith("]"): 

161 tensor_name = in_str[:in_str.index("[")] 

162 tensor_slicing = in_str[in_str.index("["):] 

163 else: 

164 tensor_name = in_str 

165 tensor_slicing = "" 

166 

167 return tensor_name, tensor_slicing 

168 

169 

170def validate_slicing_string(slicing_string): 

171 """Validate a slicing string. 

172 

173 Check if the input string contains only brackets, digits, commas and 

174 colons that are valid characters in numpy-style array slicing. 

175 

176 Args: 

177 slicing_string: (str) Input slicing string to be validated. 

178 

179 Returns: 

180 (bool) True if and only if the slicing string is valid. 

181 """ 

182 

183 return bool(re.search(r"^\[(\d|,|\s|:)+\]$", slicing_string)) 

184 

185 

186def _parse_slices(slicing_string): 

187 """Construct a tuple of slices from the slicing string. 

188 

189 The string must be a valid slicing string. 

190 

191 Args: 

192 slicing_string: (str) Input slicing string to be parsed. 

193 

194 Returns: 

195 tuple(slice1, slice2, ...) 

196 

197 Raises: 

198 ValueError: If tensor_slicing is not a valid numpy ndarray slicing str. 

199 """ 

200 parsed = [] 

201 for slice_string in slicing_string[1:-1].split(","): 

202 indices = slice_string.split(":") 

203 if len(indices) == 1: 

204 parsed.append(int(indices[0].strip())) 

205 elif 2 <= len(indices) <= 3: 

206 parsed.append( 

207 slice(*[ 

208 int(index.strip()) if index.strip() else None for index in indices 

209 ])) 

210 else: 

211 raise ValueError("Invalid tensor-slicing string.") 

212 return tuple(parsed) 

213 

214 

215def parse_indices(indices_string): 

216 """Parse a string representing indices. 

217 

218 For example, if the input is "[1, 2, 3]", the return value will be a list of 

219 indices: [1, 2, 3] 

220 

221 Args: 

222 indices_string: (str) a string representing indices. Can optionally be 

223 surrounded by a pair of brackets. 

224 

225 Returns: 

226 (list of int): Parsed indices. 

227 """ 

228 

229 # Strip whitespace. 

230 indices_string = re.sub(r"\s+", "", indices_string) 

231 

232 # Strip any brackets at the two ends. 

233 if indices_string.startswith("[") and indices_string.endswith("]"): 

234 indices_string = indices_string[1:-1] 

235 

236 return [int(element) for element in indices_string.split(",")] 

237 

238 

239def parse_ranges(range_string): 

240 """Parse a string representing numerical range(s). 

241 

242 Args: 

243 range_string: (str) A string representing a numerical range or a list of 

244 them. For example: 

245 "[-1.0,1.0]", "[-inf, 0]", "[[-inf, -1.0], [1.0, inf]]" 

246 

247 Returns: 

248 (list of list of float) A list of numerical ranges parsed from the input 

249 string. 

250 

251 Raises: 

252 ValueError: If the input doesn't represent a range or a list of ranges. 

253 """ 

254 

255 range_string = range_string.strip() 

256 if not range_string: 

257 return [] 

258 

259 if "inf" in range_string: 

260 range_string = re.sub(r"inf", repr(sys.float_info.max), range_string) 

261 

262 ranges = ast.literal_eval(range_string) 

263 if isinstance(ranges, list) and not isinstance(ranges[0], list): 

264 ranges = [ranges] 

265 

266 # Verify that ranges is a list of list of numbers. 

267 for item in ranges: 

268 if len(item) != 2: 

269 raise ValueError("Incorrect number of elements in range") 

270 elif not isinstance(item[0], (int, float)): 

271 raise ValueError("Incorrect type in the 1st element of range: %s" % 

272 type(item[0])) 

273 elif not isinstance(item[1], (int, float)): 

274 raise ValueError("Incorrect type in the 2nd element of range: %s" % 

275 type(item[0])) 

276 

277 return ranges 

278 

279 

280def parse_memory_interval(interval_str): 

281 """Convert a human-readable memory interval to a tuple of start and end value. 

282 

283 Args: 

284 interval_str: (`str`) A human-readable str representing an interval 

285 (e.g., "[10kB, 20kB]", "<100M", ">100G"). Only the units "kB", "MB", "GB" 

286 are supported. The "B character at the end of the input `str` may be 

287 omitted. 

288 

289 Returns: 

290 `Interval` object where start and end are in bytes. 

291 

292 Raises: 

293 ValueError: if the input is not valid. 

294 """ 

295 str_interval = _parse_interval(interval_str) 

296 interval_start = 0 

297 interval_end = float("inf") 

298 if str_interval.start: 

299 interval_start = parse_readable_size_str(str_interval.start) 

300 if str_interval.end: 

301 interval_end = parse_readable_size_str(str_interval.end) 

302 if interval_start > interval_end: 

303 raise ValueError( 

304 "Invalid interval %s. Start of interval must be less than or equal " 

305 "to end of interval." % interval_str) 

306 return Interval(interval_start, str_interval.start_included, 

307 interval_end, str_interval.end_included) 

308 

309 

310def parse_time_interval(interval_str): 

311 """Convert a human-readable time interval to a tuple of start and end value. 

312 

313 Args: 

314 interval_str: (`str`) A human-readable str representing an interval 

315 (e.g., "[10us, 20us]", "<100s", ">100ms"). Supported time suffixes are 

316 us, ms, s. 

317 

318 Returns: 

319 `Interval` object where start and end are in microseconds. 

320 

321 Raises: 

322 ValueError: if the input is not valid. 

323 """ 

324 str_interval = _parse_interval(interval_str) 

325 interval_start = 0 

326 interval_end = float("inf") 

327 if str_interval.start: 

328 interval_start = parse_readable_time_str(str_interval.start) 

329 if str_interval.end: 

330 interval_end = parse_readable_time_str(str_interval.end) 

331 if interval_start > interval_end: 

332 raise ValueError( 

333 "Invalid interval %s. Start must be before end of interval." % 

334 interval_str) 

335 return Interval(interval_start, str_interval.start_included, 

336 interval_end, str_interval.end_included) 

337 

338 

339def _parse_interval(interval_str): 

340 """Convert a human-readable interval to a tuple of start and end value. 

341 

342 Args: 

343 interval_str: (`str`) A human-readable str representing an interval 

344 (e.g., "[1M, 2M]", "<100k", ">100ms"). The items following the ">", "<", 

345 ">=" and "<=" signs have to start with a number (e.g., 3.0, -2, .98). 

346 The same requirement applies to the items in the parentheses or brackets. 

347 

348 Returns: 

349 Interval object where start or end can be None 

350 if the range is specified as "<N" or ">N" respectively. 

351 

352 Raises: 

353 ValueError: if the input is not valid. 

354 """ 

355 interval_str = interval_str.strip() 

356 if interval_str.startswith("<="): 

357 if _NUMBER_PATTERN.match(interval_str[2:].strip()): 

358 return Interval(start=None, start_included=False, 

359 end=interval_str[2:].strip(), end_included=True) 

360 else: 

361 raise ValueError("Invalid value string after <= in '%s'" % interval_str) 

362 if interval_str.startswith("<"): 

363 if _NUMBER_PATTERN.match(interval_str[1:].strip()): 

364 return Interval(start=None, start_included=False, 

365 end=interval_str[1:].strip(), end_included=False) 

366 else: 

367 raise ValueError("Invalid value string after < in '%s'" % interval_str) 

368 if interval_str.startswith(">="): 

369 if _NUMBER_PATTERN.match(interval_str[2:].strip()): 

370 return Interval(start=interval_str[2:].strip(), start_included=True, 

371 end=None, end_included=False) 

372 else: 

373 raise ValueError("Invalid value string after >= in '%s'" % interval_str) 

374 if interval_str.startswith(">"): 

375 if _NUMBER_PATTERN.match(interval_str[1:].strip()): 

376 return Interval(start=interval_str[1:].strip(), start_included=False, 

377 end=None, end_included=False) 

378 else: 

379 raise ValueError("Invalid value string after > in '%s'" % interval_str) 

380 

381 if (not interval_str.startswith(("[", "(")) 

382 or not interval_str.endswith(("]", ")"))): 

383 raise ValueError( 

384 "Invalid interval format: %s. Valid formats are: [min, max], " 

385 "(min, max), <max, >min" % interval_str) 

386 interval = interval_str[1:-1].split(",") 

387 if len(interval) != 2: 

388 raise ValueError( 

389 "Incorrect interval format: %s. Interval should specify two values: " 

390 "[min, max] or (min, max)." % interval_str) 

391 

392 start_item = interval[0].strip() 

393 if not _NUMBER_PATTERN.match(start_item): 

394 raise ValueError("Invalid first item in interval: '%s'" % start_item) 

395 end_item = interval[1].strip() 

396 if not _NUMBER_PATTERN.match(end_item): 

397 raise ValueError("Invalid second item in interval: '%s'" % end_item) 

398 

399 return Interval(start=start_item, 

400 start_included=(interval_str[0] == "["), 

401 end=end_item, 

402 end_included=(interval_str[-1] == "]")) 

403 

404 

405def parse_readable_size_str(size_str): 

406 """Convert a human-readable str representation to number of bytes. 

407 

408 Only the units "kB", "MB", "GB" are supported. The "B character at the end 

409 of the input `str` may be omitted. 

410 

411 Args: 

412 size_str: (`str`) A human-readable str representing a number of bytes 

413 (e.g., "0", "1023", "1.1kB", "24 MB", "23GB", "100 G". 

414 

415 Returns: 

416 (`int`) The parsed number of bytes. 

417 

418 Raises: 

419 ValueError: on failure to parse the input `size_str`. 

420 """ 

421 

422 size_str = size_str.strip() 

423 if size_str.endswith("B"): 

424 size_str = size_str[:-1] 

425 

426 if size_str.isdigit(): 

427 return int(size_str) 

428 elif size_str.endswith("k"): 

429 return int(float(size_str[:-1]) * 1024) 

430 elif size_str.endswith("M"): 

431 return int(float(size_str[:-1]) * 1048576) 

432 elif size_str.endswith("G"): 

433 return int(float(size_str[:-1]) * 1073741824) 

434 else: 

435 raise ValueError("Failed to parsed human-readable byte size str: \"%s\"" % 

436 size_str) 

437 

438 

439def parse_readable_time_str(time_str): 

440 """Parses a time string in the format N, Nus, Nms, Ns. 

441 

442 Args: 

443 time_str: (`str`) string consisting of an integer time value optionally 

444 followed by 'us', 'ms', or 's' suffix. If suffix is not specified, 

445 value is assumed to be in microseconds. (e.g. 100us, 8ms, 5s, 100). 

446 

447 Returns: 

448 Microseconds value. 

449 """ 

450 def parse_positive_float(value_str): 

451 value = float(value_str) 

452 if value < 0: 

453 raise ValueError( 

454 "Invalid time %s. Time value must be positive." % value_str) 

455 return value 

456 

457 time_str = time_str.strip() 

458 if time_str.endswith("us"): 

459 return int(parse_positive_float(time_str[:-2])) 

460 elif time_str.endswith("ms"): 

461 return int(parse_positive_float(time_str[:-2]) * 1e3) 

462 elif time_str.endswith("s"): 

463 return int(parse_positive_float(time_str[:-1]) * 1e6) 

464 return int(parse_positive_float(time_str)) 

465 

466 

467def evaluate_tensor_slice(tensor, tensor_slicing): 

468 """Call eval on the slicing of a tensor, with validation. 

469 

470 Args: 

471 tensor: (numpy ndarray) The tensor value. 

472 tensor_slicing: (str or None) Slicing of the tensor, e.g., "[:, 1]". If 

473 None, no slicing will be performed on the tensor. 

474 

475 Returns: 

476 (numpy ndarray) The sliced tensor. 

477 

478 Raises: 

479 ValueError: If tensor_slicing is not a valid numpy ndarray slicing str. 

480 """ 

481 

482 _ = tensor 

483 

484 if not validate_slicing_string(tensor_slicing): 

485 raise ValueError("Invalid tensor-slicing string.") 

486 

487 return tensor[_parse_slices(tensor_slicing)] 

488 

489 

490def get_print_tensor_argparser(description): 

491 """Get an ArgumentParser for a command that prints tensor values. 

492 

493 Examples of such commands include print_tensor and print_feed. 

494 

495 Args: 

496 description: Description of the ArgumentParser. 

497 

498 Returns: 

499 An instance of argparse.ArgumentParser. 

500 """ 

501 

502 ap = argparse.ArgumentParser( 

503 description=description, usage=argparse.SUPPRESS) 

504 ap.add_argument( 

505 "tensor_name", 

506 type=str, 

507 help="Name of the tensor, followed by any slicing indices, " 

508 "e.g., hidden1/Wx_plus_b/MatMul:0, " 

509 "hidden1/Wx_plus_b/MatMul:0[1, :]") 

510 ap.add_argument( 

511 "-n", 

512 "--number", 

513 dest="number", 

514 type=int, 

515 default=-1, 

516 help="0-based dump number for the specified tensor. " 

517 "Required for tensor with multiple dumps.") 

518 ap.add_argument( 

519 "-r", 

520 "--ranges", 

521 dest="ranges", 

522 type=str, 

523 default="", 

524 help="Numerical ranges to highlight tensor elements in. " 

525 "Examples: -r 0,1e-8, -r [-0.1,0.1], " 

526 "-r \"[[-inf, -0.1], [0.1, inf]]\"") 

527 ap.add_argument( 

528 "-a", 

529 "--all", 

530 dest="print_all", 

531 action="store_true", 

532 help="Print the tensor in its entirety, i.e., do not use ellipses.") 

533 ap.add_argument( 

534 "-s", 

535 "--numeric_summary", 

536 action="store_true", 

537 help="Include summary for non-empty tensors of numeric (int*, float*, " 

538 "complex*) and Boolean types.") 

539 ap.add_argument( 

540 "-w", 

541 "--write_path", 

542 type=str, 

543 default="", 

544 help="Path of the numpy file to write the tensor data to, using " 

545 "numpy.save().") 

546 return ap