Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/airflow/models/taskinstance.py: 21%

1316 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-06-07 06:35 +0000

1# 

2# Licensed to the Apache Software Foundation (ASF) under one 

3# or more contributor license agreements. See the NOTICE file 

4# distributed with this work for additional information 

5# regarding copyright ownership. The ASF licenses this file 

6# to you under the Apache License, Version 2.0 (the 

7# "License"); you may not use this file except in compliance 

8# with the License. You may obtain a copy of the License at 

9# 

10# http://www.apache.org/licenses/LICENSE-2.0 

11# 

12# Unless required by applicable law or agreed to in writing, 

13# software distributed under the License is distributed on an 

14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

15# KIND, either express or implied. See the License for the 

16# specific language governing permissions and limitations 

17# under the License. 

18from __future__ import annotations 

19 

20import collections.abc 

21import contextlib 

22import hashlib 

23import logging 

24import math 

25import operator 

26import os 

27import signal 

28import warnings 

29from collections import defaultdict 

30from datetime import datetime, timedelta 

31from enum import Enum 

32from functools import partial 

33from pathlib import PurePath 

34from types import TracebackType 

35from typing import TYPE_CHECKING, Any, Callable, Collection, Generator, Iterable, Tuple 

36from urllib.parse import quote 

37 

38import dill 

39import jinja2 

40import lazy_object_proxy 

41import pendulum 

42from jinja2 import TemplateAssertionError, UndefinedError 

43from sqlalchemy import ( 

44 Column, 

45 DateTime, 

46 Float, 

47 ForeignKeyConstraint, 

48 Index, 

49 Integer, 

50 PrimaryKeyConstraint, 

51 String, 

52 Text, 

53 and_, 

54 delete, 

55 false, 

56 func, 

57 inspect, 

58 or_, 

59 text, 

60) 

61from sqlalchemy.ext.associationproxy import association_proxy 

62from sqlalchemy.ext.mutable import MutableDict 

63from sqlalchemy.orm import reconstructor, relationship 

64from sqlalchemy.orm.attributes import NO_VALUE, set_committed_value 

65from sqlalchemy.orm.session import Session 

66from sqlalchemy.sql.elements import BooleanClauseList 

67from sqlalchemy.sql.expression import ColumnOperators, case 

68 

69from airflow import settings 

70from airflow.compat.functools import cache 

71from airflow.configuration import conf 

72from airflow.datasets import Dataset 

73from airflow.datasets.manager import dataset_manager 

74from airflow.exceptions import ( 

75 AirflowException, 

76 AirflowFailException, 

77 AirflowRescheduleException, 

78 AirflowSensorTimeout, 

79 AirflowSkipException, 

80 AirflowTaskTimeout, 

81 DagRunNotFound, 

82 RemovedInAirflow3Warning, 

83 TaskDeferralError, 

84 TaskDeferred, 

85 UnmappableXComLengthPushed, 

86 UnmappableXComTypePushed, 

87 XComForMappingNotPushed, 

88) 

89from airflow.listeners.listener import get_listener_manager 

90from airflow.models.base import Base, StringID 

91from airflow.models.dagbag import DagBag 

92from airflow.models.log import Log 

93from airflow.models.mappedoperator import MappedOperator 

94from airflow.models.param import process_params 

95from airflow.models.taskfail import TaskFail 

96from airflow.models.taskinstancekey import TaskInstanceKey 

97from airflow.models.taskmap import TaskMap 

98from airflow.models.taskreschedule import TaskReschedule 

99from airflow.models.xcom import LazyXComAccess, XCom 

100from airflow.plugins_manager import integrate_macros_plugins 

101from airflow.sentry import Sentry 

102from airflow.stats import Stats 

103from airflow.templates import SandboxedEnvironment 

104from airflow.ti_deps.dep_context import DepContext 

105from airflow.ti_deps.dependencies_deps import REQUEUEABLE_DEPS, RUNNING_DEPS 

106from airflow.timetables.base import DataInterval 

107from airflow.typing_compat import Literal, TypeGuard 

108from airflow.utils import timezone 

109from airflow.utils.context import ConnectionAccessor, Context, VariableAccessor, context_merge 

110from airflow.utils.email import send_email 

111from airflow.utils.helpers import prune_dict, render_template_to_string 

112from airflow.utils.log.logging_mixin import LoggingMixin 

113from airflow.utils.module_loading import qualname 

114from airflow.utils.net import get_hostname 

115from airflow.utils.operator_helpers import context_to_airflow_vars 

116from airflow.utils.platform import getuser 

117from airflow.utils.retries import run_with_db_retries 

118from airflow.utils.session import NEW_SESSION, create_session, provide_session 

119from airflow.utils.sqlalchemy import ( 

120 ExecutorConfigType, 

121 ExtendedJSON, 

122 UtcDateTime, 

123 tuple_in_condition, 

124 with_row_locks, 

125) 

126from airflow.utils.state import DagRunState, State, TaskInstanceState 

127from airflow.utils.task_group import MappedTaskGroup 

128from airflow.utils.timeout import timeout 

129from airflow.utils.xcom import XCOM_RETURN_KEY 

130 

131TR = TaskReschedule 

132 

133_CURRENT_CONTEXT: list[Context] = [] 

134log = logging.getLogger(__name__) 

135 

136 

137if TYPE_CHECKING: 

138 from airflow.models.abstractoperator import TaskStateChangeCallback 

139 from airflow.models.baseoperator import BaseOperator 

140 from airflow.models.dag import DAG, DagModel 

141 from airflow.models.dagrun import DagRun 

142 from airflow.models.dataset import DatasetEvent 

143 from airflow.models.operator import Operator 

144 from airflow.utils.task_group import TaskGroup 

145 

146 # This is a workaround because mypy doesn't work with hybrid_property 

147 # TODO: remove this hack and move hybrid_property back to main import block 

148 # See https://github.com/python/mypy/issues/4430 

149 hybrid_property = property 

150else: 

151 from sqlalchemy.ext.hybrid import hybrid_property 

152 

153 

154PAST_DEPENDS_MET = "past_depends_met" 

155 

156 

157class TaskReturnCode(Enum): 

158 """ 

159 Enum to signal manner of exit for task run command. 

160 

161 :meta private: 

162 """ 

163 

164 DEFERRED = 100 

165 """When task exits with deferral to trigger.""" 

166 

167 

168@contextlib.contextmanager 

169def set_current_context(context: Context) -> Generator[Context, None, None]: 

170 """ 

171 Sets the current execution context to the provided context object. 

172 This method should be called once per Task execution, before calling operator.execute. 

173 """ 

174 _CURRENT_CONTEXT.append(context) 

175 try: 

176 yield context 

177 finally: 

178 expected_state = _CURRENT_CONTEXT.pop() 

179 if expected_state != context: 

180 log.warning( 

181 "Current context is not equal to the state at context stack. Expected=%s, got=%s", 

182 context, 

183 expected_state, 

184 ) 

185 

186 

187def stop_all_tasks_in_dag(tis: list[TaskInstance], session: Session, task_id_to_ignore: int): 

188 for ti in tis: 

189 if ti.task_id == task_id_to_ignore or ti.state in ( 

190 TaskInstanceState.SUCCESS, 

191 TaskInstanceState.FAILED, 

192 ): 

193 continue 

194 if ti.state == TaskInstanceState.RUNNING: 

195 log.info("Forcing task %s to fail", ti.task_id) 

196 ti.error(session) 

197 else: 

198 log.info("Setting task %s to SKIPPED", ti.task_id) 

199 ti.set_state(state=TaskInstanceState.SKIPPED, session=session) 

200 

201 

202def clear_task_instances( 

203 tis: list[TaskInstance], 

204 session: Session, 

205 activate_dag_runs: None = None, 

206 dag: DAG | None = None, 

207 dag_run_state: DagRunState | Literal[False] = DagRunState.QUEUED, 

208) -> None: 

209 """ 

210 Clears a set of task instances, but makes sure the running ones 

211 get killed. Also sets Dagrun's `state` to QUEUED and `start_date` 

212 to the time of execution. But only for finished DRs (SUCCESS and FAILED). 

213 Doesn't clear DR's `state` and `start_date`for running 

214 DRs (QUEUED and RUNNING) because clearing the state for already 

215 running DR is redundant and clearing `start_date` affects DR's duration. 

216 

217 :param tis: a list of task instances 

218 :param session: current session 

219 :param dag_run_state: state to set finished DagRuns to. 

220 If set to False, DagRuns state will not be changed. 

221 :param dag: DAG object 

222 :param activate_dag_runs: Deprecated parameter, do not pass 

223 """ 

224 job_ids = [] 

225 # Keys: dag_id -> run_id -> map_indexes -> try_numbers -> task_id 

226 task_id_by_key: dict[str, dict[str, dict[int, dict[int, set[str]]]]] = defaultdict( 

227 lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(set))) 

228 ) 

229 dag_bag = DagBag(read_dags_from_db=True) 

230 for ti in tis: 

231 if ti.state == TaskInstanceState.RUNNING: 

232 if ti.job_id: 

233 # If a task is cleared when running, set its state to RESTARTING so that 

234 # the task is terminated and becomes eligible for retry. 

235 ti.state = TaskInstanceState.RESTARTING 

236 job_ids.append(ti.job_id) 

237 else: 

238 ti_dag = dag if dag and dag.dag_id == ti.dag_id else dag_bag.get_dag(ti.dag_id, session=session) 

239 task_id = ti.task_id 

240 if ti_dag and ti_dag.has_task(task_id): 

241 task = ti_dag.get_task(task_id) 

242 ti.refresh_from_task(task) 

243 task_retries = task.retries 

244 ti.max_tries = ti.try_number + task_retries - 1 

245 else: 

246 # Ignore errors when updating max_tries if the DAG or 

247 # task are not found since database records could be 

248 # outdated. We make max_tries the maximum value of its 

249 # original max_tries or the last attempted try number. 

250 ti.max_tries = max(ti.max_tries, ti.prev_attempted_tries) 

251 ti.state = None 

252 ti.external_executor_id = None 

253 ti.clear_next_method_args() 

254 session.merge(ti) 

255 

256 task_id_by_key[ti.dag_id][ti.run_id][ti.map_index][ti.try_number].add(ti.task_id) 

257 

258 if task_id_by_key: 

259 # Clear all reschedules related to the ti to clear 

260 

261 # This is an optimization for the common case where all tis are for a small number 

262 # of dag_id, run_id, try_number, and map_index. Use a nested dict of dag_id, 

263 # run_id, try_number, map_index, and task_id to construct the where clause in a 

264 # hierarchical manner. This speeds up the delete statement by more than 40x for 

265 # large number of tis (50k+). 

266 conditions = or_( 

267 and_( 

268 TR.dag_id == dag_id, 

269 or_( 

270 and_( 

271 TR.run_id == run_id, 

272 or_( 

273 and_( 

274 TR.map_index == map_index, 

275 or_( 

276 and_(TR.try_number == try_number, TR.task_id.in_(task_ids)) 

277 for try_number, task_ids in task_tries.items() 

278 ), 

279 ) 

280 for map_index, task_tries in map_indexes.items() 

281 ), 

282 ) 

283 for run_id, map_indexes in run_ids.items() 

284 ), 

285 ) 

286 for dag_id, run_ids in task_id_by_key.items() 

287 ) 

288 

289 delete_qry = TR.__table__.delete().where(conditions) 

290 session.execute(delete_qry) 

291 

292 if job_ids: 

293 from airflow.jobs.job import Job 

294 

295 for job in session.query(Job).filter(Job.id.in_(job_ids)).all(): 

296 job.state = TaskInstanceState.RESTARTING 

297 

298 if activate_dag_runs is not None: 

299 warnings.warn( 

300 "`activate_dag_runs` parameter to clear_task_instances function is deprecated. " 

301 "Please use `dag_run_state`", 

302 RemovedInAirflow3Warning, 

303 stacklevel=2, 

304 ) 

305 if not activate_dag_runs: 

306 dag_run_state = False 

307 

308 if dag_run_state is not False and tis: 

309 from airflow.models.dagrun import DagRun # Avoid circular import 

310 

311 run_ids_by_dag_id = defaultdict(set) 

312 for instance in tis: 

313 run_ids_by_dag_id[instance.dag_id].add(instance.run_id) 

314 

315 drs = ( 

316 session.query(DagRun) 

317 .filter( 

318 or_( 

319 and_(DagRun.dag_id == dag_id, DagRun.run_id.in_(run_ids)) 

320 for dag_id, run_ids in run_ids_by_dag_id.items() 

321 ) 

322 ) 

323 .all() 

324 ) 

325 dag_run_state = DagRunState(dag_run_state) # Validate the state value. 

326 for dr in drs: 

327 if dr.state in State.finished_dr_states: 

328 dr.state = dag_run_state 

329 dr.start_date = timezone.utcnow() 

330 if dag_run_state == DagRunState.QUEUED: 

331 dr.last_scheduling_decision = None 

332 dr.start_date = None 

333 session.flush() 

334 

335 

336def _is_mappable_value(value: Any) -> TypeGuard[Collection]: 

337 """Whether a value can be used for task mapping. 

338 

339 We only allow collections with guaranteed ordering, but exclude character 

340 sequences since that's usually not what users would expect to be mappable. 

341 """ 

342 if not isinstance(value, (collections.abc.Sequence, dict)): 

343 return False 

344 if isinstance(value, (bytearray, bytes, str)): 

345 return False 

346 return True 

347 

348 

349def _creator_note(val): 

350 """Custom creator for the ``note`` association proxy.""" 

351 if isinstance(val, str): 

352 return TaskInstanceNote(content=val) 

353 elif isinstance(val, dict): 

354 return TaskInstanceNote(**val) 

355 else: 

356 return TaskInstanceNote(*val) 

357 

358 

359class TaskInstance(Base, LoggingMixin): 

360 """ 

361 Task instances store the state of a task instance. This table is the 

362 authority and single source of truth around what tasks have run and the 

363 state they are in. 

364 

365 The SqlAlchemy model doesn't have a SqlAlchemy foreign key to the task or 

366 dag model deliberately to have more control over transactions. 

367 

368 Database transactions on this table should insure double triggers and 

369 any confusion around what task instances are or aren't ready to run 

370 even while multiple schedulers may be firing task instances. 

371 

372 A value of -1 in map_index represents any of: a TI without mapped tasks; 

373 a TI with mapped tasks that has yet to be expanded (state=pending); 

374 a TI with mapped tasks that expanded to an empty list (state=skipped). 

375 """ 

376 

377 __tablename__ = "task_instance" 

378 task_id = Column(StringID(), primary_key=True, nullable=False) 

379 dag_id = Column(StringID(), primary_key=True, nullable=False) 

380 run_id = Column(StringID(), primary_key=True, nullable=False) 

381 map_index = Column(Integer, primary_key=True, nullable=False, server_default=text("-1")) 

382 

383 start_date = Column(UtcDateTime) 

384 end_date = Column(UtcDateTime) 

385 duration = Column(Float) 

386 state = Column(String(20)) 

387 _try_number = Column("try_number", Integer, default=0) 

388 max_tries = Column(Integer, server_default=text("-1")) 

389 hostname = Column(String(1000)) 

390 unixname = Column(String(1000)) 

391 job_id = Column(Integer) 

392 pool = Column(String(256), nullable=False) 

393 pool_slots = Column(Integer, default=1, nullable=False) 

394 queue = Column(String(256)) 

395 priority_weight = Column(Integer) 

396 operator = Column(String(1000)) 

397 queued_dttm = Column(UtcDateTime) 

398 queued_by_job_id = Column(Integer) 

399 pid = Column(Integer) 

400 executor_config = Column(ExecutorConfigType(pickler=dill)) 

401 updated_at = Column(UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow) 

402 

403 external_executor_id = Column(StringID()) 

404 

405 # The trigger to resume on if we are in state DEFERRED 

406 trigger_id = Column(Integer) 

407 

408 # Optional timeout datetime for the trigger (past this, we'll fail) 

409 trigger_timeout = Column(DateTime) 

410 # The trigger_timeout should be TIMESTAMP(using UtcDateTime) but for ease of 

411 # migration, we are keeping it as DateTime pending a change where expensive 

412 # migration is inevitable. 

413 

414 # The method to call next, and any extra arguments to pass to it. 

415 # Usually used when resuming from DEFERRED. 

416 next_method = Column(String(1000)) 

417 next_kwargs = Column(MutableDict.as_mutable(ExtendedJSON)) 

418 

419 # If adding new fields here then remember to add them to 

420 # refresh_from_db() or they won't display in the UI correctly 

421 

422 __table_args__ = ( 

423 Index("ti_dag_state", dag_id, state), 

424 Index("ti_dag_run", dag_id, run_id), 

425 Index("ti_state", state), 

426 Index("ti_state_lkp", dag_id, task_id, run_id, state), 

427 # The below index has been added to improve performance on postgres setups with tens of millions of 

428 # taskinstance rows. Aim is to improve the below query (it can be used to find the last successful 

429 # execution date of a task instance): 

430 # SELECT start_date FROM task_instance WHERE dag_id = 'xx' AND task_id = 'yy' AND state = 'success' 

431 # ORDER BY start_date DESC NULLS LAST LIMIT 1; 

432 # Existing "ti_state_lkp" is not enough for such query when this table has millions of rows, since 

433 # rows have to be fetched in order to retrieve the start_date column. With this index, INDEX ONLY SCAN 

434 # is performed and that query runs within milliseconds. 

435 Index("ti_state_incl_start_date", dag_id, task_id, state, postgresql_include=["start_date"]), 

436 Index("ti_pool", pool, state, priority_weight), 

437 Index("ti_job_id", job_id), 

438 Index("ti_trigger_id", trigger_id), 

439 PrimaryKeyConstraint( 

440 "dag_id", "task_id", "run_id", "map_index", name="task_instance_pkey", mssql_clustered=True 

441 ), 

442 ForeignKeyConstraint( 

443 [trigger_id], 

444 ["trigger.id"], 

445 name="task_instance_trigger_id_fkey", 

446 ondelete="CASCADE", 

447 ), 

448 ForeignKeyConstraint( 

449 [dag_id, run_id], 

450 ["dag_run.dag_id", "dag_run.run_id"], 

451 name="task_instance_dag_run_fkey", 

452 ondelete="CASCADE", 

453 ), 

454 ) 

455 

456 dag_model = relationship( 

457 "DagModel", 

458 primaryjoin="TaskInstance.dag_id == DagModel.dag_id", 

459 foreign_keys=dag_id, 

460 uselist=False, 

461 innerjoin=True, 

462 viewonly=True, 

463 ) 

464 

465 trigger = relationship("Trigger", uselist=False, back_populates="task_instance") 

466 triggerer_job = association_proxy("trigger", "triggerer_job") 

467 dag_run = relationship("DagRun", back_populates="task_instances", lazy="joined", innerjoin=True) 

468 rendered_task_instance_fields = relationship("RenderedTaskInstanceFields", lazy="noload", uselist=False) 

469 execution_date = association_proxy("dag_run", "execution_date") 

470 task_instance_note = relationship( 

471 "TaskInstanceNote", 

472 back_populates="task_instance", 

473 uselist=False, 

474 cascade="all, delete, delete-orphan", 

475 ) 

476 note = association_proxy("task_instance_note", "content", creator=_creator_note) 

477 task: Operator # Not always set... 

478 

479 is_trigger_log_context: bool = False 

480 """Indicate to FileTaskHandler that logging context should be set up for trigger logging. 

481 

482 :meta private: 

483 """ 

484 

485 def __init__( 

486 self, 

487 task: Operator, 

488 execution_date: datetime | None = None, 

489 run_id: str | None = None, 

490 state: str | None = None, 

491 map_index: int = -1, 

492 ): 

493 super().__init__() 

494 self.dag_id = task.dag_id 

495 self.task_id = task.task_id 

496 self.map_index = map_index 

497 self.refresh_from_task(task) 

498 # init_on_load will config the log 

499 self.init_on_load() 

500 

501 if run_id is None and execution_date is not None: 

502 from airflow.models.dagrun import DagRun # Avoid circular import 

503 

504 warnings.warn( 

505 "Passing an execution_date to `TaskInstance()` is deprecated in favour of passing a run_id", 

506 RemovedInAirflow3Warning, 

507 # Stack level is 4 because SQLA adds some wrappers around the constructor 

508 stacklevel=4, 

509 ) 

510 # make sure we have a localized execution_date stored in UTC 

511 if execution_date and not timezone.is_localized(execution_date): 

512 self.log.warning( 

513 "execution date %s has no timezone information. Using default from dag or system", 

514 execution_date, 

515 ) 

516 if self.task.has_dag(): 

517 if TYPE_CHECKING: 

518 assert self.task.dag 

519 execution_date = timezone.make_aware(execution_date, self.task.dag.timezone) 

520 else: 

521 execution_date = timezone.make_aware(execution_date) 

522 

523 execution_date = timezone.convert_to_utc(execution_date) 

524 with create_session() as session: 

525 run_id = ( 

526 session.query(DagRun.run_id) 

527 .filter_by(dag_id=self.dag_id, execution_date=execution_date) 

528 .scalar() 

529 ) 

530 if run_id is None: 

531 raise DagRunNotFound( 

532 f"DagRun for {self.dag_id!r} with date {execution_date} not found" 

533 ) from None 

534 

535 self.run_id = run_id 

536 

537 self.try_number = 0 

538 self.max_tries = self.task.retries 

539 self.unixname = getuser() 

540 if state: 

541 self.state = state 

542 self.hostname = "" 

543 # Is this TaskInstance being currently running within `airflow tasks run --raw`. 

544 # Not persisted to the database so only valid for the current process 

545 self.raw = False 

546 # can be changed when calling 'run' 

547 self.test_mode = False 

548 

549 @property 

550 def stats_tags(self) -> dict[str, str]: 

551 return prune_dict({"dag_id": self.dag_id, "task_id": self.task_id}) 

552 

553 @staticmethod 

554 def insert_mapping(run_id: str, task: Operator, map_index: int) -> dict[str, Any]: 

555 """Insert mapping. 

556 

557 :meta private: 

558 """ 

559 return { 

560 "dag_id": task.dag_id, 

561 "task_id": task.task_id, 

562 "run_id": run_id, 

563 "_try_number": 0, 

564 "hostname": "", 

565 "unixname": getuser(), 

566 "queue": task.queue, 

567 "pool": task.pool, 

568 "pool_slots": task.pool_slots, 

569 "priority_weight": task.priority_weight_total, 

570 "run_as_user": task.run_as_user, 

571 "max_tries": task.retries, 

572 "executor_config": task.executor_config, 

573 "operator": task.task_type, 

574 "map_index": map_index, 

575 } 

576 

577 @reconstructor 

578 def init_on_load(self) -> None: 

579 """Initialize the attributes that aren't stored in the DB.""" 

580 # correctly config the ti log 

581 self._log = logging.getLogger("airflow.task") 

582 self.test_mode = False # can be changed when calling 'run' 

583 

584 @hybrid_property 

585 def try_number(self): 

586 """ 

587 Return the try number that this task number will be when it is actually 

588 run. 

589 

590 If the TaskInstance is currently running, this will match the column in the 

591 database, in all other cases this will be incremented. 

592 """ 

593 # This is designed so that task logs end up in the right file. 

594 if self.state == State.RUNNING: 

595 return self._try_number 

596 return self._try_number + 1 

597 

598 @try_number.setter 

599 def try_number(self, value: int) -> None: 

600 self._try_number = value 

601 

602 @property 

603 def prev_attempted_tries(self) -> int: 

604 """ 

605 Based on this instance's try_number, this will calculate 

606 the number of previously attempted tries, defaulting to 0. 

607 """ 

608 # Expose this for the Task Tries and Gantt graph views. 

609 # Using `try_number` throws off the counts for non-running tasks. 

610 # Also useful in error logging contexts to get 

611 # the try number for the last try that was attempted. 

612 # https://issues.apache.org/jira/browse/AIRFLOW-2143 

613 

614 return self._try_number 

615 

616 @property 

617 def next_try_number(self) -> int: 

618 return self._try_number + 1 

619 

620 def command_as_list( 

621 self, 

622 mark_success=False, 

623 ignore_all_deps=False, 

624 ignore_task_deps=False, 

625 ignore_depends_on_past=False, 

626 wait_for_past_depends_before_skipping=False, 

627 ignore_ti_state=False, 

628 local=False, 

629 pickle_id: int | None = None, 

630 raw=False, 

631 job_id=None, 

632 pool=None, 

633 cfg_path=None, 

634 ) -> list[str]: 

635 """ 

636 Returns a command that can be executed anywhere where airflow is 

637 installed. This command is part of the message sent to executors by 

638 the orchestrator. 

639 """ 

640 dag: DAG | DagModel 

641 # Use the dag if we have it, else fallback to the ORM dag_model, which might not be loaded 

642 if hasattr(self, "task") and hasattr(self.task, "dag") and self.task.dag is not None: 

643 dag = self.task.dag 

644 else: 

645 dag = self.dag_model 

646 

647 should_pass_filepath = not pickle_id and dag 

648 path: PurePath | None = None 

649 if should_pass_filepath: 

650 if dag.is_subdag: 

651 if TYPE_CHECKING: 

652 assert dag.parent_dag is not None 

653 path = dag.parent_dag.relative_fileloc 

654 else: 

655 path = dag.relative_fileloc 

656 

657 if path: 

658 if not path.is_absolute(): 

659 path = "DAGS_FOLDER" / path 

660 

661 return TaskInstance.generate_command( 

662 self.dag_id, 

663 self.task_id, 

664 run_id=self.run_id, 

665 mark_success=mark_success, 

666 ignore_all_deps=ignore_all_deps, 

667 ignore_task_deps=ignore_task_deps, 

668 ignore_depends_on_past=ignore_depends_on_past, 

669 wait_for_past_depends_before_skipping=wait_for_past_depends_before_skipping, 

670 ignore_ti_state=ignore_ti_state, 

671 local=local, 

672 pickle_id=pickle_id, 

673 file_path=path, 

674 raw=raw, 

675 job_id=job_id, 

676 pool=pool, 

677 cfg_path=cfg_path, 

678 map_index=self.map_index, 

679 ) 

680 

681 @staticmethod 

682 def generate_command( 

683 dag_id: str, 

684 task_id: str, 

685 run_id: str, 

686 mark_success: bool = False, 

687 ignore_all_deps: bool = False, 

688 ignore_depends_on_past: bool = False, 

689 wait_for_past_depends_before_skipping: bool = False, 

690 ignore_task_deps: bool = False, 

691 ignore_ti_state: bool = False, 

692 local: bool = False, 

693 pickle_id: int | None = None, 

694 file_path: PurePath | str | None = None, 

695 raw: bool = False, 

696 job_id: str | None = None, 

697 pool: str | None = None, 

698 cfg_path: str | None = None, 

699 map_index: int = -1, 

700 ) -> list[str]: 

701 """ 

702 Generates the shell command required to execute this task instance. 

703 

704 :param dag_id: DAG ID 

705 :param task_id: Task ID 

706 :param run_id: The run_id of this task's DagRun 

707 :param mark_success: Whether to mark the task as successful 

708 :param ignore_all_deps: Ignore all ignorable dependencies. 

709 Overrides the other ignore_* parameters. 

710 :param ignore_depends_on_past: Ignore depends_on_past parameter of DAGs 

711 (e.g. for Backfills) 

712 :param wait_for_past_depends_before_skipping: Wait for past depends before marking the ti as skipped 

713 :param ignore_task_deps: Ignore task-specific dependencies such as depends_on_past 

714 and trigger rule 

715 :param ignore_ti_state: Ignore the task instance's previous failure/success 

716 :param local: Whether to run the task locally 

717 :param pickle_id: If the DAG was serialized to the DB, the ID 

718 associated with the pickled DAG 

719 :param file_path: path to the file containing the DAG definition 

720 :param raw: raw mode (needs more details) 

721 :param job_id: job ID (needs more details) 

722 :param pool: the Airflow pool that the task should run in 

723 :param cfg_path: the Path to the configuration file 

724 :return: shell command that can be used to run the task instance 

725 """ 

726 cmd = ["airflow", "tasks", "run", dag_id, task_id, run_id] 

727 if mark_success: 

728 cmd.extend(["--mark-success"]) 

729 if pickle_id: 

730 cmd.extend(["--pickle", str(pickle_id)]) 

731 if job_id: 

732 cmd.extend(["--job-id", str(job_id)]) 

733 if ignore_all_deps: 

734 cmd.extend(["--ignore-all-dependencies"]) 

735 if ignore_task_deps: 

736 cmd.extend(["--ignore-dependencies"]) 

737 if ignore_depends_on_past: 

738 cmd.extend(["--depends-on-past", "ignore"]) 

739 elif wait_for_past_depends_before_skipping: 

740 cmd.extend(["--depends-on-past", "wait"]) 

741 if ignore_ti_state: 

742 cmd.extend(["--force"]) 

743 if local: 

744 cmd.extend(["--local"]) 

745 if pool: 

746 cmd.extend(["--pool", pool]) 

747 if raw: 

748 cmd.extend(["--raw"]) 

749 if file_path: 

750 cmd.extend(["--subdir", os.fspath(file_path)]) 

751 if cfg_path: 

752 cmd.extend(["--cfg-path", cfg_path]) 

753 if map_index != -1: 

754 cmd.extend(["--map-index", str(map_index)]) 

755 return cmd 

756 

757 @property 

758 def log_url(self) -> str: 

759 """Log URL for TaskInstance.""" 

760 iso = quote(self.execution_date.isoformat()) 

761 base_url = conf.get_mandatory_value("webserver", "BASE_URL") 

762 return ( 

763 f"{base_url}/log" 

764 f"?execution_date={iso}" 

765 f"&task_id={self.task_id}" 

766 f"&dag_id={self.dag_id}" 

767 f"&map_index={self.map_index}" 

768 ) 

769 

770 @property 

771 def mark_success_url(self) -> str: 

772 """URL to mark TI success.""" 

773 base_url = conf.get_mandatory_value("webserver", "BASE_URL") 

774 return ( 

775 f"{base_url}/confirm" 

776 f"?task_id={self.task_id}" 

777 f"&dag_id={self.dag_id}" 

778 f"&dag_run_id={quote(self.run_id)}" 

779 "&upstream=false" 

780 "&downstream=false" 

781 "&state=success" 

782 ) 

783 

784 @provide_session 

785 def current_state(self, session: Session = NEW_SESSION) -> str: 

786 """ 

787 Get the very latest state from the database, if a session is passed, 

788 we use and looking up the state becomes part of the session, otherwise 

789 a new session is used. 

790 

791 sqlalchemy.inspect is used here to get the primary keys ensuring that if they change 

792 it will not regress 

793 

794 :param session: SQLAlchemy ORM Session 

795 """ 

796 filters = (col == getattr(self, col.name) for col in inspect(TaskInstance).primary_key) 

797 return session.query(TaskInstance.state).filter(*filters).scalar() 

798 

799 @provide_session 

800 def error(self, session: Session = NEW_SESSION) -> None: 

801 """ 

802 Forces the task instance's state to FAILED in the database. 

803 

804 :param session: SQLAlchemy ORM Session 

805 """ 

806 self.log.error("Recording the task instance as FAILED") 

807 self.state = State.FAILED 

808 session.merge(self) 

809 session.commit() 

810 

811 @provide_session 

812 def refresh_from_db(self, session: Session = NEW_SESSION, lock_for_update: bool = False) -> None: 

813 """ 

814 Refreshes the task instance from the database based on the primary key. 

815 

816 :param session: SQLAlchemy ORM Session 

817 :param lock_for_update: if True, indicates that the database should 

818 lock the TaskInstance (issuing a FOR UPDATE clause) until the 

819 session is committed. 

820 """ 

821 self.log.debug("Refreshing TaskInstance %s from DB", self) 

822 

823 if self in session: 

824 session.refresh(self, TaskInstance.__mapper__.column_attrs.keys()) 

825 

826 qry = ( 

827 # To avoid joining any relationships, by default select all 

828 # columns, not the object. This also means we get (effectively) a 

829 # namedtuple back, not a TI object 

830 session.query(*TaskInstance.__table__.columns).filter( 

831 TaskInstance.dag_id == self.dag_id, 

832 TaskInstance.task_id == self.task_id, 

833 TaskInstance.run_id == self.run_id, 

834 TaskInstance.map_index == self.map_index, 

835 ) 

836 ) 

837 

838 if lock_for_update: 

839 for attempt in run_with_db_retries(logger=self.log): 

840 with attempt: 

841 ti: TaskInstance | None = qry.with_for_update().one_or_none() 

842 else: 

843 ti = qry.one_or_none() 

844 if ti: 

845 # Fields ordered per model definition 

846 self.start_date = ti.start_date 

847 self.end_date = ti.end_date 

848 self.duration = ti.duration 

849 self.state = ti.state 

850 # Since we selected columns, not the object, this is the raw value 

851 self.try_number = ti.try_number 

852 self.max_tries = ti.max_tries 

853 self.hostname = ti.hostname 

854 self.unixname = ti.unixname 

855 self.job_id = ti.job_id 

856 self.pool = ti.pool 

857 self.pool_slots = ti.pool_slots or 1 

858 self.queue = ti.queue 

859 self.priority_weight = ti.priority_weight 

860 self.operator = ti.operator 

861 self.queued_dttm = ti.queued_dttm 

862 self.queued_by_job_id = ti.queued_by_job_id 

863 self.pid = ti.pid 

864 self.executor_config = ti.executor_config 

865 self.external_executor_id = ti.external_executor_id 

866 self.trigger_id = ti.trigger_id 

867 self.next_method = ti.next_method 

868 self.next_kwargs = ti.next_kwargs 

869 else: 

870 self.state = None 

871 

872 def refresh_from_task(self, task: Operator, pool_override: str | None = None) -> None: 

873 """ 

874 Copy common attributes from the given task. 

875 

876 :param task: The task object to copy from 

877 :param pool_override: Use the pool_override instead of task's pool 

878 """ 

879 self.task = task 

880 self.queue = task.queue 

881 self.pool = pool_override or task.pool 

882 self.pool_slots = task.pool_slots 

883 self.priority_weight = task.priority_weight_total 

884 self.run_as_user = task.run_as_user 

885 # Do not set max_tries to task.retries here because max_tries is a cumulative 

886 # value that needs to be stored in the db. 

887 self.executor_config = task.executor_config 

888 self.operator = task.task_type 

889 

890 @provide_session 

891 def clear_xcom_data(self, session: Session = NEW_SESSION) -> None: 

892 """Clear all XCom data from the database for the task instance. 

893 

894 If the task is unmapped, all XComs matching this task ID in the same DAG 

895 run are removed. If the task is mapped, only the one with matching map 

896 index is removed. 

897 

898 :param session: SQLAlchemy ORM Session 

899 """ 

900 self.log.debug("Clearing XCom data") 

901 if self.map_index < 0: 

902 map_index: int | None = None 

903 else: 

904 map_index = self.map_index 

905 XCom.clear( 

906 dag_id=self.dag_id, 

907 task_id=self.task_id, 

908 run_id=self.run_id, 

909 map_index=map_index, 

910 session=session, 

911 ) 

912 

913 @property 

914 def key(self) -> TaskInstanceKey: 

915 """Returns a tuple that identifies the task instance uniquely.""" 

916 return TaskInstanceKey(self.dag_id, self.task_id, self.run_id, self.try_number, self.map_index) 

917 

918 @provide_session 

919 def set_state(self, state: str | None, session: Session = NEW_SESSION) -> bool: 

920 """ 

921 Set TaskInstance state. 

922 

923 :param state: State to set for the TI 

924 :param session: SQLAlchemy ORM Session 

925 :return: Was the state changed 

926 """ 

927 if self.state == state: 

928 return False 

929 

930 current_time = timezone.utcnow() 

931 self.log.debug("Setting task state for %s to %s", self, state) 

932 self.state = state 

933 self.start_date = self.start_date or current_time 

934 if self.state in State.finished or self.state == State.UP_FOR_RETRY: 

935 self.end_date = self.end_date or current_time 

936 self.duration = (self.end_date - self.start_date).total_seconds() 

937 session.merge(self) 

938 return True 

939 

940 @property 

941 def is_premature(self) -> bool: 

942 """ 

943 Returns whether a task is in UP_FOR_RETRY state and its retry interval 

944 has elapsed. 

945 """ 

946 # is the task still in the retry waiting period? 

947 return self.state == State.UP_FOR_RETRY and not self.ready_for_retry() 

948 

949 @provide_session 

950 def are_dependents_done(self, session: Session = NEW_SESSION) -> bool: 

951 """ 

952 Checks whether the immediate dependents of this task instance have succeeded or have been skipped. 

953 This is meant to be used by wait_for_downstream. 

954 

955 This is useful when you do not want to start processing the next 

956 schedule of a task until the dependents are done. For instance, 

957 if the task DROPs and recreates a table. 

958 

959 :param session: SQLAlchemy ORM Session 

960 """ 

961 task = self.task 

962 

963 if not task.downstream_task_ids: 

964 return True 

965 

966 ti = session.query(func.count(TaskInstance.task_id)).filter( 

967 TaskInstance.dag_id == self.dag_id, 

968 TaskInstance.task_id.in_(task.downstream_task_ids), 

969 TaskInstance.run_id == self.run_id, 

970 TaskInstance.state.in_([State.SKIPPED, State.SUCCESS]), 

971 ) 

972 count = ti[0][0] 

973 return count == len(task.downstream_task_ids) 

974 

975 @provide_session 

976 def get_previous_dagrun( 

977 self, 

978 state: DagRunState | None = None, 

979 session: Session | None = None, 

980 ) -> DagRun | None: 

981 """The DagRun that ran before this task instance's DagRun. 

982 

983 :param state: If passed, it only take into account instances of a specific state. 

984 :param session: SQLAlchemy ORM Session. 

985 """ 

986 dag = self.task.dag 

987 if dag is None: 

988 return None 

989 

990 dr = self.get_dagrun(session=session) 

991 dr.dag = dag 

992 

993 # We always ignore schedule in dagrun lookup when `state` is given 

994 # or the DAG is never scheduled. For legacy reasons, when 

995 # `catchup=True`, we use `get_previous_scheduled_dagrun` unless 

996 # `ignore_schedule` is `True`. 

997 ignore_schedule = state is not None or not dag.timetable.can_be_scheduled 

998 if dag.catchup is True and not ignore_schedule: 

999 last_dagrun = dr.get_previous_scheduled_dagrun(session=session) 

1000 else: 

1001 last_dagrun = dr.get_previous_dagrun(session=session, state=state) 

1002 

1003 if last_dagrun: 

1004 return last_dagrun 

1005 

1006 return None 

1007 

1008 @provide_session 

1009 def get_previous_ti( 

1010 self, 

1011 state: DagRunState | None = None, 

1012 session: Session = NEW_SESSION, 

1013 ) -> TaskInstance | None: 

1014 """ 

1015 The task instance for the task that ran before this task instance. 

1016 

1017 :param state: If passed, it only take into account instances of a specific state. 

1018 :param session: SQLAlchemy ORM Session 

1019 """ 

1020 dagrun = self.get_previous_dagrun(state, session=session) 

1021 if dagrun is None: 

1022 return None 

1023 return dagrun.get_task_instance(self.task_id, session=session) 

1024 

1025 @property 

1026 def previous_ti(self) -> TaskInstance | None: 

1027 """ 

1028 This attribute is deprecated. 

1029 Please use `airflow.models.taskinstance.TaskInstance.get_previous_ti` method. 

1030 """ 

1031 warnings.warn( 

1032 """ 

1033 This attribute is deprecated. 

1034 Please use `airflow.models.taskinstance.TaskInstance.get_previous_ti` method. 

1035 """, 

1036 RemovedInAirflow3Warning, 

1037 stacklevel=2, 

1038 ) 

1039 return self.get_previous_ti() 

1040 

1041 @property 

1042 def previous_ti_success(self) -> TaskInstance | None: 

1043 """ 

1044 This attribute is deprecated. 

1045 Please use `airflow.models.taskinstance.TaskInstance.get_previous_ti` method. 

1046 """ 

1047 warnings.warn( 

1048 """ 

1049 This attribute is deprecated. 

1050 Please use `airflow.models.taskinstance.TaskInstance.get_previous_ti` method. 

1051 """, 

1052 RemovedInAirflow3Warning, 

1053 stacklevel=2, 

1054 ) 

1055 return self.get_previous_ti(state=DagRunState.SUCCESS) 

1056 

1057 @provide_session 

1058 def get_previous_execution_date( 

1059 self, 

1060 state: DagRunState | None = None, 

1061 session: Session = NEW_SESSION, 

1062 ) -> pendulum.DateTime | None: 

1063 """ 

1064 The execution date from property previous_ti_success. 

1065 

1066 :param state: If passed, it only take into account instances of a specific state. 

1067 :param session: SQLAlchemy ORM Session 

1068 """ 

1069 self.log.debug("previous_execution_date was called") 

1070 prev_ti = self.get_previous_ti(state=state, session=session) 

1071 return prev_ti and pendulum.instance(prev_ti.execution_date) 

1072 

1073 @provide_session 

1074 def get_previous_start_date( 

1075 self, state: DagRunState | None = None, session: Session = NEW_SESSION 

1076 ) -> pendulum.DateTime | None: 

1077 """ 

1078 The start date from property previous_ti_success. 

1079 

1080 :param state: If passed, it only take into account instances of a specific state. 

1081 :param session: SQLAlchemy ORM Session 

1082 """ 

1083 self.log.debug("previous_start_date was called") 

1084 prev_ti = self.get_previous_ti(state=state, session=session) 

1085 # prev_ti may not exist and prev_ti.start_date may be None. 

1086 return prev_ti and prev_ti.start_date and pendulum.instance(prev_ti.start_date) 

1087 

1088 @property 

1089 def previous_start_date_success(self) -> pendulum.DateTime | None: 

1090 """ 

1091 This attribute is deprecated. 

1092 Please use `airflow.models.taskinstance.TaskInstance.get_previous_start_date` method. 

1093 """ 

1094 warnings.warn( 

1095 """ 

1096 This attribute is deprecated. 

1097 Please use `airflow.models.taskinstance.TaskInstance.get_previous_start_date` method. 

1098 """, 

1099 RemovedInAirflow3Warning, 

1100 stacklevel=2, 

1101 ) 

1102 return self.get_previous_start_date(state=DagRunState.SUCCESS) 

1103 

1104 @provide_session 

1105 def are_dependencies_met( 

1106 self, dep_context: DepContext | None = None, session: Session = NEW_SESSION, verbose: bool = False 

1107 ) -> bool: 

1108 """ 

1109 Returns whether or not all the conditions are met for this task instance to be run 

1110 given the context for the dependencies (e.g. a task instance being force run from 

1111 the UI will ignore some dependencies). 

1112 

1113 :param dep_context: The execution context that determines the dependencies that 

1114 should be evaluated. 

1115 :param session: database session 

1116 :param verbose: whether log details on failed dependencies on 

1117 info or debug log level 

1118 """ 

1119 dep_context = dep_context or DepContext() 

1120 failed = False 

1121 verbose_aware_logger = self.log.info if verbose else self.log.debug 

1122 for dep_status in self.get_failed_dep_statuses(dep_context=dep_context, session=session): 

1123 failed = True 

1124 

1125 verbose_aware_logger( 

1126 "Dependencies not met for %s, dependency '%s' FAILED: %s", 

1127 self, 

1128 dep_status.dep_name, 

1129 dep_status.reason, 

1130 ) 

1131 

1132 if failed: 

1133 return False 

1134 

1135 verbose_aware_logger("Dependencies all met for dep_context=%s ti=%s", dep_context.description, self) 

1136 return True 

1137 

1138 @provide_session 

1139 def get_failed_dep_statuses(self, dep_context: DepContext | None = None, session: Session = NEW_SESSION): 

1140 """Get failed Dependencies.""" 

1141 dep_context = dep_context or DepContext() 

1142 for dep in dep_context.deps | self.task.deps: 

1143 for dep_status in dep.get_dep_statuses(self, session, dep_context): 

1144 self.log.debug( 

1145 "%s dependency '%s' PASSED: %s, %s", 

1146 self, 

1147 dep_status.dep_name, 

1148 dep_status.passed, 

1149 dep_status.reason, 

1150 ) 

1151 

1152 if not dep_status.passed: 

1153 yield dep_status 

1154 

1155 def __repr__(self) -> str: 

1156 prefix = f"<TaskInstance: {self.dag_id}.{self.task_id} {self.run_id} " 

1157 if self.map_index != -1: 

1158 prefix += f"map_index={self.map_index} " 

1159 return prefix + f"[{self.state}]>" 

1160 

1161 def next_retry_datetime(self): 

1162 """ 

1163 Get datetime of the next retry if the task instance fails. For exponential 

1164 backoff, retry_delay is used as base and will be converted to seconds. 

1165 """ 

1166 from airflow.models.abstractoperator import MAX_RETRY_DELAY 

1167 

1168 delay = self.task.retry_delay 

1169 if self.task.retry_exponential_backoff: 

1170 # If the min_backoff calculation is below 1, it will be converted to 0 via int. Thus, 

1171 # we must round up prior to converting to an int, otherwise a divide by zero error 

1172 # will occur in the modded_hash calculation. 

1173 min_backoff = int(math.ceil(delay.total_seconds() * (2 ** (self.try_number - 2)))) 

1174 

1175 # In the case when delay.total_seconds() is 0, min_backoff will not be rounded up to 1. 

1176 # To address this, we impose a lower bound of 1 on min_backoff. This effectively makes 

1177 # the ceiling function unnecessary, but the ceiling function was retained to avoid 

1178 # introducing a breaking change. 

1179 if min_backoff < 1: 

1180 min_backoff = 1 

1181 

1182 # deterministic per task instance 

1183 ti_hash = int( 

1184 hashlib.sha1( 

1185 f"{self.dag_id}#{self.task_id}#{self.execution_date}#{self.try_number}".encode() 

1186 ).hexdigest(), 

1187 16, 

1188 ) 

1189 # between 1 and 1.0 * delay * (2^retry_number) 

1190 modded_hash = min_backoff + ti_hash % min_backoff 

1191 # timedelta has a maximum representable value. The exponentiation 

1192 # here means this value can be exceeded after a certain number 

1193 # of tries (around 50 if the initial delay is 1s, even fewer if 

1194 # the delay is larger). Cap the value here before creating a 

1195 # timedelta object so the operation doesn't fail with "OverflowError". 

1196 delay_backoff_in_seconds = min(modded_hash, MAX_RETRY_DELAY) 

1197 delay = timedelta(seconds=delay_backoff_in_seconds) 

1198 if self.task.max_retry_delay: 

1199 delay = min(self.task.max_retry_delay, delay) 

1200 return self.end_date + delay 

1201 

1202 def ready_for_retry(self) -> bool: 

1203 """ 

1204 Checks on whether the task instance is in the right state and timeframe 

1205 to be retried. 

1206 """ 

1207 return self.state == State.UP_FOR_RETRY and self.next_retry_datetime() < timezone.utcnow() 

1208 

1209 @provide_session 

1210 def get_dagrun(self, session: Session = NEW_SESSION) -> DagRun: 

1211 """ 

1212 Returns the DagRun for this TaskInstance. 

1213 

1214 :param session: SQLAlchemy ORM Session 

1215 :return: DagRun 

1216 """ 

1217 info = inspect(self) 

1218 if info.attrs.dag_run.loaded_value is not NO_VALUE: 

1219 if hasattr(self, "task"): 

1220 self.dag_run.dag = self.task.dag 

1221 return self.dag_run 

1222 

1223 from airflow.models.dagrun import DagRun # Avoid circular import 

1224 

1225 dr = session.query(DagRun).filter(DagRun.dag_id == self.dag_id, DagRun.run_id == self.run_id).one() 

1226 if hasattr(self, "task"): 

1227 dr.dag = self.task.dag 

1228 # Record it in the instance for next time. This means that `self.execution_date` will work correctly 

1229 set_committed_value(self, "dag_run", dr) 

1230 

1231 return dr 

1232 

1233 @provide_session 

1234 def check_and_change_state_before_execution( 

1235 self, 

1236 verbose: bool = True, 

1237 ignore_all_deps: bool = False, 

1238 ignore_depends_on_past: bool = False, 

1239 wait_for_past_depends_before_skipping: bool = False, 

1240 ignore_task_deps: bool = False, 

1241 ignore_ti_state: bool = False, 

1242 mark_success: bool = False, 

1243 test_mode: bool = False, 

1244 job_id: str | None = None, 

1245 pool: str | None = None, 

1246 external_executor_id: str | None = None, 

1247 session: Session = NEW_SESSION, 

1248 ) -> bool: 

1249 """ 

1250 Checks dependencies and then sets state to RUNNING if they are met. Returns 

1251 True if and only if state is set to RUNNING, which implies that task should be 

1252 executed, in preparation for _run_raw_task. 

1253 

1254 :param verbose: whether to turn on more verbose logging 

1255 :param ignore_all_deps: Ignore all of the non-critical dependencies, just runs 

1256 :param ignore_depends_on_past: Ignore depends_on_past DAG attribute 

1257 :param wait_for_past_depends_before_skipping: Wait for past depends before mark the ti as skipped 

1258 :param ignore_task_deps: Don't check the dependencies of this TaskInstance's task 

1259 :param ignore_ti_state: Disregards previous task instance state 

1260 :param mark_success: Don't run the task, mark its state as success 

1261 :param test_mode: Doesn't record success or failure in the DB 

1262 :param job_id: Job (BackfillJob / LocalTaskJob / SchedulerJob) ID 

1263 :param pool: specifies the pool to use to run the task instance 

1264 :param external_executor_id: The identifier of the celery executor 

1265 :param session: SQLAlchemy ORM Session 

1266 :return: whether the state was changed to running or not 

1267 """ 

1268 task = self.task 

1269 self.refresh_from_task(task, pool_override=pool) 

1270 self.test_mode = test_mode 

1271 self.refresh_from_db(session=session, lock_for_update=True) 

1272 self.job_id = job_id 

1273 self.hostname = get_hostname() 

1274 self.pid = None 

1275 

1276 if not ignore_all_deps and not ignore_ti_state and self.state == State.SUCCESS: 

1277 Stats.incr("previously_succeeded", tags=self.stats_tags) 

1278 

1279 if not mark_success: 

1280 # Firstly find non-runnable and non-requeueable tis. 

1281 # Since mark_success is not set, we do nothing. 

1282 non_requeueable_dep_context = DepContext( 

1283 deps=RUNNING_DEPS - REQUEUEABLE_DEPS, 

1284 ignore_all_deps=ignore_all_deps, 

1285 ignore_ti_state=ignore_ti_state, 

1286 ignore_depends_on_past=ignore_depends_on_past, 

1287 wait_for_past_depends_before_skipping=wait_for_past_depends_before_skipping, 

1288 ignore_task_deps=ignore_task_deps, 

1289 description="non-requeueable deps", 

1290 ) 

1291 if not self.are_dependencies_met( 

1292 dep_context=non_requeueable_dep_context, session=session, verbose=True 

1293 ): 

1294 session.commit() 

1295 return False 

1296 

1297 # For reporting purposes, we report based on 1-indexed, 

1298 # not 0-indexed lists (i.e. Attempt 1 instead of 

1299 # Attempt 0 for the first attempt). 

1300 # Set the task start date. In case it was re-scheduled use the initial 

1301 # start date that is recorded in task_reschedule table 

1302 # If the task continues after being deferred (next_method is set), use the original start_date 

1303 self.start_date = self.start_date if self.next_method else timezone.utcnow() 

1304 if self.state == State.UP_FOR_RESCHEDULE: 

1305 task_reschedule: TR = TR.query_for_task_instance(self, session=session).first() 

1306 if task_reschedule: 

1307 self.start_date = task_reschedule.start_date 

1308 

1309 # Secondly we find non-runnable but requeueable tis. We reset its state. 

1310 # This is because we might have hit concurrency limits, 

1311 # e.g. because of backfilling. 

1312 dep_context = DepContext( 

1313 deps=REQUEUEABLE_DEPS, 

1314 ignore_all_deps=ignore_all_deps, 

1315 ignore_depends_on_past=ignore_depends_on_past, 

1316 wait_for_past_depends_before_skipping=wait_for_past_depends_before_skipping, 

1317 ignore_task_deps=ignore_task_deps, 

1318 ignore_ti_state=ignore_ti_state, 

1319 description="requeueable deps", 

1320 ) 

1321 if not self.are_dependencies_met(dep_context=dep_context, session=session, verbose=True): 

1322 self.state = State.NONE 

1323 self.log.warning( 

1324 "Rescheduling due to concurrency limits reached " 

1325 "at task runtime. Attempt %s of " 

1326 "%s. State set to NONE.", 

1327 self.try_number, 

1328 self.max_tries + 1, 

1329 ) 

1330 self.queued_dttm = timezone.utcnow() 

1331 session.merge(self) 

1332 session.commit() 

1333 return False 

1334 

1335 if self.next_kwargs is not None: 

1336 self.log.info("Resuming after deferral") 

1337 else: 

1338 self.log.info("Starting attempt %s of %s", self.try_number, self.max_tries + 1) 

1339 self._try_number += 1 

1340 

1341 if not test_mode: 

1342 session.add(Log(State.RUNNING, self)) 

1343 

1344 self.state = State.RUNNING 

1345 self.emit_state_change_metric(State.RUNNING) 

1346 self.external_executor_id = external_executor_id 

1347 self.end_date = None 

1348 if not test_mode: 

1349 session.merge(self).task = task 

1350 session.commit() 

1351 

1352 # Closing all pooled connections to prevent 

1353 # "max number of connections reached" 

1354 settings.engine.dispose() # type: ignore 

1355 if verbose: 

1356 if mark_success: 

1357 self.log.info("Marking success for %s on %s", self.task, self.execution_date) 

1358 else: 

1359 self.log.info("Executing %s on %s", self.task, self.execution_date) 

1360 return True 

1361 

1362 def _date_or_empty(self, attr: str) -> str: 

1363 result: datetime | None = getattr(self, attr, None) 

1364 return result.strftime("%Y%m%dT%H%M%S") if result else "" 

1365 

1366 def _log_state(self, lead_msg: str = "") -> None: 

1367 params = [ 

1368 lead_msg, 

1369 str(self.state).upper(), 

1370 self.dag_id, 

1371 self.task_id, 

1372 ] 

1373 message = "%sMarking task as %s. dag_id=%s, task_id=%s, " 

1374 if self.map_index >= 0: 

1375 params.append(self.map_index) 

1376 message += "map_index=%d, " 

1377 self.log.info( 

1378 message + "execution_date=%s, start_date=%s, end_date=%s", 

1379 *params, 

1380 self._date_or_empty("execution_date"), 

1381 self._date_or_empty("start_date"), 

1382 self._date_or_empty("end_date"), 

1383 ) 

1384 

1385 def emit_state_change_metric(self, new_state: TaskInstanceState): 

1386 """ 

1387 Sends a time metric representing how much time a given state transition took. 

1388 The previous state and metric name is deduced from the state the task was put in. 

1389 

1390 :param new_state: The state that has just been set for this task. 

1391 We do not use `self.state`, because sometimes the state is updated directly in the DB and not in 

1392 the local TaskInstance object. 

1393 Supported states: QUEUED and RUNNING 

1394 """ 

1395 if self.end_date: 

1396 # if the task has an end date, it means that this is not its first round. 

1397 # we send the state transition time metric only on the first try, otherwise it gets more complex. 

1398 return 

1399 

1400 # switch on state and deduce which metric to send 

1401 if new_state == State.RUNNING: 

1402 metric_name = "queued_duration" 

1403 if self.queued_dttm is None: 

1404 # this should not really happen except in tests or rare cases, 

1405 # but we don't want to create errors just for a metric, so we just skip it 

1406 self.log.warning( 

1407 "cannot record %s for task %s because previous state change time has not been saved", 

1408 metric_name, 

1409 self.task_id, 

1410 ) 

1411 return 

1412 timing = (timezone.utcnow() - self.queued_dttm).total_seconds() 

1413 elif new_state == State.QUEUED: 

1414 metric_name = "scheduled_duration" 

1415 if self.start_date is None: 

1416 # same comment as above 

1417 self.log.warning( 

1418 "cannot record %s for task %s because previous state change time has not been saved", 

1419 metric_name, 

1420 self.task_id, 

1421 ) 

1422 return 

1423 timing = (timezone.utcnow() - self.start_date).total_seconds() 

1424 else: 

1425 raise NotImplementedError("no metric emission setup for state %s", new_state) 

1426 

1427 # send metric twice, once (legacy) with tags in the name and once with tags as tags 

1428 Stats.timing(f"dag.{self.dag_id}.{self.task_id}.{metric_name}", timing) 

1429 Stats.timing(f"task.{metric_name}", timing, tags={"task_id": self.task_id, "dag_id": self.dag_id}) 

1430 

1431 # Ensure we unset next_method and next_kwargs to ensure that any 

1432 # retries don't re-use them. 

1433 def clear_next_method_args(self) -> None: 

1434 self.log.debug("Clearing next_method and next_kwargs.") 

1435 

1436 self.next_method = None 

1437 self.next_kwargs = None 

1438 

1439 @provide_session 

1440 @Sentry.enrich_errors 

1441 def _run_raw_task( 

1442 self, 

1443 mark_success: bool = False, 

1444 test_mode: bool = False, 

1445 job_id: str | None = None, 

1446 pool: str | None = None, 

1447 session: Session = NEW_SESSION, 

1448 ) -> TaskReturnCode | None: 

1449 """ 

1450 Immediately runs the task (without checking or changing db state 

1451 before execution) and then sets the appropriate final state after 

1452 completion and runs any post-execute callbacks. Meant to be called 

1453 only after another function changes the state to running. 

1454 

1455 :param mark_success: Don't run the task, mark its state as success 

1456 :param test_mode: Doesn't record success or failure in the DB 

1457 :param pool: specifies the pool to use to run the task instance 

1458 :param session: SQLAlchemy ORM Session 

1459 """ 

1460 self.test_mode = test_mode 

1461 self.refresh_from_task(self.task, pool_override=pool) 

1462 self.refresh_from_db(session=session) 

1463 self.job_id = job_id 

1464 self.hostname = get_hostname() 

1465 self.pid = os.getpid() 

1466 if not test_mode: 

1467 session.merge(self) 

1468 session.commit() 

1469 actual_start_date = timezone.utcnow() 

1470 Stats.incr(f"ti.start.{self.task.dag_id}.{self.task.task_id}", tags=self.stats_tags) 

1471 # Same metric with tagging 

1472 Stats.incr("ti.start", tags=self.stats_tags) 

1473 # Initialize final state counters at zero 

1474 for state in State.task_states: 

1475 Stats.incr( 

1476 f"ti.finish.{self.task.dag_id}.{self.task.task_id}.{state}", 

1477 count=0, 

1478 tags=self.stats_tags, 

1479 ) 

1480 # Same metric with tagging 

1481 Stats.incr( 

1482 "ti.finish", 

1483 count=0, 

1484 tags={**self.stats_tags, "state": str(state)}, 

1485 ) 

1486 

1487 self.task = self.task.prepare_for_execution() 

1488 context = self.get_template_context(ignore_param_exceptions=False) 

1489 

1490 # We lose previous state because it's changed in other process in LocalTaskJob. 

1491 # We could probably pass it through here though... 

1492 get_listener_manager().hook.on_task_instance_running( 

1493 previous_state=TaskInstanceState.QUEUED, task_instance=self, session=session 

1494 ) 

1495 try: 

1496 if not mark_success: 

1497 self._execute_task_with_callbacks(context, test_mode) 

1498 if not test_mode: 

1499 self.refresh_from_db(lock_for_update=True, session=session) 

1500 self.state = State.SUCCESS 

1501 except TaskDeferred as defer: 

1502 # The task has signalled it wants to defer execution based on 

1503 # a trigger. 

1504 self._defer_task(defer=defer, session=session) 

1505 self.log.info( 

1506 "Pausing task as DEFERRED. dag_id=%s, task_id=%s, execution_date=%s, start_date=%s", 

1507 self.dag_id, 

1508 self.task_id, 

1509 self._date_or_empty("execution_date"), 

1510 self._date_or_empty("start_date"), 

1511 ) 

1512 if not test_mode: 

1513 session.add(Log(self.state, self)) 

1514 session.merge(self) 

1515 session.commit() 

1516 return TaskReturnCode.DEFERRED 

1517 except AirflowSkipException as e: 

1518 # Recording SKIP 

1519 # log only if exception has any arguments to prevent log flooding 

1520 if e.args: 

1521 self.log.info(e) 

1522 if not test_mode: 

1523 self.refresh_from_db(lock_for_update=True, session=session) 

1524 self.state = State.SKIPPED 

1525 except AirflowRescheduleException as reschedule_exception: 

1526 self._handle_reschedule(actual_start_date, reschedule_exception, test_mode, session=session) 

1527 session.commit() 

1528 return None 

1529 except (AirflowFailException, AirflowSensorTimeout) as e: 

1530 # If AirflowFailException is raised, task should not retry. 

1531 # If a sensor in reschedule mode reaches timeout, task should not retry. 

1532 self.handle_failure(e, test_mode, context, force_fail=True, session=session) 

1533 session.commit() 

1534 raise 

1535 except AirflowException as e: 

1536 if not test_mode: 

1537 self.refresh_from_db(lock_for_update=True, session=session) 

1538 # for case when task is marked as success/failed externally 

1539 # or dagrun timed out and task is marked as skipped 

1540 # current behavior doesn't hit the callbacks 

1541 if self.state in State.finished: 

1542 self.clear_next_method_args() 

1543 session.merge(self) 

1544 session.commit() 

1545 return None 

1546 else: 

1547 self.handle_failure(e, test_mode, context, session=session) 

1548 session.commit() 

1549 raise 

1550 except (Exception, KeyboardInterrupt) as e: 

1551 self.handle_failure(e, test_mode, context, session=session) 

1552 session.commit() 

1553 raise 

1554 finally: 

1555 Stats.incr(f"ti.finish.{self.dag_id}.{self.task_id}.{self.state}", tags=self.stats_tags) 

1556 # Same metric with tagging 

1557 Stats.incr("ti.finish", tags={**self.stats_tags, "state": str(self.state)}) 

1558 

1559 # Recording SKIPPED or SUCCESS 

1560 self.clear_next_method_args() 

1561 self.end_date = timezone.utcnow() 

1562 self._log_state() 

1563 self.set_duration() 

1564 

1565 # run on_success_callback before db committing 

1566 # otherwise, the LocalTaskJob sees the state is changed to `success`, 

1567 # but the task_runner is still running, LocalTaskJob then treats the state is set externally! 

1568 self._run_finished_callback(self.task.on_success_callback, context, "on_success") 

1569 

1570 if not test_mode: 

1571 session.add(Log(self.state, self)) 

1572 session.merge(self).task = self.task 

1573 if self.state == TaskInstanceState.SUCCESS: 

1574 self._register_dataset_changes(session=session) 

1575 get_listener_manager().hook.on_task_instance_success( 

1576 previous_state=TaskInstanceState.RUNNING, task_instance=self, session=session 

1577 ) 

1578 

1579 session.commit() 

1580 return None 

1581 

1582 def _register_dataset_changes(self, *, session: Session) -> None: 

1583 for obj in self.task.outlets or []: 

1584 self.log.debug("outlet obj %s", obj) 

1585 # Lineage can have other types of objects besides datasets 

1586 if isinstance(obj, Dataset): 

1587 dataset_manager.register_dataset_change( 

1588 task_instance=self, 

1589 dataset=obj, 

1590 session=session, 

1591 ) 

1592 

1593 def _execute_task_with_callbacks(self, context, test_mode=False): 

1594 """Prepare Task for Execution.""" 

1595 from airflow.models.renderedtifields import RenderedTaskInstanceFields 

1596 

1597 parent_pid = os.getpid() 

1598 

1599 def signal_handler(signum, frame): 

1600 pid = os.getpid() 

1601 

1602 # If a task forks during execution (from DAG code) for whatever 

1603 # reason, we want to make sure that we react to the signal only in 

1604 # the process that we've spawned ourselves (referred to here as the 

1605 # parent process). 

1606 if pid != parent_pid: 

1607 os._exit(1) 

1608 return 

1609 self.log.error("Received SIGTERM. Terminating subprocesses.") 

1610 self.task.on_kill() 

1611 raise AirflowException("Task received SIGTERM signal") 

1612 

1613 signal.signal(signal.SIGTERM, signal_handler) 

1614 

1615 # Don't clear Xcom until the task is certain to execute, and check if we are resuming from deferral. 

1616 if not self.next_method: 

1617 self.clear_xcom_data() 

1618 

1619 with Stats.timer(f"dag.{self.task.dag_id}.{self.task.task_id}.duration", tags=self.stats_tags): 

1620 # Set the validated/merged params on the task object. 

1621 self.task.params = context["params"] 

1622 

1623 task_orig = self.render_templates(context=context) 

1624 if not test_mode: 

1625 rtif = RenderedTaskInstanceFields(ti=self, render_templates=False) 

1626 RenderedTaskInstanceFields.write(rtif) 

1627 RenderedTaskInstanceFields.delete_old_records(self.task_id, self.dag_id) 

1628 

1629 # Export context to make it available for operators to use. 

1630 airflow_context_vars = context_to_airflow_vars(context, in_env_var_format=True) 

1631 os.environ.update(airflow_context_vars) 

1632 

1633 # Log context only for the default execution method, the assumption 

1634 # being that otherwise we're resuming a deferred task (in which 

1635 # case there's no need to log these again). 

1636 if not self.next_method: 

1637 self.log.info( 

1638 "Exporting env vars: %s", 

1639 " ".join(f"{k}={v!r}" for k, v in airflow_context_vars.items()), 

1640 ) 

1641 

1642 # Run pre_execute callback 

1643 self.task.pre_execute(context=context) 

1644 

1645 # Run on_execute callback 

1646 self._run_execute_callback(context, self.task) 

1647 

1648 # Execute the task 

1649 with set_current_context(context): 

1650 result = self._execute_task(context, task_orig) 

1651 # Run post_execute callback 

1652 self.task.post_execute(context=context, result=result) 

1653 

1654 Stats.incr(f"operator_successes_{self.task.task_type}", tags=self.stats_tags) 

1655 # Same metric with tagging 

1656 Stats.incr("operator_successes", tags={**self.stats_tags, "task_type": self.task.task_type}) 

1657 Stats.incr("ti_successes", tags=self.stats_tags) 

1658 

1659 def _run_finished_callback( 

1660 self, 

1661 callbacks: None | TaskStateChangeCallback | list[TaskStateChangeCallback], 

1662 context: Context, 

1663 callback_type: str, 

1664 ) -> None: 

1665 """Run callback after task finishes.""" 

1666 if callbacks: 

1667 callbacks = callbacks if isinstance(callbacks, list) else [callbacks] 

1668 for callback in callbacks: 

1669 try: 

1670 callback(context) 

1671 except Exception: 

1672 callback_name = qualname(callback).split(".")[-1] 

1673 self.log.exception( 

1674 f"Error when executing {callback_name} callback" # type: ignore[attr-defined] 

1675 ) 

1676 

1677 def _execute_task(self, context, task_orig): 

1678 """Executes Task (optionally with a Timeout) and pushes Xcom results.""" 

1679 task_to_execute = self.task 

1680 # If the task has been deferred and is being executed due to a trigger, 

1681 # then we need to pick the right method to come back to, otherwise 

1682 # we go for the default execute 

1683 if self.next_method: 

1684 # __fail__ is a special signal value for next_method that indicates 

1685 # this task was scheduled specifically to fail. 

1686 if self.next_method == "__fail__": 

1687 next_kwargs = self.next_kwargs or {} 

1688 traceback = self.next_kwargs.get("traceback") 

1689 if traceback is not None: 

1690 self.log.error("Trigger failed:\n%s", "\n".join(traceback)) 

1691 raise TaskDeferralError(next_kwargs.get("error", "Unknown")) 

1692 # Grab the callable off the Operator/Task and add in any kwargs 

1693 execute_callable = getattr(task_to_execute, self.next_method) 

1694 if self.next_kwargs: 

1695 execute_callable = partial(execute_callable, **self.next_kwargs) 

1696 else: 

1697 execute_callable = task_to_execute.execute 

1698 # If a timeout is specified for the task, make it fail 

1699 # if it goes beyond 

1700 if task_to_execute.execution_timeout: 

1701 # If we are coming in with a next_method (i.e. from a deferral), 

1702 # calculate the timeout from our start_date. 

1703 if self.next_method: 

1704 timeout_seconds = ( 

1705 task_to_execute.execution_timeout - (timezone.utcnow() - self.start_date) 

1706 ).total_seconds() 

1707 else: 

1708 timeout_seconds = task_to_execute.execution_timeout.total_seconds() 

1709 try: 

1710 # It's possible we're already timed out, so fast-fail if true 

1711 if timeout_seconds <= 0: 

1712 raise AirflowTaskTimeout() 

1713 # Run task in timeout wrapper 

1714 with timeout(timeout_seconds): 

1715 result = execute_callable(context=context) 

1716 except AirflowTaskTimeout: 

1717 task_to_execute.on_kill() 

1718 raise 

1719 else: 

1720 result = execute_callable(context=context) 

1721 with create_session() as session: 

1722 if task_to_execute.do_xcom_push: 

1723 xcom_value = result 

1724 else: 

1725 xcom_value = None 

1726 if xcom_value is not None: # If the task returns a result, push an XCom containing it. 

1727 self.xcom_push(key=XCOM_RETURN_KEY, value=xcom_value, session=session) 

1728 self._record_task_map_for_downstreams(task_orig, xcom_value, session=session) 

1729 return result 

1730 

1731 @provide_session 

1732 def _defer_task(self, session: Session, defer: TaskDeferred) -> None: 

1733 """ 

1734 Marks the task as deferred and sets up the trigger that is needed 

1735 to resume it. 

1736 """ 

1737 from airflow.models.trigger import Trigger 

1738 

1739 # First, make the trigger entry 

1740 trigger_row = Trigger.from_object(defer.trigger) 

1741 session.add(trigger_row) 

1742 session.flush() 

1743 

1744 # Then, update ourselves so it matches the deferral request 

1745 # Keep an eye on the logic in `check_and_change_state_before_execution()` 

1746 # depending on self.next_method semantics 

1747 self.state = State.DEFERRED 

1748 self.trigger_id = trigger_row.id 

1749 self.next_method = defer.method_name 

1750 self.next_kwargs = defer.kwargs or {} 

1751 

1752 # Decrement try number so the next one is the same try 

1753 self._try_number -= 1 

1754 

1755 # Calculate timeout too if it was passed 

1756 if defer.timeout is not None: 

1757 self.trigger_timeout = timezone.utcnow() + defer.timeout 

1758 else: 

1759 self.trigger_timeout = None 

1760 

1761 # If an execution_timeout is set, set the timeout to the minimum of 

1762 # it and the trigger timeout 

1763 execution_timeout = self.task.execution_timeout 

1764 if execution_timeout: 

1765 if self.trigger_timeout: 

1766 self.trigger_timeout = min(self.start_date + execution_timeout, self.trigger_timeout) 

1767 else: 

1768 self.trigger_timeout = self.start_date + execution_timeout 

1769 

1770 def _run_execute_callback(self, context: Context, task: Operator) -> None: 

1771 """Functions that need to be run before a Task is executed.""" 

1772 callbacks = task.on_execute_callback 

1773 if callbacks: 

1774 callbacks = callbacks if isinstance(callbacks, list) else [callbacks] 

1775 for callback in callbacks: 

1776 try: 

1777 callback(context) 

1778 except Exception: 

1779 self.log.exception("Failed when executing execute callback") 

1780 

1781 @provide_session 

1782 def run( 

1783 self, 

1784 verbose: bool = True, 

1785 ignore_all_deps: bool = False, 

1786 ignore_depends_on_past: bool = False, 

1787 wait_for_past_depends_before_skipping: bool = False, 

1788 ignore_task_deps: bool = False, 

1789 ignore_ti_state: bool = False, 

1790 mark_success: bool = False, 

1791 test_mode: bool = False, 

1792 job_id: str | None = None, 

1793 pool: str | None = None, 

1794 session: Session = NEW_SESSION, 

1795 ) -> None: 

1796 """Run TaskInstance.""" 

1797 res = self.check_and_change_state_before_execution( 

1798 verbose=verbose, 

1799 ignore_all_deps=ignore_all_deps, 

1800 ignore_depends_on_past=ignore_depends_on_past, 

1801 wait_for_past_depends_before_skipping=wait_for_past_depends_before_skipping, 

1802 ignore_task_deps=ignore_task_deps, 

1803 ignore_ti_state=ignore_ti_state, 

1804 mark_success=mark_success, 

1805 test_mode=test_mode, 

1806 job_id=job_id, 

1807 pool=pool, 

1808 session=session, 

1809 ) 

1810 if not res: 

1811 return 

1812 

1813 self._run_raw_task( 

1814 mark_success=mark_success, test_mode=test_mode, job_id=job_id, pool=pool, session=session 

1815 ) 

1816 

1817 def dry_run(self) -> None: 

1818 """Only Renders Templates for the TI.""" 

1819 from airflow.models.baseoperator import BaseOperator 

1820 

1821 self.task = self.task.prepare_for_execution() 

1822 self.render_templates() 

1823 if TYPE_CHECKING: 

1824 assert isinstance(self.task, BaseOperator) 

1825 self.task.dry_run() 

1826 

1827 @provide_session 

1828 def _handle_reschedule( 

1829 self, actual_start_date, reschedule_exception, test_mode=False, session=NEW_SESSION 

1830 ): 

1831 # Don't record reschedule request in test mode 

1832 if test_mode: 

1833 return 

1834 

1835 from airflow.models.dagrun import DagRun # Avoid circular import 

1836 

1837 self.refresh_from_db(session) 

1838 

1839 self.end_date = timezone.utcnow() 

1840 self.set_duration() 

1841 

1842 # Lock DAG run to be sure not to get into a deadlock situation when trying to insert 

1843 # TaskReschedule which apparently also creates lock on corresponding DagRun entity 

1844 with_row_locks( 

1845 session.query(DagRun).filter_by( 

1846 dag_id=self.dag_id, 

1847 run_id=self.run_id, 

1848 ), 

1849 session=session, 

1850 ).one() 

1851 

1852 # Log reschedule request 

1853 session.add( 

1854 TaskReschedule( 

1855 self.task, 

1856 self.run_id, 

1857 self._try_number, 

1858 actual_start_date, 

1859 self.end_date, 

1860 reschedule_exception.reschedule_date, 

1861 self.map_index, 

1862 ) 

1863 ) 

1864 

1865 # set state 

1866 self.state = State.UP_FOR_RESCHEDULE 

1867 

1868 # Decrement try_number so subsequent runs will use the same try number and write 

1869 # to same log file. 

1870 self._try_number -= 1 

1871 

1872 self.clear_next_method_args() 

1873 

1874 session.merge(self) 

1875 session.commit() 

1876 self.log.info("Rescheduling task, marking task as UP_FOR_RESCHEDULE") 

1877 

1878 @staticmethod 

1879 def get_truncated_error_traceback(error: BaseException, truncate_to: Callable) -> TracebackType | None: 

1880 """ 

1881 Truncates the traceback of an exception to the first frame called from within a given function. 

1882 

1883 :param error: exception to get traceback from 

1884 :param truncate_to: Function to truncate TB to. Must have a ``__code__`` attribute 

1885 

1886 :meta private: 

1887 """ 

1888 tb = error.__traceback__ 

1889 code = truncate_to.__func__.__code__ # type: ignore[attr-defined] 

1890 while tb is not None: 

1891 if tb.tb_frame.f_code is code: 

1892 return tb.tb_next 

1893 tb = tb.tb_next 

1894 return tb or error.__traceback__ 

1895 

1896 @provide_session 

1897 def handle_failure( 

1898 self, 

1899 error: None | str | Exception | KeyboardInterrupt, 

1900 test_mode: bool | None = None, 

1901 context: Context | None = None, 

1902 force_fail: bool = False, 

1903 session: Session = NEW_SESSION, 

1904 ) -> None: 

1905 """Handle Failure for the TaskInstance.""" 

1906 if test_mode is None: 

1907 test_mode = self.test_mode 

1908 

1909 get_listener_manager().hook.on_task_instance_failed( 

1910 previous_state=TaskInstanceState.RUNNING, task_instance=self, session=session 

1911 ) 

1912 

1913 if error: 

1914 if isinstance(error, BaseException): 

1915 tb = self.get_truncated_error_traceback(error, truncate_to=self._execute_task) 

1916 self.log.error("Task failed with exception", exc_info=(type(error), error, tb)) 

1917 else: 

1918 self.log.error("%s", error) 

1919 if not test_mode: 

1920 self.refresh_from_db(session) 

1921 

1922 self.end_date = timezone.utcnow() 

1923 self.set_duration() 

1924 

1925 Stats.incr(f"operator_failures_{self.operator}", tags=self.stats_tags) 

1926 # Same metric with tagging 

1927 Stats.incr("operator_failures", tags={**self.stats_tags, "operator": self.operator}) 

1928 Stats.incr("ti_failures", tags=self.stats_tags) 

1929 

1930 if not test_mode: 

1931 session.add(Log(State.FAILED, self)) 

1932 

1933 # Log failure duration 

1934 session.add(TaskFail(ti=self)) 

1935 

1936 self.clear_next_method_args() 

1937 

1938 # In extreme cases (zombie in case of dag with parse error) we might _not_ have a Task. 

1939 if context is None and getattr(self, "task", None): 

1940 context = self.get_template_context(session) 

1941 

1942 if context is not None: 

1943 context["exception"] = error 

1944 

1945 # Set state correctly and figure out how to log it and decide whether 

1946 # to email 

1947 

1948 # Note, callback invocation needs to be handled by caller of 

1949 # _run_raw_task to avoid race conditions which could lead to duplicate 

1950 # invocations or miss invocation. 

1951 

1952 # Since this function is called only when the TaskInstance state is running, 

1953 # try_number contains the current try_number (not the next). We 

1954 # only mark task instance as FAILED if the next task instance 

1955 # try_number exceeds the max_tries ... or if force_fail is truthy 

1956 

1957 task: BaseOperator | None = None 

1958 try: 

1959 if getattr(self, "task", None) and context: 

1960 task = self.task.unmap((context, session)) 

1961 except Exception: 

1962 self.log.error("Unable to unmap task to determine if we need to send an alert email") 

1963 

1964 if force_fail or not self.is_eligible_to_retry(): 

1965 self.state = State.FAILED 

1966 email_for_state = operator.attrgetter("email_on_failure") 

1967 callbacks = task.on_failure_callback if task else None 

1968 callback_type = "on_failure" 

1969 

1970 if task and task.dag and task.dag.fail_stop: 

1971 tis = self.get_dagrun(session).get_task_instances() 

1972 stop_all_tasks_in_dag(tis, session, self.task_id) 

1973 else: 

1974 if self.state == State.QUEUED: 

1975 # We increase the try_number so as to fail the task if it fails to start after sometime 

1976 self._try_number += 1 

1977 self.state = State.UP_FOR_RETRY 

1978 email_for_state = operator.attrgetter("email_on_retry") 

1979 callbacks = task.on_retry_callback if task else None 

1980 callback_type = "on_retry" 

1981 

1982 self._log_state("Immediate failure requested. " if force_fail else "") 

1983 if task and email_for_state(task) and task.email: 

1984 try: 

1985 self.email_alert(error, task) 

1986 except Exception: 

1987 self.log.exception("Failed to send email to: %s", task.email) 

1988 

1989 if callbacks and context: 

1990 self._run_finished_callback(callbacks, context, callback_type) 

1991 

1992 if not test_mode: 

1993 session.merge(self) 

1994 session.flush() 

1995 

1996 def is_eligible_to_retry(self): 

1997 """Is task instance is eligible for retry.""" 

1998 if self.state == State.RESTARTING: 

1999 # If a task is cleared when running, it goes into RESTARTING state and is always 

2000 # eligible for retry 

2001 return True 

2002 if not getattr(self, "task", None): 

2003 # Couldn't load the task, don't know number of retries, guess: 

2004 return self.try_number <= self.max_tries 

2005 

2006 return self.task.retries and self.try_number <= self.max_tries 

2007 

2008 def get_template_context( 

2009 self, 

2010 session: Session | None = None, 

2011 ignore_param_exceptions: bool = True, 

2012 ) -> Context: 

2013 """Return TI Context.""" 

2014 # Do not use provide_session here -- it expunges everything on exit! 

2015 if not session: 

2016 session = settings.Session() 

2017 

2018 from airflow import macros 

2019 from airflow.models.abstractoperator import NotMapped 

2020 

2021 integrate_macros_plugins() 

2022 

2023 task = self.task 

2024 if TYPE_CHECKING: 

2025 assert task.dag 

2026 dag: DAG = task.dag 

2027 

2028 dag_run = self.get_dagrun(session) 

2029 data_interval = dag.get_run_data_interval(dag_run) 

2030 

2031 validated_params = process_params(dag, task, dag_run, suppress_exception=ignore_param_exceptions) 

2032 

2033 logical_date = timezone.coerce_datetime(self.execution_date) 

2034 ds = logical_date.strftime("%Y-%m-%d") 

2035 ds_nodash = ds.replace("-", "") 

2036 ts = logical_date.isoformat() 

2037 ts_nodash = logical_date.strftime("%Y%m%dT%H%M%S") 

2038 ts_nodash_with_tz = ts.replace("-", "").replace(":", "") 

2039 

2040 @cache # Prevent multiple database access. 

2041 def _get_previous_dagrun_success() -> DagRun | None: 

2042 return self.get_previous_dagrun(state=DagRunState.SUCCESS, session=session) 

2043 

2044 def _get_previous_dagrun_data_interval_success() -> DataInterval | None: 

2045 dagrun = _get_previous_dagrun_success() 

2046 if dagrun is None: 

2047 return None 

2048 return dag.get_run_data_interval(dagrun) 

2049 

2050 def get_prev_data_interval_start_success() -> pendulum.DateTime | None: 

2051 data_interval = _get_previous_dagrun_data_interval_success() 

2052 if data_interval is None: 

2053 return None 

2054 return data_interval.start 

2055 

2056 def get_prev_data_interval_end_success() -> pendulum.DateTime | None: 

2057 data_interval = _get_previous_dagrun_data_interval_success() 

2058 if data_interval is None: 

2059 return None 

2060 return data_interval.end 

2061 

2062 def get_prev_start_date_success() -> pendulum.DateTime | None: 

2063 dagrun = _get_previous_dagrun_success() 

2064 if dagrun is None: 

2065 return None 

2066 return timezone.coerce_datetime(dagrun.start_date) 

2067 

2068 @cache 

2069 def get_yesterday_ds() -> str: 

2070 return (logical_date - timedelta(1)).strftime("%Y-%m-%d") 

2071 

2072 def get_yesterday_ds_nodash() -> str: 

2073 return get_yesterday_ds().replace("-", "") 

2074 

2075 @cache 

2076 def get_tomorrow_ds() -> str: 

2077 return (logical_date + timedelta(1)).strftime("%Y-%m-%d") 

2078 

2079 def get_tomorrow_ds_nodash() -> str: 

2080 return get_tomorrow_ds().replace("-", "") 

2081 

2082 @cache 

2083 def get_next_execution_date() -> pendulum.DateTime | None: 

2084 # For manually triggered dagruns that aren't run on a schedule, 

2085 # the "next" execution date doesn't make sense, and should be set 

2086 # to execution date for consistency with how execution_date is set 

2087 # for manually triggered tasks, i.e. triggered_date == execution_date. 

2088 if dag_run.external_trigger: 

2089 return logical_date 

2090 if dag is None: 

2091 return None 

2092 next_info = dag.next_dagrun_info(data_interval, restricted=False) 

2093 if next_info is None: 

2094 return None 

2095 return timezone.coerce_datetime(next_info.logical_date) 

2096 

2097 def get_next_ds() -> str | None: 

2098 execution_date = get_next_execution_date() 

2099 if execution_date is None: 

2100 return None 

2101 return execution_date.strftime("%Y-%m-%d") 

2102 

2103 def get_next_ds_nodash() -> str | None: 

2104 ds = get_next_ds() 

2105 if ds is None: 

2106 return ds 

2107 return ds.replace("-", "") 

2108 

2109 @cache 

2110 def get_prev_execution_date(): 

2111 # For manually triggered dagruns that aren't run on a schedule, 

2112 # the "previous" execution date doesn't make sense, and should be set 

2113 # to execution date for consistency with how execution_date is set 

2114 # for manually triggered tasks, i.e. triggered_date == execution_date. 

2115 if dag_run.external_trigger: 

2116 return logical_date 

2117 with warnings.catch_warnings(): 

2118 warnings.simplefilter("ignore", RemovedInAirflow3Warning) 

2119 return dag.previous_schedule(logical_date) 

2120 

2121 @cache 

2122 def get_prev_ds() -> str | None: 

2123 execution_date = get_prev_execution_date() 

2124 if execution_date is None: 

2125 return None 

2126 return execution_date.strftime(r"%Y-%m-%d") 

2127 

2128 def get_prev_ds_nodash() -> str | None: 

2129 prev_ds = get_prev_ds() 

2130 if prev_ds is None: 

2131 return None 

2132 return prev_ds.replace("-", "") 

2133 

2134 def get_triggering_events() -> dict[str, list[DatasetEvent]]: 

2135 if TYPE_CHECKING: 

2136 assert session is not None 

2137 

2138 # The dag_run may not be attached to the session anymore since the 

2139 # code base is over-zealous with use of session.expunge_all(). 

2140 # Re-attach it if we get called. 

2141 nonlocal dag_run 

2142 if dag_run not in session: 

2143 dag_run = session.merge(dag_run, load=False) 

2144 

2145 dataset_events = dag_run.consumed_dataset_events 

2146 triggering_events: dict[str, list[DatasetEvent]] = defaultdict(list) 

2147 for event in dataset_events: 

2148 triggering_events[event.dataset.uri].append(event) 

2149 

2150 return triggering_events 

2151 

2152 try: 

2153 expanded_ti_count: int | None = task.get_mapped_ti_count(self.run_id, session=session) 

2154 except NotMapped: 

2155 expanded_ti_count = None 

2156 

2157 # NOTE: If you add anything to this dict, make sure to also update the 

2158 # definition in airflow/utils/context.pyi, and KNOWN_CONTEXT_KEYS in 

2159 # airflow/utils/context.py! 

2160 context = { 

2161 "conf": conf, 

2162 "dag": dag, 

2163 "dag_run": dag_run, 

2164 "data_interval_end": timezone.coerce_datetime(data_interval.end), 

2165 "data_interval_start": timezone.coerce_datetime(data_interval.start), 

2166 "ds": ds, 

2167 "ds_nodash": ds_nodash, 

2168 "execution_date": logical_date, 

2169 "expanded_ti_count": expanded_ti_count, 

2170 "inlets": task.inlets, 

2171 "logical_date": logical_date, 

2172 "macros": macros, 

2173 "next_ds": get_next_ds(), 

2174 "next_ds_nodash": get_next_ds_nodash(), 

2175 "next_execution_date": get_next_execution_date(), 

2176 "outlets": task.outlets, 

2177 "params": validated_params, 

2178 "prev_data_interval_start_success": get_prev_data_interval_start_success(), 

2179 "prev_data_interval_end_success": get_prev_data_interval_end_success(), 

2180 "prev_ds": get_prev_ds(), 

2181 "prev_ds_nodash": get_prev_ds_nodash(), 

2182 "prev_execution_date": get_prev_execution_date(), 

2183 "prev_execution_date_success": self.get_previous_execution_date( 

2184 state=DagRunState.SUCCESS, 

2185 session=session, 

2186 ), 

2187 "prev_start_date_success": get_prev_start_date_success(), 

2188 "run_id": self.run_id, 

2189 "task": task, 

2190 "task_instance": self, 

2191 "task_instance_key_str": f"{task.dag_id}__{task.task_id}__{ds_nodash}", 

2192 "test_mode": self.test_mode, 

2193 "ti": self, 

2194 "tomorrow_ds": get_tomorrow_ds(), 

2195 "tomorrow_ds_nodash": get_tomorrow_ds_nodash(), 

2196 "triggering_dataset_events": lazy_object_proxy.Proxy(get_triggering_events), 

2197 "ts": ts, 

2198 "ts_nodash": ts_nodash, 

2199 "ts_nodash_with_tz": ts_nodash_with_tz, 

2200 "var": { 

2201 "json": VariableAccessor(deserialize_json=True), 

2202 "value": VariableAccessor(deserialize_json=False), 

2203 }, 

2204 "conn": ConnectionAccessor(), 

2205 "yesterday_ds": get_yesterday_ds(), 

2206 "yesterday_ds_nodash": get_yesterday_ds_nodash(), 

2207 } 

2208 # Mypy doesn't like turning existing dicts in to a TypeDict -- and we "lie" in the type stub to say it 

2209 # is one, but in practice it isn't. See https://github.com/python/mypy/issues/8890 

2210 return Context(context) # type: ignore 

2211 

2212 @provide_session 

2213 def get_rendered_template_fields(self, session: Session = NEW_SESSION) -> None: 

2214 """ 

2215 Update task with rendered template fields for presentation in UI. 

2216 If task has already run, will fetch from DB; otherwise will render. 

2217 """ 

2218 from airflow.models.renderedtifields import RenderedTaskInstanceFields 

2219 

2220 rendered_task_instance_fields = RenderedTaskInstanceFields.get_templated_fields(self, session=session) 

2221 if rendered_task_instance_fields: 

2222 self.task = self.task.unmap(None) 

2223 for field_name, rendered_value in rendered_task_instance_fields.items(): 

2224 setattr(self.task, field_name, rendered_value) 

2225 return 

2226 

2227 try: 

2228 # If we get here, either the task hasn't run or the RTIF record was purged. 

2229 from airflow.utils.log.secrets_masker import redact 

2230 

2231 self.render_templates() 

2232 for field_name in self.task.template_fields: 

2233 rendered_value = getattr(self.task, field_name) 

2234 setattr(self.task, field_name, redact(rendered_value, field_name)) 

2235 except (TemplateAssertionError, UndefinedError) as e: 

2236 raise AirflowException( 

2237 "Webserver does not have access to User-defined Macros or Filters " 

2238 "when Dag Serialization is enabled. Hence for the task that have not yet " 

2239 "started running, please use 'airflow tasks render' for debugging the " 

2240 "rendering of template_fields." 

2241 ) from e 

2242 

2243 @provide_session 

2244 def get_rendered_k8s_spec(self, session: Session = NEW_SESSION): 

2245 """Fetch rendered template fields from DB.""" 

2246 from airflow.models.renderedtifields import RenderedTaskInstanceFields 

2247 

2248 rendered_k8s_spec = RenderedTaskInstanceFields.get_k8s_pod_yaml(self, session=session) 

2249 if not rendered_k8s_spec: 

2250 try: 

2251 rendered_k8s_spec = self.render_k8s_pod_yaml() 

2252 except (TemplateAssertionError, UndefinedError) as e: 

2253 raise AirflowException(f"Unable to render a k8s spec for this taskinstance: {e}") from e 

2254 return rendered_k8s_spec 

2255 

2256 def overwrite_params_with_dag_run_conf(self, params, dag_run): 

2257 """Overwrite Task Params with DagRun.conf.""" 

2258 if dag_run and dag_run.conf: 

2259 self.log.debug("Updating task params (%s) with DagRun.conf (%s)", params, dag_run.conf) 

2260 params.update(dag_run.conf) 

2261 

2262 def render_templates(self, context: Context | None = None) -> Operator: 

2263 """Render templates in the operator fields. 

2264 

2265 If the task was originally mapped, this may replace ``self.task`` with 

2266 the unmapped, fully rendered BaseOperator. The original ``self.task`` 

2267 before replacement is returned. 

2268 """ 

2269 if not context: 

2270 context = self.get_template_context() 

2271 original_task = self.task 

2272 

2273 # If self.task is mapped, this call replaces self.task to point to the 

2274 # unmapped BaseOperator created by this function! This is because the 

2275 # MappedOperator is useless for template rendering, and we need to be 

2276 # able to access the unmapped task instead. 

2277 original_task.render_template_fields(context) 

2278 

2279 return original_task 

2280 

2281 def render_k8s_pod_yaml(self) -> dict | None: 

2282 """Render k8s pod yaml.""" 

2283 from kubernetes.client.api_client import ApiClient 

2284 

2285 from airflow.kubernetes.kube_config import KubeConfig 

2286 from airflow.kubernetes.kubernetes_helper_functions import create_pod_id # Circular import 

2287 from airflow.kubernetes.pod_generator import PodGenerator 

2288 

2289 kube_config = KubeConfig() 

2290 pod = PodGenerator.construct_pod( 

2291 dag_id=self.dag_id, 

2292 run_id=self.run_id, 

2293 task_id=self.task_id, 

2294 map_index=self.map_index, 

2295 date=None, 

2296 pod_id=create_pod_id(self.dag_id, self.task_id), 

2297 try_number=self.try_number, 

2298 kube_image=kube_config.kube_image, 

2299 args=self.command_as_list(), 

2300 pod_override_object=PodGenerator.from_obj(self.executor_config), 

2301 scheduler_job_id="0", 

2302 namespace=kube_config.executor_namespace, 

2303 base_worker_pod=PodGenerator.deserialize_model_file(kube_config.pod_template_file), 

2304 with_mutation_hook=True, 

2305 ) 

2306 sanitized_pod = ApiClient().sanitize_for_serialization(pod) 

2307 return sanitized_pod 

2308 

2309 def get_email_subject_content( 

2310 self, exception: BaseException, task: BaseOperator | None = None 

2311 ) -> tuple[str, str, str]: 

2312 """Get the email subject content for exceptions.""" 

2313 # For a ti from DB (without ti.task), return the default value 

2314 if task is None: 

2315 task = getattr(self, "task") 

2316 use_default = task is None 

2317 exception_html = str(exception).replace("\n", "<br>") 

2318 

2319 default_subject = "Airflow alert: {{ti}}" 

2320 # For reporting purposes, we report based on 1-indexed, 

2321 # not 0-indexed lists (i.e. Try 1 instead of 

2322 # Try 0 for the first attempt). 

2323 default_html_content = ( 

2324 "Try {{try_number}} out of {{max_tries + 1}}<br>" 

2325 "Exception:<br>{{exception_html}}<br>" 

2326 'Log: <a href="{{ti.log_url}}">Link</a><br>' 

2327 "Host: {{ti.hostname}}<br>" 

2328 'Mark success: <a href="{{ti.mark_success_url}}">Link</a><br>' 

2329 ) 

2330 

2331 default_html_content_err = ( 

2332 "Try {{try_number}} out of {{max_tries + 1}}<br>" 

2333 "Exception:<br>Failed attempt to attach error logs<br>" 

2334 'Log: <a href="{{ti.log_url}}">Link</a><br>' 

2335 "Host: {{ti.hostname}}<br>" 

2336 'Mark success: <a href="{{ti.mark_success_url}}">Link</a><br>' 

2337 ) 

2338 

2339 # This function is called after changing the state from State.RUNNING, 

2340 # so we need to subtract 1 from self.try_number here. 

2341 current_try_number = self.try_number - 1 

2342 additional_context: dict[str, Any] = { 

2343 "exception": exception, 

2344 "exception_html": exception_html, 

2345 "try_number": current_try_number, 

2346 "max_tries": self.max_tries, 

2347 } 

2348 

2349 if use_default: 

2350 default_context = {"ti": self, **additional_context} 

2351 jinja_env = jinja2.Environment( 

2352 loader=jinja2.FileSystemLoader(os.path.dirname(__file__)), autoescape=True 

2353 ) 

2354 subject = jinja_env.from_string(default_subject).render(**default_context) 

2355 html_content = jinja_env.from_string(default_html_content).render(**default_context) 

2356 html_content_err = jinja_env.from_string(default_html_content_err).render(**default_context) 

2357 

2358 else: 

2359 # Use the DAG's get_template_env() to set force_sandboxed. Don't add 

2360 # the flag to the function on task object -- that function can be 

2361 # overridden, and adding a flag breaks backward compatibility. 

2362 dag = self.task.get_dag() 

2363 if dag: 

2364 jinja_env = dag.get_template_env(force_sandboxed=True) 

2365 else: 

2366 jinja_env = SandboxedEnvironment(cache_size=0) 

2367 jinja_context = self.get_template_context() 

2368 context_merge(jinja_context, additional_context) 

2369 

2370 def render(key: str, content: str) -> str: 

2371 if conf.has_option("email", key): 

2372 path = conf.get_mandatory_value("email", key) 

2373 try: 

2374 with open(path) as f: 

2375 content = f.read() 

2376 except FileNotFoundError: 

2377 self.log.warning(f"Could not find email template file '{path!r}'. Using defaults...") 

2378 except OSError: 

2379 self.log.exception(f"Error while using email template '{path!r}'. Using defaults...") 

2380 return render_template_to_string(jinja_env.from_string(content), jinja_context) 

2381 

2382 subject = render("subject_template", default_subject) 

2383 html_content = render("html_content_template", default_html_content) 

2384 html_content_err = render("html_content_template", default_html_content_err) 

2385 

2386 return subject, html_content, html_content_err 

2387 

2388 def email_alert(self, exception, task: BaseOperator) -> None: 

2389 """Send alert email with exception information.""" 

2390 subject, html_content, html_content_err = self.get_email_subject_content(exception, task=task) 

2391 assert task.email 

2392 try: 

2393 send_email(task.email, subject, html_content) 

2394 except Exception: 

2395 send_email(task.email, subject, html_content_err) 

2396 

2397 def set_duration(self) -> None: 

2398 """Set TI duration.""" 

2399 if self.end_date and self.start_date: 

2400 self.duration = (self.end_date - self.start_date).total_seconds() 

2401 else: 

2402 self.duration = None 

2403 self.log.debug("Task Duration set to %s", self.duration) 

2404 

2405 def _record_task_map_for_downstreams(self, task: Operator, value: Any, *, session: Session) -> None: 

2406 if next(task.iter_mapped_dependants(), None) is None: # No mapped dependants, no need to validate. 

2407 return 

2408 # TODO: We don't push TaskMap for mapped task instances because it's not 

2409 # currently possible for a downstream to depend on one individual mapped 

2410 # task instance. This will change when we implement task mapping inside 

2411 # a mapped task group, and we'll need to further analyze the case. 

2412 if isinstance(task, MappedOperator): 

2413 return 

2414 if value is None: 

2415 raise XComForMappingNotPushed() 

2416 if not _is_mappable_value(value): 

2417 raise UnmappableXComTypePushed(value) 

2418 task_map = TaskMap.from_task_instance_xcom(self, value) 

2419 max_map_length = conf.getint("core", "max_map_length", fallback=1024) 

2420 if task_map.length > max_map_length: 

2421 raise UnmappableXComLengthPushed(value, max_map_length) 

2422 session.merge(task_map) 

2423 

2424 @provide_session 

2425 def xcom_push( 

2426 self, 

2427 key: str, 

2428 value: Any, 

2429 execution_date: datetime | None = None, 

2430 session: Session = NEW_SESSION, 

2431 ) -> None: 

2432 """ 

2433 Make an XCom available for tasks to pull. 

2434 

2435 :param key: Key to store the value under. 

2436 :param value: Value to store. What types are possible depends on whether 

2437 ``enable_xcom_pickling`` is true or not. If so, this can be any 

2438 picklable object; only be JSON-serializable may be used otherwise. 

2439 :param execution_date: Deprecated parameter that has no effect. 

2440 """ 

2441 if execution_date is not None: 

2442 self_execution_date = self.get_dagrun(session).execution_date 

2443 if execution_date < self_execution_date: 

2444 raise ValueError( 

2445 f"execution_date can not be in the past (current execution_date is " 

2446 f"{self_execution_date}; received {execution_date})" 

2447 ) 

2448 elif execution_date is not None: 

2449 message = "Passing 'execution_date' to 'TaskInstance.xcom_push()' is deprecated." 

2450 warnings.warn(message, RemovedInAirflow3Warning, stacklevel=3) 

2451 

2452 XCom.set( 

2453 key=key, 

2454 value=value, 

2455 task_id=self.task_id, 

2456 dag_id=self.dag_id, 

2457 run_id=self.run_id, 

2458 map_index=self.map_index, 

2459 session=session, 

2460 ) 

2461 

2462 @provide_session 

2463 def xcom_pull( 

2464 self, 

2465 task_ids: str | Iterable[str] | None = None, 

2466 dag_id: str | None = None, 

2467 key: str = XCOM_RETURN_KEY, 

2468 include_prior_dates: bool = False, 

2469 session: Session = NEW_SESSION, 

2470 *, 

2471 map_indexes: int | Iterable[int] | None = None, 

2472 default: Any = None, 

2473 ) -> Any: 

2474 """Pull XComs that optionally meet certain criteria. 

2475 

2476 :param key: A key for the XCom. If provided, only XComs with matching 

2477 keys will be returned. The default key is ``'return_value'``, also 

2478 available as constant ``XCOM_RETURN_KEY``. This key is automatically 

2479 given to XComs returned by tasks (as opposed to being pushed 

2480 manually). To remove the filter, pass *None*. 

2481 :param task_ids: Only XComs from tasks with matching ids will be 

2482 pulled. Pass *None* to remove the filter. 

2483 :param dag_id: If provided, only pulls XComs from this DAG. If *None* 

2484 (default), the DAG of the calling task is used. 

2485 :param map_indexes: If provided, only pull XComs with matching indexes. 

2486 If *None* (default), this is inferred from the task(s) being pulled 

2487 (see below for details). 

2488 :param include_prior_dates: If False, only XComs from the current 

2489 execution_date are returned. If *True*, XComs from previous dates 

2490 are returned as well. 

2491 

2492 When pulling one single task (``task_id`` is *None* or a str) without 

2493 specifying ``map_indexes``, the return value is inferred from whether 

2494 the specified task is mapped. If not, value from the one single task 

2495 instance is returned. If the task to pull is mapped, an iterator (not a 

2496 list) yielding XComs from mapped task instances is returned. In either 

2497 case, ``default`` (*None* if not specified) is returned if no matching 

2498 XComs are found. 

2499 

2500 When pulling multiple tasks (i.e. either ``task_id`` or ``map_index`` is 

2501 a non-str iterable), a list of matching XComs is returned. Elements in 

2502 the list is ordered by item ordering in ``task_id`` and ``map_index``. 

2503 """ 

2504 if dag_id is None: 

2505 dag_id = self.dag_id 

2506 

2507 query = XCom.get_many( 

2508 key=key, 

2509 run_id=self.run_id, 

2510 dag_ids=dag_id, 

2511 task_ids=task_ids, 

2512 map_indexes=map_indexes, 

2513 include_prior_dates=include_prior_dates, 

2514 session=session, 

2515 ) 

2516 

2517 # NOTE: Since we're only fetching the value field and not the whole 

2518 # class, the @recreate annotation does not kick in. Therefore we need to 

2519 # call XCom.deserialize_value() manually. 

2520 

2521 # We are only pulling one single task. 

2522 if (task_ids is None or isinstance(task_ids, str)) and not isinstance(map_indexes, Iterable): 

2523 first = query.with_entities( 

2524 XCom.run_id, XCom.task_id, XCom.dag_id, XCom.map_index, XCom.value 

2525 ).first() 

2526 if first is None: # No matching XCom at all. 

2527 return default 

2528 if map_indexes is not None or first.map_index < 0: 

2529 return XCom.deserialize_value(first) 

2530 query = query.order_by(None).order_by(XCom.map_index.asc()) 

2531 return LazyXComAccess.build_from_xcom_query(query) 

2532 

2533 # At this point either task_ids or map_indexes is explicitly multi-value. 

2534 # Order return values to match task_ids and map_indexes ordering. 

2535 query = query.order_by(None) 

2536 if task_ids is None or isinstance(task_ids, str): 

2537 query = query.order_by(XCom.task_id) 

2538 else: 

2539 task_id_whens = {tid: i for i, tid in enumerate(task_ids)} 

2540 if task_id_whens: 

2541 query = query.order_by(case(task_id_whens, value=XCom.task_id)) 

2542 else: 

2543 query = query.order_by(XCom.task_id) 

2544 if map_indexes is None or isinstance(map_indexes, int): 

2545 query = query.order_by(XCom.map_index) 

2546 elif isinstance(map_indexes, range): 

2547 order = XCom.map_index 

2548 if map_indexes.step < 0: 

2549 order = order.desc() 

2550 query = query.order_by(order) 

2551 else: 

2552 map_index_whens = {map_index: i for i, map_index in enumerate(map_indexes)} 

2553 if map_index_whens: 

2554 query = query.order_by(case(map_index_whens, value=XCom.map_index)) 

2555 else: 

2556 query = query.order_by(XCom.map_index) 

2557 return LazyXComAccess.build_from_xcom_query(query) 

2558 

2559 @provide_session 

2560 def get_num_running_task_instances(self, session: Session, same_dagrun=False) -> int: 

2561 """Return Number of running TIs from the DB.""" 

2562 # .count() is inefficient 

2563 num_running_task_instances_query = session.query(func.count()).filter( 

2564 TaskInstance.dag_id == self.dag_id, 

2565 TaskInstance.task_id == self.task_id, 

2566 TaskInstance.state == State.RUNNING, 

2567 ) 

2568 if same_dagrun: 

2569 num_running_task_instances_query.filter(TaskInstance.run_id == self.run_id) 

2570 return num_running_task_instances_query.scalar() 

2571 

2572 def init_run_context(self, raw: bool = False) -> None: 

2573 """Sets the log context.""" 

2574 self.raw = raw 

2575 self._set_context(self) 

2576 

2577 @staticmethod 

2578 def filter_for_tis(tis: Iterable[TaskInstance | TaskInstanceKey]) -> BooleanClauseList | None: 

2579 """Returns SQLAlchemy filter to query selected task instances.""" 

2580 # DictKeys type, (what we often pass here from the scheduler) is not directly indexable :( 

2581 # Or it might be a generator, but we need to be able to iterate over it more than once 

2582 tis = list(tis) 

2583 

2584 if not tis: 

2585 return None 

2586 

2587 first = tis[0] 

2588 

2589 dag_id = first.dag_id 

2590 run_id = first.run_id 

2591 map_index = first.map_index 

2592 first_task_id = first.task_id 

2593 

2594 # pre-compute the set of dag_id, run_id, map_indices and task_ids 

2595 dag_ids, run_ids, map_indices, task_ids = set(), set(), set(), set() 

2596 for t in tis: 

2597 dag_ids.add(t.dag_id) 

2598 run_ids.add(t.run_id) 

2599 map_indices.add(t.map_index) 

2600 task_ids.add(t.task_id) 

2601 

2602 # Common path optimisations: when all TIs are for the same dag_id and run_id, or same dag_id 

2603 # and task_id -- this can be over 150x faster for huge numbers of TIs (20k+) 

2604 if dag_ids == {dag_id} and run_ids == {run_id} and map_indices == {map_index}: 

2605 return and_( 

2606 TaskInstance.dag_id == dag_id, 

2607 TaskInstance.run_id == run_id, 

2608 TaskInstance.map_index == map_index, 

2609 TaskInstance.task_id.in_(task_ids), 

2610 ) 

2611 if dag_ids == {dag_id} and task_ids == {first_task_id} and map_indices == {map_index}: 

2612 return and_( 

2613 TaskInstance.dag_id == dag_id, 

2614 TaskInstance.run_id.in_(run_ids), 

2615 TaskInstance.map_index == map_index, 

2616 TaskInstance.task_id == first_task_id, 

2617 ) 

2618 if dag_ids == {dag_id} and run_ids == {run_id} and task_ids == {first_task_id}: 

2619 return and_( 

2620 TaskInstance.dag_id == dag_id, 

2621 TaskInstance.run_id == run_id, 

2622 TaskInstance.map_index.in_(map_indices), 

2623 TaskInstance.task_id == first_task_id, 

2624 ) 

2625 

2626 filter_condition = [] 

2627 # create 2 nested groups, both primarily grouped by dag_id and run_id, 

2628 # and in the nested group 1 grouped by task_id the other by map_index. 

2629 task_id_groups: dict[tuple, dict[Any, list[Any]]] = defaultdict(lambda: defaultdict(list)) 

2630 map_index_groups: dict[tuple, dict[Any, list[Any]]] = defaultdict(lambda: defaultdict(list)) 

2631 for t in tis: 

2632 task_id_groups[(t.dag_id, t.run_id)][t.task_id].append(t.map_index) 

2633 map_index_groups[(t.dag_id, t.run_id)][t.map_index].append(t.task_id) 

2634 

2635 # this assumes that most dags have dag_id as the largest grouping, followed by run_id. even 

2636 # if its not, this is still a significant optimization over querying for every single tuple key 

2637 for cur_dag_id in dag_ids: 

2638 for cur_run_id in run_ids: 

2639 # we compare the group size between task_id and map_index and use the smaller group 

2640 dag_task_id_groups = task_id_groups[(cur_dag_id, cur_run_id)] 

2641 dag_map_index_groups = map_index_groups[(cur_dag_id, cur_run_id)] 

2642 

2643 if len(dag_task_id_groups) <= len(dag_map_index_groups): 

2644 for cur_task_id, cur_map_indices in dag_task_id_groups.items(): 

2645 filter_condition.append( 

2646 and_( 

2647 TaskInstance.dag_id == cur_dag_id, 

2648 TaskInstance.run_id == cur_run_id, 

2649 TaskInstance.task_id == cur_task_id, 

2650 TaskInstance.map_index.in_(cur_map_indices), 

2651 ) 

2652 ) 

2653 else: 

2654 for cur_map_index, cur_task_ids in dag_map_index_groups.items(): 

2655 filter_condition.append( 

2656 and_( 

2657 TaskInstance.dag_id == cur_dag_id, 

2658 TaskInstance.run_id == cur_run_id, 

2659 TaskInstance.task_id.in_(cur_task_ids), 

2660 TaskInstance.map_index == cur_map_index, 

2661 ) 

2662 ) 

2663 

2664 return or_(*filter_condition) 

2665 

2666 @classmethod 

2667 def ti_selector_condition(cls, vals: Collection[str | tuple[str, int]]) -> ColumnOperators: 

2668 """ 

2669 Build an SQLAlchemy filter for a list where each element can contain 

2670 whether a task_id, or a tuple of (task_id,map_index). 

2671 

2672 :meta private: 

2673 """ 

2674 # Compute a filter for TI.task_id and TI.map_index based on input values 

2675 # For each item, it will either be a task_id, or (task_id, map_index) 

2676 task_id_only = [v for v in vals if isinstance(v, str)] 

2677 with_map_index = [v for v in vals if not isinstance(v, str)] 

2678 

2679 filters: list[ColumnOperators] = [] 

2680 if task_id_only: 

2681 filters.append(cls.task_id.in_(task_id_only)) 

2682 if with_map_index: 

2683 filters.append(tuple_in_condition((cls.task_id, cls.map_index), with_map_index)) 

2684 

2685 if not filters: 

2686 return false() 

2687 if len(filters) == 1: 

2688 return filters[0] 

2689 return or_(*filters) 

2690 

2691 @Sentry.enrich_errors 

2692 @provide_session 

2693 def schedule_downstream_tasks(self, session: Session = NEW_SESSION, max_tis_per_query: int | None = None): 

2694 """ 

2695 The mini-scheduler for scheduling downstream tasks of this task instance. 

2696 

2697 :meta: private 

2698 """ 

2699 from sqlalchemy.exc import OperationalError 

2700 

2701 from airflow.models import DagRun 

2702 

2703 try: 

2704 # Re-select the row with a lock 

2705 dag_run = with_row_locks( 

2706 session.query(DagRun).filter_by( 

2707 dag_id=self.dag_id, 

2708 run_id=self.run_id, 

2709 ), 

2710 session=session, 

2711 ).one() 

2712 

2713 task = self.task 

2714 if TYPE_CHECKING: 

2715 assert task.dag 

2716 

2717 # Get a partial DAG with just the specific tasks we want to examine. 

2718 # In order for dep checks to work correctly, we include ourself (so 

2719 # TriggerRuleDep can check the state of the task we just executed). 

2720 partial_dag = task.dag.partial_subset( 

2721 task.downstream_task_ids, 

2722 include_downstream=True, 

2723 include_upstream=False, 

2724 include_direct_upstream=True, 

2725 ) 

2726 

2727 dag_run.dag = partial_dag 

2728 info = dag_run.task_instance_scheduling_decisions(session) 

2729 

2730 skippable_task_ids = { 

2731 task_id for task_id in partial_dag.task_ids if task_id not in task.downstream_task_ids 

2732 } 

2733 

2734 schedulable_tis = [ 

2735 ti 

2736 for ti in info.schedulable_tis 

2737 if ti.task_id not in skippable_task_ids 

2738 and not ( 

2739 ti.task.inherits_from_empty_operator 

2740 and not ti.task.on_execute_callback 

2741 and not ti.task.on_success_callback 

2742 and not ti.task.outlets 

2743 ) 

2744 ] 

2745 for schedulable_ti in schedulable_tis: 

2746 if not hasattr(schedulable_ti, "task"): 

2747 schedulable_ti.task = task.dag.get_task(schedulable_ti.task_id) 

2748 

2749 num = dag_run.schedule_tis(schedulable_tis, session=session, max_tis_per_query=max_tis_per_query) 

2750 self.log.info("%d downstream tasks scheduled from follow-on schedule check", num) 

2751 

2752 session.flush() 

2753 

2754 except OperationalError as e: 

2755 # Any kind of DB error here is _non fatal_ as this block is just an optimisation. 

2756 self.log.info( 

2757 "Skipping mini scheduling run due to exception: %s", 

2758 e.statement, 

2759 exc_info=True, 

2760 ) 

2761 session.rollback() 

2762 

2763 def get_relevant_upstream_map_indexes( 

2764 self, 

2765 upstream: Operator, 

2766 ti_count: int | None, 

2767 *, 

2768 session: Session, 

2769 ) -> int | range | None: 

2770 """Infer the map indexes of an upstream "relevant" to this ti. 

2771 

2772 The bulk of the logic mainly exists to solve the problem described by 

2773 the following example, where 'val' must resolve to different values, 

2774 depending on where the reference is being used:: 

2775 

2776 @task 

2777 def this_task(v): # This is self.task. 

2778 return v * 2 

2779 

2780 @task_group 

2781 def tg1(inp): 

2782 val = upstream(inp) # This is the upstream task. 

2783 this_task(val) # When inp is 1, val here should resolve to 2. 

2784 return val 

2785 

2786 # This val is the same object returned by tg1. 

2787 val = tg1.expand(inp=[1, 2, 3]) 

2788 

2789 @task_group 

2790 def tg2(inp): 

2791 another_task(inp, val) # val here should resolve to [2, 4, 6]. 

2792 

2793 tg2.expand(inp=["a", "b"]) 

2794 

2795 The surrounding mapped task groups of ``upstream`` and ``self.task`` are 

2796 inspected to find a common "ancestor". If such an ancestor is found, 

2797 we need to return specific map indexes to pull a partial value from 

2798 upstream XCom. 

2799 

2800 :param upstream: The referenced upstream task. 

2801 :param ti_count: The total count of task instance this task was expanded 

2802 by the scheduler, i.e. ``expanded_ti_count`` in the template context. 

2803 :return: Specific map index or map indexes to pull, or ``None`` if we 

2804 want to "whole" return value (i.e. no mapped task groups involved). 

2805 """ 

2806 # This value should never be None since we already know the current task 

2807 # is in a mapped task group, and should have been expanded, despite that, 

2808 # we need to check that it is not None to satisfy Mypy. 

2809 # But this value can be 0 when we expand an empty list, for that it is 

2810 # necessary to check that ti_count is not 0 to avoid dividing by 0. 

2811 if not ti_count: 

2812 return None 

2813 

2814 # Find the innermost common mapped task group between the current task 

2815 # If the current task and the referenced task does not have a common 

2816 # mapped task group, the two are in different task mapping contexts 

2817 # (like another_task above), and we should use the "whole" value. 

2818 common_ancestor = _find_common_ancestor_mapped_group(self.task, upstream) 

2819 if common_ancestor is None: 

2820 return None 

2821 

2822 # At this point we know the two tasks share a mapped task group, and we 

2823 # should use a "partial" value. Let's break down the mapped ti count 

2824 # between the ancestor and further expansion happened inside it. 

2825 ancestor_ti_count = common_ancestor.get_mapped_ti_count(self.run_id, session=session) 

2826 ancestor_map_index = self.map_index * ancestor_ti_count // ti_count 

2827 

2828 # If the task is NOT further expanded inside the common ancestor, we 

2829 # only want to reference one single ti. We must walk the actual DAG, 

2830 # and "ti_count == ancestor_ti_count" does not work, since the further 

2831 # expansion may be of length 1. 

2832 if not _is_further_mapped_inside(upstream, common_ancestor): 

2833 return ancestor_map_index 

2834 

2835 # Otherwise we need a partial aggregation for values from selected task 

2836 # instances in the ancestor's expansion context. 

2837 further_count = ti_count // ancestor_ti_count 

2838 map_index_start = ancestor_map_index * further_count 

2839 return range(map_index_start, map_index_start + further_count) 

2840 

2841 def clear_db_references(self, session): 

2842 """ 

2843 Clear db tables that have a reference to this instance. 

2844 

2845 :param session: ORM Session 

2846 

2847 :meta private: 

2848 """ 

2849 from airflow.models.renderedtifields import RenderedTaskInstanceFields 

2850 

2851 tables = [TaskFail, TaskInstanceNote, TaskReschedule, XCom, RenderedTaskInstanceFields] 

2852 for table in tables: 

2853 session.execute( 

2854 delete(table).where( 

2855 table.dag_id == self.dag_id, 

2856 table.task_id == self.task_id, 

2857 table.run_id == self.run_id, 

2858 table.map_index == self.map_index, 

2859 ) 

2860 ) 

2861 

2862 

2863def _find_common_ancestor_mapped_group(node1: Operator, node2: Operator) -> MappedTaskGroup | None: 

2864 """Given two operators, find their innermost common mapped task group.""" 

2865 if node1.dag is None or node2.dag is None or node1.dag_id != node2.dag_id: 

2866 return None 

2867 parent_group_ids = {g.group_id for g in node1.iter_mapped_task_groups()} 

2868 common_groups = (g for g in node2.iter_mapped_task_groups() if g.group_id in parent_group_ids) 

2869 return next(common_groups, None) 

2870 

2871 

2872def _is_further_mapped_inside(operator: Operator, container: TaskGroup) -> bool: 

2873 """Whether given operator is *further* mapped inside a task group.""" 

2874 if isinstance(operator, MappedOperator): 

2875 return True 

2876 task_group = operator.task_group 

2877 while task_group is not None and task_group.group_id != container.group_id: 

2878 if isinstance(task_group, MappedTaskGroup): 

2879 return True 

2880 task_group = task_group.parent_group 

2881 return False 

2882 

2883 

2884# State of the task instance. 

2885# Stores string version of the task state. 

2886TaskInstanceStateType = Tuple[TaskInstanceKey, str] 

2887 

2888 

2889class SimpleTaskInstance: 

2890 """ 

2891 Simplified Task Instance. 

2892 

2893 Used to send data between processes via Queues. 

2894 """ 

2895 

2896 def __init__( 

2897 self, 

2898 dag_id: str, 

2899 task_id: str, 

2900 run_id: str, 

2901 start_date: datetime | None, 

2902 end_date: datetime | None, 

2903 try_number: int, 

2904 map_index: int, 

2905 state: str, 

2906 executor_config: Any, 

2907 pool: str, 

2908 queue: str, 

2909 key: TaskInstanceKey, 

2910 run_as_user: str | None = None, 

2911 priority_weight: int | None = None, 

2912 ): 

2913 self.dag_id = dag_id 

2914 self.task_id = task_id 

2915 self.run_id = run_id 

2916 self.map_index = map_index 

2917 self.start_date = start_date 

2918 self.end_date = end_date 

2919 self.try_number = try_number 

2920 self.state = state 

2921 self.executor_config = executor_config 

2922 self.run_as_user = run_as_user 

2923 self.pool = pool 

2924 self.priority_weight = priority_weight 

2925 self.queue = queue 

2926 self.key = key 

2927 

2928 def __eq__(self, other): 

2929 if isinstance(other, self.__class__): 

2930 return self.__dict__ == other.__dict__ 

2931 return NotImplemented 

2932 

2933 def as_dict(self): 

2934 warnings.warn( 

2935 "This method is deprecated. Use BaseSerialization.serialize.", 

2936 RemovedInAirflow3Warning, 

2937 stacklevel=2, 

2938 ) 

2939 new_dict = dict(self.__dict__) 

2940 for key in new_dict: 

2941 if key in ["start_date", "end_date"]: 

2942 val = new_dict[key] 

2943 if not val or isinstance(val, str): 

2944 continue 

2945 new_dict.update({key: val.isoformat()}) 

2946 return new_dict 

2947 

2948 @classmethod 

2949 def from_ti(cls, ti: TaskInstance) -> SimpleTaskInstance: 

2950 return cls( 

2951 dag_id=ti.dag_id, 

2952 task_id=ti.task_id, 

2953 run_id=ti.run_id, 

2954 map_index=ti.map_index, 

2955 start_date=ti.start_date, 

2956 end_date=ti.end_date, 

2957 try_number=ti.try_number, 

2958 state=ti.state, 

2959 executor_config=ti.executor_config, 

2960 pool=ti.pool, 

2961 queue=ti.queue, 

2962 key=ti.key, 

2963 run_as_user=ti.run_as_user if hasattr(ti, "run_as_user") else None, 

2964 priority_weight=ti.priority_weight if hasattr(ti, "priority_weight") else None, 

2965 ) 

2966 

2967 @classmethod 

2968 def from_dict(cls, obj_dict: dict) -> SimpleTaskInstance: 

2969 warnings.warn( 

2970 "This method is deprecated. Use BaseSerialization.deserialize.", 

2971 RemovedInAirflow3Warning, 

2972 stacklevel=2, 

2973 ) 

2974 ti_key = TaskInstanceKey(*obj_dict.pop("key")) 

2975 start_date = None 

2976 end_date = None 

2977 start_date_str: str | None = obj_dict.pop("start_date") 

2978 end_date_str: str | None = obj_dict.pop("end_date") 

2979 if start_date_str: 

2980 start_date = timezone.parse(start_date_str) 

2981 if end_date_str: 

2982 end_date = timezone.parse(end_date_str) 

2983 return cls(**obj_dict, start_date=start_date, end_date=end_date, key=ti_key) 

2984 

2985 

2986class TaskInstanceNote(Base): 

2987 """For storage of arbitrary notes concerning the task instance.""" 

2988 

2989 __tablename__ = "task_instance_note" 

2990 

2991 user_id = Column(Integer, nullable=True) 

2992 task_id = Column(StringID(), primary_key=True, nullable=False) 

2993 dag_id = Column(StringID(), primary_key=True, nullable=False) 

2994 run_id = Column(StringID(), primary_key=True, nullable=False) 

2995 map_index = Column(Integer, primary_key=True, nullable=False) 

2996 content = Column(String(1000).with_variant(Text(1000), "mysql")) 

2997 created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False) 

2998 updated_at = Column(UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow, nullable=False) 

2999 

3000 task_instance = relationship("TaskInstance", back_populates="task_instance_note") 

3001 

3002 __table_args__ = ( 

3003 PrimaryKeyConstraint( 

3004 "task_id", "dag_id", "run_id", "map_index", name="task_instance_note_pkey", mssql_clustered=True 

3005 ), 

3006 ForeignKeyConstraint( 

3007 (dag_id, task_id, run_id, map_index), 

3008 [ 

3009 "task_instance.dag_id", 

3010 "task_instance.task_id", 

3011 "task_instance.run_id", 

3012 "task_instance.map_index", 

3013 ], 

3014 name="task_instance_note_ti_fkey", 

3015 ondelete="CASCADE", 

3016 ), 

3017 ForeignKeyConstraint( 

3018 (user_id,), 

3019 ["ab_user.id"], 

3020 name="task_instance_note_user_fkey", 

3021 ), 

3022 ) 

3023 

3024 def __init__(self, content, user_id=None): 

3025 self.content = content 

3026 self.user_id = user_id 

3027 

3028 def __repr__(self): 

3029 prefix = f"<{self.__class__.__name__}: {self.dag_id}.{self.task_id} {self.run_id}" 

3030 if self.map_index != -1: 

3031 prefix += f" map_index={self.map_index}" 

3032 return prefix + ">" 

3033 

3034 

3035STATICA_HACK = True 

3036globals()["kcah_acitats"[::-1].upper()] = False 

3037if STATICA_HACK: # pragma: no cover 

3038 from airflow.jobs.job import Job 

3039 

3040 TaskInstance.queued_by_job = relationship(Job)