1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements. See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership. The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License. You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied. See the License for the
15# specific language governing permissions and limitations
16# under the License.
17from __future__ import annotations
18
19import logging
20from datetime import datetime, timedelta
21from typing import TYPE_CHECKING, cast
22
23from airflow.models.deadline import DeadlineReferenceType, ReferenceModels
24from airflow.sdk.definitions.callback import AsyncCallback, Callback
25from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding
26from airflow.serialization.serde import deserialize, serialize
27
28if TYPE_CHECKING:
29 from collections.abc import Callable
30 from typing import TypeAlias
31
32logger = logging.getLogger(__name__)
33
34DeadlineReferenceTypes: TypeAlias = tuple[type[ReferenceModels.BaseDeadlineReference], ...]
35
36
37class DeadlineAlertFields:
38 """
39 Define field names used in DeadlineAlert serialization/deserialization.
40
41 These constants provide a single source of truth for the field names used when
42 serializing DeadlineAlert instances to and from their dictionary representation.
43 """
44
45 REFERENCE = "reference"
46 INTERVAL = "interval"
47 CALLBACK = "callback"
48
49
50class DeadlineAlert:
51 """Store Deadline values needed to calculate the need-by timestamp and the callback information."""
52
53 def __init__(
54 self,
55 reference: DeadlineReferenceType,
56 interval: timedelta,
57 callback: Callback,
58 ):
59 self.reference = reference
60 self.interval = interval
61
62 if not isinstance(callback, AsyncCallback):
63 raise ValueError(f"Callbacks of type {type(callback).__name__} are not currently supported")
64 self.callback = callback
65
66 def __eq__(self, other: object) -> bool:
67 if not isinstance(other, DeadlineAlert):
68 return NotImplemented
69 return (
70 isinstance(self.reference, type(other.reference))
71 and self.interval == other.interval
72 and self.callback == other.callback
73 )
74
75 def __hash__(self) -> int:
76 return hash(
77 (
78 type(self.reference).__name__,
79 self.interval,
80 self.callback,
81 )
82 )
83
84 def serialize_deadline_alert(self):
85 """Return the data in a format that BaseSerialization can handle."""
86 return {
87 Encoding.TYPE: DAT.DEADLINE_ALERT,
88 Encoding.VAR: {
89 DeadlineAlertFields.REFERENCE: self.reference.serialize_reference(),
90 DeadlineAlertFields.INTERVAL: self.interval.total_seconds(),
91 DeadlineAlertFields.CALLBACK: serialize(self.callback),
92 },
93 }
94
95 @classmethod
96 def deserialize_deadline_alert(cls, encoded_data: dict) -> DeadlineAlert:
97 """Deserialize a DeadlineAlert from serialized data."""
98 data = encoded_data.get(Encoding.VAR, encoded_data)
99
100 reference_data = data[DeadlineAlertFields.REFERENCE]
101 reference_type = reference_data[ReferenceModels.REFERENCE_TYPE_FIELD]
102
103 reference_class = ReferenceModels.get_reference_class(reference_type)
104 reference = reference_class.deserialize_reference(reference_data)
105
106 return cls(
107 reference=reference,
108 interval=timedelta(seconds=data[DeadlineAlertFields.INTERVAL]),
109 callback=cast("Callback", deserialize(data[DeadlineAlertFields.CALLBACK])),
110 )
111
112
113class DeadlineReference:
114 """
115 The public interface class for all DeadlineReference options.
116
117 This class provides a unified interface for working with Deadlines, supporting both
118 calculated deadlines (which fetch values from the database) and fixed deadlines
119 (which return a predefined datetime).
120
121 ------
122 Usage:
123 ------
124
125 1. Example deadline references:
126 fixed = DeadlineReference.FIXED_DATETIME(datetime(2025, 5, 4))
127 logical = DeadlineReference.DAGRUN_LOGICAL_DATE
128 queued = DeadlineReference.DAGRUN_QUEUED_AT
129
130 2. Using in a DAG:
131 DAG(
132 dag_id='dag_with_deadline',
133 deadline=DeadlineAlert(
134 reference=DeadlineReference.DAGRUN_LOGICAL_DATE,
135 interval=timedelta(hours=1),
136 callback=hello_callback,
137 )
138 )
139
140 3. Evaluating deadlines will ignore unexpected parameters:
141 # For deadlines requiring parameters:
142 deadline = DeadlineReference.DAGRUN_LOGICAL_DATE
143 deadline.evaluate_with(dag_id=dag.dag_id)
144
145 # For deadlines with no required parameters:
146 deadline = DeadlineReference.FIXED_DATETIME(datetime(2025, 5, 4))
147 deadline.evaluate_with()
148 """
149
150 class TYPES:
151 """Collection of DeadlineReference types for type checking."""
152
153 # Deadlines that should be created when the DagRun is created.
154 DAGRUN_CREATED: DeadlineReferenceTypes = (
155 ReferenceModels.DagRunLogicalDateDeadline,
156 ReferenceModels.FixedDatetimeDeadline,
157 ReferenceModels.AverageRuntimeDeadline,
158 )
159
160 # Deadlines that should be created when the DagRun is queued.
161 DAGRUN_QUEUED: DeadlineReferenceTypes = (ReferenceModels.DagRunQueuedAtDeadline,)
162
163 # All DagRun-related deadline types.
164 DAGRUN: DeadlineReferenceTypes = DAGRUN_CREATED + DAGRUN_QUEUED
165
166 from airflow.models.deadline import ReferenceModels
167
168 DAGRUN_LOGICAL_DATE: DeadlineReferenceType = ReferenceModels.DagRunLogicalDateDeadline()
169 DAGRUN_QUEUED_AT: DeadlineReferenceType = ReferenceModels.DagRunQueuedAtDeadline()
170
171 @classmethod
172 def AVERAGE_RUNTIME(cls, max_runs: int = 0, min_runs: int | None = None) -> DeadlineReferenceType:
173 if max_runs == 0:
174 max_runs = cls.ReferenceModels.AverageRuntimeDeadline.DEFAULT_LIMIT
175 if min_runs is None:
176 min_runs = max_runs
177 return cls.ReferenceModels.AverageRuntimeDeadline(max_runs, min_runs)
178
179 @classmethod
180 def FIXED_DATETIME(cls, datetime: datetime) -> DeadlineReferenceType:
181 return cls.ReferenceModels.FixedDatetimeDeadline(datetime)
182
183 # TODO: Remove this once other deadline types exist.
184 # This is a temporary reference type used only in tests to verify that
185 # dag.has_dagrun_deadline() returns false if the dag has a non-dagrun deadline type.
186 # It should be replaced with a real non-dagrun deadline type when one is available.
187 _TEMPORARY_TEST_REFERENCE = type(
188 "TemporaryTestDeadlineForTypeChecking",
189 (DeadlineReferenceType,),
190 {"_evaluate_with": lambda self, **kwargs: datetime.now()},
191 )()
192
193 @classmethod
194 def register_custom_reference(
195 cls,
196 reference_class: type[ReferenceModels.BaseDeadlineReference],
197 deadline_reference_type: DeadlineReferenceTypes | None = None,
198 ) -> type[ReferenceModels.BaseDeadlineReference]:
199 """
200 Register a custom deadline reference class.
201
202 :param reference_class: The custom reference class inheriting from BaseDeadlineReference
203 :param deadline_reference_type: A DeadlineReference.TYPES for when the deadline should be evaluated ("DAGRUN_CREATED",
204 "DAGRUN_QUEUED", etc.); defaults to DeadlineReference.TYPES.DAGRUN_CREATED
205 """
206 from airflow.models.deadline import ReferenceModels
207
208 # Default to DAGRUN_CREATED if no deadline_reference_type specified
209 if deadline_reference_type is None:
210 deadline_reference_type = cls.TYPES.DAGRUN_CREATED
211
212 # Validate the reference class inherits from BaseDeadlineReference
213 if not issubclass(reference_class, ReferenceModels.BaseDeadlineReference):
214 raise ValueError(f"{reference_class.__name__} must inherit from BaseDeadlineReference")
215
216 # Register the new reference with ReferenceModels and DeadlineReference for discoverability
217 setattr(ReferenceModels, reference_class.__name__, reference_class)
218 setattr(cls, reference_class.__name__, reference_class())
219 logger.info("Registered DeadlineReference %s", reference_class.__name__)
220
221 # Add to appropriate deadline_reference_type classification
222 if deadline_reference_type is cls.TYPES.DAGRUN_CREATED:
223 cls.TYPES.DAGRUN_CREATED = cls.TYPES.DAGRUN_CREATED + (reference_class,)
224 elif deadline_reference_type is cls.TYPES.DAGRUN_QUEUED:
225 cls.TYPES.DAGRUN_QUEUED = cls.TYPES.DAGRUN_QUEUED + (reference_class,)
226 else:
227 raise ValueError(
228 f"Invalid deadline reference type {deadline_reference_type}; "
229 "must be a valid DeadlineReference.TYPES option."
230 )
231
232 # Refresh the combined DAGRUN tuple
233 cls.TYPES.DAGRUN = cls.TYPES.DAGRUN_CREATED + cls.TYPES.DAGRUN_QUEUED
234
235 return reference_class
236
237
238def deadline_reference(
239 deadline_reference_type: DeadlineReferenceTypes | None = None,
240) -> Callable[[type[ReferenceModels.BaseDeadlineReference]], type[ReferenceModels.BaseDeadlineReference]]:
241 """
242 Decorate a class to register a custom deadline reference.
243
244 Usage:
245 @deadline_reference()
246 class MyCustomReference(ReferenceModels.BaseDeadlineReference):
247 # By default, evaluate_with will be called when a new dagrun is created.
248 def _evaluate_with(self, *, session: Session, **kwargs) -> datetime:
249 # Put your business logic here
250 return some_datetime
251
252 @deadline_reference(DeadlineReference.TYPES.DAGRUN_QUEUED)
253 class MyQueuedRef(ReferenceModels.BaseDeadlineReference):
254 # Optionally, you can specify when you want it calculated by providing a DeadlineReference.TYPES
255 def _evaluate_with(self, *, session: Session, **kwargs) -> datetime:
256 # Put your business logic here
257 return some_datetime
258 """
259
260 def decorator(
261 reference_class: type[ReferenceModels.BaseDeadlineReference],
262 ) -> type[ReferenceModels.BaseDeadlineReference]:
263 DeadlineReference.register_custom_reference(reference_class, deadline_reference_type)
264 return reference_class
265
266 return decorator