Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/profiler/internal/flops_registry.py: 45%

204 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Register flops statistics for various TensorFlow operations. 

16""" 

17import numpy as np 

18 

19from tensorflow.python.framework import graph_util 

20from tensorflow.python.framework import ops 

21 

22 

23# List of all ops which have implemented flops statistics. 

24IMPLEMENTED_OPS = set([ 

25 # Unary ops 

26 "Reciprocal", "Square", "Rsqrt", "Log", "Neg", "AssignSub", "AssignAdd", 

27 "L2Loss", "Softmax", 

28 # Binary ops 

29 "Add", "Sub", "Mul", "RealDiv", "Maximum", "Minimum", "Pow", "RsqrtGrad", 

30 "GreaterEqual", "Greater", "LessEqual", "Less", "Equal", "NotEqual", 

31 "SquaredDifference", "AddV2", 

32 # Reduction ops 

33 "Mean", "Sum", "ArgMax", "ArgMin", "BiasAddGrad", 

34 # Convolution and pooling 

35 "AvgPool", "MaxPool", "AvgPoolGrad", "MaxPoolGrad", "Conv2DBackpropInput", 

36 "Conv2DBackpropFilter", 

37 # Other ops 

38 "AddN", "MatMul", 

39 # Ops implemented in core tensorflow: 

40 "Conv2D", "DepthwiseConv2dNative", "BiasAdd", "Dilation2D", 

41]) 

42 

43 

44def _zero_flops(graph, node): 

45 """Returns zero flops.""" 

46 del graph, node # graph and node are unused 

47 return ops.OpStats("flops", 0) 

48 

49 

50def _list_product(lst): 

51 """Computes product of element of the list.""" 

52 result = 1 

53 for item in lst: 

54 result *= item 

55 return result 

56 

57################################################################################ 

58# Unary operations 

59################################################################################ 

60 

61 

62def _unary_op_flops(graph, node, ops_per_element=1): 

63 """Common code which compute flops for unary operations.""" 

64 in_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0]) 

65 in_shape.assert_is_fully_defined() 

66 return ops.OpStats("flops", in_shape.num_elements() * ops_per_element) 

67 

68 

69@ops.RegisterStatistics("Reciprocal", "flops") 

70def _reciprocal_flops(graph, node): 

71 """Compute flops for Reciprocal operation.""" 

72 return _unary_op_flops(graph, node) 

73 

74 

75@ops.RegisterStatistics("Square", "flops") 

76def _square_flops(graph, node): 

77 """Compute flops for Square operation.""" 

78 return _unary_op_flops(graph, node) 

79 

80 

81@ops.RegisterStatistics("Rsqrt", "flops") 

82def _rsqrt_flops(graph, node): 

83 """Compute flops for Rsqrt operation.""" 

84 # Rsqrt(x) = 1 / sqrt(x) 

85 return _unary_op_flops(graph, node, ops_per_element=2) 

86 

87 

88@ops.RegisterStatistics("Log", "flops") 

89def _log_flops(graph, node): 

90 """Compute flops for Log operation.""" 

91 return _unary_op_flops(graph, node) 

92 

93 

94@ops.RegisterStatistics("Neg", "flops") 

95def _neg_flops(graph, node): 

96 """Compute flops for Neg operation.""" 

97 return _unary_op_flops(graph, node) 

98 

99 

100@ops.RegisterStatistics("AssignSub", "flops") 

101def _assign_sub_flops(graph, node): 

102 """Compute flops for AssignSub operation.""" 

103 return _unary_op_flops(graph, node) 

104 

105 

106@ops.RegisterStatistics("AssignAdd", "flops") 

107def _assign_add_flops(graph, node): 

108 """Compute flops for AssignAdd operation.""" 

109 return _unary_op_flops(graph, node) 

110 

111 

112@ops.RegisterStatistics("L2Loss", "flops") 

113def _l2_loss_flops(graph, node): 

114 """Compute flops for L2Loss operation.""" 

115 in_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0]) 

116 in_shape.assert_is_fully_defined() 

117 # Tensorflow uses inefficient implementation, with (3*N-1) flops: 

118 # Optimal implementation is 2*N flops 

119 return ops.OpStats("flops", in_shape.num_elements() * 3 - 1) 

120 

121 

122@ops.RegisterStatistics("Softmax", "flops") 

123def _softmax_flops(graph, node): 

124 """Compute flops for Softmax operation.""" 

125 # Softmax implemetation: 

126 # 

127 # Approximate flops breakdown: 

128 # 2*n -- compute shifted logits 

129 # n -- exp of shifted logits 

130 # 2*n -- compute softmax from exp of shifted logits 

131 return _unary_op_flops(graph, node, ops_per_element=5) 

132 

133################################################################################ 

134# Binary operations 

135################################################################################ 

136 

137 

138def _binary_per_element_op_flops(graph, node, ops_per_element=1): 

139 """Common code which compute flops for binary operations.""" 

140 out_shape = graph_util.tensor_shape_from_node_def_name(graph, node.name) 

141 out_shape.assert_is_fully_defined() 

142 return ops.OpStats("flops", out_shape.num_elements() * ops_per_element) 

143 

144 

145@ops.RegisterStatistics("Add", "flops") 

146@ops.RegisterStatistics("AddV2", "flops") 

147def _add_flops(graph, node): 

148 """Compute flops for Add operation.""" 

149 return _binary_per_element_op_flops(graph, node) 

150 

151 

152@ops.RegisterStatistics("Sub", "flops") 

153def _sub_flops(graph, node): 

154 """Compute flops for Sub operation.""" 

155 return _binary_per_element_op_flops(graph, node) 

156 

157 

158@ops.RegisterStatistics("Mul", "flops") 

159def _mul_flops(graph, node): 

160 """Compute flops for Mul operation.""" 

161 return _binary_per_element_op_flops(graph, node) 

162 

163 

164@ops.RegisterStatistics("RealDiv", "flops") 

165def _real_div_flops(graph, node): 

166 """Compute flops for RealDiv operation.""" 

167 return _binary_per_element_op_flops(graph, node) 

168 

169 

170@ops.RegisterStatistics("Maximum", "flops") 

171def _maximum_flops(graph, node): 

172 """Compute flops for Maximum operation.""" 

173 return _binary_per_element_op_flops(graph, node) 

174 

175 

176@ops.RegisterStatistics("Minimum", "flops") 

177def _minimum_flops(graph, node): 

178 """Compute flops for Minimum operation.""" 

179 return _binary_per_element_op_flops(graph, node) 

180 

181 

182@ops.RegisterStatistics("Pow", "flops") 

183def _pow_flops(graph, node): 

184 """Compute flops for Pow operation.""" 

185 return _binary_per_element_op_flops(graph, node) 

186 

187 

188@ops.RegisterStatistics("RsqrtGrad", "flops") 

189def _rsqrt_grad_flops(graph, node): 

190 """Compute flops for RsqrtGrad operation.""" 

191 return _binary_per_element_op_flops(graph, node, ops_per_element=4) 

192 

193 

194@ops.RegisterStatistics("GreaterEqual", "flops") 

195def _greater_equal_flops(graph, node): 

196 """Compute flops for GreaterEqual operation.""" 

197 return _binary_per_element_op_flops(graph, node) 

198 

199 

200@ops.RegisterStatistics("Greater", "flops") 

201def _greater_flops(graph, node): 

202 """Compute flops for Greater operation.""" 

203 return _binary_per_element_op_flops(graph, node) 

204 

205 

206@ops.RegisterStatistics("LessEqual", "flops") 

207def _less_equal_flops(graph, node): 

208 """Compute flops for LessEqual operation.""" 

209 return _binary_per_element_op_flops(graph, node) 

210 

211 

212@ops.RegisterStatistics("Less", "flops") 

213def _less_flops(graph, node): 

214 """Compute flops for Less operation.""" 

215 return _binary_per_element_op_flops(graph, node) 

216 

217 

218@ops.RegisterStatistics("Equal", "flops") 

219def _equal_flops(graph, node): 

220 """Compute flops for Equal operation.""" 

221 return _binary_per_element_op_flops(graph, node) 

222 

223 

224@ops.RegisterStatistics("NotEqual", "flops") 

225def _not_equal_flops(graph, node): 

226 """Compute flops for NotEqual operation.""" 

227 return _binary_per_element_op_flops(graph, node) 

228 

229 

230@ops.RegisterStatistics("SquaredDifference", "flops") 

231def _squared_difference_flops(graph, node): 

232 """Compute flops for SquaredDifference operation.""" 

233 return _binary_per_element_op_flops(graph, node, ops_per_element=2) 

234 

235################################################################################ 

236# Reduction ops 

237################################################################################ 

238 

239 

240def _reduction_op_flops(graph, node, reduce_flops=1, finalize_flops=0): 

241 """Common code which compute flops for reduction operations.""" 

242 in_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0]) 

243 in_shape.assert_is_fully_defined() 

244 out_shape = graph_util.tensor_shape_from_node_def_name(graph, node.name) 

245 out_shape.assert_is_fully_defined() 

246 num_flops = (in_shape.num_elements() * reduce_flops 

247 + out_shape.num_elements() * (finalize_flops - reduce_flops)) 

248 return ops.OpStats("flops", num_flops) 

249 

250 

251@ops.RegisterStatistics("Mean", "flops") 

252def _mean_flops(graph, node): 

253 """Compute flops for Mean operation.""" 

254 # reduction - sum, finalization - divide 

255 return _reduction_op_flops(graph, node, reduce_flops=1, finalize_flops=1) 

256 

257 

258@ops.RegisterStatistics("Sum", "flops") 

259def _sum_flops(graph, node): 

260 """Compute flops for Sum operation.""" 

261 # reduction - sum, no finalization 

262 return _reduction_op_flops(graph, node, reduce_flops=1, finalize_flops=0) 

263 

264 

265@ops.RegisterStatistics("ArgMax", "flops") 

266def _arg_max_flops(graph, node): 

267 """Compute flops for ArgMax operation.""" 

268 # reduction - comparison, no finalization 

269 return _reduction_op_flops(graph, node, reduce_flops=1, finalize_flops=0) 

270 

271 

272@ops.RegisterStatistics("ArgMin", "flops") 

273def _arg_min_flops(graph, node): 

274 """Compute flops for ArgMin operation.""" 

275 # reduction - comparison, no finalization 

276 return _reduction_op_flops(graph, node, reduce_flops=1, finalize_flops=0) 

277 

278 

279@ops.RegisterStatistics("BiasAddGrad", "flops") 

280def _bias_add_grad_flops(graph, node): 

281 """Compute flops for BiasAddGrad operation.""" 

282 # Implementation of BiasAddGrad, essentially it's a reduce sum and reshaping: 

283 # So computing flops same way as for "Sum" 

284 return _reduction_op_flops(graph, node, reduce_flops=1, finalize_flops=0) 

285 

286################################################################################ 

287# Convolution and pooling 

288# Note: all flops statistics are implemented only for NHWC data format 

289################################################################################ 

290 

291 

292def _verify_conv_data_format(node): 

293 """Verifies data format for pooling and convolutional operations.""" 

294 # TODO(xpan): P1: Support NCHW 

295 if node.attr["data_format"].s != b"NHWC": 

296 raise ValueError("Only NHWC format is supported in flops computations") 

297 

298 

299def _pool_flops(graph, node): 

300 """Common code which compute flops for pooling operations.""" 

301 # compute flops for average and max pooling 

302 _verify_conv_data_format(node) 

303 # 

304 # Pooling declaration: 

305 # Inputs: 

306 # - value 

307 # Outputs: 

308 # - output 

309 # Attributes: 

310 # - ksize 

311 # - strides 

312 # - padding 

313 # - data_format 

314 # 

315 # Pooling implemetation: 

316 out_shape = graph_util.tensor_shape_from_node_def_name(graph, node.name) 

317 out_shape.assert_is_fully_defined() 

318 kernel_shape = list(node.attr["ksize"].list.i) 

319 kernel_area = _list_product(kernel_shape) 

320 return ops.OpStats("flops", kernel_area * out_shape.num_elements()) 

321 

322 

323@ops.RegisterStatistics("AvgPool", "flops") 

324def _avg_pool_flops(graph, node): 

325 """Compute flops for AvgPool operation.""" 

326 return _pool_flops(graph, node) 

327 

328 

329@ops.RegisterStatistics("MaxPool", "flops") 

330def _max_pool_flops(graph, node): 

331 """Compute flops for MaxPool operation.""" 

332 return _pool_flops(graph, node) 

333 

334 

335@ops.RegisterStatistics("AvgPoolGrad", "flops") 

336def _avg_pool_grad_flops(graph, node): 

337 """Compute flops for AvgPoolGrad operation.""" 

338 _verify_conv_data_format(node) 

339 # Pooling gradient implementation: 

340 out_backprop_shape = graph_util.tensor_shape_from_node_def_name(graph, 

341 node.input[1]) 

342 out_backprop_shape.assert_is_fully_defined() 

343 kernel_shape = list(node.attr["ksize"].list.i) 

344 kernel_area = _list_product(kernel_shape) 

345 # TensorFlow multiply each element of pooling window by coefficient, 

346 # then sum up all of them, thus we have 2 flops per element: 

347 # More optimal implementation - if division is done after. 

348 return ops.OpStats("flops", 

349 kernel_area * out_backprop_shape.num_elements() * 2) 

350 

351 

352@ops.RegisterStatistics("MaxPoolGrad", "flops") 

353def _max_pool_grad_flops(graph, node): 

354 """Compute flops for MaxPoolGrad operation.""" 

355 _verify_conv_data_format(node) 

356 # 

357 # MaxPoolGrad declaration: 

358 # Inputs: 

359 # - orig_input -- original input tensor (of max_pool) 

360 # - orig_output -- original output tensor (of max_pool) 

361 # - grad -- gradient with respect to output of max_pool 

362 # Outputs: 

363 # - output -- gradient with respect to input of max_pool 

364 # Attributes: 

365 # - ksize 

366 # - strides 

367 # - padding 

368 # - data_format 

369 # It computes MaxPool first, then one flop per each element of original output 

370 # 

371 kernel_shape = list(node.attr["ksize"].list.i) 

372 kernel_area = _list_product(kernel_shape) 

373 orig_out_shape = graph_util.tensor_shape_from_node_def_name(graph, 

374 node.input[1]) 

375 orig_out_shape.assert_is_fully_defined() 

376 max_pool_ops = kernel_area * orig_out_shape.num_elements() 

377 return ops.OpStats("flops", max_pool_ops + orig_out_shape.num_elements()) 

378 

379 

380@ops.RegisterStatistics("Conv2DBackpropInput", "flops") 

381def _conv_2d_backprop_input_flops(graph, node): 

382 """Compute flops for Conv2DBackpropInput operation.""" 

383 # Formula: 

384 # batch_size * image_x_dim * image_y_dim * kernel_x_dim * kernel_y_dim 

385 # * input_depth * output_depth * 2 / (image_x_stride * image_x_stride) 

386 # 

387 # Where: 

388 # image_x_dim, image_y_dim and input_depth --- size of input to source (no 

389 # backprop) convolution, in other words they are sizes of backprop output. 

390 # output_depth --- number of filters in the original convolution, thus 

391 # depth of backprop input. 

392 # kernel_x_dim and kernel_y_dim --- sizes of filter in spatial dimension 

393 # image_x_stride and image_x_stride --- strides of the convolution 

394 # 

395 _verify_conv_data_format(node) 

396 # out_shape = [batch_size, image_y_dim, image_x_dim, input_depth] 

397 out_shape = graph_util.tensor_shape_from_node_def_name(graph, node.name) 

398 out_shape.assert_is_fully_defined() 

399 # kernel_shape = [kernel_y_dim, kernel_x_dim, input_depth, output_depth] 

400 kernel_shape = graph_util.tensor_shape_from_node_def_name(graph, 

401 node.input[1]) 

402 kernel_shape.assert_is_fully_defined() 

403 # strides 

404 strides_shape = list(node.attr["strides"].list.i) 

405 strides_product = strides_shape[1] * strides_shape[2] 

406 return ops.OpStats("flops", 

407 (2 * out_shape.num_elements() 

408 * kernel_shape.num_elements() 

409 / (out_shape.dims[-1].value * strides_product))) 

410 

411 

412@ops.RegisterStatistics("Conv2DBackpropFilter", "flops") 

413def _conv_2d_backprop_filter_flops(graph, node): 

414 """Compute flops for Conv2DBackpropFilter operation.""" 

415 # Formula same as for Conv2DBackpropInput: 

416 # batch_size * image_x_dim * image_y_dim * kernel_x_dim * kernel_y_dim 

417 # * input_depth * output_depth * 2 / (image_x_stride * image_x_stride) 

418 # 

419 _verify_conv_data_format(node) 

420 # image_shape = [batch_size, image_y_dim, image_x_dim, input_depth] 

421 image_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0]) 

422 image_shape.assert_is_fully_defined() 

423 # kernel_shape = [kernel_y_dim, kernel_x_dim, input_depth, output_depth] 

424 kernel_shape = graph_util.tensor_shape_from_node_def_name(graph, node.name) 

425 kernel_shape.assert_is_fully_defined() 

426 # strides 

427 strides_shape = list(node.attr["strides"].list.i) 

428 strides_product = strides_shape[1] * strides_shape[2] 

429 return ops.OpStats("flops", 

430 (2 * image_shape.num_elements() 

431 * kernel_shape.num_elements() 

432 / (image_shape.dims[-1].value * strides_product))) 

433 

434################################################################################ 

435# Other ops 

436################################################################################ 

437 

438 

439@ops.RegisterStatistics("AddN", "flops") 

440def _add_n_flops(graph, node): 

441 """Compute flops for AddN operation.""" 

442 if not node.input: 

443 return _zero_flops(graph, node) 

444 in_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0]) 

445 in_shape.assert_is_fully_defined() 

446 return ops.OpStats("flops", in_shape.num_elements() * (len(node.input) - 1)) 

447 

448 

449@ops.RegisterStatistics("MatMul", "flops") 

450def _calc_mat_mul_flops(graph, node): 

451 """Calculates the compute resources needed for MatMul.""" 

452 transpose_a = node.attr["transpose_a"].b 

453 a_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0]) 

454 a_shape.assert_is_fully_defined() 

455 if transpose_a: 

456 k = int(a_shape[0]) 

457 else: 

458 k = int(a_shape[1]) 

459 output_shape = graph_util.tensor_shape_from_node_def_name(graph, node.name) 

460 output_shape.assert_is_fully_defined() 

461 output_count = np.prod(output_shape.as_list()) 

462 return ops.OpStats("flops", (k * output_count * 2)) 

463 

464 

465@ops.RegisterStatistics("BatchMatMul", "flops") 

466@ops.RegisterStatistics("BatchMatMulV2", "flops") 

467@ops.RegisterStatistics("BatchMatMulV3", "flops") 

468def _calc_batch_mat_mul_flops(graph, node): 

469 """Calculates the compute resources needed for BatchMatMul.""" 

470 transpose_a = node.attr["transpose_a"].b 

471 a_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0]) 

472 a_shape.assert_is_fully_defined() 

473 if transpose_a: 

474 k = int(a_shape[-2]) 

475 else: 

476 k = int(a_shape[-1]) 

477 output_shape = graph_util.tensor_shape_from_node_def_name(graph, node.name) 

478 output_shape.assert_is_fully_defined() 

479 output_count = np.prod(output_shape.as_list()) 

480 return ops.OpStats("flops", (k * output_count * 2))