Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/training/experimental/mixed_precision_global_state.py: 68%

19 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Contains global variables related to mixed precision. 

16 

17This is not part of mixed_precision.py to avoid a circular dependency. 

18mixed_precision.py depends on Session, and Session depends on this file. 

19""" 

20 

21from tensorflow.python.util.tf_export import tf_export 

22 

23# Whether the mixed precision graph rewrite has been enabled or not with 

24# `enable_mixed_precision_graph_rewrite`. Used to turn on auto_mixed_precision 

25# in ConfigProtos passed to Sessions. 

26_mixed_precision_graph_rewrite_is_enabled = False 

27 

28 

29# True if a Session has been created without the mixed precision graph rewrite 

30# being enabled. Used to give a warning if mixed precision is enabled after a 

31# Session has already been created. 

32_non_mixed_precision_session_created = False 

33 

34# Whether the global tf.keras.mixed_precision.Policy uses mixed precision. Used 

35# to raise an error message if both a mixed Policy and the graph rewrite are 

36# used at the same time. 

37_using_mixed_precision_policy = False 

38 

39 

40@tf_export('__internal__.train.is_mixed_precision_graph_rewrite_enabled', v1=[]) 

41def is_mixed_precision_graph_rewrite_enabled(): 

42 return _mixed_precision_graph_rewrite_is_enabled 

43 

44 

45def set_mixed_precision_graph_rewrite_enabled(enabled): 

46 global _mixed_precision_graph_rewrite_is_enabled 

47 _mixed_precision_graph_rewrite_is_enabled = enabled 

48 

49 

50def non_mixed_precision_session_created(): 

51 return _non_mixed_precision_session_created 

52 

53 

54def set_non_mixed_precision_session_created(created): 

55 global _non_mixed_precision_session_created 

56 _non_mixed_precision_session_created = created 

57 

58 

59def is_using_mixed_precision_policy(): 

60 return _using_mixed_precision_policy 

61 

62 

63@tf_export('__internal__.train.set_using_mixed_precision_policy', v1=[]) 

64def set_using_mixed_precision_policy(is_using): 

65 global _using_mixed_precision_policy 

66 _using_mixed_precision_policy = is_using