Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/platform/benchmark.py: 23%
215 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
16"""Utilities to run benchmarks."""
17import math
18import numbers
19import os
20import re
21import sys
22import time
23import types
25from absl import app
27from tensorflow.core.protobuf import config_pb2
28from tensorflow.core.protobuf import rewriter_config_pb2
29from tensorflow.core.util import test_log_pb2
30from tensorflow.python.client import timeline
31from tensorflow.python.framework import ops
32from tensorflow.python.platform import gfile
33from tensorflow.python.platform import tf_logging as logging
34from tensorflow.python.util import tf_inspect
35from tensorflow.python.util.tf_export import tf_export
38# When a subclass of the Benchmark class is created, it is added to
39# the registry automatically
40GLOBAL_BENCHMARK_REGISTRY = set()
42# Environment variable that determines whether benchmarks are written.
43# See also tensorflow/core/util/reporter.h TestReporter::kTestReporterEnv.
44TEST_REPORTER_TEST_ENV = "TEST_REPORT_FILE_PREFIX"
46# Environment variable that lets the TensorFlow runtime allocate a new
47# threadpool for each benchmark.
48OVERRIDE_GLOBAL_THREADPOOL = "TF_OVERRIDE_GLOBAL_THREADPOOL"
51def _rename_function(f, arg_num, name):
52 """Rename the given function's name appears in the stack trace."""
53 func_code = f.__code__
54 new_code = func_code.replace(co_argcount=arg_num, co_name=name)
55 return types.FunctionType(new_code, f.__globals__, name, f.__defaults__,
56 f.__closure__)
59def _global_report_benchmark(
60 name, iters=None, cpu_time=None, wall_time=None,
61 throughput=None, extras=None, metrics=None):
62 """Method for recording a benchmark directly.
64 Args:
65 name: The BenchmarkEntry name.
66 iters: (optional) How many iterations were run
67 cpu_time: (optional) Total cpu time in seconds
68 wall_time: (optional) Total wall time in seconds
69 throughput: (optional) Throughput (in MB/s)
70 extras: (optional) Dict mapping string keys to additional benchmark info.
71 metrics: (optional) A list of dict representing metrics generated by the
72 benchmark. Each dict should contain keys 'name' and'value'. A dict
73 can optionally contain keys 'min_value' and 'max_value'.
75 Raises:
76 TypeError: if extras is not a dict.
77 IOError: if the benchmark output file already exists.
78 """
79 logging.info("Benchmark [%s] iters: %d, wall_time: %g, cpu_time: %g,"
80 "throughput: %g, extras: %s, metrics: %s", name,
81 iters if iters is not None else -1,
82 wall_time if wall_time is not None else -1,
83 cpu_time if cpu_time is not None else -1,
84 throughput if throughput is not None else -1,
85 str(extras) if extras else "None",
86 str(metrics) if metrics else "None")
88 entries = test_log_pb2.BenchmarkEntries()
89 entry = entries.entry.add()
90 entry.name = name
91 if iters is not None:
92 entry.iters = iters
93 if cpu_time is not None:
94 entry.cpu_time = cpu_time
95 if wall_time is not None:
96 entry.wall_time = wall_time
97 if throughput is not None:
98 entry.throughput = throughput
99 if extras is not None:
100 if not isinstance(extras, dict):
101 raise TypeError("extras must be a dict")
102 for (k, v) in extras.items():
103 if isinstance(v, numbers.Number):
104 entry.extras[k].double_value = v
105 else:
106 entry.extras[k].string_value = str(v)
107 if metrics is not None:
108 if not isinstance(metrics, list):
109 raise TypeError("metrics must be a list")
110 for metric in metrics:
111 if "name" not in metric:
112 raise TypeError("metric must has a 'name' field")
113 if "value" not in metric:
114 raise TypeError("metric must has a 'value' field")
116 metric_entry = entry.metrics.add()
117 metric_entry.name = metric["name"]
118 metric_entry.value = metric["value"]
119 if "min_value" in metric:
120 metric_entry.min_value.value = metric["min_value"]
121 if "max_value" in metric:
122 metric_entry.max_value.value = metric["max_value"]
124 test_env = os.environ.get(TEST_REPORTER_TEST_ENV, None)
125 if test_env is None:
126 # Reporting was not requested, just print the proto
127 print(str(entries))
128 return
130 serialized_entry = entries.SerializeToString()
132 mangled_name = name.replace("/", "__")
133 output_path = "%s%s" % (test_env, mangled_name)
134 if gfile.Exists(output_path):
135 raise IOError("File already exists: %s" % output_path)
136 with gfile.GFile(output_path, "wb") as out:
137 out.write(serialized_entry)
140class _BenchmarkRegistrar(type):
141 """The Benchmark class registrar. Used by abstract Benchmark class."""
143 def __new__(mcs, clsname, base, attrs):
144 newclass = type.__new__(mcs, clsname, base, attrs)
145 if not newclass.is_abstract():
146 GLOBAL_BENCHMARK_REGISTRY.add(newclass)
147 return newclass
150@tf_export("__internal__.test.ParameterizedBenchmark", v1=[])
151class ParameterizedBenchmark(_BenchmarkRegistrar):
152 """Metaclass to generate parameterized benchmarks.
154 Use this class as a metaclass and override the `_benchmark_parameters` to
155 generate multiple benchmark test cases. For example:
157 class FooBenchmark(metaclass=tf.test.ParameterizedBenchmark,
158 tf.test.Benchmark):
159 # The `_benchmark_parameters` is expected to be a list with test cases.
160 # Each of the test case is a tuple, with the first time to be test case
161 # name, followed by any number of the parameters needed for the test case.
162 _benchmark_parameters = [
163 ('case_1', Foo, 1, 'one'),
164 ('case_2', Bar, 2, 'two'),
165 ]
167 def benchmark_test(self, target_class, int_param, string_param):
168 # benchmark test body
170 The example above will generate two benchmark test cases:
171 "benchmark_test__case_1" and "benchmark_test__case_2".
172 """
174 def __new__(mcs, clsname, base, attrs):
175 param_config_list = attrs["_benchmark_parameters"]
177 def create_benchmark_function(original_benchmark, params):
178 return lambda self: original_benchmark(self, *params)
180 for name in attrs.copy().keys():
181 if not name.startswith("benchmark"):
182 continue
184 original_benchmark = attrs[name]
185 del attrs[name]
187 for param_config in param_config_list:
188 test_name_suffix = param_config[0]
189 params = param_config[1:]
190 benchmark_name = name + "__" + test_name_suffix
191 if benchmark_name in attrs:
192 raise Exception(
193 "Benchmark named {} already defined.".format(benchmark_name))
195 benchmark = create_benchmark_function(original_benchmark, params)
196 # Renaming is important because `report_benchmark` function looks up the
197 # function name in the stack trace.
198 attrs[benchmark_name] = _rename_function(benchmark, 1, benchmark_name)
200 return super().__new__(mcs, clsname, base, attrs)
203class Benchmark(metaclass=_BenchmarkRegistrar):
204 """Abstract class that provides helper functions for running benchmarks.
206 Any class subclassing this one is immediately registered in the global
207 benchmark registry.
209 Only methods whose names start with the word "benchmark" will be run during
210 benchmarking.
211 """
213 @classmethod
214 def is_abstract(cls):
215 # mro: (_BenchmarkRegistrar, Benchmark) means this is Benchmark
216 return len(cls.mro()) <= 2
218 def _get_name(self, overwrite_name=None):
219 """Returns full name of class and method calling report_benchmark."""
221 # Find the caller method (outermost Benchmark class)
222 stack = tf_inspect.stack()
223 calling_class = None
224 name = None
225 for frame in stack[::-1]:
226 f_locals = frame[0].f_locals
227 f_self = f_locals.get("self", None)
228 if isinstance(f_self, Benchmark):
229 calling_class = f_self # Get the outermost stack Benchmark call
230 name = frame[3] # Get the method name
231 break
232 if calling_class is None:
233 raise ValueError("Unable to determine calling Benchmark class.")
235 # Use the method name, or overwrite_name is provided.
236 name = overwrite_name or name
237 # Prefix the name with the class name.
238 class_name = type(calling_class).__name__
239 name = "%s.%s" % (class_name, name)
240 return name
242 def report_benchmark(
243 self,
244 iters=None,
245 cpu_time=None,
246 wall_time=None,
247 throughput=None,
248 extras=None,
249 name=None,
250 metrics=None):
251 """Report a benchmark.
253 Args:
254 iters: (optional) How many iterations were run
255 cpu_time: (optional) Median or mean cpu time in seconds.
256 wall_time: (optional) Median or mean wall time in seconds.
257 throughput: (optional) Throughput (in MB/s)
258 extras: (optional) Dict mapping string keys to additional benchmark info.
259 Values may be either floats or values that are convertible to strings.
260 name: (optional) Override the BenchmarkEntry name with `name`.
261 Otherwise it is inferred from the top-level method name.
262 metrics: (optional) A list of dict, where each dict has the keys below
263 name (required), string, metric name
264 value (required), double, metric value
265 min_value (optional), double, minimum acceptable metric value
266 max_value (optional), double, maximum acceptable metric value
267 """
268 name = self._get_name(overwrite_name=name)
269 _global_report_benchmark(
270 name=name, iters=iters, cpu_time=cpu_time, wall_time=wall_time,
271 throughput=throughput, extras=extras, metrics=metrics)
274@tf_export("test.benchmark_config")
275def benchmark_config():
276 """Returns a tf.compat.v1.ConfigProto for disabling the dependency optimizer.
278 Returns:
279 A TensorFlow ConfigProto object.
280 """
281 config = config_pb2.ConfigProto()
282 config.graph_options.rewrite_options.dependency_optimization = (
283 rewriter_config_pb2.RewriterConfig.OFF)
284 return config
287@tf_export("test.Benchmark")
288class TensorFlowBenchmark(Benchmark):
289 """Abstract class that provides helpers for TensorFlow benchmarks."""
291 def __init__(self):
292 # Allow TensorFlow runtime to allocate a new threadpool with different
293 # number of threads for each new benchmark.
294 os.environ[OVERRIDE_GLOBAL_THREADPOOL] = "1"
295 super().__init__()
297 @classmethod
298 def is_abstract(cls):
299 # mro: (_BenchmarkRegistrar, Benchmark, TensorFlowBenchmark) means
300 # this is TensorFlowBenchmark.
301 return len(cls.mro()) <= 3
303 def run_op_benchmark(self,
304 sess,
305 op_or_tensor,
306 feed_dict=None,
307 burn_iters=2,
308 min_iters=10,
309 store_trace=False,
310 store_memory_usage=True,
311 name=None,
312 extras=None,
313 mbs=0):
314 """Run an op or tensor in the given session. Report the results.
316 Args:
317 sess: `Session` object to use for timing.
318 op_or_tensor: `Operation` or `Tensor` to benchmark.
319 feed_dict: A `dict` of values to feed for each op iteration (see the
320 `feed_dict` parameter of `Session.run`).
321 burn_iters: Number of burn-in iterations to run.
322 min_iters: Minimum number of iterations to use for timing.
323 store_trace: Boolean, whether to run an extra untimed iteration and
324 store the trace of iteration in returned extras.
325 The trace will be stored as a string in Google Chrome trace format
326 in the extras field "full_trace_chrome_format". Note that trace
327 will not be stored in test_log_pb2.TestResults proto.
328 store_memory_usage: Boolean, whether to run an extra untimed iteration,
329 calculate memory usage, and store that in extras fields.
330 name: (optional) Override the BenchmarkEntry name with `name`.
331 Otherwise it is inferred from the top-level method name.
332 extras: (optional) Dict mapping string keys to additional benchmark info.
333 Values may be either floats or values that are convertible to strings.
334 mbs: (optional) The number of megabytes moved by this op, used to
335 calculate the ops throughput.
337 Returns:
338 A `dict` containing the key-value pairs that were passed to
339 `report_benchmark`. If `store_trace` option is used, then
340 `full_chrome_trace_format` will be included in return dictionary even
341 though it is not passed to `report_benchmark` with `extras`.
342 """
343 for _ in range(burn_iters):
344 sess.run(op_or_tensor, feed_dict=feed_dict)
346 deltas = [None] * min_iters
348 for i in range(min_iters):
349 start_time = time.time()
350 sess.run(op_or_tensor, feed_dict=feed_dict)
351 end_time = time.time()
352 delta = end_time - start_time
353 deltas[i] = delta
355 extras = extras if extras is not None else {}
356 unreported_extras = {}
357 if store_trace or store_memory_usage:
358 run_options = config_pb2.RunOptions(
359 trace_level=config_pb2.RunOptions.FULL_TRACE)
360 run_metadata = config_pb2.RunMetadata()
361 sess.run(op_or_tensor, feed_dict=feed_dict,
362 options=run_options, run_metadata=run_metadata)
363 tl = timeline.Timeline(run_metadata.step_stats)
365 if store_trace:
366 unreported_extras["full_trace_chrome_format"] = (
367 tl.generate_chrome_trace_format())
369 if store_memory_usage:
370 step_stats_analysis = tl.analyze_step_stats(show_memory=True)
371 allocator_maximums = step_stats_analysis.allocator_maximums
372 for k, v in allocator_maximums.items():
373 extras["allocator_maximum_num_bytes_%s" % k] = v.num_bytes
375 def _median(x):
376 if not x:
377 return -1
378 s = sorted(x)
379 l = len(x)
380 lm1 = l - 1
381 return (s[l//2] + s[lm1//2]) / 2.0
383 def _mean_and_stdev(x):
384 if not x:
385 return -1, -1
386 l = len(x)
387 mean = sum(x) / l
388 if l == 1:
389 return mean, -1
390 variance = sum([(e - mean) * (e - mean) for e in x]) / (l - 1)
391 return mean, math.sqrt(variance)
393 median_delta = _median(deltas)
395 benchmark_values = {
396 "iters": min_iters,
397 "wall_time": median_delta,
398 "extras": extras,
399 "name": name,
400 "throughput": mbs / median_delta
401 }
402 self.report_benchmark(**benchmark_values)
404 mean_delta, stdev_delta = _mean_and_stdev(deltas)
405 unreported_extras["wall_time_mean"] = mean_delta
406 unreported_extras["wall_time_stdev"] = stdev_delta
407 benchmark_values["extras"].update(unreported_extras)
408 return benchmark_values
410 def evaluate(self, tensors):
411 """Evaluates tensors and returns numpy values.
413 Args:
414 tensors: A Tensor or a nested list/tuple of Tensors.
416 Returns:
417 tensors numpy values.
418 """
419 sess = ops.get_default_session() or self.cached_session()
420 return sess.run(tensors)
423def _run_benchmarks(regex):
424 """Run benchmarks that match regex `regex`.
426 This function goes through the global benchmark registry, and matches
427 benchmark class and method names of the form
428 `module.name.BenchmarkClass.benchmarkMethod` to the given regex.
429 If a method matches, it is run.
431 Args:
432 regex: The string regular expression to match Benchmark classes against.
434 Raises:
435 ValueError: If no benchmarks were selected by the input regex.
436 """
437 registry = list(GLOBAL_BENCHMARK_REGISTRY)
439 selected_benchmarks = []
440 # Match benchmarks in registry against regex
441 for benchmark in registry:
442 benchmark_name = "%s.%s" % (benchmark.__module__, benchmark.__name__)
443 attrs = dir(benchmark)
444 # Don't instantiate the benchmark class unless necessary
445 benchmark_instance = None
447 for attr in attrs:
448 if not attr.startswith("benchmark"):
449 continue
450 candidate_benchmark_fn = getattr(benchmark, attr)
451 if not callable(candidate_benchmark_fn):
452 continue
453 full_benchmark_name = "%s.%s" % (benchmark_name, attr)
454 if regex == "all" or re.search(regex, full_benchmark_name):
455 selected_benchmarks.append(full_benchmark_name)
456 # Instantiate the class if it hasn't been instantiated
457 benchmark_instance = benchmark_instance or benchmark()
458 # Get the method tied to the class
459 instance_benchmark_fn = getattr(benchmark_instance, attr)
460 # Call the instance method
461 instance_benchmark_fn()
463 if not selected_benchmarks:
464 raise ValueError("No benchmarks matched the pattern: '{}'".format(regex))
467def benchmarks_main(true_main, argv=None):
468 """Run benchmarks as declared in argv.
470 Args:
471 true_main: True main function to run if benchmarks are not requested.
472 argv: the command line arguments (if None, uses sys.argv).
473 """
474 if argv is None:
475 argv = sys.argv
476 found_arg = [
477 arg
478 for arg in argv
479 if arg.startswith("--benchmark_filter=")
480 or arg.startswith("-benchmark_filter=")
481 ]
482 if found_arg:
483 # Remove --benchmark_filter arg from sys.argv
484 argv.remove(found_arg[0])
486 regex = found_arg[0].split("=")[1]
487 app.run(lambda _: _run_benchmarks(regex), argv=argv)
488 else:
489 true_main()