Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/framework/test_combinations.py: 27%

154 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Facilities for creating multiple test combinations. 

16 

17Here is a simple example for testing various optimizers in Eager and Graph: 

18 

19class AdditionExample(test.TestCase, parameterized.TestCase): 

20 @combinations.generate( 

21 combinations.combine(mode=["graph", "eager"], 

22 optimizer=[AdamOptimizer(), 

23 GradientDescentOptimizer()])) 

24 def testOptimizer(self, optimizer): 

25 ... f(optimizer)... 

26 

27This will run `testOptimizer` 4 times with the specified optimizers: 2 in 

28Eager and 2 in Graph mode. 

29The test is going to accept the same parameters as the ones used in `combine()`. 

30The parameters need to match by name between the `combine()` call and the test 

31signature. It is necessary to accept all parameters. See `OptionalParameter` 

32for a way to implement optional parameters. 

33 

34`combine()` function is available for creating a cross product of various 

35options. `times()` function exists for creating a product of N `combine()`-ed 

36results. 

37 

38The execution of generated tests can be customized in a number of ways: 

39- The test can be skipped if it is not running in the correct environment. 

40- The arguments that are passed to the test can be additionally transformed. 

41- The test can be run with specific Python context managers. 

42These behaviors can be customized by providing instances of `TestCombination` to 

43`generate()`. 

44""" 

45 

46from collections import OrderedDict 

47import contextlib 

48import re 

49import types 

50import unittest 

51 

52from absl.testing import parameterized 

53 

54from tensorflow.python.util import tf_inspect 

55from tensorflow.python.util.tf_export import tf_export 

56 

57 

58@tf_export("__internal__.test.combinations.TestCombination", v1=[]) 

59class TestCombination: 

60 """Customize the behavior of `generate()` and the tests that it executes. 

61 

62 Here is sequence of steps for executing a test combination: 

63 1. The test combination is evaluated for whether it should be executed in 

64 the given environment by calling `should_execute_combination`. 

65 2. If the test combination is going to be executed, then the arguments for 

66 all combined parameters are validated. Some arguments can be handled in 

67 a special way. This is achieved by implementing that logic in 

68 `ParameterModifier` instances that returned from `parameter_modifiers`. 

69 3. Before executing the test, `context_managers` are installed 

70 around it. 

71 """ 

72 

73 def should_execute_combination(self, kwargs): 

74 """Indicates whether the combination of test arguments should be executed. 

75 

76 If the environment doesn't satisfy the dependencies of the test 

77 combination, then it can be skipped. 

78 

79 Args: 

80 kwargs: Arguments that are passed to the test combination. 

81 

82 Returns: 

83 A tuple boolean and an optional string. The boolean False indicates 

84 that the test should be skipped. The string would indicate a textual 

85 description of the reason. If the test is going to be executed, then 

86 this method returns `None` instead of the string. 

87 """ 

88 del kwargs 

89 return (True, None) 

90 

91 def parameter_modifiers(self): 

92 """Returns `ParameterModifier` instances that customize the arguments.""" 

93 return [] 

94 

95 def context_managers(self, kwargs): 

96 """Return context managers for running the test combination. 

97 

98 The test combination will run under all context managers that all 

99 `TestCombination` instances return. 

100 

101 Args: 

102 kwargs: Arguments and their values that are passed to the test 

103 combination. 

104 

105 Returns: 

106 A list of instantiated context managers. 

107 """ 

108 del kwargs 

109 return [] 

110 

111 

112@tf_export("__internal__.test.combinations.ParameterModifier", v1=[]) 

113class ParameterModifier: 

114 """Customizes the behavior of a particular parameter. 

115 

116 Users should override `modified_arguments()` to modify the parameter they 

117 want, eg: change the value of certain parameter or filter it from the params 

118 passed to the test case. 

119 

120 See the sample usage below, it will change any negative parameters to zero 

121 before it gets passed to test case. 

122 ``` 

123 class NonNegativeParameterModifier(ParameterModifier): 

124 

125 def modified_arguments(self, kwargs, requested_parameters): 

126 updates = {} 

127 for name, value in kwargs.items(): 

128 if value < 0: 

129 updates[name] = 0 

130 return updates 

131 ``` 

132 """ 

133 

134 DO_NOT_PASS_TO_THE_TEST = object() 

135 

136 def __init__(self, parameter_name=None): 

137 """Construct a parameter modifier that may be specific to a parameter. 

138 

139 Args: 

140 parameter_name: A `ParameterModifier` instance may operate on a class of 

141 parameters or on a parameter with a particular name. Only 

142 `ParameterModifier` instances that are of a unique type or were 

143 initialized with a unique `parameter_name` will be executed. 

144 See `__eq__` and `__hash__`. 

145 """ 

146 self._parameter_name = parameter_name 

147 

148 def modified_arguments(self, kwargs, requested_parameters): 

149 """Replace user-provided arguments before they are passed to a test. 

150 

151 This makes it possible to adjust user-provided arguments before passing 

152 them to the test method. 

153 

154 Args: 

155 kwargs: The combined arguments for the test. 

156 requested_parameters: The set of parameters that are defined in the 

157 signature of the test method. 

158 

159 Returns: 

160 A dictionary with updates to `kwargs`. Keys with values set to 

161 `ParameterModifier.DO_NOT_PASS_TO_THE_TEST` are going to be deleted and 

162 not passed to the test. 

163 """ 

164 del kwargs, requested_parameters 

165 return {} 

166 

167 def __eq__(self, other): 

168 """Compare `ParameterModifier` by type and `parameter_name`.""" 

169 if self is other: 

170 return True 

171 elif type(self) is type(other): 

172 return self._parameter_name == other._parameter_name 

173 else: 

174 return False 

175 

176 def __ne__(self, other): 

177 return not self.__eq__(other) 

178 

179 def __hash__(self): 

180 """Compare `ParameterModifier` by type or `parameter_name`.""" 

181 if self._parameter_name: 

182 return hash(self._parameter_name) 

183 else: 

184 return id(self.__class__) 

185 

186 

187@tf_export("__internal__.test.combinations.OptionalParameter", v1=[]) 

188class OptionalParameter(ParameterModifier): 

189 """A parameter that is optional in `combine()` and in the test signature. 

190 

191 `OptionalParameter` is usually used with `TestCombination` in the 

192 `parameter_modifiers()`. It allows `TestCombination` to skip certain 

193 parameters when passing them to `combine()`, since the `TestCombination` might 

194 consume the param and create some context based on the value it gets. 

195 

196 See the sample usage below: 

197 

198 ``` 

199 class EagerGraphCombination(TestCombination): 

200 

201 def context_managers(self, kwargs): 

202 mode = kwargs.pop("mode", None) 

203 if mode is None: 

204 return [] 

205 elif mode == "eager": 

206 return [context.eager_mode()] 

207 elif mode == "graph": 

208 return [ops.Graph().as_default(), context.graph_mode()] 

209 else: 

210 raise ValueError( 

211 "'mode' has to be either 'eager' or 'graph', got {}".format(mode)) 

212 

213 def parameter_modifiers(self): 

214 return [test_combinations.OptionalParameter("mode")] 

215 ``` 

216 

217 When the test case is generated, the param "mode" will not be passed to the 

218 test method, since it is consumed by the `EagerGraphCombination`. 

219 """ 

220 

221 def modified_arguments(self, kwargs, requested_parameters): 

222 if self._parameter_name in requested_parameters: 

223 return {} 

224 else: 

225 return {self._parameter_name: ParameterModifier.DO_NOT_PASS_TO_THE_TEST} 

226 

227 

228def generate(combinations, test_combinations=()): 

229 """A decorator for generating combinations of a test method or a test class. 

230 

231 Parameters of the test method must match by name to get the corresponding 

232 value of the combination. Tests must accept all parameters that are passed 

233 other than the ones that are `OptionalParameter`. 

234 

235 Args: 

236 combinations: a list of dictionaries created using combine() and times(). 

237 test_combinations: a tuple of `TestCombination` instances that customize 

238 the execution of generated tests. 

239 

240 Returns: 

241 a decorator that will cause the test method or the test class to be run 

242 under the specified conditions. 

243 

244 Raises: 

245 ValueError: if any parameters were not accepted by the test method 

246 """ 

247 def decorator(test_method_or_class): 

248 """The decorator to be returned.""" 

249 

250 # Generate good test names that can be used with --test_filter. 

251 named_combinations = [] 

252 for combination in combinations: 

253 # We use OrderedDicts in `combine()` and `times()` to ensure stable 

254 # order of keys in each dictionary. 

255 assert isinstance(combination, OrderedDict) 

256 name = "".join([ 

257 "_{}_{}".format("".join(filter(str.isalnum, key)), 

258 "".join(filter(str.isalnum, _get_name(value, i)))) 

259 for i, (key, value) in enumerate(combination.items()) 

260 ]) 

261 named_combinations.append( 

262 OrderedDict( 

263 list(combination.items()) + 

264 [("testcase_name", "_test{}".format(name))])) 

265 

266 if isinstance(test_method_or_class, type): 

267 class_object = test_method_or_class 

268 class_object._test_method_ids = test_method_ids = {} 

269 for name, test_method in class_object.__dict__.copy().items(): 

270 if (name.startswith(unittest.TestLoader.testMethodPrefix) and 

271 isinstance(test_method, types.FunctionType)): 

272 delattr(class_object, name) 

273 methods = {} 

274 parameterized._update_class_dict_for_param_test_case( 

275 class_object.__name__, methods, test_method_ids, name, 

276 parameterized._ParameterizedTestIter( 

277 _augment_with_special_arguments( 

278 test_method, test_combinations=test_combinations), 

279 named_combinations, parameterized._NAMED, name)) 

280 for method_name, method in methods.items(): 

281 setattr(class_object, method_name, method) 

282 

283 return class_object 

284 else: 

285 test_method = _augment_with_special_arguments( 

286 test_method_or_class, test_combinations=test_combinations) 

287 return parameterized.named_parameters(*named_combinations)(test_method) 

288 

289 return decorator 

290 

291 

292def _augment_with_special_arguments(test_method, test_combinations): 

293 def decorated(self, **kwargs): 

294 """A wrapped test method that can treat some arguments in a special way.""" 

295 original_kwargs = kwargs.copy() 

296 

297 # Skip combinations that are going to be executed in a different testing 

298 # environment. 

299 reasons_to_skip = [] 

300 for combination in test_combinations: 

301 should_execute, reason = combination.should_execute_combination( 

302 original_kwargs.copy()) 

303 if not should_execute: 

304 reasons_to_skip.append(" - " + reason) 

305 

306 if reasons_to_skip: 

307 self.skipTest("\n".join(reasons_to_skip)) 

308 

309 customized_parameters = [] 

310 for combination in test_combinations: 

311 customized_parameters.extend(combination.parameter_modifiers()) 

312 customized_parameters = set(customized_parameters) 

313 

314 # The function for running the test under the total set of 

315 # `context_managers`: 

316 def execute_test_method(): 

317 requested_parameters = tf_inspect.getfullargspec(test_method).args 

318 for customized_parameter in customized_parameters: 

319 for argument, value in customized_parameter.modified_arguments( 

320 original_kwargs.copy(), requested_parameters).items(): 

321 if value is ParameterModifier.DO_NOT_PASS_TO_THE_TEST: 

322 kwargs.pop(argument, None) 

323 else: 

324 kwargs[argument] = value 

325 

326 omitted_arguments = set(requested_parameters).difference( 

327 set(list(kwargs.keys()) + ["self"])) 

328 if omitted_arguments: 

329 raise ValueError("The test requires parameters whose arguments " 

330 "were not passed: {} .".format(omitted_arguments)) 

331 missing_arguments = set(list(kwargs.keys()) + ["self"]).difference( 

332 set(requested_parameters)) 

333 if missing_arguments: 

334 raise ValueError("The test does not take parameters that were passed " 

335 ": {} .".format(missing_arguments)) 

336 

337 kwargs_to_pass = {} 

338 for parameter in requested_parameters: 

339 if parameter == "self": 

340 kwargs_to_pass[parameter] = self 

341 else: 

342 kwargs_to_pass[parameter] = kwargs[parameter] 

343 test_method(**kwargs_to_pass) 

344 

345 # Install `context_managers` before running the test: 

346 context_managers = [] 

347 for combination in test_combinations: 

348 for manager in combination.context_managers( 

349 original_kwargs.copy()): 

350 context_managers.append(manager) 

351 

352 if hasattr(contextlib, "nested"): # Python 2 

353 # TODO(isaprykin): Switch to ExitStack when contextlib2 is available. 

354 with contextlib.nested(*context_managers): 

355 execute_test_method() 

356 else: # Python 3 

357 with contextlib.ExitStack() as context_stack: 

358 for manager in context_managers: 

359 context_stack.enter_context(manager) 

360 execute_test_method() 

361 

362 return decorated 

363 

364 

365@tf_export("__internal__.test.combinations.combine", v1=[]) 

366def combine(**kwargs): 

367 """Generate combinations based on its keyword arguments. 

368 

369 Two sets of returned combinations can be concatenated using +. Their product 

370 can be computed using `times()`. 

371 

372 Args: 

373 **kwargs: keyword arguments of form `option=[possibilities, ...]` 

374 or `option=the_only_possibility`. 

375 

376 Returns: 

377 a list of dictionaries for each combination. Keys in the dictionaries are 

378 the keyword argument names. Each key has one value - one of the 

379 corresponding keyword argument values. 

380 """ 

381 if not kwargs: 

382 return [OrderedDict()] 

383 

384 sort_by_key = lambda k: k[0] 

385 kwargs = OrderedDict(sorted(kwargs.items(), key=sort_by_key)) 

386 first = list(kwargs.items())[0] 

387 

388 rest = dict(list(kwargs.items())[1:]) 

389 rest_combined = combine(**rest) 

390 

391 key = first[0] 

392 values = first[1] 

393 if not isinstance(values, list): 

394 values = [values] 

395 

396 return [ 

397 OrderedDict(sorted(list(combined.items()) + [(key, v)], key=sort_by_key)) 

398 for v in values 

399 for combined in rest_combined 

400 ] 

401 

402 

403@tf_export("__internal__.test.combinations.times", v1=[]) 

404def times(*combined): 

405 """Generate a product of N sets of combinations. 

406 

407 times(combine(a=[1,2]), combine(b=[3,4])) == combine(a=[1,2], b=[3,4]) 

408 

409 Args: 

410 *combined: N lists of dictionaries that specify combinations. 

411 

412 Returns: 

413 a list of dictionaries for each combination. 

414 

415 Raises: 

416 ValueError: if some of the inputs have overlapping keys. 

417 """ 

418 assert combined 

419 

420 if len(combined) == 1: 

421 return combined[0] 

422 

423 first = combined[0] 

424 rest_combined = times(*combined[1:]) 

425 

426 combined_results = [] 

427 for a in first: 

428 for b in rest_combined: 

429 if set(a.keys()).intersection(set(b.keys())): 

430 raise ValueError("Keys need to not overlap: {} vs {}".format( 

431 a.keys(), b.keys())) 

432 

433 combined_results.append(OrderedDict(list(a.items()) + list(b.items()))) 

434 return combined_results 

435 

436 

437@tf_export("__internal__.test.combinations.NamedObject", v1=[]) 

438class NamedObject: 

439 """A class that translates an object into a good test name.""" 

440 

441 def __init__(self, name, obj): 

442 self._name = name 

443 self._obj = obj 

444 

445 def __getattr__(self, name): 

446 return getattr(self._obj, name) 

447 

448 def __call__(self, *args, **kwargs): 

449 return self._obj(*args, **kwargs) 

450 

451 def __iter__(self): 

452 return self._obj.__iter__() 

453 

454 def __repr__(self): 

455 return self._name 

456 

457 

458def _get_name(value, index): 

459 return re.sub("0[xX][0-9a-fA-F]+", str(index), str(value))