Coverage for /pythoncovmergedfiles/medio/medio/src/airflow/build/lib/airflow/ti_deps/deps/trigger_rule_dep.py: 15%
175 statements
« prev ^ index » next coverage.py v7.0.1, created at 2022-12-25 06:11 +0000
« prev ^ index » next coverage.py v7.0.1, created at 2022-12-25 06:11 +0000
1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18from __future__ import annotations
20import collections
21import collections.abc
22import functools
23from typing import TYPE_CHECKING, Iterator, NamedTuple
25from sqlalchemy import and_, func, or_
27from airflow.ti_deps.dep_context import DepContext
28from airflow.ti_deps.deps.base_ti_dep import BaseTIDep, TIDepStatus
29from airflow.utils.state import TaskInstanceState
30from airflow.utils.trigger_rule import TriggerRule as TR
32if TYPE_CHECKING:
33 from sqlalchemy.orm import Session
34 from sqlalchemy.sql.expression import ColumnOperators
36 from airflow.models.taskinstance import TaskInstance
39class _UpstreamTIStates(NamedTuple):
40 """States of the upstream tis for a specific ti.
42 This is used to determine whether the specific ti can run in this iteration.
43 """
45 success: int
46 skipped: int
47 failed: int
48 upstream_failed: int
49 removed: int
50 done: int
52 @classmethod
53 def calculate(cls, finished_upstreams: Iterator[TaskInstance]) -> _UpstreamTIStates:
54 """Calculate states for a task instance.
56 :param ti: the ti that we want to calculate deps for
57 :param finished_tis: all the finished tasks of the dag_run
58 """
59 counter = collections.Counter(ti.state for ti in finished_upstreams)
60 return _UpstreamTIStates(
61 success=counter.get(TaskInstanceState.SUCCESS, 0),
62 skipped=counter.get(TaskInstanceState.SKIPPED, 0),
63 failed=counter.get(TaskInstanceState.FAILED, 0),
64 upstream_failed=counter.get(TaskInstanceState.UPSTREAM_FAILED, 0),
65 removed=counter.get(TaskInstanceState.REMOVED, 0),
66 done=sum(counter.values()),
67 )
70class TriggerRuleDep(BaseTIDep):
71 """
72 Determines if a task's upstream tasks are in a state that allows a given task instance
73 to run.
74 """
76 NAME = "Trigger Rule"
77 IGNORABLE = True
78 IS_TASK_DEP = True
80 def _get_dep_statuses(
81 self,
82 ti: TaskInstance,
83 session: Session,
84 dep_context: DepContext,
85 ) -> Iterator[TIDepStatus]:
86 # Checking that all upstream dependencies have succeeded.
87 if not ti.task.upstream_task_ids:
88 yield self._passing_status(reason="The task instance did not have any upstream tasks.")
89 return
90 if ti.task.trigger_rule == TR.ALWAYS:
91 yield self._passing_status(reason="The task had a always trigger rule set.")
92 return
93 yield from self._evaluate_trigger_rule(ti=ti, dep_context=dep_context, session=session)
95 def _evaluate_trigger_rule(
96 self,
97 *,
98 ti: TaskInstance,
99 dep_context: DepContext,
100 session: Session,
101 ) -> Iterator[TIDepStatus]:
102 """Evaluate whether ``ti``'s trigger rule was met.
104 :param ti: Task instance to evaluate the trigger rule of.
105 :param dep_context: The current dependency context.
106 :param session: Database session.
107 """
108 from airflow.models.operator import needs_expansion
109 from airflow.models.taskinstance import TaskInstance
111 task = ti.task
112 upstream_tasks = {t.task_id: t for t in task.upstream_list}
113 trigger_rule = task.trigger_rule
115 @functools.lru_cache()
116 def _get_expanded_ti_count() -> int:
117 """Get how many tis the current task is supposed to be expanded into.
119 This extra closure allows us to query the database only when needed,
120 and at most once.
121 """
122 return task.get_mapped_ti_count(ti.run_id, session=session)
124 @functools.lru_cache()
125 def _get_relevant_upstream_map_indexes(upstream_id: str) -> int | range | None:
126 """Get the given task's map indexes relevant to the current ti.
128 This extra closure allows us to query the database only when needed,
129 and at most once for each task (instead of once for each expanded
130 task instance of the same task).
131 """
132 return ti.get_relevant_upstream_map_indexes(
133 upstream_tasks[upstream_id],
134 _get_expanded_ti_count(),
135 session=session,
136 )
138 def _is_relevant_upstream(upstream: TaskInstance) -> bool:
139 """Whether a task instance is a "relevant upstream" of the current task."""
140 # Not actually an upstream task.
141 if upstream.task_id not in task.upstream_task_ids:
142 return False
143 # The current task is not in a mapped task group. All tis from an
144 # upstream task are relevant.
145 if task.get_closest_mapped_task_group() is None:
146 return True
147 # The upstream ti is not expanded. The upstream may be mapped or
148 # not, but the ti is relevant either way.
149 if upstream.map_index < 0:
150 return True
151 # Now we need to perform fine-grained check on whether this specific
152 # upstream ti's map index is relevant.
153 relevant = _get_relevant_upstream_map_indexes(upstream.task_id)
154 if relevant is None:
155 return True
156 if relevant == upstream.map_index:
157 return True
158 if isinstance(relevant, collections.abc.Container) and upstream.map_index in relevant:
159 return True
160 return False
162 finished_upstream_tis = (
163 finished_ti
164 for finished_ti in dep_context.ensure_finished_tis(ti.get_dagrun(session), session)
165 if _is_relevant_upstream(finished_ti)
166 )
167 upstream_states = _UpstreamTIStates.calculate(finished_upstream_tis)
169 success = upstream_states.success
170 skipped = upstream_states.skipped
171 failed = upstream_states.failed
172 upstream_failed = upstream_states.upstream_failed
173 removed = upstream_states.removed
174 done = upstream_states.done
176 def _iter_upstream_conditions() -> Iterator[ColumnOperators]:
177 # Optimization: If the current task is not in a mapped task group,
178 # it depends on all upstream task instances.
179 if task.get_closest_mapped_task_group() is None:
180 yield TaskInstance.task_id.in_(upstream_tasks)
181 return
182 # Otherwise we need to figure out which map indexes are depended on
183 # for each upstream by the current task instance.
184 for upstream_id in upstream_tasks:
185 map_indexes = _get_relevant_upstream_map_indexes(upstream_id)
186 if map_indexes is None: # All tis of this upstream are dependencies.
187 yield (TaskInstance.task_id == upstream_id)
188 continue
189 # At this point we know we want to depend on only selected tis
190 # of this upstream task. Since the upstream may not have been
191 # expanded at this point, we also depend on the non-expanded ti
192 # to ensure at least one ti is included for the task.
193 yield and_(TaskInstance.task_id == upstream_id, TaskInstance.map_index < 0)
194 if isinstance(map_indexes, range) and map_indexes.step == 1:
195 yield and_(
196 TaskInstance.task_id == upstream_id,
197 TaskInstance.map_index >= map_indexes.start,
198 TaskInstance.map_index < map_indexes.stop,
199 )
200 elif isinstance(map_indexes, collections.abc.Container):
201 yield and_(TaskInstance.task_id == upstream_id, TaskInstance.map_index.in_(map_indexes))
202 else:
203 yield and_(TaskInstance.task_id == upstream_id, TaskInstance.map_index == map_indexes)
205 # Optimization: Don't need to hit the database if all upstreams are
206 # "simple" tasks (no task or task group mapping involved).
207 if not any(needs_expansion(t) for t in upstream_tasks.values()):
208 upstream = len(upstream_tasks)
209 else:
210 upstream = (
211 session.query(func.count())
212 .filter(TaskInstance.dag_id == ti.dag_id, TaskInstance.run_id == ti.run_id)
213 .filter(or_(*_iter_upstream_conditions()))
214 .scalar()
215 )
216 upstream_done = done >= upstream
218 changed = False
219 if dep_context.flag_upstream_failed:
220 if trigger_rule == TR.ALL_SUCCESS:
221 if upstream_failed or failed:
222 changed = ti.set_state(TaskInstanceState.UPSTREAM_FAILED, session)
223 elif skipped:
224 changed = ti.set_state(TaskInstanceState.SKIPPED, session)
225 elif removed and success and ti.map_index > -1:
226 if ti.map_index >= success:
227 changed = ti.set_state(TaskInstanceState.REMOVED, session)
228 elif trigger_rule == TR.ALL_FAILED:
229 if success or skipped:
230 changed = ti.set_state(TaskInstanceState.SKIPPED, session)
231 elif trigger_rule == TR.ONE_SUCCESS:
232 if upstream_done and done == skipped:
233 # if upstream is done and all are skipped mark as skipped
234 changed = ti.set_state(TaskInstanceState.SKIPPED, session)
235 elif upstream_done and success <= 0:
236 # if upstream is done and there are no success mark as upstream failed
237 changed = ti.set_state(TaskInstanceState.UPSTREAM_FAILED, session)
238 elif trigger_rule == TR.ONE_FAILED:
239 if upstream_done and not (failed or upstream_failed):
240 changed = ti.set_state(TaskInstanceState.SKIPPED, session)
241 elif trigger_rule == TR.ONE_DONE:
242 if upstream_done and not (failed or success):
243 changed = ti.set_state(TaskInstanceState.SKIPPED, session)
244 elif trigger_rule == TR.NONE_FAILED:
245 if upstream_failed or failed:
246 changed = ti.set_state(TaskInstanceState.UPSTREAM_FAILED, session)
247 elif trigger_rule == TR.NONE_FAILED_MIN_ONE_SUCCESS:
248 if upstream_failed or failed:
249 changed = ti.set_state(TaskInstanceState.UPSTREAM_FAILED, session)
250 elif skipped == upstream:
251 changed = ti.set_state(TaskInstanceState.SKIPPED, session)
252 elif trigger_rule == TR.NONE_SKIPPED:
253 if skipped:
254 changed = ti.set_state(TaskInstanceState.SKIPPED, session)
255 elif trigger_rule == TR.ALL_SKIPPED:
256 if success or failed:
257 changed = ti.set_state(TaskInstanceState.SKIPPED, session)
259 if changed:
260 dep_context.have_changed_ti_states = True
262 if trigger_rule == TR.ONE_SUCCESS:
263 if success <= 0:
264 yield self._failing_status(
265 reason=(
266 f"Task's trigger rule '{trigger_rule}' requires one upstream task success, "
267 f"but none were found. upstream_states={upstream_states}, "
268 f"upstream_task_ids={task.upstream_task_ids}"
269 )
270 )
271 elif trigger_rule == TR.ONE_FAILED:
272 if not failed and not upstream_failed:
273 yield self._failing_status(
274 reason=(
275 f"Task's trigger rule '{trigger_rule}' requires one upstream task failure, "
276 f"but none were found. upstream_states={upstream_states}, "
277 f"upstream_task_ids={task.upstream_task_ids}"
278 )
279 )
280 elif trigger_rule == TR.ONE_DONE:
281 if success + failed <= 0:
282 yield self._failing_status(
283 reason=(
284 f"Task's trigger rule '{trigger_rule}'"
285 "requires at least one upstream task failure or success"
286 f"but none were failed or success. upstream_states={upstream_states}, "
287 f"upstream_task_ids={task.upstream_task_ids}"
288 )
289 )
290 elif trigger_rule == TR.ALL_SUCCESS:
291 num_failures = upstream - success
292 if ti.map_index > -1:
293 num_failures -= removed
294 if num_failures > 0:
295 yield self._failing_status(
296 reason=(
297 f"Task's trigger rule '{trigger_rule}' requires all upstream tasks to have "
298 f"succeeded, but found {num_failures} non-success(es). "
299 f"upstream_states={upstream_states}, "
300 f"upstream_task_ids={task.upstream_task_ids}"
301 )
302 )
303 elif trigger_rule == TR.ALL_FAILED:
304 num_success = upstream - failed - upstream_failed
305 if ti.map_index > -1:
306 num_success -= removed
307 if num_success > 0:
308 yield self._failing_status(
309 reason=(
310 f"Task's trigger rule '{trigger_rule}' requires all upstream tasks to have failed, "
311 f"but found {num_success} non-failure(s). "
312 f"upstream_states={upstream_states}, "
313 f"upstream_task_ids={task.upstream_task_ids}"
314 )
315 )
316 elif trigger_rule == TR.ALL_DONE:
317 if not upstream_done:
318 yield self._failing_status(
319 reason=(
320 f"Task's trigger rule '{trigger_rule}' requires all upstream tasks to have "
321 f"completed, but found {upstream_done} task(s) that were not done. "
322 f"upstream_states={upstream_states}, "
323 f"upstream_task_ids={task.upstream_task_ids}"
324 )
325 )
326 elif trigger_rule == TR.NONE_FAILED:
327 num_failures = upstream - success - skipped
328 if ti.map_index > -1:
329 num_failures -= removed
330 if num_failures > 0:
331 yield self._failing_status(
332 reason=(
333 f"Task's trigger rule '{trigger_rule}' requires all upstream tasks to have "
334 f"succeeded or been skipped, but found {num_failures} non-success(es). "
335 f"upstream_states={upstream_states}, "
336 f"upstream_task_ids={task.upstream_task_ids}"
337 )
338 )
339 elif trigger_rule == TR.NONE_FAILED_MIN_ONE_SUCCESS:
340 num_failures = upstream - success - skipped
341 if ti.map_index > -1:
342 num_failures -= removed
343 if num_failures > 0:
344 yield self._failing_status(
345 reason=(
346 f"Task's trigger rule '{trigger_rule}' requires all upstream tasks to have "
347 f"succeeded or been skipped, but found {num_failures} non-success(es). "
348 f"upstream_states={upstream_states}, "
349 f"upstream_task_ids={task.upstream_task_ids}"
350 )
351 )
352 elif trigger_rule == TR.NONE_SKIPPED:
353 if not upstream_done or (skipped > 0):
354 yield self._failing_status(
355 reason=(
356 f"Task's trigger rule '{trigger_rule}' requires all upstream tasks to not have been "
357 f"skipped, but found {skipped} task(s) skipped. "
358 f"upstream_states={upstream_states}, "
359 f"upstream_task_ids={task.upstream_task_ids}"
360 )
361 )
362 elif trigger_rule == TR.ALL_SKIPPED:
363 num_non_skipped = upstream - skipped
364 if num_non_skipped > 0:
365 yield self._failing_status(
366 reason=(
367 f"Task's trigger rule '{trigger_rule}' requires all upstream tasks to have been "
368 f"skipped, but found {num_non_skipped} task(s) in non skipped state. "
369 f"upstream_states={upstream_states}, "
370 f"upstream_task_ids={task.upstream_task_ids}"
371 )
372 )
373 else:
374 yield self._failing_status(reason=f"No strategy to evaluate trigger rule '{trigger_rule}'.")