Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/mixed_precision/policy.py: 32%

142 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Contains the Policy class for mixed precision training.""" 

16 

17import contextlib 

18 

19from tensorflow.python.framework import dtypes 

20from tensorflow.python.keras import backend 

21from tensorflow.python.keras.engine import base_layer_utils 

22from tensorflow.python.keras.mixed_precision import device_compatibility_check 

23from tensorflow.python.keras.mixed_precision import loss_scale as keras_loss_scale_module 

24from tensorflow.python.keras.utils import generic_utils 

25from tensorflow.python.platform import tf_logging 

26from tensorflow.python.training.experimental import mixed_precision_global_state 

27from tensorflow.python.util.tf_export import keras_export 

28 

29 

30# pylint: disable=g-classes-have-attributes 

31@keras_export('keras.mixed_precision.Policy', v1=[]) 

32class Policy(object): 

33 """A dtype policy for a Keras layer. 

34 

35 A dtype policy determines a layer's computation and variable dtypes. Each 

36 layer has a policy. Policies can be passed to the `dtype` argument of layer 

37 constructors, or a global policy can be set with 

38 `tf.keras.mixed_precision.set_global_policy`. 

39 

40 Args: 

41 name: The policy name, which determines the compute and variable dtypes. Can 

42 be any dtype name, such as `'float32'` or `'float64'`, which causes both 

43 the compute and variable dtypes will be that dtype. Can also be the string 

44 `'mixed_float16'` or `'mixed_bfloat16'`, which causes the compute dtype to 

45 be float16 or bfloat16 and the variable dtype to be float32. 

46 

47 Typically you only need to interact with dtype policies when using mixed 

48 precision, which is the use of float16 or bfloat16 for computations and 

49 float32 for variables. This is why the term `mixed_precision` appears in the 

50 API name. Mixed precision can be enabled by passing `'mixed_float16'` or 

51 `'mixed_bfloat16'` to `tf.keras.mixed_precision.set_global_policy`. See [the 

52 mixed precision guide](https://www.tensorflow.org/guide/keras/mixed_precision) 

53 for more information on how to use mixed precision. 

54 

55 >>> tf.keras.mixed_precision.set_global_policy('mixed_float16') 

56 >>> layer1 = tf.keras.layers.Dense(10) 

57 >>> layer1.dtype_policy # `layer1` will automatically use mixed precision 

58 <Policy "mixed_float16"> 

59 >>> # Can optionally override layer to use float32 instead of mixed precision. 

60 >>> layer2 = tf.keras.layers.Dense(10, dtype='float32') 

61 >>> layer2.dtype_policy 

62 <Policy "float32"> 

63 >>> # Set policy back to initial float32 for future examples. 

64 >>> tf.keras.mixed_precision.set_global_policy('float32') 

65 

66 In the example above, passing `dtype='float32'` to the layer is equivalent to 

67 passing `dtype=tf.keras.mixed_precision.Policy('float32')`. In general, 

68 passing a dtype policy name to a layer is equivalent to passing the 

69 corresponding policy, so it is never necessary to explicitly construct a 

70 `Policy` object. 

71 

72 Note: `Model.compile` will automatically wrap an optimizer with a 

73 `tf.keras.mixed_precision.LossScaleOptimizer` if you use the `'mixed_float16'` 

74 policy. If you use a custom training loop instead of calling `Model.compile`, 

75 you should explicitly use a `tf.keras.mixed_precision.LossScaleOptimizer` to 

76 avoid numeric underflow with float16. 

77 

78 ### How a layer uses its policy's compute dtype 

79 

80 A layer casts its inputs to its compute dtype. This causes the layer's 

81 computations and output to also be in the compute dtype. For example: 

82 

83 >>> x = tf.ones((4, 4, 4, 4), dtype='float64') 

84 >>> # `layer`'s policy defaults to float32. 

85 >>> layer = tf.keras.layers.Conv2D(filters=4, kernel_size=2) 

86 >>> layer.compute_dtype # Equivalent to layer.dtype_policy.compute_dtype 

87 'float32' 

88 >>> # `layer` casts its inputs to its compute dtype and does computations in 

89 >>> # that dtype. 

90 >>> y = layer(x) 

91 >>> y.dtype 

92 tf.float32 

93 

94 Note that the base `tf.keras.layers.Layer` class inserts the casts. If 

95 subclassing your own layer, you do not have to insert any casts. 

96 

97 Currently, only tensors in the first argument to the layer's `call` method are 

98 casted (although this will likely be changed in a future minor release). For 

99 example: 

100 

101 >>> class MyLayer(tf.keras.layers.Layer): 

102 ... # Bug! `b` will not be casted. 

103 ... def call(self, a, b): 

104 ... return a + 1., b + 1. 

105 >>> a = tf.constant(1., dtype="float32") 

106 >>> b = tf.constant(1., dtype="float32") 

107 >>> layer = MyLayer(dtype="float64") 

108 >>> x, y = layer(a, b) 

109 >>> x.dtype 

110 tf.float64 

111 >>> y.dtype 

112 tf.float32 

113 

114 If writing your own layer with multiple inputs, you should either explicitly 

115 cast other tensors to `self.compute_dtype` in `call` or accept all tensors in 

116 the first argument as a list. 

117 

118 The casting only occurs in TensorFlow 2. If 

119 `tf.compat.v1.disable_v2_behavior()` has been called, you can enable the 

120 casting behavior with `tf.compat.v1.keras.layers.enable_v2_dtype_behavior()`. 

121 

122 ### How a layer uses its policy's variable dtype 

123 

124 The default dtype of variables created by `tf.keras.layers.Layer.add_weight` 

125 is the layer's policy's variable dtype. 

126 

127 If a layer's compute and variable dtypes differ, `add_weight` will wrap 

128 floating-point variables with a special wrapper called an `AutoCastVariable`. 

129 `AutoCastVariable` is identical to the original variable except it casts 

130 itself to the layer's compute dtype when used within `Layer.call`. This means 

131 if you are writing a layer, you do not have to explicitly cast the variables 

132 to the layer's compute dtype. For example: 

133 

134 >>> class SimpleDense(tf.keras.layers.Layer): 

135 ... 

136 ... def build(self, input_shape): 

137 ... # With mixed precision, self.kernel is a float32 AutoCastVariable 

138 ... self.kernel = self.add_weight('kernel', (input_shape[-1], 10)) 

139 ... 

140 ... def call(self, inputs): 

141 ... # With mixed precision, self.kernel will be casted to float16 

142 ... return tf.linalg.matmul(inputs, self.kernel) 

143 ... 

144 >>> layer = SimpleDense(dtype='mixed_float16') 

145 >>> y = layer(tf.ones((10, 10))) 

146 >>> y.dtype 

147 tf.float16 

148 >>> layer.kernel.dtype 

149 tf.float32 

150 

151 A layer author can prevent a variable from being wrapped with an 

152 `AutoCastVariable` by passing `experimental_autocast=False` to `add_weight`, 

153 which is useful if the float32 value of the variable must be accessed within 

154 the layer. 

155 

156 ### How to write a layer that supports mixed precision and float64. 

157 

158 For the most part, layers will automatically support mixed precision and 

159 float64 without any additional work, due to the fact the base layer 

160 automatically casts inputs, creates variables of the correct type, and in the 

161 case of mixed precision, wraps variables with `AutoCastVariables`. 

162 

163 The primary case where you need extra work to support mixed precision or 

164 float64 is when you create a new tensor, such as with `tf.ones` or 

165 `tf.random.normal`, In such cases, you must create the tensor of the correct 

166 dtype. For example, if you call `tf.random.normal`, you must pass the compute 

167 dtype, which is the dtype the inputs have been casted to: 

168 

169 >>> class AddRandom(tf.keras.layers.Layer): 

170 ... 

171 ... def call(self, inputs): 

172 ... # We must pass `dtype=inputs.dtype`, otherwise a TypeError may 

173 ... # occur when adding `inputs` to `rand`. 

174 ... rand = tf.random.normal(shape=inputs.shape, dtype=inputs.dtype) 

175 ... return inputs + rand 

176 >>> layer = AddRandom(dtype='mixed_float16') 

177 >>> y = layer(x) 

178 >>> y.dtype 

179 tf.float16 

180 

181 If you did not pass `dtype=inputs.dtype` to `tf.random.normal`, a 

182 `TypeError` would have occurred. This is because the `tf.random.normal`'s 

183 dtype defaults to `"float32"`, but the input dtype is float16. You cannot add 

184 a float32 tensor with a float16 tensor. 

185 """ 

186 

187 def __init__(self, name): 

188 if isinstance(name, dtypes.DType): 

189 raise TypeError("'name' must be a string, not a DType. " 

190 "Instead, pass DType.name. Got: %s" % (name.name,)) 

191 elif not isinstance(name, str): 

192 raise TypeError("'name' must be a string, but got: %s" % (name,)) 

193 self._name = name 

194 self._compute_dtype, self._variable_dtype = self._parse_name(name) 

195 if name in ('mixed_float16', 'mixed_bloat16'): 

196 device_compatibility_check.log_device_compatibility_check(name) 

197 

198 def _parse_name(self, name): 

199 """Parses a Policy name into a compute and variable dtype. 

200 

201 Args: 

202 name: The name of the policy: 

203 

204 Returns: 

205 The (compute_dtype, variable_dtype) pair. 

206 """ 

207 if name.endswith('_float32_vars'): 

208 error_msg = ('Policies ending in \'_float32_vars\' have been removed ' 

209 'from TensorFlow.') 

210 if name in ('infer_float32_vars', 'infer_with_float32_vars'): 

211 error_msg += (' Please use the \'mixed_float16\' or \'mixed_bfloat16\' ' 

212 'policy instead.') 

213 elif name == 'float16_with_float32_vars': 

214 error_msg += (' Please use the \'mixed_float16\' policy instead.') 

215 elif name == 'bfloat16_with_float32_vars': 

216 error_msg += (' Please use the \'mixed_bfloat16\' policy instead.') 

217 error_msg += ' Got policy name: \'%s\'' % name 

218 raise ValueError(error_msg) 

219 

220 if name == 'mixed_float16': 

221 return 'float16', 'float32' 

222 elif name == 'mixed_bfloat16': 

223 return 'bfloat16', 'float32' 

224 elif name == '_infer': 

225 # The "_infer" policy exists only for compatibility with TF 1, where 

226 # "_infer" is the default. The behavior matches the behavior of TF 1's 

227 # behavior before policies were introduced. With "_infer", the computation 

228 # and variable dtype are inferred from the first input the first time the 

229 # layer is called. Once the layer is called for the first time, the 

230 # layer's policy will change to the dtype of the first input, and it will 

231 # no longer have the "_infer" policy. 

232 # 

233 # The infer policy should be considered an implementation detail and may 

234 # be removed in the future. 

235 return None, None 

236 

237 try: 

238 dtype = dtypes.as_dtype(name).name 

239 except TypeError: 

240 error = ("Cannot convert value %s to a mixed precision Policy. " 

241 "Valid policies include 'mixed_float16', 'mixed_bfloat16', " 

242 "and the name of any dtype such as 'float32'." % (name,)) 

243 raise ValueError(error) 

244 return dtype, dtype 

245 

246 @property 

247 def variable_dtype(self): 

248 """The variable dtype of this policy. 

249 

250 This is the dtype layers will create their variables in, unless a layer 

251 explicitly chooses a different dtype. If this is different than 

252 `Policy.compute_dtype`, Layers will cast variables to the compute dtype to 

253 avoid type errors. 

254 

255 Variable regularizers are run in the variable dtype, not the compute dtype. 

256 

257 Returns: 

258 The variable dtype of this policy, as a string. 

259 """ 

260 return self._variable_dtype 

261 

262 @property 

263 def compute_dtype(self): 

264 """The compute dtype of this policy. 

265 

266 This is the dtype layers will do their computations in. Typically layers 

267 output tensors with the compute dtype as well. 

268 

269 Note that even if the compute dtype is float16 or bfloat16, hardware devices 

270 may not do individual adds, multiplies, and other fundamental operations in 

271 float16 or bfloat16, but instead may do some of them in float32 for numeric 

272 stability. The compute dtype is the dtype of the inputs and outputs of the 

273 TensorFlow ops that the layer executes. Internally, many TensorFlow ops will 

274 do certain internal calculations in float32 or some other device-internal 

275 intermediate format with higher precision than float16/bfloat16, to increase 

276 numeric stability. 

277 

278 For example, a `tf.keras.layers.Dense` layer, when run on a GPU with a 

279 float16 compute dtype, will pass float16 inputs to `tf.linalg.matmul`. But, 

280 `tf.linalg.matmul` will do use float32 intermediate math. The performance 

281 benefit of float16 is still apparent, due to increased memory bandwidth and 

282 the fact modern GPUs have specialized hardware for computing matmuls on 

283 float16 inputs while still keeping intermediate computations in float32. 

284 

285 Returns: 

286 The compute dtype of this policy, as a string. 

287 """ 

288 return self._compute_dtype 

289 

290 @property 

291 def name(self): 

292 """Returns the name of this policy.""" 

293 return self._name 

294 

295 def __repr__(self): 

296 return '<Policy "%s">' % self._name 

297 

298 def get_config(self): 

299 return {'name': self.name} 

300 

301 @classmethod 

302 def from_config(cls, config, custom_objects=None): 

303 del custom_objects 

304 if 'loss_scale' in config: 

305 config = config.copy() 

306 # Policy.get_config in TensorFlow 2.3 and below had a loss_scale. We 

307 # silently drop it. 

308 del config['loss_scale'] 

309 return cls(**config) 

310 

311 

312@keras_export('keras.mixed_precision.experimental.Policy', v1=[]) 

313class PolicyV1(Policy): 

314 """A deprecated dtype policy for a Keras layer. 

315 

316 Warning: This class is now deprecated and will be removed soon. Please use the 

317 non-experimental class `tf.keras.mixed_precision.Policy` instead. 

318 

319 The difference between this class and the non-experimental class is that this 

320 class has a `loss_scale` field and the non-experimental class does not. The 

321 loss scale is only used by `tf.keras.Model.compile`, which automatically wraps 

322 the optimizer with a `LossScaleOptimizer` if the optimizer is not already a 

323 `LossScaleOptimizer`. For the non-experimental Policy class, `Model.compile` 

324 instead wraps the optimizer with a `LossScaleOptimizer` if `Policy.name` is 

325 "mixed_float16". 

326 

327 When deserializing objects with an experimental policy using functions like 

328 `tf.keras.utils.deserialize_keras_object`, the policy will be deserialized as 

329 the non-experimental `tf.keras.mixed_precision.Policy`, and the loss scale 

330 will silently be dropped. This is so that SavedModels that are generated 

331 with an experimental policy can be restored after the experimental policy is 

332 removed. 

333 """ 

334 

335 def __init__(self, name, loss_scale='auto'): 

336 """Constructs the policy. 

337 

338 The `name` argument determines the compute and variable dtype, the default 

339 loss scale, and has no additional effect on the Policy. The compute and 

340 variable dtypes can only be specified through `name`, and cannot be 

341 specified directly. 

342 

343 Args: 

344 name: A string. Can be one of the following values: 

345 * Any dtype name, such as 'float32' or 'float64'. Both the variable and 

346 compute dtypes will be that dtype. 

347 * 'mixed_float16' or 'mixed_bfloat16': The compute dtype is float16 or 

348 bfloat16, while the variable dtype is float32. With 'mixed_float16', 

349 a dynamic loss scale is used. These policies are used for mixed 

350 precision training. 

351 loss_scale: A `tf.compat.v1.mixed_precision.LossScale`, an int (which 

352 uses a `FixedLossScale`), the string "dynamic" (which uses a 

353 `DynamicLossScale`), or None (which uses no loss scale). Defaults to 

354 `"auto"`. In the `"auto"` case: 1) if `name` is `"mixed_float16"`, then 

355 use `loss_scale="dynamic"`. 2) otherwise, do not use a loss scale. Only 

356 `tf.keras.Model`s, not layers, use the loss scale, and it is only used 

357 during `Model.fit`, `Model.train_on_batch`, and other similar methods. 

358 """ 

359 super(PolicyV1, self).__init__(name) 

360 if loss_scale == 'auto': 

361 loss_scale = 'dynamic' if name == 'mixed_float16' else None 

362 self._using_default_loss_scale = True 

363 else: 

364 self._using_default_loss_scale = False 

365 if loss_scale and self._compute_dtype not in (None, 'float16'): 

366 tf_logging.warning( 

367 'Creating a Policy with a loss scale is only useful for ' 

368 'float16 policies. You passed loss_scale=%r for policy ' 

369 '%s. Consider not passing any loss_scale instead.' % 

370 (loss_scale, name)) 

371 self._loss_scale = keras_loss_scale_module.get(loss_scale) 

372 

373 @property 

374 def loss_scale(self): 

375 """Returns the loss scale of this Policy. 

376 

377 Returns: 

378 A `tf.compat.v1.mixed_precision.experimental.LossScale`, or None. 

379 """ 

380 return self._loss_scale 

381 

382 def __repr__(self): 

383 return '<PolicyV1 "%s", loss_scale=%s>' % (self._name, self.loss_scale) 

384 

385 def get_config(self): 

386 config = { 

387 'name': self.name 

388 } 

389 if not self._using_default_loss_scale: 

390 # We only include the loss scale if the default loss scale is not used. 

391 # This allows us to change the loss scale config format without breaking 

392 # users who use the default loss scale. 

393 config['loss_scale'] = keras_loss_scale_module.serialize(self.loss_scale) 

394 return config 

395 

396 @classmethod 

397 def from_config(cls, config, custom_objects=None): 

398 if 'loss_scale' in config and isinstance(config['loss_scale'], dict): 

399 config = config.copy() 

400 config['loss_scale'] = keras_loss_scale_module.deserialize( 

401 config['loss_scale'], custom_objects=custom_objects) 

402 return cls(**config) 

403 

404 

405# The current global policy in effect. If None, it means the current value of 

406# floatx should be used as the policy if the V2 dtype behavior is enabled, 

407# or "_infer" otherwise. 

408# TODO(reedwm): Make this thread local? 

409_global_policy = None 

410 

411 

412@keras_export('keras.mixed_precision.global_policy', 

413 'keras.mixed_precision.experimental.global_policy', v1=[]) 

414def global_policy(): 

415 """Returns the global dtype policy. 

416 

417 The global policy is the default `tf.keras.mixed_precision.Policy` used for 

418 layers, if no policy is passed to the layer constructor. If no policy has been 

419 set with `keras.mixed_precision.set_global_policy`, this will return a policy 

420 constructed from `tf.keras.backend.floatx()` (floatx defaults to float32). 

421 

422 >>> tf.keras.mixed_precision.global_policy() 

423 <Policy "float32"> 

424 >>> tf.keras.layers.Dense(10).dtype_policy # Defaults to the global policy 

425 <Policy "float32"> 

426 

427 If TensorFlow 2 behavior has been disabled with 

428 `tf.compat.v1.disable_v2_behavior()`, this will instead return a special 

429 "_infer" policy which infers the dtype from the dtype of the first input the 

430 first time the layer is called. This behavior matches the behavior that 

431 existed in TensorFlow 1. 

432 

433 See `tf.keras.mixed_precision.Policy` for more information on policies. 

434 

435 Returns: 

436 The global Policy. 

437 """ 

438 if _global_policy is None: 

439 if base_layer_utils.v2_dtype_behavior_enabled(): 

440 return Policy(backend.floatx()) 

441 else: 

442 return Policy('_infer') 

443 return _global_policy 

444 

445 

446def _check_if_mixed_precision_graph_rewrite_is_enabled(policy): 

447 if mixed_precision_global_state.is_mixed_precision_graph_rewrite_enabled(): 

448 raise ValueError( 

449 'The global dtype policy cannot be set to "{policy.name}", because the ' 

450 'mixed precision graph rewrite has already been enabled.\n' 

451 'At most, one of the following can be called:\n\n' 

452 ' 1. tf.compat.v1.train.enable_mixed_precision_graph_rewrite() ' 

453 '(You called this first)\n' 

454 ' 2. tf.keras.mixed_precision.experimental.set_global_policy() with a ' 

455 'mixed precision policy (You called this second)\n\n' 

456 'You called both functions, which is an error, because both functions ' 

457 'enable you to use mixed precision. If in doubt which function to use, ' 

458 'use the second, as it supports Eager execution and is more ' 

459 'customizable.'.format(policy=policy)) 

460 

461 

462@keras_export('keras.mixed_precision.set_global_policy', 

463 'keras.mixed_precision.experimental.set_global_policy', v1=[]) 

464def set_global_policy(policy): 

465 """Sets the global dtype policy. 

466 

467 The global policy is the default `tf.keras.mixed_precision.Policy` used for 

468 layers, if no policy is passed to the layer constructor. 

469 

470 >>> tf.keras.mixed_precision.set_global_policy('mixed_float16') 

471 >>> tf.keras.mixed_precision.global_policy() 

472 <Policy "mixed_float16"> 

473 >>> tf.keras.layers.Dense(10).dtype_policy 

474 <Policy "mixed_float16"> 

475 >>> # Global policy is not used if a policy is directly passed to constructor 

476 >>> tf.keras.layers.Dense(10, dtype='float64').dtype_policy 

477 <Policy "float64"> 

478 >>> tf.keras.mixed_precision.set_global_policy('float32') 

479 

480 If no global policy is set, layers will instead default to a Policy 

481 constructed from `tf.keras.backend.floatx()`. 

482 

483 To use mixed precision, the global policy should be set to `'mixed_float16'` 

484 or `'mixed_bfloat16'`, so that every layer uses a 16-bit compute dtype and 

485 float32 variable dtype by default. 

486 

487 Only floating point policies can be set as the global policy, such as 

488 `'float32'` and `'mixed_float16'`. Non-floating point policies such as 

489 `'int32'` and `'complex64'` cannot be set as the global policy because most 

490 layers do not support such policies. 

491 

492 See `tf.keras.mixed_precision.Policy` for more information. 

493 

494 Args: 

495 policy: A Policy, or a string that will be converted to a Policy. Can also 

496 be None, in which case the global policy will be constructed from 

497 `tf.keras.backend.floatx()` 

498 """ 

499 global _global_policy 

500 if not base_layer_utils.v2_dtype_behavior_enabled(): 

501 raise ValueError('The global policy can only be set in TensorFlow 2 or if ' 

502 'V2 dtype behavior has been set. To enable V2 dtype ' 

503 'behavior, call ' 

504 '"tf.compat.v1.keras.layers.enable_v2_dtype_behavior()"') 

505 if policy is not None and not isinstance(policy, Policy): 

506 policy = Policy(policy) 

507 is_mixed_policy = (policy is not None and 

508 policy.compute_dtype != policy.variable_dtype) 

509 if is_mixed_policy: 

510 _check_if_mixed_precision_graph_rewrite_is_enabled(policy) 

511 if (policy is not None and policy.compute_dtype is not None and 

512 not dtypes.as_dtype(policy.compute_dtype).is_floating): 

513 raise ValueError('set_global_policy can only be used to set the global ' 

514 'policy to floating-point policies, such as "float32" and ' 

515 '"mixed_float16", but got policy: %s' 

516 % (policy.name,)) 

517 _global_policy = policy 

518 mixed_precision_global_state.set_using_mixed_precision_policy(is_mixed_policy) 

519 

520 

521# TODO(reedwm): Make this thread local 

522@contextlib.contextmanager 

523def policy_scope(policy): 

524 """A context manager that sets the global Policy under it. 

525 

526 Args: 

527 policy: A Policy, or a string that will be converted to a Policy.. 

528 

529 Yields: 

530 Nothing. 

531 """ 

532 old_policy = _global_policy 

533 try: 

534 set_global_policy(policy) 

535 yield 

536 finally: 

537 set_global_policy(old_policy) 

538 

539 

540def _is_convertible_to_dtype(dtype): 

541 try: 

542 dtypes.as_dtype(dtype) 

543 return True 

544 except TypeError: 

545 return False 

546 

547 

548def _policy_equivalent_to_dtype(policy): 

549 """Returns True if the Policy is equivalent to a single dtype. 

550 

551 A policy is equivalent to a single dtype if the policy's compute and variable 

552 dtypes are the same and the policy's type is Policy and not a subclass of 

553 Policy (such as PolicyV1). 

554 

555 The "_infer" policy is considered equivalent to a single dtype. 

556 

557 Args: 

558 policy: A Policy. 

559 

560 Returns: 

561 True, if the policy is equivalent to a single dtype. 

562 """ 

563 # We use type() instead of isinstance because a subclass of Policy is never 

564 # equivalent to a dtype. 

565 return (type(policy) == Policy and # pylint: disable=unidiomatic-typecheck 

566 list(policy.get_config().keys()) == ['name'] and 

567 (policy.name == '_infer' or _is_convertible_to_dtype(policy.name))) 

568 

569 

570def serialize(policy): 

571 if _policy_equivalent_to_dtype(policy): 

572 # We return either None or the policy name for compatibility with older 

573 # versions of Keras. If the policy name is returned, it is a dtype string 

574 # such as 'float32'. 

575 return None if policy.name == '_infer' else policy.name 

576 return generic_utils.serialize_keras_object(policy) 

577 

578 

579def deserialize(config, custom_objects=None): 

580 if isinstance(config, str) and _is_convertible_to_dtype(config): 

581 return Policy(config) 

582 if config is None: 

583 return Policy('_infer') 

584 module_objects = {'Policy': Policy, 'PolicyV1': Policy} 

585 return generic_utils.deserialize_keras_object( 

586 config, 

587 module_objects=module_objects, 

588 custom_objects=custom_objects, 

589 printable_module_name='dtype policy')