Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/airflow/models/dataset.py: 65%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

124 statements  

1# 

2# Licensed to the Apache Software Foundation (ASF) under one 

3# or more contributor license agreements. See the NOTICE file 

4# distributed with this work for additional information 

5# regarding copyright ownership. The ASF licenses this file 

6# to you under the Apache License, Version 2.0 (the 

7# "License"); you may not use this file except in compliance 

8# with the License. You may obtain a copy of the License at 

9# 

10# http://www.apache.org/licenses/LICENSE-2.0 

11# 

12# Unless required by applicable law or agreed to in writing, 

13# software distributed under the License is distributed on an 

14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

15# KIND, either express or implied. See the License for the 

16# specific language governing permissions and limitations 

17# under the License. 

18from __future__ import annotations 

19 

20from urllib.parse import urlsplit 

21 

22import sqlalchemy_jsonfield 

23from sqlalchemy import ( 

24 Boolean, 

25 Column, 

26 ForeignKey, 

27 ForeignKeyConstraint, 

28 Index, 

29 Integer, 

30 PrimaryKeyConstraint, 

31 String, 

32 Table, 

33 text, 

34) 

35from sqlalchemy.orm import relationship 

36 

37from airflow.datasets import Dataset 

38from airflow.models.base import Base, StringID 

39from airflow.settings import json 

40from airflow.utils import timezone 

41from airflow.utils.sqlalchemy import UtcDateTime 

42 

43 

44class DatasetModel(Base): 

45 """ 

46 A table to store datasets. 

47 

48 :param uri: a string that uniquely identifies the dataset 

49 :param extra: JSON field for arbitrary extra info 

50 """ 

51 

52 id = Column(Integer, primary_key=True, autoincrement=True) 

53 uri = Column( 

54 String(length=3000).with_variant( 

55 String( 

56 length=3000, 

57 # latin1 allows for more indexed length in mysql 

58 # and this field should only be ascii chars 

59 collation="latin1_general_cs", 

60 ), 

61 "mysql", 

62 ), 

63 nullable=False, 

64 ) 

65 extra = Column(sqlalchemy_jsonfield.JSONField(json=json), nullable=False, default={}) 

66 created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False) 

67 updated_at = Column(UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow, nullable=False) 

68 is_orphaned = Column(Boolean, default=False, nullable=False, server_default="0") 

69 

70 consuming_dags = relationship("DagScheduleDatasetReference", back_populates="dataset") 

71 producing_tasks = relationship("TaskOutletDatasetReference", back_populates="dataset") 

72 

73 __tablename__ = "dataset" 

74 __table_args__ = ( 

75 Index("idx_uri_unique", uri, unique=True), 

76 {"sqlite_autoincrement": True}, # ensures PK values not reused 

77 ) 

78 

79 @classmethod 

80 def from_public(cls, obj: Dataset) -> DatasetModel: 

81 return cls(uri=obj.uri, extra=obj.extra) 

82 

83 def __init__(self, uri: str, **kwargs): 

84 try: 

85 uri.encode("ascii") 

86 except UnicodeEncodeError: 

87 raise ValueError("URI must be ascii") 

88 parsed = urlsplit(uri) 

89 if parsed.scheme and parsed.scheme.lower() == "airflow": 

90 raise ValueError("Scheme `airflow` is reserved.") 

91 super().__init__(uri=uri, **kwargs) 

92 

93 def __eq__(self, other): 

94 if isinstance(other, (self.__class__, Dataset)): 

95 return self.uri == other.uri 

96 else: 

97 return NotImplemented 

98 

99 def __hash__(self): 

100 return hash(self.uri) 

101 

102 def __repr__(self): 

103 return f"{self.__class__.__name__}(uri={self.uri!r}, extra={self.extra!r})" 

104 

105 

106class DagScheduleDatasetReference(Base): 

107 """References from a DAG to a dataset of which it is a consumer.""" 

108 

109 dataset_id = Column(Integer, primary_key=True, nullable=False) 

110 dag_id = Column(StringID(), primary_key=True, nullable=False) 

111 created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False) 

112 updated_at = Column(UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow, nullable=False) 

113 

114 dataset = relationship("DatasetModel", back_populates="consuming_dags") 

115 dag = relationship("DagModel", back_populates="schedule_dataset_references") 

116 

117 queue_records = relationship( 

118 "DatasetDagRunQueue", 

119 primaryjoin="""and_( 

120 DagScheduleDatasetReference.dataset_id == foreign(DatasetDagRunQueue.dataset_id), 

121 DagScheduleDatasetReference.dag_id == foreign(DatasetDagRunQueue.target_dag_id), 

122 )""", 

123 cascade="all, delete, delete-orphan", 

124 ) 

125 

126 __tablename__ = "dag_schedule_dataset_reference" 

127 __table_args__ = ( 

128 PrimaryKeyConstraint(dataset_id, dag_id, name="dsdr_pkey"), 

129 ForeignKeyConstraint( 

130 (dataset_id,), 

131 ["dataset.id"], 

132 name="dsdr_dataset_fkey", 

133 ondelete="CASCADE", 

134 ), 

135 ForeignKeyConstraint( 

136 columns=(dag_id,), 

137 refcolumns=["dag.dag_id"], 

138 name="dsdr_dag_id_fkey", 

139 ondelete="CASCADE", 

140 ), 

141 Index("idx_dag_schedule_dataset_reference_dag_id", dag_id), 

142 ) 

143 

144 def __eq__(self, other): 

145 if isinstance(other, self.__class__): 

146 return self.dataset_id == other.dataset_id and self.dag_id == other.dag_id 

147 else: 

148 return NotImplemented 

149 

150 def __hash__(self): 

151 return hash(self.__mapper__.primary_key) 

152 

153 def __repr__(self): 

154 args = [] 

155 for attr in [x.name for x in self.__mapper__.primary_key]: 

156 args.append(f"{attr}={getattr(self, attr)!r}") 

157 return f"{self.__class__.__name__}({', '.join(args)})" 

158 

159 

160class TaskOutletDatasetReference(Base): 

161 """References from a task to a dataset that it updates / produces.""" 

162 

163 dataset_id = Column(Integer, primary_key=True, nullable=False) 

164 dag_id = Column(StringID(), primary_key=True, nullable=False) 

165 task_id = Column(StringID(), primary_key=True, nullable=False) 

166 created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False) 

167 updated_at = Column(UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow, nullable=False) 

168 

169 dataset = relationship("DatasetModel", back_populates="producing_tasks") 

170 

171 __tablename__ = "task_outlet_dataset_reference" 

172 __table_args__ = ( 

173 ForeignKeyConstraint( 

174 (dataset_id,), 

175 ["dataset.id"], 

176 name="todr_dataset_fkey", 

177 ondelete="CASCADE", 

178 ), 

179 PrimaryKeyConstraint(dataset_id, dag_id, task_id, name="todr_pkey"), 

180 ForeignKeyConstraint( 

181 columns=(dag_id,), 

182 refcolumns=["dag.dag_id"], 

183 name="todr_dag_id_fkey", 

184 ondelete="CASCADE", 

185 ), 

186 Index("idx_task_outlet_dataset_reference_dag_id", dag_id), 

187 ) 

188 

189 def __eq__(self, other): 

190 if isinstance(other, self.__class__): 

191 return ( 

192 self.dataset_id == other.dataset_id 

193 and self.dag_id == other.dag_id 

194 and self.task_id == other.task_id 

195 ) 

196 else: 

197 return NotImplemented 

198 

199 def __hash__(self): 

200 return hash(self.__mapper__.primary_key) 

201 

202 def __repr__(self): 

203 args = [] 

204 for attr in [x.name for x in self.__mapper__.primary_key]: 

205 args.append(f"{attr}={getattr(self, attr)!r}") 

206 return f"{self.__class__.__name__}({', '.join(args)})" 

207 

208 

209class DatasetDagRunQueue(Base): 

210 """Model for storing dataset events that need processing.""" 

211 

212 dataset_id = Column(Integer, primary_key=True, nullable=False) 

213 target_dag_id = Column(StringID(), primary_key=True, nullable=False) 

214 created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False) 

215 dataset = relationship("DatasetModel", viewonly=True) 

216 __tablename__ = "dataset_dag_run_queue" 

217 __table_args__ = ( 

218 PrimaryKeyConstraint(dataset_id, target_dag_id, name="datasetdagrunqueue_pkey"), 

219 ForeignKeyConstraint( 

220 (dataset_id,), 

221 ["dataset.id"], 

222 name="ddrq_dataset_fkey", 

223 ondelete="CASCADE", 

224 ), 

225 ForeignKeyConstraint( 

226 (target_dag_id,), 

227 ["dag.dag_id"], 

228 name="ddrq_dag_fkey", 

229 ondelete="CASCADE", 

230 ), 

231 Index("idx_dataset_dag_run_queue_target_dag_id", target_dag_id), 

232 ) 

233 

234 def __eq__(self, other): 

235 if isinstance(other, self.__class__): 

236 return self.dataset_id == other.dataset_id and self.target_dag_id == other.target_dag_id 

237 else: 

238 return NotImplemented 

239 

240 def __hash__(self): 

241 return hash(self.__mapper__.primary_key) 

242 

243 def __repr__(self): 

244 args = [] 

245 for attr in [x.name for x in self.__mapper__.primary_key]: 

246 args.append(f"{attr}={getattr(self, attr)!r}") 

247 return f"{self.__class__.__name__}({', '.join(args)})" 

248 

249 

250association_table = Table( 

251 "dagrun_dataset_event", 

252 Base.metadata, 

253 Column("dag_run_id", ForeignKey("dag_run.id", ondelete="CASCADE"), primary_key=True), 

254 Column("event_id", ForeignKey("dataset_event.id", ondelete="CASCADE"), primary_key=True), 

255 Index("idx_dagrun_dataset_events_dag_run_id", "dag_run_id"), 

256 Index("idx_dagrun_dataset_events_event_id", "event_id"), 

257) 

258 

259 

260class DatasetEvent(Base): 

261 """ 

262 A table to store datasets events. 

263 

264 :param dataset_id: reference to DatasetModel record 

265 :param extra: JSON field for arbitrary extra info 

266 :param source_task_id: the task_id of the TI which updated the dataset 

267 :param source_dag_id: the dag_id of the TI which updated the dataset 

268 :param source_run_id: the run_id of the TI which updated the dataset 

269 :param source_map_index: the map_index of the TI which updated the dataset 

270 :param timestamp: the time the event was logged 

271 

272 We use relationships instead of foreign keys so that dataset events are not deleted even 

273 if the foreign key object is. 

274 """ 

275 

276 id = Column(Integer, primary_key=True, autoincrement=True) 

277 dataset_id = Column(Integer, nullable=False) 

278 extra = Column(sqlalchemy_jsonfield.JSONField(json=json), nullable=False, default={}) 

279 source_task_id = Column(StringID(), nullable=True) 

280 source_dag_id = Column(StringID(), nullable=True) 

281 source_run_id = Column(StringID(), nullable=True) 

282 source_map_index = Column(Integer, nullable=True, server_default=text("-1")) 

283 timestamp = Column(UtcDateTime, default=timezone.utcnow, nullable=False) 

284 

285 __tablename__ = "dataset_event" 

286 __table_args__ = ( 

287 Index("idx_dataset_id_timestamp", dataset_id, timestamp), 

288 {"sqlite_autoincrement": True}, # ensures PK values not reused 

289 ) 

290 

291 created_dagruns = relationship( 

292 "DagRun", 

293 secondary=association_table, 

294 backref="consumed_dataset_events", 

295 ) 

296 

297 source_task_instance = relationship( 

298 "TaskInstance", 

299 primaryjoin="""and_( 

300 DatasetEvent.source_dag_id == foreign(TaskInstance.dag_id), 

301 DatasetEvent.source_run_id == foreign(TaskInstance.run_id), 

302 DatasetEvent.source_task_id == foreign(TaskInstance.task_id), 

303 DatasetEvent.source_map_index == foreign(TaskInstance.map_index), 

304 )""", 

305 viewonly=True, 

306 lazy="select", 

307 uselist=False, 

308 ) 

309 source_dag_run = relationship( 

310 "DagRun", 

311 primaryjoin="""and_( 

312 DatasetEvent.source_dag_id == foreign(DagRun.dag_id), 

313 DatasetEvent.source_run_id == foreign(DagRun.run_id), 

314 )""", 

315 viewonly=True, 

316 lazy="select", 

317 uselist=False, 

318 ) 

319 dataset = relationship( 

320 DatasetModel, 

321 primaryjoin="DatasetEvent.dataset_id == foreign(DatasetModel.id)", 

322 viewonly=True, 

323 lazy="select", 

324 uselist=False, 

325 ) 

326 

327 @property 

328 def uri(self): 

329 return self.dataset.uri 

330 

331 def __repr__(self) -> str: 

332 args = [] 

333 for attr in [ 

334 "id", 

335 "dataset_id", 

336 "extra", 

337 "source_task_id", 

338 "source_dag_id", 

339 "source_run_id", 

340 "source_map_index", 

341 ]: 

342 args.append(f"{attr}={getattr(self, attr)!r}") 

343 return f"{self.__class__.__name__}({', '.join(args)})"