Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/mixed_precision/loss_scale.py: 30%
20 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Contains keras-specific LossScale functionality.
17This functions cannot be in the non-keras loss_scale.py file since they depend
18on keras, and files outside of keras should not depend on files inside keras.
19"""
21from tensorflow.python.keras.utils import generic_utils
22from tensorflow.python.training.experimental import loss_scale as loss_scale_module
25def serialize(loss_scale):
26 return generic_utils.serialize_keras_object(loss_scale)
29def deserialize(config, custom_objects=None):
30 loss_scale_module_objects = {
31 'FixedLossScale': loss_scale_module.FixedLossScale,
32 'DynamicLossScale': loss_scale_module.DynamicLossScale,
33 }
35 return generic_utils.deserialize_keras_object(
36 config,
37 module_objects=loss_scale_module_objects,
38 custom_objects=custom_objects,
39 printable_module_name='loss scale'
40 )
43def get(identifier):
44 """Get a loss scale object."""
45 if isinstance(identifier, dict):
46 return deserialize(identifier)
48 if isinstance(identifier, (int, float)):
49 return loss_scale_module.FixedLossScale(identifier)
50 if identifier == 'dynamic':
51 return loss_scale_module.DynamicLossScale()
52 if isinstance(identifier, loss_scale_module.LossScale):
53 return identifier
54 elif identifier is None:
55 return None
56 else:
57 raise ValueError('Could not interpret loss scale identifier: %s' %
58 identifier)