Coverage for /pythoncovmergedfiles/medio/medio/src/airflow/build/lib/airflow/models/dag.py: 32%

1516 statements  

« prev     ^ index     » next       coverage.py v7.0.1, created at 2022-12-25 06:11 +0000

1# 

2# Licensed to the Apache Software Foundation (ASF) under one 

3# or more contributor license agreements. See the NOTICE file 

4# distributed with this work for additional information 

5# regarding copyright ownership. The ASF licenses this file 

6# to you under the Apache License, Version 2.0 (the 

7# "License"); you may not use this file except in compliance 

8# with the License. You may obtain a copy of the License at 

9# 

10# http://www.apache.org/licenses/LICENSE-2.0 

11# 

12# Unless required by applicable law or agreed to in writing, 

13# software distributed under the License is distributed on an 

14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

15# KIND, either express or implied. See the License for the 

16# specific language governing permissions and limitations 

17# under the License. 

18from __future__ import annotations 

19 

20import collections 

21import copy 

22import functools 

23import itertools 

24import logging 

25import os 

26import pathlib 

27import pickle 

28import re 

29import sys 

30import traceback 

31import warnings 

32import weakref 

33from collections import deque 

34from datetime import datetime, timedelta 

35from inspect import signature 

36from typing import ( 

37 TYPE_CHECKING, 

38 Any, 

39 Callable, 

40 Collection, 

41 Deque, 

42 Iterable, 

43 Iterator, 

44 List, 

45 Sequence, 

46 Union, 

47 cast, 

48 overload, 

49) 

50from urllib.parse import urlsplit 

51 

52import jinja2 

53import pendulum 

54from dateutil.relativedelta import relativedelta 

55from pendulum.tz.timezone import Timezone 

56from sqlalchemy import Boolean, Column, ForeignKey, Index, Integer, String, Text, and_, case, func, not_, or_ 

57from sqlalchemy.ext.associationproxy import association_proxy 

58from sqlalchemy.orm import backref, joinedload, relationship 

59from sqlalchemy.orm.query import Query 

60from sqlalchemy.orm.session import Session 

61from sqlalchemy.sql import expression 

62 

63import airflow.templates 

64from airflow import settings, utils 

65from airflow.compat.functools import cached_property 

66from airflow.configuration import conf, secrets_backend_list 

67from airflow.exceptions import ( 

68 AirflowDagInconsistent, 

69 AirflowException, 

70 AirflowSkipException, 

71 DuplicateTaskIdFound, 

72 RemovedInAirflow3Warning, 

73 TaskNotFound, 

74) 

75from airflow.models.abstractoperator import AbstractOperator 

76from airflow.models.base import Base, StringID 

77from airflow.models.dagcode import DagCode 

78from airflow.models.dagpickle import DagPickle 

79from airflow.models.dagrun import DagRun 

80from airflow.models.operator import Operator 

81from airflow.models.param import DagParam, ParamsDict 

82from airflow.models.taskinstance import Context, TaskInstance, TaskInstanceKey, clear_task_instances 

83from airflow.secrets.local_filesystem import LocalFilesystemBackend 

84from airflow.security import permissions 

85from airflow.stats import Stats 

86from airflow.timetables.base import DagRunInfo, DataInterval, TimeRestriction, Timetable 

87from airflow.timetables.interval import CronDataIntervalTimetable, DeltaDataIntervalTimetable 

88from airflow.timetables.simple import DatasetTriggeredTimetable, NullTimetable, OnceTimetable 

89from airflow.typing_compat import Literal 

90from airflow.utils import timezone 

91from airflow.utils.dag_cycle_tester import check_cycle 

92from airflow.utils.dates import cron_presets, date_range as utils_date_range 

93from airflow.utils.decorators import fixup_decorator_warning_stack 

94from airflow.utils.file import correct_maybe_zipped 

95from airflow.utils.helpers import at_most_one, exactly_one, validate_key 

96from airflow.utils.log.logging_mixin import LoggingMixin 

97from airflow.utils.session import NEW_SESSION, provide_session 

98from airflow.utils.sqlalchemy import Interval, UtcDateTime, skip_locked, tuple_in_condition, with_row_locks 

99from airflow.utils.state import DagRunState, State, TaskInstanceState 

100from airflow.utils.types import NOTSET, ArgNotSet, DagRunType, EdgeInfoType 

101 

102if TYPE_CHECKING: 

103 from types import ModuleType 

104 

105 from airflow.datasets import Dataset 

106 from airflow.decorators import TaskDecoratorCollection 

107 from airflow.models.dagbag import DagBag 

108 from airflow.models.slamiss import SlaMiss 

109 from airflow.utils.task_group import TaskGroup 

110 

111 

112log = logging.getLogger(__name__) 

113 

114DEFAULT_VIEW_PRESETS = ["grid", "graph", "duration", "gantt", "landing_times"] 

115ORIENTATION_PRESETS = ["LR", "TB", "RL", "BT"] 

116 

117TAG_MAX_LEN = 100 

118 

119DagStateChangeCallback = Callable[[Context], None] 

120ScheduleInterval = Union[None, str, timedelta, relativedelta] 

121 

122# FIXME: Ideally this should be Union[Literal[NOTSET], ScheduleInterval], 

123# but Mypy cannot handle that right now. Track progress of PEP 661 for progress. 

124# See also: https://discuss.python.org/t/9126/7 

125ScheduleIntervalArg = Union[ArgNotSet, ScheduleInterval] 

126ScheduleArg = Union[ArgNotSet, ScheduleInterval, Timetable, Collection["Dataset"]] 

127 

128SLAMissCallback = Callable[["DAG", str, str, List["SlaMiss"], List[TaskInstance]], None] 

129 

130 

131# Backward compatibility: If neither schedule_interval nor timetable is 

132# *provided by the user*, default to a one-day interval. 

133DEFAULT_SCHEDULE_INTERVAL = timedelta(days=1) 

134 

135 

136class InconsistentDataInterval(AirflowException): 

137 """Exception raised when a model populates data interval fields incorrectly. 

138 

139 The data interval fields should either both be None (for runs scheduled 

140 prior to AIP-39), or both be datetime (for runs scheduled after AIP-39 is 

141 implemented). This is raised if exactly one of the fields is None. 

142 """ 

143 

144 _template = ( 

145 "Inconsistent {cls}: {start[0]}={start[1]!r}, {end[0]}={end[1]!r}, " 

146 "they must be either both None or both datetime" 

147 ) 

148 

149 def __init__(self, instance: Any, start_field_name: str, end_field_name: str) -> None: 

150 self._class_name = type(instance).__name__ 

151 self._start_field = (start_field_name, getattr(instance, start_field_name)) 

152 self._end_field = (end_field_name, getattr(instance, end_field_name)) 

153 

154 def __str__(self) -> str: 

155 return self._template.format(cls=self._class_name, start=self._start_field, end=self._end_field) 

156 

157 

158def _get_model_data_interval( 

159 instance: Any, 

160 start_field_name: str, 

161 end_field_name: str, 

162) -> DataInterval | None: 

163 start = timezone.coerce_datetime(getattr(instance, start_field_name)) 

164 end = timezone.coerce_datetime(getattr(instance, end_field_name)) 

165 if start is None: 

166 if end is not None: 

167 raise InconsistentDataInterval(instance, start_field_name, end_field_name) 

168 return None 

169 elif end is None: 

170 raise InconsistentDataInterval(instance, start_field_name, end_field_name) 

171 return DataInterval(start, end) 

172 

173 

174def create_timetable(interval: ScheduleIntervalArg, timezone: Timezone) -> Timetable: 

175 """Create a Timetable instance from a ``schedule_interval`` argument.""" 

176 if interval is NOTSET: 

177 return DeltaDataIntervalTimetable(DEFAULT_SCHEDULE_INTERVAL) 

178 if interval is None: 

179 return NullTimetable() 

180 if interval == "@once": 

181 return OnceTimetable() 

182 if isinstance(interval, (timedelta, relativedelta)): 

183 return DeltaDataIntervalTimetable(interval) 

184 if isinstance(interval, str): 

185 return CronDataIntervalTimetable(interval, timezone) 

186 raise ValueError(f"{interval!r} is not a valid schedule_interval.") 

187 

188 

189def get_last_dagrun(dag_id, session, include_externally_triggered=False): 

190 """ 

191 Returns the last dag run for a dag, None if there was none. 

192 Last dag run can be any type of run eg. scheduled or backfilled. 

193 Overridden DagRuns are ignored. 

194 """ 

195 DR = DagRun 

196 query = session.query(DR).filter(DR.dag_id == dag_id) 

197 if not include_externally_triggered: 

198 query = query.filter(DR.external_trigger == expression.false()) 

199 query = query.order_by(DR.execution_date.desc()) 

200 return query.first() 

201 

202 

203def get_dataset_triggered_next_run_info( 

204 dag_ids: list[str], *, session: Session 

205) -> dict[str, dict[str, int | str]]: 

206 """ 

207 Given a list of dag_ids, get string representing how close any that are dataset triggered are 

208 their next run, e.g. "1 of 2 datasets updated" 

209 """ 

210 from airflow.models.dataset import DagScheduleDatasetReference, DatasetDagRunQueue as DDRQ, DatasetModel 

211 

212 return { 

213 x.dag_id: { 

214 "uri": x.uri, 

215 "ready": x.ready, 

216 "total": x.total, 

217 } 

218 for x in session.query( 

219 DagScheduleDatasetReference.dag_id, 

220 # This is a dirty hack to workaround group by requiring an aggregate, since grouping by dataset 

221 # is not what we want to do here...but it works 

222 case((func.count() == 1, func.max(DatasetModel.uri)), else_="").label("uri"), 

223 func.count().label("total"), 

224 func.sum(case((DDRQ.target_dag_id.is_not(None), 1), else_=0)).label("ready"), 

225 ) 

226 .join( 

227 DDRQ, 

228 and_( 

229 DDRQ.dataset_id == DagScheduleDatasetReference.dataset_id, 

230 DDRQ.target_dag_id == DagScheduleDatasetReference.dag_id, 

231 ), 

232 isouter=True, 

233 ) 

234 .join( 

235 DatasetModel, 

236 DatasetModel.id == DagScheduleDatasetReference.dataset_id, 

237 ) 

238 .group_by( 

239 DagScheduleDatasetReference.dag_id, 

240 ) 

241 .filter(DagScheduleDatasetReference.dag_id.in_(dag_ids)) 

242 .all() 

243 } 

244 

245 

246@functools.total_ordering 

247class DAG(LoggingMixin): 

248 """ 

249 A dag (directed acyclic graph) is a collection of tasks with directional 

250 dependencies. A dag also has a schedule, a start date and an end date 

251 (optional). For each schedule, (say daily or hourly), the DAG needs to run 

252 each individual tasks as their dependencies are met. Certain tasks have 

253 the property of depending on their own past, meaning that they can't run 

254 until their previous schedule (and upstream tasks) are completed. 

255 

256 DAGs essentially act as namespaces for tasks. A task_id can only be 

257 added once to a DAG. 

258 

259 Note that if you plan to use time zones all the dates provided should be pendulum 

260 dates. See :ref:`timezone_aware_dags`. 

261 

262 .. versionadded:: 2.4 

263 The *schedule* argument to specify either time-based scheduling logic 

264 (timetable), or dataset-driven triggers. 

265 

266 .. deprecated:: 2.4 

267 The arguments *schedule_interval* and *timetable*. Their functionalities 

268 are merged into the new *schedule* argument. 

269 

270 :param dag_id: The id of the DAG; must consist exclusively of alphanumeric 

271 characters, dashes, dots and underscores (all ASCII) 

272 :param description: The description for the DAG to e.g. be shown on the webserver 

273 :param schedule: Defines the rules according to which DAG runs are scheduled. Can 

274 accept cron string, timedelta object, Timetable, or list of Dataset objects. 

275 See also :doc:`/howto/timetable`. 

276 :param start_date: The timestamp from which the scheduler will 

277 attempt to backfill 

278 :param end_date: A date beyond which your DAG won't run, leave to None 

279 for open ended scheduling 

280 :param template_searchpath: This list of folders (non relative) 

281 defines where jinja will look for your templates. Order matters. 

282 Note that jinja/airflow includes the path of your DAG file by 

283 default 

284 :param template_undefined: Template undefined type. 

285 :param user_defined_macros: a dictionary of macros that will be exposed 

286 in your jinja templates. For example, passing ``dict(foo='bar')`` 

287 to this argument allows you to ``{{ foo }}`` in all jinja 

288 templates related to this DAG. Note that you can pass any 

289 type of object here. 

290 :param user_defined_filters: a dictionary of filters that will be exposed 

291 in your jinja templates. For example, passing 

292 ``dict(hello=lambda name: 'Hello %s' % name)`` to this argument allows 

293 you to ``{{ 'world' | hello }}`` in all jinja templates related to 

294 this DAG. 

295 :param default_args: A dictionary of default parameters to be used 

296 as constructor keyword parameters when initialising operators. 

297 Note that operators have the same hook, and precede those defined 

298 here, meaning that if your dict contains `'depends_on_past': True` 

299 here and `'depends_on_past': False` in the operator's call 

300 `default_args`, the actual value will be `False`. 

301 :param params: a dictionary of DAG level parameters that are made 

302 accessible in templates, namespaced under `params`. These 

303 params can be overridden at the task level. 

304 :param max_active_tasks: the number of task instances allowed to run 

305 concurrently 

306 :param max_active_runs: maximum number of active DAG runs, beyond this 

307 number of DAG runs in a running state, the scheduler won't create 

308 new active DAG runs 

309 :param dagrun_timeout: specify how long a DagRun should be up before 

310 timing out / failing, so that new DagRuns can be created. The timeout 

311 is only enforced for scheduled DagRuns. 

312 :param sla_miss_callback: specify a function or list of functions to call when reporting SLA 

313 timeouts. See :ref:`sla_miss_callback<concepts:sla_miss_callback>` for 

314 more information about the function signature and parameters that are 

315 passed to the callback. 

316 :param default_view: Specify DAG default view (grid, graph, duration, 

317 gantt, landing_times), default grid 

318 :param orientation: Specify DAG orientation in graph view (LR, TB, RL, BT), default LR 

319 :param catchup: Perform scheduler catchup (or only run latest)? Defaults to True 

320 :param on_failure_callback: A function or list of functions to be called when a DagRun of this dag fails. 

321 A context dictionary is passed as a single parameter to this function. 

322 :param on_success_callback: Much like the ``on_failure_callback`` except 

323 that it is executed when the dag succeeds. 

324 :param access_control: Specify optional DAG-level actions, e.g., 

325 "{'role1': {'can_read'}, 'role2': {'can_read', 'can_edit', 'can_delete'}}" 

326 :param is_paused_upon_creation: Specifies if the dag is paused when created for the first time. 

327 If the dag exists already, this flag will be ignored. If this optional parameter 

328 is not specified, the global config setting will be used. 

329 :param jinja_environment_kwargs: additional configuration options to be passed to Jinja 

330 ``Environment`` for template rendering 

331 

332 **Example**: to avoid Jinja from removing a trailing newline from template strings :: 

333 

334 DAG(dag_id='my-dag', 

335 jinja_environment_kwargs={ 

336 'keep_trailing_newline': True, 

337 # some other jinja2 Environment options here 

338 } 

339 ) 

340 

341 **See**: `Jinja Environment documentation 

342 <https://jinja.palletsprojects.com/en/2.11.x/api/#jinja2.Environment>`_ 

343 

344 :param render_template_as_native_obj: If True, uses a Jinja ``NativeEnvironment`` 

345 to render templates as native Python types. If False, a Jinja 

346 ``Environment`` is used to render templates as string values. 

347 :param tags: List of tags to help filtering DAGs in the UI. 

348 :param owner_links: Dict of owners and their links, that will be clickable on the DAGs view UI. 

349 Can be used as an HTTP link (for example the link to your Slack channel), or a mailto link. 

350 e.g: {"dag_owner": "https://airflow.apache.org/"} 

351 :param auto_register: Automatically register this DAG when it is used in a ``with`` block 

352 """ 

353 

354 _comps = { 

355 "dag_id", 

356 "task_ids", 

357 "parent_dag", 

358 "start_date", 

359 "end_date", 

360 "schedule_interval", 

361 "fileloc", 

362 "template_searchpath", 

363 "last_loaded", 

364 } 

365 

366 __serialized_fields: frozenset[str] | None = None 

367 

368 fileloc: str 

369 """ 

370 File path that needs to be imported to load this DAG or subdag. 

371 

372 This may not be an actual file on disk in the case when this DAG is loaded 

373 from a ZIP file or other DAG distribution format. 

374 """ 

375 

376 parent_dag: DAG | None = None # Gets set when DAGs are loaded 

377 

378 # NOTE: When updating arguments here, please also keep arguments in @dag() 

379 # below in sync. (Search for 'def dag(' in this file.) 

380 def __init__( 

381 self, 

382 dag_id: str, 

383 description: str | None = None, 

384 schedule: ScheduleArg = NOTSET, 

385 schedule_interval: ScheduleIntervalArg = NOTSET, 

386 timetable: Timetable | None = None, 

387 start_date: datetime | None = None, 

388 end_date: datetime | None = None, 

389 full_filepath: str | None = None, 

390 template_searchpath: str | Iterable[str] | None = None, 

391 template_undefined: type[jinja2.StrictUndefined] = jinja2.StrictUndefined, 

392 user_defined_macros: dict | None = None, 

393 user_defined_filters: dict | None = None, 

394 default_args: dict | None = None, 

395 concurrency: int | None = None, 

396 max_active_tasks: int = conf.getint("core", "max_active_tasks_per_dag"), 

397 max_active_runs: int = conf.getint("core", "max_active_runs_per_dag"), 

398 dagrun_timeout: timedelta | None = None, 

399 sla_miss_callback: None | SLAMissCallback | list[SLAMissCallback] = None, 

400 default_view: str = conf.get_mandatory_value("webserver", "dag_default_view").lower(), 

401 orientation: str = conf.get_mandatory_value("webserver", "dag_orientation"), 

402 catchup: bool = conf.getboolean("scheduler", "catchup_by_default"), 

403 on_success_callback: None | DagStateChangeCallback | list[DagStateChangeCallback] = None, 

404 on_failure_callback: None | DagStateChangeCallback | list[DagStateChangeCallback] = None, 

405 doc_md: str | None = None, 

406 params: dict | None = None, 

407 access_control: dict | None = None, 

408 is_paused_upon_creation: bool | None = None, 

409 jinja_environment_kwargs: dict | None = None, 

410 render_template_as_native_obj: bool = False, 

411 tags: list[str] | None = None, 

412 owner_links: dict[str, str] | None = None, 

413 auto_register: bool = True, 

414 ): 

415 from airflow.utils.task_group import TaskGroup 

416 

417 if tags and any(len(tag) > TAG_MAX_LEN for tag in tags): 

418 raise AirflowException(f"tag cannot be longer than {TAG_MAX_LEN} characters") 

419 

420 self.owner_links = owner_links if owner_links else {} 

421 self.user_defined_macros = user_defined_macros 

422 self.user_defined_filters = user_defined_filters 

423 if default_args and not isinstance(default_args, dict): 

424 raise TypeError("default_args must be a dict") 

425 self.default_args = copy.deepcopy(default_args or {}) 

426 params = params or {} 

427 

428 # merging potentially conflicting default_args['params'] into params 

429 if "params" in self.default_args: 

430 params.update(self.default_args["params"]) 

431 del self.default_args["params"] 

432 

433 # check self.params and convert them into ParamsDict 

434 self.params = ParamsDict(params) 

435 

436 if full_filepath: 

437 warnings.warn( 

438 "Passing full_filepath to DAG() is deprecated and has no effect", 

439 RemovedInAirflow3Warning, 

440 stacklevel=2, 

441 ) 

442 

443 validate_key(dag_id) 

444 

445 self._dag_id = dag_id 

446 if concurrency: 

447 # TODO: Remove in Airflow 3.0 

448 warnings.warn( 

449 "The 'concurrency' parameter is deprecated. Please use 'max_active_tasks'.", 

450 RemovedInAirflow3Warning, 

451 stacklevel=2, 

452 ) 

453 max_active_tasks = concurrency 

454 self._max_active_tasks = max_active_tasks 

455 self._pickle_id: int | None = None 

456 

457 self._description = description 

458 # set file location to caller source path 

459 back = sys._getframe().f_back 

460 self.fileloc = back.f_code.co_filename if back else "" 

461 self.task_dict: dict[str, Operator] = {} 

462 

463 # set timezone from start_date 

464 tz = None 

465 if start_date and start_date.tzinfo: 

466 tzinfo = None if start_date.tzinfo else settings.TIMEZONE 

467 tz = pendulum.instance(start_date, tz=tzinfo).timezone 

468 elif "start_date" in self.default_args and self.default_args["start_date"]: 

469 date = self.default_args["start_date"] 

470 if not isinstance(date, datetime): 

471 date = timezone.parse(date) 

472 self.default_args["start_date"] = date 

473 start_date = date 

474 

475 tzinfo = None if date.tzinfo else settings.TIMEZONE 

476 tz = pendulum.instance(date, tz=tzinfo).timezone 

477 self.timezone = tz or settings.TIMEZONE 

478 

479 # Apply the timezone we settled on to end_date if it wasn't supplied 

480 if "end_date" in self.default_args and self.default_args["end_date"]: 

481 if isinstance(self.default_args["end_date"], str): 

482 self.default_args["end_date"] = timezone.parse( 

483 self.default_args["end_date"], timezone=self.timezone 

484 ) 

485 

486 self.start_date = timezone.convert_to_utc(start_date) 

487 self.end_date = timezone.convert_to_utc(end_date) 

488 

489 # also convert tasks 

490 if "start_date" in self.default_args: 

491 self.default_args["start_date"] = timezone.convert_to_utc(self.default_args["start_date"]) 

492 if "end_date" in self.default_args: 

493 self.default_args["end_date"] = timezone.convert_to_utc(self.default_args["end_date"]) 

494 

495 # sort out DAG's scheduling behavior 

496 scheduling_args = [schedule_interval, timetable, schedule] 

497 if not at_most_one(*scheduling_args): 

498 raise ValueError("At most one allowed for args 'schedule_interval', 'timetable', and 'schedule'.") 

499 if schedule_interval is not NOTSET: 

500 warnings.warn( 

501 "Param `schedule_interval` is deprecated and will be removed in a future release. " 

502 "Please use `schedule` instead. ", 

503 RemovedInAirflow3Warning, 

504 stacklevel=2, 

505 ) 

506 if timetable is not None: 

507 warnings.warn( 

508 "Param `timetable` is deprecated and will be removed in a future release. " 

509 "Please use `schedule` instead. ", 

510 RemovedInAirflow3Warning, 

511 stacklevel=2, 

512 ) 

513 

514 self.timetable: Timetable 

515 self.schedule_interval: ScheduleInterval 

516 self.dataset_triggers: Collection[Dataset] = [] 

517 

518 if isinstance(schedule, Collection) and not isinstance(schedule, str): 

519 from airflow.datasets import Dataset 

520 

521 if not all(isinstance(x, Dataset) for x in schedule): 

522 raise ValueError("All elements in 'schedule' should be datasets") 

523 self.dataset_triggers = list(schedule) 

524 elif isinstance(schedule, Timetable): 

525 timetable = schedule 

526 elif schedule is not NOTSET: 

527 schedule_interval = schedule 

528 

529 if self.dataset_triggers: 

530 self.timetable = DatasetTriggeredTimetable() 

531 self.schedule_interval = self.timetable.summary 

532 elif timetable: 

533 self.timetable = timetable 

534 self.schedule_interval = self.timetable.summary 

535 else: 

536 if isinstance(schedule_interval, ArgNotSet): 

537 schedule_interval = DEFAULT_SCHEDULE_INTERVAL 

538 self.schedule_interval = schedule_interval 

539 self.timetable = create_timetable(schedule_interval, self.timezone) 

540 

541 if isinstance(template_searchpath, str): 

542 template_searchpath = [template_searchpath] 

543 self.template_searchpath = template_searchpath 

544 self.template_undefined = template_undefined 

545 self.last_loaded = timezone.utcnow() 

546 self.safe_dag_id = dag_id.replace(".", "__dot__") 

547 self.max_active_runs = max_active_runs 

548 self.dagrun_timeout = dagrun_timeout 

549 self.sla_miss_callback = sla_miss_callback 

550 if default_view in DEFAULT_VIEW_PRESETS: 

551 self._default_view: str = default_view 

552 elif default_view == "tree": 

553 warnings.warn( 

554 "`default_view` of 'tree' has been renamed to 'grid' -- please update your DAG", 

555 RemovedInAirflow3Warning, 

556 stacklevel=2, 

557 ) 

558 self._default_view = "grid" 

559 else: 

560 raise AirflowException( 

561 f"Invalid values of dag.default_view: only support " 

562 f"{DEFAULT_VIEW_PRESETS}, but get {default_view}" 

563 ) 

564 if orientation in ORIENTATION_PRESETS: 

565 self.orientation = orientation 

566 else: 

567 raise AirflowException( 

568 f"Invalid values of dag.orientation: only support " 

569 f"{ORIENTATION_PRESETS}, but get {orientation}" 

570 ) 

571 self.catchup = catchup 

572 

573 self.partial = False 

574 self.on_success_callback = on_success_callback 

575 self.on_failure_callback = on_failure_callback 

576 

577 # Keeps track of any extra edge metadata (sparse; will not contain all 

578 # edges, so do not iterate over it for that). Outer key is upstream 

579 # task ID, inner key is downstream task ID. 

580 self.edge_info: dict[str, dict[str, EdgeInfoType]] = {} 

581 

582 # To keep it in parity with Serialized DAGs 

583 # and identify if DAG has on_*_callback without actually storing them in Serialized JSON 

584 self.has_on_success_callback = self.on_success_callback is not None 

585 self.has_on_failure_callback = self.on_failure_callback is not None 

586 

587 self._access_control = DAG._upgrade_outdated_dag_access_control(access_control) 

588 self.is_paused_upon_creation = is_paused_upon_creation 

589 self.auto_register = auto_register 

590 

591 self.jinja_environment_kwargs = jinja_environment_kwargs 

592 self.render_template_as_native_obj = render_template_as_native_obj 

593 

594 self.doc_md = self.get_doc_md(doc_md) 

595 

596 self.tags = tags or [] 

597 self._task_group = TaskGroup.create_root(self) 

598 self.validate_schedule_and_params() 

599 wrong_links = dict(self.iter_invalid_owner_links()) 

600 if wrong_links: 

601 raise AirflowException( 

602 "Wrong link format was used for the owner. Use a valid link \n" 

603 f"Bad formatted links are: {wrong_links}" 

604 ) 

605 

606 # this will only be set at serialization time 

607 # it's only use is for determining the relative 

608 # fileloc based only on the serialize dag 

609 self._processor_dags_folder = None 

610 

611 def get_doc_md(self, doc_md: str | None) -> str | None: 

612 if doc_md is None: 

613 return doc_md 

614 

615 env = self.get_template_env(force_sandboxed=True) 

616 

617 if not doc_md.endswith(".md"): 

618 template = jinja2.Template(doc_md) 

619 else: 

620 try: 

621 template = env.get_template(doc_md) 

622 except jinja2.exceptions.TemplateNotFound: 

623 return f""" 

624 # Templating Error! 

625 Not able to find the template file: `{doc_md}`. 

626 """ 

627 

628 return template.render() 

629 

630 def _check_schedule_interval_matches_timetable(self) -> bool: 

631 """Check ``schedule_interval`` and ``timetable`` match. 

632 

633 This is done as a part of the DAG validation done before it's bagged, to 

634 guard against the DAG's ``timetable`` (or ``schedule_interval``) from 

635 being changed after it's created, e.g. 

636 

637 .. code-block:: python 

638 

639 dag1 = DAG("d1", timetable=MyTimetable()) 

640 dag1.schedule_interval = "@once" 

641 

642 dag2 = DAG("d2", schedule="@once") 

643 dag2.timetable = MyTimetable() 

644 

645 Validation is done by creating a timetable and check its summary matches 

646 ``schedule_interval``. The logic is not bullet-proof, especially if a 

647 custom timetable does not provide a useful ``summary``. But this is the 

648 best we can do. 

649 """ 

650 if self.schedule_interval == self.timetable.summary: 

651 return True 

652 try: 

653 timetable = create_timetable(self.schedule_interval, self.timezone) 

654 except ValueError: 

655 return False 

656 return timetable.summary == self.timetable.summary 

657 

658 def validate(self): 

659 """Validate the DAG has a coherent setup. 

660 

661 This is called by the DAG bag before bagging the DAG. 

662 """ 

663 if not self._check_schedule_interval_matches_timetable(): 

664 raise AirflowDagInconsistent( 

665 f"inconsistent schedule: timetable {self.timetable.summary!r} " 

666 f"does not match schedule_interval {self.schedule_interval!r}", 

667 ) 

668 self.params.validate() 

669 self.timetable.validate() 

670 

671 def __repr__(self): 

672 return f"<DAG: {self.dag_id}>" 

673 

674 def __eq__(self, other): 

675 if type(self) == type(other): 

676 # Use getattr() instead of __dict__ as __dict__ doesn't return 

677 # correct values for properties. 

678 return all(getattr(self, c, None) == getattr(other, c, None) for c in self._comps) 

679 return False 

680 

681 def __ne__(self, other): 

682 return not self == other 

683 

684 def __lt__(self, other): 

685 return self.dag_id < other.dag_id 

686 

687 def __hash__(self): 

688 hash_components = [type(self)] 

689 for c in self._comps: 

690 # task_ids returns a list and lists can't be hashed 

691 if c == "task_ids": 

692 val = tuple(self.task_dict.keys()) 

693 else: 

694 val = getattr(self, c, None) 

695 try: 

696 hash(val) 

697 hash_components.append(val) 

698 except TypeError: 

699 hash_components.append(repr(val)) 

700 return hash(tuple(hash_components)) 

701 

702 # Context Manager ----------------------------------------------- 

703 def __enter__(self): 

704 DagContext.push_context_managed_dag(self) 

705 return self 

706 

707 def __exit__(self, _type, _value, _tb): 

708 DagContext.pop_context_managed_dag() 

709 

710 # /Context Manager ---------------------------------------------- 

711 

712 @staticmethod 

713 def _upgrade_outdated_dag_access_control(access_control=None): 

714 """ 

715 Looks for outdated dag level actions (can_dag_read and can_dag_edit) in DAG 

716 access_controls (for example, {'role1': {'can_dag_read'}, 'role2': {'can_dag_read', 'can_dag_edit'}}) 

717 and replaces them with updated actions (can_read and can_edit). 

718 """ 

719 if not access_control: 

720 return None 

721 new_perm_mapping = { 

722 permissions.DEPRECATED_ACTION_CAN_DAG_READ: permissions.ACTION_CAN_READ, 

723 permissions.DEPRECATED_ACTION_CAN_DAG_EDIT: permissions.ACTION_CAN_EDIT, 

724 } 

725 updated_access_control = {} 

726 for role, perms in access_control.items(): 

727 updated_access_control[role] = {new_perm_mapping.get(perm, perm) for perm in perms} 

728 

729 if access_control != updated_access_control: 

730 warnings.warn( 

731 "The 'can_dag_read' and 'can_dag_edit' permissions are deprecated. " 

732 "Please use 'can_read' and 'can_edit', respectively.", 

733 RemovedInAirflow3Warning, 

734 stacklevel=3, 

735 ) 

736 

737 return updated_access_control 

738 

739 def date_range( 

740 self, 

741 start_date: pendulum.DateTime, 

742 num: int | None = None, 

743 end_date: datetime | None = None, 

744 ) -> list[datetime]: 

745 message = "`DAG.date_range()` is deprecated." 

746 if num is not None: 

747 warnings.warn(message, category=RemovedInAirflow3Warning, stacklevel=2) 

748 with warnings.catch_warnings(): 

749 warnings.simplefilter("ignore", RemovedInAirflow3Warning) 

750 return utils_date_range( 

751 start_date=start_date, num=num, delta=self.normalized_schedule_interval 

752 ) 

753 message += " Please use `DAG.iter_dagrun_infos_between(..., align=False)` instead." 

754 warnings.warn(message, category=RemovedInAirflow3Warning, stacklevel=2) 

755 if end_date is None: 

756 coerced_end_date = timezone.utcnow() 

757 else: 

758 coerced_end_date = end_date 

759 it = self.iter_dagrun_infos_between(start_date, pendulum.instance(coerced_end_date), align=False) 

760 return [info.logical_date for info in it] 

761 

762 def is_fixed_time_schedule(self): 

763 warnings.warn( 

764 "`DAG.is_fixed_time_schedule()` is deprecated.", 

765 category=RemovedInAirflow3Warning, 

766 stacklevel=2, 

767 ) 

768 try: 

769 return not self.timetable._should_fix_dst 

770 except AttributeError: 

771 return True 

772 

773 def following_schedule(self, dttm): 

774 """ 

775 Calculates the following schedule for this dag in UTC. 

776 

777 :param dttm: utc datetime 

778 :return: utc datetime 

779 """ 

780 warnings.warn( 

781 "`DAG.following_schedule()` is deprecated. Use `DAG.next_dagrun_info(restricted=False)` instead.", 

782 category=RemovedInAirflow3Warning, 

783 stacklevel=2, 

784 ) 

785 data_interval = self.infer_automated_data_interval(timezone.coerce_datetime(dttm)) 

786 next_info = self.next_dagrun_info(data_interval, restricted=False) 

787 if next_info is None: 

788 return None 

789 return next_info.data_interval.start 

790 

791 def previous_schedule(self, dttm): 

792 from airflow.timetables.interval import _DataIntervalTimetable 

793 

794 warnings.warn( 

795 "`DAG.previous_schedule()` is deprecated.", 

796 category=RemovedInAirflow3Warning, 

797 stacklevel=2, 

798 ) 

799 if not isinstance(self.timetable, _DataIntervalTimetable): 

800 return None 

801 return self.timetable._get_prev(timezone.coerce_datetime(dttm)) 

802 

803 def get_next_data_interval(self, dag_model: DagModel) -> DataInterval | None: 

804 """Get the data interval of the next scheduled run. 

805 

806 For compatibility, this method infers the data interval from the DAG's 

807 schedule if the run does not have an explicit one set, which is possible 

808 for runs created prior to AIP-39. 

809 

810 This function is private to Airflow core and should not be depended as a 

811 part of the Python API. 

812 

813 :meta private: 

814 """ 

815 if self.dag_id != dag_model.dag_id: 

816 raise ValueError(f"Arguments refer to different DAGs: {self.dag_id} != {dag_model.dag_id}") 

817 if dag_model.next_dagrun is None: # Next run not scheduled. 

818 return None 

819 data_interval = dag_model.next_dagrun_data_interval 

820 if data_interval is not None: 

821 return data_interval 

822 

823 # Compatibility: A run was scheduled without an explicit data interval. 

824 # This means the run was scheduled before AIP-39 implementation. Try to 

825 # infer from the logical date. 

826 return self.infer_automated_data_interval(dag_model.next_dagrun) 

827 

828 def get_run_data_interval(self, run: DagRun) -> DataInterval: 

829 """Get the data interval of this run. 

830 

831 For compatibility, this method infers the data interval from the DAG's 

832 schedule if the run does not have an explicit one set, which is possible for 

833 runs created prior to AIP-39. 

834 

835 This function is private to Airflow core and should not be depended as a 

836 part of the Python API. 

837 

838 :meta private: 

839 """ 

840 if run.dag_id is not None and run.dag_id != self.dag_id: 

841 raise ValueError(f"Arguments refer to different DAGs: {self.dag_id} != {run.dag_id}") 

842 data_interval = _get_model_data_interval(run, "data_interval_start", "data_interval_end") 

843 if data_interval is not None: 

844 return data_interval 

845 # Compatibility: runs created before AIP-39 implementation don't have an 

846 # explicit data interval. Try to infer from the logical date. 

847 return self.infer_automated_data_interval(run.execution_date) 

848 

849 def infer_automated_data_interval(self, logical_date: datetime) -> DataInterval: 

850 """Infer a data interval for a run against this DAG. 

851 

852 This method is used to bridge runs created prior to AIP-39 

853 implementation, which do not have an explicit data interval. Therefore, 

854 this method only considers ``schedule_interval`` values valid prior to 

855 Airflow 2.2. 

856 

857 DO NOT use this method is there is a known data interval. 

858 """ 

859 timetable_type = type(self.timetable) 

860 if issubclass(timetable_type, (NullTimetable, OnceTimetable)): 

861 return DataInterval.exact(timezone.coerce_datetime(logical_date)) 

862 start = timezone.coerce_datetime(logical_date) 

863 if issubclass(timetable_type, CronDataIntervalTimetable): 

864 end = cast(CronDataIntervalTimetable, self.timetable)._get_next(start) 

865 elif issubclass(timetable_type, DeltaDataIntervalTimetable): 

866 end = cast(DeltaDataIntervalTimetable, self.timetable)._get_next(start) 

867 else: 

868 raise ValueError(f"Not a valid timetable: {self.timetable!r}") 

869 return DataInterval(start, end) 

870 

871 def next_dagrun_info( 

872 self, 

873 last_automated_dagrun: None | datetime | DataInterval, 

874 *, 

875 restricted: bool = True, 

876 ) -> DagRunInfo | None: 

877 """Get information about the next DagRun of this dag after ``date_last_automated_dagrun``. 

878 

879 This calculates what time interval the next DagRun should operate on 

880 (its execution date) and when it can be scheduled, according to the 

881 dag's timetable, start_date, end_date, etc. This doesn't check max 

882 active run or any other "max_active_tasks" type limits, but only 

883 performs calculations based on the various date and interval fields of 

884 this dag and its tasks. 

885 

886 :param last_automated_dagrun: The ``max(execution_date)`` of 

887 existing "automated" DagRuns for this dag (scheduled or backfill, 

888 but not manual). 

889 :param restricted: If set to *False* (default is *True*), ignore 

890 ``start_date``, ``end_date``, and ``catchup`` specified on the DAG 

891 or tasks. 

892 :return: DagRunInfo of the next dagrun, or None if a dagrun is not 

893 going to be scheduled. 

894 """ 

895 # Never schedule a subdag. It will be scheduled by its parent dag. 

896 if self.is_subdag: 

897 return None 

898 

899 data_interval = None 

900 if isinstance(last_automated_dagrun, datetime): 

901 warnings.warn( 

902 "Passing a datetime to DAG.next_dagrun_info is deprecated. Use a DataInterval instead.", 

903 RemovedInAirflow3Warning, 

904 stacklevel=2, 

905 ) 

906 data_interval = self.infer_automated_data_interval( 

907 timezone.coerce_datetime(last_automated_dagrun) 

908 ) 

909 else: 

910 data_interval = last_automated_dagrun 

911 if restricted: 

912 restriction = self._time_restriction 

913 else: 

914 restriction = TimeRestriction(earliest=None, latest=None, catchup=True) 

915 try: 

916 info = self.timetable.next_dagrun_info( 

917 last_automated_data_interval=data_interval, 

918 restriction=restriction, 

919 ) 

920 except Exception: 

921 self.log.exception( 

922 "Failed to fetch run info after data interval %s for DAG %r", 

923 data_interval, 

924 self.dag_id, 

925 ) 

926 info = None 

927 return info 

928 

929 def next_dagrun_after_date(self, date_last_automated_dagrun: pendulum.DateTime | None): 

930 warnings.warn( 

931 "`DAG.next_dagrun_after_date()` is deprecated. Please use `DAG.next_dagrun_info()` instead.", 

932 category=RemovedInAirflow3Warning, 

933 stacklevel=2, 

934 ) 

935 if date_last_automated_dagrun is None: 

936 data_interval = None 

937 else: 

938 data_interval = self.infer_automated_data_interval(date_last_automated_dagrun) 

939 info = self.next_dagrun_info(data_interval) 

940 if info is None: 

941 return None 

942 return info.run_after 

943 

944 @cached_property 

945 def _time_restriction(self) -> TimeRestriction: 

946 start_dates = [t.start_date for t in self.tasks if t.start_date] 

947 if self.start_date is not None: 

948 start_dates.append(self.start_date) 

949 earliest = None 

950 if start_dates: 

951 earliest = timezone.coerce_datetime(min(start_dates)) 

952 latest = self.end_date 

953 end_dates = [t.end_date for t in self.tasks if t.end_date] 

954 if len(end_dates) == len(self.tasks): # not exists null end_date 

955 if self.end_date is not None: 

956 end_dates.append(self.end_date) 

957 if end_dates: 

958 latest = timezone.coerce_datetime(max(end_dates)) 

959 return TimeRestriction(earliest, latest, self.catchup) 

960 

961 def iter_dagrun_infos_between( 

962 self, 

963 earliest: pendulum.DateTime | None, 

964 latest: pendulum.DateTime, 

965 *, 

966 align: bool = True, 

967 ) -> Iterable[DagRunInfo]: 

968 """Yield DagRunInfo using this DAG's timetable between given interval. 

969 

970 DagRunInfo instances yielded if their ``logical_date`` is not earlier 

971 than ``earliest``, nor later than ``latest``. The instances are ordered 

972 by their ``logical_date`` from earliest to latest. 

973 

974 If ``align`` is ``False``, the first run will happen immediately on 

975 ``earliest``, even if it does not fall on the logical timetable schedule. 

976 The default is ``True``, but subdags will ignore this value and always 

977 behave as if this is set to ``False`` for backward compatibility. 

978 

979 Example: A DAG is scheduled to run every midnight (``0 0 * * *``). If 

980 ``earliest`` is ``2021-06-03 23:00:00``, the first DagRunInfo would be 

981 ``2021-06-03 23:00:00`` if ``align=False``, and ``2021-06-04 00:00:00`` 

982 if ``align=True``. 

983 """ 

984 if earliest is None: 

985 earliest = self._time_restriction.earliest 

986 if earliest is None: 

987 raise ValueError("earliest was None and we had no value in time_restriction to fallback on") 

988 earliest = timezone.coerce_datetime(earliest) 

989 latest = timezone.coerce_datetime(latest) 

990 

991 restriction = TimeRestriction(earliest, latest, catchup=True) 

992 

993 # HACK: Sub-DAGs are currently scheduled differently. For example, say 

994 # the schedule is @daily and start is 2021-06-03 22:16:00, a top-level 

995 # DAG should be first scheduled to run on midnight 2021-06-04, but a 

996 # sub-DAG should be first scheduled to run RIGHT NOW. We can change 

997 # this, but since sub-DAGs are going away in 3.0 anyway, let's keep 

998 # compatibility for now and remove this entirely later. 

999 if self.is_subdag: 

1000 align = False 

1001 

1002 try: 

1003 info = self.timetable.next_dagrun_info( 

1004 last_automated_data_interval=None, 

1005 restriction=restriction, 

1006 ) 

1007 except Exception: 

1008 self.log.exception( 

1009 "Failed to fetch run info after data interval %s for DAG %r", 

1010 None, 

1011 self.dag_id, 

1012 ) 

1013 info = None 

1014 

1015 if info is None: 

1016 # No runs to be scheduled between the user-supplied timeframe. But 

1017 # if align=False, "invent" a data interval for the timeframe itself. 

1018 if not align: 

1019 yield DagRunInfo.interval(earliest, latest) 

1020 return 

1021 

1022 # If align=False and earliest does not fall on the timetable's logical 

1023 # schedule, "invent" a data interval for it. 

1024 if not align and info.logical_date != earliest: 

1025 yield DagRunInfo.interval(earliest, info.data_interval.start) 

1026 

1027 # Generate naturally according to schedule. 

1028 while info is not None: 

1029 yield info 

1030 try: 

1031 info = self.timetable.next_dagrun_info( 

1032 last_automated_data_interval=info.data_interval, 

1033 restriction=restriction, 

1034 ) 

1035 except Exception: 

1036 self.log.exception( 

1037 "Failed to fetch run info after data interval %s for DAG %r", 

1038 info.data_interval if info else "<NONE>", 

1039 self.dag_id, 

1040 ) 

1041 break 

1042 

1043 def get_run_dates(self, start_date, end_date=None) -> list: 

1044 """ 

1045 Returns a list of dates between the interval received as parameter using this 

1046 dag's schedule interval. Returned dates can be used for execution dates. 

1047 

1048 :param start_date: The start date of the interval. 

1049 :param end_date: The end date of the interval. Defaults to ``timezone.utcnow()``. 

1050 :return: A list of dates within the interval following the dag's schedule. 

1051 """ 

1052 warnings.warn( 

1053 "`DAG.get_run_dates()` is deprecated. Please use `DAG.iter_dagrun_infos_between()` instead.", 

1054 category=RemovedInAirflow3Warning, 

1055 stacklevel=2, 

1056 ) 

1057 earliest = timezone.coerce_datetime(start_date) 

1058 if end_date is None: 

1059 latest = pendulum.now(timezone.utc) 

1060 else: 

1061 latest = timezone.coerce_datetime(end_date) 

1062 return [info.logical_date for info in self.iter_dagrun_infos_between(earliest, latest)] 

1063 

1064 def normalize_schedule(self, dttm): 

1065 warnings.warn( 

1066 "`DAG.normalize_schedule()` is deprecated.", 

1067 category=RemovedInAirflow3Warning, 

1068 stacklevel=2, 

1069 ) 

1070 with warnings.catch_warnings(): 

1071 warnings.simplefilter("ignore", RemovedInAirflow3Warning) 

1072 following = self.following_schedule(dttm) 

1073 if not following: # in case of @once 

1074 return dttm 

1075 with warnings.catch_warnings(): 

1076 warnings.simplefilter("ignore", RemovedInAirflow3Warning) 

1077 previous_of_following = self.previous_schedule(following) 

1078 if previous_of_following != dttm: 

1079 return following 

1080 return dttm 

1081 

1082 @provide_session 

1083 def get_last_dagrun(self, session=NEW_SESSION, include_externally_triggered=False): 

1084 return get_last_dagrun( 

1085 self.dag_id, session=session, include_externally_triggered=include_externally_triggered 

1086 ) 

1087 

1088 @provide_session 

1089 def has_dag_runs(self, session=NEW_SESSION, include_externally_triggered=True) -> bool: 

1090 return ( 

1091 get_last_dagrun( 

1092 self.dag_id, session=session, include_externally_triggered=include_externally_triggered 

1093 ) 

1094 is not None 

1095 ) 

1096 

1097 @property 

1098 def dag_id(self) -> str: 

1099 return self._dag_id 

1100 

1101 @dag_id.setter 

1102 def dag_id(self, value: str) -> None: 

1103 self._dag_id = value 

1104 

1105 @property 

1106 def is_subdag(self) -> bool: 

1107 return self.parent_dag is not None 

1108 

1109 @property 

1110 def full_filepath(self) -> str: 

1111 """:meta private:""" 

1112 warnings.warn( 

1113 "DAG.full_filepath is deprecated in favour of fileloc", 

1114 RemovedInAirflow3Warning, 

1115 stacklevel=2, 

1116 ) 

1117 return self.fileloc 

1118 

1119 @full_filepath.setter 

1120 def full_filepath(self, value) -> None: 

1121 warnings.warn( 

1122 "DAG.full_filepath is deprecated in favour of fileloc", 

1123 RemovedInAirflow3Warning, 

1124 stacklevel=2, 

1125 ) 

1126 self.fileloc = value 

1127 

1128 @property 

1129 def concurrency(self) -> int: 

1130 # TODO: Remove in Airflow 3.0 

1131 warnings.warn( 

1132 "The 'DAG.concurrency' attribute is deprecated. Please use 'DAG.max_active_tasks'.", 

1133 RemovedInAirflow3Warning, 

1134 stacklevel=2, 

1135 ) 

1136 return self._max_active_tasks 

1137 

1138 @concurrency.setter 

1139 def concurrency(self, value: int): 

1140 self._max_active_tasks = value 

1141 

1142 @property 

1143 def max_active_tasks(self) -> int: 

1144 return self._max_active_tasks 

1145 

1146 @max_active_tasks.setter 

1147 def max_active_tasks(self, value: int): 

1148 self._max_active_tasks = value 

1149 

1150 @property 

1151 def access_control(self): 

1152 return self._access_control 

1153 

1154 @access_control.setter 

1155 def access_control(self, value): 

1156 self._access_control = DAG._upgrade_outdated_dag_access_control(value) 

1157 

1158 @property 

1159 def description(self) -> str | None: 

1160 return self._description 

1161 

1162 @property 

1163 def default_view(self) -> str: 

1164 return self._default_view 

1165 

1166 @property 

1167 def pickle_id(self) -> int | None: 

1168 return self._pickle_id 

1169 

1170 @pickle_id.setter 

1171 def pickle_id(self, value: int) -> None: 

1172 self._pickle_id = value 

1173 

1174 def param(self, name: str, default: Any = NOTSET) -> DagParam: 

1175 """ 

1176 Return a DagParam object for current dag. 

1177 

1178 :param name: dag parameter name. 

1179 :param default: fallback value for dag parameter. 

1180 :return: DagParam instance for specified name and current dag. 

1181 """ 

1182 return DagParam(current_dag=self, name=name, default=default) 

1183 

1184 @property 

1185 def tasks(self) -> list[Operator]: 

1186 return list(self.task_dict.values()) 

1187 

1188 @tasks.setter 

1189 def tasks(self, val): 

1190 raise AttributeError("DAG.tasks can not be modified. Use dag.add_task() instead.") 

1191 

1192 @property 

1193 def task_ids(self) -> list[str]: 

1194 return list(self.task_dict.keys()) 

1195 

1196 @property 

1197 def task_group(self) -> TaskGroup: 

1198 return self._task_group 

1199 

1200 @property 

1201 def filepath(self) -> str: 

1202 """:meta private:""" 

1203 warnings.warn( 

1204 "filepath is deprecated, use relative_fileloc instead", 

1205 RemovedInAirflow3Warning, 

1206 stacklevel=2, 

1207 ) 

1208 return str(self.relative_fileloc) 

1209 

1210 @property 

1211 def relative_fileloc(self) -> pathlib.Path: 

1212 """File location of the importable dag 'file' relative to the configured DAGs folder.""" 

1213 path = pathlib.Path(self.fileloc) 

1214 try: 

1215 rel_path = path.relative_to(self._processor_dags_folder or settings.DAGS_FOLDER) 

1216 if rel_path == pathlib.Path("."): 

1217 return path 

1218 else: 

1219 return rel_path 

1220 except ValueError: 

1221 # Not relative to DAGS_FOLDER. 

1222 return path 

1223 

1224 @property 

1225 def folder(self) -> str: 

1226 """Folder location of where the DAG object is instantiated.""" 

1227 return os.path.dirname(self.fileloc) 

1228 

1229 @property 

1230 def owner(self) -> str: 

1231 """ 

1232 Return list of all owners found in DAG tasks. 

1233 

1234 :return: Comma separated list of owners in DAG tasks 

1235 """ 

1236 return ", ".join({t.owner for t in self.tasks}) 

1237 

1238 @property 

1239 def allow_future_exec_dates(self) -> bool: 

1240 return settings.ALLOW_FUTURE_EXEC_DATES and not self.timetable.can_run 

1241 

1242 @provide_session 

1243 def get_concurrency_reached(self, session=NEW_SESSION) -> bool: 

1244 """ 

1245 Returns a boolean indicating whether the max_active_tasks limit for this DAG 

1246 has been reached 

1247 """ 

1248 TI = TaskInstance 

1249 qry = session.query(func.count(TI.task_id)).filter( 

1250 TI.dag_id == self.dag_id, 

1251 TI.state == State.RUNNING, 

1252 ) 

1253 return qry.scalar() >= self.max_active_tasks 

1254 

1255 @property 

1256 def concurrency_reached(self): 

1257 """This attribute is deprecated. Please use `airflow.models.DAG.get_concurrency_reached` method.""" 

1258 warnings.warn( 

1259 "This attribute is deprecated. Please use `airflow.models.DAG.get_concurrency_reached` method.", 

1260 RemovedInAirflow3Warning, 

1261 stacklevel=2, 

1262 ) 

1263 return self.get_concurrency_reached() 

1264 

1265 @provide_session 

1266 def get_is_active(self, session=NEW_SESSION) -> None: 

1267 """Returns a boolean indicating whether this DAG is active""" 

1268 return session.query(DagModel.is_active).filter(DagModel.dag_id == self.dag_id).scalar() 

1269 

1270 @provide_session 

1271 def get_is_paused(self, session=NEW_SESSION) -> None: 

1272 """Returns a boolean indicating whether this DAG is paused""" 

1273 return session.query(DagModel.is_paused).filter(DagModel.dag_id == self.dag_id).scalar() 

1274 

1275 @property 

1276 def is_paused(self): 

1277 """This attribute is deprecated. Please use `airflow.models.DAG.get_is_paused` method.""" 

1278 warnings.warn( 

1279 "This attribute is deprecated. Please use `airflow.models.DAG.get_is_paused` method.", 

1280 RemovedInAirflow3Warning, 

1281 stacklevel=2, 

1282 ) 

1283 return self.get_is_paused() 

1284 

1285 @property 

1286 def normalized_schedule_interval(self) -> ScheduleInterval: 

1287 warnings.warn( 

1288 "DAG.normalized_schedule_interval() is deprecated.", 

1289 category=RemovedInAirflow3Warning, 

1290 stacklevel=2, 

1291 ) 

1292 if isinstance(self.schedule_interval, str) and self.schedule_interval in cron_presets: 

1293 _schedule_interval: ScheduleInterval = cron_presets.get(self.schedule_interval) 

1294 elif self.schedule_interval == "@once": 

1295 _schedule_interval = None 

1296 else: 

1297 _schedule_interval = self.schedule_interval 

1298 return _schedule_interval 

1299 

1300 @provide_session 

1301 def handle_callback(self, dagrun, success=True, reason=None, session=NEW_SESSION): 

1302 """ 

1303 Triggers the appropriate callback depending on the value of success, namely the 

1304 on_failure_callback or on_success_callback. This method gets the context of a 

1305 single TaskInstance part of this DagRun and passes that to the callable along 

1306 with a 'reason', primarily to differentiate DagRun failures. 

1307 

1308 .. note: The logs end up in 

1309 ``$AIRFLOW_HOME/logs/scheduler/latest/PROJECT/DAG_FILE.py.log`` 

1310 

1311 :param dagrun: DagRun object 

1312 :param success: Flag to specify if failure or success callback should be called 

1313 :param reason: Completion reason 

1314 :param session: Database session 

1315 """ 

1316 callbacks = self.on_success_callback if success else self.on_failure_callback 

1317 if callbacks: 

1318 callbacks = callbacks if isinstance(callbacks, list) else [callbacks] 

1319 tis = dagrun.get_task_instances(session=session) 

1320 ti = tis[-1] # get first TaskInstance of DagRun 

1321 ti.task = self.get_task(ti.task_id) 

1322 context = ti.get_template_context(session=session) 

1323 context.update({"reason": reason}) 

1324 for callback in callbacks: 

1325 self.log.info("Executing dag callback function: %s", callback) 

1326 try: 

1327 callback(context) 

1328 except Exception: 

1329 self.log.exception("failed to invoke dag state update callback") 

1330 Stats.incr("dag.callback_exceptions") 

1331 

1332 def get_active_runs(self): 

1333 """ 

1334 Returns a list of dag run execution dates currently running 

1335 

1336 :return: List of execution dates 

1337 """ 

1338 runs = DagRun.find(dag_id=self.dag_id, state=State.RUNNING) 

1339 

1340 active_dates = [] 

1341 for run in runs: 

1342 active_dates.append(run.execution_date) 

1343 

1344 return active_dates 

1345 

1346 @provide_session 

1347 def get_num_active_runs(self, external_trigger=None, only_running=True, session=NEW_SESSION): 

1348 """ 

1349 Returns the number of active "running" dag runs 

1350 

1351 :param external_trigger: True for externally triggered active dag runs 

1352 :param session: 

1353 :return: number greater than 0 for active dag runs 

1354 """ 

1355 # .count() is inefficient 

1356 query = session.query(func.count()).filter(DagRun.dag_id == self.dag_id) 

1357 if only_running: 

1358 query = query.filter(DagRun.state == State.RUNNING) 

1359 else: 

1360 query = query.filter(DagRun.state.in_({State.RUNNING, State.QUEUED})) 

1361 

1362 if external_trigger is not None: 

1363 query = query.filter( 

1364 DagRun.external_trigger == (expression.true() if external_trigger else expression.false()) 

1365 ) 

1366 

1367 return query.scalar() 

1368 

1369 @provide_session 

1370 def get_dagrun( 

1371 self, 

1372 execution_date: datetime | None = None, 

1373 run_id: str | None = None, 

1374 session: Session = NEW_SESSION, 

1375 ): 

1376 """ 

1377 Returns the dag run for a given execution date or run_id if it exists, otherwise 

1378 none. 

1379 

1380 :param execution_date: The execution date of the DagRun to find. 

1381 :param run_id: The run_id of the DagRun to find. 

1382 :param session: 

1383 :return: The DagRun if found, otherwise None. 

1384 """ 

1385 if not (execution_date or run_id): 

1386 raise TypeError("You must provide either the execution_date or the run_id") 

1387 query = session.query(DagRun) 

1388 if execution_date: 

1389 query = query.filter(DagRun.dag_id == self.dag_id, DagRun.execution_date == execution_date) 

1390 if run_id: 

1391 query = query.filter(DagRun.dag_id == self.dag_id, DagRun.run_id == run_id) 

1392 return query.first() 

1393 

1394 @provide_session 

1395 def get_dagruns_between(self, start_date, end_date, session=NEW_SESSION): 

1396 """ 

1397 Returns the list of dag runs between start_date (inclusive) and end_date (inclusive). 

1398 

1399 :param start_date: The starting execution date of the DagRun to find. 

1400 :param end_date: The ending execution date of the DagRun to find. 

1401 :param session: 

1402 :return: The list of DagRuns found. 

1403 """ 

1404 dagruns = ( 

1405 session.query(DagRun) 

1406 .filter( 

1407 DagRun.dag_id == self.dag_id, 

1408 DagRun.execution_date >= start_date, 

1409 DagRun.execution_date <= end_date, 

1410 ) 

1411 .all() 

1412 ) 

1413 

1414 return dagruns 

1415 

1416 @provide_session 

1417 def get_latest_execution_date(self, session: Session = NEW_SESSION) -> pendulum.DateTime | None: 

1418 """Returns the latest date for which at least one dag run exists""" 

1419 return session.query(func.max(DagRun.execution_date)).filter(DagRun.dag_id == self.dag_id).scalar() 

1420 

1421 @property 

1422 def latest_execution_date(self): 

1423 """This attribute is deprecated. Please use `airflow.models.DAG.get_latest_execution_date`.""" 

1424 warnings.warn( 

1425 "This attribute is deprecated. Please use `airflow.models.DAG.get_latest_execution_date`.", 

1426 RemovedInAirflow3Warning, 

1427 stacklevel=2, 

1428 ) 

1429 return self.get_latest_execution_date() 

1430 

1431 @property 

1432 def subdags(self): 

1433 """Returns a list of the subdag objects associated to this DAG""" 

1434 # Check SubDag for class but don't check class directly 

1435 from airflow.operators.subdag import SubDagOperator 

1436 

1437 subdag_lst = [] 

1438 for task in self.tasks: 

1439 if ( 

1440 isinstance(task, SubDagOperator) 

1441 or 

1442 # TODO remove in Airflow 2.0 

1443 type(task).__name__ == "SubDagOperator" 

1444 or task.task_type == "SubDagOperator" 

1445 ): 

1446 subdag_lst.append(task.subdag) 

1447 subdag_lst += task.subdag.subdags 

1448 return subdag_lst 

1449 

1450 def resolve_template_files(self): 

1451 for t in self.tasks: 

1452 t.resolve_template_files() 

1453 

1454 def get_template_env(self, *, force_sandboxed: bool = False) -> jinja2.Environment: 

1455 """Build a Jinja2 environment.""" 

1456 # Collect directories to search for template files 

1457 searchpath = [self.folder] 

1458 if self.template_searchpath: 

1459 searchpath += self.template_searchpath 

1460 

1461 # Default values (for backward compatibility) 

1462 jinja_env_options = { 

1463 "loader": jinja2.FileSystemLoader(searchpath), 

1464 "undefined": self.template_undefined, 

1465 "extensions": ["jinja2.ext.do"], 

1466 "cache_size": 0, 

1467 } 

1468 if self.jinja_environment_kwargs: 

1469 jinja_env_options.update(self.jinja_environment_kwargs) 

1470 env: jinja2.Environment 

1471 if self.render_template_as_native_obj and not force_sandboxed: 

1472 env = airflow.templates.NativeEnvironment(**jinja_env_options) 

1473 else: 

1474 env = airflow.templates.SandboxedEnvironment(**jinja_env_options) 

1475 

1476 # Add any user defined items. Safe to edit globals as long as no templates are rendered yet. 

1477 # http://jinja.pocoo.org/docs/2.10/api/#jinja2.Environment.globals 

1478 if self.user_defined_macros: 

1479 env.globals.update(self.user_defined_macros) 

1480 if self.user_defined_filters: 

1481 env.filters.update(self.user_defined_filters) 

1482 

1483 return env 

1484 

1485 def set_dependency(self, upstream_task_id, downstream_task_id): 

1486 """ 

1487 Simple utility method to set dependency between two tasks that 

1488 already have been added to the DAG using add_task() 

1489 """ 

1490 self.get_task(upstream_task_id).set_downstream(self.get_task(downstream_task_id)) 

1491 

1492 @provide_session 

1493 def get_task_instances_before( 

1494 self, 

1495 base_date: datetime, 

1496 num: int, 

1497 *, 

1498 session: Session = NEW_SESSION, 

1499 ) -> list[TaskInstance]: 

1500 """Get ``num`` task instances before (including) ``base_date``. 

1501 

1502 The returned list may contain exactly ``num`` task instances. It can 

1503 have less if there are less than ``num`` scheduled DAG runs before 

1504 ``base_date``, or more if there are manual task runs between the 

1505 requested period, which does not count toward ``num``. 

1506 """ 

1507 min_date: datetime | None = ( 

1508 session.query(DagRun.execution_date) 

1509 .filter( 

1510 DagRun.dag_id == self.dag_id, 

1511 DagRun.execution_date <= base_date, 

1512 DagRun.run_type != DagRunType.MANUAL, 

1513 ) 

1514 .order_by(DagRun.execution_date.desc()) 

1515 .offset(num) 

1516 .limit(1) 

1517 .scalar() 

1518 ) 

1519 if min_date is None: 

1520 min_date = timezone.utc_epoch() 

1521 return self.get_task_instances(start_date=min_date, end_date=base_date, session=session) 

1522 

1523 @provide_session 

1524 def get_task_instances( 

1525 self, 

1526 start_date: datetime | None = None, 

1527 end_date: datetime | None = None, 

1528 state: list[TaskInstanceState] | None = None, 

1529 session: Session = NEW_SESSION, 

1530 ) -> list[TaskInstance]: 

1531 if not start_date: 

1532 start_date = (timezone.utcnow() - timedelta(30)).replace( 

1533 hour=0, minute=0, second=0, microsecond=0 

1534 ) 

1535 query = self._get_task_instances( 

1536 task_ids=None, 

1537 start_date=start_date, 

1538 end_date=end_date, 

1539 run_id=None, 

1540 state=state or (), 

1541 include_subdags=False, 

1542 include_parentdag=False, 

1543 include_dependent_dags=False, 

1544 exclude_task_ids=(), 

1545 session=session, 

1546 ) 

1547 return cast(Query, query).order_by(DagRun.execution_date).all() 

1548 

1549 @overload 

1550 def _get_task_instances( 

1551 self, 

1552 *, 

1553 task_ids: Collection[str | tuple[str, int]] | None, 

1554 start_date: datetime | None, 

1555 end_date: datetime | None, 

1556 run_id: str | None, 

1557 state: TaskInstanceState | Sequence[TaskInstanceState], 

1558 include_subdags: bool, 

1559 include_parentdag: bool, 

1560 include_dependent_dags: bool, 

1561 exclude_task_ids: Collection[str | tuple[str, int]] | None, 

1562 session: Session, 

1563 dag_bag: DagBag | None = ..., 

1564 ) -> Iterable[TaskInstance]: 

1565 ... # pragma: no cover 

1566 

1567 @overload 

1568 def _get_task_instances( 

1569 self, 

1570 *, 

1571 task_ids: Collection[str | tuple[str, int]] | None, 

1572 as_pk_tuple: Literal[True], 

1573 start_date: datetime | None, 

1574 end_date: datetime | None, 

1575 run_id: str | None, 

1576 state: TaskInstanceState | Sequence[TaskInstanceState], 

1577 include_subdags: bool, 

1578 include_parentdag: bool, 

1579 include_dependent_dags: bool, 

1580 exclude_task_ids: Collection[str | tuple[str, int]] | None, 

1581 session: Session, 

1582 dag_bag: DagBag | None = ..., 

1583 recursion_depth: int = ..., 

1584 max_recursion_depth: int = ..., 

1585 visited_external_tis: set[TaskInstanceKey] = ..., 

1586 ) -> set[TaskInstanceKey]: 

1587 ... # pragma: no cover 

1588 

1589 def _get_task_instances( 

1590 self, 

1591 *, 

1592 task_ids: Collection[str | tuple[str, int]] | None, 

1593 as_pk_tuple: Literal[True, None] = None, 

1594 start_date: datetime | None, 

1595 end_date: datetime | None, 

1596 run_id: str | None, 

1597 state: TaskInstanceState | Sequence[TaskInstanceState], 

1598 include_subdags: bool, 

1599 include_parentdag: bool, 

1600 include_dependent_dags: bool, 

1601 exclude_task_ids: Collection[str | tuple[str, int]] | None, 

1602 session: Session, 

1603 dag_bag: DagBag | None = None, 

1604 recursion_depth: int = 0, 

1605 max_recursion_depth: int | None = None, 

1606 visited_external_tis: set[TaskInstanceKey] | None = None, 

1607 ) -> Iterable[TaskInstance] | set[TaskInstanceKey]: 

1608 TI = TaskInstance 

1609 

1610 # If we are looking at subdags/dependent dags we want to avoid UNION calls 

1611 # in SQL (it doesn't play nice with fields that have no equality operator, 

1612 # like JSON types), we instead build our result set separately. 

1613 # 

1614 # This will be empty if we are only looking at one dag, in which case 

1615 # we can return the filtered TI query object directly. 

1616 result: set[TaskInstanceKey] = set() 

1617 

1618 # Do we want full objects, or just the primary columns? 

1619 if as_pk_tuple: 

1620 tis = session.query(TI.dag_id, TI.task_id, TI.run_id, TI.map_index) 

1621 else: 

1622 tis = session.query(TaskInstance) 

1623 tis = tis.join(TaskInstance.dag_run) 

1624 

1625 if include_subdags: 

1626 # Crafting the right filter for dag_id and task_ids combo 

1627 conditions = [] 

1628 for dag in self.subdags + [self]: 

1629 conditions.append( 

1630 (TaskInstance.dag_id == dag.dag_id) & TaskInstance.task_id.in_(dag.task_ids) 

1631 ) 

1632 tis = tis.filter(or_(*conditions)) 

1633 elif self.partial: 

1634 tis = tis.filter(TaskInstance.dag_id == self.dag_id, TaskInstance.task_id.in_(self.task_ids)) 

1635 else: 

1636 tis = tis.filter(TaskInstance.dag_id == self.dag_id) 

1637 if run_id: 

1638 tis = tis.filter(TaskInstance.run_id == run_id) 

1639 if start_date: 

1640 tis = tis.filter(DagRun.execution_date >= start_date) 

1641 if task_ids is not None: 

1642 tis = tis.filter(TaskInstance.ti_selector_condition(task_ids)) 

1643 

1644 # This allows allow_trigger_in_future config to take affect, rather than mandating exec_date <= UTC 

1645 if end_date or not self.allow_future_exec_dates: 

1646 end_date = end_date or timezone.utcnow() 

1647 tis = tis.filter(DagRun.execution_date <= end_date) 

1648 

1649 if state: 

1650 if isinstance(state, (str, TaskInstanceState)): 

1651 tis = tis.filter(TaskInstance.state == state) 

1652 elif len(state) == 1: 

1653 tis = tis.filter(TaskInstance.state == state[0]) 

1654 else: 

1655 # this is required to deal with NULL values 

1656 if None in state: 

1657 if all(x is None for x in state): 

1658 tis = tis.filter(TaskInstance.state.is_(None)) 

1659 else: 

1660 not_none_state = [s for s in state if s] 

1661 tis = tis.filter( 

1662 or_(TaskInstance.state.in_(not_none_state), TaskInstance.state.is_(None)) 

1663 ) 

1664 else: 

1665 tis = tis.filter(TaskInstance.state.in_(state)) 

1666 

1667 # Next, get any of them from our parent DAG (if there is one) 

1668 if include_parentdag and self.parent_dag is not None: 

1669 

1670 if visited_external_tis is None: 

1671 visited_external_tis = set() 

1672 

1673 p_dag = self.parent_dag.partial_subset( 

1674 task_ids_or_regex=r"^{}$".format(self.dag_id.split(".")[1]), 

1675 include_upstream=False, 

1676 include_downstream=True, 

1677 ) 

1678 result.update( 

1679 p_dag._get_task_instances( 

1680 task_ids=task_ids, 

1681 start_date=start_date, 

1682 end_date=end_date, 

1683 run_id=None, 

1684 state=state, 

1685 include_subdags=include_subdags, 

1686 include_parentdag=False, 

1687 include_dependent_dags=include_dependent_dags, 

1688 as_pk_tuple=True, 

1689 exclude_task_ids=exclude_task_ids, 

1690 session=session, 

1691 dag_bag=dag_bag, 

1692 recursion_depth=recursion_depth, 

1693 max_recursion_depth=max_recursion_depth, 

1694 visited_external_tis=visited_external_tis, 

1695 ) 

1696 ) 

1697 

1698 if include_dependent_dags: 

1699 # Recursively find external tasks indicated by ExternalTaskMarker 

1700 from airflow.sensors.external_task import ExternalTaskMarker 

1701 

1702 query = tis 

1703 if as_pk_tuple: 

1704 condition = TI.filter_for_tis(TaskInstanceKey(*cols) for cols in tis.all()) 

1705 if condition is not None: 

1706 query = session.query(TI).filter(condition) 

1707 

1708 if visited_external_tis is None: 

1709 visited_external_tis = set() 

1710 

1711 for ti in query.filter(TI.operator == ExternalTaskMarker.__name__): 

1712 ti_key = ti.key.primary 

1713 if ti_key in visited_external_tis: 

1714 continue 

1715 

1716 visited_external_tis.add(ti_key) 

1717 

1718 task: ExternalTaskMarker = cast(ExternalTaskMarker, copy.copy(self.get_task(ti.task_id))) 

1719 ti.task = task 

1720 

1721 if max_recursion_depth is None: 

1722 # Maximum recursion depth allowed is the recursion_depth of the first 

1723 # ExternalTaskMarker in the tasks to be visited. 

1724 max_recursion_depth = task.recursion_depth 

1725 

1726 if recursion_depth + 1 > max_recursion_depth: 

1727 # Prevent cycles or accidents. 

1728 raise AirflowException( 

1729 f"Maximum recursion depth {max_recursion_depth} reached for " 

1730 f"{ExternalTaskMarker.__name__} {ti.task_id}. " 

1731 f"Attempted to clear too many tasks or there may be a cyclic dependency." 

1732 ) 

1733 ti.render_templates() 

1734 external_tis = ( 

1735 session.query(TI) 

1736 .join(TI.dag_run) 

1737 .filter( 

1738 TI.dag_id == task.external_dag_id, 

1739 TI.task_id == task.external_task_id, 

1740 DagRun.execution_date == pendulum.parse(task.execution_date), 

1741 ) 

1742 ) 

1743 

1744 for tii in external_tis: 

1745 if not dag_bag: 

1746 from airflow.models.dagbag import DagBag 

1747 

1748 dag_bag = DagBag(read_dags_from_db=True) 

1749 external_dag = dag_bag.get_dag(tii.dag_id, session=session) 

1750 if not external_dag: 

1751 raise AirflowException(f"Could not find dag {tii.dag_id}") 

1752 downstream = external_dag.partial_subset( 

1753 task_ids_or_regex=[tii.task_id], 

1754 include_upstream=False, 

1755 include_downstream=True, 

1756 ) 

1757 result.update( 

1758 downstream._get_task_instances( 

1759 task_ids=None, 

1760 run_id=tii.run_id, 

1761 start_date=None, 

1762 end_date=None, 

1763 state=state, 

1764 include_subdags=include_subdags, 

1765 include_dependent_dags=include_dependent_dags, 

1766 include_parentdag=False, 

1767 as_pk_tuple=True, 

1768 exclude_task_ids=exclude_task_ids, 

1769 dag_bag=dag_bag, 

1770 session=session, 

1771 recursion_depth=recursion_depth + 1, 

1772 max_recursion_depth=max_recursion_depth, 

1773 visited_external_tis=visited_external_tis, 

1774 ) 

1775 ) 

1776 

1777 if result or as_pk_tuple: 

1778 # Only execute the `ti` query if we have also collected some other results (i.e. subdags etc.) 

1779 if as_pk_tuple: 

1780 result.update(TaskInstanceKey(**cols._mapping) for cols in tis.all()) 

1781 else: 

1782 result.update(ti.key for ti in tis) 

1783 

1784 if exclude_task_ids is not None: 

1785 result = { 

1786 task 

1787 for task in result 

1788 if task.task_id not in exclude_task_ids 

1789 and (task.task_id, task.map_index) not in exclude_task_ids 

1790 } 

1791 

1792 if as_pk_tuple: 

1793 return result 

1794 if result: 

1795 # We've been asked for objects, lets combine it all back in to a result set 

1796 ti_filters = TI.filter_for_tis(result) 

1797 if ti_filters is not None: 

1798 tis = session.query(TI).filter(ti_filters) 

1799 elif exclude_task_ids is None: 

1800 pass # Disable filter if not set. 

1801 elif isinstance(next(iter(exclude_task_ids), None), str): 

1802 tis = tis.filter(TI.task_id.notin_(exclude_task_ids)) 

1803 else: 

1804 tis = tis.filter(not_(tuple_in_condition((TI.task_id, TI.map_index), exclude_task_ids))) 

1805 

1806 return tis 

1807 

1808 @provide_session 

1809 def set_task_instance_state( 

1810 self, 

1811 *, 

1812 task_id: str, 

1813 map_indexes: Collection[int] | None = None, 

1814 execution_date: datetime | None = None, 

1815 run_id: str | None = None, 

1816 state: TaskInstanceState, 

1817 upstream: bool = False, 

1818 downstream: bool = False, 

1819 future: bool = False, 

1820 past: bool = False, 

1821 commit: bool = True, 

1822 session=NEW_SESSION, 

1823 ) -> list[TaskInstance]: 

1824 """ 

1825 Set the state of a TaskInstance to the given state, and clear its downstream tasks that are 

1826 in failed or upstream_failed state. 

1827 

1828 :param task_id: Task ID of the TaskInstance 

1829 :param map_indexes: Only set TaskInstance if its map_index matches. 

1830 If None (default), all mapped TaskInstances of the task are set. 

1831 :param execution_date: Execution date of the TaskInstance 

1832 :param run_id: The run_id of the TaskInstance 

1833 :param state: State to set the TaskInstance to 

1834 :param upstream: Include all upstream tasks of the given task_id 

1835 :param downstream: Include all downstream tasks of the given task_id 

1836 :param future: Include all future TaskInstances of the given task_id 

1837 :param commit: Commit changes 

1838 :param past: Include all past TaskInstances of the given task_id 

1839 """ 

1840 from airflow.api.common.mark_tasks import set_state 

1841 

1842 if not exactly_one(execution_date, run_id): 

1843 raise ValueError("Exactly one of execution_date or run_id must be provided") 

1844 

1845 task = self.get_task(task_id) 

1846 task.dag = self 

1847 

1848 tasks_to_set_state: list[Operator | tuple[Operator, int]] 

1849 if map_indexes is None: 

1850 tasks_to_set_state = [task] 

1851 else: 

1852 tasks_to_set_state = [(task, map_index) for map_index in map_indexes] 

1853 

1854 altered = set_state( 

1855 tasks=tasks_to_set_state, 

1856 execution_date=execution_date, 

1857 run_id=run_id, 

1858 upstream=upstream, 

1859 downstream=downstream, 

1860 future=future, 

1861 past=past, 

1862 state=state, 

1863 commit=commit, 

1864 session=session, 

1865 ) 

1866 

1867 if not commit: 

1868 return altered 

1869 

1870 # Clear downstream tasks that are in failed/upstream_failed state to resume them. 

1871 # Flush the session so that the tasks marked success are reflected in the db. 

1872 session.flush() 

1873 subdag = self.partial_subset( 

1874 task_ids_or_regex={task_id}, 

1875 include_downstream=True, 

1876 include_upstream=False, 

1877 ) 

1878 

1879 if execution_date is None: 

1880 dag_run = ( 

1881 session.query(DagRun).filter(DagRun.run_id == run_id, DagRun.dag_id == self.dag_id).one() 

1882 ) # Raises an error if not found 

1883 resolve_execution_date = dag_run.execution_date 

1884 else: 

1885 resolve_execution_date = execution_date 

1886 

1887 end_date = resolve_execution_date if not future else None 

1888 start_date = resolve_execution_date if not past else None 

1889 

1890 subdag.clear( 

1891 start_date=start_date, 

1892 end_date=end_date, 

1893 include_subdags=True, 

1894 include_parentdag=True, 

1895 only_failed=True, 

1896 session=session, 

1897 # Exclude the task itself from being cleared 

1898 exclude_task_ids={task_id}, 

1899 ) 

1900 

1901 return altered 

1902 

1903 @property 

1904 def roots(self) -> list[Operator]: 

1905 """Return nodes with no parents. These are first to execute and are called roots or root nodes.""" 

1906 return [task for task in self.tasks if not task.upstream_list] 

1907 

1908 @property 

1909 def leaves(self) -> list[Operator]: 

1910 """Return nodes with no children. These are last to execute and are called leaves or leaf nodes.""" 

1911 return [task for task in self.tasks if not task.downstream_list] 

1912 

1913 def topological_sort(self, include_subdag_tasks: bool = False): 

1914 """ 

1915 Sorts tasks in topographical order, such that a task comes after any of its 

1916 upstream dependencies. 

1917 

1918 Deprecated in place of ``task_group.topological_sort`` 

1919 """ 

1920 from airflow.utils.task_group import TaskGroup 

1921 

1922 def nested_topo(group): 

1923 for node in group.topological_sort(_include_subdag_tasks=include_subdag_tasks): 

1924 if isinstance(node, TaskGroup): 

1925 yield from nested_topo(node) 

1926 else: 

1927 yield node 

1928 

1929 return tuple(nested_topo(self.task_group)) 

1930 

1931 @provide_session 

1932 def set_dag_runs_state( 

1933 self, 

1934 state: str = State.RUNNING, 

1935 session: Session = NEW_SESSION, 

1936 start_date: datetime | None = None, 

1937 end_date: datetime | None = None, 

1938 dag_ids: list[str] = [], 

1939 ) -> None: 

1940 warnings.warn( 

1941 "This method is deprecated and will be removed in a future version.", 

1942 RemovedInAirflow3Warning, 

1943 stacklevel=3, 

1944 ) 

1945 dag_ids = dag_ids or [self.dag_id] 

1946 query = session.query(DagRun).filter(DagRun.dag_id.in_(dag_ids)) 

1947 if start_date: 

1948 query = query.filter(DagRun.execution_date >= start_date) 

1949 if end_date: 

1950 query = query.filter(DagRun.execution_date <= end_date) 

1951 query.update({DagRun.state: state}, synchronize_session="fetch") 

1952 

1953 @provide_session 

1954 def clear( 

1955 self, 

1956 task_ids: Collection[str | tuple[str, int]] | None = None, 

1957 start_date: datetime | None = None, 

1958 end_date: datetime | None = None, 

1959 only_failed: bool = False, 

1960 only_running: bool = False, 

1961 confirm_prompt: bool = False, 

1962 include_subdags: bool = True, 

1963 include_parentdag: bool = True, 

1964 dag_run_state: DagRunState = DagRunState.QUEUED, 

1965 dry_run: bool = False, 

1966 session: Session = NEW_SESSION, 

1967 get_tis: bool = False, 

1968 recursion_depth: int = 0, 

1969 max_recursion_depth: int | None = None, 

1970 dag_bag: DagBag | None = None, 

1971 exclude_task_ids: frozenset[str] | frozenset[tuple[str, int]] | None = frozenset(), 

1972 ) -> int | Iterable[TaskInstance]: 

1973 """ 

1974 Clears a set of task instances associated with the current dag for 

1975 a specified date range. 

1976 

1977 :param task_ids: List of task ids or (``task_id``, ``map_index``) tuples to clear 

1978 :param start_date: The minimum execution_date to clear 

1979 :param end_date: The maximum execution_date to clear 

1980 :param only_failed: Only clear failed tasks 

1981 :param only_running: Only clear running tasks. 

1982 :param confirm_prompt: Ask for confirmation 

1983 :param include_subdags: Clear tasks in subdags and clear external tasks 

1984 indicated by ExternalTaskMarker 

1985 :param include_parentdag: Clear tasks in the parent dag of the subdag. 

1986 :param dag_run_state: state to set DagRun to. If set to False, dagrun state will not 

1987 be changed. 

1988 :param dry_run: Find the tasks to clear but don't clear them. 

1989 :param session: The sqlalchemy session to use 

1990 :param dag_bag: The DagBag used to find the dags subdags (Optional) 

1991 :param exclude_task_ids: A set of ``task_id`` or (``task_id``, ``map_index``) 

1992 tuples that should not be cleared 

1993 """ 

1994 if get_tis: 

1995 warnings.warn( 

1996 "Passing `get_tis` to dag.clear() is deprecated. Use `dry_run` parameter instead.", 

1997 RemovedInAirflow3Warning, 

1998 stacklevel=2, 

1999 ) 

2000 dry_run = True 

2001 

2002 if recursion_depth: 

2003 warnings.warn( 

2004 "Passing `recursion_depth` to dag.clear() is deprecated.", 

2005 RemovedInAirflow3Warning, 

2006 stacklevel=2, 

2007 ) 

2008 if max_recursion_depth: 

2009 warnings.warn( 

2010 "Passing `max_recursion_depth` to dag.clear() is deprecated.", 

2011 RemovedInAirflow3Warning, 

2012 stacklevel=2, 

2013 ) 

2014 

2015 state = [] 

2016 if only_failed: 

2017 state += [State.FAILED, State.UPSTREAM_FAILED] 

2018 if only_running: 

2019 # Yes, having `+=` doesn't make sense, but this was the existing behaviour 

2020 state += [State.RUNNING] 

2021 

2022 tis = self._get_task_instances( 

2023 task_ids=task_ids, 

2024 start_date=start_date, 

2025 end_date=end_date, 

2026 run_id=None, 

2027 state=state, 

2028 include_subdags=include_subdags, 

2029 include_parentdag=include_parentdag, 

2030 include_dependent_dags=include_subdags, # compat, yes this is not a typo 

2031 session=session, 

2032 dag_bag=dag_bag, 

2033 exclude_task_ids=exclude_task_ids, 

2034 ) 

2035 

2036 if dry_run: 

2037 return tis 

2038 

2039 tis = list(tis) 

2040 

2041 count = len(tis) 

2042 do_it = True 

2043 if count == 0: 

2044 return 0 

2045 if confirm_prompt: 

2046 ti_list = "\n".join(str(t) for t in tis) 

2047 question = ( 

2048 "You are about to delete these {count} tasks:\n{ti_list}\n\nAre you sure? [y/n]" 

2049 ).format(count=count, ti_list=ti_list) 

2050 do_it = utils.helpers.ask_yesno(question) 

2051 

2052 if do_it: 

2053 clear_task_instances( 

2054 tis, 

2055 session, 

2056 dag=self, 

2057 dag_run_state=dag_run_state, 

2058 ) 

2059 else: 

2060 count = 0 

2061 print("Cancelled, nothing was cleared.") 

2062 

2063 session.flush() 

2064 return count 

2065 

2066 @classmethod 

2067 def clear_dags( 

2068 cls, 

2069 dags, 

2070 start_date=None, 

2071 end_date=None, 

2072 only_failed=False, 

2073 only_running=False, 

2074 confirm_prompt=False, 

2075 include_subdags=True, 

2076 include_parentdag=False, 

2077 dag_run_state=DagRunState.QUEUED, 

2078 dry_run=False, 

2079 ): 

2080 all_tis = [] 

2081 for dag in dags: 

2082 tis = dag.clear( 

2083 start_date=start_date, 

2084 end_date=end_date, 

2085 only_failed=only_failed, 

2086 only_running=only_running, 

2087 confirm_prompt=False, 

2088 include_subdags=include_subdags, 

2089 include_parentdag=include_parentdag, 

2090 dag_run_state=dag_run_state, 

2091 dry_run=True, 

2092 ) 

2093 all_tis.extend(tis) 

2094 

2095 if dry_run: 

2096 return all_tis 

2097 

2098 count = len(all_tis) 

2099 do_it = True 

2100 if count == 0: 

2101 print("Nothing to clear.") 

2102 return 0 

2103 if confirm_prompt: 

2104 ti_list = "\n".join(str(t) for t in all_tis) 

2105 question = f"You are about to delete these {count} tasks:\n{ti_list}\n\nAre you sure? [y/n]" 

2106 do_it = utils.helpers.ask_yesno(question) 

2107 

2108 if do_it: 

2109 for dag in dags: 

2110 dag.clear( 

2111 start_date=start_date, 

2112 end_date=end_date, 

2113 only_failed=only_failed, 

2114 only_running=only_running, 

2115 confirm_prompt=False, 

2116 include_subdags=include_subdags, 

2117 dag_run_state=dag_run_state, 

2118 dry_run=False, 

2119 ) 

2120 else: 

2121 count = 0 

2122 print("Cancelled, nothing was cleared.") 

2123 return count 

2124 

2125 def __deepcopy__(self, memo): 

2126 # Switcharoo to go around deepcopying objects coming through the 

2127 # backdoor 

2128 cls = self.__class__ 

2129 result = cls.__new__(cls) 

2130 memo[id(self)] = result 

2131 for k, v in self.__dict__.items(): 

2132 if k not in ("user_defined_macros", "user_defined_filters", "_log"): 

2133 setattr(result, k, copy.deepcopy(v, memo)) 

2134 

2135 result.user_defined_macros = self.user_defined_macros 

2136 result.user_defined_filters = self.user_defined_filters 

2137 if hasattr(self, "_log"): 

2138 result._log = self._log 

2139 return result 

2140 

2141 def sub_dag(self, *args, **kwargs): 

2142 """This method is deprecated in favor of partial_subset""" 

2143 warnings.warn( 

2144 "This method is deprecated and will be removed in a future version. Please use partial_subset", 

2145 RemovedInAirflow3Warning, 

2146 stacklevel=2, 

2147 ) 

2148 return self.partial_subset(*args, **kwargs) 

2149 

2150 def partial_subset( 

2151 self, 

2152 task_ids_or_regex: str | re.Pattern | Iterable[str], 

2153 include_downstream=False, 

2154 include_upstream=True, 

2155 include_direct_upstream=False, 

2156 ): 

2157 """ 

2158 Returns a subset of the current dag as a deep copy of the current dag 

2159 based on a regex that should match one or many tasks, and includes 

2160 upstream and downstream neighbours based on the flag passed. 

2161 

2162 :param task_ids_or_regex: Either a list of task_ids, or a regex to 

2163 match against task ids (as a string, or compiled regex pattern). 

2164 :param include_downstream: Include all downstream tasks of matched 

2165 tasks, in addition to matched tasks. 

2166 :param include_upstream: Include all upstream tasks of matched tasks, 

2167 in addition to matched tasks. 

2168 :param include_direct_upstream: Include all tasks directly upstream of matched 

2169 and downstream (if include_downstream = True) tasks 

2170 """ 

2171 from airflow.models.baseoperator import BaseOperator 

2172 from airflow.models.mappedoperator import MappedOperator 

2173 

2174 # deep-copying self.task_dict and self._task_group takes a long time, and we don't want all 

2175 # the tasks anyway, so we copy the tasks manually later 

2176 memo = {id(self.task_dict): None, id(self._task_group): None} 

2177 dag = copy.deepcopy(self, memo) # type: ignore 

2178 

2179 if isinstance(task_ids_or_regex, (str, re.Pattern)): 

2180 matched_tasks = [t for t in self.tasks if re.findall(task_ids_or_regex, t.task_id)] 

2181 else: 

2182 matched_tasks = [t for t in self.tasks if t.task_id in task_ids_or_regex] 

2183 

2184 also_include: list[Operator] = [] 

2185 for t in matched_tasks: 

2186 if include_downstream: 

2187 also_include.extend(t.get_flat_relatives(upstream=False)) 

2188 if include_upstream: 

2189 also_include.extend(t.get_flat_relatives(upstream=True)) 

2190 

2191 direct_upstreams: list[Operator] = [] 

2192 if include_direct_upstream: 

2193 for t in itertools.chain(matched_tasks, also_include): 

2194 upstream = (u for u in t.upstream_list if isinstance(u, (BaseOperator, MappedOperator))) 

2195 direct_upstreams.extend(upstream) 

2196 

2197 # Compiling the unique list of tasks that made the cut 

2198 # Make sure to not recursively deepcopy the dag or task_group while copying the task. 

2199 # task_group is reset later 

2200 def _deepcopy_task(t) -> Operator: 

2201 memo.setdefault(id(t.task_group), None) 

2202 return copy.deepcopy(t, memo) 

2203 

2204 dag.task_dict = { 

2205 t.task_id: _deepcopy_task(t) 

2206 for t in itertools.chain(matched_tasks, also_include, direct_upstreams) 

2207 } 

2208 

2209 def filter_task_group(group, parent_group): 

2210 """Exclude tasks not included in the subdag from the given TaskGroup.""" 

2211 copied = copy.copy(group) 

2212 copied.used_group_ids = set(copied.used_group_ids) 

2213 copied._parent_group = parent_group 

2214 

2215 copied.children = {} 

2216 

2217 for child in group.children.values(): 

2218 if isinstance(child, AbstractOperator): 

2219 if child.task_id in dag.task_dict: 

2220 task = copied.children[child.task_id] = dag.task_dict[child.task_id] 

2221 task.task_group = weakref.proxy(copied) 

2222 else: 

2223 copied.used_group_ids.discard(child.task_id) 

2224 else: 

2225 filtered_child = filter_task_group(child, copied) 

2226 

2227 # Only include this child TaskGroup if it is non-empty. 

2228 if filtered_child.children: 

2229 copied.children[child.group_id] = filtered_child 

2230 

2231 return copied 

2232 

2233 dag._task_group = filter_task_group(self._task_group, None) 

2234 

2235 # Removing upstream/downstream references to tasks and TaskGroups that did not make 

2236 # the cut. 

2237 subdag_task_groups = dag.task_group.get_task_group_dict() 

2238 for group in subdag_task_groups.values(): 

2239 group.upstream_group_ids.intersection_update(subdag_task_groups) 

2240 group.downstream_group_ids.intersection_update(subdag_task_groups) 

2241 group.upstream_task_ids.intersection_update(dag.task_dict) 

2242 group.downstream_task_ids.intersection_update(dag.task_dict) 

2243 

2244 for t in dag.tasks: 

2245 # Removing upstream/downstream references to tasks that did not 

2246 # make the cut 

2247 t.upstream_task_ids.intersection_update(dag.task_dict) 

2248 t.downstream_task_ids.intersection_update(dag.task_dict) 

2249 

2250 if len(dag.tasks) < len(self.tasks): 

2251 dag.partial = True 

2252 

2253 return dag 

2254 

2255 def has_task(self, task_id: str): 

2256 return task_id in self.task_dict 

2257 

2258 def has_task_group(self, task_group_id: str) -> bool: 

2259 return task_group_id in self.task_group_dict 

2260 

2261 @cached_property 

2262 def task_group_dict(self): 

2263 return {k: v for k, v in self._task_group.get_task_group_dict().items() if k is not None} 

2264 

2265 def get_task(self, task_id: str, include_subdags: bool = False) -> Operator: 

2266 if task_id in self.task_dict: 

2267 return self.task_dict[task_id] 

2268 if include_subdags: 

2269 for dag in self.subdags: 

2270 if task_id in dag.task_dict: 

2271 return dag.task_dict[task_id] 

2272 raise TaskNotFound(f"Task {task_id} not found") 

2273 

2274 def pickle_info(self): 

2275 d = {} 

2276 d["is_picklable"] = True 

2277 try: 

2278 dttm = timezone.utcnow() 

2279 pickled = pickle.dumps(self) 

2280 d["pickle_len"] = len(pickled) 

2281 d["pickling_duration"] = str(timezone.utcnow() - dttm) 

2282 except Exception as e: 

2283 self.log.debug(e) 

2284 d["is_picklable"] = False 

2285 d["stacktrace"] = traceback.format_exc() 

2286 return d 

2287 

2288 @provide_session 

2289 def pickle(self, session=NEW_SESSION) -> DagPickle: 

2290 dag = session.query(DagModel).filter(DagModel.dag_id == self.dag_id).first() 

2291 dp = None 

2292 if dag and dag.pickle_id: 

2293 dp = session.query(DagPickle).filter(DagPickle.id == dag.pickle_id).first() 

2294 if not dp or dp.pickle != self: 

2295 dp = DagPickle(dag=self) 

2296 session.add(dp) 

2297 self.last_pickled = timezone.utcnow() 

2298 session.commit() 

2299 self.pickle_id = dp.id 

2300 

2301 return dp 

2302 

2303 def tree_view(self) -> None: 

2304 """Print an ASCII tree representation of the DAG.""" 

2305 

2306 def get_downstream(task, level=0): 

2307 print((" " * level * 4) + str(task)) 

2308 level += 1 

2309 for t in task.downstream_list: 

2310 get_downstream(t, level) 

2311 

2312 for t in self.roots: 

2313 get_downstream(t) 

2314 

2315 @property 

2316 def task(self) -> TaskDecoratorCollection: 

2317 from airflow.decorators import task 

2318 

2319 return cast("TaskDecoratorCollection", functools.partial(task, dag=self)) 

2320 

2321 def add_task(self, task: Operator) -> None: 

2322 """ 

2323 Add a task to the DAG 

2324 

2325 :param task: the task you want to add 

2326 """ 

2327 from airflow.utils.task_group import TaskGroupContext 

2328 

2329 if not self.start_date and not task.start_date: 

2330 raise AirflowException("DAG is missing the start_date parameter") 

2331 # if the task has no start date, assign it the same as the DAG 

2332 elif not task.start_date: 

2333 task.start_date = self.start_date 

2334 # otherwise, the task will start on the later of its own start date and 

2335 # the DAG's start date 

2336 elif self.start_date: 

2337 task.start_date = max(task.start_date, self.start_date) 

2338 

2339 # if the task has no end date, assign it the same as the dag 

2340 if not task.end_date: 

2341 task.end_date = self.end_date 

2342 # otherwise, the task will end on the earlier of its own end date and 

2343 # the DAG's end date 

2344 elif task.end_date and self.end_date: 

2345 task.end_date = min(task.end_date, self.end_date) 

2346 

2347 task_id = task.task_id 

2348 if not task.task_group: 

2349 task_group = TaskGroupContext.get_current_task_group(self) 

2350 if task_group: 

2351 task_id = task_group.child_id(task_id) 

2352 task_group.add(task) 

2353 

2354 if ( 

2355 task_id in self.task_dict and self.task_dict[task_id] is not task 

2356 ) or task_id in self._task_group.used_group_ids: 

2357 raise DuplicateTaskIdFound(f"Task id '{task_id}' has already been added to the DAG") 

2358 else: 

2359 self.task_dict[task_id] = task 

2360 task.dag = self 

2361 # Add task_id to used_group_ids to prevent group_id and task_id collisions. 

2362 self._task_group.used_group_ids.add(task_id) 

2363 

2364 self.task_count = len(self.task_dict) 

2365 

2366 def add_tasks(self, tasks: Iterable[Operator]) -> None: 

2367 """ 

2368 Add a list of tasks to the DAG 

2369 

2370 :param tasks: a lit of tasks you want to add 

2371 """ 

2372 for task in tasks: 

2373 self.add_task(task) 

2374 

2375 def _remove_task(self, task_id: str) -> None: 

2376 # This is "private" as removing could leave a hole in dependencies if done incorrectly, and this 

2377 # doesn't guard against that 

2378 task = self.task_dict.pop(task_id) 

2379 tg = getattr(task, "task_group", None) 

2380 if tg: 

2381 tg._remove(task) 

2382 

2383 self.task_count = len(self.task_dict) 

2384 

2385 def run( 

2386 self, 

2387 start_date=None, 

2388 end_date=None, 

2389 mark_success=False, 

2390 local=False, 

2391 executor=None, 

2392 donot_pickle=conf.getboolean("core", "donot_pickle"), 

2393 ignore_task_deps=False, 

2394 ignore_first_depends_on_past=True, 

2395 pool=None, 

2396 delay_on_limit_secs=1.0, 

2397 verbose=False, 

2398 conf=None, 

2399 rerun_failed_tasks=False, 

2400 run_backwards=False, 

2401 run_at_least_once=False, 

2402 continue_on_failures=False, 

2403 disable_retry=False, 

2404 ): 

2405 """ 

2406 Runs the DAG. 

2407 

2408 :param start_date: the start date of the range to run 

2409 :param end_date: the end date of the range to run 

2410 :param mark_success: True to mark jobs as succeeded without running them 

2411 :param local: True to run the tasks using the LocalExecutor 

2412 :param executor: The executor instance to run the tasks 

2413 :param donot_pickle: True to avoid pickling DAG object and send to workers 

2414 :param ignore_task_deps: True to skip upstream tasks 

2415 :param ignore_first_depends_on_past: True to ignore depends_on_past 

2416 dependencies for the first set of tasks only 

2417 :param pool: Resource pool to use 

2418 :param delay_on_limit_secs: Time in seconds to wait before next attempt to run 

2419 dag run when max_active_runs limit has been reached 

2420 :param verbose: Make logging output more verbose 

2421 :param conf: user defined dictionary passed from CLI 

2422 :param rerun_failed_tasks: 

2423 :param run_backwards: 

2424 :param run_at_least_once: If true, always run the DAG at least once even 

2425 if no logical run exists within the time range. 

2426 """ 

2427 from airflow.jobs.backfill_job import BackfillJob 

2428 

2429 if not executor and local: 

2430 from airflow.executors.local_executor import LocalExecutor 

2431 

2432 executor = LocalExecutor() 

2433 elif not executor: 

2434 from airflow.executors.executor_loader import ExecutorLoader 

2435 

2436 executor = ExecutorLoader.get_default_executor() 

2437 job = BackfillJob( 

2438 self, 

2439 start_date=start_date, 

2440 end_date=end_date, 

2441 mark_success=mark_success, 

2442 executor=executor, 

2443 donot_pickle=donot_pickle, 

2444 ignore_task_deps=ignore_task_deps, 

2445 ignore_first_depends_on_past=ignore_first_depends_on_past, 

2446 pool=pool, 

2447 delay_on_limit_secs=delay_on_limit_secs, 

2448 verbose=verbose, 

2449 conf=conf, 

2450 rerun_failed_tasks=rerun_failed_tasks, 

2451 run_backwards=run_backwards, 

2452 run_at_least_once=run_at_least_once, 

2453 continue_on_failures=continue_on_failures, 

2454 disable_retry=disable_retry, 

2455 ) 

2456 job.run() 

2457 

2458 def cli(self): 

2459 """Exposes a CLI specific to this DAG""" 

2460 check_cycle(self) 

2461 

2462 from airflow.cli import cli_parser 

2463 

2464 parser = cli_parser.get_parser(dag_parser=True) 

2465 args = parser.parse_args() 

2466 args.func(args, self) 

2467 

2468 @provide_session 

2469 def test( 

2470 self, 

2471 execution_date: datetime | None = None, 

2472 run_conf: dict[str, Any] | None = None, 

2473 conn_file_path: str | None = None, 

2474 variable_file_path: str | None = None, 

2475 session: Session = NEW_SESSION, 

2476 ) -> None: 

2477 """ 

2478 Execute one single DagRun for a given DAG and execution date. 

2479 

2480 :param execution_date: execution date for the DAG run 

2481 :param run_conf: configuration to pass to newly created dagrun 

2482 :param conn_file_path: file path to a connection file in either yaml or json 

2483 :param variable_file_path: file path to a variable file in either yaml or json 

2484 :param session: database connection (optional) 

2485 """ 

2486 

2487 def add_logger_if_needed(ti: TaskInstance): 

2488 """ 

2489 Add a formatted logger to the taskinstance so all logs are surfaced to the command line instead 

2490 of into a task file. Since this is a local test run, it is much better for the user to see logs 

2491 in the command line, rather than needing to search for a log file. 

2492 Args: 

2493 ti: The taskinstance that will receive a logger 

2494 

2495 """ 

2496 format = logging.Formatter("[%(asctime)s] {%(filename)s:%(lineno)d} %(levelname)s - %(message)s") 

2497 handler = logging.StreamHandler(sys.stdout) 

2498 handler.level = logging.INFO 

2499 handler.setFormatter(format) 

2500 # only add log handler once 

2501 if not any(isinstance(h, logging.StreamHandler) for h in ti.log.handlers): 

2502 self.log.debug("Adding Streamhandler to taskinstance %s", ti.task_id) 

2503 ti.log.addHandler(handler) 

2504 

2505 if conn_file_path or variable_file_path: 

2506 local_secrets = LocalFilesystemBackend( 

2507 variables_file_path=variable_file_path, connections_file_path=conn_file_path 

2508 ) 

2509 secrets_backend_list.insert(0, local_secrets) 

2510 

2511 execution_date = execution_date or timezone.utcnow() 

2512 self.log.debug("Clearing existing task instances for execution date %s", execution_date) 

2513 self.clear( 

2514 start_date=execution_date, 

2515 end_date=execution_date, 

2516 dag_run_state=False, # type: ignore 

2517 session=session, 

2518 ) 

2519 self.log.debug("Getting dagrun for dag %s", self.dag_id) 

2520 dr: DagRun = _get_or_create_dagrun( 

2521 dag=self, 

2522 start_date=execution_date, 

2523 execution_date=execution_date, 

2524 run_id=DagRun.generate_run_id(DagRunType.MANUAL, execution_date), 

2525 session=session, 

2526 conf=run_conf, 

2527 ) 

2528 

2529 tasks = self.task_dict 

2530 self.log.debug("starting dagrun") 

2531 # Instead of starting a scheduler, we run the minimal loop possible to check 

2532 # for task readiness and dependency management. This is notably faster 

2533 # than creating a BackfillJob and allows us to surface logs to the user 

2534 while dr.state == State.RUNNING: 

2535 schedulable_tis, _ = dr.update_state(session=session) 

2536 for ti in schedulable_tis: 

2537 add_logger_if_needed(ti) 

2538 ti.task = tasks[ti.task_id] 

2539 _run_task(ti, session=session) 

2540 if conn_file_path or variable_file_path: 

2541 # Remove the local variables we have added to the secrets_backend_list 

2542 secrets_backend_list.pop(0) 

2543 

2544 @provide_session 

2545 def create_dagrun( 

2546 self, 

2547 state: DagRunState, 

2548 execution_date: datetime | None = None, 

2549 run_id: str | None = None, 

2550 start_date: datetime | None = None, 

2551 external_trigger: bool | None = False, 

2552 conf: dict | None = None, 

2553 run_type: DagRunType | None = None, 

2554 session: Session = NEW_SESSION, 

2555 dag_hash: str | None = None, 

2556 creating_job_id: int | None = None, 

2557 data_interval: tuple[datetime, datetime] | None = None, 

2558 ): 

2559 """ 

2560 Creates a dag run from this dag including the tasks associated with this dag. 

2561 Returns the dag run. 

2562 

2563 :param run_id: defines the run id for this dag run 

2564 :param run_type: type of DagRun 

2565 :param execution_date: the execution date of this dag run 

2566 :param state: the state of the dag run 

2567 :param start_date: the date this dag run should be evaluated 

2568 :param external_trigger: whether this dag run is externally triggered 

2569 :param conf: Dict containing configuration/parameters to pass to the DAG 

2570 :param creating_job_id: id of the job creating this DagRun 

2571 :param session: database session 

2572 :param dag_hash: Hash of Serialized DAG 

2573 :param data_interval: Data interval of the DagRun 

2574 """ 

2575 logical_date = timezone.coerce_datetime(execution_date) 

2576 

2577 if data_interval and not isinstance(data_interval, DataInterval): 

2578 data_interval = DataInterval(*map(timezone.coerce_datetime, data_interval)) 

2579 

2580 if data_interval is None and logical_date is not None: 

2581 warnings.warn( 

2582 "Calling `DAG.create_dagrun()` without an explicit data interval is deprecated", 

2583 RemovedInAirflow3Warning, 

2584 stacklevel=3, 

2585 ) 

2586 if run_type == DagRunType.MANUAL: 

2587 data_interval = self.timetable.infer_manual_data_interval(run_after=logical_date) 

2588 else: 

2589 data_interval = self.infer_automated_data_interval(logical_date) 

2590 

2591 if run_type is None or isinstance(run_type, DagRunType): 

2592 pass 

2593 elif isinstance(run_type, str): # Compatibility: run_type used to be a str. 

2594 run_type = DagRunType(run_type) 

2595 else: 

2596 raise ValueError(f"`run_type` should be a DagRunType, not {type(run_type)}") 

2597 

2598 if run_id: # Infer run_type from run_id if needed. 

2599 if not isinstance(run_id, str): 

2600 raise ValueError(f"`run_id` should be a str, not {type(run_id)}") 

2601 inferred_run_type = DagRunType.from_run_id(run_id) 

2602 if run_type is None: 

2603 # No explicit type given, use the inferred type. 

2604 run_type = inferred_run_type 

2605 elif run_type == DagRunType.MANUAL and inferred_run_type != DagRunType.MANUAL: 

2606 # Prevent a manual run from using an ID that looks like a scheduled run. 

2607 raise ValueError( 

2608 f"A {run_type.value} DAG run cannot use ID {run_id!r} since it " 

2609 f"is reserved for {inferred_run_type.value} runs" 

2610 ) 

2611 elif run_type and logical_date is not None: # Generate run_id from run_type and execution_date. 

2612 run_id = self.timetable.generate_run_id( 

2613 run_type=run_type, logical_date=logical_date, data_interval=data_interval 

2614 ) 

2615 else: 

2616 raise AirflowException( 

2617 "Creating DagRun needs either `run_id` or both `run_type` and `execution_date`" 

2618 ) 

2619 

2620 if run_id and "/" in run_id: 

2621 warnings.warn( 

2622 "Using forward slash ('/') in a DAG run ID is deprecated. Note that this character " 

2623 "also makes the run impossible to retrieve via Airflow's REST API.", 

2624 RemovedInAirflow3Warning, 

2625 stacklevel=3, 

2626 ) 

2627 

2628 # create a copy of params before validating 

2629 copied_params = copy.deepcopy(self.params) 

2630 copied_params.update(conf or {}) 

2631 copied_params.validate() 

2632 

2633 run = DagRun( 

2634 dag_id=self.dag_id, 

2635 run_id=run_id, 

2636 execution_date=logical_date, 

2637 start_date=start_date, 

2638 external_trigger=external_trigger, 

2639 conf=conf, 

2640 state=state, 

2641 run_type=run_type, 

2642 dag_hash=dag_hash, 

2643 creating_job_id=creating_job_id, 

2644 data_interval=data_interval, 

2645 ) 

2646 session.add(run) 

2647 session.flush() 

2648 

2649 run.dag = self 

2650 

2651 # create the associated task instances 

2652 # state is None at the moment of creation 

2653 run.verify_integrity(session=session) 

2654 

2655 return run 

2656 

2657 @classmethod 

2658 @provide_session 

2659 def bulk_sync_to_db( 

2660 cls, 

2661 dags: Collection[DAG], 

2662 session=NEW_SESSION, 

2663 ): 

2664 """This method is deprecated in favor of bulk_write_to_db""" 

2665 warnings.warn( 

2666 "This method is deprecated and will be removed in a future version. Please use bulk_write_to_db", 

2667 RemovedInAirflow3Warning, 

2668 stacklevel=2, 

2669 ) 

2670 return cls.bulk_write_to_db(dags=dags, session=session) 

2671 

2672 @classmethod 

2673 @provide_session 

2674 def bulk_write_to_db( 

2675 cls, 

2676 dags: Collection[DAG], 

2677 processor_subdir: str | None = None, 

2678 session=NEW_SESSION, 

2679 ): 

2680 """ 

2681 Ensure the DagModel rows for the given dags are up-to-date in the dag table in the DB, including 

2682 calculated fields. 

2683 

2684 Note that this method can be called for both DAGs and SubDAGs. A SubDag is actually a SubDagOperator. 

2685 

2686 :param dags: the DAG objects to save to the DB 

2687 :return: None 

2688 """ 

2689 if not dags: 

2690 return 

2691 

2692 log.info("Sync %s DAGs", len(dags)) 

2693 dag_by_ids = {dag.dag_id: dag for dag in dags} 

2694 

2695 dag_ids = set(dag_by_ids.keys()) 

2696 query = ( 

2697 session.query(DagModel) 

2698 .options(joinedload(DagModel.tags, innerjoin=False)) 

2699 .filter(DagModel.dag_id.in_(dag_ids)) 

2700 .options(joinedload(DagModel.schedule_dataset_references)) 

2701 .options(joinedload(DagModel.task_outlet_dataset_references)) 

2702 ) 

2703 orm_dags: list[DagModel] = with_row_locks(query, of=DagModel, session=session).all() 

2704 existing_dags = {orm_dag.dag_id: orm_dag for orm_dag in orm_dags} 

2705 missing_dag_ids = dag_ids.difference(existing_dags) 

2706 

2707 for missing_dag_id in missing_dag_ids: 

2708 orm_dag = DagModel(dag_id=missing_dag_id) 

2709 dag = dag_by_ids[missing_dag_id] 

2710 if dag.is_paused_upon_creation is not None: 

2711 orm_dag.is_paused = dag.is_paused_upon_creation 

2712 orm_dag.tags = [] 

2713 log.info("Creating ORM DAG for %s", dag.dag_id) 

2714 session.add(orm_dag) 

2715 orm_dags.append(orm_dag) 

2716 

2717 # Get the latest dag run for each existing dag as a single query (avoid n+1 query) 

2718 most_recent_subq = ( 

2719 session.query(DagRun.dag_id, func.max(DagRun.execution_date).label("max_execution_date")) 

2720 .filter( 

2721 DagRun.dag_id.in_(existing_dags), 

2722 or_(DagRun.run_type == DagRunType.BACKFILL_JOB, DagRun.run_type == DagRunType.SCHEDULED), 

2723 ) 

2724 .group_by(DagRun.dag_id) 

2725 .subquery() 

2726 ) 

2727 most_recent_runs_iter = session.query(DagRun).filter( 

2728 DagRun.dag_id == most_recent_subq.c.dag_id, 

2729 DagRun.execution_date == most_recent_subq.c.max_execution_date, 

2730 ) 

2731 most_recent_runs = {run.dag_id: run for run in most_recent_runs_iter} 

2732 

2733 # Get number of active dagruns for all dags we are processing as a single query. 

2734 

2735 num_active_runs = DagRun.active_runs_of_dags(dag_ids=existing_dags, session=session) 

2736 

2737 filelocs = [] 

2738 

2739 for orm_dag in sorted(orm_dags, key=lambda d: d.dag_id): 

2740 dag = dag_by_ids[orm_dag.dag_id] 

2741 filelocs.append(dag.fileloc) 

2742 if dag.is_subdag: 

2743 orm_dag.is_subdag = True 

2744 orm_dag.fileloc = dag.parent_dag.fileloc # type: ignore 

2745 orm_dag.root_dag_id = dag.parent_dag.dag_id # type: ignore 

2746 orm_dag.owners = dag.parent_dag.owner # type: ignore 

2747 else: 

2748 orm_dag.is_subdag = False 

2749 orm_dag.fileloc = dag.fileloc 

2750 orm_dag.owners = dag.owner 

2751 orm_dag.is_active = True 

2752 orm_dag.has_import_errors = False 

2753 orm_dag.last_parsed_time = timezone.utcnow() 

2754 orm_dag.default_view = dag.default_view 

2755 orm_dag.description = dag.description 

2756 orm_dag.max_active_tasks = dag.max_active_tasks 

2757 orm_dag.max_active_runs = dag.max_active_runs 

2758 orm_dag.has_task_concurrency_limits = any(t.max_active_tis_per_dag is not None for t in dag.tasks) 

2759 orm_dag.schedule_interval = dag.schedule_interval 

2760 orm_dag.timetable_description = dag.timetable.description 

2761 orm_dag.processor_subdir = processor_subdir 

2762 

2763 run: DagRun | None = most_recent_runs.get(dag.dag_id) 

2764 if run is None: 

2765 data_interval = None 

2766 else: 

2767 data_interval = dag.get_run_data_interval(run) 

2768 if num_active_runs.get(dag.dag_id, 0) >= orm_dag.max_active_runs: 

2769 orm_dag.next_dagrun_create_after = None 

2770 else: 

2771 orm_dag.calculate_dagrun_date_fields(dag, data_interval) 

2772 

2773 dag_tags = set(dag.tags or {}) 

2774 orm_dag_tags = list(orm_dag.tags or []) 

2775 for orm_tag in orm_dag_tags: 

2776 if orm_tag.name not in dag_tags: 

2777 session.delete(orm_tag) 

2778 orm_dag.tags.remove(orm_tag) 

2779 orm_tag_names = {t.name for t in orm_dag_tags} 

2780 for dag_tag in dag_tags: 

2781 if dag_tag not in orm_tag_names: 

2782 dag_tag_orm = DagTag(name=dag_tag, dag_id=dag.dag_id) 

2783 orm_dag.tags.append(dag_tag_orm) 

2784 session.add(dag_tag_orm) 

2785 

2786 orm_dag_links = orm_dag.dag_owner_links or [] 

2787 for orm_dag_link in orm_dag_links: 

2788 if orm_dag_link not in dag.owner_links: 

2789 session.delete(orm_dag_link) 

2790 for owner_name, owner_link in dag.owner_links.items(): 

2791 dag_owner_orm = DagOwnerAttributes(dag_id=dag.dag_id, owner=owner_name, link=owner_link) 

2792 session.add(dag_owner_orm) 

2793 

2794 DagCode.bulk_sync_to_db(filelocs, session=session) 

2795 

2796 from airflow.datasets import Dataset 

2797 from airflow.models.dataset import ( 

2798 DagScheduleDatasetReference, 

2799 DatasetModel, 

2800 TaskOutletDatasetReference, 

2801 ) 

2802 

2803 dag_references = collections.defaultdict(set) 

2804 outlet_references = collections.defaultdict(set) 

2805 # We can't use a set here as we want to preserve order 

2806 outlet_datasets: dict[Dataset, None] = {} 

2807 input_datasets: dict[Dataset, None] = {} 

2808 

2809 # here we go through dags and tasks to check for dataset references 

2810 # if there are now None and previously there were some, we delete them 

2811 # if there are now *any*, we add them to the above data structures and 

2812 # later we'll persist them to the database. 

2813 for dag in dags: 

2814 curr_orm_dag = existing_dags.get(dag.dag_id) 

2815 if not dag.dataset_triggers: 

2816 if curr_orm_dag and curr_orm_dag.schedule_dataset_references: 

2817 curr_orm_dag.schedule_dataset_references = [] 

2818 for dataset in dag.dataset_triggers: 

2819 dag_references[dag.dag_id].add(dataset.uri) 

2820 input_datasets[DatasetModel.from_public(dataset)] = None 

2821 curr_outlet_references = curr_orm_dag and curr_orm_dag.task_outlet_dataset_references 

2822 for task in dag.tasks: 

2823 dataset_outlets = [x for x in task.outlets or [] if isinstance(x, Dataset)] 

2824 if not dataset_outlets: 

2825 if curr_outlet_references: 

2826 this_task_outlet_refs = [ 

2827 x 

2828 for x in curr_outlet_references 

2829 if x.dag_id == dag.dag_id and x.task_id == task.task_id 

2830 ] 

2831 for ref in this_task_outlet_refs: 

2832 curr_outlet_references.remove(ref) 

2833 for d in dataset_outlets: 

2834 outlet_references[(task.dag_id, task.task_id)].add(d.uri) 

2835 outlet_datasets[DatasetModel.from_public(d)] = None 

2836 all_datasets = outlet_datasets 

2837 all_datasets.update(input_datasets) 

2838 

2839 # store datasets 

2840 stored_datasets = {} 

2841 for dataset in all_datasets: 

2842 stored_dataset = session.query(DatasetModel).filter(DatasetModel.uri == dataset.uri).first() 

2843 if stored_dataset: 

2844 # Some datasets may have been previously unreferenced, and therefore orphaned by the 

2845 # scheduler. But if we're here, then we have found that dataset again in our DAGs, which 

2846 # means that it is no longer an orphan, so set is_orphaned to False. 

2847 stored_dataset.is_orphaned = expression.false() 

2848 stored_datasets[stored_dataset.uri] = stored_dataset 

2849 else: 

2850 session.add(dataset) 

2851 stored_datasets[dataset.uri] = dataset 

2852 

2853 session.flush() # this is required to ensure each dataset has its PK loaded 

2854 

2855 del all_datasets 

2856 

2857 # reconcile dag-schedule-on-dataset references 

2858 for dag_id, uri_list in dag_references.items(): 

2859 dag_refs_needed = { 

2860 DagScheduleDatasetReference(dataset_id=stored_datasets[uri].id, dag_id=dag_id) 

2861 for uri in uri_list 

2862 } 

2863 dag_refs_stored = set( 

2864 existing_dags.get(dag_id) 

2865 and existing_dags.get(dag_id).schedule_dataset_references # type: ignore 

2866 or [] 

2867 ) 

2868 dag_refs_to_add = {x for x in dag_refs_needed if x not in dag_refs_stored} 

2869 session.bulk_save_objects(dag_refs_to_add) 

2870 for obj in dag_refs_stored - dag_refs_needed: 

2871 session.delete(obj) 

2872 

2873 existing_task_outlet_refs_dict = collections.defaultdict(set) 

2874 for dag_id, orm_dag in existing_dags.items(): 

2875 for todr in orm_dag.task_outlet_dataset_references: 

2876 existing_task_outlet_refs_dict[(dag_id, todr.task_id)].add(todr) 

2877 

2878 # reconcile task-outlet-dataset references 

2879 for (dag_id, task_id), uri_list in outlet_references.items(): 

2880 task_refs_needed = { 

2881 TaskOutletDatasetReference(dataset_id=stored_datasets[uri].id, dag_id=dag_id, task_id=task_id) 

2882 for uri in uri_list 

2883 } 

2884 task_refs_stored = existing_task_outlet_refs_dict[(dag_id, task_id)] 

2885 task_refs_to_add = {x for x in task_refs_needed if x not in task_refs_stored} 

2886 session.bulk_save_objects(task_refs_to_add) 

2887 for obj in task_refs_stored - task_refs_needed: 

2888 session.delete(obj) 

2889 

2890 # Issue SQL/finish "Unit of Work", but let @provide_session commit (or if passed a session, let caller 

2891 # decide when to commit 

2892 session.flush() 

2893 

2894 for dag in dags: 

2895 cls.bulk_write_to_db(dag.subdags, processor_subdir=processor_subdir, session=session) 

2896 

2897 @provide_session 

2898 def sync_to_db(self, processor_subdir: str | None = None, session=NEW_SESSION): 

2899 """ 

2900 Save attributes about this DAG to the DB. Note that this method 

2901 can be called for both DAGs and SubDAGs. A SubDag is actually a 

2902 SubDagOperator. 

2903 

2904 :return: None 

2905 """ 

2906 self.bulk_write_to_db([self], processor_subdir=processor_subdir, session=session) 

2907 

2908 def get_default_view(self): 

2909 """This is only there for backward compatible jinja2 templates""" 

2910 if self.default_view is None: 

2911 return conf.get("webserver", "dag_default_view").lower() 

2912 else: 

2913 return self.default_view 

2914 

2915 @staticmethod 

2916 @provide_session 

2917 def deactivate_unknown_dags(active_dag_ids, session=NEW_SESSION): 

2918 """ 

2919 Given a list of known DAGs, deactivate any other DAGs that are 

2920 marked as active in the ORM 

2921 

2922 :param active_dag_ids: list of DAG IDs that are active 

2923 :return: None 

2924 """ 

2925 if len(active_dag_ids) == 0: 

2926 return 

2927 for dag in session.query(DagModel).filter(~DagModel.dag_id.in_(active_dag_ids)).all(): 

2928 dag.is_active = False 

2929 session.merge(dag) 

2930 session.commit() 

2931 

2932 @staticmethod 

2933 @provide_session 

2934 def deactivate_stale_dags(expiration_date, session=NEW_SESSION): 

2935 """ 

2936 Deactivate any DAGs that were last touched by the scheduler before 

2937 the expiration date. These DAGs were likely deleted. 

2938 

2939 :param expiration_date: set inactive DAGs that were touched before this 

2940 time 

2941 :return: None 

2942 """ 

2943 for dag in ( 

2944 session.query(DagModel) 

2945 .filter(DagModel.last_parsed_time < expiration_date, DagModel.is_active) 

2946 .all() 

2947 ): 

2948 log.info( 

2949 "Deactivating DAG ID %s since it was last touched by the scheduler at %s", 

2950 dag.dag_id, 

2951 dag.last_parsed_time.isoformat(), 

2952 ) 

2953 dag.is_active = False 

2954 session.merge(dag) 

2955 session.commit() 

2956 

2957 @staticmethod 

2958 @provide_session 

2959 def get_num_task_instances(dag_id, task_ids=None, states=None, session=NEW_SESSION) -> int: 

2960 """ 

2961 Returns the number of task instances in the given DAG. 

2962 

2963 :param session: ORM session 

2964 :param dag_id: ID of the DAG to get the task concurrency of 

2965 :param task_ids: A list of valid task IDs for the given DAG 

2966 :param states: A list of states to filter by if supplied 

2967 :return: The number of running tasks 

2968 """ 

2969 qry = session.query(func.count(TaskInstance.task_id)).filter( 

2970 TaskInstance.dag_id == dag_id, 

2971 ) 

2972 if task_ids: 

2973 qry = qry.filter( 

2974 TaskInstance.task_id.in_(task_ids), 

2975 ) 

2976 

2977 if states: 

2978 if None in states: 

2979 if all(x is None for x in states): 

2980 qry = qry.filter(TaskInstance.state.is_(None)) 

2981 else: 

2982 not_none_states = [state for state in states if state] 

2983 qry = qry.filter( 

2984 or_(TaskInstance.state.in_(not_none_states), TaskInstance.state.is_(None)) 

2985 ) 

2986 else: 

2987 qry = qry.filter(TaskInstance.state.in_(states)) 

2988 return qry.scalar() 

2989 

2990 @classmethod 

2991 def get_serialized_fields(cls): 

2992 """Stringified DAGs and operators contain exactly these fields.""" 

2993 if not cls.__serialized_fields: 

2994 exclusion_list = { 

2995 "parent_dag", 

2996 "schedule_dataset_references", 

2997 "task_outlet_dataset_references", 

2998 "_old_context_manager_dags", 

2999 "safe_dag_id", 

3000 "last_loaded", 

3001 "user_defined_filters", 

3002 "user_defined_macros", 

3003 "partial", 

3004 "params", 

3005 "_pickle_id", 

3006 "_log", 

3007 "task_dict", 

3008 "template_searchpath", 

3009 "sla_miss_callback", 

3010 "on_success_callback", 

3011 "on_failure_callback", 

3012 "template_undefined", 

3013 "jinja_environment_kwargs", 

3014 # has_on_*_callback are only stored if the value is True, as the default is False 

3015 "has_on_success_callback", 

3016 "has_on_failure_callback", 

3017 "auto_register", 

3018 } 

3019 cls.__serialized_fields = frozenset(vars(DAG(dag_id="test")).keys()) - exclusion_list 

3020 return cls.__serialized_fields 

3021 

3022 def get_edge_info(self, upstream_task_id: str, downstream_task_id: str) -> EdgeInfoType: 

3023 """ 

3024 Returns edge information for the given pair of tasks if present, and 

3025 an empty edge if there is no information. 

3026 """ 

3027 # Note - older serialized DAGs may not have edge_info being a dict at all 

3028 empty = cast(EdgeInfoType, {}) 

3029 if self.edge_info: 

3030 return self.edge_info.get(upstream_task_id, {}).get(downstream_task_id, empty) 

3031 else: 

3032 return empty 

3033 

3034 def set_edge_info(self, upstream_task_id: str, downstream_task_id: str, info: EdgeInfoType): 

3035 """ 

3036 Sets the given edge information on the DAG. Note that this will overwrite, 

3037 rather than merge with, existing info. 

3038 """ 

3039 self.edge_info.setdefault(upstream_task_id, {})[downstream_task_id] = info 

3040 

3041 def validate_schedule_and_params(self): 

3042 """ 

3043 Validates & raise exception if there are any Params in the DAG which neither have a default value nor 

3044 have the null in schema['type'] list, but the DAG have a schedule_interval which is not None. 

3045 """ 

3046 if not self.timetable.can_run: 

3047 return 

3048 

3049 for k, v in self.params.items(): 

3050 # As type can be an array, we would check if `null` is an allowed type or not 

3051 if not v.has_value and ("type" not in v.schema or "null" not in v.schema["type"]): 

3052 raise AirflowException( 

3053 "DAG Schedule must be None, if there are any required params without default values" 

3054 ) 

3055 

3056 def iter_invalid_owner_links(self) -> Iterator[tuple[str, str]]: 

3057 """Parses a given link, and verifies if it's a valid URL, or a 'mailto' link. 

3058 Returns an iterator of invalid (owner, link) pairs. 

3059 """ 

3060 for owner, link in self.owner_links.items(): 

3061 result = urlsplit(link) 

3062 if result.scheme == "mailto": 

3063 # netloc is not existing for 'mailto' link, so we are checking that the path is parsed 

3064 if not result.path: 

3065 yield result.path, link 

3066 elif not result.scheme or not result.netloc: 

3067 yield owner, link 

3068 

3069 

3070class DagTag(Base): 

3071 """A tag name per dag, to allow quick filtering in the DAG view.""" 

3072 

3073 __tablename__ = "dag_tag" 

3074 name = Column(String(TAG_MAX_LEN), primary_key=True) 

3075 dag_id = Column( 

3076 StringID(), 

3077 ForeignKey("dag.dag_id", name="dag_tag_dag_id_fkey", ondelete="CASCADE"), 

3078 primary_key=True, 

3079 ) 

3080 

3081 def __repr__(self): 

3082 return self.name 

3083 

3084 

3085class DagOwnerAttributes(Base): 

3086 """ 

3087 Table defining different owner attributes. For example, a link for an owner that will be passed as 

3088 a hyperlink to the DAGs view 

3089 """ 

3090 

3091 __tablename__ = "dag_owner_attributes" 

3092 dag_id = Column( 

3093 StringID(), 

3094 ForeignKey("dag.dag_id", name="dag.dag_id", ondelete="CASCADE"), 

3095 nullable=False, 

3096 primary_key=True, 

3097 ) 

3098 owner = Column(String(500), primary_key=True, nullable=False) 

3099 link = Column(String(500), nullable=False) 

3100 

3101 def __repr__(self): 

3102 return f"<DagOwnerAttributes: dag_id={self.dag_id}, owner={self.owner}, link={self.link}>" 

3103 

3104 @classmethod 

3105 def get_all(cls, session) -> dict[str, dict[str, str]]: 

3106 dag_links: dict = collections.defaultdict(dict) 

3107 for obj in session.query(cls): 

3108 dag_links[obj.dag_id].update({obj.owner: obj.link}) 

3109 return dag_links 

3110 

3111 

3112class DagModel(Base): 

3113 """Table containing DAG properties""" 

3114 

3115 __tablename__ = "dag" 

3116 """ 

3117 These items are stored in the database for state related information 

3118 """ 

3119 dag_id = Column(StringID(), primary_key=True) 

3120 root_dag_id = Column(StringID()) 

3121 # A DAG can be paused from the UI / DB 

3122 # Set this default value of is_paused based on a configuration value! 

3123 is_paused_at_creation = conf.getboolean("core", "dags_are_paused_at_creation") 

3124 is_paused = Column(Boolean, default=is_paused_at_creation) 

3125 # Whether the DAG is a subdag 

3126 is_subdag = Column(Boolean, default=False) 

3127 # Whether that DAG was seen on the last DagBag load 

3128 is_active = Column(Boolean, default=False) 

3129 # Last time the scheduler started 

3130 last_parsed_time = Column(UtcDateTime) 

3131 # Last time this DAG was pickled 

3132 last_pickled = Column(UtcDateTime) 

3133 # Time when the DAG last received a refresh signal 

3134 # (e.g. the DAG's "refresh" button was clicked in the web UI) 

3135 last_expired = Column(UtcDateTime) 

3136 # Whether (one of) the scheduler is scheduling this DAG at the moment 

3137 scheduler_lock = Column(Boolean) 

3138 # Foreign key to the latest pickle_id 

3139 pickle_id = Column(Integer) 

3140 # The location of the file containing the DAG object 

3141 # Note: Do not depend on fileloc pointing to a file; in the case of a 

3142 # packaged DAG, it will point to the subpath of the DAG within the 

3143 # associated zip. 

3144 fileloc = Column(String(2000)) 

3145 # The base directory used by Dag Processor that parsed this dag. 

3146 processor_subdir = Column(String(2000), nullable=True) 

3147 # String representing the owners 

3148 owners = Column(String(2000)) 

3149 # Description of the dag 

3150 description = Column(Text) 

3151 # Default view of the DAG inside the webserver 

3152 default_view = Column(String(25)) 

3153 # Schedule interval 

3154 schedule_interval = Column(Interval) 

3155 # Timetable/Schedule Interval description 

3156 timetable_description = Column(String(1000), nullable=True) 

3157 # Tags for view filter 

3158 tags = relationship("DagTag", cascade="all, delete, delete-orphan", backref=backref("dag")) 

3159 # Dag owner links for DAGs view 

3160 dag_owner_links = relationship( 

3161 "DagOwnerAttributes", cascade="all, delete, delete-orphan", backref=backref("dag") 

3162 ) 

3163 

3164 max_active_tasks = Column(Integer, nullable=False) 

3165 max_active_runs = Column(Integer, nullable=True) 

3166 

3167 has_task_concurrency_limits = Column(Boolean, nullable=False) 

3168 has_import_errors = Column(Boolean(), default=False, server_default="0") 

3169 

3170 # The logical date of the next dag run. 

3171 next_dagrun = Column(UtcDateTime) 

3172 

3173 # Must be either both NULL or both datetime. 

3174 next_dagrun_data_interval_start = Column(UtcDateTime) 

3175 next_dagrun_data_interval_end = Column(UtcDateTime) 

3176 

3177 # Earliest time at which this ``next_dagrun`` can be created. 

3178 next_dagrun_create_after = Column(UtcDateTime) 

3179 

3180 __table_args__ = ( 

3181 Index("idx_root_dag_id", root_dag_id, unique=False), 

3182 Index("idx_next_dagrun_create_after", next_dagrun_create_after, unique=False), 

3183 ) 

3184 

3185 parent_dag = relationship( 

3186 "DagModel", remote_side=[dag_id], primaryjoin=root_dag_id == dag_id, foreign_keys=[root_dag_id] 

3187 ) 

3188 schedule_dataset_references = relationship( 

3189 "DagScheduleDatasetReference", 

3190 cascade="all, delete, delete-orphan", 

3191 ) 

3192 schedule_datasets = association_proxy("schedule_dataset_references", "dataset") 

3193 task_outlet_dataset_references = relationship( 

3194 "TaskOutletDatasetReference", 

3195 cascade="all, delete, delete-orphan", 

3196 ) 

3197 NUM_DAGS_PER_DAGRUN_QUERY = conf.getint("scheduler", "max_dagruns_to_create_per_loop", fallback=10) 

3198 

3199 def __init__(self, concurrency=None, **kwargs): 

3200 super().__init__(**kwargs) 

3201 if self.max_active_tasks is None: 

3202 if concurrency: 

3203 warnings.warn( 

3204 "The 'DagModel.concurrency' parameter is deprecated. Please use 'max_active_tasks'.", 

3205 RemovedInAirflow3Warning, 

3206 stacklevel=2, 

3207 ) 

3208 self.max_active_tasks = concurrency 

3209 else: 

3210 self.max_active_tasks = conf.getint("core", "max_active_tasks_per_dag") 

3211 

3212 if self.max_active_runs is None: 

3213 self.max_active_runs = conf.getint("core", "max_active_runs_per_dag") 

3214 

3215 if self.has_task_concurrency_limits is None: 

3216 # Be safe -- this will be updated later once the DAG is parsed 

3217 self.has_task_concurrency_limits = True 

3218 

3219 def __repr__(self): 

3220 return f"<DAG: {self.dag_id}>" 

3221 

3222 @property 

3223 def next_dagrun_data_interval(self) -> DataInterval | None: 

3224 return _get_model_data_interval( 

3225 self, 

3226 "next_dagrun_data_interval_start", 

3227 "next_dagrun_data_interval_end", 

3228 ) 

3229 

3230 @next_dagrun_data_interval.setter 

3231 def next_dagrun_data_interval(self, value: tuple[datetime, datetime] | None) -> None: 

3232 if value is None: 

3233 self.next_dagrun_data_interval_start = self.next_dagrun_data_interval_end = None 

3234 else: 

3235 self.next_dagrun_data_interval_start, self.next_dagrun_data_interval_end = value 

3236 

3237 @property 

3238 def timezone(self): 

3239 return settings.TIMEZONE 

3240 

3241 @staticmethod 

3242 @provide_session 

3243 def get_dagmodel(dag_id, session=NEW_SESSION): 

3244 return session.query(DagModel).options(joinedload(DagModel.parent_dag)).get(dag_id) 

3245 

3246 @classmethod 

3247 @provide_session 

3248 def get_current(cls, dag_id, session=NEW_SESSION): 

3249 return session.query(cls).filter(cls.dag_id == dag_id).first() 

3250 

3251 @provide_session 

3252 def get_last_dagrun(self, session=NEW_SESSION, include_externally_triggered=False): 

3253 return get_last_dagrun( 

3254 self.dag_id, session=session, include_externally_triggered=include_externally_triggered 

3255 ) 

3256 

3257 def get_is_paused(self, *, session: Session | None = None) -> bool: 

3258 """Provide interface compatibility to 'DAG'.""" 

3259 return self.is_paused 

3260 

3261 @staticmethod 

3262 @provide_session 

3263 def get_paused_dag_ids(dag_ids: list[str], session: Session = NEW_SESSION) -> set[str]: 

3264 """ 

3265 Given a list of dag_ids, get a set of Paused Dag Ids 

3266 

3267 :param dag_ids: List of Dag ids 

3268 :param session: ORM Session 

3269 :return: Paused Dag_ids 

3270 """ 

3271 paused_dag_ids = ( 

3272 session.query(DagModel.dag_id) 

3273 .filter(DagModel.is_paused == expression.true()) 

3274 .filter(DagModel.dag_id.in_(dag_ids)) 

3275 .all() 

3276 ) 

3277 

3278 paused_dag_ids = {paused_dag_id for paused_dag_id, in paused_dag_ids} 

3279 return paused_dag_ids 

3280 

3281 def get_default_view(self) -> str: 

3282 """ 

3283 Get the Default DAG View, returns the default config value if DagModel does not 

3284 have a value 

3285 """ 

3286 # This is for backwards-compatibility with old dags that don't have None as default_view 

3287 return self.default_view or conf.get_mandatory_value("webserver", "dag_default_view").lower() 

3288 

3289 @property 

3290 def safe_dag_id(self): 

3291 return self.dag_id.replace(".", "__dot__") 

3292 

3293 @property 

3294 def relative_fileloc(self) -> pathlib.Path | None: 

3295 """File location of the importable dag 'file' relative to the configured DAGs folder.""" 

3296 if self.fileloc is None: 

3297 return None 

3298 path = pathlib.Path(self.fileloc) 

3299 try: 

3300 return path.relative_to(settings.DAGS_FOLDER) 

3301 except ValueError: 

3302 # Not relative to DAGS_FOLDER. 

3303 return path 

3304 

3305 @provide_session 

3306 def set_is_paused(self, is_paused: bool, including_subdags: bool = True, session=NEW_SESSION) -> None: 

3307 """ 

3308 Pause/Un-pause a DAG. 

3309 

3310 :param is_paused: Is the DAG paused 

3311 :param including_subdags: whether to include the DAG's subdags 

3312 :param session: session 

3313 """ 

3314 filter_query = [ 

3315 DagModel.dag_id == self.dag_id, 

3316 ] 

3317 if including_subdags: 

3318 filter_query.append(DagModel.root_dag_id == self.dag_id) 

3319 session.query(DagModel).filter(or_(*filter_query)).update( 

3320 {DagModel.is_paused: is_paused}, synchronize_session="fetch" 

3321 ) 

3322 session.commit() 

3323 

3324 @classmethod 

3325 @provide_session 

3326 def deactivate_deleted_dags(cls, alive_dag_filelocs: list[str], session=NEW_SESSION): 

3327 """ 

3328 Set ``is_active=False`` on the DAGs for which the DAG files have been removed. 

3329 

3330 :param alive_dag_filelocs: file paths of alive DAGs 

3331 :param session: ORM Session 

3332 """ 

3333 log.debug("Deactivating DAGs (for which DAG files are deleted) from %s table ", cls.__tablename__) 

3334 

3335 dag_models = session.query(cls).all() 

3336 for dag_model in dag_models: 

3337 if dag_model.fileloc is not None: 

3338 if correct_maybe_zipped(dag_model.fileloc) not in alive_dag_filelocs: 

3339 dag_model.is_active = False 

3340 else: 

3341 continue 

3342 

3343 @classmethod 

3344 def dags_needing_dagruns(cls, session: Session) -> tuple[Query, dict[str, tuple[datetime, datetime]]]: 

3345 """ 

3346 Return (and lock) a list of Dag objects that are due to create a new DagRun. 

3347 

3348 This will return a resultset of rows that is row-level-locked with a "SELECT ... FOR UPDATE" query, 

3349 you should ensure that any scheduling decisions are made in a single transaction -- as soon as the 

3350 transaction is committed it will be unlocked. 

3351 """ 

3352 from airflow.models.dataset import DagScheduleDatasetReference, DatasetDagRunQueue as DDRQ 

3353 

3354 # these dag ids are triggered by datasets, and they are ready to go. 

3355 dataset_triggered_dag_info = { 

3356 x.dag_id: (x.first_queued_time, x.last_queued_time) 

3357 for x in session.query( 

3358 DagScheduleDatasetReference.dag_id, 

3359 func.max(DDRQ.created_at).label("last_queued_time"), 

3360 func.min(DDRQ.created_at).label("first_queued_time"), 

3361 ) 

3362 .join(DagScheduleDatasetReference.queue_records, isouter=True) 

3363 .group_by(DagScheduleDatasetReference.dag_id) 

3364 .having(func.count() == func.sum(case((DDRQ.target_dag_id.is_not(None), 1), else_=0))) 

3365 .all() 

3366 } 

3367 dataset_triggered_dag_ids = set(dataset_triggered_dag_info.keys()) 

3368 if dataset_triggered_dag_ids: 

3369 exclusion_list = { 

3370 x.dag_id 

3371 for x in ( 

3372 session.query(DagModel.dag_id) 

3373 .join(DagRun.dag_model) 

3374 .filter(DagRun.state.in_((DagRunState.QUEUED, DagRunState.RUNNING))) 

3375 .filter(DagModel.dag_id.in_(dataset_triggered_dag_ids)) 

3376 .group_by(DagModel.dag_id) 

3377 .having(func.count() >= func.max(DagModel.max_active_runs)) 

3378 .all() 

3379 ) 

3380 } 

3381 if exclusion_list: 

3382 dataset_triggered_dag_ids -= exclusion_list 

3383 dataset_triggered_dag_info = { 

3384 k: v for k, v in dataset_triggered_dag_info.items() if k not in exclusion_list 

3385 } 

3386 

3387 # We limit so that _one_ scheduler doesn't try to do all the creation of dag runs 

3388 query = ( 

3389 session.query(cls) 

3390 .filter( 

3391 cls.is_paused == expression.false(), 

3392 cls.is_active == expression.true(), 

3393 cls.has_import_errors == expression.false(), 

3394 or_( 

3395 cls.next_dagrun_create_after <= func.now(), 

3396 cls.dag_id.in_(dataset_triggered_dag_ids), 

3397 ), 

3398 ) 

3399 .order_by(cls.next_dagrun_create_after) 

3400 .limit(cls.NUM_DAGS_PER_DAGRUN_QUERY) 

3401 ) 

3402 

3403 return ( 

3404 with_row_locks(query, of=cls, session=session, **skip_locked(session=session)), 

3405 dataset_triggered_dag_info, 

3406 ) 

3407 

3408 def calculate_dagrun_date_fields( 

3409 self, 

3410 dag: DAG, 

3411 most_recent_dag_run: None | datetime | DataInterval, 

3412 ) -> None: 

3413 """ 

3414 Calculate ``next_dagrun`` and `next_dagrun_create_after`` 

3415 

3416 :param dag: The DAG object 

3417 :param most_recent_dag_run: DataInterval (or datetime) of most recent run of this dag, or none 

3418 if not yet scheduled. 

3419 """ 

3420 most_recent_data_interval: DataInterval | None 

3421 if isinstance(most_recent_dag_run, datetime): 

3422 warnings.warn( 

3423 "Passing a datetime to `DagModel.calculate_dagrun_date_fields` is deprecated. " 

3424 "Provide a data interval instead.", 

3425 RemovedInAirflow3Warning, 

3426 stacklevel=2, 

3427 ) 

3428 most_recent_data_interval = dag.infer_automated_data_interval(most_recent_dag_run) 

3429 else: 

3430 most_recent_data_interval = most_recent_dag_run 

3431 next_dagrun_info = dag.next_dagrun_info(most_recent_data_interval) 

3432 if next_dagrun_info is None: 

3433 self.next_dagrun_data_interval = self.next_dagrun = self.next_dagrun_create_after = None 

3434 else: 

3435 self.next_dagrun_data_interval = next_dagrun_info.data_interval 

3436 self.next_dagrun = next_dagrun_info.logical_date 

3437 self.next_dagrun_create_after = next_dagrun_info.run_after 

3438 

3439 log.info( 

3440 "Setting next_dagrun for %s to %s, run_after=%s", 

3441 dag.dag_id, 

3442 self.next_dagrun, 

3443 self.next_dagrun_create_after, 

3444 ) 

3445 

3446 @provide_session 

3447 def get_dataset_triggered_next_run_info(self, *, session=NEW_SESSION) -> dict[str, int | str] | None: 

3448 if self.schedule_interval != "Dataset": 

3449 return None 

3450 return get_dataset_triggered_next_run_info([self.dag_id], session=session)[self.dag_id] 

3451 

3452 

3453# NOTE: Please keep the list of arguments in sync with DAG.__init__. 

3454# Only exception: dag_id here should have a default value, but not in DAG. 

3455def dag( 

3456 dag_id: str = "", 

3457 description: str | None = None, 

3458 schedule: ScheduleArg = NOTSET, 

3459 schedule_interval: ScheduleIntervalArg = NOTSET, 

3460 timetable: Timetable | None = None, 

3461 start_date: datetime | None = None, 

3462 end_date: datetime | None = None, 

3463 full_filepath: str | None = None, 

3464 template_searchpath: str | Iterable[str] | None = None, 

3465 template_undefined: type[jinja2.StrictUndefined] = jinja2.StrictUndefined, 

3466 user_defined_macros: dict | None = None, 

3467 user_defined_filters: dict | None = None, 

3468 default_args: dict | None = None, 

3469 concurrency: int | None = None, 

3470 max_active_tasks: int = conf.getint("core", "max_active_tasks_per_dag"), 

3471 max_active_runs: int = conf.getint("core", "max_active_runs_per_dag"), 

3472 dagrun_timeout: timedelta | None = None, 

3473 sla_miss_callback: None | SLAMissCallback | list[SLAMissCallback] = None, 

3474 default_view: str = conf.get_mandatory_value("webserver", "dag_default_view").lower(), 

3475 orientation: str = conf.get_mandatory_value("webserver", "dag_orientation"), 

3476 catchup: bool = conf.getboolean("scheduler", "catchup_by_default"), 

3477 on_success_callback: None | DagStateChangeCallback | list[DagStateChangeCallback] = None, 

3478 on_failure_callback: None | DagStateChangeCallback | list[DagStateChangeCallback] = None, 

3479 doc_md: str | None = None, 

3480 params: dict | None = None, 

3481 access_control: dict | None = None, 

3482 is_paused_upon_creation: bool | None = None, 

3483 jinja_environment_kwargs: dict | None = None, 

3484 render_template_as_native_obj: bool = False, 

3485 tags: list[str] | None = None, 

3486 owner_links: dict[str, str] | None = None, 

3487 auto_register: bool = True, 

3488) -> Callable[[Callable], Callable[..., DAG]]: 

3489 """ 

3490 Python dag decorator. Wraps a function into an Airflow DAG. 

3491 Accepts kwargs for operator kwarg. Can be used to parameterize DAGs. 

3492 

3493 :param dag_args: Arguments for DAG object 

3494 :param dag_kwargs: Kwargs for DAG object. 

3495 """ 

3496 

3497 def wrapper(f: Callable) -> Callable[..., DAG]: 

3498 @functools.wraps(f) 

3499 def factory(*args, **kwargs): 

3500 # Generate signature for decorated function and bind the arguments when called 

3501 # we do this to extract parameters so we can annotate them on the DAG object. 

3502 # In addition, this fails if we are missing any args/kwargs with TypeError as expected. 

3503 f_sig = signature(f).bind(*args, **kwargs) 

3504 # Apply defaults to capture default values if set. 

3505 f_sig.apply_defaults() 

3506 

3507 # Initialize DAG with bound arguments 

3508 with DAG( 

3509 dag_id or f.__name__, 

3510 description=description, 

3511 schedule_interval=schedule_interval, 

3512 timetable=timetable, 

3513 start_date=start_date, 

3514 end_date=end_date, 

3515 full_filepath=full_filepath, 

3516 template_searchpath=template_searchpath, 

3517 template_undefined=template_undefined, 

3518 user_defined_macros=user_defined_macros, 

3519 user_defined_filters=user_defined_filters, 

3520 default_args=default_args, 

3521 concurrency=concurrency, 

3522 max_active_tasks=max_active_tasks, 

3523 max_active_runs=max_active_runs, 

3524 dagrun_timeout=dagrun_timeout, 

3525 sla_miss_callback=sla_miss_callback, 

3526 default_view=default_view, 

3527 orientation=orientation, 

3528 catchup=catchup, 

3529 on_success_callback=on_success_callback, 

3530 on_failure_callback=on_failure_callback, 

3531 doc_md=doc_md, 

3532 params=params, 

3533 access_control=access_control, 

3534 is_paused_upon_creation=is_paused_upon_creation, 

3535 jinja_environment_kwargs=jinja_environment_kwargs, 

3536 render_template_as_native_obj=render_template_as_native_obj, 

3537 tags=tags, 

3538 schedule=schedule, 

3539 owner_links=owner_links, 

3540 auto_register=auto_register, 

3541 ) as dag_obj: 

3542 # Set DAG documentation from function documentation. 

3543 if f.__doc__: 

3544 dag_obj.doc_md = f.__doc__ 

3545 

3546 # Generate DAGParam for each function arg/kwarg and replace it for calling the function. 

3547 # All args/kwargs for function will be DAGParam object and replaced on execution time. 

3548 f_kwargs = {} 

3549 for name, value in f_sig.arguments.items(): 

3550 f_kwargs[name] = dag_obj.param(name, value) 

3551 

3552 # set file location to caller source path 

3553 back = sys._getframe().f_back 

3554 dag_obj.fileloc = back.f_code.co_filename if back else "" 

3555 

3556 # Invoke function to create operators in the DAG scope. 

3557 f(**f_kwargs) 

3558 

3559 # Return dag object such that it's accessible in Globals. 

3560 return dag_obj 

3561 

3562 # Ensure that warnings from inside DAG() are emitted from the caller, not here 

3563 fixup_decorator_warning_stack(factory) 

3564 return factory 

3565 

3566 return wrapper 

3567 

3568 

3569STATICA_HACK = True 

3570globals()["kcah_acitats"[::-1].upper()] = False 

3571if STATICA_HACK: # pragma: no cover 

3572 

3573 from airflow.models.serialized_dag import SerializedDagModel 

3574 

3575 DagModel.serialized_dag = relationship(SerializedDagModel) 

3576 """:sphinx-autoapi-skip:""" 

3577 

3578 

3579class DagContext: 

3580 """ 

3581 DAG context is used to keep the current DAG when DAG is used as ContextManager. 

3582 

3583 You can use DAG as context: 

3584 

3585 .. code-block:: python 

3586 

3587 with DAG( 

3588 dag_id="example_dag", 

3589 default_args=default_args, 

3590 schedule="0 0 * * *", 

3591 dagrun_timeout=timedelta(minutes=60), 

3592 ) as dag: 

3593 ... 

3594 

3595 If you do this the context stores the DAG and whenever new task is created, it will use 

3596 such stored DAG as the parent DAG. 

3597 

3598 """ 

3599 

3600 _context_managed_dags: Deque[DAG] = deque() 

3601 autoregistered_dags: set[tuple[DAG, ModuleType]] = set() 

3602 current_autoregister_module_name: str | None = None 

3603 

3604 @classmethod 

3605 def push_context_managed_dag(cls, dag: DAG): 

3606 cls._context_managed_dags.appendleft(dag) 

3607 

3608 @classmethod 

3609 def pop_context_managed_dag(cls) -> DAG | None: 

3610 dag = cls._context_managed_dags.popleft() 

3611 

3612 # In a few cases around serialization we explicitly push None in to the stack 

3613 if cls.current_autoregister_module_name is not None and dag and dag.auto_register: 

3614 mod = sys.modules[cls.current_autoregister_module_name] 

3615 cls.autoregistered_dags.add((dag, mod)) 

3616 

3617 return dag 

3618 

3619 @classmethod 

3620 def get_current_dag(cls) -> DAG | None: 

3621 try: 

3622 return cls._context_managed_dags[0] 

3623 except IndexError: 

3624 return None 

3625 

3626 

3627def _run_task(ti: TaskInstance, session): 

3628 """ 

3629 Run a single task instance, and push result to Xcom for downstream tasks. Bypasses a lot of 

3630 extra steps used in `task.run` to keep our local running as fast as possible 

3631 This function is only meant for the `dag.test` function as a helper function. 

3632 

3633 Args: 

3634 ti: TaskInstance to run 

3635 """ 

3636 log.info("*****************************************************") 

3637 if ti.map_index > 0: 

3638 log.info("Running task %s index %d", ti.task_id, ti.map_index) 

3639 else: 

3640 log.info("Running task %s", ti.task_id) 

3641 try: 

3642 ti._run_raw_task(session=session) 

3643 session.flush() 

3644 log.info("%s ran successfully!", ti.task_id) 

3645 except AirflowSkipException: 

3646 log.info("Task Skipped, continuing") 

3647 log.info("*****************************************************") 

3648 

3649 

3650def _get_or_create_dagrun( 

3651 dag: DAG, 

3652 conf: dict[Any, Any] | None, 

3653 start_date: datetime, 

3654 execution_date: datetime, 

3655 run_id: str, 

3656 session: Session, 

3657) -> DagRun: 

3658 """ 

3659 Create a DAGRun, but only after clearing the previous instance of said dagrun to prevent collisions. 

3660 This function is only meant for the `dag.test` function as a helper function. 

3661 :param dag: Dag to be used to find dagrun 

3662 :param conf: configuration to pass to newly created dagrun 

3663 :param start_date: start date of new dagrun, defaults to execution_date 

3664 :param execution_date: execution_date for finding the dagrun 

3665 :param run_id: run_id to pass to new dagrun 

3666 :param session: sqlalchemy session 

3667 :return: 

3668 """ 

3669 log.info("dagrun id: %s", dag.dag_id) 

3670 dr: DagRun = ( 

3671 session.query(DagRun) 

3672 .filter(DagRun.dag_id == dag.dag_id, DagRun.execution_date == execution_date) 

3673 .first() 

3674 ) 

3675 if dr: 

3676 session.delete(dr) 

3677 session.commit() 

3678 dr = dag.create_dagrun( 

3679 state=DagRunState.RUNNING, 

3680 execution_date=execution_date, 

3681 run_id=run_id, 

3682 start_date=start_date or execution_date, 

3683 session=session, 

3684 conf=conf, # type: ignore 

3685 ) 

3686 log.info("created dagrun " + str(dr)) 

3687 return dr