Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/framework/test_combinations.py: 27%
154 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Facilities for creating multiple test combinations.
17Here is a simple example for testing various optimizers in Eager and Graph:
19class AdditionExample(test.TestCase, parameterized.TestCase):
20 @combinations.generate(
21 combinations.combine(mode=["graph", "eager"],
22 optimizer=[AdamOptimizer(),
23 GradientDescentOptimizer()]))
24 def testOptimizer(self, optimizer):
25 ... f(optimizer)...
27This will run `testOptimizer` 4 times with the specified optimizers: 2 in
28Eager and 2 in Graph mode.
29The test is going to accept the same parameters as the ones used in `combine()`.
30The parameters need to match by name between the `combine()` call and the test
31signature. It is necessary to accept all parameters. See `OptionalParameter`
32for a way to implement optional parameters.
34`combine()` function is available for creating a cross product of various
35options. `times()` function exists for creating a product of N `combine()`-ed
36results.
38The execution of generated tests can be customized in a number of ways:
39- The test can be skipped if it is not running in the correct environment.
40- The arguments that are passed to the test can be additionally transformed.
41- The test can be run with specific Python context managers.
42These behaviors can be customized by providing instances of `TestCombination` to
43`generate()`.
44"""
46from collections import OrderedDict
47import contextlib
48import re
49import types
50import unittest
52from absl.testing import parameterized
54from tensorflow.python.util import tf_inspect
55from tensorflow.python.util.tf_export import tf_export
58@tf_export("__internal__.test.combinations.TestCombination", v1=[])
59class TestCombination:
60 """Customize the behavior of `generate()` and the tests that it executes.
62 Here is sequence of steps for executing a test combination:
63 1. The test combination is evaluated for whether it should be executed in
64 the given environment by calling `should_execute_combination`.
65 2. If the test combination is going to be executed, then the arguments for
66 all combined parameters are validated. Some arguments can be handled in
67 a special way. This is achieved by implementing that logic in
68 `ParameterModifier` instances that returned from `parameter_modifiers`.
69 3. Before executing the test, `context_managers` are installed
70 around it.
71 """
73 def should_execute_combination(self, kwargs):
74 """Indicates whether the combination of test arguments should be executed.
76 If the environment doesn't satisfy the dependencies of the test
77 combination, then it can be skipped.
79 Args:
80 kwargs: Arguments that are passed to the test combination.
82 Returns:
83 A tuple boolean and an optional string. The boolean False indicates
84 that the test should be skipped. The string would indicate a textual
85 description of the reason. If the test is going to be executed, then
86 this method returns `None` instead of the string.
87 """
88 del kwargs
89 return (True, None)
91 def parameter_modifiers(self):
92 """Returns `ParameterModifier` instances that customize the arguments."""
93 return []
95 def context_managers(self, kwargs):
96 """Return context managers for running the test combination.
98 The test combination will run under all context managers that all
99 `TestCombination` instances return.
101 Args:
102 kwargs: Arguments and their values that are passed to the test
103 combination.
105 Returns:
106 A list of instantiated context managers.
107 """
108 del kwargs
109 return []
112@tf_export("__internal__.test.combinations.ParameterModifier", v1=[])
113class ParameterModifier:
114 """Customizes the behavior of a particular parameter.
116 Users should override `modified_arguments()` to modify the parameter they
117 want, eg: change the value of certain parameter or filter it from the params
118 passed to the test case.
120 See the sample usage below, it will change any negative parameters to zero
121 before it gets passed to test case.
122 ```
123 class NonNegativeParameterModifier(ParameterModifier):
125 def modified_arguments(self, kwargs, requested_parameters):
126 updates = {}
127 for name, value in kwargs.items():
128 if value < 0:
129 updates[name] = 0
130 return updates
131 ```
132 """
134 DO_NOT_PASS_TO_THE_TEST = object()
136 def __init__(self, parameter_name=None):
137 """Construct a parameter modifier that may be specific to a parameter.
139 Args:
140 parameter_name: A `ParameterModifier` instance may operate on a class of
141 parameters or on a parameter with a particular name. Only
142 `ParameterModifier` instances that are of a unique type or were
143 initialized with a unique `parameter_name` will be executed.
144 See `__eq__` and `__hash__`.
145 """
146 self._parameter_name = parameter_name
148 def modified_arguments(self, kwargs, requested_parameters):
149 """Replace user-provided arguments before they are passed to a test.
151 This makes it possible to adjust user-provided arguments before passing
152 them to the test method.
154 Args:
155 kwargs: The combined arguments for the test.
156 requested_parameters: The set of parameters that are defined in the
157 signature of the test method.
159 Returns:
160 A dictionary with updates to `kwargs`. Keys with values set to
161 `ParameterModifier.DO_NOT_PASS_TO_THE_TEST` are going to be deleted and
162 not passed to the test.
163 """
164 del kwargs, requested_parameters
165 return {}
167 def __eq__(self, other):
168 """Compare `ParameterModifier` by type and `parameter_name`."""
169 if self is other:
170 return True
171 elif type(self) is type(other):
172 return self._parameter_name == other._parameter_name
173 else:
174 return False
176 def __ne__(self, other):
177 return not self.__eq__(other)
179 def __hash__(self):
180 """Compare `ParameterModifier` by type or `parameter_name`."""
181 if self._parameter_name:
182 return hash(self._parameter_name)
183 else:
184 return id(self.__class__)
187@tf_export("__internal__.test.combinations.OptionalParameter", v1=[])
188class OptionalParameter(ParameterModifier):
189 """A parameter that is optional in `combine()` and in the test signature.
191 `OptionalParameter` is usually used with `TestCombination` in the
192 `parameter_modifiers()`. It allows `TestCombination` to skip certain
193 parameters when passing them to `combine()`, since the `TestCombination` might
194 consume the param and create some context based on the value it gets.
196 See the sample usage below:
198 ```
199 class EagerGraphCombination(TestCombination):
201 def context_managers(self, kwargs):
202 mode = kwargs.pop("mode", None)
203 if mode is None:
204 return []
205 elif mode == "eager":
206 return [context.eager_mode()]
207 elif mode == "graph":
208 return [ops.Graph().as_default(), context.graph_mode()]
209 else:
210 raise ValueError(
211 "'mode' has to be either 'eager' or 'graph', got {}".format(mode))
213 def parameter_modifiers(self):
214 return [test_combinations.OptionalParameter("mode")]
215 ```
217 When the test case is generated, the param "mode" will not be passed to the
218 test method, since it is consumed by the `EagerGraphCombination`.
219 """
221 def modified_arguments(self, kwargs, requested_parameters):
222 if self._parameter_name in requested_parameters:
223 return {}
224 else:
225 return {self._parameter_name: ParameterModifier.DO_NOT_PASS_TO_THE_TEST}
228def generate(combinations, test_combinations=()):
229 """A decorator for generating combinations of a test method or a test class.
231 Parameters of the test method must match by name to get the corresponding
232 value of the combination. Tests must accept all parameters that are passed
233 other than the ones that are `OptionalParameter`.
235 Args:
236 combinations: a list of dictionaries created using combine() and times().
237 test_combinations: a tuple of `TestCombination` instances that customize
238 the execution of generated tests.
240 Returns:
241 a decorator that will cause the test method or the test class to be run
242 under the specified conditions.
244 Raises:
245 ValueError: if any parameters were not accepted by the test method
246 """
247 def decorator(test_method_or_class):
248 """The decorator to be returned."""
250 # Generate good test names that can be used with --test_filter.
251 named_combinations = []
252 for combination in combinations:
253 # We use OrderedDicts in `combine()` and `times()` to ensure stable
254 # order of keys in each dictionary.
255 assert isinstance(combination, OrderedDict)
256 name = "".join([
257 "_{}_{}".format("".join(filter(str.isalnum, key)),
258 "".join(filter(str.isalnum, _get_name(value, i))))
259 for i, (key, value) in enumerate(combination.items())
260 ])
261 named_combinations.append(
262 OrderedDict(
263 list(combination.items()) +
264 [("testcase_name", "_test{}".format(name))]))
266 if isinstance(test_method_or_class, type):
267 class_object = test_method_or_class
268 class_object._test_method_ids = test_method_ids = {}
269 for name, test_method in class_object.__dict__.copy().items():
270 if (name.startswith(unittest.TestLoader.testMethodPrefix) and
271 isinstance(test_method, types.FunctionType)):
272 delattr(class_object, name)
273 methods = {}
274 parameterized._update_class_dict_for_param_test_case(
275 class_object.__name__, methods, test_method_ids, name,
276 parameterized._ParameterizedTestIter(
277 _augment_with_special_arguments(
278 test_method, test_combinations=test_combinations),
279 named_combinations, parameterized._NAMED, name))
280 for method_name, method in methods.items():
281 setattr(class_object, method_name, method)
283 return class_object
284 else:
285 test_method = _augment_with_special_arguments(
286 test_method_or_class, test_combinations=test_combinations)
287 return parameterized.named_parameters(*named_combinations)(test_method)
289 return decorator
292def _augment_with_special_arguments(test_method, test_combinations):
293 def decorated(self, **kwargs):
294 """A wrapped test method that can treat some arguments in a special way."""
295 original_kwargs = kwargs.copy()
297 # Skip combinations that are going to be executed in a different testing
298 # environment.
299 reasons_to_skip = []
300 for combination in test_combinations:
301 should_execute, reason = combination.should_execute_combination(
302 original_kwargs.copy())
303 if not should_execute:
304 reasons_to_skip.append(" - " + reason)
306 if reasons_to_skip:
307 self.skipTest("\n".join(reasons_to_skip))
309 customized_parameters = []
310 for combination in test_combinations:
311 customized_parameters.extend(combination.parameter_modifiers())
312 customized_parameters = set(customized_parameters)
314 # The function for running the test under the total set of
315 # `context_managers`:
316 def execute_test_method():
317 requested_parameters = tf_inspect.getfullargspec(test_method).args
318 for customized_parameter in customized_parameters:
319 for argument, value in customized_parameter.modified_arguments(
320 original_kwargs.copy(), requested_parameters).items():
321 if value is ParameterModifier.DO_NOT_PASS_TO_THE_TEST:
322 kwargs.pop(argument, None)
323 else:
324 kwargs[argument] = value
326 omitted_arguments = set(requested_parameters).difference(
327 set(list(kwargs.keys()) + ["self"]))
328 if omitted_arguments:
329 raise ValueError("The test requires parameters whose arguments "
330 "were not passed: {} .".format(omitted_arguments))
331 missing_arguments = set(list(kwargs.keys()) + ["self"]).difference(
332 set(requested_parameters))
333 if missing_arguments:
334 raise ValueError("The test does not take parameters that were passed "
335 ": {} .".format(missing_arguments))
337 kwargs_to_pass = {}
338 for parameter in requested_parameters:
339 if parameter == "self":
340 kwargs_to_pass[parameter] = self
341 else:
342 kwargs_to_pass[parameter] = kwargs[parameter]
343 test_method(**kwargs_to_pass)
345 # Install `context_managers` before running the test:
346 context_managers = []
347 for combination in test_combinations:
348 for manager in combination.context_managers(
349 original_kwargs.copy()):
350 context_managers.append(manager)
352 if hasattr(contextlib, "nested"): # Python 2
353 # TODO(isaprykin): Switch to ExitStack when contextlib2 is available.
354 with contextlib.nested(*context_managers):
355 execute_test_method()
356 else: # Python 3
357 with contextlib.ExitStack() as context_stack:
358 for manager in context_managers:
359 context_stack.enter_context(manager)
360 execute_test_method()
362 return decorated
365@tf_export("__internal__.test.combinations.combine", v1=[])
366def combine(**kwargs):
367 """Generate combinations based on its keyword arguments.
369 Two sets of returned combinations can be concatenated using +. Their product
370 can be computed using `times()`.
372 Args:
373 **kwargs: keyword arguments of form `option=[possibilities, ...]`
374 or `option=the_only_possibility`.
376 Returns:
377 a list of dictionaries for each combination. Keys in the dictionaries are
378 the keyword argument names. Each key has one value - one of the
379 corresponding keyword argument values.
380 """
381 if not kwargs:
382 return [OrderedDict()]
384 sort_by_key = lambda k: k[0]
385 kwargs = OrderedDict(sorted(kwargs.items(), key=sort_by_key))
386 first = list(kwargs.items())[0]
388 rest = dict(list(kwargs.items())[1:])
389 rest_combined = combine(**rest)
391 key = first[0]
392 values = first[1]
393 if not isinstance(values, list):
394 values = [values]
396 return [
397 OrderedDict(sorted(list(combined.items()) + [(key, v)], key=sort_by_key))
398 for v in values
399 for combined in rest_combined
400 ]
403@tf_export("__internal__.test.combinations.times", v1=[])
404def times(*combined):
405 """Generate a product of N sets of combinations.
407 times(combine(a=[1,2]), combine(b=[3,4])) == combine(a=[1,2], b=[3,4])
409 Args:
410 *combined: N lists of dictionaries that specify combinations.
412 Returns:
413 a list of dictionaries for each combination.
415 Raises:
416 ValueError: if some of the inputs have overlapping keys.
417 """
418 assert combined
420 if len(combined) == 1:
421 return combined[0]
423 first = combined[0]
424 rest_combined = times(*combined[1:])
426 combined_results = []
427 for a in first:
428 for b in rest_combined:
429 if set(a.keys()).intersection(set(b.keys())):
430 raise ValueError("Keys need to not overlap: {} vs {}".format(
431 a.keys(), b.keys()))
433 combined_results.append(OrderedDict(list(a.items()) + list(b.items())))
434 return combined_results
437@tf_export("__internal__.test.combinations.NamedObject", v1=[])
438class NamedObject:
439 """A class that translates an object into a good test name."""
441 def __init__(self, name, obj):
442 self._name = name
443 self._obj = obj
445 def __getattr__(self, name):
446 return getattr(self._obj, name)
448 def __call__(self, *args, **kwargs):
449 return self._obj(*args, **kwargs)
451 def __iter__(self):
452 return self._obj.__iter__()
454 def __repr__(self):
455 return self._name
458def _get_name(value, index):
459 return re.sub("0[xX][0-9a-fA-F]+", str(index), str(value))