Coverage for /pythoncovmergedfiles/medio/medio/src/airflow/airflow/models/taskinstance.py: 21%

1261 statements  

« prev     ^ index     » next       coverage.py v7.0.1, created at 2022-12-25 06:11 +0000

1# 

2# Licensed to the Apache Software Foundation (ASF) under one 

3# or more contributor license agreements. See the NOTICE file 

4# distributed with this work for additional information 

5# regarding copyright ownership. The ASF licenses this file 

6# to you under the Apache License, Version 2.0 (the 

7# "License"); you may not use this file except in compliance 

8# with the License. You may obtain a copy of the License at 

9# 

10# http://www.apache.org/licenses/LICENSE-2.0 

11# 

12# Unless required by applicable law or agreed to in writing, 

13# software distributed under the License is distributed on an 

14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

15# KIND, either express or implied. See the License for the 

16# specific language governing permissions and limitations 

17# under the License. 

18from __future__ import annotations 

19 

20import collections.abc 

21import contextlib 

22import hashlib 

23import logging 

24import math 

25import operator 

26import os 

27import signal 

28import warnings 

29from collections import defaultdict 

30from datetime import datetime, timedelta 

31from functools import partial 

32from types import TracebackType 

33from typing import TYPE_CHECKING, Any, Callable, Collection, Generator, Iterable, NamedTuple, Tuple 

34from urllib.parse import quote 

35 

36import dill 

37import jinja2 

38import lazy_object_proxy 

39import pendulum 

40from jinja2 import TemplateAssertionError, UndefinedError 

41from sqlalchemy import ( 

42 Column, 

43 DateTime, 

44 Float, 

45 ForeignKeyConstraint, 

46 Index, 

47 Integer, 

48 PrimaryKeyConstraint, 

49 String, 

50 Text, 

51 and_, 

52 false, 

53 func, 

54 inspect, 

55 or_, 

56 text, 

57) 

58from sqlalchemy.ext.associationproxy import association_proxy 

59from sqlalchemy.ext.mutable import MutableDict 

60from sqlalchemy.orm import reconstructor, relationship 

61from sqlalchemy.orm.attributes import NO_VALUE, set_committed_value 

62from sqlalchemy.orm.session import Session 

63from sqlalchemy.sql.elements import BooleanClauseList 

64from sqlalchemy.sql.expression import ColumnOperators, case 

65 

66from airflow import settings 

67from airflow.compat.functools import cache 

68from airflow.configuration import conf 

69from airflow.datasets import Dataset 

70from airflow.datasets.manager import dataset_manager 

71from airflow.exceptions import ( 

72 AirflowException, 

73 AirflowFailException, 

74 AirflowRescheduleException, 

75 AirflowSensorTimeout, 

76 AirflowSkipException, 

77 AirflowTaskTimeout, 

78 DagRunNotFound, 

79 RemovedInAirflow3Warning, 

80 TaskDeferralError, 

81 TaskDeferred, 

82 UnmappableXComLengthPushed, 

83 UnmappableXComTypePushed, 

84 XComForMappingNotPushed, 

85) 

86from airflow.models.base import Base, StringID 

87from airflow.models.log import Log 

88from airflow.models.mappedoperator import MappedOperator 

89from airflow.models.param import process_params 

90from airflow.models.taskfail import TaskFail 

91from airflow.models.taskmap import TaskMap 

92from airflow.models.taskreschedule import TaskReschedule 

93from airflow.models.xcom import XCOM_RETURN_KEY, LazyXComAccess, XCom 

94from airflow.plugins_manager import integrate_macros_plugins 

95from airflow.sentry import Sentry 

96from airflow.stats import Stats 

97from airflow.templates import SandboxedEnvironment 

98from airflow.ti_deps.dep_context import DepContext 

99from airflow.ti_deps.dependencies_deps import REQUEUEABLE_DEPS, RUNNING_DEPS 

100from airflow.timetables.base import DataInterval 

101from airflow.typing_compat import Literal, TypeGuard 

102from airflow.utils import timezone 

103from airflow.utils.context import ConnectionAccessor, Context, VariableAccessor, context_merge 

104from airflow.utils.email import send_email 

105from airflow.utils.helpers import render_template_to_string 

106from airflow.utils.log.logging_mixin import LoggingMixin 

107from airflow.utils.module_loading import qualname 

108from airflow.utils.net import get_hostname 

109from airflow.utils.operator_helpers import context_to_airflow_vars 

110from airflow.utils.platform import getuser 

111from airflow.utils.retries import run_with_db_retries 

112from airflow.utils.session import NEW_SESSION, create_session, provide_session 

113from airflow.utils.sqlalchemy import ( 

114 ExecutorConfigType, 

115 ExtendedJSON, 

116 UtcDateTime, 

117 tuple_in_condition, 

118 with_row_locks, 

119) 

120from airflow.utils.state import DagRunState, State, TaskInstanceState 

121from airflow.utils.timeout import timeout 

122 

123TR = TaskReschedule 

124 

125_CURRENT_CONTEXT: list[Context] = [] 

126log = logging.getLogger(__name__) 

127 

128 

129if TYPE_CHECKING: 

130 from airflow.models.abstractoperator import TaskStateChangeCallback 

131 from airflow.models.baseoperator import BaseOperator 

132 from airflow.models.dag import DAG, DagModel 

133 from airflow.models.dagrun import DagRun 

134 from airflow.models.dataset import DatasetEvent 

135 from airflow.models.operator import Operator 

136 from airflow.utils.task_group import MappedTaskGroup, TaskGroup 

137 

138 

139@contextlib.contextmanager 

140def set_current_context(context: Context) -> Generator[Context, None, None]: 

141 """ 

142 Sets the current execution context to the provided context object. 

143 This method should be called once per Task execution, before calling operator.execute. 

144 """ 

145 _CURRENT_CONTEXT.append(context) 

146 try: 

147 yield context 

148 finally: 

149 expected_state = _CURRENT_CONTEXT.pop() 

150 if expected_state != context: 

151 log.warning( 

152 "Current context is not equal to the state at context stack. Expected=%s, got=%s", 

153 context, 

154 expected_state, 

155 ) 

156 

157 

158def clear_task_instances( 

159 tis: list[TaskInstance], 

160 session: Session, 

161 activate_dag_runs: None = None, 

162 dag: DAG | None = None, 

163 dag_run_state: DagRunState | Literal[False] = DagRunState.QUEUED, 

164) -> None: 

165 """ 

166 Clears a set of task instances, but makes sure the running ones 

167 get killed. 

168 

169 :param tis: a list of task instances 

170 :param session: current session 

171 :param dag_run_state: state to set DagRun to. If set to False, dagrun state will not 

172 be changed. 

173 :param dag: DAG object 

174 :param activate_dag_runs: Deprecated parameter, do not pass 

175 """ 

176 job_ids = [] 

177 # Keys: dag_id -> run_id -> map_indexes -> try_numbers -> task_id 

178 task_id_by_key: dict[str, dict[str, dict[int, dict[int, set[str]]]]] = defaultdict( 

179 lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(set))) 

180 ) 

181 for ti in tis: 

182 if ti.state == TaskInstanceState.RUNNING: 

183 if ti.job_id: 

184 # If a task is cleared when running, set its state to RESTARTING so that 

185 # the task is terminated and becomes eligible for retry. 

186 ti.state = TaskInstanceState.RESTARTING 

187 job_ids.append(ti.job_id) 

188 else: 

189 task_id = ti.task_id 

190 if dag and dag.has_task(task_id): 

191 task = dag.get_task(task_id) 

192 ti.refresh_from_task(task) 

193 task_retries = task.retries 

194 ti.max_tries = ti.try_number + task_retries - 1 

195 else: 

196 # Ignore errors when updating max_tries if dag is None or 

197 # task not found in dag since database records could be 

198 # outdated. We make max_tries the maximum value of its 

199 # original max_tries or the last attempted try number. 

200 ti.max_tries = max(ti.max_tries, ti.prev_attempted_tries) 

201 ti.state = None 

202 ti.external_executor_id = None 

203 ti.clear_next_method_args() 

204 session.merge(ti) 

205 

206 task_id_by_key[ti.dag_id][ti.run_id][ti.map_index][ti.try_number].add(ti.task_id) 

207 

208 if task_id_by_key: 

209 # Clear all reschedules related to the ti to clear 

210 

211 # This is an optimization for the common case where all tis are for a small number 

212 # of dag_id, run_id, try_number, and map_index. Use a nested dict of dag_id, 

213 # run_id, try_number, map_index, and task_id to construct the where clause in a 

214 # hierarchical manner. This speeds up the delete statement by more than 40x for 

215 # large number of tis (50k+). 

216 conditions = or_( 

217 and_( 

218 TR.dag_id == dag_id, 

219 or_( 

220 and_( 

221 TR.run_id == run_id, 

222 or_( 

223 and_( 

224 TR.map_index == map_index, 

225 or_( 

226 and_(TR.try_number == try_number, TR.task_id.in_(task_ids)) 

227 for try_number, task_ids in task_tries.items() 

228 ), 

229 ) 

230 for map_index, task_tries in map_indexes.items() 

231 ), 

232 ) 

233 for run_id, map_indexes in run_ids.items() 

234 ), 

235 ) 

236 for dag_id, run_ids in task_id_by_key.items() 

237 ) 

238 

239 delete_qry = TR.__table__.delete().where(conditions) 

240 session.execute(delete_qry) 

241 

242 if job_ids: 

243 from airflow.jobs.base_job import BaseJob 

244 

245 for job in session.query(BaseJob).filter(BaseJob.id.in_(job_ids)).all(): 

246 job.state = TaskInstanceState.RESTARTING 

247 

248 if activate_dag_runs is not None: 

249 warnings.warn( 

250 "`activate_dag_runs` parameter to clear_task_instances function is deprecated. " 

251 "Please use `dag_run_state`", 

252 RemovedInAirflow3Warning, 

253 stacklevel=2, 

254 ) 

255 if not activate_dag_runs: 

256 dag_run_state = False 

257 

258 if dag_run_state is not False and tis: 

259 from airflow.models.dagrun import DagRun # Avoid circular import 

260 

261 run_ids_by_dag_id = defaultdict(set) 

262 for instance in tis: 

263 run_ids_by_dag_id[instance.dag_id].add(instance.run_id) 

264 

265 drs = ( 

266 session.query(DagRun) 

267 .filter( 

268 or_( 

269 and_(DagRun.dag_id == dag_id, DagRun.run_id.in_(run_ids)) 

270 for dag_id, run_ids in run_ids_by_dag_id.items() 

271 ) 

272 ) 

273 .all() 

274 ) 

275 dag_run_state = DagRunState(dag_run_state) # Validate the state value. 

276 for dr in drs: 

277 dr.state = dag_run_state 

278 dr.start_date = timezone.utcnow() 

279 if dag_run_state == DagRunState.QUEUED: 

280 dr.last_scheduling_decision = None 

281 dr.start_date = None 

282 session.flush() 

283 

284 

285def _is_mappable_value(value: Any) -> TypeGuard[Collection]: 

286 """Whether a value can be used for task mapping. 

287 

288 We only allow collections with guaranteed ordering, but exclude character 

289 sequences since that's usually not what users would expect to be mappable. 

290 """ 

291 if not isinstance(value, (collections.abc.Sequence, dict)): 

292 return False 

293 if isinstance(value, (bytearray, bytes, str)): 

294 return False 

295 return True 

296 

297 

298class TaskInstanceKey(NamedTuple): 

299 """Key used to identify task instance.""" 

300 

301 dag_id: str 

302 task_id: str 

303 run_id: str 

304 try_number: int = 1 

305 map_index: int = -1 

306 

307 @property 

308 def primary(self) -> tuple[str, str, str, int]: 

309 """Return task instance primary key part of the key""" 

310 return self.dag_id, self.task_id, self.run_id, self.map_index 

311 

312 @property 

313 def reduced(self) -> TaskInstanceKey: 

314 """Remake the key by subtracting 1 from try number to match in memory information""" 

315 return TaskInstanceKey( 

316 self.dag_id, self.task_id, self.run_id, max(1, self.try_number - 1), self.map_index 

317 ) 

318 

319 def with_try_number(self, try_number: int) -> TaskInstanceKey: 

320 """Returns TaskInstanceKey with provided ``try_number``""" 

321 return TaskInstanceKey(self.dag_id, self.task_id, self.run_id, try_number, self.map_index) 

322 

323 @property 

324 def key(self) -> TaskInstanceKey: 

325 """For API-compatibly with TaskInstance. 

326 

327 Returns self 

328 """ 

329 return self 

330 

331 

332def _creator_note(val): 

333 """Custom creator for the ``note`` association proxy.""" 

334 if isinstance(val, str): 

335 return TaskInstanceNote(content=val) 

336 elif isinstance(val, dict): 

337 return TaskInstanceNote(**val) 

338 else: 

339 return TaskInstanceNote(*val) 

340 

341 

342class TaskInstance(Base, LoggingMixin): 

343 """ 

344 Task instances store the state of a task instance. This table is the 

345 authority and single source of truth around what tasks have run and the 

346 state they are in. 

347 

348 The SqlAlchemy model doesn't have a SqlAlchemy foreign key to the task or 

349 dag model deliberately to have more control over transactions. 

350 

351 Database transactions on this table should insure double triggers and 

352 any confusion around what task instances are or aren't ready to run 

353 even while multiple schedulers may be firing task instances. 

354 

355 A value of -1 in map_index represents any of: a TI without mapped tasks; 

356 a TI with mapped tasks that has yet to be expanded (state=pending); 

357 a TI with mapped tasks that expanded to an empty list (state=skipped). 

358 """ 

359 

360 __tablename__ = "task_instance" 

361 task_id = Column(StringID(), primary_key=True, nullable=False) 

362 dag_id = Column(StringID(), primary_key=True, nullable=False) 

363 run_id = Column(StringID(), primary_key=True, nullable=False) 

364 map_index = Column(Integer, primary_key=True, nullable=False, server_default=text("-1")) 

365 

366 start_date = Column(UtcDateTime) 

367 end_date = Column(UtcDateTime) 

368 duration = Column(Float) 

369 state = Column(String(20)) 

370 _try_number = Column("try_number", Integer, default=0) 

371 max_tries = Column(Integer, server_default=text("-1")) 

372 hostname = Column(String(1000)) 

373 unixname = Column(String(1000)) 

374 job_id = Column(Integer) 

375 pool = Column(String(256), nullable=False) 

376 pool_slots = Column(Integer, default=1, nullable=False) 

377 queue = Column(String(256)) 

378 priority_weight = Column(Integer) 

379 operator = Column(String(1000)) 

380 queued_dttm = Column(UtcDateTime) 

381 queued_by_job_id = Column(Integer) 

382 pid = Column(Integer) 

383 executor_config = Column(ExecutorConfigType(pickler=dill)) 

384 updated_at = Column(UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow) 

385 

386 external_executor_id = Column(StringID()) 

387 

388 # The trigger to resume on if we are in state DEFERRED 

389 trigger_id = Column(Integer) 

390 

391 # Optional timeout datetime for the trigger (past this, we'll fail) 

392 trigger_timeout = Column(DateTime) 

393 # The trigger_timeout should be TIMESTAMP(using UtcDateTime) but for ease of 

394 # migration, we are keeping it as DateTime pending a change where expensive 

395 # migration is inevitable. 

396 

397 # The method to call next, and any extra arguments to pass to it. 

398 # Usually used when resuming from DEFERRED. 

399 next_method = Column(String(1000)) 

400 next_kwargs = Column(MutableDict.as_mutable(ExtendedJSON)) 

401 

402 # If adding new fields here then remember to add them to 

403 # refresh_from_db() or they won't display in the UI correctly 

404 

405 __table_args__ = ( 

406 Index("ti_dag_state", dag_id, state), 

407 Index("ti_dag_run", dag_id, run_id), 

408 Index("ti_state", state), 

409 Index("ti_state_lkp", dag_id, task_id, run_id, state), 

410 Index("ti_pool", pool, state, priority_weight), 

411 Index("ti_job_id", job_id), 

412 Index("ti_trigger_id", trigger_id), 

413 PrimaryKeyConstraint( 

414 "dag_id", "task_id", "run_id", "map_index", name="task_instance_pkey", mssql_clustered=True 

415 ), 

416 ForeignKeyConstraint( 

417 [trigger_id], 

418 ["trigger.id"], 

419 name="task_instance_trigger_id_fkey", 

420 ondelete="CASCADE", 

421 ), 

422 ForeignKeyConstraint( 

423 [dag_id, run_id], 

424 ["dag_run.dag_id", "dag_run.run_id"], 

425 name="task_instance_dag_run_fkey", 

426 ondelete="CASCADE", 

427 ), 

428 ) 

429 

430 dag_model = relationship( 

431 "DagModel", 

432 primaryjoin="TaskInstance.dag_id == DagModel.dag_id", 

433 foreign_keys=dag_id, 

434 uselist=False, 

435 innerjoin=True, 

436 viewonly=True, 

437 ) 

438 

439 trigger = relationship("Trigger", uselist=False) 

440 triggerer_job = association_proxy("trigger", "triggerer_job") 

441 dag_run = relationship("DagRun", back_populates="task_instances", lazy="joined", innerjoin=True) 

442 rendered_task_instance_fields = relationship("RenderedTaskInstanceFields", lazy="noload", uselist=False) 

443 execution_date = association_proxy("dag_run", "execution_date") 

444 task_instance_note = relationship("TaskInstanceNote", back_populates="task_instance", uselist=False) 

445 note = association_proxy("task_instance_note", "content", creator=_creator_note) 

446 task: Operator # Not always set... 

447 

448 def __init__( 

449 self, 

450 task: Operator, 

451 execution_date: datetime | None = None, 

452 run_id: str | None = None, 

453 state: str | None = None, 

454 map_index: int = -1, 

455 ): 

456 super().__init__() 

457 self.dag_id = task.dag_id 

458 self.task_id = task.task_id 

459 self.map_index = map_index 

460 self.refresh_from_task(task) 

461 # init_on_load will config the log 

462 self.init_on_load() 

463 

464 if run_id is None and execution_date is not None: 

465 from airflow.models.dagrun import DagRun # Avoid circular import 

466 

467 warnings.warn( 

468 "Passing an execution_date to `TaskInstance()` is deprecated in favour of passing a run_id", 

469 RemovedInAirflow3Warning, 

470 # Stack level is 4 because SQLA adds some wrappers around the constructor 

471 stacklevel=4, 

472 ) 

473 # make sure we have a localized execution_date stored in UTC 

474 if execution_date and not timezone.is_localized(execution_date): 

475 self.log.warning( 

476 "execution date %s has no timezone information. Using default from dag or system", 

477 execution_date, 

478 ) 

479 if self.task.has_dag(): 

480 if TYPE_CHECKING: 

481 assert self.task.dag 

482 execution_date = timezone.make_aware(execution_date, self.task.dag.timezone) 

483 else: 

484 execution_date = timezone.make_aware(execution_date) 

485 

486 execution_date = timezone.convert_to_utc(execution_date) 

487 with create_session() as session: 

488 run_id = ( 

489 session.query(DagRun.run_id) 

490 .filter_by(dag_id=self.dag_id, execution_date=execution_date) 

491 .scalar() 

492 ) 

493 if run_id is None: 

494 raise DagRunNotFound( 

495 f"DagRun for {self.dag_id!r} with date {execution_date} not found" 

496 ) from None 

497 

498 self.run_id = run_id 

499 

500 self.try_number = 0 

501 self.max_tries = self.task.retries 

502 self.unixname = getuser() 

503 if state: 

504 self.state = state 

505 self.hostname = "" 

506 # Is this TaskInstance being currently running within `airflow tasks run --raw`. 

507 # Not persisted to the database so only valid for the current process 

508 self.raw = False 

509 # can be changed when calling 'run' 

510 self.test_mode = False 

511 

512 @staticmethod 

513 def insert_mapping(run_id: str, task: Operator, map_index: int) -> dict[str, Any]: 

514 """:meta private:""" 

515 return { 

516 "dag_id": task.dag_id, 

517 "task_id": task.task_id, 

518 "run_id": run_id, 

519 "_try_number": 0, 

520 "hostname": "", 

521 "unixname": getuser(), 

522 "queue": task.queue, 

523 "pool": task.pool, 

524 "pool_slots": task.pool_slots, 

525 "priority_weight": task.priority_weight_total, 

526 "run_as_user": task.run_as_user, 

527 "max_tries": task.retries, 

528 "executor_config": task.executor_config, 

529 "operator": task.task_type, 

530 "map_index": map_index, 

531 } 

532 

533 @reconstructor 

534 def init_on_load(self) -> None: 

535 """Initialize the attributes that aren't stored in the DB""" 

536 # correctly config the ti log 

537 self._log = logging.getLogger("airflow.task") 

538 self.test_mode = False # can be changed when calling 'run' 

539 

540 @property 

541 def try_number(self): 

542 """ 

543 Return the try number that this task number will be when it is actually 

544 run. 

545 

546 If the TaskInstance is currently running, this will match the column in the 

547 database, in all other cases this will be incremented. 

548 """ 

549 # This is designed so that task logs end up in the right file. 

550 if self.state in State.running: 

551 return self._try_number 

552 return self._try_number + 1 

553 

554 @try_number.setter 

555 def try_number(self, value: int) -> None: 

556 self._try_number = value 

557 

558 @property 

559 def prev_attempted_tries(self) -> int: 

560 """ 

561 Based on this instance's try_number, this will calculate 

562 the number of previously attempted tries, defaulting to 0. 

563 """ 

564 # Expose this for the Task Tries and Gantt graph views. 

565 # Using `try_number` throws off the counts for non-running tasks. 

566 # Also useful in error logging contexts to get 

567 # the try number for the last try that was attempted. 

568 # https://issues.apache.org/jira/browse/AIRFLOW-2143 

569 

570 return self._try_number 

571 

572 @property 

573 def next_try_number(self) -> int: 

574 return self._try_number + 1 

575 

576 def command_as_list( 

577 self, 

578 mark_success=False, 

579 ignore_all_deps=False, 

580 ignore_task_deps=False, 

581 ignore_depends_on_past=False, 

582 ignore_ti_state=False, 

583 local=False, 

584 pickle_id=None, 

585 raw=False, 

586 job_id=None, 

587 pool=None, 

588 cfg_path=None, 

589 ): 

590 """ 

591 Returns a command that can be executed anywhere where airflow is 

592 installed. This command is part of the message sent to executors by 

593 the orchestrator. 

594 """ 

595 dag: DAG | DagModel 

596 # Use the dag if we have it, else fallback to the ORM dag_model, which might not be loaded 

597 if hasattr(self, "task") and hasattr(self.task, "dag"): 

598 dag = self.task.dag 

599 else: 

600 dag = self.dag_model 

601 

602 should_pass_filepath = not pickle_id and dag 

603 path = None 

604 if should_pass_filepath: 

605 if dag.is_subdag: 

606 path = dag.parent_dag.relative_fileloc 

607 else: 

608 path = dag.relative_fileloc 

609 

610 if path: 

611 if not path.is_absolute(): 

612 path = "DAGS_FOLDER" / path 

613 path = str(path) 

614 

615 return TaskInstance.generate_command( 

616 self.dag_id, 

617 self.task_id, 

618 run_id=self.run_id, 

619 mark_success=mark_success, 

620 ignore_all_deps=ignore_all_deps, 

621 ignore_task_deps=ignore_task_deps, 

622 ignore_depends_on_past=ignore_depends_on_past, 

623 ignore_ti_state=ignore_ti_state, 

624 local=local, 

625 pickle_id=pickle_id, 

626 file_path=path, 

627 raw=raw, 

628 job_id=job_id, 

629 pool=pool, 

630 cfg_path=cfg_path, 

631 map_index=self.map_index, 

632 ) 

633 

634 @staticmethod 

635 def generate_command( 

636 dag_id: str, 

637 task_id: str, 

638 run_id: str, 

639 mark_success: bool = False, 

640 ignore_all_deps: bool = False, 

641 ignore_depends_on_past: bool = False, 

642 ignore_task_deps: bool = False, 

643 ignore_ti_state: bool = False, 

644 local: bool = False, 

645 pickle_id: int | None = None, 

646 file_path: str | None = None, 

647 raw: bool = False, 

648 job_id: str | None = None, 

649 pool: str | None = None, 

650 cfg_path: str | None = None, 

651 map_index: int = -1, 

652 ) -> list[str]: 

653 """ 

654 Generates the shell command required to execute this task instance. 

655 

656 :param dag_id: DAG ID 

657 :param task_id: Task ID 

658 :param run_id: The run_id of this task's DagRun 

659 :param mark_success: Whether to mark the task as successful 

660 :param ignore_all_deps: Ignore all ignorable dependencies. 

661 Overrides the other ignore_* parameters. 

662 :param ignore_depends_on_past: Ignore depends_on_past parameter of DAGs 

663 (e.g. for Backfills) 

664 :param ignore_task_deps: Ignore task-specific dependencies such as depends_on_past 

665 and trigger rule 

666 :param ignore_ti_state: Ignore the task instance's previous failure/success 

667 :param local: Whether to run the task locally 

668 :param pickle_id: If the DAG was serialized to the DB, the ID 

669 associated with the pickled DAG 

670 :param file_path: path to the file containing the DAG definition 

671 :param raw: raw mode (needs more details) 

672 :param job_id: job ID (needs more details) 

673 :param pool: the Airflow pool that the task should run in 

674 :param cfg_path: the Path to the configuration file 

675 :return: shell command that can be used to run the task instance 

676 """ 

677 cmd = ["airflow", "tasks", "run", dag_id, task_id, run_id] 

678 if mark_success: 

679 cmd.extend(["--mark-success"]) 

680 if pickle_id: 

681 cmd.extend(["--pickle", str(pickle_id)]) 

682 if job_id: 

683 cmd.extend(["--job-id", str(job_id)]) 

684 if ignore_all_deps: 

685 cmd.extend(["--ignore-all-dependencies"]) 

686 if ignore_task_deps: 

687 cmd.extend(["--ignore-dependencies"]) 

688 if ignore_depends_on_past: 

689 cmd.extend(["--ignore-depends-on-past"]) 

690 if ignore_ti_state: 

691 cmd.extend(["--force"]) 

692 if local: 

693 cmd.extend(["--local"]) 

694 if pool: 

695 cmd.extend(["--pool", pool]) 

696 if raw: 

697 cmd.extend(["--raw"]) 

698 if file_path: 

699 cmd.extend(["--subdir", file_path]) 

700 if cfg_path: 

701 cmd.extend(["--cfg-path", cfg_path]) 

702 if map_index != -1: 

703 cmd.extend(["--map-index", str(map_index)]) 

704 return cmd 

705 

706 @property 

707 def log_url(self) -> str: 

708 """Log URL for TaskInstance""" 

709 iso = quote(self.execution_date.isoformat()) 

710 base_url = conf.get_mandatory_value("webserver", "BASE_URL") 

711 return ( 

712 f"{base_url}/log" 

713 f"?execution_date={iso}" 

714 f"&task_id={self.task_id}" 

715 f"&dag_id={self.dag_id}" 

716 f"&map_index={self.map_index}" 

717 ) 

718 

719 @property 

720 def mark_success_url(self) -> str: 

721 """URL to mark TI success""" 

722 base_url = conf.get_mandatory_value("webserver", "BASE_URL") 

723 return ( 

724 f"{base_url}/confirm" 

725 f"?task_id={self.task_id}" 

726 f"&dag_id={self.dag_id}" 

727 f"&dag_run_id={quote(self.run_id)}" 

728 "&upstream=false" 

729 "&downstream=false" 

730 "&state=success" 

731 ) 

732 

733 @provide_session 

734 def current_state(self, session: Session = NEW_SESSION) -> str: 

735 """ 

736 Get the very latest state from the database, if a session is passed, 

737 we use and looking up the state becomes part of the session, otherwise 

738 a new session is used. 

739 

740 sqlalchemy.inspect is used here to get the primary keys ensuring that if they change 

741 it will not regress 

742 

743 :param session: SQLAlchemy ORM Session 

744 """ 

745 filters = (col == getattr(self, col.name) for col in inspect(TaskInstance).primary_key) 

746 return session.query(TaskInstance.state).filter(*filters).scalar() 

747 

748 @provide_session 

749 def error(self, session: Session = NEW_SESSION) -> None: 

750 """ 

751 Forces the task instance's state to FAILED in the database. 

752 

753 :param session: SQLAlchemy ORM Session 

754 """ 

755 self.log.error("Recording the task instance as FAILED") 

756 self.state = State.FAILED 

757 session.merge(self) 

758 session.commit() 

759 

760 @provide_session 

761 def refresh_from_db(self, session: Session = NEW_SESSION, lock_for_update: bool = False) -> None: 

762 """ 

763 Refreshes the task instance from the database based on the primary key 

764 

765 :param session: SQLAlchemy ORM Session 

766 :param lock_for_update: if True, indicates that the database should 

767 lock the TaskInstance (issuing a FOR UPDATE clause) until the 

768 session is committed. 

769 """ 

770 self.log.debug("Refreshing TaskInstance %s from DB", self) 

771 

772 if self in session: 

773 session.refresh(self, TaskInstance.__mapper__.column_attrs.keys()) 

774 

775 qry = ( 

776 # To avoid joining any relationships, by default select all 

777 # columns, not the object. This also means we get (effectively) a 

778 # namedtuple back, not a TI object 

779 session.query(*TaskInstance.__table__.columns).filter( 

780 TaskInstance.dag_id == self.dag_id, 

781 TaskInstance.task_id == self.task_id, 

782 TaskInstance.run_id == self.run_id, 

783 TaskInstance.map_index == self.map_index, 

784 ) 

785 ) 

786 

787 if lock_for_update: 

788 for attempt in run_with_db_retries(logger=self.log): 

789 with attempt: 

790 ti: TaskInstance | None = qry.with_for_update().one_or_none() 

791 else: 

792 ti = qry.one_or_none() 

793 if ti: 

794 # Fields ordered per model definition 

795 self.start_date = ti.start_date 

796 self.end_date = ti.end_date 

797 self.duration = ti.duration 

798 self.state = ti.state 

799 # Since we selected columns, not the object, this is the raw value 

800 self.try_number = ti.try_number 

801 self.max_tries = ti.max_tries 

802 self.hostname = ti.hostname 

803 self.unixname = ti.unixname 

804 self.job_id = ti.job_id 

805 self.pool = ti.pool 

806 self.pool_slots = ti.pool_slots or 1 

807 self.queue = ti.queue 

808 self.priority_weight = ti.priority_weight 

809 self.operator = ti.operator 

810 self.queued_dttm = ti.queued_dttm 

811 self.queued_by_job_id = ti.queued_by_job_id 

812 self.pid = ti.pid 

813 self.executor_config = ti.executor_config 

814 self.external_executor_id = ti.external_executor_id 

815 self.trigger_id = ti.trigger_id 

816 self.next_method = ti.next_method 

817 self.next_kwargs = ti.next_kwargs 

818 else: 

819 self.state = None 

820 

821 def refresh_from_task(self, task: Operator, pool_override: str | None = None) -> None: 

822 """ 

823 Copy common attributes from the given task. 

824 

825 :param task: The task object to copy from 

826 :param pool_override: Use the pool_override instead of task's pool 

827 """ 

828 self.task = task 

829 self.queue = task.queue 

830 self.pool = pool_override or task.pool 

831 self.pool_slots = task.pool_slots 

832 self.priority_weight = task.priority_weight_total 

833 self.run_as_user = task.run_as_user 

834 # Do not set max_tries to task.retries here because max_tries is a cumulative 

835 # value that needs to be stored in the db. 

836 self.executor_config = task.executor_config 

837 self.operator = task.task_type 

838 

839 @provide_session 

840 def clear_xcom_data(self, session: Session = NEW_SESSION) -> None: 

841 """Clear all XCom data from the database for the task instance. 

842 

843 If the task is unmapped, all XComs matching this task ID in the same DAG 

844 run are removed. If the task is mapped, only the one with matching map 

845 index is removed. 

846 

847 :param session: SQLAlchemy ORM Session 

848 """ 

849 self.log.debug("Clearing XCom data") 

850 if self.map_index < 0: 

851 map_index: int | None = None 

852 else: 

853 map_index = self.map_index 

854 XCom.clear( 

855 dag_id=self.dag_id, 

856 task_id=self.task_id, 

857 run_id=self.run_id, 

858 map_index=map_index, 

859 session=session, 

860 ) 

861 

862 @property 

863 def key(self) -> TaskInstanceKey: 

864 """Returns a tuple that identifies the task instance uniquely""" 

865 return TaskInstanceKey(self.dag_id, self.task_id, self.run_id, self.try_number, self.map_index) 

866 

867 @provide_session 

868 def set_state(self, state: str | None, session: Session = NEW_SESSION) -> bool: 

869 """ 

870 Set TaskInstance state. 

871 

872 :param state: State to set for the TI 

873 :param session: SQLAlchemy ORM Session 

874 :return: Was the state changed 

875 """ 

876 if self.state == state: 

877 return False 

878 

879 current_time = timezone.utcnow() 

880 self.log.debug("Setting task state for %s to %s", self, state) 

881 self.state = state 

882 self.start_date = self.start_date or current_time 

883 if self.state in State.finished or self.state == State.UP_FOR_RETRY: 

884 self.end_date = self.end_date or current_time 

885 self.duration = (self.end_date - self.start_date).total_seconds() 

886 session.merge(self) 

887 return True 

888 

889 @property 

890 def is_premature(self) -> bool: 

891 """ 

892 Returns whether a task is in UP_FOR_RETRY state and its retry interval 

893 has elapsed. 

894 """ 

895 # is the task still in the retry waiting period? 

896 return self.state == State.UP_FOR_RETRY and not self.ready_for_retry() 

897 

898 @provide_session 

899 def are_dependents_done(self, session: Session = NEW_SESSION) -> bool: 

900 """ 

901 Checks whether the immediate dependents of this task instance have succeeded or have been skipped. 

902 This is meant to be used by wait_for_downstream. 

903 

904 This is useful when you do not want to start processing the next 

905 schedule of a task until the dependents are done. For instance, 

906 if the task DROPs and recreates a table. 

907 

908 :param session: SQLAlchemy ORM Session 

909 """ 

910 task = self.task 

911 

912 if not task.downstream_task_ids: 

913 return True 

914 

915 ti = session.query(func.count(TaskInstance.task_id)).filter( 

916 TaskInstance.dag_id == self.dag_id, 

917 TaskInstance.task_id.in_(task.downstream_task_ids), 

918 TaskInstance.run_id == self.run_id, 

919 TaskInstance.state.in_([State.SKIPPED, State.SUCCESS]), 

920 ) 

921 count = ti[0][0] 

922 return count == len(task.downstream_task_ids) 

923 

924 @provide_session 

925 def get_previous_dagrun( 

926 self, 

927 state: DagRunState | None = None, 

928 session: Session | None = None, 

929 ) -> DagRun | None: 

930 """The DagRun that ran before this task instance's DagRun. 

931 

932 :param state: If passed, it only take into account instances of a specific state. 

933 :param session: SQLAlchemy ORM Session. 

934 """ 

935 dag = self.task.dag 

936 if dag is None: 

937 return None 

938 

939 dr = self.get_dagrun(session=session) 

940 dr.dag = dag 

941 

942 # We always ignore schedule in dagrun lookup when `state` is given 

943 # or the DAG is never scheduled. For legacy reasons, when 

944 # `catchup=True`, we use `get_previous_scheduled_dagrun` unless 

945 # `ignore_schedule` is `True`. 

946 ignore_schedule = state is not None or not dag.timetable.can_run 

947 if dag.catchup is True and not ignore_schedule: 

948 last_dagrun = dr.get_previous_scheduled_dagrun(session=session) 

949 else: 

950 last_dagrun = dr.get_previous_dagrun(session=session, state=state) 

951 

952 if last_dagrun: 

953 return last_dagrun 

954 

955 return None 

956 

957 @provide_session 

958 def get_previous_ti( 

959 self, 

960 state: DagRunState | None = None, 

961 session: Session = NEW_SESSION, 

962 ) -> TaskInstance | None: 

963 """ 

964 The task instance for the task that ran before this task instance. 

965 

966 :param state: If passed, it only take into account instances of a specific state. 

967 :param session: SQLAlchemy ORM Session 

968 """ 

969 dagrun = self.get_previous_dagrun(state, session=session) 

970 if dagrun is None: 

971 return None 

972 return dagrun.get_task_instance(self.task_id, session=session) 

973 

974 @property 

975 def previous_ti(self) -> TaskInstance | None: 

976 """ 

977 This attribute is deprecated. 

978 Please use `airflow.models.taskinstance.TaskInstance.get_previous_ti` method. 

979 """ 

980 warnings.warn( 

981 """ 

982 This attribute is deprecated. 

983 Please use `airflow.models.taskinstance.TaskInstance.get_previous_ti` method. 

984 """, 

985 RemovedInAirflow3Warning, 

986 stacklevel=2, 

987 ) 

988 return self.get_previous_ti() 

989 

990 @property 

991 def previous_ti_success(self) -> TaskInstance | None: 

992 """ 

993 This attribute is deprecated. 

994 Please use `airflow.models.taskinstance.TaskInstance.get_previous_ti` method. 

995 """ 

996 warnings.warn( 

997 """ 

998 This attribute is deprecated. 

999 Please use `airflow.models.taskinstance.TaskInstance.get_previous_ti` method. 

1000 """, 

1001 RemovedInAirflow3Warning, 

1002 stacklevel=2, 

1003 ) 

1004 return self.get_previous_ti(state=DagRunState.SUCCESS) 

1005 

1006 @provide_session 

1007 def get_previous_execution_date( 

1008 self, 

1009 state: DagRunState | None = None, 

1010 session: Session = NEW_SESSION, 

1011 ) -> pendulum.DateTime | None: 

1012 """ 

1013 The execution date from property previous_ti_success. 

1014 

1015 :param state: If passed, it only take into account instances of a specific state. 

1016 :param session: SQLAlchemy ORM Session 

1017 """ 

1018 self.log.debug("previous_execution_date was called") 

1019 prev_ti = self.get_previous_ti(state=state, session=session) 

1020 return prev_ti and pendulum.instance(prev_ti.execution_date) 

1021 

1022 @provide_session 

1023 def get_previous_start_date( 

1024 self, state: DagRunState | None = None, session: Session = NEW_SESSION 

1025 ) -> pendulum.DateTime | None: 

1026 """ 

1027 The start date from property previous_ti_success. 

1028 

1029 :param state: If passed, it only take into account instances of a specific state. 

1030 :param session: SQLAlchemy ORM Session 

1031 """ 

1032 self.log.debug("previous_start_date was called") 

1033 prev_ti = self.get_previous_ti(state=state, session=session) 

1034 # prev_ti may not exist and prev_ti.start_date may be None. 

1035 return prev_ti and prev_ti.start_date and pendulum.instance(prev_ti.start_date) 

1036 

1037 @property 

1038 def previous_start_date_success(self) -> pendulum.DateTime | None: 

1039 """ 

1040 This attribute is deprecated. 

1041 Please use `airflow.models.taskinstance.TaskInstance.get_previous_start_date` method. 

1042 """ 

1043 warnings.warn( 

1044 """ 

1045 This attribute is deprecated. 

1046 Please use `airflow.models.taskinstance.TaskInstance.get_previous_start_date` method. 

1047 """, 

1048 RemovedInAirflow3Warning, 

1049 stacklevel=2, 

1050 ) 

1051 return self.get_previous_start_date(state=DagRunState.SUCCESS) 

1052 

1053 @provide_session 

1054 def are_dependencies_met( 

1055 self, dep_context: DepContext | None = None, session: Session = NEW_SESSION, verbose: bool = False 

1056 ) -> bool: 

1057 """ 

1058 Returns whether or not all the conditions are met for this task instance to be run 

1059 given the context for the dependencies (e.g. a task instance being force run from 

1060 the UI will ignore some dependencies). 

1061 

1062 :param dep_context: The execution context that determines the dependencies that 

1063 should be evaluated. 

1064 :param session: database session 

1065 :param verbose: whether log details on failed dependencies on 

1066 info or debug log level 

1067 """ 

1068 dep_context = dep_context or DepContext() 

1069 failed = False 

1070 verbose_aware_logger = self.log.info if verbose else self.log.debug 

1071 for dep_status in self.get_failed_dep_statuses(dep_context=dep_context, session=session): 

1072 failed = True 

1073 

1074 verbose_aware_logger( 

1075 "Dependencies not met for %s, dependency '%s' FAILED: %s", 

1076 self, 

1077 dep_status.dep_name, 

1078 dep_status.reason, 

1079 ) 

1080 

1081 if failed: 

1082 return False 

1083 

1084 verbose_aware_logger("Dependencies all met for %s", self) 

1085 return True 

1086 

1087 @provide_session 

1088 def get_failed_dep_statuses(self, dep_context: DepContext | None = None, session: Session = NEW_SESSION): 

1089 """Get failed Dependencies""" 

1090 dep_context = dep_context or DepContext() 

1091 for dep in dep_context.deps | self.task.deps: 

1092 for dep_status in dep.get_dep_statuses(self, session, dep_context): 

1093 

1094 self.log.debug( 

1095 "%s dependency '%s' PASSED: %s, %s", 

1096 self, 

1097 dep_status.dep_name, 

1098 dep_status.passed, 

1099 dep_status.reason, 

1100 ) 

1101 

1102 if not dep_status.passed: 

1103 yield dep_status 

1104 

1105 def __repr__(self) -> str: 

1106 prefix = f"<TaskInstance: {self.dag_id}.{self.task_id} {self.run_id} " 

1107 if self.map_index != -1: 

1108 prefix += f"map_index={self.map_index} " 

1109 return prefix + f"[{self.state}]>" 

1110 

1111 def next_retry_datetime(self): 

1112 """ 

1113 Get datetime of the next retry if the task instance fails. For exponential 

1114 backoff, retry_delay is used as base and will be converted to seconds. 

1115 """ 

1116 from airflow.models.abstractoperator import MAX_RETRY_DELAY 

1117 

1118 delay = self.task.retry_delay 

1119 if self.task.retry_exponential_backoff: 

1120 # If the min_backoff calculation is below 1, it will be converted to 0 via int. Thus, 

1121 # we must round up prior to converting to an int, otherwise a divide by zero error 

1122 # will occur in the modded_hash calculation. 

1123 min_backoff = int(math.ceil(delay.total_seconds() * (2 ** (self.try_number - 2)))) 

1124 

1125 # In the case when delay.total_seconds() is 0, min_backoff will not be rounded up to 1. 

1126 # To address this, we impose a lower bound of 1 on min_backoff. This effectively makes 

1127 # the ceiling function unnecessary, but the ceiling function was retained to avoid 

1128 # introducing a breaking change. 

1129 if min_backoff < 1: 

1130 min_backoff = 1 

1131 

1132 # deterministic per task instance 

1133 ti_hash = int( 

1134 hashlib.sha1( 

1135 f"{self.dag_id}#{self.task_id}#{self.execution_date}#{self.try_number}".encode() 

1136 ).hexdigest(), 

1137 16, 

1138 ) 

1139 # between 1 and 1.0 * delay * (2^retry_number) 

1140 modded_hash = min_backoff + ti_hash % min_backoff 

1141 # timedelta has a maximum representable value. The exponentiation 

1142 # here means this value can be exceeded after a certain number 

1143 # of tries (around 50 if the initial delay is 1s, even fewer if 

1144 # the delay is larger). Cap the value here before creating a 

1145 # timedelta object so the operation doesn't fail with "OverflowError". 

1146 delay_backoff_in_seconds = min(modded_hash, MAX_RETRY_DELAY) 

1147 delay = timedelta(seconds=delay_backoff_in_seconds) 

1148 if self.task.max_retry_delay: 

1149 delay = min(self.task.max_retry_delay, delay) 

1150 return self.end_date + delay 

1151 

1152 def ready_for_retry(self) -> bool: 

1153 """ 

1154 Checks on whether the task instance is in the right state and timeframe 

1155 to be retried. 

1156 """ 

1157 return self.state == State.UP_FOR_RETRY and self.next_retry_datetime() < timezone.utcnow() 

1158 

1159 @provide_session 

1160 def get_dagrun(self, session: Session = NEW_SESSION) -> DagRun: 

1161 """ 

1162 Returns the DagRun for this TaskInstance 

1163 

1164 :param session: SQLAlchemy ORM Session 

1165 :return: DagRun 

1166 """ 

1167 info = inspect(self) 

1168 if info.attrs.dag_run.loaded_value is not NO_VALUE: 

1169 return self.dag_run 

1170 

1171 from airflow.models.dagrun import DagRun # Avoid circular import 

1172 

1173 dr = session.query(DagRun).filter(DagRun.dag_id == self.dag_id, DagRun.run_id == self.run_id).one() 

1174 

1175 # Record it in the instance for next time. This means that `self.execution_date` will work correctly 

1176 set_committed_value(self, "dag_run", dr) 

1177 

1178 return dr 

1179 

1180 @provide_session 

1181 def check_and_change_state_before_execution( 

1182 self, 

1183 verbose: bool = True, 

1184 ignore_all_deps: bool = False, 

1185 ignore_depends_on_past: bool = False, 

1186 ignore_task_deps: bool = False, 

1187 ignore_ti_state: bool = False, 

1188 mark_success: bool = False, 

1189 test_mode: bool = False, 

1190 job_id: str | None = None, 

1191 pool: str | None = None, 

1192 external_executor_id: str | None = None, 

1193 session: Session = NEW_SESSION, 

1194 ) -> bool: 

1195 """ 

1196 Checks dependencies and then sets state to RUNNING if they are met. Returns 

1197 True if and only if state is set to RUNNING, which implies that task should be 

1198 executed, in preparation for _run_raw_task 

1199 

1200 :param verbose: whether to turn on more verbose logging 

1201 :param ignore_all_deps: Ignore all of the non-critical dependencies, just runs 

1202 :param ignore_depends_on_past: Ignore depends_on_past DAG attribute 

1203 :param ignore_task_deps: Don't check the dependencies of this TaskInstance's task 

1204 :param ignore_ti_state: Disregards previous task instance state 

1205 :param mark_success: Don't run the task, mark its state as success 

1206 :param test_mode: Doesn't record success or failure in the DB 

1207 :param job_id: Job (BackfillJob / LocalTaskJob / SchedulerJob) ID 

1208 :param pool: specifies the pool to use to run the task instance 

1209 :param external_executor_id: The identifier of the celery executor 

1210 :param session: SQLAlchemy ORM Session 

1211 :return: whether the state was changed to running or not 

1212 """ 

1213 task = self.task 

1214 self.refresh_from_task(task, pool_override=pool) 

1215 self.test_mode = test_mode 

1216 self.refresh_from_db(session=session, lock_for_update=True) 

1217 self.job_id = job_id 

1218 self.hostname = get_hostname() 

1219 self.pid = None 

1220 

1221 if not ignore_all_deps and not ignore_ti_state and self.state == State.SUCCESS: 

1222 Stats.incr("previously_succeeded", 1, 1) 

1223 

1224 # TODO: Logging needs cleanup, not clear what is being printed 

1225 hr_line_break = "\n" + ("-" * 80) # Line break 

1226 

1227 if not mark_success: 

1228 # Firstly find non-runnable and non-requeueable tis. 

1229 # Since mark_success is not set, we do nothing. 

1230 non_requeueable_dep_context = DepContext( 

1231 deps=RUNNING_DEPS - REQUEUEABLE_DEPS, 

1232 ignore_all_deps=ignore_all_deps, 

1233 ignore_ti_state=ignore_ti_state, 

1234 ignore_depends_on_past=ignore_depends_on_past, 

1235 ignore_task_deps=ignore_task_deps, 

1236 ) 

1237 if not self.are_dependencies_met( 

1238 dep_context=non_requeueable_dep_context, session=session, verbose=True 

1239 ): 

1240 session.commit() 

1241 return False 

1242 

1243 # For reporting purposes, we report based on 1-indexed, 

1244 # not 0-indexed lists (i.e. Attempt 1 instead of 

1245 # Attempt 0 for the first attempt). 

1246 # Set the task start date. In case it was re-scheduled use the initial 

1247 # start date that is recorded in task_reschedule table 

1248 # If the task continues after being deferred (next_method is set), use the original start_date 

1249 self.start_date = self.start_date if self.next_method else timezone.utcnow() 

1250 if self.state == State.UP_FOR_RESCHEDULE: 

1251 task_reschedule: TR = TR.query_for_task_instance(self, session=session).first() 

1252 if task_reschedule: 

1253 self.start_date = task_reschedule.start_date 

1254 

1255 # Secondly we find non-runnable but requeueable tis. We reset its state. 

1256 # This is because we might have hit concurrency limits, 

1257 # e.g. because of backfilling. 

1258 dep_context = DepContext( 

1259 deps=REQUEUEABLE_DEPS, 

1260 ignore_all_deps=ignore_all_deps, 

1261 ignore_depends_on_past=ignore_depends_on_past, 

1262 ignore_task_deps=ignore_task_deps, 

1263 ignore_ti_state=ignore_ti_state, 

1264 ) 

1265 if not self.are_dependencies_met(dep_context=dep_context, session=session, verbose=True): 

1266 self.state = State.NONE 

1267 self.log.warning(hr_line_break) 

1268 self.log.warning( 

1269 "Rescheduling due to concurrency limits reached " 

1270 "at task runtime. Attempt %s of " 

1271 "%s. State set to NONE.", 

1272 self.try_number, 

1273 self.max_tries + 1, 

1274 ) 

1275 self.log.warning(hr_line_break) 

1276 self.queued_dttm = timezone.utcnow() 

1277 session.merge(self) 

1278 session.commit() 

1279 return False 

1280 

1281 # print status message 

1282 self.log.info(hr_line_break) 

1283 self.log.info("Starting attempt %s of %s", self.try_number, self.max_tries + 1) 

1284 self.log.info(hr_line_break) 

1285 self._try_number += 1 

1286 

1287 if not test_mode: 

1288 session.add(Log(State.RUNNING, self)) 

1289 self.state = State.RUNNING 

1290 self.external_executor_id = external_executor_id 

1291 self.end_date = None 

1292 if not test_mode: 

1293 session.merge(self).task = task 

1294 session.commit() 

1295 

1296 # Closing all pooled connections to prevent 

1297 # "max number of connections reached" 

1298 settings.engine.dispose() # type: ignore 

1299 if verbose: 

1300 if mark_success: 

1301 self.log.info("Marking success for %s on %s", self.task, self.execution_date) 

1302 else: 

1303 self.log.info("Executing %s on %s", self.task, self.execution_date) 

1304 return True 

1305 

1306 def _date_or_empty(self, attr: str) -> str: 

1307 result: datetime | None = getattr(self, attr, None) 

1308 return result.strftime("%Y%m%dT%H%M%S") if result else "" 

1309 

1310 def _log_state(self, lead_msg: str = "") -> None: 

1311 params = [ 

1312 lead_msg, 

1313 str(self.state).upper(), 

1314 self.dag_id, 

1315 self.task_id, 

1316 ] 

1317 message = "%sMarking task as %s. dag_id=%s, task_id=%s, " 

1318 if self.map_index >= 0: 

1319 params.append(self.map_index) 

1320 message += "map_index=%d, " 

1321 self.log.info( 

1322 message + "execution_date=%s, start_date=%s, end_date=%s", 

1323 *params, 

1324 self._date_or_empty("execution_date"), 

1325 self._date_or_empty("start_date"), 

1326 self._date_or_empty("end_date"), 

1327 ) 

1328 

1329 # Ensure we unset next_method and next_kwargs to ensure that any 

1330 # retries don't re-use them. 

1331 def clear_next_method_args(self) -> None: 

1332 self.log.debug("Clearing next_method and next_kwargs.") 

1333 

1334 self.next_method = None 

1335 self.next_kwargs = None 

1336 

1337 @provide_session 

1338 @Sentry.enrich_errors 

1339 def _run_raw_task( 

1340 self, 

1341 mark_success: bool = False, 

1342 test_mode: bool = False, 

1343 job_id: str | None = None, 

1344 pool: str | None = None, 

1345 session: Session = NEW_SESSION, 

1346 ) -> None: 

1347 """ 

1348 Immediately runs the task (without checking or changing db state 

1349 before execution) and then sets the appropriate final state after 

1350 completion and runs any post-execute callbacks. Meant to be called 

1351 only after another function changes the state to running. 

1352 

1353 :param mark_success: Don't run the task, mark its state as success 

1354 :param test_mode: Doesn't record success or failure in the DB 

1355 :param pool: specifies the pool to use to run the task instance 

1356 :param session: SQLAlchemy ORM Session 

1357 """ 

1358 self.test_mode = test_mode 

1359 self.refresh_from_task(self.task, pool_override=pool) 

1360 self.refresh_from_db(session=session) 

1361 self.job_id = job_id 

1362 self.hostname = get_hostname() 

1363 self.pid = os.getpid() 

1364 if not test_mode: 

1365 session.merge(self) 

1366 session.commit() 

1367 actual_start_date = timezone.utcnow() 

1368 Stats.incr(f"ti.start.{self.task.dag_id}.{self.task.task_id}") 

1369 # Initialize final state counters at zero 

1370 for state in State.task_states: 

1371 Stats.incr(f"ti.finish.{self.task.dag_id}.{self.task.task_id}.{state}", count=0) 

1372 

1373 self.task = self.task.prepare_for_execution() 

1374 context = self.get_template_context(ignore_param_exceptions=False) 

1375 try: 

1376 if not mark_success: 

1377 self._execute_task_with_callbacks(context, test_mode) 

1378 if not test_mode: 

1379 self.refresh_from_db(lock_for_update=True, session=session) 

1380 self.state = State.SUCCESS 

1381 except TaskDeferred as defer: 

1382 # The task has signalled it wants to defer execution based on 

1383 # a trigger. 

1384 self._defer_task(defer=defer, session=session) 

1385 self.log.info( 

1386 "Pausing task as DEFERRED. dag_id=%s, task_id=%s, execution_date=%s, start_date=%s", 

1387 self.dag_id, 

1388 self.task_id, 

1389 self._date_or_empty("execution_date"), 

1390 self._date_or_empty("start_date"), 

1391 ) 

1392 if not test_mode: 

1393 session.add(Log(self.state, self)) 

1394 session.merge(self) 

1395 session.commit() 

1396 return 

1397 except AirflowSkipException as e: 

1398 # Recording SKIP 

1399 # log only if exception has any arguments to prevent log flooding 

1400 if e.args: 

1401 self.log.info(e) 

1402 if not test_mode: 

1403 self.refresh_from_db(lock_for_update=True, session=session) 

1404 self.state = State.SKIPPED 

1405 except AirflowRescheduleException as reschedule_exception: 

1406 self._handle_reschedule(actual_start_date, reschedule_exception, test_mode, session=session) 

1407 session.commit() 

1408 return 

1409 except (AirflowFailException, AirflowSensorTimeout) as e: 

1410 # If AirflowFailException is raised, task should not retry. 

1411 # If a sensor in reschedule mode reaches timeout, task should not retry. 

1412 self.handle_failure(e, test_mode, context, force_fail=True, session=session) 

1413 session.commit() 

1414 raise 

1415 except AirflowException as e: 

1416 if not test_mode: 

1417 self.refresh_from_db(lock_for_update=True, session=session) 

1418 # for case when task is marked as success/failed externally 

1419 # or dagrun timed out and task is marked as skipped 

1420 # current behavior doesn't hit the callbacks 

1421 if self.state in State.finished: 

1422 self.clear_next_method_args() 

1423 session.merge(self) 

1424 session.commit() 

1425 return 

1426 else: 

1427 self.handle_failure(e, test_mode, context, session=session) 

1428 session.commit() 

1429 raise 

1430 except (Exception, KeyboardInterrupt) as e: 

1431 self.handle_failure(e, test_mode, context, session=session) 

1432 session.commit() 

1433 raise 

1434 finally: 

1435 Stats.incr(f"ti.finish.{self.dag_id}.{self.task_id}.{self.state}") 

1436 

1437 # Recording SKIPPED or SUCCESS 

1438 self.clear_next_method_args() 

1439 self.end_date = timezone.utcnow() 

1440 self._log_state() 

1441 self.set_duration() 

1442 

1443 # run on_success_callback before db committing 

1444 # otherwise, the LocalTaskJob sees the state is changed to `success`, 

1445 # but the task_runner is still running, LocalTaskJob then treats the state is set externally! 

1446 self._run_finished_callback(self.task.on_success_callback, context, "on_success") 

1447 

1448 if not test_mode: 

1449 session.add(Log(self.state, self)) 

1450 session.merge(self).task = self.task 

1451 if self.state == TaskInstanceState.SUCCESS: 

1452 self._register_dataset_changes(session=session) 

1453 session.commit() 

1454 

1455 def _register_dataset_changes(self, *, session: Session) -> None: 

1456 for obj in self.task.outlets or []: 

1457 self.log.debug("outlet obj %s", obj) 

1458 # Lineage can have other types of objects besides datasets 

1459 if isinstance(obj, Dataset): 

1460 dataset_manager.register_dataset_change( 

1461 task_instance=self, 

1462 dataset=obj, 

1463 session=session, 

1464 ) 

1465 

1466 def _execute_task_with_callbacks(self, context, test_mode=False): 

1467 """Prepare Task for Execution""" 

1468 from airflow.models.renderedtifields import RenderedTaskInstanceFields 

1469 

1470 parent_pid = os.getpid() 

1471 

1472 def signal_handler(signum, frame): 

1473 pid = os.getpid() 

1474 

1475 # If a task forks during execution (from DAG code) for whatever 

1476 # reason, we want to make sure that we react to the signal only in 

1477 # the process that we've spawned ourselves (referred to here as the 

1478 # parent process). 

1479 if pid != parent_pid: 

1480 os._exit(1) 

1481 return 

1482 self.log.error("Received SIGTERM. Terminating subprocesses.") 

1483 self.task.on_kill() 

1484 raise AirflowException("Task received SIGTERM signal") 

1485 

1486 signal.signal(signal.SIGTERM, signal_handler) 

1487 

1488 # Don't clear Xcom until the task is certain to execute, and check if we are resuming from deferral. 

1489 if not self.next_method: 

1490 self.clear_xcom_data() 

1491 

1492 with Stats.timer(f"dag.{self.task.dag_id}.{self.task.task_id}.duration"): 

1493 # Set the validated/merged params on the task object. 

1494 self.task.params = context["params"] 

1495 

1496 task_orig = self.render_templates(context=context) 

1497 if not test_mode: 

1498 rtif = RenderedTaskInstanceFields(ti=self, render_templates=False) 

1499 RenderedTaskInstanceFields.write(rtif) 

1500 RenderedTaskInstanceFields.delete_old_records(self.task_id, self.dag_id) 

1501 

1502 # Export context to make it available for operators to use. 

1503 airflow_context_vars = context_to_airflow_vars(context, in_env_var_format=True) 

1504 os.environ.update(airflow_context_vars) 

1505 

1506 # Log context only for the default execution method, the assumption 

1507 # being that otherwise we're resuming a deferred task (in which 

1508 # case there's no need to log these again). 

1509 if not self.next_method: 

1510 self.log.info( 

1511 "Exporting the following env vars:\n%s", 

1512 "\n".join(f"{k}={v}" for k, v in airflow_context_vars.items()), 

1513 ) 

1514 

1515 # Run pre_execute callback 

1516 self.task.pre_execute(context=context) 

1517 

1518 # Run on_execute callback 

1519 self._run_execute_callback(context, self.task) 

1520 

1521 # Execute the task 

1522 with set_current_context(context): 

1523 result = self._execute_task(context, task_orig) 

1524 

1525 # Run post_execute callback 

1526 self.task.post_execute(context=context, result=result) 

1527 

1528 Stats.incr(f"operator_successes_{self.task.task_type}", 1, 1) 

1529 Stats.incr("ti_successes") 

1530 

1531 def _run_finished_callback( 

1532 self, 

1533 callbacks: None | TaskStateChangeCallback | list[TaskStateChangeCallback], 

1534 context: Context, 

1535 callback_type: str, 

1536 ) -> None: 

1537 """Run callback after task finishes""" 

1538 if callbacks: 

1539 callbacks = callbacks if isinstance(callbacks, list) else [callbacks] 

1540 for callback in callbacks: 

1541 try: 

1542 callback(context) 

1543 except Exception: 

1544 callback_name = qualname(callback).split(".")[-1] 

1545 self.log.exception( 

1546 f"Error when executing {callback_name} callback" # type: ignore[attr-defined] 

1547 ) 

1548 

1549 def _execute_task(self, context, task_orig): 

1550 """Executes Task (optionally with a Timeout) and pushes Xcom results""" 

1551 task_to_execute = self.task 

1552 # If the task has been deferred and is being executed due to a trigger, 

1553 # then we need to pick the right method to come back to, otherwise 

1554 # we go for the default execute 

1555 if self.next_method: 

1556 # __fail__ is a special signal value for next_method that indicates 

1557 # this task was scheduled specifically to fail. 

1558 if self.next_method == "__fail__": 

1559 next_kwargs = self.next_kwargs or {} 

1560 traceback = self.next_kwargs.get("traceback") 

1561 if traceback is not None: 

1562 self.log.error("Trigger failed:\n%s", "\n".join(traceback)) 

1563 raise TaskDeferralError(next_kwargs.get("error", "Unknown")) 

1564 # Grab the callable off the Operator/Task and add in any kwargs 

1565 execute_callable = getattr(task_to_execute, self.next_method) 

1566 if self.next_kwargs: 

1567 execute_callable = partial(execute_callable, **self.next_kwargs) 

1568 else: 

1569 execute_callable = task_to_execute.execute 

1570 # If a timeout is specified for the task, make it fail 

1571 # if it goes beyond 

1572 if task_to_execute.execution_timeout: 

1573 # If we are coming in with a next_method (i.e. from a deferral), 

1574 # calculate the timeout from our start_date. 

1575 if self.next_method: 

1576 timeout_seconds = ( 

1577 task_to_execute.execution_timeout - (timezone.utcnow() - self.start_date) 

1578 ).total_seconds() 

1579 else: 

1580 timeout_seconds = task_to_execute.execution_timeout.total_seconds() 

1581 try: 

1582 # It's possible we're already timed out, so fast-fail if true 

1583 if timeout_seconds <= 0: 

1584 raise AirflowTaskTimeout() 

1585 # Run task in timeout wrapper 

1586 with timeout(timeout_seconds): 

1587 result = execute_callable(context=context) 

1588 except AirflowTaskTimeout: 

1589 task_to_execute.on_kill() 

1590 raise 

1591 else: 

1592 result = execute_callable(context=context) 

1593 with create_session() as session: 

1594 if task_to_execute.do_xcom_push: 

1595 xcom_value = result 

1596 else: 

1597 xcom_value = None 

1598 if xcom_value is not None: # If the task returns a result, push an XCom containing it. 

1599 self.xcom_push(key=XCOM_RETURN_KEY, value=xcom_value, session=session) 

1600 self._record_task_map_for_downstreams(task_orig, xcom_value, session=session) 

1601 return result 

1602 

1603 @provide_session 

1604 def _defer_task(self, session: Session, defer: TaskDeferred) -> None: 

1605 """ 

1606 Marks the task as deferred and sets up the trigger that is needed 

1607 to resume it. 

1608 """ 

1609 from airflow.models.trigger import Trigger 

1610 

1611 # First, make the trigger entry 

1612 trigger_row = Trigger.from_object(defer.trigger) 

1613 session.add(trigger_row) 

1614 session.flush() 

1615 

1616 # Then, update ourselves so it matches the deferral request 

1617 # Keep an eye on the logic in `check_and_change_state_before_execution()` 

1618 # depending on self.next_method semantics 

1619 self.state = State.DEFERRED 

1620 self.trigger_id = trigger_row.id 

1621 self.next_method = defer.method_name 

1622 self.next_kwargs = defer.kwargs or {} 

1623 

1624 # Decrement try number so the next one is the same try 

1625 self._try_number -= 1 

1626 

1627 # Calculate timeout too if it was passed 

1628 if defer.timeout is not None: 

1629 self.trigger_timeout = timezone.utcnow() + defer.timeout 

1630 else: 

1631 self.trigger_timeout = None 

1632 

1633 # If an execution_timeout is set, set the timeout to the minimum of 

1634 # it and the trigger timeout 

1635 execution_timeout = self.task.execution_timeout 

1636 if execution_timeout: 

1637 if self.trigger_timeout: 

1638 self.trigger_timeout = min(self.start_date + execution_timeout, self.trigger_timeout) 

1639 else: 

1640 self.trigger_timeout = self.start_date + execution_timeout 

1641 

1642 def _run_execute_callback(self, context: Context, task: Operator) -> None: 

1643 """Functions that need to be run before a Task is executed""" 

1644 callbacks = task.on_execute_callback 

1645 if callbacks: 

1646 callbacks = callbacks if isinstance(callbacks, list) else [callbacks] 

1647 for callback in callbacks: 

1648 try: 

1649 callback(context) 

1650 except Exception: 

1651 self.log.exception("Failed when executing execute callback") 

1652 

1653 @provide_session 

1654 def run( 

1655 self, 

1656 verbose: bool = True, 

1657 ignore_all_deps: bool = False, 

1658 ignore_depends_on_past: bool = False, 

1659 ignore_task_deps: bool = False, 

1660 ignore_ti_state: bool = False, 

1661 mark_success: bool = False, 

1662 test_mode: bool = False, 

1663 job_id: str | None = None, 

1664 pool: str | None = None, 

1665 session: Session = NEW_SESSION, 

1666 ) -> None: 

1667 """Run TaskInstance""" 

1668 res = self.check_and_change_state_before_execution( 

1669 verbose=verbose, 

1670 ignore_all_deps=ignore_all_deps, 

1671 ignore_depends_on_past=ignore_depends_on_past, 

1672 ignore_task_deps=ignore_task_deps, 

1673 ignore_ti_state=ignore_ti_state, 

1674 mark_success=mark_success, 

1675 test_mode=test_mode, 

1676 job_id=job_id, 

1677 pool=pool, 

1678 session=session, 

1679 ) 

1680 if not res: 

1681 return 

1682 

1683 self._run_raw_task( 

1684 mark_success=mark_success, test_mode=test_mode, job_id=job_id, pool=pool, session=session 

1685 ) 

1686 

1687 def dry_run(self) -> None: 

1688 """Only Renders Templates for the TI""" 

1689 from airflow.models.baseoperator import BaseOperator 

1690 

1691 self.task = self.task.prepare_for_execution() 

1692 self.render_templates() 

1693 if TYPE_CHECKING: 

1694 assert isinstance(self.task, BaseOperator) 

1695 self.task.dry_run() 

1696 

1697 @provide_session 

1698 def _handle_reschedule( 

1699 self, actual_start_date, reschedule_exception, test_mode=False, session=NEW_SESSION 

1700 ): 

1701 # Don't record reschedule request in test mode 

1702 if test_mode: 

1703 return 

1704 

1705 from airflow.models.dagrun import DagRun # Avoid circular import 

1706 

1707 self.refresh_from_db(session) 

1708 

1709 self.end_date = timezone.utcnow() 

1710 self.set_duration() 

1711 

1712 # Lock DAG run to be sure not to get into a deadlock situation when trying to insert 

1713 # TaskReschedule which apparently also creates lock on corresponding DagRun entity 

1714 with_row_locks( 

1715 session.query(DagRun).filter_by( 

1716 dag_id=self.dag_id, 

1717 run_id=self.run_id, 

1718 ), 

1719 session=session, 

1720 ).one() 

1721 

1722 # Log reschedule request 

1723 session.add( 

1724 TaskReschedule( 

1725 self.task, 

1726 self.run_id, 

1727 self._try_number, 

1728 actual_start_date, 

1729 self.end_date, 

1730 reschedule_exception.reschedule_date, 

1731 self.map_index, 

1732 ) 

1733 ) 

1734 

1735 # set state 

1736 self.state = State.UP_FOR_RESCHEDULE 

1737 

1738 # Decrement try_number so subsequent runs will use the same try number and write 

1739 # to same log file. 

1740 self._try_number -= 1 

1741 

1742 self.clear_next_method_args() 

1743 

1744 session.merge(self) 

1745 session.commit() 

1746 self.log.info("Rescheduling task, marking task as UP_FOR_RESCHEDULE") 

1747 

1748 @staticmethod 

1749 def get_truncated_error_traceback(error: BaseException, truncate_to: Callable) -> TracebackType | None: 

1750 """ 

1751 Truncates the traceback of an exception to the first frame called from within a given function 

1752 

1753 :param error: exception to get traceback from 

1754 :param truncate_to: Function to truncate TB to. Must have a ``__code__`` attribute 

1755 

1756 :meta private: 

1757 """ 

1758 tb = error.__traceback__ 

1759 code = truncate_to.__func__.__code__ # type: ignore[attr-defined] 

1760 while tb is not None: 

1761 if tb.tb_frame.f_code is code: 

1762 return tb.tb_next 

1763 tb = tb.tb_next 

1764 return tb or error.__traceback__ 

1765 

1766 @provide_session 

1767 def handle_failure( 

1768 self, 

1769 error: None | str | Exception | KeyboardInterrupt, 

1770 test_mode: bool | None = None, 

1771 context: Context | None = None, 

1772 force_fail: bool = False, 

1773 session: Session = NEW_SESSION, 

1774 ) -> None: 

1775 """Handle Failure for the TaskInstance""" 

1776 if test_mode is None: 

1777 test_mode = self.test_mode 

1778 

1779 if error: 

1780 if isinstance(error, BaseException): 

1781 tb = self.get_truncated_error_traceback(error, truncate_to=self._execute_task) 

1782 self.log.error("Task failed with exception", exc_info=(type(error), error, tb)) 

1783 else: 

1784 self.log.error("%s", error) 

1785 if not test_mode: 

1786 self.refresh_from_db(session) 

1787 

1788 self.end_date = timezone.utcnow() 

1789 self.set_duration() 

1790 Stats.incr(f"operator_failures_{self.operator}") 

1791 Stats.incr("ti_failures") 

1792 if not test_mode: 

1793 session.add(Log(State.FAILED, self)) 

1794 

1795 # Log failure duration 

1796 session.add(TaskFail(ti=self)) 

1797 

1798 self.clear_next_method_args() 

1799 

1800 # In extreme cases (zombie in case of dag with parse error) we might _not_ have a Task. 

1801 if context is None and getattr(self, "task", None): 

1802 context = self.get_template_context(session) 

1803 

1804 if context is not None: 

1805 context["exception"] = error 

1806 

1807 # Set state correctly and figure out how to log it and decide whether 

1808 # to email 

1809 

1810 # Note, callback invocation needs to be handled by caller of 

1811 # _run_raw_task to avoid race conditions which could lead to duplicate 

1812 # invocations or miss invocation. 

1813 

1814 # Since this function is called only when the TaskInstance state is running, 

1815 # try_number contains the current try_number (not the next). We 

1816 # only mark task instance as FAILED if the next task instance 

1817 # try_number exceeds the max_tries ... or if force_fail is truthy 

1818 

1819 task: BaseOperator | None = None 

1820 try: 

1821 if getattr(self, "task", None) and context: 

1822 task = self.task.unmap((context, session)) 

1823 except Exception: 

1824 self.log.error("Unable to unmap task to determine if we need to send an alert email") 

1825 

1826 if force_fail or not self.is_eligible_to_retry(): 

1827 self.state = State.FAILED 

1828 email_for_state = operator.attrgetter("email_on_failure") 

1829 callbacks = task.on_failure_callback if task else None 

1830 callback_type = "on_failure" 

1831 else: 

1832 if self.state == State.QUEUED: 

1833 # We increase the try_number so as to fail the task if it fails to start after sometime 

1834 self._try_number += 1 

1835 self.state = State.UP_FOR_RETRY 

1836 email_for_state = operator.attrgetter("email_on_retry") 

1837 callbacks = task.on_retry_callback if task else None 

1838 callback_type = "on_retry" 

1839 

1840 self._log_state("Immediate failure requested. " if force_fail else "") 

1841 if task and email_for_state(task) and task.email: 

1842 try: 

1843 self.email_alert(error, task) 

1844 except Exception: 

1845 self.log.exception("Failed to send email to: %s", task.email) 

1846 

1847 if callbacks and context: 

1848 self._run_finished_callback(callbacks, context, callback_type) 

1849 

1850 if not test_mode: 

1851 session.merge(self) 

1852 session.flush() 

1853 

1854 def is_eligible_to_retry(self): 

1855 """Is task instance is eligible for retry""" 

1856 if self.state == State.RESTARTING: 

1857 # If a task is cleared when running, it goes into RESTARTING state and is always 

1858 # eligible for retry 

1859 return True 

1860 if not getattr(self, "task", None): 

1861 # Couldn't load the task, don't know number of retries, guess: 

1862 return self.try_number <= self.max_tries 

1863 

1864 return self.task.retries and self.try_number <= self.max_tries 

1865 

1866 def get_template_context( 

1867 self, 

1868 session: Session | None = None, 

1869 ignore_param_exceptions: bool = True, 

1870 ) -> Context: 

1871 """Return TI Context""" 

1872 # Do not use provide_session here -- it expunges everything on exit! 

1873 if not session: 

1874 session = settings.Session() 

1875 

1876 from airflow import macros 

1877 from airflow.models.abstractoperator import NotMapped 

1878 

1879 integrate_macros_plugins() 

1880 

1881 task = self.task 

1882 if TYPE_CHECKING: 

1883 assert task.dag 

1884 dag: DAG = task.dag 

1885 

1886 dag_run = self.get_dagrun(session) 

1887 data_interval = dag.get_run_data_interval(dag_run) 

1888 

1889 validated_params = process_params(dag, task, dag_run, suppress_exception=ignore_param_exceptions) 

1890 

1891 logical_date = timezone.coerce_datetime(self.execution_date) 

1892 ds = logical_date.strftime("%Y-%m-%d") 

1893 ds_nodash = ds.replace("-", "") 

1894 ts = logical_date.isoformat() 

1895 ts_nodash = logical_date.strftime("%Y%m%dT%H%M%S") 

1896 ts_nodash_with_tz = ts.replace("-", "").replace(":", "") 

1897 

1898 @cache # Prevent multiple database access. 

1899 def _get_previous_dagrun_success() -> DagRun | None: 

1900 return self.get_previous_dagrun(state=DagRunState.SUCCESS, session=session) 

1901 

1902 def _get_previous_dagrun_data_interval_success() -> DataInterval | None: 

1903 dagrun = _get_previous_dagrun_success() 

1904 if dagrun is None: 

1905 return None 

1906 return dag.get_run_data_interval(dagrun) 

1907 

1908 def get_prev_data_interval_start_success() -> pendulum.DateTime | None: 

1909 data_interval = _get_previous_dagrun_data_interval_success() 

1910 if data_interval is None: 

1911 return None 

1912 return data_interval.start 

1913 

1914 def get_prev_data_interval_end_success() -> pendulum.DateTime | None: 

1915 data_interval = _get_previous_dagrun_data_interval_success() 

1916 if data_interval is None: 

1917 return None 

1918 return data_interval.end 

1919 

1920 def get_prev_start_date_success() -> pendulum.DateTime | None: 

1921 dagrun = _get_previous_dagrun_success() 

1922 if dagrun is None: 

1923 return None 

1924 return timezone.coerce_datetime(dagrun.start_date) 

1925 

1926 @cache 

1927 def get_yesterday_ds() -> str: 

1928 return (logical_date - timedelta(1)).strftime("%Y-%m-%d") 

1929 

1930 def get_yesterday_ds_nodash() -> str: 

1931 return get_yesterday_ds().replace("-", "") 

1932 

1933 @cache 

1934 def get_tomorrow_ds() -> str: 

1935 return (logical_date + timedelta(1)).strftime("%Y-%m-%d") 

1936 

1937 def get_tomorrow_ds_nodash() -> str: 

1938 return get_tomorrow_ds().replace("-", "") 

1939 

1940 @cache 

1941 def get_next_execution_date() -> pendulum.DateTime | None: 

1942 # For manually triggered dagruns that aren't run on a schedule, 

1943 # the "next" execution date doesn't make sense, and should be set 

1944 # to execution date for consistency with how execution_date is set 

1945 # for manually triggered tasks, i.e. triggered_date == execution_date. 

1946 if dag_run.external_trigger: 

1947 return logical_date 

1948 if dag is None: 

1949 return None 

1950 next_info = dag.next_dagrun_info(data_interval, restricted=False) 

1951 if next_info is None: 

1952 return None 

1953 return timezone.coerce_datetime(next_info.logical_date) 

1954 

1955 def get_next_ds() -> str | None: 

1956 execution_date = get_next_execution_date() 

1957 if execution_date is None: 

1958 return None 

1959 return execution_date.strftime("%Y-%m-%d") 

1960 

1961 def get_next_ds_nodash() -> str | None: 

1962 ds = get_next_ds() 

1963 if ds is None: 

1964 return ds 

1965 return ds.replace("-", "") 

1966 

1967 @cache 

1968 def get_prev_execution_date(): 

1969 # For manually triggered dagruns that aren't run on a schedule, 

1970 # the "previous" execution date doesn't make sense, and should be set 

1971 # to execution date for consistency with how execution_date is set 

1972 # for manually triggered tasks, i.e. triggered_date == execution_date. 

1973 if dag_run.external_trigger: 

1974 return logical_date 

1975 with warnings.catch_warnings(): 

1976 warnings.simplefilter("ignore", RemovedInAirflow3Warning) 

1977 return dag.previous_schedule(logical_date) 

1978 

1979 @cache 

1980 def get_prev_ds() -> str | None: 

1981 execution_date = get_prev_execution_date() 

1982 if execution_date is None: 

1983 return None 

1984 return execution_date.strftime(r"%Y-%m-%d") 

1985 

1986 def get_prev_ds_nodash() -> str | None: 

1987 prev_ds = get_prev_ds() 

1988 if prev_ds is None: 

1989 return None 

1990 return prev_ds.replace("-", "") 

1991 

1992 def get_triggering_events() -> dict[str, list[DatasetEvent]]: 

1993 if TYPE_CHECKING: 

1994 assert session is not None 

1995 

1996 # The dag_run may not be attached to the session anymore since the 

1997 # code base is over-zealous with use of session.expunge_all(). 

1998 # Re-attach it if we get called. 

1999 nonlocal dag_run 

2000 if dag_run not in session: 

2001 dag_run = session.merge(dag_run, load=False) 

2002 

2003 dataset_events = dag_run.consumed_dataset_events 

2004 triggering_events: dict[str, list[DatasetEvent]] = defaultdict(list) 

2005 for event in dataset_events: 

2006 triggering_events[event.dataset.uri].append(event) 

2007 

2008 return triggering_events 

2009 

2010 try: 

2011 expanded_ti_count: int | None = task.get_mapped_ti_count(self.run_id, session=session) 

2012 except NotMapped: 

2013 expanded_ti_count = None 

2014 

2015 # NOTE: If you add anything to this dict, make sure to also update the 

2016 # definition in airflow/utils/context.pyi, and KNOWN_CONTEXT_KEYS in 

2017 # airflow/utils/context.py! 

2018 context = { 

2019 "conf": conf, 

2020 "dag": dag, 

2021 "dag_run": dag_run, 

2022 "data_interval_end": timezone.coerce_datetime(data_interval.end), 

2023 "data_interval_start": timezone.coerce_datetime(data_interval.start), 

2024 "ds": ds, 

2025 "ds_nodash": ds_nodash, 

2026 "execution_date": logical_date, 

2027 "expanded_ti_count": expanded_ti_count, 

2028 "inlets": task.inlets, 

2029 "logical_date": logical_date, 

2030 "macros": macros, 

2031 "next_ds": get_next_ds(), 

2032 "next_ds_nodash": get_next_ds_nodash(), 

2033 "next_execution_date": get_next_execution_date(), 

2034 "outlets": task.outlets, 

2035 "params": validated_params, 

2036 "prev_data_interval_start_success": get_prev_data_interval_start_success(), 

2037 "prev_data_interval_end_success": get_prev_data_interval_end_success(), 

2038 "prev_ds": get_prev_ds(), 

2039 "prev_ds_nodash": get_prev_ds_nodash(), 

2040 "prev_execution_date": get_prev_execution_date(), 

2041 "prev_execution_date_success": self.get_previous_execution_date( 

2042 state=DagRunState.SUCCESS, 

2043 session=session, 

2044 ), 

2045 "prev_start_date_success": get_prev_start_date_success(), 

2046 "run_id": self.run_id, 

2047 "task": task, 

2048 "task_instance": self, 

2049 "task_instance_key_str": f"{task.dag_id}__{task.task_id}__{ds_nodash}", 

2050 "test_mode": self.test_mode, 

2051 "ti": self, 

2052 "tomorrow_ds": get_tomorrow_ds(), 

2053 "tomorrow_ds_nodash": get_tomorrow_ds_nodash(), 

2054 "triggering_dataset_events": lazy_object_proxy.Proxy(get_triggering_events), 

2055 "ts": ts, 

2056 "ts_nodash": ts_nodash, 

2057 "ts_nodash_with_tz": ts_nodash_with_tz, 

2058 "var": { 

2059 "json": VariableAccessor(deserialize_json=True), 

2060 "value": VariableAccessor(deserialize_json=False), 

2061 }, 

2062 "conn": ConnectionAccessor(), 

2063 "yesterday_ds": get_yesterday_ds(), 

2064 "yesterday_ds_nodash": get_yesterday_ds_nodash(), 

2065 } 

2066 # Mypy doesn't like turning existing dicts in to a TypeDict -- and we "lie" in the type stub to say it 

2067 # is one, but in practice it isn't. See https://github.com/python/mypy/issues/8890 

2068 return Context(context) # type: ignore 

2069 

2070 @provide_session 

2071 def get_rendered_template_fields(self, session: Session = NEW_SESSION) -> None: 

2072 """ 

2073 Update task with rendered template fields for presentation in UI. 

2074 If task has already run, will fetch from DB; otherwise will render. 

2075 """ 

2076 from airflow.models.renderedtifields import RenderedTaskInstanceFields 

2077 

2078 rendered_task_instance_fields = RenderedTaskInstanceFields.get_templated_fields(self, session=session) 

2079 if rendered_task_instance_fields: 

2080 self.task = self.task.unmap(None) 

2081 for field_name, rendered_value in rendered_task_instance_fields.items(): 

2082 setattr(self.task, field_name, rendered_value) 

2083 return 

2084 

2085 try: 

2086 # If we get here, either the task hasn't run or the RTIF record was purged. 

2087 from airflow.utils.log.secrets_masker import redact 

2088 

2089 self.render_templates() 

2090 for field_name in self.task.template_fields: 

2091 rendered_value = getattr(self.task, field_name) 

2092 setattr(self.task, field_name, redact(rendered_value, field_name)) 

2093 except (TemplateAssertionError, UndefinedError) as e: 

2094 raise AirflowException( 

2095 "Webserver does not have access to User-defined Macros or Filters " 

2096 "when Dag Serialization is enabled. Hence for the task that have not yet " 

2097 "started running, please use 'airflow tasks render' for debugging the " 

2098 "rendering of template_fields." 

2099 ) from e 

2100 

2101 @provide_session 

2102 def get_rendered_k8s_spec(self, session: Session = NEW_SESSION): 

2103 """Fetch rendered template fields from DB""" 

2104 from airflow.models.renderedtifields import RenderedTaskInstanceFields 

2105 

2106 rendered_k8s_spec = RenderedTaskInstanceFields.get_k8s_pod_yaml(self, session=session) 

2107 if not rendered_k8s_spec: 

2108 try: 

2109 rendered_k8s_spec = self.render_k8s_pod_yaml() 

2110 except (TemplateAssertionError, UndefinedError) as e: 

2111 raise AirflowException(f"Unable to render a k8s spec for this taskinstance: {e}") from e 

2112 return rendered_k8s_spec 

2113 

2114 def overwrite_params_with_dag_run_conf(self, params, dag_run): 

2115 """Overwrite Task Params with DagRun.conf""" 

2116 if dag_run and dag_run.conf: 

2117 self.log.debug("Updating task params (%s) with DagRun.conf (%s)", params, dag_run.conf) 

2118 params.update(dag_run.conf) 

2119 

2120 def render_templates(self, context: Context | None = None) -> Operator: 

2121 """Render templates in the operator fields. 

2122 

2123 If the task was originally mapped, this may replace ``self.task`` with 

2124 the unmapped, fully rendered BaseOperator. The original ``self.task`` 

2125 before replacement is returned. 

2126 """ 

2127 if not context: 

2128 context = self.get_template_context() 

2129 original_task = self.task 

2130 

2131 # If self.task is mapped, this call replaces self.task to point to the 

2132 # unmapped BaseOperator created by this function! This is because the 

2133 # MappedOperator is useless for template rendering, and we need to be 

2134 # able to access the unmapped task instead. 

2135 original_task.render_template_fields(context) 

2136 

2137 return original_task 

2138 

2139 def render_k8s_pod_yaml(self) -> dict | None: 

2140 """Render k8s pod yaml""" 

2141 from kubernetes.client.api_client import ApiClient 

2142 

2143 from airflow.kubernetes.kube_config import KubeConfig 

2144 from airflow.kubernetes.kubernetes_helper_functions import create_pod_id # Circular import 

2145 from airflow.kubernetes.pod_generator import PodGenerator 

2146 

2147 kube_config = KubeConfig() 

2148 pod = PodGenerator.construct_pod( 

2149 dag_id=self.dag_id, 

2150 run_id=self.run_id, 

2151 task_id=self.task_id, 

2152 map_index=self.map_index, 

2153 date=None, 

2154 pod_id=create_pod_id(self.dag_id, self.task_id), 

2155 try_number=self.try_number, 

2156 kube_image=kube_config.kube_image, 

2157 args=self.command_as_list(), 

2158 pod_override_object=PodGenerator.from_obj(self.executor_config), 

2159 scheduler_job_id="0", 

2160 namespace=kube_config.executor_namespace, 

2161 base_worker_pod=PodGenerator.deserialize_model_file(kube_config.pod_template_file), 

2162 with_mutation_hook=True, 

2163 ) 

2164 sanitized_pod = ApiClient().sanitize_for_serialization(pod) 

2165 return sanitized_pod 

2166 

2167 def get_email_subject_content( 

2168 self, exception: BaseException, task: BaseOperator | None = None 

2169 ) -> tuple[str, str, str]: 

2170 """Get the email subject content for exceptions.""" 

2171 # For a ti from DB (without ti.task), return the default value 

2172 if task is None: 

2173 task = getattr(self, "task") 

2174 use_default = task is None 

2175 exception_html = str(exception).replace("\n", "<br>") 

2176 

2177 default_subject = "Airflow alert: {{ti}}" 

2178 # For reporting purposes, we report based on 1-indexed, 

2179 # not 0-indexed lists (i.e. Try 1 instead of 

2180 # Try 0 for the first attempt). 

2181 default_html_content = ( 

2182 "Try {{try_number}} out of {{max_tries + 1}}<br>" 

2183 "Exception:<br>{{exception_html}}<br>" 

2184 'Log: <a href="{{ti.log_url}}">Link</a><br>' 

2185 "Host: {{ti.hostname}}<br>" 

2186 'Mark success: <a href="{{ti.mark_success_url}}">Link</a><br>' 

2187 ) 

2188 

2189 default_html_content_err = ( 

2190 "Try {{try_number}} out of {{max_tries + 1}}<br>" 

2191 "Exception:<br>Failed attempt to attach error logs<br>" 

2192 'Log: <a href="{{ti.log_url}}">Link</a><br>' 

2193 "Host: {{ti.hostname}}<br>" 

2194 'Mark success: <a href="{{ti.mark_success_url}}">Link</a><br>' 

2195 ) 

2196 

2197 # This function is called after changing the state from State.RUNNING, 

2198 # so we need to subtract 1 from self.try_number here. 

2199 current_try_number = self.try_number - 1 

2200 additional_context: dict[str, Any] = { 

2201 "exception": exception, 

2202 "exception_html": exception_html, 

2203 "try_number": current_try_number, 

2204 "max_tries": self.max_tries, 

2205 } 

2206 

2207 if use_default: 

2208 default_context = {"ti": self, **additional_context} 

2209 jinja_env = jinja2.Environment( 

2210 loader=jinja2.FileSystemLoader(os.path.dirname(__file__)), autoescape=True 

2211 ) 

2212 subject = jinja_env.from_string(default_subject).render(**default_context) 

2213 html_content = jinja_env.from_string(default_html_content).render(**default_context) 

2214 html_content_err = jinja_env.from_string(default_html_content_err).render(**default_context) 

2215 

2216 else: 

2217 # Use the DAG's get_template_env() to set force_sandboxed. Don't add 

2218 # the flag to the function on task object -- that function can be 

2219 # overridden, and adding a flag breaks backward compatibility. 

2220 dag = self.task.get_dag() 

2221 if dag: 

2222 jinja_env = dag.get_template_env(force_sandboxed=True) 

2223 else: 

2224 jinja_env = SandboxedEnvironment(cache_size=0) 

2225 jinja_context = self.get_template_context() 

2226 context_merge(jinja_context, additional_context) 

2227 

2228 def render(key: str, content: str) -> str: 

2229 if conf.has_option("email", key): 

2230 path = conf.get_mandatory_value("email", key) 

2231 try: 

2232 with open(path) as f: 

2233 content = f.read() 

2234 except FileNotFoundError: 

2235 self.log.warning(f"Could not find email template file '{path!r}'. Using defaults...") 

2236 except OSError: 

2237 self.log.exception(f"Error while using email template '{path!r}'. Using defaults...") 

2238 return render_template_to_string(jinja_env.from_string(content), jinja_context) 

2239 

2240 subject = render("subject_template", default_subject) 

2241 html_content = render("html_content_template", default_html_content) 

2242 html_content_err = render("html_content_template", default_html_content_err) 

2243 

2244 return subject, html_content, html_content_err 

2245 

2246 def email_alert(self, exception, task: BaseOperator) -> None: 

2247 """Send alert email with exception information.""" 

2248 subject, html_content, html_content_err = self.get_email_subject_content(exception, task=task) 

2249 assert task.email 

2250 try: 

2251 send_email(task.email, subject, html_content) 

2252 except Exception: 

2253 send_email(task.email, subject, html_content_err) 

2254 

2255 def set_duration(self) -> None: 

2256 """Set TI duration""" 

2257 if self.end_date and self.start_date: 

2258 self.duration = (self.end_date - self.start_date).total_seconds() 

2259 else: 

2260 self.duration = None 

2261 self.log.debug("Task Duration set to %s", self.duration) 

2262 

2263 def _record_task_map_for_downstreams(self, task: Operator, value: Any, *, session: Session) -> None: 

2264 if next(task.iter_mapped_dependants(), None) is None: # No mapped dependants, no need to validate. 

2265 return 

2266 # TODO: We don't push TaskMap for mapped task instances because it's not 

2267 # currently possible for a downstream to depend on one individual mapped 

2268 # task instance. This will change when we implement task mapping inside 

2269 # a mapped task group, and we'll need to further analyze the case. 

2270 if isinstance(task, MappedOperator): 

2271 return 

2272 if value is None: 

2273 raise XComForMappingNotPushed() 

2274 if not _is_mappable_value(value): 

2275 raise UnmappableXComTypePushed(value) 

2276 task_map = TaskMap.from_task_instance_xcom(self, value) 

2277 max_map_length = conf.getint("core", "max_map_length", fallback=1024) 

2278 if task_map.length > max_map_length: 

2279 raise UnmappableXComLengthPushed(value, max_map_length) 

2280 session.merge(task_map) 

2281 

2282 @provide_session 

2283 def xcom_push( 

2284 self, 

2285 key: str, 

2286 value: Any, 

2287 execution_date: datetime | None = None, 

2288 session: Session = NEW_SESSION, 

2289 ) -> None: 

2290 """ 

2291 Make an XCom available for tasks to pull. 

2292 

2293 :param key: Key to store the value under. 

2294 :param value: Value to store. What types are possible depends on whether 

2295 ``enable_xcom_pickling`` is true or not. If so, this can be any 

2296 picklable object; only be JSON-serializable may be used otherwise. 

2297 :param execution_date: Deprecated parameter that has no effect. 

2298 """ 

2299 if execution_date is not None: 

2300 self_execution_date = self.get_dagrun(session).execution_date 

2301 if execution_date < self_execution_date: 

2302 raise ValueError( 

2303 f"execution_date can not be in the past (current execution_date is " 

2304 f"{self_execution_date}; received {execution_date})" 

2305 ) 

2306 elif execution_date is not None: 

2307 message = "Passing 'execution_date' to 'TaskInstance.xcom_push()' is deprecated." 

2308 warnings.warn(message, RemovedInAirflow3Warning, stacklevel=3) 

2309 

2310 XCom.set( 

2311 key=key, 

2312 value=value, 

2313 task_id=self.task_id, 

2314 dag_id=self.dag_id, 

2315 run_id=self.run_id, 

2316 map_index=self.map_index, 

2317 session=session, 

2318 ) 

2319 

2320 @provide_session 

2321 def xcom_pull( 

2322 self, 

2323 task_ids: str | Iterable[str] | None = None, 

2324 dag_id: str | None = None, 

2325 key: str = XCOM_RETURN_KEY, 

2326 include_prior_dates: bool = False, 

2327 session: Session = NEW_SESSION, 

2328 *, 

2329 map_indexes: int | Iterable[int] | None = None, 

2330 default: Any = None, 

2331 ) -> Any: 

2332 """Pull XComs that optionally meet certain criteria. 

2333 

2334 :param key: A key for the XCom. If provided, only XComs with matching 

2335 keys will be returned. The default key is ``'return_value'``, also 

2336 available as constant ``XCOM_RETURN_KEY``. This key is automatically 

2337 given to XComs returned by tasks (as opposed to being pushed 

2338 manually). To remove the filter, pass *None*. 

2339 :param task_ids: Only XComs from tasks with matching ids will be 

2340 pulled. Pass *None* to remove the filter. 

2341 :param dag_id: If provided, only pulls XComs from this DAG. If *None* 

2342 (default), the DAG of the calling task is used. 

2343 :param map_indexes: If provided, only pull XComs with matching indexes. 

2344 If *None* (default), this is inferred from the task(s) being pulled 

2345 (see below for details). 

2346 :param include_prior_dates: If False, only XComs from the current 

2347 execution_date are returned. If *True*, XComs from previous dates 

2348 are returned as well. 

2349 

2350 When pulling one single task (``task_id`` is *None* or a str) without 

2351 specifying ``map_indexes``, the return value is inferred from whether 

2352 the specified task is mapped. If not, value from the one single task 

2353 instance is returned. If the task to pull is mapped, an iterator (not a 

2354 list) yielding XComs from mapped task instances is returned. In either 

2355 case, ``default`` (*None* if not specified) is returned if no matching 

2356 XComs are found. 

2357 

2358 When pulling multiple tasks (i.e. either ``task_id`` or ``map_index`` is 

2359 a non-str iterable), a list of matching XComs is returned. Elements in 

2360 the list is ordered by item ordering in ``task_id`` and ``map_index``. 

2361 """ 

2362 if dag_id is None: 

2363 dag_id = self.dag_id 

2364 

2365 query = XCom.get_many( 

2366 key=key, 

2367 run_id=self.run_id, 

2368 dag_ids=dag_id, 

2369 task_ids=task_ids, 

2370 map_indexes=map_indexes, 

2371 include_prior_dates=include_prior_dates, 

2372 session=session, 

2373 ) 

2374 

2375 # NOTE: Since we're only fetching the value field and not the whole 

2376 # class, the @recreate annotation does not kick in. Therefore we need to 

2377 # call XCom.deserialize_value() manually. 

2378 

2379 # We are only pulling one single task. 

2380 if (task_ids is None or isinstance(task_ids, str)) and not isinstance(map_indexes, Iterable): 

2381 first = query.with_entities( 

2382 XCom.run_id, XCom.task_id, XCom.dag_id, XCom.map_index, XCom.value 

2383 ).first() 

2384 if first is None: # No matching XCom at all. 

2385 return default 

2386 if map_indexes is not None or first.map_index < 0: 

2387 return XCom.deserialize_value(first) 

2388 query = query.order_by(None).order_by(XCom.map_index.asc()) 

2389 return LazyXComAccess.build_from_xcom_query(query) 

2390 

2391 # At this point either task_ids or map_indexes is explicitly multi-value. 

2392 # Order return values to match task_ids and map_indexes ordering. 

2393 query = query.order_by(None) 

2394 if task_ids is None or isinstance(task_ids, str): 

2395 query = query.order_by(XCom.task_id) 

2396 else: 

2397 task_id_whens = {tid: i for i, tid in enumerate(task_ids)} 

2398 if task_id_whens: 

2399 query = query.order_by(case(task_id_whens, value=XCom.task_id)) 

2400 else: 

2401 query = query.order_by(XCom.task_id) 

2402 if map_indexes is None or isinstance(map_indexes, int): 

2403 query = query.order_by(XCom.map_index) 

2404 elif isinstance(map_indexes, range): 

2405 order = XCom.map_index 

2406 if map_indexes.step < 0: 

2407 order = order.desc() 

2408 query = query.order_by(order) 

2409 else: 

2410 map_index_whens = {map_index: i for i, map_index in enumerate(map_indexes)} 

2411 if map_index_whens: 

2412 query = query.order_by(case(map_index_whens, value=XCom.map_index)) 

2413 else: 

2414 query = query.order_by(XCom.map_index) 

2415 return LazyXComAccess.build_from_xcom_query(query) 

2416 

2417 @provide_session 

2418 def get_num_running_task_instances(self, session: Session) -> int: 

2419 """Return Number of running TIs from the DB""" 

2420 # .count() is inefficient 

2421 return ( 

2422 session.query(func.count()) 

2423 .filter( 

2424 TaskInstance.dag_id == self.dag_id, 

2425 TaskInstance.task_id == self.task_id, 

2426 TaskInstance.state == State.RUNNING, 

2427 ) 

2428 .scalar() 

2429 ) 

2430 

2431 def init_run_context(self, raw: bool = False) -> None: 

2432 """Sets the log context.""" 

2433 self.raw = raw 

2434 self._set_context(self) 

2435 

2436 @staticmethod 

2437 def filter_for_tis(tis: Iterable[TaskInstance | TaskInstanceKey]) -> BooleanClauseList | None: 

2438 """Returns SQLAlchemy filter to query selected task instances""" 

2439 # DictKeys type, (what we often pass here from the scheduler) is not directly indexable :( 

2440 # Or it might be a generator, but we need to be able to iterate over it more than once 

2441 tis = list(tis) 

2442 

2443 if not tis: 

2444 return None 

2445 

2446 first = tis[0] 

2447 

2448 dag_id = first.dag_id 

2449 run_id = first.run_id 

2450 map_index = first.map_index 

2451 first_task_id = first.task_id 

2452 

2453 # pre-compute the set of dag_id, run_id, map_indices and task_ids 

2454 dag_ids, run_ids, map_indices, task_ids = set(), set(), set(), set() 

2455 for t in tis: 

2456 dag_ids.add(t.dag_id) 

2457 run_ids.add(t.run_id) 

2458 map_indices.add(t.map_index) 

2459 task_ids.add(t.task_id) 

2460 

2461 # Common path optimisations: when all TIs are for the same dag_id and run_id, or same dag_id 

2462 # and task_id -- this can be over 150x faster for huge numbers of TIs (20k+) 

2463 if dag_ids == {dag_id} and run_ids == {run_id} and map_indices == {map_index}: 

2464 return and_( 

2465 TaskInstance.dag_id == dag_id, 

2466 TaskInstance.run_id == run_id, 

2467 TaskInstance.map_index == map_index, 

2468 TaskInstance.task_id.in_(task_ids), 

2469 ) 

2470 if dag_ids == {dag_id} and task_ids == {first_task_id} and map_indices == {map_index}: 

2471 return and_( 

2472 TaskInstance.dag_id == dag_id, 

2473 TaskInstance.run_id.in_(run_ids), 

2474 TaskInstance.map_index == map_index, 

2475 TaskInstance.task_id == first_task_id, 

2476 ) 

2477 if dag_ids == {dag_id} and run_ids == {run_id} and task_ids == {first_task_id}: 

2478 return and_( 

2479 TaskInstance.dag_id == dag_id, 

2480 TaskInstance.run_id == run_id, 

2481 TaskInstance.map_index.in_(map_indices), 

2482 TaskInstance.task_id == first_task_id, 

2483 ) 

2484 

2485 filter_condition = [] 

2486 # create 2 nested groups, both primarily grouped by dag_id and run_id, 

2487 # and in the nested group 1 grouped by task_id the other by map_index. 

2488 task_id_groups: dict[tuple, dict[Any, list[Any]]] = defaultdict(lambda: defaultdict(list)) 

2489 map_index_groups: dict[tuple, dict[Any, list[Any]]] = defaultdict(lambda: defaultdict(list)) 

2490 for t in tis: 

2491 task_id_groups[(t.dag_id, t.run_id)][t.task_id].append(t.map_index) 

2492 map_index_groups[(t.dag_id, t.run_id)][t.map_index].append(t.task_id) 

2493 

2494 # this assumes that most dags have dag_id as the largest grouping, followed by run_id. even 

2495 # if its not, this is still a significant optimization over querying for every single tuple key 

2496 for cur_dag_id in dag_ids: 

2497 for cur_run_id in run_ids: 

2498 # we compare the group size between task_id and map_index and use the smaller group 

2499 dag_task_id_groups = task_id_groups[(cur_dag_id, cur_run_id)] 

2500 dag_map_index_groups = map_index_groups[(cur_dag_id, cur_run_id)] 

2501 

2502 if len(dag_task_id_groups) <= len(dag_map_index_groups): 

2503 for cur_task_id, cur_map_indices in dag_task_id_groups.items(): 

2504 filter_condition.append( 

2505 and_( 

2506 TaskInstance.dag_id == cur_dag_id, 

2507 TaskInstance.run_id == cur_run_id, 

2508 TaskInstance.task_id == cur_task_id, 

2509 TaskInstance.map_index.in_(cur_map_indices), 

2510 ) 

2511 ) 

2512 else: 

2513 for cur_map_index, cur_task_ids in dag_map_index_groups.items(): 

2514 filter_condition.append( 

2515 and_( 

2516 TaskInstance.dag_id == cur_dag_id, 

2517 TaskInstance.run_id == cur_run_id, 

2518 TaskInstance.task_id.in_(cur_task_ids), 

2519 TaskInstance.map_index == cur_map_index, 

2520 ) 

2521 ) 

2522 

2523 return or_(*filter_condition) 

2524 

2525 @classmethod 

2526 def ti_selector_condition(cls, vals: Collection[str | tuple[str, int]]) -> ColumnOperators: 

2527 """ 

2528 Build an SQLAlchemy filter for a list where each element can contain 

2529 whether a task_id, or a tuple of (task_id,map_index) 

2530 

2531 :meta private: 

2532 """ 

2533 # Compute a filter for TI.task_id and TI.map_index based on input values 

2534 # For each item, it will either be a task_id, or (task_id, map_index) 

2535 task_id_only = [v for v in vals if isinstance(v, str)] 

2536 with_map_index = [v for v in vals if not isinstance(v, str)] 

2537 

2538 filters: list[ColumnOperators] = [] 

2539 if task_id_only: 

2540 filters.append(cls.task_id.in_(task_id_only)) 

2541 if with_map_index: 

2542 filters.append(tuple_in_condition((cls.task_id, cls.map_index), with_map_index)) 

2543 

2544 if not filters: 

2545 return false() 

2546 if len(filters) == 1: 

2547 return filters[0] 

2548 return or_(*filters) 

2549 

2550 @Sentry.enrich_errors 

2551 @provide_session 

2552 def schedule_downstream_tasks(self, session=None): 

2553 """ 

2554 The mini-scheduler for scheduling downstream tasks of this task instance 

2555 :meta: private 

2556 """ 

2557 from sqlalchemy.exc import OperationalError 

2558 

2559 from airflow.models import DagRun 

2560 

2561 try: 

2562 # Re-select the row with a lock 

2563 dag_run = with_row_locks( 

2564 session.query(DagRun).filter_by( 

2565 dag_id=self.dag_id, 

2566 run_id=self.run_id, 

2567 ), 

2568 session=session, 

2569 ).one() 

2570 

2571 task = self.task 

2572 if TYPE_CHECKING: 

2573 assert task.dag 

2574 

2575 # Get a partial DAG with just the specific tasks we want to examine. 

2576 # In order for dep checks to work correctly, we include ourself (so 

2577 # TriggerRuleDep can check the state of the task we just executed). 

2578 partial_dag = task.dag.partial_subset( 

2579 task.downstream_task_ids, 

2580 include_downstream=True, 

2581 include_upstream=False, 

2582 include_direct_upstream=True, 

2583 ) 

2584 

2585 dag_run.dag = partial_dag 

2586 info = dag_run.task_instance_scheduling_decisions(session) 

2587 

2588 skippable_task_ids = { 

2589 task_id for task_id in partial_dag.task_ids if task_id not in task.downstream_task_ids 

2590 } 

2591 

2592 schedulable_tis = [ti for ti in info.schedulable_tis if ti.task_id not in skippable_task_ids] 

2593 for schedulable_ti in schedulable_tis: 

2594 if not hasattr(schedulable_ti, "task"): 

2595 schedulable_ti.task = task.dag.get_task(schedulable_ti.task_id) 

2596 

2597 num = dag_run.schedule_tis(schedulable_tis, session=session) 

2598 self.log.info("%d downstream tasks scheduled from follow-on schedule check", num) 

2599 

2600 session.flush() 

2601 

2602 except OperationalError as e: 

2603 # Any kind of DB error here is _non fatal_ as this block is just an optimisation. 

2604 self.log.info( 

2605 "Skipping mini scheduling run due to exception: %s", 

2606 e.statement, 

2607 exc_info=True, 

2608 ) 

2609 session.rollback() 

2610 

2611 def get_relevant_upstream_map_indexes( 

2612 self, 

2613 upstream: Operator, 

2614 ti_count: int | None, 

2615 *, 

2616 session: Session, 

2617 ) -> int | range | None: 

2618 """Infer the map indexes of an upstream "relevant" to this ti. 

2619 

2620 The bulk of the logic mainly exists to solve the problem described by 

2621 the following example, where 'val' must resolve to different values, 

2622 depending on where the reference is being used:: 

2623 

2624 @task 

2625 def this_task(v): # This is self.task. 

2626 return v * 2 

2627 

2628 @task_group 

2629 def tg1(inp): 

2630 val = upstream(inp) # This is the upstream task. 

2631 this_task(val) # When inp is 1, val here should resolve to 2. 

2632 return val 

2633 

2634 # This val is the same object returned by tg1. 

2635 val = tg1.expand(inp=[1, 2, 3]) 

2636 

2637 @task_group 

2638 def tg2(inp): 

2639 another_task(inp, val) # val here should resolve to [2, 4, 6]. 

2640 

2641 tg2.expand(inp=["a", "b"]) 

2642 

2643 The surrounding mapped task groups of ``upstream`` and ``self.task`` are 

2644 inspected to find a common "ancestor". If such an ancestor is found, 

2645 we need to return specific map indexes to pull a partial value from 

2646 upstream XCom. 

2647 

2648 :param upstream: The referenced upstream task. 

2649 :param ti_count: The total count of task instance this task was expanded 

2650 by the scheduler, i.e. ``expanded_ti_count`` in the template context. 

2651 :return: Specific map index or map indexes to pull, or ``None`` if we 

2652 want to "whole" return value (i.e. no mapped task groups involved). 

2653 """ 

2654 # Find the innermost common mapped task group between the current task 

2655 # If the current task and the referenced task does not have a common 

2656 # mapped task group, the two are in different task mapping contexts 

2657 # (like another_task above), and we should use the "whole" value. 

2658 common_ancestor = _find_common_ancestor_mapped_group(self.task, upstream) 

2659 if common_ancestor is None: 

2660 return None 

2661 

2662 # This value should never be None since we already know the current task 

2663 # is in a mapped task group, and should have been expanded. The check 

2664 # exists mainly to satisfy Mypy. 

2665 if ti_count is None: 

2666 return None 

2667 

2668 # At this point we know the two tasks share a mapped task group, and we 

2669 # should use a "partial" value. Let's break down the mapped ti count 

2670 # between the ancestor and further expansion happened inside it. 

2671 ancestor_ti_count = common_ancestor.get_mapped_ti_count(self.run_id, session=session) 

2672 ancestor_map_index = self.map_index * ancestor_ti_count // ti_count 

2673 

2674 # If the task is NOT further expanded inside the common ancestor, we 

2675 # only want to reference one single ti. We must walk the actual DAG, 

2676 # and "ti_count == ancestor_ti_count" does not work, since the further 

2677 # expansion may be of length 1. 

2678 if not _is_further_mapped_inside(upstream, common_ancestor): 

2679 return ancestor_map_index 

2680 

2681 # Otherwise we need a partial aggregation for values from selected task 

2682 # instances in the ancestor's expansion context. 

2683 further_count = ti_count // ancestor_ti_count 

2684 map_index_start = ancestor_map_index * further_count 

2685 return range(map_index_start, map_index_start + further_count) 

2686 

2687 

2688def _find_common_ancestor_mapped_group(node1: Operator, node2: Operator) -> MappedTaskGroup | None: 

2689 """Given two operators, find their innermost common mapped task group.""" 

2690 if node1.dag is None or node2.dag is None or node1.dag_id != node2.dag_id: 

2691 return None 

2692 parent_group_ids = {g.group_id for g in node1.iter_mapped_task_groups()} 

2693 common_groups = (g for g in node2.iter_mapped_task_groups() if g.group_id in parent_group_ids) 

2694 return next(common_groups, None) 

2695 

2696 

2697def _is_further_mapped_inside(operator: Operator, container: TaskGroup) -> bool: 

2698 """Whether given operator is *further* mapped inside a task group.""" 

2699 if isinstance(operator, MappedOperator): 

2700 return True 

2701 task_group = operator.task_group 

2702 while task_group is not None and task_group.group_id != container.group_id: 

2703 if isinstance(task_group, MappedTaskGroup): 

2704 return True 

2705 task_group = task_group.parent_group 

2706 return False 

2707 

2708 

2709# State of the task instance. 

2710# Stores string version of the task state. 

2711TaskInstanceStateType = Tuple[TaskInstanceKey, str] 

2712 

2713 

2714class SimpleTaskInstance: 

2715 """ 

2716 Simplified Task Instance. 

2717 

2718 Used to send data between processes via Queues. 

2719 """ 

2720 

2721 def __init__( 

2722 self, 

2723 dag_id: str, 

2724 task_id: str, 

2725 run_id: str, 

2726 start_date: datetime | None, 

2727 end_date: datetime | None, 

2728 try_number: int, 

2729 map_index: int, 

2730 state: str, 

2731 executor_config: Any, 

2732 pool: str, 

2733 queue: str, 

2734 key: TaskInstanceKey, 

2735 run_as_user: str | None = None, 

2736 priority_weight: int | None = None, 

2737 ): 

2738 self.dag_id = dag_id 

2739 self.task_id = task_id 

2740 self.run_id = run_id 

2741 self.map_index = map_index 

2742 self.start_date = start_date 

2743 self.end_date = end_date 

2744 self.try_number = try_number 

2745 self.state = state 

2746 self.executor_config = executor_config 

2747 self.run_as_user = run_as_user 

2748 self.pool = pool 

2749 self.priority_weight = priority_weight 

2750 self.queue = queue 

2751 self.key = key 

2752 

2753 def __eq__(self, other): 

2754 if isinstance(other, self.__class__): 

2755 return self.__dict__ == other.__dict__ 

2756 return NotImplemented 

2757 

2758 def as_dict(self): 

2759 warnings.warn( 

2760 "This method is deprecated. Use BaseSerialization.serialize.", 

2761 RemovedInAirflow3Warning, 

2762 stacklevel=2, 

2763 ) 

2764 new_dict = dict(self.__dict__) 

2765 for key in new_dict: 

2766 if key in ["start_date", "end_date"]: 

2767 val = new_dict[key] 

2768 if not val or isinstance(val, str): 

2769 continue 

2770 new_dict.update({key: val.isoformat()}) 

2771 return new_dict 

2772 

2773 @classmethod 

2774 def from_ti(cls, ti: TaskInstance) -> SimpleTaskInstance: 

2775 return cls( 

2776 dag_id=ti.dag_id, 

2777 task_id=ti.task_id, 

2778 run_id=ti.run_id, 

2779 map_index=ti.map_index, 

2780 start_date=ti.start_date, 

2781 end_date=ti.end_date, 

2782 try_number=ti.try_number, 

2783 state=ti.state, 

2784 executor_config=ti.executor_config, 

2785 pool=ti.pool, 

2786 queue=ti.queue, 

2787 key=ti.key, 

2788 run_as_user=ti.run_as_user if hasattr(ti, "run_as_user") else None, 

2789 priority_weight=ti.priority_weight if hasattr(ti, "priority_weight") else None, 

2790 ) 

2791 

2792 @classmethod 

2793 def from_dict(cls, obj_dict: dict) -> SimpleTaskInstance: 

2794 warnings.warn( 

2795 "This method is deprecated. Use BaseSerialization.deserialize.", 

2796 RemovedInAirflow3Warning, 

2797 stacklevel=2, 

2798 ) 

2799 ti_key = TaskInstanceKey(*obj_dict.pop("key")) 

2800 start_date = None 

2801 end_date = None 

2802 start_date_str: str | None = obj_dict.pop("start_date") 

2803 end_date_str: str | None = obj_dict.pop("end_date") 

2804 if start_date_str: 

2805 start_date = timezone.parse(start_date_str) 

2806 if end_date_str: 

2807 end_date = timezone.parse(end_date_str) 

2808 return cls(**obj_dict, start_date=start_date, end_date=end_date, key=ti_key) 

2809 

2810 

2811class TaskInstanceNote(Base): 

2812 """For storage of arbitrary notes concerning the task instance.""" 

2813 

2814 __tablename__ = "task_instance_note" 

2815 

2816 user_id = Column(Integer, nullable=True) 

2817 task_id = Column(StringID(), primary_key=True, nullable=False) 

2818 dag_id = Column(StringID(), primary_key=True, nullable=False) 

2819 run_id = Column(StringID(), primary_key=True, nullable=False) 

2820 map_index = Column(Integer, primary_key=True, nullable=False) 

2821 content = Column(String(1000).with_variant(Text(1000), "mysql")) 

2822 created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False) 

2823 updated_at = Column(UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow, nullable=False) 

2824 

2825 task_instance = relationship("TaskInstance", back_populates="task_instance_note") 

2826 

2827 __table_args__ = ( 

2828 PrimaryKeyConstraint( 

2829 "task_id", "dag_id", "run_id", "map_index", name="task_instance_note_pkey", mssql_clustered=True 

2830 ), 

2831 ForeignKeyConstraint( 

2832 (dag_id, task_id, run_id, map_index), 

2833 [ 

2834 "task_instance.dag_id", 

2835 "task_instance.task_id", 

2836 "task_instance.run_id", 

2837 "task_instance.map_index", 

2838 ], 

2839 name="task_instance_note_ti_fkey", 

2840 ondelete="CASCADE", 

2841 ), 

2842 ForeignKeyConstraint( 

2843 (user_id,), 

2844 ["ab_user.id"], 

2845 name="task_instance_note_user_fkey", 

2846 ), 

2847 ) 

2848 

2849 def __init__(self, content, user_id=None): 

2850 self.content = content 

2851 self.user_id = user_id 

2852 

2853 def __repr__(self): 

2854 prefix = f"<{self.__class__.__name__}: {self.dag_id}.{self.task_id} {self.run_id}" 

2855 if self.map_index != -1: 

2856 prefix += f" map_index={self.map_index}" 

2857 return prefix + ">" 

2858 

2859 

2860STATICA_HACK = True 

2861globals()["kcah_acitats"[::-1].upper()] = False 

2862if STATICA_HACK: # pragma: no cover 

2863 from airflow.jobs.base_job import BaseJob 

2864 

2865 TaskInstance.queued_by_job = relationship(BaseJob)