Coverage for /pythoncovmergedfiles/medio/medio/src/airflow/airflow/models/renderedtifields.py: 51%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18"""Save Rendered Template Fields."""
20from __future__ import annotations
22import os
23from typing import TYPE_CHECKING
25import sqlalchemy_jsonfield
26from sqlalchemy import (
27 Column,
28 ForeignKeyConstraint,
29 Integer,
30 PrimaryKeyConstraint,
31 delete,
32 exists,
33 select,
34 text,
35)
36from sqlalchemy.ext.associationproxy import association_proxy
37from sqlalchemy.orm import relationship
39from airflow.configuration import conf
40from airflow.models.base import StringID, TaskInstanceDependencies
41from airflow.serialization.helpers import serialize_template_field
42from airflow.settings import json
43from airflow.utils.retries import retry_db_transaction
44from airflow.utils.session import NEW_SESSION, provide_session
46if TYPE_CHECKING:
47 from sqlalchemy.orm import Session
48 from sqlalchemy.sql import FromClause
50 from airflow.models import Operator
51 from airflow.models.taskinstance import TaskInstance, TaskInstancePydantic
54def get_serialized_template_fields(task: Operator):
55 """
56 Get and serialize the template fields for a task.
58 Used in preparing to store them in RTIF table.
60 :param task: Operator instance with rendered template fields
62 :meta private:
63 """
64 return {field: serialize_template_field(getattr(task, field), field) for field in task.template_fields}
67class RenderedTaskInstanceFields(TaskInstanceDependencies):
68 """Save Rendered Template Fields."""
70 __tablename__ = "rendered_task_instance_fields"
72 dag_id = Column(StringID(), primary_key=True)
73 task_id = Column(StringID(), primary_key=True)
74 run_id = Column(StringID(), primary_key=True)
75 map_index = Column(Integer, primary_key=True, server_default=text("-1"))
76 rendered_fields = Column(sqlalchemy_jsonfield.JSONField(json=json), nullable=False)
77 k8s_pod_yaml = Column(sqlalchemy_jsonfield.JSONField(json=json), nullable=True)
79 __table_args__ = (
80 PrimaryKeyConstraint(
81 "dag_id",
82 "task_id",
83 "run_id",
84 "map_index",
85 name="rendered_task_instance_fields_pkey",
86 ),
87 ForeignKeyConstraint(
88 [dag_id, task_id, run_id, map_index],
89 [
90 "task_instance.dag_id",
91 "task_instance.task_id",
92 "task_instance.run_id",
93 "task_instance.map_index",
94 ],
95 name="rtif_ti_fkey",
96 ondelete="CASCADE",
97 ),
98 )
99 task_instance = relationship(
100 "TaskInstance",
101 lazy="joined",
102 back_populates="rendered_task_instance_fields",
103 )
105 # We don't need a DB level FK here, as we already have that to TI (which has one to DR) but by defining
106 # the relationship we can more easily find the execution date for these rows
107 dag_run = relationship(
108 "DagRun",
109 primaryjoin="""and_(
110 RenderedTaskInstanceFields.dag_id == foreign(DagRun.dag_id),
111 RenderedTaskInstanceFields.run_id == foreign(DagRun.run_id),
112 )""",
113 viewonly=True,
114 )
116 execution_date = association_proxy("dag_run", "execution_date")
118 def __init__(self, ti: TaskInstance, render_templates=True, rendered_fields=None):
119 self.dag_id = ti.dag_id
120 self.task_id = ti.task_id
121 self.run_id = ti.run_id
122 self.map_index = ti.map_index
123 self.ti = ti
124 if render_templates:
125 ti.render_templates()
127 if TYPE_CHECKING:
128 assert ti.task
130 self.task = ti.task
131 if os.environ.get("AIRFLOW_IS_K8S_EXECUTOR_POD", None):
132 # we can safely import it here from provider. In Airflow 2.7.0+ you need to have new version
133 # of kubernetes provider installed to reach this place
134 from airflow.providers.cncf.kubernetes.template_rendering import render_k8s_pod_yaml
136 self.k8s_pod_yaml = render_k8s_pod_yaml(ti)
137 self.rendered_fields = rendered_fields or get_serialized_template_fields(task=ti.task)
139 self._redact()
141 def __repr__(self):
142 prefix = f"<{self.__class__.__name__}: {self.dag_id}.{self.task_id} {self.run_id}"
143 if self.map_index != -1:
144 prefix += f" map_index={self.map_index}"
145 return prefix + ">"
147 def _redact(self):
148 from airflow.utils.log.secrets_masker import redact
150 if self.k8s_pod_yaml:
151 self.k8s_pod_yaml = redact(self.k8s_pod_yaml)
153 for field, rendered in self.rendered_fields.items():
154 self.rendered_fields[field] = redact(rendered, field)
156 @classmethod
157 @provide_session
158 def get_templated_fields(
159 cls, ti: TaskInstance | TaskInstancePydantic, session: Session = NEW_SESSION
160 ) -> dict | None:
161 """
162 Get templated field for a TaskInstance from the RenderedTaskInstanceFields table.
164 :param ti: Task Instance
165 :param session: SqlAlchemy Session
166 :return: Rendered Templated TI field
167 """
168 result = session.scalar(
169 select(cls).where(
170 cls.dag_id == ti.dag_id,
171 cls.task_id == ti.task_id,
172 cls.run_id == ti.run_id,
173 cls.map_index == ti.map_index,
174 )
175 )
177 if result:
178 rendered_fields = result.rendered_fields
179 return rendered_fields
180 else:
181 return None
183 @classmethod
184 @provide_session
185 def get_k8s_pod_yaml(cls, ti: TaskInstance, session: Session = NEW_SESSION) -> dict | None:
186 """
187 Get rendered Kubernetes Pod Yaml for a TaskInstance from the RenderedTaskInstanceFields table.
189 :param ti: Task Instance
190 :param session: SqlAlchemy Session
191 :return: Kubernetes Pod Yaml
192 """
193 result = session.scalar(
194 select(cls).where(
195 cls.dag_id == ti.dag_id,
196 cls.task_id == ti.task_id,
197 cls.run_id == ti.run_id,
198 cls.map_index == ti.map_index,
199 )
200 )
201 return result.k8s_pod_yaml if result else None
203 @provide_session
204 @retry_db_transaction
205 def write(self, session: Session = None):
206 """Write instance to database.
208 :param session: SqlAlchemy Session
209 """
210 session.merge(self)
212 @classmethod
213 @provide_session
214 def delete_old_records(
215 cls,
216 task_id: str,
217 dag_id: str,
218 num_to_keep: int = conf.getint("core", "max_num_rendered_ti_fields_per_task", fallback=0),
219 session: Session = NEW_SESSION,
220 ) -> None:
221 """
222 Keep only Last X (num_to_keep) number of records for a task by deleting others.
224 In the case of data for a mapped task either all of the rows or none of the rows will be deleted, so
225 we don't end up with partial data for a set of mapped Task Instances left in the database.
227 :param task_id: Task ID
228 :param dag_id: Dag ID
229 :param num_to_keep: Number of Records to keep
230 :param session: SqlAlchemy Session
231 """
232 if num_to_keep <= 0:
233 return
235 from airflow.models.dagrun import DagRun
237 tis_to_keep_query = (
238 select(cls.dag_id, cls.task_id, cls.run_id, DagRun.execution_date)
239 .where(cls.dag_id == dag_id, cls.task_id == task_id)
240 .join(cls.dag_run)
241 .distinct()
242 .order_by(DagRun.execution_date.desc())
243 .limit(num_to_keep)
244 )
246 cls._do_delete_old_records(
247 dag_id=dag_id,
248 task_id=task_id,
249 ti_clause=tis_to_keep_query.subquery(),
250 session=session,
251 )
252 session.flush()
254 @classmethod
255 @retry_db_transaction
256 def _do_delete_old_records(
257 cls,
258 *,
259 task_id: str,
260 dag_id: str,
261 ti_clause: FromClause,
262 session: Session,
263 ) -> None:
264 # This query might deadlock occasionally and it should be retried if fails (see decorator)
266 stmt = (
267 delete(cls)
268 .where(
269 cls.dag_id == dag_id,
270 cls.task_id == task_id,
271 ~exists(1).where(
272 ti_clause.c.dag_id == cls.dag_id,
273 ti_clause.c.task_id == cls.task_id,
274 ti_clause.c.run_id == cls.run_id,
275 ),
276 )
277 .execution_options(synchronize_session=False)
278 )
280 session.execute(stmt)