Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/tpu/tpu_sharding.py: 22%

106 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================= 

15"""Helper library for sharding during TPU compilation.""" 

16 

17 

18from tensorflow.python.framework import tensor_shape 

19 

20_DEFAULT_NUMBER_OF_SHARDS = 1 

21_DEFAULT_SHARD_DIMENSION = 0 

22 

23 

24# TODO(b/36777903) change other parts of tpu.py to use this class. 

25class ShardingPolicy(object): 

26 """An object use to hold the sharding policy for a Tensor.""" 

27 

28 def __init__(self): 

29 self._number_of_shards = None 

30 self._number_of_partitions = 1 

31 self._shard_dimension = None 

32 self._frozen = False 

33 

34 def __str__(self): 

35 if self.number_of_shards is None or self.shard_dimension is None: 

36 return "ShardingPolicy(unset)" 

37 else: 

38 return ("ShardingPolicy(%d shards dimension %d)" % 

39 (self.number_of_shards, self.shard_dimension)) 

40 

41 def _fill_default_values(self): 

42 if self._number_of_shards is None: 

43 self._number_of_shards = _DEFAULT_NUMBER_OF_SHARDS 

44 if self._shard_dimension is None: 

45 self._shard_dimension = tensor_shape.as_dimension( 

46 _DEFAULT_SHARD_DIMENSION) 

47 

48 def freeze(self): 

49 """Prevents further modification to the sharding policy. 

50 

51 Any values that have not been set when freeze is called are set to 

52 defaults. If the ShardingPolicy is already frozen, this is a NoOp. 

53 """ 

54 if not self._frozen: 

55 self._fill_default_values() 

56 self._frozen = True 

57 

58 @property 

59 def number_of_shards(self): 

60 """Returns the number of shards in the policy or None if unspecified.""" 

61 return self._number_of_shards 

62 

63 def set_number_of_shards(self, number_of_shards): 

64 """Sets the number of shards for the current policy. 

65 

66 If the policy has been frozen then number_of_shards must match the 

67 existing setting. 

68 

69 Args: 

70 number_of_shards: The number of shards to use in the policy. 

71 

72 Raises: 

73 ValueError: If the policy has been frozen and number_of_shards 

74 differs from the frozen value; or number_of_shards <= 0. 

75 """ 

76 if self._frozen: 

77 if self._number_of_shards != number_of_shards: 

78 raise ValueError( 

79 f"Can't set sharding policy to use {number_of_shards} shards since " 

80 f"it has been frozen to use {self._number_of_shards}") 

81 else: 

82 if number_of_shards > 0: 

83 self._number_of_shards = number_of_shards 

84 else: 

85 raise ValueError( 

86 f"Can't set sharding policy to use {number_of_shards} shards; " 

87 "value must be > 0") 

88 

89 @property 

90 def number_of_partitions(self): 

91 """Returns the number of partitions of the policy or None if unspecified.""" 

92 return self._number_of_partitions 

93 

94 def set_number_of_partitions(self, number_of_partitions): 

95 """Sets the number of partitions for the current policy. 

96 

97 If the policy has been frozen then shard_dimension must match the 

98 existing setting. 

99 

100 Args: 

101 number_of_partitions: The number of partitions to use in the policy. 

102 

103 Raises: 

104 ValueError: If the policy has been frozen and shard_dimension 

105 differs from the frozen value. 

106 """ 

107 if self._frozen: 

108 if self._number_of_partitions != number_of_partitions: 

109 raise ValueError( 

110 f"Can't set number_of_partitions to {number_of_partitions} since " 

111 f"it has been frozen to use {self._number_of_partitions}.") 

112 else: 

113 self._number_of_partitions = number_of_partitions 

114 

115 @property 

116 def shard_dimension(self): 

117 """Returns the shard dimension of the policy or None if unspecified.""" 

118 return self._shard_dimension 

119 

120 def set_shard_dimension(self, shard_dimension): 

121 """Sets the shard dimension for the current policy. 

122 

123 If the policy has been frozen then shard_dimension must match the 

124 existing setting. 

125 

126 Args: 

127 shard_dimension: The shard dimension to use in the policy. 

128 

129 Raises: 

130 ValueError: If the policy has been frozen and shard_dimension 

131 differs from the frozen value, or shard_dimension can't be 

132 interpreted as a Dimension. 

133 """ 

134 if self._frozen: 

135 if self._shard_dimension != shard_dimension: 

136 raise ValueError( 

137 "Can't set shard dimension to %d since it has been frozen to " 

138 "use %d." % (shard_dimension, self._shard_dimension)) 

139 else: 

140 self._shard_dimension = tensor_shape.as_dimension(shard_dimension) 

141 

142 def merge(self, other): 

143 """Merges the policy of another policy into the current policy. 

144 

145 Args: 

146 other: The policy to merge into this one. 

147 

148 Raises: 

149 ValueError: If this policy has been frozen and the merge conflicts with 

150 the frozen policy. 

151 """ 

152 if other.number_of_shards is not None: 

153 self.set_number_of_shards(other.number_of_shards) 

154 if other.shard_dimension is not None: 

155 self.set_shard_dimension(other.shard_dimension) 

156 

157 def get_unpartitioned_shape(self, shape): 

158 """Returns the shape of an unpartitioned Tensor. 

159 

160 When given the shape of a 'sharded-size' Tensor, returns the shape 

161 of the full shape of its unpartitioned Tensor. 

162 

163 Args: 

164 shape: The shape of the sharded Tensor. 

165 

166 Returns: 

167 The shape of the unpartitioned version of the Tensor. 

168 

169 Raises: 

170 ValueError: if shape has unknown sharded dimension 

171 """ 

172 shape = tensor_shape.as_shape(shape) 

173 dims = shape.as_list() 

174 if (self._shard_dimension is None or self._number_of_partitions is None or 

175 not dims): 

176 return None 

177 if dims[self._shard_dimension] is None: 

178 raise ValueError(f"Shape {shape.as_list()} must have a fixed size for " 

179 f"dimension {self._shard_dimension} that is known. ") 

180 if self._number_of_partitions > 1: 

181 dims[self._shard_dimension] *= self._number_of_partitions 

182 return tensor_shape.as_shape(dims) 

183 

184 def get_sharded_shape(self, shape, shard_index=None): 

185 """Returns the shape of a shard of a full Tensor. 

186 

187 When given the shape of a 'full-size' Tensor, returns the shape of 

188 the sub-Tensor after it has been sharded. Freezes the policy if it 

189 has not yet been frozen. 

190 

191 Args: 

192 shape: The shape of the full-size Tensor to be sharded. 

193 shard_index: The index of the shard whose shape should be returned. 

194 shard_index can be None for sharding policies that use the same shape 

195 for every shard. 

196 

197 Returns: 

198 The shape of the sharded version of the Tensor. 

199 

200 Raises: 

201 ValueError: If shard_index is None when shards are of different 

202 shapes; or shard_index is not None and 

203 !(0<=shard_index<number_of_shards); or shape does not have at 

204 least self.shard_dimension+1 dimensions; or the value of 

205 shape's shard dimension is not a multiple of 

206 self.number_of_shards 

207 """ 

208 if self._shard_dimension is None or self._number_of_shards is None: 

209 # Don't raise an error if the config is unset. 

210 return None 

211 if shard_index is not None: 

212 if shard_index < 0 or shard_index >= self.number_of_shards: 

213 raise ValueError( 

214 f"Requested shard_index {shard_index}, but shard_index must be in " 

215 f"[0,{self._number_of_shards}).") 

216 shape = tensor_shape.as_shape(shape) 

217 if self._number_of_shards == 1: 

218 # Don't do anything when there's only one shard. 

219 return shape 

220 ndims = shape.ndims 

221 if ndims is None: 

222 raise ValueError(f"Shape {shape} must be a known shape.") 

223 if ndims <= self._shard_dimension: 

224 raise ValueError( 

225 f"Shape {shape.as_list()} does not contain shard_dimension " 

226 f"{self._shard_dimension}") 

227 dims = shape.as_list() 

228 if dims[self._shard_dimension] is None: 

229 raise ValueError( 

230 f"Shape {shape.as_list()} must have a fixed size for dimension " 

231 f"{self._shard_dimension} that is known at construction time.") 

232 if (dims[self._shard_dimension] % self._number_of_shards) != 0: 

233 raise ValueError( 

234 f"Shape {shape.as_list()} cannot be sharded {self._number_of_shards} " 

235 f"ways along dimension {self._shard_dimension}") 

236 dims[self._shard_dimension] //= self._number_of_shards 

237 return tensor_shape.TensorShape(dims) 

238 

239 def _unshard_shape(self, shape): 

240 """Return the unsharded shape that would generate a given sharded shape. 

241 

242 Args: 

243 shape: the sharded shape to unshard 

244 

245 Returns: 

246 The unsharded shape. 

247 

248 Raises: 

249 ValueError: if shape is unknown or does not contain 

250 self.shard_dimension 

251 TypeError: if shape is not convertible to a TensorShape 

252 """ 

253 shape = tensor_shape.as_shape(shape) 

254 if self._number_of_shards == 1: 

255 # Don't do anything when there's only one shard. 

256 return shape 

257 ndims = shape.ndims 

258 if ndims is None: 

259 raise ValueError(f"Shape {shape} must be statically known.") 

260 if ndims <= self._shard_dimension: 

261 raise ValueError(f"Shape {shape.as_list()} does not contain " 

262 f"shard_dimension {self._shard_dimension}. " 

263 f"Rank is too small.") 

264 dims = shape.as_list() 

265 dims[self._shard_dimension] *= self._number_of_shards 

266 return tensor_shape.TensorShape(dims) 

267 

268 def get_unsharded_shape(self, shapes): 

269 """Returns the shape of an unsharded Tensor given a list of shards. 

270 

271 When given a list of shapes of shards, returns the shape of the 

272 unsharded Tensor that would generate the shards. Sets defaults for the 

273 policy if number_of_shards or shard_dimension is None. 

274 

275 Args: 

276 shapes: The shapes of the Tensor shards to be combined. 

277 

278 Returns: 

279 The shape of the unsharded version of the Tensor. 

280 

281 Raises: 

282 ValueError: if shapes is not a list of length 

283 self.number_of_shards; or any element of shapes is not a valid 

284 shape consistent with the sharding policy; or the list of 

285 shapes is not a valid sharding of a full shape. 

286 TypeError: if an element of shapes is not convertible to a 

287 TensorShape 

288 """ 

289 self._fill_default_values() 

290 if len(shapes) != self.number_of_shards: 

291 raise ValueError( 

292 f"Shapes {shapes} is length {len(shapes)} but must be a list of " 

293 f"length number_of_shards={self.number_of_shards}") 

294 unsharded_shapes = [self._unshard_shape(s) for s in shapes] 

295 for i in range(self.number_of_shards - 1): 

296 if not unsharded_shapes[i].is_compatible_with( 

297 unsharded_shapes[self.number_of_shards - 1]): 

298 raise ValueError( 

299 f"Sharded shapes {shapes} are not consistent shards of a full shape " 

300 f"sharded {self.number_of_shards} ways along " 

301 f"dimension {self.shard_dimension}.") 

302 return unsharded_shapes[0]