1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18from __future__ import annotations
19
20from urllib.parse import urlsplit
21
22import sqlalchemy_jsonfield
23from sqlalchemy import (
24 Boolean,
25 Column,
26 ForeignKey,
27 ForeignKeyConstraint,
28 Index,
29 Integer,
30 PrimaryKeyConstraint,
31 String,
32 Table,
33 text,
34)
35from sqlalchemy.orm import relationship
36
37from airflow.datasets import Dataset
38from airflow.models.base import Base, StringID
39from airflow.settings import json
40from airflow.utils import timezone
41from airflow.utils.sqlalchemy import UtcDateTime
42
43
44class DatasetModel(Base):
45 """
46 A table to store datasets.
47
48 :param uri: a string that uniquely identifies the dataset
49 :param extra: JSON field for arbitrary extra info
50 """
51
52 id = Column(Integer, primary_key=True, autoincrement=True)
53 uri = Column(
54 String(length=3000).with_variant(
55 String(
56 length=3000,
57 # latin1 allows for more indexed length in mysql
58 # and this field should only be ascii chars
59 collation="latin1_general_cs",
60 ),
61 "mysql",
62 ),
63 nullable=False,
64 )
65 extra = Column(sqlalchemy_jsonfield.JSONField(json=json), nullable=False, default={})
66 created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False)
67 updated_at = Column(UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow, nullable=False)
68 is_orphaned = Column(Boolean, default=False, nullable=False, server_default="0")
69
70 consuming_dags = relationship("DagScheduleDatasetReference", back_populates="dataset")
71 producing_tasks = relationship("TaskOutletDatasetReference", back_populates="dataset")
72
73 __tablename__ = "dataset"
74 __table_args__ = (
75 Index("idx_uri_unique", uri, unique=True),
76 {"sqlite_autoincrement": True}, # ensures PK values not reused
77 )
78
79 @classmethod
80 def from_public(cls, obj: Dataset) -> DatasetModel:
81 return cls(uri=obj.uri, extra=obj.extra)
82
83 def __init__(self, uri: str, **kwargs):
84 try:
85 uri.encode("ascii")
86 except UnicodeEncodeError:
87 raise ValueError("URI must be ascii")
88 parsed = urlsplit(uri)
89 if parsed.scheme and parsed.scheme.lower() == "airflow":
90 raise ValueError("Scheme `airflow` is reserved.")
91 super().__init__(uri=uri, **kwargs)
92
93 def __eq__(self, other):
94 if isinstance(other, (self.__class__, Dataset)):
95 return self.uri == other.uri
96 else:
97 return NotImplemented
98
99 def __hash__(self):
100 return hash(self.uri)
101
102 def __repr__(self):
103 return f"{self.__class__.__name__}(uri={self.uri!r}, extra={self.extra!r})"
104
105
106class DagScheduleDatasetReference(Base):
107 """References from a DAG to a dataset of which it is a consumer."""
108
109 dataset_id = Column(Integer, primary_key=True, nullable=False)
110 dag_id = Column(StringID(), primary_key=True, nullable=False)
111 created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False)
112 updated_at = Column(UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow, nullable=False)
113
114 dataset = relationship("DatasetModel", back_populates="consuming_dags")
115 dag = relationship("DagModel", back_populates="schedule_dataset_references")
116
117 queue_records = relationship(
118 "DatasetDagRunQueue",
119 primaryjoin="""and_(
120 DagScheduleDatasetReference.dataset_id == foreign(DatasetDagRunQueue.dataset_id),
121 DagScheduleDatasetReference.dag_id == foreign(DatasetDagRunQueue.target_dag_id),
122 )""",
123 cascade="all, delete, delete-orphan",
124 )
125
126 __tablename__ = "dag_schedule_dataset_reference"
127 __table_args__ = (
128 PrimaryKeyConstraint(dataset_id, dag_id, name="dsdr_pkey"),
129 ForeignKeyConstraint(
130 (dataset_id,),
131 ["dataset.id"],
132 name="dsdr_dataset_fkey",
133 ondelete="CASCADE",
134 ),
135 ForeignKeyConstraint(
136 columns=(dag_id,),
137 refcolumns=["dag.dag_id"],
138 name="dsdr_dag_id_fkey",
139 ondelete="CASCADE",
140 ),
141 Index("idx_dag_schedule_dataset_reference_dag_id", dag_id),
142 )
143
144 def __eq__(self, other):
145 if isinstance(other, self.__class__):
146 return self.dataset_id == other.dataset_id and self.dag_id == other.dag_id
147 else:
148 return NotImplemented
149
150 def __hash__(self):
151 return hash(self.__mapper__.primary_key)
152
153 def __repr__(self):
154 args = []
155 for attr in [x.name for x in self.__mapper__.primary_key]:
156 args.append(f"{attr}={getattr(self, attr)!r}")
157 return f"{self.__class__.__name__}({', '.join(args)})"
158
159
160class TaskOutletDatasetReference(Base):
161 """References from a task to a dataset that it updates / produces."""
162
163 dataset_id = Column(Integer, primary_key=True, nullable=False)
164 dag_id = Column(StringID(), primary_key=True, nullable=False)
165 task_id = Column(StringID(), primary_key=True, nullable=False)
166 created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False)
167 updated_at = Column(UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow, nullable=False)
168
169 dataset = relationship("DatasetModel", back_populates="producing_tasks")
170
171 __tablename__ = "task_outlet_dataset_reference"
172 __table_args__ = (
173 ForeignKeyConstraint(
174 (dataset_id,),
175 ["dataset.id"],
176 name="todr_dataset_fkey",
177 ondelete="CASCADE",
178 ),
179 PrimaryKeyConstraint(dataset_id, dag_id, task_id, name="todr_pkey"),
180 ForeignKeyConstraint(
181 columns=(dag_id,),
182 refcolumns=["dag.dag_id"],
183 name="todr_dag_id_fkey",
184 ondelete="CASCADE",
185 ),
186 Index("idx_task_outlet_dataset_reference_dag_id", dag_id),
187 )
188
189 def __eq__(self, other):
190 if isinstance(other, self.__class__):
191 return (
192 self.dataset_id == other.dataset_id
193 and self.dag_id == other.dag_id
194 and self.task_id == other.task_id
195 )
196 else:
197 return NotImplemented
198
199 def __hash__(self):
200 return hash(self.__mapper__.primary_key)
201
202 def __repr__(self):
203 args = []
204 for attr in [x.name for x in self.__mapper__.primary_key]:
205 args.append(f"{attr}={getattr(self, attr)!r}")
206 return f"{self.__class__.__name__}({', '.join(args)})"
207
208
209class DatasetDagRunQueue(Base):
210 """Model for storing dataset events that need processing."""
211
212 dataset_id = Column(Integer, primary_key=True, nullable=False)
213 target_dag_id = Column(StringID(), primary_key=True, nullable=False)
214 created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False)
215 dataset = relationship("DatasetModel", viewonly=True)
216 __tablename__ = "dataset_dag_run_queue"
217 __table_args__ = (
218 PrimaryKeyConstraint(dataset_id, target_dag_id, name="datasetdagrunqueue_pkey"),
219 ForeignKeyConstraint(
220 (dataset_id,),
221 ["dataset.id"],
222 name="ddrq_dataset_fkey",
223 ondelete="CASCADE",
224 ),
225 ForeignKeyConstraint(
226 (target_dag_id,),
227 ["dag.dag_id"],
228 name="ddrq_dag_fkey",
229 ondelete="CASCADE",
230 ),
231 Index("idx_dataset_dag_run_queue_target_dag_id", target_dag_id),
232 )
233
234 def __eq__(self, other):
235 if isinstance(other, self.__class__):
236 return self.dataset_id == other.dataset_id and self.target_dag_id == other.target_dag_id
237 else:
238 return NotImplemented
239
240 def __hash__(self):
241 return hash(self.__mapper__.primary_key)
242
243 def __repr__(self):
244 args = []
245 for attr in [x.name for x in self.__mapper__.primary_key]:
246 args.append(f"{attr}={getattr(self, attr)!r}")
247 return f"{self.__class__.__name__}({', '.join(args)})"
248
249
250association_table = Table(
251 "dagrun_dataset_event",
252 Base.metadata,
253 Column("dag_run_id", ForeignKey("dag_run.id", ondelete="CASCADE"), primary_key=True),
254 Column("event_id", ForeignKey("dataset_event.id", ondelete="CASCADE"), primary_key=True),
255 Index("idx_dagrun_dataset_events_dag_run_id", "dag_run_id"),
256 Index("idx_dagrun_dataset_events_event_id", "event_id"),
257)
258
259
260class DatasetEvent(Base):
261 """
262 A table to store datasets events.
263
264 :param dataset_id: reference to DatasetModel record
265 :param extra: JSON field for arbitrary extra info
266 :param source_task_id: the task_id of the TI which updated the dataset
267 :param source_dag_id: the dag_id of the TI which updated the dataset
268 :param source_run_id: the run_id of the TI which updated the dataset
269 :param source_map_index: the map_index of the TI which updated the dataset
270 :param timestamp: the time the event was logged
271
272 We use relationships instead of foreign keys so that dataset events are not deleted even
273 if the foreign key object is.
274 """
275
276 id = Column(Integer, primary_key=True, autoincrement=True)
277 dataset_id = Column(Integer, nullable=False)
278 extra = Column(sqlalchemy_jsonfield.JSONField(json=json), nullable=False, default={})
279 source_task_id = Column(StringID(), nullable=True)
280 source_dag_id = Column(StringID(), nullable=True)
281 source_run_id = Column(StringID(), nullable=True)
282 source_map_index = Column(Integer, nullable=True, server_default=text("-1"))
283 timestamp = Column(UtcDateTime, default=timezone.utcnow, nullable=False)
284
285 __tablename__ = "dataset_event"
286 __table_args__ = (
287 Index("idx_dataset_id_timestamp", dataset_id, timestamp),
288 {"sqlite_autoincrement": True}, # ensures PK values not reused
289 )
290
291 created_dagruns = relationship(
292 "DagRun",
293 secondary=association_table,
294 backref="consumed_dataset_events",
295 )
296
297 source_task_instance = relationship(
298 "TaskInstance",
299 primaryjoin="""and_(
300 DatasetEvent.source_dag_id == foreign(TaskInstance.dag_id),
301 DatasetEvent.source_run_id == foreign(TaskInstance.run_id),
302 DatasetEvent.source_task_id == foreign(TaskInstance.task_id),
303 DatasetEvent.source_map_index == foreign(TaskInstance.map_index),
304 )""",
305 viewonly=True,
306 lazy="select",
307 uselist=False,
308 )
309 source_dag_run = relationship(
310 "DagRun",
311 primaryjoin="""and_(
312 DatasetEvent.source_dag_id == foreign(DagRun.dag_id),
313 DatasetEvent.source_run_id == foreign(DagRun.run_id),
314 )""",
315 viewonly=True,
316 lazy="select",
317 uselist=False,
318 )
319 dataset = relationship(
320 DatasetModel,
321 primaryjoin="DatasetEvent.dataset_id == foreign(DatasetModel.id)",
322 viewonly=True,
323 lazy="select",
324 uselist=False,
325 )
326
327 @property
328 def uri(self):
329 return self.dataset.uri
330
331 def __repr__(self) -> str:
332 args = []
333 for attr in [
334 "id",
335 "dataset_id",
336 "extra",
337 "source_task_id",
338 "source_dag_id",
339 "source_run_id",
340 "source_map_index",
341 ]:
342 args.append(f"{attr}={getattr(self, attr)!r}")
343 return f"{self.__class__.__name__}({', '.join(args)})"