Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/airflow/providers/standard/utils/skipmixin.py: 25%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

63 statements  

1# 

2# Licensed to the Apache Software Foundation (ASF) under one 

3# or more contributor license agreements. See the NOTICE file 

4# distributed with this work for additional information 

5# regarding copyright ownership. The ASF licenses this file 

6# to you under the Apache License, Version 2.0 (the 

7# "License"); you may not use this file except in compliance 

8# with the License. You may obtain a copy of the License at 

9# 

10# http://www.apache.org/licenses/LICENSE-2.0 

11# 

12# Unless required by applicable law or agreed to in writing, 

13# software distributed under the License is distributed on an 

14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

15# KIND, either express or implied. See the License for the 

16# specific language governing permissions and limitations 

17# under the License. 

18from __future__ import annotations 

19 

20from collections.abc import Iterable, Sequence 

21from types import GeneratorType 

22from typing import TYPE_CHECKING 

23 

24from airflow.providers.common.compat.sdk import AirflowException 

25from airflow.utils.log.logging_mixin import LoggingMixin 

26 

27if TYPE_CHECKING: 

28 from airflow.sdk.definitions._internal.node import DAGNode 

29 from airflow.sdk.types import Operator, RuntimeTaskInstanceProtocol 

30 

31# The key used by SkipMixin to store XCom data. 

32XCOM_SKIPMIXIN_KEY = "skipmixin_key" 

33 

34# The dictionary key used to denote task IDs that are skipped 

35XCOM_SKIPMIXIN_SKIPPED = "skipped" 

36 

37# The dictionary key used to denote task IDs that are followed 

38XCOM_SKIPMIXIN_FOLLOWED = "followed" 

39 

40 

41def _ensure_tasks(nodes: Iterable[DAGNode]) -> Sequence[Operator]: 

42 from airflow.providers.common.compat.sdk import BaseOperator, MappedOperator 

43 

44 return [n for n in nodes if isinstance(n, (BaseOperator, MappedOperator))] 

45 

46 

47# This class should only be used in Airflow 3.0 and later. 

48class SkipMixin(LoggingMixin): 

49 """A Mixin to skip Tasks Instances.""" 

50 

51 @staticmethod 

52 def _set_state_to_skipped( 

53 tasks: Sequence[str | tuple[str, int]], 

54 map_index: int | None, 

55 ) -> None: 

56 """ 

57 Set state of task instances to skipped from the same dag run. 

58 

59 Raises 

60 ------ 

61 SkipDownstreamTaskInstances 

62 If the task instances are not in the same dag run. 

63 """ 

64 # Import is internal for backward compatibility when importing PythonOperator 

65 # from airflow.providers.common.compat.standard.operators 

66 from airflow.providers.common.compat.sdk import DownstreamTasksSkipped 

67 

68 # The following could be applied only for non-mapped tasks, 

69 # as future mapped tasks have not been expanded yet. Such tasks 

70 # have to be handled by NotPreviouslySkippedDep. 

71 if tasks and map_index == -1: 

72 raise DownstreamTasksSkipped(tasks=tasks) 

73 

74 def skip( 

75 self, 

76 ti: RuntimeTaskInstanceProtocol, 

77 tasks: Iterable[DAGNode], 

78 ): 

79 """ 

80 Set tasks instances to skipped from the same dag run. 

81 

82 If this instance has a `task_id` attribute, store the list of skipped task IDs to XCom 

83 so that NotPreviouslySkippedDep knows these tasks should be skipped when they 

84 are cleared. 

85 

86 :param ti: the task instance for which to set the tasks to skipped 

87 :param tasks: tasks to skip (not task_ids) 

88 """ 

89 # SkipMixin may not necessarily have a task_id attribute. Only store to XCom if one is available. 

90 task_id: str | None = getattr(self, "task_id", None) 

91 task_list = _ensure_tasks(tasks) 

92 if not task_list: 

93 return 

94 

95 task_ids_list = [d.task_id for d in task_list] 

96 

97 if task_id is not None: 

98 ti.xcom_push( 

99 key=XCOM_SKIPMIXIN_KEY, 

100 value={XCOM_SKIPMIXIN_SKIPPED: task_ids_list}, 

101 ) 

102 

103 self._set_state_to_skipped(task_ids_list, ti.map_index) 

104 

105 def skip_all_except( 

106 self, 

107 ti: RuntimeTaskInstanceProtocol, 

108 branch_task_ids: None | str | Iterable[str], 

109 ): 

110 """ 

111 Implement the logic for a branching operator. 

112 

113 Given a single task ID or list of task IDs to follow, this skips all other tasks 

114 immediately downstream of this operator. 

115 

116 branch_task_ids is stored to XCom so that NotPreviouslySkippedDep knows skipped tasks or 

117 newly added tasks should be skipped when they are cleared. 

118 """ 

119 # Ensure we don't serialize a generator object 

120 if branch_task_ids and isinstance(branch_task_ids, GeneratorType): 

121 branch_task_ids = list(branch_task_ids) 

122 log = self.log # Note: need to catch logger form instance, static logger breaks pytest 

123 if isinstance(branch_task_ids, str): 

124 branch_task_id_set = {branch_task_ids} 

125 elif isinstance(branch_task_ids, Iterable): 

126 # Handle the case where invalid values are passed as elements of an Iterable 

127 # Non-string values are considered invalid elements 

128 branch_task_id_set = set(branch_task_ids) 

129 invalid_task_ids_type = { 

130 (bti, type(bti).__name__) for bti in branch_task_id_set if not isinstance(bti, str) 

131 } 

132 if invalid_task_ids_type: 

133 raise AirflowException( 

134 f"Unable to branch to the specified tasks. " 

135 f"The branching function returned invalid 'branch_task_ids': {invalid_task_ids_type}. " 

136 f"Please check that your function returns an Iterable of valid task IDs that exist in your DAG." 

137 ) 

138 elif branch_task_ids is None: 

139 branch_task_id_set = set() 

140 else: 

141 raise AirflowException( 

142 "'branch_task_ids' must be either None, a task ID, or an Iterable of IDs, " 

143 f"but got {type(branch_task_ids).__name__!r}." 

144 ) 

145 

146 log.info("Following branch %s", branch_task_id_set) 

147 

148 if TYPE_CHECKING: 

149 assert ti.task 

150 

151 task = ti.task 

152 dag = ti.task.dag 

153 

154 valid_task_ids = set(dag.task_ids) 

155 invalid_task_ids = branch_task_id_set - valid_task_ids 

156 if invalid_task_ids: 

157 raise AirflowException( 

158 "'branch_task_ids' must contain only valid task_ids. " 

159 f"Invalid tasks found: {invalid_task_ids}." 

160 ) 

161 

162 downstream_tasks = _ensure_tasks(task.downstream_list) 

163 

164 if downstream_tasks: 

165 # For a branching workflow that looks like this, when "branch" does skip_all_except("task1"), 

166 # we intuitively expect both "task1" and "join" to execute even though strictly speaking, 

167 # "join" is also immediately downstream of "branch" and should have been skipped. Therefore, 

168 # we need a special case here for such empty branches: Check downstream tasks of branch_task_ids. 

169 # In case the task to skip is also downstream of branch_task_ids, we add it to branch_task_ids and 

170 # exclude it from skipping. 

171 # 

172 # branch -----> join 

173 # \ ^ 

174 # v / 

175 # task1 

176 # 

177 for branch_task_id in list(branch_task_id_set): 

178 branch_task_id_set.update(dag.get_task(branch_task_id).get_flat_relative_ids(upstream=False)) 

179 

180 skip_tasks = [ 

181 (t.task_id, ti.map_index) for t in downstream_tasks if t.task_id not in branch_task_id_set 

182 ] 

183 

184 follow_task_ids = [t.task_id for t in downstream_tasks if t.task_id in branch_task_id_set] 

185 log.info("Skipping tasks %s", skip_tasks) 

186 ti.xcom_push( 

187 key=XCOM_SKIPMIXIN_KEY, 

188 value={XCOM_SKIPMIXIN_FOLLOWED: follow_task_ids}, 

189 ) 

190 # The following could be applied only for non-mapped tasks, 

191 # as future mapped tasks have not been expanded yet. Such tasks 

192 # have to be handled by NotPreviouslySkippedDep. 

193 self._set_state_to_skipped(skip_tasks, ti.map_index) # type: ignore[arg-type]