Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/backend_config.py: 94%

32 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Keras backend config API.""" 

16 

17import tensorflow.compat.v2 as tf 

18 

19# isort: off 

20from tensorflow.python.util.tf_export import keras_export 

21 

22# The type of float to use throughout a session. 

23_FLOATX = "float32" 

24 

25# Epsilon fuzz factor used throughout the codebase. 

26_EPSILON = 1e-7 

27 

28# Default image data format, one of "channels_last", "channels_first". 

29_IMAGE_DATA_FORMAT = "channels_last" 

30 

31 

32@keras_export("keras.backend.epsilon") 

33@tf.__internal__.dispatch.add_dispatch_support 

34def epsilon(): 

35 """Returns the value of the fuzz factor used in numeric expressions. 

36 

37 Returns: 

38 A float. 

39 

40 Example: 

41 >>> tf.keras.backend.epsilon() 

42 1e-07 

43 """ 

44 return _EPSILON 

45 

46 

47@keras_export("keras.backend.set_epsilon") 

48def set_epsilon(value): 

49 """Sets the value of the fuzz factor used in numeric expressions. 

50 

51 Args: 

52 value: float. New value of epsilon. 

53 

54 Example: 

55 >>> tf.keras.backend.epsilon() 

56 1e-07 

57 >>> tf.keras.backend.set_epsilon(1e-5) 

58 >>> tf.keras.backend.epsilon() 

59 1e-05 

60 >>> tf.keras.backend.set_epsilon(1e-7) 

61 """ 

62 global _EPSILON 

63 _EPSILON = value 

64 

65 

66@keras_export("keras.backend.floatx") 

67def floatx(): 

68 """Returns the default float type, as a string. 

69 

70 E.g. `'float16'`, `'float32'`, `'float64'`. 

71 

72 Returns: 

73 String, the current default float type. 

74 

75 Example: 

76 >>> tf.keras.backend.floatx() 

77 'float32' 

78 """ 

79 return _FLOATX 

80 

81 

82@keras_export("keras.backend.set_floatx") 

83def set_floatx(value): 

84 """Sets the default float type. 

85 

86 Note: It is not recommended to set this to float16 for training, as this 

87 will likely cause numeric stability issues. Instead, mixed precision, which 

88 is using a mix of float16 and float32, can be used by calling 

89 `tf.keras.mixed_precision.set_global_policy('mixed_float16')`. See the 

90 [mixed precision guide]( 

91 https://www.tensorflow.org/guide/keras/mixed_precision) for details. 

92 

93 Args: 

94 value: String; `'float16'`, `'float32'`, or `'float64'`. 

95 

96 Example: 

97 >>> tf.keras.backend.floatx() 

98 'float32' 

99 >>> tf.keras.backend.set_floatx('float64') 

100 >>> tf.keras.backend.floatx() 

101 'float64' 

102 >>> tf.keras.backend.set_floatx('float32') 

103 

104 Raises: 

105 ValueError: In case of invalid value. 

106 """ 

107 global _FLOATX 

108 accepted_dtypes = {"float16", "float32", "float64"} 

109 if value not in accepted_dtypes: 

110 raise ValueError( 

111 f"Unknown `floatx` value: {value}. " 

112 f"Expected one of {accepted_dtypes}" 

113 ) 

114 _FLOATX = str(value) 

115 

116 

117@keras_export("keras.backend.image_data_format") 

118@tf.__internal__.dispatch.add_dispatch_support 

119def image_data_format(): 

120 """Returns the default image data format convention. 

121 

122 Returns: 

123 A string, either `'channels_first'` or `'channels_last'` 

124 

125 Example: 

126 >>> tf.keras.backend.image_data_format() 

127 'channels_last' 

128 """ 

129 return _IMAGE_DATA_FORMAT 

130 

131 

132@keras_export("keras.backend.set_image_data_format") 

133def set_image_data_format(data_format): 

134 """Sets the value of the image data format convention. 

135 

136 Args: 

137 data_format: string. `'channels_first'` or `'channels_last'`. 

138 

139 Example: 

140 >>> tf.keras.backend.image_data_format() 

141 'channels_last' 

142 >>> tf.keras.backend.set_image_data_format('channels_first') 

143 >>> tf.keras.backend.image_data_format() 

144 'channels_first' 

145 >>> tf.keras.backend.set_image_data_format('channels_last') 

146 

147 Raises: 

148 ValueError: In case of invalid `data_format` value. 

149 """ 

150 global _IMAGE_DATA_FORMAT 

151 accepted_formats = {"channels_last", "channels_first"} 

152 if data_format not in accepted_formats: 

153 raise ValueError( 

154 f"Unknown `data_format`: {data_format}. " 

155 f"Expected one of {accepted_formats}" 

156 ) 

157 _IMAGE_DATA_FORMAT = str(data_format) 

158