Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/profiler/internal/flops_registry.py: 45%
204 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Register flops statistics for various TensorFlow operations.
16"""
17import numpy as np
19from tensorflow.python.framework import graph_util
20from tensorflow.python.framework import ops
23# List of all ops which have implemented flops statistics.
24IMPLEMENTED_OPS = set([
25 # Unary ops
26 "Reciprocal", "Square", "Rsqrt", "Log", "Neg", "AssignSub", "AssignAdd",
27 "L2Loss", "Softmax",
28 # Binary ops
29 "Add", "Sub", "Mul", "RealDiv", "Maximum", "Minimum", "Pow", "RsqrtGrad",
30 "GreaterEqual", "Greater", "LessEqual", "Less", "Equal", "NotEqual",
31 "SquaredDifference", "AddV2",
32 # Reduction ops
33 "Mean", "Sum", "ArgMax", "ArgMin", "BiasAddGrad",
34 # Convolution and pooling
35 "AvgPool", "MaxPool", "AvgPoolGrad", "MaxPoolGrad", "Conv2DBackpropInput",
36 "Conv2DBackpropFilter",
37 # Other ops
38 "AddN", "MatMul",
39 # Ops implemented in core tensorflow:
40 "Conv2D", "DepthwiseConv2dNative", "BiasAdd", "Dilation2D",
41])
44def _zero_flops(graph, node):
45 """Returns zero flops."""
46 del graph, node # graph and node are unused
47 return ops.OpStats("flops", 0)
50def _list_product(lst):
51 """Computes product of element of the list."""
52 result = 1
53 for item in lst:
54 result *= item
55 return result
57################################################################################
58# Unary operations
59################################################################################
62def _unary_op_flops(graph, node, ops_per_element=1):
63 """Common code which compute flops for unary operations."""
64 in_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0])
65 in_shape.assert_is_fully_defined()
66 return ops.OpStats("flops", in_shape.num_elements() * ops_per_element)
69@ops.RegisterStatistics("Reciprocal", "flops")
70def _reciprocal_flops(graph, node):
71 """Compute flops for Reciprocal operation."""
72 return _unary_op_flops(graph, node)
75@ops.RegisterStatistics("Square", "flops")
76def _square_flops(graph, node):
77 """Compute flops for Square operation."""
78 return _unary_op_flops(graph, node)
81@ops.RegisterStatistics("Rsqrt", "flops")
82def _rsqrt_flops(graph, node):
83 """Compute flops for Rsqrt operation."""
84 # Rsqrt(x) = 1 / sqrt(x)
85 return _unary_op_flops(graph, node, ops_per_element=2)
88@ops.RegisterStatistics("Log", "flops")
89def _log_flops(graph, node):
90 """Compute flops for Log operation."""
91 return _unary_op_flops(graph, node)
94@ops.RegisterStatistics("Neg", "flops")
95def _neg_flops(graph, node):
96 """Compute flops for Neg operation."""
97 return _unary_op_flops(graph, node)
100@ops.RegisterStatistics("AssignSub", "flops")
101def _assign_sub_flops(graph, node):
102 """Compute flops for AssignSub operation."""
103 return _unary_op_flops(graph, node)
106@ops.RegisterStatistics("AssignAdd", "flops")
107def _assign_add_flops(graph, node):
108 """Compute flops for AssignAdd operation."""
109 return _unary_op_flops(graph, node)
112@ops.RegisterStatistics("L2Loss", "flops")
113def _l2_loss_flops(graph, node):
114 """Compute flops for L2Loss operation."""
115 in_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0])
116 in_shape.assert_is_fully_defined()
117 # Tensorflow uses inefficient implementation, with (3*N-1) flops:
118 # Optimal implementation is 2*N flops
119 return ops.OpStats("flops", in_shape.num_elements() * 3 - 1)
122@ops.RegisterStatistics("Softmax", "flops")
123def _softmax_flops(graph, node):
124 """Compute flops for Softmax operation."""
125 # Softmax implemetation:
126 #
127 # Approximate flops breakdown:
128 # 2*n -- compute shifted logits
129 # n -- exp of shifted logits
130 # 2*n -- compute softmax from exp of shifted logits
131 return _unary_op_flops(graph, node, ops_per_element=5)
133################################################################################
134# Binary operations
135################################################################################
138def _binary_per_element_op_flops(graph, node, ops_per_element=1):
139 """Common code which compute flops for binary operations."""
140 out_shape = graph_util.tensor_shape_from_node_def_name(graph, node.name)
141 out_shape.assert_is_fully_defined()
142 return ops.OpStats("flops", out_shape.num_elements() * ops_per_element)
145@ops.RegisterStatistics("Add", "flops")
146@ops.RegisterStatistics("AddV2", "flops")
147def _add_flops(graph, node):
148 """Compute flops for Add operation."""
149 return _binary_per_element_op_flops(graph, node)
152@ops.RegisterStatistics("Sub", "flops")
153def _sub_flops(graph, node):
154 """Compute flops for Sub operation."""
155 return _binary_per_element_op_flops(graph, node)
158@ops.RegisterStatistics("Mul", "flops")
159def _mul_flops(graph, node):
160 """Compute flops for Mul operation."""
161 return _binary_per_element_op_flops(graph, node)
164@ops.RegisterStatistics("RealDiv", "flops")
165def _real_div_flops(graph, node):
166 """Compute flops for RealDiv operation."""
167 return _binary_per_element_op_flops(graph, node)
170@ops.RegisterStatistics("Maximum", "flops")
171def _maximum_flops(graph, node):
172 """Compute flops for Maximum operation."""
173 return _binary_per_element_op_flops(graph, node)
176@ops.RegisterStatistics("Minimum", "flops")
177def _minimum_flops(graph, node):
178 """Compute flops for Minimum operation."""
179 return _binary_per_element_op_flops(graph, node)
182@ops.RegisterStatistics("Pow", "flops")
183def _pow_flops(graph, node):
184 """Compute flops for Pow operation."""
185 return _binary_per_element_op_flops(graph, node)
188@ops.RegisterStatistics("RsqrtGrad", "flops")
189def _rsqrt_grad_flops(graph, node):
190 """Compute flops for RsqrtGrad operation."""
191 return _binary_per_element_op_flops(graph, node, ops_per_element=4)
194@ops.RegisterStatistics("GreaterEqual", "flops")
195def _greater_equal_flops(graph, node):
196 """Compute flops for GreaterEqual operation."""
197 return _binary_per_element_op_flops(graph, node)
200@ops.RegisterStatistics("Greater", "flops")
201def _greater_flops(graph, node):
202 """Compute flops for Greater operation."""
203 return _binary_per_element_op_flops(graph, node)
206@ops.RegisterStatistics("LessEqual", "flops")
207def _less_equal_flops(graph, node):
208 """Compute flops for LessEqual operation."""
209 return _binary_per_element_op_flops(graph, node)
212@ops.RegisterStatistics("Less", "flops")
213def _less_flops(graph, node):
214 """Compute flops for Less operation."""
215 return _binary_per_element_op_flops(graph, node)
218@ops.RegisterStatistics("Equal", "flops")
219def _equal_flops(graph, node):
220 """Compute flops for Equal operation."""
221 return _binary_per_element_op_flops(graph, node)
224@ops.RegisterStatistics("NotEqual", "flops")
225def _not_equal_flops(graph, node):
226 """Compute flops for NotEqual operation."""
227 return _binary_per_element_op_flops(graph, node)
230@ops.RegisterStatistics("SquaredDifference", "flops")
231def _squared_difference_flops(graph, node):
232 """Compute flops for SquaredDifference operation."""
233 return _binary_per_element_op_flops(graph, node, ops_per_element=2)
235################################################################################
236# Reduction ops
237################################################################################
240def _reduction_op_flops(graph, node, reduce_flops=1, finalize_flops=0):
241 """Common code which compute flops for reduction operations."""
242 in_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0])
243 in_shape.assert_is_fully_defined()
244 out_shape = graph_util.tensor_shape_from_node_def_name(graph, node.name)
245 out_shape.assert_is_fully_defined()
246 num_flops = (in_shape.num_elements() * reduce_flops
247 + out_shape.num_elements() * (finalize_flops - reduce_flops))
248 return ops.OpStats("flops", num_flops)
251@ops.RegisterStatistics("Mean", "flops")
252def _mean_flops(graph, node):
253 """Compute flops for Mean operation."""
254 # reduction - sum, finalization - divide
255 return _reduction_op_flops(graph, node, reduce_flops=1, finalize_flops=1)
258@ops.RegisterStatistics("Sum", "flops")
259def _sum_flops(graph, node):
260 """Compute flops for Sum operation."""
261 # reduction - sum, no finalization
262 return _reduction_op_flops(graph, node, reduce_flops=1, finalize_flops=0)
265@ops.RegisterStatistics("ArgMax", "flops")
266def _arg_max_flops(graph, node):
267 """Compute flops for ArgMax operation."""
268 # reduction - comparison, no finalization
269 return _reduction_op_flops(graph, node, reduce_flops=1, finalize_flops=0)
272@ops.RegisterStatistics("ArgMin", "flops")
273def _arg_min_flops(graph, node):
274 """Compute flops for ArgMin operation."""
275 # reduction - comparison, no finalization
276 return _reduction_op_flops(graph, node, reduce_flops=1, finalize_flops=0)
279@ops.RegisterStatistics("BiasAddGrad", "flops")
280def _bias_add_grad_flops(graph, node):
281 """Compute flops for BiasAddGrad operation."""
282 # Implementation of BiasAddGrad, essentially it's a reduce sum and reshaping:
283 # So computing flops same way as for "Sum"
284 return _reduction_op_flops(graph, node, reduce_flops=1, finalize_flops=0)
286################################################################################
287# Convolution and pooling
288# Note: all flops statistics are implemented only for NHWC data format
289################################################################################
292def _verify_conv_data_format(node):
293 """Verifies data format for pooling and convolutional operations."""
294 # TODO(xpan): P1: Support NCHW
295 if node.attr["data_format"].s != b"NHWC":
296 raise ValueError("Only NHWC format is supported in flops computations")
299def _pool_flops(graph, node):
300 """Common code which compute flops for pooling operations."""
301 # compute flops for average and max pooling
302 _verify_conv_data_format(node)
303 #
304 # Pooling declaration:
305 # Inputs:
306 # - value
307 # Outputs:
308 # - output
309 # Attributes:
310 # - ksize
311 # - strides
312 # - padding
313 # - data_format
314 #
315 # Pooling implemetation:
316 out_shape = graph_util.tensor_shape_from_node_def_name(graph, node.name)
317 out_shape.assert_is_fully_defined()
318 kernel_shape = list(node.attr["ksize"].list.i)
319 kernel_area = _list_product(kernel_shape)
320 return ops.OpStats("flops", kernel_area * out_shape.num_elements())
323@ops.RegisterStatistics("AvgPool", "flops")
324def _avg_pool_flops(graph, node):
325 """Compute flops for AvgPool operation."""
326 return _pool_flops(graph, node)
329@ops.RegisterStatistics("MaxPool", "flops")
330def _max_pool_flops(graph, node):
331 """Compute flops for MaxPool operation."""
332 return _pool_flops(graph, node)
335@ops.RegisterStatistics("AvgPoolGrad", "flops")
336def _avg_pool_grad_flops(graph, node):
337 """Compute flops for AvgPoolGrad operation."""
338 _verify_conv_data_format(node)
339 # Pooling gradient implementation:
340 out_backprop_shape = graph_util.tensor_shape_from_node_def_name(graph,
341 node.input[1])
342 out_backprop_shape.assert_is_fully_defined()
343 kernel_shape = list(node.attr["ksize"].list.i)
344 kernel_area = _list_product(kernel_shape)
345 # TensorFlow multiply each element of pooling window by coefficient,
346 # then sum up all of them, thus we have 2 flops per element:
347 # More optimal implementation - if division is done after.
348 return ops.OpStats("flops",
349 kernel_area * out_backprop_shape.num_elements() * 2)
352@ops.RegisterStatistics("MaxPoolGrad", "flops")
353def _max_pool_grad_flops(graph, node):
354 """Compute flops for MaxPoolGrad operation."""
355 _verify_conv_data_format(node)
356 #
357 # MaxPoolGrad declaration:
358 # Inputs:
359 # - orig_input -- original input tensor (of max_pool)
360 # - orig_output -- original output tensor (of max_pool)
361 # - grad -- gradient with respect to output of max_pool
362 # Outputs:
363 # - output -- gradient with respect to input of max_pool
364 # Attributes:
365 # - ksize
366 # - strides
367 # - padding
368 # - data_format
369 # It computes MaxPool first, then one flop per each element of original output
370 #
371 kernel_shape = list(node.attr["ksize"].list.i)
372 kernel_area = _list_product(kernel_shape)
373 orig_out_shape = graph_util.tensor_shape_from_node_def_name(graph,
374 node.input[1])
375 orig_out_shape.assert_is_fully_defined()
376 max_pool_ops = kernel_area * orig_out_shape.num_elements()
377 return ops.OpStats("flops", max_pool_ops + orig_out_shape.num_elements())
380@ops.RegisterStatistics("Conv2DBackpropInput", "flops")
381def _conv_2d_backprop_input_flops(graph, node):
382 """Compute flops for Conv2DBackpropInput operation."""
383 # Formula:
384 # batch_size * image_x_dim * image_y_dim * kernel_x_dim * kernel_y_dim
385 # * input_depth * output_depth * 2 / (image_x_stride * image_x_stride)
386 #
387 # Where:
388 # image_x_dim, image_y_dim and input_depth --- size of input to source (no
389 # backprop) convolution, in other words they are sizes of backprop output.
390 # output_depth --- number of filters in the original convolution, thus
391 # depth of backprop input.
392 # kernel_x_dim and kernel_y_dim --- sizes of filter in spatial dimension
393 # image_x_stride and image_x_stride --- strides of the convolution
394 #
395 _verify_conv_data_format(node)
396 # out_shape = [batch_size, image_y_dim, image_x_dim, input_depth]
397 out_shape = graph_util.tensor_shape_from_node_def_name(graph, node.name)
398 out_shape.assert_is_fully_defined()
399 # kernel_shape = [kernel_y_dim, kernel_x_dim, input_depth, output_depth]
400 kernel_shape = graph_util.tensor_shape_from_node_def_name(graph,
401 node.input[1])
402 kernel_shape.assert_is_fully_defined()
403 # strides
404 strides_shape = list(node.attr["strides"].list.i)
405 strides_product = strides_shape[1] * strides_shape[2]
406 return ops.OpStats("flops",
407 (2 * out_shape.num_elements()
408 * kernel_shape.num_elements()
409 / (out_shape.dims[-1].value * strides_product)))
412@ops.RegisterStatistics("Conv2DBackpropFilter", "flops")
413def _conv_2d_backprop_filter_flops(graph, node):
414 """Compute flops for Conv2DBackpropFilter operation."""
415 # Formula same as for Conv2DBackpropInput:
416 # batch_size * image_x_dim * image_y_dim * kernel_x_dim * kernel_y_dim
417 # * input_depth * output_depth * 2 / (image_x_stride * image_x_stride)
418 #
419 _verify_conv_data_format(node)
420 # image_shape = [batch_size, image_y_dim, image_x_dim, input_depth]
421 image_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0])
422 image_shape.assert_is_fully_defined()
423 # kernel_shape = [kernel_y_dim, kernel_x_dim, input_depth, output_depth]
424 kernel_shape = graph_util.tensor_shape_from_node_def_name(graph, node.name)
425 kernel_shape.assert_is_fully_defined()
426 # strides
427 strides_shape = list(node.attr["strides"].list.i)
428 strides_product = strides_shape[1] * strides_shape[2]
429 return ops.OpStats("flops",
430 (2 * image_shape.num_elements()
431 * kernel_shape.num_elements()
432 / (image_shape.dims[-1].value * strides_product)))
434################################################################################
435# Other ops
436################################################################################
439@ops.RegisterStatistics("AddN", "flops")
440def _add_n_flops(graph, node):
441 """Compute flops for AddN operation."""
442 if not node.input:
443 return _zero_flops(graph, node)
444 in_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0])
445 in_shape.assert_is_fully_defined()
446 return ops.OpStats("flops", in_shape.num_elements() * (len(node.input) - 1))
449@ops.RegisterStatistics("MatMul", "flops")
450def _calc_mat_mul_flops(graph, node):
451 """Calculates the compute resources needed for MatMul."""
452 transpose_a = node.attr["transpose_a"].b
453 a_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0])
454 a_shape.assert_is_fully_defined()
455 if transpose_a:
456 k = int(a_shape[0])
457 else:
458 k = int(a_shape[1])
459 output_shape = graph_util.tensor_shape_from_node_def_name(graph, node.name)
460 output_shape.assert_is_fully_defined()
461 output_count = np.prod(output_shape.as_list())
462 return ops.OpStats("flops", (k * output_count * 2))
465@ops.RegisterStatistics("BatchMatMul", "flops")
466@ops.RegisterStatistics("BatchMatMulV2", "flops")
467@ops.RegisterStatistics("BatchMatMulV3", "flops")
468def _calc_batch_mat_mul_flops(graph, node):
469 """Calculates the compute resources needed for BatchMatMul."""
470 transpose_a = node.attr["transpose_a"].b
471 a_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0])
472 a_shape.assert_is_fully_defined()
473 if transpose_a:
474 k = int(a_shape[-2])
475 else:
476 k = int(a_shape[-1])
477 output_shape = graph_util.tensor_shape_from_node_def_name(graph, node.name)
478 output_shape.assert_is_fully_defined()
479 output_count = np.prod(output_shape.as_list())
480 return ops.OpStats("flops", (k * output_count * 2))