Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/distribute/sharded_variable.py: 32%
379 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""ShardedVariable class."""
16import copy
17import math
18from typing import Sequence
19import weakref
21import numpy as np
23from tensorflow.python.framework import composite_tensor
24from tensorflow.python.framework import constant_op
25from tensorflow.python.framework import dtypes
26from tensorflow.python.framework import indexed_slices as indexed_slices_lib
27from tensorflow.python.framework import ops
28from tensorflow.python.framework import tensor_conversion_registry
29from tensorflow.python.framework import tensor_shape
30from tensorflow.python.framework import type_spec
31from tensorflow.python.ops import array_ops
32from tensorflow.python.ops import data_flow_ops
33from tensorflow.python.ops import embedding_ops
34from tensorflow.python.ops import math_ops
35from tensorflow.python.ops import partitioned_variables
36from tensorflow.python.ops import resource_variable_ops
37from tensorflow.python.ops import variables as variables_lib
38from tensorflow.python.saved_model import save_context
39from tensorflow.python.trackable import base as trackable
40from tensorflow.python.training.saving import saveable_object_util
41from tensorflow.python.util import dispatch
42from tensorflow.python.util.tf_export import tf_export
45@tf_export('distribute.experimental.partitioners.Partitioner', v1=[])
46class Partitioner(object):
47 """Partitioner base class: all partitiners inherit from this class.
49 Partitioners should implement a `__call__` method with the following
50 signature:
52 ```python
53 def __call__(self, shape, dtype, axis=0):
54 # Partitions the given `shape` and returns the partition results.
55 # See docstring of `__call__` method for the format of partition results.
56 ```
57 """
59 def __call__(self, shape, dtype, axis=0):
60 """Partitions the given `shape` and returns the partition results.
62 Examples of a partitioner that allocates a fixed number of shards:
64 ```python
65 partitioner = FixedShardsPartitioner(num_shards=2)
66 partitions = partitioner(tf.TensorShape([10, 3], tf.float32), axis=0)
67 print(partitions) # [2, 0]
68 ```
70 Args:
71 shape: a `tf.TensorShape`, the shape to partition.
72 dtype: a `tf.dtypes.Dtype` indicating the type of the partition value.
73 axis: The axis to partition along. Default: outermost axis.
75 Returns:
76 A list of integers representing the number of partitions on each axis,
77 where i-th value correponds to i-th axis.
78 """
79 raise NotImplementedError
82@tf_export('distribute.experimental.partitioners.FixedShardsPartitioner', v1=[])
83class FixedShardsPartitioner(Partitioner):
84 """Partitioner that allocates a fixed number of shards.
86 Examples:
88 >>> # standalone usage:
89 >>> partitioner = FixedShardsPartitioner(num_shards=2)
90 >>> partitions = partitioner(tf.TensorShape([10, 3]), tf.float32)
91 >>> [2, 1]
92 >>>
93 >>> # use in ParameterServerStrategy
94 >>> # strategy = tf.distribute.experimental.ParameterServerStrategy(
95 >>> # cluster_resolver=cluster_resolver, variable_partitioner=partitioner)
97 """
99 def __init__(self, num_shards):
100 """Creates a new `FixedShardsPartitioner`.
102 Args:
103 num_shards: `int`, number of shards to partition.
104 """
105 self._num_shards = num_shards
107 def __call__(self, shape, dtype, axis=0):
108 del dtype
109 result = [1] * len(shape)
110 result[axis] = min(self._num_shards, shape.dims[axis].value)
111 return result
114@tf_export('distribute.experimental.partitioners.MinSizePartitioner', v1=[])
115class MinSizePartitioner(Partitioner):
116 """Partitioner that allocates a minimum size per shard.
118 This partitioner ensures each shard has at least `min_shard_bytes`, and tries
119 to allocate as many shards as possible, i.e., keeping shard size as small as
120 possible. The maximum number of such shards (upper bound) is given by
121 `max_shards`.
123 Examples:
125 >>> partitioner = MinSizePartitioner(min_shard_bytes=4, max_shards=2)
126 >>> partitions = partitioner(tf.TensorShape([6, 1]), tf.float32)
127 >>> [2, 1]
128 >>> partitioner = MinSizePartitioner(min_shard_bytes=4, max_shards=10)
129 >>> partitions = partitioner(tf.TensorShape([6, 1]), tf.float32)
130 >>> [6, 1]
131 >>>
132 >>> # use in ParameterServerStrategy
133 >>> # strategy = tf.distribute.experimental.ParameterServerStrategy(
134 >>> # cluster_resolver=cluster_resolver, variable_partitioner=partitioner)
135 """
137 def __init__(self,
138 min_shard_bytes=256 << 10,
139 max_shards=1,
140 bytes_per_string=16):
141 """Creates a new `MinSizePartitioner`.
143 Args:
144 min_shard_bytes: Minimum bytes of each shard. Defaults to 256K.
145 max_shards: Upper bound on the number of shards. Defaults to 1.
146 bytes_per_string: If the partition value is of type string, this provides
147 an estimate of how large each string is.
148 """
149 if min_shard_bytes < 1:
150 raise ValueError('Argument `min_shard_bytes` must be positive. '
151 f'Received: {min_shard_bytes}')
152 if max_shards < 1:
153 raise ValueError('Argument `max_shards` must be positive. '
154 f'Received: {max_shards}')
155 if bytes_per_string < 1:
156 raise ValueError('Argument `bytes_per_string` must be positive. '
157 f'Received: {bytes_per_string}')
158 self._min_shard_bytes = min_shard_bytes
159 self._max_shards = max_shards
160 self._bytes_per_string = bytes_per_string
162 def __call__(self, shape, dtype, axis=0):
163 return partitioned_variables.min_max_variable_partitioner(
164 max_partitions=self._max_shards,
165 axis=axis,
166 min_slice_size=self._min_shard_bytes,
167 bytes_per_string_element=self._bytes_per_string)(shape, dtype)
170@tf_export('distribute.experimental.partitioners.MaxSizePartitioner', v1=[])
171class MaxSizePartitioner(Partitioner):
172 """Partitioner that keeps shards below `max_shard_bytes`.
174 This partitioner ensures each shard has at most `max_shard_bytes`, and tries
175 to allocate as few shards as possible, i.e., keeping shard size as large
176 as possible.
178 If the partitioner hits the `max_shards` limit, then each shard may end up
179 larger than `max_shard_bytes`. By default `max_shards` equals `None` and no
180 limit on the number of shards is enforced.
182 Examples:
184 >>> partitioner = MaxSizePartitioner(max_shard_bytes=4)
185 >>> partitions = partitioner(tf.TensorShape([6, 1]), tf.float32)
186 >>> [6, 1]
187 >>> partitioner = MaxSizePartitioner(max_shard_bytes=4, max_shards=2)
188 >>> partitions = partitioner(tf.TensorShape([6, 1]), tf.float32)
189 >>> [2, 1]
190 >>> partitioner = MaxSizePartitioner(max_shard_bytes=1024)
191 >>> partitions = partitioner(tf.TensorShape([6, 1]), tf.float32)
192 >>> [1, 1]
193 >>>
194 >>> # use in ParameterServerStrategy
195 >>> # strategy = tf.distribute.experimental.ParameterServerStrategy(
196 >>> # cluster_resolver=cluster_resolver, variable_partitioner=partitioner)
197 """
199 def __init__(self, max_shard_bytes, max_shards=None, bytes_per_string=16):
200 """Creates a new `MaxSizePartitioner`.
202 Args:
203 max_shard_bytes: The maximum size any given shard is allowed to be.
204 max_shards: The maximum number of shards in `int` created taking
205 precedence over `max_shard_bytes`.
206 bytes_per_string: If the partition value is of type string, this provides
207 an estimate of how large each string is.
208 """
209 if max_shard_bytes < 1:
210 raise ValueError('Argument `max_shard_bytes` must be positive. '
211 f'Received {max_shard_bytes}')
212 if max_shards and max_shards < 1:
213 raise ValueError('Argument `max_shards` must be positive. '
214 f'Received {max_shards}')
215 if bytes_per_string < 1:
216 raise ValueError('Argument `bytes_per_string` must be positive. '
217 f'Received: {bytes_per_string}')
219 self._max_shard_bytes = max_shard_bytes
220 self._max_shards = max_shards
221 self._bytes_per_string = bytes_per_string
223 def __call__(self, shape, dtype, axis=0):
224 return partitioned_variables.variable_axis_size_partitioner(
225 max_shard_bytes=self._max_shard_bytes,
226 max_shards=self._max_shards,
227 bytes_per_string_element=self._bytes_per_string,
228 axis=axis)(shape, dtype)
231class ShardedVariableSpec(type_spec.TypeSpec):
232 """Type specification for a `ShardedVariable`."""
234 __slots__ = ['_variable_specs']
236 value_type = property(lambda self: ShardedVariable)
238 def __init__(self, *variable_specs):
239 self._variable_specs = tuple(variable_specs)
241 def _serialize(self):
242 return self._variable_specs
244 @property
245 def _component_specs(self):
246 return self._variable_specs
248 def _to_components(self, value):
249 return value.variables
251 def _from_components(self, variables):
252 return ShardedVariable(variables)
255class ShardedVariableMixin(trackable.Trackable):
256 """Mixin for ShardedVariable."""
258 # TODO(b/170877138): Remove this mixin once fixed. This mixin is required
259 # since TPUEmbeddingVariable can't be a CompositeTensor.
261 def __init__(self, variables, name='ShardedVariable'):
262 """Treats `variables` as shards of a larger Variable.
265 Example:
267 ```
268 variables = [
269 tf.Variable(..., shape=(10, 100), dtype=tf.float32),
270 tf.Variable(..., shape=(15, 100), dtype=tf.float32),
271 tf.Variable(..., shape=(5, 100), dtype=tf.float32)
272 ]
273 sharded_variable = ShardedVariableMixin(variables)
274 assert sharded_variable.shape.as_list() == [30, 100]
275 ```
277 Args:
278 variables: A list of `ResourceVariable`s that comprise this sharded
279 variable. Variables should not be shared between different
280 `ShardedVariableMixin` objects.
281 name: String. Name of this container. Defaults to "ShardedVariable".
282 """
283 super(ShardedVariableMixin, self).__init__()
284 self._variables = variables
285 self._name = name
287 if not isinstance(variables, Sequence) or not variables or any(
288 not isinstance(v, variables_lib.Variable) for v in variables):
289 raise TypeError('Argument `variables` should be a non-empty list of '
290 f'`variables.Variable`s. Received {variables}')
292 var_dtypes = {v.dtype for v in variables}
293 if len(var_dtypes) > 1:
294 raise ValueError(
295 'All elements in argument `variables` must have the same dtype. '
296 f'Received dtypes: {[v.dtype for v in variables]}')
298 first_var = variables[0]
299 self._dtype = first_var.dtype
301 # All variables must have the same shape for axes > 0.
302 higher_dim_shapes = {tuple(v.shape.as_list()[1:]) for v in variables}
303 if len(higher_dim_shapes) > 1:
304 raise ValueError(
305 'All elements in argument `variables` must have the same shapes '
306 'except for the first axis. '
307 f'Received shapes: {[v.shape for v in variables]}')
308 first_dim = sum(int(v.shape.as_list()[0]) for v in variables)
309 self._shape = tensor_shape.TensorShape([first_dim] +
310 first_var.shape.as_list()[1:])
312 for v in variables:
313 v._sharded_container = weakref.ref(self)
315 self._var_offsets = [
316 [0 for _ in range(len(first_var.shape))] for _ in range(len(variables))
317 ]
318 for i in range(1, len(variables)):
319 # Always partition on the first axis. Offsets on other axes are 0.
320 self._var_offsets[i][0] += (
321 self._var_offsets[i - 1][0] + variables[i - 1].shape.as_list()[0])
323 save_slice_info = [v._get_save_slice_info() for v in variables] # pylint: disable=protected-access
324 if any(slice_info is not None for slice_info in save_slice_info):
325 raise ValueError(
326 '`SaveSliceInfo` should not be set for all elements in argument '
327 '`variables`. `ShardedVariable` will infer `SaveSliceInfo` according '
328 'to the order of the elements `variables`. '
329 f'Received save slice info {save_slice_info}')
331 # We create an uninitialized saving_variable with the full shape, which can
332 # be later captured in signatures so that the signatures can treat this
333 # ShardedVariable as one single variable.
334 self._saving_variable = resource_variable_ops.UninitializedVariable(
335 shape=self._shape, dtype=self._dtype, name=self._name,
336 trainable=self._variables[0].trainable,
337 synchronization=variables_lib.VariableSynchronization.NONE,
338 aggregation=variables_lib.VariableAggregation.NONE)
340 def __iter__(self):
341 """Return an iterable for accessing the underlying sharded variables."""
342 return iter(self._variables)
344 def __getitem__(self, slice_spec):
345 """Extracts the specified region as a Tensor from the sharded variable.
347 The API contract is identical to `Tensor.__getitem__`. Assignment to the
348 sliced range is not yet supported.
350 Args:
351 slice_spec: The arguments to __getitem__, specifying the global slicing of
352 the sharded variable.
354 Returns:
355 The appropriate slice of tensor based on `slice_spec`.
357 Raises:
358 IndexError: If a slice index is out of bound.
359 TypeError: If `spec_spec` contains Tensor.
360 """
362 # TODO(b/177482728): Support tensor input.
363 # TODO(b/177482728): Support slice assign, similar to variable slice assign.
365 if (isinstance(slice_spec, bool) or (isinstance(slice_spec, ops.Tensor) and
366 slice_spec.dtype == dtypes.bool) or
367 (isinstance(slice_spec, np.ndarray) and slice_spec.dtype == bool)):
368 tensor = _var_to_tensor(self)
369 return array_ops.boolean_mask(tensor=tensor, mask=slice_spec)
371 if not isinstance(slice_spec, (list, tuple)):
372 slice_spec = (slice_spec,)
374 s = slice_spec[0]
375 if isinstance(s, slice):
376 first_dim_slice_specs = self._decompose_slice_spec(s)
377 values = []
378 for i, var in enumerate(self._variables):
379 if first_dim_slice_specs[i] is not None:
380 all_dim_slice_spec = (first_dim_slice_specs[i],) + slice_spec[1:]
381 values.append(var[all_dim_slice_spec])
382 if s.step is not None and s.step < 0:
383 values.reverse()
384 if not values:
385 return constant_op.constant([],
386 dtype=self._dtype,
387 shape=((0,) + self._shape[1:]))
388 return array_ops.concat(values, axis=0)
389 elif s is Ellipsis:
390 return array_ops.concat([var[slice_spec] for var in self._variables],
391 axis=0)
392 elif s is array_ops.newaxis:
393 return array_ops.concat([var[slice_spec[1:]] for var in self._variables],
394 axis=0)[array_ops.newaxis]
395 else:
396 if isinstance(s, ops.Tensor):
397 raise TypeError(
398 'ShardedVariable: using Tensor for indexing is not allowed.')
399 if s < 0:
400 s += self._shape[0]
401 if s < 0 or s >= self._shape[0]:
402 raise IndexError(
403 f'ShardedVariable: slice index {s} of dimension 0 out of bounds.')
404 for i in range(len(self._variables)):
405 if i == len(self._variables) - 1 or (s > self._var_offsets[i][0] and
406 s < self._var_offsets[i + 1][0]):
407 return self._variables[i][(s - self._var_offsets[i][0],) +
408 slice_spec[1:]]
410 def _decompose_slice_spec(self, slice_spec):
411 """Decompose a global slice_spec into a list of per-variable slice_spec.
413 `ShardedVariable` only supports first dimension partitioning, thus
414 `slice_spec` must be for first dimension.
416 Args:
417 slice_spec: A python `slice` object that specifies the global slicing.
419 Returns:
420 A list of python `slice` objects or None specifying the local slicing for
421 each component variable. None means no slicing.
423 For example, given component variables:
424 v0 = [0, 1, 2]
425 v1 = [3, 4, 5]
426 v2 = [6, 7, 8, 9]
428 If `slice_spec` is slice(start=None, stop=None, step=None), we will have:
429 v0[returned[0]] = [0, 1, 2]
430 v1[returned[1]] = [3, 4, 5]
431 v2[returned[2]] = [6, 7, 8, 9]
432 If `slice_spec` is slice(start=2, stop=8, step=3), we will have:
433 v0[returned[0]] = [2]
434 v1[returned[1]] = [5]
435 returned[2] == None
436 If `slice_spec` is slice(start=9, stop=3, step=-2), we will have:
437 returned[0] == None
438 v1[returned[1]] = [5]
439 v2[returned[2]] = [9, 7]
440 """
441 if isinstance(slice_spec.start, ops.Tensor) or isinstance(
442 slice_spec.stop, ops.Tensor) or isinstance(slice_spec.step, ops.Tensor):
443 raise TypeError(
444 'ShardedVariable: using Tensor in slice_spec is not allowed. Please '
445 'file a feature request with the TensorFlow team.')
447 result = []
448 # Normalize start, end and stop.
449 slice_step = slice_spec.step if slice_spec.step is not None else 1
450 if slice_step == 0:
451 raise ValueError('slice step cannot be zero')
452 slice_start = slice_spec.start
453 if slice_start is None:
454 slice_start = 0 if slice_step > 0 else self._shape[0] - 1
455 elif slice_start < 0:
456 slice_start += self._shape[0]
457 slice_end = slice_spec.stop
458 if slice_end is None:
459 # After the normalization, we no longer interpret negative index, thus
460 # "-1" conceptually refers to the element before the first one, which
461 # doesn't exist. This is to ease the decomposition code.
462 slice_end = self._shape[0] if slice_step > 0 else -1
463 elif slice_end < 0:
464 slice_end += self._shape[0]
466 # To find the local slice_spec of each component variable, we start from
467 # the start of the global slice, and iterate through each variable.
468 # When iterating on a variable, we move the cursor (`cur`) to the first
469 # index that falls into the variable's range, which becomes the start of
470 # the variable's local slice_spec. The end of the local_spec is determined
471 # by using whatever is smaller between global slice end and variable range
472 # end.
473 cur = slice_start
474 if slice_step > 0:
475 for i in range(len(self._var_offsets)):
476 var_start = self._var_offsets[i][0]
477 var_end = (
478 self._var_offsets[i + 1][0]
479 if i < len(self._var_offsets) - 1 else self._shape[0])
480 if cur < var_start:
481 cur += slice_step * int(math.ceil((var_start - cur) / slice_step))
482 if cur >= var_end or cur >= slice_end:
483 result.append(None)
484 else:
485 start = cur - var_start
486 end = min(slice_end, var_end) - var_start
487 result.append(slice(start, end, slice_step))
488 else: # slice_step < 0
489 for i in range(len(self._var_offsets) - 1, -1, -1):
490 var_start = self._var_offsets[i][0]
491 var_end = (
492 self._var_offsets[i + 1][0]
493 if i < len(self._var_offsets) - 1 else self._shape[0])
494 if cur >= var_end:
495 cur += slice_step * int(math.ceil((var_end - cur - 1) / slice_step))
496 if cur < var_start or cur <= slice_end:
497 result.append(None)
498 else:
499 start = cur - var_start
500 if slice_end >= var_start:
501 end = slice_end - var_start
502 else:
503 end = None # no explicit end: slice until hitting the boundary.
504 result.append(slice(start, end, slice_step))
506 result.reverse()
508 return result
510 @property
511 def _type_spec(self):
512 return ShardedVariableSpec(
513 *(resource_variable_ops.VariableSpec(v.shape, v.dtype)
514 for v in self._variables))
516 @property
517 def variables(self):
518 """The list of `Variable`s that make up the shards of this object."""
519 if save_context.in_save_context():
520 return [self._saving_variable]
521 return self._variables
523 @property
524 def name(self):
525 """The name of this object. Used for checkpointing."""
526 return self._name
528 @property
529 def dtype(self):
530 """The dtype of all `Variable`s in this object."""
531 return self._dtype
533 @property
534 def shape(self):
535 """The overall shape, combining all shards along axis `0`."""
536 return self._shape
538 def assign(self, value, use_locking=None, name=None, read_value=True):
539 for i, v in enumerate(self._variables):
540 v.assign(array_ops.slice(value, self._var_offsets[i], v.shape.as_list()))
541 return self
543 def assign_add(self, delta, use_locking=False, name=None, read_value=True):
544 for i, v in enumerate(self._variables):
545 v.assign_add(
546 array_ops.slice(delta, self._var_offsets[i], v.shape.as_list()))
547 return self
549 def assign_sub(self, delta, use_locking=False, name=None, read_value=True):
550 for i, v in enumerate(self._variables):
551 v.assign_sub(
552 array_ops.slice(delta, self._var_offsets[i], v.shape.as_list()))
553 return self
555 def _decompose_indices(self, indices):
556 """Decompose a global 1D indices into a list of per-variable indices."""
557 if indices.shape.rank != 1:
558 raise ValueError(
559 'ShardedVariable: indices must be 1D Tensor for sparse operations. '
560 f'Received shape: {indices.shape}')
562 base = self._shape[0] // len(self._variables)
563 extra = self._shape[0] % len(self._variables)
565 # Assert that sharding conforms to "div" sharding
566 expect_first_dim = [base] * len(self._variables)
567 for i in range(extra):
568 expect_first_dim[i] = expect_first_dim[i] + 1
569 actual_first_dim = [v.shape.as_list()[0] for v in self._variables]
570 if expect_first_dim != actual_first_dim:
571 raise NotImplementedError(
572 'scater_xxx ops are not supported in ShardedVariale that does not '
573 'conform to "div" sharding')
575 # For index that falls into the partition that has extra 1, assignment is
576 # `index // (base + 1)` (no less than `(indices - extra) // base`)
577 # For index that falls into the partition that doesn't has extra 1,
578 # assignment is `(indices - extra) // base` (no less than
579 # `indices // (base + 1)`)
580 #
581 # Example:
582 # base = 10, extra = 2, partitions: [0, 11), [11, 22), [22, 32)
583 # index = 10 -> partition_assigment = 0
584 # index = 22 -> partition_assiment = 2
585 partition_assignments = math_ops.maximum(indices // (base + 1),
586 (indices - extra) // base)
587 local_indices = array_ops.where(partition_assignments < extra,
588 indices % (base + 1),
589 (indices - extra) % base)
590 # For whatever reason `dynamic_partition` only supports int32
591 partition_assignments = math_ops.cast(partition_assignments, dtypes.int32)
592 per_var_indices = data_flow_ops.dynamic_partition(local_indices,
593 partition_assignments,
594 len(self._variables))
596 return per_var_indices, partition_assignments
598 def _decompose_indexed_slices(self, indexed_slices):
599 """Decompose a global `IndexedSlices` into a list of per-variable ones."""
600 per_var_indices, partition_assignments = self._decompose_indices(
601 indexed_slices.indices)
602 per_var_values = data_flow_ops.dynamic_partition(indexed_slices.values,
603 partition_assignments,
604 len(self._variables))
606 return [
607 indexed_slices_lib.IndexedSlices(
608 values=per_var_values[i], indices=per_var_indices[i])
609 for i in range(len(self._variables))
610 ]
612 # ==================== scatter ops implementations ======================== #
614 def scatter_add(self, sparse_delta, use_locking=False, name=None):
615 """Implements tf.Variable.scatter_add."""
616 per_var_sparse_delta = self._decompose_indexed_slices(sparse_delta)
617 for i, v in enumerate(self._variables):
618 new_name = None
619 if name is not None:
620 new_name = '{}/part_{}'.format(name, i)
621 v.scatter_add(per_var_sparse_delta[i], name=new_name)
622 return self
624 def scatter_div(self, sparse_delta, use_locking=False, name=None):
625 """Implements tf.Variable.scatter_div."""
626 per_var_sparse_delta = self._decompose_indexed_slices(sparse_delta)
627 for i, v in enumerate(self._variables):
628 new_name = None
629 if name is not None:
630 new_name = '{}/part_{}'.format(name, i)
631 v.scatter_div(per_var_sparse_delta[i], name=new_name)
632 return self
634 def scatter_max(self, sparse_delta, use_locking=False, name=None):
635 """Implements tf.Variable.scatter_max."""
636 per_var_sparse_delta = self._decompose_indexed_slices(sparse_delta)
637 for i, v in enumerate(self._variables):
638 new_name = None
639 if name is not None:
640 new_name = '{}/part_{}'.format(name, i)
641 v.scatter_max(per_var_sparse_delta[i], name=new_name)
642 return self
644 def scatter_min(self, sparse_delta, use_locking=False, name=None):
645 """Implements tf.Variable.scatter_min."""
646 per_var_sparse_delta = self._decompose_indexed_slices(sparse_delta)
647 for i, v in enumerate(self._variables):
648 new_name = None
649 if name is not None:
650 new_name = '{}/part_{}'.format(name, i)
651 v.scatter_min(per_var_sparse_delta[i], name=new_name)
652 return self
654 def scatter_mul(self, sparse_delta, use_locking=False, name=None):
655 """Implements tf.Variable.scatter_mul."""
656 per_var_sparse_delta = self._decompose_indexed_slices(sparse_delta)
657 for i, v in enumerate(self._variables):
658 new_name = None
659 if name is not None:
660 new_name = '{}/part_{}'.format(name, i)
661 v.scatter_mul(per_var_sparse_delta[i], name=new_name)
662 return self
664 def scatter_sub(self, sparse_delta, use_locking=False, name=None):
665 """Implements tf.Variable.scatter_sub."""
666 per_var_sparse_delta = self._decompose_indexed_slices(sparse_delta)
667 for i, v in enumerate(self._variables):
668 new_name = None
669 if name is not None:
670 new_name = '{}/part_{}'.format(name, i)
671 v.scatter_sub(per_var_sparse_delta[i], name=new_name)
672 return self
674 def scatter_update(self, sparse_delta, use_locking=False, name=None):
675 """Implements tf.Variable.scatter_update."""
676 per_var_sparse_delta = self._decompose_indexed_slices(sparse_delta)
677 for i, v in enumerate(self._variables):
678 new_name = None
679 if name is not None:
680 new_name = '{}/part_{}'.format(name, i)
681 v.scatter_update(per_var_sparse_delta[i], name=new_name)
682 return self
684 def batch_scatter_update(self, sparse_delta, use_locking=False, name=None):
685 """Implements tf.Variable.batch_scatter_update."""
686 per_var_sparse_delta = self._decompose_indexed_slices(sparse_delta)
687 for i, v in enumerate(self._variables):
688 new_name = None
689 if name is not None:
690 new_name = '{}/part_{}'.format(name, i)
691 v.batch_scatter_update(per_var_sparse_delta[i], name=new_name)
692 return self
694 # ================== scatter ops implementations END ====================== #
696 def sparse_read(self, indices, name=None):
697 """Implements tf.Variable.sparse_read."""
698 per_var_indices, _ = self._decompose_indices(indices)
699 result = []
700 for i, v in enumerate(self._variables):
701 new_name = None
702 if name is not None:
703 new_name = '{}/part_{}'.format(name, i)
704 result.append(v.sparse_read(per_var_indices[i], name=new_name))
705 return array_ops.concat(result, axis=0)
707 def _gather_saveables_for_checkpoint(self):
708 """Return a `Saveable` for each shard. See `Trackable`."""
710 def _saveable_factory(name=self.name):
711 """Creates `SaveableObject`s for this `ShardedVariable`."""
712 saveables = []
713 dims = len(self._variables[0].shape)
714 var_offset = [0 for _ in range(dims)]
715 for v in self._variables:
716 save_slice_info = variables_lib.Variable.SaveSliceInfo(
717 full_name=self.name,
718 full_shape=self.shape.as_list(),
719 var_offset=copy.copy(var_offset),
720 var_shape=v.shape.as_list())
721 saveables.append(
722 saveable_object_util.ResourceVariableSaveable(
723 v, save_slice_info.spec, name))
724 var_offset[0] += int(v.shape[0])
725 return saveables
727 return {trackable.VARIABLE_VALUE_KEY: _saveable_factory}
729 def _export_to_saved_model_graph(self, object_map, tensor_map,
730 options, **kwargs):
731 """For implementing `Trackable`."""
732 resource_list = []
733 for v in self._variables + [self._saving_variable]:
734 resource_list.extend(v._export_to_saved_model_graph( # pylint:disable=protected-access
735 object_map, tensor_map, options, **kwargs))
736 object_map[self] = ShardedVariable([object_map[self._saving_variable]],
737 name=self.name)
738 return resource_list
740 @property
741 def _unique_id(self):
742 # String-replace to ensure uniqueness for checkpoint tracking
743 return self.variables[0]._unique_id.replace('part_0', 'sharded') # pylint: disable=protected-access
745 @property
746 def _distribute_strategy(self):
747 return self.variables[0]._distribute_strategy # pylint: disable=protected-access
749 @property
750 def _shared_name(self):
751 return self._name
753 @property
754 def is_sharded_variable(self):
755 return True
757 def numpy(self):
758 """Copies the values in this ShardedVariable to a NumPy array.
760 First converts to a single Tensor using the registered conversion function,
761 which concatenates the shards, then uses Tensor.numpy() to convert to
762 a NumPy array.
764 Returns:
765 A NumPy array of the same shape and dtype.
766 """
767 return _var_to_tensor(self).numpy()
770@tf_export('__internal__.distribute.ShardedVariable', v1=[])
771class ShardedVariable(ShardedVariableMixin, composite_tensor.CompositeTensor):
772 """A container for `Variables` that should be treated as shards.
774 Variables that are too large to fit on a single device (e.g., large
775 embeddings)
776 may need to be sharded over multiple devices. This class maintains a list of
777 smaller variables that can be independently stored on separate devices (eg,
778 multiple parameter servers), and saves and restores those variables as if they
779 were a single larger variable.
781 Objects of this class can be saved with a given number of shards and then
782 restored from a checkpoint into a different number of shards.
784 Objects of this class can be saved to SavedModel format using
785 `tf.saved_model.save`. The SavedModel can be used by programs like TF serving
786 APIs. It is not yet supported to load the SavedModel with
787 `tf.saved_model.load`.
789 Since `ShardedVariable` can be saved and then restored to different number of
790 shards depending on the restore environments, for example, TF serving APIs
791 would restore to one shard for serving efficiency, when using
792 `ShardedVariable` in a tf.function, one should generally not assume it has the
793 same number of shards across save and load.
795 Sharding is only supported along the first dimension.
797 >>> class Model(tf.Module):
798 ... def __init__(self):
799 ... self.sharded_variable = ShardedVariable([
800 ... tf.Variable([3.0], dtype=tf.float32),
801 ... tf.Variable([2.0], dtype=tf.float32)
802 ... ])
803 ...
804 ... @tf.function(input_signature=[tf.TensorSpec([], dtype=tf.int32)])
805 ... def fn(self, x):
806 ... return tf.nn.embedding_lookup(self.sharded_variable.variables, x)
807 ...
808 ... @tf.function(input_signature=[tf.TensorSpec([], dtype=tf.int32)])
809 ... def serve_fn(self, x):
810 ... return tf.nn.embedding_lookup(self.sharded_variable.variables, x)
811 >>>
812 >>> model = Model()
813 >>> model.fn(1).numpy()
814 2.0
815 >>> tf.saved_model.save(model, export_dir='/tmp/saved_model',
816 ... signatures=model.serve_fn)
817 """
819 @property
820 def _type_spec(self):
821 return ShardedVariableSpec(
822 *(resource_variable_ops.VariableSpec(v.shape, v.dtype)
823 for v in self._variables))
825 @classmethod
826 def _overload_all_operators(cls):
827 """Register overloads for all operators."""
828 for operator in ops.Tensor.OVERLOADABLE_OPERATORS:
829 if operator == '__getitem__':
830 continue
832 cls._overload_operator(operator)
834 @classmethod
835 def _overload_operator(cls, operator):
836 """Delegate an operator overload to `ops.Tensor`."""
837 tensor_operator = getattr(ops.Tensor, operator)
839 def _operator(v, *args, **kwargs):
840 return tensor_operator(_var_to_tensor(v), *args, **kwargs)
842 setattr(cls, operator, _operator)
844 def __tf_experimental_restore_capture__(self, concrete_function,
845 internal_capture):
846 # Avoid restoring captures for functions that use ShardedVariable - the
847 # layer will be recreated during Keras model loading
848 # TODO(jmullenbach): support loading models with ShardedVariables using
849 # tf.saved_model.load
850 return None
852 def _should_act_as_resource_variable(self):
853 """Pass resource_variable_ops.is_resource_variable check."""
854 return True
856 def _write_object_proto(self, proto, options):
857 resource_variable_ops.write_object_proto_for_resource_variable(
858 self._saving_variable, proto, options, enforce_naming=False)
861def _var_to_tensor(var, dtype=None, name=None, as_ref=False):
862 """Converts a `ShardedVariable` to a `Tensor`."""
863 del name
864 if dtype is not None and not dtype.is_compatible_with(var.dtype):
865 raise ValueError(
866 'Incompatible type conversion requested to type {!r} for variable '
867 'of type {!r}'.format(dtype.name, var.dtype.name))
868 if as_ref:
869 raise NotImplementedError(
870 "ShardedVariable doesn't support being used as a reference.")
871 # We use op dispatch mechanism to override embedding_lookup ops when called
872 # with ShardedVariable. This requires embedding_lookup ops to raise TypeError
873 # when called with ShardedVariable. However since ShardedVariable can be
874 # converted to a tensor via concat, embedding_lookup ops would silently
875 # do the convertion and never raise a TypeError. To be able to properly
876 # raise a TypeError, namescope is used to detect if this method is called
877 # within a embedding_lookup op.
878 # NOTE: This doesn't work in eager mode since op namescope is always cleared
879 # in eager. This also breaks if user sets the name of embedding_lookup op
880 # with something that doesn't contain str "embedding_lookup".
881 #
882 # TODO(chenkai): Find a more robust way to do this, which should not rely
883 # on namescope.
884 if 'embedding_lookup' in ops.get_name_scope():
885 raise TypeError('Converting ShardedVariable to tensor in embedding lookup'
886 ' ops is disallowed.')
887 return array_ops.concat(var.variables, axis=0)
890# Register a conversion function which reads the value of the variable,
891# allowing instances of the class to be used as tensors.
892tensor_conversion_registry.register_tensor_conversion_function(
893 ShardedVariable, _var_to_tensor)
895ShardedVariable._overload_all_operators() # pylint: disable=protected-access
898# Override the behavior of embedding_lookup(sharded_variable, ...)
899@dispatch.dispatch_for_types(embedding_ops.embedding_lookup, ShardedVariable)
900def embedding_lookup(params,
901 ids,
902 partition_strategy='mod',
903 name=None,
904 validate_indices=True,
905 max_norm=None):
906 if isinstance(params, list):
907 params = params[0]
908 return embedding_ops.embedding_lookup(params.variables, ids,
909 partition_strategy, name,
910 validate_indices, max_norm)
913# Separately override safe_embedding_lookup_sparse, to avoid conversion of
914# ShardedVariable to tensor.
915@dispatch.dispatch_for_api(embedding_ops.safe_embedding_lookup_sparse)
916def safe_embedding_lookup_sparse(
917 embedding_weights: ShardedVariable,
918 sparse_ids,
919 sparse_weights=None,
920 combiner='mean',
921 default_id=None,
922 name=None,
923 partition_strategy='div',
924 max_norm=None,
925 allow_fast_lookup=False,
926):
927 """Pass the individual shard variables as a list."""
928 return embedding_ops.safe_embedding_lookup_sparse(
929 embedding_weights.variables,
930 sparse_ids,
931 sparse_weights=sparse_weights,
932 combiner=combiner,
933 default_id=default_id,
934 name=name,
935 partition_strategy=partition_strategy,
936 max_norm=max_norm,
937 allow_fast_lookup=allow_fast_lookup)