Coverage for /pythoncovmergedfiles/medio/medio/src/airflow/build/lib/airflow/models/skipmixin.py: 25%

79 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-06-07 06:35 +0000

1# 

2# Licensed to the Apache Software Foundation (ASF) under one 

3# or more contributor license agreements. See the NOTICE file 

4# distributed with this work for additional information 

5# regarding copyright ownership. The ASF licenses this file 

6# to you under the Apache License, Version 2.0 (the 

7# "License"); you may not use this file except in compliance 

8# with the License. You may obtain a copy of the License at 

9# 

10# http://www.apache.org/licenses/LICENSE-2.0 

11# 

12# Unless required by applicable law or agreed to in writing, 

13# software distributed under the License is distributed on an 

14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

15# KIND, either express or implied. See the License for the 

16# specific language governing permissions and limitations 

17# under the License. 

18from __future__ import annotations 

19 

20import warnings 

21from typing import TYPE_CHECKING, Iterable, Sequence 

22 

23from airflow.exceptions import AirflowException, RemovedInAirflow3Warning 

24from airflow.models.taskinstance import TaskInstance 

25from airflow.serialization.pydantic.dag_run import DagRunPydantic 

26from airflow.utils import timezone 

27from airflow.utils.log.logging_mixin import LoggingMixin 

28from airflow.utils.session import NEW_SESSION, create_session, provide_session 

29from airflow.utils.state import State 

30 

31if TYPE_CHECKING: 

32 from pendulum import DateTime 

33 from sqlalchemy import Session 

34 

35 from airflow.models.dagrun import DagRun 

36 from airflow.models.operator import Operator 

37 from airflow.models.taskmixin import DAGNode 

38 from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic 

39 

40# The key used by SkipMixin to store XCom data. 

41XCOM_SKIPMIXIN_KEY = "skipmixin_key" 

42 

43# The dictionary key used to denote task IDs that are skipped 

44XCOM_SKIPMIXIN_SKIPPED = "skipped" 

45 

46# The dictionary key used to denote task IDs that are followed 

47XCOM_SKIPMIXIN_FOLLOWED = "followed" 

48 

49 

50def _ensure_tasks(nodes: Iterable[DAGNode]) -> Sequence[Operator]: 

51 from airflow.models.baseoperator import BaseOperator 

52 from airflow.models.mappedoperator import MappedOperator 

53 

54 return [n for n in nodes if isinstance(n, (BaseOperator, MappedOperator))] 

55 

56 

57class SkipMixin(LoggingMixin): 

58 """A Mixin to skip Tasks Instances.""" 

59 

60 def _set_state_to_skipped( 

61 self, 

62 dag_run: DagRun | DagRunPydantic, 

63 tasks: Iterable[Operator], 

64 session: Session, 

65 ) -> None: 

66 """Used internally to set state of task instances to skipped from the same dag run.""" 

67 now = timezone.utcnow() 

68 

69 session.query(TaskInstance).filter( 

70 TaskInstance.dag_id == dag_run.dag_id, 

71 TaskInstance.run_id == dag_run.run_id, 

72 TaskInstance.task_id.in_(d.task_id for d in tasks), 

73 ).update( 

74 { 

75 TaskInstance.state: State.SKIPPED, 

76 TaskInstance.start_date: now, 

77 TaskInstance.end_date: now, 

78 }, 

79 synchronize_session=False, 

80 ) 

81 

82 @provide_session 

83 def skip( 

84 self, 

85 dag_run: DagRun | DagRunPydantic, 

86 execution_date: DateTime, 

87 tasks: Iterable[DAGNode], 

88 session: Session = NEW_SESSION, 

89 map_index: int = -1, 

90 ): 

91 """ 

92 Sets tasks instances to skipped from the same dag run. 

93 

94 If this instance has a `task_id` attribute, store the list of skipped task IDs to XCom 

95 so that NotPreviouslySkippedDep knows these tasks should be skipped when they 

96 are cleared. 

97 

98 :param dag_run: the DagRun for which to set the tasks to skipped 

99 :param execution_date: execution_date 

100 :param tasks: tasks to skip (not task_ids) 

101 :param session: db session to use 

102 :param map_index: map_index of the current task instance 

103 """ 

104 task_list = _ensure_tasks(tasks) 

105 if not task_list: 

106 return 

107 

108 if execution_date and not dag_run: 

109 from airflow.models.dagrun import DagRun 

110 

111 warnings.warn( 

112 "Passing an execution_date to `skip()` is deprecated in favour of passing a dag_run", 

113 RemovedInAirflow3Warning, 

114 stacklevel=2, 

115 ) 

116 

117 dag_run = ( 

118 session.query(DagRun) 

119 .filter( 

120 DagRun.dag_id == task_list[0].dag_id, 

121 DagRun.execution_date == execution_date, 

122 ) 

123 .one() 

124 ) 

125 elif execution_date and dag_run and execution_date != dag_run.execution_date: 

126 raise ValueError( 

127 "execution_date has a different value to dag_run.execution_date -- please only pass dag_run" 

128 ) 

129 

130 if dag_run is None: 

131 raise ValueError("dag_run is required") 

132 

133 self._set_state_to_skipped(dag_run, task_list, session) 

134 session.commit() 

135 

136 # SkipMixin may not necessarily have a task_id attribute. Only store to XCom if one is available. 

137 task_id: str | None = getattr(self, "task_id", None) 

138 if task_id is not None: 

139 from airflow.models.xcom import XCom 

140 

141 XCom.set( 

142 key=XCOM_SKIPMIXIN_KEY, 

143 value={XCOM_SKIPMIXIN_SKIPPED: [d.task_id for d in task_list]}, 

144 task_id=task_id, 

145 dag_id=dag_run.dag_id, 

146 run_id=dag_run.run_id, 

147 map_index=map_index, 

148 session=session, 

149 ) 

150 

151 def skip_all_except( 

152 self, 

153 ti: TaskInstance | TaskInstancePydantic, 

154 branch_task_ids: None | str | Iterable[str], 

155 ): 

156 """ 

157 This method implements the logic for a branching operator; given a single 

158 task ID or list of task IDs to follow, this skips all other tasks 

159 immediately downstream of this operator. 

160 

161 branch_task_ids is stored to XCom so that NotPreviouslySkippedDep knows skipped tasks or 

162 newly added tasks should be skipped when they are cleared. 

163 """ 

164 self.log.info("Following branch %s", branch_task_ids) 

165 if isinstance(branch_task_ids, str): 

166 branch_task_id_set = {branch_task_ids} 

167 elif isinstance(branch_task_ids, Iterable): 

168 branch_task_id_set = set(branch_task_ids) 

169 invalid_task_ids_type = { 

170 (bti, type(bti).__name__) for bti in branch_task_ids if not isinstance(bti, str) 

171 } 

172 if invalid_task_ids_type: 

173 raise AirflowException( 

174 f"'branch_task_ids' expected all task IDs are strings. " 

175 f"Invalid tasks found: {invalid_task_ids_type}." 

176 ) 

177 elif branch_task_ids is None: 

178 branch_task_id_set = set() 

179 else: 

180 raise AirflowException( 

181 "'branch_task_ids' must be either None, a task ID, or an Iterable of IDs, " 

182 f"but got {type(branch_task_ids).__name__!r}." 

183 ) 

184 

185 dag_run = ti.get_dagrun() 

186 

187 # TODO(potiuk): Handle TaskInstancePydantic case differently - we need to figure out the way to 

188 # pass task that has been set in LocalTaskJob but in the way that TaskInstancePydantic definition 

189 # does not attempt to serialize the field from/to ORM 

190 task = ti.task # type: ignore[union-attr] 

191 dag = task.dag 

192 if TYPE_CHECKING: 

193 assert dag 

194 

195 valid_task_ids = set(dag.task_ids) 

196 invalid_task_ids = branch_task_id_set - valid_task_ids 

197 if invalid_task_ids: 

198 raise AirflowException( 

199 "'branch_task_ids' must contain only valid task_ids. " 

200 f"Invalid tasks found: {invalid_task_ids}." 

201 ) 

202 

203 downstream_tasks = _ensure_tasks(task.downstream_list) 

204 

205 if downstream_tasks: 

206 # For a branching workflow that looks like this, when "branch" does skip_all_except("task1"), 

207 # we intuitively expect both "task1" and "join" to execute even though strictly speaking, 

208 # "join" is also immediately downstream of "branch" and should have been skipped. Therefore, 

209 # we need a special case here for such empty branches: Check downstream tasks of branch_task_ids. 

210 # In case the task to skip is also downstream of branch_task_ids, we add it to branch_task_ids and 

211 # exclude it from skipping. 

212 # 

213 # branch -----> join 

214 # \ ^ 

215 # v / 

216 # task1 

217 # 

218 for branch_task_id in list(branch_task_id_set): 

219 branch_task_id_set.update(dag.get_task(branch_task_id).get_flat_relative_ids(upstream=False)) 

220 

221 skip_tasks = [t for t in downstream_tasks if t.task_id not in branch_task_id_set] 

222 follow_task_ids = [t.task_id for t in downstream_tasks if t.task_id in branch_task_id_set] 

223 

224 self.log.info("Skipping tasks %s", [t.task_id for t in skip_tasks]) 

225 with create_session() as session: 

226 self._set_state_to_skipped(dag_run, skip_tasks, session=session) 

227 # For some reason, session.commit() needs to happen before xcom_push. 

228 # Otherwise the session is not committed. 

229 session.commit() 

230 ti.xcom_push(key=XCOM_SKIPMIXIN_KEY, value={XCOM_SKIPMIXIN_FOLLOWED: follow_task_ids})