Coverage for /pythoncovmergedfiles/medio/medio/src/airflow/airflow/models/dagrun.py: 25%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

674 statements  

1# 

2# Licensed to the Apache Software Foundation (ASF) under one 

3# or more contributor license agreements. See the NOTICE file 

4# distributed with this work for additional information 

5# regarding copyright ownership. The ASF licenses this file 

6# to you under the Apache License, Version 2.0 (the 

7# "License"); you may not use this file except in compliance 

8# with the License. You may obtain a copy of the License at 

9# 

10# http://www.apache.org/licenses/LICENSE-2.0 

11# 

12# Unless required by applicable law or agreed to in writing, 

13# software distributed under the License is distributed on an 

14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

15# KIND, either express or implied. See the License for the 

16# specific language governing permissions and limitations 

17# under the License. 

18from __future__ import annotations 

19 

20import itertools 

21import os 

22import warnings 

23from collections import defaultdict 

24from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator, NamedTuple, Sequence, TypeVar, overload 

25 

26import re2 

27from sqlalchemy import ( 

28 Boolean, 

29 Column, 

30 ForeignKey, 

31 ForeignKeyConstraint, 

32 Index, 

33 Integer, 

34 PickleType, 

35 PrimaryKeyConstraint, 

36 String, 

37 Text, 

38 UniqueConstraint, 

39 and_, 

40 func, 

41 or_, 

42 text, 

43 update, 

44) 

45from sqlalchemy.exc import IntegrityError 

46from sqlalchemy.ext.associationproxy import association_proxy 

47from sqlalchemy.orm import declared_attr, joinedload, relationship, synonym, validates 

48from sqlalchemy.sql.expression import case, false, select, true 

49 

50from airflow import settings 

51from airflow.api_internal.internal_api_call import internal_api_call 

52from airflow.callbacks.callback_requests import DagCallbackRequest 

53from airflow.configuration import conf as airflow_conf 

54from airflow.exceptions import AirflowException, RemovedInAirflow3Warning, TaskDeferred, TaskNotFound 

55from airflow.listeners.listener import get_listener_manager 

56from airflow.models import Log 

57from airflow.models.abstractoperator import NotMapped 

58from airflow.models.base import Base, StringID 

59from airflow.models.expandinput import NotFullyPopulated 

60from airflow.models.taskinstance import TaskInstance as TI 

61from airflow.models.tasklog import LogTemplate 

62from airflow.stats import Stats 

63from airflow.ti_deps.dep_context import DepContext 

64from airflow.ti_deps.dependencies_states import SCHEDULEABLE_STATES 

65from airflow.utils import timezone 

66from airflow.utils.helpers import chunks, is_container, prune_dict 

67from airflow.utils.log.logging_mixin import LoggingMixin 

68from airflow.utils.session import NEW_SESSION, provide_session 

69from airflow.utils.sqlalchemy import UtcDateTime, nulls_first, tuple_in_condition, with_row_locks 

70from airflow.utils.state import DagRunState, State, TaskInstanceState 

71from airflow.utils.types import NOTSET, DagRunType 

72 

73if TYPE_CHECKING: 

74 from datetime import datetime 

75 

76 from sqlalchemy.orm import Query, Session 

77 

78 from airflow.models.dag import DAG 

79 from airflow.models.operator import Operator 

80 from airflow.serialization.pydantic.dag_run import DagRunPydantic 

81 from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic 

82 from airflow.serialization.pydantic.tasklog import LogTemplatePydantic 

83 from airflow.typing_compat import Literal 

84 from airflow.utils.types import ArgNotSet 

85 

86 CreatedTasks = TypeVar("CreatedTasks", Iterator["dict[str, Any]"], Iterator[TI]) 

87 

88RUN_ID_REGEX = r"^(?:manual|scheduled|dataset_triggered)__(?:\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\+00:00)$" 

89 

90 

91class TISchedulingDecision(NamedTuple): 

92 """Type of return for DagRun.task_instance_scheduling_decisions.""" 

93 

94 tis: list[TI] 

95 schedulable_tis: list[TI] 

96 changed_tis: bool 

97 unfinished_tis: list[TI] 

98 finished_tis: list[TI] 

99 

100 

101def _creator_note(val): 

102 """Creator the ``note`` association proxy.""" 

103 if isinstance(val, str): 

104 return DagRunNote(content=val) 

105 elif isinstance(val, dict): 

106 return DagRunNote(**val) 

107 else: 

108 return DagRunNote(*val) 

109 

110 

111class DagRun(Base, LoggingMixin): 

112 """Invocation instance of a DAG. 

113 

114 A DAG run can be created by the scheduler (i.e. scheduled runs), or by an 

115 external trigger (i.e. manual runs). 

116 """ 

117 

118 __tablename__ = "dag_run" 

119 

120 id = Column(Integer, primary_key=True) 

121 dag_id = Column(StringID(), nullable=False) 

122 queued_at = Column(UtcDateTime) 

123 execution_date = Column(UtcDateTime, default=timezone.utcnow, nullable=False) 

124 start_date = Column(UtcDateTime) 

125 end_date = Column(UtcDateTime) 

126 _state = Column("state", String(50), default=DagRunState.QUEUED) 

127 run_id = Column(StringID(), nullable=False) 

128 creating_job_id = Column(Integer) 

129 external_trigger = Column(Boolean, default=True) 

130 run_type = Column(String(50), nullable=False) 

131 conf = Column(PickleType) 

132 # These two must be either both NULL or both datetime. 

133 data_interval_start = Column(UtcDateTime) 

134 data_interval_end = Column(UtcDateTime) 

135 # When a scheduler last attempted to schedule TIs for this DagRun 

136 last_scheduling_decision = Column(UtcDateTime) 

137 dag_hash = Column(String(32)) 

138 # Foreign key to LogTemplate. DagRun rows created prior to this column's 

139 # existence have this set to NULL. Later rows automatically populate this on 

140 # insert to point to the latest LogTemplate entry. 

141 log_template_id = Column( 

142 Integer, 

143 ForeignKey("log_template.id", name="task_instance_log_template_id_fkey", ondelete="NO ACTION"), 

144 default=select(func.max(LogTemplate.__table__.c.id)), 

145 ) 

146 updated_at = Column(UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow) 

147 # Keeps track of the number of times the dagrun had been cleared. 

148 # This number is incremented only when the DagRun is re-Queued, 

149 # when the DagRun is cleared. 

150 clear_number = Column(Integer, default=0, nullable=False, server_default="0") 

151 

152 # Remove this `if` after upgrading Sphinx-AutoAPI 

153 if not TYPE_CHECKING and "BUILDING_AIRFLOW_DOCS" in os.environ: 

154 dag: DAG | None 

155 else: 

156 dag: DAG | None = None 

157 

158 __table_args__ = ( 

159 Index("dag_id_state", dag_id, _state), 

160 UniqueConstraint("dag_id", "execution_date", name="dag_run_dag_id_execution_date_key"), 

161 UniqueConstraint("dag_id", "run_id", name="dag_run_dag_id_run_id_key"), 

162 Index("idx_dag_run_dag_id", dag_id), 

163 Index( 

164 "idx_dag_run_running_dags", 

165 "state", 

166 "dag_id", 

167 postgresql_where=text("state='running'"), 

168 sqlite_where=text("state='running'"), 

169 ), 

170 # since mysql lacks filtered/partial indices, this creates a 

171 # duplicate index on mysql. Not the end of the world 

172 Index( 

173 "idx_dag_run_queued_dags", 

174 "state", 

175 "dag_id", 

176 postgresql_where=text("state='queued'"), 

177 sqlite_where=text("state='queued'"), 

178 ), 

179 ) 

180 

181 task_instances = relationship( 

182 TI, back_populates="dag_run", cascade="save-update, merge, delete, delete-orphan" 

183 ) 

184 dag_model = relationship( 

185 "DagModel", 

186 primaryjoin="foreign(DagRun.dag_id) == DagModel.dag_id", 

187 uselist=False, 

188 viewonly=True, 

189 ) 

190 dag_run_note = relationship( 

191 "DagRunNote", 

192 back_populates="dag_run", 

193 uselist=False, 

194 cascade="all, delete, delete-orphan", 

195 ) 

196 note = association_proxy("dag_run_note", "content", creator=_creator_note) 

197 

198 DEFAULT_DAGRUNS_TO_EXAMINE = airflow_conf.getint( 

199 "scheduler", 

200 "max_dagruns_per_loop_to_schedule", 

201 fallback=20, 

202 ) 

203 

204 def __init__( 

205 self, 

206 dag_id: str | None = None, 

207 run_id: str | None = None, 

208 queued_at: datetime | None | ArgNotSet = NOTSET, 

209 execution_date: datetime | None = None, 

210 start_date: datetime | None = None, 

211 external_trigger: bool | None = None, 

212 conf: Any | None = None, 

213 state: DagRunState | None = None, 

214 run_type: str | None = None, 

215 dag_hash: str | None = None, 

216 creating_job_id: int | None = None, 

217 data_interval: tuple[datetime, datetime] | None = None, 

218 ): 

219 if data_interval is None: 

220 # Legacy: Only happen for runs created prior to Airflow 2.2. 

221 self.data_interval_start = self.data_interval_end = None 

222 else: 

223 self.data_interval_start, self.data_interval_end = data_interval 

224 

225 self.dag_id = dag_id 

226 self.run_id = run_id 

227 self.execution_date = execution_date 

228 self.start_date = start_date 

229 self.external_trigger = external_trigger 

230 self.conf = conf or {} 

231 if state is not None: 

232 self.state = state 

233 if queued_at is NOTSET: 

234 self.queued_at = timezone.utcnow() if state == DagRunState.QUEUED else None 

235 else: 

236 self.queued_at = queued_at 

237 self.run_type = run_type 

238 self.dag_hash = dag_hash 

239 self.creating_job_id = creating_job_id 

240 self.clear_number = 0 

241 super().__init__() 

242 

243 def __repr__(self): 

244 return ( 

245 f"<DagRun {self.dag_id} @ {self.execution_date}: {self.run_id}, state:{self.state}, " 

246 f"queued_at: {self.queued_at}. externally triggered: {self.external_trigger}>" 

247 ) 

248 

249 @validates("run_id") 

250 def validate_run_id(self, key: str, run_id: str) -> str | None: 

251 if not run_id: 

252 return None 

253 regex = airflow_conf.get("scheduler", "allowed_run_id_pattern") 

254 if not re2.match(regex, run_id) and not re2.match(RUN_ID_REGEX, run_id): 

255 raise ValueError( 

256 f"The run_id provided '{run_id}' does not match the pattern '{regex}' or '{RUN_ID_REGEX}'" 

257 ) 

258 return run_id 

259 

260 @property 

261 def stats_tags(self) -> dict[str, str]: 

262 return prune_dict({"dag_id": self.dag_id, "run_type": self.run_type}) 

263 

264 @property 

265 def logical_date(self) -> datetime: 

266 return self.execution_date 

267 

268 def get_state(self): 

269 return self._state 

270 

271 def set_state(self, state: DagRunState) -> None: 

272 """Change the state of the DagRan. 

273 

274 Changes to attributes are implemented in accordance with the following table 

275 (rows represent old states, columns represent new states): 

276 

277 .. list-table:: State transition matrix 

278 :header-rows: 1 

279 :stub-columns: 1 

280 

281 * - 

282 - QUEUED 

283 - RUNNING 

284 - SUCCESS 

285 - FAILED 

286 * - None 

287 - queued_at = timezone.utcnow() 

288 - if empty: start_date = timezone.utcnow() 

289 end_date = None 

290 - end_date = timezone.utcnow() 

291 - end_date = timezone.utcnow() 

292 * - QUEUED 

293 - queued_at = timezone.utcnow() 

294 - if empty: start_date = timezone.utcnow() 

295 end_date = None 

296 - end_date = timezone.utcnow() 

297 - end_date = timezone.utcnow() 

298 * - RUNNING 

299 - queued_at = timezone.utcnow() 

300 start_date = None 

301 end_date = None 

302 - 

303 - end_date = timezone.utcnow() 

304 - end_date = timezone.utcnow() 

305 * - SUCCESS 

306 - queued_at = timezone.utcnow() 

307 start_date = None 

308 end_date = None 

309 - start_date = timezone.utcnow() 

310 end_date = None 

311 - 

312 - 

313 * - FAILED 

314 - queued_at = timezone.utcnow() 

315 start_date = None 

316 end_date = None 

317 - start_date = timezone.utcnow() 

318 end_date = None 

319 - 

320 - 

321 

322 """ 

323 if state not in State.dag_states: 

324 raise ValueError(f"invalid DagRun state: {state}") 

325 if self._state != state: 

326 if state == DagRunState.QUEUED: 

327 self.queued_at = timezone.utcnow() 

328 self.start_date = None 

329 self.end_date = None 

330 if state == DagRunState.RUNNING: 

331 if self._state in State.finished_dr_states: 

332 self.start_date = timezone.utcnow() 

333 else: 

334 self.start_date = self.start_date or timezone.utcnow() 

335 self.end_date = None 

336 if self._state in State.unfinished_dr_states or self._state is None: 

337 if state in State.finished_dr_states: 

338 self.end_date = timezone.utcnow() 

339 self._state = state 

340 else: 

341 if state == DagRunState.QUEUED: 

342 self.queued_at = timezone.utcnow() 

343 

344 @declared_attr 

345 def state(self): 

346 return synonym("_state", descriptor=property(self.get_state, self.set_state)) 

347 

348 @provide_session 

349 def refresh_from_db(self, session: Session = NEW_SESSION) -> None: 

350 """ 

351 Reload the current dagrun from the database. 

352 

353 :param session: database session 

354 """ 

355 dr = session.scalars( 

356 select(DagRun).where(DagRun.dag_id == self.dag_id, DagRun.run_id == self.run_id) 

357 ).one() 

358 self.id = dr.id 

359 self.state = dr.state 

360 

361 @classmethod 

362 @provide_session 

363 def active_runs_of_dags( 

364 cls, 

365 dag_ids: Iterable[str] | None = None, 

366 only_running: bool = False, 

367 session: Session = NEW_SESSION, 

368 ) -> dict[str, int]: 

369 """Get the number of active dag runs for each dag.""" 

370 query = select(cls.dag_id, func.count("*")) 

371 if dag_ids is not None: 

372 # 'set' called to avoid duplicate dag_ids, but converted back to 'list' 

373 # because SQLAlchemy doesn't accept a set here. 

374 query = query.where(cls.dag_id.in_(set(dag_ids))) 

375 if only_running: 

376 query = query.where(cls.state == DagRunState.RUNNING) 

377 else: 

378 query = query.where(cls.state.in_((DagRunState.RUNNING, DagRunState.QUEUED))) 

379 query = query.group_by(cls.dag_id) 

380 return dict(iter(session.execute(query))) 

381 

382 @classmethod 

383 def next_dagruns_to_examine( 

384 cls, 

385 state: DagRunState, 

386 session: Session, 

387 max_number: int | None = None, 

388 ) -> Query: 

389 """ 

390 Return the next DagRuns that the scheduler should attempt to schedule. 

391 

392 This will return zero or more DagRun rows that are row-level-locked with a "SELECT ... FOR UPDATE" 

393 query, you should ensure that any scheduling decisions are made in a single transaction -- as soon as 

394 the transaction is committed it will be unlocked. 

395 

396 """ 

397 from airflow.models.dag import DagModel 

398 

399 if max_number is None: 

400 max_number = cls.DEFAULT_DAGRUNS_TO_EXAMINE 

401 

402 # TODO: Bake this query, it is run _A lot_ 

403 query = ( 

404 select(cls) 

405 .with_hint(cls, "USE INDEX (idx_dag_run_running_dags)", dialect_name="mysql") 

406 .where(cls.state == state, cls.run_type != DagRunType.BACKFILL_JOB) 

407 .join(DagModel, DagModel.dag_id == cls.dag_id) 

408 .where(DagModel.is_paused == false(), DagModel.is_active == true()) 

409 ) 

410 if state == DagRunState.QUEUED: 

411 # For dag runs in the queued state, we check if they have reached the max_active_runs limit 

412 # and if so we drop them 

413 running_drs = ( 

414 select(DagRun.dag_id, func.count(DagRun.state).label("num_running")) 

415 .where(DagRun.state == DagRunState.RUNNING) 

416 .group_by(DagRun.dag_id) 

417 .subquery() 

418 ) 

419 query = query.outerjoin(running_drs, running_drs.c.dag_id == DagRun.dag_id).where( 

420 func.coalesce(running_drs.c.num_running, 0) < DagModel.max_active_runs 

421 ) 

422 query = query.order_by( 

423 nulls_first(cls.last_scheduling_decision, session=session), 

424 cls.execution_date, 

425 ) 

426 

427 if not settings.ALLOW_FUTURE_EXEC_DATES: 

428 query = query.where(DagRun.execution_date <= func.now()) 

429 

430 return session.scalars( 

431 with_row_locks(query.limit(max_number), of=cls, session=session, skip_locked=True) 

432 ) 

433 

434 @classmethod 

435 @provide_session 

436 def find( 

437 cls, 

438 dag_id: str | list[str] | None = None, 

439 run_id: Iterable[str] | None = None, 

440 execution_date: datetime | Iterable[datetime] | None = None, 

441 state: DagRunState | None = None, 

442 external_trigger: bool | None = None, 

443 no_backfills: bool = False, 

444 run_type: DagRunType | None = None, 

445 session: Session = NEW_SESSION, 

446 execution_start_date: datetime | None = None, 

447 execution_end_date: datetime | None = None, 

448 ) -> list[DagRun]: 

449 """ 

450 Return a set of dag runs for the given search criteria. 

451 

452 :param dag_id: the dag_id or list of dag_id to find dag runs for 

453 :param run_id: defines the run id for this dag run 

454 :param run_type: type of DagRun 

455 :param execution_date: the execution date 

456 :param state: the state of the dag run 

457 :param external_trigger: whether this dag run is externally triggered 

458 :param no_backfills: return no backfills (True), return all (False). 

459 Defaults to False 

460 :param session: database session 

461 :param execution_start_date: dag run that was executed from this date 

462 :param execution_end_date: dag run that was executed until this date 

463 """ 

464 qry = select(cls) 

465 dag_ids = [dag_id] if isinstance(dag_id, str) else dag_id 

466 if dag_ids: 

467 qry = qry.where(cls.dag_id.in_(dag_ids)) 

468 

469 if is_container(run_id): 

470 qry = qry.where(cls.run_id.in_(run_id)) 

471 elif run_id is not None: 

472 qry = qry.where(cls.run_id == run_id) 

473 if is_container(execution_date): 

474 qry = qry.where(cls.execution_date.in_(execution_date)) 

475 elif execution_date is not None: 

476 qry = qry.where(cls.execution_date == execution_date) 

477 if execution_start_date and execution_end_date: 

478 qry = qry.where(cls.execution_date.between(execution_start_date, execution_end_date)) 

479 elif execution_start_date: 

480 qry = qry.where(cls.execution_date >= execution_start_date) 

481 elif execution_end_date: 

482 qry = qry.where(cls.execution_date <= execution_end_date) 

483 if state: 

484 qry = qry.where(cls.state == state) 

485 if external_trigger is not None: 

486 qry = qry.where(cls.external_trigger == external_trigger) 

487 if run_type: 

488 qry = qry.where(cls.run_type == run_type) 

489 if no_backfills: 

490 qry = qry.where(cls.run_type != DagRunType.BACKFILL_JOB) 

491 

492 return session.scalars(qry.order_by(cls.execution_date)).all() 

493 

494 @classmethod 

495 @provide_session 

496 def find_duplicate( 

497 cls, 

498 dag_id: str, 

499 run_id: str, 

500 execution_date: datetime, 

501 session: Session = NEW_SESSION, 

502 ) -> DagRun | None: 

503 """ 

504 Return an existing run for the DAG with a specific run_id or execution_date. 

505 

506 *None* is returned if no such DAG run is found. 

507 

508 :param dag_id: the dag_id to find duplicates for 

509 :param run_id: defines the run id for this dag run 

510 :param execution_date: the execution date 

511 :param session: database session 

512 """ 

513 return session.scalars( 

514 select(cls).where( 

515 cls.dag_id == dag_id, 

516 or_(cls.run_id == run_id, cls.execution_date == execution_date), 

517 ) 

518 ).one_or_none() 

519 

520 @staticmethod 

521 def generate_run_id(run_type: DagRunType, execution_date: datetime) -> str: 

522 """Generate Run ID based on Run Type and Execution Date.""" 

523 # _Ensure_ run_type is a DagRunType, not just a string from user code 

524 return DagRunType(run_type).generate_run_id(execution_date) 

525 

526 @staticmethod 

527 @internal_api_call 

528 @provide_session 

529 def fetch_task_instances( 

530 dag_id: str | None = None, 

531 run_id: str | None = None, 

532 task_ids: list[str] | None = None, 

533 state: Iterable[TaskInstanceState | None] | None = None, 

534 session: Session = NEW_SESSION, 

535 ) -> list[TI]: 

536 """Return the task instances for this dag run.""" 

537 tis = ( 

538 select(TI) 

539 .options(joinedload(TI.dag_run)) 

540 .where( 

541 TI.dag_id == dag_id, 

542 TI.run_id == run_id, 

543 ) 

544 ) 

545 

546 if state: 

547 if isinstance(state, str): 

548 tis = tis.where(TI.state == state) 

549 else: 

550 # this is required to deal with NULL values 

551 if None in state: 

552 if all(x is None for x in state): 

553 tis = tis.where(TI.state.is_(None)) 

554 else: 

555 not_none_state = (s for s in state if s) 

556 tis = tis.where(or_(TI.state.in_(not_none_state), TI.state.is_(None))) 

557 else: 

558 tis = tis.where(TI.state.in_(state)) 

559 

560 if task_ids is not None: 

561 tis = tis.where(TI.task_id.in_(task_ids)) 

562 return session.scalars(tis).all() 

563 

564 @internal_api_call 

565 def _check_last_n_dagruns_failed(self, dag_id, max_consecutive_failed_dag_runs, session): 

566 """Check if last N dags failed.""" 

567 dag_runs = ( 

568 session.query(DagRun) 

569 .filter(DagRun.dag_id == dag_id) 

570 .order_by(DagRun.execution_date.desc()) 

571 .limit(max_consecutive_failed_dag_runs) 

572 .all() 

573 ) 

574 """ Marking dag as paused, if needed""" 

575 to_be_paused = len(dag_runs) >= max_consecutive_failed_dag_runs and all( 

576 dag_run.state == DagRunState.FAILED for dag_run in dag_runs 

577 ) 

578 

579 if to_be_paused: 

580 from airflow.models.dag import DagModel 

581 

582 self.log.info( 

583 "Pausing DAG %s because last %s DAG runs failed.", 

584 self.dag_id, 

585 max_consecutive_failed_dag_runs, 

586 ) 

587 filter_query = [ 

588 DagModel.dag_id == self.dag_id, 

589 DagModel.root_dag_id == self.dag_id, # for sub-dags 

590 ] 

591 session.execute( 

592 update(DagModel) 

593 .where(or_(*filter_query)) 

594 .values(is_paused=True) 

595 .execution_options(synchronize_session="fetch") 

596 ) 

597 session.add( 

598 Log( 

599 event="paused", 

600 dag_id=self.dag_id, 

601 owner="scheduler", 

602 owner_display_name="Scheduler", 

603 extra=f"[('dag_id', '{self.dag_id}'), ('is_paused', True)]", 

604 ) 

605 ) 

606 else: 

607 self.log.debug( 

608 "Limit of consecutive DAG failed dag runs is not reached, DAG %s will not be paused.", 

609 self.dag_id, 

610 ) 

611 

612 @provide_session 

613 def get_task_instances( 

614 self, 

615 state: Iterable[TaskInstanceState | None] | None = None, 

616 session: Session = NEW_SESSION, 

617 ) -> list[TI]: 

618 """ 

619 Return the task instances for this dag run. 

620 

621 Redirect to DagRun.fetch_task_instances method. 

622 Keep this method because it is widely used across the code. 

623 """ 

624 task_ids = self.dag.task_ids if self.dag and self.dag.partial else None 

625 return DagRun.fetch_task_instances( 

626 dag_id=self.dag_id, run_id=self.run_id, task_ids=task_ids, state=state, session=session 

627 ) 

628 

629 @provide_session 

630 def get_task_instance( 

631 self, 

632 task_id: str, 

633 session: Session = NEW_SESSION, 

634 *, 

635 map_index: int = -1, 

636 ) -> TI | TaskInstancePydantic | None: 

637 """ 

638 Return the task instance specified by task_id for this dag run. 

639 

640 :param task_id: the task id 

641 :param session: Sqlalchemy ORM Session 

642 """ 

643 return DagRun.fetch_task_instance( 

644 dag_id=self.dag_id, 

645 dag_run_id=self.run_id, 

646 task_id=task_id, 

647 session=session, 

648 map_index=map_index, 

649 ) 

650 

651 @staticmethod 

652 @provide_session 

653 def fetch_task_instance( 

654 dag_id: str, 

655 dag_run_id: str, 

656 task_id: str, 

657 session: Session = NEW_SESSION, 

658 map_index: int = -1, 

659 ) -> TI | TaskInstancePydantic | None: 

660 """ 

661 Return the task instance specified by task_id for this dag run. 

662 

663 :param dag_id: the DAG id 

664 :param dag_run_id: the DAG run id 

665 :param task_id: the task id 

666 :param session: Sqlalchemy ORM Session 

667 """ 

668 return session.scalars( 

669 select(TI).filter_by(dag_id=dag_id, run_id=dag_run_id, task_id=task_id, map_index=map_index) 

670 ).one_or_none() 

671 

672 def get_dag(self) -> DAG: 

673 """ 

674 Return the Dag associated with this DagRun. 

675 

676 :return: DAG 

677 """ 

678 if not self.dag: 

679 raise AirflowException(f"The DAG (.dag) for {self} needs to be set") 

680 

681 return self.dag 

682 

683 @staticmethod 

684 @internal_api_call 

685 @provide_session 

686 def get_previous_dagrun( 

687 dag_run: DagRun | DagRunPydantic, state: DagRunState | None = None, session: Session = NEW_SESSION 

688 ) -> DagRun | None: 

689 """ 

690 Return the previous DagRun, if there is one. 

691 

692 :param dag_run: the dag run 

693 :param session: SQLAlchemy ORM Session 

694 :param state: the dag run state 

695 """ 

696 filters = [ 

697 DagRun.dag_id == dag_run.dag_id, 

698 DagRun.execution_date < dag_run.execution_date, 

699 ] 

700 if state is not None: 

701 filters.append(DagRun.state == state) 

702 return session.scalar(select(DagRun).where(*filters).order_by(DagRun.execution_date.desc()).limit(1)) 

703 

704 @staticmethod 

705 @internal_api_call 

706 @provide_session 

707 def get_previous_scheduled_dagrun( 

708 dag_run_id: int, 

709 session: Session = NEW_SESSION, 

710 ) -> DagRun | None: 

711 """ 

712 Return the previous SCHEDULED DagRun, if there is one. 

713 

714 :param dag_run_id: the DAG run ID 

715 :param session: SQLAlchemy ORM Session 

716 """ 

717 dag_run = session.get(DagRun, dag_run_id) 

718 return session.scalar( 

719 select(DagRun) 

720 .where( 

721 DagRun.dag_id == dag_run.dag_id, 

722 DagRun.execution_date < dag_run.execution_date, 

723 DagRun.run_type != DagRunType.MANUAL, 

724 ) 

725 .order_by(DagRun.execution_date.desc()) 

726 .limit(1) 

727 ) 

728 

729 def _tis_for_dagrun_state(self, *, dag, tis): 

730 """ 

731 Return the collection of tasks that should be considered for evaluation of terminal dag run state. 

732 

733 Teardown tasks by default are not considered for the purpose of dag run state. But 

734 users may enable such consideration with on_failure_fail_dagrun. 

735 """ 

736 

737 def is_effective_leaf(task): 

738 for down_task_id in task.downstream_task_ids: 

739 down_task = dag.get_task(down_task_id) 

740 if not down_task.is_teardown or down_task.on_failure_fail_dagrun: 

741 # we found a down task that is not ignorable; not a leaf 

742 return False 

743 # we found no ignorable downstreams 

744 # evaluate whether task is itself ignorable 

745 return not task.is_teardown or task.on_failure_fail_dagrun 

746 

747 leaf_task_ids = {x.task_id for x in dag.tasks if is_effective_leaf(x)} 

748 if not leaf_task_ids: 

749 # can happen if dag is exclusively teardown tasks 

750 leaf_task_ids = {x.task_id for x in dag.tasks if not x.downstream_list} 

751 leaf_tis = {ti for ti in tis if ti.task_id in leaf_task_ids if ti.state != TaskInstanceState.REMOVED} 

752 return leaf_tis 

753 

754 @provide_session 

755 def update_state( 

756 self, session: Session = NEW_SESSION, execute_callbacks: bool = True 

757 ) -> tuple[list[TI], DagCallbackRequest | None]: 

758 """ 

759 Determine the overall state of the DagRun based on the state of its TaskInstances. 

760 

761 :param session: Sqlalchemy ORM Session 

762 :param execute_callbacks: Should dag callbacks (success/failure, SLA etc.) be invoked 

763 directly (default: true) or recorded as a pending request in the ``returned_callback`` property 

764 :return: Tuple containing tis that can be scheduled in the current loop & `returned_callback` that 

765 needs to be executed 

766 """ 

767 # Callback to execute in case of Task Failures 

768 callback: DagCallbackRequest | None = None 

769 

770 class _UnfinishedStates(NamedTuple): 

771 tis: Sequence[TI] 

772 

773 @classmethod 

774 def calculate(cls, unfinished_tis: Sequence[TI]) -> _UnfinishedStates: 

775 return cls(tis=unfinished_tis) 

776 

777 @property 

778 def should_schedule(self) -> bool: 

779 return ( 

780 bool(self.tis) 

781 and all(not t.task.depends_on_past for t in self.tis) # type: ignore[union-attr] 

782 and all(t.task.max_active_tis_per_dag is None for t in self.tis) # type: ignore[union-attr] 

783 and all(t.task.max_active_tis_per_dagrun is None for t in self.tis) # type: ignore[union-attr] 

784 and all(t.state != TaskInstanceState.DEFERRED for t in self.tis) 

785 ) 

786 

787 def recalculate(self) -> _UnfinishedStates: 

788 return self._replace(tis=[t for t in self.tis if t.state in State.unfinished]) 

789 

790 start_dttm = timezone.utcnow() 

791 self.last_scheduling_decision = start_dttm 

792 with Stats.timer(f"dagrun.dependency-check.{self.dag_id}"), Stats.timer( 

793 "dagrun.dependency-check", tags=self.stats_tags 

794 ): 

795 dag = self.get_dag() 

796 info = self.task_instance_scheduling_decisions(session) 

797 

798 tis = info.tis 

799 schedulable_tis = info.schedulable_tis 

800 changed_tis = info.changed_tis 

801 finished_tis = info.finished_tis 

802 unfinished = _UnfinishedStates.calculate(info.unfinished_tis) 

803 

804 if unfinished.should_schedule: 

805 are_runnable_tasks = schedulable_tis or changed_tis 

806 # small speed up 

807 if not are_runnable_tasks: 

808 are_runnable_tasks, changed_by_upstream = self._are_premature_tis( 

809 unfinished.tis, finished_tis, session 

810 ) 

811 if changed_by_upstream: # Something changed, we need to recalculate! 

812 unfinished = unfinished.recalculate() 

813 

814 tis_for_dagrun_state = self._tis_for_dagrun_state(dag=dag, tis=tis) 

815 

816 # if all tasks finished and at least one failed, the run failed 

817 if not unfinished.tis and any(x.state in State.failed_states for x in tis_for_dagrun_state): 

818 self.log.error("Marking run %s failed", self) 

819 self.set_state(DagRunState.FAILED) 

820 self.notify_dagrun_state_changed(msg="task_failure") 

821 

822 if execute_callbacks: 

823 dag.handle_callback(self, success=False, reason="task_failure", session=session) 

824 elif dag.has_on_failure_callback: 

825 from airflow.models.dag import DagModel 

826 

827 dag_model = DagModel.get_dagmodel(dag.dag_id, session) 

828 callback = DagCallbackRequest( 

829 full_filepath=dag.fileloc, 

830 dag_id=self.dag_id, 

831 run_id=self.run_id, 

832 is_failure_callback=True, 

833 processor_subdir=None if dag_model is None else dag_model.processor_subdir, 

834 msg="task_failure", 

835 ) 

836 

837 # Check if the max_consecutive_failed_dag_runs has been provided and not 0 

838 # and last consecutive failures are more 

839 if dag.max_consecutive_failed_dag_runs > 0: 

840 self.log.debug( 

841 "Checking consecutive failed DAG runs for DAG %s, limit is %s", 

842 self.dag_id, 

843 dag.max_consecutive_failed_dag_runs, 

844 ) 

845 self._check_last_n_dagruns_failed(dag.dag_id, dag.max_consecutive_failed_dag_runs, session) 

846 

847 # if all leaves succeeded and no unfinished tasks, the run succeeded 

848 elif not unfinished.tis and all(x.state in State.success_states for x in tis_for_dagrun_state): 

849 self.log.info("Marking run %s successful", self) 

850 self.set_state(DagRunState.SUCCESS) 

851 self.notify_dagrun_state_changed(msg="success") 

852 

853 if execute_callbacks: 

854 dag.handle_callback(self, success=True, reason="success", session=session) 

855 elif dag.has_on_success_callback: 

856 from airflow.models.dag import DagModel 

857 

858 dag_model = DagModel.get_dagmodel(dag.dag_id, session) 

859 callback = DagCallbackRequest( 

860 full_filepath=dag.fileloc, 

861 dag_id=self.dag_id, 

862 run_id=self.run_id, 

863 is_failure_callback=False, 

864 processor_subdir=None if dag_model is None else dag_model.processor_subdir, 

865 msg="success", 

866 ) 

867 

868 # if *all tasks* are deadlocked, the run failed 

869 elif unfinished.should_schedule and not are_runnable_tasks: 

870 self.log.error("Task deadlock (no runnable tasks); marking run %s failed", self) 

871 self.set_state(DagRunState.FAILED) 

872 self.notify_dagrun_state_changed(msg="all_tasks_deadlocked") 

873 

874 if execute_callbacks: 

875 dag.handle_callback(self, success=False, reason="all_tasks_deadlocked", session=session) 

876 elif dag.has_on_failure_callback: 

877 from airflow.models.dag import DagModel 

878 

879 dag_model = DagModel.get_dagmodel(dag.dag_id, session) 

880 callback = DagCallbackRequest( 

881 full_filepath=dag.fileloc, 

882 dag_id=self.dag_id, 

883 run_id=self.run_id, 

884 is_failure_callback=True, 

885 processor_subdir=None if dag_model is None else dag_model.processor_subdir, 

886 msg="all_tasks_deadlocked", 

887 ) 

888 

889 # finally, if the leaves aren't done, the dag is still running 

890 else: 

891 self.set_state(DagRunState.RUNNING) 

892 

893 if self._state == DagRunState.FAILED or self._state == DagRunState.SUCCESS: 

894 msg = ( 

895 "DagRun Finished: dag_id=%s, execution_date=%s, run_id=%s, " 

896 "run_start_date=%s, run_end_date=%s, run_duration=%s, " 

897 "state=%s, external_trigger=%s, run_type=%s, " 

898 "data_interval_start=%s, data_interval_end=%s, dag_hash=%s" 

899 ) 

900 self.log.info( 

901 msg, 

902 self.dag_id, 

903 self.execution_date, 

904 self.run_id, 

905 self.start_date, 

906 self.end_date, 

907 ( 

908 (self.end_date - self.start_date).total_seconds() 

909 if self.start_date and self.end_date 

910 else None 

911 ), 

912 self._state, 

913 self.external_trigger, 

914 self.run_type, 

915 self.data_interval_start, 

916 self.data_interval_end, 

917 self.dag_hash, 

918 ) 

919 session.flush() 

920 

921 self._emit_true_scheduling_delay_stats_for_finished_state(finished_tis) 

922 self._emit_duration_stats_for_finished_state() 

923 

924 session.merge(self) 

925 # We do not flush here for performance reasons(It increases queries count by +20) 

926 

927 return schedulable_tis, callback 

928 

929 @provide_session 

930 def task_instance_scheduling_decisions(self, session: Session = NEW_SESSION) -> TISchedulingDecision: 

931 tis = self.get_task_instances(session=session, state=State.task_states) 

932 self.log.debug("number of tis tasks for %s: %s task(s)", self, len(tis)) 

933 

934 def _filter_tis_and_exclude_removed(dag: DAG, tis: list[TI]) -> Iterable[TI]: 

935 """Populate ``ti.task`` while excluding those missing one, marking them as REMOVED.""" 

936 for ti in tis: 

937 try: 

938 ti.task = dag.get_task(ti.task_id) 

939 except TaskNotFound: 

940 if ti.state != TaskInstanceState.REMOVED: 

941 self.log.error("Failed to get task for ti %s. Marking it as removed.", ti) 

942 ti.state = TaskInstanceState.REMOVED 

943 session.flush() 

944 else: 

945 yield ti 

946 

947 tis = list(_filter_tis_and_exclude_removed(self.get_dag(), tis)) 

948 

949 unfinished_tis = [t for t in tis if t.state in State.unfinished] 

950 finished_tis = [t for t in tis if t.state in State.finished] 

951 if unfinished_tis: 

952 schedulable_tis = [ut for ut in unfinished_tis if ut.state in SCHEDULEABLE_STATES] 

953 self.log.debug("number of scheduleable tasks for %s: %s task(s)", self, len(schedulable_tis)) 

954 schedulable_tis, changed_tis, expansion_happened = self._get_ready_tis( 

955 schedulable_tis, 

956 finished_tis, 

957 session=session, 

958 ) 

959 

960 # During expansion, we may change some tis into non-schedulable 

961 # states, so we need to re-compute. 

962 if expansion_happened: 

963 changed_tis = True 

964 new_unfinished_tis = [t for t in unfinished_tis if t.state in State.unfinished] 

965 finished_tis.extend(t for t in unfinished_tis if t.state in State.finished) 

966 unfinished_tis = new_unfinished_tis 

967 else: 

968 schedulable_tis = [] 

969 changed_tis = False 

970 

971 return TISchedulingDecision( 

972 tis=tis, 

973 schedulable_tis=schedulable_tis, 

974 changed_tis=changed_tis, 

975 unfinished_tis=unfinished_tis, 

976 finished_tis=finished_tis, 

977 ) 

978 

979 def notify_dagrun_state_changed(self, msg: str = ""): 

980 if self.state == DagRunState.RUNNING: 

981 get_listener_manager().hook.on_dag_run_running(dag_run=self, msg=msg) 

982 elif self.state == DagRunState.SUCCESS: 

983 get_listener_manager().hook.on_dag_run_success(dag_run=self, msg=msg) 

984 elif self.state == DagRunState.FAILED: 

985 get_listener_manager().hook.on_dag_run_failed(dag_run=self, msg=msg) 

986 # deliberately not notifying on QUEUED 

987 # we can't get all the state changes on SchedulerJob, BackfillJob 

988 # or LocalTaskJob, so we don't want to "falsely advertise" we notify about that 

989 

990 def _get_ready_tis( 

991 self, 

992 schedulable_tis: list[TI], 

993 finished_tis: list[TI], 

994 session: Session, 

995 ) -> tuple[list[TI], bool, bool]: 

996 old_states = {} 

997 ready_tis: list[TI] = [] 

998 changed_tis = False 

999 

1000 if not schedulable_tis: 

1001 return ready_tis, changed_tis, False 

1002 

1003 # If we expand TIs, we need a new list so that we iterate over them too. (We can't alter 

1004 # `schedulable_tis` in place and have the `for` loop pick them up 

1005 additional_tis: list[TI] = [] 

1006 dep_context = DepContext( 

1007 flag_upstream_failed=True, 

1008 ignore_unmapped_tasks=True, # Ignore this Dep, as we will expand it if we can. 

1009 finished_tis=finished_tis, 

1010 ) 

1011 

1012 def _expand_mapped_task_if_needed(ti: TI) -> Iterable[TI] | None: 

1013 """Try to expand the ti, if needed. 

1014 

1015 If the ti needs expansion, newly created task instances are 

1016 returned as well as the original ti. 

1017 The original ti is also modified in-place and assigned the 

1018 ``map_index`` of 0. 

1019 

1020 If the ti does not need expansion, either because the task is not 

1021 mapped, or has already been expanded, *None* is returned. 

1022 """ 

1023 if TYPE_CHECKING: 

1024 assert ti.task 

1025 

1026 if ti.map_index >= 0: # Already expanded, we're good. 

1027 return None 

1028 

1029 from airflow.models.mappedoperator import MappedOperator 

1030 

1031 if isinstance(ti.task, MappedOperator): 

1032 # If we get here, it could be that we are moving from non-mapped to mapped 

1033 # after task instance clearing or this ti is not yet expanded. Safe to clear 

1034 # the db references. 

1035 ti.clear_db_references(session=session) 

1036 try: 

1037 expanded_tis, _ = ti.task.expand_mapped_task(self.run_id, session=session) 

1038 except NotMapped: # Not a mapped task, nothing needed. 

1039 return None 

1040 if expanded_tis: 

1041 return expanded_tis 

1042 return () 

1043 

1044 # Check dependencies. 

1045 expansion_happened = False 

1046 # Set of task ids for which was already done _revise_map_indexes_if_mapped 

1047 revised_map_index_task_ids = set() 

1048 for schedulable in itertools.chain(schedulable_tis, additional_tis): 

1049 if TYPE_CHECKING: 

1050 assert schedulable.task 

1051 old_state = schedulable.state 

1052 if not schedulable.are_dependencies_met(session=session, dep_context=dep_context): 

1053 old_states[schedulable.key] = old_state 

1054 continue 

1055 # If schedulable is not yet expanded, try doing it now. This is 

1056 # called in two places: First and ideally in the mini scheduler at 

1057 # the end of LocalTaskJob, and then as an "expansion of last resort" 

1058 # in the scheduler to ensure that the mapped task is correctly 

1059 # expanded before executed. Also see _revise_map_indexes_if_mapped 

1060 # docstring for additional information. 

1061 new_tis = None 

1062 if schedulable.map_index < 0: 

1063 new_tis = _expand_mapped_task_if_needed(schedulable) 

1064 if new_tis is not None: 

1065 additional_tis.extend(new_tis) 

1066 expansion_happened = True 

1067 if new_tis is None and schedulable.state in SCHEDULEABLE_STATES: 

1068 # It's enough to revise map index once per task id, 

1069 # checking the map index for each mapped task significantly slows down scheduling 

1070 if schedulable.task.task_id not in revised_map_index_task_ids: 

1071 ready_tis.extend(self._revise_map_indexes_if_mapped(schedulable.task, session=session)) 

1072 revised_map_index_task_ids.add(schedulable.task.task_id) 

1073 ready_tis.append(schedulable) 

1074 

1075 # Check if any ti changed state 

1076 tis_filter = TI.filter_for_tis(old_states) 

1077 if tis_filter is not None: 

1078 fresh_tis = session.scalars(select(TI).where(tis_filter)).all() 

1079 changed_tis = any(ti.state != old_states[ti.key] for ti in fresh_tis) 

1080 

1081 return ready_tis, changed_tis, expansion_happened 

1082 

1083 def _are_premature_tis( 

1084 self, 

1085 unfinished_tis: Sequence[TI], 

1086 finished_tis: list[TI], 

1087 session: Session, 

1088 ) -> tuple[bool, bool]: 

1089 dep_context = DepContext( 

1090 flag_upstream_failed=True, 

1091 ignore_in_retry_period=True, 

1092 ignore_in_reschedule_period=True, 

1093 finished_tis=finished_tis, 

1094 ) 

1095 # there might be runnable tasks that are up for retry and for some reason(retry delay, etc.) are 

1096 # not ready yet, so we set the flags to count them in 

1097 return ( 

1098 any(ut.are_dependencies_met(dep_context=dep_context, session=session) for ut in unfinished_tis), 

1099 dep_context.have_changed_ti_states, 

1100 ) 

1101 

1102 def _emit_true_scheduling_delay_stats_for_finished_state(self, finished_tis: list[TI]) -> None: 

1103 """Emit the true scheduling delay stats. 

1104 

1105 The true scheduling delay stats is defined as the time when the first 

1106 task in DAG starts minus the expected DAG run datetime. 

1107 

1108 This helper method is used in ``update_state`` when the state of the 

1109 DAG run is updated to a completed status (either success or failure). 

1110 It finds the first started task within the DAG, calculates the run's 

1111 expected start time based on the logical date and timetable, and gets 

1112 the delay from the difference of these two values. 

1113 

1114 The emitted data may contain outliers (e.g. when the first task was 

1115 cleared, so the second task's start date will be used), but we can get 

1116 rid of the outliers on the stats side through dashboards tooling. 

1117 

1118 Note that the stat will only be emitted for scheduler-triggered DAG runs 

1119 (i.e. when ``external_trigger`` is *False* and ``clear_number`` is equal to 0). 

1120 """ 

1121 if self.state == TaskInstanceState.RUNNING: 

1122 return 

1123 if self.external_trigger: 

1124 return 

1125 if self.clear_number > 0: 

1126 return 

1127 if not finished_tis: 

1128 return 

1129 

1130 try: 

1131 dag = self.get_dag() 

1132 

1133 if not dag.timetable.periodic: 

1134 # We can't emit this metric if there is no following schedule to calculate from! 

1135 return 

1136 

1137 try: 

1138 first_start_date = min(ti.start_date for ti in finished_tis if ti.start_date) 

1139 except ValueError: # No start dates at all. 

1140 pass 

1141 else: 

1142 # TODO: Logically, this should be DagRunInfo.run_after, but the 

1143 # information is not stored on a DagRun, only before the actual 

1144 # execution on DagModel.next_dagrun_create_after. We should add 

1145 # a field on DagRun for this instead of relying on the run 

1146 # always happening immediately after the data interval. 

1147 data_interval_end = dag.get_run_data_interval(self).end 

1148 true_delay = first_start_date - data_interval_end 

1149 if true_delay.total_seconds() > 0: 

1150 Stats.timing( 

1151 f"dagrun.{dag.dag_id}.first_task_scheduling_delay", true_delay, tags=self.stats_tags 

1152 ) 

1153 Stats.timing("dagrun.first_task_scheduling_delay", true_delay, tags=self.stats_tags) 

1154 except Exception: 

1155 self.log.warning("Failed to record first_task_scheduling_delay metric:", exc_info=True) 

1156 

1157 def _emit_duration_stats_for_finished_state(self): 

1158 if self.state == DagRunState.RUNNING: 

1159 return 

1160 if self.start_date is None: 

1161 self.log.warning("Failed to record duration of %s: start_date is not set.", self) 

1162 return 

1163 if self.end_date is None: 

1164 self.log.warning("Failed to record duration of %s: end_date is not set.", self) 

1165 return 

1166 

1167 duration = self.end_date - self.start_date 

1168 timer_params = {"dt": duration, "tags": self.stats_tags} 

1169 Stats.timing(f"dagrun.duration.{self.state}.{self.dag_id}", **timer_params) 

1170 Stats.timing(f"dagrun.duration.{self.state}", **timer_params) 

1171 

1172 @provide_session 

1173 def verify_integrity(self, *, session: Session = NEW_SESSION) -> None: 

1174 """ 

1175 Verify the DagRun by checking for removed tasks or tasks that are not in the database yet. 

1176 

1177 It will set state to removed or add the task if required. 

1178 

1179 :missing_indexes: A dictionary of task vs indexes that are missing. 

1180 :param session: Sqlalchemy ORM Session 

1181 """ 

1182 from airflow.settings import task_instance_mutation_hook 

1183 

1184 # Set for the empty default in airflow.settings -- if it's not set this means it has been changed 

1185 # Note: Literal[True, False] instead of bool because otherwise it doesn't correctly find the overload. 

1186 hook_is_noop: Literal[True, False] = getattr(task_instance_mutation_hook, "is_noop", False) 

1187 

1188 dag = self.get_dag() 

1189 task_ids = self._check_for_removed_or_restored_tasks( 

1190 dag, task_instance_mutation_hook, session=session 

1191 ) 

1192 

1193 def task_filter(task: Operator) -> bool: 

1194 return task.task_id not in task_ids and ( 

1195 self.is_backfill 

1196 or (task.start_date is None or task.start_date <= self.execution_date) 

1197 and (task.end_date is None or self.execution_date <= task.end_date) 

1198 ) 

1199 

1200 created_counts: dict[str, int] = defaultdict(int) 

1201 task_creator = self._get_task_creator(created_counts, task_instance_mutation_hook, hook_is_noop) 

1202 

1203 # Create the missing tasks, including mapped tasks 

1204 tasks_to_create = (task for task in dag.task_dict.values() if task_filter(task)) 

1205 tis_to_create = self._create_tasks(tasks_to_create, task_creator, session=session) 

1206 self._create_task_instances(self.dag_id, tis_to_create, created_counts, hook_is_noop, session=session) 

1207 

1208 def _check_for_removed_or_restored_tasks( 

1209 self, dag: DAG, ti_mutation_hook, *, session: Session 

1210 ) -> set[str]: 

1211 """ 

1212 Check for removed tasks/restored/missing tasks. 

1213 

1214 :param dag: DAG object corresponding to the dagrun 

1215 :param ti_mutation_hook: task_instance_mutation_hook function 

1216 :param session: Sqlalchemy ORM Session 

1217 

1218 :return: Task IDs in the DAG run 

1219 

1220 """ 

1221 tis = self.get_task_instances(session=session) 

1222 

1223 # check for removed or restored tasks 

1224 task_ids = set() 

1225 for ti in tis: 

1226 ti_mutation_hook(ti) 

1227 task_ids.add(ti.task_id) 

1228 try: 

1229 task = dag.get_task(ti.task_id) 

1230 

1231 should_restore_task = (task is not None) and ti.state == TaskInstanceState.REMOVED 

1232 if should_restore_task: 

1233 self.log.info("Restoring task '%s' which was previously removed from DAG '%s'", ti, dag) 

1234 Stats.incr(f"task_restored_to_dag.{dag.dag_id}", tags=self.stats_tags) 

1235 # Same metric with tagging 

1236 Stats.incr("task_restored_to_dag", tags={**self.stats_tags, "dag_id": dag.dag_id}) 

1237 ti.state = None 

1238 except AirflowException: 

1239 if ti.state == TaskInstanceState.REMOVED: 

1240 pass # ti has already been removed, just ignore it 

1241 elif self.state != DagRunState.RUNNING and not dag.partial: 

1242 self.log.warning("Failed to get task '%s' for dag '%s'. Marking it as removed.", ti, dag) 

1243 Stats.incr(f"task_removed_from_dag.{dag.dag_id}", tags=self.stats_tags) 

1244 # Same metric with tagging 

1245 Stats.incr("task_removed_from_dag", tags={**self.stats_tags, "dag_id": dag.dag_id}) 

1246 ti.state = TaskInstanceState.REMOVED 

1247 continue 

1248 

1249 try: 

1250 num_mapped_tis = task.get_parse_time_mapped_ti_count() 

1251 except NotMapped: 

1252 continue 

1253 except NotFullyPopulated: 

1254 # What if it is _now_ dynamically mapped, but wasn't before? 

1255 try: 

1256 total_length = task.get_mapped_ti_count(self.run_id, session=session) 

1257 except NotFullyPopulated: 

1258 # Not all upstreams finished, so we can't tell what should be here. Remove everything. 

1259 if ti.map_index >= 0: 

1260 self.log.debug( 

1261 "Removing the unmapped TI '%s' as the mapping can't be resolved yet", ti 

1262 ) 

1263 ti.state = TaskInstanceState.REMOVED 

1264 continue 

1265 # Upstreams finished, check there aren't any extras 

1266 if ti.map_index >= total_length: 

1267 self.log.debug( 

1268 "Removing task '%s' as the map_index is longer than the resolved mapping list (%d)", 

1269 ti, 

1270 total_length, 

1271 ) 

1272 ti.state = TaskInstanceState.REMOVED 

1273 else: 

1274 # Check if the number of mapped literals has changed, and we need to mark this TI as removed. 

1275 if ti.map_index >= num_mapped_tis: 

1276 self.log.debug( 

1277 "Removing task '%s' as the map_index is longer than the literal mapping list (%s)", 

1278 ti, 

1279 num_mapped_tis, 

1280 ) 

1281 ti.state = TaskInstanceState.REMOVED 

1282 elif ti.map_index < 0: 

1283 self.log.debug("Removing the unmapped TI '%s' as the mapping can now be performed", ti) 

1284 ti.state = TaskInstanceState.REMOVED 

1285 

1286 return task_ids 

1287 

1288 @overload 

1289 def _get_task_creator( 

1290 self, 

1291 created_counts: dict[str, int], 

1292 ti_mutation_hook: Callable, 

1293 hook_is_noop: Literal[True], 

1294 ) -> Callable[[Operator, Iterable[int]], Iterator[dict[str, Any]]]: ... 

1295 

1296 @overload 

1297 def _get_task_creator( 

1298 self, 

1299 created_counts: dict[str, int], 

1300 ti_mutation_hook: Callable, 

1301 hook_is_noop: Literal[False], 

1302 ) -> Callable[[Operator, Iterable[int]], Iterator[TI]]: ... 

1303 

1304 def _get_task_creator( 

1305 self, 

1306 created_counts: dict[str, int], 

1307 ti_mutation_hook: Callable, 

1308 hook_is_noop: Literal[True, False], 

1309 ) -> Callable[[Operator, Iterable[int]], Iterator[dict[str, Any]] | Iterator[TI]]: 

1310 """ 

1311 Get the task creator function. 

1312 

1313 This function also updates the created_counts dictionary with the number of tasks created. 

1314 

1315 :param created_counts: Dictionary of task_type -> count of created TIs 

1316 :param ti_mutation_hook: task_instance_mutation_hook function 

1317 :param hook_is_noop: Whether the task_instance_mutation_hook is a noop 

1318 

1319 """ 

1320 if hook_is_noop: 

1321 

1322 def create_ti_mapping(task: Operator, indexes: Iterable[int]) -> Iterator[dict[str, Any]]: 

1323 created_counts[task.task_type] += 1 

1324 for map_index in indexes: 

1325 yield TI.insert_mapping(self.run_id, task, map_index=map_index) 

1326 

1327 creator = create_ti_mapping 

1328 

1329 else: 

1330 

1331 def create_ti(task: Operator, indexes: Iterable[int]) -> Iterator[TI]: 

1332 for map_index in indexes: 

1333 ti = TI(task, run_id=self.run_id, map_index=map_index) 

1334 ti_mutation_hook(ti) 

1335 created_counts[ti.operator] += 1 

1336 yield ti 

1337 

1338 creator = create_ti 

1339 return creator 

1340 

1341 def _create_tasks( 

1342 self, 

1343 tasks: Iterable[Operator], 

1344 task_creator: Callable[[Operator, Iterable[int]], CreatedTasks], 

1345 *, 

1346 session: Session, 

1347 ) -> CreatedTasks: 

1348 """ 

1349 Create missing tasks -- and expand any MappedOperator that _only_ have literals as input. 

1350 

1351 :param tasks: Tasks to create jobs for in the DAG run 

1352 :param task_creator: Function to create task instances 

1353 """ 

1354 map_indexes: Iterable[int] 

1355 for task in tasks: 

1356 try: 

1357 count = task.get_mapped_ti_count(self.run_id, session=session) 

1358 except (NotMapped, NotFullyPopulated): 

1359 map_indexes = (-1,) 

1360 else: 

1361 if count: 

1362 map_indexes = range(count) 

1363 else: 

1364 # Make sure to always create at least one ti; this will be 

1365 # marked as REMOVED later at runtime. 

1366 map_indexes = (-1,) 

1367 yield from task_creator(task, map_indexes) 

1368 

1369 def _create_task_instances( 

1370 self, 

1371 dag_id: str, 

1372 tasks: Iterator[dict[str, Any]] | Iterator[TI], 

1373 created_counts: dict[str, int], 

1374 hook_is_noop: bool, 

1375 *, 

1376 session: Session, 

1377 ) -> None: 

1378 """ 

1379 Create the necessary task instances from the given tasks. 

1380 

1381 :param dag_id: DAG ID associated with the dagrun 

1382 :param tasks: the tasks to create the task instances from 

1383 :param created_counts: a dictionary of number of tasks -> total ti created by the task creator 

1384 :param hook_is_noop: whether the task_instance_mutation_hook is noop 

1385 :param session: the session to use 

1386 

1387 """ 

1388 # Fetch the information we need before handling the exception to avoid 

1389 # PendingRollbackError due to the session being invalidated on exception 

1390 # see https://github.com/apache/superset/pull/530 

1391 run_id = self.run_id 

1392 try: 

1393 if hook_is_noop: 

1394 session.bulk_insert_mappings(TI, tasks) 

1395 else: 

1396 session.bulk_save_objects(tasks) 

1397 

1398 for task_type, count in created_counts.items(): 

1399 Stats.incr(f"task_instance_created_{task_type}", count, tags=self.stats_tags) 

1400 # Same metric with tagging 

1401 Stats.incr("task_instance_created", count, tags={**self.stats_tags, "task_type": task_type}) 

1402 session.flush() 

1403 except IntegrityError: 

1404 self.log.info( 

1405 "Hit IntegrityError while creating the TIs for %s- %s", 

1406 dag_id, 

1407 run_id, 

1408 exc_info=True, 

1409 ) 

1410 self.log.info("Doing session rollback.") 

1411 # TODO[HA]: We probably need to savepoint this so we can keep the transaction alive. 

1412 session.rollback() 

1413 

1414 def _revise_map_indexes_if_mapped(self, task: Operator, *, session: Session) -> Iterator[TI]: 

1415 """Check if task increased or reduced in length and handle appropriately. 

1416 

1417 Task instances that do not already exist are created and returned if 

1418 possible. Expansion only happens if all upstreams are ready; otherwise 

1419 we delay expansion to the "last resort". See comments at the call site 

1420 for more details. 

1421 """ 

1422 from airflow.settings import task_instance_mutation_hook 

1423 

1424 try: 

1425 total_length = task.get_mapped_ti_count(self.run_id, session=session) 

1426 except NotMapped: 

1427 return # Not a mapped task, don't need to do anything. 

1428 except NotFullyPopulated: 

1429 return # Upstreams not ready, don't need to revise this yet. 

1430 

1431 query = session.scalars( 

1432 select(TI.map_index).where( 

1433 TI.dag_id == self.dag_id, 

1434 TI.task_id == task.task_id, 

1435 TI.run_id == self.run_id, 

1436 ) 

1437 ) 

1438 existing_indexes = set(query) 

1439 

1440 removed_indexes = existing_indexes.difference(range(total_length)) 

1441 if removed_indexes: 

1442 session.execute( 

1443 update(TI) 

1444 .where( 

1445 TI.dag_id == self.dag_id, 

1446 TI.task_id == task.task_id, 

1447 TI.run_id == self.run_id, 

1448 TI.map_index.in_(removed_indexes), 

1449 ) 

1450 .values(state=TaskInstanceState.REMOVED) 

1451 ) 

1452 session.flush() 

1453 

1454 for index in range(total_length): 

1455 if index in existing_indexes: 

1456 continue 

1457 ti = TI(task, run_id=self.run_id, map_index=index, state=None) 

1458 self.log.debug("Expanding TIs upserted %s", ti) 

1459 task_instance_mutation_hook(ti) 

1460 ti = session.merge(ti) 

1461 ti.refresh_from_task(task) 

1462 session.flush() 

1463 yield ti 

1464 

1465 @staticmethod 

1466 def get_run(session: Session, dag_id: str, execution_date: datetime) -> DagRun | None: 

1467 """ 

1468 Get a single DAG Run. 

1469 

1470 :meta private: 

1471 :param session: Sqlalchemy ORM Session 

1472 :param dag_id: DAG ID 

1473 :param execution_date: execution date 

1474 :return: DagRun corresponding to the given dag_id and execution date 

1475 if one exists. None otherwise. 

1476 """ 

1477 warnings.warn( 

1478 "This method is deprecated. Please use SQLAlchemy directly", 

1479 RemovedInAirflow3Warning, 

1480 stacklevel=2, 

1481 ) 

1482 return session.scalar( 

1483 select(DagRun).where( 

1484 DagRun.dag_id == dag_id, 

1485 DagRun.external_trigger == False, # noqa: E712 

1486 DagRun.execution_date == execution_date, 

1487 ) 

1488 ) 

1489 

1490 @property 

1491 def is_backfill(self) -> bool: 

1492 return self.run_type == DagRunType.BACKFILL_JOB 

1493 

1494 @classmethod 

1495 @provide_session 

1496 def get_latest_runs(cls, session: Session = NEW_SESSION) -> list[DagRun]: 

1497 """Return the latest DagRun for each DAG.""" 

1498 subquery = ( 

1499 select(cls.dag_id, func.max(cls.execution_date).label("execution_date")) 

1500 .group_by(cls.dag_id) 

1501 .subquery() 

1502 ) 

1503 return session.scalars( 

1504 select(cls).join( 

1505 subquery, 

1506 and_(cls.dag_id == subquery.c.dag_id, cls.execution_date == subquery.c.execution_date), 

1507 ) 

1508 ).all() 

1509 

1510 @provide_session 

1511 def schedule_tis( 

1512 self, 

1513 schedulable_tis: Iterable[TI], 

1514 session: Session = NEW_SESSION, 

1515 max_tis_per_query: int | None = None, 

1516 ) -> int: 

1517 """ 

1518 Set the given task instances in to the scheduled state. 

1519 

1520 Each element of ``schedulable_tis`` should have its ``task`` attribute already set. 

1521 

1522 Any EmptyOperator without callbacks or outlets is instead set straight to the success state. 

1523 

1524 All the TIs should belong to this DagRun, but this code is in the hot-path, this is not checked -- it 

1525 is the caller's responsibility to call this function only with TIs from a single dag run. 

1526 """ 

1527 # Get list of TI IDs that do not need to executed, these are 

1528 # tasks using EmptyOperator and without on_execute_callback / on_success_callback 

1529 dummy_ti_ids = [] 

1530 schedulable_ti_ids = [] 

1531 for ti in schedulable_tis: 

1532 if TYPE_CHECKING: 

1533 assert ti.task 

1534 if ( 

1535 ti.task.inherits_from_empty_operator 

1536 and not ti.task.on_execute_callback 

1537 and not ti.task.on_success_callback 

1538 and not ti.task.outlets 

1539 ): 

1540 dummy_ti_ids.append((ti.task_id, ti.map_index)) 

1541 elif ( 

1542 ti.task.start_trigger is not None 

1543 and ti.task.next_method is not None 

1544 and not ti.task.on_execute_callback 

1545 and not ti.task.on_success_callback 

1546 and not ti.task.outlets 

1547 ): 

1548 if ti.state != TaskInstanceState.UP_FOR_RESCHEDULE: 

1549 ti.try_number += 1 

1550 ti.defer_task( 

1551 defer=TaskDeferred(trigger=ti.task.start_trigger, method_name=ti.task.next_method), 

1552 session=session, 

1553 ) 

1554 else: 

1555 schedulable_ti_ids.append((ti.task_id, ti.map_index)) 

1556 

1557 count = 0 

1558 

1559 if schedulable_ti_ids: 

1560 schedulable_ti_ids_chunks = chunks( 

1561 schedulable_ti_ids, max_tis_per_query or len(schedulable_ti_ids) 

1562 ) 

1563 for schedulable_ti_ids_chunk in schedulable_ti_ids_chunks: 

1564 count += session.execute( 

1565 update(TI) 

1566 .where( 

1567 TI.dag_id == self.dag_id, 

1568 TI.run_id == self.run_id, 

1569 tuple_in_condition((TI.task_id, TI.map_index), schedulable_ti_ids_chunk), 

1570 ) 

1571 .values( 

1572 state=TaskInstanceState.SCHEDULED, 

1573 try_number=case( 

1574 ( 

1575 or_(TI.state.is_(None), TI.state != TaskInstanceState.UP_FOR_RESCHEDULE), 

1576 TI.try_number + 1, 

1577 ), 

1578 else_=TI.try_number, 

1579 ), 

1580 ) 

1581 .execution_options(synchronize_session=False) 

1582 ).rowcount 

1583 

1584 # Tasks using EmptyOperator should not be executed, mark them as success 

1585 if dummy_ti_ids: 

1586 dummy_ti_ids_chunks = chunks(dummy_ti_ids, max_tis_per_query or len(dummy_ti_ids)) 

1587 for dummy_ti_ids_chunk in dummy_ti_ids_chunks: 

1588 count += session.execute( 

1589 update(TI) 

1590 .where( 

1591 TI.dag_id == self.dag_id, 

1592 TI.run_id == self.run_id, 

1593 tuple_in_condition((TI.task_id, TI.map_index), dummy_ti_ids_chunk), 

1594 ) 

1595 .values( 

1596 state=TaskInstanceState.SUCCESS, 

1597 start_date=timezone.utcnow(), 

1598 end_date=timezone.utcnow(), 

1599 duration=0, 

1600 ) 

1601 .execution_options( 

1602 synchronize_session=False, 

1603 ) 

1604 ).rowcount 

1605 

1606 return count 

1607 

1608 @provide_session 

1609 def get_log_template(self, *, session: Session = NEW_SESSION) -> LogTemplate | LogTemplatePydantic: 

1610 return DagRun._get_log_template(log_template_id=self.log_template_id, session=session) 

1611 

1612 @staticmethod 

1613 @internal_api_call 

1614 @provide_session 

1615 def _get_log_template( 

1616 log_template_id: int | None, session: Session = NEW_SESSION 

1617 ) -> LogTemplate | LogTemplatePydantic: 

1618 template: LogTemplate | None 

1619 if log_template_id is None: # DagRun created before LogTemplate introduction. 

1620 template = session.scalar(select(LogTemplate).order_by(LogTemplate.id).limit(1)) 

1621 else: 

1622 template = session.get(LogTemplate, log_template_id) 

1623 if template is None: 

1624 raise AirflowException( 

1625 f"No log_template entry found for ID {log_template_id!r}. " 

1626 f"Please make sure you set up the metadatabase correctly." 

1627 ) 

1628 return template 

1629 

1630 @provide_session 

1631 def get_log_filename_template(self, *, session: Session = NEW_SESSION) -> str: 

1632 warnings.warn( 

1633 "This method is deprecated. Please use get_log_template instead.", 

1634 RemovedInAirflow3Warning, 

1635 stacklevel=2, 

1636 ) 

1637 return self.get_log_template(session=session).filename 

1638 

1639 

1640class DagRunNote(Base): 

1641 """For storage of arbitrary notes concerning the dagrun instance.""" 

1642 

1643 __tablename__ = "dag_run_note" 

1644 

1645 user_id = Column( 

1646 Integer, 

1647 ForeignKey("ab_user.id", name="dag_run_note_user_fkey"), 

1648 nullable=True, 

1649 ) 

1650 dag_run_id = Column(Integer, primary_key=True, nullable=False) 

1651 content = Column(String(1000).with_variant(Text(1000), "mysql")) 

1652 created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False) 

1653 updated_at = Column(UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow, nullable=False) 

1654 

1655 dag_run = relationship("DagRun", back_populates="dag_run_note") 

1656 

1657 __table_args__ = ( 

1658 PrimaryKeyConstraint("dag_run_id", name="dag_run_note_pkey"), 

1659 ForeignKeyConstraint( 

1660 (dag_run_id,), 

1661 ["dag_run.id"], 

1662 name="dag_run_note_dr_fkey", 

1663 ondelete="CASCADE", 

1664 ), 

1665 ) 

1666 

1667 def __init__(self, content, user_id=None): 

1668 self.content = content 

1669 self.user_id = user_id 

1670 

1671 def __repr__(self): 

1672 prefix = f"<{self.__class__.__name__}: {self.dag_id}.{self.dagrun_id} {self.run_id}" 

1673 if self.map_index != -1: 

1674 prefix += f" map_index={self.map_index}" 

1675 return prefix + ">"