Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/tpu/bfloat16.py: 42%
26 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
16"""Helper context for running models with bfloat16."""
18from typing import Generator, Optional, Text
20from tensorflow.python.framework import dtypes
21from tensorflow.python.ops import math_ops
22from tensorflow.python.ops import variable_scope
23from tensorflow.python.util import tf_contextlib
24from tensorflow.python.util.tf_export import tf_export
27def _get_custom_getter():
28 """Returns a custom getter that this class's methods must be called under.
30 All methods of this class must be called under a variable scope that was
31 passed this custom getter. Example:
33 ```python
34 network = ConvNetBuilder(...)
35 with tf.compat.v1.variable_scope('cg',
36 custom_getter=network.get_custom_getter()):
37 network.conv(...)
38 # Call more methods of network here
39 ```
41 Currently, this custom getter only does anything if self.use_tf_layers is
42 True. In that case, it causes variables to be stored as dtype
43 self.variable_type, then casted to the requested dtype, instead of directly
44 storing the variable as the requested dtype.
45 """
47 def inner_custom_getter(getter, *args, **kwargs):
48 """Custom getter that forces variables to have type self.variable_type."""
49 cast_to_bfloat16 = False
50 requested_dtype = kwargs['dtype']
51 if requested_dtype == dtypes.bfloat16:
52 # Only change the variable dtype if doing so does not decrease variable
53 # precision.
54 kwargs['dtype'] = dtypes.float32
55 cast_to_bfloat16 = True
56 var = getter(*args, **kwargs)
57 # This if statement is needed to guard the cast, because batch norm
58 # assigns directly to the return value of this custom getter. The cast
59 # makes the return value not a variable so it cannot be assigned. Batch
60 # norm variables are always in fp32 so this if statement is never
61 # triggered for them.
62 if cast_to_bfloat16:
63 var = math_ops.cast(var, dtypes.bfloat16)
64 return var
66 return inner_custom_getter
69@tf_export(v1=['tpu.bfloat16_scope'])
70@tf_contextlib.contextmanager
71def bfloat16_scope(
72 name: Optional[Text] = None
73) -> Generator[variable_scope.variable_scope, None, None]:
74 """Scope class for bfloat16 variables so that the model uses custom getter.
76 This enables variables to be read as bfloat16 type when using get_variable.
78 Arguments:
79 name: Name to use for scope.
81 Yields:
82 a variable scope.
83 """
84 if name is None:
85 name = ''
86 with variable_scope.variable_scope(
87 name, custom_getter=_get_custom_getter()) as varscope:
88 yield varscope