Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/profiler/tfprof_logger.py: 16%
112 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Logging tensorflow::tfprof::OpLogProto.
17OpLogProto is used to add extra model information for offline analysis.
18"""
19import os
20import sys
22from tensorflow.core.profiler import tfprof_log_pb2
23from tensorflow.python.eager import context
24from tensorflow.python.framework import ops
25from tensorflow.python.framework import tensor_shape
26from tensorflow.python.platform import gfile
27from tensorflow.python.profiler.internal import flops_registry # pylint: disable=unused-import
28from tensorflow.python.util.tf_export import tf_export
30TRAINABLE_VARIABLES = '_trainable_variables'
31REGISTERED_FLOP_STATS = 'flops'
34def _fill_missing_graph_shape(graph, run_meta):
35 """Fill Tensor shapes in 'graph' with run time shape from 'run_meta'."""
36 for dev_stat in run_meta.step_stats.dev_stats:
37 for node_stat in dev_stat.node_stats:
38 if not node_stat.output:
39 continue
40 try:
41 op = graph.get_operation_by_name(node_stat.node_name)
42 except KeyError as e:
43 # Graph doesn't contains the node_stat, usually RecvTensor.
44 continue
45 if len(node_stat.output) != len(op.outputs):
46 # For example, conditional op has only 1 output at run time.
47 continue
48 for (i, node_stat_out) in enumerate(node_stat.output):
49 if op.outputs[i].get_shape().is_fully_defined():
50 continue
51 node_stat_dims = node_stat_out.tensor_description.shape.dim
52 node_stat_shape = tensor_shape.TensorShape(
53 [d.size for d in node_stat_dims])
54 try:
55 op.outputs[i].set_shape(op.outputs[i].get_shape().merge_with(
56 node_stat_shape))
57 except ValueError as e:
58 sys.stderr.write('Node %s incompatible shapes: %s.\n' %
59 (node_stat.node_name, e))
60 return graph
63def _str_id(s, str_to_id):
64 """Maps string to id."""
65 num = str_to_id.get(s, None)
66 if num is None:
67 num = len(str_to_id)
68 str_to_id[s] = num
69 return num
72def _get_logged_ops(graph, run_meta=None, add_trace=True,
73 add_trainable_var=True):
74 """Extract trainable model parameters and FLOPs for ops from a Graph.
76 Args:
77 graph: tf.Graph.
78 run_meta: RunMetadata proto used to complete shape information.
79 add_trace: Whether to add op trace information.
80 add_trainable_var: Whether to assign tf.compat.v1.trainable_variables() op
81 type '_trainable_variables'.
82 Returns:
83 logged_ops: dict mapping from op_name to OpLogEntry.
84 string_to_id: dict mapping from string to id.
85 """
86 if run_meta:
87 graph = _fill_missing_graph_shape(graph, run_meta)
89 op_missing_shape = 0
90 logged_ops = {}
91 string_to_id = {}
92 string_to_id['none'] = len(string_to_id)
93 # TODO(xpan): Work with Profiler more efficiently.
94 for op in graph.get_operations():
95 try:
96 stats = ops.get_stats_for_node_def(
97 graph, op.node_def, REGISTERED_FLOP_STATS)
98 except ValueError:
99 # Catch Exception When shape is incomplete. Skip it.
100 op_missing_shape += 1
101 stats = None
103 entry = tfprof_log_pb2.OpLogEntry()
104 entry.name = op.name
105 add_entry = False
106 if stats and stats.value:
107 entry.float_ops = int(stats.value)
108 add_entry = True
110 if add_trace:
111 if op.traceback:
112 for filename, lineno, funcname, line in op.traceback:
113 trace = entry.code_def.traces.add()
114 trace.file_id = _str_id(filename, string_to_id) if filename else 0
115 trace.lineno = lineno if lineno else -1
116 trace.function_id = _str_id(funcname, string_to_id) if funcname else 0
117 trace.line_id = _str_id(line, string_to_id) if line else 0
118 # TODO(slebedev): remove this unused field from the proto.
119 trace.func_start_line = -1
120 add_entry = True
122 if add_entry:
123 logged_ops[entry.name] = entry
125 if add_trainable_var:
126 for v in graph.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES):
127 if v.op.name not in logged_ops:
128 entry = tfprof_log_pb2.OpLogEntry()
129 entry.name = v.op.name
130 entry.types.append(TRAINABLE_VARIABLES)
131 logged_ops[entry.name] = entry
132 else:
133 logged_ops[v.op.name].types.append(TRAINABLE_VARIABLES)
135 if op_missing_shape > 0 and not run_meta:
136 sys.stderr.write('%d ops no flops stats due to incomplete shapes.\n' %
137 op_missing_shape)
138 return logged_ops, string_to_id
141def merge_default_with_oplog(graph, op_log=None, run_meta=None,
142 add_trace=True, add_trainable_var=True):
143 """Merge the tfprof default extra info with caller's op_log.
145 Args:
146 graph: tf.Graph. If None and eager execution is not enabled, use
147 default graph.
148 op_log: OpLogProto proto.
149 run_meta: RunMetadata proto used to complete shape information.
150 add_trace: Whether to add op trace information.
151 add_trainable_var: Whether to assign tf.compat.v1.trainable_variables() op
152 type '_trainable_variables'.
153 Returns:
154 tmp_op_log: Merged OpLogProto proto.
155 """
156 if not graph and not context.executing_eagerly():
157 graph = ops.get_default_graph()
159 tmp_op_log = tfprof_log_pb2.OpLogProto()
160 if not graph:
161 return tmp_op_log
163 logged_ops, string_to_id = _get_logged_ops(
164 graph, run_meta, add_trace=add_trace, add_trainable_var=add_trainable_var)
166 if not op_log:
167 tmp_op_log.log_entries.extend(logged_ops.values())
168 else:
169 all_ops = {}
170 for entry in op_log.log_entries:
171 all_ops[entry.name] = entry
172 for op_name, entry in logged_ops.items():
173 if op_name in all_ops:
174 all_ops[op_name].types.extend(entry.types)
175 if entry.float_ops > 0 and all_ops[op_name].float_ops == 0:
176 all_ops[op_name].float_ops = entry.float_ops
177 if entry.code_def.traces and not all_ops[op_name].code_def.traces:
178 all_ops[op_name].code_def.MergeFrom(entry.code_def)
179 else:
180 all_ops[op_name] = entry
181 tmp_op_log.log_entries.extend(all_ops.values())
183 for s, i in string_to_id.items():
184 tmp_op_log.id_to_string[i] = s
185 return tmp_op_log
188@tf_export(v1=['profiler.write_op_log'])
189def write_op_log(graph, log_dir, op_log=None, run_meta=None, add_trace=True):
190 """Log provided 'op_log', and add additional model information below.
192 The API also assigns ops in tf.compat.v1.trainable_variables() an op type
193 called '_trainable_variables'.
194 The API also logs 'flops' statistics for ops with op.RegisterStatistics()
195 defined. flops calculation depends on Tensor shapes defined in 'graph',
196 which might not be complete. 'run_meta', if provided, completes the shape
197 information with best effort.
199 Args:
200 graph: tf.Graph. If None and eager execution is not enabled, use
201 default graph.
202 log_dir: directory to write the log file.
203 op_log: (Optional) OpLogProto proto to be written. If not provided, an new
204 one is created.
205 run_meta: (Optional) RunMetadata proto that helps flops computation using
206 run time shape information.
207 add_trace: Whether to add python code trace information.
208 Used to support "code" view.
209 """
210 if not graph and not context.executing_eagerly():
211 graph = ops.get_default_graph()
212 op_log = merge_default_with_oplog(graph, op_log, run_meta, add_trace)
214 with gfile.Open(os.path.join(log_dir, 'tfprof_log'), 'w') as log:
215 log.write(op_log.SerializeToString())