Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/airflow/models/dag.py: 31%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1753 statements  

1# 

2# Licensed to the Apache Software Foundation (ASF) under one 

3# or more contributor license agreements. See the NOTICE file 

4# distributed with this work for additional information 

5# regarding copyright ownership. The ASF licenses this file 

6# to you under the Apache License, Version 2.0 (the 

7# "License"); you may not use this file except in compliance 

8# with the License. You may obtain a copy of the License at 

9# 

10# http://www.apache.org/licenses/LICENSE-2.0 

11# 

12# Unless required by applicable law or agreed to in writing, 

13# software distributed under the License is distributed on an 

14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

15# KIND, either express or implied. See the License for the 

16# specific language governing permissions and limitations 

17# under the License. 

18from __future__ import annotations 

19 

20import asyncio 

21import copy 

22import functools 

23import itertools 

24import logging 

25import os 

26import pathlib 

27import pickle 

28import sys 

29import time 

30import traceback 

31import warnings 

32import weakref 

33from collections import abc, defaultdict, deque 

34from contextlib import ExitStack 

35from datetime import datetime, timedelta 

36from inspect import signature 

37from typing import ( 

38 TYPE_CHECKING, 

39 Any, 

40 Callable, 

41 Collection, 

42 Container, 

43 Generator, 

44 Iterable, 

45 Iterator, 

46 List, 

47 Pattern, 

48 Sequence, 

49 Union, 

50 cast, 

51 overload, 

52) 

53from urllib.parse import urlsplit 

54 

55import jinja2 

56import pendulum 

57import re2 

58import sqlalchemy_jsonfield 

59from dateutil.relativedelta import relativedelta 

60from sqlalchemy import ( 

61 Boolean, 

62 Column, 

63 ForeignKey, 

64 Index, 

65 Integer, 

66 String, 

67 Text, 

68 and_, 

69 case, 

70 func, 

71 not_, 

72 or_, 

73 select, 

74 update, 

75) 

76from sqlalchemy.ext.associationproxy import association_proxy 

77from sqlalchemy.orm import backref, joinedload, load_only, relationship 

78from sqlalchemy.sql import Select, expression 

79 

80import airflow.templates 

81from airflow import settings, utils 

82from airflow.api_internal.internal_api_call import internal_api_call 

83from airflow.configuration import conf as airflow_conf, secrets_backend_list 

84from airflow.datasets import BaseDataset, Dataset, DatasetAll 

85from airflow.datasets.manager import dataset_manager 

86from airflow.exceptions import ( 

87 AirflowDagInconsistent, 

88 AirflowException, 

89 DuplicateTaskIdFound, 

90 FailStopDagInvalidTriggerRule, 

91 ParamValidationError, 

92 RemovedInAirflow3Warning, 

93 TaskDeferred, 

94 TaskNotFound, 

95) 

96from airflow.jobs.job import run_job 

97from airflow.models.abstractoperator import AbstractOperator, TaskStateChangeCallback 

98from airflow.models.base import Base, StringID 

99from airflow.models.baseoperator import BaseOperator 

100from airflow.models.dagcode import DagCode 

101from airflow.models.dagpickle import DagPickle 

102from airflow.models.dagrun import RUN_ID_REGEX, DagRun 

103from airflow.models.dataset import DatasetDagRunQueue, DatasetModel 

104from airflow.models.param import DagParam, ParamsDict 

105from airflow.models.taskinstance import ( 

106 Context, 

107 TaskInstance, 

108 TaskInstanceKey, 

109 clear_task_instances, 

110) 

111from airflow.secrets.local_filesystem import LocalFilesystemBackend 

112from airflow.security import permissions 

113from airflow.settings import json 

114from airflow.stats import Stats 

115from airflow.timetables.base import DagRunInfo, DataInterval, TimeRestriction, Timetable 

116from airflow.timetables.datasets import DatasetOrTimeSchedule 

117from airflow.timetables.interval import CronDataIntervalTimetable, DeltaDataIntervalTimetable 

118from airflow.timetables.simple import ( 

119 ContinuousTimetable, 

120 DatasetTriggeredTimetable, 

121 NullTimetable, 

122 OnceTimetable, 

123) 

124from airflow.timetables.trigger import CronTriggerTimetable 

125from airflow.utils import timezone 

126from airflow.utils.dag_cycle_tester import check_cycle 

127from airflow.utils.dates import cron_presets, date_range as utils_date_range 

128from airflow.utils.decorators import fixup_decorator_warning_stack 

129from airflow.utils.helpers import at_most_one, exactly_one, validate_key 

130from airflow.utils.log.logging_mixin import LoggingMixin 

131from airflow.utils.session import NEW_SESSION, provide_session 

132from airflow.utils.sqlalchemy import ( 

133 Interval, 

134 UtcDateTime, 

135 lock_rows, 

136 tuple_in_condition, 

137 with_row_locks, 

138) 

139from airflow.utils.state import DagRunState, State, TaskInstanceState 

140from airflow.utils.trigger_rule import TriggerRule 

141from airflow.utils.types import NOTSET, ArgNotSet, DagRunType, EdgeInfoType 

142 

143if TYPE_CHECKING: 

144 from types import ModuleType 

145 

146 from pendulum.tz.timezone import FixedTimezone, Timezone 

147 from sqlalchemy.orm.query import Query 

148 from sqlalchemy.orm.session import Session 

149 

150 from airflow.decorators import TaskDecoratorCollection 

151 from airflow.models.dagbag import DagBag 

152 from airflow.models.operator import Operator 

153 from airflow.models.slamiss import SlaMiss 

154 from airflow.serialization.pydantic.dag import DagModelPydantic 

155 from airflow.serialization.pydantic.dag_run import DagRunPydantic 

156 from airflow.typing_compat import Literal 

157 from airflow.utils.task_group import TaskGroup 

158 

159 # This is a workaround because mypy doesn't work with hybrid_property 

160 # TODO: remove this hack and move hybrid_property back to main import block 

161 # See https://github.com/python/mypy/issues/4430 

162 hybrid_property = property 

163else: 

164 from sqlalchemy.ext.hybrid import hybrid_property 

165 

166log = logging.getLogger(__name__) 

167 

168DEFAULT_VIEW_PRESETS = ["grid", "graph", "duration", "gantt", "landing_times"] 

169ORIENTATION_PRESETS = ["LR", "TB", "RL", "BT"] 

170 

171TAG_MAX_LEN = 100 

172 

173DagStateChangeCallback = Callable[[Context], None] 

174ScheduleInterval = Union[None, str, timedelta, relativedelta] 

175 

176# FIXME: Ideally this should be Union[Literal[NOTSET], ScheduleInterval], 

177# but Mypy cannot handle that right now. Track progress of PEP 661 for progress. 

178# See also: https://discuss.python.org/t/9126/7 

179ScheduleIntervalArg = Union[ArgNotSet, ScheduleInterval] 

180ScheduleArg = Union[ArgNotSet, ScheduleInterval, Timetable, BaseDataset, Collection["Dataset"]] 

181 

182SLAMissCallback = Callable[["DAG", str, str, List["SlaMiss"], List[TaskInstance]], None] 

183 

184# Backward compatibility: If neither schedule_interval nor timetable is 

185# *provided by the user*, default to a one-day interval. 

186DEFAULT_SCHEDULE_INTERVAL = timedelta(days=1) 

187 

188 

189class InconsistentDataInterval(AirflowException): 

190 """Exception raised when a model populates data interval fields incorrectly. 

191 

192 The data interval fields should either both be None (for runs scheduled 

193 prior to AIP-39), or both be datetime (for runs scheduled after AIP-39 is 

194 implemented). This is raised if exactly one of the fields is None. 

195 """ 

196 

197 _template = ( 

198 "Inconsistent {cls}: {start[0]}={start[1]!r}, {end[0]}={end[1]!r}, " 

199 "they must be either both None or both datetime" 

200 ) 

201 

202 def __init__(self, instance: Any, start_field_name: str, end_field_name: str) -> None: 

203 self._class_name = type(instance).__name__ 

204 self._start_field = (start_field_name, getattr(instance, start_field_name)) 

205 self._end_field = (end_field_name, getattr(instance, end_field_name)) 

206 

207 def __str__(self) -> str: 

208 return self._template.format(cls=self._class_name, start=self._start_field, end=self._end_field) 

209 

210 

211def _get_model_data_interval( 

212 instance: Any, 

213 start_field_name: str, 

214 end_field_name: str, 

215) -> DataInterval | None: 

216 start = timezone.coerce_datetime(getattr(instance, start_field_name)) 

217 end = timezone.coerce_datetime(getattr(instance, end_field_name)) 

218 if start is None: 

219 if end is not None: 

220 raise InconsistentDataInterval(instance, start_field_name, end_field_name) 

221 return None 

222 elif end is None: 

223 raise InconsistentDataInterval(instance, start_field_name, end_field_name) 

224 return DataInterval(start, end) 

225 

226 

227def create_timetable(interval: ScheduleIntervalArg, timezone: Timezone | FixedTimezone) -> Timetable: 

228 """Create a Timetable instance from a ``schedule_interval`` argument.""" 

229 if interval is NOTSET: 

230 return DeltaDataIntervalTimetable(DEFAULT_SCHEDULE_INTERVAL) 

231 if interval is None: 

232 return NullTimetable() 

233 if interval == "@once": 

234 return OnceTimetable() 

235 if interval == "@continuous": 

236 return ContinuousTimetable() 

237 if isinstance(interval, (timedelta, relativedelta)): 

238 return DeltaDataIntervalTimetable(interval) 

239 if isinstance(interval, str): 

240 if airflow_conf.getboolean("scheduler", "create_cron_data_intervals"): 

241 return CronDataIntervalTimetable(interval, timezone) 

242 else: 

243 return CronTriggerTimetable(interval, timezone=timezone) 

244 raise ValueError(f"{interval!r} is not a valid schedule_interval.") 

245 

246 

247def get_last_dagrun(dag_id, session, include_externally_triggered=False): 

248 """ 

249 Return the last dag run for a dag, None if there was none. 

250 

251 Last dag run can be any type of run e.g. scheduled or backfilled. 

252 Overridden DagRuns are ignored. 

253 """ 

254 DR = DagRun 

255 query = select(DR).where(DR.dag_id == dag_id) 

256 if not include_externally_triggered: 

257 query = query.where(DR.external_trigger == expression.false()) 

258 query = query.order_by(DR.execution_date.desc()) 

259 return session.scalar(query.limit(1)) 

260 

261 

262def get_dataset_triggered_next_run_info( 

263 dag_ids: list[str], *, session: Session 

264) -> dict[str, dict[str, int | str]]: 

265 """ 

266 Get next run info for a list of dag_ids. 

267 

268 Given a list of dag_ids, get string representing how close any that are dataset triggered are 

269 their next run, e.g. "1 of 2 datasets updated". 

270 """ 

271 from airflow.models.dataset import DagScheduleDatasetReference, DatasetDagRunQueue as DDRQ, DatasetModel 

272 

273 return { 

274 x.dag_id: { 

275 "uri": x.uri, 

276 "ready": x.ready, 

277 "total": x.total, 

278 } 

279 for x in session.execute( 

280 select( 

281 DagScheduleDatasetReference.dag_id, 

282 # This is a dirty hack to workaround group by requiring an aggregate, 

283 # since grouping by dataset is not what we want to do here...but it works 

284 case((func.count() == 1, func.max(DatasetModel.uri)), else_="").label("uri"), 

285 func.count().label("total"), 

286 func.sum(case((DDRQ.target_dag_id.is_not(None), 1), else_=0)).label("ready"), 

287 ) 

288 .join( 

289 DDRQ, 

290 and_( 

291 DDRQ.dataset_id == DagScheduleDatasetReference.dataset_id, 

292 DDRQ.target_dag_id == DagScheduleDatasetReference.dag_id, 

293 ), 

294 isouter=True, 

295 ) 

296 .join(DatasetModel, DatasetModel.id == DagScheduleDatasetReference.dataset_id) 

297 .group_by(DagScheduleDatasetReference.dag_id) 

298 .where(DagScheduleDatasetReference.dag_id.in_(dag_ids)) 

299 ).all() 

300 } 

301 

302 

303def _triggerer_is_healthy(): 

304 from airflow.jobs.triggerer_job_runner import TriggererJobRunner 

305 

306 job = TriggererJobRunner.most_recent_job() 

307 return job and job.is_alive() 

308 

309 

310@internal_api_call 

311@provide_session 

312def _create_orm_dagrun( 

313 dag, 

314 dag_id, 

315 run_id, 

316 logical_date, 

317 start_date, 

318 external_trigger, 

319 conf, 

320 state, 

321 run_type, 

322 dag_hash, 

323 creating_job_id, 

324 data_interval, 

325 session, 

326): 

327 run = DagRun( 

328 dag_id=dag_id, 

329 run_id=run_id, 

330 execution_date=logical_date, 

331 start_date=start_date, 

332 external_trigger=external_trigger, 

333 conf=conf, 

334 state=state, 

335 run_type=run_type, 

336 dag_hash=dag_hash, 

337 creating_job_id=creating_job_id, 

338 data_interval=data_interval, 

339 ) 

340 session.add(run) 

341 session.flush() 

342 run.dag = dag 

343 # create the associated task instances 

344 # state is None at the moment of creation 

345 run.verify_integrity(session=session) 

346 return run 

347 

348 

349@functools.total_ordering 

350class DAG(LoggingMixin): 

351 """ 

352 A dag (directed acyclic graph) is a collection of tasks with directional dependencies. 

353 

354 A dag also has a schedule, a start date and an end date (optional). For each schedule, 

355 (say daily or hourly), the DAG needs to run each individual tasks as their dependencies 

356 are met. Certain tasks have the property of depending on their own past, meaning that 

357 they can't run until their previous schedule (and upstream tasks) are completed. 

358 

359 DAGs essentially act as namespaces for tasks. A task_id can only be 

360 added once to a DAG. 

361 

362 Note that if you plan to use time zones all the dates provided should be pendulum 

363 dates. See :ref:`timezone_aware_dags`. 

364 

365 .. versionadded:: 2.4 

366 The *schedule* argument to specify either time-based scheduling logic 

367 (timetable), or dataset-driven triggers. 

368 

369 .. deprecated:: 2.4 

370 The arguments *schedule_interval* and *timetable*. Their functionalities 

371 are merged into the new *schedule* argument. 

372 

373 :param dag_id: The id of the DAG; must consist exclusively of alphanumeric 

374 characters, dashes, dots and underscores (all ASCII) 

375 :param description: The description for the DAG to e.g. be shown on the webserver 

376 :param schedule: Defines the rules according to which DAG runs are scheduled. Can 

377 accept cron string, timedelta object, Timetable, or list of Dataset objects. 

378 If this is not provided, the DAG will be set to the default 

379 schedule ``timedelta(days=1)``. See also :doc:`/howto/timetable`. 

380 :param start_date: The timestamp from which the scheduler will 

381 attempt to backfill 

382 :param end_date: A date beyond which your DAG won't run, leave to None 

383 for open-ended scheduling 

384 :param template_searchpath: This list of folders (non-relative) 

385 defines where jinja will look for your templates. Order matters. 

386 Note that jinja/airflow includes the path of your DAG file by 

387 default 

388 :param template_undefined: Template undefined type. 

389 :param user_defined_macros: a dictionary of macros that will be exposed 

390 in your jinja templates. For example, passing ``dict(foo='bar')`` 

391 to this argument allows you to ``{{ foo }}`` in all jinja 

392 templates related to this DAG. Note that you can pass any 

393 type of object here. 

394 :param user_defined_filters: a dictionary of filters that will be exposed 

395 in your jinja templates. For example, passing 

396 ``dict(hello=lambda name: 'Hello %s' % name)`` to this argument allows 

397 you to ``{{ 'world' | hello }}`` in all jinja templates related to 

398 this DAG. 

399 :param default_args: A dictionary of default parameters to be used 

400 as constructor keyword parameters when initialising operators. 

401 Note that operators have the same hook, and precede those defined 

402 here, meaning that if your dict contains `'depends_on_past': True` 

403 here and `'depends_on_past': False` in the operator's call 

404 `default_args`, the actual value will be `False`. 

405 :param params: a dictionary of DAG level parameters that are made 

406 accessible in templates, namespaced under `params`. These 

407 params can be overridden at the task level. 

408 :param max_active_tasks: the number of task instances allowed to run 

409 concurrently 

410 :param max_active_runs: maximum number of active DAG runs, beyond this 

411 number of DAG runs in a running state, the scheduler won't create 

412 new active DAG runs 

413 :param max_consecutive_failed_dag_runs: (experimental) maximum number of consecutive failed DAG runs, 

414 beyond this the scheduler will disable the DAG 

415 :param dagrun_timeout: specify how long a DagRun should be up before 

416 timing out / failing, so that new DagRuns can be created. 

417 :param sla_miss_callback: specify a function or list of functions to call when reporting SLA 

418 timeouts. See :ref:`sla_miss_callback<concepts:sla_miss_callback>` for 

419 more information about the function signature and parameters that are 

420 passed to the callback. 

421 :param default_view: Specify DAG default view (grid, graph, duration, 

422 gantt, landing_times), default grid 

423 :param orientation: Specify DAG orientation in graph view (LR, TB, RL, BT), default LR 

424 :param catchup: Perform scheduler catchup (or only run latest)? Defaults to True 

425 :param on_failure_callback: A function or list of functions to be called when a DagRun of this dag fails. 

426 A context dictionary is passed as a single parameter to this function. 

427 :param on_success_callback: Much like the ``on_failure_callback`` except 

428 that it is executed when the dag succeeds. 

429 :param access_control: Specify optional DAG-level actions, e.g., 

430 "{'role1': {'can_read'}, 'role2': {'can_read', 'can_edit', 'can_delete'}}" 

431 :param is_paused_upon_creation: Specifies if the dag is paused when created for the first time. 

432 If the dag exists already, this flag will be ignored. If this optional parameter 

433 is not specified, the global config setting will be used. 

434 :param jinja_environment_kwargs: additional configuration options to be passed to Jinja 

435 ``Environment`` for template rendering 

436 

437 **Example**: to avoid Jinja from removing a trailing newline from template strings :: 

438 

439 DAG( 

440 dag_id="my-dag", 

441 jinja_environment_kwargs={ 

442 "keep_trailing_newline": True, 

443 # some other jinja2 Environment options here 

444 }, 

445 ) 

446 

447 **See**: `Jinja Environment documentation 

448 <https://jinja.palletsprojects.com/en/2.11.x/api/#jinja2.Environment>`_ 

449 

450 :param render_template_as_native_obj: If True, uses a Jinja ``NativeEnvironment`` 

451 to render templates as native Python types. If False, a Jinja 

452 ``Environment`` is used to render templates as string values. 

453 :param tags: List of tags to help filtering DAGs in the UI. 

454 :param owner_links: Dict of owners and their links, that will be clickable on the DAGs view UI. 

455 Can be used as an HTTP link (for example the link to your Slack channel), or a mailto link. 

456 e.g: {"dag_owner": "https://airflow.apache.org/"} 

457 :param auto_register: Automatically register this DAG when it is used in a ``with`` block 

458 :param fail_stop: Fails currently running tasks when task in DAG fails. 

459 **Warning**: A fail stop dag can only have tasks with the default trigger rule ("all_success"). 

460 An exception will be thrown if any task in a fail stop dag has a non default trigger rule. 

461 :param dag_display_name: The display name of the DAG which appears on the UI. 

462 """ 

463 

464 _comps = { 

465 "dag_id", 

466 "task_ids", 

467 "parent_dag", 

468 "start_date", 

469 "end_date", 

470 "schedule_interval", 

471 "fileloc", 

472 "template_searchpath", 

473 "last_loaded", 

474 } 

475 

476 __serialized_fields: frozenset[str] | None = None 

477 

478 fileloc: str 

479 """ 

480 File path that needs to be imported to load this DAG or subdag. 

481 

482 This may not be an actual file on disk in the case when this DAG is loaded 

483 from a ZIP file or other DAG distribution format. 

484 """ 

485 

486 parent_dag: DAG | None = None # Gets set when DAGs are loaded 

487 

488 # NOTE: When updating arguments here, please also keep arguments in @dag() 

489 # below in sync. (Search for 'def dag(' in this file.) 

490 def __init__( 

491 self, 

492 dag_id: str, 

493 description: str | None = None, 

494 schedule: ScheduleArg = NOTSET, 

495 schedule_interval: ScheduleIntervalArg = NOTSET, 

496 timetable: Timetable | None = None, 

497 start_date: datetime | None = None, 

498 end_date: datetime | None = None, 

499 full_filepath: str | None = None, 

500 template_searchpath: str | Iterable[str] | None = None, 

501 template_undefined: type[jinja2.StrictUndefined] = jinja2.StrictUndefined, 

502 user_defined_macros: dict | None = None, 

503 user_defined_filters: dict | None = None, 

504 default_args: dict | None = None, 

505 concurrency: int | None = None, 

506 max_active_tasks: int = airflow_conf.getint("core", "max_active_tasks_per_dag"), 

507 max_active_runs: int = airflow_conf.getint("core", "max_active_runs_per_dag"), 

508 max_consecutive_failed_dag_runs: int = airflow_conf.getint( 

509 "core", "max_consecutive_failed_dag_runs_per_dag" 

510 ), 

511 dagrun_timeout: timedelta | None = None, 

512 sla_miss_callback: None | SLAMissCallback | list[SLAMissCallback] = None, 

513 default_view: str = airflow_conf.get_mandatory_value("webserver", "dag_default_view").lower(), 

514 orientation: str = airflow_conf.get_mandatory_value("webserver", "dag_orientation"), 

515 catchup: bool = airflow_conf.getboolean("scheduler", "catchup_by_default"), 

516 on_success_callback: None | DagStateChangeCallback | list[DagStateChangeCallback] = None, 

517 on_failure_callback: None | DagStateChangeCallback | list[DagStateChangeCallback] = None, 

518 doc_md: str | None = None, 

519 params: abc.MutableMapping | None = None, 

520 access_control: dict | None = None, 

521 is_paused_upon_creation: bool | None = None, 

522 jinja_environment_kwargs: dict | None = None, 

523 render_template_as_native_obj: bool = False, 

524 tags: list[str] | None = None, 

525 owner_links: dict[str, str] | None = None, 

526 auto_register: bool = True, 

527 fail_stop: bool = False, 

528 dag_display_name: str | None = None, 

529 ): 

530 from airflow.utils.task_group import TaskGroup 

531 

532 if tags and any(len(tag) > TAG_MAX_LEN for tag in tags): 

533 raise AirflowException(f"tag cannot be longer than {TAG_MAX_LEN} characters") 

534 

535 self.owner_links = owner_links or {} 

536 self.user_defined_macros = user_defined_macros 

537 self.user_defined_filters = user_defined_filters 

538 if default_args and not isinstance(default_args, dict): 

539 raise TypeError("default_args must be a dict") 

540 self.default_args = copy.deepcopy(default_args or {}) 

541 params = params or {} 

542 

543 # merging potentially conflicting default_args['params'] into params 

544 if "params" in self.default_args: 

545 params.update(self.default_args["params"]) 

546 del self.default_args["params"] 

547 

548 # check self.params and convert them into ParamsDict 

549 self.params = ParamsDict(params) 

550 

551 if full_filepath: 

552 warnings.warn( 

553 "Passing full_filepath to DAG() is deprecated and has no effect", 

554 RemovedInAirflow3Warning, 

555 stacklevel=2, 

556 ) 

557 

558 validate_key(dag_id) 

559 

560 self._dag_id = dag_id 

561 self._dag_display_property_value = dag_display_name 

562 

563 if concurrency: 

564 # TODO: Remove in Airflow 3.0 

565 warnings.warn( 

566 "The 'concurrency' parameter is deprecated. Please use 'max_active_tasks'.", 

567 RemovedInAirflow3Warning, 

568 stacklevel=2, 

569 ) 

570 max_active_tasks = concurrency 

571 self._max_active_tasks = max_active_tasks 

572 self._pickle_id: int | None = None 

573 

574 self._description = description 

575 # set file location to caller source path 

576 back = sys._getframe().f_back 

577 self.fileloc = back.f_code.co_filename if back else "" 

578 self.task_dict: dict[str, Operator] = {} 

579 

580 # set timezone from start_date 

581 tz = None 

582 if start_date and start_date.tzinfo: 

583 tzinfo = None if start_date.tzinfo else settings.TIMEZONE 

584 tz = pendulum.instance(start_date, tz=tzinfo).timezone 

585 elif date := self.default_args.get("start_date"): 

586 if not isinstance(date, datetime): 

587 date = timezone.parse(date) 

588 self.default_args["start_date"] = date 

589 start_date = date 

590 

591 tzinfo = None if date.tzinfo else settings.TIMEZONE 

592 tz = pendulum.instance(date, tz=tzinfo).timezone 

593 self.timezone: Timezone | FixedTimezone = tz or settings.TIMEZONE 

594 

595 # Apply the timezone we settled on to end_date if it wasn't supplied 

596 if isinstance(_end_date := self.default_args.get("end_date"), str): 

597 self.default_args["end_date"] = timezone.parse(_end_date, timezone=self.timezone) 

598 

599 self.start_date = timezone.convert_to_utc(start_date) 

600 self.end_date = timezone.convert_to_utc(end_date) 

601 

602 # also convert tasks 

603 if "start_date" in self.default_args: 

604 self.default_args["start_date"] = timezone.convert_to_utc(self.default_args["start_date"]) 

605 if "end_date" in self.default_args: 

606 self.default_args["end_date"] = timezone.convert_to_utc(self.default_args["end_date"]) 

607 

608 # sort out DAG's scheduling behavior 

609 scheduling_args = [schedule_interval, timetable, schedule] 

610 

611 has_scheduling_args = any(a is not NOTSET and bool(a) for a in scheduling_args) 

612 has_empty_start_date = not ("start_date" in self.default_args or self.start_date) 

613 

614 if has_scheduling_args and has_empty_start_date: 

615 raise ValueError("DAG is missing the start_date parameter") 

616 

617 if not at_most_one(*scheduling_args): 

618 raise ValueError("At most one allowed for args 'schedule_interval', 'timetable', and 'schedule'.") 

619 if schedule_interval is not NOTSET: 

620 warnings.warn( 

621 "Param `schedule_interval` is deprecated and will be removed in a future release. " 

622 "Please use `schedule` instead. ", 

623 RemovedInAirflow3Warning, 

624 stacklevel=2, 

625 ) 

626 if timetable is not None: 

627 warnings.warn( 

628 "Param `timetable` is deprecated and will be removed in a future release. " 

629 "Please use `schedule` instead. ", 

630 RemovedInAirflow3Warning, 

631 stacklevel=2, 

632 ) 

633 

634 self.timetable: Timetable 

635 self.schedule_interval: ScheduleInterval 

636 self.dataset_triggers: BaseDataset | None = None 

637 if isinstance(schedule, BaseDataset): 

638 self.dataset_triggers = schedule 

639 elif isinstance(schedule, Collection) and not isinstance(schedule, str): 

640 if not all(isinstance(x, Dataset) for x in schedule): 

641 raise ValueError("All elements in 'schedule' should be datasets") 

642 self.dataset_triggers = DatasetAll(*schedule) 

643 elif isinstance(schedule, Timetable): 

644 timetable = schedule 

645 elif schedule is not NOTSET and not isinstance(schedule, BaseDataset): 

646 schedule_interval = schedule 

647 

648 if isinstance(schedule, DatasetOrTimeSchedule): 

649 self.timetable = schedule 

650 self.dataset_triggers = self.timetable.datasets 

651 self.schedule_interval = self.timetable.summary 

652 elif self.dataset_triggers: 

653 self.timetable = DatasetTriggeredTimetable() 

654 self.schedule_interval = self.timetable.summary 

655 elif timetable: 

656 self.timetable = timetable 

657 self.schedule_interval = self.timetable.summary 

658 else: 

659 if isinstance(schedule_interval, ArgNotSet): 

660 schedule_interval = DEFAULT_SCHEDULE_INTERVAL 

661 self.schedule_interval = schedule_interval 

662 self.timetable = create_timetable(schedule_interval, self.timezone) 

663 

664 if isinstance(template_searchpath, str): 

665 template_searchpath = [template_searchpath] 

666 self.template_searchpath = template_searchpath 

667 self.template_undefined = template_undefined 

668 self.last_loaded: datetime = timezone.utcnow() 

669 self.safe_dag_id = dag_id.replace(".", "__dot__") 

670 self.max_active_runs = max_active_runs 

671 self.max_consecutive_failed_dag_runs = max_consecutive_failed_dag_runs 

672 if self.max_consecutive_failed_dag_runs == 0: 

673 self.max_consecutive_failed_dag_runs = airflow_conf.getint( 

674 "core", "max_consecutive_failed_dag_runs_per_dag" 

675 ) 

676 if self.max_consecutive_failed_dag_runs < 0: 

677 raise AirflowException( 

678 f"Invalid max_consecutive_failed_dag_runs: {self.max_consecutive_failed_dag_runs}." 

679 f"Requires max_consecutive_failed_dag_runs >= 0" 

680 ) 

681 if self.timetable.active_runs_limit is not None: 

682 if self.timetable.active_runs_limit < self.max_active_runs: 

683 raise AirflowException( 

684 f"Invalid max_active_runs: {type(self.timetable)} " 

685 f"requires max_active_runs <= {self.timetable.active_runs_limit}" 

686 ) 

687 self.dagrun_timeout = dagrun_timeout 

688 self.sla_miss_callback = sla_miss_callback 

689 if default_view in DEFAULT_VIEW_PRESETS: 

690 self._default_view: str = default_view 

691 elif default_view == "tree": 

692 warnings.warn( 

693 "`default_view` of 'tree' has been renamed to 'grid' -- please update your DAG", 

694 RemovedInAirflow3Warning, 

695 stacklevel=2, 

696 ) 

697 self._default_view = "grid" 

698 else: 

699 raise AirflowException( 

700 f"Invalid values of dag.default_view: only support " 

701 f"{DEFAULT_VIEW_PRESETS}, but get {default_view}" 

702 ) 

703 if orientation in ORIENTATION_PRESETS: 

704 self.orientation = orientation 

705 else: 

706 raise AirflowException( 

707 f"Invalid values of dag.orientation: only support " 

708 f"{ORIENTATION_PRESETS}, but get {orientation}" 

709 ) 

710 self.catchup: bool = catchup 

711 

712 self.partial: bool = False 

713 self.on_success_callback = on_success_callback 

714 self.on_failure_callback = on_failure_callback 

715 

716 # Keeps track of any extra edge metadata (sparse; will not contain all 

717 # edges, so do not iterate over it for that). Outer key is upstream 

718 # task ID, inner key is downstream task ID. 

719 self.edge_info: dict[str, dict[str, EdgeInfoType]] = {} 

720 

721 # To keep it in parity with Serialized DAGs 

722 # and identify if DAG has on_*_callback without actually storing them in Serialized JSON 

723 self.has_on_success_callback: bool = self.on_success_callback is not None 

724 self.has_on_failure_callback: bool = self.on_failure_callback is not None 

725 

726 self._access_control = DAG._upgrade_outdated_dag_access_control(access_control) 

727 self.is_paused_upon_creation = is_paused_upon_creation 

728 self.auto_register = auto_register 

729 

730 self.fail_stop: bool = fail_stop 

731 

732 self.jinja_environment_kwargs = jinja_environment_kwargs 

733 self.render_template_as_native_obj = render_template_as_native_obj 

734 

735 self.doc_md = self.get_doc_md(doc_md) 

736 

737 self.tags = tags or [] 

738 self._task_group = TaskGroup.create_root(self) 

739 self.validate_schedule_and_params() 

740 wrong_links = dict(self.iter_invalid_owner_links()) 

741 if wrong_links: 

742 raise AirflowException( 

743 "Wrong link format was used for the owner. Use a valid link \n" 

744 f"Bad formatted links are: {wrong_links}" 

745 ) 

746 

747 # this will only be set at serialization time 

748 # it's only use is for determining the relative 

749 # fileloc based only on the serialize dag 

750 self._processor_dags_folder = None 

751 

752 def get_doc_md(self, doc_md: str | None) -> str | None: 

753 if doc_md is None: 

754 return doc_md 

755 

756 env = self.get_template_env(force_sandboxed=True) 

757 

758 if not doc_md.endswith(".md"): 

759 template = jinja2.Template(doc_md) 

760 else: 

761 try: 

762 template = env.get_template(doc_md) 

763 except jinja2.exceptions.TemplateNotFound: 

764 return f""" 

765 # Templating Error! 

766 Not able to find the template file: `{doc_md}`. 

767 """ 

768 

769 return template.render() 

770 

771 def _check_schedule_interval_matches_timetable(self) -> bool: 

772 """Check ``schedule_interval`` and ``timetable`` match. 

773 

774 This is done as a part of the DAG validation done before it's bagged, to 

775 guard against the DAG's ``timetable`` (or ``schedule_interval``) from 

776 being changed after it's created, e.g. 

777 

778 .. code-block:: python 

779 

780 dag1 = DAG("d1", timetable=MyTimetable()) 

781 dag1.schedule_interval = "@once" 

782 

783 dag2 = DAG("d2", schedule="@once") 

784 dag2.timetable = MyTimetable() 

785 

786 Validation is done by creating a timetable and check its summary matches 

787 ``schedule_interval``. The logic is not bullet-proof, especially if a 

788 custom timetable does not provide a useful ``summary``. But this is the 

789 best we can do. 

790 """ 

791 if self.schedule_interval == self.timetable.summary: 

792 return True 

793 try: 

794 timetable = create_timetable(self.schedule_interval, self.timezone) 

795 except ValueError: 

796 return False 

797 return timetable.summary == self.timetable.summary 

798 

799 def validate(self): 

800 """Validate the DAG has a coherent setup. 

801 

802 This is called by the DAG bag before bagging the DAG. 

803 """ 

804 if not self._check_schedule_interval_matches_timetable(): 

805 raise AirflowDagInconsistent( 

806 f"inconsistent schedule: timetable {self.timetable.summary!r} " 

807 f"does not match schedule_interval {self.schedule_interval!r}", 

808 ) 

809 self.validate_schedule_and_params() 

810 self.timetable.validate() 

811 self.validate_setup_teardown() 

812 

813 def validate_setup_teardown(self): 

814 """ 

815 Validate that setup and teardown tasks are configured properly. 

816 

817 :meta private: 

818 """ 

819 for task in self.tasks: 

820 if task.is_setup: 

821 for down_task in task.downstream_list: 

822 if not down_task.is_teardown and down_task.trigger_rule != TriggerRule.ALL_SUCCESS: 

823 # todo: we can relax this to allow out-of-scope tasks to have other trigger rules 

824 # this is required to ensure consistent behavior of dag 

825 # when clearing an indirect setup 

826 raise ValueError("Setup tasks must be followed with trigger rule ALL_SUCCESS.") 

827 FailStopDagInvalidTriggerRule.check(dag=self, trigger_rule=task.trigger_rule) 

828 

829 def __repr__(self): 

830 return f"<DAG: {self.dag_id}>" 

831 

832 def __eq__(self, other): 

833 if type(self) == type(other): 

834 # Use getattr() instead of __dict__ as __dict__ doesn't return 

835 # correct values for properties. 

836 return all(getattr(self, c, None) == getattr(other, c, None) for c in self._comps) 

837 return False 

838 

839 def __ne__(self, other): 

840 return not self == other 

841 

842 def __lt__(self, other): 

843 return self.dag_id < other.dag_id 

844 

845 def __hash__(self): 

846 hash_components = [type(self)] 

847 for c in self._comps: 

848 # task_ids returns a list and lists can't be hashed 

849 if c == "task_ids": 

850 val = tuple(self.task_dict) 

851 else: 

852 val = getattr(self, c, None) 

853 try: 

854 hash(val) 

855 hash_components.append(val) 

856 except TypeError: 

857 hash_components.append(repr(val)) 

858 return hash(tuple(hash_components)) 

859 

860 # Context Manager ----------------------------------------------- 

861 def __enter__(self): 

862 DagContext.push_context_managed_dag(self) 

863 return self 

864 

865 def __exit__(self, _type, _value, _tb): 

866 DagContext.pop_context_managed_dag() 

867 

868 # /Context Manager ---------------------------------------------- 

869 

870 @staticmethod 

871 def _upgrade_outdated_dag_access_control(access_control=None): 

872 """ 

873 Look for outdated dag level actions in DAG access_controls and replace them with updated actions. 

874 

875 For example, in DAG access_control {'role1': {'can_dag_read'}} 'can_dag_read' 

876 will be replaced with 'can_read', in {'role2': {'can_dag_read', 'can_dag_edit'}} 

877 'can_dag_edit' will be replaced with 'can_edit', etc. 

878 """ 

879 if access_control is None: 

880 return None 

881 new_perm_mapping = { 

882 permissions.DEPRECATED_ACTION_CAN_DAG_READ: permissions.ACTION_CAN_READ, 

883 permissions.DEPRECATED_ACTION_CAN_DAG_EDIT: permissions.ACTION_CAN_EDIT, 

884 } 

885 updated_access_control = {} 

886 for role, perms in access_control.items(): 

887 updated_access_control[role] = {new_perm_mapping.get(perm, perm) for perm in perms} 

888 

889 if access_control != updated_access_control: 

890 warnings.warn( 

891 "The 'can_dag_read' and 'can_dag_edit' permissions are deprecated. " 

892 "Please use 'can_read' and 'can_edit', respectively.", 

893 RemovedInAirflow3Warning, 

894 stacklevel=3, 

895 ) 

896 

897 return updated_access_control 

898 

899 def date_range( 

900 self, 

901 start_date: pendulum.DateTime, 

902 num: int | None = None, 

903 end_date: datetime | None = None, 

904 ) -> list[datetime]: 

905 message = "`DAG.date_range()` is deprecated." 

906 if num is not None: 

907 warnings.warn(message, category=RemovedInAirflow3Warning, stacklevel=2) 

908 with warnings.catch_warnings(): 

909 warnings.simplefilter("ignore", RemovedInAirflow3Warning) 

910 return utils_date_range( 

911 start_date=start_date, num=num, delta=self.normalized_schedule_interval 

912 ) 

913 message += " Please use `DAG.iter_dagrun_infos_between(..., align=False)` instead." 

914 warnings.warn(message, category=RemovedInAirflow3Warning, stacklevel=2) 

915 if end_date is None: 

916 coerced_end_date = timezone.utcnow() 

917 else: 

918 coerced_end_date = end_date 

919 it = self.iter_dagrun_infos_between(start_date, pendulum.instance(coerced_end_date), align=False) 

920 return [info.logical_date for info in it] 

921 

922 def is_fixed_time_schedule(self): 

923 """Figures out if the schedule has a fixed time (e.g. 3 AM every day). 

924 

925 Detection is done by "peeking" the next two cron trigger time; if the 

926 two times have the same minute and hour value, the schedule is fixed, 

927 and we *don't* need to perform the DST fix. 

928 

929 This assumes DST happens on whole minute changes (e.g. 12:59 -> 12:00). 

930 

931 Do not try to understand what this actually means. It is old logic that 

932 should not be used anywhere. 

933 """ 

934 warnings.warn( 

935 "`DAG.is_fixed_time_schedule()` is deprecated.", 

936 category=RemovedInAirflow3Warning, 

937 stacklevel=2, 

938 ) 

939 

940 from airflow.timetables._cron import CronMixin 

941 

942 if not isinstance(self.timetable, CronMixin): 

943 return True 

944 

945 from croniter import croniter 

946 

947 cron = croniter(self.timetable._expression) 

948 next_a = cron.get_next(datetime) 

949 next_b = cron.get_next(datetime) 

950 return next_b.minute == next_a.minute and next_b.hour == next_a.hour 

951 

952 def following_schedule(self, dttm): 

953 """ 

954 Calculate the following schedule for this dag in UTC. 

955 

956 :param dttm: utc datetime 

957 :return: utc datetime 

958 """ 

959 warnings.warn( 

960 "`DAG.following_schedule()` is deprecated. Use `DAG.next_dagrun_info(restricted=False)` instead.", 

961 category=RemovedInAirflow3Warning, 

962 stacklevel=2, 

963 ) 

964 data_interval = self.infer_automated_data_interval(timezone.coerce_datetime(dttm)) 

965 next_info = self.next_dagrun_info(data_interval, restricted=False) 

966 if next_info is None: 

967 return None 

968 return next_info.data_interval.start 

969 

970 def previous_schedule(self, dttm): 

971 from airflow.timetables.interval import _DataIntervalTimetable 

972 

973 warnings.warn( 

974 "`DAG.previous_schedule()` is deprecated.", 

975 category=RemovedInAirflow3Warning, 

976 stacklevel=2, 

977 ) 

978 if not isinstance(self.timetable, _DataIntervalTimetable): 

979 return None 

980 return self.timetable._get_prev(timezone.coerce_datetime(dttm)) 

981 

982 def get_next_data_interval(self, dag_model: DagModel) -> DataInterval | None: 

983 """Get the data interval of the next scheduled run. 

984 

985 For compatibility, this method infers the data interval from the DAG's 

986 schedule if the run does not have an explicit one set, which is possible 

987 for runs created prior to AIP-39. 

988 

989 This function is private to Airflow core and should not be depended on as a 

990 part of the Python API. 

991 

992 :meta private: 

993 """ 

994 if self.dag_id != dag_model.dag_id: 

995 raise ValueError(f"Arguments refer to different DAGs: {self.dag_id} != {dag_model.dag_id}") 

996 if dag_model.next_dagrun is None: # Next run not scheduled. 

997 return None 

998 data_interval = dag_model.next_dagrun_data_interval 

999 if data_interval is not None: 

1000 return data_interval 

1001 

1002 # Compatibility: A run was scheduled without an explicit data interval. 

1003 # This means the run was scheduled before AIP-39 implementation. Try to 

1004 # infer from the logical date. 

1005 return self.infer_automated_data_interval(dag_model.next_dagrun) 

1006 

1007 def get_run_data_interval(self, run: DagRun | DagRunPydantic) -> DataInterval: 

1008 """Get the data interval of this run. 

1009 

1010 For compatibility, this method infers the data interval from the DAG's 

1011 schedule if the run does not have an explicit one set, which is possible for 

1012 runs created prior to AIP-39. 

1013 

1014 This function is private to Airflow core and should not be depended on as a 

1015 part of the Python API. 

1016 

1017 :meta private: 

1018 """ 

1019 if run.dag_id is not None and run.dag_id != self.dag_id: 

1020 raise ValueError(f"Arguments refer to different DAGs: {self.dag_id} != {run.dag_id}") 

1021 data_interval = _get_model_data_interval(run, "data_interval_start", "data_interval_end") 

1022 if data_interval is not None: 

1023 return data_interval 

1024 # Compatibility: runs created before AIP-39 implementation don't have an 

1025 # explicit data interval. Try to infer from the logical date. 

1026 return self.infer_automated_data_interval(run.execution_date) 

1027 

1028 def infer_automated_data_interval(self, logical_date: datetime) -> DataInterval: 

1029 """Infer a data interval for a run against this DAG. 

1030 

1031 This method is used to bridge runs created prior to AIP-39 

1032 implementation, which do not have an explicit data interval. Therefore, 

1033 this method only considers ``schedule_interval`` values valid prior to 

1034 Airflow 2.2. 

1035 

1036 DO NOT call this method if there is a known data interval. 

1037 

1038 :meta private: 

1039 """ 

1040 timetable_type = type(self.timetable) 

1041 if issubclass(timetable_type, (NullTimetable, OnceTimetable, DatasetTriggeredTimetable)): 

1042 return DataInterval.exact(timezone.coerce_datetime(logical_date)) 

1043 start = timezone.coerce_datetime(logical_date) 

1044 if issubclass(timetable_type, CronDataIntervalTimetable): 

1045 end = cast(CronDataIntervalTimetable, self.timetable)._get_next(start) 

1046 elif issubclass(timetable_type, DeltaDataIntervalTimetable): 

1047 end = cast(DeltaDataIntervalTimetable, self.timetable)._get_next(start) 

1048 # Contributors: When the exception below is raised, you might want to 

1049 # add an 'elif' block here to handle custom timetables. Stop! The bug 

1050 # you're looking for is instead at when the DAG run (represented by 

1051 # logical_date) was created. See GH-31969 for an example: 

1052 # * Wrong fix: GH-32074 (modifies this function). 

1053 # * Correct fix: GH-32118 (modifies the DAG run creation code). 

1054 else: 

1055 raise ValueError(f"Not a valid timetable: {self.timetable!r}") 

1056 return DataInterval(start, end) 

1057 

1058 def next_dagrun_info( 

1059 self, 

1060 last_automated_dagrun: None | datetime | DataInterval, 

1061 *, 

1062 restricted: bool = True, 

1063 ) -> DagRunInfo | None: 

1064 """Get information about the next DagRun of this dag after ``date_last_automated_dagrun``. 

1065 

1066 This calculates what time interval the next DagRun should operate on 

1067 (its execution date) and when it can be scheduled, according to the 

1068 dag's timetable, start_date, end_date, etc. This doesn't check max 

1069 active run or any other "max_active_tasks" type limits, but only 

1070 performs calculations based on the various date and interval fields of 

1071 this dag and its tasks. 

1072 

1073 :param last_automated_dagrun: The ``max(execution_date)`` of 

1074 existing "automated" DagRuns for this dag (scheduled or backfill, 

1075 but not manual). 

1076 :param restricted: If set to *False* (default is *True*), ignore 

1077 ``start_date``, ``end_date``, and ``catchup`` specified on the DAG 

1078 or tasks. 

1079 :return: DagRunInfo of the next dagrun, or None if a dagrun is not 

1080 going to be scheduled. 

1081 """ 

1082 # Never schedule a subdag. It will be scheduled by its parent dag. 

1083 if self.is_subdag: 

1084 return None 

1085 

1086 data_interval = None 

1087 if isinstance(last_automated_dagrun, datetime): 

1088 warnings.warn( 

1089 "Passing a datetime to DAG.next_dagrun_info is deprecated. Use a DataInterval instead.", 

1090 RemovedInAirflow3Warning, 

1091 stacklevel=2, 

1092 ) 

1093 data_interval = self.infer_automated_data_interval( 

1094 timezone.coerce_datetime(last_automated_dagrun) 

1095 ) 

1096 else: 

1097 data_interval = last_automated_dagrun 

1098 if restricted: 

1099 restriction = self._time_restriction 

1100 else: 

1101 restriction = TimeRestriction(earliest=None, latest=None, catchup=True) 

1102 try: 

1103 info = self.timetable.next_dagrun_info( 

1104 last_automated_data_interval=data_interval, 

1105 restriction=restriction, 

1106 ) 

1107 except Exception: 

1108 self.log.exception( 

1109 "Failed to fetch run info after data interval %s for DAG %r", 

1110 data_interval, 

1111 self.dag_id, 

1112 ) 

1113 info = None 

1114 return info 

1115 

1116 def next_dagrun_after_date(self, date_last_automated_dagrun: pendulum.DateTime | None): 

1117 warnings.warn( 

1118 "`DAG.next_dagrun_after_date()` is deprecated. Please use `DAG.next_dagrun_info()` instead.", 

1119 category=RemovedInAirflow3Warning, 

1120 stacklevel=2, 

1121 ) 

1122 if date_last_automated_dagrun is None: 

1123 data_interval = None 

1124 else: 

1125 data_interval = self.infer_automated_data_interval(date_last_automated_dagrun) 

1126 info = self.next_dagrun_info(data_interval) 

1127 if info is None: 

1128 return None 

1129 return info.run_after 

1130 

1131 @functools.cached_property 

1132 def _time_restriction(self) -> TimeRestriction: 

1133 start_dates = [t.start_date for t in self.tasks if t.start_date] 

1134 if self.start_date is not None: 

1135 start_dates.append(self.start_date) 

1136 earliest = None 

1137 if start_dates: 

1138 earliest = timezone.coerce_datetime(min(start_dates)) 

1139 latest = self.end_date 

1140 end_dates = [t.end_date for t in self.tasks if t.end_date] 

1141 if len(end_dates) == len(self.tasks): # not exists null end_date 

1142 if self.end_date is not None: 

1143 end_dates.append(self.end_date) 

1144 if end_dates: 

1145 latest = timezone.coerce_datetime(max(end_dates)) 

1146 return TimeRestriction(earliest, latest, self.catchup) 

1147 

1148 def iter_dagrun_infos_between( 

1149 self, 

1150 earliest: pendulum.DateTime | None, 

1151 latest: pendulum.DateTime, 

1152 *, 

1153 align: bool = True, 

1154 ) -> Iterable[DagRunInfo]: 

1155 """Yield DagRunInfo using this DAG's timetable between given interval. 

1156 

1157 DagRunInfo instances yielded if their ``logical_date`` is not earlier 

1158 than ``earliest``, nor later than ``latest``. The instances are ordered 

1159 by their ``logical_date`` from earliest to latest. 

1160 

1161 If ``align`` is ``False``, the first run will happen immediately on 

1162 ``earliest``, even if it does not fall on the logical timetable schedule. 

1163 The default is ``True``, but subdags will ignore this value and always 

1164 behave as if this is set to ``False`` for backward compatibility. 

1165 

1166 Example: A DAG is scheduled to run every midnight (``0 0 * * *``). If 

1167 ``earliest`` is ``2021-06-03 23:00:00``, the first DagRunInfo would be 

1168 ``2021-06-03 23:00:00`` if ``align=False``, and ``2021-06-04 00:00:00`` 

1169 if ``align=True``. 

1170 """ 

1171 if earliest is None: 

1172 earliest = self._time_restriction.earliest 

1173 if earliest is None: 

1174 raise ValueError("earliest was None and we had no value in time_restriction to fallback on") 

1175 earliest = timezone.coerce_datetime(earliest) 

1176 latest = timezone.coerce_datetime(latest) 

1177 

1178 restriction = TimeRestriction(earliest, latest, catchup=True) 

1179 

1180 # HACK: Sub-DAGs are currently scheduled differently. For example, say 

1181 # the schedule is @daily and start is 2021-06-03 22:16:00, a top-level 

1182 # DAG should be first scheduled to run on midnight 2021-06-04, but a 

1183 # sub-DAG should be first scheduled to run RIGHT NOW. We can change 

1184 # this, but since sub-DAGs are going away in 3.0 anyway, let's keep 

1185 # compatibility for now and remove this entirely later. 

1186 if self.is_subdag: 

1187 align = False 

1188 

1189 try: 

1190 info = self.timetable.next_dagrun_info( 

1191 last_automated_data_interval=None, 

1192 restriction=restriction, 

1193 ) 

1194 except Exception: 

1195 self.log.exception( 

1196 "Failed to fetch run info after data interval %s for DAG %r", 

1197 None, 

1198 self.dag_id, 

1199 ) 

1200 info = None 

1201 

1202 if info is None: 

1203 # No runs to be scheduled between the user-supplied timeframe. But 

1204 # if align=False, "invent" a data interval for the timeframe itself. 

1205 if not align: 

1206 yield DagRunInfo.interval(earliest, latest) 

1207 return 

1208 

1209 # If align=False and earliest does not fall on the timetable's logical 

1210 # schedule, "invent" a data interval for it. 

1211 if not align and info.logical_date != earliest: 

1212 yield DagRunInfo.interval(earliest, info.data_interval.start) 

1213 

1214 # Generate naturally according to schedule. 

1215 while info is not None: 

1216 yield info 

1217 try: 

1218 info = self.timetable.next_dagrun_info( 

1219 last_automated_data_interval=info.data_interval, 

1220 restriction=restriction, 

1221 ) 

1222 except Exception: 

1223 self.log.exception( 

1224 "Failed to fetch run info after data interval %s for DAG %r", 

1225 info.data_interval if info else "<NONE>", 

1226 self.dag_id, 

1227 ) 

1228 break 

1229 

1230 def get_run_dates(self, start_date, end_date=None) -> list: 

1231 """ 

1232 Return a list of dates between the interval received as parameter using this dag's schedule interval. 

1233 

1234 Returned dates can be used for execution dates. 

1235 

1236 :param start_date: The start date of the interval. 

1237 :param end_date: The end date of the interval. Defaults to ``timezone.utcnow()``. 

1238 :return: A list of dates within the interval following the dag's schedule. 

1239 """ 

1240 warnings.warn( 

1241 "`DAG.get_run_dates()` is deprecated. Please use `DAG.iter_dagrun_infos_between()` instead.", 

1242 category=RemovedInAirflow3Warning, 

1243 stacklevel=2, 

1244 ) 

1245 earliest = timezone.coerce_datetime(start_date) 

1246 if end_date is None: 

1247 latest = pendulum.now(timezone.utc) 

1248 else: 

1249 latest = timezone.coerce_datetime(end_date) 

1250 return [info.logical_date for info in self.iter_dagrun_infos_between(earliest, latest)] 

1251 

1252 def normalize_schedule(self, dttm): 

1253 warnings.warn( 

1254 "`DAG.normalize_schedule()` is deprecated.", 

1255 category=RemovedInAirflow3Warning, 

1256 stacklevel=2, 

1257 ) 

1258 with warnings.catch_warnings(): 

1259 warnings.simplefilter("ignore", RemovedInAirflow3Warning) 

1260 following = self.following_schedule(dttm) 

1261 if not following: # in case of @once 

1262 return dttm 

1263 with warnings.catch_warnings(): 

1264 warnings.simplefilter("ignore", RemovedInAirflow3Warning) 

1265 previous_of_following = self.previous_schedule(following) 

1266 if previous_of_following != dttm: 

1267 return following 

1268 return dttm 

1269 

1270 @provide_session 

1271 def get_last_dagrun(self, session=NEW_SESSION, include_externally_triggered=False): 

1272 return get_last_dagrun( 

1273 self.dag_id, session=session, include_externally_triggered=include_externally_triggered 

1274 ) 

1275 

1276 @provide_session 

1277 def has_dag_runs(self, session=NEW_SESSION, include_externally_triggered=True) -> bool: 

1278 return ( 

1279 get_last_dagrun( 

1280 self.dag_id, session=session, include_externally_triggered=include_externally_triggered 

1281 ) 

1282 is not None 

1283 ) 

1284 

1285 @property 

1286 def dag_id(self) -> str: 

1287 return self._dag_id 

1288 

1289 @dag_id.setter 

1290 def dag_id(self, value: str) -> None: 

1291 self._dag_id = value 

1292 

1293 @property 

1294 def is_subdag(self) -> bool: 

1295 return self.parent_dag is not None 

1296 

1297 @property 

1298 def full_filepath(self) -> str: 

1299 """Full file path to the DAG. 

1300 

1301 :meta private: 

1302 """ 

1303 warnings.warn( 

1304 "DAG.full_filepath is deprecated in favour of fileloc", 

1305 RemovedInAirflow3Warning, 

1306 stacklevel=2, 

1307 ) 

1308 return self.fileloc 

1309 

1310 @full_filepath.setter 

1311 def full_filepath(self, value) -> None: 

1312 warnings.warn( 

1313 "DAG.full_filepath is deprecated in favour of fileloc", 

1314 RemovedInAirflow3Warning, 

1315 stacklevel=2, 

1316 ) 

1317 self.fileloc = value 

1318 

1319 @property 

1320 def concurrency(self) -> int: 

1321 # TODO: Remove in Airflow 3.0 

1322 warnings.warn( 

1323 "The 'DAG.concurrency' attribute is deprecated. Please use 'DAG.max_active_tasks'.", 

1324 RemovedInAirflow3Warning, 

1325 stacklevel=2, 

1326 ) 

1327 return self._max_active_tasks 

1328 

1329 @concurrency.setter 

1330 def concurrency(self, value: int): 

1331 self._max_active_tasks = value 

1332 

1333 @property 

1334 def max_active_tasks(self) -> int: 

1335 return self._max_active_tasks 

1336 

1337 @max_active_tasks.setter 

1338 def max_active_tasks(self, value: int): 

1339 self._max_active_tasks = value 

1340 

1341 @property 

1342 def access_control(self): 

1343 return self._access_control 

1344 

1345 @access_control.setter 

1346 def access_control(self, value): 

1347 self._access_control = DAG._upgrade_outdated_dag_access_control(value) 

1348 

1349 @property 

1350 def dag_display_name(self) -> str: 

1351 return self._dag_display_property_value or self._dag_id 

1352 

1353 @property 

1354 def description(self) -> str | None: 

1355 return self._description 

1356 

1357 @property 

1358 def default_view(self) -> str: 

1359 return self._default_view 

1360 

1361 @property 

1362 def pickle_id(self) -> int | None: 

1363 return self._pickle_id 

1364 

1365 @pickle_id.setter 

1366 def pickle_id(self, value: int) -> None: 

1367 self._pickle_id = value 

1368 

1369 def param(self, name: str, default: Any = NOTSET) -> DagParam: 

1370 """ 

1371 Return a DagParam object for current dag. 

1372 

1373 :param name: dag parameter name. 

1374 :param default: fallback value for dag parameter. 

1375 :return: DagParam instance for specified name and current dag. 

1376 """ 

1377 return DagParam(current_dag=self, name=name, default=default) 

1378 

1379 @property 

1380 def tasks(self) -> list[Operator]: 

1381 return list(self.task_dict.values()) 

1382 

1383 @tasks.setter 

1384 def tasks(self, val): 

1385 raise AttributeError("DAG.tasks can not be modified. Use dag.add_task() instead.") 

1386 

1387 @property 

1388 def task_ids(self) -> list[str]: 

1389 return list(self.task_dict) 

1390 

1391 @property 

1392 def teardowns(self) -> list[Operator]: 

1393 return [task for task in self.tasks if getattr(task, "is_teardown", None)] 

1394 

1395 @property 

1396 def tasks_upstream_of_teardowns(self) -> list[Operator]: 

1397 upstream_tasks = [t.upstream_list for t in self.teardowns] 

1398 return [val for sublist in upstream_tasks for val in sublist if not getattr(val, "is_teardown", None)] 

1399 

1400 @property 

1401 def task_group(self) -> TaskGroup: 

1402 return self._task_group 

1403 

1404 @property 

1405 def filepath(self) -> str: 

1406 """Relative file path to the DAG. 

1407 

1408 :meta private: 

1409 """ 

1410 warnings.warn( 

1411 "filepath is deprecated, use relative_fileloc instead", 

1412 RemovedInAirflow3Warning, 

1413 stacklevel=2, 

1414 ) 

1415 return str(self.relative_fileloc) 

1416 

1417 @property 

1418 def relative_fileloc(self) -> pathlib.Path: 

1419 """File location of the importable dag 'file' relative to the configured DAGs folder.""" 

1420 path = pathlib.Path(self.fileloc) 

1421 try: 

1422 rel_path = path.relative_to(self._processor_dags_folder or settings.DAGS_FOLDER) 

1423 if rel_path == pathlib.Path("."): 

1424 return path 

1425 else: 

1426 return rel_path 

1427 except ValueError: 

1428 # Not relative to DAGS_FOLDER. 

1429 return path 

1430 

1431 @property 

1432 def folder(self) -> str: 

1433 """Folder location of where the DAG object is instantiated.""" 

1434 return os.path.dirname(self.fileloc) 

1435 

1436 @property 

1437 def owner(self) -> str: 

1438 """ 

1439 Return list of all owners found in DAG tasks. 

1440 

1441 :return: Comma separated list of owners in DAG tasks 

1442 """ 

1443 return ", ".join({t.owner for t in self.tasks}) 

1444 

1445 @property 

1446 def allow_future_exec_dates(self) -> bool: 

1447 return settings.ALLOW_FUTURE_EXEC_DATES and not self.timetable.can_be_scheduled 

1448 

1449 @provide_session 

1450 def get_concurrency_reached(self, session=NEW_SESSION) -> bool: 

1451 """Return a boolean indicating whether the max_active_tasks limit for this DAG has been reached.""" 

1452 TI = TaskInstance 

1453 total_tasks = session.scalar( 

1454 select(func.count(TI.task_id)).where( 

1455 TI.dag_id == self.dag_id, 

1456 TI.state == TaskInstanceState.RUNNING, 

1457 ) 

1458 ) 

1459 return total_tasks >= self.max_active_tasks 

1460 

1461 @property 

1462 def concurrency_reached(self): 

1463 """Use `airflow.models.DAG.get_concurrency_reached`, this attribute is deprecated.""" 

1464 warnings.warn( 

1465 "This attribute is deprecated. Please use `airflow.models.DAG.get_concurrency_reached` method.", 

1466 RemovedInAirflow3Warning, 

1467 stacklevel=2, 

1468 ) 

1469 return self.get_concurrency_reached() 

1470 

1471 @provide_session 

1472 def get_is_active(self, session=NEW_SESSION) -> None: 

1473 """Return a boolean indicating whether this DAG is active.""" 

1474 return session.scalar(select(DagModel.is_active).where(DagModel.dag_id == self.dag_id)) 

1475 

1476 @provide_session 

1477 def get_is_paused(self, session=NEW_SESSION) -> None: 

1478 """Return a boolean indicating whether this DAG is paused.""" 

1479 return session.scalar(select(DagModel.is_paused).where(DagModel.dag_id == self.dag_id)) 

1480 

1481 @property 

1482 def is_paused(self): 

1483 """Use `airflow.models.DAG.get_is_paused`, this attribute is deprecated.""" 

1484 warnings.warn( 

1485 "This attribute is deprecated. Please use `airflow.models.DAG.get_is_paused` method.", 

1486 RemovedInAirflow3Warning, 

1487 stacklevel=2, 

1488 ) 

1489 return self.get_is_paused() 

1490 

1491 @property 

1492 def normalized_schedule_interval(self) -> ScheduleInterval: 

1493 warnings.warn( 

1494 "DAG.normalized_schedule_interval() is deprecated.", 

1495 category=RemovedInAirflow3Warning, 

1496 stacklevel=2, 

1497 ) 

1498 if isinstance(self.schedule_interval, str) and self.schedule_interval in cron_presets: 

1499 _schedule_interval: ScheduleInterval = cron_presets.get(self.schedule_interval) 

1500 elif self.schedule_interval == "@once": 

1501 _schedule_interval = None 

1502 else: 

1503 _schedule_interval = self.schedule_interval 

1504 return _schedule_interval 

1505 

1506 @staticmethod 

1507 @internal_api_call 

1508 @provide_session 

1509 def fetch_callback( 

1510 dag: DAG, 

1511 dag_run_id: str, 

1512 success: bool = True, 

1513 reason: str | None = None, 

1514 *, 

1515 session: Session = NEW_SESSION, 

1516 ) -> tuple[list[TaskStateChangeCallback], Context] | None: 

1517 """ 

1518 Fetch the appropriate callbacks depending on the value of success. 

1519 

1520 This method gets the context of a single TaskInstance part of this DagRun and returns it along 

1521 the list of callbacks. 

1522 

1523 :param dag: DAG object 

1524 :param dag_run_id: The DAG run ID 

1525 :param success: Flag to specify if failure or success callback should be called 

1526 :param reason: Completion reason 

1527 :param session: Database session 

1528 """ 

1529 callbacks = dag.on_success_callback if success else dag.on_failure_callback 

1530 if callbacks: 

1531 dagrun = DAG.fetch_dagrun(dag_id=dag.dag_id, run_id=dag_run_id, session=session) 

1532 callbacks = callbacks if isinstance(callbacks, list) else [callbacks] 

1533 tis = dagrun.get_task_instances(session=session) 

1534 # tis from a dagrun may not be a part of dag.partial_subset, 

1535 # since dag.partial_subset is a subset of the dag. 

1536 # This ensures that we will only use the accessible TI 

1537 # context for the callback. 

1538 if dag.partial: 

1539 tis = [ti for ti in tis if not ti.state == State.NONE] 

1540 # filter out removed tasks 

1541 tis = [ti for ti in tis if ti.state != TaskInstanceState.REMOVED] 

1542 ti = tis[-1] # get first TaskInstance of DagRun 

1543 ti.task = dag.get_task(ti.task_id) 

1544 context = ti.get_template_context(session=session) 

1545 context["reason"] = reason 

1546 return callbacks, context 

1547 return None 

1548 

1549 @provide_session 

1550 def handle_callback(self, dagrun: DagRun, success=True, reason=None, session=NEW_SESSION): 

1551 """ 

1552 Triggers on_failure_callback or on_success_callback as appropriate. 

1553 

1554 This method gets the context of a single TaskInstance part of this DagRun 

1555 and passes that to the callable along with a 'reason', primarily to 

1556 differentiate DagRun failures. 

1557 

1558 .. note: The logs end up in 

1559 ``$AIRFLOW_HOME/logs/scheduler/latest/PROJECT/DAG_FILE.py.log`` 

1560 

1561 :param dagrun: DagRun object 

1562 :param success: Flag to specify if failure or success callback should be called 

1563 :param reason: Completion reason 

1564 :param session: Database session 

1565 """ 

1566 callbacks, context = DAG.fetch_callback( 

1567 dag=self, dag_run_id=dagrun.run_id, success=success, reason=reason, session=session 

1568 ) or (None, None) 

1569 

1570 DAG.execute_callback(callbacks, context, self.dag_id) 

1571 

1572 @classmethod 

1573 def execute_callback(cls, callbacks: list[Callable] | None, context: Context | None, dag_id: str): 

1574 """ 

1575 Triggers the callbacks with the given context. 

1576 

1577 :param callbacks: List of callbacks to call 

1578 :param context: Context to pass to all callbacks 

1579 :param dag_id: The dag_id of the DAG to find. 

1580 """ 

1581 if callbacks and context: 

1582 for callback in callbacks: 

1583 cls.logger().info("Executing dag callback function: %s", callback) 

1584 try: 

1585 callback(context) 

1586 except Exception: 

1587 cls.logger().exception("failed to invoke dag state update callback") 

1588 Stats.incr("dag.callback_exceptions", tags={"dag_id": dag_id}) 

1589 

1590 def get_active_runs(self): 

1591 """ 

1592 Return a list of dag run execution dates currently running. 

1593 

1594 :return: List of execution dates 

1595 """ 

1596 runs = DagRun.find(dag_id=self.dag_id, state=DagRunState.RUNNING) 

1597 

1598 active_dates = [] 

1599 for run in runs: 

1600 active_dates.append(run.execution_date) 

1601 

1602 return active_dates 

1603 

1604 @provide_session 

1605 def get_num_active_runs(self, external_trigger=None, only_running=True, session=NEW_SESSION): 

1606 """ 

1607 Return the number of active "running" dag runs. 

1608 

1609 :param external_trigger: True for externally triggered active dag runs 

1610 :param session: 

1611 :return: number greater than 0 for active dag runs 

1612 """ 

1613 query = select(func.count()).where(DagRun.dag_id == self.dag_id) 

1614 if only_running: 

1615 query = query.where(DagRun.state == DagRunState.RUNNING) 

1616 else: 

1617 query = query.where(DagRun.state.in_({DagRunState.RUNNING, DagRunState.QUEUED})) 

1618 

1619 if external_trigger is not None: 

1620 query = query.where( 

1621 DagRun.external_trigger == (expression.true() if external_trigger else expression.false()) 

1622 ) 

1623 

1624 return session.scalar(query) 

1625 

1626 @staticmethod 

1627 @internal_api_call 

1628 @provide_session 

1629 def fetch_dagrun( 

1630 dag_id: str, 

1631 execution_date: datetime | None = None, 

1632 run_id: str | None = None, 

1633 session: Session = NEW_SESSION, 

1634 ) -> DagRun | DagRunPydantic: 

1635 """ 

1636 Return the dag run for a given execution date or run_id if it exists, otherwise none. 

1637 

1638 :param dag_id: The dag_id of the DAG to find. 

1639 :param execution_date: The execution date of the DagRun to find. 

1640 :param run_id: The run_id of the DagRun to find. 

1641 :param session: 

1642 :return: The DagRun if found, otherwise None. 

1643 """ 

1644 if not (execution_date or run_id): 

1645 raise TypeError("You must provide either the execution_date or the run_id") 

1646 query = select(DagRun) 

1647 if execution_date: 

1648 query = query.where(DagRun.dag_id == dag_id, DagRun.execution_date == execution_date) 

1649 if run_id: 

1650 query = query.where(DagRun.dag_id == dag_id, DagRun.run_id == run_id) 

1651 return session.scalar(query) 

1652 

1653 @provide_session 

1654 def get_dagrun( 

1655 self, 

1656 execution_date: datetime | None = None, 

1657 run_id: str | None = None, 

1658 session: Session = NEW_SESSION, 

1659 ) -> DagRun | DagRunPydantic: 

1660 return DAG.fetch_dagrun( 

1661 dag_id=self.dag_id, execution_date=execution_date, run_id=run_id, session=session 

1662 ) 

1663 

1664 @provide_session 

1665 def get_dagruns_between(self, start_date, end_date, session=NEW_SESSION): 

1666 """ 

1667 Return the list of dag runs between start_date (inclusive) and end_date (inclusive). 

1668 

1669 :param start_date: The starting execution date of the DagRun to find. 

1670 :param end_date: The ending execution date of the DagRun to find. 

1671 :param session: 

1672 :return: The list of DagRuns found. 

1673 """ 

1674 dagruns = session.scalars( 

1675 select(DagRun).where( 

1676 DagRun.dag_id == self.dag_id, 

1677 DagRun.execution_date >= start_date, 

1678 DagRun.execution_date <= end_date, 

1679 ) 

1680 ).all() 

1681 

1682 return dagruns 

1683 

1684 @provide_session 

1685 def get_latest_execution_date(self, session: Session = NEW_SESSION) -> pendulum.DateTime | None: 

1686 """Return the latest date for which at least one dag run exists.""" 

1687 return session.scalar(select(func.max(DagRun.execution_date)).where(DagRun.dag_id == self.dag_id)) 

1688 

1689 @property 

1690 def latest_execution_date(self): 

1691 """Use `airflow.models.DAG.get_latest_execution_date`, this attribute is deprecated.""" 

1692 warnings.warn( 

1693 "This attribute is deprecated. Please use `airflow.models.DAG.get_latest_execution_date`.", 

1694 RemovedInAirflow3Warning, 

1695 stacklevel=2, 

1696 ) 

1697 return self.get_latest_execution_date() 

1698 

1699 @property 

1700 def subdags(self): 

1701 """Return a list of the subdag objects associated to this DAG.""" 

1702 # Check SubDag for class but don't check class directly 

1703 from airflow.operators.subdag import SubDagOperator 

1704 

1705 subdag_lst = [] 

1706 for task in self.tasks: 

1707 if ( 

1708 isinstance(task, SubDagOperator) 

1709 or 

1710 # TODO remove in Airflow 2.0 

1711 type(task).__name__ == "SubDagOperator" 

1712 or task.task_type == "SubDagOperator" 

1713 ): 

1714 subdag_lst.append(task.subdag) 

1715 subdag_lst += task.subdag.subdags 

1716 return subdag_lst 

1717 

1718 def resolve_template_files(self): 

1719 for t in self.tasks: 

1720 t.resolve_template_files() 

1721 

1722 def get_template_env(self, *, force_sandboxed: bool = False) -> jinja2.Environment: 

1723 """Build a Jinja2 environment.""" 

1724 # Collect directories to search for template files 

1725 searchpath = [self.folder] 

1726 if self.template_searchpath: 

1727 searchpath += self.template_searchpath 

1728 

1729 # Default values (for backward compatibility) 

1730 jinja_env_options = { 

1731 "loader": jinja2.FileSystemLoader(searchpath), 

1732 "undefined": self.template_undefined, 

1733 "extensions": ["jinja2.ext.do"], 

1734 "cache_size": 0, 

1735 } 

1736 if self.jinja_environment_kwargs: 

1737 jinja_env_options.update(self.jinja_environment_kwargs) 

1738 env: jinja2.Environment 

1739 if self.render_template_as_native_obj and not force_sandboxed: 

1740 env = airflow.templates.NativeEnvironment(**jinja_env_options) 

1741 else: 

1742 env = airflow.templates.SandboxedEnvironment(**jinja_env_options) 

1743 

1744 # Add any user defined items. Safe to edit globals as long as no templates are rendered yet. 

1745 # http://jinja.pocoo.org/docs/2.10/api/#jinja2.Environment.globals 

1746 if self.user_defined_macros: 

1747 env.globals.update(self.user_defined_macros) 

1748 if self.user_defined_filters: 

1749 env.filters.update(self.user_defined_filters) 

1750 

1751 return env 

1752 

1753 def set_dependency(self, upstream_task_id, downstream_task_id): 

1754 """Set dependency between two tasks that already have been added to the DAG using add_task().""" 

1755 self.get_task(upstream_task_id).set_downstream(self.get_task(downstream_task_id)) 

1756 

1757 @provide_session 

1758 def get_task_instances_before( 

1759 self, 

1760 base_date: datetime, 

1761 num: int, 

1762 *, 

1763 session: Session = NEW_SESSION, 

1764 ) -> list[TaskInstance]: 

1765 """Get ``num`` task instances before (including) ``base_date``. 

1766 

1767 The returned list may contain exactly ``num`` task instances 

1768 corresponding to any DagRunType. It can have less if there are 

1769 less than ``num`` scheduled DAG runs before ``base_date``. 

1770 """ 

1771 execution_dates: list[Any] = session.execute( 

1772 select(DagRun.execution_date) 

1773 .where( 

1774 DagRun.dag_id == self.dag_id, 

1775 DagRun.execution_date <= base_date, 

1776 ) 

1777 .order_by(DagRun.execution_date.desc()) 

1778 .limit(num) 

1779 ).all() 

1780 

1781 if not execution_dates: 

1782 return self.get_task_instances(start_date=base_date, end_date=base_date, session=session) 

1783 

1784 min_date: datetime | None = execution_dates[-1]._mapping.get( 

1785 "execution_date" 

1786 ) # getting the last value from the list 

1787 

1788 return self.get_task_instances(start_date=min_date, end_date=base_date, session=session) 

1789 

1790 @provide_session 

1791 def get_task_instances( 

1792 self, 

1793 start_date: datetime | None = None, 

1794 end_date: datetime | None = None, 

1795 state: list[TaskInstanceState] | None = None, 

1796 session: Session = NEW_SESSION, 

1797 ) -> list[TaskInstance]: 

1798 if not start_date: 

1799 start_date = (timezone.utcnow() - timedelta(30)).replace( 

1800 hour=0, minute=0, second=0, microsecond=0 

1801 ) 

1802 

1803 query = self._get_task_instances( 

1804 task_ids=None, 

1805 start_date=start_date, 

1806 end_date=end_date, 

1807 run_id=None, 

1808 state=state or (), 

1809 include_subdags=False, 

1810 include_parentdag=False, 

1811 include_dependent_dags=False, 

1812 exclude_task_ids=(), 

1813 session=session, 

1814 ) 

1815 return session.scalars(cast(Select, query).order_by(DagRun.execution_date)).all() 

1816 

1817 @overload 

1818 def _get_task_instances( 

1819 self, 

1820 *, 

1821 task_ids: Collection[str | tuple[str, int]] | None, 

1822 start_date: datetime | None, 

1823 end_date: datetime | None, 

1824 run_id: str | None, 

1825 state: TaskInstanceState | Sequence[TaskInstanceState], 

1826 include_subdags: bool, 

1827 include_parentdag: bool, 

1828 include_dependent_dags: bool, 

1829 exclude_task_ids: Collection[str | tuple[str, int]] | None, 

1830 session: Session, 

1831 dag_bag: DagBag | None = ..., 

1832 ) -> Iterable[TaskInstance]: ... # pragma: no cover 

1833 

1834 @overload 

1835 def _get_task_instances( 

1836 self, 

1837 *, 

1838 task_ids: Collection[str | tuple[str, int]] | None, 

1839 as_pk_tuple: Literal[True], 

1840 start_date: datetime | None, 

1841 end_date: datetime | None, 

1842 run_id: str | None, 

1843 state: TaskInstanceState | Sequence[TaskInstanceState], 

1844 include_subdags: bool, 

1845 include_parentdag: bool, 

1846 include_dependent_dags: bool, 

1847 exclude_task_ids: Collection[str | tuple[str, int]] | None, 

1848 session: Session, 

1849 dag_bag: DagBag | None = ..., 

1850 recursion_depth: int = ..., 

1851 max_recursion_depth: int = ..., 

1852 visited_external_tis: set[TaskInstanceKey] = ..., 

1853 ) -> set[TaskInstanceKey]: ... # pragma: no cover 

1854 

1855 def _get_task_instances( 

1856 self, 

1857 *, 

1858 task_ids: Collection[str | tuple[str, int]] | None, 

1859 as_pk_tuple: Literal[True, None] = None, 

1860 start_date: datetime | None, 

1861 end_date: datetime | None, 

1862 run_id: str | None, 

1863 state: TaskInstanceState | Sequence[TaskInstanceState], 

1864 include_subdags: bool, 

1865 include_parentdag: bool, 

1866 include_dependent_dags: bool, 

1867 exclude_task_ids: Collection[str | tuple[str, int]] | None, 

1868 session: Session, 

1869 dag_bag: DagBag | None = None, 

1870 recursion_depth: int = 0, 

1871 max_recursion_depth: int | None = None, 

1872 visited_external_tis: set[TaskInstanceKey] | None = None, 

1873 ) -> Iterable[TaskInstance] | set[TaskInstanceKey]: 

1874 TI = TaskInstance 

1875 

1876 # If we are looking at subdags/dependent dags we want to avoid UNION calls 

1877 # in SQL (it doesn't play nice with fields that have no equality operator, 

1878 # like JSON types), we instead build our result set separately. 

1879 # 

1880 # This will be empty if we are only looking at one dag, in which case 

1881 # we can return the filtered TI query object directly. 

1882 result: set[TaskInstanceKey] = set() 

1883 

1884 # Do we want full objects, or just the primary columns? 

1885 if as_pk_tuple: 

1886 tis = select(TI.dag_id, TI.task_id, TI.run_id, TI.map_index) 

1887 else: 

1888 tis = select(TaskInstance) 

1889 tis = tis.join(TaskInstance.dag_run) 

1890 

1891 if include_subdags: 

1892 # Crafting the right filter for dag_id and task_ids combo 

1893 conditions = [] 

1894 for dag in [*self.subdags, self]: 

1895 conditions.append( 

1896 (TaskInstance.dag_id == dag.dag_id) & TaskInstance.task_id.in_(dag.task_ids) 

1897 ) 

1898 tis = tis.where(or_(*conditions)) 

1899 elif self.partial: 

1900 tis = tis.where(TaskInstance.dag_id == self.dag_id, TaskInstance.task_id.in_(self.task_ids)) 

1901 else: 

1902 tis = tis.where(TaskInstance.dag_id == self.dag_id) 

1903 if run_id: 

1904 tis = tis.where(TaskInstance.run_id == run_id) 

1905 if start_date: 

1906 tis = tis.where(DagRun.execution_date >= start_date) 

1907 if task_ids is not None: 

1908 tis = tis.where(TaskInstance.ti_selector_condition(task_ids)) 

1909 

1910 # This allows allow_trigger_in_future config to take affect, rather than mandating exec_date <= UTC 

1911 if end_date or not self.allow_future_exec_dates: 

1912 end_date = end_date or timezone.utcnow() 

1913 tis = tis.where(DagRun.execution_date <= end_date) 

1914 

1915 if state: 

1916 if isinstance(state, (str, TaskInstanceState)): 

1917 tis = tis.where(TaskInstance.state == state) 

1918 elif len(state) == 1: 

1919 tis = tis.where(TaskInstance.state == state[0]) 

1920 else: 

1921 # this is required to deal with NULL values 

1922 if None in state: 

1923 if all(x is None for x in state): 

1924 tis = tis.where(TaskInstance.state.is_(None)) 

1925 else: 

1926 not_none_state = [s for s in state if s] 

1927 tis = tis.where( 

1928 or_(TaskInstance.state.in_(not_none_state), TaskInstance.state.is_(None)) 

1929 ) 

1930 else: 

1931 tis = tis.where(TaskInstance.state.in_(state)) 

1932 

1933 # Next, get any of them from our parent DAG (if there is one) 

1934 if include_parentdag and self.parent_dag is not None: 

1935 if visited_external_tis is None: 

1936 visited_external_tis = set() 

1937 

1938 p_dag = self.parent_dag.partial_subset( 

1939 task_ids_or_regex=r"^{}$".format(self.dag_id.split(".")[1]), 

1940 include_upstream=False, 

1941 include_downstream=True, 

1942 ) 

1943 result.update( 

1944 p_dag._get_task_instances( 

1945 task_ids=task_ids, 

1946 start_date=start_date, 

1947 end_date=end_date, 

1948 run_id=None, 

1949 state=state, 

1950 include_subdags=include_subdags, 

1951 include_parentdag=False, 

1952 include_dependent_dags=include_dependent_dags, 

1953 as_pk_tuple=True, 

1954 exclude_task_ids=exclude_task_ids, 

1955 session=session, 

1956 dag_bag=dag_bag, 

1957 recursion_depth=recursion_depth, 

1958 max_recursion_depth=max_recursion_depth, 

1959 visited_external_tis=visited_external_tis, 

1960 ) 

1961 ) 

1962 

1963 if include_dependent_dags: 

1964 # Recursively find external tasks indicated by ExternalTaskMarker 

1965 from airflow.sensors.external_task import ExternalTaskMarker 

1966 

1967 query = tis 

1968 if as_pk_tuple: 

1969 all_tis = session.execute(query).all() 

1970 condition = TI.filter_for_tis(TaskInstanceKey(*cols) for cols in all_tis) 

1971 if condition is not None: 

1972 query = select(TI).where(condition) 

1973 

1974 if visited_external_tis is None: 

1975 visited_external_tis = set() 

1976 

1977 external_tasks = session.scalars(query.where(TI.operator == ExternalTaskMarker.__name__)) 

1978 

1979 for ti in external_tasks: 

1980 ti_key = ti.key.primary 

1981 if ti_key in visited_external_tis: 

1982 continue 

1983 

1984 visited_external_tis.add(ti_key) 

1985 

1986 task: ExternalTaskMarker = cast(ExternalTaskMarker, copy.copy(self.get_task(ti.task_id))) 

1987 ti.task = task 

1988 

1989 if max_recursion_depth is None: 

1990 # Maximum recursion depth allowed is the recursion_depth of the first 

1991 # ExternalTaskMarker in the tasks to be visited. 

1992 max_recursion_depth = task.recursion_depth 

1993 

1994 if recursion_depth + 1 > max_recursion_depth: 

1995 # Prevent cycles or accidents. 

1996 raise AirflowException( 

1997 f"Maximum recursion depth {max_recursion_depth} reached for " 

1998 f"{ExternalTaskMarker.__name__} {ti.task_id}. " 

1999 f"Attempted to clear too many tasks or there may be a cyclic dependency." 

2000 ) 

2001 ti.render_templates() 

2002 external_tis = session.scalars( 

2003 select(TI) 

2004 .join(TI.dag_run) 

2005 .where( 

2006 TI.dag_id == task.external_dag_id, 

2007 TI.task_id == task.external_task_id, 

2008 DagRun.execution_date == pendulum.parse(task.execution_date), 

2009 ) 

2010 ) 

2011 

2012 for tii in external_tis: 

2013 if not dag_bag: 

2014 from airflow.models.dagbag import DagBag 

2015 

2016 dag_bag = DagBag(read_dags_from_db=True) 

2017 external_dag = dag_bag.get_dag(tii.dag_id, session=session) 

2018 if not external_dag: 

2019 raise AirflowException(f"Could not find dag {tii.dag_id}") 

2020 downstream = external_dag.partial_subset( 

2021 task_ids_or_regex=[tii.task_id], 

2022 include_upstream=False, 

2023 include_downstream=True, 

2024 ) 

2025 result.update( 

2026 downstream._get_task_instances( 

2027 task_ids=None, 

2028 run_id=tii.run_id, 

2029 start_date=None, 

2030 end_date=None, 

2031 state=state, 

2032 include_subdags=include_subdags, 

2033 include_dependent_dags=include_dependent_dags, 

2034 include_parentdag=False, 

2035 as_pk_tuple=True, 

2036 exclude_task_ids=exclude_task_ids, 

2037 dag_bag=dag_bag, 

2038 session=session, 

2039 recursion_depth=recursion_depth + 1, 

2040 max_recursion_depth=max_recursion_depth, 

2041 visited_external_tis=visited_external_tis, 

2042 ) 

2043 ) 

2044 

2045 if result or as_pk_tuple: 

2046 # Only execute the `ti` query if we have also collected some other results (i.e. subdags etc.) 

2047 if as_pk_tuple: 

2048 tis_query = session.execute(tis).all() 

2049 result.update(TaskInstanceKey(**cols._mapping) for cols in tis_query) 

2050 else: 

2051 result.update(ti.key for ti in session.scalars(tis)) 

2052 

2053 if exclude_task_ids is not None: 

2054 result = { 

2055 task 

2056 for task in result 

2057 if task.task_id not in exclude_task_ids 

2058 and (task.task_id, task.map_index) not in exclude_task_ids 

2059 } 

2060 

2061 if as_pk_tuple: 

2062 return result 

2063 if result: 

2064 # We've been asked for objects, lets combine it all back in to a result set 

2065 ti_filters = TI.filter_for_tis(result) 

2066 if ti_filters is not None: 

2067 tis = select(TI).where(ti_filters) 

2068 elif exclude_task_ids is None: 

2069 pass # Disable filter if not set. 

2070 elif isinstance(next(iter(exclude_task_ids), None), str): 

2071 tis = tis.where(TI.task_id.notin_(exclude_task_ids)) 

2072 else: 

2073 tis = tis.where(not_(tuple_in_condition((TI.task_id, TI.map_index), exclude_task_ids))) 

2074 

2075 return tis 

2076 

2077 @provide_session 

2078 def set_task_instance_state( 

2079 self, 

2080 *, 

2081 task_id: str, 

2082 map_indexes: Collection[int] | None = None, 

2083 execution_date: datetime | None = None, 

2084 run_id: str | None = None, 

2085 state: TaskInstanceState, 

2086 upstream: bool = False, 

2087 downstream: bool = False, 

2088 future: bool = False, 

2089 past: bool = False, 

2090 commit: bool = True, 

2091 session=NEW_SESSION, 

2092 ) -> list[TaskInstance]: 

2093 """ 

2094 Set the state of a TaskInstance and clear downstream tasks in failed or upstream_failed state. 

2095 

2096 :param task_id: Task ID of the TaskInstance 

2097 :param map_indexes: Only set TaskInstance if its map_index matches. 

2098 If None (default), all mapped TaskInstances of the task are set. 

2099 :param execution_date: Execution date of the TaskInstance 

2100 :param run_id: The run_id of the TaskInstance 

2101 :param state: State to set the TaskInstance to 

2102 :param upstream: Include all upstream tasks of the given task_id 

2103 :param downstream: Include all downstream tasks of the given task_id 

2104 :param future: Include all future TaskInstances of the given task_id 

2105 :param commit: Commit changes 

2106 :param past: Include all past TaskInstances of the given task_id 

2107 """ 

2108 from airflow.api.common.mark_tasks import set_state 

2109 

2110 if not exactly_one(execution_date, run_id): 

2111 raise ValueError("Exactly one of execution_date or run_id must be provided") 

2112 

2113 task = self.get_task(task_id) 

2114 task.dag = self 

2115 

2116 tasks_to_set_state: list[Operator | tuple[Operator, int]] 

2117 if map_indexes is None: 

2118 tasks_to_set_state = [task] 

2119 else: 

2120 tasks_to_set_state = [(task, map_index) for map_index in map_indexes] 

2121 

2122 altered = set_state( 

2123 tasks=tasks_to_set_state, 

2124 execution_date=execution_date, 

2125 run_id=run_id, 

2126 upstream=upstream, 

2127 downstream=downstream, 

2128 future=future, 

2129 past=past, 

2130 state=state, 

2131 commit=commit, 

2132 session=session, 

2133 ) 

2134 

2135 if not commit: 

2136 return altered 

2137 

2138 # Clear downstream tasks that are in failed/upstream_failed state to resume them. 

2139 # Flush the session so that the tasks marked success are reflected in the db. 

2140 session.flush() 

2141 subdag = self.partial_subset( 

2142 task_ids_or_regex={task_id}, 

2143 include_downstream=True, 

2144 include_upstream=False, 

2145 ) 

2146 

2147 if execution_date is None: 

2148 dag_run = session.scalars( 

2149 select(DagRun).where(DagRun.run_id == run_id, DagRun.dag_id == self.dag_id) 

2150 ).one() # Raises an error if not found 

2151 resolve_execution_date = dag_run.execution_date 

2152 else: 

2153 resolve_execution_date = execution_date 

2154 

2155 end_date = resolve_execution_date if not future else None 

2156 start_date = resolve_execution_date if not past else None 

2157 

2158 subdag.clear( 

2159 start_date=start_date, 

2160 end_date=end_date, 

2161 include_subdags=True, 

2162 include_parentdag=True, 

2163 only_failed=True, 

2164 session=session, 

2165 # Exclude the task itself from being cleared 

2166 exclude_task_ids=frozenset({task_id}), 

2167 ) 

2168 

2169 return altered 

2170 

2171 @provide_session 

2172 def set_task_group_state( 

2173 self, 

2174 *, 

2175 group_id: str, 

2176 execution_date: datetime | None = None, 

2177 run_id: str | None = None, 

2178 state: TaskInstanceState, 

2179 upstream: bool = False, 

2180 downstream: bool = False, 

2181 future: bool = False, 

2182 past: bool = False, 

2183 commit: bool = True, 

2184 session: Session = NEW_SESSION, 

2185 ) -> list[TaskInstance]: 

2186 """ 

2187 Set TaskGroup to the given state and clear downstream tasks in failed or upstream_failed state. 

2188 

2189 :param group_id: The group_id of the TaskGroup 

2190 :param execution_date: Execution date of the TaskInstance 

2191 :param run_id: The run_id of the TaskInstance 

2192 :param state: State to set the TaskInstance to 

2193 :param upstream: Include all upstream tasks of the given task_id 

2194 :param downstream: Include all downstream tasks of the given task_id 

2195 :param future: Include all future TaskInstances of the given task_id 

2196 :param commit: Commit changes 

2197 :param past: Include all past TaskInstances of the given task_id 

2198 :param session: new session 

2199 """ 

2200 from airflow.api.common.mark_tasks import set_state 

2201 

2202 if not exactly_one(execution_date, run_id): 

2203 raise ValueError("Exactly one of execution_date or run_id must be provided") 

2204 

2205 tasks_to_set_state: list[BaseOperator | tuple[BaseOperator, int]] = [] 

2206 task_ids: list[str] = [] 

2207 

2208 if execution_date is None: 

2209 dag_run = session.scalars( 

2210 select(DagRun).where(DagRun.run_id == run_id, DagRun.dag_id == self.dag_id) 

2211 ).one() # Raises an error if not found 

2212 resolve_execution_date = dag_run.execution_date 

2213 else: 

2214 resolve_execution_date = execution_date 

2215 

2216 end_date = resolve_execution_date if not future else None 

2217 start_date = resolve_execution_date if not past else None 

2218 

2219 task_group_dict = self.task_group.get_task_group_dict() 

2220 task_group = task_group_dict.get(group_id) 

2221 if task_group is None: 

2222 raise ValueError("TaskGroup {group_id} could not be found") 

2223 tasks_to_set_state = [task for task in task_group.iter_tasks() if isinstance(task, BaseOperator)] 

2224 task_ids = [task.task_id for task in task_group.iter_tasks()] 

2225 dag_runs_query = select(DagRun.id).where(DagRun.dag_id == self.dag_id) 

2226 if start_date is None and end_date is None: 

2227 dag_runs_query = dag_runs_query.where(DagRun.execution_date == start_date) 

2228 else: 

2229 if start_date is not None: 

2230 dag_runs_query = dag_runs_query.where(DagRun.execution_date >= start_date) 

2231 if end_date is not None: 

2232 dag_runs_query = dag_runs_query.where(DagRun.execution_date <= end_date) 

2233 

2234 with lock_rows(dag_runs_query, session): 

2235 altered = set_state( 

2236 tasks=tasks_to_set_state, 

2237 execution_date=execution_date, 

2238 run_id=run_id, 

2239 upstream=upstream, 

2240 downstream=downstream, 

2241 future=future, 

2242 past=past, 

2243 state=state, 

2244 commit=commit, 

2245 session=session, 

2246 ) 

2247 if not commit: 

2248 return altered 

2249 

2250 # Clear downstream tasks that are in failed/upstream_failed state to resume them. 

2251 # Flush the session so that the tasks marked success are reflected in the db. 

2252 session.flush() 

2253 task_subset = self.partial_subset( 

2254 task_ids_or_regex=task_ids, 

2255 include_downstream=True, 

2256 include_upstream=False, 

2257 ) 

2258 

2259 task_subset.clear( 

2260 start_date=start_date, 

2261 end_date=end_date, 

2262 include_subdags=True, 

2263 include_parentdag=True, 

2264 only_failed=True, 

2265 session=session, 

2266 # Exclude the task from the current group from being cleared 

2267 exclude_task_ids=frozenset(task_ids), 

2268 ) 

2269 

2270 return altered 

2271 

2272 @property 

2273 def roots(self) -> list[Operator]: 

2274 """Return nodes with no parents. These are first to execute and are called roots or root nodes.""" 

2275 return [task for task in self.tasks if not task.upstream_list] 

2276 

2277 @property 

2278 def leaves(self) -> list[Operator]: 

2279 """Return nodes with no children. These are last to execute and are called leaves or leaf nodes.""" 

2280 return [task for task in self.tasks if not task.downstream_list] 

2281 

2282 def topological_sort(self, include_subdag_tasks: bool = False): 

2283 """ 

2284 Sorts tasks in topographical order, such that a task comes after any of its upstream dependencies. 

2285 

2286 Deprecated in place of ``task_group.topological_sort`` 

2287 """ 

2288 from airflow.utils.task_group import TaskGroup 

2289 

2290 def nested_topo(group): 

2291 for node in group.topological_sort(_include_subdag_tasks=include_subdag_tasks): 

2292 if isinstance(node, TaskGroup): 

2293 yield from nested_topo(node) 

2294 else: 

2295 yield node 

2296 

2297 return tuple(nested_topo(self.task_group)) 

2298 

2299 @provide_session 

2300 def set_dag_runs_state( 

2301 self, 

2302 state: DagRunState = DagRunState.RUNNING, 

2303 session: Session = NEW_SESSION, 

2304 start_date: datetime | None = None, 

2305 end_date: datetime | None = None, 

2306 dag_ids: list[str] | None = None, 

2307 ) -> None: 

2308 warnings.warn( 

2309 "This method is deprecated and will be removed in a future version.", 

2310 RemovedInAirflow3Warning, 

2311 stacklevel=3, 

2312 ) 

2313 dag_ids = dag_ids or [self.dag_id] 

2314 query = update(DagRun).where(DagRun.dag_id.in_(dag_ids)) 

2315 if start_date: 

2316 query = query.where(DagRun.execution_date >= start_date) 

2317 if end_date: 

2318 query = query.where(DagRun.execution_date <= end_date) 

2319 session.execute(query.values(state=state).execution_options(synchronize_session="fetch")) 

2320 

2321 @provide_session 

2322 def clear( 

2323 self, 

2324 task_ids: Collection[str | tuple[str, int]] | None = None, 

2325 start_date: datetime | None = None, 

2326 end_date: datetime | None = None, 

2327 only_failed: bool = False, 

2328 only_running: bool = False, 

2329 confirm_prompt: bool = False, 

2330 include_subdags: bool = True, 

2331 include_parentdag: bool = True, 

2332 dag_run_state: DagRunState = DagRunState.QUEUED, 

2333 dry_run: bool = False, 

2334 session: Session = NEW_SESSION, 

2335 get_tis: bool = False, 

2336 recursion_depth: int = 0, 

2337 max_recursion_depth: int | None = None, 

2338 dag_bag: DagBag | None = None, 

2339 exclude_task_ids: frozenset[str] | frozenset[tuple[str, int]] | None = frozenset(), 

2340 ) -> int | Iterable[TaskInstance]: 

2341 """ 

2342 Clear a set of task instances associated with the current dag for a specified date range. 

2343 

2344 :param task_ids: List of task ids or (``task_id``, ``map_index``) tuples to clear 

2345 :param start_date: The minimum execution_date to clear 

2346 :param end_date: The maximum execution_date to clear 

2347 :param only_failed: Only clear failed tasks 

2348 :param only_running: Only clear running tasks. 

2349 :param confirm_prompt: Ask for confirmation 

2350 :param include_subdags: Clear tasks in subdags and clear external tasks 

2351 indicated by ExternalTaskMarker 

2352 :param include_parentdag: Clear tasks in the parent dag of the subdag. 

2353 :param dag_run_state: state to set DagRun to. If set to False, dagrun state will not 

2354 be changed. 

2355 :param dry_run: Find the tasks to clear but don't clear them. 

2356 :param session: The sqlalchemy session to use 

2357 :param dag_bag: The DagBag used to find the dags subdags (Optional) 

2358 :param exclude_task_ids: A set of ``task_id`` or (``task_id``, ``map_index``) 

2359 tuples that should not be cleared 

2360 """ 

2361 if get_tis: 

2362 warnings.warn( 

2363 "Passing `get_tis` to dag.clear() is deprecated. Use `dry_run` parameter instead.", 

2364 RemovedInAirflow3Warning, 

2365 stacklevel=2, 

2366 ) 

2367 dry_run = True 

2368 

2369 if recursion_depth: 

2370 warnings.warn( 

2371 "Passing `recursion_depth` to dag.clear() is deprecated.", 

2372 RemovedInAirflow3Warning, 

2373 stacklevel=2, 

2374 ) 

2375 if max_recursion_depth: 

2376 warnings.warn( 

2377 "Passing `max_recursion_depth` to dag.clear() is deprecated.", 

2378 RemovedInAirflow3Warning, 

2379 stacklevel=2, 

2380 ) 

2381 

2382 state: list[TaskInstanceState] = [] 

2383 if only_failed: 

2384 state += [TaskInstanceState.FAILED, TaskInstanceState.UPSTREAM_FAILED] 

2385 if only_running: 

2386 # Yes, having `+=` doesn't make sense, but this was the existing behaviour 

2387 state += [TaskInstanceState.RUNNING] 

2388 

2389 tis = self._get_task_instances( 

2390 task_ids=task_ids, 

2391 start_date=start_date, 

2392 end_date=end_date, 

2393 run_id=None, 

2394 state=state, 

2395 include_subdags=include_subdags, 

2396 include_parentdag=include_parentdag, 

2397 include_dependent_dags=include_subdags, # compat, yes this is not a typo 

2398 session=session, 

2399 dag_bag=dag_bag, 

2400 exclude_task_ids=exclude_task_ids, 

2401 ) 

2402 

2403 if dry_run: 

2404 return session.scalars(tis).all() 

2405 

2406 tis = session.scalars(tis).all() 

2407 

2408 count = len(list(tis)) 

2409 do_it = True 

2410 if count == 0: 

2411 return 0 

2412 if confirm_prompt: 

2413 ti_list = "\n".join(str(t) for t in tis) 

2414 question = f"You are about to delete these {count} tasks:\n{ti_list}\n\nAre you sure? [y/n]" 

2415 do_it = utils.helpers.ask_yesno(question) 

2416 

2417 if do_it: 

2418 clear_task_instances( 

2419 list(tis), 

2420 session, 

2421 dag=self, 

2422 dag_run_state=dag_run_state, 

2423 ) 

2424 else: 

2425 count = 0 

2426 print("Cancelled, nothing was cleared.") 

2427 

2428 session.flush() 

2429 return count 

2430 

2431 @classmethod 

2432 def clear_dags( 

2433 cls, 

2434 dags, 

2435 start_date=None, 

2436 end_date=None, 

2437 only_failed=False, 

2438 only_running=False, 

2439 confirm_prompt=False, 

2440 include_subdags=True, 

2441 include_parentdag=False, 

2442 dag_run_state=DagRunState.QUEUED, 

2443 dry_run=False, 

2444 ): 

2445 all_tis = [] 

2446 for dag in dags: 

2447 tis = dag.clear( 

2448 start_date=start_date, 

2449 end_date=end_date, 

2450 only_failed=only_failed, 

2451 only_running=only_running, 

2452 confirm_prompt=False, 

2453 include_subdags=include_subdags, 

2454 include_parentdag=include_parentdag, 

2455 dag_run_state=dag_run_state, 

2456 dry_run=True, 

2457 ) 

2458 all_tis.extend(tis) 

2459 

2460 if dry_run: 

2461 return all_tis 

2462 

2463 count = len(all_tis) 

2464 do_it = True 

2465 if count == 0: 

2466 print("Nothing to clear.") 

2467 return 0 

2468 if confirm_prompt: 

2469 ti_list = "\n".join(str(t) for t in all_tis) 

2470 question = f"You are about to delete these {count} tasks:\n{ti_list}\n\nAre you sure? [y/n]" 

2471 do_it = utils.helpers.ask_yesno(question) 

2472 

2473 if do_it: 

2474 for dag in dags: 

2475 dag.clear( 

2476 start_date=start_date, 

2477 end_date=end_date, 

2478 only_failed=only_failed, 

2479 only_running=only_running, 

2480 confirm_prompt=False, 

2481 include_subdags=include_subdags, 

2482 dag_run_state=dag_run_state, 

2483 dry_run=False, 

2484 ) 

2485 else: 

2486 count = 0 

2487 print("Cancelled, nothing was cleared.") 

2488 return count 

2489 

2490 def __deepcopy__(self, memo): 

2491 # Switcharoo to go around deepcopying objects coming through the 

2492 # backdoor 

2493 cls = self.__class__ 

2494 result = cls.__new__(cls) 

2495 memo[id(self)] = result 

2496 for k, v in self.__dict__.items(): 

2497 if k not in ("user_defined_macros", "user_defined_filters", "_log"): 

2498 setattr(result, k, copy.deepcopy(v, memo)) 

2499 

2500 result.user_defined_macros = self.user_defined_macros 

2501 result.user_defined_filters = self.user_defined_filters 

2502 if hasattr(self, "_log"): 

2503 result._log = self._log 

2504 return result 

2505 

2506 def sub_dag(self, *args, **kwargs): 

2507 """Use `airflow.models.DAG.partial_subset`, this method is deprecated.""" 

2508 warnings.warn( 

2509 "This method is deprecated and will be removed in a future version. Please use partial_subset", 

2510 RemovedInAirflow3Warning, 

2511 stacklevel=2, 

2512 ) 

2513 return self.partial_subset(*args, **kwargs) 

2514 

2515 def partial_subset( 

2516 self, 

2517 task_ids_or_regex: str | Pattern | Iterable[str], 

2518 include_downstream=False, 

2519 include_upstream=True, 

2520 include_direct_upstream=False, 

2521 ): 

2522 """ 

2523 Return a subset of the current dag based on regex matching one or more tasks. 

2524 

2525 Returns a subset of the current dag as a deep copy of the current dag 

2526 based on a regex that should match one or many tasks, and includes 

2527 upstream and downstream neighbours based on the flag passed. 

2528 

2529 :param task_ids_or_regex: Either a list of task_ids, or a regex to 

2530 match against task ids (as a string, or compiled regex pattern). 

2531 :param include_downstream: Include all downstream tasks of matched 

2532 tasks, in addition to matched tasks. 

2533 :param include_upstream: Include all upstream tasks of matched tasks, 

2534 in addition to matched tasks. 

2535 :param include_direct_upstream: Include all tasks directly upstream of matched 

2536 and downstream (if include_downstream = True) tasks 

2537 """ 

2538 from airflow.models.baseoperator import BaseOperator 

2539 from airflow.models.mappedoperator import MappedOperator 

2540 

2541 # deep-copying self.task_dict and self._task_group takes a long time, and we don't want all 

2542 # the tasks anyway, so we copy the tasks manually later 

2543 memo = {id(self.task_dict): None, id(self._task_group): None} 

2544 dag = copy.deepcopy(self, memo) # type: ignore 

2545 

2546 if isinstance(task_ids_or_regex, (str, Pattern)): 

2547 matched_tasks = [t for t in self.tasks if re2.findall(task_ids_or_regex, t.task_id)] 

2548 else: 

2549 matched_tasks = [t for t in self.tasks if t.task_id in task_ids_or_regex] 

2550 

2551 also_include_ids: set[str] = set() 

2552 for t in matched_tasks: 

2553 if include_downstream: 

2554 for rel in t.get_flat_relatives(upstream=False): 

2555 also_include_ids.add(rel.task_id) 

2556 if rel not in matched_tasks: # if it's in there, we're already processing it 

2557 # need to include setups and teardowns for tasks that are in multiple 

2558 # non-collinear setup/teardown paths 

2559 if not rel.is_setup and not rel.is_teardown: 

2560 also_include_ids.update( 

2561 x.task_id for x in rel.get_upstreams_only_setups_and_teardowns() 

2562 ) 

2563 if include_upstream: 

2564 also_include_ids.update(x.task_id for x in t.get_upstreams_follow_setups()) 

2565 else: 

2566 if not t.is_setup and not t.is_teardown: 

2567 also_include_ids.update(x.task_id for x in t.get_upstreams_only_setups_and_teardowns()) 

2568 if t.is_setup and not include_downstream: 

2569 also_include_ids.update(x.task_id for x in t.downstream_list if x.is_teardown) 

2570 

2571 also_include: list[Operator] = [self.task_dict[x] for x in also_include_ids] 

2572 direct_upstreams: list[Operator] = [] 

2573 if include_direct_upstream: 

2574 for t in itertools.chain(matched_tasks, also_include): 

2575 upstream = (u for u in t.upstream_list if isinstance(u, (BaseOperator, MappedOperator))) 

2576 direct_upstreams.extend(upstream) 

2577 

2578 # Compiling the unique list of tasks that made the cut 

2579 # Make sure to not recursively deepcopy the dag or task_group while copying the task. 

2580 # task_group is reset later 

2581 def _deepcopy_task(t) -> Operator: 

2582 memo.setdefault(id(t.task_group), None) 

2583 return copy.deepcopy(t, memo) 

2584 

2585 dag.task_dict = { 

2586 t.task_id: _deepcopy_task(t) 

2587 for t in itertools.chain(matched_tasks, also_include, direct_upstreams) 

2588 } 

2589 

2590 def filter_task_group(group, parent_group): 

2591 """Exclude tasks not included in the subdag from the given TaskGroup.""" 

2592 # We want to deepcopy _most but not all_ attributes of the task group, so we create a shallow copy 

2593 # and then manually deep copy the instances. (memo argument to deepcopy only works for instances 

2594 # of classes, not "native" properties of an instance) 

2595 copied = copy.copy(group) 

2596 

2597 memo[id(group.children)] = {} 

2598 if parent_group: 

2599 memo[id(group.parent_group)] = parent_group 

2600 for attr, value in copied.__dict__.items(): 

2601 if id(value) in memo: 

2602 value = memo[id(value)] 

2603 else: 

2604 value = copy.deepcopy(value, memo) 

2605 copied.__dict__[attr] = value 

2606 

2607 proxy = weakref.proxy(copied) 

2608 

2609 for child in group.children.values(): 

2610 if isinstance(child, AbstractOperator): 

2611 if child.task_id in dag.task_dict: 

2612 task = copied.children[child.task_id] = dag.task_dict[child.task_id] 

2613 task.task_group = proxy 

2614 else: 

2615 copied.used_group_ids.discard(child.task_id) 

2616 else: 

2617 filtered_child = filter_task_group(child, proxy) 

2618 

2619 # Only include this child TaskGroup if it is non-empty. 

2620 if filtered_child.children: 

2621 copied.children[child.group_id] = filtered_child 

2622 

2623 return copied 

2624 

2625 dag._task_group = filter_task_group(self.task_group, None) 

2626 

2627 # Removing upstream/downstream references to tasks and TaskGroups that did not make 

2628 # the cut. 

2629 subdag_task_groups = dag.task_group.get_task_group_dict() 

2630 for group in subdag_task_groups.values(): 

2631 group.upstream_group_ids.intersection_update(subdag_task_groups) 

2632 group.downstream_group_ids.intersection_update(subdag_task_groups) 

2633 group.upstream_task_ids.intersection_update(dag.task_dict) 

2634 group.downstream_task_ids.intersection_update(dag.task_dict) 

2635 

2636 for t in dag.tasks: 

2637 # Removing upstream/downstream references to tasks that did not 

2638 # make the cut 

2639 t.upstream_task_ids.intersection_update(dag.task_dict) 

2640 t.downstream_task_ids.intersection_update(dag.task_dict) 

2641 

2642 if len(dag.tasks) < len(self.tasks): 

2643 dag.partial = True 

2644 

2645 return dag 

2646 

2647 def has_task(self, task_id: str): 

2648 return task_id in self.task_dict 

2649 

2650 def has_task_group(self, task_group_id: str) -> bool: 

2651 return task_group_id in self.task_group_dict 

2652 

2653 @functools.cached_property 

2654 def task_group_dict(self): 

2655 return {k: v for k, v in self._task_group.get_task_group_dict().items() if k is not None} 

2656 

2657 def get_task(self, task_id: str, include_subdags: bool = False) -> Operator: 

2658 if task_id in self.task_dict: 

2659 return self.task_dict[task_id] 

2660 if include_subdags: 

2661 for dag in self.subdags: 

2662 if task_id in dag.task_dict: 

2663 return dag.task_dict[task_id] 

2664 raise TaskNotFound(f"Task {task_id} not found") 

2665 

2666 def pickle_info(self): 

2667 d = {} 

2668 d["is_picklable"] = True 

2669 try: 

2670 dttm = timezone.utcnow() 

2671 pickled = pickle.dumps(self) 

2672 d["pickle_len"] = len(pickled) 

2673 d["pickling_duration"] = str(timezone.utcnow() - dttm) 

2674 except Exception as e: 

2675 self.log.debug(e) 

2676 d["is_picklable"] = False 

2677 d["stacktrace"] = traceback.format_exc() 

2678 return d 

2679 

2680 @provide_session 

2681 def pickle(self, session=NEW_SESSION) -> DagPickle: 

2682 dag = session.scalar(select(DagModel).where(DagModel.dag_id == self.dag_id).limit(1)) 

2683 dp = None 

2684 if dag and dag.pickle_id: 

2685 dp = session.scalar(select(DagPickle).where(DagPickle.id == dag.pickle_id).limit(1)) 

2686 if not dp or dp.pickle != self: 

2687 dp = DagPickle(dag=self) 

2688 session.add(dp) 

2689 self.last_pickled = timezone.utcnow() 

2690 session.commit() 

2691 self.pickle_id = dp.id 

2692 

2693 return dp 

2694 

2695 def tree_view(self) -> None: 

2696 """Print an ASCII tree representation of the DAG.""" 

2697 for tmp in self._generate_tree_view(): 

2698 print(tmp) 

2699 

2700 def _generate_tree_view(self) -> Generator[str, None, None]: 

2701 def get_downstream(task, level=0) -> Generator[str, None, None]: 

2702 yield (" " * level * 4) + str(task) 

2703 level += 1 

2704 for tmp_task in sorted(task.downstream_list, key=lambda x: x.task_id): 

2705 yield from get_downstream(tmp_task, level) 

2706 

2707 for t in sorted(self.roots, key=lambda x: x.task_id): 

2708 yield from get_downstream(t) 

2709 

2710 def get_tree_view(self) -> str: 

2711 """Return an ASCII tree representation of the DAG.""" 

2712 rst = "" 

2713 for tmp in self._generate_tree_view(): 

2714 rst += tmp + "\n" 

2715 return rst 

2716 

2717 @property 

2718 def task(self) -> TaskDecoratorCollection: 

2719 from airflow.decorators import task 

2720 

2721 return cast("TaskDecoratorCollection", functools.partial(task, dag=self)) 

2722 

2723 def add_task(self, task: Operator) -> None: 

2724 """ 

2725 Add a task to the DAG. 

2726 

2727 :param task: the task you want to add 

2728 """ 

2729 FailStopDagInvalidTriggerRule.check(dag=self, trigger_rule=task.trigger_rule) 

2730 

2731 from airflow.utils.task_group import TaskGroupContext 

2732 

2733 # if the task has no start date, assign it the same as the DAG 

2734 if not task.start_date: 

2735 task.start_date = self.start_date 

2736 # otherwise, the task will start on the later of its own start date and 

2737 # the DAG's start date 

2738 elif self.start_date: 

2739 task.start_date = max(task.start_date, self.start_date) 

2740 

2741 # if the task has no end date, assign it the same as the dag 

2742 if not task.end_date: 

2743 task.end_date = self.end_date 

2744 # otherwise, the task will end on the earlier of its own end date and 

2745 # the DAG's end date 

2746 elif task.end_date and self.end_date: 

2747 task.end_date = min(task.end_date, self.end_date) 

2748 

2749 task_id = task.task_id 

2750 if not task.task_group: 

2751 task_group = TaskGroupContext.get_current_task_group(self) 

2752 if task_group: 

2753 task_id = task_group.child_id(task_id) 

2754 task_group.add(task) 

2755 

2756 if ( 

2757 task_id in self.task_dict and self.task_dict[task_id] is not task 

2758 ) or task_id in self._task_group.used_group_ids: 

2759 raise DuplicateTaskIdFound(f"Task id '{task_id}' has already been added to the DAG") 

2760 else: 

2761 self.task_dict[task_id] = task 

2762 task.dag = self 

2763 # Add task_id to used_group_ids to prevent group_id and task_id collisions. 

2764 self._task_group.used_group_ids.add(task_id) 

2765 

2766 self.task_count = len(self.task_dict) 

2767 

2768 def add_tasks(self, tasks: Iterable[Operator]) -> None: 

2769 """ 

2770 Add a list of tasks to the DAG. 

2771 

2772 :param tasks: a lit of tasks you want to add 

2773 """ 

2774 for task in tasks: 

2775 self.add_task(task) 

2776 

2777 def _remove_task(self, task_id: str) -> None: 

2778 # This is "private" as removing could leave a hole in dependencies if done incorrectly, and this 

2779 # doesn't guard against that 

2780 task = self.task_dict.pop(task_id) 

2781 tg = getattr(task, "task_group", None) 

2782 if tg: 

2783 tg._remove(task) 

2784 

2785 self.task_count = len(self.task_dict) 

2786 

2787 def run( 

2788 self, 

2789 start_date=None, 

2790 end_date=None, 

2791 mark_success=False, 

2792 local=False, 

2793 executor=None, 

2794 donot_pickle=airflow_conf.getboolean("core", "donot_pickle"), 

2795 ignore_task_deps=False, 

2796 ignore_first_depends_on_past=True, 

2797 pool=None, 

2798 delay_on_limit_secs=1.0, 

2799 verbose=False, 

2800 conf=None, 

2801 rerun_failed_tasks=False, 

2802 run_backwards=False, 

2803 run_at_least_once=False, 

2804 continue_on_failures=False, 

2805 disable_retry=False, 

2806 ): 

2807 """ 

2808 Run the DAG. 

2809 

2810 :param start_date: the start date of the range to run 

2811 :param end_date: the end date of the range to run 

2812 :param mark_success: True to mark jobs as succeeded without running them 

2813 :param local: True to run the tasks using the LocalExecutor 

2814 :param executor: The executor instance to run the tasks 

2815 :param donot_pickle: True to avoid pickling DAG object and send to workers 

2816 :param ignore_task_deps: True to skip upstream tasks 

2817 :param ignore_first_depends_on_past: True to ignore depends_on_past 

2818 dependencies for the first set of tasks only 

2819 :param pool: Resource pool to use 

2820 :param delay_on_limit_secs: Time in seconds to wait before next attempt to run 

2821 dag run when max_active_runs limit has been reached 

2822 :param verbose: Make logging output more verbose 

2823 :param conf: user defined dictionary passed from CLI 

2824 :param rerun_failed_tasks: 

2825 :param run_backwards: 

2826 :param run_at_least_once: If true, always run the DAG at least once even 

2827 if no logical run exists within the time range. 

2828 """ 

2829 from airflow.jobs.backfill_job_runner import BackfillJobRunner 

2830 

2831 if not executor and local: 

2832 from airflow.executors.local_executor import LocalExecutor 

2833 

2834 executor = LocalExecutor() 

2835 elif not executor: 

2836 from airflow.executors.executor_loader import ExecutorLoader 

2837 

2838 executor = ExecutorLoader.get_default_executor() 

2839 from airflow.jobs.job import Job 

2840 

2841 job = Job(executor=executor) 

2842 job_runner = BackfillJobRunner( 

2843 job=job, 

2844 dag=self, 

2845 start_date=start_date, 

2846 end_date=end_date, 

2847 mark_success=mark_success, 

2848 donot_pickle=donot_pickle, 

2849 ignore_task_deps=ignore_task_deps, 

2850 ignore_first_depends_on_past=ignore_first_depends_on_past, 

2851 pool=pool, 

2852 delay_on_limit_secs=delay_on_limit_secs, 

2853 verbose=verbose, 

2854 conf=conf, 

2855 rerun_failed_tasks=rerun_failed_tasks, 

2856 run_backwards=run_backwards, 

2857 run_at_least_once=run_at_least_once, 

2858 continue_on_failures=continue_on_failures, 

2859 disable_retry=disable_retry, 

2860 ) 

2861 run_job(job=job, execute_callable=job_runner._execute) 

2862 

2863 def cli(self): 

2864 """Exposes a CLI specific to this DAG.""" 

2865 check_cycle(self) 

2866 

2867 from airflow.cli import cli_parser 

2868 

2869 parser = cli_parser.get_parser(dag_parser=True) 

2870 args = parser.parse_args() 

2871 args.func(args, self) 

2872 

2873 @provide_session 

2874 def test( 

2875 self, 

2876 execution_date: datetime | None = None, 

2877 run_conf: dict[str, Any] | None = None, 

2878 conn_file_path: str | None = None, 

2879 variable_file_path: str | None = None, 

2880 session: Session = NEW_SESSION, 

2881 ) -> DagRun: 

2882 """ 

2883 Execute one single DagRun for a given DAG and execution date. 

2884 

2885 :param execution_date: execution date for the DAG run 

2886 :param run_conf: configuration to pass to newly created dagrun 

2887 :param conn_file_path: file path to a connection file in either yaml or json 

2888 :param variable_file_path: file path to a variable file in either yaml or json 

2889 :param session: database connection (optional) 

2890 """ 

2891 

2892 def add_logger_if_needed(ti: TaskInstance): 

2893 """Add a formatted logger to the task instance. 

2894 

2895 This allows all logs to surface to the command line, instead of into 

2896 a task file. Since this is a local test run, it is much better for 

2897 the user to see logs in the command line, rather than needing to 

2898 search for a log file. 

2899 

2900 :param ti: The task instance that will receive a logger. 

2901 """ 

2902 format = logging.Formatter("[%(asctime)s] {%(filename)s:%(lineno)d} %(levelname)s - %(message)s") 

2903 handler = logging.StreamHandler(sys.stdout) 

2904 handler.level = logging.INFO 

2905 handler.setFormatter(format) 

2906 # only add log handler once 

2907 if not any(isinstance(h, logging.StreamHandler) for h in ti.log.handlers): 

2908 self.log.debug("Adding Streamhandler to taskinstance %s", ti.task_id) 

2909 ti.log.addHandler(handler) 

2910 

2911 exit_stack = ExitStack() 

2912 if conn_file_path or variable_file_path: 

2913 local_secrets = LocalFilesystemBackend( 

2914 variables_file_path=variable_file_path, connections_file_path=conn_file_path 

2915 ) 

2916 secrets_backend_list.insert(0, local_secrets) 

2917 exit_stack.callback(lambda: secrets_backend_list.pop(0)) 

2918 

2919 with exit_stack: 

2920 execution_date = execution_date or timezone.utcnow() 

2921 self.validate() 

2922 self.log.debug("Clearing existing task instances for execution date %s", execution_date) 

2923 self.clear( 

2924 start_date=execution_date, 

2925 end_date=execution_date, 

2926 dag_run_state=False, # type: ignore 

2927 session=session, 

2928 ) 

2929 self.log.debug("Getting dagrun for dag %s", self.dag_id) 

2930 logical_date = timezone.coerce_datetime(execution_date) 

2931 data_interval = self.timetable.infer_manual_data_interval(run_after=logical_date) 

2932 dr: DagRun = _get_or_create_dagrun( 

2933 dag=self, 

2934 start_date=execution_date, 

2935 execution_date=execution_date, 

2936 run_id=DagRun.generate_run_id(DagRunType.MANUAL, execution_date), 

2937 session=session, 

2938 conf=run_conf, 

2939 data_interval=data_interval, 

2940 ) 

2941 

2942 tasks = self.task_dict 

2943 self.log.debug("starting dagrun") 

2944 # Instead of starting a scheduler, we run the minimal loop possible to check 

2945 # for task readiness and dependency management. This is notably faster 

2946 # than creating a BackfillJob and allows us to surface logs to the user 

2947 while dr.state == DagRunState.RUNNING: 

2948 session.expire_all() 

2949 schedulable_tis, _ = dr.update_state(session=session) 

2950 for s in schedulable_tis: 

2951 if s.state != TaskInstanceState.UP_FOR_RESCHEDULE: 

2952 s.try_number += 1 

2953 s.state = TaskInstanceState.SCHEDULED 

2954 session.commit() 

2955 # triggerer may mark tasks scheduled so we read from DB 

2956 all_tis = set(dr.get_task_instances(session=session)) 

2957 scheduled_tis = {x for x in all_tis if x.state == TaskInstanceState.SCHEDULED} 

2958 ids_unrunnable = {x for x in all_tis if x.state not in State.finished} - scheduled_tis 

2959 if not scheduled_tis and ids_unrunnable: 

2960 self.log.warning("No tasks to run. unrunnable tasks: %s", ids_unrunnable) 

2961 time.sleep(1) 

2962 triggerer_running = _triggerer_is_healthy() 

2963 for ti in scheduled_tis: 

2964 try: 

2965 add_logger_if_needed(ti) 

2966 ti.task = tasks[ti.task_id] 

2967 _run_task(ti=ti, inline_trigger=not triggerer_running, session=session) 

2968 except Exception: 

2969 self.log.exception("Task failed; ti=%s", ti) 

2970 return dr 

2971 

2972 @provide_session 

2973 def create_dagrun( 

2974 self, 

2975 state: DagRunState, 

2976 execution_date: datetime | None = None, 

2977 run_id: str | None = None, 

2978 start_date: datetime | None = None, 

2979 external_trigger: bool | None = False, 

2980 conf: dict | None = None, 

2981 run_type: DagRunType | None = None, 

2982 session: Session = NEW_SESSION, 

2983 dag_hash: str | None = None, 

2984 creating_job_id: int | None = None, 

2985 data_interval: tuple[datetime, datetime] | None = None, 

2986 ): 

2987 """ 

2988 Create a dag run from this dag including the tasks associated with this dag. 

2989 

2990 Returns the dag run. 

2991 

2992 :param run_id: defines the run id for this dag run 

2993 :param run_type: type of DagRun 

2994 :param execution_date: the execution date of this dag run 

2995 :param state: the state of the dag run 

2996 :param start_date: the date this dag run should be evaluated 

2997 :param external_trigger: whether this dag run is externally triggered 

2998 :param conf: Dict containing configuration/parameters to pass to the DAG 

2999 :param creating_job_id: id of the job creating this DagRun 

3000 :param session: database session 

3001 :param dag_hash: Hash of Serialized DAG 

3002 :param data_interval: Data interval of the DagRun 

3003 """ 

3004 logical_date = timezone.coerce_datetime(execution_date) 

3005 

3006 if data_interval and not isinstance(data_interval, DataInterval): 

3007 data_interval = DataInterval(*map(timezone.coerce_datetime, data_interval)) 

3008 

3009 if data_interval is None and logical_date is not None: 

3010 warnings.warn( 

3011 "Calling `DAG.create_dagrun()` without an explicit data interval is deprecated", 

3012 RemovedInAirflow3Warning, 

3013 stacklevel=3, 

3014 ) 

3015 if run_type == DagRunType.MANUAL: 

3016 data_interval = self.timetable.infer_manual_data_interval(run_after=logical_date) 

3017 else: 

3018 data_interval = self.infer_automated_data_interval(logical_date) 

3019 

3020 if run_type is None or isinstance(run_type, DagRunType): 

3021 pass 

3022 elif isinstance(run_type, str): # Compatibility: run_type used to be a str. 

3023 run_type = DagRunType(run_type) 

3024 else: 

3025 raise ValueError(f"`run_type` should be a DagRunType, not {type(run_type)}") 

3026 

3027 if run_id: # Infer run_type from run_id if needed. 

3028 if not isinstance(run_id, str): 

3029 raise ValueError(f"`run_id` should be a str, not {type(run_id)}") 

3030 inferred_run_type = DagRunType.from_run_id(run_id) 

3031 if run_type is None: 

3032 # No explicit type given, use the inferred type. 

3033 run_type = inferred_run_type 

3034 elif run_type == DagRunType.MANUAL and inferred_run_type != DagRunType.MANUAL: 

3035 # Prevent a manual run from using an ID that looks like a scheduled run. 

3036 raise ValueError( 

3037 f"A {run_type.value} DAG run cannot use ID {run_id!r} since it " 

3038 f"is reserved for {inferred_run_type.value} runs" 

3039 ) 

3040 elif run_type and logical_date is not None: # Generate run_id from run_type and execution_date. 

3041 run_id = self.timetable.generate_run_id( 

3042 run_type=run_type, logical_date=logical_date, data_interval=data_interval 

3043 ) 

3044 else: 

3045 raise AirflowException( 

3046 "Creating DagRun needs either `run_id` or both `run_type` and `execution_date`" 

3047 ) 

3048 

3049 regex = airflow_conf.get("scheduler", "allowed_run_id_pattern") 

3050 

3051 if run_id and not re2.match(RUN_ID_REGEX, run_id): 

3052 if not regex.strip() or not re2.match(regex.strip(), run_id): 

3053 raise AirflowException( 

3054 f"The provided run ID '{run_id}' is invalid. It does not match either " 

3055 f"the configured pattern: '{regex}' or the built-in pattern: '{RUN_ID_REGEX}'" 

3056 ) 

3057 

3058 # create a copy of params before validating 

3059 copied_params = copy.deepcopy(self.params) 

3060 copied_params.update(conf or {}) 

3061 copied_params.validate() 

3062 

3063 run = _create_orm_dagrun( 

3064 dag=self, 

3065 dag_id=self.dag_id, 

3066 run_id=run_id, 

3067 logical_date=logical_date, 

3068 start_date=start_date, 

3069 external_trigger=external_trigger, 

3070 conf=conf, 

3071 state=state, 

3072 run_type=run_type, 

3073 dag_hash=dag_hash, 

3074 creating_job_id=creating_job_id, 

3075 data_interval=data_interval, 

3076 session=session, 

3077 ) 

3078 return run 

3079 

3080 @classmethod 

3081 @provide_session 

3082 def bulk_sync_to_db( 

3083 cls, 

3084 dags: Collection[DAG], 

3085 session=NEW_SESSION, 

3086 ): 

3087 """Use `airflow.models.DAG.bulk_write_to_db`, this method is deprecated.""" 

3088 warnings.warn( 

3089 "This method is deprecated and will be removed in a future version. Please use bulk_write_to_db", 

3090 RemovedInAirflow3Warning, 

3091 stacklevel=2, 

3092 ) 

3093 return cls.bulk_write_to_db(dags=dags, session=session) 

3094 

3095 @classmethod 

3096 @provide_session 

3097 def bulk_write_to_db( 

3098 cls, 

3099 dags: Collection[DAG], 

3100 processor_subdir: str | None = None, 

3101 session=NEW_SESSION, 

3102 ): 

3103 """ 

3104 Ensure the DagModel rows for the given dags are up-to-date in the dag table in the DB. 

3105 

3106 Note that this method can be called for both DAGs and SubDAGs. A SubDag is actually a SubDagOperator. 

3107 

3108 :param dags: the DAG objects to save to the DB 

3109 :return: None 

3110 """ 

3111 if not dags: 

3112 return 

3113 

3114 log.info("Sync %s DAGs", len(dags)) 

3115 dag_by_ids = {dag.dag_id: dag for dag in dags} 

3116 

3117 dag_ids = set(dag_by_ids) 

3118 query = ( 

3119 select(DagModel) 

3120 .options(joinedload(DagModel.tags, innerjoin=False)) 

3121 .where(DagModel.dag_id.in_(dag_ids)) 

3122 .options(joinedload(DagModel.schedule_dataset_references)) 

3123 .options(joinedload(DagModel.task_outlet_dataset_references)) 

3124 ) 

3125 query = with_row_locks(query, of=DagModel, session=session) 

3126 orm_dags: list[DagModel] = session.scalars(query).unique().all() 

3127 existing_dags: dict[str, DagModel] = {x.dag_id: x for x in orm_dags} 

3128 missing_dag_ids = dag_ids.difference(existing_dags.keys()) 

3129 

3130 for missing_dag_id in missing_dag_ids: 

3131 orm_dag = DagModel(dag_id=missing_dag_id) 

3132 dag = dag_by_ids[missing_dag_id] 

3133 if dag.is_paused_upon_creation is not None: 

3134 orm_dag.is_paused = dag.is_paused_upon_creation 

3135 orm_dag.tags = [] 

3136 log.info("Creating ORM DAG for %s", dag.dag_id) 

3137 session.add(orm_dag) 

3138 orm_dags.append(orm_dag) 

3139 

3140 latest_runs: dict[str, DagRun] = {} 

3141 num_active_runs: dict[str, int] = {} 

3142 # Skip these queries entirely if no DAGs can be scheduled to save time. 

3143 if any(dag.timetable.can_be_scheduled for dag in dags): 

3144 # Get the latest automated dag run for each existing dag as a single query (avoid n+1 query) 

3145 query = cls._get_latest_runs_stmt(dags=list(existing_dags.keys())) 

3146 latest_runs = {run.dag_id: run for run in session.scalars(query)} 

3147 

3148 # Get number of active dagruns for all dags we are processing as a single query. 

3149 num_active_runs = DagRun.active_runs_of_dags(dag_ids=existing_dags, session=session) 

3150 

3151 filelocs = [] 

3152 

3153 for orm_dag in sorted(orm_dags, key=lambda d: d.dag_id): 

3154 dag = dag_by_ids[orm_dag.dag_id] 

3155 filelocs.append(dag.fileloc) 

3156 if dag.is_subdag: 

3157 orm_dag.is_subdag = True 

3158 orm_dag.fileloc = dag.parent_dag.fileloc # type: ignore 

3159 orm_dag.root_dag_id = dag.parent_dag.dag_id # type: ignore 

3160 orm_dag.owners = dag.parent_dag.owner # type: ignore 

3161 else: 

3162 orm_dag.is_subdag = False 

3163 orm_dag.fileloc = dag.fileloc 

3164 orm_dag.owners = dag.owner 

3165 orm_dag.is_active = True 

3166 orm_dag.has_import_errors = False 

3167 orm_dag.last_parsed_time = timezone.utcnow() 

3168 orm_dag.default_view = dag.default_view 

3169 orm_dag._dag_display_property_value = dag._dag_display_property_value 

3170 orm_dag.description = dag.description 

3171 orm_dag.max_active_tasks = dag.max_active_tasks 

3172 orm_dag.max_active_runs = dag.max_active_runs 

3173 orm_dag.max_consecutive_failed_dag_runs = dag.max_consecutive_failed_dag_runs 

3174 orm_dag.has_task_concurrency_limits = any( 

3175 t.max_active_tis_per_dag is not None or t.max_active_tis_per_dagrun is not None 

3176 for t in dag.tasks 

3177 ) 

3178 orm_dag.schedule_interval = dag.schedule_interval 

3179 orm_dag.timetable_description = dag.timetable.description 

3180 if (dataset_triggers := dag.dataset_triggers) is None: 

3181 orm_dag.dataset_expression = None 

3182 else: 

3183 orm_dag.dataset_expression = dataset_triggers.as_expression() 

3184 

3185 orm_dag.processor_subdir = processor_subdir 

3186 

3187 last_automated_run: DagRun | None = latest_runs.get(dag.dag_id) 

3188 if last_automated_run is None: 

3189 last_automated_data_interval = None 

3190 else: 

3191 last_automated_data_interval = dag.get_run_data_interval(last_automated_run) 

3192 if num_active_runs.get(dag.dag_id, 0) >= orm_dag.max_active_runs: 

3193 orm_dag.next_dagrun_create_after = None 

3194 else: 

3195 orm_dag.calculate_dagrun_date_fields(dag, last_automated_data_interval) 

3196 

3197 dag_tags = set(dag.tags or {}) 

3198 orm_dag_tags = list(orm_dag.tags or []) 

3199 for orm_tag in orm_dag_tags: 

3200 if orm_tag.name not in dag_tags: 

3201 session.delete(orm_tag) 

3202 orm_dag.tags.remove(orm_tag) 

3203 orm_tag_names = {t.name for t in orm_dag_tags} 

3204 for dag_tag in dag_tags: 

3205 if dag_tag not in orm_tag_names: 

3206 dag_tag_orm = DagTag(name=dag_tag, dag_id=dag.dag_id) 

3207 orm_dag.tags.append(dag_tag_orm) 

3208 session.add(dag_tag_orm) 

3209 

3210 orm_dag_links = orm_dag.dag_owner_links or [] 

3211 for orm_dag_link in orm_dag_links: 

3212 if orm_dag_link not in dag.owner_links: 

3213 session.delete(orm_dag_link) 

3214 for owner_name, owner_link in dag.owner_links.items(): 

3215 dag_owner_orm = DagOwnerAttributes(dag_id=dag.dag_id, owner=owner_name, link=owner_link) 

3216 session.add(dag_owner_orm) 

3217 

3218 DagCode.bulk_sync_to_db(filelocs, session=session) 

3219 

3220 from airflow.datasets import Dataset 

3221 from airflow.models.dataset import ( 

3222 DagScheduleDatasetReference, 

3223 DatasetModel, 

3224 TaskOutletDatasetReference, 

3225 ) 

3226 

3227 dag_references = defaultdict(set) 

3228 outlet_references = defaultdict(set) 

3229 # We can't use a set here as we want to preserve order 

3230 outlet_datasets: dict[DatasetModel, None] = {} 

3231 input_datasets: dict[DatasetModel, None] = {} 

3232 

3233 # here we go through dags and tasks to check for dataset references 

3234 # if there are now None and previously there were some, we delete them 

3235 # if there are now *any*, we add them to the above data structures, and 

3236 # later we'll persist them to the database. 

3237 for dag in dags: 

3238 curr_orm_dag = existing_dags.get(dag.dag_id) 

3239 if dag.dataset_triggers is None: 

3240 if curr_orm_dag and curr_orm_dag.schedule_dataset_references: 

3241 curr_orm_dag.schedule_dataset_references = [] 

3242 else: 

3243 for _, dataset in dag.dataset_triggers.iter_datasets(): 

3244 dag_references[dag.dag_id].add(dataset.uri) 

3245 input_datasets[DatasetModel.from_public(dataset)] = None 

3246 curr_outlet_references = curr_orm_dag and curr_orm_dag.task_outlet_dataset_references 

3247 for task in dag.tasks: 

3248 dataset_outlets = [x for x in task.outlets or [] if isinstance(x, Dataset)] 

3249 if not dataset_outlets: 

3250 if curr_outlet_references: 

3251 this_task_outlet_refs = [ 

3252 x 

3253 for x in curr_outlet_references 

3254 if x.dag_id == dag.dag_id and x.task_id == task.task_id 

3255 ] 

3256 for ref in this_task_outlet_refs: 

3257 curr_outlet_references.remove(ref) 

3258 for d in dataset_outlets: 

3259 outlet_references[(task.dag_id, task.task_id)].add(d.uri) 

3260 outlet_datasets[DatasetModel.from_public(d)] = None 

3261 all_datasets = outlet_datasets 

3262 all_datasets.update(input_datasets) 

3263 

3264 # store datasets 

3265 stored_datasets: dict[str, DatasetModel] = {} 

3266 new_datasets: list[DatasetModel] = [] 

3267 for dataset in all_datasets: 

3268 stored_dataset = session.scalar( 

3269 select(DatasetModel).where(DatasetModel.uri == dataset.uri).limit(1) 

3270 ) 

3271 if stored_dataset: 

3272 # Some datasets may have been previously unreferenced, and therefore orphaned by the 

3273 # scheduler. But if we're here, then we have found that dataset again in our DAGs, which 

3274 # means that it is no longer an orphan, so set is_orphaned to False. 

3275 stored_dataset.is_orphaned = expression.false() 

3276 stored_datasets[stored_dataset.uri] = stored_dataset 

3277 else: 

3278 new_datasets.append(dataset) 

3279 dataset_manager.create_datasets(dataset_models=new_datasets, session=session) 

3280 stored_datasets.update({dataset.uri: dataset for dataset in new_datasets}) 

3281 

3282 del new_datasets 

3283 del all_datasets 

3284 

3285 # reconcile dag-schedule-on-dataset references 

3286 for dag_id, uri_list in dag_references.items(): 

3287 dag_refs_needed = { 

3288 DagScheduleDatasetReference(dataset_id=stored_datasets[uri].id, dag_id=dag_id) 

3289 for uri in uri_list 

3290 } 

3291 dag_refs_stored = set( 

3292 existing_dags.get(dag_id) 

3293 and existing_dags.get(dag_id).schedule_dataset_references # type: ignore 

3294 or [] 

3295 ) 

3296 dag_refs_to_add = {x for x in dag_refs_needed if x not in dag_refs_stored} 

3297 session.bulk_save_objects(dag_refs_to_add) 

3298 for obj in dag_refs_stored - dag_refs_needed: 

3299 session.delete(obj) 

3300 

3301 existing_task_outlet_refs_dict = defaultdict(set) 

3302 for dag_id, orm_dag in existing_dags.items(): 

3303 for todr in orm_dag.task_outlet_dataset_references: 

3304 existing_task_outlet_refs_dict[(dag_id, todr.task_id)].add(todr) 

3305 

3306 # reconcile task-outlet-dataset references 

3307 for (dag_id, task_id), uri_list in outlet_references.items(): 

3308 task_refs_needed = { 

3309 TaskOutletDatasetReference(dataset_id=stored_datasets[uri].id, dag_id=dag_id, task_id=task_id) 

3310 for uri in uri_list 

3311 } 

3312 task_refs_stored = existing_task_outlet_refs_dict[(dag_id, task_id)] 

3313 task_refs_to_add = {x for x in task_refs_needed if x not in task_refs_stored} 

3314 session.bulk_save_objects(task_refs_to_add) 

3315 for obj in task_refs_stored - task_refs_needed: 

3316 session.delete(obj) 

3317 

3318 # Issue SQL/finish "Unit of Work", but let @provide_session commit (or if passed a session, let caller 

3319 # decide when to commit 

3320 session.flush() 

3321 

3322 for dag in dags: 

3323 cls.bulk_write_to_db(dag.subdags, processor_subdir=processor_subdir, session=session) 

3324 

3325 @classmethod 

3326 def _get_latest_runs_stmt(cls, dags: list[str]) -> Select: 

3327 """ 

3328 Build a select statement for retrieve the last automated run for each dag. 

3329 

3330 :param dags: dags to query 

3331 """ 

3332 if len(dags) == 1: 

3333 # Index optimized fast path to avoid more complicated & slower groupby queryplan 

3334 existing_dag_id = dags[0] 

3335 last_automated_runs_subq = ( 

3336 select(func.max(DagRun.execution_date).label("max_execution_date")) 

3337 .where( 

3338 DagRun.dag_id == existing_dag_id, 

3339 DagRun.run_type.in_((DagRunType.BACKFILL_JOB, DagRunType.SCHEDULED)), 

3340 ) 

3341 .scalar_subquery() 

3342 ) 

3343 query = select(DagRun).where( 

3344 DagRun.dag_id == existing_dag_id, DagRun.execution_date == last_automated_runs_subq 

3345 ) 

3346 else: 

3347 last_automated_runs_subq = ( 

3348 select(DagRun.dag_id, func.max(DagRun.execution_date).label("max_execution_date")) 

3349 .where( 

3350 DagRun.dag_id.in_(dags), 

3351 DagRun.run_type.in_((DagRunType.BACKFILL_JOB, DagRunType.SCHEDULED)), 

3352 ) 

3353 .group_by(DagRun.dag_id) 

3354 .subquery() 

3355 ) 

3356 query = select(DagRun).where( 

3357 DagRun.dag_id == last_automated_runs_subq.c.dag_id, 

3358 DagRun.execution_date == last_automated_runs_subq.c.max_execution_date, 

3359 ) 

3360 return query.options( 

3361 load_only( 

3362 DagRun.dag_id, 

3363 DagRun.execution_date, 

3364 DagRun.data_interval_start, 

3365 DagRun.data_interval_end, 

3366 ) 

3367 ) 

3368 

3369 @provide_session 

3370 def sync_to_db(self, processor_subdir: str | None = None, session=NEW_SESSION): 

3371 """ 

3372 Save attributes about this DAG to the DB. 

3373 

3374 Note that this method can be called for both DAGs and SubDAGs. A SubDag is actually a SubDagOperator. 

3375 

3376 :return: None 

3377 """ 

3378 self.bulk_write_to_db([self], processor_subdir=processor_subdir, session=session) 

3379 

3380 def get_default_view(self): 

3381 """Allow backward compatible jinja2 templates.""" 

3382 if self.default_view is None: 

3383 return airflow_conf.get("webserver", "dag_default_view").lower() 

3384 else: 

3385 return self.default_view 

3386 

3387 @staticmethod 

3388 @provide_session 

3389 def deactivate_unknown_dags(active_dag_ids, session=NEW_SESSION): 

3390 """ 

3391 Given a list of known DAGs, deactivate any other DAGs that are marked as active in the ORM. 

3392 

3393 :param active_dag_ids: list of DAG IDs that are active 

3394 :return: None 

3395 """ 

3396 if not active_dag_ids: 

3397 return 

3398 for dag in session.scalars(select(DagModel).where(~DagModel.dag_id.in_(active_dag_ids))).all(): 

3399 dag.is_active = False 

3400 session.merge(dag) 

3401 session.commit() 

3402 

3403 @staticmethod 

3404 @provide_session 

3405 def deactivate_stale_dags(expiration_date, session=NEW_SESSION): 

3406 """ 

3407 Deactivate any DAGs that were last touched by the scheduler before the expiration date. 

3408 

3409 These DAGs were likely deleted. 

3410 

3411 :param expiration_date: set inactive DAGs that were touched before this time 

3412 :return: None 

3413 """ 

3414 for dag in session.scalars( 

3415 select(DagModel).where(DagModel.last_parsed_time < expiration_date, DagModel.is_active) 

3416 ): 

3417 log.info( 

3418 "Deactivating DAG ID %s since it was last touched by the scheduler at %s", 

3419 dag.dag_id, 

3420 dag.last_parsed_time.isoformat(), 

3421 ) 

3422 dag.is_active = False 

3423 session.merge(dag) 

3424 session.commit() 

3425 

3426 @staticmethod 

3427 @provide_session 

3428 def get_num_task_instances(dag_id, run_id=None, task_ids=None, states=None, session=NEW_SESSION) -> int: 

3429 """ 

3430 Return the number of task instances in the given DAG. 

3431 

3432 :param session: ORM session 

3433 :param dag_id: ID of the DAG to get the task concurrency of 

3434 :param run_id: ID of the DAG run to get the task concurrency of 

3435 :param task_ids: A list of valid task IDs for the given DAG 

3436 :param states: A list of states to filter by if supplied 

3437 :return: The number of running tasks 

3438 """ 

3439 qry = select(func.count(TaskInstance.task_id)).where( 

3440 TaskInstance.dag_id == dag_id, 

3441 ) 

3442 if run_id: 

3443 qry = qry.where( 

3444 TaskInstance.run_id == run_id, 

3445 ) 

3446 if task_ids: 

3447 qry = qry.where( 

3448 TaskInstance.task_id.in_(task_ids), 

3449 ) 

3450 

3451 if states: 

3452 if None in states: 

3453 if all(x is None for x in states): 

3454 qry = qry.where(TaskInstance.state.is_(None)) 

3455 else: 

3456 not_none_states = [state for state in states if state] 

3457 qry = qry.where( 

3458 or_(TaskInstance.state.in_(not_none_states), TaskInstance.state.is_(None)) 

3459 ) 

3460 else: 

3461 qry = qry.where(TaskInstance.state.in_(states)) 

3462 return session.scalar(qry) 

3463 

3464 @classmethod 

3465 def get_serialized_fields(cls): 

3466 """Stringified DAGs and operators contain exactly these fields.""" 

3467 if not cls.__serialized_fields: 

3468 exclusion_list = { 

3469 "parent_dag", 

3470 "schedule_dataset_references", 

3471 "task_outlet_dataset_references", 

3472 "_old_context_manager_dags", 

3473 "safe_dag_id", 

3474 "last_loaded", 

3475 "user_defined_filters", 

3476 "user_defined_macros", 

3477 "partial", 

3478 "params", 

3479 "_pickle_id", 

3480 "_log", 

3481 "task_dict", 

3482 "template_searchpath", 

3483 "sla_miss_callback", 

3484 "on_success_callback", 

3485 "on_failure_callback", 

3486 "template_undefined", 

3487 "jinja_environment_kwargs", 

3488 # has_on_*_callback are only stored if the value is True, as the default is False 

3489 "has_on_success_callback", 

3490 "has_on_failure_callback", 

3491 "auto_register", 

3492 "fail_stop", 

3493 } 

3494 cls.__serialized_fields = frozenset(vars(DAG(dag_id="test"))) - exclusion_list 

3495 return cls.__serialized_fields 

3496 

3497 def get_edge_info(self, upstream_task_id: str, downstream_task_id: str) -> EdgeInfoType: 

3498 """Return edge information for the given pair of tasks or an empty edge if there is no information.""" 

3499 # Note - older serialized DAGs may not have edge_info being a dict at all 

3500 empty = cast(EdgeInfoType, {}) 

3501 if self.edge_info: 

3502 return self.edge_info.get(upstream_task_id, {}).get(downstream_task_id, empty) 

3503 else: 

3504 return empty 

3505 

3506 def set_edge_info(self, upstream_task_id: str, downstream_task_id: str, info: EdgeInfoType): 

3507 """ 

3508 Set the given edge information on the DAG. 

3509 

3510 Note that this will overwrite, rather than merge with, existing info. 

3511 """ 

3512 self.edge_info.setdefault(upstream_task_id, {})[downstream_task_id] = info 

3513 

3514 def validate_schedule_and_params(self): 

3515 """ 

3516 Validate Param values when the DAG has schedule defined. 

3517 

3518 Raise exception if there are any Params which can not be resolved by their schema definition. 

3519 """ 

3520 if not self.timetable.can_be_scheduled: 

3521 return 

3522 

3523 try: 

3524 self.params.validate() 

3525 except ParamValidationError as pverr: 

3526 raise AirflowException( 

3527 "DAG is not allowed to define a Schedule, " 

3528 "if there are any required params without default values or default values are not valid." 

3529 ) from pverr 

3530 

3531 def iter_invalid_owner_links(self) -> Iterator[tuple[str, str]]: 

3532 """ 

3533 Parse a given link, and verifies if it's a valid URL, or a 'mailto' link. 

3534 

3535 Returns an iterator of invalid (owner, link) pairs. 

3536 """ 

3537 for owner, link in self.owner_links.items(): 

3538 result = urlsplit(link) 

3539 if result.scheme == "mailto": 

3540 # netloc is not existing for 'mailto' link, so we are checking that the path is parsed 

3541 if not result.path: 

3542 yield result.path, link 

3543 elif not result.scheme or not result.netloc: 

3544 yield owner, link 

3545 

3546 

3547class DagTag(Base): 

3548 """A tag name per dag, to allow quick filtering in the DAG view.""" 

3549 

3550 __tablename__ = "dag_tag" 

3551 name = Column(String(TAG_MAX_LEN), primary_key=True) 

3552 dag_id = Column( 

3553 StringID(), 

3554 ForeignKey("dag.dag_id", name="dag_tag_dag_id_fkey", ondelete="CASCADE"), 

3555 primary_key=True, 

3556 ) 

3557 

3558 __table_args__ = (Index("idx_dag_tag_dag_id", dag_id),) 

3559 

3560 def __repr__(self): 

3561 return self.name 

3562 

3563 

3564class DagOwnerAttributes(Base): 

3565 """ 

3566 Table defining different owner attributes. 

3567 

3568 For example, a link for an owner that will be passed as a hyperlink to the "DAGs" view. 

3569 """ 

3570 

3571 __tablename__ = "dag_owner_attributes" 

3572 dag_id = Column( 

3573 StringID(), 

3574 ForeignKey("dag.dag_id", name="dag.dag_id", ondelete="CASCADE"), 

3575 nullable=False, 

3576 primary_key=True, 

3577 ) 

3578 owner = Column(String(500), primary_key=True, nullable=False) 

3579 link = Column(String(500), nullable=False) 

3580 

3581 def __repr__(self): 

3582 return f"<DagOwnerAttributes: dag_id={self.dag_id}, owner={self.owner}, link={self.link}>" 

3583 

3584 @classmethod 

3585 def get_all(cls, session) -> dict[str, dict[str, str]]: 

3586 dag_links: dict = defaultdict(dict) 

3587 for obj in session.scalars(select(cls)): 

3588 dag_links[obj.dag_id].update({obj.owner: obj.link}) 

3589 return dag_links 

3590 

3591 

3592class DagModel(Base): 

3593 """Table containing DAG properties.""" 

3594 

3595 __tablename__ = "dag" 

3596 """ 

3597 These items are stored in the database for state related information 

3598 """ 

3599 dag_id = Column(StringID(), primary_key=True) 

3600 root_dag_id = Column(StringID()) 

3601 # A DAG can be paused from the UI / DB 

3602 # Set this default value of is_paused based on a configuration value! 

3603 is_paused_at_creation = airflow_conf.getboolean("core", "dags_are_paused_at_creation") 

3604 is_paused = Column(Boolean, default=is_paused_at_creation) 

3605 # Whether the DAG is a subdag 

3606 is_subdag = Column(Boolean, default=False) 

3607 # Whether that DAG was seen on the last DagBag load 

3608 is_active = Column(Boolean, default=False) 

3609 # Last time the scheduler started 

3610 last_parsed_time = Column(UtcDateTime) 

3611 # Last time this DAG was pickled 

3612 last_pickled = Column(UtcDateTime) 

3613 # Time when the DAG last received a refresh signal 

3614 # (e.g. the DAG's "refresh" button was clicked in the web UI) 

3615 last_expired = Column(UtcDateTime) 

3616 # Whether (one of) the scheduler is scheduling this DAG at the moment 

3617 scheduler_lock = Column(Boolean) 

3618 # Foreign key to the latest pickle_id 

3619 pickle_id = Column(Integer) 

3620 # The location of the file containing the DAG object 

3621 # Note: Do not depend on fileloc pointing to a file; in the case of a 

3622 # packaged DAG, it will point to the subpath of the DAG within the 

3623 # associated zip. 

3624 fileloc = Column(String(2000)) 

3625 # The base directory used by Dag Processor that parsed this dag. 

3626 processor_subdir = Column(String(2000), nullable=True) 

3627 # String representing the owners 

3628 owners = Column(String(2000)) 

3629 # Display name of the dag 

3630 _dag_display_property_value = Column("dag_display_name", String(2000), nullable=True) 

3631 # Description of the dag 

3632 description = Column(Text) 

3633 # Default view of the DAG inside the webserver 

3634 default_view = Column(String(25)) 

3635 # Schedule interval 

3636 schedule_interval = Column(Interval) 

3637 # Timetable/Schedule Interval description 

3638 timetable_description = Column(String(1000), nullable=True) 

3639 # Dataset expression based on dataset triggers 

3640 dataset_expression = Column(sqlalchemy_jsonfield.JSONField(json=json), nullable=True) 

3641 # Tags for view filter 

3642 tags = relationship("DagTag", cascade="all, delete, delete-orphan", backref=backref("dag")) 

3643 # Dag owner links for DAGs view 

3644 dag_owner_links = relationship( 

3645 "DagOwnerAttributes", cascade="all, delete, delete-orphan", backref=backref("dag") 

3646 ) 

3647 

3648 max_active_tasks = Column(Integer, nullable=False) 

3649 max_active_runs = Column(Integer, nullable=True) 

3650 max_consecutive_failed_dag_runs = Column(Integer, nullable=False) 

3651 

3652 has_task_concurrency_limits = Column(Boolean, nullable=False) 

3653 has_import_errors = Column(Boolean(), default=False, server_default="0") 

3654 

3655 # The logical date of the next dag run. 

3656 next_dagrun = Column(UtcDateTime) 

3657 

3658 # Must be either both NULL or both datetime. 

3659 next_dagrun_data_interval_start = Column(UtcDateTime) 

3660 next_dagrun_data_interval_end = Column(UtcDateTime) 

3661 

3662 # Earliest time at which this ``next_dagrun`` can be created. 

3663 next_dagrun_create_after = Column(UtcDateTime) 

3664 

3665 __table_args__ = ( 

3666 Index("idx_root_dag_id", root_dag_id, unique=False), 

3667 Index("idx_next_dagrun_create_after", next_dagrun_create_after, unique=False), 

3668 ) 

3669 

3670 parent_dag = relationship( 

3671 "DagModel", remote_side=[dag_id], primaryjoin=root_dag_id == dag_id, foreign_keys=[root_dag_id] 

3672 ) 

3673 schedule_dataset_references = relationship( 

3674 "DagScheduleDatasetReference", 

3675 back_populates="dag", 

3676 cascade="all, delete, delete-orphan", 

3677 ) 

3678 schedule_datasets = association_proxy("schedule_dataset_references", "dataset") 

3679 task_outlet_dataset_references = relationship( 

3680 "TaskOutletDatasetReference", 

3681 cascade="all, delete, delete-orphan", 

3682 ) 

3683 NUM_DAGS_PER_DAGRUN_QUERY = airflow_conf.getint( 

3684 "scheduler", "max_dagruns_to_create_per_loop", fallback=10 

3685 ) 

3686 

3687 def __init__(self, concurrency=None, **kwargs): 

3688 super().__init__(**kwargs) 

3689 if self.max_active_tasks is None: 

3690 if concurrency: 

3691 warnings.warn( 

3692 "The 'DagModel.concurrency' parameter is deprecated. Please use 'max_active_tasks'.", 

3693 RemovedInAirflow3Warning, 

3694 stacklevel=2, 

3695 ) 

3696 self.max_active_tasks = concurrency 

3697 else: 

3698 self.max_active_tasks = airflow_conf.getint("core", "max_active_tasks_per_dag") 

3699 

3700 if self.max_active_runs is None: 

3701 self.max_active_runs = airflow_conf.getint("core", "max_active_runs_per_dag") 

3702 

3703 if self.max_consecutive_failed_dag_runs is None: 

3704 self.max_consecutive_failed_dag_runs = airflow_conf.getint( 

3705 "core", "max_consecutive_failed_dag_runs_per_dag" 

3706 ) 

3707 

3708 if self.has_task_concurrency_limits is None: 

3709 # Be safe -- this will be updated later once the DAG is parsed 

3710 self.has_task_concurrency_limits = True 

3711 

3712 def __repr__(self): 

3713 return f"<DAG: {self.dag_id}>" 

3714 

3715 @property 

3716 def next_dagrun_data_interval(self) -> DataInterval | None: 

3717 return _get_model_data_interval( 

3718 self, 

3719 "next_dagrun_data_interval_start", 

3720 "next_dagrun_data_interval_end", 

3721 ) 

3722 

3723 @next_dagrun_data_interval.setter 

3724 def next_dagrun_data_interval(self, value: tuple[datetime, datetime] | None) -> None: 

3725 if value is None: 

3726 self.next_dagrun_data_interval_start = self.next_dagrun_data_interval_end = None 

3727 else: 

3728 self.next_dagrun_data_interval_start, self.next_dagrun_data_interval_end = value 

3729 

3730 @property 

3731 def timezone(self): 

3732 return settings.TIMEZONE 

3733 

3734 @staticmethod 

3735 @provide_session 

3736 def get_dagmodel(dag_id: str, session: Session = NEW_SESSION) -> DagModel | None: 

3737 return session.get( 

3738 DagModel, 

3739 dag_id, 

3740 options=[joinedload(DagModel.parent_dag)], 

3741 ) 

3742 

3743 @classmethod 

3744 @internal_api_call 

3745 @provide_session 

3746 def get_current(cls, dag_id: str, session=NEW_SESSION) -> DagModel | DagModelPydantic: 

3747 return session.scalar(select(cls).where(cls.dag_id == dag_id)) 

3748 

3749 @provide_session 

3750 def get_last_dagrun(self, session=NEW_SESSION, include_externally_triggered=False): 

3751 return get_last_dagrun( 

3752 self.dag_id, session=session, include_externally_triggered=include_externally_triggered 

3753 ) 

3754 

3755 def get_is_paused(self, *, session: Session | None = None) -> bool: 

3756 """Provide interface compatibility to 'DAG'.""" 

3757 return self.is_paused 

3758 

3759 def get_is_active(self, *, session: Session | None = None) -> bool: 

3760 """Provide interface compatibility to 'DAG'.""" 

3761 return self.is_active 

3762 

3763 @staticmethod 

3764 @internal_api_call 

3765 @provide_session 

3766 def get_paused_dag_ids(dag_ids: list[str], session: Session = NEW_SESSION) -> set[str]: 

3767 """ 

3768 Given a list of dag_ids, get a set of Paused Dag Ids. 

3769 

3770 :param dag_ids: List of Dag ids 

3771 :param session: ORM Session 

3772 :return: Paused Dag_ids 

3773 """ 

3774 paused_dag_ids = session.execute( 

3775 select(DagModel.dag_id) 

3776 .where(DagModel.is_paused == expression.true()) 

3777 .where(DagModel.dag_id.in_(dag_ids)) 

3778 ) 

3779 

3780 paused_dag_ids = {paused_dag_id for (paused_dag_id,) in paused_dag_ids} 

3781 return paused_dag_ids 

3782 

3783 def get_default_view(self) -> str: 

3784 """Get the Default DAG View, returns the default config value if DagModel does not have a value.""" 

3785 # This is for backwards-compatibility with old dags that don't have None as default_view 

3786 return self.default_view or airflow_conf.get_mandatory_value("webserver", "dag_default_view").lower() 

3787 

3788 @property 

3789 def safe_dag_id(self): 

3790 return self.dag_id.replace(".", "__dot__") 

3791 

3792 @property 

3793 def relative_fileloc(self) -> pathlib.Path | None: 

3794 """File location of the importable dag 'file' relative to the configured DAGs folder.""" 

3795 if self.fileloc is None: 

3796 return None 

3797 path = pathlib.Path(self.fileloc) 

3798 try: 

3799 return path.relative_to(settings.DAGS_FOLDER) 

3800 except ValueError: 

3801 # Not relative to DAGS_FOLDER. 

3802 return path 

3803 

3804 @provide_session 

3805 def set_is_paused(self, is_paused: bool, including_subdags: bool = True, session=NEW_SESSION) -> None: 

3806 """ 

3807 Pause/Un-pause a DAG. 

3808 

3809 :param is_paused: Is the DAG paused 

3810 :param including_subdags: whether to include the DAG's subdags 

3811 :param session: session 

3812 """ 

3813 filter_query = [ 

3814 DagModel.dag_id == self.dag_id, 

3815 ] 

3816 if including_subdags: 

3817 filter_query.append(DagModel.root_dag_id == self.dag_id) 

3818 session.execute( 

3819 update(DagModel) 

3820 .where(or_(*filter_query)) 

3821 .values(is_paused=is_paused) 

3822 .execution_options(synchronize_session="fetch") 

3823 ) 

3824 session.commit() 

3825 

3826 @hybrid_property 

3827 def dag_display_name(self) -> str: 

3828 return self._dag_display_property_value or self.dag_id 

3829 

3830 @classmethod 

3831 @internal_api_call 

3832 @provide_session 

3833 def deactivate_deleted_dags( 

3834 cls, 

3835 alive_dag_filelocs: Container[str], 

3836 processor_subdir: str, 

3837 session: Session = NEW_SESSION, 

3838 ) -> None: 

3839 """ 

3840 Set ``is_active=False`` on the DAGs for which the DAG files have been removed. 

3841 

3842 :param alive_dag_filelocs: file paths of alive DAGs 

3843 :param processor_subdir: dag processor subdir 

3844 :param session: ORM Session 

3845 """ 

3846 log.debug("Deactivating DAGs (for which DAG files are deleted) from %s table ", cls.__tablename__) 

3847 dag_models = session.scalars( 

3848 select(cls).where( 

3849 cls.fileloc.is_not(None), 

3850 or_( 

3851 cls.processor_subdir.is_(None), 

3852 cls.processor_subdir == processor_subdir, 

3853 ), 

3854 ) 

3855 ) 

3856 

3857 for dag_model in dag_models: 

3858 if dag_model.fileloc not in alive_dag_filelocs: 

3859 dag_model.is_active = False 

3860 

3861 @classmethod 

3862 def dags_needing_dagruns(cls, session: Session) -> tuple[Query, dict[str, tuple[datetime, datetime]]]: 

3863 """ 

3864 Return (and lock) a list of Dag objects that are due to create a new DagRun. 

3865 

3866 This will return a resultset of rows that is row-level-locked with a "SELECT ... FOR UPDATE" query, 

3867 you should ensure that any scheduling decisions are made in a single transaction -- as soon as the 

3868 transaction is committed it will be unlocked. 

3869 """ 

3870 from airflow.models.serialized_dag import SerializedDagModel 

3871 

3872 def dag_ready(dag_id: str, cond: BaseDataset, statuses: dict) -> bool | None: 

3873 # if dag was serialized before 2.9 and we *just* upgraded, 

3874 # we may be dealing with old version. In that case, 

3875 # just wait for the dag to be reserialized. 

3876 try: 

3877 return cond.evaluate(statuses) 

3878 except AttributeError: 

3879 log.warning("dag '%s' has old serialization; skipping DAG run creation.", dag_id) 

3880 return None 

3881 

3882 # this loads all the DDRQ records.... may need to limit num dags 

3883 all_records = session.scalars(select(DatasetDagRunQueue)).all() 

3884 by_dag = defaultdict(list) 

3885 for r in all_records: 

3886 by_dag[r.target_dag_id].append(r) 

3887 del all_records 

3888 dag_statuses = {} 

3889 for dag_id, records in by_dag.items(): 

3890 dag_statuses[dag_id] = {x.dataset.uri: True for x in records} 

3891 ser_dags = session.scalars( 

3892 select(SerializedDagModel).where(SerializedDagModel.dag_id.in_(dag_statuses.keys())) 

3893 ).all() 

3894 for ser_dag in ser_dags: 

3895 dag_id = ser_dag.dag_id 

3896 statuses = dag_statuses[dag_id] 

3897 if not dag_ready(dag_id, cond=ser_dag.dag.dataset_triggers, statuses=statuses): 

3898 del by_dag[dag_id] 

3899 del dag_statuses[dag_id] 

3900 del dag_statuses 

3901 dataset_triggered_dag_info = {} 

3902 for dag_id, records in by_dag.items(): 

3903 times = sorted(x.created_at for x in records) 

3904 dataset_triggered_dag_info[dag_id] = (times[0], times[-1]) 

3905 del by_dag 

3906 dataset_triggered_dag_ids = set(dataset_triggered_dag_info.keys()) 

3907 if dataset_triggered_dag_ids: 

3908 exclusion_list = set( 

3909 session.scalars( 

3910 select(DagModel.dag_id) 

3911 .join(DagRun.dag_model) 

3912 .where(DagRun.state.in_((DagRunState.QUEUED, DagRunState.RUNNING))) 

3913 .where(DagModel.dag_id.in_(dataset_triggered_dag_ids)) 

3914 .group_by(DagModel.dag_id) 

3915 .having(func.count() >= func.max(DagModel.max_active_runs)) 

3916 ) 

3917 ) 

3918 if exclusion_list: 

3919 dataset_triggered_dag_ids -= exclusion_list 

3920 dataset_triggered_dag_info = { 

3921 k: v for k, v in dataset_triggered_dag_info.items() if k not in exclusion_list 

3922 } 

3923 

3924 # We limit so that _one_ scheduler doesn't try to do all the creation of dag runs 

3925 query = ( 

3926 select(cls) 

3927 .where( 

3928 cls.is_paused == expression.false(), 

3929 cls.is_active == expression.true(), 

3930 cls.has_import_errors == expression.false(), 

3931 or_( 

3932 cls.next_dagrun_create_after <= func.now(), 

3933 cls.dag_id.in_(dataset_triggered_dag_ids), 

3934 ), 

3935 ) 

3936 .order_by(cls.next_dagrun_create_after) 

3937 .limit(cls.NUM_DAGS_PER_DAGRUN_QUERY) 

3938 ) 

3939 

3940 return ( 

3941 session.scalars(with_row_locks(query, of=cls, session=session, skip_locked=True)), 

3942 dataset_triggered_dag_info, 

3943 ) 

3944 

3945 def calculate_dagrun_date_fields( 

3946 self, 

3947 dag: DAG, 

3948 last_automated_dag_run: None | datetime | DataInterval, 

3949 ) -> None: 

3950 """ 

3951 Calculate ``next_dagrun`` and `next_dagrun_create_after``. 

3952 

3953 :param dag: The DAG object 

3954 :param last_automated_dag_run: DataInterval (or datetime) of most recent run of this dag, or none 

3955 if not yet scheduled. 

3956 """ 

3957 last_automated_data_interval: DataInterval | None 

3958 if isinstance(last_automated_dag_run, datetime): 

3959 warnings.warn( 

3960 "Passing a datetime to `DagModel.calculate_dagrun_date_fields` is deprecated. " 

3961 "Provide a data interval instead.", 

3962 RemovedInAirflow3Warning, 

3963 stacklevel=2, 

3964 ) 

3965 last_automated_data_interval = dag.infer_automated_data_interval(last_automated_dag_run) 

3966 else: 

3967 last_automated_data_interval = last_automated_dag_run 

3968 next_dagrun_info = dag.next_dagrun_info(last_automated_data_interval) 

3969 if next_dagrun_info is None: 

3970 self.next_dagrun_data_interval = self.next_dagrun = self.next_dagrun_create_after = None 

3971 else: 

3972 self.next_dagrun_data_interval = next_dagrun_info.data_interval 

3973 self.next_dagrun = next_dagrun_info.logical_date 

3974 self.next_dagrun_create_after = next_dagrun_info.run_after 

3975 

3976 log.info( 

3977 "Setting next_dagrun for %s to %s, run_after=%s", 

3978 dag.dag_id, 

3979 self.next_dagrun, 

3980 self.next_dagrun_create_after, 

3981 ) 

3982 

3983 @provide_session 

3984 def get_dataset_triggered_next_run_info(self, *, session=NEW_SESSION) -> dict[str, int | str] | None: 

3985 if self.schedule_interval != "Dataset": 

3986 return None 

3987 return get_dataset_triggered_next_run_info([self.dag_id], session=session)[self.dag_id] 

3988 

3989 

3990# NOTE: Please keep the list of arguments in sync with DAG.__init__. 

3991# Only exception: dag_id here should have a default value, but not in DAG. 

3992def dag( 

3993 dag_id: str = "", 

3994 description: str | None = None, 

3995 schedule: ScheduleArg = NOTSET, 

3996 schedule_interval: ScheduleIntervalArg = NOTSET, 

3997 timetable: Timetable | None = None, 

3998 start_date: datetime | None = None, 

3999 end_date: datetime | None = None, 

4000 full_filepath: str | None = None, 

4001 template_searchpath: str | Iterable[str] | None = None, 

4002 template_undefined: type[jinja2.StrictUndefined] = jinja2.StrictUndefined, 

4003 user_defined_macros: dict | None = None, 

4004 user_defined_filters: dict | None = None, 

4005 default_args: dict | None = None, 

4006 concurrency: int | None = None, 

4007 max_active_tasks: int = airflow_conf.getint("core", "max_active_tasks_per_dag"), 

4008 max_active_runs: int = airflow_conf.getint("core", "max_active_runs_per_dag"), 

4009 max_consecutive_failed_dag_runs: int = airflow_conf.getint( 

4010 "core", "max_consecutive_failed_dag_runs_per_dag" 

4011 ), 

4012 dagrun_timeout: timedelta | None = None, 

4013 sla_miss_callback: None | SLAMissCallback | list[SLAMissCallback] = None, 

4014 default_view: str = airflow_conf.get_mandatory_value("webserver", "dag_default_view").lower(), 

4015 orientation: str = airflow_conf.get_mandatory_value("webserver", "dag_orientation"), 

4016 catchup: bool = airflow_conf.getboolean("scheduler", "catchup_by_default"), 

4017 on_success_callback: None | DagStateChangeCallback | list[DagStateChangeCallback] = None, 

4018 on_failure_callback: None | DagStateChangeCallback | list[DagStateChangeCallback] = None, 

4019 doc_md: str | None = None, 

4020 params: abc.MutableMapping | None = None, 

4021 access_control: dict | None = None, 

4022 is_paused_upon_creation: bool | None = None, 

4023 jinja_environment_kwargs: dict | None = None, 

4024 render_template_as_native_obj: bool = False, 

4025 tags: list[str] | None = None, 

4026 owner_links: dict[str, str] | None = None, 

4027 auto_register: bool = True, 

4028 fail_stop: bool = False, 

4029 dag_display_name: str | None = None, 

4030) -> Callable[[Callable], Callable[..., DAG]]: 

4031 """ 

4032 Python dag decorator which wraps a function into an Airflow DAG. 

4033 

4034 Accepts kwargs for operator kwarg. Can be used to parameterize DAGs. 

4035 

4036 :param dag_args: Arguments for DAG object 

4037 :param dag_kwargs: Kwargs for DAG object. 

4038 """ 

4039 

4040 def wrapper(f: Callable) -> Callable[..., DAG]: 

4041 @functools.wraps(f) 

4042 def factory(*args, **kwargs): 

4043 # Generate signature for decorated function and bind the arguments when called 

4044 # we do this to extract parameters, so we can annotate them on the DAG object. 

4045 # In addition, this fails if we are missing any args/kwargs with TypeError as expected. 

4046 f_sig = signature(f).bind(*args, **kwargs) 

4047 # Apply defaults to capture default values if set. 

4048 f_sig.apply_defaults() 

4049 

4050 # Initialize DAG with bound arguments 

4051 with DAG( 

4052 dag_id or f.__name__, 

4053 description=description, 

4054 schedule_interval=schedule_interval, 

4055 timetable=timetable, 

4056 start_date=start_date, 

4057 end_date=end_date, 

4058 full_filepath=full_filepath, 

4059 template_searchpath=template_searchpath, 

4060 template_undefined=template_undefined, 

4061 user_defined_macros=user_defined_macros, 

4062 user_defined_filters=user_defined_filters, 

4063 default_args=default_args, 

4064 concurrency=concurrency, 

4065 max_active_tasks=max_active_tasks, 

4066 max_active_runs=max_active_runs, 

4067 max_consecutive_failed_dag_runs=max_consecutive_failed_dag_runs, 

4068 dagrun_timeout=dagrun_timeout, 

4069 sla_miss_callback=sla_miss_callback, 

4070 default_view=default_view, 

4071 orientation=orientation, 

4072 catchup=catchup, 

4073 on_success_callback=on_success_callback, 

4074 on_failure_callback=on_failure_callback, 

4075 doc_md=doc_md, 

4076 params=params, 

4077 access_control=access_control, 

4078 is_paused_upon_creation=is_paused_upon_creation, 

4079 jinja_environment_kwargs=jinja_environment_kwargs, 

4080 render_template_as_native_obj=render_template_as_native_obj, 

4081 tags=tags, 

4082 schedule=schedule, 

4083 owner_links=owner_links, 

4084 auto_register=auto_register, 

4085 fail_stop=fail_stop, 

4086 dag_display_name=dag_display_name, 

4087 ) as dag_obj: 

4088 # Set DAG documentation from function documentation if it exists and doc_md is not set. 

4089 if f.__doc__ and not dag_obj.doc_md: 

4090 dag_obj.doc_md = f.__doc__ 

4091 

4092 # Generate DAGParam for each function arg/kwarg and replace it for calling the function. 

4093 # All args/kwargs for function will be DAGParam object and replaced on execution time. 

4094 f_kwargs = {} 

4095 for name, value in f_sig.arguments.items(): 

4096 f_kwargs[name] = dag_obj.param(name, value) 

4097 

4098 # set file location to caller source path 

4099 back = sys._getframe().f_back 

4100 dag_obj.fileloc = back.f_code.co_filename if back else "" 

4101 

4102 # Invoke function to create operators in the DAG scope. 

4103 f(**f_kwargs) 

4104 

4105 # Return dag object such that it's accessible in Globals. 

4106 return dag_obj 

4107 

4108 # Ensure that warnings from inside DAG() are emitted from the caller, not here 

4109 fixup_decorator_warning_stack(factory) 

4110 return factory 

4111 

4112 return wrapper 

4113 

4114 

4115STATICA_HACK = True 

4116globals()["kcah_acitats"[::-1].upper()] = False 

4117if STATICA_HACK: # pragma: no cover 

4118 from airflow.models.serialized_dag import SerializedDagModel 

4119 

4120 DagModel.serialized_dag = relationship(SerializedDagModel) 

4121 """:sphinx-autoapi-skip:""" 

4122 

4123 

4124class DagContext: 

4125 """ 

4126 DAG context is used to keep the current DAG when DAG is used as ContextManager. 

4127 

4128 You can use DAG as context: 

4129 

4130 .. code-block:: python 

4131 

4132 with DAG( 

4133 dag_id="example_dag", 

4134 default_args=default_args, 

4135 schedule="0 0 * * *", 

4136 dagrun_timeout=timedelta(minutes=60), 

4137 ) as dag: 

4138 ... 

4139 

4140 If you do this the context stores the DAG and whenever new task is created, it will use 

4141 such stored DAG as the parent DAG. 

4142 

4143 """ 

4144 

4145 _context_managed_dags: deque[DAG] = deque() 

4146 autoregistered_dags: set[tuple[DAG, ModuleType]] = set() 

4147 current_autoregister_module_name: str | None = None 

4148 

4149 @classmethod 

4150 def push_context_managed_dag(cls, dag: DAG): 

4151 cls._context_managed_dags.appendleft(dag) 

4152 

4153 @classmethod 

4154 def pop_context_managed_dag(cls) -> DAG | None: 

4155 dag = cls._context_managed_dags.popleft() 

4156 

4157 # In a few cases around serialization we explicitly push None in to the stack 

4158 if cls.current_autoregister_module_name is not None and dag and dag.auto_register: 

4159 mod = sys.modules[cls.current_autoregister_module_name] 

4160 cls.autoregistered_dags.add((dag, mod)) 

4161 

4162 return dag 

4163 

4164 @classmethod 

4165 def get_current_dag(cls) -> DAG | None: 

4166 try: 

4167 return cls._context_managed_dags[0] 

4168 except IndexError: 

4169 return None 

4170 

4171 

4172def _run_inline_trigger(trigger): 

4173 async def _run_inline_trigger_main(): 

4174 async for event in trigger.run(): 

4175 return event 

4176 

4177 return asyncio.run(_run_inline_trigger_main()) 

4178 

4179 

4180def _run_task(*, ti: TaskInstance, inline_trigger: bool = False, session: Session): 

4181 """ 

4182 Run a single task instance, and push result to Xcom for downstream tasks. 

4183 

4184 Bypasses a lot of extra steps used in `task.run` to keep our local running as fast as 

4185 possible. This function is only meant for the `dag.test` function as a helper function. 

4186 

4187 Args: 

4188 ti: TaskInstance to run 

4189 """ 

4190 log.info("[DAG TEST] starting task_id=%s map_index=%s", ti.task_id, ti.map_index) 

4191 while True: 

4192 try: 

4193 log.info("[DAG TEST] running task %s", ti) 

4194 ti._run_raw_task(session=session, raise_on_defer=inline_trigger) 

4195 break 

4196 except TaskDeferred as e: 

4197 log.info("[DAG TEST] running trigger in line") 

4198 event = _run_inline_trigger(e.trigger) 

4199 ti.next_method = e.method_name 

4200 ti.next_kwargs = {"event": event.payload} if event else e.kwargs 

4201 log.info("[DAG TEST] Trigger completed") 

4202 session.merge(ti) 

4203 session.commit() 

4204 log.info("[DAG TEST] end task task_id=%s map_index=%s", ti.task_id, ti.map_index) 

4205 

4206 

4207def _get_or_create_dagrun( 

4208 dag: DAG, 

4209 conf: dict[Any, Any] | None, 

4210 start_date: datetime, 

4211 execution_date: datetime, 

4212 run_id: str, 

4213 session: Session, 

4214 data_interval: tuple[datetime, datetime] | None = None, 

4215) -> DagRun: 

4216 """Create a DAG run, replacing an existing instance if needed to prevent collisions. 

4217 

4218 This function is only meant to be used by :meth:`DAG.test` as a helper function. 

4219 

4220 :param dag: DAG to be used to find run. 

4221 :param conf: Configuration to pass to newly created run. 

4222 :param start_date: Start date of new run. 

4223 :param execution_date: Logical date for finding an existing run. 

4224 :param run_id: Run ID for the new DAG run. 

4225 

4226 :return: The newly created DAG run. 

4227 """ 

4228 log.info("dagrun id: %s", dag.dag_id) 

4229 dr: DagRun = session.scalar( 

4230 select(DagRun).where(DagRun.dag_id == dag.dag_id, DagRun.execution_date == execution_date) 

4231 ) 

4232 if dr: 

4233 session.delete(dr) 

4234 session.commit() 

4235 dr = dag.create_dagrun( 

4236 state=DagRunState.RUNNING, 

4237 execution_date=execution_date, 

4238 run_id=run_id, 

4239 start_date=start_date or execution_date, 

4240 session=session, 

4241 conf=conf, 

4242 data_interval=data_interval, 

4243 ) 

4244 log.info("created dagrun %s", dr) 

4245 return dr