1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18from __future__ import annotations
19
20from typing import TYPE_CHECKING
21
22from sqlalchemy import exc, select
23from sqlalchemy.orm import joinedload
24
25from airflow.configuration import conf
26from airflow.datasets import Dataset
27from airflow.listeners.listener import get_listener_manager
28from airflow.models.dataset import DagScheduleDatasetReference, DatasetDagRunQueue, DatasetEvent, DatasetModel
29from airflow.stats import Stats
30from airflow.utils.log.logging_mixin import LoggingMixin
31
32if TYPE_CHECKING:
33 from sqlalchemy.orm.session import Session
34
35 from airflow.models.dag import DagModel
36 from airflow.models.taskinstance import TaskInstance
37
38
39class DatasetManager(LoggingMixin):
40 """
41 A pluggable class that manages operations for datasets.
42
43 The intent is to have one place to handle all Dataset-related operations, so different
44 Airflow deployments can use plugins that broadcast dataset events to each other.
45 """
46
47 def __init__(self, **kwargs):
48 super().__init__(**kwargs)
49
50 def create_datasets(self, dataset_models: list[DatasetModel], session: Session) -> None:
51 """Create new datasets."""
52 for dataset_model in dataset_models:
53 session.add(dataset_model)
54 session.flush()
55
56 for dataset_model in dataset_models:
57 self.notify_dataset_created(dataset=Dataset(uri=dataset_model.uri, extra=dataset_model.extra))
58
59 def register_dataset_change(
60 self,
61 *,
62 task_instance: TaskInstance | None = None,
63 dataset: Dataset,
64 extra=None,
65 session: Session,
66 **kwargs,
67 ) -> DatasetEvent | None:
68 """
69 Register dataset related changes.
70
71 For local datasets, look them up, record the dataset event, queue dagruns, and broadcast
72 the dataset event
73 """
74 dataset_model = session.scalar(
75 select(DatasetModel)
76 .where(DatasetModel.uri == dataset.uri)
77 .options(joinedload(DatasetModel.consuming_dags).joinedload(DagScheduleDatasetReference.dag))
78 )
79 if not dataset_model:
80 self.log.warning("DatasetModel %s not found", dataset)
81 return None
82
83 event_kwargs = {
84 "dataset_id": dataset_model.id,
85 "extra": extra,
86 }
87 if task_instance:
88 event_kwargs.update(
89 {
90 "source_task_id": task_instance.task_id,
91 "source_dag_id": task_instance.dag_id,
92 "source_run_id": task_instance.run_id,
93 "source_map_index": task_instance.map_index,
94 }
95 )
96 dataset_event = DatasetEvent(**event_kwargs)
97 session.add(dataset_event)
98 session.flush()
99
100 self.notify_dataset_changed(dataset=dataset)
101
102 Stats.incr("dataset.updates")
103 self._queue_dagruns(dataset_model, session)
104 session.flush()
105 return dataset_event
106
107 def notify_dataset_created(self, dataset: Dataset):
108 """Run applicable notification actions when a dataset is created."""
109 get_listener_manager().hook.on_dataset_created(dataset=dataset)
110
111 def notify_dataset_changed(self, dataset: Dataset):
112 """Run applicable notification actions when a dataset is changed."""
113 get_listener_manager().hook.on_dataset_changed(dataset=dataset)
114
115 def _queue_dagruns(self, dataset: DatasetModel, session: Session) -> None:
116 # Possible race condition: if multiple dags or multiple (usually
117 # mapped) tasks update the same dataset, this can fail with a unique
118 # constraint violation.
119 #
120 # If we support it, use ON CONFLICT to do nothing, otherwise
121 # "fallback" to running this in a nested transaction. This is needed
122 # so that the adding of these rows happens in the same transaction
123 # where `ti.state` is changed.
124
125 if session.bind.dialect.name == "postgresql":
126 return self._postgres_queue_dagruns(dataset, session)
127 return self._slow_path_queue_dagruns(dataset, session)
128
129 def _slow_path_queue_dagruns(self, dataset: DatasetModel, session: Session) -> None:
130 def _queue_dagrun_if_needed(dag: DagModel) -> str | None:
131 if not dag.is_active or dag.is_paused:
132 return None
133 item = DatasetDagRunQueue(target_dag_id=dag.dag_id, dataset_id=dataset.id)
134 # Don't error whole transaction when a single RunQueue item conflicts.
135 # https://docs.sqlalchemy.org/en/14/orm/session_transaction.html#using-savepoint
136 try:
137 with session.begin_nested():
138 session.merge(item)
139 except exc.IntegrityError:
140 self.log.debug("Skipping record %s", item, exc_info=True)
141 return dag.dag_id
142
143 queued_results = (_queue_dagrun_if_needed(ref.dag) for ref in dataset.consuming_dags)
144 if queued_dag_ids := [r for r in queued_results if r is not None]:
145 self.log.debug("consuming dag ids %s", queued_dag_ids)
146
147 def _postgres_queue_dagruns(self, dataset: DatasetModel, session: Session) -> None:
148 from sqlalchemy.dialects.postgresql import insert
149
150 values = [
151 {"target_dag_id": dag.dag_id}
152 for dag in (r.dag for r in dataset.consuming_dags)
153 if dag.is_active and not dag.is_paused
154 ]
155 if not values:
156 return
157 stmt = insert(DatasetDagRunQueue).values(dataset_id=dataset.id).on_conflict_do_nothing()
158 session.execute(stmt, values)
159
160
161def resolve_dataset_manager() -> DatasetManager:
162 """Retrieve the dataset manager."""
163 _dataset_manager_class = conf.getimport(
164 section="core",
165 key="dataset_manager_class",
166 fallback="airflow.datasets.manager.DatasetManager",
167 )
168 _dataset_manager_kwargs = conf.getjson(
169 section="core",
170 key="dataset_manager_kwargs",
171 fallback={},
172 )
173 return _dataset_manager_class(**_dataset_manager_kwargs)
174
175
176dataset_manager = resolve_dataset_manager()