Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/distribute/test_util.py: 30%

149 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Test utilities.""" 

16 

17import collections 

18import dataclasses 

19import functools 

20import io 

21import itertools 

22import threading 

23 

24from absl import app 

25 

26from tensorflow.python.compat import v2_compat 

27from tensorflow.python.distribute import collective_all_reduce_strategy 

28from tensorflow.python.distribute import multi_process_runner 

29from tensorflow.python.distribute import multi_worker_test_base 

30from tensorflow.python.distribute import tpu_strategy 

31from tensorflow.python.distribute import values 

32from tensorflow.python.eager import context 

33from tensorflow.python.framework import config 

34from tensorflow.python.framework import ops 

35from tensorflow.python.ops import array_ops 

36from tensorflow.python.ops import array_ops_stack 

37from tensorflow.python.util import nest 

38 

39try: 

40 import objgraph # pylint:disable=g-import-not-at-top 

41except ImportError: 

42 objgraph = None 

43 

44 

45@dataclasses.dataclass 

46class TestClusterParams: 

47 cluster: dict 

48 max_num_worker: int 

49 max_num_ps: int 

50 

51 

52def get_cluster_def(cluster_params, num_workers, num_ps): 

53 if (num_workers > cluster_params.max_num_worker or 

54 num_ps > cluster_params.max_num_ps): 

55 raise ValueError("Requesting more servers than the maximum, adjust" 

56 "cluster params' max_num_ps and max_num_worker") 

57 if cluster_params.cluster is None: 

58 cluster_params.cluster = multi_worker_test_base.create_in_process_cluster( 

59 num_workers=cluster_params.max_num_worker, 

60 num_ps=cluster_params.max_num_ps) 

61 return { 

62 "worker": cluster_params.cluster["worker"][:num_workers], 

63 "ps": cluster_params.cluster["ps"][:num_ps], 

64 } 

65 

66 

67def gather(strategy, value): 

68 """Gathers value from all workers. 

69 

70 This is intended for tests before we implement an official all-gather API. 

71 

72 Args: 

73 strategy: a `tf.distribute.Strategy`. 

74 value: a nested structure of n-dim `tf.distribute.DistributedValue` of 

75 `tf.Tensor`, or of a `tf.Tensor` if the strategy only has one replica. 

76 Cannot contain tf.sparse.SparseTensor. 

77 

78 Returns: 

79 a (n+1)-dim `tf.Tensor`. 

80 """ 

81 return nest.map_structure(functools.partial(_gather, strategy), value) 

82 

83 

84def _gather(strategy, value): 

85 """Gathers a single value.""" 

86 # pylint: disable=protected-access 

87 if not isinstance(value, values.DistributedValues): 

88 value = values.PerReplica([ops.convert_to_tensor(value)]) 

89 if not isinstance(strategy.extended, 

90 collective_all_reduce_strategy.CollectiveAllReduceExtended): 

91 return array_ops_stack.stack(value._values) 

92 assert len(strategy.extended.worker_devices) == len(value._values) 

93 inputs = [array_ops.expand_dims_v2(v, axis=0) for v in value._values] 

94 return strategy.gather(values.PerReplica(inputs), axis=0) 

95 # pylint: enable=protected-access 

96 

97 

98def set_logical_devices_to_at_least(device, num): 

99 """Create logical devices of at least a given number.""" 

100 if num < 1: 

101 raise ValueError("`num` must be at least 1 not %r" % (num,)) 

102 physical_devices = config.list_physical_devices(device) 

103 if not physical_devices: 

104 raise RuntimeError("No {} found".format(device)) 

105 if len(physical_devices) >= num: 

106 return 

107 # By default each physical device corresponds to one logical device. We create 

108 # multiple logical devices for the last physical device so that we have `num` 

109 # logical devices. 

110 num = num - len(physical_devices) + 1 

111 logical_devices = [] 

112 for _ in range(num): 

113 if device.upper() == "GPU": 

114 logical_devices.append( 

115 context.LogicalDeviceConfiguration(memory_limit=2048)) 

116 else: 

117 logical_devices.append(context.LogicalDeviceConfiguration()) 

118 # Create logical devices from the last device since sometimes the first GPU 

119 # is the primary graphic card and may have less memory available. 

120 config.set_logical_device_configuration(physical_devices[-1], logical_devices) 

121 

122 

123def _set_logical_devices(): 

124 if config.list_physical_devices("GPU"): 

125 set_logical_devices_to_at_least("GPU", 2) 

126 if config.list_physical_devices("CPU"): 

127 set_logical_devices_to_at_least("CPU", 2) 

128 

129 

130def main(enable_v2_behavior=True, config_logical_devices=True): 

131 """All-in-one main function for tf.distribute tests.""" 

132 if config_logical_devices: 

133 app.call_after_init(_set_logical_devices) 

134 if enable_v2_behavior: 

135 v2_compat.enable_v2_behavior() 

136 else: 

137 v2_compat.disable_v2_behavior() 

138 multi_process_runner.test_main() 

139 

140 

141def _op_dependencies(op): 

142 """Returns the data and control dependencies of a tf.Operation combined.""" 

143 deps = [] 

144 for node in itertools.chain(op.inputs, op.control_inputs): 

145 if isinstance(node, ops.Tensor): 

146 node = node.op 

147 assert isinstance(node, ops.Operation) 

148 deps.append(node) 

149 return deps 

150 

151 

152def topological_sort_operations(operations): 

153 """Topological sorts a list of operations. 

154 

155 This does a topological sort of the operations in a graph. The edges include 

156 both data dependencies and control dependencies. Note that the edge goes from 

157 an operation to its dependencies. 

158 

159 The sort is intentionally unstable, reversing orders of operations and 

160 dependencies on ties. 

161 

162 Args: 

163 operations: a list of tf.Operation in the same graph. 

164 

165 Returns: 

166 A map from a tf.Operation to its topological order. 

167 """ 

168 in_degrees = collections.OrderedDict() 

169 for op in reversed(operations): 

170 if op not in in_degrees: 

171 in_degrees[op] = 0 

172 for next_op in reversed(_op_dependencies(op)): 

173 in_degrees[next_op] = in_degrees.get(next_op, 0) + 1 

174 nexts = [] 

175 for op, in_degree in in_degrees.items(): 

176 if in_degree == 0: 

177 nexts.append(op) 

178 order = {} 

179 next_order = 0 

180 while nexts: 

181 op, nexts = nexts[0], nexts[1:] 

182 order[op] = next_order 

183 next_order += 1 

184 for next_op in reversed(_op_dependencies(op)): 

185 in_degrees[next_op] -= 1 

186 if in_degrees[next_op] == 0: 

187 nexts.append(next_op) 

188 assert len(order) == len(operations) 

189 return order 

190 

191 

192def _exists_dependency(start, end): 

193 """Returns whether there exists a dependency chain from start to end.""" 

194 nexts = [start] 

195 while nexts: 

196 op, nexts = nexts[0], nexts[1:] 

197 for next_op in _op_dependencies(op): 

198 if next_op == end: 

199 return True 

200 nexts.append(next_op) 

201 return False 

202 

203 

204def assert_sequential_execution(order, operations): 

205 """Asserts there's a deterministic execution order between the operations. 

206 

207 Args: 

208 order: a map from a tf.Operation to its topological order. 

209 operations: a list of operations that should be executed sequentially. It 

210 can be given in any order. 

211 """ 

212 # Topological ordering guarantees that, if there's a dependency from N_a to 

213 # N_b, then order[N_a] < order[N_b]. If there do exist a path of dependencies 

214 # among the operations, it always goes from a operation with a smaller 

215 # topological order to one with a larger topological order. Therefore, we only 

216 # need to sort the operations by their topological orders, and verify that 

217 # there's a path of dependency between adjacent pairs. 

218 operations = sorted(operations, key=lambda op: order[op]) 

219 for i in range(len(operations) - 1): 

220 if not _exists_dependency(operations[i], operations[i + 1]): 

221 print(operations[i].graph.as_graph_def()) 

222 raise AssertionError( 

223 "No dependency between {} and {}. Graph is dumped to stdout.".format( 

224 operations[i].name, operations[i + 1].name)) 

225 

226 

227def get_running_threads(): 

228 """Returns a set of all running thread names.""" 

229 running_threads = set() 

230 for thread in threading.enumerate(): 

231 if thread.name is not None: 

232 running_threads.add(thread.name) 

233 return running_threads 

234 

235 

236def has_thread(prefix, running_threads): 

237 """Returns whether any 'running_threads' is prefixed with 'prefix'. 

238 

239 Args: 

240 prefix: The prefix of the expected thread name. 

241 running_threads: A collection of the running thread names. 

242 """ 

243 for thread in running_threads: 

244 if thread.startswith(prefix): 

245 return True 

246 return False 

247 

248 

249def show_backref(target, max_depth=3): 

250 """Returns a dot graph of all the objects that are referencing the target. 

251 

252 A object referencing graph is useful to debug memory leak like circular 

253 reference. objgraph provides a good visualization of the memory graph than 

254 most python built-in utilities like gc.get_referrers(), which are not 

255 human-readable sometimes. 

256 

257 The dot graph will be written to a string IO object, and can be rendered with 

258 graphviz in operating system. 

259 E.g. dot -Tpng {$dot_graph} -o output.png 

260 Args: 

261 target: The target object for the memory graph. 

262 max_depth: The maximum depth of the graph. By default 3 layers of references 

263 are used. Increases this a lot may result in the graph growing too big. 

264 

265 Returns: 

266 A string that contains the object reference graph. 

267 Raises: 

268 NotImplementedError: if objgraph is not installed. 

269 """ 

270 if objgraph is None: 

271 raise NotImplementedError("objgraph is not installed.") 

272 string_io = io.StringIO() 

273 objgraph.show_backrefs(target, max_depth=max_depth, output=string_io) 

274 graph = string_io.getvalue() 

275 string_io.close() 

276 return graph 

277 

278 

279def create_per_replica(strategy, value_list): 

280 """Creates a PerReplica of Tensors from the value_list.""" 

281 if len(strategy.extended.worker_devices) != len(value_list): 

282 raise ValueError( 

283 "the length of values must be the same as the number of worker devices") 

284 tensors = [] 

285 for device, value in zip(strategy.extended.worker_devices, value_list): 

286 with ops.device(device): 

287 tensors.append(ops.convert_to_tensor(value)) 

288 return values.PerReplica(tensors) 

289 

290 

291def is_tpu_strategy(strategy): 

292 """Returns whether the strategy is a TPU strategy.""" 

293 return isinstance(strategy, 

294 (tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1, 

295 tpu_strategy.TPUStrategyV2)) 

296 

297 

298def reset_context(): 

299 """Resets eager context.""" 

300 context._reset_context() # pylint: disable=protected-access