Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/mixed_precision/autocast_variable.py: 43%

276 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Contains AutoCastVariable, a variable which automatically casts itself.""" 

16 

17import threading 

18from tensorflow.python.eager import context 

19from tensorflow.python.framework import ops 

20from tensorflow.python.framework import tensor_conversion 

21from tensorflow.python.framework import tensor_conversion_registry 

22from tensorflow.python.keras.distribute import distributed_training_utils 

23from tensorflow.python.ops import math_ops 

24from tensorflow.python.ops import resource_variable_ops 

25from tensorflow.python.ops import variables 

26from tensorflow.python.types import core 

27 

28 

29# _autocast_dtype.dtype is the dtype AutoCastVariables should be cast to, or 

30# None if AutoCastVariables should not be cast. 

31_autocast_dtype = threading.local() 

32 

33 

34def numpy_text(tensor, is_repr=False): 

35 """Human readable representation of a tensor's numpy value.""" 

36 if tensor.dtype.is_numpy_compatible: 

37 # pylint: disable=protected-access 

38 text = repr(tensor._numpy()) if is_repr else str(tensor._numpy()) 

39 # pylint: enable=protected-access 

40 else: 

41 text = '<unprintable>' 

42 if '\n' in text: 

43 text = '\n' + text 

44 return text 

45 

46 

47class AutoCastVariable(variables.Variable, core.Tensor): 

48 """Variable that will cast itself to a different dtype in applicable contexts. 

49 

50 This class wraps a floating-point `tf.Variable`. It emulates the variable 

51 interface and delegates to the wrapped variable, but it additionally will cast 

52 the wrapped variable under an `enable_auto_cast_variables(dtype)` context 

53 manager. 

54 

55 For example: 

56 

57 >>> v = tf.Variable(1.0, dtype=tf.float32) 

58 >>> v = AutoCastVariable(v) 

59 >>> tf.identity(v).dtype 

60 tf.float32 

61 >>> with enable_auto_cast_variables(tf.float16): 

62 ... tf.identity(v).dtype 

63 tf.float16 

64 

65 The purpose of this class is to allow Keras layers to create variables in 

66 float32, and automatically cast them to float16 or bfloat16 when the layer is 

67 called. 

68 """ 

69 

70 def __init__(self, variable): 

71 """Creates an AutoCastVariable instance. 

72 

73 Args: 

74 variable: A floating-point resource variable to wrap. 

75 

76 Raises: 

77 ValueError: If `variable` is not a floating-point resource variable 

78 """ 

79 if not isinstance(variable, variables.Variable): 

80 raise ValueError('variable must be of type tf.ResourceVariable, but got: ' 

81 '%s' % variable) 

82 if not variable.dtype.is_floating: 

83 raise ValueError('variable must be a floating point variable but has ' 

84 'type: %s' % variable.dtype.name) 

85 self._variable = variable 

86 # 'delegate' means AutoCastVariable.op return self._variable.op, which will 

87 # raise an AttributeError in Eager (as intended). If set to any other value, 

88 # AutoCastVariable.op returns that value instead, which is used to set the 

89 # op attribute in AutoCastVariable.assign(). 

90 self._op = 'delegate' 

91 

92 def _should_cast(self): 

93 """Returns True if this variable should be casted when accessed.""" 

94 autocast_dtype = getattr(_autocast_dtype, 'dtype', None) 

95 return autocast_dtype is not None and self.dtype != autocast_dtype 

96 

97 @property 

98 def dtype(self): 

99 """The dtype of the underlying variable, before any casts are done.""" 

100 return self._variable.dtype 

101 

102 @property 

103 def true_dtype(self): 

104 """Deprecated alias of `dtype`.""" 

105 return self._variable.dtype 

106 

107 @property 

108 def _cast_dtype(self): 

109 dtype = getattr(_autocast_dtype, 'dtype', None) 

110 return dtype or self._variable.dtype 

111 

112 def value(self): 

113 val = self._variable.value() 

114 if not self._should_cast(): 

115 return val 

116 return math_ops.cast(val, self._cast_dtype) 

117 

118 def read_value(self): 

119 val = self._variable.read_value() 

120 return math_ops.cast(val, self._cast_dtype) 

121 

122 def sparse_read(self, indices, name=None): 

123 """Reads the value of this variable sparsely, using `gather`.""" 

124 val = self._variable.sparse_read(indices, name=name) 

125 return math_ops.cast(val, self._cast_dtype) 

126 

127 def gather_nd(self, indices, name=None): 

128 """Gather slices of the variable into a Tensor.""" 

129 val = self._variable.gather_nd(indices, name=name) 

130 return math_ops.cast(val, self._cast_dtype) 

131 

132 def __getattr__(self, name): 

133 return getattr(self._variable, name) 

134 

135 def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False): 

136 """Converts this variable to a tensor.""" 

137 if as_ref: 

138 # This ValueError should not occur in practice since it is impossible to 

139 # pass as_ref=True using public APIs. 

140 raise ValueError('Cannot convert AutoCastVariable to a tensor if ' 

141 'as_ref=True is passed to convert_to_tensor') 

142 if not self._should_cast(): 

143 return tensor_conversion.convert_to_tensor_v2_with_dispatch( 

144 self._variable, dtype=dtype, name=name 

145 ) 

146 if dtype is not None and not dtype.is_compatible_with(self._cast_dtype): 

147 raise ValueError( 

148 'Incompatible type conversion requested to type {!r} for ' 

149 'AutoCastVariable which is casted to type {!r}'.format( 

150 dtype.name, self._cast_dtype.name)) 

151 val = tensor_conversion.convert_to_tensor_v2_with_dispatch( 

152 self._variable, dtype=self._variable.dtype, name=name 

153 ) 

154 return math_ops.cast(val, self._cast_dtype) 

155 

156 def _should_act_as_resource_variable(self): 

157 """Pass resource_variable_ops.is_resource_variable check.""" 

158 pass 

159 

160 def __repr__(self): 

161 if context.executing_eagerly() and not self._in_graph_mode: 

162 repr_str = ("<AutoCastVariable '{v.name}' shape={v.shape} " 

163 'dtype={v.dtype.name} dtype_to_cast_to={v._cast_dtype.name}, ' 

164 'numpy={np_repr}>') 

165 return repr_str.format( 

166 v=self, np_repr=numpy_text(self.read_value(), is_repr=True)) 

167 else: 

168 repr_str = ("<AutoCastVariable '{v.name}' shape={v.shape} " 

169 'dtype={v.dtype.name} dtype_to_cast_to={v._cast_dtype.name}>') 

170 return repr_str.format(v=self) 

171 

172 # Method delegations: We delegate the following methods to self._variable. 

173 # Each of these methods simply calls the same method on self._variable. The 

174 # base Variable raises NotImplementedError for most of these, so we must 

175 # override them. 

176 # 

177 # We do not define the following methods from Variable for the following 

178 # reasons: 

179 # * 'count_up_to': This method only applies to int variables, which cannot 

180 # be wrapped with an AutoCastVariable. 

181 # * 'ref': Instead we inherit the definition from Variable. 

182 # If we defined and delegated to Variable, the ref of an AutoCastVariable 

183 # would be the same as the ref of the underlying variable, which would be 

184 # strange as they are different Python objects. 

185 

186 def set_shape(self, shape): 

187 return self._variable.set_shape(self, shape) 

188 

189 @property 

190 def trainable(self): 

191 return self._variable.trainable 

192 

193 @property 

194 def synchronization(self): 

195 return self._variable.synchronization 

196 

197 @property 

198 def aggregation(self): 

199 return self._variable.aggregation 

200 

201 def eval(self, session=None): 

202 return self._variable.eval(session) 

203 

204 def initialized_value(self): 

205 return self._variable.initialized_value() 

206 

207 @property 

208 def initial_value(self): 

209 return self._variable.initial_value 

210 

211 @property 

212 def constraint(self): 

213 return self._variable.constraint 

214 

215 def _apply_assign_update(self, 

216 update_fn, 

217 value, 

218 use_locking=None, 

219 name=None, 

220 read_value=True): 

221 # TODO(b/146181571): This logic can be simplified once 

222 # DistributedVariable.assign returns a DistributedVariable. Currently for 

223 # MirroredStrategy, it returns a Mirrored value. 

224 if ops.executing_eagerly_outside_functions(): 

225 assign_op = update_fn(value, use_locking, name, False) 

226 if read_value: 

227 # We create a new AutoCastVariable with the same underlying tf.Variable. 

228 # The new AutoCastVariable is identical except the 'op' attribute is 

229 # defined. This matches the behavior of tf.Variable.assign. 

230 var = create_autocast_variable(self._variable) 

231 var._op = assign_op # pylint:disable=protected-access 

232 return var 

233 return assign_op 

234 

235 # Fallback to wrapping the returned variable in graph mode if possible 

236 assign_var = update_fn(value, use_locking, name, read_value) 

237 if read_value and resource_variable_ops.is_resource_variable(assign_var): 

238 return create_autocast_variable(assign_var) 

239 return assign_var 

240 

241 def _apply_update(self, update_fn, *args, **kwargs): 

242 update_var = update_fn(*args, **kwargs) 

243 if ops.executing_eagerly_outside_functions(): 

244 return self 

245 

246 # Fallback to wrapping the returned variable in graph mode if possible 

247 if resource_variable_ops.is_resource_variable(update_var): 

248 return create_autocast_variable(update_var) 

249 return update_var 

250 

251 def assign(self, value, use_locking=None, name=None, read_value=True): 

252 return self._apply_assign_update(self._variable.assign, value, use_locking, 

253 name, read_value) 

254 

255 def assign_add(self, delta, use_locking=None, name=None, read_value=True): 

256 return self._apply_assign_update(self._variable.assign_add, delta, 

257 use_locking, name, read_value) 

258 

259 def assign_sub(self, delta, use_locking=None, name=None, read_value=True): 

260 return self._apply_assign_update(self._variable.assign_sub, delta, 

261 use_locking, name, read_value) 

262 

263 def scatter_sub(self, sparse_delta, use_locking=False, name=None): 

264 return self._apply_update(self._variable.scatter_sub, sparse_delta, 

265 use_locking, name) 

266 

267 def scatter_add(self, sparse_delta, use_locking=False, name=None): 

268 return self._apply_update(self._variable.scatter_add, sparse_delta, 

269 use_locking, name) 

270 

271 def scatter_max(self, sparse_delta, use_locking=False, name=None): 

272 return self._apply_update(self._variable.scatter_max, sparse_delta, 

273 use_locking, name) 

274 

275 def scatter_min(self, sparse_delta, use_locking=False, name=None): 

276 return self._apply_update(self._variable.scatter_min, sparse_delta, 

277 use_locking, name) 

278 

279 def scatter_mul(self, sparse_delta, use_locking=False, name=None): 

280 return self._apply_update(self._variable.scatter_mul, sparse_delta, 

281 use_locking, name) 

282 

283 def scatter_div(self, sparse_delta, use_locking=False, name=None): 

284 return self._apply_update(self._variable.scatter_div, sparse_delta, 

285 use_locking, name) 

286 

287 def scatter_update(self, sparse_delta, use_locking=False, name=None): 

288 return self._apply_update(self._variable.scatter_update, sparse_delta, 

289 use_locking, name) 

290 

291 def batch_scatter_update(self, sparse_delta, use_locking=False, name=None): 

292 return self._apply_update(self._variable.batch_scatter_update, sparse_delta, 

293 use_locking, name) 

294 

295 def scatter_nd_sub(self, indices, updates, name=None): 

296 return self._apply_update(self._variable.scatter_nd_sub, indices, updates, 

297 name) 

298 

299 def scatter_nd_add(self, indices, updates, name=None): 

300 return self._apply_update(self._variable.scatter_nd_add, indices, updates, 

301 name) 

302 

303 def scatter_nd_update(self, indices, updates, name=None): 

304 return self._apply_update(self._variable.scatter_nd_update, indices, 

305 updates, name) 

306 

307 def load(self, value, session=None): 

308 return self._variable.load(value, session) 

309 

310 @property 

311 def name(self): 

312 return self._variable.name 

313 

314 @property 

315 def _shared_name(self): 

316 return self._variable._shared_name # pylint:disable=protected-access 

317 

318 @property 

319 def initializer(self): 

320 return self._variable.initializer 

321 

322 @property 

323 def device(self): 

324 return self._variable.device 

325 

326 @property 

327 def op(self): 

328 if self._op == 'delegate': 

329 return self._variable.op 

330 return self._op 

331 

332 def _as_graph_element(self): 

333 graph_element = self._variable._as_graph_element() # pylint:disable=protected-access 

334 if graph_element is None: 

335 return self._op 

336 return graph_element 

337 

338 @property 

339 def graph(self): 

340 return self._variable.graph 

341 

342 @property 

343 def shape(self): 

344 return self._variable.shape 

345 

346 def get_shape(self): 

347 return self._variable.get_shape() 

348 

349 def _gather_saveables_for_checkpoint(self): 

350 # By delegating this method to the wrapped variable, checkpoints with 

351 # AutoCastVariables are identical to checkpoints with normal variables. 

352 # Therefore models checkpointed with AutoCastVariables can be restored on 

353 # models with normal variables, and vice versa. 

354 return self._variable._gather_saveables_for_checkpoint() # pylint:disable=protected-access 

355 

356 def _export_to_saved_model_graph(self, object_map, tensor_map, options, 

357 **kwargs): 

358 # By delegating this method to the wrapped variable, SavedModel with 

359 # AutoCastVariables are identical to SavedModel with normal variables. 

360 resource_list = self._variable._export_to_saved_model_graph( # pylint:disable=protected-access 

361 object_map, tensor_map, options, **kwargs) 

362 object_map[self] = object_map[self._variable] 

363 return resource_list 

364 

365 # TODO(reedwm): Maybe encode the fact the variable is an AutoCastVariable in 

366 # to_proto(). 

367 def to_proto(self, export_scope=None): 

368 return self._variable.to_proto(export_scope) 

369 

370 def from_proto(self, variable_def, import_scope=None): 

371 return self._variable.from_proto(variable_def, import_scope) 

372 

373 # Delegate the private attributes _handle_name and _initializer_op to 

374 # self._variable. SavedModel sets these attributes when loading a model. For 

375 # example, it sets _handle_name here: 

376 # https://github.com/tensorflow/tensorflow/blob/db26bd574fa95b5bdd53c08463dd19407cc0297e/tensorflow/python/keras/saving/saved_model/load.py#L211 

377 # We need to expose these attributes on AutoCastVariable as well for 

378 # SavedModel to work properly. 

379 # TODO(reedwm/kathywu): Find a better way to support SavedModel. Exposing 

380 # private attributes is hacky and difficult to maintain. 

381 @property 

382 def _handle_name(self): 

383 return self._variable._handle_name # pylint: disable=protected-access 

384 

385 @_handle_name.setter 

386 def _handle_name(self, handle_name): 

387 self._variable._handle_name = handle_name # pylint: disable=protected-access 

388 

389 @property 

390 def _initializer_op(self): 

391 return self._variable._initializer_op # pylint: disable=protected-access 

392 

393 @_initializer_op.setter 

394 def _initializer_op(self, initializer_op): 

395 self._variable._initializer_op = initializer_op # pylint: disable=protected-access 

396 

397 # Operator overloads: 

398 # Note we only overload operators that support floating-point types, as 

399 # non-float variables cannot be wrapped with an AutoCastVariable. 

400 # Also note: We call read_value() instead of value(), because value() causes 

401 # gradients not to work properly when TPUStrategy is used: b/143380936 

402 

403 def __add__(self, o): 

404 return self.read_value() + o 

405 

406 def __radd__(self, o): 

407 return o + self.read_value() 

408 

409 def __sub__(self, o): 

410 return self.read_value() - o 

411 

412 def __rsub__(self, o): 

413 return o - self.read_value() 

414 

415 def __mul__(self, o): 

416 return self.read_value() * o 

417 

418 def __rmul__(self, o): 

419 return o * self.read_value() 

420 

421 def __truediv__(self, o): 

422 return self.read_value() / o 

423 

424 def __rtruediv__(self, o): 

425 return o / self.read_value() 

426 

427 def __floordiv__(self, o): 

428 return self.read_value() // o 

429 

430 def __rfloordiv__(self, o): 

431 return o // self.read_value() 

432 

433 def __mod__(self, o): 

434 return self.read_value() % o 

435 

436 def __rmod__(self, o): 

437 return o % self.read_value() 

438 

439 def __lt__(self, o): 

440 return self.read_value() < o 

441 

442 def __le__(self, o): 

443 return self.read_value() <= o 

444 

445 def __gt__(self, o): 

446 return self.read_value() > o 

447 

448 def __ge__(self, o): 

449 return self.read_value() >= o 

450 

451 def __getitem__(self, o): 

452 return self.read_value()[o] 

453 

454 def __pow__(self, o, modulo=None): 

455 return pow(self.read_value(), o, modulo) 

456 

457 def __rpow__(self, o): 

458 return pow(o, self.read_value()) 

459 

460 def __neg__(self): 

461 return -self.read_value() # pylint: disable=invalid-unary-operand-type 

462 

463 def __abs__(self): 

464 return abs(self.read_value()) 

465 

466 def __div__(self, o): 

467 try: 

468 return self.read_value().__div__(o) 

469 except AttributeError: 

470 # See https://docs.python.org/3/library/constants.html#NotImplemented 

471 return NotImplemented 

472 

473 def __rdiv__(self, o): 

474 try: 

475 return self.read_value().__rdiv__(o) 

476 except AttributeError: 

477 # See https://docs.python.org/3/library/constants.html#NotImplemented 

478 return NotImplemented 

479 

480 def __matmul__(self, o): 

481 try: 

482 return self.read_value().__matmul__(o) 

483 except AttributeError: 

484 # See https://docs.python.org/3/library/constants.html#NotImplemented 

485 return NotImplemented 

486 

487 def __rmatmul__(self, o): 

488 try: 

489 return self.read_value().__rmatmul__(o) 

490 except AttributeError: 

491 # See https://docs.python.org/3/library/constants.html#NotImplemented 

492 return NotImplemented 

493 

494 # pylint: enable=multiple-statements 

495 

496 

497tensor_conversion_registry.register_tensor_conversion_function( 

498 AutoCastVariable, AutoCastVariable._dense_var_to_tensor) # pylint:disable=protected-access 

499 

500 

501def create_autocast_variable(variable): 

502 """Creates an AutoCastVariable that wraps another variable. 

503 

504 This typically just returns `AutoCastVariable(variable)`. But, if the variable 

505 is a DistributedVariable or one of its subclasses, we instead dynamically 

506 create a class that subclasses from both AutoCastVariable and 

507 variable.__class__. This is so the returned variable will still pass 

508 `isinstance(variable, variable.__class__)`, which is required for 

509 DistributedVariables and its subclasses to work properly. 

510 

511 Args: 

512 variable: A floating-point resource variable to wrap. 

513 

514 Returns: 

515 An AutoCastVariable that wraps the variable. 

516 """ 

517 if not distributed_training_utils.is_distributed_variable(variable): 

518 return AutoCastVariable(variable) 

519 

520 class AutoCastDistributedVariable(AutoCastVariable, variable.__class__): 

521 """An AutoCastVariable that also subclasses from variable.__class__. 

522 

523 variable.__class__ is either a DistributedVariable or an 

524 AggregatingVariable. 

525 """ 

526 

527 def __repr__(self): 

528 

529 # pylint: disable=missing-format-attribute 

530 return ('<AutoCastDistributedVariable dtype={v.dtype.name} ' 

531 'dtype_to_cast_to={v._cast_dtype.name} ' 

532 'inner_variable={v._variable}>' 

533 ).format(v=self) 

534 # pylint: enable=missing-format-attribute 

535 

536 return AutoCastDistributedVariable(variable) 

537 

538 

539class enable_auto_cast_variables(object): # pylint:disable=invalid-name 

540 """Context manager which enables the autocasting of `AutoCastVariable`s. 

541 

542 Under this context manager, `AutoCastVariable`s will be cast to `dtype` if 

543 `dtype` is floating-point. Otherwise, `AutoCastVariable`s will not be cast. 

544 """ 

545 

546 __slots__ = ['_dtype', '_prev_dtype'] 

547 

548 def __init__(self, dtype): 

549 if dtype and not dtype.is_floating: 

550 dtype = None 

551 self._dtype = dtype 

552 

553 def __enter__(self): 

554 self._prev_dtype = getattr(_autocast_dtype, 'dtype', None) 

555 _autocast_dtype.dtype = self._dtype 

556 

557 def __exit__(self, type_arg, value_arg, traceback_arg): 

558 _autocast_dtype.dtype = self._prev_dtype