Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/utils/timed_threads.py: 41%
39 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Thread utilities."""
17import abc
18import threading
20from absl import logging
21from tensorflow.python.util.tf_export import keras_export
24@keras_export("keras.utils.TimedThread", v1=[])
25class TimedThread:
26 """Time-based interval Threads.
28 Runs a timed thread every x seconds. It can be used to run a threaded
29 function alongside model training or any other snippet of code.
31 Args:
32 interval: The interval, in seconds, to wait between calls to the
33 `on_interval` function.
34 **kwargs: additional args that are passed to `threading.Thread`. By
35 default, `Thread` is started as a `daemon` thread unless
36 overridden by the user in `kwargs`.
38 Examples:
40 ```python
41 class TimedLogIterations(keras.utils.TimedThread):
42 def __init__(self, model, interval):
43 self.model = model
44 super().__init__(interval)
46 def on_interval(self):
47 # Logs Optimizer iterations every x seconds
48 try:
49 opt_iterations = self.model.optimizer.iterations.numpy()
50 print(f"Epoch: {epoch}, Optimizer Iterations: {opt_iterations}")
51 except Exception as e:
52 print(str(e)) # To prevent thread from getting killed
54 # `start` and `stop` the `TimerThread` manually. If the `on_interval` call
55 # requires access to `model` or other objects, override `__init__` method.
56 # Wrap it in a `try-except` to handle exceptions and `stop` the thread run.
57 timed_logs = TimedLogIterations(model=model, interval=5)
58 timed_logs.start()
59 try:
60 model.fit(...)
61 finally:
62 timed_logs.stop()
64 # Alternatively, run the `TimedThread` in a context manager
65 with TimedLogIterations(model=model, interval=5):
66 model.fit(...)
68 # If the timed thread instance needs access to callback events,
69 # subclass both `TimedThread` and `Callback`. Note that when calling
70 # `super`, they will have to called for each parent class if both of them
71 # have the method that needs to be run. Also, note that `Callback` has
72 # access to `model` as an attribute and need not be explictly provided.
73 class LogThreadCallback(
74 keras.utils.TimedThread, keras.callbacks.Callback
75 ):
76 def __init__(self, interval):
77 self._epoch = 0
78 keras.utils.TimedThread.__init__(self, interval)
79 keras.callbacks.Callback.__init__(self)
81 def on_interval(self):
82 if self.epoch:
83 opt_iter = self.model.optimizer.iterations.numpy()
84 logging.info(f"Epoch: {self._epoch}, Opt Iteration: {opt_iter}")
86 def on_epoch_begin(self, epoch, logs=None):
87 self._epoch = epoch
89 with LogThreadCallback(interval=5) as thread_callback:
90 # It's required to pass `thread_callback` to also `callbacks` arg of
91 # `model.fit` to be triggered on callback events.
92 model.fit(..., callbacks=[thread_callback])
93 ```
94 """
96 def __init__(self, interval, **kwargs):
97 self.interval = interval
98 self.daemon = kwargs.pop("daemon", True)
99 self.thread_kwargs = kwargs
100 self.thread = None
101 self.thread_stop_event = None
103 def _call_on_interval(self):
104 # Runs indefinitely once thread is started
105 while not self.thread_stop_event.is_set():
106 self.on_interval()
107 self.thread_stop_event.wait(self.interval)
109 def start(self):
110 """Creates and starts the thread run."""
111 if self.thread and self.thread.is_alive():
112 logging.warning("Thread is already running.")
113 return
114 self.thread = threading.Thread(
115 target=self._call_on_interval,
116 daemon=self.daemon,
117 **self.thread_kwargs
118 )
119 self.thread_stop_event = threading.Event()
120 self.thread.start()
122 def stop(self):
123 """Stops the thread run."""
124 if self.thread_stop_event:
125 self.thread_stop_event.set()
127 def is_alive(self):
128 """Returns True if thread is running. Otherwise returns False."""
129 if self.thread:
130 return self.thread.is_alive()
131 return False
133 def __enter__(self):
134 # Starts the thread in context manager
135 self.start()
136 return self
138 def __exit__(self, *args, **kwargs):
139 # Stops the thread run.
140 self.stop()
142 @abc.abstractmethod
143 def on_interval(self):
144 """User-defined behavior that is called in the thread."""
145 raise NotImplementedError(
146 "Runs every x interval seconds. Needs to be "
147 "implemented in subclasses of `TimedThread`"
148 )