Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/airflow/models/deadline.py: 41%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

203 statements  

1# Licensed to the Apache Software Foundation (ASF) under one 

2# or more contributor license agreements. See the NOTICE file 

3# distributed with this work for additional information 

4# regarding copyright ownership. The ASF licenses this file 

5# to you under the Apache License, Version 2.0 (the 

6# "License"); you may not use this file except in compliance 

7# with the License. You may obtain a copy of the License at 

8# 

9# http://www.apache.org/licenses/LICENSE-2.0 

10# 

11# Unless required by applicable law or agreed to in writing, 

12# software distributed under the License is distributed on an 

13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

14# KIND, either express or implied. See the License for the 

15# specific language governing permissions and limitations 

16# under the License. 

17from __future__ import annotations 

18 

19import logging 

20from abc import ABC, abstractmethod 

21from dataclasses import dataclass 

22from datetime import datetime, timedelta 

23from typing import TYPE_CHECKING, Any, cast 

24 

25import uuid6 

26from sqlalchemy import Boolean, ForeignKey, Index, Integer, and_, func, inspect, select, text 

27from sqlalchemy.exc import SQLAlchemyError 

28from sqlalchemy.orm import Mapped, relationship 

29from sqlalchemy_utils import UUIDType 

30 

31from airflow._shared.timezones import timezone 

32from airflow.models.base import Base 

33from airflow.models.callback import Callback, CallbackDefinitionProtocol 

34from airflow.observability.stats import Stats 

35from airflow.utils.log.logging_mixin import LoggingMixin 

36from airflow.utils.session import provide_session 

37from airflow.utils.sqlalchemy import UtcDateTime, get_dialect_name, mapped_column 

38 

39if TYPE_CHECKING: 

40 from sqlalchemy.orm import Session 

41 from sqlalchemy.sql import ColumnElement 

42 

43 

44logger = logging.getLogger(__name__) 

45 

46CALLBACK_METRICS_PREFIX = "deadline_alerts" 

47 

48 

49class classproperty: 

50 """ 

51 Decorator that converts a method with a single cls argument into a property. 

52 

53 Mypy won't let us use both @property and @classmethod together, this is a workaround 

54 to combine the two. 

55 

56 Usage: 

57 

58 class Circle: 

59 def __init__(self, radius): 

60 self.radius = radius 

61 

62 @classproperty 

63 def pi(cls): 

64 return 3.14159 

65 

66 print(Circle.pi) # Outputs: 3.14159 

67 """ 

68 

69 def __init__(self, method): 

70 self.method = method 

71 

72 def __get__(self, instance, cls=None): 

73 return self.method(cls) 

74 

75 

76class Deadline(Base): 

77 """A Deadline is a 'need-by' date which triggers a callback if the provided time has passed.""" 

78 

79 __tablename__ = "deadline" 

80 

81 id: Mapped[str] = mapped_column(UUIDType(binary=False), primary_key=True, default=uuid6.uuid7) 

82 

83 # If the Deadline Alert is for a DAG, store the DAG run ID from the dag_run. 

84 dagrun_id: Mapped[int | None] = mapped_column( 

85 Integer, ForeignKey("dag_run.id", ondelete="CASCADE"), nullable=True 

86 ) 

87 dagrun = relationship("DagRun", back_populates="deadlines") 

88 

89 # The time after which the Deadline has passed and the callback should be triggered. 

90 deadline_time: Mapped[datetime] = mapped_column(UtcDateTime, nullable=False) 

91 

92 # Whether the deadline has been marked as missed by the scheduler 

93 missed: Mapped[bool] = mapped_column(Boolean, nullable=False) 

94 

95 # Callback that will run when this deadline is missed 

96 callback_id: Mapped[str] = mapped_column( 

97 UUIDType(binary=False), ForeignKey("callback.id", ondelete="CASCADE"), nullable=False 

98 ) 

99 callback = relationship("Callback", uselist=False, cascade="all, delete-orphan", single_parent=True) 

100 

101 __table_args__ = (Index("deadline_missed_deadline_time_idx", missed, deadline_time, unique=False),) 

102 

103 def __init__( 

104 self, 

105 deadline_time: datetime, 

106 callback: CallbackDefinitionProtocol, 

107 dagrun_id: int, 

108 dag_id: str | None = None, 

109 ): 

110 super().__init__() 

111 self.deadline_time = deadline_time 

112 self.dagrun_id = dagrun_id 

113 self.missed = False 

114 self.callback = Callback.create_from_sdk_def( 

115 callback_def=callback, prefix=CALLBACK_METRICS_PREFIX, dag_id=dag_id 

116 ) 

117 

118 def __repr__(self): 

119 def _determine_resource() -> tuple[str, str]: 

120 """Determine the type of resource based on which values are present.""" 

121 if self.dagrun_id: 

122 # The deadline is for a Dag run: 

123 return "DagRun", f"Dag: {self.dagrun.dag_id} Run: {self.dagrun_id}" 

124 

125 return "Unknown", "" 

126 

127 resource_type, resource_details = _determine_resource() 

128 

129 return ( 

130 f"[{resource_type} Deadline] {resource_details} needed by " 

131 f"{self.deadline_time} or run: {self.callback}" 

132 ) 

133 

134 @classmethod 

135 def prune_deadlines(cls, *, session: Session, conditions: dict[Mapped, Any]) -> int: 

136 """ 

137 Remove deadlines from the table which match the provided conditions and return the number removed. 

138 

139 NOTE: This should only be used to remove deadlines which are associated with 

140 successful events (DagRuns, etc). If the deadline was missed, it will be 

141 handled by the scheduler. 

142 

143 :param conditions: Dictionary of conditions to evaluate against. 

144 :param session: Session to use. 

145 """ 

146 from airflow.models import DagRun # Avoids circular import 

147 

148 # Assemble the filter conditions. 

149 filter_conditions = [column == value for column, value in conditions.items()] 

150 if not filter_conditions: 

151 return 0 

152 

153 try: 

154 # Get deadlines which match the provided conditions and their associated DagRuns. 

155 deadline_dagrun_pairs = session.execute( 

156 select(Deadline, DagRun).join(DagRun).where(and_(*filter_conditions)) 

157 ).all() 

158 

159 except AttributeError as e: 

160 logger.exception("Error resolving deadlines: %s", e) 

161 raise 

162 

163 if not deadline_dagrun_pairs: 

164 return 0 

165 

166 deleted_count = 0 

167 dagruns_to_refresh = set() 

168 

169 for deadline, dagrun in deadline_dagrun_pairs: 

170 if dagrun.end_date <= deadline.deadline_time: 

171 # If the DagRun finished before the Deadline: 

172 session.delete(deadline) 

173 Stats.incr( 

174 "deadline_alerts.deadline_not_missed", 

175 tags={"dag_id": dagrun.dag_id, "dagrun_id": dagrun.run_id}, 

176 ) 

177 deleted_count += 1 

178 dagruns_to_refresh.add(dagrun) 

179 session.flush() 

180 

181 logger.debug("%d deadline records were deleted matching the conditions %s", deleted_count, conditions) 

182 

183 # Refresh any affected DAG runs. 

184 for dagrun in dagruns_to_refresh: 

185 session.refresh(dagrun) 

186 

187 return deleted_count 

188 

189 def handle_miss(self, session: Session): 

190 """Handle a missed deadline by queueing the callback.""" 

191 

192 def get_simple_context(): 

193 from airflow.api_fastapi.core_api.datamodels.dag_run import DAGRunResponse 

194 from airflow.models import DagRun 

195 

196 # TODO: Use the TaskAPI from within Triggerer to fetch full context instead of sending this context 

197 # from the scheduler 

198 

199 # Fetch the DagRun from the database again to avoid errors when self.dagrun's relationship fields 

200 # are not in the current session. 

201 dagrun = session.get(DagRun, self.dagrun_id) 

202 

203 return { 

204 "dag_run": DAGRunResponse.model_validate(dagrun).model_dump(mode="json"), 

205 "deadline": {"id": self.id, "deadline_time": self.deadline_time}, 

206 } 

207 

208 self.callback.data["kwargs"] = self.callback.data["kwargs"] | {"context": get_simple_context()} 

209 self.missed = True 

210 self.callback.queue() 

211 session.add(self) 

212 Stats.incr( 

213 "deadline_alerts.deadline_missed", 

214 tags={"dag_id": self.dagrun.dag_id, "dagrun_id": self.dagrun.run_id}, 

215 ) 

216 

217 

218class ReferenceModels: 

219 """ 

220 Store the implementations for the different Deadline References. 

221 

222 After adding the implementations here, all DeadlineReferences should be added 

223 to the user interface in airflow.sdk.definitions.deadline.DeadlineReference 

224 """ 

225 

226 REFERENCE_TYPE_FIELD = "reference_type" 

227 

228 @classmethod 

229 def get_reference_class(cls, reference_name: str) -> type[BaseDeadlineReference]: 

230 """ 

231 Get a reference class by its name. 

232 

233 :param reference_name: The name of the reference class to find 

234 """ 

235 try: 

236 return next( 

237 ref_class 

238 for name, ref_class in vars(cls).items() 

239 if isinstance(ref_class, type) 

240 and issubclass(ref_class, cls.BaseDeadlineReference) 

241 and ref_class.__name__ == reference_name 

242 ) 

243 except StopIteration: 

244 raise ValueError(f"No reference class found with name: {reference_name}") 

245 

246 class BaseDeadlineReference(LoggingMixin, ABC): 

247 """Base class for all Deadline implementations.""" 

248 

249 # Set of required kwargs - subclasses should override this. 

250 required_kwargs: set[str] = set() 

251 

252 @classproperty 

253 def reference_name(cls: Any) -> str: 

254 return cls.__name__ 

255 

256 def evaluate_with(self, *, session: Session, interval: timedelta, **kwargs: Any) -> datetime | None: 

257 """Validate the provided kwargs and evaluate this deadline with the given conditions.""" 

258 filtered_kwargs = {k: v for k, v in kwargs.items() if k in self.required_kwargs} 

259 

260 if missing_kwargs := self.required_kwargs - filtered_kwargs.keys(): 

261 raise ValueError( 

262 f"{self.__class__.__name__} is missing required parameters: {', '.join(missing_kwargs)}" 

263 ) 

264 

265 if extra_kwargs := kwargs.keys() - filtered_kwargs.keys(): 

266 self.log.debug("Ignoring unexpected parameters: %s", ", ".join(extra_kwargs)) 

267 

268 base_time = self._evaluate_with(session=session, **filtered_kwargs) 

269 return base_time + interval if base_time is not None else None 

270 

271 @abstractmethod 

272 def _evaluate_with(self, *, session: Session, **kwargs: Any) -> datetime | None: 

273 """Must be implemented by subclasses to perform the actual evaluation.""" 

274 raise NotImplementedError 

275 

276 @classmethod 

277 def deserialize_reference(cls, reference_data: dict): 

278 """ 

279 Deserialize a reference type from its dictionary representation. 

280 

281 While the base implementation doesn't use reference_data, this parameter is required 

282 for subclasses that need additional data for initialization (like FixedDatetimeDeadline 

283 which needs a datetime value). 

284 

285 :param reference_data: Dictionary containing serialized reference data. 

286 Always includes a 'reference_type' field, and may include additional 

287 fields needed by specific reference implementations. 

288 """ 

289 return cls() 

290 

291 def serialize_reference(self) -> dict: 

292 """ 

293 Serialize this reference type into a dictionary representation. 

294 

295 This method assumes that the reference doesn't require any additional data. 

296 Override this method in subclasses if additional data is needed for serialization. 

297 """ 

298 return {ReferenceModels.REFERENCE_TYPE_FIELD: self.reference_name} 

299 

300 @dataclass 

301 class FixedDatetimeDeadline(BaseDeadlineReference): 

302 """A deadline that always returns a fixed datetime.""" 

303 

304 _datetime: datetime 

305 

306 def _evaluate_with(self, *, session: Session, **kwargs: Any) -> datetime | None: 

307 return self._datetime 

308 

309 def serialize_reference(self) -> dict: 

310 return { 

311 ReferenceModels.REFERENCE_TYPE_FIELD: self.reference_name, 

312 "datetime": self._datetime.timestamp(), 

313 } 

314 

315 @classmethod 

316 def deserialize_reference(cls, reference_data: dict): 

317 return cls(_datetime=timezone.from_timestamp(reference_data["datetime"])) 

318 

319 class DagRunLogicalDateDeadline(BaseDeadlineReference): 

320 """A deadline that returns a DagRun's logical date.""" 

321 

322 required_kwargs = {"dag_id", "run_id"} 

323 

324 def _evaluate_with(self, *, session: Session, **kwargs: Any) -> datetime | None: 

325 from airflow.models import DagRun 

326 

327 return _fetch_from_db(DagRun.logical_date, session=session, **kwargs) 

328 

329 class DagRunQueuedAtDeadline(BaseDeadlineReference): 

330 """A deadline that returns when a DagRun was queued.""" 

331 

332 required_kwargs = {"dag_id", "run_id"} 

333 

334 @provide_session 

335 def _evaluate_with(self, *, session: Session, **kwargs: Any) -> datetime | None: 

336 from airflow.models import DagRun 

337 

338 return _fetch_from_db(DagRun.queued_at, session=session, **kwargs) 

339 

340 @dataclass 

341 class AverageRuntimeDeadline(BaseDeadlineReference): 

342 """A deadline that calculates the average runtime from past DAG runs.""" 

343 

344 DEFAULT_LIMIT = 10 

345 max_runs: int 

346 min_runs: int | None = None 

347 required_kwargs = {"dag_id"} 

348 

349 def __post_init__(self): 

350 if self.min_runs is None: 

351 self.min_runs = self.max_runs 

352 if self.min_runs < 1: 

353 raise ValueError("min_runs must be at least 1") 

354 

355 @provide_session 

356 def _evaluate_with(self, *, session: Session, **kwargs: Any) -> datetime | None: 

357 from airflow.models import DagRun 

358 

359 dag_id = kwargs["dag_id"] 

360 

361 # Get database dialect to use appropriate time difference calculation 

362 dialect = get_dialect_name(session) 

363 

364 # Create database-specific expression for calculating duration in seconds 

365 duration_expr: ColumnElement[Any] 

366 if dialect == "postgresql": 

367 duration_expr = func.extract("epoch", DagRun.end_date - DagRun.start_date) 

368 elif dialect == "mysql": 

369 # Use TIMESTAMPDIFF to get exact seconds like PostgreSQL EXTRACT(epoch FROM ...) 

370 duration_expr = func.timestampdiff(text("SECOND"), DagRun.start_date, DagRun.end_date) 

371 elif dialect == "sqlite": 

372 duration_expr = (func.julianday(DagRun.end_date) - func.julianday(DagRun.start_date)) * 86400 

373 else: 

374 raise ValueError(f"Unsupported database dialect: {dialect}") 

375 

376 # Query for completed DAG runs with both start and end dates 

377 # Order by logical_date descending to get most recent runs first 

378 query = ( 

379 select(duration_expr) 

380 .filter(DagRun.dag_id == dag_id, DagRun.start_date.isnot(None), DagRun.end_date.isnot(None)) 

381 .order_by(DagRun.logical_date.desc()) 

382 ) 

383 

384 # Apply max_runs 

385 query = query.limit(self.max_runs) 

386 

387 # Get all durations and calculate average 

388 durations = session.execute(query).scalars().all() 

389 

390 if len(durations) < cast("int", self.min_runs): 

391 logger.info( 

392 "Only %d completed DAG runs found for dag_id: %s (need %d), skipping deadline creation", 

393 len(durations), 

394 dag_id, 

395 self.min_runs, 

396 ) 

397 return None 

398 # Convert to float to handle Decimal types from MySQL while preserving precision 

399 # Use Decimal arithmetic for higher precision, then convert to float 

400 from decimal import Decimal 

401 

402 decimal_durations = [Decimal(str(d)) for d in durations] 

403 avg_seconds = float(sum(decimal_durations) / len(decimal_durations)) 

404 logger.info( 

405 "Average runtime for dag_id %s (from %d runs): %.2f seconds", 

406 dag_id, 

407 len(durations), 

408 avg_seconds, 

409 ) 

410 return timezone.utcnow() + timedelta(seconds=avg_seconds) 

411 

412 def serialize_reference(self) -> dict: 

413 return { 

414 ReferenceModels.REFERENCE_TYPE_FIELD: self.reference_name, 

415 "max_runs": self.max_runs, 

416 "min_runs": self.min_runs, 

417 } 

418 

419 @classmethod 

420 def deserialize_reference(cls, reference_data: dict): 

421 max_runs = reference_data.get("max_runs", cls.DEFAULT_LIMIT) 

422 min_runs = reference_data.get("min_runs", max_runs) 

423 if min_runs < 1: 

424 raise ValueError("min_runs must be at least 1") 

425 return cls( 

426 max_runs=max_runs, 

427 min_runs=min_runs, 

428 ) 

429 

430 

431DeadlineReferenceType = ReferenceModels.BaseDeadlineReference 

432 

433 

434@provide_session 

435def _fetch_from_db(model_reference: Mapped, session=None, **conditions) -> datetime | None: 

436 """ 

437 Fetch a datetime value from the database using the provided model reference and filtering conditions. 

438 

439 For example, to fetch a TaskInstance's start_date: 

440 _fetch_from_db( 

441 TaskInstance.start_date, dag_id='example_dag', task_id='example_task', run_id='example_run' 

442 ) 

443 

444 This generates SQL equivalent to: 

445 SELECT start_date 

446 FROM task_instance 

447 WHERE dag_id = 'example_dag' 

448 AND task_id = 'example_task' 

449 AND run_id = 'example_run' 

450 

451 :param model_reference: SQLAlchemy Column to select (e.g., DagRun.logical_date, TaskInstance.start_date) 

452 :param conditions: Filtering conditions applied as equality comparisons in the WHERE clause. 

453 Multiple conditions are combined with AND. 

454 :param session: SQLAlchemy session (auto-provided by decorator) 

455 """ 

456 query = select(model_reference) 

457 

458 for key, value in conditions.items(): 

459 inspected = inspect(model_reference) 

460 if inspected is not None: 

461 query = query.where(getattr(inspected.class_, key) == value) 

462 

463 compiled_query = query.compile(compile_kwargs={"literal_binds": True}) 

464 pretty_query = "\n ".join(str(compiled_query).splitlines()) 

465 logger.debug( 

466 "Executing query:\n %r\nAs SQL:\n %s", 

467 query, 

468 pretty_query, 

469 ) 

470 

471 try: 

472 result = session.scalar(query) 

473 except SQLAlchemyError: 

474 logger.exception("Database query failed.") 

475 raise 

476 

477 if result is None: 

478 message = f"No matching record found in the database for query:\n {pretty_query}" 

479 logger.error(message) 

480 raise ValueError(message) 

481 

482 return result