Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/tpu/tensor_tracer.py: 18%
953 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ========================================================================
15"""A utility to trace tensor values on TPU."""
17import collections
18import hashlib
19import operator
20import os
21import os.path
22import sys
24import numpy as np
26from tensorflow.core.framework import summary_pb2
27from tensorflow.python.eager import monitoring
28from tensorflow.python.framework import constant_op
29from tensorflow.python.framework import dtypes
30from tensorflow.python.framework import func_graph
31from tensorflow.python.framework import function
32from tensorflow.python.framework import graph_io
33from tensorflow.python.framework import ops
34from tensorflow.python.framework import tensor_util
35from tensorflow.python.lib.io import file_io
36from tensorflow.python.ops import array_ops
37from tensorflow.python.ops import array_ops_stack
38from tensorflow.python.ops import cond
39from tensorflow.python.ops import control_flow_case
40from tensorflow.python.ops import control_flow_ops
41from tensorflow.python.ops import control_flow_util
42from tensorflow.python.ops import gen_math_ops
43from tensorflow.python.ops import init_ops
44from tensorflow.python.ops import linalg_ops
45from tensorflow.python.ops import logging_ops
46from tensorflow.python.ops import math_ops
47from tensorflow.python.ops import nn_impl
48from tensorflow.python.ops import state_ops
49from tensorflow.python.ops import string_ops
50from tensorflow.python.ops import summary_ops_v2 as summary
51from tensorflow.python.ops import variable_scope
52from tensorflow.python.platform import analytics
53from tensorflow.python.platform import gfile
54from tensorflow.python.platform import remote_utils
55from tensorflow.python.platform import tf_logging as logging
56from tensorflow.python.summary import summary_iterator
57from tensorflow.python.tpu import tensor_tracer_flags
58from tensorflow.python.tpu import tensor_tracer_report
59from tensorflow.python.tpu import tpu_replication
60from tensorflow.python.tpu.ops import tpu_ops
61from tensorflow.python.training import training_util
63_DEVICE_TYPE_TPU = 'tpu'
64_DEVICE_TYPE_CPU = 'cpu'
65_TRACE_MODE_PART_TENSOR_SIZE = 3
67_REASON_OUTSIDE_OP_RANGE = 'not-traced-outside-op-range'
68_REASON_UNSAFE_OP = 'not-traced-unsafe-op'
69_REASON_WHILELOOP_OP = 'not-traced-special-whileloop-op'
70_REASON_CONTROLFLOW_OP = 'not-traced-control-flow-op'
71_REASON_IN_CONTROL_FLOW = 'not-traced-in-control-flow'
72_REASON_UNSAFE_SCALAR = 'not-traced-unsafe-scalar'
73_REASON_SKIP_SCALAR = 'not-traced-scalar'
74_REASON_LESS_INTERESTING_OP = 'not-traced-less-interesting-op'
75_REASON_DEVICE_MISMATCH = 'not-traced-device-mismatch'
76_REASON_DYNAMIC_SHAPE = 'not-traced-dynamic-shape'
77_REASON_SCALAR_GET_TRACED = 'traced-scalar'
78_REASON_TENSOR_GET_TRACED = 'traced-tensor'
79_REASON_USER_INCLUDED = 'traced-user-included'
80_REASON_USER_EXCLUDED = 'not-traced-user-excluded'
81_REASON_NOT_EXECUTED = 'not-traced-not-in-exec-path'
82_REASON_NON_NUMERIC_TENSOR = 'not-traced-non-numeric-tensor'
83_REASON_FEEDS_WHILELOOP_OP = 'not-traced-feeds-special-whileloop-op'
85_OUTPUT_STREAM_ESCAPE = 'file://'
86_TENSOR_TRACER_COLLECTION = 'tensor_tracer_variables'
87TENSOR_TRACER_SUMMARY_COLLECTION = 'tensor_tracer_summary_writers'
88_TRACE_FILE_NAME = 'trace.all'
89_COMPACT_TRACE_FILE_PREFIX = 'compact_trace.'
90_COMPACT_TRACE_ENTRY_INIT_VALUE = -1.0
91_TENSOR_TRACER_STORAGE = 'tensor_tracer_storage'
92_TT_SNAPSHOT = 'tensor_tracer_snapshot'
93_REPLICA_ID_TAG = '#replica-id: '
94_SKIP_REPORT_FILE = 'None' # Do not write report proto if --report_file=None
96_TT_SUMMARY_NORM = tensor_tracer_flags.TT_SUMMARY_NORM
97_TT_SUMMARY_MAX = tensor_tracer_flags.TT_SUMMARY_MAX
98_TT_SUMMARY_MAX_ABS = tensor_tracer_flags.TT_SUMMARY_MAX_ABS
99_TT_SUMMARY_MIN = tensor_tracer_flags.TT_SUMMARY_MIN
100_TT_SUMMARY_MEAN = tensor_tracer_flags.TT_SUMMARY_MEAN
101_TT_SUMMARY_VAR = tensor_tracer_flags.TT_SUMMARY_VAR
102_TT_SUMMARY_SIZE = tensor_tracer_flags.TT_SUMMARY_SIZE
103_TT_SUMMARY_SPARSITY = tensor_tracer_flags.TT_SUMMARY_SPARSITY
105_TT_SUMMARY_TAG = 'tensor_tracer_summary'
106_TT_TENSORBOARD_PLUGIN_NAME = 'tensor_tracer'
107_TT_HOSTCALL_KEY = 'tensor_tracer_host_call'
108_TT_EVENT_FILE_SUFFIX = '.tensor_tracer'
110_TT_SUMMARY_MAX_QUEUE = 10
112tt_gauge = monitoring.BoolGauge('/tensorflow/api/tensor_tracer/v1',
113 'tensor tracer usage', 'method')
116def _graph_summary_tag(graph):
117 """Generates and returns a summary tag name for the given graph."""
119 if graph is None:
120 raise RuntimeError('graph is None')
121 # The chance of collision with md5 is effectively 0.
122 hash_id = hashlib.md5()
123 hash_id.update(repr(graph).encode('utf-8'))
124 # hexdigest() returns a string.
125 return hash_id.hexdigest()
128def set_parameters(tensor_tracer_params=None):
129 """Enables tensor tracer and sets its parameters.
131 Example usage:
132 tensor_tracer_parameters = {'trace_dir': '/usr/tmp/trace_dir',
133 'trace_mode': 'norm',
134 'report_file': '/usr/tmp/trace_dir/report.all'}
135 tensor_tracer.set_parameters(tensor_tracer_parameters)
137 This sets up the parameters for tensor tracer. A call to tensor tracer as
138 below is necessary to enable debugging on CPUs and GPUs. On TPUs below can be
139 skipped as this call is hooked into tpu.rewrite.
140 tt = tensor_tracer.TensorTracer()
141 loss = tt.trace_cpu(tf.get_default_graph(), tensor_fetches=loss)
143 Args:
144 tensor_tracer_params: Tensor tracer parameter dictionary. Below gives
145 examples of these parameters: See tensor_tracer_report.py for all
146 parameters.
147 - enable: If set, tensor tracer will be enabled. Calling
148 enable_tensor_tracer automatically adds this parameters.
149 - trace_mode: The trace_mode to be used by tensor tracer. These include:
150 - summary: Collects multiple statistics for traced tensors, and writes
151 them a summary file that can be visualized using tensorboard. This
152 mode currently only works for TPUEstimator. It can be also be used
153 for other models, but outfeed must be handled by the user.
154 - norm: Collects norm of each traced tensor and writes them into a
155 text file pointed by 'trace_dir' flag. (Default mode).
156 - nan-inf: Checks the existince of NaNs and Infs in the tensor, and
157 writes a boolean value to a text file pointed by 'trace_dir' flag.
158 Note that 'norm' mode can also capture this information with more
159 numerical info.
160 - max-abs: Collects the absolute max for each traced tensors and
161 writes it into a text file pointed by 'trace_dir' flag.
162 - full-tensor: Writes the full tensor content of the traced tensors
163 into a text file pointed by 'trace_dir' flag.
164 - part-tensor: Writes a part of the tensor content of the traced
165 tensors into a text file pointed by 'trace_dir' flag.
166 - full_tensor_summary: Writes the full tensors as binary event files.
167 The outputs can be read using: trace =
168 tensor_tracer.read_tensor_tracer_event_file(event_file_path)
170 - report_file: Path to the metadata file that is written during graph
171 construction. If not set, metadata will be printed to stdout during
172 graph construction.
173 - trace_dir: Path where the execution traces will be written during the
174 graph execution. If not set, trace will be printed to stderr.
175 - trace_level: Tensor tracer aims to trace everything it can. This
176 introduces some overhead on graph execution and graph compilation
177 times. Using trace_level parameter, it is possible to trace operation
178 based on their priorities. For example, - trace_level=7 is the highest
179 trace_level, in which every op is traced. - trace_level=6 will skip
180 constant operations such as tf.constant. - trace_level=5 will skip
181 less important ops such as tf.identities. - The default trace_level=3,
182 that will skip concat ops, or random number generators. - To reduce
183 the graph compile time overhead, trace_level can be set to 0, that
184 will skip additions, and substractions, and multiplications as well.
185 - excluded_opnames: If set, any matching op name will not be traced.
186 excluded_opnames can be set as a regular expression. E.g,
187 excluded_opnames=.* will exclude everything.
188 - excluded_optypes: If set, any matching op type will not be traced.
189 excluded_optypes can be set as a regular expression. E.g,
190 excluded_optypes=.* will exclude everything. excluded_optypes=MatMul
191 will exclude all MatMul ops from tracing.
192 - included_opnames: If set, any matching op name will be forced to be
193 traced. included_opnames can be set as a regular expression. E.g,
194 '--included_opnames=some_op --excluded_opname=*.' will only trace
195 some_op.
196 - included_optypes: If set, any matching op type will be forced to be
197 traced. included_optypes can be set as a regular expression. E.g,
198 '--included_optypes=some_op_type --excluded_optypes=*.' will trace
199 only the ops with type 'some_op_type'
200 - flush_summaries: If summary mode is used, flush_summaries=1 will
201 flush summaries using outside compilation. Note that, if used with
202 low level APIs, flush_summaries=1 is necessary to obtain results.
203 Advanced Flags:
204 - trace_scalar: Scalar values are not traced by default. If this flag is
205 set, scalar values will also be traced.
206 - op_range: In the form of '%d:%d' that limits the tracing to the ops
207 within this limit. --op_range='5:10' will trace only the ops that have
208 topological order between 5-10.
209 - submode: 'brief' or 'detailed'. If the trace mode is not compact,
210 brief mode will print only the id of each traced tensor to save some
211 space. 'detailed' mode prints the full tensor name.
212 - use_fingerprint_subdirectory: The trace directory will be chosen as
213 using the fingerprint of the trace metadata under the provided
214 trace_dir.
215 """
216 enable_flags = '--%s=1' % tensor_tracer_flags.FLAG_NAME_ENABLE
217 if tensor_tracer_params:
218 for key, value in tensor_tracer_params.items():
219 enable_flags += ' --%s=%s' % (key, value)
220 os.environ[tensor_tracer_flags.FLAGS_ENV_VAR] = enable_flags
223def op_priority(op_type):
224 """Returns the priority of the op.
226 If the priority of the op is k, it will be traced if trace_level>=k.
227 Args:
228 op_type: String name of the operation type.
229 Returns:
230 Integer value corresponding the priority of the op.
231 """
232 if op_type in ('Const', 'Shape', 'BroadcastGradientArgs', 'Range',
233 'VariableShape', 'Fill', 'OneHot', 'ShapeN'):
234 # Lowest priority ops, e.g., constant ops across different steps,
235 # They will be traced only if trace_level>=7
236 return 7
238 if op_type in ('Identity', 'Cast', 'Reshape', 'ExpandDims', 'StopGradient',
239 'PreventGradient', 'Squeeze', 'Gather', 'GatherNd'):
240 # Operations without numerical effects.
241 # They will be only if trace_level>=6
242 return 6
243 if op_type in ('ConcatV2', 'Concat', 'StridedSlice', 'Slice', 'Pack', 'Tile',
244 'CollectivePermute', 'SplitV', 'DynamicPartition'):
245 # Operations that merge or slice an input, will be traced if trace_level>=5
246 return 5
247 if op_type in ('Pad', 'RandomUniformInt', 'GreaterEqual'):
248 # Operations less likely to provide useful information,
249 # will be traced if trace_level>=4
250 return 4
251 if op_type in ('Sum', 'AddV2', 'Add', 'AddN', 'BiasAdd', 'CrossReplicaSum'):
252 # Add operations that are less likely create any issues, will be traced
253 # if trace_level>=3 (default=3)
254 return 3
255 if op_type in ('Neg', 'Sub'):
256 # Sub operations that are less likely create any issues, will be traced
257 # trace_level>=2
258 return 2
259 if op_type in ('Mul', 'Square', 'MatMul', 'RandomUniform', 'Select',
260 'Maximum', 'Mean', 'Variance', 'Exp', 'Rsqrt'):
261 # Multiplication and some other operations, will be traced if trace_level>=1
262 return 1
264 # Unclassified op_types default to being traced at level 2 and above.
265 return 2
268def read_tensor_tracer_event_file(event_file):
269 """Reads the event file written by tensor tracer.
271 This can be used to read the full tensors written into binary event files by
272 by TensorTracer with trace_mode=full_tensor_summary.
274 Example usage:
275 result_dict_list = tensor_tracer.read_tensor_tracer_event_file(
276 event_file_path)
277 for result_dict in result_dict_list:
278 for step, tensor_dict in result_dict.items():
279 for tensor_name, full_tensor_content in tensor_dict.items():
280 logging.info(tensor_name, full_tensor_content)
282 Args:
283 event_file: Path to the event file that contains only tensor tracer events.
284 Returns:
285 A list of event dictionaries, each of which with the form:
286 {step_number: {tensor_name: tensor_content}}. This is a list instead of
287 a single event dictionary because it is possible that an event file may
288 have multiple event traces, each of them covering the same step ranges.
289 Raises:
290 ValueError: If an unexpected trace is found.
291 """
293 # Keeps track of how many times that a step number shows up in these events.
294 step_occurrence_count = collections.defaultdict(int)
296 # List of step occurrences.
297 step_occurrence_list = []
299 for trace_event in summary_iterator.summary_iterator(event_file):
300 # First event is an event with file_version: "brain.Event:2"
301 if not trace_event.HasField('summary'):
302 continue
303 if len(trace_event.summary.value) != 1:
304 raise ValueError('Single step contains %d summary values,'
305 ' expected 1.' % len(trace_event.summary.value))
306 step = trace_event.step
307 step_occurrence_count[step] += 1 # a new occurrence for this step.
309 occurrence_idx = step_occurrence_count[step] - 1
310 occurrence_size = len(step_occurrence_list)
312 if occurrence_idx == occurrence_size:
313 # This particular occurrence isn't yet recorded on step_occurrence_list.
314 # So append this new occurrence to the end of step_occurrence_list.
315 new_occurrence = collections.defaultdict(dict)
316 step_occurrence_list.append(new_occurrence)
317 else:
318 # This particular occurrence must be already recorded on
319 # step_occurrence_list (i.e. occurrence_idx < occurrence_size).
320 if occurrence_idx > occurrence_size:
321 raise ValueError('Unexpected: occurrence_idx (%d) > '
322 'occurrence_size (%d)' % (occurrence_idx,
323 occurrence_size))
324 tensor_value = trace_event.summary.value[0]
325 tensor_name = tensor_value.tag
327 real_shape = [d.size for d in tensor_value.tensor.tensor_shape.dim]
328 tensor_content = np.frombuffer(
329 tensor_value.tensor.tensor_content,
330 dtypes.DType(tensor_value.tensor.dtype).as_numpy_dtype()
331 ).reshape(real_shape)
332 step_occurrence_list[occurrence_idx][step][tensor_name] = tensor_content
333 return step_occurrence_list
336def trace_tensor(tensor, tracepoint_name=None):
337 """Programmatic interface to trace a tensor with Tensor Tracer.
339 Tensor Tracer, by default, traces all tensors in the execution. This function
340 can be used to limit traced tensors. If this function is called for a subset
341 of the tensors, only those will be traced.
343 For example, Tensor Traacer will only trace c below.
344 c = tf.MatMul(a, b)
345 tensor_tracer.trace_tensor(c)
346 d = tf.add(c, 1)
347 Args:
348 tensor: the tensor object for which the tracing is requested.
349 tracepoint_name: an optional tensor tracepoint name string. A tracepoint
350 name is an Tensor Tracer internal name for the tensor. It is useful when
351 comparing equivalent traces from different models that have different
352 tensor namings. Equivalent tensors (with different names) can be mapped
353 to each other by assigning a common tracepoint_name.
355 Returns:
356 The provided tensor.
357 """
358 if tracepoint_name is None:
359 tracepoint_name = tensor.name
360 tensor.graph.get_collection(_TENSOR_TRACER_COLLECTION)
361 tensor.graph.add_to_collection(_TENSOR_TRACER_COLLECTION,
362 (tensor, tracepoint_name))
363 return tensor
366def keras_layer_tracepoint(layer, checkpoint_name):
367 """An interface for adding the tensor outputs of a keras layer.
369 Encapsulates trace_tensor.
371 Args:
372 layer: A keras layer.
373 checkpoint_name: a string name for the checkpoint. This name has to be a
374 unique name if used within model comparison. The tensors that have the same
375 checkpoint identifier is compared in model comparison.
377 Returns:
378 The provided layer.
379 """
380 try:
381 outputs = layer.output
382 if tensor_util.is_tf_type(outputs):
383 trace_tensor(outputs, '%s' % (checkpoint_name))
384 else:
385 idx = 0
386 for output_tensor in outputs:
387 if tensor_util.is_tf_type(outputs):
388 trace_tensor(output_tensor, '%s_%d' % (checkpoint_name, idx))
389 idx += 1
390 except AttributeError:
391 pass
392 except RuntimeError:
393 pass
394 return layer
397class TensorTracer:
398 """A software construct for tracing tensor values in a TF graph.
400 This utility is disabled by default. It is hooked into tpu.rewrite, so it can
401 easily be enabled on TPUs by setting the TENSOR_TRACER_FLAGS env variable as
402 below without a code change.
403 export TENSOR_TRACER_FLAGS="--enable=1"
405 Below is the use example to enable it on CPUs or GPUs, or for more advance use
406 cases on TPUs.
408 a = x + 1
409 b = a * 2
410 rs = tf.reduce_sum(b)
411 tensor_tracer.set_parameters({'trace_dir': 'path/to/trace_dir',
412 'report_file: 'path/to/report/file'})
413 tt = tensor_tracer.TensorTracer()
414 if on_tpu:
415 rs = tt.trace_tpu(tf.get_default_graph(),
416 tensor_fetches=rs)
417 else:
418 rs = tt.trace_cpu(tf.get_default_graph(),
419 tensor_fetches=rs)
420 session.run(rs)
422 If it is enabled, it will trace the output tensor values of
423 selected Ops in the graph. It has two outputs: (1) the traces and (2)
424 a report. The traces are dumped to a specified directory during the graph
425 execution, while the report is dumped during the graph construction.
426 By passing options via the env variable, users can change:
427 (1) the trace mode (e.g., detecting NaN/Inf, printing partial or
428 full tensor values)
429 (2) which Ops to be traced (via op.name or op.type)
430 (3) output trace file path.
432 """
433 # The set of graphs that are rewritten by tensor tracer.
434 _traced_graphs = set()
436 @staticmethod
437 def is_enabled():
438 """Returns True if TensorTracer is enabled."""
439 try:
440 enable = tensor_tracer_flags.TTParameters().is_enabled()
441 # Add metrics to determine API usage.
442 if enable: tt_gauge.get_cell('is_enabled').set(True)
443 return enable
444 except (ValueError, RuntimeError) as e:
445 logging.warning(
446 'Tensor Tracer V1 flags processing error encountered in is_enabled '
447 'check. %s', e)
448 # TODO(b/210212559): Find a more robust fix.
449 # Should only produce exception if Tensor Tracer is enabled.
450 return True
452 @staticmethod
453 def check_device_type(device_type):
454 """Checks if the given device type is valid."""
456 if device_type not in (_DEVICE_TYPE_TPU, _DEVICE_TYPE_CPU):
457 raise ValueError('Invalid device_type "%s"'%device_type)
459 @staticmethod
460 def check_trace_mode(device_type, trace_mode):
461 """Checks if the given trace mode work on the given device type.
463 Args:
464 device_type: Device type, TPU, GPU, CPU.
465 trace_mode: Tensor tracer trace mode.
466 Raises:
467 ValueError: If the given trace mode is not supported for the device.
468 """
469 if trace_mode == tensor_tracer_flags.TRACE_MODE_FULL_TENSOR_SUMMARY:
470 if device_type != _DEVICE_TYPE_TPU:
471 raise ValueError('Device_type "%s" is not yet supported for '
472 'trace mode "%s"' % (device_type, trace_mode))
474 @staticmethod
475 def loop_cond_op(op):
476 return op.type in ('LoopCond', 'RefLoopCond')
478 @staticmethod
479 def while_loop_op(op):
480 """Returns true if op is one of the special ops of in a while loop.
482 Args:
483 op: A tf.Operation.
485 Returns:
486 True if the given op is one of [Switch, Merge, Enter, Exit,
487 NextIteration, LoopCond], which are all building blocks for TF while
488 loops.
489 """
490 return (control_flow_util.IsLoopSwitch(op) or
491 control_flow_util.IsLoopMerge(op) or
492 control_flow_util.IsLoopEnter(op) or
493 control_flow_util.IsLoopExit(op) or
494 TensorTracer.loop_cond_op(op) or
495 op.type in ('RefNextIteration', 'NextIteration'))
497 @staticmethod
498 def control_flow_op(op):
499 """Returns true if op is one of the special ops of in a while loop.
501 Args:
502 op: A tf.Operation.
504 Returns:
505 True if the given op is one of [Switch, Merge, Enter, Exit,
506 NextIteration, LoopCond], which are all building blocks for TF while
507 loops.
508 """
509 return (control_flow_util.IsSwitch(op) or
510 control_flow_util.IsMerge(op))
512 @staticmethod
513 def unsafe_op(op):
514 """Returns True if this op is not safe to be traced."""
516 # Reasons for not including following op types:
517 # Assign: cause incorrect result with CPU tracing.
518 if op.type == 'Assign':
519 return True
520 return False
522 @staticmethod
523 def device_mismatch(device_type, op):
524 if device_type == _DEVICE_TYPE_TPU:
525 # pylint: disable=protected-access
526 return tpu_replication._TPU_REPLICATE_ATTR not in op.node_def.attr
527 # pylint: enable=protected-access
528 return False
530 @staticmethod
531 def unsafe_scalar_trace(op):
532 """Return true if scalar output tensor from Op is not safe to be traced."""
534 # Tracing the following causes cycle in the graph on TPU.
535 if op.type in ('LoopCond', 'Enter', 'Merge', 'Const',
536 'Switch', 'Less', 'ReadVariableOp'):
537 return True
538 # Tracing the following will cause casting-issue
539 # with the norm tracing mode or other compilation issues on CPU.
540 if op.type in ('VarHandleOp', 'IteratorToStringHandle',
541 'IteratorGetNext', 'OneShotIterator',
542 'IteratorV2', 'MakeIterator',
543 'BatchDatasetV2', 'MapDataset',
544 'FixedLengthRecordDataset', 'TakeDataset', 'ZipDataset',
545 'Placeholder', 'PlaceholderWithDefault', 'StridedSlice'):
546 return True
547 return False
549 def _is_interesting_op(self, op):
550 """Returns True if the given op is not an interesting one to be traced."""
551 return op_priority(op.type) <= self._parameters.trace_level
553 @staticmethod
554 def reason(op_idx, details):
555 """Returns reason why the Op at op_idx is traced or not."""
557 return '%d %s'%(op_idx, details)
559 def __init__(self):
560 """Initializes a TensorTracer.
562 Sets the various member fields from the flags (if given) or the defaults.
563 """
564 self._replica_id = None
565 self._tt_config = tensor_tracer_report.TensorTracerConfig()
566 self._parameters = tensor_tracer_flags.TTParameters()
567 self._host_call_fn = {}
568 # _cache_variables is a dict (key = graph, value = dicts
569 # (key = name, value = tensors))
570 self._cache_variables = {}
571 self._history_value_cache = {}
573 self._traced_op_names = set()
574 self._report_proto = None
575 # _temp_cache_var is a dict (key = graph, value = [])
576 self._temp_cache_var = {}
577 self._report_proto_path = ''
578 self._outmost_context = None
580 def report_proto(self):
581 """Getter for tensor_tracer.proto object for summary and full_tensor_summary modes.
583 Returns:
584 A tensor_tracer.proto object.
585 Raises:
586 ValueError if called before tracing happens, or when trace mode is not
587 summary or full_tensor_summary.
588 """
589 if self._report_proto:
590 return self._report_proto
591 else:
592 raise ValueError('Call to report_proto must be done after tracing.'
593 'Report proto only exists for '
594 'trace_mode=[summary|full_tensor_summary]')
596 def report_proto_path(self):
597 """Getter for path where tensor_tracer.proto object should be written.
599 Returns:
600 A string path.
601 """
602 return self._report_proto_path
604 def _escape_namescopes(self, variable_name):
605 return variable_name.replace('/', '_').replace(':', '_')
607 def _cache_variable_for_graph(self, graph):
608 if graph not in self._cache_variables:
609 self._cache_variables[graph] = {}
610 return self._cache_variables[graph]
612 def _create_or_get_tensor_history_values_cache(self,
613 cache_name,
614 graph,
615 shape=None,
616 dtype=dtypes.float32):
617 """Creates a variable as the cache to store historic intermediate tensor values.
619 Args:
620 cache_name: Name to be given to the cache (an instance of tf.variable).
621 graph: Tensorflow graph.
622 shape: A list of dimensions.
623 dtype: Data type of created cache.
624 Returns:
625 A ref to newly created or existing cache with the given dimensions.
626 Raises:
627 ValueError:
628 (1) If graph is None, or
629 (2) shape is None when a new cache needs to be created.
630 """
631 if graph is None:
632 raise ValueError('Invalid graph.')
634 if graph not in self._history_value_cache:
635 self._history_value_cache[graph] = {}
637 if cache_name not in self._history_value_cache[graph]:
638 if shape is None:
639 raise ValueError('shape must be provided at cache creation.')
640 if dtype.is_integer:
641 init_val = int(_COMPACT_TRACE_ENTRY_INIT_VALUE)
642 else:
643 init_val = _COMPACT_TRACE_ENTRY_INIT_VALUE
645 # Create in proper graph and base name_scope.
646 with graph.as_default() as g, g.name_scope(None):
647 self._history_value_cache[graph][
648 cache_name] = variable_scope.get_variable(
649 'tt_history' + '_' + self._escape_namescopes(cache_name),
650 shape=shape,
651 dtype=dtype,
652 initializer=init_ops.constant_initializer(init_val),
653 trainable=False,
654 use_resource=True,
655 collections=[
656 _TENSOR_TRACER_STORAGE, ops.GraphKeys.LOCAL_VARIABLES
657 ])
659 return self._history_value_cache[graph][cache_name]
661 def _create_or_get_tensor_values_cache(self, cache_name, graph,
662 shape=None, dtype=dtypes.float32):
663 """Creates a variable as the cache to store intermediate tensor values.
665 Args:
666 cache_name: Name to be given to the cache (an instance of tf.variable).
667 graph: Tensorflow graph.
668 shape: A list of dimensions.
669 dtype: Data type of created cache.
670 Returns:
671 A ref to newly created or existing cache with the given dimensions.
672 Raises:
673 ValueError:
674 (1) If graph is None, or
675 (2) shape is None when a new cache needs to be created.
676 """
677 if graph is None:
678 raise ValueError('Invalid graph.')
680 graph_cache_var = self._cache_variable_for_graph(graph)
682 if cache_name not in graph_cache_var:
683 if shape is None:
684 raise ValueError('shape must be provided at cache creation.')
685 if dtype.is_integer:
686 init_val = int(_COMPACT_TRACE_ENTRY_INIT_VALUE)
687 else:
688 init_val = _COMPACT_TRACE_ENTRY_INIT_VALUE
690 # Create in proper graph and base name_scope.
691 with graph.as_default() as g, g.name_scope(None):
692 graph_cache_var[cache_name] = variable_scope.get_variable(
693 _TT_SNAPSHOT + '_' + self._escape_namescopes(cache_name),
694 shape=shape, dtype=dtype,
695 initializer=init_ops.constant_initializer(init_val),
696 trainable=False,
697 use_resource=True,
698 collections=[_TENSOR_TRACER_STORAGE, ops.GraphKeys.LOCAL_VARIABLES])
699 return graph_cache_var[cache_name]
701 def _add_replica_id_to_graph(self):
702 """Adds nodes for computing the replica ID to the graph."""
704 if self._tt_config.num_replicas:
705 with ops.control_dependencies(None):
706 # Uses None as dependency to run outside of TPU graph rewrites.
707 self._replica_id = tpu_ops.tpu_replicated_input(
708 list(range(self._tt_config.num_replicas)),
709 name='tt_replica_id')
710 else:
711 self._replica_id = 'unknown'
713 def _inside_op_range(self, idx):
714 """Return True if the given index is inside the selected range."""
716 if idx < self._parameters.op_range[0]:
717 return False
718 return (self._parameters.op_range[1] < 0 or
719 idx <= self._parameters.op_range[1])
721 def _is_user_included_op(self, op):
722 """Checks whether the op is included in the tensor tracer flags.
724 Args:
725 op: tf Operation
726 Returns:
727 True, if the op is included.
728 An op is included if:
729 - Its op name is given in included_opnames
730 - Its op type is given in included_optypes
731 - The op is at most _trace_ops_before_included hops before an included op
732 - The op is at most _trace_ops_after_included hops after an included op
733 """
734 for opname_re in self._parameters.included_opname_re_list:
735 if opname_re.match(op.name):
736 return True
738 for optype_re in self._parameters.included_optype_re_list:
739 if optype_re.match(op.type):
740 return True
741 return False
743 def _is_user_excluded_op(self, op):
744 for opname_re in self._parameters.excluded_opname_re_list:
745 if opname_re.match(op.name):
746 return True
747 for optype_re in self._parameters.excluded_optype_re_list:
748 if optype_re.match(op.type):
749 return True
750 return False
752 def _signature_types(self):
753 """Returns a dictionary holding the order of signatures in the cache for the selected trace mode."""
754 if self._parameters.trace_mode in set([
755 tensor_tracer_flags.TRACE_MODE_NAN_INF,
756 tensor_tracer_flags.TRACE_MODE_NORM,
757 tensor_tracer_flags.TRACE_MODE_HISTORY,
758 tensor_tracer_flags.TRACE_MODE_MAX_ABS]):
759 return {self._parameters.trace_mode: 0}
760 if self._parameters.trace_mode == tensor_tracer_flags.TRACE_MODE_SUMMARY:
761 return self._parameters.summary_signatures
762 return {}
764 def _num_signature_dimensions(self):
765 return len(self._signature_types())
767 def _use_temp_cache(self):
768 """Returns true if the intermediate values should be stacked instead of being stored in a tf.Variable.
770 Returns:
771 A boolean, denoting whether to use a temporary cache or not.
772 """
773 # If full tensors need to be stored tf.variables, then do not use temp
774 # variables to store them.
775 if self._use_tensor_buffer():
776 return False
777 if self._use_tensor_values_cache():
778 return self._parameters.use_temp_cache_var
779 else:
780 # Temporary caches only replaces tf.Variables caches. If no cache is used
781 # return False.
782 return False
784 def _use_tensor_values_cache(self):
785 """Returns True if immediate tensors should be first saved to a cache."""
786 return self._parameters.use_compact_trace
788 def _use_tensor_buffer(self):
789 """Returns true if the whole tensor needs to be cached/buffered in memory."""
790 return (self._parameters.trace_mode ==
791 tensor_tracer_flags.TRACE_MODE_FULL_TENSOR_SUMMARY)
793 def _merge_tensor_signatures(self, signatures):
794 """Returns a tensor that merges the given signatures.
796 Args:
797 signatures: A dictionary of the signature updates from signature name to
798 a tensor of dimension [1].
799 Returns:
800 A tensor that concats the signature values in a predefined order.
801 Raises:
802 ValueError: Unable to merge signatures.
803 """
804 sorted_update = []
805 if self._num_signature_dimensions() > 1:
806 signature_indices = self._signature_types()
807 for _, val in sorted(signatures.items(),
808 key=lambda item: signature_indices[item[0]]):
809 sorted_update.append(val)
810 updates = array_ops_stack.stack(
811 sorted_update, axis=0, name='merge_single_op_signatures')
812 elif self._num_signature_dimensions() == 1:
813 # Avoid stack operation if there is only a single signature.
814 (_, val), = signatures.items()
815 updates = val
816 else:
817 raise ValueError('Cannot merge 0 signatures. Check the value passed for '
818 'flag --signatures.')
819 return updates
821 def _save_tensor_value_to_tmp_cache(self, cache_idx, updates, graph):
822 """Returns an op that will save the given updates to an entry in the cache.
824 Args:
825 cache_idx: The cache index of the tensor within the cache.
826 updates: A dictionary of the signature updates from signature name to
827 a tensor of dimension [1].
828 graph: A TensorFlow graph.
829 Raises:
830 RuntimeError:
831 (1) graph is not already in self._temp_cache_var, or
832 (2) cache_idx is out of range.
833 """
834 updates = self._merge_tensor_signatures(updates)
835 updates = array_ops.reshape(updates,
836 [self._num_signature_dimensions()])
837 if graph not in self._temp_cache_var:
838 raise RuntimeError('graph is not in self._temp_cache_var')
839 if cache_idx >= len(self._temp_cache_var[graph]):
840 raise RuntimeError('cache_idx (%d) is out of range (%d)' % (
841 cache_idx, len(self._temp_cache_var[graph])))
842 self._temp_cache_var[graph][cache_idx] = updates
844 def _save_tensor_value_to_cache_op(self, cache_idx, updates, graph):
845 """Returns an op that will save the given updates to an entry in the cache.
847 Args:
848 cache_idx: The cache index of the tensor within the cache.
849 updates: A dictionary of the signature updates.
850 graph: A TensorFlow graph.
851 Returns:
852 Cache update operation.
853 """
854 # state_ops.scatter_update allows updates only along the first dimension.
855 # Make a compact array by concatenating different signatures, and update
856 # them all together.
857 updates = self._merge_tensor_signatures(updates)
858 updates = array_ops.reshape(updates,
859 [1, self._num_signature_dimensions()])
860 indices = constant_op.constant([cache_idx])
861 cache = self._create_or_get_tensor_values_cache(_TT_SUMMARY_TAG, graph)
862 return state_ops.scatter_update(cache, indices, updates).op
864 def _snapshot_tensor(self, tensor):
865 """Creates a new tf.Variable and a new tf.Operation that assigns the value of the tensor to this variable.
867 Args:
868 tensor: tensor whose values will be stored in a new tf.Variable.
869 Returns:
870 An assignment operation.
871 """
873 snapshot_variable = self._create_or_get_tensor_values_cache(
874 tensor.name, tensor.op.graph,
875 tensor.shape.as_list(), tensor.dtype)
876 return state_ops.assign(snapshot_variable, tensor).op
878 def _preprocess_traced_tensor(self, tensor):
879 """Computes NAN/Norm/Max on TPUs before sending to CPU.
881 Args:
882 tensor: The tensor to be traced.
883 Returns:
884 A tensor that should be input to the trace_function.
885 Raises:
886 RuntimeError: If the signature is invalid.
887 """
889 def _detect_nan_inf(tensor):
890 """Trace function for detecting any NaN/Inf in the tensor."""
892 if tensor.dtype.is_floating:
893 mask = math_ops.reduce_any(
894 gen_math_ops.logical_or(
895 gen_math_ops.is_nan(tensor), gen_math_ops.is_inf(tensor)))
896 output_tensor = cond.cond(
897 mask,
898 lambda: constant_op.constant([1.0]),
899 lambda: constant_op.constant([0.0]))
900 else:
901 output_tensor = constant_op.constant([0.0])
902 return output_tensor
904 def _compute_signature(tensor, tf_op, cast_to_f32=True):
905 if cast_to_f32:
906 tensor = math_ops.cast(tensor, dtypes.float32)
907 output_tensor = tf_op(tensor)
908 # Return type should be scalar. Set it if it does not have the
909 # information.
910 if not output_tensor.get_shape().is_fully_defined():
911 output_tensor = array_ops.reshape(output_tensor, [])
912 return output_tensor
914 def _show_size(tensor):
915 # In order to check the size of a tensor.
916 # Not all sizes are known at the compile time, also, different replicas
917 # sometimes get different sizes of tensors.
918 # Collect it here to be used in merging replica data.
919 tsize = _compute_signature(tensor, array_ops.size, cast_to_f32=False)
920 # Cast to float32, so that it can be placed into same cache with other
921 # signatures.
922 return math_ops.cast(tsize, dtypes.float32)
924 def _show_max(tensor, cast_to_f32=True):
925 # returns -inf for empty tensor
926 return _compute_signature(tensor, math_ops.reduce_max, cast_to_f32)
928 def _show_min(tensor, cast_to_f32=True):
929 # returns inf for empty tensor
930 return _compute_signature(tensor, math_ops.reduce_min, cast_to_f32)
932 def _show_norm(tensor, cast_to_f32=True):
933 # returns 0 for empty tensor
934 return _compute_signature(tensor, linalg_ops.norm, cast_to_f32)
936 def _show_sparsity(tensor, cast_to_f32=True, tolerance=1e-06):
937 # returns nan for empty tensor and treats nans as non-zero numbers
938 def sparsity_fn(tensor):
939 non_zeros = math_ops.greater_equal(math_ops.abs(tensor), tolerance)
940 nans = math_ops.is_nan(tensor)
941 return nn_impl.zero_fraction(math_ops.logical_or(non_zeros, nans))
943 return _compute_signature(tensor, sparsity_fn, cast_to_f32)
945 def _show_mean_and_variance(tensor, cast_to_f32=True):
946 """Returns the mean and variance of the given tensor."""
947 if cast_to_f32:
948 tensor = math_ops.cast(tensor, dtypes.float32)
949 # returns nan for empty tensor
950 mean, var = nn_impl.moments(array_ops.reshape(tensor, [-1]), axes=[0])
951 # The shape has to be 1. Set it if it does not have the information.
952 if not mean.get_shape().is_fully_defined():
953 mean = array_ops.reshape(mean, [])
954 if not var.get_shape().is_fully_defined():
955 var = array_ops.reshape(var, [])
956 return mean, var
958 def _show_max_abs(tensor, cast_to_f32=True):
959 return _compute_signature(
960 tensor, lambda t: math_ops.reduce_max(math_ops.abs(t)), cast_to_f32)
962 if self._parameters.trace_mode == tensor_tracer_flags.TRACE_MODE_NAN_INF:
963 return {self._parameters.trace_mode: _detect_nan_inf(tensor)}
964 if (self._parameters.trace_mode ==
965 tensor_tracer_flags.TRACE_MODE_PART_TENSOR):
966 return {self._parameters.trace_mode: tensor}
967 if (self._parameters.trace_mode in (
968 tensor_tracer_flags.TRACE_MODE_FULL_TENSOR,
969 tensor_tracer_flags.TRACE_MODE_FULL_TENSOR_SUMMARY)):
970 return {self._parameters.trace_mode: tensor}
971 if self._parameters.trace_mode == tensor_tracer_flags.TRACE_MODE_NORM:
972 return {self._parameters.trace_mode: array_ops.reshape(
973 _show_norm(tensor), [1])}
974 if self._parameters.trace_mode == tensor_tracer_flags.TRACE_MODE_HISTORY:
975 return {self._parameters.trace_mode: array_ops.reshape(
976 _show_norm(tensor), [1])}
977 if self._parameters.trace_mode == tensor_tracer_flags.TRACE_MODE_MAX_ABS:
978 return {self._parameters.trace_mode: _show_max_abs(tensor)}
980 if self._parameters.trace_mode == tensor_tracer_flags.TRACE_MODE_SUMMARY:
981 tensor = math_ops.cast(tensor, dtypes.float32)
982 result_dict = {}
983 # Call mean and variance computation here to avoid adding the same nodes
984 # twice.
985 if (_TT_SUMMARY_MEAN in self._signature_types() or
986 _TT_SUMMARY_VAR in self._signature_types()):
987 mean, variance = _show_mean_and_variance(tensor, cast_to_f32=False)
989 for signature_name, _ in sorted(self._signature_types().items(),
990 key=lambda x: x[1]):
991 if signature_name == _TT_SUMMARY_NORM:
992 signature_result_tensor = _show_norm(tensor, cast_to_f32=False)
993 elif signature_name == _TT_SUMMARY_MAX:
994 signature_result_tensor = _show_max(tensor, cast_to_f32=False)
995 elif signature_name == _TT_SUMMARY_MAX_ABS:
996 signature_result_tensor = _show_max_abs(tensor, cast_to_f32=False)
997 elif signature_name == _TT_SUMMARY_MIN:
998 signature_result_tensor = _show_min(tensor, cast_to_f32=False)
999 elif signature_name == _TT_SUMMARY_SPARSITY:
1000 signature_result_tensor = _show_sparsity(tensor)
1001 elif signature_name == _TT_SUMMARY_SIZE:
1002 signature_result_tensor = _show_size(tensor)
1003 elif signature_name == _TT_SUMMARY_MEAN:
1004 signature_result_tensor = mean
1005 elif signature_name == _TT_SUMMARY_VAR:
1006 signature_result_tensor = variance
1007 else:
1008 raise ValueError('Unknown signature type :%s.' % signature_name)
1010 result_dict[signature_name] = signature_result_tensor
1011 return result_dict
1013 raise RuntimeError(
1014 'Unsupported signature for trace mode %s.'
1015 % self._parameters.trace_mode)
1017 def _make_tensor_trace_fun(self, tensor_name, tensor_trace_order):
1018 """Makes the tensor tracing function called by outside compilation.
1020 Args:
1021 tensor_name: name of the tensor being traced.
1022 tensor_trace_order: TensorTraceOrder object holding tensorname to id map.
1023 Returns:
1024 A function to be passed as the first argument to outside compilation.
1026 Raises:
1027 RuntimeError: If the trace mode is invalid.
1028 """
1030 def _print_tensor(tensor_name, num_elements, tensor, output_tensor):
1031 """Prints a tensor value to a file.
1033 Args:
1034 tensor_name: name of the tensor being traced.
1035 num_elements: number of elements to print (-1 means print all).
1036 tensor: the tensor needs to be returned.
1037 output_tensor: the tensor needs to be printed.
1039 Returns:
1040 The same tensor passed via the "tensor" argument.
1042 Raises:
1043 ValueError: If tensor_name is not already in
1044 tensor_trace_order.tensorname_to_cache_idx.
1045 """
1047 if self._parameters.is_brief_mode():
1048 if tensor_name not in tensor_trace_order.tensorname_to_cache_idx:
1049 raise ValueError(
1050 'Tensor %s with name %s is not in the tensorname_to_cache_idx' %
1051 (tensor, tensor_name))
1052 msg = '%d' % tensor_trace_order.tensorname_to_cache_idx[tensor_name]
1053 else:
1054 msg = '"%s"' % tensor_name
1056 if self._parameters.trace_dir:
1057 output_path = os.path.join(
1058 self._parameters.trace_dir,
1059 _TRACE_FILE_NAME + self._get_outfile_suffix())
1060 output_stream = _OUTPUT_STREAM_ESCAPE + output_path
1061 else:
1062 output_stream = sys.stderr
1063 return logging_ops.print_v2(msg, array_ops.shape(output_tensor),
1064 '@', self._replica_id,
1065 '\n', output_tensor, '\n',
1066 summarize=num_elements,
1067 output_stream=output_stream)
1069 def _show_part_tensor(tensor):
1070 """Trace function for printing part of the tensor."""
1072 return _print_tensor(tensor_name, _TRACE_MODE_PART_TENSOR_SIZE,
1073 tensor, tensor)
1075 def _show_full_tensor(tensor):
1076 """Trace function for printing the entire tensor."""
1078 return _print_tensor(tensor_name, -1, tensor, tensor)
1080 if (self._parameters.trace_mode ==
1081 tensor_tracer_flags.TRACE_MODE_PART_TENSOR):
1082 return _show_part_tensor
1083 # The input tensor has a shape of "[1]" for TRACE_MODE_NAN_INF,
1084 # TRACE_MODE_NORM, and TRACE_MODE_MAX_ABS, as related computations are
1085 # performed within TPUs and only their results are transferred to CPU.
1086 # Simply, print the full tensor for these trace modes.
1087 if self._parameters.trace_mode in (
1088 tensor_tracer_flags.TRACE_MODE_NAN_INF,
1089 tensor_tracer_flags.TRACE_MODE_NORM,
1090 tensor_tracer_flags.TRACE_MODE_FULL_TENSOR,
1091 tensor_tracer_flags.TRACE_MODE_MAX_ABS,
1092 tensor_tracer_flags.TRACE_MODE_SUMMARY,
1093 tensor_tracer_flags.TRACE_MODE_HISTORY
1094 ):
1095 return _show_full_tensor
1097 raise RuntimeError('Full tensor support is not available with trace mode %s'
1098 %self._parameters.trace_mode)
1100 def _is_in_control_flow(self, op):
1101 """Returns true if the given op is inside a tf.cond or in tf.while_loop.
1103 Args:
1104 op: A tensorflow op that should be checked whether in control flow or not.
1105 Returns:
1106 A boolean value whether the op is in control flow or not.
1107 """
1108 return control_flow_util.IsInCond(op)
1110 def _is_in_outmost_while_loop(self, op):
1111 """Returns true if the op is at the same level with the training loop.
1113 Returns false if the op is in an inner while loop or if it is outside of the
1114 training loop.
1115 Args:
1116 op: tf.Operation
1118 Returns:
1119 A boolean.
1120 """
1121 ctxt = self._get_op_control_flow_context(op)
1122 outer_while_context = control_flow_util.GetContainingWhileContext(ctxt)
1123 return outer_while_context == control_flow_util.GetContainingWhileContext(
1124 self._outmost_context)
1126 def _should_trace_in_control_flow(self):
1127 """Returns false incase it is not safe to trace ops in tf.cond or tf.while_loop."""
1128 # As different from the other trace modes, TRACE_MODE_OPTIONAL_SUMMARY
1129 # forces the execution of the traced tensors. We should not trace the ops
1130 # that may not be executed due to control flow.
1131 if self._use_temp_cache():
1132 return False
1133 elif self._tt_config.device_type == _DEVICE_TYPE_TPU:
1134 # On TPUs do not trace in control flow unless we use caches to store
1135 # intermediate values as calling outside compilation within an inner loop
1136 # causes errors.
1137 return self._use_tensor_values_cache() or self._use_tensor_buffer()
1138 return True
1140 def _skip_op(self, op_id, op, ops_in_exec_path, report_handler):
1141 """Returns True if we should not trace Op.
1143 Args:
1144 op_id: Topological index of the op.
1145 op: tf.Operation
1146 ops_in_exec_path: Set of operations that are in the execution path.
1147 report_handler: An instance of tensor_tracer_report.TTReportHandle.
1148 Returns:
1149 True if the op should not be traced, false otherwise.
1150 """
1151 if TensorTracer.while_loop_op(op):
1152 report_handler.instrument_op(
1153 op, TensorTracer.reason(op_id, _REASON_WHILELOOP_OP))
1154 return True
1155 if TensorTracer.control_flow_op(op):
1156 report_handler.instrument_op(
1157 op, TensorTracer.reason(op_id, _REASON_CONTROLFLOW_OP))
1158 return True
1159 if TensorTracer.unsafe_op(op):
1160 report_handler.instrument_op(
1161 op, TensorTracer.reason(op_id, _REASON_UNSAFE_OP))
1162 return True
1163 if TensorTracer.device_mismatch(self._tt_config.device_type, op):
1164 report_handler.instrument_op(
1165 op, TensorTracer.reason(op_id, _REASON_DEVICE_MISMATCH))
1166 return True
1167 if op not in ops_in_exec_path:
1168 report_handler.instrument_op(
1169 op, TensorTracer.reason(op_id, _REASON_NOT_EXECUTED))
1170 return True
1171 # TensorTracer will not trace the operations that are in an inner while loop
1172 # or tf.cond when a temporary cache is used. Temporary cache adds direct
1173 # data dependencies to traced operations, and needs a static number of
1174 # traced operations. For these cases,
1175 # - We do not know the number of slots required when there are inner while
1176 # loops. TensorTracer can only trace the result of a while loop.
1177 # - We do not know ahead of time which branch of the tf.cond
1178 # will be taken, so we avoid introducing data dependencies for the
1179 # operations inside a tf.cond.
1180 # - We also cannot have a data dependency to an operation in a different
1181 # while context.
1182 if self._is_in_control_flow(op) or not self._is_in_outmost_while_loop(op):
1183 if not self._should_trace_in_control_flow():
1184 report_handler.instrument_op(
1185 op, TensorTracer.reason(op_id, _REASON_IN_CONTROL_FLOW))
1186 return True
1187 if self._is_user_included_op(op):
1188 report_handler.instrument_op(
1189 op, TensorTracer.reason(op_id, _REASON_USER_INCLUDED))
1190 if tensor_tracer_flags.TT_CHECK_FILTER.value:
1191 logging.info('USER_INCLUDED op %s', op.name)
1192 return False
1194 if not self._inside_op_range(op_id):
1195 report_handler.instrument_op(
1196 op, TensorTracer.reason(op_id, _REASON_OUTSIDE_OP_RANGE))
1197 return True
1198 if not self._is_interesting_op(op):
1199 report_handler.instrument_op(
1200 op, TensorTracer.reason(op_id, _REASON_LESS_INTERESTING_OP))
1201 return True
1202 if self._is_user_excluded_op(op):
1203 report_handler.instrument_op(
1204 op, TensorTracer.reason(op_id, _REASON_USER_EXCLUDED))
1205 if tensor_tracer_flags.TT_CHECK_FILTER.value:
1206 logging.info('USER_EXCLUDED op %s', op.name)
1207 return True
1208 return False
1210 def _skip_tensor(self, op_id, out_tensor, report_handler):
1211 """Returns True if we should not trace out_tensor.
1213 Args:
1214 op_id: Topological index of the op producing tensor.
1215 out_tensor: tf.Tensor
1216 report_handler: An instance of tensor_tracer_report.TTReportHandle.
1217 Returns:
1218 True if the tensor should not be traced, false otherwise.
1219 """
1221 # Skips a tensor if the tensor has a non-numeric type.
1222 # Note: we cannot use check_ops.is_numeric_tensor(out_tensor)
1223 # because it also excludes tensors with dtypes, bool, and
1224 # float32_ref, which we actually want to trace.
1225 non_numeric_tensor_types = set([dtypes.variant, dtypes.resource,
1226 dtypes.string])
1227 if out_tensor.dtype in non_numeric_tensor_types:
1229 report_handler.instrument_tensor(
1230 out_tensor, TensorTracer.reason(op_id, _REASON_NON_NUMERIC_TENSOR))
1231 return True
1232 # Skip a tensor if it feeds a special while loop op.
1233 if [consumer for consumer in out_tensor.consumers() if
1234 TensorTracer.while_loop_op(consumer)]:
1235 report_handler.instrument_tensor(
1236 out_tensor, TensorTracer.reason(op_id, _REASON_FEEDS_WHILELOOP_OP))
1237 return True
1238 if self._is_user_included_op(out_tensor.op):
1239 report_handler.instrument_tensor(
1240 out_tensor, TensorTracer.reason(op_id, _REASON_USER_INCLUDED))
1241 if tensor_tracer_flags.TT_CHECK_FILTER.value:
1242 logging.info('USER_INCLUDED tensor %s', out_tensor.name)
1243 return False
1244 if self._is_user_excluded_op(out_tensor.op):
1245 report_handler.instrument_tensor(
1246 out_tensor, TensorTracer.reason(op_id, _REASON_USER_EXCLUDED))
1247 if tensor_tracer_flags.TT_CHECK_FILTER.value:
1248 logging.info('USER_EXCLUDED tensor %s', out_tensor.name)
1249 return True
1250 if not out_tensor.get_shape().is_fully_defined():
1251 # If trace mode is nan-inf, norm or max, then the tensor will be reduced
1252 # to a scalar before the outside compilation call.
1253 if self._parameters.trace_mode in (
1254 tensor_tracer_flags.TRACE_MODE_NAN_INF,
1255 tensor_tracer_flags.TRACE_MODE_NORM,
1256 tensor_tracer_flags.TRACE_MODE_HISTORY,
1257 tensor_tracer_flags.TRACE_MODE_MAX_ABS,
1258 tensor_tracer_flags.TRACE_MODE_SUMMARY
1259 ):
1260 report_handler.instrument_tensor(
1261 out_tensor, TensorTracer.reason(op_id, _REASON_TENSOR_GET_TRACED))
1262 return False
1263 else:
1264 report_handler.instrument_tensor(
1265 out_tensor, TensorTracer.reason(op_id, _REASON_DYNAMIC_SHAPE))
1266 return True
1267 rank = len(out_tensor.shape)
1268 if rank < 1:
1269 # scalar
1270 if self._parameters.trace_scalar_ops:
1271 if TensorTracer.unsafe_scalar_trace(out_tensor.op):
1272 report_handler.instrument_tensor(
1273 out_tensor, TensorTracer.reason(op_id, _REASON_UNSAFE_SCALAR))
1274 return True
1275 else:
1276 report_handler.instrument_tensor(
1277 out_tensor, TensorTracer.reason(op_id, _REASON_SCALAR_GET_TRACED))
1278 return False
1279 else:
1280 report_handler.instrument_tensor(
1281 out_tensor, TensorTracer.reason(op_id, _REASON_SKIP_SCALAR))
1282 return True
1283 else:
1284 # tensor
1285 report_handler.instrument_tensor(
1286 out_tensor, TensorTracer.reason(op_id, _REASON_TENSOR_GET_TRACED))
1287 return False
1289 def _filter_execution_path_operations(self, operations, fetches):
1290 """Returns the set of ops in the execution path to compute given fetches."""
1292 # If no fetch provided, then return all operations.
1293 if fetches is None:
1294 return set(operations)
1295 # Convert to list, if a single element is provided.
1296 if not isinstance(fetches, (list, tuple)):
1297 fetches = [fetches]
1298 # If a tensor is given as fetch, convert it to op.
1299 op_fetches = []
1300 for fetch in fetches:
1301 if isinstance(fetch, ops.Operation):
1302 op_fetches.append(fetch)
1303 elif isinstance(fetch, ops.Tensor):
1304 op_fetches.append(fetch.op)
1305 else:
1306 raise RuntimeError('Given fetch:%s is neither a tensor nor an op.'
1307 %fetch)
1309 execution_path_operations = set(op_fetches)
1310 traverse_stack = list(op_fetches)
1311 while True:
1312 if not traverse_stack:
1313 break
1314 head_op = traverse_stack.pop()
1315 input_ops = [tensor_input.op for tensor_input in head_op.inputs]
1316 input_ops.extend(head_op.control_inputs)
1318 for input_op in input_ops:
1319 if input_op not in execution_path_operations:
1320 # Filter out loop condition operations, tracing them causes a cycle.
1321 # Trace only the loop-body.
1322 if TensorTracer.loop_cond_op(input_op):
1323 continue
1324 execution_path_operations.add(input_op)
1325 traverse_stack.append(input_op)
1326 return execution_path_operations
1328 def _determine_and_instrument_traced_tensors(self, graph_order,
1329 ops_in_exec_path,
1330 tensor_trace_points,
1331 report_handler):
1332 """Determines the tensors to trace and instruments the trace details.
1334 Args:
1335 graph_order: graph_order tuple containing graph (tf.graph), operations
1336 (list of operations), op_to_idx (op id mapping), (tensors) list of
1337 tensors, tensor_to_idx (tensor id mapping), contains_cycle (whether
1338 there is a cycle in the graph), topological_order_or_cycle (list of ops
1339 in topological order or list of ops creating a cycle).
1340 ops_in_exec_path: Set of ops in the execution path.
1341 tensor_trace_points: Collection of programatic tensor trace points.
1342 report_handler: An instance of tensor_tracer_report.TTReportHandle.
1343 Returns:
1344 List of tensors to be traced.
1345 """
1347 traced_tensors = []
1348 checkpoint_operations = set([tensor.op
1349 for (tensor, _) in tensor_trace_points])
1350 for op_id, op in enumerate(graph_order.operations):
1351 if checkpoint_operations and op not in checkpoint_operations:
1352 continue
1353 if self._skip_op(op_id, op, ops_in_exec_path, report_handler):
1354 continue
1355 for i in range(len(op.outputs)):
1356 out_tensor = op.outputs[i]
1357 if not self._skip_tensor(op_id, out_tensor, report_handler):
1358 traced_tensors.append(out_tensor)
1359 return traced_tensors
1361 def _check_trace_files(self):
1362 """Checks if any requirements for trace files are satisfied."""
1364 if not self._parameters.trace_dir:
1365 # traces will be written to stderr. No need to check trace files.
1366 return
1367 if self._parameters.trace_mode == tensor_tracer_flags.TRACE_MODE_SUMMARY:
1368 # Output files are handled by tf.summary operations, no need to precreate
1369 # them.
1370 return
1371 if not gfile.Exists(self._parameters.trace_dir):
1372 file_io.recursive_create_dir(self._parameters.trace_dir)
1373 if not gfile.Exists(self._parameters.trace_dir):
1374 raise RuntimeError('Failed to create trace directory at %s' %
1375 self._parameters.trace_dir)
1377 def _create_temp_cache(self, num_traced_tensors, num_signatures, graph):
1378 """Creates a temporary cache with the given dimensions.
1380 Fills the self._temp_cache_var with num_traced_tensors tf.constant() ops
1381 that have shape of [num_signatures].
1382 Args:
1383 num_traced_tensors: Int, denoting total number of traced tensors.
1384 num_signatures: Int, denoting the number of statistics collected per
1385 tensors.
1386 graph: TensorFlow graph.
1387 """
1388 init_value = constant_op.constant(_COMPACT_TRACE_ENTRY_INIT_VALUE,
1389 dtype=dtypes.float32,
1390 shape=[num_signatures])
1391 self._temp_cache_var[graph] = [
1392 init_value for _ in range(num_traced_tensors)]
1394 def _determine_trace_and_create_report(self, graph, ops_in_exec_path,
1395 graph_summary_tag):
1396 """Work needs to be done prior to TPU or CPU tracing.
1398 Args:
1399 graph: tf.graph
1400 ops_in_exec_path: Set of operations in the execution path.
1401 graph_summary_tag: the summary tag name for the given graph.
1402 Returns:
1403 An instance of tensor_tracer_report.TensorTraceOrder, containing list of
1404 tensors to be traced with their topological order information.
1405 Raises:
1406 RuntimeError: If opname filtering is incorrectly set.
1407 """
1409 self._check_trace_files()
1411 graph_order = tensor_tracer_report.sort_tensors_and_ops(graph)
1412 tensor_trace_points = graph.get_collection(_TENSOR_TRACER_COLLECTION)
1414 report_handler = tensor_tracer_report.TTReportHandle()
1415 traced_tensors = self._determine_and_instrument_traced_tensors(
1416 graph_order, ops_in_exec_path, tensor_trace_points, report_handler)
1417 logging.info('TensorTracer is tracing %d tensors.', len(traced_tensors))
1418 if traced_tensors and tensor_tracer_flags.TT_CHECK_FILTER.value:
1419 raise RuntimeError('Verify ops being traced by tensor tracer.')
1421 tensor_trace_order = tensor_tracer_report.TensorTraceOrder(graph_order,
1422 traced_tensors)
1423 num_signatures = self._num_signature_dimensions()
1424 # Create a cache variable if compact_tracing is used.
1425 if num_signatures and self._use_tensor_values_cache():
1426 if self._use_temp_cache():
1427 self._create_temp_cache(len(traced_tensors), num_signatures, graph)
1428 else:
1429 self._create_or_get_tensor_values_cache(
1430 _TT_SUMMARY_TAG, graph, [len(traced_tensors), num_signatures])
1431 if self._parameters.trace_mode in (
1432 tensor_tracer_flags.TRACE_MODE_HISTORY):
1433 self._create_or_get_tensor_history_values_cache(
1434 _TT_SUMMARY_TAG, graph, [len(traced_tensors), num_signatures])
1435 if self._parameters.trace_mode in (
1436 tensor_tracer_flags.TRACE_MODE_SUMMARY,
1437 tensor_tracer_flags.TRACE_MODE_FULL_TENSOR_SUMMARY):
1438 self._report_proto = report_handler.create_report_proto(
1439 self._tt_config, self._parameters, tensor_trace_order,
1440 tensor_trace_points, self._signature_types())
1441 if self._parameters.use_fingerprint_subdir:
1442 self._parameters.trace_dir = os.path.join(
1443 self._parameters.trace_dir, self._report_proto.fingerprint)
1444 logging.info('TensorTracer updating trace_dir to %s',
1445 self._parameters.trace_dir)
1446 self._report_proto_path = report_handler.report_proto_path(
1447 self._parameters.trace_dir, graph_summary_tag)
1449 if self._parameters.report_file_path != _SKIP_REPORT_FILE:
1450 report_handler.write_report_proto(self._report_proto_path,
1451 self._report_proto, self._parameters)
1452 else:
1453 if self._parameters.trace_mode not in (
1454 tensor_tracer_flags.TRACE_MODE_HISTORY):
1455 report_handler.create_report(self._tt_config, self._parameters,
1456 tensor_trace_order, tensor_trace_points)
1457 return tensor_trace_order
1459 def _create_host_call(self):
1460 return self._parameters.trace_mode in (
1461 tensor_tracer_flags.TRACE_MODE_SUMMARY,
1462 tensor_tracer_flags.TRACE_MODE_FULL_TENSOR_SUMMARY)
1464 def _inspect_summary_cache(self, cache, replica_id, step_num, output_stream,
1465 tensor_trace_order):
1466 """Generates a print operation to print trace inspection.
1468 Args:
1469 cache: Tensor storing the trace results for the step.
1470 replica_id: Tensor storing the replica id of the running core.
1471 step_num: Step number.
1472 output_stream: Where to print the outputs, e.g., file path, or sys.stderr.
1473 tensor_trace_order: TensorTraceOrder object holding tensorname to id map.
1475 Returns:
1476 The Op to flush the cache to file.
1477 """
1478 def _inspect_tensor(tensor):
1479 """Returns the text to be printed for inspection output."""
1480 if (self._parameters.trace_mode ==
1481 tensor_tracer_flags.TRACE_MODE_NAN_INF):
1482 return cond.cond(
1483 math_ops.greater(tensor, 0.0),
1484 lambda: 'has NaNs/Infs!',
1485 lambda: 'has no NaNs or Infs.')
1486 else:
1487 return tensor
1489 # Check if there are graph operations being profiled.
1490 if not tensor_trace_order.traced_tensors:
1491 logging.warn('Inspect mode has no tensors in the cache to check.')
1492 return control_flow_ops.no_op
1494 # Check if the cache includes any nan or inf
1495 if self._parameters.trace_mode == tensor_tracer_flags.TRACE_MODE_NAN_INF:
1496 # Cache has 1s or 0s if the mode is NaN_INF
1497 step_has_nan_or_inf = math_ops.greater(math_ops.reduce_sum(cache), 0.0)
1498 else:
1499 # Cache has the actual numerics for other modes.
1500 step_has_nan_or_inf = math_ops.reduce_any(
1501 gen_math_ops.logical_or(
1502 gen_math_ops.is_nan(cache), gen_math_ops.is_inf(cache)))
1504 # Summarizing message for each step.
1505 step_error_message = cond.cond(
1506 step_has_nan_or_inf,
1507 lambda: 'NaNs or Infs in the step!',
1508 lambda: 'No numerical issues have been found for the step.')
1510 # No need to print core numbers if the cache is merged already.
1511 if self._parameters.collect_summary_per_core:
1512 stats = ['\n\n', 'core:', replica_id, ',', 'step:', step_num, '-->',
1513 step_error_message,
1514 'Printing tensors for mode:%s...' % self._parameters.trace_mode]
1515 else:
1516 stats = ['\n\n', 'step:', step_num, '-->', step_error_message,
1517 'Printing tensors for mode:%s...' % self._parameters.trace_mode]
1519 for tensor_name, cache_idx in sorted(
1520 tensor_trace_order.tensorname_to_cache_idx.items(),
1521 key=lambda item: item[1]):
1522 if self._parameters.collect_summary_per_core:
1523 stats.extend([
1524 '\n', 'core:', replica_id, ',', 'step:', step_num, ',',
1525 tensor_name, '-->', _inspect_tensor(cache[cache_idx, 0])])
1526 else:
1527 stats.extend([
1528 '\n', 'step:', step_num, ',',
1529 tensor_name, '-->', _inspect_tensor(cache[cache_idx, 0])])
1530 return logging_ops.print_v2(*stats, summarize=-1,
1531 output_stream=output_stream)
1533 def _inspect_history_cache(self, cache, replica_id, step_num,
1534 tensor_trace_order):
1535 """Generates a conditional print operation to log differences in tensor values.
1537 Args:
1538 cache: Tensor storing the trace results for the step.
1539 replica_id: Tensor storing the replica id of the running core.
1540 step_num: Step number.
1541 tensor_trace_order: TensorTraceOrder object holding tensorname to id map.
1543 Returns:
1544 The Op to flush the cache to file.
1545 """
1546 # Check if there are graph operations being profiled.
1547 if not tensor_trace_order.traced_tensors:
1548 logging.warn('TT history mode has no tensors in the cache to check.')
1549 return control_flow_ops.no_op
1551 stats = ['\n\n', 'core:', replica_id, ',', 'step:', step_num]
1552 diffs = []
1553 for tensor_name, cache_idx in sorted(
1554 tensor_trace_order.tensorname_to_cache_idx.items(),
1555 key=lambda item: item[1]):
1557 tensor_to_write = cache[cache_idx, 0]
1558 snapshot_variable = self._create_or_get_tensor_history_values_cache(
1559 tensor_to_write.name, tensor_to_write.op.graph,
1560 tensor_to_write.shape.as_list(), tensor_to_write.dtype)
1562 with ops.control_dependencies([snapshot_variable]):
1563 old_value = state_ops.assign_add(snapshot_variable, 0.0)
1565 with ops.control_dependencies([old_value]):
1566 new_value = math_ops.cast(tensor_to_write, dtypes.float32)
1567 delta = math_ops.abs(math_ops.subtract(old_value, new_value))
1568 updated = state_ops.assign(snapshot_variable, new_value)
1569 diffs.append(delta)
1570 with ops.control_dependencies([updated]):
1571 new_value_from_var = state_ops.assign_add(snapshot_variable, 0.0)
1573 stats.extend([
1574 '\n', 'core:', replica_id, ',', 'step:', step_num, ',',
1575 tensor_name, '-->', old_value, new_value_from_var, delta])
1577 diff_stack = array_ops_stack.stack(diffs)
1578 step_max = math_ops.reduce_max(diff_stack)
1580 return cond.cond(
1581 math_ops.greater(step_max, tensor_tracer_flags.DELTA_THRESHOLD.value),
1582 lambda: logging_ops.print_v2(*stats, summarize=-1),
1583 lambda: control_flow_ops.no_op()) # pylint: disable=unnecessary-lambda
1585 def _get_outfile_suffix(self):
1586 if remote_utils.is_remote_path(self._parameters.trace_dir):
1587 return remote_utils.get_appendable_file_encoding()
1588 else:
1589 return ''
1591 def _generate_flush_cache_op(self, num_replicas, on_tpu,
1592 tensor_trace_order, graph):
1593 """Generates an Op that will flush the cache to file.
1595 Args:
1596 num_replicas: total number of replicas.
1597 on_tpu: if the graph is executed on TPU.
1598 tensor_trace_order: TensorTraceOrder object holding tensorname to id map.
1599 graph: TensorFlow graph.
1601 Returns:
1602 The Op to flush the cache to file.
1603 """
1605 def _flush_fun(cache, replica_id, step_num):
1606 """Flushes the cache to a file corresponding to replica_id."""
1608 def _f(file_index):
1609 """Generates a func that flushes the cache to a file."""
1610 def _print_cache():
1611 """Flushes the cache to a file."""
1612 replica_str = ('%d' % file_index)
1613 if self._parameters.trace_dir:
1614 output_path = (os.path.join(self._parameters.trace_dir,
1615 _COMPACT_TRACE_FILE_PREFIX)
1616 + replica_str + self._get_outfile_suffix())
1617 output_stream = _OUTPUT_STREAM_ESCAPE + output_path
1618 else:
1619 output_stream = sys.stderr
1621 new_step_line = _REPLICA_ID_TAG + replica_str
1622 print_ops = []
1623 if self._parameters.inspect_trace:
1624 if self._num_signature_dimensions() > 1:
1625 raise ValueError('Inspecting multi signatures are not supported.')
1626 if self._parameters.trace_mode in (
1627 tensor_tracer_flags.TRACE_MODE_HISTORY):
1628 print_ops.append(
1629 self._inspect_history_cache(
1630 cache=cache,
1631 replica_id=replica_id,
1632 step_num=step_num,
1633 tensor_trace_order=tensor_trace_order))
1634 else:
1635 print_ops.append(
1636 self._inspect_summary_cache(
1637 cache=cache,
1638 replica_id=replica_id,
1639 step_num=step_num,
1640 output_stream=output_stream,
1641 tensor_trace_order=tensor_trace_order))
1642 else:
1643 for i in range(self._num_signature_dimensions()):
1644 print_ops.append(logging_ops.print_v2(
1645 new_step_line, '\n',
1646 cache[:, i], '\n',
1647 summarize=-1,
1648 output_stream=output_stream))
1649 with ops.control_dependencies(print_ops):
1650 return constant_op.constant(0).op
1651 return _print_cache
1653 def _eq(file_index):
1654 return math_ops.equal(replica_id, file_index)
1656 flush_op_cases = {}
1657 flush_op_cases[_eq(0)] = _f(0)
1658 for i in range(1, num_replicas):
1659 if on_tpu and not self._parameters.collect_summary_per_core:
1660 # If this is the case, the cache is already merged for all cores.
1661 # Only first core flushes the cache.
1662 flush_op_cases[_eq(i)] = control_flow_ops.no_op
1663 else:
1664 flush_op_cases[_eq(i)] = _f(i)
1665 # Each replica needs to determine where to write their output.
1666 # To do this, we check if replica_id is 0, then 1, ..., and then
1667 # num_replicas - 1 statically; and return the corresponding static file
1668 # name. We cannot simply set the file name in python, as replica_id is
1669 # only known during tf runtime, and we cannot create dynamic filenames.
1670 return control_flow_case.case(flush_op_cases, exclusive=True)
1672 cache = self._create_or_get_tensor_values_cache(_TT_SUMMARY_TAG, graph)
1673 if self._use_temp_cache():
1674 cache_val = cache
1675 else:
1676 cache_val = cache.value()
1678 if on_tpu:
1679 # If we do not need to collect traces for all cores, merge and aggregate
1680 # per core trace.
1681 if not self._parameters.collect_summary_per_core:
1682 cache_val = self.merge_caches_on_tpu(cache_val)
1683 cache_val = self.aggregate_global_cache(cache_val)[0]
1685 flush_op = tpu_replication.outside_compilation(
1686 _flush_fun, cache_val, self._replica_id,
1687 array_ops.identity(training_util.get_or_create_global_step()))
1688 else:
1689 global_step = training_util.get_or_create_global_step()
1690 flush_op = _flush_fun(cache_val, self._replica_id, global_step)
1692 if self._use_temp_cache():
1693 with ops.control_dependencies([flush_op]):
1694 return constant_op.constant(0).op
1695 else:
1696 # Re-initialize the local cache variable.
1697 with ops.control_dependencies([flush_op]):
1698 reset_value = constant_op.constant(_COMPACT_TRACE_ENTRY_INIT_VALUE,
1699 dtype=cache.dtype,
1700 shape=cache.shape)
1701 assign_op = state_ops.assign(cache, reset_value).op
1702 with ops.control_dependencies([assign_op]):
1703 return constant_op.constant(0).op
1705 def _flush_tensor_values_cache(self, tensor_fetches, op_fetches, on_tpu,
1706 tensor_trace_order, graph):
1707 """Flushes the intermediate tensor values in the graph to the cache.
1709 Args:
1710 tensor_fetches: list of tensor results returned by the model_fn.
1711 op_fetches: list of ops that are returned by the model_fn, e.g., train_op.
1712 on_tpu: if the graph is executed on TPU.
1713 tensor_trace_order: TensorTraceOrder object holding tensorname to id map.
1714 graph: TensorFlow graph.
1716 Returns:
1717 An identical copy of tensor_fetches.
1718 """
1719 # Add a dependency to op and tensor fetches to make sure that all tracing
1720 # ops are executed before flushing trace results.
1721 if not tensor_trace_order.traced_tensors:
1722 logging.warn('No tensor values being traced. No flush cache op added.')
1723 return tensor_fetches
1724 with ops.control_dependencies(op_fetches +
1725 [tensor.op for tensor in tensor_fetches]):
1726 flush_cache_op = self._generate_flush_cache_op(
1727 self._tt_config.num_replicas, on_tpu, tensor_trace_order, graph)
1728 return control_flow_ops.tuple(tensor_fetches,
1729 control_inputs=[flush_cache_op])
1731 def _process_tensor_fetches(self, tensor_fetches):
1732 """Check that tensor_fetches is not empty and have valid tensors."""
1733 # If none or empty list.
1734 if tensor_fetches is None:
1735 raise RuntimeError('tensor_fetches provided to tensor_tracer cannot be '
1736 'None.')
1737 if not isinstance(tensor_fetches, (list, tuple)):
1738 tensor_fetches = [tensor_fetches]
1739 elif not tensor_fetches:
1740 raise RuntimeError('tensor_fetches provided to tensor_tracer cannot be '
1741 'empty list.')
1742 fetches = []
1743 for fetch in tensor_fetches:
1744 if isinstance(fetch, ops.Tensor):
1745 fetches.append(fetch)
1746 else:
1747 raise RuntimeError('Given tensor_fetch:%s is not a tensor.' % fetch)
1748 return fetches
1750 def _process_op_fetches(self, op_fetches):
1751 """Check that op_fetches have valid ops."""
1752 if op_fetches is None:
1753 return []
1755 if not isinstance(op_fetches, (list, tuple)):
1756 op_fetches = [op_fetches]
1758 fetches = []
1759 for fetch in op_fetches:
1760 if isinstance(fetch, ops.Operation):
1761 fetches.append(fetch)
1762 elif isinstance(fetch, ops.Tensor):
1763 fetches.append(fetch.op)
1764 else:
1765 logging.warning('Ignoring the given op_fetch:%s, which is not an op.' %
1766 fetch)
1767 return fetches
1769 def _convert_fetches_to_input_format(self, input_fetches, current_fetches):
1770 """Changes current_fetches' format, so that it matches input_fetches."""
1771 if isinstance(input_fetches, ops.Tensor):
1772 if len(current_fetches) != 1:
1773 raise RuntimeError('Tensor tracer input/output fetches do not match.')
1774 return current_fetches[0]
1775 else:
1776 if len(current_fetches) != len(current_fetches):
1777 raise RuntimeError('Tensor tracer input/output fetches do not match.')
1778 elif isinstance(input_fetches, tuple):
1779 return tuple(current_fetches)
1780 else:
1781 return current_fetches
1783 def _get_op_control_flow_context(self, op):
1784 """Returns the control flow of the given op.
1786 Args:
1787 op: tf.Operation for which the control flow context is requested.
1788 Returns:
1789 op_control_flow_context: which the is control flow context of the given
1790 op. If the operation type is LoopExit, returns the outer control flow
1791 context.
1792 """
1793 # pylint: disable=protected-access
1794 op_control_flow_context = op._control_flow_context
1795 # pylint: enable=protected-access
1796 if control_flow_util.IsLoopExit(op):
1797 op_control_flow_context = op_control_flow_context.outer_context
1798 return op_control_flow_context
1800 def merge_caches_on_tpu(self, local_tpu_cache_tensor):
1801 """Merges the given caches on tpu.
1803 Args:
1804 local_tpu_cache_tensor: A local tensor that needs to be merged
1805 by concanting data from other tpu cores.
1806 Returns:
1807 A merged tf.Tensor.
1808 """
1809 x = array_ops.broadcast_to(
1810 local_tpu_cache_tensor,
1811 shape=[self._tt_config.num_replicas] +
1812 local_tpu_cache_tensor.shape.as_list())
1814 if tensor_tracer_flags.TT_SINGLE_CORE_SUMMARIES.value:
1815 return x
1817 return tpu_ops.all_to_all(
1818 x, concat_dimension=0, split_dimension=0,
1819 split_count=self._tt_config.num_replicas,
1820 group_assignment=[list(range(self._tt_config.num_replicas))])
1822 def aggregate_global_cache(self, global_tt_summary_cache):
1823 """Merges the given caches on tpu.
1825 Args:
1826 global_tt_summary_cache: The global tensor tracer summary cache tensor
1827 with shape (num_cores, num_traced_tensors, num_traced_signatures). First
1828 dimension corresponds to core_id, where global_tpu_cache_tensor[i]
1829 correspond to the local cache from core-i.
1830 Returns:
1831 An aggregated tf.Tensor.
1832 Raises:
1833 RuntimeError: if there is no aggregate function defined for a signature.
1834 """
1836 # Merge only statistics tensor, if it is any other tensor we simply,
1837 # concatenate them.
1838 agg_fn_map = self._parameters.get_signature_to_agg_fn_map()
1839 signature_idx_map = self._signature_types()
1840 aggregation_result = []
1841 for signature, idx in sorted(signature_idx_map.items(),
1842 key=operator.itemgetter(1)):
1843 if signature not in agg_fn_map:
1844 raise RuntimeError('No aggregation function is defined for '
1845 'signature %s.' % signature)
1846 # The dimensions of the statistics tensor is
1847 # num_cores x num_traced_tensors x num_signatures
1848 # value[:,:,idx] will return the portion of the tensor related
1849 # to signature.
1850 signature_tensor = global_tt_summary_cache[:, :, idx]
1851 # Merge it along the first (core) axis.
1852 agg_fn = agg_fn_map[signature]
1853 agg_tensor = agg_fn(signature_tensor, axis=0)
1854 aggregation_result.append(agg_tensor)
1855 # Merge results corresponding to different signatures
1857 merged_signatures = array_ops_stack.stack(aggregation_result)
1858 # merged_signatures has dimensions
1859 # num_signatures x num_traced_tensors, transpose it so that it
1860 # will match with the original structure
1861 # num_traced_tensors x num_signatures.
1862 transposed_signatures = array_ops.transpose(merged_signatures)
1863 # Expand 1 more dimension so that it will match with the expected
1864 # structure num_cores x num_traced_tensors x num_signatures.
1865 return array_ops.expand_dims(transposed_signatures, axis=0)
1867 def _prepare_host_call_fn(self, processed_t_fetches,
1868 op_fetches, graph, graph_summary_tag):
1869 """Creates a host call function that will write the cache as tb summary.
1871 Args:
1872 processed_t_fetches: List of tensor provided to session.run.
1873 op_fetches: List of operations provided to session.run.
1874 graph: TensorFlow graph.
1875 graph_summary_tag: the summary_tag name for the given graph.
1876 Raises:
1877 ValueError if trace_dir is not set.
1878 """
1879 if self._parameters.trace_dir is None:
1880 raise ValueError('Provide a trace_dir for tensor tracer in summary mode. '
1881 '--trace_dir=/model/dir')
1883 def _write_cache(step, event_file_suffix=None, **kwargs):
1884 """Writes the given caches as tensor summary.
1886 Args:
1887 step: Step tensor with dimension [num_cores].
1888 event_file_suffix: Event filename suffix tensor.
1889 **kwargs: The dictionary of tensors that needs to be written as
1890 summaries. Key and value pairs within kwargs correspond to the tag
1891 name, and tensor content that will be written using summary.write.
1892 The trace_modes that use this function are:
1893 - summary: In summary mode, kwargs includes a single (tag, content)
1894 pair which are, _TT_SUMMARY_TAG and a tf.float32 signature_cache
1895 variable. The dimension of the signature_cache is:
1896 num_cores x num_traced_tensors x num_signatures.
1897 - full_tensor_summary: kwargs will include all traced tensors. Tag
1898 and content correspond to the name of the tensor, and its actual
1899 content.
1900 Returns:
1901 A tf.Operation that needs to be executed for the host call dependencies.
1902 """
1903 file_suffix = _TT_EVENT_FILE_SUFFIX
1904 if event_file_suffix is not None:
1905 file_suffix = string_ops.string_join([file_suffix, event_file_suffix],
1906 separator='.')
1907 # TODO(deveci): Parametrize max_queue, so that flushing op can be called
1908 # less frequently.
1909 # Setting max_queue to 100 appears to be safe even when the number of
1910 # iterations are much lower, as the destructor of the writer flushes it.
1911 summary_write_ops = []
1912 summary_writer = summary.create_file_writer_v2(
1913 self._parameters.trace_dir,
1914 filename_suffix=file_suffix,
1915 max_queue=_TT_SUMMARY_MAX_QUEUE)
1916 graph.add_to_collection(
1917 TENSOR_TRACER_SUMMARY_COLLECTION, summary_writer)
1919 step_value = step[0]
1920 dt = step_value.dtype
1922 # The step parameter to a summary write call must be 64-bit.
1923 if dt.__ne__(dtypes.int64) and dt.__ne__(
1924 dtypes.uint64) and dt.__ne__(dtypes.float64):
1925 step_value = math_ops.cast(step_value, dtypes.int64)
1927 with summary_writer.as_default():
1928 summary_metadata = summary_pb2.SummaryMetadata(
1929 plugin_data=summary_pb2.SummaryMetadata.PluginData(
1930 plugin_name=_TT_TENSORBOARD_PLUGIN_NAME))
1931 for key, value in kwargs.items():
1932 # Check whether we need to compute aggregated statistics that merge
1933 # all cores statistics.
1934 if not self._parameters.collect_summary_per_core:
1935 # Merge only statistics tensor, if it is any other tensor we simply,
1936 # concatenate them.
1937 # Also, if there is only a single core (first dim. is 0), then skip
1938 # aggregation.
1939 if key == _TT_SUMMARY_TAG and value.shape.as_list()[0] != 1:
1940 value = self.aggregate_global_cache(value)
1941 with ops.control_dependencies([summary_writer.init()]):
1942 summary_write_ops.append(summary.write(
1943 _TT_SUMMARY_TAG + '/' + key + '.' + graph_summary_tag,
1944 value, metadata=summary_metadata,
1945 step=step_value))
1946 return control_flow_ops.group(summary_write_ops)
1948 global_step = training_util.get_or_create_global_step()
1949 step = array_ops.reshape(global_step, [1])
1950 self._host_call_fn = {}
1952 host_call_deps = op_fetches + [tensor.op for tensor in processed_t_fetches]
1954 caches_to_write = {}
1955 with ops.control_dependencies(host_call_deps):
1956 all_caches = self._cache_variable_for_graph(graph)
1957 for cache_name, cache_variable in all_caches.items():
1958 # Increase the cache rank by 1, so that when host call concatenates
1959 # tensors from different replicas, we can identify them with [core_id].
1960 new_cache_shape = [1]
1961 new_cache_shape.extend(cache_variable.shape.as_list())
1962 cache = array_ops.reshape(cache_variable, new_cache_shape)
1963 caches_to_write[cache_name] = cache
1964 # Add step to parameter dictionary.
1965 caches_to_write['step'] = step
1966 # Other options without adding step to parameter dictionary are
1967 # * host_call_fn = (_write_cache(step, caches_to_write)) : fails as it
1968 # considers caches_to_write as a single parameter, rather than a keyword
1969 # parameters.
1970 # * host_call_fn = (_write_cache(step, **caches_to_write)) : fails with
1971 # a syntax error.
1972 self._host_call_fn[_TT_HOSTCALL_KEY] = (_write_cache, caches_to_write)
1974 def host_call_deps_and_fn(self):
1975 return self._host_call_fn
1977 def get_traced_op_names(self):
1978 """Returns the set of traced op names."""
1979 return self._traced_op_names
1981 def _trace_execution(self, graph,
1982 tensor_fetches,
1983 op_fetches=None,
1984 on_tpu=True):
1985 """Commong tracing function for both CPU and TPUs.
1987 The caller function should set device_type, num_replicas,
1988 num_replicas_per_host, num_hosts and replica_id before calling
1989 _trace_execution.
1992 Args:
1993 graph: the graph of Ops executed on the TPU.
1994 tensor_fetches: a (list,tuple,or a single object) of tensor fetches
1995 returned by model_fn given to session.run. Function must be provided
1996 with as least one tensor to fetch.
1997 op_fetches: A list of op fetches returned by model_fn given to
1998 session.run. op_fetches and tensor_fetches are used to determine the
1999 nodes that will be executed. Can be None.
2000 on_tpu: True if executing on TPU.
2002 Returns:
2003 tensor_fetches: an exact copy of tensor_fetches that has additional
2004 dependencies.
2005 Raises:
2006 RuntimeError: If tensor_fetches is None or empty.
2007 """
2008 def _cast_unsupported_dtypes(tensor):
2009 """Casts tensor to a supported type."""
2011 if tensor.dtype.__eq__(dtypes.int64):
2012 # outside-compilation doesn't support int64 input yet.
2013 return math_ops.cast(tensor, dtypes.int32)
2014 if tensor.dtype.__eq__(dtypes.bfloat16) or tensor.dtype.__eq__(
2015 dtypes.float16):
2016 # Since host can't handle bf16, convert tensor to f32.
2017 return math_ops.cast(tensor, dtypes.float32)
2018 return tensor
2020 trace_mode = self._parameters.trace_mode
2021 device_type = self._tt_config.device_type
2022 # pylint: disable=protected-access
2023 self._outmost_context = graph._get_control_flow_context()
2024 # pylint: enable=protected-access
2026 analytics.track_usage('tensor_tracer', [trace_mode, device_type])
2027 TensorTracer.check_device_type(device_type)
2028 TensorTracer.check_trace_mode(device_type, trace_mode)
2029 # Check in_tensor_fetches, and op_fetches and convert them to lists.
2030 processed_t_fetches = self._process_tensor_fetches(tensor_fetches)
2031 op_fetches = self._process_op_fetches(op_fetches)
2032 all_fetches = op_fetches + [tensor.op for tensor in processed_t_fetches]
2034 # Filter out the operations that won't be executed.
2035 # if fetches=None, then ops_in_exec_path = set(operations)
2036 exec_op_set = self._filter_execution_path_operations(graph.get_operations(),
2037 all_fetches)
2038 graph_summary_tag = _graph_summary_tag(graph)
2040 # Write report file, and determine the traced tensors.
2041 tensor_trace_order = self._determine_trace_and_create_report(
2042 graph, exec_op_set, graph_summary_tag)
2044 tensor_fetch_set = set(processed_t_fetches)
2045 tracing_ops = []
2047 sorted_exec_op_list = list(exec_op_set)
2048 sorted_exec_op_list.sort(key=lambda op: op.name)
2049 # Trace ops only if they are in the execution path.
2050 for op in sorted_exec_op_list:
2051 for i in range(len(op.outputs)):
2052 out_tensor = op.outputs[i]
2053 tensor_name = out_tensor.name
2054 if tensor_name not in tensor_trace_order.tensorname_to_cache_idx:
2055 continue
2056 self._traced_op_names.add(op.name)
2057 # Create the list of consumers before calling _preprocess_traced_tensor.
2058 # Otherwise, adding control input below, will introduce a cycle in the
2059 # graph.
2060 consumers = out_tensor.consumers()
2061 # Not all consumers may be in the exec path. Filter out the consumers
2062 # to keep the graph simpler.
2063 consumers = [cop for cop in consumers if cop in exec_op_set]
2065 # If there is no consumer of the tensor, there is no need to trace it;
2066 # unless the tensor itself is one of the fetches.
2067 is_a_fetched_tensor = out_tensor in tensor_fetch_set
2068 if (not consumers) and (not is_a_fetched_tensor):
2069 continue
2071 op_control_flow_context = self._get_op_control_flow_context(op)
2072 if op_control_flow_context:
2073 # pylint: disable=protected-access
2074 graph._set_control_flow_context(op_control_flow_context)
2075 # pylint: enable=protected-access
2077 processed_tensors = self._preprocess_traced_tensor(out_tensor)
2079 if on_tpu:
2080 for signature in processed_tensors.keys():
2081 processed_tensors[signature] = _cast_unsupported_dtypes(
2082 processed_tensors[signature])
2084 if self._use_tensor_values_cache():
2085 # Use a small cache (either temp cache or tf local variable) to store
2086 # the characteristics of the tensor.
2087 if self._use_temp_cache():
2088 cache_idx = tensor_trace_order.tensorname_to_cache_idx[tensor_name]
2089 self._save_tensor_value_to_tmp_cache(cache_idx,
2090 processed_tensors,
2091 graph)
2092 trace_op = None
2093 else:
2094 cache_idx = tensor_trace_order.tensorname_to_cache_idx[tensor_name]
2095 trace_op = self._save_tensor_value_to_cache_op(cache_idx,
2096 processed_tensors,
2097 graph)
2098 elif self._use_tensor_buffer():
2099 if len(processed_tensors) != 1:
2100 raise RuntimeError('Multiple stats are only allowed in compact '
2101 'mode.')
2102 processed_out_tensor = list(processed_tensors.values())[0]
2103 # Store the whole tensor in a buffer.
2104 trace_op = self._snapshot_tensor(processed_out_tensor)
2105 else:
2107 def tpu_wrap_trace_fn(tensor, out_tensor_name):
2108 """Wraps the trace_fn with outside compilation if on TPUs."""
2109 tensor_trace_fn = self._make_tensor_trace_fun(out_tensor_name,
2110 tensor_trace_order)
2111 if on_tpu:
2112 return tpu_replication.outside_compilation(
2113 tensor_trace_fn, tensor)
2114 else:
2115 return tensor_trace_fn(tensor)
2117 if len(processed_tensors) != 1:
2118 raise RuntimeError('Multiple stats are only allowed in compact '
2119 'mode.')
2120 # Collecting multiple statistics are only supported in the summary
2121 # mode that uses compact format(self._use_tensor_values_cache = true).
2122 # Non-compact mode currently allows single stat per tensor.
2123 processed_out_tensor = next(iter(processed_tensors.values()))
2124 trace_op = tpu_wrap_trace_fn(processed_out_tensor, tensor_name)
2126 if op_control_flow_context:
2127 # pylint: disable=protected-access
2128 graph._set_control_flow_context(self._outmost_context)
2129 # pylint: enable=protected-access
2130 if trace_op:
2131 if is_a_fetched_tensor:
2132 tracing_ops.append(trace_op)
2133 continue
2134 # Add it to all consumers, as some consumers may not be executed if
2135 # they are in a control flow.
2136 for consumer_op in consumers:
2137 # pylint: disable=protected-access
2138 consumer_op._add_control_input(trace_op)
2139 # pylint: enable=protected-access
2141 # pylint: disable=protected-access
2142 graph._set_control_flow_context(self._outmost_context)
2143 # pylint: enable=protected-access
2144 if tracing_ops:
2145 # If we are tracing a fetched tensor, their dependency is stored in
2146 # tracing_ops.
2147 processed_t_fetches = control_flow_ops.tuple(processed_t_fetches,
2148 control_inputs=tracing_ops)
2149 if self._use_tensor_values_cache() or self._use_tensor_buffer():
2150 if self._use_temp_cache():
2151 # Create the temporary tf cache variable by concantanating all
2152 # statistics.
2153 graph_cache_var = self._cache_variable_for_graph(graph)
2154 if graph not in self._temp_cache_var:
2155 raise RuntimeError('graph is not in self._temp_cache_var')
2156 graph_cache_var[_TT_SUMMARY_TAG] = array_ops_stack.stack(
2157 self._temp_cache_var[graph], axis=0, name='stack_all_op_signatures')
2158 if self._create_host_call():
2159 self._prepare_host_call_fn(processed_t_fetches, op_fetches, graph,
2160 graph_summary_tag)
2161 if not on_tpu:
2162 write_cache, caches_to_write = self._host_call_fn[_TT_HOSTCALL_KEY]
2163 cache_write_op = write_cache(**caches_to_write)
2164 processed_t_fetches = control_flow_ops.tuple(
2165 processed_t_fetches, control_inputs=[cache_write_op])
2166 del self._host_call_fn[_TT_HOSTCALL_KEY]
2167 elif self._parameters.flush_summaries_with_outside_compile:
2168 write_cache, caches_to_write = self._host_call_fn[_TT_HOSTCALL_KEY]
2169 if (_TT_SUMMARY_TAG in caches_to_write and 'step' in caches_to_write):
2170 step = caches_to_write['step']
2171 tensor_tracer_summary = caches_to_write[_TT_SUMMARY_TAG]
2172 tt_core_summary = self.merge_caches_on_tpu(tensor_tracer_summary[0])
2173 if not self._parameters.collect_summary_per_core:
2174 tt_core_summary = self.aggregate_global_cache(tt_core_summary)
2176 def write_if_core_0(step, replica_id, tt_summary):
2178 return cond.cond(
2179 math_ops.equal(replica_id, 0),
2180 lambda: write_cache(step=step, event_file_suffix=None, # pylint: disable=g-long-lambda
2181 tensor_tracer_summary=tt_summary),
2182 control_flow_ops.no_op)
2184 write_op = tpu_replication.outside_compilation(
2185 write_if_core_0,
2186 step=step,
2187 replica_id=self._replica_id,
2188 tt_summary=tt_core_summary)
2189 processed_t_fetches = control_flow_ops.tuple(
2190 processed_t_fetches, control_inputs=[write_op])
2191 del self._host_call_fn[_TT_HOSTCALL_KEY]
2192 else:
2193 raise ValueError('Outside compiled flush in only supported for '
2194 'summary mode')
2195 else:
2196 processed_t_fetches = self._flush_tensor_values_cache(
2197 processed_t_fetches, op_fetches, on_tpu=on_tpu,
2198 tensor_trace_order=tensor_trace_order,
2199 graph=graph)
2201 # processed_t_fetches is a list at this point. Convert it to the same
2202 # format as given in tensor_fetches.
2203 return self._convert_fetches_to_input_format(tensor_fetches,
2204 processed_t_fetches)
2206 def trace_tpu(self, graph,
2207 tensor_fetches,
2208 op_fetches=None,
2209 num_replicas=None,
2210 num_replicas_per_host=None,
2211 num_hosts=None):
2212 """Traces the tensors generated by TPU Ops in a TF graph.
2214 Args:
2215 graph: the graph of Ops executed on the TPU.
2216 tensor_fetches: a (list,tuple,or a single object) of tensor fetches
2217 returned by model_fn given to session.run. Function must be provided
2218 with as least one tensor to fetch.
2219 op_fetches: A list of op fetches returned by model_fn given to
2220 session.run. op_fetches and tensor_fetches are used to determine the
2221 nodes that will be executed. Can be None.
2222 num_replicas: number of replicas used on the TPU.
2223 num_replicas_per_host: number of replicas per TPU host.
2224 num_hosts: total number of TPU hosts.
2226 Returns:
2227 tensor_fetches: an exact copy of tensor_fetches that has additional
2228 dependencies.
2229 """
2230 if isinstance(graph, func_graph.FuncGraph) or isinstance(
2231 graph, function._FuncGraph): # pylint: disable=protected-access
2232 logging.warning('Tensor Tracer is not supported for tracing FuncGraphs. '
2233 'Ignoring tracing.')
2234 return tensor_fetches
2236 if graph in TensorTracer._traced_graphs:
2237 logging.warning('Graph is already rewritten with tensor tracer, ignoring '
2238 'multiple calls.')
2239 return tensor_fetches
2240 else:
2241 TensorTracer._traced_graphs.add(graph)
2242 # Reset the parameters in case parameters are changed.
2243 self._parameters = tensor_tracer_flags.TTParameters()
2244 self._tt_config.device_type = _DEVICE_TYPE_TPU
2245 self._tt_config.num_replicas = num_replicas
2246 self._tt_config.num_replicas_per_host = num_replicas_per_host
2247 self._tt_config.num_hosts = num_hosts
2248 if self._tt_config.num_replicas is not None:
2249 if self._tt_config.num_replicas_per_host is None:
2250 self._tt_config.num_replicas_per_host = 8
2251 if self._tt_config.num_hosts is None:
2252 self._tt_config.num_hosts = (
2253 num_replicas // self._tt_config.num_replicas_per_host +
2254 (num_replicas % self._tt_config.num_replicas_per_host > 0))
2256 if self._parameters.graph_dump_path:
2257 graph_io.write_graph(graph, self._parameters.graph_dump_path,
2258 'graph_before_tt.pbtxt')
2259 with graph.as_default():
2260 self._add_replica_id_to_graph()
2261 tensor_fetches = self._trace_execution(graph, tensor_fetches, op_fetches,
2262 on_tpu=True)
2263 if self._parameters.graph_dump_path:
2264 graph_io.write_graph(graph, self._parameters.graph_dump_path,
2265 'graph_after_tt.pbtxt')
2266 return tensor_fetches
2268 def trace_cpu(self, graph, tensor_fetches, op_fetches=None):
2269 """Traces the tensors generated by CPU Ops in a TF graph.
2271 Args:
2272 graph: the graph of Ops executed on the CPU.
2273 tensor_fetches: a (list,tuple,or a single object) of tensor fetches
2274 returned by model_fn given to session.run. Function must be provided
2275 with as least one tensor to fetch.
2276 op_fetches: A list of op fetches returned by model_fn given to
2277 session.run. op_fetches and tensor_fetches are used to determine the
2278 nodes that will be executed. Can be None.
2280 Returns:
2281 tensor_fetches: an exact copy of tensor_fetches that has additional
2282 dependencies.
2283 """
2284 if isinstance(graph, func_graph.FuncGraph) or isinstance(
2285 graph, function._FuncGraph): # pylint: disable=protected-access
2286 logging.warning('Tensor Tracer is not supported for tracing FuncGraphs. '
2287 'Ignoring tracing.')
2288 return tensor_fetches
2290 if graph in TensorTracer._traced_graphs:
2291 logging.warning('Graph is already rewritten with tensor tracer, ignoring '
2292 'multiple calls.')
2293 return tensor_fetches
2294 else:
2295 TensorTracer._traced_graphs.add(graph)
2296 # Reset the parameters in case parameters are changed.
2297 self._parameters = tensor_tracer_flags.TTParameters()
2299 self._tt_config.device_type = _DEVICE_TYPE_CPU
2300 self._tt_config.num_replicas = 1
2301 self._tt_config.num_replicas_per_host = 1
2302 self._tt_config.num_hosts = 1
2303 self._replica_id = 0
2304 if self._parameters.graph_dump_path:
2305 graph_io.write_graph(graph, self._parameters.graph_dump_path,
2306 'graph_before_tt.pbtxt')
2307 with graph.as_default():
2308 tensor_fetches = self._trace_execution(graph, tensor_fetches, op_fetches,
2309 on_tpu=False)
2310 if self._parameters.graph_dump_path:
2311 graph_io.write_graph(graph, self._parameters.graph_dump_path,
2312 'graph_after_tt.pbtxt')
2313 return tensor_fetches