Coverage for /pythoncovmergedfiles/medio/medio/src/airflow/airflow/models/dataset.py: 65%

122 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-06-07 06:35 +0000

1# 

2# Licensed to the Apache Software Foundation (ASF) under one 

3# or more contributor license agreements. See the NOTICE file 

4# distributed with this work for additional information 

5# regarding copyright ownership. The ASF licenses this file 

6# to you under the Apache License, Version 2.0 (the 

7# "License"); you may not use this file except in compliance 

8# with the License. You may obtain a copy of the License at 

9# 

10# http://www.apache.org/licenses/LICENSE-2.0 

11# 

12# Unless required by applicable law or agreed to in writing, 

13# software distributed under the License is distributed on an 

14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

15# KIND, either express or implied. See the License for the 

16# specific language governing permissions and limitations 

17# under the License. 

18from __future__ import annotations 

19 

20from urllib.parse import urlsplit 

21 

22import sqlalchemy_jsonfield 

23from sqlalchemy import ( 

24 Boolean, 

25 Column, 

26 ForeignKey, 

27 ForeignKeyConstraint, 

28 Index, 

29 Integer, 

30 PrimaryKeyConstraint, 

31 String, 

32 Table, 

33 text, 

34) 

35from sqlalchemy.orm import relationship 

36 

37from airflow.datasets import Dataset 

38from airflow.models.base import Base, StringID 

39from airflow.settings import json 

40from airflow.utils import timezone 

41from airflow.utils.sqlalchemy import UtcDateTime 

42 

43 

44class DatasetModel(Base): 

45 """ 

46 A table to store datasets. 

47 

48 :param uri: a string that uniquely identifies the dataset 

49 :param extra: JSON field for arbitrary extra info 

50 """ 

51 

52 id = Column(Integer, primary_key=True, autoincrement=True) 

53 uri = Column( 

54 String(length=3000).with_variant( 

55 String( 

56 length=3000, 

57 # latin1 allows for more indexed length in mysql 

58 # and this field should only be ascii chars 

59 collation="latin1_general_cs", 

60 ), 

61 "mysql", 

62 ), 

63 nullable=False, 

64 ) 

65 extra = Column(sqlalchemy_jsonfield.JSONField(json=json), nullable=False, default={}) 

66 created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False) 

67 updated_at = Column(UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow, nullable=False) 

68 is_orphaned = Column(Boolean, default=False, nullable=False, server_default="0") 

69 

70 consuming_dags = relationship("DagScheduleDatasetReference", back_populates="dataset") 

71 producing_tasks = relationship("TaskOutletDatasetReference", back_populates="dataset") 

72 

73 __tablename__ = "dataset" 

74 __table_args__ = ( 

75 Index("idx_uri_unique", uri, unique=True), 

76 {"sqlite_autoincrement": True}, # ensures PK values not reused 

77 ) 

78 

79 @classmethod 

80 def from_public(cls, obj: Dataset) -> DatasetModel: 

81 return cls(uri=obj.uri, extra=obj.extra) 

82 

83 def __init__(self, uri: str, **kwargs): 

84 try: 

85 uri.encode("ascii") 

86 except UnicodeEncodeError: 

87 raise ValueError("URI must be ascii") 

88 parsed = urlsplit(uri) 

89 if parsed.scheme and parsed.scheme.lower() == "airflow": 

90 raise ValueError("Scheme `airflow` is reserved.") 

91 super().__init__(uri=uri, **kwargs) 

92 

93 def __eq__(self, other): 

94 if isinstance(other, (self.__class__, Dataset)): 

95 return self.uri == other.uri 

96 else: 

97 return NotImplemented 

98 

99 def __hash__(self): 

100 return hash(self.uri) 

101 

102 def __repr__(self): 

103 return f"{self.__class__.__name__}(uri={self.uri!r}, extra={self.extra!r})" 

104 

105 

106class DagScheduleDatasetReference(Base): 

107 """References from a DAG to a dataset of which it is a consumer.""" 

108 

109 dataset_id = Column(Integer, primary_key=True, nullable=False) 

110 dag_id = Column(StringID(), primary_key=True, nullable=False) 

111 created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False) 

112 updated_at = Column(UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow, nullable=False) 

113 

114 dataset = relationship("DatasetModel", back_populates="consuming_dags") 

115 queue_records = relationship( 

116 "DatasetDagRunQueue", 

117 primaryjoin="""and_( 

118 DagScheduleDatasetReference.dataset_id == foreign(DatasetDagRunQueue.dataset_id), 

119 DagScheduleDatasetReference.dag_id == foreign(DatasetDagRunQueue.target_dag_id), 

120 )""", 

121 cascade="all, delete, delete-orphan", 

122 ) 

123 

124 __tablename__ = "dag_schedule_dataset_reference" 

125 __table_args__ = ( 

126 PrimaryKeyConstraint(dataset_id, dag_id, name="dsdr_pkey", mssql_clustered=True), 

127 ForeignKeyConstraint( 

128 (dataset_id,), 

129 ["dataset.id"], 

130 name="dsdr_dataset_fkey", 

131 ondelete="CASCADE", 

132 ), 

133 ForeignKeyConstraint( 

134 columns=(dag_id,), 

135 refcolumns=["dag.dag_id"], 

136 name="dsdr_dag_id_fkey", 

137 ondelete="CASCADE", 

138 ), 

139 ) 

140 

141 def __eq__(self, other): 

142 if isinstance(other, self.__class__): 

143 return self.dataset_id == other.dataset_id and self.dag_id == other.dag_id 

144 else: 

145 return NotImplemented 

146 

147 def __hash__(self): 

148 return hash(self.__mapper__.primary_key) 

149 

150 def __repr__(self): 

151 args = [] 

152 for attr in [x.name for x in self.__mapper__.primary_key]: 

153 args.append(f"{attr}={getattr(self, attr)!r}") 

154 return f"{self.__class__.__name__}({', '.join(args)})" 

155 

156 

157class TaskOutletDatasetReference(Base): 

158 """References from a task to a dataset that it updates / produces.""" 

159 

160 dataset_id = Column(Integer, primary_key=True, nullable=False) 

161 dag_id = Column(StringID(), primary_key=True, nullable=False) 

162 task_id = Column(StringID(), primary_key=True, nullable=False) 

163 created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False) 

164 updated_at = Column(UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow, nullable=False) 

165 

166 dataset = relationship("DatasetModel", back_populates="producing_tasks") 

167 

168 __tablename__ = "task_outlet_dataset_reference" 

169 __table_args__ = ( 

170 ForeignKeyConstraint( 

171 (dataset_id,), 

172 ["dataset.id"], 

173 name="todr_dataset_fkey", 

174 ondelete="CASCADE", 

175 ), 

176 PrimaryKeyConstraint(dataset_id, dag_id, task_id, name="todr_pkey", mssql_clustered=True), 

177 ForeignKeyConstraint( 

178 columns=(dag_id,), 

179 refcolumns=["dag.dag_id"], 

180 name="todr_dag_id_fkey", 

181 ondelete="CASCADE", 

182 ), 

183 ) 

184 

185 def __eq__(self, other): 

186 if isinstance(other, self.__class__): 

187 return ( 

188 self.dataset_id == other.dataset_id 

189 and self.dag_id == other.dag_id 

190 and self.task_id == other.task_id 

191 ) 

192 else: 

193 return NotImplemented 

194 

195 def __hash__(self): 

196 return hash(self.__mapper__.primary_key) 

197 

198 def __repr__(self): 

199 args = [] 

200 for attr in [x.name for x in self.__mapper__.primary_key]: 

201 args.append(f"{attr}={getattr(self, attr)!r}") 

202 return f"{self.__class__.__name__}({', '.join(args)})" 

203 

204 

205class DatasetDagRunQueue(Base): 

206 """Model for storing dataset events that need processing.""" 

207 

208 dataset_id = Column(Integer, primary_key=True, nullable=False) 

209 target_dag_id = Column(StringID(), primary_key=True, nullable=False) 

210 created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False) 

211 

212 __tablename__ = "dataset_dag_run_queue" 

213 __table_args__ = ( 

214 PrimaryKeyConstraint(dataset_id, target_dag_id, name="datasetdagrunqueue_pkey", mssql_clustered=True), 

215 ForeignKeyConstraint( 

216 (dataset_id,), 

217 ["dataset.id"], 

218 name="ddrq_dataset_fkey", 

219 ondelete="CASCADE", 

220 ), 

221 ForeignKeyConstraint( 

222 (target_dag_id,), 

223 ["dag.dag_id"], 

224 name="ddrq_dag_fkey", 

225 ondelete="CASCADE", 

226 ), 

227 ) 

228 

229 def __eq__(self, other): 

230 if isinstance(other, self.__class__): 

231 return self.dataset_id == other.dataset_id and self.target_dag_id == other.target_dag_id 

232 else: 

233 return NotImplemented 

234 

235 def __hash__(self): 

236 return hash(self.__mapper__.primary_key) 

237 

238 def __repr__(self): 

239 args = [] 

240 for attr in [x.name for x in self.__mapper__.primary_key]: 

241 args.append(f"{attr}={getattr(self, attr)!r}") 

242 return f"{self.__class__.__name__}({', '.join(args)})" 

243 

244 

245association_table = Table( 

246 "dagrun_dataset_event", 

247 Base.metadata, 

248 Column("dag_run_id", ForeignKey("dag_run.id", ondelete="CASCADE"), primary_key=True), 

249 Column("event_id", ForeignKey("dataset_event.id", ondelete="CASCADE"), primary_key=True), 

250 Index("idx_dagrun_dataset_events_dag_run_id", "dag_run_id"), 

251 Index("idx_dagrun_dataset_events_event_id", "event_id"), 

252) 

253 

254 

255class DatasetEvent(Base): 

256 """ 

257 A table to store datasets events. 

258 

259 :param dataset_id: reference to DatasetModel record 

260 :param extra: JSON field for arbitrary extra info 

261 :param source_task_id: the task_id of the TI which updated the dataset 

262 :param source_dag_id: the dag_id of the TI which updated the dataset 

263 :param source_run_id: the run_id of the TI which updated the dataset 

264 :param source_map_index: the map_index of the TI which updated the dataset 

265 :param timestamp: the time the event was logged 

266 

267 We use relationships instead of foreign keys so that dataset events are not deleted even 

268 if the foreign key object is. 

269 """ 

270 

271 id = Column(Integer, primary_key=True, autoincrement=True) 

272 dataset_id = Column(Integer, nullable=False) 

273 extra = Column(sqlalchemy_jsonfield.JSONField(json=json), nullable=False, default={}) 

274 source_task_id = Column(StringID(), nullable=True) 

275 source_dag_id = Column(StringID(), nullable=True) 

276 source_run_id = Column(StringID(), nullable=True) 

277 source_map_index = Column(Integer, nullable=True, server_default=text("-1")) 

278 timestamp = Column(UtcDateTime, default=timezone.utcnow, nullable=False) 

279 

280 __tablename__ = "dataset_event" 

281 __table_args__ = ( 

282 Index("idx_dataset_id_timestamp", dataset_id, timestamp), 

283 {"sqlite_autoincrement": True}, # ensures PK values not reused 

284 ) 

285 

286 created_dagruns = relationship( 

287 "DagRun", 

288 secondary=association_table, 

289 backref="consumed_dataset_events", 

290 ) 

291 

292 source_task_instance = relationship( 

293 "TaskInstance", 

294 primaryjoin="""and_( 

295 DatasetEvent.source_dag_id == foreign(TaskInstance.dag_id), 

296 DatasetEvent.source_run_id == foreign(TaskInstance.run_id), 

297 DatasetEvent.source_task_id == foreign(TaskInstance.task_id), 

298 DatasetEvent.source_map_index == foreign(TaskInstance.map_index), 

299 )""", 

300 viewonly=True, 

301 lazy="select", 

302 uselist=False, 

303 ) 

304 source_dag_run = relationship( 

305 "DagRun", 

306 primaryjoin="""and_( 

307 DatasetEvent.source_dag_id == foreign(DagRun.dag_id), 

308 DatasetEvent.source_run_id == foreign(DagRun.run_id), 

309 )""", 

310 viewonly=True, 

311 lazy="select", 

312 uselist=False, 

313 ) 

314 dataset = relationship( 

315 DatasetModel, 

316 primaryjoin="DatasetEvent.dataset_id == foreign(DatasetModel.id)", 

317 viewonly=True, 

318 lazy="select", 

319 uselist=False, 

320 ) 

321 

322 @property 

323 def uri(self): 

324 return self.dataset.uri 

325 

326 def __repr__(self) -> str: 

327 args = [] 

328 for attr in [ 

329 "id", 

330 "dataset_id", 

331 "extra", 

332 "source_task_id", 

333 "source_dag_id", 

334 "source_run_id", 

335 "source_map_index", 

336 ]: 

337 args.append(f"{attr}={getattr(self, attr)!r}") 

338 return f"{self.__class__.__name__}({', '.join(args)})"