1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements. See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership. The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License. You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied. See the License for the
15# specific language governing permissions and limitations
16# under the License.
17from __future__ import annotations
18
19import inspect
20from abc import ABC
21from collections.abc import Callable
22from typing import Any
23
24import structlog
25
26from airflow.sdk._shared.module_loading import import_string, is_valid_dotpath
27
28log = structlog.getLogger(__name__)
29
30
31class Callback(ABC):
32 """
33 Base class for Deadline Alert callbacks.
34
35 Callbacks are used to execute custom logic when a deadline is missed.
36
37 The `callback_callable` can be a Python callable type or a string containing the path to the callable that
38 can be used to import the callable. It must be a top-level callable in a module present on the host where
39 it will run.
40
41 It will be called with Airflow context and specified kwargs when a deadline is missed.
42 """
43
44 path: str
45 kwargs: dict
46
47 def __init__(self, callback_callable: Callable | str, kwargs: dict[str, Any] | None = None):
48 self.path = self.get_callback_path(callback_callable)
49 if kwargs and "context" in kwargs:
50 raise ValueError("context is a reserved kwarg for this class")
51 self.kwargs = kwargs or {}
52
53 @classmethod
54 def get_callback_path(cls, _callback: str | Callable) -> str:
55 """Convert callback to a string path that can be used to import it later."""
56 if callable(_callback):
57 cls.verify_callable(_callback)
58
59 # TODO: This implementation doesn't support using a lambda function as a callback.
60 # We should consider that in the future, but the addition is non-trivial.
61 # Get the reference path to the callable in the form `airflow.models.deadline.get_from_db`
62 return f"{_callback.__module__}.{_callback.__qualname__}"
63
64 if not isinstance(_callback, str) or not is_valid_dotpath(_callback.strip()):
65 raise ImportError(f"`{_callback}` doesn't look like a valid dot path.")
66
67 stripped_callback = _callback.strip()
68
69 try:
70 # The provided callback is a string which appears to be a valid dotpath, attempt to import it.
71 callback = import_string(stripped_callback)
72 if not callable(callback):
73 # The input is a string which can be imported, but is not callable.
74 raise AttributeError(f"Provided callback {callback} is not callable.")
75
76 cls.verify_callable(callback)
77
78 except ImportError as e:
79 # Logging here instead of failing because it is possible that the code for the callable
80 # exists somewhere other than on the DAG processor. We are making a best effort to validate,
81 # but can't rule out that it may be available at runtime even if it can not be imported here.
82 log.debug(
83 "Callback %s is formatted like a callable dotpath, but could not be imported.\n%s",
84 stripped_callback,
85 e,
86 )
87
88 return stripped_callback
89
90 @classmethod
91 def verify_callable(cls, callback: Callable):
92 """For additional verification of the callable during initialization in subclasses."""
93 pass # No verification needed in the base class
94
95 @classmethod
96 def deserialize(cls, data: dict, version):
97 path = data.pop("path")
98 return cls(callback_callable=path, **data)
99
100 @classmethod
101 def serialized_fields(cls) -> tuple[str, ...]:
102 return ("path", "kwargs")
103
104 def serialize(self) -> dict[str, Any]:
105 return {f: getattr(self, f) for f in self.serialized_fields()}
106
107 def __eq__(self, other):
108 if type(self) is not type(other):
109 return NotImplemented
110 return self.serialize() == other.serialize()
111
112 def __hash__(self):
113 serialized = self.serialize()
114 hashable_items = []
115 for k, v in serialized.items():
116 if isinstance(v, dict):
117 hashable_items.append((k, tuple(sorted(v.items()))))
118 else:
119 hashable_items.append((k, v))
120 return hash(tuple(sorted(hashable_items)))
121
122
123class AsyncCallback(Callback):
124 """
125 Asynchronous callback that runs in the triggerer.
126
127 The `callback_callable` can be a Python callable type or a string containing the path to the callable that
128 can be used to import the callable. It must be a top-level awaitable callable in a module present on the
129 triggerer.
130
131 It will be called with Airflow context and specified kwargs when a deadline is missed.
132 """
133
134 def __init__(self, callback_callable: Callable | str, kwargs: dict | None = None):
135 super().__init__(callback_callable=callback_callable, kwargs=kwargs)
136
137 @classmethod
138 def verify_callable(cls, callback: Callable):
139 if not (inspect.iscoroutinefunction(callback) or hasattr(callback, "__await__")):
140 raise AttributeError(f"Provided callback {callback} is not awaitable.")
141
142
143class SyncCallback(Callback):
144 """
145 Synchronous callback that runs in the specified or default executor.
146
147 The `callback_callable` can be a Python callable type or a string containing the path to the callable that
148 can be used to import the callable. It must be a top-level callable in a module present on the executor.
149
150 It will be called with Airflow context and specified kwargs when a deadline is missed.
151 """
152
153 executor: str | None
154
155 def __init__(
156 self, callback_callable: Callable | str, kwargs: dict | None = None, executor: str | None = None
157 ):
158 super().__init__(callback_callable=callback_callable, kwargs=kwargs)
159 self.executor = executor
160
161 @classmethod
162 def serialized_fields(cls) -> tuple[str, ...]:
163 return super().serialized_fields() + ("executor",)