Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/airflow/providers/standard/operators/branch.py: 36%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

42 statements  

1# 

2# Licensed to the Apache Software Foundation (ASF) under one 

3# or more contributor license agreements. See the NOTICE file 

4# distributed with this work for additional information 

5# regarding copyright ownership. The ASF licenses this file 

6# to you under the Apache License, Version 2.0 (the 

7# "License"); you may not use this file except in compliance 

8# with the License. You may obtain a copy of the License at 

9# 

10# http://www.apache.org/licenses/LICENSE-2.0 

11# 

12# Unless required by applicable law or agreed to in writing, 

13# software distributed under the License is distributed on an 

14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

15# KIND, either express or implied. See the License for the 

16# specific language governing permissions and limitations 

17# under the License. 

18"""Branching operators.""" 

19 

20from __future__ import annotations 

21 

22from collections.abc import Iterable 

23from typing import TYPE_CHECKING 

24 

25from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS, BaseOperator 

26 

27if AIRFLOW_V_3_0_PLUS: 

28 from airflow.providers.standard.utils.skipmixin import SkipMixin 

29else: 

30 from airflow.models.skipmixin import SkipMixin 

31 

32if TYPE_CHECKING: 

33 from airflow.providers.common.compat.sdk import Context 

34 from airflow.sdk.types import RuntimeTaskInstanceProtocol 

35 

36 

37class BranchMixIn(SkipMixin): 

38 """Utility helper which handles the branching as one-liner.""" 

39 

40 def do_branch( 

41 self, context: Context, branches_to_execute: str | Iterable[str] | None 

42 ) -> str | Iterable[str] | None: 

43 """Implement the handling of branching including logging.""" 

44 self.log.info("Branch into %s", branches_to_execute) 

45 if branches_to_execute is None: 

46 # When None is returned, skip all downstream tasks 

47 self.skip_all_except(context["ti"], None) 

48 else: 

49 branch_task_ids = self._expand_task_group_roots(context["ti"], branches_to_execute) 

50 self.skip_all_except(context["ti"], branch_task_ids) 

51 return branches_to_execute 

52 

53 def _expand_task_group_roots( 

54 self, ti: RuntimeTaskInstanceProtocol, branches_to_execute: str | Iterable[str] 

55 ) -> Iterable[str]: 

56 """Expand any task group into its root task ids.""" 

57 if TYPE_CHECKING: 

58 assert ti.task 

59 

60 task = ti.task 

61 dag = task.dag 

62 if TYPE_CHECKING: 

63 assert dag 

64 

65 if isinstance(branches_to_execute, str) or not isinstance(branches_to_execute, Iterable): 

66 branches_to_execute = [branches_to_execute] 

67 

68 for branch in branches_to_execute: 

69 if branch in dag.task_group_dict: 

70 tg = dag.task_group_dict[branch] 

71 root_ids = [root.task_id for root in tg.roots] 

72 self.log.info("Expanding task group %s into %s", tg.group_id, root_ids) 

73 yield from root_ids 

74 else: 

75 yield branch 

76 

77 

78class BaseBranchOperator(BaseOperator, BranchMixIn): 

79 """ 

80 A base class for creating operators with branching functionality, like to BranchPythonOperator. 

81 

82 Users should create a subclass from this operator and implement the function 

83 `choose_branch(self, context)`. This should run whatever business logic 

84 is needed to determine the branch, and return one of the following: 

85 - A single task_id (as a str) 

86 - A single task_group_id (as a str) 

87 - A list containing a combination of task_ids and task_group_ids 

88 

89 The operator will continue with the returned task_id(s) and/or task_group_id(s), and all other 

90 tasks directly downstream of this operator will be skipped. 

91 """ 

92 

93 inherits_from_skipmixin = True 

94 

95 def choose_branch(self, context: Context) -> str | Iterable[str] | None: 

96 """ 

97 Abstract method to choose which branch to run. 

98 

99 Subclasses should implement this, running whatever logic is 

100 necessary to choose a branch and returning a task_id or list of 

101 task_ids. If None is returned, all downstream tasks will be skipped. 

102 

103 :param context: Context dictionary as passed to execute() 

104 """ 

105 raise NotImplementedError 

106 

107 def execute(self, context: Context): 

108 return self.do_branch(context, self.choose_branch(context))