1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18"""Branching operators."""
19
20from __future__ import annotations
21
22from collections.abc import Iterable
23from typing import TYPE_CHECKING
24
25from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS, BaseOperator
26
27if AIRFLOW_V_3_0_PLUS:
28 from airflow.providers.standard.utils.skipmixin import SkipMixin
29else:
30 from airflow.models.skipmixin import SkipMixin
31
32if TYPE_CHECKING:
33 from airflow.providers.common.compat.sdk import Context
34 from airflow.sdk.types import RuntimeTaskInstanceProtocol
35
36
37class BranchMixIn(SkipMixin):
38 """Utility helper which handles the branching as one-liner."""
39
40 def do_branch(
41 self, context: Context, branches_to_execute: str | Iterable[str] | None
42 ) -> str | Iterable[str] | None:
43 """Implement the handling of branching including logging."""
44 self.log.info("Branch into %s", branches_to_execute)
45 if branches_to_execute is None:
46 # When None is returned, skip all downstream tasks
47 self.skip_all_except(context["ti"], None)
48 else:
49 branch_task_ids = self._expand_task_group_roots(context["ti"], branches_to_execute)
50 self.skip_all_except(context["ti"], branch_task_ids)
51 return branches_to_execute
52
53 def _expand_task_group_roots(
54 self, ti: RuntimeTaskInstanceProtocol, branches_to_execute: str | Iterable[str]
55 ) -> Iterable[str]:
56 """Expand any task group into its root task ids."""
57 if TYPE_CHECKING:
58 assert ti.task
59
60 task = ti.task
61 dag = task.dag
62 if TYPE_CHECKING:
63 assert dag
64
65 if isinstance(branches_to_execute, str) or not isinstance(branches_to_execute, Iterable):
66 branches_to_execute = [branches_to_execute]
67
68 for branch in branches_to_execute:
69 if branch in dag.task_group_dict:
70 tg = dag.task_group_dict[branch]
71 root_ids = [root.task_id for root in tg.roots]
72 self.log.info("Expanding task group %s into %s", tg.group_id, root_ids)
73 yield from root_ids
74 else:
75 yield branch
76
77
78class BaseBranchOperator(BaseOperator, BranchMixIn):
79 """
80 A base class for creating operators with branching functionality, like to BranchPythonOperator.
81
82 Users should create a subclass from this operator and implement the function
83 `choose_branch(self, context)`. This should run whatever business logic
84 is needed to determine the branch, and return one of the following:
85 - A single task_id (as a str)
86 - A single task_group_id (as a str)
87 - A list containing a combination of task_ids and task_group_ids
88
89 The operator will continue with the returned task_id(s) and/or task_group_id(s), and all other
90 tasks directly downstream of this operator will be skipped.
91 """
92
93 inherits_from_skipmixin = True
94
95 def choose_branch(self, context: Context) -> str | Iterable[str] | None:
96 """
97 Abstract method to choose which branch to run.
98
99 Subclasses should implement this, running whatever logic is
100 necessary to choose a branch and returning a task_id or list of
101 task_ids. If None is returned, all downstream tasks will be skipped.
102
103 :param context: Context dictionary as passed to execute()
104 """
105 raise NotImplementedError
106
107 def execute(self, context: Context):
108 return self.do_branch(context, self.choose_branch(context))