Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/distribute/sharded_variable.py: 32%

379 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""ShardedVariable class.""" 

16import copy 

17import math 

18from typing import Sequence 

19import weakref 

20 

21import numpy as np 

22 

23from tensorflow.python.framework import composite_tensor 

24from tensorflow.python.framework import constant_op 

25from tensorflow.python.framework import dtypes 

26from tensorflow.python.framework import indexed_slices as indexed_slices_lib 

27from tensorflow.python.framework import ops 

28from tensorflow.python.framework import tensor_conversion_registry 

29from tensorflow.python.framework import tensor_shape 

30from tensorflow.python.framework import type_spec 

31from tensorflow.python.ops import array_ops 

32from tensorflow.python.ops import data_flow_ops 

33from tensorflow.python.ops import embedding_ops 

34from tensorflow.python.ops import math_ops 

35from tensorflow.python.ops import partitioned_variables 

36from tensorflow.python.ops import resource_variable_ops 

37from tensorflow.python.ops import variables as variables_lib 

38from tensorflow.python.saved_model import save_context 

39from tensorflow.python.trackable import base as trackable 

40from tensorflow.python.training.saving import saveable_object_util 

41from tensorflow.python.util import dispatch 

42from tensorflow.python.util.tf_export import tf_export 

43 

44 

45@tf_export('distribute.experimental.partitioners.Partitioner', v1=[]) 

46class Partitioner(object): 

47 """Partitioner base class: all partitiners inherit from this class. 

48 

49 Partitioners should implement a `__call__` method with the following 

50 signature: 

51 

52 ```python 

53 def __call__(self, shape, dtype, axis=0): 

54 # Partitions the given `shape` and returns the partition results. 

55 # See docstring of `__call__` method for the format of partition results. 

56 ``` 

57 """ 

58 

59 def __call__(self, shape, dtype, axis=0): 

60 """Partitions the given `shape` and returns the partition results. 

61 

62 Examples of a partitioner that allocates a fixed number of shards: 

63 

64 ```python 

65 partitioner = FixedShardsPartitioner(num_shards=2) 

66 partitions = partitioner(tf.TensorShape([10, 3], tf.float32), axis=0) 

67 print(partitions) # [2, 0] 

68 ``` 

69 

70 Args: 

71 shape: a `tf.TensorShape`, the shape to partition. 

72 dtype: a `tf.dtypes.Dtype` indicating the type of the partition value. 

73 axis: The axis to partition along. Default: outermost axis. 

74 

75 Returns: 

76 A list of integers representing the number of partitions on each axis, 

77 where i-th value correponds to i-th axis. 

78 """ 

79 raise NotImplementedError 

80 

81 

82@tf_export('distribute.experimental.partitioners.FixedShardsPartitioner', v1=[]) 

83class FixedShardsPartitioner(Partitioner): 

84 """Partitioner that allocates a fixed number of shards. 

85 

86 Examples: 

87 

88 >>> # standalone usage: 

89 >>> partitioner = FixedShardsPartitioner(num_shards=2) 

90 >>> partitions = partitioner(tf.TensorShape([10, 3]), tf.float32) 

91 >>> [2, 1] 

92 >>> 

93 >>> # use in ParameterServerStrategy 

94 >>> # strategy = tf.distribute.experimental.ParameterServerStrategy( 

95 >>> # cluster_resolver=cluster_resolver, variable_partitioner=partitioner) 

96 

97 """ 

98 

99 def __init__(self, num_shards): 

100 """Creates a new `FixedShardsPartitioner`. 

101 

102 Args: 

103 num_shards: `int`, number of shards to partition. 

104 """ 

105 self._num_shards = num_shards 

106 

107 def __call__(self, shape, dtype, axis=0): 

108 del dtype 

109 result = [1] * len(shape) 

110 result[axis] = min(self._num_shards, shape.dims[axis].value) 

111 return result 

112 

113 

114@tf_export('distribute.experimental.partitioners.MinSizePartitioner', v1=[]) 

115class MinSizePartitioner(Partitioner): 

116 """Partitioner that allocates a minimum size per shard. 

117 

118 This partitioner ensures each shard has at least `min_shard_bytes`, and tries 

119 to allocate as many shards as possible, i.e., keeping shard size as small as 

120 possible. The maximum number of such shards (upper bound) is given by 

121 `max_shards`. 

122 

123 Examples: 

124 

125 >>> partitioner = MinSizePartitioner(min_shard_bytes=4, max_shards=2) 

126 >>> partitions = partitioner(tf.TensorShape([6, 1]), tf.float32) 

127 >>> [2, 1] 

128 >>> partitioner = MinSizePartitioner(min_shard_bytes=4, max_shards=10) 

129 >>> partitions = partitioner(tf.TensorShape([6, 1]), tf.float32) 

130 >>> [6, 1] 

131 >>> 

132 >>> # use in ParameterServerStrategy 

133 >>> # strategy = tf.distribute.experimental.ParameterServerStrategy( 

134 >>> # cluster_resolver=cluster_resolver, variable_partitioner=partitioner) 

135 """ 

136 

137 def __init__(self, 

138 min_shard_bytes=256 << 10, 

139 max_shards=1, 

140 bytes_per_string=16): 

141 """Creates a new `MinSizePartitioner`. 

142 

143 Args: 

144 min_shard_bytes: Minimum bytes of each shard. Defaults to 256K. 

145 max_shards: Upper bound on the number of shards. Defaults to 1. 

146 bytes_per_string: If the partition value is of type string, this provides 

147 an estimate of how large each string is. 

148 """ 

149 if min_shard_bytes < 1: 

150 raise ValueError('Argument `min_shard_bytes` must be positive. ' 

151 f'Received: {min_shard_bytes}') 

152 if max_shards < 1: 

153 raise ValueError('Argument `max_shards` must be positive. ' 

154 f'Received: {max_shards}') 

155 if bytes_per_string < 1: 

156 raise ValueError('Argument `bytes_per_string` must be positive. ' 

157 f'Received: {bytes_per_string}') 

158 self._min_shard_bytes = min_shard_bytes 

159 self._max_shards = max_shards 

160 self._bytes_per_string = bytes_per_string 

161 

162 def __call__(self, shape, dtype, axis=0): 

163 return partitioned_variables.min_max_variable_partitioner( 

164 max_partitions=self._max_shards, 

165 axis=axis, 

166 min_slice_size=self._min_shard_bytes, 

167 bytes_per_string_element=self._bytes_per_string)(shape, dtype) 

168 

169 

170@tf_export('distribute.experimental.partitioners.MaxSizePartitioner', v1=[]) 

171class MaxSizePartitioner(Partitioner): 

172 """Partitioner that keeps shards below `max_shard_bytes`. 

173 

174 This partitioner ensures each shard has at most `max_shard_bytes`, and tries 

175 to allocate as few shards as possible, i.e., keeping shard size as large 

176 as possible. 

177 

178 If the partitioner hits the `max_shards` limit, then each shard may end up 

179 larger than `max_shard_bytes`. By default `max_shards` equals `None` and no 

180 limit on the number of shards is enforced. 

181 

182 Examples: 

183 

184 >>> partitioner = MaxSizePartitioner(max_shard_bytes=4) 

185 >>> partitions = partitioner(tf.TensorShape([6, 1]), tf.float32) 

186 >>> [6, 1] 

187 >>> partitioner = MaxSizePartitioner(max_shard_bytes=4, max_shards=2) 

188 >>> partitions = partitioner(tf.TensorShape([6, 1]), tf.float32) 

189 >>> [2, 1] 

190 >>> partitioner = MaxSizePartitioner(max_shard_bytes=1024) 

191 >>> partitions = partitioner(tf.TensorShape([6, 1]), tf.float32) 

192 >>> [1, 1] 

193 >>> 

194 >>> # use in ParameterServerStrategy 

195 >>> # strategy = tf.distribute.experimental.ParameterServerStrategy( 

196 >>> # cluster_resolver=cluster_resolver, variable_partitioner=partitioner) 

197 """ 

198 

199 def __init__(self, max_shard_bytes, max_shards=None, bytes_per_string=16): 

200 """Creates a new `MaxSizePartitioner`. 

201 

202 Args: 

203 max_shard_bytes: The maximum size any given shard is allowed to be. 

204 max_shards: The maximum number of shards in `int` created taking 

205 precedence over `max_shard_bytes`. 

206 bytes_per_string: If the partition value is of type string, this provides 

207 an estimate of how large each string is. 

208 """ 

209 if max_shard_bytes < 1: 

210 raise ValueError('Argument `max_shard_bytes` must be positive. ' 

211 f'Received {max_shard_bytes}') 

212 if max_shards and max_shards < 1: 

213 raise ValueError('Argument `max_shards` must be positive. ' 

214 f'Received {max_shards}') 

215 if bytes_per_string < 1: 

216 raise ValueError('Argument `bytes_per_string` must be positive. ' 

217 f'Received: {bytes_per_string}') 

218 

219 self._max_shard_bytes = max_shard_bytes 

220 self._max_shards = max_shards 

221 self._bytes_per_string = bytes_per_string 

222 

223 def __call__(self, shape, dtype, axis=0): 

224 return partitioned_variables.variable_axis_size_partitioner( 

225 max_shard_bytes=self._max_shard_bytes, 

226 max_shards=self._max_shards, 

227 bytes_per_string_element=self._bytes_per_string, 

228 axis=axis)(shape, dtype) 

229 

230 

231class ShardedVariableSpec(type_spec.TypeSpec): 

232 """Type specification for a `ShardedVariable`.""" 

233 

234 __slots__ = ['_variable_specs'] 

235 

236 value_type = property(lambda self: ShardedVariable) 

237 

238 def __init__(self, *variable_specs): 

239 self._variable_specs = tuple(variable_specs) 

240 

241 def _serialize(self): 

242 return self._variable_specs 

243 

244 @property 

245 def _component_specs(self): 

246 return self._variable_specs 

247 

248 def _to_components(self, value): 

249 return value.variables 

250 

251 def _from_components(self, variables): 

252 return ShardedVariable(variables) 

253 

254 

255class ShardedVariableMixin(trackable.Trackable): 

256 """Mixin for ShardedVariable.""" 

257 

258 # TODO(b/170877138): Remove this mixin once fixed. This mixin is required 

259 # since TPUEmbeddingVariable can't be a CompositeTensor. 

260 

261 def __init__(self, variables, name='ShardedVariable'): 

262 """Treats `variables` as shards of a larger Variable. 

263 

264 

265 Example: 

266 

267 ``` 

268 variables = [ 

269 tf.Variable(..., shape=(10, 100), dtype=tf.float32), 

270 tf.Variable(..., shape=(15, 100), dtype=tf.float32), 

271 tf.Variable(..., shape=(5, 100), dtype=tf.float32) 

272 ] 

273 sharded_variable = ShardedVariableMixin(variables) 

274 assert sharded_variable.shape.as_list() == [30, 100] 

275 ``` 

276 

277 Args: 

278 variables: A list of `ResourceVariable`s that comprise this sharded 

279 variable. Variables should not be shared between different 

280 `ShardedVariableMixin` objects. 

281 name: String. Name of this container. Defaults to "ShardedVariable". 

282 """ 

283 super(ShardedVariableMixin, self).__init__() 

284 self._variables = variables 

285 self._name = name 

286 

287 if not isinstance(variables, Sequence) or not variables or any( 

288 not isinstance(v, variables_lib.Variable) for v in variables): 

289 raise TypeError('Argument `variables` should be a non-empty list of ' 

290 f'`variables.Variable`s. Received {variables}') 

291 

292 var_dtypes = {v.dtype for v in variables} 

293 if len(var_dtypes) > 1: 

294 raise ValueError( 

295 'All elements in argument `variables` must have the same dtype. ' 

296 f'Received dtypes: {[v.dtype for v in variables]}') 

297 

298 first_var = variables[0] 

299 self._dtype = first_var.dtype 

300 

301 # All variables must have the same shape for axes > 0. 

302 higher_dim_shapes = {tuple(v.shape.as_list()[1:]) for v in variables} 

303 if len(higher_dim_shapes) > 1: 

304 raise ValueError( 

305 'All elements in argument `variables` must have the same shapes ' 

306 'except for the first axis. ' 

307 f'Received shapes: {[v.shape for v in variables]}') 

308 first_dim = sum(int(v.shape.as_list()[0]) for v in variables) 

309 self._shape = tensor_shape.TensorShape([first_dim] + 

310 first_var.shape.as_list()[1:]) 

311 

312 for v in variables: 

313 v._sharded_container = weakref.ref(self) 

314 

315 self._var_offsets = [ 

316 [0 for _ in range(len(first_var.shape))] for _ in range(len(variables)) 

317 ] 

318 for i in range(1, len(variables)): 

319 # Always partition on the first axis. Offsets on other axes are 0. 

320 self._var_offsets[i][0] += ( 

321 self._var_offsets[i - 1][0] + variables[i - 1].shape.as_list()[0]) 

322 

323 save_slice_info = [v._get_save_slice_info() for v in variables] # pylint: disable=protected-access 

324 if any(slice_info is not None for slice_info in save_slice_info): 

325 raise ValueError( 

326 '`SaveSliceInfo` should not be set for all elements in argument ' 

327 '`variables`. `ShardedVariable` will infer `SaveSliceInfo` according ' 

328 'to the order of the elements `variables`. ' 

329 f'Received save slice info {save_slice_info}') 

330 

331 # We create an uninitialized saving_variable with the full shape, which can 

332 # be later captured in signatures so that the signatures can treat this 

333 # ShardedVariable as one single variable. 

334 self._saving_variable = resource_variable_ops.UninitializedVariable( 

335 shape=self._shape, dtype=self._dtype, name=self._name, 

336 trainable=self._variables[0].trainable, 

337 synchronization=variables_lib.VariableSynchronization.NONE, 

338 aggregation=variables_lib.VariableAggregation.NONE) 

339 

340 def __iter__(self): 

341 """Return an iterable for accessing the underlying sharded variables.""" 

342 return iter(self._variables) 

343 

344 def __getitem__(self, slice_spec): 

345 """Extracts the specified region as a Tensor from the sharded variable. 

346 

347 The API contract is identical to `Tensor.__getitem__`. Assignment to the 

348 sliced range is not yet supported. 

349 

350 Args: 

351 slice_spec: The arguments to __getitem__, specifying the global slicing of 

352 the sharded variable. 

353 

354 Returns: 

355 The appropriate slice of tensor based on `slice_spec`. 

356 

357 Raises: 

358 IndexError: If a slice index is out of bound. 

359 TypeError: If `spec_spec` contains Tensor. 

360 """ 

361 

362 # TODO(b/177482728): Support tensor input. 

363 # TODO(b/177482728): Support slice assign, similar to variable slice assign. 

364 

365 if (isinstance(slice_spec, bool) or (isinstance(slice_spec, ops.Tensor) and 

366 slice_spec.dtype == dtypes.bool) or 

367 (isinstance(slice_spec, np.ndarray) and slice_spec.dtype == bool)): 

368 tensor = _var_to_tensor(self) 

369 return array_ops.boolean_mask(tensor=tensor, mask=slice_spec) 

370 

371 if not isinstance(slice_spec, (list, tuple)): 

372 slice_spec = (slice_spec,) 

373 

374 s = slice_spec[0] 

375 if isinstance(s, slice): 

376 first_dim_slice_specs = self._decompose_slice_spec(s) 

377 values = [] 

378 for i, var in enumerate(self._variables): 

379 if first_dim_slice_specs[i] is not None: 

380 all_dim_slice_spec = (first_dim_slice_specs[i],) + slice_spec[1:] 

381 values.append(var[all_dim_slice_spec]) 

382 if s.step is not None and s.step < 0: 

383 values.reverse() 

384 if not values: 

385 return constant_op.constant([], 

386 dtype=self._dtype, 

387 shape=((0,) + self._shape[1:])) 

388 return array_ops.concat(values, axis=0) 

389 elif s is Ellipsis: 

390 return array_ops.concat([var[slice_spec] for var in self._variables], 

391 axis=0) 

392 elif s is array_ops.newaxis: 

393 return array_ops.concat([var[slice_spec[1:]] for var in self._variables], 

394 axis=0)[array_ops.newaxis] 

395 else: 

396 if isinstance(s, ops.Tensor): 

397 raise TypeError( 

398 'ShardedVariable: using Tensor for indexing is not allowed.') 

399 if s < 0: 

400 s += self._shape[0] 

401 if s < 0 or s >= self._shape[0]: 

402 raise IndexError( 

403 f'ShardedVariable: slice index {s} of dimension 0 out of bounds.') 

404 for i in range(len(self._variables)): 

405 if i == len(self._variables) - 1 or (s > self._var_offsets[i][0] and 

406 s < self._var_offsets[i + 1][0]): 

407 return self._variables[i][(s - self._var_offsets[i][0],) + 

408 slice_spec[1:]] 

409 

410 def _decompose_slice_spec(self, slice_spec): 

411 """Decompose a global slice_spec into a list of per-variable slice_spec. 

412 

413 `ShardedVariable` only supports first dimension partitioning, thus 

414 `slice_spec` must be for first dimension. 

415 

416 Args: 

417 slice_spec: A python `slice` object that specifies the global slicing. 

418 

419 Returns: 

420 A list of python `slice` objects or None specifying the local slicing for 

421 each component variable. None means no slicing. 

422 

423 For example, given component variables: 

424 v0 = [0, 1, 2] 

425 v1 = [3, 4, 5] 

426 v2 = [6, 7, 8, 9] 

427 

428 If `slice_spec` is slice(start=None, stop=None, step=None), we will have: 

429 v0[returned[0]] = [0, 1, 2] 

430 v1[returned[1]] = [3, 4, 5] 

431 v2[returned[2]] = [6, 7, 8, 9] 

432 If `slice_spec` is slice(start=2, stop=8, step=3), we will have: 

433 v0[returned[0]] = [2] 

434 v1[returned[1]] = [5] 

435 returned[2] == None 

436 If `slice_spec` is slice(start=9, stop=3, step=-2), we will have: 

437 returned[0] == None 

438 v1[returned[1]] = [5] 

439 v2[returned[2]] = [9, 7] 

440 """ 

441 if isinstance(slice_spec.start, ops.Tensor) or isinstance( 

442 slice_spec.stop, ops.Tensor) or isinstance(slice_spec.step, ops.Tensor): 

443 raise TypeError( 

444 'ShardedVariable: using Tensor in slice_spec is not allowed. Please ' 

445 'file a feature request with the TensorFlow team.') 

446 

447 result = [] 

448 # Normalize start, end and stop. 

449 slice_step = slice_spec.step if slice_spec.step is not None else 1 

450 if slice_step == 0: 

451 raise ValueError('slice step cannot be zero') 

452 slice_start = slice_spec.start 

453 if slice_start is None: 

454 slice_start = 0 if slice_step > 0 else self._shape[0] - 1 

455 elif slice_start < 0: 

456 slice_start += self._shape[0] 

457 slice_end = slice_spec.stop 

458 if slice_end is None: 

459 # After the normalization, we no longer interpret negative index, thus 

460 # "-1" conceptually refers to the element before the first one, which 

461 # doesn't exist. This is to ease the decomposition code. 

462 slice_end = self._shape[0] if slice_step > 0 else -1 

463 elif slice_end < 0: 

464 slice_end += self._shape[0] 

465 

466 # To find the local slice_spec of each component variable, we start from 

467 # the start of the global slice, and iterate through each variable. 

468 # When iterating on a variable, we move the cursor (`cur`) to the first 

469 # index that falls into the variable's range, which becomes the start of 

470 # the variable's local slice_spec. The end of the local_spec is determined 

471 # by using whatever is smaller between global slice end and variable range 

472 # end. 

473 cur = slice_start 

474 if slice_step > 0: 

475 for i in range(len(self._var_offsets)): 

476 var_start = self._var_offsets[i][0] 

477 var_end = ( 

478 self._var_offsets[i + 1][0] 

479 if i < len(self._var_offsets) - 1 else self._shape[0]) 

480 if cur < var_start: 

481 cur += slice_step * int(math.ceil((var_start - cur) / slice_step)) 

482 if cur >= var_end or cur >= slice_end: 

483 result.append(None) 

484 else: 

485 start = cur - var_start 

486 end = min(slice_end, var_end) - var_start 

487 result.append(slice(start, end, slice_step)) 

488 else: # slice_step < 0 

489 for i in range(len(self._var_offsets) - 1, -1, -1): 

490 var_start = self._var_offsets[i][0] 

491 var_end = ( 

492 self._var_offsets[i + 1][0] 

493 if i < len(self._var_offsets) - 1 else self._shape[0]) 

494 if cur >= var_end: 

495 cur += slice_step * int(math.ceil((var_end - cur - 1) / slice_step)) 

496 if cur < var_start or cur <= slice_end: 

497 result.append(None) 

498 else: 

499 start = cur - var_start 

500 if slice_end >= var_start: 

501 end = slice_end - var_start 

502 else: 

503 end = None # no explicit end: slice until hitting the boundary. 

504 result.append(slice(start, end, slice_step)) 

505 

506 result.reverse() 

507 

508 return result 

509 

510 @property 

511 def _type_spec(self): 

512 return ShardedVariableSpec( 

513 *(resource_variable_ops.VariableSpec(v.shape, v.dtype) 

514 for v in self._variables)) 

515 

516 @property 

517 def variables(self): 

518 """The list of `Variable`s that make up the shards of this object.""" 

519 if save_context.in_save_context(): 

520 return [self._saving_variable] 

521 return self._variables 

522 

523 @property 

524 def name(self): 

525 """The name of this object. Used for checkpointing.""" 

526 return self._name 

527 

528 @property 

529 def dtype(self): 

530 """The dtype of all `Variable`s in this object.""" 

531 return self._dtype 

532 

533 @property 

534 def shape(self): 

535 """The overall shape, combining all shards along axis `0`.""" 

536 return self._shape 

537 

538 def assign(self, value, use_locking=None, name=None, read_value=True): 

539 for i, v in enumerate(self._variables): 

540 v.assign(array_ops.slice(value, self._var_offsets[i], v.shape.as_list())) 

541 return self 

542 

543 def assign_add(self, delta, use_locking=False, name=None, read_value=True): 

544 for i, v in enumerate(self._variables): 

545 v.assign_add( 

546 array_ops.slice(delta, self._var_offsets[i], v.shape.as_list())) 

547 return self 

548 

549 def assign_sub(self, delta, use_locking=False, name=None, read_value=True): 

550 for i, v in enumerate(self._variables): 

551 v.assign_sub( 

552 array_ops.slice(delta, self._var_offsets[i], v.shape.as_list())) 

553 return self 

554 

555 def _decompose_indices(self, indices): 

556 """Decompose a global 1D indices into a list of per-variable indices.""" 

557 if indices.shape.rank != 1: 

558 raise ValueError( 

559 'ShardedVariable: indices must be 1D Tensor for sparse operations. ' 

560 f'Received shape: {indices.shape}') 

561 

562 base = self._shape[0] // len(self._variables) 

563 extra = self._shape[0] % len(self._variables) 

564 

565 # Assert that sharding conforms to "div" sharding 

566 expect_first_dim = [base] * len(self._variables) 

567 for i in range(extra): 

568 expect_first_dim[i] = expect_first_dim[i] + 1 

569 actual_first_dim = [v.shape.as_list()[0] for v in self._variables] 

570 if expect_first_dim != actual_first_dim: 

571 raise NotImplementedError( 

572 'scater_xxx ops are not supported in ShardedVariale that does not ' 

573 'conform to "div" sharding') 

574 

575 # For index that falls into the partition that has extra 1, assignment is 

576 # `index // (base + 1)` (no less than `(indices - extra) // base`) 

577 # For index that falls into the partition that doesn't has extra 1, 

578 # assignment is `(indices - extra) // base` (no less than 

579 # `indices // (base + 1)`) 

580 # 

581 # Example: 

582 # base = 10, extra = 2, partitions: [0, 11), [11, 22), [22, 32) 

583 # index = 10 -> partition_assigment = 0 

584 # index = 22 -> partition_assiment = 2 

585 partition_assignments = math_ops.maximum(indices // (base + 1), 

586 (indices - extra) // base) 

587 local_indices = array_ops.where(partition_assignments < extra, 

588 indices % (base + 1), 

589 (indices - extra) % base) 

590 # For whatever reason `dynamic_partition` only supports int32 

591 partition_assignments = math_ops.cast(partition_assignments, dtypes.int32) 

592 per_var_indices = data_flow_ops.dynamic_partition(local_indices, 

593 partition_assignments, 

594 len(self._variables)) 

595 

596 return per_var_indices, partition_assignments 

597 

598 def _decompose_indexed_slices(self, indexed_slices): 

599 """Decompose a global `IndexedSlices` into a list of per-variable ones.""" 

600 per_var_indices, partition_assignments = self._decompose_indices( 

601 indexed_slices.indices) 

602 per_var_values = data_flow_ops.dynamic_partition(indexed_slices.values, 

603 partition_assignments, 

604 len(self._variables)) 

605 

606 return [ 

607 indexed_slices_lib.IndexedSlices( 

608 values=per_var_values[i], indices=per_var_indices[i]) 

609 for i in range(len(self._variables)) 

610 ] 

611 

612 # ==================== scatter ops implementations ======================== # 

613 

614 def scatter_add(self, sparse_delta, use_locking=False, name=None): 

615 """Implements tf.Variable.scatter_add.""" 

616 per_var_sparse_delta = self._decompose_indexed_slices(sparse_delta) 

617 for i, v in enumerate(self._variables): 

618 new_name = None 

619 if name is not None: 

620 new_name = '{}/part_{}'.format(name, i) 

621 v.scatter_add(per_var_sparse_delta[i], name=new_name) 

622 return self 

623 

624 def scatter_div(self, sparse_delta, use_locking=False, name=None): 

625 """Implements tf.Variable.scatter_div.""" 

626 per_var_sparse_delta = self._decompose_indexed_slices(sparse_delta) 

627 for i, v in enumerate(self._variables): 

628 new_name = None 

629 if name is not None: 

630 new_name = '{}/part_{}'.format(name, i) 

631 v.scatter_div(per_var_sparse_delta[i], name=new_name) 

632 return self 

633 

634 def scatter_max(self, sparse_delta, use_locking=False, name=None): 

635 """Implements tf.Variable.scatter_max.""" 

636 per_var_sparse_delta = self._decompose_indexed_slices(sparse_delta) 

637 for i, v in enumerate(self._variables): 

638 new_name = None 

639 if name is not None: 

640 new_name = '{}/part_{}'.format(name, i) 

641 v.scatter_max(per_var_sparse_delta[i], name=new_name) 

642 return self 

643 

644 def scatter_min(self, sparse_delta, use_locking=False, name=None): 

645 """Implements tf.Variable.scatter_min.""" 

646 per_var_sparse_delta = self._decompose_indexed_slices(sparse_delta) 

647 for i, v in enumerate(self._variables): 

648 new_name = None 

649 if name is not None: 

650 new_name = '{}/part_{}'.format(name, i) 

651 v.scatter_min(per_var_sparse_delta[i], name=new_name) 

652 return self 

653 

654 def scatter_mul(self, sparse_delta, use_locking=False, name=None): 

655 """Implements tf.Variable.scatter_mul.""" 

656 per_var_sparse_delta = self._decompose_indexed_slices(sparse_delta) 

657 for i, v in enumerate(self._variables): 

658 new_name = None 

659 if name is not None: 

660 new_name = '{}/part_{}'.format(name, i) 

661 v.scatter_mul(per_var_sparse_delta[i], name=new_name) 

662 return self 

663 

664 def scatter_sub(self, sparse_delta, use_locking=False, name=None): 

665 """Implements tf.Variable.scatter_sub.""" 

666 per_var_sparse_delta = self._decompose_indexed_slices(sparse_delta) 

667 for i, v in enumerate(self._variables): 

668 new_name = None 

669 if name is not None: 

670 new_name = '{}/part_{}'.format(name, i) 

671 v.scatter_sub(per_var_sparse_delta[i], name=new_name) 

672 return self 

673 

674 def scatter_update(self, sparse_delta, use_locking=False, name=None): 

675 """Implements tf.Variable.scatter_update.""" 

676 per_var_sparse_delta = self._decompose_indexed_slices(sparse_delta) 

677 for i, v in enumerate(self._variables): 

678 new_name = None 

679 if name is not None: 

680 new_name = '{}/part_{}'.format(name, i) 

681 v.scatter_update(per_var_sparse_delta[i], name=new_name) 

682 return self 

683 

684 def batch_scatter_update(self, sparse_delta, use_locking=False, name=None): 

685 """Implements tf.Variable.batch_scatter_update.""" 

686 per_var_sparse_delta = self._decompose_indexed_slices(sparse_delta) 

687 for i, v in enumerate(self._variables): 

688 new_name = None 

689 if name is not None: 

690 new_name = '{}/part_{}'.format(name, i) 

691 v.batch_scatter_update(per_var_sparse_delta[i], name=new_name) 

692 return self 

693 

694 # ================== scatter ops implementations END ====================== # 

695 

696 def sparse_read(self, indices, name=None): 

697 """Implements tf.Variable.sparse_read.""" 

698 per_var_indices, _ = self._decompose_indices(indices) 

699 result = [] 

700 for i, v in enumerate(self._variables): 

701 new_name = None 

702 if name is not None: 

703 new_name = '{}/part_{}'.format(name, i) 

704 result.append(v.sparse_read(per_var_indices[i], name=new_name)) 

705 return array_ops.concat(result, axis=0) 

706 

707 def _gather_saveables_for_checkpoint(self): 

708 """Return a `Saveable` for each shard. See `Trackable`.""" 

709 

710 def _saveable_factory(name=self.name): 

711 """Creates `SaveableObject`s for this `ShardedVariable`.""" 

712 saveables = [] 

713 dims = len(self._variables[0].shape) 

714 var_offset = [0 for _ in range(dims)] 

715 for v in self._variables: 

716 save_slice_info = variables_lib.Variable.SaveSliceInfo( 

717 full_name=self.name, 

718 full_shape=self.shape.as_list(), 

719 var_offset=copy.copy(var_offset), 

720 var_shape=v.shape.as_list()) 

721 saveables.append( 

722 saveable_object_util.ResourceVariableSaveable( 

723 v, save_slice_info.spec, name)) 

724 var_offset[0] += int(v.shape[0]) 

725 return saveables 

726 

727 return {trackable.VARIABLE_VALUE_KEY: _saveable_factory} 

728 

729 def _export_to_saved_model_graph(self, object_map, tensor_map, 

730 options, **kwargs): 

731 """For implementing `Trackable`.""" 

732 resource_list = [] 

733 for v in self._variables + [self._saving_variable]: 

734 resource_list.extend(v._export_to_saved_model_graph( # pylint:disable=protected-access 

735 object_map, tensor_map, options, **kwargs)) 

736 object_map[self] = ShardedVariable([object_map[self._saving_variable]], 

737 name=self.name) 

738 return resource_list 

739 

740 @property 

741 def _unique_id(self): 

742 # String-replace to ensure uniqueness for checkpoint tracking 

743 return self.variables[0]._unique_id.replace('part_0', 'sharded') # pylint: disable=protected-access 

744 

745 @property 

746 def _distribute_strategy(self): 

747 return self.variables[0]._distribute_strategy # pylint: disable=protected-access 

748 

749 @property 

750 def _shared_name(self): 

751 return self._name 

752 

753 @property 

754 def is_sharded_variable(self): 

755 return True 

756 

757 def numpy(self): 

758 """Copies the values in this ShardedVariable to a NumPy array. 

759 

760 First converts to a single Tensor using the registered conversion function, 

761 which concatenates the shards, then uses Tensor.numpy() to convert to 

762 a NumPy array. 

763 

764 Returns: 

765 A NumPy array of the same shape and dtype. 

766 """ 

767 return _var_to_tensor(self).numpy() 

768 

769 

770@tf_export('__internal__.distribute.ShardedVariable', v1=[]) 

771class ShardedVariable(ShardedVariableMixin, composite_tensor.CompositeTensor): 

772 """A container for `Variables` that should be treated as shards. 

773 

774 Variables that are too large to fit on a single device (e.g., large 

775 embeddings) 

776 may need to be sharded over multiple devices. This class maintains a list of 

777 smaller variables that can be independently stored on separate devices (eg, 

778 multiple parameter servers), and saves and restores those variables as if they 

779 were a single larger variable. 

780 

781 Objects of this class can be saved with a given number of shards and then 

782 restored from a checkpoint into a different number of shards. 

783 

784 Objects of this class can be saved to SavedModel format using 

785 `tf.saved_model.save`. The SavedModel can be used by programs like TF serving 

786 APIs. It is not yet supported to load the SavedModel with 

787 `tf.saved_model.load`. 

788 

789 Since `ShardedVariable` can be saved and then restored to different number of 

790 shards depending on the restore environments, for example, TF serving APIs 

791 would restore to one shard for serving efficiency, when using 

792 `ShardedVariable` in a tf.function, one should generally not assume it has the 

793 same number of shards across save and load. 

794 

795 Sharding is only supported along the first dimension. 

796 

797 >>> class Model(tf.Module): 

798 ... def __init__(self): 

799 ... self.sharded_variable = ShardedVariable([ 

800 ... tf.Variable([3.0], dtype=tf.float32), 

801 ... tf.Variable([2.0], dtype=tf.float32) 

802 ... ]) 

803 ... 

804 ... @tf.function(input_signature=[tf.TensorSpec([], dtype=tf.int32)]) 

805 ... def fn(self, x): 

806 ... return tf.nn.embedding_lookup(self.sharded_variable.variables, x) 

807 ... 

808 ... @tf.function(input_signature=[tf.TensorSpec([], dtype=tf.int32)]) 

809 ... def serve_fn(self, x): 

810 ... return tf.nn.embedding_lookup(self.sharded_variable.variables, x) 

811 >>> 

812 >>> model = Model() 

813 >>> model.fn(1).numpy() 

814 2.0 

815 >>> tf.saved_model.save(model, export_dir='/tmp/saved_model', 

816 ... signatures=model.serve_fn) 

817 """ 

818 

819 @property 

820 def _type_spec(self): 

821 return ShardedVariableSpec( 

822 *(resource_variable_ops.VariableSpec(v.shape, v.dtype) 

823 for v in self._variables)) 

824 

825 @classmethod 

826 def _overload_all_operators(cls): 

827 """Register overloads for all operators.""" 

828 for operator in ops.Tensor.OVERLOADABLE_OPERATORS: 

829 if operator == '__getitem__': 

830 continue 

831 

832 cls._overload_operator(operator) 

833 

834 @classmethod 

835 def _overload_operator(cls, operator): 

836 """Delegate an operator overload to `ops.Tensor`.""" 

837 tensor_operator = getattr(ops.Tensor, operator) 

838 

839 def _operator(v, *args, **kwargs): 

840 return tensor_operator(_var_to_tensor(v), *args, **kwargs) 

841 

842 setattr(cls, operator, _operator) 

843 

844 def __tf_experimental_restore_capture__(self, concrete_function, 

845 internal_capture): 

846 # Avoid restoring captures for functions that use ShardedVariable - the 

847 # layer will be recreated during Keras model loading 

848 # TODO(jmullenbach): support loading models with ShardedVariables using 

849 # tf.saved_model.load 

850 return None 

851 

852 def _should_act_as_resource_variable(self): 

853 """Pass resource_variable_ops.is_resource_variable check.""" 

854 return True 

855 

856 def _write_object_proto(self, proto, options): 

857 resource_variable_ops.write_object_proto_for_resource_variable( 

858 self._saving_variable, proto, options, enforce_naming=False) 

859 

860 

861def _var_to_tensor(var, dtype=None, name=None, as_ref=False): 

862 """Converts a `ShardedVariable` to a `Tensor`.""" 

863 del name 

864 if dtype is not None and not dtype.is_compatible_with(var.dtype): 

865 raise ValueError( 

866 'Incompatible type conversion requested to type {!r} for variable ' 

867 'of type {!r}'.format(dtype.name, var.dtype.name)) 

868 if as_ref: 

869 raise NotImplementedError( 

870 "ShardedVariable doesn't support being used as a reference.") 

871 # We use op dispatch mechanism to override embedding_lookup ops when called 

872 # with ShardedVariable. This requires embedding_lookup ops to raise TypeError 

873 # when called with ShardedVariable. However since ShardedVariable can be 

874 # converted to a tensor via concat, embedding_lookup ops would silently 

875 # do the convertion and never raise a TypeError. To be able to properly 

876 # raise a TypeError, namescope is used to detect if this method is called 

877 # within a embedding_lookup op. 

878 # NOTE: This doesn't work in eager mode since op namescope is always cleared 

879 # in eager. This also breaks if user sets the name of embedding_lookup op 

880 # with something that doesn't contain str "embedding_lookup". 

881 # 

882 # TODO(chenkai): Find a more robust way to do this, which should not rely 

883 # on namescope. 

884 if 'embedding_lookup' in ops.get_name_scope(): 

885 raise TypeError('Converting ShardedVariable to tensor in embedding lookup' 

886 ' ops is disallowed.') 

887 return array_ops.concat(var.variables, axis=0) 

888 

889 

890# Register a conversion function which reads the value of the variable, 

891# allowing instances of the class to be used as tensors. 

892tensor_conversion_registry.register_tensor_conversion_function( 

893 ShardedVariable, _var_to_tensor) 

894 

895ShardedVariable._overload_all_operators() # pylint: disable=protected-access 

896 

897 

898# Override the behavior of embedding_lookup(sharded_variable, ...) 

899@dispatch.dispatch_for_types(embedding_ops.embedding_lookup, ShardedVariable) 

900def embedding_lookup(params, 

901 ids, 

902 partition_strategy='mod', 

903 name=None, 

904 validate_indices=True, 

905 max_norm=None): 

906 if isinstance(params, list): 

907 params = params[0] 

908 return embedding_ops.embedding_lookup(params.variables, ids, 

909 partition_strategy, name, 

910 validate_indices, max_norm) 

911 

912 

913# Separately override safe_embedding_lookup_sparse, to avoid conversion of 

914# ShardedVariable to tensor. 

915@dispatch.dispatch_for_api(embedding_ops.safe_embedding_lookup_sparse) 

916def safe_embedding_lookup_sparse( 

917 embedding_weights: ShardedVariable, 

918 sparse_ids, 

919 sparse_weights=None, 

920 combiner='mean', 

921 default_id=None, 

922 name=None, 

923 partition_strategy='div', 

924 max_norm=None, 

925 allow_fast_lookup=False, 

926): 

927 """Pass the individual shard variables as a list.""" 

928 return embedding_ops.safe_embedding_lookup_sparse( 

929 embedding_weights.variables, 

930 sparse_ids, 

931 sparse_weights=sparse_weights, 

932 combiner=combiner, 

933 default_id=default_id, 

934 name=name, 

935 partition_strategy=partition_strategy, 

936 max_norm=max_norm, 

937 allow_fast_lookup=allow_fast_lookup)