Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/airflow/models/skipmixin.py: 24%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

88 statements  

1# 

2# Licensed to the Apache Software Foundation (ASF) under one 

3# or more contributor license agreements. See the NOTICE file 

4# distributed with this work for additional information 

5# regarding copyright ownership. The ASF licenses this file 

6# to you under the Apache License, Version 2.0 (the 

7# "License"); you may not use this file except in compliance 

8# with the License. You may obtain a copy of the License at 

9# 

10# http://www.apache.org/licenses/LICENSE-2.0 

11# 

12# Unless required by applicable law or agreed to in writing, 

13# software distributed under the License is distributed on an 

14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

15# KIND, either express or implied. See the License for the 

16# specific language governing permissions and limitations 

17# under the License. 

18from __future__ import annotations 

19 

20import warnings 

21from typing import TYPE_CHECKING, Iterable, Sequence 

22 

23from sqlalchemy import select, update 

24 

25from airflow.exceptions import AirflowException, RemovedInAirflow3Warning 

26from airflow.models.taskinstance import TaskInstance 

27from airflow.utils import timezone 

28from airflow.utils.log.logging_mixin import LoggingMixin 

29from airflow.utils.session import NEW_SESSION, create_session, provide_session 

30from airflow.utils.sqlalchemy import tuple_in_condition 

31from airflow.utils.state import TaskInstanceState 

32 

33if TYPE_CHECKING: 

34 from pendulum import DateTime 

35 from sqlalchemy import Session 

36 

37 from airflow.models.dagrun import DagRun 

38 from airflow.models.operator import Operator 

39 from airflow.models.taskmixin import DAGNode 

40 from airflow.serialization.pydantic.dag_run import DagRunPydantic 

41 from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic 

42 

43# The key used by SkipMixin to store XCom data. 

44XCOM_SKIPMIXIN_KEY = "skipmixin_key" 

45 

46# The dictionary key used to denote task IDs that are skipped 

47XCOM_SKIPMIXIN_SKIPPED = "skipped" 

48 

49# The dictionary key used to denote task IDs that are followed 

50XCOM_SKIPMIXIN_FOLLOWED = "followed" 

51 

52 

53def _ensure_tasks(nodes: Iterable[DAGNode]) -> Sequence[Operator]: 

54 from airflow.models.baseoperator import BaseOperator 

55 from airflow.models.mappedoperator import MappedOperator 

56 

57 return [n for n in nodes if isinstance(n, (BaseOperator, MappedOperator))] 

58 

59 

60class SkipMixin(LoggingMixin): 

61 """A Mixin to skip Tasks Instances.""" 

62 

63 def _set_state_to_skipped( 

64 self, 

65 dag_run: DagRun | DagRunPydantic, 

66 tasks: Sequence[str] | Sequence[tuple[str, int]], 

67 session: Session, 

68 ) -> None: 

69 """Set state of task instances to skipped from the same dag run.""" 

70 if tasks: 

71 now = timezone.utcnow() 

72 

73 if isinstance(tasks[0], tuple): 

74 session.execute( 

75 update(TaskInstance) 

76 .where( 

77 TaskInstance.dag_id == dag_run.dag_id, 

78 TaskInstance.run_id == dag_run.run_id, 

79 tuple_in_condition((TaskInstance.task_id, TaskInstance.map_index), tasks), 

80 ) 

81 .values(state=TaskInstanceState.SKIPPED, start_date=now, end_date=now) 

82 .execution_options(synchronize_session=False) 

83 ) 

84 else: 

85 session.execute( 

86 update(TaskInstance) 

87 .where( 

88 TaskInstance.dag_id == dag_run.dag_id, 

89 TaskInstance.run_id == dag_run.run_id, 

90 TaskInstance.task_id.in_(tasks), 

91 ) 

92 .values(state=TaskInstanceState.SKIPPED, start_date=now, end_date=now) 

93 .execution_options(synchronize_session=False) 

94 ) 

95 

96 @provide_session 

97 def skip( 

98 self, 

99 dag_run: DagRun | DagRunPydantic, 

100 execution_date: DateTime, 

101 tasks: Iterable[DAGNode], 

102 session: Session = NEW_SESSION, 

103 map_index: int = -1, 

104 ): 

105 """ 

106 Set tasks instances to skipped from the same dag run. 

107 

108 If this instance has a `task_id` attribute, store the list of skipped task IDs to XCom 

109 so that NotPreviouslySkippedDep knows these tasks should be skipped when they 

110 are cleared. 

111 

112 :param dag_run: the DagRun for which to set the tasks to skipped 

113 :param execution_date: execution_date 

114 :param tasks: tasks to skip (not task_ids) 

115 :param session: db session to use 

116 :param map_index: map_index of the current task instance 

117 """ 

118 task_list = _ensure_tasks(tasks) 

119 if not task_list: 

120 return 

121 

122 if execution_date and not dag_run: 

123 from airflow.models.dagrun import DagRun 

124 

125 warnings.warn( 

126 "Passing an execution_date to `skip()` is deprecated in favour of passing a dag_run", 

127 RemovedInAirflow3Warning, 

128 stacklevel=2, 

129 ) 

130 

131 dag_run = session.scalars( 

132 select(DagRun).where( 

133 DagRun.dag_id == task_list[0].dag_id, DagRun.execution_date == execution_date 

134 ) 

135 ).one() 

136 

137 elif execution_date and dag_run and execution_date != dag_run.execution_date: 

138 raise ValueError( 

139 "execution_date has a different value to dag_run.execution_date -- please only pass dag_run" 

140 ) 

141 

142 if dag_run is None: 

143 raise ValueError("dag_run is required") 

144 

145 task_ids_list = [d.task_id for d in task_list] 

146 self._set_state_to_skipped(dag_run, task_ids_list, session) 

147 session.commit() 

148 

149 # SkipMixin may not necessarily have a task_id attribute. Only store to XCom if one is available. 

150 task_id: str | None = getattr(self, "task_id", None) 

151 if task_id is not None: 

152 from airflow.models.xcom import XCom 

153 

154 XCom.set( 

155 key=XCOM_SKIPMIXIN_KEY, 

156 value={XCOM_SKIPMIXIN_SKIPPED: task_ids_list}, 

157 task_id=task_id, 

158 dag_id=dag_run.dag_id, 

159 run_id=dag_run.run_id, 

160 map_index=map_index, 

161 session=session, 

162 ) 

163 

164 def skip_all_except( 

165 self, 

166 ti: TaskInstance | TaskInstancePydantic, 

167 branch_task_ids: None | str | Iterable[str], 

168 ): 

169 """ 

170 Implement the logic for a branching operator. 

171 

172 Given a single task ID or list of task IDs to follow, this skips all other tasks 

173 immediately downstream of this operator. 

174 

175 branch_task_ids is stored to XCom so that NotPreviouslySkippedDep knows skipped tasks or 

176 newly added tasks should be skipped when they are cleared. 

177 """ 

178 self.log.info("Following branch %s", branch_task_ids) 

179 if isinstance(branch_task_ids, str): 

180 branch_task_id_set = {branch_task_ids} 

181 elif isinstance(branch_task_ids, Iterable): 

182 branch_task_id_set = set(branch_task_ids) 

183 invalid_task_ids_type = { 

184 (bti, type(bti).__name__) for bti in branch_task_ids if not isinstance(bti, str) 

185 } 

186 if invalid_task_ids_type: 

187 raise AirflowException( 

188 f"'branch_task_ids' expected all task IDs are strings. " 

189 f"Invalid tasks found: {invalid_task_ids_type}." 

190 ) 

191 elif branch_task_ids is None: 

192 branch_task_id_set = set() 

193 else: 

194 raise AirflowException( 

195 "'branch_task_ids' must be either None, a task ID, or an Iterable of IDs, " 

196 f"but got {type(branch_task_ids).__name__!r}." 

197 ) 

198 

199 dag_run = ti.get_dagrun() 

200 if TYPE_CHECKING: 

201 assert isinstance(dag_run, DagRun) 

202 assert ti.task 

203 

204 # TODO(potiuk): Handle TaskInstancePydantic case differently - we need to figure out the way to 

205 # pass task that has been set in LocalTaskJob but in the way that TaskInstancePydantic definition 

206 # does not attempt to serialize the field from/to ORM 

207 task = ti.task 

208 dag = task.dag 

209 if TYPE_CHECKING: 

210 assert dag 

211 

212 valid_task_ids = set(dag.task_ids) 

213 invalid_task_ids = branch_task_id_set - valid_task_ids 

214 if invalid_task_ids: 

215 raise AirflowException( 

216 "'branch_task_ids' must contain only valid task_ids. " 

217 f"Invalid tasks found: {invalid_task_ids}." 

218 ) 

219 

220 downstream_tasks = _ensure_tasks(task.downstream_list) 

221 

222 if downstream_tasks: 

223 # For a branching workflow that looks like this, when "branch" does skip_all_except("task1"), 

224 # we intuitively expect both "task1" and "join" to execute even though strictly speaking, 

225 # "join" is also immediately downstream of "branch" and should have been skipped. Therefore, 

226 # we need a special case here for such empty branches: Check downstream tasks of branch_task_ids. 

227 # In case the task to skip is also downstream of branch_task_ids, we add it to branch_task_ids and 

228 # exclude it from skipping. 

229 # 

230 # branch -----> join 

231 # \ ^ 

232 # v / 

233 # task1 

234 # 

235 for branch_task_id in list(branch_task_id_set): 

236 branch_task_id_set.update(dag.get_task(branch_task_id).get_flat_relative_ids(upstream=False)) 

237 

238 skip_tasks = [ 

239 (t.task_id, downstream_ti.map_index) 

240 for t in downstream_tasks 

241 if (downstream_ti := dag_run.get_task_instance(t.task_id, map_index=ti.map_index)) 

242 and t.task_id not in branch_task_id_set 

243 ] 

244 

245 follow_task_ids = [t.task_id for t in downstream_tasks if t.task_id in branch_task_id_set] 

246 self.log.info("Skipping tasks %s", skip_tasks) 

247 with create_session() as session: 

248 self._set_state_to_skipped(dag_run, skip_tasks, session=session) 

249 # For some reason, session.commit() needs to happen before xcom_push. 

250 # Otherwise the session is not committed. 

251 session.commit() 

252 ti.xcom_push(key=XCOM_SKIPMIXIN_KEY, value={XCOM_SKIPMIXIN_FOLLOWED: follow_task_ids})