Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/airflow/models/deadline.py: 42%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

211 statements  

1# Licensed to the Apache Software Foundation (ASF) under one 

2# or more contributor license agreements. See the NOTICE file 

3# distributed with this work for additional information 

4# regarding copyright ownership. The ASF licenses this file 

5# to you under the Apache License, Version 2.0 (the 

6# "License"); you may not use this file except in compliance 

7# with the License. You may obtain a copy of the License at 

8# 

9# http://www.apache.org/licenses/LICENSE-2.0 

10# 

11# Unless required by applicable law or agreed to in writing, 

12# software distributed under the License is distributed on an 

13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

14# KIND, either express or implied. See the License for the 

15# specific language governing permissions and limitations 

16# under the License. 

17from __future__ import annotations 

18 

19import logging 

20from abc import ABC, abstractmethod 

21from collections.abc import Sequence 

22from dataclasses import dataclass 

23from datetime import datetime, timedelta 

24from typing import TYPE_CHECKING, Any, cast 

25from uuid import UUID 

26 

27import uuid6 

28from sqlalchemy import Boolean, ForeignKey, Index, Integer, Uuid, and_, func, inspect, select, text 

29from sqlalchemy.exc import SQLAlchemyError 

30from sqlalchemy.orm import Mapped, mapped_column, relationship 

31 

32from airflow._shared.observability.metrics.stats import Stats 

33from airflow._shared.timezones import timezone 

34from airflow.models.base import Base 

35from airflow.models.callback import Callback, CallbackDefinitionProtocol 

36from airflow.utils.log.logging_mixin import LoggingMixin 

37from airflow.utils.session import provide_session 

38from airflow.utils.sqlalchemy import UtcDateTime, get_dialect_name 

39 

40if TYPE_CHECKING: 

41 from sqlalchemy.orm import Session 

42 from sqlalchemy.sql import ColumnElement 

43 

44 from airflow.models.callback import CallbackDefinitionProtocol 

45 from airflow.models.deadline_alert import DeadlineAlert 

46 

47 

48logger = logging.getLogger(__name__) 

49 

50CALLBACK_METRICS_PREFIX = "deadline_alerts" 

51 

52 

53class classproperty: 

54 """ 

55 Decorator that converts a method with a single cls argument into a property. 

56 

57 Mypy won't let us use both @property and @classmethod together, this is a workaround 

58 to combine the two. 

59 

60 Usage: 

61 

62 class Circle: 

63 def __init__(self, radius): 

64 self.radius = radius 

65 

66 @classproperty 

67 def pi(cls): 

68 return 3.14159 

69 

70 print(Circle.pi) # Outputs: 3.14159 

71 """ 

72 

73 def __init__(self, method): 

74 self.method = method 

75 

76 def __get__(self, instance, cls=None): 

77 return self.method(cls) 

78 

79 

80class Deadline(Base): 

81 """A Deadline is a 'need-by' date which triggers a callback if the provided time has passed.""" 

82 

83 __tablename__ = "deadline" 

84 

85 id: Mapped[UUID] = mapped_column(Uuid(), primary_key=True, default=uuid6.uuid7) 

86 created_at: Mapped[datetime] = mapped_column(UtcDateTime, nullable=False, default=timezone.utcnow) 

87 last_updated_at: Mapped[datetime] = mapped_column( 

88 UtcDateTime, nullable=False, default=timezone.utcnow, onupdate=timezone.utcnow 

89 ) 

90 

91 # If the Deadline Alert is for a DAG, store the DAG run ID from the dag_run. 

92 dagrun_id: Mapped[int | None] = mapped_column( 

93 Integer, ForeignKey("dag_run.id", ondelete="CASCADE"), nullable=True 

94 ) 

95 dagrun = relationship("DagRun", back_populates="deadlines") 

96 

97 # The time after which the Deadline has passed and the callback should be triggered. 

98 deadline_time: Mapped[datetime] = mapped_column(UtcDateTime, nullable=False) 

99 

100 # Whether the deadline has been marked as missed by the scheduler 

101 missed: Mapped[bool] = mapped_column(Boolean, nullable=False) 

102 

103 # Callback that will run when this deadline is missed 

104 callback_id: Mapped[UUID] = mapped_column( 

105 Uuid(), ForeignKey("callback.id", ondelete="CASCADE"), nullable=False 

106 ) 

107 callback = relationship("Callback", uselist=False, cascade="all, delete-orphan", single_parent=True) 

108 

109 # The DeadlineAlert that generated this deadline 

110 deadline_alert_id: Mapped[UUID | None] = mapped_column( 

111 Uuid(), ForeignKey("deadline_alert.id", ondelete="SET NULL"), nullable=True 

112 ) 

113 deadline_alert: Mapped[DeadlineAlert | None] = relationship("DeadlineAlert") 

114 

115 __table_args__ = (Index("deadline_missed_deadline_time_idx", missed, deadline_time, unique=False),) 

116 

117 def __init__( 

118 self, 

119 deadline_time: datetime, 

120 callback: CallbackDefinitionProtocol, 

121 dagrun_id: int, 

122 deadline_alert_id: UUID | None, 

123 dag_id: str | None = None, 

124 ): 

125 super().__init__() 

126 self.deadline_time = deadline_time 

127 self.dagrun_id = dagrun_id 

128 self.missed = False 

129 self.callback = Callback.create_from_sdk_def( 

130 callback_def=callback, prefix=CALLBACK_METRICS_PREFIX, dag_id=dag_id 

131 ) 

132 self.deadline_alert_id = deadline_alert_id 

133 

134 def __repr__(self): 

135 def _determine_resource() -> tuple[str, str]: 

136 """Determine the type of resource based on which values are present.""" 

137 if self.dagrun_id: 

138 # The deadline is for a Dag run: 

139 return "DagRun", f"Dag: {self.dagrun.dag_id} Run: {self.dagrun_id}" 

140 

141 return "Unknown", "" 

142 

143 resource_type, resource_details = _determine_resource() 

144 

145 return ( 

146 f"[{resource_type} Deadline] " 

147 f"created at {self.created_at}, " 

148 f"{resource_details}, " 

149 f"needed by {self.deadline_time} " 

150 f"or run: {self.callback}" 

151 ) 

152 

153 @classmethod 

154 def prune_deadlines(cls, *, session: Session, conditions: dict[Mapped, Any]) -> int: 

155 """ 

156 Remove deadlines from the table which match the provided conditions and return the number removed. 

157 

158 NOTE: This should only be used to remove deadlines which are associated with 

159 successful events (DagRuns, etc). If the deadline was missed, it will be 

160 handled by the scheduler. 

161 

162 :param conditions: Dictionary of conditions to evaluate against. 

163 :param session: Session to use. 

164 """ 

165 from airflow.models import DagRun # Avoids circular import 

166 

167 # Assemble the filter conditions. 

168 filter_conditions = [column == value for column, value in conditions.items()] 

169 if not filter_conditions: 

170 return 0 

171 

172 try: 

173 # Get deadlines which match the provided conditions and their associated DagRuns. 

174 deadline_dagrun_pairs = session.execute( 

175 select(Deadline, DagRun).join(DagRun).where(and_(*filter_conditions)) 

176 ).all() 

177 

178 except AttributeError as e: 

179 logger.exception("Error resolving deadlines: %s", e) 

180 raise 

181 

182 if not deadline_dagrun_pairs: 

183 return 0 

184 

185 deleted_count = 0 

186 dagruns_to_refresh = set() 

187 

188 for deadline, dagrun in deadline_dagrun_pairs: 

189 if dagrun.end_date is not None and dagrun.end_date <= deadline.deadline_time: 

190 # If the DagRun finished before the Deadline: 

191 session.delete(deadline) 

192 Stats.incr( 

193 "deadline_alerts.deadline_not_missed", 

194 tags={"dag_id": dagrun.dag_id, "dagrun_id": dagrun.run_id}, 

195 ) 

196 deleted_count += 1 

197 dagruns_to_refresh.add(dagrun) 

198 session.flush() 

199 

200 logger.debug("%d deadline records were deleted matching the conditions %s", deleted_count, conditions) 

201 

202 # Refresh any affected DAG runs. 

203 for dagrun in dagruns_to_refresh: 

204 session.refresh(dagrun) 

205 

206 return deleted_count 

207 

208 def handle_miss(self, session: Session): 

209 """Handle a missed deadline by queueing the callback.""" 

210 

211 def get_simple_context(): 

212 from airflow.api_fastapi.core_api.datamodels.dag_run import DAGRunResponse 

213 from airflow.models import DagRun 

214 

215 # TODO: Use the TaskAPI from within Triggerer to fetch full context instead of sending this context 

216 # from the scheduler 

217 

218 # Fetch the DagRun from the database again to avoid errors when self.dagrun's relationship fields 

219 # are not in the current session. 

220 dagrun = session.get(DagRun, self.dagrun_id) 

221 

222 return { 

223 "dag_run": DAGRunResponse.model_validate(dagrun).model_dump(mode="json"), 

224 "deadline": {"id": self.id, "deadline_time": self.deadline_time}, 

225 } 

226 

227 self.callback.data["kwargs"] = self.callback.data["kwargs"] | {"context": get_simple_context()} 

228 self.missed = True 

229 self.callback.queue() 

230 session.add(self) 

231 Stats.incr( 

232 "deadline_alerts.deadline_missed", 

233 tags={"dag_id": self.dagrun.dag_id, "dagrun_id": self.dagrun.run_id}, 

234 ) 

235 

236 

237class ReferenceModels: 

238 """ 

239 Store the implementations for the different Deadline References. 

240 

241 After adding the implementations here, all DeadlineReferences should be added 

242 to the user interface in airflow.sdk.definitions.deadline.DeadlineReference 

243 """ 

244 

245 REFERENCE_TYPE_FIELD = "reference_type" 

246 

247 @classmethod 

248 def get_reference_class(cls, reference_name: str) -> type[BaseDeadlineReference]: 

249 """ 

250 Get a reference class by its name. 

251 

252 :param reference_name: The name of the reference class to find 

253 """ 

254 try: 

255 return next( 

256 ref_class 

257 for name, ref_class in vars(cls).items() 

258 if isinstance(ref_class, type) 

259 and issubclass(ref_class, cls.BaseDeadlineReference) 

260 and ref_class.__name__ == reference_name 

261 ) 

262 except StopIteration: 

263 raise ValueError(f"No reference class found with name: {reference_name}") 

264 

265 class BaseDeadlineReference(LoggingMixin, ABC): 

266 """Base class for all Deadline implementations.""" 

267 

268 # Set of required kwargs - subclasses should override this. 

269 required_kwargs: set[str] = set() 

270 

271 @classproperty 

272 def reference_name(cls: Any) -> str: 

273 return cls.__name__ 

274 

275 def evaluate_with(self, *, session: Session, interval: timedelta, **kwargs: Any) -> datetime | None: 

276 """Validate the provided kwargs and evaluate this deadline with the given conditions.""" 

277 filtered_kwargs = {k: v for k, v in kwargs.items() if k in self.required_kwargs} 

278 

279 if missing_kwargs := self.required_kwargs - filtered_kwargs.keys(): 

280 raise ValueError( 

281 f"{self.__class__.__name__} is missing required parameters: {', '.join(missing_kwargs)}" 

282 ) 

283 

284 if extra_kwargs := kwargs.keys() - filtered_kwargs.keys(): 

285 self.log.debug("Ignoring unexpected parameters: %s", ", ".join(extra_kwargs)) 

286 

287 base_time = self._evaluate_with(session=session, **filtered_kwargs) 

288 return base_time + interval if base_time is not None else None 

289 

290 @abstractmethod 

291 def _evaluate_with(self, *, session: Session, **kwargs: Any) -> datetime | None: 

292 """Must be implemented by subclasses to perform the actual evaluation.""" 

293 raise NotImplementedError 

294 

295 @classmethod 

296 def deserialize_reference(cls, reference_data: dict): 

297 """ 

298 Deserialize a reference type from its dictionary representation. 

299 

300 While the base implementation doesn't use reference_data, this parameter is required 

301 for subclasses that need additional data for initialization (like FixedDatetimeDeadline 

302 which needs a datetime value). 

303 

304 :param reference_data: Dictionary containing serialized reference data. 

305 Always includes a 'reference_type' field, and may include additional 

306 fields needed by specific reference implementations. 

307 """ 

308 return cls() 

309 

310 def serialize_reference(self) -> dict: 

311 """ 

312 Serialize this reference type into a dictionary representation. 

313 

314 This method assumes that the reference doesn't require any additional data. 

315 Override this method in subclasses if additional data is needed for serialization. 

316 """ 

317 return {ReferenceModels.REFERENCE_TYPE_FIELD: self.reference_name} 

318 

319 @dataclass 

320 class FixedDatetimeDeadline(BaseDeadlineReference): 

321 """A deadline that always returns a fixed datetime.""" 

322 

323 _datetime: datetime 

324 

325 def _evaluate_with(self, *, session: Session, **kwargs: Any) -> datetime | None: 

326 return self._datetime 

327 

328 def serialize_reference(self) -> dict: 

329 return { 

330 ReferenceModels.REFERENCE_TYPE_FIELD: self.reference_name, 

331 "datetime": self._datetime.timestamp(), 

332 } 

333 

334 @classmethod 

335 def deserialize_reference(cls, reference_data: dict): 

336 return cls(_datetime=timezone.from_timestamp(reference_data["datetime"])) 

337 

338 class DagRunLogicalDateDeadline(BaseDeadlineReference): 

339 """A deadline that returns a DagRun's logical date.""" 

340 

341 required_kwargs = {"dag_id", "run_id"} 

342 

343 def _evaluate_with(self, *, session: Session, **kwargs: Any) -> datetime | None: 

344 from airflow.models import DagRun 

345 

346 return _fetch_from_db(DagRun.logical_date, session=session, **kwargs) 

347 

348 class DagRunQueuedAtDeadline(BaseDeadlineReference): 

349 """A deadline that returns when a DagRun was queued.""" 

350 

351 required_kwargs = {"dag_id", "run_id"} 

352 

353 @provide_session 

354 def _evaluate_with(self, *, session: Session, **kwargs: Any) -> datetime | None: 

355 from airflow.models import DagRun 

356 

357 return _fetch_from_db(DagRun.queued_at, session=session, **kwargs) 

358 

359 @dataclass 

360 class AverageRuntimeDeadline(BaseDeadlineReference): 

361 """A deadline that calculates the average runtime from past DAG runs.""" 

362 

363 DEFAULT_LIMIT = 10 

364 max_runs: int 

365 min_runs: int | None = None 

366 required_kwargs = {"dag_id"} 

367 

368 def __post_init__(self): 

369 if self.min_runs is None: 

370 self.min_runs = self.max_runs 

371 if self.min_runs < 1: 

372 raise ValueError("min_runs must be at least 1") 

373 

374 @provide_session 

375 def _evaluate_with(self, *, session: Session, **kwargs: Any) -> datetime | None: 

376 from airflow.models import DagRun 

377 

378 dag_id = kwargs["dag_id"] 

379 

380 # Get database dialect to use appropriate time difference calculation 

381 dialect = get_dialect_name(session) 

382 

383 # Create database-specific expression for calculating duration in seconds 

384 duration_expr: ColumnElement[Any] 

385 if dialect == "postgresql": 

386 duration_expr = func.extract("epoch", DagRun.end_date - DagRun.start_date) 

387 elif dialect == "mysql": 

388 # Use TIMESTAMPDIFF to get exact seconds like PostgreSQL EXTRACT(epoch FROM ...) 

389 duration_expr = func.timestampdiff(text("SECOND"), DagRun.start_date, DagRun.end_date) 

390 elif dialect == "sqlite": 

391 duration_expr = (func.julianday(DagRun.end_date) - func.julianday(DagRun.start_date)) * 86400 

392 else: 

393 raise ValueError(f"Unsupported database dialect: {dialect}") 

394 

395 # Query for completed DAG runs with both start and end dates 

396 # Order by logical_date descending to get most recent runs first 

397 query = ( 

398 select(duration_expr) 

399 .filter(DagRun.dag_id == dag_id, DagRun.start_date.isnot(None), DagRun.end_date.isnot(None)) 

400 .order_by(DagRun.logical_date.desc()) 

401 ) 

402 

403 # Apply max_runs 

404 query = query.limit(self.max_runs) 

405 

406 # Get all durations and calculate average 

407 durations: Sequence = session.execute(query).scalars().all() 

408 

409 if len(durations) < cast("int", self.min_runs): 

410 logger.info( 

411 "Only %d completed DAG runs found for dag_id: %s (need %d), skipping deadline creation", 

412 len(durations), 

413 dag_id, 

414 self.min_runs, 

415 ) 

416 return None 

417 # Convert to float to handle Decimal types from MySQL while preserving precision 

418 # Use Decimal arithmetic for higher precision, then convert to float 

419 from decimal import Decimal 

420 

421 decimal_durations = [Decimal(str(d)) for d in durations] 

422 avg_seconds = float(sum(decimal_durations) / len(decimal_durations)) 

423 logger.info( 

424 "Average runtime for dag_id %s (from %d runs): %.2f seconds", 

425 dag_id, 

426 len(durations), 

427 avg_seconds, 

428 ) 

429 return timezone.utcnow() + timedelta(seconds=avg_seconds) 

430 

431 def serialize_reference(self) -> dict: 

432 return { 

433 ReferenceModels.REFERENCE_TYPE_FIELD: self.reference_name, 

434 "max_runs": self.max_runs, 

435 "min_runs": self.min_runs, 

436 } 

437 

438 @classmethod 

439 def deserialize_reference(cls, reference_data: dict): 

440 max_runs = reference_data.get("max_runs", cls.DEFAULT_LIMIT) 

441 min_runs = reference_data.get("min_runs", max_runs) 

442 if min_runs < 1: 

443 raise ValueError("min_runs must be at least 1") 

444 return cls( 

445 max_runs=max_runs, 

446 min_runs=min_runs, 

447 ) 

448 

449 

450DeadlineReferenceType = ReferenceModels.BaseDeadlineReference 

451 

452 

453@provide_session 

454def _fetch_from_db(model_reference: Mapped, session=None, **conditions) -> datetime | None: 

455 """ 

456 Fetch a datetime value from the database using the provided model reference and filtering conditions. 

457 

458 For example, to fetch a TaskInstance's start_date: 

459 _fetch_from_db( 

460 TaskInstance.start_date, dag_id='example_dag', task_id='example_task', run_id='example_run' 

461 ) 

462 

463 This generates SQL equivalent to: 

464 SELECT start_date 

465 FROM task_instance 

466 WHERE dag_id = 'example_dag' 

467 AND task_id = 'example_task' 

468 AND run_id = 'example_run' 

469 

470 :param model_reference: SQLAlchemy Column to select (e.g., DagRun.logical_date, TaskInstance.start_date) 

471 :param conditions: Filtering conditions applied as equality comparisons in the WHERE clause. 

472 Multiple conditions are combined with AND. 

473 :param session: SQLAlchemy session (auto-provided by decorator) 

474 """ 

475 query = select(model_reference) 

476 

477 for key, value in conditions.items(): 

478 inspected = inspect(model_reference) 

479 if inspected is not None: 

480 query = query.where(getattr(inspected.class_, key) == value) 

481 

482 compiled_query = query.compile(compile_kwargs={"literal_binds": True}) 

483 pretty_query = "\n ".join(str(compiled_query).splitlines()) 

484 logger.debug( 

485 "Executing query:\n %r\nAs SQL:\n %s", 

486 query, 

487 pretty_query, 

488 ) 

489 

490 try: 

491 result = session.scalar(query) 

492 except SQLAlchemyError: 

493 logger.exception("Database query failed.") 

494 raise 

495 

496 if result is None: 

497 message = f"No matching record found in the database for query:\n {pretty_query}" 

498 logger.error(message) 

499 raise ValueError(message) 

500 

501 return result