Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/checkpoint/checkpoint_context.py: 61%

41 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2023 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Context for saving checkpoint.""" 

16 

17import contextlib 

18import threading 

19 

20 

21class PreemptionSaveContext(threading.local): 

22 """A context for saving checkpoint upon preemption.""" 

23 

24 def __init__(self): 

25 super().__init__() 

26 self._in_preemption_save_context = False 

27 

28 def enter_preemption_save_context(self): 

29 self._in_preemption_save_context = True 

30 

31 def exit_preemption_save_context(self): 

32 self._in_preemption_save_context = False 

33 

34 def in_preemption_save_context(self): 

35 return self._in_preemption_save_context 

36 

37 

38_preemption_save_context = PreemptionSaveContext() 

39 

40 

41@contextlib.contextmanager 

42def preemption_save_context(): 

43 _preemption_save_context.enter_preemption_save_context() 

44 try: 

45 yield 

46 finally: 

47 _preemption_save_context.exit_preemption_save_context() 

48 

49 

50def in_preemption_save_context(): 

51 return _preemption_save_context.in_preemption_save_context() 

52 

53 

54class AsyncMetricsContext(threading.local): 

55 """A context for controlling metrics recording when async checkpoint is used. 

56 """ 

57 

58 def __init__(self): 

59 super().__init__() 

60 self._in_async_metrics_context = False 

61 

62 def enter_async_metrics_context(self): 

63 self._in_async_metrics_context = True 

64 

65 def exit_async_metrics_context(self): 

66 self._in_async_metrics_context = False 

67 

68 def in_async_metrics_context(self): 

69 return self._in_async_metrics_context 

70 

71 

72_async_metrics_context = AsyncMetricsContext() 

73 

74 

75@contextlib.contextmanager 

76def async_metrics_context(): 

77 _async_metrics_context.enter_async_metrics_context() 

78 try: 

79 yield 

80 finally: 

81 _async_metrics_context.exit_async_metrics_context() 

82 

83 

84def in_async_metrics_context(): 

85 return _async_metrics_context.in_async_metrics_context()