Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/distribute/tpu_values.py: 32%

260 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Various classes representing TPU distributed values. 

16 

17Note that the tests are in values_test.py . 

18 

19""" 

20 

21from tensorflow.python.distribute import packed_distributed_variable as packed 

22from tensorflow.python.distribute import tpu_replicated_variable 

23from tensorflow.python.distribute import tpu_util 

24from tensorflow.python.distribute import values 

25from tensorflow.python.distribute import values_util 

26from tensorflow.python.eager import context 

27from tensorflow.python.eager import tape 

28from tensorflow.python.framework import ops 

29from tensorflow.python.ops import gen_resource_variable_ops 

30from tensorflow.python.ops import math_ops 

31from tensorflow.python.ops import variable_scope 

32 

33 

34_scatter_error_msg = ("{op_name} is only supported for distributed " 

35 "variable (variable created within certain " 

36 "`tf.distribute.Strategy` scope) with NONE " 

37 " aggregation, got: {aggregation}.") 

38 

39 

40class TPUVariableMixin(object): 

41 """Mixin for TPU variables.""" 

42 

43 def __init__(self, *args, **kwargs): 

44 super(TPUVariableMixin, self).__init__(*args, **kwargs) 

45 

46 # Handle ID is needed for `get_replicated_var_handle` to cache the variables 

47 # correctly since in eager mode different variables can have the same name. 

48 if ops.executing_eagerly_outside_functions(): 

49 self._handle_id = self._common_name + "_" + str(id(self._primary)) 

50 else: 

51 self._handle_id = self._common_name 

52 

53 def __getattr__(self, name): 

54 if tpu_util.enclosing_tpu_context() is None: 

55 return super(TPUVariableMixin, self).__getattr__(name) 

56 else: 

57 raise AttributeError( 

58 f"`TPUVariableMixin.{name}` not accessible within a TPU context.") 

59 

60 def get(self): 

61 if tpu_util.enclosing_tpu_context() is None: 

62 return super(TPUVariableMixin, self).get() 

63 else: 

64 raise NotImplementedError( 

65 "`TPUVariableMixin.get()` is not supported within a TPU context.") 

66 

67 def _get_as_operand(self): 

68 return self.read_value() 

69 

70 @property 

71 def handle(self): 

72 """The handle by which this variable can be accessed.""" 

73 # If we're in a tpu.rewrite(), return the replicated handle. 

74 tpu_context = tpu_util.enclosing_tpu_context() 

75 if tpu_context is None or context.executing_eagerly(): 

76 var = self._get_on_device_or_primary() 

77 if isinstance(var, packed.PackedVarAndDevice): 

78 return var.on_device_handle() 

79 else: 

80 return var.handle 

81 else: 

82 is_packed = self._packed_var is not None 

83 val = self._values 

84 if is_packed: 

85 val = [self._packed_var] 

86 

87 return tpu_context.get_replicated_var_handle(self._common_name, 

88 self._handle_id, val, 

89 self._is_mirrored(), 

90 is_packed) 

91 

92 @property 

93 def device(self): 

94 return self.handle.device 

95 

96 def _read_variable_op(self): 

97 """Reads the value of this variable.""" 

98 if self.trainable: 

99 tape.variable_accessed(self) 

100 

101 handle = self.handle 

102 if getattr(handle, "is_packed", False): 

103 # Add a device scope for a packed variable handle. 

104 with ops.device(self._get_on_device_or_primary().device): 

105 return gen_resource_variable_ops.read_variable_op(handle, self.dtype) 

106 else: 

107 return gen_resource_variable_ops.read_variable_op(handle, self.dtype) 

108 

109 def read_value(self): 

110 if tpu_util.enclosing_tpu_context() is None: 

111 return super(TPUVariableMixin, self).read_value() 

112 else: 

113 return self._read_variable_op() 

114 

115 def value(self): 

116 if tpu_util.enclosing_tpu_context() is None: 

117 return super(TPUVariableMixin, self).value() 

118 else: 

119 return self._read_variable_op() 

120 

121 def _as_graph_element(self): 

122 if tpu_util.enclosing_tpu_context() is None: 

123 return super(TPUVariableMixin, self)._as_graph_element() # pylint: disable=protected-access 

124 else: 

125 return None 

126 

127 @property 

128 def op(self): 

129 if values_util.is_saving_non_distributed(): 

130 return self._primary.op 

131 return values.DistributedVarOp(self._primary.op.name, 

132 self._primary.op.graph, 

133 self._primary.op.traceback, 

134 self._primary.op.type) 

135 

136 def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False): 

137 """Converts a variable to a tensor.""" 

138 # pylint: disable=protected-access 

139 if tpu_util.enclosing_tpu_context() is None: 

140 return super(TPUVariableMixin, self)._dense_var_to_tensor( 

141 dtype=dtype, name=name, as_ref=as_ref) 

142 # pylint: enable=protected-access 

143 elif dtype is not None and dtype != self.dtype: 

144 return math_ops.cast(self.read_value(), dtype) 

145 else: 

146 return self.handle if as_ref else self.read_value() 

147 

148 

149class TPUDistributedVariable(TPUVariableMixin, values.DistributedVariable): 

150 """DistributedVariable subclass for TPUStrategy.""" 

151 

152 def assign_sub(self, value, use_locking=False, name=None, read_value=True): 

153 if values_util.is_saving_non_distributed(): 

154 return self._primary.assign_sub(value, use_locking, name, read_value) 

155 return self._policy.assign_sub( 

156 self, value, use_locking=use_locking, name=name, read_value=read_value) 

157 

158 def assign_add(self, value, use_locking=False, name=None, read_value=True): 

159 if values_util.is_saving_non_distributed(): 

160 return self._primary.assign_add(value, use_locking, name, read_value) 

161 return self._policy.assign_add( 

162 self, value, use_locking=use_locking, name=name, read_value=read_value) 

163 

164 def assign(self, value, use_locking=False, name=None, read_value=True): 

165 if values_util.is_saving_non_distributed(): 

166 return self._primary.assign(value, use_locking, name, read_value) 

167 return self._policy.assign( 

168 self, value, use_locking=use_locking, name=name, read_value=read_value) 

169 

170 def scatter_sub(self, sparse_delta, use_locking=False, name=None): 

171 if values_util.is_saving_non_distributed(): 

172 return self._primary.scatter_sub(sparse_delta, use_locking, name) 

173 return self._policy.scatter_sub( 

174 self, sparse_delta, use_locking=use_locking, name=name) 

175 

176 def scatter_add(self, sparse_delta, use_locking=False, name=None): 

177 if values_util.is_saving_non_distributed(): 

178 return self._primary.scatter_add(sparse_delta, use_locking, name) 

179 return self._policy.scatter_add( 

180 self, sparse_delta, use_locking=use_locking, name=name) 

181 

182 def scatter_mul(self, sparse_delta, use_locking=False, name=None): 

183 if values_util.is_saving_non_distributed(): 

184 return self._primary.scatter_mul(sparse_delta, use_locking, name) 

185 return self._policy.scatter_mul( 

186 self, sparse_delta, use_locking=use_locking, name=name) 

187 

188 def scatter_div(self, sparse_delta, use_locking=False, name=None): 

189 if values_util.is_saving_non_distributed(): 

190 return self._primary.scatter_div(sparse_delta, use_locking, name) 

191 return self._policy.scatter_div( 

192 self, sparse_delta, use_locking=use_locking, name=name) 

193 

194 def scatter_min(self, sparse_delta, use_locking=False, name=None): 

195 if values_util.is_saving_non_distributed(): 

196 return self._primary.scatter_min(sparse_delta, use_locking, name) 

197 return self._policy.scatter_min( 

198 self, sparse_delta, use_locking=use_locking, name=name) 

199 

200 def scatter_max(self, sparse_delta, use_locking=False, name=None): 

201 if values_util.is_saving_non_distributed(): 

202 return self._primary.scatter_max(sparse_delta, use_locking, name) 

203 return self._policy.scatter_max( 

204 self, sparse_delta, use_locking=use_locking, name=name) 

205 

206 def scatter_update(self, sparse_delta, use_locking=False, name=None): 

207 if values_util.is_saving_non_distributed(): 

208 return self._primary.scatter_update(sparse_delta, use_locking, name) 

209 return self._policy.scatter_update( 

210 self, sparse_delta, use_locking=use_locking, name=name) 

211 

212 

213class TPUMirroredVariable(TPUVariableMixin, values.MirroredVariable): 

214 """Holds a map from replica to TPU variables whose values are kept in sync.""" 

215 

216 def _is_replicated_or_sharded_to_logical_cores(self): 

217 """Returns whether each of the underlying variables is replicated or sharded to logical cores. 

218 

219 If True, the handles of the underlying variables are not available outside a 

220 TPU context. 

221 """ 

222 return isinstance(self._primary, 

223 tpu_replicated_variable.TPUReplicatedVariable) 

224 

225 @property 

226 def device(self): 

227 if (self._is_replicated_or_sharded_to_logical_cores() and 

228 tpu_util.enclosing_tpu_context() is None): 

229 return self._primary.device 

230 return super(TPUMirroredVariable, self).device 

231 

232 def assign_sub(self, value, use_locking=False, name=None, read_value=True): 

233 tpu_context = tpu_util.enclosing_tpu_context() 

234 if (self._is_replicated_or_sharded_to_logical_cores() and 

235 tpu_context is None): 

236 assign_sub_fn = lambda v, *a, **ka: v.assign_sub(*a, **ka) 

237 return self._update( 

238 update_fn=assign_sub_fn, 

239 value=value, 

240 use_locking=use_locking, 

241 name=name, 

242 read_value=read_value) 

243 

244 if (tpu_context and 

245 self.aggregation == variable_scope.VariableAggregation.NONE): 

246 return tpu_util.make_raw_assign_fn( 

247 gen_resource_variable_ops.assign_sub_variable_op)( 

248 self, 

249 value=value, 

250 use_locking=use_locking, 

251 name=name, 

252 read_value=read_value) 

253 return assign_sub( 

254 self, value, use_locking=use_locking, name=name, read_value=read_value) 

255 

256 def assign_add(self, value, use_locking=False, name=None, read_value=True): 

257 tpu_context = tpu_util.enclosing_tpu_context() 

258 if (self._is_replicated_or_sharded_to_logical_cores() and 

259 tpu_context is None): 

260 assign_add_fn = lambda v, *a, **ka: v.assign_add(*a, **ka) 

261 return self._update( 

262 update_fn=assign_add_fn, 

263 value=value, 

264 use_locking=use_locking, 

265 name=name, 

266 read_value=read_value) 

267 

268 if (tpu_context and 

269 self.aggregation == variable_scope.VariableAggregation.NONE): 

270 return tpu_util.make_raw_assign_fn( 

271 gen_resource_variable_ops.assign_add_variable_op)( 

272 self, 

273 value=value, 

274 use_locking=use_locking, 

275 name=name, 

276 read_value=read_value) 

277 return assign_add( 

278 self, value, use_locking=use_locking, name=name, read_value=read_value) 

279 

280 def assign(self, value, use_locking=False, name=None, read_value=True): 

281 tpu_context = tpu_util.enclosing_tpu_context() 

282 if (self._is_replicated_or_sharded_to_logical_cores() and 

283 tpu_context is None): 

284 assign_fn = lambda v, *a, **ka: v.assign(*a, **ka) 

285 return self._update( 

286 update_fn=assign_fn, 

287 value=value, 

288 use_locking=use_locking, 

289 name=name, 

290 read_value=read_value) 

291 

292 if (tpu_util.enclosing_tpu_context() and 

293 self.aggregation == variable_scope.VariableAggregation.NONE): 

294 return tpu_util.make_raw_assign_fn( 

295 gen_resource_variable_ops.assign_variable_op)( 

296 self, 

297 value=value, 

298 use_locking=use_locking, 

299 name=name, 

300 read_value=read_value) 

301 return assign( 

302 self, value, use_locking=use_locking, name=name, read_value=read_value) 

303 

304 def scatter_sub(self, *args, **kwargs): 

305 if values_util.is_saving_non_distributed(): 

306 return self._primary.scatter_sub(*args, **kwargs) 

307 raise NotImplementedError 

308 

309 def scatter_add(self, *args, **kwargs): 

310 if values_util.is_saving_non_distributed(): 

311 return self._primary.scatter_add(*args, **kwargs) 

312 raise NotImplementedError 

313 

314 def scatter_max(self, *args, **kwargs): 

315 if values_util.is_saving_non_distributed(): 

316 return self._primary.scatter_max(*args, **kwargs) 

317 raise NotImplementedError 

318 

319 def scatter_min(self, *args, **kwargs): 

320 if values_util.is_saving_non_distributed(): 

321 return self._primary.scatter_min(*args, **kwargs) 

322 raise NotImplementedError 

323 

324 def scatter_mul(self, *args, **kwargs): 

325 if values_util.is_saving_non_distributed(): 

326 return self._primary.scatter_mul(*args, **kwargs) 

327 raise NotImplementedError 

328 

329 def scatter_div(self, *args, **kwargs): 

330 if values_util.is_saving_non_distributed(): 

331 return self._primary.scatter_div(*args, **kwargs) 

332 raise NotImplementedError 

333 

334 def scatter_update(self, *args, **kwargs): 

335 if values_util.is_saving_non_distributed(): 

336 return self._primary.scatter_update(*args, **kwargs) 

337 raise NotImplementedError 

338 

339 

340class TPUSyncOnReadVariable(TPUVariableMixin, values.SyncOnReadVariable): 

341 """Holds a map from replica to variables whose values are reduced on save.""" 

342 

343 def assign_sub(self, *args, **kwargs): 

344 if tpu_util.enclosing_tpu_context() is None: 

345 return values.SyncOnReadVariable.assign_sub(self, *args, **kwargs) 

346 else: 

347 return tpu_util.make_raw_assign_fn( 

348 gen_resource_variable_ops.assign_sub_variable_op)(self, *args, 

349 **kwargs) 

350 

351 def assign_add(self, *args, **kwargs): 

352 if tpu_util.enclosing_tpu_context() is None: 

353 return values.SyncOnReadVariable.assign_add(self, *args, **kwargs) 

354 else: 

355 return tpu_util.make_raw_assign_fn( 

356 gen_resource_variable_ops.assign_add_variable_op)(self, *args, 

357 **kwargs) 

358 

359 def assign(self, *args, **kwargs): 

360 if tpu_util.enclosing_tpu_context() is None: 

361 return values.SyncOnReadVariable.assign(self, *args, **kwargs) 

362 else: 

363 return tpu_util.make_raw_assign_fn( 

364 gen_resource_variable_ops.assign_variable_op)(self, *args, **kwargs) 

365 

366 

367# Common method between OnWrite and Mirrored variables. 

368def assign_sub(var, value, use_locking=False, name=None, read_value=True): 

369 assign_sub_fn = tpu_util.make_raw_assign_fn( 

370 gen_resource_variable_ops.assign_sub_variable_op) 

371 return var._update( # pylint: disable=protected-access 

372 update_fn=assign_sub_fn, 

373 value=value, 

374 use_locking=use_locking, 

375 name=name, 

376 read_value=read_value) 

377 

378 

379def assign_add(var, value, use_locking=False, name=None, read_value=True): 

380 assign_add_fn = tpu_util.make_raw_assign_fn( 

381 gen_resource_variable_ops.assign_add_variable_op) 

382 return var._update( # pylint: disable=protected-access 

383 update_fn=assign_add_fn, 

384 value=value, 

385 use_locking=use_locking, 

386 name=name, 

387 read_value=read_value) 

388 

389 

390def assign(var, value, use_locking=False, name=None, read_value=True): 

391 assign_fn = tpu_util.make_raw_assign_fn( 

392 gen_resource_variable_ops.assign_variable_op) 

393 return var._update( # pylint: disable=protected-access 

394 update_fn=assign_fn, 

395 value=value, 

396 use_locking=use_locking, 

397 name=name, 

398 read_value=read_value) 

399 

400 

401class TPUOnWritePolicy(values.OnWritePolicy): 

402 """Policy defined for `tf.VariableSynchronization.ON_WRITE` synchronization. 

403 

404 This policy is created when `synchronization` is set to 

405 `tf.VariableSynchronization.AUTO` or `tf.VariableSynchronization.ON_WRITE`. 

406 """ 

407 

408 def assign_sub(self, 

409 var, 

410 value, 

411 use_locking=False, 

412 name=None, 

413 read_value=True): 

414 if (tpu_util.enclosing_tpu_context() and 

415 var.aggregation == variable_scope.VariableAggregation.NONE): 

416 return tpu_util.make_raw_assign_fn( 

417 gen_resource_variable_ops.assign_sub_variable_op)( 

418 var, 

419 value=value, 

420 use_locking=use_locking, 

421 name=name, 

422 read_value=read_value) 

423 return assign_sub( 

424 var, value, use_locking=use_locking, name=name, read_value=read_value) 

425 

426 def assign_add(self, 

427 var, 

428 value, 

429 use_locking=False, 

430 name=None, 

431 read_value=True): 

432 if (tpu_util.enclosing_tpu_context() and 

433 var.aggregation == variable_scope.VariableAggregation.NONE): 

434 return tpu_util.make_raw_assign_fn( 

435 gen_resource_variable_ops.assign_add_variable_op)( 

436 var, 

437 value=value, 

438 use_locking=use_locking, 

439 name=name, 

440 read_value=read_value) 

441 return assign_add( 

442 var, value, use_locking=use_locking, name=name, read_value=read_value) 

443 

444 def assign(self, var, value, use_locking=False, name=None, read_value=True): 

445 if (tpu_util.enclosing_tpu_context() and 

446 var.aggregation == variable_scope.VariableAggregation.NONE): 

447 return tpu_util.make_raw_assign_fn( 

448 gen_resource_variable_ops.assign_variable_op)( 

449 var, 

450 value=value, 

451 use_locking=use_locking, 

452 name=name, 

453 read_value=read_value) 

454 return assign( 

455 var, value, use_locking=use_locking, name=name, read_value=read_value) 

456 

457 def _scatter_xxx(self, 

458 raw_scater_xxx_fn, 

459 op_name, 

460 var, 

461 sparse_delta, 

462 use_locking=False, 

463 name=None): 

464 scater_xxx_fn = tpu_util.make_raw_scatter_xxx_fn(raw_scater_xxx_fn) 

465 if tpu_util.enclosing_tpu_context(): 

466 if self._aggregation != variable_scope.VariableAggregation.NONE: 

467 raise NotImplementedError( 

468 _scatter_error_msg.format( 

469 op_name=op_name, aggregation=self._aggregation)) 

470 return scater_xxx_fn( 

471 var, sparse_delta=sparse_delta, use_locking=use_locking, name=name) 

472 else: 

473 return var._update( # pylint: disable=protected-access 

474 update_fn=scater_xxx_fn, 

475 value=sparse_delta, 

476 use_locking=use_locking, 

477 name=name) 

478 

479 def scatter_sub(self, var, sparse_delta, use_locking=False, name=None): 

480 return self._scatter_xxx(gen_resource_variable_ops.resource_scatter_sub, 

481 "scatter_sub", var, sparse_delta, use_locking, 

482 name) 

483 

484 def scatter_add(self, var, sparse_delta, use_locking=False, name=None): 

485 return self._scatter_xxx(gen_resource_variable_ops.resource_scatter_add, 

486 "scatter_add", var, sparse_delta, use_locking, 

487 name) 

488 

489 def scatter_max(self, var, sparse_delta, use_locking=False, name=None): 

490 return self._scatter_xxx(gen_resource_variable_ops.resource_scatter_max, 

491 "scatter_max", var, sparse_delta, use_locking, 

492 name) 

493 

494 def scatter_min(self, var, sparse_delta, use_locking=False, name=None): 

495 return self._scatter_xxx(gen_resource_variable_ops.resource_scatter_min, 

496 "scatter_min", var, sparse_delta, use_locking, 

497 name) 

498 

499 def scatter_mul(self, var, sparse_delta, use_locking=False, name=None): 

500 return self._scatter_xxx(gen_resource_variable_ops.resource_scatter_mul, 

501 "scatter_mul", var, sparse_delta, use_locking, 

502 name) 

503 

504 def scatter_div(self, var, sparse_delta, use_locking=False, name=None): 

505 return self._scatter_xxx(gen_resource_variable_ops.resource_scatter_div, 

506 "scatter_div", var, sparse_delta, use_locking, 

507 name) 

508 

509 def scatter_update(self, var, sparse_delta, use_locking=False, name=None): 

510 return self._scatter_xxx(gen_resource_variable_ops.resource_scatter_update, 

511 "scatter_update", var, sparse_delta, use_locking, 

512 name) 

513 

514 

515class TPUOnReadPolicy(values.OnReadPolicy): 

516 """Policy defined for `tf.VariableSynchronization.ON_READ` synchronization. 

517 

518 This policy is created when `synchronization` is set to 

519 `tf.VariableSynchronization.ON_READ` and `aggregation` is set to any of the 

520 values allowed by the `tf.VariableAggregation` enum such as `NONE`, `SUM`, 

521 `MEAN` or `ONLY_FIRST_REPLICA`when creating a `tf.Variable` in `tf.distribute` 

522 scope. 

523 """ 

524 

525 def assign_sub(self, var, *args, **kwargs): 

526 if tpu_util.enclosing_tpu_context() is None: 

527 return super(TPUOnReadPolicy, self).assign_sub(var, *args, **kwargs) 

528 else: 

529 return tpu_util.make_raw_assign_fn( 

530 gen_resource_variable_ops.assign_sub_variable_op)(var, *args, 

531 **kwargs) 

532 

533 def assign_add(self, var, *args, **kwargs): 

534 if tpu_util.enclosing_tpu_context() is None: 

535 return super(TPUOnReadPolicy, self).assign_add(var, *args, **kwargs) 

536 else: 

537 return tpu_util.make_raw_assign_fn( 

538 gen_resource_variable_ops.assign_add_variable_op)(var, *args, 

539 **kwargs) 

540 

541 def assign(self, var, *args, **kwargs): 

542 if tpu_util.enclosing_tpu_context() is None: 

543 return super(TPUOnReadPolicy, self).assign(var, *args, **kwargs) 

544 else: 

545 return tpu_util.make_raw_assign_fn( 

546 gen_resource_variable_ops.assign_variable_op)(var, *args, **kwargs) 

547 

548 def scatter_sub(self, *args, **kwargs): 

549 raise NotImplementedError 

550 

551 def scatter_add(self, *args, **kwargs): 

552 raise NotImplementedError 

553 

554 def scatter_max(self, *args, **kwargs): 

555 raise NotImplementedError 

556 

557 def scatter_min(self, *args, **kwargs): 

558 raise NotImplementedError 

559 

560 def scatter_mul(self, *args, **kwargs): 

561 raise NotImplementedError 

562 

563 def scatter_div(self, *args, **kwargs): 

564 raise NotImplementedError 

565 

566 def scatter_update(self, *args, **kwargs): 

567 raise NotImplementedError