Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/mixed_precision/autocast_variable.py: 43%
289 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Contains AutoCastVariable, a variable which automatically casts itself."""
17import threading
18from typing import Optional
20import tensorflow.compat.v2 as tf
22from keras.src.distribute import distributed_training_utils
24# _autocast_dtype.dtype is the dtype AutoCastVariables should be cast to, or
25# None if AutoCastVariables should not be cast.
26_autocast_dtype = threading.local()
29def numpy_text(tensor, is_repr=False):
30 """Human readable representation of a tensor's numpy value."""
31 if tensor.dtype.is_numpy_compatible:
33 text = repr(tensor._numpy()) if is_repr else str(tensor._numpy())
35 else:
36 text = "<unprintable>"
37 if "\n" in text:
38 text = "\n" + text
39 return text
42class AutoCastVariableSpec(tf.types.experimental.TraceType):
43 """TraceType for AutoCastVariableSpec for tracing with tf.function.
45 This class implements the Type for AutoCastVariable used in tracing.
46 """
48 def __init__(self, value):
49 self._value = value
51 def is_subtype_of(self, other) -> bool:
52 """If the other spec is the same as `self`, return True."""
53 return self == other
55 def most_specific_common_supertype(self, others):
56 """`self` is the common supertype if all input types match it."""
57 return self if all(self == other for other in others) else None
59 def placeholder_value(self, placeholder_context=None):
60 """Use the AutoCastVariable value itself as a placeholder."""
61 return self._value
63 def _to_tensors(self, value):
64 return []
66 def __hash__(self) -> int:
67 return hash(id(self._value))
69 def __eq__(self, other) -> bool:
70 return self is other
73class AutoCastVariable(tf.Variable, tf.__internal__.types.Tensor):
74 """Variable that casts itself to a different dtype in applicable contexts.
76 This class wraps a floating-point `tf.Variable`. It emulates the variable
77 interface and delegates to the wrapped variable, but it additionally will
78 cast the wrapped variable under an `enable_auto_cast_variables(dtype)`
79 context manager.
81 For example:
83 >>> v = tf.Variable(1.0, dtype=tf.float32)
84 >>> v = AutoCastVariable(v)
85 >>> tf.identity(v).dtype
86 tf.float32
87 >>> with enable_auto_cast_variables(tf.float16):
88 ... tf.identity(v).dtype
89 tf.float16
91 The purpose of this class is to allow Keras layers to create variables in
92 float32, and automatically cast them to float16 or bfloat16 when the layer
93 is called.
94 """
96 def __init__(self, variable):
97 """Creates an AutoCastVariable instance.
99 Args:
100 variable: A floating-point resource variable to wrap.
102 Raises:
103 ValueError: If `variable` is not a floating-point resource variable
104 """
105 if not isinstance(variable, tf.Variable):
106 raise ValueError(
107 "variable must be of type tf.ResourceVariable, but got: %s"
108 % variable
109 )
110 if not variable.dtype.is_floating:
111 raise ValueError(
112 "variable must be a floating point variable but has type: %s"
113 % variable.dtype.name
114 )
115 self._variable = variable
116 # 'delegate' means AutoCastVariable.op return self._variable.op, which
117 # will raise an AttributeError in Eager (as intended). If set to any
118 # other value, AutoCastVariable.op returns that value instead, which is
119 # used to set the op attribute in AutoCastVariable.assign().
120 self._op = "delegate"
122 def _should_cast(self):
123 """Returns True if this variable should be casted when accessed."""
124 autocast_dtype = getattr(_autocast_dtype, "dtype", None)
125 return autocast_dtype is not None and self.dtype != autocast_dtype
127 @property
128 def dtype(self):
129 """The dtype of the underlying variable, before any casts are done."""
130 return self._variable.dtype
132 @property
133 def true_dtype(self):
134 """Deprecated alias of `dtype`."""
135 return self._variable.dtype
137 @property
138 def _cast_dtype(self):
139 dtype = getattr(_autocast_dtype, "dtype", None)
140 return dtype or self._variable.dtype
142 def value(self):
143 val = self._variable.value()
144 if not self._should_cast():
145 return val
146 return tf.cast(val, self._cast_dtype)
148 def read_value(self):
149 val = self._variable.read_value()
150 return tf.cast(val, self._cast_dtype)
152 def sparse_read(self, indices, name=None):
153 """Reads the value of this variable sparsely, using `gather`."""
154 val = self._variable.sparse_read(indices, name=name)
155 return tf.cast(val, self._cast_dtype)
157 def gather_nd(self, indices, name=None):
158 """Gather slices of the variable into a Tensor."""
159 val = self._variable.gather_nd(indices, name=name)
160 return tf.cast(val, self._cast_dtype)
162 def __getattr__(self, name):
163 return getattr(self._variable, name)
165 def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
166 """Converts this variable to a tensor."""
167 if as_ref:
168 # This ValueError should not occur in practice since it is
169 # impossible to pass as_ref=True using public APIs.
170 raise ValueError(
171 "Cannot convert AutoCastVariable to a tensor if "
172 "as_ref=True is passed to convert_to_tensor"
173 )
174 if not self._should_cast():
175 return tf.convert_to_tensor(self._variable, dtype=dtype, name=name)
176 if dtype is not None and not dtype.is_compatible_with(self._cast_dtype):
177 raise ValueError(
178 "Incompatible type conversion requested to type {!r} for "
179 "AutoCastVariable which is casted to type {!r}".format(
180 dtype.name, self._cast_dtype.name
181 )
182 )
183 val = tf.convert_to_tensor(
184 self._variable, dtype=self._variable.dtype, name=name
185 )
186 return tf.cast(val, self._cast_dtype)
188 def __tf_tensor__(
189 self,
190 dtype: Optional[tf.dtypes.DType] = None,
191 name: Optional[str] = None,
192 ) -> tf.Tensor:
193 return self._dense_var_to_tensor(dtype=dtype, name=name)
195 def _should_act_as_resource_variable(self):
196 """Pass resource_variable_ops.is_resource_variable check."""
197 pass
199 def __repr__(self):
200 if tf.executing_eagerly() and not self._in_graph_mode:
201 repr_str = (
202 "<AutoCastVariable '{v.name}' shape={v.shape} "
203 "dtype={v.dtype.name} dtype_to_cast_to={v._cast_dtype.name}, "
204 "numpy={np_repr}>"
205 )
206 return repr_str.format(
207 v=self, np_repr=numpy_text(self.read_value(), is_repr=True)
208 )
209 else:
210 repr_str = (
211 "<AutoCastVariable '{v.name}' shape={v.shape} "
212 "dtype={v.dtype.name} dtype_to_cast_to={v._cast_dtype.name}>"
213 )
214 return repr_str.format(v=self)
216 # Method delegations: We delegate the following methods to self._variable.
217 # Each of these methods simply calls the same method on self._variable. The
218 # base Variable raises NotImplementedError for most of these, so we must
219 # override them.
220 #
221 # We do not define the following methods from Variable for the following
222 # reasons:
223 # * 'count_up_to': This method only applies to int variables, which cannot
224 # be wrapped with an AutoCastVariable.
225 # * 'ref': Instead we inherit the definition from Variable.
226 # If we defined and delegated to Variable, the ref of an
227 # AutoCastVariable would be the same as the ref of the underlying
228 # variable, which would be strange as they are different Python objects.
230 def set_shape(self, shape):
231 return self._variable.set_shape(self, shape)
233 @property
234 def trainable(self):
235 return self._variable.trainable
237 @property
238 def synchronization(self):
239 return self._variable.synchronization
241 @property
242 def aggregation(self):
243 return self._variable.aggregation
245 def eval(self, session=None):
246 return self._variable.eval(session)
248 def initialized_value(self):
249 return self._variable.initialized_value()
251 @property
252 def initial_value(self):
253 return self._variable.initial_value
255 @property
256 def constraint(self):
257 return self._variable.constraint
259 def _apply_assign_update(
260 self, update_fn, value, use_locking=None, name=None, read_value=True
261 ):
262 # TODO(b/146181571): This logic can be simplified once
263 # DistributedVariable.assign returns a DistributedVariable. Currently
264 # for MirroredStrategy, it returns a Mirrored value.
265 if tf.compat.v1.executing_eagerly_outside_functions():
266 assign_op = update_fn(value, use_locking, name, False)
267 if read_value:
268 # We create a new AutoCastVariable with the same underlying
269 # tf.Variable. The new AutoCastVariable is identical except the
270 # 'op' attribute is defined. This matches the behavior of
271 # tf.Variable.assign.
272 var = create_autocast_variable(self._variable)
273 var._op = assign_op
274 return var
275 return assign_op
277 # Fallback to wrapping the returned variable in graph mode if possible
278 assign_var = update_fn(value, use_locking, name, read_value)
279 if read_value and tf.__internal__.ops.is_resource_variable(assign_var):
280 return create_autocast_variable(assign_var)
281 return assign_var
283 def _apply_update(self, update_fn, *args, **kwargs):
284 update_var = update_fn(*args, **kwargs)
285 if tf.compat.v1.executing_eagerly_outside_functions():
286 return self
288 # Fallback to wrapping the returned variable in graph mode if possible
289 if tf.__internal__.ops.is_resource_variable(update_var):
290 return create_autocast_variable(update_var)
291 return update_var
293 def assign(self, value, use_locking=None, name=None, read_value=True):
294 return self._apply_assign_update(
295 self._variable.assign, value, use_locking, name, read_value
296 )
298 def assign_add(self, delta, use_locking=None, name=None, read_value=True):
299 return self._apply_assign_update(
300 self._variable.assign_add, delta, use_locking, name, read_value
301 )
303 def assign_sub(self, delta, use_locking=None, name=None, read_value=True):
304 return self._apply_assign_update(
305 self._variable.assign_sub, delta, use_locking, name, read_value
306 )
308 def scatter_sub(self, sparse_delta, use_locking=False, name=None):
309 return self._apply_update(
310 self._variable.scatter_sub, sparse_delta, use_locking, name
311 )
313 def scatter_add(self, sparse_delta, use_locking=False, name=None):
314 return self._apply_update(
315 self._variable.scatter_add, sparse_delta, use_locking, name
316 )
318 def scatter_max(self, sparse_delta, use_locking=False, name=None):
319 return self._apply_update(
320 self._variable.scatter_max, sparse_delta, use_locking, name
321 )
323 def scatter_min(self, sparse_delta, use_locking=False, name=None):
324 return self._apply_update(
325 self._variable.scatter_min, sparse_delta, use_locking, name
326 )
328 def scatter_mul(self, sparse_delta, use_locking=False, name=None):
329 return self._apply_update(
330 self._variable.scatter_mul, sparse_delta, use_locking, name
331 )
333 def scatter_div(self, sparse_delta, use_locking=False, name=None):
334 return self._apply_update(
335 self._variable.scatter_div, sparse_delta, use_locking, name
336 )
338 def scatter_update(self, sparse_delta, use_locking=False, name=None):
339 return self._apply_update(
340 self._variable.scatter_update, sparse_delta, use_locking, name
341 )
343 def batch_scatter_update(self, sparse_delta, use_locking=False, name=None):
344 return self._apply_update(
345 self._variable.batch_scatter_update, sparse_delta, use_locking, name
346 )
348 def scatter_nd_sub(self, indices, updates, name=None):
349 return self._apply_update(
350 self._variable.scatter_nd_sub, indices, updates, name
351 )
353 def scatter_nd_add(self, indices, updates, name=None):
354 return self._apply_update(
355 self._variable.scatter_nd_add, indices, updates, name
356 )
358 def scatter_nd_update(self, indices, updates, name=None):
359 return self._apply_update(
360 self._variable.scatter_nd_update, indices, updates, name
361 )
363 def load(self, value, session=None):
364 return self._variable.load(value, session)
366 @property
367 def name(self):
368 return self._variable.name
370 @property
371 def _shared_name(self):
372 return self._variable._shared_name
374 @property
375 def initializer(self):
376 return self._variable.initializer
378 @property
379 def device(self):
380 return self._variable.device
382 @property
383 def op(self):
384 if self._op == "delegate":
385 return self._variable.op
386 return self._op
388 def _as_graph_element(self):
389 graph_element = self._variable._as_graph_element()
390 if graph_element is None:
391 return self._op
392 return graph_element
394 @property
395 def graph(self):
396 return self._variable.graph
398 @property
399 def shape(self):
400 return self._variable.shape
402 def get_shape(self):
403 return self._variable.get_shape()
405 def __tf_tracing_type__(self, context):
406 return AutoCastVariableSpec(self)
408 def _gather_saveables_for_checkpoint(self):
409 # By delegating this method to the wrapped variable, checkpoints with
410 # AutoCastVariables are identical to checkpoints with normal variables.
411 # Therefore models checkpointed with AutoCastVariables can be restored
412 # on models with normal variables, and vice versa.
413 return self._variable._gather_saveables_for_checkpoint()
415 def _export_to_saved_model_graph(
416 self, object_map, tensor_map, options, **kwargs
417 ):
418 # By delegating this method to the wrapped variable, SavedModel with
419 # AutoCastVariables are identical to SavedModel with normal variables.
420 resource_list = self._variable._export_to_saved_model_graph(
421 object_map, tensor_map, options, **kwargs
422 )
423 object_map[self] = object_map[self._variable]
424 return resource_list
426 # TODO(reedwm): Maybe encode the fact the variable is an AutoCastVariable in
427 # to_proto().
428 def to_proto(self, export_scope=None):
429 return self._variable.to_proto(export_scope)
431 def from_proto(self, variable_def, import_scope=None):
432 return self._variable.from_proto(variable_def, import_scope)
434 # Delegate the private attributes _handle_name and _initializer_op to
435 # self._variable. SavedModel sets these attributes when loading a model. For
436 # example, it sets _handle_name here:
437 # https://github.com/tensorflow/tensorflow/blob/db26bd574fa95b5bdd53c08463dd19407cc0297e/tensorflow/python/keras/saving/saved_model/load.py#L211
438 # We need to expose these attributes on AutoCastVariable as well for
439 # SavedModel to work properly.
440 # TODO(reedwm/kathywu): Find a better way to support SavedModel. Exposing
441 # private attributes is hacky and difficult to maintain.
442 @property
443 def _handle_name(self):
444 return self._variable._handle_name
446 @_handle_name.setter
447 def _handle_name(self, handle_name):
448 self._variable._handle_name = handle_name
450 @property
451 def _initializer_op(self):
452 return self._variable._initializer_op
454 @_initializer_op.setter
455 def _initializer_op(self, initializer_op):
456 self._variable._initializer_op = initializer_op
458 # Operator overloads:
459 # Note we only overload operators that support floating-point types, as
460 # non-float variables cannot be wrapped with an AutoCastVariable.
461 # Also note: We call read_value() instead of value(), because value() causes
462 # gradients not to work properly when TPUStrategy is used: b/143380936
464 def __add__(self, o):
465 return self.read_value() + o
467 def __radd__(self, o):
468 return o + self.read_value()
470 def __sub__(self, o):
471 return self.read_value() - o
473 def __rsub__(self, o):
474 return o - self.read_value()
476 def __mul__(self, o):
477 return self.read_value() * o
479 def __rmul__(self, o):
480 return o * self.read_value()
482 def __truediv__(self, o):
483 return self.read_value() / o
485 def __rtruediv__(self, o):
486 return o / self.read_value()
488 def __floordiv__(self, o):
489 return self.read_value() // o
491 def __rfloordiv__(self, o):
492 return o // self.read_value()
494 def __mod__(self, o):
495 return self.read_value() % o
497 def __rmod__(self, o):
498 return o % self.read_value()
500 def __lt__(self, o):
501 return self.read_value() < o
503 def __le__(self, o):
504 return self.read_value() <= o
506 def __gt__(self, o):
507 return self.read_value() > o
509 def __ge__(self, o):
510 return self.read_value() >= o
512 def __getitem__(self, o):
513 return self.read_value()[o]
515 def __pow__(self, o, modulo=None):
516 return pow(self.read_value(), o, modulo)
518 def __rpow__(self, o):
519 return pow(o, self.read_value())
521 def __neg__(self):
522 return -self.read_value()
524 def __abs__(self):
525 return abs(self.read_value())
527 def __div__(self, o):
528 try:
529 return self.read_value().__div__(o)
530 except AttributeError:
531 # See
532 # https://docs.python.org/3/library/constants.html#NotImplemented
533 return NotImplemented
535 def __rdiv__(self, o):
536 try:
537 return self.read_value().__rdiv__(o)
538 except AttributeError:
539 # See
540 # https://docs.python.org/3/library/constants.html#NotImplemented
541 return NotImplemented
543 def __matmul__(self, o):
544 try:
545 return self.read_value().__matmul__(o)
546 except AttributeError:
547 # See
548 # https://docs.python.org/3/library/constants.html#NotImplemented
549 return NotImplemented
551 def __rmatmul__(self, o):
552 try:
553 return self.read_value().__rmatmul__(o)
554 except AttributeError:
555 # See
556 # https://docs.python.org/3/library/constants.html#NotImplemented
557 return NotImplemented
560tf.register_tensor_conversion_function(
561 AutoCastVariable, AutoCastVariable._dense_var_to_tensor
562)
565def create_autocast_variable(variable):
566 """Creates an AutoCastVariable that wraps another variable.
568 This typically just returns `AutoCastVariable(variable)`. But, if the
569 variable is a DistributedVariable or one of its subclasses, we instead
570 dynamically create a class that subclasses from both AutoCastVariable and
571 variable.__class__. This is so the returned variable will still pass
572 `isinstance(variable, variable.__class__)`, which is required for
573 DistributedVariables and its subclasses to work properly.
575 Args:
576 variable: A floating-point resource variable to wrap.
578 Returns:
579 An AutoCastVariable that wraps the variable.
580 """
581 if not distributed_training_utils.is_distributed_variable(variable):
582 return AutoCastVariable(variable)
584 class AutoCastDistributedVariable(AutoCastVariable, variable.__class__):
585 """An AutoCastVariable that also subclasses from variable.__class__.
587 variable.__class__ is either a DistributedVariable or an
588 AggregatingVariable.
589 """
591 def __repr__(self):
593 return (
594 "<AutoCastDistributedVariable dtype={v.dtype.name} "
595 "dtype_to_cast_to={v._cast_dtype.name} "
596 "inner_variable={v._variable}>"
597 ).format(v=self)
599 return AutoCastDistributedVariable(variable)
602class enable_auto_cast_variables:
603 """Context manager which enables the autocasting of `AutoCastVariable`s.
605 Under this context manager, `AutoCastVariable`s will be cast to `dtype` if
606 `dtype` is floating-point. Otherwise, `AutoCastVariable`s will not be cast.
607 """
609 __slots__ = ["_dtype", "_prev_dtype"]
611 def __init__(self, dtype):
612 if dtype and not dtype.is_floating:
613 dtype = None
614 self._dtype = dtype
616 def __enter__(self):
617 self._prev_dtype = getattr(_autocast_dtype, "dtype", None)
618 _autocast_dtype.dtype = self._dtype
620 def __exit__(self, type_arg, value_arg, traceback_arg):
621 _autocast_dtype.dtype = self._prev_dtype