1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18from __future__ import annotations
19
20import copy
21import functools
22import itertools
23import json
24import logging
25import os
26import sys
27import warnings
28import weakref
29from collections import abc, defaultdict, deque
30from collections.abc import Callable, Collection, Iterable, MutableSet
31from datetime import datetime, timedelta
32from inspect import signature
33from typing import TYPE_CHECKING, Any, ClassVar, TypeGuard, Union, cast, overload
34from urllib.parse import urlsplit
35from uuid import UUID
36
37import attrs
38import jinja2
39from dateutil.relativedelta import relativedelta
40
41from airflow import settings
42from airflow.sdk import TaskInstanceState, TriggerRule
43from airflow.sdk.bases.operator import BaseOperator
44from airflow.sdk.bases.timetable import BaseTimetable
45from airflow.sdk.definitions._internal.node import validate_key
46from airflow.sdk.definitions._internal.types import NOTSET, ArgNotSet, is_arg_set
47from airflow.sdk.definitions.asset import AssetAll, BaseAsset
48from airflow.sdk.definitions.context import Context
49from airflow.sdk.definitions.deadline import DeadlineAlert
50from airflow.sdk.definitions.param import DagParam, ParamsDict
51from airflow.sdk.definitions.timetables.assets import AssetTriggeredTimetable
52from airflow.sdk.definitions.timetables.simple import ContinuousTimetable, NullTimetable, OnceTimetable
53from airflow.sdk.exceptions import (
54 AirflowDagCycleException,
55 DuplicateTaskIdFound,
56 FailFastDagInvalidTriggerRule,
57 ParamValidationError,
58 RemovedInAirflow4Warning,
59 TaskNotFound,
60)
61
62if TYPE_CHECKING:
63 from re import Pattern
64 from typing import TypeAlias
65
66 from pendulum.tz.timezone import FixedTimezone, Timezone
67 from typing_extensions import Self, TypeIs
68
69 from airflow.models.taskinstance import TaskInstance as SchedulerTaskInstance
70 from airflow.sdk.definitions.decorators import TaskDecoratorCollection
71 from airflow.sdk.definitions.edges import EdgeInfoType
72 from airflow.sdk.definitions.mappedoperator import MappedOperator
73 from airflow.sdk.definitions.taskgroup import TaskGroup
74 from airflow.sdk.execution_time.supervisor import TaskRunResult
75 from airflow.timetables.base import DataInterval, Timetable as CoreTimetable
76
77 Operator: TypeAlias = BaseOperator | MappedOperator
78
79log = logging.getLogger(__name__)
80
81TAG_MAX_LEN = 100
82
83__all__ = [
84 "DAG",
85 "dag",
86]
87
88FINISHED_STATES = frozenset(
89 [
90 TaskInstanceState.SUCCESS,
91 TaskInstanceState.FAILED,
92 TaskInstanceState.SKIPPED,
93 TaskInstanceState.UPSTREAM_FAILED,
94 TaskInstanceState.REMOVED,
95 ]
96)
97
98DagStateChangeCallback = Callable[[Context], None]
99ScheduleInterval = None | str | timedelta | relativedelta
100
101ScheduleArg = Union[ScheduleInterval, BaseTimetable, "CoreTimetable", BaseAsset, Collection[BaseAsset]]
102
103
104_DAG_HASH_ATTRS = frozenset(
105 {
106 "dag_id",
107 "task_ids",
108 "start_date",
109 "end_date",
110 "fileloc",
111 "template_searchpath",
112 "last_loaded",
113 "schedule",
114 # TODO: Task-SDK: we should be hashing on timetable now, not schedule!
115 # "timetable",
116 }
117)
118
119
120def _is_core_timetable(schedule: ScheduleArg) -> TypeIs[CoreTimetable]:
121 try:
122 from airflow.timetables.base import Timetable
123 except ImportError:
124 return False
125 return isinstance(schedule, Timetable)
126
127
128def _create_timetable(interval: ScheduleInterval, timezone: Timezone | FixedTimezone) -> BaseTimetable:
129 """Create a Timetable instance from a plain ``schedule`` value."""
130 from airflow.sdk.configuration import conf as airflow_conf
131 from airflow.sdk.definitions.timetables.interval import (
132 CronDataIntervalTimetable,
133 DeltaDataIntervalTimetable,
134 )
135 from airflow.sdk.definitions.timetables.trigger import CronTriggerTimetable, DeltaTriggerTimetable
136
137 if interval is None:
138 return NullTimetable()
139 if interval == "@once":
140 return OnceTimetable()
141 if interval == "@continuous":
142 return ContinuousTimetable()
143 if isinstance(interval, timedelta | relativedelta):
144 if airflow_conf.getboolean("scheduler", "create_cron_data_intervals"):
145 return DeltaDataIntervalTimetable(interval)
146 return DeltaTriggerTimetable(interval)
147 if isinstance(interval, str):
148 if airflow_conf.getboolean("scheduler", "create_cron_data_intervals"):
149 return CronDataIntervalTimetable(interval, timezone)
150 return CronTriggerTimetable(interval, timezone=timezone)
151 raise ValueError(f"{interval!r} is not a valid schedule.")
152
153
154def _config_bool_factory(section: str, key: str) -> Callable[[], bool]:
155 from airflow.sdk.configuration import conf
156
157 return functools.partial(conf.getboolean, section, key)
158
159
160def _config_int_factory(section: str, key: str) -> Callable[[], int]:
161 from airflow.sdk.configuration import conf
162
163 return functools.partial(conf.getint, section, key)
164
165
166def _convert_params(val: abc.MutableMapping | None, self_: DAG) -> ParamsDict:
167 """
168 Convert the plain dict into a ParamsDict.
169
170 This will also merge in params from default_args
171 """
172 val = val or {}
173
174 # merging potentially conflicting default_args['params'] into params
175 if "params" in self_.default_args:
176 val.update(self_.default_args["params"])
177 del self_.default_args["params"]
178
179 params = ParamsDict(val)
180 object.__setattr__(self_, "params", params)
181
182 return params
183
184
185def _convert_str_to_tuple(val: str | Iterable[str] | None) -> Iterable[str] | None:
186 if isinstance(val, str):
187 return (val,)
188 return val
189
190
191def _convert_tags(tags: Collection[str] | None) -> MutableSet[str]:
192 return set(tags or [])
193
194
195def _convert_access_control(access_control):
196 if access_control is None:
197 return None
198 updated_access_control = {}
199 for role, perms in access_control.items():
200 updated_access_control[role] = updated_access_control.get(role, {})
201 if isinstance(perms, set | list):
202 # Support for old-style access_control where only the actions are specified
203 updated_access_control[role]["DAGs"] = set(perms)
204 else:
205 updated_access_control[role] = perms
206 return updated_access_control
207
208
209def _convert_deadline(deadline: list[DeadlineAlert] | DeadlineAlert | None) -> list[DeadlineAlert] | None:
210 """Convert deadline parameter to a list of DeadlineAlert objects."""
211 if deadline is None:
212 return None
213 if isinstance(deadline, DeadlineAlert):
214 return [deadline]
215 return list(deadline)
216
217
218def _convert_doc_md(doc_md: str | None) -> str | None:
219 if doc_md is None:
220 return doc_md
221
222 if doc_md.endswith(".md"):
223 try:
224 with open(doc_md) as fh:
225 return fh.read()
226 except FileNotFoundError:
227 return doc_md
228
229 return doc_md
230
231
232def _all_after_dag_id_to_kw_only(cls, fields: list[attrs.Attribute]):
233 i = iter(fields)
234 f = next(i)
235 if f.name != "dag_id":
236 raise RuntimeError("dag_id was not the first field")
237 yield f
238
239 for f in i:
240 yield f.evolve(kw_only=True)
241
242
243if TYPE_CHECKING:
244 # Given this attrs field:
245 #
246 # default_args: dict[str, Any] = attrs.field(factory=dict, converter=copy.copy)
247 #
248 # mypy ignores the type of the attrs and works out the type as the converter function. However it doesn't
249 # cope with generics properly and errors with 'incompatible type "dict[str, object]"; expected "_T"'
250 #
251 # https://github.com/python/mypy/issues/8625
252 def dict_copy(_: dict[str, Any]) -> dict[str, Any]: ...
253else:
254 dict_copy = copy.copy
255
256
257def _default_start_date(instance: DAG):
258 # Find start date inside default_args for compat with Airflow 2.
259 from airflow.sdk import timezone
260
261 if date := instance.default_args.get("start_date"):
262 if not isinstance(date, datetime):
263 date = timezone.parse(date)
264 instance.default_args["start_date"] = date
265 return date
266 return None
267
268
269def _default_dag_display_name(instance: DAG) -> str:
270 return instance.dag_id
271
272
273def _default_fileloc() -> str:
274 # Skip over this frame, and the 'attrs generated init'
275 back = sys._getframe().f_back
276 if not back or not (back := back.f_back):
277 # We expect two frames back, if not we don't know where we are
278 return ""
279 return back.f_code.co_filename if back else ""
280
281
282def _default_task_group(instance: DAG) -> TaskGroup:
283 from airflow.sdk.definitions.taskgroup import TaskGroup
284
285 return TaskGroup.create_root(dag=instance)
286
287
288# TODO: Task-SDK: look at re-enabling slots after we remove pickling
289@attrs.define(repr=False, field_transformer=_all_after_dag_id_to_kw_only, slots=False)
290class DAG:
291 """
292 A dag is a collection of tasks with directional dependencies.
293
294 A dag also has a schedule, a start date and an end date (optional). For each schedule,
295 (say daily or hourly), the DAG needs to run each individual tasks as their dependencies
296 are met. Certain tasks have the property of depending on their own past, meaning that
297 they can't run until their previous schedule (and upstream tasks) are completed.
298
299 Dags essentially act as namespaces for tasks. A task_id can only be
300 added once to a Dag.
301
302 Note that if you plan to use time zones all the dates provided should be pendulum
303 dates. See :ref:`timezone_aware_dags`.
304
305 .. versionadded:: 2.4
306 The *schedule* argument to specify either time-based scheduling logic
307 (timetable), or dataset-driven triggers.
308
309 .. versionchanged:: 3.0
310 The default value of *schedule* has been changed to *None* (no schedule).
311 The previous default was ``timedelta(days=1)``.
312
313 :param dag_id: The id of the DAG; must consist exclusively of alphanumeric
314 characters, dashes, dots and underscores (all ASCII)
315 :param description: The description for the DAG to e.g. be shown on the webserver
316 :param schedule: If provided, this defines the rules according to which DAG
317 runs are scheduled. Possible values include a cron expression string,
318 timedelta object, Timetable, or list of Asset objects.
319 See also :external:doc:`howto/timetable`.
320 :param start_date: The timestamp from which the scheduler will
321 attempt to backfill. If this is not provided, backfilling must be done
322 manually with an explicit time range.
323 :param end_date: A date beyond which your DAG won't run, leave to None
324 for open-ended scheduling.
325 :param template_searchpath: This list of folders (non-relative)
326 defines where jinja will look for your templates. Order matters.
327 Note that jinja/airflow includes the path of your DAG file by
328 default
329 :param template_undefined: Template undefined type.
330 :param user_defined_macros: a dictionary of macros that will be exposed
331 in your jinja templates. For example, passing ``dict(foo='bar')``
332 to this argument allows you to ``{{ foo }}`` in all jinja
333 templates related to this DAG. Note that you can pass any
334 type of object here.
335 :param user_defined_filters: a dictionary of filters that will be exposed
336 in your jinja templates. For example, passing
337 ``dict(hello=lambda name: 'Hello %s' % name)`` to this argument allows
338 you to ``{{ 'world' | hello }}`` in all jinja templates related to
339 this DAG.
340 :param default_args: A dictionary of default parameters to be used
341 as constructor keyword parameters when initialising operators.
342 Note that operators have the same hook, and precede those defined
343 here, meaning that if your dict contains `'depends_on_past': True`
344 here and `'depends_on_past': False` in the operator's call
345 `default_args`, the actual value will be `False`.
346 :param params: a dictionary of DAG level parameters that are made
347 accessible in templates, namespaced under `params`. These
348 params can be overridden at the task level.
349 :param max_active_tasks: the number of task instances allowed to run
350 concurrently
351 :param max_active_runs: maximum number of active DAG runs, beyond this
352 number of DAG runs in a running state, the scheduler won't create
353 new active DAG runs
354 :param max_consecutive_failed_dag_runs: (experimental) maximum number of consecutive failed DAG runs,
355 beyond this the scheduler will disable the DAG
356 :param dagrun_timeout: Specify the duration a DagRun should be allowed to run before it times out or
357 fails. Task instances that are running when a DagRun is timed out will be marked as skipped.
358 :param sla_miss_callback: DEPRECATED - The SLA feature is removed in Airflow 3.0, to be replaced with DeadlineAlerts in 3.1
359 :param deadline: An optional DeadlineAlert for the Dag.
360 :param catchup: Perform scheduler catchup (or only run latest)? Defaults to False
361 :param on_failure_callback: A function or list of functions to be called when a DagRun of this dag fails.
362 A context dictionary is passed as a single parameter to this function.
363 :param on_success_callback: Much like the ``on_failure_callback`` except
364 that it is executed when the dag succeeds.
365 :param access_control: Specify optional DAG-level actions, e.g.,
366 "{'role1': {'can_read'}, 'role2': {'can_read', 'can_edit', 'can_delete'}}"
367 or it can specify the resource name if there is a DAGs Run resource, e.g.,
368 "{'role1': {'DAG Runs': {'can_create'}}, 'role2': {'DAGs': {'can_read', 'can_edit', 'can_delete'}}"
369 :param is_paused_upon_creation: Specifies if the dag is paused when created for the first time.
370 If the dag exists already, this flag will be ignored. If this optional parameter
371 is not specified, the global config setting will be used.
372 :param jinja_environment_kwargs: additional configuration options to be passed to Jinja
373 ``Environment`` for template rendering
374
375 **Example**: to avoid Jinja from removing a trailing newline from template strings ::
376
377 DAG(
378 dag_id="my-dag",
379 jinja_environment_kwargs={
380 "keep_trailing_newline": True,
381 # some other jinja2 Environment options here
382 },
383 )
384
385 **See**: `Jinja Environment documentation
386 <https://jinja.palletsprojects.com/en/2.11.x/api/#jinja2.Environment>`_
387
388 :param render_template_as_native_obj: If True, uses a Jinja ``NativeEnvironment``
389 to render templates as native Python types. If False, a Jinja
390 ``Environment`` is used to render templates as string values.
391 :param tags: List of tags to help filtering Dags in the UI.
392 :param owner_links: Dict of owners and their links, that will be clickable on the Dags view UI.
393 Can be used as an HTTP link (for example the link to your Slack channel), or a mailto link.
394 e.g: ``{"dag_owner": "https://airflow.apache.org/"}``
395 :param auto_register: Automatically register this DAG when it is used in a ``with`` block
396 :param fail_fast: Fails currently running tasks when task in Dag fails.
397 **Warning**: A fail stop dag can only have tasks with the default trigger rule ("all_success").
398 An exception will be thrown if any task in a fail stop dag has a non default trigger rule.
399 :param dag_display_name: The display name of the Dag which appears on the UI.
400 """
401
402 __serialized_fields: ClassVar[frozenset[str]]
403
404 # Note: mypy gets very confused about the use of `@${attr}.default` for attrs without init=False -- and it
405 # doesn't correctly track/notice that they have default values (it gives errors about `Missing positional
406 # argument "description" in call to "DAG"`` etc), so for init=True args we use the `default=Factory()`
407 # style
408
409 def __rich_repr__(self):
410 yield "dag_id", self.dag_id
411 yield "schedule", self.schedule
412 yield "#tasks", len(self.tasks)
413
414 __rich_repr__.angular = True # type: ignore[attr-defined]
415
416 # NOTE: When updating arguments here, please also keep arguments in @dag()
417 # below in sync. (Search for 'def dag(' in this file.)
418 dag_id: str = attrs.field(kw_only=False, validator=lambda i, a, v: validate_key(v))
419 description: str | None = attrs.field(
420 default=None,
421 validator=attrs.validators.optional(attrs.validators.instance_of(str)),
422 )
423 default_args: dict[str, Any] = attrs.field(
424 factory=dict, validator=attrs.validators.instance_of(dict), converter=dict_copy
425 )
426 start_date: datetime | None = attrs.field(
427 default=attrs.Factory(_default_start_date, takes_self=True),
428 )
429
430 end_date: datetime | None = None
431 timezone: FixedTimezone | Timezone = attrs.field(init=False)
432 schedule: ScheduleArg = attrs.field(default=None, on_setattr=attrs.setters.frozen)
433 timetable: BaseTimetable | CoreTimetable = attrs.field(init=False)
434 template_searchpath: str | Iterable[str] | None = attrs.field(
435 default=None, converter=_convert_str_to_tuple
436 )
437 # TODO: Task-SDK: Work out how to not import jinj2 until we need it! It's expensive
438 template_undefined: type[jinja2.StrictUndefined] = jinja2.StrictUndefined
439 user_defined_macros: dict | None = None
440 user_defined_filters: dict | None = None
441 max_active_tasks: int = attrs.field(
442 factory=_config_int_factory("core", "max_active_tasks_per_dag"),
443 converter=attrs.converters.default_if_none( # type: ignore[misc]
444 # attrs only supports named callables or lambdas, but partial works
445 # OK here too. This is a false positive from attrs's Mypy plugin.
446 factory=_config_int_factory("core", "max_active_tasks_per_dag"),
447 ),
448 validator=attrs.validators.instance_of(int),
449 )
450 max_active_runs: int = attrs.field(
451 factory=_config_int_factory("core", "max_active_runs_per_dag"),
452 converter=attrs.converters.default_if_none( # type: ignore[misc]
453 # attrs only supports named callables or lambdas, but partial works
454 # OK here too. This is a false positive from attrs's Mypy plugin.
455 factory=_config_int_factory("core", "max_active_runs_per_dag"),
456 ),
457 validator=attrs.validators.instance_of(int),
458 )
459 max_consecutive_failed_dag_runs: int = attrs.field(
460 factory=_config_int_factory("core", "max_consecutive_failed_dag_runs_per_dag"),
461 converter=attrs.converters.default_if_none( # type: ignore[misc]
462 # attrs only supports named callables or lambdas, but partial works
463 # OK here too. This is a false positive from attrs's Mypy plugin.
464 factory=_config_int_factory("core", "max_consecutive_failed_dag_runs_per_dag"),
465 ),
466 validator=attrs.validators.instance_of(int),
467 )
468 dagrun_timeout: timedelta | None = attrs.field(
469 default=None,
470 validator=attrs.validators.optional(attrs.validators.instance_of(timedelta)),
471 )
472 deadline: list[DeadlineAlert] | DeadlineAlert | None = attrs.field(
473 default=None,
474 converter=_convert_deadline,
475 validator=attrs.validators.optional(
476 attrs.validators.deep_iterable(
477 member_validator=attrs.validators.instance_of(DeadlineAlert),
478 iterable_validator=attrs.validators.instance_of(list),
479 )
480 ),
481 )
482
483 sla_miss_callback: None = attrs.field(default=None)
484 catchup: bool = attrs.field(
485 factory=_config_bool_factory("scheduler", "catchup_by_default"),
486 )
487 on_success_callback: None | DagStateChangeCallback | list[DagStateChangeCallback] = None
488 on_failure_callback: None | DagStateChangeCallback | list[DagStateChangeCallback] = None
489 doc_md: str | None = attrs.field(default=None, converter=_convert_doc_md)
490 params: ParamsDict = attrs.field(
491 # mypy doesn't really like passing the Converter object
492 default=None,
493 converter=attrs.Converter(_convert_params, takes_self=True), # type: ignore[misc, call-overload]
494 )
495 access_control: dict[str, dict[str, Collection[str]]] | None = attrs.field(
496 default=None,
497 converter=attrs.Converter(_convert_access_control), # type: ignore[misc, call-overload]
498 )
499 is_paused_upon_creation: bool | None = None
500 jinja_environment_kwargs: dict | None = None
501 render_template_as_native_obj: bool = attrs.field(default=False, converter=bool)
502 tags: MutableSet[str] = attrs.field(factory=set, converter=_convert_tags)
503 owner_links: dict[str, str] = attrs.field(factory=dict)
504 auto_register: bool = attrs.field(default=True, converter=bool)
505 fail_fast: bool = attrs.field(default=False, converter=bool)
506 dag_display_name: str = attrs.field(
507 default=attrs.Factory(_default_dag_display_name, takes_self=True),
508 validator=attrs.validators.instance_of(str),
509 )
510
511 task_dict: dict[str, Operator] = attrs.field(factory=dict, init=False)
512
513 task_group: TaskGroup = attrs.field(
514 on_setattr=attrs.setters.frozen, default=attrs.Factory(_default_task_group, takes_self=True)
515 )
516
517 fileloc: str = attrs.field(init=False, factory=_default_fileloc)
518 relative_fileloc: str | None = attrs.field(init=False, default=None)
519 partial: bool = attrs.field(init=False, default=False)
520
521 edge_info: dict[str, dict[str, EdgeInfoType]] = attrs.field(init=False, factory=dict)
522
523 has_on_success_callback: bool = attrs.field(init=False)
524 has_on_failure_callback: bool = attrs.field(init=False)
525 disable_bundle_versioning: bool = attrs.field(
526 factory=_config_bool_factory("dag_processor", "disable_bundle_versioning")
527 )
528
529 # TODO (GH-52141): This is never used in the sdk dag (it only makes sense
530 # after this goes through the dag processor), but various parts of the code
531 # depends on its existence. We should remove this after completely splitting
532 # DAG classes in the SDK and scheduler.
533 last_loaded: datetime | None = attrs.field(init=False, default=None)
534
535 def __attrs_post_init__(self):
536 from airflow.sdk import timezone
537
538 # Apply the timezone we settled on to start_date, end_date if it wasn't supplied
539 if isinstance(_start_date := self.default_args.get("start_date"), str):
540 self.default_args["start_date"] = timezone.parse(_start_date, timezone=self.timezone)
541 if isinstance(_end_date := self.default_args.get("end_date"), str):
542 self.default_args["end_date"] = timezone.parse(_end_date, timezone=self.timezone)
543
544 self.start_date = timezone.convert_to_utc(self.start_date)
545 self.end_date = timezone.convert_to_utc(self.end_date)
546 if start_date := self.default_args.get("start_date", None):
547 self.default_args["start_date"] = timezone.convert_to_utc(start_date)
548 if end_date := self.default_args.get("end_date", None):
549 self.default_args["end_date"] = timezone.convert_to_utc(end_date)
550 if self.access_control is not None:
551 warnings.warn(
552 "The airflow.security.permissions module is deprecated; please see https://airflow.apache.org/docs/apache-airflow/stable/security/deprecated_permissions.html",
553 RemovedInAirflow4Warning,
554 stacklevel=2,
555 )
556 if (
557 active_runs_limit := self.timetable.active_runs_limit
558 ) is not None and active_runs_limit < self.max_active_runs:
559 raise ValueError(
560 f"Invalid max_active_runs: {type(self.timetable)} "
561 f"requires max_active_runs <= {active_runs_limit}"
562 )
563
564 @params.validator
565 def _validate_params(self, _, params: ParamsDict):
566 """
567 Validate Param values when the Dag has schedule defined.
568
569 Raise exception if there are any Params which can not be resolved by their schema definition.
570 """
571 if not self.timetable or not self.timetable.can_be_scheduled:
572 return
573
574 try:
575 params.validate()
576 except ParamValidationError as pverr:
577 raise ValueError(
578 f"Dag {self.dag_id!r} is not allowed to define a Schedule, "
579 "as there are required params without default values, or the default values are not valid."
580 ) from pverr
581
582 @catchup.validator
583 def _validate_catchup(self, _, catchup: bool):
584 requires_automatic_backfilling = self.timetable.can_be_scheduled and catchup
585 if requires_automatic_backfilling and not ("start_date" in self.default_args or self.start_date):
586 raise ValueError("start_date is required when catchup=True")
587
588 @tags.validator
589 def _validate_tags(self, _, tags: Collection[str]):
590 if tags and any(len(tag) > TAG_MAX_LEN for tag in tags):
591 raise ValueError(f"tag cannot be longer than {TAG_MAX_LEN} characters")
592
593 @max_active_runs.validator
594 def _validate_max_active_runs(self, _, max_active_runs):
595 if self.timetable.active_runs_limit is not None:
596 if self.timetable.active_runs_limit < self.max_active_runs:
597 raise ValueError(
598 f"Invalid max_active_runs: {type(self.timetable).__name__} "
599 f"requires max_active_runs <= {self.timetable.active_runs_limit}"
600 )
601
602 @timetable.default
603 def _default_timetable(instance: DAG) -> BaseTimetable | CoreTimetable:
604 schedule = instance.schedule
605 # TODO: Once
606 # delattr(self, "schedule")
607 if _is_core_timetable(schedule):
608 return schedule
609 if isinstance(schedule, BaseTimetable):
610 return schedule
611 if isinstance(schedule, BaseAsset):
612 return AssetTriggeredTimetable(schedule)
613 if isinstance(schedule, Collection) and not isinstance(schedule, str):
614 if not all(isinstance(x, BaseAsset) for x in schedule):
615 raise ValueError(
616 "All elements in 'schedule' should be either assets, asset references, or asset aliases"
617 )
618 return AssetTriggeredTimetable(AssetAll(*schedule))
619 return _create_timetable(schedule, instance.timezone)
620
621 @timezone.default
622 def _extract_tz(instance):
623 import pendulum
624
625 from airflow.sdk import timezone
626
627 start_date = instance.start_date or instance.default_args.get("start_date")
628
629 if start_date:
630 if not isinstance(start_date, datetime):
631 start_date = timezone.parse(start_date)
632 tzinfo = start_date.tzinfo or settings.TIMEZONE
633 tz = pendulum.instance(start_date, tz=tzinfo).timezone
634 else:
635 tz = settings.TIMEZONE
636
637 return tz
638
639 @has_on_success_callback.default
640 def _has_on_success_callback(self) -> bool:
641 return self.on_success_callback is not None
642
643 @has_on_failure_callback.default
644 def _has_on_failure_callback(self) -> bool:
645 return self.on_failure_callback is not None
646
647 @sla_miss_callback.validator
648 def _validate_sla_miss_callback(self, _, value):
649 if value is not None:
650 warnings.warn(
651 "The SLA feature is removed in Airflow 3.0, and replaced with a Deadline Alerts in >=3.1",
652 stacklevel=2,
653 )
654 return value
655
656 def __repr__(self):
657 return f"<DAG: {self.dag_id}>"
658
659 def __eq__(self, other: Self | Any):
660 # TODO: This subclassing behaviour seems wrong, but it's what Airflow has done for ~ever.
661 if type(self) is not type(other):
662 return False
663 return all(getattr(self, c, None) == getattr(other, c, None) for c in _DAG_HASH_ATTRS)
664
665 def __ne__(self, other: Any):
666 return not self == other
667
668 def __lt__(self, other):
669 return self.dag_id < other.dag_id
670
671 def __hash__(self):
672 hash_components: list[Any] = [type(self)]
673 for c in _DAG_HASH_ATTRS:
674 # If it is a list, convert to tuple because lists can't be hashed
675 if isinstance(getattr(self, c, None), list):
676 val = tuple(getattr(self, c))
677 else:
678 val = getattr(self, c, None)
679 try:
680 hash(val)
681 hash_components.append(val)
682 except TypeError:
683 hash_components.append(repr(val))
684 return hash(tuple(hash_components))
685
686 def __enter__(self) -> Self:
687 from airflow.sdk.definitions._internal.contextmanager import DagContext
688
689 DagContext.push(self)
690 return self
691
692 def __exit__(self, _type, _value, _tb):
693 from airflow.sdk.definitions._internal.contextmanager import DagContext
694
695 _ = DagContext.pop()
696
697 def validate(self):
698 """
699 Validate the Dag has a coherent setup.
700
701 This is called by the Dag bag before bagging the Dag.
702 """
703 self.timetable.validate()
704 self.validate_setup_teardown()
705
706 # We validate owner links on set, but since it's a dict it could be mutated without calling the
707 # setter. Validate again here
708 self._validate_owner_links(None, self.owner_links)
709
710 def validate_setup_teardown(self):
711 """
712 Validate that setup and teardown tasks are configured properly.
713
714 :meta private:
715 """
716 for task in self.tasks:
717 if task.is_setup:
718 for down_task in task.downstream_list:
719 if not down_task.is_teardown and down_task.trigger_rule != TriggerRule.ALL_SUCCESS:
720 # todo: we can relax this to allow out-of-scope tasks to have other trigger rules
721 # this is required to ensure consistent behavior of dag
722 # when clearing an indirect setup
723 raise ValueError("Setup tasks must be followed with trigger rule ALL_SUCCESS.")
724
725 def param(self, name: str, default: Any = NOTSET) -> DagParam:
726 """
727 Return a DagParam object for current dag.
728
729 :param name: dag parameter name.
730 :param default: fallback value for dag parameter.
731 :return: DagParam instance for specified name and current dag.
732 """
733 return DagParam(current_dag=self, name=name, default=default)
734
735 @property
736 def tasks(self) -> list[Operator]:
737 return list(self.task_dict.values())
738
739 @tasks.setter
740 def tasks(self, val):
741 raise AttributeError("DAG.tasks can not be modified. Use dag.add_task() instead.")
742
743 @property
744 def task_ids(self) -> list[str]:
745 return list(self.task_dict)
746
747 @property
748 def teardowns(self) -> list[Operator]:
749 return [task for task in self.tasks if getattr(task, "is_teardown", None)]
750
751 @property
752 def tasks_upstream_of_teardowns(self) -> list[Operator]:
753 upstream_tasks = [t.upstream_list for t in self.teardowns]
754 return [val for sublist in upstream_tasks for val in sublist if not getattr(val, "is_teardown", None)]
755
756 @property
757 def folder(self) -> str:
758 """Folder location of where the Dag object is instantiated."""
759 return os.path.dirname(self.fileloc)
760
761 @property
762 def owner(self) -> str:
763 """
764 Return list of all owners found in Dag tasks.
765
766 :return: Comma separated list of owners in Dag tasks
767 """
768 return ", ".join({t.owner for t in self.tasks})
769
770 def resolve_template_files(self):
771 for t in self.tasks:
772 # TODO: TaskSDK: move this on to BaseOperator and remove the check?
773 if hasattr(t, "resolve_template_files"):
774 t.resolve_template_files()
775
776 def get_template_env(self, *, force_sandboxed: bool = False) -> jinja2.Environment:
777 """Build a Jinja2 environment."""
778 from airflow.sdk.definitions._internal.templater import NativeEnvironment, SandboxedEnvironment
779
780 # Collect directories to search for template files
781 searchpath = [self.folder]
782 if self.template_searchpath:
783 searchpath += self.template_searchpath
784
785 # Default values (for backward compatibility)
786 jinja_env_options = {
787 "loader": jinja2.FileSystemLoader(searchpath),
788 "undefined": self.template_undefined,
789 "extensions": ["jinja2.ext.do"],
790 "cache_size": 0,
791 }
792 if self.jinja_environment_kwargs:
793 jinja_env_options.update(self.jinja_environment_kwargs)
794 env: jinja2.Environment
795 if self.render_template_as_native_obj and not force_sandboxed:
796 env = NativeEnvironment(**jinja_env_options)
797 else:
798 env = SandboxedEnvironment(**jinja_env_options)
799
800 # Add any user defined items. Safe to edit globals as long as no templates are rendered yet.
801 # http://jinja.pocoo.org/docs/2.10/api/#jinja2.Environment.globals
802 if self.user_defined_macros:
803 env.globals.update(self.user_defined_macros)
804 if self.user_defined_filters:
805 env.filters.update(self.user_defined_filters)
806
807 return env
808
809 def set_dependency(self, upstream_task_id, downstream_task_id):
810 """Set dependency between two tasks that already have been added to the Dag using add_task()."""
811 self.get_task(upstream_task_id).set_downstream(self.get_task(downstream_task_id))
812
813 @property
814 def roots(self) -> list[Operator]:
815 """Return nodes with no parents. These are first to execute and are called roots or root nodes."""
816 return [task for task in self.tasks if not task.upstream_list]
817
818 @property
819 def leaves(self) -> list[Operator]:
820 """Return nodes with no children. These are last to execute and are called leaves or leaf nodes."""
821 return [task for task in self.tasks if not task.downstream_list]
822
823 def topological_sort(self):
824 """
825 Sorts tasks in topographical order, such that a task comes after any of its upstream dependencies.
826
827 Deprecated in place of ``task_group.topological_sort``
828 """
829 from airflow.sdk.definitions.taskgroup import TaskGroup
830
831 # TODO: Remove in RemovedInAirflow3Warning
832 def nested_topo(group):
833 for node in group.topological_sort():
834 if isinstance(node, TaskGroup):
835 yield from nested_topo(node)
836 else:
837 yield node
838
839 return tuple(nested_topo(self.task_group))
840
841 def __deepcopy__(self, memo: dict[int, Any]):
842 # Switcharoo to go around deepcopying objects coming through the
843 # backdoor
844 cls = self.__class__
845 result = cls.__new__(cls)
846 memo[id(self)] = result
847 for k, v in self.__dict__.items():
848 if k not in ("user_defined_macros", "user_defined_filters", "_log"):
849 object.__setattr__(result, k, copy.deepcopy(v, memo))
850
851 result.user_defined_macros = self.user_defined_macros
852 result.user_defined_filters = self.user_defined_filters
853 if hasattr(self, "_log"):
854 result._log = self._log # type: ignore[attr-defined]
855 return result
856
857 def partial_subset(
858 self,
859 task_ids: str | Iterable[str],
860 include_downstream=False,
861 include_upstream=True,
862 include_direct_upstream=False,
863 depth: int | None = None,
864 ):
865 """
866 Return a subset of the current dag based on regex matching one or more tasks.
867
868 Returns a subset of the current dag as a deep copy of the current dag
869 based on a regex that should match one or many tasks, and includes
870 upstream and downstream neighbours based on the flag passed.
871
872 :param task_ids: Either a list of task_ids, or a string task_id
873 :param include_downstream: Include all downstream tasks of matched
874 tasks, in addition to matched tasks.
875 :param include_upstream: Include all upstream tasks of matched tasks,
876 in addition to matched tasks.
877 :param include_direct_upstream: Include all tasks directly upstream of matched
878 and downstream (if include_downstream = True) tasks
879 :param depth: Maximum number of levels to traverse in the upstream/downstream
880 direction. If None, traverses all levels. Must be non-negative.
881 """
882 from airflow.sdk.definitions.mappedoperator import MappedOperator
883
884 def is_task(obj) -> TypeGuard[Operator]:
885 return isinstance(obj, BaseOperator | MappedOperator)
886
887 # deep-copying self.task_dict and self.task_group takes a long time, and we don't want all
888 # the tasks anyway, so we copy the tasks manually later
889 memo = {id(self.task_dict): None, id(self.task_group): None}
890 dag = copy.deepcopy(self, memo)
891
892 if isinstance(task_ids, str):
893 matched_tasks = [t for t in self.tasks if task_ids in t.task_id]
894 else:
895 matched_tasks = [t for t in self.tasks if t.task_id in task_ids]
896
897 also_include_ids: set[str] = set()
898 for t in matched_tasks:
899 if include_downstream:
900 for rel in t.get_flat_relatives(upstream=False, depth=depth):
901 also_include_ids.add(rel.task_id)
902 if rel not in matched_tasks: # if it's in there, we're already processing it
903 # need to include setups and teardowns for tasks that are in multiple
904 # non-collinear setup/teardown paths
905 if not rel.is_setup and not rel.is_teardown:
906 also_include_ids.update(
907 x.task_id for x in rel.get_upstreams_only_setups_and_teardowns()
908 )
909 if include_upstream:
910 also_include_ids.update(x.task_id for x in t.get_upstreams_follow_setups(depth=depth))
911 else:
912 if not t.is_setup and not t.is_teardown:
913 also_include_ids.update(x.task_id for x in t.get_upstreams_only_setups_and_teardowns())
914 if t.is_setup and not include_downstream:
915 also_include_ids.update(x.task_id for x in t.downstream_list if x.is_teardown)
916
917 also_include: list[Operator] = [self.task_dict[x] for x in also_include_ids]
918 direct_upstreams: list[Operator] = []
919 if include_direct_upstream:
920 for t in itertools.chain(matched_tasks, also_include):
921 direct_upstreams.extend(u for u in t.upstream_list if is_task(u))
922
923 # Make sure to not recursively deepcopy the dag or task_group while copying the task.
924 # task_group is reset later
925 def _deepcopy_task(t) -> Operator:
926 memo.setdefault(id(t.task_group), None)
927 return copy.deepcopy(t, memo)
928
929 # Compiling the unique list of tasks that made the cut
930 dag.task_dict = {
931 t.task_id: _deepcopy_task(t)
932 for t in itertools.chain(matched_tasks, also_include, direct_upstreams)
933 }
934
935 def filter_task_group(group, parent_group):
936 """Exclude tasks not included in the partial dag from the given TaskGroup."""
937 # We want to deepcopy _most but not all_ attributes of the task group, so we create a shallow copy
938 # and then manually deep copy the instances. (memo argument to deepcopy only works for instances
939 # of classes, not "native" properties of an instance)
940 copied = copy.copy(group)
941
942 memo[id(group.children)] = {}
943 if parent_group:
944 memo[id(group.parent_group)] = parent_group
945 for attr in type(group).__slots__:
946 value = getattr(group, attr)
947 value = copy.deepcopy(value, memo)
948 object.__setattr__(copied, attr, value)
949
950 proxy = weakref.proxy(copied)
951
952 for child in group.children.values():
953 if is_task(child):
954 if child.task_id in dag.task_dict:
955 task = copied.children[child.task_id] = dag.task_dict[child.task_id]
956 task.task_group = proxy
957 else:
958 copied.used_group_ids.discard(child.task_id)
959 else:
960 filtered_child = filter_task_group(child, proxy)
961
962 # Only include this child TaskGroup if it is non-empty.
963 if filtered_child.children:
964 copied.children[child.group_id] = filtered_child
965
966 return copied
967
968 object.__setattr__(dag, "task_group", filter_task_group(self.task_group, None))
969
970 # Removing upstream/downstream references to tasks and TaskGroups that did not make
971 # the cut.
972 groups = dag.task_group.get_task_group_dict()
973 for g in groups.values():
974 g.upstream_group_ids.intersection_update(groups)
975 g.downstream_group_ids.intersection_update(groups)
976 g.upstream_task_ids.intersection_update(dag.task_dict)
977 g.downstream_task_ids.intersection_update(dag.task_dict)
978
979 for t in dag.tasks:
980 # Removing upstream/downstream references to tasks that did not
981 # make the cut
982 t.upstream_task_ids.intersection_update(dag.task_dict)
983 t.downstream_task_ids.intersection_update(dag.task_dict)
984
985 dag.partial = len(dag.tasks) < len(self.tasks)
986
987 return dag
988
989 def has_task(self, task_id: str):
990 return task_id in self.task_dict
991
992 def has_task_group(self, task_group_id: str) -> bool:
993 return task_group_id in self.task_group_dict
994
995 @functools.cached_property
996 def task_group_dict(self):
997 return {k: v for k, v in self.task_group.get_task_group_dict().items() if k is not None}
998
999 def get_task(self, task_id: str) -> Operator:
1000 if task_id in self.task_dict:
1001 return self.task_dict[task_id]
1002 raise TaskNotFound(f"Task {task_id} not found")
1003
1004 @property
1005 def task(self) -> TaskDecoratorCollection:
1006 from airflow.sdk.definitions.decorators import task
1007
1008 return cast("TaskDecoratorCollection", functools.partial(task, dag=self))
1009
1010 def add_task(self, task: Operator) -> None:
1011 """
1012 Add a task to the Dag.
1013
1014 :param task: the task you want to add
1015 """
1016 # FailStopDagInvalidTriggerRule.check(dag=self, trigger_rule=task.trigger_rule)
1017
1018 from airflow.sdk.definitions._internal.contextmanager import TaskGroupContext
1019
1020 # if the task has no start date, assign it the same as the Dag
1021 if not task.start_date:
1022 task.start_date = self.start_date
1023 # otherwise, the task will start on the later of its own start date and
1024 # the Dag's start date
1025 elif self.start_date:
1026 task.start_date = max(task.start_date, self.start_date)
1027
1028 # if the task has no end date, assign it the same as the dag
1029 if not task.end_date:
1030 task.end_date = self.end_date
1031 # otherwise, the task will end on the earlier of its own end date and
1032 # the Dag's end date
1033 elif task.end_date and self.end_date:
1034 task.end_date = min(task.end_date, self.end_date)
1035
1036 task_id = task.node_id
1037 if not task.task_group:
1038 task_group = TaskGroupContext.get_current(self)
1039 if task_group:
1040 task_id = task_group.child_id(task_id)
1041 task_group.add(task)
1042
1043 if (
1044 task_id in self.task_dict and self.task_dict[task_id] is not task
1045 ) or task_id in self.task_group.used_group_ids:
1046 raise DuplicateTaskIdFound(f"Task id '{task_id}' has already been added to the DAG")
1047 self.task_dict[task_id] = task
1048
1049 task.dag = self
1050 # Add task_id to used_group_ids to prevent group_id and task_id collisions.
1051 self.task_group.used_group_ids.add(task_id)
1052
1053 FailFastDagInvalidTriggerRule.check(fail_fast=self.fail_fast, trigger_rule=task.trigger_rule)
1054
1055 def add_tasks(self, tasks: Iterable[Operator]) -> None:
1056 """
1057 Add a list of tasks to the Dag.
1058
1059 :param tasks: a lit of tasks you want to add
1060 """
1061 for task in tasks:
1062 self.add_task(task)
1063
1064 def _remove_task(self, task_id: str) -> None:
1065 # This is "private" as removing could leave a hole in dependencies if done incorrectly, and this
1066 # doesn't guard against that
1067 task = self.task_dict.pop(task_id)
1068 tg = getattr(task, "task_group", None)
1069 if tg:
1070 tg._remove(task)
1071
1072 def check_cycle(self) -> None:
1073 """
1074 Check to see if there are any cycles in the Dag.
1075
1076 :raises AirflowDagCycleException: If cycle is found in the Dag.
1077 """
1078 # default of int is 0 which corresponds to CYCLE_NEW
1079 CYCLE_NEW = 0
1080 CYCLE_IN_PROGRESS = 1
1081 CYCLE_DONE = 2
1082
1083 visited: dict[str, int] = defaultdict(int)
1084 path_stack: deque[str] = deque()
1085 task_dict = self.task_dict
1086
1087 def _check_adjacent_tasks(task_id, current_task):
1088 """Return first untraversed child task, else None if all tasks traversed."""
1089 for adjacent_task in current_task.get_direct_relative_ids():
1090 if visited[adjacent_task] == CYCLE_IN_PROGRESS:
1091 msg = f"Cycle detected in Dag: {self.dag_id}. Faulty task: {task_id}"
1092 raise AirflowDagCycleException(msg)
1093 if visited[adjacent_task] == CYCLE_NEW:
1094 return adjacent_task
1095 return None
1096
1097 for dag_task_id in self.task_dict.keys():
1098 if visited[dag_task_id] == CYCLE_DONE:
1099 continue
1100 path_stack.append(dag_task_id)
1101 while path_stack:
1102 current_task_id = path_stack[-1]
1103 if visited[current_task_id] == CYCLE_NEW:
1104 visited[current_task_id] = CYCLE_IN_PROGRESS
1105 task = task_dict[current_task_id]
1106 child_to_check = _check_adjacent_tasks(current_task_id, task)
1107 if not child_to_check:
1108 visited[current_task_id] = CYCLE_DONE
1109 path_stack.pop()
1110 else:
1111 path_stack.append(child_to_check)
1112
1113 def cli(self):
1114 """Exposes a CLI specific to this Dag."""
1115 self.check_cycle()
1116
1117 from airflow.cli import cli_parser
1118
1119 parser = cli_parser.get_parser(dag_parser=True)
1120 args = parser.parse_args()
1121 args.func(args, self)
1122
1123 @classmethod
1124 def get_serialized_fields(cls):
1125 """Stringified Dags and operators contain exactly these fields."""
1126 return cls.__serialized_fields
1127
1128 def get_edge_info(self, upstream_task_id: str, downstream_task_id: str) -> EdgeInfoType:
1129 """Return edge information for the given pair of tasks or an empty edge if there is no information."""
1130 empty = cast("EdgeInfoType", {})
1131 if self.edge_info:
1132 return self.edge_info.get(upstream_task_id, {}).get(downstream_task_id, empty)
1133 return empty
1134
1135 def set_edge_info(self, upstream_task_id: str, downstream_task_id: str, info: EdgeInfoType):
1136 """
1137 Set the given edge information on the Dag.
1138
1139 Note that this will overwrite, rather than merge with, existing info.
1140 """
1141 self.edge_info.setdefault(upstream_task_id, {})[downstream_task_id] = info
1142
1143 @owner_links.validator
1144 def _validate_owner_links(self, _, owner_links):
1145 wrong_links = {}
1146
1147 for owner, link in owner_links.items():
1148 result = urlsplit(link)
1149 if result.scheme == "mailto":
1150 # netloc is not existing for 'mailto' link, so we are checking that the path is parsed
1151 if not result.path:
1152 wrong_links[result.path] = link
1153 elif not result.scheme or not result.netloc:
1154 wrong_links[owner] = link
1155 if wrong_links:
1156 raise ValueError(
1157 "Wrong link format was used for the owner. Use a valid link \n"
1158 f"Bad formatted links are: {wrong_links}"
1159 )
1160
1161 def test(
1162 self,
1163 run_after: datetime | None = None,
1164 logical_date: datetime | None | ArgNotSet = NOTSET,
1165 run_conf: dict[str, Any] | None = None,
1166 conn_file_path: str | None = None,
1167 variable_file_path: str | None = None,
1168 use_executor: bool = False,
1169 mark_success_pattern: Pattern | str | None = None,
1170 ):
1171 """
1172 Execute one single DagRun for a given Dag and logical date.
1173
1174 :param run_after: the datetime before which to Dag cannot run.
1175 :param logical_date: logical date for the Dag run
1176 :param run_conf: configuration to pass to newly created dagrun
1177 :param conn_file_path: file path to a connection file in either yaml or json
1178 :param variable_file_path: file path to a variable file in either yaml or json
1179 :param use_executor: if set, uses an executor to test the Dag
1180 :param mark_success_pattern: regex of task_ids to mark as success instead of running
1181 """
1182 import re
1183 import time
1184 from contextlib import ExitStack
1185 from unittest.mock import patch
1186
1187 from airflow import settings
1188 from airflow.models.dagrun import DagRun, get_or_create_dagrun
1189 from airflow.sdk import DagRunState, timezone
1190 from airflow.serialization.definitions.dag import SerializedDAG
1191 from airflow.serialization.encoders import coerce_to_core_timetable
1192 from airflow.serialization.serialized_objects import DagSerialization
1193 from airflow.utils.types import DagRunTriggeredByType, DagRunType
1194
1195 exit_stack = ExitStack()
1196
1197 if conn_file_path or variable_file_path:
1198 backend_kwargs = {}
1199 if conn_file_path:
1200 backend_kwargs["connections_file_path"] = conn_file_path
1201 if variable_file_path:
1202 backend_kwargs["variables_file_path"] = variable_file_path
1203
1204 exit_stack.enter_context(
1205 patch.dict(
1206 os.environ,
1207 {
1208 "AIRFLOW__SECRETS__BACKEND": "airflow.secrets.local_filesystem.LocalFilesystemBackend",
1209 "AIRFLOW__SECRETS__BACKEND_KWARGS": json.dumps(backend_kwargs),
1210 },
1211 )
1212 )
1213
1214 if settings.Session is None:
1215 raise RuntimeError("Session not configured. Call configure_orm() first.")
1216 session = settings.Session()
1217
1218 with exit_stack:
1219 self.validate()
1220 scheduler_dag = DagSerialization.deserialize_dag(DagSerialization.serialize_dag(self))
1221
1222 # Allow users to explicitly pass None. If it isn't set, we default to current time.
1223 logical_date = logical_date if is_arg_set(logical_date) else timezone.utcnow()
1224
1225 log.debug("Clearing existing task instances for logical date %s", logical_date)
1226 # TODO: Replace with calling client.dag_run.clear in Execution API at some point
1227 SerializedDAG.clear_dags(
1228 dags=[scheduler_dag],
1229 start_date=logical_date,
1230 end_date=logical_date,
1231 dag_run_state=False,
1232 )
1233
1234 log.debug("Getting dagrun for dag %s", self.dag_id)
1235 logical_date = timezone.coerce_datetime(logical_date)
1236 run_after = timezone.coerce_datetime(run_after) or timezone.coerce_datetime(timezone.utcnow())
1237 if logical_date is None:
1238 data_interval: DataInterval | None = None
1239 else:
1240 timetable = coerce_to_core_timetable(self.timetable)
1241 data_interval = timetable.infer_manual_data_interval(run_after=logical_date)
1242 from airflow.models.dag_version import DagVersion
1243
1244 version = DagVersion.get_version(self.dag_id)
1245 if not version:
1246 from airflow.dag_processing.bundles.manager import DagBundlesManager
1247 from airflow.dag_processing.dagbag import DagBag, sync_bag_to_db
1248 from airflow.sdk.definitions._internal.dag_parsing_context import (
1249 _airflow_parsing_context_manager,
1250 )
1251
1252 manager = DagBundlesManager()
1253 manager.sync_bundles_to_db(session=session)
1254 session.commit()
1255 # sync all bundles? or use the dags-folder bundle?
1256 # What if the test dag is in a different bundle?
1257 for bundle in manager.get_all_dag_bundles():
1258 if not bundle.is_initialized:
1259 bundle.initialize()
1260 with _airflow_parsing_context_manager(dag_id=self.dag_id):
1261 dagbag = DagBag(
1262 dag_folder=bundle.path, bundle_path=bundle.path, include_examples=False
1263 )
1264 sync_bag_to_db(dagbag, bundle.name, bundle.version)
1265 version = DagVersion.get_version(self.dag_id)
1266 if version:
1267 break
1268
1269 # Preserve callback functions from original Dag since they're lost during serialization
1270 # and yes it is a hack for now! It is a tradeoff for code simplicity.
1271 # Without it, we need "Scheduler Dag" (Serialized dag) for the scheduler bits
1272 # -- dep check, scheduling tis
1273 # and need real dag to get and run callbacks without having to load the dag model
1274
1275 # Scheduler DAG shouldn't have these attributes, but assigning them
1276 # here is an easy hack to get this test() thing working.
1277 scheduler_dag.on_success_callback = self.on_success_callback # type: ignore[attr-defined, union-attr]
1278 scheduler_dag.on_failure_callback = self.on_failure_callback # type: ignore[attr-defined, union-attr]
1279
1280 dr: DagRun = get_or_create_dagrun(
1281 dag=scheduler_dag,
1282 start_date=logical_date or run_after,
1283 logical_date=logical_date,
1284 data_interval=data_interval,
1285 run_after=run_after,
1286 run_id=DagRun.generate_run_id(
1287 run_type=DagRunType.MANUAL,
1288 logical_date=logical_date,
1289 run_after=run_after,
1290 ),
1291 session=session,
1292 conf=run_conf,
1293 triggered_by=DagRunTriggeredByType.TEST,
1294 triggering_user_name="dag_test",
1295 )
1296 # Start a mock span so that one is present and not started downstream. We
1297 # don't care about otel in dag.test and starting the span during dagrun update
1298 # is not functioning properly in this context anyway.
1299 dr.start_dr_spans_if_needed(tis=[])
1300
1301 log.debug("starting dagrun")
1302 # Instead of starting a scheduler, we run the minimal loop possible to check
1303 # for task readiness and dependency management.
1304 # Instead of starting a scheduler, we run the minimal loop possible to check
1305 # for task readiness and dependency management.
1306
1307 # ``Dag.test()`` works in two different modes depending on ``use_executor``:
1308 # - if ``use_executor`` is False, runs the task locally with no executor using ``_run_task``
1309 # - if ``use_executor`` is True, sends workloads to the executor with
1310 # ``BaseExecutor.queue_workload``
1311 if use_executor:
1312 from airflow.executors.base_executor import ExecutorLoader
1313
1314 executor = ExecutorLoader.get_default_executor()
1315 executor.start()
1316
1317 while dr.state == DagRunState.RUNNING:
1318 session.expire_all()
1319 schedulable_tis, _ = dr.update_state(session=session)
1320 for s in schedulable_tis:
1321 if s.state != TaskInstanceState.UP_FOR_RESCHEDULE:
1322 s.try_number += 1
1323 s.state = TaskInstanceState.SCHEDULED
1324 s.scheduled_dttm = timezone.utcnow()
1325 session.commit()
1326 # triggerer may mark tasks scheduled so we read from DB
1327 all_tis = set(dr.get_task_instances(session=session))
1328 scheduled_tis = {x for x in all_tis if x.state == TaskInstanceState.SCHEDULED}
1329 ids_unrunnable = {x for x in all_tis if x.state not in FINISHED_STATES} - scheduled_tis
1330 if not scheduled_tis and ids_unrunnable:
1331 log.warning("No tasks to run. unrunnable tasks: %s", ids_unrunnable)
1332 time.sleep(1)
1333
1334 for ti in scheduled_tis:
1335 task = self.task_dict[ti.task_id]
1336
1337 mark_success = (
1338 re.compile(mark_success_pattern).fullmatch(ti.task_id) is not None
1339 if mark_success_pattern is not None
1340 else False
1341 )
1342
1343 if use_executor:
1344 if executor.has_task(ti):
1345 continue
1346
1347 from pathlib import Path
1348
1349 from airflow.executors import workloads
1350 from airflow.executors.base_executor import ExecutorLoader
1351 from airflow.executors.workloads import BundleInfo
1352
1353 workload = workloads.ExecuteTask.make(
1354 ti,
1355 dag_rel_path=Path(self.fileloc),
1356 generator=executor.jwt_generator,
1357 sentry_integration=executor.sentry_integration,
1358 # For the system test/debug purpose, we use the default bundle which uses
1359 # local file system. If it turns out to be a feature people want, we could
1360 # plumb the Bundle to use as a parameter to dag.test
1361 bundle_info=BundleInfo(name="dags-folder"),
1362 )
1363 executor.queue_workload(workload, session=session)
1364 ti.state = TaskInstanceState.QUEUED
1365 session.commit()
1366 else:
1367 # Run the task locally
1368 try:
1369 if mark_success:
1370 ti.set_state(TaskInstanceState.SUCCESS)
1371 log.info("[DAG TEST] Marking success for %s on %s", task, ti.logical_date)
1372 else:
1373 _run_task(ti=ti, task=task, run_triggerer=True)
1374 except Exception:
1375 log.exception("Task failed; ti=%s", ti)
1376 if use_executor:
1377 executor.heartbeat()
1378 session.expire_all()
1379
1380 from airflow.jobs.scheduler_job_runner import SchedulerJobRunner
1381 from airflow.models.dagbag import DBDagBag
1382
1383 SchedulerJobRunner.process_executor_events(
1384 executor=executor, job_id=None, scheduler_dag_bag=DBDagBag(), session=session
1385 )
1386 if use_executor:
1387 executor.end()
1388 return dr
1389
1390
1391def _run_task(
1392 *,
1393 ti: SchedulerTaskInstance,
1394 task: Operator,
1395 run_triggerer: bool = False,
1396) -> TaskRunResult | None:
1397 """
1398 Run a single task instance, and push result to Xcom for downstream tasks.
1399
1400 Bypasses a lot of extra steps used in `task.run` to keep our local running as fast as
1401 possible. This function is only meant for the `dag.test` function as a helper function.
1402 """
1403 from airflow.sdk._shared.module_loading import import_string
1404 from airflow.sdk.serde import deserialize, serialize
1405 from airflow.utils.session import create_session
1406
1407 taskrun_result: TaskRunResult | None
1408 log.info("[DAG TEST] starting task_id=%s map_index=%s", ti.task_id, ti.map_index)
1409 while True:
1410 try:
1411 log.info("[DAG TEST] running task %s", ti)
1412
1413 from airflow.sdk.api.datamodels._generated import TaskInstance as TaskInstanceSDK
1414 from airflow.sdk.execution_time.comms import DeferTask
1415 from airflow.sdk.execution_time.supervisor import run_task_in_process
1416 from airflow.serialization.serialized_objects import create_scheduler_operator
1417
1418 # The API Server expects the task instance to be in QUEUED state before
1419 # it is run.
1420 ti.set_state(TaskInstanceState.QUEUED)
1421 task_sdk_ti = TaskInstanceSDK(
1422 id=UUID(str(ti.id)),
1423 task_id=ti.task_id,
1424 dag_id=ti.dag_id,
1425 run_id=ti.run_id,
1426 try_number=ti.try_number,
1427 map_index=ti.map_index,
1428 dag_version_id=UUID(str(ti.dag_version_id)),
1429 )
1430
1431 taskrun_result = run_task_in_process(ti=task_sdk_ti, task=task)
1432 msg = taskrun_result.msg
1433 ti.set_state(taskrun_result.ti.state)
1434 ti.task = create_scheduler_operator(taskrun_result.ti.task)
1435
1436 if ti.state == TaskInstanceState.DEFERRED and isinstance(msg, DeferTask) and run_triggerer:
1437 # API Server expects the task instance to be in QUEUED state before
1438 # resuming from deferral.
1439 ti.set_state(TaskInstanceState.QUEUED)
1440
1441 log.info("[DAG TEST] running trigger in line")
1442 # trigger_kwargs need to be deserialized before passing to the
1443 # trigger class since they are in serde encoded format.
1444 # Ignore needed to convince mypy that trigger_kwargs is a dict
1445 # or a str because its unable to infer JsonValue.
1446 kwargs = deserialize(msg.trigger_kwargs) # type: ignore[type-var]
1447 if TYPE_CHECKING:
1448 assert isinstance(kwargs, dict)
1449 trigger = import_string(msg.classpath)(**kwargs)
1450 event = _run_inline_trigger(trigger, task_sdk_ti)
1451 ti.next_method = msg.next_method
1452 ti.next_kwargs = {"event": serialize(event.payload)} if event else msg.next_kwargs
1453 log.info("[DAG TEST] Trigger completed")
1454
1455 # Set the state to SCHEDULED so that the task can be resumed.
1456 with create_session() as session:
1457 ti.state = TaskInstanceState.SCHEDULED
1458 session.add(ti)
1459 continue
1460
1461 break
1462 except Exception:
1463 log.exception("[DAG TEST] Error running task %s", ti)
1464 if ti.state not in FINISHED_STATES:
1465 ti.set_state(TaskInstanceState.FAILED)
1466 taskrun_result = None
1467 break
1468 raise
1469
1470 log.info("[DAG TEST] end task task_id=%s map_index=%s", ti.task_id, ti.map_index)
1471 return taskrun_result
1472
1473
1474def _run_inline_trigger(trigger, task_sdk_ti):
1475 from airflow.sdk.execution_time.supervisor import InProcessTestSupervisor
1476
1477 return InProcessTestSupervisor.run_trigger_in_process(trigger=trigger, ti=task_sdk_ti)
1478
1479
1480# Since we define all the attributes of the class with attrs, we can compute this statically at parse time
1481DAG._DAG__serialized_fields = frozenset(a.name for a in attrs.fields(DAG)) - { # type: ignore[attr-defined]
1482 "schedule_asset_references",
1483 "schedule_asset_alias_references",
1484 "task_outlet_asset_references",
1485 "_old_context_manager_dags",
1486 "safe_dag_id",
1487 "last_loaded",
1488 "user_defined_filters",
1489 "user_defined_macros",
1490 "partial",
1491 "params",
1492 "_log",
1493 "task_dict",
1494 "template_searchpath",
1495 "sla_miss_callback",
1496 "on_success_callback",
1497 "on_failure_callback",
1498 "template_undefined",
1499 "jinja_environment_kwargs",
1500 # has_on_*_callback are only stored if the value is True, as the default is False
1501 "has_on_success_callback",
1502 "has_on_failure_callback",
1503 "auto_register",
1504 "schedule",
1505}
1506
1507if TYPE_CHECKING:
1508 # NOTE: Please keep the list of arguments in sync with DAG.__init__.
1509 # Only exception: dag_id here should have a default value, but not in DAG.
1510 @overload
1511 def dag(
1512 dag_id: str = "",
1513 *,
1514 description: str | None = None,
1515 schedule: ScheduleArg = None,
1516 start_date: datetime | None = None,
1517 end_date: datetime | None = None,
1518 template_searchpath: str | Iterable[str] | None = None,
1519 template_undefined: type[jinja2.StrictUndefined] = jinja2.StrictUndefined,
1520 user_defined_macros: dict | None = None,
1521 user_defined_filters: dict | None = None,
1522 default_args: dict[str, Any] | None = None,
1523 max_active_tasks: int = ...,
1524 max_active_runs: int = ...,
1525 max_consecutive_failed_dag_runs: int = ...,
1526 dagrun_timeout: timedelta | None = None,
1527 catchup: bool = ...,
1528 on_success_callback: None | DagStateChangeCallback | list[DagStateChangeCallback] = None,
1529 on_failure_callback: None | DagStateChangeCallback | list[DagStateChangeCallback] = None,
1530 deadline: list[DeadlineAlert] | DeadlineAlert | None = None,
1531 doc_md: str | None = None,
1532 params: ParamsDict | dict[str, Any] | None = None,
1533 access_control: dict[str, dict[str, Collection[str]]] | dict[str, Collection[str]] | None = None,
1534 is_paused_upon_creation: bool | None = None,
1535 jinja_environment_kwargs: dict | None = None,
1536 render_template_as_native_obj: bool = False,
1537 tags: Collection[str] | None = None,
1538 owner_links: dict[str, str] | None = None,
1539 auto_register: bool = True,
1540 fail_fast: bool = False,
1541 dag_display_name: str | None = None,
1542 disable_bundle_versioning: bool = False,
1543 ) -> Callable[[Callable], Callable[..., DAG]]:
1544 """
1545 Python dag decorator which wraps a function into an Airflow Dag.
1546
1547 Accepts kwargs for operator kwarg. Can be used to parameterize Dags.
1548
1549 :param dag_args: Arguments for DAG object
1550 :param dag_kwargs: Kwargs for DAG object.
1551 """
1552
1553 @overload
1554 def dag(func: Callable[..., DAG]) -> Callable[..., DAG]:
1555 """Python dag decorator to use without any arguments."""
1556
1557
1558def dag(dag_id_or_func=None, __DAG_class=DAG, __warnings_stacklevel_delta=2, **decorator_kwargs):
1559 from airflow.sdk.definitions._internal.decorators import fixup_decorator_warning_stack
1560
1561 # TODO: Task-SDK: remove __DAG_class
1562 # __DAG_class is a temporary hack to allow the dag decorator in airflow.models.dag to continue to
1563 # return SchedulerDag objects
1564 DAG = __DAG_class
1565
1566 def wrapper(f: Callable) -> Callable[..., DAG]:
1567 # Determine dag_id: prioritize keyword arg, then positional string, fallback to function name
1568 if "dag_id" in decorator_kwargs:
1569 dag_id = decorator_kwargs.pop("dag_id", "")
1570 elif isinstance(dag_id_or_func, str) and dag_id_or_func.strip():
1571 dag_id = dag_id_or_func
1572 else:
1573 dag_id = f.__name__
1574
1575 @functools.wraps(f)
1576 def factory(*args, **kwargs):
1577 # Generate signature for decorated function and bind the arguments when called
1578 # we do this to extract parameters, so we can annotate them on the DAG object.
1579 # In addition, this fails if we are missing any args/kwargs with TypeError as expected.
1580 f_sig = signature(f).bind(*args, **kwargs)
1581 # Apply defaults to capture default values if set.
1582 f_sig.apply_defaults()
1583
1584 # Initialize Dag with bound arguments
1585 with DAG(dag_id, **decorator_kwargs) as dag_obj:
1586 # Set Dag documentation from function documentation if it exists and doc_md is not set.
1587 if f.__doc__ and not dag_obj.doc_md:
1588 dag_obj.doc_md = f.__doc__
1589
1590 # Generate DAGParam for each function arg/kwarg and replace it for calling the function.
1591 # All args/kwargs for function will be DAGParam object and replaced on execution time.
1592 f_kwargs = {}
1593 for name, value in f_sig.arguments.items():
1594 f_kwargs[name] = dag_obj.param(name, value)
1595
1596 # set file location to caller source path
1597 back = sys._getframe().f_back
1598 dag_obj.fileloc = back.f_code.co_filename if back else ""
1599
1600 # Invoke function to create operators in the Dag scope.
1601 f(**f_kwargs)
1602
1603 # Return dag object such that it's accessible in Globals.
1604 return dag_obj
1605
1606 # Ensure that warnings from inside DAG() are emitted from the caller, not here
1607 fixup_decorator_warning_stack(factory)
1608 return factory
1609
1610 if callable(dag_id_or_func) and not isinstance(dag_id_or_func, str):
1611 return wrapper(dag_id_or_func)
1612
1613 return wrapper