Coverage for /pythoncovmergedfiles/medio/medio/src/airflow/airflow/models/taskinstance.py: 22%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1557 statements  

1# 

2# Licensed to the Apache Software Foundation (ASF) under one 

3# or more contributor license agreements. See the NOTICE file 

4# distributed with this work for additional information 

5# regarding copyright ownership. The ASF licenses this file 

6# to you under the Apache License, Version 2.0 (the 

7# "License"); you may not use this file except in compliance 

8# with the License. You may obtain a copy of the License at 

9# 

10# http://www.apache.org/licenses/LICENSE-2.0 

11# 

12# Unless required by applicable law or agreed to in writing, 

13# software distributed under the License is distributed on an 

14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

15# KIND, either express or implied. See the License for the 

16# specific language governing permissions and limitations 

17# under the License. 

18from __future__ import annotations 

19 

20import collections.abc 

21import contextlib 

22import hashlib 

23import itertools 

24import logging 

25import math 

26import operator 

27import os 

28import signal 

29import warnings 

30from collections import defaultdict 

31from contextlib import nullcontext 

32from datetime import timedelta 

33from enum import Enum 

34from typing import TYPE_CHECKING, Any, Callable, Collection, Generator, Iterable, Mapping, Tuple 

35from urllib.parse import quote 

36 

37import dill 

38import jinja2 

39import lazy_object_proxy 

40import pendulum 

41from deprecated import deprecated 

42from jinja2 import TemplateAssertionError, UndefinedError 

43from sqlalchemy import ( 

44 Column, 

45 DateTime, 

46 Float, 

47 ForeignKey, 

48 ForeignKeyConstraint, 

49 Index, 

50 Integer, 

51 PrimaryKeyConstraint, 

52 String, 

53 Text, 

54 and_, 

55 delete, 

56 false, 

57 func, 

58 inspect, 

59 or_, 

60 text, 

61 update, 

62) 

63from sqlalchemy.ext.associationproxy import association_proxy 

64from sqlalchemy.ext.mutable import MutableDict 

65from sqlalchemy.orm import lazyload, reconstructor, relationship 

66from sqlalchemy.orm.attributes import NO_VALUE, set_committed_value 

67from sqlalchemy.sql.expression import case, select 

68 

69from airflow import settings 

70from airflow.api_internal.internal_api_call import InternalApiConfig, internal_api_call 

71from airflow.compat.functools import cache 

72from airflow.configuration import conf 

73from airflow.datasets import Dataset 

74from airflow.datasets.manager import dataset_manager 

75from airflow.exceptions import ( 

76 AirflowException, 

77 AirflowFailException, 

78 AirflowRescheduleException, 

79 AirflowSensorTimeout, 

80 AirflowSkipException, 

81 AirflowTaskTerminated, 

82 AirflowTaskTimeout, 

83 DagRunNotFound, 

84 RemovedInAirflow3Warning, 

85 TaskDeferred, 

86 UnmappableXComLengthPushed, 

87 UnmappableXComTypePushed, 

88 XComForMappingNotPushed, 

89) 

90from airflow.listeners.listener import get_listener_manager 

91from airflow.models.base import Base, StringID, TaskInstanceDependencies, _sentinel 

92from airflow.models.dagbag import DagBag 

93from airflow.models.log import Log 

94from airflow.models.mappedoperator import MappedOperator 

95from airflow.models.param import process_params 

96from airflow.models.renderedtifields import get_serialized_template_fields 

97from airflow.models.taskfail import TaskFail 

98from airflow.models.taskinstancekey import TaskInstanceKey 

99from airflow.models.taskmap import TaskMap 

100from airflow.models.taskreschedule import TaskReschedule 

101from airflow.models.xcom import LazyXComSelectSequence, XCom 

102from airflow.plugins_manager import integrate_macros_plugins 

103from airflow.sentry import Sentry 

104from airflow.settings import task_instance_mutation_hook 

105from airflow.stats import Stats 

106from airflow.templates import SandboxedEnvironment 

107from airflow.ti_deps.dep_context import DepContext 

108from airflow.ti_deps.dependencies_deps import REQUEUEABLE_DEPS, RUNNING_DEPS 

109from airflow.utils import timezone 

110from airflow.utils.context import ( 

111 ConnectionAccessor, 

112 Context, 

113 InletEventsAccessors, 

114 OutletEventAccessors, 

115 VariableAccessor, 

116 context_get_outlet_events, 

117 context_merge, 

118) 

119from airflow.utils.email import send_email 

120from airflow.utils.helpers import prune_dict, render_template_to_string 

121from airflow.utils.log.logging_mixin import LoggingMixin 

122from airflow.utils.net import get_hostname 

123from airflow.utils.operator_helpers import ExecutionCallableRunner, context_to_airflow_vars 

124from airflow.utils.platform import getuser 

125from airflow.utils.retries import run_with_db_retries 

126from airflow.utils.session import NEW_SESSION, create_session, provide_session 

127from airflow.utils.sqlalchemy import ( 

128 ExecutorConfigType, 

129 ExtendedJSON, 

130 UtcDateTime, 

131 tuple_in_condition, 

132 with_row_locks, 

133) 

134from airflow.utils.state import DagRunState, JobState, State, TaskInstanceState 

135from airflow.utils.task_group import MappedTaskGroup 

136from airflow.utils.task_instance_session import set_current_task_instance_session 

137from airflow.utils.timeout import timeout 

138from airflow.utils.xcom import XCOM_RETURN_KEY 

139 

140TR = TaskReschedule 

141 

142_CURRENT_CONTEXT: list[Context] = [] 

143log = logging.getLogger(__name__) 

144 

145 

146if TYPE_CHECKING: 

147 from datetime import datetime 

148 from pathlib import PurePath 

149 from types import TracebackType 

150 

151 from sqlalchemy.orm.session import Session 

152 from sqlalchemy.sql.elements import BooleanClauseList 

153 from sqlalchemy.sql.expression import ColumnOperators 

154 

155 from airflow.models.abstractoperator import TaskStateChangeCallback 

156 from airflow.models.baseoperator import BaseOperator 

157 from airflow.models.dag import DAG, DagModel 

158 from airflow.models.dagrun import DagRun 

159 from airflow.models.dataset import DatasetEvent 

160 from airflow.models.operator import Operator 

161 from airflow.serialization.pydantic.dag import DagModelPydantic 

162 from airflow.serialization.pydantic.dataset import DatasetEventPydantic 

163 from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic 

164 from airflow.timetables.base import DataInterval 

165 from airflow.typing_compat import Literal, TypeGuard 

166 from airflow.utils.task_group import TaskGroup 

167 

168 # This is a workaround because mypy doesn't work with hybrid_property 

169 # TODO: remove this hack and move hybrid_property back to main import block 

170 # See https://github.com/python/mypy/issues/4430 

171 hybrid_property = property 

172else: 

173 from sqlalchemy.ext.hybrid import hybrid_property 

174 

175 

176PAST_DEPENDS_MET = "past_depends_met" 

177 

178 

179class TaskReturnCode(Enum): 

180 """ 

181 Enum to signal manner of exit for task run command. 

182 

183 :meta private: 

184 """ 

185 

186 DEFERRED = 100 

187 """When task exits with deferral to trigger.""" 

188 

189 

190@contextlib.contextmanager 

191def set_current_context(context: Context) -> Generator[Context, None, None]: 

192 """ 

193 Set the current execution context to the provided context object. 

194 

195 This method should be called once per Task execution, before calling operator.execute. 

196 """ 

197 _CURRENT_CONTEXT.append(context) 

198 try: 

199 yield context 

200 finally: 

201 expected_state = _CURRENT_CONTEXT.pop() 

202 if expected_state != context: 

203 log.warning( 

204 "Current context is not equal to the state at context stack. Expected=%s, got=%s", 

205 context, 

206 expected_state, 

207 ) 

208 

209 

210def _stop_remaining_tasks(*, task_instance: TaskInstance | TaskInstancePydantic, session: Session): 

211 """ 

212 Stop non-teardown tasks in dag. 

213 

214 :meta private: 

215 """ 

216 if not task_instance.dag_run: 

217 raise ValueError("``task_instance`` must have ``dag_run`` set") 

218 tis = task_instance.dag_run.get_task_instances(session=session) 

219 if TYPE_CHECKING: 

220 assert task_instance.task 

221 assert isinstance(task_instance.task.dag, DAG) 

222 

223 for ti in tis: 

224 if ti.task_id == task_instance.task_id or ti.state in ( 

225 TaskInstanceState.SUCCESS, 

226 TaskInstanceState.FAILED, 

227 ): 

228 continue 

229 task = task_instance.task.dag.task_dict[ti.task_id] 

230 if not task.is_teardown: 

231 if ti.state == TaskInstanceState.RUNNING: 

232 log.info("Forcing task %s to fail due to dag's `fail_stop` setting", ti.task_id) 

233 ti.error(session) 

234 else: 

235 log.info("Setting task %s to SKIPPED due to dag's `fail_stop` setting.", ti.task_id) 

236 ti.set_state(state=TaskInstanceState.SKIPPED, session=session) 

237 else: 

238 log.info("Not skipping teardown task '%s'", ti.task_id) 

239 

240 

241def clear_task_instances( 

242 tis: list[TaskInstance], 

243 session: Session, 

244 activate_dag_runs: None = None, 

245 dag: DAG | None = None, 

246 dag_run_state: DagRunState | Literal[False] = DagRunState.QUEUED, 

247) -> None: 

248 """ 

249 Clear a set of task instances, but make sure the running ones get killed. 

250 

251 Also sets Dagrun's `state` to QUEUED and `start_date` to the time of execution. 

252 But only for finished DRs (SUCCESS and FAILED). 

253 Doesn't clear DR's `state` and `start_date`for running 

254 DRs (QUEUED and RUNNING) because clearing the state for already 

255 running DR is redundant and clearing `start_date` affects DR's duration. 

256 

257 :param tis: a list of task instances 

258 :param session: current session 

259 :param dag_run_state: state to set finished DagRuns to. 

260 If set to False, DagRuns state will not be changed. 

261 :param dag: DAG object 

262 :param activate_dag_runs: Deprecated parameter, do not pass 

263 """ 

264 job_ids = [] 

265 # Keys: dag_id -> run_id -> map_indexes -> try_numbers -> task_id 

266 task_id_by_key: dict[str, dict[str, dict[int, dict[int, set[str]]]]] = defaultdict( 

267 lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(set))) 

268 ) 

269 dag_bag = DagBag(read_dags_from_db=True) 

270 for ti in tis: 

271 if ti.state == TaskInstanceState.RUNNING: 

272 if ti.job_id: 

273 # If a task is cleared when running, set its state to RESTARTING so that 

274 # the task is terminated and becomes eligible for retry. 

275 ti.state = TaskInstanceState.RESTARTING 

276 job_ids.append(ti.job_id) 

277 else: 

278 ti_dag = dag if dag and dag.dag_id == ti.dag_id else dag_bag.get_dag(ti.dag_id, session=session) 

279 task_id = ti.task_id 

280 if ti_dag and ti_dag.has_task(task_id): 

281 task = ti_dag.get_task(task_id) 

282 ti.refresh_from_task(task) 

283 if TYPE_CHECKING: 

284 assert ti.task 

285 ti.max_tries = ti.try_number + task.retries 

286 else: 

287 # Ignore errors when updating max_tries if the DAG or 

288 # task are not found since database records could be 

289 # outdated. We make max_tries the maximum value of its 

290 # original max_tries or the last attempted try number. 

291 ti.max_tries = max(ti.max_tries, ti.try_number) 

292 ti.state = None 

293 ti.external_executor_id = None 

294 ti.clear_next_method_args() 

295 session.merge(ti) 

296 

297 task_id_by_key[ti.dag_id][ti.run_id][ti.map_index][ti.try_number].add(ti.task_id) 

298 

299 if task_id_by_key: 

300 # Clear all reschedules related to the ti to clear 

301 

302 # This is an optimization for the common case where all tis are for a small number 

303 # of dag_id, run_id, try_number, and map_index. Use a nested dict of dag_id, 

304 # run_id, try_number, map_index, and task_id to construct the where clause in a 

305 # hierarchical manner. This speeds up the delete statement by more than 40x for 

306 # large number of tis (50k+). 

307 conditions = or_( 

308 and_( 

309 TR.dag_id == dag_id, 

310 or_( 

311 and_( 

312 TR.run_id == run_id, 

313 or_( 

314 and_( 

315 TR.map_index == map_index, 

316 or_( 

317 and_(TR.try_number == try_number, TR.task_id.in_(task_ids)) 

318 for try_number, task_ids in task_tries.items() 

319 ), 

320 ) 

321 for map_index, task_tries in map_indexes.items() 

322 ), 

323 ) 

324 for run_id, map_indexes in run_ids.items() 

325 ), 

326 ) 

327 for dag_id, run_ids in task_id_by_key.items() 

328 ) 

329 

330 delete_qry = TR.__table__.delete().where(conditions) 

331 session.execute(delete_qry) 

332 

333 if job_ids: 

334 from airflow.jobs.job import Job 

335 

336 session.execute(update(Job).where(Job.id.in_(job_ids)).values(state=JobState.RESTARTING)) 

337 

338 if activate_dag_runs is not None: 

339 warnings.warn( 

340 "`activate_dag_runs` parameter to clear_task_instances function is deprecated. " 

341 "Please use `dag_run_state`", 

342 RemovedInAirflow3Warning, 

343 stacklevel=2, 

344 ) 

345 if not activate_dag_runs: 

346 dag_run_state = False 

347 

348 if dag_run_state is not False and tis: 

349 from airflow.models.dagrun import DagRun # Avoid circular import 

350 

351 run_ids_by_dag_id = defaultdict(set) 

352 for instance in tis: 

353 run_ids_by_dag_id[instance.dag_id].add(instance.run_id) 

354 

355 drs = ( 

356 session.query(DagRun) 

357 .filter( 

358 or_( 

359 and_(DagRun.dag_id == dag_id, DagRun.run_id.in_(run_ids)) 

360 for dag_id, run_ids in run_ids_by_dag_id.items() 

361 ) 

362 ) 

363 .all() 

364 ) 

365 dag_run_state = DagRunState(dag_run_state) # Validate the state value. 

366 for dr in drs: 

367 if dr.state in State.finished_dr_states: 

368 dr.state = dag_run_state 

369 dr.start_date = timezone.utcnow() 

370 if dag_run_state == DagRunState.QUEUED: 

371 dr.last_scheduling_decision = None 

372 dr.start_date = None 

373 dr.clear_number += 1 

374 session.flush() 

375 

376 

377def _is_mappable_value(value: Any) -> TypeGuard[Collection]: 

378 """Whether a value can be used for task mapping. 

379 

380 We only allow collections with guaranteed ordering, but exclude character 

381 sequences since that's usually not what users would expect to be mappable. 

382 """ 

383 if not isinstance(value, (collections.abc.Sequence, dict)): 

384 return False 

385 if isinstance(value, (bytearray, bytes, str)): 

386 return False 

387 return True 

388 

389 

390def _creator_note(val): 

391 """Creator the ``note`` association proxy.""" 

392 if isinstance(val, str): 

393 return TaskInstanceNote(content=val) 

394 elif isinstance(val, dict): 

395 return TaskInstanceNote(**val) 

396 else: 

397 return TaskInstanceNote(*val) 

398 

399 

400def _execute_task(task_instance: TaskInstance | TaskInstancePydantic, context: Context, task_orig: Operator): 

401 """ 

402 Execute Task (optionally with a Timeout) and push Xcom results. 

403 

404 :param task_instance: the task instance 

405 :param context: Jinja2 context 

406 :param task_orig: origin task 

407 

408 :meta private: 

409 """ 

410 task_to_execute = task_instance.task 

411 

412 if TYPE_CHECKING: 

413 assert task_to_execute 

414 

415 if isinstance(task_to_execute, MappedOperator): 

416 raise AirflowException("MappedOperator cannot be executed.") 

417 

418 # If the task has been deferred and is being executed due to a trigger, 

419 # then we need to pick the right method to come back to, otherwise 

420 # we go for the default execute 

421 execute_callable_kwargs: dict[str, Any] = {} 

422 execute_callable: Callable 

423 if task_instance.next_method: 

424 if task_instance.next_method == "execute": 

425 if not task_instance.next_kwargs: 

426 task_instance.next_kwargs = {} 

427 task_instance.next_kwargs[f"{task_to_execute.__class__.__name__}__sentinel"] = _sentinel 

428 execute_callable = task_to_execute.resume_execution 

429 execute_callable_kwargs["next_method"] = task_instance.next_method 

430 execute_callable_kwargs["next_kwargs"] = task_instance.next_kwargs 

431 else: 

432 execute_callable = task_to_execute.execute 

433 if execute_callable.__name__ == "execute": 

434 execute_callable_kwargs[f"{task_to_execute.__class__.__name__}__sentinel"] = _sentinel 

435 

436 def _execute_callable(context: Context, **execute_callable_kwargs): 

437 try: 

438 # Print a marker for log grouping of details before task execution 

439 log.info("::endgroup::") 

440 

441 return ExecutionCallableRunner( 

442 execute_callable, 

443 context_get_outlet_events(context), 

444 logger=log, 

445 ).run(context=context, **execute_callable_kwargs) 

446 except SystemExit as e: 

447 # Handle only successful cases here. Failure cases will be handled upper 

448 # in the exception chain. 

449 if e.code is not None and e.code != 0: 

450 raise 

451 return None 

452 finally: 

453 # Print a marker post execution for internals of post task processing 

454 log.info("::group::Post task execution logs") 

455 

456 # If a timeout is specified for the task, make it fail 

457 # if it goes beyond 

458 if task_to_execute.execution_timeout: 

459 # If we are coming in with a next_method (i.e. from a deferral), 

460 # calculate the timeout from our start_date. 

461 if task_instance.next_method and task_instance.start_date: 

462 timeout_seconds = ( 

463 task_to_execute.execution_timeout - (timezone.utcnow() - task_instance.start_date) 

464 ).total_seconds() 

465 else: 

466 timeout_seconds = task_to_execute.execution_timeout.total_seconds() 

467 try: 

468 # It's possible we're already timed out, so fast-fail if true 

469 if timeout_seconds <= 0: 

470 raise AirflowTaskTimeout() 

471 # Run task in timeout wrapper 

472 with timeout(timeout_seconds): 

473 result = _execute_callable(context=context, **execute_callable_kwargs) 

474 except AirflowTaskTimeout: 

475 task_to_execute.on_kill() 

476 raise 

477 else: 

478 result = _execute_callable(context=context, **execute_callable_kwargs) 

479 cm = nullcontext() if InternalApiConfig.get_use_internal_api() else create_session() 

480 with cm as session_or_null: 

481 if task_to_execute.do_xcom_push: 

482 xcom_value = result 

483 else: 

484 xcom_value = None 

485 if xcom_value is not None: # If the task returns a result, push an XCom containing it. 

486 if task_to_execute.multiple_outputs: 

487 if not isinstance(xcom_value, Mapping): 

488 raise AirflowException( 

489 f"Returned output was type {type(xcom_value)} " 

490 "expected dictionary for multiple_outputs" 

491 ) 

492 for key in xcom_value.keys(): 

493 if not isinstance(key, str): 

494 raise AirflowException( 

495 "Returned dictionary keys must be strings when using " 

496 f"multiple_outputs, found {key} ({type(key)}) instead" 

497 ) 

498 for key, value in xcom_value.items(): 

499 task_instance.xcom_push(key=key, value=value, session=session_or_null) 

500 task_instance.xcom_push(key=XCOM_RETURN_KEY, value=xcom_value, session=session_or_null) 

501 _record_task_map_for_downstreams( 

502 task_instance=task_instance, task=task_orig, value=xcom_value, session=session_or_null 

503 ) 

504 return result 

505 

506 

507def _refresh_from_db( 

508 *, 

509 task_instance: TaskInstance | TaskInstancePydantic, 

510 session: Session | None = None, 

511 lock_for_update: bool = False, 

512) -> None: 

513 """ 

514 Refresh the task instance from the database based on the primary key. 

515 

516 :param task_instance: the task instance 

517 :param session: SQLAlchemy ORM Session 

518 :param lock_for_update: if True, indicates that the database should 

519 lock the TaskInstance (issuing a FOR UPDATE clause) until the 

520 session is committed. 

521 

522 :meta private: 

523 """ 

524 if session and task_instance in session: 

525 session.refresh(task_instance, TaskInstance.__mapper__.column_attrs.keys()) 

526 

527 ti = TaskInstance.get_task_instance( 

528 dag_id=task_instance.dag_id, 

529 task_id=task_instance.task_id, 

530 run_id=task_instance.run_id, 

531 map_index=task_instance.map_index, 

532 lock_for_update=lock_for_update, 

533 session=session, 

534 ) 

535 

536 if ti: 

537 # Fields ordered per model definition 

538 task_instance.start_date = ti.start_date 

539 task_instance.end_date = ti.end_date 

540 task_instance.duration = ti.duration 

541 task_instance.state = ti.state 

542 task_instance.try_number = ti.try_number 

543 task_instance.max_tries = ti.max_tries 

544 task_instance.hostname = ti.hostname 

545 task_instance.unixname = ti.unixname 

546 task_instance.job_id = ti.job_id 

547 task_instance.pool = ti.pool 

548 task_instance.pool_slots = ti.pool_slots or 1 

549 task_instance.queue = ti.queue 

550 task_instance.priority_weight = ti.priority_weight 

551 task_instance.operator = ti.operator 

552 task_instance.custom_operator_name = ti.custom_operator_name 

553 task_instance.queued_dttm = ti.queued_dttm 

554 task_instance.queued_by_job_id = ti.queued_by_job_id 

555 task_instance.pid = ti.pid 

556 task_instance.executor = ti.executor 

557 task_instance.executor_config = ti.executor_config 

558 task_instance.external_executor_id = ti.external_executor_id 

559 task_instance.trigger_id = ti.trigger_id 

560 task_instance.next_method = ti.next_method 

561 task_instance.next_kwargs = ti.next_kwargs 

562 else: 

563 task_instance.state = None 

564 

565 

566def _set_duration(*, task_instance: TaskInstance | TaskInstancePydantic) -> None: 

567 """ 

568 Set task instance duration. 

569 

570 :param task_instance: the task instance 

571 

572 :meta private: 

573 """ 

574 if task_instance.end_date and task_instance.start_date: 

575 task_instance.duration = (task_instance.end_date - task_instance.start_date).total_seconds() 

576 else: 

577 task_instance.duration = None 

578 log.debug("Task Duration set to %s", task_instance.duration) 

579 

580 

581def _stats_tags(*, task_instance: TaskInstance | TaskInstancePydantic) -> dict[str, str]: 

582 """ 

583 Return task instance tags. 

584 

585 :param task_instance: the task instance 

586 

587 :meta private: 

588 """ 

589 return prune_dict({"dag_id": task_instance.dag_id, "task_id": task_instance.task_id}) 

590 

591 

592def _clear_next_method_args(*, task_instance: TaskInstance | TaskInstancePydantic) -> None: 

593 """ 

594 Ensure we unset next_method and next_kwargs to ensure that any retries don't reuse them. 

595 

596 :param task_instance: the task instance 

597 

598 :meta private: 

599 """ 

600 log.debug("Clearing next_method and next_kwargs.") 

601 

602 task_instance.next_method = None 

603 task_instance.next_kwargs = None 

604 

605 

606@internal_api_call 

607def _get_template_context( 

608 *, 

609 task_instance: TaskInstance | TaskInstancePydantic, 

610 session: Session | None = None, 

611 ignore_param_exceptions: bool = True, 

612) -> Context: 

613 """ 

614 Return TI Context. 

615 

616 :param task_instance: the task instance 

617 :param session: SQLAlchemy ORM Session 

618 :param ignore_param_exceptions: flag to suppress value exceptions while initializing the ParamsDict 

619 

620 :meta private: 

621 """ 

622 # Do not use provide_session here -- it expunges everything on exit! 

623 if not session: 

624 session = settings.Session() 

625 

626 from airflow import macros 

627 from airflow.models.abstractoperator import NotMapped 

628 

629 integrate_macros_plugins() 

630 

631 task = task_instance.task 

632 if TYPE_CHECKING: 

633 assert task_instance.task 

634 assert task 

635 assert task.dag 

636 try: 

637 dag: DAG = task.dag 

638 except AirflowException: 

639 from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic 

640 

641 if isinstance(task_instance, TaskInstancePydantic): 

642 ti = session.scalar( 

643 select(TaskInstance).where( 

644 TaskInstance.task_id == task_instance.task_id, 

645 TaskInstance.dag_id == task_instance.dag_id, 

646 TaskInstance.run_id == task_instance.run_id, 

647 TaskInstance.map_index == task_instance.map_index, 

648 ) 

649 ) 

650 dag = ti.dag_model.serialized_dag.dag 

651 if hasattr(task_instance.task, "_dag"): # BaseOperator 

652 task_instance.task._dag = dag 

653 else: # MappedOperator 

654 task_instance.task.dag = dag 

655 else: 

656 raise 

657 dag_run = task_instance.get_dagrun(session) 

658 data_interval = dag.get_run_data_interval(dag_run) 

659 

660 validated_params = process_params(dag, task, dag_run, suppress_exception=ignore_param_exceptions) 

661 

662 logical_date: DateTime = timezone.coerce_datetime(task_instance.execution_date) 

663 ds = logical_date.strftime("%Y-%m-%d") 

664 ds_nodash = ds.replace("-", "") 

665 ts = logical_date.isoformat() 

666 ts_nodash = logical_date.strftime("%Y%m%dT%H%M%S") 

667 ts_nodash_with_tz = ts.replace("-", "").replace(":", "") 

668 

669 @cache # Prevent multiple database access. 

670 def _get_previous_dagrun_success() -> DagRun | None: 

671 return task_instance.get_previous_dagrun(state=DagRunState.SUCCESS, session=session) 

672 

673 def _get_previous_dagrun_data_interval_success() -> DataInterval | None: 

674 dagrun = _get_previous_dagrun_success() 

675 if dagrun is None: 

676 return None 

677 return dag.get_run_data_interval(dagrun) 

678 

679 def get_prev_data_interval_start_success() -> pendulum.DateTime | None: 

680 data_interval = _get_previous_dagrun_data_interval_success() 

681 if data_interval is None: 

682 return None 

683 return data_interval.start 

684 

685 def get_prev_data_interval_end_success() -> pendulum.DateTime | None: 

686 data_interval = _get_previous_dagrun_data_interval_success() 

687 if data_interval is None: 

688 return None 

689 return data_interval.end 

690 

691 def get_prev_start_date_success() -> pendulum.DateTime | None: 

692 dagrun = _get_previous_dagrun_success() 

693 if dagrun is None: 

694 return None 

695 return timezone.coerce_datetime(dagrun.start_date) 

696 

697 def get_prev_end_date_success() -> pendulum.DateTime | None: 

698 dagrun = _get_previous_dagrun_success() 

699 if dagrun is None: 

700 return None 

701 return timezone.coerce_datetime(dagrun.end_date) 

702 

703 @cache 

704 def get_yesterday_ds() -> str: 

705 return (logical_date - timedelta(1)).strftime("%Y-%m-%d") 

706 

707 def get_yesterday_ds_nodash() -> str: 

708 return get_yesterday_ds().replace("-", "") 

709 

710 @cache 

711 def get_tomorrow_ds() -> str: 

712 return (logical_date + timedelta(1)).strftime("%Y-%m-%d") 

713 

714 def get_tomorrow_ds_nodash() -> str: 

715 return get_tomorrow_ds().replace("-", "") 

716 

717 @cache 

718 def get_next_execution_date() -> pendulum.DateTime | None: 

719 # For manually triggered dagruns that aren't run on a schedule, 

720 # the "next" execution date doesn't make sense, and should be set 

721 # to execution date for consistency with how execution_date is set 

722 # for manually triggered tasks, i.e. triggered_date == execution_date. 

723 if dag_run.external_trigger: 

724 return logical_date 

725 if dag is None: 

726 return None 

727 next_info = dag.next_dagrun_info(data_interval, restricted=False) 

728 if next_info is None: 

729 return None 

730 return timezone.coerce_datetime(next_info.logical_date) 

731 

732 def get_next_ds() -> str | None: 

733 execution_date = get_next_execution_date() 

734 if execution_date is None: 

735 return None 

736 return execution_date.strftime("%Y-%m-%d") 

737 

738 def get_next_ds_nodash() -> str | None: 

739 ds = get_next_ds() 

740 if ds is None: 

741 return ds 

742 return ds.replace("-", "") 

743 

744 @cache 

745 def get_prev_execution_date(): 

746 # For manually triggered dagruns that aren't run on a schedule, 

747 # the "previous" execution date doesn't make sense, and should be set 

748 # to execution date for consistency with how execution_date is set 

749 # for manually triggered tasks, i.e. triggered_date == execution_date. 

750 if dag_run.external_trigger: 

751 return logical_date 

752 with warnings.catch_warnings(): 

753 warnings.simplefilter("ignore", RemovedInAirflow3Warning) 

754 return dag.previous_schedule(logical_date) 

755 

756 @cache 

757 def get_prev_ds() -> str | None: 

758 execution_date = get_prev_execution_date() 

759 if execution_date is None: 

760 return None 

761 return execution_date.strftime("%Y-%m-%d") 

762 

763 def get_prev_ds_nodash() -> str | None: 

764 prev_ds = get_prev_ds() 

765 if prev_ds is None: 

766 return None 

767 return prev_ds.replace("-", "") 

768 

769 def get_triggering_events() -> dict[str, list[DatasetEvent | DatasetEventPydantic]]: 

770 if TYPE_CHECKING: 

771 assert session is not None 

772 

773 # The dag_run may not be attached to the session anymore since the 

774 # code base is over-zealous with use of session.expunge_all(). 

775 # Re-attach it if we get called. 

776 nonlocal dag_run 

777 if dag_run not in session: 

778 dag_run = session.merge(dag_run, load=False) 

779 dataset_events = dag_run.consumed_dataset_events 

780 triggering_events: dict[str, list[DatasetEvent | DatasetEventPydantic]] = defaultdict(list) 

781 for event in dataset_events: 

782 if event.dataset: 

783 triggering_events[event.dataset.uri].append(event) 

784 

785 return triggering_events 

786 

787 try: 

788 expanded_ti_count: int | None = task.get_mapped_ti_count(task_instance.run_id, session=session) 

789 except NotMapped: 

790 expanded_ti_count = None 

791 

792 # NOTE: If you add to this dict, make sure to also update the following: 

793 # * Context in airflow/utils/context.pyi 

794 # * KNOWN_CONTEXT_KEYS in airflow/utils/context.py 

795 # * Table in docs/apache-airflow/templates-ref.rst 

796 context: dict[str, Any] = { 

797 "conf": conf, 

798 "dag": dag, 

799 "dag_run": dag_run, 

800 "data_interval_end": timezone.coerce_datetime(data_interval.end), 

801 "data_interval_start": timezone.coerce_datetime(data_interval.start), 

802 "outlet_events": OutletEventAccessors(), 

803 "ds": ds, 

804 "ds_nodash": ds_nodash, 

805 "execution_date": logical_date, 

806 "expanded_ti_count": expanded_ti_count, 

807 "inlets": task.inlets, 

808 "inlet_events": InletEventsAccessors(task.inlets, session=session), 

809 "logical_date": logical_date, 

810 "macros": macros, 

811 "map_index_template": task.map_index_template, 

812 "next_ds": get_next_ds(), 

813 "next_ds_nodash": get_next_ds_nodash(), 

814 "next_execution_date": get_next_execution_date(), 

815 "outlets": task.outlets, 

816 "params": validated_params, 

817 "prev_data_interval_start_success": get_prev_data_interval_start_success(), 

818 "prev_data_interval_end_success": get_prev_data_interval_end_success(), 

819 "prev_ds": get_prev_ds(), 

820 "prev_ds_nodash": get_prev_ds_nodash(), 

821 "prev_execution_date": get_prev_execution_date(), 

822 "prev_execution_date_success": task_instance.get_previous_execution_date( 

823 state=DagRunState.SUCCESS, 

824 session=session, 

825 ), 

826 "prev_start_date_success": get_prev_start_date_success(), 

827 "prev_end_date_success": get_prev_end_date_success(), 

828 "run_id": task_instance.run_id, 

829 "task": task, 

830 "task_instance": task_instance, 

831 "task_instance_key_str": f"{task.dag_id}__{task.task_id}__{ds_nodash}", 

832 "test_mode": task_instance.test_mode, 

833 "ti": task_instance, 

834 "tomorrow_ds": get_tomorrow_ds(), 

835 "tomorrow_ds_nodash": get_tomorrow_ds_nodash(), 

836 "triggering_dataset_events": lazy_object_proxy.Proxy(get_triggering_events), 

837 "ts": ts, 

838 "ts_nodash": ts_nodash, 

839 "ts_nodash_with_tz": ts_nodash_with_tz, 

840 "var": { 

841 "json": VariableAccessor(deserialize_json=True), 

842 "value": VariableAccessor(deserialize_json=False), 

843 }, 

844 "conn": ConnectionAccessor(), 

845 "yesterday_ds": get_yesterday_ds(), 

846 "yesterday_ds_nodash": get_yesterday_ds_nodash(), 

847 } 

848 # Mypy doesn't like turning existing dicts in to a TypeDict -- and we "lie" in the type stub to say it 

849 # is one, but in practice it isn't. See https://github.com/python/mypy/issues/8890 

850 return Context(context) # type: ignore 

851 

852 

853def _is_eligible_to_retry(*, task_instance: TaskInstance | TaskInstancePydantic): 

854 """ 

855 Is task instance is eligible for retry. 

856 

857 :param task_instance: the task instance 

858 

859 :meta private: 

860 """ 

861 if task_instance.state == TaskInstanceState.RESTARTING: 

862 # If a task is cleared when running, it goes into RESTARTING state and is always 

863 # eligible for retry 

864 return True 

865 if not getattr(task_instance, "task", None): 

866 # Couldn't load the task, don't know number of retries, guess: 

867 return task_instance.try_number <= task_instance.max_tries 

868 

869 if TYPE_CHECKING: 

870 assert task_instance.task 

871 

872 return task_instance.task.retries and task_instance.try_number <= task_instance.max_tries 

873 

874 

875def _handle_failure( 

876 *, 

877 task_instance: TaskInstance | TaskInstancePydantic, 

878 error: None | str | BaseException, 

879 session: Session, 

880 test_mode: bool | None = None, 

881 context: Context | None = None, 

882 force_fail: bool = False, 

883 fail_stop: bool = False, 

884) -> None: 

885 """ 

886 Handle Failure for a task instance. 

887 

888 :param task_instance: the task instance 

889 :param error: if specified, log the specific exception if thrown 

890 :param session: SQLAlchemy ORM Session 

891 :param test_mode: doesn't record success or failure in the DB if True 

892 :param context: Jinja2 context 

893 :param force_fail: if True, task does not retry 

894 

895 :meta private: 

896 """ 

897 if test_mode is None: 

898 test_mode = task_instance.test_mode 

899 

900 failure_context = TaskInstance.fetch_handle_failure_context( 

901 ti=task_instance, 

902 error=error, 

903 test_mode=test_mode, 

904 context=context, 

905 force_fail=force_fail, 

906 session=session, 

907 fail_stop=fail_stop, 

908 ) 

909 

910 _log_state(task_instance=task_instance, lead_msg="Immediate failure requested. " if force_fail else "") 

911 if ( 

912 failure_context["task"] 

913 and failure_context["email_for_state"](failure_context["task"]) 

914 and failure_context["task"].email 

915 ): 

916 try: 

917 task_instance.email_alert(error, failure_context["task"]) 

918 except Exception: 

919 log.exception("Failed to send email to: %s", failure_context["task"].email) 

920 

921 if failure_context["callbacks"] and failure_context["context"]: 

922 _run_finished_callback( 

923 callbacks=failure_context["callbacks"], 

924 context=failure_context["context"], 

925 ) 

926 

927 if not test_mode: 

928 TaskInstance.save_to_db(failure_context["ti"], session) 

929 

930 

931def _refresh_from_task( 

932 *, task_instance: TaskInstance | TaskInstancePydantic, task: Operator, pool_override: str | None = None 

933) -> None: 

934 """ 

935 Copy common attributes from the given task. 

936 

937 :param task_instance: the task instance 

938 :param task: The task object to copy from 

939 :param pool_override: Use the pool_override instead of task's pool 

940 

941 :meta private: 

942 """ 

943 task_instance.task = task 

944 task_instance.queue = task.queue 

945 task_instance.pool = pool_override or task.pool 

946 task_instance.pool_slots = task.pool_slots 

947 with contextlib.suppress(Exception): 

948 # This method is called from the different places, and sometimes the TI is not fully initialized 

949 task_instance.priority_weight = task_instance.task.weight_rule.get_weight( 

950 task_instance # type: ignore[arg-type] 

951 ) 

952 task_instance.run_as_user = task.run_as_user 

953 # Do not set max_tries to task.retries here because max_tries is a cumulative 

954 # value that needs to be stored in the db. 

955 task_instance.executor = task.executor 

956 task_instance.executor_config = task.executor_config 

957 task_instance.operator = task.task_type 

958 task_instance.custom_operator_name = getattr(task, "custom_operator_name", None) 

959 # Re-apply cluster policy here so that task default do not overload previous data 

960 task_instance_mutation_hook(task_instance) 

961 

962 

963def _record_task_map_for_downstreams( 

964 *, task_instance: TaskInstance | TaskInstancePydantic, task: Operator, value: Any, session: Session 

965) -> None: 

966 """ 

967 Record the task map for downstream tasks. 

968 

969 :param task_instance: the task instance 

970 :param task: The task object 

971 :param value: The value 

972 :param session: SQLAlchemy ORM Session 

973 

974 :meta private: 

975 """ 

976 if next(task.iter_mapped_dependants(), None) is None: # No mapped dependants, no need to validate. 

977 return 

978 # TODO: We don't push TaskMap for mapped task instances because it's not 

979 # currently possible for a downstream to depend on one individual mapped 

980 # task instance. This will change when we implement task mapping inside 

981 # a mapped task group, and we'll need to further analyze the case. 

982 if isinstance(task, MappedOperator): 

983 return 

984 if value is None: 

985 raise XComForMappingNotPushed() 

986 if not _is_mappable_value(value): 

987 raise UnmappableXComTypePushed(value) 

988 task_map = TaskMap.from_task_instance_xcom(task_instance, value) 

989 max_map_length = conf.getint("core", "max_map_length", fallback=1024) 

990 if task_map.length > max_map_length: 

991 raise UnmappableXComLengthPushed(value, max_map_length) 

992 session.merge(task_map) 

993 

994 

995def _get_previous_dagrun( 

996 *, 

997 task_instance: TaskInstance | TaskInstancePydantic, 

998 state: DagRunState | None = None, 

999 session: Session | None = None, 

1000) -> DagRun | None: 

1001 """ 

1002 Return the DagRun that ran prior to this task instance's DagRun. 

1003 

1004 :param task_instance: the task instance 

1005 :param state: If passed, it only take into account instances of a specific state. 

1006 :param session: SQLAlchemy ORM Session. 

1007 

1008 :meta private: 

1009 """ 

1010 if TYPE_CHECKING: 

1011 assert task_instance.task 

1012 

1013 dag = task_instance.task.dag 

1014 if dag is None: 

1015 return None 

1016 

1017 dr = task_instance.get_dagrun(session=session) 

1018 dr.dag = dag 

1019 

1020 from airflow.models.dagrun import DagRun # Avoid circular import 

1021 

1022 # We always ignore schedule in dagrun lookup when `state` is given 

1023 # or the DAG is never scheduled. For legacy reasons, when 

1024 # `catchup=True`, we use `get_previous_scheduled_dagrun` unless 

1025 # `ignore_schedule` is `True`. 

1026 ignore_schedule = state is not None or not dag.timetable.can_be_scheduled 

1027 if dag.catchup is True and not ignore_schedule: 

1028 last_dagrun = DagRun.get_previous_scheduled_dagrun(dr.id, session=session) 

1029 else: 

1030 last_dagrun = DagRun.get_previous_dagrun(dag_run=dr, session=session, state=state) 

1031 

1032 if last_dagrun: 

1033 return last_dagrun 

1034 

1035 return None 

1036 

1037 

1038def _get_previous_execution_date( 

1039 *, 

1040 task_instance: TaskInstance | TaskInstancePydantic, 

1041 state: DagRunState | None, 

1042 session: Session, 

1043) -> pendulum.DateTime | None: 

1044 """ 

1045 Get execution date from property previous_ti_success. 

1046 

1047 :param task_instance: the task instance 

1048 :param session: SQLAlchemy ORM Session 

1049 :param state: If passed, it only take into account instances of a specific state. 

1050 

1051 :meta private: 

1052 """ 

1053 log.debug("previous_execution_date was called") 

1054 prev_ti = task_instance.get_previous_ti(state=state, session=session) 

1055 return pendulum.instance(prev_ti.execution_date) if prev_ti and prev_ti.execution_date else None 

1056 

1057 

1058def _email_alert( 

1059 *, task_instance: TaskInstance | TaskInstancePydantic, exception, task: BaseOperator 

1060) -> None: 

1061 """ 

1062 Send alert email with exception information. 

1063 

1064 :param task_instance: the task instance 

1065 :param exception: the exception 

1066 :param task: task related to the exception 

1067 

1068 :meta private: 

1069 """ 

1070 subject, html_content, html_content_err = task_instance.get_email_subject_content(exception, task=task) 

1071 if TYPE_CHECKING: 

1072 assert task.email 

1073 try: 

1074 send_email(task.email, subject, html_content) 

1075 except Exception: 

1076 send_email(task.email, subject, html_content_err) 

1077 

1078 

1079def _get_email_subject_content( 

1080 *, 

1081 task_instance: TaskInstance | TaskInstancePydantic, 

1082 exception: BaseException, 

1083 task: BaseOperator | None = None, 

1084) -> tuple[str, str, str]: 

1085 """ 

1086 Get the email subject content for exceptions. 

1087 

1088 :param task_instance: the task instance 

1089 :param exception: the exception sent in the email 

1090 :param task: 

1091 

1092 :meta private: 

1093 """ 

1094 # For a ti from DB (without ti.task), return the default value 

1095 if task is None: 

1096 task = getattr(task_instance, "task") 

1097 use_default = task is None 

1098 exception_html = str(exception).replace("\n", "<br>") 

1099 

1100 default_subject = "Airflow alert: {{ti}}" 

1101 # For reporting purposes, we report based on 1-indexed, 

1102 # not 0-indexed lists (i.e. Try 1 instead of 

1103 # Try 0 for the first attempt). 

1104 default_html_content = ( 

1105 "Try {{try_number}} out of {{max_tries + 1}}<br>" 

1106 "Exception:<br>{{exception_html}}<br>" 

1107 'Log: <a href="{{ti.log_url}}">Link</a><br>' 

1108 "Host: {{ti.hostname}}<br>" 

1109 'Mark success: <a href="{{ti.mark_success_url}}">Link</a><br>' 

1110 ) 

1111 

1112 default_html_content_err = ( 

1113 "Try {{try_number}} out of {{max_tries + 1}}<br>" 

1114 "Exception:<br>Failed attempt to attach error logs<br>" 

1115 'Log: <a href="{{ti.log_url}}">Link</a><br>' 

1116 "Host: {{ti.hostname}}<br>" 

1117 'Mark success: <a href="{{ti.mark_success_url}}">Link</a><br>' 

1118 ) 

1119 

1120 additional_context: dict[str, Any] = { 

1121 "exception": exception, 

1122 "exception_html": exception_html, 

1123 "try_number": task_instance.try_number, 

1124 "max_tries": task_instance.max_tries, 

1125 } 

1126 

1127 if use_default: 

1128 default_context = {"ti": task_instance, **additional_context} 

1129 jinja_env = jinja2.Environment( 

1130 loader=jinja2.FileSystemLoader(os.path.dirname(__file__)), autoescape=True 

1131 ) 

1132 subject = jinja_env.from_string(default_subject).render(**default_context) 

1133 html_content = jinja_env.from_string(default_html_content).render(**default_context) 

1134 html_content_err = jinja_env.from_string(default_html_content_err).render(**default_context) 

1135 

1136 else: 

1137 if TYPE_CHECKING: 

1138 assert task_instance.task 

1139 

1140 # Use the DAG's get_template_env() to set force_sandboxed. Don't add 

1141 # the flag to the function on task object -- that function can be 

1142 # overridden, and adding a flag breaks backward compatibility. 

1143 dag = task_instance.task.get_dag() 

1144 if dag: 

1145 jinja_env = dag.get_template_env(force_sandboxed=True) 

1146 else: 

1147 jinja_env = SandboxedEnvironment(cache_size=0) 

1148 jinja_context = task_instance.get_template_context() 

1149 context_merge(jinja_context, additional_context) 

1150 

1151 def render(key: str, content: str) -> str: 

1152 if conf.has_option("email", key): 

1153 path = conf.get_mandatory_value("email", key) 

1154 try: 

1155 with open(path) as f: 

1156 content = f.read() 

1157 except FileNotFoundError: 

1158 log.warning("Could not find email template file '%s'. Using defaults...", path) 

1159 except OSError: 

1160 log.exception("Error while using email template %s. Using defaults...", path) 

1161 return render_template_to_string(jinja_env.from_string(content), jinja_context) 

1162 

1163 subject = render("subject_template", default_subject) 

1164 html_content = render("html_content_template", default_html_content) 

1165 html_content_err = render("html_content_template", default_html_content_err) 

1166 

1167 return subject, html_content, html_content_err 

1168 

1169 

1170def _run_finished_callback( 

1171 *, 

1172 callbacks: None | TaskStateChangeCallback | list[TaskStateChangeCallback], 

1173 context: Context, 

1174) -> None: 

1175 """ 

1176 Run callback after task finishes. 

1177 

1178 :param callbacks: callbacks to run 

1179 :param context: callbacks context 

1180 

1181 :meta private: 

1182 """ 

1183 if callbacks: 

1184 callbacks = callbacks if isinstance(callbacks, list) else [callbacks] 

1185 for callback in callbacks: 

1186 log.info("Executing %s callback", callback.__name__) 

1187 try: 

1188 callback(context) 

1189 except Exception: 

1190 log.exception("Error when executing %s callback", callback.__name__) # type: ignore[attr-defined] 

1191 

1192 

1193def _log_state(*, task_instance: TaskInstance | TaskInstancePydantic, lead_msg: str = "") -> None: 

1194 """ 

1195 Log task state. 

1196 

1197 :param task_instance: the task instance 

1198 :param lead_msg: lead message 

1199 

1200 :meta private: 

1201 """ 

1202 params = [ 

1203 lead_msg, 

1204 str(task_instance.state).upper(), 

1205 task_instance.dag_id, 

1206 task_instance.task_id, 

1207 task_instance.run_id, 

1208 ] 

1209 message = "%sMarking task as %s. dag_id=%s, task_id=%s, run_id=%s, " 

1210 if task_instance.map_index >= 0: 

1211 params.append(task_instance.map_index) 

1212 message += "map_index=%d, " 

1213 message += "execution_date=%s, start_date=%s, end_date=%s" 

1214 log.info( 

1215 message, 

1216 *params, 

1217 _date_or_empty(task_instance=task_instance, attr="execution_date"), 

1218 _date_or_empty(task_instance=task_instance, attr="start_date"), 

1219 _date_or_empty(task_instance=task_instance, attr="end_date"), 

1220 stacklevel=2, 

1221 ) 

1222 

1223 

1224def _date_or_empty(*, task_instance: TaskInstance | TaskInstancePydantic, attr: str) -> str: 

1225 """ 

1226 Fetch a date attribute or None of it does not exist. 

1227 

1228 :param task_instance: the task instance 

1229 :param attr: the attribute name 

1230 

1231 :meta private: 

1232 """ 

1233 result: datetime | None = getattr(task_instance, attr, None) 

1234 return result.strftime("%Y%m%dT%H%M%S") if result else "" 

1235 

1236 

1237def _get_previous_ti( 

1238 *, 

1239 task_instance: TaskInstance | TaskInstancePydantic, 

1240 session: Session, 

1241 state: DagRunState | None = None, 

1242) -> TaskInstance | TaskInstancePydantic | None: 

1243 """ 

1244 Get task instance for the task that ran before this task instance. 

1245 

1246 :param task_instance: the task instance 

1247 :param state: If passed, it only take into account instances of a specific state. 

1248 :param session: SQLAlchemy ORM Session 

1249 

1250 :meta private: 

1251 """ 

1252 dagrun = task_instance.get_previous_dagrun(state, session=session) 

1253 if dagrun is None: 

1254 return None 

1255 return dagrun.get_task_instance(task_instance.task_id, session=session) 

1256 

1257 

1258@internal_api_call 

1259@provide_session 

1260def _update_rtif(ti, rendered_fields, session: Session | None = None): 

1261 from airflow.models.renderedtifields import RenderedTaskInstanceFields 

1262 

1263 rtif = RenderedTaskInstanceFields(ti=ti, render_templates=False, rendered_fields=rendered_fields) 

1264 RenderedTaskInstanceFields.write(rtif, session=session) 

1265 RenderedTaskInstanceFields.delete_old_records(ti.task_id, ti.dag_id, session=session) 

1266 

1267 

1268class TaskInstance(Base, LoggingMixin): 

1269 """ 

1270 Task instances store the state of a task instance. 

1271 

1272 This table is the authority and single source of truth around what tasks 

1273 have run and the state they are in. 

1274 

1275 The SqlAlchemy model doesn't have a SqlAlchemy foreign key to the task or 

1276 dag model deliberately to have more control over transactions. 

1277 

1278 Database transactions on this table should insure double triggers and 

1279 any confusion around what task instances are or aren't ready to run 

1280 even while multiple schedulers may be firing task instances. 

1281 

1282 A value of -1 in map_index represents any of: a TI without mapped tasks; 

1283 a TI with mapped tasks that has yet to be expanded (state=pending); 

1284 a TI with mapped tasks that expanded to an empty list (state=skipped). 

1285 """ 

1286 

1287 __tablename__ = "task_instance" 

1288 task_id = Column(StringID(), primary_key=True, nullable=False) 

1289 dag_id = Column(StringID(), primary_key=True, nullable=False) 

1290 run_id = Column(StringID(), primary_key=True, nullable=False) 

1291 map_index = Column(Integer, primary_key=True, nullable=False, server_default=text("-1")) 

1292 

1293 start_date = Column(UtcDateTime) 

1294 end_date = Column(UtcDateTime) 

1295 duration = Column(Float) 

1296 state = Column(String(20)) 

1297 try_number = Column(Integer, default=0) 

1298 max_tries = Column(Integer, server_default=text("-1")) 

1299 hostname = Column(String(1000)) 

1300 unixname = Column(String(1000)) 

1301 job_id = Column(Integer) 

1302 pool = Column(String(256), nullable=False) 

1303 pool_slots = Column(Integer, default=1, nullable=False) 

1304 queue = Column(String(256)) 

1305 priority_weight = Column(Integer) 

1306 operator = Column(String(1000)) 

1307 custom_operator_name = Column(String(1000)) 

1308 queued_dttm = Column(UtcDateTime) 

1309 queued_by_job_id = Column(Integer) 

1310 pid = Column(Integer) 

1311 executor = Column(String(1000)) 

1312 executor_config = Column(ExecutorConfigType(pickler=dill)) 

1313 updated_at = Column(UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow) 

1314 rendered_map_index = Column(String(250)) 

1315 

1316 external_executor_id = Column(StringID()) 

1317 

1318 # The trigger to resume on if we are in state DEFERRED 

1319 trigger_id = Column(Integer) 

1320 

1321 # Optional timeout datetime for the trigger (past this, we'll fail) 

1322 trigger_timeout = Column(DateTime) 

1323 # The trigger_timeout should be TIMESTAMP(using UtcDateTime) but for ease of 

1324 # migration, we are keeping it as DateTime pending a change where expensive 

1325 # migration is inevitable. 

1326 

1327 # The method to call next, and any extra arguments to pass to it. 

1328 # Usually used when resuming from DEFERRED. 

1329 next_method = Column(String(1000)) 

1330 next_kwargs = Column(MutableDict.as_mutable(ExtendedJSON)) 

1331 

1332 _task_display_property_value = Column("task_display_name", String(2000), nullable=True) 

1333 # If adding new fields here then remember to add them to 

1334 # refresh_from_db() or they won't display in the UI correctly 

1335 

1336 __table_args__ = ( 

1337 Index("ti_dag_state", dag_id, state), 

1338 Index("ti_dag_run", dag_id, run_id), 

1339 Index("ti_state", state), 

1340 Index("ti_state_lkp", dag_id, task_id, run_id, state), 

1341 Index("ti_pool", pool, state, priority_weight), 

1342 Index("ti_job_id", job_id), 

1343 Index("ti_trigger_id", trigger_id), 

1344 PrimaryKeyConstraint("dag_id", "task_id", "run_id", "map_index", name="task_instance_pkey"), 

1345 ForeignKeyConstraint( 

1346 [trigger_id], 

1347 ["trigger.id"], 

1348 name="task_instance_trigger_id_fkey", 

1349 ondelete="CASCADE", 

1350 ), 

1351 ForeignKeyConstraint( 

1352 [dag_id, run_id], 

1353 ["dag_run.dag_id", "dag_run.run_id"], 

1354 name="task_instance_dag_run_fkey", 

1355 ondelete="CASCADE", 

1356 ), 

1357 ) 

1358 

1359 dag_model: DagModel = relationship( 

1360 "DagModel", 

1361 primaryjoin="TaskInstance.dag_id == DagModel.dag_id", 

1362 foreign_keys=dag_id, 

1363 uselist=False, 

1364 innerjoin=True, 

1365 viewonly=True, 

1366 ) 

1367 

1368 trigger = relationship("Trigger", uselist=False, back_populates="task_instance") 

1369 triggerer_job = association_proxy("trigger", "triggerer_job") 

1370 dag_run = relationship("DagRun", back_populates="task_instances", lazy="joined", innerjoin=True) 

1371 rendered_task_instance_fields = relationship("RenderedTaskInstanceFields", lazy="noload", uselist=False) 

1372 execution_date = association_proxy("dag_run", "execution_date") 

1373 task_instance_note = relationship( 

1374 "TaskInstanceNote", 

1375 back_populates="task_instance", 

1376 uselist=False, 

1377 cascade="all, delete, delete-orphan", 

1378 ) 

1379 note = association_proxy("task_instance_note", "content", creator=_creator_note) 

1380 

1381 task: Operator | None = None 

1382 test_mode: bool = False 

1383 is_trigger_log_context: bool = False 

1384 run_as_user: str | None = None 

1385 raw: bool | None = None 

1386 """Indicate to FileTaskHandler that logging context should be set up for trigger logging. 

1387 

1388 :meta private: 

1389 """ 

1390 _logger_name = "airflow.task" 

1391 

1392 def __init__( 

1393 self, 

1394 task: Operator, 

1395 execution_date: datetime | None = None, 

1396 run_id: str | None = None, 

1397 state: str | None = None, 

1398 map_index: int = -1, 

1399 ): 

1400 super().__init__() 

1401 self.dag_id = task.dag_id 

1402 self.task_id = task.task_id 

1403 self.map_index = map_index 

1404 self.refresh_from_task(task) 

1405 if TYPE_CHECKING: 

1406 assert self.task 

1407 

1408 # init_on_load will config the log 

1409 self.init_on_load() 

1410 

1411 if run_id is None and execution_date is not None: 

1412 from airflow.models.dagrun import DagRun # Avoid circular import 

1413 

1414 warnings.warn( 

1415 "Passing an execution_date to `TaskInstance()` is deprecated in favour of passing a run_id", 

1416 RemovedInAirflow3Warning, 

1417 # Stack level is 4 because SQLA adds some wrappers around the constructor 

1418 stacklevel=4, 

1419 ) 

1420 # make sure we have a localized execution_date stored in UTC 

1421 if execution_date and not timezone.is_localized(execution_date): 

1422 self.log.warning( 

1423 "execution date %s has no timezone information. Using default from dag or system", 

1424 execution_date, 

1425 ) 

1426 if self.task.has_dag(): 

1427 if TYPE_CHECKING: 

1428 assert self.task.dag 

1429 execution_date = timezone.make_aware(execution_date, self.task.dag.timezone) 

1430 else: 

1431 execution_date = timezone.make_aware(execution_date) 

1432 

1433 execution_date = timezone.convert_to_utc(execution_date) 

1434 with create_session() as session: 

1435 run_id = ( 

1436 session.query(DagRun.run_id) 

1437 .filter_by(dag_id=self.dag_id, execution_date=execution_date) 

1438 .scalar() 

1439 ) 

1440 if run_id is None: 

1441 raise DagRunNotFound( 

1442 f"DagRun for {self.dag_id!r} with date {execution_date} not found" 

1443 ) from None 

1444 

1445 self.run_id = run_id 

1446 

1447 self.try_number = 0 

1448 self.max_tries = self.task.retries 

1449 self.unixname = getuser() 

1450 if state: 

1451 self.state = state 

1452 self.hostname = "" 

1453 # Is this TaskInstance being currently running within `airflow tasks run --raw`. 

1454 # Not persisted to the database so only valid for the current process 

1455 self.raw = False 

1456 # can be changed when calling 'run' 

1457 self.test_mode = False 

1458 

1459 def __hash__(self): 

1460 return hash((self.task_id, self.dag_id, self.run_id, self.map_index)) 

1461 

1462 @property 

1463 @deprecated(reason="Use try_number instead.", version="2.10.0", category=RemovedInAirflow3Warning) 

1464 def _try_number(self): 

1465 """ 

1466 Do not use. For semblance of backcompat. 

1467 

1468 :meta private: 

1469 """ 

1470 return self.try_number 

1471 

1472 @_try_number.setter 

1473 @deprecated(reason="Use try_number instead.", version="2.10.0", category=RemovedInAirflow3Warning) 

1474 def _try_number(self, val): 

1475 """ 

1476 Do not use. For semblance of backcompat. 

1477 

1478 :meta private: 

1479 """ 

1480 self.try_number = val 

1481 

1482 @property 

1483 def stats_tags(self) -> dict[str, str]: 

1484 """Returns task instance tags.""" 

1485 return _stats_tags(task_instance=self) 

1486 

1487 @staticmethod 

1488 def insert_mapping(run_id: str, task: Operator, map_index: int) -> dict[str, Any]: 

1489 """Insert mapping. 

1490 

1491 :meta private: 

1492 """ 

1493 priority_weight = task.weight_rule.get_weight( 

1494 TaskInstance(task=task, run_id=run_id, map_index=map_index) 

1495 ) 

1496 

1497 return { 

1498 "dag_id": task.dag_id, 

1499 "task_id": task.task_id, 

1500 "run_id": run_id, 

1501 "try_number": 0, 

1502 "hostname": "", 

1503 "unixname": getuser(), 

1504 "queue": task.queue, 

1505 "pool": task.pool, 

1506 "pool_slots": task.pool_slots, 

1507 "priority_weight": priority_weight, 

1508 "run_as_user": task.run_as_user, 

1509 "max_tries": task.retries, 

1510 "executor": task.executor, 

1511 "executor_config": task.executor_config, 

1512 "operator": task.task_type, 

1513 "custom_operator_name": getattr(task, "custom_operator_name", None), 

1514 "map_index": map_index, 

1515 "_task_display_property_value": task.task_display_name, 

1516 } 

1517 

1518 @reconstructor 

1519 def init_on_load(self) -> None: 

1520 """Initialize the attributes that aren't stored in the DB.""" 

1521 self.test_mode = False # can be changed when calling 'run' 

1522 

1523 @property 

1524 @deprecated(reason="Use try_number instead.", version="2.10.0", category=RemovedInAirflow3Warning) 

1525 def prev_attempted_tries(self) -> int: 

1526 """ 

1527 Calculate the total number of attempted tries, defaulting to 0. 

1528 

1529 This used to be necessary because try_number did not always tell the truth. 

1530 

1531 :meta private: 

1532 """ 

1533 return self.try_number 

1534 

1535 @property 

1536 def next_try_number(self) -> int: 

1537 # todo (dstandish): deprecate this property; we don't need a property that is just + 1 

1538 return self.try_number + 1 

1539 

1540 @property 

1541 def operator_name(self) -> str | None: 

1542 """@property: use a more friendly display name for the operator, if set.""" 

1543 return self.custom_operator_name or self.operator 

1544 

1545 @hybrid_property 

1546 def task_display_name(self) -> str: 

1547 return self._task_display_property_value or self.task_id 

1548 

1549 @staticmethod 

1550 def _command_as_list( 

1551 ti: TaskInstance | TaskInstancePydantic, 

1552 mark_success: bool = False, 

1553 ignore_all_deps: bool = False, 

1554 ignore_task_deps: bool = False, 

1555 ignore_depends_on_past: bool = False, 

1556 wait_for_past_depends_before_skipping: bool = False, 

1557 ignore_ti_state: bool = False, 

1558 local: bool = False, 

1559 pickle_id: int | None = None, 

1560 raw: bool = False, 

1561 job_id: str | None = None, 

1562 pool: str | None = None, 

1563 cfg_path: str | None = None, 

1564 ) -> list[str]: 

1565 dag: DAG | DagModel | DagModelPydantic | None 

1566 # Use the dag if we have it, else fallback to the ORM dag_model, which might not be loaded 

1567 if hasattr(ti, "task") and getattr(ti.task, "dag", None) is not None: 

1568 if TYPE_CHECKING: 

1569 assert ti.task 

1570 dag = ti.task.dag 

1571 else: 

1572 dag = ti.dag_model 

1573 

1574 if dag is None: 

1575 raise ValueError("DagModel is empty") 

1576 

1577 should_pass_filepath = not pickle_id and dag 

1578 path: PurePath | None = None 

1579 if should_pass_filepath: 

1580 if dag.is_subdag: 

1581 if TYPE_CHECKING: 

1582 assert dag.parent_dag is not None 

1583 path = dag.parent_dag.relative_fileloc 

1584 else: 

1585 path = dag.relative_fileloc 

1586 

1587 if path: 

1588 if not path.is_absolute(): 

1589 path = "DAGS_FOLDER" / path 

1590 

1591 return TaskInstance.generate_command( 

1592 ti.dag_id, 

1593 ti.task_id, 

1594 run_id=ti.run_id, 

1595 mark_success=mark_success, 

1596 ignore_all_deps=ignore_all_deps, 

1597 ignore_task_deps=ignore_task_deps, 

1598 ignore_depends_on_past=ignore_depends_on_past, 

1599 wait_for_past_depends_before_skipping=wait_for_past_depends_before_skipping, 

1600 ignore_ti_state=ignore_ti_state, 

1601 local=local, 

1602 pickle_id=pickle_id, 

1603 file_path=path, 

1604 raw=raw, 

1605 job_id=job_id, 

1606 pool=pool, 

1607 cfg_path=cfg_path, 

1608 map_index=ti.map_index, 

1609 ) 

1610 

1611 def command_as_list( 

1612 self, 

1613 mark_success: bool = False, 

1614 ignore_all_deps: bool = False, 

1615 ignore_task_deps: bool = False, 

1616 ignore_depends_on_past: bool = False, 

1617 wait_for_past_depends_before_skipping: bool = False, 

1618 ignore_ti_state: bool = False, 

1619 local: bool = False, 

1620 pickle_id: int | None = None, 

1621 raw: bool = False, 

1622 job_id: str | None = None, 

1623 pool: str | None = None, 

1624 cfg_path: str | None = None, 

1625 ) -> list[str]: 

1626 """ 

1627 Return a command that can be executed anywhere where airflow is installed. 

1628 

1629 This command is part of the message sent to executors by the orchestrator. 

1630 """ 

1631 return TaskInstance._command_as_list( 

1632 ti=self, 

1633 mark_success=mark_success, 

1634 ignore_all_deps=ignore_all_deps, 

1635 ignore_task_deps=ignore_task_deps, 

1636 ignore_depends_on_past=ignore_depends_on_past, 

1637 wait_for_past_depends_before_skipping=wait_for_past_depends_before_skipping, 

1638 ignore_ti_state=ignore_ti_state, 

1639 local=local, 

1640 pickle_id=pickle_id, 

1641 raw=raw, 

1642 job_id=job_id, 

1643 pool=pool, 

1644 cfg_path=cfg_path, 

1645 ) 

1646 

1647 @staticmethod 

1648 def generate_command( 

1649 dag_id: str, 

1650 task_id: str, 

1651 run_id: str, 

1652 mark_success: bool = False, 

1653 ignore_all_deps: bool = False, 

1654 ignore_depends_on_past: bool = False, 

1655 wait_for_past_depends_before_skipping: bool = False, 

1656 ignore_task_deps: bool = False, 

1657 ignore_ti_state: bool = False, 

1658 local: bool = False, 

1659 pickle_id: int | None = None, 

1660 file_path: PurePath | str | None = None, 

1661 raw: bool = False, 

1662 job_id: str | None = None, 

1663 pool: str | None = None, 

1664 cfg_path: str | None = None, 

1665 map_index: int = -1, 

1666 ) -> list[str]: 

1667 """ 

1668 Generate the shell command required to execute this task instance. 

1669 

1670 :param dag_id: DAG ID 

1671 :param task_id: Task ID 

1672 :param run_id: The run_id of this task's DagRun 

1673 :param mark_success: Whether to mark the task as successful 

1674 :param ignore_all_deps: Ignore all ignorable dependencies. 

1675 Overrides the other ignore_* parameters. 

1676 :param ignore_depends_on_past: Ignore depends_on_past parameter of DAGs 

1677 (e.g. for Backfills) 

1678 :param wait_for_past_depends_before_skipping: Wait for past depends before marking the ti as skipped 

1679 :param ignore_task_deps: Ignore task-specific dependencies such as depends_on_past 

1680 and trigger rule 

1681 :param ignore_ti_state: Ignore the task instance's previous failure/success 

1682 :param local: Whether to run the task locally 

1683 :param pickle_id: If the DAG was serialized to the DB, the ID 

1684 associated with the pickled DAG 

1685 :param file_path: path to the file containing the DAG definition 

1686 :param raw: raw mode (needs more details) 

1687 :param job_id: job ID (needs more details) 

1688 :param pool: the Airflow pool that the task should run in 

1689 :param cfg_path: the Path to the configuration file 

1690 :return: shell command that can be used to run the task instance 

1691 """ 

1692 cmd = ["airflow", "tasks", "run", dag_id, task_id, run_id] 

1693 if mark_success: 

1694 cmd.extend(["--mark-success"]) 

1695 if pickle_id: 

1696 cmd.extend(["--pickle", str(pickle_id)]) 

1697 if job_id: 

1698 cmd.extend(["--job-id", str(job_id)]) 

1699 if ignore_all_deps: 

1700 cmd.extend(["--ignore-all-dependencies"]) 

1701 if ignore_task_deps: 

1702 cmd.extend(["--ignore-dependencies"]) 

1703 if ignore_depends_on_past: 

1704 cmd.extend(["--depends-on-past", "ignore"]) 

1705 elif wait_for_past_depends_before_skipping: 

1706 cmd.extend(["--depends-on-past", "wait"]) 

1707 if ignore_ti_state: 

1708 cmd.extend(["--force"]) 

1709 if local: 

1710 cmd.extend(["--local"]) 

1711 if pool: 

1712 cmd.extend(["--pool", pool]) 

1713 if raw: 

1714 cmd.extend(["--raw"]) 

1715 if file_path: 

1716 cmd.extend(["--subdir", os.fspath(file_path)]) 

1717 if cfg_path: 

1718 cmd.extend(["--cfg-path", cfg_path]) 

1719 if map_index != -1: 

1720 cmd.extend(["--map-index", str(map_index)]) 

1721 return cmd 

1722 

1723 @property 

1724 def log_url(self) -> str: 

1725 """Log URL for TaskInstance.""" 

1726 run_id = quote(self.run_id) 

1727 base_url = conf.get_mandatory_value("webserver", "BASE_URL") 

1728 return ( 

1729 f"{base_url}" 

1730 f"/dags" 

1731 f"/{self.dag_id}" 

1732 f"/grid" 

1733 f"?dag_run_id={run_id}" 

1734 f"&task_id={self.task_id}" 

1735 f"&map_index={self.map_index}" 

1736 "&tab=logs" 

1737 ) 

1738 

1739 @property 

1740 def mark_success_url(self) -> str: 

1741 """URL to mark TI success.""" 

1742 base_url = conf.get_mandatory_value("webserver", "BASE_URL") 

1743 return ( 

1744 f"{base_url}" 

1745 "/confirm" 

1746 f"?task_id={self.task_id}" 

1747 f"&dag_id={self.dag_id}" 

1748 f"&dag_run_id={quote(self.run_id)}" 

1749 "&upstream=false" 

1750 "&downstream=false" 

1751 "&state=success" 

1752 ) 

1753 

1754 @provide_session 

1755 def current_state(self, session: Session = NEW_SESSION) -> str: 

1756 """ 

1757 Get the very latest state from the database. 

1758 

1759 If a session is passed, we use and looking up the state becomes part of the session, 

1760 otherwise a new session is used. 

1761 

1762 sqlalchemy.inspect is used here to get the primary keys ensuring that if they change 

1763 it will not regress 

1764 

1765 :param session: SQLAlchemy ORM Session 

1766 """ 

1767 filters = (col == getattr(self, col.name) for col in inspect(TaskInstance).primary_key) 

1768 return session.query(TaskInstance.state).filter(*filters).scalar() 

1769 

1770 @provide_session 

1771 def error(self, session: Session = NEW_SESSION) -> None: 

1772 """ 

1773 Force the task instance's state to FAILED in the database. 

1774 

1775 :param session: SQLAlchemy ORM Session 

1776 """ 

1777 self.log.error("Recording the task instance as FAILED") 

1778 self.state = TaskInstanceState.FAILED 

1779 session.merge(self) 

1780 session.commit() 

1781 

1782 @classmethod 

1783 @internal_api_call 

1784 @provide_session 

1785 def get_task_instance( 

1786 cls, 

1787 dag_id: str, 

1788 run_id: str, 

1789 task_id: str, 

1790 map_index: int, 

1791 lock_for_update: bool = False, 

1792 session: Session = NEW_SESSION, 

1793 ) -> TaskInstance | TaskInstancePydantic | None: 

1794 query = ( 

1795 session.query(TaskInstance) 

1796 .options(lazyload(TaskInstance.dag_run)) # lazy load dag run to avoid locking it 

1797 .filter_by( 

1798 dag_id=dag_id, 

1799 run_id=run_id, 

1800 task_id=task_id, 

1801 map_index=map_index, 

1802 ) 

1803 ) 

1804 

1805 if lock_for_update: 

1806 for attempt in run_with_db_retries(logger=cls.logger()): 

1807 with attempt: 

1808 return query.with_for_update().one_or_none() 

1809 else: 

1810 return query.one_or_none() 

1811 

1812 return None 

1813 

1814 @provide_session 

1815 def refresh_from_db(self, session: Session = NEW_SESSION, lock_for_update: bool = False) -> None: 

1816 """ 

1817 Refresh the task instance from the database based on the primary key. 

1818 

1819 :param session: SQLAlchemy ORM Session 

1820 :param lock_for_update: if True, indicates that the database should 

1821 lock the TaskInstance (issuing a FOR UPDATE clause) until the 

1822 session is committed. 

1823 """ 

1824 _refresh_from_db(task_instance=self, session=session, lock_for_update=lock_for_update) 

1825 

1826 def refresh_from_task(self, task: Operator, pool_override: str | None = None) -> None: 

1827 """ 

1828 Copy common attributes from the given task. 

1829 

1830 :param task: The task object to copy from 

1831 :param pool_override: Use the pool_override instead of task's pool 

1832 """ 

1833 _refresh_from_task(task_instance=self, task=task, pool_override=pool_override) 

1834 

1835 @staticmethod 

1836 @internal_api_call 

1837 @provide_session 

1838 def _clear_xcom_data(ti: TaskInstance | TaskInstancePydantic, session: Session = NEW_SESSION) -> None: 

1839 """Clear all XCom data from the database for the task instance. 

1840 

1841 If the task is unmapped, all XComs matching this task ID in the same DAG 

1842 run are removed. If the task is mapped, only the one with matching map 

1843 index is removed. 

1844 

1845 :param ti: The TI for which we need to clear xcoms. 

1846 :param session: SQLAlchemy ORM Session 

1847 """ 

1848 ti.log.debug("Clearing XCom data") 

1849 if ti.map_index < 0: 

1850 map_index: int | None = None 

1851 else: 

1852 map_index = ti.map_index 

1853 XCom.clear( 

1854 dag_id=ti.dag_id, 

1855 task_id=ti.task_id, 

1856 run_id=ti.run_id, 

1857 map_index=map_index, 

1858 session=session, 

1859 ) 

1860 

1861 @provide_session 

1862 def clear_xcom_data(self, session: Session = NEW_SESSION): 

1863 self._clear_xcom_data(ti=self, session=session) 

1864 

1865 @property 

1866 def key(self) -> TaskInstanceKey: 

1867 """Returns a tuple that identifies the task instance uniquely.""" 

1868 return TaskInstanceKey(self.dag_id, self.task_id, self.run_id, self.try_number, self.map_index) 

1869 

1870 @staticmethod 

1871 @internal_api_call 

1872 def _set_state(ti: TaskInstance | TaskInstancePydantic, state, session: Session) -> bool: 

1873 if not isinstance(ti, TaskInstance): 

1874 ti = session.scalars( 

1875 select(TaskInstance).where( 

1876 TaskInstance.task_id == ti.task_id, 

1877 TaskInstance.dag_id == ti.dag_id, 

1878 TaskInstance.run_id == ti.run_id, 

1879 TaskInstance.map_index == ti.map_index, 

1880 ) 

1881 ).one() 

1882 

1883 if ti.state == state: 

1884 return False 

1885 

1886 current_time = timezone.utcnow() 

1887 ti.log.debug("Setting task state for %s to %s", ti, state) 

1888 ti.state = state 

1889 ti.start_date = ti.start_date or current_time 

1890 if ti.state in State.finished or ti.state == TaskInstanceState.UP_FOR_RETRY: 

1891 ti.end_date = ti.end_date or current_time 

1892 ti.duration = (ti.end_date - ti.start_date).total_seconds() 

1893 session.merge(ti) 

1894 return True 

1895 

1896 @provide_session 

1897 def set_state(self, state: str | None, session: Session = NEW_SESSION) -> bool: 

1898 """ 

1899 Set TaskInstance state. 

1900 

1901 :param state: State to set for the TI 

1902 :param session: SQLAlchemy ORM Session 

1903 :return: Was the state changed 

1904 """ 

1905 return self._set_state(ti=self, state=state, session=session) 

1906 

1907 @property 

1908 def is_premature(self) -> bool: 

1909 """Returns whether a task is in UP_FOR_RETRY state and its retry interval has elapsed.""" 

1910 # is the task still in the retry waiting period? 

1911 return self.state == TaskInstanceState.UP_FOR_RETRY and not self.ready_for_retry() 

1912 

1913 @provide_session 

1914 def are_dependents_done(self, session: Session = NEW_SESSION) -> bool: 

1915 """ 

1916 Check whether the immediate dependents of this task instance have succeeded or have been skipped. 

1917 

1918 This is meant to be used by wait_for_downstream. 

1919 

1920 This is useful when you do not want to start processing the next 

1921 schedule of a task until the dependents are done. For instance, 

1922 if the task DROPs and recreates a table. 

1923 

1924 :param session: SQLAlchemy ORM Session 

1925 """ 

1926 task = self.task 

1927 if TYPE_CHECKING: 

1928 assert task 

1929 

1930 if not task.downstream_task_ids: 

1931 return True 

1932 

1933 ti = session.query(func.count(TaskInstance.task_id)).filter( 

1934 TaskInstance.dag_id == self.dag_id, 

1935 TaskInstance.task_id.in_(task.downstream_task_ids), 

1936 TaskInstance.run_id == self.run_id, 

1937 TaskInstance.state.in_((TaskInstanceState.SKIPPED, TaskInstanceState.SUCCESS)), 

1938 ) 

1939 count = ti[0][0] 

1940 return count == len(task.downstream_task_ids) 

1941 

1942 @provide_session 

1943 def get_previous_dagrun( 

1944 self, 

1945 state: DagRunState | None = None, 

1946 session: Session | None = None, 

1947 ) -> DagRun | None: 

1948 """ 

1949 Return the DagRun that ran before this task instance's DagRun. 

1950 

1951 :param state: If passed, it only take into account instances of a specific state. 

1952 :param session: SQLAlchemy ORM Session. 

1953 """ 

1954 return _get_previous_dagrun(task_instance=self, state=state, session=session) 

1955 

1956 @provide_session 

1957 def get_previous_ti( 

1958 self, 

1959 state: DagRunState | None = None, 

1960 session: Session = NEW_SESSION, 

1961 ) -> TaskInstance | TaskInstancePydantic | None: 

1962 """ 

1963 Return the task instance for the task that ran before this task instance. 

1964 

1965 :param session: SQLAlchemy ORM Session 

1966 :param state: If passed, it only take into account instances of a specific state. 

1967 """ 

1968 return _get_previous_ti(task_instance=self, state=state, session=session) 

1969 

1970 @property 

1971 def previous_ti(self) -> TaskInstance | TaskInstancePydantic | None: 

1972 """ 

1973 This attribute is deprecated. 

1974 

1975 Please use :class:`airflow.models.taskinstance.TaskInstance.get_previous_ti`. 

1976 """ 

1977 warnings.warn( 

1978 """ 

1979 This attribute is deprecated. 

1980 Please use `airflow.models.taskinstance.TaskInstance.get_previous_ti` method. 

1981 """, 

1982 RemovedInAirflow3Warning, 

1983 stacklevel=2, 

1984 ) 

1985 return self.get_previous_ti() 

1986 

1987 @property 

1988 def previous_ti_success(self) -> TaskInstance | TaskInstancePydantic | None: 

1989 """ 

1990 This attribute is deprecated. 

1991 

1992 Please use :class:`airflow.models.taskinstance.TaskInstance.get_previous_ti`. 

1993 """ 

1994 warnings.warn( 

1995 """ 

1996 This attribute is deprecated. 

1997 Please use `airflow.models.taskinstance.TaskInstance.get_previous_ti` method. 

1998 """, 

1999 RemovedInAirflow3Warning, 

2000 stacklevel=2, 

2001 ) 

2002 return self.get_previous_ti(state=DagRunState.SUCCESS) 

2003 

2004 @provide_session 

2005 def get_previous_execution_date( 

2006 self, 

2007 state: DagRunState | None = None, 

2008 session: Session = NEW_SESSION, 

2009 ) -> pendulum.DateTime | None: 

2010 """ 

2011 Return the execution date from property previous_ti_success. 

2012 

2013 :param state: If passed, it only take into account instances of a specific state. 

2014 :param session: SQLAlchemy ORM Session 

2015 """ 

2016 return _get_previous_execution_date(task_instance=self, state=state, session=session) 

2017 

2018 @provide_session 

2019 def get_previous_start_date( 

2020 self, state: DagRunState | None = None, session: Session = NEW_SESSION 

2021 ) -> pendulum.DateTime | None: 

2022 """ 

2023 Return the start date from property previous_ti_success. 

2024 

2025 :param state: If passed, it only take into account instances of a specific state. 

2026 :param session: SQLAlchemy ORM Session 

2027 """ 

2028 self.log.debug("previous_start_date was called") 

2029 prev_ti = self.get_previous_ti(state=state, session=session) 

2030 # prev_ti may not exist and prev_ti.start_date may be None. 

2031 return pendulum.instance(prev_ti.start_date) if prev_ti and prev_ti.start_date else None 

2032 

2033 @property 

2034 def previous_start_date_success(self) -> pendulum.DateTime | None: 

2035 """ 

2036 This attribute is deprecated. 

2037 

2038 Please use :class:`airflow.models.taskinstance.TaskInstance.get_previous_start_date`. 

2039 """ 

2040 warnings.warn( 

2041 """ 

2042 This attribute is deprecated. 

2043 Please use `airflow.models.taskinstance.TaskInstance.get_previous_start_date` method. 

2044 """, 

2045 RemovedInAirflow3Warning, 

2046 stacklevel=2, 

2047 ) 

2048 return self.get_previous_start_date(state=DagRunState.SUCCESS) 

2049 

2050 @provide_session 

2051 def are_dependencies_met( 

2052 self, dep_context: DepContext | None = None, session: Session = NEW_SESSION, verbose: bool = False 

2053 ) -> bool: 

2054 """ 

2055 Are all conditions met for this task instance to be run given the context for the dependencies. 

2056 

2057 (e.g. a task instance being force run from the UI will ignore some dependencies). 

2058 

2059 :param dep_context: The execution context that determines the dependencies that should be evaluated. 

2060 :param session: database session 

2061 :param verbose: whether log details on failed dependencies on info or debug log level 

2062 """ 

2063 dep_context = dep_context or DepContext() 

2064 failed = False 

2065 verbose_aware_logger = self.log.info if verbose else self.log.debug 

2066 for dep_status in self.get_failed_dep_statuses(dep_context=dep_context, session=session): 

2067 failed = True 

2068 

2069 verbose_aware_logger( 

2070 "Dependencies not met for %s, dependency '%s' FAILED: %s", 

2071 self, 

2072 dep_status.dep_name, 

2073 dep_status.reason, 

2074 ) 

2075 

2076 if failed: 

2077 return False 

2078 

2079 verbose_aware_logger("Dependencies all met for dep_context=%s ti=%s", dep_context.description, self) 

2080 return True 

2081 

2082 @provide_session 

2083 def get_failed_dep_statuses(self, dep_context: DepContext | None = None, session: Session = NEW_SESSION): 

2084 """Get failed Dependencies.""" 

2085 if TYPE_CHECKING: 

2086 assert self.task 

2087 

2088 dep_context = dep_context or DepContext() 

2089 for dep in dep_context.deps | self.task.deps: 

2090 for dep_status in dep.get_dep_statuses(self, session, dep_context): 

2091 self.log.debug( 

2092 "%s dependency '%s' PASSED: %s, %s", 

2093 self, 

2094 dep_status.dep_name, 

2095 dep_status.passed, 

2096 dep_status.reason, 

2097 ) 

2098 

2099 if not dep_status.passed: 

2100 yield dep_status 

2101 

2102 def __repr__(self) -> str: 

2103 prefix = f"<TaskInstance: {self.dag_id}.{self.task_id} {self.run_id} " 

2104 if self.map_index != -1: 

2105 prefix += f"map_index={self.map_index} " 

2106 return prefix + f"[{self.state}]>" 

2107 

2108 def next_retry_datetime(self): 

2109 """ 

2110 Get datetime of the next retry if the task instance fails. 

2111 

2112 For exponential backoff, retry_delay is used as base and will be converted to seconds. 

2113 """ 

2114 from airflow.models.abstractoperator import MAX_RETRY_DELAY 

2115 

2116 delay = self.task.retry_delay 

2117 if self.task.retry_exponential_backoff: 

2118 # If the min_backoff calculation is below 1, it will be converted to 0 via int. Thus, 

2119 # we must round up prior to converting to an int, otherwise a divide by zero error 

2120 # will occur in the modded_hash calculation. 

2121 # this probably gives unexpected results if a task instance has previously been cleared, 

2122 # because try_number can increase without bound 

2123 min_backoff = math.ceil(delay.total_seconds() * (2 ** (self.try_number - 1))) 

2124 

2125 # In the case when delay.total_seconds() is 0, min_backoff will not be rounded up to 1. 

2126 # To address this, we impose a lower bound of 1 on min_backoff. This effectively makes 

2127 # the ceiling function unnecessary, but the ceiling function was retained to avoid 

2128 # introducing a breaking change. 

2129 if min_backoff < 1: 

2130 min_backoff = 1 

2131 

2132 # deterministic per task instance 

2133 ti_hash = int( 

2134 hashlib.sha1( 

2135 f"{self.dag_id}#{self.task_id}#{self.execution_date}#{self.try_number}".encode() 

2136 ).hexdigest(), 

2137 16, 

2138 ) 

2139 # between 1 and 1.0 * delay * (2^retry_number) 

2140 modded_hash = min_backoff + ti_hash % min_backoff 

2141 # timedelta has a maximum representable value. The exponentiation 

2142 # here means this value can be exceeded after a certain number 

2143 # of tries (around 50 if the initial delay is 1s, even fewer if 

2144 # the delay is larger). Cap the value here before creating a 

2145 # timedelta object so the operation doesn't fail with "OverflowError". 

2146 delay_backoff_in_seconds = min(modded_hash, MAX_RETRY_DELAY) 

2147 delay = timedelta(seconds=delay_backoff_in_seconds) 

2148 if self.task.max_retry_delay: 

2149 delay = min(self.task.max_retry_delay, delay) 

2150 return self.end_date + delay 

2151 

2152 def ready_for_retry(self) -> bool: 

2153 """Check on whether the task instance is in the right state and timeframe to be retried.""" 

2154 return self.state == TaskInstanceState.UP_FOR_RETRY and self.next_retry_datetime() < timezone.utcnow() 

2155 

2156 @staticmethod 

2157 @internal_api_call 

2158 def _get_dagrun(dag_id, run_id, session) -> DagRun: 

2159 from airflow.models.dagrun import DagRun # Avoid circular import 

2160 

2161 dr = session.query(DagRun).filter(DagRun.dag_id == dag_id, DagRun.run_id == run_id).one() 

2162 return dr 

2163 

2164 @provide_session 

2165 def get_dagrun(self, session: Session = NEW_SESSION) -> DagRun: 

2166 """ 

2167 Return the DagRun for this TaskInstance. 

2168 

2169 :param session: SQLAlchemy ORM Session 

2170 :return: DagRun 

2171 """ 

2172 info = inspect(self) 

2173 if info.attrs.dag_run.loaded_value is not NO_VALUE: 

2174 if getattr(self, "task", None) is not None: 

2175 if TYPE_CHECKING: 

2176 assert self.task 

2177 self.dag_run.dag = self.task.dag 

2178 return self.dag_run 

2179 

2180 dr = self._get_dagrun(self.dag_id, self.run_id, session) 

2181 if getattr(self, "task", None) is not None: 

2182 if TYPE_CHECKING: 

2183 assert self.task 

2184 dr.dag = self.task.dag 

2185 # Record it in the instance for next time. This means that `self.execution_date` will work correctly 

2186 set_committed_value(self, "dag_run", dr) 

2187 

2188 return dr 

2189 

2190 @classmethod 

2191 @internal_api_call 

2192 @provide_session 

2193 def _check_and_change_state_before_execution( 

2194 cls, 

2195 task_instance: TaskInstance | TaskInstancePydantic, 

2196 verbose: bool = True, 

2197 ignore_all_deps: bool = False, 

2198 ignore_depends_on_past: bool = False, 

2199 wait_for_past_depends_before_skipping: bool = False, 

2200 ignore_task_deps: bool = False, 

2201 ignore_ti_state: bool = False, 

2202 mark_success: bool = False, 

2203 test_mode: bool = False, 

2204 hostname: str = "", 

2205 job_id: str | None = None, 

2206 pool: str | None = None, 

2207 external_executor_id: str | None = None, 

2208 session: Session = NEW_SESSION, 

2209 ) -> bool: 

2210 """ 

2211 Check dependencies and then sets state to RUNNING if they are met. 

2212 

2213 Returns True if and only if state is set to RUNNING, which implies that task should be 

2214 executed, in preparation for _run_raw_task. 

2215 

2216 :param verbose: whether to turn on more verbose logging 

2217 :param ignore_all_deps: Ignore all of the non-critical dependencies, just runs 

2218 :param ignore_depends_on_past: Ignore depends_on_past DAG attribute 

2219 :param wait_for_past_depends_before_skipping: Wait for past depends before mark the ti as skipped 

2220 :param ignore_task_deps: Don't check the dependencies of this TaskInstance's task 

2221 :param ignore_ti_state: Disregards previous task instance state 

2222 :param mark_success: Don't run the task, mark its state as success 

2223 :param test_mode: Doesn't record success or failure in the DB 

2224 :param hostname: The hostname of the worker running the task instance. 

2225 :param job_id: Job (BackfillJob / LocalTaskJob / SchedulerJob) ID 

2226 :param pool: specifies the pool to use to run the task instance 

2227 :param external_executor_id: The identifier of the celery executor 

2228 :param session: SQLAlchemy ORM Session 

2229 :return: whether the state was changed to running or not 

2230 """ 

2231 if TYPE_CHECKING: 

2232 assert task_instance.task 

2233 

2234 if isinstance(task_instance, TaskInstance): 

2235 ti: TaskInstance = task_instance 

2236 else: # isinstance(task_instance, TaskInstancePydantic) 

2237 filters = (col == getattr(task_instance, col.name) for col in inspect(TaskInstance).primary_key) 

2238 ti = session.query(TaskInstance).filter(*filters).scalar() 

2239 dag = ti.dag_model.serialized_dag.dag 

2240 task_instance.task = dag.task_dict[ti.task_id] 

2241 ti.task = task_instance.task 

2242 task = task_instance.task 

2243 if TYPE_CHECKING: 

2244 assert task 

2245 ti.refresh_from_task(task, pool_override=pool) 

2246 ti.test_mode = test_mode 

2247 ti.refresh_from_db(session=session, lock_for_update=True) 

2248 ti.job_id = job_id 

2249 ti.hostname = hostname 

2250 ti.pid = None 

2251 

2252 if not ignore_all_deps and not ignore_ti_state and ti.state == TaskInstanceState.SUCCESS: 

2253 Stats.incr("previously_succeeded", tags=ti.stats_tags) 

2254 

2255 if not mark_success: 

2256 # Firstly find non-runnable and non-requeueable tis. 

2257 # Since mark_success is not set, we do nothing. 

2258 non_requeueable_dep_context = DepContext( 

2259 deps=RUNNING_DEPS - REQUEUEABLE_DEPS, 

2260 ignore_all_deps=ignore_all_deps, 

2261 ignore_ti_state=ignore_ti_state, 

2262 ignore_depends_on_past=ignore_depends_on_past, 

2263 wait_for_past_depends_before_skipping=wait_for_past_depends_before_skipping, 

2264 ignore_task_deps=ignore_task_deps, 

2265 description="non-requeueable deps", 

2266 ) 

2267 if not ti.are_dependencies_met( 

2268 dep_context=non_requeueable_dep_context, session=session, verbose=True 

2269 ): 

2270 session.commit() 

2271 return False 

2272 

2273 # For reporting purposes, we report based on 1-indexed, 

2274 # not 0-indexed lists (i.e. Attempt 1 instead of 

2275 # Attempt 0 for the first attempt). 

2276 # Set the task start date. In case it was re-scheduled use the initial 

2277 # start date that is recorded in task_reschedule table 

2278 # If the task continues after being deferred (next_method is set), use the original start_date 

2279 ti.start_date = ti.start_date if ti.next_method else timezone.utcnow() 

2280 if ti.state == TaskInstanceState.UP_FOR_RESCHEDULE: 

2281 tr_start_date = session.scalar( 

2282 TR.stmt_for_task_instance(ti, descending=False).with_only_columns(TR.start_date).limit(1) 

2283 ) 

2284 if tr_start_date: 

2285 ti.start_date = tr_start_date 

2286 

2287 # Secondly we find non-runnable but requeueable tis. We reset its state. 

2288 # This is because we might have hit concurrency limits, 

2289 # e.g. because of backfilling. 

2290 dep_context = DepContext( 

2291 deps=REQUEUEABLE_DEPS, 

2292 ignore_all_deps=ignore_all_deps, 

2293 ignore_depends_on_past=ignore_depends_on_past, 

2294 wait_for_past_depends_before_skipping=wait_for_past_depends_before_skipping, 

2295 ignore_task_deps=ignore_task_deps, 

2296 ignore_ti_state=ignore_ti_state, 

2297 description="requeueable deps", 

2298 ) 

2299 if not ti.are_dependencies_met(dep_context=dep_context, session=session, verbose=True): 

2300 ti.state = None 

2301 cls.logger().warning( 

2302 "Rescheduling due to concurrency limits reached " 

2303 "at task runtime. Attempt %s of " 

2304 "%s. State set to NONE.", 

2305 ti.try_number, 

2306 ti.max_tries + 1, 

2307 ) 

2308 ti.queued_dttm = timezone.utcnow() 

2309 session.merge(ti) 

2310 session.commit() 

2311 return False 

2312 

2313 if ti.next_kwargs is not None: 

2314 cls.logger().info("Resuming after deferral") 

2315 else: 

2316 cls.logger().info("Starting attempt %s of %s", ti.try_number, ti.max_tries + 1) 

2317 

2318 if not test_mode: 

2319 session.add(Log(TaskInstanceState.RUNNING.value, ti)) 

2320 

2321 ti.state = TaskInstanceState.RUNNING 

2322 ti.emit_state_change_metric(TaskInstanceState.RUNNING) 

2323 

2324 if external_executor_id: 

2325 ti.external_executor_id = external_executor_id 

2326 

2327 ti.end_date = None 

2328 if not test_mode: 

2329 session.merge(ti).task = task 

2330 session.commit() 

2331 

2332 # Closing all pooled connections to prevent 

2333 # "max number of connections reached" 

2334 settings.engine.dispose() # type: ignore 

2335 if verbose: 

2336 if mark_success: 

2337 cls.logger().info("Marking success for %s on %s", ti.task, ti.execution_date) 

2338 else: 

2339 cls.logger().info("Executing %s on %s", ti.task, ti.execution_date) 

2340 return True 

2341 

2342 @provide_session 

2343 def check_and_change_state_before_execution( 

2344 self, 

2345 verbose: bool = True, 

2346 ignore_all_deps: bool = False, 

2347 ignore_depends_on_past: bool = False, 

2348 wait_for_past_depends_before_skipping: bool = False, 

2349 ignore_task_deps: bool = False, 

2350 ignore_ti_state: bool = False, 

2351 mark_success: bool = False, 

2352 test_mode: bool = False, 

2353 job_id: str | None = None, 

2354 pool: str | None = None, 

2355 external_executor_id: str | None = None, 

2356 session: Session = NEW_SESSION, 

2357 ) -> bool: 

2358 return TaskInstance._check_and_change_state_before_execution( 

2359 task_instance=self, 

2360 verbose=verbose, 

2361 ignore_all_deps=ignore_all_deps, 

2362 ignore_depends_on_past=ignore_depends_on_past, 

2363 wait_for_past_depends_before_skipping=wait_for_past_depends_before_skipping, 

2364 ignore_task_deps=ignore_task_deps, 

2365 ignore_ti_state=ignore_ti_state, 

2366 mark_success=mark_success, 

2367 test_mode=test_mode, 

2368 hostname=get_hostname(), 

2369 job_id=job_id, 

2370 pool=pool, 

2371 external_executor_id=external_executor_id, 

2372 session=session, 

2373 ) 

2374 

2375 def emit_state_change_metric(self, new_state: TaskInstanceState) -> None: 

2376 """ 

2377 Send a time metric representing how much time a given state transition took. 

2378 

2379 The previous state and metric name is deduced from the state the task was put in. 

2380 

2381 :param new_state: The state that has just been set for this task. 

2382 We do not use `self.state`, because sometimes the state is updated directly in the DB and not in 

2383 the local TaskInstance object. 

2384 Supported states: QUEUED and RUNNING 

2385 """ 

2386 if self.end_date: 

2387 # if the task has an end date, it means that this is not its first round. 

2388 # we send the state transition time metric only on the first try, otherwise it gets more complex. 

2389 return 

2390 

2391 # switch on state and deduce which metric to send 

2392 if new_state == TaskInstanceState.RUNNING: 

2393 metric_name = "queued_duration" 

2394 if self.queued_dttm is None: 

2395 # this should not really happen except in tests or rare cases, 

2396 # but we don't want to create errors just for a metric, so we just skip it 

2397 self.log.warning( 

2398 "cannot record %s for task %s because previous state change time has not been saved", 

2399 metric_name, 

2400 self.task_id, 

2401 ) 

2402 return 

2403 timing = (timezone.utcnow() - self.queued_dttm).total_seconds() 

2404 elif new_state == TaskInstanceState.QUEUED: 

2405 metric_name = "scheduled_duration" 

2406 if self.start_date is None: 

2407 # This check does not work correctly before fields like `scheduled_dttm` are implemented. 

2408 # TODO: Change the level to WARNING once it's viable. 

2409 # see #30612 #34493 and #34771 for more details 

2410 self.log.debug( 

2411 "cannot record %s for task %s because previous state change time has not been saved", 

2412 metric_name, 

2413 self.task_id, 

2414 ) 

2415 return 

2416 timing = (timezone.utcnow() - self.start_date).total_seconds() 

2417 else: 

2418 raise NotImplementedError("no metric emission setup for state %s", new_state) 

2419 

2420 # send metric twice, once (legacy) with tags in the name and once with tags as tags 

2421 Stats.timing(f"dag.{self.dag_id}.{self.task_id}.{metric_name}", timing) 

2422 Stats.timing(f"task.{metric_name}", timing, tags={"task_id": self.task_id, "dag_id": self.dag_id}) 

2423 

2424 def clear_next_method_args(self) -> None: 

2425 """Ensure we unset next_method and next_kwargs to ensure that any retries don't reuse them.""" 

2426 _clear_next_method_args(task_instance=self) 

2427 

2428 @provide_session 

2429 @Sentry.enrich_errors 

2430 def _run_raw_task( 

2431 self, 

2432 mark_success: bool = False, 

2433 test_mode: bool = False, 

2434 job_id: str | None = None, 

2435 pool: str | None = None, 

2436 raise_on_defer: bool = False, 

2437 session: Session = NEW_SESSION, 

2438 ) -> TaskReturnCode | None: 

2439 """ 

2440 Run a task, update the state upon completion, and run any appropriate callbacks. 

2441 

2442 Immediately runs the task (without checking or changing db state 

2443 before execution) and then sets the appropriate final state after 

2444 completion and runs any post-execute callbacks. Meant to be called 

2445 only after another function changes the state to running. 

2446 

2447 :param mark_success: Don't run the task, mark its state as success 

2448 :param test_mode: Doesn't record success or failure in the DB 

2449 :param pool: specifies the pool to use to run the task instance 

2450 :param session: SQLAlchemy ORM Session 

2451 """ 

2452 if TYPE_CHECKING: 

2453 assert self.task 

2454 

2455 self.test_mode = test_mode 

2456 self.refresh_from_task(self.task, pool_override=pool) 

2457 self.refresh_from_db(session=session) 

2458 

2459 self.job_id = job_id 

2460 self.hostname = get_hostname() 

2461 self.pid = os.getpid() 

2462 if not test_mode: 

2463 session.merge(self) 

2464 session.commit() 

2465 actual_start_date = timezone.utcnow() 

2466 Stats.incr(f"ti.start.{self.task.dag_id}.{self.task.task_id}", tags=self.stats_tags) 

2467 # Same metric with tagging 

2468 Stats.incr("ti.start", tags=self.stats_tags) 

2469 # Initialize final state counters at zero 

2470 for state in State.task_states: 

2471 Stats.incr( 

2472 f"ti.finish.{self.task.dag_id}.{self.task.task_id}.{state}", 

2473 count=0, 

2474 tags=self.stats_tags, 

2475 ) 

2476 # Same metric with tagging 

2477 Stats.incr( 

2478 "ti.finish", 

2479 count=0, 

2480 tags={**self.stats_tags, "state": str(state)}, 

2481 ) 

2482 with set_current_task_instance_session(session=session): 

2483 self.task = self.task.prepare_for_execution() 

2484 context = self.get_template_context(ignore_param_exceptions=False) 

2485 

2486 try: 

2487 if not mark_success: 

2488 self._execute_task_with_callbacks(context, test_mode, session=session) 

2489 if not test_mode: 

2490 self.refresh_from_db(lock_for_update=True, session=session) 

2491 self.state = TaskInstanceState.SUCCESS 

2492 except TaskDeferred as defer: 

2493 # The task has signalled it wants to defer execution based on 

2494 # a trigger. 

2495 if raise_on_defer: 

2496 raise 

2497 self.defer_task(defer=defer, session=session) 

2498 self.log.info( 

2499 "Pausing task as DEFERRED. dag_id=%s, task_id=%s, run_id=%s, execution_date=%s, start_date=%s", 

2500 self.dag_id, 

2501 self.task_id, 

2502 self.run_id, 

2503 _date_or_empty(task_instance=self, attr="execution_date"), 

2504 _date_or_empty(task_instance=self, attr="start_date"), 

2505 ) 

2506 if not test_mode: 

2507 session.add(Log(self.state, self)) 

2508 session.merge(self) 

2509 session.commit() 

2510 return TaskReturnCode.DEFERRED 

2511 except AirflowSkipException as e: 

2512 # Recording SKIP 

2513 # log only if exception has any arguments to prevent log flooding 

2514 if e.args: 

2515 self.log.info(e) 

2516 if not test_mode: 

2517 self.refresh_from_db(lock_for_update=True, session=session) 

2518 _run_finished_callback(callbacks=self.task.on_skipped_callback, context=context) 

2519 session.commit() 

2520 self.state = TaskInstanceState.SKIPPED 

2521 except AirflowRescheduleException as reschedule_exception: 

2522 self._handle_reschedule(actual_start_date, reschedule_exception, test_mode, session=session) 

2523 session.commit() 

2524 return None 

2525 except (AirflowFailException, AirflowSensorTimeout) as e: 

2526 # If AirflowFailException is raised, task should not retry. 

2527 # If a sensor in reschedule mode reaches timeout, task should not retry. 

2528 self.handle_failure(e, test_mode, context, force_fail=True, session=session) 

2529 session.commit() 

2530 raise 

2531 except (AirflowTaskTimeout, AirflowException, AirflowTaskTerminated) as e: 

2532 if not test_mode: 

2533 self.refresh_from_db(lock_for_update=True, session=session) 

2534 # for case when task is marked as success/failed externally 

2535 # or dagrun timed out and task is marked as skipped 

2536 # current behavior doesn't hit the callbacks 

2537 if self.state in State.finished: 

2538 self.clear_next_method_args() 

2539 session.merge(self) 

2540 session.commit() 

2541 return None 

2542 else: 

2543 self.handle_failure(e, test_mode, context, session=session) 

2544 session.commit() 

2545 raise 

2546 except SystemExit as e: 

2547 # We have already handled SystemExit with success codes (0 and None) in the `_execute_task`. 

2548 # Therefore, here we must handle only error codes. 

2549 msg = f"Task failed due to SystemExit({e.code})" 

2550 self.handle_failure(msg, test_mode, context, session=session) 

2551 session.commit() 

2552 raise AirflowException(msg) 

2553 except BaseException as e: 

2554 self.handle_failure(e, test_mode, context, session=session) 

2555 session.commit() 

2556 raise 

2557 finally: 

2558 Stats.incr(f"ti.finish.{self.dag_id}.{self.task_id}.{self.state}", tags=self.stats_tags) 

2559 # Same metric with tagging 

2560 Stats.incr("ti.finish", tags={**self.stats_tags, "state": str(self.state)}) 

2561 

2562 # Recording SKIPPED or SUCCESS 

2563 self.clear_next_method_args() 

2564 self.end_date = timezone.utcnow() 

2565 _log_state(task_instance=self) 

2566 self.set_duration() 

2567 

2568 # run on_success_callback before db committing 

2569 # otherwise, the LocalTaskJob sees the state is changed to `success`, 

2570 # but the task_runner is still running, LocalTaskJob then treats the state is set externally! 

2571 _run_finished_callback(callbacks=self.task.on_success_callback, context=context) 

2572 

2573 if not test_mode: 

2574 session.add(Log(self.state, self)) 

2575 session.merge(self).task = self.task 

2576 if self.state == TaskInstanceState.SUCCESS: 

2577 self._register_dataset_changes(events=context["outlet_events"], session=session) 

2578 

2579 session.commit() 

2580 if self.state == TaskInstanceState.SUCCESS: 

2581 get_listener_manager().hook.on_task_instance_success( 

2582 previous_state=TaskInstanceState.RUNNING, task_instance=self, session=session 

2583 ) 

2584 

2585 return None 

2586 

2587 def _register_dataset_changes(self, *, events: OutletEventAccessors, session: Session) -> None: 

2588 if TYPE_CHECKING: 

2589 assert self.task 

2590 

2591 for obj in self.task.outlets or []: 

2592 self.log.debug("outlet obj %s", obj) 

2593 # Lineage can have other types of objects besides datasets 

2594 if isinstance(obj, Dataset): 

2595 dataset_manager.register_dataset_change( 

2596 task_instance=self, 

2597 dataset=obj, 

2598 extra=events[obj].extra, 

2599 session=session, 

2600 ) 

2601 

2602 def _execute_task_with_callbacks(self, context: Context, test_mode: bool = False, *, session: Session): 

2603 """Prepare Task for Execution.""" 

2604 if TYPE_CHECKING: 

2605 assert self.task 

2606 

2607 parent_pid = os.getpid() 

2608 

2609 def signal_handler(signum, frame): 

2610 pid = os.getpid() 

2611 

2612 # If a task forks during execution (from DAG code) for whatever 

2613 # reason, we want to make sure that we react to the signal only in 

2614 # the process that we've spawned ourselves (referred to here as the 

2615 # parent process). 

2616 if pid != parent_pid: 

2617 os._exit(1) 

2618 return 

2619 self.log.error("Received SIGTERM. Terminating subprocesses.") 

2620 self.task.on_kill() 

2621 raise AirflowTaskTerminated("Task received SIGTERM signal") 

2622 

2623 signal.signal(signal.SIGTERM, signal_handler) 

2624 

2625 # Don't clear Xcom until the task is certain to execute, and check if we are resuming from deferral. 

2626 if not self.next_method: 

2627 self.clear_xcom_data() 

2628 

2629 with Stats.timer(f"dag.{self.task.dag_id}.{self.task.task_id}.duration"), Stats.timer( 

2630 "task.duration", tags=self.stats_tags 

2631 ): 

2632 # Set the validated/merged params on the task object. 

2633 self.task.params = context["params"] 

2634 

2635 with set_current_context(context): 

2636 dag = self.task.get_dag() 

2637 if dag is not None: 

2638 jinja_env = dag.get_template_env() 

2639 else: 

2640 jinja_env = None 

2641 task_orig = self.render_templates(context=context, jinja_env=jinja_env) 

2642 

2643 # The task is never MappedOperator at this point. 

2644 if TYPE_CHECKING: 

2645 assert isinstance(self.task, BaseOperator) 

2646 

2647 if not test_mode: 

2648 rendered_fields = get_serialized_template_fields(task=self.task) 

2649 _update_rtif(ti=self, rendered_fields=rendered_fields) 

2650 # Export context to make it available for operators to use. 

2651 airflow_context_vars = context_to_airflow_vars(context, in_env_var_format=True) 

2652 os.environ.update(airflow_context_vars) 

2653 

2654 # Log context only for the default execution method, the assumption 

2655 # being that otherwise we're resuming a deferred task (in which 

2656 # case there's no need to log these again). 

2657 if not self.next_method: 

2658 self.log.info( 

2659 "Exporting env vars: %s", 

2660 " ".join(f"{k}={v!r}" for k, v in airflow_context_vars.items()), 

2661 ) 

2662 

2663 # Run pre_execute callback 

2664 self.task.pre_execute(context=context) 

2665 

2666 # Run on_execute callback 

2667 self._run_execute_callback(context, self.task) 

2668 

2669 # Run on_task_instance_running event 

2670 get_listener_manager().hook.on_task_instance_running( 

2671 previous_state=TaskInstanceState.QUEUED, task_instance=self, session=session 

2672 ) 

2673 

2674 def _render_map_index(context: Context, *, jinja_env: jinja2.Environment | None) -> str | None: 

2675 """Render named map index if the DAG author defined map_index_template at the task level.""" 

2676 if jinja_env is None or (template := context.get("map_index_template")) is None: 

2677 return None 

2678 rendered_map_index = jinja_env.from_string(template).render(context) 

2679 log.debug("Map index rendered as %s", rendered_map_index) 

2680 return rendered_map_index 

2681 

2682 # Execute the task. 

2683 with set_current_context(context): 

2684 try: 

2685 result = self._execute_task(context, task_orig) 

2686 except Exception: 

2687 # If the task failed, swallow rendering error so it doesn't mask the main error. 

2688 with contextlib.suppress(jinja2.TemplateSyntaxError, jinja2.UndefinedError): 

2689 self.rendered_map_index = _render_map_index(context, jinja_env=jinja_env) 

2690 raise 

2691 else: # If the task succeeded, render normally to let rendering error bubble up. 

2692 self.rendered_map_index = _render_map_index(context, jinja_env=jinja_env) 

2693 

2694 # Run post_execute callback 

2695 self.task.post_execute(context=context, result=result) 

2696 

2697 Stats.incr(f"operator_successes_{self.task.task_type}", tags=self.stats_tags) 

2698 # Same metric with tagging 

2699 Stats.incr("operator_successes", tags={**self.stats_tags, "task_type": self.task.task_type}) 

2700 Stats.incr("ti_successes", tags=self.stats_tags) 

2701 

2702 def _execute_task(self, context: Context, task_orig: Operator): 

2703 """ 

2704 Execute Task (optionally with a Timeout) and push Xcom results. 

2705 

2706 :param context: Jinja2 context 

2707 :param task_orig: origin task 

2708 """ 

2709 return _execute_task(self, context, task_orig) 

2710 

2711 @provide_session 

2712 def defer_task(self, session: Session, defer: TaskDeferred) -> None: 

2713 """Mark the task as deferred and sets up the trigger that is needed to resume it. 

2714 

2715 :meta: private 

2716 """ 

2717 from airflow.models.trigger import Trigger 

2718 

2719 if TYPE_CHECKING: 

2720 assert self.task 

2721 

2722 # First, make the trigger entry 

2723 trigger_row = Trigger.from_object(defer.trigger) 

2724 session.add(trigger_row) 

2725 session.flush() 

2726 

2727 # Then, update ourselves so it matches the deferral request 

2728 # Keep an eye on the logic in `check_and_change_state_before_execution()` 

2729 # depending on self.next_method semantics 

2730 self.state = TaskInstanceState.DEFERRED 

2731 self.trigger_id = trigger_row.id 

2732 self.next_method = defer.method_name 

2733 self.next_kwargs = defer.kwargs or {} 

2734 

2735 # Calculate timeout too if it was passed 

2736 if defer.timeout is not None: 

2737 self.trigger_timeout = timezone.utcnow() + defer.timeout 

2738 else: 

2739 self.trigger_timeout = None 

2740 

2741 # If an execution_timeout is set, set the timeout to the minimum of 

2742 # it and the trigger timeout 

2743 execution_timeout = self.task.execution_timeout 

2744 if execution_timeout: 

2745 if self.trigger_timeout: 

2746 self.trigger_timeout = min(self.start_date + execution_timeout, self.trigger_timeout) 

2747 else: 

2748 self.trigger_timeout = self.start_date + execution_timeout 

2749 

2750 def _run_execute_callback(self, context: Context, task: BaseOperator) -> None: 

2751 """Functions that need to be run before a Task is executed.""" 

2752 if not (callbacks := task.on_execute_callback): 

2753 return 

2754 for callback in callbacks if isinstance(callbacks, list) else [callbacks]: 

2755 try: 

2756 callback(context) 

2757 except Exception: 

2758 self.log.exception("Failed when executing execute callback") 

2759 

2760 @provide_session 

2761 def run( 

2762 self, 

2763 verbose: bool = True, 

2764 ignore_all_deps: bool = False, 

2765 ignore_depends_on_past: bool = False, 

2766 wait_for_past_depends_before_skipping: bool = False, 

2767 ignore_task_deps: bool = False, 

2768 ignore_ti_state: bool = False, 

2769 mark_success: bool = False, 

2770 test_mode: bool = False, 

2771 job_id: str | None = None, 

2772 pool: str | None = None, 

2773 session: Session = NEW_SESSION, 

2774 raise_on_defer: bool = False, 

2775 ) -> None: 

2776 """Run TaskInstance.""" 

2777 res = self.check_and_change_state_before_execution( 

2778 verbose=verbose, 

2779 ignore_all_deps=ignore_all_deps, 

2780 ignore_depends_on_past=ignore_depends_on_past, 

2781 wait_for_past_depends_before_skipping=wait_for_past_depends_before_skipping, 

2782 ignore_task_deps=ignore_task_deps, 

2783 ignore_ti_state=ignore_ti_state, 

2784 mark_success=mark_success, 

2785 test_mode=test_mode, 

2786 job_id=job_id, 

2787 pool=pool, 

2788 session=session, 

2789 ) 

2790 if not res: 

2791 return 

2792 

2793 self._run_raw_task( 

2794 mark_success=mark_success, 

2795 test_mode=test_mode, 

2796 job_id=job_id, 

2797 pool=pool, 

2798 session=session, 

2799 raise_on_defer=raise_on_defer, 

2800 ) 

2801 

2802 def dry_run(self) -> None: 

2803 """Only Renders Templates for the TI.""" 

2804 if TYPE_CHECKING: 

2805 assert self.task 

2806 

2807 self.task = self.task.prepare_for_execution() 

2808 self.render_templates() 

2809 if TYPE_CHECKING: 

2810 assert isinstance(self.task, BaseOperator) 

2811 self.task.dry_run() 

2812 

2813 @provide_session 

2814 def _handle_reschedule( 

2815 self, 

2816 actual_start_date: datetime, 

2817 reschedule_exception: AirflowRescheduleException, 

2818 test_mode: bool = False, 

2819 session: Session = NEW_SESSION, 

2820 ): 

2821 # Don't record reschedule request in test mode 

2822 if test_mode: 

2823 return 

2824 

2825 from airflow.models.dagrun import DagRun # Avoid circular import 

2826 

2827 self.refresh_from_db(session) 

2828 

2829 if TYPE_CHECKING: 

2830 assert self.task 

2831 

2832 self.end_date = timezone.utcnow() 

2833 self.set_duration() 

2834 

2835 # Lock DAG run to be sure not to get into a deadlock situation when trying to insert 

2836 # TaskReschedule which apparently also creates lock on corresponding DagRun entity 

2837 with_row_locks( 

2838 session.query(DagRun).filter_by( 

2839 dag_id=self.dag_id, 

2840 run_id=self.run_id, 

2841 ), 

2842 session=session, 

2843 ).one() 

2844 

2845 # Log reschedule request 

2846 session.add( 

2847 TaskReschedule( 

2848 self.task_id, 

2849 self.dag_id, 

2850 self.run_id, 

2851 self.try_number, 

2852 actual_start_date, 

2853 self.end_date, 

2854 reschedule_exception.reschedule_date, 

2855 self.map_index, 

2856 ) 

2857 ) 

2858 

2859 # set state 

2860 self.state = TaskInstanceState.UP_FOR_RESCHEDULE 

2861 

2862 self.clear_next_method_args() 

2863 

2864 session.merge(self) 

2865 session.commit() 

2866 self.log.info("Rescheduling task, marking task as UP_FOR_RESCHEDULE") 

2867 

2868 @staticmethod 

2869 def get_truncated_error_traceback(error: BaseException, truncate_to: Callable) -> TracebackType | None: 

2870 """ 

2871 Truncate the traceback of an exception to the first frame called from within a given function. 

2872 

2873 :param error: exception to get traceback from 

2874 :param truncate_to: Function to truncate TB to. Must have a ``__code__`` attribute 

2875 

2876 :meta private: 

2877 """ 

2878 tb = error.__traceback__ 

2879 code = truncate_to.__func__.__code__ # type: ignore[attr-defined] 

2880 while tb is not None: 

2881 if tb.tb_frame.f_code is code: 

2882 return tb.tb_next 

2883 tb = tb.tb_next 

2884 return tb or error.__traceback__ 

2885 

2886 @classmethod 

2887 @internal_api_call 

2888 @provide_session 

2889 def fetch_handle_failure_context( 

2890 cls, 

2891 ti: TaskInstance | TaskInstancePydantic, 

2892 error: None | str | BaseException, 

2893 test_mode: bool | None = None, 

2894 context: Context | None = None, 

2895 force_fail: bool = False, 

2896 session: Session = NEW_SESSION, 

2897 fail_stop: bool = False, 

2898 ): 

2899 """ 

2900 Handle Failure for the TaskInstance. 

2901 

2902 :param fail_stop: if true, stop remaining tasks in dag 

2903 """ 

2904 get_listener_manager().hook.on_task_instance_failed( 

2905 previous_state=TaskInstanceState.RUNNING, task_instance=ti, error=error, session=session 

2906 ) 

2907 

2908 if error: 

2909 if isinstance(error, BaseException): 

2910 tb = TaskInstance.get_truncated_error_traceback(error, truncate_to=ti._execute_task) 

2911 cls.logger().error("Task failed with exception", exc_info=(type(error), error, tb)) 

2912 else: 

2913 cls.logger().error("%s", error) 

2914 if not test_mode: 

2915 ti.refresh_from_db(session) 

2916 

2917 ti.end_date = timezone.utcnow() 

2918 ti.set_duration() 

2919 

2920 Stats.incr(f"operator_failures_{ti.operator}", tags=ti.stats_tags) 

2921 # Same metric with tagging 

2922 Stats.incr("operator_failures", tags={**ti.stats_tags, "operator": ti.operator}) 

2923 Stats.incr("ti_failures", tags=ti.stats_tags) 

2924 

2925 if not test_mode: 

2926 session.add(Log(TaskInstanceState.FAILED.value, ti)) 

2927 

2928 # Log failure duration 

2929 session.add(TaskFail(ti=ti)) 

2930 

2931 ti.clear_next_method_args() 

2932 

2933 # In extreme cases (zombie in case of dag with parse error) we might _not_ have a Task. 

2934 if context is None and getattr(ti, "task", None): 

2935 context = ti.get_template_context(session) 

2936 

2937 if context is not None: 

2938 context["exception"] = error 

2939 

2940 # Set state correctly and figure out how to log it and decide whether 

2941 # to email 

2942 

2943 # Note, callback invocation needs to be handled by caller of 

2944 # _run_raw_task to avoid race conditions which could lead to duplicate 

2945 # invocations or miss invocation. 

2946 

2947 # Since this function is called only when the TaskInstance state is running, 

2948 # try_number contains the current try_number (not the next). We 

2949 # only mark task instance as FAILED if the next task instance 

2950 # try_number exceeds the max_tries ... or if force_fail is truthy 

2951 

2952 task: BaseOperator | None = None 

2953 try: 

2954 if getattr(ti, "task", None) and context: 

2955 if TYPE_CHECKING: 

2956 assert ti.task 

2957 task = ti.task.unmap((context, session)) 

2958 except Exception: 

2959 cls.logger().error("Unable to unmap task to determine if we need to send an alert email") 

2960 

2961 if force_fail or not ti.is_eligible_to_retry(): 

2962 ti.state = TaskInstanceState.FAILED 

2963 email_for_state = operator.attrgetter("email_on_failure") 

2964 callbacks = task.on_failure_callback if task else None 

2965 

2966 if task and fail_stop: 

2967 _stop_remaining_tasks(task_instance=ti, session=session) 

2968 else: 

2969 if ti.state == TaskInstanceState.QUEUED: 

2970 from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic 

2971 

2972 if isinstance(ti, TaskInstancePydantic): 

2973 # todo: (AIP-44) we should probably "coalesce" `ti` to TaskInstance before here 

2974 # e.g. we could make refresh_from_db return a TI and replace ti with that 

2975 raise RuntimeError("Expected TaskInstance here. Further AIP-44 work required.") 

2976 # We increase the try_number to fail the task if it fails to start after sometime 

2977 ti.state = State.UP_FOR_RETRY 

2978 email_for_state = operator.attrgetter("email_on_retry") 

2979 callbacks = task.on_retry_callback if task else None 

2980 

2981 return { 

2982 "ti": ti, 

2983 "email_for_state": email_for_state, 

2984 "task": task, 

2985 "callbacks": callbacks, 

2986 "context": context, 

2987 } 

2988 

2989 @staticmethod 

2990 @internal_api_call 

2991 @provide_session 

2992 def save_to_db(ti: TaskInstance | TaskInstancePydantic, session: Session = NEW_SESSION): 

2993 session.merge(ti) 

2994 session.flush() 

2995 

2996 @provide_session 

2997 def handle_failure( 

2998 self, 

2999 error: None | str | BaseException, 

3000 test_mode: bool | None = None, 

3001 context: Context | None = None, 

3002 force_fail: bool = False, 

3003 session: Session = NEW_SESSION, 

3004 ) -> None: 

3005 """ 

3006 Handle Failure for a task instance. 

3007 

3008 :param error: if specified, log the specific exception if thrown 

3009 :param session: SQLAlchemy ORM Session 

3010 :param test_mode: doesn't record success or failure in the DB if True 

3011 :param context: Jinja2 context 

3012 :param force_fail: if True, task does not retry 

3013 """ 

3014 if TYPE_CHECKING: 

3015 assert self.task 

3016 assert self.task.dag 

3017 try: 

3018 fail_stop = self.task.dag.fail_stop 

3019 except Exception: 

3020 fail_stop = False 

3021 _handle_failure( 

3022 task_instance=self, 

3023 error=error, 

3024 session=session, 

3025 test_mode=test_mode, 

3026 context=context, 

3027 force_fail=force_fail, 

3028 fail_stop=fail_stop, 

3029 ) 

3030 

3031 def is_eligible_to_retry(self): 

3032 """Is task instance is eligible for retry.""" 

3033 return _is_eligible_to_retry(task_instance=self) 

3034 

3035 def get_template_context( 

3036 self, 

3037 session: Session | None = None, 

3038 ignore_param_exceptions: bool = True, 

3039 ) -> Context: 

3040 """ 

3041 Return TI Context. 

3042 

3043 :param session: SQLAlchemy ORM Session 

3044 :param ignore_param_exceptions: flag to suppress value exceptions while initializing the ParamsDict 

3045 """ 

3046 return _get_template_context( 

3047 task_instance=self, 

3048 session=session, 

3049 ignore_param_exceptions=ignore_param_exceptions, 

3050 ) 

3051 

3052 @provide_session 

3053 def get_rendered_template_fields(self, session: Session = NEW_SESSION) -> None: 

3054 """ 

3055 Update task with rendered template fields for presentation in UI. 

3056 

3057 If task has already run, will fetch from DB; otherwise will render. 

3058 """ 

3059 from airflow.models.renderedtifields import RenderedTaskInstanceFields 

3060 

3061 if TYPE_CHECKING: 

3062 assert self.task 

3063 

3064 rendered_task_instance_fields = RenderedTaskInstanceFields.get_templated_fields(self, session=session) 

3065 if rendered_task_instance_fields: 

3066 self.task = self.task.unmap(None) 

3067 for field_name, rendered_value in rendered_task_instance_fields.items(): 

3068 setattr(self.task, field_name, rendered_value) 

3069 return 

3070 

3071 try: 

3072 # If we get here, either the task hasn't run or the RTIF record was purged. 

3073 from airflow.utils.log.secrets_masker import redact 

3074 

3075 self.render_templates() 

3076 for field_name in self.task.template_fields: 

3077 rendered_value = getattr(self.task, field_name) 

3078 setattr(self.task, field_name, redact(rendered_value, field_name)) 

3079 except (TemplateAssertionError, UndefinedError) as e: 

3080 raise AirflowException( 

3081 "Webserver does not have access to User-defined Macros or Filters " 

3082 "when Dag Serialization is enabled. Hence for the task that have not yet " 

3083 "started running, please use 'airflow tasks render' for debugging the " 

3084 "rendering of template_fields." 

3085 ) from e 

3086 

3087 def overwrite_params_with_dag_run_conf(self, params: dict, dag_run: DagRun): 

3088 """Overwrite Task Params with DagRun.conf.""" 

3089 if dag_run and dag_run.conf: 

3090 self.log.debug("Updating task params (%s) with DagRun.conf (%s)", params, dag_run.conf) 

3091 params.update(dag_run.conf) 

3092 

3093 def render_templates( 

3094 self, context: Context | None = None, jinja_env: jinja2.Environment | None = None 

3095 ) -> Operator: 

3096 """Render templates in the operator fields. 

3097 

3098 If the task was originally mapped, this may replace ``self.task`` with 

3099 the unmapped, fully rendered BaseOperator. The original ``self.task`` 

3100 before replacement is returned. 

3101 """ 

3102 if not context: 

3103 context = self.get_template_context() 

3104 original_task = self.task 

3105 

3106 if TYPE_CHECKING: 

3107 assert original_task 

3108 

3109 # If self.task is mapped, this call replaces self.task to point to the 

3110 # unmapped BaseOperator created by this function! This is because the 

3111 # MappedOperator is useless for template rendering, and we need to be 

3112 # able to access the unmapped task instead. 

3113 original_task.render_template_fields(context, jinja_env) 

3114 

3115 return original_task 

3116 

3117 def render_k8s_pod_yaml(self) -> dict | None: 

3118 """Render the k8s pod yaml.""" 

3119 try: 

3120 from airflow.providers.cncf.kubernetes.template_rendering import ( 

3121 render_k8s_pod_yaml as render_k8s_pod_yaml_from_provider, 

3122 ) 

3123 except ImportError: 

3124 raise RuntimeError( 

3125 "You need to have the `cncf.kubernetes` provider installed to use this feature. " 

3126 "Also rather than calling it directly you should import " 

3127 "render_k8s_pod_yaml from airflow.providers.cncf.kubernetes.template_rendering " 

3128 "and call it with TaskInstance as the first argument." 

3129 ) 

3130 warnings.warn( 

3131 "You should not call `task_instance.render_k8s_pod_yaml` directly. This method will be removed" 

3132 "in Airflow 3. Rather than calling it directly you should import " 

3133 "`render_k8s_pod_yaml` from `airflow.providers.cncf.kubernetes.template_rendering` " 

3134 "and call it with `TaskInstance` as the first argument.", 

3135 DeprecationWarning, 

3136 stacklevel=2, 

3137 ) 

3138 return render_k8s_pod_yaml_from_provider(self) 

3139 

3140 @provide_session 

3141 def get_rendered_k8s_spec(self, session: Session = NEW_SESSION): 

3142 """Render the k8s pod yaml.""" 

3143 try: 

3144 from airflow.providers.cncf.kubernetes.template_rendering import ( 

3145 get_rendered_k8s_spec as get_rendered_k8s_spec_from_provider, 

3146 ) 

3147 except ImportError: 

3148 raise RuntimeError( 

3149 "You need to have the `cncf.kubernetes` provider installed to use this feature. " 

3150 "Also rather than calling it directly you should import " 

3151 "`get_rendered_k8s_spec` from `airflow.providers.cncf.kubernetes.template_rendering` " 

3152 "and call it with `TaskInstance` as the first argument." 

3153 ) 

3154 warnings.warn( 

3155 "You should not call `task_instance.render_k8s_pod_yaml` directly. This method will be removed" 

3156 "in Airflow 3. Rather than calling it directly you should import " 

3157 "`get_rendered_k8s_spec` from `airflow.providers.cncf.kubernetes.template_rendering` " 

3158 "and call it with `TaskInstance` as the first argument.", 

3159 DeprecationWarning, 

3160 stacklevel=2, 

3161 ) 

3162 return get_rendered_k8s_spec_from_provider(self, session=session) 

3163 

3164 def get_email_subject_content( 

3165 self, exception: BaseException, task: BaseOperator | None = None 

3166 ) -> tuple[str, str, str]: 

3167 """ 

3168 Get the email subject content for exceptions. 

3169 

3170 :param exception: the exception sent in the email 

3171 :param task: 

3172 """ 

3173 return _get_email_subject_content(task_instance=self, exception=exception, task=task) 

3174 

3175 def email_alert(self, exception, task: BaseOperator) -> None: 

3176 """ 

3177 Send alert email with exception information. 

3178 

3179 :param exception: the exception 

3180 :param task: task related to the exception 

3181 """ 

3182 _email_alert(task_instance=self, exception=exception, task=task) 

3183 

3184 def set_duration(self) -> None: 

3185 """Set task instance duration.""" 

3186 _set_duration(task_instance=self) 

3187 

3188 @provide_session 

3189 def xcom_push( 

3190 self, 

3191 key: str, 

3192 value: Any, 

3193 execution_date: datetime | None = None, 

3194 session: Session = NEW_SESSION, 

3195 ) -> None: 

3196 """ 

3197 Make an XCom available for tasks to pull. 

3198 

3199 :param key: Key to store the value under. 

3200 :param value: Value to store. What types are possible depends on whether 

3201 ``enable_xcom_pickling`` is true or not. If so, this can be any 

3202 picklable object; only be JSON-serializable may be used otherwise. 

3203 :param execution_date: Deprecated parameter that has no effect. 

3204 """ 

3205 if execution_date is not None: 

3206 self_execution_date = self.get_dagrun(session).execution_date 

3207 if execution_date < self_execution_date: 

3208 raise ValueError( 

3209 f"execution_date can not be in the past (current execution_date is " 

3210 f"{self_execution_date}; received {execution_date})" 

3211 ) 

3212 elif execution_date is not None: 

3213 message = "Passing 'execution_date' to 'TaskInstance.xcom_push()' is deprecated." 

3214 warnings.warn(message, RemovedInAirflow3Warning, stacklevel=3) 

3215 

3216 XCom.set( 

3217 key=key, 

3218 value=value, 

3219 task_id=self.task_id, 

3220 dag_id=self.dag_id, 

3221 run_id=self.run_id, 

3222 map_index=self.map_index, 

3223 session=session, 

3224 ) 

3225 

3226 @provide_session 

3227 def xcom_pull( 

3228 self, 

3229 task_ids: str | Iterable[str] | None = None, 

3230 dag_id: str | None = None, 

3231 key: str = XCOM_RETURN_KEY, 

3232 include_prior_dates: bool = False, 

3233 session: Session = NEW_SESSION, 

3234 *, 

3235 map_indexes: int | Iterable[int] | None = None, 

3236 default: Any = None, 

3237 ) -> Any: 

3238 """Pull XComs that optionally meet certain criteria. 

3239 

3240 :param key: A key for the XCom. If provided, only XComs with matching 

3241 keys will be returned. The default key is ``'return_value'``, also 

3242 available as constant ``XCOM_RETURN_KEY``. This key is automatically 

3243 given to XComs returned by tasks (as opposed to being pushed 

3244 manually). To remove the filter, pass *None*. 

3245 :param task_ids: Only XComs from tasks with matching ids will be 

3246 pulled. Pass *None* to remove the filter. 

3247 :param dag_id: If provided, only pulls XComs from this DAG. If *None* 

3248 (default), the DAG of the calling task is used. 

3249 :param map_indexes: If provided, only pull XComs with matching indexes. 

3250 If *None* (default), this is inferred from the task(s) being pulled 

3251 (see below for details). 

3252 :param include_prior_dates: If False, only XComs from the current 

3253 execution_date are returned. If *True*, XComs from previous dates 

3254 are returned as well. 

3255 

3256 When pulling one single task (``task_id`` is *None* or a str) without 

3257 specifying ``map_indexes``, the return value is inferred from whether 

3258 the specified task is mapped. If not, value from the one single task 

3259 instance is returned. If the task to pull is mapped, an iterator (not a 

3260 list) yielding XComs from mapped task instances is returned. In either 

3261 case, ``default`` (*None* if not specified) is returned if no matching 

3262 XComs are found. 

3263 

3264 When pulling multiple tasks (i.e. either ``task_id`` or ``map_index`` is 

3265 a non-str iterable), a list of matching XComs is returned. Elements in 

3266 the list is ordered by item ordering in ``task_id`` and ``map_index``. 

3267 """ 

3268 if dag_id is None: 

3269 dag_id = self.dag_id 

3270 

3271 query = XCom.get_many( 

3272 key=key, 

3273 run_id=self.run_id, 

3274 dag_ids=dag_id, 

3275 task_ids=task_ids, 

3276 map_indexes=map_indexes, 

3277 include_prior_dates=include_prior_dates, 

3278 session=session, 

3279 ) 

3280 

3281 # NOTE: Since we're only fetching the value field and not the whole 

3282 # class, the @recreate annotation does not kick in. Therefore we need to 

3283 # call XCom.deserialize_value() manually. 

3284 

3285 # We are only pulling one single task. 

3286 if (task_ids is None or isinstance(task_ids, str)) and not isinstance(map_indexes, Iterable): 

3287 first = query.with_entities( 

3288 XCom.run_id, XCom.task_id, XCom.dag_id, XCom.map_index, XCom.value 

3289 ).first() 

3290 if first is None: # No matching XCom at all. 

3291 return default 

3292 if map_indexes is not None or first.map_index < 0: 

3293 return XCom.deserialize_value(first) 

3294 return LazyXComSelectSequence.from_select( 

3295 query.with_entities(XCom.value).order_by(None).statement, 

3296 order_by=[XCom.map_index], 

3297 session=session, 

3298 ) 

3299 

3300 # At this point either task_ids or map_indexes is explicitly multi-value. 

3301 # Order return values to match task_ids and map_indexes ordering. 

3302 ordering = [] 

3303 if task_ids is None or isinstance(task_ids, str): 

3304 ordering.append(XCom.task_id) 

3305 elif task_id_whens := {tid: i for i, tid in enumerate(task_ids)}: 

3306 ordering.append(case(task_id_whens, value=XCom.task_id)) 

3307 else: 

3308 ordering.append(XCom.task_id) 

3309 if map_indexes is None or isinstance(map_indexes, int): 

3310 ordering.append(XCom.map_index) 

3311 elif isinstance(map_indexes, range): 

3312 order = XCom.map_index 

3313 if map_indexes.step < 0: 

3314 order = order.desc() 

3315 ordering.append(order) 

3316 elif map_index_whens := {map_index: i for i, map_index in enumerate(map_indexes)}: 

3317 ordering.append(case(map_index_whens, value=XCom.map_index)) 

3318 else: 

3319 ordering.append(XCom.map_index) 

3320 return LazyXComSelectSequence.from_select( 

3321 query.with_entities(XCom.value).order_by(None).statement, 

3322 order_by=ordering, 

3323 session=session, 

3324 ) 

3325 

3326 @provide_session 

3327 def get_num_running_task_instances(self, session: Session, same_dagrun: bool = False) -> int: 

3328 """Return Number of running TIs from the DB.""" 

3329 # .count() is inefficient 

3330 num_running_task_instances_query = session.query(func.count()).filter( 

3331 TaskInstance.dag_id == self.dag_id, 

3332 TaskInstance.task_id == self.task_id, 

3333 TaskInstance.state == TaskInstanceState.RUNNING, 

3334 ) 

3335 if same_dagrun: 

3336 num_running_task_instances_query = num_running_task_instances_query.filter( 

3337 TaskInstance.run_id == self.run_id 

3338 ) 

3339 return num_running_task_instances_query.scalar() 

3340 

3341 def init_run_context(self, raw: bool = False) -> None: 

3342 """Set the log context.""" 

3343 self.raw = raw 

3344 self._set_context(self) 

3345 

3346 @staticmethod 

3347 def filter_for_tis(tis: Iterable[TaskInstance | TaskInstanceKey]) -> BooleanClauseList | None: 

3348 """Return SQLAlchemy filter to query selected task instances.""" 

3349 # DictKeys type, (what we often pass here from the scheduler) is not directly indexable :( 

3350 # Or it might be a generator, but we need to be able to iterate over it more than once 

3351 tis = list(tis) 

3352 

3353 if not tis: 

3354 return None 

3355 

3356 first = tis[0] 

3357 

3358 dag_id = first.dag_id 

3359 run_id = first.run_id 

3360 map_index = first.map_index 

3361 first_task_id = first.task_id 

3362 

3363 # pre-compute the set of dag_id, run_id, map_indices and task_ids 

3364 dag_ids, run_ids, map_indices, task_ids = set(), set(), set(), set() 

3365 for t in tis: 

3366 dag_ids.add(t.dag_id) 

3367 run_ids.add(t.run_id) 

3368 map_indices.add(t.map_index) 

3369 task_ids.add(t.task_id) 

3370 

3371 # Common path optimisations: when all TIs are for the same dag_id and run_id, or same dag_id 

3372 # and task_id -- this can be over 150x faster for huge numbers of TIs (20k+) 

3373 if dag_ids == {dag_id} and run_ids == {run_id} and map_indices == {map_index}: 

3374 return and_( 

3375 TaskInstance.dag_id == dag_id, 

3376 TaskInstance.run_id == run_id, 

3377 TaskInstance.map_index == map_index, 

3378 TaskInstance.task_id.in_(task_ids), 

3379 ) 

3380 if dag_ids == {dag_id} and task_ids == {first_task_id} and map_indices == {map_index}: 

3381 return and_( 

3382 TaskInstance.dag_id == dag_id, 

3383 TaskInstance.run_id.in_(run_ids), 

3384 TaskInstance.map_index == map_index, 

3385 TaskInstance.task_id == first_task_id, 

3386 ) 

3387 if dag_ids == {dag_id} and run_ids == {run_id} and task_ids == {first_task_id}: 

3388 return and_( 

3389 TaskInstance.dag_id == dag_id, 

3390 TaskInstance.run_id == run_id, 

3391 TaskInstance.map_index.in_(map_indices), 

3392 TaskInstance.task_id == first_task_id, 

3393 ) 

3394 

3395 filter_condition = [] 

3396 # create 2 nested groups, both primarily grouped by dag_id and run_id, 

3397 # and in the nested group 1 grouped by task_id the other by map_index. 

3398 task_id_groups: dict[tuple, dict[Any, list[Any]]] = defaultdict(lambda: defaultdict(list)) 

3399 map_index_groups: dict[tuple, dict[Any, list[Any]]] = defaultdict(lambda: defaultdict(list)) 

3400 for t in tis: 

3401 task_id_groups[(t.dag_id, t.run_id)][t.task_id].append(t.map_index) 

3402 map_index_groups[(t.dag_id, t.run_id)][t.map_index].append(t.task_id) 

3403 

3404 # this assumes that most dags have dag_id as the largest grouping, followed by run_id. even 

3405 # if its not, this is still a significant optimization over querying for every single tuple key 

3406 for cur_dag_id, cur_run_id in itertools.product(dag_ids, run_ids): 

3407 # we compare the group size between task_id and map_index and use the smaller group 

3408 dag_task_id_groups = task_id_groups[(cur_dag_id, cur_run_id)] 

3409 dag_map_index_groups = map_index_groups[(cur_dag_id, cur_run_id)] 

3410 

3411 if len(dag_task_id_groups) <= len(dag_map_index_groups): 

3412 for cur_task_id, cur_map_indices in dag_task_id_groups.items(): 

3413 filter_condition.append( 

3414 and_( 

3415 TaskInstance.dag_id == cur_dag_id, 

3416 TaskInstance.run_id == cur_run_id, 

3417 TaskInstance.task_id == cur_task_id, 

3418 TaskInstance.map_index.in_(cur_map_indices), 

3419 ) 

3420 ) 

3421 else: 

3422 for cur_map_index, cur_task_ids in dag_map_index_groups.items(): 

3423 filter_condition.append( 

3424 and_( 

3425 TaskInstance.dag_id == cur_dag_id, 

3426 TaskInstance.run_id == cur_run_id, 

3427 TaskInstance.task_id.in_(cur_task_ids), 

3428 TaskInstance.map_index == cur_map_index, 

3429 ) 

3430 ) 

3431 

3432 return or_(*filter_condition) 

3433 

3434 @classmethod 

3435 def ti_selector_condition(cls, vals: Collection[str | tuple[str, int]]) -> ColumnOperators: 

3436 """ 

3437 Build an SQLAlchemy filter for a list of task_ids or tuples of (task_id,map_index). 

3438 

3439 :meta private: 

3440 """ 

3441 # Compute a filter for TI.task_id and TI.map_index based on input values 

3442 # For each item, it will either be a task_id, or (task_id, map_index) 

3443 task_id_only = [v for v in vals if isinstance(v, str)] 

3444 with_map_index = [v for v in vals if not isinstance(v, str)] 

3445 

3446 filters: list[ColumnOperators] = [] 

3447 if task_id_only: 

3448 filters.append(cls.task_id.in_(task_id_only)) 

3449 if with_map_index: 

3450 filters.append(tuple_in_condition((cls.task_id, cls.map_index), with_map_index)) 

3451 

3452 if not filters: 

3453 return false() 

3454 if len(filters) == 1: 

3455 return filters[0] 

3456 return or_(*filters) 

3457 

3458 @classmethod 

3459 @internal_api_call 

3460 @provide_session 

3461 def _schedule_downstream_tasks( 

3462 cls, 

3463 ti: TaskInstance | TaskInstancePydantic, 

3464 session: Session = NEW_SESSION, 

3465 max_tis_per_query: int | None = None, 

3466 ): 

3467 from sqlalchemy.exc import OperationalError 

3468 

3469 from airflow.models.dagrun import DagRun 

3470 

3471 try: 

3472 # Re-select the row with a lock 

3473 dag_run = with_row_locks( 

3474 session.query(DagRun).filter_by( 

3475 dag_id=ti.dag_id, 

3476 run_id=ti.run_id, 

3477 ), 

3478 session=session, 

3479 nowait=True, 

3480 ).one() 

3481 

3482 task = ti.task 

3483 if TYPE_CHECKING: 

3484 assert task 

3485 assert task.dag 

3486 

3487 # Get a partial DAG with just the specific tasks we want to examine. 

3488 # In order for dep checks to work correctly, we include ourself (so 

3489 # TriggerRuleDep can check the state of the task we just executed). 

3490 partial_dag = task.dag.partial_subset( 

3491 task.downstream_task_ids, 

3492 include_downstream=True, 

3493 include_upstream=False, 

3494 include_direct_upstream=True, 

3495 ) 

3496 

3497 dag_run.dag = partial_dag 

3498 info = dag_run.task_instance_scheduling_decisions(session) 

3499 

3500 skippable_task_ids = { 

3501 task_id for task_id in partial_dag.task_ids if task_id not in task.downstream_task_ids 

3502 } 

3503 

3504 schedulable_tis = [ 

3505 ti 

3506 for ti in info.schedulable_tis 

3507 if ti.task_id not in skippable_task_ids 

3508 and not ( 

3509 ti.task.inherits_from_empty_operator 

3510 and not ti.task.on_execute_callback 

3511 and not ti.task.on_success_callback 

3512 and not ti.task.outlets 

3513 ) 

3514 ] 

3515 for schedulable_ti in schedulable_tis: 

3516 if getattr(schedulable_ti, "task", None) is None: 

3517 schedulable_ti.task = task.dag.get_task(schedulable_ti.task_id) 

3518 

3519 num = dag_run.schedule_tis(schedulable_tis, session=session, max_tis_per_query=max_tis_per_query) 

3520 cls.logger().info("%d downstream tasks scheduled from follow-on schedule check", num) 

3521 

3522 session.flush() 

3523 

3524 except OperationalError as e: 

3525 # Any kind of DB error here is _non fatal_ as this block is just an optimisation. 

3526 cls.logger().debug( 

3527 "Skipping mini scheduling run due to exception: %s", 

3528 e.statement, 

3529 exc_info=True, 

3530 ) 

3531 session.rollback() 

3532 

3533 @provide_session 

3534 def schedule_downstream_tasks(self, session: Session = NEW_SESSION, max_tis_per_query: int | None = None): 

3535 """ 

3536 Schedule downstream tasks of this task instance. 

3537 

3538 :meta: private 

3539 """ 

3540 return TaskInstance._schedule_downstream_tasks( 

3541 ti=self, session=session, max_tis_per_query=max_tis_per_query 

3542 ) 

3543 

3544 def get_relevant_upstream_map_indexes( 

3545 self, 

3546 upstream: Operator, 

3547 ti_count: int | None, 

3548 *, 

3549 session: Session, 

3550 ) -> int | range | None: 

3551 """Infer the map indexes of an upstream "relevant" to this ti. 

3552 

3553 The bulk of the logic mainly exists to solve the problem described by 

3554 the following example, where 'val' must resolve to different values, 

3555 depending on where the reference is being used:: 

3556 

3557 @task 

3558 def this_task(v): # This is self.task. 

3559 return v * 2 

3560 

3561 

3562 @task_group 

3563 def tg1(inp): 

3564 val = upstream(inp) # This is the upstream task. 

3565 this_task(val) # When inp is 1, val here should resolve to 2. 

3566 return val 

3567 

3568 

3569 # This val is the same object returned by tg1. 

3570 val = tg1.expand(inp=[1, 2, 3]) 

3571 

3572 

3573 @task_group 

3574 def tg2(inp): 

3575 another_task(inp, val) # val here should resolve to [2, 4, 6]. 

3576 

3577 

3578 tg2.expand(inp=["a", "b"]) 

3579 

3580 The surrounding mapped task groups of ``upstream`` and ``self.task`` are 

3581 inspected to find a common "ancestor". If such an ancestor is found, 

3582 we need to return specific map indexes to pull a partial value from 

3583 upstream XCom. 

3584 

3585 :param upstream: The referenced upstream task. 

3586 :param ti_count: The total count of task instance this task was expanded 

3587 by the scheduler, i.e. ``expanded_ti_count`` in the template context. 

3588 :return: Specific map index or map indexes to pull, or ``None`` if we 

3589 want to "whole" return value (i.e. no mapped task groups involved). 

3590 """ 

3591 if TYPE_CHECKING: 

3592 assert self.task 

3593 

3594 # This value should never be None since we already know the current task 

3595 # is in a mapped task group, and should have been expanded, despite that, 

3596 # we need to check that it is not None to satisfy Mypy. 

3597 # But this value can be 0 when we expand an empty list, for that it is 

3598 # necessary to check that ti_count is not 0 to avoid dividing by 0. 

3599 if not ti_count: 

3600 return None 

3601 

3602 # Find the innermost common mapped task group between the current task 

3603 # If the current task and the referenced task does not have a common 

3604 # mapped task group, the two are in different task mapping contexts 

3605 # (like another_task above), and we should use the "whole" value. 

3606 common_ancestor = _find_common_ancestor_mapped_group(self.task, upstream) 

3607 if common_ancestor is None: 

3608 return None 

3609 

3610 # At this point we know the two tasks share a mapped task group, and we 

3611 # should use a "partial" value. Let's break down the mapped ti count 

3612 # between the ancestor and further expansion happened inside it. 

3613 ancestor_ti_count = common_ancestor.get_mapped_ti_count(self.run_id, session=session) 

3614 ancestor_map_index = self.map_index * ancestor_ti_count // ti_count 

3615 

3616 # If the task is NOT further expanded inside the common ancestor, we 

3617 # only want to reference one single ti. We must walk the actual DAG, 

3618 # and "ti_count == ancestor_ti_count" does not work, since the further 

3619 # expansion may be of length 1. 

3620 if not _is_further_mapped_inside(upstream, common_ancestor): 

3621 return ancestor_map_index 

3622 

3623 # Otherwise we need a partial aggregation for values from selected task 

3624 # instances in the ancestor's expansion context. 

3625 further_count = ti_count // ancestor_ti_count 

3626 map_index_start = ancestor_map_index * further_count 

3627 return range(map_index_start, map_index_start + further_count) 

3628 

3629 def clear_db_references(self, session: Session): 

3630 """ 

3631 Clear db tables that have a reference to this instance. 

3632 

3633 :param session: ORM Session 

3634 

3635 :meta private: 

3636 """ 

3637 from airflow.models.renderedtifields import RenderedTaskInstanceFields 

3638 

3639 tables: list[type[TaskInstanceDependencies]] = [ 

3640 TaskFail, 

3641 TaskInstanceNote, 

3642 TaskReschedule, 

3643 XCom, 

3644 RenderedTaskInstanceFields, 

3645 TaskMap, 

3646 ] 

3647 for table in tables: 

3648 session.execute( 

3649 delete(table).where( 

3650 table.dag_id == self.dag_id, 

3651 table.task_id == self.task_id, 

3652 table.run_id == self.run_id, 

3653 table.map_index == self.map_index, 

3654 ) 

3655 ) 

3656 

3657 

3658def _find_common_ancestor_mapped_group(node1: Operator, node2: Operator) -> MappedTaskGroup | None: 

3659 """Given two operators, find their innermost common mapped task group.""" 

3660 if node1.dag is None or node2.dag is None or node1.dag_id != node2.dag_id: 

3661 return None 

3662 parent_group_ids = {g.group_id for g in node1.iter_mapped_task_groups()} 

3663 common_groups = (g for g in node2.iter_mapped_task_groups() if g.group_id in parent_group_ids) 

3664 return next(common_groups, None) 

3665 

3666 

3667def _is_further_mapped_inside(operator: Operator, container: TaskGroup) -> bool: 

3668 """Whether given operator is *further* mapped inside a task group.""" 

3669 if isinstance(operator, MappedOperator): 

3670 return True 

3671 task_group = operator.task_group 

3672 while task_group is not None and task_group.group_id != container.group_id: 

3673 if isinstance(task_group, MappedTaskGroup): 

3674 return True 

3675 task_group = task_group.parent_group 

3676 return False 

3677 

3678 

3679# State of the task instance. 

3680# Stores string version of the task state. 

3681TaskInstanceStateType = Tuple[TaskInstanceKey, TaskInstanceState] 

3682 

3683 

3684class SimpleTaskInstance: 

3685 """ 

3686 Simplified Task Instance. 

3687 

3688 Used to send data between processes via Queues. 

3689 """ 

3690 

3691 def __init__( 

3692 self, 

3693 dag_id: str, 

3694 task_id: str, 

3695 run_id: str, 

3696 start_date: datetime | None, 

3697 end_date: datetime | None, 

3698 try_number: int, 

3699 map_index: int, 

3700 state: str, 

3701 executor: str | None, 

3702 executor_config: Any, 

3703 pool: str, 

3704 queue: str, 

3705 key: TaskInstanceKey, 

3706 run_as_user: str | None = None, 

3707 priority_weight: int | None = None, 

3708 ): 

3709 self.dag_id = dag_id 

3710 self.task_id = task_id 

3711 self.run_id = run_id 

3712 self.map_index = map_index 

3713 self.start_date = start_date 

3714 self.end_date = end_date 

3715 self.try_number = try_number 

3716 self.state = state 

3717 self.executor = executor 

3718 self.executor_config = executor_config 

3719 self.run_as_user = run_as_user 

3720 self.pool = pool 

3721 self.priority_weight = priority_weight 

3722 self.queue = queue 

3723 self.key = key 

3724 

3725 def __eq__(self, other): 

3726 if isinstance(other, self.__class__): 

3727 return self.__dict__ == other.__dict__ 

3728 return NotImplemented 

3729 

3730 def as_dict(self): 

3731 warnings.warn( 

3732 "This method is deprecated. Use BaseSerialization.serialize.", 

3733 RemovedInAirflow3Warning, 

3734 stacklevel=2, 

3735 ) 

3736 new_dict = dict(self.__dict__) 

3737 for key in new_dict: 

3738 if key in ["start_date", "end_date"]: 

3739 val = new_dict[key] 

3740 if not val or isinstance(val, str): 

3741 continue 

3742 new_dict.update({key: val.isoformat()}) 

3743 return new_dict 

3744 

3745 @classmethod 

3746 def from_ti(cls, ti: TaskInstance) -> SimpleTaskInstance: 

3747 return cls( 

3748 dag_id=ti.dag_id, 

3749 task_id=ti.task_id, 

3750 run_id=ti.run_id, 

3751 map_index=ti.map_index, 

3752 start_date=ti.start_date, 

3753 end_date=ti.end_date, 

3754 try_number=ti.try_number, 

3755 state=ti.state, 

3756 executor=ti.executor, 

3757 executor_config=ti.executor_config, 

3758 pool=ti.pool, 

3759 queue=ti.queue, 

3760 key=ti.key, 

3761 run_as_user=ti.run_as_user if hasattr(ti, "run_as_user") else None, 

3762 priority_weight=ti.priority_weight if hasattr(ti, "priority_weight") else None, 

3763 ) 

3764 

3765 @classmethod 

3766 def from_dict(cls, obj_dict: dict) -> SimpleTaskInstance: 

3767 warnings.warn( 

3768 "This method is deprecated. Use BaseSerialization.deserialize.", 

3769 RemovedInAirflow3Warning, 

3770 stacklevel=2, 

3771 ) 

3772 ti_key = TaskInstanceKey(*obj_dict.pop("key")) 

3773 start_date = None 

3774 end_date = None 

3775 start_date_str: str | None = obj_dict.pop("start_date") 

3776 end_date_str: str | None = obj_dict.pop("end_date") 

3777 if start_date_str: 

3778 start_date = timezone.parse(start_date_str) 

3779 if end_date_str: 

3780 end_date = timezone.parse(end_date_str) 

3781 return cls(**obj_dict, start_date=start_date, end_date=end_date, key=ti_key) 

3782 

3783 

3784class TaskInstanceNote(TaskInstanceDependencies): 

3785 """For storage of arbitrary notes concerning the task instance.""" 

3786 

3787 __tablename__ = "task_instance_note" 

3788 

3789 user_id = Column(Integer, ForeignKey("ab_user.id", name="task_instance_note_user_fkey"), nullable=True) 

3790 task_id = Column(StringID(), primary_key=True, nullable=False) 

3791 dag_id = Column(StringID(), primary_key=True, nullable=False) 

3792 run_id = Column(StringID(), primary_key=True, nullable=False) 

3793 map_index = Column(Integer, primary_key=True, nullable=False) 

3794 content = Column(String(1000).with_variant(Text(1000), "mysql")) 

3795 created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False) 

3796 updated_at = Column(UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow, nullable=False) 

3797 

3798 task_instance = relationship("TaskInstance", back_populates="task_instance_note") 

3799 

3800 __table_args__ = ( 

3801 PrimaryKeyConstraint("task_id", "dag_id", "run_id", "map_index", name="task_instance_note_pkey"), 

3802 ForeignKeyConstraint( 

3803 (dag_id, task_id, run_id, map_index), 

3804 [ 

3805 "task_instance.dag_id", 

3806 "task_instance.task_id", 

3807 "task_instance.run_id", 

3808 "task_instance.map_index", 

3809 ], 

3810 name="task_instance_note_ti_fkey", 

3811 ondelete="CASCADE", 

3812 ), 

3813 ) 

3814 

3815 def __init__(self, content, user_id=None): 

3816 self.content = content 

3817 self.user_id = user_id 

3818 

3819 def __repr__(self): 

3820 prefix = f"<{self.__class__.__name__}: {self.dag_id}.{self.task_id} {self.run_id}" 

3821 if self.map_index != -1: 

3822 prefix += f" map_index={self.map_index}" 

3823 return prefix + ">" 

3824 

3825 

3826STATICA_HACK = True 

3827globals()["kcah_acitats"[::-1].upper()] = False 

3828if STATICA_HACK: # pragma: no cover 

3829 from airflow.jobs.job import Job 

3830 

3831 TaskInstance.queued_by_job = relationship(Job)