Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/checkpoint/checkpoint_context.py: 61%
41 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Context for saving checkpoint."""
17import contextlib
18import threading
21class PreemptionSaveContext(threading.local):
22 """A context for saving checkpoint upon preemption."""
24 def __init__(self):
25 super().__init__()
26 self._in_preemption_save_context = False
28 def enter_preemption_save_context(self):
29 self._in_preemption_save_context = True
31 def exit_preemption_save_context(self):
32 self._in_preemption_save_context = False
34 def in_preemption_save_context(self):
35 return self._in_preemption_save_context
38_preemption_save_context = PreemptionSaveContext()
41@contextlib.contextmanager
42def preemption_save_context():
43 _preemption_save_context.enter_preemption_save_context()
44 try:
45 yield
46 finally:
47 _preemption_save_context.exit_preemption_save_context()
50def in_preemption_save_context():
51 return _preemption_save_context.in_preemption_save_context()
54class AsyncMetricsContext(threading.local):
55 """A context for controlling metrics recording when async checkpoint is used.
56 """
58 def __init__(self):
59 super().__init__()
60 self._in_async_metrics_context = False
62 def enter_async_metrics_context(self):
63 self._in_async_metrics_context = True
65 def exit_async_metrics_context(self):
66 self._in_async_metrics_context = False
68 def in_async_metrics_context(self):
69 return self._in_async_metrics_context
72_async_metrics_context = AsyncMetricsContext()
75@contextlib.contextmanager
76def async_metrics_context():
77 _async_metrics_context.enter_async_metrics_context()
78 try:
79 yield
80 finally:
81 _async_metrics_context.exit_async_metrics_context()
84def in_async_metrics_context():
85 return _async_metrics_context.in_async_metrics_context()