Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/profiler/tfprof_logger.py: 16%

112 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Logging tensorflow::tfprof::OpLogProto. 

16 

17OpLogProto is used to add extra model information for offline analysis. 

18""" 

19import os 

20import sys 

21 

22from tensorflow.core.profiler import tfprof_log_pb2 

23from tensorflow.python.eager import context 

24from tensorflow.python.framework import ops 

25from tensorflow.python.framework import tensor_shape 

26from tensorflow.python.platform import gfile 

27from tensorflow.python.profiler.internal import flops_registry # pylint: disable=unused-import 

28from tensorflow.python.util.tf_export import tf_export 

29 

30TRAINABLE_VARIABLES = '_trainable_variables' 

31REGISTERED_FLOP_STATS = 'flops' 

32 

33 

34def _fill_missing_graph_shape(graph, run_meta): 

35 """Fill Tensor shapes in 'graph' with run time shape from 'run_meta'.""" 

36 for dev_stat in run_meta.step_stats.dev_stats: 

37 for node_stat in dev_stat.node_stats: 

38 if not node_stat.output: 

39 continue 

40 try: 

41 op = graph.get_operation_by_name(node_stat.node_name) 

42 except KeyError as e: 

43 # Graph doesn't contains the node_stat, usually RecvTensor. 

44 continue 

45 if len(node_stat.output) != len(op.outputs): 

46 # For example, conditional op has only 1 output at run time. 

47 continue 

48 for (i, node_stat_out) in enumerate(node_stat.output): 

49 if op.outputs[i].get_shape().is_fully_defined(): 

50 continue 

51 node_stat_dims = node_stat_out.tensor_description.shape.dim 

52 node_stat_shape = tensor_shape.TensorShape( 

53 [d.size for d in node_stat_dims]) 

54 try: 

55 op.outputs[i].set_shape(op.outputs[i].get_shape().merge_with( 

56 node_stat_shape)) 

57 except ValueError as e: 

58 sys.stderr.write('Node %s incompatible shapes: %s.\n' % 

59 (node_stat.node_name, e)) 

60 return graph 

61 

62 

63def _str_id(s, str_to_id): 

64 """Maps string to id.""" 

65 num = str_to_id.get(s, None) 

66 if num is None: 

67 num = len(str_to_id) 

68 str_to_id[s] = num 

69 return num 

70 

71 

72def _get_logged_ops(graph, run_meta=None, add_trace=True, 

73 add_trainable_var=True): 

74 """Extract trainable model parameters and FLOPs for ops from a Graph. 

75 

76 Args: 

77 graph: tf.Graph. 

78 run_meta: RunMetadata proto used to complete shape information. 

79 add_trace: Whether to add op trace information. 

80 add_trainable_var: Whether to assign tf.compat.v1.trainable_variables() op 

81 type '_trainable_variables'. 

82 Returns: 

83 logged_ops: dict mapping from op_name to OpLogEntry. 

84 string_to_id: dict mapping from string to id. 

85 """ 

86 if run_meta: 

87 graph = _fill_missing_graph_shape(graph, run_meta) 

88 

89 op_missing_shape = 0 

90 logged_ops = {} 

91 string_to_id = {} 

92 string_to_id['none'] = len(string_to_id) 

93 # TODO(xpan): Work with Profiler more efficiently. 

94 for op in graph.get_operations(): 

95 try: 

96 stats = ops.get_stats_for_node_def( 

97 graph, op.node_def, REGISTERED_FLOP_STATS) 

98 except ValueError: 

99 # Catch Exception When shape is incomplete. Skip it. 

100 op_missing_shape += 1 

101 stats = None 

102 

103 entry = tfprof_log_pb2.OpLogEntry() 

104 entry.name = op.name 

105 add_entry = False 

106 if stats and stats.value: 

107 entry.float_ops = int(stats.value) 

108 add_entry = True 

109 

110 if add_trace: 

111 if op.traceback: 

112 for filename, lineno, funcname, line in op.traceback: 

113 trace = entry.code_def.traces.add() 

114 trace.file_id = _str_id(filename, string_to_id) if filename else 0 

115 trace.lineno = lineno if lineno else -1 

116 trace.function_id = _str_id(funcname, string_to_id) if funcname else 0 

117 trace.line_id = _str_id(line, string_to_id) if line else 0 

118 # TODO(slebedev): remove this unused field from the proto. 

119 trace.func_start_line = -1 

120 add_entry = True 

121 

122 if add_entry: 

123 logged_ops[entry.name] = entry 

124 

125 if add_trainable_var: 

126 for v in graph.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES): 

127 if v.op.name not in logged_ops: 

128 entry = tfprof_log_pb2.OpLogEntry() 

129 entry.name = v.op.name 

130 entry.types.append(TRAINABLE_VARIABLES) 

131 logged_ops[entry.name] = entry 

132 else: 

133 logged_ops[v.op.name].types.append(TRAINABLE_VARIABLES) 

134 

135 if op_missing_shape > 0 and not run_meta: 

136 sys.stderr.write('%d ops no flops stats due to incomplete shapes.\n' % 

137 op_missing_shape) 

138 return logged_ops, string_to_id 

139 

140 

141def merge_default_with_oplog(graph, op_log=None, run_meta=None, 

142 add_trace=True, add_trainable_var=True): 

143 """Merge the tfprof default extra info with caller's op_log. 

144 

145 Args: 

146 graph: tf.Graph. If None and eager execution is not enabled, use 

147 default graph. 

148 op_log: OpLogProto proto. 

149 run_meta: RunMetadata proto used to complete shape information. 

150 add_trace: Whether to add op trace information. 

151 add_trainable_var: Whether to assign tf.compat.v1.trainable_variables() op 

152 type '_trainable_variables'. 

153 Returns: 

154 tmp_op_log: Merged OpLogProto proto. 

155 """ 

156 if not graph and not context.executing_eagerly(): 

157 graph = ops.get_default_graph() 

158 

159 tmp_op_log = tfprof_log_pb2.OpLogProto() 

160 if not graph: 

161 return tmp_op_log 

162 

163 logged_ops, string_to_id = _get_logged_ops( 

164 graph, run_meta, add_trace=add_trace, add_trainable_var=add_trainable_var) 

165 

166 if not op_log: 

167 tmp_op_log.log_entries.extend(logged_ops.values()) 

168 else: 

169 all_ops = {} 

170 for entry in op_log.log_entries: 

171 all_ops[entry.name] = entry 

172 for op_name, entry in logged_ops.items(): 

173 if op_name in all_ops: 

174 all_ops[op_name].types.extend(entry.types) 

175 if entry.float_ops > 0 and all_ops[op_name].float_ops == 0: 

176 all_ops[op_name].float_ops = entry.float_ops 

177 if entry.code_def.traces and not all_ops[op_name].code_def.traces: 

178 all_ops[op_name].code_def.MergeFrom(entry.code_def) 

179 else: 

180 all_ops[op_name] = entry 

181 tmp_op_log.log_entries.extend(all_ops.values()) 

182 

183 for s, i in string_to_id.items(): 

184 tmp_op_log.id_to_string[i] = s 

185 return tmp_op_log 

186 

187 

188@tf_export(v1=['profiler.write_op_log']) 

189def write_op_log(graph, log_dir, op_log=None, run_meta=None, add_trace=True): 

190 """Log provided 'op_log', and add additional model information below. 

191 

192 The API also assigns ops in tf.compat.v1.trainable_variables() an op type 

193 called '_trainable_variables'. 

194 The API also logs 'flops' statistics for ops with op.RegisterStatistics() 

195 defined. flops calculation depends on Tensor shapes defined in 'graph', 

196 which might not be complete. 'run_meta', if provided, completes the shape 

197 information with best effort. 

198 

199 Args: 

200 graph: tf.Graph. If None and eager execution is not enabled, use 

201 default graph. 

202 log_dir: directory to write the log file. 

203 op_log: (Optional) OpLogProto proto to be written. If not provided, an new 

204 one is created. 

205 run_meta: (Optional) RunMetadata proto that helps flops computation using 

206 run time shape information. 

207 add_trace: Whether to add python code trace information. 

208 Used to support "code" view. 

209 """ 

210 if not graph and not context.executing_eagerly(): 

211 graph = ops.get_default_graph() 

212 op_log = merge_default_with_oplog(graph, op_log, run_meta, add_trace) 

213 

214 with gfile.Open(os.path.join(log_dir, 'tfprof_log'), 'w') as log: 

215 log.write(op_log.SerializeToString())