Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/training/experimental/mixed_precision.py: 50%

38 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Contains functions to use mixed precision with the graph rewrite.""" 

16 

17from tensorflow.python.framework import config 

18from tensorflow.python.platform import tf_logging 

19from tensorflow.python.training import optimizer 

20from tensorflow.python.training.experimental import loss_scale_optimizer as loss_scale_optimizer_v1 

21from tensorflow.python.training.experimental import mixed_precision_global_state 

22from tensorflow.python.util import deprecation 

23from tensorflow.python.util.tf_export import tf_export 

24 

25 

26# A mapping between optimizers and (wrapper_fn, wrapper_cls) pairs. wrapper_cls 

27# is a loss scale optimizer class, and wrapper_fn is a function that takes in 

28# an optimizer and LossScale and returns a wrapper_cls instance. 

29_REGISTERED_WRAPPER_OPTIMIZER_CLS = { 

30 optimizer.Optimizer: 

31 (loss_scale_optimizer_v1.MixedPrecisionLossScaleOptimizer,) * 2, 

32} 

33 

34 

35@tf_export('__internal__.mixed_precision.register_loss_scale_wrapper', v1=[]) 

36def register_loss_scale_wrapper(optimizer_cls, wrapper_fn, wrapper_cls=None): 

37 """Registers a loss scale optimizer wrapper. 

38 

39 `tf.compat.v1.mixed_precision.enable_mixed_precision_graph_rewrite` 

40 automatically wraps an optimizer with an optimizer wrapper that performs loss 

41 scaling. This function registers a 

42 `(base_cls, wrapper_fn, wrapper_cls)` triple 

43 that is used by `enable_mixed_precision_graph_rewrite`, where 

44 `wrapper_fn` is called to create a `wrapper_cls` instance that wraps an 

45 `optimizer_cls` instance. 

46 

47 Args: 

48 optimizer_cls: A base optimizer class, e.g. `tf.keras.optimizers.Optimizer`. 

49 wrapper_fn: A function that takes in arguments "optimizer" and 

50 "loss_scale", and returns a loss scale optimizer of type "wrapper_cls" 

51 that wraps "optimizer". 

52 wrapper_cls: A loss scale optimizer class. Defaults to `wrapper_fn`, in 

53 which case `wrapper_fn` should be a loss scale optimizer class whose 

54 constructor takes in arguments "optimizer" and "loss_scale". 

55 """ 

56 _REGISTERED_WRAPPER_OPTIMIZER_CLS[optimizer_cls] = ( 

57 wrapper_fn, wrapper_cls or wrapper_fn) 

58 

59 

60def _wrap_optimizer(opt, loss_scale): 

61 """Wraps an optimizer with a LossScaleOptimizer.""" 

62 

63 for _, wrapper_optimizer in _REGISTERED_WRAPPER_OPTIMIZER_CLS.values(): 

64 if isinstance(opt, wrapper_optimizer): 

65 raise ValueError('"opt" must not already be an instance of a {cls}. ' 

66 '`enable_mixed_precision_graph_rewrite` will ' 

67 'automatically wrap the optimizer with a ' 

68 '{cls}.' 

69 .format(cls=wrapper_optimizer.__name__)) 

70 

71 for optimizer_cls, (wrapper_fn, _) in ( 

72 _REGISTERED_WRAPPER_OPTIMIZER_CLS.items()): 

73 if isinstance(opt, optimizer_cls): 

74 return wrapper_fn(opt, loss_scale) 

75 

76 raise ValueError('"opt" must be an instance of a tf.train.Optimizer or a ' 

77 'tf.keras.optimizers.Optimizer, but got: %s' % opt) 

78 

79 

80@deprecation.deprecated_endpoints( 

81 'train.experimental.enable_mixed_precision_graph_rewrite') 

82@tf_export(v1=['mixed_precision.enable_mixed_precision_graph_rewrite', 

83 'train.experimental.enable_mixed_precision_graph_rewrite']) 

84def enable_mixed_precision_graph_rewrite_v1(opt, loss_scale='dynamic'): 

85 """Enable mixed precision via a graph rewrite. 

86 

87 Mixed precision is the use of both float32 and float16 data types when 

88 training a model to improve performance. This is achieved via a graph rewrite 

89 operation and a loss-scale optimizer. 

90 

91 Performing arithmetic operations in float16 takes advantage of specialized 

92 processing units, such as NVIDIA Tensor Cores, for much higher arithmetic 

93 throughput. However, due to the smaller representable range, performing the 

94 entire training with float16 can result in gradient underflow, that is, small 

95 gradient values becoming zeroes. Instead, performing only select arithmetic 

96 operations in float16 results in higher throughput and decreased training 

97 time when using compatible hardware accelerators while also reducing memory 

98 usage, typically without sacrificing model accuracy. 

99 

100 Note: While the mixed precision rewrite changes the datatype of various 

101 layers throughout the model, the same accuracy reached in float32 is 

102 expected. If a `NaN` gradient occurs with dynamic loss scaling, the model 

103 update for that batch is skipped. In this case, the global step count is not 

104 incremented, and the `LossScaleOptimizer` attempts to decrease the loss 

105 scaling value to avoid `NaN` values in subsequent iterations. This approach 

106 has been shown to achieve the same accuracy as float32 and, in most cases, 

107 better training throughput. 

108 

109 Example: 

110 

111 ```python 

112 model = tf.keras.models.Sequential([ 

113 tf.keras.layers.Dense(64, activation='relu'), 

114 tf.keras.layers.Dense(64, activation='softmax'), 

115 ]) 

116 

117 opt = tf.keras.optimizers.SGD() 

118 opt = tf.train.experimental.enable_mixed_precision_graph_rewrite(opt) 

119 model.compile(loss="mse", optimizer=opt) 

120 

121 x_train = np.random.random((1024, 64)) 

122 y_train = np.random.random((1024, 64)) 

123 model.fit(x_train, y_train) 

124 ``` 

125 

126 Calling `enable_mixed_precision_graph_rewrite(opt)` enables the graph rewrite 

127 operation before computing gradients. The function additionally returns an 

128 `Optimizer` (`opt`) wrapped with a `LossScaleOptimizer`. This prevents 

129 underflow in the float16 tensors during the backward pass. An optimizer of 

130 type `tf.train.Optimizer` or `tf.keras.optimizers.Optimizer` must be passed 

131 to this function, which will then be wrapped to use loss scaling. 

132 

133 The graph rewrite operation changes the `dtype` of certain operations in the 

134 graph from float32 to float16. There are several categories of operations 

135 that are either included or excluded by this rewrite operation. The following 

136 categories of Ops are defined inside corresponding functions under the class 

137 `AutoMixedPrecisionLists` in 

138 <a href="https://github.com/tensorflow/tensorflow/blob/master/tensorflow/ 

139 core/grappler/optimizers/auto_mixed_precision_lists.h"> 

140 auto_mixed_precision_lists.h</a>: 

141 

142 * `ClearList`: Ops that do not have numerically significant adverse effects. 

143 E.g. `ArgMax` and `Floor`. 

144 * `AllowList`: Ops that are considered numerically safe for execution in 

145 float16, and thus are always converted. E.g. `Conv2D`. 

146 * `DenyList`: Ops that are numerically unsafe to execute in float16 and 

147 can negatively affect downstream nodes. E.g. `Softmax`. 

148 * `GrayList`: Ops that are considered numerically safe for execution in 

149 float16 unless downstream from a DenyList Op. E.g. `Add` and `AvgPool`. 

150 

151 When this function is used, gradients should only be computed and applied 

152 with the returned optimizer, either by calling `opt.minimize()` or 

153 `opt.compute_gradients()` followed by `opt.apply_gradients()`. 

154 Gradients should not be computed with `tf.gradients` or `tf.GradientTape`. 

155 This is because the returned optimizer will apply loss scaling, and 

156 `tf.gradients` or `tf.GradientTape` will not. If you do directly use 

157 `tf.gradients` or `tf.GradientTape`, your model may not converge due to 

158 float16 underflow problems. 

159 

160 When eager execution is enabled, the mixed precision graph rewrite is only 

161 enabled within `tf.function`s, as outside `tf.function`s, there is no graph. 

162 

163 For NVIDIA GPUs with Tensor cores, as a general performance guide, dimensions 

164 (such as batch size, input size, output size, and channel counts) 

165 should be powers of two if under 256, or otherwise divisible by 8 if above 

166 256. For more information, check out the 

167 [NVIDIA Deep Learning Performance Guide]( 

168 https://docs.nvidia.com/deeplearning/sdk/dl-performance-guide/index.html). 

169 

170 Currently, mixed precision is only enabled on NVIDIA Tensor Core GPUs with 

171 Compute Capability 7.0 and above (Volta, Turing, or newer architectures). The 

172 parts of the graph on CPUs and TPUs are untouched by the graph rewrite. 

173 

174 Raises: 

175 `ValueError`, if the `tf.keras.mixed_precision` API is also used by calling 

176 `tf.keras.mixed_precision.set_global_policy`. Only one mixed precision 

177 API can be used. 

178 

179 Args: 

180 opt: An instance of a `tf.keras.optimizers.Optimizer` or a 

181 `tf.train.Optimizer`. 

182 loss_scale: Either an int/float, the string `"dynamic"`, or an instance of 

183 a `tf.mixed_precision.experimental.LossScale`. The loss scale to use. It 

184 is recommended to keep this as its default value of `"dynamic"`, which 

185 will adjust the scaling automatically to prevent `Inf` or `NaN` values. 

186 

187 Returns: 

188 A version of `opt` that will use loss scaling to prevent underflow. 

189 """ 

190 if mixed_precision_global_state.is_using_mixed_precision_policy(): 

191 raise ValueError( 

192 'The mixed precision graph rewrite cannot be enabled, because the ' 

193 'global Keras dtype Policy has been set to a mixed precision policy. ' 

194 'At most, one of the following can be called:\n\n' 

195 ' 1. tf.keras.mixed_precision.set_global_policy() with a mixed ' 

196 'precision policy (You called this first)\n\n' 

197 ' 2. tf.train.experimental.enable_mixed_precision_graph_rewrite() ' 

198 '(You called this second)\n' 

199 'You called both functions, which is an error, because both functions ' 

200 'enable you to use mixed precision. If in doubt which function to use, ' 

201 'use the first, as it supports Eager execution and is more ' 

202 'customizable.') 

203 

204 if mixed_precision_global_state.non_mixed_precision_session_created(): 

205 # TODO(reedwm): Give the stacktrace of the existing Sessions. And if the 

206 # Sessions have already been closed, do not raise this error message. 

207 tf_logging.warn('You already have existing Sessions that do not use mixed ' 

208 'precision. enable_mixed_precision_graph_rewrite() will ' 

209 'not affect these Sessions.') 

210 opt = _wrap_optimizer(opt, loss_scale) 

211 config.set_optimizer_experimental_options({'auto_mixed_precision': True}) 

212 mixed_precision_global_state.set_mixed_precision_graph_rewrite_enabled(True) 

213 return opt 

214 

215 

216@deprecation.deprecated_endpoints( 

217 'train.experimental.disable_mixed_precision_graph_rewrite') 

218@tf_export(v1=['mixed_precision.disable_mixed_precision_graph_rewrite', 

219 'train.experimental.disable_mixed_precision_graph_rewrite']) 

220def disable_mixed_precision_graph_rewrite_v1(): 

221 """Disables the mixed precision graph rewrite. 

222 

223 After this is called, the mixed precision graph rewrite will no longer run for 

224 new Sessions, and so float32 operations will no longer be converted to float16 

225 in such Sessions. However, any existing Sessions will continue to have the 

226 graph rewrite enabled if they were created after 

227 `enable_mixed_precision_graph_rewrite` was called but before 

228 `disable_mixed_precision_graph_rewrite` was called. 

229 

230 This does not undo the effects of loss scaling. Any optimizers wrapped with a 

231 LossScaleOptimizer will continue to do loss scaling, although this loss 

232 scaling will no longer be useful if the optimizer is used in new Sessions, as 

233 the graph rewrite no longer converts the graph to use float16. 

234 

235 This function is useful for unit testing. A unit tests can test using the 

236 mixed precision graph rewrite, then disable it so future unit tests continue 

237 using float32. If this is done, unit tests should not share a single session, 

238 as `enable_mixed_precision_graph_rewrite` and 

239 `disable_mixed_precision_graph_rewrite` have no effect on existing sessions. 

240 """ 

241 # We only have a separate V1 version of this function, because the V1 

242 # docstring mentions sessions. 

243 if (not 

244 mixed_precision_global_state.is_mixed_precision_graph_rewrite_enabled()): 

245 tf_logging.warn('disable_mixed_precision_graph_rewrite() called when mixed ' 

246 'precision is already disabled.') 

247 config.set_optimizer_experimental_options({'auto_mixed_precision': False}) 

248 mixed_precision_global_state.set_mixed_precision_graph_rewrite_enabled(False)