Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/backend_config.py: 77%

30 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Keras backend config API.""" 

16 

17from tensorflow.python.util import dispatch 

18from tensorflow.python.util.tf_export import keras_export 

19 

20# The type of float to use throughout a session. 

21_FLOATX = 'float32' 

22 

23# Epsilon fuzz factor used throughout the codebase. 

24_EPSILON = 1e-7 

25 

26# Default image data format, one of "channels_last", "channels_first". 

27_IMAGE_DATA_FORMAT = 'channels_last' 

28 

29 

30@keras_export('keras.backend.epsilon') 

31@dispatch.add_dispatch_support 

32def epsilon(): 

33 """Returns the value of the fuzz factor used in numeric expressions. 

34 

35 Returns: 

36 A float. 

37 

38 Example: 

39 >>> tf.keras.backend.epsilon() 

40 1e-07 

41 """ 

42 return _EPSILON 

43 

44 

45@keras_export('keras.backend.set_epsilon') 

46def set_epsilon(value): 

47 """Sets the value of the fuzz factor used in numeric expressions. 

48 

49 Args: 

50 value: float. New value of epsilon. 

51 

52 Example: 

53 >>> tf.keras.backend.epsilon() 

54 1e-07 

55 >>> tf.keras.backend.set_epsilon(1e-5) 

56 >>> tf.keras.backend.epsilon() 

57 1e-05 

58 >>> tf.keras.backend.set_epsilon(1e-7) 

59 """ 

60 global _EPSILON 

61 _EPSILON = value 

62 

63 

64@keras_export('keras.backend.floatx') 

65def floatx(): 

66 """Returns the default float type, as a string. 

67 

68 E.g. `'float16'`, `'float32'`, `'float64'`. 

69 

70 Returns: 

71 String, the current default float type. 

72 

73 Example: 

74 >>> tf.keras.backend.floatx() 

75 'float32' 

76 """ 

77 return _FLOATX 

78 

79 

80@keras_export('keras.backend.set_floatx') 

81def set_floatx(value): 

82 """Sets the default float type. 

83 

84 Note: It is not recommended to set this to float16 for training, as this will 

85 likely cause numeric stability issues. Instead, mixed precision, which is 

86 using a mix of float16 and float32, can be used by calling 

87 `tf.keras.mixed_precision.set_global_policy('mixed_float16')`. See the 

88 [mixed precision guide]( 

89 https://www.tensorflow.org/guide/keras/mixed_precision) for details. 

90 

91 Args: 

92 value: String; `'float16'`, `'float32'`, or `'float64'`. 

93 

94 Example: 

95 >>> tf.keras.backend.floatx() 

96 'float32' 

97 >>> tf.keras.backend.set_floatx('float64') 

98 >>> tf.keras.backend.floatx() 

99 'float64' 

100 >>> tf.keras.backend.set_floatx('float32') 

101 

102 Raises: 

103 ValueError: In case of invalid value. 

104 """ 

105 global _FLOATX 

106 if value not in {'float16', 'float32', 'float64'}: 

107 raise ValueError('Unknown floatx type: ' + str(value)) 

108 _FLOATX = str(value) 

109 

110 

111@keras_export('keras.backend.image_data_format') 

112@dispatch.add_dispatch_support 

113def image_data_format(): 

114 """Returns the default image data format convention. 

115 

116 Returns: 

117 A string, either `'channels_first'` or `'channels_last'` 

118 

119 Example: 

120 >>> tf.keras.backend.image_data_format() 

121 'channels_last' 

122 """ 

123 return _IMAGE_DATA_FORMAT 

124 

125 

126@keras_export('keras.backend.set_image_data_format') 

127def set_image_data_format(data_format): 

128 """Sets the value of the image data format convention. 

129 

130 Args: 

131 data_format: string. `'channels_first'` or `'channels_last'`. 

132 

133 Example: 

134 >>> tf.keras.backend.image_data_format() 

135 'channels_last' 

136 >>> tf.keras.backend.set_image_data_format('channels_first') 

137 >>> tf.keras.backend.image_data_format() 

138 'channels_first' 

139 >>> tf.keras.backend.set_image_data_format('channels_last') 

140 

141 Raises: 

142 ValueError: In case of invalid `data_format` value. 

143 """ 

144 global _IMAGE_DATA_FORMAT 

145 if data_format not in {'channels_last', 'channels_first'}: 

146 raise ValueError('Unknown data_format: ' + str(data_format)) 

147 _IMAGE_DATA_FORMAT = str(data_format)