Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/distribute/ps_values.py: 42%

443 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Various classes representing distributed values for PS.""" 

16 

17import contextlib 

18import copy 

19import threading 

20import weakref 

21 

22import numpy as np 

23 

24from tensorflow.python.distribute import distribute_lib 

25from tensorflow.python.distribute import distribute_utils 

26from tensorflow.python.distribute import values 

27from tensorflow.python.distribute import values_util 

28from tensorflow.python.distribute.coordinator import coordinator_context 

29from tensorflow.python.eager import context 

30from tensorflow.python.framework import dtypes 

31from tensorflow.python.framework import ops 

32from tensorflow.python.framework import tensor_conversion_registry 

33from tensorflow.python.framework import tensor_spec 

34from tensorflow.python.ops import array_ops 

35from tensorflow.python.ops import lookup_ops 

36from tensorflow.python.ops import resource_variable_ops 

37from tensorflow.python.ops import variable_scope as vs 

38from tensorflow.python.saved_model import save_context 

39from tensorflow.python.trackable import base as trackable 

40from tensorflow.python.types import core 

41from tensorflow.python.util.lazy_loader import LazyLoader 

42 

43load_context = LazyLoader( 

44 "load_context", globals(), 

45 "tensorflow.python.keras.saving.saved_model.load_context" 

46) 

47 

48TRACKABLE_RESOURCE_METHODS = [ 

49 "_create_resource", "_initialize", "_destroy_resource" 

50] 

51 

52 

53# Variable used in PSStrategy TF 1, TF2 and CentralStorageStrategy. 

54class AggregatingVariable(resource_variable_ops.BaseResourceVariable, 

55 core.Tensor): 

56 """A wrapper around a variable that aggregates updates across replicas.""" 

57 

58 def __init__(self, strategy, v, aggregation): 

59 self._distribute_strategy = strategy 

60 self._v = v 

61 # NOTE: We don't use "_distributed_container" here because we don't want 

62 # to trigger that code path in regroup(). 

63 v._aggregating_container = weakref.ref(self) # pylint: disable=protected-access 

64 self._aggregation = aggregation 

65 

66 def __deepcopy__(self, memo): 

67 """Perform a deepcopy of the `AggregatingVariable`. 

68 

69 Unlike the deepcopy of a regular tf.Variable, this keeps the original 

70 strategy and devices of the `AggregatingVariable`. To avoid confusion 

71 with the behavior of deepcopy on a regular `Variable` (which does 

72 copy into new devices), we only allow a deepcopy of a `AggregatingVariable` 

73 within its originating strategy scope. 

74 

75 Args: 

76 memo: The memoization object for `deepcopy`. 

77 

78 Returns: 

79 A deep copy of the current `AggregatingVariable`. 

80 

81 Raises: 

82 RuntimeError: If trying to deepcopy into a different strategy. 

83 """ 

84 with distribute_lib.enter_or_assert_strategy(self._distribute_strategy): 

85 v = copy.deepcopy(self._v, memo) 

86 

87 copied_variable = type(self)( 

88 strategy=self._distribute_strategy, 

89 v=v, 

90 aggregation=self._aggregation) 

91 

92 memo[id(self)] = copied_variable 

93 

94 return copied_variable 

95 

96 def get(self): 

97 return self._v 

98 

99 @property 

100 def distribute_strategy(self): 

101 return self._distribute_strategy 

102 

103 def __getattr__(self, name): 

104 return getattr(self._v, name) 

105 

106 def _assign_func(self, *args, **kwargs): 

107 with distribute_lib.enter_or_assert_strategy(self._distribute_strategy): 

108 f = kwargs.pop("f") 

109 if distribute_lib.in_cross_replica_context(): 

110 if distribute_lib.get_update_replica_id() is not None: 

111 # We are calling an assign function in an update context. 

112 return f(self._v, *args, **kwargs) 

113 

114 # We are calling an assign function in cross replica context, wrap it in 

115 # an update call. 

116 return self._distribute_strategy.extended.update( 

117 self, f, args=args, kwargs=kwargs) 

118 else: 

119 replica_context = distribute_lib.get_replica_context() 

120 assert replica_context 

121 # We are calling an assign function in replica context. 

122 # We reduce the value we want to assign/add/sub. More details about how 

123 # we handle the different use cases can be found in the _reduce method. 

124 # We call the function with the reduced value. 

125 if self._aggregation == vs.VariableAggregation.NONE: 

126 raise ValueError( 

127 values_util.aggregation_error_msg.format( 

128 variable_type="AggregatingVariable")) 

129 

130 def merge_fn(strategy, 

131 value, 

132 use_locking=False, 

133 name=None, 

134 read_value=True): 

135 v = values_util.apply_aggregation(strategy, value, self._aggregation, 

136 self) 

137 if name and isinstance(name, values.PerReplica): 

138 name = name.values[0] 

139 return strategy.extended.update( 

140 self, 

141 f, 

142 args=(v,), 

143 kwargs={ 

144 "use_locking": use_locking, 

145 "name": name, 

146 "read_value": read_value 

147 }) 

148 return replica_context.merge_call(merge_fn, args=args, kwargs=kwargs) 

149 

150 def assign_sub(self, *args, **kwargs): 

151 assign_sub_fn = lambda var, *a, **kw: var.assign_sub(*a, **kw) 

152 return self._assign_func(f=assign_sub_fn, *args, **kwargs) 

153 

154 def assign_add(self, *args, **kwargs): 

155 assign_add_fn = lambda var, *a, **kw: var.assign_add(*a, **kw) 

156 return self._assign_func(f=assign_add_fn, *args, **kwargs) 

157 

158 def assign(self, *args, **kwargs): 

159 assign_fn = lambda var, *a, **kw: var.assign(*a, **kw) 

160 return self._assign_func(f=assign_fn, *args, **kwargs) 

161 

162 @property 

163 def initializer(self): 

164 return self._v.initializer 

165 

166 def initialized_value(self): 

167 return self._v.initialized_value() 

168 

169 @property 

170 def initial_value(self): 

171 return self._v.initial_value 

172 

173 @property 

174 def op(self): 

175 return self._v.op 

176 

177 def value(self): 

178 return self._v.value() 

179 

180 def read_value(self): 

181 return self._v.read_value() 

182 

183 def sparse_read(self, indices, name=None): 

184 return self._v.sparse_read(indices, name=name) 

185 

186 def eval(self, session=None): 

187 return self._v.eval(session) 

188 

189 @property 

190 def graph(self): 

191 return self._v.graph 

192 

193 @property 

194 def device(self): 

195 return self._v.device 

196 

197 @property 

198 def shape(self): 

199 return self._v.shape 

200 

201 @property 

202 def aggregation(self): 

203 return self._aggregation 

204 

205 @property 

206 def synchronization(self): 

207 return self._v.synchronization 

208 

209 @property 

210 def name(self): 

211 return self._v.name 

212 

213 @property 

214 def trainable(self): 

215 return self._v.trainable 

216 

217 @property 

218 def dtype(self): 

219 return self._v.dtype 

220 

221 # TODO(josh11b): Test saving & restoring. 

222 def _gather_saveables_for_checkpoint(self): 

223 if isinstance(self._v, CachingVariable): 

224 return self._v._gather_saveables_for_checkpoint() # pylint:disable=protected-access 

225 return {trackable.VARIABLE_VALUE_KEY: self._v} 

226 

227 def _export_to_saved_model_graph(self, object_map, tensor_map, 

228 options, **kwargs): 

229 """For implementing `Trackable`.""" 

230 # By delegating this method to the wrapped variable, SavedModel with 

231 # AggregatingVariable are identical to SavedModel with normal variables. 

232 resource_list = self._v._export_to_saved_model_graph(object_map, tensor_map, # pylint:disable=protected-access 

233 options, **kwargs) 

234 object_map[self] = object_map[self._v] 

235 return resource_list 

236 

237 # pylint: disable=multiple-statements 

238 def __add__(self, o): 

239 return self._v + o 

240 

241 def __radd__(self, o): 

242 return o + self._v 

243 

244 def __sub__(self, o): 

245 return self._v - o 

246 

247 def __rsub__(self, o): 

248 return o - self._v 

249 

250 def __mul__(self, o): 

251 return self._v * o 

252 

253 def __rmul__(self, o): 

254 return o * self._v 

255 

256 def __truediv__(self, o): 

257 return self._v / o 

258 

259 def __rtruediv__(self, o): 

260 return o / self._v 

261 

262 def __floordiv__(self, o): 

263 return self._v // o 

264 

265 def __rfloordiv__(self, o): 

266 return o // self._v 

267 

268 def __mod__(self, o): 

269 return self._v % o 

270 

271 def __rmod__(self, o): 

272 return o % self._v 

273 

274 def __lt__(self, o): 

275 return self._v < o 

276 

277 def __le__(self, o): 

278 return self._v <= o 

279 

280 def __gt__(self, o): 

281 return self._v > o 

282 

283 def __ge__(self, o): 

284 return self._v >= o 

285 

286 def __and__(self, o): 

287 return self._v & o 

288 

289 def __rand__(self, o): 

290 return o & self._v 

291 

292 def __or__(self, o): 

293 return self._v | o 

294 

295 def __ror__(self, o): 

296 return o | self._v 

297 

298 def __xor__(self, o): 

299 return self._v ^ o 

300 

301 def __rxor__(self, o): 

302 return o ^ self._v 

303 

304 def __getitem__(self, o): 

305 return self._v[o] 

306 

307 def __pow__(self, o, modulo=None): 

308 return pow(self._v, o, modulo) 

309 

310 def __rpow__(self, o): 

311 return pow(o, self._v) 

312 

313 def __invert__(self): 

314 return ~self._v 

315 

316 def __neg__(self): 

317 return -self._v 

318 

319 def __abs__(self): 

320 return abs(self._v) 

321 

322 def __div__(self, o): 

323 try: 

324 return self._v.__div__(o) 

325 except AttributeError: 

326 # See https://docs.python.org/3/library/constants.html#NotImplemented 

327 return NotImplemented 

328 

329 def __rdiv__(self, o): 

330 try: 

331 return self._v.__rdiv__(o) 

332 except AttributeError: 

333 # See https://docs.python.org/3/library/constants.html#NotImplemented 

334 return NotImplemented 

335 

336 def __matmul__(self, o): 

337 try: 

338 return self._v.__matmul__(o) 

339 except AttributeError: 

340 # See https://docs.python.org/3/library/constants.html#NotImplemented 

341 return NotImplemented 

342 

343 def __rmatmul__(self, o): 

344 try: 

345 return self._v.__rmatmul__(o) 

346 except AttributeError: 

347 # See https://docs.python.org/3/library/constants.html#NotImplemented 

348 return NotImplemented 

349 

350 def __str__(self): 

351 return str(self._v) 

352 

353 def __repr__(self): 

354 return repr(self._v) 

355 

356 def _should_act_as_resource_variable(self): 

357 """Pass resource_variable_ops.is_resource_variable check.""" 

358 pass 

359 

360 def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False): 

361 return self._v._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref) # pylint: disable=protected-access 

362 

363 

364class CachingVariable(resource_variable_ops.BaseResourceVariable, core.Tensor): 

365 """A wrapper around a variable that caches read value locally.""" 

366 

367 def __init__(self, v): 

368 self._v = v 

369 self._cache = None 

370 self._current_new_cache_scope_count = 0 

371 

372 def get(self): 

373 return self._v 

374 

375 def __getattr__(self, name): 

376 return getattr(self._v, name) 

377 

378 def read_value(self): 

379 if distribute_utils.caching_scope_local.in_caching_scope(): 

380 return self.cached_read_value() 

381 return self._v.read_value() 

382 

383 def sparse_read(self, indices, name=None): 

384 return self._v.sparse_read(indices, name=name) 

385 

386 def cached_read_value(self): 

387 if (distribute_utils.caching_scope_local.new_cache_scope_count > 

388 self._current_new_cache_scope_count): 

389 self._current_new_cache_scope_count += 1 

390 self._cache = None 

391 

392 with ops.device("CPU:0"): 

393 if self._cache is not None: 

394 return self._cache 

395 else: 

396 self._cache = array_ops.identity(self._v) 

397 return self._cache 

398 

399 def assign_sub(self, *args, **kwargs): 

400 return self._v.assign_sub(*args, **kwargs) 

401 

402 def assign_add(self, *args, **kwargs): 

403 return self._v.assign_add(*args, **kwargs) 

404 

405 def assign(self, *args, **kwargs): 

406 return self._v.assign(*args, **kwargs) 

407 

408 @property 

409 def initializer(self): 

410 return self._v.initializer 

411 

412 def initialized_value(self): 

413 return self._v.initialized_value() 

414 

415 @property 

416 def initial_value(self): 

417 return self._v.initial_value 

418 

419 @property 

420 def op(self): 

421 return self._v.op 

422 

423 def value(self): 

424 if distribute_utils.caching_scope_local.in_caching_scope(): 

425 return self.cached_read_value() 

426 return self._v.value() 

427 

428 def eval(self, session=None): 

429 return self._v.eval(session) 

430 

431 @property 

432 def graph(self): 

433 return self._v.graph 

434 

435 @property 

436 def device(self): 

437 return self._v.device 

438 

439 @property 

440 def shape(self): 

441 return self._v.shape 

442 

443 @property 

444 def synchronization(self): 

445 return self._v.synchronization 

446 

447 @property 

448 def name(self): 

449 return self._v.name 

450 

451 @property 

452 def trainable(self): 

453 return self._v.trainable 

454 

455 @property 

456 def dtype(self): 

457 return self._v.dtype 

458 

459 @property 

460 def constraint(self): 

461 return self._v.constraint 

462 

463 def __array__(self, dtype=None): 

464 return np.asarray(self.numpy(), dtype=dtype) 

465 

466 def __complex__(self): 

467 return complex(self.value().numpy()) 

468 

469 def __int__(self): 

470 return int(self.value().numpy()) 

471 

472 def __float__(self): 

473 return float(self.value().numpy()) 

474 

475 def numpy(self): 

476 if context.executing_eagerly(): 

477 return self.read_value().numpy() 

478 else: 

479 raise NotImplementedError( 

480 "numpy() is only available when eager execution is enabled.") 

481 

482 def __str__(self): 

483 return str(self._v) 

484 

485 def __repr__(self): 

486 return repr(self._v) 

487 

488 def _should_act_as_resource_variable(self): 

489 """Pass resource_variable_ops.is_resource_variable check.""" 

490 pass 

491 

492 def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False): 

493 if distribute_utils.caching_scope_local.in_caching_scope(): 

494 return self.cached_read_value() 

495 return self._v._dense_var_to_tensor(dtype=dtype, name=name, as_ref=False) # pylint: disable=protected-access 

496 

497 @classmethod 

498 def _overload_overloadable_operators(cls): 

499 """Register overloads for all operators.""" 

500 for operator in ops.Tensor.OVERLOADABLE_OPERATORS: 

501 # Overloading __eq__ or __ne__ does not work as expected. 

502 if operator == "__eq__" or operator == "__ne__": 

503 continue 

504 cls._tensor_overload_operator(operator) 

505 

506 @classmethod 

507 def _tensor_overload_operator(cls, operator): 

508 """Delegate an operator overload to `ops.Tensor`.""" 

509 tensor_operator = getattr(ops.Tensor, operator) 

510 

511 def _operator(v, *args, **kwargs): 

512 return tensor_operator(v.value(), *args, **kwargs) # pylint: disable=protected-access 

513 setattr(cls, operator, _operator) 

514 

515 def _gather_saveables_for_checkpoint(self): 

516 return {trackable.VARIABLE_VALUE_KEY: self._v} 

517 

518 def _export_to_saved_model_graph(self, object_map, tensor_map, 

519 options, **kwargs): 

520 """For implementing `Trackable`.""" 

521 # By delegating this method to the wrapped variable, SavedModel with 

522 # AggregatingVariable are identical to SavedModel with normal variables. 

523 resource_list = self._v._export_to_saved_model_graph(object_map, tensor_map, # pylint:disable=protected-access 

524 options, **kwargs) 

525 object_map[self] = object_map[self._v] 

526 return resource_list 

527 

528 

529# Register a conversion function which reads the value of the variable, 

530# allowing instances of the class to be used as tensors. 

531def _tensor_conversion_aggregate(var, dtype=None, name=None, as_ref=False): 

532 return var._dense_var_to_tensor(dtype, name, as_ref) # pylint: disable=protected-access 

533 

534 

535tensor_conversion_registry.register_tensor_conversion_function( 

536 AggregatingVariable, _tensor_conversion_aggregate) 

537 

538 

539# Register a conversion function which reads the value of the variable, 

540# allowing instances of the class to be used as tensors. 

541def _tensor_conversion_caching(var, dtype=None, name=None, as_ref=False): 

542 return var._dense_var_to_tensor(dtype, name, as_ref) # pylint: disable=protected-access 

543 

544 

545tensor_conversion_registry.register_tensor_conversion_function( 

546 CachingVariable, _tensor_conversion_caching) 

547 

548CachingVariable._overload_overloadable_operators() # pylint: disable=protected-access 

549 

550 

551class DistributedTable(lookup_ops.StaticHashTable): 

552 """A distributed StaticHashTable for ParameterServerStrategy. 

553 

554 An instance of DistributedTable has copies of a StaticHashTable and its 

555 resource handle on the coordinator of each worker, created at the 

556 DistributedTable instance initialization time with initializers on each 

557 worker. Users can call methods on a DistributedTable as if it were a 

558 StaticHashTable, which leads to execution with the resource local to the 

559 consumer worker (or the coordinator, if calling from the coordinator). This 

560 implementation relies on the fact that the methods of StaticHashTable are 

561 queried with the resource handle (instead of the python object). 

562 

563 Currently, at saving time, a DistributedTable is saved as a StaticHashTable on 

564 the coordinator, and restoring a DistributedTable from SavedModel is not 

565 supported. 

566 """ 

567 

568 def __init__(self, strategy, wrapped_creator): 

569 distribute_lib.distribution_strategy_input_api_counter.get_cell( 

570 self.__class__.__name__, "PSSDistributedLookupTable").increase_by(1) 

571 self._coordinator_instance = wrapped_creator() 

572 self._wrapped_creator = wrapped_creator 

573 self._coordinator = strategy._cluster_coordinator 

574 # self._distributed_table is a RemoteValue mapping worker_index to 

575 # RemoteValue that wraps a resource handle on the worker 

576 self._distributed_table = None 

577 self._distributed_table_creation_lock = threading.Lock() 

578 

579 if not save_context.in_save_context(): 

580 self._maybe_build_distributed_table() 

581 

582 def __getattr__(self, attr): 

583 # This allows copy.copy(DistributedTable), e.g. at saving time. 

584 # (DistributedVariable uses the same fix.) When copying an object, copy.copy 

585 # doesn't invoke its __init__ method, instead it makes a new empty object, 

586 # then copies the attributes over. copy.copy looks for attributes like 

587 # "__setstate__" in case the object implements its custom unpickling. Since 

588 # DistributedTable doesn't have those attributes defined, __getattr__ will 

589 # be invoked, which tries to access the `_coordinator_instance` attribute. 

590 # But that doesn't exist either because this is an empty object, and again 

591 # __getattr__ is invoked, leading to an infinite recursion. 

592 if attr == "_coordinator_instance": 

593 raise AttributeError() 

594 

595 if attr in self._coordinator_instance.__dict__: 

596 attr_value = self._coordinator_instance.__dict__[attr] 

597 if callable(attr_value): 

598 

599 def wrapper(*args, **kwargs): 

600 return attr_value(self, *args, **kwargs) 

601 

602 return wrapper 

603 elif isinstance(attr_value, property): 

604 return attr_value 

605 else: 

606 return getattr(self._coordinator_instance, attr) 

607 else: 

608 return getattr(self._coordinator_instance, attr) 

609 

610 def resource_handle_call_time_value(self): 

611 """Returns a closure to run for a resource handle at call time and its spec. 

612 

613 This function is called in self.resource_handle to create a placeholder 

614 which returns a resource handle on some worker or on the coordinator. 

615 """ 

616 

617 def closure(): 

618 # function to be evaluated at function call time, returning a nest of 

619 # tensors compatible with `spec`. 

620 dispatch_context = coordinator_context.get_current_dispatch_context() 

621 if dispatch_context: 

622 remote_value = self._distributed_table._values[ # pylint: disable=protected-access 

623 dispatch_context.worker_index] 

624 ret = dispatch_context.maybe_get_remote_value(remote_value) 

625 return ret 

626 

627 else: 

628 return self._coordinator_instance.resource_handle 

629 

630 return closure, tensor_spec.TensorSpec([], dtype=dtypes.resource) 

631 

632 def _maybe_build_distributed_table(self): 

633 """Create table objects and resources on each worker if hasn't been created.""" 

634 with self._distributed_table_creation_lock: 

635 if not self._distributed_table: 

636 

637 def create_copy(): 

638 new_table = self._wrapped_creator() 

639 ret = new_table.resource_handle 

640 return ret 

641 

642 self._distributed_table = ( 

643 self._coordinator._create_per_worker_resources(create_copy)) # pylint: disable=protected-access 

644 

645 @property 

646 def resource_handle(self): 

647 if context.executing_eagerly() or save_context.in_save_context(): 

648 return self._coordinator_instance.resource_handle 

649 else: 

650 self._maybe_build_distributed_table() 

651 closure, spec = self.resource_handle_call_time_value() 

652 return ops.get_default_graph().capture_call_time_value( 

653 closure, 

654 spec, 

655 default_value=self._coordinator_instance.resource_handle) 

656 

657 @property 

658 def is_distributed_table(self): 

659 return True 

660 

661 def __tf_experimental_restore_capture__( 

662 self, concrete_function, internal_capture): 

663 closure, spec = self.resource_handle_call_time_value() 

664 concrete_function.graph.replace_capture_with_deferred_capture( 

665 self._coordinator_instance.resource_handle, 

666 closure, 

667 spec, 

668 default_value=self._coordinator_instance.resource_handle, 

669 placeholder=internal_capture) 

670 return concrete_function.graph.deferred_external_captures[-1] 

671 

672 

673_local_resource_restore_context = threading.local() 

674 

675 

676def get_current_local_resource_restore_context(): 

677 try: 

678 return _local_resource_restore_context.current 

679 except AttributeError: 

680 return None 

681 

682 

683@contextlib.contextmanager 

684def with_local_resource_restore_context(instance): 

685 previous_context = getattr(_local_resource_restore_context, "current", None) 

686 _local_resource_restore_context.current = LocalResourceRestoreContext( 

687 instance) 

688 yield 

689 _local_resource_restore_context.current = previous_context 

690 

691 

692class LocalResourceRestoreContext(object): 

693 """Class holding information of a distributed instance, e.g. StaticHashTable. 

694 

695 Pairing use with context manager `with_local_resource_restore_context` allows 

696 operations under this context manager to conveniently gets information of a 

697 component of the `RestoredDistributedTable` (and other restored distributed 

698 `CapturableResource` if we're supporting their distribution in the future), 

699 instead of looking it up from the mapping of the worker-to-resource handle. 

700 This is especially useful when we know which instance the operations should 

701 execute with and the mapping is not available yet. 

702 """ 

703 

704 def __init__(self, instance): 

705 self.instance = instance 

706 

707 

708class RestoredDistributedTable(DistributedTable): 

709 """A restored and distributed StaticHashTable for ParameterServerStrategy.""" 

710 

711 def __init__(self, strategy, wrapped_creator): 

712 # Wait for all resource functions to have been set before building the table 

713 self._has_resource_functions = threading.Condition() 

714 super().__init__(strategy, wrapped_creator) 

715 

716 def resource_handle_call_time_value(self): 

717 """Returns a closure to run for a resource handle at call time and its spec. 

718 

719 This function is called in self.resource_handle to create a placeholder 

720 which returns a resource handle on some worker or on the coordinator. 

721 """ 

722 

723 def closure(): 

724 # function to be evaluated at function call time, returning a nest of 

725 # tensors compatible with `spec`. 

726 dispatch_context = coordinator_context.get_current_dispatch_context() 

727 if dispatch_context: 

728 local_resource_restore_context = ( 

729 get_current_local_resource_restore_context()) 

730 

731 # A LocalResourceRestoreContext is entered in the process of remote 

732 # table creation and initialization if we're in the process of loading 

733 # from a SavedModel. A LocalResourceRestoreContext carries the 

734 # information regarding which table is being created and initialized. In 

735 # order to initialize a table, we need the restored `_initialize` 

736 # function, which captures this closure as table resource. And when this 

737 # closure is executed, we will read the table info from the 

738 # LocalResourceRestoreContext and return its handle, rather than 

739 # following the normal procedure of fetching from 

740 # `self._distributed_table`, because we're still in the middle of 

741 # building `self._distributed_table`. 

742 if local_resource_restore_context: 

743 remote_value = local_resource_restore_context.instance.resource_handle 

744 

745 else: 

746 remote_value = self._distributed_table._values[ # pylint: disable=protected-access 

747 dispatch_context.worker_index] 

748 

749 ret = dispatch_context.maybe_get_remote_value(remote_value) 

750 return ret 

751 

752 else: 

753 

754 return self._coordinator_instance.resource_handle 

755 

756 return closure, tensor_spec.TensorSpec(shape=(), dtype=dtypes.resource) 

757 

758 def __setattr__(self, name, value): 

759 if name in TRACKABLE_RESOURCE_METHODS: 

760 # When a StaticHashTable is loaded with `tf.saved_model.load`, it becomes 

761 # a RestoredResource with dummy `_create_resource`, `_initialize`, and 

762 # `_destroy_resource" methods. Similarly, when loaded with 

763 # `tf.keras.models.load_model`, its initializer becomes a dummy one. In 

764 # both cases, these methods needs to be set to some RestoredFunctions 

765 # through `__setattr__`. Thus we need to store and set these methods for 

766 # the distributed tables (a.k.a. `self._distributed_table`) on the 

767 # workers too, besides setting for the coordinator instance. However, we 

768 # cannot set them at this point, since the distributed tables have not 

769 # been created. We store them in '_restored_function' and set them to the 

770 # distributed tables when they're created in 

771 # `self._maybe_build_distributed_table.create_copy`. 

772 if not hasattr(self, "_restored_function"): 

773 self._restored_function = {} 

774 self._restored_function[name] = value 

775 if all(method in self._restored_function 

776 for method in TRACKABLE_RESOURCE_METHODS): 

777 with self._has_resource_functions: 

778 self._has_resource_functions.notify_all() 

779 return self._coordinator_instance.__setattr__(name, value) 

780 else: 

781 return super(RestoredDistributedTable, self).__setattr__(name, value) 

782 

783 def _create_resource(self): 

784 """A function that creates a resource handle for a table on coordinator.""" 

785 return self._coordinator_instance._create_resource() # pylint: disable=protected-access 

786 

787 def _initialize(self): 

788 """A function that initializes the resource.""" 

789 return self._coordinator_instance._initialize() # pylint: disable=protected-access 

790 

791 def _destroy_resource(self): 

792 """A function that destroys the resource.""" 

793 return self._coordinator_instance._destroy_resource() # pylint: disable=protected-access 

794 

795 def _maybe_build_distributed_table(self): 

796 """Create table objects and resources on each worker if hasn't been created.""" 

797 with self._distributed_table_creation_lock: 

798 if not self._distributed_table: 

799 

800 def create_copy(): 

801 new_table = self._wrapped_creator() 

802 # Wait until all resource functions are available before setting them 

803 # on new_table. 

804 with self._has_resource_functions: 

805 while not hasattr(self, "_restored_function") or any( 

806 method not in self._restored_function 

807 for method in TRACKABLE_RESOURCE_METHODS): 

808 self._has_resource_functions.wait() 

809 

810 if hasattr(self, "_restored_function"): 

811 with with_local_resource_restore_context(new_table): 

812 for name, tf_function in self._restored_function.items(): 

813 setattr(new_table, name, tf_function) 

814 init_op = new_table._initialize() # pylint: disable=protected-access 

815 if not context.executing_eagerly(): 

816 ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op) 

817 

818 ret = new_table.resource_handle 

819 return ret 

820 

821 self._distributed_table = ( 

822 self._coordinator._create_per_worker_resources(create_copy)) # pylint: disable=protected-access