1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18from __future__ import annotations
19
20import collections.abc
21import functools
22from collections import Counter
23from typing import TYPE_CHECKING, Iterator, KeysView, NamedTuple
24
25from sqlalchemy import and_, func, or_, select
26
27from airflow.models.taskinstance import PAST_DEPENDS_MET
28from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
29from airflow.utils.state import TaskInstanceState
30from airflow.utils.trigger_rule import TriggerRule as TR
31
32if TYPE_CHECKING:
33 from sqlalchemy.orm import Session
34 from sqlalchemy.sql.expression import ColumnOperators
35
36 from airflow import DAG
37 from airflow.models.taskinstance import TaskInstance
38 from airflow.ti_deps.dep_context import DepContext
39 from airflow.ti_deps.deps.base_ti_dep import TIDepStatus
40
41
42class _UpstreamTIStates(NamedTuple):
43 """States of the upstream tis for a specific ti.
44
45 This is used to determine whether the specific ti can run in this iteration.
46 """
47
48 success: int
49 skipped: int
50 failed: int
51 upstream_failed: int
52 removed: int
53 done: int
54 success_setup: int
55 skipped_setup: int
56
57 @classmethod
58 def calculate(cls, finished_upstreams: Iterator[TaskInstance]) -> _UpstreamTIStates:
59 """Calculate states for a task instance.
60
61 ``counter`` is inclusive of ``setup_counter`` -- e.g. if there are 2 skipped upstreams, one
62 of which is a setup, then counter will show 2 skipped and setup counter will show 1.
63
64 :param ti: the ti that we want to calculate deps for
65 :param finished_tis: all the finished tasks of the dag_run
66 """
67 counter: dict[str, int] = Counter()
68 setup_counter: dict[str, int] = Counter()
69 for ti in finished_upstreams:
70 if TYPE_CHECKING:
71 assert ti.task
72 curr_state = {ti.state: 1}
73 counter.update(curr_state)
74 if ti.task.is_setup:
75 setup_counter.update(curr_state)
76 return _UpstreamTIStates(
77 success=counter.get(TaskInstanceState.SUCCESS, 0),
78 skipped=counter.get(TaskInstanceState.SKIPPED, 0),
79 failed=counter.get(TaskInstanceState.FAILED, 0),
80 upstream_failed=counter.get(TaskInstanceState.UPSTREAM_FAILED, 0),
81 removed=counter.get(TaskInstanceState.REMOVED, 0),
82 done=sum(counter.values()),
83 success_setup=setup_counter.get(TaskInstanceState.SUCCESS, 0),
84 skipped_setup=setup_counter.get(TaskInstanceState.SKIPPED, 0),
85 )
86
87
88class TriggerRuleDep(BaseTIDep):
89 """Determines if a task's upstream tasks are in a state that allows a given task instance to run."""
90
91 NAME = "Trigger Rule"
92 IGNORABLE = True
93 IS_TASK_DEP = True
94
95 def _get_dep_statuses(
96 self,
97 ti: TaskInstance,
98 session: Session,
99 dep_context: DepContext,
100 ) -> Iterator[TIDepStatus]:
101 if TYPE_CHECKING:
102 assert ti.task
103
104 # Checking that all upstream dependencies have succeeded.
105 if not ti.task.upstream_task_ids:
106 yield self._passing_status(reason="The task instance did not have any upstream tasks.")
107 return
108 if ti.task.trigger_rule == TR.ALWAYS:
109 yield self._passing_status(reason="The task had a always trigger rule set.")
110 return
111 yield from self._evaluate_trigger_rule(ti=ti, dep_context=dep_context, session=session)
112
113 def _evaluate_trigger_rule(
114 self,
115 *,
116 ti: TaskInstance,
117 dep_context: DepContext,
118 session: Session,
119 ) -> Iterator[TIDepStatus]:
120 """Evaluate whether ``ti``'s trigger rule was met.
121
122 :param ti: Task instance to evaluate the trigger rule of.
123 :param dep_context: The current dependency context.
124 :param session: Database session.
125 """
126 from airflow.models.abstractoperator import NotMapped
127 from airflow.models.expandinput import NotFullyPopulated
128 from airflow.models.taskinstance import TaskInstance
129
130 @functools.lru_cache
131 def _get_expanded_ti_count() -> int:
132 """Get how many tis the current task is supposed to be expanded into.
133
134 This extra closure allows us to query the database only when needed,
135 and at most once.
136 """
137 if TYPE_CHECKING:
138 assert ti.task
139
140 return ti.task.get_mapped_ti_count(ti.run_id, session=session)
141
142 @functools.lru_cache
143 def _get_relevant_upstream_map_indexes(upstream_id: str) -> int | range | None:
144 """Get the given task's map indexes relevant to the current ti.
145
146 This extra closure allows us to query the database only when needed,
147 and at most once for each task (instead of once for each expanded
148 task instance of the same task).
149 """
150 if TYPE_CHECKING:
151 assert ti.task
152 assert isinstance(ti.task.dag, DAG)
153
154 try:
155 expanded_ti_count = _get_expanded_ti_count()
156 except (NotFullyPopulated, NotMapped):
157 return None
158 return ti.get_relevant_upstream_map_indexes(
159 upstream=ti.task.dag.task_dict[upstream_id],
160 ti_count=expanded_ti_count,
161 session=session,
162 )
163
164 def _is_relevant_upstream(upstream: TaskInstance, relevant_ids: set[str] | KeysView[str]) -> bool:
165 """
166 Whether a task instance is a "relevant upstream" of the current task.
167
168 This will return false if upstream.task_id is not in relevant_ids,
169 or if both of the following are true:
170 1. upstream.task_id in relevant_ids is True
171 2. ti is in a mapped task group and upstream has a map index
172 that ti does not depend on.
173 """
174 if TYPE_CHECKING:
175 assert ti.task
176
177 # Not actually an upstream task.
178 if upstream.task_id not in relevant_ids:
179 return False
180 # The current task is not in a mapped task group. All tis from an
181 # upstream task are relevant.
182 if ti.task.get_closest_mapped_task_group() is None:
183 return True
184 # The upstream ti is not expanded. The upstream may be mapped or
185 # not, but the ti is relevant either way.
186 if upstream.map_index < 0:
187 return True
188 # Now we need to perform fine-grained check on whether this specific
189 # upstream ti's map index is relevant.
190 relevant = _get_relevant_upstream_map_indexes(upstream_id=upstream.task_id)
191 if relevant is None:
192 return True
193 if relevant == upstream.map_index:
194 return True
195 if isinstance(relevant, collections.abc.Container) and upstream.map_index in relevant:
196 return True
197 return False
198
199 def _iter_upstream_conditions(relevant_tasks: dict) -> Iterator[ColumnOperators]:
200 # Optimization: If the current task is not in a mapped task group,
201 # it depends on all upstream task instances.
202 from airflow.models.taskinstance import TaskInstance
203
204 if TYPE_CHECKING:
205 assert ti.task
206
207 if ti.task.get_closest_mapped_task_group() is None:
208 yield TaskInstance.task_id.in_(relevant_tasks.keys())
209 return
210 # Otherwise we need to figure out which map indexes are depended on
211 # for each upstream by the current task instance.
212 for upstream_id in relevant_tasks:
213 map_indexes = _get_relevant_upstream_map_indexes(upstream_id)
214 if map_indexes is None: # All tis of this upstream are dependencies.
215 yield (TaskInstance.task_id == upstream_id)
216 continue
217 # At this point we know we want to depend on only selected tis
218 # of this upstream task. Since the upstream may not have been
219 # expanded at this point, we also depend on the non-expanded ti
220 # to ensure at least one ti is included for the task.
221 yield and_(TaskInstance.task_id == upstream_id, TaskInstance.map_index < 0)
222 if isinstance(map_indexes, range) and map_indexes.step == 1:
223 yield and_(
224 TaskInstance.task_id == upstream_id,
225 TaskInstance.map_index >= map_indexes.start,
226 TaskInstance.map_index < map_indexes.stop,
227 )
228 elif isinstance(map_indexes, collections.abc.Container):
229 yield and_(TaskInstance.task_id == upstream_id, TaskInstance.map_index.in_(map_indexes))
230 else:
231 yield and_(TaskInstance.task_id == upstream_id, TaskInstance.map_index == map_indexes)
232
233 def _evaluate_setup_constraint(*, relevant_setups) -> Iterator[tuple[TIDepStatus, bool]]:
234 """Evaluate whether ``ti``'s trigger rule was met.
235
236 :param ti: Task instance to evaluate the trigger rule of.
237 :param dep_context: The current dependency context.
238 :param session: Database session.
239 """
240 if TYPE_CHECKING:
241 assert ti.task
242
243 task = ti.task
244
245 indirect_setups = {k: v for k, v in relevant_setups.items() if k not in task.upstream_task_ids}
246 finished_upstream_tis = (
247 x
248 for x in dep_context.ensure_finished_tis(ti.get_dagrun(session), session)
249 if _is_relevant_upstream(upstream=x, relevant_ids=indirect_setups.keys())
250 )
251 upstream_states = _UpstreamTIStates.calculate(finished_upstream_tis)
252
253 # all of these counts reflect indirect setups which are relevant for this ti
254 success = upstream_states.success
255 skipped = upstream_states.skipped
256 failed = upstream_states.failed
257 upstream_failed = upstream_states.upstream_failed
258 removed = upstream_states.removed
259
260 # Optimization: Don't need to hit the database if all upstreams are
261 # "simple" tasks (no task or task group mapping involved).
262 if not any(t.get_needs_expansion() for t in indirect_setups.values()):
263 upstream = len(indirect_setups)
264 else:
265 task_id_counts = session.execute(
266 select(TaskInstance.task_id, func.count(TaskInstance.task_id))
267 .where(TaskInstance.dag_id == ti.dag_id, TaskInstance.run_id == ti.run_id)
268 .where(or_(*_iter_upstream_conditions(relevant_tasks=indirect_setups)))
269 .group_by(TaskInstance.task_id)
270 ).all()
271 upstream = sum(count for _, count in task_id_counts)
272
273 new_state = None
274 changed = False
275
276 # if there's a failure, we mark upstream_failed; if there's a skip, we mark skipped
277 # in either case, we don't wait for all relevant setups to complete
278 if dep_context.flag_upstream_failed:
279 if upstream_failed or failed:
280 new_state = TaskInstanceState.UPSTREAM_FAILED
281 elif skipped:
282 new_state = TaskInstanceState.SKIPPED
283 elif removed and success and ti.map_index > -1:
284 if ti.map_index >= success:
285 new_state = TaskInstanceState.REMOVED
286
287 if new_state is not None:
288 if (
289 new_state == TaskInstanceState.SKIPPED
290 and dep_context.wait_for_past_depends_before_skipping
291 ):
292 past_depends_met = ti.xcom_pull(
293 task_ids=ti.task_id, key=PAST_DEPENDS_MET, session=session, default=False
294 )
295 if not past_depends_met:
296 yield (
297 self._failing_status(
298 reason="Task should be skipped but the past depends are not met"
299 ),
300 changed,
301 )
302 return
303 changed = ti.set_state(new_state, session)
304
305 if changed:
306 dep_context.have_changed_ti_states = True
307
308 non_successes = upstream - success
309 if ti.map_index > -1:
310 non_successes -= removed
311 if non_successes > 0:
312 yield (
313 self._failing_status(
314 reason=(
315 f"All setup tasks must complete successfully. Relevant setups: {relevant_setups}: "
316 f"upstream_states={upstream_states}, "
317 f"upstream_task_ids={task.upstream_task_ids}"
318 ),
319 ),
320 changed,
321 )
322
323 def _evaluate_direct_relatives() -> Iterator[TIDepStatus]:
324 """Evaluate whether ``ti``'s trigger rule was met.
325
326 :param ti: Task instance to evaluate the trigger rule of.
327 :param dep_context: The current dependency context.
328 :param session: Database session.
329 """
330 if TYPE_CHECKING:
331 assert ti.task
332
333 task = ti.task
334 upstream_tasks = {t.task_id: t for t in task.upstream_list}
335 trigger_rule = task.trigger_rule
336
337 finished_upstream_tis = (
338 finished_ti
339 for finished_ti in dep_context.ensure_finished_tis(ti.get_dagrun(session), session)
340 if _is_relevant_upstream(upstream=finished_ti, relevant_ids=ti.task.upstream_task_ids)
341 )
342 upstream_states = _UpstreamTIStates.calculate(finished_upstream_tis)
343
344 success = upstream_states.success
345 skipped = upstream_states.skipped
346 failed = upstream_states.failed
347 upstream_failed = upstream_states.upstream_failed
348 removed = upstream_states.removed
349 done = upstream_states.done
350 success_setup = upstream_states.success_setup
351 skipped_setup = upstream_states.skipped_setup
352
353 # Optimization: Don't need to hit the database if all upstreams are
354 # "simple" tasks (no task or task group mapping involved).
355 if not any(t.get_needs_expansion() for t in upstream_tasks.values()):
356 upstream = len(upstream_tasks)
357 upstream_setup = sum(1 for x in upstream_tasks.values() if x.is_setup)
358 else:
359 task_id_counts = session.execute(
360 select(TaskInstance.task_id, func.count(TaskInstance.task_id))
361 .where(TaskInstance.dag_id == ti.dag_id, TaskInstance.run_id == ti.run_id)
362 .where(or_(*_iter_upstream_conditions(relevant_tasks=upstream_tasks)))
363 .group_by(TaskInstance.task_id)
364 ).all()
365 upstream = sum(count for _, count in task_id_counts)
366 upstream_setup = sum(c for t, c in task_id_counts if upstream_tasks[t].is_setup)
367
368 upstream_done = done >= upstream
369
370 changed = False
371 new_state = None
372 if dep_context.flag_upstream_failed:
373 if trigger_rule == TR.ALL_SUCCESS:
374 if upstream_failed or failed:
375 new_state = TaskInstanceState.UPSTREAM_FAILED
376 elif skipped:
377 new_state = TaskInstanceState.SKIPPED
378 elif removed and success and ti.map_index > -1:
379 if ti.map_index >= success:
380 new_state = TaskInstanceState.REMOVED
381 elif trigger_rule == TR.ALL_FAILED:
382 if success or skipped:
383 new_state = TaskInstanceState.SKIPPED
384 elif trigger_rule == TR.ONE_SUCCESS:
385 if upstream_done and done == skipped:
386 # if upstream is done and all are skipped mark as skipped
387 new_state = TaskInstanceState.SKIPPED
388 elif upstream_done and success <= 0:
389 # if upstream is done and there are no success mark as upstream failed
390 new_state = TaskInstanceState.UPSTREAM_FAILED
391 elif trigger_rule == TR.ONE_FAILED:
392 if upstream_done and not (failed or upstream_failed):
393 new_state = TaskInstanceState.SKIPPED
394 elif trigger_rule == TR.ONE_DONE:
395 if upstream_done and not (failed or success):
396 new_state = TaskInstanceState.SKIPPED
397 elif trigger_rule == TR.NONE_FAILED:
398 if upstream_failed or failed:
399 new_state = TaskInstanceState.UPSTREAM_FAILED
400 elif trigger_rule == TR.NONE_FAILED_MIN_ONE_SUCCESS:
401 if upstream_failed or failed:
402 new_state = TaskInstanceState.UPSTREAM_FAILED
403 elif skipped == upstream:
404 new_state = TaskInstanceState.SKIPPED
405 elif trigger_rule == TR.NONE_SKIPPED:
406 if skipped:
407 new_state = TaskInstanceState.SKIPPED
408 elif trigger_rule == TR.ALL_SKIPPED:
409 if success or failed or upstream_failed:
410 new_state = TaskInstanceState.SKIPPED
411 elif trigger_rule == TR.ALL_DONE_SETUP_SUCCESS:
412 if upstream_done and upstream_setup and skipped_setup >= upstream_setup:
413 # when there is an upstream setup and they have all skipped, then skip
414 new_state = TaskInstanceState.SKIPPED
415 elif upstream_done and upstream_setup and success_setup == 0:
416 # when there is an upstream setup, if none succeeded, mark upstream failed
417 # if at least one setup ran, we'll let it run
418 new_state = TaskInstanceState.UPSTREAM_FAILED
419 if new_state is not None:
420 if (
421 new_state == TaskInstanceState.SKIPPED
422 and dep_context.wait_for_past_depends_before_skipping
423 ):
424 past_depends_met = ti.xcom_pull(
425 task_ids=ti.task_id, key=PAST_DEPENDS_MET, session=session, default=False
426 )
427 if not past_depends_met:
428 yield self._failing_status(
429 reason=("Task should be skipped but the past depends are not met")
430 )
431 return
432 changed = ti.set_state(new_state, session)
433
434 if changed:
435 dep_context.have_changed_ti_states = True
436
437 if trigger_rule == TR.ONE_SUCCESS:
438 if success <= 0:
439 yield self._failing_status(
440 reason=(
441 f"Task's trigger rule '{trigger_rule}' requires one upstream task success, "
442 f"but none were found. upstream_states={upstream_states}, "
443 f"upstream_task_ids={task.upstream_task_ids}"
444 )
445 )
446 elif trigger_rule == TR.ONE_FAILED:
447 if not failed and not upstream_failed:
448 yield self._failing_status(
449 reason=(
450 f"Task's trigger rule '{trigger_rule}' requires one upstream task failure, "
451 f"but none were found. upstream_states={upstream_states}, "
452 f"upstream_task_ids={task.upstream_task_ids}"
453 )
454 )
455 elif trigger_rule == TR.ONE_DONE:
456 if success + failed <= 0:
457 yield self._failing_status(
458 reason=(
459 f"Task's trigger rule '{trigger_rule}' "
460 "requires at least one upstream task failure or success "
461 f"but none were failed or success. upstream_states={upstream_states}, "
462 f"upstream_task_ids={task.upstream_task_ids}"
463 )
464 )
465 elif trigger_rule == TR.ALL_SUCCESS:
466 num_failures = upstream - success
467 if ti.map_index > -1:
468 num_failures -= removed
469 if num_failures > 0:
470 yield self._failing_status(
471 reason=(
472 f"Task's trigger rule '{trigger_rule}' requires all upstream tasks to have "
473 f"succeeded, but found {num_failures} non-success(es). "
474 f"upstream_states={upstream_states}, "
475 f"upstream_task_ids={task.upstream_task_ids}"
476 )
477 )
478 elif trigger_rule == TR.ALL_FAILED:
479 num_success = upstream - failed - upstream_failed
480 if ti.map_index > -1:
481 num_success -= removed
482 if num_success > 0:
483 yield self._failing_status(
484 reason=(
485 f"Task's trigger rule '{trigger_rule}' requires all upstream tasks "
486 f"to have failed, but found {num_success} non-failure(s). "
487 f"upstream_states={upstream_states}, "
488 f"upstream_task_ids={task.upstream_task_ids}"
489 )
490 )
491 elif trigger_rule == TR.ALL_DONE:
492 if not upstream_done:
493 yield self._failing_status(
494 reason=(
495 f"Task's trigger rule '{trigger_rule}' requires all upstream tasks to have "
496 f"completed, but found {len(upstream_tasks) - done} task(s) that were "
497 f"not done. upstream_states={upstream_states}, "
498 f"upstream_task_ids={task.upstream_task_ids}"
499 )
500 )
501 elif trigger_rule == TR.NONE_FAILED or trigger_rule == TR.NONE_FAILED_MIN_ONE_SUCCESS:
502 num_failures = upstream - success - skipped
503 if ti.map_index > -1:
504 num_failures -= removed
505 if num_failures > 0:
506 yield self._failing_status(
507 reason=(
508 f"Task's trigger rule '{trigger_rule}' requires all upstream tasks to have "
509 f"succeeded or been skipped, but found {num_failures} non-success(es). "
510 f"upstream_states={upstream_states}, "
511 f"upstream_task_ids={task.upstream_task_ids}"
512 )
513 )
514 elif trigger_rule == TR.NONE_SKIPPED:
515 if not upstream_done or (skipped > 0):
516 yield self._failing_status(
517 reason=(
518 f"Task's trigger rule '{trigger_rule}' requires all upstream tasks to not "
519 f"have been skipped, but found {skipped} task(s) skipped. "
520 f"upstream_states={upstream_states}, "
521 f"upstream_task_ids={task.upstream_task_ids}"
522 )
523 )
524 elif trigger_rule == TR.ALL_SKIPPED:
525 num_non_skipped = upstream - skipped
526 if num_non_skipped > 0:
527 yield self._failing_status(
528 reason=(
529 f"Task's trigger rule '{trigger_rule}' requires all upstream tasks to have been "
530 f"skipped, but found {num_non_skipped} task(s) in non skipped state. "
531 f"upstream_states={upstream_states}, "
532 f"upstream_task_ids={task.upstream_task_ids}"
533 )
534 )
535 elif trigger_rule == TR.ALL_DONE_SETUP_SUCCESS:
536 if not upstream_done:
537 yield self._failing_status(
538 reason=(
539 f"Task's trigger rule '{trigger_rule}' requires all upstream tasks to have "
540 f"completed, but found {len(upstream_tasks) - done} task(s) that were not done. "
541 f"upstream_states={upstream_states}, "
542 f"upstream_task_ids={task.upstream_task_ids}"
543 )
544 )
545 elif upstream_setup and not success_setup:
546 yield self._failing_status(
547 reason=(
548 f"Task's trigger rule '{trigger_rule}' requires at least one upstream setup task "
549 f"be successful, but found {upstream_setup - success_setup} task(s) that were "
550 f"not. upstream_states={upstream_states}, "
551 f"upstream_task_ids={task.upstream_task_ids}"
552 )
553 )
554 else:
555 yield self._failing_status(reason=f"No strategy to evaluate trigger rule '{trigger_rule}'.")
556
557 if TYPE_CHECKING:
558 assert ti.task
559
560 if not ti.task.is_teardown:
561 # a teardown cannot have any indirect setups
562 relevant_setups = {t.task_id: t for t in ti.task.get_upstreams_only_setups()}
563 if relevant_setups:
564 for status, changed in _evaluate_setup_constraint(relevant_setups=relevant_setups):
565 yield status
566 if not status.passed and changed:
567 # no need to evaluate trigger rule; we've already marked as skipped or failed
568 return
569
570 yield from _evaluate_direct_relatives()