Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/airflow/ti_deps/deps/trigger_rule_dep.py: 14%
215 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 06:35 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 06:35 +0000
1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18from __future__ import annotations
20import collections
21import collections.abc
22import functools
23from typing import TYPE_CHECKING, Iterator, NamedTuple
25from sqlalchemy import and_, func, or_
27from airflow.models import MappedOperator
28from airflow.models.taskinstance import PAST_DEPENDS_MET
29from airflow.ti_deps.dep_context import DepContext
30from airflow.ti_deps.deps.base_ti_dep import BaseTIDep, TIDepStatus
31from airflow.utils.state import TaskInstanceState
32from airflow.utils.trigger_rule import TriggerRule as TR
34if TYPE_CHECKING:
35 from sqlalchemy.orm import Session
36 from sqlalchemy.sql.expression import ColumnOperators
38 from airflow.models.taskinstance import TaskInstance
41class _UpstreamTIStates(NamedTuple):
42 """States of the upstream tis for a specific ti.
44 This is used to determine whether the specific ti can run in this iteration.
45 """
47 success: int
48 skipped: int
49 failed: int
50 upstream_failed: int
51 removed: int
52 done: int
53 success_setup: int
54 skipped_setup: int
56 @classmethod
57 def calculate(cls, finished_upstreams: Iterator[TaskInstance]) -> _UpstreamTIStates:
58 """Calculate states for a task instance.
60 ``counter`` is inclusive of ``setup_counter`` -- e.g. if there are 2 skipped upstreams, one
61 of which is a setup, then counter will show 2 skipped and setup counter will show 1.
63 :param ti: the ti that we want to calculate deps for
64 :param finished_tis: all the finished tasks of the dag_run
65 """
66 counter: dict[str, int] = collections.Counter()
67 setup_counter: dict[str, int] = collections.Counter()
68 for ti in finished_upstreams:
69 curr_state = {ti.state: 1}
70 counter.update(curr_state)
71 # setup task cannot be mapped
72 if not isinstance(ti.task, MappedOperator) and ti.task.is_setup:
73 setup_counter.update(curr_state)
74 return _UpstreamTIStates(
75 success=counter.get(TaskInstanceState.SUCCESS, 0),
76 skipped=counter.get(TaskInstanceState.SKIPPED, 0),
77 failed=counter.get(TaskInstanceState.FAILED, 0),
78 upstream_failed=counter.get(TaskInstanceState.UPSTREAM_FAILED, 0),
79 removed=counter.get(TaskInstanceState.REMOVED, 0),
80 done=sum(counter.values()),
81 success_setup=setup_counter.get(TaskInstanceState.SUCCESS, 0),
82 skipped_setup=setup_counter.get(TaskInstanceState.SKIPPED, 0),
83 )
86class TriggerRuleDep(BaseTIDep):
87 """
88 Determines if a task's upstream tasks are in a state that allows a given task instance
89 to run.
90 """
92 NAME = "Trigger Rule"
93 IGNORABLE = True
94 IS_TASK_DEP = True
96 def _get_dep_statuses(
97 self,
98 ti: TaskInstance,
99 session: Session,
100 dep_context: DepContext,
101 ) -> Iterator[TIDepStatus]:
102 # Checking that all upstream dependencies have succeeded.
103 if not ti.task.upstream_task_ids:
104 yield self._passing_status(reason="The task instance did not have any upstream tasks.")
105 return
106 if ti.task.trigger_rule == TR.ALWAYS:
107 yield self._passing_status(reason="The task had a always trigger rule set.")
108 return
109 yield from self._evaluate_trigger_rule(ti=ti, dep_context=dep_context, session=session)
111 def _evaluate_trigger_rule(
112 self,
113 *,
114 ti: TaskInstance,
115 dep_context: DepContext,
116 session: Session,
117 ) -> Iterator[TIDepStatus]:
118 """Evaluate whether ``ti``'s trigger rule was met.
120 :param ti: Task instance to evaluate the trigger rule of.
121 :param dep_context: The current dependency context.
122 :param session: Database session.
123 """
124 from airflow.models.abstractoperator import NotMapped
125 from airflow.models.expandinput import NotFullyPopulated
126 from airflow.models.operator import needs_expansion
127 from airflow.models.taskinstance import TaskInstance
129 task = ti.task
130 upstream_tasks = {t.task_id: t for t in task.upstream_list}
131 trigger_rule = task.trigger_rule
133 @functools.lru_cache
134 def _get_expanded_ti_count() -> int:
135 """Get how many tis the current task is supposed to be expanded into.
137 This extra closure allows us to query the database only when needed,
138 and at most once.
139 """
140 return task.get_mapped_ti_count(ti.run_id, session=session)
142 @functools.lru_cache
143 def _get_relevant_upstream_map_indexes(upstream_id: str) -> int | range | None:
144 """Get the given task's map indexes relevant to the current ti.
146 This extra closure allows us to query the database only when needed,
147 and at most once for each task (instead of once for each expanded
148 task instance of the same task).
149 """
150 try:
151 expanded_ti_count = _get_expanded_ti_count()
152 except (NotFullyPopulated, NotMapped):
153 return None
154 return ti.get_relevant_upstream_map_indexes(
155 upstream_tasks[upstream_id],
156 expanded_ti_count,
157 session=session,
158 )
160 def _is_relevant_upstream(upstream: TaskInstance) -> bool:
161 """Whether a task instance is a "relevant upstream" of the current task."""
162 # Not actually an upstream task.
163 if upstream.task_id not in task.upstream_task_ids:
164 return False
165 # The current task is not in a mapped task group. All tis from an
166 # upstream task are relevant.
167 if task.get_closest_mapped_task_group() is None:
168 return True
169 # The upstream ti is not expanded. The upstream may be mapped or
170 # not, but the ti is relevant either way.
171 if upstream.map_index < 0:
172 return True
173 # Now we need to perform fine-grained check on whether this specific
174 # upstream ti's map index is relevant.
175 relevant = _get_relevant_upstream_map_indexes(upstream.task_id)
176 if relevant is None:
177 return True
178 if relevant == upstream.map_index:
179 return True
180 if isinstance(relevant, collections.abc.Container) and upstream.map_index in relevant:
181 return True
182 return False
184 finished_upstream_tis = (
185 finished_ti
186 for finished_ti in dep_context.ensure_finished_tis(ti.get_dagrun(session), session)
187 if _is_relevant_upstream(finished_ti)
188 )
189 upstream_states = _UpstreamTIStates.calculate(finished_upstream_tis)
191 success = upstream_states.success
192 skipped = upstream_states.skipped
193 failed = upstream_states.failed
194 upstream_failed = upstream_states.upstream_failed
195 removed = upstream_states.removed
196 done = upstream_states.done
197 success_setup = upstream_states.success_setup
198 skipped_setup = upstream_states.skipped_setup
200 def _iter_upstream_conditions() -> Iterator[ColumnOperators]:
201 # Optimization: If the current task is not in a mapped task group,
202 # it depends on all upstream task instances.
203 if task.get_closest_mapped_task_group() is None:
204 yield TaskInstance.task_id.in_(upstream_tasks)
205 return
206 # Otherwise we need to figure out which map indexes are depended on
207 # for each upstream by the current task instance.
208 for upstream_id in upstream_tasks:
209 map_indexes = _get_relevant_upstream_map_indexes(upstream_id)
210 if map_indexes is None: # All tis of this upstream are dependencies.
211 yield (TaskInstance.task_id == upstream_id)
212 continue
213 # At this point we know we want to depend on only selected tis
214 # of this upstream task. Since the upstream may not have been
215 # expanded at this point, we also depend on the non-expanded ti
216 # to ensure at least one ti is included for the task.
217 yield and_(TaskInstance.task_id == upstream_id, TaskInstance.map_index < 0)
218 if isinstance(map_indexes, range) and map_indexes.step == 1:
219 yield and_(
220 TaskInstance.task_id == upstream_id,
221 TaskInstance.map_index >= map_indexes.start,
222 TaskInstance.map_index < map_indexes.stop,
223 )
224 elif isinstance(map_indexes, collections.abc.Container):
225 yield and_(TaskInstance.task_id == upstream_id, TaskInstance.map_index.in_(map_indexes))
226 else:
227 yield and_(TaskInstance.task_id == upstream_id, TaskInstance.map_index == map_indexes)
229 # Optimization: Don't need to hit the database if all upstreams are
230 # "simple" tasks (no task or task group mapping involved).
231 if not any(needs_expansion(t) for t in upstream_tasks.values()):
232 upstream = len(upstream_tasks)
233 upstream_setup = len(
234 [x for x in upstream_tasks.values() if not isinstance(x, MappedOperator) and x.is_setup]
235 )
236 else:
237 upstream = (
238 session.query(func.count())
239 .filter(TaskInstance.dag_id == ti.dag_id, TaskInstance.run_id == ti.run_id)
240 .filter(or_(*_iter_upstream_conditions()))
241 .scalar()
242 )
243 # todo: add support for mapped setup?
244 upstream_setup = None
245 upstream_done = done >= upstream
247 changed = False
248 new_state = None
249 if dep_context.flag_upstream_failed:
250 if trigger_rule == TR.ALL_SUCCESS:
251 if upstream_failed or failed:
252 new_state = TaskInstanceState.UPSTREAM_FAILED
253 elif skipped:
254 new_state = TaskInstanceState.SKIPPED
255 elif removed and success and ti.map_index > -1:
256 if ti.map_index >= success:
257 new_state = TaskInstanceState.REMOVED
258 elif trigger_rule == TR.ALL_FAILED:
259 if success or skipped:
260 new_state = TaskInstanceState.SKIPPED
261 elif trigger_rule == TR.ONE_SUCCESS:
262 if upstream_done and done == skipped:
263 # if upstream is done and all are skipped mark as skipped
264 new_state = TaskInstanceState.SKIPPED
265 elif upstream_done and success <= 0:
266 # if upstream is done and there are no success mark as upstream failed
267 new_state = TaskInstanceState.UPSTREAM_FAILED
268 elif trigger_rule == TR.ONE_FAILED:
269 if upstream_done and not (failed or upstream_failed):
270 new_state = TaskInstanceState.SKIPPED
271 elif trigger_rule == TR.ONE_DONE:
272 if upstream_done and not (failed or success):
273 new_state = TaskInstanceState.SKIPPED
274 elif trigger_rule == TR.NONE_FAILED:
275 if upstream_failed or failed:
276 new_state = TaskInstanceState.UPSTREAM_FAILED
277 elif trigger_rule == TR.NONE_FAILED_MIN_ONE_SUCCESS:
278 if upstream_failed or failed:
279 new_state = TaskInstanceState.UPSTREAM_FAILED
280 elif skipped == upstream:
281 new_state = TaskInstanceState.SKIPPED
282 elif trigger_rule == TR.NONE_SKIPPED:
283 if skipped:
284 new_state = TaskInstanceState.SKIPPED
285 elif trigger_rule == TR.ALL_SKIPPED:
286 if success or failed:
287 new_state = TaskInstanceState.SKIPPED
288 elif trigger_rule == TR.ALL_DONE_SETUP_SUCCESS:
289 if upstream_done and upstream_setup and skipped_setup >= upstream_setup:
290 # when there is an upstream setup and they have all skipped, then skip
291 new_state = TaskInstanceState.SKIPPED
292 elif upstream_done and upstream_setup and success_setup == 0:
293 # when there is an upstream setup, if none succeeded, mark upstream failed
294 # if at least one setup ran, we'll let it run
295 new_state = TaskInstanceState.UPSTREAM_FAILED
296 if new_state is not None:
297 if new_state == TaskInstanceState.SKIPPED and dep_context.wait_for_past_depends_before_skipping:
298 past_depends_met = ti.xcom_pull(
299 task_ids=ti.task_id, key=PAST_DEPENDS_MET, session=session, default=False
300 )
301 if not past_depends_met:
302 yield self._failing_status(
303 reason=("Task should be skipped but the the past depends are not met")
304 )
305 return
306 changed = ti.set_state(new_state, session)
308 if changed:
309 dep_context.have_changed_ti_states = True
311 if trigger_rule == TR.ONE_SUCCESS:
312 if success <= 0:
313 yield self._failing_status(
314 reason=(
315 f"Task's trigger rule '{trigger_rule}' requires one upstream task success, "
316 f"but none were found. upstream_states={upstream_states}, "
317 f"upstream_task_ids={task.upstream_task_ids}"
318 )
319 )
320 elif trigger_rule == TR.ONE_FAILED:
321 if not failed and not upstream_failed:
322 yield self._failing_status(
323 reason=(
324 f"Task's trigger rule '{trigger_rule}' requires one upstream task failure, "
325 f"but none were found. upstream_states={upstream_states}, "
326 f"upstream_task_ids={task.upstream_task_ids}"
327 )
328 )
329 elif trigger_rule == TR.ONE_DONE:
330 if success + failed <= 0:
331 yield self._failing_status(
332 reason=(
333 f"Task's trigger rule '{trigger_rule}'"
334 "requires at least one upstream task failure or success"
335 f"but none were failed or success. upstream_states={upstream_states}, "
336 f"upstream_task_ids={task.upstream_task_ids}"
337 )
338 )
339 elif trigger_rule == TR.ALL_SUCCESS:
340 num_failures = upstream - success
341 if ti.map_index > -1:
342 num_failures -= removed
343 if num_failures > 0:
344 yield self._failing_status(
345 reason=(
346 f"Task's trigger rule '{trigger_rule}' requires all upstream tasks to have "
347 f"succeeded, but found {num_failures} non-success(es). "
348 f"upstream_states={upstream_states}, "
349 f"upstream_task_ids={task.upstream_task_ids}"
350 )
351 )
352 elif trigger_rule == TR.ALL_FAILED:
353 num_success = upstream - failed - upstream_failed
354 if ti.map_index > -1:
355 num_success -= removed
356 if num_success > 0:
357 yield self._failing_status(
358 reason=(
359 f"Task's trigger rule '{trigger_rule}' requires all upstream tasks to have failed, "
360 f"but found {num_success} non-failure(s). "
361 f"upstream_states={upstream_states}, "
362 f"upstream_task_ids={task.upstream_task_ids}"
363 )
364 )
365 elif trigger_rule == TR.ALL_DONE:
366 if not upstream_done:
367 yield self._failing_status(
368 reason=(
369 f"Task's trigger rule '{trigger_rule}' requires all upstream tasks to have "
370 f"completed, but found {len(upstream_tasks) - done} task(s) that were not done. "
371 f"upstream_states={upstream_states}, "
372 f"upstream_task_ids={task.upstream_task_ids}"
373 )
374 )
375 elif trigger_rule == TR.NONE_FAILED:
376 num_failures = upstream - success - skipped
377 if ti.map_index > -1:
378 num_failures -= removed
379 if num_failures > 0:
380 yield self._failing_status(
381 reason=(
382 f"Task's trigger rule '{trigger_rule}' requires all upstream tasks to have "
383 f"succeeded or been skipped, but found {num_failures} non-success(es). "
384 f"upstream_states={upstream_states}, "
385 f"upstream_task_ids={task.upstream_task_ids}"
386 )
387 )
388 elif trigger_rule == TR.NONE_FAILED_MIN_ONE_SUCCESS:
389 num_failures = upstream - success - skipped
390 if ti.map_index > -1:
391 num_failures -= removed
392 if num_failures > 0:
393 yield self._failing_status(
394 reason=(
395 f"Task's trigger rule '{trigger_rule}' requires all upstream tasks to have "
396 f"succeeded or been skipped, but found {num_failures} non-success(es). "
397 f"upstream_states={upstream_states}, "
398 f"upstream_task_ids={task.upstream_task_ids}"
399 )
400 )
401 elif trigger_rule == TR.NONE_SKIPPED:
402 if not upstream_done or (skipped > 0):
403 yield self._failing_status(
404 reason=(
405 f"Task's trigger rule '{trigger_rule}' requires all upstream tasks to not have been "
406 f"skipped, but found {skipped} task(s) skipped. "
407 f"upstream_states={upstream_states}, "
408 f"upstream_task_ids={task.upstream_task_ids}"
409 )
410 )
411 elif trigger_rule == TR.ALL_SKIPPED:
412 num_non_skipped = upstream - skipped
413 if num_non_skipped > 0:
414 yield self._failing_status(
415 reason=(
416 f"Task's trigger rule '{trigger_rule}' requires all upstream tasks to have been "
417 f"skipped, but found {num_non_skipped} task(s) in non skipped state. "
418 f"upstream_states={upstream_states}, "
419 f"upstream_task_ids={task.upstream_task_ids}"
420 )
421 )
422 elif trigger_rule == TR.ALL_DONE_SETUP_SUCCESS:
423 if not upstream_done:
424 yield self._failing_status(
425 reason=(
426 f"Task's trigger rule '{trigger_rule}' requires all upstream tasks to have "
427 f"completed, but found {len(upstream_tasks) - done} task(s) that were not done. "
428 f"upstream_states={upstream_states}, "
429 f"upstream_task_ids={task.upstream_task_ids}"
430 )
431 )
432 elif upstream_setup is None: # for now, None only happens in mapped case
433 yield self._failing_status(
434 reason=(
435 f"Task's trigger rule '{trigger_rule}' cannot have mapped tasks as upstream. "
436 f"upstream_states={upstream_states}, "
437 f"upstream_task_ids={task.upstream_task_ids}"
438 )
439 )
440 elif upstream_setup and not success_setup >= 1:
441 yield self._failing_status(
442 reason=(
443 f"Task's trigger rule '{trigger_rule}' requires at least one upstream setup task be "
444 f"successful, but found {upstream_setup - success_setup} task(s) that were not. "
445 f"upstream_states={upstream_states}, "
446 f"upstream_task_ids={task.upstream_task_ids}"
447 )
448 )
449 else:
450 yield self._failing_status(reason=f"No strategy to evaluate trigger rule '{trigger_rule}'.")