Coverage for /pythoncovmergedfiles/medio/medio/src/airflow/build/lib/airflow/ti_deps/deps/trigger_rule_dep.py: 15%

175 statements  

« prev     ^ index     » next       coverage.py v7.0.1, created at 2022-12-25 06:11 +0000

1# 

2# Licensed to the Apache Software Foundation (ASF) under one 

3# or more contributor license agreements. See the NOTICE file 

4# distributed with this work for additional information 

5# regarding copyright ownership. The ASF licenses this file 

6# to you under the Apache License, Version 2.0 (the 

7# "License"); you may not use this file except in compliance 

8# with the License. You may obtain a copy of the License at 

9# 

10# http://www.apache.org/licenses/LICENSE-2.0 

11# 

12# Unless required by applicable law or agreed to in writing, 

13# software distributed under the License is distributed on an 

14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

15# KIND, either express or implied. See the License for the 

16# specific language governing permissions and limitations 

17# under the License. 

18from __future__ import annotations 

19 

20import collections 

21import collections.abc 

22import functools 

23from typing import TYPE_CHECKING, Iterator, NamedTuple 

24 

25from sqlalchemy import and_, func, or_ 

26 

27from airflow.ti_deps.dep_context import DepContext 

28from airflow.ti_deps.deps.base_ti_dep import BaseTIDep, TIDepStatus 

29from airflow.utils.state import TaskInstanceState 

30from airflow.utils.trigger_rule import TriggerRule as TR 

31 

32if TYPE_CHECKING: 

33 from sqlalchemy.orm import Session 

34 from sqlalchemy.sql.expression import ColumnOperators 

35 

36 from airflow.models.taskinstance import TaskInstance 

37 

38 

39class _UpstreamTIStates(NamedTuple): 

40 """States of the upstream tis for a specific ti. 

41 

42 This is used to determine whether the specific ti can run in this iteration. 

43 """ 

44 

45 success: int 

46 skipped: int 

47 failed: int 

48 upstream_failed: int 

49 removed: int 

50 done: int 

51 

52 @classmethod 

53 def calculate(cls, finished_upstreams: Iterator[TaskInstance]) -> _UpstreamTIStates: 

54 """Calculate states for a task instance. 

55 

56 :param ti: the ti that we want to calculate deps for 

57 :param finished_tis: all the finished tasks of the dag_run 

58 """ 

59 counter = collections.Counter(ti.state for ti in finished_upstreams) 

60 return _UpstreamTIStates( 

61 success=counter.get(TaskInstanceState.SUCCESS, 0), 

62 skipped=counter.get(TaskInstanceState.SKIPPED, 0), 

63 failed=counter.get(TaskInstanceState.FAILED, 0), 

64 upstream_failed=counter.get(TaskInstanceState.UPSTREAM_FAILED, 0), 

65 removed=counter.get(TaskInstanceState.REMOVED, 0), 

66 done=sum(counter.values()), 

67 ) 

68 

69 

70class TriggerRuleDep(BaseTIDep): 

71 """ 

72 Determines if a task's upstream tasks are in a state that allows a given task instance 

73 to run. 

74 """ 

75 

76 NAME = "Trigger Rule" 

77 IGNORABLE = True 

78 IS_TASK_DEP = True 

79 

80 def _get_dep_statuses( 

81 self, 

82 ti: TaskInstance, 

83 session: Session, 

84 dep_context: DepContext, 

85 ) -> Iterator[TIDepStatus]: 

86 # Checking that all upstream dependencies have succeeded. 

87 if not ti.task.upstream_task_ids: 

88 yield self._passing_status(reason="The task instance did not have any upstream tasks.") 

89 return 

90 if ti.task.trigger_rule == TR.ALWAYS: 

91 yield self._passing_status(reason="The task had a always trigger rule set.") 

92 return 

93 yield from self._evaluate_trigger_rule(ti=ti, dep_context=dep_context, session=session) 

94 

95 def _evaluate_trigger_rule( 

96 self, 

97 *, 

98 ti: TaskInstance, 

99 dep_context: DepContext, 

100 session: Session, 

101 ) -> Iterator[TIDepStatus]: 

102 """Evaluate whether ``ti``'s trigger rule was met. 

103 

104 :param ti: Task instance to evaluate the trigger rule of. 

105 :param dep_context: The current dependency context. 

106 :param session: Database session. 

107 """ 

108 from airflow.models.operator import needs_expansion 

109 from airflow.models.taskinstance import TaskInstance 

110 

111 task = ti.task 

112 upstream_tasks = {t.task_id: t for t in task.upstream_list} 

113 trigger_rule = task.trigger_rule 

114 

115 @functools.lru_cache() 

116 def _get_expanded_ti_count() -> int: 

117 """Get how many tis the current task is supposed to be expanded into. 

118 

119 This extra closure allows us to query the database only when needed, 

120 and at most once. 

121 """ 

122 return task.get_mapped_ti_count(ti.run_id, session=session) 

123 

124 @functools.lru_cache() 

125 def _get_relevant_upstream_map_indexes(upstream_id: str) -> int | range | None: 

126 """Get the given task's map indexes relevant to the current ti. 

127 

128 This extra closure allows us to query the database only when needed, 

129 and at most once for each task (instead of once for each expanded 

130 task instance of the same task). 

131 """ 

132 return ti.get_relevant_upstream_map_indexes( 

133 upstream_tasks[upstream_id], 

134 _get_expanded_ti_count(), 

135 session=session, 

136 ) 

137 

138 def _is_relevant_upstream(upstream: TaskInstance) -> bool: 

139 """Whether a task instance is a "relevant upstream" of the current task.""" 

140 # Not actually an upstream task. 

141 if upstream.task_id not in task.upstream_task_ids: 

142 return False 

143 # The current task is not in a mapped task group. All tis from an 

144 # upstream task are relevant. 

145 if task.get_closest_mapped_task_group() is None: 

146 return True 

147 # The upstream ti is not expanded. The upstream may be mapped or 

148 # not, but the ti is relevant either way. 

149 if upstream.map_index < 0: 

150 return True 

151 # Now we need to perform fine-grained check on whether this specific 

152 # upstream ti's map index is relevant. 

153 relevant = _get_relevant_upstream_map_indexes(upstream.task_id) 

154 if relevant is None: 

155 return True 

156 if relevant == upstream.map_index: 

157 return True 

158 if isinstance(relevant, collections.abc.Container) and upstream.map_index in relevant: 

159 return True 

160 return False 

161 

162 finished_upstream_tis = ( 

163 finished_ti 

164 for finished_ti in dep_context.ensure_finished_tis(ti.get_dagrun(session), session) 

165 if _is_relevant_upstream(finished_ti) 

166 ) 

167 upstream_states = _UpstreamTIStates.calculate(finished_upstream_tis) 

168 

169 success = upstream_states.success 

170 skipped = upstream_states.skipped 

171 failed = upstream_states.failed 

172 upstream_failed = upstream_states.upstream_failed 

173 removed = upstream_states.removed 

174 done = upstream_states.done 

175 

176 def _iter_upstream_conditions() -> Iterator[ColumnOperators]: 

177 # Optimization: If the current task is not in a mapped task group, 

178 # it depends on all upstream task instances. 

179 if task.get_closest_mapped_task_group() is None: 

180 yield TaskInstance.task_id.in_(upstream_tasks) 

181 return 

182 # Otherwise we need to figure out which map indexes are depended on 

183 # for each upstream by the current task instance. 

184 for upstream_id in upstream_tasks: 

185 map_indexes = _get_relevant_upstream_map_indexes(upstream_id) 

186 if map_indexes is None: # All tis of this upstream are dependencies. 

187 yield (TaskInstance.task_id == upstream_id) 

188 continue 

189 # At this point we know we want to depend on only selected tis 

190 # of this upstream task. Since the upstream may not have been 

191 # expanded at this point, we also depend on the non-expanded ti 

192 # to ensure at least one ti is included for the task. 

193 yield and_(TaskInstance.task_id == upstream_id, TaskInstance.map_index < 0) 

194 if isinstance(map_indexes, range) and map_indexes.step == 1: 

195 yield and_( 

196 TaskInstance.task_id == upstream_id, 

197 TaskInstance.map_index >= map_indexes.start, 

198 TaskInstance.map_index < map_indexes.stop, 

199 ) 

200 elif isinstance(map_indexes, collections.abc.Container): 

201 yield and_(TaskInstance.task_id == upstream_id, TaskInstance.map_index.in_(map_indexes)) 

202 else: 

203 yield and_(TaskInstance.task_id == upstream_id, TaskInstance.map_index == map_indexes) 

204 

205 # Optimization: Don't need to hit the database if all upstreams are 

206 # "simple" tasks (no task or task group mapping involved). 

207 if not any(needs_expansion(t) for t in upstream_tasks.values()): 

208 upstream = len(upstream_tasks) 

209 else: 

210 upstream = ( 

211 session.query(func.count()) 

212 .filter(TaskInstance.dag_id == ti.dag_id, TaskInstance.run_id == ti.run_id) 

213 .filter(or_(*_iter_upstream_conditions())) 

214 .scalar() 

215 ) 

216 upstream_done = done >= upstream 

217 

218 changed = False 

219 if dep_context.flag_upstream_failed: 

220 if trigger_rule == TR.ALL_SUCCESS: 

221 if upstream_failed or failed: 

222 changed = ti.set_state(TaskInstanceState.UPSTREAM_FAILED, session) 

223 elif skipped: 

224 changed = ti.set_state(TaskInstanceState.SKIPPED, session) 

225 elif removed and success and ti.map_index > -1: 

226 if ti.map_index >= success: 

227 changed = ti.set_state(TaskInstanceState.REMOVED, session) 

228 elif trigger_rule == TR.ALL_FAILED: 

229 if success or skipped: 

230 changed = ti.set_state(TaskInstanceState.SKIPPED, session) 

231 elif trigger_rule == TR.ONE_SUCCESS: 

232 if upstream_done and done == skipped: 

233 # if upstream is done and all are skipped mark as skipped 

234 changed = ti.set_state(TaskInstanceState.SKIPPED, session) 

235 elif upstream_done and success <= 0: 

236 # if upstream is done and there are no success mark as upstream failed 

237 changed = ti.set_state(TaskInstanceState.UPSTREAM_FAILED, session) 

238 elif trigger_rule == TR.ONE_FAILED: 

239 if upstream_done and not (failed or upstream_failed): 

240 changed = ti.set_state(TaskInstanceState.SKIPPED, session) 

241 elif trigger_rule == TR.ONE_DONE: 

242 if upstream_done and not (failed or success): 

243 changed = ti.set_state(TaskInstanceState.SKIPPED, session) 

244 elif trigger_rule == TR.NONE_FAILED: 

245 if upstream_failed or failed: 

246 changed = ti.set_state(TaskInstanceState.UPSTREAM_FAILED, session) 

247 elif trigger_rule == TR.NONE_FAILED_MIN_ONE_SUCCESS: 

248 if upstream_failed or failed: 

249 changed = ti.set_state(TaskInstanceState.UPSTREAM_FAILED, session) 

250 elif skipped == upstream: 

251 changed = ti.set_state(TaskInstanceState.SKIPPED, session) 

252 elif trigger_rule == TR.NONE_SKIPPED: 

253 if skipped: 

254 changed = ti.set_state(TaskInstanceState.SKIPPED, session) 

255 elif trigger_rule == TR.ALL_SKIPPED: 

256 if success or failed: 

257 changed = ti.set_state(TaskInstanceState.SKIPPED, session) 

258 

259 if changed: 

260 dep_context.have_changed_ti_states = True 

261 

262 if trigger_rule == TR.ONE_SUCCESS: 

263 if success <= 0: 

264 yield self._failing_status( 

265 reason=( 

266 f"Task's trigger rule '{trigger_rule}' requires one upstream task success, " 

267 f"but none were found. upstream_states={upstream_states}, " 

268 f"upstream_task_ids={task.upstream_task_ids}" 

269 ) 

270 ) 

271 elif trigger_rule == TR.ONE_FAILED: 

272 if not failed and not upstream_failed: 

273 yield self._failing_status( 

274 reason=( 

275 f"Task's trigger rule '{trigger_rule}' requires one upstream task failure, " 

276 f"but none were found. upstream_states={upstream_states}, " 

277 f"upstream_task_ids={task.upstream_task_ids}" 

278 ) 

279 ) 

280 elif trigger_rule == TR.ONE_DONE: 

281 if success + failed <= 0: 

282 yield self._failing_status( 

283 reason=( 

284 f"Task's trigger rule '{trigger_rule}'" 

285 "requires at least one upstream task failure or success" 

286 f"but none were failed or success. upstream_states={upstream_states}, " 

287 f"upstream_task_ids={task.upstream_task_ids}" 

288 ) 

289 ) 

290 elif trigger_rule == TR.ALL_SUCCESS: 

291 num_failures = upstream - success 

292 if ti.map_index > -1: 

293 num_failures -= removed 

294 if num_failures > 0: 

295 yield self._failing_status( 

296 reason=( 

297 f"Task's trigger rule '{trigger_rule}' requires all upstream tasks to have " 

298 f"succeeded, but found {num_failures} non-success(es). " 

299 f"upstream_states={upstream_states}, " 

300 f"upstream_task_ids={task.upstream_task_ids}" 

301 ) 

302 ) 

303 elif trigger_rule == TR.ALL_FAILED: 

304 num_success = upstream - failed - upstream_failed 

305 if ti.map_index > -1: 

306 num_success -= removed 

307 if num_success > 0: 

308 yield self._failing_status( 

309 reason=( 

310 f"Task's trigger rule '{trigger_rule}' requires all upstream tasks to have failed, " 

311 f"but found {num_success} non-failure(s). " 

312 f"upstream_states={upstream_states}, " 

313 f"upstream_task_ids={task.upstream_task_ids}" 

314 ) 

315 ) 

316 elif trigger_rule == TR.ALL_DONE: 

317 if not upstream_done: 

318 yield self._failing_status( 

319 reason=( 

320 f"Task's trigger rule '{trigger_rule}' requires all upstream tasks to have " 

321 f"completed, but found {upstream_done} task(s) that were not done. " 

322 f"upstream_states={upstream_states}, " 

323 f"upstream_task_ids={task.upstream_task_ids}" 

324 ) 

325 ) 

326 elif trigger_rule == TR.NONE_FAILED: 

327 num_failures = upstream - success - skipped 

328 if ti.map_index > -1: 

329 num_failures -= removed 

330 if num_failures > 0: 

331 yield self._failing_status( 

332 reason=( 

333 f"Task's trigger rule '{trigger_rule}' requires all upstream tasks to have " 

334 f"succeeded or been skipped, but found {num_failures} non-success(es). " 

335 f"upstream_states={upstream_states}, " 

336 f"upstream_task_ids={task.upstream_task_ids}" 

337 ) 

338 ) 

339 elif trigger_rule == TR.NONE_FAILED_MIN_ONE_SUCCESS: 

340 num_failures = upstream - success - skipped 

341 if ti.map_index > -1: 

342 num_failures -= removed 

343 if num_failures > 0: 

344 yield self._failing_status( 

345 reason=( 

346 f"Task's trigger rule '{trigger_rule}' requires all upstream tasks to have " 

347 f"succeeded or been skipped, but found {num_failures} non-success(es). " 

348 f"upstream_states={upstream_states}, " 

349 f"upstream_task_ids={task.upstream_task_ids}" 

350 ) 

351 ) 

352 elif trigger_rule == TR.NONE_SKIPPED: 

353 if not upstream_done or (skipped > 0): 

354 yield self._failing_status( 

355 reason=( 

356 f"Task's trigger rule '{trigger_rule}' requires all upstream tasks to not have been " 

357 f"skipped, but found {skipped} task(s) skipped. " 

358 f"upstream_states={upstream_states}, " 

359 f"upstream_task_ids={task.upstream_task_ids}" 

360 ) 

361 ) 

362 elif trigger_rule == TR.ALL_SKIPPED: 

363 num_non_skipped = upstream - skipped 

364 if num_non_skipped > 0: 

365 yield self._failing_status( 

366 reason=( 

367 f"Task's trigger rule '{trigger_rule}' requires all upstream tasks to have been " 

368 f"skipped, but found {num_non_skipped} task(s) in non skipped state. " 

369 f"upstream_states={upstream_states}, " 

370 f"upstream_task_ids={task.upstream_task_ids}" 

371 ) 

372 ) 

373 else: 

374 yield self._failing_status(reason=f"No strategy to evaluate trigger rule '{trigger_rule}'.")