Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/mixed_precision/policy.py: 29%

126 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Contains the Policy class for mixed precision training.""" 

16 

17import contextlib 

18 

19import tensorflow.compat.v2 as tf 

20 

21from keras.src import backend 

22from keras.src.engine import base_layer_utils 

23from keras.src.mixed_precision import device_compatibility_check 

24from keras.src.mixed_precision import loss_scale_optimizer 

25from keras.src.saving import serialization_lib 

26 

27# isort: off 

28from tensorflow.python.util.tf_export import keras_export 

29 

30 

31@keras_export("keras.mixed_precision.Policy", v1=[]) 

32class Policy: 

33 """A dtype policy for a Keras layer. 

34 

35 A dtype policy determines a layer's computation and variable dtypes. Each 

36 layer has a policy. Policies can be passed to the `dtype` argument of layer 

37 constructors, or a global policy can be set with 

38 `tf.keras.mixed_precision.set_global_policy`. 

39 

40 Args: 

41 name: The policy name, which determines the compute and variable dtypes. 

42 Can be any dtype name, such as `'float32'` or `'float64'`, which causes 

43 both the compute and variable dtypes will be that dtype. Can also be the 

44 string `'mixed_float16'` or `'mixed_bfloat16'`, which causes the compute 

45 dtype to be float16 or bfloat16 and the variable dtype to be float32. 

46 

47 Typically you only need to interact with dtype policies when using mixed 

48 precision, which is the use of float16 or bfloat16 for computations and 

49 float32 for variables. This is why the term `mixed_precision` appears in the 

50 API name. Mixed precision can be enabled by passing `'mixed_float16'` or 

51 `'mixed_bfloat16'` to `tf.keras.mixed_precision.set_global_policy`. See [the 

52 mixed precision 

53 guide](https://www.tensorflow.org/guide/keras/mixed_precision) for more 

54 information on how to use mixed precision. 

55 

56 >>> tf.keras.mixed_precision.set_global_policy('mixed_float16') 

57 >>> layer1 = tf.keras.layers.Dense(10) 

58 >>> layer1.dtype_policy # `layer1` will automatically use mixed precision 

59 <Policy "mixed_float16"> 

60 >>> # Can optionally override layer to use float32 

61 >>> # instead of mixed precision. 

62 >>> layer2 = tf.keras.layers.Dense(10, dtype='float32') 

63 >>> layer2.dtype_policy 

64 <Policy "float32"> 

65 >>> # Set policy back to initial float32 for future examples. 

66 >>> tf.keras.mixed_precision.set_global_policy('float32') 

67 

68 In the example above, passing `dtype='float32'` to the layer is equivalent 

69 to passing `dtype=tf.keras.mixed_precision.Policy('float32')`. In general, 

70 passing a dtype policy name to a layer is equivalent to passing the 

71 corresponding policy, so it is never necessary to explicitly construct a 

72 `Policy` object. 

73 

74 Note: `Model.compile` will automatically wrap an optimizer with a 

75 `tf.keras.mixed_precision.LossScaleOptimizer` if you use the 

76 `'mixed_float16'` policy. If you use a custom training loop instead of 

77 calling `Model.compile`, you should explicitly use a 

78 `tf.keras.mixed_precision.LossScaleOptimizer` to avoid numeric underflow 

79 with float16. 

80 

81 ### How a layer uses its policy's compute dtype 

82 

83 A layer casts its inputs to its compute dtype. This causes the layer's 

84 computations and output to also be in the compute dtype. For example: 

85 

86 >>> x = tf.ones((4, 4, 4, 4), dtype='float64') 

87 >>> # `layer`'s policy defaults to float32. 

88 >>> layer = tf.keras.layers.Conv2D(filters=4, kernel_size=2) 

89 >>> layer.compute_dtype # Equivalent to layer.dtype_policy.compute_dtype 

90 'float32' 

91 >>> # `layer` casts its inputs to its compute dtype and does computations in 

92 >>> # that dtype. 

93 >>> y = layer(x) 

94 >>> y.dtype 

95 tf.float32 

96 

97 Note that the base `tf.keras.layers.Layer` class inserts the casts. If 

98 subclassing your own layer, you do not have to insert any casts. 

99 

100 Currently, only tensors in the first argument to the layer's `call` method 

101 are casted (although this will likely be changed in a future minor release). 

102 For example: 

103 

104 >>> class MyLayer(tf.keras.layers.Layer): 

105 ... # Bug! `b` will not be casted. 

106 ... def call(self, a, b): 

107 ... return a + 1., b + 1. 

108 >>> a = tf.constant(1., dtype="float32") 

109 >>> b = tf.constant(1., dtype="float32") 

110 >>> layer = MyLayer(dtype="float64") 

111 >>> x, y = layer(a, b) 

112 >>> x.dtype 

113 tf.float64 

114 >>> y.dtype 

115 tf.float32 

116 

117 If writing your own layer with multiple inputs, you should either explicitly 

118 cast other tensors to `self.compute_dtype` in `call` or accept all tensors 

119 in the first argument as a list. 

120 

121 The casting only occurs in TensorFlow 2. If 

122 `tf.compat.v1.disable_v2_behavior()` has been called, you can enable the 

123 casting behavior with 

124 `tf.compat.v1.keras.layers.enable_v2_dtype_behavior()`. 

125 

126 ### How a layer uses its policy's variable dtype 

127 

128 The default dtype of variables created by `tf.keras.layers.Layer.add_weight` 

129 is the layer's policy's variable dtype. 

130 

131 If a layer's compute and variable dtypes differ, `add_weight` will wrap 

132 floating-point variables with a special wrapper called an 

133 `AutoCastVariable`. `AutoCastVariable` is identical to the original 

134 variable except it casts itself to the layer's compute dtype when used 

135 within `Layer.call`. This means if you are writing a layer, you do not have 

136 to explicitly cast the variables to the layer's compute dtype. For example: 

137 

138 >>> class SimpleDense(tf.keras.layers.Layer): 

139 ... 

140 ... def build(self, input_shape): 

141 ... # With mixed precision, self.kernel is a float32 AutoCastVariable 

142 ... self.kernel = self.add_weight('kernel', (input_shape[-1], 10)) 

143 ... 

144 ... def call(self, inputs): 

145 ... # With mixed precision, self.kernel will be casted to float16 

146 ... return tf.linalg.matmul(inputs, self.kernel) 

147 ... 

148 >>> layer = SimpleDense(dtype='mixed_float16') 

149 >>> y = layer(tf.ones((10, 10))) 

150 >>> y.dtype 

151 tf.float16 

152 >>> layer.kernel.dtype 

153 tf.float32 

154 

155 A layer author can prevent a variable from being wrapped with an 

156 `AutoCastVariable` by passing `experimental_autocast=False` to `add_weight`, 

157 which is useful if the float32 value of the variable must be accessed within 

158 the layer. 

159 

160 ### How to write a layer that supports mixed precision and float64. 

161 

162 For the most part, layers will automatically support mixed precision and 

163 float64 without any additional work, due to the fact the base layer 

164 automatically casts inputs, creates variables of the correct type, and in 

165 the case of mixed precision, wraps variables with `AutoCastVariables`. 

166 

167 The primary case where you need extra work to support mixed precision or 

168 float64 is when you create a new tensor, such as with `tf.ones` or 

169 `tf.random.normal`, In such cases, you must create the tensor of the correct 

170 dtype. For example, if you call `tf.random.normal`, you must pass the 

171 compute dtype, which is the dtype the inputs have been casted to: 

172 

173 >>> class AddRandom(tf.keras.layers.Layer): 

174 ... 

175 ... def call(self, inputs): 

176 ... # We must pass `dtype=inputs.dtype`, otherwise a TypeError may 

177 ... # occur when adding `inputs` to `rand`. 

178 ... rand = tf.random.normal(shape=inputs.shape, dtype=inputs.dtype) 

179 ... return inputs + rand 

180 >>> layer = AddRandom(dtype='mixed_float16') 

181 >>> y = layer(x) 

182 >>> y.dtype 

183 tf.float16 

184 

185 If you did not pass `dtype=inputs.dtype` to `tf.random.normal`, a 

186 `TypeError` would have occurred. This is because the `tf.random.normal`'s 

187 dtype defaults to `"float32"`, but the input dtype is float16. You cannot 

188 add a float32 tensor with a float16 tensor. 

189 """ 

190 

191 def __init__(self, name): 

192 if isinstance(name, tf.DType): 

193 raise TypeError( 

194 "'name' must be a string, not a DType. " 

195 f"Instead, pass DType.name. Received: name={name.name}" 

196 ) 

197 elif not isinstance(name, str): 

198 raise TypeError(f"'name' must be a string, but got: {name}") 

199 self._name = name 

200 self._compute_dtype, self._variable_dtype = self._parse_name(name) 

201 if name in ("mixed_float16", "mixed_bloat16"): 

202 device_compatibility_check.log_device_compatibility_check(name) 

203 

204 def _parse_name(self, name): 

205 """Parses a Policy name into a compute and variable dtype. 

206 

207 Args: 

208 name: The name of the policy: 

209 

210 Returns: 

211 The (compute_dtype, variable_dtype) pair. 

212 """ 

213 if name.endswith("_float32_vars"): 

214 error_msg = ( 

215 "Policies ending in '_float32_vars' have been removed " 

216 "from TensorFlow." 

217 ) 

218 if name in ("infer_float32_vars", "infer_with_float32_vars"): 

219 error_msg += ( 

220 " Please use the 'mixed_float16' or 'mixed_bfloat16' " 

221 "policy instead." 

222 ) 

223 elif name == "float16_with_float32_vars": 

224 error_msg += " Please use the 'mixed_float16' policy instead." 

225 elif name == "bfloat16_with_float32_vars": 

226 error_msg += " Please use the 'mixed_bfloat16' policy instead." 

227 error_msg += f" Got policy name: '{name}'" 

228 raise ValueError(error_msg) 

229 

230 if name == "mixed_float16": 

231 return "float16", "float32" 

232 elif name == "mixed_bfloat16": 

233 return "bfloat16", "float32" 

234 elif name == "_infer": 

235 # The "_infer" policy exists only for compatibility with TF 1, where 

236 # "_infer" is the default. The behavior matches the behavior of TF 

237 # 1's behavior before policies were introduced. With "_infer", the 

238 # computation and variable dtype are inferred from the first input 

239 # the first time the layer is called. Once the layer is called for 

240 # the first time, the layer's policy will change to the dtype of the 

241 # first input, and it will no longer have the "_infer" policy. 

242 # 

243 # The infer policy should be considered an implementation detail and 

244 # may be removed in the future. 

245 return None, None 

246 

247 try: 

248 dtype = tf.as_dtype(name).name 

249 except TypeError: 

250 raise ValueError( 

251 f"Cannot convert value {name} to a mixed precision Policy. " 

252 "Valid policies include 'mixed_float16', 'mixed_bfloat16', " 

253 "and the name of any dtype such as 'float32'." 

254 ) 

255 return dtype, dtype 

256 

257 @property 

258 def variable_dtype(self): 

259 """The variable dtype of this policy. 

260 

261 This is the dtype layers will create their variables in, unless a layer 

262 explicitly chooses a different dtype. If this is different than 

263 `Policy.compute_dtype`, Layers will cast variables to the compute dtype 

264 to avoid type errors. 

265 

266 Variable regularizers are run in the variable dtype, not the compute 

267 dtype. 

268 

269 Returns: 

270 The variable dtype of this policy, as a string. 

271 """ 

272 return self._variable_dtype 

273 

274 @property 

275 def compute_dtype(self): 

276 """The compute dtype of this policy. 

277 

278 This is the dtype layers will do their computations in. Typically layers 

279 output tensors with the compute dtype as well. 

280 

281 Note that even if the compute dtype is float16 or bfloat16, hardware 

282 devices may not do individual adds, multiplies, and other fundamental 

283 operations in float16 or bfloat16, but instead may do some of them in 

284 float32 for numeric stability. The compute dtype is the dtype of the 

285 inputs and outputs of the TensorFlow ops that the layer executes. 

286 Internally, many TensorFlow ops will do certain internal calculations in 

287 float32 or some other device-internal intermediate format with higher 

288 precision than float16/bfloat16, to increase numeric stability. 

289 

290 For example, a `tf.keras.layers.Dense` layer, when run on a GPU with a 

291 float16 compute dtype, will pass float16 inputs to `tf.linalg.matmul`. 

292 But, `tf.linalg.matmul` will do use float32 intermediate math. The 

293 performance benefit of float16 is still apparent, due to increased 

294 memory bandwidth and the fact modern GPUs have specialized hardware for 

295 computing matmuls on float16 inputs while still keeping intermediate 

296 computations in float32. 

297 

298 Returns: 

299 The compute dtype of this policy, as a string. 

300 """ 

301 return self._compute_dtype 

302 

303 @property 

304 def name(self): 

305 """Returns the name of this policy.""" 

306 return self._name 

307 

308 def __repr__(self): 

309 return f'<Policy "{self._name}">' 

310 

311 def get_config(self): 

312 return {"name": self.name} 

313 

314 @classmethod 

315 def from_config(cls, config, custom_objects=None): 

316 del custom_objects 

317 if "loss_scale" in config: 

318 config = config.copy() 

319 # Policy.get_config in TensorFlow 2.3 and below had a loss_scale. We 

320 # silently drop it. 

321 del config["loss_scale"] 

322 return cls(**config) 

323 

324 

325# The current global policy in effect. If None, it means the current value of 

326# floatx should be used as the policy if the V2 dtype behavior is enabled, 

327# or "_infer" otherwise. 

328# TODO(reedwm): Make this thread local? 

329_global_policy = None 

330 

331 

332@keras_export("keras.mixed_precision.global_policy", v1=[]) 

333def global_policy(): 

334 """Returns the global dtype policy. 

335 

336 The global policy is the default `tf.keras.mixed_precision.Policy` used for 

337 layers, if no policy is passed to the layer constructor. If no policy has 

338 been set with `keras.mixed_precision.set_global_policy`, this will return a 

339 policy constructed from `tf.keras.backend.floatx()` (floatx defaults to 

340 float32). 

341 

342 >>> tf.keras.mixed_precision.global_policy() 

343 <Policy "float32"> 

344 >>> tf.keras.layers.Dense(10).dtype_policy # Defaults to the global policy 

345 <Policy "float32"> 

346 

347 If TensorFlow 2 behavior has been disabled with 

348 `tf.compat.v1.disable_v2_behavior()`, this will instead return a special 

349 "_infer" policy which infers the dtype from the dtype of the first input the 

350 first time the layer is called. This behavior matches the behavior that 

351 existed in TensorFlow 1. 

352 

353 See `tf.keras.mixed_precision.Policy` for more information on policies. 

354 

355 Returns: 

356 The global Policy. 

357 """ 

358 if _global_policy is None: 

359 if base_layer_utils.v2_dtype_behavior_enabled(): 

360 return Policy(backend.floatx()) 

361 else: 

362 return Policy("_infer") 

363 return _global_policy 

364 

365 

366def _check_if_mixed_precision_graph_rewrite_is_enabled(policy): 

367 if tf.__internal__.train.is_mixed_precision_graph_rewrite_enabled(): 

368 raise ValueError( 

369 'The global dtype policy cannot be set to "{policy.name}", because ' 

370 "the mixed precision graph rewrite has already been enabled.\n" 

371 "At most, one of the following can be called:\n\n" 

372 " 1. tf.compat.v1.train.enable_mixed_precision_graph_rewrite() " 

373 "(You called this first)\n" 

374 " 2. tf.keras.mixed_precision.set_global_policy() with a mixed " 

375 "precision policy (You called this second)\n\n" 

376 "You called both functions, which is an error, because both " 

377 "functions enable you to use mixed precision. If in doubt which " 

378 "function to use, use the second, as it supports Eager execution " 

379 "and is more customizable.".format(policy=policy) 

380 ) 

381 

382 

383@keras_export("keras.mixed_precision.set_global_policy", v1=[]) 

384def set_global_policy(policy): 

385 """Sets the global dtype policy. 

386 

387 The global policy is the default `tf.keras.mixed_precision.Policy` used for 

388 layers, if no policy is passed to the layer constructor. 

389 

390 >>> tf.keras.mixed_precision.set_global_policy('mixed_float16') 

391 >>> tf.keras.mixed_precision.global_policy() 

392 <Policy "mixed_float16"> 

393 >>> tf.keras.layers.Dense(10).dtype_policy 

394 <Policy "mixed_float16"> 

395 >>> # Global policy is not used if a policy 

396 >>> # is directly passed to constructor 

397 >>> tf.keras.layers.Dense(10, dtype='float64').dtype_policy 

398 <Policy "float64"> 

399 >>> tf.keras.mixed_precision.set_global_policy('float32') 

400 

401 If no global policy is set, layers will instead default to a Policy 

402 constructed from `tf.keras.backend.floatx()`. 

403 

404 To use mixed precision, the global policy should be set to `'mixed_float16'` 

405 or `'mixed_bfloat16'`, so that every layer uses a 16-bit compute dtype and 

406 float32 variable dtype by default. 

407 

408 Only floating point policies can be set as the global policy, such as 

409 `'float32'` and `'mixed_float16'`. Non-floating point policies such as 

410 `'int32'` and `'complex64'` cannot be set as the global policy because most 

411 layers do not support such policies. 

412 

413 See `tf.keras.mixed_precision.Policy` for more information. 

414 

415 Args: 

416 policy: A Policy, or a string that will be converted to a Policy. Can also 

417 be None, in which case the global policy will be constructed from 

418 `tf.keras.backend.floatx()` 

419 """ 

420 global _global_policy 

421 if not base_layer_utils.v2_dtype_behavior_enabled(): 

422 raise ValueError( 

423 "The global policy can only be set in TensorFlow 2 or if " 

424 "V2 dtype behavior has been set. To enable V2 dtype " 

425 "behavior, call " 

426 '"tf.compat.v1.keras.layers.enable_v2_dtype_behavior()"' 

427 ) 

428 if policy is not None and not isinstance(policy, Policy): 

429 policy = Policy(policy) 

430 is_mixed_policy = ( 

431 policy is not None and policy.compute_dtype != policy.variable_dtype 

432 ) 

433 if is_mixed_policy: 

434 _check_if_mixed_precision_graph_rewrite_is_enabled(policy) 

435 if ( 

436 policy is not None 

437 and policy.compute_dtype is not None 

438 and not tf.as_dtype(policy.compute_dtype).is_floating 

439 ): 

440 raise ValueError( 

441 "set_global_policy can only be used to set the global " 

442 'policy to floating-point policies, such as "float32" and ' 

443 f'"mixed_float16", but got policy: {policy.name}' 

444 ) 

445 _global_policy = policy 

446 tf.__internal__.train.set_using_mixed_precision_policy(is_mixed_policy) 

447 

448 

449# TODO(reedwm): Make this thread local 

450@contextlib.contextmanager 

451def policy_scope(policy): 

452 """A context manager that sets the global Policy under it. 

453 

454 Args: 

455 policy: A Policy, or a string that will be converted to a Policy.. 

456 

457 Yields: 

458 Nothing. 

459 """ 

460 old_policy = _global_policy 

461 try: 

462 set_global_policy(policy) 

463 yield 

464 finally: 

465 set_global_policy(old_policy) 

466 

467 

468def get_policy(identifier): 

469 if isinstance(identifier, Policy): 

470 dtype_policy = identifier 

471 elif isinstance(identifier, dict): 

472 dtype_policy = deserialize(identifier) 

473 elif isinstance(identifier, str) and identifier in ( 

474 "mixed_float16", 

475 "mixed_bfloat16", 

476 ): 

477 # The isinstance check is required since np.dtype raises an error if 

478 # compared to a non-dtype string. 

479 dtype_policy = Policy(identifier) 

480 elif identifier: 

481 dtype_policy = Policy(tf.as_dtype(identifier).name) 

482 else: 

483 dtype_policy = global_policy() 

484 if ( 

485 dtype_policy.name == "mixed_float16" 

486 and not loss_scale_optimizer.strategy_supports_loss_scaling() 

487 ): 

488 # Although only loss scaling doesn't support certain strategies, to 

489 # avoid confusion, we disallow the 'mixed_float16' policy with 

490 # unsupported strategies. This is because 'mixed_float16' requires 

491 # loss scaling for numeric stability. 

492 strategy = tf.distribute.get_strategy() 

493 raise ValueError( 

494 "Mixed precision is not supported with the " 

495 f"tf.distribute.Strategy: {strategy.__class__.__name__}. " 

496 "Either stop using mixed precision by removing the use of " 

497 f"the {dtype_policy.name} policy or " 

498 "use a different Strategy, e.g. a MirroredStrategy." 

499 ) 

500 return dtype_policy 

501 

502 

503def _is_convertible_to_dtype(dtype): 

504 try: 

505 tf.as_dtype(dtype) 

506 return True 

507 except TypeError: 

508 return False 

509 

510 

511def _policy_equivalent_to_dtype(policy): 

512 """Returns True if the Policy is equivalent to a single dtype. 

513 

514 A policy is equivalent to a single dtype if the policy's compute and 

515 variable dtypes are the same and the policy's type is Policy and not a 

516 subclass of Policy. 

517 

518 The "_infer" policy is considered equivalent to a single dtype. 

519 

520 Args: 

521 policy: A Policy. 

522 

523 Returns: 

524 True, if the policy is equivalent to a single dtype. 

525 """ 

526 # We use type() instead of isinstance because a subclass of Policy is never 

527 # equivalent to a dtype. 

528 return type(policy) == Policy and ( 

529 policy.name == "_infer" or _is_convertible_to_dtype(policy.name) 

530 ) 

531 

532 

533def serialize(policy): 

534 if _policy_equivalent_to_dtype(policy): 

535 # We return either None or the policy name for compatibility with older 

536 # versions of Keras. If the policy name is returned, it is a dtype 

537 # string such as 'float32'. 

538 return None if policy.name == "_infer" else policy.name 

539 return serialization_lib.serialize_keras_object(policy) 

540 

541 

542def deserialize(config, custom_objects=None): 

543 if isinstance(config, str) and _is_convertible_to_dtype(config): 

544 return Policy(config) 

545 if config is None: 

546 return Policy("_infer") 

547 # PolicyV1 was an old version of Policy that was removed. Deserializing it 

548 # turns it into a (non-V1) Policy. 

549 module_objects = {"Policy": Policy, "PolicyV1": Policy} 

550 return serialization_lib.deserialize_keras_object( 

551 config, 

552 module_objects=module_objects, 

553 custom_objects=custom_objects, 

554 printable_module_name="dtype policy", 

555 ) 

556