Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/airflow/ti_deps/deps/trigger_rule_dep.py: 14%

215 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-06-07 06:35 +0000

1# 

2# Licensed to the Apache Software Foundation (ASF) under one 

3# or more contributor license agreements. See the NOTICE file 

4# distributed with this work for additional information 

5# regarding copyright ownership. The ASF licenses this file 

6# to you under the Apache License, Version 2.0 (the 

7# "License"); you may not use this file except in compliance 

8# with the License. You may obtain a copy of the License at 

9# 

10# http://www.apache.org/licenses/LICENSE-2.0 

11# 

12# Unless required by applicable law or agreed to in writing, 

13# software distributed under the License is distributed on an 

14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

15# KIND, either express or implied. See the License for the 

16# specific language governing permissions and limitations 

17# under the License. 

18from __future__ import annotations 

19 

20import collections 

21import collections.abc 

22import functools 

23from typing import TYPE_CHECKING, Iterator, NamedTuple 

24 

25from sqlalchemy import and_, func, or_ 

26 

27from airflow.models import MappedOperator 

28from airflow.models.taskinstance import PAST_DEPENDS_MET 

29from airflow.ti_deps.dep_context import DepContext 

30from airflow.ti_deps.deps.base_ti_dep import BaseTIDep, TIDepStatus 

31from airflow.utils.state import TaskInstanceState 

32from airflow.utils.trigger_rule import TriggerRule as TR 

33 

34if TYPE_CHECKING: 

35 from sqlalchemy.orm import Session 

36 from sqlalchemy.sql.expression import ColumnOperators 

37 

38 from airflow.models.taskinstance import TaskInstance 

39 

40 

41class _UpstreamTIStates(NamedTuple): 

42 """States of the upstream tis for a specific ti. 

43 

44 This is used to determine whether the specific ti can run in this iteration. 

45 """ 

46 

47 success: int 

48 skipped: int 

49 failed: int 

50 upstream_failed: int 

51 removed: int 

52 done: int 

53 success_setup: int 

54 skipped_setup: int 

55 

56 @classmethod 

57 def calculate(cls, finished_upstreams: Iterator[TaskInstance]) -> _UpstreamTIStates: 

58 """Calculate states for a task instance. 

59 

60 ``counter`` is inclusive of ``setup_counter`` -- e.g. if there are 2 skipped upstreams, one 

61 of which is a setup, then counter will show 2 skipped and setup counter will show 1. 

62 

63 :param ti: the ti that we want to calculate deps for 

64 :param finished_tis: all the finished tasks of the dag_run 

65 """ 

66 counter: dict[str, int] = collections.Counter() 

67 setup_counter: dict[str, int] = collections.Counter() 

68 for ti in finished_upstreams: 

69 curr_state = {ti.state: 1} 

70 counter.update(curr_state) 

71 # setup task cannot be mapped 

72 if not isinstance(ti.task, MappedOperator) and ti.task.is_setup: 

73 setup_counter.update(curr_state) 

74 return _UpstreamTIStates( 

75 success=counter.get(TaskInstanceState.SUCCESS, 0), 

76 skipped=counter.get(TaskInstanceState.SKIPPED, 0), 

77 failed=counter.get(TaskInstanceState.FAILED, 0), 

78 upstream_failed=counter.get(TaskInstanceState.UPSTREAM_FAILED, 0), 

79 removed=counter.get(TaskInstanceState.REMOVED, 0), 

80 done=sum(counter.values()), 

81 success_setup=setup_counter.get(TaskInstanceState.SUCCESS, 0), 

82 skipped_setup=setup_counter.get(TaskInstanceState.SKIPPED, 0), 

83 ) 

84 

85 

86class TriggerRuleDep(BaseTIDep): 

87 """ 

88 Determines if a task's upstream tasks are in a state that allows a given task instance 

89 to run. 

90 """ 

91 

92 NAME = "Trigger Rule" 

93 IGNORABLE = True 

94 IS_TASK_DEP = True 

95 

96 def _get_dep_statuses( 

97 self, 

98 ti: TaskInstance, 

99 session: Session, 

100 dep_context: DepContext, 

101 ) -> Iterator[TIDepStatus]: 

102 # Checking that all upstream dependencies have succeeded. 

103 if not ti.task.upstream_task_ids: 

104 yield self._passing_status(reason="The task instance did not have any upstream tasks.") 

105 return 

106 if ti.task.trigger_rule == TR.ALWAYS: 

107 yield self._passing_status(reason="The task had a always trigger rule set.") 

108 return 

109 yield from self._evaluate_trigger_rule(ti=ti, dep_context=dep_context, session=session) 

110 

111 def _evaluate_trigger_rule( 

112 self, 

113 *, 

114 ti: TaskInstance, 

115 dep_context: DepContext, 

116 session: Session, 

117 ) -> Iterator[TIDepStatus]: 

118 """Evaluate whether ``ti``'s trigger rule was met. 

119 

120 :param ti: Task instance to evaluate the trigger rule of. 

121 :param dep_context: The current dependency context. 

122 :param session: Database session. 

123 """ 

124 from airflow.models.abstractoperator import NotMapped 

125 from airflow.models.expandinput import NotFullyPopulated 

126 from airflow.models.operator import needs_expansion 

127 from airflow.models.taskinstance import TaskInstance 

128 

129 task = ti.task 

130 upstream_tasks = {t.task_id: t for t in task.upstream_list} 

131 trigger_rule = task.trigger_rule 

132 

133 @functools.lru_cache 

134 def _get_expanded_ti_count() -> int: 

135 """Get how many tis the current task is supposed to be expanded into. 

136 

137 This extra closure allows us to query the database only when needed, 

138 and at most once. 

139 """ 

140 return task.get_mapped_ti_count(ti.run_id, session=session) 

141 

142 @functools.lru_cache 

143 def _get_relevant_upstream_map_indexes(upstream_id: str) -> int | range | None: 

144 """Get the given task's map indexes relevant to the current ti. 

145 

146 This extra closure allows us to query the database only when needed, 

147 and at most once for each task (instead of once for each expanded 

148 task instance of the same task). 

149 """ 

150 try: 

151 expanded_ti_count = _get_expanded_ti_count() 

152 except (NotFullyPopulated, NotMapped): 

153 return None 

154 return ti.get_relevant_upstream_map_indexes( 

155 upstream_tasks[upstream_id], 

156 expanded_ti_count, 

157 session=session, 

158 ) 

159 

160 def _is_relevant_upstream(upstream: TaskInstance) -> bool: 

161 """Whether a task instance is a "relevant upstream" of the current task.""" 

162 # Not actually an upstream task. 

163 if upstream.task_id not in task.upstream_task_ids: 

164 return False 

165 # The current task is not in a mapped task group. All tis from an 

166 # upstream task are relevant. 

167 if task.get_closest_mapped_task_group() is None: 

168 return True 

169 # The upstream ti is not expanded. The upstream may be mapped or 

170 # not, but the ti is relevant either way. 

171 if upstream.map_index < 0: 

172 return True 

173 # Now we need to perform fine-grained check on whether this specific 

174 # upstream ti's map index is relevant. 

175 relevant = _get_relevant_upstream_map_indexes(upstream.task_id) 

176 if relevant is None: 

177 return True 

178 if relevant == upstream.map_index: 

179 return True 

180 if isinstance(relevant, collections.abc.Container) and upstream.map_index in relevant: 

181 return True 

182 return False 

183 

184 finished_upstream_tis = ( 

185 finished_ti 

186 for finished_ti in dep_context.ensure_finished_tis(ti.get_dagrun(session), session) 

187 if _is_relevant_upstream(finished_ti) 

188 ) 

189 upstream_states = _UpstreamTIStates.calculate(finished_upstream_tis) 

190 

191 success = upstream_states.success 

192 skipped = upstream_states.skipped 

193 failed = upstream_states.failed 

194 upstream_failed = upstream_states.upstream_failed 

195 removed = upstream_states.removed 

196 done = upstream_states.done 

197 success_setup = upstream_states.success_setup 

198 skipped_setup = upstream_states.skipped_setup 

199 

200 def _iter_upstream_conditions() -> Iterator[ColumnOperators]: 

201 # Optimization: If the current task is not in a mapped task group, 

202 # it depends on all upstream task instances. 

203 if task.get_closest_mapped_task_group() is None: 

204 yield TaskInstance.task_id.in_(upstream_tasks) 

205 return 

206 # Otherwise we need to figure out which map indexes are depended on 

207 # for each upstream by the current task instance. 

208 for upstream_id in upstream_tasks: 

209 map_indexes = _get_relevant_upstream_map_indexes(upstream_id) 

210 if map_indexes is None: # All tis of this upstream are dependencies. 

211 yield (TaskInstance.task_id == upstream_id) 

212 continue 

213 # At this point we know we want to depend on only selected tis 

214 # of this upstream task. Since the upstream may not have been 

215 # expanded at this point, we also depend on the non-expanded ti 

216 # to ensure at least one ti is included for the task. 

217 yield and_(TaskInstance.task_id == upstream_id, TaskInstance.map_index < 0) 

218 if isinstance(map_indexes, range) and map_indexes.step == 1: 

219 yield and_( 

220 TaskInstance.task_id == upstream_id, 

221 TaskInstance.map_index >= map_indexes.start, 

222 TaskInstance.map_index < map_indexes.stop, 

223 ) 

224 elif isinstance(map_indexes, collections.abc.Container): 

225 yield and_(TaskInstance.task_id == upstream_id, TaskInstance.map_index.in_(map_indexes)) 

226 else: 

227 yield and_(TaskInstance.task_id == upstream_id, TaskInstance.map_index == map_indexes) 

228 

229 # Optimization: Don't need to hit the database if all upstreams are 

230 # "simple" tasks (no task or task group mapping involved). 

231 if not any(needs_expansion(t) for t in upstream_tasks.values()): 

232 upstream = len(upstream_tasks) 

233 upstream_setup = len( 

234 [x for x in upstream_tasks.values() if not isinstance(x, MappedOperator) and x.is_setup] 

235 ) 

236 else: 

237 upstream = ( 

238 session.query(func.count()) 

239 .filter(TaskInstance.dag_id == ti.dag_id, TaskInstance.run_id == ti.run_id) 

240 .filter(or_(*_iter_upstream_conditions())) 

241 .scalar() 

242 ) 

243 # todo: add support for mapped setup? 

244 upstream_setup = None 

245 upstream_done = done >= upstream 

246 

247 changed = False 

248 new_state = None 

249 if dep_context.flag_upstream_failed: 

250 if trigger_rule == TR.ALL_SUCCESS: 

251 if upstream_failed or failed: 

252 new_state = TaskInstanceState.UPSTREAM_FAILED 

253 elif skipped: 

254 new_state = TaskInstanceState.SKIPPED 

255 elif removed and success and ti.map_index > -1: 

256 if ti.map_index >= success: 

257 new_state = TaskInstanceState.REMOVED 

258 elif trigger_rule == TR.ALL_FAILED: 

259 if success or skipped: 

260 new_state = TaskInstanceState.SKIPPED 

261 elif trigger_rule == TR.ONE_SUCCESS: 

262 if upstream_done and done == skipped: 

263 # if upstream is done and all are skipped mark as skipped 

264 new_state = TaskInstanceState.SKIPPED 

265 elif upstream_done and success <= 0: 

266 # if upstream is done and there are no success mark as upstream failed 

267 new_state = TaskInstanceState.UPSTREAM_FAILED 

268 elif trigger_rule == TR.ONE_FAILED: 

269 if upstream_done and not (failed or upstream_failed): 

270 new_state = TaskInstanceState.SKIPPED 

271 elif trigger_rule == TR.ONE_DONE: 

272 if upstream_done and not (failed or success): 

273 new_state = TaskInstanceState.SKIPPED 

274 elif trigger_rule == TR.NONE_FAILED: 

275 if upstream_failed or failed: 

276 new_state = TaskInstanceState.UPSTREAM_FAILED 

277 elif trigger_rule == TR.NONE_FAILED_MIN_ONE_SUCCESS: 

278 if upstream_failed or failed: 

279 new_state = TaskInstanceState.UPSTREAM_FAILED 

280 elif skipped == upstream: 

281 new_state = TaskInstanceState.SKIPPED 

282 elif trigger_rule == TR.NONE_SKIPPED: 

283 if skipped: 

284 new_state = TaskInstanceState.SKIPPED 

285 elif trigger_rule == TR.ALL_SKIPPED: 

286 if success or failed: 

287 new_state = TaskInstanceState.SKIPPED 

288 elif trigger_rule == TR.ALL_DONE_SETUP_SUCCESS: 

289 if upstream_done and upstream_setup and skipped_setup >= upstream_setup: 

290 # when there is an upstream setup and they have all skipped, then skip 

291 new_state = TaskInstanceState.SKIPPED 

292 elif upstream_done and upstream_setup and success_setup == 0: 

293 # when there is an upstream setup, if none succeeded, mark upstream failed 

294 # if at least one setup ran, we'll let it run 

295 new_state = TaskInstanceState.UPSTREAM_FAILED 

296 if new_state is not None: 

297 if new_state == TaskInstanceState.SKIPPED and dep_context.wait_for_past_depends_before_skipping: 

298 past_depends_met = ti.xcom_pull( 

299 task_ids=ti.task_id, key=PAST_DEPENDS_MET, session=session, default=False 

300 ) 

301 if not past_depends_met: 

302 yield self._failing_status( 

303 reason=("Task should be skipped but the the past depends are not met") 

304 ) 

305 return 

306 changed = ti.set_state(new_state, session) 

307 

308 if changed: 

309 dep_context.have_changed_ti_states = True 

310 

311 if trigger_rule == TR.ONE_SUCCESS: 

312 if success <= 0: 

313 yield self._failing_status( 

314 reason=( 

315 f"Task's trigger rule '{trigger_rule}' requires one upstream task success, " 

316 f"but none were found. upstream_states={upstream_states}, " 

317 f"upstream_task_ids={task.upstream_task_ids}" 

318 ) 

319 ) 

320 elif trigger_rule == TR.ONE_FAILED: 

321 if not failed and not upstream_failed: 

322 yield self._failing_status( 

323 reason=( 

324 f"Task's trigger rule '{trigger_rule}' requires one upstream task failure, " 

325 f"but none were found. upstream_states={upstream_states}, " 

326 f"upstream_task_ids={task.upstream_task_ids}" 

327 ) 

328 ) 

329 elif trigger_rule == TR.ONE_DONE: 

330 if success + failed <= 0: 

331 yield self._failing_status( 

332 reason=( 

333 f"Task's trigger rule '{trigger_rule}'" 

334 "requires at least one upstream task failure or success" 

335 f"but none were failed or success. upstream_states={upstream_states}, " 

336 f"upstream_task_ids={task.upstream_task_ids}" 

337 ) 

338 ) 

339 elif trigger_rule == TR.ALL_SUCCESS: 

340 num_failures = upstream - success 

341 if ti.map_index > -1: 

342 num_failures -= removed 

343 if num_failures > 0: 

344 yield self._failing_status( 

345 reason=( 

346 f"Task's trigger rule '{trigger_rule}' requires all upstream tasks to have " 

347 f"succeeded, but found {num_failures} non-success(es). " 

348 f"upstream_states={upstream_states}, " 

349 f"upstream_task_ids={task.upstream_task_ids}" 

350 ) 

351 ) 

352 elif trigger_rule == TR.ALL_FAILED: 

353 num_success = upstream - failed - upstream_failed 

354 if ti.map_index > -1: 

355 num_success -= removed 

356 if num_success > 0: 

357 yield self._failing_status( 

358 reason=( 

359 f"Task's trigger rule '{trigger_rule}' requires all upstream tasks to have failed, " 

360 f"but found {num_success} non-failure(s). " 

361 f"upstream_states={upstream_states}, " 

362 f"upstream_task_ids={task.upstream_task_ids}" 

363 ) 

364 ) 

365 elif trigger_rule == TR.ALL_DONE: 

366 if not upstream_done: 

367 yield self._failing_status( 

368 reason=( 

369 f"Task's trigger rule '{trigger_rule}' requires all upstream tasks to have " 

370 f"completed, but found {len(upstream_tasks) - done} task(s) that were not done. " 

371 f"upstream_states={upstream_states}, " 

372 f"upstream_task_ids={task.upstream_task_ids}" 

373 ) 

374 ) 

375 elif trigger_rule == TR.NONE_FAILED: 

376 num_failures = upstream - success - skipped 

377 if ti.map_index > -1: 

378 num_failures -= removed 

379 if num_failures > 0: 

380 yield self._failing_status( 

381 reason=( 

382 f"Task's trigger rule '{trigger_rule}' requires all upstream tasks to have " 

383 f"succeeded or been skipped, but found {num_failures} non-success(es). " 

384 f"upstream_states={upstream_states}, " 

385 f"upstream_task_ids={task.upstream_task_ids}" 

386 ) 

387 ) 

388 elif trigger_rule == TR.NONE_FAILED_MIN_ONE_SUCCESS: 

389 num_failures = upstream - success - skipped 

390 if ti.map_index > -1: 

391 num_failures -= removed 

392 if num_failures > 0: 

393 yield self._failing_status( 

394 reason=( 

395 f"Task's trigger rule '{trigger_rule}' requires all upstream tasks to have " 

396 f"succeeded or been skipped, but found {num_failures} non-success(es). " 

397 f"upstream_states={upstream_states}, " 

398 f"upstream_task_ids={task.upstream_task_ids}" 

399 ) 

400 ) 

401 elif trigger_rule == TR.NONE_SKIPPED: 

402 if not upstream_done or (skipped > 0): 

403 yield self._failing_status( 

404 reason=( 

405 f"Task's trigger rule '{trigger_rule}' requires all upstream tasks to not have been " 

406 f"skipped, but found {skipped} task(s) skipped. " 

407 f"upstream_states={upstream_states}, " 

408 f"upstream_task_ids={task.upstream_task_ids}" 

409 ) 

410 ) 

411 elif trigger_rule == TR.ALL_SKIPPED: 

412 num_non_skipped = upstream - skipped 

413 if num_non_skipped > 0: 

414 yield self._failing_status( 

415 reason=( 

416 f"Task's trigger rule '{trigger_rule}' requires all upstream tasks to have been " 

417 f"skipped, but found {num_non_skipped} task(s) in non skipped state. " 

418 f"upstream_states={upstream_states}, " 

419 f"upstream_task_ids={task.upstream_task_ids}" 

420 ) 

421 ) 

422 elif trigger_rule == TR.ALL_DONE_SETUP_SUCCESS: 

423 if not upstream_done: 

424 yield self._failing_status( 

425 reason=( 

426 f"Task's trigger rule '{trigger_rule}' requires all upstream tasks to have " 

427 f"completed, but found {len(upstream_tasks) - done} task(s) that were not done. " 

428 f"upstream_states={upstream_states}, " 

429 f"upstream_task_ids={task.upstream_task_ids}" 

430 ) 

431 ) 

432 elif upstream_setup is None: # for now, None only happens in mapped case 

433 yield self._failing_status( 

434 reason=( 

435 f"Task's trigger rule '{trigger_rule}' cannot have mapped tasks as upstream. " 

436 f"upstream_states={upstream_states}, " 

437 f"upstream_task_ids={task.upstream_task_ids}" 

438 ) 

439 ) 

440 elif upstream_setup and not success_setup >= 1: 

441 yield self._failing_status( 

442 reason=( 

443 f"Task's trigger rule '{trigger_rule}' requires at least one upstream setup task be " 

444 f"successful, but found {upstream_setup - success_setup} task(s) that were not. " 

445 f"upstream_states={upstream_states}, " 

446 f"upstream_task_ids={task.upstream_task_ids}" 

447 ) 

448 ) 

449 else: 

450 yield self._failing_status(reason=f"No strategy to evaluate trigger rule '{trigger_rule}'.")