Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/distribute/test_util.py: 30%
149 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Test utilities."""
17import collections
18import dataclasses
19import functools
20import io
21import itertools
22import threading
24from absl import app
26from tensorflow.python.compat import v2_compat
27from tensorflow.python.distribute import collective_all_reduce_strategy
28from tensorflow.python.distribute import multi_process_runner
29from tensorflow.python.distribute import multi_worker_test_base
30from tensorflow.python.distribute import tpu_strategy
31from tensorflow.python.distribute import values
32from tensorflow.python.eager import context
33from tensorflow.python.framework import config
34from tensorflow.python.framework import ops
35from tensorflow.python.ops import array_ops
36from tensorflow.python.ops import array_ops_stack
37from tensorflow.python.util import nest
39try:
40 import objgraph # pylint:disable=g-import-not-at-top
41except ImportError:
42 objgraph = None
45@dataclasses.dataclass
46class TestClusterParams:
47 cluster: dict
48 max_num_worker: int
49 max_num_ps: int
52def get_cluster_def(cluster_params, num_workers, num_ps):
53 if (num_workers > cluster_params.max_num_worker or
54 num_ps > cluster_params.max_num_ps):
55 raise ValueError("Requesting more servers than the maximum, adjust"
56 "cluster params' max_num_ps and max_num_worker")
57 if cluster_params.cluster is None:
58 cluster_params.cluster = multi_worker_test_base.create_in_process_cluster(
59 num_workers=cluster_params.max_num_worker,
60 num_ps=cluster_params.max_num_ps)
61 return {
62 "worker": cluster_params.cluster["worker"][:num_workers],
63 "ps": cluster_params.cluster["ps"][:num_ps],
64 }
67def gather(strategy, value):
68 """Gathers value from all workers.
70 This is intended for tests before we implement an official all-gather API.
72 Args:
73 strategy: a `tf.distribute.Strategy`.
74 value: a nested structure of n-dim `tf.distribute.DistributedValue` of
75 `tf.Tensor`, or of a `tf.Tensor` if the strategy only has one replica.
76 Cannot contain tf.sparse.SparseTensor.
78 Returns:
79 a (n+1)-dim `tf.Tensor`.
80 """
81 return nest.map_structure(functools.partial(_gather, strategy), value)
84def _gather(strategy, value):
85 """Gathers a single value."""
86 # pylint: disable=protected-access
87 if not isinstance(value, values.DistributedValues):
88 value = values.PerReplica([ops.convert_to_tensor(value)])
89 if not isinstance(strategy.extended,
90 collective_all_reduce_strategy.CollectiveAllReduceExtended):
91 return array_ops_stack.stack(value._values)
92 assert len(strategy.extended.worker_devices) == len(value._values)
93 inputs = [array_ops.expand_dims_v2(v, axis=0) for v in value._values]
94 return strategy.gather(values.PerReplica(inputs), axis=0)
95 # pylint: enable=protected-access
98def set_logical_devices_to_at_least(device, num):
99 """Create logical devices of at least a given number."""
100 if num < 1:
101 raise ValueError("`num` must be at least 1 not %r" % (num,))
102 physical_devices = config.list_physical_devices(device)
103 if not physical_devices:
104 raise RuntimeError("No {} found".format(device))
105 if len(physical_devices) >= num:
106 return
107 # By default each physical device corresponds to one logical device. We create
108 # multiple logical devices for the last physical device so that we have `num`
109 # logical devices.
110 num = num - len(physical_devices) + 1
111 logical_devices = []
112 for _ in range(num):
113 if device.upper() == "GPU":
114 logical_devices.append(
115 context.LogicalDeviceConfiguration(memory_limit=2048))
116 else:
117 logical_devices.append(context.LogicalDeviceConfiguration())
118 # Create logical devices from the last device since sometimes the first GPU
119 # is the primary graphic card and may have less memory available.
120 config.set_logical_device_configuration(physical_devices[-1], logical_devices)
123def _set_logical_devices():
124 if config.list_physical_devices("GPU"):
125 set_logical_devices_to_at_least("GPU", 2)
126 if config.list_physical_devices("CPU"):
127 set_logical_devices_to_at_least("CPU", 2)
130def main(enable_v2_behavior=True, config_logical_devices=True):
131 """All-in-one main function for tf.distribute tests."""
132 if config_logical_devices:
133 app.call_after_init(_set_logical_devices)
134 if enable_v2_behavior:
135 v2_compat.enable_v2_behavior()
136 else:
137 v2_compat.disable_v2_behavior()
138 multi_process_runner.test_main()
141def _op_dependencies(op):
142 """Returns the data and control dependencies of a tf.Operation combined."""
143 deps = []
144 for node in itertools.chain(op.inputs, op.control_inputs):
145 if isinstance(node, ops.Tensor):
146 node = node.op
147 assert isinstance(node, ops.Operation)
148 deps.append(node)
149 return deps
152def topological_sort_operations(operations):
153 """Topological sorts a list of operations.
155 This does a topological sort of the operations in a graph. The edges include
156 both data dependencies and control dependencies. Note that the edge goes from
157 an operation to its dependencies.
159 The sort is intentionally unstable, reversing orders of operations and
160 dependencies on ties.
162 Args:
163 operations: a list of tf.Operation in the same graph.
165 Returns:
166 A map from a tf.Operation to its topological order.
167 """
168 in_degrees = collections.OrderedDict()
169 for op in reversed(operations):
170 if op not in in_degrees:
171 in_degrees[op] = 0
172 for next_op in reversed(_op_dependencies(op)):
173 in_degrees[next_op] = in_degrees.get(next_op, 0) + 1
174 nexts = []
175 for op, in_degree in in_degrees.items():
176 if in_degree == 0:
177 nexts.append(op)
178 order = {}
179 next_order = 0
180 while nexts:
181 op, nexts = nexts[0], nexts[1:]
182 order[op] = next_order
183 next_order += 1
184 for next_op in reversed(_op_dependencies(op)):
185 in_degrees[next_op] -= 1
186 if in_degrees[next_op] == 0:
187 nexts.append(next_op)
188 assert len(order) == len(operations)
189 return order
192def _exists_dependency(start, end):
193 """Returns whether there exists a dependency chain from start to end."""
194 nexts = [start]
195 while nexts:
196 op, nexts = nexts[0], nexts[1:]
197 for next_op in _op_dependencies(op):
198 if next_op == end:
199 return True
200 nexts.append(next_op)
201 return False
204def assert_sequential_execution(order, operations):
205 """Asserts there's a deterministic execution order between the operations.
207 Args:
208 order: a map from a tf.Operation to its topological order.
209 operations: a list of operations that should be executed sequentially. It
210 can be given in any order.
211 """
212 # Topological ordering guarantees that, if there's a dependency from N_a to
213 # N_b, then order[N_a] < order[N_b]. If there do exist a path of dependencies
214 # among the operations, it always goes from a operation with a smaller
215 # topological order to one with a larger topological order. Therefore, we only
216 # need to sort the operations by their topological orders, and verify that
217 # there's a path of dependency between adjacent pairs.
218 operations = sorted(operations, key=lambda op: order[op])
219 for i in range(len(operations) - 1):
220 if not _exists_dependency(operations[i], operations[i + 1]):
221 print(operations[i].graph.as_graph_def())
222 raise AssertionError(
223 "No dependency between {} and {}. Graph is dumped to stdout.".format(
224 operations[i].name, operations[i + 1].name))
227def get_running_threads():
228 """Returns a set of all running thread names."""
229 running_threads = set()
230 for thread in threading.enumerate():
231 if thread.name is not None:
232 running_threads.add(thread.name)
233 return running_threads
236def has_thread(prefix, running_threads):
237 """Returns whether any 'running_threads' is prefixed with 'prefix'.
239 Args:
240 prefix: The prefix of the expected thread name.
241 running_threads: A collection of the running thread names.
242 """
243 for thread in running_threads:
244 if thread.startswith(prefix):
245 return True
246 return False
249def show_backref(target, max_depth=3):
250 """Returns a dot graph of all the objects that are referencing the target.
252 A object referencing graph is useful to debug memory leak like circular
253 reference. objgraph provides a good visualization of the memory graph than
254 most python built-in utilities like gc.get_referrers(), which are not
255 human-readable sometimes.
257 The dot graph will be written to a string IO object, and can be rendered with
258 graphviz in operating system.
259 E.g. dot -Tpng {$dot_graph} -o output.png
260 Args:
261 target: The target object for the memory graph.
262 max_depth: The maximum depth of the graph. By default 3 layers of references
263 are used. Increases this a lot may result in the graph growing too big.
265 Returns:
266 A string that contains the object reference graph.
267 Raises:
268 NotImplementedError: if objgraph is not installed.
269 """
270 if objgraph is None:
271 raise NotImplementedError("objgraph is not installed.")
272 string_io = io.StringIO()
273 objgraph.show_backrefs(target, max_depth=max_depth, output=string_io)
274 graph = string_io.getvalue()
275 string_io.close()
276 return graph
279def create_per_replica(strategy, value_list):
280 """Creates a PerReplica of Tensors from the value_list."""
281 if len(strategy.extended.worker_devices) != len(value_list):
282 raise ValueError(
283 "the length of values must be the same as the number of worker devices")
284 tensors = []
285 for device, value in zip(strategy.extended.worker_devices, value_list):
286 with ops.device(device):
287 tensors.append(ops.convert_to_tensor(value))
288 return values.PerReplica(tensors)
291def is_tpu_strategy(strategy):
292 """Returns whether the strategy is a TPU strategy."""
293 return isinstance(strategy,
294 (tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1,
295 tpu_strategy.TPUStrategyV2))
298def reset_context():
299 """Resets eager context."""
300 context._reset_context() # pylint: disable=protected-access