Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/debug/cli/command_parser.py: 13%
204 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Command parsing module for TensorFlow Debugger (tfdbg)."""
16import argparse
17import ast
18import re
19import sys
22_BRACKETS_PATTERN = re.compile(r"\[[^\]]*\]")
23_QUOTES_PATTERN = re.compile(r"(\"[^\"]*\"|\'[^\']*\')")
24_WHITESPACE_PATTERN = re.compile(r"\s+")
26_NUMBER_PATTERN = re.compile(r"[-+]?(\d+(\.\d*)?|\.\d+)([eE][-+]?\d+)?")
29class Interval(object):
30 """Represents an interval between a start and end value."""
32 def __init__(self, start, start_included, end, end_included):
33 self.start = start
34 self.start_included = start_included
35 self.end = end
36 self.end_included = end_included
38 def contains(self, value):
39 if value < self.start or value == self.start and not self.start_included:
40 return False
41 if value > self.end or value == self.end and not self.end_included:
42 return False
43 return True
45 def __eq__(self, other):
46 return (self.start == other.start and
47 self.start_included == other.start_included and
48 self.end == other.end and
49 self.end_included == other.end_included)
52def parse_command(command):
53 """Parse command string into a list of arguments.
55 - Disregards whitespace inside double quotes and brackets.
56 - Strips paired leading and trailing double quotes in arguments.
57 - Splits the command at whitespace.
59 Nested double quotes and brackets are not handled.
61 Args:
62 command: (str) Input command.
64 Returns:
65 (list of str) List of arguments.
66 """
68 command = command.strip()
69 if not command:
70 return []
72 brackets_intervals = [f.span() for f in _BRACKETS_PATTERN.finditer(command)]
73 quotes_intervals = [f.span() for f in _QUOTES_PATTERN.finditer(command)]
74 whitespaces_intervals = [
75 f.span() for f in _WHITESPACE_PATTERN.finditer(command)
76 ]
78 if not whitespaces_intervals:
79 return [command]
81 arguments = []
82 idx0 = 0
83 for start, end in whitespaces_intervals + [(len(command), None)]:
84 # Skip whitespace stretches enclosed in brackets or double quotes.
86 if not any(interval[0] < start < interval[1]
87 for interval in brackets_intervals + quotes_intervals):
88 argument = command[idx0:start]
90 # Strip leading and trailing double quote if they are paired.
91 if (argument.startswith("\"") and argument.endswith("\"") or
92 argument.startswith("'") and argument.endswith("'")):
93 argument = argument[1:-1]
94 arguments.append(argument)
95 idx0 = end
97 return arguments
100def extract_output_file_path(args):
101 """Extract output file path from command arguments.
103 Args:
104 args: (list of str) command arguments.
106 Returns:
107 (list of str) Command arguments with the output file path part stripped.
108 (str or None) Output file path (if any).
110 Raises:
111 SyntaxError: If there is no file path after the last ">" character.
112 """
114 if args and args[-1].endswith(">"):
115 raise SyntaxError("Redirect file path is empty")
116 elif args and args[-1].startswith(">"):
117 try:
118 _parse_interval(args[-1])
119 if len(args) > 1 and args[-2].startswith("-"):
120 output_file_path = None
121 else:
122 output_file_path = args[-1][1:]
123 args = args[:-1]
124 except ValueError:
125 output_file_path = args[-1][1:]
126 args = args[:-1]
127 elif len(args) > 1 and args[-2] == ">":
128 output_file_path = args[-1]
129 args = args[:-2]
130 elif args and args[-1].count(">") == 1:
131 gt_index = args[-1].index(">")
132 if gt_index > 0 and args[-1][gt_index - 1] == "=":
133 output_file_path = None
134 else:
135 output_file_path = args[-1][gt_index + 1:]
136 args[-1] = args[-1][:gt_index]
137 elif len(args) > 1 and args[-2].endswith(">"):
138 output_file_path = args[-1]
139 args = args[:-1]
140 args[-1] = args[-1][:-1]
141 else:
142 output_file_path = None
144 return args, output_file_path
147def parse_tensor_name_with_slicing(in_str):
148 """Parse tensor name, potentially suffixed by slicing string.
150 Args:
151 in_str: (str) Input name of the tensor, potentially followed by a slicing
152 string. E.g.: Without slicing string: "hidden/weights/Variable:0", with
153 slicing string: "hidden/weights/Variable:0[1, :]"
155 Returns:
156 (str) name of the tensor
157 (str) slicing string, if any. If no slicing string is present, return "".
158 """
160 if in_str.count("[") == 1 and in_str.endswith("]"):
161 tensor_name = in_str[:in_str.index("[")]
162 tensor_slicing = in_str[in_str.index("["):]
163 else:
164 tensor_name = in_str
165 tensor_slicing = ""
167 return tensor_name, tensor_slicing
170def validate_slicing_string(slicing_string):
171 """Validate a slicing string.
173 Check if the input string contains only brackets, digits, commas and
174 colons that are valid characters in numpy-style array slicing.
176 Args:
177 slicing_string: (str) Input slicing string to be validated.
179 Returns:
180 (bool) True if and only if the slicing string is valid.
181 """
183 return bool(re.search(r"^\[(\d|,|\s|:)+\]$", slicing_string))
186def _parse_slices(slicing_string):
187 """Construct a tuple of slices from the slicing string.
189 The string must be a valid slicing string.
191 Args:
192 slicing_string: (str) Input slicing string to be parsed.
194 Returns:
195 tuple(slice1, slice2, ...)
197 Raises:
198 ValueError: If tensor_slicing is not a valid numpy ndarray slicing str.
199 """
200 parsed = []
201 for slice_string in slicing_string[1:-1].split(","):
202 indices = slice_string.split(":")
203 if len(indices) == 1:
204 parsed.append(int(indices[0].strip()))
205 elif 2 <= len(indices) <= 3:
206 parsed.append(
207 slice(*[
208 int(index.strip()) if index.strip() else None for index in indices
209 ]))
210 else:
211 raise ValueError("Invalid tensor-slicing string.")
212 return tuple(parsed)
215def parse_indices(indices_string):
216 """Parse a string representing indices.
218 For example, if the input is "[1, 2, 3]", the return value will be a list of
219 indices: [1, 2, 3]
221 Args:
222 indices_string: (str) a string representing indices. Can optionally be
223 surrounded by a pair of brackets.
225 Returns:
226 (list of int): Parsed indices.
227 """
229 # Strip whitespace.
230 indices_string = re.sub(r"\s+", "", indices_string)
232 # Strip any brackets at the two ends.
233 if indices_string.startswith("[") and indices_string.endswith("]"):
234 indices_string = indices_string[1:-1]
236 return [int(element) for element in indices_string.split(",")]
239def parse_ranges(range_string):
240 """Parse a string representing numerical range(s).
242 Args:
243 range_string: (str) A string representing a numerical range or a list of
244 them. For example:
245 "[-1.0,1.0]", "[-inf, 0]", "[[-inf, -1.0], [1.0, inf]]"
247 Returns:
248 (list of list of float) A list of numerical ranges parsed from the input
249 string.
251 Raises:
252 ValueError: If the input doesn't represent a range or a list of ranges.
253 """
255 range_string = range_string.strip()
256 if not range_string:
257 return []
259 if "inf" in range_string:
260 range_string = re.sub(r"inf", repr(sys.float_info.max), range_string)
262 ranges = ast.literal_eval(range_string)
263 if isinstance(ranges, list) and not isinstance(ranges[0], list):
264 ranges = [ranges]
266 # Verify that ranges is a list of list of numbers.
267 for item in ranges:
268 if len(item) != 2:
269 raise ValueError("Incorrect number of elements in range")
270 elif not isinstance(item[0], (int, float)):
271 raise ValueError("Incorrect type in the 1st element of range: %s" %
272 type(item[0]))
273 elif not isinstance(item[1], (int, float)):
274 raise ValueError("Incorrect type in the 2nd element of range: %s" %
275 type(item[0]))
277 return ranges
280def parse_memory_interval(interval_str):
281 """Convert a human-readable memory interval to a tuple of start and end value.
283 Args:
284 interval_str: (`str`) A human-readable str representing an interval
285 (e.g., "[10kB, 20kB]", "<100M", ">100G"). Only the units "kB", "MB", "GB"
286 are supported. The "B character at the end of the input `str` may be
287 omitted.
289 Returns:
290 `Interval` object where start and end are in bytes.
292 Raises:
293 ValueError: if the input is not valid.
294 """
295 str_interval = _parse_interval(interval_str)
296 interval_start = 0
297 interval_end = float("inf")
298 if str_interval.start:
299 interval_start = parse_readable_size_str(str_interval.start)
300 if str_interval.end:
301 interval_end = parse_readable_size_str(str_interval.end)
302 if interval_start > interval_end:
303 raise ValueError(
304 "Invalid interval %s. Start of interval must be less than or equal "
305 "to end of interval." % interval_str)
306 return Interval(interval_start, str_interval.start_included,
307 interval_end, str_interval.end_included)
310def parse_time_interval(interval_str):
311 """Convert a human-readable time interval to a tuple of start and end value.
313 Args:
314 interval_str: (`str`) A human-readable str representing an interval
315 (e.g., "[10us, 20us]", "<100s", ">100ms"). Supported time suffixes are
316 us, ms, s.
318 Returns:
319 `Interval` object where start and end are in microseconds.
321 Raises:
322 ValueError: if the input is not valid.
323 """
324 str_interval = _parse_interval(interval_str)
325 interval_start = 0
326 interval_end = float("inf")
327 if str_interval.start:
328 interval_start = parse_readable_time_str(str_interval.start)
329 if str_interval.end:
330 interval_end = parse_readable_time_str(str_interval.end)
331 if interval_start > interval_end:
332 raise ValueError(
333 "Invalid interval %s. Start must be before end of interval." %
334 interval_str)
335 return Interval(interval_start, str_interval.start_included,
336 interval_end, str_interval.end_included)
339def _parse_interval(interval_str):
340 """Convert a human-readable interval to a tuple of start and end value.
342 Args:
343 interval_str: (`str`) A human-readable str representing an interval
344 (e.g., "[1M, 2M]", "<100k", ">100ms"). The items following the ">", "<",
345 ">=" and "<=" signs have to start with a number (e.g., 3.0, -2, .98).
346 The same requirement applies to the items in the parentheses or brackets.
348 Returns:
349 Interval object where start or end can be None
350 if the range is specified as "<N" or ">N" respectively.
352 Raises:
353 ValueError: if the input is not valid.
354 """
355 interval_str = interval_str.strip()
356 if interval_str.startswith("<="):
357 if _NUMBER_PATTERN.match(interval_str[2:].strip()):
358 return Interval(start=None, start_included=False,
359 end=interval_str[2:].strip(), end_included=True)
360 else:
361 raise ValueError("Invalid value string after <= in '%s'" % interval_str)
362 if interval_str.startswith("<"):
363 if _NUMBER_PATTERN.match(interval_str[1:].strip()):
364 return Interval(start=None, start_included=False,
365 end=interval_str[1:].strip(), end_included=False)
366 else:
367 raise ValueError("Invalid value string after < in '%s'" % interval_str)
368 if interval_str.startswith(">="):
369 if _NUMBER_PATTERN.match(interval_str[2:].strip()):
370 return Interval(start=interval_str[2:].strip(), start_included=True,
371 end=None, end_included=False)
372 else:
373 raise ValueError("Invalid value string after >= in '%s'" % interval_str)
374 if interval_str.startswith(">"):
375 if _NUMBER_PATTERN.match(interval_str[1:].strip()):
376 return Interval(start=interval_str[1:].strip(), start_included=False,
377 end=None, end_included=False)
378 else:
379 raise ValueError("Invalid value string after > in '%s'" % interval_str)
381 if (not interval_str.startswith(("[", "("))
382 or not interval_str.endswith(("]", ")"))):
383 raise ValueError(
384 "Invalid interval format: %s. Valid formats are: [min, max], "
385 "(min, max), <max, >min" % interval_str)
386 interval = interval_str[1:-1].split(",")
387 if len(interval) != 2:
388 raise ValueError(
389 "Incorrect interval format: %s. Interval should specify two values: "
390 "[min, max] or (min, max)." % interval_str)
392 start_item = interval[0].strip()
393 if not _NUMBER_PATTERN.match(start_item):
394 raise ValueError("Invalid first item in interval: '%s'" % start_item)
395 end_item = interval[1].strip()
396 if not _NUMBER_PATTERN.match(end_item):
397 raise ValueError("Invalid second item in interval: '%s'" % end_item)
399 return Interval(start=start_item,
400 start_included=(interval_str[0] == "["),
401 end=end_item,
402 end_included=(interval_str[-1] == "]"))
405def parse_readable_size_str(size_str):
406 """Convert a human-readable str representation to number of bytes.
408 Only the units "kB", "MB", "GB" are supported. The "B character at the end
409 of the input `str` may be omitted.
411 Args:
412 size_str: (`str`) A human-readable str representing a number of bytes
413 (e.g., "0", "1023", "1.1kB", "24 MB", "23GB", "100 G".
415 Returns:
416 (`int`) The parsed number of bytes.
418 Raises:
419 ValueError: on failure to parse the input `size_str`.
420 """
422 size_str = size_str.strip()
423 if size_str.endswith("B"):
424 size_str = size_str[:-1]
426 if size_str.isdigit():
427 return int(size_str)
428 elif size_str.endswith("k"):
429 return int(float(size_str[:-1]) * 1024)
430 elif size_str.endswith("M"):
431 return int(float(size_str[:-1]) * 1048576)
432 elif size_str.endswith("G"):
433 return int(float(size_str[:-1]) * 1073741824)
434 else:
435 raise ValueError("Failed to parsed human-readable byte size str: \"%s\"" %
436 size_str)
439def parse_readable_time_str(time_str):
440 """Parses a time string in the format N, Nus, Nms, Ns.
442 Args:
443 time_str: (`str`) string consisting of an integer time value optionally
444 followed by 'us', 'ms', or 's' suffix. If suffix is not specified,
445 value is assumed to be in microseconds. (e.g. 100us, 8ms, 5s, 100).
447 Returns:
448 Microseconds value.
449 """
450 def parse_positive_float(value_str):
451 value = float(value_str)
452 if value < 0:
453 raise ValueError(
454 "Invalid time %s. Time value must be positive." % value_str)
455 return value
457 time_str = time_str.strip()
458 if time_str.endswith("us"):
459 return int(parse_positive_float(time_str[:-2]))
460 elif time_str.endswith("ms"):
461 return int(parse_positive_float(time_str[:-2]) * 1e3)
462 elif time_str.endswith("s"):
463 return int(parse_positive_float(time_str[:-1]) * 1e6)
464 return int(parse_positive_float(time_str))
467def evaluate_tensor_slice(tensor, tensor_slicing):
468 """Call eval on the slicing of a tensor, with validation.
470 Args:
471 tensor: (numpy ndarray) The tensor value.
472 tensor_slicing: (str or None) Slicing of the tensor, e.g., "[:, 1]". If
473 None, no slicing will be performed on the tensor.
475 Returns:
476 (numpy ndarray) The sliced tensor.
478 Raises:
479 ValueError: If tensor_slicing is not a valid numpy ndarray slicing str.
480 """
482 _ = tensor
484 if not validate_slicing_string(tensor_slicing):
485 raise ValueError("Invalid tensor-slicing string.")
487 return tensor[_parse_slices(tensor_slicing)]
490def get_print_tensor_argparser(description):
491 """Get an ArgumentParser for a command that prints tensor values.
493 Examples of such commands include print_tensor and print_feed.
495 Args:
496 description: Description of the ArgumentParser.
498 Returns:
499 An instance of argparse.ArgumentParser.
500 """
502 ap = argparse.ArgumentParser(
503 description=description, usage=argparse.SUPPRESS)
504 ap.add_argument(
505 "tensor_name",
506 type=str,
507 help="Name of the tensor, followed by any slicing indices, "
508 "e.g., hidden1/Wx_plus_b/MatMul:0, "
509 "hidden1/Wx_plus_b/MatMul:0[1, :]")
510 ap.add_argument(
511 "-n",
512 "--number",
513 dest="number",
514 type=int,
515 default=-1,
516 help="0-based dump number for the specified tensor. "
517 "Required for tensor with multiple dumps.")
518 ap.add_argument(
519 "-r",
520 "--ranges",
521 dest="ranges",
522 type=str,
523 default="",
524 help="Numerical ranges to highlight tensor elements in. "
525 "Examples: -r 0,1e-8, -r [-0.1,0.1], "
526 "-r \"[[-inf, -0.1], [0.1, inf]]\"")
527 ap.add_argument(
528 "-a",
529 "--all",
530 dest="print_all",
531 action="store_true",
532 help="Print the tensor in its entirety, i.e., do not use ellipses.")
533 ap.add_argument(
534 "-s",
535 "--numeric_summary",
536 action="store_true",
537 help="Include summary for non-empty tensors of numeric (int*, float*, "
538 "complex*) and Boolean types.")
539 ap.add_argument(
540 "-w",
541 "--write_path",
542 type=str,
543 default="",
544 help="Path of the numpy file to write the tensor data to, using "
545 "numpy.save().")
546 return ap