Coverage for /pythoncovmergedfiles/medio/medio/src/airflow/airflow/models/skipmixin.py: 24%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18from __future__ import annotations
20import warnings
21from typing import TYPE_CHECKING, Iterable, Sequence
23from sqlalchemy import select, update
25from airflow.exceptions import AirflowException, RemovedInAirflow3Warning
26from airflow.models.taskinstance import TaskInstance
27from airflow.utils import timezone
28from airflow.utils.log.logging_mixin import LoggingMixin
29from airflow.utils.session import NEW_SESSION, create_session, provide_session
30from airflow.utils.sqlalchemy import tuple_in_condition
31from airflow.utils.state import TaskInstanceState
33if TYPE_CHECKING:
34 from pendulum import DateTime
35 from sqlalchemy import Session
37 from airflow.models.dagrun import DagRun
38 from airflow.models.operator import Operator
39 from airflow.models.taskmixin import DAGNode
40 from airflow.serialization.pydantic.dag_run import DagRunPydantic
41 from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic
43# The key used by SkipMixin to store XCom data.
44XCOM_SKIPMIXIN_KEY = "skipmixin_key"
46# The dictionary key used to denote task IDs that are skipped
47XCOM_SKIPMIXIN_SKIPPED = "skipped"
49# The dictionary key used to denote task IDs that are followed
50XCOM_SKIPMIXIN_FOLLOWED = "followed"
53def _ensure_tasks(nodes: Iterable[DAGNode]) -> Sequence[Operator]:
54 from airflow.models.baseoperator import BaseOperator
55 from airflow.models.mappedoperator import MappedOperator
57 return [n for n in nodes if isinstance(n, (BaseOperator, MappedOperator))]
60class SkipMixin(LoggingMixin):
61 """A Mixin to skip Tasks Instances."""
63 def _set_state_to_skipped(
64 self,
65 dag_run: DagRun | DagRunPydantic,
66 tasks: Sequence[str] | Sequence[tuple[str, int]],
67 session: Session,
68 ) -> None:
69 """Set state of task instances to skipped from the same dag run."""
70 if tasks:
71 now = timezone.utcnow()
73 if isinstance(tasks[0], tuple):
74 session.execute(
75 update(TaskInstance)
76 .where(
77 TaskInstance.dag_id == dag_run.dag_id,
78 TaskInstance.run_id == dag_run.run_id,
79 tuple_in_condition((TaskInstance.task_id, TaskInstance.map_index), tasks),
80 )
81 .values(state=TaskInstanceState.SKIPPED, start_date=now, end_date=now)
82 .execution_options(synchronize_session=False)
83 )
84 else:
85 session.execute(
86 update(TaskInstance)
87 .where(
88 TaskInstance.dag_id == dag_run.dag_id,
89 TaskInstance.run_id == dag_run.run_id,
90 TaskInstance.task_id.in_(tasks),
91 )
92 .values(state=TaskInstanceState.SKIPPED, start_date=now, end_date=now)
93 .execution_options(synchronize_session=False)
94 )
96 @provide_session
97 def skip(
98 self,
99 dag_run: DagRun | DagRunPydantic,
100 execution_date: DateTime,
101 tasks: Iterable[DAGNode],
102 session: Session = NEW_SESSION,
103 map_index: int = -1,
104 ):
105 """
106 Set tasks instances to skipped from the same dag run.
108 If this instance has a `task_id` attribute, store the list of skipped task IDs to XCom
109 so that NotPreviouslySkippedDep knows these tasks should be skipped when they
110 are cleared.
112 :param dag_run: the DagRun for which to set the tasks to skipped
113 :param execution_date: execution_date
114 :param tasks: tasks to skip (not task_ids)
115 :param session: db session to use
116 :param map_index: map_index of the current task instance
117 """
118 task_list = _ensure_tasks(tasks)
119 if not task_list:
120 return
122 if execution_date and not dag_run:
123 from airflow.models.dagrun import DagRun
125 warnings.warn(
126 "Passing an execution_date to `skip()` is deprecated in favour of passing a dag_run",
127 RemovedInAirflow3Warning,
128 stacklevel=2,
129 )
131 dag_run = session.scalars(
132 select(DagRun).where(
133 DagRun.dag_id == task_list[0].dag_id, DagRun.execution_date == execution_date
134 )
135 ).one()
137 elif execution_date and dag_run and execution_date != dag_run.execution_date:
138 raise ValueError(
139 "execution_date has a different value to dag_run.execution_date -- please only pass dag_run"
140 )
142 if dag_run is None:
143 raise ValueError("dag_run is required")
145 task_ids_list = [d.task_id for d in task_list]
146 self._set_state_to_skipped(dag_run, task_ids_list, session)
147 session.commit()
149 # SkipMixin may not necessarily have a task_id attribute. Only store to XCom if one is available.
150 task_id: str | None = getattr(self, "task_id", None)
151 if task_id is not None:
152 from airflow.models.xcom import XCom
154 XCom.set(
155 key=XCOM_SKIPMIXIN_KEY,
156 value={XCOM_SKIPMIXIN_SKIPPED: task_ids_list},
157 task_id=task_id,
158 dag_id=dag_run.dag_id,
159 run_id=dag_run.run_id,
160 map_index=map_index,
161 session=session,
162 )
164 def skip_all_except(
165 self,
166 ti: TaskInstance | TaskInstancePydantic,
167 branch_task_ids: None | str | Iterable[str],
168 ):
169 """
170 Implement the logic for a branching operator.
172 Given a single task ID or list of task IDs to follow, this skips all other tasks
173 immediately downstream of this operator.
175 branch_task_ids is stored to XCom so that NotPreviouslySkippedDep knows skipped tasks or
176 newly added tasks should be skipped when they are cleared.
177 """
178 self.log.info("Following branch %s", branch_task_ids)
179 if isinstance(branch_task_ids, str):
180 branch_task_id_set = {branch_task_ids}
181 elif isinstance(branch_task_ids, Iterable):
182 branch_task_id_set = set(branch_task_ids)
183 invalid_task_ids_type = {
184 (bti, type(bti).__name__) for bti in branch_task_ids if not isinstance(bti, str)
185 }
186 if invalid_task_ids_type:
187 raise AirflowException(
188 f"'branch_task_ids' expected all task IDs are strings. "
189 f"Invalid tasks found: {invalid_task_ids_type}."
190 )
191 elif branch_task_ids is None:
192 branch_task_id_set = set()
193 else:
194 raise AirflowException(
195 "'branch_task_ids' must be either None, a task ID, or an Iterable of IDs, "
196 f"but got {type(branch_task_ids).__name__!r}."
197 )
199 dag_run = ti.get_dagrun()
200 if TYPE_CHECKING:
201 assert isinstance(dag_run, DagRun)
202 assert ti.task
204 # TODO(potiuk): Handle TaskInstancePydantic case differently - we need to figure out the way to
205 # pass task that has been set in LocalTaskJob but in the way that TaskInstancePydantic definition
206 # does not attempt to serialize the field from/to ORM
207 task = ti.task
208 dag = task.dag
209 if TYPE_CHECKING:
210 assert dag
212 valid_task_ids = set(dag.task_ids)
213 invalid_task_ids = branch_task_id_set - valid_task_ids
214 if invalid_task_ids:
215 raise AirflowException(
216 "'branch_task_ids' must contain only valid task_ids. "
217 f"Invalid tasks found: {invalid_task_ids}."
218 )
220 downstream_tasks = _ensure_tasks(task.downstream_list)
222 if downstream_tasks:
223 # For a branching workflow that looks like this, when "branch" does skip_all_except("task1"),
224 # we intuitively expect both "task1" and "join" to execute even though strictly speaking,
225 # "join" is also immediately downstream of "branch" and should have been skipped. Therefore,
226 # we need a special case here for such empty branches: Check downstream tasks of branch_task_ids.
227 # In case the task to skip is also downstream of branch_task_ids, we add it to branch_task_ids and
228 # exclude it from skipping.
229 #
230 # branch -----> join
231 # \ ^
232 # v /
233 # task1
234 #
235 for branch_task_id in list(branch_task_id_set):
236 branch_task_id_set.update(dag.get_task(branch_task_id).get_flat_relative_ids(upstream=False))
238 skip_tasks = [
239 (t.task_id, downstream_ti.map_index)
240 for t in downstream_tasks
241 if (downstream_ti := dag_run.get_task_instance(t.task_id, map_index=ti.map_index))
242 and t.task_id not in branch_task_id_set
243 ]
245 follow_task_ids = [t.task_id for t in downstream_tasks if t.task_id in branch_task_id_set]
246 self.log.info("Skipping tasks %s", skip_tasks)
247 with create_session() as session:
248 self._set_state_to_skipped(dag_run, skip_tasks, session=session)
249 # For some reason, session.commit() needs to happen before xcom_push.
250 # Otherwise the session is not committed.
251 session.commit()
252 ti.xcom_push(key=XCOM_SKIPMIXIN_KEY, value={XCOM_SKIPMIXIN_FOLLOWED: follow_task_ids})