Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/tpu/topology.py: 32%

88 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ====================================== 

15"""Defines the `Topology` class, that describes a TPU fabric topology.""" 

16 

17import numpy as np 

18 

19from tensorflow.core.protobuf.tpu import topology_pb2 

20from tensorflow.python.util.tf_export import tf_export 

21 

22 

23def _tpu_device_name(job, task, device): 

24 """Returns the device name for the TPU `device` on `task` of `job`.""" 

25 if job is None: 

26 return "/task:%d/device:TPU:%d" % (task, device) 

27 else: 

28 return "/job:%s/task:%d/device:TPU:%d" % (job, task, device) 

29 

30 

31def _tpu_host_device_name(job, task): 

32 """Returns the device name for the CPU device on `task` of `job`.""" 

33 if job is None: 

34 return "/task:%d/device:CPU:0" % task 

35 else: 

36 return "/job:%s/task:%d/device:CPU:0" % (job, task) 

37 

38 

39@tf_export("tpu.experimental.Topology") 

40class Topology(object): 

41 """Describes a set of TPU devices. 

42 

43 Represents both the shape of the physical mesh, and the mapping between 

44 TensorFlow TPU devices to physical mesh coordinates. 

45 """ 

46 

47 def __init__(self, serialized=None, mesh_shape=None, device_coordinates=None): 

48 """Builds a Topology object. 

49 

50 If `serialized` is not `None`, the topology is parsed from `serialized` and 

51 the other arguments are ignored. Otherwise, the topology is computed from 

52 `mesh_shape` and `device_coordinates`. 

53 

54 Args: 

55 serialized: A serialized `TopologyProto`, or `None`. If not `None`, the 

56 serialized proto is parsed to discover the topology. 

57 mesh_shape: A sequence of 4 positive integers, or `None`. If not `None`, 

58 the shape of the TPU topology, in number of cores. Ignored if 

59 `serialized` is not `None`. 

60 device_coordinates: A rank 3 numpy array that describes the mapping from 

61 TensorFlow TPU devices to TPU fabric coordinates, or `None`. If 

62 specified, array is a rank 3 int32 array with shape 

63 `[tasks, devices, axis]`. `tasks` is the number of tasks in the TPU 

64 cluster, `devices` is the number of TPU devices per task, and `axis` is 

65 the number of axes in the TPU cluster topology. Each entry gives the 

66 `axis`-th coordinate in the topology of a task/device pair. TPU 

67 topologies are 4-dimensional, with dimensions `(x, y, z, core number)`. 

68 This arg is ignored if `serialized is not `None`. 

69 

70 Raises: 

71 ValueError: If `serialized` does not describe a well-formed topology. 

72 ValueError: If `serialized` is `None` and `mesh_shape` is not a sequence 

73 of 4 positive integers. 

74 ValueError: If `serialized` is `None` and `device_coordinates` is not a 

75 rank 3 numpy int32 array that describes a valid coordinate mapping. 

76 """ 

77 

78 self._serialized = serialized 

79 

80 if serialized: 

81 self._parse_topology(serialized) 

82 else: 

83 self._mesh_shape = np.asarray(mesh_shape, dtype=np.int32) 

84 self._device_coordinates = np.asarray(device_coordinates, np.int32) 

85 if len(self._mesh_shape) != 4 or any(self._mesh_shape < 1): 

86 raise ValueError("`mesh_shape` must be a sequence of 4 positive " 

87 f"entries; got `mesh_shape={self._mesh_shape}`") 

88 

89 if (len(self._device_coordinates.shape) != 3 or 

90 self._device_coordinates.shape[2] != len(self._mesh_shape)): 

91 raise ValueError( 

92 "`device_coordinates` must be a rank 3 int32 array " 

93 "with minor dimension equal to the `mesh_shape` rank" 

94 "got device_coordinates.shape={} len(device_coordinates.shape)={} device_coordinates.shape[2]={} mesh_shape={}, len(mesh_shape)={}" 

95 .format(self._device_coordinates.shape, 

96 len(self._device_coordinates.shape), 

97 self._device_coordinates.shape[2], self._mesh_shape, 

98 len(self._mesh_shape))) 

99 

100 self._topology_tasks, self._topology_devices = self._invert_topology() 

101 

102 # Coordinates of devices that are missing 

103 self._missing_devices = np.argwhere(self._topology_tasks < 0) 

104 

105 def _parse_topology(self, serialized): 

106 """Parses a serialized `TopologyProto` into `self`.""" 

107 proto = topology_pb2.TopologyProto() 

108 proto.ParseFromString(serialized) 

109 

110 self._mesh_shape = np.array(proto.mesh_shape, dtype=np.int32) 

111 if len(self._mesh_shape) != 4 or any(self._mesh_shape < 1): 

112 raise ValueError("`mesh_shape` must be a vector of size 4 with positive " 

113 "entries; got {}".format(self._mesh_shape)) 

114 

115 if proto.num_tasks < 0: 

116 raise ValueError("`num_tasks` must be >= 0; got {}".format( 

117 proto.num_tasks)) 

118 if proto.num_tpu_devices_per_task < 0: 

119 raise ValueError("`num_tpu_devices_per_task` must be >= 0; got {}".format( 

120 proto.num_tpu_devices_per_task)) 

121 

122 expected_coordinates_size = ( 

123 proto.num_tasks * proto.num_tpu_devices_per_task * len( 

124 proto.mesh_shape)) 

125 if len(proto.device_coordinates) != expected_coordinates_size: 

126 raise ValueError("`device_coordinates` must have shape num_tasks ({}) * " 

127 "num_tpu_devices_per_task ({}) * len(mesh_shape) ({}); " 

128 "got shape {}".format(proto.num_tasks, 

129 proto.num_tpu_devices_per_task, 

130 proto.mesh_shape, 

131 len(proto.device_coordinates))) 

132 

133 coords = np.array(proto.device_coordinates, dtype=np.int32) 

134 if any(coords < 0): 

135 raise ValueError( 

136 "All values in `device_coordinates` must be >= 0, got {}" 

137 .format(coords)) 

138 coords = coords.reshape((proto.num_tasks, proto.num_tpu_devices_per_task, 

139 len(proto.mesh_shape))) 

140 self._device_coordinates = coords 

141 

142 def _invert_topology(self): 

143 """Inverts a [task,device,axis] topology to [x,y,z] -> task/device maps.""" 

144 tasks = np.full(list(self.mesh_shape), -1, dtype=np.int32) 

145 devices = np.full(list(self.mesh_shape), -1, dtype=np.int32) 

146 for task in range(self.device_coordinates.shape[0]): 

147 for device in range(self.device_coordinates.shape[1]): 

148 x, y, z, core = self.device_coordinates[task, device, :] 

149 tasks[x, y, z, core] = task 

150 devices[x, y, z, core] = device 

151 return tasks, devices 

152 

153 @property 

154 def mesh_shape(self): 

155 """A rank 1 int32 array describing the shape of the TPU topology.""" 

156 return self._mesh_shape 

157 

158 @property 

159 def mesh_rank(self): 

160 """Returns the number of dimensions in the mesh.""" 

161 return len(self._mesh_shape) 

162 

163 @property 

164 def device_coordinates(self): 

165 """Describes the mapping from TPU devices to topology coordinates. 

166 

167 Returns: 

168 A rank 3 int32 array with shape `[tasks, devices, axis]`. 

169 `tasks` is the number of tasks in the TPU cluster, `devices` is the number 

170 of TPU devices per task, and `axis` is the number of axes in the TPU 

171 cluster topology. Each entry gives the `axis`-th coordinate in the 

172 topology of a task/device pair. TPU topologies are 4-dimensional, with 

173 dimensions `(x, y, z, core number)`. 

174 """ 

175 return self._device_coordinates 

176 

177 @property 

178 def missing_devices(self): 

179 """Array of indices of missing devices.""" 

180 return self._missing_devices 

181 

182 def task_ordinal_at_coordinates(self, device_coordinates): 

183 """Returns the TensorFlow task number attached to `device_coordinates`. 

184 

185 Args: 

186 device_coordinates: An integer sequence describing a device's physical 

187 coordinates in the TPU fabric. 

188 

189 Returns: 

190 Returns the TensorFlow task number that contains the TPU device with those 

191 physical coordinates. 

192 """ 

193 return self._topology_tasks[tuple(device_coordinates)] 

194 

195 def tpu_device_ordinal_at_coordinates(self, device_coordinates): 

196 """Returns the TensorFlow device number at `device_coordinates`. 

197 

198 Args: 

199 device_coordinates: An integer sequence describing a device's physical 

200 coordinates in the TPU fabric. 

201 

202 Returns: 

203 Returns the TensorFlow device number within the task corresponding to 

204 attached to the device with those physical coordinates. 

205 """ 

206 return self._topology_devices[tuple(device_coordinates)] 

207 

208 def cpu_device_name_at_coordinates(self, device_coordinates, job=None): 

209 """Returns the CPU device attached to a logical core.""" 

210 return _tpu_host_device_name( 

211 job, self._topology_tasks[tuple(device_coordinates)]) 

212 

213 def tpu_device_name_at_coordinates(self, device_coordinates, job=None): 

214 """Returns the name of the TPU device assigned to a logical core.""" 

215 return _tpu_device_name(job, 

216 self._topology_tasks[tuple(device_coordinates)], 

217 self._topology_devices[tuple(device_coordinates)]) 

218 

219 @property 

220 def num_tasks(self): 

221 """Returns the number of TensorFlow tasks in the TPU slice.""" 

222 return self._device_coordinates.shape[0] 

223 

224 @property 

225 def num_tpus_per_task(self): 

226 """Returns the number of TPU devices per task in the TPU slice.""" 

227 return self._device_coordinates.shape[1] 

228 

229 def serialized(self): 

230 """Returns the serialized form of the topology.""" 

231 if self._serialized is None: 

232 proto = topology_pb2.TopologyProto() 

233 proto.mesh_shape[:] = list(self._mesh_shape) 

234 proto.num_tasks = self._device_coordinates.shape[0] 

235 proto.num_tpu_devices_per_task = self._device_coordinates.shape[1] 

236 proto.device_coordinates.extend(list(self._device_coordinates.flatten())) 

237 self._serialized = proto.SerializeToString() 

238 

239 return self._serialized