Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/tpu/tensor_tracer.py: 18%

953 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ======================================================================== 

15"""A utility to trace tensor values on TPU.""" 

16 

17import collections 

18import hashlib 

19import operator 

20import os 

21import os.path 

22import sys 

23 

24import numpy as np 

25 

26from tensorflow.core.framework import summary_pb2 

27from tensorflow.python.eager import monitoring 

28from tensorflow.python.framework import constant_op 

29from tensorflow.python.framework import dtypes 

30from tensorflow.python.framework import func_graph 

31from tensorflow.python.framework import function 

32from tensorflow.python.framework import graph_io 

33from tensorflow.python.framework import ops 

34from tensorflow.python.framework import tensor_util 

35from tensorflow.python.lib.io import file_io 

36from tensorflow.python.ops import array_ops 

37from tensorflow.python.ops import array_ops_stack 

38from tensorflow.python.ops import cond 

39from tensorflow.python.ops import control_flow_case 

40from tensorflow.python.ops import control_flow_ops 

41from tensorflow.python.ops import control_flow_util 

42from tensorflow.python.ops import gen_math_ops 

43from tensorflow.python.ops import init_ops 

44from tensorflow.python.ops import linalg_ops 

45from tensorflow.python.ops import logging_ops 

46from tensorflow.python.ops import math_ops 

47from tensorflow.python.ops import nn_impl 

48from tensorflow.python.ops import state_ops 

49from tensorflow.python.ops import string_ops 

50from tensorflow.python.ops import summary_ops_v2 as summary 

51from tensorflow.python.ops import variable_scope 

52from tensorflow.python.platform import analytics 

53from tensorflow.python.platform import gfile 

54from tensorflow.python.platform import remote_utils 

55from tensorflow.python.platform import tf_logging as logging 

56from tensorflow.python.summary import summary_iterator 

57from tensorflow.python.tpu import tensor_tracer_flags 

58from tensorflow.python.tpu import tensor_tracer_report 

59from tensorflow.python.tpu import tpu_replication 

60from tensorflow.python.tpu.ops import tpu_ops 

61from tensorflow.python.training import training_util 

62 

63_DEVICE_TYPE_TPU = 'tpu' 

64_DEVICE_TYPE_CPU = 'cpu' 

65_TRACE_MODE_PART_TENSOR_SIZE = 3 

66 

67_REASON_OUTSIDE_OP_RANGE = 'not-traced-outside-op-range' 

68_REASON_UNSAFE_OP = 'not-traced-unsafe-op' 

69_REASON_WHILELOOP_OP = 'not-traced-special-whileloop-op' 

70_REASON_CONTROLFLOW_OP = 'not-traced-control-flow-op' 

71_REASON_IN_CONTROL_FLOW = 'not-traced-in-control-flow' 

72_REASON_UNSAFE_SCALAR = 'not-traced-unsafe-scalar' 

73_REASON_SKIP_SCALAR = 'not-traced-scalar' 

74_REASON_LESS_INTERESTING_OP = 'not-traced-less-interesting-op' 

75_REASON_DEVICE_MISMATCH = 'not-traced-device-mismatch' 

76_REASON_DYNAMIC_SHAPE = 'not-traced-dynamic-shape' 

77_REASON_SCALAR_GET_TRACED = 'traced-scalar' 

78_REASON_TENSOR_GET_TRACED = 'traced-tensor' 

79_REASON_USER_INCLUDED = 'traced-user-included' 

80_REASON_USER_EXCLUDED = 'not-traced-user-excluded' 

81_REASON_NOT_EXECUTED = 'not-traced-not-in-exec-path' 

82_REASON_NON_NUMERIC_TENSOR = 'not-traced-non-numeric-tensor' 

83_REASON_FEEDS_WHILELOOP_OP = 'not-traced-feeds-special-whileloop-op' 

84 

85_OUTPUT_STREAM_ESCAPE = 'file://' 

86_TENSOR_TRACER_COLLECTION = 'tensor_tracer_variables' 

87TENSOR_TRACER_SUMMARY_COLLECTION = 'tensor_tracer_summary_writers' 

88_TRACE_FILE_NAME = 'trace.all' 

89_COMPACT_TRACE_FILE_PREFIX = 'compact_trace.' 

90_COMPACT_TRACE_ENTRY_INIT_VALUE = -1.0 

91_TENSOR_TRACER_STORAGE = 'tensor_tracer_storage' 

92_TT_SNAPSHOT = 'tensor_tracer_snapshot' 

93_REPLICA_ID_TAG = '#replica-id: ' 

94_SKIP_REPORT_FILE = 'None' # Do not write report proto if --report_file=None 

95 

96_TT_SUMMARY_NORM = tensor_tracer_flags.TT_SUMMARY_NORM 

97_TT_SUMMARY_MAX = tensor_tracer_flags.TT_SUMMARY_MAX 

98_TT_SUMMARY_MAX_ABS = tensor_tracer_flags.TT_SUMMARY_MAX_ABS 

99_TT_SUMMARY_MIN = tensor_tracer_flags.TT_SUMMARY_MIN 

100_TT_SUMMARY_MEAN = tensor_tracer_flags.TT_SUMMARY_MEAN 

101_TT_SUMMARY_VAR = tensor_tracer_flags.TT_SUMMARY_VAR 

102_TT_SUMMARY_SIZE = tensor_tracer_flags.TT_SUMMARY_SIZE 

103_TT_SUMMARY_SPARSITY = tensor_tracer_flags.TT_SUMMARY_SPARSITY 

104 

105_TT_SUMMARY_TAG = 'tensor_tracer_summary' 

106_TT_TENSORBOARD_PLUGIN_NAME = 'tensor_tracer' 

107_TT_HOSTCALL_KEY = 'tensor_tracer_host_call' 

108_TT_EVENT_FILE_SUFFIX = '.tensor_tracer' 

109 

110_TT_SUMMARY_MAX_QUEUE = 10 

111 

112tt_gauge = monitoring.BoolGauge('/tensorflow/api/tensor_tracer/v1', 

113 'tensor tracer usage', 'method') 

114 

115 

116def _graph_summary_tag(graph): 

117 """Generates and returns a summary tag name for the given graph.""" 

118 

119 if graph is None: 

120 raise RuntimeError('graph is None') 

121 # The chance of collision with md5 is effectively 0. 

122 hash_id = hashlib.md5() 

123 hash_id.update(repr(graph).encode('utf-8')) 

124 # hexdigest() returns a string. 

125 return hash_id.hexdigest() 

126 

127 

128def set_parameters(tensor_tracer_params=None): 

129 """Enables tensor tracer and sets its parameters. 

130 

131 Example usage: 

132 tensor_tracer_parameters = {'trace_dir': '/usr/tmp/trace_dir', 

133 'trace_mode': 'norm', 

134 'report_file': '/usr/tmp/trace_dir/report.all'} 

135 tensor_tracer.set_parameters(tensor_tracer_parameters) 

136 

137 This sets up the parameters for tensor tracer. A call to tensor tracer as 

138 below is necessary to enable debugging on CPUs and GPUs. On TPUs below can be 

139 skipped as this call is hooked into tpu.rewrite. 

140 tt = tensor_tracer.TensorTracer() 

141 loss = tt.trace_cpu(tf.get_default_graph(), tensor_fetches=loss) 

142 

143 Args: 

144 tensor_tracer_params: Tensor tracer parameter dictionary. Below gives 

145 examples of these parameters: See tensor_tracer_report.py for all 

146 parameters. 

147 - enable: If set, tensor tracer will be enabled. Calling 

148 enable_tensor_tracer automatically adds this parameters. 

149 - trace_mode: The trace_mode to be used by tensor tracer. These include: 

150 - summary: Collects multiple statistics for traced tensors, and writes 

151 them a summary file that can be visualized using tensorboard. This 

152 mode currently only works for TPUEstimator. It can be also be used 

153 for other models, but outfeed must be handled by the user. 

154 - norm: Collects norm of each traced tensor and writes them into a 

155 text file pointed by 'trace_dir' flag. (Default mode). 

156 - nan-inf: Checks the existince of NaNs and Infs in the tensor, and 

157 writes a boolean value to a text file pointed by 'trace_dir' flag. 

158 Note that 'norm' mode can also capture this information with more 

159 numerical info. 

160 - max-abs: Collects the absolute max for each traced tensors and 

161 writes it into a text file pointed by 'trace_dir' flag. 

162 - full-tensor: Writes the full tensor content of the traced tensors 

163 into a text file pointed by 'trace_dir' flag. 

164 - part-tensor: Writes a part of the tensor content of the traced 

165 tensors into a text file pointed by 'trace_dir' flag. 

166 - full_tensor_summary: Writes the full tensors as binary event files. 

167 The outputs can be read using: trace = 

168 tensor_tracer.read_tensor_tracer_event_file(event_file_path) 

169 

170 - report_file: Path to the metadata file that is written during graph 

171 construction. If not set, metadata will be printed to stdout during 

172 graph construction. 

173 - trace_dir: Path where the execution traces will be written during the 

174 graph execution. If not set, trace will be printed to stderr. 

175 - trace_level: Tensor tracer aims to trace everything it can. This 

176 introduces some overhead on graph execution and graph compilation 

177 times. Using trace_level parameter, it is possible to trace operation 

178 based on their priorities. For example, - trace_level=7 is the highest 

179 trace_level, in which every op is traced. - trace_level=6 will skip 

180 constant operations such as tf.constant. - trace_level=5 will skip 

181 less important ops such as tf.identities. - The default trace_level=3, 

182 that will skip concat ops, or random number generators. - To reduce 

183 the graph compile time overhead, trace_level can be set to 0, that 

184 will skip additions, and substractions, and multiplications as well. 

185 - excluded_opnames: If set, any matching op name will not be traced. 

186 excluded_opnames can be set as a regular expression. E.g, 

187 excluded_opnames=.* will exclude everything. 

188 - excluded_optypes: If set, any matching op type will not be traced. 

189 excluded_optypes can be set as a regular expression. E.g, 

190 excluded_optypes=.* will exclude everything. excluded_optypes=MatMul 

191 will exclude all MatMul ops from tracing. 

192 - included_opnames: If set, any matching op name will be forced to be 

193 traced. included_opnames can be set as a regular expression. E.g, 

194 '--included_opnames=some_op --excluded_opname=*.' will only trace 

195 some_op. 

196 - included_optypes: If set, any matching op type will be forced to be 

197 traced. included_optypes can be set as a regular expression. E.g, 

198 '--included_optypes=some_op_type --excluded_optypes=*.' will trace 

199 only the ops with type 'some_op_type' 

200 - flush_summaries: If summary mode is used, flush_summaries=1 will 

201 flush summaries using outside compilation. Note that, if used with 

202 low level APIs, flush_summaries=1 is necessary to obtain results. 

203 Advanced Flags: 

204 - trace_scalar: Scalar values are not traced by default. If this flag is 

205 set, scalar values will also be traced. 

206 - op_range: In the form of '%d:%d' that limits the tracing to the ops 

207 within this limit. --op_range='5:10' will trace only the ops that have 

208 topological order between 5-10. 

209 - submode: 'brief' or 'detailed'. If the trace mode is not compact, 

210 brief mode will print only the id of each traced tensor to save some 

211 space. 'detailed' mode prints the full tensor name. 

212 - use_fingerprint_subdirectory: The trace directory will be chosen as 

213 using the fingerprint of the trace metadata under the provided 

214 trace_dir. 

215 """ 

216 enable_flags = '--%s=1' % tensor_tracer_flags.FLAG_NAME_ENABLE 

217 if tensor_tracer_params: 

218 for key, value in tensor_tracer_params.items(): 

219 enable_flags += ' --%s=%s' % (key, value) 

220 os.environ[tensor_tracer_flags.FLAGS_ENV_VAR] = enable_flags 

221 

222 

223def op_priority(op_type): 

224 """Returns the priority of the op. 

225 

226 If the priority of the op is k, it will be traced if trace_level>=k. 

227 Args: 

228 op_type: String name of the operation type. 

229 Returns: 

230 Integer value corresponding the priority of the op. 

231 """ 

232 if op_type in ('Const', 'Shape', 'BroadcastGradientArgs', 'Range', 

233 'VariableShape', 'Fill', 'OneHot', 'ShapeN'): 

234 # Lowest priority ops, e.g., constant ops across different steps, 

235 # They will be traced only if trace_level>=7 

236 return 7 

237 

238 if op_type in ('Identity', 'Cast', 'Reshape', 'ExpandDims', 'StopGradient', 

239 'PreventGradient', 'Squeeze', 'Gather', 'GatherNd'): 

240 # Operations without numerical effects. 

241 # They will be only if trace_level>=6 

242 return 6 

243 if op_type in ('ConcatV2', 'Concat', 'StridedSlice', 'Slice', 'Pack', 'Tile', 

244 'CollectivePermute', 'SplitV', 'DynamicPartition'): 

245 # Operations that merge or slice an input, will be traced if trace_level>=5 

246 return 5 

247 if op_type in ('Pad', 'RandomUniformInt', 'GreaterEqual'): 

248 # Operations less likely to provide useful information, 

249 # will be traced if trace_level>=4 

250 return 4 

251 if op_type in ('Sum', 'AddV2', 'Add', 'AddN', 'BiasAdd', 'CrossReplicaSum'): 

252 # Add operations that are less likely create any issues, will be traced 

253 # if trace_level>=3 (default=3) 

254 return 3 

255 if op_type in ('Neg', 'Sub'): 

256 # Sub operations that are less likely create any issues, will be traced 

257 # trace_level>=2 

258 return 2 

259 if op_type in ('Mul', 'Square', 'MatMul', 'RandomUniform', 'Select', 

260 'Maximum', 'Mean', 'Variance', 'Exp', 'Rsqrt'): 

261 # Multiplication and some other operations, will be traced if trace_level>=1 

262 return 1 

263 

264 # Unclassified op_types default to being traced at level 2 and above. 

265 return 2 

266 

267 

268def read_tensor_tracer_event_file(event_file): 

269 """Reads the event file written by tensor tracer. 

270 

271 This can be used to read the full tensors written into binary event files by 

272 by TensorTracer with trace_mode=full_tensor_summary. 

273 

274 Example usage: 

275 result_dict_list = tensor_tracer.read_tensor_tracer_event_file( 

276 event_file_path) 

277 for result_dict in result_dict_list: 

278 for step, tensor_dict in result_dict.items(): 

279 for tensor_name, full_tensor_content in tensor_dict.items(): 

280 logging.info(tensor_name, full_tensor_content) 

281 

282 Args: 

283 event_file: Path to the event file that contains only tensor tracer events. 

284 Returns: 

285 A list of event dictionaries, each of which with the form: 

286 {step_number: {tensor_name: tensor_content}}. This is a list instead of 

287 a single event dictionary because it is possible that an event file may 

288 have multiple event traces, each of them covering the same step ranges. 

289 Raises: 

290 ValueError: If an unexpected trace is found. 

291 """ 

292 

293 # Keeps track of how many times that a step number shows up in these events. 

294 step_occurrence_count = collections.defaultdict(int) 

295 

296 # List of step occurrences. 

297 step_occurrence_list = [] 

298 

299 for trace_event in summary_iterator.summary_iterator(event_file): 

300 # First event is an event with file_version: "brain.Event:2" 

301 if not trace_event.HasField('summary'): 

302 continue 

303 if len(trace_event.summary.value) != 1: 

304 raise ValueError('Single step contains %d summary values,' 

305 ' expected 1.' % len(trace_event.summary.value)) 

306 step = trace_event.step 

307 step_occurrence_count[step] += 1 # a new occurrence for this step. 

308 

309 occurrence_idx = step_occurrence_count[step] - 1 

310 occurrence_size = len(step_occurrence_list) 

311 

312 if occurrence_idx == occurrence_size: 

313 # This particular occurrence isn't yet recorded on step_occurrence_list. 

314 # So append this new occurrence to the end of step_occurrence_list. 

315 new_occurrence = collections.defaultdict(dict) 

316 step_occurrence_list.append(new_occurrence) 

317 else: 

318 # This particular occurrence must be already recorded on 

319 # step_occurrence_list (i.e. occurrence_idx < occurrence_size). 

320 if occurrence_idx > occurrence_size: 

321 raise ValueError('Unexpected: occurrence_idx (%d) > ' 

322 'occurrence_size (%d)' % (occurrence_idx, 

323 occurrence_size)) 

324 tensor_value = trace_event.summary.value[0] 

325 tensor_name = tensor_value.tag 

326 

327 real_shape = [d.size for d in tensor_value.tensor.tensor_shape.dim] 

328 tensor_content = np.frombuffer( 

329 tensor_value.tensor.tensor_content, 

330 dtypes.DType(tensor_value.tensor.dtype).as_numpy_dtype() 

331 ).reshape(real_shape) 

332 step_occurrence_list[occurrence_idx][step][tensor_name] = tensor_content 

333 return step_occurrence_list 

334 

335 

336def trace_tensor(tensor, tracepoint_name=None): 

337 """Programmatic interface to trace a tensor with Tensor Tracer. 

338 

339 Tensor Tracer, by default, traces all tensors in the execution. This function 

340 can be used to limit traced tensors. If this function is called for a subset 

341 of the tensors, only those will be traced. 

342 

343 For example, Tensor Traacer will only trace c below. 

344 c = tf.MatMul(a, b) 

345 tensor_tracer.trace_tensor(c) 

346 d = tf.add(c, 1) 

347 Args: 

348 tensor: the tensor object for which the tracing is requested. 

349 tracepoint_name: an optional tensor tracepoint name string. A tracepoint 

350 name is an Tensor Tracer internal name for the tensor. It is useful when 

351 comparing equivalent traces from different models that have different 

352 tensor namings. Equivalent tensors (with different names) can be mapped 

353 to each other by assigning a common tracepoint_name. 

354 

355 Returns: 

356 The provided tensor. 

357 """ 

358 if tracepoint_name is None: 

359 tracepoint_name = tensor.name 

360 tensor.graph.get_collection(_TENSOR_TRACER_COLLECTION) 

361 tensor.graph.add_to_collection(_TENSOR_TRACER_COLLECTION, 

362 (tensor, tracepoint_name)) 

363 return tensor 

364 

365 

366def keras_layer_tracepoint(layer, checkpoint_name): 

367 """An interface for adding the tensor outputs of a keras layer. 

368 

369 Encapsulates trace_tensor. 

370 

371 Args: 

372 layer: A keras layer. 

373 checkpoint_name: a string name for the checkpoint. This name has to be a 

374 unique name if used within model comparison. The tensors that have the same 

375 checkpoint identifier is compared in model comparison. 

376 

377 Returns: 

378 The provided layer. 

379 """ 

380 try: 

381 outputs = layer.output 

382 if tensor_util.is_tf_type(outputs): 

383 trace_tensor(outputs, '%s' % (checkpoint_name)) 

384 else: 

385 idx = 0 

386 for output_tensor in outputs: 

387 if tensor_util.is_tf_type(outputs): 

388 trace_tensor(output_tensor, '%s_%d' % (checkpoint_name, idx)) 

389 idx += 1 

390 except AttributeError: 

391 pass 

392 except RuntimeError: 

393 pass 

394 return layer 

395 

396 

397class TensorTracer: 

398 """A software construct for tracing tensor values in a TF graph. 

399 

400 This utility is disabled by default. It is hooked into tpu.rewrite, so it can 

401 easily be enabled on TPUs by setting the TENSOR_TRACER_FLAGS env variable as 

402 below without a code change. 

403 export TENSOR_TRACER_FLAGS="--enable=1" 

404 

405 Below is the use example to enable it on CPUs or GPUs, or for more advance use 

406 cases on TPUs. 

407 

408 a = x + 1 

409 b = a * 2 

410 rs = tf.reduce_sum(b) 

411 tensor_tracer.set_parameters({'trace_dir': 'path/to/trace_dir', 

412 'report_file: 'path/to/report/file'}) 

413 tt = tensor_tracer.TensorTracer() 

414 if on_tpu: 

415 rs = tt.trace_tpu(tf.get_default_graph(), 

416 tensor_fetches=rs) 

417 else: 

418 rs = tt.trace_cpu(tf.get_default_graph(), 

419 tensor_fetches=rs) 

420 session.run(rs) 

421 

422 If it is enabled, it will trace the output tensor values of 

423 selected Ops in the graph. It has two outputs: (1) the traces and (2) 

424 a report. The traces are dumped to a specified directory during the graph 

425 execution, while the report is dumped during the graph construction. 

426 By passing options via the env variable, users can change: 

427 (1) the trace mode (e.g., detecting NaN/Inf, printing partial or 

428 full tensor values) 

429 (2) which Ops to be traced (via op.name or op.type) 

430 (3) output trace file path. 

431 

432 """ 

433 # The set of graphs that are rewritten by tensor tracer. 

434 _traced_graphs = set() 

435 

436 @staticmethod 

437 def is_enabled(): 

438 """Returns True if TensorTracer is enabled.""" 

439 try: 

440 enable = tensor_tracer_flags.TTParameters().is_enabled() 

441 # Add metrics to determine API usage. 

442 if enable: tt_gauge.get_cell('is_enabled').set(True) 

443 return enable 

444 except (ValueError, RuntimeError) as e: 

445 logging.warning( 

446 'Tensor Tracer V1 flags processing error encountered in is_enabled ' 

447 'check. %s', e) 

448 # TODO(b/210212559): Find a more robust fix. 

449 # Should only produce exception if Tensor Tracer is enabled. 

450 return True 

451 

452 @staticmethod 

453 def check_device_type(device_type): 

454 """Checks if the given device type is valid.""" 

455 

456 if device_type not in (_DEVICE_TYPE_TPU, _DEVICE_TYPE_CPU): 

457 raise ValueError('Invalid device_type "%s"'%device_type) 

458 

459 @staticmethod 

460 def check_trace_mode(device_type, trace_mode): 

461 """Checks if the given trace mode work on the given device type. 

462 

463 Args: 

464 device_type: Device type, TPU, GPU, CPU. 

465 trace_mode: Tensor tracer trace mode. 

466 Raises: 

467 ValueError: If the given trace mode is not supported for the device. 

468 """ 

469 if trace_mode == tensor_tracer_flags.TRACE_MODE_FULL_TENSOR_SUMMARY: 

470 if device_type != _DEVICE_TYPE_TPU: 

471 raise ValueError('Device_type "%s" is not yet supported for ' 

472 'trace mode "%s"' % (device_type, trace_mode)) 

473 

474 @staticmethod 

475 def loop_cond_op(op): 

476 return op.type in ('LoopCond', 'RefLoopCond') 

477 

478 @staticmethod 

479 def while_loop_op(op): 

480 """Returns true if op is one of the special ops of in a while loop. 

481 

482 Args: 

483 op: A tf.Operation. 

484 

485 Returns: 

486 True if the given op is one of [Switch, Merge, Enter, Exit, 

487 NextIteration, LoopCond], which are all building blocks for TF while 

488 loops. 

489 """ 

490 return (control_flow_util.IsLoopSwitch(op) or 

491 control_flow_util.IsLoopMerge(op) or 

492 control_flow_util.IsLoopEnter(op) or 

493 control_flow_util.IsLoopExit(op) or 

494 TensorTracer.loop_cond_op(op) or 

495 op.type in ('RefNextIteration', 'NextIteration')) 

496 

497 @staticmethod 

498 def control_flow_op(op): 

499 """Returns true if op is one of the special ops of in a while loop. 

500 

501 Args: 

502 op: A tf.Operation. 

503 

504 Returns: 

505 True if the given op is one of [Switch, Merge, Enter, Exit, 

506 NextIteration, LoopCond], which are all building blocks for TF while 

507 loops. 

508 """ 

509 return (control_flow_util.IsSwitch(op) or 

510 control_flow_util.IsMerge(op)) 

511 

512 @staticmethod 

513 def unsafe_op(op): 

514 """Returns True if this op is not safe to be traced.""" 

515 

516 # Reasons for not including following op types: 

517 # Assign: cause incorrect result with CPU tracing. 

518 if op.type == 'Assign': 

519 return True 

520 return False 

521 

522 @staticmethod 

523 def device_mismatch(device_type, op): 

524 if device_type == _DEVICE_TYPE_TPU: 

525 # pylint: disable=protected-access 

526 return tpu_replication._TPU_REPLICATE_ATTR not in op.node_def.attr 

527 # pylint: enable=protected-access 

528 return False 

529 

530 @staticmethod 

531 def unsafe_scalar_trace(op): 

532 """Return true if scalar output tensor from Op is not safe to be traced.""" 

533 

534 # Tracing the following causes cycle in the graph on TPU. 

535 if op.type in ('LoopCond', 'Enter', 'Merge', 'Const', 

536 'Switch', 'Less', 'ReadVariableOp'): 

537 return True 

538 # Tracing the following will cause casting-issue 

539 # with the norm tracing mode or other compilation issues on CPU. 

540 if op.type in ('VarHandleOp', 'IteratorToStringHandle', 

541 'IteratorGetNext', 'OneShotIterator', 

542 'IteratorV2', 'MakeIterator', 

543 'BatchDatasetV2', 'MapDataset', 

544 'FixedLengthRecordDataset', 'TakeDataset', 'ZipDataset', 

545 'Placeholder', 'PlaceholderWithDefault', 'StridedSlice'): 

546 return True 

547 return False 

548 

549 def _is_interesting_op(self, op): 

550 """Returns True if the given op is not an interesting one to be traced.""" 

551 return op_priority(op.type) <= self._parameters.trace_level 

552 

553 @staticmethod 

554 def reason(op_idx, details): 

555 """Returns reason why the Op at op_idx is traced or not.""" 

556 

557 return '%d %s'%(op_idx, details) 

558 

559 def __init__(self): 

560 """Initializes a TensorTracer. 

561 

562 Sets the various member fields from the flags (if given) or the defaults. 

563 """ 

564 self._replica_id = None 

565 self._tt_config = tensor_tracer_report.TensorTracerConfig() 

566 self._parameters = tensor_tracer_flags.TTParameters() 

567 self._host_call_fn = {} 

568 # _cache_variables is a dict (key = graph, value = dicts 

569 # (key = name, value = tensors)) 

570 self._cache_variables = {} 

571 self._history_value_cache = {} 

572 

573 self._traced_op_names = set() 

574 self._report_proto = None 

575 # _temp_cache_var is a dict (key = graph, value = []) 

576 self._temp_cache_var = {} 

577 self._report_proto_path = '' 

578 self._outmost_context = None 

579 

580 def report_proto(self): 

581 """Getter for tensor_tracer.proto object for summary and full_tensor_summary modes. 

582 

583 Returns: 

584 A tensor_tracer.proto object. 

585 Raises: 

586 ValueError if called before tracing happens, or when trace mode is not 

587 summary or full_tensor_summary. 

588 """ 

589 if self._report_proto: 

590 return self._report_proto 

591 else: 

592 raise ValueError('Call to report_proto must be done after tracing.' 

593 'Report proto only exists for ' 

594 'trace_mode=[summary|full_tensor_summary]') 

595 

596 def report_proto_path(self): 

597 """Getter for path where tensor_tracer.proto object should be written. 

598 

599 Returns: 

600 A string path. 

601 """ 

602 return self._report_proto_path 

603 

604 def _escape_namescopes(self, variable_name): 

605 return variable_name.replace('/', '_').replace(':', '_') 

606 

607 def _cache_variable_for_graph(self, graph): 

608 if graph not in self._cache_variables: 

609 self._cache_variables[graph] = {} 

610 return self._cache_variables[graph] 

611 

612 def _create_or_get_tensor_history_values_cache(self, 

613 cache_name, 

614 graph, 

615 shape=None, 

616 dtype=dtypes.float32): 

617 """Creates a variable as the cache to store historic intermediate tensor values. 

618 

619 Args: 

620 cache_name: Name to be given to the cache (an instance of tf.variable). 

621 graph: Tensorflow graph. 

622 shape: A list of dimensions. 

623 dtype: Data type of created cache. 

624 Returns: 

625 A ref to newly created or existing cache with the given dimensions. 

626 Raises: 

627 ValueError: 

628 (1) If graph is None, or 

629 (2) shape is None when a new cache needs to be created. 

630 """ 

631 if graph is None: 

632 raise ValueError('Invalid graph.') 

633 

634 if graph not in self._history_value_cache: 

635 self._history_value_cache[graph] = {} 

636 

637 if cache_name not in self._history_value_cache[graph]: 

638 if shape is None: 

639 raise ValueError('shape must be provided at cache creation.') 

640 if dtype.is_integer: 

641 init_val = int(_COMPACT_TRACE_ENTRY_INIT_VALUE) 

642 else: 

643 init_val = _COMPACT_TRACE_ENTRY_INIT_VALUE 

644 

645 # Create in proper graph and base name_scope. 

646 with graph.as_default() as g, g.name_scope(None): 

647 self._history_value_cache[graph][ 

648 cache_name] = variable_scope.get_variable( 

649 'tt_history' + '_' + self._escape_namescopes(cache_name), 

650 shape=shape, 

651 dtype=dtype, 

652 initializer=init_ops.constant_initializer(init_val), 

653 trainable=False, 

654 use_resource=True, 

655 collections=[ 

656 _TENSOR_TRACER_STORAGE, ops.GraphKeys.LOCAL_VARIABLES 

657 ]) 

658 

659 return self._history_value_cache[graph][cache_name] 

660 

661 def _create_or_get_tensor_values_cache(self, cache_name, graph, 

662 shape=None, dtype=dtypes.float32): 

663 """Creates a variable as the cache to store intermediate tensor values. 

664 

665 Args: 

666 cache_name: Name to be given to the cache (an instance of tf.variable). 

667 graph: Tensorflow graph. 

668 shape: A list of dimensions. 

669 dtype: Data type of created cache. 

670 Returns: 

671 A ref to newly created or existing cache with the given dimensions. 

672 Raises: 

673 ValueError: 

674 (1) If graph is None, or 

675 (2) shape is None when a new cache needs to be created. 

676 """ 

677 if graph is None: 

678 raise ValueError('Invalid graph.') 

679 

680 graph_cache_var = self._cache_variable_for_graph(graph) 

681 

682 if cache_name not in graph_cache_var: 

683 if shape is None: 

684 raise ValueError('shape must be provided at cache creation.') 

685 if dtype.is_integer: 

686 init_val = int(_COMPACT_TRACE_ENTRY_INIT_VALUE) 

687 else: 

688 init_val = _COMPACT_TRACE_ENTRY_INIT_VALUE 

689 

690 # Create in proper graph and base name_scope. 

691 with graph.as_default() as g, g.name_scope(None): 

692 graph_cache_var[cache_name] = variable_scope.get_variable( 

693 _TT_SNAPSHOT + '_' + self._escape_namescopes(cache_name), 

694 shape=shape, dtype=dtype, 

695 initializer=init_ops.constant_initializer(init_val), 

696 trainable=False, 

697 use_resource=True, 

698 collections=[_TENSOR_TRACER_STORAGE, ops.GraphKeys.LOCAL_VARIABLES]) 

699 return graph_cache_var[cache_name] 

700 

701 def _add_replica_id_to_graph(self): 

702 """Adds nodes for computing the replica ID to the graph.""" 

703 

704 if self._tt_config.num_replicas: 

705 with ops.control_dependencies(None): 

706 # Uses None as dependency to run outside of TPU graph rewrites. 

707 self._replica_id = tpu_ops.tpu_replicated_input( 

708 list(range(self._tt_config.num_replicas)), 

709 name='tt_replica_id') 

710 else: 

711 self._replica_id = 'unknown' 

712 

713 def _inside_op_range(self, idx): 

714 """Return True if the given index is inside the selected range.""" 

715 

716 if idx < self._parameters.op_range[0]: 

717 return False 

718 return (self._parameters.op_range[1] < 0 or 

719 idx <= self._parameters.op_range[1]) 

720 

721 def _is_user_included_op(self, op): 

722 """Checks whether the op is included in the tensor tracer flags. 

723 

724 Args: 

725 op: tf Operation 

726 Returns: 

727 True, if the op is included. 

728 An op is included if: 

729 - Its op name is given in included_opnames 

730 - Its op type is given in included_optypes 

731 - The op is at most _trace_ops_before_included hops before an included op 

732 - The op is at most _trace_ops_after_included hops after an included op 

733 """ 

734 for opname_re in self._parameters.included_opname_re_list: 

735 if opname_re.match(op.name): 

736 return True 

737 

738 for optype_re in self._parameters.included_optype_re_list: 

739 if optype_re.match(op.type): 

740 return True 

741 return False 

742 

743 def _is_user_excluded_op(self, op): 

744 for opname_re in self._parameters.excluded_opname_re_list: 

745 if opname_re.match(op.name): 

746 return True 

747 for optype_re in self._parameters.excluded_optype_re_list: 

748 if optype_re.match(op.type): 

749 return True 

750 return False 

751 

752 def _signature_types(self): 

753 """Returns a dictionary holding the order of signatures in the cache for the selected trace mode.""" 

754 if self._parameters.trace_mode in set([ 

755 tensor_tracer_flags.TRACE_MODE_NAN_INF, 

756 tensor_tracer_flags.TRACE_MODE_NORM, 

757 tensor_tracer_flags.TRACE_MODE_HISTORY, 

758 tensor_tracer_flags.TRACE_MODE_MAX_ABS]): 

759 return {self._parameters.trace_mode: 0} 

760 if self._parameters.trace_mode == tensor_tracer_flags.TRACE_MODE_SUMMARY: 

761 return self._parameters.summary_signatures 

762 return {} 

763 

764 def _num_signature_dimensions(self): 

765 return len(self._signature_types()) 

766 

767 def _use_temp_cache(self): 

768 """Returns true if the intermediate values should be stacked instead of being stored in a tf.Variable. 

769 

770 Returns: 

771 A boolean, denoting whether to use a temporary cache or not. 

772 """ 

773 # If full tensors need to be stored tf.variables, then do not use temp 

774 # variables to store them. 

775 if self._use_tensor_buffer(): 

776 return False 

777 if self._use_tensor_values_cache(): 

778 return self._parameters.use_temp_cache_var 

779 else: 

780 # Temporary caches only replaces tf.Variables caches. If no cache is used 

781 # return False. 

782 return False 

783 

784 def _use_tensor_values_cache(self): 

785 """Returns True if immediate tensors should be first saved to a cache.""" 

786 return self._parameters.use_compact_trace 

787 

788 def _use_tensor_buffer(self): 

789 """Returns true if the whole tensor needs to be cached/buffered in memory.""" 

790 return (self._parameters.trace_mode == 

791 tensor_tracer_flags.TRACE_MODE_FULL_TENSOR_SUMMARY) 

792 

793 def _merge_tensor_signatures(self, signatures): 

794 """Returns a tensor that merges the given signatures. 

795 

796 Args: 

797 signatures: A dictionary of the signature updates from signature name to 

798 a tensor of dimension [1]. 

799 Returns: 

800 A tensor that concats the signature values in a predefined order. 

801 Raises: 

802 ValueError: Unable to merge signatures. 

803 """ 

804 sorted_update = [] 

805 if self._num_signature_dimensions() > 1: 

806 signature_indices = self._signature_types() 

807 for _, val in sorted(signatures.items(), 

808 key=lambda item: signature_indices[item[0]]): 

809 sorted_update.append(val) 

810 updates = array_ops_stack.stack( 

811 sorted_update, axis=0, name='merge_single_op_signatures') 

812 elif self._num_signature_dimensions() == 1: 

813 # Avoid stack operation if there is only a single signature. 

814 (_, val), = signatures.items() 

815 updates = val 

816 else: 

817 raise ValueError('Cannot merge 0 signatures. Check the value passed for ' 

818 'flag --signatures.') 

819 return updates 

820 

821 def _save_tensor_value_to_tmp_cache(self, cache_idx, updates, graph): 

822 """Returns an op that will save the given updates to an entry in the cache. 

823 

824 Args: 

825 cache_idx: The cache index of the tensor within the cache. 

826 updates: A dictionary of the signature updates from signature name to 

827 a tensor of dimension [1]. 

828 graph: A TensorFlow graph. 

829 Raises: 

830 RuntimeError: 

831 (1) graph is not already in self._temp_cache_var, or 

832 (2) cache_idx is out of range. 

833 """ 

834 updates = self._merge_tensor_signatures(updates) 

835 updates = array_ops.reshape(updates, 

836 [self._num_signature_dimensions()]) 

837 if graph not in self._temp_cache_var: 

838 raise RuntimeError('graph is not in self._temp_cache_var') 

839 if cache_idx >= len(self._temp_cache_var[graph]): 

840 raise RuntimeError('cache_idx (%d) is out of range (%d)' % ( 

841 cache_idx, len(self._temp_cache_var[graph]))) 

842 self._temp_cache_var[graph][cache_idx] = updates 

843 

844 def _save_tensor_value_to_cache_op(self, cache_idx, updates, graph): 

845 """Returns an op that will save the given updates to an entry in the cache. 

846 

847 Args: 

848 cache_idx: The cache index of the tensor within the cache. 

849 updates: A dictionary of the signature updates. 

850 graph: A TensorFlow graph. 

851 Returns: 

852 Cache update operation. 

853 """ 

854 # state_ops.scatter_update allows updates only along the first dimension. 

855 # Make a compact array by concatenating different signatures, and update 

856 # them all together. 

857 updates = self._merge_tensor_signatures(updates) 

858 updates = array_ops.reshape(updates, 

859 [1, self._num_signature_dimensions()]) 

860 indices = constant_op.constant([cache_idx]) 

861 cache = self._create_or_get_tensor_values_cache(_TT_SUMMARY_TAG, graph) 

862 return state_ops.scatter_update(cache, indices, updates).op 

863 

864 def _snapshot_tensor(self, tensor): 

865 """Creates a new tf.Variable and a new tf.Operation that assigns the value of the tensor to this variable. 

866 

867 Args: 

868 tensor: tensor whose values will be stored in a new tf.Variable. 

869 Returns: 

870 An assignment operation. 

871 """ 

872 

873 snapshot_variable = self._create_or_get_tensor_values_cache( 

874 tensor.name, tensor.op.graph, 

875 tensor.shape.as_list(), tensor.dtype) 

876 return state_ops.assign(snapshot_variable, tensor).op 

877 

878 def _preprocess_traced_tensor(self, tensor): 

879 """Computes NAN/Norm/Max on TPUs before sending to CPU. 

880 

881 Args: 

882 tensor: The tensor to be traced. 

883 Returns: 

884 A tensor that should be input to the trace_function. 

885 Raises: 

886 RuntimeError: If the signature is invalid. 

887 """ 

888 

889 def _detect_nan_inf(tensor): 

890 """Trace function for detecting any NaN/Inf in the tensor.""" 

891 

892 if tensor.dtype.is_floating: 

893 mask = math_ops.reduce_any( 

894 gen_math_ops.logical_or( 

895 gen_math_ops.is_nan(tensor), gen_math_ops.is_inf(tensor))) 

896 output_tensor = cond.cond( 

897 mask, 

898 lambda: constant_op.constant([1.0]), 

899 lambda: constant_op.constant([0.0])) 

900 else: 

901 output_tensor = constant_op.constant([0.0]) 

902 return output_tensor 

903 

904 def _compute_signature(tensor, tf_op, cast_to_f32=True): 

905 if cast_to_f32: 

906 tensor = math_ops.cast(tensor, dtypes.float32) 

907 output_tensor = tf_op(tensor) 

908 # Return type should be scalar. Set it if it does not have the 

909 # information. 

910 if not output_tensor.get_shape().is_fully_defined(): 

911 output_tensor = array_ops.reshape(output_tensor, []) 

912 return output_tensor 

913 

914 def _show_size(tensor): 

915 # In order to check the size of a tensor. 

916 # Not all sizes are known at the compile time, also, different replicas 

917 # sometimes get different sizes of tensors. 

918 # Collect it here to be used in merging replica data. 

919 tsize = _compute_signature(tensor, array_ops.size, cast_to_f32=False) 

920 # Cast to float32, so that it can be placed into same cache with other 

921 # signatures. 

922 return math_ops.cast(tsize, dtypes.float32) 

923 

924 def _show_max(tensor, cast_to_f32=True): 

925 # returns -inf for empty tensor 

926 return _compute_signature(tensor, math_ops.reduce_max, cast_to_f32) 

927 

928 def _show_min(tensor, cast_to_f32=True): 

929 # returns inf for empty tensor 

930 return _compute_signature(tensor, math_ops.reduce_min, cast_to_f32) 

931 

932 def _show_norm(tensor, cast_to_f32=True): 

933 # returns 0 for empty tensor 

934 return _compute_signature(tensor, linalg_ops.norm, cast_to_f32) 

935 

936 def _show_sparsity(tensor, cast_to_f32=True, tolerance=1e-06): 

937 # returns nan for empty tensor and treats nans as non-zero numbers 

938 def sparsity_fn(tensor): 

939 non_zeros = math_ops.greater_equal(math_ops.abs(tensor), tolerance) 

940 nans = math_ops.is_nan(tensor) 

941 return nn_impl.zero_fraction(math_ops.logical_or(non_zeros, nans)) 

942 

943 return _compute_signature(tensor, sparsity_fn, cast_to_f32) 

944 

945 def _show_mean_and_variance(tensor, cast_to_f32=True): 

946 """Returns the mean and variance of the given tensor.""" 

947 if cast_to_f32: 

948 tensor = math_ops.cast(tensor, dtypes.float32) 

949 # returns nan for empty tensor 

950 mean, var = nn_impl.moments(array_ops.reshape(tensor, [-1]), axes=[0]) 

951 # The shape has to be 1. Set it if it does not have the information. 

952 if not mean.get_shape().is_fully_defined(): 

953 mean = array_ops.reshape(mean, []) 

954 if not var.get_shape().is_fully_defined(): 

955 var = array_ops.reshape(var, []) 

956 return mean, var 

957 

958 def _show_max_abs(tensor, cast_to_f32=True): 

959 return _compute_signature( 

960 tensor, lambda t: math_ops.reduce_max(math_ops.abs(t)), cast_to_f32) 

961 

962 if self._parameters.trace_mode == tensor_tracer_flags.TRACE_MODE_NAN_INF: 

963 return {self._parameters.trace_mode: _detect_nan_inf(tensor)} 

964 if (self._parameters.trace_mode == 

965 tensor_tracer_flags.TRACE_MODE_PART_TENSOR): 

966 return {self._parameters.trace_mode: tensor} 

967 if (self._parameters.trace_mode in ( 

968 tensor_tracer_flags.TRACE_MODE_FULL_TENSOR, 

969 tensor_tracer_flags.TRACE_MODE_FULL_TENSOR_SUMMARY)): 

970 return {self._parameters.trace_mode: tensor} 

971 if self._parameters.trace_mode == tensor_tracer_flags.TRACE_MODE_NORM: 

972 return {self._parameters.trace_mode: array_ops.reshape( 

973 _show_norm(tensor), [1])} 

974 if self._parameters.trace_mode == tensor_tracer_flags.TRACE_MODE_HISTORY: 

975 return {self._parameters.trace_mode: array_ops.reshape( 

976 _show_norm(tensor), [1])} 

977 if self._parameters.trace_mode == tensor_tracer_flags.TRACE_MODE_MAX_ABS: 

978 return {self._parameters.trace_mode: _show_max_abs(tensor)} 

979 

980 if self._parameters.trace_mode == tensor_tracer_flags.TRACE_MODE_SUMMARY: 

981 tensor = math_ops.cast(tensor, dtypes.float32) 

982 result_dict = {} 

983 # Call mean and variance computation here to avoid adding the same nodes 

984 # twice. 

985 if (_TT_SUMMARY_MEAN in self._signature_types() or 

986 _TT_SUMMARY_VAR in self._signature_types()): 

987 mean, variance = _show_mean_and_variance(tensor, cast_to_f32=False) 

988 

989 for signature_name, _ in sorted(self._signature_types().items(), 

990 key=lambda x: x[1]): 

991 if signature_name == _TT_SUMMARY_NORM: 

992 signature_result_tensor = _show_norm(tensor, cast_to_f32=False) 

993 elif signature_name == _TT_SUMMARY_MAX: 

994 signature_result_tensor = _show_max(tensor, cast_to_f32=False) 

995 elif signature_name == _TT_SUMMARY_MAX_ABS: 

996 signature_result_tensor = _show_max_abs(tensor, cast_to_f32=False) 

997 elif signature_name == _TT_SUMMARY_MIN: 

998 signature_result_tensor = _show_min(tensor, cast_to_f32=False) 

999 elif signature_name == _TT_SUMMARY_SPARSITY: 

1000 signature_result_tensor = _show_sparsity(tensor) 

1001 elif signature_name == _TT_SUMMARY_SIZE: 

1002 signature_result_tensor = _show_size(tensor) 

1003 elif signature_name == _TT_SUMMARY_MEAN: 

1004 signature_result_tensor = mean 

1005 elif signature_name == _TT_SUMMARY_VAR: 

1006 signature_result_tensor = variance 

1007 else: 

1008 raise ValueError('Unknown signature type :%s.' % signature_name) 

1009 

1010 result_dict[signature_name] = signature_result_tensor 

1011 return result_dict 

1012 

1013 raise RuntimeError( 

1014 'Unsupported signature for trace mode %s.' 

1015 % self._parameters.trace_mode) 

1016 

1017 def _make_tensor_trace_fun(self, tensor_name, tensor_trace_order): 

1018 """Makes the tensor tracing function called by outside compilation. 

1019 

1020 Args: 

1021 tensor_name: name of the tensor being traced. 

1022 tensor_trace_order: TensorTraceOrder object holding tensorname to id map. 

1023 Returns: 

1024 A function to be passed as the first argument to outside compilation. 

1025 

1026 Raises: 

1027 RuntimeError: If the trace mode is invalid. 

1028 """ 

1029 

1030 def _print_tensor(tensor_name, num_elements, tensor, output_tensor): 

1031 """Prints a tensor value to a file. 

1032 

1033 Args: 

1034 tensor_name: name of the tensor being traced. 

1035 num_elements: number of elements to print (-1 means print all). 

1036 tensor: the tensor needs to be returned. 

1037 output_tensor: the tensor needs to be printed. 

1038 

1039 Returns: 

1040 The same tensor passed via the "tensor" argument. 

1041 

1042 Raises: 

1043 ValueError: If tensor_name is not already in 

1044 tensor_trace_order.tensorname_to_cache_idx. 

1045 """ 

1046 

1047 if self._parameters.is_brief_mode(): 

1048 if tensor_name not in tensor_trace_order.tensorname_to_cache_idx: 

1049 raise ValueError( 

1050 'Tensor %s with name %s is not in the tensorname_to_cache_idx' % 

1051 (tensor, tensor_name)) 

1052 msg = '%d' % tensor_trace_order.tensorname_to_cache_idx[tensor_name] 

1053 else: 

1054 msg = '"%s"' % tensor_name 

1055 

1056 if self._parameters.trace_dir: 

1057 output_path = os.path.join( 

1058 self._parameters.trace_dir, 

1059 _TRACE_FILE_NAME + self._get_outfile_suffix()) 

1060 output_stream = _OUTPUT_STREAM_ESCAPE + output_path 

1061 else: 

1062 output_stream = sys.stderr 

1063 return logging_ops.print_v2(msg, array_ops.shape(output_tensor), 

1064 '@', self._replica_id, 

1065 '\n', output_tensor, '\n', 

1066 summarize=num_elements, 

1067 output_stream=output_stream) 

1068 

1069 def _show_part_tensor(tensor): 

1070 """Trace function for printing part of the tensor.""" 

1071 

1072 return _print_tensor(tensor_name, _TRACE_MODE_PART_TENSOR_SIZE, 

1073 tensor, tensor) 

1074 

1075 def _show_full_tensor(tensor): 

1076 """Trace function for printing the entire tensor.""" 

1077 

1078 return _print_tensor(tensor_name, -1, tensor, tensor) 

1079 

1080 if (self._parameters.trace_mode == 

1081 tensor_tracer_flags.TRACE_MODE_PART_TENSOR): 

1082 return _show_part_tensor 

1083 # The input tensor has a shape of "[1]" for TRACE_MODE_NAN_INF, 

1084 # TRACE_MODE_NORM, and TRACE_MODE_MAX_ABS, as related computations are 

1085 # performed within TPUs and only their results are transferred to CPU. 

1086 # Simply, print the full tensor for these trace modes. 

1087 if self._parameters.trace_mode in ( 

1088 tensor_tracer_flags.TRACE_MODE_NAN_INF, 

1089 tensor_tracer_flags.TRACE_MODE_NORM, 

1090 tensor_tracer_flags.TRACE_MODE_FULL_TENSOR, 

1091 tensor_tracer_flags.TRACE_MODE_MAX_ABS, 

1092 tensor_tracer_flags.TRACE_MODE_SUMMARY, 

1093 tensor_tracer_flags.TRACE_MODE_HISTORY 

1094 ): 

1095 return _show_full_tensor 

1096 

1097 raise RuntimeError('Full tensor support is not available with trace mode %s' 

1098 %self._parameters.trace_mode) 

1099 

1100 def _is_in_control_flow(self, op): 

1101 """Returns true if the given op is inside a tf.cond or in tf.while_loop. 

1102 

1103 Args: 

1104 op: A tensorflow op that should be checked whether in control flow or not. 

1105 Returns: 

1106 A boolean value whether the op is in control flow or not. 

1107 """ 

1108 return control_flow_util.IsInCond(op) 

1109 

1110 def _is_in_outmost_while_loop(self, op): 

1111 """Returns true if the op is at the same level with the training loop. 

1112 

1113 Returns false if the op is in an inner while loop or if it is outside of the 

1114 training loop. 

1115 Args: 

1116 op: tf.Operation 

1117 

1118 Returns: 

1119 A boolean. 

1120 """ 

1121 ctxt = self._get_op_control_flow_context(op) 

1122 outer_while_context = control_flow_util.GetContainingWhileContext(ctxt) 

1123 return outer_while_context == control_flow_util.GetContainingWhileContext( 

1124 self._outmost_context) 

1125 

1126 def _should_trace_in_control_flow(self): 

1127 """Returns false incase it is not safe to trace ops in tf.cond or tf.while_loop.""" 

1128 # As different from the other trace modes, TRACE_MODE_OPTIONAL_SUMMARY 

1129 # forces the execution of the traced tensors. We should not trace the ops 

1130 # that may not be executed due to control flow. 

1131 if self._use_temp_cache(): 

1132 return False 

1133 elif self._tt_config.device_type == _DEVICE_TYPE_TPU: 

1134 # On TPUs do not trace in control flow unless we use caches to store 

1135 # intermediate values as calling outside compilation within an inner loop 

1136 # causes errors. 

1137 return self._use_tensor_values_cache() or self._use_tensor_buffer() 

1138 return True 

1139 

1140 def _skip_op(self, op_id, op, ops_in_exec_path, report_handler): 

1141 """Returns True if we should not trace Op. 

1142 

1143 Args: 

1144 op_id: Topological index of the op. 

1145 op: tf.Operation 

1146 ops_in_exec_path: Set of operations that are in the execution path. 

1147 report_handler: An instance of tensor_tracer_report.TTReportHandle. 

1148 Returns: 

1149 True if the op should not be traced, false otherwise. 

1150 """ 

1151 if TensorTracer.while_loop_op(op): 

1152 report_handler.instrument_op( 

1153 op, TensorTracer.reason(op_id, _REASON_WHILELOOP_OP)) 

1154 return True 

1155 if TensorTracer.control_flow_op(op): 

1156 report_handler.instrument_op( 

1157 op, TensorTracer.reason(op_id, _REASON_CONTROLFLOW_OP)) 

1158 return True 

1159 if TensorTracer.unsafe_op(op): 

1160 report_handler.instrument_op( 

1161 op, TensorTracer.reason(op_id, _REASON_UNSAFE_OP)) 

1162 return True 

1163 if TensorTracer.device_mismatch(self._tt_config.device_type, op): 

1164 report_handler.instrument_op( 

1165 op, TensorTracer.reason(op_id, _REASON_DEVICE_MISMATCH)) 

1166 return True 

1167 if op not in ops_in_exec_path: 

1168 report_handler.instrument_op( 

1169 op, TensorTracer.reason(op_id, _REASON_NOT_EXECUTED)) 

1170 return True 

1171 # TensorTracer will not trace the operations that are in an inner while loop 

1172 # or tf.cond when a temporary cache is used. Temporary cache adds direct 

1173 # data dependencies to traced operations, and needs a static number of 

1174 # traced operations. For these cases, 

1175 # - We do not know the number of slots required when there are inner while 

1176 # loops. TensorTracer can only trace the result of a while loop. 

1177 # - We do not know ahead of time which branch of the tf.cond 

1178 # will be taken, so we avoid introducing data dependencies for the 

1179 # operations inside a tf.cond. 

1180 # - We also cannot have a data dependency to an operation in a different 

1181 # while context. 

1182 if self._is_in_control_flow(op) or not self._is_in_outmost_while_loop(op): 

1183 if not self._should_trace_in_control_flow(): 

1184 report_handler.instrument_op( 

1185 op, TensorTracer.reason(op_id, _REASON_IN_CONTROL_FLOW)) 

1186 return True 

1187 if self._is_user_included_op(op): 

1188 report_handler.instrument_op( 

1189 op, TensorTracer.reason(op_id, _REASON_USER_INCLUDED)) 

1190 if tensor_tracer_flags.TT_CHECK_FILTER.value: 

1191 logging.info('USER_INCLUDED op %s', op.name) 

1192 return False 

1193 

1194 if not self._inside_op_range(op_id): 

1195 report_handler.instrument_op( 

1196 op, TensorTracer.reason(op_id, _REASON_OUTSIDE_OP_RANGE)) 

1197 return True 

1198 if not self._is_interesting_op(op): 

1199 report_handler.instrument_op( 

1200 op, TensorTracer.reason(op_id, _REASON_LESS_INTERESTING_OP)) 

1201 return True 

1202 if self._is_user_excluded_op(op): 

1203 report_handler.instrument_op( 

1204 op, TensorTracer.reason(op_id, _REASON_USER_EXCLUDED)) 

1205 if tensor_tracer_flags.TT_CHECK_FILTER.value: 

1206 logging.info('USER_EXCLUDED op %s', op.name) 

1207 return True 

1208 return False 

1209 

1210 def _skip_tensor(self, op_id, out_tensor, report_handler): 

1211 """Returns True if we should not trace out_tensor. 

1212 

1213 Args: 

1214 op_id: Topological index of the op producing tensor. 

1215 out_tensor: tf.Tensor 

1216 report_handler: An instance of tensor_tracer_report.TTReportHandle. 

1217 Returns: 

1218 True if the tensor should not be traced, false otherwise. 

1219 """ 

1220 

1221 # Skips a tensor if the tensor has a non-numeric type. 

1222 # Note: we cannot use check_ops.is_numeric_tensor(out_tensor) 

1223 # because it also excludes tensors with dtypes, bool, and 

1224 # float32_ref, which we actually want to trace. 

1225 non_numeric_tensor_types = set([dtypes.variant, dtypes.resource, 

1226 dtypes.string]) 

1227 if out_tensor.dtype in non_numeric_tensor_types: 

1228 

1229 report_handler.instrument_tensor( 

1230 out_tensor, TensorTracer.reason(op_id, _REASON_NON_NUMERIC_TENSOR)) 

1231 return True 

1232 # Skip a tensor if it feeds a special while loop op. 

1233 if [consumer for consumer in out_tensor.consumers() if 

1234 TensorTracer.while_loop_op(consumer)]: 

1235 report_handler.instrument_tensor( 

1236 out_tensor, TensorTracer.reason(op_id, _REASON_FEEDS_WHILELOOP_OP)) 

1237 return True 

1238 if self._is_user_included_op(out_tensor.op): 

1239 report_handler.instrument_tensor( 

1240 out_tensor, TensorTracer.reason(op_id, _REASON_USER_INCLUDED)) 

1241 if tensor_tracer_flags.TT_CHECK_FILTER.value: 

1242 logging.info('USER_INCLUDED tensor %s', out_tensor.name) 

1243 return False 

1244 if self._is_user_excluded_op(out_tensor.op): 

1245 report_handler.instrument_tensor( 

1246 out_tensor, TensorTracer.reason(op_id, _REASON_USER_EXCLUDED)) 

1247 if tensor_tracer_flags.TT_CHECK_FILTER.value: 

1248 logging.info('USER_EXCLUDED tensor %s', out_tensor.name) 

1249 return True 

1250 if not out_tensor.get_shape().is_fully_defined(): 

1251 # If trace mode is nan-inf, norm or max, then the tensor will be reduced 

1252 # to a scalar before the outside compilation call. 

1253 if self._parameters.trace_mode in ( 

1254 tensor_tracer_flags.TRACE_MODE_NAN_INF, 

1255 tensor_tracer_flags.TRACE_MODE_NORM, 

1256 tensor_tracer_flags.TRACE_MODE_HISTORY, 

1257 tensor_tracer_flags.TRACE_MODE_MAX_ABS, 

1258 tensor_tracer_flags.TRACE_MODE_SUMMARY 

1259 ): 

1260 report_handler.instrument_tensor( 

1261 out_tensor, TensorTracer.reason(op_id, _REASON_TENSOR_GET_TRACED)) 

1262 return False 

1263 else: 

1264 report_handler.instrument_tensor( 

1265 out_tensor, TensorTracer.reason(op_id, _REASON_DYNAMIC_SHAPE)) 

1266 return True 

1267 rank = len(out_tensor.shape) 

1268 if rank < 1: 

1269 # scalar 

1270 if self._parameters.trace_scalar_ops: 

1271 if TensorTracer.unsafe_scalar_trace(out_tensor.op): 

1272 report_handler.instrument_tensor( 

1273 out_tensor, TensorTracer.reason(op_id, _REASON_UNSAFE_SCALAR)) 

1274 return True 

1275 else: 

1276 report_handler.instrument_tensor( 

1277 out_tensor, TensorTracer.reason(op_id, _REASON_SCALAR_GET_TRACED)) 

1278 return False 

1279 else: 

1280 report_handler.instrument_tensor( 

1281 out_tensor, TensorTracer.reason(op_id, _REASON_SKIP_SCALAR)) 

1282 return True 

1283 else: 

1284 # tensor 

1285 report_handler.instrument_tensor( 

1286 out_tensor, TensorTracer.reason(op_id, _REASON_TENSOR_GET_TRACED)) 

1287 return False 

1288 

1289 def _filter_execution_path_operations(self, operations, fetches): 

1290 """Returns the set of ops in the execution path to compute given fetches.""" 

1291 

1292 # If no fetch provided, then return all operations. 

1293 if fetches is None: 

1294 return set(operations) 

1295 # Convert to list, if a single element is provided. 

1296 if not isinstance(fetches, (list, tuple)): 

1297 fetches = [fetches] 

1298 # If a tensor is given as fetch, convert it to op. 

1299 op_fetches = [] 

1300 for fetch in fetches: 

1301 if isinstance(fetch, ops.Operation): 

1302 op_fetches.append(fetch) 

1303 elif isinstance(fetch, ops.Tensor): 

1304 op_fetches.append(fetch.op) 

1305 else: 

1306 raise RuntimeError('Given fetch:%s is neither a tensor nor an op.' 

1307 %fetch) 

1308 

1309 execution_path_operations = set(op_fetches) 

1310 traverse_stack = list(op_fetches) 

1311 while True: 

1312 if not traverse_stack: 

1313 break 

1314 head_op = traverse_stack.pop() 

1315 input_ops = [tensor_input.op for tensor_input in head_op.inputs] 

1316 input_ops.extend(head_op.control_inputs) 

1317 

1318 for input_op in input_ops: 

1319 if input_op not in execution_path_operations: 

1320 # Filter out loop condition operations, tracing them causes a cycle. 

1321 # Trace only the loop-body. 

1322 if TensorTracer.loop_cond_op(input_op): 

1323 continue 

1324 execution_path_operations.add(input_op) 

1325 traverse_stack.append(input_op) 

1326 return execution_path_operations 

1327 

1328 def _determine_and_instrument_traced_tensors(self, graph_order, 

1329 ops_in_exec_path, 

1330 tensor_trace_points, 

1331 report_handler): 

1332 """Determines the tensors to trace and instruments the trace details. 

1333 

1334 Args: 

1335 graph_order: graph_order tuple containing graph (tf.graph), operations 

1336 (list of operations), op_to_idx (op id mapping), (tensors) list of 

1337 tensors, tensor_to_idx (tensor id mapping), contains_cycle (whether 

1338 there is a cycle in the graph), topological_order_or_cycle (list of ops 

1339 in topological order or list of ops creating a cycle). 

1340 ops_in_exec_path: Set of ops in the execution path. 

1341 tensor_trace_points: Collection of programatic tensor trace points. 

1342 report_handler: An instance of tensor_tracer_report.TTReportHandle. 

1343 Returns: 

1344 List of tensors to be traced. 

1345 """ 

1346 

1347 traced_tensors = [] 

1348 checkpoint_operations = set([tensor.op 

1349 for (tensor, _) in tensor_trace_points]) 

1350 for op_id, op in enumerate(graph_order.operations): 

1351 if checkpoint_operations and op not in checkpoint_operations: 

1352 continue 

1353 if self._skip_op(op_id, op, ops_in_exec_path, report_handler): 

1354 continue 

1355 for i in range(len(op.outputs)): 

1356 out_tensor = op.outputs[i] 

1357 if not self._skip_tensor(op_id, out_tensor, report_handler): 

1358 traced_tensors.append(out_tensor) 

1359 return traced_tensors 

1360 

1361 def _check_trace_files(self): 

1362 """Checks if any requirements for trace files are satisfied.""" 

1363 

1364 if not self._parameters.trace_dir: 

1365 # traces will be written to stderr. No need to check trace files. 

1366 return 

1367 if self._parameters.trace_mode == tensor_tracer_flags.TRACE_MODE_SUMMARY: 

1368 # Output files are handled by tf.summary operations, no need to precreate 

1369 # them. 

1370 return 

1371 if not gfile.Exists(self._parameters.trace_dir): 

1372 file_io.recursive_create_dir(self._parameters.trace_dir) 

1373 if not gfile.Exists(self._parameters.trace_dir): 

1374 raise RuntimeError('Failed to create trace directory at %s' % 

1375 self._parameters.trace_dir) 

1376 

1377 def _create_temp_cache(self, num_traced_tensors, num_signatures, graph): 

1378 """Creates a temporary cache with the given dimensions. 

1379 

1380 Fills the self._temp_cache_var with num_traced_tensors tf.constant() ops 

1381 that have shape of [num_signatures]. 

1382 Args: 

1383 num_traced_tensors: Int, denoting total number of traced tensors. 

1384 num_signatures: Int, denoting the number of statistics collected per 

1385 tensors. 

1386 graph: TensorFlow graph. 

1387 """ 

1388 init_value = constant_op.constant(_COMPACT_TRACE_ENTRY_INIT_VALUE, 

1389 dtype=dtypes.float32, 

1390 shape=[num_signatures]) 

1391 self._temp_cache_var[graph] = [ 

1392 init_value for _ in range(num_traced_tensors)] 

1393 

1394 def _determine_trace_and_create_report(self, graph, ops_in_exec_path, 

1395 graph_summary_tag): 

1396 """Work needs to be done prior to TPU or CPU tracing. 

1397 

1398 Args: 

1399 graph: tf.graph 

1400 ops_in_exec_path: Set of operations in the execution path. 

1401 graph_summary_tag: the summary tag name for the given graph. 

1402 Returns: 

1403 An instance of tensor_tracer_report.TensorTraceOrder, containing list of 

1404 tensors to be traced with their topological order information. 

1405 Raises: 

1406 RuntimeError: If opname filtering is incorrectly set. 

1407 """ 

1408 

1409 self._check_trace_files() 

1410 

1411 graph_order = tensor_tracer_report.sort_tensors_and_ops(graph) 

1412 tensor_trace_points = graph.get_collection(_TENSOR_TRACER_COLLECTION) 

1413 

1414 report_handler = tensor_tracer_report.TTReportHandle() 

1415 traced_tensors = self._determine_and_instrument_traced_tensors( 

1416 graph_order, ops_in_exec_path, tensor_trace_points, report_handler) 

1417 logging.info('TensorTracer is tracing %d tensors.', len(traced_tensors)) 

1418 if traced_tensors and tensor_tracer_flags.TT_CHECK_FILTER.value: 

1419 raise RuntimeError('Verify ops being traced by tensor tracer.') 

1420 

1421 tensor_trace_order = tensor_tracer_report.TensorTraceOrder(graph_order, 

1422 traced_tensors) 

1423 num_signatures = self._num_signature_dimensions() 

1424 # Create a cache variable if compact_tracing is used. 

1425 if num_signatures and self._use_tensor_values_cache(): 

1426 if self._use_temp_cache(): 

1427 self._create_temp_cache(len(traced_tensors), num_signatures, graph) 

1428 else: 

1429 self._create_or_get_tensor_values_cache( 

1430 _TT_SUMMARY_TAG, graph, [len(traced_tensors), num_signatures]) 

1431 if self._parameters.trace_mode in ( 

1432 tensor_tracer_flags.TRACE_MODE_HISTORY): 

1433 self._create_or_get_tensor_history_values_cache( 

1434 _TT_SUMMARY_TAG, graph, [len(traced_tensors), num_signatures]) 

1435 if self._parameters.trace_mode in ( 

1436 tensor_tracer_flags.TRACE_MODE_SUMMARY, 

1437 tensor_tracer_flags.TRACE_MODE_FULL_TENSOR_SUMMARY): 

1438 self._report_proto = report_handler.create_report_proto( 

1439 self._tt_config, self._parameters, tensor_trace_order, 

1440 tensor_trace_points, self._signature_types()) 

1441 if self._parameters.use_fingerprint_subdir: 

1442 self._parameters.trace_dir = os.path.join( 

1443 self._parameters.trace_dir, self._report_proto.fingerprint) 

1444 logging.info('TensorTracer updating trace_dir to %s', 

1445 self._parameters.trace_dir) 

1446 self._report_proto_path = report_handler.report_proto_path( 

1447 self._parameters.trace_dir, graph_summary_tag) 

1448 

1449 if self._parameters.report_file_path != _SKIP_REPORT_FILE: 

1450 report_handler.write_report_proto(self._report_proto_path, 

1451 self._report_proto, self._parameters) 

1452 else: 

1453 if self._parameters.trace_mode not in ( 

1454 tensor_tracer_flags.TRACE_MODE_HISTORY): 

1455 report_handler.create_report(self._tt_config, self._parameters, 

1456 tensor_trace_order, tensor_trace_points) 

1457 return tensor_trace_order 

1458 

1459 def _create_host_call(self): 

1460 return self._parameters.trace_mode in ( 

1461 tensor_tracer_flags.TRACE_MODE_SUMMARY, 

1462 tensor_tracer_flags.TRACE_MODE_FULL_TENSOR_SUMMARY) 

1463 

1464 def _inspect_summary_cache(self, cache, replica_id, step_num, output_stream, 

1465 tensor_trace_order): 

1466 """Generates a print operation to print trace inspection. 

1467 

1468 Args: 

1469 cache: Tensor storing the trace results for the step. 

1470 replica_id: Tensor storing the replica id of the running core. 

1471 step_num: Step number. 

1472 output_stream: Where to print the outputs, e.g., file path, or sys.stderr. 

1473 tensor_trace_order: TensorTraceOrder object holding tensorname to id map. 

1474 

1475 Returns: 

1476 The Op to flush the cache to file. 

1477 """ 

1478 def _inspect_tensor(tensor): 

1479 """Returns the text to be printed for inspection output.""" 

1480 if (self._parameters.trace_mode == 

1481 tensor_tracer_flags.TRACE_MODE_NAN_INF): 

1482 return cond.cond( 

1483 math_ops.greater(tensor, 0.0), 

1484 lambda: 'has NaNs/Infs!', 

1485 lambda: 'has no NaNs or Infs.') 

1486 else: 

1487 return tensor 

1488 

1489 # Check if there are graph operations being profiled. 

1490 if not tensor_trace_order.traced_tensors: 

1491 logging.warn('Inspect mode has no tensors in the cache to check.') 

1492 return control_flow_ops.no_op 

1493 

1494 # Check if the cache includes any nan or inf 

1495 if self._parameters.trace_mode == tensor_tracer_flags.TRACE_MODE_NAN_INF: 

1496 # Cache has 1s or 0s if the mode is NaN_INF 

1497 step_has_nan_or_inf = math_ops.greater(math_ops.reduce_sum(cache), 0.0) 

1498 else: 

1499 # Cache has the actual numerics for other modes. 

1500 step_has_nan_or_inf = math_ops.reduce_any( 

1501 gen_math_ops.logical_or( 

1502 gen_math_ops.is_nan(cache), gen_math_ops.is_inf(cache))) 

1503 

1504 # Summarizing message for each step. 

1505 step_error_message = cond.cond( 

1506 step_has_nan_or_inf, 

1507 lambda: 'NaNs or Infs in the step!', 

1508 lambda: 'No numerical issues have been found for the step.') 

1509 

1510 # No need to print core numbers if the cache is merged already. 

1511 if self._parameters.collect_summary_per_core: 

1512 stats = ['\n\n', 'core:', replica_id, ',', 'step:', step_num, '-->', 

1513 step_error_message, 

1514 'Printing tensors for mode:%s...' % self._parameters.trace_mode] 

1515 else: 

1516 stats = ['\n\n', 'step:', step_num, '-->', step_error_message, 

1517 'Printing tensors for mode:%s...' % self._parameters.trace_mode] 

1518 

1519 for tensor_name, cache_idx in sorted( 

1520 tensor_trace_order.tensorname_to_cache_idx.items(), 

1521 key=lambda item: item[1]): 

1522 if self._parameters.collect_summary_per_core: 

1523 stats.extend([ 

1524 '\n', 'core:', replica_id, ',', 'step:', step_num, ',', 

1525 tensor_name, '-->', _inspect_tensor(cache[cache_idx, 0])]) 

1526 else: 

1527 stats.extend([ 

1528 '\n', 'step:', step_num, ',', 

1529 tensor_name, '-->', _inspect_tensor(cache[cache_idx, 0])]) 

1530 return logging_ops.print_v2(*stats, summarize=-1, 

1531 output_stream=output_stream) 

1532 

1533 def _inspect_history_cache(self, cache, replica_id, step_num, 

1534 tensor_trace_order): 

1535 """Generates a conditional print operation to log differences in tensor values. 

1536 

1537 Args: 

1538 cache: Tensor storing the trace results for the step. 

1539 replica_id: Tensor storing the replica id of the running core. 

1540 step_num: Step number. 

1541 tensor_trace_order: TensorTraceOrder object holding tensorname to id map. 

1542 

1543 Returns: 

1544 The Op to flush the cache to file. 

1545 """ 

1546 # Check if there are graph operations being profiled. 

1547 if not tensor_trace_order.traced_tensors: 

1548 logging.warn('TT history mode has no tensors in the cache to check.') 

1549 return control_flow_ops.no_op 

1550 

1551 stats = ['\n\n', 'core:', replica_id, ',', 'step:', step_num] 

1552 diffs = [] 

1553 for tensor_name, cache_idx in sorted( 

1554 tensor_trace_order.tensorname_to_cache_idx.items(), 

1555 key=lambda item: item[1]): 

1556 

1557 tensor_to_write = cache[cache_idx, 0] 

1558 snapshot_variable = self._create_or_get_tensor_history_values_cache( 

1559 tensor_to_write.name, tensor_to_write.op.graph, 

1560 tensor_to_write.shape.as_list(), tensor_to_write.dtype) 

1561 

1562 with ops.control_dependencies([snapshot_variable]): 

1563 old_value = state_ops.assign_add(snapshot_variable, 0.0) 

1564 

1565 with ops.control_dependencies([old_value]): 

1566 new_value = math_ops.cast(tensor_to_write, dtypes.float32) 

1567 delta = math_ops.abs(math_ops.subtract(old_value, new_value)) 

1568 updated = state_ops.assign(snapshot_variable, new_value) 

1569 diffs.append(delta) 

1570 with ops.control_dependencies([updated]): 

1571 new_value_from_var = state_ops.assign_add(snapshot_variable, 0.0) 

1572 

1573 stats.extend([ 

1574 '\n', 'core:', replica_id, ',', 'step:', step_num, ',', 

1575 tensor_name, '-->', old_value, new_value_from_var, delta]) 

1576 

1577 diff_stack = array_ops_stack.stack(diffs) 

1578 step_max = math_ops.reduce_max(diff_stack) 

1579 

1580 return cond.cond( 

1581 math_ops.greater(step_max, tensor_tracer_flags.DELTA_THRESHOLD.value), 

1582 lambda: logging_ops.print_v2(*stats, summarize=-1), 

1583 lambda: control_flow_ops.no_op()) # pylint: disable=unnecessary-lambda 

1584 

1585 def _get_outfile_suffix(self): 

1586 if remote_utils.is_remote_path(self._parameters.trace_dir): 

1587 return remote_utils.get_appendable_file_encoding() 

1588 else: 

1589 return '' 

1590 

1591 def _generate_flush_cache_op(self, num_replicas, on_tpu, 

1592 tensor_trace_order, graph): 

1593 """Generates an Op that will flush the cache to file. 

1594 

1595 Args: 

1596 num_replicas: total number of replicas. 

1597 on_tpu: if the graph is executed on TPU. 

1598 tensor_trace_order: TensorTraceOrder object holding tensorname to id map. 

1599 graph: TensorFlow graph. 

1600 

1601 Returns: 

1602 The Op to flush the cache to file. 

1603 """ 

1604 

1605 def _flush_fun(cache, replica_id, step_num): 

1606 """Flushes the cache to a file corresponding to replica_id.""" 

1607 

1608 def _f(file_index): 

1609 """Generates a func that flushes the cache to a file.""" 

1610 def _print_cache(): 

1611 """Flushes the cache to a file.""" 

1612 replica_str = ('%d' % file_index) 

1613 if self._parameters.trace_dir: 

1614 output_path = (os.path.join(self._parameters.trace_dir, 

1615 _COMPACT_TRACE_FILE_PREFIX) 

1616 + replica_str + self._get_outfile_suffix()) 

1617 output_stream = _OUTPUT_STREAM_ESCAPE + output_path 

1618 else: 

1619 output_stream = sys.stderr 

1620 

1621 new_step_line = _REPLICA_ID_TAG + replica_str 

1622 print_ops = [] 

1623 if self._parameters.inspect_trace: 

1624 if self._num_signature_dimensions() > 1: 

1625 raise ValueError('Inspecting multi signatures are not supported.') 

1626 if self._parameters.trace_mode in ( 

1627 tensor_tracer_flags.TRACE_MODE_HISTORY): 

1628 print_ops.append( 

1629 self._inspect_history_cache( 

1630 cache=cache, 

1631 replica_id=replica_id, 

1632 step_num=step_num, 

1633 tensor_trace_order=tensor_trace_order)) 

1634 else: 

1635 print_ops.append( 

1636 self._inspect_summary_cache( 

1637 cache=cache, 

1638 replica_id=replica_id, 

1639 step_num=step_num, 

1640 output_stream=output_stream, 

1641 tensor_trace_order=tensor_trace_order)) 

1642 else: 

1643 for i in range(self._num_signature_dimensions()): 

1644 print_ops.append(logging_ops.print_v2( 

1645 new_step_line, '\n', 

1646 cache[:, i], '\n', 

1647 summarize=-1, 

1648 output_stream=output_stream)) 

1649 with ops.control_dependencies(print_ops): 

1650 return constant_op.constant(0).op 

1651 return _print_cache 

1652 

1653 def _eq(file_index): 

1654 return math_ops.equal(replica_id, file_index) 

1655 

1656 flush_op_cases = {} 

1657 flush_op_cases[_eq(0)] = _f(0) 

1658 for i in range(1, num_replicas): 

1659 if on_tpu and not self._parameters.collect_summary_per_core: 

1660 # If this is the case, the cache is already merged for all cores. 

1661 # Only first core flushes the cache. 

1662 flush_op_cases[_eq(i)] = control_flow_ops.no_op 

1663 else: 

1664 flush_op_cases[_eq(i)] = _f(i) 

1665 # Each replica needs to determine where to write their output. 

1666 # To do this, we check if replica_id is 0, then 1, ..., and then 

1667 # num_replicas - 1 statically; and return the corresponding static file 

1668 # name. We cannot simply set the file name in python, as replica_id is 

1669 # only known during tf runtime, and we cannot create dynamic filenames. 

1670 return control_flow_case.case(flush_op_cases, exclusive=True) 

1671 

1672 cache = self._create_or_get_tensor_values_cache(_TT_SUMMARY_TAG, graph) 

1673 if self._use_temp_cache(): 

1674 cache_val = cache 

1675 else: 

1676 cache_val = cache.value() 

1677 

1678 if on_tpu: 

1679 # If we do not need to collect traces for all cores, merge and aggregate 

1680 # per core trace. 

1681 if not self._parameters.collect_summary_per_core: 

1682 cache_val = self.merge_caches_on_tpu(cache_val) 

1683 cache_val = self.aggregate_global_cache(cache_val)[0] 

1684 

1685 flush_op = tpu_replication.outside_compilation( 

1686 _flush_fun, cache_val, self._replica_id, 

1687 array_ops.identity(training_util.get_or_create_global_step())) 

1688 else: 

1689 global_step = training_util.get_or_create_global_step() 

1690 flush_op = _flush_fun(cache_val, self._replica_id, global_step) 

1691 

1692 if self._use_temp_cache(): 

1693 with ops.control_dependencies([flush_op]): 

1694 return constant_op.constant(0).op 

1695 else: 

1696 # Re-initialize the local cache variable. 

1697 with ops.control_dependencies([flush_op]): 

1698 reset_value = constant_op.constant(_COMPACT_TRACE_ENTRY_INIT_VALUE, 

1699 dtype=cache.dtype, 

1700 shape=cache.shape) 

1701 assign_op = state_ops.assign(cache, reset_value).op 

1702 with ops.control_dependencies([assign_op]): 

1703 return constant_op.constant(0).op 

1704 

1705 def _flush_tensor_values_cache(self, tensor_fetches, op_fetches, on_tpu, 

1706 tensor_trace_order, graph): 

1707 """Flushes the intermediate tensor values in the graph to the cache. 

1708 

1709 Args: 

1710 tensor_fetches: list of tensor results returned by the model_fn. 

1711 op_fetches: list of ops that are returned by the model_fn, e.g., train_op. 

1712 on_tpu: if the graph is executed on TPU. 

1713 tensor_trace_order: TensorTraceOrder object holding tensorname to id map. 

1714 graph: TensorFlow graph. 

1715 

1716 Returns: 

1717 An identical copy of tensor_fetches. 

1718 """ 

1719 # Add a dependency to op and tensor fetches to make sure that all tracing 

1720 # ops are executed before flushing trace results. 

1721 if not tensor_trace_order.traced_tensors: 

1722 logging.warn('No tensor values being traced. No flush cache op added.') 

1723 return tensor_fetches 

1724 with ops.control_dependencies(op_fetches + 

1725 [tensor.op for tensor in tensor_fetches]): 

1726 flush_cache_op = self._generate_flush_cache_op( 

1727 self._tt_config.num_replicas, on_tpu, tensor_trace_order, graph) 

1728 return control_flow_ops.tuple(tensor_fetches, 

1729 control_inputs=[flush_cache_op]) 

1730 

1731 def _process_tensor_fetches(self, tensor_fetches): 

1732 """Check that tensor_fetches is not empty and have valid tensors.""" 

1733 # If none or empty list. 

1734 if tensor_fetches is None: 

1735 raise RuntimeError('tensor_fetches provided to tensor_tracer cannot be ' 

1736 'None.') 

1737 if not isinstance(tensor_fetches, (list, tuple)): 

1738 tensor_fetches = [tensor_fetches] 

1739 elif not tensor_fetches: 

1740 raise RuntimeError('tensor_fetches provided to tensor_tracer cannot be ' 

1741 'empty list.') 

1742 fetches = [] 

1743 for fetch in tensor_fetches: 

1744 if isinstance(fetch, ops.Tensor): 

1745 fetches.append(fetch) 

1746 else: 

1747 raise RuntimeError('Given tensor_fetch:%s is not a tensor.' % fetch) 

1748 return fetches 

1749 

1750 def _process_op_fetches(self, op_fetches): 

1751 """Check that op_fetches have valid ops.""" 

1752 if op_fetches is None: 

1753 return [] 

1754 

1755 if not isinstance(op_fetches, (list, tuple)): 

1756 op_fetches = [op_fetches] 

1757 

1758 fetches = [] 

1759 for fetch in op_fetches: 

1760 if isinstance(fetch, ops.Operation): 

1761 fetches.append(fetch) 

1762 elif isinstance(fetch, ops.Tensor): 

1763 fetches.append(fetch.op) 

1764 else: 

1765 logging.warning('Ignoring the given op_fetch:%s, which is not an op.' % 

1766 fetch) 

1767 return fetches 

1768 

1769 def _convert_fetches_to_input_format(self, input_fetches, current_fetches): 

1770 """Changes current_fetches' format, so that it matches input_fetches.""" 

1771 if isinstance(input_fetches, ops.Tensor): 

1772 if len(current_fetches) != 1: 

1773 raise RuntimeError('Tensor tracer input/output fetches do not match.') 

1774 return current_fetches[0] 

1775 else: 

1776 if len(current_fetches) != len(current_fetches): 

1777 raise RuntimeError('Tensor tracer input/output fetches do not match.') 

1778 elif isinstance(input_fetches, tuple): 

1779 return tuple(current_fetches) 

1780 else: 

1781 return current_fetches 

1782 

1783 def _get_op_control_flow_context(self, op): 

1784 """Returns the control flow of the given op. 

1785 

1786 Args: 

1787 op: tf.Operation for which the control flow context is requested. 

1788 Returns: 

1789 op_control_flow_context: which the is control flow context of the given 

1790 op. If the operation type is LoopExit, returns the outer control flow 

1791 context. 

1792 """ 

1793 # pylint: disable=protected-access 

1794 op_control_flow_context = op._control_flow_context 

1795 # pylint: enable=protected-access 

1796 if control_flow_util.IsLoopExit(op): 

1797 op_control_flow_context = op_control_flow_context.outer_context 

1798 return op_control_flow_context 

1799 

1800 def merge_caches_on_tpu(self, local_tpu_cache_tensor): 

1801 """Merges the given caches on tpu. 

1802 

1803 Args: 

1804 local_tpu_cache_tensor: A local tensor that needs to be merged 

1805 by concanting data from other tpu cores. 

1806 Returns: 

1807 A merged tf.Tensor. 

1808 """ 

1809 x = array_ops.broadcast_to( 

1810 local_tpu_cache_tensor, 

1811 shape=[self._tt_config.num_replicas] + 

1812 local_tpu_cache_tensor.shape.as_list()) 

1813 

1814 if tensor_tracer_flags.TT_SINGLE_CORE_SUMMARIES.value: 

1815 return x 

1816 

1817 return tpu_ops.all_to_all( 

1818 x, concat_dimension=0, split_dimension=0, 

1819 split_count=self._tt_config.num_replicas, 

1820 group_assignment=[list(range(self._tt_config.num_replicas))]) 

1821 

1822 def aggregate_global_cache(self, global_tt_summary_cache): 

1823 """Merges the given caches on tpu. 

1824 

1825 Args: 

1826 global_tt_summary_cache: The global tensor tracer summary cache tensor 

1827 with shape (num_cores, num_traced_tensors, num_traced_signatures). First 

1828 dimension corresponds to core_id, where global_tpu_cache_tensor[i] 

1829 correspond to the local cache from core-i. 

1830 Returns: 

1831 An aggregated tf.Tensor. 

1832 Raises: 

1833 RuntimeError: if there is no aggregate function defined for a signature. 

1834 """ 

1835 

1836 # Merge only statistics tensor, if it is any other tensor we simply, 

1837 # concatenate them. 

1838 agg_fn_map = self._parameters.get_signature_to_agg_fn_map() 

1839 signature_idx_map = self._signature_types() 

1840 aggregation_result = [] 

1841 for signature, idx in sorted(signature_idx_map.items(), 

1842 key=operator.itemgetter(1)): 

1843 if signature not in agg_fn_map: 

1844 raise RuntimeError('No aggregation function is defined for ' 

1845 'signature %s.' % signature) 

1846 # The dimensions of the statistics tensor is 

1847 # num_cores x num_traced_tensors x num_signatures 

1848 # value[:,:,idx] will return the portion of the tensor related 

1849 # to signature. 

1850 signature_tensor = global_tt_summary_cache[:, :, idx] 

1851 # Merge it along the first (core) axis. 

1852 agg_fn = agg_fn_map[signature] 

1853 agg_tensor = agg_fn(signature_tensor, axis=0) 

1854 aggregation_result.append(agg_tensor) 

1855 # Merge results corresponding to different signatures 

1856 

1857 merged_signatures = array_ops_stack.stack(aggregation_result) 

1858 # merged_signatures has dimensions 

1859 # num_signatures x num_traced_tensors, transpose it so that it 

1860 # will match with the original structure 

1861 # num_traced_tensors x num_signatures. 

1862 transposed_signatures = array_ops.transpose(merged_signatures) 

1863 # Expand 1 more dimension so that it will match with the expected 

1864 # structure num_cores x num_traced_tensors x num_signatures. 

1865 return array_ops.expand_dims(transposed_signatures, axis=0) 

1866 

1867 def _prepare_host_call_fn(self, processed_t_fetches, 

1868 op_fetches, graph, graph_summary_tag): 

1869 """Creates a host call function that will write the cache as tb summary. 

1870 

1871 Args: 

1872 processed_t_fetches: List of tensor provided to session.run. 

1873 op_fetches: List of operations provided to session.run. 

1874 graph: TensorFlow graph. 

1875 graph_summary_tag: the summary_tag name for the given graph. 

1876 Raises: 

1877 ValueError if trace_dir is not set. 

1878 """ 

1879 if self._parameters.trace_dir is None: 

1880 raise ValueError('Provide a trace_dir for tensor tracer in summary mode. ' 

1881 '--trace_dir=/model/dir') 

1882 

1883 def _write_cache(step, event_file_suffix=None, **kwargs): 

1884 """Writes the given caches as tensor summary. 

1885 

1886 Args: 

1887 step: Step tensor with dimension [num_cores]. 

1888 event_file_suffix: Event filename suffix tensor. 

1889 **kwargs: The dictionary of tensors that needs to be written as 

1890 summaries. Key and value pairs within kwargs correspond to the tag 

1891 name, and tensor content that will be written using summary.write. 

1892 The trace_modes that use this function are: 

1893 - summary: In summary mode, kwargs includes a single (tag, content) 

1894 pair which are, _TT_SUMMARY_TAG and a tf.float32 signature_cache 

1895 variable. The dimension of the signature_cache is: 

1896 num_cores x num_traced_tensors x num_signatures. 

1897 - full_tensor_summary: kwargs will include all traced tensors. Tag 

1898 and content correspond to the name of the tensor, and its actual 

1899 content. 

1900 Returns: 

1901 A tf.Operation that needs to be executed for the host call dependencies. 

1902 """ 

1903 file_suffix = _TT_EVENT_FILE_SUFFIX 

1904 if event_file_suffix is not None: 

1905 file_suffix = string_ops.string_join([file_suffix, event_file_suffix], 

1906 separator='.') 

1907 # TODO(deveci): Parametrize max_queue, so that flushing op can be called 

1908 # less frequently. 

1909 # Setting max_queue to 100 appears to be safe even when the number of 

1910 # iterations are much lower, as the destructor of the writer flushes it. 

1911 summary_write_ops = [] 

1912 summary_writer = summary.create_file_writer_v2( 

1913 self._parameters.trace_dir, 

1914 filename_suffix=file_suffix, 

1915 max_queue=_TT_SUMMARY_MAX_QUEUE) 

1916 graph.add_to_collection( 

1917 TENSOR_TRACER_SUMMARY_COLLECTION, summary_writer) 

1918 

1919 step_value = step[0] 

1920 dt = step_value.dtype 

1921 

1922 # The step parameter to a summary write call must be 64-bit. 

1923 if dt.__ne__(dtypes.int64) and dt.__ne__( 

1924 dtypes.uint64) and dt.__ne__(dtypes.float64): 

1925 step_value = math_ops.cast(step_value, dtypes.int64) 

1926 

1927 with summary_writer.as_default(): 

1928 summary_metadata = summary_pb2.SummaryMetadata( 

1929 plugin_data=summary_pb2.SummaryMetadata.PluginData( 

1930 plugin_name=_TT_TENSORBOARD_PLUGIN_NAME)) 

1931 for key, value in kwargs.items(): 

1932 # Check whether we need to compute aggregated statistics that merge 

1933 # all cores statistics. 

1934 if not self._parameters.collect_summary_per_core: 

1935 # Merge only statistics tensor, if it is any other tensor we simply, 

1936 # concatenate them. 

1937 # Also, if there is only a single core (first dim. is 0), then skip 

1938 # aggregation. 

1939 if key == _TT_SUMMARY_TAG and value.shape.as_list()[0] != 1: 

1940 value = self.aggregate_global_cache(value) 

1941 with ops.control_dependencies([summary_writer.init()]): 

1942 summary_write_ops.append(summary.write( 

1943 _TT_SUMMARY_TAG + '/' + key + '.' + graph_summary_tag, 

1944 value, metadata=summary_metadata, 

1945 step=step_value)) 

1946 return control_flow_ops.group(summary_write_ops) 

1947 

1948 global_step = training_util.get_or_create_global_step() 

1949 step = array_ops.reshape(global_step, [1]) 

1950 self._host_call_fn = {} 

1951 

1952 host_call_deps = op_fetches + [tensor.op for tensor in processed_t_fetches] 

1953 

1954 caches_to_write = {} 

1955 with ops.control_dependencies(host_call_deps): 

1956 all_caches = self._cache_variable_for_graph(graph) 

1957 for cache_name, cache_variable in all_caches.items(): 

1958 # Increase the cache rank by 1, so that when host call concatenates 

1959 # tensors from different replicas, we can identify them with [core_id]. 

1960 new_cache_shape = [1] 

1961 new_cache_shape.extend(cache_variable.shape.as_list()) 

1962 cache = array_ops.reshape(cache_variable, new_cache_shape) 

1963 caches_to_write[cache_name] = cache 

1964 # Add step to parameter dictionary. 

1965 caches_to_write['step'] = step 

1966 # Other options without adding step to parameter dictionary are 

1967 # * host_call_fn = (_write_cache(step, caches_to_write)) : fails as it 

1968 # considers caches_to_write as a single parameter, rather than a keyword 

1969 # parameters. 

1970 # * host_call_fn = (_write_cache(step, **caches_to_write)) : fails with 

1971 # a syntax error. 

1972 self._host_call_fn[_TT_HOSTCALL_KEY] = (_write_cache, caches_to_write) 

1973 

1974 def host_call_deps_and_fn(self): 

1975 return self._host_call_fn 

1976 

1977 def get_traced_op_names(self): 

1978 """Returns the set of traced op names.""" 

1979 return self._traced_op_names 

1980 

1981 def _trace_execution(self, graph, 

1982 tensor_fetches, 

1983 op_fetches=None, 

1984 on_tpu=True): 

1985 """Commong tracing function for both CPU and TPUs. 

1986 

1987 The caller function should set device_type, num_replicas, 

1988 num_replicas_per_host, num_hosts and replica_id before calling 

1989 _trace_execution. 

1990 

1991 

1992 Args: 

1993 graph: the graph of Ops executed on the TPU. 

1994 tensor_fetches: a (list,tuple,or a single object) of tensor fetches 

1995 returned by model_fn given to session.run. Function must be provided 

1996 with as least one tensor to fetch. 

1997 op_fetches: A list of op fetches returned by model_fn given to 

1998 session.run. op_fetches and tensor_fetches are used to determine the 

1999 nodes that will be executed. Can be None. 

2000 on_tpu: True if executing on TPU. 

2001 

2002 Returns: 

2003 tensor_fetches: an exact copy of tensor_fetches that has additional 

2004 dependencies. 

2005 Raises: 

2006 RuntimeError: If tensor_fetches is None or empty. 

2007 """ 

2008 def _cast_unsupported_dtypes(tensor): 

2009 """Casts tensor to a supported type.""" 

2010 

2011 if tensor.dtype.__eq__(dtypes.int64): 

2012 # outside-compilation doesn't support int64 input yet. 

2013 return math_ops.cast(tensor, dtypes.int32) 

2014 if tensor.dtype.__eq__(dtypes.bfloat16) or tensor.dtype.__eq__( 

2015 dtypes.float16): 

2016 # Since host can't handle bf16, convert tensor to f32. 

2017 return math_ops.cast(tensor, dtypes.float32) 

2018 return tensor 

2019 

2020 trace_mode = self._parameters.trace_mode 

2021 device_type = self._tt_config.device_type 

2022 # pylint: disable=protected-access 

2023 self._outmost_context = graph._get_control_flow_context() 

2024 # pylint: enable=protected-access 

2025 

2026 analytics.track_usage('tensor_tracer', [trace_mode, device_type]) 

2027 TensorTracer.check_device_type(device_type) 

2028 TensorTracer.check_trace_mode(device_type, trace_mode) 

2029 # Check in_tensor_fetches, and op_fetches and convert them to lists. 

2030 processed_t_fetches = self._process_tensor_fetches(tensor_fetches) 

2031 op_fetches = self._process_op_fetches(op_fetches) 

2032 all_fetches = op_fetches + [tensor.op for tensor in processed_t_fetches] 

2033 

2034 # Filter out the operations that won't be executed. 

2035 # if fetches=None, then ops_in_exec_path = set(operations) 

2036 exec_op_set = self._filter_execution_path_operations(graph.get_operations(), 

2037 all_fetches) 

2038 graph_summary_tag = _graph_summary_tag(graph) 

2039 

2040 # Write report file, and determine the traced tensors. 

2041 tensor_trace_order = self._determine_trace_and_create_report( 

2042 graph, exec_op_set, graph_summary_tag) 

2043 

2044 tensor_fetch_set = set(processed_t_fetches) 

2045 tracing_ops = [] 

2046 

2047 sorted_exec_op_list = list(exec_op_set) 

2048 sorted_exec_op_list.sort(key=lambda op: op.name) 

2049 # Trace ops only if they are in the execution path. 

2050 for op in sorted_exec_op_list: 

2051 for i in range(len(op.outputs)): 

2052 out_tensor = op.outputs[i] 

2053 tensor_name = out_tensor.name 

2054 if tensor_name not in tensor_trace_order.tensorname_to_cache_idx: 

2055 continue 

2056 self._traced_op_names.add(op.name) 

2057 # Create the list of consumers before calling _preprocess_traced_tensor. 

2058 # Otherwise, adding control input below, will introduce a cycle in the 

2059 # graph. 

2060 consumers = out_tensor.consumers() 

2061 # Not all consumers may be in the exec path. Filter out the consumers 

2062 # to keep the graph simpler. 

2063 consumers = [cop for cop in consumers if cop in exec_op_set] 

2064 

2065 # If there is no consumer of the tensor, there is no need to trace it; 

2066 # unless the tensor itself is one of the fetches. 

2067 is_a_fetched_tensor = out_tensor in tensor_fetch_set 

2068 if (not consumers) and (not is_a_fetched_tensor): 

2069 continue 

2070 

2071 op_control_flow_context = self._get_op_control_flow_context(op) 

2072 if op_control_flow_context: 

2073 # pylint: disable=protected-access 

2074 graph._set_control_flow_context(op_control_flow_context) 

2075 # pylint: enable=protected-access 

2076 

2077 processed_tensors = self._preprocess_traced_tensor(out_tensor) 

2078 

2079 if on_tpu: 

2080 for signature in processed_tensors.keys(): 

2081 processed_tensors[signature] = _cast_unsupported_dtypes( 

2082 processed_tensors[signature]) 

2083 

2084 if self._use_tensor_values_cache(): 

2085 # Use a small cache (either temp cache or tf local variable) to store 

2086 # the characteristics of the tensor. 

2087 if self._use_temp_cache(): 

2088 cache_idx = tensor_trace_order.tensorname_to_cache_idx[tensor_name] 

2089 self._save_tensor_value_to_tmp_cache(cache_idx, 

2090 processed_tensors, 

2091 graph) 

2092 trace_op = None 

2093 else: 

2094 cache_idx = tensor_trace_order.tensorname_to_cache_idx[tensor_name] 

2095 trace_op = self._save_tensor_value_to_cache_op(cache_idx, 

2096 processed_tensors, 

2097 graph) 

2098 elif self._use_tensor_buffer(): 

2099 if len(processed_tensors) != 1: 

2100 raise RuntimeError('Multiple stats are only allowed in compact ' 

2101 'mode.') 

2102 processed_out_tensor = list(processed_tensors.values())[0] 

2103 # Store the whole tensor in a buffer. 

2104 trace_op = self._snapshot_tensor(processed_out_tensor) 

2105 else: 

2106 

2107 def tpu_wrap_trace_fn(tensor, out_tensor_name): 

2108 """Wraps the trace_fn with outside compilation if on TPUs.""" 

2109 tensor_trace_fn = self._make_tensor_trace_fun(out_tensor_name, 

2110 tensor_trace_order) 

2111 if on_tpu: 

2112 return tpu_replication.outside_compilation( 

2113 tensor_trace_fn, tensor) 

2114 else: 

2115 return tensor_trace_fn(tensor) 

2116 

2117 if len(processed_tensors) != 1: 

2118 raise RuntimeError('Multiple stats are only allowed in compact ' 

2119 'mode.') 

2120 # Collecting multiple statistics are only supported in the summary 

2121 # mode that uses compact format(self._use_tensor_values_cache = true). 

2122 # Non-compact mode currently allows single stat per tensor. 

2123 processed_out_tensor = next(iter(processed_tensors.values())) 

2124 trace_op = tpu_wrap_trace_fn(processed_out_tensor, tensor_name) 

2125 

2126 if op_control_flow_context: 

2127 # pylint: disable=protected-access 

2128 graph._set_control_flow_context(self._outmost_context) 

2129 # pylint: enable=protected-access 

2130 if trace_op: 

2131 if is_a_fetched_tensor: 

2132 tracing_ops.append(trace_op) 

2133 continue 

2134 # Add it to all consumers, as some consumers may not be executed if 

2135 # they are in a control flow. 

2136 for consumer_op in consumers: 

2137 # pylint: disable=protected-access 

2138 consumer_op._add_control_input(trace_op) 

2139 # pylint: enable=protected-access 

2140 

2141 # pylint: disable=protected-access 

2142 graph._set_control_flow_context(self._outmost_context) 

2143 # pylint: enable=protected-access 

2144 if tracing_ops: 

2145 # If we are tracing a fetched tensor, their dependency is stored in 

2146 # tracing_ops. 

2147 processed_t_fetches = control_flow_ops.tuple(processed_t_fetches, 

2148 control_inputs=tracing_ops) 

2149 if self._use_tensor_values_cache() or self._use_tensor_buffer(): 

2150 if self._use_temp_cache(): 

2151 # Create the temporary tf cache variable by concantanating all 

2152 # statistics. 

2153 graph_cache_var = self._cache_variable_for_graph(graph) 

2154 if graph not in self._temp_cache_var: 

2155 raise RuntimeError('graph is not in self._temp_cache_var') 

2156 graph_cache_var[_TT_SUMMARY_TAG] = array_ops_stack.stack( 

2157 self._temp_cache_var[graph], axis=0, name='stack_all_op_signatures') 

2158 if self._create_host_call(): 

2159 self._prepare_host_call_fn(processed_t_fetches, op_fetches, graph, 

2160 graph_summary_tag) 

2161 if not on_tpu: 

2162 write_cache, caches_to_write = self._host_call_fn[_TT_HOSTCALL_KEY] 

2163 cache_write_op = write_cache(**caches_to_write) 

2164 processed_t_fetches = control_flow_ops.tuple( 

2165 processed_t_fetches, control_inputs=[cache_write_op]) 

2166 del self._host_call_fn[_TT_HOSTCALL_KEY] 

2167 elif self._parameters.flush_summaries_with_outside_compile: 

2168 write_cache, caches_to_write = self._host_call_fn[_TT_HOSTCALL_KEY] 

2169 if (_TT_SUMMARY_TAG in caches_to_write and 'step' in caches_to_write): 

2170 step = caches_to_write['step'] 

2171 tensor_tracer_summary = caches_to_write[_TT_SUMMARY_TAG] 

2172 tt_core_summary = self.merge_caches_on_tpu(tensor_tracer_summary[0]) 

2173 if not self._parameters.collect_summary_per_core: 

2174 tt_core_summary = self.aggregate_global_cache(tt_core_summary) 

2175 

2176 def write_if_core_0(step, replica_id, tt_summary): 

2177 

2178 return cond.cond( 

2179 math_ops.equal(replica_id, 0), 

2180 lambda: write_cache(step=step, event_file_suffix=None, # pylint: disable=g-long-lambda 

2181 tensor_tracer_summary=tt_summary), 

2182 control_flow_ops.no_op) 

2183 

2184 write_op = tpu_replication.outside_compilation( 

2185 write_if_core_0, 

2186 step=step, 

2187 replica_id=self._replica_id, 

2188 tt_summary=tt_core_summary) 

2189 processed_t_fetches = control_flow_ops.tuple( 

2190 processed_t_fetches, control_inputs=[write_op]) 

2191 del self._host_call_fn[_TT_HOSTCALL_KEY] 

2192 else: 

2193 raise ValueError('Outside compiled flush in only supported for ' 

2194 'summary mode') 

2195 else: 

2196 processed_t_fetches = self._flush_tensor_values_cache( 

2197 processed_t_fetches, op_fetches, on_tpu=on_tpu, 

2198 tensor_trace_order=tensor_trace_order, 

2199 graph=graph) 

2200 

2201 # processed_t_fetches is a list at this point. Convert it to the same 

2202 # format as given in tensor_fetches. 

2203 return self._convert_fetches_to_input_format(tensor_fetches, 

2204 processed_t_fetches) 

2205 

2206 def trace_tpu(self, graph, 

2207 tensor_fetches, 

2208 op_fetches=None, 

2209 num_replicas=None, 

2210 num_replicas_per_host=None, 

2211 num_hosts=None): 

2212 """Traces the tensors generated by TPU Ops in a TF graph. 

2213 

2214 Args: 

2215 graph: the graph of Ops executed on the TPU. 

2216 tensor_fetches: a (list,tuple,or a single object) of tensor fetches 

2217 returned by model_fn given to session.run. Function must be provided 

2218 with as least one tensor to fetch. 

2219 op_fetches: A list of op fetches returned by model_fn given to 

2220 session.run. op_fetches and tensor_fetches are used to determine the 

2221 nodes that will be executed. Can be None. 

2222 num_replicas: number of replicas used on the TPU. 

2223 num_replicas_per_host: number of replicas per TPU host. 

2224 num_hosts: total number of TPU hosts. 

2225 

2226 Returns: 

2227 tensor_fetches: an exact copy of tensor_fetches that has additional 

2228 dependencies. 

2229 """ 

2230 if isinstance(graph, func_graph.FuncGraph) or isinstance( 

2231 graph, function._FuncGraph): # pylint: disable=protected-access 

2232 logging.warning('Tensor Tracer is not supported for tracing FuncGraphs. ' 

2233 'Ignoring tracing.') 

2234 return tensor_fetches 

2235 

2236 if graph in TensorTracer._traced_graphs: 

2237 logging.warning('Graph is already rewritten with tensor tracer, ignoring ' 

2238 'multiple calls.') 

2239 return tensor_fetches 

2240 else: 

2241 TensorTracer._traced_graphs.add(graph) 

2242 # Reset the parameters in case parameters are changed. 

2243 self._parameters = tensor_tracer_flags.TTParameters() 

2244 self._tt_config.device_type = _DEVICE_TYPE_TPU 

2245 self._tt_config.num_replicas = num_replicas 

2246 self._tt_config.num_replicas_per_host = num_replicas_per_host 

2247 self._tt_config.num_hosts = num_hosts 

2248 if self._tt_config.num_replicas is not None: 

2249 if self._tt_config.num_replicas_per_host is None: 

2250 self._tt_config.num_replicas_per_host = 8 

2251 if self._tt_config.num_hosts is None: 

2252 self._tt_config.num_hosts = ( 

2253 num_replicas // self._tt_config.num_replicas_per_host + 

2254 (num_replicas % self._tt_config.num_replicas_per_host > 0)) 

2255 

2256 if self._parameters.graph_dump_path: 

2257 graph_io.write_graph(graph, self._parameters.graph_dump_path, 

2258 'graph_before_tt.pbtxt') 

2259 with graph.as_default(): 

2260 self._add_replica_id_to_graph() 

2261 tensor_fetches = self._trace_execution(graph, tensor_fetches, op_fetches, 

2262 on_tpu=True) 

2263 if self._parameters.graph_dump_path: 

2264 graph_io.write_graph(graph, self._parameters.graph_dump_path, 

2265 'graph_after_tt.pbtxt') 

2266 return tensor_fetches 

2267 

2268 def trace_cpu(self, graph, tensor_fetches, op_fetches=None): 

2269 """Traces the tensors generated by CPU Ops in a TF graph. 

2270 

2271 Args: 

2272 graph: the graph of Ops executed on the CPU. 

2273 tensor_fetches: a (list,tuple,or a single object) of tensor fetches 

2274 returned by model_fn given to session.run. Function must be provided 

2275 with as least one tensor to fetch. 

2276 op_fetches: A list of op fetches returned by model_fn given to 

2277 session.run. op_fetches and tensor_fetches are used to determine the 

2278 nodes that will be executed. Can be None. 

2279 

2280 Returns: 

2281 tensor_fetches: an exact copy of tensor_fetches that has additional 

2282 dependencies. 

2283 """ 

2284 if isinstance(graph, func_graph.FuncGraph) or isinstance( 

2285 graph, function._FuncGraph): # pylint: disable=protected-access 

2286 logging.warning('Tensor Tracer is not supported for tracing FuncGraphs. ' 

2287 'Ignoring tracing.') 

2288 return tensor_fetches 

2289 

2290 if graph in TensorTracer._traced_graphs: 

2291 logging.warning('Graph is already rewritten with tensor tracer, ignoring ' 

2292 'multiple calls.') 

2293 return tensor_fetches 

2294 else: 

2295 TensorTracer._traced_graphs.add(graph) 

2296 # Reset the parameters in case parameters are changed. 

2297 self._parameters = tensor_tracer_flags.TTParameters() 

2298 

2299 self._tt_config.device_type = _DEVICE_TYPE_CPU 

2300 self._tt_config.num_replicas = 1 

2301 self._tt_config.num_replicas_per_host = 1 

2302 self._tt_config.num_hosts = 1 

2303 self._replica_id = 0 

2304 if self._parameters.graph_dump_path: 

2305 graph_io.write_graph(graph, self._parameters.graph_dump_path, 

2306 'graph_before_tt.pbtxt') 

2307 with graph.as_default(): 

2308 tensor_fetches = self._trace_execution(graph, tensor_fetches, op_fetches, 

2309 on_tpu=False) 

2310 if self._parameters.graph_dump_path: 

2311 graph_io.write_graph(graph, self._parameters.graph_dump_path, 

2312 'graph_after_tt.pbtxt') 

2313 return tensor_fetches