Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/airflow/ti_deps/deps/trigger_rule_dep.py: 10%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

275 statements  

1# 

2# Licensed to the Apache Software Foundation (ASF) under one 

3# or more contributor license agreements. See the NOTICE file 

4# distributed with this work for additional information 

5# regarding copyright ownership. The ASF licenses this file 

6# to you under the Apache License, Version 2.0 (the 

7# "License"); you may not use this file except in compliance 

8# with the License. You may obtain a copy of the License at 

9# 

10# http://www.apache.org/licenses/LICENSE-2.0 

11# 

12# Unless required by applicable law or agreed to in writing, 

13# software distributed under the License is distributed on an 

14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

15# KIND, either express or implied. See the License for the 

16# specific language governing permissions and limitations 

17# under the License. 

18from __future__ import annotations 

19 

20import collections.abc 

21import functools 

22from collections import Counter 

23from typing import TYPE_CHECKING, Iterator, KeysView, NamedTuple 

24 

25from sqlalchemy import and_, func, or_, select 

26 

27from airflow.models.taskinstance import PAST_DEPENDS_MET 

28from airflow.ti_deps.deps.base_ti_dep import BaseTIDep 

29from airflow.utils.state import TaskInstanceState 

30from airflow.utils.trigger_rule import TriggerRule as TR 

31 

32if TYPE_CHECKING: 

33 from sqlalchemy.orm import Session 

34 from sqlalchemy.sql.expression import ColumnOperators 

35 

36 from airflow import DAG 

37 from airflow.models.taskinstance import TaskInstance 

38 from airflow.ti_deps.dep_context import DepContext 

39 from airflow.ti_deps.deps.base_ti_dep import TIDepStatus 

40 

41 

42class _UpstreamTIStates(NamedTuple): 

43 """States of the upstream tis for a specific ti. 

44 

45 This is used to determine whether the specific ti can run in this iteration. 

46 """ 

47 

48 success: int 

49 skipped: int 

50 failed: int 

51 upstream_failed: int 

52 removed: int 

53 done: int 

54 success_setup: int 

55 skipped_setup: int 

56 

57 @classmethod 

58 def calculate(cls, finished_upstreams: Iterator[TaskInstance]) -> _UpstreamTIStates: 

59 """Calculate states for a task instance. 

60 

61 ``counter`` is inclusive of ``setup_counter`` -- e.g. if there are 2 skipped upstreams, one 

62 of which is a setup, then counter will show 2 skipped and setup counter will show 1. 

63 

64 :param ti: the ti that we want to calculate deps for 

65 :param finished_tis: all the finished tasks of the dag_run 

66 """ 

67 counter: dict[str, int] = Counter() 

68 setup_counter: dict[str, int] = Counter() 

69 for ti in finished_upstreams: 

70 if TYPE_CHECKING: 

71 assert ti.task 

72 curr_state = {ti.state: 1} 

73 counter.update(curr_state) 

74 if ti.task.is_setup: 

75 setup_counter.update(curr_state) 

76 return _UpstreamTIStates( 

77 success=counter.get(TaskInstanceState.SUCCESS, 0), 

78 skipped=counter.get(TaskInstanceState.SKIPPED, 0), 

79 failed=counter.get(TaskInstanceState.FAILED, 0), 

80 upstream_failed=counter.get(TaskInstanceState.UPSTREAM_FAILED, 0), 

81 removed=counter.get(TaskInstanceState.REMOVED, 0), 

82 done=sum(counter.values()), 

83 success_setup=setup_counter.get(TaskInstanceState.SUCCESS, 0), 

84 skipped_setup=setup_counter.get(TaskInstanceState.SKIPPED, 0), 

85 ) 

86 

87 

88class TriggerRuleDep(BaseTIDep): 

89 """Determines if a task's upstream tasks are in a state that allows a given task instance to run.""" 

90 

91 NAME = "Trigger Rule" 

92 IGNORABLE = True 

93 IS_TASK_DEP = True 

94 

95 def _get_dep_statuses( 

96 self, 

97 ti: TaskInstance, 

98 session: Session, 

99 dep_context: DepContext, 

100 ) -> Iterator[TIDepStatus]: 

101 if TYPE_CHECKING: 

102 assert ti.task 

103 

104 # Checking that all upstream dependencies have succeeded. 

105 if not ti.task.upstream_task_ids: 

106 yield self._passing_status(reason="The task instance did not have any upstream tasks.") 

107 return 

108 if ti.task.trigger_rule == TR.ALWAYS: 

109 yield self._passing_status(reason="The task had a always trigger rule set.") 

110 return 

111 yield from self._evaluate_trigger_rule(ti=ti, dep_context=dep_context, session=session) 

112 

113 def _evaluate_trigger_rule( 

114 self, 

115 *, 

116 ti: TaskInstance, 

117 dep_context: DepContext, 

118 session: Session, 

119 ) -> Iterator[TIDepStatus]: 

120 """Evaluate whether ``ti``'s trigger rule was met. 

121 

122 :param ti: Task instance to evaluate the trigger rule of. 

123 :param dep_context: The current dependency context. 

124 :param session: Database session. 

125 """ 

126 from airflow.models.abstractoperator import NotMapped 

127 from airflow.models.expandinput import NotFullyPopulated 

128 from airflow.models.taskinstance import TaskInstance 

129 

130 @functools.lru_cache 

131 def _get_expanded_ti_count() -> int: 

132 """Get how many tis the current task is supposed to be expanded into. 

133 

134 This extra closure allows us to query the database only when needed, 

135 and at most once. 

136 """ 

137 if TYPE_CHECKING: 

138 assert ti.task 

139 

140 return ti.task.get_mapped_ti_count(ti.run_id, session=session) 

141 

142 @functools.lru_cache 

143 def _get_relevant_upstream_map_indexes(upstream_id: str) -> int | range | None: 

144 """Get the given task's map indexes relevant to the current ti. 

145 

146 This extra closure allows us to query the database only when needed, 

147 and at most once for each task (instead of once for each expanded 

148 task instance of the same task). 

149 """ 

150 if TYPE_CHECKING: 

151 assert ti.task 

152 assert isinstance(ti.task.dag, DAG) 

153 

154 try: 

155 expanded_ti_count = _get_expanded_ti_count() 

156 except (NotFullyPopulated, NotMapped): 

157 return None 

158 return ti.get_relevant_upstream_map_indexes( 

159 upstream=ti.task.dag.task_dict[upstream_id], 

160 ti_count=expanded_ti_count, 

161 session=session, 

162 ) 

163 

164 def _is_relevant_upstream(upstream: TaskInstance, relevant_ids: set[str] | KeysView[str]) -> bool: 

165 """ 

166 Whether a task instance is a "relevant upstream" of the current task. 

167 

168 This will return false if upstream.task_id is not in relevant_ids, 

169 or if both of the following are true: 

170 1. upstream.task_id in relevant_ids is True 

171 2. ti is in a mapped task group and upstream has a map index 

172 that ti does not depend on. 

173 """ 

174 if TYPE_CHECKING: 

175 assert ti.task 

176 

177 # Not actually an upstream task. 

178 if upstream.task_id not in relevant_ids: 

179 return False 

180 # The current task is not in a mapped task group. All tis from an 

181 # upstream task are relevant. 

182 if ti.task.get_closest_mapped_task_group() is None: 

183 return True 

184 # The upstream ti is not expanded. The upstream may be mapped or 

185 # not, but the ti is relevant either way. 

186 if upstream.map_index < 0: 

187 return True 

188 # Now we need to perform fine-grained check on whether this specific 

189 # upstream ti's map index is relevant. 

190 relevant = _get_relevant_upstream_map_indexes(upstream_id=upstream.task_id) 

191 if relevant is None: 

192 return True 

193 if relevant == upstream.map_index: 

194 return True 

195 if isinstance(relevant, collections.abc.Container) and upstream.map_index in relevant: 

196 return True 

197 return False 

198 

199 def _iter_upstream_conditions(relevant_tasks: dict) -> Iterator[ColumnOperators]: 

200 # Optimization: If the current task is not in a mapped task group, 

201 # it depends on all upstream task instances. 

202 from airflow.models.taskinstance import TaskInstance 

203 

204 if TYPE_CHECKING: 

205 assert ti.task 

206 

207 if ti.task.get_closest_mapped_task_group() is None: 

208 yield TaskInstance.task_id.in_(relevant_tasks.keys()) 

209 return 

210 # Otherwise we need to figure out which map indexes are depended on 

211 # for each upstream by the current task instance. 

212 for upstream_id in relevant_tasks: 

213 map_indexes = _get_relevant_upstream_map_indexes(upstream_id) 

214 if map_indexes is None: # All tis of this upstream are dependencies. 

215 yield (TaskInstance.task_id == upstream_id) 

216 continue 

217 # At this point we know we want to depend on only selected tis 

218 # of this upstream task. Since the upstream may not have been 

219 # expanded at this point, we also depend on the non-expanded ti 

220 # to ensure at least one ti is included for the task. 

221 yield and_(TaskInstance.task_id == upstream_id, TaskInstance.map_index < 0) 

222 if isinstance(map_indexes, range) and map_indexes.step == 1: 

223 yield and_( 

224 TaskInstance.task_id == upstream_id, 

225 TaskInstance.map_index >= map_indexes.start, 

226 TaskInstance.map_index < map_indexes.stop, 

227 ) 

228 elif isinstance(map_indexes, collections.abc.Container): 

229 yield and_(TaskInstance.task_id == upstream_id, TaskInstance.map_index.in_(map_indexes)) 

230 else: 

231 yield and_(TaskInstance.task_id == upstream_id, TaskInstance.map_index == map_indexes) 

232 

233 def _evaluate_setup_constraint(*, relevant_setups) -> Iterator[tuple[TIDepStatus, bool]]: 

234 """Evaluate whether ``ti``'s trigger rule was met. 

235 

236 :param ti: Task instance to evaluate the trigger rule of. 

237 :param dep_context: The current dependency context. 

238 :param session: Database session. 

239 """ 

240 if TYPE_CHECKING: 

241 assert ti.task 

242 

243 task = ti.task 

244 

245 indirect_setups = {k: v for k, v in relevant_setups.items() if k not in task.upstream_task_ids} 

246 finished_upstream_tis = ( 

247 x 

248 for x in dep_context.ensure_finished_tis(ti.get_dagrun(session), session) 

249 if _is_relevant_upstream(upstream=x, relevant_ids=indirect_setups.keys()) 

250 ) 

251 upstream_states = _UpstreamTIStates.calculate(finished_upstream_tis) 

252 

253 # all of these counts reflect indirect setups which are relevant for this ti 

254 success = upstream_states.success 

255 skipped = upstream_states.skipped 

256 failed = upstream_states.failed 

257 upstream_failed = upstream_states.upstream_failed 

258 removed = upstream_states.removed 

259 

260 # Optimization: Don't need to hit the database if all upstreams are 

261 # "simple" tasks (no task or task group mapping involved). 

262 if not any(t.get_needs_expansion() for t in indirect_setups.values()): 

263 upstream = len(indirect_setups) 

264 else: 

265 task_id_counts = session.execute( 

266 select(TaskInstance.task_id, func.count(TaskInstance.task_id)) 

267 .where(TaskInstance.dag_id == ti.dag_id, TaskInstance.run_id == ti.run_id) 

268 .where(or_(*_iter_upstream_conditions(relevant_tasks=indirect_setups))) 

269 .group_by(TaskInstance.task_id) 

270 ).all() 

271 upstream = sum(count for _, count in task_id_counts) 

272 

273 new_state = None 

274 changed = False 

275 

276 # if there's a failure, we mark upstream_failed; if there's a skip, we mark skipped 

277 # in either case, we don't wait for all relevant setups to complete 

278 if dep_context.flag_upstream_failed: 

279 if upstream_failed or failed: 

280 new_state = TaskInstanceState.UPSTREAM_FAILED 

281 elif skipped: 

282 new_state = TaskInstanceState.SKIPPED 

283 elif removed and success and ti.map_index > -1: 

284 if ti.map_index >= success: 

285 new_state = TaskInstanceState.REMOVED 

286 

287 if new_state is not None: 

288 if ( 

289 new_state == TaskInstanceState.SKIPPED 

290 and dep_context.wait_for_past_depends_before_skipping 

291 ): 

292 past_depends_met = ti.xcom_pull( 

293 task_ids=ti.task_id, key=PAST_DEPENDS_MET, session=session, default=False 

294 ) 

295 if not past_depends_met: 

296 yield ( 

297 self._failing_status( 

298 reason="Task should be skipped but the past depends are not met" 

299 ), 

300 changed, 

301 ) 

302 return 

303 changed = ti.set_state(new_state, session) 

304 

305 if changed: 

306 dep_context.have_changed_ti_states = True 

307 

308 non_successes = upstream - success 

309 if ti.map_index > -1: 

310 non_successes -= removed 

311 if non_successes > 0: 

312 yield ( 

313 self._failing_status( 

314 reason=( 

315 f"All setup tasks must complete successfully. Relevant setups: {relevant_setups}: " 

316 f"upstream_states={upstream_states}, " 

317 f"upstream_task_ids={task.upstream_task_ids}" 

318 ), 

319 ), 

320 changed, 

321 ) 

322 

323 def _evaluate_direct_relatives() -> Iterator[TIDepStatus]: 

324 """Evaluate whether ``ti``'s trigger rule was met. 

325 

326 :param ti: Task instance to evaluate the trigger rule of. 

327 :param dep_context: The current dependency context. 

328 :param session: Database session. 

329 """ 

330 if TYPE_CHECKING: 

331 assert ti.task 

332 

333 task = ti.task 

334 upstream_tasks = {t.task_id: t for t in task.upstream_list} 

335 trigger_rule = task.trigger_rule 

336 

337 finished_upstream_tis = ( 

338 finished_ti 

339 for finished_ti in dep_context.ensure_finished_tis(ti.get_dagrun(session), session) 

340 if _is_relevant_upstream(upstream=finished_ti, relevant_ids=ti.task.upstream_task_ids) 

341 ) 

342 upstream_states = _UpstreamTIStates.calculate(finished_upstream_tis) 

343 

344 success = upstream_states.success 

345 skipped = upstream_states.skipped 

346 failed = upstream_states.failed 

347 upstream_failed = upstream_states.upstream_failed 

348 removed = upstream_states.removed 

349 done = upstream_states.done 

350 success_setup = upstream_states.success_setup 

351 skipped_setup = upstream_states.skipped_setup 

352 

353 # Optimization: Don't need to hit the database if all upstreams are 

354 # "simple" tasks (no task or task group mapping involved). 

355 if not any(t.get_needs_expansion() for t in upstream_tasks.values()): 

356 upstream = len(upstream_tasks) 

357 upstream_setup = sum(1 for x in upstream_tasks.values() if x.is_setup) 

358 else: 

359 task_id_counts = session.execute( 

360 select(TaskInstance.task_id, func.count(TaskInstance.task_id)) 

361 .where(TaskInstance.dag_id == ti.dag_id, TaskInstance.run_id == ti.run_id) 

362 .where(or_(*_iter_upstream_conditions(relevant_tasks=upstream_tasks))) 

363 .group_by(TaskInstance.task_id) 

364 ).all() 

365 upstream = sum(count for _, count in task_id_counts) 

366 upstream_setup = sum(c for t, c in task_id_counts if upstream_tasks[t].is_setup) 

367 

368 upstream_done = done >= upstream 

369 

370 changed = False 

371 new_state = None 

372 if dep_context.flag_upstream_failed: 

373 if trigger_rule == TR.ALL_SUCCESS: 

374 if upstream_failed or failed: 

375 new_state = TaskInstanceState.UPSTREAM_FAILED 

376 elif skipped: 

377 new_state = TaskInstanceState.SKIPPED 

378 elif removed and success and ti.map_index > -1: 

379 if ti.map_index >= success: 

380 new_state = TaskInstanceState.REMOVED 

381 elif trigger_rule == TR.ALL_FAILED: 

382 if success or skipped: 

383 new_state = TaskInstanceState.SKIPPED 

384 elif trigger_rule == TR.ONE_SUCCESS: 

385 if upstream_done and done == skipped: 

386 # if upstream is done and all are skipped mark as skipped 

387 new_state = TaskInstanceState.SKIPPED 

388 elif upstream_done and success <= 0: 

389 # if upstream is done and there are no success mark as upstream failed 

390 new_state = TaskInstanceState.UPSTREAM_FAILED 

391 elif trigger_rule == TR.ONE_FAILED: 

392 if upstream_done and not (failed or upstream_failed): 

393 new_state = TaskInstanceState.SKIPPED 

394 elif trigger_rule == TR.ONE_DONE: 

395 if upstream_done and not (failed or success): 

396 new_state = TaskInstanceState.SKIPPED 

397 elif trigger_rule == TR.NONE_FAILED: 

398 if upstream_failed or failed: 

399 new_state = TaskInstanceState.UPSTREAM_FAILED 

400 elif trigger_rule == TR.NONE_FAILED_MIN_ONE_SUCCESS: 

401 if upstream_failed or failed: 

402 new_state = TaskInstanceState.UPSTREAM_FAILED 

403 elif skipped == upstream: 

404 new_state = TaskInstanceState.SKIPPED 

405 elif trigger_rule == TR.NONE_SKIPPED: 

406 if skipped: 

407 new_state = TaskInstanceState.SKIPPED 

408 elif trigger_rule == TR.ALL_SKIPPED: 

409 if success or failed or upstream_failed: 

410 new_state = TaskInstanceState.SKIPPED 

411 elif trigger_rule == TR.ALL_DONE_SETUP_SUCCESS: 

412 if upstream_done and upstream_setup and skipped_setup >= upstream_setup: 

413 # when there is an upstream setup and they have all skipped, then skip 

414 new_state = TaskInstanceState.SKIPPED 

415 elif upstream_done and upstream_setup and success_setup == 0: 

416 # when there is an upstream setup, if none succeeded, mark upstream failed 

417 # if at least one setup ran, we'll let it run 

418 new_state = TaskInstanceState.UPSTREAM_FAILED 

419 if new_state is not None: 

420 if ( 

421 new_state == TaskInstanceState.SKIPPED 

422 and dep_context.wait_for_past_depends_before_skipping 

423 ): 

424 past_depends_met = ti.xcom_pull( 

425 task_ids=ti.task_id, key=PAST_DEPENDS_MET, session=session, default=False 

426 ) 

427 if not past_depends_met: 

428 yield self._failing_status( 

429 reason=("Task should be skipped but the past depends are not met") 

430 ) 

431 return 

432 changed = ti.set_state(new_state, session) 

433 

434 if changed: 

435 dep_context.have_changed_ti_states = True 

436 

437 if trigger_rule == TR.ONE_SUCCESS: 

438 if success <= 0: 

439 yield self._failing_status( 

440 reason=( 

441 f"Task's trigger rule '{trigger_rule}' requires one upstream task success, " 

442 f"but none were found. upstream_states={upstream_states}, " 

443 f"upstream_task_ids={task.upstream_task_ids}" 

444 ) 

445 ) 

446 elif trigger_rule == TR.ONE_FAILED: 

447 if not failed and not upstream_failed: 

448 yield self._failing_status( 

449 reason=( 

450 f"Task's trigger rule '{trigger_rule}' requires one upstream task failure, " 

451 f"but none were found. upstream_states={upstream_states}, " 

452 f"upstream_task_ids={task.upstream_task_ids}" 

453 ) 

454 ) 

455 elif trigger_rule == TR.ONE_DONE: 

456 if success + failed <= 0: 

457 yield self._failing_status( 

458 reason=( 

459 f"Task's trigger rule '{trigger_rule}' " 

460 "requires at least one upstream task failure or success " 

461 f"but none were failed or success. upstream_states={upstream_states}, " 

462 f"upstream_task_ids={task.upstream_task_ids}" 

463 ) 

464 ) 

465 elif trigger_rule == TR.ALL_SUCCESS: 

466 num_failures = upstream - success 

467 if ti.map_index > -1: 

468 num_failures -= removed 

469 if num_failures > 0: 

470 yield self._failing_status( 

471 reason=( 

472 f"Task's trigger rule '{trigger_rule}' requires all upstream tasks to have " 

473 f"succeeded, but found {num_failures} non-success(es). " 

474 f"upstream_states={upstream_states}, " 

475 f"upstream_task_ids={task.upstream_task_ids}" 

476 ) 

477 ) 

478 elif trigger_rule == TR.ALL_FAILED: 

479 num_success = upstream - failed - upstream_failed 

480 if ti.map_index > -1: 

481 num_success -= removed 

482 if num_success > 0: 

483 yield self._failing_status( 

484 reason=( 

485 f"Task's trigger rule '{trigger_rule}' requires all upstream tasks " 

486 f"to have failed, but found {num_success} non-failure(s). " 

487 f"upstream_states={upstream_states}, " 

488 f"upstream_task_ids={task.upstream_task_ids}" 

489 ) 

490 ) 

491 elif trigger_rule == TR.ALL_DONE: 

492 if not upstream_done: 

493 yield self._failing_status( 

494 reason=( 

495 f"Task's trigger rule '{trigger_rule}' requires all upstream tasks to have " 

496 f"completed, but found {len(upstream_tasks) - done} task(s) that were " 

497 f"not done. upstream_states={upstream_states}, " 

498 f"upstream_task_ids={task.upstream_task_ids}" 

499 ) 

500 ) 

501 elif trigger_rule == TR.NONE_FAILED or trigger_rule == TR.NONE_FAILED_MIN_ONE_SUCCESS: 

502 num_failures = upstream - success - skipped 

503 if ti.map_index > -1: 

504 num_failures -= removed 

505 if num_failures > 0: 

506 yield self._failing_status( 

507 reason=( 

508 f"Task's trigger rule '{trigger_rule}' requires all upstream tasks to have " 

509 f"succeeded or been skipped, but found {num_failures} non-success(es). " 

510 f"upstream_states={upstream_states}, " 

511 f"upstream_task_ids={task.upstream_task_ids}" 

512 ) 

513 ) 

514 elif trigger_rule == TR.NONE_SKIPPED: 

515 if not upstream_done or (skipped > 0): 

516 yield self._failing_status( 

517 reason=( 

518 f"Task's trigger rule '{trigger_rule}' requires all upstream tasks to not " 

519 f"have been skipped, but found {skipped} task(s) skipped. " 

520 f"upstream_states={upstream_states}, " 

521 f"upstream_task_ids={task.upstream_task_ids}" 

522 ) 

523 ) 

524 elif trigger_rule == TR.ALL_SKIPPED: 

525 num_non_skipped = upstream - skipped 

526 if num_non_skipped > 0: 

527 yield self._failing_status( 

528 reason=( 

529 f"Task's trigger rule '{trigger_rule}' requires all upstream tasks to have been " 

530 f"skipped, but found {num_non_skipped} task(s) in non skipped state. " 

531 f"upstream_states={upstream_states}, " 

532 f"upstream_task_ids={task.upstream_task_ids}" 

533 ) 

534 ) 

535 elif trigger_rule == TR.ALL_DONE_SETUP_SUCCESS: 

536 if not upstream_done: 

537 yield self._failing_status( 

538 reason=( 

539 f"Task's trigger rule '{trigger_rule}' requires all upstream tasks to have " 

540 f"completed, but found {len(upstream_tasks) - done} task(s) that were not done. " 

541 f"upstream_states={upstream_states}, " 

542 f"upstream_task_ids={task.upstream_task_ids}" 

543 ) 

544 ) 

545 elif upstream_setup and not success_setup: 

546 yield self._failing_status( 

547 reason=( 

548 f"Task's trigger rule '{trigger_rule}' requires at least one upstream setup task " 

549 f"be successful, but found {upstream_setup - success_setup} task(s) that were " 

550 f"not. upstream_states={upstream_states}, " 

551 f"upstream_task_ids={task.upstream_task_ids}" 

552 ) 

553 ) 

554 else: 

555 yield self._failing_status(reason=f"No strategy to evaluate trigger rule '{trigger_rule}'.") 

556 

557 if TYPE_CHECKING: 

558 assert ti.task 

559 

560 if not ti.task.is_teardown: 

561 # a teardown cannot have any indirect setups 

562 relevant_setups = {t.task_id: t for t in ti.task.get_upstreams_only_setups()} 

563 if relevant_setups: 

564 for status, changed in _evaluate_setup_constraint(relevant_setups=relevant_setups): 

565 yield status 

566 if not status.passed and changed: 

567 # no need to evaluate trigger rule; we've already marked as skipped or failed 

568 return 

569 

570 yield from _evaluate_direct_relatives()