Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/tpu/tensor_tracer_flags.py: 35%
261 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ========================================================================
15"""Utilities to handle tensor tracer parameters."""
18import os
19import os.path
20import re
21from absl import flags
22from tensorflow.python.ops import linalg_ops
23from tensorflow.python.ops import math_ops
24from tensorflow.python.platform import tf_logging as logging
26TRACE_MODE_PART_TENSOR = 'part-tensor'
27TRACE_MODE_FULL_TENSOR = 'full-tensor'
28TRACE_MODE_FULL_TENSOR_SUMMARY = 'full_tensor_summary'
30TRACE_MODE_NAN_INF = 'nan-inf'
31TRACE_MODE_NORM = 'norm'
32TRACE_MODE_MAX_ABS = 'max-abs'
33TRACE_MODE_SUMMARY = 'summary'
34TRACE_MODE_HISTORY = 'history'
35# summary mode to collects a finite set of signatures for each traced tensor,
36# (such as norm, max, min, mean) and dumps it using tb summaries.
38# Full tensor mode dumps the whole tensor values for the traced tensors without
39# any processing on them; using tb summaries.
41_SUBMODE_BRIEF = 'brief'
42_SUBMODE_DETAILED = 'detailed'
44_FLAG_SINGLE_QUOTE_PAT = re.compile(r"\s*--([^=]+)='([^']*)'")
45_FLAG_DOUBLE_QUOTE_PAT = re.compile(r'\s*--([^=]+)="([^"]*)"')
46_FLAG_NO_QUOTE_PAT = re.compile(r'\s*--([^=]+)=(\S*)')
47_FLAG_NO_EQUAL_PAT = re.compile(r'\s*--([^=]+)\s*')
49FLAGS_ENV_VAR = 'TENSOR_TRACER_FLAGS'
50FLAG_NAME_ENABLE = 'enable'
51FLAG_NAME_TRACE_MODE = 'trace_mode'
52FLAG_NAME_TRACE_SCALAR_OPS = 'trace_scalar'
53FLAG_NAME_SUBMODE = 'submode'
54FLAG_NAME_EXCLUDED_OPNAMES = 'excluded_opnames'
55FLAG_NAME_EXCLUDED_OPTYPES = 'excluded_optypes'
56FLAG_NAME_INCLUDED_OPNAMES = 'included_opnames'
57FLAG_NAME_INCLUDED_OPTYPES = 'included_optypes'
58FLAG_NAME_TRACE_LEVEL = 'trace_level'
59FLAG_NAME_TRACE_DIR = 'trace_dir'
60FLAG_NAME_REPORT_FILE = 'report_file'
61FLAG_NAME_USE_TEST_UNDECLARED_OUTPUTS_DIR = 'use_test_undeclared_outputs_dir'
62FLAG_NAME_OP_RANGE = 'op_range'
63# Folder to dump the pre (before tensor tracer updates) and post graphs (after
64# tensor tracer updates).
65FLAG_NAME_DUMP_BEFORE_AFTER_GRAPHS = 'dump_graphs'
66FLAG_NAME_SUMMARY_SIGNATURES = 'signatures'
67FLAG_NAME_SUMMARY_PER_CORE = 'collect_summary_per_core'
68FLAG_NAME_TEMP_CACHE_VAR = 'use_temp_cache'
69FLAG_NAME_INSPECT_TRACE = 'inspect_trace'
70FLAG_NAME_FINGERPRINT_DIR = 'use_fingerprint_subdirectory'
71FLAG_FLUSH_SUMMARY = 'flush_summaries'
74VALID_FLAG_NAMES = [
75 FLAG_NAME_ENABLE, FLAG_NAME_TRACE_MODE,
76 FLAG_NAME_TRACE_SCALAR_OPS,
77 FLAG_NAME_SUBMODE, FLAG_NAME_EXCLUDED_OPNAMES,
78 FLAG_NAME_EXCLUDED_OPTYPES, FLAG_NAME_INCLUDED_OPNAMES,
79 FLAG_NAME_INCLUDED_OPTYPES, FLAG_NAME_TRACE_DIR,
80 FLAG_NAME_REPORT_FILE,
81 FLAG_NAME_USE_TEST_UNDECLARED_OUTPUTS_DIR,
82 FLAG_NAME_OP_RANGE,
83 FLAG_NAME_DUMP_BEFORE_AFTER_GRAPHS, FLAG_NAME_TRACE_LEVEL,
84 FLAG_NAME_SUMMARY_SIGNATURES, FLAG_NAME_SUMMARY_PER_CORE,
85 FLAG_NAME_TEMP_CACHE_VAR, FLAG_NAME_FINGERPRINT_DIR,
86 FLAG_NAME_INSPECT_TRACE, FLAG_FLUSH_SUMMARY,
87]
89_OP_RANGE_PAT = re.compile(r'(\d+):(\d+)')
90_TEST_UNDECLARED_OUTPUTS_DIR_ENV_VAR = 'TEST_UNDECLARED_OUTPUTS_DIR'
92_TT_DEFAULT_TRACE_LEVEL = 3
93_TT_PREFIX = 'tensor_tracer'
95_TT_NORM = 'norm'
96_TT_MAX = 'max'
97_TT_MAX_ABS = 'max-abs'
98_TT_MIN = 'min'
99_TT_SPARSITY = 'sparsity'
100_TT_MEAN = 'mean'
101_TT_VAR = 'var'
102_TT_SIZE = 'size'
104TT_SUMMARY_NORM = '%s_%s' % (_TT_PREFIX, _TT_NORM)
105TT_SUMMARY_MAX = '%s_%s' % (_TT_PREFIX, _TT_MAX)
106TT_SUMMARY_MAX_ABS = '%s_%s' % (_TT_PREFIX, _TT_MAX_ABS)
107TT_SUMMARY_MIN = '%s_%s' % (_TT_PREFIX, _TT_MIN)
108TT_SUMMARY_SPARSITY = '%s_%s' % (_TT_PREFIX, _TT_SPARSITY)
109TT_SUMMARY_MEAN = '%s_%s' % (_TT_PREFIX, _TT_MEAN)
110TT_SUMMARY_VAR = '%s_%s' % (_TT_PREFIX, _TT_VAR)
111TT_SUMMARY_SIZE = '%s_%s' % (_TT_PREFIX, _TT_SIZE)
113TT_SUMMARY_SIGNATURES = (TT_SUMMARY_NORM, TT_SUMMARY_MAX, TT_SUMMARY_MIN,
114 TT_SUMMARY_SPARSITY, TT_SUMMARY_MEAN, TT_SUMMARY_VAR,
115 TT_SUMMARY_SIZE, TT_SUMMARY_MAX_ABS)
117FLAGS = flags.FLAGS
119DELTA_THRESHOLD = flags.DEFINE_float(
120 'delta_threshold',
121 default=0.5,
122 help=('Log if history based diff crosses this threshold.'))
123TT_CHECK_FILTER = flags.DEFINE_bool(
124 'tt_check_filter',
125 default=False,
126 help='Terminate early to check op name filtering.')
127TT_SINGLE_CORE_SUMMARIES = flags.DEFINE_bool(
128 'tt_single_core_summaries',
129 default=False,
130 help='Report single core metric and avoid aggregation.')
133class TTParameters(object):
134 """A class that handles the parameters of Tensor Tracer."""
136 def __init__(self, env=None):
137 if env:
138 self._env = env
139 else:
140 self._env = os.environ
141 self._validate_flag_names()
142 self.trace_mode = self._get_trace_mode()
143 self.submode = self._get_submode()
144 self.trace_dir = self._get_trace_dir()
145 self.report_file_path = self._get_report_filepath()
146 self.op_range = self._get_op_range()
147 self.excluded_opname_re_list = self._flag_value_to_re_list(
148 FLAG_NAME_EXCLUDED_OPNAMES)
149 self.excluded_optype_re_list = self._flag_value_to_re_list(
150 FLAG_NAME_EXCLUDED_OPTYPES)
152 self.included_opname_re_list = self._flag_value_to_re_list(
153 FLAG_NAME_INCLUDED_OPNAMES)
154 self.included_optype_re_list = self._flag_value_to_re_list(
155 FLAG_NAME_INCLUDED_OPTYPES)
157 self.trace_scalar_ops = self.is_flag_on(FLAG_NAME_TRACE_SCALAR_OPS)
158 self.use_compact_trace = self.trace_mode in (TRACE_MODE_NAN_INF,
159 TRACE_MODE_NORM,
160 TRACE_MODE_HISTORY,
161 TRACE_MODE_MAX_ABS,
162 TRACE_MODE_SUMMARY)
163 self.use_temp_cache_var = self.is_flag_on(FLAG_NAME_TEMP_CACHE_VAR)
164 self.inspect_trace = self.is_flag_on(FLAG_NAME_INSPECT_TRACE)
165 self.use_fingerprint_subdir = self.is_flag_on(FLAG_NAME_FINGERPRINT_DIR)
167 _, self.graph_dump_path = self.get_flag_value(
168 FLAG_NAME_DUMP_BEFORE_AFTER_GRAPHS)
169 self.trace_level = self._get_flag_int_value(FLAG_NAME_TRACE_LEVEL,
170 _TT_DEFAULT_TRACE_LEVEL)
171 self.summary_signatures = self._get_summary_signatures()
172 self.collect_summary_per_core = self.is_flag_on(FLAG_NAME_SUMMARY_PER_CORE)
173 # TODO(b/199284834): Will be resolved with referenced bug.
174 if self.collect_summary_per_core:
175 logging.warning('Aggregate signatures are approximate for mean, variance'
176 ' and sparsity.')
177 self.flush_summaries_with_outside_compile = self.is_flag_on(
178 FLAG_FLUSH_SUMMARY)
179 # Do not produce errors or warnings if Tensor Tracer is not enabled.
180 if self.is_enabled():
181 self._check_flag_errors()
183 def _check_flag_errors(self):
184 if self.trace_mode in (TRACE_MODE_SUMMARY, TRACE_MODE_FULL_TENSOR_SUMMARY):
185 if not self.trace_dir:
186 raise ValueError('trace_dir must be explicitly provided in '
187 'TENSOR_TRACER_FLAGS when summary mode is used.')
189 def _get_report_filepath(self):
190 """Sets the path of the output report file."""
192 found, report_file_path = self.get_flag_value(FLAG_NAME_REPORT_FILE)
193 if found and report_file_path and self.use_test_undeclared_outputs_dir():
194 if os.path.isabs(report_file_path):
195 raise ValueError('If use_test_undeclared_outputs_dir is set,'
196 'report_file_path cannot be an absolute path (%s)'
197 %report_file_path)
198 outputs_dir = self._env.get(_TEST_UNDECLARED_OUTPUTS_DIR_ENV_VAR)
199 report_file_path = os.path.join(outputs_dir, report_file_path)
200 return report_file_path
202 def _get_op_range(self):
203 """Sets the index range of the Ops that we will consider tracing."""
204 found, op_range = self.get_flag_value(FLAG_NAME_OP_RANGE)
205 if not found or not op_range:
206 op_range = (-1, -1) # this means including all ops.
207 return op_range
208 match = _OP_RANGE_PAT.match(op_range)
209 if not match:
210 op_range = (-1, -1) # this means including all ops.
211 return op_range
212 op_range = (int(match.group(1)), int(match.group(2)))
213 return op_range
215 def _get_trace_dir(self):
216 found, trace_dir = self.get_flag_value(FLAG_NAME_TRACE_DIR)
217 if found and trace_dir and self.use_test_undeclared_outputs_dir():
218 raise ValueError(
219 'Cannot not use --%s and --%s at the same time' %
220 (FLAG_NAME_TRACE_DIR, FLAG_NAME_USE_TEST_UNDECLARED_OUTPUTS_DIR))
221 if self.use_test_undeclared_outputs_dir():
222 trace_dir = self._env.get(_TEST_UNDECLARED_OUTPUTS_DIR_ENV_VAR)
223 return trace_dir
225 def _get_trace_mode(self):
226 """Checks if the given trace mode is valid."""
228 found, trace_mode = self.get_flag_value(FLAG_NAME_TRACE_MODE)
229 if not found or not trace_mode:
230 trace_mode = TRACE_MODE_NORM
231 valid_trace_modes = [
232 TRACE_MODE_NAN_INF, TRACE_MODE_PART_TENSOR, TRACE_MODE_FULL_TENSOR,
233 TRACE_MODE_NORM, TRACE_MODE_MAX_ABS,
234 TRACE_MODE_SUMMARY, TRACE_MODE_FULL_TENSOR_SUMMARY,
235 TRACE_MODE_HISTORY
236 ]
237 if trace_mode not in valid_trace_modes:
238 raise ValueError('Invalid trace mode "%s" given to the Tensor_Tracer.'
239 'Valid trace modes are: %s'%(trace_mode,
240 valid_trace_modes))
241 return trace_mode
243 def is_brief_mode(self):
244 return self.submode == _SUBMODE_BRIEF
246 def _get_submode(self):
247 """Checks if the given submode is valid."""
249 found, submode = self.get_flag_value(FLAG_NAME_SUBMODE)
250 if not found or not submode:
251 submode = _SUBMODE_DETAILED
252 if not submode:
253 return
254 valid_submodes = [_SUBMODE_DETAILED, _SUBMODE_BRIEF]
255 if submode not in valid_submodes:
256 raise ValueError('Invalid submode "%s" given to the Tensor_Tracer.'
257 'Valid submodes are: %s'%(submode,
258 valid_submodes))
259 return submode
261 @staticmethod
262 def match_next_flag(tt_flags, pos):
263 """Returns the match for the next TensorTracer flag.
265 Args:
266 tt_flags: a string that contains the flags.
267 pos: where in flags to start the search.
269 Returns:
270 A pair where the first element is the regular-expression
271 match found and the second element indicates if the match
272 has a value.
273 """
275 match = _FLAG_DOUBLE_QUOTE_PAT.match(tt_flags, pos)
276 if match:
277 return match, True
278 match = _FLAG_SINGLE_QUOTE_PAT.match(tt_flags, pos)
279 if match:
280 return match, True
281 match = _FLAG_NO_QUOTE_PAT.match(tt_flags, pos)
282 if match:
283 return match, True
284 match = _FLAG_NO_EQUAL_PAT.match(tt_flags, pos)
285 if match:
286 # The flag is found but is not given a value.
287 return match, False
288 # The flag is not found.
289 return None, False
291 def _validate_flag_names(self):
292 """Validates if the TensorTrace flags passed are valid."""
293 tensor_tracer_flags = self._env.get(FLAGS_ENV_VAR)
294 if not tensor_tracer_flags:
295 return
296 pos = 0
297 while True:
298 match, _ = TTParameters.match_next_flag(tensor_tracer_flags, pos)
299 if not match:
300 break
301 flag_name = match.group(1)
302 if flag_name not in VALID_FLAG_NAMES:
303 raise ValueError(
304 'The flag name "%s" passed via the environment variable "%s" '
305 'is invalid. Valid flag names are:'
306 '\n%s' % (flag_name, FLAGS_ENV_VAR, VALID_FLAG_NAMES))
307 pos = match.end()
309 def _supported_signatures(self):
310 """Returns a tuple of supported signatures."""
311 return TT_SUMMARY_SIGNATURES
313 def _get_summary_signatures(self):
314 """Verifies and returns the summary signatures.
316 Returns:
317 A dictionary of the signature identifiers {signature: index} that will be
318 computed when trace_mode is summary.
319 """
320 signatures = self._flag_value_as_list(FLAG_NAME_SUMMARY_SIGNATURES)
321 supported_signatures = self._supported_signatures()
323 tt_signatures = []
324 for signature in signatures:
325 signature_with_prefix = '%s_%s' % (_TT_PREFIX, signature)
326 if signature in supported_signatures:
327 tt_signatures.append(signature)
328 elif signature_with_prefix in supported_signatures:
329 tt_signatures.append(signature_with_prefix)
330 else:
331 logging.warning('Unknown signature:%s. Supported signatures: %s' %
332 (signature, supported_signatures))
333 if not tt_signatures:
334 # Default case collects norm and max only.
335 return {TT_SUMMARY_MAX_ABS: 0, TT_SUMMARY_NORM: 1}
336 else:
337 return {signature: idx for idx, signature in enumerate(tt_signatures)}
339 def get_signature_to_agg_fn_map(self):
340 """Returns a map that contains the aggregate function for each signature."""
341 # TODO(b/199284834): Aggregations are not accurate for mean and sparsity if
342 # cores have a different number of elements. Variance uses the maximal core
343 # variance.
344 return {TRACE_MODE_NORM: linalg_ops.norm,
345 TRACE_MODE_HISTORY: math_ops.reduce_max,
346 TRACE_MODE_MAX_ABS: math_ops.reduce_max,
347 TRACE_MODE_NAN_INF: math_ops.reduce_max,
348 TT_SUMMARY_NORM: linalg_ops.norm,
349 TT_SUMMARY_MAX: math_ops.reduce_max,
350 TT_SUMMARY_MAX_ABS:
351 lambda t, axis=0: math_ops.reduce_max(math_ops.abs(t), # pylint: disable=g-long-lambda
352 axis=axis),
353 TT_SUMMARY_MIN: math_ops.reduce_min,
354 # Exact if each part has the same number of values.
355 TT_SUMMARY_SPARSITY: math_ops.reduce_mean,
356 TT_SUMMARY_MEAN: math_ops.reduce_mean,
357 TT_SUMMARY_VAR: math_ops.reduce_max, # Simply reduce max variance.
358 TT_SUMMARY_SIZE: math_ops.reduce_sum}
360 def _flag_value_as_list(self, wanted_flag_name):
361 """Returns the string list of a TensorTracer flag.
363 Args:
364 wanted_flag_name: the name of the flag we are looking for.
366 Returns:
367 The list value of the flag.
368 """
369 string_value_list = []
370 found, flag_value = self.get_flag_value(wanted_flag_name)
372 if found:
373 assert flag_value is not None
374 string_value_list = flag_value.split(',')
375 return string_value_list
377 def _flag_value_as_int_list(self, wanted_flag_name):
378 """Returns the integer list of a TensorTracer flag.
380 Args:
381 wanted_flag_name: the name of the flag we are looking for.
383 Returns:
384 the value of the flag.
385 Raises:
386 RuntimeError: If supposedly deadcode is reached.
387 """
388 int_list = []
389 found, flag_value = self.get_flag_value(wanted_flag_name)
391 if found and flag_value:
392 try:
393 integer_values = flag_value.split(',')
394 int_list = [int(int_val) for int_val in integer_values]
395 except ValueError:
396 logging.warning('Cannot convert %s to int for flag %s', int_list,
397 wanted_flag_name)
398 return int_list
400 def _get_flag_int_value(self, wanted_flag_name, default_value):
401 """Returns the int value of a TensorTracer flag.
403 Args:
404 wanted_flag_name: the name of the flag we are looking for.
405 default_value: the default value for the flag, if not provided.
406 Returns:
407 the value of the flag.
408 Raises:
409 RuntimeError: If supposedly deadcode is reached.
410 """
411 flag_int_value = default_value
412 found, flag_value = self.get_flag_value(wanted_flag_name)
414 if found:
415 try:
416 flag_int_value = int(flag_value)
417 except ValueError:
418 logging.warning('Cannot convert %s to int for flag %s' % (
419 flag_int_value, wanted_flag_name))
420 return flag_int_value
422 def get_flag_value(self, wanted_flag_name):
423 """Returns the value of a TensorTracer flags.
425 Args:
426 wanted_flag_name: the name of the flag we are looking for.
428 Returns:
429 A pair where the first element indicates if the flag is
430 found and the second element is the value of the flag.
432 Raises:
433 RuntimeError: If supposedly deadcode is reached.
434 """
436 tensor_tracer_flags = self._env.get(FLAGS_ENV_VAR)
437 if not tensor_tracer_flags:
438 return False, None
439 pos = 0
440 while True:
441 match, has_value = TTParameters.match_next_flag(
442 tensor_tracer_flags, pos)
443 if not match:
444 return False, None
445 flag_name = match.group(1)
446 if has_value:
447 flag_value = match.group(2)
448 else:
449 flag_value = None
450 if flag_name == wanted_flag_name:
451 return True, flag_value
452 pos = match.end()
453 raise RuntimeError('Invalid tensor tracer flag. Could not recognize %s.' %
454 flag_name)
456 def _flag_value_to_re_list(self, flag_name):
457 """Converts list of strings to compiled RE."""
459 re_list = []
460 found, flag_value = self.get_flag_value(flag_name)
461 if not found or not flag_value:
462 return re_list
463 list_of_values = flag_value.split(',')
464 for v in list_of_values:
465 r = re.compile(v)
466 re_list.append(r)
467 return re_list
469 def is_flag_on(self, flag_name):
470 """Returns True if the given flag is on."""
472 found, flag_value = self.get_flag_value(flag_name)
473 if not found:
474 return False
475 if flag_value is None:
476 return True
477 # Depends on the flag value.
478 flag_value = flag_value.lower()
479 enabled = flag_value in ['1', 't', 'true', 'y', 'yes']
480 return enabled
482 def is_enabled(self):
483 """Returns True if TensorTracer is enabled."""
485 if self.is_flag_on(FLAG_NAME_ENABLE):
486 logging.debug('Tensor Tracer is enabled with flags %s.',
487 self._env.get(FLAGS_ENV_VAR))
488 return True
489 else:
490 return False
492 def use_test_undeclared_outputs_dir(self):
493 """Decides the output directory of the report and trace files.
495 Args:
496 None.
498 Returns:
499 True if the output files should be written to the
500 test-undeclared-outputs-directory defined via an
501 env variable.
502 """
504 return self.is_flag_on(FLAG_NAME_USE_TEST_UNDECLARED_OUTPUTS_DIR)