Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/platform/benchmark.py: 23%

215 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15 

16"""Utilities to run benchmarks.""" 

17import math 

18import numbers 

19import os 

20import re 

21import sys 

22import time 

23import types 

24 

25from absl import app 

26 

27from tensorflow.core.protobuf import config_pb2 

28from tensorflow.core.protobuf import rewriter_config_pb2 

29from tensorflow.core.util import test_log_pb2 

30from tensorflow.python.client import timeline 

31from tensorflow.python.framework import ops 

32from tensorflow.python.platform import gfile 

33from tensorflow.python.platform import tf_logging as logging 

34from tensorflow.python.util import tf_inspect 

35from tensorflow.python.util.tf_export import tf_export 

36 

37 

38# When a subclass of the Benchmark class is created, it is added to 

39# the registry automatically 

40GLOBAL_BENCHMARK_REGISTRY = set() 

41 

42# Environment variable that determines whether benchmarks are written. 

43# See also tensorflow/core/util/reporter.h TestReporter::kTestReporterEnv. 

44TEST_REPORTER_TEST_ENV = "TEST_REPORT_FILE_PREFIX" 

45 

46# Environment variable that lets the TensorFlow runtime allocate a new 

47# threadpool for each benchmark. 

48OVERRIDE_GLOBAL_THREADPOOL = "TF_OVERRIDE_GLOBAL_THREADPOOL" 

49 

50 

51def _rename_function(f, arg_num, name): 

52 """Rename the given function's name appears in the stack trace.""" 

53 func_code = f.__code__ 

54 new_code = func_code.replace(co_argcount=arg_num, co_name=name) 

55 return types.FunctionType(new_code, f.__globals__, name, f.__defaults__, 

56 f.__closure__) 

57 

58 

59def _global_report_benchmark( 

60 name, iters=None, cpu_time=None, wall_time=None, 

61 throughput=None, extras=None, metrics=None): 

62 """Method for recording a benchmark directly. 

63 

64 Args: 

65 name: The BenchmarkEntry name. 

66 iters: (optional) How many iterations were run 

67 cpu_time: (optional) Total cpu time in seconds 

68 wall_time: (optional) Total wall time in seconds 

69 throughput: (optional) Throughput (in MB/s) 

70 extras: (optional) Dict mapping string keys to additional benchmark info. 

71 metrics: (optional) A list of dict representing metrics generated by the 

72 benchmark. Each dict should contain keys 'name' and'value'. A dict 

73 can optionally contain keys 'min_value' and 'max_value'. 

74 

75 Raises: 

76 TypeError: if extras is not a dict. 

77 IOError: if the benchmark output file already exists. 

78 """ 

79 logging.info("Benchmark [%s] iters: %d, wall_time: %g, cpu_time: %g," 

80 "throughput: %g, extras: %s, metrics: %s", name, 

81 iters if iters is not None else -1, 

82 wall_time if wall_time is not None else -1, 

83 cpu_time if cpu_time is not None else -1, 

84 throughput if throughput is not None else -1, 

85 str(extras) if extras else "None", 

86 str(metrics) if metrics else "None") 

87 

88 entries = test_log_pb2.BenchmarkEntries() 

89 entry = entries.entry.add() 

90 entry.name = name 

91 if iters is not None: 

92 entry.iters = iters 

93 if cpu_time is not None: 

94 entry.cpu_time = cpu_time 

95 if wall_time is not None: 

96 entry.wall_time = wall_time 

97 if throughput is not None: 

98 entry.throughput = throughput 

99 if extras is not None: 

100 if not isinstance(extras, dict): 

101 raise TypeError("extras must be a dict") 

102 for (k, v) in extras.items(): 

103 if isinstance(v, numbers.Number): 

104 entry.extras[k].double_value = v 

105 else: 

106 entry.extras[k].string_value = str(v) 

107 if metrics is not None: 

108 if not isinstance(metrics, list): 

109 raise TypeError("metrics must be a list") 

110 for metric in metrics: 

111 if "name" not in metric: 

112 raise TypeError("metric must has a 'name' field") 

113 if "value" not in metric: 

114 raise TypeError("metric must has a 'value' field") 

115 

116 metric_entry = entry.metrics.add() 

117 metric_entry.name = metric["name"] 

118 metric_entry.value = metric["value"] 

119 if "min_value" in metric: 

120 metric_entry.min_value.value = metric["min_value"] 

121 if "max_value" in metric: 

122 metric_entry.max_value.value = metric["max_value"] 

123 

124 test_env = os.environ.get(TEST_REPORTER_TEST_ENV, None) 

125 if test_env is None: 

126 # Reporting was not requested, just print the proto 

127 print(str(entries)) 

128 return 

129 

130 serialized_entry = entries.SerializeToString() 

131 

132 mangled_name = name.replace("/", "__") 

133 output_path = "%s%s" % (test_env, mangled_name) 

134 if gfile.Exists(output_path): 

135 raise IOError("File already exists: %s" % output_path) 

136 with gfile.GFile(output_path, "wb") as out: 

137 out.write(serialized_entry) 

138 

139 

140class _BenchmarkRegistrar(type): 

141 """The Benchmark class registrar. Used by abstract Benchmark class.""" 

142 

143 def __new__(mcs, clsname, base, attrs): 

144 newclass = type.__new__(mcs, clsname, base, attrs) 

145 if not newclass.is_abstract(): 

146 GLOBAL_BENCHMARK_REGISTRY.add(newclass) 

147 return newclass 

148 

149 

150@tf_export("__internal__.test.ParameterizedBenchmark", v1=[]) 

151class ParameterizedBenchmark(_BenchmarkRegistrar): 

152 """Metaclass to generate parameterized benchmarks. 

153 

154 Use this class as a metaclass and override the `_benchmark_parameters` to 

155 generate multiple benchmark test cases. For example: 

156 

157 class FooBenchmark(metaclass=tf.test.ParameterizedBenchmark, 

158 tf.test.Benchmark): 

159 # The `_benchmark_parameters` is expected to be a list with test cases. 

160 # Each of the test case is a tuple, with the first time to be test case 

161 # name, followed by any number of the parameters needed for the test case. 

162 _benchmark_parameters = [ 

163 ('case_1', Foo, 1, 'one'), 

164 ('case_2', Bar, 2, 'two'), 

165 ] 

166 

167 def benchmark_test(self, target_class, int_param, string_param): 

168 # benchmark test body 

169 

170 The example above will generate two benchmark test cases: 

171 "benchmark_test__case_1" and "benchmark_test__case_2". 

172 """ 

173 

174 def __new__(mcs, clsname, base, attrs): 

175 param_config_list = attrs["_benchmark_parameters"] 

176 

177 def create_benchmark_function(original_benchmark, params): 

178 return lambda self: original_benchmark(self, *params) 

179 

180 for name in attrs.copy().keys(): 

181 if not name.startswith("benchmark"): 

182 continue 

183 

184 original_benchmark = attrs[name] 

185 del attrs[name] 

186 

187 for param_config in param_config_list: 

188 test_name_suffix = param_config[0] 

189 params = param_config[1:] 

190 benchmark_name = name + "__" + test_name_suffix 

191 if benchmark_name in attrs: 

192 raise Exception( 

193 "Benchmark named {} already defined.".format(benchmark_name)) 

194 

195 benchmark = create_benchmark_function(original_benchmark, params) 

196 # Renaming is important because `report_benchmark` function looks up the 

197 # function name in the stack trace. 

198 attrs[benchmark_name] = _rename_function(benchmark, 1, benchmark_name) 

199 

200 return super().__new__(mcs, clsname, base, attrs) 

201 

202 

203class Benchmark(metaclass=_BenchmarkRegistrar): 

204 """Abstract class that provides helper functions for running benchmarks. 

205 

206 Any class subclassing this one is immediately registered in the global 

207 benchmark registry. 

208 

209 Only methods whose names start with the word "benchmark" will be run during 

210 benchmarking. 

211 """ 

212 

213 @classmethod 

214 def is_abstract(cls): 

215 # mro: (_BenchmarkRegistrar, Benchmark) means this is Benchmark 

216 return len(cls.mro()) <= 2 

217 

218 def _get_name(self, overwrite_name=None): 

219 """Returns full name of class and method calling report_benchmark.""" 

220 

221 # Find the caller method (outermost Benchmark class) 

222 stack = tf_inspect.stack() 

223 calling_class = None 

224 name = None 

225 for frame in stack[::-1]: 

226 f_locals = frame[0].f_locals 

227 f_self = f_locals.get("self", None) 

228 if isinstance(f_self, Benchmark): 

229 calling_class = f_self # Get the outermost stack Benchmark call 

230 name = frame[3] # Get the method name 

231 break 

232 if calling_class is None: 

233 raise ValueError("Unable to determine calling Benchmark class.") 

234 

235 # Use the method name, or overwrite_name is provided. 

236 name = overwrite_name or name 

237 # Prefix the name with the class name. 

238 class_name = type(calling_class).__name__ 

239 name = "%s.%s" % (class_name, name) 

240 return name 

241 

242 def report_benchmark( 

243 self, 

244 iters=None, 

245 cpu_time=None, 

246 wall_time=None, 

247 throughput=None, 

248 extras=None, 

249 name=None, 

250 metrics=None): 

251 """Report a benchmark. 

252 

253 Args: 

254 iters: (optional) How many iterations were run 

255 cpu_time: (optional) Median or mean cpu time in seconds. 

256 wall_time: (optional) Median or mean wall time in seconds. 

257 throughput: (optional) Throughput (in MB/s) 

258 extras: (optional) Dict mapping string keys to additional benchmark info. 

259 Values may be either floats or values that are convertible to strings. 

260 name: (optional) Override the BenchmarkEntry name with `name`. 

261 Otherwise it is inferred from the top-level method name. 

262 metrics: (optional) A list of dict, where each dict has the keys below 

263 name (required), string, metric name 

264 value (required), double, metric value 

265 min_value (optional), double, minimum acceptable metric value 

266 max_value (optional), double, maximum acceptable metric value 

267 """ 

268 name = self._get_name(overwrite_name=name) 

269 _global_report_benchmark( 

270 name=name, iters=iters, cpu_time=cpu_time, wall_time=wall_time, 

271 throughput=throughput, extras=extras, metrics=metrics) 

272 

273 

274@tf_export("test.benchmark_config") 

275def benchmark_config(): 

276 """Returns a tf.compat.v1.ConfigProto for disabling the dependency optimizer. 

277 

278 Returns: 

279 A TensorFlow ConfigProto object. 

280 """ 

281 config = config_pb2.ConfigProto() 

282 config.graph_options.rewrite_options.dependency_optimization = ( 

283 rewriter_config_pb2.RewriterConfig.OFF) 

284 return config 

285 

286 

287@tf_export("test.Benchmark") 

288class TensorFlowBenchmark(Benchmark): 

289 """Abstract class that provides helpers for TensorFlow benchmarks.""" 

290 

291 def __init__(self): 

292 # Allow TensorFlow runtime to allocate a new threadpool with different 

293 # number of threads for each new benchmark. 

294 os.environ[OVERRIDE_GLOBAL_THREADPOOL] = "1" 

295 super().__init__() 

296 

297 @classmethod 

298 def is_abstract(cls): 

299 # mro: (_BenchmarkRegistrar, Benchmark, TensorFlowBenchmark) means 

300 # this is TensorFlowBenchmark. 

301 return len(cls.mro()) <= 3 

302 

303 def run_op_benchmark(self, 

304 sess, 

305 op_or_tensor, 

306 feed_dict=None, 

307 burn_iters=2, 

308 min_iters=10, 

309 store_trace=False, 

310 store_memory_usage=True, 

311 name=None, 

312 extras=None, 

313 mbs=0): 

314 """Run an op or tensor in the given session. Report the results. 

315 

316 Args: 

317 sess: `Session` object to use for timing. 

318 op_or_tensor: `Operation` or `Tensor` to benchmark. 

319 feed_dict: A `dict` of values to feed for each op iteration (see the 

320 `feed_dict` parameter of `Session.run`). 

321 burn_iters: Number of burn-in iterations to run. 

322 min_iters: Minimum number of iterations to use for timing. 

323 store_trace: Boolean, whether to run an extra untimed iteration and 

324 store the trace of iteration in returned extras. 

325 The trace will be stored as a string in Google Chrome trace format 

326 in the extras field "full_trace_chrome_format". Note that trace 

327 will not be stored in test_log_pb2.TestResults proto. 

328 store_memory_usage: Boolean, whether to run an extra untimed iteration, 

329 calculate memory usage, and store that in extras fields. 

330 name: (optional) Override the BenchmarkEntry name with `name`. 

331 Otherwise it is inferred from the top-level method name. 

332 extras: (optional) Dict mapping string keys to additional benchmark info. 

333 Values may be either floats or values that are convertible to strings. 

334 mbs: (optional) The number of megabytes moved by this op, used to 

335 calculate the ops throughput. 

336 

337 Returns: 

338 A `dict` containing the key-value pairs that were passed to 

339 `report_benchmark`. If `store_trace` option is used, then 

340 `full_chrome_trace_format` will be included in return dictionary even 

341 though it is not passed to `report_benchmark` with `extras`. 

342 """ 

343 for _ in range(burn_iters): 

344 sess.run(op_or_tensor, feed_dict=feed_dict) 

345 

346 deltas = [None] * min_iters 

347 

348 for i in range(min_iters): 

349 start_time = time.time() 

350 sess.run(op_or_tensor, feed_dict=feed_dict) 

351 end_time = time.time() 

352 delta = end_time - start_time 

353 deltas[i] = delta 

354 

355 extras = extras if extras is not None else {} 

356 unreported_extras = {} 

357 if store_trace or store_memory_usage: 

358 run_options = config_pb2.RunOptions( 

359 trace_level=config_pb2.RunOptions.FULL_TRACE) 

360 run_metadata = config_pb2.RunMetadata() 

361 sess.run(op_or_tensor, feed_dict=feed_dict, 

362 options=run_options, run_metadata=run_metadata) 

363 tl = timeline.Timeline(run_metadata.step_stats) 

364 

365 if store_trace: 

366 unreported_extras["full_trace_chrome_format"] = ( 

367 tl.generate_chrome_trace_format()) 

368 

369 if store_memory_usage: 

370 step_stats_analysis = tl.analyze_step_stats(show_memory=True) 

371 allocator_maximums = step_stats_analysis.allocator_maximums 

372 for k, v in allocator_maximums.items(): 

373 extras["allocator_maximum_num_bytes_%s" % k] = v.num_bytes 

374 

375 def _median(x): 

376 if not x: 

377 return -1 

378 s = sorted(x) 

379 l = len(x) 

380 lm1 = l - 1 

381 return (s[l//2] + s[lm1//2]) / 2.0 

382 

383 def _mean_and_stdev(x): 

384 if not x: 

385 return -1, -1 

386 l = len(x) 

387 mean = sum(x) / l 

388 if l == 1: 

389 return mean, -1 

390 variance = sum([(e - mean) * (e - mean) for e in x]) / (l - 1) 

391 return mean, math.sqrt(variance) 

392 

393 median_delta = _median(deltas) 

394 

395 benchmark_values = { 

396 "iters": min_iters, 

397 "wall_time": median_delta, 

398 "extras": extras, 

399 "name": name, 

400 "throughput": mbs / median_delta 

401 } 

402 self.report_benchmark(**benchmark_values) 

403 

404 mean_delta, stdev_delta = _mean_and_stdev(deltas) 

405 unreported_extras["wall_time_mean"] = mean_delta 

406 unreported_extras["wall_time_stdev"] = stdev_delta 

407 benchmark_values["extras"].update(unreported_extras) 

408 return benchmark_values 

409 

410 def evaluate(self, tensors): 

411 """Evaluates tensors and returns numpy values. 

412 

413 Args: 

414 tensors: A Tensor or a nested list/tuple of Tensors. 

415 

416 Returns: 

417 tensors numpy values. 

418 """ 

419 sess = ops.get_default_session() or self.cached_session() 

420 return sess.run(tensors) 

421 

422 

423def _run_benchmarks(regex): 

424 """Run benchmarks that match regex `regex`. 

425 

426 This function goes through the global benchmark registry, and matches 

427 benchmark class and method names of the form 

428 `module.name.BenchmarkClass.benchmarkMethod` to the given regex. 

429 If a method matches, it is run. 

430 

431 Args: 

432 regex: The string regular expression to match Benchmark classes against. 

433 

434 Raises: 

435 ValueError: If no benchmarks were selected by the input regex. 

436 """ 

437 registry = list(GLOBAL_BENCHMARK_REGISTRY) 

438 

439 selected_benchmarks = [] 

440 # Match benchmarks in registry against regex 

441 for benchmark in registry: 

442 benchmark_name = "%s.%s" % (benchmark.__module__, benchmark.__name__) 

443 attrs = dir(benchmark) 

444 # Don't instantiate the benchmark class unless necessary 

445 benchmark_instance = None 

446 

447 for attr in attrs: 

448 if not attr.startswith("benchmark"): 

449 continue 

450 candidate_benchmark_fn = getattr(benchmark, attr) 

451 if not callable(candidate_benchmark_fn): 

452 continue 

453 full_benchmark_name = "%s.%s" % (benchmark_name, attr) 

454 if regex == "all" or re.search(regex, full_benchmark_name): 

455 selected_benchmarks.append(full_benchmark_name) 

456 # Instantiate the class if it hasn't been instantiated 

457 benchmark_instance = benchmark_instance or benchmark() 

458 # Get the method tied to the class 

459 instance_benchmark_fn = getattr(benchmark_instance, attr) 

460 # Call the instance method 

461 instance_benchmark_fn() 

462 

463 if not selected_benchmarks: 

464 raise ValueError("No benchmarks matched the pattern: '{}'".format(regex)) 

465 

466 

467def benchmarks_main(true_main, argv=None): 

468 """Run benchmarks as declared in argv. 

469 

470 Args: 

471 true_main: True main function to run if benchmarks are not requested. 

472 argv: the command line arguments (if None, uses sys.argv). 

473 """ 

474 if argv is None: 

475 argv = sys.argv 

476 found_arg = [ 

477 arg 

478 for arg in argv 

479 if arg.startswith("--benchmark_filter=") 

480 or arg.startswith("-benchmark_filter=") 

481 ] 

482 if found_arg: 

483 # Remove --benchmark_filter arg from sys.argv 

484 argv.remove(found_arg[0]) 

485 

486 regex = found_arg[0].split("=")[1] 

487 app.run(lambda _: _run_benchmarks(regex), argv=argv) 

488 else: 

489 true_main()