Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/distribute/values.py: 33%

884 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Various classes representing distributed values.""" 

16 

17import copy 

18from typing import Optional 

19import weakref 

20 

21from tensorflow.core.protobuf import struct_pb2 

22from tensorflow.python.distribute import device_util 

23from tensorflow.python.distribute import distribute_lib 

24from tensorflow.python.distribute import packed_distributed_variable as packed 

25from tensorflow.python.distribute import reduce_util 

26from tensorflow.python.distribute import values_util 

27from tensorflow.python.eager import context 

28from tensorflow.python.eager import record 

29from tensorflow.python.framework import composite_tensor 

30from tensorflow.python.framework import dtypes 

31from tensorflow.python.framework import ops 

32from tensorflow.python.framework import tensor_conversion_registry 

33from tensorflow.python.framework import tensor_util 

34from tensorflow.python.framework import type_spec 

35from tensorflow.python.ops import array_ops 

36from tensorflow.python.ops import control_flow_ops 

37from tensorflow.python.ops import math_ops 

38from tensorflow.python.ops import resource_variable_ops 

39from tensorflow.python.ops import variable_scope as vs 

40from tensorflow.python.ops import variables as variables_lib 

41from tensorflow.python.saved_model import nested_structure_coder 

42from tensorflow.python.trackable import base as trackable 

43from tensorflow.python.training.saving import saveable_object 

44from tensorflow.python.types import core 

45from tensorflow.python.types import distribute as ds_types 

46from tensorflow.python.types import trace 

47 

48 

49def _on_write_update_replica(var, update_fn, value, **kwargs): 

50 """Updates variables with ON_WRITE synchronization in replica context.""" 

51 if var.aggregation == vs.VariableAggregation.NONE: 

52 return update_fn(var._get_on_device_or_primary(), value, **kwargs) # pylint: disable=protected-access 

53 

54 if not distribute_lib.get_strategy().extended._use_merge_call(): # pylint: disable=protected-access 

55 # Don't allow MEAN with non float dtype, since it may cause unexpected 

56 # precision loss. Python3 and NumPy automatically upcast integers to 

57 # float in division, but we should always preserve the type. 

58 if var.aggregation == vs.VariableAggregation.MEAN and ( 

59 not var.dtype.is_floating) and tensor_util.is_tf_type(value): 

60 raise ValueError( 

61 "Cannot update non-float variables with " 

62 "tf.VariableAggregation.MEAN aggregation in replica context. " 

63 "Either change the variable dtype to float or update it in " 

64 "cross-replica context.") 

65 

66 aggregated_value = apply_aggregation_replica_context( 

67 value, var.aggregation, var) 

68 values_util.mark_as_unsaveable() 

69 

70 return distribute_lib.get_replica_context()._update( # pylint: disable=protected-access 

71 var, 

72 update_fn, 

73 args=(aggregated_value,), 

74 kwargs=kwargs, 

75 group=True) 

76 

77 else: 

78 

79 def merge_fn(strategy, value, **kwargs): 

80 """Aggregate values and update all variables in cross replica context.""" 

81 # Don't allow MEAN with non float dtype, since it may cause unexpected 

82 # precision loss. Python3 and NumPy automatically upcast integers to 

83 # float in division, but we should always preserve the type. 

84 # 

85 # Note that to be backward compatible we allow the case when the value 

86 # is *always* the same on each replica. I.E. value is not a 

87 # PerReplica. Refer to regroup() to see how values are grouped. 

88 if var.aggregation == vs.VariableAggregation.MEAN and ( 

89 not var.dtype.is_floating) and isinstance(value, PerReplica): 

90 raise ValueError( 

91 "Cannot update non-float variables with " 

92 "tf.VariableAggregation.MEAN aggregation in replica context. " 

93 "Either change the variable dtype to float or update it in " 

94 "cross-replica context.") 

95 

96 assert strategy == var.distribute_strategy 

97 v = values_util.apply_aggregation(strategy, value, var.aggregation, var) 

98 return var._update_cross_replica(update_fn, v, **kwargs) # pylint: disable=protected-access 

99 

100 return distribute_lib.get_replica_context().merge_call( 

101 merge_fn, args=(value,), kwargs=kwargs) 

102 

103 

104def apply_aggregation_replica_context(value, aggregation, destinations): 

105 """Aggregate `value` to `destinations` as specified by `aggregation`.""" 

106 # if it is a python literal, return without aggregation 

107 if isinstance(value, DistributedValues): 

108 raise TypeError( 

109 "Cannot use DistributedValues to update variables in replica context.") 

110 if not tensor_util.is_tf_type(value): 

111 return value 

112 

113 if aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA: 

114 # Switch to cross-replica context to broadcast 

115 def merge_fn(strategy, value): 

116 return strategy.extended.broadcast_to( 

117 strategy.experimental_local_results(value)[0], 

118 destinations=destinations) 

119 

120 return distribute_lib.get_replica_context().merge_call( 

121 merge_fn, args=(value,)) 

122 

123 else: 

124 reduce_op = reduce_util.ReduceOp.from_variable_aggregation(aggregation) 

125 aggregated_value = distribute_lib.get_strategy( # pylint: disable=protected-access 

126 ).extended._replica_ctx_all_reduce(reduce_op, value) 

127 return aggregated_value 

128 

129 

130class DistributedValues(ds_types.DistributedValues): 

131 """Base class for representing distributed values.""" 

132 

133 def __init__(self, values): 

134 """Should only be called by subclass __init__.""" 

135 self._values = tuple(values) 

136 

137 def _get(self): 

138 """Returns the value for the current device or raises a ValueError.""" 

139 replica_id = values_util.get_current_replica_id_as_int() 

140 if replica_id is None: 

141 return self._get_cross_replica() 

142 else: 

143 return self._values[replica_id] 

144 

145 def _get_cross_replica(self): 

146 raise NotImplementedError( 

147 "DistributedValues._get_cross_replica should be implemented by " 

148 "sub-classes which support cross-replica accesses.") 

149 

150 def _get_on_device_or_primary(self): 

151 """Returns value in same replica or device if possible, else the _primary.""" 

152 replica_id = values_util.get_current_replica_id_as_int() 

153 if replica_id is None: 

154 # Try to find a value on the current device. 

155 current_device = device_util.canonicalize(device_util.current()) 

156 for value in self._values: 

157 if device_util.canonicalize(value.device) == current_device: 

158 return value 

159 return self._primary 

160 else: 

161 return self._values[replica_id] 

162 

163 @property 

164 def _primary(self): 

165 """Returns a representative component.""" 

166 return self._values[0] 

167 

168 @property 

169 def _devices(self): 

170 return tuple(v.device for v in self._values) 

171 

172 def __str__(self): 

173 debug_str = ",\n".join( 

174 " %d: %s" % (i, v) for i, v in enumerate(self._values)) 

175 return "%s:{\n%s\n}" % (self.__class__.__name__, debug_str) 

176 

177 def __repr__(self): 

178 debug_repr = ",\n".join( 

179 " %d: %r" % (i, v) for i, v in enumerate(self._values)) 

180 return "%s:{\n%s\n}" % (self.__class__.__name__, debug_repr) 

181 

182 

183# NOTE(josh11b,apassos): It would be great if we could inspect the values this was 

184# initialized with and use that to generate the overloaded operators here. 

185# Unfortunately, Python's rules for special methods don't allow this, see 

186# https://docs.python.org/3/reference/datamodel.html#special-method-names 

187# "if a class defines a method named __getitem__(), and x is an instance of 

188# this class, then x[i] is roughly equivalent to type(x).__getitem__(x, i)." 

189# In particular, these special methods don't go through __getattr__, and 

190# it will only use those methods if they are defined in the class, not the 

191# object. 

192class DistributedDelegate(DistributedValues): 

193 """A map from device to values; acts as the same type as the values.""" 

194 

195 def __getattr__(self, name): 

196 # The '_use_resource_variables' and the attrs starts with '_self' are used 

197 # for restoring the saved_model proto, and '_attribute_sentinel' is used for 

198 # Layer tracking. At the point these attrs are queried, the variable has not 

199 # been initialized. Thus it should not query those of the underlying 

200 # components. 

201 if name.startswith("_self_") or name in ("_use_resource_variables", 

202 "_attribute_sentinel", 

203 "_distributed_container"): 

204 return super(DistributedDelegate, self).__getattr__(name) 

205 

206 # This allows copy.copy(DistributedDelegate). When copying an object, 

207 # copy.copy doesn't invoke its __init__ method, instead it makes a new 

208 # empty object, then copies the attributes over. copy.copy looks for 

209 # attributes like "__getstate__" in case the object implements its custom 

210 # copying. Since DistributedDelegate doesn't have those attributes defined, 

211 # __getattr__ will be invoked, which tries to access "_values" attributes, 

212 # but that doesn't exist either because this is an empty object, and again 

213 # __getattr__ is invoked, leading to an infinite recursion. 

214 if name == "_values": 

215 raise AttributeError() 

216 

217 # TODO(priyag): This needs to be made robust against pitfalls from mix use 

218 # __getattr__ and @property. See b/120402273. 

219 return getattr(self._get(), name) 

220 

221 @property 

222 def values(self): 

223 """Returns the per replica values.""" 

224 return self._values 

225 

226 def _get_as_operand(self): 

227 """Returns the value for operations for the current device. 

228 

229 Some implementations, e.g. `TPUMirroredVariable`, are not able to return the 

230 value type within a replica context. They can, however, return a value that 

231 can be used by the operations below. 

232 """ 

233 return self._get() 

234 

235 # pylint: disable=multiple-statements 

236 def __add__(self, o): 

237 return self._get_as_operand() + o 

238 

239 def __radd__(self, o): 

240 return o + self._get_as_operand() 

241 

242 def __sub__(self, o): 

243 return self._get_as_operand() - o 

244 

245 def __rsub__(self, o): 

246 return o - self._get_as_operand() 

247 

248 def __mul__(self, o): 

249 return self._get_as_operand() * o 

250 

251 def __rmul__(self, o): 

252 return o * self._get_as_operand() 

253 

254 def __truediv__(self, o): 

255 return self._get_as_operand() / o 

256 

257 def __rtruediv__(self, o): 

258 return o / self._get_as_operand() 

259 

260 def __floordiv__(self, o): 

261 return self._get_as_operand() // o 

262 

263 def __rfloordiv__(self, o): 

264 return o // self._get_as_operand() 

265 

266 def __mod__(self, o): 

267 return self._get_as_operand() % o 

268 

269 def __rmod__(self, o): 

270 return o % self._get_as_operand() 

271 

272 def __lt__(self, o): 

273 return self._get_as_operand() < o 

274 

275 def __le__(self, o): 

276 return self._get_as_operand() <= o 

277 

278 def __gt__(self, o): 

279 return self._get_as_operand() > o 

280 

281 def __ge__(self, o): 

282 return self._get_as_operand() >= o 

283 

284 def __and__(self, o): 

285 return self._get_as_operand() & o 

286 

287 def __rand__(self, o): 

288 return o & self._get_as_operand() 

289 

290 def __or__(self, o): 

291 return self._get_as_operand() | o 

292 

293 def __ror__(self, o): 

294 return o | self._get_as_operand() 

295 

296 def __xor__(self, o): 

297 return self._get_as_operand() ^ o 

298 

299 def __rxor__(self, o): 

300 return o ^ self._get_as_operand() 

301 

302 def __getitem__(self, o): 

303 return self._get_as_operand()[o] 

304 

305 def __pow__(self, o, modulo=None): 

306 return pow(self._get_as_operand(), o, modulo) 

307 

308 def __rpow__(self, o): 

309 return pow(o, self._get_as_operand()) 

310 

311 def __invert__(self): 

312 return ~self._get_as_operand() 

313 

314 def __neg__(self): 

315 return -self._get_as_operand() 

316 

317 def __abs__(self): 

318 return abs(self._get_as_operand()) 

319 

320 def __div__(self, o): 

321 try: 

322 return self._get_as_operand().__div__(o) 

323 except AttributeError: 

324 # See https://docs.python.org/3/library/constants.html#NotImplemented 

325 return NotImplemented 

326 

327 def __rdiv__(self, o): 

328 try: 

329 return self._get_as_operand().__rdiv__(o) 

330 except AttributeError: 

331 # See https://docs.python.org/3/library/constants.html#NotImplemented 

332 return NotImplemented 

333 

334 def __matmul__(self, o): 

335 try: 

336 return self._get_as_operand().__matmul__(o) 

337 except AttributeError: 

338 # See https://docs.python.org/3/library/constants.html#NotImplemented 

339 return NotImplemented 

340 

341 def __rmatmul__(self, o): 

342 try: 

343 return self._get_as_operand().__rmatmul__(o) 

344 except AttributeError: 

345 # See https://docs.python.org/3/library/constants.html#NotImplemented 

346 return NotImplemented 

347 

348 # TODO(josh11b): Even more operator overloads. 

349 

350 

351class PerReplica(DistributedValues, composite_tensor.CompositeTensor, 

352 ds_types.PerReplica): 

353 """Holds a map from replica to unsynchronized values.""" 

354 

355 @property 

356 def _type_spec(self): 

357 return PerReplicaSpec( 

358 *(type_spec.type_spec_from_value(v) for v in self._values)) 

359 

360 @property 

361 def values(self): 

362 """Returns the per replica values.""" 

363 return self._values 

364 

365 

366def _per_replica_to_tensor(var, dtype=None, name=None, as_ref=False): 

367 """Converts a `PerReplica` to a `Tensor`.""" 

368 del name 

369 if dtype is not None and not dtype.is_compatible_with(var.dtype): 

370 raise ValueError( 

371 "Incompatible type conversion requested to type {!r} for variable " 

372 "of type {!r}".format(dtype.name, var.dtype.name)) 

373 if as_ref: 

374 raise NotImplementedError( 

375 "PerReplica doesn't support being used as a reference.") 

376 if (distribute_lib.in_cross_replica_context() or 

377 not distribute_lib.has_strategy()): 

378 raise ValueError("It looks like you are using a PerReplica object while " 

379 "not inside a replica context, which is not supported. " 

380 "Try running your op or function inside a replica context " 

381 "by using `strategy.run`") 

382 else: 

383 replica_id = values_util.get_current_replica_id_as_int() 

384 return var.values[replica_id] 

385 

386# Register a conversion function to provide a useful error message when users 

387# try to use PerReplica values in the wrong contexts 

388tensor_conversion_registry.register_tensor_conversion_function( 

389 PerReplica, _per_replica_to_tensor) 

390 

391 

392class PerReplicaSpec(type_spec.TypeSpec): 

393 """Type specification for a `PerReplica`.""" 

394 

395 __slots__ = ["_value_specs"] 

396 

397 value_type = property(lambda self: PerReplica) 

398 

399 def __init__(self, *value_specs): 

400 self._value_specs = tuple(value_specs) 

401 

402 def _serialize(self): 

403 return self._value_specs 

404 

405 @property 

406 def _component_specs(self): 

407 return self._value_specs 

408 

409 def _to_components(self, value): 

410 replica_context = distribute_lib.get_replica_context() 

411 if replica_context is not None and replica_context.num_replicas_in_sync > 1: 

412 raise ValueError( 

413 "Flattening a PerReplica to components is not supported in replica " 

414 "context.") 

415 return value._values # pylint: disable=protected-access 

416 

417 def _from_components(self, tensor_list): 

418 return PerReplica(tensor_list) 

419 

420 

421nested_structure_coder.register_codec( 

422 nested_structure_coder.BuiltInTypeSpecCodec( 

423 PerReplicaSpec, struct_pb2.TypeSpecProto.PER_REPLICA_SPEC 

424 ) 

425) 

426 

427 

428# Note that unlike PerReplica, Mirrored values inherit from 

429# DistributedDelegate and so can be used directly in cross-replica mode. 

430# TODO(tomhennigan) Should this extend CompositeTensor? 

431class Mirrored(DistributedDelegate, ds_types.Mirrored): 

432 """Holds a map from replica to values which are kept in sync.""" 

433 

434 def _get_cross_replica(self): 

435 return self._get_on_device_or_primary() 

436 

437 def _as_graph_element(self): 

438 obj = self._get() 

439 conv_fn = getattr(obj, "_as_graph_element", None) 

440 if conv_fn and callable(conv_fn): 

441 return conv_fn() 

442 return obj 

443 

444 def _is_mirrored(self): 

445 return True 

446 

447 

448class DistributedVarOp(object): 

449 """A class that looks like `tf.Operation`.""" 

450 

451 def __init__(self, name, graph, traceback, typ): 

452 self.name = name 

453 self.graph = graph 

454 self.traceback = traceback 

455 self.type = typ 

456 

457 def __eq__(self, o): 

458 if not isinstance(o, self.__class__): 

459 raise NotImplementedError 

460 return (self.name == o.name and self.graph == o.graph and 

461 self.traceback == o.traceback and self.type == o.type) 

462 

463 def __hash__(self): 

464 return hash((self.name, self.graph, tuple(self.traceback), self.type)) 

465 

466 

467# TODO(b/209081027): Remove this once Variable is a CompositeTensor. 

468class DistributedVariableTraceType(trace.TraceType): 

469 """TraceType of DistributedVariable objects.""" 

470 

471 def __init__(self, distributed_variable): 

472 self.distributed_variable = distributed_variable 

473 self.components = (tuple(distributed_variable.shape.as_list()), 

474 distributed_variable.dtype) 

475 

476 def is_subtype_of(self, other): 

477 return self == other 

478 

479 def most_specific_common_supertype(self, others): 

480 return self if all(self == other for other in others) else None 

481 

482 def placeholder_value(self, placeholder_context=None): 

483 return self.distributed_variable 

484 

485 def _to_tensors(self, value): 

486 return [] 

487 

488 def __hash__(self) -> int: 

489 return hash(self.components) 

490 

491 def __eq__(self, other) -> bool: 

492 if not isinstance(other, DistributedVariableTraceType): 

493 return False 

494 

495 return self.components == other.components 

496 

497 

498class DistributedVariable(DistributedDelegate, variables_lib.Variable, 

499 core.Tensor): 

500 """Holds a map from replica to variables.""" 

501 

502 def __init__(self, strategy, values, aggregation, var_policy=None): 

503 if (aggregation == variables_lib.VariableAggregation.MEAN and 

504 not values[0].dtype.is_floating): 

505 raise ValueError( 

506 "creating distributed tf.Variable with aggregation=MEAN and a " 

507 "non-floating dtype is not supported, please use a different " 

508 "aggregation or dtype") 

509 self._distribute_strategy = strategy 

510 self._aggregation = aggregation 

511 super(DistributedVariable, self).__init__(values) 

512 self._common_name = self._primary.name.split(":")[0] 

513 # Use a weakref to make it easy to map from the contained values 

514 # to the container without introducing a reference cycle. 

515 for v in values: 

516 # ResourceVariable is a CompositeTensor. Attributes added to 

517 # CompositeTensors will get lost through tf.nest packing and unpacking. 

518 if isinstance(v, composite_tensor.CompositeTensor) and hasattr( 

519 v, "handle"): 

520 v.handle._distributed_container = weakref.ref(self) # pylint: disable=protected-access 

521 else: 

522 v._distributed_container = weakref.ref(self) # pylint: disable=protected-access 

523 

524 # Packed variable is used to reduce the overhead of function execution. 

525 # For a DistributedVariable, only one variable handle is captured into a 

526 # function graph. It's only supported in eager mode. 

527 if ops.executing_eagerly_outside_functions() and getattr( 

528 strategy, "_enable_packed_variable_in_eager_mode", False): 

529 name = "%s/packed/" % self._common_name 

530 if hasattr(values[0], "_vars"): 

531 # Handle when the resource variables are "nested" underneath another 

532 # layer of values, e.g., TPUReplicatedVariable, by packing all them 

533 # together and pushing the packed var down a level 

534 # pylint: disable=protected-access 

535 packed_var = packed.PackedDistributedVariable( 

536 sum((value._vars for value in values), []), name=name) 

537 for value in values: 

538 value._packed_var = packed_var 

539 self._packed_var = None 

540 # pylint: enable=protected-access 

541 else: 

542 self._packed_var = packed.PackedDistributedVariable(values, name=name) 

543 else: 

544 self._packed_var = None 

545 

546 # tf.keras keeps track of variables initialized using this attribute. When 

547 # tf.keras gets the default session, it initializes all uninitialized vars. 

548 # We need to make _keras_initialized a member of DistributedVariable because 

549 # without this it will use `__getattr__` which will delegate to a component 

550 # variable. 

551 self._keras_initialized = False 

552 # Typically, a `DistributedVariable`'s initializer is composed of the 

553 # initializers of the components variables. However, in some cases, such as 

554 # when restoring from a checkpoint, we may set the _initializer_op 

555 # property on the entire `DistributedVariable`. 

556 self._initializer_op = None 

557 # Set a VariablePolicy which decides how we replicate/aggregate the given 

558 # variable. 

559 self._policy = var_policy 

560 

561 def __deepcopy__(self, memo): 

562 """Perform a deepcopy of the `DistributedVariable`. 

563 

564 Unlike the deepcopy of a regular tf.Variable, this keeps the original 

565 strategy and devices of the `DistributedVariable`. To avoid confusion 

566 with the behavior of deepcopy on a regular `Variable` (which does 

567 copy into new devices), we only allow a deepcopy of a `DistributedVariable` 

568 within its originating strategy scope. 

569 

570 Args: 

571 memo: The memoization object for `deepcopy`. 

572 

573 Returns: 

574 A deep copy of the current `DistributedVariable`. 

575 

576 Raises: 

577 RuntimeError: If trying to deepcopy into a different strategy. 

578 """ 

579 with distribute_lib.enter_or_assert_strategy(self._distribute_strategy): 

580 new_values = [] 

581 

582 for value in self._values: 

583 with ops.device(value.device): 

584 new_values.append(copy.deepcopy(value, memo)) 

585 

586 copied_variable = type(self)( 

587 strategy=self._distribute_strategy, 

588 values=new_values, 

589 aggregation=self._aggregation, 

590 var_policy=copy.deepcopy(self._policy, memo)) 

591 

592 memo[id(self)] = copied_variable 

593 

594 return copied_variable 

595 

596 def _use_packed_variable(self): 

597 # Don't use packed variable when under a SaveContext to avoid explicit 

598 # device placement on variable consuming ops. 

599 return self._packed_var is not None and ( 

600 not values_util.is_saving_non_distributed()) 

601 

602 def is_initialized(self, name=None): 

603 """Identifies if all the component variables are initialized. 

604 

605 Args: 

606 name: Name of the final `logical_and` op. 

607 

608 Returns: 

609 The op that evaluates to True or False depending on if all the 

610 component variables are initialized. 

611 """ 

612 if values_util.is_saving_non_distributed(): 

613 return self._primary.is_initialized() 

614 if self._use_packed_variable(): 

615 return self._packed_var.is_initialized() 

616 result = self._primary.is_initialized() 

617 # We iterate through the list of values except the last one to allow us to 

618 # name the final `logical_and` op the same name that is passed by the user 

619 # to the `is_initialized` op. For distributed variables, the 

620 # `is_initialized` op is a `logical_and` op. 

621 for v in self._values[1:-1]: 

622 result = math_ops.logical_and(result, v.is_initialized()) 

623 result = math_ops.logical_and( 

624 result, self._values[-1].is_initialized(), name=name) 

625 return result 

626 

627 @property 

628 def initializer(self): 

629 if values_util.is_saving_non_distributed(): 

630 return self._primary.initializer 

631 if self._initializer_op: 

632 init_op = self._initializer_op 

633 else: 

634 # return grouped ops of all the var initializations of component values of 

635 # the mirrored variable 

636 init_op = control_flow_ops.group( 

637 tuple(v.initializer for v in self._values)) 

638 return init_op 

639 

640 def initialized_value(self): 

641 return self._get_on_device_or_primary().initialized_value() 

642 

643 def _is_mirrored(self): 

644 return (self._policy is not None) and (self._policy._is_mirrored()) # pylint: disable=protected-access 

645 

646 @property 

647 def initial_value(self): 

648 return self._get_on_device_or_primary().initial_value 

649 

650 @property 

651 def constraint(self): 

652 return self._primary.constraint 

653 

654 @property 

655 def graph(self): 

656 return self._primary.graph 

657 

658 @property 

659 def _shared_name(self): 

660 return self._common_name 

661 

662 @property 

663 def _unique_id(self): 

664 return self._primary._unique_id # pylint: disable=protected-access 

665 

666 @property 

667 def _graph_key(self): 

668 """Lets Optimizers know which graph this variable is from.""" 

669 return self._primary._graph_key # pylint: disable=protected-access 

670 

671 @property 

672 def name(self): 

673 return self._primary.name 

674 

675 @property 

676 def dtype(self): 

677 return self._primary.dtype 

678 

679 @property 

680 def shape(self): 

681 return self._primary.shape 

682 

683 @property 

684 def synchronization(self): 

685 return self._primary.synchronization 

686 

687 @property 

688 def aggregation(self): 

689 return self._aggregation 

690 

691 @property 

692 def _packed_variable(self): 

693 if self._use_packed_variable(): 

694 return self._packed_var 

695 return None 

696 

697 @property 

698 def handle(self): 

699 if values_util.is_saving_non_distributed(): 

700 return self._primary.handle 

701 replica_id = values_util.get_current_replica_id_as_int() 

702 if replica_id is None: 

703 raise ValueError( 

704 "DistributedVariable.handle is not available outside the replica " 

705 "context or a `tf.distribute.Strategy.update()` call.") 

706 else: 

707 if self._use_packed_variable(): 

708 return self._packed_var.handle 

709 return self._values[replica_id].handle 

710 

711 def eval(self, session=None): 

712 return self._get_on_device_or_primary().eval(session) 

713 

714 @property 

715 def _save_slice_info(self): 

716 return self._primary._save_slice_info # pylint: disable=protected-access 

717 

718 def _get_save_slice_info(self): 

719 return self._primary._get_save_slice_info() # pylint: disable=protected-access 

720 

721 def _set_save_slice_info(self, save_slice_info): 

722 for v in self._values: 

723 v._set_save_slice_info(save_slice_info) # pylint: disable=protected-access 

724 

725 @property 

726 def device(self): 

727 return self._get_on_device_or_primary().device 

728 

729 @property 

730 def trainable(self): 

731 return self._primary.trainable 

732 

733 @property 

734 def distribute_strategy(self): 

735 return self._distribute_strategy 

736 

737 def get_shape(self): 

738 return self._primary.get_shape() 

739 

740 def to_proto(self, export_scope=None): 

741 return self._primary.to_proto(export_scope=export_scope) 

742 

743 @property 

744 def op(self): 

745 if values_util.is_saving_non_distributed(): 

746 return self._primary.op 

747 # We want cross-replica code that does some var.op.X calls 

748 # to work (even if the current device isn't in self._devices), but 

749 # other uses of var.op in a cross-replica context to fail. 

750 if distribute_lib.in_cross_replica_context(): 

751 return DistributedVarOp(self._primary.op.name, self._primary.op.graph, 

752 self._primary.op.traceback, self._primary.op.type) 

753 return self._get().op 

754 

755 @property 

756 def _in_graph_mode(self): 

757 return self._primary._in_graph_mode # pylint: disable=protected-access 

758 

759 def _get_replica(self, replica_id): 

760 """Returns the value on a device with the given replica_id.""" 

761 value = self._values[replica_id] 

762 if self._use_packed_variable(): 

763 return self._packed_var.on_device(value.device) 

764 else: 

765 return value 

766 

767 def _get(self): 

768 """Returns the value for the current device or raises a ValueError.""" 

769 if values_util.is_saving_non_distributed(): 

770 return self._primary 

771 replica_id = values_util.get_current_replica_id_as_int() 

772 if replica_id is None: 

773 return self._get_cross_replica() 

774 else: 

775 return self._get_replica(replica_id) 

776 

777 def _get_on_device_or_primary(self): 

778 """Returns value in same replica or device if possible, else the _primary.""" 

779 if values_util.is_saving_non_distributed(): 

780 return self._primary 

781 replica_id = values_util.get_current_replica_id_as_int() 

782 if replica_id is None: 

783 # Try to find a value on the current device. 

784 current_device = device_util.canonicalize(device_util.current()) 

785 for i, value in enumerate(self._values): 

786 if device_util.canonicalize(value.device) == current_device: 

787 return self._get_replica(i) 

788 return self._get_replica(0) 

789 else: 

790 return self._get_replica(replica_id) 

791 

792 def read_value(self): 

793 if values_util.is_saving_non_distributed(): 

794 return self._primary.read_value() 

795 with distribute_lib.enter_or_assert_strategy(self._distribute_strategy): 

796 return array_ops.identity(self._get()) 

797 

798 def value(self): 

799 if values_util.is_saving_non_distributed(): 

800 return self._primary.value() 

801 if self._policy: 

802 return self._policy.value(self) 

803 return self._get_on_device_or_primary().value() 

804 

805 def numpy(self): 

806 if context.executing_eagerly(): 

807 return self.read_value().numpy() 

808 else: 

809 raise NotImplementedError("DistributedVariable.numpy() is only available " 

810 "when eager execution is enabled.") 

811 

812 def assign_sub(self, value, use_locking=False, name=None, read_value=True): 

813 if values_util.is_saving_non_distributed(): 

814 return self._primary.assign_sub(value, use_locking, name, read_value) 

815 if self._policy: 

816 return self._policy.assign_sub( 

817 self, 

818 value, 

819 use_locking=use_locking, 

820 name=name, 

821 read_value=read_value) 

822 return values_util.on_write_assign_sub( 

823 self, value, use_locking=use_locking, name=name, read_value=read_value) 

824 

825 def assign_add(self, value, use_locking=False, name=None, read_value=True): 

826 if values_util.is_saving_non_distributed(): 

827 return self._primary.assign_add(value, use_locking, name, read_value) 

828 if self._policy: 

829 return self._policy.assign_add( 

830 self, 

831 value, 

832 use_locking=use_locking, 

833 name=name, 

834 read_value=read_value) 

835 return values_util.on_write_assign_add( 

836 self, value, use_locking=use_locking, name=name, read_value=read_value) 

837 

838 def assign(self, value, use_locking=False, name=None, read_value=True): 

839 if values_util.is_saving_non_distributed(): 

840 return self._primary.assign(value, use_locking, name, read_value) 

841 if self._policy: 

842 return self._policy.assign( 

843 self, 

844 value, 

845 use_locking=use_locking, 

846 name=name, 

847 read_value=read_value) 

848 return values_util.on_write_assign( 

849 self, value, use_locking=use_locking, name=name, read_value=read_value) 

850 

851 def scatter_sub(self, sparse_delta, use_locking=False, name=None): 

852 if values_util.is_saving_non_distributed(): 

853 return self._primary.scatter_sub(sparse_delta, use_locking, name) 

854 if self._policy: 

855 return self._policy.scatter_sub( 

856 self, sparse_delta, use_locking=use_locking, name=name) 

857 return values_util.scatter_sub( 

858 self, sparse_delta, use_locking=use_locking, name=name) 

859 

860 def scatter_add(self, sparse_delta, use_locking=False, name=None): 

861 if values_util.is_saving_non_distributed(): 

862 return self._primary.scatter_add(sparse_delta, use_locking, name) 

863 if self._policy: 

864 return self._policy.scatter_add( 

865 self, sparse_delta, use_locking=use_locking, name=name) 

866 return values_util.scatter_add( 

867 self, sparse_delta, use_locking=use_locking, name=name) 

868 

869 def scatter_mul(self, sparse_delta, use_locking=False, name=None): 

870 if values_util.is_saving_non_distributed(): 

871 return self._primary.scatter_mul(sparse_delta, use_locking, name) 

872 if self._policy: 

873 return self._policy.scatter_mul( 

874 self, sparse_delta, use_locking=use_locking, name=name) 

875 return values_util.scatter_mul( 

876 self, sparse_delta, use_locking=use_locking, name=name) 

877 

878 def scatter_div(self, sparse_delta, use_locking=False, name=None): 

879 if values_util.is_saving_non_distributed(): 

880 return self._primary.scatter_div(sparse_delta, use_locking, name) 

881 if self._policy: 

882 return self._policy.scatter_div( 

883 self, sparse_delta, use_locking=use_locking, name=name) 

884 return values_util.scatter_div( 

885 self, sparse_delta, use_locking=use_locking, name=name) 

886 

887 def scatter_min(self, sparse_delta, use_locking=False, name=None): 

888 if values_util.is_saving_non_distributed(): 

889 return self._primary.scatter_min(sparse_delta, use_locking, name) 

890 if self._policy: 

891 return self._policy.scatter_min( 

892 self, sparse_delta, use_locking=use_locking, name=name) 

893 return values_util.scatter_min( 

894 self, sparse_delta, use_locking=use_locking, name=name) 

895 

896 def scatter_max(self, sparse_delta, use_locking=False, name=None): 

897 if values_util.is_saving_non_distributed(): 

898 return self._primary.scatter_max(sparse_delta, use_locking, name) 

899 if self._policy: 

900 return self._policy.scatter_max( 

901 self, sparse_delta, use_locking=use_locking, name=name) 

902 return values_util.scatter_max( 

903 self, sparse_delta, use_locking=use_locking, name=name) 

904 

905 def scatter_update(self, sparse_delta, use_locking=False, name=None): 

906 if values_util.is_saving_non_distributed(): 

907 return self._primary.scatter_update(sparse_delta, use_locking, name) 

908 if self._policy: 

909 return self._policy.scatter_update( 

910 self, sparse_delta, use_locking=use_locking, name=name) 

911 return values_util.scatter_update( 

912 self, sparse_delta, use_locking=use_locking, name=name) 

913 

914 def __tf_tracing_type__(self, _): 

915 return DistributedVariableTraceType(self) 

916 

917 def _gather_saveables_for_checkpoint(self): 

918 """Overrides Trackable method. 

919 

920 This allows both name-based and object-based save and restore of 

921 DistributedVariables. 

922 

923 Returns: 

924 A dictionary mapping attribute names to `SaveableObject` factories. 

925 """ 

926 

927 def _saveable_factory(name=self._common_name): 

928 return _DistributedVariableSaveable(self, self._primary, name) 

929 

930 return {trackable.VARIABLE_VALUE_KEY: _saveable_factory} 

931 

932 def _as_graph_element(self): 

933 if values_util.is_saving_non_distributed(): 

934 return self._primary._as_graph_element() # pylint: disable=protected-access 

935 if self._policy: 

936 return self._policy._as_graph_element(self) # pylint: disable=protected-access 

937 

938 raise NotImplementedError( 

939 "DistributedVariable._as_graph_element requires a valid " 

940 "VariablePolicy. Please set the policy via the `var_policy` argument " 

941 "in the constructor, or override this method in sub-classes which " 

942 "support cross-replica accesses.") 

943 

944 def _get_cross_replica(self): 

945 if values_util.is_saving_non_distributed(): 

946 return self._primary 

947 if self._policy: 

948 return self._policy._get_cross_replica(self) # pylint: disable=protected-access 

949 

950 raise NotImplementedError( 

951 "DistributedVariable._get_cross_replica requires a valid " 

952 "VariablePolicy. Please set the policy via the `var_policy` argument " 

953 "in the constructor, or override this method in sub-classes which " 

954 "support cross-replica accesses.") 

955 

956 def _update_cross_replica(self, update_fn, value, **kwargs): 

957 """Applies updates across replicas. 

958 

959 Args: 

960 update_fn: A callable to pass to `strategy.extended.update` to update the 

961 variable. It should has the same signature as `Variable.assign()`. 

962 value: value to be passed to `update_fn`. 

963 **kwargs: remaining arguments to `update_fn`. 

964 

965 Returns: 

966 Updated variable or `tf.Operation`. 

967 """ 

968 values_util.mark_as_unsaveable() 

969 return self.distribute_strategy.extended.update( 

970 self, update_fn, args=(value,), kwargs=kwargs, group=True) 

971 

972 def _update_replica(self, update_fn, value, **kwargs): 

973 """Applies updates in one replica. 

974 

975 Args: 

976 update_fn: A callable to update the variable. It should has the same 

977 signature as `Variable.assign()`. 

978 value: value to be passed to `update_fn`. 

979 **kwargs: remaining arguments to `update_fn`. 

980 

981 Returns: 

982 Updated variable or `tf.Operation`. 

983 """ 

984 if self._policy: 

985 return self._policy._update_replica(self, update_fn, value, **kwargs) # pylint: disable=protected-access 

986 raise NotImplementedError( 

987 "DistributedVariable._update_replica requires a valid VariablePolicy. " 

988 "Please set the policy via the `var_policy` argument in the " 

989 "constructor, or override this method in sub-classes which support " 

990 "cross-replica accesses.") 

991 

992 def _update(self, update_fn, value, **kwargs): 

993 """Applies updates depending on the context. 

994 

995 The method calls `_update_replica` in replica context, 

996 `_update_cross_replica` in cross replica context, and `update_fn` in update 

997 context. 

998 

999 If `read_value` is True, the method returns the updated Variable. If 

1000 `read_value` is False, the method returns the update `tf.Operation`. 

1001 

1002 Args: 

1003 update_fn: A callable to pass to `strategy.extended.update` to update the 

1004 variable. It should have the same signature as `Variable.assign()`. 

1005 value: value to be passed to `update_fn`. 

1006 **kwargs: keyword arguments to `update_fn`. 

1007 

1008 Returns: 

1009 Updated variable or `tf.Operation`. 

1010 

1011 """ 

1012 if values_util.is_saving_non_distributed(): 

1013 return update_fn(self._primary, value, **kwargs) 

1014 with distribute_lib.enter_or_assert_strategy(self.distribute_strategy): 

1015 if distribute_lib.in_cross_replica_context(): 

1016 update_replica_id = distribute_lib.get_update_replica_id() 

1017 if update_replica_id is not None: 

1018 replica_value = self._get_replica(update_replica_id) 

1019 return update_fn(replica_value, value, **kwargs) 

1020 return self._update_cross_replica(update_fn, value, **kwargs) 

1021 else: 

1022 values_util.assert_replica_context(self.distribute_strategy) 

1023 return self._update_replica(update_fn, value, **kwargs) 

1024 

1025 def _should_act_as_resource_variable(self): 

1026 """Pass resource_variable_ops.is_resource_variable check.""" 

1027 pass 

1028 

1029 def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False): 

1030 """Converts a variable to a tensor.""" 

1031 if values_util.is_saving_non_distributed(): 

1032 return ops.convert_to_tensor( 

1033 self._primary, dtype=dtype, name=name, as_ref=as_ref) 

1034 with distribute_lib.enter_or_assert_strategy(self._distribute_strategy): 

1035 return ops.convert_to_tensor( 

1036 self._get(), dtype=dtype, name=name, as_ref=as_ref) 

1037 

1038 def __tf_tensor__(self, 

1039 dtype: Optional[dtypes.DType] = None, 

1040 name: Optional[str] = None) -> ops.Tensor: 

1041 return self._dense_var_to_tensor(dtype, name) 

1042 

1043 def _export_to_saved_model_graph(self, 

1044 object_map=None, 

1045 tensor_map=None, 

1046 options=None, 

1047 **kwargs): 

1048 # Initialize for self._primary first, so that obj_map[self._primary] and 

1049 # resource_map[self._primary.handle] contain mapped values. 

1050 resource_list = self._primary._export_to_saved_model_graph( # pylint:disable=protected-access 

1051 object_map=object_map, 

1052 tensor_map=tensor_map, 

1053 options=options, 

1054 **kwargs) 

1055 for v in [v for v in self._values if v != self._primary]: 

1056 if (options.experimental_variable_policy # pylint:disable=protected-access 

1057 ._expand_distributed_variables()): 

1058 resource_list.extend( 

1059 v._export_to_saved_model_graph( # pylint:disable=protected-access 

1060 object_map=object_map, 

1061 tensor_map=tensor_map, 

1062 options=options, 

1063 **kwargs)) # pylint:disable=protected-access 

1064 else: 

1065 object_map[v] = object_map[self._primary] 

1066 tensor_map[v.handle] = tensor_map[self._primary.handle] 

1067 resource_list.append(v.handle) 

1068 object_map[self] = object_map[self._primary] 

1069 tensor_map[self] = tensor_map[self._primary.handle] 

1070 resource_list.append(self) 

1071 if self._packed_var is not None: 

1072 tensor_map[self._packed_var.packed_handle] = tensor_map[ 

1073 self._primary.handle] 

1074 resource_list.append(self._packed_var.packed_handle) 

1075 return resource_list 

1076 

1077 def _write_object_proto(self, proto, options): 

1078 """Update a SavedObject proto for the caller. 

1079 

1080 If a DistributedVariable object supports this method, it will be called when 

1081 saving with a pre-built `SavedObject` proto representing the object, plus an 

1082 instance of `SaveOptions`. This method is then free to modify that proto 

1083 instance. 

1084 

1085 `DistributedVariable` with `AUTO` or `ON_WRITE` synchronization optionally 

1086 write out information about their components to the 

1087 `experimental_distributed_variable_components` field of a 

1088 `SavedVariable` (depending on the `SaveOptions` variable policy). 

1089 

1090 Args: 

1091 proto: A pre-built `SavedObject` proto for this object. It is assumed this 

1092 will be a `SavedVariable` instance. 

1093 options: A `SaveOptions` instance. 

1094 """ 

1095 resource_variable_ops.write_object_proto_for_resource_variable( 

1096 self, proto, options) 

1097 if self._is_mirrored(): 

1098 values_util.write_object_proto(self, proto, options) 

1099 

1100 @property 

1101 def is_distributed_variable(self): 

1102 return True 

1103 

1104 def __tf_experimental_restore_capture__( 

1105 self, concrete_function, internal_capture): 

1106 graph = concrete_function.graph 

1107 # Add given distributed variable to captures with given placeholder. 

1108 graph.replace_capture(self, internal_capture) 

1109 record.record_operation( 

1110 "captured_value", [internal_capture], [self], 

1111 backward_function=lambda x: [x], 

1112 forward_function=lambda x: [x]) 

1113 return self 

1114 

1115 

1116# We extend from `saveable_object.SaveableObject` instead of 

1117# `saveable_object_util.ResourceVariableSaveable` since we need to read the 

1118# value of ONREAD variables when saving. `SaveableObject` provides a way to 

1119# specify the function to run to get the value of the variable or tensor at 

1120# saving time. We can use this for both ON_READ and ON_WRITE variables. 

1121# TODO(b/164586507): Consolidate ON_WRITE and ON_READ saving/restoring logic 

1122# if possible. 

1123class _DistributedVariableSaveable(saveable_object.SaveableObject): 

1124 """Class for defining how to restore a DistributedVariable.""" 

1125 

1126 def __init__(self, distributed_variable, primary_variable, name): 

1127 self._distributed_variable = distributed_variable 

1128 if not self._distributed_variable._policy: 

1129 raise ValueError( 

1130 "The VariablePolicy of the argument `distributed_variable` must be " 

1131 "set to create a _DistributedVariableSaveable. Please set it via " 

1132 "the `var_policy` argument in the constructor of DistributedVariable." 

1133 ) 

1134 tensor, spec = distributed_variable._policy.get_saveable( 

1135 distributed_variable, primary_variable, name) 

1136 super(_DistributedVariableSaveable, self).__init__(tensor, spec, name) 

1137 

1138 def restore(self, restored_tensors, restored_shapes): 

1139 """Restore the same value into all variables.""" 

1140 tensor, = restored_tensors 

1141 return self._distributed_variable._policy.get_restore_ops( # pylint: disable=protected-access 

1142 self._distributed_variable, tensor) 

1143 

1144 

1145class _MirroredSaveable(saveable_object.SaveableObject): 

1146 """Class for defining how to restore a MirroredVariable.""" 

1147 

1148 def __init__(self, mirrored_variable, primary_variable, name): 

1149 self._mirrored_variable = mirrored_variable 

1150 tensor, spec = values_util.get_on_write_saveable(self._mirrored_variable, 

1151 primary_variable, name) 

1152 super(_MirroredSaveable, self).__init__(tensor, spec, name) 

1153 

1154 def restore(self, restored_tensors, restored_shapes): 

1155 """Restore the same value into all variables.""" 

1156 tensor, = restored_tensors 

1157 return values_util.get_on_write_restore_ops(self._mirrored_variable, tensor) 

1158 

1159 

1160class MirroredVariable(DistributedVariable, Mirrored): 

1161 """Holds a map from replica to variables whose values are kept in sync.""" 

1162 

1163 def _is_mirrored(self): 

1164 return Mirrored._is_mirrored(self) # Use correct parent class. 

1165 

1166 def _update_replica(self, update_fn, value, **kwargs): 

1167 return _on_write_update_replica(self, update_fn, value, **kwargs) 

1168 

1169 def scatter_min(self, *args, **kwargs): 

1170 if values_util.is_saving_non_distributed(): 

1171 return self._primary.scatter_min(*args, **kwargs) 

1172 if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and 

1173 self._aggregation != vs.VariableAggregation.NONE): 

1174 raise NotImplementedError( 

1175 values_util.scatter_error_msg.format( 

1176 op_name="scatter_min", aggregation=self._aggregation)) 

1177 return super(MirroredVariable, self).scatter_min(*args, **kwargs) 

1178 

1179 def scatter_max(self, *args, **kwargs): 

1180 if values_util.is_saving_non_distributed(): 

1181 return self._primary.scatter_max(*args, **kwargs) 

1182 if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and 

1183 self._aggregation != vs.VariableAggregation.NONE): 

1184 raise NotImplementedError( 

1185 values_util.scatter_error_msg.format( 

1186 op_name="scatter_max", aggregation=self._aggregation)) 

1187 return super(MirroredVariable, self).scatter_max(*args, **kwargs) 

1188 

1189 def scatter_update(self, *args, **kwargs): 

1190 if values_util.is_saving_non_distributed(): 

1191 return self._primary.scatter_update(*args, **kwargs) 

1192 if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and 

1193 self._aggregation != vs.VariableAggregation.NONE): 

1194 raise NotImplementedError( 

1195 values_util.scatter_error_msg.format( 

1196 op_name="scatter_update", aggregation=self._aggregation)) 

1197 return super(MirroredVariable, self).scatter_update(*args, **kwargs) 

1198 

1199 def _get_cross_replica(self): 

1200 # Return identity, to avoid directly exposing the variable to the user and 

1201 # allowing it to be modified by mistake. 

1202 return array_ops.identity(Mirrored._get_cross_replica(self)) 

1203 

1204 def _as_graph_element(self): 

1205 return self._get_on_device_or_primary()._as_graph_element() # pylint: disable=protected-access 

1206 

1207 def _gather_saveables_for_checkpoint(self): 

1208 """Overrides Trackable method. 

1209 

1210 This allows both name-based and object-based save and restore of 

1211 MirroredVariables. 

1212 

1213 Returns: 

1214 A dictionary mapping attribute names to `SaveableObject` factories. 

1215 """ 

1216 

1217 def _saveable_factory(name=self._common_name): 

1218 return _MirroredSaveable(self, self._primary, name) 

1219 

1220 return {trackable.VARIABLE_VALUE_KEY: _saveable_factory} 

1221 

1222 def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False): 

1223 """Converts a variable to a tensor.""" 

1224 # TODO(b/154017756): Make _dense_var_to_tensor consistent between ON_READ 

1225 # and ON_WRITE. 

1226 # Try to avoid assignments to and other mutations of MirroredVariable 

1227 # state except through a DistributionStrategy.extended.update() or any of 

1228 # the `assign*` and `scatter*` calls. 

1229 if as_ref: 

1230 # A TF 1.x case where the variable is a boolean variable and used like: 

1231 # tf.cond(v, true_fn, false_fn). 

1232 raise ValueError( 

1233 "You may be using variable created under distribute strategy in TF " 

1234 "1.x control flows. Try explicitly converting the variable to Tensor " 

1235 "using variable.read_value(), or switch to TF 2.x.") 

1236 return ops.convert_to_tensor( 

1237 self._get(), dtype=dtype, name=name, as_ref=as_ref) 

1238 

1239 

1240class _SyncOnReadSaveable(saveable_object.SaveableObject): 

1241 """Class for defining how to restore a SyncOnReadVariable.""" 

1242 

1243 def __init__(self, sync_on_read_variable, name): 

1244 self._sync_on_read_variable = sync_on_read_variable 

1245 tensor, spec = values_util.get_on_read_saveable( 

1246 sync_on_read_variable, sync_on_read_variable._primary, name) 

1247 

1248 super(_SyncOnReadSaveable, self).__init__(tensor, spec, name) 

1249 

1250 def restore(self, restored_tensors, restored_shapes): 

1251 """Restore the same value into all variables.""" 

1252 tensor, = restored_tensors 

1253 return values_util.get_on_read_restore_ops( 

1254 self._sync_on_read_variable, tensor, 

1255 self._sync_on_read_variable.aggregation) 

1256 

1257 

1258class SyncOnReadVariable(DistributedVariable): 

1259 """Holds a map from replica to variables whose values are reduced on save.""" 

1260 

1261 def _update_replica(self, update_fn, value, **kwargs): 

1262 return update_fn(self._get_on_device_or_primary(), value, **kwargs) 

1263 

1264 def _get(self): 

1265 """Returns the value of SyncOnReadVariable based on surrounding context. 

1266 

1267 If called under a non-default replica-context, returns the corresponding 

1268 variable on that replica. 

1269 If called under default replica-context or cross-replica context, returns 

1270 the synced value. 

1271 """ 

1272 with distribute_lib.enter_or_assert_strategy(self._distribute_strategy): 

1273 return super(SyncOnReadVariable, self)._get() 

1274 

1275 # TODO(b/154017756): Make assign behaivor in cross replica context consistent 

1276 # with MirroredVariable. 

1277 def assign_sub(self, value, use_locking=False, name=None, read_value=True): 

1278 if values_util.is_saving_non_distributed(): 

1279 return self._primary.assign_sub(value, use_locking, name, read_value) 

1280 with distribute_lib.enter_or_assert_strategy(self._distribute_strategy): 

1281 if (distribute_lib.in_cross_replica_context() and 

1282 not values_util.in_replica_update_context()): 

1283 values_util.mark_as_unsaveable() 

1284 return values_util.on_read_assign_sub_cross_replica( 

1285 self, value, read_value=read_value) 

1286 else: 

1287 return super(SyncOnReadVariable, 

1288 self).assign_sub(value, use_locking, name, read_value) 

1289 

1290 def assign_add(self, value, use_locking=False, name=None, read_value=True): 

1291 if values_util.is_saving_non_distributed(): 

1292 return self._primary.assign_add(value, use_locking, name, read_value) 

1293 with distribute_lib.enter_or_assert_strategy(self._distribute_strategy): 

1294 if (distribute_lib.in_cross_replica_context() and 

1295 not values_util.in_replica_update_context()): 

1296 values_util.mark_as_unsaveable() 

1297 return values_util.on_read_assign_add_cross_replica( 

1298 self, value, read_value=read_value) 

1299 else: 

1300 return super(SyncOnReadVariable, 

1301 self).assign_add(value, use_locking, name, read_value) 

1302 

1303 def assign(self, value, use_locking=False, name=None, read_value=True): 

1304 if values_util.is_saving_non_distributed(): 

1305 return self._primary.assign(value, use_locking, name, read_value) 

1306 with distribute_lib.enter_or_assert_strategy(self._distribute_strategy): 

1307 if (distribute_lib.in_cross_replica_context() and 

1308 not values_util.in_replica_update_context()): 

1309 values_util.mark_as_unsaveable() 

1310 return values_util.on_read_assign_cross_replica( 

1311 self, value, read_value=read_value) 

1312 else: 

1313 return super(SyncOnReadVariable, self).assign(value, use_locking, name, 

1314 read_value) 

1315 

1316 def _scatter_not_implemented(self, method): 

1317 raise NotImplementedError( 

1318 f"Variables with `synchronization=ON_READ` doesn't support `{method}`") 

1319 

1320 def scatter_sub(self, *args, **kwargs): 

1321 if values_util.is_saving_non_distributed(): 

1322 return self._primary.scatter_sub(*args, **kwargs) 

1323 self._scatter_not_implemented("scatter_sub") 

1324 

1325 def scatter_add(self, *args, **kwargs): 

1326 if values_util.is_saving_non_distributed(): 

1327 return self._primary.scatter_add(*args, **kwargs) 

1328 self._scatter_not_implemented("scatter_add") 

1329 

1330 def scatter_mul(self, *args, **kwargs): 

1331 if values_util.is_saving_non_distributed(): 

1332 return self._primary.scatter_mul(*args, **kwargs) 

1333 self._scatter_not_implemented("scatter_mul") 

1334 

1335 def scatter_div(self, *args, **kwargs): 

1336 if values_util.is_saving_non_distributed(): 

1337 return self._primary.scatter_div(*args, **kwargs) 

1338 self._scatter_not_implemented("scatter_div") 

1339 

1340 def scatter_min(self, *args, **kwargs): 

1341 if values_util.is_saving_non_distributed(): 

1342 return self._primary.scatter_min(*args, **kwargs) 

1343 self._scatter_not_implemented("scatter_min") 

1344 

1345 def scatter_max(self, *args, **kwargs): 

1346 if values_util.is_saving_non_distributed(): 

1347 return self._primary.scatter_max(*args, **kwargs) 

1348 self._scatter_not_implemented("scatter_max") 

1349 

1350 def scatter_update(self, *args, **kwargs): 

1351 if values_util.is_saving_non_distributed(): 

1352 return self._primary.scatter_update(*args, **kwargs) 

1353 self._scatter_not_implemented("scatter_update") 

1354 

1355 def value(self): 

1356 if distribute_lib.in_variable_sync_on_read_context(): 

1357 raise NotImplementedError( 

1358 "call `variable.value()` inside variable_sync_on_read_context is not " 

1359 "supported") 

1360 if values_util.is_saving_non_distributed(): 

1361 return self._primary.value() 

1362 with distribute_lib.enter_or_assert_strategy(self._distribute_strategy): 

1363 if (distribute_lib.in_cross_replica_context() and 

1364 not values_util.in_replica_update_context()): 

1365 if self._aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA: 

1366 return self._get_replica(0).value() 

1367 return self._get_cross_replica() 

1368 else: 

1369 # _get_on_device_or_primary() returns a Variable. 

1370 return self._get_on_device_or_primary().value() 

1371 

1372 def read_value(self): 

1373 if distribute_lib.in_variable_sync_on_read_context(): 

1374 raise NotImplementedError( 

1375 "call `variable.read_value()` inside variable_sync_on_read_context is" 

1376 " not supported") 

1377 return super().read_value() 

1378 

1379 def _get_cross_replica(self): 

1380 if self._aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA: 

1381 # Consider returning a tensor value here to make the return value of 

1382 # _get_cross_replica consistent. 

1383 return self._get_replica(0) 

1384 if self._aggregation == vs.VariableAggregation.SUM: 

1385 values_util.mark_as_unsaveable() 

1386 with distribute_lib.enter_or_assert_strategy(self._distribute_strategy): 

1387 return self._distribute_strategy.reduce( 

1388 reduce_util.ReduceOp.from_variable_aggregation(self._aggregation), 

1389 self, 

1390 axis=None) 

1391 

1392 def _as_graph_element(self): 

1393 if values_util.is_saving_non_distributed(): 

1394 return self._primary._as_graph_element() # pylint: disable=protected-access 

1395 # pylint: disable=protected-access 

1396 with distribute_lib.enter_or_assert_strategy(self._distribute_strategy): 

1397 if distribute_lib.in_cross_replica_context(): 

1398 return ops.convert_to_tensor(self._get_cross_replica()) 

1399 return self._get()._as_graph_element() 

1400 

1401 def _gather_saveables_for_checkpoint(self): 

1402 """Overrides Trackable method. 

1403 

1404 This allows both name-based and object-based save and restore of 

1405 `SyncOnReadVariable`s. 

1406 

1407 Returns: 

1408 A dictionary mapping attribute names to `SaveableObject` factories. 

1409 """ 

1410 

1411 def _saveable_factory(name=self._common_name): 

1412 return _SyncOnReadSaveable(self, name) 

1413 

1414 return {trackable.VARIABLE_VALUE_KEY: _saveable_factory} 

1415 

1416 def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False): 

1417 """Converts a SyncOnReadVariable to a tensor.""" 

1418 if values_util.is_saving_non_distributed(): 

1419 return ops.convert_to_tensor( 

1420 self._primary, dtype=dtype, name=name, as_ref=as_ref) 

1421 with distribute_lib.enter_or_assert_strategy(self._distribute_strategy): 

1422 replica_context = distribute_lib.get_replica_context() 

1423 if (replica_context is not None and 

1424 distribute_lib.in_variable_sync_on_read_context()): 

1425 if self._aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA: 

1426 return ops.convert_to_tensor( 

1427 self._get_replica(0), dtype=dtype, name=name, as_ref=as_ref) 

1428 if self._aggregation == vs.VariableAggregation.SUM: 

1429 values_util.mark_as_unsaveable() 

1430 # pylint: disable=protected-access 

1431 reduced = ( 

1432 replica_context.strategy.extended._replica_ctx_all_reduce( 

1433 reduce_util.ReduceOp.from_variable_aggregation( 

1434 self._aggregation), 

1435 self._get().read_value())) 

1436 return ops.convert_to_tensor( 

1437 reduced, dtype=dtype, name=name, as_ref=as_ref) 

1438 

1439 return ops.convert_to_tensor( 

1440 self._get(), dtype=dtype, name=name, as_ref=as_ref) 

1441 

1442 

1443# Register a conversion functions which reads the value of the variable, 

1444# allowing instances of the class to be used as tensors. 

1445# DistributedVariable 

1446def _tensor_conversion_distributed_var(var, 

1447 dtype=None, 

1448 name=None, 

1449 as_ref=False): 

1450 return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref) # pylint: disable=protected-access 

1451 

1452 

1453tensor_conversion_registry.register_tensor_conversion_function( 

1454 DistributedVariable, _tensor_conversion_distributed_var) 

1455 

1456 

1457# MirroredVariables 

1458def _tensor_conversion_mirrored(var, dtype=None, name=None, as_ref=False): 

1459 return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref) # pylint: disable=protected-access 

1460 

1461 

1462tensor_conversion_registry.register_tensor_conversion_function( 

1463 MirroredVariable, _tensor_conversion_mirrored) 

1464 

1465 

1466# Mirrored Values 

1467def _tensor_conversion_mirrored_val(value, dtype=None, name=None, as_ref=False): 

1468 return ops.convert_to_tensor( 

1469 value._get(), dtype=dtype, name=name, as_ref=as_ref) # pylint: disable=protected-access 

1470 

1471 

1472tensor_conversion_registry.register_tensor_conversion_function( 

1473 Mirrored, _tensor_conversion_mirrored_val) 

1474 

1475 

1476# SyncOnReadVariables 

1477def _tensor_conversion_sync_on_read(var, dtype=None, name=None, as_ref=False): 

1478 return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref) # pylint: disable=protected-access 

1479 

1480 

1481tensor_conversion_registry.register_tensor_conversion_function( 

1482 SyncOnReadVariable, _tensor_conversion_sync_on_read) 

1483 

1484 

1485class VariablePolicy(object): 

1486 """Policy defining synchronization and aggregation of a distributed variable. 

1487 

1488 Given `synchronization` and `aggregation` parameters set on a `tf.Variable` 

1489 during variable creation within `tf.distribute` scope, `tf.distribute` creates 

1490 an appropriate policy object and assigns it to the distributed variable. All 

1491 variable operations are delegated to the respective policy object. 

1492 """ 

1493 

1494 def __init__(self, aggregation): 

1495 self._aggregation = aggregation 

1496 

1497 def value(self): 

1498 raise NotImplementedError( 

1499 "VariablePolicy.value should be overriden by sub-classes.") 

1500 

1501 def _is_mirrored(self): 

1502 raise NotImplementedError( 

1503 "VariablePolicy._is_mirrored should be overriden by sub-classes.") 

1504 

1505 def _as_graph_element(self, _): 

1506 raise NotImplementedError( 

1507 "VariablePolicy._as_graph_element should be overriden by sub-classes.") 

1508 

1509 def _get_cross_replica(self, var): 

1510 raise NotImplementedError( 

1511 "VariablePolicy._get_cross_replica should be overriden by sub-classes.") 

1512 

1513 def _update_replica(self, var, update_fn, value, **kwargs): 

1514 raise NotImplementedError( 

1515 "VariablePolicy._update_replica should be overriden by sub-classes.") 

1516 

1517 

1518class OnReadPolicy(VariablePolicy): 

1519 """Policy defined for `tf.VariableSynchronization.ON_READ` synchronization. 

1520 

1521 This policy is created when `synchronization` is set to 

1522 `tf.VariableSynchronization.ON_READ` and `aggregation` is set to any of the 

1523 values allowed by the `tf.VariableAggregation` enum such as `NONE`, `SUM`, 

1524 `MEAN` or `ONLY_FIRST_REPLICA`when creating a `tf.Variable` in `tf.distribute` 

1525 scope. 

1526 """ 

1527 

1528 def _is_mirrored(self): 

1529 return False 

1530 

1531 def value(self, var): 

1532 with distribute_lib.enter_or_assert_strategy(var.distribute_strategy): 

1533 if (distribute_lib.in_cross_replica_context() and 

1534 not values_util.in_replica_update_context()): 

1535 if self._aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA: 

1536 return var._get_replica(0).value() # pylint: disable=protected-access 

1537 return var._get_cross_replica() # pylint: disable=protected-access 

1538 else: 

1539 return var._get_on_device_or_primary().value() # pylint: disable=protected-access 

1540 

1541 def _as_graph_element(self, var): 

1542 with distribute_lib.enter_or_assert_strategy(var.distribute_strategy): 

1543 if distribute_lib.in_cross_replica_context(): 

1544 return ops.convert_to_tensor(var._get_cross_replica()) # pylint: disable=protected-access 

1545 return var._get()._as_graph_element() # pylint: disable=protected-access 

1546 

1547 def _get_cross_replica(self, var): 

1548 if self._aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA: 

1549 return var._get_replica(0) # pylint: disable=protected-access 

1550 if self._aggregation == vs.VariableAggregation.SUM: 

1551 values_util.mark_as_unsaveable() 

1552 with distribute_lib.enter_or_assert_strategy(var.distribute_strategy): 

1553 return var.distribute_strategy.reduce( 

1554 reduce_util.ReduceOp.from_variable_aggregation(self._aggregation), 

1555 var, 

1556 axis=None) 

1557 

1558 def _update_replica(self, var, update_fn, value, **kwargs): 

1559 return update_fn(var._get_on_device_or_primary(), value, **kwargs) # pylint: disable=protected-access 

1560 

1561 def _scatter_not_implemented(self, method): 

1562 raise NotImplementedError(f"ON_READ variables doesn't support `{method}` " 

1563 "in cross replica context") 

1564 

1565 def assign_sub(self, 

1566 var, 

1567 value, 

1568 use_locking=False, 

1569 name=None, 

1570 read_value=True): 

1571 """Subtracts a value from this variable.""" 

1572 with distribute_lib.enter_or_assert_strategy(var.distribute_strategy): 

1573 if (distribute_lib.in_cross_replica_context() and 

1574 not values_util.in_replica_update_context()): 

1575 values_util.mark_as_unsaveable() 

1576 return values_util.on_read_assign_sub_cross_replica( 

1577 var, value, read_value=read_value) 

1578 else: 

1579 return values_util.on_write_assign_sub( 

1580 var, 

1581 value, 

1582 use_locking=use_locking, 

1583 name=name, 

1584 read_value=read_value) 

1585 

1586 def assign_add(self, 

1587 var, 

1588 value, 

1589 use_locking=False, 

1590 name=None, 

1591 read_value=True): 

1592 """Adds a value to this variable.""" 

1593 with distribute_lib.enter_or_assert_strategy(var.distribute_strategy): 

1594 if (distribute_lib.in_cross_replica_context() and 

1595 not values_util.in_replica_update_context()): 

1596 values_util.mark_as_unsaveable() 

1597 return values_util.on_read_assign_add_cross_replica( 

1598 var, value, read_value=read_value) 

1599 else: 

1600 return values_util.on_write_assign_add( 

1601 var, 

1602 value, 

1603 use_locking=use_locking, 

1604 name=name, 

1605 read_value=read_value) 

1606 

1607 def assign(self, var, value, use_locking=False, name=None, read_value=True): 

1608 with distribute_lib.enter_or_assert_strategy(var.distribute_strategy): 

1609 if (distribute_lib.in_cross_replica_context() and 

1610 not values_util.in_replica_update_context()): 

1611 values_util.mark_as_unsaveable() 

1612 return values_util.on_read_assign_cross_replica( 

1613 var, value, read_value=read_value) 

1614 else: 

1615 return values_util.on_write_assign( 

1616 var, 

1617 value, 

1618 use_locking=use_locking, 

1619 name=name, 

1620 read_value=read_value) 

1621 

1622 def scatter_sub(self, *args, **kwargs): 

1623 del args, kwargs 

1624 self._scatter_not_implemented("scatter_sub") 

1625 

1626 def scatter_add(self, *args, **kwargs): 

1627 del args, kwargs 

1628 self._scatter_not_implemented("scatter_add") 

1629 

1630 def scatter_mul(self, *args, **kwargs): 

1631 del args, kwargs 

1632 self._scatter_not_implemented("scatter_mul") 

1633 

1634 def scatter_div(self, *args, **kwargs): 

1635 del args, kwargs 

1636 self._scatter_not_implemented("scatter_div") 

1637 

1638 def scatter_min(self, *args, **kwargs): 

1639 del args, kwargs 

1640 self._scatter_not_implemented("scatter_min") 

1641 

1642 def scatter_max(self, *args, **kwargs): 

1643 del args, kwargs 

1644 self._scatter_not_implemented("scatter_max") 

1645 

1646 def scatter_update(self, *args, **kwargs): 

1647 del args, kwargs 

1648 self._scatter_not_implemented("scatter_update") 

1649 

1650 def get_saveable(self, var, primary_var, name): 

1651 """Create a saveable object for the given variable.""" 

1652 return values_util.get_on_read_saveable(var, primary_var, name) 

1653 

1654 def get_restore_ops(self, var, tensor): 

1655 """Restore the same value into all variables.""" 

1656 return values_util.get_on_read_restore_ops(var, tensor, self._aggregation) 

1657 

1658 

1659class OnWritePolicy(VariablePolicy): 

1660 """Policy defined for `tf.VariableSynchronization.ON_WRITE` synchronization. 

1661 

1662 This policy is created when the following `synchronization` and `aggregation` 

1663 parameters are specified when creating a `tf.Variable` in `tf.distribute` 

1664 scope and `synchronization` is equal to `tf.VariableSynchronization.ON_WRITE` 

1665 or `tf.VariableSynchronization.AUTO`. 

1666 """ 

1667 

1668 def _is_mirrored(self): 

1669 return True 

1670 

1671 def value(self, var): 

1672 return var._get_on_device_or_primary().value() # pylint: disable=protected-access 

1673 

1674 def _as_graph_element(self, var): 

1675 return var._get_on_device_or_primary()._as_graph_element() # pylint: disable=protected-access 

1676 

1677 def _get_cross_replica(self, var): 

1678 # Return identity, to avoid directly exposing the variable to the user and 

1679 # allowing it to be modified by mistake. 

1680 return array_ops.identity(var._get_on_device_or_primary()) # pylint: disable=protected-access 

1681 

1682 def _update_replica(self, var, update_fn, value, **kwargs): 

1683 if var.aggregation == variables_lib.VariableAggregation.NONE: 

1684 return update_fn(var._get_on_device_or_primary(), value, **kwargs) # pylint: disable=protected-access 

1685 return _on_write_update_replica(var, update_fn, value, **kwargs) 

1686 

1687 def assign(self, var, value, use_locking=False, name=None, read_value=True): 

1688 return values_util.on_write_assign( 

1689 var, value, use_locking=use_locking, name=name, read_value=read_value) 

1690 

1691 def assign_add(self, 

1692 var, 

1693 value, 

1694 use_locking=False, 

1695 name=None, 

1696 read_value=True): 

1697 return values_util.on_write_assign_add( 

1698 var, value, use_locking=use_locking, name=name, read_value=read_value) 

1699 

1700 def assign_sub(self, 

1701 var, 

1702 value, 

1703 use_locking=False, 

1704 name=None, 

1705 read_value=True): 

1706 return values_util.on_write_assign_sub( 

1707 var, value, use_locking=use_locking, name=name, read_value=read_value) 

1708 

1709 def scatter_sub(self, var, sparse_delta, use_locking=False, name=None): 

1710 return values_util.scatter_sub( 

1711 var, sparse_delta, use_locking=use_locking, name=name) 

1712 

1713 def scatter_add(self, var, sparse_delta, use_locking=False, name=None): 

1714 return values_util.scatter_add( 

1715 var, sparse_delta, use_locking=use_locking, name=name) 

1716 

1717 def scatter_mul(self, var, sparse_delta, use_locking=False, name=None): 

1718 return values_util.scatter_mul( 

1719 var, sparse_delta, use_locking=use_locking, name=name) 

1720 

1721 def scatter_div(self, var, sparse_delta, use_locking=False, name=None): 

1722 return values_util.scatter_div( 

1723 var, sparse_delta, use_locking=use_locking, name=name) 

1724 

1725 def scatter_min(self, var, sparse_delta, use_locking=False, name=None): 

1726 if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and 

1727 self._aggregation != vs.VariableAggregation.NONE): 

1728 raise NotImplementedError( 

1729 values_util.scatter_error_msg.format( 

1730 op_name="scatter_min", aggregation=self._aggregation)) 

1731 return values_util.scatter_min( 

1732 var, sparse_delta, use_locking=use_locking, name=name) 

1733 

1734 def scatter_max(self, var, sparse_delta, use_locking=False, name=None): 

1735 if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and 

1736 self._aggregation != vs.VariableAggregation.NONE): 

1737 raise NotImplementedError( 

1738 values_util.scatter_error_msg.format( 

1739 op_name="scatter_max", aggregation=self._aggregation)) 

1740 return values_util.scatter_max( 

1741 var, sparse_delta, use_locking=use_locking, name=name) 

1742 

1743 def scatter_update(self, var, sparse_delta, use_locking=False, name=None): 

1744 if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and 

1745 self._aggregation != vs.VariableAggregation.NONE): 

1746 raise NotImplementedError( 

1747 values_util.scatter_error_msg.format( 

1748 op_name="scatter_update", aggregation=self._aggregation)) 

1749 return values_util.scatter_update( 

1750 var, sparse_delta, use_locking=use_locking, name=name) 

1751 

1752 def get_saveable(self, var, primary_var, name): 

1753 """Saveable ops for AUTO variables.""" 

1754 return values_util.get_on_write_saveable(var, primary_var, name) 

1755 

1756 def get_restore_ops(self, var, tensor): 

1757 return values_util.get_on_write_restore_ops(var, tensor) 

1758 

1759 

1760class PerWorkerResource(): 

1761 """A per-worker CapturableResource class for non-ParameterServer strategy. 

1762 

1763 Resources that populate `host_to_resources` should be instances of classes 

1764 subclassing CapturableResource, although currently it's only used and tested 

1765 for StaticHashTable with TPUStrategy. 

1766 """ 

1767 

1768 def __init__(self, strategy, host_to_resources): 

1769 distribute_lib.distribution_strategy_input_api_counter.get_cell( 

1770 "PerWorkerResource", "TPUDistributedLookupTable").increase_by(1) 

1771 self._strategy = strategy 

1772 self._host_to_resources = host_to_resources 

1773 

1774 def __getattribute__(self, name): 

1775 if name not in ("__init__", "__getattribute__", "_host_to_resources", 

1776 "_strategy", "local_resource"): 

1777 return getattr(self.local_resource(), name) 

1778 return super(PerWorkerResource, self).__getattribute__(name) 

1779 

1780 def __setattr__(self, name, value): 

1781 if name not in ("_strategy", "_host_to_resources"): 

1782 return setattr(self.local_resource(), name, value) 

1783 return super(PerWorkerResource, self).__setattr__(name, value) 

1784 

1785 def local_resource(self): 

1786 """Returns the resource on the local worker.""" 

1787 current_device = device_util.canonicalize(device_util.current()) 

1788 host_device = device_util.canonicalize( 

1789 device_util.get_host_for_device(current_device)) 

1790 return self._host_to_resources.get( 

1791 host_device, 

1792 self._host_to_resources[next(iter(self._host_to_resources))])