Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/mixed_precision/policy.py: 29%
126 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Contains the Policy class for mixed precision training."""
17import contextlib
19import tensorflow.compat.v2 as tf
21from keras.src import backend
22from keras.src.engine import base_layer_utils
23from keras.src.mixed_precision import device_compatibility_check
24from keras.src.mixed_precision import loss_scale_optimizer
25from keras.src.saving import serialization_lib
27# isort: off
28from tensorflow.python.util.tf_export import keras_export
31@keras_export("keras.mixed_precision.Policy", v1=[])
32class Policy:
33 """A dtype policy for a Keras layer.
35 A dtype policy determines a layer's computation and variable dtypes. Each
36 layer has a policy. Policies can be passed to the `dtype` argument of layer
37 constructors, or a global policy can be set with
38 `tf.keras.mixed_precision.set_global_policy`.
40 Args:
41 name: The policy name, which determines the compute and variable dtypes.
42 Can be any dtype name, such as `'float32'` or `'float64'`, which causes
43 both the compute and variable dtypes will be that dtype. Can also be the
44 string `'mixed_float16'` or `'mixed_bfloat16'`, which causes the compute
45 dtype to be float16 or bfloat16 and the variable dtype to be float32.
47 Typically you only need to interact with dtype policies when using mixed
48 precision, which is the use of float16 or bfloat16 for computations and
49 float32 for variables. This is why the term `mixed_precision` appears in the
50 API name. Mixed precision can be enabled by passing `'mixed_float16'` or
51 `'mixed_bfloat16'` to `tf.keras.mixed_precision.set_global_policy`. See [the
52 mixed precision
53 guide](https://www.tensorflow.org/guide/keras/mixed_precision) for more
54 information on how to use mixed precision.
56 >>> tf.keras.mixed_precision.set_global_policy('mixed_float16')
57 >>> layer1 = tf.keras.layers.Dense(10)
58 >>> layer1.dtype_policy # `layer1` will automatically use mixed precision
59 <Policy "mixed_float16">
60 >>> # Can optionally override layer to use float32
61 >>> # instead of mixed precision.
62 >>> layer2 = tf.keras.layers.Dense(10, dtype='float32')
63 >>> layer2.dtype_policy
64 <Policy "float32">
65 >>> # Set policy back to initial float32 for future examples.
66 >>> tf.keras.mixed_precision.set_global_policy('float32')
68 In the example above, passing `dtype='float32'` to the layer is equivalent
69 to passing `dtype=tf.keras.mixed_precision.Policy('float32')`. In general,
70 passing a dtype policy name to a layer is equivalent to passing the
71 corresponding policy, so it is never necessary to explicitly construct a
72 `Policy` object.
74 Note: `Model.compile` will automatically wrap an optimizer with a
75 `tf.keras.mixed_precision.LossScaleOptimizer` if you use the
76 `'mixed_float16'` policy. If you use a custom training loop instead of
77 calling `Model.compile`, you should explicitly use a
78 `tf.keras.mixed_precision.LossScaleOptimizer` to avoid numeric underflow
79 with float16.
81 ### How a layer uses its policy's compute dtype
83 A layer casts its inputs to its compute dtype. This causes the layer's
84 computations and output to also be in the compute dtype. For example:
86 >>> x = tf.ones((4, 4, 4, 4), dtype='float64')
87 >>> # `layer`'s policy defaults to float32.
88 >>> layer = tf.keras.layers.Conv2D(filters=4, kernel_size=2)
89 >>> layer.compute_dtype # Equivalent to layer.dtype_policy.compute_dtype
90 'float32'
91 >>> # `layer` casts its inputs to its compute dtype and does computations in
92 >>> # that dtype.
93 >>> y = layer(x)
94 >>> y.dtype
95 tf.float32
97 Note that the base `tf.keras.layers.Layer` class inserts the casts. If
98 subclassing your own layer, you do not have to insert any casts.
100 Currently, only tensors in the first argument to the layer's `call` method
101 are casted (although this will likely be changed in a future minor release).
102 For example:
104 >>> class MyLayer(tf.keras.layers.Layer):
105 ... # Bug! `b` will not be casted.
106 ... def call(self, a, b):
107 ... return a + 1., b + 1.
108 >>> a = tf.constant(1., dtype="float32")
109 >>> b = tf.constant(1., dtype="float32")
110 >>> layer = MyLayer(dtype="float64")
111 >>> x, y = layer(a, b)
112 >>> x.dtype
113 tf.float64
114 >>> y.dtype
115 tf.float32
117 If writing your own layer with multiple inputs, you should either explicitly
118 cast other tensors to `self.compute_dtype` in `call` or accept all tensors
119 in the first argument as a list.
121 The casting only occurs in TensorFlow 2. If
122 `tf.compat.v1.disable_v2_behavior()` has been called, you can enable the
123 casting behavior with
124 `tf.compat.v1.keras.layers.enable_v2_dtype_behavior()`.
126 ### How a layer uses its policy's variable dtype
128 The default dtype of variables created by `tf.keras.layers.Layer.add_weight`
129 is the layer's policy's variable dtype.
131 If a layer's compute and variable dtypes differ, `add_weight` will wrap
132 floating-point variables with a special wrapper called an
133 `AutoCastVariable`. `AutoCastVariable` is identical to the original
134 variable except it casts itself to the layer's compute dtype when used
135 within `Layer.call`. This means if you are writing a layer, you do not have
136 to explicitly cast the variables to the layer's compute dtype. For example:
138 >>> class SimpleDense(tf.keras.layers.Layer):
139 ...
140 ... def build(self, input_shape):
141 ... # With mixed precision, self.kernel is a float32 AutoCastVariable
142 ... self.kernel = self.add_weight('kernel', (input_shape[-1], 10))
143 ...
144 ... def call(self, inputs):
145 ... # With mixed precision, self.kernel will be casted to float16
146 ... return tf.linalg.matmul(inputs, self.kernel)
147 ...
148 >>> layer = SimpleDense(dtype='mixed_float16')
149 >>> y = layer(tf.ones((10, 10)))
150 >>> y.dtype
151 tf.float16
152 >>> layer.kernel.dtype
153 tf.float32
155 A layer author can prevent a variable from being wrapped with an
156 `AutoCastVariable` by passing `experimental_autocast=False` to `add_weight`,
157 which is useful if the float32 value of the variable must be accessed within
158 the layer.
160 ### How to write a layer that supports mixed precision and float64.
162 For the most part, layers will automatically support mixed precision and
163 float64 without any additional work, due to the fact the base layer
164 automatically casts inputs, creates variables of the correct type, and in
165 the case of mixed precision, wraps variables with `AutoCastVariables`.
167 The primary case where you need extra work to support mixed precision or
168 float64 is when you create a new tensor, such as with `tf.ones` or
169 `tf.random.normal`, In such cases, you must create the tensor of the correct
170 dtype. For example, if you call `tf.random.normal`, you must pass the
171 compute dtype, which is the dtype the inputs have been casted to:
173 >>> class AddRandom(tf.keras.layers.Layer):
174 ...
175 ... def call(self, inputs):
176 ... # We must pass `dtype=inputs.dtype`, otherwise a TypeError may
177 ... # occur when adding `inputs` to `rand`.
178 ... rand = tf.random.normal(shape=inputs.shape, dtype=inputs.dtype)
179 ... return inputs + rand
180 >>> layer = AddRandom(dtype='mixed_float16')
181 >>> y = layer(x)
182 >>> y.dtype
183 tf.float16
185 If you did not pass `dtype=inputs.dtype` to `tf.random.normal`, a
186 `TypeError` would have occurred. This is because the `tf.random.normal`'s
187 dtype defaults to `"float32"`, but the input dtype is float16. You cannot
188 add a float32 tensor with a float16 tensor.
189 """
191 def __init__(self, name):
192 if isinstance(name, tf.DType):
193 raise TypeError(
194 "'name' must be a string, not a DType. "
195 f"Instead, pass DType.name. Received: name={name.name}"
196 )
197 elif not isinstance(name, str):
198 raise TypeError(f"'name' must be a string, but got: {name}")
199 self._name = name
200 self._compute_dtype, self._variable_dtype = self._parse_name(name)
201 if name in ("mixed_float16", "mixed_bloat16"):
202 device_compatibility_check.log_device_compatibility_check(name)
204 def _parse_name(self, name):
205 """Parses a Policy name into a compute and variable dtype.
207 Args:
208 name: The name of the policy:
210 Returns:
211 The (compute_dtype, variable_dtype) pair.
212 """
213 if name.endswith("_float32_vars"):
214 error_msg = (
215 "Policies ending in '_float32_vars' have been removed "
216 "from TensorFlow."
217 )
218 if name in ("infer_float32_vars", "infer_with_float32_vars"):
219 error_msg += (
220 " Please use the 'mixed_float16' or 'mixed_bfloat16' "
221 "policy instead."
222 )
223 elif name == "float16_with_float32_vars":
224 error_msg += " Please use the 'mixed_float16' policy instead."
225 elif name == "bfloat16_with_float32_vars":
226 error_msg += " Please use the 'mixed_bfloat16' policy instead."
227 error_msg += f" Got policy name: '{name}'"
228 raise ValueError(error_msg)
230 if name == "mixed_float16":
231 return "float16", "float32"
232 elif name == "mixed_bfloat16":
233 return "bfloat16", "float32"
234 elif name == "_infer":
235 # The "_infer" policy exists only for compatibility with TF 1, where
236 # "_infer" is the default. The behavior matches the behavior of TF
237 # 1's behavior before policies were introduced. With "_infer", the
238 # computation and variable dtype are inferred from the first input
239 # the first time the layer is called. Once the layer is called for
240 # the first time, the layer's policy will change to the dtype of the
241 # first input, and it will no longer have the "_infer" policy.
242 #
243 # The infer policy should be considered an implementation detail and
244 # may be removed in the future.
245 return None, None
247 try:
248 dtype = tf.as_dtype(name).name
249 except TypeError:
250 raise ValueError(
251 f"Cannot convert value {name} to a mixed precision Policy. "
252 "Valid policies include 'mixed_float16', 'mixed_bfloat16', "
253 "and the name of any dtype such as 'float32'."
254 )
255 return dtype, dtype
257 @property
258 def variable_dtype(self):
259 """The variable dtype of this policy.
261 This is the dtype layers will create their variables in, unless a layer
262 explicitly chooses a different dtype. If this is different than
263 `Policy.compute_dtype`, Layers will cast variables to the compute dtype
264 to avoid type errors.
266 Variable regularizers are run in the variable dtype, not the compute
267 dtype.
269 Returns:
270 The variable dtype of this policy, as a string.
271 """
272 return self._variable_dtype
274 @property
275 def compute_dtype(self):
276 """The compute dtype of this policy.
278 This is the dtype layers will do their computations in. Typically layers
279 output tensors with the compute dtype as well.
281 Note that even if the compute dtype is float16 or bfloat16, hardware
282 devices may not do individual adds, multiplies, and other fundamental
283 operations in float16 or bfloat16, but instead may do some of them in
284 float32 for numeric stability. The compute dtype is the dtype of the
285 inputs and outputs of the TensorFlow ops that the layer executes.
286 Internally, many TensorFlow ops will do certain internal calculations in
287 float32 or some other device-internal intermediate format with higher
288 precision than float16/bfloat16, to increase numeric stability.
290 For example, a `tf.keras.layers.Dense` layer, when run on a GPU with a
291 float16 compute dtype, will pass float16 inputs to `tf.linalg.matmul`.
292 But, `tf.linalg.matmul` will do use float32 intermediate math. The
293 performance benefit of float16 is still apparent, due to increased
294 memory bandwidth and the fact modern GPUs have specialized hardware for
295 computing matmuls on float16 inputs while still keeping intermediate
296 computations in float32.
298 Returns:
299 The compute dtype of this policy, as a string.
300 """
301 return self._compute_dtype
303 @property
304 def name(self):
305 """Returns the name of this policy."""
306 return self._name
308 def __repr__(self):
309 return f'<Policy "{self._name}">'
311 def get_config(self):
312 return {"name": self.name}
314 @classmethod
315 def from_config(cls, config, custom_objects=None):
316 del custom_objects
317 if "loss_scale" in config:
318 config = config.copy()
319 # Policy.get_config in TensorFlow 2.3 and below had a loss_scale. We
320 # silently drop it.
321 del config["loss_scale"]
322 return cls(**config)
325# The current global policy in effect. If None, it means the current value of
326# floatx should be used as the policy if the V2 dtype behavior is enabled,
327# or "_infer" otherwise.
328# TODO(reedwm): Make this thread local?
329_global_policy = None
332@keras_export("keras.mixed_precision.global_policy", v1=[])
333def global_policy():
334 """Returns the global dtype policy.
336 The global policy is the default `tf.keras.mixed_precision.Policy` used for
337 layers, if no policy is passed to the layer constructor. If no policy has
338 been set with `keras.mixed_precision.set_global_policy`, this will return a
339 policy constructed from `tf.keras.backend.floatx()` (floatx defaults to
340 float32).
342 >>> tf.keras.mixed_precision.global_policy()
343 <Policy "float32">
344 >>> tf.keras.layers.Dense(10).dtype_policy # Defaults to the global policy
345 <Policy "float32">
347 If TensorFlow 2 behavior has been disabled with
348 `tf.compat.v1.disable_v2_behavior()`, this will instead return a special
349 "_infer" policy which infers the dtype from the dtype of the first input the
350 first time the layer is called. This behavior matches the behavior that
351 existed in TensorFlow 1.
353 See `tf.keras.mixed_precision.Policy` for more information on policies.
355 Returns:
356 The global Policy.
357 """
358 if _global_policy is None:
359 if base_layer_utils.v2_dtype_behavior_enabled():
360 return Policy(backend.floatx())
361 else:
362 return Policy("_infer")
363 return _global_policy
366def _check_if_mixed_precision_graph_rewrite_is_enabled(policy):
367 if tf.__internal__.train.is_mixed_precision_graph_rewrite_enabled():
368 raise ValueError(
369 'The global dtype policy cannot be set to "{policy.name}", because '
370 "the mixed precision graph rewrite has already been enabled.\n"
371 "At most, one of the following can be called:\n\n"
372 " 1. tf.compat.v1.train.enable_mixed_precision_graph_rewrite() "
373 "(You called this first)\n"
374 " 2. tf.keras.mixed_precision.set_global_policy() with a mixed "
375 "precision policy (You called this second)\n\n"
376 "You called both functions, which is an error, because both "
377 "functions enable you to use mixed precision. If in doubt which "
378 "function to use, use the second, as it supports Eager execution "
379 "and is more customizable.".format(policy=policy)
380 )
383@keras_export("keras.mixed_precision.set_global_policy", v1=[])
384def set_global_policy(policy):
385 """Sets the global dtype policy.
387 The global policy is the default `tf.keras.mixed_precision.Policy` used for
388 layers, if no policy is passed to the layer constructor.
390 >>> tf.keras.mixed_precision.set_global_policy('mixed_float16')
391 >>> tf.keras.mixed_precision.global_policy()
392 <Policy "mixed_float16">
393 >>> tf.keras.layers.Dense(10).dtype_policy
394 <Policy "mixed_float16">
395 >>> # Global policy is not used if a policy
396 >>> # is directly passed to constructor
397 >>> tf.keras.layers.Dense(10, dtype='float64').dtype_policy
398 <Policy "float64">
399 >>> tf.keras.mixed_precision.set_global_policy('float32')
401 If no global policy is set, layers will instead default to a Policy
402 constructed from `tf.keras.backend.floatx()`.
404 To use mixed precision, the global policy should be set to `'mixed_float16'`
405 or `'mixed_bfloat16'`, so that every layer uses a 16-bit compute dtype and
406 float32 variable dtype by default.
408 Only floating point policies can be set as the global policy, such as
409 `'float32'` and `'mixed_float16'`. Non-floating point policies such as
410 `'int32'` and `'complex64'` cannot be set as the global policy because most
411 layers do not support such policies.
413 See `tf.keras.mixed_precision.Policy` for more information.
415 Args:
416 policy: A Policy, or a string that will be converted to a Policy. Can also
417 be None, in which case the global policy will be constructed from
418 `tf.keras.backend.floatx()`
419 """
420 global _global_policy
421 if not base_layer_utils.v2_dtype_behavior_enabled():
422 raise ValueError(
423 "The global policy can only be set in TensorFlow 2 or if "
424 "V2 dtype behavior has been set. To enable V2 dtype "
425 "behavior, call "
426 '"tf.compat.v1.keras.layers.enable_v2_dtype_behavior()"'
427 )
428 if policy is not None and not isinstance(policy, Policy):
429 policy = Policy(policy)
430 is_mixed_policy = (
431 policy is not None and policy.compute_dtype != policy.variable_dtype
432 )
433 if is_mixed_policy:
434 _check_if_mixed_precision_graph_rewrite_is_enabled(policy)
435 if (
436 policy is not None
437 and policy.compute_dtype is not None
438 and not tf.as_dtype(policy.compute_dtype).is_floating
439 ):
440 raise ValueError(
441 "set_global_policy can only be used to set the global "
442 'policy to floating-point policies, such as "float32" and '
443 f'"mixed_float16", but got policy: {policy.name}'
444 )
445 _global_policy = policy
446 tf.__internal__.train.set_using_mixed_precision_policy(is_mixed_policy)
449# TODO(reedwm): Make this thread local
450@contextlib.contextmanager
451def policy_scope(policy):
452 """A context manager that sets the global Policy under it.
454 Args:
455 policy: A Policy, or a string that will be converted to a Policy..
457 Yields:
458 Nothing.
459 """
460 old_policy = _global_policy
461 try:
462 set_global_policy(policy)
463 yield
464 finally:
465 set_global_policy(old_policy)
468def get_policy(identifier):
469 if isinstance(identifier, Policy):
470 dtype_policy = identifier
471 elif isinstance(identifier, dict):
472 dtype_policy = deserialize(identifier)
473 elif isinstance(identifier, str) and identifier in (
474 "mixed_float16",
475 "mixed_bfloat16",
476 ):
477 # The isinstance check is required since np.dtype raises an error if
478 # compared to a non-dtype string.
479 dtype_policy = Policy(identifier)
480 elif identifier:
481 dtype_policy = Policy(tf.as_dtype(identifier).name)
482 else:
483 dtype_policy = global_policy()
484 if (
485 dtype_policy.name == "mixed_float16"
486 and not loss_scale_optimizer.strategy_supports_loss_scaling()
487 ):
488 # Although only loss scaling doesn't support certain strategies, to
489 # avoid confusion, we disallow the 'mixed_float16' policy with
490 # unsupported strategies. This is because 'mixed_float16' requires
491 # loss scaling for numeric stability.
492 strategy = tf.distribute.get_strategy()
493 raise ValueError(
494 "Mixed precision is not supported with the "
495 f"tf.distribute.Strategy: {strategy.__class__.__name__}. "
496 "Either stop using mixed precision by removing the use of "
497 f"the {dtype_policy.name} policy or "
498 "use a different Strategy, e.g. a MirroredStrategy."
499 )
500 return dtype_policy
503def _is_convertible_to_dtype(dtype):
504 try:
505 tf.as_dtype(dtype)
506 return True
507 except TypeError:
508 return False
511def _policy_equivalent_to_dtype(policy):
512 """Returns True if the Policy is equivalent to a single dtype.
514 A policy is equivalent to a single dtype if the policy's compute and
515 variable dtypes are the same and the policy's type is Policy and not a
516 subclass of Policy.
518 The "_infer" policy is considered equivalent to a single dtype.
520 Args:
521 policy: A Policy.
523 Returns:
524 True, if the policy is equivalent to a single dtype.
525 """
526 # We use type() instead of isinstance because a subclass of Policy is never
527 # equivalent to a dtype.
528 return type(policy) == Policy and (
529 policy.name == "_infer" or _is_convertible_to_dtype(policy.name)
530 )
533def serialize(policy):
534 if _policy_equivalent_to_dtype(policy):
535 # We return either None or the policy name for compatibility with older
536 # versions of Keras. If the policy name is returned, it is a dtype
537 # string such as 'float32'.
538 return None if policy.name == "_infer" else policy.name
539 return serialization_lib.serialize_keras_object(policy)
542def deserialize(config, custom_objects=None):
543 if isinstance(config, str) and _is_convertible_to_dtype(config):
544 return Policy(config)
545 if config is None:
546 return Policy("_infer")
547 # PolicyV1 was an old version of Policy that was removed. Deserializing it
548 # turns it into a (non-V1) Policy.
549 module_objects = {"Policy": Policy, "PolicyV1": Policy}
550 return serialization_lib.deserialize_keras_object(
551 config,
552 module_objects=module_objects,
553 custom_objects=custom_objects,
554 printable_module_name="dtype policy",
555 )