Coverage for /pythoncovmergedfiles/medio/medio/src/airflow/build/lib/airflow/models/dagrun.py: 25%

601 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-06-07 06:35 +0000

1# 

2# Licensed to the Apache Software Foundation (ASF) under one 

3# or more contributor license agreements. See the NOTICE file 

4# distributed with this work for additional information 

5# regarding copyright ownership. The ASF licenses this file 

6# to you under the Apache License, Version 2.0 (the 

7# "License"); you may not use this file except in compliance 

8# with the License. You may obtain a copy of the License at 

9# 

10# http://www.apache.org/licenses/LICENSE-2.0 

11# 

12# Unless required by applicable law or agreed to in writing, 

13# software distributed under the License is distributed on an 

14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

15# KIND, either express or implied. See the License for the 

16# specific language governing permissions and limitations 

17# under the License. 

18from __future__ import annotations 

19 

20import itertools 

21import os 

22import warnings 

23from collections import defaultdict 

24from datetime import datetime 

25from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator, NamedTuple, Sequence, TypeVar, overload 

26 

27from sqlalchemy import ( 

28 Boolean, 

29 Column, 

30 ForeignKey, 

31 ForeignKeyConstraint, 

32 Index, 

33 Integer, 

34 PickleType, 

35 PrimaryKeyConstraint, 

36 String, 

37 Text, 

38 UniqueConstraint, 

39 and_, 

40 func, 

41 or_, 

42 text, 

43) 

44from sqlalchemy.exc import IntegrityError 

45from sqlalchemy.ext.associationproxy import association_proxy 

46from sqlalchemy.orm import Query, Session, declared_attr, joinedload, relationship, synonym 

47from sqlalchemy.sql.expression import false, select, true 

48 

49from airflow import settings 

50from airflow.callbacks.callback_requests import DagCallbackRequest 

51from airflow.configuration import conf as airflow_conf 

52from airflow.exceptions import AirflowException, RemovedInAirflow3Warning, TaskNotFound 

53from airflow.listeners.listener import get_listener_manager 

54from airflow.models.abstractoperator import NotMapped 

55from airflow.models.base import Base, StringID 

56from airflow.models.expandinput import NotFullyPopulated 

57from airflow.models.taskinstance import TaskInstance as TI 

58from airflow.models.tasklog import LogTemplate 

59from airflow.stats import Stats 

60from airflow.ti_deps.dep_context import DepContext 

61from airflow.ti_deps.dependencies_states import SCHEDULEABLE_STATES 

62from airflow.typing_compat import Literal 

63from airflow.utils import timezone 

64from airflow.utils.helpers import chunks, is_container, prune_dict 

65from airflow.utils.log.logging_mixin import LoggingMixin 

66from airflow.utils.session import NEW_SESSION, provide_session 

67from airflow.utils.sqlalchemy import UtcDateTime, nulls_first, skip_locked, tuple_in_condition, with_row_locks 

68from airflow.utils.state import DagRunState, State, TaskInstanceState 

69from airflow.utils.types import NOTSET, ArgNotSet, DagRunType 

70 

71if TYPE_CHECKING: 

72 from airflow.models.dag import DAG 

73 from airflow.models.operator import Operator 

74 

75 CreatedTasks = TypeVar("CreatedTasks", Iterator["dict[str, Any]"], Iterator[TI]) 

76 TaskCreator = Callable[[Operator, Iterable[int]], CreatedTasks] 

77 

78 

79class TISchedulingDecision(NamedTuple): 

80 """Type of return for DagRun.task_instance_scheduling_decisions.""" 

81 

82 tis: list[TI] 

83 schedulable_tis: list[TI] 

84 changed_tis: bool 

85 unfinished_tis: list[TI] 

86 finished_tis: list[TI] 

87 

88 

89def _creator_note(val): 

90 """Custom creator for the ``note`` association proxy.""" 

91 if isinstance(val, str): 

92 return DagRunNote(content=val) 

93 elif isinstance(val, dict): 

94 return DagRunNote(**val) 

95 else: 

96 return DagRunNote(*val) 

97 

98 

99class DagRun(Base, LoggingMixin): 

100 """Invocation instance of a DAG. 

101 

102 A DAG run can be created by the scheduler (i.e. scheduled runs), or by an 

103 external trigger (i.e. manual runs). 

104 """ 

105 

106 __tablename__ = "dag_run" 

107 

108 id = Column(Integer, primary_key=True) 

109 dag_id = Column(StringID(), nullable=False) 

110 queued_at = Column(UtcDateTime) 

111 execution_date = Column(UtcDateTime, default=timezone.utcnow, nullable=False) 

112 start_date = Column(UtcDateTime) 

113 end_date = Column(UtcDateTime) 

114 _state = Column("state", String(50), default=State.QUEUED) 

115 run_id = Column(StringID(), nullable=False) 

116 creating_job_id = Column(Integer) 

117 external_trigger = Column(Boolean, default=True) 

118 run_type = Column(String(50), nullable=False) 

119 conf = Column(PickleType) 

120 # These two must be either both NULL or both datetime. 

121 data_interval_start = Column(UtcDateTime) 

122 data_interval_end = Column(UtcDateTime) 

123 # When a scheduler last attempted to schedule TIs for this DagRun 

124 last_scheduling_decision = Column(UtcDateTime) 

125 dag_hash = Column(String(32)) 

126 # Foreign key to LogTemplate. DagRun rows created prior to this column's 

127 # existence have this set to NULL. Later rows automatically populate this on 

128 # insert to point to the latest LogTemplate entry. 

129 log_template_id = Column( 

130 Integer, 

131 ForeignKey("log_template.id", name="task_instance_log_template_id_fkey", ondelete="NO ACTION"), 

132 default=select(func.max(LogTemplate.__table__.c.id)), 

133 ) 

134 updated_at = Column(UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow) 

135 

136 # Remove this `if` after upgrading Sphinx-AutoAPI 

137 if not TYPE_CHECKING and "BUILDING_AIRFLOW_DOCS" in os.environ: 

138 dag: DAG | None 

139 else: 

140 dag: DAG | None = None 

141 

142 __table_args__ = ( 

143 Index("dag_id_state", dag_id, _state), 

144 UniqueConstraint("dag_id", "execution_date", name="dag_run_dag_id_execution_date_key"), 

145 UniqueConstraint("dag_id", "run_id", name="dag_run_dag_id_run_id_key"), 

146 Index("idx_last_scheduling_decision", last_scheduling_decision), 

147 Index("idx_dag_run_dag_id", dag_id), 

148 Index( 

149 "idx_dag_run_running_dags", 

150 "state", 

151 "dag_id", 

152 postgresql_where=text("state='running'"), 

153 mssql_where=text("state='running'"), 

154 sqlite_where=text("state='running'"), 

155 ), 

156 # since mysql lacks filtered/partial indices, this creates a 

157 # duplicate index on mysql. Not the end of the world 

158 Index( 

159 "idx_dag_run_queued_dags", 

160 "state", 

161 "dag_id", 

162 postgresql_where=text("state='queued'"), 

163 mssql_where=text("state='queued'"), 

164 sqlite_where=text("state='queued'"), 

165 ), 

166 ) 

167 

168 task_instances = relationship( 

169 TI, back_populates="dag_run", cascade="save-update, merge, delete, delete-orphan" 

170 ) 

171 dag_model = relationship( 

172 "DagModel", 

173 primaryjoin="foreign(DagRun.dag_id) == DagModel.dag_id", 

174 uselist=False, 

175 viewonly=True, 

176 ) 

177 dag_run_note = relationship( 

178 "DagRunNote", 

179 back_populates="dag_run", 

180 uselist=False, 

181 cascade="all, delete, delete-orphan", 

182 ) 

183 note = association_proxy("dag_run_note", "content", creator=_creator_note) 

184 

185 DEFAULT_DAGRUNS_TO_EXAMINE = airflow_conf.getint( 

186 "scheduler", 

187 "max_dagruns_per_loop_to_schedule", 

188 fallback=20, 

189 ) 

190 

191 def __init__( 

192 self, 

193 dag_id: str | None = None, 

194 run_id: str | None = None, 

195 queued_at: datetime | None | ArgNotSet = NOTSET, 

196 execution_date: datetime | None = None, 

197 start_date: datetime | None = None, 

198 external_trigger: bool | None = None, 

199 conf: Any | None = None, 

200 state: DagRunState | None = None, 

201 run_type: str | None = None, 

202 dag_hash: str | None = None, 

203 creating_job_id: int | None = None, 

204 data_interval: tuple[datetime, datetime] | None = None, 

205 ): 

206 if data_interval is None: 

207 # Legacy: Only happen for runs created prior to Airflow 2.2. 

208 self.data_interval_start = self.data_interval_end = None 

209 else: 

210 self.data_interval_start, self.data_interval_end = data_interval 

211 

212 self.dag_id = dag_id 

213 self.run_id = run_id 

214 self.execution_date = execution_date 

215 self.start_date = start_date 

216 self.external_trigger = external_trigger 

217 self.conf = conf or {} 

218 if state is not None: 

219 self.state = state 

220 if queued_at is NOTSET: 

221 self.queued_at = timezone.utcnow() if state == State.QUEUED else None 

222 else: 

223 self.queued_at = queued_at 

224 self.run_type = run_type 

225 self.dag_hash = dag_hash 

226 self.creating_job_id = creating_job_id 

227 super().__init__() 

228 

229 def __repr__(self): 

230 return ( 

231 "<DagRun {dag_id} @ {execution_date}: {run_id}, state:{state}, " 

232 "queued_at: {queued_at}. externally triggered: {external_trigger}>" 

233 ).format( 

234 dag_id=self.dag_id, 

235 execution_date=self.execution_date, 

236 run_id=self.run_id, 

237 state=self.state, 

238 queued_at=self.queued_at, 

239 external_trigger=self.external_trigger, 

240 ) 

241 

242 @property 

243 def stats_tags(self) -> dict[str, str]: 

244 return prune_dict({"dag_id": self.dag_id, "run_type": self.run_type}) 

245 

246 @property 

247 def logical_date(self) -> datetime: 

248 return self.execution_date 

249 

250 def get_state(self): 

251 return self._state 

252 

253 def set_state(self, state: DagRunState): 

254 if state not in State.dag_states: 

255 raise ValueError(f"invalid DagRun state: {state}") 

256 if self._state != state: 

257 self._state = state 

258 self.end_date = timezone.utcnow() if self._state in State.finished else None 

259 if state == State.QUEUED: 

260 self.queued_at = timezone.utcnow() 

261 

262 @declared_attr 

263 def state(self): 

264 return synonym("_state", descriptor=property(self.get_state, self.set_state)) 

265 

266 @provide_session 

267 def refresh_from_db(self, session: Session = NEW_SESSION) -> None: 

268 """ 

269 Reloads the current dagrun from the database. 

270 

271 :param session: database session 

272 """ 

273 dr = session.query(DagRun).filter(DagRun.dag_id == self.dag_id, DagRun.run_id == self.run_id).one() 

274 self.id = dr.id 

275 self.state = dr.state 

276 

277 @classmethod 

278 @provide_session 

279 def active_runs_of_dags( 

280 cls, 

281 dag_ids: Iterable[str] | None = None, 

282 only_running: bool = False, 

283 session: Session = NEW_SESSION, 

284 ) -> dict[str, int]: 

285 """Get the number of active dag runs for each dag.""" 

286 query = session.query(cls.dag_id, func.count("*")) 

287 if dag_ids is not None: 

288 # 'set' called to avoid duplicate dag_ids, but converted back to 'list' 

289 # because SQLAlchemy doesn't accept a set here. 

290 query = query.filter(cls.dag_id.in_(set(dag_ids))) 

291 if only_running: 

292 query = query.filter(cls.state == State.RUNNING) 

293 else: 

294 query = query.filter(cls.state.in_([State.RUNNING, State.QUEUED])) 

295 query = query.group_by(cls.dag_id) 

296 return {dag_id: count for dag_id, count in query.all()} 

297 

298 @classmethod 

299 def next_dagruns_to_examine( 

300 cls, 

301 state: DagRunState, 

302 session: Session, 

303 max_number: int | None = None, 

304 ) -> Query: 

305 """ 

306 Return the next DagRuns that the scheduler should attempt to schedule. 

307 

308 This will return zero or more DagRun rows that are row-level-locked with a "SELECT ... FOR UPDATE" 

309 query, you should ensure that any scheduling decisions are made in a single transaction -- as soon as 

310 the transaction is committed it will be unlocked. 

311 

312 """ 

313 from airflow.models.dag import DagModel 

314 

315 if max_number is None: 

316 max_number = cls.DEFAULT_DAGRUNS_TO_EXAMINE 

317 

318 # TODO: Bake this query, it is run _A lot_ 

319 query = ( 

320 session.query(cls) 

321 .with_hint(cls, "USE INDEX (idx_dag_run_running_dags)", dialect_name="mysql") 

322 .filter(cls.state == state, cls.run_type != DagRunType.BACKFILL_JOB) 

323 .join(DagModel, DagModel.dag_id == cls.dag_id) 

324 .filter(DagModel.is_paused == false(), DagModel.is_active == true()) 

325 ) 

326 if state == State.QUEUED: 

327 # For dag runs in the queued state, we check if they have reached the max_active_runs limit 

328 # and if so we drop them 

329 running_drs = ( 

330 session.query(DagRun.dag_id, func.count(DagRun.state).label("num_running")) 

331 .filter(DagRun.state == DagRunState.RUNNING) 

332 .group_by(DagRun.dag_id) 

333 .subquery() 

334 ) 

335 query = query.outerjoin(running_drs, running_drs.c.dag_id == DagRun.dag_id).filter( 

336 func.coalesce(running_drs.c.num_running, 0) < DagModel.max_active_runs 

337 ) 

338 query = query.order_by( 

339 nulls_first(cls.last_scheduling_decision, session=session), 

340 cls.execution_date, 

341 ) 

342 

343 if not settings.ALLOW_FUTURE_EXEC_DATES: 

344 query = query.filter(DagRun.execution_date <= func.now()) 

345 

346 return with_row_locks( 

347 query.limit(max_number), of=cls, session=session, **skip_locked(session=session) 

348 ) 

349 

350 @classmethod 

351 @provide_session 

352 def find( 

353 cls, 

354 dag_id: str | list[str] | None = None, 

355 run_id: Iterable[str] | None = None, 

356 execution_date: datetime | Iterable[datetime] | None = None, 

357 state: DagRunState | None = None, 

358 external_trigger: bool | None = None, 

359 no_backfills: bool = False, 

360 run_type: DagRunType | None = None, 

361 session: Session = NEW_SESSION, 

362 execution_start_date: datetime | None = None, 

363 execution_end_date: datetime | None = None, 

364 ) -> list[DagRun]: 

365 """ 

366 Returns a set of dag runs for the given search criteria. 

367 

368 :param dag_id: the dag_id or list of dag_id to find dag runs for 

369 :param run_id: defines the run id for this dag run 

370 :param run_type: type of DagRun 

371 :param execution_date: the execution date 

372 :param state: the state of the dag run 

373 :param external_trigger: whether this dag run is externally triggered 

374 :param no_backfills: return no backfills (True), return all (False). 

375 Defaults to False 

376 :param session: database session 

377 :param execution_start_date: dag run that was executed from this date 

378 :param execution_end_date: dag run that was executed until this date 

379 """ 

380 qry = session.query(cls) 

381 dag_ids = [dag_id] if isinstance(dag_id, str) else dag_id 

382 if dag_ids: 

383 qry = qry.filter(cls.dag_id.in_(dag_ids)) 

384 

385 if is_container(run_id): 

386 qry = qry.filter(cls.run_id.in_(run_id)) 

387 elif run_id is not None: 

388 qry = qry.filter(cls.run_id == run_id) 

389 if is_container(execution_date): 

390 qry = qry.filter(cls.execution_date.in_(execution_date)) 

391 elif execution_date is not None: 

392 qry = qry.filter(cls.execution_date == execution_date) 

393 if execution_start_date and execution_end_date: 

394 qry = qry.filter(cls.execution_date.between(execution_start_date, execution_end_date)) 

395 elif execution_start_date: 

396 qry = qry.filter(cls.execution_date >= execution_start_date) 

397 elif execution_end_date: 

398 qry = qry.filter(cls.execution_date <= execution_end_date) 

399 if state: 

400 qry = qry.filter(cls.state == state) 

401 if external_trigger is not None: 

402 qry = qry.filter(cls.external_trigger == external_trigger) 

403 if run_type: 

404 qry = qry.filter(cls.run_type == run_type) 

405 if no_backfills: 

406 qry = qry.filter(cls.run_type != DagRunType.BACKFILL_JOB) 

407 

408 return qry.order_by(cls.execution_date).all() 

409 

410 @classmethod 

411 @provide_session 

412 def find_duplicate( 

413 cls, 

414 dag_id: str, 

415 run_id: str, 

416 execution_date: datetime, 

417 session: Session = NEW_SESSION, 

418 ) -> DagRun | None: 

419 """ 

420 Return an existing run for the DAG with a specific run_id or execution_date. 

421 

422 *None* is returned if no such DAG run is found. 

423 

424 :param dag_id: the dag_id to find duplicates for 

425 :param run_id: defines the run id for this dag run 

426 :param execution_date: the execution date 

427 :param session: database session 

428 """ 

429 return ( 

430 session.query(cls) 

431 .filter( 

432 cls.dag_id == dag_id, 

433 or_(cls.run_id == run_id, cls.execution_date == execution_date), 

434 ) 

435 .one_or_none() 

436 ) 

437 

438 @staticmethod 

439 def generate_run_id(run_type: DagRunType, execution_date: datetime) -> str: 

440 """Generate Run ID based on Run Type and Execution Date.""" 

441 # _Ensure_ run_type is a DagRunType, not just a string from user code 

442 return DagRunType(run_type).generate_run_id(execution_date) 

443 

444 @provide_session 

445 def get_task_instances( 

446 self, 

447 state: Iterable[TaskInstanceState | None] | None = None, 

448 session: Session = NEW_SESSION, 

449 ) -> list[TI]: 

450 """Returns the task instances for this dag run.""" 

451 tis = ( 

452 session.query(TI) 

453 .options(joinedload(TI.dag_run)) 

454 .filter( 

455 TI.dag_id == self.dag_id, 

456 TI.run_id == self.run_id, 

457 ) 

458 ) 

459 

460 if state: 

461 if isinstance(state, str): 

462 tis = tis.filter(TI.state == state) 

463 else: 

464 # this is required to deal with NULL values 

465 if State.NONE in state: 

466 if all(x is None for x in state): 

467 tis = tis.filter(TI.state.is_(None)) 

468 else: 

469 not_none_state = [s for s in state if s] 

470 tis = tis.filter(or_(TI.state.in_(not_none_state), TI.state.is_(None))) 

471 else: 

472 tis = tis.filter(TI.state.in_(state)) 

473 

474 if self.dag and self.dag.partial: 

475 tis = tis.filter(TI.task_id.in_(self.dag.task_ids)) 

476 return tis.all() 

477 

478 @provide_session 

479 def get_task_instance( 

480 self, 

481 task_id: str, 

482 session: Session = NEW_SESSION, 

483 *, 

484 map_index: int = -1, 

485 ) -> TI | None: 

486 """ 

487 Returns the task instance specified by task_id for this dag run. 

488 

489 :param task_id: the task id 

490 :param session: Sqlalchemy ORM Session 

491 """ 

492 return ( 

493 session.query(TI) 

494 .filter_by(dag_id=self.dag_id, run_id=self.run_id, task_id=task_id, map_index=map_index) 

495 .one_or_none() 

496 ) 

497 

498 def get_dag(self) -> DAG: 

499 """ 

500 Returns the Dag associated with this DagRun. 

501 

502 :return: DAG 

503 """ 

504 if not self.dag: 

505 raise AirflowException(f"The DAG (.dag) for {self} needs to be set") 

506 

507 return self.dag 

508 

509 @provide_session 

510 def get_previous_dagrun( 

511 self, state: DagRunState | None = None, session: Session = NEW_SESSION 

512 ) -> DagRun | None: 

513 """The previous DagRun, if there is one.""" 

514 filters = [ 

515 DagRun.dag_id == self.dag_id, 

516 DagRun.execution_date < self.execution_date, 

517 ] 

518 if state is not None: 

519 filters.append(DagRun.state == state) 

520 return session.query(DagRun).filter(*filters).order_by(DagRun.execution_date.desc()).first() 

521 

522 @provide_session 

523 def get_previous_scheduled_dagrun(self, session: Session = NEW_SESSION) -> DagRun | None: 

524 """The previous, SCHEDULED DagRun, if there is one.""" 

525 return ( 

526 session.query(DagRun) 

527 .filter( 

528 DagRun.dag_id == self.dag_id, 

529 DagRun.execution_date < self.execution_date, 

530 DagRun.run_type != DagRunType.MANUAL, 

531 ) 

532 .order_by(DagRun.execution_date.desc()) 

533 .first() 

534 ) 

535 

536 def _tis_for_dagrun_state(self, *, dag, tis): 

537 """ 

538 Return the collection of tasks that should be considered for evaluation of terminal dag run state. 

539 

540 Teardown tasks by default are not considered for the purpose of dag run state. But 

541 users may enable such consideration with on_failure_fail_dagrun. 

542 """ 

543 

544 def is_effective_leaf(task): 

545 for down_task_id in task.downstream_task_ids: 

546 down_task = dag.get_task(down_task_id) 

547 if not down_task.is_teardown or down_task.on_failure_fail_dagrun: 

548 # we found a down task that is not ignorable; not a leaf 

549 return False 

550 # we found no ignorable downstreams 

551 # evaluate whether task is itself ignorable 

552 return not task.is_teardown or task.on_failure_fail_dagrun 

553 

554 leaf_task_ids = {x.task_id for x in dag.tasks if is_effective_leaf(x)} 

555 if not leaf_task_ids: 

556 # can happen if dag is exclusively teardown tasks 

557 leaf_task_ids = {x.task_id for x in dag.tasks if not x.downstream_list} 

558 leaf_tis = {ti for ti in tis if ti.task_id in leaf_task_ids if ti.state != TaskInstanceState.REMOVED} 

559 return leaf_tis 

560 

561 @provide_session 

562 def update_state( 

563 self, session: Session = NEW_SESSION, execute_callbacks: bool = True 

564 ) -> tuple[list[TI], DagCallbackRequest | None]: 

565 """ 

566 Determines the overall state of the DagRun based on the state 

567 of its TaskInstances. 

568 

569 :param session: Sqlalchemy ORM Session 

570 :param execute_callbacks: Should dag callbacks (success/failure, SLA etc.) be invoked 

571 directly (default: true) or recorded as a pending request in the ``returned_callback`` property 

572 :return: Tuple containing tis that can be scheduled in the current loop & `returned_callback` that 

573 needs to be executed 

574 """ 

575 # Callback to execute in case of Task Failures 

576 callback: DagCallbackRequest | None = None 

577 

578 class _UnfinishedStates(NamedTuple): 

579 tis: Sequence[TI] 

580 

581 @classmethod 

582 def calculate(cls, unfinished_tis: Sequence[TI]) -> _UnfinishedStates: 

583 return cls(tis=unfinished_tis) 

584 

585 @property 

586 def should_schedule(self) -> bool: 

587 return ( 

588 bool(self.tis) 

589 and all(not t.task.depends_on_past for t in self.tis) 

590 and all(t.task.max_active_tis_per_dag is None for t in self.tis) 

591 and all(t.task.max_active_tis_per_dagrun is None for t in self.tis) 

592 and all(t.state != TaskInstanceState.DEFERRED for t in self.tis) 

593 ) 

594 

595 def recalculate(self) -> _UnfinishedStates: 

596 return self._replace(tis=[t for t in self.tis if t.state in State.unfinished]) 

597 

598 start_dttm = timezone.utcnow() 

599 self.last_scheduling_decision = start_dttm 

600 with Stats.timer( 

601 f"dagrun.dependency-check.{self.dag_id}", 

602 tags=self.stats_tags, 

603 ): 

604 dag = self.get_dag() 

605 info = self.task_instance_scheduling_decisions(session) 

606 

607 tis = info.tis 

608 schedulable_tis = info.schedulable_tis 

609 changed_tis = info.changed_tis 

610 finished_tis = info.finished_tis 

611 unfinished = _UnfinishedStates.calculate(info.unfinished_tis) 

612 

613 if unfinished.should_schedule: 

614 are_runnable_tasks = schedulable_tis or changed_tis 

615 # small speed up 

616 if not are_runnable_tasks: 

617 are_runnable_tasks, changed_by_upstream = self._are_premature_tis( 

618 unfinished.tis, finished_tis, session 

619 ) 

620 if changed_by_upstream: # Something changed, we need to recalculate! 

621 unfinished = unfinished.recalculate() 

622 

623 tis_for_dagrun_state = self._tis_for_dagrun_state(dag=dag, tis=tis) 

624 

625 # if all tasks finished and at least one failed, the run failed 

626 if not unfinished.tis and any(x.state in State.failed_states for x in tis_for_dagrun_state): 

627 self.log.error("Marking run %s failed", self) 

628 self.set_state(DagRunState.FAILED) 

629 self.notify_dagrun_state_changed(msg="task_failure") 

630 

631 if execute_callbacks: 

632 dag.handle_callback(self, success=False, reason="task_failure", session=session) 

633 elif dag.has_on_failure_callback: 

634 from airflow.models.dag import DagModel 

635 

636 dag_model = DagModel.get_dagmodel(dag.dag_id, session) 

637 callback = DagCallbackRequest( 

638 full_filepath=dag.fileloc, 

639 dag_id=self.dag_id, 

640 run_id=self.run_id, 

641 is_failure_callback=True, 

642 processor_subdir=None if dag_model is None else dag_model.processor_subdir, 

643 msg="task_failure", 

644 ) 

645 

646 # if all leaves succeeded and no unfinished tasks, the run succeeded 

647 elif not unfinished.tis and all(x.state in State.success_states for x in tis_for_dagrun_state): 

648 self.log.info("Marking run %s successful", self) 

649 self.set_state(DagRunState.SUCCESS) 

650 self.notify_dagrun_state_changed(msg="success") 

651 

652 if execute_callbacks: 

653 dag.handle_callback(self, success=True, reason="success", session=session) 

654 elif dag.has_on_success_callback: 

655 from airflow.models.dag import DagModel 

656 

657 dag_model = DagModel.get_dagmodel(dag.dag_id, session) 

658 callback = DagCallbackRequest( 

659 full_filepath=dag.fileloc, 

660 dag_id=self.dag_id, 

661 run_id=self.run_id, 

662 is_failure_callback=False, 

663 processor_subdir=None if dag_model is None else dag_model.processor_subdir, 

664 msg="success", 

665 ) 

666 

667 # if *all tasks* are deadlocked, the run failed 

668 elif unfinished.should_schedule and not are_runnable_tasks: 

669 self.log.error("Task deadlock (no runnable tasks); marking run %s failed", self) 

670 self.set_state(DagRunState.FAILED) 

671 self.notify_dagrun_state_changed(msg="all_tasks_deadlocked") 

672 

673 if execute_callbacks: 

674 dag.handle_callback(self, success=False, reason="all_tasks_deadlocked", session=session) 

675 elif dag.has_on_failure_callback: 

676 from airflow.models.dag import DagModel 

677 

678 dag_model = DagModel.get_dagmodel(dag.dag_id, session) 

679 callback = DagCallbackRequest( 

680 full_filepath=dag.fileloc, 

681 dag_id=self.dag_id, 

682 run_id=self.run_id, 

683 is_failure_callback=True, 

684 processor_subdir=None if dag_model is None else dag_model.processor_subdir, 

685 msg="all_tasks_deadlocked", 

686 ) 

687 

688 # finally, if the roots aren't done, the dag is still running 

689 else: 

690 self.set_state(DagRunState.RUNNING) 

691 

692 if self._state == DagRunState.FAILED or self._state == DagRunState.SUCCESS: 

693 msg = ( 

694 "DagRun Finished: dag_id=%s, execution_date=%s, run_id=%s, " 

695 "run_start_date=%s, run_end_date=%s, run_duration=%s, " 

696 "state=%s, external_trigger=%s, run_type=%s, " 

697 "data_interval_start=%s, data_interval_end=%s, dag_hash=%s" 

698 ) 

699 self.log.info( 

700 msg, 

701 self.dag_id, 

702 self.execution_date, 

703 self.run_id, 

704 self.start_date, 

705 self.end_date, 

706 (self.end_date - self.start_date).total_seconds() 

707 if self.start_date and self.end_date 

708 else None, 

709 self._state, 

710 self.external_trigger, 

711 self.run_type, 

712 self.data_interval_start, 

713 self.data_interval_end, 

714 self.dag_hash, 

715 ) 

716 session.flush() 

717 

718 self._emit_true_scheduling_delay_stats_for_finished_state(finished_tis) 

719 self._emit_duration_stats_for_finished_state() 

720 

721 session.merge(self) 

722 # We do not flush here for performance reasons(It increases queries count by +20) 

723 

724 return schedulable_tis, callback 

725 

726 @provide_session 

727 def task_instance_scheduling_decisions(self, session: Session = NEW_SESSION) -> TISchedulingDecision: 

728 tis = self.get_task_instances(session=session, state=State.task_states) 

729 self.log.debug("number of tis tasks for %s: %s task(s)", self, len(tis)) 

730 

731 def _filter_tis_and_exclude_removed(dag: DAG, tis: list[TI]) -> Iterable[TI]: 

732 """Populate ``ti.task`` while excluding those missing one, marking them as REMOVED.""" 

733 for ti in tis: 

734 try: 

735 ti.task = dag.get_task(ti.task_id) 

736 except TaskNotFound: 

737 if ti.state != State.REMOVED: 

738 self.log.error("Failed to get task for ti %s. Marking it as removed.", ti) 

739 ti.state = State.REMOVED 

740 session.flush() 

741 else: 

742 yield ti 

743 

744 tis = list(_filter_tis_and_exclude_removed(self.get_dag(), tis)) 

745 

746 unfinished_tis = [t for t in tis if t.state in State.unfinished] 

747 finished_tis = [t for t in tis if t.state in State.finished] 

748 if unfinished_tis: 

749 schedulable_tis = [ut for ut in unfinished_tis if ut.state in SCHEDULEABLE_STATES] 

750 self.log.debug("number of scheduleable tasks for %s: %s task(s)", self, len(schedulable_tis)) 

751 schedulable_tis, changed_tis, expansion_happened = self._get_ready_tis( 

752 schedulable_tis, 

753 finished_tis, 

754 session=session, 

755 ) 

756 

757 # During expansion, we may change some tis into non-schedulable 

758 # states, so we need to re-compute. 

759 if expansion_happened: 

760 changed_tis = True 

761 new_unfinished_tis = [t for t in unfinished_tis if t.state in State.unfinished] 

762 finished_tis.extend(t for t in unfinished_tis if t.state in State.finished) 

763 unfinished_tis = new_unfinished_tis 

764 else: 

765 schedulable_tis = [] 

766 changed_tis = False 

767 

768 return TISchedulingDecision( 

769 tis=tis, 

770 schedulable_tis=schedulable_tis, 

771 changed_tis=changed_tis, 

772 unfinished_tis=unfinished_tis, 

773 finished_tis=finished_tis, 

774 ) 

775 

776 def notify_dagrun_state_changed(self, msg: str = ""): 

777 if self.state == DagRunState.RUNNING: 

778 get_listener_manager().hook.on_dag_run_running(dag_run=self, msg=msg) 

779 elif self.state == DagRunState.SUCCESS: 

780 get_listener_manager().hook.on_dag_run_success(dag_run=self, msg=msg) 

781 elif self.state == DagRunState.FAILED: 

782 get_listener_manager().hook.on_dag_run_failed(dag_run=self, msg=msg) 

783 # deliberately not notifying on QUEUED 

784 # we can't get all the state changes on SchedulerJob, BackfillJob 

785 # or LocalTaskJob, so we don't want to "falsely advertise" we notify about that 

786 

787 def _get_ready_tis( 

788 self, 

789 schedulable_tis: list[TI], 

790 finished_tis: list[TI], 

791 session: Session, 

792 ) -> tuple[list[TI], bool, bool]: 

793 old_states = {} 

794 ready_tis: list[TI] = [] 

795 changed_tis = False 

796 

797 if not schedulable_tis: 

798 return ready_tis, changed_tis, False 

799 

800 # If we expand TIs, we need a new list so that we iterate over them too. (We can't alter 

801 # `schedulable_tis` in place and have the `for` loop pick them up 

802 additional_tis: list[TI] = [] 

803 dep_context = DepContext( 

804 flag_upstream_failed=True, 

805 ignore_unmapped_tasks=True, # Ignore this Dep, as we will expand it if we can. 

806 finished_tis=finished_tis, 

807 ) 

808 

809 def _expand_mapped_task_if_needed(ti: TI) -> Iterable[TI] | None: 

810 """Try to expand the ti, if needed. 

811 

812 If the ti needs expansion, newly created task instances are 

813 returned as well as the original ti. 

814 The original ti is also modified in-place and assigned the 

815 ``map_index`` of 0. 

816 

817 If the ti does not need expansion, either because the task is not 

818 mapped, or has already been expanded, *None* is returned. 

819 """ 

820 if ti.map_index >= 0: # Already expanded, we're good. 

821 return None 

822 

823 from airflow.models.mappedoperator import MappedOperator 

824 

825 if isinstance(ti.task, MappedOperator): 

826 # If we get here, it could be that we are moving from non-mapped to mapped 

827 # after task instance clearing or this ti is not yet expanded. Safe to clear 

828 # the db references. 

829 ti.clear_db_references(session=session) 

830 try: 

831 expanded_tis, _ = ti.task.expand_mapped_task(self.run_id, session=session) 

832 except NotMapped: # Not a mapped task, nothing needed. 

833 return None 

834 if expanded_tis: 

835 return expanded_tis 

836 return () 

837 

838 # Check dependencies. 

839 expansion_happened = False 

840 # Set of task ids for which was already done _revise_map_indexes_if_mapped 

841 revised_map_index_task_ids = set() 

842 for schedulable in itertools.chain(schedulable_tis, additional_tis): 

843 old_state = schedulable.state 

844 if not schedulable.are_dependencies_met(session=session, dep_context=dep_context): 

845 old_states[schedulable.key] = old_state 

846 continue 

847 # If schedulable is not yet expanded, try doing it now. This is 

848 # called in two places: First and ideally in the mini scheduler at 

849 # the end of LocalTaskJob, and then as an "expansion of last resort" 

850 # in the scheduler to ensure that the mapped task is correctly 

851 # expanded before executed. Also see _revise_map_indexes_if_mapped 

852 # docstring for additional information. 

853 new_tis = None 

854 if schedulable.map_index < 0: 

855 new_tis = _expand_mapped_task_if_needed(schedulable) 

856 if new_tis is not None: 

857 additional_tis.extend(new_tis) 

858 expansion_happened = True 

859 if new_tis is None and schedulable.state in SCHEDULEABLE_STATES: 

860 # It's enough to revise map index once per task id, 

861 # checking the map index for each mapped task significantly slows down scheduling 

862 if schedulable.task.task_id not in revised_map_index_task_ids: 

863 ready_tis.extend(self._revise_map_indexes_if_mapped(schedulable.task, session=session)) 

864 revised_map_index_task_ids.add(schedulable.task.task_id) 

865 ready_tis.append(schedulable) 

866 

867 # Check if any ti changed state 

868 tis_filter = TI.filter_for_tis(old_states) 

869 if tis_filter is not None: 

870 fresh_tis = session.query(TI).filter(tis_filter).all() 

871 changed_tis = any(ti.state != old_states[ti.key] for ti in fresh_tis) 

872 

873 return ready_tis, changed_tis, expansion_happened 

874 

875 def _are_premature_tis( 

876 self, 

877 unfinished_tis: Sequence[TI], 

878 finished_tis: list[TI], 

879 session: Session, 

880 ) -> tuple[bool, bool]: 

881 dep_context = DepContext( 

882 flag_upstream_failed=True, 

883 ignore_in_retry_period=True, 

884 ignore_in_reschedule_period=True, 

885 finished_tis=finished_tis, 

886 ) 

887 # there might be runnable tasks that are up for retry and for some reason(retry delay, etc.) are 

888 # not ready yet, so we set the flags to count them in 

889 return ( 

890 any(ut.are_dependencies_met(dep_context=dep_context, session=session) for ut in unfinished_tis), 

891 dep_context.have_changed_ti_states, 

892 ) 

893 

894 def _emit_true_scheduling_delay_stats_for_finished_state(self, finished_tis: list[TI]) -> None: 

895 """Emit the true scheduling delay stats. 

896 

897 The true scheduling delay stats is defined as the time when the first 

898 task in DAG starts minus the expected DAG run datetime. 

899 

900 This helper method is used in ``update_state`` when the state of the 

901 DAG run is updated to a completed status (either success or failure). 

902 It finds the first started task within the DAG, calculates the run's 

903 expected start time based on the logical date and timetable, and gets 

904 the delay from the difference of these two values. 

905 

906 The emitted data may contain outliers (e.g. when the first task was 

907 cleared, so the second task's start date will be used), but we can get 

908 rid of the outliers on the stats side through dashboards tooling. 

909 

910 Note that the stat will only be emitted for scheduler-triggered DAG runs 

911 (i.e. when ``external_trigger`` is *False*). 

912 """ 

913 if self.state == TaskInstanceState.RUNNING: 

914 return 

915 if self.external_trigger: 

916 return 

917 if not finished_tis: 

918 return 

919 

920 try: 

921 dag = self.get_dag() 

922 

923 if not dag.timetable.periodic: 

924 # We can't emit this metric if there is no following schedule to calculate from! 

925 return 

926 

927 try: 

928 first_start_date = min(ti.start_date for ti in finished_tis if ti.start_date) 

929 except ValueError: # No start dates at all. 

930 pass 

931 else: 

932 # TODO: Logically, this should be DagRunInfo.run_after, but the 

933 # information is not stored on a DagRun, only before the actual 

934 # execution on DagModel.next_dagrun_create_after. We should add 

935 # a field on DagRun for this instead of relying on the run 

936 # always happening immediately after the data interval. 

937 data_interval_end = dag.get_run_data_interval(self).end 

938 true_delay = first_start_date - data_interval_end 

939 if true_delay.total_seconds() > 0: 

940 Stats.timing( 

941 f"dagrun.{dag.dag_id}.first_task_scheduling_delay", true_delay, tags=self.stats_tags 

942 ) 

943 Stats.timing("dagrun.first_task_scheduling_delay", true_delay, tags=self.stats_tags) 

944 except Exception: 

945 self.log.warning("Failed to record first_task_scheduling_delay metric:", exc_info=True) 

946 

947 def _emit_duration_stats_for_finished_state(self): 

948 if self.state == State.RUNNING: 

949 return 

950 if self.start_date is None: 

951 self.log.warning("Failed to record duration of %s: start_date is not set.", self) 

952 return 

953 if self.end_date is None: 

954 self.log.warning("Failed to record duration of %s: end_date is not set.", self) 

955 return 

956 

957 duration = self.end_date - self.start_date 

958 timer_params = {"dt": duration, "tags": self.stats_tags} 

959 Stats.timing(f"dagrun.duration.{self.state.value}.{self.dag_id}", **timer_params) 

960 Stats.timing(f"dagrun.duration.{self.state.value}", **timer_params) 

961 

962 @provide_session 

963 def verify_integrity(self, *, session: Session = NEW_SESSION) -> None: 

964 """ 

965 Verifies the DagRun by checking for removed tasks or tasks that are not in the 

966 database yet. It will set state to removed or add the task if required. 

967 

968 :missing_indexes: A dictionary of task vs indexes that are missing. 

969 :param session: Sqlalchemy ORM Session 

970 """ 

971 from airflow.settings import task_instance_mutation_hook 

972 

973 # Set for the empty default in airflow.settings -- if it's not set this means it has been changed 

974 # Note: Literal[True, False] instead of bool because otherwise it doesn't correctly find the overload. 

975 hook_is_noop: Literal[True, False] = getattr(task_instance_mutation_hook, "is_noop", False) 

976 

977 dag = self.get_dag() 

978 task_ids = self._check_for_removed_or_restored_tasks( 

979 dag, task_instance_mutation_hook, session=session 

980 ) 

981 

982 def task_filter(task: Operator) -> bool: 

983 return task.task_id not in task_ids and ( 

984 self.is_backfill 

985 or task.start_date <= self.execution_date 

986 and (task.end_date is None or self.execution_date <= task.end_date) 

987 ) 

988 

989 created_counts: dict[str, int] = defaultdict(int) 

990 task_creator = self._get_task_creator(created_counts, task_instance_mutation_hook, hook_is_noop) 

991 

992 # Create the missing tasks, including mapped tasks 

993 tasks_to_create = (task for task in dag.task_dict.values() if task_filter(task)) 

994 tis_to_create = self._create_tasks(tasks_to_create, task_creator, session=session) 

995 self._create_task_instances(self.dag_id, tis_to_create, created_counts, hook_is_noop, session=session) 

996 

997 def _check_for_removed_or_restored_tasks( 

998 self, dag: DAG, ti_mutation_hook, *, session: Session 

999 ) -> set[str]: 

1000 """ 

1001 Check for removed tasks/restored/missing tasks. 

1002 

1003 :param dag: DAG object corresponding to the dagrun 

1004 :param ti_mutation_hook: task_instance_mutation_hook function 

1005 :param session: Sqlalchemy ORM Session 

1006 

1007 :return: Task IDs in the DAG run 

1008 

1009 """ 

1010 tis = self.get_task_instances(session=session) 

1011 

1012 # check for removed or restored tasks 

1013 task_ids = set() 

1014 for ti in tis: 

1015 ti_mutation_hook(ti) 

1016 task_ids.add(ti.task_id) 

1017 try: 

1018 task = dag.get_task(ti.task_id) 

1019 

1020 should_restore_task = (task is not None) and ti.state == State.REMOVED 

1021 if should_restore_task: 

1022 self.log.info("Restoring task '%s' which was previously removed from DAG '%s'", ti, dag) 

1023 Stats.incr(f"task_restored_to_dag.{dag.dag_id}", tags=self.stats_tags) 

1024 # Same metric with tagging 

1025 Stats.incr("task_restored_to_dag", tags={**self.stats_tags, "dag_id": dag.dag_id}) 

1026 ti.state = State.NONE 

1027 except AirflowException: 

1028 if ti.state == State.REMOVED: 

1029 pass # ti has already been removed, just ignore it 

1030 elif self.state != State.RUNNING and not dag.partial: 

1031 self.log.warning("Failed to get task '%s' for dag '%s'. Marking it as removed.", ti, dag) 

1032 Stats.incr(f"task_removed_from_dag.{dag.dag_id}", tags=self.stats_tags) 

1033 # Same metric with tagging 

1034 Stats.incr("task_removed_from_dag", tags={**self.stats_tags, "dag_id": dag.dag_id}) 

1035 ti.state = State.REMOVED 

1036 continue 

1037 

1038 try: 

1039 num_mapped_tis = task.get_parse_time_mapped_ti_count() 

1040 except NotMapped: 

1041 continue 

1042 except NotFullyPopulated: 

1043 # What if it is _now_ dynamically mapped, but wasn't before? 

1044 try: 

1045 total_length = task.get_mapped_ti_count(self.run_id, session=session) 

1046 except NotFullyPopulated: 

1047 # Not all upstreams finished, so we can't tell what should be here. Remove everything. 

1048 if ti.map_index >= 0: 

1049 self.log.debug( 

1050 "Removing the unmapped TI '%s' as the mapping can't be resolved yet", ti 

1051 ) 

1052 ti.state = State.REMOVED 

1053 continue 

1054 # Upstreams finished, check there aren't any extras 

1055 if ti.map_index >= total_length: 

1056 self.log.debug( 

1057 "Removing task '%s' as the map_index is longer than the resolved mapping list (%d)", 

1058 ti, 

1059 total_length, 

1060 ) 

1061 ti.state = State.REMOVED 

1062 else: 

1063 # Check if the number of mapped literals has changed, and we need to mark this TI as removed. 

1064 if ti.map_index >= num_mapped_tis: 

1065 self.log.debug( 

1066 "Removing task '%s' as the map_index is longer than the literal mapping list (%s)", 

1067 ti, 

1068 num_mapped_tis, 

1069 ) 

1070 ti.state = State.REMOVED 

1071 elif ti.map_index < 0: 

1072 self.log.debug("Removing the unmapped TI '%s' as the mapping can now be performed", ti) 

1073 ti.state = State.REMOVED 

1074 

1075 return task_ids 

1076 

1077 @overload 

1078 def _get_task_creator( 

1079 self, 

1080 created_counts: dict[str, int], 

1081 ti_mutation_hook: Callable, 

1082 hook_is_noop: Literal[True], 

1083 ) -> Callable[[Operator, Iterable[int]], Iterator[dict[str, Any]]]: 

1084 ... 

1085 

1086 @overload 

1087 def _get_task_creator( 

1088 self, 

1089 created_counts: dict[str, int], 

1090 ti_mutation_hook: Callable, 

1091 hook_is_noop: Literal[False], 

1092 ) -> Callable[[Operator, Iterable[int]], Iterator[TI]]: 

1093 ... 

1094 

1095 def _get_task_creator( 

1096 self, 

1097 created_counts: dict[str, int], 

1098 ti_mutation_hook: Callable, 

1099 hook_is_noop: Literal[True, False], 

1100 ) -> Callable[[Operator, Iterable[int]], Iterator[dict[str, Any]] | Iterator[TI]]: 

1101 """ 

1102 Get the task creator function. 

1103 

1104 This function also updates the created_counts dictionary with the number of tasks created. 

1105 

1106 :param created_counts: Dictionary of task_type -> count of created TIs 

1107 :param ti_mutation_hook: task_instance_mutation_hook function 

1108 :param hook_is_noop: Whether the task_instance_mutation_hook is a noop 

1109 

1110 """ 

1111 if hook_is_noop: 

1112 

1113 def create_ti_mapping(task: Operator, indexes: Iterable[int]) -> Iterator[dict[str, Any]]: 

1114 created_counts[task.task_type] += 1 

1115 for map_index in indexes: 

1116 yield TI.insert_mapping(self.run_id, task, map_index=map_index) 

1117 

1118 creator = create_ti_mapping 

1119 

1120 else: 

1121 

1122 def create_ti(task: Operator, indexes: Iterable[int]) -> Iterator[TI]: 

1123 for map_index in indexes: 

1124 ti = TI(task, run_id=self.run_id, map_index=map_index) 

1125 ti_mutation_hook(ti) 

1126 created_counts[ti.operator] += 1 

1127 yield ti 

1128 

1129 creator = create_ti 

1130 return creator 

1131 

1132 def _create_tasks( 

1133 self, 

1134 tasks: Iterable[Operator], 

1135 task_creator: TaskCreator, 

1136 *, 

1137 session: Session, 

1138 ) -> CreatedTasks: 

1139 """ 

1140 Create missing tasks -- and expand any MappedOperator that _only_ have literals as input. 

1141 

1142 :param tasks: Tasks to create jobs for in the DAG run 

1143 :param task_creator: Function to create task instances 

1144 """ 

1145 map_indexes: Iterable[int] 

1146 for task in tasks: 

1147 try: 

1148 count = task.get_mapped_ti_count(self.run_id, session=session) 

1149 except (NotMapped, NotFullyPopulated): 

1150 map_indexes = (-1,) 

1151 else: 

1152 if count: 

1153 map_indexes = range(count) 

1154 else: 

1155 # Make sure to always create at least one ti; this will be 

1156 # marked as REMOVED later at runtime. 

1157 map_indexes = (-1,) 

1158 yield from task_creator(task, map_indexes) 

1159 

1160 def _create_task_instances( 

1161 self, 

1162 dag_id: str, 

1163 tasks: Iterator[dict[str, Any]] | Iterator[TI], 

1164 created_counts: dict[str, int], 

1165 hook_is_noop: bool, 

1166 *, 

1167 session: Session, 

1168 ) -> None: 

1169 """ 

1170 Create the necessary task instances from the given tasks. 

1171 

1172 :param dag_id: DAG ID associated with the dagrun 

1173 :param tasks: the tasks to create the task instances from 

1174 :param created_counts: a dictionary of number of tasks -> total ti created by the task creator 

1175 :param hook_is_noop: whether the task_instance_mutation_hook is noop 

1176 :param session: the session to use 

1177 

1178 """ 

1179 # Fetch the information we need before handling the exception to avoid 

1180 # PendingRollbackError due to the session being invalidated on exception 

1181 # see https://github.com/apache/superset/pull/530 

1182 run_id = self.run_id 

1183 try: 

1184 if hook_is_noop: 

1185 session.bulk_insert_mappings(TI, tasks) 

1186 else: 

1187 session.bulk_save_objects(tasks) 

1188 

1189 for task_type, count in created_counts.items(): 

1190 Stats.incr(f"task_instance_created-{task_type}", count, tags=self.stats_tags) 

1191 # Same metric with tagging 

1192 Stats.incr("task_instance_created", count, tags={**self.stats_tags, "task_type": task_type}) 

1193 session.flush() 

1194 except IntegrityError: 

1195 self.log.info( 

1196 "Hit IntegrityError while creating the TIs for %s- %s", 

1197 dag_id, 

1198 run_id, 

1199 exc_info=True, 

1200 ) 

1201 self.log.info("Doing session rollback.") 

1202 # TODO[HA]: We probably need to savepoint this so we can keep the transaction alive. 

1203 session.rollback() 

1204 

1205 def _revise_map_indexes_if_mapped(self, task: Operator, *, session: Session) -> Iterator[TI]: 

1206 """Check if task increased or reduced in length and handle appropriately. 

1207 

1208 Task instances that do not already exist are created and returned if 

1209 possible. Expansion only happens if all upstreams are ready; otherwise 

1210 we delay expansion to the "last resort". See comments at the call site 

1211 for more details. 

1212 """ 

1213 from airflow.settings import task_instance_mutation_hook 

1214 

1215 try: 

1216 total_length = task.get_mapped_ti_count(self.run_id, session=session) 

1217 except NotMapped: 

1218 return # Not a mapped task, don't need to do anything. 

1219 except NotFullyPopulated: 

1220 return # Upstreams not ready, don't need to revise this yet. 

1221 

1222 query = session.query(TI.map_index).filter( 

1223 TI.dag_id == self.dag_id, 

1224 TI.task_id == task.task_id, 

1225 TI.run_id == self.run_id, 

1226 ) 

1227 existing_indexes = {i for (i,) in query} 

1228 

1229 removed_indexes = existing_indexes.difference(range(total_length)) 

1230 if removed_indexes: 

1231 session.query(TI).filter( 

1232 TI.dag_id == self.dag_id, 

1233 TI.task_id == task.task_id, 

1234 TI.run_id == self.run_id, 

1235 TI.map_index.in_(removed_indexes), 

1236 ).update({TI.state: TaskInstanceState.REMOVED}) 

1237 session.flush() 

1238 

1239 for index in range(total_length): 

1240 if index in existing_indexes: 

1241 continue 

1242 ti = TI(task, run_id=self.run_id, map_index=index, state=None) 

1243 self.log.debug("Expanding TIs upserted %s", ti) 

1244 task_instance_mutation_hook(ti) 

1245 ti = session.merge(ti) 

1246 ti.refresh_from_task(task) 

1247 session.flush() 

1248 yield ti 

1249 

1250 @staticmethod 

1251 def get_run(session: Session, dag_id: str, execution_date: datetime) -> DagRun | None: 

1252 """ 

1253 Get a single DAG Run. 

1254 

1255 :meta private: 

1256 :param session: Sqlalchemy ORM Session 

1257 :param dag_id: DAG ID 

1258 :param execution_date: execution date 

1259 :return: DagRun corresponding to the given dag_id and execution date 

1260 if one exists. None otherwise. 

1261 """ 

1262 warnings.warn( 

1263 "This method is deprecated. Please use SQLAlchemy directly", 

1264 RemovedInAirflow3Warning, 

1265 stacklevel=2, 

1266 ) 

1267 return ( 

1268 session.query(DagRun) 

1269 .filter( 

1270 DagRun.dag_id == dag_id, 

1271 DagRun.external_trigger == False, # noqa 

1272 DagRun.execution_date == execution_date, 

1273 ) 

1274 .first() 

1275 ) 

1276 

1277 @property 

1278 def is_backfill(self) -> bool: 

1279 return self.run_type == DagRunType.BACKFILL_JOB 

1280 

1281 @classmethod 

1282 @provide_session 

1283 def get_latest_runs(cls, session: Session = NEW_SESSION) -> list[DagRun]: 

1284 """Returns the latest DagRun for each DAG.""" 

1285 subquery = ( 

1286 session.query(cls.dag_id, func.max(cls.execution_date).label("execution_date")) 

1287 .group_by(cls.dag_id) 

1288 .subquery() 

1289 ) 

1290 return ( 

1291 session.query(cls) 

1292 .join( 

1293 subquery, 

1294 and_(cls.dag_id == subquery.c.dag_id, cls.execution_date == subquery.c.execution_date), 

1295 ) 

1296 .all() 

1297 ) 

1298 

1299 @provide_session 

1300 def schedule_tis( 

1301 self, 

1302 schedulable_tis: Iterable[TI], 

1303 session: Session = NEW_SESSION, 

1304 max_tis_per_query: int | None = None, 

1305 ) -> int: 

1306 """ 

1307 Set the given task instances in to the scheduled state. 

1308 

1309 Each element of ``schedulable_tis`` should have it's ``task`` attribute already set. 

1310 

1311 Any EmptyOperator without callbacks or outlets is instead set straight to the success state. 

1312 

1313 All the TIs should belong to this DagRun, but this code is in the hot-path, this is not checked -- it 

1314 is the caller's responsibility to call this function only with TIs from a single dag run. 

1315 """ 

1316 # Get list of TI IDs that do not need to executed, these are 

1317 # tasks using EmptyOperator and without on_execute_callback / on_success_callback 

1318 dummy_ti_ids = [] 

1319 schedulable_ti_ids = [] 

1320 for ti in schedulable_tis: 

1321 if ( 

1322 ti.task.inherits_from_empty_operator 

1323 and not ti.task.on_execute_callback 

1324 and not ti.task.on_success_callback 

1325 and not ti.task.outlets 

1326 ): 

1327 dummy_ti_ids.append(ti.task_id) 

1328 else: 

1329 schedulable_ti_ids.append((ti.task_id, ti.map_index)) 

1330 

1331 count = 0 

1332 

1333 if schedulable_ti_ids: 

1334 schedulable_ti_ids_chunks = chunks( 

1335 schedulable_ti_ids, max_tis_per_query or len(schedulable_ti_ids) 

1336 ) 

1337 for schedulable_ti_ids_chunk in schedulable_ti_ids_chunks: 

1338 count += ( 

1339 session.query(TI) 

1340 .filter( 

1341 TI.dag_id == self.dag_id, 

1342 TI.run_id == self.run_id, 

1343 tuple_in_condition((TI.task_id, TI.map_index), schedulable_ti_ids_chunk), 

1344 ) 

1345 .update({TI.state: State.SCHEDULED}, synchronize_session=False) 

1346 ) 

1347 

1348 # Tasks using EmptyOperator should not be executed, mark them as success 

1349 if dummy_ti_ids: 

1350 dummy_ti_ids_chunks = chunks(dummy_ti_ids, max_tis_per_query or len(dummy_ti_ids)) 

1351 for dummy_ti_ids_chunk in dummy_ti_ids_chunks: 

1352 count += ( 

1353 session.query(TI) 

1354 .filter( 

1355 TI.dag_id == self.dag_id, 

1356 TI.run_id == self.run_id, 

1357 TI.task_id.in_(dummy_ti_ids_chunk), 

1358 ) 

1359 .update( 

1360 { 

1361 TI.state: State.SUCCESS, 

1362 TI.start_date: timezone.utcnow(), 

1363 TI.end_date: timezone.utcnow(), 

1364 TI.duration: 0, 

1365 }, 

1366 synchronize_session=False, 

1367 ) 

1368 ) 

1369 

1370 return count 

1371 

1372 @provide_session 

1373 def get_log_template(self, *, session: Session = NEW_SESSION) -> LogTemplate: 

1374 if self.log_template_id is None: # DagRun created before LogTemplate introduction. 

1375 template = session.query(LogTemplate).order_by(LogTemplate.id).first() 

1376 else: 

1377 template = session.get(LogTemplate, self.log_template_id) 

1378 if template is None: 

1379 raise AirflowException( 

1380 f"No log_template entry found for ID {self.log_template_id!r}. " 

1381 f"Please make sure you set up the metadatabase correctly." 

1382 ) 

1383 return template 

1384 

1385 @provide_session 

1386 def get_log_filename_template(self, *, session: Session = NEW_SESSION) -> str: 

1387 warnings.warn( 

1388 "This method is deprecated. Please use get_log_template instead.", 

1389 RemovedInAirflow3Warning, 

1390 stacklevel=2, 

1391 ) 

1392 return self.get_log_template(session=session).filename 

1393 

1394 

1395class DagRunNote(Base): 

1396 """For storage of arbitrary notes concerning the dagrun instance.""" 

1397 

1398 __tablename__ = "dag_run_note" 

1399 

1400 user_id = Column(Integer, nullable=True) 

1401 dag_run_id = Column(Integer, primary_key=True, nullable=False) 

1402 content = Column(String(1000).with_variant(Text(1000), "mysql")) 

1403 created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False) 

1404 updated_at = Column(UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow, nullable=False) 

1405 

1406 dag_run = relationship("DagRun", back_populates="dag_run_note") 

1407 

1408 __table_args__ = ( 

1409 PrimaryKeyConstraint("dag_run_id", name="dag_run_note_pkey"), 

1410 ForeignKeyConstraint( 

1411 (dag_run_id,), 

1412 ["dag_run.id"], 

1413 name="dag_run_note_dr_fkey", 

1414 ondelete="CASCADE", 

1415 ), 

1416 ForeignKeyConstraint( 

1417 (user_id,), 

1418 ["ab_user.id"], 

1419 name="dag_run_note_user_fkey", 

1420 ), 

1421 ) 

1422 

1423 def __init__(self, content, user_id=None): 

1424 self.content = content 

1425 self.user_id = user_id 

1426 

1427 def __repr__(self): 

1428 prefix = f"<{self.__class__.__name__}: {self.dag_id}.{self.dagrun_id} {self.run_id}" 

1429 if self.map_index != -1: 

1430 prefix += f" map_index={self.map_index}" 

1431 return prefix + ">"