1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements. See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership. The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License. You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied. See the License for the
15# specific language governing permissions and limitations
16# under the License.
17from __future__ import annotations
18
19from datetime import datetime
20from enum import Enum
21from importlib import import_module
22from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable
23
24import structlog
25import uuid6
26from sqlalchemy import ForeignKey, Integer, String, Text
27from sqlalchemy.orm import Mapped, relationship
28from sqlalchemy_utils import UUIDType
29
30from airflow._shared.timezones import timezone
31from airflow.models import Base
32from airflow.observability.stats import Stats
33from airflow.utils.sqlalchemy import ExtendedJSON, UtcDateTime, mapped_column
34
35if TYPE_CHECKING:
36 from sqlalchemy.orm import Session
37
38 from airflow.callbacks.callback_requests import CallbackRequest
39 from airflow.triggers.base import TriggerEvent
40
41log = structlog.get_logger(__name__)
42
43
44class CallbackState(str, Enum):
45 """All possible states of callbacks."""
46
47 PENDING = "pending"
48 QUEUED = "queued"
49 RUNNING = "running"
50 SUCCESS = "success"
51 FAILED = "failed"
52
53 def __str__(self) -> str:
54 return self.value
55
56
57ACTIVE_STATES = frozenset((CallbackState.QUEUED, CallbackState.RUNNING))
58TERMINAL_STATES = frozenset((CallbackState.SUCCESS, CallbackState.FAILED))
59
60
61class CallbackType(str, Enum):
62 """
63 Types of Callbacks.
64
65 Used for figuring out what class to instantiate during deserialization.
66 """
67
68 TRIGGERER = "triggerer"
69 EXECUTOR = "executor"
70 DAG_PROCESSOR = "dag_processor"
71
72
73class CallbackFetchMethod(str, Enum):
74 """Methods used to fetch callback at runtime."""
75
76 # For future use once Dag Processor callbacks (on_success_callback/on_failure_callback) get moved to executors
77 DAG_ATTRIBUTE = "dag_attribute"
78
79 # For deadline callbacks since they import callbacks through the import path
80 IMPORT_PATH = "import_path"
81
82
83class CallbackDefinitionProtocol(Protocol):
84 """Protocol for TaskSDK Callback definition."""
85
86 def serialize(self) -> dict[str, Any]:
87 """Serialize to a dictionary."""
88 ...
89
90
91@runtime_checkable
92class ImportPathCallbackDefProtocol(CallbackDefinitionProtocol, Protocol):
93 """Protocol for callbacks that use the import path fetch method."""
94
95 path: str
96 kwargs: dict
97
98
99@runtime_checkable
100class ImportPathExecutorCallbackDefProtocol(ImportPathCallbackDefProtocol, Protocol):
101 """Protocol for callbacks that use the import path fetch method and have an executor attribute to specify the executor to run them on."""
102
103 executor: str | None
104
105
106class Callback(Base):
107 """Base class for callbacks."""
108
109 __tablename__ = "callback"
110
111 id: Mapped[str] = mapped_column(UUIDType(binary=False), primary_key=True, default=uuid6.uuid7)
112
113 # This is used by SQLAlchemy to be able to deserialize DB rows to subclasses
114 __mapper_args__ = {
115 "polymorphic_identity": "callback",
116 "polymorphic_on": "type",
117 }
118 type: Mapped[str] = mapped_column(String(20), nullable=False)
119
120 # Method used to fetch the callback, of type: CallbackFetchMethod
121 fetch_method: Mapped[str] = mapped_column(String(20), nullable=False)
122
123 # Used by subclasses to store information about how to run the callback
124 data: Mapped[dict] = mapped_column(ExtendedJSON, nullable=False)
125
126 # State of the Callback of type: CallbackState. Can be null for instances of DagProcessorCallback.
127 state: Mapped[str | None] = mapped_column(String(10))
128
129 # Return value of the callback if successful, otherwise exception details
130 output: Mapped[str | None] = mapped_column(Text, nullable=True)
131
132 # Used for prioritization. Higher weight -> higher priority
133 priority_weight: Mapped[int] = mapped_column(Integer, nullable=False)
134
135 # Creation time of the callback
136 created_at: Mapped[datetime] = mapped_column(UtcDateTime, default=timezone.utcnow, nullable=False)
137
138 # Used for callbacks of type CallbackType.TRIGGERER
139 trigger_id: Mapped[int] = mapped_column(Integer, ForeignKey("trigger.id"), nullable=True)
140 trigger = relationship("Trigger", back_populates="callback", uselist=False)
141
142 def __init__(self, priority_weight: int = 1, prefix: str = "", **kwargs):
143 """
144 Initialize a Callback. This is the base class so it shouldn't usually need to be initialized.
145
146 :param priority_weight: Priority for callback execution (higher value -> higher priority)
147 :param prefix: Optional prefix for metric names
148 :param kwargs: Additional data emitted in metric tags
149 """
150 self.state = CallbackState.PENDING
151 self.priority_weight = priority_weight
152 self.data = kwargs # kwargs can be used to include additional info in metric tags
153 if prefix:
154 self.data["prefix"] = prefix
155
156 def queue(self):
157 self.state = CallbackState.QUEUED
158
159 def get_metric_info(self, status: CallbackState, result: Any) -> dict:
160 tags = {"result": result, **self.data}
161 tags.pop("prefix", None)
162
163 if "kwargs" in tags:
164 # Remove the context (if exists) to keep the tags simple
165 tags["kwargs"] = {k: v for k, v in tags["kwargs"].items() if k != "context"}
166
167 prefix = self.data.get("prefix", "")
168 name = f"{prefix}.callback_{status}" if prefix else f"callback_{status}"
169
170 return {"stat": name, "tags": tags}
171
172 @staticmethod
173 def create_from_sdk_def(callback_def: CallbackDefinitionProtocol, **kwargs) -> Callback:
174 # Cannot check actual type using isinstance() because that would require SDK import
175 match type(callback_def).__name__:
176 case "AsyncCallback":
177 if TYPE_CHECKING:
178 assert isinstance(callback_def, ImportPathCallbackDefProtocol)
179 return TriggererCallback(callback_def, **kwargs)
180
181 case "SyncCallback":
182 if TYPE_CHECKING:
183 assert isinstance(callback_def, ImportPathExecutorCallbackDefProtocol)
184 return ExecutorCallback(callback_def, fetch_method=CallbackFetchMethod.IMPORT_PATH, **kwargs)
185
186 case _:
187 raise ValueError(f"Cannot handle Callback of type {type(callback_def)}")
188
189
190class TriggererCallback(Callback):
191 """Callbacks that run on the Triggerer (must be async)."""
192
193 __mapper_args__ = {"polymorphic_identity": CallbackType.TRIGGERER}
194
195 def __init__(self, callback_def: ImportPathCallbackDefProtocol, **kwargs):
196 """
197 Initialize a TriggererCallback from a callback definition.
198
199 :param callback_def: Callback definition with path and kwargs
200 :param kwargs: Passed to parent Callback.__init__ (see base class for details)
201 """
202 super().__init__(**kwargs)
203 self.fetch_method = CallbackFetchMethod.IMPORT_PATH
204 self.data |= callback_def.serialize()
205
206 def __repr__(self):
207 return f"{self.data['path']}({self.data['kwargs'] or ''}) on a triggerer"
208
209 def queue(self):
210 from airflow.models.trigger import Trigger
211 from airflow.triggers.callback import CallbackTrigger
212
213 self.trigger = Trigger.from_object(
214 CallbackTrigger(
215 callback_path=self.data["path"],
216 callback_kwargs=self.data["kwargs"],
217 )
218 )
219 super().queue()
220
221 def handle_event(self, event: TriggerEvent, session: Session):
222 from airflow.triggers.callback import PAYLOAD_BODY_KEY, PAYLOAD_STATUS_KEY
223
224 if (status := event.payload.get(PAYLOAD_STATUS_KEY)) and status in (ACTIVE_STATES | TERMINAL_STATES):
225 self.state = status
226 if status in TERMINAL_STATES:
227 self.trigger = None
228 self.output = event.payload.get(PAYLOAD_BODY_KEY)
229 Stats.incr(**self.get_metric_info(status, self.output))
230
231 session.add(self)
232 else:
233 log.error("Unexpected event received: %s", event.payload)
234
235
236class ExecutorCallback(Callback):
237 """Callbacks that run on the executor."""
238
239 __mapper_args__ = {"polymorphic_identity": CallbackType.EXECUTOR}
240
241 def __init__(
242 self, callback_def: ImportPathExecutorCallbackDefProtocol, fetch_method: CallbackFetchMethod, **kwargs
243 ):
244 """
245 Initialize an ExecutorCallback from a callback definition and fetch method.
246
247 :param callback_def: Callback definition with path, kwargs, and executor
248 :param fetch_method: Method to fetch the callback at runtime
249 :param kwargs: Passed to parent Callback.__init__ (see base class for details)
250 """
251 super().__init__(**kwargs)
252 self.fetch_method = fetch_method
253 self.data |= callback_def.serialize()
254
255 def __repr__(self):
256 return f"{self.data['path']}({self.data['kwargs'] or ''}) on {self.data.get('executor', 'default')} executor"
257
258
259class DagProcessorCallback(Callback):
260 """Used to store Dag Processor's callback requests in the DB."""
261
262 __mapper_args__ = {"polymorphic_identity": CallbackType.DAG_PROCESSOR}
263
264 def __init__(self, priority_weight: int, callback: CallbackRequest):
265 """Initialize a DagProcessorCallback from a callback request."""
266 super().__init__(priority_weight=priority_weight)
267
268 self.fetch_method = CallbackFetchMethod.DAG_ATTRIBUTE
269 self.state = None
270 self.data |= {"req_class": callback.__class__.__name__, "req_data": callback.to_json()}
271
272 def get_callback_request(self) -> CallbackRequest:
273 module = import_module("airflow.callbacks.callback_requests")
274 callback_request_class = getattr(module, self.data["req_class"])
275 # Get the function (from the instance) that we need to call
276 from_json = getattr(callback_request_class, "from_json")
277 return from_json(self.data["req_data"])