Coverage for /pythoncovmergedfiles/medio/medio/src/airflow/build/lib/airflow/models/skipmixin.py: 25%
79 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 06:35 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 06:35 +0000
1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18from __future__ import annotations
20import warnings
21from typing import TYPE_CHECKING, Iterable, Sequence
23from airflow.exceptions import AirflowException, RemovedInAirflow3Warning
24from airflow.models.taskinstance import TaskInstance
25from airflow.serialization.pydantic.dag_run import DagRunPydantic
26from airflow.utils import timezone
27from airflow.utils.log.logging_mixin import LoggingMixin
28from airflow.utils.session import NEW_SESSION, create_session, provide_session
29from airflow.utils.state import State
31if TYPE_CHECKING:
32 from pendulum import DateTime
33 from sqlalchemy import Session
35 from airflow.models.dagrun import DagRun
36 from airflow.models.operator import Operator
37 from airflow.models.taskmixin import DAGNode
38 from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic
40# The key used by SkipMixin to store XCom data.
41XCOM_SKIPMIXIN_KEY = "skipmixin_key"
43# The dictionary key used to denote task IDs that are skipped
44XCOM_SKIPMIXIN_SKIPPED = "skipped"
46# The dictionary key used to denote task IDs that are followed
47XCOM_SKIPMIXIN_FOLLOWED = "followed"
50def _ensure_tasks(nodes: Iterable[DAGNode]) -> Sequence[Operator]:
51 from airflow.models.baseoperator import BaseOperator
52 from airflow.models.mappedoperator import MappedOperator
54 return [n for n in nodes if isinstance(n, (BaseOperator, MappedOperator))]
57class SkipMixin(LoggingMixin):
58 """A Mixin to skip Tasks Instances."""
60 def _set_state_to_skipped(
61 self,
62 dag_run: DagRun | DagRunPydantic,
63 tasks: Iterable[Operator],
64 session: Session,
65 ) -> None:
66 """Used internally to set state of task instances to skipped from the same dag run."""
67 now = timezone.utcnow()
69 session.query(TaskInstance).filter(
70 TaskInstance.dag_id == dag_run.dag_id,
71 TaskInstance.run_id == dag_run.run_id,
72 TaskInstance.task_id.in_(d.task_id for d in tasks),
73 ).update(
74 {
75 TaskInstance.state: State.SKIPPED,
76 TaskInstance.start_date: now,
77 TaskInstance.end_date: now,
78 },
79 synchronize_session=False,
80 )
82 @provide_session
83 def skip(
84 self,
85 dag_run: DagRun | DagRunPydantic,
86 execution_date: DateTime,
87 tasks: Iterable[DAGNode],
88 session: Session = NEW_SESSION,
89 map_index: int = -1,
90 ):
91 """
92 Sets tasks instances to skipped from the same dag run.
94 If this instance has a `task_id` attribute, store the list of skipped task IDs to XCom
95 so that NotPreviouslySkippedDep knows these tasks should be skipped when they
96 are cleared.
98 :param dag_run: the DagRun for which to set the tasks to skipped
99 :param execution_date: execution_date
100 :param tasks: tasks to skip (not task_ids)
101 :param session: db session to use
102 :param map_index: map_index of the current task instance
103 """
104 task_list = _ensure_tasks(tasks)
105 if not task_list:
106 return
108 if execution_date and not dag_run:
109 from airflow.models.dagrun import DagRun
111 warnings.warn(
112 "Passing an execution_date to `skip()` is deprecated in favour of passing a dag_run",
113 RemovedInAirflow3Warning,
114 stacklevel=2,
115 )
117 dag_run = (
118 session.query(DagRun)
119 .filter(
120 DagRun.dag_id == task_list[0].dag_id,
121 DagRun.execution_date == execution_date,
122 )
123 .one()
124 )
125 elif execution_date and dag_run and execution_date != dag_run.execution_date:
126 raise ValueError(
127 "execution_date has a different value to dag_run.execution_date -- please only pass dag_run"
128 )
130 if dag_run is None:
131 raise ValueError("dag_run is required")
133 self._set_state_to_skipped(dag_run, task_list, session)
134 session.commit()
136 # SkipMixin may not necessarily have a task_id attribute. Only store to XCom if one is available.
137 task_id: str | None = getattr(self, "task_id", None)
138 if task_id is not None:
139 from airflow.models.xcom import XCom
141 XCom.set(
142 key=XCOM_SKIPMIXIN_KEY,
143 value={XCOM_SKIPMIXIN_SKIPPED: [d.task_id for d in task_list]},
144 task_id=task_id,
145 dag_id=dag_run.dag_id,
146 run_id=dag_run.run_id,
147 map_index=map_index,
148 session=session,
149 )
151 def skip_all_except(
152 self,
153 ti: TaskInstance | TaskInstancePydantic,
154 branch_task_ids: None | str | Iterable[str],
155 ):
156 """
157 This method implements the logic for a branching operator; given a single
158 task ID or list of task IDs to follow, this skips all other tasks
159 immediately downstream of this operator.
161 branch_task_ids is stored to XCom so that NotPreviouslySkippedDep knows skipped tasks or
162 newly added tasks should be skipped when they are cleared.
163 """
164 self.log.info("Following branch %s", branch_task_ids)
165 if isinstance(branch_task_ids, str):
166 branch_task_id_set = {branch_task_ids}
167 elif isinstance(branch_task_ids, Iterable):
168 branch_task_id_set = set(branch_task_ids)
169 invalid_task_ids_type = {
170 (bti, type(bti).__name__) for bti in branch_task_ids if not isinstance(bti, str)
171 }
172 if invalid_task_ids_type:
173 raise AirflowException(
174 f"'branch_task_ids' expected all task IDs are strings. "
175 f"Invalid tasks found: {invalid_task_ids_type}."
176 )
177 elif branch_task_ids is None:
178 branch_task_id_set = set()
179 else:
180 raise AirflowException(
181 "'branch_task_ids' must be either None, a task ID, or an Iterable of IDs, "
182 f"but got {type(branch_task_ids).__name__!r}."
183 )
185 dag_run = ti.get_dagrun()
187 # TODO(potiuk): Handle TaskInstancePydantic case differently - we need to figure out the way to
188 # pass task that has been set in LocalTaskJob but in the way that TaskInstancePydantic definition
189 # does not attempt to serialize the field from/to ORM
190 task = ti.task # type: ignore[union-attr]
191 dag = task.dag
192 if TYPE_CHECKING:
193 assert dag
195 valid_task_ids = set(dag.task_ids)
196 invalid_task_ids = branch_task_id_set - valid_task_ids
197 if invalid_task_ids:
198 raise AirflowException(
199 "'branch_task_ids' must contain only valid task_ids. "
200 f"Invalid tasks found: {invalid_task_ids}."
201 )
203 downstream_tasks = _ensure_tasks(task.downstream_list)
205 if downstream_tasks:
206 # For a branching workflow that looks like this, when "branch" does skip_all_except("task1"),
207 # we intuitively expect both "task1" and "join" to execute even though strictly speaking,
208 # "join" is also immediately downstream of "branch" and should have been skipped. Therefore,
209 # we need a special case here for such empty branches: Check downstream tasks of branch_task_ids.
210 # In case the task to skip is also downstream of branch_task_ids, we add it to branch_task_ids and
211 # exclude it from skipping.
212 #
213 # branch -----> join
214 # \ ^
215 # v /
216 # task1
217 #
218 for branch_task_id in list(branch_task_id_set):
219 branch_task_id_set.update(dag.get_task(branch_task_id).get_flat_relative_ids(upstream=False))
221 skip_tasks = [t for t in downstream_tasks if t.task_id not in branch_task_id_set]
222 follow_task_ids = [t.task_id for t in downstream_tasks if t.task_id in branch_task_id_set]
224 self.log.info("Skipping tasks %s", [t.task_id for t in skip_tasks])
225 with create_session() as session:
226 self._set_state_to_skipped(dag_run, skip_tasks, session=session)
227 # For some reason, session.commit() needs to happen before xcom_push.
228 # Otherwise the session is not committed.
229 session.commit()
230 ti.xcom_push(key=XCOM_SKIPMIXIN_KEY, value={XCOM_SKIPMIXIN_FOLLOWED: follow_task_ids})