Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/distribute/combinations.py: 37%

259 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""This module customizes `test_combinations` for `tf.distribute.Strategy`. 

16 

17Additionally it provides `generate()`, `combine()` and `times()` with 

18`tf.distribute.Strategy` customizations as a default. 

19""" 

20 

21import collections 

22import copy 

23import re 

24import sys 

25import types 

26import unittest 

27 

28from absl import app 

29import six 

30 

31 

32from tensorflow.python.client import session 

33from tensorflow.python.distribute import collective_all_reduce_strategy 

34from tensorflow.python.distribute import distribute_lib 

35from tensorflow.python.distribute import multi_process_runner 

36from tensorflow.python.distribute import multi_worker_test_base 

37from tensorflow.python.eager import context 

38from tensorflow.python.eager import def_function 

39from tensorflow.python.framework import combinations as framework_combinations 

40from tensorflow.python.framework import config 

41from tensorflow.python.framework import ops 

42from tensorflow.python.framework import test_combinations as combinations_lib 

43from tensorflow.python.framework import test_util 

44from tensorflow.python.platform import flags 

45from tensorflow.python.platform import tf_logging as logging 

46from tensorflow.python.util import tf_decorator 

47from tensorflow.python.util import tf_inspect 

48from tensorflow.python.util.tf_export import tf_export 

49 

50 

51# TODO(rchao): Rename `distribution` parameter to `strategy` or 

52# `distribute_strategy` in all tests. 

53class DistributionParameter(combinations_lib.ParameterModifier): 

54 """Transforms arguments of type `NamedDistribution`. 

55 

56 Convert all arguments of type `NamedDistribution` to the value of their 

57 `strategy` property. 

58 """ 

59 

60 def modified_arguments(self, kwargs, requested_parameters): 

61 # Get the parameter that indicates if we need to set the `_use_policy` flag 

62 # on the strategy object. This is a temporary flag for testing the variable 

63 # policy rollout. 

64 use_var_policy = kwargs.get("use_var_policy", None) 

65 distribution_arguments = {} 

66 for k, v in kwargs.items(): 

67 if isinstance(v, NamedDistribution): 

68 strategy = v.strategy 

69 if use_var_policy: 

70 strategy.extended._use_var_policy = use_var_policy 

71 distribution_arguments[k] = strategy 

72 return distribution_arguments 

73 

74 

75class ClusterParameters(combinations_lib.ParameterModifier): 

76 """Adds cluster parameters if a `NamedDistribution` has it. 

77 

78 It needs to be before DistributionParameter. 

79 """ 

80 

81 def modified_arguments(self, kwargs, requested_parameters): 

82 strategy = None 

83 for _, v in kwargs.items(): 

84 if isinstance(v, NamedDistribution): 

85 if strategy is not None and _num_total_workers(v.has_chief, 

86 v.num_workers) > 1: 

87 raise ValueError("Only support one NamedDistribution for multi worker" 

88 "tests.") 

89 strategy = v 

90 

91 if strategy: 

92 has_chief = strategy.has_chief 

93 num_workers = strategy.num_workers 

94 runner = strategy.runner 

95 share_gpu = strategy.share_gpu 

96 num_ps = strategy.num_ps 

97 if "has_chief" in kwargs and kwargs["has_chief"] != has_chief: 

98 raise ValueError( 

99 "both has_chief and strategy specified but are not compatible") 

100 if "num_workers" in kwargs and kwargs["num_workers"] != num_workers: 

101 raise ValueError( 

102 "both num_workers and strategy specified but are not compatible") 

103 else: 

104 has_chief = kwargs.get("has_chief", False) 

105 num_workers = kwargs.get("num_workers", 1) 

106 runner = kwargs.get("runner", None) 

107 share_gpu = kwargs.get("share_gpu", True) 

108 num_ps = kwargs.get("num_ps", 0) 

109 

110 # Always set cluster parameters if they're requested. So that generate() 

111 # works when there's no startegy in the combinations. 

112 update = {} 

113 if "has_chief" in requested_parameters: 

114 update["has_chief"] = has_chief 

115 if "num_workers" in requested_parameters: 

116 update["num_workers"] = num_workers 

117 if "runner" in requested_parameters: 

118 update["runner"] = runner 

119 if "share_gpu" in requested_parameters: 

120 update["share_gpu"] = share_gpu 

121 if "num_ps" in requested_parameters: 

122 update["num_ps"] = num_ps 

123 return update 

124 

125 

126class DistributionCombination(combinations_lib.TestCombination): 

127 """Sets up distribution strategy for tests.""" 

128 

129 def should_execute_combination(self, kwargs): 

130 distributions = [ 

131 v for v in kwargs.values() if isinstance(v, NamedDistribution) 

132 ] 

133 if test_util.is_xla_enabled() and any(d.no_xla for d in distributions): 

134 return ( 

135 False, 

136 "n/a: skipping strategy combination with no_xla=True in XLA tests") 

137 return (True, None) 

138 

139 def parameter_modifiers(self): 

140 return [ 

141 DistributionParameter(), 

142 combinations_lib.OptionalParameter("use_var_policy"), 

143 ] 

144 

145 

146class ClusterCombination(combinations_lib.TestCombination): 

147 """Sets up multi worker tests.""" 

148 

149 def parameter_modifiers(self): 

150 return [ClusterParameters()] 

151 

152 

153class GPUCombination(combinations_lib.TestCombination): 

154 """Enable tests to request GPU hardware and skip non-GPU combinations. 

155 

156 This class expects test_combinations to be generated with `NamedDistribution` 

157 wrapping instances of `tf.distribute.Strategy`. 

158 

159 Optionally, the `required_gpus` argument is supported. GPU hardware is 

160 required, if its value is `True` or > 0. 

161 

162 Attributes: 

163 GPU_TEST: The environment is considered to have GPU hardware available if 

164 the name of the program contains "test_gpu" or "test_xla_gpu". 

165 """ 

166 GPU_TEST = False 

167 if sys.argv: 

168 GPU_TEST = re.search(r"(test_2?gpu|test_xla_2?gpu)$", sys.argv[0]) 

169 

170 def should_execute_combination(self, kwargs): 

171 distributions = [ 

172 v for v in kwargs.values() if isinstance(v, NamedDistribution) 

173 ] 

174 required_gpus = kwargs.get("required_gpus", 0) 

175 required_physical_gpus = kwargs.get("required_physical_gpus", 0) 

176 

177 if distributions and required_gpus: 

178 raise ValueError("Do not use `required_gpus` and arguments of type " 

179 "NamedDistribution together.") 

180 

181 number_of_required_gpus = max( 

182 [required_gpus] + [required_physical_gpus] + 

183 [d.required_physical_gpus or 0 for d in distributions] + 

184 [d.required_gpus or 0 for d in distributions]) 

185 number_of_required_physical_gpus = max( 

186 [required_physical_gpus] + 

187 [d.required_physical_gpus or 0 for d in distributions]) 

188 

189 if (required_physical_gpus and required_gpus): 

190 raise ValueError("Only one of `required_physical_gpus`(number of physical" 

191 " GPUs required) and `required_gpus`(total number of " 

192 "GPUs required) should be set. ") 

193 if not number_of_required_gpus and GPUCombination.GPU_TEST: 

194 return (False, "Test that doesn't require GPUs.") 

195 elif (number_of_required_gpus > 0 

196 and context.num_gpus() < number_of_required_gpus): 

197 return (False, ("Only {} of {} required GPUs are available.".format( 

198 context.num_gpus(), number_of_required_gpus))) 

199 elif number_of_required_physical_gpus > len( 

200 config.list_physical_devices("GPU")): 

201 return (False, 

202 ("Only {} of {} required physical GPUs are available.".format( 

203 config.list_physical_devices("GPU"), required_physical_gpus))) 

204 else: 

205 return (True, None) 

206 

207 def parameter_modifiers(self): 

208 return [combinations_lib.OptionalParameter("required_gpus"), 

209 combinations_lib.OptionalParameter("required_physical_gpus")] 

210 

211 

212class TPUCombination(combinations_lib.TestCombination): 

213 """Allow to request TPU hardware and skip non-TPU combinations. 

214 

215 This class expects test_combinations to be generated with `NamedDistribution` 

216 wrapping instances of `tf.distribute.Strategy`. 

217 

218 Optionally, the `required_tpus` parameter is supported. TPU hardware is 

219 required, if its argument is `True` or > 0. 

220 

221 Optionally, the `use_cloud_tpu` parameter is supported. If TPU hardware is 

222 required by `required_tpus`, it specifically must be a Cloud TPU (specified 

223 with `--tpu`) if `use_cloud_tpu` is `True`. 

224 

225 Attributes: 

226 TPU_TEST: The environment is considered to have TPU hardware available if 

227 the name of the program contains "test_tpu". 

228 """ 

229 

230 TPU_TEST = False 

231 if sys.argv: 

232 TPU_TEST = "test_tpu" in sys.argv[0] 

233 

234 def should_execute_combination(self, kwargs): 

235 distributions = [ 

236 v for v in kwargs.values() if isinstance(v, NamedDistribution) 

237 ] 

238 # TODO(isaprykin): Migrate all tests away from using 'required_tpu' in favor 

239 # of 'required_tpus'. 

240 if "required_tpus" in kwargs and "required_tpu" in kwargs: 

241 raise ValueError("Do not use `required_tpu`. Both `required_tpus` and " 

242 "`required_tpu` were specified.") 

243 required_tpus = kwargs.get("required_tpus", None) or kwargs.get( 

244 "required_tpu", None) 

245 

246 if distributions and required_tpus: 

247 raise ValueError("Do not use `required_tpus` and arguments of type " 

248 "NamedDistribution together.") 

249 

250 # TODO(isaprykin): Add support for a particular number of TPUs. Right now 

251 # it's binary. 

252 number_of_required_tpus = max([required_tpus or 0] + 

253 [d.required_tpu or 0 for d in distributions]) 

254 use_cloud_tpu = any([kwargs.get("use_cloud_tpu")] + 

255 [d.use_cloud_tpu for d in distributions]) 

256 tpu = hasattr(flags.FLAGS, "tpu") and flags.FLAGS.tpu or "" 

257 

258 if not number_of_required_tpus and TPUCombination.TPU_TEST: 

259 return (False, "Test that doesn't require TPUs.") 

260 if number_of_required_tpus and not TPUCombination.TPU_TEST: 

261 return (False, "Test requires a TPU, but it's not available.") 

262 if use_cloud_tpu and not tpu: 

263 return (False, "Test requires a Cloud TPU, but none specified.") 

264 if not use_cloud_tpu and tpu: 

265 return (False, "Test requires local TPU, but Cloud TPU specified.") 

266 return (True, None) 

267 

268 def parameter_modifiers(self): 

269 return [ 

270 combinations_lib.OptionalParameter("required_tpus"), 

271 combinations_lib.OptionalParameter("required_tpu"), 

272 combinations_lib.OptionalParameter("use_cloud_tpu"), 

273 ] 

274 

275 

276class NamedDistribution(object): 

277 """Wraps a `tf.distribute.Strategy` and adds a name for test titles.""" 

278 

279 def __init__(self, 

280 name, 

281 distribution_fn, 

282 required_gpus=None, 

283 required_physical_gpus=0, 

284 required_tpu=False, 

285 use_cloud_tpu=False, 

286 has_chief=False, 

287 num_workers=1, 

288 num_ps=0, 

289 share_gpu=True, 

290 pool_runner_fn=None, 

291 no_xla=False): 

292 """Initialize NamedDistribution. 

293 

294 Args: 

295 name: Name that will be a part of the name of the test case. 

296 distribution_fn: A callable that creates a `tf.distribute.Strategy`. 

297 required_gpus: The number of GPUs that the strategy requires. Only one of 

298 `required_gpus` and `required_physical_gpus` should be set. 

299 required_physical_gpus: Number of physical GPUs required. Only one of 

300 `required_gpus` and `required_physical_gpus` should be set. 

301 required_tpu: Whether the strategy requires TPU. 

302 use_cloud_tpu: Whether the strategy requires cloud TPU. 

303 has_chief: Whether the strategy requires a chief worker. 

304 num_workers: The number of workers that the strategy requires. 

305 num_ps: The number of parameter servers. 

306 share_gpu: Whether to share GPUs among workers. 

307 pool_runner_fn: An optional callable that returns a MultiProcessPoolRunner 

308 to run the test. 

309 no_xla: Whether to skip in XLA tests. 

310 """ 

311 object.__init__(self) 

312 self._name = name 

313 self._distribution_fn = distribution_fn 

314 self.required_gpus = required_gpus 

315 self.required_physical_gpus = required_physical_gpus 

316 self.required_tpu = required_tpu 

317 self.use_cloud_tpu = use_cloud_tpu 

318 self.has_chief = has_chief 

319 self.num_workers = num_workers 

320 self.num_ps = num_ps 

321 self.share_gpu = share_gpu 

322 self._pool_runner_fn = pool_runner_fn 

323 self.no_xla = no_xla 

324 

325 @property 

326 def runner(self): 

327 if self._pool_runner_fn is not None: 

328 return self._pool_runner_fn() 

329 return None 

330 

331 @property 

332 def strategy(self): 

333 return self._distribution_fn() 

334 

335 def __repr__(self): 

336 return self._name 

337 

338 

339# This is to allow adding combinations that runs a function both as a 

340# tf.function and eagerly. 

341# 

342# @combinations.generate( 

343# combinations.combine( 

344# tf_function = [combinations.tf_function, combinations.no_tf_function] 

345# ) 

346# ) 

347# def testXXX(tf_function): 

348# @tf_function 

349# def foo(): 

350# tf.add(1., 1.) 

351# 

352# foo() 

353tf_function = combinations_lib.NamedObject("TfFunction", def_function.function) 

354no_tf_function = combinations_lib.NamedObject("NoTfFunction", lambda f: f) 

355 

356 

357def concat(*combined): 

358 """Concats combinations.""" 

359 result = [] 

360 for one in combined: 

361 result += one 

362 return result 

363 

364 

365@tf_export("__internal__.distribute.combinations.generate", v1=[]) 

366def generate(combinations, test_combinations=()): 

367 # pylint: disable=g-doc-args,g-doc-return-or-yield 

368 """Distributed adapter of `tf.__internal__.test.combinations.generate`. 

369 

370 All tests with distributed strategy should use this one instead of 

371 `tf.__internal__.test.combinations.generate`. This function has support of 

372 strategy combinations, GPU/TPU and multi worker support. 

373 

374 See `tf.__internal__.test.combinations.generate` for usage. 

375 """ 

376 # pylint: enable=g-doc-args,g-doc-return-or-yield 

377 default_combinations = ( 

378 framework_combinations.EagerGraphCombination(), 

379 framework_combinations.TFVersionCombination(), 

380 ClusterCombination(), 

381 DistributionCombination(), 

382 GPUCombination(), 

383 TPUCombination(), 

384 ) 

385 # We apply our own decoration to handle multi worker tests before applying 

386 # framework.test_combinations.generate. The order is important since we need 

387 # framework.test_combinations.generate to apply all parameter modifiers first. 

388 combination_decorator = combinations_lib.generate( 

389 combinations, test_combinations=default_combinations + test_combinations) 

390 

391 def decorator(test_method_or_class): 

392 if isinstance(test_method_or_class, type): 

393 # If it's a test class. 

394 class_object = test_method_or_class 

395 # Decorate each test method with _multi_worker_test. 

396 for name, test_method in six.iteritems(class_object.__dict__.copy()): 

397 if (name.startswith(unittest.TestLoader.testMethodPrefix) and 

398 isinstance(test_method, types.FunctionType)): 

399 setattr(class_object, name, _multi_worker_test(test_method)) 

400 return combination_decorator(class_object) 

401 else: 

402 return combination_decorator(_multi_worker_test(test_method_or_class)) 

403 

404 return decorator 

405 

406 

407combine = combinations_lib.combine 

408times = combinations_lib.times 

409NamedObject = combinations_lib.NamedObject 

410 

411 

412# Identifies whether we're in the main process or worker processes. 

413# `_multi_worker_test` decoration behaves differently in the main processs and 

414# the worker processes. See the documentation of _multi_worker_test for detail. 

415_running_in_worker = False 

416 

417 

418@tf_export("__internal__.distribute.combinations.in_main_process", v1=[]) 

419def in_main_process(): 

420 """Whether it's in the main test process. 

421 

422 This is normally used to prepare the test environment which should only happen 

423 in the main process. 

424 

425 Returns: 

426 A boolean. 

427 """ 

428 return not _running_in_worker 

429 

430 

431class TestEnvironment(object): 

432 """Holds the test environment information. 

433 

434 Tests should modify the attributes of the instance returned by `env()` in the 

435 main process if needed, and it will be passed to the worker processes each 

436 time a test case is run. 

437 """ 

438 

439 def __init__(self): 

440 self.tf_data_service_dispatcher = None 

441 # Note that this includes GPUs that may not be visible to the current 

442 # worker. 

443 self.total_phsyical_gpus = None 

444 

445 def __setattr__(self, name, value): 

446 if not in_main_process(): 

447 raise ValueError( 

448 "combinations.env() should only be modified in the main process. " 

449 "Condition your code on combinations.in_main_process().") 

450 super().__setattr__(name, value) 

451 

452 

453_env = TestEnvironment() 

454 

455 

456@tf_export("__internal__.distribute.combinations.env", v1=[]) 

457def env(): 

458 """Returns the object holds the test environment information. 

459 

460 Tests should modify this in the main process if needed, and it will be passed 

461 to the worker processes each time a test case is run. 

462 

463 Returns: 

464 a TestEnvironment object. 

465 """ 

466 return _env 

467 

468 

469def _set_total_phsyical_gpus(): 

470 if in_main_process(): 

471 env().total_phsyical_gpus = len( 

472 context.context().list_physical_devices("GPU")) 

473 

474 

475# This is needed in case CUDA is lazily loaded. 

476app.call_after_init(_set_total_phsyical_gpus) 

477 

478 

479_TestResult = collections.namedtuple("_TestResult", ["status", "message"]) 

480 

481 

482def _test_runner(test_id, test_env): 

483 """Executes the test with the given test_id. 

484 

485 This is a simple wrapper around TestRunner to be used with 

486 multi_process_runner. Similar to test.main(), but it executes only one test 

487 specified by test_id and returns whether the test succeeds. If the test fails, 

488 the function prints failures and errors to stdout. 

489 

490 Args: 

491 test_id: TestCase.id() 

492 test_env: a TestEnvironment object. 

493 

494 Returns: 

495 A boolean indicates whether the test succeeds. 

496 """ 

497 global _running_in_worker, _env 

498 # No need to restore the value of _running_in_worker since it should always be 

499 # True in worker processes. 

500 _running_in_worker = True 

501 _env = test_env 

502 test = unittest.defaultTestLoader.loadTestsFromName(test_id) 

503 runner = unittest.TextTestRunner() 

504 result = runner.run(test) 

505 # Treat expected failures as failures, so that the main process can get 

506 # them and fail as expected. Also treat errors as failures to simplify the 

507 # handling. 

508 failures = result.failures + result.expectedFailures + result.errors 

509 if failures: 

510 ret = _TestResult(status="failure", message=failures[0][1]) 

511 elif result.skipped: 

512 ret = _TestResult(status="skipped", message=result.skipped[0][1]) 

513 else: 

514 # Treat unexpectedSuccesses as OK so that the test case in the main process 

515 # succeed as well. 

516 ret = _TestResult(status="ok", message=None) 

517 # Print tracebacks to stdout and multi_process_runner will collect 

518 # them and stream back to the main process. 

519 if ret.message: 

520 print(ret.message) 

521 return ret 

522 

523 

524def _multi_worker_test(test_method): 

525 """Decorate test_method so that it runs in each worker. 

526 

527 We use `multi_process_runner` to simulate multiple workers. Since we run the 

528 this function in the main process and all worker processes, this decoration 

529 behaves differently in the main process and worker procssses. In the main 

530 process, it spawns subprocesses and runs the test on each of them; in a worker 

531 process, it executes test in the same way as a normal test, e.g. 

532 setUp()/tearDown() are called before/after the test. 

533 

534 Args: 

535 test_method: a function which must be a test method. 

536 

537 Returns: 

538 Decorated `test_method`. Note that the decorated function has additional 

539 arguments. 

540 """ 

541 

542 def decorator(self, has_chief, num_workers, num_ps, share_gpu, runner, 

543 **kwargs): 

544 if _num_total_workers(has_chief, 

545 num_workers) == 1 or _running_in_worker or ( 

546 # Use in-process cluster for PS combinations 

547 # when XLA is enabled. 

548 test_util.is_xla_enabled() and num_ps > 0): 

549 # We're in worker process or the test is for single worker. Either case we 

550 # execute the test method directly instead of spawning subprocesses. 

551 

552 # For MultiWorkerMirroredStrategy(CollectiveAllReduceStrategy), install a 

553 # session that connects to the local server. This is necessary for multi 

554 # worker graph mode tests to work. Those tests cannot use their graphs or 

555 # sessions, including the one returned by self.cached_session(). Since 

556 # existing tests may already be doing so, we only install the session for 

557 # multi worker tests. 

558 with _multi_worker_session(kwargs): 

559 test_method(self, **kwargs) 

560 return 

561 

562 # We're in the main process. We spawn subprocesses and run the *test* on 

563 # each of them. Note that we're not directly executing test_method passed to 

564 # _multi_worker_test, because we need setUp()/tearDown() to be called and 

565 # all the decorations on the test method. The conceptual call stack is: 

566 # [main process]test.main() 

567 # [main process]test_runner.run(test) 

568 # [main process]wrapper by combinations.generate() 

569 # [main process]_multi_worker_test.decorator() 

570 # # A sub process goes through the same code path as the main 

571 # # process. 

572 # [sub process]_test_runner() 

573 # [sub process]test_runner.run(test) 

574 # [sub process]wrapper by combinations.generate() 

575 # [sub process]_multi_worker_test.decorator() 

576 # # _running_in_worker is True 

577 # [sub process]test_method() 

578 test_id = self.id() 

579 if runner: 

580 results = runner.run(_test_runner, args=(test_id, _env)) 

581 else: 

582 cluster_spec = multi_worker_test_base.create_cluster_spec( 

583 has_chief=has_chief, 

584 num_workers=num_workers, 

585 num_ps=num_ps, 

586 has_eval=False) 

587 ephemeral_runner = multi_process_runner.MultiProcessRunner( 

588 _test_runner, 

589 cluster_spec, 

590 share_gpu=share_gpu, 

591 args=(test_id, _env), 

592 dependence_on_chief=has_chief) 

593 ephemeral_runner.start() 

594 results = ephemeral_runner.join().return_value 

595 

596 skip_reason = None 

597 for result in results: 

598 if result.status == "failure": 

599 # We can't tell which worker the return value come from, so we fail on 

600 # the first error. 

601 self.fail(result.message) 

602 break 

603 elif result.status == "skipped": 

604 # Record the skip reason, but do not actually skip the test in case some 

605 # processes fail instead. 

606 skip_reason = result.message 

607 if skip_reason is not None: 

608 self.skipTest(skip_reason) 

609 

610 argspec = tf_inspect.getfullargspec(test_method) 

611 decorator_args = (argspec.args or []) + [ 

612 "has_chief", "num_workers", "num_ps", "share_gpu", "runner" 

613 ] 

614 decorator_argspec = argspec._replace(args=decorator_args) 

615 return tf_decorator.make_decorator( 

616 test_method, decorator, decorator_argspec=decorator_argspec) 

617 

618 

619def _num_total_workers(has_chief, num_workers): 

620 """Returns the number of workers including the chief.""" 

621 if has_chief: 

622 return num_workers + 1 

623 return num_workers 

624 

625 

626def _multi_worker_session(kwargs): 

627 """Returns a context manager that enters a session that is configured for the MultiWorkerMirroredStrategy. 

628 

629 Args: 

630 kwargs: a dict. Keyword arguments passed to the test. 

631 

632 Returns: 

633 A context manager. If MultiWorkerMirroredStrategy is the one and only one 

634 strategy in kwargs and it's in graph mode, it's the seesion that is 

635 configured for that strategy. Otherwise, it's a no-op context manager. 

636 """ 

637 strategy = None 

638 for _, v in kwargs.items(): 

639 if isinstance(v, distribute_lib.StrategyBase): 

640 if strategy is not None: 

641 logging.warning( 

642 "The test uses multiple strategies. Skipping " 

643 "entering a session that is configured for the strategy.") 

644 return ops.NullContextmanager() 

645 strategy = v 

646 if context.executing_eagerly() or not isinstance( 

647 strategy, collective_all_reduce_strategy.CollectiveAllReduceStrategy): 

648 return ops.NullContextmanager() 

649 sess_config = copy.deepcopy(context.context().config) 

650 sess_config = strategy.update_config_proto(sess_config) 

651 target = strategy.cluster_resolver.master() 

652 return session.Session(config=sess_config, target=target).as_default()