Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/airflow/models/taskinstance.py: 21%
1316 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 06:35 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 06:35 +0000
1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18from __future__ import annotations
20import collections.abc
21import contextlib
22import hashlib
23import logging
24import math
25import operator
26import os
27import signal
28import warnings
29from collections import defaultdict
30from datetime import datetime, timedelta
31from enum import Enum
32from functools import partial
33from pathlib import PurePath
34from types import TracebackType
35from typing import TYPE_CHECKING, Any, Callable, Collection, Generator, Iterable, Tuple
36from urllib.parse import quote
38import dill
39import jinja2
40import lazy_object_proxy
41import pendulum
42from jinja2 import TemplateAssertionError, UndefinedError
43from sqlalchemy import (
44 Column,
45 DateTime,
46 Float,
47 ForeignKeyConstraint,
48 Index,
49 Integer,
50 PrimaryKeyConstraint,
51 String,
52 Text,
53 and_,
54 delete,
55 false,
56 func,
57 inspect,
58 or_,
59 text,
60)
61from sqlalchemy.ext.associationproxy import association_proxy
62from sqlalchemy.ext.mutable import MutableDict
63from sqlalchemy.orm import reconstructor, relationship
64from sqlalchemy.orm.attributes import NO_VALUE, set_committed_value
65from sqlalchemy.orm.session import Session
66from sqlalchemy.sql.elements import BooleanClauseList
67from sqlalchemy.sql.expression import ColumnOperators, case
69from airflow import settings
70from airflow.compat.functools import cache
71from airflow.configuration import conf
72from airflow.datasets import Dataset
73from airflow.datasets.manager import dataset_manager
74from airflow.exceptions import (
75 AirflowException,
76 AirflowFailException,
77 AirflowRescheduleException,
78 AirflowSensorTimeout,
79 AirflowSkipException,
80 AirflowTaskTimeout,
81 DagRunNotFound,
82 RemovedInAirflow3Warning,
83 TaskDeferralError,
84 TaskDeferred,
85 UnmappableXComLengthPushed,
86 UnmappableXComTypePushed,
87 XComForMappingNotPushed,
88)
89from airflow.listeners.listener import get_listener_manager
90from airflow.models.base import Base, StringID
91from airflow.models.dagbag import DagBag
92from airflow.models.log import Log
93from airflow.models.mappedoperator import MappedOperator
94from airflow.models.param import process_params
95from airflow.models.taskfail import TaskFail
96from airflow.models.taskinstancekey import TaskInstanceKey
97from airflow.models.taskmap import TaskMap
98from airflow.models.taskreschedule import TaskReschedule
99from airflow.models.xcom import LazyXComAccess, XCom
100from airflow.plugins_manager import integrate_macros_plugins
101from airflow.sentry import Sentry
102from airflow.stats import Stats
103from airflow.templates import SandboxedEnvironment
104from airflow.ti_deps.dep_context import DepContext
105from airflow.ti_deps.dependencies_deps import REQUEUEABLE_DEPS, RUNNING_DEPS
106from airflow.timetables.base import DataInterval
107from airflow.typing_compat import Literal, TypeGuard
108from airflow.utils import timezone
109from airflow.utils.context import ConnectionAccessor, Context, VariableAccessor, context_merge
110from airflow.utils.email import send_email
111from airflow.utils.helpers import prune_dict, render_template_to_string
112from airflow.utils.log.logging_mixin import LoggingMixin
113from airflow.utils.module_loading import qualname
114from airflow.utils.net import get_hostname
115from airflow.utils.operator_helpers import context_to_airflow_vars
116from airflow.utils.platform import getuser
117from airflow.utils.retries import run_with_db_retries
118from airflow.utils.session import NEW_SESSION, create_session, provide_session
119from airflow.utils.sqlalchemy import (
120 ExecutorConfigType,
121 ExtendedJSON,
122 UtcDateTime,
123 tuple_in_condition,
124 with_row_locks,
125)
126from airflow.utils.state import DagRunState, State, TaskInstanceState
127from airflow.utils.task_group import MappedTaskGroup
128from airflow.utils.timeout import timeout
129from airflow.utils.xcom import XCOM_RETURN_KEY
131TR = TaskReschedule
133_CURRENT_CONTEXT: list[Context] = []
134log = logging.getLogger(__name__)
137if TYPE_CHECKING:
138 from airflow.models.abstractoperator import TaskStateChangeCallback
139 from airflow.models.baseoperator import BaseOperator
140 from airflow.models.dag import DAG, DagModel
141 from airflow.models.dagrun import DagRun
142 from airflow.models.dataset import DatasetEvent
143 from airflow.models.operator import Operator
144 from airflow.utils.task_group import TaskGroup
146 # This is a workaround because mypy doesn't work with hybrid_property
147 # TODO: remove this hack and move hybrid_property back to main import block
148 # See https://github.com/python/mypy/issues/4430
149 hybrid_property = property
150else:
151 from sqlalchemy.ext.hybrid import hybrid_property
154PAST_DEPENDS_MET = "past_depends_met"
157class TaskReturnCode(Enum):
158 """
159 Enum to signal manner of exit for task run command.
161 :meta private:
162 """
164 DEFERRED = 100
165 """When task exits with deferral to trigger."""
168@contextlib.contextmanager
169def set_current_context(context: Context) -> Generator[Context, None, None]:
170 """
171 Sets the current execution context to the provided context object.
172 This method should be called once per Task execution, before calling operator.execute.
173 """
174 _CURRENT_CONTEXT.append(context)
175 try:
176 yield context
177 finally:
178 expected_state = _CURRENT_CONTEXT.pop()
179 if expected_state != context:
180 log.warning(
181 "Current context is not equal to the state at context stack. Expected=%s, got=%s",
182 context,
183 expected_state,
184 )
187def stop_all_tasks_in_dag(tis: list[TaskInstance], session: Session, task_id_to_ignore: int):
188 for ti in tis:
189 if ti.task_id == task_id_to_ignore or ti.state in (
190 TaskInstanceState.SUCCESS,
191 TaskInstanceState.FAILED,
192 ):
193 continue
194 if ti.state == TaskInstanceState.RUNNING:
195 log.info("Forcing task %s to fail", ti.task_id)
196 ti.error(session)
197 else:
198 log.info("Setting task %s to SKIPPED", ti.task_id)
199 ti.set_state(state=TaskInstanceState.SKIPPED, session=session)
202def clear_task_instances(
203 tis: list[TaskInstance],
204 session: Session,
205 activate_dag_runs: None = None,
206 dag: DAG | None = None,
207 dag_run_state: DagRunState | Literal[False] = DagRunState.QUEUED,
208) -> None:
209 """
210 Clears a set of task instances, but makes sure the running ones
211 get killed. Also sets Dagrun's `state` to QUEUED and `start_date`
212 to the time of execution. But only for finished DRs (SUCCESS and FAILED).
213 Doesn't clear DR's `state` and `start_date`for running
214 DRs (QUEUED and RUNNING) because clearing the state for already
215 running DR is redundant and clearing `start_date` affects DR's duration.
217 :param tis: a list of task instances
218 :param session: current session
219 :param dag_run_state: state to set finished DagRuns to.
220 If set to False, DagRuns state will not be changed.
221 :param dag: DAG object
222 :param activate_dag_runs: Deprecated parameter, do not pass
223 """
224 job_ids = []
225 # Keys: dag_id -> run_id -> map_indexes -> try_numbers -> task_id
226 task_id_by_key: dict[str, dict[str, dict[int, dict[int, set[str]]]]] = defaultdict(
227 lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(set)))
228 )
229 dag_bag = DagBag(read_dags_from_db=True)
230 for ti in tis:
231 if ti.state == TaskInstanceState.RUNNING:
232 if ti.job_id:
233 # If a task is cleared when running, set its state to RESTARTING so that
234 # the task is terminated and becomes eligible for retry.
235 ti.state = TaskInstanceState.RESTARTING
236 job_ids.append(ti.job_id)
237 else:
238 ti_dag = dag if dag and dag.dag_id == ti.dag_id else dag_bag.get_dag(ti.dag_id, session=session)
239 task_id = ti.task_id
240 if ti_dag and ti_dag.has_task(task_id):
241 task = ti_dag.get_task(task_id)
242 ti.refresh_from_task(task)
243 task_retries = task.retries
244 ti.max_tries = ti.try_number + task_retries - 1
245 else:
246 # Ignore errors when updating max_tries if the DAG or
247 # task are not found since database records could be
248 # outdated. We make max_tries the maximum value of its
249 # original max_tries or the last attempted try number.
250 ti.max_tries = max(ti.max_tries, ti.prev_attempted_tries)
251 ti.state = None
252 ti.external_executor_id = None
253 ti.clear_next_method_args()
254 session.merge(ti)
256 task_id_by_key[ti.dag_id][ti.run_id][ti.map_index][ti.try_number].add(ti.task_id)
258 if task_id_by_key:
259 # Clear all reschedules related to the ti to clear
261 # This is an optimization for the common case where all tis are for a small number
262 # of dag_id, run_id, try_number, and map_index. Use a nested dict of dag_id,
263 # run_id, try_number, map_index, and task_id to construct the where clause in a
264 # hierarchical manner. This speeds up the delete statement by more than 40x for
265 # large number of tis (50k+).
266 conditions = or_(
267 and_(
268 TR.dag_id == dag_id,
269 or_(
270 and_(
271 TR.run_id == run_id,
272 or_(
273 and_(
274 TR.map_index == map_index,
275 or_(
276 and_(TR.try_number == try_number, TR.task_id.in_(task_ids))
277 for try_number, task_ids in task_tries.items()
278 ),
279 )
280 for map_index, task_tries in map_indexes.items()
281 ),
282 )
283 for run_id, map_indexes in run_ids.items()
284 ),
285 )
286 for dag_id, run_ids in task_id_by_key.items()
287 )
289 delete_qry = TR.__table__.delete().where(conditions)
290 session.execute(delete_qry)
292 if job_ids:
293 from airflow.jobs.job import Job
295 for job in session.query(Job).filter(Job.id.in_(job_ids)).all():
296 job.state = TaskInstanceState.RESTARTING
298 if activate_dag_runs is not None:
299 warnings.warn(
300 "`activate_dag_runs` parameter to clear_task_instances function is deprecated. "
301 "Please use `dag_run_state`",
302 RemovedInAirflow3Warning,
303 stacklevel=2,
304 )
305 if not activate_dag_runs:
306 dag_run_state = False
308 if dag_run_state is not False and tis:
309 from airflow.models.dagrun import DagRun # Avoid circular import
311 run_ids_by_dag_id = defaultdict(set)
312 for instance in tis:
313 run_ids_by_dag_id[instance.dag_id].add(instance.run_id)
315 drs = (
316 session.query(DagRun)
317 .filter(
318 or_(
319 and_(DagRun.dag_id == dag_id, DagRun.run_id.in_(run_ids))
320 for dag_id, run_ids in run_ids_by_dag_id.items()
321 )
322 )
323 .all()
324 )
325 dag_run_state = DagRunState(dag_run_state) # Validate the state value.
326 for dr in drs:
327 if dr.state in State.finished_dr_states:
328 dr.state = dag_run_state
329 dr.start_date = timezone.utcnow()
330 if dag_run_state == DagRunState.QUEUED:
331 dr.last_scheduling_decision = None
332 dr.start_date = None
333 session.flush()
336def _is_mappable_value(value: Any) -> TypeGuard[Collection]:
337 """Whether a value can be used for task mapping.
339 We only allow collections with guaranteed ordering, but exclude character
340 sequences since that's usually not what users would expect to be mappable.
341 """
342 if not isinstance(value, (collections.abc.Sequence, dict)):
343 return False
344 if isinstance(value, (bytearray, bytes, str)):
345 return False
346 return True
349def _creator_note(val):
350 """Custom creator for the ``note`` association proxy."""
351 if isinstance(val, str):
352 return TaskInstanceNote(content=val)
353 elif isinstance(val, dict):
354 return TaskInstanceNote(**val)
355 else:
356 return TaskInstanceNote(*val)
359class TaskInstance(Base, LoggingMixin):
360 """
361 Task instances store the state of a task instance. This table is the
362 authority and single source of truth around what tasks have run and the
363 state they are in.
365 The SqlAlchemy model doesn't have a SqlAlchemy foreign key to the task or
366 dag model deliberately to have more control over transactions.
368 Database transactions on this table should insure double triggers and
369 any confusion around what task instances are or aren't ready to run
370 even while multiple schedulers may be firing task instances.
372 A value of -1 in map_index represents any of: a TI without mapped tasks;
373 a TI with mapped tasks that has yet to be expanded (state=pending);
374 a TI with mapped tasks that expanded to an empty list (state=skipped).
375 """
377 __tablename__ = "task_instance"
378 task_id = Column(StringID(), primary_key=True, nullable=False)
379 dag_id = Column(StringID(), primary_key=True, nullable=False)
380 run_id = Column(StringID(), primary_key=True, nullable=False)
381 map_index = Column(Integer, primary_key=True, nullable=False, server_default=text("-1"))
383 start_date = Column(UtcDateTime)
384 end_date = Column(UtcDateTime)
385 duration = Column(Float)
386 state = Column(String(20))
387 _try_number = Column("try_number", Integer, default=0)
388 max_tries = Column(Integer, server_default=text("-1"))
389 hostname = Column(String(1000))
390 unixname = Column(String(1000))
391 job_id = Column(Integer)
392 pool = Column(String(256), nullable=False)
393 pool_slots = Column(Integer, default=1, nullable=False)
394 queue = Column(String(256))
395 priority_weight = Column(Integer)
396 operator = Column(String(1000))
397 queued_dttm = Column(UtcDateTime)
398 queued_by_job_id = Column(Integer)
399 pid = Column(Integer)
400 executor_config = Column(ExecutorConfigType(pickler=dill))
401 updated_at = Column(UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow)
403 external_executor_id = Column(StringID())
405 # The trigger to resume on if we are in state DEFERRED
406 trigger_id = Column(Integer)
408 # Optional timeout datetime for the trigger (past this, we'll fail)
409 trigger_timeout = Column(DateTime)
410 # The trigger_timeout should be TIMESTAMP(using UtcDateTime) but for ease of
411 # migration, we are keeping it as DateTime pending a change where expensive
412 # migration is inevitable.
414 # The method to call next, and any extra arguments to pass to it.
415 # Usually used when resuming from DEFERRED.
416 next_method = Column(String(1000))
417 next_kwargs = Column(MutableDict.as_mutable(ExtendedJSON))
419 # If adding new fields here then remember to add them to
420 # refresh_from_db() or they won't display in the UI correctly
422 __table_args__ = (
423 Index("ti_dag_state", dag_id, state),
424 Index("ti_dag_run", dag_id, run_id),
425 Index("ti_state", state),
426 Index("ti_state_lkp", dag_id, task_id, run_id, state),
427 # The below index has been added to improve performance on postgres setups with tens of millions of
428 # taskinstance rows. Aim is to improve the below query (it can be used to find the last successful
429 # execution date of a task instance):
430 # SELECT start_date FROM task_instance WHERE dag_id = 'xx' AND task_id = 'yy' AND state = 'success'
431 # ORDER BY start_date DESC NULLS LAST LIMIT 1;
432 # Existing "ti_state_lkp" is not enough for such query when this table has millions of rows, since
433 # rows have to be fetched in order to retrieve the start_date column. With this index, INDEX ONLY SCAN
434 # is performed and that query runs within milliseconds.
435 Index("ti_state_incl_start_date", dag_id, task_id, state, postgresql_include=["start_date"]),
436 Index("ti_pool", pool, state, priority_weight),
437 Index("ti_job_id", job_id),
438 Index("ti_trigger_id", trigger_id),
439 PrimaryKeyConstraint(
440 "dag_id", "task_id", "run_id", "map_index", name="task_instance_pkey", mssql_clustered=True
441 ),
442 ForeignKeyConstraint(
443 [trigger_id],
444 ["trigger.id"],
445 name="task_instance_trigger_id_fkey",
446 ondelete="CASCADE",
447 ),
448 ForeignKeyConstraint(
449 [dag_id, run_id],
450 ["dag_run.dag_id", "dag_run.run_id"],
451 name="task_instance_dag_run_fkey",
452 ondelete="CASCADE",
453 ),
454 )
456 dag_model = relationship(
457 "DagModel",
458 primaryjoin="TaskInstance.dag_id == DagModel.dag_id",
459 foreign_keys=dag_id,
460 uselist=False,
461 innerjoin=True,
462 viewonly=True,
463 )
465 trigger = relationship("Trigger", uselist=False, back_populates="task_instance")
466 triggerer_job = association_proxy("trigger", "triggerer_job")
467 dag_run = relationship("DagRun", back_populates="task_instances", lazy="joined", innerjoin=True)
468 rendered_task_instance_fields = relationship("RenderedTaskInstanceFields", lazy="noload", uselist=False)
469 execution_date = association_proxy("dag_run", "execution_date")
470 task_instance_note = relationship(
471 "TaskInstanceNote",
472 back_populates="task_instance",
473 uselist=False,
474 cascade="all, delete, delete-orphan",
475 )
476 note = association_proxy("task_instance_note", "content", creator=_creator_note)
477 task: Operator # Not always set...
479 is_trigger_log_context: bool = False
480 """Indicate to FileTaskHandler that logging context should be set up for trigger logging.
482 :meta private:
483 """
485 def __init__(
486 self,
487 task: Operator,
488 execution_date: datetime | None = None,
489 run_id: str | None = None,
490 state: str | None = None,
491 map_index: int = -1,
492 ):
493 super().__init__()
494 self.dag_id = task.dag_id
495 self.task_id = task.task_id
496 self.map_index = map_index
497 self.refresh_from_task(task)
498 # init_on_load will config the log
499 self.init_on_load()
501 if run_id is None and execution_date is not None:
502 from airflow.models.dagrun import DagRun # Avoid circular import
504 warnings.warn(
505 "Passing an execution_date to `TaskInstance()` is deprecated in favour of passing a run_id",
506 RemovedInAirflow3Warning,
507 # Stack level is 4 because SQLA adds some wrappers around the constructor
508 stacklevel=4,
509 )
510 # make sure we have a localized execution_date stored in UTC
511 if execution_date and not timezone.is_localized(execution_date):
512 self.log.warning(
513 "execution date %s has no timezone information. Using default from dag or system",
514 execution_date,
515 )
516 if self.task.has_dag():
517 if TYPE_CHECKING:
518 assert self.task.dag
519 execution_date = timezone.make_aware(execution_date, self.task.dag.timezone)
520 else:
521 execution_date = timezone.make_aware(execution_date)
523 execution_date = timezone.convert_to_utc(execution_date)
524 with create_session() as session:
525 run_id = (
526 session.query(DagRun.run_id)
527 .filter_by(dag_id=self.dag_id, execution_date=execution_date)
528 .scalar()
529 )
530 if run_id is None:
531 raise DagRunNotFound(
532 f"DagRun for {self.dag_id!r} with date {execution_date} not found"
533 ) from None
535 self.run_id = run_id
537 self.try_number = 0
538 self.max_tries = self.task.retries
539 self.unixname = getuser()
540 if state:
541 self.state = state
542 self.hostname = ""
543 # Is this TaskInstance being currently running within `airflow tasks run --raw`.
544 # Not persisted to the database so only valid for the current process
545 self.raw = False
546 # can be changed when calling 'run'
547 self.test_mode = False
549 @property
550 def stats_tags(self) -> dict[str, str]:
551 return prune_dict({"dag_id": self.dag_id, "task_id": self.task_id})
553 @staticmethod
554 def insert_mapping(run_id: str, task: Operator, map_index: int) -> dict[str, Any]:
555 """Insert mapping.
557 :meta private:
558 """
559 return {
560 "dag_id": task.dag_id,
561 "task_id": task.task_id,
562 "run_id": run_id,
563 "_try_number": 0,
564 "hostname": "",
565 "unixname": getuser(),
566 "queue": task.queue,
567 "pool": task.pool,
568 "pool_slots": task.pool_slots,
569 "priority_weight": task.priority_weight_total,
570 "run_as_user": task.run_as_user,
571 "max_tries": task.retries,
572 "executor_config": task.executor_config,
573 "operator": task.task_type,
574 "map_index": map_index,
575 }
577 @reconstructor
578 def init_on_load(self) -> None:
579 """Initialize the attributes that aren't stored in the DB."""
580 # correctly config the ti log
581 self._log = logging.getLogger("airflow.task")
582 self.test_mode = False # can be changed when calling 'run'
584 @hybrid_property
585 def try_number(self):
586 """
587 Return the try number that this task number will be when it is actually
588 run.
590 If the TaskInstance is currently running, this will match the column in the
591 database, in all other cases this will be incremented.
592 """
593 # This is designed so that task logs end up in the right file.
594 if self.state == State.RUNNING:
595 return self._try_number
596 return self._try_number + 1
598 @try_number.setter
599 def try_number(self, value: int) -> None:
600 self._try_number = value
602 @property
603 def prev_attempted_tries(self) -> int:
604 """
605 Based on this instance's try_number, this will calculate
606 the number of previously attempted tries, defaulting to 0.
607 """
608 # Expose this for the Task Tries and Gantt graph views.
609 # Using `try_number` throws off the counts for non-running tasks.
610 # Also useful in error logging contexts to get
611 # the try number for the last try that was attempted.
612 # https://issues.apache.org/jira/browse/AIRFLOW-2143
614 return self._try_number
616 @property
617 def next_try_number(self) -> int:
618 return self._try_number + 1
620 def command_as_list(
621 self,
622 mark_success=False,
623 ignore_all_deps=False,
624 ignore_task_deps=False,
625 ignore_depends_on_past=False,
626 wait_for_past_depends_before_skipping=False,
627 ignore_ti_state=False,
628 local=False,
629 pickle_id: int | None = None,
630 raw=False,
631 job_id=None,
632 pool=None,
633 cfg_path=None,
634 ) -> list[str]:
635 """
636 Returns a command that can be executed anywhere where airflow is
637 installed. This command is part of the message sent to executors by
638 the orchestrator.
639 """
640 dag: DAG | DagModel
641 # Use the dag if we have it, else fallback to the ORM dag_model, which might not be loaded
642 if hasattr(self, "task") and hasattr(self.task, "dag") and self.task.dag is not None:
643 dag = self.task.dag
644 else:
645 dag = self.dag_model
647 should_pass_filepath = not pickle_id and dag
648 path: PurePath | None = None
649 if should_pass_filepath:
650 if dag.is_subdag:
651 if TYPE_CHECKING:
652 assert dag.parent_dag is not None
653 path = dag.parent_dag.relative_fileloc
654 else:
655 path = dag.relative_fileloc
657 if path:
658 if not path.is_absolute():
659 path = "DAGS_FOLDER" / path
661 return TaskInstance.generate_command(
662 self.dag_id,
663 self.task_id,
664 run_id=self.run_id,
665 mark_success=mark_success,
666 ignore_all_deps=ignore_all_deps,
667 ignore_task_deps=ignore_task_deps,
668 ignore_depends_on_past=ignore_depends_on_past,
669 wait_for_past_depends_before_skipping=wait_for_past_depends_before_skipping,
670 ignore_ti_state=ignore_ti_state,
671 local=local,
672 pickle_id=pickle_id,
673 file_path=path,
674 raw=raw,
675 job_id=job_id,
676 pool=pool,
677 cfg_path=cfg_path,
678 map_index=self.map_index,
679 )
681 @staticmethod
682 def generate_command(
683 dag_id: str,
684 task_id: str,
685 run_id: str,
686 mark_success: bool = False,
687 ignore_all_deps: bool = False,
688 ignore_depends_on_past: bool = False,
689 wait_for_past_depends_before_skipping: bool = False,
690 ignore_task_deps: bool = False,
691 ignore_ti_state: bool = False,
692 local: bool = False,
693 pickle_id: int | None = None,
694 file_path: PurePath | str | None = None,
695 raw: bool = False,
696 job_id: str | None = None,
697 pool: str | None = None,
698 cfg_path: str | None = None,
699 map_index: int = -1,
700 ) -> list[str]:
701 """
702 Generates the shell command required to execute this task instance.
704 :param dag_id: DAG ID
705 :param task_id: Task ID
706 :param run_id: The run_id of this task's DagRun
707 :param mark_success: Whether to mark the task as successful
708 :param ignore_all_deps: Ignore all ignorable dependencies.
709 Overrides the other ignore_* parameters.
710 :param ignore_depends_on_past: Ignore depends_on_past parameter of DAGs
711 (e.g. for Backfills)
712 :param wait_for_past_depends_before_skipping: Wait for past depends before marking the ti as skipped
713 :param ignore_task_deps: Ignore task-specific dependencies such as depends_on_past
714 and trigger rule
715 :param ignore_ti_state: Ignore the task instance's previous failure/success
716 :param local: Whether to run the task locally
717 :param pickle_id: If the DAG was serialized to the DB, the ID
718 associated with the pickled DAG
719 :param file_path: path to the file containing the DAG definition
720 :param raw: raw mode (needs more details)
721 :param job_id: job ID (needs more details)
722 :param pool: the Airflow pool that the task should run in
723 :param cfg_path: the Path to the configuration file
724 :return: shell command that can be used to run the task instance
725 """
726 cmd = ["airflow", "tasks", "run", dag_id, task_id, run_id]
727 if mark_success:
728 cmd.extend(["--mark-success"])
729 if pickle_id:
730 cmd.extend(["--pickle", str(pickle_id)])
731 if job_id:
732 cmd.extend(["--job-id", str(job_id)])
733 if ignore_all_deps:
734 cmd.extend(["--ignore-all-dependencies"])
735 if ignore_task_deps:
736 cmd.extend(["--ignore-dependencies"])
737 if ignore_depends_on_past:
738 cmd.extend(["--depends-on-past", "ignore"])
739 elif wait_for_past_depends_before_skipping:
740 cmd.extend(["--depends-on-past", "wait"])
741 if ignore_ti_state:
742 cmd.extend(["--force"])
743 if local:
744 cmd.extend(["--local"])
745 if pool:
746 cmd.extend(["--pool", pool])
747 if raw:
748 cmd.extend(["--raw"])
749 if file_path:
750 cmd.extend(["--subdir", os.fspath(file_path)])
751 if cfg_path:
752 cmd.extend(["--cfg-path", cfg_path])
753 if map_index != -1:
754 cmd.extend(["--map-index", str(map_index)])
755 return cmd
757 @property
758 def log_url(self) -> str:
759 """Log URL for TaskInstance."""
760 iso = quote(self.execution_date.isoformat())
761 base_url = conf.get_mandatory_value("webserver", "BASE_URL")
762 return (
763 f"{base_url}/log"
764 f"?execution_date={iso}"
765 f"&task_id={self.task_id}"
766 f"&dag_id={self.dag_id}"
767 f"&map_index={self.map_index}"
768 )
770 @property
771 def mark_success_url(self) -> str:
772 """URL to mark TI success."""
773 base_url = conf.get_mandatory_value("webserver", "BASE_URL")
774 return (
775 f"{base_url}/confirm"
776 f"?task_id={self.task_id}"
777 f"&dag_id={self.dag_id}"
778 f"&dag_run_id={quote(self.run_id)}"
779 "&upstream=false"
780 "&downstream=false"
781 "&state=success"
782 )
784 @provide_session
785 def current_state(self, session: Session = NEW_SESSION) -> str:
786 """
787 Get the very latest state from the database, if a session is passed,
788 we use and looking up the state becomes part of the session, otherwise
789 a new session is used.
791 sqlalchemy.inspect is used here to get the primary keys ensuring that if they change
792 it will not regress
794 :param session: SQLAlchemy ORM Session
795 """
796 filters = (col == getattr(self, col.name) for col in inspect(TaskInstance).primary_key)
797 return session.query(TaskInstance.state).filter(*filters).scalar()
799 @provide_session
800 def error(self, session: Session = NEW_SESSION) -> None:
801 """
802 Forces the task instance's state to FAILED in the database.
804 :param session: SQLAlchemy ORM Session
805 """
806 self.log.error("Recording the task instance as FAILED")
807 self.state = State.FAILED
808 session.merge(self)
809 session.commit()
811 @provide_session
812 def refresh_from_db(self, session: Session = NEW_SESSION, lock_for_update: bool = False) -> None:
813 """
814 Refreshes the task instance from the database based on the primary key.
816 :param session: SQLAlchemy ORM Session
817 :param lock_for_update: if True, indicates that the database should
818 lock the TaskInstance (issuing a FOR UPDATE clause) until the
819 session is committed.
820 """
821 self.log.debug("Refreshing TaskInstance %s from DB", self)
823 if self in session:
824 session.refresh(self, TaskInstance.__mapper__.column_attrs.keys())
826 qry = (
827 # To avoid joining any relationships, by default select all
828 # columns, not the object. This also means we get (effectively) a
829 # namedtuple back, not a TI object
830 session.query(*TaskInstance.__table__.columns).filter(
831 TaskInstance.dag_id == self.dag_id,
832 TaskInstance.task_id == self.task_id,
833 TaskInstance.run_id == self.run_id,
834 TaskInstance.map_index == self.map_index,
835 )
836 )
838 if lock_for_update:
839 for attempt in run_with_db_retries(logger=self.log):
840 with attempt:
841 ti: TaskInstance | None = qry.with_for_update().one_or_none()
842 else:
843 ti = qry.one_or_none()
844 if ti:
845 # Fields ordered per model definition
846 self.start_date = ti.start_date
847 self.end_date = ti.end_date
848 self.duration = ti.duration
849 self.state = ti.state
850 # Since we selected columns, not the object, this is the raw value
851 self.try_number = ti.try_number
852 self.max_tries = ti.max_tries
853 self.hostname = ti.hostname
854 self.unixname = ti.unixname
855 self.job_id = ti.job_id
856 self.pool = ti.pool
857 self.pool_slots = ti.pool_slots or 1
858 self.queue = ti.queue
859 self.priority_weight = ti.priority_weight
860 self.operator = ti.operator
861 self.queued_dttm = ti.queued_dttm
862 self.queued_by_job_id = ti.queued_by_job_id
863 self.pid = ti.pid
864 self.executor_config = ti.executor_config
865 self.external_executor_id = ti.external_executor_id
866 self.trigger_id = ti.trigger_id
867 self.next_method = ti.next_method
868 self.next_kwargs = ti.next_kwargs
869 else:
870 self.state = None
872 def refresh_from_task(self, task: Operator, pool_override: str | None = None) -> None:
873 """
874 Copy common attributes from the given task.
876 :param task: The task object to copy from
877 :param pool_override: Use the pool_override instead of task's pool
878 """
879 self.task = task
880 self.queue = task.queue
881 self.pool = pool_override or task.pool
882 self.pool_slots = task.pool_slots
883 self.priority_weight = task.priority_weight_total
884 self.run_as_user = task.run_as_user
885 # Do not set max_tries to task.retries here because max_tries is a cumulative
886 # value that needs to be stored in the db.
887 self.executor_config = task.executor_config
888 self.operator = task.task_type
890 @provide_session
891 def clear_xcom_data(self, session: Session = NEW_SESSION) -> None:
892 """Clear all XCom data from the database for the task instance.
894 If the task is unmapped, all XComs matching this task ID in the same DAG
895 run are removed. If the task is mapped, only the one with matching map
896 index is removed.
898 :param session: SQLAlchemy ORM Session
899 """
900 self.log.debug("Clearing XCom data")
901 if self.map_index < 0:
902 map_index: int | None = None
903 else:
904 map_index = self.map_index
905 XCom.clear(
906 dag_id=self.dag_id,
907 task_id=self.task_id,
908 run_id=self.run_id,
909 map_index=map_index,
910 session=session,
911 )
913 @property
914 def key(self) -> TaskInstanceKey:
915 """Returns a tuple that identifies the task instance uniquely."""
916 return TaskInstanceKey(self.dag_id, self.task_id, self.run_id, self.try_number, self.map_index)
918 @provide_session
919 def set_state(self, state: str | None, session: Session = NEW_SESSION) -> bool:
920 """
921 Set TaskInstance state.
923 :param state: State to set for the TI
924 :param session: SQLAlchemy ORM Session
925 :return: Was the state changed
926 """
927 if self.state == state:
928 return False
930 current_time = timezone.utcnow()
931 self.log.debug("Setting task state for %s to %s", self, state)
932 self.state = state
933 self.start_date = self.start_date or current_time
934 if self.state in State.finished or self.state == State.UP_FOR_RETRY:
935 self.end_date = self.end_date or current_time
936 self.duration = (self.end_date - self.start_date).total_seconds()
937 session.merge(self)
938 return True
940 @property
941 def is_premature(self) -> bool:
942 """
943 Returns whether a task is in UP_FOR_RETRY state and its retry interval
944 has elapsed.
945 """
946 # is the task still in the retry waiting period?
947 return self.state == State.UP_FOR_RETRY and not self.ready_for_retry()
949 @provide_session
950 def are_dependents_done(self, session: Session = NEW_SESSION) -> bool:
951 """
952 Checks whether the immediate dependents of this task instance have succeeded or have been skipped.
953 This is meant to be used by wait_for_downstream.
955 This is useful when you do not want to start processing the next
956 schedule of a task until the dependents are done. For instance,
957 if the task DROPs and recreates a table.
959 :param session: SQLAlchemy ORM Session
960 """
961 task = self.task
963 if not task.downstream_task_ids:
964 return True
966 ti = session.query(func.count(TaskInstance.task_id)).filter(
967 TaskInstance.dag_id == self.dag_id,
968 TaskInstance.task_id.in_(task.downstream_task_ids),
969 TaskInstance.run_id == self.run_id,
970 TaskInstance.state.in_([State.SKIPPED, State.SUCCESS]),
971 )
972 count = ti[0][0]
973 return count == len(task.downstream_task_ids)
975 @provide_session
976 def get_previous_dagrun(
977 self,
978 state: DagRunState | None = None,
979 session: Session | None = None,
980 ) -> DagRun | None:
981 """The DagRun that ran before this task instance's DagRun.
983 :param state: If passed, it only take into account instances of a specific state.
984 :param session: SQLAlchemy ORM Session.
985 """
986 dag = self.task.dag
987 if dag is None:
988 return None
990 dr = self.get_dagrun(session=session)
991 dr.dag = dag
993 # We always ignore schedule in dagrun lookup when `state` is given
994 # or the DAG is never scheduled. For legacy reasons, when
995 # `catchup=True`, we use `get_previous_scheduled_dagrun` unless
996 # `ignore_schedule` is `True`.
997 ignore_schedule = state is not None or not dag.timetable.can_be_scheduled
998 if dag.catchup is True and not ignore_schedule:
999 last_dagrun = dr.get_previous_scheduled_dagrun(session=session)
1000 else:
1001 last_dagrun = dr.get_previous_dagrun(session=session, state=state)
1003 if last_dagrun:
1004 return last_dagrun
1006 return None
1008 @provide_session
1009 def get_previous_ti(
1010 self,
1011 state: DagRunState | None = None,
1012 session: Session = NEW_SESSION,
1013 ) -> TaskInstance | None:
1014 """
1015 The task instance for the task that ran before this task instance.
1017 :param state: If passed, it only take into account instances of a specific state.
1018 :param session: SQLAlchemy ORM Session
1019 """
1020 dagrun = self.get_previous_dagrun(state, session=session)
1021 if dagrun is None:
1022 return None
1023 return dagrun.get_task_instance(self.task_id, session=session)
1025 @property
1026 def previous_ti(self) -> TaskInstance | None:
1027 """
1028 This attribute is deprecated.
1029 Please use `airflow.models.taskinstance.TaskInstance.get_previous_ti` method.
1030 """
1031 warnings.warn(
1032 """
1033 This attribute is deprecated.
1034 Please use `airflow.models.taskinstance.TaskInstance.get_previous_ti` method.
1035 """,
1036 RemovedInAirflow3Warning,
1037 stacklevel=2,
1038 )
1039 return self.get_previous_ti()
1041 @property
1042 def previous_ti_success(self) -> TaskInstance | None:
1043 """
1044 This attribute is deprecated.
1045 Please use `airflow.models.taskinstance.TaskInstance.get_previous_ti` method.
1046 """
1047 warnings.warn(
1048 """
1049 This attribute is deprecated.
1050 Please use `airflow.models.taskinstance.TaskInstance.get_previous_ti` method.
1051 """,
1052 RemovedInAirflow3Warning,
1053 stacklevel=2,
1054 )
1055 return self.get_previous_ti(state=DagRunState.SUCCESS)
1057 @provide_session
1058 def get_previous_execution_date(
1059 self,
1060 state: DagRunState | None = None,
1061 session: Session = NEW_SESSION,
1062 ) -> pendulum.DateTime | None:
1063 """
1064 The execution date from property previous_ti_success.
1066 :param state: If passed, it only take into account instances of a specific state.
1067 :param session: SQLAlchemy ORM Session
1068 """
1069 self.log.debug("previous_execution_date was called")
1070 prev_ti = self.get_previous_ti(state=state, session=session)
1071 return prev_ti and pendulum.instance(prev_ti.execution_date)
1073 @provide_session
1074 def get_previous_start_date(
1075 self, state: DagRunState | None = None, session: Session = NEW_SESSION
1076 ) -> pendulum.DateTime | None:
1077 """
1078 The start date from property previous_ti_success.
1080 :param state: If passed, it only take into account instances of a specific state.
1081 :param session: SQLAlchemy ORM Session
1082 """
1083 self.log.debug("previous_start_date was called")
1084 prev_ti = self.get_previous_ti(state=state, session=session)
1085 # prev_ti may not exist and prev_ti.start_date may be None.
1086 return prev_ti and prev_ti.start_date and pendulum.instance(prev_ti.start_date)
1088 @property
1089 def previous_start_date_success(self) -> pendulum.DateTime | None:
1090 """
1091 This attribute is deprecated.
1092 Please use `airflow.models.taskinstance.TaskInstance.get_previous_start_date` method.
1093 """
1094 warnings.warn(
1095 """
1096 This attribute is deprecated.
1097 Please use `airflow.models.taskinstance.TaskInstance.get_previous_start_date` method.
1098 """,
1099 RemovedInAirflow3Warning,
1100 stacklevel=2,
1101 )
1102 return self.get_previous_start_date(state=DagRunState.SUCCESS)
1104 @provide_session
1105 def are_dependencies_met(
1106 self, dep_context: DepContext | None = None, session: Session = NEW_SESSION, verbose: bool = False
1107 ) -> bool:
1108 """
1109 Returns whether or not all the conditions are met for this task instance to be run
1110 given the context for the dependencies (e.g. a task instance being force run from
1111 the UI will ignore some dependencies).
1113 :param dep_context: The execution context that determines the dependencies that
1114 should be evaluated.
1115 :param session: database session
1116 :param verbose: whether log details on failed dependencies on
1117 info or debug log level
1118 """
1119 dep_context = dep_context or DepContext()
1120 failed = False
1121 verbose_aware_logger = self.log.info if verbose else self.log.debug
1122 for dep_status in self.get_failed_dep_statuses(dep_context=dep_context, session=session):
1123 failed = True
1125 verbose_aware_logger(
1126 "Dependencies not met for %s, dependency '%s' FAILED: %s",
1127 self,
1128 dep_status.dep_name,
1129 dep_status.reason,
1130 )
1132 if failed:
1133 return False
1135 verbose_aware_logger("Dependencies all met for dep_context=%s ti=%s", dep_context.description, self)
1136 return True
1138 @provide_session
1139 def get_failed_dep_statuses(self, dep_context: DepContext | None = None, session: Session = NEW_SESSION):
1140 """Get failed Dependencies."""
1141 dep_context = dep_context or DepContext()
1142 for dep in dep_context.deps | self.task.deps:
1143 for dep_status in dep.get_dep_statuses(self, session, dep_context):
1144 self.log.debug(
1145 "%s dependency '%s' PASSED: %s, %s",
1146 self,
1147 dep_status.dep_name,
1148 dep_status.passed,
1149 dep_status.reason,
1150 )
1152 if not dep_status.passed:
1153 yield dep_status
1155 def __repr__(self) -> str:
1156 prefix = f"<TaskInstance: {self.dag_id}.{self.task_id} {self.run_id} "
1157 if self.map_index != -1:
1158 prefix += f"map_index={self.map_index} "
1159 return prefix + f"[{self.state}]>"
1161 def next_retry_datetime(self):
1162 """
1163 Get datetime of the next retry if the task instance fails. For exponential
1164 backoff, retry_delay is used as base and will be converted to seconds.
1165 """
1166 from airflow.models.abstractoperator import MAX_RETRY_DELAY
1168 delay = self.task.retry_delay
1169 if self.task.retry_exponential_backoff:
1170 # If the min_backoff calculation is below 1, it will be converted to 0 via int. Thus,
1171 # we must round up prior to converting to an int, otherwise a divide by zero error
1172 # will occur in the modded_hash calculation.
1173 min_backoff = int(math.ceil(delay.total_seconds() * (2 ** (self.try_number - 2))))
1175 # In the case when delay.total_seconds() is 0, min_backoff will not be rounded up to 1.
1176 # To address this, we impose a lower bound of 1 on min_backoff. This effectively makes
1177 # the ceiling function unnecessary, but the ceiling function was retained to avoid
1178 # introducing a breaking change.
1179 if min_backoff < 1:
1180 min_backoff = 1
1182 # deterministic per task instance
1183 ti_hash = int(
1184 hashlib.sha1(
1185 f"{self.dag_id}#{self.task_id}#{self.execution_date}#{self.try_number}".encode()
1186 ).hexdigest(),
1187 16,
1188 )
1189 # between 1 and 1.0 * delay * (2^retry_number)
1190 modded_hash = min_backoff + ti_hash % min_backoff
1191 # timedelta has a maximum representable value. The exponentiation
1192 # here means this value can be exceeded after a certain number
1193 # of tries (around 50 if the initial delay is 1s, even fewer if
1194 # the delay is larger). Cap the value here before creating a
1195 # timedelta object so the operation doesn't fail with "OverflowError".
1196 delay_backoff_in_seconds = min(modded_hash, MAX_RETRY_DELAY)
1197 delay = timedelta(seconds=delay_backoff_in_seconds)
1198 if self.task.max_retry_delay:
1199 delay = min(self.task.max_retry_delay, delay)
1200 return self.end_date + delay
1202 def ready_for_retry(self) -> bool:
1203 """
1204 Checks on whether the task instance is in the right state and timeframe
1205 to be retried.
1206 """
1207 return self.state == State.UP_FOR_RETRY and self.next_retry_datetime() < timezone.utcnow()
1209 @provide_session
1210 def get_dagrun(self, session: Session = NEW_SESSION) -> DagRun:
1211 """
1212 Returns the DagRun for this TaskInstance.
1214 :param session: SQLAlchemy ORM Session
1215 :return: DagRun
1216 """
1217 info = inspect(self)
1218 if info.attrs.dag_run.loaded_value is not NO_VALUE:
1219 if hasattr(self, "task"):
1220 self.dag_run.dag = self.task.dag
1221 return self.dag_run
1223 from airflow.models.dagrun import DagRun # Avoid circular import
1225 dr = session.query(DagRun).filter(DagRun.dag_id == self.dag_id, DagRun.run_id == self.run_id).one()
1226 if hasattr(self, "task"):
1227 dr.dag = self.task.dag
1228 # Record it in the instance for next time. This means that `self.execution_date` will work correctly
1229 set_committed_value(self, "dag_run", dr)
1231 return dr
1233 @provide_session
1234 def check_and_change_state_before_execution(
1235 self,
1236 verbose: bool = True,
1237 ignore_all_deps: bool = False,
1238 ignore_depends_on_past: bool = False,
1239 wait_for_past_depends_before_skipping: bool = False,
1240 ignore_task_deps: bool = False,
1241 ignore_ti_state: bool = False,
1242 mark_success: bool = False,
1243 test_mode: bool = False,
1244 job_id: str | None = None,
1245 pool: str | None = None,
1246 external_executor_id: str | None = None,
1247 session: Session = NEW_SESSION,
1248 ) -> bool:
1249 """
1250 Checks dependencies and then sets state to RUNNING if they are met. Returns
1251 True if and only if state is set to RUNNING, which implies that task should be
1252 executed, in preparation for _run_raw_task.
1254 :param verbose: whether to turn on more verbose logging
1255 :param ignore_all_deps: Ignore all of the non-critical dependencies, just runs
1256 :param ignore_depends_on_past: Ignore depends_on_past DAG attribute
1257 :param wait_for_past_depends_before_skipping: Wait for past depends before mark the ti as skipped
1258 :param ignore_task_deps: Don't check the dependencies of this TaskInstance's task
1259 :param ignore_ti_state: Disregards previous task instance state
1260 :param mark_success: Don't run the task, mark its state as success
1261 :param test_mode: Doesn't record success or failure in the DB
1262 :param job_id: Job (BackfillJob / LocalTaskJob / SchedulerJob) ID
1263 :param pool: specifies the pool to use to run the task instance
1264 :param external_executor_id: The identifier of the celery executor
1265 :param session: SQLAlchemy ORM Session
1266 :return: whether the state was changed to running or not
1267 """
1268 task = self.task
1269 self.refresh_from_task(task, pool_override=pool)
1270 self.test_mode = test_mode
1271 self.refresh_from_db(session=session, lock_for_update=True)
1272 self.job_id = job_id
1273 self.hostname = get_hostname()
1274 self.pid = None
1276 if not ignore_all_deps and not ignore_ti_state and self.state == State.SUCCESS:
1277 Stats.incr("previously_succeeded", tags=self.stats_tags)
1279 if not mark_success:
1280 # Firstly find non-runnable and non-requeueable tis.
1281 # Since mark_success is not set, we do nothing.
1282 non_requeueable_dep_context = DepContext(
1283 deps=RUNNING_DEPS - REQUEUEABLE_DEPS,
1284 ignore_all_deps=ignore_all_deps,
1285 ignore_ti_state=ignore_ti_state,
1286 ignore_depends_on_past=ignore_depends_on_past,
1287 wait_for_past_depends_before_skipping=wait_for_past_depends_before_skipping,
1288 ignore_task_deps=ignore_task_deps,
1289 description="non-requeueable deps",
1290 )
1291 if not self.are_dependencies_met(
1292 dep_context=non_requeueable_dep_context, session=session, verbose=True
1293 ):
1294 session.commit()
1295 return False
1297 # For reporting purposes, we report based on 1-indexed,
1298 # not 0-indexed lists (i.e. Attempt 1 instead of
1299 # Attempt 0 for the first attempt).
1300 # Set the task start date. In case it was re-scheduled use the initial
1301 # start date that is recorded in task_reschedule table
1302 # If the task continues after being deferred (next_method is set), use the original start_date
1303 self.start_date = self.start_date if self.next_method else timezone.utcnow()
1304 if self.state == State.UP_FOR_RESCHEDULE:
1305 task_reschedule: TR = TR.query_for_task_instance(self, session=session).first()
1306 if task_reschedule:
1307 self.start_date = task_reschedule.start_date
1309 # Secondly we find non-runnable but requeueable tis. We reset its state.
1310 # This is because we might have hit concurrency limits,
1311 # e.g. because of backfilling.
1312 dep_context = DepContext(
1313 deps=REQUEUEABLE_DEPS,
1314 ignore_all_deps=ignore_all_deps,
1315 ignore_depends_on_past=ignore_depends_on_past,
1316 wait_for_past_depends_before_skipping=wait_for_past_depends_before_skipping,
1317 ignore_task_deps=ignore_task_deps,
1318 ignore_ti_state=ignore_ti_state,
1319 description="requeueable deps",
1320 )
1321 if not self.are_dependencies_met(dep_context=dep_context, session=session, verbose=True):
1322 self.state = State.NONE
1323 self.log.warning(
1324 "Rescheduling due to concurrency limits reached "
1325 "at task runtime. Attempt %s of "
1326 "%s. State set to NONE.",
1327 self.try_number,
1328 self.max_tries + 1,
1329 )
1330 self.queued_dttm = timezone.utcnow()
1331 session.merge(self)
1332 session.commit()
1333 return False
1335 if self.next_kwargs is not None:
1336 self.log.info("Resuming after deferral")
1337 else:
1338 self.log.info("Starting attempt %s of %s", self.try_number, self.max_tries + 1)
1339 self._try_number += 1
1341 if not test_mode:
1342 session.add(Log(State.RUNNING, self))
1344 self.state = State.RUNNING
1345 self.emit_state_change_metric(State.RUNNING)
1346 self.external_executor_id = external_executor_id
1347 self.end_date = None
1348 if not test_mode:
1349 session.merge(self).task = task
1350 session.commit()
1352 # Closing all pooled connections to prevent
1353 # "max number of connections reached"
1354 settings.engine.dispose() # type: ignore
1355 if verbose:
1356 if mark_success:
1357 self.log.info("Marking success for %s on %s", self.task, self.execution_date)
1358 else:
1359 self.log.info("Executing %s on %s", self.task, self.execution_date)
1360 return True
1362 def _date_or_empty(self, attr: str) -> str:
1363 result: datetime | None = getattr(self, attr, None)
1364 return result.strftime("%Y%m%dT%H%M%S") if result else ""
1366 def _log_state(self, lead_msg: str = "") -> None:
1367 params = [
1368 lead_msg,
1369 str(self.state).upper(),
1370 self.dag_id,
1371 self.task_id,
1372 ]
1373 message = "%sMarking task as %s. dag_id=%s, task_id=%s, "
1374 if self.map_index >= 0:
1375 params.append(self.map_index)
1376 message += "map_index=%d, "
1377 self.log.info(
1378 message + "execution_date=%s, start_date=%s, end_date=%s",
1379 *params,
1380 self._date_or_empty("execution_date"),
1381 self._date_or_empty("start_date"),
1382 self._date_or_empty("end_date"),
1383 )
1385 def emit_state_change_metric(self, new_state: TaskInstanceState):
1386 """
1387 Sends a time metric representing how much time a given state transition took.
1388 The previous state and metric name is deduced from the state the task was put in.
1390 :param new_state: The state that has just been set for this task.
1391 We do not use `self.state`, because sometimes the state is updated directly in the DB and not in
1392 the local TaskInstance object.
1393 Supported states: QUEUED and RUNNING
1394 """
1395 if self.end_date:
1396 # if the task has an end date, it means that this is not its first round.
1397 # we send the state transition time metric only on the first try, otherwise it gets more complex.
1398 return
1400 # switch on state and deduce which metric to send
1401 if new_state == State.RUNNING:
1402 metric_name = "queued_duration"
1403 if self.queued_dttm is None:
1404 # this should not really happen except in tests or rare cases,
1405 # but we don't want to create errors just for a metric, so we just skip it
1406 self.log.warning(
1407 "cannot record %s for task %s because previous state change time has not been saved",
1408 metric_name,
1409 self.task_id,
1410 )
1411 return
1412 timing = (timezone.utcnow() - self.queued_dttm).total_seconds()
1413 elif new_state == State.QUEUED:
1414 metric_name = "scheduled_duration"
1415 if self.start_date is None:
1416 # same comment as above
1417 self.log.warning(
1418 "cannot record %s for task %s because previous state change time has not been saved",
1419 metric_name,
1420 self.task_id,
1421 )
1422 return
1423 timing = (timezone.utcnow() - self.start_date).total_seconds()
1424 else:
1425 raise NotImplementedError("no metric emission setup for state %s", new_state)
1427 # send metric twice, once (legacy) with tags in the name and once with tags as tags
1428 Stats.timing(f"dag.{self.dag_id}.{self.task_id}.{metric_name}", timing)
1429 Stats.timing(f"task.{metric_name}", timing, tags={"task_id": self.task_id, "dag_id": self.dag_id})
1431 # Ensure we unset next_method and next_kwargs to ensure that any
1432 # retries don't re-use them.
1433 def clear_next_method_args(self) -> None:
1434 self.log.debug("Clearing next_method and next_kwargs.")
1436 self.next_method = None
1437 self.next_kwargs = None
1439 @provide_session
1440 @Sentry.enrich_errors
1441 def _run_raw_task(
1442 self,
1443 mark_success: bool = False,
1444 test_mode: bool = False,
1445 job_id: str | None = None,
1446 pool: str | None = None,
1447 session: Session = NEW_SESSION,
1448 ) -> TaskReturnCode | None:
1449 """
1450 Immediately runs the task (without checking or changing db state
1451 before execution) and then sets the appropriate final state after
1452 completion and runs any post-execute callbacks. Meant to be called
1453 only after another function changes the state to running.
1455 :param mark_success: Don't run the task, mark its state as success
1456 :param test_mode: Doesn't record success or failure in the DB
1457 :param pool: specifies the pool to use to run the task instance
1458 :param session: SQLAlchemy ORM Session
1459 """
1460 self.test_mode = test_mode
1461 self.refresh_from_task(self.task, pool_override=pool)
1462 self.refresh_from_db(session=session)
1463 self.job_id = job_id
1464 self.hostname = get_hostname()
1465 self.pid = os.getpid()
1466 if not test_mode:
1467 session.merge(self)
1468 session.commit()
1469 actual_start_date = timezone.utcnow()
1470 Stats.incr(f"ti.start.{self.task.dag_id}.{self.task.task_id}", tags=self.stats_tags)
1471 # Same metric with tagging
1472 Stats.incr("ti.start", tags=self.stats_tags)
1473 # Initialize final state counters at zero
1474 for state in State.task_states:
1475 Stats.incr(
1476 f"ti.finish.{self.task.dag_id}.{self.task.task_id}.{state}",
1477 count=0,
1478 tags=self.stats_tags,
1479 )
1480 # Same metric with tagging
1481 Stats.incr(
1482 "ti.finish",
1483 count=0,
1484 tags={**self.stats_tags, "state": str(state)},
1485 )
1487 self.task = self.task.prepare_for_execution()
1488 context = self.get_template_context(ignore_param_exceptions=False)
1490 # We lose previous state because it's changed in other process in LocalTaskJob.
1491 # We could probably pass it through here though...
1492 get_listener_manager().hook.on_task_instance_running(
1493 previous_state=TaskInstanceState.QUEUED, task_instance=self, session=session
1494 )
1495 try:
1496 if not mark_success:
1497 self._execute_task_with_callbacks(context, test_mode)
1498 if not test_mode:
1499 self.refresh_from_db(lock_for_update=True, session=session)
1500 self.state = State.SUCCESS
1501 except TaskDeferred as defer:
1502 # The task has signalled it wants to defer execution based on
1503 # a trigger.
1504 self._defer_task(defer=defer, session=session)
1505 self.log.info(
1506 "Pausing task as DEFERRED. dag_id=%s, task_id=%s, execution_date=%s, start_date=%s",
1507 self.dag_id,
1508 self.task_id,
1509 self._date_or_empty("execution_date"),
1510 self._date_or_empty("start_date"),
1511 )
1512 if not test_mode:
1513 session.add(Log(self.state, self))
1514 session.merge(self)
1515 session.commit()
1516 return TaskReturnCode.DEFERRED
1517 except AirflowSkipException as e:
1518 # Recording SKIP
1519 # log only if exception has any arguments to prevent log flooding
1520 if e.args:
1521 self.log.info(e)
1522 if not test_mode:
1523 self.refresh_from_db(lock_for_update=True, session=session)
1524 self.state = State.SKIPPED
1525 except AirflowRescheduleException as reschedule_exception:
1526 self._handle_reschedule(actual_start_date, reschedule_exception, test_mode, session=session)
1527 session.commit()
1528 return None
1529 except (AirflowFailException, AirflowSensorTimeout) as e:
1530 # If AirflowFailException is raised, task should not retry.
1531 # If a sensor in reschedule mode reaches timeout, task should not retry.
1532 self.handle_failure(e, test_mode, context, force_fail=True, session=session)
1533 session.commit()
1534 raise
1535 except AirflowException as e:
1536 if not test_mode:
1537 self.refresh_from_db(lock_for_update=True, session=session)
1538 # for case when task is marked as success/failed externally
1539 # or dagrun timed out and task is marked as skipped
1540 # current behavior doesn't hit the callbacks
1541 if self.state in State.finished:
1542 self.clear_next_method_args()
1543 session.merge(self)
1544 session.commit()
1545 return None
1546 else:
1547 self.handle_failure(e, test_mode, context, session=session)
1548 session.commit()
1549 raise
1550 except (Exception, KeyboardInterrupt) as e:
1551 self.handle_failure(e, test_mode, context, session=session)
1552 session.commit()
1553 raise
1554 finally:
1555 Stats.incr(f"ti.finish.{self.dag_id}.{self.task_id}.{self.state}", tags=self.stats_tags)
1556 # Same metric with tagging
1557 Stats.incr("ti.finish", tags={**self.stats_tags, "state": str(self.state)})
1559 # Recording SKIPPED or SUCCESS
1560 self.clear_next_method_args()
1561 self.end_date = timezone.utcnow()
1562 self._log_state()
1563 self.set_duration()
1565 # run on_success_callback before db committing
1566 # otherwise, the LocalTaskJob sees the state is changed to `success`,
1567 # but the task_runner is still running, LocalTaskJob then treats the state is set externally!
1568 self._run_finished_callback(self.task.on_success_callback, context, "on_success")
1570 if not test_mode:
1571 session.add(Log(self.state, self))
1572 session.merge(self).task = self.task
1573 if self.state == TaskInstanceState.SUCCESS:
1574 self._register_dataset_changes(session=session)
1575 get_listener_manager().hook.on_task_instance_success(
1576 previous_state=TaskInstanceState.RUNNING, task_instance=self, session=session
1577 )
1579 session.commit()
1580 return None
1582 def _register_dataset_changes(self, *, session: Session) -> None:
1583 for obj in self.task.outlets or []:
1584 self.log.debug("outlet obj %s", obj)
1585 # Lineage can have other types of objects besides datasets
1586 if isinstance(obj, Dataset):
1587 dataset_manager.register_dataset_change(
1588 task_instance=self,
1589 dataset=obj,
1590 session=session,
1591 )
1593 def _execute_task_with_callbacks(self, context, test_mode=False):
1594 """Prepare Task for Execution."""
1595 from airflow.models.renderedtifields import RenderedTaskInstanceFields
1597 parent_pid = os.getpid()
1599 def signal_handler(signum, frame):
1600 pid = os.getpid()
1602 # If a task forks during execution (from DAG code) for whatever
1603 # reason, we want to make sure that we react to the signal only in
1604 # the process that we've spawned ourselves (referred to here as the
1605 # parent process).
1606 if pid != parent_pid:
1607 os._exit(1)
1608 return
1609 self.log.error("Received SIGTERM. Terminating subprocesses.")
1610 self.task.on_kill()
1611 raise AirflowException("Task received SIGTERM signal")
1613 signal.signal(signal.SIGTERM, signal_handler)
1615 # Don't clear Xcom until the task is certain to execute, and check if we are resuming from deferral.
1616 if not self.next_method:
1617 self.clear_xcom_data()
1619 with Stats.timer(f"dag.{self.task.dag_id}.{self.task.task_id}.duration", tags=self.stats_tags):
1620 # Set the validated/merged params on the task object.
1621 self.task.params = context["params"]
1623 task_orig = self.render_templates(context=context)
1624 if not test_mode:
1625 rtif = RenderedTaskInstanceFields(ti=self, render_templates=False)
1626 RenderedTaskInstanceFields.write(rtif)
1627 RenderedTaskInstanceFields.delete_old_records(self.task_id, self.dag_id)
1629 # Export context to make it available for operators to use.
1630 airflow_context_vars = context_to_airflow_vars(context, in_env_var_format=True)
1631 os.environ.update(airflow_context_vars)
1633 # Log context only for the default execution method, the assumption
1634 # being that otherwise we're resuming a deferred task (in which
1635 # case there's no need to log these again).
1636 if not self.next_method:
1637 self.log.info(
1638 "Exporting env vars: %s",
1639 " ".join(f"{k}={v!r}" for k, v in airflow_context_vars.items()),
1640 )
1642 # Run pre_execute callback
1643 self.task.pre_execute(context=context)
1645 # Run on_execute callback
1646 self._run_execute_callback(context, self.task)
1648 # Execute the task
1649 with set_current_context(context):
1650 result = self._execute_task(context, task_orig)
1651 # Run post_execute callback
1652 self.task.post_execute(context=context, result=result)
1654 Stats.incr(f"operator_successes_{self.task.task_type}", tags=self.stats_tags)
1655 # Same metric with tagging
1656 Stats.incr("operator_successes", tags={**self.stats_tags, "task_type": self.task.task_type})
1657 Stats.incr("ti_successes", tags=self.stats_tags)
1659 def _run_finished_callback(
1660 self,
1661 callbacks: None | TaskStateChangeCallback | list[TaskStateChangeCallback],
1662 context: Context,
1663 callback_type: str,
1664 ) -> None:
1665 """Run callback after task finishes."""
1666 if callbacks:
1667 callbacks = callbacks if isinstance(callbacks, list) else [callbacks]
1668 for callback in callbacks:
1669 try:
1670 callback(context)
1671 except Exception:
1672 callback_name = qualname(callback).split(".")[-1]
1673 self.log.exception(
1674 f"Error when executing {callback_name} callback" # type: ignore[attr-defined]
1675 )
1677 def _execute_task(self, context, task_orig):
1678 """Executes Task (optionally with a Timeout) and pushes Xcom results."""
1679 task_to_execute = self.task
1680 # If the task has been deferred and is being executed due to a trigger,
1681 # then we need to pick the right method to come back to, otherwise
1682 # we go for the default execute
1683 if self.next_method:
1684 # __fail__ is a special signal value for next_method that indicates
1685 # this task was scheduled specifically to fail.
1686 if self.next_method == "__fail__":
1687 next_kwargs = self.next_kwargs or {}
1688 traceback = self.next_kwargs.get("traceback")
1689 if traceback is not None:
1690 self.log.error("Trigger failed:\n%s", "\n".join(traceback))
1691 raise TaskDeferralError(next_kwargs.get("error", "Unknown"))
1692 # Grab the callable off the Operator/Task and add in any kwargs
1693 execute_callable = getattr(task_to_execute, self.next_method)
1694 if self.next_kwargs:
1695 execute_callable = partial(execute_callable, **self.next_kwargs)
1696 else:
1697 execute_callable = task_to_execute.execute
1698 # If a timeout is specified for the task, make it fail
1699 # if it goes beyond
1700 if task_to_execute.execution_timeout:
1701 # If we are coming in with a next_method (i.e. from a deferral),
1702 # calculate the timeout from our start_date.
1703 if self.next_method:
1704 timeout_seconds = (
1705 task_to_execute.execution_timeout - (timezone.utcnow() - self.start_date)
1706 ).total_seconds()
1707 else:
1708 timeout_seconds = task_to_execute.execution_timeout.total_seconds()
1709 try:
1710 # It's possible we're already timed out, so fast-fail if true
1711 if timeout_seconds <= 0:
1712 raise AirflowTaskTimeout()
1713 # Run task in timeout wrapper
1714 with timeout(timeout_seconds):
1715 result = execute_callable(context=context)
1716 except AirflowTaskTimeout:
1717 task_to_execute.on_kill()
1718 raise
1719 else:
1720 result = execute_callable(context=context)
1721 with create_session() as session:
1722 if task_to_execute.do_xcom_push:
1723 xcom_value = result
1724 else:
1725 xcom_value = None
1726 if xcom_value is not None: # If the task returns a result, push an XCom containing it.
1727 self.xcom_push(key=XCOM_RETURN_KEY, value=xcom_value, session=session)
1728 self._record_task_map_for_downstreams(task_orig, xcom_value, session=session)
1729 return result
1731 @provide_session
1732 def _defer_task(self, session: Session, defer: TaskDeferred) -> None:
1733 """
1734 Marks the task as deferred and sets up the trigger that is needed
1735 to resume it.
1736 """
1737 from airflow.models.trigger import Trigger
1739 # First, make the trigger entry
1740 trigger_row = Trigger.from_object(defer.trigger)
1741 session.add(trigger_row)
1742 session.flush()
1744 # Then, update ourselves so it matches the deferral request
1745 # Keep an eye on the logic in `check_and_change_state_before_execution()`
1746 # depending on self.next_method semantics
1747 self.state = State.DEFERRED
1748 self.trigger_id = trigger_row.id
1749 self.next_method = defer.method_name
1750 self.next_kwargs = defer.kwargs or {}
1752 # Decrement try number so the next one is the same try
1753 self._try_number -= 1
1755 # Calculate timeout too if it was passed
1756 if defer.timeout is not None:
1757 self.trigger_timeout = timezone.utcnow() + defer.timeout
1758 else:
1759 self.trigger_timeout = None
1761 # If an execution_timeout is set, set the timeout to the minimum of
1762 # it and the trigger timeout
1763 execution_timeout = self.task.execution_timeout
1764 if execution_timeout:
1765 if self.trigger_timeout:
1766 self.trigger_timeout = min(self.start_date + execution_timeout, self.trigger_timeout)
1767 else:
1768 self.trigger_timeout = self.start_date + execution_timeout
1770 def _run_execute_callback(self, context: Context, task: Operator) -> None:
1771 """Functions that need to be run before a Task is executed."""
1772 callbacks = task.on_execute_callback
1773 if callbacks:
1774 callbacks = callbacks if isinstance(callbacks, list) else [callbacks]
1775 for callback in callbacks:
1776 try:
1777 callback(context)
1778 except Exception:
1779 self.log.exception("Failed when executing execute callback")
1781 @provide_session
1782 def run(
1783 self,
1784 verbose: bool = True,
1785 ignore_all_deps: bool = False,
1786 ignore_depends_on_past: bool = False,
1787 wait_for_past_depends_before_skipping: bool = False,
1788 ignore_task_deps: bool = False,
1789 ignore_ti_state: bool = False,
1790 mark_success: bool = False,
1791 test_mode: bool = False,
1792 job_id: str | None = None,
1793 pool: str | None = None,
1794 session: Session = NEW_SESSION,
1795 ) -> None:
1796 """Run TaskInstance."""
1797 res = self.check_and_change_state_before_execution(
1798 verbose=verbose,
1799 ignore_all_deps=ignore_all_deps,
1800 ignore_depends_on_past=ignore_depends_on_past,
1801 wait_for_past_depends_before_skipping=wait_for_past_depends_before_skipping,
1802 ignore_task_deps=ignore_task_deps,
1803 ignore_ti_state=ignore_ti_state,
1804 mark_success=mark_success,
1805 test_mode=test_mode,
1806 job_id=job_id,
1807 pool=pool,
1808 session=session,
1809 )
1810 if not res:
1811 return
1813 self._run_raw_task(
1814 mark_success=mark_success, test_mode=test_mode, job_id=job_id, pool=pool, session=session
1815 )
1817 def dry_run(self) -> None:
1818 """Only Renders Templates for the TI."""
1819 from airflow.models.baseoperator import BaseOperator
1821 self.task = self.task.prepare_for_execution()
1822 self.render_templates()
1823 if TYPE_CHECKING:
1824 assert isinstance(self.task, BaseOperator)
1825 self.task.dry_run()
1827 @provide_session
1828 def _handle_reschedule(
1829 self, actual_start_date, reschedule_exception, test_mode=False, session=NEW_SESSION
1830 ):
1831 # Don't record reschedule request in test mode
1832 if test_mode:
1833 return
1835 from airflow.models.dagrun import DagRun # Avoid circular import
1837 self.refresh_from_db(session)
1839 self.end_date = timezone.utcnow()
1840 self.set_duration()
1842 # Lock DAG run to be sure not to get into a deadlock situation when trying to insert
1843 # TaskReschedule which apparently also creates lock on corresponding DagRun entity
1844 with_row_locks(
1845 session.query(DagRun).filter_by(
1846 dag_id=self.dag_id,
1847 run_id=self.run_id,
1848 ),
1849 session=session,
1850 ).one()
1852 # Log reschedule request
1853 session.add(
1854 TaskReschedule(
1855 self.task,
1856 self.run_id,
1857 self._try_number,
1858 actual_start_date,
1859 self.end_date,
1860 reschedule_exception.reschedule_date,
1861 self.map_index,
1862 )
1863 )
1865 # set state
1866 self.state = State.UP_FOR_RESCHEDULE
1868 # Decrement try_number so subsequent runs will use the same try number and write
1869 # to same log file.
1870 self._try_number -= 1
1872 self.clear_next_method_args()
1874 session.merge(self)
1875 session.commit()
1876 self.log.info("Rescheduling task, marking task as UP_FOR_RESCHEDULE")
1878 @staticmethod
1879 def get_truncated_error_traceback(error: BaseException, truncate_to: Callable) -> TracebackType | None:
1880 """
1881 Truncates the traceback of an exception to the first frame called from within a given function.
1883 :param error: exception to get traceback from
1884 :param truncate_to: Function to truncate TB to. Must have a ``__code__`` attribute
1886 :meta private:
1887 """
1888 tb = error.__traceback__
1889 code = truncate_to.__func__.__code__ # type: ignore[attr-defined]
1890 while tb is not None:
1891 if tb.tb_frame.f_code is code:
1892 return tb.tb_next
1893 tb = tb.tb_next
1894 return tb or error.__traceback__
1896 @provide_session
1897 def handle_failure(
1898 self,
1899 error: None | str | Exception | KeyboardInterrupt,
1900 test_mode: bool | None = None,
1901 context: Context | None = None,
1902 force_fail: bool = False,
1903 session: Session = NEW_SESSION,
1904 ) -> None:
1905 """Handle Failure for the TaskInstance."""
1906 if test_mode is None:
1907 test_mode = self.test_mode
1909 get_listener_manager().hook.on_task_instance_failed(
1910 previous_state=TaskInstanceState.RUNNING, task_instance=self, session=session
1911 )
1913 if error:
1914 if isinstance(error, BaseException):
1915 tb = self.get_truncated_error_traceback(error, truncate_to=self._execute_task)
1916 self.log.error("Task failed with exception", exc_info=(type(error), error, tb))
1917 else:
1918 self.log.error("%s", error)
1919 if not test_mode:
1920 self.refresh_from_db(session)
1922 self.end_date = timezone.utcnow()
1923 self.set_duration()
1925 Stats.incr(f"operator_failures_{self.operator}", tags=self.stats_tags)
1926 # Same metric with tagging
1927 Stats.incr("operator_failures", tags={**self.stats_tags, "operator": self.operator})
1928 Stats.incr("ti_failures", tags=self.stats_tags)
1930 if not test_mode:
1931 session.add(Log(State.FAILED, self))
1933 # Log failure duration
1934 session.add(TaskFail(ti=self))
1936 self.clear_next_method_args()
1938 # In extreme cases (zombie in case of dag with parse error) we might _not_ have a Task.
1939 if context is None and getattr(self, "task", None):
1940 context = self.get_template_context(session)
1942 if context is not None:
1943 context["exception"] = error
1945 # Set state correctly and figure out how to log it and decide whether
1946 # to email
1948 # Note, callback invocation needs to be handled by caller of
1949 # _run_raw_task to avoid race conditions which could lead to duplicate
1950 # invocations or miss invocation.
1952 # Since this function is called only when the TaskInstance state is running,
1953 # try_number contains the current try_number (not the next). We
1954 # only mark task instance as FAILED if the next task instance
1955 # try_number exceeds the max_tries ... or if force_fail is truthy
1957 task: BaseOperator | None = None
1958 try:
1959 if getattr(self, "task", None) and context:
1960 task = self.task.unmap((context, session))
1961 except Exception:
1962 self.log.error("Unable to unmap task to determine if we need to send an alert email")
1964 if force_fail or not self.is_eligible_to_retry():
1965 self.state = State.FAILED
1966 email_for_state = operator.attrgetter("email_on_failure")
1967 callbacks = task.on_failure_callback if task else None
1968 callback_type = "on_failure"
1970 if task and task.dag and task.dag.fail_stop:
1971 tis = self.get_dagrun(session).get_task_instances()
1972 stop_all_tasks_in_dag(tis, session, self.task_id)
1973 else:
1974 if self.state == State.QUEUED:
1975 # We increase the try_number so as to fail the task if it fails to start after sometime
1976 self._try_number += 1
1977 self.state = State.UP_FOR_RETRY
1978 email_for_state = operator.attrgetter("email_on_retry")
1979 callbacks = task.on_retry_callback if task else None
1980 callback_type = "on_retry"
1982 self._log_state("Immediate failure requested. " if force_fail else "")
1983 if task and email_for_state(task) and task.email:
1984 try:
1985 self.email_alert(error, task)
1986 except Exception:
1987 self.log.exception("Failed to send email to: %s", task.email)
1989 if callbacks and context:
1990 self._run_finished_callback(callbacks, context, callback_type)
1992 if not test_mode:
1993 session.merge(self)
1994 session.flush()
1996 def is_eligible_to_retry(self):
1997 """Is task instance is eligible for retry."""
1998 if self.state == State.RESTARTING:
1999 # If a task is cleared when running, it goes into RESTARTING state and is always
2000 # eligible for retry
2001 return True
2002 if not getattr(self, "task", None):
2003 # Couldn't load the task, don't know number of retries, guess:
2004 return self.try_number <= self.max_tries
2006 return self.task.retries and self.try_number <= self.max_tries
2008 def get_template_context(
2009 self,
2010 session: Session | None = None,
2011 ignore_param_exceptions: bool = True,
2012 ) -> Context:
2013 """Return TI Context."""
2014 # Do not use provide_session here -- it expunges everything on exit!
2015 if not session:
2016 session = settings.Session()
2018 from airflow import macros
2019 from airflow.models.abstractoperator import NotMapped
2021 integrate_macros_plugins()
2023 task = self.task
2024 if TYPE_CHECKING:
2025 assert task.dag
2026 dag: DAG = task.dag
2028 dag_run = self.get_dagrun(session)
2029 data_interval = dag.get_run_data_interval(dag_run)
2031 validated_params = process_params(dag, task, dag_run, suppress_exception=ignore_param_exceptions)
2033 logical_date = timezone.coerce_datetime(self.execution_date)
2034 ds = logical_date.strftime("%Y-%m-%d")
2035 ds_nodash = ds.replace("-", "")
2036 ts = logical_date.isoformat()
2037 ts_nodash = logical_date.strftime("%Y%m%dT%H%M%S")
2038 ts_nodash_with_tz = ts.replace("-", "").replace(":", "")
2040 @cache # Prevent multiple database access.
2041 def _get_previous_dagrun_success() -> DagRun | None:
2042 return self.get_previous_dagrun(state=DagRunState.SUCCESS, session=session)
2044 def _get_previous_dagrun_data_interval_success() -> DataInterval | None:
2045 dagrun = _get_previous_dagrun_success()
2046 if dagrun is None:
2047 return None
2048 return dag.get_run_data_interval(dagrun)
2050 def get_prev_data_interval_start_success() -> pendulum.DateTime | None:
2051 data_interval = _get_previous_dagrun_data_interval_success()
2052 if data_interval is None:
2053 return None
2054 return data_interval.start
2056 def get_prev_data_interval_end_success() -> pendulum.DateTime | None:
2057 data_interval = _get_previous_dagrun_data_interval_success()
2058 if data_interval is None:
2059 return None
2060 return data_interval.end
2062 def get_prev_start_date_success() -> pendulum.DateTime | None:
2063 dagrun = _get_previous_dagrun_success()
2064 if dagrun is None:
2065 return None
2066 return timezone.coerce_datetime(dagrun.start_date)
2068 @cache
2069 def get_yesterday_ds() -> str:
2070 return (logical_date - timedelta(1)).strftime("%Y-%m-%d")
2072 def get_yesterday_ds_nodash() -> str:
2073 return get_yesterday_ds().replace("-", "")
2075 @cache
2076 def get_tomorrow_ds() -> str:
2077 return (logical_date + timedelta(1)).strftime("%Y-%m-%d")
2079 def get_tomorrow_ds_nodash() -> str:
2080 return get_tomorrow_ds().replace("-", "")
2082 @cache
2083 def get_next_execution_date() -> pendulum.DateTime | None:
2084 # For manually triggered dagruns that aren't run on a schedule,
2085 # the "next" execution date doesn't make sense, and should be set
2086 # to execution date for consistency with how execution_date is set
2087 # for manually triggered tasks, i.e. triggered_date == execution_date.
2088 if dag_run.external_trigger:
2089 return logical_date
2090 if dag is None:
2091 return None
2092 next_info = dag.next_dagrun_info(data_interval, restricted=False)
2093 if next_info is None:
2094 return None
2095 return timezone.coerce_datetime(next_info.logical_date)
2097 def get_next_ds() -> str | None:
2098 execution_date = get_next_execution_date()
2099 if execution_date is None:
2100 return None
2101 return execution_date.strftime("%Y-%m-%d")
2103 def get_next_ds_nodash() -> str | None:
2104 ds = get_next_ds()
2105 if ds is None:
2106 return ds
2107 return ds.replace("-", "")
2109 @cache
2110 def get_prev_execution_date():
2111 # For manually triggered dagruns that aren't run on a schedule,
2112 # the "previous" execution date doesn't make sense, and should be set
2113 # to execution date for consistency with how execution_date is set
2114 # for manually triggered tasks, i.e. triggered_date == execution_date.
2115 if dag_run.external_trigger:
2116 return logical_date
2117 with warnings.catch_warnings():
2118 warnings.simplefilter("ignore", RemovedInAirflow3Warning)
2119 return dag.previous_schedule(logical_date)
2121 @cache
2122 def get_prev_ds() -> str | None:
2123 execution_date = get_prev_execution_date()
2124 if execution_date is None:
2125 return None
2126 return execution_date.strftime(r"%Y-%m-%d")
2128 def get_prev_ds_nodash() -> str | None:
2129 prev_ds = get_prev_ds()
2130 if prev_ds is None:
2131 return None
2132 return prev_ds.replace("-", "")
2134 def get_triggering_events() -> dict[str, list[DatasetEvent]]:
2135 if TYPE_CHECKING:
2136 assert session is not None
2138 # The dag_run may not be attached to the session anymore since the
2139 # code base is over-zealous with use of session.expunge_all().
2140 # Re-attach it if we get called.
2141 nonlocal dag_run
2142 if dag_run not in session:
2143 dag_run = session.merge(dag_run, load=False)
2145 dataset_events = dag_run.consumed_dataset_events
2146 triggering_events: dict[str, list[DatasetEvent]] = defaultdict(list)
2147 for event in dataset_events:
2148 triggering_events[event.dataset.uri].append(event)
2150 return triggering_events
2152 try:
2153 expanded_ti_count: int | None = task.get_mapped_ti_count(self.run_id, session=session)
2154 except NotMapped:
2155 expanded_ti_count = None
2157 # NOTE: If you add anything to this dict, make sure to also update the
2158 # definition in airflow/utils/context.pyi, and KNOWN_CONTEXT_KEYS in
2159 # airflow/utils/context.py!
2160 context = {
2161 "conf": conf,
2162 "dag": dag,
2163 "dag_run": dag_run,
2164 "data_interval_end": timezone.coerce_datetime(data_interval.end),
2165 "data_interval_start": timezone.coerce_datetime(data_interval.start),
2166 "ds": ds,
2167 "ds_nodash": ds_nodash,
2168 "execution_date": logical_date,
2169 "expanded_ti_count": expanded_ti_count,
2170 "inlets": task.inlets,
2171 "logical_date": logical_date,
2172 "macros": macros,
2173 "next_ds": get_next_ds(),
2174 "next_ds_nodash": get_next_ds_nodash(),
2175 "next_execution_date": get_next_execution_date(),
2176 "outlets": task.outlets,
2177 "params": validated_params,
2178 "prev_data_interval_start_success": get_prev_data_interval_start_success(),
2179 "prev_data_interval_end_success": get_prev_data_interval_end_success(),
2180 "prev_ds": get_prev_ds(),
2181 "prev_ds_nodash": get_prev_ds_nodash(),
2182 "prev_execution_date": get_prev_execution_date(),
2183 "prev_execution_date_success": self.get_previous_execution_date(
2184 state=DagRunState.SUCCESS,
2185 session=session,
2186 ),
2187 "prev_start_date_success": get_prev_start_date_success(),
2188 "run_id": self.run_id,
2189 "task": task,
2190 "task_instance": self,
2191 "task_instance_key_str": f"{task.dag_id}__{task.task_id}__{ds_nodash}",
2192 "test_mode": self.test_mode,
2193 "ti": self,
2194 "tomorrow_ds": get_tomorrow_ds(),
2195 "tomorrow_ds_nodash": get_tomorrow_ds_nodash(),
2196 "triggering_dataset_events": lazy_object_proxy.Proxy(get_triggering_events),
2197 "ts": ts,
2198 "ts_nodash": ts_nodash,
2199 "ts_nodash_with_tz": ts_nodash_with_tz,
2200 "var": {
2201 "json": VariableAccessor(deserialize_json=True),
2202 "value": VariableAccessor(deserialize_json=False),
2203 },
2204 "conn": ConnectionAccessor(),
2205 "yesterday_ds": get_yesterday_ds(),
2206 "yesterday_ds_nodash": get_yesterday_ds_nodash(),
2207 }
2208 # Mypy doesn't like turning existing dicts in to a TypeDict -- and we "lie" in the type stub to say it
2209 # is one, but in practice it isn't. See https://github.com/python/mypy/issues/8890
2210 return Context(context) # type: ignore
2212 @provide_session
2213 def get_rendered_template_fields(self, session: Session = NEW_SESSION) -> None:
2214 """
2215 Update task with rendered template fields for presentation in UI.
2216 If task has already run, will fetch from DB; otherwise will render.
2217 """
2218 from airflow.models.renderedtifields import RenderedTaskInstanceFields
2220 rendered_task_instance_fields = RenderedTaskInstanceFields.get_templated_fields(self, session=session)
2221 if rendered_task_instance_fields:
2222 self.task = self.task.unmap(None)
2223 for field_name, rendered_value in rendered_task_instance_fields.items():
2224 setattr(self.task, field_name, rendered_value)
2225 return
2227 try:
2228 # If we get here, either the task hasn't run or the RTIF record was purged.
2229 from airflow.utils.log.secrets_masker import redact
2231 self.render_templates()
2232 for field_name in self.task.template_fields:
2233 rendered_value = getattr(self.task, field_name)
2234 setattr(self.task, field_name, redact(rendered_value, field_name))
2235 except (TemplateAssertionError, UndefinedError) as e:
2236 raise AirflowException(
2237 "Webserver does not have access to User-defined Macros or Filters "
2238 "when Dag Serialization is enabled. Hence for the task that have not yet "
2239 "started running, please use 'airflow tasks render' for debugging the "
2240 "rendering of template_fields."
2241 ) from e
2243 @provide_session
2244 def get_rendered_k8s_spec(self, session: Session = NEW_SESSION):
2245 """Fetch rendered template fields from DB."""
2246 from airflow.models.renderedtifields import RenderedTaskInstanceFields
2248 rendered_k8s_spec = RenderedTaskInstanceFields.get_k8s_pod_yaml(self, session=session)
2249 if not rendered_k8s_spec:
2250 try:
2251 rendered_k8s_spec = self.render_k8s_pod_yaml()
2252 except (TemplateAssertionError, UndefinedError) as e:
2253 raise AirflowException(f"Unable to render a k8s spec for this taskinstance: {e}") from e
2254 return rendered_k8s_spec
2256 def overwrite_params_with_dag_run_conf(self, params, dag_run):
2257 """Overwrite Task Params with DagRun.conf."""
2258 if dag_run and dag_run.conf:
2259 self.log.debug("Updating task params (%s) with DagRun.conf (%s)", params, dag_run.conf)
2260 params.update(dag_run.conf)
2262 def render_templates(self, context: Context | None = None) -> Operator:
2263 """Render templates in the operator fields.
2265 If the task was originally mapped, this may replace ``self.task`` with
2266 the unmapped, fully rendered BaseOperator. The original ``self.task``
2267 before replacement is returned.
2268 """
2269 if not context:
2270 context = self.get_template_context()
2271 original_task = self.task
2273 # If self.task is mapped, this call replaces self.task to point to the
2274 # unmapped BaseOperator created by this function! This is because the
2275 # MappedOperator is useless for template rendering, and we need to be
2276 # able to access the unmapped task instead.
2277 original_task.render_template_fields(context)
2279 return original_task
2281 def render_k8s_pod_yaml(self) -> dict | None:
2282 """Render k8s pod yaml."""
2283 from kubernetes.client.api_client import ApiClient
2285 from airflow.kubernetes.kube_config import KubeConfig
2286 from airflow.kubernetes.kubernetes_helper_functions import create_pod_id # Circular import
2287 from airflow.kubernetes.pod_generator import PodGenerator
2289 kube_config = KubeConfig()
2290 pod = PodGenerator.construct_pod(
2291 dag_id=self.dag_id,
2292 run_id=self.run_id,
2293 task_id=self.task_id,
2294 map_index=self.map_index,
2295 date=None,
2296 pod_id=create_pod_id(self.dag_id, self.task_id),
2297 try_number=self.try_number,
2298 kube_image=kube_config.kube_image,
2299 args=self.command_as_list(),
2300 pod_override_object=PodGenerator.from_obj(self.executor_config),
2301 scheduler_job_id="0",
2302 namespace=kube_config.executor_namespace,
2303 base_worker_pod=PodGenerator.deserialize_model_file(kube_config.pod_template_file),
2304 with_mutation_hook=True,
2305 )
2306 sanitized_pod = ApiClient().sanitize_for_serialization(pod)
2307 return sanitized_pod
2309 def get_email_subject_content(
2310 self, exception: BaseException, task: BaseOperator | None = None
2311 ) -> tuple[str, str, str]:
2312 """Get the email subject content for exceptions."""
2313 # For a ti from DB (without ti.task), return the default value
2314 if task is None:
2315 task = getattr(self, "task")
2316 use_default = task is None
2317 exception_html = str(exception).replace("\n", "<br>")
2319 default_subject = "Airflow alert: {{ti}}"
2320 # For reporting purposes, we report based on 1-indexed,
2321 # not 0-indexed lists (i.e. Try 1 instead of
2322 # Try 0 for the first attempt).
2323 default_html_content = (
2324 "Try {{try_number}} out of {{max_tries + 1}}<br>"
2325 "Exception:<br>{{exception_html}}<br>"
2326 'Log: <a href="{{ti.log_url}}">Link</a><br>'
2327 "Host: {{ti.hostname}}<br>"
2328 'Mark success: <a href="{{ti.mark_success_url}}">Link</a><br>'
2329 )
2331 default_html_content_err = (
2332 "Try {{try_number}} out of {{max_tries + 1}}<br>"
2333 "Exception:<br>Failed attempt to attach error logs<br>"
2334 'Log: <a href="{{ti.log_url}}">Link</a><br>'
2335 "Host: {{ti.hostname}}<br>"
2336 'Mark success: <a href="{{ti.mark_success_url}}">Link</a><br>'
2337 )
2339 # This function is called after changing the state from State.RUNNING,
2340 # so we need to subtract 1 from self.try_number here.
2341 current_try_number = self.try_number - 1
2342 additional_context: dict[str, Any] = {
2343 "exception": exception,
2344 "exception_html": exception_html,
2345 "try_number": current_try_number,
2346 "max_tries": self.max_tries,
2347 }
2349 if use_default:
2350 default_context = {"ti": self, **additional_context}
2351 jinja_env = jinja2.Environment(
2352 loader=jinja2.FileSystemLoader(os.path.dirname(__file__)), autoescape=True
2353 )
2354 subject = jinja_env.from_string(default_subject).render(**default_context)
2355 html_content = jinja_env.from_string(default_html_content).render(**default_context)
2356 html_content_err = jinja_env.from_string(default_html_content_err).render(**default_context)
2358 else:
2359 # Use the DAG's get_template_env() to set force_sandboxed. Don't add
2360 # the flag to the function on task object -- that function can be
2361 # overridden, and adding a flag breaks backward compatibility.
2362 dag = self.task.get_dag()
2363 if dag:
2364 jinja_env = dag.get_template_env(force_sandboxed=True)
2365 else:
2366 jinja_env = SandboxedEnvironment(cache_size=0)
2367 jinja_context = self.get_template_context()
2368 context_merge(jinja_context, additional_context)
2370 def render(key: str, content: str) -> str:
2371 if conf.has_option("email", key):
2372 path = conf.get_mandatory_value("email", key)
2373 try:
2374 with open(path) as f:
2375 content = f.read()
2376 except FileNotFoundError:
2377 self.log.warning(f"Could not find email template file '{path!r}'. Using defaults...")
2378 except OSError:
2379 self.log.exception(f"Error while using email template '{path!r}'. Using defaults...")
2380 return render_template_to_string(jinja_env.from_string(content), jinja_context)
2382 subject = render("subject_template", default_subject)
2383 html_content = render("html_content_template", default_html_content)
2384 html_content_err = render("html_content_template", default_html_content_err)
2386 return subject, html_content, html_content_err
2388 def email_alert(self, exception, task: BaseOperator) -> None:
2389 """Send alert email with exception information."""
2390 subject, html_content, html_content_err = self.get_email_subject_content(exception, task=task)
2391 assert task.email
2392 try:
2393 send_email(task.email, subject, html_content)
2394 except Exception:
2395 send_email(task.email, subject, html_content_err)
2397 def set_duration(self) -> None:
2398 """Set TI duration."""
2399 if self.end_date and self.start_date:
2400 self.duration = (self.end_date - self.start_date).total_seconds()
2401 else:
2402 self.duration = None
2403 self.log.debug("Task Duration set to %s", self.duration)
2405 def _record_task_map_for_downstreams(self, task: Operator, value: Any, *, session: Session) -> None:
2406 if next(task.iter_mapped_dependants(), None) is None: # No mapped dependants, no need to validate.
2407 return
2408 # TODO: We don't push TaskMap for mapped task instances because it's not
2409 # currently possible for a downstream to depend on one individual mapped
2410 # task instance. This will change when we implement task mapping inside
2411 # a mapped task group, and we'll need to further analyze the case.
2412 if isinstance(task, MappedOperator):
2413 return
2414 if value is None:
2415 raise XComForMappingNotPushed()
2416 if not _is_mappable_value(value):
2417 raise UnmappableXComTypePushed(value)
2418 task_map = TaskMap.from_task_instance_xcom(self, value)
2419 max_map_length = conf.getint("core", "max_map_length", fallback=1024)
2420 if task_map.length > max_map_length:
2421 raise UnmappableXComLengthPushed(value, max_map_length)
2422 session.merge(task_map)
2424 @provide_session
2425 def xcom_push(
2426 self,
2427 key: str,
2428 value: Any,
2429 execution_date: datetime | None = None,
2430 session: Session = NEW_SESSION,
2431 ) -> None:
2432 """
2433 Make an XCom available for tasks to pull.
2435 :param key: Key to store the value under.
2436 :param value: Value to store. What types are possible depends on whether
2437 ``enable_xcom_pickling`` is true or not. If so, this can be any
2438 picklable object; only be JSON-serializable may be used otherwise.
2439 :param execution_date: Deprecated parameter that has no effect.
2440 """
2441 if execution_date is not None:
2442 self_execution_date = self.get_dagrun(session).execution_date
2443 if execution_date < self_execution_date:
2444 raise ValueError(
2445 f"execution_date can not be in the past (current execution_date is "
2446 f"{self_execution_date}; received {execution_date})"
2447 )
2448 elif execution_date is not None:
2449 message = "Passing 'execution_date' to 'TaskInstance.xcom_push()' is deprecated."
2450 warnings.warn(message, RemovedInAirflow3Warning, stacklevel=3)
2452 XCom.set(
2453 key=key,
2454 value=value,
2455 task_id=self.task_id,
2456 dag_id=self.dag_id,
2457 run_id=self.run_id,
2458 map_index=self.map_index,
2459 session=session,
2460 )
2462 @provide_session
2463 def xcom_pull(
2464 self,
2465 task_ids: str | Iterable[str] | None = None,
2466 dag_id: str | None = None,
2467 key: str = XCOM_RETURN_KEY,
2468 include_prior_dates: bool = False,
2469 session: Session = NEW_SESSION,
2470 *,
2471 map_indexes: int | Iterable[int] | None = None,
2472 default: Any = None,
2473 ) -> Any:
2474 """Pull XComs that optionally meet certain criteria.
2476 :param key: A key for the XCom. If provided, only XComs with matching
2477 keys will be returned. The default key is ``'return_value'``, also
2478 available as constant ``XCOM_RETURN_KEY``. This key is automatically
2479 given to XComs returned by tasks (as opposed to being pushed
2480 manually). To remove the filter, pass *None*.
2481 :param task_ids: Only XComs from tasks with matching ids will be
2482 pulled. Pass *None* to remove the filter.
2483 :param dag_id: If provided, only pulls XComs from this DAG. If *None*
2484 (default), the DAG of the calling task is used.
2485 :param map_indexes: If provided, only pull XComs with matching indexes.
2486 If *None* (default), this is inferred from the task(s) being pulled
2487 (see below for details).
2488 :param include_prior_dates: If False, only XComs from the current
2489 execution_date are returned. If *True*, XComs from previous dates
2490 are returned as well.
2492 When pulling one single task (``task_id`` is *None* or a str) without
2493 specifying ``map_indexes``, the return value is inferred from whether
2494 the specified task is mapped. If not, value from the one single task
2495 instance is returned. If the task to pull is mapped, an iterator (not a
2496 list) yielding XComs from mapped task instances is returned. In either
2497 case, ``default`` (*None* if not specified) is returned if no matching
2498 XComs are found.
2500 When pulling multiple tasks (i.e. either ``task_id`` or ``map_index`` is
2501 a non-str iterable), a list of matching XComs is returned. Elements in
2502 the list is ordered by item ordering in ``task_id`` and ``map_index``.
2503 """
2504 if dag_id is None:
2505 dag_id = self.dag_id
2507 query = XCom.get_many(
2508 key=key,
2509 run_id=self.run_id,
2510 dag_ids=dag_id,
2511 task_ids=task_ids,
2512 map_indexes=map_indexes,
2513 include_prior_dates=include_prior_dates,
2514 session=session,
2515 )
2517 # NOTE: Since we're only fetching the value field and not the whole
2518 # class, the @recreate annotation does not kick in. Therefore we need to
2519 # call XCom.deserialize_value() manually.
2521 # We are only pulling one single task.
2522 if (task_ids is None or isinstance(task_ids, str)) and not isinstance(map_indexes, Iterable):
2523 first = query.with_entities(
2524 XCom.run_id, XCom.task_id, XCom.dag_id, XCom.map_index, XCom.value
2525 ).first()
2526 if first is None: # No matching XCom at all.
2527 return default
2528 if map_indexes is not None or first.map_index < 0:
2529 return XCom.deserialize_value(first)
2530 query = query.order_by(None).order_by(XCom.map_index.asc())
2531 return LazyXComAccess.build_from_xcom_query(query)
2533 # At this point either task_ids or map_indexes is explicitly multi-value.
2534 # Order return values to match task_ids and map_indexes ordering.
2535 query = query.order_by(None)
2536 if task_ids is None or isinstance(task_ids, str):
2537 query = query.order_by(XCom.task_id)
2538 else:
2539 task_id_whens = {tid: i for i, tid in enumerate(task_ids)}
2540 if task_id_whens:
2541 query = query.order_by(case(task_id_whens, value=XCom.task_id))
2542 else:
2543 query = query.order_by(XCom.task_id)
2544 if map_indexes is None or isinstance(map_indexes, int):
2545 query = query.order_by(XCom.map_index)
2546 elif isinstance(map_indexes, range):
2547 order = XCom.map_index
2548 if map_indexes.step < 0:
2549 order = order.desc()
2550 query = query.order_by(order)
2551 else:
2552 map_index_whens = {map_index: i for i, map_index in enumerate(map_indexes)}
2553 if map_index_whens:
2554 query = query.order_by(case(map_index_whens, value=XCom.map_index))
2555 else:
2556 query = query.order_by(XCom.map_index)
2557 return LazyXComAccess.build_from_xcom_query(query)
2559 @provide_session
2560 def get_num_running_task_instances(self, session: Session, same_dagrun=False) -> int:
2561 """Return Number of running TIs from the DB."""
2562 # .count() is inefficient
2563 num_running_task_instances_query = session.query(func.count()).filter(
2564 TaskInstance.dag_id == self.dag_id,
2565 TaskInstance.task_id == self.task_id,
2566 TaskInstance.state == State.RUNNING,
2567 )
2568 if same_dagrun:
2569 num_running_task_instances_query.filter(TaskInstance.run_id == self.run_id)
2570 return num_running_task_instances_query.scalar()
2572 def init_run_context(self, raw: bool = False) -> None:
2573 """Sets the log context."""
2574 self.raw = raw
2575 self._set_context(self)
2577 @staticmethod
2578 def filter_for_tis(tis: Iterable[TaskInstance | TaskInstanceKey]) -> BooleanClauseList | None:
2579 """Returns SQLAlchemy filter to query selected task instances."""
2580 # DictKeys type, (what we often pass here from the scheduler) is not directly indexable :(
2581 # Or it might be a generator, but we need to be able to iterate over it more than once
2582 tis = list(tis)
2584 if not tis:
2585 return None
2587 first = tis[0]
2589 dag_id = first.dag_id
2590 run_id = first.run_id
2591 map_index = first.map_index
2592 first_task_id = first.task_id
2594 # pre-compute the set of dag_id, run_id, map_indices and task_ids
2595 dag_ids, run_ids, map_indices, task_ids = set(), set(), set(), set()
2596 for t in tis:
2597 dag_ids.add(t.dag_id)
2598 run_ids.add(t.run_id)
2599 map_indices.add(t.map_index)
2600 task_ids.add(t.task_id)
2602 # Common path optimisations: when all TIs are for the same dag_id and run_id, or same dag_id
2603 # and task_id -- this can be over 150x faster for huge numbers of TIs (20k+)
2604 if dag_ids == {dag_id} and run_ids == {run_id} and map_indices == {map_index}:
2605 return and_(
2606 TaskInstance.dag_id == dag_id,
2607 TaskInstance.run_id == run_id,
2608 TaskInstance.map_index == map_index,
2609 TaskInstance.task_id.in_(task_ids),
2610 )
2611 if dag_ids == {dag_id} and task_ids == {first_task_id} and map_indices == {map_index}:
2612 return and_(
2613 TaskInstance.dag_id == dag_id,
2614 TaskInstance.run_id.in_(run_ids),
2615 TaskInstance.map_index == map_index,
2616 TaskInstance.task_id == first_task_id,
2617 )
2618 if dag_ids == {dag_id} and run_ids == {run_id} and task_ids == {first_task_id}:
2619 return and_(
2620 TaskInstance.dag_id == dag_id,
2621 TaskInstance.run_id == run_id,
2622 TaskInstance.map_index.in_(map_indices),
2623 TaskInstance.task_id == first_task_id,
2624 )
2626 filter_condition = []
2627 # create 2 nested groups, both primarily grouped by dag_id and run_id,
2628 # and in the nested group 1 grouped by task_id the other by map_index.
2629 task_id_groups: dict[tuple, dict[Any, list[Any]]] = defaultdict(lambda: defaultdict(list))
2630 map_index_groups: dict[tuple, dict[Any, list[Any]]] = defaultdict(lambda: defaultdict(list))
2631 for t in tis:
2632 task_id_groups[(t.dag_id, t.run_id)][t.task_id].append(t.map_index)
2633 map_index_groups[(t.dag_id, t.run_id)][t.map_index].append(t.task_id)
2635 # this assumes that most dags have dag_id as the largest grouping, followed by run_id. even
2636 # if its not, this is still a significant optimization over querying for every single tuple key
2637 for cur_dag_id in dag_ids:
2638 for cur_run_id in run_ids:
2639 # we compare the group size between task_id and map_index and use the smaller group
2640 dag_task_id_groups = task_id_groups[(cur_dag_id, cur_run_id)]
2641 dag_map_index_groups = map_index_groups[(cur_dag_id, cur_run_id)]
2643 if len(dag_task_id_groups) <= len(dag_map_index_groups):
2644 for cur_task_id, cur_map_indices in dag_task_id_groups.items():
2645 filter_condition.append(
2646 and_(
2647 TaskInstance.dag_id == cur_dag_id,
2648 TaskInstance.run_id == cur_run_id,
2649 TaskInstance.task_id == cur_task_id,
2650 TaskInstance.map_index.in_(cur_map_indices),
2651 )
2652 )
2653 else:
2654 for cur_map_index, cur_task_ids in dag_map_index_groups.items():
2655 filter_condition.append(
2656 and_(
2657 TaskInstance.dag_id == cur_dag_id,
2658 TaskInstance.run_id == cur_run_id,
2659 TaskInstance.task_id.in_(cur_task_ids),
2660 TaskInstance.map_index == cur_map_index,
2661 )
2662 )
2664 return or_(*filter_condition)
2666 @classmethod
2667 def ti_selector_condition(cls, vals: Collection[str | tuple[str, int]]) -> ColumnOperators:
2668 """
2669 Build an SQLAlchemy filter for a list where each element can contain
2670 whether a task_id, or a tuple of (task_id,map_index).
2672 :meta private:
2673 """
2674 # Compute a filter for TI.task_id and TI.map_index based on input values
2675 # For each item, it will either be a task_id, or (task_id, map_index)
2676 task_id_only = [v for v in vals if isinstance(v, str)]
2677 with_map_index = [v for v in vals if not isinstance(v, str)]
2679 filters: list[ColumnOperators] = []
2680 if task_id_only:
2681 filters.append(cls.task_id.in_(task_id_only))
2682 if with_map_index:
2683 filters.append(tuple_in_condition((cls.task_id, cls.map_index), with_map_index))
2685 if not filters:
2686 return false()
2687 if len(filters) == 1:
2688 return filters[0]
2689 return or_(*filters)
2691 @Sentry.enrich_errors
2692 @provide_session
2693 def schedule_downstream_tasks(self, session: Session = NEW_SESSION, max_tis_per_query: int | None = None):
2694 """
2695 The mini-scheduler for scheduling downstream tasks of this task instance.
2697 :meta: private
2698 """
2699 from sqlalchemy.exc import OperationalError
2701 from airflow.models import DagRun
2703 try:
2704 # Re-select the row with a lock
2705 dag_run = with_row_locks(
2706 session.query(DagRun).filter_by(
2707 dag_id=self.dag_id,
2708 run_id=self.run_id,
2709 ),
2710 session=session,
2711 ).one()
2713 task = self.task
2714 if TYPE_CHECKING:
2715 assert task.dag
2717 # Get a partial DAG with just the specific tasks we want to examine.
2718 # In order for dep checks to work correctly, we include ourself (so
2719 # TriggerRuleDep can check the state of the task we just executed).
2720 partial_dag = task.dag.partial_subset(
2721 task.downstream_task_ids,
2722 include_downstream=True,
2723 include_upstream=False,
2724 include_direct_upstream=True,
2725 )
2727 dag_run.dag = partial_dag
2728 info = dag_run.task_instance_scheduling_decisions(session)
2730 skippable_task_ids = {
2731 task_id for task_id in partial_dag.task_ids if task_id not in task.downstream_task_ids
2732 }
2734 schedulable_tis = [
2735 ti
2736 for ti in info.schedulable_tis
2737 if ti.task_id not in skippable_task_ids
2738 and not (
2739 ti.task.inherits_from_empty_operator
2740 and not ti.task.on_execute_callback
2741 and not ti.task.on_success_callback
2742 and not ti.task.outlets
2743 )
2744 ]
2745 for schedulable_ti in schedulable_tis:
2746 if not hasattr(schedulable_ti, "task"):
2747 schedulable_ti.task = task.dag.get_task(schedulable_ti.task_id)
2749 num = dag_run.schedule_tis(schedulable_tis, session=session, max_tis_per_query=max_tis_per_query)
2750 self.log.info("%d downstream tasks scheduled from follow-on schedule check", num)
2752 session.flush()
2754 except OperationalError as e:
2755 # Any kind of DB error here is _non fatal_ as this block is just an optimisation.
2756 self.log.info(
2757 "Skipping mini scheduling run due to exception: %s",
2758 e.statement,
2759 exc_info=True,
2760 )
2761 session.rollback()
2763 def get_relevant_upstream_map_indexes(
2764 self,
2765 upstream: Operator,
2766 ti_count: int | None,
2767 *,
2768 session: Session,
2769 ) -> int | range | None:
2770 """Infer the map indexes of an upstream "relevant" to this ti.
2772 The bulk of the logic mainly exists to solve the problem described by
2773 the following example, where 'val' must resolve to different values,
2774 depending on where the reference is being used::
2776 @task
2777 def this_task(v): # This is self.task.
2778 return v * 2
2780 @task_group
2781 def tg1(inp):
2782 val = upstream(inp) # This is the upstream task.
2783 this_task(val) # When inp is 1, val here should resolve to 2.
2784 return val
2786 # This val is the same object returned by tg1.
2787 val = tg1.expand(inp=[1, 2, 3])
2789 @task_group
2790 def tg2(inp):
2791 another_task(inp, val) # val here should resolve to [2, 4, 6].
2793 tg2.expand(inp=["a", "b"])
2795 The surrounding mapped task groups of ``upstream`` and ``self.task`` are
2796 inspected to find a common "ancestor". If such an ancestor is found,
2797 we need to return specific map indexes to pull a partial value from
2798 upstream XCom.
2800 :param upstream: The referenced upstream task.
2801 :param ti_count: The total count of task instance this task was expanded
2802 by the scheduler, i.e. ``expanded_ti_count`` in the template context.
2803 :return: Specific map index or map indexes to pull, or ``None`` if we
2804 want to "whole" return value (i.e. no mapped task groups involved).
2805 """
2806 # This value should never be None since we already know the current task
2807 # is in a mapped task group, and should have been expanded, despite that,
2808 # we need to check that it is not None to satisfy Mypy.
2809 # But this value can be 0 when we expand an empty list, for that it is
2810 # necessary to check that ti_count is not 0 to avoid dividing by 0.
2811 if not ti_count:
2812 return None
2814 # Find the innermost common mapped task group between the current task
2815 # If the current task and the referenced task does not have a common
2816 # mapped task group, the two are in different task mapping contexts
2817 # (like another_task above), and we should use the "whole" value.
2818 common_ancestor = _find_common_ancestor_mapped_group(self.task, upstream)
2819 if common_ancestor is None:
2820 return None
2822 # At this point we know the two tasks share a mapped task group, and we
2823 # should use a "partial" value. Let's break down the mapped ti count
2824 # between the ancestor and further expansion happened inside it.
2825 ancestor_ti_count = common_ancestor.get_mapped_ti_count(self.run_id, session=session)
2826 ancestor_map_index = self.map_index * ancestor_ti_count // ti_count
2828 # If the task is NOT further expanded inside the common ancestor, we
2829 # only want to reference one single ti. We must walk the actual DAG,
2830 # and "ti_count == ancestor_ti_count" does not work, since the further
2831 # expansion may be of length 1.
2832 if not _is_further_mapped_inside(upstream, common_ancestor):
2833 return ancestor_map_index
2835 # Otherwise we need a partial aggregation for values from selected task
2836 # instances in the ancestor's expansion context.
2837 further_count = ti_count // ancestor_ti_count
2838 map_index_start = ancestor_map_index * further_count
2839 return range(map_index_start, map_index_start + further_count)
2841 def clear_db_references(self, session):
2842 """
2843 Clear db tables that have a reference to this instance.
2845 :param session: ORM Session
2847 :meta private:
2848 """
2849 from airflow.models.renderedtifields import RenderedTaskInstanceFields
2851 tables = [TaskFail, TaskInstanceNote, TaskReschedule, XCom, RenderedTaskInstanceFields]
2852 for table in tables:
2853 session.execute(
2854 delete(table).where(
2855 table.dag_id == self.dag_id,
2856 table.task_id == self.task_id,
2857 table.run_id == self.run_id,
2858 table.map_index == self.map_index,
2859 )
2860 )
2863def _find_common_ancestor_mapped_group(node1: Operator, node2: Operator) -> MappedTaskGroup | None:
2864 """Given two operators, find their innermost common mapped task group."""
2865 if node1.dag is None or node2.dag is None or node1.dag_id != node2.dag_id:
2866 return None
2867 parent_group_ids = {g.group_id for g in node1.iter_mapped_task_groups()}
2868 common_groups = (g for g in node2.iter_mapped_task_groups() if g.group_id in parent_group_ids)
2869 return next(common_groups, None)
2872def _is_further_mapped_inside(operator: Operator, container: TaskGroup) -> bool:
2873 """Whether given operator is *further* mapped inside a task group."""
2874 if isinstance(operator, MappedOperator):
2875 return True
2876 task_group = operator.task_group
2877 while task_group is not None and task_group.group_id != container.group_id:
2878 if isinstance(task_group, MappedTaskGroup):
2879 return True
2880 task_group = task_group.parent_group
2881 return False
2884# State of the task instance.
2885# Stores string version of the task state.
2886TaskInstanceStateType = Tuple[TaskInstanceKey, str]
2889class SimpleTaskInstance:
2890 """
2891 Simplified Task Instance.
2893 Used to send data between processes via Queues.
2894 """
2896 def __init__(
2897 self,
2898 dag_id: str,
2899 task_id: str,
2900 run_id: str,
2901 start_date: datetime | None,
2902 end_date: datetime | None,
2903 try_number: int,
2904 map_index: int,
2905 state: str,
2906 executor_config: Any,
2907 pool: str,
2908 queue: str,
2909 key: TaskInstanceKey,
2910 run_as_user: str | None = None,
2911 priority_weight: int | None = None,
2912 ):
2913 self.dag_id = dag_id
2914 self.task_id = task_id
2915 self.run_id = run_id
2916 self.map_index = map_index
2917 self.start_date = start_date
2918 self.end_date = end_date
2919 self.try_number = try_number
2920 self.state = state
2921 self.executor_config = executor_config
2922 self.run_as_user = run_as_user
2923 self.pool = pool
2924 self.priority_weight = priority_weight
2925 self.queue = queue
2926 self.key = key
2928 def __eq__(self, other):
2929 if isinstance(other, self.__class__):
2930 return self.__dict__ == other.__dict__
2931 return NotImplemented
2933 def as_dict(self):
2934 warnings.warn(
2935 "This method is deprecated. Use BaseSerialization.serialize.",
2936 RemovedInAirflow3Warning,
2937 stacklevel=2,
2938 )
2939 new_dict = dict(self.__dict__)
2940 for key in new_dict:
2941 if key in ["start_date", "end_date"]:
2942 val = new_dict[key]
2943 if not val or isinstance(val, str):
2944 continue
2945 new_dict.update({key: val.isoformat()})
2946 return new_dict
2948 @classmethod
2949 def from_ti(cls, ti: TaskInstance) -> SimpleTaskInstance:
2950 return cls(
2951 dag_id=ti.dag_id,
2952 task_id=ti.task_id,
2953 run_id=ti.run_id,
2954 map_index=ti.map_index,
2955 start_date=ti.start_date,
2956 end_date=ti.end_date,
2957 try_number=ti.try_number,
2958 state=ti.state,
2959 executor_config=ti.executor_config,
2960 pool=ti.pool,
2961 queue=ti.queue,
2962 key=ti.key,
2963 run_as_user=ti.run_as_user if hasattr(ti, "run_as_user") else None,
2964 priority_weight=ti.priority_weight if hasattr(ti, "priority_weight") else None,
2965 )
2967 @classmethod
2968 def from_dict(cls, obj_dict: dict) -> SimpleTaskInstance:
2969 warnings.warn(
2970 "This method is deprecated. Use BaseSerialization.deserialize.",
2971 RemovedInAirflow3Warning,
2972 stacklevel=2,
2973 )
2974 ti_key = TaskInstanceKey(*obj_dict.pop("key"))
2975 start_date = None
2976 end_date = None
2977 start_date_str: str | None = obj_dict.pop("start_date")
2978 end_date_str: str | None = obj_dict.pop("end_date")
2979 if start_date_str:
2980 start_date = timezone.parse(start_date_str)
2981 if end_date_str:
2982 end_date = timezone.parse(end_date_str)
2983 return cls(**obj_dict, start_date=start_date, end_date=end_date, key=ti_key)
2986class TaskInstanceNote(Base):
2987 """For storage of arbitrary notes concerning the task instance."""
2989 __tablename__ = "task_instance_note"
2991 user_id = Column(Integer, nullable=True)
2992 task_id = Column(StringID(), primary_key=True, nullable=False)
2993 dag_id = Column(StringID(), primary_key=True, nullable=False)
2994 run_id = Column(StringID(), primary_key=True, nullable=False)
2995 map_index = Column(Integer, primary_key=True, nullable=False)
2996 content = Column(String(1000).with_variant(Text(1000), "mysql"))
2997 created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False)
2998 updated_at = Column(UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow, nullable=False)
3000 task_instance = relationship("TaskInstance", back_populates="task_instance_note")
3002 __table_args__ = (
3003 PrimaryKeyConstraint(
3004 "task_id", "dag_id", "run_id", "map_index", name="task_instance_note_pkey", mssql_clustered=True
3005 ),
3006 ForeignKeyConstraint(
3007 (dag_id, task_id, run_id, map_index),
3008 [
3009 "task_instance.dag_id",
3010 "task_instance.task_id",
3011 "task_instance.run_id",
3012 "task_instance.map_index",
3013 ],
3014 name="task_instance_note_ti_fkey",
3015 ondelete="CASCADE",
3016 ),
3017 ForeignKeyConstraint(
3018 (user_id,),
3019 ["ab_user.id"],
3020 name="task_instance_note_user_fkey",
3021 ),
3022 )
3024 def __init__(self, content, user_id=None):
3025 self.content = content
3026 self.user_id = user_id
3028 def __repr__(self):
3029 prefix = f"<{self.__class__.__name__}: {self.dag_id}.{self.task_id} {self.run_id}"
3030 if self.map_index != -1:
3031 prefix += f" map_index={self.map_index}"
3032 return prefix + ">"
3035STATICA_HACK = True
3036globals()["kcah_acitats"[::-1].upper()] = False
3037if STATICA_HACK: # pragma: no cover
3038 from airflow.jobs.job import Job
3040 TaskInstance.queued_by_job = relationship(Job)