Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/mixed_precision/loss_scale.py: 30%

20 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Contains keras-specific LossScale functionality. 

16 

17This functions cannot be in the non-keras loss_scale.py file since they depend 

18on keras, and files outside of keras should not depend on files inside keras. 

19""" 

20 

21from tensorflow.python.keras.utils import generic_utils 

22from tensorflow.python.training.experimental import loss_scale as loss_scale_module 

23 

24 

25def serialize(loss_scale): 

26 return generic_utils.serialize_keras_object(loss_scale) 

27 

28 

29def deserialize(config, custom_objects=None): 

30 loss_scale_module_objects = { 

31 'FixedLossScale': loss_scale_module.FixedLossScale, 

32 'DynamicLossScale': loss_scale_module.DynamicLossScale, 

33 } 

34 

35 return generic_utils.deserialize_keras_object( 

36 config, 

37 module_objects=loss_scale_module_objects, 

38 custom_objects=custom_objects, 

39 printable_module_name='loss scale' 

40 ) 

41 

42 

43def get(identifier): 

44 """Get a loss scale object.""" 

45 if isinstance(identifier, dict): 

46 return deserialize(identifier) 

47 

48 if isinstance(identifier, (int, float)): 

49 return loss_scale_module.FixedLossScale(identifier) 

50 if identifier == 'dynamic': 

51 return loss_scale_module.DynamicLossScale() 

52 if isinstance(identifier, loss_scale_module.LossScale): 

53 return identifier 

54 elif identifier is None: 

55 return None 

56 else: 

57 raise ValueError('Could not interpret loss scale identifier: %s' % 

58 identifier)