Coverage for /pythoncovmergedfiles/medio/medio/src/airflow/airflow/models/dagrun.py: 25%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18from __future__ import annotations
20import itertools
21import os
22import warnings
23from collections import defaultdict
24from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator, NamedTuple, Sequence, TypeVar, overload
26import re2
27from sqlalchemy import (
28 Boolean,
29 Column,
30 ForeignKey,
31 ForeignKeyConstraint,
32 Index,
33 Integer,
34 PickleType,
35 PrimaryKeyConstraint,
36 String,
37 Text,
38 UniqueConstraint,
39 and_,
40 func,
41 or_,
42 text,
43 update,
44)
45from sqlalchemy.exc import IntegrityError
46from sqlalchemy.ext.associationproxy import association_proxy
47from sqlalchemy.orm import declared_attr, joinedload, relationship, synonym, validates
48from sqlalchemy.sql.expression import case, false, select, true
50from airflow import settings
51from airflow.api_internal.internal_api_call import internal_api_call
52from airflow.callbacks.callback_requests import DagCallbackRequest
53from airflow.configuration import conf as airflow_conf
54from airflow.exceptions import AirflowException, RemovedInAirflow3Warning, TaskDeferred, TaskNotFound
55from airflow.listeners.listener import get_listener_manager
56from airflow.models import Log
57from airflow.models.abstractoperator import NotMapped
58from airflow.models.base import Base, StringID
59from airflow.models.expandinput import NotFullyPopulated
60from airflow.models.taskinstance import TaskInstance as TI
61from airflow.models.tasklog import LogTemplate
62from airflow.stats import Stats
63from airflow.ti_deps.dep_context import DepContext
64from airflow.ti_deps.dependencies_states import SCHEDULEABLE_STATES
65from airflow.utils import timezone
66from airflow.utils.helpers import chunks, is_container, prune_dict
67from airflow.utils.log.logging_mixin import LoggingMixin
68from airflow.utils.session import NEW_SESSION, provide_session
69from airflow.utils.sqlalchemy import UtcDateTime, nulls_first, tuple_in_condition, with_row_locks
70from airflow.utils.state import DagRunState, State, TaskInstanceState
71from airflow.utils.types import NOTSET, DagRunType
73if TYPE_CHECKING:
74 from datetime import datetime
76 from sqlalchemy.orm import Query, Session
78 from airflow.models.dag import DAG
79 from airflow.models.operator import Operator
80 from airflow.serialization.pydantic.dag_run import DagRunPydantic
81 from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic
82 from airflow.serialization.pydantic.tasklog import LogTemplatePydantic
83 from airflow.typing_compat import Literal
84 from airflow.utils.types import ArgNotSet
86 CreatedTasks = TypeVar("CreatedTasks", Iterator["dict[str, Any]"], Iterator[TI])
88RUN_ID_REGEX = r"^(?:manual|scheduled|dataset_triggered)__(?:\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\+00:00)$"
91class TISchedulingDecision(NamedTuple):
92 """Type of return for DagRun.task_instance_scheduling_decisions."""
94 tis: list[TI]
95 schedulable_tis: list[TI]
96 changed_tis: bool
97 unfinished_tis: list[TI]
98 finished_tis: list[TI]
101def _creator_note(val):
102 """Creator the ``note`` association proxy."""
103 if isinstance(val, str):
104 return DagRunNote(content=val)
105 elif isinstance(val, dict):
106 return DagRunNote(**val)
107 else:
108 return DagRunNote(*val)
111class DagRun(Base, LoggingMixin):
112 """Invocation instance of a DAG.
114 A DAG run can be created by the scheduler (i.e. scheduled runs), or by an
115 external trigger (i.e. manual runs).
116 """
118 __tablename__ = "dag_run"
120 id = Column(Integer, primary_key=True)
121 dag_id = Column(StringID(), nullable=False)
122 queued_at = Column(UtcDateTime)
123 execution_date = Column(UtcDateTime, default=timezone.utcnow, nullable=False)
124 start_date = Column(UtcDateTime)
125 end_date = Column(UtcDateTime)
126 _state = Column("state", String(50), default=DagRunState.QUEUED)
127 run_id = Column(StringID(), nullable=False)
128 creating_job_id = Column(Integer)
129 external_trigger = Column(Boolean, default=True)
130 run_type = Column(String(50), nullable=False)
131 conf = Column(PickleType)
132 # These two must be either both NULL or both datetime.
133 data_interval_start = Column(UtcDateTime)
134 data_interval_end = Column(UtcDateTime)
135 # When a scheduler last attempted to schedule TIs for this DagRun
136 last_scheduling_decision = Column(UtcDateTime)
137 dag_hash = Column(String(32))
138 # Foreign key to LogTemplate. DagRun rows created prior to this column's
139 # existence have this set to NULL. Later rows automatically populate this on
140 # insert to point to the latest LogTemplate entry.
141 log_template_id = Column(
142 Integer,
143 ForeignKey("log_template.id", name="task_instance_log_template_id_fkey", ondelete="NO ACTION"),
144 default=select(func.max(LogTemplate.__table__.c.id)),
145 )
146 updated_at = Column(UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow)
147 # Keeps track of the number of times the dagrun had been cleared.
148 # This number is incremented only when the DagRun is re-Queued,
149 # when the DagRun is cleared.
150 clear_number = Column(Integer, default=0, nullable=False, server_default="0")
152 # Remove this `if` after upgrading Sphinx-AutoAPI
153 if not TYPE_CHECKING and "BUILDING_AIRFLOW_DOCS" in os.environ:
154 dag: DAG | None
155 else:
156 dag: DAG | None = None
158 __table_args__ = (
159 Index("dag_id_state", dag_id, _state),
160 UniqueConstraint("dag_id", "execution_date", name="dag_run_dag_id_execution_date_key"),
161 UniqueConstraint("dag_id", "run_id", name="dag_run_dag_id_run_id_key"),
162 Index("idx_dag_run_dag_id", dag_id),
163 Index(
164 "idx_dag_run_running_dags",
165 "state",
166 "dag_id",
167 postgresql_where=text("state='running'"),
168 sqlite_where=text("state='running'"),
169 ),
170 # since mysql lacks filtered/partial indices, this creates a
171 # duplicate index on mysql. Not the end of the world
172 Index(
173 "idx_dag_run_queued_dags",
174 "state",
175 "dag_id",
176 postgresql_where=text("state='queued'"),
177 sqlite_where=text("state='queued'"),
178 ),
179 )
181 task_instances = relationship(
182 TI, back_populates="dag_run", cascade="save-update, merge, delete, delete-orphan"
183 )
184 dag_model = relationship(
185 "DagModel",
186 primaryjoin="foreign(DagRun.dag_id) == DagModel.dag_id",
187 uselist=False,
188 viewonly=True,
189 )
190 dag_run_note = relationship(
191 "DagRunNote",
192 back_populates="dag_run",
193 uselist=False,
194 cascade="all, delete, delete-orphan",
195 )
196 note = association_proxy("dag_run_note", "content", creator=_creator_note)
198 DEFAULT_DAGRUNS_TO_EXAMINE = airflow_conf.getint(
199 "scheduler",
200 "max_dagruns_per_loop_to_schedule",
201 fallback=20,
202 )
204 def __init__(
205 self,
206 dag_id: str | None = None,
207 run_id: str | None = None,
208 queued_at: datetime | None | ArgNotSet = NOTSET,
209 execution_date: datetime | None = None,
210 start_date: datetime | None = None,
211 external_trigger: bool | None = None,
212 conf: Any | None = None,
213 state: DagRunState | None = None,
214 run_type: str | None = None,
215 dag_hash: str | None = None,
216 creating_job_id: int | None = None,
217 data_interval: tuple[datetime, datetime] | None = None,
218 ):
219 if data_interval is None:
220 # Legacy: Only happen for runs created prior to Airflow 2.2.
221 self.data_interval_start = self.data_interval_end = None
222 else:
223 self.data_interval_start, self.data_interval_end = data_interval
225 self.dag_id = dag_id
226 self.run_id = run_id
227 self.execution_date = execution_date
228 self.start_date = start_date
229 self.external_trigger = external_trigger
230 self.conf = conf or {}
231 if state is not None:
232 self.state = state
233 if queued_at is NOTSET:
234 self.queued_at = timezone.utcnow() if state == DagRunState.QUEUED else None
235 else:
236 self.queued_at = queued_at
237 self.run_type = run_type
238 self.dag_hash = dag_hash
239 self.creating_job_id = creating_job_id
240 self.clear_number = 0
241 super().__init__()
243 def __repr__(self):
244 return (
245 f"<DagRun {self.dag_id} @ {self.execution_date}: {self.run_id}, state:{self.state}, "
246 f"queued_at: {self.queued_at}. externally triggered: {self.external_trigger}>"
247 )
249 @validates("run_id")
250 def validate_run_id(self, key: str, run_id: str) -> str | None:
251 if not run_id:
252 return None
253 regex = airflow_conf.get("scheduler", "allowed_run_id_pattern")
254 if not re2.match(regex, run_id) and not re2.match(RUN_ID_REGEX, run_id):
255 raise ValueError(
256 f"The run_id provided '{run_id}' does not match the pattern '{regex}' or '{RUN_ID_REGEX}'"
257 )
258 return run_id
260 @property
261 def stats_tags(self) -> dict[str, str]:
262 return prune_dict({"dag_id": self.dag_id, "run_type": self.run_type})
264 @property
265 def logical_date(self) -> datetime:
266 return self.execution_date
268 def get_state(self):
269 return self._state
271 def set_state(self, state: DagRunState) -> None:
272 """Change the state of the DagRan.
274 Changes to attributes are implemented in accordance with the following table
275 (rows represent old states, columns represent new states):
277 .. list-table:: State transition matrix
278 :header-rows: 1
279 :stub-columns: 1
281 * -
282 - QUEUED
283 - RUNNING
284 - SUCCESS
285 - FAILED
286 * - None
287 - queued_at = timezone.utcnow()
288 - if empty: start_date = timezone.utcnow()
289 end_date = None
290 - end_date = timezone.utcnow()
291 - end_date = timezone.utcnow()
292 * - QUEUED
293 - queued_at = timezone.utcnow()
294 - if empty: start_date = timezone.utcnow()
295 end_date = None
296 - end_date = timezone.utcnow()
297 - end_date = timezone.utcnow()
298 * - RUNNING
299 - queued_at = timezone.utcnow()
300 start_date = None
301 end_date = None
302 -
303 - end_date = timezone.utcnow()
304 - end_date = timezone.utcnow()
305 * - SUCCESS
306 - queued_at = timezone.utcnow()
307 start_date = None
308 end_date = None
309 - start_date = timezone.utcnow()
310 end_date = None
311 -
312 -
313 * - FAILED
314 - queued_at = timezone.utcnow()
315 start_date = None
316 end_date = None
317 - start_date = timezone.utcnow()
318 end_date = None
319 -
320 -
322 """
323 if state not in State.dag_states:
324 raise ValueError(f"invalid DagRun state: {state}")
325 if self._state != state:
326 if state == DagRunState.QUEUED:
327 self.queued_at = timezone.utcnow()
328 self.start_date = None
329 self.end_date = None
330 if state == DagRunState.RUNNING:
331 if self._state in State.finished_dr_states:
332 self.start_date = timezone.utcnow()
333 else:
334 self.start_date = self.start_date or timezone.utcnow()
335 self.end_date = None
336 if self._state in State.unfinished_dr_states or self._state is None:
337 if state in State.finished_dr_states:
338 self.end_date = timezone.utcnow()
339 self._state = state
340 else:
341 if state == DagRunState.QUEUED:
342 self.queued_at = timezone.utcnow()
344 @declared_attr
345 def state(self):
346 return synonym("_state", descriptor=property(self.get_state, self.set_state))
348 @provide_session
349 def refresh_from_db(self, session: Session = NEW_SESSION) -> None:
350 """
351 Reload the current dagrun from the database.
353 :param session: database session
354 """
355 dr = session.scalars(
356 select(DagRun).where(DagRun.dag_id == self.dag_id, DagRun.run_id == self.run_id)
357 ).one()
358 self.id = dr.id
359 self.state = dr.state
361 @classmethod
362 @provide_session
363 def active_runs_of_dags(
364 cls,
365 dag_ids: Iterable[str] | None = None,
366 only_running: bool = False,
367 session: Session = NEW_SESSION,
368 ) -> dict[str, int]:
369 """Get the number of active dag runs for each dag."""
370 query = select(cls.dag_id, func.count("*"))
371 if dag_ids is not None:
372 # 'set' called to avoid duplicate dag_ids, but converted back to 'list'
373 # because SQLAlchemy doesn't accept a set here.
374 query = query.where(cls.dag_id.in_(set(dag_ids)))
375 if only_running:
376 query = query.where(cls.state == DagRunState.RUNNING)
377 else:
378 query = query.where(cls.state.in_((DagRunState.RUNNING, DagRunState.QUEUED)))
379 query = query.group_by(cls.dag_id)
380 return dict(iter(session.execute(query)))
382 @classmethod
383 def next_dagruns_to_examine(
384 cls,
385 state: DagRunState,
386 session: Session,
387 max_number: int | None = None,
388 ) -> Query:
389 """
390 Return the next DagRuns that the scheduler should attempt to schedule.
392 This will return zero or more DagRun rows that are row-level-locked with a "SELECT ... FOR UPDATE"
393 query, you should ensure that any scheduling decisions are made in a single transaction -- as soon as
394 the transaction is committed it will be unlocked.
396 """
397 from airflow.models.dag import DagModel
399 if max_number is None:
400 max_number = cls.DEFAULT_DAGRUNS_TO_EXAMINE
402 # TODO: Bake this query, it is run _A lot_
403 query = (
404 select(cls)
405 .with_hint(cls, "USE INDEX (idx_dag_run_running_dags)", dialect_name="mysql")
406 .where(cls.state == state, cls.run_type != DagRunType.BACKFILL_JOB)
407 .join(DagModel, DagModel.dag_id == cls.dag_id)
408 .where(DagModel.is_paused == false(), DagModel.is_active == true())
409 )
410 if state == DagRunState.QUEUED:
411 # For dag runs in the queued state, we check if they have reached the max_active_runs limit
412 # and if so we drop them
413 running_drs = (
414 select(DagRun.dag_id, func.count(DagRun.state).label("num_running"))
415 .where(DagRun.state == DagRunState.RUNNING)
416 .group_by(DagRun.dag_id)
417 .subquery()
418 )
419 query = query.outerjoin(running_drs, running_drs.c.dag_id == DagRun.dag_id).where(
420 func.coalesce(running_drs.c.num_running, 0) < DagModel.max_active_runs
421 )
422 query = query.order_by(
423 nulls_first(cls.last_scheduling_decision, session=session),
424 cls.execution_date,
425 )
427 if not settings.ALLOW_FUTURE_EXEC_DATES:
428 query = query.where(DagRun.execution_date <= func.now())
430 return session.scalars(
431 with_row_locks(query.limit(max_number), of=cls, session=session, skip_locked=True)
432 )
434 @classmethod
435 @provide_session
436 def find(
437 cls,
438 dag_id: str | list[str] | None = None,
439 run_id: Iterable[str] | None = None,
440 execution_date: datetime | Iterable[datetime] | None = None,
441 state: DagRunState | None = None,
442 external_trigger: bool | None = None,
443 no_backfills: bool = False,
444 run_type: DagRunType | None = None,
445 session: Session = NEW_SESSION,
446 execution_start_date: datetime | None = None,
447 execution_end_date: datetime | None = None,
448 ) -> list[DagRun]:
449 """
450 Return a set of dag runs for the given search criteria.
452 :param dag_id: the dag_id or list of dag_id to find dag runs for
453 :param run_id: defines the run id for this dag run
454 :param run_type: type of DagRun
455 :param execution_date: the execution date
456 :param state: the state of the dag run
457 :param external_trigger: whether this dag run is externally triggered
458 :param no_backfills: return no backfills (True), return all (False).
459 Defaults to False
460 :param session: database session
461 :param execution_start_date: dag run that was executed from this date
462 :param execution_end_date: dag run that was executed until this date
463 """
464 qry = select(cls)
465 dag_ids = [dag_id] if isinstance(dag_id, str) else dag_id
466 if dag_ids:
467 qry = qry.where(cls.dag_id.in_(dag_ids))
469 if is_container(run_id):
470 qry = qry.where(cls.run_id.in_(run_id))
471 elif run_id is not None:
472 qry = qry.where(cls.run_id == run_id)
473 if is_container(execution_date):
474 qry = qry.where(cls.execution_date.in_(execution_date))
475 elif execution_date is not None:
476 qry = qry.where(cls.execution_date == execution_date)
477 if execution_start_date and execution_end_date:
478 qry = qry.where(cls.execution_date.between(execution_start_date, execution_end_date))
479 elif execution_start_date:
480 qry = qry.where(cls.execution_date >= execution_start_date)
481 elif execution_end_date:
482 qry = qry.where(cls.execution_date <= execution_end_date)
483 if state:
484 qry = qry.where(cls.state == state)
485 if external_trigger is not None:
486 qry = qry.where(cls.external_trigger == external_trigger)
487 if run_type:
488 qry = qry.where(cls.run_type == run_type)
489 if no_backfills:
490 qry = qry.where(cls.run_type != DagRunType.BACKFILL_JOB)
492 return session.scalars(qry.order_by(cls.execution_date)).all()
494 @classmethod
495 @provide_session
496 def find_duplicate(
497 cls,
498 dag_id: str,
499 run_id: str,
500 execution_date: datetime,
501 session: Session = NEW_SESSION,
502 ) -> DagRun | None:
503 """
504 Return an existing run for the DAG with a specific run_id or execution_date.
506 *None* is returned if no such DAG run is found.
508 :param dag_id: the dag_id to find duplicates for
509 :param run_id: defines the run id for this dag run
510 :param execution_date: the execution date
511 :param session: database session
512 """
513 return session.scalars(
514 select(cls).where(
515 cls.dag_id == dag_id,
516 or_(cls.run_id == run_id, cls.execution_date == execution_date),
517 )
518 ).one_or_none()
520 @staticmethod
521 def generate_run_id(run_type: DagRunType, execution_date: datetime) -> str:
522 """Generate Run ID based on Run Type and Execution Date."""
523 # _Ensure_ run_type is a DagRunType, not just a string from user code
524 return DagRunType(run_type).generate_run_id(execution_date)
526 @staticmethod
527 @internal_api_call
528 @provide_session
529 def fetch_task_instances(
530 dag_id: str | None = None,
531 run_id: str | None = None,
532 task_ids: list[str] | None = None,
533 state: Iterable[TaskInstanceState | None] | None = None,
534 session: Session = NEW_SESSION,
535 ) -> list[TI]:
536 """Return the task instances for this dag run."""
537 tis = (
538 select(TI)
539 .options(joinedload(TI.dag_run))
540 .where(
541 TI.dag_id == dag_id,
542 TI.run_id == run_id,
543 )
544 )
546 if state:
547 if isinstance(state, str):
548 tis = tis.where(TI.state == state)
549 else:
550 # this is required to deal with NULL values
551 if None in state:
552 if all(x is None for x in state):
553 tis = tis.where(TI.state.is_(None))
554 else:
555 not_none_state = (s for s in state if s)
556 tis = tis.where(or_(TI.state.in_(not_none_state), TI.state.is_(None)))
557 else:
558 tis = tis.where(TI.state.in_(state))
560 if task_ids is not None:
561 tis = tis.where(TI.task_id.in_(task_ids))
562 return session.scalars(tis).all()
564 @internal_api_call
565 def _check_last_n_dagruns_failed(self, dag_id, max_consecutive_failed_dag_runs, session):
566 """Check if last N dags failed."""
567 dag_runs = (
568 session.query(DagRun)
569 .filter(DagRun.dag_id == dag_id)
570 .order_by(DagRun.execution_date.desc())
571 .limit(max_consecutive_failed_dag_runs)
572 .all()
573 )
574 """ Marking dag as paused, if needed"""
575 to_be_paused = len(dag_runs) >= max_consecutive_failed_dag_runs and all(
576 dag_run.state == DagRunState.FAILED for dag_run in dag_runs
577 )
579 if to_be_paused:
580 from airflow.models.dag import DagModel
582 self.log.info(
583 "Pausing DAG %s because last %s DAG runs failed.",
584 self.dag_id,
585 max_consecutive_failed_dag_runs,
586 )
587 filter_query = [
588 DagModel.dag_id == self.dag_id,
589 DagModel.root_dag_id == self.dag_id, # for sub-dags
590 ]
591 session.execute(
592 update(DagModel)
593 .where(or_(*filter_query))
594 .values(is_paused=True)
595 .execution_options(synchronize_session="fetch")
596 )
597 session.add(
598 Log(
599 event="paused",
600 dag_id=self.dag_id,
601 owner="scheduler",
602 owner_display_name="Scheduler",
603 extra=f"[('dag_id', '{self.dag_id}'), ('is_paused', True)]",
604 )
605 )
606 else:
607 self.log.debug(
608 "Limit of consecutive DAG failed dag runs is not reached, DAG %s will not be paused.",
609 self.dag_id,
610 )
612 @provide_session
613 def get_task_instances(
614 self,
615 state: Iterable[TaskInstanceState | None] | None = None,
616 session: Session = NEW_SESSION,
617 ) -> list[TI]:
618 """
619 Return the task instances for this dag run.
621 Redirect to DagRun.fetch_task_instances method.
622 Keep this method because it is widely used across the code.
623 """
624 task_ids = self.dag.task_ids if self.dag and self.dag.partial else None
625 return DagRun.fetch_task_instances(
626 dag_id=self.dag_id, run_id=self.run_id, task_ids=task_ids, state=state, session=session
627 )
629 @provide_session
630 def get_task_instance(
631 self,
632 task_id: str,
633 session: Session = NEW_SESSION,
634 *,
635 map_index: int = -1,
636 ) -> TI | TaskInstancePydantic | None:
637 """
638 Return the task instance specified by task_id for this dag run.
640 :param task_id: the task id
641 :param session: Sqlalchemy ORM Session
642 """
643 return DagRun.fetch_task_instance(
644 dag_id=self.dag_id,
645 dag_run_id=self.run_id,
646 task_id=task_id,
647 session=session,
648 map_index=map_index,
649 )
651 @staticmethod
652 @provide_session
653 def fetch_task_instance(
654 dag_id: str,
655 dag_run_id: str,
656 task_id: str,
657 session: Session = NEW_SESSION,
658 map_index: int = -1,
659 ) -> TI | TaskInstancePydantic | None:
660 """
661 Return the task instance specified by task_id for this dag run.
663 :param dag_id: the DAG id
664 :param dag_run_id: the DAG run id
665 :param task_id: the task id
666 :param session: Sqlalchemy ORM Session
667 """
668 return session.scalars(
669 select(TI).filter_by(dag_id=dag_id, run_id=dag_run_id, task_id=task_id, map_index=map_index)
670 ).one_or_none()
672 def get_dag(self) -> DAG:
673 """
674 Return the Dag associated with this DagRun.
676 :return: DAG
677 """
678 if not self.dag:
679 raise AirflowException(f"The DAG (.dag) for {self} needs to be set")
681 return self.dag
683 @staticmethod
684 @internal_api_call
685 @provide_session
686 def get_previous_dagrun(
687 dag_run: DagRun | DagRunPydantic, state: DagRunState | None = None, session: Session = NEW_SESSION
688 ) -> DagRun | None:
689 """
690 Return the previous DagRun, if there is one.
692 :param dag_run: the dag run
693 :param session: SQLAlchemy ORM Session
694 :param state: the dag run state
695 """
696 filters = [
697 DagRun.dag_id == dag_run.dag_id,
698 DagRun.execution_date < dag_run.execution_date,
699 ]
700 if state is not None:
701 filters.append(DagRun.state == state)
702 return session.scalar(select(DagRun).where(*filters).order_by(DagRun.execution_date.desc()).limit(1))
704 @staticmethod
705 @internal_api_call
706 @provide_session
707 def get_previous_scheduled_dagrun(
708 dag_run_id: int,
709 session: Session = NEW_SESSION,
710 ) -> DagRun | None:
711 """
712 Return the previous SCHEDULED DagRun, if there is one.
714 :param dag_run_id: the DAG run ID
715 :param session: SQLAlchemy ORM Session
716 """
717 dag_run = session.get(DagRun, dag_run_id)
718 return session.scalar(
719 select(DagRun)
720 .where(
721 DagRun.dag_id == dag_run.dag_id,
722 DagRun.execution_date < dag_run.execution_date,
723 DagRun.run_type != DagRunType.MANUAL,
724 )
725 .order_by(DagRun.execution_date.desc())
726 .limit(1)
727 )
729 def _tis_for_dagrun_state(self, *, dag, tis):
730 """
731 Return the collection of tasks that should be considered for evaluation of terminal dag run state.
733 Teardown tasks by default are not considered for the purpose of dag run state. But
734 users may enable such consideration with on_failure_fail_dagrun.
735 """
737 def is_effective_leaf(task):
738 for down_task_id in task.downstream_task_ids:
739 down_task = dag.get_task(down_task_id)
740 if not down_task.is_teardown or down_task.on_failure_fail_dagrun:
741 # we found a down task that is not ignorable; not a leaf
742 return False
743 # we found no ignorable downstreams
744 # evaluate whether task is itself ignorable
745 return not task.is_teardown or task.on_failure_fail_dagrun
747 leaf_task_ids = {x.task_id for x in dag.tasks if is_effective_leaf(x)}
748 if not leaf_task_ids:
749 # can happen if dag is exclusively teardown tasks
750 leaf_task_ids = {x.task_id for x in dag.tasks if not x.downstream_list}
751 leaf_tis = {ti for ti in tis if ti.task_id in leaf_task_ids if ti.state != TaskInstanceState.REMOVED}
752 return leaf_tis
754 @provide_session
755 def update_state(
756 self, session: Session = NEW_SESSION, execute_callbacks: bool = True
757 ) -> tuple[list[TI], DagCallbackRequest | None]:
758 """
759 Determine the overall state of the DagRun based on the state of its TaskInstances.
761 :param session: Sqlalchemy ORM Session
762 :param execute_callbacks: Should dag callbacks (success/failure, SLA etc.) be invoked
763 directly (default: true) or recorded as a pending request in the ``returned_callback`` property
764 :return: Tuple containing tis that can be scheduled in the current loop & `returned_callback` that
765 needs to be executed
766 """
767 # Callback to execute in case of Task Failures
768 callback: DagCallbackRequest | None = None
770 class _UnfinishedStates(NamedTuple):
771 tis: Sequence[TI]
773 @classmethod
774 def calculate(cls, unfinished_tis: Sequence[TI]) -> _UnfinishedStates:
775 return cls(tis=unfinished_tis)
777 @property
778 def should_schedule(self) -> bool:
779 return (
780 bool(self.tis)
781 and all(not t.task.depends_on_past for t in self.tis) # type: ignore[union-attr]
782 and all(t.task.max_active_tis_per_dag is None for t in self.tis) # type: ignore[union-attr]
783 and all(t.task.max_active_tis_per_dagrun is None for t in self.tis) # type: ignore[union-attr]
784 and all(t.state != TaskInstanceState.DEFERRED for t in self.tis)
785 )
787 def recalculate(self) -> _UnfinishedStates:
788 return self._replace(tis=[t for t in self.tis if t.state in State.unfinished])
790 start_dttm = timezone.utcnow()
791 self.last_scheduling_decision = start_dttm
792 with Stats.timer(f"dagrun.dependency-check.{self.dag_id}"), Stats.timer(
793 "dagrun.dependency-check", tags=self.stats_tags
794 ):
795 dag = self.get_dag()
796 info = self.task_instance_scheduling_decisions(session)
798 tis = info.tis
799 schedulable_tis = info.schedulable_tis
800 changed_tis = info.changed_tis
801 finished_tis = info.finished_tis
802 unfinished = _UnfinishedStates.calculate(info.unfinished_tis)
804 if unfinished.should_schedule:
805 are_runnable_tasks = schedulable_tis or changed_tis
806 # small speed up
807 if not are_runnable_tasks:
808 are_runnable_tasks, changed_by_upstream = self._are_premature_tis(
809 unfinished.tis, finished_tis, session
810 )
811 if changed_by_upstream: # Something changed, we need to recalculate!
812 unfinished = unfinished.recalculate()
814 tis_for_dagrun_state = self._tis_for_dagrun_state(dag=dag, tis=tis)
816 # if all tasks finished and at least one failed, the run failed
817 if not unfinished.tis and any(x.state in State.failed_states for x in tis_for_dagrun_state):
818 self.log.error("Marking run %s failed", self)
819 self.set_state(DagRunState.FAILED)
820 self.notify_dagrun_state_changed(msg="task_failure")
822 if execute_callbacks:
823 dag.handle_callback(self, success=False, reason="task_failure", session=session)
824 elif dag.has_on_failure_callback:
825 from airflow.models.dag import DagModel
827 dag_model = DagModel.get_dagmodel(dag.dag_id, session)
828 callback = DagCallbackRequest(
829 full_filepath=dag.fileloc,
830 dag_id=self.dag_id,
831 run_id=self.run_id,
832 is_failure_callback=True,
833 processor_subdir=None if dag_model is None else dag_model.processor_subdir,
834 msg="task_failure",
835 )
837 # Check if the max_consecutive_failed_dag_runs has been provided and not 0
838 # and last consecutive failures are more
839 if dag.max_consecutive_failed_dag_runs > 0:
840 self.log.debug(
841 "Checking consecutive failed DAG runs for DAG %s, limit is %s",
842 self.dag_id,
843 dag.max_consecutive_failed_dag_runs,
844 )
845 self._check_last_n_dagruns_failed(dag.dag_id, dag.max_consecutive_failed_dag_runs, session)
847 # if all leaves succeeded and no unfinished tasks, the run succeeded
848 elif not unfinished.tis and all(x.state in State.success_states for x in tis_for_dagrun_state):
849 self.log.info("Marking run %s successful", self)
850 self.set_state(DagRunState.SUCCESS)
851 self.notify_dagrun_state_changed(msg="success")
853 if execute_callbacks:
854 dag.handle_callback(self, success=True, reason="success", session=session)
855 elif dag.has_on_success_callback:
856 from airflow.models.dag import DagModel
858 dag_model = DagModel.get_dagmodel(dag.dag_id, session)
859 callback = DagCallbackRequest(
860 full_filepath=dag.fileloc,
861 dag_id=self.dag_id,
862 run_id=self.run_id,
863 is_failure_callback=False,
864 processor_subdir=None if dag_model is None else dag_model.processor_subdir,
865 msg="success",
866 )
868 # if *all tasks* are deadlocked, the run failed
869 elif unfinished.should_schedule and not are_runnable_tasks:
870 self.log.error("Task deadlock (no runnable tasks); marking run %s failed", self)
871 self.set_state(DagRunState.FAILED)
872 self.notify_dagrun_state_changed(msg="all_tasks_deadlocked")
874 if execute_callbacks:
875 dag.handle_callback(self, success=False, reason="all_tasks_deadlocked", session=session)
876 elif dag.has_on_failure_callback:
877 from airflow.models.dag import DagModel
879 dag_model = DagModel.get_dagmodel(dag.dag_id, session)
880 callback = DagCallbackRequest(
881 full_filepath=dag.fileloc,
882 dag_id=self.dag_id,
883 run_id=self.run_id,
884 is_failure_callback=True,
885 processor_subdir=None if dag_model is None else dag_model.processor_subdir,
886 msg="all_tasks_deadlocked",
887 )
889 # finally, if the leaves aren't done, the dag is still running
890 else:
891 self.set_state(DagRunState.RUNNING)
893 if self._state == DagRunState.FAILED or self._state == DagRunState.SUCCESS:
894 msg = (
895 "DagRun Finished: dag_id=%s, execution_date=%s, run_id=%s, "
896 "run_start_date=%s, run_end_date=%s, run_duration=%s, "
897 "state=%s, external_trigger=%s, run_type=%s, "
898 "data_interval_start=%s, data_interval_end=%s, dag_hash=%s"
899 )
900 self.log.info(
901 msg,
902 self.dag_id,
903 self.execution_date,
904 self.run_id,
905 self.start_date,
906 self.end_date,
907 (
908 (self.end_date - self.start_date).total_seconds()
909 if self.start_date and self.end_date
910 else None
911 ),
912 self._state,
913 self.external_trigger,
914 self.run_type,
915 self.data_interval_start,
916 self.data_interval_end,
917 self.dag_hash,
918 )
919 session.flush()
921 self._emit_true_scheduling_delay_stats_for_finished_state(finished_tis)
922 self._emit_duration_stats_for_finished_state()
924 session.merge(self)
925 # We do not flush here for performance reasons(It increases queries count by +20)
927 return schedulable_tis, callback
929 @provide_session
930 def task_instance_scheduling_decisions(self, session: Session = NEW_SESSION) -> TISchedulingDecision:
931 tis = self.get_task_instances(session=session, state=State.task_states)
932 self.log.debug("number of tis tasks for %s: %s task(s)", self, len(tis))
934 def _filter_tis_and_exclude_removed(dag: DAG, tis: list[TI]) -> Iterable[TI]:
935 """Populate ``ti.task`` while excluding those missing one, marking them as REMOVED."""
936 for ti in tis:
937 try:
938 ti.task = dag.get_task(ti.task_id)
939 except TaskNotFound:
940 if ti.state != TaskInstanceState.REMOVED:
941 self.log.error("Failed to get task for ti %s. Marking it as removed.", ti)
942 ti.state = TaskInstanceState.REMOVED
943 session.flush()
944 else:
945 yield ti
947 tis = list(_filter_tis_and_exclude_removed(self.get_dag(), tis))
949 unfinished_tis = [t for t in tis if t.state in State.unfinished]
950 finished_tis = [t for t in tis if t.state in State.finished]
951 if unfinished_tis:
952 schedulable_tis = [ut for ut in unfinished_tis if ut.state in SCHEDULEABLE_STATES]
953 self.log.debug("number of scheduleable tasks for %s: %s task(s)", self, len(schedulable_tis))
954 schedulable_tis, changed_tis, expansion_happened = self._get_ready_tis(
955 schedulable_tis,
956 finished_tis,
957 session=session,
958 )
960 # During expansion, we may change some tis into non-schedulable
961 # states, so we need to re-compute.
962 if expansion_happened:
963 changed_tis = True
964 new_unfinished_tis = [t for t in unfinished_tis if t.state in State.unfinished]
965 finished_tis.extend(t for t in unfinished_tis if t.state in State.finished)
966 unfinished_tis = new_unfinished_tis
967 else:
968 schedulable_tis = []
969 changed_tis = False
971 return TISchedulingDecision(
972 tis=tis,
973 schedulable_tis=schedulable_tis,
974 changed_tis=changed_tis,
975 unfinished_tis=unfinished_tis,
976 finished_tis=finished_tis,
977 )
979 def notify_dagrun_state_changed(self, msg: str = ""):
980 if self.state == DagRunState.RUNNING:
981 get_listener_manager().hook.on_dag_run_running(dag_run=self, msg=msg)
982 elif self.state == DagRunState.SUCCESS:
983 get_listener_manager().hook.on_dag_run_success(dag_run=self, msg=msg)
984 elif self.state == DagRunState.FAILED:
985 get_listener_manager().hook.on_dag_run_failed(dag_run=self, msg=msg)
986 # deliberately not notifying on QUEUED
987 # we can't get all the state changes on SchedulerJob, BackfillJob
988 # or LocalTaskJob, so we don't want to "falsely advertise" we notify about that
990 def _get_ready_tis(
991 self,
992 schedulable_tis: list[TI],
993 finished_tis: list[TI],
994 session: Session,
995 ) -> tuple[list[TI], bool, bool]:
996 old_states = {}
997 ready_tis: list[TI] = []
998 changed_tis = False
1000 if not schedulable_tis:
1001 return ready_tis, changed_tis, False
1003 # If we expand TIs, we need a new list so that we iterate over them too. (We can't alter
1004 # `schedulable_tis` in place and have the `for` loop pick them up
1005 additional_tis: list[TI] = []
1006 dep_context = DepContext(
1007 flag_upstream_failed=True,
1008 ignore_unmapped_tasks=True, # Ignore this Dep, as we will expand it if we can.
1009 finished_tis=finished_tis,
1010 )
1012 def _expand_mapped_task_if_needed(ti: TI) -> Iterable[TI] | None:
1013 """Try to expand the ti, if needed.
1015 If the ti needs expansion, newly created task instances are
1016 returned as well as the original ti.
1017 The original ti is also modified in-place and assigned the
1018 ``map_index`` of 0.
1020 If the ti does not need expansion, either because the task is not
1021 mapped, or has already been expanded, *None* is returned.
1022 """
1023 if TYPE_CHECKING:
1024 assert ti.task
1026 if ti.map_index >= 0: # Already expanded, we're good.
1027 return None
1029 from airflow.models.mappedoperator import MappedOperator
1031 if isinstance(ti.task, MappedOperator):
1032 # If we get here, it could be that we are moving from non-mapped to mapped
1033 # after task instance clearing or this ti is not yet expanded. Safe to clear
1034 # the db references.
1035 ti.clear_db_references(session=session)
1036 try:
1037 expanded_tis, _ = ti.task.expand_mapped_task(self.run_id, session=session)
1038 except NotMapped: # Not a mapped task, nothing needed.
1039 return None
1040 if expanded_tis:
1041 return expanded_tis
1042 return ()
1044 # Check dependencies.
1045 expansion_happened = False
1046 # Set of task ids for which was already done _revise_map_indexes_if_mapped
1047 revised_map_index_task_ids = set()
1048 for schedulable in itertools.chain(schedulable_tis, additional_tis):
1049 if TYPE_CHECKING:
1050 assert schedulable.task
1051 old_state = schedulable.state
1052 if not schedulable.are_dependencies_met(session=session, dep_context=dep_context):
1053 old_states[schedulable.key] = old_state
1054 continue
1055 # If schedulable is not yet expanded, try doing it now. This is
1056 # called in two places: First and ideally in the mini scheduler at
1057 # the end of LocalTaskJob, and then as an "expansion of last resort"
1058 # in the scheduler to ensure that the mapped task is correctly
1059 # expanded before executed. Also see _revise_map_indexes_if_mapped
1060 # docstring for additional information.
1061 new_tis = None
1062 if schedulable.map_index < 0:
1063 new_tis = _expand_mapped_task_if_needed(schedulable)
1064 if new_tis is not None:
1065 additional_tis.extend(new_tis)
1066 expansion_happened = True
1067 if new_tis is None and schedulable.state in SCHEDULEABLE_STATES:
1068 # It's enough to revise map index once per task id,
1069 # checking the map index for each mapped task significantly slows down scheduling
1070 if schedulable.task.task_id not in revised_map_index_task_ids:
1071 ready_tis.extend(self._revise_map_indexes_if_mapped(schedulable.task, session=session))
1072 revised_map_index_task_ids.add(schedulable.task.task_id)
1073 ready_tis.append(schedulable)
1075 # Check if any ti changed state
1076 tis_filter = TI.filter_for_tis(old_states)
1077 if tis_filter is not None:
1078 fresh_tis = session.scalars(select(TI).where(tis_filter)).all()
1079 changed_tis = any(ti.state != old_states[ti.key] for ti in fresh_tis)
1081 return ready_tis, changed_tis, expansion_happened
1083 def _are_premature_tis(
1084 self,
1085 unfinished_tis: Sequence[TI],
1086 finished_tis: list[TI],
1087 session: Session,
1088 ) -> tuple[bool, bool]:
1089 dep_context = DepContext(
1090 flag_upstream_failed=True,
1091 ignore_in_retry_period=True,
1092 ignore_in_reschedule_period=True,
1093 finished_tis=finished_tis,
1094 )
1095 # there might be runnable tasks that are up for retry and for some reason(retry delay, etc.) are
1096 # not ready yet, so we set the flags to count them in
1097 return (
1098 any(ut.are_dependencies_met(dep_context=dep_context, session=session) for ut in unfinished_tis),
1099 dep_context.have_changed_ti_states,
1100 )
1102 def _emit_true_scheduling_delay_stats_for_finished_state(self, finished_tis: list[TI]) -> None:
1103 """Emit the true scheduling delay stats.
1105 The true scheduling delay stats is defined as the time when the first
1106 task in DAG starts minus the expected DAG run datetime.
1108 This helper method is used in ``update_state`` when the state of the
1109 DAG run is updated to a completed status (either success or failure).
1110 It finds the first started task within the DAG, calculates the run's
1111 expected start time based on the logical date and timetable, and gets
1112 the delay from the difference of these two values.
1114 The emitted data may contain outliers (e.g. when the first task was
1115 cleared, so the second task's start date will be used), but we can get
1116 rid of the outliers on the stats side through dashboards tooling.
1118 Note that the stat will only be emitted for scheduler-triggered DAG runs
1119 (i.e. when ``external_trigger`` is *False* and ``clear_number`` is equal to 0).
1120 """
1121 if self.state == TaskInstanceState.RUNNING:
1122 return
1123 if self.external_trigger:
1124 return
1125 if self.clear_number > 0:
1126 return
1127 if not finished_tis:
1128 return
1130 try:
1131 dag = self.get_dag()
1133 if not dag.timetable.periodic:
1134 # We can't emit this metric if there is no following schedule to calculate from!
1135 return
1137 try:
1138 first_start_date = min(ti.start_date for ti in finished_tis if ti.start_date)
1139 except ValueError: # No start dates at all.
1140 pass
1141 else:
1142 # TODO: Logically, this should be DagRunInfo.run_after, but the
1143 # information is not stored on a DagRun, only before the actual
1144 # execution on DagModel.next_dagrun_create_after. We should add
1145 # a field on DagRun for this instead of relying on the run
1146 # always happening immediately after the data interval.
1147 data_interval_end = dag.get_run_data_interval(self).end
1148 true_delay = first_start_date - data_interval_end
1149 if true_delay.total_seconds() > 0:
1150 Stats.timing(
1151 f"dagrun.{dag.dag_id}.first_task_scheduling_delay", true_delay, tags=self.stats_tags
1152 )
1153 Stats.timing("dagrun.first_task_scheduling_delay", true_delay, tags=self.stats_tags)
1154 except Exception:
1155 self.log.warning("Failed to record first_task_scheduling_delay metric:", exc_info=True)
1157 def _emit_duration_stats_for_finished_state(self):
1158 if self.state == DagRunState.RUNNING:
1159 return
1160 if self.start_date is None:
1161 self.log.warning("Failed to record duration of %s: start_date is not set.", self)
1162 return
1163 if self.end_date is None:
1164 self.log.warning("Failed to record duration of %s: end_date is not set.", self)
1165 return
1167 duration = self.end_date - self.start_date
1168 timer_params = {"dt": duration, "tags": self.stats_tags}
1169 Stats.timing(f"dagrun.duration.{self.state}.{self.dag_id}", **timer_params)
1170 Stats.timing(f"dagrun.duration.{self.state}", **timer_params)
1172 @provide_session
1173 def verify_integrity(self, *, session: Session = NEW_SESSION) -> None:
1174 """
1175 Verify the DagRun by checking for removed tasks or tasks that are not in the database yet.
1177 It will set state to removed or add the task if required.
1179 :missing_indexes: A dictionary of task vs indexes that are missing.
1180 :param session: Sqlalchemy ORM Session
1181 """
1182 from airflow.settings import task_instance_mutation_hook
1184 # Set for the empty default in airflow.settings -- if it's not set this means it has been changed
1185 # Note: Literal[True, False] instead of bool because otherwise it doesn't correctly find the overload.
1186 hook_is_noop: Literal[True, False] = getattr(task_instance_mutation_hook, "is_noop", False)
1188 dag = self.get_dag()
1189 task_ids = self._check_for_removed_or_restored_tasks(
1190 dag, task_instance_mutation_hook, session=session
1191 )
1193 def task_filter(task: Operator) -> bool:
1194 return task.task_id not in task_ids and (
1195 self.is_backfill
1196 or (task.start_date is None or task.start_date <= self.execution_date)
1197 and (task.end_date is None or self.execution_date <= task.end_date)
1198 )
1200 created_counts: dict[str, int] = defaultdict(int)
1201 task_creator = self._get_task_creator(created_counts, task_instance_mutation_hook, hook_is_noop)
1203 # Create the missing tasks, including mapped tasks
1204 tasks_to_create = (task for task in dag.task_dict.values() if task_filter(task))
1205 tis_to_create = self._create_tasks(tasks_to_create, task_creator, session=session)
1206 self._create_task_instances(self.dag_id, tis_to_create, created_counts, hook_is_noop, session=session)
1208 def _check_for_removed_or_restored_tasks(
1209 self, dag: DAG, ti_mutation_hook, *, session: Session
1210 ) -> set[str]:
1211 """
1212 Check for removed tasks/restored/missing tasks.
1214 :param dag: DAG object corresponding to the dagrun
1215 :param ti_mutation_hook: task_instance_mutation_hook function
1216 :param session: Sqlalchemy ORM Session
1218 :return: Task IDs in the DAG run
1220 """
1221 tis = self.get_task_instances(session=session)
1223 # check for removed or restored tasks
1224 task_ids = set()
1225 for ti in tis:
1226 ti_mutation_hook(ti)
1227 task_ids.add(ti.task_id)
1228 try:
1229 task = dag.get_task(ti.task_id)
1231 should_restore_task = (task is not None) and ti.state == TaskInstanceState.REMOVED
1232 if should_restore_task:
1233 self.log.info("Restoring task '%s' which was previously removed from DAG '%s'", ti, dag)
1234 Stats.incr(f"task_restored_to_dag.{dag.dag_id}", tags=self.stats_tags)
1235 # Same metric with tagging
1236 Stats.incr("task_restored_to_dag", tags={**self.stats_tags, "dag_id": dag.dag_id})
1237 ti.state = None
1238 except AirflowException:
1239 if ti.state == TaskInstanceState.REMOVED:
1240 pass # ti has already been removed, just ignore it
1241 elif self.state != DagRunState.RUNNING and not dag.partial:
1242 self.log.warning("Failed to get task '%s' for dag '%s'. Marking it as removed.", ti, dag)
1243 Stats.incr(f"task_removed_from_dag.{dag.dag_id}", tags=self.stats_tags)
1244 # Same metric with tagging
1245 Stats.incr("task_removed_from_dag", tags={**self.stats_tags, "dag_id": dag.dag_id})
1246 ti.state = TaskInstanceState.REMOVED
1247 continue
1249 try:
1250 num_mapped_tis = task.get_parse_time_mapped_ti_count()
1251 except NotMapped:
1252 continue
1253 except NotFullyPopulated:
1254 # What if it is _now_ dynamically mapped, but wasn't before?
1255 try:
1256 total_length = task.get_mapped_ti_count(self.run_id, session=session)
1257 except NotFullyPopulated:
1258 # Not all upstreams finished, so we can't tell what should be here. Remove everything.
1259 if ti.map_index >= 0:
1260 self.log.debug(
1261 "Removing the unmapped TI '%s' as the mapping can't be resolved yet", ti
1262 )
1263 ti.state = TaskInstanceState.REMOVED
1264 continue
1265 # Upstreams finished, check there aren't any extras
1266 if ti.map_index >= total_length:
1267 self.log.debug(
1268 "Removing task '%s' as the map_index is longer than the resolved mapping list (%d)",
1269 ti,
1270 total_length,
1271 )
1272 ti.state = TaskInstanceState.REMOVED
1273 else:
1274 # Check if the number of mapped literals has changed, and we need to mark this TI as removed.
1275 if ti.map_index >= num_mapped_tis:
1276 self.log.debug(
1277 "Removing task '%s' as the map_index is longer than the literal mapping list (%s)",
1278 ti,
1279 num_mapped_tis,
1280 )
1281 ti.state = TaskInstanceState.REMOVED
1282 elif ti.map_index < 0:
1283 self.log.debug("Removing the unmapped TI '%s' as the mapping can now be performed", ti)
1284 ti.state = TaskInstanceState.REMOVED
1286 return task_ids
1288 @overload
1289 def _get_task_creator(
1290 self,
1291 created_counts: dict[str, int],
1292 ti_mutation_hook: Callable,
1293 hook_is_noop: Literal[True],
1294 ) -> Callable[[Operator, Iterable[int]], Iterator[dict[str, Any]]]: ...
1296 @overload
1297 def _get_task_creator(
1298 self,
1299 created_counts: dict[str, int],
1300 ti_mutation_hook: Callable,
1301 hook_is_noop: Literal[False],
1302 ) -> Callable[[Operator, Iterable[int]], Iterator[TI]]: ...
1304 def _get_task_creator(
1305 self,
1306 created_counts: dict[str, int],
1307 ti_mutation_hook: Callable,
1308 hook_is_noop: Literal[True, False],
1309 ) -> Callable[[Operator, Iterable[int]], Iterator[dict[str, Any]] | Iterator[TI]]:
1310 """
1311 Get the task creator function.
1313 This function also updates the created_counts dictionary with the number of tasks created.
1315 :param created_counts: Dictionary of task_type -> count of created TIs
1316 :param ti_mutation_hook: task_instance_mutation_hook function
1317 :param hook_is_noop: Whether the task_instance_mutation_hook is a noop
1319 """
1320 if hook_is_noop:
1322 def create_ti_mapping(task: Operator, indexes: Iterable[int]) -> Iterator[dict[str, Any]]:
1323 created_counts[task.task_type] += 1
1324 for map_index in indexes:
1325 yield TI.insert_mapping(self.run_id, task, map_index=map_index)
1327 creator = create_ti_mapping
1329 else:
1331 def create_ti(task: Operator, indexes: Iterable[int]) -> Iterator[TI]:
1332 for map_index in indexes:
1333 ti = TI(task, run_id=self.run_id, map_index=map_index)
1334 ti_mutation_hook(ti)
1335 created_counts[ti.operator] += 1
1336 yield ti
1338 creator = create_ti
1339 return creator
1341 def _create_tasks(
1342 self,
1343 tasks: Iterable[Operator],
1344 task_creator: Callable[[Operator, Iterable[int]], CreatedTasks],
1345 *,
1346 session: Session,
1347 ) -> CreatedTasks:
1348 """
1349 Create missing tasks -- and expand any MappedOperator that _only_ have literals as input.
1351 :param tasks: Tasks to create jobs for in the DAG run
1352 :param task_creator: Function to create task instances
1353 """
1354 map_indexes: Iterable[int]
1355 for task in tasks:
1356 try:
1357 count = task.get_mapped_ti_count(self.run_id, session=session)
1358 except (NotMapped, NotFullyPopulated):
1359 map_indexes = (-1,)
1360 else:
1361 if count:
1362 map_indexes = range(count)
1363 else:
1364 # Make sure to always create at least one ti; this will be
1365 # marked as REMOVED later at runtime.
1366 map_indexes = (-1,)
1367 yield from task_creator(task, map_indexes)
1369 def _create_task_instances(
1370 self,
1371 dag_id: str,
1372 tasks: Iterator[dict[str, Any]] | Iterator[TI],
1373 created_counts: dict[str, int],
1374 hook_is_noop: bool,
1375 *,
1376 session: Session,
1377 ) -> None:
1378 """
1379 Create the necessary task instances from the given tasks.
1381 :param dag_id: DAG ID associated with the dagrun
1382 :param tasks: the tasks to create the task instances from
1383 :param created_counts: a dictionary of number of tasks -> total ti created by the task creator
1384 :param hook_is_noop: whether the task_instance_mutation_hook is noop
1385 :param session: the session to use
1387 """
1388 # Fetch the information we need before handling the exception to avoid
1389 # PendingRollbackError due to the session being invalidated on exception
1390 # see https://github.com/apache/superset/pull/530
1391 run_id = self.run_id
1392 try:
1393 if hook_is_noop:
1394 session.bulk_insert_mappings(TI, tasks)
1395 else:
1396 session.bulk_save_objects(tasks)
1398 for task_type, count in created_counts.items():
1399 Stats.incr(f"task_instance_created_{task_type}", count, tags=self.stats_tags)
1400 # Same metric with tagging
1401 Stats.incr("task_instance_created", count, tags={**self.stats_tags, "task_type": task_type})
1402 session.flush()
1403 except IntegrityError:
1404 self.log.info(
1405 "Hit IntegrityError while creating the TIs for %s- %s",
1406 dag_id,
1407 run_id,
1408 exc_info=True,
1409 )
1410 self.log.info("Doing session rollback.")
1411 # TODO[HA]: We probably need to savepoint this so we can keep the transaction alive.
1412 session.rollback()
1414 def _revise_map_indexes_if_mapped(self, task: Operator, *, session: Session) -> Iterator[TI]:
1415 """Check if task increased or reduced in length and handle appropriately.
1417 Task instances that do not already exist are created and returned if
1418 possible. Expansion only happens if all upstreams are ready; otherwise
1419 we delay expansion to the "last resort". See comments at the call site
1420 for more details.
1421 """
1422 from airflow.settings import task_instance_mutation_hook
1424 try:
1425 total_length = task.get_mapped_ti_count(self.run_id, session=session)
1426 except NotMapped:
1427 return # Not a mapped task, don't need to do anything.
1428 except NotFullyPopulated:
1429 return # Upstreams not ready, don't need to revise this yet.
1431 query = session.scalars(
1432 select(TI.map_index).where(
1433 TI.dag_id == self.dag_id,
1434 TI.task_id == task.task_id,
1435 TI.run_id == self.run_id,
1436 )
1437 )
1438 existing_indexes = set(query)
1440 removed_indexes = existing_indexes.difference(range(total_length))
1441 if removed_indexes:
1442 session.execute(
1443 update(TI)
1444 .where(
1445 TI.dag_id == self.dag_id,
1446 TI.task_id == task.task_id,
1447 TI.run_id == self.run_id,
1448 TI.map_index.in_(removed_indexes),
1449 )
1450 .values(state=TaskInstanceState.REMOVED)
1451 )
1452 session.flush()
1454 for index in range(total_length):
1455 if index in existing_indexes:
1456 continue
1457 ti = TI(task, run_id=self.run_id, map_index=index, state=None)
1458 self.log.debug("Expanding TIs upserted %s", ti)
1459 task_instance_mutation_hook(ti)
1460 ti = session.merge(ti)
1461 ti.refresh_from_task(task)
1462 session.flush()
1463 yield ti
1465 @staticmethod
1466 def get_run(session: Session, dag_id: str, execution_date: datetime) -> DagRun | None:
1467 """
1468 Get a single DAG Run.
1470 :meta private:
1471 :param session: Sqlalchemy ORM Session
1472 :param dag_id: DAG ID
1473 :param execution_date: execution date
1474 :return: DagRun corresponding to the given dag_id and execution date
1475 if one exists. None otherwise.
1476 """
1477 warnings.warn(
1478 "This method is deprecated. Please use SQLAlchemy directly",
1479 RemovedInAirflow3Warning,
1480 stacklevel=2,
1481 )
1482 return session.scalar(
1483 select(DagRun).where(
1484 DagRun.dag_id == dag_id,
1485 DagRun.external_trigger == False, # noqa: E712
1486 DagRun.execution_date == execution_date,
1487 )
1488 )
1490 @property
1491 def is_backfill(self) -> bool:
1492 return self.run_type == DagRunType.BACKFILL_JOB
1494 @classmethod
1495 @provide_session
1496 def get_latest_runs(cls, session: Session = NEW_SESSION) -> list[DagRun]:
1497 """Return the latest DagRun for each DAG."""
1498 subquery = (
1499 select(cls.dag_id, func.max(cls.execution_date).label("execution_date"))
1500 .group_by(cls.dag_id)
1501 .subquery()
1502 )
1503 return session.scalars(
1504 select(cls).join(
1505 subquery,
1506 and_(cls.dag_id == subquery.c.dag_id, cls.execution_date == subquery.c.execution_date),
1507 )
1508 ).all()
1510 @provide_session
1511 def schedule_tis(
1512 self,
1513 schedulable_tis: Iterable[TI],
1514 session: Session = NEW_SESSION,
1515 max_tis_per_query: int | None = None,
1516 ) -> int:
1517 """
1518 Set the given task instances in to the scheduled state.
1520 Each element of ``schedulable_tis`` should have its ``task`` attribute already set.
1522 Any EmptyOperator without callbacks or outlets is instead set straight to the success state.
1524 All the TIs should belong to this DagRun, but this code is in the hot-path, this is not checked -- it
1525 is the caller's responsibility to call this function only with TIs from a single dag run.
1526 """
1527 # Get list of TI IDs that do not need to executed, these are
1528 # tasks using EmptyOperator and without on_execute_callback / on_success_callback
1529 dummy_ti_ids = []
1530 schedulable_ti_ids = []
1531 for ti in schedulable_tis:
1532 if TYPE_CHECKING:
1533 assert ti.task
1534 if (
1535 ti.task.inherits_from_empty_operator
1536 and not ti.task.on_execute_callback
1537 and not ti.task.on_success_callback
1538 and not ti.task.outlets
1539 ):
1540 dummy_ti_ids.append((ti.task_id, ti.map_index))
1541 elif (
1542 ti.task.start_trigger is not None
1543 and ti.task.next_method is not None
1544 and not ti.task.on_execute_callback
1545 and not ti.task.on_success_callback
1546 and not ti.task.outlets
1547 ):
1548 if ti.state != TaskInstanceState.UP_FOR_RESCHEDULE:
1549 ti.try_number += 1
1550 ti.defer_task(
1551 defer=TaskDeferred(trigger=ti.task.start_trigger, method_name=ti.task.next_method),
1552 session=session,
1553 )
1554 else:
1555 schedulable_ti_ids.append((ti.task_id, ti.map_index))
1557 count = 0
1559 if schedulable_ti_ids:
1560 schedulable_ti_ids_chunks = chunks(
1561 schedulable_ti_ids, max_tis_per_query or len(schedulable_ti_ids)
1562 )
1563 for schedulable_ti_ids_chunk in schedulable_ti_ids_chunks:
1564 count += session.execute(
1565 update(TI)
1566 .where(
1567 TI.dag_id == self.dag_id,
1568 TI.run_id == self.run_id,
1569 tuple_in_condition((TI.task_id, TI.map_index), schedulable_ti_ids_chunk),
1570 )
1571 .values(
1572 state=TaskInstanceState.SCHEDULED,
1573 try_number=case(
1574 (
1575 or_(TI.state.is_(None), TI.state != TaskInstanceState.UP_FOR_RESCHEDULE),
1576 TI.try_number + 1,
1577 ),
1578 else_=TI.try_number,
1579 ),
1580 )
1581 .execution_options(synchronize_session=False)
1582 ).rowcount
1584 # Tasks using EmptyOperator should not be executed, mark them as success
1585 if dummy_ti_ids:
1586 dummy_ti_ids_chunks = chunks(dummy_ti_ids, max_tis_per_query or len(dummy_ti_ids))
1587 for dummy_ti_ids_chunk in dummy_ti_ids_chunks:
1588 count += session.execute(
1589 update(TI)
1590 .where(
1591 TI.dag_id == self.dag_id,
1592 TI.run_id == self.run_id,
1593 tuple_in_condition((TI.task_id, TI.map_index), dummy_ti_ids_chunk),
1594 )
1595 .values(
1596 state=TaskInstanceState.SUCCESS,
1597 start_date=timezone.utcnow(),
1598 end_date=timezone.utcnow(),
1599 duration=0,
1600 )
1601 .execution_options(
1602 synchronize_session=False,
1603 )
1604 ).rowcount
1606 return count
1608 @provide_session
1609 def get_log_template(self, *, session: Session = NEW_SESSION) -> LogTemplate | LogTemplatePydantic:
1610 return DagRun._get_log_template(log_template_id=self.log_template_id, session=session)
1612 @staticmethod
1613 @internal_api_call
1614 @provide_session
1615 def _get_log_template(
1616 log_template_id: int | None, session: Session = NEW_SESSION
1617 ) -> LogTemplate | LogTemplatePydantic:
1618 template: LogTemplate | None
1619 if log_template_id is None: # DagRun created before LogTemplate introduction.
1620 template = session.scalar(select(LogTemplate).order_by(LogTemplate.id).limit(1))
1621 else:
1622 template = session.get(LogTemplate, log_template_id)
1623 if template is None:
1624 raise AirflowException(
1625 f"No log_template entry found for ID {log_template_id!r}. "
1626 f"Please make sure you set up the metadatabase correctly."
1627 )
1628 return template
1630 @provide_session
1631 def get_log_filename_template(self, *, session: Session = NEW_SESSION) -> str:
1632 warnings.warn(
1633 "This method is deprecated. Please use get_log_template instead.",
1634 RemovedInAirflow3Warning,
1635 stacklevel=2,
1636 )
1637 return self.get_log_template(session=session).filename
1640class DagRunNote(Base):
1641 """For storage of arbitrary notes concerning the dagrun instance."""
1643 __tablename__ = "dag_run_note"
1645 user_id = Column(
1646 Integer,
1647 ForeignKey("ab_user.id", name="dag_run_note_user_fkey"),
1648 nullable=True,
1649 )
1650 dag_run_id = Column(Integer, primary_key=True, nullable=False)
1651 content = Column(String(1000).with_variant(Text(1000), "mysql"))
1652 created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False)
1653 updated_at = Column(UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow, nullable=False)
1655 dag_run = relationship("DagRun", back_populates="dag_run_note")
1657 __table_args__ = (
1658 PrimaryKeyConstraint("dag_run_id", name="dag_run_note_pkey"),
1659 ForeignKeyConstraint(
1660 (dag_run_id,),
1661 ["dag_run.id"],
1662 name="dag_run_note_dr_fkey",
1663 ondelete="CASCADE",
1664 ),
1665 )
1667 def __init__(self, content, user_id=None):
1668 self.content = content
1669 self.user_id = user_id
1671 def __repr__(self):
1672 prefix = f"<{self.__class__.__name__}: {self.dag_id}.{self.dagrun_id} {self.run_id}"
1673 if self.map_index != -1:
1674 prefix += f" map_index={self.map_index}"
1675 return prefix + ">"