Coverage for /pythoncovmergedfiles/medio/medio/src/airflow/airflow/models/renderedtifields.py: 51%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

91 statements  

1# 

2# Licensed to the Apache Software Foundation (ASF) under one 

3# or more contributor license agreements. See the NOTICE file 

4# distributed with this work for additional information 

5# regarding copyright ownership. The ASF licenses this file 

6# to you under the Apache License, Version 2.0 (the 

7# "License"); you may not use this file except in compliance 

8# with the License. You may obtain a copy of the License at 

9# 

10# http://www.apache.org/licenses/LICENSE-2.0 

11# 

12# Unless required by applicable law or agreed to in writing, 

13# software distributed under the License is distributed on an 

14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

15# KIND, either express or implied. See the License for the 

16# specific language governing permissions and limitations 

17# under the License. 

18"""Save Rendered Template Fields.""" 

19 

20from __future__ import annotations 

21 

22import os 

23from typing import TYPE_CHECKING 

24 

25import sqlalchemy_jsonfield 

26from sqlalchemy import ( 

27 Column, 

28 ForeignKeyConstraint, 

29 Integer, 

30 PrimaryKeyConstraint, 

31 delete, 

32 exists, 

33 select, 

34 text, 

35) 

36from sqlalchemy.ext.associationproxy import association_proxy 

37from sqlalchemy.orm import relationship 

38 

39from airflow.configuration import conf 

40from airflow.models.base import StringID, TaskInstanceDependencies 

41from airflow.serialization.helpers import serialize_template_field 

42from airflow.settings import json 

43from airflow.utils.retries import retry_db_transaction 

44from airflow.utils.session import NEW_SESSION, provide_session 

45 

46if TYPE_CHECKING: 

47 from sqlalchemy.orm import Session 

48 from sqlalchemy.sql import FromClause 

49 

50 from airflow.models import Operator 

51 from airflow.models.taskinstance import TaskInstance, TaskInstancePydantic 

52 

53 

54def get_serialized_template_fields(task: Operator): 

55 """ 

56 Get and serialize the template fields for a task. 

57 

58 Used in preparing to store them in RTIF table. 

59 

60 :param task: Operator instance with rendered template fields 

61 

62 :meta private: 

63 """ 

64 return {field: serialize_template_field(getattr(task, field), field) for field in task.template_fields} 

65 

66 

67class RenderedTaskInstanceFields(TaskInstanceDependencies): 

68 """Save Rendered Template Fields.""" 

69 

70 __tablename__ = "rendered_task_instance_fields" 

71 

72 dag_id = Column(StringID(), primary_key=True) 

73 task_id = Column(StringID(), primary_key=True) 

74 run_id = Column(StringID(), primary_key=True) 

75 map_index = Column(Integer, primary_key=True, server_default=text("-1")) 

76 rendered_fields = Column(sqlalchemy_jsonfield.JSONField(json=json), nullable=False) 

77 k8s_pod_yaml = Column(sqlalchemy_jsonfield.JSONField(json=json), nullable=True) 

78 

79 __table_args__ = ( 

80 PrimaryKeyConstraint( 

81 "dag_id", 

82 "task_id", 

83 "run_id", 

84 "map_index", 

85 name="rendered_task_instance_fields_pkey", 

86 ), 

87 ForeignKeyConstraint( 

88 [dag_id, task_id, run_id, map_index], 

89 [ 

90 "task_instance.dag_id", 

91 "task_instance.task_id", 

92 "task_instance.run_id", 

93 "task_instance.map_index", 

94 ], 

95 name="rtif_ti_fkey", 

96 ondelete="CASCADE", 

97 ), 

98 ) 

99 task_instance = relationship( 

100 "TaskInstance", 

101 lazy="joined", 

102 back_populates="rendered_task_instance_fields", 

103 ) 

104 

105 # We don't need a DB level FK here, as we already have that to TI (which has one to DR) but by defining 

106 # the relationship we can more easily find the execution date for these rows 

107 dag_run = relationship( 

108 "DagRun", 

109 primaryjoin="""and_( 

110 RenderedTaskInstanceFields.dag_id == foreign(DagRun.dag_id), 

111 RenderedTaskInstanceFields.run_id == foreign(DagRun.run_id), 

112 )""", 

113 viewonly=True, 

114 ) 

115 

116 execution_date = association_proxy("dag_run", "execution_date") 

117 

118 def __init__(self, ti: TaskInstance, render_templates=True, rendered_fields=None): 

119 self.dag_id = ti.dag_id 

120 self.task_id = ti.task_id 

121 self.run_id = ti.run_id 

122 self.map_index = ti.map_index 

123 self.ti = ti 

124 if render_templates: 

125 ti.render_templates() 

126 

127 if TYPE_CHECKING: 

128 assert ti.task 

129 

130 self.task = ti.task 

131 if os.environ.get("AIRFLOW_IS_K8S_EXECUTOR_POD", None): 

132 # we can safely import it here from provider. In Airflow 2.7.0+ you need to have new version 

133 # of kubernetes provider installed to reach this place 

134 from airflow.providers.cncf.kubernetes.template_rendering import render_k8s_pod_yaml 

135 

136 self.k8s_pod_yaml = render_k8s_pod_yaml(ti) 

137 self.rendered_fields = rendered_fields or get_serialized_template_fields(task=ti.task) 

138 

139 self._redact() 

140 

141 def __repr__(self): 

142 prefix = f"<{self.__class__.__name__}: {self.dag_id}.{self.task_id} {self.run_id}" 

143 if self.map_index != -1: 

144 prefix += f" map_index={self.map_index}" 

145 return prefix + ">" 

146 

147 def _redact(self): 

148 from airflow.utils.log.secrets_masker import redact 

149 

150 if self.k8s_pod_yaml: 

151 self.k8s_pod_yaml = redact(self.k8s_pod_yaml) 

152 

153 for field, rendered in self.rendered_fields.items(): 

154 self.rendered_fields[field] = redact(rendered, field) 

155 

156 @classmethod 

157 @provide_session 

158 def get_templated_fields( 

159 cls, ti: TaskInstance | TaskInstancePydantic, session: Session = NEW_SESSION 

160 ) -> dict | None: 

161 """ 

162 Get templated field for a TaskInstance from the RenderedTaskInstanceFields table. 

163 

164 :param ti: Task Instance 

165 :param session: SqlAlchemy Session 

166 :return: Rendered Templated TI field 

167 """ 

168 result = session.scalar( 

169 select(cls).where( 

170 cls.dag_id == ti.dag_id, 

171 cls.task_id == ti.task_id, 

172 cls.run_id == ti.run_id, 

173 cls.map_index == ti.map_index, 

174 ) 

175 ) 

176 

177 if result: 

178 rendered_fields = result.rendered_fields 

179 return rendered_fields 

180 else: 

181 return None 

182 

183 @classmethod 

184 @provide_session 

185 def get_k8s_pod_yaml(cls, ti: TaskInstance, session: Session = NEW_SESSION) -> dict | None: 

186 """ 

187 Get rendered Kubernetes Pod Yaml for a TaskInstance from the RenderedTaskInstanceFields table. 

188 

189 :param ti: Task Instance 

190 :param session: SqlAlchemy Session 

191 :return: Kubernetes Pod Yaml 

192 """ 

193 result = session.scalar( 

194 select(cls).where( 

195 cls.dag_id == ti.dag_id, 

196 cls.task_id == ti.task_id, 

197 cls.run_id == ti.run_id, 

198 cls.map_index == ti.map_index, 

199 ) 

200 ) 

201 return result.k8s_pod_yaml if result else None 

202 

203 @provide_session 

204 @retry_db_transaction 

205 def write(self, session: Session = None): 

206 """Write instance to database. 

207 

208 :param session: SqlAlchemy Session 

209 """ 

210 session.merge(self) 

211 

212 @classmethod 

213 @provide_session 

214 def delete_old_records( 

215 cls, 

216 task_id: str, 

217 dag_id: str, 

218 num_to_keep: int = conf.getint("core", "max_num_rendered_ti_fields_per_task", fallback=0), 

219 session: Session = NEW_SESSION, 

220 ) -> None: 

221 """ 

222 Keep only Last X (num_to_keep) number of records for a task by deleting others. 

223 

224 In the case of data for a mapped task either all of the rows or none of the rows will be deleted, so 

225 we don't end up with partial data for a set of mapped Task Instances left in the database. 

226 

227 :param task_id: Task ID 

228 :param dag_id: Dag ID 

229 :param num_to_keep: Number of Records to keep 

230 :param session: SqlAlchemy Session 

231 """ 

232 if num_to_keep <= 0: 

233 return 

234 

235 from airflow.models.dagrun import DagRun 

236 

237 tis_to_keep_query = ( 

238 select(cls.dag_id, cls.task_id, cls.run_id, DagRun.execution_date) 

239 .where(cls.dag_id == dag_id, cls.task_id == task_id) 

240 .join(cls.dag_run) 

241 .distinct() 

242 .order_by(DagRun.execution_date.desc()) 

243 .limit(num_to_keep) 

244 ) 

245 

246 cls._do_delete_old_records( 

247 dag_id=dag_id, 

248 task_id=task_id, 

249 ti_clause=tis_to_keep_query.subquery(), 

250 session=session, 

251 ) 

252 session.flush() 

253 

254 @classmethod 

255 @retry_db_transaction 

256 def _do_delete_old_records( 

257 cls, 

258 *, 

259 task_id: str, 

260 dag_id: str, 

261 ti_clause: FromClause, 

262 session: Session, 

263 ) -> None: 

264 # This query might deadlock occasionally and it should be retried if fails (see decorator) 

265 

266 stmt = ( 

267 delete(cls) 

268 .where( 

269 cls.dag_id == dag_id, 

270 cls.task_id == task_id, 

271 ~exists(1).where( 

272 ti_clause.c.dag_id == cls.dag_id, 

273 ti_clause.c.task_id == cls.task_id, 

274 ti_clause.c.run_id == cls.run_id, 

275 ), 

276 ) 

277 .execution_options(synchronize_session=False) 

278 ) 

279 

280 session.execute(stmt)