Coverage for /pythoncovmergedfiles/medio/medio/src/airflow/airflow/models/dag.py: 31%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18from __future__ import annotations
20import asyncio
21import copy
22import functools
23import itertools
24import logging
25import os
26import pathlib
27import pickle
28import sys
29import time
30import traceback
31import warnings
32import weakref
33from collections import abc, defaultdict, deque
34from contextlib import ExitStack
35from datetime import datetime, timedelta
36from inspect import signature
37from typing import (
38 TYPE_CHECKING,
39 Any,
40 Callable,
41 Collection,
42 Container,
43 Generator,
44 Iterable,
45 Iterator,
46 List,
47 Pattern,
48 Sequence,
49 Union,
50 cast,
51 overload,
52)
53from urllib.parse import urlsplit
55import jinja2
56import pendulum
57import re2
58import sqlalchemy_jsonfield
59from dateutil.relativedelta import relativedelta
60from sqlalchemy import (
61 Boolean,
62 Column,
63 ForeignKey,
64 Index,
65 Integer,
66 String,
67 Text,
68 and_,
69 case,
70 func,
71 not_,
72 or_,
73 select,
74 update,
75)
76from sqlalchemy.ext.associationproxy import association_proxy
77from sqlalchemy.orm import backref, joinedload, load_only, relationship
78from sqlalchemy.sql import Select, expression
80import airflow.templates
81from airflow import settings, utils
82from airflow.api_internal.internal_api_call import internal_api_call
83from airflow.configuration import conf as airflow_conf, secrets_backend_list
84from airflow.datasets import BaseDataset, Dataset, DatasetAll
85from airflow.datasets.manager import dataset_manager
86from airflow.exceptions import (
87 AirflowDagInconsistent,
88 AirflowException,
89 DuplicateTaskIdFound,
90 FailStopDagInvalidTriggerRule,
91 ParamValidationError,
92 RemovedInAirflow3Warning,
93 TaskDeferred,
94 TaskNotFound,
95)
96from airflow.jobs.job import run_job
97from airflow.models.abstractoperator import AbstractOperator, TaskStateChangeCallback
98from airflow.models.base import Base, StringID
99from airflow.models.baseoperator import BaseOperator
100from airflow.models.dagcode import DagCode
101from airflow.models.dagpickle import DagPickle
102from airflow.models.dagrun import RUN_ID_REGEX, DagRun
103from airflow.models.dataset import DatasetDagRunQueue, DatasetModel
104from airflow.models.param import DagParam, ParamsDict
105from airflow.models.taskinstance import (
106 Context,
107 TaskInstance,
108 TaskInstanceKey,
109 clear_task_instances,
110)
111from airflow.secrets.local_filesystem import LocalFilesystemBackend
112from airflow.security import permissions
113from airflow.settings import json
114from airflow.stats import Stats
115from airflow.timetables.base import DagRunInfo, DataInterval, TimeRestriction, Timetable
116from airflow.timetables.datasets import DatasetOrTimeSchedule
117from airflow.timetables.interval import CronDataIntervalTimetable, DeltaDataIntervalTimetable
118from airflow.timetables.simple import (
119 ContinuousTimetable,
120 DatasetTriggeredTimetable,
121 NullTimetable,
122 OnceTimetable,
123)
124from airflow.timetables.trigger import CronTriggerTimetable
125from airflow.utils import timezone
126from airflow.utils.dag_cycle_tester import check_cycle
127from airflow.utils.dates import cron_presets, date_range as utils_date_range
128from airflow.utils.decorators import fixup_decorator_warning_stack
129from airflow.utils.helpers import at_most_one, exactly_one, validate_key
130from airflow.utils.log.logging_mixin import LoggingMixin
131from airflow.utils.session import NEW_SESSION, provide_session
132from airflow.utils.sqlalchemy import (
133 Interval,
134 UtcDateTime,
135 lock_rows,
136 tuple_in_condition,
137 with_row_locks,
138)
139from airflow.utils.state import DagRunState, State, TaskInstanceState
140from airflow.utils.trigger_rule import TriggerRule
141from airflow.utils.types import NOTSET, ArgNotSet, DagRunType, EdgeInfoType
143if TYPE_CHECKING:
144 from types import ModuleType
146 from pendulum.tz.timezone import FixedTimezone, Timezone
147 from sqlalchemy.orm.query import Query
148 from sqlalchemy.orm.session import Session
150 from airflow.decorators import TaskDecoratorCollection
151 from airflow.models.dagbag import DagBag
152 from airflow.models.operator import Operator
153 from airflow.models.slamiss import SlaMiss
154 from airflow.serialization.pydantic.dag import DagModelPydantic
155 from airflow.serialization.pydantic.dag_run import DagRunPydantic
156 from airflow.typing_compat import Literal
157 from airflow.utils.task_group import TaskGroup
159 # This is a workaround because mypy doesn't work with hybrid_property
160 # TODO: remove this hack and move hybrid_property back to main import block
161 # See https://github.com/python/mypy/issues/4430
162 hybrid_property = property
163else:
164 from sqlalchemy.ext.hybrid import hybrid_property
166log = logging.getLogger(__name__)
168DEFAULT_VIEW_PRESETS = ["grid", "graph", "duration", "gantt", "landing_times"]
169ORIENTATION_PRESETS = ["LR", "TB", "RL", "BT"]
171TAG_MAX_LEN = 100
173DagStateChangeCallback = Callable[[Context], None]
174ScheduleInterval = Union[None, str, timedelta, relativedelta]
176# FIXME: Ideally this should be Union[Literal[NOTSET], ScheduleInterval],
177# but Mypy cannot handle that right now. Track progress of PEP 661 for progress.
178# See also: https://discuss.python.org/t/9126/7
179ScheduleIntervalArg = Union[ArgNotSet, ScheduleInterval]
180ScheduleArg = Union[ArgNotSet, ScheduleInterval, Timetable, BaseDataset, Collection["Dataset"]]
182SLAMissCallback = Callable[["DAG", str, str, List["SlaMiss"], List[TaskInstance]], None]
184# Backward compatibility: If neither schedule_interval nor timetable is
185# *provided by the user*, default to a one-day interval.
186DEFAULT_SCHEDULE_INTERVAL = timedelta(days=1)
189class InconsistentDataInterval(AirflowException):
190 """Exception raised when a model populates data interval fields incorrectly.
192 The data interval fields should either both be None (for runs scheduled
193 prior to AIP-39), or both be datetime (for runs scheduled after AIP-39 is
194 implemented). This is raised if exactly one of the fields is None.
195 """
197 _template = (
198 "Inconsistent {cls}: {start[0]}={start[1]!r}, {end[0]}={end[1]!r}, "
199 "they must be either both None or both datetime"
200 )
202 def __init__(self, instance: Any, start_field_name: str, end_field_name: str) -> None:
203 self._class_name = type(instance).__name__
204 self._start_field = (start_field_name, getattr(instance, start_field_name))
205 self._end_field = (end_field_name, getattr(instance, end_field_name))
207 def __str__(self) -> str:
208 return self._template.format(cls=self._class_name, start=self._start_field, end=self._end_field)
211def _get_model_data_interval(
212 instance: Any,
213 start_field_name: str,
214 end_field_name: str,
215) -> DataInterval | None:
216 start = timezone.coerce_datetime(getattr(instance, start_field_name))
217 end = timezone.coerce_datetime(getattr(instance, end_field_name))
218 if start is None:
219 if end is not None:
220 raise InconsistentDataInterval(instance, start_field_name, end_field_name)
221 return None
222 elif end is None:
223 raise InconsistentDataInterval(instance, start_field_name, end_field_name)
224 return DataInterval(start, end)
227def create_timetable(interval: ScheduleIntervalArg, timezone: Timezone | FixedTimezone) -> Timetable:
228 """Create a Timetable instance from a ``schedule_interval`` argument."""
229 if interval is NOTSET:
230 return DeltaDataIntervalTimetable(DEFAULT_SCHEDULE_INTERVAL)
231 if interval is None:
232 return NullTimetable()
233 if interval == "@once":
234 return OnceTimetable()
235 if interval == "@continuous":
236 return ContinuousTimetable()
237 if isinstance(interval, (timedelta, relativedelta)):
238 return DeltaDataIntervalTimetable(interval)
239 if isinstance(interval, str):
240 if airflow_conf.getboolean("scheduler", "create_cron_data_intervals"):
241 return CronDataIntervalTimetable(interval, timezone)
242 else:
243 return CronTriggerTimetable(interval, timezone=timezone)
244 raise ValueError(f"{interval!r} is not a valid schedule_interval.")
247def get_last_dagrun(dag_id, session, include_externally_triggered=False):
248 """
249 Return the last dag run for a dag, None if there was none.
251 Last dag run can be any type of run e.g. scheduled or backfilled.
252 Overridden DagRuns are ignored.
253 """
254 DR = DagRun
255 query = select(DR).where(DR.dag_id == dag_id)
256 if not include_externally_triggered:
257 query = query.where(DR.external_trigger == expression.false())
258 query = query.order_by(DR.execution_date.desc())
259 return session.scalar(query.limit(1))
262def get_dataset_triggered_next_run_info(
263 dag_ids: list[str], *, session: Session
264) -> dict[str, dict[str, int | str]]:
265 """
266 Get next run info for a list of dag_ids.
268 Given a list of dag_ids, get string representing how close any that are dataset triggered are
269 their next run, e.g. "1 of 2 datasets updated".
270 """
271 from airflow.models.dataset import DagScheduleDatasetReference, DatasetDagRunQueue as DDRQ, DatasetModel
273 return {
274 x.dag_id: {
275 "uri": x.uri,
276 "ready": x.ready,
277 "total": x.total,
278 }
279 for x in session.execute(
280 select(
281 DagScheduleDatasetReference.dag_id,
282 # This is a dirty hack to workaround group by requiring an aggregate,
283 # since grouping by dataset is not what we want to do here...but it works
284 case((func.count() == 1, func.max(DatasetModel.uri)), else_="").label("uri"),
285 func.count().label("total"),
286 func.sum(case((DDRQ.target_dag_id.is_not(None), 1), else_=0)).label("ready"),
287 )
288 .join(
289 DDRQ,
290 and_(
291 DDRQ.dataset_id == DagScheduleDatasetReference.dataset_id,
292 DDRQ.target_dag_id == DagScheduleDatasetReference.dag_id,
293 ),
294 isouter=True,
295 )
296 .join(DatasetModel, DatasetModel.id == DagScheduleDatasetReference.dataset_id)
297 .group_by(DagScheduleDatasetReference.dag_id)
298 .where(DagScheduleDatasetReference.dag_id.in_(dag_ids))
299 ).all()
300 }
303def _triggerer_is_healthy():
304 from airflow.jobs.triggerer_job_runner import TriggererJobRunner
306 job = TriggererJobRunner.most_recent_job()
307 return job and job.is_alive()
310@internal_api_call
311@provide_session
312def _create_orm_dagrun(
313 dag,
314 dag_id,
315 run_id,
316 logical_date,
317 start_date,
318 external_trigger,
319 conf,
320 state,
321 run_type,
322 dag_hash,
323 creating_job_id,
324 data_interval,
325 session,
326):
327 run = DagRun(
328 dag_id=dag_id,
329 run_id=run_id,
330 execution_date=logical_date,
331 start_date=start_date,
332 external_trigger=external_trigger,
333 conf=conf,
334 state=state,
335 run_type=run_type,
336 dag_hash=dag_hash,
337 creating_job_id=creating_job_id,
338 data_interval=data_interval,
339 )
340 session.add(run)
341 session.flush()
342 run.dag = dag
343 # create the associated task instances
344 # state is None at the moment of creation
345 run.verify_integrity(session=session)
346 return run
349@functools.total_ordering
350class DAG(LoggingMixin):
351 """
352 A dag (directed acyclic graph) is a collection of tasks with directional dependencies.
354 A dag also has a schedule, a start date and an end date (optional). For each schedule,
355 (say daily or hourly), the DAG needs to run each individual tasks as their dependencies
356 are met. Certain tasks have the property of depending on their own past, meaning that
357 they can't run until their previous schedule (and upstream tasks) are completed.
359 DAGs essentially act as namespaces for tasks. A task_id can only be
360 added once to a DAG.
362 Note that if you plan to use time zones all the dates provided should be pendulum
363 dates. See :ref:`timezone_aware_dags`.
365 .. versionadded:: 2.4
366 The *schedule* argument to specify either time-based scheduling logic
367 (timetable), or dataset-driven triggers.
369 .. deprecated:: 2.4
370 The arguments *schedule_interval* and *timetable*. Their functionalities
371 are merged into the new *schedule* argument.
373 :param dag_id: The id of the DAG; must consist exclusively of alphanumeric
374 characters, dashes, dots and underscores (all ASCII)
375 :param description: The description for the DAG to e.g. be shown on the webserver
376 :param schedule: Defines the rules according to which DAG runs are scheduled. Can
377 accept cron string, timedelta object, Timetable, or list of Dataset objects.
378 If this is not provided, the DAG will be set to the default
379 schedule ``timedelta(days=1)``. See also :doc:`/howto/timetable`.
380 :param start_date: The timestamp from which the scheduler will
381 attempt to backfill
382 :param end_date: A date beyond which your DAG won't run, leave to None
383 for open-ended scheduling
384 :param template_searchpath: This list of folders (non-relative)
385 defines where jinja will look for your templates. Order matters.
386 Note that jinja/airflow includes the path of your DAG file by
387 default
388 :param template_undefined: Template undefined type.
389 :param user_defined_macros: a dictionary of macros that will be exposed
390 in your jinja templates. For example, passing ``dict(foo='bar')``
391 to this argument allows you to ``{{ foo }}`` in all jinja
392 templates related to this DAG. Note that you can pass any
393 type of object here.
394 :param user_defined_filters: a dictionary of filters that will be exposed
395 in your jinja templates. For example, passing
396 ``dict(hello=lambda name: 'Hello %s' % name)`` to this argument allows
397 you to ``{{ 'world' | hello }}`` in all jinja templates related to
398 this DAG.
399 :param default_args: A dictionary of default parameters to be used
400 as constructor keyword parameters when initialising operators.
401 Note that operators have the same hook, and precede those defined
402 here, meaning that if your dict contains `'depends_on_past': True`
403 here and `'depends_on_past': False` in the operator's call
404 `default_args`, the actual value will be `False`.
405 :param params: a dictionary of DAG level parameters that are made
406 accessible in templates, namespaced under `params`. These
407 params can be overridden at the task level.
408 :param max_active_tasks: the number of task instances allowed to run
409 concurrently
410 :param max_active_runs: maximum number of active DAG runs, beyond this
411 number of DAG runs in a running state, the scheduler won't create
412 new active DAG runs
413 :param max_consecutive_failed_dag_runs: (experimental) maximum number of consecutive failed DAG runs,
414 beyond this the scheduler will disable the DAG
415 :param dagrun_timeout: specify how long a DagRun should be up before
416 timing out / failing, so that new DagRuns can be created.
417 :param sla_miss_callback: specify a function or list of functions to call when reporting SLA
418 timeouts. See :ref:`sla_miss_callback<concepts:sla_miss_callback>` for
419 more information about the function signature and parameters that are
420 passed to the callback.
421 :param default_view: Specify DAG default view (grid, graph, duration,
422 gantt, landing_times), default grid
423 :param orientation: Specify DAG orientation in graph view (LR, TB, RL, BT), default LR
424 :param catchup: Perform scheduler catchup (or only run latest)? Defaults to True
425 :param on_failure_callback: A function or list of functions to be called when a DagRun of this dag fails.
426 A context dictionary is passed as a single parameter to this function.
427 :param on_success_callback: Much like the ``on_failure_callback`` except
428 that it is executed when the dag succeeds.
429 :param access_control: Specify optional DAG-level actions, e.g.,
430 "{'role1': {'can_read'}, 'role2': {'can_read', 'can_edit', 'can_delete'}}"
431 :param is_paused_upon_creation: Specifies if the dag is paused when created for the first time.
432 If the dag exists already, this flag will be ignored. If this optional parameter
433 is not specified, the global config setting will be used.
434 :param jinja_environment_kwargs: additional configuration options to be passed to Jinja
435 ``Environment`` for template rendering
437 **Example**: to avoid Jinja from removing a trailing newline from template strings ::
439 DAG(
440 dag_id="my-dag",
441 jinja_environment_kwargs={
442 "keep_trailing_newline": True,
443 # some other jinja2 Environment options here
444 },
445 )
447 **See**: `Jinja Environment documentation
448 <https://jinja.palletsprojects.com/en/2.11.x/api/#jinja2.Environment>`_
450 :param render_template_as_native_obj: If True, uses a Jinja ``NativeEnvironment``
451 to render templates as native Python types. If False, a Jinja
452 ``Environment`` is used to render templates as string values.
453 :param tags: List of tags to help filtering DAGs in the UI.
454 :param owner_links: Dict of owners and their links, that will be clickable on the DAGs view UI.
455 Can be used as an HTTP link (for example the link to your Slack channel), or a mailto link.
456 e.g: {"dag_owner": "https://airflow.apache.org/"}
457 :param auto_register: Automatically register this DAG when it is used in a ``with`` block
458 :param fail_stop: Fails currently running tasks when task in DAG fails.
459 **Warning**: A fail stop dag can only have tasks with the default trigger rule ("all_success").
460 An exception will be thrown if any task in a fail stop dag has a non default trigger rule.
461 :param dag_display_name: The display name of the DAG which appears on the UI.
462 """
464 _comps = {
465 "dag_id",
466 "task_ids",
467 "parent_dag",
468 "start_date",
469 "end_date",
470 "schedule_interval",
471 "fileloc",
472 "template_searchpath",
473 "last_loaded",
474 }
476 __serialized_fields: frozenset[str] | None = None
478 fileloc: str
479 """
480 File path that needs to be imported to load this DAG or subdag.
482 This may not be an actual file on disk in the case when this DAG is loaded
483 from a ZIP file or other DAG distribution format.
484 """
486 parent_dag: DAG | None = None # Gets set when DAGs are loaded
488 # NOTE: When updating arguments here, please also keep arguments in @dag()
489 # below in sync. (Search for 'def dag(' in this file.)
490 def __init__(
491 self,
492 dag_id: str,
493 description: str | None = None,
494 schedule: ScheduleArg = NOTSET,
495 schedule_interval: ScheduleIntervalArg = NOTSET,
496 timetable: Timetable | None = None,
497 start_date: datetime | None = None,
498 end_date: datetime | None = None,
499 full_filepath: str | None = None,
500 template_searchpath: str | Iterable[str] | None = None,
501 template_undefined: type[jinja2.StrictUndefined] = jinja2.StrictUndefined,
502 user_defined_macros: dict | None = None,
503 user_defined_filters: dict | None = None,
504 default_args: dict | None = None,
505 concurrency: int | None = None,
506 max_active_tasks: int = airflow_conf.getint("core", "max_active_tasks_per_dag"),
507 max_active_runs: int = airflow_conf.getint("core", "max_active_runs_per_dag"),
508 max_consecutive_failed_dag_runs: int = airflow_conf.getint(
509 "core", "max_consecutive_failed_dag_runs_per_dag"
510 ),
511 dagrun_timeout: timedelta | None = None,
512 sla_miss_callback: None | SLAMissCallback | list[SLAMissCallback] = None,
513 default_view: str = airflow_conf.get_mandatory_value("webserver", "dag_default_view").lower(),
514 orientation: str = airflow_conf.get_mandatory_value("webserver", "dag_orientation"),
515 catchup: bool = airflow_conf.getboolean("scheduler", "catchup_by_default"),
516 on_success_callback: None | DagStateChangeCallback | list[DagStateChangeCallback] = None,
517 on_failure_callback: None | DagStateChangeCallback | list[DagStateChangeCallback] = None,
518 doc_md: str | None = None,
519 params: abc.MutableMapping | None = None,
520 access_control: dict | None = None,
521 is_paused_upon_creation: bool | None = None,
522 jinja_environment_kwargs: dict | None = None,
523 render_template_as_native_obj: bool = False,
524 tags: list[str] | None = None,
525 owner_links: dict[str, str] | None = None,
526 auto_register: bool = True,
527 fail_stop: bool = False,
528 dag_display_name: str | None = None,
529 ):
530 from airflow.utils.task_group import TaskGroup
532 if tags and any(len(tag) > TAG_MAX_LEN for tag in tags):
533 raise AirflowException(f"tag cannot be longer than {TAG_MAX_LEN} characters")
535 self.owner_links = owner_links or {}
536 self.user_defined_macros = user_defined_macros
537 self.user_defined_filters = user_defined_filters
538 if default_args and not isinstance(default_args, dict):
539 raise TypeError("default_args must be a dict")
540 self.default_args = copy.deepcopy(default_args or {})
541 params = params or {}
543 # merging potentially conflicting default_args['params'] into params
544 if "params" in self.default_args:
545 params.update(self.default_args["params"])
546 del self.default_args["params"]
548 # check self.params and convert them into ParamsDict
549 self.params = ParamsDict(params)
551 if full_filepath:
552 warnings.warn(
553 "Passing full_filepath to DAG() is deprecated and has no effect",
554 RemovedInAirflow3Warning,
555 stacklevel=2,
556 )
558 validate_key(dag_id)
560 self._dag_id = dag_id
561 self._dag_display_property_value = dag_display_name
563 if concurrency:
564 # TODO: Remove in Airflow 3.0
565 warnings.warn(
566 "The 'concurrency' parameter is deprecated. Please use 'max_active_tasks'.",
567 RemovedInAirflow3Warning,
568 stacklevel=2,
569 )
570 max_active_tasks = concurrency
571 self._max_active_tasks = max_active_tasks
572 self._pickle_id: int | None = None
574 self._description = description
575 # set file location to caller source path
576 back = sys._getframe().f_back
577 self.fileloc = back.f_code.co_filename if back else ""
578 self.task_dict: dict[str, Operator] = {}
580 # set timezone from start_date
581 tz = None
582 if start_date and start_date.tzinfo:
583 tzinfo = None if start_date.tzinfo else settings.TIMEZONE
584 tz = pendulum.instance(start_date, tz=tzinfo).timezone
585 elif date := self.default_args.get("start_date"):
586 if not isinstance(date, datetime):
587 date = timezone.parse(date)
588 self.default_args["start_date"] = date
589 start_date = date
591 tzinfo = None if date.tzinfo else settings.TIMEZONE
592 tz = pendulum.instance(date, tz=tzinfo).timezone
593 self.timezone: Timezone | FixedTimezone = tz or settings.TIMEZONE
595 # Apply the timezone we settled on to end_date if it wasn't supplied
596 if isinstance(_end_date := self.default_args.get("end_date"), str):
597 self.default_args["end_date"] = timezone.parse(_end_date, timezone=self.timezone)
599 self.start_date = timezone.convert_to_utc(start_date)
600 self.end_date = timezone.convert_to_utc(end_date)
602 # also convert tasks
603 if "start_date" in self.default_args:
604 self.default_args["start_date"] = timezone.convert_to_utc(self.default_args["start_date"])
605 if "end_date" in self.default_args:
606 self.default_args["end_date"] = timezone.convert_to_utc(self.default_args["end_date"])
608 # sort out DAG's scheduling behavior
609 scheduling_args = [schedule_interval, timetable, schedule]
611 has_scheduling_args = any(a is not NOTSET and bool(a) for a in scheduling_args)
612 has_empty_start_date = not ("start_date" in self.default_args or self.start_date)
614 if has_scheduling_args and has_empty_start_date:
615 raise ValueError("DAG is missing the start_date parameter")
617 if not at_most_one(*scheduling_args):
618 raise ValueError("At most one allowed for args 'schedule_interval', 'timetable', and 'schedule'.")
619 if schedule_interval is not NOTSET:
620 warnings.warn(
621 "Param `schedule_interval` is deprecated and will be removed in a future release. "
622 "Please use `schedule` instead. ",
623 RemovedInAirflow3Warning,
624 stacklevel=2,
625 )
626 if timetable is not None:
627 warnings.warn(
628 "Param `timetable` is deprecated and will be removed in a future release. "
629 "Please use `schedule` instead. ",
630 RemovedInAirflow3Warning,
631 stacklevel=2,
632 )
634 self.timetable: Timetable
635 self.schedule_interval: ScheduleInterval
636 self.dataset_triggers: BaseDataset | None = None
637 if isinstance(schedule, BaseDataset):
638 self.dataset_triggers = schedule
639 elif isinstance(schedule, Collection) and not isinstance(schedule, str):
640 if not all(isinstance(x, Dataset) for x in schedule):
641 raise ValueError("All elements in 'schedule' should be datasets")
642 self.dataset_triggers = DatasetAll(*schedule)
643 elif isinstance(schedule, Timetable):
644 timetable = schedule
645 elif schedule is not NOTSET and not isinstance(schedule, BaseDataset):
646 schedule_interval = schedule
648 if isinstance(schedule, DatasetOrTimeSchedule):
649 self.timetable = schedule
650 self.dataset_triggers = self.timetable.datasets
651 self.schedule_interval = self.timetable.summary
652 elif self.dataset_triggers:
653 self.timetable = DatasetTriggeredTimetable()
654 self.schedule_interval = self.timetable.summary
655 elif timetable:
656 self.timetable = timetable
657 self.schedule_interval = self.timetable.summary
658 else:
659 if isinstance(schedule_interval, ArgNotSet):
660 schedule_interval = DEFAULT_SCHEDULE_INTERVAL
661 self.schedule_interval = schedule_interval
662 self.timetable = create_timetable(schedule_interval, self.timezone)
664 if isinstance(template_searchpath, str):
665 template_searchpath = [template_searchpath]
666 self.template_searchpath = template_searchpath
667 self.template_undefined = template_undefined
668 self.last_loaded: datetime = timezone.utcnow()
669 self.safe_dag_id = dag_id.replace(".", "__dot__")
670 self.max_active_runs = max_active_runs
671 self.max_consecutive_failed_dag_runs = max_consecutive_failed_dag_runs
672 if self.max_consecutive_failed_dag_runs == 0:
673 self.max_consecutive_failed_dag_runs = airflow_conf.getint(
674 "core", "max_consecutive_failed_dag_runs_per_dag"
675 )
676 if self.max_consecutive_failed_dag_runs < 0:
677 raise AirflowException(
678 f"Invalid max_consecutive_failed_dag_runs: {self.max_consecutive_failed_dag_runs}."
679 f"Requires max_consecutive_failed_dag_runs >= 0"
680 )
681 if self.timetable.active_runs_limit is not None:
682 if self.timetable.active_runs_limit < self.max_active_runs:
683 raise AirflowException(
684 f"Invalid max_active_runs: {type(self.timetable)} "
685 f"requires max_active_runs <= {self.timetable.active_runs_limit}"
686 )
687 self.dagrun_timeout = dagrun_timeout
688 self.sla_miss_callback = sla_miss_callback
689 if default_view in DEFAULT_VIEW_PRESETS:
690 self._default_view: str = default_view
691 elif default_view == "tree":
692 warnings.warn(
693 "`default_view` of 'tree' has been renamed to 'grid' -- please update your DAG",
694 RemovedInAirflow3Warning,
695 stacklevel=2,
696 )
697 self._default_view = "grid"
698 else:
699 raise AirflowException(
700 f"Invalid values of dag.default_view: only support "
701 f"{DEFAULT_VIEW_PRESETS}, but get {default_view}"
702 )
703 if orientation in ORIENTATION_PRESETS:
704 self.orientation = orientation
705 else:
706 raise AirflowException(
707 f"Invalid values of dag.orientation: only support "
708 f"{ORIENTATION_PRESETS}, but get {orientation}"
709 )
710 self.catchup: bool = catchup
712 self.partial: bool = False
713 self.on_success_callback = on_success_callback
714 self.on_failure_callback = on_failure_callback
716 # Keeps track of any extra edge metadata (sparse; will not contain all
717 # edges, so do not iterate over it for that). Outer key is upstream
718 # task ID, inner key is downstream task ID.
719 self.edge_info: dict[str, dict[str, EdgeInfoType]] = {}
721 # To keep it in parity with Serialized DAGs
722 # and identify if DAG has on_*_callback without actually storing them in Serialized JSON
723 self.has_on_success_callback: bool = self.on_success_callback is not None
724 self.has_on_failure_callback: bool = self.on_failure_callback is not None
726 self._access_control = DAG._upgrade_outdated_dag_access_control(access_control)
727 self.is_paused_upon_creation = is_paused_upon_creation
728 self.auto_register = auto_register
730 self.fail_stop: bool = fail_stop
732 self.jinja_environment_kwargs = jinja_environment_kwargs
733 self.render_template_as_native_obj = render_template_as_native_obj
735 self.doc_md = self.get_doc_md(doc_md)
737 self.tags = tags or []
738 self._task_group = TaskGroup.create_root(self)
739 self.validate_schedule_and_params()
740 wrong_links = dict(self.iter_invalid_owner_links())
741 if wrong_links:
742 raise AirflowException(
743 "Wrong link format was used for the owner. Use a valid link \n"
744 f"Bad formatted links are: {wrong_links}"
745 )
747 # this will only be set at serialization time
748 # it's only use is for determining the relative
749 # fileloc based only on the serialize dag
750 self._processor_dags_folder = None
752 def get_doc_md(self, doc_md: str | None) -> str | None:
753 if doc_md is None:
754 return doc_md
756 env = self.get_template_env(force_sandboxed=True)
758 if not doc_md.endswith(".md"):
759 template = jinja2.Template(doc_md)
760 else:
761 try:
762 template = env.get_template(doc_md)
763 except jinja2.exceptions.TemplateNotFound:
764 return f"""
765 # Templating Error!
766 Not able to find the template file: `{doc_md}`.
767 """
769 return template.render()
771 def _check_schedule_interval_matches_timetable(self) -> bool:
772 """Check ``schedule_interval`` and ``timetable`` match.
774 This is done as a part of the DAG validation done before it's bagged, to
775 guard against the DAG's ``timetable`` (or ``schedule_interval``) from
776 being changed after it's created, e.g.
778 .. code-block:: python
780 dag1 = DAG("d1", timetable=MyTimetable())
781 dag1.schedule_interval = "@once"
783 dag2 = DAG("d2", schedule="@once")
784 dag2.timetable = MyTimetable()
786 Validation is done by creating a timetable and check its summary matches
787 ``schedule_interval``. The logic is not bullet-proof, especially if a
788 custom timetable does not provide a useful ``summary``. But this is the
789 best we can do.
790 """
791 if self.schedule_interval == self.timetable.summary:
792 return True
793 try:
794 timetable = create_timetable(self.schedule_interval, self.timezone)
795 except ValueError:
796 return False
797 return timetable.summary == self.timetable.summary
799 def validate(self):
800 """Validate the DAG has a coherent setup.
802 This is called by the DAG bag before bagging the DAG.
803 """
804 if not self._check_schedule_interval_matches_timetable():
805 raise AirflowDagInconsistent(
806 f"inconsistent schedule: timetable {self.timetable.summary!r} "
807 f"does not match schedule_interval {self.schedule_interval!r}",
808 )
809 self.validate_schedule_and_params()
810 self.timetable.validate()
811 self.validate_setup_teardown()
813 def validate_setup_teardown(self):
814 """
815 Validate that setup and teardown tasks are configured properly.
817 :meta private:
818 """
819 for task in self.tasks:
820 if task.is_setup:
821 for down_task in task.downstream_list:
822 if not down_task.is_teardown and down_task.trigger_rule != TriggerRule.ALL_SUCCESS:
823 # todo: we can relax this to allow out-of-scope tasks to have other trigger rules
824 # this is required to ensure consistent behavior of dag
825 # when clearing an indirect setup
826 raise ValueError("Setup tasks must be followed with trigger rule ALL_SUCCESS.")
827 FailStopDagInvalidTriggerRule.check(dag=self, trigger_rule=task.trigger_rule)
829 def __repr__(self):
830 return f"<DAG: {self.dag_id}>"
832 def __eq__(self, other):
833 if type(self) == type(other):
834 # Use getattr() instead of __dict__ as __dict__ doesn't return
835 # correct values for properties.
836 return all(getattr(self, c, None) == getattr(other, c, None) for c in self._comps)
837 return False
839 def __ne__(self, other):
840 return not self == other
842 def __lt__(self, other):
843 return self.dag_id < other.dag_id
845 def __hash__(self):
846 hash_components = [type(self)]
847 for c in self._comps:
848 # task_ids returns a list and lists can't be hashed
849 if c == "task_ids":
850 val = tuple(self.task_dict)
851 else:
852 val = getattr(self, c, None)
853 try:
854 hash(val)
855 hash_components.append(val)
856 except TypeError:
857 hash_components.append(repr(val))
858 return hash(tuple(hash_components))
860 # Context Manager -----------------------------------------------
861 def __enter__(self):
862 DagContext.push_context_managed_dag(self)
863 return self
865 def __exit__(self, _type, _value, _tb):
866 DagContext.pop_context_managed_dag()
868 # /Context Manager ----------------------------------------------
870 @staticmethod
871 def _upgrade_outdated_dag_access_control(access_control=None):
872 """
873 Look for outdated dag level actions in DAG access_controls and replace them with updated actions.
875 For example, in DAG access_control {'role1': {'can_dag_read'}} 'can_dag_read'
876 will be replaced with 'can_read', in {'role2': {'can_dag_read', 'can_dag_edit'}}
877 'can_dag_edit' will be replaced with 'can_edit', etc.
878 """
879 if access_control is None:
880 return None
881 new_perm_mapping = {
882 permissions.DEPRECATED_ACTION_CAN_DAG_READ: permissions.ACTION_CAN_READ,
883 permissions.DEPRECATED_ACTION_CAN_DAG_EDIT: permissions.ACTION_CAN_EDIT,
884 }
885 updated_access_control = {}
886 for role, perms in access_control.items():
887 updated_access_control[role] = {new_perm_mapping.get(perm, perm) for perm in perms}
889 if access_control != updated_access_control:
890 warnings.warn(
891 "The 'can_dag_read' and 'can_dag_edit' permissions are deprecated. "
892 "Please use 'can_read' and 'can_edit', respectively.",
893 RemovedInAirflow3Warning,
894 stacklevel=3,
895 )
897 return updated_access_control
899 def date_range(
900 self,
901 start_date: pendulum.DateTime,
902 num: int | None = None,
903 end_date: datetime | None = None,
904 ) -> list[datetime]:
905 message = "`DAG.date_range()` is deprecated."
906 if num is not None:
907 warnings.warn(message, category=RemovedInAirflow3Warning, stacklevel=2)
908 with warnings.catch_warnings():
909 warnings.simplefilter("ignore", RemovedInAirflow3Warning)
910 return utils_date_range(
911 start_date=start_date, num=num, delta=self.normalized_schedule_interval
912 )
913 message += " Please use `DAG.iter_dagrun_infos_between(..., align=False)` instead."
914 warnings.warn(message, category=RemovedInAirflow3Warning, stacklevel=2)
915 if end_date is None:
916 coerced_end_date = timezone.utcnow()
917 else:
918 coerced_end_date = end_date
919 it = self.iter_dagrun_infos_between(start_date, pendulum.instance(coerced_end_date), align=False)
920 return [info.logical_date for info in it]
922 def is_fixed_time_schedule(self):
923 """Figures out if the schedule has a fixed time (e.g. 3 AM every day).
925 Detection is done by "peeking" the next two cron trigger time; if the
926 two times have the same minute and hour value, the schedule is fixed,
927 and we *don't* need to perform the DST fix.
929 This assumes DST happens on whole minute changes (e.g. 12:59 -> 12:00).
931 Do not try to understand what this actually means. It is old logic that
932 should not be used anywhere.
933 """
934 warnings.warn(
935 "`DAG.is_fixed_time_schedule()` is deprecated.",
936 category=RemovedInAirflow3Warning,
937 stacklevel=2,
938 )
940 from airflow.timetables._cron import CronMixin
942 if not isinstance(self.timetable, CronMixin):
943 return True
945 from croniter import croniter
947 cron = croniter(self.timetable._expression)
948 next_a = cron.get_next(datetime)
949 next_b = cron.get_next(datetime)
950 return next_b.minute == next_a.minute and next_b.hour == next_a.hour
952 def following_schedule(self, dttm):
953 """
954 Calculate the following schedule for this dag in UTC.
956 :param dttm: utc datetime
957 :return: utc datetime
958 """
959 warnings.warn(
960 "`DAG.following_schedule()` is deprecated. Use `DAG.next_dagrun_info(restricted=False)` instead.",
961 category=RemovedInAirflow3Warning,
962 stacklevel=2,
963 )
964 data_interval = self.infer_automated_data_interval(timezone.coerce_datetime(dttm))
965 next_info = self.next_dagrun_info(data_interval, restricted=False)
966 if next_info is None:
967 return None
968 return next_info.data_interval.start
970 def previous_schedule(self, dttm):
971 from airflow.timetables.interval import _DataIntervalTimetable
973 warnings.warn(
974 "`DAG.previous_schedule()` is deprecated.",
975 category=RemovedInAirflow3Warning,
976 stacklevel=2,
977 )
978 if not isinstance(self.timetable, _DataIntervalTimetable):
979 return None
980 return self.timetable._get_prev(timezone.coerce_datetime(dttm))
982 def get_next_data_interval(self, dag_model: DagModel) -> DataInterval | None:
983 """Get the data interval of the next scheduled run.
985 For compatibility, this method infers the data interval from the DAG's
986 schedule if the run does not have an explicit one set, which is possible
987 for runs created prior to AIP-39.
989 This function is private to Airflow core and should not be depended on as a
990 part of the Python API.
992 :meta private:
993 """
994 if self.dag_id != dag_model.dag_id:
995 raise ValueError(f"Arguments refer to different DAGs: {self.dag_id} != {dag_model.dag_id}")
996 if dag_model.next_dagrun is None: # Next run not scheduled.
997 return None
998 data_interval = dag_model.next_dagrun_data_interval
999 if data_interval is not None:
1000 return data_interval
1002 # Compatibility: A run was scheduled without an explicit data interval.
1003 # This means the run was scheduled before AIP-39 implementation. Try to
1004 # infer from the logical date.
1005 return self.infer_automated_data_interval(dag_model.next_dagrun)
1007 def get_run_data_interval(self, run: DagRun | DagRunPydantic) -> DataInterval:
1008 """Get the data interval of this run.
1010 For compatibility, this method infers the data interval from the DAG's
1011 schedule if the run does not have an explicit one set, which is possible for
1012 runs created prior to AIP-39.
1014 This function is private to Airflow core and should not be depended on as a
1015 part of the Python API.
1017 :meta private:
1018 """
1019 if run.dag_id is not None and run.dag_id != self.dag_id:
1020 raise ValueError(f"Arguments refer to different DAGs: {self.dag_id} != {run.dag_id}")
1021 data_interval = _get_model_data_interval(run, "data_interval_start", "data_interval_end")
1022 if data_interval is not None:
1023 return data_interval
1024 # Compatibility: runs created before AIP-39 implementation don't have an
1025 # explicit data interval. Try to infer from the logical date.
1026 return self.infer_automated_data_interval(run.execution_date)
1028 def infer_automated_data_interval(self, logical_date: datetime) -> DataInterval:
1029 """Infer a data interval for a run against this DAG.
1031 This method is used to bridge runs created prior to AIP-39
1032 implementation, which do not have an explicit data interval. Therefore,
1033 this method only considers ``schedule_interval`` values valid prior to
1034 Airflow 2.2.
1036 DO NOT call this method if there is a known data interval.
1038 :meta private:
1039 """
1040 timetable_type = type(self.timetable)
1041 if issubclass(timetable_type, (NullTimetable, OnceTimetable, DatasetTriggeredTimetable)):
1042 return DataInterval.exact(timezone.coerce_datetime(logical_date))
1043 start = timezone.coerce_datetime(logical_date)
1044 if issubclass(timetable_type, CronDataIntervalTimetable):
1045 end = cast(CronDataIntervalTimetable, self.timetable)._get_next(start)
1046 elif issubclass(timetable_type, DeltaDataIntervalTimetable):
1047 end = cast(DeltaDataIntervalTimetable, self.timetable)._get_next(start)
1048 # Contributors: When the exception below is raised, you might want to
1049 # add an 'elif' block here to handle custom timetables. Stop! The bug
1050 # you're looking for is instead at when the DAG run (represented by
1051 # logical_date) was created. See GH-31969 for an example:
1052 # * Wrong fix: GH-32074 (modifies this function).
1053 # * Correct fix: GH-32118 (modifies the DAG run creation code).
1054 else:
1055 raise ValueError(f"Not a valid timetable: {self.timetable!r}")
1056 return DataInterval(start, end)
1058 def next_dagrun_info(
1059 self,
1060 last_automated_dagrun: None | datetime | DataInterval,
1061 *,
1062 restricted: bool = True,
1063 ) -> DagRunInfo | None:
1064 """Get information about the next DagRun of this dag after ``date_last_automated_dagrun``.
1066 This calculates what time interval the next DagRun should operate on
1067 (its execution date) and when it can be scheduled, according to the
1068 dag's timetable, start_date, end_date, etc. This doesn't check max
1069 active run or any other "max_active_tasks" type limits, but only
1070 performs calculations based on the various date and interval fields of
1071 this dag and its tasks.
1073 :param last_automated_dagrun: The ``max(execution_date)`` of
1074 existing "automated" DagRuns for this dag (scheduled or backfill,
1075 but not manual).
1076 :param restricted: If set to *False* (default is *True*), ignore
1077 ``start_date``, ``end_date``, and ``catchup`` specified on the DAG
1078 or tasks.
1079 :return: DagRunInfo of the next dagrun, or None if a dagrun is not
1080 going to be scheduled.
1081 """
1082 # Never schedule a subdag. It will be scheduled by its parent dag.
1083 if self.is_subdag:
1084 return None
1086 data_interval = None
1087 if isinstance(last_automated_dagrun, datetime):
1088 warnings.warn(
1089 "Passing a datetime to DAG.next_dagrun_info is deprecated. Use a DataInterval instead.",
1090 RemovedInAirflow3Warning,
1091 stacklevel=2,
1092 )
1093 data_interval = self.infer_automated_data_interval(
1094 timezone.coerce_datetime(last_automated_dagrun)
1095 )
1096 else:
1097 data_interval = last_automated_dagrun
1098 if restricted:
1099 restriction = self._time_restriction
1100 else:
1101 restriction = TimeRestriction(earliest=None, latest=None, catchup=True)
1102 try:
1103 info = self.timetable.next_dagrun_info(
1104 last_automated_data_interval=data_interval,
1105 restriction=restriction,
1106 )
1107 except Exception:
1108 self.log.exception(
1109 "Failed to fetch run info after data interval %s for DAG %r",
1110 data_interval,
1111 self.dag_id,
1112 )
1113 info = None
1114 return info
1116 def next_dagrun_after_date(self, date_last_automated_dagrun: pendulum.DateTime | None):
1117 warnings.warn(
1118 "`DAG.next_dagrun_after_date()` is deprecated. Please use `DAG.next_dagrun_info()` instead.",
1119 category=RemovedInAirflow3Warning,
1120 stacklevel=2,
1121 )
1122 if date_last_automated_dagrun is None:
1123 data_interval = None
1124 else:
1125 data_interval = self.infer_automated_data_interval(date_last_automated_dagrun)
1126 info = self.next_dagrun_info(data_interval)
1127 if info is None:
1128 return None
1129 return info.run_after
1131 @functools.cached_property
1132 def _time_restriction(self) -> TimeRestriction:
1133 start_dates = [t.start_date for t in self.tasks if t.start_date]
1134 if self.start_date is not None:
1135 start_dates.append(self.start_date)
1136 earliest = None
1137 if start_dates:
1138 earliest = timezone.coerce_datetime(min(start_dates))
1139 latest = self.end_date
1140 end_dates = [t.end_date for t in self.tasks if t.end_date]
1141 if len(end_dates) == len(self.tasks): # not exists null end_date
1142 if self.end_date is not None:
1143 end_dates.append(self.end_date)
1144 if end_dates:
1145 latest = timezone.coerce_datetime(max(end_dates))
1146 return TimeRestriction(earliest, latest, self.catchup)
1148 def iter_dagrun_infos_between(
1149 self,
1150 earliest: pendulum.DateTime | None,
1151 latest: pendulum.DateTime,
1152 *,
1153 align: bool = True,
1154 ) -> Iterable[DagRunInfo]:
1155 """Yield DagRunInfo using this DAG's timetable between given interval.
1157 DagRunInfo instances yielded if their ``logical_date`` is not earlier
1158 than ``earliest``, nor later than ``latest``. The instances are ordered
1159 by their ``logical_date`` from earliest to latest.
1161 If ``align`` is ``False``, the first run will happen immediately on
1162 ``earliest``, even if it does not fall on the logical timetable schedule.
1163 The default is ``True``, but subdags will ignore this value and always
1164 behave as if this is set to ``False`` for backward compatibility.
1166 Example: A DAG is scheduled to run every midnight (``0 0 * * *``). If
1167 ``earliest`` is ``2021-06-03 23:00:00``, the first DagRunInfo would be
1168 ``2021-06-03 23:00:00`` if ``align=False``, and ``2021-06-04 00:00:00``
1169 if ``align=True``.
1170 """
1171 if earliest is None:
1172 earliest = self._time_restriction.earliest
1173 if earliest is None:
1174 raise ValueError("earliest was None and we had no value in time_restriction to fallback on")
1175 earliest = timezone.coerce_datetime(earliest)
1176 latest = timezone.coerce_datetime(latest)
1178 restriction = TimeRestriction(earliest, latest, catchup=True)
1180 # HACK: Sub-DAGs are currently scheduled differently. For example, say
1181 # the schedule is @daily and start is 2021-06-03 22:16:00, a top-level
1182 # DAG should be first scheduled to run on midnight 2021-06-04, but a
1183 # sub-DAG should be first scheduled to run RIGHT NOW. We can change
1184 # this, but since sub-DAGs are going away in 3.0 anyway, let's keep
1185 # compatibility for now and remove this entirely later.
1186 if self.is_subdag:
1187 align = False
1189 try:
1190 info = self.timetable.next_dagrun_info(
1191 last_automated_data_interval=None,
1192 restriction=restriction,
1193 )
1194 except Exception:
1195 self.log.exception(
1196 "Failed to fetch run info after data interval %s for DAG %r",
1197 None,
1198 self.dag_id,
1199 )
1200 info = None
1202 if info is None:
1203 # No runs to be scheduled between the user-supplied timeframe. But
1204 # if align=False, "invent" a data interval for the timeframe itself.
1205 if not align:
1206 yield DagRunInfo.interval(earliest, latest)
1207 return
1209 # If align=False and earliest does not fall on the timetable's logical
1210 # schedule, "invent" a data interval for it.
1211 if not align and info.logical_date != earliest:
1212 yield DagRunInfo.interval(earliest, info.data_interval.start)
1214 # Generate naturally according to schedule.
1215 while info is not None:
1216 yield info
1217 try:
1218 info = self.timetable.next_dagrun_info(
1219 last_automated_data_interval=info.data_interval,
1220 restriction=restriction,
1221 )
1222 except Exception:
1223 self.log.exception(
1224 "Failed to fetch run info after data interval %s for DAG %r",
1225 info.data_interval if info else "<NONE>",
1226 self.dag_id,
1227 )
1228 break
1230 def get_run_dates(self, start_date, end_date=None) -> list:
1231 """
1232 Return a list of dates between the interval received as parameter using this dag's schedule interval.
1234 Returned dates can be used for execution dates.
1236 :param start_date: The start date of the interval.
1237 :param end_date: The end date of the interval. Defaults to ``timezone.utcnow()``.
1238 :return: A list of dates within the interval following the dag's schedule.
1239 """
1240 warnings.warn(
1241 "`DAG.get_run_dates()` is deprecated. Please use `DAG.iter_dagrun_infos_between()` instead.",
1242 category=RemovedInAirflow3Warning,
1243 stacklevel=2,
1244 )
1245 earliest = timezone.coerce_datetime(start_date)
1246 if end_date is None:
1247 latest = pendulum.now(timezone.utc)
1248 else:
1249 latest = timezone.coerce_datetime(end_date)
1250 return [info.logical_date for info in self.iter_dagrun_infos_between(earliest, latest)]
1252 def normalize_schedule(self, dttm):
1253 warnings.warn(
1254 "`DAG.normalize_schedule()` is deprecated.",
1255 category=RemovedInAirflow3Warning,
1256 stacklevel=2,
1257 )
1258 with warnings.catch_warnings():
1259 warnings.simplefilter("ignore", RemovedInAirflow3Warning)
1260 following = self.following_schedule(dttm)
1261 if not following: # in case of @once
1262 return dttm
1263 with warnings.catch_warnings():
1264 warnings.simplefilter("ignore", RemovedInAirflow3Warning)
1265 previous_of_following = self.previous_schedule(following)
1266 if previous_of_following != dttm:
1267 return following
1268 return dttm
1270 @provide_session
1271 def get_last_dagrun(self, session=NEW_SESSION, include_externally_triggered=False):
1272 return get_last_dagrun(
1273 self.dag_id, session=session, include_externally_triggered=include_externally_triggered
1274 )
1276 @provide_session
1277 def has_dag_runs(self, session=NEW_SESSION, include_externally_triggered=True) -> bool:
1278 return (
1279 get_last_dagrun(
1280 self.dag_id, session=session, include_externally_triggered=include_externally_triggered
1281 )
1282 is not None
1283 )
1285 @property
1286 def dag_id(self) -> str:
1287 return self._dag_id
1289 @dag_id.setter
1290 def dag_id(self, value: str) -> None:
1291 self._dag_id = value
1293 @property
1294 def is_subdag(self) -> bool:
1295 return self.parent_dag is not None
1297 @property
1298 def full_filepath(self) -> str:
1299 """Full file path to the DAG.
1301 :meta private:
1302 """
1303 warnings.warn(
1304 "DAG.full_filepath is deprecated in favour of fileloc",
1305 RemovedInAirflow3Warning,
1306 stacklevel=2,
1307 )
1308 return self.fileloc
1310 @full_filepath.setter
1311 def full_filepath(self, value) -> None:
1312 warnings.warn(
1313 "DAG.full_filepath is deprecated in favour of fileloc",
1314 RemovedInAirflow3Warning,
1315 stacklevel=2,
1316 )
1317 self.fileloc = value
1319 @property
1320 def concurrency(self) -> int:
1321 # TODO: Remove in Airflow 3.0
1322 warnings.warn(
1323 "The 'DAG.concurrency' attribute is deprecated. Please use 'DAG.max_active_tasks'.",
1324 RemovedInAirflow3Warning,
1325 stacklevel=2,
1326 )
1327 return self._max_active_tasks
1329 @concurrency.setter
1330 def concurrency(self, value: int):
1331 self._max_active_tasks = value
1333 @property
1334 def max_active_tasks(self) -> int:
1335 return self._max_active_tasks
1337 @max_active_tasks.setter
1338 def max_active_tasks(self, value: int):
1339 self._max_active_tasks = value
1341 @property
1342 def access_control(self):
1343 return self._access_control
1345 @access_control.setter
1346 def access_control(self, value):
1347 self._access_control = DAG._upgrade_outdated_dag_access_control(value)
1349 @property
1350 def dag_display_name(self) -> str:
1351 return self._dag_display_property_value or self._dag_id
1353 @property
1354 def description(self) -> str | None:
1355 return self._description
1357 @property
1358 def default_view(self) -> str:
1359 return self._default_view
1361 @property
1362 def pickle_id(self) -> int | None:
1363 return self._pickle_id
1365 @pickle_id.setter
1366 def pickle_id(self, value: int) -> None:
1367 self._pickle_id = value
1369 def param(self, name: str, default: Any = NOTSET) -> DagParam:
1370 """
1371 Return a DagParam object for current dag.
1373 :param name: dag parameter name.
1374 :param default: fallback value for dag parameter.
1375 :return: DagParam instance for specified name and current dag.
1376 """
1377 return DagParam(current_dag=self, name=name, default=default)
1379 @property
1380 def tasks(self) -> list[Operator]:
1381 return list(self.task_dict.values())
1383 @tasks.setter
1384 def tasks(self, val):
1385 raise AttributeError("DAG.tasks can not be modified. Use dag.add_task() instead.")
1387 @property
1388 def task_ids(self) -> list[str]:
1389 return list(self.task_dict)
1391 @property
1392 def teardowns(self) -> list[Operator]:
1393 return [task for task in self.tasks if getattr(task, "is_teardown", None)]
1395 @property
1396 def tasks_upstream_of_teardowns(self) -> list[Operator]:
1397 upstream_tasks = [t.upstream_list for t in self.teardowns]
1398 return [val for sublist in upstream_tasks for val in sublist if not getattr(val, "is_teardown", None)]
1400 @property
1401 def task_group(self) -> TaskGroup:
1402 return self._task_group
1404 @property
1405 def filepath(self) -> str:
1406 """Relative file path to the DAG.
1408 :meta private:
1409 """
1410 warnings.warn(
1411 "filepath is deprecated, use relative_fileloc instead",
1412 RemovedInAirflow3Warning,
1413 stacklevel=2,
1414 )
1415 return str(self.relative_fileloc)
1417 @property
1418 def relative_fileloc(self) -> pathlib.Path:
1419 """File location of the importable dag 'file' relative to the configured DAGs folder."""
1420 path = pathlib.Path(self.fileloc)
1421 try:
1422 rel_path = path.relative_to(self._processor_dags_folder or settings.DAGS_FOLDER)
1423 if rel_path == pathlib.Path("."):
1424 return path
1425 else:
1426 return rel_path
1427 except ValueError:
1428 # Not relative to DAGS_FOLDER.
1429 return path
1431 @property
1432 def folder(self) -> str:
1433 """Folder location of where the DAG object is instantiated."""
1434 return os.path.dirname(self.fileloc)
1436 @property
1437 def owner(self) -> str:
1438 """
1439 Return list of all owners found in DAG tasks.
1441 :return: Comma separated list of owners in DAG tasks
1442 """
1443 return ", ".join({t.owner for t in self.tasks})
1445 @property
1446 def allow_future_exec_dates(self) -> bool:
1447 return settings.ALLOW_FUTURE_EXEC_DATES and not self.timetable.can_be_scheduled
1449 @provide_session
1450 def get_concurrency_reached(self, session=NEW_SESSION) -> bool:
1451 """Return a boolean indicating whether the max_active_tasks limit for this DAG has been reached."""
1452 TI = TaskInstance
1453 total_tasks = session.scalar(
1454 select(func.count(TI.task_id)).where(
1455 TI.dag_id == self.dag_id,
1456 TI.state == TaskInstanceState.RUNNING,
1457 )
1458 )
1459 return total_tasks >= self.max_active_tasks
1461 @property
1462 def concurrency_reached(self):
1463 """Use `airflow.models.DAG.get_concurrency_reached`, this attribute is deprecated."""
1464 warnings.warn(
1465 "This attribute is deprecated. Please use `airflow.models.DAG.get_concurrency_reached` method.",
1466 RemovedInAirflow3Warning,
1467 stacklevel=2,
1468 )
1469 return self.get_concurrency_reached()
1471 @provide_session
1472 def get_is_active(self, session=NEW_SESSION) -> None:
1473 """Return a boolean indicating whether this DAG is active."""
1474 return session.scalar(select(DagModel.is_active).where(DagModel.dag_id == self.dag_id))
1476 @provide_session
1477 def get_is_paused(self, session=NEW_SESSION) -> None:
1478 """Return a boolean indicating whether this DAG is paused."""
1479 return session.scalar(select(DagModel.is_paused).where(DagModel.dag_id == self.dag_id))
1481 @property
1482 def is_paused(self):
1483 """Use `airflow.models.DAG.get_is_paused`, this attribute is deprecated."""
1484 warnings.warn(
1485 "This attribute is deprecated. Please use `airflow.models.DAG.get_is_paused` method.",
1486 RemovedInAirflow3Warning,
1487 stacklevel=2,
1488 )
1489 return self.get_is_paused()
1491 @property
1492 def normalized_schedule_interval(self) -> ScheduleInterval:
1493 warnings.warn(
1494 "DAG.normalized_schedule_interval() is deprecated.",
1495 category=RemovedInAirflow3Warning,
1496 stacklevel=2,
1497 )
1498 if isinstance(self.schedule_interval, str) and self.schedule_interval in cron_presets:
1499 _schedule_interval: ScheduleInterval = cron_presets.get(self.schedule_interval)
1500 elif self.schedule_interval == "@once":
1501 _schedule_interval = None
1502 else:
1503 _schedule_interval = self.schedule_interval
1504 return _schedule_interval
1506 @staticmethod
1507 @internal_api_call
1508 @provide_session
1509 def fetch_callback(
1510 dag: DAG,
1511 dag_run_id: str,
1512 success: bool = True,
1513 reason: str | None = None,
1514 *,
1515 session: Session = NEW_SESSION,
1516 ) -> tuple[list[TaskStateChangeCallback], Context] | None:
1517 """
1518 Fetch the appropriate callbacks depending on the value of success.
1520 This method gets the context of a single TaskInstance part of this DagRun and returns it along
1521 the list of callbacks.
1523 :param dag: DAG object
1524 :param dag_run_id: The DAG run ID
1525 :param success: Flag to specify if failure or success callback should be called
1526 :param reason: Completion reason
1527 :param session: Database session
1528 """
1529 callbacks = dag.on_success_callback if success else dag.on_failure_callback
1530 if callbacks:
1531 dagrun = DAG.fetch_dagrun(dag_id=dag.dag_id, run_id=dag_run_id, session=session)
1532 callbacks = callbacks if isinstance(callbacks, list) else [callbacks]
1533 tis = dagrun.get_task_instances(session=session)
1534 # tis from a dagrun may not be a part of dag.partial_subset,
1535 # since dag.partial_subset is a subset of the dag.
1536 # This ensures that we will only use the accessible TI
1537 # context for the callback.
1538 if dag.partial:
1539 tis = [ti for ti in tis if not ti.state == State.NONE]
1540 # filter out removed tasks
1541 tis = [ti for ti in tis if ti.state != TaskInstanceState.REMOVED]
1542 ti = tis[-1] # get first TaskInstance of DagRun
1543 ti.task = dag.get_task(ti.task_id)
1544 context = ti.get_template_context(session=session)
1545 context["reason"] = reason
1546 return callbacks, context
1547 return None
1549 @provide_session
1550 def handle_callback(self, dagrun: DagRun, success=True, reason=None, session=NEW_SESSION):
1551 """
1552 Triggers on_failure_callback or on_success_callback as appropriate.
1554 This method gets the context of a single TaskInstance part of this DagRun
1555 and passes that to the callable along with a 'reason', primarily to
1556 differentiate DagRun failures.
1558 .. note: The logs end up in
1559 ``$AIRFLOW_HOME/logs/scheduler/latest/PROJECT/DAG_FILE.py.log``
1561 :param dagrun: DagRun object
1562 :param success: Flag to specify if failure or success callback should be called
1563 :param reason: Completion reason
1564 :param session: Database session
1565 """
1566 callbacks, context = DAG.fetch_callback(
1567 dag=self, dag_run_id=dagrun.run_id, success=success, reason=reason, session=session
1568 ) or (None, None)
1570 DAG.execute_callback(callbacks, context, self.dag_id)
1572 @classmethod
1573 def execute_callback(cls, callbacks: list[Callable] | None, context: Context | None, dag_id: str):
1574 """
1575 Triggers the callbacks with the given context.
1577 :param callbacks: List of callbacks to call
1578 :param context: Context to pass to all callbacks
1579 :param dag_id: The dag_id of the DAG to find.
1580 """
1581 if callbacks and context:
1582 for callback in callbacks:
1583 cls.logger().info("Executing dag callback function: %s", callback)
1584 try:
1585 callback(context)
1586 except Exception:
1587 cls.logger().exception("failed to invoke dag state update callback")
1588 Stats.incr("dag.callback_exceptions", tags={"dag_id": dag_id})
1590 def get_active_runs(self):
1591 """
1592 Return a list of dag run execution dates currently running.
1594 :return: List of execution dates
1595 """
1596 runs = DagRun.find(dag_id=self.dag_id, state=DagRunState.RUNNING)
1598 active_dates = []
1599 for run in runs:
1600 active_dates.append(run.execution_date)
1602 return active_dates
1604 @provide_session
1605 def get_num_active_runs(self, external_trigger=None, only_running=True, session=NEW_SESSION):
1606 """
1607 Return the number of active "running" dag runs.
1609 :param external_trigger: True for externally triggered active dag runs
1610 :param session:
1611 :return: number greater than 0 for active dag runs
1612 """
1613 query = select(func.count()).where(DagRun.dag_id == self.dag_id)
1614 if only_running:
1615 query = query.where(DagRun.state == DagRunState.RUNNING)
1616 else:
1617 query = query.where(DagRun.state.in_({DagRunState.RUNNING, DagRunState.QUEUED}))
1619 if external_trigger is not None:
1620 query = query.where(
1621 DagRun.external_trigger == (expression.true() if external_trigger else expression.false())
1622 )
1624 return session.scalar(query)
1626 @staticmethod
1627 @internal_api_call
1628 @provide_session
1629 def fetch_dagrun(
1630 dag_id: str,
1631 execution_date: datetime | None = None,
1632 run_id: str | None = None,
1633 session: Session = NEW_SESSION,
1634 ) -> DagRun | DagRunPydantic:
1635 """
1636 Return the dag run for a given execution date or run_id if it exists, otherwise none.
1638 :param dag_id: The dag_id of the DAG to find.
1639 :param execution_date: The execution date of the DagRun to find.
1640 :param run_id: The run_id of the DagRun to find.
1641 :param session:
1642 :return: The DagRun if found, otherwise None.
1643 """
1644 if not (execution_date or run_id):
1645 raise TypeError("You must provide either the execution_date or the run_id")
1646 query = select(DagRun)
1647 if execution_date:
1648 query = query.where(DagRun.dag_id == dag_id, DagRun.execution_date == execution_date)
1649 if run_id:
1650 query = query.where(DagRun.dag_id == dag_id, DagRun.run_id == run_id)
1651 return session.scalar(query)
1653 @provide_session
1654 def get_dagrun(
1655 self,
1656 execution_date: datetime | None = None,
1657 run_id: str | None = None,
1658 session: Session = NEW_SESSION,
1659 ) -> DagRun | DagRunPydantic:
1660 return DAG.fetch_dagrun(
1661 dag_id=self.dag_id, execution_date=execution_date, run_id=run_id, session=session
1662 )
1664 @provide_session
1665 def get_dagruns_between(self, start_date, end_date, session=NEW_SESSION):
1666 """
1667 Return the list of dag runs between start_date (inclusive) and end_date (inclusive).
1669 :param start_date: The starting execution date of the DagRun to find.
1670 :param end_date: The ending execution date of the DagRun to find.
1671 :param session:
1672 :return: The list of DagRuns found.
1673 """
1674 dagruns = session.scalars(
1675 select(DagRun).where(
1676 DagRun.dag_id == self.dag_id,
1677 DagRun.execution_date >= start_date,
1678 DagRun.execution_date <= end_date,
1679 )
1680 ).all()
1682 return dagruns
1684 @provide_session
1685 def get_latest_execution_date(self, session: Session = NEW_SESSION) -> pendulum.DateTime | None:
1686 """Return the latest date for which at least one dag run exists."""
1687 return session.scalar(select(func.max(DagRun.execution_date)).where(DagRun.dag_id == self.dag_id))
1689 @property
1690 def latest_execution_date(self):
1691 """Use `airflow.models.DAG.get_latest_execution_date`, this attribute is deprecated."""
1692 warnings.warn(
1693 "This attribute is deprecated. Please use `airflow.models.DAG.get_latest_execution_date`.",
1694 RemovedInAirflow3Warning,
1695 stacklevel=2,
1696 )
1697 return self.get_latest_execution_date()
1699 @property
1700 def subdags(self):
1701 """Return a list of the subdag objects associated to this DAG."""
1702 # Check SubDag for class but don't check class directly
1703 from airflow.operators.subdag import SubDagOperator
1705 subdag_lst = []
1706 for task in self.tasks:
1707 if (
1708 isinstance(task, SubDagOperator)
1709 or
1710 # TODO remove in Airflow 2.0
1711 type(task).__name__ == "SubDagOperator"
1712 or task.task_type == "SubDagOperator"
1713 ):
1714 subdag_lst.append(task.subdag)
1715 subdag_lst += task.subdag.subdags
1716 return subdag_lst
1718 def resolve_template_files(self):
1719 for t in self.tasks:
1720 t.resolve_template_files()
1722 def get_template_env(self, *, force_sandboxed: bool = False) -> jinja2.Environment:
1723 """Build a Jinja2 environment."""
1724 # Collect directories to search for template files
1725 searchpath = [self.folder]
1726 if self.template_searchpath:
1727 searchpath += self.template_searchpath
1729 # Default values (for backward compatibility)
1730 jinja_env_options = {
1731 "loader": jinja2.FileSystemLoader(searchpath),
1732 "undefined": self.template_undefined,
1733 "extensions": ["jinja2.ext.do"],
1734 "cache_size": 0,
1735 }
1736 if self.jinja_environment_kwargs:
1737 jinja_env_options.update(self.jinja_environment_kwargs)
1738 env: jinja2.Environment
1739 if self.render_template_as_native_obj and not force_sandboxed:
1740 env = airflow.templates.NativeEnvironment(**jinja_env_options)
1741 else:
1742 env = airflow.templates.SandboxedEnvironment(**jinja_env_options)
1744 # Add any user defined items. Safe to edit globals as long as no templates are rendered yet.
1745 # http://jinja.pocoo.org/docs/2.10/api/#jinja2.Environment.globals
1746 if self.user_defined_macros:
1747 env.globals.update(self.user_defined_macros)
1748 if self.user_defined_filters:
1749 env.filters.update(self.user_defined_filters)
1751 return env
1753 def set_dependency(self, upstream_task_id, downstream_task_id):
1754 """Set dependency between two tasks that already have been added to the DAG using add_task()."""
1755 self.get_task(upstream_task_id).set_downstream(self.get_task(downstream_task_id))
1757 @provide_session
1758 def get_task_instances_before(
1759 self,
1760 base_date: datetime,
1761 num: int,
1762 *,
1763 session: Session = NEW_SESSION,
1764 ) -> list[TaskInstance]:
1765 """Get ``num`` task instances before (including) ``base_date``.
1767 The returned list may contain exactly ``num`` task instances
1768 corresponding to any DagRunType. It can have less if there are
1769 less than ``num`` scheduled DAG runs before ``base_date``.
1770 """
1771 execution_dates: list[Any] = session.execute(
1772 select(DagRun.execution_date)
1773 .where(
1774 DagRun.dag_id == self.dag_id,
1775 DagRun.execution_date <= base_date,
1776 )
1777 .order_by(DagRun.execution_date.desc())
1778 .limit(num)
1779 ).all()
1781 if not execution_dates:
1782 return self.get_task_instances(start_date=base_date, end_date=base_date, session=session)
1784 min_date: datetime | None = execution_dates[-1]._mapping.get(
1785 "execution_date"
1786 ) # getting the last value from the list
1788 return self.get_task_instances(start_date=min_date, end_date=base_date, session=session)
1790 @provide_session
1791 def get_task_instances(
1792 self,
1793 start_date: datetime | None = None,
1794 end_date: datetime | None = None,
1795 state: list[TaskInstanceState] | None = None,
1796 session: Session = NEW_SESSION,
1797 ) -> list[TaskInstance]:
1798 if not start_date:
1799 start_date = (timezone.utcnow() - timedelta(30)).replace(
1800 hour=0, minute=0, second=0, microsecond=0
1801 )
1803 query = self._get_task_instances(
1804 task_ids=None,
1805 start_date=start_date,
1806 end_date=end_date,
1807 run_id=None,
1808 state=state or (),
1809 include_subdags=False,
1810 include_parentdag=False,
1811 include_dependent_dags=False,
1812 exclude_task_ids=(),
1813 session=session,
1814 )
1815 return session.scalars(cast(Select, query).order_by(DagRun.execution_date)).all()
1817 @overload
1818 def _get_task_instances(
1819 self,
1820 *,
1821 task_ids: Collection[str | tuple[str, int]] | None,
1822 start_date: datetime | None,
1823 end_date: datetime | None,
1824 run_id: str | None,
1825 state: TaskInstanceState | Sequence[TaskInstanceState],
1826 include_subdags: bool,
1827 include_parentdag: bool,
1828 include_dependent_dags: bool,
1829 exclude_task_ids: Collection[str | tuple[str, int]] | None,
1830 session: Session,
1831 dag_bag: DagBag | None = ...,
1832 ) -> Iterable[TaskInstance]: ... # pragma: no cover
1834 @overload
1835 def _get_task_instances(
1836 self,
1837 *,
1838 task_ids: Collection[str | tuple[str, int]] | None,
1839 as_pk_tuple: Literal[True],
1840 start_date: datetime | None,
1841 end_date: datetime | None,
1842 run_id: str | None,
1843 state: TaskInstanceState | Sequence[TaskInstanceState],
1844 include_subdags: bool,
1845 include_parentdag: bool,
1846 include_dependent_dags: bool,
1847 exclude_task_ids: Collection[str | tuple[str, int]] | None,
1848 session: Session,
1849 dag_bag: DagBag | None = ...,
1850 recursion_depth: int = ...,
1851 max_recursion_depth: int = ...,
1852 visited_external_tis: set[TaskInstanceKey] = ...,
1853 ) -> set[TaskInstanceKey]: ... # pragma: no cover
1855 def _get_task_instances(
1856 self,
1857 *,
1858 task_ids: Collection[str | tuple[str, int]] | None,
1859 as_pk_tuple: Literal[True, None] = None,
1860 start_date: datetime | None,
1861 end_date: datetime | None,
1862 run_id: str | None,
1863 state: TaskInstanceState | Sequence[TaskInstanceState],
1864 include_subdags: bool,
1865 include_parentdag: bool,
1866 include_dependent_dags: bool,
1867 exclude_task_ids: Collection[str | tuple[str, int]] | None,
1868 session: Session,
1869 dag_bag: DagBag | None = None,
1870 recursion_depth: int = 0,
1871 max_recursion_depth: int | None = None,
1872 visited_external_tis: set[TaskInstanceKey] | None = None,
1873 ) -> Iterable[TaskInstance] | set[TaskInstanceKey]:
1874 TI = TaskInstance
1876 # If we are looking at subdags/dependent dags we want to avoid UNION calls
1877 # in SQL (it doesn't play nice with fields that have no equality operator,
1878 # like JSON types), we instead build our result set separately.
1879 #
1880 # This will be empty if we are only looking at one dag, in which case
1881 # we can return the filtered TI query object directly.
1882 result: set[TaskInstanceKey] = set()
1884 # Do we want full objects, or just the primary columns?
1885 if as_pk_tuple:
1886 tis = select(TI.dag_id, TI.task_id, TI.run_id, TI.map_index)
1887 else:
1888 tis = select(TaskInstance)
1889 tis = tis.join(TaskInstance.dag_run)
1891 if include_subdags:
1892 # Crafting the right filter for dag_id and task_ids combo
1893 conditions = []
1894 for dag in [*self.subdags, self]:
1895 conditions.append(
1896 (TaskInstance.dag_id == dag.dag_id) & TaskInstance.task_id.in_(dag.task_ids)
1897 )
1898 tis = tis.where(or_(*conditions))
1899 elif self.partial:
1900 tis = tis.where(TaskInstance.dag_id == self.dag_id, TaskInstance.task_id.in_(self.task_ids))
1901 else:
1902 tis = tis.where(TaskInstance.dag_id == self.dag_id)
1903 if run_id:
1904 tis = tis.where(TaskInstance.run_id == run_id)
1905 if start_date:
1906 tis = tis.where(DagRun.execution_date >= start_date)
1907 if task_ids is not None:
1908 tis = tis.where(TaskInstance.ti_selector_condition(task_ids))
1910 # This allows allow_trigger_in_future config to take affect, rather than mandating exec_date <= UTC
1911 if end_date or not self.allow_future_exec_dates:
1912 end_date = end_date or timezone.utcnow()
1913 tis = tis.where(DagRun.execution_date <= end_date)
1915 if state:
1916 if isinstance(state, (str, TaskInstanceState)):
1917 tis = tis.where(TaskInstance.state == state)
1918 elif len(state) == 1:
1919 tis = tis.where(TaskInstance.state == state[0])
1920 else:
1921 # this is required to deal with NULL values
1922 if None in state:
1923 if all(x is None for x in state):
1924 tis = tis.where(TaskInstance.state.is_(None))
1925 else:
1926 not_none_state = [s for s in state if s]
1927 tis = tis.where(
1928 or_(TaskInstance.state.in_(not_none_state), TaskInstance.state.is_(None))
1929 )
1930 else:
1931 tis = tis.where(TaskInstance.state.in_(state))
1933 # Next, get any of them from our parent DAG (if there is one)
1934 if include_parentdag and self.parent_dag is not None:
1935 if visited_external_tis is None:
1936 visited_external_tis = set()
1938 p_dag = self.parent_dag.partial_subset(
1939 task_ids_or_regex=r"^{}$".format(self.dag_id.split(".")[1]),
1940 include_upstream=False,
1941 include_downstream=True,
1942 )
1943 result.update(
1944 p_dag._get_task_instances(
1945 task_ids=task_ids,
1946 start_date=start_date,
1947 end_date=end_date,
1948 run_id=None,
1949 state=state,
1950 include_subdags=include_subdags,
1951 include_parentdag=False,
1952 include_dependent_dags=include_dependent_dags,
1953 as_pk_tuple=True,
1954 exclude_task_ids=exclude_task_ids,
1955 session=session,
1956 dag_bag=dag_bag,
1957 recursion_depth=recursion_depth,
1958 max_recursion_depth=max_recursion_depth,
1959 visited_external_tis=visited_external_tis,
1960 )
1961 )
1963 if include_dependent_dags:
1964 # Recursively find external tasks indicated by ExternalTaskMarker
1965 from airflow.sensors.external_task import ExternalTaskMarker
1967 query = tis
1968 if as_pk_tuple:
1969 all_tis = session.execute(query).all()
1970 condition = TI.filter_for_tis(TaskInstanceKey(*cols) for cols in all_tis)
1971 if condition is not None:
1972 query = select(TI).where(condition)
1974 if visited_external_tis is None:
1975 visited_external_tis = set()
1977 external_tasks = session.scalars(query.where(TI.operator == ExternalTaskMarker.__name__))
1979 for ti in external_tasks:
1980 ti_key = ti.key.primary
1981 if ti_key in visited_external_tis:
1982 continue
1984 visited_external_tis.add(ti_key)
1986 task: ExternalTaskMarker = cast(ExternalTaskMarker, copy.copy(self.get_task(ti.task_id)))
1987 ti.task = task
1989 if max_recursion_depth is None:
1990 # Maximum recursion depth allowed is the recursion_depth of the first
1991 # ExternalTaskMarker in the tasks to be visited.
1992 max_recursion_depth = task.recursion_depth
1994 if recursion_depth + 1 > max_recursion_depth:
1995 # Prevent cycles or accidents.
1996 raise AirflowException(
1997 f"Maximum recursion depth {max_recursion_depth} reached for "
1998 f"{ExternalTaskMarker.__name__} {ti.task_id}. "
1999 f"Attempted to clear too many tasks or there may be a cyclic dependency."
2000 )
2001 ti.render_templates()
2002 external_tis = session.scalars(
2003 select(TI)
2004 .join(TI.dag_run)
2005 .where(
2006 TI.dag_id == task.external_dag_id,
2007 TI.task_id == task.external_task_id,
2008 DagRun.execution_date == pendulum.parse(task.execution_date),
2009 )
2010 )
2012 for tii in external_tis:
2013 if not dag_bag:
2014 from airflow.models.dagbag import DagBag
2016 dag_bag = DagBag(read_dags_from_db=True)
2017 external_dag = dag_bag.get_dag(tii.dag_id, session=session)
2018 if not external_dag:
2019 raise AirflowException(f"Could not find dag {tii.dag_id}")
2020 downstream = external_dag.partial_subset(
2021 task_ids_or_regex=[tii.task_id],
2022 include_upstream=False,
2023 include_downstream=True,
2024 )
2025 result.update(
2026 downstream._get_task_instances(
2027 task_ids=None,
2028 run_id=tii.run_id,
2029 start_date=None,
2030 end_date=None,
2031 state=state,
2032 include_subdags=include_subdags,
2033 include_dependent_dags=include_dependent_dags,
2034 include_parentdag=False,
2035 as_pk_tuple=True,
2036 exclude_task_ids=exclude_task_ids,
2037 dag_bag=dag_bag,
2038 session=session,
2039 recursion_depth=recursion_depth + 1,
2040 max_recursion_depth=max_recursion_depth,
2041 visited_external_tis=visited_external_tis,
2042 )
2043 )
2045 if result or as_pk_tuple:
2046 # Only execute the `ti` query if we have also collected some other results (i.e. subdags etc.)
2047 if as_pk_tuple:
2048 tis_query = session.execute(tis).all()
2049 result.update(TaskInstanceKey(**cols._mapping) for cols in tis_query)
2050 else:
2051 result.update(ti.key for ti in session.scalars(tis))
2053 if exclude_task_ids is not None:
2054 result = {
2055 task
2056 for task in result
2057 if task.task_id not in exclude_task_ids
2058 and (task.task_id, task.map_index) not in exclude_task_ids
2059 }
2061 if as_pk_tuple:
2062 return result
2063 if result:
2064 # We've been asked for objects, lets combine it all back in to a result set
2065 ti_filters = TI.filter_for_tis(result)
2066 if ti_filters is not None:
2067 tis = select(TI).where(ti_filters)
2068 elif exclude_task_ids is None:
2069 pass # Disable filter if not set.
2070 elif isinstance(next(iter(exclude_task_ids), None), str):
2071 tis = tis.where(TI.task_id.notin_(exclude_task_ids))
2072 else:
2073 tis = tis.where(not_(tuple_in_condition((TI.task_id, TI.map_index), exclude_task_ids)))
2075 return tis
2077 @provide_session
2078 def set_task_instance_state(
2079 self,
2080 *,
2081 task_id: str,
2082 map_indexes: Collection[int] | None = None,
2083 execution_date: datetime | None = None,
2084 run_id: str | None = None,
2085 state: TaskInstanceState,
2086 upstream: bool = False,
2087 downstream: bool = False,
2088 future: bool = False,
2089 past: bool = False,
2090 commit: bool = True,
2091 session=NEW_SESSION,
2092 ) -> list[TaskInstance]:
2093 """
2094 Set the state of a TaskInstance and clear downstream tasks in failed or upstream_failed state.
2096 :param task_id: Task ID of the TaskInstance
2097 :param map_indexes: Only set TaskInstance if its map_index matches.
2098 If None (default), all mapped TaskInstances of the task are set.
2099 :param execution_date: Execution date of the TaskInstance
2100 :param run_id: The run_id of the TaskInstance
2101 :param state: State to set the TaskInstance to
2102 :param upstream: Include all upstream tasks of the given task_id
2103 :param downstream: Include all downstream tasks of the given task_id
2104 :param future: Include all future TaskInstances of the given task_id
2105 :param commit: Commit changes
2106 :param past: Include all past TaskInstances of the given task_id
2107 """
2108 from airflow.api.common.mark_tasks import set_state
2110 if not exactly_one(execution_date, run_id):
2111 raise ValueError("Exactly one of execution_date or run_id must be provided")
2113 task = self.get_task(task_id)
2114 task.dag = self
2116 tasks_to_set_state: list[Operator | tuple[Operator, int]]
2117 if map_indexes is None:
2118 tasks_to_set_state = [task]
2119 else:
2120 tasks_to_set_state = [(task, map_index) for map_index in map_indexes]
2122 altered = set_state(
2123 tasks=tasks_to_set_state,
2124 execution_date=execution_date,
2125 run_id=run_id,
2126 upstream=upstream,
2127 downstream=downstream,
2128 future=future,
2129 past=past,
2130 state=state,
2131 commit=commit,
2132 session=session,
2133 )
2135 if not commit:
2136 return altered
2138 # Clear downstream tasks that are in failed/upstream_failed state to resume them.
2139 # Flush the session so that the tasks marked success are reflected in the db.
2140 session.flush()
2141 subdag = self.partial_subset(
2142 task_ids_or_regex={task_id},
2143 include_downstream=True,
2144 include_upstream=False,
2145 )
2147 if execution_date is None:
2148 dag_run = session.scalars(
2149 select(DagRun).where(DagRun.run_id == run_id, DagRun.dag_id == self.dag_id)
2150 ).one() # Raises an error if not found
2151 resolve_execution_date = dag_run.execution_date
2152 else:
2153 resolve_execution_date = execution_date
2155 end_date = resolve_execution_date if not future else None
2156 start_date = resolve_execution_date if not past else None
2158 subdag.clear(
2159 start_date=start_date,
2160 end_date=end_date,
2161 include_subdags=True,
2162 include_parentdag=True,
2163 only_failed=True,
2164 session=session,
2165 # Exclude the task itself from being cleared
2166 exclude_task_ids=frozenset({task_id}),
2167 )
2169 return altered
2171 @provide_session
2172 def set_task_group_state(
2173 self,
2174 *,
2175 group_id: str,
2176 execution_date: datetime | None = None,
2177 run_id: str | None = None,
2178 state: TaskInstanceState,
2179 upstream: bool = False,
2180 downstream: bool = False,
2181 future: bool = False,
2182 past: bool = False,
2183 commit: bool = True,
2184 session: Session = NEW_SESSION,
2185 ) -> list[TaskInstance]:
2186 """
2187 Set TaskGroup to the given state and clear downstream tasks in failed or upstream_failed state.
2189 :param group_id: The group_id of the TaskGroup
2190 :param execution_date: Execution date of the TaskInstance
2191 :param run_id: The run_id of the TaskInstance
2192 :param state: State to set the TaskInstance to
2193 :param upstream: Include all upstream tasks of the given task_id
2194 :param downstream: Include all downstream tasks of the given task_id
2195 :param future: Include all future TaskInstances of the given task_id
2196 :param commit: Commit changes
2197 :param past: Include all past TaskInstances of the given task_id
2198 :param session: new session
2199 """
2200 from airflow.api.common.mark_tasks import set_state
2202 if not exactly_one(execution_date, run_id):
2203 raise ValueError("Exactly one of execution_date or run_id must be provided")
2205 tasks_to_set_state: list[BaseOperator | tuple[BaseOperator, int]] = []
2206 task_ids: list[str] = []
2208 if execution_date is None:
2209 dag_run = session.scalars(
2210 select(DagRun).where(DagRun.run_id == run_id, DagRun.dag_id == self.dag_id)
2211 ).one() # Raises an error if not found
2212 resolve_execution_date = dag_run.execution_date
2213 else:
2214 resolve_execution_date = execution_date
2216 end_date = resolve_execution_date if not future else None
2217 start_date = resolve_execution_date if not past else None
2219 task_group_dict = self.task_group.get_task_group_dict()
2220 task_group = task_group_dict.get(group_id)
2221 if task_group is None:
2222 raise ValueError("TaskGroup {group_id} could not be found")
2223 tasks_to_set_state = [task for task in task_group.iter_tasks() if isinstance(task, BaseOperator)]
2224 task_ids = [task.task_id for task in task_group.iter_tasks()]
2225 dag_runs_query = select(DagRun.id).where(DagRun.dag_id == self.dag_id)
2226 if start_date is None and end_date is None:
2227 dag_runs_query = dag_runs_query.where(DagRun.execution_date == start_date)
2228 else:
2229 if start_date is not None:
2230 dag_runs_query = dag_runs_query.where(DagRun.execution_date >= start_date)
2231 if end_date is not None:
2232 dag_runs_query = dag_runs_query.where(DagRun.execution_date <= end_date)
2234 with lock_rows(dag_runs_query, session):
2235 altered = set_state(
2236 tasks=tasks_to_set_state,
2237 execution_date=execution_date,
2238 run_id=run_id,
2239 upstream=upstream,
2240 downstream=downstream,
2241 future=future,
2242 past=past,
2243 state=state,
2244 commit=commit,
2245 session=session,
2246 )
2247 if not commit:
2248 return altered
2250 # Clear downstream tasks that are in failed/upstream_failed state to resume them.
2251 # Flush the session so that the tasks marked success are reflected in the db.
2252 session.flush()
2253 task_subset = self.partial_subset(
2254 task_ids_or_regex=task_ids,
2255 include_downstream=True,
2256 include_upstream=False,
2257 )
2259 task_subset.clear(
2260 start_date=start_date,
2261 end_date=end_date,
2262 include_subdags=True,
2263 include_parentdag=True,
2264 only_failed=True,
2265 session=session,
2266 # Exclude the task from the current group from being cleared
2267 exclude_task_ids=frozenset(task_ids),
2268 )
2270 return altered
2272 @property
2273 def roots(self) -> list[Operator]:
2274 """Return nodes with no parents. These are first to execute and are called roots or root nodes."""
2275 return [task for task in self.tasks if not task.upstream_list]
2277 @property
2278 def leaves(self) -> list[Operator]:
2279 """Return nodes with no children. These are last to execute and are called leaves or leaf nodes."""
2280 return [task for task in self.tasks if not task.downstream_list]
2282 def topological_sort(self, include_subdag_tasks: bool = False):
2283 """
2284 Sorts tasks in topographical order, such that a task comes after any of its upstream dependencies.
2286 Deprecated in place of ``task_group.topological_sort``
2287 """
2288 from airflow.utils.task_group import TaskGroup
2290 def nested_topo(group):
2291 for node in group.topological_sort(_include_subdag_tasks=include_subdag_tasks):
2292 if isinstance(node, TaskGroup):
2293 yield from nested_topo(node)
2294 else:
2295 yield node
2297 return tuple(nested_topo(self.task_group))
2299 @provide_session
2300 def set_dag_runs_state(
2301 self,
2302 state: DagRunState = DagRunState.RUNNING,
2303 session: Session = NEW_SESSION,
2304 start_date: datetime | None = None,
2305 end_date: datetime | None = None,
2306 dag_ids: list[str] | None = None,
2307 ) -> None:
2308 warnings.warn(
2309 "This method is deprecated and will be removed in a future version.",
2310 RemovedInAirflow3Warning,
2311 stacklevel=3,
2312 )
2313 dag_ids = dag_ids or [self.dag_id]
2314 query = update(DagRun).where(DagRun.dag_id.in_(dag_ids))
2315 if start_date:
2316 query = query.where(DagRun.execution_date >= start_date)
2317 if end_date:
2318 query = query.where(DagRun.execution_date <= end_date)
2319 session.execute(query.values(state=state).execution_options(synchronize_session="fetch"))
2321 @provide_session
2322 def clear(
2323 self,
2324 task_ids: Collection[str | tuple[str, int]] | None = None,
2325 start_date: datetime | None = None,
2326 end_date: datetime | None = None,
2327 only_failed: bool = False,
2328 only_running: bool = False,
2329 confirm_prompt: bool = False,
2330 include_subdags: bool = True,
2331 include_parentdag: bool = True,
2332 dag_run_state: DagRunState = DagRunState.QUEUED,
2333 dry_run: bool = False,
2334 session: Session = NEW_SESSION,
2335 get_tis: bool = False,
2336 recursion_depth: int = 0,
2337 max_recursion_depth: int | None = None,
2338 dag_bag: DagBag | None = None,
2339 exclude_task_ids: frozenset[str] | frozenset[tuple[str, int]] | None = frozenset(),
2340 ) -> int | Iterable[TaskInstance]:
2341 """
2342 Clear a set of task instances associated with the current dag for a specified date range.
2344 :param task_ids: List of task ids or (``task_id``, ``map_index``) tuples to clear
2345 :param start_date: The minimum execution_date to clear
2346 :param end_date: The maximum execution_date to clear
2347 :param only_failed: Only clear failed tasks
2348 :param only_running: Only clear running tasks.
2349 :param confirm_prompt: Ask for confirmation
2350 :param include_subdags: Clear tasks in subdags and clear external tasks
2351 indicated by ExternalTaskMarker
2352 :param include_parentdag: Clear tasks in the parent dag of the subdag.
2353 :param dag_run_state: state to set DagRun to. If set to False, dagrun state will not
2354 be changed.
2355 :param dry_run: Find the tasks to clear but don't clear them.
2356 :param session: The sqlalchemy session to use
2357 :param dag_bag: The DagBag used to find the dags subdags (Optional)
2358 :param exclude_task_ids: A set of ``task_id`` or (``task_id``, ``map_index``)
2359 tuples that should not be cleared
2360 """
2361 if get_tis:
2362 warnings.warn(
2363 "Passing `get_tis` to dag.clear() is deprecated. Use `dry_run` parameter instead.",
2364 RemovedInAirflow3Warning,
2365 stacklevel=2,
2366 )
2367 dry_run = True
2369 if recursion_depth:
2370 warnings.warn(
2371 "Passing `recursion_depth` to dag.clear() is deprecated.",
2372 RemovedInAirflow3Warning,
2373 stacklevel=2,
2374 )
2375 if max_recursion_depth:
2376 warnings.warn(
2377 "Passing `max_recursion_depth` to dag.clear() is deprecated.",
2378 RemovedInAirflow3Warning,
2379 stacklevel=2,
2380 )
2382 state: list[TaskInstanceState] = []
2383 if only_failed:
2384 state += [TaskInstanceState.FAILED, TaskInstanceState.UPSTREAM_FAILED]
2385 if only_running:
2386 # Yes, having `+=` doesn't make sense, but this was the existing behaviour
2387 state += [TaskInstanceState.RUNNING]
2389 tis = self._get_task_instances(
2390 task_ids=task_ids,
2391 start_date=start_date,
2392 end_date=end_date,
2393 run_id=None,
2394 state=state,
2395 include_subdags=include_subdags,
2396 include_parentdag=include_parentdag,
2397 include_dependent_dags=include_subdags, # compat, yes this is not a typo
2398 session=session,
2399 dag_bag=dag_bag,
2400 exclude_task_ids=exclude_task_ids,
2401 )
2403 if dry_run:
2404 return session.scalars(tis).all()
2406 tis = session.scalars(tis).all()
2408 count = len(list(tis))
2409 do_it = True
2410 if count == 0:
2411 return 0
2412 if confirm_prompt:
2413 ti_list = "\n".join(str(t) for t in tis)
2414 question = f"You are about to delete these {count} tasks:\n{ti_list}\n\nAre you sure? [y/n]"
2415 do_it = utils.helpers.ask_yesno(question)
2417 if do_it:
2418 clear_task_instances(
2419 list(tis),
2420 session,
2421 dag=self,
2422 dag_run_state=dag_run_state,
2423 )
2424 else:
2425 count = 0
2426 print("Cancelled, nothing was cleared.")
2428 session.flush()
2429 return count
2431 @classmethod
2432 def clear_dags(
2433 cls,
2434 dags,
2435 start_date=None,
2436 end_date=None,
2437 only_failed=False,
2438 only_running=False,
2439 confirm_prompt=False,
2440 include_subdags=True,
2441 include_parentdag=False,
2442 dag_run_state=DagRunState.QUEUED,
2443 dry_run=False,
2444 ):
2445 all_tis = []
2446 for dag in dags:
2447 tis = dag.clear(
2448 start_date=start_date,
2449 end_date=end_date,
2450 only_failed=only_failed,
2451 only_running=only_running,
2452 confirm_prompt=False,
2453 include_subdags=include_subdags,
2454 include_parentdag=include_parentdag,
2455 dag_run_state=dag_run_state,
2456 dry_run=True,
2457 )
2458 all_tis.extend(tis)
2460 if dry_run:
2461 return all_tis
2463 count = len(all_tis)
2464 do_it = True
2465 if count == 0:
2466 print("Nothing to clear.")
2467 return 0
2468 if confirm_prompt:
2469 ti_list = "\n".join(str(t) for t in all_tis)
2470 question = f"You are about to delete these {count} tasks:\n{ti_list}\n\nAre you sure? [y/n]"
2471 do_it = utils.helpers.ask_yesno(question)
2473 if do_it:
2474 for dag in dags:
2475 dag.clear(
2476 start_date=start_date,
2477 end_date=end_date,
2478 only_failed=only_failed,
2479 only_running=only_running,
2480 confirm_prompt=False,
2481 include_subdags=include_subdags,
2482 dag_run_state=dag_run_state,
2483 dry_run=False,
2484 )
2485 else:
2486 count = 0
2487 print("Cancelled, nothing was cleared.")
2488 return count
2490 def __deepcopy__(self, memo):
2491 # Switcharoo to go around deepcopying objects coming through the
2492 # backdoor
2493 cls = self.__class__
2494 result = cls.__new__(cls)
2495 memo[id(self)] = result
2496 for k, v in self.__dict__.items():
2497 if k not in ("user_defined_macros", "user_defined_filters", "_log"):
2498 setattr(result, k, copy.deepcopy(v, memo))
2500 result.user_defined_macros = self.user_defined_macros
2501 result.user_defined_filters = self.user_defined_filters
2502 if hasattr(self, "_log"):
2503 result._log = self._log
2504 return result
2506 def sub_dag(self, *args, **kwargs):
2507 """Use `airflow.models.DAG.partial_subset`, this method is deprecated."""
2508 warnings.warn(
2509 "This method is deprecated and will be removed in a future version. Please use partial_subset",
2510 RemovedInAirflow3Warning,
2511 stacklevel=2,
2512 )
2513 return self.partial_subset(*args, **kwargs)
2515 def partial_subset(
2516 self,
2517 task_ids_or_regex: str | Pattern | Iterable[str],
2518 include_downstream=False,
2519 include_upstream=True,
2520 include_direct_upstream=False,
2521 ):
2522 """
2523 Return a subset of the current dag based on regex matching one or more tasks.
2525 Returns a subset of the current dag as a deep copy of the current dag
2526 based on a regex that should match one or many tasks, and includes
2527 upstream and downstream neighbours based on the flag passed.
2529 :param task_ids_or_regex: Either a list of task_ids, or a regex to
2530 match against task ids (as a string, or compiled regex pattern).
2531 :param include_downstream: Include all downstream tasks of matched
2532 tasks, in addition to matched tasks.
2533 :param include_upstream: Include all upstream tasks of matched tasks,
2534 in addition to matched tasks.
2535 :param include_direct_upstream: Include all tasks directly upstream of matched
2536 and downstream (if include_downstream = True) tasks
2537 """
2538 from airflow.models.baseoperator import BaseOperator
2539 from airflow.models.mappedoperator import MappedOperator
2541 # deep-copying self.task_dict and self._task_group takes a long time, and we don't want all
2542 # the tasks anyway, so we copy the tasks manually later
2543 memo = {id(self.task_dict): None, id(self._task_group): None}
2544 dag = copy.deepcopy(self, memo) # type: ignore
2546 if isinstance(task_ids_or_regex, (str, Pattern)):
2547 matched_tasks = [t for t in self.tasks if re2.findall(task_ids_or_regex, t.task_id)]
2548 else:
2549 matched_tasks = [t for t in self.tasks if t.task_id in task_ids_or_regex]
2551 also_include_ids: set[str] = set()
2552 for t in matched_tasks:
2553 if include_downstream:
2554 for rel in t.get_flat_relatives(upstream=False):
2555 also_include_ids.add(rel.task_id)
2556 if rel not in matched_tasks: # if it's in there, we're already processing it
2557 # need to include setups and teardowns for tasks that are in multiple
2558 # non-collinear setup/teardown paths
2559 if not rel.is_setup and not rel.is_teardown:
2560 also_include_ids.update(
2561 x.task_id for x in rel.get_upstreams_only_setups_and_teardowns()
2562 )
2563 if include_upstream:
2564 also_include_ids.update(x.task_id for x in t.get_upstreams_follow_setups())
2565 else:
2566 if not t.is_setup and not t.is_teardown:
2567 also_include_ids.update(x.task_id for x in t.get_upstreams_only_setups_and_teardowns())
2568 if t.is_setup and not include_downstream:
2569 also_include_ids.update(x.task_id for x in t.downstream_list if x.is_teardown)
2571 also_include: list[Operator] = [self.task_dict[x] for x in also_include_ids]
2572 direct_upstreams: list[Operator] = []
2573 if include_direct_upstream:
2574 for t in itertools.chain(matched_tasks, also_include):
2575 upstream = (u for u in t.upstream_list if isinstance(u, (BaseOperator, MappedOperator)))
2576 direct_upstreams.extend(upstream)
2578 # Compiling the unique list of tasks that made the cut
2579 # Make sure to not recursively deepcopy the dag or task_group while copying the task.
2580 # task_group is reset later
2581 def _deepcopy_task(t) -> Operator:
2582 memo.setdefault(id(t.task_group), None)
2583 return copy.deepcopy(t, memo)
2585 dag.task_dict = {
2586 t.task_id: _deepcopy_task(t)
2587 for t in itertools.chain(matched_tasks, also_include, direct_upstreams)
2588 }
2590 def filter_task_group(group, parent_group):
2591 """Exclude tasks not included in the subdag from the given TaskGroup."""
2592 # We want to deepcopy _most but not all_ attributes of the task group, so we create a shallow copy
2593 # and then manually deep copy the instances. (memo argument to deepcopy only works for instances
2594 # of classes, not "native" properties of an instance)
2595 copied = copy.copy(group)
2597 memo[id(group.children)] = {}
2598 if parent_group:
2599 memo[id(group.parent_group)] = parent_group
2600 for attr, value in copied.__dict__.items():
2601 if id(value) in memo:
2602 value = memo[id(value)]
2603 else:
2604 value = copy.deepcopy(value, memo)
2605 copied.__dict__[attr] = value
2607 proxy = weakref.proxy(copied)
2609 for child in group.children.values():
2610 if isinstance(child, AbstractOperator):
2611 if child.task_id in dag.task_dict:
2612 task = copied.children[child.task_id] = dag.task_dict[child.task_id]
2613 task.task_group = proxy
2614 else:
2615 copied.used_group_ids.discard(child.task_id)
2616 else:
2617 filtered_child = filter_task_group(child, proxy)
2619 # Only include this child TaskGroup if it is non-empty.
2620 if filtered_child.children:
2621 copied.children[child.group_id] = filtered_child
2623 return copied
2625 dag._task_group = filter_task_group(self.task_group, None)
2627 # Removing upstream/downstream references to tasks and TaskGroups that did not make
2628 # the cut.
2629 subdag_task_groups = dag.task_group.get_task_group_dict()
2630 for group in subdag_task_groups.values():
2631 group.upstream_group_ids.intersection_update(subdag_task_groups)
2632 group.downstream_group_ids.intersection_update(subdag_task_groups)
2633 group.upstream_task_ids.intersection_update(dag.task_dict)
2634 group.downstream_task_ids.intersection_update(dag.task_dict)
2636 for t in dag.tasks:
2637 # Removing upstream/downstream references to tasks that did not
2638 # make the cut
2639 t.upstream_task_ids.intersection_update(dag.task_dict)
2640 t.downstream_task_ids.intersection_update(dag.task_dict)
2642 if len(dag.tasks) < len(self.tasks):
2643 dag.partial = True
2645 return dag
2647 def has_task(self, task_id: str):
2648 return task_id in self.task_dict
2650 def has_task_group(self, task_group_id: str) -> bool:
2651 return task_group_id in self.task_group_dict
2653 @functools.cached_property
2654 def task_group_dict(self):
2655 return {k: v for k, v in self._task_group.get_task_group_dict().items() if k is not None}
2657 def get_task(self, task_id: str, include_subdags: bool = False) -> Operator:
2658 if task_id in self.task_dict:
2659 return self.task_dict[task_id]
2660 if include_subdags:
2661 for dag in self.subdags:
2662 if task_id in dag.task_dict:
2663 return dag.task_dict[task_id]
2664 raise TaskNotFound(f"Task {task_id} not found")
2666 def pickle_info(self):
2667 d = {}
2668 d["is_picklable"] = True
2669 try:
2670 dttm = timezone.utcnow()
2671 pickled = pickle.dumps(self)
2672 d["pickle_len"] = len(pickled)
2673 d["pickling_duration"] = str(timezone.utcnow() - dttm)
2674 except Exception as e:
2675 self.log.debug(e)
2676 d["is_picklable"] = False
2677 d["stacktrace"] = traceback.format_exc()
2678 return d
2680 @provide_session
2681 def pickle(self, session=NEW_SESSION) -> DagPickle:
2682 dag = session.scalar(select(DagModel).where(DagModel.dag_id == self.dag_id).limit(1))
2683 dp = None
2684 if dag and dag.pickle_id:
2685 dp = session.scalar(select(DagPickle).where(DagPickle.id == dag.pickle_id).limit(1))
2686 if not dp or dp.pickle != self:
2687 dp = DagPickle(dag=self)
2688 session.add(dp)
2689 self.last_pickled = timezone.utcnow()
2690 session.commit()
2691 self.pickle_id = dp.id
2693 return dp
2695 def tree_view(self) -> None:
2696 """Print an ASCII tree representation of the DAG."""
2697 for tmp in self._generate_tree_view():
2698 print(tmp)
2700 def _generate_tree_view(self) -> Generator[str, None, None]:
2701 def get_downstream(task, level=0) -> Generator[str, None, None]:
2702 yield (" " * level * 4) + str(task)
2703 level += 1
2704 for tmp_task in sorted(task.downstream_list, key=lambda x: x.task_id):
2705 yield from get_downstream(tmp_task, level)
2707 for t in sorted(self.roots, key=lambda x: x.task_id):
2708 yield from get_downstream(t)
2710 def get_tree_view(self) -> str:
2711 """Return an ASCII tree representation of the DAG."""
2712 rst = ""
2713 for tmp in self._generate_tree_view():
2714 rst += tmp + "\n"
2715 return rst
2717 @property
2718 def task(self) -> TaskDecoratorCollection:
2719 from airflow.decorators import task
2721 return cast("TaskDecoratorCollection", functools.partial(task, dag=self))
2723 def add_task(self, task: Operator) -> None:
2724 """
2725 Add a task to the DAG.
2727 :param task: the task you want to add
2728 """
2729 FailStopDagInvalidTriggerRule.check(dag=self, trigger_rule=task.trigger_rule)
2731 from airflow.utils.task_group import TaskGroupContext
2733 # if the task has no start date, assign it the same as the DAG
2734 if not task.start_date:
2735 task.start_date = self.start_date
2736 # otherwise, the task will start on the later of its own start date and
2737 # the DAG's start date
2738 elif self.start_date:
2739 task.start_date = max(task.start_date, self.start_date)
2741 # if the task has no end date, assign it the same as the dag
2742 if not task.end_date:
2743 task.end_date = self.end_date
2744 # otherwise, the task will end on the earlier of its own end date and
2745 # the DAG's end date
2746 elif task.end_date and self.end_date:
2747 task.end_date = min(task.end_date, self.end_date)
2749 task_id = task.task_id
2750 if not task.task_group:
2751 task_group = TaskGroupContext.get_current_task_group(self)
2752 if task_group:
2753 task_id = task_group.child_id(task_id)
2754 task_group.add(task)
2756 if (
2757 task_id in self.task_dict and self.task_dict[task_id] is not task
2758 ) or task_id in self._task_group.used_group_ids:
2759 raise DuplicateTaskIdFound(f"Task id '{task_id}' has already been added to the DAG")
2760 else:
2761 self.task_dict[task_id] = task
2762 task.dag = self
2763 # Add task_id to used_group_ids to prevent group_id and task_id collisions.
2764 self._task_group.used_group_ids.add(task_id)
2766 self.task_count = len(self.task_dict)
2768 def add_tasks(self, tasks: Iterable[Operator]) -> None:
2769 """
2770 Add a list of tasks to the DAG.
2772 :param tasks: a lit of tasks you want to add
2773 """
2774 for task in tasks:
2775 self.add_task(task)
2777 def _remove_task(self, task_id: str) -> None:
2778 # This is "private" as removing could leave a hole in dependencies if done incorrectly, and this
2779 # doesn't guard against that
2780 task = self.task_dict.pop(task_id)
2781 tg = getattr(task, "task_group", None)
2782 if tg:
2783 tg._remove(task)
2785 self.task_count = len(self.task_dict)
2787 def run(
2788 self,
2789 start_date=None,
2790 end_date=None,
2791 mark_success=False,
2792 local=False,
2793 executor=None,
2794 donot_pickle=airflow_conf.getboolean("core", "donot_pickle"),
2795 ignore_task_deps=False,
2796 ignore_first_depends_on_past=True,
2797 pool=None,
2798 delay_on_limit_secs=1.0,
2799 verbose=False,
2800 conf=None,
2801 rerun_failed_tasks=False,
2802 run_backwards=False,
2803 run_at_least_once=False,
2804 continue_on_failures=False,
2805 disable_retry=False,
2806 ):
2807 """
2808 Run the DAG.
2810 :param start_date: the start date of the range to run
2811 :param end_date: the end date of the range to run
2812 :param mark_success: True to mark jobs as succeeded without running them
2813 :param local: True to run the tasks using the LocalExecutor
2814 :param executor: The executor instance to run the tasks
2815 :param donot_pickle: True to avoid pickling DAG object and send to workers
2816 :param ignore_task_deps: True to skip upstream tasks
2817 :param ignore_first_depends_on_past: True to ignore depends_on_past
2818 dependencies for the first set of tasks only
2819 :param pool: Resource pool to use
2820 :param delay_on_limit_secs: Time in seconds to wait before next attempt to run
2821 dag run when max_active_runs limit has been reached
2822 :param verbose: Make logging output more verbose
2823 :param conf: user defined dictionary passed from CLI
2824 :param rerun_failed_tasks:
2825 :param run_backwards:
2826 :param run_at_least_once: If true, always run the DAG at least once even
2827 if no logical run exists within the time range.
2828 """
2829 from airflow.jobs.backfill_job_runner import BackfillJobRunner
2831 if not executor and local:
2832 from airflow.executors.local_executor import LocalExecutor
2834 executor = LocalExecutor()
2835 elif not executor:
2836 from airflow.executors.executor_loader import ExecutorLoader
2838 executor = ExecutorLoader.get_default_executor()
2839 from airflow.jobs.job import Job
2841 job = Job(executor=executor)
2842 job_runner = BackfillJobRunner(
2843 job=job,
2844 dag=self,
2845 start_date=start_date,
2846 end_date=end_date,
2847 mark_success=mark_success,
2848 donot_pickle=donot_pickle,
2849 ignore_task_deps=ignore_task_deps,
2850 ignore_first_depends_on_past=ignore_first_depends_on_past,
2851 pool=pool,
2852 delay_on_limit_secs=delay_on_limit_secs,
2853 verbose=verbose,
2854 conf=conf,
2855 rerun_failed_tasks=rerun_failed_tasks,
2856 run_backwards=run_backwards,
2857 run_at_least_once=run_at_least_once,
2858 continue_on_failures=continue_on_failures,
2859 disable_retry=disable_retry,
2860 )
2861 run_job(job=job, execute_callable=job_runner._execute)
2863 def cli(self):
2864 """Exposes a CLI specific to this DAG."""
2865 check_cycle(self)
2867 from airflow.cli import cli_parser
2869 parser = cli_parser.get_parser(dag_parser=True)
2870 args = parser.parse_args()
2871 args.func(args, self)
2873 @provide_session
2874 def test(
2875 self,
2876 execution_date: datetime | None = None,
2877 run_conf: dict[str, Any] | None = None,
2878 conn_file_path: str | None = None,
2879 variable_file_path: str | None = None,
2880 session: Session = NEW_SESSION,
2881 ) -> DagRun:
2882 """
2883 Execute one single DagRun for a given DAG and execution date.
2885 :param execution_date: execution date for the DAG run
2886 :param run_conf: configuration to pass to newly created dagrun
2887 :param conn_file_path: file path to a connection file in either yaml or json
2888 :param variable_file_path: file path to a variable file in either yaml or json
2889 :param session: database connection (optional)
2890 """
2892 def add_logger_if_needed(ti: TaskInstance):
2893 """Add a formatted logger to the task instance.
2895 This allows all logs to surface to the command line, instead of into
2896 a task file. Since this is a local test run, it is much better for
2897 the user to see logs in the command line, rather than needing to
2898 search for a log file.
2900 :param ti: The task instance that will receive a logger.
2901 """
2902 format = logging.Formatter("[%(asctime)s] {%(filename)s:%(lineno)d} %(levelname)s - %(message)s")
2903 handler = logging.StreamHandler(sys.stdout)
2904 handler.level = logging.INFO
2905 handler.setFormatter(format)
2906 # only add log handler once
2907 if not any(isinstance(h, logging.StreamHandler) for h in ti.log.handlers):
2908 self.log.debug("Adding Streamhandler to taskinstance %s", ti.task_id)
2909 ti.log.addHandler(handler)
2911 exit_stack = ExitStack()
2912 if conn_file_path or variable_file_path:
2913 local_secrets = LocalFilesystemBackend(
2914 variables_file_path=variable_file_path, connections_file_path=conn_file_path
2915 )
2916 secrets_backend_list.insert(0, local_secrets)
2917 exit_stack.callback(lambda: secrets_backend_list.pop(0))
2919 with exit_stack:
2920 execution_date = execution_date or timezone.utcnow()
2921 self.validate()
2922 self.log.debug("Clearing existing task instances for execution date %s", execution_date)
2923 self.clear(
2924 start_date=execution_date,
2925 end_date=execution_date,
2926 dag_run_state=False, # type: ignore
2927 session=session,
2928 )
2929 self.log.debug("Getting dagrun for dag %s", self.dag_id)
2930 logical_date = timezone.coerce_datetime(execution_date)
2931 data_interval = self.timetable.infer_manual_data_interval(run_after=logical_date)
2932 dr: DagRun = _get_or_create_dagrun(
2933 dag=self,
2934 start_date=execution_date,
2935 execution_date=execution_date,
2936 run_id=DagRun.generate_run_id(DagRunType.MANUAL, execution_date),
2937 session=session,
2938 conf=run_conf,
2939 data_interval=data_interval,
2940 )
2942 tasks = self.task_dict
2943 self.log.debug("starting dagrun")
2944 # Instead of starting a scheduler, we run the minimal loop possible to check
2945 # for task readiness and dependency management. This is notably faster
2946 # than creating a BackfillJob and allows us to surface logs to the user
2947 while dr.state == DagRunState.RUNNING:
2948 session.expire_all()
2949 schedulable_tis, _ = dr.update_state(session=session)
2950 for s in schedulable_tis:
2951 if s.state != TaskInstanceState.UP_FOR_RESCHEDULE:
2952 s.try_number += 1
2953 s.state = TaskInstanceState.SCHEDULED
2954 session.commit()
2955 # triggerer may mark tasks scheduled so we read from DB
2956 all_tis = set(dr.get_task_instances(session=session))
2957 scheduled_tis = {x for x in all_tis if x.state == TaskInstanceState.SCHEDULED}
2958 ids_unrunnable = {x for x in all_tis if x.state not in State.finished} - scheduled_tis
2959 if not scheduled_tis and ids_unrunnable:
2960 self.log.warning("No tasks to run. unrunnable tasks: %s", ids_unrunnable)
2961 time.sleep(1)
2962 triggerer_running = _triggerer_is_healthy()
2963 for ti in scheduled_tis:
2964 try:
2965 add_logger_if_needed(ti)
2966 ti.task = tasks[ti.task_id]
2967 _run_task(ti=ti, inline_trigger=not triggerer_running, session=session)
2968 except Exception:
2969 self.log.exception("Task failed; ti=%s", ti)
2970 return dr
2972 @provide_session
2973 def create_dagrun(
2974 self,
2975 state: DagRunState,
2976 execution_date: datetime | None = None,
2977 run_id: str | None = None,
2978 start_date: datetime | None = None,
2979 external_trigger: bool | None = False,
2980 conf: dict | None = None,
2981 run_type: DagRunType | None = None,
2982 session: Session = NEW_SESSION,
2983 dag_hash: str | None = None,
2984 creating_job_id: int | None = None,
2985 data_interval: tuple[datetime, datetime] | None = None,
2986 ):
2987 """
2988 Create a dag run from this dag including the tasks associated with this dag.
2990 Returns the dag run.
2992 :param run_id: defines the run id for this dag run
2993 :param run_type: type of DagRun
2994 :param execution_date: the execution date of this dag run
2995 :param state: the state of the dag run
2996 :param start_date: the date this dag run should be evaluated
2997 :param external_trigger: whether this dag run is externally triggered
2998 :param conf: Dict containing configuration/parameters to pass to the DAG
2999 :param creating_job_id: id of the job creating this DagRun
3000 :param session: database session
3001 :param dag_hash: Hash of Serialized DAG
3002 :param data_interval: Data interval of the DagRun
3003 """
3004 logical_date = timezone.coerce_datetime(execution_date)
3006 if data_interval and not isinstance(data_interval, DataInterval):
3007 data_interval = DataInterval(*map(timezone.coerce_datetime, data_interval))
3009 if data_interval is None and logical_date is not None:
3010 warnings.warn(
3011 "Calling `DAG.create_dagrun()` without an explicit data interval is deprecated",
3012 RemovedInAirflow3Warning,
3013 stacklevel=3,
3014 )
3015 if run_type == DagRunType.MANUAL:
3016 data_interval = self.timetable.infer_manual_data_interval(run_after=logical_date)
3017 else:
3018 data_interval = self.infer_automated_data_interval(logical_date)
3020 if run_type is None or isinstance(run_type, DagRunType):
3021 pass
3022 elif isinstance(run_type, str): # Compatibility: run_type used to be a str.
3023 run_type = DagRunType(run_type)
3024 else:
3025 raise ValueError(f"`run_type` should be a DagRunType, not {type(run_type)}")
3027 if run_id: # Infer run_type from run_id if needed.
3028 if not isinstance(run_id, str):
3029 raise ValueError(f"`run_id` should be a str, not {type(run_id)}")
3030 inferred_run_type = DagRunType.from_run_id(run_id)
3031 if run_type is None:
3032 # No explicit type given, use the inferred type.
3033 run_type = inferred_run_type
3034 elif run_type == DagRunType.MANUAL and inferred_run_type != DagRunType.MANUAL:
3035 # Prevent a manual run from using an ID that looks like a scheduled run.
3036 raise ValueError(
3037 f"A {run_type.value} DAG run cannot use ID {run_id!r} since it "
3038 f"is reserved for {inferred_run_type.value} runs"
3039 )
3040 elif run_type and logical_date is not None: # Generate run_id from run_type and execution_date.
3041 run_id = self.timetable.generate_run_id(
3042 run_type=run_type, logical_date=logical_date, data_interval=data_interval
3043 )
3044 else:
3045 raise AirflowException(
3046 "Creating DagRun needs either `run_id` or both `run_type` and `execution_date`"
3047 )
3049 regex = airflow_conf.get("scheduler", "allowed_run_id_pattern")
3051 if run_id and not re2.match(RUN_ID_REGEX, run_id):
3052 if not regex.strip() or not re2.match(regex.strip(), run_id):
3053 raise AirflowException(
3054 f"The provided run ID '{run_id}' is invalid. It does not match either "
3055 f"the configured pattern: '{regex}' or the built-in pattern: '{RUN_ID_REGEX}'"
3056 )
3058 # create a copy of params before validating
3059 copied_params = copy.deepcopy(self.params)
3060 copied_params.update(conf or {})
3061 copied_params.validate()
3063 run = _create_orm_dagrun(
3064 dag=self,
3065 dag_id=self.dag_id,
3066 run_id=run_id,
3067 logical_date=logical_date,
3068 start_date=start_date,
3069 external_trigger=external_trigger,
3070 conf=conf,
3071 state=state,
3072 run_type=run_type,
3073 dag_hash=dag_hash,
3074 creating_job_id=creating_job_id,
3075 data_interval=data_interval,
3076 session=session,
3077 )
3078 return run
3080 @classmethod
3081 @provide_session
3082 def bulk_sync_to_db(
3083 cls,
3084 dags: Collection[DAG],
3085 session=NEW_SESSION,
3086 ):
3087 """Use `airflow.models.DAG.bulk_write_to_db`, this method is deprecated."""
3088 warnings.warn(
3089 "This method is deprecated and will be removed in a future version. Please use bulk_write_to_db",
3090 RemovedInAirflow3Warning,
3091 stacklevel=2,
3092 )
3093 return cls.bulk_write_to_db(dags=dags, session=session)
3095 @classmethod
3096 @provide_session
3097 def bulk_write_to_db(
3098 cls,
3099 dags: Collection[DAG],
3100 processor_subdir: str | None = None,
3101 session=NEW_SESSION,
3102 ):
3103 """
3104 Ensure the DagModel rows for the given dags are up-to-date in the dag table in the DB.
3106 Note that this method can be called for both DAGs and SubDAGs. A SubDag is actually a SubDagOperator.
3108 :param dags: the DAG objects to save to the DB
3109 :return: None
3110 """
3111 if not dags:
3112 return
3114 log.info("Sync %s DAGs", len(dags))
3115 dag_by_ids = {dag.dag_id: dag for dag in dags}
3117 dag_ids = set(dag_by_ids)
3118 query = (
3119 select(DagModel)
3120 .options(joinedload(DagModel.tags, innerjoin=False))
3121 .where(DagModel.dag_id.in_(dag_ids))
3122 .options(joinedload(DagModel.schedule_dataset_references))
3123 .options(joinedload(DagModel.task_outlet_dataset_references))
3124 )
3125 query = with_row_locks(query, of=DagModel, session=session)
3126 orm_dags: list[DagModel] = session.scalars(query).unique().all()
3127 existing_dags: dict[str, DagModel] = {x.dag_id: x for x in orm_dags}
3128 missing_dag_ids = dag_ids.difference(existing_dags.keys())
3130 for missing_dag_id in missing_dag_ids:
3131 orm_dag = DagModel(dag_id=missing_dag_id)
3132 dag = dag_by_ids[missing_dag_id]
3133 if dag.is_paused_upon_creation is not None:
3134 orm_dag.is_paused = dag.is_paused_upon_creation
3135 orm_dag.tags = []
3136 log.info("Creating ORM DAG for %s", dag.dag_id)
3137 session.add(orm_dag)
3138 orm_dags.append(orm_dag)
3140 latest_runs: dict[str, DagRun] = {}
3141 num_active_runs: dict[str, int] = {}
3142 # Skip these queries entirely if no DAGs can be scheduled to save time.
3143 if any(dag.timetable.can_be_scheduled for dag in dags):
3144 # Get the latest automated dag run for each existing dag as a single query (avoid n+1 query)
3145 query = cls._get_latest_runs_stmt(dags=list(existing_dags.keys()))
3146 latest_runs = {run.dag_id: run for run in session.scalars(query)}
3148 # Get number of active dagruns for all dags we are processing as a single query.
3149 num_active_runs = DagRun.active_runs_of_dags(dag_ids=existing_dags, session=session)
3151 filelocs = []
3153 for orm_dag in sorted(orm_dags, key=lambda d: d.dag_id):
3154 dag = dag_by_ids[orm_dag.dag_id]
3155 filelocs.append(dag.fileloc)
3156 if dag.is_subdag:
3157 orm_dag.is_subdag = True
3158 orm_dag.fileloc = dag.parent_dag.fileloc # type: ignore
3159 orm_dag.root_dag_id = dag.parent_dag.dag_id # type: ignore
3160 orm_dag.owners = dag.parent_dag.owner # type: ignore
3161 else:
3162 orm_dag.is_subdag = False
3163 orm_dag.fileloc = dag.fileloc
3164 orm_dag.owners = dag.owner
3165 orm_dag.is_active = True
3166 orm_dag.has_import_errors = False
3167 orm_dag.last_parsed_time = timezone.utcnow()
3168 orm_dag.default_view = dag.default_view
3169 orm_dag._dag_display_property_value = dag._dag_display_property_value
3170 orm_dag.description = dag.description
3171 orm_dag.max_active_tasks = dag.max_active_tasks
3172 orm_dag.max_active_runs = dag.max_active_runs
3173 orm_dag.max_consecutive_failed_dag_runs = dag.max_consecutive_failed_dag_runs
3174 orm_dag.has_task_concurrency_limits = any(
3175 t.max_active_tis_per_dag is not None or t.max_active_tis_per_dagrun is not None
3176 for t in dag.tasks
3177 )
3178 orm_dag.schedule_interval = dag.schedule_interval
3179 orm_dag.timetable_description = dag.timetable.description
3180 if (dataset_triggers := dag.dataset_triggers) is None:
3181 orm_dag.dataset_expression = None
3182 else:
3183 orm_dag.dataset_expression = dataset_triggers.as_expression()
3185 orm_dag.processor_subdir = processor_subdir
3187 last_automated_run: DagRun | None = latest_runs.get(dag.dag_id)
3188 if last_automated_run is None:
3189 last_automated_data_interval = None
3190 else:
3191 last_automated_data_interval = dag.get_run_data_interval(last_automated_run)
3192 if num_active_runs.get(dag.dag_id, 0) >= orm_dag.max_active_runs:
3193 orm_dag.next_dagrun_create_after = None
3194 else:
3195 orm_dag.calculate_dagrun_date_fields(dag, last_automated_data_interval)
3197 dag_tags = set(dag.tags or {})
3198 orm_dag_tags = list(orm_dag.tags or [])
3199 for orm_tag in orm_dag_tags:
3200 if orm_tag.name not in dag_tags:
3201 session.delete(orm_tag)
3202 orm_dag.tags.remove(orm_tag)
3203 orm_tag_names = {t.name for t in orm_dag_tags}
3204 for dag_tag in dag_tags:
3205 if dag_tag not in orm_tag_names:
3206 dag_tag_orm = DagTag(name=dag_tag, dag_id=dag.dag_id)
3207 orm_dag.tags.append(dag_tag_orm)
3208 session.add(dag_tag_orm)
3210 orm_dag_links = orm_dag.dag_owner_links or []
3211 for orm_dag_link in orm_dag_links:
3212 if orm_dag_link not in dag.owner_links:
3213 session.delete(orm_dag_link)
3214 for owner_name, owner_link in dag.owner_links.items():
3215 dag_owner_orm = DagOwnerAttributes(dag_id=dag.dag_id, owner=owner_name, link=owner_link)
3216 session.add(dag_owner_orm)
3218 DagCode.bulk_sync_to_db(filelocs, session=session)
3220 from airflow.datasets import Dataset
3221 from airflow.models.dataset import (
3222 DagScheduleDatasetReference,
3223 DatasetModel,
3224 TaskOutletDatasetReference,
3225 )
3227 dag_references = defaultdict(set)
3228 outlet_references = defaultdict(set)
3229 # We can't use a set here as we want to preserve order
3230 outlet_datasets: dict[DatasetModel, None] = {}
3231 input_datasets: dict[DatasetModel, None] = {}
3233 # here we go through dags and tasks to check for dataset references
3234 # if there are now None and previously there were some, we delete them
3235 # if there are now *any*, we add them to the above data structures, and
3236 # later we'll persist them to the database.
3237 for dag in dags:
3238 curr_orm_dag = existing_dags.get(dag.dag_id)
3239 if dag.dataset_triggers is None:
3240 if curr_orm_dag and curr_orm_dag.schedule_dataset_references:
3241 curr_orm_dag.schedule_dataset_references = []
3242 else:
3243 for _, dataset in dag.dataset_triggers.iter_datasets():
3244 dag_references[dag.dag_id].add(dataset.uri)
3245 input_datasets[DatasetModel.from_public(dataset)] = None
3246 curr_outlet_references = curr_orm_dag and curr_orm_dag.task_outlet_dataset_references
3247 for task in dag.tasks:
3248 dataset_outlets = [x for x in task.outlets or [] if isinstance(x, Dataset)]
3249 if not dataset_outlets:
3250 if curr_outlet_references:
3251 this_task_outlet_refs = [
3252 x
3253 for x in curr_outlet_references
3254 if x.dag_id == dag.dag_id and x.task_id == task.task_id
3255 ]
3256 for ref in this_task_outlet_refs:
3257 curr_outlet_references.remove(ref)
3258 for d in dataset_outlets:
3259 outlet_references[(task.dag_id, task.task_id)].add(d.uri)
3260 outlet_datasets[DatasetModel.from_public(d)] = None
3261 all_datasets = outlet_datasets
3262 all_datasets.update(input_datasets)
3264 # store datasets
3265 stored_datasets: dict[str, DatasetModel] = {}
3266 new_datasets: list[DatasetModel] = []
3267 for dataset in all_datasets:
3268 stored_dataset = session.scalar(
3269 select(DatasetModel).where(DatasetModel.uri == dataset.uri).limit(1)
3270 )
3271 if stored_dataset:
3272 # Some datasets may have been previously unreferenced, and therefore orphaned by the
3273 # scheduler. But if we're here, then we have found that dataset again in our DAGs, which
3274 # means that it is no longer an orphan, so set is_orphaned to False.
3275 stored_dataset.is_orphaned = expression.false()
3276 stored_datasets[stored_dataset.uri] = stored_dataset
3277 else:
3278 new_datasets.append(dataset)
3279 dataset_manager.create_datasets(dataset_models=new_datasets, session=session)
3280 stored_datasets.update({dataset.uri: dataset for dataset in new_datasets})
3282 del new_datasets
3283 del all_datasets
3285 # reconcile dag-schedule-on-dataset references
3286 for dag_id, uri_list in dag_references.items():
3287 dag_refs_needed = {
3288 DagScheduleDatasetReference(dataset_id=stored_datasets[uri].id, dag_id=dag_id)
3289 for uri in uri_list
3290 }
3291 dag_refs_stored = set(
3292 existing_dags.get(dag_id)
3293 and existing_dags.get(dag_id).schedule_dataset_references # type: ignore
3294 or []
3295 )
3296 dag_refs_to_add = {x for x in dag_refs_needed if x not in dag_refs_stored}
3297 session.bulk_save_objects(dag_refs_to_add)
3298 for obj in dag_refs_stored - dag_refs_needed:
3299 session.delete(obj)
3301 existing_task_outlet_refs_dict = defaultdict(set)
3302 for dag_id, orm_dag in existing_dags.items():
3303 for todr in orm_dag.task_outlet_dataset_references:
3304 existing_task_outlet_refs_dict[(dag_id, todr.task_id)].add(todr)
3306 # reconcile task-outlet-dataset references
3307 for (dag_id, task_id), uri_list in outlet_references.items():
3308 task_refs_needed = {
3309 TaskOutletDatasetReference(dataset_id=stored_datasets[uri].id, dag_id=dag_id, task_id=task_id)
3310 for uri in uri_list
3311 }
3312 task_refs_stored = existing_task_outlet_refs_dict[(dag_id, task_id)]
3313 task_refs_to_add = {x for x in task_refs_needed if x not in task_refs_stored}
3314 session.bulk_save_objects(task_refs_to_add)
3315 for obj in task_refs_stored - task_refs_needed:
3316 session.delete(obj)
3318 # Issue SQL/finish "Unit of Work", but let @provide_session commit (or if passed a session, let caller
3319 # decide when to commit
3320 session.flush()
3322 for dag in dags:
3323 cls.bulk_write_to_db(dag.subdags, processor_subdir=processor_subdir, session=session)
3325 @classmethod
3326 def _get_latest_runs_stmt(cls, dags: list[str]) -> Select:
3327 """
3328 Build a select statement for retrieve the last automated run for each dag.
3330 :param dags: dags to query
3331 """
3332 if len(dags) == 1:
3333 # Index optimized fast path to avoid more complicated & slower groupby queryplan
3334 existing_dag_id = dags[0]
3335 last_automated_runs_subq = (
3336 select(func.max(DagRun.execution_date).label("max_execution_date"))
3337 .where(
3338 DagRun.dag_id == existing_dag_id,
3339 DagRun.run_type.in_((DagRunType.BACKFILL_JOB, DagRunType.SCHEDULED)),
3340 )
3341 .scalar_subquery()
3342 )
3343 query = select(DagRun).where(
3344 DagRun.dag_id == existing_dag_id, DagRun.execution_date == last_automated_runs_subq
3345 )
3346 else:
3347 last_automated_runs_subq = (
3348 select(DagRun.dag_id, func.max(DagRun.execution_date).label("max_execution_date"))
3349 .where(
3350 DagRun.dag_id.in_(dags),
3351 DagRun.run_type.in_((DagRunType.BACKFILL_JOB, DagRunType.SCHEDULED)),
3352 )
3353 .group_by(DagRun.dag_id)
3354 .subquery()
3355 )
3356 query = select(DagRun).where(
3357 DagRun.dag_id == last_automated_runs_subq.c.dag_id,
3358 DagRun.execution_date == last_automated_runs_subq.c.max_execution_date,
3359 )
3360 return query.options(
3361 load_only(
3362 DagRun.dag_id,
3363 DagRun.execution_date,
3364 DagRun.data_interval_start,
3365 DagRun.data_interval_end,
3366 )
3367 )
3369 @provide_session
3370 def sync_to_db(self, processor_subdir: str | None = None, session=NEW_SESSION):
3371 """
3372 Save attributes about this DAG to the DB.
3374 Note that this method can be called for both DAGs and SubDAGs. A SubDag is actually a SubDagOperator.
3376 :return: None
3377 """
3378 self.bulk_write_to_db([self], processor_subdir=processor_subdir, session=session)
3380 def get_default_view(self):
3381 """Allow backward compatible jinja2 templates."""
3382 if self.default_view is None:
3383 return airflow_conf.get("webserver", "dag_default_view").lower()
3384 else:
3385 return self.default_view
3387 @staticmethod
3388 @provide_session
3389 def deactivate_unknown_dags(active_dag_ids, session=NEW_SESSION):
3390 """
3391 Given a list of known DAGs, deactivate any other DAGs that are marked as active in the ORM.
3393 :param active_dag_ids: list of DAG IDs that are active
3394 :return: None
3395 """
3396 if not active_dag_ids:
3397 return
3398 for dag in session.scalars(select(DagModel).where(~DagModel.dag_id.in_(active_dag_ids))).all():
3399 dag.is_active = False
3400 session.merge(dag)
3401 session.commit()
3403 @staticmethod
3404 @provide_session
3405 def deactivate_stale_dags(expiration_date, session=NEW_SESSION):
3406 """
3407 Deactivate any DAGs that were last touched by the scheduler before the expiration date.
3409 These DAGs were likely deleted.
3411 :param expiration_date: set inactive DAGs that were touched before this time
3412 :return: None
3413 """
3414 for dag in session.scalars(
3415 select(DagModel).where(DagModel.last_parsed_time < expiration_date, DagModel.is_active)
3416 ):
3417 log.info(
3418 "Deactivating DAG ID %s since it was last touched by the scheduler at %s",
3419 dag.dag_id,
3420 dag.last_parsed_time.isoformat(),
3421 )
3422 dag.is_active = False
3423 session.merge(dag)
3424 session.commit()
3426 @staticmethod
3427 @provide_session
3428 def get_num_task_instances(dag_id, run_id=None, task_ids=None, states=None, session=NEW_SESSION) -> int:
3429 """
3430 Return the number of task instances in the given DAG.
3432 :param session: ORM session
3433 :param dag_id: ID of the DAG to get the task concurrency of
3434 :param run_id: ID of the DAG run to get the task concurrency of
3435 :param task_ids: A list of valid task IDs for the given DAG
3436 :param states: A list of states to filter by if supplied
3437 :return: The number of running tasks
3438 """
3439 qry = select(func.count(TaskInstance.task_id)).where(
3440 TaskInstance.dag_id == dag_id,
3441 )
3442 if run_id:
3443 qry = qry.where(
3444 TaskInstance.run_id == run_id,
3445 )
3446 if task_ids:
3447 qry = qry.where(
3448 TaskInstance.task_id.in_(task_ids),
3449 )
3451 if states:
3452 if None in states:
3453 if all(x is None for x in states):
3454 qry = qry.where(TaskInstance.state.is_(None))
3455 else:
3456 not_none_states = [state for state in states if state]
3457 qry = qry.where(
3458 or_(TaskInstance.state.in_(not_none_states), TaskInstance.state.is_(None))
3459 )
3460 else:
3461 qry = qry.where(TaskInstance.state.in_(states))
3462 return session.scalar(qry)
3464 @classmethod
3465 def get_serialized_fields(cls):
3466 """Stringified DAGs and operators contain exactly these fields."""
3467 if not cls.__serialized_fields:
3468 exclusion_list = {
3469 "parent_dag",
3470 "schedule_dataset_references",
3471 "task_outlet_dataset_references",
3472 "_old_context_manager_dags",
3473 "safe_dag_id",
3474 "last_loaded",
3475 "user_defined_filters",
3476 "user_defined_macros",
3477 "partial",
3478 "params",
3479 "_pickle_id",
3480 "_log",
3481 "task_dict",
3482 "template_searchpath",
3483 "sla_miss_callback",
3484 "on_success_callback",
3485 "on_failure_callback",
3486 "template_undefined",
3487 "jinja_environment_kwargs",
3488 # has_on_*_callback are only stored if the value is True, as the default is False
3489 "has_on_success_callback",
3490 "has_on_failure_callback",
3491 "auto_register",
3492 "fail_stop",
3493 }
3494 cls.__serialized_fields = frozenset(vars(DAG(dag_id="test"))) - exclusion_list
3495 return cls.__serialized_fields
3497 def get_edge_info(self, upstream_task_id: str, downstream_task_id: str) -> EdgeInfoType:
3498 """Return edge information for the given pair of tasks or an empty edge if there is no information."""
3499 # Note - older serialized DAGs may not have edge_info being a dict at all
3500 empty = cast(EdgeInfoType, {})
3501 if self.edge_info:
3502 return self.edge_info.get(upstream_task_id, {}).get(downstream_task_id, empty)
3503 else:
3504 return empty
3506 def set_edge_info(self, upstream_task_id: str, downstream_task_id: str, info: EdgeInfoType):
3507 """
3508 Set the given edge information on the DAG.
3510 Note that this will overwrite, rather than merge with, existing info.
3511 """
3512 self.edge_info.setdefault(upstream_task_id, {})[downstream_task_id] = info
3514 def validate_schedule_and_params(self):
3515 """
3516 Validate Param values when the DAG has schedule defined.
3518 Raise exception if there are any Params which can not be resolved by their schema definition.
3519 """
3520 if not self.timetable.can_be_scheduled:
3521 return
3523 try:
3524 self.params.validate()
3525 except ParamValidationError as pverr:
3526 raise AirflowException(
3527 "DAG is not allowed to define a Schedule, "
3528 "if there are any required params without default values or default values are not valid."
3529 ) from pverr
3531 def iter_invalid_owner_links(self) -> Iterator[tuple[str, str]]:
3532 """
3533 Parse a given link, and verifies if it's a valid URL, or a 'mailto' link.
3535 Returns an iterator of invalid (owner, link) pairs.
3536 """
3537 for owner, link in self.owner_links.items():
3538 result = urlsplit(link)
3539 if result.scheme == "mailto":
3540 # netloc is not existing for 'mailto' link, so we are checking that the path is parsed
3541 if not result.path:
3542 yield result.path, link
3543 elif not result.scheme or not result.netloc:
3544 yield owner, link
3547class DagTag(Base):
3548 """A tag name per dag, to allow quick filtering in the DAG view."""
3550 __tablename__ = "dag_tag"
3551 name = Column(String(TAG_MAX_LEN), primary_key=True)
3552 dag_id = Column(
3553 StringID(),
3554 ForeignKey("dag.dag_id", name="dag_tag_dag_id_fkey", ondelete="CASCADE"),
3555 primary_key=True,
3556 )
3558 __table_args__ = (Index("idx_dag_tag_dag_id", dag_id),)
3560 def __repr__(self):
3561 return self.name
3564class DagOwnerAttributes(Base):
3565 """
3566 Table defining different owner attributes.
3568 For example, a link for an owner that will be passed as a hyperlink to the "DAGs" view.
3569 """
3571 __tablename__ = "dag_owner_attributes"
3572 dag_id = Column(
3573 StringID(),
3574 ForeignKey("dag.dag_id", name="dag.dag_id", ondelete="CASCADE"),
3575 nullable=False,
3576 primary_key=True,
3577 )
3578 owner = Column(String(500), primary_key=True, nullable=False)
3579 link = Column(String(500), nullable=False)
3581 def __repr__(self):
3582 return f"<DagOwnerAttributes: dag_id={self.dag_id}, owner={self.owner}, link={self.link}>"
3584 @classmethod
3585 def get_all(cls, session) -> dict[str, dict[str, str]]:
3586 dag_links: dict = defaultdict(dict)
3587 for obj in session.scalars(select(cls)):
3588 dag_links[obj.dag_id].update({obj.owner: obj.link})
3589 return dag_links
3592class DagModel(Base):
3593 """Table containing DAG properties."""
3595 __tablename__ = "dag"
3596 """
3597 These items are stored in the database for state related information
3598 """
3599 dag_id = Column(StringID(), primary_key=True)
3600 root_dag_id = Column(StringID())
3601 # A DAG can be paused from the UI / DB
3602 # Set this default value of is_paused based on a configuration value!
3603 is_paused_at_creation = airflow_conf.getboolean("core", "dags_are_paused_at_creation")
3604 is_paused = Column(Boolean, default=is_paused_at_creation)
3605 # Whether the DAG is a subdag
3606 is_subdag = Column(Boolean, default=False)
3607 # Whether that DAG was seen on the last DagBag load
3608 is_active = Column(Boolean, default=False)
3609 # Last time the scheduler started
3610 last_parsed_time = Column(UtcDateTime)
3611 # Last time this DAG was pickled
3612 last_pickled = Column(UtcDateTime)
3613 # Time when the DAG last received a refresh signal
3614 # (e.g. the DAG's "refresh" button was clicked in the web UI)
3615 last_expired = Column(UtcDateTime)
3616 # Whether (one of) the scheduler is scheduling this DAG at the moment
3617 scheduler_lock = Column(Boolean)
3618 # Foreign key to the latest pickle_id
3619 pickle_id = Column(Integer)
3620 # The location of the file containing the DAG object
3621 # Note: Do not depend on fileloc pointing to a file; in the case of a
3622 # packaged DAG, it will point to the subpath of the DAG within the
3623 # associated zip.
3624 fileloc = Column(String(2000))
3625 # The base directory used by Dag Processor that parsed this dag.
3626 processor_subdir = Column(String(2000), nullable=True)
3627 # String representing the owners
3628 owners = Column(String(2000))
3629 # Display name of the dag
3630 _dag_display_property_value = Column("dag_display_name", String(2000), nullable=True)
3631 # Description of the dag
3632 description = Column(Text)
3633 # Default view of the DAG inside the webserver
3634 default_view = Column(String(25))
3635 # Schedule interval
3636 schedule_interval = Column(Interval)
3637 # Timetable/Schedule Interval description
3638 timetable_description = Column(String(1000), nullable=True)
3639 # Dataset expression based on dataset triggers
3640 dataset_expression = Column(sqlalchemy_jsonfield.JSONField(json=json), nullable=True)
3641 # Tags for view filter
3642 tags = relationship("DagTag", cascade="all, delete, delete-orphan", backref=backref("dag"))
3643 # Dag owner links for DAGs view
3644 dag_owner_links = relationship(
3645 "DagOwnerAttributes", cascade="all, delete, delete-orphan", backref=backref("dag")
3646 )
3648 max_active_tasks = Column(Integer, nullable=False)
3649 max_active_runs = Column(Integer, nullable=True)
3650 max_consecutive_failed_dag_runs = Column(Integer, nullable=False)
3652 has_task_concurrency_limits = Column(Boolean, nullable=False)
3653 has_import_errors = Column(Boolean(), default=False, server_default="0")
3655 # The logical date of the next dag run.
3656 next_dagrun = Column(UtcDateTime)
3658 # Must be either both NULL or both datetime.
3659 next_dagrun_data_interval_start = Column(UtcDateTime)
3660 next_dagrun_data_interval_end = Column(UtcDateTime)
3662 # Earliest time at which this ``next_dagrun`` can be created.
3663 next_dagrun_create_after = Column(UtcDateTime)
3665 __table_args__ = (
3666 Index("idx_root_dag_id", root_dag_id, unique=False),
3667 Index("idx_next_dagrun_create_after", next_dagrun_create_after, unique=False),
3668 )
3670 parent_dag = relationship(
3671 "DagModel", remote_side=[dag_id], primaryjoin=root_dag_id == dag_id, foreign_keys=[root_dag_id]
3672 )
3673 schedule_dataset_references = relationship(
3674 "DagScheduleDatasetReference",
3675 back_populates="dag",
3676 cascade="all, delete, delete-orphan",
3677 )
3678 schedule_datasets = association_proxy("schedule_dataset_references", "dataset")
3679 task_outlet_dataset_references = relationship(
3680 "TaskOutletDatasetReference",
3681 cascade="all, delete, delete-orphan",
3682 )
3683 NUM_DAGS_PER_DAGRUN_QUERY = airflow_conf.getint(
3684 "scheduler", "max_dagruns_to_create_per_loop", fallback=10
3685 )
3687 def __init__(self, concurrency=None, **kwargs):
3688 super().__init__(**kwargs)
3689 if self.max_active_tasks is None:
3690 if concurrency:
3691 warnings.warn(
3692 "The 'DagModel.concurrency' parameter is deprecated. Please use 'max_active_tasks'.",
3693 RemovedInAirflow3Warning,
3694 stacklevel=2,
3695 )
3696 self.max_active_tasks = concurrency
3697 else:
3698 self.max_active_tasks = airflow_conf.getint("core", "max_active_tasks_per_dag")
3700 if self.max_active_runs is None:
3701 self.max_active_runs = airflow_conf.getint("core", "max_active_runs_per_dag")
3703 if self.max_consecutive_failed_dag_runs is None:
3704 self.max_consecutive_failed_dag_runs = airflow_conf.getint(
3705 "core", "max_consecutive_failed_dag_runs_per_dag"
3706 )
3708 if self.has_task_concurrency_limits is None:
3709 # Be safe -- this will be updated later once the DAG is parsed
3710 self.has_task_concurrency_limits = True
3712 def __repr__(self):
3713 return f"<DAG: {self.dag_id}>"
3715 @property
3716 def next_dagrun_data_interval(self) -> DataInterval | None:
3717 return _get_model_data_interval(
3718 self,
3719 "next_dagrun_data_interval_start",
3720 "next_dagrun_data_interval_end",
3721 )
3723 @next_dagrun_data_interval.setter
3724 def next_dagrun_data_interval(self, value: tuple[datetime, datetime] | None) -> None:
3725 if value is None:
3726 self.next_dagrun_data_interval_start = self.next_dagrun_data_interval_end = None
3727 else:
3728 self.next_dagrun_data_interval_start, self.next_dagrun_data_interval_end = value
3730 @property
3731 def timezone(self):
3732 return settings.TIMEZONE
3734 @staticmethod
3735 @provide_session
3736 def get_dagmodel(dag_id: str, session: Session = NEW_SESSION) -> DagModel | None:
3737 return session.get(
3738 DagModel,
3739 dag_id,
3740 options=[joinedload(DagModel.parent_dag)],
3741 )
3743 @classmethod
3744 @internal_api_call
3745 @provide_session
3746 def get_current(cls, dag_id: str, session=NEW_SESSION) -> DagModel | DagModelPydantic:
3747 return session.scalar(select(cls).where(cls.dag_id == dag_id))
3749 @provide_session
3750 def get_last_dagrun(self, session=NEW_SESSION, include_externally_triggered=False):
3751 return get_last_dagrun(
3752 self.dag_id, session=session, include_externally_triggered=include_externally_triggered
3753 )
3755 def get_is_paused(self, *, session: Session | None = None) -> bool:
3756 """Provide interface compatibility to 'DAG'."""
3757 return self.is_paused
3759 def get_is_active(self, *, session: Session | None = None) -> bool:
3760 """Provide interface compatibility to 'DAG'."""
3761 return self.is_active
3763 @staticmethod
3764 @internal_api_call
3765 @provide_session
3766 def get_paused_dag_ids(dag_ids: list[str], session: Session = NEW_SESSION) -> set[str]:
3767 """
3768 Given a list of dag_ids, get a set of Paused Dag Ids.
3770 :param dag_ids: List of Dag ids
3771 :param session: ORM Session
3772 :return: Paused Dag_ids
3773 """
3774 paused_dag_ids = session.execute(
3775 select(DagModel.dag_id)
3776 .where(DagModel.is_paused == expression.true())
3777 .where(DagModel.dag_id.in_(dag_ids))
3778 )
3780 paused_dag_ids = {paused_dag_id for (paused_dag_id,) in paused_dag_ids}
3781 return paused_dag_ids
3783 def get_default_view(self) -> str:
3784 """Get the Default DAG View, returns the default config value if DagModel does not have a value."""
3785 # This is for backwards-compatibility with old dags that don't have None as default_view
3786 return self.default_view or airflow_conf.get_mandatory_value("webserver", "dag_default_view").lower()
3788 @property
3789 def safe_dag_id(self):
3790 return self.dag_id.replace(".", "__dot__")
3792 @property
3793 def relative_fileloc(self) -> pathlib.Path | None:
3794 """File location of the importable dag 'file' relative to the configured DAGs folder."""
3795 if self.fileloc is None:
3796 return None
3797 path = pathlib.Path(self.fileloc)
3798 try:
3799 return path.relative_to(settings.DAGS_FOLDER)
3800 except ValueError:
3801 # Not relative to DAGS_FOLDER.
3802 return path
3804 @provide_session
3805 def set_is_paused(self, is_paused: bool, including_subdags: bool = True, session=NEW_SESSION) -> None:
3806 """
3807 Pause/Un-pause a DAG.
3809 :param is_paused: Is the DAG paused
3810 :param including_subdags: whether to include the DAG's subdags
3811 :param session: session
3812 """
3813 filter_query = [
3814 DagModel.dag_id == self.dag_id,
3815 ]
3816 if including_subdags:
3817 filter_query.append(DagModel.root_dag_id == self.dag_id)
3818 session.execute(
3819 update(DagModel)
3820 .where(or_(*filter_query))
3821 .values(is_paused=is_paused)
3822 .execution_options(synchronize_session="fetch")
3823 )
3824 session.commit()
3826 @hybrid_property
3827 def dag_display_name(self) -> str:
3828 return self._dag_display_property_value or self.dag_id
3830 @classmethod
3831 @internal_api_call
3832 @provide_session
3833 def deactivate_deleted_dags(
3834 cls,
3835 alive_dag_filelocs: Container[str],
3836 processor_subdir: str,
3837 session: Session = NEW_SESSION,
3838 ) -> None:
3839 """
3840 Set ``is_active=False`` on the DAGs for which the DAG files have been removed.
3842 :param alive_dag_filelocs: file paths of alive DAGs
3843 :param processor_subdir: dag processor subdir
3844 :param session: ORM Session
3845 """
3846 log.debug("Deactivating DAGs (for which DAG files are deleted) from %s table ", cls.__tablename__)
3847 dag_models = session.scalars(
3848 select(cls).where(
3849 cls.fileloc.is_not(None),
3850 or_(
3851 cls.processor_subdir.is_(None),
3852 cls.processor_subdir == processor_subdir,
3853 ),
3854 )
3855 )
3857 for dag_model in dag_models:
3858 if dag_model.fileloc not in alive_dag_filelocs:
3859 dag_model.is_active = False
3861 @classmethod
3862 def dags_needing_dagruns(cls, session: Session) -> tuple[Query, dict[str, tuple[datetime, datetime]]]:
3863 """
3864 Return (and lock) a list of Dag objects that are due to create a new DagRun.
3866 This will return a resultset of rows that is row-level-locked with a "SELECT ... FOR UPDATE" query,
3867 you should ensure that any scheduling decisions are made in a single transaction -- as soon as the
3868 transaction is committed it will be unlocked.
3869 """
3870 from airflow.models.serialized_dag import SerializedDagModel
3872 def dag_ready(dag_id: str, cond: BaseDataset, statuses: dict) -> bool | None:
3873 # if dag was serialized before 2.9 and we *just* upgraded,
3874 # we may be dealing with old version. In that case,
3875 # just wait for the dag to be reserialized.
3876 try:
3877 return cond.evaluate(statuses)
3878 except AttributeError:
3879 log.warning("dag '%s' has old serialization; skipping DAG run creation.", dag_id)
3880 return None
3882 # this loads all the DDRQ records.... may need to limit num dags
3883 all_records = session.scalars(select(DatasetDagRunQueue)).all()
3884 by_dag = defaultdict(list)
3885 for r in all_records:
3886 by_dag[r.target_dag_id].append(r)
3887 del all_records
3888 dag_statuses = {}
3889 for dag_id, records in by_dag.items():
3890 dag_statuses[dag_id] = {x.dataset.uri: True for x in records}
3891 ser_dags = session.scalars(
3892 select(SerializedDagModel).where(SerializedDagModel.dag_id.in_(dag_statuses.keys()))
3893 ).all()
3894 for ser_dag in ser_dags:
3895 dag_id = ser_dag.dag_id
3896 statuses = dag_statuses[dag_id]
3897 if not dag_ready(dag_id, cond=ser_dag.dag.dataset_triggers, statuses=statuses):
3898 del by_dag[dag_id]
3899 del dag_statuses[dag_id]
3900 del dag_statuses
3901 dataset_triggered_dag_info = {}
3902 for dag_id, records in by_dag.items():
3903 times = sorted(x.created_at for x in records)
3904 dataset_triggered_dag_info[dag_id] = (times[0], times[-1])
3905 del by_dag
3906 dataset_triggered_dag_ids = set(dataset_triggered_dag_info.keys())
3907 if dataset_triggered_dag_ids:
3908 exclusion_list = set(
3909 session.scalars(
3910 select(DagModel.dag_id)
3911 .join(DagRun.dag_model)
3912 .where(DagRun.state.in_((DagRunState.QUEUED, DagRunState.RUNNING)))
3913 .where(DagModel.dag_id.in_(dataset_triggered_dag_ids))
3914 .group_by(DagModel.dag_id)
3915 .having(func.count() >= func.max(DagModel.max_active_runs))
3916 )
3917 )
3918 if exclusion_list:
3919 dataset_triggered_dag_ids -= exclusion_list
3920 dataset_triggered_dag_info = {
3921 k: v for k, v in dataset_triggered_dag_info.items() if k not in exclusion_list
3922 }
3924 # We limit so that _one_ scheduler doesn't try to do all the creation of dag runs
3925 query = (
3926 select(cls)
3927 .where(
3928 cls.is_paused == expression.false(),
3929 cls.is_active == expression.true(),
3930 cls.has_import_errors == expression.false(),
3931 or_(
3932 cls.next_dagrun_create_after <= func.now(),
3933 cls.dag_id.in_(dataset_triggered_dag_ids),
3934 ),
3935 )
3936 .order_by(cls.next_dagrun_create_after)
3937 .limit(cls.NUM_DAGS_PER_DAGRUN_QUERY)
3938 )
3940 return (
3941 session.scalars(with_row_locks(query, of=cls, session=session, skip_locked=True)),
3942 dataset_triggered_dag_info,
3943 )
3945 def calculate_dagrun_date_fields(
3946 self,
3947 dag: DAG,
3948 last_automated_dag_run: None | datetime | DataInterval,
3949 ) -> None:
3950 """
3951 Calculate ``next_dagrun`` and `next_dagrun_create_after``.
3953 :param dag: The DAG object
3954 :param last_automated_dag_run: DataInterval (or datetime) of most recent run of this dag, or none
3955 if not yet scheduled.
3956 """
3957 last_automated_data_interval: DataInterval | None
3958 if isinstance(last_automated_dag_run, datetime):
3959 warnings.warn(
3960 "Passing a datetime to `DagModel.calculate_dagrun_date_fields` is deprecated. "
3961 "Provide a data interval instead.",
3962 RemovedInAirflow3Warning,
3963 stacklevel=2,
3964 )
3965 last_automated_data_interval = dag.infer_automated_data_interval(last_automated_dag_run)
3966 else:
3967 last_automated_data_interval = last_automated_dag_run
3968 next_dagrun_info = dag.next_dagrun_info(last_automated_data_interval)
3969 if next_dagrun_info is None:
3970 self.next_dagrun_data_interval = self.next_dagrun = self.next_dagrun_create_after = None
3971 else:
3972 self.next_dagrun_data_interval = next_dagrun_info.data_interval
3973 self.next_dagrun = next_dagrun_info.logical_date
3974 self.next_dagrun_create_after = next_dagrun_info.run_after
3976 log.info(
3977 "Setting next_dagrun for %s to %s, run_after=%s",
3978 dag.dag_id,
3979 self.next_dagrun,
3980 self.next_dagrun_create_after,
3981 )
3983 @provide_session
3984 def get_dataset_triggered_next_run_info(self, *, session=NEW_SESSION) -> dict[str, int | str] | None:
3985 if self.schedule_interval != "Dataset":
3986 return None
3987 return get_dataset_triggered_next_run_info([self.dag_id], session=session)[self.dag_id]
3990# NOTE: Please keep the list of arguments in sync with DAG.__init__.
3991# Only exception: dag_id here should have a default value, but not in DAG.
3992def dag(
3993 dag_id: str = "",
3994 description: str | None = None,
3995 schedule: ScheduleArg = NOTSET,
3996 schedule_interval: ScheduleIntervalArg = NOTSET,
3997 timetable: Timetable | None = None,
3998 start_date: datetime | None = None,
3999 end_date: datetime | None = None,
4000 full_filepath: str | None = None,
4001 template_searchpath: str | Iterable[str] | None = None,
4002 template_undefined: type[jinja2.StrictUndefined] = jinja2.StrictUndefined,
4003 user_defined_macros: dict | None = None,
4004 user_defined_filters: dict | None = None,
4005 default_args: dict | None = None,
4006 concurrency: int | None = None,
4007 max_active_tasks: int = airflow_conf.getint("core", "max_active_tasks_per_dag"),
4008 max_active_runs: int = airflow_conf.getint("core", "max_active_runs_per_dag"),
4009 max_consecutive_failed_dag_runs: int = airflow_conf.getint(
4010 "core", "max_consecutive_failed_dag_runs_per_dag"
4011 ),
4012 dagrun_timeout: timedelta | None = None,
4013 sla_miss_callback: None | SLAMissCallback | list[SLAMissCallback] = None,
4014 default_view: str = airflow_conf.get_mandatory_value("webserver", "dag_default_view").lower(),
4015 orientation: str = airflow_conf.get_mandatory_value("webserver", "dag_orientation"),
4016 catchup: bool = airflow_conf.getboolean("scheduler", "catchup_by_default"),
4017 on_success_callback: None | DagStateChangeCallback | list[DagStateChangeCallback] = None,
4018 on_failure_callback: None | DagStateChangeCallback | list[DagStateChangeCallback] = None,
4019 doc_md: str | None = None,
4020 params: abc.MutableMapping | None = None,
4021 access_control: dict | None = None,
4022 is_paused_upon_creation: bool | None = None,
4023 jinja_environment_kwargs: dict | None = None,
4024 render_template_as_native_obj: bool = False,
4025 tags: list[str] | None = None,
4026 owner_links: dict[str, str] | None = None,
4027 auto_register: bool = True,
4028 fail_stop: bool = False,
4029 dag_display_name: str | None = None,
4030) -> Callable[[Callable], Callable[..., DAG]]:
4031 """
4032 Python dag decorator which wraps a function into an Airflow DAG.
4034 Accepts kwargs for operator kwarg. Can be used to parameterize DAGs.
4036 :param dag_args: Arguments for DAG object
4037 :param dag_kwargs: Kwargs for DAG object.
4038 """
4040 def wrapper(f: Callable) -> Callable[..., DAG]:
4041 @functools.wraps(f)
4042 def factory(*args, **kwargs):
4043 # Generate signature for decorated function and bind the arguments when called
4044 # we do this to extract parameters, so we can annotate them on the DAG object.
4045 # In addition, this fails if we are missing any args/kwargs with TypeError as expected.
4046 f_sig = signature(f).bind(*args, **kwargs)
4047 # Apply defaults to capture default values if set.
4048 f_sig.apply_defaults()
4050 # Initialize DAG with bound arguments
4051 with DAG(
4052 dag_id or f.__name__,
4053 description=description,
4054 schedule_interval=schedule_interval,
4055 timetable=timetable,
4056 start_date=start_date,
4057 end_date=end_date,
4058 full_filepath=full_filepath,
4059 template_searchpath=template_searchpath,
4060 template_undefined=template_undefined,
4061 user_defined_macros=user_defined_macros,
4062 user_defined_filters=user_defined_filters,
4063 default_args=default_args,
4064 concurrency=concurrency,
4065 max_active_tasks=max_active_tasks,
4066 max_active_runs=max_active_runs,
4067 max_consecutive_failed_dag_runs=max_consecutive_failed_dag_runs,
4068 dagrun_timeout=dagrun_timeout,
4069 sla_miss_callback=sla_miss_callback,
4070 default_view=default_view,
4071 orientation=orientation,
4072 catchup=catchup,
4073 on_success_callback=on_success_callback,
4074 on_failure_callback=on_failure_callback,
4075 doc_md=doc_md,
4076 params=params,
4077 access_control=access_control,
4078 is_paused_upon_creation=is_paused_upon_creation,
4079 jinja_environment_kwargs=jinja_environment_kwargs,
4080 render_template_as_native_obj=render_template_as_native_obj,
4081 tags=tags,
4082 schedule=schedule,
4083 owner_links=owner_links,
4084 auto_register=auto_register,
4085 fail_stop=fail_stop,
4086 dag_display_name=dag_display_name,
4087 ) as dag_obj:
4088 # Set DAG documentation from function documentation if it exists and doc_md is not set.
4089 if f.__doc__ and not dag_obj.doc_md:
4090 dag_obj.doc_md = f.__doc__
4092 # Generate DAGParam for each function arg/kwarg and replace it for calling the function.
4093 # All args/kwargs for function will be DAGParam object and replaced on execution time.
4094 f_kwargs = {}
4095 for name, value in f_sig.arguments.items():
4096 f_kwargs[name] = dag_obj.param(name, value)
4098 # set file location to caller source path
4099 back = sys._getframe().f_back
4100 dag_obj.fileloc = back.f_code.co_filename if back else ""
4102 # Invoke function to create operators in the DAG scope.
4103 f(**f_kwargs)
4105 # Return dag object such that it's accessible in Globals.
4106 return dag_obj
4108 # Ensure that warnings from inside DAG() are emitted from the caller, not here
4109 fixup_decorator_warning_stack(factory)
4110 return factory
4112 return wrapper
4115STATICA_HACK = True
4116globals()["kcah_acitats"[::-1].upper()] = False
4117if STATICA_HACK: # pragma: no cover
4118 from airflow.models.serialized_dag import SerializedDagModel
4120 DagModel.serialized_dag = relationship(SerializedDagModel)
4121 """:sphinx-autoapi-skip:"""
4124class DagContext:
4125 """
4126 DAG context is used to keep the current DAG when DAG is used as ContextManager.
4128 You can use DAG as context:
4130 .. code-block:: python
4132 with DAG(
4133 dag_id="example_dag",
4134 default_args=default_args,
4135 schedule="0 0 * * *",
4136 dagrun_timeout=timedelta(minutes=60),
4137 ) as dag:
4138 ...
4140 If you do this the context stores the DAG and whenever new task is created, it will use
4141 such stored DAG as the parent DAG.
4143 """
4145 _context_managed_dags: deque[DAG] = deque()
4146 autoregistered_dags: set[tuple[DAG, ModuleType]] = set()
4147 current_autoregister_module_name: str | None = None
4149 @classmethod
4150 def push_context_managed_dag(cls, dag: DAG):
4151 cls._context_managed_dags.appendleft(dag)
4153 @classmethod
4154 def pop_context_managed_dag(cls) -> DAG | None:
4155 dag = cls._context_managed_dags.popleft()
4157 # In a few cases around serialization we explicitly push None in to the stack
4158 if cls.current_autoregister_module_name is not None and dag and dag.auto_register:
4159 mod = sys.modules[cls.current_autoregister_module_name]
4160 cls.autoregistered_dags.add((dag, mod))
4162 return dag
4164 @classmethod
4165 def get_current_dag(cls) -> DAG | None:
4166 try:
4167 return cls._context_managed_dags[0]
4168 except IndexError:
4169 return None
4172def _run_inline_trigger(trigger):
4173 async def _run_inline_trigger_main():
4174 async for event in trigger.run():
4175 return event
4177 return asyncio.run(_run_inline_trigger_main())
4180def _run_task(*, ti: TaskInstance, inline_trigger: bool = False, session: Session):
4181 """
4182 Run a single task instance, and push result to Xcom for downstream tasks.
4184 Bypasses a lot of extra steps used in `task.run` to keep our local running as fast as
4185 possible. This function is only meant for the `dag.test` function as a helper function.
4187 Args:
4188 ti: TaskInstance to run
4189 """
4190 log.info("[DAG TEST] starting task_id=%s map_index=%s", ti.task_id, ti.map_index)
4191 while True:
4192 try:
4193 log.info("[DAG TEST] running task %s", ti)
4194 ti._run_raw_task(session=session, raise_on_defer=inline_trigger)
4195 break
4196 except TaskDeferred as e:
4197 log.info("[DAG TEST] running trigger in line")
4198 event = _run_inline_trigger(e.trigger)
4199 ti.next_method = e.method_name
4200 ti.next_kwargs = {"event": event.payload} if event else e.kwargs
4201 log.info("[DAG TEST] Trigger completed")
4202 session.merge(ti)
4203 session.commit()
4204 log.info("[DAG TEST] end task task_id=%s map_index=%s", ti.task_id, ti.map_index)
4207def _get_or_create_dagrun(
4208 dag: DAG,
4209 conf: dict[Any, Any] | None,
4210 start_date: datetime,
4211 execution_date: datetime,
4212 run_id: str,
4213 session: Session,
4214 data_interval: tuple[datetime, datetime] | None = None,
4215) -> DagRun:
4216 """Create a DAG run, replacing an existing instance if needed to prevent collisions.
4218 This function is only meant to be used by :meth:`DAG.test` as a helper function.
4220 :param dag: DAG to be used to find run.
4221 :param conf: Configuration to pass to newly created run.
4222 :param start_date: Start date of new run.
4223 :param execution_date: Logical date for finding an existing run.
4224 :param run_id: Run ID for the new DAG run.
4226 :return: The newly created DAG run.
4227 """
4228 log.info("dagrun id: %s", dag.dag_id)
4229 dr: DagRun = session.scalar(
4230 select(DagRun).where(DagRun.dag_id == dag.dag_id, DagRun.execution_date == execution_date)
4231 )
4232 if dr:
4233 session.delete(dr)
4234 session.commit()
4235 dr = dag.create_dagrun(
4236 state=DagRunState.RUNNING,
4237 execution_date=execution_date,
4238 run_id=run_id,
4239 start_date=start_date or execution_date,
4240 session=session,
4241 conf=conf,
4242 data_interval=data_interval,
4243 )
4244 log.info("created dagrun %s", dr)
4245 return dr