Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/mixed_precision/autocast_variable.py: 43%
276 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Contains AutoCastVariable, a variable which automatically casts itself."""
17import threading
18from tensorflow.python.eager import context
19from tensorflow.python.framework import ops
20from tensorflow.python.framework import tensor_conversion
21from tensorflow.python.framework import tensor_conversion_registry
22from tensorflow.python.keras.distribute import distributed_training_utils
23from tensorflow.python.ops import math_ops
24from tensorflow.python.ops import resource_variable_ops
25from tensorflow.python.ops import variables
26from tensorflow.python.types import core
29# _autocast_dtype.dtype is the dtype AutoCastVariables should be cast to, or
30# None if AutoCastVariables should not be cast.
31_autocast_dtype = threading.local()
34def numpy_text(tensor, is_repr=False):
35 """Human readable representation of a tensor's numpy value."""
36 if tensor.dtype.is_numpy_compatible:
37 # pylint: disable=protected-access
38 text = repr(tensor._numpy()) if is_repr else str(tensor._numpy())
39 # pylint: enable=protected-access
40 else:
41 text = '<unprintable>'
42 if '\n' in text:
43 text = '\n' + text
44 return text
47class AutoCastVariable(variables.Variable, core.Tensor):
48 """Variable that will cast itself to a different dtype in applicable contexts.
50 This class wraps a floating-point `tf.Variable`. It emulates the variable
51 interface and delegates to the wrapped variable, but it additionally will cast
52 the wrapped variable under an `enable_auto_cast_variables(dtype)` context
53 manager.
55 For example:
57 >>> v = tf.Variable(1.0, dtype=tf.float32)
58 >>> v = AutoCastVariable(v)
59 >>> tf.identity(v).dtype
60 tf.float32
61 >>> with enable_auto_cast_variables(tf.float16):
62 ... tf.identity(v).dtype
63 tf.float16
65 The purpose of this class is to allow Keras layers to create variables in
66 float32, and automatically cast them to float16 or bfloat16 when the layer is
67 called.
68 """
70 def __init__(self, variable):
71 """Creates an AutoCastVariable instance.
73 Args:
74 variable: A floating-point resource variable to wrap.
76 Raises:
77 ValueError: If `variable` is not a floating-point resource variable
78 """
79 if not isinstance(variable, variables.Variable):
80 raise ValueError('variable must be of type tf.ResourceVariable, but got: '
81 '%s' % variable)
82 if not variable.dtype.is_floating:
83 raise ValueError('variable must be a floating point variable but has '
84 'type: %s' % variable.dtype.name)
85 self._variable = variable
86 # 'delegate' means AutoCastVariable.op return self._variable.op, which will
87 # raise an AttributeError in Eager (as intended). If set to any other value,
88 # AutoCastVariable.op returns that value instead, which is used to set the
89 # op attribute in AutoCastVariable.assign().
90 self._op = 'delegate'
92 def _should_cast(self):
93 """Returns True if this variable should be casted when accessed."""
94 autocast_dtype = getattr(_autocast_dtype, 'dtype', None)
95 return autocast_dtype is not None and self.dtype != autocast_dtype
97 @property
98 def dtype(self):
99 """The dtype of the underlying variable, before any casts are done."""
100 return self._variable.dtype
102 @property
103 def true_dtype(self):
104 """Deprecated alias of `dtype`."""
105 return self._variable.dtype
107 @property
108 def _cast_dtype(self):
109 dtype = getattr(_autocast_dtype, 'dtype', None)
110 return dtype or self._variable.dtype
112 def value(self):
113 val = self._variable.value()
114 if not self._should_cast():
115 return val
116 return math_ops.cast(val, self._cast_dtype)
118 def read_value(self):
119 val = self._variable.read_value()
120 return math_ops.cast(val, self._cast_dtype)
122 def sparse_read(self, indices, name=None):
123 """Reads the value of this variable sparsely, using `gather`."""
124 val = self._variable.sparse_read(indices, name=name)
125 return math_ops.cast(val, self._cast_dtype)
127 def gather_nd(self, indices, name=None):
128 """Gather slices of the variable into a Tensor."""
129 val = self._variable.gather_nd(indices, name=name)
130 return math_ops.cast(val, self._cast_dtype)
132 def __getattr__(self, name):
133 return getattr(self._variable, name)
135 def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
136 """Converts this variable to a tensor."""
137 if as_ref:
138 # This ValueError should not occur in practice since it is impossible to
139 # pass as_ref=True using public APIs.
140 raise ValueError('Cannot convert AutoCastVariable to a tensor if '
141 'as_ref=True is passed to convert_to_tensor')
142 if not self._should_cast():
143 return tensor_conversion.convert_to_tensor_v2_with_dispatch(
144 self._variable, dtype=dtype, name=name
145 )
146 if dtype is not None and not dtype.is_compatible_with(self._cast_dtype):
147 raise ValueError(
148 'Incompatible type conversion requested to type {!r} for '
149 'AutoCastVariable which is casted to type {!r}'.format(
150 dtype.name, self._cast_dtype.name))
151 val = tensor_conversion.convert_to_tensor_v2_with_dispatch(
152 self._variable, dtype=self._variable.dtype, name=name
153 )
154 return math_ops.cast(val, self._cast_dtype)
156 def _should_act_as_resource_variable(self):
157 """Pass resource_variable_ops.is_resource_variable check."""
158 pass
160 def __repr__(self):
161 if context.executing_eagerly() and not self._in_graph_mode:
162 repr_str = ("<AutoCastVariable '{v.name}' shape={v.shape} "
163 'dtype={v.dtype.name} dtype_to_cast_to={v._cast_dtype.name}, '
164 'numpy={np_repr}>')
165 return repr_str.format(
166 v=self, np_repr=numpy_text(self.read_value(), is_repr=True))
167 else:
168 repr_str = ("<AutoCastVariable '{v.name}' shape={v.shape} "
169 'dtype={v.dtype.name} dtype_to_cast_to={v._cast_dtype.name}>')
170 return repr_str.format(v=self)
172 # Method delegations: We delegate the following methods to self._variable.
173 # Each of these methods simply calls the same method on self._variable. The
174 # base Variable raises NotImplementedError for most of these, so we must
175 # override them.
176 #
177 # We do not define the following methods from Variable for the following
178 # reasons:
179 # * 'count_up_to': This method only applies to int variables, which cannot
180 # be wrapped with an AutoCastVariable.
181 # * 'ref': Instead we inherit the definition from Variable.
182 # If we defined and delegated to Variable, the ref of an AutoCastVariable
183 # would be the same as the ref of the underlying variable, which would be
184 # strange as they are different Python objects.
186 def set_shape(self, shape):
187 return self._variable.set_shape(self, shape)
189 @property
190 def trainable(self):
191 return self._variable.trainable
193 @property
194 def synchronization(self):
195 return self._variable.synchronization
197 @property
198 def aggregation(self):
199 return self._variable.aggregation
201 def eval(self, session=None):
202 return self._variable.eval(session)
204 def initialized_value(self):
205 return self._variable.initialized_value()
207 @property
208 def initial_value(self):
209 return self._variable.initial_value
211 @property
212 def constraint(self):
213 return self._variable.constraint
215 def _apply_assign_update(self,
216 update_fn,
217 value,
218 use_locking=None,
219 name=None,
220 read_value=True):
221 # TODO(b/146181571): This logic can be simplified once
222 # DistributedVariable.assign returns a DistributedVariable. Currently for
223 # MirroredStrategy, it returns a Mirrored value.
224 if ops.executing_eagerly_outside_functions():
225 assign_op = update_fn(value, use_locking, name, False)
226 if read_value:
227 # We create a new AutoCastVariable with the same underlying tf.Variable.
228 # The new AutoCastVariable is identical except the 'op' attribute is
229 # defined. This matches the behavior of tf.Variable.assign.
230 var = create_autocast_variable(self._variable)
231 var._op = assign_op # pylint:disable=protected-access
232 return var
233 return assign_op
235 # Fallback to wrapping the returned variable in graph mode if possible
236 assign_var = update_fn(value, use_locking, name, read_value)
237 if read_value and resource_variable_ops.is_resource_variable(assign_var):
238 return create_autocast_variable(assign_var)
239 return assign_var
241 def _apply_update(self, update_fn, *args, **kwargs):
242 update_var = update_fn(*args, **kwargs)
243 if ops.executing_eagerly_outside_functions():
244 return self
246 # Fallback to wrapping the returned variable in graph mode if possible
247 if resource_variable_ops.is_resource_variable(update_var):
248 return create_autocast_variable(update_var)
249 return update_var
251 def assign(self, value, use_locking=None, name=None, read_value=True):
252 return self._apply_assign_update(self._variable.assign, value, use_locking,
253 name, read_value)
255 def assign_add(self, delta, use_locking=None, name=None, read_value=True):
256 return self._apply_assign_update(self._variable.assign_add, delta,
257 use_locking, name, read_value)
259 def assign_sub(self, delta, use_locking=None, name=None, read_value=True):
260 return self._apply_assign_update(self._variable.assign_sub, delta,
261 use_locking, name, read_value)
263 def scatter_sub(self, sparse_delta, use_locking=False, name=None):
264 return self._apply_update(self._variable.scatter_sub, sparse_delta,
265 use_locking, name)
267 def scatter_add(self, sparse_delta, use_locking=False, name=None):
268 return self._apply_update(self._variable.scatter_add, sparse_delta,
269 use_locking, name)
271 def scatter_max(self, sparse_delta, use_locking=False, name=None):
272 return self._apply_update(self._variable.scatter_max, sparse_delta,
273 use_locking, name)
275 def scatter_min(self, sparse_delta, use_locking=False, name=None):
276 return self._apply_update(self._variable.scatter_min, sparse_delta,
277 use_locking, name)
279 def scatter_mul(self, sparse_delta, use_locking=False, name=None):
280 return self._apply_update(self._variable.scatter_mul, sparse_delta,
281 use_locking, name)
283 def scatter_div(self, sparse_delta, use_locking=False, name=None):
284 return self._apply_update(self._variable.scatter_div, sparse_delta,
285 use_locking, name)
287 def scatter_update(self, sparse_delta, use_locking=False, name=None):
288 return self._apply_update(self._variable.scatter_update, sparse_delta,
289 use_locking, name)
291 def batch_scatter_update(self, sparse_delta, use_locking=False, name=None):
292 return self._apply_update(self._variable.batch_scatter_update, sparse_delta,
293 use_locking, name)
295 def scatter_nd_sub(self, indices, updates, name=None):
296 return self._apply_update(self._variable.scatter_nd_sub, indices, updates,
297 name)
299 def scatter_nd_add(self, indices, updates, name=None):
300 return self._apply_update(self._variable.scatter_nd_add, indices, updates,
301 name)
303 def scatter_nd_update(self, indices, updates, name=None):
304 return self._apply_update(self._variable.scatter_nd_update, indices,
305 updates, name)
307 def load(self, value, session=None):
308 return self._variable.load(value, session)
310 @property
311 def name(self):
312 return self._variable.name
314 @property
315 def _shared_name(self):
316 return self._variable._shared_name # pylint:disable=protected-access
318 @property
319 def initializer(self):
320 return self._variable.initializer
322 @property
323 def device(self):
324 return self._variable.device
326 @property
327 def op(self):
328 if self._op == 'delegate':
329 return self._variable.op
330 return self._op
332 def _as_graph_element(self):
333 graph_element = self._variable._as_graph_element() # pylint:disable=protected-access
334 if graph_element is None:
335 return self._op
336 return graph_element
338 @property
339 def graph(self):
340 return self._variable.graph
342 @property
343 def shape(self):
344 return self._variable.shape
346 def get_shape(self):
347 return self._variable.get_shape()
349 def _gather_saveables_for_checkpoint(self):
350 # By delegating this method to the wrapped variable, checkpoints with
351 # AutoCastVariables are identical to checkpoints with normal variables.
352 # Therefore models checkpointed with AutoCastVariables can be restored on
353 # models with normal variables, and vice versa.
354 return self._variable._gather_saveables_for_checkpoint() # pylint:disable=protected-access
356 def _export_to_saved_model_graph(self, object_map, tensor_map, options,
357 **kwargs):
358 # By delegating this method to the wrapped variable, SavedModel with
359 # AutoCastVariables are identical to SavedModel with normal variables.
360 resource_list = self._variable._export_to_saved_model_graph( # pylint:disable=protected-access
361 object_map, tensor_map, options, **kwargs)
362 object_map[self] = object_map[self._variable]
363 return resource_list
365 # TODO(reedwm): Maybe encode the fact the variable is an AutoCastVariable in
366 # to_proto().
367 def to_proto(self, export_scope=None):
368 return self._variable.to_proto(export_scope)
370 def from_proto(self, variable_def, import_scope=None):
371 return self._variable.from_proto(variable_def, import_scope)
373 # Delegate the private attributes _handle_name and _initializer_op to
374 # self._variable. SavedModel sets these attributes when loading a model. For
375 # example, it sets _handle_name here:
376 # https://github.com/tensorflow/tensorflow/blob/db26bd574fa95b5bdd53c08463dd19407cc0297e/tensorflow/python/keras/saving/saved_model/load.py#L211
377 # We need to expose these attributes on AutoCastVariable as well for
378 # SavedModel to work properly.
379 # TODO(reedwm/kathywu): Find a better way to support SavedModel. Exposing
380 # private attributes is hacky and difficult to maintain.
381 @property
382 def _handle_name(self):
383 return self._variable._handle_name # pylint: disable=protected-access
385 @_handle_name.setter
386 def _handle_name(self, handle_name):
387 self._variable._handle_name = handle_name # pylint: disable=protected-access
389 @property
390 def _initializer_op(self):
391 return self._variable._initializer_op # pylint: disable=protected-access
393 @_initializer_op.setter
394 def _initializer_op(self, initializer_op):
395 self._variable._initializer_op = initializer_op # pylint: disable=protected-access
397 # Operator overloads:
398 # Note we only overload operators that support floating-point types, as
399 # non-float variables cannot be wrapped with an AutoCastVariable.
400 # Also note: We call read_value() instead of value(), because value() causes
401 # gradients not to work properly when TPUStrategy is used: b/143380936
403 def __add__(self, o):
404 return self.read_value() + o
406 def __radd__(self, o):
407 return o + self.read_value()
409 def __sub__(self, o):
410 return self.read_value() - o
412 def __rsub__(self, o):
413 return o - self.read_value()
415 def __mul__(self, o):
416 return self.read_value() * o
418 def __rmul__(self, o):
419 return o * self.read_value()
421 def __truediv__(self, o):
422 return self.read_value() / o
424 def __rtruediv__(self, o):
425 return o / self.read_value()
427 def __floordiv__(self, o):
428 return self.read_value() // o
430 def __rfloordiv__(self, o):
431 return o // self.read_value()
433 def __mod__(self, o):
434 return self.read_value() % o
436 def __rmod__(self, o):
437 return o % self.read_value()
439 def __lt__(self, o):
440 return self.read_value() < o
442 def __le__(self, o):
443 return self.read_value() <= o
445 def __gt__(self, o):
446 return self.read_value() > o
448 def __ge__(self, o):
449 return self.read_value() >= o
451 def __getitem__(self, o):
452 return self.read_value()[o]
454 def __pow__(self, o, modulo=None):
455 return pow(self.read_value(), o, modulo)
457 def __rpow__(self, o):
458 return pow(o, self.read_value())
460 def __neg__(self):
461 return -self.read_value() # pylint: disable=invalid-unary-operand-type
463 def __abs__(self):
464 return abs(self.read_value())
466 def __div__(self, o):
467 try:
468 return self.read_value().__div__(o)
469 except AttributeError:
470 # See https://docs.python.org/3/library/constants.html#NotImplemented
471 return NotImplemented
473 def __rdiv__(self, o):
474 try:
475 return self.read_value().__rdiv__(o)
476 except AttributeError:
477 # See https://docs.python.org/3/library/constants.html#NotImplemented
478 return NotImplemented
480 def __matmul__(self, o):
481 try:
482 return self.read_value().__matmul__(o)
483 except AttributeError:
484 # See https://docs.python.org/3/library/constants.html#NotImplemented
485 return NotImplemented
487 def __rmatmul__(self, o):
488 try:
489 return self.read_value().__rmatmul__(o)
490 except AttributeError:
491 # See https://docs.python.org/3/library/constants.html#NotImplemented
492 return NotImplemented
494 # pylint: enable=multiple-statements
497tensor_conversion_registry.register_tensor_conversion_function(
498 AutoCastVariable, AutoCastVariable._dense_var_to_tensor) # pylint:disable=protected-access
501def create_autocast_variable(variable):
502 """Creates an AutoCastVariable that wraps another variable.
504 This typically just returns `AutoCastVariable(variable)`. But, if the variable
505 is a DistributedVariable or one of its subclasses, we instead dynamically
506 create a class that subclasses from both AutoCastVariable and
507 variable.__class__. This is so the returned variable will still pass
508 `isinstance(variable, variable.__class__)`, which is required for
509 DistributedVariables and its subclasses to work properly.
511 Args:
512 variable: A floating-point resource variable to wrap.
514 Returns:
515 An AutoCastVariable that wraps the variable.
516 """
517 if not distributed_training_utils.is_distributed_variable(variable):
518 return AutoCastVariable(variable)
520 class AutoCastDistributedVariable(AutoCastVariable, variable.__class__):
521 """An AutoCastVariable that also subclasses from variable.__class__.
523 variable.__class__ is either a DistributedVariable or an
524 AggregatingVariable.
525 """
527 def __repr__(self):
529 # pylint: disable=missing-format-attribute
530 return ('<AutoCastDistributedVariable dtype={v.dtype.name} '
531 'dtype_to_cast_to={v._cast_dtype.name} '
532 'inner_variable={v._variable}>'
533 ).format(v=self)
534 # pylint: enable=missing-format-attribute
536 return AutoCastDistributedVariable(variable)
539class enable_auto_cast_variables(object): # pylint:disable=invalid-name
540 """Context manager which enables the autocasting of `AutoCastVariable`s.
542 Under this context manager, `AutoCastVariable`s will be cast to `dtype` if
543 `dtype` is floating-point. Otherwise, `AutoCastVariable`s will not be cast.
544 """
546 __slots__ = ['_dtype', '_prev_dtype']
548 def __init__(self, dtype):
549 if dtype and not dtype.is_floating:
550 dtype = None
551 self._dtype = dtype
553 def __enter__(self):
554 self._prev_dtype = getattr(_autocast_dtype, 'dtype', None)
555 _autocast_dtype.dtype = self._dtype
557 def __exit__(self, type_arg, value_arg, traceback_arg):
558 _autocast_dtype.dtype = self._prev_dtype