Coverage for /pythoncovmergedfiles/medio/medio/src/airflow/build/lib/airflow/models/dagrun.py: 26%

576 statements  

« prev     ^ index     » next       coverage.py v7.0.1, created at 2022-12-25 06:11 +0000

1# 

2# Licensed to the Apache Software Foundation (ASF) under one 

3# or more contributor license agreements. See the NOTICE file 

4# distributed with this work for additional information 

5# regarding copyright ownership. The ASF licenses this file 

6# to you under the Apache License, Version 2.0 (the 

7# "License"); you may not use this file except in compliance 

8# with the License. You may obtain a copy of the License at 

9# 

10# http://www.apache.org/licenses/LICENSE-2.0 

11# 

12# Unless required by applicable law or agreed to in writing, 

13# software distributed under the License is distributed on an 

14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

15# KIND, either express or implied. See the License for the 

16# specific language governing permissions and limitations 

17# under the License. 

18from __future__ import annotations 

19 

20import itertools 

21import os 

22import warnings 

23from collections import defaultdict 

24from datetime import datetime 

25from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator, NamedTuple, Sequence, TypeVar, overload 

26 

27from sqlalchemy import ( 

28 Boolean, 

29 Column, 

30 ForeignKey, 

31 ForeignKeyConstraint, 

32 Index, 

33 Integer, 

34 PickleType, 

35 PrimaryKeyConstraint, 

36 String, 

37 Text, 

38 UniqueConstraint, 

39 and_, 

40 func, 

41 or_, 

42 text, 

43) 

44from sqlalchemy.exc import IntegrityError 

45from sqlalchemy.ext.associationproxy import association_proxy 

46from sqlalchemy.ext.declarative import declared_attr 

47from sqlalchemy.orm import joinedload, relationship, synonym 

48from sqlalchemy.orm.session import Session 

49from sqlalchemy.sql.expression import false, select, true 

50 

51from airflow import settings 

52from airflow.callbacks.callback_requests import DagCallbackRequest 

53from airflow.configuration import conf as airflow_conf 

54from airflow.exceptions import AirflowException, RemovedInAirflow3Warning, TaskNotFound 

55from airflow.listeners.listener import get_listener_manager 

56from airflow.models.abstractoperator import NotMapped 

57from airflow.models.base import Base, StringID 

58from airflow.models.expandinput import NotFullyPopulated 

59from airflow.models.taskinstance import TaskInstance as TI 

60from airflow.models.tasklog import LogTemplate 

61from airflow.stats import Stats 

62from airflow.ti_deps.dep_context import DepContext 

63from airflow.ti_deps.dependencies_states import SCHEDULEABLE_STATES 

64from airflow.typing_compat import Literal 

65from airflow.utils import timezone 

66from airflow.utils.helpers import is_container 

67from airflow.utils.log.logging_mixin import LoggingMixin 

68from airflow.utils.session import NEW_SESSION, provide_session 

69from airflow.utils.sqlalchemy import UtcDateTime, nulls_first, skip_locked, tuple_in_condition, with_row_locks 

70from airflow.utils.state import DagRunState, State, TaskInstanceState 

71from airflow.utils.types import NOTSET, ArgNotSet, DagRunType 

72 

73if TYPE_CHECKING: 

74 from airflow.models.dag import DAG 

75 from airflow.models.operator import Operator 

76 

77 CreatedTasks = TypeVar("CreatedTasks", Iterator["dict[str, Any]"], Iterator[TI]) 

78 TaskCreator = Callable[[Operator, Iterable[int]], CreatedTasks] 

79 

80 

81class TISchedulingDecision(NamedTuple): 

82 """Type of return for DagRun.task_instance_scheduling_decisions""" 

83 

84 tis: list[TI] 

85 schedulable_tis: list[TI] 

86 changed_tis: bool 

87 unfinished_tis: list[TI] 

88 finished_tis: list[TI] 

89 

90 

91def _creator_note(val): 

92 """Custom creator for the ``note`` association proxy.""" 

93 if isinstance(val, str): 

94 return DagRunNote(content=val) 

95 elif isinstance(val, dict): 

96 return DagRunNote(**val) 

97 else: 

98 return DagRunNote(*val) 

99 

100 

101class DagRun(Base, LoggingMixin): 

102 """ 

103 DagRun describes an instance of a Dag. It can be created 

104 by the scheduler (for regular runs) or by an external trigger 

105 """ 

106 

107 __tablename__ = "dag_run" 

108 

109 id = Column(Integer, primary_key=True) 

110 dag_id = Column(StringID(), nullable=False) 

111 queued_at = Column(UtcDateTime) 

112 execution_date = Column(UtcDateTime, default=timezone.utcnow, nullable=False) 

113 start_date = Column(UtcDateTime) 

114 end_date = Column(UtcDateTime) 

115 _state = Column("state", String(50), default=State.QUEUED) 

116 run_id = Column(StringID(), nullable=False) 

117 creating_job_id = Column(Integer) 

118 external_trigger = Column(Boolean, default=True) 

119 run_type = Column(String(50), nullable=False) 

120 conf = Column(PickleType) 

121 # These two must be either both NULL or both datetime. 

122 data_interval_start = Column(UtcDateTime) 

123 data_interval_end = Column(UtcDateTime) 

124 # When a scheduler last attempted to schedule TIs for this DagRun 

125 last_scheduling_decision = Column(UtcDateTime) 

126 dag_hash = Column(String(32)) 

127 # Foreign key to LogTemplate. DagRun rows created prior to this column's 

128 # existence have this set to NULL. Later rows automatically populate this on 

129 # insert to point to the latest LogTemplate entry. 

130 log_template_id = Column( 

131 Integer, 

132 ForeignKey("log_template.id", name="task_instance_log_template_id_fkey", ondelete="NO ACTION"), 

133 default=select([func.max(LogTemplate.__table__.c.id)]), 

134 ) 

135 updated_at = Column(UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow) 

136 

137 # Remove this `if` after upgrading Sphinx-AutoAPI 

138 if not TYPE_CHECKING and "BUILDING_AIRFLOW_DOCS" in os.environ: 

139 dag: DAG | None 

140 else: 

141 dag: DAG | None = None 

142 

143 __table_args__ = ( 

144 Index("dag_id_state", dag_id, _state), 

145 UniqueConstraint("dag_id", "execution_date", name="dag_run_dag_id_execution_date_key"), 

146 UniqueConstraint("dag_id", "run_id", name="dag_run_dag_id_run_id_key"), 

147 Index("idx_last_scheduling_decision", last_scheduling_decision), 

148 Index("idx_dag_run_dag_id", dag_id), 

149 Index( 

150 "idx_dag_run_running_dags", 

151 "state", 

152 "dag_id", 

153 postgresql_where=text("state='running'"), 

154 mssql_where=text("state='running'"), 

155 sqlite_where=text("state='running'"), 

156 ), 

157 # since mysql lacks filtered/partial indices, this creates a 

158 # duplicate index on mysql. Not the end of the world 

159 Index( 

160 "idx_dag_run_queued_dags", 

161 "state", 

162 "dag_id", 

163 postgresql_where=text("state='queued'"), 

164 mssql_where=text("state='queued'"), 

165 sqlite_where=text("state='queued'"), 

166 ), 

167 ) 

168 

169 task_instances = relationship( 

170 TI, back_populates="dag_run", cascade="save-update, merge, delete, delete-orphan" 

171 ) 

172 dag_model = relationship( 

173 "DagModel", 

174 primaryjoin="foreign(DagRun.dag_id) == DagModel.dag_id", 

175 uselist=False, 

176 viewonly=True, 

177 ) 

178 dag_run_note = relationship("DagRunNote", back_populates="dag_run", uselist=False) 

179 note = association_proxy("dag_run_note", "content", creator=_creator_note) 

180 

181 DEFAULT_DAGRUNS_TO_EXAMINE = airflow_conf.getint( 

182 "scheduler", 

183 "max_dagruns_per_loop_to_schedule", 

184 fallback=20, 

185 ) 

186 

187 def __init__( 

188 self, 

189 dag_id: str | None = None, 

190 run_id: str | None = None, 

191 queued_at: datetime | None | ArgNotSet = NOTSET, 

192 execution_date: datetime | None = None, 

193 start_date: datetime | None = None, 

194 external_trigger: bool | None = None, 

195 conf: Any | None = None, 

196 state: DagRunState | None = None, 

197 run_type: str | None = None, 

198 dag_hash: str | None = None, 

199 creating_job_id: int | None = None, 

200 data_interval: tuple[datetime, datetime] | None = None, 

201 ): 

202 if data_interval is None: 

203 # Legacy: Only happen for runs created prior to Airflow 2.2. 

204 self.data_interval_start = self.data_interval_end = None 

205 else: 

206 self.data_interval_start, self.data_interval_end = data_interval 

207 

208 self.dag_id = dag_id 

209 self.run_id = run_id 

210 self.execution_date = execution_date 

211 self.start_date = start_date 

212 self.external_trigger = external_trigger 

213 self.conf = conf or {} 

214 if state is not None: 

215 self.state = state 

216 if queued_at is NOTSET: 

217 self.queued_at = timezone.utcnow() if state == State.QUEUED else None 

218 else: 

219 self.queued_at = queued_at 

220 self.run_type = run_type 

221 self.dag_hash = dag_hash 

222 self.creating_job_id = creating_job_id 

223 super().__init__() 

224 

225 def __repr__(self): 

226 return ( 

227 "<DagRun {dag_id} @ {execution_date}: {run_id}, state:{state}, " 

228 "queued_at: {queued_at}. externally triggered: {external_trigger}>" 

229 ).format( 

230 dag_id=self.dag_id, 

231 execution_date=self.execution_date, 

232 run_id=self.run_id, 

233 state=self.state, 

234 queued_at=self.queued_at, 

235 external_trigger=self.external_trigger, 

236 ) 

237 

238 @property 

239 def logical_date(self) -> datetime: 

240 return self.execution_date 

241 

242 def get_state(self): 

243 return self._state 

244 

245 def set_state(self, state: DagRunState): 

246 if state not in State.dag_states: 

247 raise ValueError(f"invalid DagRun state: {state}") 

248 if self._state != state: 

249 self._state = state 

250 self.end_date = timezone.utcnow() if self._state in State.finished else None 

251 if state == State.QUEUED: 

252 self.queued_at = timezone.utcnow() 

253 

254 @declared_attr 

255 def state(self): 

256 return synonym("_state", descriptor=property(self.get_state, self.set_state)) 

257 

258 @provide_session 

259 def refresh_from_db(self, session: Session = NEW_SESSION) -> None: 

260 """ 

261 Reloads the current dagrun from the database 

262 

263 :param session: database session 

264 """ 

265 dr = session.query(DagRun).filter(DagRun.dag_id == self.dag_id, DagRun.run_id == self.run_id).one() 

266 self.id = dr.id 

267 self.state = dr.state 

268 

269 @classmethod 

270 @provide_session 

271 def active_runs_of_dags(cls, dag_ids=None, only_running=False, session=None) -> dict[str, int]: 

272 """Get the number of active dag runs for each dag.""" 

273 query = session.query(cls.dag_id, func.count("*")) 

274 if dag_ids is not None: 

275 # 'set' called to avoid duplicate dag_ids, but converted back to 'list' 

276 # because SQLAlchemy doesn't accept a set here. 

277 query = query.filter(cls.dag_id.in_(list(set(dag_ids)))) 

278 if only_running: 

279 query = query.filter(cls.state == State.RUNNING) 

280 else: 

281 query = query.filter(cls.state.in_([State.RUNNING, State.QUEUED])) 

282 query = query.group_by(cls.dag_id) 

283 return {dag_id: count for dag_id, count in query.all()} 

284 

285 @classmethod 

286 def next_dagruns_to_examine( 

287 cls, 

288 state: DagRunState, 

289 session: Session, 

290 max_number: int | None = None, 

291 ) -> list[DagRun]: 

292 """ 

293 Return the next DagRuns that the scheduler should attempt to schedule. 

294 

295 This will return zero or more DagRun rows that are row-level-locked with a "SELECT ... FOR UPDATE" 

296 query, you should ensure that any scheduling decisions are made in a single transaction -- as soon as 

297 the transaction is committed it will be unlocked. 

298 

299 """ 

300 from airflow.models.dag import DagModel 

301 

302 if max_number is None: 

303 max_number = cls.DEFAULT_DAGRUNS_TO_EXAMINE 

304 

305 # TODO: Bake this query, it is run _A lot_ 

306 query = ( 

307 session.query(cls) 

308 .with_hint(cls, "USE INDEX (idx_dag_run_running_dags)", dialect_name="mysql") 

309 .filter(cls.state == state, cls.run_type != DagRunType.BACKFILL_JOB) 

310 .join(DagModel, DagModel.dag_id == cls.dag_id) 

311 .filter(DagModel.is_paused == false(), DagModel.is_active == true()) 

312 ) 

313 if state == State.QUEUED: 

314 # For dag runs in the queued state, we check if they have reached the max_active_runs limit 

315 # and if so we drop them 

316 running_drs = ( 

317 session.query(DagRun.dag_id, func.count(DagRun.state).label("num_running")) 

318 .filter(DagRun.state == DagRunState.RUNNING) 

319 .group_by(DagRun.dag_id) 

320 .subquery() 

321 ) 

322 query = query.outerjoin(running_drs, running_drs.c.dag_id == DagRun.dag_id).filter( 

323 func.coalesce(running_drs.c.num_running, 0) < DagModel.max_active_runs 

324 ) 

325 query = query.order_by( 

326 nulls_first(cls.last_scheduling_decision, session=session), 

327 cls.execution_date, 

328 ) 

329 

330 if not settings.ALLOW_FUTURE_EXEC_DATES: 

331 query = query.filter(DagRun.execution_date <= func.now()) 

332 

333 return with_row_locks( 

334 query.limit(max_number), of=cls, session=session, **skip_locked(session=session) 

335 ) 

336 

337 @classmethod 

338 @provide_session 

339 def find( 

340 cls, 

341 dag_id: str | list[str] | None = None, 

342 run_id: Iterable[str] | None = None, 

343 execution_date: datetime | Iterable[datetime] | None = None, 

344 state: DagRunState | None = None, 

345 external_trigger: bool | None = None, 

346 no_backfills: bool = False, 

347 run_type: DagRunType | None = None, 

348 session: Session = NEW_SESSION, 

349 execution_start_date: datetime | None = None, 

350 execution_end_date: datetime | None = None, 

351 ) -> list[DagRun]: 

352 """ 

353 Returns a set of dag runs for the given search criteria. 

354 

355 :param dag_id: the dag_id or list of dag_id to find dag runs for 

356 :param run_id: defines the run id for this dag run 

357 :param run_type: type of DagRun 

358 :param execution_date: the execution date 

359 :param state: the state of the dag run 

360 :param external_trigger: whether this dag run is externally triggered 

361 :param no_backfills: return no backfills (True), return all (False). 

362 Defaults to False 

363 :param session: database session 

364 :param execution_start_date: dag run that was executed from this date 

365 :param execution_end_date: dag run that was executed until this date 

366 """ 

367 qry = session.query(cls) 

368 dag_ids = [dag_id] if isinstance(dag_id, str) else dag_id 

369 if dag_ids: 

370 qry = qry.filter(cls.dag_id.in_(dag_ids)) 

371 

372 if is_container(run_id): 

373 qry = qry.filter(cls.run_id.in_(run_id)) 

374 elif run_id is not None: 

375 qry = qry.filter(cls.run_id == run_id) 

376 if is_container(execution_date): 

377 qry = qry.filter(cls.execution_date.in_(execution_date)) 

378 elif execution_date is not None: 

379 qry = qry.filter(cls.execution_date == execution_date) 

380 if execution_start_date and execution_end_date: 

381 qry = qry.filter(cls.execution_date.between(execution_start_date, execution_end_date)) 

382 elif execution_start_date: 

383 qry = qry.filter(cls.execution_date >= execution_start_date) 

384 elif execution_end_date: 

385 qry = qry.filter(cls.execution_date <= execution_end_date) 

386 if state: 

387 qry = qry.filter(cls.state == state) 

388 if external_trigger is not None: 

389 qry = qry.filter(cls.external_trigger == external_trigger) 

390 if run_type: 

391 qry = qry.filter(cls.run_type == run_type) 

392 if no_backfills: 

393 qry = qry.filter(cls.run_type != DagRunType.BACKFILL_JOB) 

394 

395 return qry.order_by(cls.execution_date).all() 

396 

397 @classmethod 

398 @provide_session 

399 def find_duplicate( 

400 cls, 

401 dag_id: str, 

402 run_id: str, 

403 execution_date: datetime, 

404 session: Session = NEW_SESSION, 

405 ) -> DagRun | None: 

406 """ 

407 Return an existing run for the DAG with a specific run_id or execution_date. 

408 

409 *None* is returned if no such DAG run is found. 

410 

411 :param dag_id: the dag_id to find duplicates for 

412 :param run_id: defines the run id for this dag run 

413 :param execution_date: the execution date 

414 :param session: database session 

415 """ 

416 return ( 

417 session.query(cls) 

418 .filter( 

419 cls.dag_id == dag_id, 

420 or_(cls.run_id == run_id, cls.execution_date == execution_date), 

421 ) 

422 .one_or_none() 

423 ) 

424 

425 @staticmethod 

426 def generate_run_id(run_type: DagRunType, execution_date: datetime) -> str: 

427 """Generate Run ID based on Run Type and Execution Date""" 

428 # _Ensure_ run_type is a DagRunType, not just a string from user code 

429 return DagRunType(run_type).generate_run_id(execution_date) 

430 

431 @provide_session 

432 def get_task_instances( 

433 self, 

434 state: Iterable[TaskInstanceState | None] | None = None, 

435 session: Session = NEW_SESSION, 

436 ) -> list[TI]: 

437 """Returns the task instances for this dag run""" 

438 tis = ( 

439 session.query(TI) 

440 .options(joinedload(TI.dag_run)) 

441 .filter( 

442 TI.dag_id == self.dag_id, 

443 TI.run_id == self.run_id, 

444 ) 

445 ) 

446 

447 if state: 

448 if isinstance(state, str): 

449 tis = tis.filter(TI.state == state) 

450 else: 

451 # this is required to deal with NULL values 

452 if State.NONE in state: 

453 if all(x is None for x in state): 

454 tis = tis.filter(TI.state.is_(None)) 

455 else: 

456 not_none_state = [s for s in state if s] 

457 tis = tis.filter(or_(TI.state.in_(not_none_state), TI.state.is_(None))) 

458 else: 

459 tis = tis.filter(TI.state.in_(state)) 

460 

461 if self.dag and self.dag.partial: 

462 tis = tis.filter(TI.task_id.in_(self.dag.task_ids)) 

463 return tis.all() 

464 

465 @provide_session 

466 def get_task_instance( 

467 self, 

468 task_id: str, 

469 session: Session = NEW_SESSION, 

470 *, 

471 map_index: int = -1, 

472 ) -> TI | None: 

473 """ 

474 Returns the task instance specified by task_id for this dag run 

475 

476 :param task_id: the task id 

477 :param session: Sqlalchemy ORM Session 

478 """ 

479 return ( 

480 session.query(TI) 

481 .filter_by(dag_id=self.dag_id, run_id=self.run_id, task_id=task_id, map_index=map_index) 

482 .one_or_none() 

483 ) 

484 

485 def get_dag(self) -> DAG: 

486 """ 

487 Returns the Dag associated with this DagRun. 

488 

489 :return: DAG 

490 """ 

491 if not self.dag: 

492 raise AirflowException(f"The DAG (.dag) for {self} needs to be set") 

493 

494 return self.dag 

495 

496 @provide_session 

497 def get_previous_dagrun( 

498 self, state: DagRunState | None = None, session: Session = NEW_SESSION 

499 ) -> DagRun | None: 

500 """The previous DagRun, if there is one""" 

501 filters = [ 

502 DagRun.dag_id == self.dag_id, 

503 DagRun.execution_date < self.execution_date, 

504 ] 

505 if state is not None: 

506 filters.append(DagRun.state == state) 

507 return session.query(DagRun).filter(*filters).order_by(DagRun.execution_date.desc()).first() 

508 

509 @provide_session 

510 def get_previous_scheduled_dagrun(self, session: Session = NEW_SESSION) -> DagRun | None: 

511 """The previous, SCHEDULED DagRun, if there is one""" 

512 return ( 

513 session.query(DagRun) 

514 .filter( 

515 DagRun.dag_id == self.dag_id, 

516 DagRun.execution_date < self.execution_date, 

517 DagRun.run_type != DagRunType.MANUAL, 

518 ) 

519 .order_by(DagRun.execution_date.desc()) 

520 .first() 

521 ) 

522 

523 @provide_session 

524 def update_state( 

525 self, session: Session = NEW_SESSION, execute_callbacks: bool = True 

526 ) -> tuple[list[TI], DagCallbackRequest | None]: 

527 """ 

528 Determines the overall state of the DagRun based on the state 

529 of its TaskInstances. 

530 

531 :param session: Sqlalchemy ORM Session 

532 :param execute_callbacks: Should dag callbacks (success/failure, SLA etc) be invoked 

533 directly (default: true) or recorded as a pending request in the ``returned_callback`` property 

534 :return: Tuple containing tis that can be scheduled in the current loop & `returned_callback` that 

535 needs to be executed 

536 """ 

537 # Callback to execute in case of Task Failures 

538 callback: DagCallbackRequest | None = None 

539 

540 class _UnfinishedStates(NamedTuple): 

541 tis: Sequence[TI] 

542 

543 @classmethod 

544 def calculate(cls, unfinished_tis: Sequence[TI]) -> _UnfinishedStates: 

545 return cls(tis=unfinished_tis) 

546 

547 @property 

548 def should_schedule(self) -> bool: 

549 return ( 

550 bool(self.tis) 

551 and all(not t.task.depends_on_past for t in self.tis) 

552 and all(t.task.max_active_tis_per_dag is None for t in self.tis) 

553 and all(t.state != TaskInstanceState.DEFERRED for t in self.tis) 

554 ) 

555 

556 def recalculate(self) -> _UnfinishedStates: 

557 return self._replace(tis=[t for t in self.tis if t.state in State.unfinished]) 

558 

559 start_dttm = timezone.utcnow() 

560 self.last_scheduling_decision = start_dttm 

561 with Stats.timer(f"dagrun.dependency-check.{self.dag_id}"): 

562 dag = self.get_dag() 

563 info = self.task_instance_scheduling_decisions(session) 

564 

565 tis = info.tis 

566 schedulable_tis = info.schedulable_tis 

567 changed_tis = info.changed_tis 

568 finished_tis = info.finished_tis 

569 unfinished = _UnfinishedStates.calculate(info.unfinished_tis) 

570 

571 if unfinished.should_schedule: 

572 are_runnable_tasks = schedulable_tis or changed_tis 

573 # small speed up 

574 if not are_runnable_tasks: 

575 are_runnable_tasks, changed_by_upstream = self._are_premature_tis( 

576 unfinished.tis, finished_tis, session 

577 ) 

578 if changed_by_upstream: # Something changed, we need to recalculate! 

579 unfinished = unfinished.recalculate() 

580 

581 leaf_task_ids = {t.task_id for t in dag.leaves} 

582 leaf_tis = [ti for ti in tis if ti.task_id in leaf_task_ids if ti.state != TaskInstanceState.REMOVED] 

583 

584 # if all roots finished and at least one failed, the run failed 

585 if not unfinished.tis and any(leaf_ti.state in State.failed_states for leaf_ti in leaf_tis): 

586 self.log.error("Marking run %s failed", self) 

587 self.set_state(DagRunState.FAILED) 

588 self.notify_dagrun_state_changed(msg="task_failure") 

589 

590 if execute_callbacks: 

591 dag.handle_callback(self, success=False, reason="task_failure", session=session) 

592 elif dag.has_on_failure_callback: 

593 from airflow.models.dag import DagModel 

594 

595 dag_model = DagModel.get_dagmodel(dag.dag_id, session) 

596 callback = DagCallbackRequest( 

597 full_filepath=dag.fileloc, 

598 dag_id=self.dag_id, 

599 run_id=self.run_id, 

600 is_failure_callback=True, 

601 processor_subdir=dag_model.processor_subdir, 

602 msg="task_failure", 

603 ) 

604 

605 # if all leaves succeeded and no unfinished tasks, the run succeeded 

606 elif not unfinished.tis and all(leaf_ti.state in State.success_states for leaf_ti in leaf_tis): 

607 self.log.info("Marking run %s successful", self) 

608 self.set_state(DagRunState.SUCCESS) 

609 self.notify_dagrun_state_changed(msg="success") 

610 

611 if execute_callbacks: 

612 dag.handle_callback(self, success=True, reason="success", session=session) 

613 elif dag.has_on_success_callback: 

614 from airflow.models.dag import DagModel 

615 

616 dag_model = DagModel.get_dagmodel(dag.dag_id, session) 

617 callback = DagCallbackRequest( 

618 full_filepath=dag.fileloc, 

619 dag_id=self.dag_id, 

620 run_id=self.run_id, 

621 is_failure_callback=False, 

622 processor_subdir=dag_model.processor_subdir, 

623 msg="success", 

624 ) 

625 

626 # if *all tasks* are deadlocked, the run failed 

627 elif unfinished.should_schedule and not are_runnable_tasks: 

628 self.log.error("Task deadlock (no runnable tasks); marking run %s failed", self) 

629 self.set_state(DagRunState.FAILED) 

630 self.notify_dagrun_state_changed(msg="all_tasks_deadlocked") 

631 

632 if execute_callbacks: 

633 dag.handle_callback(self, success=False, reason="all_tasks_deadlocked", session=session) 

634 elif dag.has_on_failure_callback: 

635 from airflow.models.dag import DagModel 

636 

637 dag_model = DagModel.get_dagmodel(dag.dag_id, session) 

638 callback = DagCallbackRequest( 

639 full_filepath=dag.fileloc, 

640 dag_id=self.dag_id, 

641 run_id=self.run_id, 

642 is_failure_callback=True, 

643 processor_subdir=dag_model.processor_subdir, 

644 msg="all_tasks_deadlocked", 

645 ) 

646 

647 # finally, if the roots aren't done, the dag is still running 

648 else: 

649 self.set_state(DagRunState.RUNNING) 

650 

651 if self._state == DagRunState.FAILED or self._state == DagRunState.SUCCESS: 

652 msg = ( 

653 "DagRun Finished: dag_id=%s, execution_date=%s, run_id=%s, " 

654 "run_start_date=%s, run_end_date=%s, run_duration=%s, " 

655 "state=%s, external_trigger=%s, run_type=%s, " 

656 "data_interval_start=%s, data_interval_end=%s, dag_hash=%s" 

657 ) 

658 self.log.info( 

659 msg, 

660 self.dag_id, 

661 self.execution_date, 

662 self.run_id, 

663 self.start_date, 

664 self.end_date, 

665 (self.end_date - self.start_date).total_seconds() 

666 if self.start_date and self.end_date 

667 else None, 

668 self._state, 

669 self.external_trigger, 

670 self.run_type, 

671 self.data_interval_start, 

672 self.data_interval_end, 

673 self.dag_hash, 

674 ) 

675 session.flush() 

676 

677 self._emit_true_scheduling_delay_stats_for_finished_state(finished_tis) 

678 self._emit_duration_stats_for_finished_state() 

679 

680 session.merge(self) 

681 # We do not flush here for performance reasons(It increases queries count by +20) 

682 

683 return schedulable_tis, callback 

684 

685 @provide_session 

686 def task_instance_scheduling_decisions(self, session: Session = NEW_SESSION) -> TISchedulingDecision: 

687 tis = self.get_task_instances(session=session, state=State.task_states) 

688 self.log.debug("number of tis tasks for %s: %s task(s)", self, len(tis)) 

689 

690 def _filter_tis_and_exclude_removed(dag: DAG, tis: list[TI]) -> Iterable[TI]: 

691 """Populate ``ti.task`` while excluding those missing one, marking them as REMOVED.""" 

692 for ti in tis: 

693 try: 

694 ti.task = dag.get_task(ti.task_id) 

695 except TaskNotFound: 

696 if ti.state != State.REMOVED: 

697 self.log.error("Failed to get task for ti %s. Marking it as removed.", ti) 

698 ti.state = State.REMOVED 

699 session.flush() 

700 else: 

701 yield ti 

702 

703 tis = list(_filter_tis_and_exclude_removed(self.get_dag(), tis)) 

704 

705 unfinished_tis = [t for t in tis if t.state in State.unfinished] 

706 finished_tis = [t for t in tis if t.state in State.finished] 

707 if unfinished_tis: 

708 schedulable_tis = [ut for ut in unfinished_tis if ut.state in SCHEDULEABLE_STATES] 

709 self.log.debug("number of scheduleable tasks for %s: %s task(s)", self, len(schedulable_tis)) 

710 schedulable_tis, changed_tis, expansion_happened = self._get_ready_tis( 

711 schedulable_tis, 

712 finished_tis, 

713 session=session, 

714 ) 

715 

716 # During expansion we may change some tis into non-schedulable 

717 # states, so we need to re-compute. 

718 if expansion_happened: 

719 changed_tis = True 

720 new_unfinished_tis = [t for t in unfinished_tis if t.state in State.unfinished] 

721 finished_tis.extend(t for t in unfinished_tis if t.state in State.finished) 

722 unfinished_tis = new_unfinished_tis 

723 else: 

724 schedulable_tis = [] 

725 changed_tis = False 

726 

727 return TISchedulingDecision( 

728 tis=tis, 

729 schedulable_tis=schedulable_tis, 

730 changed_tis=changed_tis, 

731 unfinished_tis=unfinished_tis, 

732 finished_tis=finished_tis, 

733 ) 

734 

735 def notify_dagrun_state_changed(self, msg: str = ""): 

736 if self.state == DagRunState.RUNNING: 

737 get_listener_manager().hook.on_dag_run_running(dag_run=self, msg=msg) 

738 elif self.state == DagRunState.SUCCESS: 

739 get_listener_manager().hook.on_dag_run_success(dag_run=self, msg=msg) 

740 elif self.state == DagRunState.FAILED: 

741 get_listener_manager().hook.on_dag_run_failed(dag_run=self, msg=msg) 

742 # deliberately not notifying on QUEUED 

743 # we can't get all the state changes on SchedulerJob, BackfillJob 

744 # or LocalTaskJob, so we don't want to "falsely advertise" we notify about that 

745 

746 def _get_ready_tis( 

747 self, 

748 schedulable_tis: list[TI], 

749 finished_tis: list[TI], 

750 session: Session, 

751 ) -> tuple[list[TI], bool, bool]: 

752 old_states = {} 

753 ready_tis: list[TI] = [] 

754 changed_tis = False 

755 

756 if not schedulable_tis: 

757 return ready_tis, changed_tis, False 

758 

759 # If we expand TIs, we need a new list so that we iterate over them too. (We can't alter 

760 # `schedulable_tis` in place and have the `for` loop pick them up 

761 additional_tis: list[TI] = [] 

762 dep_context = DepContext( 

763 flag_upstream_failed=True, 

764 ignore_unmapped_tasks=True, # Ignore this Dep, as we will expand it if we can. 

765 finished_tis=finished_tis, 

766 ) 

767 

768 def _expand_mapped_task_if_needed(ti: TI) -> Iterable[TI] | None: 

769 """Try to expand the ti, if needed. 

770 

771 If the ti needs expansion, newly created task instances are 

772 returned as well as the original ti. 

773 The original ti is also modified in-place and assigned the 

774 ``map_index`` of 0. 

775 

776 If the ti does not need expansion, either because the task is not 

777 mapped, or has already been expanded, *None* is returned. 

778 """ 

779 if ti.map_index >= 0: # Already expanded, we're good. 

780 return None 

781 try: 

782 expanded_tis, _ = ti.task.expand_mapped_task(self.run_id, session=session) 

783 except NotMapped: # Not a mapped task, nothing needed. 

784 return None 

785 if expanded_tis: 

786 return expanded_tis 

787 return () 

788 

789 # Check dependencies. 

790 expansion_happened = False 

791 for schedulable in itertools.chain(schedulable_tis, additional_tis): 

792 old_state = schedulable.state 

793 if not schedulable.are_dependencies_met(session=session, dep_context=dep_context): 

794 old_states[schedulable.key] = old_state 

795 continue 

796 # If schedulable is not yet expanded, try doing it now. This is 

797 # called in two places: First and ideally in the mini scheduler at 

798 # the end of LocalTaskJob, and then as an "expansion of last resort" 

799 # in the scheduler to ensure that the mapped task is correctly 

800 # expanded before executed. Also see _revise_map_indexes_if_mapped 

801 # docstring for additional information. 

802 new_tis = None 

803 if schedulable.map_index < 0: 

804 new_tis = _expand_mapped_task_if_needed(schedulable) 

805 if new_tis is not None: 

806 additional_tis.extend(new_tis) 

807 expansion_happened = True 

808 if new_tis is None and schedulable.state in SCHEDULEABLE_STATES: 

809 ready_tis.extend(self._revise_map_indexes_if_mapped(schedulable.task, session=session)) 

810 ready_tis.append(schedulable) 

811 

812 # Check if any ti changed state 

813 tis_filter = TI.filter_for_tis(old_states) 

814 if tis_filter is not None: 

815 fresh_tis = session.query(TI).filter(tis_filter).all() 

816 changed_tis = any(ti.state != old_states[ti.key] for ti in fresh_tis) 

817 

818 return ready_tis, changed_tis, expansion_happened 

819 

820 def _are_premature_tis( 

821 self, 

822 unfinished_tis: Sequence[TI], 

823 finished_tis: list[TI], 

824 session: Session, 

825 ) -> tuple[bool, bool]: 

826 dep_context = DepContext( 

827 flag_upstream_failed=True, 

828 ignore_in_retry_period=True, 

829 ignore_in_reschedule_period=True, 

830 finished_tis=finished_tis, 

831 ) 

832 # there might be runnable tasks that are up for retry and for some reason(retry delay, etc) are 

833 # not ready yet so we set the flags to count them in 

834 return ( 

835 any(ut.are_dependencies_met(dep_context=dep_context, session=session) for ut in unfinished_tis), 

836 dep_context.have_changed_ti_states, 

837 ) 

838 

839 def _emit_true_scheduling_delay_stats_for_finished_state(self, finished_tis: list[TI]) -> None: 

840 """ 

841 This is a helper method to emit the true scheduling delay stats, which is defined as 

842 the time when the first task in DAG starts minus the expected DAG run datetime. 

843 This method will be used in the update_state method when the state of the DagRun 

844 is updated to a completed status (either success or failure). The method will find the first 

845 started task within the DAG and calculate the expected DagRun start time (based on 

846 dag.execution_date & dag.timetable), and minus these two values to get the delay. 

847 The emitted data may contains outlier (e.g. when the first task was cleared, so 

848 the second task's start_date will be used), but we can get rid of the outliers 

849 on the stats side through the dashboards tooling built. 

850 Note, the stat will only be emitted if the DagRun is a scheduler triggered one 

851 (i.e. external_trigger is False). 

852 """ 

853 if self.state == State.RUNNING: 

854 return 

855 if self.external_trigger: 

856 return 

857 if not finished_tis: 

858 return 

859 

860 try: 

861 dag = self.get_dag() 

862 

863 if not dag.timetable.periodic: 

864 # We can't emit this metric if there is no following schedule to calculate from! 

865 return 

866 

867 ordered_tis_by_start_date = [ti for ti in finished_tis if ti.start_date] 

868 ordered_tis_by_start_date.sort(key=lambda ti: ti.start_date, reverse=False) 

869 first_start_date = ordered_tis_by_start_date[0].start_date if ordered_tis_by_start_date else None 

870 if first_start_date: 

871 # TODO: Logically, this should be DagRunInfo.run_after, but the 

872 # information is not stored on a DagRun, only before the actual 

873 # execution on DagModel.next_dagrun_create_after. We should add 

874 # a field on DagRun for this instead of relying on the run 

875 # always happening immediately after the data interval. 

876 data_interval_end = dag.get_run_data_interval(self).end 

877 true_delay = first_start_date - data_interval_end 

878 if true_delay.total_seconds() > 0: 

879 Stats.timing(f"dagrun.{dag.dag_id}.first_task_scheduling_delay", true_delay) 

880 except Exception: 

881 self.log.warning("Failed to record first_task_scheduling_delay metric:", exc_info=True) 

882 

883 def _emit_duration_stats_for_finished_state(self): 

884 if self.state == State.RUNNING: 

885 return 

886 if self.start_date is None: 

887 self.log.warning("Failed to record duration of %s: start_date is not set.", self) 

888 return 

889 if self.end_date is None: 

890 self.log.warning("Failed to record duration of %s: end_date is not set.", self) 

891 return 

892 

893 duration = self.end_date - self.start_date 

894 if self.state == State.SUCCESS: 

895 Stats.timing(f"dagrun.duration.success.{self.dag_id}", duration) 

896 elif self.state == State.FAILED: 

897 Stats.timing(f"dagrun.duration.failed.{self.dag_id}", duration) 

898 

899 @provide_session 

900 def verify_integrity(self, *, session: Session = NEW_SESSION) -> None: 

901 """ 

902 Verifies the DagRun by checking for removed tasks or tasks that are not in the 

903 database yet. It will set state to removed or add the task if required. 

904 

905 :missing_indexes: A dictionary of task vs indexes that are missing. 

906 :param session: Sqlalchemy ORM Session 

907 """ 

908 from airflow.settings import task_instance_mutation_hook 

909 

910 # Set for the empty default in airflow.settings -- if it's not set this means it has been changed 

911 # Note: Literal[True, False] instead of bool because otherwise it doesn't correctly find the overload. 

912 hook_is_noop: Literal[True, False] = getattr(task_instance_mutation_hook, "is_noop", False) 

913 

914 dag = self.get_dag() 

915 task_ids = self._check_for_removed_or_restored_tasks( 

916 dag, task_instance_mutation_hook, session=session 

917 ) 

918 

919 def task_filter(task: Operator) -> bool: 

920 return task.task_id not in task_ids and ( 

921 self.is_backfill 

922 or task.start_date <= self.execution_date 

923 and (task.end_date is None or self.execution_date <= task.end_date) 

924 ) 

925 

926 created_counts: dict[str, int] = defaultdict(int) 

927 task_creator = self._get_task_creator(created_counts, task_instance_mutation_hook, hook_is_noop) 

928 

929 # Create the missing tasks, including mapped tasks 

930 tasks_to_create = (task for task in dag.task_dict.values() if task_filter(task)) 

931 tis_to_create = self._create_tasks(tasks_to_create, task_creator, session=session) 

932 self._create_task_instances(self.dag_id, tis_to_create, created_counts, hook_is_noop, session=session) 

933 

934 def _check_for_removed_or_restored_tasks( 

935 self, dag: DAG, ti_mutation_hook, *, session: Session 

936 ) -> set[str]: 

937 """ 

938 Check for removed tasks/restored/missing tasks. 

939 

940 :param dag: DAG object corresponding to the dagrun 

941 :param ti_mutation_hook: task_instance_mutation_hook function 

942 :param session: Sqlalchemy ORM Session 

943 

944 :return: Task IDs in the DAG run 

945 

946 """ 

947 tis = self.get_task_instances(session=session) 

948 

949 # check for removed or restored tasks 

950 task_ids = set() 

951 for ti in tis: 

952 ti_mutation_hook(ti) 

953 task_ids.add(ti.task_id) 

954 try: 

955 task = dag.get_task(ti.task_id) 

956 

957 should_restore_task = (task is not None) and ti.state == State.REMOVED 

958 if should_restore_task: 

959 self.log.info("Restoring task '%s' which was previously removed from DAG '%s'", ti, dag) 

960 Stats.incr(f"task_restored_to_dag.{dag.dag_id}", 1, 1) 

961 ti.state = State.NONE 

962 except AirflowException: 

963 if ti.state == State.REMOVED: 

964 pass # ti has already been removed, just ignore it 

965 elif self.state != State.RUNNING and not dag.partial: 

966 self.log.warning("Failed to get task '%s' for dag '%s'. Marking it as removed.", ti, dag) 

967 Stats.incr(f"task_removed_from_dag.{dag.dag_id}", 1, 1) 

968 ti.state = State.REMOVED 

969 continue 

970 

971 try: 

972 num_mapped_tis = task.get_parse_time_mapped_ti_count() 

973 except NotMapped: 

974 continue 

975 except NotFullyPopulated: 

976 # What if it is _now_ dynamically mapped, but wasn't before? 

977 try: 

978 total_length = task.get_mapped_ti_count(self.run_id, session=session) 

979 except NotFullyPopulated: 

980 # Not all upstreams finished, so we can't tell what should be here. Remove everything. 

981 if ti.map_index >= 0: 

982 self.log.debug( 

983 "Removing the unmapped TI '%s' as the mapping can't be resolved yet", ti 

984 ) 

985 ti.state = State.REMOVED 

986 continue 

987 # Upstreams finished, check there aren't any extras 

988 if ti.map_index >= total_length: 

989 self.log.debug( 

990 "Removing task '%s' as the map_index is longer than the resolved mapping list (%d)", 

991 ti, 

992 total_length, 

993 ) 

994 ti.state = State.REMOVED 

995 else: 

996 # Check if the number of mapped literals has changed and we need to mark this TI as removed. 

997 if ti.map_index >= num_mapped_tis: 

998 self.log.debug( 

999 "Removing task '%s' as the map_index is longer than the literal mapping list (%s)", 

1000 ti, 

1001 num_mapped_tis, 

1002 ) 

1003 ti.state = State.REMOVED 

1004 elif ti.map_index < 0: 

1005 self.log.debug("Removing the unmapped TI '%s' as the mapping can now be performed", ti) 

1006 ti.state = State.REMOVED 

1007 

1008 return task_ids 

1009 

1010 @overload 

1011 def _get_task_creator( 

1012 self, 

1013 created_counts: dict[str, int], 

1014 ti_mutation_hook: Callable, 

1015 hook_is_noop: Literal[True], 

1016 ) -> Callable[[Operator, Iterable[int]], Iterator[dict[str, Any]]]: 

1017 ... 

1018 

1019 @overload 

1020 def _get_task_creator( 

1021 self, 

1022 created_counts: dict[str, int], 

1023 ti_mutation_hook: Callable, 

1024 hook_is_noop: Literal[False], 

1025 ) -> Callable[[Operator, Iterable[int]], Iterator[TI]]: 

1026 ... 

1027 

1028 def _get_task_creator( 

1029 self, 

1030 created_counts: dict[str, int], 

1031 ti_mutation_hook: Callable, 

1032 hook_is_noop: Literal[True, False], 

1033 ) -> Callable[[Operator, Iterable[int]], Iterator[dict[str, Any]] | Iterator[TI]]: 

1034 """ 

1035 Get the task creator function. 

1036 

1037 This function also updates the created_counts dictionary with the number of tasks created. 

1038 

1039 :param created_counts: Dictionary of task_type -> count of created TIs 

1040 :param ti_mutation_hook: task_instance_mutation_hook function 

1041 :param hook_is_noop: Whether the task_instance_mutation_hook is a noop 

1042 

1043 """ 

1044 if hook_is_noop: 

1045 

1046 def create_ti_mapping(task: Operator, indexes: Iterable[int]) -> Iterator[dict[str, Any]]: 

1047 created_counts[task.task_type] += 1 

1048 for map_index in indexes: 

1049 yield TI.insert_mapping(self.run_id, task, map_index=map_index) 

1050 

1051 creator = create_ti_mapping 

1052 

1053 else: 

1054 

1055 def create_ti(task: Operator, indexes: Iterable[int]) -> Iterator[TI]: 

1056 for map_index in indexes: 

1057 ti = TI(task, run_id=self.run_id, map_index=map_index) 

1058 ti_mutation_hook(ti) 

1059 created_counts[ti.operator] += 1 

1060 yield ti 

1061 

1062 creator = create_ti 

1063 return creator 

1064 

1065 def _create_tasks( 

1066 self, 

1067 tasks: Iterable[Operator], 

1068 task_creator: TaskCreator, 

1069 *, 

1070 session: Session, 

1071 ) -> CreatedTasks: 

1072 """ 

1073 Create missing tasks -- and expand any MappedOperator that _only_ have literals as input 

1074 

1075 :param tasks: Tasks to create jobs for in the DAG run 

1076 :param task_creator: Function to create task instances 

1077 """ 

1078 map_indexes: Iterable[int] 

1079 for task in tasks: 

1080 try: 

1081 count = task.get_mapped_ti_count(self.run_id, session=session) 

1082 except (NotMapped, NotFullyPopulated): 

1083 map_indexes = (-1,) 

1084 else: 

1085 if count: 

1086 map_indexes = range(count) 

1087 else: 

1088 # Make sure to always create at least one ti; this will be 

1089 # marked as REMOVED later at runtime. 

1090 map_indexes = (-1,) 

1091 yield from task_creator(task, map_indexes) 

1092 

1093 def _create_task_instances( 

1094 self, 

1095 dag_id: str, 

1096 tasks: Iterator[dict[str, Any]] | Iterator[TI], 

1097 created_counts: dict[str, int], 

1098 hook_is_noop: bool, 

1099 *, 

1100 session: Session, 

1101 ) -> None: 

1102 """ 

1103 Create the necessary task instances from the given tasks. 

1104 

1105 :param dag_id: DAG ID associated with the dagrun 

1106 :param tasks: the tasks to create the task instances from 

1107 :param created_counts: a dictionary of number of tasks -> total ti created by the task creator 

1108 :param hook_is_noop: whether the task_instance_mutation_hook is noop 

1109 :param session: the session to use 

1110 

1111 """ 

1112 # Fetch the information we need before handling the exception to avoid 

1113 # PendingRollbackError due to the session being invalidated on exception 

1114 # see https://github.com/apache/superset/pull/530 

1115 run_id = self.run_id 

1116 try: 

1117 if hook_is_noop: 

1118 session.bulk_insert_mappings(TI, tasks) 

1119 else: 

1120 session.bulk_save_objects(tasks) 

1121 

1122 for task_type, count in created_counts.items(): 

1123 Stats.incr(f"task_instance_created-{task_type}", count) 

1124 session.flush() 

1125 except IntegrityError: 

1126 self.log.info( 

1127 "Hit IntegrityError while creating the TIs for %s- %s", 

1128 dag_id, 

1129 run_id, 

1130 exc_info=True, 

1131 ) 

1132 self.log.info("Doing session rollback.") 

1133 # TODO[HA]: We probably need to savepoint this so we can keep the transaction alive. 

1134 session.rollback() 

1135 

1136 def _revise_map_indexes_if_mapped(self, task: Operator, *, session: Session) -> Iterator[TI]: 

1137 """Check if task increased or reduced in length and handle appropriately. 

1138 

1139 Task instances that do not already exist are created and returned if 

1140 possible. Expansion only happens if all upstreams are ready; otherwise 

1141 we delay expansion to the "last resort". See comments at the call site 

1142 for more details. 

1143 """ 

1144 from airflow.settings import task_instance_mutation_hook 

1145 

1146 try: 

1147 total_length = task.get_mapped_ti_count(self.run_id, session=session) 

1148 except NotMapped: 

1149 return # Not a mapped task, don't need to do anything. 

1150 except NotFullyPopulated: 

1151 return # Upstreams not ready, don't need to revise this yet. 

1152 

1153 query = session.query(TI.map_index).filter( 

1154 TI.dag_id == self.dag_id, 

1155 TI.task_id == task.task_id, 

1156 TI.run_id == self.run_id, 

1157 ) 

1158 existing_indexes = {i for (i,) in query} 

1159 

1160 removed_indexes = existing_indexes.difference(range(total_length)) 

1161 if removed_indexes: 

1162 session.query(TI).filter( 

1163 TI.dag_id == self.dag_id, 

1164 TI.task_id == task.task_id, 

1165 TI.run_id == self.run_id, 

1166 TI.map_index.in_(removed_indexes), 

1167 ).update({TI.state: TaskInstanceState.REMOVED}) 

1168 session.flush() 

1169 

1170 for index in range(total_length): 

1171 if index in existing_indexes: 

1172 continue 

1173 ti = TI(task, run_id=self.run_id, map_index=index, state=None) 

1174 self.log.debug("Expanding TIs upserted %s", ti) 

1175 task_instance_mutation_hook(ti) 

1176 ti = session.merge(ti) 

1177 ti.refresh_from_task(task) 

1178 session.flush() 

1179 yield ti 

1180 

1181 @staticmethod 

1182 def get_run(session: Session, dag_id: str, execution_date: datetime) -> DagRun | None: 

1183 """ 

1184 Get a single DAG Run 

1185 

1186 :meta private: 

1187 :param session: Sqlalchemy ORM Session 

1188 :param dag_id: DAG ID 

1189 :param execution_date: execution date 

1190 :return: DagRun corresponding to the given dag_id and execution date 

1191 if one exists. None otherwise. 

1192 """ 

1193 warnings.warn( 

1194 "This method is deprecated. Please use SQLAlchemy directly", 

1195 RemovedInAirflow3Warning, 

1196 stacklevel=2, 

1197 ) 

1198 return ( 

1199 session.query(DagRun) 

1200 .filter( 

1201 DagRun.dag_id == dag_id, 

1202 DagRun.external_trigger == False, # noqa 

1203 DagRun.execution_date == execution_date, 

1204 ) 

1205 .first() 

1206 ) 

1207 

1208 @property 

1209 def is_backfill(self) -> bool: 

1210 return self.run_type == DagRunType.BACKFILL_JOB 

1211 

1212 @classmethod 

1213 @provide_session 

1214 def get_latest_runs(cls, session=None) -> list[DagRun]: 

1215 """Returns the latest DagRun for each DAG""" 

1216 subquery = ( 

1217 session.query(cls.dag_id, func.max(cls.execution_date).label("execution_date")) 

1218 .group_by(cls.dag_id) 

1219 .subquery() 

1220 ) 

1221 return ( 

1222 session.query(cls) 

1223 .join( 

1224 subquery, 

1225 and_(cls.dag_id == subquery.c.dag_id, cls.execution_date == subquery.c.execution_date), 

1226 ) 

1227 .all() 

1228 ) 

1229 

1230 @provide_session 

1231 def schedule_tis(self, schedulable_tis: Iterable[TI], session: Session = NEW_SESSION) -> int: 

1232 """ 

1233 Set the given task instances in to the scheduled state. 

1234 

1235 Each element of ``schedulable_tis`` should have it's ``task`` attribute already set. 

1236 

1237 Any EmptyOperator without callbacks or outlets is instead set straight to the success state. 

1238 

1239 All the TIs should belong to this DagRun, but this code is in the hot-path, this is not checked -- it 

1240 is the caller's responsibility to call this function only with TIs from a single dag run. 

1241 """ 

1242 # Get list of TI IDs that do not need to executed, these are 

1243 # tasks using EmptyOperator and without on_execute_callback / on_success_callback 

1244 dummy_ti_ids = [] 

1245 schedulable_ti_ids = [] 

1246 for ti in schedulable_tis: 

1247 if ( 

1248 ti.task.inherits_from_empty_operator 

1249 and not ti.task.on_execute_callback 

1250 and not ti.task.on_success_callback 

1251 and not ti.task.outlets 

1252 ): 

1253 dummy_ti_ids.append(ti.task_id) 

1254 else: 

1255 schedulable_ti_ids.append((ti.task_id, ti.map_index)) 

1256 

1257 count = 0 

1258 

1259 if schedulable_ti_ids: 

1260 count += ( 

1261 session.query(TI) 

1262 .filter( 

1263 TI.dag_id == self.dag_id, 

1264 TI.run_id == self.run_id, 

1265 tuple_in_condition((TI.task_id, TI.map_index), schedulable_ti_ids), 

1266 ) 

1267 .update({TI.state: State.SCHEDULED}, synchronize_session=False) 

1268 ) 

1269 

1270 # Tasks using EmptyOperator should not be executed, mark them as success 

1271 if dummy_ti_ids: 

1272 count += ( 

1273 session.query(TI) 

1274 .filter( 

1275 TI.dag_id == self.dag_id, 

1276 TI.run_id == self.run_id, 

1277 TI.task_id.in_(dummy_ti_ids), 

1278 ) 

1279 .update( 

1280 { 

1281 TI.state: State.SUCCESS, 

1282 TI.start_date: timezone.utcnow(), 

1283 TI.end_date: timezone.utcnow(), 

1284 TI.duration: 0, 

1285 }, 

1286 synchronize_session=False, 

1287 ) 

1288 ) 

1289 

1290 return count 

1291 

1292 @provide_session 

1293 def get_log_template(self, *, session: Session = NEW_SESSION) -> LogTemplate: 

1294 if self.log_template_id is None: # DagRun created before LogTemplate introduction. 

1295 template = session.query(LogTemplate).order_by(LogTemplate.id).first() 

1296 else: 

1297 template = session.query(LogTemplate).get(self.log_template_id) 

1298 if template is None: 

1299 raise AirflowException( 

1300 f"No log_template entry found for ID {self.log_template_id!r}. " 

1301 f"Please make sure you set up the metadatabase correctly." 

1302 ) 

1303 return template 

1304 

1305 @provide_session 

1306 def get_log_filename_template(self, *, session: Session = NEW_SESSION) -> str: 

1307 warnings.warn( 

1308 "This method is deprecated. Please use get_log_template instead.", 

1309 RemovedInAirflow3Warning, 

1310 stacklevel=2, 

1311 ) 

1312 return self.get_log_template(session=session).filename 

1313 

1314 

1315class DagRunNote(Base): 

1316 """For storage of arbitrary notes concerning the dagrun instance.""" 

1317 

1318 __tablename__ = "dag_run_note" 

1319 

1320 user_id = Column(Integer, nullable=True) 

1321 dag_run_id = Column(Integer, primary_key=True, nullable=False) 

1322 content = Column(String(1000).with_variant(Text(1000), "mysql")) 

1323 created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False) 

1324 updated_at = Column(UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow, nullable=False) 

1325 

1326 dag_run = relationship("DagRun", back_populates="dag_run_note") 

1327 

1328 __table_args__ = ( 

1329 PrimaryKeyConstraint("dag_run_id", name="dag_run_note_pkey"), 

1330 ForeignKeyConstraint( 

1331 (dag_run_id,), 

1332 ["dag_run.id"], 

1333 name="dag_run_note_dr_fkey", 

1334 ondelete="CASCADE", 

1335 ), 

1336 ForeignKeyConstraint( 

1337 (user_id,), 

1338 ["ab_user.id"], 

1339 name="dag_run_note_user_fkey", 

1340 ), 

1341 ) 

1342 

1343 def __init__(self, content, user_id=None): 

1344 self.content = content 

1345 self.user_id = user_id 

1346 

1347 def __repr__(self): 

1348 prefix = f"<{self.__class__.__name__}: {self.dag_id}.{self.dagrun_id} {self.run_id}" 

1349 if self.map_index != -1: 

1350 prefix += f" map_index={self.map_index}" 

1351 return prefix + ">"