Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/distribute/strategy_combinations.py: 60%

227 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Strategy combinations for combinations.combine().""" 

16 

17import sys 

18import unittest 

19from tensorflow.core.protobuf import config_pb2 

20from tensorflow.python import tf2 

21from tensorflow.python.distribute import central_storage_strategy 

22from tensorflow.python.distribute import cluster_resolver 

23from tensorflow.python.distribute import collective_all_reduce_strategy 

24from tensorflow.python.distribute import combinations 

25from tensorflow.python.distribute import distribute_lib 

26from tensorflow.python.distribute import mirrored_strategy as mirrored_lib 

27from tensorflow.python.distribute import multi_process_runner 

28from tensorflow.python.distribute import multi_worker_test_base 

29from tensorflow.python.distribute import one_device_strategy as one_device_lib 

30from tensorflow.python.distribute import parameter_server_strategy_v2 

31from tensorflow.python.distribute import sharded_variable 

32from tensorflow.python.distribute import test_util 

33from tensorflow.python.distribute import tpu_strategy as tpu_lib 

34from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver 

35from tensorflow.python.eager import context 

36from tensorflow.python.eager import remote 

37from tensorflow.python.framework import device as tf_device 

38from tensorflow.python.framework import errors 

39from tensorflow.python.framework import test_util as framework_test_util 

40from tensorflow.python.platform import flags 

41from tensorflow.python.tpu import device_assignment as device_assignment_lib 

42from tensorflow.python.tpu import tpu_strategy_util 

43from tensorflow.python.training import server_lib 

44from tensorflow.python.util.tf_export import tf_export 

45 

46_TF_INTERNAL_API_PREFIX = "__internal__.distribute.combinations." 

47 

48_did_connect_to_cluster = False 

49_topology = None 

50CollectiveAllReduceExtended = ( 

51 collective_all_reduce_strategy.CollectiveAllReduceExtended) 

52 

53 

54def _version_chooser(tf1_cls, tf2_cls): 

55 

56 def creator(*args, **kwargs): 

57 if tf2.enabled(): 

58 return tf2_cls(*args, **kwargs) 

59 return tf1_cls(*args, **kwargs) 

60 

61 return creator 

62 

63 

64MirroredStrategy = _version_chooser(mirrored_lib.MirroredStrategyV1, 

65 mirrored_lib.MirroredStrategy) 

66CentralStorageStrategy = _version_chooser( 

67 central_storage_strategy.CentralStorageStrategyV1, 

68 central_storage_strategy.CentralStorageStrategy) 

69OneDeviceStrategy = _version_chooser(one_device_lib.OneDeviceStrategyV1, 

70 one_device_lib.OneDeviceStrategy) 

71# Only V2 CollectiveAllReduceStrategy combinations are supported. 

72CollectiveAllReduceStrategy = ( 

73 collective_all_reduce_strategy.CollectiveAllReduceStrategy) 

74 

75 

76# pylint: disable=missing-docstring 

77def _get_tpu_strategy_creator(steps_per_run, 

78 use_single_core=False, 

79 enable_packed_variable=False, 

80 enable_spmd_xla_paritioning=False, 

81 **kwargs): 

82 

83 def _create_tpu_strategy(): 

84 FLAGS = flags.FLAGS # pylint: disable=invalid-name 

85 global _did_connect_to_cluster 

86 global _topology 

87 

88 try: 

89 # Attempt to locally discover the TPU. This will fail for Cloud TPU, in 

90 # which case we fall back to the values passed as flags. 

91 resolver = tpu_cluster_resolver.TPUClusterResolver() 

92 did_automatically_resolve = True 

93 except ValueError: 

94 did_automatically_resolve = False 

95 

96 # These flags will be defined by tpu_test_wrapper.py. 

97 resolver = tpu_cluster_resolver.TPUClusterResolver( 

98 tpu=hasattr(FLAGS, "tpu") and FLAGS.tpu or "", 

99 zone=hasattr(FLAGS, "zone") and FLAGS.zone or None, 

100 project=hasattr(FLAGS, "project") and FLAGS.project or None, 

101 ) 

102 

103 # Only connect once per process, rather than per test method. 

104 if not _did_connect_to_cluster: 

105 if getattr(FLAGS, "tpu", "") or did_automatically_resolve: 

106 remote.connect_to_cluster(resolver) 

107 _did_connect_to_cluster = True 

108 _topology = tpu_strategy_util.initialize_tpu_system(resolver) 

109 

110 device_assignment = None 

111 if use_single_core: 

112 device_assignment = device_assignment_lib.DeviceAssignment( 

113 _topology, 

114 core_assignment=device_assignment_lib.SINGLE_CORE_ASSIGNMENT) 

115 

116 # Steps per run is only supported in TF 1.x 

117 if tf2.enabled(): 

118 strategy = tpu_lib.TPUStrategyV2( 

119 resolver, 

120 device_assignment, 

121 experimental_spmd_xla_partitioning=enable_spmd_xla_paritioning, 

122 **kwargs) 

123 else: 

124 strategy = tpu_lib.TPUStrategyV1(resolver, steps_per_run, 

125 device_assignment, **kwargs) 

126 if enable_packed_variable and enable_spmd_xla_paritioning: 

127 raise ValueError("Packed Variable is not compatiable with SPMD mode") 

128 strategy._enable_packed_variable_in_eager_mode = enable_packed_variable # pylint: disable=protected-access 

129 return strategy 

130 

131 return _create_tpu_strategy 

132 

133 

134def _mirrored_strategy_with_collective_key_base(devices): 

135 required_cpus_nums = sum( 

136 1 

137 for d in devices 

138 if tf_device.DeviceSpec.from_string(d).device_type == "CPU" 

139 ) 

140 

141 # If required virtual CPUs are not setup yet, config the logical devices. 

142 if required_cpus_nums > len(context.context().list_logical_devices("CPU")): 

143 context._reset_context() # pylint: disable=protected-access 

144 test_util.set_logical_devices_to_at_least("CPU", required_cpus_nums) 

145 

146 # Increase collective base key to avoid key collision across subtests. 

147 mirrored_lib.MirroredStrategyV1._collective_key_base += 100000 

148 mirrored_lib.MirroredStrategy._collective_key_base += 100000 

149 return MirroredStrategy(devices) 

150 

151 

152def _mirrored_strategy_with_no_merge_call(devices): 

153 mirrored_lib.MirroredStrategyV1._collective_key_base += 100000 

154 mirrored_lib.MirroredStrategy._collective_key_base += 100000 

155 out = MirroredStrategy(devices) 

156 # Stub out merge call usage. 

157 out.extended._use_merge_call = lambda: False # pylint: disable=protected-access 

158 return out 

159 

160 

161def _get_multi_worker_mirrored_creator(required_gpus, use_merge_call=True): 

162 

163 def _create_multi_worker_mirrored(): 

164 tf_config = cluster_resolver.TFConfigClusterResolver() 

165 master = tf_config.master() 

166 if tf_config.rpc_layer: 

167 # Strip off the rpc_layer suffix. 

168 master = master[len("%s://" % tf_config.rpc_layer):] 

169 resolver = cluster_resolver.SimpleClusterResolver( 

170 cluster_spec=tf_config.cluster_spec(), 

171 task_type=tf_config.task_type, 

172 task_id=tf_config.task_id, 

173 master=master, 

174 environment=tf_config.environment, 

175 num_accelerators={"GPU": required_gpus}, 

176 rpc_layer=tf_config.rpc_layer or "grpc", 

177 ) 

178 # Disable health check and coordination service. We don't have a reliable 

179 # way to shutdown the strategy (and thus the strategy health check or 

180 # coordination service heartbeat) at the end of a test. Turning on the 

181 # strategy health check or coordination service heartbeat causes some 

182 # flakiness since we re-create part of the server when creating a strategy, 

183 # and our tests are capable of handling failures. 

184 CollectiveAllReduceExtended._enable_check_health = False # pylint: disable=protected-access 

185 context.context().configure_coordination_service(service_type="") 

186 # Always create the strategy in eager mode so that it starts the server and 

187 # configures the eager context. The eager context can no longer be 

188 # configured after initialization. 

189 with context.eager_mode(): 

190 strategy = CollectiveAllReduceStrategy(cluster_resolver=resolver) 

191 

192 if not use_merge_call: 

193 strategy.extended._use_merge_call = lambda: False # pylint: disable=protected-access 

194 # TODO(b/152320929): Wait for the cluster before proceeding, otherwise 

195 # collectives may hang if any worker launches collectives before the chief 

196 # creates the strategy. 

197 try: 

198 multi_process_runner.get_barrier().wait() 

199 except ValueError: 

200 # If the creator is called in the main process, 

201 # multi_process_runner.get_barrier() raises ValueError, which is safe to 

202 # ignore. 

203 pass 

204 return strategy 

205 

206 def skip_if_cannot_start_grpc_server(): 

207 try: 

208 return _create_multi_worker_mirrored() 

209 except errors.UnknownError as e: 

210 if "Could not start gRPC server" in e.message and ( 

211 len(sys.argv) >= 1 and "bazel" in sys.argv[0]): 

212 raise unittest.SkipTest("Cannot start std servers.") 

213 else: 

214 raise 

215 

216 return skip_if_cannot_start_grpc_server 

217 

218 

219# Due to b/195615322, FixedShardsPartitioner will wrongly partition 

220# RNG state, so we use MinSizePartitioner as the default. Maximum RNG 

221# state size is int64[3] which is 8 * 3 bytes, so we set 

222# min_shard_bytes to 8 * 3 + 1. 

223DEFAULT_PARTITIONER = sharded_variable.MinSizePartitioner( 

224 min_shard_bytes=8 * 3 + 1, max_shards=2) 

225 

226 

227def _get_ps_strategy_creator(num_workers, 

228 num_ps, 

229 required_gpus=0, 

230 variable_partitioner=DEFAULT_PARTITIONER): 

231 

232 def _create_ps_strategy(resolver, variable_partitioner): 

233 return parameter_server_strategy_v2.ParameterServerStrategyV2( 

234 resolver, variable_partitioner=variable_partitioner) 

235 

236 def _create_parameter_server(): 

237 if framework_test_util.is_xla_enabled(): 

238 # To address test failures resulting in XLA with MultiProcessRunner, 

239 # continue to use in-process cluster for XLA tests. 

240 cluster_def = multi_worker_test_base.create_in_process_cluster( 

241 num_workers=num_workers, num_ps=num_ps, rpc_layer="grpc") 

242 resolver = cluster_resolver.SimpleClusterResolver( 

243 server_lib.ClusterSpec(cluster_def), 

244 num_accelerators={"GPU": required_gpus}, 

245 rpc_layer="grpc") 

246 return _create_ps_strategy(resolver, variable_partitioner) 

247 else: 

248 tf_config = cluster_resolver.TFConfigClusterResolver() 

249 cluster_def = tf_config.cluster_spec().as_dict() 

250 if not cluster_def: 

251 # When MultiProcessRunner cluster is used, the cluster is not created 

252 # initially when the decorator is called. When the test runs, initially 

253 # this method is invoked via decorator before setting up the 

254 # MultiProcessRunner with worker and ps in the combinations.py. After 

255 # setup is done, the subprocess invokes this method again to get 

256 # strategy object. We return None strategy when the main thread invokes 

257 # this method before setting up cluster. 

258 # Returning None is fine here, since this thread will proceed to create 

259 # MultiProcessRunner and invoke tests with decorator inside 

260 # subprocesses. 

261 return None 

262 # MultiProcessRunner is already setup and this method is invoked from a 

263 # subprocess running the actual test. 

264 resolver = cluster_resolver.SimpleClusterResolver( 

265 server_lib.ClusterSpec(cluster_def), 

266 num_accelerators={"GPU": required_gpus}, 

267 task_type=tf_config.task_type, 

268 task_id=tf_config.task_id, 

269 environment=tf_config.environment, 

270 rpc_layer=tf_config.rpc_layer or "grpc") 

271 if tf_config.task_type in ("worker", "ps"): 

272 worker_config = config_pb2.ConfigProto() 

273 worker_config.inter_op_parallelism_threads = 4 # max num_workers + 1 

274 

275 try: 

276 server = server_lib.Server( 

277 cluster_def, 

278 job_name=tf_config.task_type, 

279 task_index=tf_config.task_id, 

280 protocol="grpc", 

281 config=worker_config) 

282 except errors.UnknownError as e: 

283 if "Could not start gRPC server" in e.message: 

284 raise unittest.SkipTest("Cannot start std servers.") 

285 else: 

286 raise 

287 

288 # Blocking the process that starts a server from exiting. 

289 server.join() 

290 

291 return _create_ps_strategy(resolver, variable_partitioner) 

292 

293 return _create_parameter_server 

294 

295 

296def _deferred_pool_runner(has_chief, 

297 num_workers, 

298 initializer=None, 

299 share_gpu=True): 

300 """Returns a callable that returns the pool runner. 

301 

302 It creates the pool runner only upon first invocation. This avoids creating it 

303 when this file is imported. 

304 

305 Args: 

306 has_chief: whether there should be a chief. 

307 num_workers: the number of workers excluding the chief. 

308 initializer: initializer of each process. 

309 share_gpu: whether to share GPU between the workers. 

310 

311 Returns: 

312 A callable that returns the runner. 

313 """ 

314 

315 container = [] 

316 

317 def get_or_create(): 

318 if not container: 

319 cluster_spec = multi_worker_test_base.create_cluster_spec( 

320 has_chief=has_chief, 

321 num_workers=num_workers, 

322 num_ps=0, 

323 has_eval=False) 

324 runner = multi_process_runner.MultiProcessPoolRunner( 

325 cluster_spec, initializer=initializer, share_gpu=share_gpu) 

326 container.append(runner) 

327 return container[0] 

328 

329 return get_or_create 

330 

331 

332# We need to create the strategy in the initializer to start the server before 

333# any test runs. 

334_two_worker_pool = _deferred_pool_runner( 

335 has_chief=True, 

336 num_workers=1, 

337 initializer=_get_multi_worker_mirrored_creator(required_gpus=0)) 

338 

339# Two-worker pool where each worker gets it's own GPU. Useful for testing MWMS 

340# on a single host. 

341_two_worker_pool_noshare = _deferred_pool_runner( 

342 has_chief=True, 

343 num_workers=1, 

344 initializer=_get_multi_worker_mirrored_creator(required_gpus=0), 

345 share_gpu=False) 

346_four_worker_pool = _deferred_pool_runner( 

347 has_chief=True, 

348 num_workers=3, 

349 initializer=_get_multi_worker_mirrored_creator(required_gpus=0)) 

350 

351# pylint: disable=g-long-lambda 

352default_strategy = combinations.NamedDistribution( 

353 "Default", 

354 distribute_lib._get_default_strategy, # pylint: disable=protected-access 

355 required_gpus=None) 

356one_device_strategy = combinations.NamedDistribution( 

357 "OneDeviceCPU", lambda: OneDeviceStrategy("/cpu:0"), required_gpus=None) 

358one_device_strategy_gpu = combinations.NamedDistribution( 

359 "OneDeviceGPU", lambda: OneDeviceStrategy("/gpu:0"), required_gpus=1) 

360one_device_strategy_on_worker_1 = combinations.NamedDistribution( 

361 "OneDeviceOnWorker1CPU", 

362 lambda: OneDeviceStrategy("/job:worker/replica:0/task:1/cpu:0"), 

363 required_gpus=None) 

364one_device_strategy_gpu_on_worker_1 = combinations.NamedDistribution( 

365 "OneDeviceOnWorker1GPU", 

366 lambda: OneDeviceStrategy("/job:worker/replica:0/task:1/gpu:0"), 

367 required_gpus=1) 

368tpu_strategy = combinations.NamedDistribution( 

369 "TPU", _get_tpu_strategy_creator(steps_per_run=2), required_tpu=True) 

370tpu_strategy_packed_var = combinations.NamedDistribution( 

371 "TPUPackedVar", 

372 _get_tpu_strategy_creator(steps_per_run=2, enable_packed_variable=True), 

373 required_tpu=True) 

374tpu_strategy_spmd = combinations.NamedDistribution( 

375 "TPUUseSPMD", 

376 _get_tpu_strategy_creator( 

377 steps_per_run=2, enable_spmd_xla_paritioning=True), 

378 required_tpu=True) 

379tpu_strategy_one_step = combinations.NamedDistribution( 

380 "TPUOneStep", _get_tpu_strategy_creator(steps_per_run=1), required_tpu=True) 

381tpu_strategy_one_core = combinations.NamedDistribution( 

382 "TPUOneCore", 

383 _get_tpu_strategy_creator(steps_per_run=2, use_single_core=True), 

384 required_tpu=True) 

385tpu_strategy_one_step_one_core = combinations.NamedDistribution( 

386 "TPUOneStepOneCore", 

387 _get_tpu_strategy_creator(steps_per_run=1, use_single_core=True), 

388 required_tpu=True) 

389cloud_tpu_strategy = combinations.NamedDistribution( 

390 "CloudTPU", 

391 _get_tpu_strategy_creator(steps_per_run=2), 

392 required_tpu=True, 

393 use_cloud_tpu=True) 

394mirrored_strategy_with_one_cpu = combinations.NamedDistribution( 

395 "Mirrored1CPU", 

396 lambda: _mirrored_strategy_with_collective_key_base(["/cpu:0"])) 

397mirrored_strategy_with_one_gpu = combinations.NamedDistribution( 

398 "Mirrored1GPU", 

399 lambda: _mirrored_strategy_with_collective_key_base(["/gpu:0"]), 

400 required_gpus=1) 

401mirrored_strategy_with_gpu_and_cpu = combinations.NamedDistribution( 

402 "MirroredCPUAndGPU", 

403 lambda: _mirrored_strategy_with_collective_key_base(["/gpu:0", "/cpu:0"]), 

404 required_gpus=1) 

405mirrored_strategy_with_two_cpus = combinations.NamedDistribution( 

406 "Mirrored2CPUs", 

407 lambda: _mirrored_strategy_with_collective_key_base(["/cpu:0", "/cpu:1"]), 

408 required_gpus=0) 

409mirrored_strategy_with_two_gpus = combinations.NamedDistribution( 

410 "Mirrored2GPUs", 

411 lambda: _mirrored_strategy_with_collective_key_base(["/gpu:0", "/gpu:1"]), 

412 required_gpus=2) 

413mirrored_strategy_with_two_gpus_no_merge_call = combinations.NamedDistribution( 

414 "Mirrored2GPUsNoMergeCall", 

415 lambda: _mirrored_strategy_with_no_merge_call(["/gpu:0", "/gpu:1"]), 

416 required_physical_gpus=2) 

417# Should call set_virtual_cpus_to_at_least(3) in your test's setUp methods. 

418# Deprecated, use mirrored_strategy_with_two_cpus instead. 

419mirrored_strategy_with_cpu_1_and_2 = combinations.NamedDistribution( 

420 "Mirrored2CPU", 

421 lambda: _mirrored_strategy_with_collective_key_base(["/cpu:1", "/cpu:2"])) 

422mirrored_strategy_with_cpu_1_and_2.__doc__ = ( 

423 """Mirrored strategy with 2 virtual CPUs. 

424 

425 Should set up logical devices before use 

426 """) 

427central_storage_strategy_with_two_gpus = combinations.NamedDistribution( 

428 "CentralStorage2GPUs", 

429 lambda: CentralStorageStrategy(["/gpu:0", "/gpu:1"]), 

430 required_gpus=2) 

431central_storage_strategy_with_gpu_and_cpu = combinations.NamedDistribution( 

432 "CentralStorageCPUAndGPU", 

433 lambda: CentralStorageStrategy(["/gpu:0", "/cpu:0"]), 

434 required_gpus=1) 

435# chief + 1 worker, with CPU. 

436multi_worker_mirrored_2x1_cpu = combinations.NamedDistribution( 

437 "MultiWorkerMirrored2x1CPU", 

438 _get_multi_worker_mirrored_creator(required_gpus=0), 

439 has_chief=True, 

440 num_workers=1, 

441 pool_runner_fn=_two_worker_pool, 

442 no_xla=True, 

443) 

444# chief + 1 worker, with 1 GPU each. 

445multi_worker_mirrored_2x1_gpu = combinations.NamedDistribution( 

446 "MultiWorkerMirrored2x1GPU", 

447 _get_multi_worker_mirrored_creator(required_gpus=1), 

448 has_chief=True, 

449 num_workers=1, 

450 required_gpus=1, 

451 pool_runner_fn=_two_worker_pool, 

452 share_gpu=False, 

453) 

454 

455# Same as above, but not sharing the GPU between the workers. 

456multi_worker_mirrored_2x1_gpu_noshare = combinations.NamedDistribution( 

457 "MultiWorkerMirrored2x1GPUNoShare", 

458 _get_multi_worker_mirrored_creator(required_gpus=1), 

459 has_chief=True, 

460 num_workers=1, 

461 required_gpus=1, 

462 pool_runner_fn=_two_worker_pool_noshare, 

463 share_gpu=False, 

464) 

465# chief + 1 worker, with 2 GPU each. 

466multi_worker_mirrored_2x2_gpu = combinations.NamedDistribution( 

467 "MultiWorkerMirrored2x2GPU", 

468 _get_multi_worker_mirrored_creator(required_gpus=2), 

469 has_chief=True, 

470 num_workers=1, 

471 required_gpus=2, 

472 pool_runner_fn=_two_worker_pool, 

473 no_xla=True, 

474) 

475multi_worker_mirrored_2x2_gpu_no_merge_call = combinations.NamedDistribution( 

476 "MultiWorkerMirrored2x2GPUNoMergeCall", 

477 _get_multi_worker_mirrored_creator(required_gpus=2, use_merge_call=False), 

478 has_chief=True, 

479 num_workers=1, 

480 required_physical_gpus=2, 

481 pool_runner_fn=_two_worker_pool, 

482 no_xla=True, 

483) 

484# chief + 3 workers, with CPU. 

485multi_worker_mirrored_4x1_cpu = combinations.NamedDistribution( 

486 "MultiWorkerMirrored4x1CPU", 

487 _get_multi_worker_mirrored_creator(required_gpus=0), 

488 has_chief=True, 

489 num_workers=3, 

490 pool_runner_fn=_four_worker_pool, 

491 no_xla=True, 

492) 

493 

494 

495def parameter_server_strategy_fn(name, 

496 num_workers, 

497 num_ps, 

498 required_gpus=0, 

499 variable_partitioner=DEFAULT_PARTITIONER): 

500 return combinations.NamedDistribution( 

501 name, 

502 _get_ps_strategy_creator( 

503 num_workers=num_workers, 

504 num_ps=num_ps, 

505 required_gpus=required_gpus, 

506 variable_partitioner=variable_partitioner), 

507 required_gpus=required_gpus, 

508 num_workers=num_workers, 

509 has_chief=True, 

510 num_ps=num_ps) 

511 

512 

513parameter_server_strategy_3worker_2ps_cpu = parameter_server_strategy_fn( 

514 "ParameterServer3Worker2PSCPU", num_workers=3, num_ps=2) 

515parameter_server_strategy_1worker_2ps_cpu = parameter_server_strategy_fn( 

516 "ParameterServer1Worker2PSCPU", num_workers=1, num_ps=2) 

517parameter_server_strategy_3worker_2ps_1gpu = parameter_server_strategy_fn( 

518 "ParameterServer3Worker2PS1GPU", num_workers=3, num_ps=2, required_gpus=1) 

519parameter_server_strategy_1worker_2ps_1gpu = parameter_server_strategy_fn( 

520 "ParameterServer1Worker2PS1GPU", num_workers=1, num_ps=2, required_gpus=1) 

521 

522graph_and_eager_modes = ["graph", "eager"] 

523 

524 

525# TODO(crccw): remove after tf-nightly picks up the new API. 

526def set_virtual_cpus_to_at_least(num_virtual_cpus): 

527 test_util.set_logical_devices_to_at_least("CPU", num_virtual_cpus) 

528 

529 

530strategies_minus_tpu = [ 

531 default_strategy, 

532 one_device_strategy, 

533 one_device_strategy_gpu, 

534 mirrored_strategy_with_gpu_and_cpu, 

535 mirrored_strategy_with_two_gpus, 

536 central_storage_strategy_with_gpu_and_cpu, 

537] 

538 

539strategies_minus_default_and_tpu = [ 

540 one_device_strategy, 

541 one_device_strategy_gpu, 

542 mirrored_strategy_with_gpu_and_cpu, 

543 mirrored_strategy_with_two_gpus, 

544] 

545 

546tpu_strategies = [ 

547 tpu_strategy, # steps_per_run=2 

548 tpu_strategy_one_step, 

549 tpu_strategy_packed_var, 

550 cloud_tpu_strategy, 

551] 

552 

553all_strategies_minus_default = strategies_minus_default_and_tpu + tpu_strategies 

554 

555all_strategies = strategies_minus_tpu + tpu_strategies 

556 

557two_replica_strategies = [ 

558 mirrored_strategy_with_gpu_and_cpu, 

559 mirrored_strategy_with_two_gpus, 

560 multi_worker_mirrored_2x1_cpu, 

561 multi_worker_mirrored_2x1_gpu, 

562 tpu_strategy, # steps_per_run=2 

563 tpu_strategy_one_step, 

564 central_storage_strategy_with_gpu_and_cpu, 

565] 

566 

567four_replica_strategies = [ 

568 multi_worker_mirrored_2x2_gpu, 

569 multi_worker_mirrored_4x1_cpu, 

570] 

571 

572# TODO(b/159831907): replace with two_replica_strategies after the tests using 

573# it work with MWMS. 

574multidevice_strategies = [ 

575 mirrored_strategy_with_gpu_and_cpu, 

576 mirrored_strategy_with_two_gpus, 

577 tpu_strategy, # steps_per_run=2 

578 tpu_strategy_one_step 

579] 

580 

581multiworker_strategies = [ 

582 multi_worker_mirrored_2x1_cpu, multi_worker_mirrored_2x1_gpu, 

583 multi_worker_mirrored_2x2_gpu 

584] 

585 

586 

587def strategy_minus_tpu_combinations(): 

588 return combinations.combine( 

589 distribution=strategies_minus_tpu, mode=["graph", "eager"]) 

590 

591 

592def tpu_strategy_combinations(): 

593 return combinations.combine(distribution=tpu_strategies, mode=["graph"]) 

594 

595 

596def all_strategy_combinations(): 

597 return strategy_minus_tpu_combinations() + tpu_strategy_combinations() 

598 

599 

600def all_strategy_minus_default_and_tpu_combinations(): 

601 return combinations.combine( 

602 distribution=[ 

603 one_device_strategy, one_device_strategy_gpu, 

604 mirrored_strategy_with_gpu_and_cpu, mirrored_strategy_with_two_gpus 

605 ], 

606 mode=["graph", "eager"]) 

607 

608 

609def all_strategy_combinations_minus_default(): 

610 return (all_strategy_minus_default_and_tpu_combinations() + 

611 tpu_strategy_combinations()) 

612 

613 

614tf_export( 

615 _TF_INTERNAL_API_PREFIX + "central_storage_strategy_with_gpu_and_cpu", 

616 v1=[]).export_constant(__name__, 

617 "central_storage_strategy_with_gpu_and_cpu") 

618tf_export( 

619 _TF_INTERNAL_API_PREFIX + "central_storage_strategy_with_two_gpus", 

620 v1=[]).export_constant(__name__, "central_storage_strategy_with_two_gpus") 

621tf_export( 

622 _TF_INTERNAL_API_PREFIX + "cloud_tpu_strategy", 

623 v1=[]).export_constant(__name__, "cloud_tpu_strategy") 

624tf_export( 

625 _TF_INTERNAL_API_PREFIX + "default_strategy", 

626 v1=[]).export_constant(__name__, "default_strategy") 

627tf_export( 

628 _TF_INTERNAL_API_PREFIX + "mirrored_strategy_with_cpu_1_and_2", 

629 v1=[]).export_constant(__name__, "mirrored_strategy_with_cpu_1_and_2") 

630tf_export( 

631 _TF_INTERNAL_API_PREFIX + "mirrored_strategy_with_two_cpus", 

632 v1=[]).export_constant(__name__, "mirrored_strategy_with_two_cpus") 

633tf_export( 

634 _TF_INTERNAL_API_PREFIX + "mirrored_strategy_with_gpu_and_cpu", 

635 v1=[]).export_constant(__name__, "mirrored_strategy_with_gpu_and_cpu") 

636tf_export( 

637 _TF_INTERNAL_API_PREFIX + "mirrored_strategy_with_one_cpu", 

638 v1=[]).export_constant(__name__, "mirrored_strategy_with_one_cpu") 

639tf_export( 

640 _TF_INTERNAL_API_PREFIX + "mirrored_strategy_with_one_gpu", 

641 v1=[]).export_constant(__name__, "mirrored_strategy_with_one_gpu") 

642tf_export( 

643 _TF_INTERNAL_API_PREFIX + "mirrored_strategy_with_two_gpus", 

644 v1=[]).export_constant(__name__, "mirrored_strategy_with_two_gpus") 

645tf_export( 

646 _TF_INTERNAL_API_PREFIX + "mirrored_strategy_with_two_gpus_no_merge_call", 

647 v1=[]).export_constant(__name__, 

648 "mirrored_strategy_with_two_gpus_no_merge_call") 

649tf_export( 

650 _TF_INTERNAL_API_PREFIX + "multi_worker_mirrored_2x1_cpu", 

651 v1=[]).export_constant(__name__, "multi_worker_mirrored_2x1_cpu") 

652tf_export( 

653 _TF_INTERNAL_API_PREFIX + "multi_worker_mirrored_2x1_gpu", 

654 v1=[]).export_constant(__name__, "multi_worker_mirrored_2x1_gpu") 

655tf_export( 

656 _TF_INTERNAL_API_PREFIX + "multi_worker_mirrored_2x1_gpu_noshare", 

657 v1=[]).export_constant(__name__, "multi_worker_mirrored_2x1_gpu_noshare") 

658tf_export( 

659 _TF_INTERNAL_API_PREFIX + "multi_worker_mirrored_2x2_gpu", 

660 v1=[]).export_constant(__name__, "multi_worker_mirrored_2x2_gpu") 

661tf_export( 

662 _TF_INTERNAL_API_PREFIX + "multi_worker_mirrored_2x2_gpu_no_merge_call", 

663 v1=[]).export_constant(__name__, 

664 "multi_worker_mirrored_2x2_gpu_no_merge_call") 

665tf_export( 

666 _TF_INTERNAL_API_PREFIX + "one_device_strategy", 

667 v1=[]).export_constant(__name__, "one_device_strategy") 

668tf_export( 

669 _TF_INTERNAL_API_PREFIX + "one_device_strategy_gpu", 

670 v1=[]).export_constant(__name__, "one_device_strategy_gpu") 

671tf_export( 

672 _TF_INTERNAL_API_PREFIX + "tpu_strategy", 

673 v1=[]).export_constant(__name__, "tpu_strategy") 

674tf_export( 

675 _TF_INTERNAL_API_PREFIX + "parameter_server_strategy_3worker_2ps_cpu", 

676 v1=[]).export_constant(__name__, 

677 "parameter_server_strategy_3worker_2ps_cpu") 

678tf_export( 

679 _TF_INTERNAL_API_PREFIX + "parameter_server_strategy_1worker_2ps_cpu", 

680 v1=[]).export_constant(__name__, 

681 "parameter_server_strategy_1worker_2ps_cpu") 

682tf_export( 

683 _TF_INTERNAL_API_PREFIX + "parameter_server_strategy_3worker_2ps_1gpu", 

684 v1=[]).export_constant(__name__, 

685 "parameter_server_strategy_3worker_2ps_1gpu") 

686tf_export( 

687 _TF_INTERNAL_API_PREFIX + "parameter_server_strategy_1worker_2ps_1gpu", 

688 v1=[]).export_constant(__name__, 

689 "parameter_server_strategy_1worker_2ps_1gpu") 

690tf_export( 

691 _TF_INTERNAL_API_PREFIX + "tpu_strategy_one_core", 

692 v1=[]).export_constant(__name__, "tpu_strategy_one_core") 

693tf_export( 

694 _TF_INTERNAL_API_PREFIX + "tpu_strategy_packed_var", 

695 v1=[]).export_constant(__name__, "tpu_strategy_packed_var")