Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/mixed_precision/autocast_variable.py: 43%

289 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Contains AutoCastVariable, a variable which automatically casts itself.""" 

16 

17import threading 

18from typing import Optional 

19 

20import tensorflow.compat.v2 as tf 

21 

22from keras.src.distribute import distributed_training_utils 

23 

24# _autocast_dtype.dtype is the dtype AutoCastVariables should be cast to, or 

25# None if AutoCastVariables should not be cast. 

26_autocast_dtype = threading.local() 

27 

28 

29def numpy_text(tensor, is_repr=False): 

30 """Human readable representation of a tensor's numpy value.""" 

31 if tensor.dtype.is_numpy_compatible: 

32 

33 text = repr(tensor._numpy()) if is_repr else str(tensor._numpy()) 

34 

35 else: 

36 text = "<unprintable>" 

37 if "\n" in text: 

38 text = "\n" + text 

39 return text 

40 

41 

42class AutoCastVariableSpec(tf.types.experimental.TraceType): 

43 """TraceType for AutoCastVariableSpec for tracing with tf.function. 

44 

45 This class implements the Type for AutoCastVariable used in tracing. 

46 """ 

47 

48 def __init__(self, value): 

49 self._value = value 

50 

51 def is_subtype_of(self, other) -> bool: 

52 """If the other spec is the same as `self`, return True.""" 

53 return self == other 

54 

55 def most_specific_common_supertype(self, others): 

56 """`self` is the common supertype if all input types match it.""" 

57 return self if all(self == other for other in others) else None 

58 

59 def placeholder_value(self, placeholder_context=None): 

60 """Use the AutoCastVariable value itself as a placeholder.""" 

61 return self._value 

62 

63 def _to_tensors(self, value): 

64 return [] 

65 

66 def __hash__(self) -> int: 

67 return hash(id(self._value)) 

68 

69 def __eq__(self, other) -> bool: 

70 return self is other 

71 

72 

73class AutoCastVariable(tf.Variable, tf.__internal__.types.Tensor): 

74 """Variable that casts itself to a different dtype in applicable contexts. 

75 

76 This class wraps a floating-point `tf.Variable`. It emulates the variable 

77 interface and delegates to the wrapped variable, but it additionally will 

78 cast the wrapped variable under an `enable_auto_cast_variables(dtype)` 

79 context manager. 

80 

81 For example: 

82 

83 >>> v = tf.Variable(1.0, dtype=tf.float32) 

84 >>> v = AutoCastVariable(v) 

85 >>> tf.identity(v).dtype 

86 tf.float32 

87 >>> with enable_auto_cast_variables(tf.float16): 

88 ... tf.identity(v).dtype 

89 tf.float16 

90 

91 The purpose of this class is to allow Keras layers to create variables in 

92 float32, and automatically cast them to float16 or bfloat16 when the layer 

93 is called. 

94 """ 

95 

96 def __init__(self, variable): 

97 """Creates an AutoCastVariable instance. 

98 

99 Args: 

100 variable: A floating-point resource variable to wrap. 

101 

102 Raises: 

103 ValueError: If `variable` is not a floating-point resource variable 

104 """ 

105 if not isinstance(variable, tf.Variable): 

106 raise ValueError( 

107 "variable must be of type tf.ResourceVariable, but got: %s" 

108 % variable 

109 ) 

110 if not variable.dtype.is_floating: 

111 raise ValueError( 

112 "variable must be a floating point variable but has type: %s" 

113 % variable.dtype.name 

114 ) 

115 self._variable = variable 

116 # 'delegate' means AutoCastVariable.op return self._variable.op, which 

117 # will raise an AttributeError in Eager (as intended). If set to any 

118 # other value, AutoCastVariable.op returns that value instead, which is 

119 # used to set the op attribute in AutoCastVariable.assign(). 

120 self._op = "delegate" 

121 

122 def _should_cast(self): 

123 """Returns True if this variable should be casted when accessed.""" 

124 autocast_dtype = getattr(_autocast_dtype, "dtype", None) 

125 return autocast_dtype is not None and self.dtype != autocast_dtype 

126 

127 @property 

128 def dtype(self): 

129 """The dtype of the underlying variable, before any casts are done.""" 

130 return self._variable.dtype 

131 

132 @property 

133 def true_dtype(self): 

134 """Deprecated alias of `dtype`.""" 

135 return self._variable.dtype 

136 

137 @property 

138 def _cast_dtype(self): 

139 dtype = getattr(_autocast_dtype, "dtype", None) 

140 return dtype or self._variable.dtype 

141 

142 def value(self): 

143 val = self._variable.value() 

144 if not self._should_cast(): 

145 return val 

146 return tf.cast(val, self._cast_dtype) 

147 

148 def read_value(self): 

149 val = self._variable.read_value() 

150 return tf.cast(val, self._cast_dtype) 

151 

152 def sparse_read(self, indices, name=None): 

153 """Reads the value of this variable sparsely, using `gather`.""" 

154 val = self._variable.sparse_read(indices, name=name) 

155 return tf.cast(val, self._cast_dtype) 

156 

157 def gather_nd(self, indices, name=None): 

158 """Gather slices of the variable into a Tensor.""" 

159 val = self._variable.gather_nd(indices, name=name) 

160 return tf.cast(val, self._cast_dtype) 

161 

162 def __getattr__(self, name): 

163 return getattr(self._variable, name) 

164 

165 def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False): 

166 """Converts this variable to a tensor.""" 

167 if as_ref: 

168 # This ValueError should not occur in practice since it is 

169 # impossible to pass as_ref=True using public APIs. 

170 raise ValueError( 

171 "Cannot convert AutoCastVariable to a tensor if " 

172 "as_ref=True is passed to convert_to_tensor" 

173 ) 

174 if not self._should_cast(): 

175 return tf.convert_to_tensor(self._variable, dtype=dtype, name=name) 

176 if dtype is not None and not dtype.is_compatible_with(self._cast_dtype): 

177 raise ValueError( 

178 "Incompatible type conversion requested to type {!r} for " 

179 "AutoCastVariable which is casted to type {!r}".format( 

180 dtype.name, self._cast_dtype.name 

181 ) 

182 ) 

183 val = tf.convert_to_tensor( 

184 self._variable, dtype=self._variable.dtype, name=name 

185 ) 

186 return tf.cast(val, self._cast_dtype) 

187 

188 def __tf_tensor__( 

189 self, 

190 dtype: Optional[tf.dtypes.DType] = None, 

191 name: Optional[str] = None, 

192 ) -> tf.Tensor: 

193 return self._dense_var_to_tensor(dtype=dtype, name=name) 

194 

195 def _should_act_as_resource_variable(self): 

196 """Pass resource_variable_ops.is_resource_variable check.""" 

197 pass 

198 

199 def __repr__(self): 

200 if tf.executing_eagerly() and not self._in_graph_mode: 

201 repr_str = ( 

202 "<AutoCastVariable '{v.name}' shape={v.shape} " 

203 "dtype={v.dtype.name} dtype_to_cast_to={v._cast_dtype.name}, " 

204 "numpy={np_repr}>" 

205 ) 

206 return repr_str.format( 

207 v=self, np_repr=numpy_text(self.read_value(), is_repr=True) 

208 ) 

209 else: 

210 repr_str = ( 

211 "<AutoCastVariable '{v.name}' shape={v.shape} " 

212 "dtype={v.dtype.name} dtype_to_cast_to={v._cast_dtype.name}>" 

213 ) 

214 return repr_str.format(v=self) 

215 

216 # Method delegations: We delegate the following methods to self._variable. 

217 # Each of these methods simply calls the same method on self._variable. The 

218 # base Variable raises NotImplementedError for most of these, so we must 

219 # override them. 

220 # 

221 # We do not define the following methods from Variable for the following 

222 # reasons: 

223 # * 'count_up_to': This method only applies to int variables, which cannot 

224 # be wrapped with an AutoCastVariable. 

225 # * 'ref': Instead we inherit the definition from Variable. 

226 # If we defined and delegated to Variable, the ref of an 

227 # AutoCastVariable would be the same as the ref of the underlying 

228 # variable, which would be strange as they are different Python objects. 

229 

230 def set_shape(self, shape): 

231 return self._variable.set_shape(self, shape) 

232 

233 @property 

234 def trainable(self): 

235 return self._variable.trainable 

236 

237 @property 

238 def synchronization(self): 

239 return self._variable.synchronization 

240 

241 @property 

242 def aggregation(self): 

243 return self._variable.aggregation 

244 

245 def eval(self, session=None): 

246 return self._variable.eval(session) 

247 

248 def initialized_value(self): 

249 return self._variable.initialized_value() 

250 

251 @property 

252 def initial_value(self): 

253 return self._variable.initial_value 

254 

255 @property 

256 def constraint(self): 

257 return self._variable.constraint 

258 

259 def _apply_assign_update( 

260 self, update_fn, value, use_locking=None, name=None, read_value=True 

261 ): 

262 # TODO(b/146181571): This logic can be simplified once 

263 # DistributedVariable.assign returns a DistributedVariable. Currently 

264 # for MirroredStrategy, it returns a Mirrored value. 

265 if tf.compat.v1.executing_eagerly_outside_functions(): 

266 assign_op = update_fn(value, use_locking, name, False) 

267 if read_value: 

268 # We create a new AutoCastVariable with the same underlying 

269 # tf.Variable. The new AutoCastVariable is identical except the 

270 # 'op' attribute is defined. This matches the behavior of 

271 # tf.Variable.assign. 

272 var = create_autocast_variable(self._variable) 

273 var._op = assign_op 

274 return var 

275 return assign_op 

276 

277 # Fallback to wrapping the returned variable in graph mode if possible 

278 assign_var = update_fn(value, use_locking, name, read_value) 

279 if read_value and tf.__internal__.ops.is_resource_variable(assign_var): 

280 return create_autocast_variable(assign_var) 

281 return assign_var 

282 

283 def _apply_update(self, update_fn, *args, **kwargs): 

284 update_var = update_fn(*args, **kwargs) 

285 if tf.compat.v1.executing_eagerly_outside_functions(): 

286 return self 

287 

288 # Fallback to wrapping the returned variable in graph mode if possible 

289 if tf.__internal__.ops.is_resource_variable(update_var): 

290 return create_autocast_variable(update_var) 

291 return update_var 

292 

293 def assign(self, value, use_locking=None, name=None, read_value=True): 

294 return self._apply_assign_update( 

295 self._variable.assign, value, use_locking, name, read_value 

296 ) 

297 

298 def assign_add(self, delta, use_locking=None, name=None, read_value=True): 

299 return self._apply_assign_update( 

300 self._variable.assign_add, delta, use_locking, name, read_value 

301 ) 

302 

303 def assign_sub(self, delta, use_locking=None, name=None, read_value=True): 

304 return self._apply_assign_update( 

305 self._variable.assign_sub, delta, use_locking, name, read_value 

306 ) 

307 

308 def scatter_sub(self, sparse_delta, use_locking=False, name=None): 

309 return self._apply_update( 

310 self._variable.scatter_sub, sparse_delta, use_locking, name 

311 ) 

312 

313 def scatter_add(self, sparse_delta, use_locking=False, name=None): 

314 return self._apply_update( 

315 self._variable.scatter_add, sparse_delta, use_locking, name 

316 ) 

317 

318 def scatter_max(self, sparse_delta, use_locking=False, name=None): 

319 return self._apply_update( 

320 self._variable.scatter_max, sparse_delta, use_locking, name 

321 ) 

322 

323 def scatter_min(self, sparse_delta, use_locking=False, name=None): 

324 return self._apply_update( 

325 self._variable.scatter_min, sparse_delta, use_locking, name 

326 ) 

327 

328 def scatter_mul(self, sparse_delta, use_locking=False, name=None): 

329 return self._apply_update( 

330 self._variable.scatter_mul, sparse_delta, use_locking, name 

331 ) 

332 

333 def scatter_div(self, sparse_delta, use_locking=False, name=None): 

334 return self._apply_update( 

335 self._variable.scatter_div, sparse_delta, use_locking, name 

336 ) 

337 

338 def scatter_update(self, sparse_delta, use_locking=False, name=None): 

339 return self._apply_update( 

340 self._variable.scatter_update, sparse_delta, use_locking, name 

341 ) 

342 

343 def batch_scatter_update(self, sparse_delta, use_locking=False, name=None): 

344 return self._apply_update( 

345 self._variable.batch_scatter_update, sparse_delta, use_locking, name 

346 ) 

347 

348 def scatter_nd_sub(self, indices, updates, name=None): 

349 return self._apply_update( 

350 self._variable.scatter_nd_sub, indices, updates, name 

351 ) 

352 

353 def scatter_nd_add(self, indices, updates, name=None): 

354 return self._apply_update( 

355 self._variable.scatter_nd_add, indices, updates, name 

356 ) 

357 

358 def scatter_nd_update(self, indices, updates, name=None): 

359 return self._apply_update( 

360 self._variable.scatter_nd_update, indices, updates, name 

361 ) 

362 

363 def load(self, value, session=None): 

364 return self._variable.load(value, session) 

365 

366 @property 

367 def name(self): 

368 return self._variable.name 

369 

370 @property 

371 def _shared_name(self): 

372 return self._variable._shared_name 

373 

374 @property 

375 def initializer(self): 

376 return self._variable.initializer 

377 

378 @property 

379 def device(self): 

380 return self._variable.device 

381 

382 @property 

383 def op(self): 

384 if self._op == "delegate": 

385 return self._variable.op 

386 return self._op 

387 

388 def _as_graph_element(self): 

389 graph_element = self._variable._as_graph_element() 

390 if graph_element is None: 

391 return self._op 

392 return graph_element 

393 

394 @property 

395 def graph(self): 

396 return self._variable.graph 

397 

398 @property 

399 def shape(self): 

400 return self._variable.shape 

401 

402 def get_shape(self): 

403 return self._variable.get_shape() 

404 

405 def __tf_tracing_type__(self, context): 

406 return AutoCastVariableSpec(self) 

407 

408 def _gather_saveables_for_checkpoint(self): 

409 # By delegating this method to the wrapped variable, checkpoints with 

410 # AutoCastVariables are identical to checkpoints with normal variables. 

411 # Therefore models checkpointed with AutoCastVariables can be restored 

412 # on models with normal variables, and vice versa. 

413 return self._variable._gather_saveables_for_checkpoint() 

414 

415 def _export_to_saved_model_graph( 

416 self, object_map, tensor_map, options, **kwargs 

417 ): 

418 # By delegating this method to the wrapped variable, SavedModel with 

419 # AutoCastVariables are identical to SavedModel with normal variables. 

420 resource_list = self._variable._export_to_saved_model_graph( 

421 object_map, tensor_map, options, **kwargs 

422 ) 

423 object_map[self] = object_map[self._variable] 

424 return resource_list 

425 

426 # TODO(reedwm): Maybe encode the fact the variable is an AutoCastVariable in 

427 # to_proto(). 

428 def to_proto(self, export_scope=None): 

429 return self._variable.to_proto(export_scope) 

430 

431 def from_proto(self, variable_def, import_scope=None): 

432 return self._variable.from_proto(variable_def, import_scope) 

433 

434 # Delegate the private attributes _handle_name and _initializer_op to 

435 # self._variable. SavedModel sets these attributes when loading a model. For 

436 # example, it sets _handle_name here: 

437 # https://github.com/tensorflow/tensorflow/blob/db26bd574fa95b5bdd53c08463dd19407cc0297e/tensorflow/python/keras/saving/saved_model/load.py#L211 

438 # We need to expose these attributes on AutoCastVariable as well for 

439 # SavedModel to work properly. 

440 # TODO(reedwm/kathywu): Find a better way to support SavedModel. Exposing 

441 # private attributes is hacky and difficult to maintain. 

442 @property 

443 def _handle_name(self): 

444 return self._variable._handle_name 

445 

446 @_handle_name.setter 

447 def _handle_name(self, handle_name): 

448 self._variable._handle_name = handle_name 

449 

450 @property 

451 def _initializer_op(self): 

452 return self._variable._initializer_op 

453 

454 @_initializer_op.setter 

455 def _initializer_op(self, initializer_op): 

456 self._variable._initializer_op = initializer_op 

457 

458 # Operator overloads: 

459 # Note we only overload operators that support floating-point types, as 

460 # non-float variables cannot be wrapped with an AutoCastVariable. 

461 # Also note: We call read_value() instead of value(), because value() causes 

462 # gradients not to work properly when TPUStrategy is used: b/143380936 

463 

464 def __add__(self, o): 

465 return self.read_value() + o 

466 

467 def __radd__(self, o): 

468 return o + self.read_value() 

469 

470 def __sub__(self, o): 

471 return self.read_value() - o 

472 

473 def __rsub__(self, o): 

474 return o - self.read_value() 

475 

476 def __mul__(self, o): 

477 return self.read_value() * o 

478 

479 def __rmul__(self, o): 

480 return o * self.read_value() 

481 

482 def __truediv__(self, o): 

483 return self.read_value() / o 

484 

485 def __rtruediv__(self, o): 

486 return o / self.read_value() 

487 

488 def __floordiv__(self, o): 

489 return self.read_value() // o 

490 

491 def __rfloordiv__(self, o): 

492 return o // self.read_value() 

493 

494 def __mod__(self, o): 

495 return self.read_value() % o 

496 

497 def __rmod__(self, o): 

498 return o % self.read_value() 

499 

500 def __lt__(self, o): 

501 return self.read_value() < o 

502 

503 def __le__(self, o): 

504 return self.read_value() <= o 

505 

506 def __gt__(self, o): 

507 return self.read_value() > o 

508 

509 def __ge__(self, o): 

510 return self.read_value() >= o 

511 

512 def __getitem__(self, o): 

513 return self.read_value()[o] 

514 

515 def __pow__(self, o, modulo=None): 

516 return pow(self.read_value(), o, modulo) 

517 

518 def __rpow__(self, o): 

519 return pow(o, self.read_value()) 

520 

521 def __neg__(self): 

522 return -self.read_value() 

523 

524 def __abs__(self): 

525 return abs(self.read_value()) 

526 

527 def __div__(self, o): 

528 try: 

529 return self.read_value().__div__(o) 

530 except AttributeError: 

531 # See 

532 # https://docs.python.org/3/library/constants.html#NotImplemented 

533 return NotImplemented 

534 

535 def __rdiv__(self, o): 

536 try: 

537 return self.read_value().__rdiv__(o) 

538 except AttributeError: 

539 # See 

540 # https://docs.python.org/3/library/constants.html#NotImplemented 

541 return NotImplemented 

542 

543 def __matmul__(self, o): 

544 try: 

545 return self.read_value().__matmul__(o) 

546 except AttributeError: 

547 # See 

548 # https://docs.python.org/3/library/constants.html#NotImplemented 

549 return NotImplemented 

550 

551 def __rmatmul__(self, o): 

552 try: 

553 return self.read_value().__rmatmul__(o) 

554 except AttributeError: 

555 # See 

556 # https://docs.python.org/3/library/constants.html#NotImplemented 

557 return NotImplemented 

558 

559 

560tf.register_tensor_conversion_function( 

561 AutoCastVariable, AutoCastVariable._dense_var_to_tensor 

562) 

563 

564 

565def create_autocast_variable(variable): 

566 """Creates an AutoCastVariable that wraps another variable. 

567 

568 This typically just returns `AutoCastVariable(variable)`. But, if the 

569 variable is a DistributedVariable or one of its subclasses, we instead 

570 dynamically create a class that subclasses from both AutoCastVariable and 

571 variable.__class__. This is so the returned variable will still pass 

572 `isinstance(variable, variable.__class__)`, which is required for 

573 DistributedVariables and its subclasses to work properly. 

574 

575 Args: 

576 variable: A floating-point resource variable to wrap. 

577 

578 Returns: 

579 An AutoCastVariable that wraps the variable. 

580 """ 

581 if not distributed_training_utils.is_distributed_variable(variable): 

582 return AutoCastVariable(variable) 

583 

584 class AutoCastDistributedVariable(AutoCastVariable, variable.__class__): 

585 """An AutoCastVariable that also subclasses from variable.__class__. 

586 

587 variable.__class__ is either a DistributedVariable or an 

588 AggregatingVariable. 

589 """ 

590 

591 def __repr__(self): 

592 

593 return ( 

594 "<AutoCastDistributedVariable dtype={v.dtype.name} " 

595 "dtype_to_cast_to={v._cast_dtype.name} " 

596 "inner_variable={v._variable}>" 

597 ).format(v=self) 

598 

599 return AutoCastDistributedVariable(variable) 

600 

601 

602class enable_auto_cast_variables: 

603 """Context manager which enables the autocasting of `AutoCastVariable`s. 

604 

605 Under this context manager, `AutoCastVariable`s will be cast to `dtype` if 

606 `dtype` is floating-point. Otherwise, `AutoCastVariable`s will not be cast. 

607 """ 

608 

609 __slots__ = ["_dtype", "_prev_dtype"] 

610 

611 def __init__(self, dtype): 

612 if dtype and not dtype.is_floating: 

613 dtype = None 

614 self._dtype = dtype 

615 

616 def __enter__(self): 

617 self._prev_dtype = getattr(_autocast_dtype, "dtype", None) 

618 _autocast_dtype.dtype = self._dtype 

619 

620 def __exit__(self, type_arg, value_arg, traceback_arg): 

621 _autocast_dtype.dtype = self._prev_dtype 

622