1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements. See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership. The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License. You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied. See the License for the
15# specific language governing permissions and limitations
16# under the License.
17from __future__ import annotations
18
19import abc
20import json
21from collections.abc import AsyncIterator
22from dataclasses import dataclass
23from datetime import timedelta
24from typing import Annotated, Any
25
26import structlog
27from pydantic import (
28 BaseModel,
29 Discriminator,
30 JsonValue,
31 Tag,
32 model_serializer,
33)
34
35from airflow.utils.log.logging_mixin import LoggingMixin
36from airflow.utils.state import TaskInstanceState
37
38log = structlog.get_logger(logger_name=__name__)
39
40
41@dataclass
42class StartTriggerArgs:
43 """Arguments required for start task execution from triggerer."""
44
45 trigger_cls: str
46 next_method: str
47 trigger_kwargs: dict[str, Any] | None = None
48 next_kwargs: dict[str, Any] | None = None
49 timeout: timedelta | None = None
50
51
52class BaseTrigger(abc.ABC, LoggingMixin):
53 """
54 Base class for all triggers.
55
56 A trigger has two contexts it can exist in:
57
58 - Inside an Operator, when it's passed to TaskDeferred
59 - Actively running in a trigger worker
60
61 We use the same class for both situations, and rely on all Trigger classes
62 to be able to return the arguments (possible to encode with Airflow-JSON) that will
63 let them be re-instantiated elsewhere.
64 """
65
66 def __init__(self, **kwargs):
67 # these values are set by triggerer when preparing to run the instance
68 # when run, they are injected into logger record.
69 self.task_instance = None
70 self.trigger_id = None
71
72 def _set_context(self, context):
73 """Part of LoggingMixin and used mainly for configuration of task logging; not used for triggers."""
74 raise NotImplementedError
75
76 @abc.abstractmethod
77 def serialize(self) -> tuple[str, dict[str, Any]]:
78 """
79 Return the information needed to reconstruct this Trigger.
80
81 :return: Tuple of (class path, keyword arguments needed to re-instantiate).
82 """
83 raise NotImplementedError("Triggers must implement serialize()")
84
85 @abc.abstractmethod
86 async def run(self) -> AsyncIterator[TriggerEvent]:
87 """
88 Run the trigger in an asynchronous context.
89
90 The trigger should yield an Event whenever it wants to fire off
91 an event, and return None if it is finished. Single-event triggers
92 should thus yield and then immediately return.
93
94 If it yields, it is likely that it will be resumed very quickly,
95 but it may not be (e.g. if the workload is being moved to another
96 triggerer process, or a multi-event trigger was being used for a
97 single-event task defer).
98
99 In either case, Trigger classes should assume they will be persisted,
100 and then rely on cleanup() being called when they are no longer needed.
101 """
102 raise NotImplementedError("Triggers must implement run()")
103 yield # To convince Mypy this is an async iterator.
104
105 async def cleanup(self) -> None:
106 """
107 Cleanup the trigger.
108
109 Called when the trigger is no longer needed, and it's being removed
110 from the active triggerer process.
111
112 This method follows the async/await pattern to allow to run the cleanup
113 in triggerer main event loop. Exceptions raised by the cleanup method
114 are ignored, so if you would like to be able to debug them and be notified
115 that cleanup method failed, you should wrap your code with try/except block
116 and handle it appropriately (in async-compatible way).
117 """
118
119 @staticmethod
120 def repr(classpath: str, kwargs: dict[str, Any]):
121 kwargs_str = ", ".join(f"{k}={v}" for k, v in kwargs.items())
122 return f"<{classpath} {kwargs_str}>"
123
124 def __repr__(self) -> str:
125 classpath, kwargs = self.serialize()
126 return self.repr(classpath, kwargs)
127
128
129class BaseEventTrigger(BaseTrigger):
130 """
131 Base class for triggers used to schedule DAGs based on external events.
132
133 ``BaseEventTrigger`` is a subclass of ``BaseTrigger`` designed to identify triggers compatible with
134 event-driven scheduling.
135 """
136
137 @staticmethod
138 def hash(classpath: str, kwargs: dict[str, Any]) -> int:
139 """
140 Return the hash of the trigger classpath and kwargs. This is used to uniquely identify a trigger.
141
142 We do not want to have this logic in ``BaseTrigger`` because, when used to defer tasks, 2 triggers
143 can have the same classpath and kwargs. This is not true for event driven scheduling.
144 """
145 from airflow.serialization.serialized_objects import BaseSerialization
146
147 return hash((classpath, json.dumps(BaseSerialization.serialize(kwargs)).encode("utf-8")))
148
149
150class TriggerEvent(BaseModel):
151 """
152 Something that a trigger can fire when its conditions are met.
153
154 Events must have a uniquely identifying value that would be the same
155 wherever the trigger is run; this is to ensure that if the same trigger
156 is being run in two locations (for HA reasons) that we can deduplicate its
157 events.
158 """
159
160 payload: Any = None
161 """
162 The payload for the event to send back to the task.
163
164 Must be natively JSON-serializable, or registered with the airflow serialization code.
165 """
166
167 def __init__(self, payload, **kwargs):
168 super().__init__(payload=payload, **kwargs)
169
170 def __repr__(self) -> str:
171 return f"TriggerEvent<{self.payload!r}>"
172
173
174class BaseTaskEndEvent(TriggerEvent):
175 """
176 Base event class to end the task without resuming on worker.
177
178 :meta private:
179 """
180
181 task_instance_state: TaskInstanceState
182 xcoms: dict[str, JsonValue] | None = None
183
184 def __init__(self, *, xcoms: dict[str, JsonValue] | None = None, **kwargs) -> None:
185 """
186 Initialize the class with the specified parameters.
187
188 :param xcoms: A dictionary of XComs or None.
189 :param kwargs: Additional keyword arguments.
190 """
191 if "payload" in kwargs:
192 raise ValueError("Param 'payload' not supported for this class.")
193 # Yes this is _odd_. It's to support both constructor from users of
194 # `TaskSuccessEvent(some_xcom_value)` and deserialization by pydantic.
195 state = kwargs.pop("task_instance_state", self.__pydantic_fields__["task_instance_state"].default)
196 super().__init__(payload=str(state), task_instance_state=state, **kwargs)
197 self.xcoms = xcoms
198
199 @model_serializer
200 def ser_model(self) -> dict[str, Any]:
201 # We need to customize the serialized schema so it works for the custom constructor we have to keep
202 # the interface to this class "nice"
203 return {"task_instance_state": self.task_instance_state, "xcoms": self.xcoms}
204
205
206class TaskSuccessEvent(BaseTaskEndEvent):
207 """Yield this event in order to end the task successfully."""
208
209 task_instance_state: TaskInstanceState = TaskInstanceState.SUCCESS
210
211
212class TaskFailedEvent(BaseTaskEndEvent):
213 """Yield this event in order to end the task with failure."""
214
215 task_instance_state: TaskInstanceState = TaskInstanceState.FAILED
216
217
218class TaskSkippedEvent(BaseTaskEndEvent):
219 """Yield this event in order to end the task with status 'skipped'."""
220
221 task_instance_state: TaskInstanceState = TaskInstanceState.SKIPPED
222
223
224def trigger_event_discriminator(v):
225 if isinstance(v, dict):
226 return v.get("task_instance_state", "_event_")
227 if isinstance(v, TriggerEvent):
228 return getattr(v, "task_instance_state", "_event_")
229
230
231DiscrimatedTriggerEvent = Annotated[
232 Annotated[TriggerEvent, Tag("_event_")]
233 | Annotated[TaskSuccessEvent, Tag(TaskInstanceState.SUCCESS)]
234 | Annotated[TaskFailedEvent, Tag(TaskInstanceState.FAILED)]
235 | Annotated[TaskSkippedEvent, Tag(TaskInstanceState.SKIPPED)],
236 Discriminator(trigger_event_discriminator),
237]