Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/airflow/models/skipmixin.py: 25%

77 statements  

« prev     ^ index     » next       coverage.py v7.0.1, created at 2022-12-25 06:11 +0000

1# 

2# Licensed to the Apache Software Foundation (ASF) under one 

3# or more contributor license agreements. See the NOTICE file 

4# distributed with this work for additional information 

5# regarding copyright ownership. The ASF licenses this file 

6# to you under the Apache License, Version 2.0 (the 

7# "License"); you may not use this file except in compliance 

8# with the License. You may obtain a copy of the License at 

9# 

10# http://www.apache.org/licenses/LICENSE-2.0 

11# 

12# Unless required by applicable law or agreed to in writing, 

13# software distributed under the License is distributed on an 

14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

15# KIND, either express or implied. See the License for the 

16# specific language governing permissions and limitations 

17# under the License. 

18from __future__ import annotations 

19 

20import warnings 

21from typing import TYPE_CHECKING, Iterable, Sequence 

22 

23from airflow.exceptions import AirflowException, RemovedInAirflow3Warning 

24from airflow.models.taskinstance import TaskInstance 

25from airflow.utils import timezone 

26from airflow.utils.log.logging_mixin import LoggingMixin 

27from airflow.utils.session import NEW_SESSION, create_session, provide_session 

28from airflow.utils.state import State 

29 

30if TYPE_CHECKING: 

31 from pendulum import DateTime 

32 from sqlalchemy import Session 

33 

34 from airflow.models.dagrun import DagRun 

35 from airflow.models.operator import Operator 

36 from airflow.models.taskmixin import DAGNode 

37 

38# The key used by SkipMixin to store XCom data. 

39XCOM_SKIPMIXIN_KEY = "skipmixin_key" 

40 

41# The dictionary key used to denote task IDs that are skipped 

42XCOM_SKIPMIXIN_SKIPPED = "skipped" 

43 

44# The dictionary key used to denote task IDs that are followed 

45XCOM_SKIPMIXIN_FOLLOWED = "followed" 

46 

47 

48def _ensure_tasks(nodes: Iterable[DAGNode]) -> Sequence[Operator]: 

49 from airflow.models.baseoperator import BaseOperator 

50 from airflow.models.mappedoperator import MappedOperator 

51 

52 return [n for n in nodes if isinstance(n, (BaseOperator, MappedOperator))] 

53 

54 

55class SkipMixin(LoggingMixin): 

56 """A Mixin to skip Tasks Instances""" 

57 

58 def _set_state_to_skipped( 

59 self, 

60 dag_run: DagRun, 

61 tasks: Iterable[Operator], 

62 session: Session, 

63 ) -> None: 

64 """Used internally to set state of task instances to skipped from the same dag run.""" 

65 now = timezone.utcnow() 

66 

67 session.query(TaskInstance).filter( 

68 TaskInstance.dag_id == dag_run.dag_id, 

69 TaskInstance.run_id == dag_run.run_id, 

70 TaskInstance.task_id.in_(d.task_id for d in tasks), 

71 ).update( 

72 { 

73 TaskInstance.state: State.SKIPPED, 

74 TaskInstance.start_date: now, 

75 TaskInstance.end_date: now, 

76 }, 

77 synchronize_session=False, 

78 ) 

79 

80 @provide_session 

81 def skip( 

82 self, 

83 dag_run: DagRun, 

84 execution_date: DateTime, 

85 tasks: Iterable[DAGNode], 

86 session: Session = NEW_SESSION, 

87 ): 

88 """ 

89 Sets tasks instances to skipped from the same dag run. 

90 

91 If this instance has a `task_id` attribute, store the list of skipped task IDs to XCom 

92 so that NotPreviouslySkippedDep knows these tasks should be skipped when they 

93 are cleared. 

94 

95 :param dag_run: the DagRun for which to set the tasks to skipped 

96 :param execution_date: execution_date 

97 :param tasks: tasks to skip (not task_ids) 

98 :param session: db session to use 

99 """ 

100 task_list = _ensure_tasks(tasks) 

101 if not task_list: 

102 return 

103 

104 if execution_date and not dag_run: 

105 from airflow.models.dagrun import DagRun 

106 

107 warnings.warn( 

108 "Passing an execution_date to `skip()` is deprecated in favour of passing a dag_run", 

109 RemovedInAirflow3Warning, 

110 stacklevel=2, 

111 ) 

112 

113 dag_run = ( 

114 session.query(DagRun) 

115 .filter( 

116 DagRun.dag_id == task_list[0].dag_id, 

117 DagRun.execution_date == execution_date, 

118 ) 

119 .one() 

120 ) 

121 elif execution_date and dag_run and execution_date != dag_run.execution_date: 

122 raise ValueError( 

123 "execution_date has a different value to dag_run.execution_date -- please only pass dag_run" 

124 ) 

125 

126 if dag_run is None: 

127 raise ValueError("dag_run is required") 

128 

129 self._set_state_to_skipped(dag_run, task_list, session) 

130 session.commit() 

131 

132 # SkipMixin may not necessarily have a task_id attribute. Only store to XCom if one is available. 

133 task_id: str | None = getattr(self, "task_id", None) 

134 if task_id is not None: 

135 from airflow.models.xcom import XCom 

136 

137 XCom.set( 

138 key=XCOM_SKIPMIXIN_KEY, 

139 value={XCOM_SKIPMIXIN_SKIPPED: [d.task_id for d in task_list]}, 

140 task_id=task_id, 

141 dag_id=dag_run.dag_id, 

142 run_id=dag_run.run_id, 

143 session=session, 

144 ) 

145 

146 def skip_all_except(self, ti: TaskInstance, branch_task_ids: None | str | Iterable[str]): 

147 """ 

148 This method implements the logic for a branching operator; given a single 

149 task ID or list of task IDs to follow, this skips all other tasks 

150 immediately downstream of this operator. 

151 

152 branch_task_ids is stored to XCom so that NotPreviouslySkippedDep knows skipped tasks or 

153 newly added tasks should be skipped when they are cleared. 

154 """ 

155 self.log.info("Following branch %s", branch_task_ids) 

156 if isinstance(branch_task_ids, str): 

157 branch_task_id_set = {branch_task_ids} 

158 elif isinstance(branch_task_ids, Iterable): 

159 branch_task_id_set = set(branch_task_ids) 

160 invalid_task_ids_type = { 

161 (bti, type(bti).__name__) for bti in branch_task_ids if not isinstance(bti, str) 

162 } 

163 if invalid_task_ids_type: 

164 raise AirflowException( 

165 f"'branch_task_ids' expected all task IDs are strings. " 

166 f"Invalid tasks found: {invalid_task_ids_type}." 

167 ) 

168 elif branch_task_ids is None: 

169 branch_task_id_set = set() 

170 else: 

171 raise AirflowException( 

172 "'branch_task_ids' must be either None, a task ID, or an Iterable of IDs, " 

173 f"but got {type(branch_task_ids).__name__!r}." 

174 ) 

175 

176 dag_run = ti.get_dagrun() 

177 task = ti.task 

178 dag = task.dag 

179 if TYPE_CHECKING: 

180 assert dag 

181 

182 valid_task_ids = set(dag.task_ids) 

183 invalid_task_ids = branch_task_id_set - valid_task_ids 

184 if invalid_task_ids: 

185 raise AirflowException( 

186 "'branch_task_ids' must contain only valid task_ids. " 

187 f"Invalid tasks found: {invalid_task_ids}." 

188 ) 

189 

190 downstream_tasks = _ensure_tasks(task.downstream_list) 

191 

192 if downstream_tasks: 

193 # For a branching workflow that looks like this, when "branch" does skip_all_except("task1"), 

194 # we intuitively expect both "task1" and "join" to execute even though strictly speaking, 

195 # "join" is also immediately downstream of "branch" and should have been skipped. Therefore, 

196 # we need a special case here for such empty branches: Check downstream tasks of branch_task_ids. 

197 # In case the task to skip is also downstream of branch_task_ids, we add it to branch_task_ids and 

198 # exclude it from skipping. 

199 # 

200 # branch -----> join 

201 # \ ^ 

202 # v / 

203 # task1 

204 # 

205 for branch_task_id in list(branch_task_id_set): 

206 branch_task_id_set.update(dag.get_task(branch_task_id).get_flat_relative_ids(upstream=False)) 

207 

208 skip_tasks = [t for t in downstream_tasks if t.task_id not in branch_task_id_set] 

209 follow_task_ids = [t.task_id for t in downstream_tasks if t.task_id in branch_task_id_set] 

210 

211 self.log.info("Skipping tasks %s", [t.task_id for t in skip_tasks]) 

212 with create_session() as session: 

213 self._set_state_to_skipped(dag_run, skip_tasks, session=session) 

214 # For some reason, session.commit() needs to happen before xcom_push. 

215 # Otherwise the session is not committed. 

216 session.commit() 

217 ti.xcom_push(key=XCOM_SKIPMIXIN_KEY, value={XCOM_SKIPMIXIN_FOLLOWED: follow_task_ids})