Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/tpu/bfloat16.py: 42%

26 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================= 

15 

16"""Helper context for running models with bfloat16.""" 

17 

18from typing import Generator, Optional, Text 

19 

20from tensorflow.python.framework import dtypes 

21from tensorflow.python.ops import math_ops 

22from tensorflow.python.ops import variable_scope 

23from tensorflow.python.util import tf_contextlib 

24from tensorflow.python.util.tf_export import tf_export 

25 

26 

27def _get_custom_getter(): 

28 """Returns a custom getter that this class's methods must be called under. 

29 

30 All methods of this class must be called under a variable scope that was 

31 passed this custom getter. Example: 

32 

33 ```python 

34 network = ConvNetBuilder(...) 

35 with tf.compat.v1.variable_scope('cg', 

36 custom_getter=network.get_custom_getter()): 

37 network.conv(...) 

38 # Call more methods of network here 

39 ``` 

40 

41 Currently, this custom getter only does anything if self.use_tf_layers is 

42 True. In that case, it causes variables to be stored as dtype 

43 self.variable_type, then casted to the requested dtype, instead of directly 

44 storing the variable as the requested dtype. 

45 """ 

46 

47 def inner_custom_getter(getter, *args, **kwargs): 

48 """Custom getter that forces variables to have type self.variable_type.""" 

49 cast_to_bfloat16 = False 

50 requested_dtype = kwargs['dtype'] 

51 if requested_dtype == dtypes.bfloat16: 

52 # Only change the variable dtype if doing so does not decrease variable 

53 # precision. 

54 kwargs['dtype'] = dtypes.float32 

55 cast_to_bfloat16 = True 

56 var = getter(*args, **kwargs) 

57 # This if statement is needed to guard the cast, because batch norm 

58 # assigns directly to the return value of this custom getter. The cast 

59 # makes the return value not a variable so it cannot be assigned. Batch 

60 # norm variables are always in fp32 so this if statement is never 

61 # triggered for them. 

62 if cast_to_bfloat16: 

63 var = math_ops.cast(var, dtypes.bfloat16) 

64 return var 

65 

66 return inner_custom_getter 

67 

68 

69@tf_export(v1=['tpu.bfloat16_scope']) 

70@tf_contextlib.contextmanager 

71def bfloat16_scope( 

72 name: Optional[Text] = None 

73) -> Generator[variable_scope.variable_scope, None, None]: 

74 """Scope class for bfloat16 variables so that the model uses custom getter. 

75 

76 This enables variables to be read as bfloat16 type when using get_variable. 

77 

78 Arguments: 

79 name: Name to use for scope. 

80 

81 Yields: 

82 a variable scope. 

83 """ 

84 if name is None: 

85 name = '' 

86 with variable_scope.variable_scope( 

87 name, custom_getter=_get_custom_getter()) as varscope: 

88 yield varscope