Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/distribute/strategy_combinations.py: 60%
227 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Strategy combinations for combinations.combine()."""
17import sys
18import unittest
19from tensorflow.core.protobuf import config_pb2
20from tensorflow.python import tf2
21from tensorflow.python.distribute import central_storage_strategy
22from tensorflow.python.distribute import cluster_resolver
23from tensorflow.python.distribute import collective_all_reduce_strategy
24from tensorflow.python.distribute import combinations
25from tensorflow.python.distribute import distribute_lib
26from tensorflow.python.distribute import mirrored_strategy as mirrored_lib
27from tensorflow.python.distribute import multi_process_runner
28from tensorflow.python.distribute import multi_worker_test_base
29from tensorflow.python.distribute import one_device_strategy as one_device_lib
30from tensorflow.python.distribute import parameter_server_strategy_v2
31from tensorflow.python.distribute import sharded_variable
32from tensorflow.python.distribute import test_util
33from tensorflow.python.distribute import tpu_strategy as tpu_lib
34from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver
35from tensorflow.python.eager import context
36from tensorflow.python.eager import remote
37from tensorflow.python.framework import device as tf_device
38from tensorflow.python.framework import errors
39from tensorflow.python.framework import test_util as framework_test_util
40from tensorflow.python.platform import flags
41from tensorflow.python.tpu import device_assignment as device_assignment_lib
42from tensorflow.python.tpu import tpu_strategy_util
43from tensorflow.python.training import server_lib
44from tensorflow.python.util.tf_export import tf_export
46_TF_INTERNAL_API_PREFIX = "__internal__.distribute.combinations."
48_did_connect_to_cluster = False
49_topology = None
50CollectiveAllReduceExtended = (
51 collective_all_reduce_strategy.CollectiveAllReduceExtended)
54def _version_chooser(tf1_cls, tf2_cls):
56 def creator(*args, **kwargs):
57 if tf2.enabled():
58 return tf2_cls(*args, **kwargs)
59 return tf1_cls(*args, **kwargs)
61 return creator
64MirroredStrategy = _version_chooser(mirrored_lib.MirroredStrategyV1,
65 mirrored_lib.MirroredStrategy)
66CentralStorageStrategy = _version_chooser(
67 central_storage_strategy.CentralStorageStrategyV1,
68 central_storage_strategy.CentralStorageStrategy)
69OneDeviceStrategy = _version_chooser(one_device_lib.OneDeviceStrategyV1,
70 one_device_lib.OneDeviceStrategy)
71# Only V2 CollectiveAllReduceStrategy combinations are supported.
72CollectiveAllReduceStrategy = (
73 collective_all_reduce_strategy.CollectiveAllReduceStrategy)
76# pylint: disable=missing-docstring
77def _get_tpu_strategy_creator(steps_per_run,
78 use_single_core=False,
79 enable_packed_variable=False,
80 enable_spmd_xla_paritioning=False,
81 **kwargs):
83 def _create_tpu_strategy():
84 FLAGS = flags.FLAGS # pylint: disable=invalid-name
85 global _did_connect_to_cluster
86 global _topology
88 try:
89 # Attempt to locally discover the TPU. This will fail for Cloud TPU, in
90 # which case we fall back to the values passed as flags.
91 resolver = tpu_cluster_resolver.TPUClusterResolver()
92 did_automatically_resolve = True
93 except ValueError:
94 did_automatically_resolve = False
96 # These flags will be defined by tpu_test_wrapper.py.
97 resolver = tpu_cluster_resolver.TPUClusterResolver(
98 tpu=hasattr(FLAGS, "tpu") and FLAGS.tpu or "",
99 zone=hasattr(FLAGS, "zone") and FLAGS.zone or None,
100 project=hasattr(FLAGS, "project") and FLAGS.project or None,
101 )
103 # Only connect once per process, rather than per test method.
104 if not _did_connect_to_cluster:
105 if getattr(FLAGS, "tpu", "") or did_automatically_resolve:
106 remote.connect_to_cluster(resolver)
107 _did_connect_to_cluster = True
108 _topology = tpu_strategy_util.initialize_tpu_system(resolver)
110 device_assignment = None
111 if use_single_core:
112 device_assignment = device_assignment_lib.DeviceAssignment(
113 _topology,
114 core_assignment=device_assignment_lib.SINGLE_CORE_ASSIGNMENT)
116 # Steps per run is only supported in TF 1.x
117 if tf2.enabled():
118 strategy = tpu_lib.TPUStrategyV2(
119 resolver,
120 device_assignment,
121 experimental_spmd_xla_partitioning=enable_spmd_xla_paritioning,
122 **kwargs)
123 else:
124 strategy = tpu_lib.TPUStrategyV1(resolver, steps_per_run,
125 device_assignment, **kwargs)
126 if enable_packed_variable and enable_spmd_xla_paritioning:
127 raise ValueError("Packed Variable is not compatiable with SPMD mode")
128 strategy._enable_packed_variable_in_eager_mode = enable_packed_variable # pylint: disable=protected-access
129 return strategy
131 return _create_tpu_strategy
134def _mirrored_strategy_with_collective_key_base(devices):
135 required_cpus_nums = sum(
136 1
137 for d in devices
138 if tf_device.DeviceSpec.from_string(d).device_type == "CPU"
139 )
141 # If required virtual CPUs are not setup yet, config the logical devices.
142 if required_cpus_nums > len(context.context().list_logical_devices("CPU")):
143 context._reset_context() # pylint: disable=protected-access
144 test_util.set_logical_devices_to_at_least("CPU", required_cpus_nums)
146 # Increase collective base key to avoid key collision across subtests.
147 mirrored_lib.MirroredStrategyV1._collective_key_base += 100000
148 mirrored_lib.MirroredStrategy._collective_key_base += 100000
149 return MirroredStrategy(devices)
152def _mirrored_strategy_with_no_merge_call(devices):
153 mirrored_lib.MirroredStrategyV1._collective_key_base += 100000
154 mirrored_lib.MirroredStrategy._collective_key_base += 100000
155 out = MirroredStrategy(devices)
156 # Stub out merge call usage.
157 out.extended._use_merge_call = lambda: False # pylint: disable=protected-access
158 return out
161def _get_multi_worker_mirrored_creator(required_gpus, use_merge_call=True):
163 def _create_multi_worker_mirrored():
164 tf_config = cluster_resolver.TFConfigClusterResolver()
165 master = tf_config.master()
166 if tf_config.rpc_layer:
167 # Strip off the rpc_layer suffix.
168 master = master[len("%s://" % tf_config.rpc_layer):]
169 resolver = cluster_resolver.SimpleClusterResolver(
170 cluster_spec=tf_config.cluster_spec(),
171 task_type=tf_config.task_type,
172 task_id=tf_config.task_id,
173 master=master,
174 environment=tf_config.environment,
175 num_accelerators={"GPU": required_gpus},
176 rpc_layer=tf_config.rpc_layer or "grpc",
177 )
178 # Disable health check and coordination service. We don't have a reliable
179 # way to shutdown the strategy (and thus the strategy health check or
180 # coordination service heartbeat) at the end of a test. Turning on the
181 # strategy health check or coordination service heartbeat causes some
182 # flakiness since we re-create part of the server when creating a strategy,
183 # and our tests are capable of handling failures.
184 CollectiveAllReduceExtended._enable_check_health = False # pylint: disable=protected-access
185 context.context().configure_coordination_service(service_type="")
186 # Always create the strategy in eager mode so that it starts the server and
187 # configures the eager context. The eager context can no longer be
188 # configured after initialization.
189 with context.eager_mode():
190 strategy = CollectiveAllReduceStrategy(cluster_resolver=resolver)
192 if not use_merge_call:
193 strategy.extended._use_merge_call = lambda: False # pylint: disable=protected-access
194 # TODO(b/152320929): Wait for the cluster before proceeding, otherwise
195 # collectives may hang if any worker launches collectives before the chief
196 # creates the strategy.
197 try:
198 multi_process_runner.get_barrier().wait()
199 except ValueError:
200 # If the creator is called in the main process,
201 # multi_process_runner.get_barrier() raises ValueError, which is safe to
202 # ignore.
203 pass
204 return strategy
206 def skip_if_cannot_start_grpc_server():
207 try:
208 return _create_multi_worker_mirrored()
209 except errors.UnknownError as e:
210 if "Could not start gRPC server" in e.message and (
211 len(sys.argv) >= 1 and "bazel" in sys.argv[0]):
212 raise unittest.SkipTest("Cannot start std servers.")
213 else:
214 raise
216 return skip_if_cannot_start_grpc_server
219# Due to b/195615322, FixedShardsPartitioner will wrongly partition
220# RNG state, so we use MinSizePartitioner as the default. Maximum RNG
221# state size is int64[3] which is 8 * 3 bytes, so we set
222# min_shard_bytes to 8 * 3 + 1.
223DEFAULT_PARTITIONER = sharded_variable.MinSizePartitioner(
224 min_shard_bytes=8 * 3 + 1, max_shards=2)
227def _get_ps_strategy_creator(num_workers,
228 num_ps,
229 required_gpus=0,
230 variable_partitioner=DEFAULT_PARTITIONER):
232 def _create_ps_strategy(resolver, variable_partitioner):
233 return parameter_server_strategy_v2.ParameterServerStrategyV2(
234 resolver, variable_partitioner=variable_partitioner)
236 def _create_parameter_server():
237 if framework_test_util.is_xla_enabled():
238 # To address test failures resulting in XLA with MultiProcessRunner,
239 # continue to use in-process cluster for XLA tests.
240 cluster_def = multi_worker_test_base.create_in_process_cluster(
241 num_workers=num_workers, num_ps=num_ps, rpc_layer="grpc")
242 resolver = cluster_resolver.SimpleClusterResolver(
243 server_lib.ClusterSpec(cluster_def),
244 num_accelerators={"GPU": required_gpus},
245 rpc_layer="grpc")
246 return _create_ps_strategy(resolver, variable_partitioner)
247 else:
248 tf_config = cluster_resolver.TFConfigClusterResolver()
249 cluster_def = tf_config.cluster_spec().as_dict()
250 if not cluster_def:
251 # When MultiProcessRunner cluster is used, the cluster is not created
252 # initially when the decorator is called. When the test runs, initially
253 # this method is invoked via decorator before setting up the
254 # MultiProcessRunner with worker and ps in the combinations.py. After
255 # setup is done, the subprocess invokes this method again to get
256 # strategy object. We return None strategy when the main thread invokes
257 # this method before setting up cluster.
258 # Returning None is fine here, since this thread will proceed to create
259 # MultiProcessRunner and invoke tests with decorator inside
260 # subprocesses.
261 return None
262 # MultiProcessRunner is already setup and this method is invoked from a
263 # subprocess running the actual test.
264 resolver = cluster_resolver.SimpleClusterResolver(
265 server_lib.ClusterSpec(cluster_def),
266 num_accelerators={"GPU": required_gpus},
267 task_type=tf_config.task_type,
268 task_id=tf_config.task_id,
269 environment=tf_config.environment,
270 rpc_layer=tf_config.rpc_layer or "grpc")
271 if tf_config.task_type in ("worker", "ps"):
272 worker_config = config_pb2.ConfigProto()
273 worker_config.inter_op_parallelism_threads = 4 # max num_workers + 1
275 try:
276 server = server_lib.Server(
277 cluster_def,
278 job_name=tf_config.task_type,
279 task_index=tf_config.task_id,
280 protocol="grpc",
281 config=worker_config)
282 except errors.UnknownError as e:
283 if "Could not start gRPC server" in e.message:
284 raise unittest.SkipTest("Cannot start std servers.")
285 else:
286 raise
288 # Blocking the process that starts a server from exiting.
289 server.join()
291 return _create_ps_strategy(resolver, variable_partitioner)
293 return _create_parameter_server
296def _deferred_pool_runner(has_chief,
297 num_workers,
298 initializer=None,
299 share_gpu=True):
300 """Returns a callable that returns the pool runner.
302 It creates the pool runner only upon first invocation. This avoids creating it
303 when this file is imported.
305 Args:
306 has_chief: whether there should be a chief.
307 num_workers: the number of workers excluding the chief.
308 initializer: initializer of each process.
309 share_gpu: whether to share GPU between the workers.
311 Returns:
312 A callable that returns the runner.
313 """
315 container = []
317 def get_or_create():
318 if not container:
319 cluster_spec = multi_worker_test_base.create_cluster_spec(
320 has_chief=has_chief,
321 num_workers=num_workers,
322 num_ps=0,
323 has_eval=False)
324 runner = multi_process_runner.MultiProcessPoolRunner(
325 cluster_spec, initializer=initializer, share_gpu=share_gpu)
326 container.append(runner)
327 return container[0]
329 return get_or_create
332# We need to create the strategy in the initializer to start the server before
333# any test runs.
334_two_worker_pool = _deferred_pool_runner(
335 has_chief=True,
336 num_workers=1,
337 initializer=_get_multi_worker_mirrored_creator(required_gpus=0))
339# Two-worker pool where each worker gets it's own GPU. Useful for testing MWMS
340# on a single host.
341_two_worker_pool_noshare = _deferred_pool_runner(
342 has_chief=True,
343 num_workers=1,
344 initializer=_get_multi_worker_mirrored_creator(required_gpus=0),
345 share_gpu=False)
346_four_worker_pool = _deferred_pool_runner(
347 has_chief=True,
348 num_workers=3,
349 initializer=_get_multi_worker_mirrored_creator(required_gpus=0))
351# pylint: disable=g-long-lambda
352default_strategy = combinations.NamedDistribution(
353 "Default",
354 distribute_lib._get_default_strategy, # pylint: disable=protected-access
355 required_gpus=None)
356one_device_strategy = combinations.NamedDistribution(
357 "OneDeviceCPU", lambda: OneDeviceStrategy("/cpu:0"), required_gpus=None)
358one_device_strategy_gpu = combinations.NamedDistribution(
359 "OneDeviceGPU", lambda: OneDeviceStrategy("/gpu:0"), required_gpus=1)
360one_device_strategy_on_worker_1 = combinations.NamedDistribution(
361 "OneDeviceOnWorker1CPU",
362 lambda: OneDeviceStrategy("/job:worker/replica:0/task:1/cpu:0"),
363 required_gpus=None)
364one_device_strategy_gpu_on_worker_1 = combinations.NamedDistribution(
365 "OneDeviceOnWorker1GPU",
366 lambda: OneDeviceStrategy("/job:worker/replica:0/task:1/gpu:0"),
367 required_gpus=1)
368tpu_strategy = combinations.NamedDistribution(
369 "TPU", _get_tpu_strategy_creator(steps_per_run=2), required_tpu=True)
370tpu_strategy_packed_var = combinations.NamedDistribution(
371 "TPUPackedVar",
372 _get_tpu_strategy_creator(steps_per_run=2, enable_packed_variable=True),
373 required_tpu=True)
374tpu_strategy_spmd = combinations.NamedDistribution(
375 "TPUUseSPMD",
376 _get_tpu_strategy_creator(
377 steps_per_run=2, enable_spmd_xla_paritioning=True),
378 required_tpu=True)
379tpu_strategy_one_step = combinations.NamedDistribution(
380 "TPUOneStep", _get_tpu_strategy_creator(steps_per_run=1), required_tpu=True)
381tpu_strategy_one_core = combinations.NamedDistribution(
382 "TPUOneCore",
383 _get_tpu_strategy_creator(steps_per_run=2, use_single_core=True),
384 required_tpu=True)
385tpu_strategy_one_step_one_core = combinations.NamedDistribution(
386 "TPUOneStepOneCore",
387 _get_tpu_strategy_creator(steps_per_run=1, use_single_core=True),
388 required_tpu=True)
389cloud_tpu_strategy = combinations.NamedDistribution(
390 "CloudTPU",
391 _get_tpu_strategy_creator(steps_per_run=2),
392 required_tpu=True,
393 use_cloud_tpu=True)
394mirrored_strategy_with_one_cpu = combinations.NamedDistribution(
395 "Mirrored1CPU",
396 lambda: _mirrored_strategy_with_collective_key_base(["/cpu:0"]))
397mirrored_strategy_with_one_gpu = combinations.NamedDistribution(
398 "Mirrored1GPU",
399 lambda: _mirrored_strategy_with_collective_key_base(["/gpu:0"]),
400 required_gpus=1)
401mirrored_strategy_with_gpu_and_cpu = combinations.NamedDistribution(
402 "MirroredCPUAndGPU",
403 lambda: _mirrored_strategy_with_collective_key_base(["/gpu:0", "/cpu:0"]),
404 required_gpus=1)
405mirrored_strategy_with_two_cpus = combinations.NamedDistribution(
406 "Mirrored2CPUs",
407 lambda: _mirrored_strategy_with_collective_key_base(["/cpu:0", "/cpu:1"]),
408 required_gpus=0)
409mirrored_strategy_with_two_gpus = combinations.NamedDistribution(
410 "Mirrored2GPUs",
411 lambda: _mirrored_strategy_with_collective_key_base(["/gpu:0", "/gpu:1"]),
412 required_gpus=2)
413mirrored_strategy_with_two_gpus_no_merge_call = combinations.NamedDistribution(
414 "Mirrored2GPUsNoMergeCall",
415 lambda: _mirrored_strategy_with_no_merge_call(["/gpu:0", "/gpu:1"]),
416 required_physical_gpus=2)
417# Should call set_virtual_cpus_to_at_least(3) in your test's setUp methods.
418# Deprecated, use mirrored_strategy_with_two_cpus instead.
419mirrored_strategy_with_cpu_1_and_2 = combinations.NamedDistribution(
420 "Mirrored2CPU",
421 lambda: _mirrored_strategy_with_collective_key_base(["/cpu:1", "/cpu:2"]))
422mirrored_strategy_with_cpu_1_and_2.__doc__ = (
423 """Mirrored strategy with 2 virtual CPUs.
425 Should set up logical devices before use
426 """)
427central_storage_strategy_with_two_gpus = combinations.NamedDistribution(
428 "CentralStorage2GPUs",
429 lambda: CentralStorageStrategy(["/gpu:0", "/gpu:1"]),
430 required_gpus=2)
431central_storage_strategy_with_gpu_and_cpu = combinations.NamedDistribution(
432 "CentralStorageCPUAndGPU",
433 lambda: CentralStorageStrategy(["/gpu:0", "/cpu:0"]),
434 required_gpus=1)
435# chief + 1 worker, with CPU.
436multi_worker_mirrored_2x1_cpu = combinations.NamedDistribution(
437 "MultiWorkerMirrored2x1CPU",
438 _get_multi_worker_mirrored_creator(required_gpus=0),
439 has_chief=True,
440 num_workers=1,
441 pool_runner_fn=_two_worker_pool,
442 no_xla=True,
443)
444# chief + 1 worker, with 1 GPU each.
445multi_worker_mirrored_2x1_gpu = combinations.NamedDistribution(
446 "MultiWorkerMirrored2x1GPU",
447 _get_multi_worker_mirrored_creator(required_gpus=1),
448 has_chief=True,
449 num_workers=1,
450 required_gpus=1,
451 pool_runner_fn=_two_worker_pool,
452 share_gpu=False,
453)
455# Same as above, but not sharing the GPU between the workers.
456multi_worker_mirrored_2x1_gpu_noshare = combinations.NamedDistribution(
457 "MultiWorkerMirrored2x1GPUNoShare",
458 _get_multi_worker_mirrored_creator(required_gpus=1),
459 has_chief=True,
460 num_workers=1,
461 required_gpus=1,
462 pool_runner_fn=_two_worker_pool_noshare,
463 share_gpu=False,
464)
465# chief + 1 worker, with 2 GPU each.
466multi_worker_mirrored_2x2_gpu = combinations.NamedDistribution(
467 "MultiWorkerMirrored2x2GPU",
468 _get_multi_worker_mirrored_creator(required_gpus=2),
469 has_chief=True,
470 num_workers=1,
471 required_gpus=2,
472 pool_runner_fn=_two_worker_pool,
473 no_xla=True,
474)
475multi_worker_mirrored_2x2_gpu_no_merge_call = combinations.NamedDistribution(
476 "MultiWorkerMirrored2x2GPUNoMergeCall",
477 _get_multi_worker_mirrored_creator(required_gpus=2, use_merge_call=False),
478 has_chief=True,
479 num_workers=1,
480 required_physical_gpus=2,
481 pool_runner_fn=_two_worker_pool,
482 no_xla=True,
483)
484# chief + 3 workers, with CPU.
485multi_worker_mirrored_4x1_cpu = combinations.NamedDistribution(
486 "MultiWorkerMirrored4x1CPU",
487 _get_multi_worker_mirrored_creator(required_gpus=0),
488 has_chief=True,
489 num_workers=3,
490 pool_runner_fn=_four_worker_pool,
491 no_xla=True,
492)
495def parameter_server_strategy_fn(name,
496 num_workers,
497 num_ps,
498 required_gpus=0,
499 variable_partitioner=DEFAULT_PARTITIONER):
500 return combinations.NamedDistribution(
501 name,
502 _get_ps_strategy_creator(
503 num_workers=num_workers,
504 num_ps=num_ps,
505 required_gpus=required_gpus,
506 variable_partitioner=variable_partitioner),
507 required_gpus=required_gpus,
508 num_workers=num_workers,
509 has_chief=True,
510 num_ps=num_ps)
513parameter_server_strategy_3worker_2ps_cpu = parameter_server_strategy_fn(
514 "ParameterServer3Worker2PSCPU", num_workers=3, num_ps=2)
515parameter_server_strategy_1worker_2ps_cpu = parameter_server_strategy_fn(
516 "ParameterServer1Worker2PSCPU", num_workers=1, num_ps=2)
517parameter_server_strategy_3worker_2ps_1gpu = parameter_server_strategy_fn(
518 "ParameterServer3Worker2PS1GPU", num_workers=3, num_ps=2, required_gpus=1)
519parameter_server_strategy_1worker_2ps_1gpu = parameter_server_strategy_fn(
520 "ParameterServer1Worker2PS1GPU", num_workers=1, num_ps=2, required_gpus=1)
522graph_and_eager_modes = ["graph", "eager"]
525# TODO(crccw): remove after tf-nightly picks up the new API.
526def set_virtual_cpus_to_at_least(num_virtual_cpus):
527 test_util.set_logical_devices_to_at_least("CPU", num_virtual_cpus)
530strategies_minus_tpu = [
531 default_strategy,
532 one_device_strategy,
533 one_device_strategy_gpu,
534 mirrored_strategy_with_gpu_and_cpu,
535 mirrored_strategy_with_two_gpus,
536 central_storage_strategy_with_gpu_and_cpu,
537]
539strategies_minus_default_and_tpu = [
540 one_device_strategy,
541 one_device_strategy_gpu,
542 mirrored_strategy_with_gpu_and_cpu,
543 mirrored_strategy_with_two_gpus,
544]
546tpu_strategies = [
547 tpu_strategy, # steps_per_run=2
548 tpu_strategy_one_step,
549 tpu_strategy_packed_var,
550 cloud_tpu_strategy,
551]
553all_strategies_minus_default = strategies_minus_default_and_tpu + tpu_strategies
555all_strategies = strategies_minus_tpu + tpu_strategies
557two_replica_strategies = [
558 mirrored_strategy_with_gpu_and_cpu,
559 mirrored_strategy_with_two_gpus,
560 multi_worker_mirrored_2x1_cpu,
561 multi_worker_mirrored_2x1_gpu,
562 tpu_strategy, # steps_per_run=2
563 tpu_strategy_one_step,
564 central_storage_strategy_with_gpu_and_cpu,
565]
567four_replica_strategies = [
568 multi_worker_mirrored_2x2_gpu,
569 multi_worker_mirrored_4x1_cpu,
570]
572# TODO(b/159831907): replace with two_replica_strategies after the tests using
573# it work with MWMS.
574multidevice_strategies = [
575 mirrored_strategy_with_gpu_and_cpu,
576 mirrored_strategy_with_two_gpus,
577 tpu_strategy, # steps_per_run=2
578 tpu_strategy_one_step
579]
581multiworker_strategies = [
582 multi_worker_mirrored_2x1_cpu, multi_worker_mirrored_2x1_gpu,
583 multi_worker_mirrored_2x2_gpu
584]
587def strategy_minus_tpu_combinations():
588 return combinations.combine(
589 distribution=strategies_minus_tpu, mode=["graph", "eager"])
592def tpu_strategy_combinations():
593 return combinations.combine(distribution=tpu_strategies, mode=["graph"])
596def all_strategy_combinations():
597 return strategy_minus_tpu_combinations() + tpu_strategy_combinations()
600def all_strategy_minus_default_and_tpu_combinations():
601 return combinations.combine(
602 distribution=[
603 one_device_strategy, one_device_strategy_gpu,
604 mirrored_strategy_with_gpu_and_cpu, mirrored_strategy_with_two_gpus
605 ],
606 mode=["graph", "eager"])
609def all_strategy_combinations_minus_default():
610 return (all_strategy_minus_default_and_tpu_combinations() +
611 tpu_strategy_combinations())
614tf_export(
615 _TF_INTERNAL_API_PREFIX + "central_storage_strategy_with_gpu_and_cpu",
616 v1=[]).export_constant(__name__,
617 "central_storage_strategy_with_gpu_and_cpu")
618tf_export(
619 _TF_INTERNAL_API_PREFIX + "central_storage_strategy_with_two_gpus",
620 v1=[]).export_constant(__name__, "central_storage_strategy_with_two_gpus")
621tf_export(
622 _TF_INTERNAL_API_PREFIX + "cloud_tpu_strategy",
623 v1=[]).export_constant(__name__, "cloud_tpu_strategy")
624tf_export(
625 _TF_INTERNAL_API_PREFIX + "default_strategy",
626 v1=[]).export_constant(__name__, "default_strategy")
627tf_export(
628 _TF_INTERNAL_API_PREFIX + "mirrored_strategy_with_cpu_1_and_2",
629 v1=[]).export_constant(__name__, "mirrored_strategy_with_cpu_1_and_2")
630tf_export(
631 _TF_INTERNAL_API_PREFIX + "mirrored_strategy_with_two_cpus",
632 v1=[]).export_constant(__name__, "mirrored_strategy_with_two_cpus")
633tf_export(
634 _TF_INTERNAL_API_PREFIX + "mirrored_strategy_with_gpu_and_cpu",
635 v1=[]).export_constant(__name__, "mirrored_strategy_with_gpu_and_cpu")
636tf_export(
637 _TF_INTERNAL_API_PREFIX + "mirrored_strategy_with_one_cpu",
638 v1=[]).export_constant(__name__, "mirrored_strategy_with_one_cpu")
639tf_export(
640 _TF_INTERNAL_API_PREFIX + "mirrored_strategy_with_one_gpu",
641 v1=[]).export_constant(__name__, "mirrored_strategy_with_one_gpu")
642tf_export(
643 _TF_INTERNAL_API_PREFIX + "mirrored_strategy_with_two_gpus",
644 v1=[]).export_constant(__name__, "mirrored_strategy_with_two_gpus")
645tf_export(
646 _TF_INTERNAL_API_PREFIX + "mirrored_strategy_with_two_gpus_no_merge_call",
647 v1=[]).export_constant(__name__,
648 "mirrored_strategy_with_two_gpus_no_merge_call")
649tf_export(
650 _TF_INTERNAL_API_PREFIX + "multi_worker_mirrored_2x1_cpu",
651 v1=[]).export_constant(__name__, "multi_worker_mirrored_2x1_cpu")
652tf_export(
653 _TF_INTERNAL_API_PREFIX + "multi_worker_mirrored_2x1_gpu",
654 v1=[]).export_constant(__name__, "multi_worker_mirrored_2x1_gpu")
655tf_export(
656 _TF_INTERNAL_API_PREFIX + "multi_worker_mirrored_2x1_gpu_noshare",
657 v1=[]).export_constant(__name__, "multi_worker_mirrored_2x1_gpu_noshare")
658tf_export(
659 _TF_INTERNAL_API_PREFIX + "multi_worker_mirrored_2x2_gpu",
660 v1=[]).export_constant(__name__, "multi_worker_mirrored_2x2_gpu")
661tf_export(
662 _TF_INTERNAL_API_PREFIX + "multi_worker_mirrored_2x2_gpu_no_merge_call",
663 v1=[]).export_constant(__name__,
664 "multi_worker_mirrored_2x2_gpu_no_merge_call")
665tf_export(
666 _TF_INTERNAL_API_PREFIX + "one_device_strategy",
667 v1=[]).export_constant(__name__, "one_device_strategy")
668tf_export(
669 _TF_INTERNAL_API_PREFIX + "one_device_strategy_gpu",
670 v1=[]).export_constant(__name__, "one_device_strategy_gpu")
671tf_export(
672 _TF_INTERNAL_API_PREFIX + "tpu_strategy",
673 v1=[]).export_constant(__name__, "tpu_strategy")
674tf_export(
675 _TF_INTERNAL_API_PREFIX + "parameter_server_strategy_3worker_2ps_cpu",
676 v1=[]).export_constant(__name__,
677 "parameter_server_strategy_3worker_2ps_cpu")
678tf_export(
679 _TF_INTERNAL_API_PREFIX + "parameter_server_strategy_1worker_2ps_cpu",
680 v1=[]).export_constant(__name__,
681 "parameter_server_strategy_1worker_2ps_cpu")
682tf_export(
683 _TF_INTERNAL_API_PREFIX + "parameter_server_strategy_3worker_2ps_1gpu",
684 v1=[]).export_constant(__name__,
685 "parameter_server_strategy_3worker_2ps_1gpu")
686tf_export(
687 _TF_INTERNAL_API_PREFIX + "parameter_server_strategy_1worker_2ps_1gpu",
688 v1=[]).export_constant(__name__,
689 "parameter_server_strategy_1worker_2ps_1gpu")
690tf_export(
691 _TF_INTERNAL_API_PREFIX + "tpu_strategy_one_core",
692 v1=[]).export_constant(__name__, "tpu_strategy_one_core")
693tf_export(
694 _TF_INTERNAL_API_PREFIX + "tpu_strategy_packed_var",
695 v1=[]).export_constant(__name__, "tpu_strategy_packed_var")