Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/airflow/models/callback.py: 55%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

132 statements  

1# Licensed to the Apache Software Foundation (ASF) under one 

2# or more contributor license agreements. See the NOTICE file 

3# distributed with this work for additional information 

4# regarding copyright ownership. The ASF licenses this file 

5# to you under the Apache License, Version 2.0 (the 

6# "License"); you may not use this file except in compliance 

7# with the License. You may obtain a copy of the License at 

8# 

9# http://www.apache.org/licenses/LICENSE-2.0 

10# 

11# Unless required by applicable law or agreed to in writing, 

12# software distributed under the License is distributed on an 

13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

14# KIND, either express or implied. See the License for the 

15# specific language governing permissions and limitations 

16# under the License. 

17from __future__ import annotations 

18 

19from datetime import datetime 

20from enum import Enum 

21from importlib import import_module 

22from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable 

23 

24import structlog 

25import uuid6 

26from sqlalchemy import ForeignKey, Integer, String, Text 

27from sqlalchemy.orm import Mapped, relationship 

28from sqlalchemy_utils import UUIDType 

29 

30from airflow._shared.timezones import timezone 

31from airflow.models import Base 

32from airflow.observability.stats import Stats 

33from airflow.utils.sqlalchemy import ExtendedJSON, UtcDateTime, mapped_column 

34 

35if TYPE_CHECKING: 

36 from sqlalchemy.orm import Session 

37 

38 from airflow.callbacks.callback_requests import CallbackRequest 

39 from airflow.triggers.base import TriggerEvent 

40 

41log = structlog.get_logger(__name__) 

42 

43 

44class CallbackState(str, Enum): 

45 """All possible states of callbacks.""" 

46 

47 PENDING = "pending" 

48 QUEUED = "queued" 

49 RUNNING = "running" 

50 SUCCESS = "success" 

51 FAILED = "failed" 

52 

53 def __str__(self) -> str: 

54 return self.value 

55 

56 

57ACTIVE_STATES = frozenset((CallbackState.QUEUED, CallbackState.RUNNING)) 

58TERMINAL_STATES = frozenset((CallbackState.SUCCESS, CallbackState.FAILED)) 

59 

60 

61class CallbackType(str, Enum): 

62 """ 

63 Types of Callbacks. 

64 

65 Used for figuring out what class to instantiate during deserialization. 

66 """ 

67 

68 TRIGGERER = "triggerer" 

69 EXECUTOR = "executor" 

70 DAG_PROCESSOR = "dag_processor" 

71 

72 

73class CallbackFetchMethod(str, Enum): 

74 """Methods used to fetch callback at runtime.""" 

75 

76 # For future use once Dag Processor callbacks (on_success_callback/on_failure_callback) get moved to executors 

77 DAG_ATTRIBUTE = "dag_attribute" 

78 

79 # For deadline callbacks since they import callbacks through the import path 

80 IMPORT_PATH = "import_path" 

81 

82 

83class CallbackDefinitionProtocol(Protocol): 

84 """Protocol for TaskSDK Callback definition.""" 

85 

86 def serialize(self) -> dict[str, Any]: 

87 """Serialize to a dictionary.""" 

88 ... 

89 

90 

91@runtime_checkable 

92class ImportPathCallbackDefProtocol(CallbackDefinitionProtocol, Protocol): 

93 """Protocol for callbacks that use the import path fetch method.""" 

94 

95 path: str 

96 kwargs: dict 

97 

98 

99@runtime_checkable 

100class ImportPathExecutorCallbackDefProtocol(ImportPathCallbackDefProtocol, Protocol): 

101 """Protocol for callbacks that use the import path fetch method and have an executor attribute to specify the executor to run them on.""" 

102 

103 executor: str | None 

104 

105 

106class Callback(Base): 

107 """Base class for callbacks.""" 

108 

109 __tablename__ = "callback" 

110 

111 id: Mapped[str] = mapped_column(UUIDType(binary=False), primary_key=True, default=uuid6.uuid7) 

112 

113 # This is used by SQLAlchemy to be able to deserialize DB rows to subclasses 

114 __mapper_args__ = { 

115 "polymorphic_identity": "callback", 

116 "polymorphic_on": "type", 

117 } 

118 type: Mapped[str] = mapped_column(String(20), nullable=False) 

119 

120 # Method used to fetch the callback, of type: CallbackFetchMethod 

121 fetch_method: Mapped[str] = mapped_column(String(20), nullable=False) 

122 

123 # Used by subclasses to store information about how to run the callback 

124 data: Mapped[dict] = mapped_column(ExtendedJSON, nullable=False) 

125 

126 # State of the Callback of type: CallbackState. Can be null for instances of DagProcessorCallback. 

127 state: Mapped[str | None] = mapped_column(String(10)) 

128 

129 # Return value of the callback if successful, otherwise exception details 

130 output: Mapped[str | None] = mapped_column(Text, nullable=True) 

131 

132 # Used for prioritization. Higher weight -> higher priority 

133 priority_weight: Mapped[int] = mapped_column(Integer, nullable=False) 

134 

135 # Creation time of the callback 

136 created_at: Mapped[datetime] = mapped_column(UtcDateTime, default=timezone.utcnow, nullable=False) 

137 

138 # Used for callbacks of type CallbackType.TRIGGERER 

139 trigger_id: Mapped[int] = mapped_column(Integer, ForeignKey("trigger.id"), nullable=True) 

140 trigger = relationship("Trigger", back_populates="callback", uselist=False) 

141 

142 def __init__(self, priority_weight: int = 1, prefix: str = "", **kwargs): 

143 """ 

144 Initialize a Callback. This is the base class so it shouldn't usually need to be initialized. 

145 

146 :param priority_weight: Priority for callback execution (higher value -> higher priority) 

147 :param prefix: Optional prefix for metric names 

148 :param kwargs: Additional data emitted in metric tags 

149 """ 

150 self.state = CallbackState.PENDING 

151 self.priority_weight = priority_weight 

152 self.data = kwargs # kwargs can be used to include additional info in metric tags 

153 if prefix: 

154 self.data["prefix"] = prefix 

155 

156 def queue(self): 

157 self.state = CallbackState.QUEUED 

158 

159 def get_metric_info(self, status: CallbackState, result: Any) -> dict: 

160 tags = {"result": result, **self.data} 

161 tags.pop("prefix", None) 

162 

163 if "kwargs" in tags: 

164 # Remove the context (if exists) to keep the tags simple 

165 tags["kwargs"] = {k: v for k, v in tags["kwargs"].items() if k != "context"} 

166 

167 prefix = self.data.get("prefix", "") 

168 name = f"{prefix}.callback_{status}" if prefix else f"callback_{status}" 

169 

170 return {"stat": name, "tags": tags} 

171 

172 @staticmethod 

173 def create_from_sdk_def(callback_def: CallbackDefinitionProtocol, **kwargs) -> Callback: 

174 # Cannot check actual type using isinstance() because that would require SDK import 

175 match type(callback_def).__name__: 

176 case "AsyncCallback": 

177 if TYPE_CHECKING: 

178 assert isinstance(callback_def, ImportPathCallbackDefProtocol) 

179 return TriggererCallback(callback_def, **kwargs) 

180 

181 case "SyncCallback": 

182 if TYPE_CHECKING: 

183 assert isinstance(callback_def, ImportPathExecutorCallbackDefProtocol) 

184 return ExecutorCallback(callback_def, fetch_method=CallbackFetchMethod.IMPORT_PATH, **kwargs) 

185 

186 case _: 

187 raise ValueError(f"Cannot handle Callback of type {type(callback_def)}") 

188 

189 

190class TriggererCallback(Callback): 

191 """Callbacks that run on the Triggerer (must be async).""" 

192 

193 __mapper_args__ = {"polymorphic_identity": CallbackType.TRIGGERER} 

194 

195 def __init__(self, callback_def: ImportPathCallbackDefProtocol, **kwargs): 

196 """ 

197 Initialize a TriggererCallback from a callback definition. 

198 

199 :param callback_def: Callback definition with path and kwargs 

200 :param kwargs: Passed to parent Callback.__init__ (see base class for details) 

201 """ 

202 super().__init__(**kwargs) 

203 self.fetch_method = CallbackFetchMethod.IMPORT_PATH 

204 self.data |= callback_def.serialize() 

205 

206 def __repr__(self): 

207 return f"{self.data['path']}({self.data['kwargs'] or ''}) on a triggerer" 

208 

209 def queue(self): 

210 from airflow.models.trigger import Trigger 

211 from airflow.triggers.callback import CallbackTrigger 

212 

213 self.trigger = Trigger.from_object( 

214 CallbackTrigger( 

215 callback_path=self.data["path"], 

216 callback_kwargs=self.data["kwargs"], 

217 ) 

218 ) 

219 super().queue() 

220 

221 def handle_event(self, event: TriggerEvent, session: Session): 

222 from airflow.triggers.callback import PAYLOAD_BODY_KEY, PAYLOAD_STATUS_KEY 

223 

224 if (status := event.payload.get(PAYLOAD_STATUS_KEY)) and status in (ACTIVE_STATES | TERMINAL_STATES): 

225 self.state = status 

226 if status in TERMINAL_STATES: 

227 self.trigger = None 

228 self.output = event.payload.get(PAYLOAD_BODY_KEY) 

229 Stats.incr(**self.get_metric_info(status, self.output)) 

230 

231 session.add(self) 

232 else: 

233 log.error("Unexpected event received: %s", event.payload) 

234 

235 

236class ExecutorCallback(Callback): 

237 """Callbacks that run on the executor.""" 

238 

239 __mapper_args__ = {"polymorphic_identity": CallbackType.EXECUTOR} 

240 

241 def __init__( 

242 self, callback_def: ImportPathExecutorCallbackDefProtocol, fetch_method: CallbackFetchMethod, **kwargs 

243 ): 

244 """ 

245 Initialize an ExecutorCallback from a callback definition and fetch method. 

246 

247 :param callback_def: Callback definition with path, kwargs, and executor 

248 :param fetch_method: Method to fetch the callback at runtime 

249 :param kwargs: Passed to parent Callback.__init__ (see base class for details) 

250 """ 

251 super().__init__(**kwargs) 

252 self.fetch_method = fetch_method 

253 self.data |= callback_def.serialize() 

254 

255 def __repr__(self): 

256 return f"{self.data['path']}({self.data['kwargs'] or ''}) on {self.data.get('executor', 'default')} executor" 

257 

258 

259class DagProcessorCallback(Callback): 

260 """Used to store Dag Processor's callback requests in the DB.""" 

261 

262 __mapper_args__ = {"polymorphic_identity": CallbackType.DAG_PROCESSOR} 

263 

264 def __init__(self, priority_weight: int, callback: CallbackRequest): 

265 """Initialize a DagProcessorCallback from a callback request.""" 

266 super().__init__(priority_weight=priority_weight) 

267 

268 self.fetch_method = CallbackFetchMethod.DAG_ATTRIBUTE 

269 self.state = None 

270 self.data |= {"req_class": callback.__class__.__name__, "req_data": callback.to_json()} 

271 

272 def get_callback_request(self) -> CallbackRequest: 

273 module = import_module("airflow.callbacks.callback_requests") 

274 callback_request_class = getattr(module, self.data["req_class"]) 

275 # Get the function (from the instance) that we need to call 

276 from_json = getattr(callback_request_class, "from_json") 

277 return from_json(self.data["req_data"])