Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/airflow/sdk/definitions/dag.py: 26%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

745 statements  

1# 

2# Licensed to the Apache Software Foundation (ASF) under one 

3# or more contributor license agreements. See the NOTICE file 

4# distributed with this work for additional information 

5# regarding copyright ownership. The ASF licenses this file 

6# to you under the Apache License, Version 2.0 (the 

7# "License"); you may not use this file except in compliance 

8# with the License. You may obtain a copy of the License at 

9# 

10# http://www.apache.org/licenses/LICENSE-2.0 

11# 

12# Unless required by applicable law or agreed to in writing, 

13# software distributed under the License is distributed on an 

14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

15# KIND, either express or implied. See the License for the 

16# specific language governing permissions and limitations 

17# under the License. 

18from __future__ import annotations 

19 

20import copy 

21import functools 

22import itertools 

23import json 

24import logging 

25import os 

26import sys 

27import warnings 

28import weakref 

29from collections import abc, defaultdict, deque 

30from collections.abc import Callable, Collection, Iterable, MutableSet 

31from datetime import datetime, timedelta 

32from inspect import signature 

33from typing import TYPE_CHECKING, Any, ClassVar, TypeGuard, Union, cast, overload 

34from urllib.parse import urlsplit 

35from uuid import UUID 

36 

37import attrs 

38import jinja2 

39from dateutil.relativedelta import relativedelta 

40 

41from airflow import settings 

42from airflow.sdk import TaskInstanceState, TriggerRule 

43from airflow.sdk.bases.operator import BaseOperator 

44from airflow.sdk.bases.timetable import BaseTimetable 

45from airflow.sdk.definitions._internal.node import validate_key 

46from airflow.sdk.definitions._internal.types import NOTSET, ArgNotSet, is_arg_set 

47from airflow.sdk.definitions.asset import AssetAll, BaseAsset 

48from airflow.sdk.definitions.context import Context 

49from airflow.sdk.definitions.deadline import DeadlineAlert 

50from airflow.sdk.definitions.param import DagParam, ParamsDict 

51from airflow.sdk.definitions.timetables.assets import AssetTriggeredTimetable 

52from airflow.sdk.definitions.timetables.simple import ContinuousTimetable, NullTimetable, OnceTimetable 

53from airflow.sdk.exceptions import ( 

54 AirflowDagCycleException, 

55 DuplicateTaskIdFound, 

56 FailFastDagInvalidTriggerRule, 

57 ParamValidationError, 

58 RemovedInAirflow4Warning, 

59 TaskNotFound, 

60) 

61 

62if TYPE_CHECKING: 

63 from re import Pattern 

64 from typing import TypeAlias 

65 

66 from pendulum.tz.timezone import FixedTimezone, Timezone 

67 from typing_extensions import Self, TypeIs 

68 

69 from airflow.models.taskinstance import TaskInstance as SchedulerTaskInstance 

70 from airflow.sdk.definitions.decorators import TaskDecoratorCollection 

71 from airflow.sdk.definitions.edges import EdgeInfoType 

72 from airflow.sdk.definitions.mappedoperator import MappedOperator 

73 from airflow.sdk.definitions.taskgroup import TaskGroup 

74 from airflow.sdk.execution_time.supervisor import TaskRunResult 

75 from airflow.timetables.base import DataInterval, Timetable as CoreTimetable 

76 

77 Operator: TypeAlias = BaseOperator | MappedOperator 

78 

79log = logging.getLogger(__name__) 

80 

81TAG_MAX_LEN = 100 

82 

83__all__ = [ 

84 "DAG", 

85 "dag", 

86] 

87 

88FINISHED_STATES = frozenset( 

89 [ 

90 TaskInstanceState.SUCCESS, 

91 TaskInstanceState.FAILED, 

92 TaskInstanceState.SKIPPED, 

93 TaskInstanceState.UPSTREAM_FAILED, 

94 TaskInstanceState.REMOVED, 

95 ] 

96) 

97 

98DagStateChangeCallback = Callable[[Context], None] 

99ScheduleInterval = None | str | timedelta | relativedelta 

100 

101ScheduleArg = Union[ScheduleInterval, BaseTimetable, "CoreTimetable", BaseAsset, Collection[BaseAsset]] 

102 

103 

104_DAG_HASH_ATTRS = frozenset( 

105 { 

106 "dag_id", 

107 "task_ids", 

108 "start_date", 

109 "end_date", 

110 "fileloc", 

111 "template_searchpath", 

112 "last_loaded", 

113 "schedule", 

114 # TODO: Task-SDK: we should be hashing on timetable now, not schedule! 

115 # "timetable", 

116 } 

117) 

118 

119 

120def _is_core_timetable(schedule: ScheduleArg) -> TypeIs[CoreTimetable]: 

121 try: 

122 from airflow.timetables.base import Timetable 

123 except ImportError: 

124 return False 

125 return isinstance(schedule, Timetable) 

126 

127 

128def _create_timetable(interval: ScheduleInterval, timezone: Timezone | FixedTimezone) -> BaseTimetable: 

129 """Create a Timetable instance from a plain ``schedule`` value.""" 

130 from airflow.sdk.configuration import conf as airflow_conf 

131 from airflow.sdk.definitions.timetables.interval import ( 

132 CronDataIntervalTimetable, 

133 DeltaDataIntervalTimetable, 

134 ) 

135 from airflow.sdk.definitions.timetables.trigger import CronTriggerTimetable, DeltaTriggerTimetable 

136 

137 if interval is None: 

138 return NullTimetable() 

139 if interval == "@once": 

140 return OnceTimetable() 

141 if interval == "@continuous": 

142 return ContinuousTimetable() 

143 if isinstance(interval, timedelta | relativedelta): 

144 if airflow_conf.getboolean("scheduler", "create_cron_data_intervals"): 

145 return DeltaDataIntervalTimetable(interval) 

146 return DeltaTriggerTimetable(interval) 

147 if isinstance(interval, str): 

148 if airflow_conf.getboolean("scheduler", "create_cron_data_intervals"): 

149 return CronDataIntervalTimetable(interval, timezone) 

150 return CronTriggerTimetable(interval, timezone=timezone) 

151 raise ValueError(f"{interval!r} is not a valid schedule.") 

152 

153 

154def _config_bool_factory(section: str, key: str) -> Callable[[], bool]: 

155 from airflow.sdk.configuration import conf 

156 

157 return functools.partial(conf.getboolean, section, key) 

158 

159 

160def _config_int_factory(section: str, key: str) -> Callable[[], int]: 

161 from airflow.sdk.configuration import conf 

162 

163 return functools.partial(conf.getint, section, key) 

164 

165 

166def _convert_params(val: abc.MutableMapping | None, self_: DAG) -> ParamsDict: 

167 """ 

168 Convert the plain dict into a ParamsDict. 

169 

170 This will also merge in params from default_args 

171 """ 

172 val = val or {} 

173 

174 # merging potentially conflicting default_args['params'] into params 

175 if "params" in self_.default_args: 

176 val.update(self_.default_args["params"]) 

177 del self_.default_args["params"] 

178 

179 params = ParamsDict(val) 

180 object.__setattr__(self_, "params", params) 

181 

182 return params 

183 

184 

185def _convert_str_to_tuple(val: str | Iterable[str] | None) -> Iterable[str] | None: 

186 if isinstance(val, str): 

187 return (val,) 

188 return val 

189 

190 

191def _convert_tags(tags: Collection[str] | None) -> MutableSet[str]: 

192 return set(tags or []) 

193 

194 

195def _convert_access_control(access_control): 

196 if access_control is None: 

197 return None 

198 updated_access_control = {} 

199 for role, perms in access_control.items(): 

200 updated_access_control[role] = updated_access_control.get(role, {}) 

201 if isinstance(perms, set | list): 

202 # Support for old-style access_control where only the actions are specified 

203 updated_access_control[role]["DAGs"] = set(perms) 

204 else: 

205 updated_access_control[role] = perms 

206 return updated_access_control 

207 

208 

209def _convert_deadline(deadline: list[DeadlineAlert] | DeadlineAlert | None) -> list[DeadlineAlert] | None: 

210 """Convert deadline parameter to a list of DeadlineAlert objects.""" 

211 if deadline is None: 

212 return None 

213 if isinstance(deadline, DeadlineAlert): 

214 return [deadline] 

215 return list(deadline) 

216 

217 

218def _convert_doc_md(doc_md: str | None) -> str | None: 

219 if doc_md is None: 

220 return doc_md 

221 

222 if doc_md.endswith(".md"): 

223 try: 

224 with open(doc_md) as fh: 

225 return fh.read() 

226 except FileNotFoundError: 

227 return doc_md 

228 

229 return doc_md 

230 

231 

232def _all_after_dag_id_to_kw_only(cls, fields: list[attrs.Attribute]): 

233 i = iter(fields) 

234 f = next(i) 

235 if f.name != "dag_id": 

236 raise RuntimeError("dag_id was not the first field") 

237 yield f 

238 

239 for f in i: 

240 yield f.evolve(kw_only=True) 

241 

242 

243if TYPE_CHECKING: 

244 # Given this attrs field: 

245 # 

246 # default_args: dict[str, Any] = attrs.field(factory=dict, converter=copy.copy) 

247 # 

248 # mypy ignores the type of the attrs and works out the type as the converter function. However it doesn't 

249 # cope with generics properly and errors with 'incompatible type "dict[str, object]"; expected "_T"' 

250 # 

251 # https://github.com/python/mypy/issues/8625 

252 def dict_copy(_: dict[str, Any]) -> dict[str, Any]: ... 

253else: 

254 dict_copy = copy.copy 

255 

256 

257def _default_start_date(instance: DAG): 

258 # Find start date inside default_args for compat with Airflow 2. 

259 from airflow.sdk import timezone 

260 

261 if date := instance.default_args.get("start_date"): 

262 if not isinstance(date, datetime): 

263 date = timezone.parse(date) 

264 instance.default_args["start_date"] = date 

265 return date 

266 return None 

267 

268 

269def _default_dag_display_name(instance: DAG) -> str: 

270 return instance.dag_id 

271 

272 

273def _default_fileloc() -> str: 

274 # Skip over this frame, and the 'attrs generated init' 

275 back = sys._getframe().f_back 

276 if not back or not (back := back.f_back): 

277 # We expect two frames back, if not we don't know where we are 

278 return "" 

279 return back.f_code.co_filename if back else "" 

280 

281 

282def _default_task_group(instance: DAG) -> TaskGroup: 

283 from airflow.sdk.definitions.taskgroup import TaskGroup 

284 

285 return TaskGroup.create_root(dag=instance) 

286 

287 

288# TODO: Task-SDK: look at re-enabling slots after we remove pickling 

289@attrs.define(repr=False, field_transformer=_all_after_dag_id_to_kw_only, slots=False) 

290class DAG: 

291 """ 

292 A dag is a collection of tasks with directional dependencies. 

293 

294 A dag also has a schedule, a start date and an end date (optional). For each schedule, 

295 (say daily or hourly), the DAG needs to run each individual tasks as their dependencies 

296 are met. Certain tasks have the property of depending on their own past, meaning that 

297 they can't run until their previous schedule (and upstream tasks) are completed. 

298 

299 Dags essentially act as namespaces for tasks. A task_id can only be 

300 added once to a Dag. 

301 

302 Note that if you plan to use time zones all the dates provided should be pendulum 

303 dates. See :ref:`timezone_aware_dags`. 

304 

305 .. versionadded:: 2.4 

306 The *schedule* argument to specify either time-based scheduling logic 

307 (timetable), or dataset-driven triggers. 

308 

309 .. versionchanged:: 3.0 

310 The default value of *schedule* has been changed to *None* (no schedule). 

311 The previous default was ``timedelta(days=1)``. 

312 

313 :param dag_id: The id of the DAG; must consist exclusively of alphanumeric 

314 characters, dashes, dots and underscores (all ASCII) 

315 :param description: The description for the DAG to e.g. be shown on the webserver 

316 :param schedule: If provided, this defines the rules according to which DAG 

317 runs are scheduled. Possible values include a cron expression string, 

318 timedelta object, Timetable, or list of Asset objects. 

319 See also :external:doc:`howto/timetable`. 

320 :param start_date: The timestamp from which the scheduler will 

321 attempt to backfill. If this is not provided, backfilling must be done 

322 manually with an explicit time range. 

323 :param end_date: A date beyond which your DAG won't run, leave to None 

324 for open-ended scheduling. 

325 :param template_searchpath: This list of folders (non-relative) 

326 defines where jinja will look for your templates. Order matters. 

327 Note that jinja/airflow includes the path of your DAG file by 

328 default 

329 :param template_undefined: Template undefined type. 

330 :param user_defined_macros: a dictionary of macros that will be exposed 

331 in your jinja templates. For example, passing ``dict(foo='bar')`` 

332 to this argument allows you to ``{{ foo }}`` in all jinja 

333 templates related to this DAG. Note that you can pass any 

334 type of object here. 

335 :param user_defined_filters: a dictionary of filters that will be exposed 

336 in your jinja templates. For example, passing 

337 ``dict(hello=lambda name: 'Hello %s' % name)`` to this argument allows 

338 you to ``{{ 'world' | hello }}`` in all jinja templates related to 

339 this DAG. 

340 :param default_args: A dictionary of default parameters to be used 

341 as constructor keyword parameters when initialising operators. 

342 Note that operators have the same hook, and precede those defined 

343 here, meaning that if your dict contains `'depends_on_past': True` 

344 here and `'depends_on_past': False` in the operator's call 

345 `default_args`, the actual value will be `False`. 

346 :param params: a dictionary of DAG level parameters that are made 

347 accessible in templates, namespaced under `params`. These 

348 params can be overridden at the task level. 

349 :param max_active_tasks: the number of task instances allowed to run 

350 concurrently 

351 :param max_active_runs: maximum number of active DAG runs, beyond this 

352 number of DAG runs in a running state, the scheduler won't create 

353 new active DAG runs 

354 :param max_consecutive_failed_dag_runs: (experimental) maximum number of consecutive failed DAG runs, 

355 beyond this the scheduler will disable the DAG 

356 :param dagrun_timeout: Specify the duration a DagRun should be allowed to run before it times out or 

357 fails. Task instances that are running when a DagRun is timed out will be marked as skipped. 

358 :param sla_miss_callback: DEPRECATED - The SLA feature is removed in Airflow 3.0, to be replaced with DeadlineAlerts in 3.1 

359 :param deadline: An optional DeadlineAlert for the Dag. 

360 :param catchup: Perform scheduler catchup (or only run latest)? Defaults to False 

361 :param on_failure_callback: A function or list of functions to be called when a DagRun of this dag fails. 

362 A context dictionary is passed as a single parameter to this function. 

363 :param on_success_callback: Much like the ``on_failure_callback`` except 

364 that it is executed when the dag succeeds. 

365 :param access_control: Specify optional DAG-level actions, e.g., 

366 "{'role1': {'can_read'}, 'role2': {'can_read', 'can_edit', 'can_delete'}}" 

367 or it can specify the resource name if there is a DAGs Run resource, e.g., 

368 "{'role1': {'DAG Runs': {'can_create'}}, 'role2': {'DAGs': {'can_read', 'can_edit', 'can_delete'}}" 

369 :param is_paused_upon_creation: Specifies if the dag is paused when created for the first time. 

370 If the dag exists already, this flag will be ignored. If this optional parameter 

371 is not specified, the global config setting will be used. 

372 :param jinja_environment_kwargs: additional configuration options to be passed to Jinja 

373 ``Environment`` for template rendering 

374 

375 **Example**: to avoid Jinja from removing a trailing newline from template strings :: 

376 

377 DAG( 

378 dag_id="my-dag", 

379 jinja_environment_kwargs={ 

380 "keep_trailing_newline": True, 

381 # some other jinja2 Environment options here 

382 }, 

383 ) 

384 

385 **See**: `Jinja Environment documentation 

386 <https://jinja.palletsprojects.com/en/2.11.x/api/#jinja2.Environment>`_ 

387 

388 :param render_template_as_native_obj: If True, uses a Jinja ``NativeEnvironment`` 

389 to render templates as native Python types. If False, a Jinja 

390 ``Environment`` is used to render templates as string values. 

391 :param tags: List of tags to help filtering Dags in the UI. 

392 :param owner_links: Dict of owners and their links, that will be clickable on the Dags view UI. 

393 Can be used as an HTTP link (for example the link to your Slack channel), or a mailto link. 

394 e.g: ``{"dag_owner": "https://airflow.apache.org/"}`` 

395 :param auto_register: Automatically register this DAG when it is used in a ``with`` block 

396 :param fail_fast: Fails currently running tasks when task in Dag fails. 

397 **Warning**: A fail stop dag can only have tasks with the default trigger rule ("all_success"). 

398 An exception will be thrown if any task in a fail stop dag has a non default trigger rule. 

399 :param dag_display_name: The display name of the Dag which appears on the UI. 

400 """ 

401 

402 __serialized_fields: ClassVar[frozenset[str]] 

403 

404 # Note: mypy gets very confused about the use of `@${attr}.default` for attrs without init=False -- and it 

405 # doesn't correctly track/notice that they have default values (it gives errors about `Missing positional 

406 # argument "description" in call to "DAG"`` etc), so for init=True args we use the `default=Factory()` 

407 # style 

408 

409 def __rich_repr__(self): 

410 yield "dag_id", self.dag_id 

411 yield "schedule", self.schedule 

412 yield "#tasks", len(self.tasks) 

413 

414 __rich_repr__.angular = True # type: ignore[attr-defined] 

415 

416 # NOTE: When updating arguments here, please also keep arguments in @dag() 

417 # below in sync. (Search for 'def dag(' in this file.) 

418 dag_id: str = attrs.field(kw_only=False, validator=lambda i, a, v: validate_key(v)) 

419 description: str | None = attrs.field( 

420 default=None, 

421 validator=attrs.validators.optional(attrs.validators.instance_of(str)), 

422 ) 

423 default_args: dict[str, Any] = attrs.field( 

424 factory=dict, validator=attrs.validators.instance_of(dict), converter=dict_copy 

425 ) 

426 start_date: datetime | None = attrs.field( 

427 default=attrs.Factory(_default_start_date, takes_self=True), 

428 ) 

429 

430 end_date: datetime | None = None 

431 timezone: FixedTimezone | Timezone = attrs.field(init=False) 

432 schedule: ScheduleArg = attrs.field(default=None, on_setattr=attrs.setters.frozen) 

433 timetable: BaseTimetable | CoreTimetable = attrs.field(init=False) 

434 template_searchpath: str | Iterable[str] | None = attrs.field( 

435 default=None, converter=_convert_str_to_tuple 

436 ) 

437 # TODO: Task-SDK: Work out how to not import jinj2 until we need it! It's expensive 

438 template_undefined: type[jinja2.StrictUndefined] = jinja2.StrictUndefined 

439 user_defined_macros: dict | None = None 

440 user_defined_filters: dict | None = None 

441 max_active_tasks: int = attrs.field( 

442 factory=_config_int_factory("core", "max_active_tasks_per_dag"), 

443 converter=attrs.converters.default_if_none( # type: ignore[misc] 

444 # attrs only supports named callables or lambdas, but partial works 

445 # OK here too. This is a false positive from attrs's Mypy plugin. 

446 factory=_config_int_factory("core", "max_active_tasks_per_dag"), 

447 ), 

448 validator=attrs.validators.instance_of(int), 

449 ) 

450 max_active_runs: int = attrs.field( 

451 factory=_config_int_factory("core", "max_active_runs_per_dag"), 

452 converter=attrs.converters.default_if_none( # type: ignore[misc] 

453 # attrs only supports named callables or lambdas, but partial works 

454 # OK here too. This is a false positive from attrs's Mypy plugin. 

455 factory=_config_int_factory("core", "max_active_runs_per_dag"), 

456 ), 

457 validator=attrs.validators.instance_of(int), 

458 ) 

459 max_consecutive_failed_dag_runs: int = attrs.field( 

460 factory=_config_int_factory("core", "max_consecutive_failed_dag_runs_per_dag"), 

461 converter=attrs.converters.default_if_none( # type: ignore[misc] 

462 # attrs only supports named callables or lambdas, but partial works 

463 # OK here too. This is a false positive from attrs's Mypy plugin. 

464 factory=_config_int_factory("core", "max_consecutive_failed_dag_runs_per_dag"), 

465 ), 

466 validator=attrs.validators.instance_of(int), 

467 ) 

468 dagrun_timeout: timedelta | None = attrs.field( 

469 default=None, 

470 validator=attrs.validators.optional(attrs.validators.instance_of(timedelta)), 

471 ) 

472 deadline: list[DeadlineAlert] | DeadlineAlert | None = attrs.field( 

473 default=None, 

474 converter=_convert_deadline, 

475 validator=attrs.validators.optional( 

476 attrs.validators.deep_iterable( 

477 member_validator=attrs.validators.instance_of(DeadlineAlert), 

478 iterable_validator=attrs.validators.instance_of(list), 

479 ) 

480 ), 

481 ) 

482 

483 sla_miss_callback: None = attrs.field(default=None) 

484 catchup: bool = attrs.field( 

485 factory=_config_bool_factory("scheduler", "catchup_by_default"), 

486 ) 

487 on_success_callback: None | DagStateChangeCallback | list[DagStateChangeCallback] = None 

488 on_failure_callback: None | DagStateChangeCallback | list[DagStateChangeCallback] = None 

489 doc_md: str | None = attrs.field(default=None, converter=_convert_doc_md) 

490 params: ParamsDict = attrs.field( 

491 # mypy doesn't really like passing the Converter object 

492 default=None, 

493 converter=attrs.Converter(_convert_params, takes_self=True), # type: ignore[misc, call-overload] 

494 ) 

495 access_control: dict[str, dict[str, Collection[str]]] | None = attrs.field( 

496 default=None, 

497 converter=attrs.Converter(_convert_access_control), # type: ignore[misc, call-overload] 

498 ) 

499 is_paused_upon_creation: bool | None = None 

500 jinja_environment_kwargs: dict | None = None 

501 render_template_as_native_obj: bool = attrs.field(default=False, converter=bool) 

502 tags: MutableSet[str] = attrs.field(factory=set, converter=_convert_tags) 

503 owner_links: dict[str, str] = attrs.field(factory=dict) 

504 auto_register: bool = attrs.field(default=True, converter=bool) 

505 fail_fast: bool = attrs.field(default=False, converter=bool) 

506 dag_display_name: str = attrs.field( 

507 default=attrs.Factory(_default_dag_display_name, takes_self=True), 

508 validator=attrs.validators.instance_of(str), 

509 ) 

510 

511 task_dict: dict[str, Operator] = attrs.field(factory=dict, init=False) 

512 

513 task_group: TaskGroup = attrs.field( 

514 on_setattr=attrs.setters.frozen, default=attrs.Factory(_default_task_group, takes_self=True) 

515 ) 

516 

517 fileloc: str = attrs.field(init=False, factory=_default_fileloc) 

518 relative_fileloc: str | None = attrs.field(init=False, default=None) 

519 partial: bool = attrs.field(init=False, default=False) 

520 

521 edge_info: dict[str, dict[str, EdgeInfoType]] = attrs.field(init=False, factory=dict) 

522 

523 has_on_success_callback: bool = attrs.field(init=False) 

524 has_on_failure_callback: bool = attrs.field(init=False) 

525 disable_bundle_versioning: bool = attrs.field( 

526 factory=_config_bool_factory("dag_processor", "disable_bundle_versioning") 

527 ) 

528 

529 # TODO (GH-52141): This is never used in the sdk dag (it only makes sense 

530 # after this goes through the dag processor), but various parts of the code 

531 # depends on its existence. We should remove this after completely splitting 

532 # DAG classes in the SDK and scheduler. 

533 last_loaded: datetime | None = attrs.field(init=False, default=None) 

534 

535 def __attrs_post_init__(self): 

536 from airflow.sdk import timezone 

537 

538 # Apply the timezone we settled on to start_date, end_date if it wasn't supplied 

539 if isinstance(_start_date := self.default_args.get("start_date"), str): 

540 self.default_args["start_date"] = timezone.parse(_start_date, timezone=self.timezone) 

541 if isinstance(_end_date := self.default_args.get("end_date"), str): 

542 self.default_args["end_date"] = timezone.parse(_end_date, timezone=self.timezone) 

543 

544 self.start_date = timezone.convert_to_utc(self.start_date) 

545 self.end_date = timezone.convert_to_utc(self.end_date) 

546 if start_date := self.default_args.get("start_date", None): 

547 self.default_args["start_date"] = timezone.convert_to_utc(start_date) 

548 if end_date := self.default_args.get("end_date", None): 

549 self.default_args["end_date"] = timezone.convert_to_utc(end_date) 

550 if self.access_control is not None: 

551 warnings.warn( 

552 "The airflow.security.permissions module is deprecated; please see https://airflow.apache.org/docs/apache-airflow/stable/security/deprecated_permissions.html", 

553 RemovedInAirflow4Warning, 

554 stacklevel=2, 

555 ) 

556 if ( 

557 active_runs_limit := self.timetable.active_runs_limit 

558 ) is not None and active_runs_limit < self.max_active_runs: 

559 raise ValueError( 

560 f"Invalid max_active_runs: {type(self.timetable)} " 

561 f"requires max_active_runs <= {active_runs_limit}" 

562 ) 

563 

564 @params.validator 

565 def _validate_params(self, _, params: ParamsDict): 

566 """ 

567 Validate Param values when the Dag has schedule defined. 

568 

569 Raise exception if there are any Params which can not be resolved by their schema definition. 

570 """ 

571 if not self.timetable or not self.timetable.can_be_scheduled: 

572 return 

573 

574 try: 

575 params.validate() 

576 except ParamValidationError as pverr: 

577 raise ValueError( 

578 f"Dag {self.dag_id!r} is not allowed to define a Schedule, " 

579 "as there are required params without default values, or the default values are not valid." 

580 ) from pverr 

581 

582 @catchup.validator 

583 def _validate_catchup(self, _, catchup: bool): 

584 requires_automatic_backfilling = self.timetable.can_be_scheduled and catchup 

585 if requires_automatic_backfilling and not ("start_date" in self.default_args or self.start_date): 

586 raise ValueError("start_date is required when catchup=True") 

587 

588 @tags.validator 

589 def _validate_tags(self, _, tags: Collection[str]): 

590 if tags and any(len(tag) > TAG_MAX_LEN for tag in tags): 

591 raise ValueError(f"tag cannot be longer than {TAG_MAX_LEN} characters") 

592 

593 @max_active_runs.validator 

594 def _validate_max_active_runs(self, _, max_active_runs): 

595 if self.timetable.active_runs_limit is not None: 

596 if self.timetable.active_runs_limit < self.max_active_runs: 

597 raise ValueError( 

598 f"Invalid max_active_runs: {type(self.timetable).__name__} " 

599 f"requires max_active_runs <= {self.timetable.active_runs_limit}" 

600 ) 

601 

602 @timetable.default 

603 def _default_timetable(instance: DAG) -> BaseTimetable | CoreTimetable: 

604 schedule = instance.schedule 

605 # TODO: Once 

606 # delattr(self, "schedule") 

607 if _is_core_timetable(schedule): 

608 return schedule 

609 if isinstance(schedule, BaseTimetable): 

610 return schedule 

611 if isinstance(schedule, BaseAsset): 

612 return AssetTriggeredTimetable(schedule) 

613 if isinstance(schedule, Collection) and not isinstance(schedule, str): 

614 if not all(isinstance(x, BaseAsset) for x in schedule): 

615 raise ValueError( 

616 "All elements in 'schedule' should be either assets, asset references, or asset aliases" 

617 ) 

618 return AssetTriggeredTimetable(AssetAll(*schedule)) 

619 return _create_timetable(schedule, instance.timezone) 

620 

621 @timezone.default 

622 def _extract_tz(instance): 

623 import pendulum 

624 

625 from airflow.sdk import timezone 

626 

627 start_date = instance.start_date or instance.default_args.get("start_date") 

628 

629 if start_date: 

630 if not isinstance(start_date, datetime): 

631 start_date = timezone.parse(start_date) 

632 tzinfo = start_date.tzinfo or settings.TIMEZONE 

633 tz = pendulum.instance(start_date, tz=tzinfo).timezone 

634 else: 

635 tz = settings.TIMEZONE 

636 

637 return tz 

638 

639 @has_on_success_callback.default 

640 def _has_on_success_callback(self) -> bool: 

641 return self.on_success_callback is not None 

642 

643 @has_on_failure_callback.default 

644 def _has_on_failure_callback(self) -> bool: 

645 return self.on_failure_callback is not None 

646 

647 @sla_miss_callback.validator 

648 def _validate_sla_miss_callback(self, _, value): 

649 if value is not None: 

650 warnings.warn( 

651 "The SLA feature is removed in Airflow 3.0, and replaced with a Deadline Alerts in >=3.1", 

652 stacklevel=2, 

653 ) 

654 return value 

655 

656 def __repr__(self): 

657 return f"<DAG: {self.dag_id}>" 

658 

659 def __eq__(self, other: Self | Any): 

660 # TODO: This subclassing behaviour seems wrong, but it's what Airflow has done for ~ever. 

661 if type(self) is not type(other): 

662 return False 

663 return all(getattr(self, c, None) == getattr(other, c, None) for c in _DAG_HASH_ATTRS) 

664 

665 def __ne__(self, other: Any): 

666 return not self == other 

667 

668 def __lt__(self, other): 

669 return self.dag_id < other.dag_id 

670 

671 def __hash__(self): 

672 hash_components: list[Any] = [type(self)] 

673 for c in _DAG_HASH_ATTRS: 

674 # If it is a list, convert to tuple because lists can't be hashed 

675 if isinstance(getattr(self, c, None), list): 

676 val = tuple(getattr(self, c)) 

677 else: 

678 val = getattr(self, c, None) 

679 try: 

680 hash(val) 

681 hash_components.append(val) 

682 except TypeError: 

683 hash_components.append(repr(val)) 

684 return hash(tuple(hash_components)) 

685 

686 def __enter__(self) -> Self: 

687 from airflow.sdk.definitions._internal.contextmanager import DagContext 

688 

689 DagContext.push(self) 

690 return self 

691 

692 def __exit__(self, _type, _value, _tb): 

693 from airflow.sdk.definitions._internal.contextmanager import DagContext 

694 

695 _ = DagContext.pop() 

696 

697 def validate(self): 

698 """ 

699 Validate the Dag has a coherent setup. 

700 

701 This is called by the Dag bag before bagging the Dag. 

702 """ 

703 self.timetable.validate() 

704 self.validate_setup_teardown() 

705 

706 # We validate owner links on set, but since it's a dict it could be mutated without calling the 

707 # setter. Validate again here 

708 self._validate_owner_links(None, self.owner_links) 

709 

710 def validate_setup_teardown(self): 

711 """ 

712 Validate that setup and teardown tasks are configured properly. 

713 

714 :meta private: 

715 """ 

716 for task in self.tasks: 

717 if task.is_setup: 

718 for down_task in task.downstream_list: 

719 if not down_task.is_teardown and down_task.trigger_rule != TriggerRule.ALL_SUCCESS: 

720 # todo: we can relax this to allow out-of-scope tasks to have other trigger rules 

721 # this is required to ensure consistent behavior of dag 

722 # when clearing an indirect setup 

723 raise ValueError("Setup tasks must be followed with trigger rule ALL_SUCCESS.") 

724 

725 def param(self, name: str, default: Any = NOTSET) -> DagParam: 

726 """ 

727 Return a DagParam object for current dag. 

728 

729 :param name: dag parameter name. 

730 :param default: fallback value for dag parameter. 

731 :return: DagParam instance for specified name and current dag. 

732 """ 

733 return DagParam(current_dag=self, name=name, default=default) 

734 

735 @property 

736 def tasks(self) -> list[Operator]: 

737 return list(self.task_dict.values()) 

738 

739 @tasks.setter 

740 def tasks(self, val): 

741 raise AttributeError("DAG.tasks can not be modified. Use dag.add_task() instead.") 

742 

743 @property 

744 def task_ids(self) -> list[str]: 

745 return list(self.task_dict) 

746 

747 @property 

748 def teardowns(self) -> list[Operator]: 

749 return [task for task in self.tasks if getattr(task, "is_teardown", None)] 

750 

751 @property 

752 def tasks_upstream_of_teardowns(self) -> list[Operator]: 

753 upstream_tasks = [t.upstream_list for t in self.teardowns] 

754 return [val for sublist in upstream_tasks for val in sublist if not getattr(val, "is_teardown", None)] 

755 

756 @property 

757 def folder(self) -> str: 

758 """Folder location of where the Dag object is instantiated.""" 

759 return os.path.dirname(self.fileloc) 

760 

761 @property 

762 def owner(self) -> str: 

763 """ 

764 Return list of all owners found in Dag tasks. 

765 

766 :return: Comma separated list of owners in Dag tasks 

767 """ 

768 return ", ".join({t.owner for t in self.tasks}) 

769 

770 def resolve_template_files(self): 

771 for t in self.tasks: 

772 # TODO: TaskSDK: move this on to BaseOperator and remove the check? 

773 if hasattr(t, "resolve_template_files"): 

774 t.resolve_template_files() 

775 

776 def get_template_env(self, *, force_sandboxed: bool = False) -> jinja2.Environment: 

777 """Build a Jinja2 environment.""" 

778 from airflow.sdk.definitions._internal.templater import NativeEnvironment, SandboxedEnvironment 

779 

780 # Collect directories to search for template files 

781 searchpath = [self.folder] 

782 if self.template_searchpath: 

783 searchpath += self.template_searchpath 

784 

785 # Default values (for backward compatibility) 

786 jinja_env_options = { 

787 "loader": jinja2.FileSystemLoader(searchpath), 

788 "undefined": self.template_undefined, 

789 "extensions": ["jinja2.ext.do"], 

790 "cache_size": 0, 

791 } 

792 if self.jinja_environment_kwargs: 

793 jinja_env_options.update(self.jinja_environment_kwargs) 

794 env: jinja2.Environment 

795 if self.render_template_as_native_obj and not force_sandboxed: 

796 env = NativeEnvironment(**jinja_env_options) 

797 else: 

798 env = SandboxedEnvironment(**jinja_env_options) 

799 

800 # Add any user defined items. Safe to edit globals as long as no templates are rendered yet. 

801 # http://jinja.pocoo.org/docs/2.10/api/#jinja2.Environment.globals 

802 if self.user_defined_macros: 

803 env.globals.update(self.user_defined_macros) 

804 if self.user_defined_filters: 

805 env.filters.update(self.user_defined_filters) 

806 

807 return env 

808 

809 def set_dependency(self, upstream_task_id, downstream_task_id): 

810 """Set dependency between two tasks that already have been added to the Dag using add_task().""" 

811 self.get_task(upstream_task_id).set_downstream(self.get_task(downstream_task_id)) 

812 

813 @property 

814 def roots(self) -> list[Operator]: 

815 """Return nodes with no parents. These are first to execute and are called roots or root nodes.""" 

816 return [task for task in self.tasks if not task.upstream_list] 

817 

818 @property 

819 def leaves(self) -> list[Operator]: 

820 """Return nodes with no children. These are last to execute and are called leaves or leaf nodes.""" 

821 return [task for task in self.tasks if not task.downstream_list] 

822 

823 def topological_sort(self): 

824 """ 

825 Sorts tasks in topographical order, such that a task comes after any of its upstream dependencies. 

826 

827 Deprecated in place of ``task_group.topological_sort`` 

828 """ 

829 from airflow.sdk.definitions.taskgroup import TaskGroup 

830 

831 # TODO: Remove in RemovedInAirflow3Warning 

832 def nested_topo(group): 

833 for node in group.topological_sort(): 

834 if isinstance(node, TaskGroup): 

835 yield from nested_topo(node) 

836 else: 

837 yield node 

838 

839 return tuple(nested_topo(self.task_group)) 

840 

841 def __deepcopy__(self, memo: dict[int, Any]): 

842 # Switcharoo to go around deepcopying objects coming through the 

843 # backdoor 

844 cls = self.__class__ 

845 result = cls.__new__(cls) 

846 memo[id(self)] = result 

847 for k, v in self.__dict__.items(): 

848 if k not in ("user_defined_macros", "user_defined_filters", "_log"): 

849 object.__setattr__(result, k, copy.deepcopy(v, memo)) 

850 

851 result.user_defined_macros = self.user_defined_macros 

852 result.user_defined_filters = self.user_defined_filters 

853 if hasattr(self, "_log"): 

854 result._log = self._log # type: ignore[attr-defined] 

855 return result 

856 

857 def partial_subset( 

858 self, 

859 task_ids: str | Iterable[str], 

860 include_downstream=False, 

861 include_upstream=True, 

862 include_direct_upstream=False, 

863 ): 

864 """ 

865 Return a subset of the current dag based on regex matching one or more tasks. 

866 

867 Returns a subset of the current dag as a deep copy of the current dag 

868 based on a regex that should match one or many tasks, and includes 

869 upstream and downstream neighbours based on the flag passed. 

870 

871 :param task_ids: Either a list of task_ids, or a string task_id 

872 :param include_downstream: Include all downstream tasks of matched 

873 tasks, in addition to matched tasks. 

874 :param include_upstream: Include all upstream tasks of matched tasks, 

875 in addition to matched tasks. 

876 :param include_direct_upstream: Include all tasks directly upstream of matched 

877 and downstream (if include_downstream = True) tasks 

878 """ 

879 from airflow.sdk.definitions.mappedoperator import MappedOperator 

880 

881 def is_task(obj) -> TypeGuard[Operator]: 

882 return isinstance(obj, BaseOperator | MappedOperator) 

883 

884 # deep-copying self.task_dict and self.task_group takes a long time, and we don't want all 

885 # the tasks anyway, so we copy the tasks manually later 

886 memo = {id(self.task_dict): None, id(self.task_group): None} 

887 dag = copy.deepcopy(self, memo) 

888 

889 if isinstance(task_ids, str): 

890 matched_tasks = [t for t in self.tasks if task_ids in t.task_id] 

891 else: 

892 matched_tasks = [t for t in self.tasks if t.task_id in task_ids] 

893 

894 also_include_ids: set[str] = set() 

895 for t in matched_tasks: 

896 if include_downstream: 

897 for rel in t.get_flat_relatives(upstream=False): 

898 also_include_ids.add(rel.task_id) 

899 if rel not in matched_tasks: # if it's in there, we're already processing it 

900 # need to include setups and teardowns for tasks that are in multiple 

901 # non-collinear setup/teardown paths 

902 if not rel.is_setup and not rel.is_teardown: 

903 also_include_ids.update( 

904 x.task_id for x in rel.get_upstreams_only_setups_and_teardowns() 

905 ) 

906 if include_upstream: 

907 also_include_ids.update(x.task_id for x in t.get_upstreams_follow_setups()) 

908 else: 

909 if not t.is_setup and not t.is_teardown: 

910 also_include_ids.update(x.task_id for x in t.get_upstreams_only_setups_and_teardowns()) 

911 if t.is_setup and not include_downstream: 

912 also_include_ids.update(x.task_id for x in t.downstream_list if x.is_teardown) 

913 

914 also_include: list[Operator] = [self.task_dict[x] for x in also_include_ids] 

915 direct_upstreams: list[Operator] = [] 

916 if include_direct_upstream: 

917 for t in itertools.chain(matched_tasks, also_include): 

918 direct_upstreams.extend(u for u in t.upstream_list if is_task(u)) 

919 

920 # Make sure to not recursively deepcopy the dag or task_group while copying the task. 

921 # task_group is reset later 

922 def _deepcopy_task(t) -> Operator: 

923 memo.setdefault(id(t.task_group), None) 

924 return copy.deepcopy(t, memo) 

925 

926 # Compiling the unique list of tasks that made the cut 

927 dag.task_dict = { 

928 t.task_id: _deepcopy_task(t) 

929 for t in itertools.chain(matched_tasks, also_include, direct_upstreams) 

930 } 

931 

932 def filter_task_group(group, parent_group): 

933 """Exclude tasks not included in the partial dag from the given TaskGroup.""" 

934 # We want to deepcopy _most but not all_ attributes of the task group, so we create a shallow copy 

935 # and then manually deep copy the instances. (memo argument to deepcopy only works for instances 

936 # of classes, not "native" properties of an instance) 

937 copied = copy.copy(group) 

938 

939 memo[id(group.children)] = {} 

940 if parent_group: 

941 memo[id(group.parent_group)] = parent_group 

942 for attr in type(group).__slots__: 

943 value = getattr(group, attr) 

944 value = copy.deepcopy(value, memo) 

945 object.__setattr__(copied, attr, value) 

946 

947 proxy = weakref.proxy(copied) 

948 

949 for child in group.children.values(): 

950 if is_task(child): 

951 if child.task_id in dag.task_dict: 

952 task = copied.children[child.task_id] = dag.task_dict[child.task_id] 

953 task.task_group = proxy 

954 else: 

955 copied.used_group_ids.discard(child.task_id) 

956 else: 

957 filtered_child = filter_task_group(child, proxy) 

958 

959 # Only include this child TaskGroup if it is non-empty. 

960 if filtered_child.children: 

961 copied.children[child.group_id] = filtered_child 

962 

963 return copied 

964 

965 object.__setattr__(dag, "task_group", filter_task_group(self.task_group, None)) 

966 

967 # Removing upstream/downstream references to tasks and TaskGroups that did not make 

968 # the cut. 

969 groups = dag.task_group.get_task_group_dict() 

970 for g in groups.values(): 

971 g.upstream_group_ids.intersection_update(groups) 

972 g.downstream_group_ids.intersection_update(groups) 

973 g.upstream_task_ids.intersection_update(dag.task_dict) 

974 g.downstream_task_ids.intersection_update(dag.task_dict) 

975 

976 for t in dag.tasks: 

977 # Removing upstream/downstream references to tasks that did not 

978 # make the cut 

979 t.upstream_task_ids.intersection_update(dag.task_dict) 

980 t.downstream_task_ids.intersection_update(dag.task_dict) 

981 

982 dag.partial = len(dag.tasks) < len(self.tasks) 

983 

984 return dag 

985 

986 def has_task(self, task_id: str): 

987 return task_id in self.task_dict 

988 

989 def has_task_group(self, task_group_id: str) -> bool: 

990 return task_group_id in self.task_group_dict 

991 

992 @functools.cached_property 

993 def task_group_dict(self): 

994 return {k: v for k, v in self.task_group.get_task_group_dict().items() if k is not None} 

995 

996 def get_task(self, task_id: str) -> Operator: 

997 if task_id in self.task_dict: 

998 return self.task_dict[task_id] 

999 raise TaskNotFound(f"Task {task_id} not found") 

1000 

1001 @property 

1002 def task(self) -> TaskDecoratorCollection: 

1003 from airflow.sdk.definitions.decorators import task 

1004 

1005 return cast("TaskDecoratorCollection", functools.partial(task, dag=self)) 

1006 

1007 def add_task(self, task: Operator) -> None: 

1008 """ 

1009 Add a task to the Dag. 

1010 

1011 :param task: the task you want to add 

1012 """ 

1013 # FailStopDagInvalidTriggerRule.check(dag=self, trigger_rule=task.trigger_rule) 

1014 

1015 from airflow.sdk.definitions._internal.contextmanager import TaskGroupContext 

1016 

1017 # if the task has no start date, assign it the same as the Dag 

1018 if not task.start_date: 

1019 task.start_date = self.start_date 

1020 # otherwise, the task will start on the later of its own start date and 

1021 # the Dag's start date 

1022 elif self.start_date: 

1023 task.start_date = max(task.start_date, self.start_date) 

1024 

1025 # if the task has no end date, assign it the same as the dag 

1026 if not task.end_date: 

1027 task.end_date = self.end_date 

1028 # otherwise, the task will end on the earlier of its own end date and 

1029 # the Dag's end date 

1030 elif task.end_date and self.end_date: 

1031 task.end_date = min(task.end_date, self.end_date) 

1032 

1033 task_id = task.node_id 

1034 if not task.task_group: 

1035 task_group = TaskGroupContext.get_current(self) 

1036 if task_group: 

1037 task_id = task_group.child_id(task_id) 

1038 task_group.add(task) 

1039 

1040 if ( 

1041 task_id in self.task_dict and self.task_dict[task_id] is not task 

1042 ) or task_id in self.task_group.used_group_ids: 

1043 raise DuplicateTaskIdFound(f"Task id '{task_id}' has already been added to the DAG") 

1044 self.task_dict[task_id] = task 

1045 

1046 task.dag = self 

1047 # Add task_id to used_group_ids to prevent group_id and task_id collisions. 

1048 self.task_group.used_group_ids.add(task_id) 

1049 

1050 FailFastDagInvalidTriggerRule.check(fail_fast=self.fail_fast, trigger_rule=task.trigger_rule) 

1051 

1052 def add_tasks(self, tasks: Iterable[Operator]) -> None: 

1053 """ 

1054 Add a list of tasks to the Dag. 

1055 

1056 :param tasks: a lit of tasks you want to add 

1057 """ 

1058 for task in tasks: 

1059 self.add_task(task) 

1060 

1061 def _remove_task(self, task_id: str) -> None: 

1062 # This is "private" as removing could leave a hole in dependencies if done incorrectly, and this 

1063 # doesn't guard against that 

1064 task = self.task_dict.pop(task_id) 

1065 tg = getattr(task, "task_group", None) 

1066 if tg: 

1067 tg._remove(task) 

1068 

1069 def check_cycle(self) -> None: 

1070 """ 

1071 Check to see if there are any cycles in the Dag. 

1072 

1073 :raises AirflowDagCycleException: If cycle is found in the Dag. 

1074 """ 

1075 # default of int is 0 which corresponds to CYCLE_NEW 

1076 CYCLE_NEW = 0 

1077 CYCLE_IN_PROGRESS = 1 

1078 CYCLE_DONE = 2 

1079 

1080 visited: dict[str, int] = defaultdict(int) 

1081 path_stack: deque[str] = deque() 

1082 task_dict = self.task_dict 

1083 

1084 def _check_adjacent_tasks(task_id, current_task): 

1085 """Return first untraversed child task, else None if all tasks traversed.""" 

1086 for adjacent_task in current_task.get_direct_relative_ids(): 

1087 if visited[adjacent_task] == CYCLE_IN_PROGRESS: 

1088 msg = f"Cycle detected in Dag: {self.dag_id}. Faulty task: {task_id}" 

1089 raise AirflowDagCycleException(msg) 

1090 if visited[adjacent_task] == CYCLE_NEW: 

1091 return adjacent_task 

1092 return None 

1093 

1094 for dag_task_id in self.task_dict.keys(): 

1095 if visited[dag_task_id] == CYCLE_DONE: 

1096 continue 

1097 path_stack.append(dag_task_id) 

1098 while path_stack: 

1099 current_task_id = path_stack[-1] 

1100 if visited[current_task_id] == CYCLE_NEW: 

1101 visited[current_task_id] = CYCLE_IN_PROGRESS 

1102 task = task_dict[current_task_id] 

1103 child_to_check = _check_adjacent_tasks(current_task_id, task) 

1104 if not child_to_check: 

1105 visited[current_task_id] = CYCLE_DONE 

1106 path_stack.pop() 

1107 else: 

1108 path_stack.append(child_to_check) 

1109 

1110 def cli(self): 

1111 """Exposes a CLI specific to this Dag.""" 

1112 self.check_cycle() 

1113 

1114 from airflow.cli import cli_parser 

1115 

1116 parser = cli_parser.get_parser(dag_parser=True) 

1117 args = parser.parse_args() 

1118 args.func(args, self) 

1119 

1120 @classmethod 

1121 def get_serialized_fields(cls): 

1122 """Stringified Dags and operators contain exactly these fields.""" 

1123 return cls.__serialized_fields 

1124 

1125 def get_edge_info(self, upstream_task_id: str, downstream_task_id: str) -> EdgeInfoType: 

1126 """Return edge information for the given pair of tasks or an empty edge if there is no information.""" 

1127 empty = cast("EdgeInfoType", {}) 

1128 if self.edge_info: 

1129 return self.edge_info.get(upstream_task_id, {}).get(downstream_task_id, empty) 

1130 return empty 

1131 

1132 def set_edge_info(self, upstream_task_id: str, downstream_task_id: str, info: EdgeInfoType): 

1133 """ 

1134 Set the given edge information on the Dag. 

1135 

1136 Note that this will overwrite, rather than merge with, existing info. 

1137 """ 

1138 self.edge_info.setdefault(upstream_task_id, {})[downstream_task_id] = info 

1139 

1140 @owner_links.validator 

1141 def _validate_owner_links(self, _, owner_links): 

1142 wrong_links = {} 

1143 

1144 for owner, link in owner_links.items(): 

1145 result = urlsplit(link) 

1146 if result.scheme == "mailto": 

1147 # netloc is not existing for 'mailto' link, so we are checking that the path is parsed 

1148 if not result.path: 

1149 wrong_links[result.path] = link 

1150 elif not result.scheme or not result.netloc: 

1151 wrong_links[owner] = link 

1152 if wrong_links: 

1153 raise ValueError( 

1154 "Wrong link format was used for the owner. Use a valid link \n" 

1155 f"Bad formatted links are: {wrong_links}" 

1156 ) 

1157 

1158 def test( 

1159 self, 

1160 run_after: datetime | None = None, 

1161 logical_date: datetime | None | ArgNotSet = NOTSET, 

1162 run_conf: dict[str, Any] | None = None, 

1163 conn_file_path: str | None = None, 

1164 variable_file_path: str | None = None, 

1165 use_executor: bool = False, 

1166 mark_success_pattern: Pattern | str | None = None, 

1167 ): 

1168 """ 

1169 Execute one single DagRun for a given Dag and logical date. 

1170 

1171 :param run_after: the datetime before which to Dag cannot run. 

1172 :param logical_date: logical date for the Dag run 

1173 :param run_conf: configuration to pass to newly created dagrun 

1174 :param conn_file_path: file path to a connection file in either yaml or json 

1175 :param variable_file_path: file path to a variable file in either yaml or json 

1176 :param use_executor: if set, uses an executor to test the Dag 

1177 :param mark_success_pattern: regex of task_ids to mark as success instead of running 

1178 """ 

1179 import re 

1180 import time 

1181 from contextlib import ExitStack 

1182 from unittest.mock import patch 

1183 

1184 from airflow import settings 

1185 from airflow.models.dagrun import DagRun, get_or_create_dagrun 

1186 from airflow.sdk import DagRunState, timezone 

1187 from airflow.serialization.definitions.dag import SerializedDAG 

1188 from airflow.serialization.encoders import coerce_to_core_timetable 

1189 from airflow.serialization.serialized_objects import DagSerialization 

1190 from airflow.utils.types import DagRunTriggeredByType, DagRunType 

1191 

1192 exit_stack = ExitStack() 

1193 

1194 if conn_file_path or variable_file_path: 

1195 backend_kwargs = {} 

1196 if conn_file_path: 

1197 backend_kwargs["connections_file_path"] = conn_file_path 

1198 if variable_file_path: 

1199 backend_kwargs["variables_file_path"] = variable_file_path 

1200 

1201 exit_stack.enter_context( 

1202 patch.dict( 

1203 os.environ, 

1204 { 

1205 "AIRFLOW__SECRETS__BACKEND": "airflow.secrets.local_filesystem.LocalFilesystemBackend", 

1206 "AIRFLOW__SECRETS__BACKEND_KWARGS": json.dumps(backend_kwargs), 

1207 }, 

1208 ) 

1209 ) 

1210 

1211 if settings.Session is None: 

1212 raise RuntimeError("Session not configured. Call configure_orm() first.") 

1213 session = settings.Session() 

1214 

1215 with exit_stack: 

1216 self.validate() 

1217 scheduler_dag = DagSerialization.deserialize_dag(DagSerialization.serialize_dag(self)) 

1218 

1219 # Allow users to explicitly pass None. If it isn't set, we default to current time. 

1220 logical_date = logical_date if is_arg_set(logical_date) else timezone.utcnow() 

1221 

1222 log.debug("Clearing existing task instances for logical date %s", logical_date) 

1223 # TODO: Replace with calling client.dag_run.clear in Execution API at some point 

1224 SerializedDAG.clear_dags( 

1225 dags=[scheduler_dag], 

1226 start_date=logical_date, 

1227 end_date=logical_date, 

1228 dag_run_state=False, 

1229 ) 

1230 

1231 log.debug("Getting dagrun for dag %s", self.dag_id) 

1232 logical_date = timezone.coerce_datetime(logical_date) 

1233 run_after = timezone.coerce_datetime(run_after) or timezone.coerce_datetime(timezone.utcnow()) 

1234 if logical_date is None: 

1235 data_interval: DataInterval | None = None 

1236 else: 

1237 timetable = coerce_to_core_timetable(self.timetable) 

1238 data_interval = timetable.infer_manual_data_interval(run_after=logical_date) 

1239 from airflow.models.dag_version import DagVersion 

1240 

1241 version = DagVersion.get_version(self.dag_id) 

1242 if not version: 

1243 from airflow.dag_processing.bundles.manager import DagBundlesManager 

1244 from airflow.dag_processing.dagbag import DagBag, sync_bag_to_db 

1245 from airflow.sdk.definitions._internal.dag_parsing_context import ( 

1246 _airflow_parsing_context_manager, 

1247 ) 

1248 

1249 manager = DagBundlesManager() 

1250 manager.sync_bundles_to_db(session=session) 

1251 session.commit() 

1252 # sync all bundles? or use the dags-folder bundle? 

1253 # What if the test dag is in a different bundle? 

1254 for bundle in manager.get_all_dag_bundles(): 

1255 if not bundle.is_initialized: 

1256 bundle.initialize() 

1257 with _airflow_parsing_context_manager(dag_id=self.dag_id): 

1258 dagbag = DagBag( 

1259 dag_folder=bundle.path, bundle_path=bundle.path, include_examples=False 

1260 ) 

1261 sync_bag_to_db(dagbag, bundle.name, bundle.version) 

1262 version = DagVersion.get_version(self.dag_id) 

1263 if version: 

1264 break 

1265 

1266 # Preserve callback functions from original Dag since they're lost during serialization 

1267 # and yes it is a hack for now! It is a tradeoff for code simplicity. 

1268 # Without it, we need "Scheduler Dag" (Serialized dag) for the scheduler bits 

1269 # -- dep check, scheduling tis 

1270 # and need real dag to get and run callbacks without having to load the dag model 

1271 

1272 # Scheduler DAG shouldn't have these attributes, but assigning them 

1273 # here is an easy hack to get this test() thing working. 

1274 scheduler_dag.on_success_callback = self.on_success_callback # type: ignore[attr-defined, union-attr] 

1275 scheduler_dag.on_failure_callback = self.on_failure_callback # type: ignore[attr-defined, union-attr] 

1276 

1277 dr: DagRun = get_or_create_dagrun( 

1278 dag=scheduler_dag, 

1279 start_date=logical_date or run_after, 

1280 logical_date=logical_date, 

1281 data_interval=data_interval, 

1282 run_after=run_after, 

1283 run_id=DagRun.generate_run_id( 

1284 run_type=DagRunType.MANUAL, 

1285 logical_date=logical_date, 

1286 run_after=run_after, 

1287 ), 

1288 session=session, 

1289 conf=run_conf, 

1290 triggered_by=DagRunTriggeredByType.TEST, 

1291 triggering_user_name="dag_test", 

1292 ) 

1293 # Start a mock span so that one is present and not started downstream. We 

1294 # don't care about otel in dag.test and starting the span during dagrun update 

1295 # is not functioning properly in this context anyway. 

1296 dr.start_dr_spans_if_needed(tis=[]) 

1297 

1298 log.debug("starting dagrun") 

1299 # Instead of starting a scheduler, we run the minimal loop possible to check 

1300 # for task readiness and dependency management. 

1301 # Instead of starting a scheduler, we run the minimal loop possible to check 

1302 # for task readiness and dependency management. 

1303 

1304 # ``Dag.test()`` works in two different modes depending on ``use_executor``: 

1305 # - if ``use_executor`` is False, runs the task locally with no executor using ``_run_task`` 

1306 # - if ``use_executor`` is True, sends workloads to the executor with 

1307 # ``BaseExecutor.queue_workload`` 

1308 if use_executor: 

1309 from airflow.executors.base_executor import ExecutorLoader 

1310 

1311 executor = ExecutorLoader.get_default_executor() 

1312 executor.start() 

1313 

1314 while dr.state == DagRunState.RUNNING: 

1315 session.expire_all() 

1316 schedulable_tis, _ = dr.update_state(session=session) 

1317 for s in schedulable_tis: 

1318 if s.state != TaskInstanceState.UP_FOR_RESCHEDULE: 

1319 s.try_number += 1 

1320 s.state = TaskInstanceState.SCHEDULED 

1321 s.scheduled_dttm = timezone.utcnow() 

1322 session.commit() 

1323 # triggerer may mark tasks scheduled so we read from DB 

1324 all_tis = set(dr.get_task_instances(session=session)) 

1325 scheduled_tis = {x for x in all_tis if x.state == TaskInstanceState.SCHEDULED} 

1326 ids_unrunnable = {x for x in all_tis if x.state not in FINISHED_STATES} - scheduled_tis 

1327 if not scheduled_tis and ids_unrunnable: 

1328 log.warning("No tasks to run. unrunnable tasks: %s", ids_unrunnable) 

1329 time.sleep(1) 

1330 

1331 for ti in scheduled_tis: 

1332 task = self.task_dict[ti.task_id] 

1333 

1334 mark_success = ( 

1335 re.compile(mark_success_pattern).fullmatch(ti.task_id) is not None 

1336 if mark_success_pattern is not None 

1337 else False 

1338 ) 

1339 

1340 if use_executor: 

1341 if executor.has_task(ti): 

1342 continue 

1343 

1344 from pathlib import Path 

1345 

1346 from airflow.executors import workloads 

1347 from airflow.executors.base_executor import ExecutorLoader 

1348 from airflow.executors.workloads import BundleInfo 

1349 

1350 workload = workloads.ExecuteTask.make( 

1351 ti, 

1352 dag_rel_path=Path(self.fileloc), 

1353 generator=executor.jwt_generator, 

1354 sentry_integration=executor.sentry_integration, 

1355 # For the system test/debug purpose, we use the default bundle which uses 

1356 # local file system. If it turns out to be a feature people want, we could 

1357 # plumb the Bundle to use as a parameter to dag.test 

1358 bundle_info=BundleInfo(name="dags-folder"), 

1359 ) 

1360 executor.queue_workload(workload, session=session) 

1361 ti.state = TaskInstanceState.QUEUED 

1362 session.commit() 

1363 else: 

1364 # Run the task locally 

1365 try: 

1366 if mark_success: 

1367 ti.set_state(TaskInstanceState.SUCCESS) 

1368 log.info("[DAG TEST] Marking success for %s on %s", task, ti.logical_date) 

1369 else: 

1370 _run_task(ti=ti, task=task, run_triggerer=True) 

1371 except Exception: 

1372 log.exception("Task failed; ti=%s", ti) 

1373 if use_executor: 

1374 executor.heartbeat() 

1375 session.expire_all() 

1376 

1377 from airflow.jobs.scheduler_job_runner import SchedulerJobRunner 

1378 from airflow.models.dagbag import DBDagBag 

1379 

1380 SchedulerJobRunner.process_executor_events( 

1381 executor=executor, job_id=None, scheduler_dag_bag=DBDagBag(), session=session 

1382 ) 

1383 if use_executor: 

1384 executor.end() 

1385 return dr 

1386 

1387 

1388def _run_task( 

1389 *, 

1390 ti: SchedulerTaskInstance, 

1391 task: Operator, 

1392 run_triggerer: bool = False, 

1393) -> TaskRunResult | None: 

1394 """ 

1395 Run a single task instance, and push result to Xcom for downstream tasks. 

1396 

1397 Bypasses a lot of extra steps used in `task.run` to keep our local running as fast as 

1398 possible. This function is only meant for the `dag.test` function as a helper function. 

1399 """ 

1400 from airflow.sdk._shared.module_loading import import_string 

1401 

1402 taskrun_result: TaskRunResult | None 

1403 log.info("[DAG TEST] starting task_id=%s map_index=%s", ti.task_id, ti.map_index) 

1404 while True: 

1405 try: 

1406 log.info("[DAG TEST] running task %s", ti) 

1407 

1408 from airflow.sdk.api.datamodels._generated import TaskInstance as TaskInstanceSDK 

1409 from airflow.sdk.execution_time.comms import DeferTask 

1410 from airflow.sdk.execution_time.supervisor import run_task_in_process 

1411 from airflow.serialization.serialized_objects import create_scheduler_operator 

1412 

1413 # The API Server expects the task instance to be in QUEUED state before 

1414 # it is run. 

1415 ti.set_state(TaskInstanceState.QUEUED) 

1416 task_sdk_ti = TaskInstanceSDK( 

1417 id=UUID(str(ti.id)), 

1418 task_id=ti.task_id, 

1419 dag_id=ti.dag_id, 

1420 run_id=ti.run_id, 

1421 try_number=ti.try_number, 

1422 map_index=ti.map_index, 

1423 dag_version_id=UUID(str(ti.dag_version_id)), 

1424 ) 

1425 

1426 taskrun_result = run_task_in_process(ti=task_sdk_ti, task=task) 

1427 msg = taskrun_result.msg 

1428 ti.set_state(taskrun_result.ti.state) 

1429 ti.task = create_scheduler_operator(taskrun_result.ti.task) 

1430 

1431 if ti.state == TaskInstanceState.DEFERRED and isinstance(msg, DeferTask) and run_triggerer: 

1432 from airflow.sdk.serde import deserialize, serialize 

1433 from airflow.utils.session import create_session 

1434 

1435 # API Server expects the task instance to be in QUEUED state before 

1436 # resuming from deferral. 

1437 ti.set_state(TaskInstanceState.QUEUED) 

1438 

1439 log.info("[DAG TEST] running trigger in line") 

1440 # trigger_kwargs need to be deserialized before passing to the trigger class since they are in serde encoded format 

1441 kwargs = deserialize(msg.trigger_kwargs) # type: ignore[type-var] # needed to convince mypy that trigger_kwargs is a dict or a str because its unable to infer JsonValue 

1442 if TYPE_CHECKING: 

1443 assert isinstance(kwargs, dict) 

1444 trigger = import_string(msg.classpath)(**kwargs) 

1445 event = _run_inline_trigger(trigger, task_sdk_ti) 

1446 ti.next_method = msg.next_method 

1447 ti.next_kwargs = {"event": serialize(event.payload)} if event else msg.next_kwargs 

1448 log.info("[DAG TEST] Trigger completed") 

1449 

1450 # Set the state to SCHEDULED so that the task can be resumed. 

1451 with create_session() as session: 

1452 ti.state = TaskInstanceState.SCHEDULED 

1453 session.add(ti) 

1454 continue 

1455 

1456 break 

1457 except Exception: 

1458 log.exception("[DAG TEST] Error running task %s", ti) 

1459 if ti.state not in FINISHED_STATES: 

1460 ti.set_state(TaskInstanceState.FAILED) 

1461 taskrun_result = None 

1462 break 

1463 raise 

1464 

1465 log.info("[DAG TEST] end task task_id=%s map_index=%s", ti.task_id, ti.map_index) 

1466 return taskrun_result 

1467 

1468 

1469def _run_inline_trigger(trigger, task_sdk_ti): 

1470 from airflow.sdk.execution_time.supervisor import InProcessTestSupervisor 

1471 

1472 return InProcessTestSupervisor.run_trigger_in_process(trigger=trigger, ti=task_sdk_ti) 

1473 

1474 

1475# Since we define all the attributes of the class with attrs, we can compute this statically at parse time 

1476DAG._DAG__serialized_fields = frozenset(a.name for a in attrs.fields(DAG)) - { # type: ignore[attr-defined] 

1477 "schedule_asset_references", 

1478 "schedule_asset_alias_references", 

1479 "task_outlet_asset_references", 

1480 "_old_context_manager_dags", 

1481 "safe_dag_id", 

1482 "last_loaded", 

1483 "user_defined_filters", 

1484 "user_defined_macros", 

1485 "partial", 

1486 "params", 

1487 "_log", 

1488 "task_dict", 

1489 "template_searchpath", 

1490 "sla_miss_callback", 

1491 "on_success_callback", 

1492 "on_failure_callback", 

1493 "template_undefined", 

1494 "jinja_environment_kwargs", 

1495 # has_on_*_callback are only stored if the value is True, as the default is False 

1496 "has_on_success_callback", 

1497 "has_on_failure_callback", 

1498 "auto_register", 

1499 "schedule", 

1500} 

1501 

1502if TYPE_CHECKING: 

1503 # NOTE: Please keep the list of arguments in sync with DAG.__init__. 

1504 # Only exception: dag_id here should have a default value, but not in DAG. 

1505 @overload 

1506 def dag( 

1507 dag_id: str = "", 

1508 *, 

1509 description: str | None = None, 

1510 schedule: ScheduleArg = None, 

1511 start_date: datetime | None = None, 

1512 end_date: datetime | None = None, 

1513 template_searchpath: str | Iterable[str] | None = None, 

1514 template_undefined: type[jinja2.StrictUndefined] = jinja2.StrictUndefined, 

1515 user_defined_macros: dict | None = None, 

1516 user_defined_filters: dict | None = None, 

1517 default_args: dict[str, Any] | None = None, 

1518 max_active_tasks: int = ..., 

1519 max_active_runs: int = ..., 

1520 max_consecutive_failed_dag_runs: int = ..., 

1521 dagrun_timeout: timedelta | None = None, 

1522 catchup: bool = ..., 

1523 on_success_callback: None | DagStateChangeCallback | list[DagStateChangeCallback] = None, 

1524 on_failure_callback: None | DagStateChangeCallback | list[DagStateChangeCallback] = None, 

1525 deadline: list[DeadlineAlert] | DeadlineAlert | None = None, 

1526 doc_md: str | None = None, 

1527 params: ParamsDict | dict[str, Any] | None = None, 

1528 access_control: dict[str, dict[str, Collection[str]]] | dict[str, Collection[str]] | None = None, 

1529 is_paused_upon_creation: bool | None = None, 

1530 jinja_environment_kwargs: dict | None = None, 

1531 render_template_as_native_obj: bool = False, 

1532 tags: Collection[str] | None = None, 

1533 owner_links: dict[str, str] | None = None, 

1534 auto_register: bool = True, 

1535 fail_fast: bool = False, 

1536 dag_display_name: str | None = None, 

1537 disable_bundle_versioning: bool = False, 

1538 ) -> Callable[[Callable], Callable[..., DAG]]: 

1539 """ 

1540 Python dag decorator which wraps a function into an Airflow Dag. 

1541 

1542 Accepts kwargs for operator kwarg. Can be used to parameterize Dags. 

1543 

1544 :param dag_args: Arguments for DAG object 

1545 :param dag_kwargs: Kwargs for DAG object. 

1546 """ 

1547 

1548 @overload 

1549 def dag(func: Callable[..., DAG]) -> Callable[..., DAG]: 

1550 """Python dag decorator to use without any arguments.""" 

1551 

1552 

1553def dag(dag_id_or_func=None, __DAG_class=DAG, __warnings_stacklevel_delta=2, **decorator_kwargs): 

1554 from airflow.sdk.definitions._internal.decorators import fixup_decorator_warning_stack 

1555 

1556 # TODO: Task-SDK: remove __DAG_class 

1557 # __DAG_class is a temporary hack to allow the dag decorator in airflow.models.dag to continue to 

1558 # return SchedulerDag objects 

1559 DAG = __DAG_class 

1560 

1561 def wrapper(f: Callable) -> Callable[..., DAG]: 

1562 # Determine dag_id: prioritize keyword arg, then positional string, fallback to function name 

1563 if "dag_id" in decorator_kwargs: 

1564 dag_id = decorator_kwargs.pop("dag_id", "") 

1565 elif isinstance(dag_id_or_func, str) and dag_id_or_func.strip(): 

1566 dag_id = dag_id_or_func 

1567 else: 

1568 dag_id = f.__name__ 

1569 

1570 @functools.wraps(f) 

1571 def factory(*args, **kwargs): 

1572 # Generate signature for decorated function and bind the arguments when called 

1573 # we do this to extract parameters, so we can annotate them on the DAG object. 

1574 # In addition, this fails if we are missing any args/kwargs with TypeError as expected. 

1575 f_sig = signature(f).bind(*args, **kwargs) 

1576 # Apply defaults to capture default values if set. 

1577 f_sig.apply_defaults() 

1578 

1579 # Initialize Dag with bound arguments 

1580 with DAG(dag_id, **decorator_kwargs) as dag_obj: 

1581 # Set Dag documentation from function documentation if it exists and doc_md is not set. 

1582 if f.__doc__ and not dag_obj.doc_md: 

1583 dag_obj.doc_md = f.__doc__ 

1584 

1585 # Generate DAGParam for each function arg/kwarg and replace it for calling the function. 

1586 # All args/kwargs for function will be DAGParam object and replaced on execution time. 

1587 f_kwargs = {} 

1588 for name, value in f_sig.arguments.items(): 

1589 f_kwargs[name] = dag_obj.param(name, value) 

1590 

1591 # set file location to caller source path 

1592 back = sys._getframe().f_back 

1593 dag_obj.fileloc = back.f_code.co_filename if back else "" 

1594 

1595 # Invoke function to create operators in the Dag scope. 

1596 f(**f_kwargs) 

1597 

1598 # Return dag object such that it's accessible in Globals. 

1599 return dag_obj 

1600 

1601 # Ensure that warnings from inside DAG() are emitted from the caller, not here 

1602 fixup_decorator_warning_stack(factory) 

1603 return factory 

1604 

1605 if callable(dag_id_or_func) and not isinstance(dag_id_or_func, str): 

1606 return wrapper(dag_id_or_func) 

1607 

1608 return wrapper