Coverage for /pythoncovmergedfiles/medio/medio/src/airflow/airflow/models/dataset.py: 65%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18from __future__ import annotations
20from urllib.parse import urlsplit
22import sqlalchemy_jsonfield
23from sqlalchemy import (
24 Boolean,
25 Column,
26 ForeignKey,
27 ForeignKeyConstraint,
28 Index,
29 Integer,
30 PrimaryKeyConstraint,
31 String,
32 Table,
33 text,
34)
35from sqlalchemy.orm import relationship
37from airflow.datasets import Dataset
38from airflow.models.base import Base, StringID
39from airflow.settings import json
40from airflow.utils import timezone
41from airflow.utils.sqlalchemy import UtcDateTime
44class DatasetModel(Base):
45 """
46 A table to store datasets.
48 :param uri: a string that uniquely identifies the dataset
49 :param extra: JSON field for arbitrary extra info
50 """
52 id = Column(Integer, primary_key=True, autoincrement=True)
53 uri = Column(
54 String(length=3000).with_variant(
55 String(
56 length=3000,
57 # latin1 allows for more indexed length in mysql
58 # and this field should only be ascii chars
59 collation="latin1_general_cs",
60 ),
61 "mysql",
62 ),
63 nullable=False,
64 )
65 extra = Column(sqlalchemy_jsonfield.JSONField(json=json), nullable=False, default={})
66 created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False)
67 updated_at = Column(UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow, nullable=False)
68 is_orphaned = Column(Boolean, default=False, nullable=False, server_default="0")
70 consuming_dags = relationship("DagScheduleDatasetReference", back_populates="dataset")
71 producing_tasks = relationship("TaskOutletDatasetReference", back_populates="dataset")
73 __tablename__ = "dataset"
74 __table_args__ = (
75 Index("idx_uri_unique", uri, unique=True),
76 {"sqlite_autoincrement": True}, # ensures PK values not reused
77 )
79 @classmethod
80 def from_public(cls, obj: Dataset) -> DatasetModel:
81 return cls(uri=obj.uri, extra=obj.extra)
83 def __init__(self, uri: str, **kwargs):
84 try:
85 uri.encode("ascii")
86 except UnicodeEncodeError:
87 raise ValueError("URI must be ascii")
88 parsed = urlsplit(uri)
89 if parsed.scheme and parsed.scheme.lower() == "airflow":
90 raise ValueError("Scheme `airflow` is reserved.")
91 super().__init__(uri=uri, **kwargs)
93 def __eq__(self, other):
94 if isinstance(other, (self.__class__, Dataset)):
95 return self.uri == other.uri
96 else:
97 return NotImplemented
99 def __hash__(self):
100 return hash(self.uri)
102 def __repr__(self):
103 return f"{self.__class__.__name__}(uri={self.uri!r}, extra={self.extra!r})"
106class DagScheduleDatasetReference(Base):
107 """References from a DAG to a dataset of which it is a consumer."""
109 dataset_id = Column(Integer, primary_key=True, nullable=False)
110 dag_id = Column(StringID(), primary_key=True, nullable=False)
111 created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False)
112 updated_at = Column(UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow, nullable=False)
114 dataset = relationship("DatasetModel", back_populates="consuming_dags")
115 dag = relationship("DagModel", back_populates="schedule_dataset_references")
117 queue_records = relationship(
118 "DatasetDagRunQueue",
119 primaryjoin="""and_(
120 DagScheduleDatasetReference.dataset_id == foreign(DatasetDagRunQueue.dataset_id),
121 DagScheduleDatasetReference.dag_id == foreign(DatasetDagRunQueue.target_dag_id),
122 )""",
123 cascade="all, delete, delete-orphan",
124 )
126 __tablename__ = "dag_schedule_dataset_reference"
127 __table_args__ = (
128 PrimaryKeyConstraint(dataset_id, dag_id, name="dsdr_pkey"),
129 ForeignKeyConstraint(
130 (dataset_id,),
131 ["dataset.id"],
132 name="dsdr_dataset_fkey",
133 ondelete="CASCADE",
134 ),
135 ForeignKeyConstraint(
136 columns=(dag_id,),
137 refcolumns=["dag.dag_id"],
138 name="dsdr_dag_id_fkey",
139 ondelete="CASCADE",
140 ),
141 Index("idx_dag_schedule_dataset_reference_dag_id", dag_id),
142 )
144 def __eq__(self, other):
145 if isinstance(other, self.__class__):
146 return self.dataset_id == other.dataset_id and self.dag_id == other.dag_id
147 else:
148 return NotImplemented
150 def __hash__(self):
151 return hash(self.__mapper__.primary_key)
153 def __repr__(self):
154 args = []
155 for attr in [x.name for x in self.__mapper__.primary_key]:
156 args.append(f"{attr}={getattr(self, attr)!r}")
157 return f"{self.__class__.__name__}({', '.join(args)})"
160class TaskOutletDatasetReference(Base):
161 """References from a task to a dataset that it updates / produces."""
163 dataset_id = Column(Integer, primary_key=True, nullable=False)
164 dag_id = Column(StringID(), primary_key=True, nullable=False)
165 task_id = Column(StringID(), primary_key=True, nullable=False)
166 created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False)
167 updated_at = Column(UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow, nullable=False)
169 dataset = relationship("DatasetModel", back_populates="producing_tasks")
171 __tablename__ = "task_outlet_dataset_reference"
172 __table_args__ = (
173 ForeignKeyConstraint(
174 (dataset_id,),
175 ["dataset.id"],
176 name="todr_dataset_fkey",
177 ondelete="CASCADE",
178 ),
179 PrimaryKeyConstraint(dataset_id, dag_id, task_id, name="todr_pkey"),
180 ForeignKeyConstraint(
181 columns=(dag_id,),
182 refcolumns=["dag.dag_id"],
183 name="todr_dag_id_fkey",
184 ondelete="CASCADE",
185 ),
186 Index("idx_task_outlet_dataset_reference_dag_id", dag_id),
187 )
189 def __eq__(self, other):
190 if isinstance(other, self.__class__):
191 return (
192 self.dataset_id == other.dataset_id
193 and self.dag_id == other.dag_id
194 and self.task_id == other.task_id
195 )
196 else:
197 return NotImplemented
199 def __hash__(self):
200 return hash(self.__mapper__.primary_key)
202 def __repr__(self):
203 args = []
204 for attr in [x.name for x in self.__mapper__.primary_key]:
205 args.append(f"{attr}={getattr(self, attr)!r}")
206 return f"{self.__class__.__name__}({', '.join(args)})"
209class DatasetDagRunQueue(Base):
210 """Model for storing dataset events that need processing."""
212 dataset_id = Column(Integer, primary_key=True, nullable=False)
213 target_dag_id = Column(StringID(), primary_key=True, nullable=False)
214 created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False)
215 dataset = relationship("DatasetModel", viewonly=True)
216 __tablename__ = "dataset_dag_run_queue"
217 __table_args__ = (
218 PrimaryKeyConstraint(dataset_id, target_dag_id, name="datasetdagrunqueue_pkey"),
219 ForeignKeyConstraint(
220 (dataset_id,),
221 ["dataset.id"],
222 name="ddrq_dataset_fkey",
223 ondelete="CASCADE",
224 ),
225 ForeignKeyConstraint(
226 (target_dag_id,),
227 ["dag.dag_id"],
228 name="ddrq_dag_fkey",
229 ondelete="CASCADE",
230 ),
231 Index("idx_dataset_dag_run_queue_target_dag_id", target_dag_id),
232 )
234 def __eq__(self, other):
235 if isinstance(other, self.__class__):
236 return self.dataset_id == other.dataset_id and self.target_dag_id == other.target_dag_id
237 else:
238 return NotImplemented
240 def __hash__(self):
241 return hash(self.__mapper__.primary_key)
243 def __repr__(self):
244 args = []
245 for attr in [x.name for x in self.__mapper__.primary_key]:
246 args.append(f"{attr}={getattr(self, attr)!r}")
247 return f"{self.__class__.__name__}({', '.join(args)})"
250association_table = Table(
251 "dagrun_dataset_event",
252 Base.metadata,
253 Column("dag_run_id", ForeignKey("dag_run.id", ondelete="CASCADE"), primary_key=True),
254 Column("event_id", ForeignKey("dataset_event.id", ondelete="CASCADE"), primary_key=True),
255 Index("idx_dagrun_dataset_events_dag_run_id", "dag_run_id"),
256 Index("idx_dagrun_dataset_events_event_id", "event_id"),
257)
260class DatasetEvent(Base):
261 """
262 A table to store datasets events.
264 :param dataset_id: reference to DatasetModel record
265 :param extra: JSON field for arbitrary extra info
266 :param source_task_id: the task_id of the TI which updated the dataset
267 :param source_dag_id: the dag_id of the TI which updated the dataset
268 :param source_run_id: the run_id of the TI which updated the dataset
269 :param source_map_index: the map_index of the TI which updated the dataset
270 :param timestamp: the time the event was logged
272 We use relationships instead of foreign keys so that dataset events are not deleted even
273 if the foreign key object is.
274 """
276 id = Column(Integer, primary_key=True, autoincrement=True)
277 dataset_id = Column(Integer, nullable=False)
278 extra = Column(sqlalchemy_jsonfield.JSONField(json=json), nullable=False, default={})
279 source_task_id = Column(StringID(), nullable=True)
280 source_dag_id = Column(StringID(), nullable=True)
281 source_run_id = Column(StringID(), nullable=True)
282 source_map_index = Column(Integer, nullable=True, server_default=text("-1"))
283 timestamp = Column(UtcDateTime, default=timezone.utcnow, nullable=False)
285 __tablename__ = "dataset_event"
286 __table_args__ = (
287 Index("idx_dataset_id_timestamp", dataset_id, timestamp),
288 {"sqlite_autoincrement": True}, # ensures PK values not reused
289 )
291 created_dagruns = relationship(
292 "DagRun",
293 secondary=association_table,
294 backref="consumed_dataset_events",
295 )
297 source_task_instance = relationship(
298 "TaskInstance",
299 primaryjoin="""and_(
300 DatasetEvent.source_dag_id == foreign(TaskInstance.dag_id),
301 DatasetEvent.source_run_id == foreign(TaskInstance.run_id),
302 DatasetEvent.source_task_id == foreign(TaskInstance.task_id),
303 DatasetEvent.source_map_index == foreign(TaskInstance.map_index),
304 )""",
305 viewonly=True,
306 lazy="select",
307 uselist=False,
308 )
309 source_dag_run = relationship(
310 "DagRun",
311 primaryjoin="""and_(
312 DatasetEvent.source_dag_id == foreign(DagRun.dag_id),
313 DatasetEvent.source_run_id == foreign(DagRun.run_id),
314 )""",
315 viewonly=True,
316 lazy="select",
317 uselist=False,
318 )
319 dataset = relationship(
320 DatasetModel,
321 primaryjoin="DatasetEvent.dataset_id == foreign(DatasetModel.id)",
322 viewonly=True,
323 lazy="select",
324 uselist=False,
325 )
327 @property
328 def uri(self):
329 return self.dataset.uri
331 def __repr__(self) -> str:
332 args = []
333 for attr in [
334 "id",
335 "dataset_id",
336 "extra",
337 "source_task_id",
338 "source_dag_id",
339 "source_run_id",
340 "source_map_index",
341 ]:
342 args.append(f"{attr}={getattr(self, attr)!r}")
343 return f"{self.__class__.__name__}({', '.join(args)})"