1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements. See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership. The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License. You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied. See the License for the
15# specific language governing permissions and limitations
16# under the License.
17from __future__ import annotations
18
19import logging
20from abc import ABC, abstractmethod
21from collections.abc import Sequence
22from dataclasses import dataclass
23from datetime import datetime, timedelta
24from typing import TYPE_CHECKING, Any, cast
25from uuid import UUID
26
27import uuid6
28from sqlalchemy import Boolean, ForeignKey, Index, Integer, Uuid, and_, func, inspect, select, text
29from sqlalchemy.exc import SQLAlchemyError
30from sqlalchemy.orm import Mapped, mapped_column, relationship
31
32from airflow._shared.observability.metrics.stats import Stats
33from airflow._shared.timezones import timezone
34from airflow.models.base import Base
35from airflow.models.callback import Callback, CallbackDefinitionProtocol
36from airflow.utils.log.logging_mixin import LoggingMixin
37from airflow.utils.session import provide_session
38from airflow.utils.sqlalchemy import UtcDateTime, get_dialect_name
39
40if TYPE_CHECKING:
41 from sqlalchemy.orm import Session
42 from sqlalchemy.sql import ColumnElement
43
44 from airflow.models.callback import CallbackDefinitionProtocol
45 from airflow.models.deadline_alert import DeadlineAlert
46
47
48logger = logging.getLogger(__name__)
49
50CALLBACK_METRICS_PREFIX = "deadline_alerts"
51
52
53class classproperty:
54 """
55 Decorator that converts a method with a single cls argument into a property.
56
57 Mypy won't let us use both @property and @classmethod together, this is a workaround
58 to combine the two.
59
60 Usage:
61
62 class Circle:
63 def __init__(self, radius):
64 self.radius = radius
65
66 @classproperty
67 def pi(cls):
68 return 3.14159
69
70 print(Circle.pi) # Outputs: 3.14159
71 """
72
73 def __init__(self, method):
74 self.method = method
75
76 def __get__(self, instance, cls=None):
77 return self.method(cls)
78
79
80class Deadline(Base):
81 """A Deadline is a 'need-by' date which triggers a callback if the provided time has passed."""
82
83 __tablename__ = "deadline"
84
85 id: Mapped[UUID] = mapped_column(Uuid(), primary_key=True, default=uuid6.uuid7)
86 created_at: Mapped[datetime] = mapped_column(UtcDateTime, nullable=False, default=timezone.utcnow)
87 last_updated_at: Mapped[datetime] = mapped_column(
88 UtcDateTime, nullable=False, default=timezone.utcnow, onupdate=timezone.utcnow
89 )
90
91 # If the Deadline Alert is for a DAG, store the DAG run ID from the dag_run.
92 dagrun_id: Mapped[int | None] = mapped_column(
93 Integer, ForeignKey("dag_run.id", ondelete="CASCADE"), nullable=True
94 )
95 dagrun = relationship("DagRun", back_populates="deadlines")
96
97 # The time after which the Deadline has passed and the callback should be triggered.
98 deadline_time: Mapped[datetime] = mapped_column(UtcDateTime, nullable=False)
99
100 # Whether the deadline has been marked as missed by the scheduler
101 missed: Mapped[bool] = mapped_column(Boolean, nullable=False)
102
103 # Callback that will run when this deadline is missed
104 callback_id: Mapped[UUID] = mapped_column(
105 Uuid(), ForeignKey("callback.id", ondelete="CASCADE"), nullable=False
106 )
107 callback = relationship("Callback", uselist=False, cascade="all, delete-orphan", single_parent=True)
108
109 # The DeadlineAlert that generated this deadline
110 deadline_alert_id: Mapped[UUID | None] = mapped_column(
111 Uuid(), ForeignKey("deadline_alert.id", ondelete="SET NULL"), nullable=True
112 )
113 deadline_alert: Mapped[DeadlineAlert | None] = relationship("DeadlineAlert")
114
115 __table_args__ = (Index("deadline_missed_deadline_time_idx", missed, deadline_time, unique=False),)
116
117 def __init__(
118 self,
119 deadline_time: datetime,
120 callback: CallbackDefinitionProtocol,
121 dagrun_id: int,
122 deadline_alert_id: UUID | None,
123 dag_id: str | None = None,
124 ):
125 super().__init__()
126 self.deadline_time = deadline_time
127 self.dagrun_id = dagrun_id
128 self.missed = False
129 self.callback = Callback.create_from_sdk_def(
130 callback_def=callback, prefix=CALLBACK_METRICS_PREFIX, dag_id=dag_id
131 )
132 self.deadline_alert_id = deadline_alert_id
133
134 def __repr__(self):
135 def _determine_resource() -> tuple[str, str]:
136 """Determine the type of resource based on which values are present."""
137 if self.dagrun_id:
138 # The deadline is for a Dag run:
139 return "DagRun", f"Dag: {self.dagrun.dag_id} Run: {self.dagrun_id}"
140
141 return "Unknown", ""
142
143 resource_type, resource_details = _determine_resource()
144
145 return (
146 f"[{resource_type} Deadline] "
147 f"created at {self.created_at}, "
148 f"{resource_details}, "
149 f"needed by {self.deadline_time} "
150 f"or run: {self.callback}"
151 )
152
153 @classmethod
154 def prune_deadlines(cls, *, session: Session, conditions: dict[Mapped, Any]) -> int:
155 """
156 Remove deadlines from the table which match the provided conditions and return the number removed.
157
158 NOTE: This should only be used to remove deadlines which are associated with
159 successful events (DagRuns, etc). If the deadline was missed, it will be
160 handled by the scheduler.
161
162 :param conditions: Dictionary of conditions to evaluate against.
163 :param session: Session to use.
164 """
165 from airflow.models import DagRun # Avoids circular import
166
167 # Assemble the filter conditions.
168 filter_conditions = [column == value for column, value in conditions.items()]
169 if not filter_conditions:
170 return 0
171
172 try:
173 # Get deadlines which match the provided conditions and their associated DagRuns.
174 deadline_dagrun_pairs = session.execute(
175 select(Deadline, DagRun).join(DagRun).where(and_(*filter_conditions))
176 ).all()
177
178 except AttributeError as e:
179 logger.exception("Error resolving deadlines: %s", e)
180 raise
181
182 if not deadline_dagrun_pairs:
183 return 0
184
185 deleted_count = 0
186 dagruns_to_refresh = set()
187
188 for deadline, dagrun in deadline_dagrun_pairs:
189 if dagrun.end_date is not None and dagrun.end_date <= deadline.deadline_time:
190 # If the DagRun finished before the Deadline:
191 session.delete(deadline)
192 Stats.incr(
193 "deadline_alerts.deadline_not_missed",
194 tags={"dag_id": dagrun.dag_id, "dagrun_id": dagrun.run_id},
195 )
196 deleted_count += 1
197 dagruns_to_refresh.add(dagrun)
198 session.flush()
199
200 logger.debug("%d deadline records were deleted matching the conditions %s", deleted_count, conditions)
201
202 # Refresh any affected DAG runs.
203 for dagrun in dagruns_to_refresh:
204 session.refresh(dagrun)
205
206 return deleted_count
207
208 def handle_miss(self, session: Session):
209 """Handle a missed deadline by queueing the callback."""
210
211 def get_simple_context():
212 from airflow.api_fastapi.core_api.datamodels.dag_run import DAGRunResponse
213 from airflow.models import DagRun
214
215 # TODO: Use the TaskAPI from within Triggerer to fetch full context instead of sending this context
216 # from the scheduler
217
218 # Fetch the DagRun from the database again to avoid errors when self.dagrun's relationship fields
219 # are not in the current session.
220 dagrun = session.get(DagRun, self.dagrun_id)
221
222 return {
223 "dag_run": DAGRunResponse.model_validate(dagrun).model_dump(mode="json"),
224 "deadline": {"id": self.id, "deadline_time": self.deadline_time},
225 }
226
227 self.callback.data["kwargs"] = self.callback.data["kwargs"] | {"context": get_simple_context()}
228 self.missed = True
229 self.callback.queue()
230 session.add(self)
231 Stats.incr(
232 "deadline_alerts.deadline_missed",
233 tags={"dag_id": self.dagrun.dag_id, "dagrun_id": self.dagrun.run_id},
234 )
235
236
237class ReferenceModels:
238 """
239 Store the implementations for the different Deadline References.
240
241 After adding the implementations here, all DeadlineReferences should be added
242 to the user interface in airflow.sdk.definitions.deadline.DeadlineReference
243 """
244
245 REFERENCE_TYPE_FIELD = "reference_type"
246
247 @classmethod
248 def get_reference_class(cls, reference_name: str) -> type[BaseDeadlineReference]:
249 """
250 Get a reference class by its name.
251
252 :param reference_name: The name of the reference class to find
253 """
254 try:
255 return next(
256 ref_class
257 for name, ref_class in vars(cls).items()
258 if isinstance(ref_class, type)
259 and issubclass(ref_class, cls.BaseDeadlineReference)
260 and ref_class.__name__ == reference_name
261 )
262 except StopIteration:
263 raise ValueError(f"No reference class found with name: {reference_name}")
264
265 class BaseDeadlineReference(LoggingMixin, ABC):
266 """Base class for all Deadline implementations."""
267
268 # Set of required kwargs - subclasses should override this.
269 required_kwargs: set[str] = set()
270
271 @classproperty
272 def reference_name(cls: Any) -> str:
273 return cls.__name__
274
275 def evaluate_with(self, *, session: Session, interval: timedelta, **kwargs: Any) -> datetime | None:
276 """Validate the provided kwargs and evaluate this deadline with the given conditions."""
277 filtered_kwargs = {k: v for k, v in kwargs.items() if k in self.required_kwargs}
278
279 if missing_kwargs := self.required_kwargs - filtered_kwargs.keys():
280 raise ValueError(
281 f"{self.__class__.__name__} is missing required parameters: {', '.join(missing_kwargs)}"
282 )
283
284 if extra_kwargs := kwargs.keys() - filtered_kwargs.keys():
285 self.log.debug("Ignoring unexpected parameters: %s", ", ".join(extra_kwargs))
286
287 base_time = self._evaluate_with(session=session, **filtered_kwargs)
288 return base_time + interval if base_time is not None else None
289
290 @abstractmethod
291 def _evaluate_with(self, *, session: Session, **kwargs: Any) -> datetime | None:
292 """Must be implemented by subclasses to perform the actual evaluation."""
293 raise NotImplementedError
294
295 @classmethod
296 def deserialize_reference(cls, reference_data: dict):
297 """
298 Deserialize a reference type from its dictionary representation.
299
300 While the base implementation doesn't use reference_data, this parameter is required
301 for subclasses that need additional data for initialization (like FixedDatetimeDeadline
302 which needs a datetime value).
303
304 :param reference_data: Dictionary containing serialized reference data.
305 Always includes a 'reference_type' field, and may include additional
306 fields needed by specific reference implementations.
307 """
308 return cls()
309
310 def serialize_reference(self) -> dict:
311 """
312 Serialize this reference type into a dictionary representation.
313
314 This method assumes that the reference doesn't require any additional data.
315 Override this method in subclasses if additional data is needed for serialization.
316 """
317 return {ReferenceModels.REFERENCE_TYPE_FIELD: self.reference_name}
318
319 @dataclass
320 class FixedDatetimeDeadline(BaseDeadlineReference):
321 """A deadline that always returns a fixed datetime."""
322
323 _datetime: datetime
324
325 def _evaluate_with(self, *, session: Session, **kwargs: Any) -> datetime | None:
326 return self._datetime
327
328 def serialize_reference(self) -> dict:
329 return {
330 ReferenceModels.REFERENCE_TYPE_FIELD: self.reference_name,
331 "datetime": self._datetime.timestamp(),
332 }
333
334 @classmethod
335 def deserialize_reference(cls, reference_data: dict):
336 return cls(_datetime=timezone.from_timestamp(reference_data["datetime"]))
337
338 class DagRunLogicalDateDeadline(BaseDeadlineReference):
339 """A deadline that returns a DagRun's logical date."""
340
341 required_kwargs = {"dag_id", "run_id"}
342
343 def _evaluate_with(self, *, session: Session, **kwargs: Any) -> datetime | None:
344 from airflow.models import DagRun
345
346 return _fetch_from_db(DagRun.logical_date, session=session, **kwargs)
347
348 class DagRunQueuedAtDeadline(BaseDeadlineReference):
349 """A deadline that returns when a DagRun was queued."""
350
351 required_kwargs = {"dag_id", "run_id"}
352
353 @provide_session
354 def _evaluate_with(self, *, session: Session, **kwargs: Any) -> datetime | None:
355 from airflow.models import DagRun
356
357 return _fetch_from_db(DagRun.queued_at, session=session, **kwargs)
358
359 @dataclass
360 class AverageRuntimeDeadline(BaseDeadlineReference):
361 """A deadline that calculates the average runtime from past DAG runs."""
362
363 DEFAULT_LIMIT = 10
364 max_runs: int
365 min_runs: int | None = None
366 required_kwargs = {"dag_id"}
367
368 def __post_init__(self):
369 if self.min_runs is None:
370 self.min_runs = self.max_runs
371 if self.min_runs < 1:
372 raise ValueError("min_runs must be at least 1")
373
374 @provide_session
375 def _evaluate_with(self, *, session: Session, **kwargs: Any) -> datetime | None:
376 from airflow.models import DagRun
377
378 dag_id = kwargs["dag_id"]
379
380 # Get database dialect to use appropriate time difference calculation
381 dialect = get_dialect_name(session)
382
383 # Create database-specific expression for calculating duration in seconds
384 duration_expr: ColumnElement[Any]
385 if dialect == "postgresql":
386 duration_expr = func.extract("epoch", DagRun.end_date - DagRun.start_date)
387 elif dialect == "mysql":
388 # Use TIMESTAMPDIFF to get exact seconds like PostgreSQL EXTRACT(epoch FROM ...)
389 duration_expr = func.timestampdiff(text("SECOND"), DagRun.start_date, DagRun.end_date)
390 elif dialect == "sqlite":
391 duration_expr = (func.julianday(DagRun.end_date) - func.julianday(DagRun.start_date)) * 86400
392 else:
393 raise ValueError(f"Unsupported database dialect: {dialect}")
394
395 # Query for completed DAG runs with both start and end dates
396 # Order by logical_date descending to get most recent runs first
397 query = (
398 select(duration_expr)
399 .filter(DagRun.dag_id == dag_id, DagRun.start_date.isnot(None), DagRun.end_date.isnot(None))
400 .order_by(DagRun.logical_date.desc())
401 )
402
403 # Apply max_runs
404 query = query.limit(self.max_runs)
405
406 # Get all durations and calculate average
407 durations: Sequence = session.execute(query).scalars().all()
408
409 if len(durations) < cast("int", self.min_runs):
410 logger.info(
411 "Only %d completed DAG runs found for dag_id: %s (need %d), skipping deadline creation",
412 len(durations),
413 dag_id,
414 self.min_runs,
415 )
416 return None
417 # Convert to float to handle Decimal types from MySQL while preserving precision
418 # Use Decimal arithmetic for higher precision, then convert to float
419 from decimal import Decimal
420
421 decimal_durations = [Decimal(str(d)) for d in durations]
422 avg_seconds = float(sum(decimal_durations) / len(decimal_durations))
423 logger.info(
424 "Average runtime for dag_id %s (from %d runs): %.2f seconds",
425 dag_id,
426 len(durations),
427 avg_seconds,
428 )
429 return timezone.utcnow() + timedelta(seconds=avg_seconds)
430
431 def serialize_reference(self) -> dict:
432 return {
433 ReferenceModels.REFERENCE_TYPE_FIELD: self.reference_name,
434 "max_runs": self.max_runs,
435 "min_runs": self.min_runs,
436 }
437
438 @classmethod
439 def deserialize_reference(cls, reference_data: dict):
440 max_runs = reference_data.get("max_runs", cls.DEFAULT_LIMIT)
441 min_runs = reference_data.get("min_runs", max_runs)
442 if min_runs < 1:
443 raise ValueError("min_runs must be at least 1")
444 return cls(
445 max_runs=max_runs,
446 min_runs=min_runs,
447 )
448
449
450DeadlineReferenceType = ReferenceModels.BaseDeadlineReference
451
452
453@provide_session
454def _fetch_from_db(model_reference: Mapped, session=None, **conditions) -> datetime | None:
455 """
456 Fetch a datetime value from the database using the provided model reference and filtering conditions.
457
458 For example, to fetch a TaskInstance's start_date:
459 _fetch_from_db(
460 TaskInstance.start_date, dag_id='example_dag', task_id='example_task', run_id='example_run'
461 )
462
463 This generates SQL equivalent to:
464 SELECT start_date
465 FROM task_instance
466 WHERE dag_id = 'example_dag'
467 AND task_id = 'example_task'
468 AND run_id = 'example_run'
469
470 :param model_reference: SQLAlchemy Column to select (e.g., DagRun.logical_date, TaskInstance.start_date)
471 :param conditions: Filtering conditions applied as equality comparisons in the WHERE clause.
472 Multiple conditions are combined with AND.
473 :param session: SQLAlchemy session (auto-provided by decorator)
474 """
475 query = select(model_reference)
476
477 for key, value in conditions.items():
478 inspected = inspect(model_reference)
479 if inspected is not None:
480 query = query.where(getattr(inspected.class_, key) == value)
481
482 compiled_query = query.compile(compile_kwargs={"literal_binds": True})
483 pretty_query = "\n ".join(str(compiled_query).splitlines())
484 logger.debug(
485 "Executing query:\n %r\nAs SQL:\n %s",
486 query,
487 pretty_query,
488 )
489
490 try:
491 result = session.scalar(query)
492 except SQLAlchemyError:
493 logger.exception("Database query failed.")
494 raise
495
496 if result is None:
497 message = f"No matching record found in the database for query:\n {pretty_query}"
498 logger.error(message)
499 raise ValueError(message)
500
501 return result