1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18
19from __future__ import annotations
20
21import os
22import tempfile
23from itertools import islice
24from typing import IO, TYPE_CHECKING
25
26if TYPE_CHECKING:
27 from airflow.typing_compat import Self
28 from airflow.utils.log.file_task_handler import (
29 LogHandlerOutputStream,
30 StructuredLogMessage,
31 StructuredLogStream,
32 )
33
34
35class LogStreamAccumulator:
36 """
37 Memory-efficient log stream accumulator that tracks the total number of lines while preserving the original stream.
38
39 This class captures logs from a stream and stores them in a buffer, flushing them to disk when the buffer
40 exceeds a specified threshold. This approach optimizes memory usage while handling large log streams.
41
42 Usage:
43
44 .. code-block:: python
45
46 with LogStreamAccumulator(stream, threshold) as log_accumulator:
47 # Get total number of lines captured
48 total_lines = log_accumulator.get_total_lines()
49
50 # Retrieve the original stream of logs
51 for log in log_accumulator.get_stream():
52 print(log)
53 """
54
55 def __init__(
56 self,
57 stream: LogHandlerOutputStream,
58 threshold: int,
59 ) -> None:
60 """
61 Initialize the LogStreamAccumulator.
62
63 Args:
64 stream: The input log stream to capture and count.
65 threshold: Maximum number of lines to keep in memory before flushing to disk.
66 """
67 self._stream = stream
68 self._threshold = threshold
69 self._buffer: list[StructuredLogMessage] = []
70 self._disk_lines: int = 0
71 self._tmpfile: IO[str] | None = None
72
73 def _flush_buffer_to_disk(self) -> None:
74 """Flush the buffer contents to a temporary file on disk."""
75 if self._tmpfile is None:
76 self._tmpfile = tempfile.NamedTemporaryFile(delete=False, mode="w+", encoding="utf-8")
77
78 self._disk_lines += len(self._buffer)
79 self._tmpfile.writelines(f"{log.model_dump_json()}\n" for log in self._buffer)
80 self._tmpfile.flush()
81 self._buffer.clear()
82
83 def _capture(self) -> None:
84 """Capture logs from the stream into the buffer, flushing to disk when threshold is reached."""
85 while True:
86 # `islice` will try to get up to `self._threshold` lines from the stream.
87 self._buffer.extend(islice(self._stream, self._threshold))
88 # If there are no more lines to capture, exit the loop.
89 if len(self._buffer) < self._threshold:
90 break
91 self._flush_buffer_to_disk()
92
93 def _cleanup(self) -> None:
94 """Clean up the temporary file if it exists."""
95 self._buffer.clear()
96 if self._tmpfile:
97 self._tmpfile.close()
98 os.remove(self._tmpfile.name)
99 self._tmpfile = None
100
101 @property
102 def total_lines(self) -> int:
103 """
104 Return the total number of lines captured from the stream.
105
106 Returns:
107 The sum of lines stored in the buffer and lines written to disk.
108 """
109 return self._disk_lines + len(self._buffer)
110
111 @property
112 def stream(self) -> StructuredLogStream:
113 """
114 Return the original stream of logs and clean up resources.
115
116 Important: This method automatically cleans up resources after all logs have been yielded.
117 Make sure to fully consume the returned generator to ensure proper cleanup.
118
119 Returns:
120 A stream of the captured log messages.
121 """
122 try:
123 if not self._tmpfile:
124 # if no temporary file was created, return from the buffer
125 yield from self._buffer
126 else:
127 # avoid circular import
128 from airflow.utils.log.file_task_handler import StructuredLogMessage
129
130 with open(self._tmpfile.name, encoding="utf-8") as f:
131 yield from (StructuredLogMessage.model_validate_json(line.strip()) for line in f)
132 # yield the remaining buffer
133 yield from self._buffer
134 finally:
135 # Ensure cleanup after yielding
136 self._cleanup()
137
138 def __enter__(self) -> Self:
139 """
140 Context manager entry point that initiates log capture.
141
142 Returns:
143 Self instance for use in context manager.
144 """
145 self._capture()
146 return self
147
148 def __exit__(self, exc_type, exc_val, exc_tb) -> None:
149 """
150 Context manager exit that doesn't perform resource cleanup.
151
152 Note: Resources are not cleaned up here. Cleanup is deferred until
153 get_stream() is called and fully consumed, ensuring all logs are properly
154 yielded before cleanup occurs.
155 """