1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements. See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership. The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License. You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied. See the License for the
15# specific language governing permissions and limitations
16# under the License.
17from __future__ import annotations
18
19import logging
20from abc import ABC, abstractmethod
21from dataclasses import dataclass
22from datetime import datetime, timedelta
23from typing import TYPE_CHECKING, Any, cast
24
25import uuid6
26from sqlalchemy import Boolean, ForeignKey, Index, Integer, and_, func, inspect, select, text
27from sqlalchemy.exc import SQLAlchemyError
28from sqlalchemy.orm import Mapped, relationship
29from sqlalchemy_utils import UUIDType
30
31from airflow._shared.timezones import timezone
32from airflow.models.base import Base
33from airflow.models.callback import Callback, CallbackDefinitionProtocol
34from airflow.observability.stats import Stats
35from airflow.utils.log.logging_mixin import LoggingMixin
36from airflow.utils.session import provide_session
37from airflow.utils.sqlalchemy import UtcDateTime, get_dialect_name, mapped_column
38
39if TYPE_CHECKING:
40 from sqlalchemy.orm import Session
41 from sqlalchemy.sql import ColumnElement
42
43
44logger = logging.getLogger(__name__)
45
46CALLBACK_METRICS_PREFIX = "deadline_alerts"
47
48
49class classproperty:
50 """
51 Decorator that converts a method with a single cls argument into a property.
52
53 Mypy won't let us use both @property and @classmethod together, this is a workaround
54 to combine the two.
55
56 Usage:
57
58 class Circle:
59 def __init__(self, radius):
60 self.radius = radius
61
62 @classproperty
63 def pi(cls):
64 return 3.14159
65
66 print(Circle.pi) # Outputs: 3.14159
67 """
68
69 def __init__(self, method):
70 self.method = method
71
72 def __get__(self, instance, cls=None):
73 return self.method(cls)
74
75
76class Deadline(Base):
77 """A Deadline is a 'need-by' date which triggers a callback if the provided time has passed."""
78
79 __tablename__ = "deadline"
80
81 id: Mapped[str] = mapped_column(UUIDType(binary=False), primary_key=True, default=uuid6.uuid7)
82
83 # If the Deadline Alert is for a DAG, store the DAG run ID from the dag_run.
84 dagrun_id: Mapped[int | None] = mapped_column(
85 Integer, ForeignKey("dag_run.id", ondelete="CASCADE"), nullable=True
86 )
87 dagrun = relationship("DagRun", back_populates="deadlines")
88
89 # The time after which the Deadline has passed and the callback should be triggered.
90 deadline_time: Mapped[datetime] = mapped_column(UtcDateTime, nullable=False)
91
92 # Whether the deadline has been marked as missed by the scheduler
93 missed: Mapped[bool] = mapped_column(Boolean, nullable=False)
94
95 # Callback that will run when this deadline is missed
96 callback_id: Mapped[str] = mapped_column(
97 UUIDType(binary=False), ForeignKey("callback.id", ondelete="CASCADE"), nullable=False
98 )
99 callback = relationship("Callback", uselist=False, cascade="all, delete-orphan", single_parent=True)
100
101 __table_args__ = (Index("deadline_missed_deadline_time_idx", missed, deadline_time, unique=False),)
102
103 def __init__(
104 self,
105 deadline_time: datetime,
106 callback: CallbackDefinitionProtocol,
107 dagrun_id: int,
108 dag_id: str | None = None,
109 ):
110 super().__init__()
111 self.deadline_time = deadline_time
112 self.dagrun_id = dagrun_id
113 self.missed = False
114 self.callback = Callback.create_from_sdk_def(
115 callback_def=callback, prefix=CALLBACK_METRICS_PREFIX, dag_id=dag_id
116 )
117
118 def __repr__(self):
119 def _determine_resource() -> tuple[str, str]:
120 """Determine the type of resource based on which values are present."""
121 if self.dagrun_id:
122 # The deadline is for a Dag run:
123 return "DagRun", f"Dag: {self.dagrun.dag_id} Run: {self.dagrun_id}"
124
125 return "Unknown", ""
126
127 resource_type, resource_details = _determine_resource()
128
129 return (
130 f"[{resource_type} Deadline] {resource_details} needed by "
131 f"{self.deadline_time} or run: {self.callback}"
132 )
133
134 @classmethod
135 def prune_deadlines(cls, *, session: Session, conditions: dict[Mapped, Any]) -> int:
136 """
137 Remove deadlines from the table which match the provided conditions and return the number removed.
138
139 NOTE: This should only be used to remove deadlines which are associated with
140 successful events (DagRuns, etc). If the deadline was missed, it will be
141 handled by the scheduler.
142
143 :param conditions: Dictionary of conditions to evaluate against.
144 :param session: Session to use.
145 """
146 from airflow.models import DagRun # Avoids circular import
147
148 # Assemble the filter conditions.
149 filter_conditions = [column == value for column, value in conditions.items()]
150 if not filter_conditions:
151 return 0
152
153 try:
154 # Get deadlines which match the provided conditions and their associated DagRuns.
155 deadline_dagrun_pairs = session.execute(
156 select(Deadline, DagRun).join(DagRun).where(and_(*filter_conditions))
157 ).all()
158
159 except AttributeError as e:
160 logger.exception("Error resolving deadlines: %s", e)
161 raise
162
163 if not deadline_dagrun_pairs:
164 return 0
165
166 deleted_count = 0
167 dagruns_to_refresh = set()
168
169 for deadline, dagrun in deadline_dagrun_pairs:
170 if dagrun.end_date <= deadline.deadline_time:
171 # If the DagRun finished before the Deadline:
172 session.delete(deadline)
173 Stats.incr(
174 "deadline_alerts.deadline_not_missed",
175 tags={"dag_id": dagrun.dag_id, "dagrun_id": dagrun.run_id},
176 )
177 deleted_count += 1
178 dagruns_to_refresh.add(dagrun)
179 session.flush()
180
181 logger.debug("%d deadline records were deleted matching the conditions %s", deleted_count, conditions)
182
183 # Refresh any affected DAG runs.
184 for dagrun in dagruns_to_refresh:
185 session.refresh(dagrun)
186
187 return deleted_count
188
189 def handle_miss(self, session: Session):
190 """Handle a missed deadline by queueing the callback."""
191
192 def get_simple_context():
193 from airflow.api_fastapi.core_api.datamodels.dag_run import DAGRunResponse
194 from airflow.models import DagRun
195
196 # TODO: Use the TaskAPI from within Triggerer to fetch full context instead of sending this context
197 # from the scheduler
198
199 # Fetch the DagRun from the database again to avoid errors when self.dagrun's relationship fields
200 # are not in the current session.
201 dagrun = session.get(DagRun, self.dagrun_id)
202
203 return {
204 "dag_run": DAGRunResponse.model_validate(dagrun).model_dump(mode="json"),
205 "deadline": {"id": self.id, "deadline_time": self.deadline_time},
206 }
207
208 self.callback.data["kwargs"] = self.callback.data["kwargs"] | {"context": get_simple_context()}
209 self.missed = True
210 self.callback.queue()
211 session.add(self)
212 Stats.incr(
213 "deadline_alerts.deadline_missed",
214 tags={"dag_id": self.dagrun.dag_id, "dagrun_id": self.dagrun.run_id},
215 )
216
217
218class ReferenceModels:
219 """
220 Store the implementations for the different Deadline References.
221
222 After adding the implementations here, all DeadlineReferences should be added
223 to the user interface in airflow.sdk.definitions.deadline.DeadlineReference
224 """
225
226 REFERENCE_TYPE_FIELD = "reference_type"
227
228 @classmethod
229 def get_reference_class(cls, reference_name: str) -> type[BaseDeadlineReference]:
230 """
231 Get a reference class by its name.
232
233 :param reference_name: The name of the reference class to find
234 """
235 try:
236 return next(
237 ref_class
238 for name, ref_class in vars(cls).items()
239 if isinstance(ref_class, type)
240 and issubclass(ref_class, cls.BaseDeadlineReference)
241 and ref_class.__name__ == reference_name
242 )
243 except StopIteration:
244 raise ValueError(f"No reference class found with name: {reference_name}")
245
246 class BaseDeadlineReference(LoggingMixin, ABC):
247 """Base class for all Deadline implementations."""
248
249 # Set of required kwargs - subclasses should override this.
250 required_kwargs: set[str] = set()
251
252 @classproperty
253 def reference_name(cls: Any) -> str:
254 return cls.__name__
255
256 def evaluate_with(self, *, session: Session, interval: timedelta, **kwargs: Any) -> datetime | None:
257 """Validate the provided kwargs and evaluate this deadline with the given conditions."""
258 filtered_kwargs = {k: v for k, v in kwargs.items() if k in self.required_kwargs}
259
260 if missing_kwargs := self.required_kwargs - filtered_kwargs.keys():
261 raise ValueError(
262 f"{self.__class__.__name__} is missing required parameters: {', '.join(missing_kwargs)}"
263 )
264
265 if extra_kwargs := kwargs.keys() - filtered_kwargs.keys():
266 self.log.debug("Ignoring unexpected parameters: %s", ", ".join(extra_kwargs))
267
268 base_time = self._evaluate_with(session=session, **filtered_kwargs)
269 return base_time + interval if base_time is not None else None
270
271 @abstractmethod
272 def _evaluate_with(self, *, session: Session, **kwargs: Any) -> datetime | None:
273 """Must be implemented by subclasses to perform the actual evaluation."""
274 raise NotImplementedError
275
276 @classmethod
277 def deserialize_reference(cls, reference_data: dict):
278 """
279 Deserialize a reference type from its dictionary representation.
280
281 While the base implementation doesn't use reference_data, this parameter is required
282 for subclasses that need additional data for initialization (like FixedDatetimeDeadline
283 which needs a datetime value).
284
285 :param reference_data: Dictionary containing serialized reference data.
286 Always includes a 'reference_type' field, and may include additional
287 fields needed by specific reference implementations.
288 """
289 return cls()
290
291 def serialize_reference(self) -> dict:
292 """
293 Serialize this reference type into a dictionary representation.
294
295 This method assumes that the reference doesn't require any additional data.
296 Override this method in subclasses if additional data is needed for serialization.
297 """
298 return {ReferenceModels.REFERENCE_TYPE_FIELD: self.reference_name}
299
300 @dataclass
301 class FixedDatetimeDeadline(BaseDeadlineReference):
302 """A deadline that always returns a fixed datetime."""
303
304 _datetime: datetime
305
306 def _evaluate_with(self, *, session: Session, **kwargs: Any) -> datetime | None:
307 return self._datetime
308
309 def serialize_reference(self) -> dict:
310 return {
311 ReferenceModels.REFERENCE_TYPE_FIELD: self.reference_name,
312 "datetime": self._datetime.timestamp(),
313 }
314
315 @classmethod
316 def deserialize_reference(cls, reference_data: dict):
317 return cls(_datetime=timezone.from_timestamp(reference_data["datetime"]))
318
319 class DagRunLogicalDateDeadline(BaseDeadlineReference):
320 """A deadline that returns a DagRun's logical date."""
321
322 required_kwargs = {"dag_id", "run_id"}
323
324 def _evaluate_with(self, *, session: Session, **kwargs: Any) -> datetime | None:
325 from airflow.models import DagRun
326
327 return _fetch_from_db(DagRun.logical_date, session=session, **kwargs)
328
329 class DagRunQueuedAtDeadline(BaseDeadlineReference):
330 """A deadline that returns when a DagRun was queued."""
331
332 required_kwargs = {"dag_id", "run_id"}
333
334 @provide_session
335 def _evaluate_with(self, *, session: Session, **kwargs: Any) -> datetime | None:
336 from airflow.models import DagRun
337
338 return _fetch_from_db(DagRun.queued_at, session=session, **kwargs)
339
340 @dataclass
341 class AverageRuntimeDeadline(BaseDeadlineReference):
342 """A deadline that calculates the average runtime from past DAG runs."""
343
344 DEFAULT_LIMIT = 10
345 max_runs: int
346 min_runs: int | None = None
347 required_kwargs = {"dag_id"}
348
349 def __post_init__(self):
350 if self.min_runs is None:
351 self.min_runs = self.max_runs
352 if self.min_runs < 1:
353 raise ValueError("min_runs must be at least 1")
354
355 @provide_session
356 def _evaluate_with(self, *, session: Session, **kwargs: Any) -> datetime | None:
357 from airflow.models import DagRun
358
359 dag_id = kwargs["dag_id"]
360
361 # Get database dialect to use appropriate time difference calculation
362 dialect = get_dialect_name(session)
363
364 # Create database-specific expression for calculating duration in seconds
365 duration_expr: ColumnElement[Any]
366 if dialect == "postgresql":
367 duration_expr = func.extract("epoch", DagRun.end_date - DagRun.start_date)
368 elif dialect == "mysql":
369 # Use TIMESTAMPDIFF to get exact seconds like PostgreSQL EXTRACT(epoch FROM ...)
370 duration_expr = func.timestampdiff(text("SECOND"), DagRun.start_date, DagRun.end_date)
371 elif dialect == "sqlite":
372 duration_expr = (func.julianday(DagRun.end_date) - func.julianday(DagRun.start_date)) * 86400
373 else:
374 raise ValueError(f"Unsupported database dialect: {dialect}")
375
376 # Query for completed DAG runs with both start and end dates
377 # Order by logical_date descending to get most recent runs first
378 query = (
379 select(duration_expr)
380 .filter(DagRun.dag_id == dag_id, DagRun.start_date.isnot(None), DagRun.end_date.isnot(None))
381 .order_by(DagRun.logical_date.desc())
382 )
383
384 # Apply max_runs
385 query = query.limit(self.max_runs)
386
387 # Get all durations and calculate average
388 durations = session.execute(query).scalars().all()
389
390 if len(durations) < cast("int", self.min_runs):
391 logger.info(
392 "Only %d completed DAG runs found for dag_id: %s (need %d), skipping deadline creation",
393 len(durations),
394 dag_id,
395 self.min_runs,
396 )
397 return None
398 # Convert to float to handle Decimal types from MySQL while preserving precision
399 # Use Decimal arithmetic for higher precision, then convert to float
400 from decimal import Decimal
401
402 decimal_durations = [Decimal(str(d)) for d in durations]
403 avg_seconds = float(sum(decimal_durations) / len(decimal_durations))
404 logger.info(
405 "Average runtime for dag_id %s (from %d runs): %.2f seconds",
406 dag_id,
407 len(durations),
408 avg_seconds,
409 )
410 return timezone.utcnow() + timedelta(seconds=avg_seconds)
411
412 def serialize_reference(self) -> dict:
413 return {
414 ReferenceModels.REFERENCE_TYPE_FIELD: self.reference_name,
415 "max_runs": self.max_runs,
416 "min_runs": self.min_runs,
417 }
418
419 @classmethod
420 def deserialize_reference(cls, reference_data: dict):
421 max_runs = reference_data.get("max_runs", cls.DEFAULT_LIMIT)
422 min_runs = reference_data.get("min_runs", max_runs)
423 if min_runs < 1:
424 raise ValueError("min_runs must be at least 1")
425 return cls(
426 max_runs=max_runs,
427 min_runs=min_runs,
428 )
429
430
431DeadlineReferenceType = ReferenceModels.BaseDeadlineReference
432
433
434@provide_session
435def _fetch_from_db(model_reference: Mapped, session=None, **conditions) -> datetime | None:
436 """
437 Fetch a datetime value from the database using the provided model reference and filtering conditions.
438
439 For example, to fetch a TaskInstance's start_date:
440 _fetch_from_db(
441 TaskInstance.start_date, dag_id='example_dag', task_id='example_task', run_id='example_run'
442 )
443
444 This generates SQL equivalent to:
445 SELECT start_date
446 FROM task_instance
447 WHERE dag_id = 'example_dag'
448 AND task_id = 'example_task'
449 AND run_id = 'example_run'
450
451 :param model_reference: SQLAlchemy Column to select (e.g., DagRun.logical_date, TaskInstance.start_date)
452 :param conditions: Filtering conditions applied as equality comparisons in the WHERE clause.
453 Multiple conditions are combined with AND.
454 :param session: SQLAlchemy session (auto-provided by decorator)
455 """
456 query = select(model_reference)
457
458 for key, value in conditions.items():
459 inspected = inspect(model_reference)
460 if inspected is not None:
461 query = query.where(getattr(inspected.class_, key) == value)
462
463 compiled_query = query.compile(compile_kwargs={"literal_binds": True})
464 pretty_query = "\n ".join(str(compiled_query).splitlines())
465 logger.debug(
466 "Executing query:\n %r\nAs SQL:\n %s",
467 query,
468 pretty_query,
469 )
470
471 try:
472 result = session.scalar(query)
473 except SQLAlchemyError:
474 logger.exception("Database query failed.")
475 raise
476
477 if result is None:
478 message = f"No matching record found in the database for query:\n {pretty_query}"
479 logger.error(message)
480 raise ValueError(message)
481
482 return result