Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/tpu/tensor_tracer_flags.py: 35%

261 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ======================================================================== 

15"""Utilities to handle tensor tracer parameters.""" 

16 

17 

18import os 

19import os.path 

20import re 

21from absl import flags 

22from tensorflow.python.ops import linalg_ops 

23from tensorflow.python.ops import math_ops 

24from tensorflow.python.platform import tf_logging as logging 

25 

26TRACE_MODE_PART_TENSOR = 'part-tensor' 

27TRACE_MODE_FULL_TENSOR = 'full-tensor' 

28TRACE_MODE_FULL_TENSOR_SUMMARY = 'full_tensor_summary' 

29 

30TRACE_MODE_NAN_INF = 'nan-inf' 

31TRACE_MODE_NORM = 'norm' 

32TRACE_MODE_MAX_ABS = 'max-abs' 

33TRACE_MODE_SUMMARY = 'summary' 

34TRACE_MODE_HISTORY = 'history' 

35# summary mode to collects a finite set of signatures for each traced tensor, 

36# (such as norm, max, min, mean) and dumps it using tb summaries. 

37 

38# Full tensor mode dumps the whole tensor values for the traced tensors without 

39# any processing on them; using tb summaries. 

40 

41_SUBMODE_BRIEF = 'brief' 

42_SUBMODE_DETAILED = 'detailed' 

43 

44_FLAG_SINGLE_QUOTE_PAT = re.compile(r"\s*--([^=]+)='([^']*)'") 

45_FLAG_DOUBLE_QUOTE_PAT = re.compile(r'\s*--([^=]+)="([^"]*)"') 

46_FLAG_NO_QUOTE_PAT = re.compile(r'\s*--([^=]+)=(\S*)') 

47_FLAG_NO_EQUAL_PAT = re.compile(r'\s*--([^=]+)\s*') 

48 

49FLAGS_ENV_VAR = 'TENSOR_TRACER_FLAGS' 

50FLAG_NAME_ENABLE = 'enable' 

51FLAG_NAME_TRACE_MODE = 'trace_mode' 

52FLAG_NAME_TRACE_SCALAR_OPS = 'trace_scalar' 

53FLAG_NAME_SUBMODE = 'submode' 

54FLAG_NAME_EXCLUDED_OPNAMES = 'excluded_opnames' 

55FLAG_NAME_EXCLUDED_OPTYPES = 'excluded_optypes' 

56FLAG_NAME_INCLUDED_OPNAMES = 'included_opnames' 

57FLAG_NAME_INCLUDED_OPTYPES = 'included_optypes' 

58FLAG_NAME_TRACE_LEVEL = 'trace_level' 

59FLAG_NAME_TRACE_DIR = 'trace_dir' 

60FLAG_NAME_REPORT_FILE = 'report_file' 

61FLAG_NAME_USE_TEST_UNDECLARED_OUTPUTS_DIR = 'use_test_undeclared_outputs_dir' 

62FLAG_NAME_OP_RANGE = 'op_range' 

63# Folder to dump the pre (before tensor tracer updates) and post graphs (after 

64# tensor tracer updates). 

65FLAG_NAME_DUMP_BEFORE_AFTER_GRAPHS = 'dump_graphs' 

66FLAG_NAME_SUMMARY_SIGNATURES = 'signatures' 

67FLAG_NAME_SUMMARY_PER_CORE = 'collect_summary_per_core' 

68FLAG_NAME_TEMP_CACHE_VAR = 'use_temp_cache' 

69FLAG_NAME_INSPECT_TRACE = 'inspect_trace' 

70FLAG_NAME_FINGERPRINT_DIR = 'use_fingerprint_subdirectory' 

71FLAG_FLUSH_SUMMARY = 'flush_summaries' 

72 

73 

74VALID_FLAG_NAMES = [ 

75 FLAG_NAME_ENABLE, FLAG_NAME_TRACE_MODE, 

76 FLAG_NAME_TRACE_SCALAR_OPS, 

77 FLAG_NAME_SUBMODE, FLAG_NAME_EXCLUDED_OPNAMES, 

78 FLAG_NAME_EXCLUDED_OPTYPES, FLAG_NAME_INCLUDED_OPNAMES, 

79 FLAG_NAME_INCLUDED_OPTYPES, FLAG_NAME_TRACE_DIR, 

80 FLAG_NAME_REPORT_FILE, 

81 FLAG_NAME_USE_TEST_UNDECLARED_OUTPUTS_DIR, 

82 FLAG_NAME_OP_RANGE, 

83 FLAG_NAME_DUMP_BEFORE_AFTER_GRAPHS, FLAG_NAME_TRACE_LEVEL, 

84 FLAG_NAME_SUMMARY_SIGNATURES, FLAG_NAME_SUMMARY_PER_CORE, 

85 FLAG_NAME_TEMP_CACHE_VAR, FLAG_NAME_FINGERPRINT_DIR, 

86 FLAG_NAME_INSPECT_TRACE, FLAG_FLUSH_SUMMARY, 

87] 

88 

89_OP_RANGE_PAT = re.compile(r'(\d+):(\d+)') 

90_TEST_UNDECLARED_OUTPUTS_DIR_ENV_VAR = 'TEST_UNDECLARED_OUTPUTS_DIR' 

91 

92_TT_DEFAULT_TRACE_LEVEL = 3 

93_TT_PREFIX = 'tensor_tracer' 

94 

95_TT_NORM = 'norm' 

96_TT_MAX = 'max' 

97_TT_MAX_ABS = 'max-abs' 

98_TT_MIN = 'min' 

99_TT_SPARSITY = 'sparsity' 

100_TT_MEAN = 'mean' 

101_TT_VAR = 'var' 

102_TT_SIZE = 'size' 

103 

104TT_SUMMARY_NORM = '%s_%s' % (_TT_PREFIX, _TT_NORM) 

105TT_SUMMARY_MAX = '%s_%s' % (_TT_PREFIX, _TT_MAX) 

106TT_SUMMARY_MAX_ABS = '%s_%s' % (_TT_PREFIX, _TT_MAX_ABS) 

107TT_SUMMARY_MIN = '%s_%s' % (_TT_PREFIX, _TT_MIN) 

108TT_SUMMARY_SPARSITY = '%s_%s' % (_TT_PREFIX, _TT_SPARSITY) 

109TT_SUMMARY_MEAN = '%s_%s' % (_TT_PREFIX, _TT_MEAN) 

110TT_SUMMARY_VAR = '%s_%s' % (_TT_PREFIX, _TT_VAR) 

111TT_SUMMARY_SIZE = '%s_%s' % (_TT_PREFIX, _TT_SIZE) 

112 

113TT_SUMMARY_SIGNATURES = (TT_SUMMARY_NORM, TT_SUMMARY_MAX, TT_SUMMARY_MIN, 

114 TT_SUMMARY_SPARSITY, TT_SUMMARY_MEAN, TT_SUMMARY_VAR, 

115 TT_SUMMARY_SIZE, TT_SUMMARY_MAX_ABS) 

116 

117FLAGS = flags.FLAGS 

118 

119DELTA_THRESHOLD = flags.DEFINE_float( 

120 'delta_threshold', 

121 default=0.5, 

122 help=('Log if history based diff crosses this threshold.')) 

123TT_CHECK_FILTER = flags.DEFINE_bool( 

124 'tt_check_filter', 

125 default=False, 

126 help='Terminate early to check op name filtering.') 

127TT_SINGLE_CORE_SUMMARIES = flags.DEFINE_bool( 

128 'tt_single_core_summaries', 

129 default=False, 

130 help='Report single core metric and avoid aggregation.') 

131 

132 

133class TTParameters(object): 

134 """A class that handles the parameters of Tensor Tracer.""" 

135 

136 def __init__(self, env=None): 

137 if env: 

138 self._env = env 

139 else: 

140 self._env = os.environ 

141 self._validate_flag_names() 

142 self.trace_mode = self._get_trace_mode() 

143 self.submode = self._get_submode() 

144 self.trace_dir = self._get_trace_dir() 

145 self.report_file_path = self._get_report_filepath() 

146 self.op_range = self._get_op_range() 

147 self.excluded_opname_re_list = self._flag_value_to_re_list( 

148 FLAG_NAME_EXCLUDED_OPNAMES) 

149 self.excluded_optype_re_list = self._flag_value_to_re_list( 

150 FLAG_NAME_EXCLUDED_OPTYPES) 

151 

152 self.included_opname_re_list = self._flag_value_to_re_list( 

153 FLAG_NAME_INCLUDED_OPNAMES) 

154 self.included_optype_re_list = self._flag_value_to_re_list( 

155 FLAG_NAME_INCLUDED_OPTYPES) 

156 

157 self.trace_scalar_ops = self.is_flag_on(FLAG_NAME_TRACE_SCALAR_OPS) 

158 self.use_compact_trace = self.trace_mode in (TRACE_MODE_NAN_INF, 

159 TRACE_MODE_NORM, 

160 TRACE_MODE_HISTORY, 

161 TRACE_MODE_MAX_ABS, 

162 TRACE_MODE_SUMMARY) 

163 self.use_temp_cache_var = self.is_flag_on(FLAG_NAME_TEMP_CACHE_VAR) 

164 self.inspect_trace = self.is_flag_on(FLAG_NAME_INSPECT_TRACE) 

165 self.use_fingerprint_subdir = self.is_flag_on(FLAG_NAME_FINGERPRINT_DIR) 

166 

167 _, self.graph_dump_path = self.get_flag_value( 

168 FLAG_NAME_DUMP_BEFORE_AFTER_GRAPHS) 

169 self.trace_level = self._get_flag_int_value(FLAG_NAME_TRACE_LEVEL, 

170 _TT_DEFAULT_TRACE_LEVEL) 

171 self.summary_signatures = self._get_summary_signatures() 

172 self.collect_summary_per_core = self.is_flag_on(FLAG_NAME_SUMMARY_PER_CORE) 

173 # TODO(b/199284834): Will be resolved with referenced bug. 

174 if self.collect_summary_per_core: 

175 logging.warning('Aggregate signatures are approximate for mean, variance' 

176 ' and sparsity.') 

177 self.flush_summaries_with_outside_compile = self.is_flag_on( 

178 FLAG_FLUSH_SUMMARY) 

179 # Do not produce errors or warnings if Tensor Tracer is not enabled. 

180 if self.is_enabled(): 

181 self._check_flag_errors() 

182 

183 def _check_flag_errors(self): 

184 if self.trace_mode in (TRACE_MODE_SUMMARY, TRACE_MODE_FULL_TENSOR_SUMMARY): 

185 if not self.trace_dir: 

186 raise ValueError('trace_dir must be explicitly provided in ' 

187 'TENSOR_TRACER_FLAGS when summary mode is used.') 

188 

189 def _get_report_filepath(self): 

190 """Sets the path of the output report file.""" 

191 

192 found, report_file_path = self.get_flag_value(FLAG_NAME_REPORT_FILE) 

193 if found and report_file_path and self.use_test_undeclared_outputs_dir(): 

194 if os.path.isabs(report_file_path): 

195 raise ValueError('If use_test_undeclared_outputs_dir is set,' 

196 'report_file_path cannot be an absolute path (%s)' 

197 %report_file_path) 

198 outputs_dir = self._env.get(_TEST_UNDECLARED_OUTPUTS_DIR_ENV_VAR) 

199 report_file_path = os.path.join(outputs_dir, report_file_path) 

200 return report_file_path 

201 

202 def _get_op_range(self): 

203 """Sets the index range of the Ops that we will consider tracing.""" 

204 found, op_range = self.get_flag_value(FLAG_NAME_OP_RANGE) 

205 if not found or not op_range: 

206 op_range = (-1, -1) # this means including all ops. 

207 return op_range 

208 match = _OP_RANGE_PAT.match(op_range) 

209 if not match: 

210 op_range = (-1, -1) # this means including all ops. 

211 return op_range 

212 op_range = (int(match.group(1)), int(match.group(2))) 

213 return op_range 

214 

215 def _get_trace_dir(self): 

216 found, trace_dir = self.get_flag_value(FLAG_NAME_TRACE_DIR) 

217 if found and trace_dir and self.use_test_undeclared_outputs_dir(): 

218 raise ValueError( 

219 'Cannot not use --%s and --%s at the same time' % 

220 (FLAG_NAME_TRACE_DIR, FLAG_NAME_USE_TEST_UNDECLARED_OUTPUTS_DIR)) 

221 if self.use_test_undeclared_outputs_dir(): 

222 trace_dir = self._env.get(_TEST_UNDECLARED_OUTPUTS_DIR_ENV_VAR) 

223 return trace_dir 

224 

225 def _get_trace_mode(self): 

226 """Checks if the given trace mode is valid.""" 

227 

228 found, trace_mode = self.get_flag_value(FLAG_NAME_TRACE_MODE) 

229 if not found or not trace_mode: 

230 trace_mode = TRACE_MODE_NORM 

231 valid_trace_modes = [ 

232 TRACE_MODE_NAN_INF, TRACE_MODE_PART_TENSOR, TRACE_MODE_FULL_TENSOR, 

233 TRACE_MODE_NORM, TRACE_MODE_MAX_ABS, 

234 TRACE_MODE_SUMMARY, TRACE_MODE_FULL_TENSOR_SUMMARY, 

235 TRACE_MODE_HISTORY 

236 ] 

237 if trace_mode not in valid_trace_modes: 

238 raise ValueError('Invalid trace mode "%s" given to the Tensor_Tracer.' 

239 'Valid trace modes are: %s'%(trace_mode, 

240 valid_trace_modes)) 

241 return trace_mode 

242 

243 def is_brief_mode(self): 

244 return self.submode == _SUBMODE_BRIEF 

245 

246 def _get_submode(self): 

247 """Checks if the given submode is valid.""" 

248 

249 found, submode = self.get_flag_value(FLAG_NAME_SUBMODE) 

250 if not found or not submode: 

251 submode = _SUBMODE_DETAILED 

252 if not submode: 

253 return 

254 valid_submodes = [_SUBMODE_DETAILED, _SUBMODE_BRIEF] 

255 if submode not in valid_submodes: 

256 raise ValueError('Invalid submode "%s" given to the Tensor_Tracer.' 

257 'Valid submodes are: %s'%(submode, 

258 valid_submodes)) 

259 return submode 

260 

261 @staticmethod 

262 def match_next_flag(tt_flags, pos): 

263 """Returns the match for the next TensorTracer flag. 

264 

265 Args: 

266 tt_flags: a string that contains the flags. 

267 pos: where in flags to start the search. 

268 

269 Returns: 

270 A pair where the first element is the regular-expression 

271 match found and the second element indicates if the match 

272 has a value. 

273 """ 

274 

275 match = _FLAG_DOUBLE_QUOTE_PAT.match(tt_flags, pos) 

276 if match: 

277 return match, True 

278 match = _FLAG_SINGLE_QUOTE_PAT.match(tt_flags, pos) 

279 if match: 

280 return match, True 

281 match = _FLAG_NO_QUOTE_PAT.match(tt_flags, pos) 

282 if match: 

283 return match, True 

284 match = _FLAG_NO_EQUAL_PAT.match(tt_flags, pos) 

285 if match: 

286 # The flag is found but is not given a value. 

287 return match, False 

288 # The flag is not found. 

289 return None, False 

290 

291 def _validate_flag_names(self): 

292 """Validates if the TensorTrace flags passed are valid.""" 

293 tensor_tracer_flags = self._env.get(FLAGS_ENV_VAR) 

294 if not tensor_tracer_flags: 

295 return 

296 pos = 0 

297 while True: 

298 match, _ = TTParameters.match_next_flag(tensor_tracer_flags, pos) 

299 if not match: 

300 break 

301 flag_name = match.group(1) 

302 if flag_name not in VALID_FLAG_NAMES: 

303 raise ValueError( 

304 'The flag name "%s" passed via the environment variable "%s" ' 

305 'is invalid. Valid flag names are:' 

306 '\n%s' % (flag_name, FLAGS_ENV_VAR, VALID_FLAG_NAMES)) 

307 pos = match.end() 

308 

309 def _supported_signatures(self): 

310 """Returns a tuple of supported signatures.""" 

311 return TT_SUMMARY_SIGNATURES 

312 

313 def _get_summary_signatures(self): 

314 """Verifies and returns the summary signatures. 

315 

316 Returns: 

317 A dictionary of the signature identifiers {signature: index} that will be 

318 computed when trace_mode is summary. 

319 """ 

320 signatures = self._flag_value_as_list(FLAG_NAME_SUMMARY_SIGNATURES) 

321 supported_signatures = self._supported_signatures() 

322 

323 tt_signatures = [] 

324 for signature in signatures: 

325 signature_with_prefix = '%s_%s' % (_TT_PREFIX, signature) 

326 if signature in supported_signatures: 

327 tt_signatures.append(signature) 

328 elif signature_with_prefix in supported_signatures: 

329 tt_signatures.append(signature_with_prefix) 

330 else: 

331 logging.warning('Unknown signature:%s. Supported signatures: %s' % 

332 (signature, supported_signatures)) 

333 if not tt_signatures: 

334 # Default case collects norm and max only. 

335 return {TT_SUMMARY_MAX_ABS: 0, TT_SUMMARY_NORM: 1} 

336 else: 

337 return {signature: idx for idx, signature in enumerate(tt_signatures)} 

338 

339 def get_signature_to_agg_fn_map(self): 

340 """Returns a map that contains the aggregate function for each signature.""" 

341 # TODO(b/199284834): Aggregations are not accurate for mean and sparsity if 

342 # cores have a different number of elements. Variance uses the maximal core 

343 # variance. 

344 return {TRACE_MODE_NORM: linalg_ops.norm, 

345 TRACE_MODE_HISTORY: math_ops.reduce_max, 

346 TRACE_MODE_MAX_ABS: math_ops.reduce_max, 

347 TRACE_MODE_NAN_INF: math_ops.reduce_max, 

348 TT_SUMMARY_NORM: linalg_ops.norm, 

349 TT_SUMMARY_MAX: math_ops.reduce_max, 

350 TT_SUMMARY_MAX_ABS: 

351 lambda t, axis=0: math_ops.reduce_max(math_ops.abs(t), # pylint: disable=g-long-lambda 

352 axis=axis), 

353 TT_SUMMARY_MIN: math_ops.reduce_min, 

354 # Exact if each part has the same number of values. 

355 TT_SUMMARY_SPARSITY: math_ops.reduce_mean, 

356 TT_SUMMARY_MEAN: math_ops.reduce_mean, 

357 TT_SUMMARY_VAR: math_ops.reduce_max, # Simply reduce max variance. 

358 TT_SUMMARY_SIZE: math_ops.reduce_sum} 

359 

360 def _flag_value_as_list(self, wanted_flag_name): 

361 """Returns the string list of a TensorTracer flag. 

362 

363 Args: 

364 wanted_flag_name: the name of the flag we are looking for. 

365 

366 Returns: 

367 The list value of the flag. 

368 """ 

369 string_value_list = [] 

370 found, flag_value = self.get_flag_value(wanted_flag_name) 

371 

372 if found: 

373 assert flag_value is not None 

374 string_value_list = flag_value.split(',') 

375 return string_value_list 

376 

377 def _flag_value_as_int_list(self, wanted_flag_name): 

378 """Returns the integer list of a TensorTracer flag. 

379 

380 Args: 

381 wanted_flag_name: the name of the flag we are looking for. 

382 

383 Returns: 

384 the value of the flag. 

385 Raises: 

386 RuntimeError: If supposedly deadcode is reached. 

387 """ 

388 int_list = [] 

389 found, flag_value = self.get_flag_value(wanted_flag_name) 

390 

391 if found and flag_value: 

392 try: 

393 integer_values = flag_value.split(',') 

394 int_list = [int(int_val) for int_val in integer_values] 

395 except ValueError: 

396 logging.warning('Cannot convert %s to int for flag %s', int_list, 

397 wanted_flag_name) 

398 return int_list 

399 

400 def _get_flag_int_value(self, wanted_flag_name, default_value): 

401 """Returns the int value of a TensorTracer flag. 

402 

403 Args: 

404 wanted_flag_name: the name of the flag we are looking for. 

405 default_value: the default value for the flag, if not provided. 

406 Returns: 

407 the value of the flag. 

408 Raises: 

409 RuntimeError: If supposedly deadcode is reached. 

410 """ 

411 flag_int_value = default_value 

412 found, flag_value = self.get_flag_value(wanted_flag_name) 

413 

414 if found: 

415 try: 

416 flag_int_value = int(flag_value) 

417 except ValueError: 

418 logging.warning('Cannot convert %s to int for flag %s' % ( 

419 flag_int_value, wanted_flag_name)) 

420 return flag_int_value 

421 

422 def get_flag_value(self, wanted_flag_name): 

423 """Returns the value of a TensorTracer flags. 

424 

425 Args: 

426 wanted_flag_name: the name of the flag we are looking for. 

427 

428 Returns: 

429 A pair where the first element indicates if the flag is 

430 found and the second element is the value of the flag. 

431 

432 Raises: 

433 RuntimeError: If supposedly deadcode is reached. 

434 """ 

435 

436 tensor_tracer_flags = self._env.get(FLAGS_ENV_VAR) 

437 if not tensor_tracer_flags: 

438 return False, None 

439 pos = 0 

440 while True: 

441 match, has_value = TTParameters.match_next_flag( 

442 tensor_tracer_flags, pos) 

443 if not match: 

444 return False, None 

445 flag_name = match.group(1) 

446 if has_value: 

447 flag_value = match.group(2) 

448 else: 

449 flag_value = None 

450 if flag_name == wanted_flag_name: 

451 return True, flag_value 

452 pos = match.end() 

453 raise RuntimeError('Invalid tensor tracer flag. Could not recognize %s.' % 

454 flag_name) 

455 

456 def _flag_value_to_re_list(self, flag_name): 

457 """Converts list of strings to compiled RE.""" 

458 

459 re_list = [] 

460 found, flag_value = self.get_flag_value(flag_name) 

461 if not found or not flag_value: 

462 return re_list 

463 list_of_values = flag_value.split(',') 

464 for v in list_of_values: 

465 r = re.compile(v) 

466 re_list.append(r) 

467 return re_list 

468 

469 def is_flag_on(self, flag_name): 

470 """Returns True if the given flag is on.""" 

471 

472 found, flag_value = self.get_flag_value(flag_name) 

473 if not found: 

474 return False 

475 if flag_value is None: 

476 return True 

477 # Depends on the flag value. 

478 flag_value = flag_value.lower() 

479 enabled = flag_value in ['1', 't', 'true', 'y', 'yes'] 

480 return enabled 

481 

482 def is_enabled(self): 

483 """Returns True if TensorTracer is enabled.""" 

484 

485 if self.is_flag_on(FLAG_NAME_ENABLE): 

486 logging.debug('Tensor Tracer is enabled with flags %s.', 

487 self._env.get(FLAGS_ENV_VAR)) 

488 return True 

489 else: 

490 return False 

491 

492 def use_test_undeclared_outputs_dir(self): 

493 """Decides the output directory of the report and trace files. 

494 

495 Args: 

496 None. 

497 

498 Returns: 

499 True if the output files should be written to the 

500 test-undeclared-outputs-directory defined via an 

501 env variable. 

502 """ 

503 

504 return self.is_flag_on(FLAG_NAME_USE_TEST_UNDECLARED_OUTPUTS_DIR)