Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/utils/timed_threads.py: 41%

39 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Thread utilities.""" 

16 

17import abc 

18import threading 

19 

20from absl import logging 

21from tensorflow.python.util.tf_export import keras_export 

22 

23 

24@keras_export("keras.utils.TimedThread", v1=[]) 

25class TimedThread: 

26 """Time-based interval Threads. 

27 

28 Runs a timed thread every x seconds. It can be used to run a threaded 

29 function alongside model training or any other snippet of code. 

30 

31 Args: 

32 interval: The interval, in seconds, to wait between calls to the 

33 `on_interval` function. 

34 **kwargs: additional args that are passed to `threading.Thread`. By 

35 default, `Thread` is started as a `daemon` thread unless 

36 overridden by the user in `kwargs`. 

37 

38 Examples: 

39 

40 ```python 

41 class TimedLogIterations(keras.utils.TimedThread): 

42 def __init__(self, model, interval): 

43 self.model = model 

44 super().__init__(interval) 

45 

46 def on_interval(self): 

47 # Logs Optimizer iterations every x seconds 

48 try: 

49 opt_iterations = self.model.optimizer.iterations.numpy() 

50 print(f"Epoch: {epoch}, Optimizer Iterations: {opt_iterations}") 

51 except Exception as e: 

52 print(str(e)) # To prevent thread from getting killed 

53 

54 # `start` and `stop` the `TimerThread` manually. If the `on_interval` call 

55 # requires access to `model` or other objects, override `__init__` method. 

56 # Wrap it in a `try-except` to handle exceptions and `stop` the thread run. 

57 timed_logs = TimedLogIterations(model=model, interval=5) 

58 timed_logs.start() 

59 try: 

60 model.fit(...) 

61 finally: 

62 timed_logs.stop() 

63 

64 # Alternatively, run the `TimedThread` in a context manager 

65 with TimedLogIterations(model=model, interval=5): 

66 model.fit(...) 

67 

68 # If the timed thread instance needs access to callback events, 

69 # subclass both `TimedThread` and `Callback`. Note that when calling 

70 # `super`, they will have to called for each parent class if both of them 

71 # have the method that needs to be run. Also, note that `Callback` has 

72 # access to `model` as an attribute and need not be explictly provided. 

73 class LogThreadCallback( 

74 keras.utils.TimedThread, keras.callbacks.Callback 

75 ): 

76 def __init__(self, interval): 

77 self._epoch = 0 

78 keras.utils.TimedThread.__init__(self, interval) 

79 keras.callbacks.Callback.__init__(self) 

80 

81 def on_interval(self): 

82 if self.epoch: 

83 opt_iter = self.model.optimizer.iterations.numpy() 

84 logging.info(f"Epoch: {self._epoch}, Opt Iteration: {opt_iter}") 

85 

86 def on_epoch_begin(self, epoch, logs=None): 

87 self._epoch = epoch 

88 

89 with LogThreadCallback(interval=5) as thread_callback: 

90 # It's required to pass `thread_callback` to also `callbacks` arg of 

91 # `model.fit` to be triggered on callback events. 

92 model.fit(..., callbacks=[thread_callback]) 

93 ``` 

94 """ 

95 

96 def __init__(self, interval, **kwargs): 

97 self.interval = interval 

98 self.daemon = kwargs.pop("daemon", True) 

99 self.thread_kwargs = kwargs 

100 self.thread = None 

101 self.thread_stop_event = None 

102 

103 def _call_on_interval(self): 

104 # Runs indefinitely once thread is started 

105 while not self.thread_stop_event.is_set(): 

106 self.on_interval() 

107 self.thread_stop_event.wait(self.interval) 

108 

109 def start(self): 

110 """Creates and starts the thread run.""" 

111 if self.thread and self.thread.is_alive(): 

112 logging.warning("Thread is already running.") 

113 return 

114 self.thread = threading.Thread( 

115 target=self._call_on_interval, 

116 daemon=self.daemon, 

117 **self.thread_kwargs 

118 ) 

119 self.thread_stop_event = threading.Event() 

120 self.thread.start() 

121 

122 def stop(self): 

123 """Stops the thread run.""" 

124 if self.thread_stop_event: 

125 self.thread_stop_event.set() 

126 

127 def is_alive(self): 

128 """Returns True if thread is running. Otherwise returns False.""" 

129 if self.thread: 

130 return self.thread.is_alive() 

131 return False 

132 

133 def __enter__(self): 

134 # Starts the thread in context manager 

135 self.start() 

136 return self 

137 

138 def __exit__(self, *args, **kwargs): 

139 # Stops the thread run. 

140 self.stop() 

141 

142 @abc.abstractmethod 

143 def on_interval(self): 

144 """User-defined behavior that is called in the thread.""" 

145 raise NotImplementedError( 

146 "Runs every x interval seconds. Needs to be " 

147 "implemented in subclasses of `TimedThread`" 

148 ) 

149