Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/airflow/models/dag.py: 32%

1581 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-06-07 06:35 +0000

1# 

2# Licensed to the Apache Software Foundation (ASF) under one 

3# or more contributor license agreements. See the NOTICE file 

4# distributed with this work for additional information 

5# regarding copyright ownership. The ASF licenses this file 

6# to you under the Apache License, Version 2.0 (the 

7# "License"); you may not use this file except in compliance 

8# with the License. You may obtain a copy of the License at 

9# 

10# http://www.apache.org/licenses/LICENSE-2.0 

11# 

12# Unless required by applicable law or agreed to in writing, 

13# software distributed under the License is distributed on an 

14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

15# KIND, either express or implied. See the License for the 

16# specific language governing permissions and limitations 

17# under the License. 

18from __future__ import annotations 

19 

20import collections 

21import collections.abc 

22import copy 

23import functools 

24import itertools 

25import logging 

26import os 

27import pathlib 

28import pickle 

29import re 

30import sys 

31import traceback 

32import warnings 

33import weakref 

34from collections import deque 

35from datetime import datetime, timedelta 

36from inspect import signature 

37from typing import ( 

38 TYPE_CHECKING, 

39 Any, 

40 Callable, 

41 Collection, 

42 Deque, 

43 Iterable, 

44 Iterator, 

45 List, 

46 Sequence, 

47 Union, 

48 cast, 

49 overload, 

50) 

51from urllib.parse import urlsplit 

52 

53import jinja2 

54import pendulum 

55from dateutil.relativedelta import relativedelta 

56from pendulum.tz.timezone import Timezone 

57from sqlalchemy import Boolean, Column, ForeignKey, Index, Integer, String, Text, and_, case, func, not_, or_ 

58from sqlalchemy.ext.associationproxy import association_proxy 

59from sqlalchemy.orm import backref, joinedload, relationship 

60from sqlalchemy.orm.query import Query 

61from sqlalchemy.orm.session import Session 

62from sqlalchemy.sql import expression 

63 

64import airflow.templates 

65from airflow import settings, utils 

66from airflow.api_internal.internal_api_call import internal_api_call 

67from airflow.configuration import conf, secrets_backend_list 

68from airflow.exceptions import ( 

69 AirflowDagInconsistent, 

70 AirflowException, 

71 AirflowSkipException, 

72 DagInvalidTriggerRule, 

73 DuplicateTaskIdFound, 

74 RemovedInAirflow3Warning, 

75 TaskNotFound, 

76) 

77from airflow.jobs.job import run_job 

78from airflow.models.abstractoperator import AbstractOperator 

79from airflow.models.base import Base, StringID 

80from airflow.models.baseoperator import BaseOperator 

81from airflow.models.dagcode import DagCode 

82from airflow.models.dagpickle import DagPickle 

83from airflow.models.dagrun import DagRun 

84from airflow.models.operator import Operator 

85from airflow.models.param import DagParam, ParamsDict 

86from airflow.models.taskinstance import Context, TaskInstance, TaskInstanceKey, clear_task_instances 

87from airflow.secrets.local_filesystem import LocalFilesystemBackend 

88from airflow.security import permissions 

89from airflow.stats import Stats 

90from airflow.timetables.base import DagRunInfo, DataInterval, TimeRestriction, Timetable 

91from airflow.timetables.interval import CronDataIntervalTimetable, DeltaDataIntervalTimetable 

92from airflow.timetables.simple import ( 

93 ContinuousTimetable, 

94 DatasetTriggeredTimetable, 

95 NullTimetable, 

96 OnceTimetable, 

97) 

98from airflow.typing_compat import Literal 

99from airflow.utils import timezone 

100from airflow.utils.dag_cycle_tester import check_cycle 

101from airflow.utils.dates import cron_presets, date_range as utils_date_range 

102from airflow.utils.decorators import fixup_decorator_warning_stack 

103from airflow.utils.helpers import at_most_one, exactly_one, validate_key 

104from airflow.utils.log.logging_mixin import LoggingMixin 

105from airflow.utils.session import NEW_SESSION, provide_session 

106from airflow.utils.sqlalchemy import Interval, UtcDateTime, skip_locked, tuple_in_condition, with_row_locks 

107from airflow.utils.state import DagRunState, State, TaskInstanceState 

108from airflow.utils.types import NOTSET, ArgNotSet, DagRunType, EdgeInfoType 

109 

110if TYPE_CHECKING: 

111 from types import ModuleType 

112 

113 from airflow.datasets import Dataset 

114 from airflow.decorators import TaskDecoratorCollection 

115 from airflow.models.dagbag import DagBag 

116 from airflow.models.slamiss import SlaMiss 

117 from airflow.utils.task_group import TaskGroup 

118 

119log = logging.getLogger(__name__) 

120 

121DEFAULT_VIEW_PRESETS = ["grid", "graph", "duration", "gantt", "landing_times"] 

122ORIENTATION_PRESETS = ["LR", "TB", "RL", "BT"] 

123 

124TAG_MAX_LEN = 100 

125 

126DagStateChangeCallback = Callable[[Context], None] 

127ScheduleInterval = Union[None, str, timedelta, relativedelta] 

128 

129# FIXME: Ideally this should be Union[Literal[NOTSET], ScheduleInterval], 

130# but Mypy cannot handle that right now. Track progress of PEP 661 for progress. 

131# See also: https://discuss.python.org/t/9126/7 

132ScheduleIntervalArg = Union[ArgNotSet, ScheduleInterval] 

133ScheduleArg = Union[ArgNotSet, ScheduleInterval, Timetable, Collection["Dataset"]] 

134 

135SLAMissCallback = Callable[["DAG", str, str, List["SlaMiss"], List[TaskInstance]], None] 

136 

137# Backward compatibility: If neither schedule_interval nor timetable is 

138# *provided by the user*, default to a one-day interval. 

139DEFAULT_SCHEDULE_INTERVAL = timedelta(days=1) 

140 

141 

142class InconsistentDataInterval(AirflowException): 

143 """Exception raised when a model populates data interval fields incorrectly. 

144 

145 The data interval fields should either both be None (for runs scheduled 

146 prior to AIP-39), or both be datetime (for runs scheduled after AIP-39 is 

147 implemented). This is raised if exactly one of the fields is None. 

148 """ 

149 

150 _template = ( 

151 "Inconsistent {cls}: {start[0]}={start[1]!r}, {end[0]}={end[1]!r}, " 

152 "they must be either both None or both datetime" 

153 ) 

154 

155 def __init__(self, instance: Any, start_field_name: str, end_field_name: str) -> None: 

156 self._class_name = type(instance).__name__ 

157 self._start_field = (start_field_name, getattr(instance, start_field_name)) 

158 self._end_field = (end_field_name, getattr(instance, end_field_name)) 

159 

160 def __str__(self) -> str: 

161 return self._template.format(cls=self._class_name, start=self._start_field, end=self._end_field) 

162 

163 

164def _get_model_data_interval( 

165 instance: Any, 

166 start_field_name: str, 

167 end_field_name: str, 

168) -> DataInterval | None: 

169 start = timezone.coerce_datetime(getattr(instance, start_field_name)) 

170 end = timezone.coerce_datetime(getattr(instance, end_field_name)) 

171 if start is None: 

172 if end is not None: 

173 raise InconsistentDataInterval(instance, start_field_name, end_field_name) 

174 return None 

175 elif end is None: 

176 raise InconsistentDataInterval(instance, start_field_name, end_field_name) 

177 return DataInterval(start, end) 

178 

179 

180def create_timetable(interval: ScheduleIntervalArg, timezone: Timezone) -> Timetable: 

181 """Create a Timetable instance from a ``schedule_interval`` argument.""" 

182 if interval is NOTSET: 

183 return DeltaDataIntervalTimetable(DEFAULT_SCHEDULE_INTERVAL) 

184 if interval is None: 

185 return NullTimetable() 

186 if interval == "@once": 

187 return OnceTimetable() 

188 if interval == "@continuous": 

189 return ContinuousTimetable() 

190 if isinstance(interval, (timedelta, relativedelta)): 

191 return DeltaDataIntervalTimetable(interval) 

192 if isinstance(interval, str): 

193 return CronDataIntervalTimetable(interval, timezone) 

194 raise ValueError(f"{interval!r} is not a valid schedule_interval.") 

195 

196 

197def get_last_dagrun(dag_id, session, include_externally_triggered=False): 

198 """ 

199 Returns the last dag run for a dag, None if there was none. 

200 Last dag run can be any type of run e.g. scheduled or backfilled. 

201 Overridden DagRuns are ignored. 

202 """ 

203 DR = DagRun 

204 query = session.query(DR).filter(DR.dag_id == dag_id) 

205 if not include_externally_triggered: 

206 query = query.filter(DR.external_trigger == expression.false()) 

207 query = query.order_by(DR.execution_date.desc()) 

208 return query.first() 

209 

210 

211def get_dataset_triggered_next_run_info( 

212 dag_ids: list[str], *, session: Session 

213) -> dict[str, dict[str, int | str]]: 

214 """ 

215 Given a list of dag_ids, get string representing how close any that are dataset triggered are 

216 their next run, e.g. "1 of 2 datasets updated". 

217 """ 

218 from airflow.models.dataset import DagScheduleDatasetReference, DatasetDagRunQueue as DDRQ, DatasetModel 

219 

220 return { 

221 x.dag_id: { 

222 "uri": x.uri, 

223 "ready": x.ready, 

224 "total": x.total, 

225 } 

226 for x in session.query( 

227 DagScheduleDatasetReference.dag_id, 

228 # This is a dirty hack to workaround group by requiring an aggregate, since grouping by dataset 

229 # is not what we want to do here...but it works 

230 case((func.count() == 1, func.max(DatasetModel.uri)), else_="").label("uri"), 

231 func.count().label("total"), 

232 func.sum(case((DDRQ.target_dag_id.is_not(None), 1), else_=0)).label("ready"), 

233 ) 

234 .join( 

235 DDRQ, 

236 and_( 

237 DDRQ.dataset_id == DagScheduleDatasetReference.dataset_id, 

238 DDRQ.target_dag_id == DagScheduleDatasetReference.dag_id, 

239 ), 

240 isouter=True, 

241 ) 

242 .join( 

243 DatasetModel, 

244 DatasetModel.id == DagScheduleDatasetReference.dataset_id, 

245 ) 

246 .group_by( 

247 DagScheduleDatasetReference.dag_id, 

248 ) 

249 .filter(DagScheduleDatasetReference.dag_id.in_(dag_ids)) 

250 .all() 

251 } 

252 

253 

254@functools.total_ordering 

255class DAG(LoggingMixin): 

256 """ 

257 A dag (directed acyclic graph) is a collection of tasks with directional 

258 dependencies. A dag also has a schedule, a start date and an end date 

259 (optional). For each schedule, (say daily or hourly), the DAG needs to run 

260 each individual tasks as their dependencies are met. Certain tasks have 

261 the property of depending on their own past, meaning that they can't run 

262 until their previous schedule (and upstream tasks) are completed. 

263 

264 DAGs essentially act as namespaces for tasks. A task_id can only be 

265 added once to a DAG. 

266 

267 Note that if you plan to use time zones all the dates provided should be pendulum 

268 dates. See :ref:`timezone_aware_dags`. 

269 

270 .. versionadded:: 2.4 

271 The *schedule* argument to specify either time-based scheduling logic 

272 (timetable), or dataset-driven triggers. 

273 

274 .. deprecated:: 2.4 

275 The arguments *schedule_interval* and *timetable*. Their functionalities 

276 are merged into the new *schedule* argument. 

277 

278 :param dag_id: The id of the DAG; must consist exclusively of alphanumeric 

279 characters, dashes, dots and underscores (all ASCII) 

280 :param description: The description for the DAG to e.g. be shown on the webserver 

281 :param schedule: Defines the rules according to which DAG runs are scheduled. Can 

282 accept cron string, timedelta object, Timetable, or list of Dataset objects. 

283 See also :doc:`/howto/timetable`. 

284 :param start_date: The timestamp from which the scheduler will 

285 attempt to backfill 

286 :param end_date: A date beyond which your DAG won't run, leave to None 

287 for open-ended scheduling 

288 :param template_searchpath: This list of folders (non-relative) 

289 defines where jinja will look for your templates. Order matters. 

290 Note that jinja/airflow includes the path of your DAG file by 

291 default 

292 :param template_undefined: Template undefined type. 

293 :param user_defined_macros: a dictionary of macros that will be exposed 

294 in your jinja templates. For example, passing ``dict(foo='bar')`` 

295 to this argument allows you to ``{{ foo }}`` in all jinja 

296 templates related to this DAG. Note that you can pass any 

297 type of object here. 

298 :param user_defined_filters: a dictionary of filters that will be exposed 

299 in your jinja templates. For example, passing 

300 ``dict(hello=lambda name: 'Hello %s' % name)`` to this argument allows 

301 you to ``{{ 'world' | hello }}`` in all jinja templates related to 

302 this DAG. 

303 :param default_args: A dictionary of default parameters to be used 

304 as constructor keyword parameters when initialising operators. 

305 Note that operators have the same hook, and precede those defined 

306 here, meaning that if your dict contains `'depends_on_past': True` 

307 here and `'depends_on_past': False` in the operator's call 

308 `default_args`, the actual value will be `False`. 

309 :param params: a dictionary of DAG level parameters that are made 

310 accessible in templates, namespaced under `params`. These 

311 params can be overridden at the task level. 

312 :param max_active_tasks: the number of task instances allowed to run 

313 concurrently 

314 :param max_active_runs: maximum number of active DAG runs, beyond this 

315 number of DAG runs in a running state, the scheduler won't create 

316 new active DAG runs 

317 :param dagrun_timeout: specify how long a DagRun should be up before 

318 timing out / failing, so that new DagRuns can be created. 

319 :param sla_miss_callback: specify a function or list of functions to call when reporting SLA 

320 timeouts. See :ref:`sla_miss_callback<concepts:sla_miss_callback>` for 

321 more information about the function signature and parameters that are 

322 passed to the callback. 

323 :param default_view: Specify DAG default view (grid, graph, duration, 

324 gantt, landing_times), default grid 

325 :param orientation: Specify DAG orientation in graph view (LR, TB, RL, BT), default LR 

326 :param catchup: Perform scheduler catchup (or only run latest)? Defaults to True 

327 :param on_failure_callback: A function or list of functions to be called when a DagRun of this dag fails. 

328 A context dictionary is passed as a single parameter to this function. 

329 :param on_success_callback: Much like the ``on_failure_callback`` except 

330 that it is executed when the dag succeeds. 

331 :param access_control: Specify optional DAG-level actions, e.g., 

332 "{'role1': {'can_read'}, 'role2': {'can_read', 'can_edit', 'can_delete'}}" 

333 :param is_paused_upon_creation: Specifies if the dag is paused when created for the first time. 

334 If the dag exists already, this flag will be ignored. If this optional parameter 

335 is not specified, the global config setting will be used. 

336 :param jinja_environment_kwargs: additional configuration options to be passed to Jinja 

337 ``Environment`` for template rendering 

338 

339 **Example**: to avoid Jinja from removing a trailing newline from template strings :: 

340 

341 DAG(dag_id='my-dag', 

342 jinja_environment_kwargs={ 

343 'keep_trailing_newline': True, 

344 # some other jinja2 Environment options here 

345 } 

346 ) 

347 

348 **See**: `Jinja Environment documentation 

349 <https://jinja.palletsprojects.com/en/2.11.x/api/#jinja2.Environment>`_ 

350 

351 :param render_template_as_native_obj: If True, uses a Jinja ``NativeEnvironment`` 

352 to render templates as native Python types. If False, a Jinja 

353 ``Environment`` is used to render templates as string values. 

354 :param tags: List of tags to help filtering DAGs in the UI. 

355 :param owner_links: Dict of owners and their links, that will be clickable on the DAGs view UI. 

356 Can be used as an HTTP link (for example the link to your Slack channel), or a mailto link. 

357 e.g: {"dag_owner": "https://airflow.apache.org/"} 

358 :param auto_register: Automatically register this DAG when it is used in a ``with`` block 

359 :param fail_stop: Fails currently running tasks when task in DAG fails. 

360 **Warning**: A fail stop dag can only have tasks with the default trigger rule ("all_success"). 

361 An exception will be thrown if any task in a fail stop dag has a non default trigger rule. 

362 """ 

363 

364 _comps = { 

365 "dag_id", 

366 "task_ids", 

367 "parent_dag", 

368 "start_date", 

369 "end_date", 

370 "schedule_interval", 

371 "fileloc", 

372 "template_searchpath", 

373 "last_loaded", 

374 } 

375 

376 __serialized_fields: frozenset[str] | None = None 

377 

378 fileloc: str 

379 """ 

380 File path that needs to be imported to load this DAG or subdag. 

381 

382 This may not be an actual file on disk in the case when this DAG is loaded 

383 from a ZIP file or other DAG distribution format. 

384 """ 

385 

386 parent_dag: DAG | None = None # Gets set when DAGs are loaded 

387 

388 # NOTE: When updating arguments here, please also keep arguments in @dag() 

389 # below in sync. (Search for 'def dag(' in this file.) 

390 def __init__( 

391 self, 

392 dag_id: str, 

393 description: str | None = None, 

394 schedule: ScheduleArg = NOTSET, 

395 schedule_interval: ScheduleIntervalArg = NOTSET, 

396 timetable: Timetable | None = None, 

397 start_date: datetime | None = None, 

398 end_date: datetime | None = None, 

399 full_filepath: str | None = None, 

400 template_searchpath: str | Iterable[str] | None = None, 

401 template_undefined: type[jinja2.StrictUndefined] = jinja2.StrictUndefined, 

402 user_defined_macros: dict | None = None, 

403 user_defined_filters: dict | None = None, 

404 default_args: dict | None = None, 

405 concurrency: int | None = None, 

406 max_active_tasks: int = conf.getint("core", "max_active_tasks_per_dag"), 

407 max_active_runs: int = conf.getint("core", "max_active_runs_per_dag"), 

408 dagrun_timeout: timedelta | None = None, 

409 sla_miss_callback: None | SLAMissCallback | list[SLAMissCallback] = None, 

410 default_view: str = conf.get_mandatory_value("webserver", "dag_default_view").lower(), 

411 orientation: str = conf.get_mandatory_value("webserver", "dag_orientation"), 

412 catchup: bool = conf.getboolean("scheduler", "catchup_by_default"), 

413 on_success_callback: None | DagStateChangeCallback | list[DagStateChangeCallback] = None, 

414 on_failure_callback: None | DagStateChangeCallback | list[DagStateChangeCallback] = None, 

415 doc_md: str | None = None, 

416 params: collections.abc.MutableMapping | None = None, 

417 access_control: dict | None = None, 

418 is_paused_upon_creation: bool | None = None, 

419 jinja_environment_kwargs: dict | None = None, 

420 render_template_as_native_obj: bool = False, 

421 tags: list[str] | None = None, 

422 owner_links: dict[str, str] | None = None, 

423 auto_register: bool = True, 

424 fail_stop: bool = False, 

425 ): 

426 from airflow.utils.task_group import TaskGroup 

427 

428 if tags and any(len(tag) > TAG_MAX_LEN for tag in tags): 

429 raise AirflowException(f"tag cannot be longer than {TAG_MAX_LEN} characters") 

430 

431 self.owner_links = owner_links if owner_links else {} 

432 self.user_defined_macros = user_defined_macros 

433 self.user_defined_filters = user_defined_filters 

434 if default_args and not isinstance(default_args, dict): 

435 raise TypeError("default_args must be a dict") 

436 self.default_args = copy.deepcopy(default_args or {}) 

437 params = params or {} 

438 

439 # merging potentially conflicting default_args['params'] into params 

440 if "params" in self.default_args: 

441 params.update(self.default_args["params"]) 

442 del self.default_args["params"] 

443 

444 # check self.params and convert them into ParamsDict 

445 self.params = ParamsDict(params) 

446 

447 if full_filepath: 

448 warnings.warn( 

449 "Passing full_filepath to DAG() is deprecated and has no effect", 

450 RemovedInAirflow3Warning, 

451 stacklevel=2, 

452 ) 

453 

454 validate_key(dag_id) 

455 

456 self._dag_id = dag_id 

457 if concurrency: 

458 # TODO: Remove in Airflow 3.0 

459 warnings.warn( 

460 "The 'concurrency' parameter is deprecated. Please use 'max_active_tasks'.", 

461 RemovedInAirflow3Warning, 

462 stacklevel=2, 

463 ) 

464 max_active_tasks = concurrency 

465 self._max_active_tasks = max_active_tasks 

466 self._pickle_id: int | None = None 

467 

468 self._description = description 

469 # set file location to caller source path 

470 back = sys._getframe().f_back 

471 self.fileloc = back.f_code.co_filename if back else "" 

472 self.task_dict: dict[str, Operator] = {} 

473 

474 # set timezone from start_date 

475 tz = None 

476 if start_date and start_date.tzinfo: 

477 tzinfo = None if start_date.tzinfo else settings.TIMEZONE 

478 tz = pendulum.instance(start_date, tz=tzinfo).timezone 

479 elif "start_date" in self.default_args and self.default_args["start_date"]: 

480 date = self.default_args["start_date"] 

481 if not isinstance(date, datetime): 

482 date = timezone.parse(date) 

483 self.default_args["start_date"] = date 

484 start_date = date 

485 

486 tzinfo = None if date.tzinfo else settings.TIMEZONE 

487 tz = pendulum.instance(date, tz=tzinfo).timezone 

488 self.timezone = tz or settings.TIMEZONE 

489 

490 # Apply the timezone we settled on to end_date if it wasn't supplied 

491 if "end_date" in self.default_args and self.default_args["end_date"]: 

492 if isinstance(self.default_args["end_date"], str): 

493 self.default_args["end_date"] = timezone.parse( 

494 self.default_args["end_date"], timezone=self.timezone 

495 ) 

496 

497 self.start_date = timezone.convert_to_utc(start_date) 

498 self.end_date = timezone.convert_to_utc(end_date) 

499 

500 # also convert tasks 

501 if "start_date" in self.default_args: 

502 self.default_args["start_date"] = timezone.convert_to_utc(self.default_args["start_date"]) 

503 if "end_date" in self.default_args: 

504 self.default_args["end_date"] = timezone.convert_to_utc(self.default_args["end_date"]) 

505 

506 # sort out DAG's scheduling behavior 

507 scheduling_args = [schedule_interval, timetable, schedule] 

508 if not at_most_one(*scheduling_args): 

509 raise ValueError("At most one allowed for args 'schedule_interval', 'timetable', and 'schedule'.") 

510 if schedule_interval is not NOTSET: 

511 warnings.warn( 

512 "Param `schedule_interval` is deprecated and will be removed in a future release. " 

513 "Please use `schedule` instead. ", 

514 RemovedInAirflow3Warning, 

515 stacklevel=2, 

516 ) 

517 if timetable is not None: 

518 warnings.warn( 

519 "Param `timetable` is deprecated and will be removed in a future release. " 

520 "Please use `schedule` instead. ", 

521 RemovedInAirflow3Warning, 

522 stacklevel=2, 

523 ) 

524 

525 self.timetable: Timetable 

526 self.schedule_interval: ScheduleInterval 

527 self.dataset_triggers: Collection[Dataset] = [] 

528 

529 if isinstance(schedule, Collection) and not isinstance(schedule, str): 

530 from airflow.datasets import Dataset 

531 

532 if not all(isinstance(x, Dataset) for x in schedule): 

533 raise ValueError("All elements in 'schedule' should be datasets") 

534 self.dataset_triggers = list(schedule) 

535 elif isinstance(schedule, Timetable): 

536 timetable = schedule 

537 elif schedule is not NOTSET: 

538 schedule_interval = schedule 

539 

540 if self.dataset_triggers: 

541 self.timetable = DatasetTriggeredTimetable() 

542 self.schedule_interval = self.timetable.summary 

543 elif timetable: 

544 self.timetable = timetable 

545 self.schedule_interval = self.timetable.summary 

546 else: 

547 if isinstance(schedule_interval, ArgNotSet): 

548 schedule_interval = DEFAULT_SCHEDULE_INTERVAL 

549 self.schedule_interval = schedule_interval 

550 self.timetable = create_timetable(schedule_interval, self.timezone) 

551 

552 if isinstance(template_searchpath, str): 

553 template_searchpath = [template_searchpath] 

554 self.template_searchpath = template_searchpath 

555 self.template_undefined = template_undefined 

556 self.last_loaded = timezone.utcnow() 

557 self.safe_dag_id = dag_id.replace(".", "__dot__") 

558 self.max_active_runs = max_active_runs 

559 if self.timetable.active_runs_limit is not None: 

560 if self.timetable.active_runs_limit < self.max_active_runs: 

561 raise AirflowException( 

562 f"Invalid max_active_runs: {type(self.timetable)} " 

563 f"requires max_active_runs <= {self.timetable.active_runs_limit}" 

564 ) 

565 self.dagrun_timeout = dagrun_timeout 

566 self.sla_miss_callback = sla_miss_callback 

567 if default_view in DEFAULT_VIEW_PRESETS: 

568 self._default_view: str = default_view 

569 elif default_view == "tree": 

570 warnings.warn( 

571 "`default_view` of 'tree' has been renamed to 'grid' -- please update your DAG", 

572 RemovedInAirflow3Warning, 

573 stacklevel=2, 

574 ) 

575 self._default_view = "grid" 

576 else: 

577 raise AirflowException( 

578 f"Invalid values of dag.default_view: only support " 

579 f"{DEFAULT_VIEW_PRESETS}, but get {default_view}" 

580 ) 

581 if orientation in ORIENTATION_PRESETS: 

582 self.orientation = orientation 

583 else: 

584 raise AirflowException( 

585 f"Invalid values of dag.orientation: only support " 

586 f"{ORIENTATION_PRESETS}, but get {orientation}" 

587 ) 

588 self.catchup = catchup 

589 

590 self.partial = False 

591 self.on_success_callback = on_success_callback 

592 self.on_failure_callback = on_failure_callback 

593 

594 # Keeps track of any extra edge metadata (sparse; will not contain all 

595 # edges, so do not iterate over it for that). Outer key is upstream 

596 # task ID, inner key is downstream task ID. 

597 self.edge_info: dict[str, dict[str, EdgeInfoType]] = {} 

598 

599 # To keep it in parity with Serialized DAGs 

600 # and identify if DAG has on_*_callback without actually storing them in Serialized JSON 

601 self.has_on_success_callback = self.on_success_callback is not None 

602 self.has_on_failure_callback = self.on_failure_callback is not None 

603 

604 self._access_control = DAG._upgrade_outdated_dag_access_control(access_control) 

605 self.is_paused_upon_creation = is_paused_upon_creation 

606 self.auto_register = auto_register 

607 

608 self.fail_stop = fail_stop 

609 

610 self.jinja_environment_kwargs = jinja_environment_kwargs 

611 self.render_template_as_native_obj = render_template_as_native_obj 

612 

613 self.doc_md = self.get_doc_md(doc_md) 

614 

615 self.tags = tags or [] 

616 self._task_group = TaskGroup.create_root(self) 

617 self.validate_schedule_and_params() 

618 wrong_links = dict(self.iter_invalid_owner_links()) 

619 if wrong_links: 

620 raise AirflowException( 

621 "Wrong link format was used for the owner. Use a valid link \n" 

622 f"Bad formatted links are: {wrong_links}" 

623 ) 

624 

625 # this will only be set at serialization time 

626 # it's only use is for determining the relative 

627 # fileloc based only on the serialize dag 

628 self._processor_dags_folder = None 

629 

630 def get_doc_md(self, doc_md: str | None) -> str | None: 

631 if doc_md is None: 

632 return doc_md 

633 

634 env = self.get_template_env(force_sandboxed=True) 

635 

636 if not doc_md.endswith(".md"): 

637 template = jinja2.Template(doc_md) 

638 else: 

639 try: 

640 template = env.get_template(doc_md) 

641 except jinja2.exceptions.TemplateNotFound: 

642 return f""" 

643 # Templating Error! 

644 Not able to find the template file: `{doc_md}`. 

645 """ 

646 

647 return template.render() 

648 

649 def _check_schedule_interval_matches_timetable(self) -> bool: 

650 """Check ``schedule_interval`` and ``timetable`` match. 

651 

652 This is done as a part of the DAG validation done before it's bagged, to 

653 guard against the DAG's ``timetable`` (or ``schedule_interval``) from 

654 being changed after it's created, e.g. 

655 

656 .. code-block:: python 

657 

658 dag1 = DAG("d1", timetable=MyTimetable()) 

659 dag1.schedule_interval = "@once" 

660 

661 dag2 = DAG("d2", schedule="@once") 

662 dag2.timetable = MyTimetable() 

663 

664 Validation is done by creating a timetable and check its summary matches 

665 ``schedule_interval``. The logic is not bullet-proof, especially if a 

666 custom timetable does not provide a useful ``summary``. But this is the 

667 best we can do. 

668 """ 

669 if self.schedule_interval == self.timetable.summary: 

670 return True 

671 try: 

672 timetable = create_timetable(self.schedule_interval, self.timezone) 

673 except ValueError: 

674 return False 

675 return timetable.summary == self.timetable.summary 

676 

677 def validate(self): 

678 """Validate the DAG has a coherent setup. 

679 

680 This is called by the DAG bag before bagging the DAG. 

681 """ 

682 if not self._check_schedule_interval_matches_timetable(): 

683 raise AirflowDagInconsistent( 

684 f"inconsistent schedule: timetable {self.timetable.summary!r} " 

685 f"does not match schedule_interval {self.schedule_interval!r}", 

686 ) 

687 self.params.validate() 

688 self.timetable.validate() 

689 

690 def __repr__(self): 

691 return f"<DAG: {self.dag_id}>" 

692 

693 def __eq__(self, other): 

694 if type(self) == type(other): 

695 # Use getattr() instead of __dict__ as __dict__ doesn't return 

696 # correct values for properties. 

697 return all(getattr(self, c, None) == getattr(other, c, None) for c in self._comps) 

698 return False 

699 

700 def __ne__(self, other): 

701 return not self == other 

702 

703 def __lt__(self, other): 

704 return self.dag_id < other.dag_id 

705 

706 def __hash__(self): 

707 hash_components = [type(self)] 

708 for c in self._comps: 

709 # task_ids returns a list and lists can't be hashed 

710 if c == "task_ids": 

711 val = tuple(self.task_dict.keys()) 

712 else: 

713 val = getattr(self, c, None) 

714 try: 

715 hash(val) 

716 hash_components.append(val) 

717 except TypeError: 

718 hash_components.append(repr(val)) 

719 return hash(tuple(hash_components)) 

720 

721 # Context Manager ----------------------------------------------- 

722 def __enter__(self): 

723 DagContext.push_context_managed_dag(self) 

724 return self 

725 

726 def __exit__(self, _type, _value, _tb): 

727 DagContext.pop_context_managed_dag() 

728 

729 # /Context Manager ---------------------------------------------- 

730 

731 @staticmethod 

732 def _upgrade_outdated_dag_access_control(access_control=None): 

733 """ 

734 Looks for outdated dag level actions (can_dag_read and can_dag_edit) in DAG 

735 access_controls (for example, {'role1': {'can_dag_read'}, 'role2': {'can_dag_read', 'can_dag_edit'}}) 

736 and replaces them with updated actions (can_read and can_edit). 

737 """ 

738 if not access_control: 

739 return None 

740 new_perm_mapping = { 

741 permissions.DEPRECATED_ACTION_CAN_DAG_READ: permissions.ACTION_CAN_READ, 

742 permissions.DEPRECATED_ACTION_CAN_DAG_EDIT: permissions.ACTION_CAN_EDIT, 

743 } 

744 updated_access_control = {} 

745 for role, perms in access_control.items(): 

746 updated_access_control[role] = {new_perm_mapping.get(perm, perm) for perm in perms} 

747 

748 if access_control != updated_access_control: 

749 warnings.warn( 

750 "The 'can_dag_read' and 'can_dag_edit' permissions are deprecated. " 

751 "Please use 'can_read' and 'can_edit', respectively.", 

752 RemovedInAirflow3Warning, 

753 stacklevel=3, 

754 ) 

755 

756 return updated_access_control 

757 

758 def date_range( 

759 self, 

760 start_date: pendulum.DateTime, 

761 num: int | None = None, 

762 end_date: datetime | None = None, 

763 ) -> list[datetime]: 

764 message = "`DAG.date_range()` is deprecated." 

765 if num is not None: 

766 warnings.warn(message, category=RemovedInAirflow3Warning, stacklevel=2) 

767 with warnings.catch_warnings(): 

768 warnings.simplefilter("ignore", RemovedInAirflow3Warning) 

769 return utils_date_range( 

770 start_date=start_date, num=num, delta=self.normalized_schedule_interval 

771 ) 

772 message += " Please use `DAG.iter_dagrun_infos_between(..., align=False)` instead." 

773 warnings.warn(message, category=RemovedInAirflow3Warning, stacklevel=2) 

774 if end_date is None: 

775 coerced_end_date = timezone.utcnow() 

776 else: 

777 coerced_end_date = end_date 

778 it = self.iter_dagrun_infos_between(start_date, pendulum.instance(coerced_end_date), align=False) 

779 return [info.logical_date for info in it] 

780 

781 def is_fixed_time_schedule(self): 

782 warnings.warn( 

783 "`DAG.is_fixed_time_schedule()` is deprecated.", 

784 category=RemovedInAirflow3Warning, 

785 stacklevel=2, 

786 ) 

787 try: 

788 return not self.timetable._should_fix_dst 

789 except AttributeError: 

790 return True 

791 

792 def following_schedule(self, dttm): 

793 """ 

794 Calculates the following schedule for this dag in UTC. 

795 

796 :param dttm: utc datetime 

797 :return: utc datetime 

798 """ 

799 warnings.warn( 

800 "`DAG.following_schedule()` is deprecated. Use `DAG.next_dagrun_info(restricted=False)` instead.", 

801 category=RemovedInAirflow3Warning, 

802 stacklevel=2, 

803 ) 

804 data_interval = self.infer_automated_data_interval(timezone.coerce_datetime(dttm)) 

805 next_info = self.next_dagrun_info(data_interval, restricted=False) 

806 if next_info is None: 

807 return None 

808 return next_info.data_interval.start 

809 

810 def previous_schedule(self, dttm): 

811 from airflow.timetables.interval import _DataIntervalTimetable 

812 

813 warnings.warn( 

814 "`DAG.previous_schedule()` is deprecated.", 

815 category=RemovedInAirflow3Warning, 

816 stacklevel=2, 

817 ) 

818 if not isinstance(self.timetable, _DataIntervalTimetable): 

819 return None 

820 return self.timetable._get_prev(timezone.coerce_datetime(dttm)) 

821 

822 def get_next_data_interval(self, dag_model: DagModel) -> DataInterval | None: 

823 """Get the data interval of the next scheduled run. 

824 

825 For compatibility, this method infers the data interval from the DAG's 

826 schedule if the run does not have an explicit one set, which is possible 

827 for runs created prior to AIP-39. 

828 

829 This function is private to Airflow core and should not be depended on as a 

830 part of the Python API. 

831 

832 :meta private: 

833 """ 

834 if self.dag_id != dag_model.dag_id: 

835 raise ValueError(f"Arguments refer to different DAGs: {self.dag_id} != {dag_model.dag_id}") 

836 if dag_model.next_dagrun is None: # Next run not scheduled. 

837 return None 

838 data_interval = dag_model.next_dagrun_data_interval 

839 if data_interval is not None: 

840 return data_interval 

841 

842 # Compatibility: A run was scheduled without an explicit data interval. 

843 # This means the run was scheduled before AIP-39 implementation. Try to 

844 # infer from the logical date. 

845 return self.infer_automated_data_interval(dag_model.next_dagrun) 

846 

847 def get_run_data_interval(self, run: DagRun) -> DataInterval: 

848 """Get the data interval of this run. 

849 

850 For compatibility, this method infers the data interval from the DAG's 

851 schedule if the run does not have an explicit one set, which is possible for 

852 runs created prior to AIP-39. 

853 

854 This function is private to Airflow core and should not be depended on as a 

855 part of the Python API. 

856 

857 :meta private: 

858 """ 

859 if run.dag_id is not None and run.dag_id != self.dag_id: 

860 raise ValueError(f"Arguments refer to different DAGs: {self.dag_id} != {run.dag_id}") 

861 data_interval = _get_model_data_interval(run, "data_interval_start", "data_interval_end") 

862 if data_interval is not None: 

863 return data_interval 

864 # Compatibility: runs created before AIP-39 implementation don't have an 

865 # explicit data interval. Try to infer from the logical date. 

866 return self.infer_automated_data_interval(run.execution_date) 

867 

868 def infer_automated_data_interval(self, logical_date: datetime) -> DataInterval: 

869 """Infer a data interval for a run against this DAG. 

870 

871 This method is used to bridge runs created prior to AIP-39 

872 implementation, which do not have an explicit data interval. Therefore, 

873 this method only considers ``schedule_interval`` values valid prior to 

874 Airflow 2.2. 

875 

876 DO NOT use this method is there is a known data interval. 

877 """ 

878 timetable_type = type(self.timetable) 

879 if issubclass(timetable_type, (NullTimetable, OnceTimetable, DatasetTriggeredTimetable)): 

880 return DataInterval.exact(timezone.coerce_datetime(logical_date)) 

881 start = timezone.coerce_datetime(logical_date) 

882 if issubclass(timetable_type, CronDataIntervalTimetable): 

883 end = cast(CronDataIntervalTimetable, self.timetable)._get_next(start) 

884 elif issubclass(timetable_type, DeltaDataIntervalTimetable): 

885 end = cast(DeltaDataIntervalTimetable, self.timetable)._get_next(start) 

886 else: 

887 raise ValueError(f"Not a valid timetable: {self.timetable!r}") 

888 return DataInterval(start, end) 

889 

890 def next_dagrun_info( 

891 self, 

892 last_automated_dagrun: None | datetime | DataInterval, 

893 *, 

894 restricted: bool = True, 

895 ) -> DagRunInfo | None: 

896 """Get information about the next DagRun of this dag after ``date_last_automated_dagrun``. 

897 

898 This calculates what time interval the next DagRun should operate on 

899 (its execution date) and when it can be scheduled, according to the 

900 dag's timetable, start_date, end_date, etc. This doesn't check max 

901 active run or any other "max_active_tasks" type limits, but only 

902 performs calculations based on the various date and interval fields of 

903 this dag and its tasks. 

904 

905 :param last_automated_dagrun: The ``max(execution_date)`` of 

906 existing "automated" DagRuns for this dag (scheduled or backfill, 

907 but not manual). 

908 :param restricted: If set to *False* (default is *True*), ignore 

909 ``start_date``, ``end_date``, and ``catchup`` specified on the DAG 

910 or tasks. 

911 :return: DagRunInfo of the next dagrun, or None if a dagrun is not 

912 going to be scheduled. 

913 """ 

914 # Never schedule a subdag. It will be scheduled by its parent dag. 

915 if self.is_subdag: 

916 return None 

917 

918 data_interval = None 

919 if isinstance(last_automated_dagrun, datetime): 

920 warnings.warn( 

921 "Passing a datetime to DAG.next_dagrun_info is deprecated. Use a DataInterval instead.", 

922 RemovedInAirflow3Warning, 

923 stacklevel=2, 

924 ) 

925 data_interval = self.infer_automated_data_interval( 

926 timezone.coerce_datetime(last_automated_dagrun) 

927 ) 

928 else: 

929 data_interval = last_automated_dagrun 

930 if restricted: 

931 restriction = self._time_restriction 

932 else: 

933 restriction = TimeRestriction(earliest=None, latest=None, catchup=True) 

934 try: 

935 info = self.timetable.next_dagrun_info( 

936 last_automated_data_interval=data_interval, 

937 restriction=restriction, 

938 ) 

939 except Exception: 

940 self.log.exception( 

941 "Failed to fetch run info after data interval %s for DAG %r", 

942 data_interval, 

943 self.dag_id, 

944 ) 

945 info = None 

946 return info 

947 

948 def next_dagrun_after_date(self, date_last_automated_dagrun: pendulum.DateTime | None): 

949 warnings.warn( 

950 "`DAG.next_dagrun_after_date()` is deprecated. Please use `DAG.next_dagrun_info()` instead.", 

951 category=RemovedInAirflow3Warning, 

952 stacklevel=2, 

953 ) 

954 if date_last_automated_dagrun is None: 

955 data_interval = None 

956 else: 

957 data_interval = self.infer_automated_data_interval(date_last_automated_dagrun) 

958 info = self.next_dagrun_info(data_interval) 

959 if info is None: 

960 return None 

961 return info.run_after 

962 

963 @functools.cached_property 

964 def _time_restriction(self) -> TimeRestriction: 

965 start_dates = [t.start_date for t in self.tasks if t.start_date] 

966 if self.start_date is not None: 

967 start_dates.append(self.start_date) 

968 earliest = None 

969 if start_dates: 

970 earliest = timezone.coerce_datetime(min(start_dates)) 

971 latest = self.end_date 

972 end_dates = [t.end_date for t in self.tasks if t.end_date] 

973 if len(end_dates) == len(self.tasks): # not exists null end_date 

974 if self.end_date is not None: 

975 end_dates.append(self.end_date) 

976 if end_dates: 

977 latest = timezone.coerce_datetime(max(end_dates)) 

978 return TimeRestriction(earliest, latest, self.catchup) 

979 

980 def iter_dagrun_infos_between( 

981 self, 

982 earliest: pendulum.DateTime | None, 

983 latest: pendulum.DateTime, 

984 *, 

985 align: bool = True, 

986 ) -> Iterable[DagRunInfo]: 

987 """Yield DagRunInfo using this DAG's timetable between given interval. 

988 

989 DagRunInfo instances yielded if their ``logical_date`` is not earlier 

990 than ``earliest``, nor later than ``latest``. The instances are ordered 

991 by their ``logical_date`` from earliest to latest. 

992 

993 If ``align`` is ``False``, the first run will happen immediately on 

994 ``earliest``, even if it does not fall on the logical timetable schedule. 

995 The default is ``True``, but subdags will ignore this value and always 

996 behave as if this is set to ``False`` for backward compatibility. 

997 

998 Example: A DAG is scheduled to run every midnight (``0 0 * * *``). If 

999 ``earliest`` is ``2021-06-03 23:00:00``, the first DagRunInfo would be 

1000 ``2021-06-03 23:00:00`` if ``align=False``, and ``2021-06-04 00:00:00`` 

1001 if ``align=True``. 

1002 """ 

1003 if earliest is None: 

1004 earliest = self._time_restriction.earliest 

1005 if earliest is None: 

1006 raise ValueError("earliest was None and we had no value in time_restriction to fallback on") 

1007 earliest = timezone.coerce_datetime(earliest) 

1008 latest = timezone.coerce_datetime(latest) 

1009 

1010 restriction = TimeRestriction(earliest, latest, catchup=True) 

1011 

1012 # HACK: Sub-DAGs are currently scheduled differently. For example, say 

1013 # the schedule is @daily and start is 2021-06-03 22:16:00, a top-level 

1014 # DAG should be first scheduled to run on midnight 2021-06-04, but a 

1015 # sub-DAG should be first scheduled to run RIGHT NOW. We can change 

1016 # this, but since sub-DAGs are going away in 3.0 anyway, let's keep 

1017 # compatibility for now and remove this entirely later. 

1018 if self.is_subdag: 

1019 align = False 

1020 

1021 try: 

1022 info = self.timetable.next_dagrun_info( 

1023 last_automated_data_interval=None, 

1024 restriction=restriction, 

1025 ) 

1026 except Exception: 

1027 self.log.exception( 

1028 "Failed to fetch run info after data interval %s for DAG %r", 

1029 None, 

1030 self.dag_id, 

1031 ) 

1032 info = None 

1033 

1034 if info is None: 

1035 # No runs to be scheduled between the user-supplied timeframe. But 

1036 # if align=False, "invent" a data interval for the timeframe itself. 

1037 if not align: 

1038 yield DagRunInfo.interval(earliest, latest) 

1039 return 

1040 

1041 # If align=False and earliest does not fall on the timetable's logical 

1042 # schedule, "invent" a data interval for it. 

1043 if not align and info.logical_date != earliest: 

1044 yield DagRunInfo.interval(earliest, info.data_interval.start) 

1045 

1046 # Generate naturally according to schedule. 

1047 while info is not None: 

1048 yield info 

1049 try: 

1050 info = self.timetable.next_dagrun_info( 

1051 last_automated_data_interval=info.data_interval, 

1052 restriction=restriction, 

1053 ) 

1054 except Exception: 

1055 self.log.exception( 

1056 "Failed to fetch run info after data interval %s for DAG %r", 

1057 info.data_interval if info else "<NONE>", 

1058 self.dag_id, 

1059 ) 

1060 break 

1061 

1062 def get_run_dates(self, start_date, end_date=None) -> list: 

1063 """ 

1064 Returns a list of dates between the interval received as parameter using this 

1065 dag's schedule interval. Returned dates can be used for execution dates. 

1066 

1067 :param start_date: The start date of the interval. 

1068 :param end_date: The end date of the interval. Defaults to ``timezone.utcnow()``. 

1069 :return: A list of dates within the interval following the dag's schedule. 

1070 """ 

1071 warnings.warn( 

1072 "`DAG.get_run_dates()` is deprecated. Please use `DAG.iter_dagrun_infos_between()` instead.", 

1073 category=RemovedInAirflow3Warning, 

1074 stacklevel=2, 

1075 ) 

1076 earliest = timezone.coerce_datetime(start_date) 

1077 if end_date is None: 

1078 latest = pendulum.now(timezone.utc) 

1079 else: 

1080 latest = timezone.coerce_datetime(end_date) 

1081 return [info.logical_date for info in self.iter_dagrun_infos_between(earliest, latest)] 

1082 

1083 def normalize_schedule(self, dttm): 

1084 warnings.warn( 

1085 "`DAG.normalize_schedule()` is deprecated.", 

1086 category=RemovedInAirflow3Warning, 

1087 stacklevel=2, 

1088 ) 

1089 with warnings.catch_warnings(): 

1090 warnings.simplefilter("ignore", RemovedInAirflow3Warning) 

1091 following = self.following_schedule(dttm) 

1092 if not following: # in case of @once 

1093 return dttm 

1094 with warnings.catch_warnings(): 

1095 warnings.simplefilter("ignore", RemovedInAirflow3Warning) 

1096 previous_of_following = self.previous_schedule(following) 

1097 if previous_of_following != dttm: 

1098 return following 

1099 return dttm 

1100 

1101 @provide_session 

1102 def get_last_dagrun(self, session=NEW_SESSION, include_externally_triggered=False): 

1103 return get_last_dagrun( 

1104 self.dag_id, session=session, include_externally_triggered=include_externally_triggered 

1105 ) 

1106 

1107 @provide_session 

1108 def has_dag_runs(self, session=NEW_SESSION, include_externally_triggered=True) -> bool: 

1109 return ( 

1110 get_last_dagrun( 

1111 self.dag_id, session=session, include_externally_triggered=include_externally_triggered 

1112 ) 

1113 is not None 

1114 ) 

1115 

1116 @property 

1117 def dag_id(self) -> str: 

1118 return self._dag_id 

1119 

1120 @dag_id.setter 

1121 def dag_id(self, value: str) -> None: 

1122 self._dag_id = value 

1123 

1124 @property 

1125 def is_subdag(self) -> bool: 

1126 return self.parent_dag is not None 

1127 

1128 @property 

1129 def full_filepath(self) -> str: 

1130 """Full file path to the DAG. 

1131 

1132 :meta private: 

1133 """ 

1134 warnings.warn( 

1135 "DAG.full_filepath is deprecated in favour of fileloc", 

1136 RemovedInAirflow3Warning, 

1137 stacklevel=2, 

1138 ) 

1139 return self.fileloc 

1140 

1141 @full_filepath.setter 

1142 def full_filepath(self, value) -> None: 

1143 warnings.warn( 

1144 "DAG.full_filepath is deprecated in favour of fileloc", 

1145 RemovedInAirflow3Warning, 

1146 stacklevel=2, 

1147 ) 

1148 self.fileloc = value 

1149 

1150 @property 

1151 def concurrency(self) -> int: 

1152 # TODO: Remove in Airflow 3.0 

1153 warnings.warn( 

1154 "The 'DAG.concurrency' attribute is deprecated. Please use 'DAG.max_active_tasks'.", 

1155 RemovedInAirflow3Warning, 

1156 stacklevel=2, 

1157 ) 

1158 return self._max_active_tasks 

1159 

1160 @concurrency.setter 

1161 def concurrency(self, value: int): 

1162 self._max_active_tasks = value 

1163 

1164 @property 

1165 def max_active_tasks(self) -> int: 

1166 return self._max_active_tasks 

1167 

1168 @max_active_tasks.setter 

1169 def max_active_tasks(self, value: int): 

1170 self._max_active_tasks = value 

1171 

1172 @property 

1173 def access_control(self): 

1174 return self._access_control 

1175 

1176 @access_control.setter 

1177 def access_control(self, value): 

1178 self._access_control = DAG._upgrade_outdated_dag_access_control(value) 

1179 

1180 @property 

1181 def description(self) -> str | None: 

1182 return self._description 

1183 

1184 @property 

1185 def default_view(self) -> str: 

1186 return self._default_view 

1187 

1188 @property 

1189 def pickle_id(self) -> int | None: 

1190 return self._pickle_id 

1191 

1192 @pickle_id.setter 

1193 def pickle_id(self, value: int) -> None: 

1194 self._pickle_id = value 

1195 

1196 def param(self, name: str, default: Any = NOTSET) -> DagParam: 

1197 """ 

1198 Return a DagParam object for current dag. 

1199 

1200 :param name: dag parameter name. 

1201 :param default: fallback value for dag parameter. 

1202 :return: DagParam instance for specified name and current dag. 

1203 """ 

1204 return DagParam(current_dag=self, name=name, default=default) 

1205 

1206 @property 

1207 def tasks(self) -> list[Operator]: 

1208 return list(self.task_dict.values()) 

1209 

1210 @tasks.setter 

1211 def tasks(self, val): 

1212 raise AttributeError("DAG.tasks can not be modified. Use dag.add_task() instead.") 

1213 

1214 @property 

1215 def task_ids(self) -> list[str]: 

1216 return list(self.task_dict.keys()) 

1217 

1218 @property 

1219 def teardowns(self) -> list[Operator]: 

1220 return [task for task in self.tasks if getattr(task, "is_teardown", None)] 

1221 

1222 @property 

1223 def tasks_upstream_of_teardowns(self) -> list[Operator]: 

1224 upstream_tasks = [t.upstream_list for t in self.teardowns] 

1225 return [val for sublist in upstream_tasks for val in sublist if not getattr(val, "is_teardown", None)] 

1226 

1227 @property 

1228 def task_group(self) -> TaskGroup: 

1229 return self._task_group 

1230 

1231 @property 

1232 def filepath(self) -> str: 

1233 """Relative file path to the DAG. 

1234 

1235 :meta private: 

1236 """ 

1237 warnings.warn( 

1238 "filepath is deprecated, use relative_fileloc instead", 

1239 RemovedInAirflow3Warning, 

1240 stacklevel=2, 

1241 ) 

1242 return str(self.relative_fileloc) 

1243 

1244 @property 

1245 def relative_fileloc(self) -> pathlib.Path: 

1246 """File location of the importable dag 'file' relative to the configured DAGs folder.""" 

1247 path = pathlib.Path(self.fileloc) 

1248 try: 

1249 rel_path = path.relative_to(self._processor_dags_folder or settings.DAGS_FOLDER) 

1250 if rel_path == pathlib.Path("."): 

1251 return path 

1252 else: 

1253 return rel_path 

1254 except ValueError: 

1255 # Not relative to DAGS_FOLDER. 

1256 return path 

1257 

1258 @property 

1259 def folder(self) -> str: 

1260 """Folder location of where the DAG object is instantiated.""" 

1261 return os.path.dirname(self.fileloc) 

1262 

1263 @property 

1264 def owner(self) -> str: 

1265 """ 

1266 Return list of all owners found in DAG tasks. 

1267 

1268 :return: Comma separated list of owners in DAG tasks 

1269 """ 

1270 return ", ".join({t.owner for t in self.tasks}) 

1271 

1272 @property 

1273 def allow_future_exec_dates(self) -> bool: 

1274 return settings.ALLOW_FUTURE_EXEC_DATES and not self.timetable.can_be_scheduled 

1275 

1276 @provide_session 

1277 def get_concurrency_reached(self, session=NEW_SESSION) -> bool: 

1278 """ 

1279 Returns a boolean indicating whether the max_active_tasks limit for this DAG 

1280 has been reached. 

1281 """ 

1282 TI = TaskInstance 

1283 qry = session.query(func.count(TI.task_id)).filter( 

1284 TI.dag_id == self.dag_id, 

1285 TI.state == State.RUNNING, 

1286 ) 

1287 return qry.scalar() >= self.max_active_tasks 

1288 

1289 @property 

1290 def concurrency_reached(self): 

1291 """This attribute is deprecated. Please use `airflow.models.DAG.get_concurrency_reached` method.""" 

1292 warnings.warn( 

1293 "This attribute is deprecated. Please use `airflow.models.DAG.get_concurrency_reached` method.", 

1294 RemovedInAirflow3Warning, 

1295 stacklevel=2, 

1296 ) 

1297 return self.get_concurrency_reached() 

1298 

1299 @provide_session 

1300 def get_is_active(self, session=NEW_SESSION) -> None: 

1301 """Returns a boolean indicating whether this DAG is active.""" 

1302 return session.query(DagModel.is_active).filter(DagModel.dag_id == self.dag_id).scalar() 

1303 

1304 @provide_session 

1305 def get_is_paused(self, session=NEW_SESSION) -> None: 

1306 """Returns a boolean indicating whether this DAG is paused.""" 

1307 return session.query(DagModel.is_paused).filter(DagModel.dag_id == self.dag_id).scalar() 

1308 

1309 @property 

1310 def is_paused(self): 

1311 """This attribute is deprecated. Please use `airflow.models.DAG.get_is_paused` method.""" 

1312 warnings.warn( 

1313 "This attribute is deprecated. Please use `airflow.models.DAG.get_is_paused` method.", 

1314 RemovedInAirflow3Warning, 

1315 stacklevel=2, 

1316 ) 

1317 return self.get_is_paused() 

1318 

1319 @property 

1320 def normalized_schedule_interval(self) -> ScheduleInterval: 

1321 warnings.warn( 

1322 "DAG.normalized_schedule_interval() is deprecated.", 

1323 category=RemovedInAirflow3Warning, 

1324 stacklevel=2, 

1325 ) 

1326 if isinstance(self.schedule_interval, str) and self.schedule_interval in cron_presets: 

1327 _schedule_interval: ScheduleInterval = cron_presets.get(self.schedule_interval) 

1328 elif self.schedule_interval == "@once": 

1329 _schedule_interval = None 

1330 else: 

1331 _schedule_interval = self.schedule_interval 

1332 return _schedule_interval 

1333 

1334 @provide_session 

1335 def handle_callback(self, dagrun, success=True, reason=None, session=NEW_SESSION): 

1336 """ 

1337 Triggers the appropriate callback depending on the value of success, namely the 

1338 on_failure_callback or on_success_callback. This method gets the context of a 

1339 single TaskInstance part of this DagRun and passes that to the callable along 

1340 with a 'reason', primarily to differentiate DagRun failures. 

1341 

1342 .. note: The logs end up in 

1343 ``$AIRFLOW_HOME/logs/scheduler/latest/PROJECT/DAG_FILE.py.log`` 

1344 

1345 :param dagrun: DagRun object 

1346 :param success: Flag to specify if failure or success callback should be called 

1347 :param reason: Completion reason 

1348 :param session: Database session 

1349 """ 

1350 callbacks = self.on_success_callback if success else self.on_failure_callback 

1351 if callbacks: 

1352 callbacks = callbacks if isinstance(callbacks, list) else [callbacks] 

1353 tis = dagrun.get_task_instances(session=session) 

1354 ti = tis[-1] # get first TaskInstance of DagRun 

1355 ti.task = self.get_task(ti.task_id) 

1356 context = ti.get_template_context(session=session) 

1357 context.update({"reason": reason}) 

1358 for callback in callbacks: 

1359 self.log.info("Executing dag callback function: %s", callback) 

1360 try: 

1361 callback(context) 

1362 except Exception: 

1363 self.log.exception("failed to invoke dag state update callback") 

1364 Stats.incr("dag.callback_exceptions", tags={"dag_id": dagrun.dag_id}) 

1365 

1366 def get_active_runs(self): 

1367 """ 

1368 Returns a list of dag run execution dates currently running. 

1369 

1370 :return: List of execution dates 

1371 """ 

1372 runs = DagRun.find(dag_id=self.dag_id, state=State.RUNNING) 

1373 

1374 active_dates = [] 

1375 for run in runs: 

1376 active_dates.append(run.execution_date) 

1377 

1378 return active_dates 

1379 

1380 @provide_session 

1381 def get_num_active_runs(self, external_trigger=None, only_running=True, session=NEW_SESSION): 

1382 """ 

1383 Returns the number of active "running" dag runs. 

1384 

1385 :param external_trigger: True for externally triggered active dag runs 

1386 :param session: 

1387 :return: number greater than 0 for active dag runs 

1388 """ 

1389 # .count() is inefficient 

1390 query = session.query(func.count()).filter(DagRun.dag_id == self.dag_id) 

1391 if only_running: 

1392 query = query.filter(DagRun.state == State.RUNNING) 

1393 else: 

1394 query = query.filter(DagRun.state.in_({State.RUNNING, State.QUEUED})) 

1395 

1396 if external_trigger is not None: 

1397 query = query.filter( 

1398 DagRun.external_trigger == (expression.true() if external_trigger else expression.false()) 

1399 ) 

1400 

1401 return query.scalar() 

1402 

1403 @provide_session 

1404 def get_dagrun( 

1405 self, 

1406 execution_date: datetime | None = None, 

1407 run_id: str | None = None, 

1408 session: Session = NEW_SESSION, 

1409 ): 

1410 """ 

1411 Returns the dag run for a given execution date or run_id if it exists, otherwise 

1412 none. 

1413 

1414 :param execution_date: The execution date of the DagRun to find. 

1415 :param run_id: The run_id of the DagRun to find. 

1416 :param session: 

1417 :return: The DagRun if found, otherwise None. 

1418 """ 

1419 if not (execution_date or run_id): 

1420 raise TypeError("You must provide either the execution_date or the run_id") 

1421 query = session.query(DagRun) 

1422 if execution_date: 

1423 query = query.filter(DagRun.dag_id == self.dag_id, DagRun.execution_date == execution_date) 

1424 if run_id: 

1425 query = query.filter(DagRun.dag_id == self.dag_id, DagRun.run_id == run_id) 

1426 return query.first() 

1427 

1428 @provide_session 

1429 def get_dagruns_between(self, start_date, end_date, session=NEW_SESSION): 

1430 """ 

1431 Returns the list of dag runs between start_date (inclusive) and end_date (inclusive). 

1432 

1433 :param start_date: The starting execution date of the DagRun to find. 

1434 :param end_date: The ending execution date of the DagRun to find. 

1435 :param session: 

1436 :return: The list of DagRuns found. 

1437 """ 

1438 dagruns = ( 

1439 session.query(DagRun) 

1440 .filter( 

1441 DagRun.dag_id == self.dag_id, 

1442 DagRun.execution_date >= start_date, 

1443 DagRun.execution_date <= end_date, 

1444 ) 

1445 .all() 

1446 ) 

1447 

1448 return dagruns 

1449 

1450 @provide_session 

1451 def get_latest_execution_date(self, session: Session = NEW_SESSION) -> pendulum.DateTime | None: 

1452 """Returns the latest date for which at least one dag run exists.""" 

1453 return session.query(func.max(DagRun.execution_date)).filter(DagRun.dag_id == self.dag_id).scalar() 

1454 

1455 @property 

1456 def latest_execution_date(self): 

1457 """This attribute is deprecated. Please use `airflow.models.DAG.get_latest_execution_date`.""" 

1458 warnings.warn( 

1459 "This attribute is deprecated. Please use `airflow.models.DAG.get_latest_execution_date`.", 

1460 RemovedInAirflow3Warning, 

1461 stacklevel=2, 

1462 ) 

1463 return self.get_latest_execution_date() 

1464 

1465 @property 

1466 def subdags(self): 

1467 """Returns a list of the subdag objects associated to this DAG.""" 

1468 # Check SubDag for class but don't check class directly 

1469 from airflow.operators.subdag import SubDagOperator 

1470 

1471 subdag_lst = [] 

1472 for task in self.tasks: 

1473 if ( 

1474 isinstance(task, SubDagOperator) 

1475 or 

1476 # TODO remove in Airflow 2.0 

1477 type(task).__name__ == "SubDagOperator" 

1478 or task.task_type == "SubDagOperator" 

1479 ): 

1480 subdag_lst.append(task.subdag) 

1481 subdag_lst += task.subdag.subdags 

1482 return subdag_lst 

1483 

1484 def resolve_template_files(self): 

1485 for t in self.tasks: 

1486 t.resolve_template_files() 

1487 

1488 def get_template_env(self, *, force_sandboxed: bool = False) -> jinja2.Environment: 

1489 """Build a Jinja2 environment.""" 

1490 # Collect directories to search for template files 

1491 searchpath = [self.folder] 

1492 if self.template_searchpath: 

1493 searchpath += self.template_searchpath 

1494 

1495 # Default values (for backward compatibility) 

1496 jinja_env_options = { 

1497 "loader": jinja2.FileSystemLoader(searchpath), 

1498 "undefined": self.template_undefined, 

1499 "extensions": ["jinja2.ext.do"], 

1500 "cache_size": 0, 

1501 } 

1502 if self.jinja_environment_kwargs: 

1503 jinja_env_options.update(self.jinja_environment_kwargs) 

1504 env: jinja2.Environment 

1505 if self.render_template_as_native_obj and not force_sandboxed: 

1506 env = airflow.templates.NativeEnvironment(**jinja_env_options) 

1507 else: 

1508 env = airflow.templates.SandboxedEnvironment(**jinja_env_options) 

1509 

1510 # Add any user defined items. Safe to edit globals as long as no templates are rendered yet. 

1511 # http://jinja.pocoo.org/docs/2.10/api/#jinja2.Environment.globals 

1512 if self.user_defined_macros: 

1513 env.globals.update(self.user_defined_macros) 

1514 if self.user_defined_filters: 

1515 env.filters.update(self.user_defined_filters) 

1516 

1517 return env 

1518 

1519 def set_dependency(self, upstream_task_id, downstream_task_id): 

1520 """ 

1521 Simple utility method to set dependency between two tasks that 

1522 already have been added to the DAG using add_task(). 

1523 """ 

1524 self.get_task(upstream_task_id).set_downstream(self.get_task(downstream_task_id)) 

1525 

1526 @provide_session 

1527 def get_task_instances_before( 

1528 self, 

1529 base_date: datetime, 

1530 num: int, 

1531 *, 

1532 session: Session = NEW_SESSION, 

1533 ) -> list[TaskInstance]: 

1534 """Get ``num`` task instances before (including) ``base_date``. 

1535 

1536 The returned list may contain exactly ``num`` task instances 

1537 corresponding to any DagRunType. It can have less if there are 

1538 less than ``num`` scheduled DAG runs before ``base_date``. 

1539 """ 

1540 execution_dates: list[Any] = ( 

1541 session.query(DagRun.execution_date) 

1542 .filter( 

1543 DagRun.dag_id == self.dag_id, 

1544 DagRun.execution_date <= base_date, 

1545 ) 

1546 .order_by(DagRun.execution_date.desc()) 

1547 .limit(num) 

1548 .all() 

1549 ) 

1550 

1551 if len(execution_dates) == 0: 

1552 return self.get_task_instances(start_date=base_date, end_date=base_date, session=session) 

1553 

1554 min_date: datetime | None = execution_dates[-1]._mapping.get( 

1555 "execution_date" 

1556 ) # getting the last value from the list 

1557 

1558 return self.get_task_instances(start_date=min_date, end_date=base_date, session=session) 

1559 

1560 @provide_session 

1561 def get_task_instances( 

1562 self, 

1563 start_date: datetime | None = None, 

1564 end_date: datetime | None = None, 

1565 state: list[TaskInstanceState] | None = None, 

1566 session: Session = NEW_SESSION, 

1567 ) -> list[TaskInstance]: 

1568 if not start_date: 

1569 start_date = (timezone.utcnow() - timedelta(30)).replace( 

1570 hour=0, minute=0, second=0, microsecond=0 

1571 ) 

1572 

1573 query = self._get_task_instances( 

1574 task_ids=None, 

1575 start_date=start_date, 

1576 end_date=end_date, 

1577 run_id=None, 

1578 state=state or (), 

1579 include_subdags=False, 

1580 include_parentdag=False, 

1581 include_dependent_dags=False, 

1582 exclude_task_ids=(), 

1583 session=session, 

1584 ) 

1585 return cast(Query, query).order_by(DagRun.execution_date).all() 

1586 

1587 @overload 

1588 def _get_task_instances( 

1589 self, 

1590 *, 

1591 task_ids: Collection[str | tuple[str, int]] | None, 

1592 start_date: datetime | None, 

1593 end_date: datetime | None, 

1594 run_id: str | None, 

1595 state: TaskInstanceState | Sequence[TaskInstanceState], 

1596 include_subdags: bool, 

1597 include_parentdag: bool, 

1598 include_dependent_dags: bool, 

1599 exclude_task_ids: Collection[str | tuple[str, int]] | None, 

1600 session: Session, 

1601 dag_bag: DagBag | None = ..., 

1602 ) -> Iterable[TaskInstance]: 

1603 ... # pragma: no cover 

1604 

1605 @overload 

1606 def _get_task_instances( 

1607 self, 

1608 *, 

1609 task_ids: Collection[str | tuple[str, int]] | None, 

1610 as_pk_tuple: Literal[True], 

1611 start_date: datetime | None, 

1612 end_date: datetime | None, 

1613 run_id: str | None, 

1614 state: TaskInstanceState | Sequence[TaskInstanceState], 

1615 include_subdags: bool, 

1616 include_parentdag: bool, 

1617 include_dependent_dags: bool, 

1618 exclude_task_ids: Collection[str | tuple[str, int]] | None, 

1619 session: Session, 

1620 dag_bag: DagBag | None = ..., 

1621 recursion_depth: int = ..., 

1622 max_recursion_depth: int = ..., 

1623 visited_external_tis: set[TaskInstanceKey] = ..., 

1624 ) -> set[TaskInstanceKey]: 

1625 ... # pragma: no cover 

1626 

1627 def _get_task_instances( 

1628 self, 

1629 *, 

1630 task_ids: Collection[str | tuple[str, int]] | None, 

1631 as_pk_tuple: Literal[True, None] = None, 

1632 start_date: datetime | None, 

1633 end_date: datetime | None, 

1634 run_id: str | None, 

1635 state: TaskInstanceState | Sequence[TaskInstanceState], 

1636 include_subdags: bool, 

1637 include_parentdag: bool, 

1638 include_dependent_dags: bool, 

1639 exclude_task_ids: Collection[str | tuple[str, int]] | None, 

1640 session: Session, 

1641 dag_bag: DagBag | None = None, 

1642 recursion_depth: int = 0, 

1643 max_recursion_depth: int | None = None, 

1644 visited_external_tis: set[TaskInstanceKey] | None = None, 

1645 ) -> Iterable[TaskInstance] | set[TaskInstanceKey]: 

1646 TI = TaskInstance 

1647 

1648 # If we are looking at subdags/dependent dags we want to avoid UNION calls 

1649 # in SQL (it doesn't play nice with fields that have no equality operator, 

1650 # like JSON types), we instead build our result set separately. 

1651 # 

1652 # This will be empty if we are only looking at one dag, in which case 

1653 # we can return the filtered TI query object directly. 

1654 result: set[TaskInstanceKey] = set() 

1655 

1656 # Do we want full objects, or just the primary columns? 

1657 if as_pk_tuple: 

1658 tis = session.query(TI.dag_id, TI.task_id, TI.run_id, TI.map_index) 

1659 else: 

1660 tis = session.query(TaskInstance) 

1661 tis = tis.join(TaskInstance.dag_run) 

1662 

1663 if include_subdags: 

1664 # Crafting the right filter for dag_id and task_ids combo 

1665 conditions = [] 

1666 for dag in self.subdags + [self]: 

1667 conditions.append( 

1668 (TaskInstance.dag_id == dag.dag_id) & TaskInstance.task_id.in_(dag.task_ids) 

1669 ) 

1670 tis = tis.filter(or_(*conditions)) 

1671 elif self.partial: 

1672 tis = tis.filter(TaskInstance.dag_id == self.dag_id, TaskInstance.task_id.in_(self.task_ids)) 

1673 else: 

1674 tis = tis.filter(TaskInstance.dag_id == self.dag_id) 

1675 if run_id: 

1676 tis = tis.filter(TaskInstance.run_id == run_id) 

1677 if start_date: 

1678 tis = tis.filter(DagRun.execution_date >= start_date) 

1679 if task_ids is not None: 

1680 tis = tis.filter(TaskInstance.ti_selector_condition(task_ids)) 

1681 

1682 # This allows allow_trigger_in_future config to take affect, rather than mandating exec_date <= UTC 

1683 if end_date or not self.allow_future_exec_dates: 

1684 end_date = end_date or timezone.utcnow() 

1685 tis = tis.filter(DagRun.execution_date <= end_date) 

1686 

1687 if state: 

1688 if isinstance(state, (str, TaskInstanceState)): 

1689 tis = tis.filter(TaskInstance.state == state) 

1690 elif len(state) == 1: 

1691 tis = tis.filter(TaskInstance.state == state[0]) 

1692 else: 

1693 # this is required to deal with NULL values 

1694 if None in state: 

1695 if all(x is None for x in state): 

1696 tis = tis.filter(TaskInstance.state.is_(None)) 

1697 else: 

1698 not_none_state = [s for s in state if s] 

1699 tis = tis.filter( 

1700 or_(TaskInstance.state.in_(not_none_state), TaskInstance.state.is_(None)) 

1701 ) 

1702 else: 

1703 tis = tis.filter(TaskInstance.state.in_(state)) 

1704 

1705 # Next, get any of them from our parent DAG (if there is one) 

1706 if include_parentdag and self.parent_dag is not None: 

1707 

1708 if visited_external_tis is None: 

1709 visited_external_tis = set() 

1710 

1711 p_dag = self.parent_dag.partial_subset( 

1712 task_ids_or_regex=r"^{}$".format(self.dag_id.split(".")[1]), 

1713 include_upstream=False, 

1714 include_downstream=True, 

1715 ) 

1716 result.update( 

1717 p_dag._get_task_instances( 

1718 task_ids=task_ids, 

1719 start_date=start_date, 

1720 end_date=end_date, 

1721 run_id=None, 

1722 state=state, 

1723 include_subdags=include_subdags, 

1724 include_parentdag=False, 

1725 include_dependent_dags=include_dependent_dags, 

1726 as_pk_tuple=True, 

1727 exclude_task_ids=exclude_task_ids, 

1728 session=session, 

1729 dag_bag=dag_bag, 

1730 recursion_depth=recursion_depth, 

1731 max_recursion_depth=max_recursion_depth, 

1732 visited_external_tis=visited_external_tis, 

1733 ) 

1734 ) 

1735 

1736 if include_dependent_dags: 

1737 # Recursively find external tasks indicated by ExternalTaskMarker 

1738 from airflow.sensors.external_task import ExternalTaskMarker 

1739 

1740 query = tis 

1741 if as_pk_tuple: 

1742 condition = TI.filter_for_tis(TaskInstanceKey(*cols) for cols in tis.all()) 

1743 if condition is not None: 

1744 query = session.query(TI).filter(condition) 

1745 

1746 if visited_external_tis is None: 

1747 visited_external_tis = set() 

1748 

1749 for ti in query.filter(TI.operator == ExternalTaskMarker.__name__): 

1750 ti_key = ti.key.primary 

1751 if ti_key in visited_external_tis: 

1752 continue 

1753 

1754 visited_external_tis.add(ti_key) 

1755 

1756 task: ExternalTaskMarker = cast(ExternalTaskMarker, copy.copy(self.get_task(ti.task_id))) 

1757 ti.task = task 

1758 

1759 if max_recursion_depth is None: 

1760 # Maximum recursion depth allowed is the recursion_depth of the first 

1761 # ExternalTaskMarker in the tasks to be visited. 

1762 max_recursion_depth = task.recursion_depth 

1763 

1764 if recursion_depth + 1 > max_recursion_depth: 

1765 # Prevent cycles or accidents. 

1766 raise AirflowException( 

1767 f"Maximum recursion depth {max_recursion_depth} reached for " 

1768 f"{ExternalTaskMarker.__name__} {ti.task_id}. " 

1769 f"Attempted to clear too many tasks or there may be a cyclic dependency." 

1770 ) 

1771 ti.render_templates() 

1772 external_tis = ( 

1773 session.query(TI) 

1774 .join(TI.dag_run) 

1775 .filter( 

1776 TI.dag_id == task.external_dag_id, 

1777 TI.task_id == task.external_task_id, 

1778 DagRun.execution_date == pendulum.parse(task.execution_date), 

1779 ) 

1780 ) 

1781 

1782 for tii in external_tis: 

1783 if not dag_bag: 

1784 from airflow.models.dagbag import DagBag 

1785 

1786 dag_bag = DagBag(read_dags_from_db=True) 

1787 external_dag = dag_bag.get_dag(tii.dag_id, session=session) 

1788 if not external_dag: 

1789 raise AirflowException(f"Could not find dag {tii.dag_id}") 

1790 downstream = external_dag.partial_subset( 

1791 task_ids_or_regex=[tii.task_id], 

1792 include_upstream=False, 

1793 include_downstream=True, 

1794 ) 

1795 result.update( 

1796 downstream._get_task_instances( 

1797 task_ids=None, 

1798 run_id=tii.run_id, 

1799 start_date=None, 

1800 end_date=None, 

1801 state=state, 

1802 include_subdags=include_subdags, 

1803 include_dependent_dags=include_dependent_dags, 

1804 include_parentdag=False, 

1805 as_pk_tuple=True, 

1806 exclude_task_ids=exclude_task_ids, 

1807 dag_bag=dag_bag, 

1808 session=session, 

1809 recursion_depth=recursion_depth + 1, 

1810 max_recursion_depth=max_recursion_depth, 

1811 visited_external_tis=visited_external_tis, 

1812 ) 

1813 ) 

1814 

1815 if result or as_pk_tuple: 

1816 # Only execute the `ti` query if we have also collected some other results (i.e. subdags etc.) 

1817 if as_pk_tuple: 

1818 result.update(TaskInstanceKey(**cols._mapping) for cols in tis.all()) 

1819 else: 

1820 result.update(ti.key for ti in tis) 

1821 

1822 if exclude_task_ids is not None: 

1823 result = { 

1824 task 

1825 for task in result 

1826 if task.task_id not in exclude_task_ids 

1827 and (task.task_id, task.map_index) not in exclude_task_ids 

1828 } 

1829 

1830 if as_pk_tuple: 

1831 return result 

1832 if result: 

1833 # We've been asked for objects, lets combine it all back in to a result set 

1834 ti_filters = TI.filter_for_tis(result) 

1835 if ti_filters is not None: 

1836 tis = session.query(TI).filter(ti_filters) 

1837 elif exclude_task_ids is None: 

1838 pass # Disable filter if not set. 

1839 elif isinstance(next(iter(exclude_task_ids), None), str): 

1840 tis = tis.filter(TI.task_id.notin_(exclude_task_ids)) 

1841 else: 

1842 tis = tis.filter(not_(tuple_in_condition((TI.task_id, TI.map_index), exclude_task_ids))) 

1843 

1844 return tis 

1845 

1846 @provide_session 

1847 def set_task_instance_state( 

1848 self, 

1849 *, 

1850 task_id: str, 

1851 map_indexes: Collection[int] | None = None, 

1852 execution_date: datetime | None = None, 

1853 run_id: str | None = None, 

1854 state: TaskInstanceState, 

1855 upstream: bool = False, 

1856 downstream: bool = False, 

1857 future: bool = False, 

1858 past: bool = False, 

1859 commit: bool = True, 

1860 session=NEW_SESSION, 

1861 ) -> list[TaskInstance]: 

1862 """ 

1863 Set the state of a TaskInstance to the given state, and clear its downstream tasks that are 

1864 in failed or upstream_failed state. 

1865 

1866 :param task_id: Task ID of the TaskInstance 

1867 :param map_indexes: Only set TaskInstance if its map_index matches. 

1868 If None (default), all mapped TaskInstances of the task are set. 

1869 :param execution_date: Execution date of the TaskInstance 

1870 :param run_id: The run_id of the TaskInstance 

1871 :param state: State to set the TaskInstance to 

1872 :param upstream: Include all upstream tasks of the given task_id 

1873 :param downstream: Include all downstream tasks of the given task_id 

1874 :param future: Include all future TaskInstances of the given task_id 

1875 :param commit: Commit changes 

1876 :param past: Include all past TaskInstances of the given task_id 

1877 """ 

1878 from airflow.api.common.mark_tasks import set_state 

1879 

1880 if not exactly_one(execution_date, run_id): 

1881 raise ValueError("Exactly one of execution_date or run_id must be provided") 

1882 

1883 task = self.get_task(task_id) 

1884 task.dag = self 

1885 

1886 tasks_to_set_state: list[Operator | tuple[Operator, int]] 

1887 if map_indexes is None: 

1888 tasks_to_set_state = [task] 

1889 else: 

1890 tasks_to_set_state = [(task, map_index) for map_index in map_indexes] 

1891 

1892 altered = set_state( 

1893 tasks=tasks_to_set_state, 

1894 execution_date=execution_date, 

1895 run_id=run_id, 

1896 upstream=upstream, 

1897 downstream=downstream, 

1898 future=future, 

1899 past=past, 

1900 state=state, 

1901 commit=commit, 

1902 session=session, 

1903 ) 

1904 

1905 if not commit: 

1906 return altered 

1907 

1908 # Clear downstream tasks that are in failed/upstream_failed state to resume them. 

1909 # Flush the session so that the tasks marked success are reflected in the db. 

1910 session.flush() 

1911 subdag = self.partial_subset( 

1912 task_ids_or_regex={task_id}, 

1913 include_downstream=True, 

1914 include_upstream=False, 

1915 ) 

1916 

1917 if execution_date is None: 

1918 dag_run = ( 

1919 session.query(DagRun).filter(DagRun.run_id == run_id, DagRun.dag_id == self.dag_id).one() 

1920 ) # Raises an error if not found 

1921 resolve_execution_date = dag_run.execution_date 

1922 else: 

1923 resolve_execution_date = execution_date 

1924 

1925 end_date = resolve_execution_date if not future else None 

1926 start_date = resolve_execution_date if not past else None 

1927 

1928 subdag.clear( 

1929 start_date=start_date, 

1930 end_date=end_date, 

1931 include_subdags=True, 

1932 include_parentdag=True, 

1933 only_failed=True, 

1934 session=session, 

1935 # Exclude the task itself from being cleared 

1936 exclude_task_ids=frozenset({task_id}), 

1937 ) 

1938 

1939 return altered 

1940 

1941 @provide_session 

1942 def set_task_group_state( 

1943 self, 

1944 *, 

1945 group_id: str, 

1946 execution_date: datetime | None = None, 

1947 run_id: str | None = None, 

1948 state: TaskInstanceState, 

1949 upstream: bool = False, 

1950 downstream: bool = False, 

1951 future: bool = False, 

1952 past: bool = False, 

1953 commit: bool = True, 

1954 session: Session = NEW_SESSION, 

1955 ) -> list[TaskInstance]: 

1956 """ 

1957 Set the state of the TaskGroup to the given state, and clear its downstream tasks that are 

1958 in failed or upstream_failed state. 

1959 

1960 :param group_id: The group_id of the TaskGroup 

1961 :param execution_date: Execution date of the TaskInstance 

1962 :param run_id: The run_id of the TaskInstance 

1963 :param state: State to set the TaskInstance to 

1964 :param upstream: Include all upstream tasks of the given task_id 

1965 :param downstream: Include all downstream tasks of the given task_id 

1966 :param future: Include all future TaskInstances of the given task_id 

1967 :param commit: Commit changes 

1968 :param past: Include all past TaskInstances of the given task_id 

1969 :param session: new session 

1970 """ 

1971 from airflow.api.common.mark_tasks import set_state 

1972 

1973 if not exactly_one(execution_date, run_id): 

1974 raise ValueError("Exactly one of execution_date or run_id must be provided") 

1975 

1976 tasks_to_set_state: list[BaseOperator | tuple[BaseOperator, int]] = [] 

1977 task_ids: list[str] = [] 

1978 locked_dag_run_ids: list[int] = [] 

1979 

1980 if execution_date is None: 

1981 dag_run = ( 

1982 session.query(DagRun).filter(DagRun.run_id == run_id, DagRun.dag_id == self.dag_id).one() 

1983 ) # Raises an error if not found 

1984 resolve_execution_date = dag_run.execution_date 

1985 else: 

1986 resolve_execution_date = execution_date 

1987 

1988 end_date = resolve_execution_date if not future else None 

1989 start_date = resolve_execution_date if not past else None 

1990 

1991 task_group_dict = self.task_group.get_task_group_dict() 

1992 task_group = task_group_dict.get(group_id) 

1993 if task_group is None: 

1994 raise ValueError("TaskGroup {group_id} could not be found") 

1995 tasks_to_set_state = [task for task in task_group.iter_tasks() if isinstance(task, BaseOperator)] 

1996 task_ids = [task.task_id for task in task_group.iter_tasks()] 

1997 dag_runs_query = session.query(DagRun.id).filter(DagRun.dag_id == self.dag_id).with_for_update() 

1998 

1999 if start_date is None and end_date is None: 

2000 dag_runs_query = dag_runs_query.filter(DagRun.execution_date == start_date) 

2001 else: 

2002 if start_date is not None: 

2003 dag_runs_query = dag_runs_query.filter(DagRun.execution_date >= start_date) 

2004 

2005 if end_date is not None: 

2006 dag_runs_query = dag_runs_query.filter(DagRun.execution_date <= end_date) 

2007 

2008 locked_dag_run_ids = dag_runs_query.all() 

2009 

2010 altered = set_state( 

2011 tasks=tasks_to_set_state, 

2012 execution_date=execution_date, 

2013 run_id=run_id, 

2014 upstream=upstream, 

2015 downstream=downstream, 

2016 future=future, 

2017 past=past, 

2018 state=state, 

2019 commit=commit, 

2020 session=session, 

2021 ) 

2022 

2023 if not commit: 

2024 del locked_dag_run_ids 

2025 return altered 

2026 

2027 # Clear downstream tasks that are in failed/upstream_failed state to resume them. 

2028 # Flush the session so that the tasks marked success are reflected in the db. 

2029 session.flush() 

2030 task_subset = self.partial_subset( 

2031 task_ids_or_regex=task_ids, 

2032 include_downstream=True, 

2033 include_upstream=False, 

2034 ) 

2035 

2036 task_subset.clear( 

2037 start_date=start_date, 

2038 end_date=end_date, 

2039 include_subdags=True, 

2040 include_parentdag=True, 

2041 only_failed=True, 

2042 session=session, 

2043 # Exclude the task from the current group from being cleared 

2044 exclude_task_ids=frozenset(task_ids), 

2045 ) 

2046 

2047 del locked_dag_run_ids 

2048 return altered 

2049 

2050 @property 

2051 def roots(self) -> list[Operator]: 

2052 """Return nodes with no parents. These are first to execute and are called roots or root nodes.""" 

2053 return [task for task in self.tasks if not task.upstream_list] 

2054 

2055 @property 

2056 def leaves(self) -> list[Operator]: 

2057 """Return nodes with no children. These are last to execute and are called leaves or leaf nodes.""" 

2058 return [task for task in self.tasks if not task.downstream_list] 

2059 

2060 def topological_sort(self, include_subdag_tasks: bool = False): 

2061 """ 

2062 Sorts tasks in topographical order, such that a task comes after any of its 

2063 upstream dependencies. 

2064 

2065 Deprecated in place of ``task_group.topological_sort`` 

2066 """ 

2067 from airflow.utils.task_group import TaskGroup 

2068 

2069 def nested_topo(group): 

2070 for node in group.topological_sort(_include_subdag_tasks=include_subdag_tasks): 

2071 if isinstance(node, TaskGroup): 

2072 yield from nested_topo(node) 

2073 else: 

2074 yield node 

2075 

2076 return tuple(nested_topo(self.task_group)) 

2077 

2078 @provide_session 

2079 def set_dag_runs_state( 

2080 self, 

2081 state: str = State.RUNNING, 

2082 session: Session = NEW_SESSION, 

2083 start_date: datetime | None = None, 

2084 end_date: datetime | None = None, 

2085 dag_ids: list[str] = [], 

2086 ) -> None: 

2087 warnings.warn( 

2088 "This method is deprecated and will be removed in a future version.", 

2089 RemovedInAirflow3Warning, 

2090 stacklevel=3, 

2091 ) 

2092 dag_ids = dag_ids or [self.dag_id] 

2093 query = session.query(DagRun).filter(DagRun.dag_id.in_(dag_ids)) 

2094 if start_date: 

2095 query = query.filter(DagRun.execution_date >= start_date) 

2096 if end_date: 

2097 query = query.filter(DagRun.execution_date <= end_date) 

2098 query.update({DagRun.state: state}, synchronize_session="fetch") 

2099 

2100 @provide_session 

2101 def clear( 

2102 self, 

2103 task_ids: Collection[str | tuple[str, int]] | None = None, 

2104 start_date: datetime | None = None, 

2105 end_date: datetime | None = None, 

2106 only_failed: bool = False, 

2107 only_running: bool = False, 

2108 confirm_prompt: bool = False, 

2109 include_subdags: bool = True, 

2110 include_parentdag: bool = True, 

2111 dag_run_state: DagRunState = DagRunState.QUEUED, 

2112 dry_run: bool = False, 

2113 session: Session = NEW_SESSION, 

2114 get_tis: bool = False, 

2115 recursion_depth: int = 0, 

2116 max_recursion_depth: int | None = None, 

2117 dag_bag: DagBag | None = None, 

2118 exclude_task_ids: frozenset[str] | frozenset[tuple[str, int]] | None = frozenset(), 

2119 ) -> int | Iterable[TaskInstance]: 

2120 """ 

2121 Clears a set of task instances associated with the current dag for 

2122 a specified date range. 

2123 

2124 :param task_ids: List of task ids or (``task_id``, ``map_index``) tuples to clear 

2125 :param start_date: The minimum execution_date to clear 

2126 :param end_date: The maximum execution_date to clear 

2127 :param only_failed: Only clear failed tasks 

2128 :param only_running: Only clear running tasks. 

2129 :param confirm_prompt: Ask for confirmation 

2130 :param include_subdags: Clear tasks in subdags and clear external tasks 

2131 indicated by ExternalTaskMarker 

2132 :param include_parentdag: Clear tasks in the parent dag of the subdag. 

2133 :param dag_run_state: state to set DagRun to. If set to False, dagrun state will not 

2134 be changed. 

2135 :param dry_run: Find the tasks to clear but don't clear them. 

2136 :param session: The sqlalchemy session to use 

2137 :param dag_bag: The DagBag used to find the dags subdags (Optional) 

2138 :param exclude_task_ids: A set of ``task_id`` or (``task_id``, ``map_index``) 

2139 tuples that should not be cleared 

2140 """ 

2141 if get_tis: 

2142 warnings.warn( 

2143 "Passing `get_tis` to dag.clear() is deprecated. Use `dry_run` parameter instead.", 

2144 RemovedInAirflow3Warning, 

2145 stacklevel=2, 

2146 ) 

2147 dry_run = True 

2148 

2149 if recursion_depth: 

2150 warnings.warn( 

2151 "Passing `recursion_depth` to dag.clear() is deprecated.", 

2152 RemovedInAirflow3Warning, 

2153 stacklevel=2, 

2154 ) 

2155 if max_recursion_depth: 

2156 warnings.warn( 

2157 "Passing `max_recursion_depth` to dag.clear() is deprecated.", 

2158 RemovedInAirflow3Warning, 

2159 stacklevel=2, 

2160 ) 

2161 

2162 state = [] 

2163 if only_failed: 

2164 state += [State.FAILED, State.UPSTREAM_FAILED] 

2165 if only_running: 

2166 # Yes, having `+=` doesn't make sense, but this was the existing behaviour 

2167 state += [State.RUNNING] 

2168 

2169 tis = self._get_task_instances( 

2170 task_ids=task_ids, 

2171 start_date=start_date, 

2172 end_date=end_date, 

2173 run_id=None, 

2174 state=state, 

2175 include_subdags=include_subdags, 

2176 include_parentdag=include_parentdag, 

2177 include_dependent_dags=include_subdags, # compat, yes this is not a typo 

2178 session=session, 

2179 dag_bag=dag_bag, 

2180 exclude_task_ids=exclude_task_ids, 

2181 ) 

2182 

2183 if dry_run: 

2184 return tis 

2185 

2186 tis = list(tis) 

2187 

2188 count = len(tis) 

2189 do_it = True 

2190 if count == 0: 

2191 return 0 

2192 if confirm_prompt: 

2193 ti_list = "\n".join(str(t) for t in tis) 

2194 question = ( 

2195 "You are about to delete these {count} tasks:\n{ti_list}\n\nAre you sure? [y/n]" 

2196 ).format(count=count, ti_list=ti_list) 

2197 do_it = utils.helpers.ask_yesno(question) 

2198 

2199 if do_it: 

2200 clear_task_instances( 

2201 tis, 

2202 session, 

2203 dag=self, 

2204 dag_run_state=dag_run_state, 

2205 ) 

2206 else: 

2207 count = 0 

2208 print("Cancelled, nothing was cleared.") 

2209 

2210 session.flush() 

2211 return count 

2212 

2213 @classmethod 

2214 def clear_dags( 

2215 cls, 

2216 dags, 

2217 start_date=None, 

2218 end_date=None, 

2219 only_failed=False, 

2220 only_running=False, 

2221 confirm_prompt=False, 

2222 include_subdags=True, 

2223 include_parentdag=False, 

2224 dag_run_state=DagRunState.QUEUED, 

2225 dry_run=False, 

2226 ): 

2227 all_tis = [] 

2228 for dag in dags: 

2229 tis = dag.clear( 

2230 start_date=start_date, 

2231 end_date=end_date, 

2232 only_failed=only_failed, 

2233 only_running=only_running, 

2234 confirm_prompt=False, 

2235 include_subdags=include_subdags, 

2236 include_parentdag=include_parentdag, 

2237 dag_run_state=dag_run_state, 

2238 dry_run=True, 

2239 ) 

2240 all_tis.extend(tis) 

2241 

2242 if dry_run: 

2243 return all_tis 

2244 

2245 count = len(all_tis) 

2246 do_it = True 

2247 if count == 0: 

2248 print("Nothing to clear.") 

2249 return 0 

2250 if confirm_prompt: 

2251 ti_list = "\n".join(str(t) for t in all_tis) 

2252 question = f"You are about to delete these {count} tasks:\n{ti_list}\n\nAre you sure? [y/n]" 

2253 do_it = utils.helpers.ask_yesno(question) 

2254 

2255 if do_it: 

2256 for dag in dags: 

2257 dag.clear( 

2258 start_date=start_date, 

2259 end_date=end_date, 

2260 only_failed=only_failed, 

2261 only_running=only_running, 

2262 confirm_prompt=False, 

2263 include_subdags=include_subdags, 

2264 dag_run_state=dag_run_state, 

2265 dry_run=False, 

2266 ) 

2267 else: 

2268 count = 0 

2269 print("Cancelled, nothing was cleared.") 

2270 return count 

2271 

2272 def __deepcopy__(self, memo): 

2273 # Switcharoo to go around deepcopying objects coming through the 

2274 # backdoor 

2275 cls = self.__class__ 

2276 result = cls.__new__(cls) 

2277 memo[id(self)] = result 

2278 for k, v in self.__dict__.items(): 

2279 if k not in ("user_defined_macros", "user_defined_filters", "_log"): 

2280 setattr(result, k, copy.deepcopy(v, memo)) 

2281 

2282 result.user_defined_macros = self.user_defined_macros 

2283 result.user_defined_filters = self.user_defined_filters 

2284 if hasattr(self, "_log"): 

2285 result._log = self._log 

2286 return result 

2287 

2288 def sub_dag(self, *args, **kwargs): 

2289 """This method is deprecated in favor of partial_subset.""" 

2290 warnings.warn( 

2291 "This method is deprecated and will be removed in a future version. Please use partial_subset", 

2292 RemovedInAirflow3Warning, 

2293 stacklevel=2, 

2294 ) 

2295 return self.partial_subset(*args, **kwargs) 

2296 

2297 def partial_subset( 

2298 self, 

2299 task_ids_or_regex: str | re.Pattern | Iterable[str], 

2300 include_downstream=False, 

2301 include_upstream=True, 

2302 include_direct_upstream=False, 

2303 ): 

2304 """ 

2305 Returns a subset of the current dag as a deep copy of the current dag 

2306 based on a regex that should match one or many tasks, and includes 

2307 upstream and downstream neighbours based on the flag passed. 

2308 

2309 :param task_ids_or_regex: Either a list of task_ids, or a regex to 

2310 match against task ids (as a string, or compiled regex pattern). 

2311 :param include_downstream: Include all downstream tasks of matched 

2312 tasks, in addition to matched tasks. 

2313 :param include_upstream: Include all upstream tasks of matched tasks, 

2314 in addition to matched tasks. 

2315 :param include_direct_upstream: Include all tasks directly upstream of matched 

2316 and downstream (if include_downstream = True) tasks 

2317 """ 

2318 from airflow.models.baseoperator import BaseOperator 

2319 from airflow.models.mappedoperator import MappedOperator 

2320 

2321 # deep-copying self.task_dict and self._task_group takes a long time, and we don't want all 

2322 # the tasks anyway, so we copy the tasks manually later 

2323 memo = {id(self.task_dict): None, id(self._task_group): None} 

2324 dag = copy.deepcopy(self, memo) # type: ignore 

2325 

2326 if isinstance(task_ids_or_regex, (str, re.Pattern)): 

2327 matched_tasks = [t for t in self.tasks if re.findall(task_ids_or_regex, t.task_id)] 

2328 else: 

2329 matched_tasks = [t for t in self.tasks if t.task_id in task_ids_or_regex] 

2330 

2331 also_include: list[Operator] = [] 

2332 for t in matched_tasks: 

2333 if include_downstream: 

2334 also_include.extend(t.get_flat_relatives(upstream=False)) 

2335 if include_upstream: 

2336 also_include.extend(t.get_flat_relatives(upstream=True)) 

2337 

2338 direct_upstreams: list[Operator] = [] 

2339 if include_direct_upstream: 

2340 for t in itertools.chain(matched_tasks, also_include): 

2341 upstream = (u for u in t.upstream_list if isinstance(u, (BaseOperator, MappedOperator))) 

2342 direct_upstreams.extend(upstream) 

2343 

2344 # Compiling the unique list of tasks that made the cut 

2345 # Make sure to not recursively deepcopy the dag or task_group while copying the task. 

2346 # task_group is reset later 

2347 def _deepcopy_task(t) -> Operator: 

2348 memo.setdefault(id(t.task_group), None) 

2349 return copy.deepcopy(t, memo) 

2350 

2351 dag.task_dict = { 

2352 t.task_id: _deepcopy_task(t) 

2353 for t in itertools.chain(matched_tasks, also_include, direct_upstreams) 

2354 } 

2355 

2356 def filter_task_group(group, parent_group): 

2357 """Exclude tasks not included in the subdag from the given TaskGroup.""" 

2358 # We want to deepcopy _most but not all_ attributes of the task group, so we create a shallow copy 

2359 # and then manually deep copy the instances. (memo argument to deepcopy only works for instances 

2360 # of classes, not "native" properties of an instance) 

2361 copied = copy.copy(group) 

2362 

2363 memo[id(group.children)] = {} 

2364 if parent_group: 

2365 memo[id(group.parent_group)] = parent_group 

2366 for attr, value in copied.__dict__.items(): 

2367 if id(value) in memo: 

2368 value = memo[id(value)] 

2369 else: 

2370 value = copy.deepcopy(value, memo) 

2371 copied.__dict__[attr] = value 

2372 

2373 proxy = weakref.proxy(copied) 

2374 

2375 for child in group.children.values(): 

2376 if isinstance(child, AbstractOperator): 

2377 if child.task_id in dag.task_dict: 

2378 task = copied.children[child.task_id] = dag.task_dict[child.task_id] 

2379 task.task_group = proxy 

2380 else: 

2381 copied.used_group_ids.discard(child.task_id) 

2382 else: 

2383 filtered_child = filter_task_group(child, proxy) 

2384 

2385 # Only include this child TaskGroup if it is non-empty. 

2386 if filtered_child.children: 

2387 copied.children[child.group_id] = filtered_child 

2388 

2389 return copied 

2390 

2391 dag._task_group = filter_task_group(self.task_group, None) 

2392 

2393 # Removing upstream/downstream references to tasks and TaskGroups that did not make 

2394 # the cut. 

2395 subdag_task_groups = dag.task_group.get_task_group_dict() 

2396 for group in subdag_task_groups.values(): 

2397 group.upstream_group_ids.intersection_update(subdag_task_groups) 

2398 group.downstream_group_ids.intersection_update(subdag_task_groups) 

2399 group.upstream_task_ids.intersection_update(dag.task_dict) 

2400 group.downstream_task_ids.intersection_update(dag.task_dict) 

2401 

2402 for t in dag.tasks: 

2403 # Removing upstream/downstream references to tasks that did not 

2404 # make the cut 

2405 t.upstream_task_ids.intersection_update(dag.task_dict) 

2406 t.downstream_task_ids.intersection_update(dag.task_dict) 

2407 

2408 if len(dag.tasks) < len(self.tasks): 

2409 dag.partial = True 

2410 

2411 return dag 

2412 

2413 def has_task(self, task_id: str): 

2414 return task_id in self.task_dict 

2415 

2416 def has_task_group(self, task_group_id: str) -> bool: 

2417 return task_group_id in self.task_group_dict 

2418 

2419 @functools.cached_property 

2420 def task_group_dict(self): 

2421 return {k: v for k, v in self._task_group.get_task_group_dict().items() if k is not None} 

2422 

2423 def get_task(self, task_id: str, include_subdags: bool = False) -> Operator: 

2424 if task_id in self.task_dict: 

2425 return self.task_dict[task_id] 

2426 if include_subdags: 

2427 for dag in self.subdags: 

2428 if task_id in dag.task_dict: 

2429 return dag.task_dict[task_id] 

2430 raise TaskNotFound(f"Task {task_id} not found") 

2431 

2432 def pickle_info(self): 

2433 d = {} 

2434 d["is_picklable"] = True 

2435 try: 

2436 dttm = timezone.utcnow() 

2437 pickled = pickle.dumps(self) 

2438 d["pickle_len"] = len(pickled) 

2439 d["pickling_duration"] = str(timezone.utcnow() - dttm) 

2440 except Exception as e: 

2441 self.log.debug(e) 

2442 d["is_picklable"] = False 

2443 d["stacktrace"] = traceback.format_exc() 

2444 return d 

2445 

2446 @provide_session 

2447 def pickle(self, session=NEW_SESSION) -> DagPickle: 

2448 dag = session.query(DagModel).filter(DagModel.dag_id == self.dag_id).first() 

2449 dp = None 

2450 if dag and dag.pickle_id: 

2451 dp = session.query(DagPickle).filter(DagPickle.id == dag.pickle_id).first() 

2452 if not dp or dp.pickle != self: 

2453 dp = DagPickle(dag=self) 

2454 session.add(dp) 

2455 self.last_pickled = timezone.utcnow() 

2456 session.commit() 

2457 self.pickle_id = dp.id 

2458 

2459 return dp 

2460 

2461 def tree_view(self) -> None: 

2462 """Print an ASCII tree representation of the DAG.""" 

2463 

2464 def get_downstream(task, level=0): 

2465 print((" " * level * 4) + str(task)) 

2466 level += 1 

2467 for t in task.downstream_list: 

2468 get_downstream(t, level) 

2469 

2470 for t in self.roots: 

2471 get_downstream(t) 

2472 

2473 @property 

2474 def task(self) -> TaskDecoratorCollection: 

2475 from airflow.decorators import task 

2476 

2477 return cast("TaskDecoratorCollection", functools.partial(task, dag=self)) 

2478 

2479 def add_task(self, task: Operator) -> None: 

2480 """ 

2481 Add a task to the DAG. 

2482 

2483 :param task: the task you want to add 

2484 """ 

2485 DagInvalidTriggerRule.check(self, task.trigger_rule) 

2486 

2487 from airflow.utils.task_group import TaskGroupContext 

2488 

2489 if not self.start_date and not task.start_date: 

2490 raise AirflowException("DAG is missing the start_date parameter") 

2491 # if the task has no start date, assign it the same as the DAG 

2492 elif not task.start_date: 

2493 task.start_date = self.start_date 

2494 # otherwise, the task will start on the later of its own start date and 

2495 # the DAG's start date 

2496 elif self.start_date: 

2497 task.start_date = max(task.start_date, self.start_date) 

2498 

2499 # if the task has no end date, assign it the same as the dag 

2500 if not task.end_date: 

2501 task.end_date = self.end_date 

2502 # otherwise, the task will end on the earlier of its own end date and 

2503 # the DAG's end date 

2504 elif task.end_date and self.end_date: 

2505 task.end_date = min(task.end_date, self.end_date) 

2506 

2507 task_id = task.task_id 

2508 if not task.task_group: 

2509 task_group = TaskGroupContext.get_current_task_group(self) 

2510 if task_group: 

2511 task_id = task_group.child_id(task_id) 

2512 task_group.add(task) 

2513 

2514 if ( 

2515 task_id in self.task_dict and self.task_dict[task_id] is not task 

2516 ) or task_id in self._task_group.used_group_ids: 

2517 raise DuplicateTaskIdFound(f"Task id '{task_id}' has already been added to the DAG") 

2518 else: 

2519 self.task_dict[task_id] = task 

2520 task.dag = self 

2521 # Add task_id to used_group_ids to prevent group_id and task_id collisions. 

2522 self._task_group.used_group_ids.add(task_id) 

2523 

2524 self.task_count = len(self.task_dict) 

2525 

2526 def add_tasks(self, tasks: Iterable[Operator]) -> None: 

2527 """ 

2528 Add a list of tasks to the DAG. 

2529 

2530 :param tasks: a lit of tasks you want to add 

2531 """ 

2532 for task in tasks: 

2533 self.add_task(task) 

2534 

2535 def _remove_task(self, task_id: str) -> None: 

2536 # This is "private" as removing could leave a hole in dependencies if done incorrectly, and this 

2537 # doesn't guard against that 

2538 task = self.task_dict.pop(task_id) 

2539 tg = getattr(task, "task_group", None) 

2540 if tg: 

2541 tg._remove(task) 

2542 

2543 self.task_count = len(self.task_dict) 

2544 

2545 def run( 

2546 self, 

2547 start_date=None, 

2548 end_date=None, 

2549 mark_success=False, 

2550 local=False, 

2551 executor=None, 

2552 donot_pickle=conf.getboolean("core", "donot_pickle"), 

2553 ignore_task_deps=False, 

2554 ignore_first_depends_on_past=True, 

2555 pool=None, 

2556 delay_on_limit_secs=1.0, 

2557 verbose=False, 

2558 conf=None, 

2559 rerun_failed_tasks=False, 

2560 run_backwards=False, 

2561 run_at_least_once=False, 

2562 continue_on_failures=False, 

2563 disable_retry=False, 

2564 ): 

2565 """ 

2566 Runs the DAG. 

2567 

2568 :param start_date: the start date of the range to run 

2569 :param end_date: the end date of the range to run 

2570 :param mark_success: True to mark jobs as succeeded without running them 

2571 :param local: True to run the tasks using the LocalExecutor 

2572 :param executor: The executor instance to run the tasks 

2573 :param donot_pickle: True to avoid pickling DAG object and send to workers 

2574 :param ignore_task_deps: True to skip upstream tasks 

2575 :param ignore_first_depends_on_past: True to ignore depends_on_past 

2576 dependencies for the first set of tasks only 

2577 :param pool: Resource pool to use 

2578 :param delay_on_limit_secs: Time in seconds to wait before next attempt to run 

2579 dag run when max_active_runs limit has been reached 

2580 :param verbose: Make logging output more verbose 

2581 :param conf: user defined dictionary passed from CLI 

2582 :param rerun_failed_tasks: 

2583 :param run_backwards: 

2584 :param run_at_least_once: If true, always run the DAG at least once even 

2585 if no logical run exists within the time range. 

2586 """ 

2587 from airflow.jobs.backfill_job_runner import BackfillJobRunner 

2588 

2589 if not executor and local: 

2590 from airflow.executors.local_executor import LocalExecutor 

2591 

2592 executor = LocalExecutor() 

2593 elif not executor: 

2594 from airflow.executors.executor_loader import ExecutorLoader 

2595 

2596 executor = ExecutorLoader.get_default_executor() 

2597 from airflow.jobs.job import Job 

2598 

2599 job = Job(executor=executor) 

2600 job_runner = BackfillJobRunner( 

2601 job=job, 

2602 dag=self, 

2603 start_date=start_date, 

2604 end_date=end_date, 

2605 mark_success=mark_success, 

2606 donot_pickle=donot_pickle, 

2607 ignore_task_deps=ignore_task_deps, 

2608 ignore_first_depends_on_past=ignore_first_depends_on_past, 

2609 pool=pool, 

2610 delay_on_limit_secs=delay_on_limit_secs, 

2611 verbose=verbose, 

2612 conf=conf, 

2613 rerun_failed_tasks=rerun_failed_tasks, 

2614 run_backwards=run_backwards, 

2615 run_at_least_once=run_at_least_once, 

2616 continue_on_failures=continue_on_failures, 

2617 disable_retry=disable_retry, 

2618 ) 

2619 run_job(job=job, execute_callable=job_runner._execute) 

2620 

2621 def cli(self): 

2622 """Exposes a CLI specific to this DAG.""" 

2623 check_cycle(self) 

2624 

2625 from airflow.cli import cli_parser 

2626 

2627 parser = cli_parser.get_parser(dag_parser=True) 

2628 args = parser.parse_args() 

2629 args.func(args, self) 

2630 

2631 @provide_session 

2632 def test( 

2633 self, 

2634 execution_date: datetime | None = None, 

2635 run_conf: dict[str, Any] | None = None, 

2636 conn_file_path: str | None = None, 

2637 variable_file_path: str | None = None, 

2638 session: Session = NEW_SESSION, 

2639 ) -> None: 

2640 """ 

2641 Execute one single DagRun for a given DAG and execution date. 

2642 

2643 :param execution_date: execution date for the DAG run 

2644 :param run_conf: configuration to pass to newly created dagrun 

2645 :param conn_file_path: file path to a connection file in either yaml or json 

2646 :param variable_file_path: file path to a variable file in either yaml or json 

2647 :param session: database connection (optional) 

2648 """ 

2649 

2650 def add_logger_if_needed(ti: TaskInstance): 

2651 """Add a formatted logger to the task instance. 

2652 

2653 This allows all logs to surface to the command line, instead of into 

2654 a task file. Since this is a local test run, it is much better for 

2655 the user to see logs in the command line, rather than needing to 

2656 search for a log file. 

2657 

2658 :param ti: The task instance that will receive a logger. 

2659 """ 

2660 format = logging.Formatter("[%(asctime)s] {%(filename)s:%(lineno)d} %(levelname)s - %(message)s") 

2661 handler = logging.StreamHandler(sys.stdout) 

2662 handler.level = logging.INFO 

2663 handler.setFormatter(format) 

2664 # only add log handler once 

2665 if not any(isinstance(h, logging.StreamHandler) for h in ti.log.handlers): 

2666 self.log.debug("Adding Streamhandler to taskinstance %s", ti.task_id) 

2667 ti.log.addHandler(handler) 

2668 

2669 if conn_file_path or variable_file_path: 

2670 local_secrets = LocalFilesystemBackend( 

2671 variables_file_path=variable_file_path, connections_file_path=conn_file_path 

2672 ) 

2673 secrets_backend_list.insert(0, local_secrets) 

2674 

2675 execution_date = execution_date or timezone.utcnow() 

2676 self.log.debug("Clearing existing task instances for execution date %s", execution_date) 

2677 self.clear( 

2678 start_date=execution_date, 

2679 end_date=execution_date, 

2680 dag_run_state=False, # type: ignore 

2681 session=session, 

2682 ) 

2683 self.log.debug("Getting dagrun for dag %s", self.dag_id) 

2684 dr: DagRun = _get_or_create_dagrun( 

2685 dag=self, 

2686 start_date=execution_date, 

2687 execution_date=execution_date, 

2688 run_id=DagRun.generate_run_id(DagRunType.MANUAL, execution_date), 

2689 session=session, 

2690 conf=run_conf, 

2691 ) 

2692 

2693 tasks = self.task_dict 

2694 self.log.debug("starting dagrun") 

2695 # Instead of starting a scheduler, we run the minimal loop possible to check 

2696 # for task readiness and dependency management. This is notably faster 

2697 # than creating a BackfillJob and allows us to surface logs to the user 

2698 while dr.state == State.RUNNING: 

2699 schedulable_tis, _ = dr.update_state(session=session) 

2700 for ti in schedulable_tis: 

2701 add_logger_if_needed(ti) 

2702 ti.task = tasks[ti.task_id] 

2703 _run_task(ti, session=session) 

2704 if conn_file_path or variable_file_path: 

2705 # Remove the local variables we have added to the secrets_backend_list 

2706 secrets_backend_list.pop(0) 

2707 

2708 @provide_session 

2709 def create_dagrun( 

2710 self, 

2711 state: DagRunState, 

2712 execution_date: datetime | None = None, 

2713 run_id: str | None = None, 

2714 start_date: datetime | None = None, 

2715 external_trigger: bool | None = False, 

2716 conf: dict | None = None, 

2717 run_type: DagRunType | None = None, 

2718 session: Session = NEW_SESSION, 

2719 dag_hash: str | None = None, 

2720 creating_job_id: int | None = None, 

2721 data_interval: tuple[datetime, datetime] | None = None, 

2722 ): 

2723 """ 

2724 Creates a dag run from this dag including the tasks associated with this dag. 

2725 Returns the dag run. 

2726 

2727 :param run_id: defines the run id for this dag run 

2728 :param run_type: type of DagRun 

2729 :param execution_date: the execution date of this dag run 

2730 :param state: the state of the dag run 

2731 :param start_date: the date this dag run should be evaluated 

2732 :param external_trigger: whether this dag run is externally triggered 

2733 :param conf: Dict containing configuration/parameters to pass to the DAG 

2734 :param creating_job_id: id of the job creating this DagRun 

2735 :param session: database session 

2736 :param dag_hash: Hash of Serialized DAG 

2737 :param data_interval: Data interval of the DagRun 

2738 """ 

2739 logical_date = timezone.coerce_datetime(execution_date) 

2740 

2741 if data_interval and not isinstance(data_interval, DataInterval): 

2742 data_interval = DataInterval(*map(timezone.coerce_datetime, data_interval)) 

2743 

2744 if data_interval is None and logical_date is not None: 

2745 warnings.warn( 

2746 "Calling `DAG.create_dagrun()` without an explicit data interval is deprecated", 

2747 RemovedInAirflow3Warning, 

2748 stacklevel=3, 

2749 ) 

2750 if run_type == DagRunType.MANUAL: 

2751 data_interval = self.timetable.infer_manual_data_interval(run_after=logical_date) 

2752 else: 

2753 data_interval = self.infer_automated_data_interval(logical_date) 

2754 

2755 if run_type is None or isinstance(run_type, DagRunType): 

2756 pass 

2757 elif isinstance(run_type, str): # Compatibility: run_type used to be a str. 

2758 run_type = DagRunType(run_type) 

2759 else: 

2760 raise ValueError(f"`run_type` should be a DagRunType, not {type(run_type)}") 

2761 

2762 if run_id: # Infer run_type from run_id if needed. 

2763 if not isinstance(run_id, str): 

2764 raise ValueError(f"`run_id` should be a str, not {type(run_id)}") 

2765 inferred_run_type = DagRunType.from_run_id(run_id) 

2766 if run_type is None: 

2767 # No explicit type given, use the inferred type. 

2768 run_type = inferred_run_type 

2769 elif run_type == DagRunType.MANUAL and inferred_run_type != DagRunType.MANUAL: 

2770 # Prevent a manual run from using an ID that looks like a scheduled run. 

2771 raise ValueError( 

2772 f"A {run_type.value} DAG run cannot use ID {run_id!r} since it " 

2773 f"is reserved for {inferred_run_type.value} runs" 

2774 ) 

2775 elif run_type and logical_date is not None: # Generate run_id from run_type and execution_date. 

2776 run_id = self.timetable.generate_run_id( 

2777 run_type=run_type, logical_date=logical_date, data_interval=data_interval 

2778 ) 

2779 else: 

2780 raise AirflowException( 

2781 "Creating DagRun needs either `run_id` or both `run_type` and `execution_date`" 

2782 ) 

2783 

2784 if run_id and "/" in run_id: 

2785 warnings.warn( 

2786 "Using forward slash ('/') in a DAG run ID is deprecated. Note that this character " 

2787 "also makes the run impossible to retrieve via Airflow's REST API.", 

2788 RemovedInAirflow3Warning, 

2789 stacklevel=3, 

2790 ) 

2791 

2792 # create a copy of params before validating 

2793 copied_params = copy.deepcopy(self.params) 

2794 copied_params.update(conf or {}) 

2795 copied_params.validate() 

2796 

2797 run = DagRun( 

2798 dag_id=self.dag_id, 

2799 run_id=run_id, 

2800 execution_date=logical_date, 

2801 start_date=start_date, 

2802 external_trigger=external_trigger, 

2803 conf=conf, 

2804 state=state, 

2805 run_type=run_type, 

2806 dag_hash=dag_hash, 

2807 creating_job_id=creating_job_id, 

2808 data_interval=data_interval, 

2809 ) 

2810 session.add(run) 

2811 session.flush() 

2812 

2813 run.dag = self 

2814 

2815 # create the associated task instances 

2816 # state is None at the moment of creation 

2817 run.verify_integrity(session=session) 

2818 

2819 return run 

2820 

2821 @classmethod 

2822 @provide_session 

2823 def bulk_sync_to_db( 

2824 cls, 

2825 dags: Collection[DAG], 

2826 session=NEW_SESSION, 

2827 ): 

2828 """This method is deprecated in favor of bulk_write_to_db.""" 

2829 warnings.warn( 

2830 "This method is deprecated and will be removed in a future version. Please use bulk_write_to_db", 

2831 RemovedInAirflow3Warning, 

2832 stacklevel=2, 

2833 ) 

2834 return cls.bulk_write_to_db(dags=dags, session=session) 

2835 

2836 @classmethod 

2837 @provide_session 

2838 def bulk_write_to_db( 

2839 cls, 

2840 dags: Collection[DAG], 

2841 processor_subdir: str | None = None, 

2842 session=NEW_SESSION, 

2843 ): 

2844 """ 

2845 Ensure the DagModel rows for the given dags are up-to-date in the dag table in the DB, including 

2846 calculated fields. 

2847 

2848 Note that this method can be called for both DAGs and SubDAGs. A SubDag is actually a SubDagOperator. 

2849 

2850 :param dags: the DAG objects to save to the DB 

2851 :return: None 

2852 """ 

2853 if not dags: 

2854 return 

2855 

2856 log.info("Sync %s DAGs", len(dags)) 

2857 dag_by_ids = {dag.dag_id: dag for dag in dags} 

2858 

2859 dag_ids = set(dag_by_ids.keys()) 

2860 query = ( 

2861 session.query(DagModel) 

2862 .options(joinedload(DagModel.tags, innerjoin=False)) 

2863 .filter(DagModel.dag_id.in_(dag_ids)) 

2864 .options(joinedload(DagModel.schedule_dataset_references)) 

2865 .options(joinedload(DagModel.task_outlet_dataset_references)) 

2866 ) 

2867 orm_dags: list[DagModel] = with_row_locks(query, of=DagModel, session=session).all() 

2868 existing_dags = {orm_dag.dag_id: orm_dag for orm_dag in orm_dags} 

2869 missing_dag_ids = dag_ids.difference(existing_dags) 

2870 

2871 for missing_dag_id in missing_dag_ids: 

2872 orm_dag = DagModel(dag_id=missing_dag_id) 

2873 dag = dag_by_ids[missing_dag_id] 

2874 if dag.is_paused_upon_creation is not None: 

2875 orm_dag.is_paused = dag.is_paused_upon_creation 

2876 orm_dag.tags = [] 

2877 log.info("Creating ORM DAG for %s", dag.dag_id) 

2878 session.add(orm_dag) 

2879 orm_dags.append(orm_dag) 

2880 

2881 # Get the latest dag run for each existing dag as a single query (avoid n+1 query) 

2882 most_recent_subq = ( 

2883 session.query(DagRun.dag_id, func.max(DagRun.execution_date).label("max_execution_date")) 

2884 .filter( 

2885 DagRun.dag_id.in_(existing_dags), 

2886 or_(DagRun.run_type == DagRunType.BACKFILL_JOB, DagRun.run_type == DagRunType.SCHEDULED), 

2887 ) 

2888 .group_by(DagRun.dag_id) 

2889 .subquery() 

2890 ) 

2891 most_recent_runs_iter = session.query(DagRun).filter( 

2892 DagRun.dag_id == most_recent_subq.c.dag_id, 

2893 DagRun.execution_date == most_recent_subq.c.max_execution_date, 

2894 ) 

2895 most_recent_runs = {run.dag_id: run for run in most_recent_runs_iter} 

2896 

2897 # Get number of active dagruns for all dags we are processing as a single query. 

2898 

2899 num_active_runs = DagRun.active_runs_of_dags(dag_ids=existing_dags, session=session) 

2900 

2901 filelocs = [] 

2902 

2903 for orm_dag in sorted(orm_dags, key=lambda d: d.dag_id): 

2904 dag = dag_by_ids[orm_dag.dag_id] 

2905 filelocs.append(dag.fileloc) 

2906 if dag.is_subdag: 

2907 orm_dag.is_subdag = True 

2908 orm_dag.fileloc = dag.parent_dag.fileloc # type: ignore 

2909 orm_dag.root_dag_id = dag.parent_dag.dag_id # type: ignore 

2910 orm_dag.owners = dag.parent_dag.owner # type: ignore 

2911 else: 

2912 orm_dag.is_subdag = False 

2913 orm_dag.fileloc = dag.fileloc 

2914 orm_dag.owners = dag.owner 

2915 orm_dag.is_active = True 

2916 orm_dag.has_import_errors = False 

2917 orm_dag.last_parsed_time = timezone.utcnow() 

2918 orm_dag.default_view = dag.default_view 

2919 orm_dag.description = dag.description 

2920 orm_dag.max_active_tasks = dag.max_active_tasks 

2921 orm_dag.max_active_runs = dag.max_active_runs 

2922 orm_dag.has_task_concurrency_limits = any( 

2923 t.max_active_tis_per_dag is not None or t.max_active_tis_per_dagrun is not None 

2924 for t in dag.tasks 

2925 ) 

2926 orm_dag.schedule_interval = dag.schedule_interval 

2927 orm_dag.timetable_description = dag.timetable.description 

2928 orm_dag.processor_subdir = processor_subdir 

2929 

2930 run: DagRun | None = most_recent_runs.get(dag.dag_id) 

2931 if run is None: 

2932 data_interval = None 

2933 else: 

2934 data_interval = dag.get_run_data_interval(run) 

2935 if num_active_runs.get(dag.dag_id, 0) >= orm_dag.max_active_runs: 

2936 orm_dag.next_dagrun_create_after = None 

2937 else: 

2938 orm_dag.calculate_dagrun_date_fields(dag, data_interval) 

2939 

2940 dag_tags = set(dag.tags or {}) 

2941 orm_dag_tags = list(orm_dag.tags or []) 

2942 for orm_tag in orm_dag_tags: 

2943 if orm_tag.name not in dag_tags: 

2944 session.delete(orm_tag) 

2945 orm_dag.tags.remove(orm_tag) 

2946 orm_tag_names = {t.name for t in orm_dag_tags} 

2947 for dag_tag in dag_tags: 

2948 if dag_tag not in orm_tag_names: 

2949 dag_tag_orm = DagTag(name=dag_tag, dag_id=dag.dag_id) 

2950 orm_dag.tags.append(dag_tag_orm) 

2951 session.add(dag_tag_orm) 

2952 

2953 orm_dag_links = orm_dag.dag_owner_links or [] 

2954 for orm_dag_link in orm_dag_links: 

2955 if orm_dag_link not in dag.owner_links: 

2956 session.delete(orm_dag_link) 

2957 for owner_name, owner_link in dag.owner_links.items(): 

2958 dag_owner_orm = DagOwnerAttributes(dag_id=dag.dag_id, owner=owner_name, link=owner_link) 

2959 session.add(dag_owner_orm) 

2960 

2961 DagCode.bulk_sync_to_db(filelocs, session=session) 

2962 

2963 from airflow.datasets import Dataset 

2964 from airflow.models.dataset import ( 

2965 DagScheduleDatasetReference, 

2966 DatasetModel, 

2967 TaskOutletDatasetReference, 

2968 ) 

2969 

2970 dag_references = collections.defaultdict(set) 

2971 outlet_references = collections.defaultdict(set) 

2972 # We can't use a set here as we want to preserve order 

2973 outlet_datasets: dict[Dataset, None] = {} 

2974 input_datasets: dict[Dataset, None] = {} 

2975 

2976 # here we go through dags and tasks to check for dataset references 

2977 # if there are now None and previously there were some, we delete them 

2978 # if there are now *any*, we add them to the above data structures, and 

2979 # later we'll persist them to the database. 

2980 for dag in dags: 

2981 curr_orm_dag = existing_dags.get(dag.dag_id) 

2982 if not dag.dataset_triggers: 

2983 if curr_orm_dag and curr_orm_dag.schedule_dataset_references: 

2984 curr_orm_dag.schedule_dataset_references = [] 

2985 for dataset in dag.dataset_triggers: 

2986 dag_references[dag.dag_id].add(dataset.uri) 

2987 input_datasets[DatasetModel.from_public(dataset)] = None 

2988 curr_outlet_references = curr_orm_dag and curr_orm_dag.task_outlet_dataset_references 

2989 for task in dag.tasks: 

2990 dataset_outlets = [x for x in task.outlets or [] if isinstance(x, Dataset)] 

2991 if not dataset_outlets: 

2992 if curr_outlet_references: 

2993 this_task_outlet_refs = [ 

2994 x 

2995 for x in curr_outlet_references 

2996 if x.dag_id == dag.dag_id and x.task_id == task.task_id 

2997 ] 

2998 for ref in this_task_outlet_refs: 

2999 curr_outlet_references.remove(ref) 

3000 for d in dataset_outlets: 

3001 outlet_references[(task.dag_id, task.task_id)].add(d.uri) 

3002 outlet_datasets[DatasetModel.from_public(d)] = None 

3003 all_datasets = outlet_datasets 

3004 all_datasets.update(input_datasets) 

3005 

3006 # store datasets 

3007 stored_datasets = {} 

3008 for dataset in all_datasets: 

3009 stored_dataset = session.query(DatasetModel).filter(DatasetModel.uri == dataset.uri).first() 

3010 if stored_dataset: 

3011 # Some datasets may have been previously unreferenced, and therefore orphaned by the 

3012 # scheduler. But if we're here, then we have found that dataset again in our DAGs, which 

3013 # means that it is no longer an orphan, so set is_orphaned to False. 

3014 stored_dataset.is_orphaned = expression.false() 

3015 stored_datasets[stored_dataset.uri] = stored_dataset 

3016 else: 

3017 session.add(dataset) 

3018 stored_datasets[dataset.uri] = dataset 

3019 

3020 session.flush() # this is required to ensure each dataset has its PK loaded 

3021 

3022 del all_datasets 

3023 

3024 # reconcile dag-schedule-on-dataset references 

3025 for dag_id, uri_list in dag_references.items(): 

3026 dag_refs_needed = { 

3027 DagScheduleDatasetReference(dataset_id=stored_datasets[uri].id, dag_id=dag_id) 

3028 for uri in uri_list 

3029 } 

3030 dag_refs_stored = set( 

3031 existing_dags.get(dag_id) 

3032 and existing_dags.get(dag_id).schedule_dataset_references # type: ignore 

3033 or [] 

3034 ) 

3035 dag_refs_to_add = {x for x in dag_refs_needed if x not in dag_refs_stored} 

3036 session.bulk_save_objects(dag_refs_to_add) 

3037 for obj in dag_refs_stored - dag_refs_needed: 

3038 session.delete(obj) 

3039 

3040 existing_task_outlet_refs_dict = collections.defaultdict(set) 

3041 for dag_id, orm_dag in existing_dags.items(): 

3042 for todr in orm_dag.task_outlet_dataset_references: 

3043 existing_task_outlet_refs_dict[(dag_id, todr.task_id)].add(todr) 

3044 

3045 # reconcile task-outlet-dataset references 

3046 for (dag_id, task_id), uri_list in outlet_references.items(): 

3047 task_refs_needed = { 

3048 TaskOutletDatasetReference(dataset_id=stored_datasets[uri].id, dag_id=dag_id, task_id=task_id) 

3049 for uri in uri_list 

3050 } 

3051 task_refs_stored = existing_task_outlet_refs_dict[(dag_id, task_id)] 

3052 task_refs_to_add = {x for x in task_refs_needed if x not in task_refs_stored} 

3053 session.bulk_save_objects(task_refs_to_add) 

3054 for obj in task_refs_stored - task_refs_needed: 

3055 session.delete(obj) 

3056 

3057 # Issue SQL/finish "Unit of Work", but let @provide_session commit (or if passed a session, let caller 

3058 # decide when to commit 

3059 session.flush() 

3060 

3061 for dag in dags: 

3062 cls.bulk_write_to_db(dag.subdags, processor_subdir=processor_subdir, session=session) 

3063 

3064 @provide_session 

3065 def sync_to_db(self, processor_subdir: str | None = None, session=NEW_SESSION): 

3066 """ 

3067 Save attributes about this DAG to the DB. Note that this method 

3068 can be called for both DAGs and SubDAGs. A SubDag is actually a 

3069 SubDagOperator. 

3070 

3071 :return: None 

3072 """ 

3073 self.bulk_write_to_db([self], processor_subdir=processor_subdir, session=session) 

3074 

3075 def get_default_view(self): 

3076 """This is only there for backward compatible jinja2 templates.""" 

3077 if self.default_view is None: 

3078 return conf.get("webserver", "dag_default_view").lower() 

3079 else: 

3080 return self.default_view 

3081 

3082 @staticmethod 

3083 @provide_session 

3084 def deactivate_unknown_dags(active_dag_ids, session=NEW_SESSION): 

3085 """ 

3086 Given a list of known DAGs, deactivate any other DAGs that are 

3087 marked as active in the ORM. 

3088 

3089 :param active_dag_ids: list of DAG IDs that are active 

3090 :return: None 

3091 """ 

3092 if len(active_dag_ids) == 0: 

3093 return 

3094 for dag in session.query(DagModel).filter(~DagModel.dag_id.in_(active_dag_ids)).all(): 

3095 dag.is_active = False 

3096 session.merge(dag) 

3097 session.commit() 

3098 

3099 @staticmethod 

3100 @provide_session 

3101 def deactivate_stale_dags(expiration_date, session=NEW_SESSION): 

3102 """ 

3103 Deactivate any DAGs that were last touched by the scheduler before 

3104 the expiration date. These DAGs were likely deleted. 

3105 

3106 :param expiration_date: set inactive DAGs that were touched before this 

3107 time 

3108 :return: None 

3109 """ 

3110 for dag in ( 

3111 session.query(DagModel) 

3112 .filter(DagModel.last_parsed_time < expiration_date, DagModel.is_active) 

3113 .all() 

3114 ): 

3115 log.info( 

3116 "Deactivating DAG ID %s since it was last touched by the scheduler at %s", 

3117 dag.dag_id, 

3118 dag.last_parsed_time.isoformat(), 

3119 ) 

3120 dag.is_active = False 

3121 session.merge(dag) 

3122 session.commit() 

3123 

3124 @staticmethod 

3125 @provide_session 

3126 def get_num_task_instances(dag_id, run_id=None, task_ids=None, states=None, session=NEW_SESSION) -> int: 

3127 """ 

3128 Returns the number of task instances in the given DAG. 

3129 

3130 :param session: ORM session 

3131 :param dag_id: ID of the DAG to get the task concurrency of 

3132 :param run_id: ID of the DAG run to get the task concurrency of 

3133 :param task_ids: A list of valid task IDs for the given DAG 

3134 :param states: A list of states to filter by if supplied 

3135 :return: The number of running tasks 

3136 """ 

3137 qry = session.query(func.count(TaskInstance.task_id)).filter( 

3138 TaskInstance.dag_id == dag_id, 

3139 ) 

3140 if run_id: 

3141 qry = qry.filter( 

3142 TaskInstance.run_id == run_id, 

3143 ) 

3144 if task_ids: 

3145 qry = qry.filter( 

3146 TaskInstance.task_id.in_(task_ids), 

3147 ) 

3148 

3149 if states: 

3150 if None in states: 

3151 if all(x is None for x in states): 

3152 qry = qry.filter(TaskInstance.state.is_(None)) 

3153 else: 

3154 not_none_states = [state for state in states if state] 

3155 qry = qry.filter( 

3156 or_(TaskInstance.state.in_(not_none_states), TaskInstance.state.is_(None)) 

3157 ) 

3158 else: 

3159 qry = qry.filter(TaskInstance.state.in_(states)) 

3160 return qry.scalar() 

3161 

3162 @classmethod 

3163 def get_serialized_fields(cls): 

3164 """Stringified DAGs and operators contain exactly these fields.""" 

3165 if not cls.__serialized_fields: 

3166 exclusion_list = { 

3167 "parent_dag", 

3168 "schedule_dataset_references", 

3169 "task_outlet_dataset_references", 

3170 "_old_context_manager_dags", 

3171 "safe_dag_id", 

3172 "last_loaded", 

3173 "user_defined_filters", 

3174 "user_defined_macros", 

3175 "partial", 

3176 "params", 

3177 "_pickle_id", 

3178 "_log", 

3179 "task_dict", 

3180 "template_searchpath", 

3181 "sla_miss_callback", 

3182 "on_success_callback", 

3183 "on_failure_callback", 

3184 "template_undefined", 

3185 "jinja_environment_kwargs", 

3186 # has_on_*_callback are only stored if the value is True, as the default is False 

3187 "has_on_success_callback", 

3188 "has_on_failure_callback", 

3189 "auto_register", 

3190 "fail_stop", 

3191 } 

3192 cls.__serialized_fields = frozenset(vars(DAG(dag_id="test")).keys()) - exclusion_list 

3193 return cls.__serialized_fields 

3194 

3195 def get_edge_info(self, upstream_task_id: str, downstream_task_id: str) -> EdgeInfoType: 

3196 """ 

3197 Returns edge information for the given pair of tasks if present, and 

3198 an empty edge if there is no information. 

3199 """ 

3200 # Note - older serialized DAGs may not have edge_info being a dict at all 

3201 empty = cast(EdgeInfoType, {}) 

3202 if self.edge_info: 

3203 return self.edge_info.get(upstream_task_id, {}).get(downstream_task_id, empty) 

3204 else: 

3205 return empty 

3206 

3207 def set_edge_info(self, upstream_task_id: str, downstream_task_id: str, info: EdgeInfoType): 

3208 """ 

3209 Sets the given edge information on the DAG. Note that this will overwrite, 

3210 rather than merge with, existing info. 

3211 """ 

3212 self.edge_info.setdefault(upstream_task_id, {})[downstream_task_id] = info 

3213 

3214 def validate_schedule_and_params(self): 

3215 """ 

3216 Validates & raise exception if there are any Params in the DAG which neither have a default value nor 

3217 have the null in schema['type'] list, but the DAG have a schedule_interval which is not None. 

3218 """ 

3219 if not self.timetable.can_be_scheduled: 

3220 return 

3221 

3222 for k, v in self.params.items(): 

3223 # As type can be an array, we would check if `null` is an allowed type or not 

3224 if not v.has_value and ("type" not in v.schema or "null" not in v.schema["type"]): 

3225 raise AirflowException( 

3226 "DAG Schedule must be None, if there are any required params without default values" 

3227 ) 

3228 

3229 def iter_invalid_owner_links(self) -> Iterator[tuple[str, str]]: 

3230 """Parses a given link, and verifies if it's a valid URL, or a 'mailto' link. 

3231 Returns an iterator of invalid (owner, link) pairs. 

3232 """ 

3233 for owner, link in self.owner_links.items(): 

3234 result = urlsplit(link) 

3235 if result.scheme == "mailto": 

3236 # netloc is not existing for 'mailto' link, so we are checking that the path is parsed 

3237 if not result.path: 

3238 yield result.path, link 

3239 elif not result.scheme or not result.netloc: 

3240 yield owner, link 

3241 

3242 

3243class DagTag(Base): 

3244 """A tag name per dag, to allow quick filtering in the DAG view.""" 

3245 

3246 __tablename__ = "dag_tag" 

3247 name = Column(String(TAG_MAX_LEN), primary_key=True) 

3248 dag_id = Column( 

3249 StringID(), 

3250 ForeignKey("dag.dag_id", name="dag_tag_dag_id_fkey", ondelete="CASCADE"), 

3251 primary_key=True, 

3252 ) 

3253 

3254 def __repr__(self): 

3255 return self.name 

3256 

3257 

3258class DagOwnerAttributes(Base): 

3259 """Table defining different owner attributes. 

3260 

3261 For example, a link for an owner that will be passed as a hyperlink to the 

3262 "DAGs" view. 

3263 """ 

3264 

3265 __tablename__ = "dag_owner_attributes" 

3266 dag_id = Column( 

3267 StringID(), 

3268 ForeignKey("dag.dag_id", name="dag.dag_id", ondelete="CASCADE"), 

3269 nullable=False, 

3270 primary_key=True, 

3271 ) 

3272 owner = Column(String(500), primary_key=True, nullable=False) 

3273 link = Column(String(500), nullable=False) 

3274 

3275 def __repr__(self): 

3276 return f"<DagOwnerAttributes: dag_id={self.dag_id}, owner={self.owner}, link={self.link}>" 

3277 

3278 @classmethod 

3279 def get_all(cls, session) -> dict[str, dict[str, str]]: 

3280 dag_links: dict = collections.defaultdict(dict) 

3281 for obj in session.query(cls): 

3282 dag_links[obj.dag_id].update({obj.owner: obj.link}) 

3283 return dag_links 

3284 

3285 

3286class DagModel(Base): 

3287 """Table containing DAG properties.""" 

3288 

3289 __tablename__ = "dag" 

3290 """ 

3291 These items are stored in the database for state related information 

3292 """ 

3293 dag_id = Column(StringID(), primary_key=True) 

3294 root_dag_id = Column(StringID()) 

3295 # A DAG can be paused from the UI / DB 

3296 # Set this default value of is_paused based on a configuration value! 

3297 is_paused_at_creation = conf.getboolean("core", "dags_are_paused_at_creation") 

3298 is_paused = Column(Boolean, default=is_paused_at_creation) 

3299 # Whether the DAG is a subdag 

3300 is_subdag = Column(Boolean, default=False) 

3301 # Whether that DAG was seen on the last DagBag load 

3302 is_active = Column(Boolean, default=False) 

3303 # Last time the scheduler started 

3304 last_parsed_time = Column(UtcDateTime) 

3305 # Last time this DAG was pickled 

3306 last_pickled = Column(UtcDateTime) 

3307 # Time when the DAG last received a refresh signal 

3308 # (e.g. the DAG's "refresh" button was clicked in the web UI) 

3309 last_expired = Column(UtcDateTime) 

3310 # Whether (one of) the scheduler is scheduling this DAG at the moment 

3311 scheduler_lock = Column(Boolean) 

3312 # Foreign key to the latest pickle_id 

3313 pickle_id = Column(Integer) 

3314 # The location of the file containing the DAG object 

3315 # Note: Do not depend on fileloc pointing to a file; in the case of a 

3316 # packaged DAG, it will point to the subpath of the DAG within the 

3317 # associated zip. 

3318 fileloc = Column(String(2000)) 

3319 # The base directory used by Dag Processor that parsed this dag. 

3320 processor_subdir = Column(String(2000), nullable=True) 

3321 # String representing the owners 

3322 owners = Column(String(2000)) 

3323 # Description of the dag 

3324 description = Column(Text) 

3325 # Default view of the DAG inside the webserver 

3326 default_view = Column(String(25)) 

3327 # Schedule interval 

3328 schedule_interval = Column(Interval) 

3329 # Timetable/Schedule Interval description 

3330 timetable_description = Column(String(1000), nullable=True) 

3331 # Tags for view filter 

3332 tags = relationship("DagTag", cascade="all, delete, delete-orphan", backref=backref("dag")) 

3333 # Dag owner links for DAGs view 

3334 dag_owner_links = relationship( 

3335 "DagOwnerAttributes", cascade="all, delete, delete-orphan", backref=backref("dag") 

3336 ) 

3337 

3338 max_active_tasks = Column(Integer, nullable=False) 

3339 max_active_runs = Column(Integer, nullable=True) 

3340 

3341 has_task_concurrency_limits = Column(Boolean, nullable=False) 

3342 has_import_errors = Column(Boolean(), default=False, server_default="0") 

3343 

3344 # The logical date of the next dag run. 

3345 next_dagrun = Column(UtcDateTime) 

3346 

3347 # Must be either both NULL or both datetime. 

3348 next_dagrun_data_interval_start = Column(UtcDateTime) 

3349 next_dagrun_data_interval_end = Column(UtcDateTime) 

3350 

3351 # Earliest time at which this ``next_dagrun`` can be created. 

3352 next_dagrun_create_after = Column(UtcDateTime) 

3353 

3354 __table_args__ = ( 

3355 Index("idx_root_dag_id", root_dag_id, unique=False), 

3356 Index("idx_next_dagrun_create_after", next_dagrun_create_after, unique=False), 

3357 ) 

3358 

3359 parent_dag = relationship( 

3360 "DagModel", remote_side=[dag_id], primaryjoin=root_dag_id == dag_id, foreign_keys=[root_dag_id] 

3361 ) 

3362 schedule_dataset_references = relationship( 

3363 "DagScheduleDatasetReference", 

3364 cascade="all, delete, delete-orphan", 

3365 ) 

3366 schedule_datasets = association_proxy("schedule_dataset_references", "dataset") 

3367 task_outlet_dataset_references = relationship( 

3368 "TaskOutletDatasetReference", 

3369 cascade="all, delete, delete-orphan", 

3370 ) 

3371 NUM_DAGS_PER_DAGRUN_QUERY = conf.getint("scheduler", "max_dagruns_to_create_per_loop", fallback=10) 

3372 

3373 def __init__(self, concurrency=None, **kwargs): 

3374 super().__init__(**kwargs) 

3375 if self.max_active_tasks is None: 

3376 if concurrency: 

3377 warnings.warn( 

3378 "The 'DagModel.concurrency' parameter is deprecated. Please use 'max_active_tasks'.", 

3379 RemovedInAirflow3Warning, 

3380 stacklevel=2, 

3381 ) 

3382 self.max_active_tasks = concurrency 

3383 else: 

3384 self.max_active_tasks = conf.getint("core", "max_active_tasks_per_dag") 

3385 

3386 if self.max_active_runs is None: 

3387 self.max_active_runs = conf.getint("core", "max_active_runs_per_dag") 

3388 

3389 if self.has_task_concurrency_limits is None: 

3390 # Be safe -- this will be updated later once the DAG is parsed 

3391 self.has_task_concurrency_limits = True 

3392 

3393 def __repr__(self): 

3394 return f"<DAG: {self.dag_id}>" 

3395 

3396 @property 

3397 def next_dagrun_data_interval(self) -> DataInterval | None: 

3398 return _get_model_data_interval( 

3399 self, 

3400 "next_dagrun_data_interval_start", 

3401 "next_dagrun_data_interval_end", 

3402 ) 

3403 

3404 @next_dagrun_data_interval.setter 

3405 def next_dagrun_data_interval(self, value: tuple[datetime, datetime] | None) -> None: 

3406 if value is None: 

3407 self.next_dagrun_data_interval_start = self.next_dagrun_data_interval_end = None 

3408 else: 

3409 self.next_dagrun_data_interval_start, self.next_dagrun_data_interval_end = value 

3410 

3411 @property 

3412 def timezone(self): 

3413 return settings.TIMEZONE 

3414 

3415 @staticmethod 

3416 @provide_session 

3417 def get_dagmodel(dag_id: str, session: Session = NEW_SESSION) -> DagModel | None: 

3418 return session.get( 

3419 DagModel, 

3420 dag_id, 

3421 options=[joinedload(DagModel.parent_dag)], 

3422 ) 

3423 

3424 @classmethod 

3425 @provide_session 

3426 def get_current(cls, dag_id, session=NEW_SESSION): 

3427 return session.query(cls).filter(cls.dag_id == dag_id).first() 

3428 

3429 @provide_session 

3430 def get_last_dagrun(self, session=NEW_SESSION, include_externally_triggered=False): 

3431 return get_last_dagrun( 

3432 self.dag_id, session=session, include_externally_triggered=include_externally_triggered 

3433 ) 

3434 

3435 def get_is_paused(self, *, session: Session | None = None) -> bool: 

3436 """Provide interface compatibility to 'DAG'.""" 

3437 return self.is_paused 

3438 

3439 @staticmethod 

3440 @internal_api_call 

3441 @provide_session 

3442 def get_paused_dag_ids(dag_ids: list[str], session: Session = NEW_SESSION) -> set[str]: 

3443 """ 

3444 Given a list of dag_ids, get a set of Paused Dag Ids. 

3445 

3446 :param dag_ids: List of Dag ids 

3447 :param session: ORM Session 

3448 :return: Paused Dag_ids 

3449 """ 

3450 paused_dag_ids = ( 

3451 session.query(DagModel.dag_id) 

3452 .filter(DagModel.is_paused == expression.true()) 

3453 .filter(DagModel.dag_id.in_(dag_ids)) 

3454 .all() 

3455 ) 

3456 

3457 paused_dag_ids = {paused_dag_id for paused_dag_id, in paused_dag_ids} 

3458 return paused_dag_ids 

3459 

3460 def get_default_view(self) -> str: 

3461 """ 

3462 Get the Default DAG View, returns the default config value if DagModel does not 

3463 have a value. 

3464 """ 

3465 # This is for backwards-compatibility with old dags that don't have None as default_view 

3466 return self.default_view or conf.get_mandatory_value("webserver", "dag_default_view").lower() 

3467 

3468 @property 

3469 def safe_dag_id(self): 

3470 return self.dag_id.replace(".", "__dot__") 

3471 

3472 @property 

3473 def relative_fileloc(self) -> pathlib.Path | None: 

3474 """File location of the importable dag 'file' relative to the configured DAGs folder.""" 

3475 if self.fileloc is None: 

3476 return None 

3477 path = pathlib.Path(self.fileloc) 

3478 try: 

3479 return path.relative_to(settings.DAGS_FOLDER) 

3480 except ValueError: 

3481 # Not relative to DAGS_FOLDER. 

3482 return path 

3483 

3484 @provide_session 

3485 def set_is_paused(self, is_paused: bool, including_subdags: bool = True, session=NEW_SESSION) -> None: 

3486 """ 

3487 Pause/Un-pause a DAG. 

3488 

3489 :param is_paused: Is the DAG paused 

3490 :param including_subdags: whether to include the DAG's subdags 

3491 :param session: session 

3492 """ 

3493 filter_query = [ 

3494 DagModel.dag_id == self.dag_id, 

3495 ] 

3496 if including_subdags: 

3497 filter_query.append(DagModel.root_dag_id == self.dag_id) 

3498 session.query(DagModel).filter(or_(*filter_query)).update( 

3499 {DagModel.is_paused: is_paused}, synchronize_session="fetch" 

3500 ) 

3501 session.commit() 

3502 

3503 @classmethod 

3504 @internal_api_call 

3505 @provide_session 

3506 def deactivate_deleted_dags(cls, alive_dag_filelocs: list[str], session=NEW_SESSION): 

3507 """ 

3508 Set ``is_active=False`` on the DAGs for which the DAG files have been removed. 

3509 

3510 :param alive_dag_filelocs: file paths of alive DAGs 

3511 :param session: ORM Session 

3512 """ 

3513 log.debug("Deactivating DAGs (for which DAG files are deleted) from %s table ", cls.__tablename__) 

3514 

3515 dag_models = session.query(cls).all() 

3516 for dag_model in dag_models: 

3517 if dag_model.fileloc is not None and dag_model.fileloc not in alive_dag_filelocs: 

3518 dag_model.is_active = False 

3519 else: 

3520 continue 

3521 

3522 @classmethod 

3523 def dags_needing_dagruns(cls, session: Session) -> tuple[Query, dict[str, tuple[datetime, datetime]]]: 

3524 """ 

3525 Return (and lock) a list of Dag objects that are due to create a new DagRun. 

3526 

3527 This will return a resultset of rows that is row-level-locked with a "SELECT ... FOR UPDATE" query, 

3528 you should ensure that any scheduling decisions are made in a single transaction -- as soon as the 

3529 transaction is committed it will be unlocked. 

3530 """ 

3531 from airflow.models.dataset import DagScheduleDatasetReference, DatasetDagRunQueue as DDRQ 

3532 

3533 # these dag ids are triggered by datasets, and they are ready to go. 

3534 dataset_triggered_dag_info = { 

3535 x.dag_id: (x.first_queued_time, x.last_queued_time) 

3536 for x in session.query( 

3537 DagScheduleDatasetReference.dag_id, 

3538 func.max(DDRQ.created_at).label("last_queued_time"), 

3539 func.min(DDRQ.created_at).label("first_queued_time"), 

3540 ) 

3541 .join(DagScheduleDatasetReference.queue_records, isouter=True) 

3542 .group_by(DagScheduleDatasetReference.dag_id) 

3543 .having(func.count() == func.sum(case((DDRQ.target_dag_id.is_not(None), 1), else_=0))) 

3544 .all() 

3545 } 

3546 dataset_triggered_dag_ids = set(dataset_triggered_dag_info.keys()) 

3547 if dataset_triggered_dag_ids: 

3548 exclusion_list = { 

3549 x.dag_id 

3550 for x in ( 

3551 session.query(DagModel.dag_id) 

3552 .join(DagRun.dag_model) 

3553 .filter(DagRun.state.in_((DagRunState.QUEUED, DagRunState.RUNNING))) 

3554 .filter(DagModel.dag_id.in_(dataset_triggered_dag_ids)) 

3555 .group_by(DagModel.dag_id) 

3556 .having(func.count() >= func.max(DagModel.max_active_runs)) 

3557 .all() 

3558 ) 

3559 } 

3560 if exclusion_list: 

3561 dataset_triggered_dag_ids -= exclusion_list 

3562 dataset_triggered_dag_info = { 

3563 k: v for k, v in dataset_triggered_dag_info.items() if k not in exclusion_list 

3564 } 

3565 

3566 # We limit so that _one_ scheduler doesn't try to do all the creation of dag runs 

3567 query = ( 

3568 session.query(cls) 

3569 .filter( 

3570 cls.is_paused == expression.false(), 

3571 cls.is_active == expression.true(), 

3572 cls.has_import_errors == expression.false(), 

3573 or_( 

3574 cls.next_dagrun_create_after <= func.now(), 

3575 cls.dag_id.in_(dataset_triggered_dag_ids), 

3576 ), 

3577 ) 

3578 .order_by(cls.next_dagrun_create_after) 

3579 .limit(cls.NUM_DAGS_PER_DAGRUN_QUERY) 

3580 ) 

3581 

3582 return ( 

3583 with_row_locks(query, of=cls, session=session, **skip_locked(session=session)), 

3584 dataset_triggered_dag_info, 

3585 ) 

3586 

3587 def calculate_dagrun_date_fields( 

3588 self, 

3589 dag: DAG, 

3590 most_recent_dag_run: None | datetime | DataInterval, 

3591 ) -> None: 

3592 """ 

3593 Calculate ``next_dagrun`` and `next_dagrun_create_after``. 

3594 

3595 :param dag: The DAG object 

3596 :param most_recent_dag_run: DataInterval (or datetime) of most recent run of this dag, or none 

3597 if not yet scheduled. 

3598 """ 

3599 most_recent_data_interval: DataInterval | None 

3600 if isinstance(most_recent_dag_run, datetime): 

3601 warnings.warn( 

3602 "Passing a datetime to `DagModel.calculate_dagrun_date_fields` is deprecated. " 

3603 "Provide a data interval instead.", 

3604 RemovedInAirflow3Warning, 

3605 stacklevel=2, 

3606 ) 

3607 most_recent_data_interval = dag.infer_automated_data_interval(most_recent_dag_run) 

3608 else: 

3609 most_recent_data_interval = most_recent_dag_run 

3610 next_dagrun_info = dag.next_dagrun_info(most_recent_data_interval) 

3611 if next_dagrun_info is None: 

3612 self.next_dagrun_data_interval = self.next_dagrun = self.next_dagrun_create_after = None 

3613 else: 

3614 self.next_dagrun_data_interval = next_dagrun_info.data_interval 

3615 self.next_dagrun = next_dagrun_info.logical_date 

3616 self.next_dagrun_create_after = next_dagrun_info.run_after 

3617 

3618 log.info( 

3619 "Setting next_dagrun for %s to %s, run_after=%s", 

3620 dag.dag_id, 

3621 self.next_dagrun, 

3622 self.next_dagrun_create_after, 

3623 ) 

3624 

3625 @provide_session 

3626 def get_dataset_triggered_next_run_info(self, *, session=NEW_SESSION) -> dict[str, int | str] | None: 

3627 if self.schedule_interval != "Dataset": 

3628 return None 

3629 return get_dataset_triggered_next_run_info([self.dag_id], session=session)[self.dag_id] 

3630 

3631 

3632# NOTE: Please keep the list of arguments in sync with DAG.__init__. 

3633# Only exception: dag_id here should have a default value, but not in DAG. 

3634def dag( 

3635 dag_id: str = "", 

3636 description: str | None = None, 

3637 schedule: ScheduleArg = NOTSET, 

3638 schedule_interval: ScheduleIntervalArg = NOTSET, 

3639 timetable: Timetable | None = None, 

3640 start_date: datetime | None = None, 

3641 end_date: datetime | None = None, 

3642 full_filepath: str | None = None, 

3643 template_searchpath: str | Iterable[str] | None = None, 

3644 template_undefined: type[jinja2.StrictUndefined] = jinja2.StrictUndefined, 

3645 user_defined_macros: dict | None = None, 

3646 user_defined_filters: dict | None = None, 

3647 default_args: dict | None = None, 

3648 concurrency: int | None = None, 

3649 max_active_tasks: int = conf.getint("core", "max_active_tasks_per_dag"), 

3650 max_active_runs: int = conf.getint("core", "max_active_runs_per_dag"), 

3651 dagrun_timeout: timedelta | None = None, 

3652 sla_miss_callback: None | SLAMissCallback | list[SLAMissCallback] = None, 

3653 default_view: str = conf.get_mandatory_value("webserver", "dag_default_view").lower(), 

3654 orientation: str = conf.get_mandatory_value("webserver", "dag_orientation"), 

3655 catchup: bool = conf.getboolean("scheduler", "catchup_by_default"), 

3656 on_success_callback: None | DagStateChangeCallback | list[DagStateChangeCallback] = None, 

3657 on_failure_callback: None | DagStateChangeCallback | list[DagStateChangeCallback] = None, 

3658 doc_md: str | None = None, 

3659 params: collections.abc.MutableMapping | None = None, 

3660 access_control: dict | None = None, 

3661 is_paused_upon_creation: bool | None = None, 

3662 jinja_environment_kwargs: dict | None = None, 

3663 render_template_as_native_obj: bool = False, 

3664 tags: list[str] | None = None, 

3665 owner_links: dict[str, str] | None = None, 

3666 auto_register: bool = True, 

3667 fail_stop: bool = False, 

3668) -> Callable[[Callable], Callable[..., DAG]]: 

3669 """ 

3670 Python dag decorator. Wraps a function into an Airflow DAG. 

3671 Accepts kwargs for operator kwarg. Can be used to parameterize DAGs. 

3672 

3673 :param dag_args: Arguments for DAG object 

3674 :param dag_kwargs: Kwargs for DAG object. 

3675 """ 

3676 

3677 def wrapper(f: Callable) -> Callable[..., DAG]: 

3678 @functools.wraps(f) 

3679 def factory(*args, **kwargs): 

3680 # Generate signature for decorated function and bind the arguments when called 

3681 # we do this to extract parameters, so we can annotate them on the DAG object. 

3682 # In addition, this fails if we are missing any args/kwargs with TypeError as expected. 

3683 f_sig = signature(f).bind(*args, **kwargs) 

3684 # Apply defaults to capture default values if set. 

3685 f_sig.apply_defaults() 

3686 

3687 # Initialize DAG with bound arguments 

3688 with DAG( 

3689 dag_id or f.__name__, 

3690 description=description, 

3691 schedule_interval=schedule_interval, 

3692 timetable=timetable, 

3693 start_date=start_date, 

3694 end_date=end_date, 

3695 full_filepath=full_filepath, 

3696 template_searchpath=template_searchpath, 

3697 template_undefined=template_undefined, 

3698 user_defined_macros=user_defined_macros, 

3699 user_defined_filters=user_defined_filters, 

3700 default_args=default_args, 

3701 concurrency=concurrency, 

3702 max_active_tasks=max_active_tasks, 

3703 max_active_runs=max_active_runs, 

3704 dagrun_timeout=dagrun_timeout, 

3705 sla_miss_callback=sla_miss_callback, 

3706 default_view=default_view, 

3707 orientation=orientation, 

3708 catchup=catchup, 

3709 on_success_callback=on_success_callback, 

3710 on_failure_callback=on_failure_callback, 

3711 doc_md=doc_md, 

3712 params=params, 

3713 access_control=access_control, 

3714 is_paused_upon_creation=is_paused_upon_creation, 

3715 jinja_environment_kwargs=jinja_environment_kwargs, 

3716 render_template_as_native_obj=render_template_as_native_obj, 

3717 tags=tags, 

3718 schedule=schedule, 

3719 owner_links=owner_links, 

3720 auto_register=auto_register, 

3721 fail_stop=fail_stop, 

3722 ) as dag_obj: 

3723 # Set DAG documentation from function documentation if it exists and doc_md is not set. 

3724 if f.__doc__ and not dag_obj.doc_md: 

3725 dag_obj.doc_md = f.__doc__ 

3726 

3727 # Generate DAGParam for each function arg/kwarg and replace it for calling the function. 

3728 # All args/kwargs for function will be DAGParam object and replaced on execution time. 

3729 f_kwargs = {} 

3730 for name, value in f_sig.arguments.items(): 

3731 f_kwargs[name] = dag_obj.param(name, value) 

3732 

3733 # set file location to caller source path 

3734 back = sys._getframe().f_back 

3735 dag_obj.fileloc = back.f_code.co_filename if back else "" 

3736 

3737 # Invoke function to create operators in the DAG scope. 

3738 f(**f_kwargs) 

3739 

3740 # Return dag object such that it's accessible in Globals. 

3741 return dag_obj 

3742 

3743 # Ensure that warnings from inside DAG() are emitted from the caller, not here 

3744 fixup_decorator_warning_stack(factory) 

3745 return factory 

3746 

3747 return wrapper 

3748 

3749 

3750STATICA_HACK = True 

3751globals()["kcah_acitats"[::-1].upper()] = False 

3752if STATICA_HACK: # pragma: no cover 

3753 

3754 from airflow.models.serialized_dag import SerializedDagModel 

3755 

3756 DagModel.serialized_dag = relationship(SerializedDagModel) 

3757 """:sphinx-autoapi-skip:""" 

3758 

3759 

3760class DagContext: 

3761 """ 

3762 DAG context is used to keep the current DAG when DAG is used as ContextManager. 

3763 

3764 You can use DAG as context: 

3765 

3766 .. code-block:: python 

3767 

3768 with DAG( 

3769 dag_id="example_dag", 

3770 default_args=default_args, 

3771 schedule="0 0 * * *", 

3772 dagrun_timeout=timedelta(minutes=60), 

3773 ) as dag: 

3774 ... 

3775 

3776 If you do this the context stores the DAG and whenever new task is created, it will use 

3777 such stored DAG as the parent DAG. 

3778 

3779 """ 

3780 

3781 _context_managed_dags: Deque[DAG] = deque() 

3782 autoregistered_dags: set[tuple[DAG, ModuleType]] = set() 

3783 current_autoregister_module_name: str | None = None 

3784 

3785 @classmethod 

3786 def push_context_managed_dag(cls, dag: DAG): 

3787 cls._context_managed_dags.appendleft(dag) 

3788 

3789 @classmethod 

3790 def pop_context_managed_dag(cls) -> DAG | None: 

3791 dag = cls._context_managed_dags.popleft() 

3792 

3793 # In a few cases around serialization we explicitly push None in to the stack 

3794 if cls.current_autoregister_module_name is not None and dag and dag.auto_register: 

3795 mod = sys.modules[cls.current_autoregister_module_name] 

3796 cls.autoregistered_dags.add((dag, mod)) 

3797 

3798 return dag 

3799 

3800 @classmethod 

3801 def get_current_dag(cls) -> DAG | None: 

3802 try: 

3803 return cls._context_managed_dags[0] 

3804 except IndexError: 

3805 return None 

3806 

3807 

3808def _run_task(ti: TaskInstance, session): 

3809 """ 

3810 Run a single task instance, and push result to Xcom for downstream tasks. Bypasses a lot of 

3811 extra steps used in `task.run` to keep our local running as fast as possible 

3812 This function is only meant for the `dag.test` function as a helper function. 

3813 

3814 Args: 

3815 ti: TaskInstance to run 

3816 """ 

3817 log.info("*****************************************************") 

3818 if ti.map_index > 0: 

3819 log.info("Running task %s index %d", ti.task_id, ti.map_index) 

3820 else: 

3821 log.info("Running task %s", ti.task_id) 

3822 try: 

3823 ti._run_raw_task(session=session) 

3824 session.flush() 

3825 log.info("%s ran successfully!", ti.task_id) 

3826 except AirflowSkipException: 

3827 log.info("Task Skipped, continuing") 

3828 log.info("*****************************************************") 

3829 

3830 

3831def _get_or_create_dagrun( 

3832 dag: DAG, 

3833 conf: dict[Any, Any] | None, 

3834 start_date: datetime, 

3835 execution_date: datetime, 

3836 run_id: str, 

3837 session: Session, 

3838) -> DagRun: 

3839 """Create a DAG run, replacing an existing instance if needed to prevent collisions. 

3840 

3841 This function is only meant to be used by :meth:`DAG.test` as a helper function. 

3842 

3843 :param dag: DAG to be used to find run. 

3844 :param conf: Configuration to pass to newly created run. 

3845 :param start_date: Start date of new run. 

3846 :param execution_date: Logical date for finding an existing run. 

3847 :param run_id: Run ID for the new DAG run. 

3848 

3849 :return: The newly created DAG run. 

3850 """ 

3851 log.info("dagrun id: %s", dag.dag_id) 

3852 dr: DagRun = ( 

3853 session.query(DagRun) 

3854 .filter(DagRun.dag_id == dag.dag_id, DagRun.execution_date == execution_date) 

3855 .first() 

3856 ) 

3857 if dr: 

3858 session.delete(dr) 

3859 session.commit() 

3860 dr = dag.create_dagrun( 

3861 state=DagRunState.RUNNING, 

3862 execution_date=execution_date, 

3863 run_id=run_id, 

3864 start_date=start_date or execution_date, 

3865 session=session, 

3866 conf=conf, 

3867 ) 

3868 log.info("created dagrun %s", dr) 

3869 return dr