1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18from __future__ import annotations
19
20from collections.abc import Iterable, Sequence
21from types import GeneratorType
22from typing import TYPE_CHECKING
23
24from airflow.providers.common.compat.sdk import AirflowException
25from airflow.utils.log.logging_mixin import LoggingMixin
26
27if TYPE_CHECKING:
28 from airflow.sdk.definitions._internal.node import DAGNode
29 from airflow.sdk.types import Operator, RuntimeTaskInstanceProtocol
30
31# The key used by SkipMixin to store XCom data.
32XCOM_SKIPMIXIN_KEY = "skipmixin_key"
33
34# The dictionary key used to denote task IDs that are skipped
35XCOM_SKIPMIXIN_SKIPPED = "skipped"
36
37# The dictionary key used to denote task IDs that are followed
38XCOM_SKIPMIXIN_FOLLOWED = "followed"
39
40
41def _ensure_tasks(nodes: Iterable[DAGNode]) -> Sequence[Operator]:
42 from airflow.providers.common.compat.sdk import BaseOperator, MappedOperator
43
44 return [n for n in nodes if isinstance(n, (BaseOperator, MappedOperator))]
45
46
47# This class should only be used in Airflow 3.0 and later.
48class SkipMixin(LoggingMixin):
49 """A Mixin to skip Tasks Instances."""
50
51 @staticmethod
52 def _set_state_to_skipped(
53 tasks: Sequence[str | tuple[str, int]],
54 map_index: int | None,
55 ) -> None:
56 """
57 Set state of task instances to skipped from the same dag run.
58
59 Raises
60 ------
61 SkipDownstreamTaskInstances
62 If the task instances are not in the same dag run.
63 """
64 # Import is internal for backward compatibility when importing PythonOperator
65 # from airflow.providers.common.compat.standard.operators
66 from airflow.providers.common.compat.sdk import DownstreamTasksSkipped
67
68 # The following could be applied only for non-mapped tasks,
69 # as future mapped tasks have not been expanded yet. Such tasks
70 # have to be handled by NotPreviouslySkippedDep.
71 if tasks and map_index == -1:
72 raise DownstreamTasksSkipped(tasks=tasks)
73
74 def skip(
75 self,
76 ti: RuntimeTaskInstanceProtocol,
77 tasks: Iterable[DAGNode],
78 ):
79 """
80 Set tasks instances to skipped from the same dag run.
81
82 If this instance has a `task_id` attribute, store the list of skipped task IDs to XCom
83 so that NotPreviouslySkippedDep knows these tasks should be skipped when they
84 are cleared.
85
86 :param ti: the task instance for which to set the tasks to skipped
87 :param tasks: tasks to skip (not task_ids)
88 """
89 # SkipMixin may not necessarily have a task_id attribute. Only store to XCom if one is available.
90 task_id: str | None = getattr(self, "task_id", None)
91 task_list = _ensure_tasks(tasks)
92 if not task_list:
93 return
94
95 task_ids_list = [d.task_id for d in task_list]
96
97 if task_id is not None:
98 ti.xcom_push(
99 key=XCOM_SKIPMIXIN_KEY,
100 value={XCOM_SKIPMIXIN_SKIPPED: task_ids_list},
101 )
102
103 self._set_state_to_skipped(task_ids_list, ti.map_index)
104
105 def skip_all_except(
106 self,
107 ti: RuntimeTaskInstanceProtocol,
108 branch_task_ids: None | str | Iterable[str],
109 ):
110 """
111 Implement the logic for a branching operator.
112
113 Given a single task ID or list of task IDs to follow, this skips all other tasks
114 immediately downstream of this operator.
115
116 branch_task_ids is stored to XCom so that NotPreviouslySkippedDep knows skipped tasks or
117 newly added tasks should be skipped when they are cleared.
118 """
119 # Ensure we don't serialize a generator object
120 if branch_task_ids and isinstance(branch_task_ids, GeneratorType):
121 branch_task_ids = list(branch_task_ids)
122 log = self.log # Note: need to catch logger form instance, static logger breaks pytest
123 if isinstance(branch_task_ids, str):
124 branch_task_id_set = {branch_task_ids}
125 elif isinstance(branch_task_ids, Iterable):
126 # Handle the case where invalid values are passed as elements of an Iterable
127 # Non-string values are considered invalid elements
128 branch_task_id_set = set(branch_task_ids)
129 invalid_task_ids_type = {
130 (bti, type(bti).__name__) for bti in branch_task_id_set if not isinstance(bti, str)
131 }
132 if invalid_task_ids_type:
133 raise AirflowException(
134 f"Unable to branch to the specified tasks. "
135 f"The branching function returned invalid 'branch_task_ids': {invalid_task_ids_type}. "
136 f"Please check that your function returns an Iterable of valid task IDs that exist in your DAG."
137 )
138 elif branch_task_ids is None:
139 branch_task_id_set = set()
140 else:
141 raise AirflowException(
142 "'branch_task_ids' must be either None, a task ID, or an Iterable of IDs, "
143 f"but got {type(branch_task_ids).__name__!r}."
144 )
145
146 log.info("Following branch %s", branch_task_id_set)
147
148 if TYPE_CHECKING:
149 assert ti.task
150
151 task = ti.task
152 dag = ti.task.dag
153
154 valid_task_ids = set(dag.task_ids)
155 invalid_task_ids = branch_task_id_set - valid_task_ids
156 if invalid_task_ids:
157 raise AirflowException(
158 "'branch_task_ids' must contain only valid task_ids. "
159 f"Invalid tasks found: {invalid_task_ids}."
160 )
161
162 downstream_tasks = _ensure_tasks(task.downstream_list)
163
164 if downstream_tasks:
165 # For a branching workflow that looks like this, when "branch" does skip_all_except("task1"),
166 # we intuitively expect both "task1" and "join" to execute even though strictly speaking,
167 # "join" is also immediately downstream of "branch" and should have been skipped. Therefore,
168 # we need a special case here for such empty branches: Check downstream tasks of branch_task_ids.
169 # In case the task to skip is also downstream of branch_task_ids, we add it to branch_task_ids and
170 # exclude it from skipping.
171 #
172 # branch -----> join
173 # \ ^
174 # v /
175 # task1
176 #
177 for branch_task_id in list(branch_task_id_set):
178 branch_task_id_set.update(dag.get_task(branch_task_id).get_flat_relative_ids(upstream=False))
179
180 skip_tasks = [
181 (t.task_id, ti.map_index) for t in downstream_tasks if t.task_id not in branch_task_id_set
182 ]
183
184 follow_task_ids = [t.task_id for t in downstream_tasks if t.task_id in branch_task_id_set]
185 log.info("Skipping tasks %s", skip_tasks)
186 ti.xcom_push(
187 key=XCOM_SKIPMIXIN_KEY,
188 value={XCOM_SKIPMIXIN_FOLLOWED: follow_task_ids},
189 )
190 # The following could be applied only for non-mapped tasks,
191 # as future mapped tasks have not been expanded yet. Such tasks
192 # have to be handled by NotPreviouslySkippedDep.
193 self._set_state_to_skipped(skip_tasks, ti.map_index) # type: ignore[arg-type]