Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/distribute/ps_values.py: 42%
443 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Various classes representing distributed values for PS."""
17import contextlib
18import copy
19import threading
20import weakref
22import numpy as np
24from tensorflow.python.distribute import distribute_lib
25from tensorflow.python.distribute import distribute_utils
26from tensorflow.python.distribute import values
27from tensorflow.python.distribute import values_util
28from tensorflow.python.distribute.coordinator import coordinator_context
29from tensorflow.python.eager import context
30from tensorflow.python.framework import dtypes
31from tensorflow.python.framework import ops
32from tensorflow.python.framework import tensor_conversion_registry
33from tensorflow.python.framework import tensor_spec
34from tensorflow.python.ops import array_ops
35from tensorflow.python.ops import lookup_ops
36from tensorflow.python.ops import resource_variable_ops
37from tensorflow.python.ops import variable_scope as vs
38from tensorflow.python.saved_model import save_context
39from tensorflow.python.trackable import base as trackable
40from tensorflow.python.types import core
41from tensorflow.python.util.lazy_loader import LazyLoader
43load_context = LazyLoader(
44 "load_context", globals(),
45 "tensorflow.python.keras.saving.saved_model.load_context"
46)
48TRACKABLE_RESOURCE_METHODS = [
49 "_create_resource", "_initialize", "_destroy_resource"
50]
53# Variable used in PSStrategy TF 1, TF2 and CentralStorageStrategy.
54class AggregatingVariable(resource_variable_ops.BaseResourceVariable,
55 core.Tensor):
56 """A wrapper around a variable that aggregates updates across replicas."""
58 def __init__(self, strategy, v, aggregation):
59 self._distribute_strategy = strategy
60 self._v = v
61 # NOTE: We don't use "_distributed_container" here because we don't want
62 # to trigger that code path in regroup().
63 v._aggregating_container = weakref.ref(self) # pylint: disable=protected-access
64 self._aggregation = aggregation
66 def __deepcopy__(self, memo):
67 """Perform a deepcopy of the `AggregatingVariable`.
69 Unlike the deepcopy of a regular tf.Variable, this keeps the original
70 strategy and devices of the `AggregatingVariable`. To avoid confusion
71 with the behavior of deepcopy on a regular `Variable` (which does
72 copy into new devices), we only allow a deepcopy of a `AggregatingVariable`
73 within its originating strategy scope.
75 Args:
76 memo: The memoization object for `deepcopy`.
78 Returns:
79 A deep copy of the current `AggregatingVariable`.
81 Raises:
82 RuntimeError: If trying to deepcopy into a different strategy.
83 """
84 with distribute_lib.enter_or_assert_strategy(self._distribute_strategy):
85 v = copy.deepcopy(self._v, memo)
87 copied_variable = type(self)(
88 strategy=self._distribute_strategy,
89 v=v,
90 aggregation=self._aggregation)
92 memo[id(self)] = copied_variable
94 return copied_variable
96 def get(self):
97 return self._v
99 @property
100 def distribute_strategy(self):
101 return self._distribute_strategy
103 def __getattr__(self, name):
104 return getattr(self._v, name)
106 def _assign_func(self, *args, **kwargs):
107 with distribute_lib.enter_or_assert_strategy(self._distribute_strategy):
108 f = kwargs.pop("f")
109 if distribute_lib.in_cross_replica_context():
110 if distribute_lib.get_update_replica_id() is not None:
111 # We are calling an assign function in an update context.
112 return f(self._v, *args, **kwargs)
114 # We are calling an assign function in cross replica context, wrap it in
115 # an update call.
116 return self._distribute_strategy.extended.update(
117 self, f, args=args, kwargs=kwargs)
118 else:
119 replica_context = distribute_lib.get_replica_context()
120 assert replica_context
121 # We are calling an assign function in replica context.
122 # We reduce the value we want to assign/add/sub. More details about how
123 # we handle the different use cases can be found in the _reduce method.
124 # We call the function with the reduced value.
125 if self._aggregation == vs.VariableAggregation.NONE:
126 raise ValueError(
127 values_util.aggregation_error_msg.format(
128 variable_type="AggregatingVariable"))
130 def merge_fn(strategy,
131 value,
132 use_locking=False,
133 name=None,
134 read_value=True):
135 v = values_util.apply_aggregation(strategy, value, self._aggregation,
136 self)
137 if name and isinstance(name, values.PerReplica):
138 name = name.values[0]
139 return strategy.extended.update(
140 self,
141 f,
142 args=(v,),
143 kwargs={
144 "use_locking": use_locking,
145 "name": name,
146 "read_value": read_value
147 })
148 return replica_context.merge_call(merge_fn, args=args, kwargs=kwargs)
150 def assign_sub(self, *args, **kwargs):
151 assign_sub_fn = lambda var, *a, **kw: var.assign_sub(*a, **kw)
152 return self._assign_func(f=assign_sub_fn, *args, **kwargs)
154 def assign_add(self, *args, **kwargs):
155 assign_add_fn = lambda var, *a, **kw: var.assign_add(*a, **kw)
156 return self._assign_func(f=assign_add_fn, *args, **kwargs)
158 def assign(self, *args, **kwargs):
159 assign_fn = lambda var, *a, **kw: var.assign(*a, **kw)
160 return self._assign_func(f=assign_fn, *args, **kwargs)
162 @property
163 def initializer(self):
164 return self._v.initializer
166 def initialized_value(self):
167 return self._v.initialized_value()
169 @property
170 def initial_value(self):
171 return self._v.initial_value
173 @property
174 def op(self):
175 return self._v.op
177 def value(self):
178 return self._v.value()
180 def read_value(self):
181 return self._v.read_value()
183 def sparse_read(self, indices, name=None):
184 return self._v.sparse_read(indices, name=name)
186 def eval(self, session=None):
187 return self._v.eval(session)
189 @property
190 def graph(self):
191 return self._v.graph
193 @property
194 def device(self):
195 return self._v.device
197 @property
198 def shape(self):
199 return self._v.shape
201 @property
202 def aggregation(self):
203 return self._aggregation
205 @property
206 def synchronization(self):
207 return self._v.synchronization
209 @property
210 def name(self):
211 return self._v.name
213 @property
214 def trainable(self):
215 return self._v.trainable
217 @property
218 def dtype(self):
219 return self._v.dtype
221 # TODO(josh11b): Test saving & restoring.
222 def _gather_saveables_for_checkpoint(self):
223 if isinstance(self._v, CachingVariable):
224 return self._v._gather_saveables_for_checkpoint() # pylint:disable=protected-access
225 return {trackable.VARIABLE_VALUE_KEY: self._v}
227 def _export_to_saved_model_graph(self, object_map, tensor_map,
228 options, **kwargs):
229 """For implementing `Trackable`."""
230 # By delegating this method to the wrapped variable, SavedModel with
231 # AggregatingVariable are identical to SavedModel with normal variables.
232 resource_list = self._v._export_to_saved_model_graph(object_map, tensor_map, # pylint:disable=protected-access
233 options, **kwargs)
234 object_map[self] = object_map[self._v]
235 return resource_list
237 # pylint: disable=multiple-statements
238 def __add__(self, o):
239 return self._v + o
241 def __radd__(self, o):
242 return o + self._v
244 def __sub__(self, o):
245 return self._v - o
247 def __rsub__(self, o):
248 return o - self._v
250 def __mul__(self, o):
251 return self._v * o
253 def __rmul__(self, o):
254 return o * self._v
256 def __truediv__(self, o):
257 return self._v / o
259 def __rtruediv__(self, o):
260 return o / self._v
262 def __floordiv__(self, o):
263 return self._v // o
265 def __rfloordiv__(self, o):
266 return o // self._v
268 def __mod__(self, o):
269 return self._v % o
271 def __rmod__(self, o):
272 return o % self._v
274 def __lt__(self, o):
275 return self._v < o
277 def __le__(self, o):
278 return self._v <= o
280 def __gt__(self, o):
281 return self._v > o
283 def __ge__(self, o):
284 return self._v >= o
286 def __and__(self, o):
287 return self._v & o
289 def __rand__(self, o):
290 return o & self._v
292 def __or__(self, o):
293 return self._v | o
295 def __ror__(self, o):
296 return o | self._v
298 def __xor__(self, o):
299 return self._v ^ o
301 def __rxor__(self, o):
302 return o ^ self._v
304 def __getitem__(self, o):
305 return self._v[o]
307 def __pow__(self, o, modulo=None):
308 return pow(self._v, o, modulo)
310 def __rpow__(self, o):
311 return pow(o, self._v)
313 def __invert__(self):
314 return ~self._v
316 def __neg__(self):
317 return -self._v
319 def __abs__(self):
320 return abs(self._v)
322 def __div__(self, o):
323 try:
324 return self._v.__div__(o)
325 except AttributeError:
326 # See https://docs.python.org/3/library/constants.html#NotImplemented
327 return NotImplemented
329 def __rdiv__(self, o):
330 try:
331 return self._v.__rdiv__(o)
332 except AttributeError:
333 # See https://docs.python.org/3/library/constants.html#NotImplemented
334 return NotImplemented
336 def __matmul__(self, o):
337 try:
338 return self._v.__matmul__(o)
339 except AttributeError:
340 # See https://docs.python.org/3/library/constants.html#NotImplemented
341 return NotImplemented
343 def __rmatmul__(self, o):
344 try:
345 return self._v.__rmatmul__(o)
346 except AttributeError:
347 # See https://docs.python.org/3/library/constants.html#NotImplemented
348 return NotImplemented
350 def __str__(self):
351 return str(self._v)
353 def __repr__(self):
354 return repr(self._v)
356 def _should_act_as_resource_variable(self):
357 """Pass resource_variable_ops.is_resource_variable check."""
358 pass
360 def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
361 return self._v._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref) # pylint: disable=protected-access
364class CachingVariable(resource_variable_ops.BaseResourceVariable, core.Tensor):
365 """A wrapper around a variable that caches read value locally."""
367 def __init__(self, v):
368 self._v = v
369 self._cache = None
370 self._current_new_cache_scope_count = 0
372 def get(self):
373 return self._v
375 def __getattr__(self, name):
376 return getattr(self._v, name)
378 def read_value(self):
379 if distribute_utils.caching_scope_local.in_caching_scope():
380 return self.cached_read_value()
381 return self._v.read_value()
383 def sparse_read(self, indices, name=None):
384 return self._v.sparse_read(indices, name=name)
386 def cached_read_value(self):
387 if (distribute_utils.caching_scope_local.new_cache_scope_count >
388 self._current_new_cache_scope_count):
389 self._current_new_cache_scope_count += 1
390 self._cache = None
392 with ops.device("CPU:0"):
393 if self._cache is not None:
394 return self._cache
395 else:
396 self._cache = array_ops.identity(self._v)
397 return self._cache
399 def assign_sub(self, *args, **kwargs):
400 return self._v.assign_sub(*args, **kwargs)
402 def assign_add(self, *args, **kwargs):
403 return self._v.assign_add(*args, **kwargs)
405 def assign(self, *args, **kwargs):
406 return self._v.assign(*args, **kwargs)
408 @property
409 def initializer(self):
410 return self._v.initializer
412 def initialized_value(self):
413 return self._v.initialized_value()
415 @property
416 def initial_value(self):
417 return self._v.initial_value
419 @property
420 def op(self):
421 return self._v.op
423 def value(self):
424 if distribute_utils.caching_scope_local.in_caching_scope():
425 return self.cached_read_value()
426 return self._v.value()
428 def eval(self, session=None):
429 return self._v.eval(session)
431 @property
432 def graph(self):
433 return self._v.graph
435 @property
436 def device(self):
437 return self._v.device
439 @property
440 def shape(self):
441 return self._v.shape
443 @property
444 def synchronization(self):
445 return self._v.synchronization
447 @property
448 def name(self):
449 return self._v.name
451 @property
452 def trainable(self):
453 return self._v.trainable
455 @property
456 def dtype(self):
457 return self._v.dtype
459 @property
460 def constraint(self):
461 return self._v.constraint
463 def __array__(self, dtype=None):
464 return np.asarray(self.numpy(), dtype=dtype)
466 def __complex__(self):
467 return complex(self.value().numpy())
469 def __int__(self):
470 return int(self.value().numpy())
472 def __float__(self):
473 return float(self.value().numpy())
475 def numpy(self):
476 if context.executing_eagerly():
477 return self.read_value().numpy()
478 else:
479 raise NotImplementedError(
480 "numpy() is only available when eager execution is enabled.")
482 def __str__(self):
483 return str(self._v)
485 def __repr__(self):
486 return repr(self._v)
488 def _should_act_as_resource_variable(self):
489 """Pass resource_variable_ops.is_resource_variable check."""
490 pass
492 def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
493 if distribute_utils.caching_scope_local.in_caching_scope():
494 return self.cached_read_value()
495 return self._v._dense_var_to_tensor(dtype=dtype, name=name, as_ref=False) # pylint: disable=protected-access
497 @classmethod
498 def _overload_overloadable_operators(cls):
499 """Register overloads for all operators."""
500 for operator in ops.Tensor.OVERLOADABLE_OPERATORS:
501 # Overloading __eq__ or __ne__ does not work as expected.
502 if operator == "__eq__" or operator == "__ne__":
503 continue
504 cls._tensor_overload_operator(operator)
506 @classmethod
507 def _tensor_overload_operator(cls, operator):
508 """Delegate an operator overload to `ops.Tensor`."""
509 tensor_operator = getattr(ops.Tensor, operator)
511 def _operator(v, *args, **kwargs):
512 return tensor_operator(v.value(), *args, **kwargs) # pylint: disable=protected-access
513 setattr(cls, operator, _operator)
515 def _gather_saveables_for_checkpoint(self):
516 return {trackable.VARIABLE_VALUE_KEY: self._v}
518 def _export_to_saved_model_graph(self, object_map, tensor_map,
519 options, **kwargs):
520 """For implementing `Trackable`."""
521 # By delegating this method to the wrapped variable, SavedModel with
522 # AggregatingVariable are identical to SavedModel with normal variables.
523 resource_list = self._v._export_to_saved_model_graph(object_map, tensor_map, # pylint:disable=protected-access
524 options, **kwargs)
525 object_map[self] = object_map[self._v]
526 return resource_list
529# Register a conversion function which reads the value of the variable,
530# allowing instances of the class to be used as tensors.
531def _tensor_conversion_aggregate(var, dtype=None, name=None, as_ref=False):
532 return var._dense_var_to_tensor(dtype, name, as_ref) # pylint: disable=protected-access
535tensor_conversion_registry.register_tensor_conversion_function(
536 AggregatingVariable, _tensor_conversion_aggregate)
539# Register a conversion function which reads the value of the variable,
540# allowing instances of the class to be used as tensors.
541def _tensor_conversion_caching(var, dtype=None, name=None, as_ref=False):
542 return var._dense_var_to_tensor(dtype, name, as_ref) # pylint: disable=protected-access
545tensor_conversion_registry.register_tensor_conversion_function(
546 CachingVariable, _tensor_conversion_caching)
548CachingVariable._overload_overloadable_operators() # pylint: disable=protected-access
551class DistributedTable(lookup_ops.StaticHashTable):
552 """A distributed StaticHashTable for ParameterServerStrategy.
554 An instance of DistributedTable has copies of a StaticHashTable and its
555 resource handle on the coordinator of each worker, created at the
556 DistributedTable instance initialization time with initializers on each
557 worker. Users can call methods on a DistributedTable as if it were a
558 StaticHashTable, which leads to execution with the resource local to the
559 consumer worker (or the coordinator, if calling from the coordinator). This
560 implementation relies on the fact that the methods of StaticHashTable are
561 queried with the resource handle (instead of the python object).
563 Currently, at saving time, a DistributedTable is saved as a StaticHashTable on
564 the coordinator, and restoring a DistributedTable from SavedModel is not
565 supported.
566 """
568 def __init__(self, strategy, wrapped_creator):
569 distribute_lib.distribution_strategy_input_api_counter.get_cell(
570 self.__class__.__name__, "PSSDistributedLookupTable").increase_by(1)
571 self._coordinator_instance = wrapped_creator()
572 self._wrapped_creator = wrapped_creator
573 self._coordinator = strategy._cluster_coordinator
574 # self._distributed_table is a RemoteValue mapping worker_index to
575 # RemoteValue that wraps a resource handle on the worker
576 self._distributed_table = None
577 self._distributed_table_creation_lock = threading.Lock()
579 if not save_context.in_save_context():
580 self._maybe_build_distributed_table()
582 def __getattr__(self, attr):
583 # This allows copy.copy(DistributedTable), e.g. at saving time.
584 # (DistributedVariable uses the same fix.) When copying an object, copy.copy
585 # doesn't invoke its __init__ method, instead it makes a new empty object,
586 # then copies the attributes over. copy.copy looks for attributes like
587 # "__setstate__" in case the object implements its custom unpickling. Since
588 # DistributedTable doesn't have those attributes defined, __getattr__ will
589 # be invoked, which tries to access the `_coordinator_instance` attribute.
590 # But that doesn't exist either because this is an empty object, and again
591 # __getattr__ is invoked, leading to an infinite recursion.
592 if attr == "_coordinator_instance":
593 raise AttributeError()
595 if attr in self._coordinator_instance.__dict__:
596 attr_value = self._coordinator_instance.__dict__[attr]
597 if callable(attr_value):
599 def wrapper(*args, **kwargs):
600 return attr_value(self, *args, **kwargs)
602 return wrapper
603 elif isinstance(attr_value, property):
604 return attr_value
605 else:
606 return getattr(self._coordinator_instance, attr)
607 else:
608 return getattr(self._coordinator_instance, attr)
610 def resource_handle_call_time_value(self):
611 """Returns a closure to run for a resource handle at call time and its spec.
613 This function is called in self.resource_handle to create a placeholder
614 which returns a resource handle on some worker or on the coordinator.
615 """
617 def closure():
618 # function to be evaluated at function call time, returning a nest of
619 # tensors compatible with `spec`.
620 dispatch_context = coordinator_context.get_current_dispatch_context()
621 if dispatch_context:
622 remote_value = self._distributed_table._values[ # pylint: disable=protected-access
623 dispatch_context.worker_index]
624 ret = dispatch_context.maybe_get_remote_value(remote_value)
625 return ret
627 else:
628 return self._coordinator_instance.resource_handle
630 return closure, tensor_spec.TensorSpec([], dtype=dtypes.resource)
632 def _maybe_build_distributed_table(self):
633 """Create table objects and resources on each worker if hasn't been created."""
634 with self._distributed_table_creation_lock:
635 if not self._distributed_table:
637 def create_copy():
638 new_table = self._wrapped_creator()
639 ret = new_table.resource_handle
640 return ret
642 self._distributed_table = (
643 self._coordinator._create_per_worker_resources(create_copy)) # pylint: disable=protected-access
645 @property
646 def resource_handle(self):
647 if context.executing_eagerly() or save_context.in_save_context():
648 return self._coordinator_instance.resource_handle
649 else:
650 self._maybe_build_distributed_table()
651 closure, spec = self.resource_handle_call_time_value()
652 return ops.get_default_graph().capture_call_time_value(
653 closure,
654 spec,
655 default_value=self._coordinator_instance.resource_handle)
657 @property
658 def is_distributed_table(self):
659 return True
661 def __tf_experimental_restore_capture__(
662 self, concrete_function, internal_capture):
663 closure, spec = self.resource_handle_call_time_value()
664 concrete_function.graph.replace_capture_with_deferred_capture(
665 self._coordinator_instance.resource_handle,
666 closure,
667 spec,
668 default_value=self._coordinator_instance.resource_handle,
669 placeholder=internal_capture)
670 return concrete_function.graph.deferred_external_captures[-1]
673_local_resource_restore_context = threading.local()
676def get_current_local_resource_restore_context():
677 try:
678 return _local_resource_restore_context.current
679 except AttributeError:
680 return None
683@contextlib.contextmanager
684def with_local_resource_restore_context(instance):
685 previous_context = getattr(_local_resource_restore_context, "current", None)
686 _local_resource_restore_context.current = LocalResourceRestoreContext(
687 instance)
688 yield
689 _local_resource_restore_context.current = previous_context
692class LocalResourceRestoreContext(object):
693 """Class holding information of a distributed instance, e.g. StaticHashTable.
695 Pairing use with context manager `with_local_resource_restore_context` allows
696 operations under this context manager to conveniently gets information of a
697 component of the `RestoredDistributedTable` (and other restored distributed
698 `CapturableResource` if we're supporting their distribution in the future),
699 instead of looking it up from the mapping of the worker-to-resource handle.
700 This is especially useful when we know which instance the operations should
701 execute with and the mapping is not available yet.
702 """
704 def __init__(self, instance):
705 self.instance = instance
708class RestoredDistributedTable(DistributedTable):
709 """A restored and distributed StaticHashTable for ParameterServerStrategy."""
711 def __init__(self, strategy, wrapped_creator):
712 # Wait for all resource functions to have been set before building the table
713 self._has_resource_functions = threading.Condition()
714 super().__init__(strategy, wrapped_creator)
716 def resource_handle_call_time_value(self):
717 """Returns a closure to run for a resource handle at call time and its spec.
719 This function is called in self.resource_handle to create a placeholder
720 which returns a resource handle on some worker or on the coordinator.
721 """
723 def closure():
724 # function to be evaluated at function call time, returning a nest of
725 # tensors compatible with `spec`.
726 dispatch_context = coordinator_context.get_current_dispatch_context()
727 if dispatch_context:
728 local_resource_restore_context = (
729 get_current_local_resource_restore_context())
731 # A LocalResourceRestoreContext is entered in the process of remote
732 # table creation and initialization if we're in the process of loading
733 # from a SavedModel. A LocalResourceRestoreContext carries the
734 # information regarding which table is being created and initialized. In
735 # order to initialize a table, we need the restored `_initialize`
736 # function, which captures this closure as table resource. And when this
737 # closure is executed, we will read the table info from the
738 # LocalResourceRestoreContext and return its handle, rather than
739 # following the normal procedure of fetching from
740 # `self._distributed_table`, because we're still in the middle of
741 # building `self._distributed_table`.
742 if local_resource_restore_context:
743 remote_value = local_resource_restore_context.instance.resource_handle
745 else:
746 remote_value = self._distributed_table._values[ # pylint: disable=protected-access
747 dispatch_context.worker_index]
749 ret = dispatch_context.maybe_get_remote_value(remote_value)
750 return ret
752 else:
754 return self._coordinator_instance.resource_handle
756 return closure, tensor_spec.TensorSpec(shape=(), dtype=dtypes.resource)
758 def __setattr__(self, name, value):
759 if name in TRACKABLE_RESOURCE_METHODS:
760 # When a StaticHashTable is loaded with `tf.saved_model.load`, it becomes
761 # a RestoredResource with dummy `_create_resource`, `_initialize`, and
762 # `_destroy_resource" methods. Similarly, when loaded with
763 # `tf.keras.models.load_model`, its initializer becomes a dummy one. In
764 # both cases, these methods needs to be set to some RestoredFunctions
765 # through `__setattr__`. Thus we need to store and set these methods for
766 # the distributed tables (a.k.a. `self._distributed_table`) on the
767 # workers too, besides setting for the coordinator instance. However, we
768 # cannot set them at this point, since the distributed tables have not
769 # been created. We store them in '_restored_function' and set them to the
770 # distributed tables when they're created in
771 # `self._maybe_build_distributed_table.create_copy`.
772 if not hasattr(self, "_restored_function"):
773 self._restored_function = {}
774 self._restored_function[name] = value
775 if all(method in self._restored_function
776 for method in TRACKABLE_RESOURCE_METHODS):
777 with self._has_resource_functions:
778 self._has_resource_functions.notify_all()
779 return self._coordinator_instance.__setattr__(name, value)
780 else:
781 return super(RestoredDistributedTable, self).__setattr__(name, value)
783 def _create_resource(self):
784 """A function that creates a resource handle for a table on coordinator."""
785 return self._coordinator_instance._create_resource() # pylint: disable=protected-access
787 def _initialize(self):
788 """A function that initializes the resource."""
789 return self._coordinator_instance._initialize() # pylint: disable=protected-access
791 def _destroy_resource(self):
792 """A function that destroys the resource."""
793 return self._coordinator_instance._destroy_resource() # pylint: disable=protected-access
795 def _maybe_build_distributed_table(self):
796 """Create table objects and resources on each worker if hasn't been created."""
797 with self._distributed_table_creation_lock:
798 if not self._distributed_table:
800 def create_copy():
801 new_table = self._wrapped_creator()
802 # Wait until all resource functions are available before setting them
803 # on new_table.
804 with self._has_resource_functions:
805 while not hasattr(self, "_restored_function") or any(
806 method not in self._restored_function
807 for method in TRACKABLE_RESOURCE_METHODS):
808 self._has_resource_functions.wait()
810 if hasattr(self, "_restored_function"):
811 with with_local_resource_restore_context(new_table):
812 for name, tf_function in self._restored_function.items():
813 setattr(new_table, name, tf_function)
814 init_op = new_table._initialize() # pylint: disable=protected-access
815 if not context.executing_eagerly():
816 ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op)
818 ret = new_table.resource_handle
819 return ret
821 self._distributed_table = (
822 self._coordinator._create_per_worker_resources(create_copy)) # pylint: disable=protected-access