Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/airflow/models/taskinstance.py: 21%
1261 statements
« prev ^ index » next coverage.py v7.0.1, created at 2022-12-25 06:11 +0000
« prev ^ index » next coverage.py v7.0.1, created at 2022-12-25 06:11 +0000
1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18from __future__ import annotations
20import collections.abc
21import contextlib
22import hashlib
23import logging
24import math
25import operator
26import os
27import signal
28import warnings
29from collections import defaultdict
30from datetime import datetime, timedelta
31from functools import partial
32from types import TracebackType
33from typing import TYPE_CHECKING, Any, Callable, Collection, Generator, Iterable, NamedTuple, Tuple
34from urllib.parse import quote
36import dill
37import jinja2
38import lazy_object_proxy
39import pendulum
40from jinja2 import TemplateAssertionError, UndefinedError
41from sqlalchemy import (
42 Column,
43 DateTime,
44 Float,
45 ForeignKeyConstraint,
46 Index,
47 Integer,
48 PrimaryKeyConstraint,
49 String,
50 Text,
51 and_,
52 false,
53 func,
54 inspect,
55 or_,
56 text,
57)
58from sqlalchemy.ext.associationproxy import association_proxy
59from sqlalchemy.ext.mutable import MutableDict
60from sqlalchemy.orm import reconstructor, relationship
61from sqlalchemy.orm.attributes import NO_VALUE, set_committed_value
62from sqlalchemy.orm.session import Session
63from sqlalchemy.sql.elements import BooleanClauseList
64from sqlalchemy.sql.expression import ColumnOperators, case
66from airflow import settings
67from airflow.compat.functools import cache
68from airflow.configuration import conf
69from airflow.datasets import Dataset
70from airflow.datasets.manager import dataset_manager
71from airflow.exceptions import (
72 AirflowException,
73 AirflowFailException,
74 AirflowRescheduleException,
75 AirflowSensorTimeout,
76 AirflowSkipException,
77 AirflowTaskTimeout,
78 DagRunNotFound,
79 RemovedInAirflow3Warning,
80 TaskDeferralError,
81 TaskDeferred,
82 UnmappableXComLengthPushed,
83 UnmappableXComTypePushed,
84 XComForMappingNotPushed,
85)
86from airflow.models.base import Base, StringID
87from airflow.models.log import Log
88from airflow.models.mappedoperator import MappedOperator
89from airflow.models.param import process_params
90from airflow.models.taskfail import TaskFail
91from airflow.models.taskmap import TaskMap
92from airflow.models.taskreschedule import TaskReschedule
93from airflow.models.xcom import XCOM_RETURN_KEY, LazyXComAccess, XCom
94from airflow.plugins_manager import integrate_macros_plugins
95from airflow.sentry import Sentry
96from airflow.stats import Stats
97from airflow.templates import SandboxedEnvironment
98from airflow.ti_deps.dep_context import DepContext
99from airflow.ti_deps.dependencies_deps import REQUEUEABLE_DEPS, RUNNING_DEPS
100from airflow.timetables.base import DataInterval
101from airflow.typing_compat import Literal, TypeGuard
102from airflow.utils import timezone
103from airflow.utils.context import ConnectionAccessor, Context, VariableAccessor, context_merge
104from airflow.utils.email import send_email
105from airflow.utils.helpers import render_template_to_string
106from airflow.utils.log.logging_mixin import LoggingMixin
107from airflow.utils.module_loading import qualname
108from airflow.utils.net import get_hostname
109from airflow.utils.operator_helpers import context_to_airflow_vars
110from airflow.utils.platform import getuser
111from airflow.utils.retries import run_with_db_retries
112from airflow.utils.session import NEW_SESSION, create_session, provide_session
113from airflow.utils.sqlalchemy import (
114 ExecutorConfigType,
115 ExtendedJSON,
116 UtcDateTime,
117 tuple_in_condition,
118 with_row_locks,
119)
120from airflow.utils.state import DagRunState, State, TaskInstanceState
121from airflow.utils.timeout import timeout
123TR = TaskReschedule
125_CURRENT_CONTEXT: list[Context] = []
126log = logging.getLogger(__name__)
129if TYPE_CHECKING:
130 from airflow.models.abstractoperator import TaskStateChangeCallback
131 from airflow.models.baseoperator import BaseOperator
132 from airflow.models.dag import DAG, DagModel
133 from airflow.models.dagrun import DagRun
134 from airflow.models.dataset import DatasetEvent
135 from airflow.models.operator import Operator
136 from airflow.utils.task_group import MappedTaskGroup, TaskGroup
139@contextlib.contextmanager
140def set_current_context(context: Context) -> Generator[Context, None, None]:
141 """
142 Sets the current execution context to the provided context object.
143 This method should be called once per Task execution, before calling operator.execute.
144 """
145 _CURRENT_CONTEXT.append(context)
146 try:
147 yield context
148 finally:
149 expected_state = _CURRENT_CONTEXT.pop()
150 if expected_state != context:
151 log.warning(
152 "Current context is not equal to the state at context stack. Expected=%s, got=%s",
153 context,
154 expected_state,
155 )
158def clear_task_instances(
159 tis: list[TaskInstance],
160 session: Session,
161 activate_dag_runs: None = None,
162 dag: DAG | None = None,
163 dag_run_state: DagRunState | Literal[False] = DagRunState.QUEUED,
164) -> None:
165 """
166 Clears a set of task instances, but makes sure the running ones
167 get killed.
169 :param tis: a list of task instances
170 :param session: current session
171 :param dag_run_state: state to set DagRun to. If set to False, dagrun state will not
172 be changed.
173 :param dag: DAG object
174 :param activate_dag_runs: Deprecated parameter, do not pass
175 """
176 job_ids = []
177 # Keys: dag_id -> run_id -> map_indexes -> try_numbers -> task_id
178 task_id_by_key: dict[str, dict[str, dict[int, dict[int, set[str]]]]] = defaultdict(
179 lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(set)))
180 )
181 for ti in tis:
182 if ti.state == TaskInstanceState.RUNNING:
183 if ti.job_id:
184 # If a task is cleared when running, set its state to RESTARTING so that
185 # the task is terminated and becomes eligible for retry.
186 ti.state = TaskInstanceState.RESTARTING
187 job_ids.append(ti.job_id)
188 else:
189 task_id = ti.task_id
190 if dag and dag.has_task(task_id):
191 task = dag.get_task(task_id)
192 ti.refresh_from_task(task)
193 task_retries = task.retries
194 ti.max_tries = ti.try_number + task_retries - 1
195 else:
196 # Ignore errors when updating max_tries if dag is None or
197 # task not found in dag since database records could be
198 # outdated. We make max_tries the maximum value of its
199 # original max_tries or the last attempted try number.
200 ti.max_tries = max(ti.max_tries, ti.prev_attempted_tries)
201 ti.state = None
202 ti.external_executor_id = None
203 ti.clear_next_method_args()
204 session.merge(ti)
206 task_id_by_key[ti.dag_id][ti.run_id][ti.map_index][ti.try_number].add(ti.task_id)
208 if task_id_by_key:
209 # Clear all reschedules related to the ti to clear
211 # This is an optimization for the common case where all tis are for a small number
212 # of dag_id, run_id, try_number, and map_index. Use a nested dict of dag_id,
213 # run_id, try_number, map_index, and task_id to construct the where clause in a
214 # hierarchical manner. This speeds up the delete statement by more than 40x for
215 # large number of tis (50k+).
216 conditions = or_(
217 and_(
218 TR.dag_id == dag_id,
219 or_(
220 and_(
221 TR.run_id == run_id,
222 or_(
223 and_(
224 TR.map_index == map_index,
225 or_(
226 and_(TR.try_number == try_number, TR.task_id.in_(task_ids))
227 for try_number, task_ids in task_tries.items()
228 ),
229 )
230 for map_index, task_tries in map_indexes.items()
231 ),
232 )
233 for run_id, map_indexes in run_ids.items()
234 ),
235 )
236 for dag_id, run_ids in task_id_by_key.items()
237 )
239 delete_qry = TR.__table__.delete().where(conditions)
240 session.execute(delete_qry)
242 if job_ids:
243 from airflow.jobs.base_job import BaseJob
245 for job in session.query(BaseJob).filter(BaseJob.id.in_(job_ids)).all():
246 job.state = TaskInstanceState.RESTARTING
248 if activate_dag_runs is not None:
249 warnings.warn(
250 "`activate_dag_runs` parameter to clear_task_instances function is deprecated. "
251 "Please use `dag_run_state`",
252 RemovedInAirflow3Warning,
253 stacklevel=2,
254 )
255 if not activate_dag_runs:
256 dag_run_state = False
258 if dag_run_state is not False and tis:
259 from airflow.models.dagrun import DagRun # Avoid circular import
261 run_ids_by_dag_id = defaultdict(set)
262 for instance in tis:
263 run_ids_by_dag_id[instance.dag_id].add(instance.run_id)
265 drs = (
266 session.query(DagRun)
267 .filter(
268 or_(
269 and_(DagRun.dag_id == dag_id, DagRun.run_id.in_(run_ids))
270 for dag_id, run_ids in run_ids_by_dag_id.items()
271 )
272 )
273 .all()
274 )
275 dag_run_state = DagRunState(dag_run_state) # Validate the state value.
276 for dr in drs:
277 dr.state = dag_run_state
278 dr.start_date = timezone.utcnow()
279 if dag_run_state == DagRunState.QUEUED:
280 dr.last_scheduling_decision = None
281 dr.start_date = None
282 session.flush()
285def _is_mappable_value(value: Any) -> TypeGuard[Collection]:
286 """Whether a value can be used for task mapping.
288 We only allow collections with guaranteed ordering, but exclude character
289 sequences since that's usually not what users would expect to be mappable.
290 """
291 if not isinstance(value, (collections.abc.Sequence, dict)):
292 return False
293 if isinstance(value, (bytearray, bytes, str)):
294 return False
295 return True
298class TaskInstanceKey(NamedTuple):
299 """Key used to identify task instance."""
301 dag_id: str
302 task_id: str
303 run_id: str
304 try_number: int = 1
305 map_index: int = -1
307 @property
308 def primary(self) -> tuple[str, str, str, int]:
309 """Return task instance primary key part of the key"""
310 return self.dag_id, self.task_id, self.run_id, self.map_index
312 @property
313 def reduced(self) -> TaskInstanceKey:
314 """Remake the key by subtracting 1 from try number to match in memory information"""
315 return TaskInstanceKey(
316 self.dag_id, self.task_id, self.run_id, max(1, self.try_number - 1), self.map_index
317 )
319 def with_try_number(self, try_number: int) -> TaskInstanceKey:
320 """Returns TaskInstanceKey with provided ``try_number``"""
321 return TaskInstanceKey(self.dag_id, self.task_id, self.run_id, try_number, self.map_index)
323 @property
324 def key(self) -> TaskInstanceKey:
325 """For API-compatibly with TaskInstance.
327 Returns self
328 """
329 return self
332def _creator_note(val):
333 """Custom creator for the ``note`` association proxy."""
334 if isinstance(val, str):
335 return TaskInstanceNote(content=val)
336 elif isinstance(val, dict):
337 return TaskInstanceNote(**val)
338 else:
339 return TaskInstanceNote(*val)
342class TaskInstance(Base, LoggingMixin):
343 """
344 Task instances store the state of a task instance. This table is the
345 authority and single source of truth around what tasks have run and the
346 state they are in.
348 The SqlAlchemy model doesn't have a SqlAlchemy foreign key to the task or
349 dag model deliberately to have more control over transactions.
351 Database transactions on this table should insure double triggers and
352 any confusion around what task instances are or aren't ready to run
353 even while multiple schedulers may be firing task instances.
355 A value of -1 in map_index represents any of: a TI without mapped tasks;
356 a TI with mapped tasks that has yet to be expanded (state=pending);
357 a TI with mapped tasks that expanded to an empty list (state=skipped).
358 """
360 __tablename__ = "task_instance"
361 task_id = Column(StringID(), primary_key=True, nullable=False)
362 dag_id = Column(StringID(), primary_key=True, nullable=False)
363 run_id = Column(StringID(), primary_key=True, nullable=False)
364 map_index = Column(Integer, primary_key=True, nullable=False, server_default=text("-1"))
366 start_date = Column(UtcDateTime)
367 end_date = Column(UtcDateTime)
368 duration = Column(Float)
369 state = Column(String(20))
370 _try_number = Column("try_number", Integer, default=0)
371 max_tries = Column(Integer, server_default=text("-1"))
372 hostname = Column(String(1000))
373 unixname = Column(String(1000))
374 job_id = Column(Integer)
375 pool = Column(String(256), nullable=False)
376 pool_slots = Column(Integer, default=1, nullable=False)
377 queue = Column(String(256))
378 priority_weight = Column(Integer)
379 operator = Column(String(1000))
380 queued_dttm = Column(UtcDateTime)
381 queued_by_job_id = Column(Integer)
382 pid = Column(Integer)
383 executor_config = Column(ExecutorConfigType(pickler=dill))
384 updated_at = Column(UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow)
386 external_executor_id = Column(StringID())
388 # The trigger to resume on if we are in state DEFERRED
389 trigger_id = Column(Integer)
391 # Optional timeout datetime for the trigger (past this, we'll fail)
392 trigger_timeout = Column(DateTime)
393 # The trigger_timeout should be TIMESTAMP(using UtcDateTime) but for ease of
394 # migration, we are keeping it as DateTime pending a change where expensive
395 # migration is inevitable.
397 # The method to call next, and any extra arguments to pass to it.
398 # Usually used when resuming from DEFERRED.
399 next_method = Column(String(1000))
400 next_kwargs = Column(MutableDict.as_mutable(ExtendedJSON))
402 # If adding new fields here then remember to add them to
403 # refresh_from_db() or they won't display in the UI correctly
405 __table_args__ = (
406 Index("ti_dag_state", dag_id, state),
407 Index("ti_dag_run", dag_id, run_id),
408 Index("ti_state", state),
409 Index("ti_state_lkp", dag_id, task_id, run_id, state),
410 Index("ti_pool", pool, state, priority_weight),
411 Index("ti_job_id", job_id),
412 Index("ti_trigger_id", trigger_id),
413 PrimaryKeyConstraint(
414 "dag_id", "task_id", "run_id", "map_index", name="task_instance_pkey", mssql_clustered=True
415 ),
416 ForeignKeyConstraint(
417 [trigger_id],
418 ["trigger.id"],
419 name="task_instance_trigger_id_fkey",
420 ondelete="CASCADE",
421 ),
422 ForeignKeyConstraint(
423 [dag_id, run_id],
424 ["dag_run.dag_id", "dag_run.run_id"],
425 name="task_instance_dag_run_fkey",
426 ondelete="CASCADE",
427 ),
428 )
430 dag_model = relationship(
431 "DagModel",
432 primaryjoin="TaskInstance.dag_id == DagModel.dag_id",
433 foreign_keys=dag_id,
434 uselist=False,
435 innerjoin=True,
436 viewonly=True,
437 )
439 trigger = relationship("Trigger", uselist=False)
440 triggerer_job = association_proxy("trigger", "triggerer_job")
441 dag_run = relationship("DagRun", back_populates="task_instances", lazy="joined", innerjoin=True)
442 rendered_task_instance_fields = relationship("RenderedTaskInstanceFields", lazy="noload", uselist=False)
443 execution_date = association_proxy("dag_run", "execution_date")
444 task_instance_note = relationship("TaskInstanceNote", back_populates="task_instance", uselist=False)
445 note = association_proxy("task_instance_note", "content", creator=_creator_note)
446 task: Operator # Not always set...
448 def __init__(
449 self,
450 task: Operator,
451 execution_date: datetime | None = None,
452 run_id: str | None = None,
453 state: str | None = None,
454 map_index: int = -1,
455 ):
456 super().__init__()
457 self.dag_id = task.dag_id
458 self.task_id = task.task_id
459 self.map_index = map_index
460 self.refresh_from_task(task)
461 # init_on_load will config the log
462 self.init_on_load()
464 if run_id is None and execution_date is not None:
465 from airflow.models.dagrun import DagRun # Avoid circular import
467 warnings.warn(
468 "Passing an execution_date to `TaskInstance()` is deprecated in favour of passing a run_id",
469 RemovedInAirflow3Warning,
470 # Stack level is 4 because SQLA adds some wrappers around the constructor
471 stacklevel=4,
472 )
473 # make sure we have a localized execution_date stored in UTC
474 if execution_date and not timezone.is_localized(execution_date):
475 self.log.warning(
476 "execution date %s has no timezone information. Using default from dag or system",
477 execution_date,
478 )
479 if self.task.has_dag():
480 if TYPE_CHECKING:
481 assert self.task.dag
482 execution_date = timezone.make_aware(execution_date, self.task.dag.timezone)
483 else:
484 execution_date = timezone.make_aware(execution_date)
486 execution_date = timezone.convert_to_utc(execution_date)
487 with create_session() as session:
488 run_id = (
489 session.query(DagRun.run_id)
490 .filter_by(dag_id=self.dag_id, execution_date=execution_date)
491 .scalar()
492 )
493 if run_id is None:
494 raise DagRunNotFound(
495 f"DagRun for {self.dag_id!r} with date {execution_date} not found"
496 ) from None
498 self.run_id = run_id
500 self.try_number = 0
501 self.max_tries = self.task.retries
502 self.unixname = getuser()
503 if state:
504 self.state = state
505 self.hostname = ""
506 # Is this TaskInstance being currently running within `airflow tasks run --raw`.
507 # Not persisted to the database so only valid for the current process
508 self.raw = False
509 # can be changed when calling 'run'
510 self.test_mode = False
512 @staticmethod
513 def insert_mapping(run_id: str, task: Operator, map_index: int) -> dict[str, Any]:
514 """:meta private:"""
515 return {
516 "dag_id": task.dag_id,
517 "task_id": task.task_id,
518 "run_id": run_id,
519 "_try_number": 0,
520 "hostname": "",
521 "unixname": getuser(),
522 "queue": task.queue,
523 "pool": task.pool,
524 "pool_slots": task.pool_slots,
525 "priority_weight": task.priority_weight_total,
526 "run_as_user": task.run_as_user,
527 "max_tries": task.retries,
528 "executor_config": task.executor_config,
529 "operator": task.task_type,
530 "map_index": map_index,
531 }
533 @reconstructor
534 def init_on_load(self) -> None:
535 """Initialize the attributes that aren't stored in the DB"""
536 # correctly config the ti log
537 self._log = logging.getLogger("airflow.task")
538 self.test_mode = False # can be changed when calling 'run'
540 @property
541 def try_number(self):
542 """
543 Return the try number that this task number will be when it is actually
544 run.
546 If the TaskInstance is currently running, this will match the column in the
547 database, in all other cases this will be incremented.
548 """
549 # This is designed so that task logs end up in the right file.
550 if self.state in State.running:
551 return self._try_number
552 return self._try_number + 1
554 @try_number.setter
555 def try_number(self, value: int) -> None:
556 self._try_number = value
558 @property
559 def prev_attempted_tries(self) -> int:
560 """
561 Based on this instance's try_number, this will calculate
562 the number of previously attempted tries, defaulting to 0.
563 """
564 # Expose this for the Task Tries and Gantt graph views.
565 # Using `try_number` throws off the counts for non-running tasks.
566 # Also useful in error logging contexts to get
567 # the try number for the last try that was attempted.
568 # https://issues.apache.org/jira/browse/AIRFLOW-2143
570 return self._try_number
572 @property
573 def next_try_number(self) -> int:
574 return self._try_number + 1
576 def command_as_list(
577 self,
578 mark_success=False,
579 ignore_all_deps=False,
580 ignore_task_deps=False,
581 ignore_depends_on_past=False,
582 ignore_ti_state=False,
583 local=False,
584 pickle_id=None,
585 raw=False,
586 job_id=None,
587 pool=None,
588 cfg_path=None,
589 ):
590 """
591 Returns a command that can be executed anywhere where airflow is
592 installed. This command is part of the message sent to executors by
593 the orchestrator.
594 """
595 dag: DAG | DagModel
596 # Use the dag if we have it, else fallback to the ORM dag_model, which might not be loaded
597 if hasattr(self, "task") and hasattr(self.task, "dag"):
598 dag = self.task.dag
599 else:
600 dag = self.dag_model
602 should_pass_filepath = not pickle_id and dag
603 path = None
604 if should_pass_filepath:
605 if dag.is_subdag:
606 path = dag.parent_dag.relative_fileloc
607 else:
608 path = dag.relative_fileloc
610 if path:
611 if not path.is_absolute():
612 path = "DAGS_FOLDER" / path
613 path = str(path)
615 return TaskInstance.generate_command(
616 self.dag_id,
617 self.task_id,
618 run_id=self.run_id,
619 mark_success=mark_success,
620 ignore_all_deps=ignore_all_deps,
621 ignore_task_deps=ignore_task_deps,
622 ignore_depends_on_past=ignore_depends_on_past,
623 ignore_ti_state=ignore_ti_state,
624 local=local,
625 pickle_id=pickle_id,
626 file_path=path,
627 raw=raw,
628 job_id=job_id,
629 pool=pool,
630 cfg_path=cfg_path,
631 map_index=self.map_index,
632 )
634 @staticmethod
635 def generate_command(
636 dag_id: str,
637 task_id: str,
638 run_id: str,
639 mark_success: bool = False,
640 ignore_all_deps: bool = False,
641 ignore_depends_on_past: bool = False,
642 ignore_task_deps: bool = False,
643 ignore_ti_state: bool = False,
644 local: bool = False,
645 pickle_id: int | None = None,
646 file_path: str | None = None,
647 raw: bool = False,
648 job_id: str | None = None,
649 pool: str | None = None,
650 cfg_path: str | None = None,
651 map_index: int = -1,
652 ) -> list[str]:
653 """
654 Generates the shell command required to execute this task instance.
656 :param dag_id: DAG ID
657 :param task_id: Task ID
658 :param run_id: The run_id of this task's DagRun
659 :param mark_success: Whether to mark the task as successful
660 :param ignore_all_deps: Ignore all ignorable dependencies.
661 Overrides the other ignore_* parameters.
662 :param ignore_depends_on_past: Ignore depends_on_past parameter of DAGs
663 (e.g. for Backfills)
664 :param ignore_task_deps: Ignore task-specific dependencies such as depends_on_past
665 and trigger rule
666 :param ignore_ti_state: Ignore the task instance's previous failure/success
667 :param local: Whether to run the task locally
668 :param pickle_id: If the DAG was serialized to the DB, the ID
669 associated with the pickled DAG
670 :param file_path: path to the file containing the DAG definition
671 :param raw: raw mode (needs more details)
672 :param job_id: job ID (needs more details)
673 :param pool: the Airflow pool that the task should run in
674 :param cfg_path: the Path to the configuration file
675 :return: shell command that can be used to run the task instance
676 """
677 cmd = ["airflow", "tasks", "run", dag_id, task_id, run_id]
678 if mark_success:
679 cmd.extend(["--mark-success"])
680 if pickle_id:
681 cmd.extend(["--pickle", str(pickle_id)])
682 if job_id:
683 cmd.extend(["--job-id", str(job_id)])
684 if ignore_all_deps:
685 cmd.extend(["--ignore-all-dependencies"])
686 if ignore_task_deps:
687 cmd.extend(["--ignore-dependencies"])
688 if ignore_depends_on_past:
689 cmd.extend(["--ignore-depends-on-past"])
690 if ignore_ti_state:
691 cmd.extend(["--force"])
692 if local:
693 cmd.extend(["--local"])
694 if pool:
695 cmd.extend(["--pool", pool])
696 if raw:
697 cmd.extend(["--raw"])
698 if file_path:
699 cmd.extend(["--subdir", file_path])
700 if cfg_path:
701 cmd.extend(["--cfg-path", cfg_path])
702 if map_index != -1:
703 cmd.extend(["--map-index", str(map_index)])
704 return cmd
706 @property
707 def log_url(self) -> str:
708 """Log URL for TaskInstance"""
709 iso = quote(self.execution_date.isoformat())
710 base_url = conf.get_mandatory_value("webserver", "BASE_URL")
711 return (
712 f"{base_url}/log"
713 f"?execution_date={iso}"
714 f"&task_id={self.task_id}"
715 f"&dag_id={self.dag_id}"
716 f"&map_index={self.map_index}"
717 )
719 @property
720 def mark_success_url(self) -> str:
721 """URL to mark TI success"""
722 base_url = conf.get_mandatory_value("webserver", "BASE_URL")
723 return (
724 f"{base_url}/confirm"
725 f"?task_id={self.task_id}"
726 f"&dag_id={self.dag_id}"
727 f"&dag_run_id={quote(self.run_id)}"
728 "&upstream=false"
729 "&downstream=false"
730 "&state=success"
731 )
733 @provide_session
734 def current_state(self, session: Session = NEW_SESSION) -> str:
735 """
736 Get the very latest state from the database, if a session is passed,
737 we use and looking up the state becomes part of the session, otherwise
738 a new session is used.
740 sqlalchemy.inspect is used here to get the primary keys ensuring that if they change
741 it will not regress
743 :param session: SQLAlchemy ORM Session
744 """
745 filters = (col == getattr(self, col.name) for col in inspect(TaskInstance).primary_key)
746 return session.query(TaskInstance.state).filter(*filters).scalar()
748 @provide_session
749 def error(self, session: Session = NEW_SESSION) -> None:
750 """
751 Forces the task instance's state to FAILED in the database.
753 :param session: SQLAlchemy ORM Session
754 """
755 self.log.error("Recording the task instance as FAILED")
756 self.state = State.FAILED
757 session.merge(self)
758 session.commit()
760 @provide_session
761 def refresh_from_db(self, session: Session = NEW_SESSION, lock_for_update: bool = False) -> None:
762 """
763 Refreshes the task instance from the database based on the primary key
765 :param session: SQLAlchemy ORM Session
766 :param lock_for_update: if True, indicates that the database should
767 lock the TaskInstance (issuing a FOR UPDATE clause) until the
768 session is committed.
769 """
770 self.log.debug("Refreshing TaskInstance %s from DB", self)
772 if self in session:
773 session.refresh(self, TaskInstance.__mapper__.column_attrs.keys())
775 qry = (
776 # To avoid joining any relationships, by default select all
777 # columns, not the object. This also means we get (effectively) a
778 # namedtuple back, not a TI object
779 session.query(*TaskInstance.__table__.columns).filter(
780 TaskInstance.dag_id == self.dag_id,
781 TaskInstance.task_id == self.task_id,
782 TaskInstance.run_id == self.run_id,
783 TaskInstance.map_index == self.map_index,
784 )
785 )
787 if lock_for_update:
788 for attempt in run_with_db_retries(logger=self.log):
789 with attempt:
790 ti: TaskInstance | None = qry.with_for_update().one_or_none()
791 else:
792 ti = qry.one_or_none()
793 if ti:
794 # Fields ordered per model definition
795 self.start_date = ti.start_date
796 self.end_date = ti.end_date
797 self.duration = ti.duration
798 self.state = ti.state
799 # Since we selected columns, not the object, this is the raw value
800 self.try_number = ti.try_number
801 self.max_tries = ti.max_tries
802 self.hostname = ti.hostname
803 self.unixname = ti.unixname
804 self.job_id = ti.job_id
805 self.pool = ti.pool
806 self.pool_slots = ti.pool_slots or 1
807 self.queue = ti.queue
808 self.priority_weight = ti.priority_weight
809 self.operator = ti.operator
810 self.queued_dttm = ti.queued_dttm
811 self.queued_by_job_id = ti.queued_by_job_id
812 self.pid = ti.pid
813 self.executor_config = ti.executor_config
814 self.external_executor_id = ti.external_executor_id
815 self.trigger_id = ti.trigger_id
816 self.next_method = ti.next_method
817 self.next_kwargs = ti.next_kwargs
818 else:
819 self.state = None
821 def refresh_from_task(self, task: Operator, pool_override: str | None = None) -> None:
822 """
823 Copy common attributes from the given task.
825 :param task: The task object to copy from
826 :param pool_override: Use the pool_override instead of task's pool
827 """
828 self.task = task
829 self.queue = task.queue
830 self.pool = pool_override or task.pool
831 self.pool_slots = task.pool_slots
832 self.priority_weight = task.priority_weight_total
833 self.run_as_user = task.run_as_user
834 # Do not set max_tries to task.retries here because max_tries is a cumulative
835 # value that needs to be stored in the db.
836 self.executor_config = task.executor_config
837 self.operator = task.task_type
839 @provide_session
840 def clear_xcom_data(self, session: Session = NEW_SESSION) -> None:
841 """Clear all XCom data from the database for the task instance.
843 If the task is unmapped, all XComs matching this task ID in the same DAG
844 run are removed. If the task is mapped, only the one with matching map
845 index is removed.
847 :param session: SQLAlchemy ORM Session
848 """
849 self.log.debug("Clearing XCom data")
850 if self.map_index < 0:
851 map_index: int | None = None
852 else:
853 map_index = self.map_index
854 XCom.clear(
855 dag_id=self.dag_id,
856 task_id=self.task_id,
857 run_id=self.run_id,
858 map_index=map_index,
859 session=session,
860 )
862 @property
863 def key(self) -> TaskInstanceKey:
864 """Returns a tuple that identifies the task instance uniquely"""
865 return TaskInstanceKey(self.dag_id, self.task_id, self.run_id, self.try_number, self.map_index)
867 @provide_session
868 def set_state(self, state: str | None, session: Session = NEW_SESSION) -> bool:
869 """
870 Set TaskInstance state.
872 :param state: State to set for the TI
873 :param session: SQLAlchemy ORM Session
874 :return: Was the state changed
875 """
876 if self.state == state:
877 return False
879 current_time = timezone.utcnow()
880 self.log.debug("Setting task state for %s to %s", self, state)
881 self.state = state
882 self.start_date = self.start_date or current_time
883 if self.state in State.finished or self.state == State.UP_FOR_RETRY:
884 self.end_date = self.end_date or current_time
885 self.duration = (self.end_date - self.start_date).total_seconds()
886 session.merge(self)
887 return True
889 @property
890 def is_premature(self) -> bool:
891 """
892 Returns whether a task is in UP_FOR_RETRY state and its retry interval
893 has elapsed.
894 """
895 # is the task still in the retry waiting period?
896 return self.state == State.UP_FOR_RETRY and not self.ready_for_retry()
898 @provide_session
899 def are_dependents_done(self, session: Session = NEW_SESSION) -> bool:
900 """
901 Checks whether the immediate dependents of this task instance have succeeded or have been skipped.
902 This is meant to be used by wait_for_downstream.
904 This is useful when you do not want to start processing the next
905 schedule of a task until the dependents are done. For instance,
906 if the task DROPs and recreates a table.
908 :param session: SQLAlchemy ORM Session
909 """
910 task = self.task
912 if not task.downstream_task_ids:
913 return True
915 ti = session.query(func.count(TaskInstance.task_id)).filter(
916 TaskInstance.dag_id == self.dag_id,
917 TaskInstance.task_id.in_(task.downstream_task_ids),
918 TaskInstance.run_id == self.run_id,
919 TaskInstance.state.in_([State.SKIPPED, State.SUCCESS]),
920 )
921 count = ti[0][0]
922 return count == len(task.downstream_task_ids)
924 @provide_session
925 def get_previous_dagrun(
926 self,
927 state: DagRunState | None = None,
928 session: Session | None = None,
929 ) -> DagRun | None:
930 """The DagRun that ran before this task instance's DagRun.
932 :param state: If passed, it only take into account instances of a specific state.
933 :param session: SQLAlchemy ORM Session.
934 """
935 dag = self.task.dag
936 if dag is None:
937 return None
939 dr = self.get_dagrun(session=session)
940 dr.dag = dag
942 # We always ignore schedule in dagrun lookup when `state` is given
943 # or the DAG is never scheduled. For legacy reasons, when
944 # `catchup=True`, we use `get_previous_scheduled_dagrun` unless
945 # `ignore_schedule` is `True`.
946 ignore_schedule = state is not None or not dag.timetable.can_run
947 if dag.catchup is True and not ignore_schedule:
948 last_dagrun = dr.get_previous_scheduled_dagrun(session=session)
949 else:
950 last_dagrun = dr.get_previous_dagrun(session=session, state=state)
952 if last_dagrun:
953 return last_dagrun
955 return None
957 @provide_session
958 def get_previous_ti(
959 self,
960 state: DagRunState | None = None,
961 session: Session = NEW_SESSION,
962 ) -> TaskInstance | None:
963 """
964 The task instance for the task that ran before this task instance.
966 :param state: If passed, it only take into account instances of a specific state.
967 :param session: SQLAlchemy ORM Session
968 """
969 dagrun = self.get_previous_dagrun(state, session=session)
970 if dagrun is None:
971 return None
972 return dagrun.get_task_instance(self.task_id, session=session)
974 @property
975 def previous_ti(self) -> TaskInstance | None:
976 """
977 This attribute is deprecated.
978 Please use `airflow.models.taskinstance.TaskInstance.get_previous_ti` method.
979 """
980 warnings.warn(
981 """
982 This attribute is deprecated.
983 Please use `airflow.models.taskinstance.TaskInstance.get_previous_ti` method.
984 """,
985 RemovedInAirflow3Warning,
986 stacklevel=2,
987 )
988 return self.get_previous_ti()
990 @property
991 def previous_ti_success(self) -> TaskInstance | None:
992 """
993 This attribute is deprecated.
994 Please use `airflow.models.taskinstance.TaskInstance.get_previous_ti` method.
995 """
996 warnings.warn(
997 """
998 This attribute is deprecated.
999 Please use `airflow.models.taskinstance.TaskInstance.get_previous_ti` method.
1000 """,
1001 RemovedInAirflow3Warning,
1002 stacklevel=2,
1003 )
1004 return self.get_previous_ti(state=DagRunState.SUCCESS)
1006 @provide_session
1007 def get_previous_execution_date(
1008 self,
1009 state: DagRunState | None = None,
1010 session: Session = NEW_SESSION,
1011 ) -> pendulum.DateTime | None:
1012 """
1013 The execution date from property previous_ti_success.
1015 :param state: If passed, it only take into account instances of a specific state.
1016 :param session: SQLAlchemy ORM Session
1017 """
1018 self.log.debug("previous_execution_date was called")
1019 prev_ti = self.get_previous_ti(state=state, session=session)
1020 return prev_ti and pendulum.instance(prev_ti.execution_date)
1022 @provide_session
1023 def get_previous_start_date(
1024 self, state: DagRunState | None = None, session: Session = NEW_SESSION
1025 ) -> pendulum.DateTime | None:
1026 """
1027 The start date from property previous_ti_success.
1029 :param state: If passed, it only take into account instances of a specific state.
1030 :param session: SQLAlchemy ORM Session
1031 """
1032 self.log.debug("previous_start_date was called")
1033 prev_ti = self.get_previous_ti(state=state, session=session)
1034 # prev_ti may not exist and prev_ti.start_date may be None.
1035 return prev_ti and prev_ti.start_date and pendulum.instance(prev_ti.start_date)
1037 @property
1038 def previous_start_date_success(self) -> pendulum.DateTime | None:
1039 """
1040 This attribute is deprecated.
1041 Please use `airflow.models.taskinstance.TaskInstance.get_previous_start_date` method.
1042 """
1043 warnings.warn(
1044 """
1045 This attribute is deprecated.
1046 Please use `airflow.models.taskinstance.TaskInstance.get_previous_start_date` method.
1047 """,
1048 RemovedInAirflow3Warning,
1049 stacklevel=2,
1050 )
1051 return self.get_previous_start_date(state=DagRunState.SUCCESS)
1053 @provide_session
1054 def are_dependencies_met(
1055 self, dep_context: DepContext | None = None, session: Session = NEW_SESSION, verbose: bool = False
1056 ) -> bool:
1057 """
1058 Returns whether or not all the conditions are met for this task instance to be run
1059 given the context for the dependencies (e.g. a task instance being force run from
1060 the UI will ignore some dependencies).
1062 :param dep_context: The execution context that determines the dependencies that
1063 should be evaluated.
1064 :param session: database session
1065 :param verbose: whether log details on failed dependencies on
1066 info or debug log level
1067 """
1068 dep_context = dep_context or DepContext()
1069 failed = False
1070 verbose_aware_logger = self.log.info if verbose else self.log.debug
1071 for dep_status in self.get_failed_dep_statuses(dep_context=dep_context, session=session):
1072 failed = True
1074 verbose_aware_logger(
1075 "Dependencies not met for %s, dependency '%s' FAILED: %s",
1076 self,
1077 dep_status.dep_name,
1078 dep_status.reason,
1079 )
1081 if failed:
1082 return False
1084 verbose_aware_logger("Dependencies all met for %s", self)
1085 return True
1087 @provide_session
1088 def get_failed_dep_statuses(self, dep_context: DepContext | None = None, session: Session = NEW_SESSION):
1089 """Get failed Dependencies"""
1090 dep_context = dep_context or DepContext()
1091 for dep in dep_context.deps | self.task.deps:
1092 for dep_status in dep.get_dep_statuses(self, session, dep_context):
1094 self.log.debug(
1095 "%s dependency '%s' PASSED: %s, %s",
1096 self,
1097 dep_status.dep_name,
1098 dep_status.passed,
1099 dep_status.reason,
1100 )
1102 if not dep_status.passed:
1103 yield dep_status
1105 def __repr__(self) -> str:
1106 prefix = f"<TaskInstance: {self.dag_id}.{self.task_id} {self.run_id} "
1107 if self.map_index != -1:
1108 prefix += f"map_index={self.map_index} "
1109 return prefix + f"[{self.state}]>"
1111 def next_retry_datetime(self):
1112 """
1113 Get datetime of the next retry if the task instance fails. For exponential
1114 backoff, retry_delay is used as base and will be converted to seconds.
1115 """
1116 from airflow.models.abstractoperator import MAX_RETRY_DELAY
1118 delay = self.task.retry_delay
1119 if self.task.retry_exponential_backoff:
1120 # If the min_backoff calculation is below 1, it will be converted to 0 via int. Thus,
1121 # we must round up prior to converting to an int, otherwise a divide by zero error
1122 # will occur in the modded_hash calculation.
1123 min_backoff = int(math.ceil(delay.total_seconds() * (2 ** (self.try_number - 2))))
1125 # In the case when delay.total_seconds() is 0, min_backoff will not be rounded up to 1.
1126 # To address this, we impose a lower bound of 1 on min_backoff. This effectively makes
1127 # the ceiling function unnecessary, but the ceiling function was retained to avoid
1128 # introducing a breaking change.
1129 if min_backoff < 1:
1130 min_backoff = 1
1132 # deterministic per task instance
1133 ti_hash = int(
1134 hashlib.sha1(
1135 f"{self.dag_id}#{self.task_id}#{self.execution_date}#{self.try_number}".encode()
1136 ).hexdigest(),
1137 16,
1138 )
1139 # between 1 and 1.0 * delay * (2^retry_number)
1140 modded_hash = min_backoff + ti_hash % min_backoff
1141 # timedelta has a maximum representable value. The exponentiation
1142 # here means this value can be exceeded after a certain number
1143 # of tries (around 50 if the initial delay is 1s, even fewer if
1144 # the delay is larger). Cap the value here before creating a
1145 # timedelta object so the operation doesn't fail with "OverflowError".
1146 delay_backoff_in_seconds = min(modded_hash, MAX_RETRY_DELAY)
1147 delay = timedelta(seconds=delay_backoff_in_seconds)
1148 if self.task.max_retry_delay:
1149 delay = min(self.task.max_retry_delay, delay)
1150 return self.end_date + delay
1152 def ready_for_retry(self) -> bool:
1153 """
1154 Checks on whether the task instance is in the right state and timeframe
1155 to be retried.
1156 """
1157 return self.state == State.UP_FOR_RETRY and self.next_retry_datetime() < timezone.utcnow()
1159 @provide_session
1160 def get_dagrun(self, session: Session = NEW_SESSION) -> DagRun:
1161 """
1162 Returns the DagRun for this TaskInstance
1164 :param session: SQLAlchemy ORM Session
1165 :return: DagRun
1166 """
1167 info = inspect(self)
1168 if info.attrs.dag_run.loaded_value is not NO_VALUE:
1169 return self.dag_run
1171 from airflow.models.dagrun import DagRun # Avoid circular import
1173 dr = session.query(DagRun).filter(DagRun.dag_id == self.dag_id, DagRun.run_id == self.run_id).one()
1175 # Record it in the instance for next time. This means that `self.execution_date` will work correctly
1176 set_committed_value(self, "dag_run", dr)
1178 return dr
1180 @provide_session
1181 def check_and_change_state_before_execution(
1182 self,
1183 verbose: bool = True,
1184 ignore_all_deps: bool = False,
1185 ignore_depends_on_past: bool = False,
1186 ignore_task_deps: bool = False,
1187 ignore_ti_state: bool = False,
1188 mark_success: bool = False,
1189 test_mode: bool = False,
1190 job_id: str | None = None,
1191 pool: str | None = None,
1192 external_executor_id: str | None = None,
1193 session: Session = NEW_SESSION,
1194 ) -> bool:
1195 """
1196 Checks dependencies and then sets state to RUNNING if they are met. Returns
1197 True if and only if state is set to RUNNING, which implies that task should be
1198 executed, in preparation for _run_raw_task
1200 :param verbose: whether to turn on more verbose logging
1201 :param ignore_all_deps: Ignore all of the non-critical dependencies, just runs
1202 :param ignore_depends_on_past: Ignore depends_on_past DAG attribute
1203 :param ignore_task_deps: Don't check the dependencies of this TaskInstance's task
1204 :param ignore_ti_state: Disregards previous task instance state
1205 :param mark_success: Don't run the task, mark its state as success
1206 :param test_mode: Doesn't record success or failure in the DB
1207 :param job_id: Job (BackfillJob / LocalTaskJob / SchedulerJob) ID
1208 :param pool: specifies the pool to use to run the task instance
1209 :param external_executor_id: The identifier of the celery executor
1210 :param session: SQLAlchemy ORM Session
1211 :return: whether the state was changed to running or not
1212 """
1213 task = self.task
1214 self.refresh_from_task(task, pool_override=pool)
1215 self.test_mode = test_mode
1216 self.refresh_from_db(session=session, lock_for_update=True)
1217 self.job_id = job_id
1218 self.hostname = get_hostname()
1219 self.pid = None
1221 if not ignore_all_deps and not ignore_ti_state and self.state == State.SUCCESS:
1222 Stats.incr("previously_succeeded", 1, 1)
1224 # TODO: Logging needs cleanup, not clear what is being printed
1225 hr_line_break = "\n" + ("-" * 80) # Line break
1227 if not mark_success:
1228 # Firstly find non-runnable and non-requeueable tis.
1229 # Since mark_success is not set, we do nothing.
1230 non_requeueable_dep_context = DepContext(
1231 deps=RUNNING_DEPS - REQUEUEABLE_DEPS,
1232 ignore_all_deps=ignore_all_deps,
1233 ignore_ti_state=ignore_ti_state,
1234 ignore_depends_on_past=ignore_depends_on_past,
1235 ignore_task_deps=ignore_task_deps,
1236 )
1237 if not self.are_dependencies_met(
1238 dep_context=non_requeueable_dep_context, session=session, verbose=True
1239 ):
1240 session.commit()
1241 return False
1243 # For reporting purposes, we report based on 1-indexed,
1244 # not 0-indexed lists (i.e. Attempt 1 instead of
1245 # Attempt 0 for the first attempt).
1246 # Set the task start date. In case it was re-scheduled use the initial
1247 # start date that is recorded in task_reschedule table
1248 # If the task continues after being deferred (next_method is set), use the original start_date
1249 self.start_date = self.start_date if self.next_method else timezone.utcnow()
1250 if self.state == State.UP_FOR_RESCHEDULE:
1251 task_reschedule: TR = TR.query_for_task_instance(self, session=session).first()
1252 if task_reschedule:
1253 self.start_date = task_reschedule.start_date
1255 # Secondly we find non-runnable but requeueable tis. We reset its state.
1256 # This is because we might have hit concurrency limits,
1257 # e.g. because of backfilling.
1258 dep_context = DepContext(
1259 deps=REQUEUEABLE_DEPS,
1260 ignore_all_deps=ignore_all_deps,
1261 ignore_depends_on_past=ignore_depends_on_past,
1262 ignore_task_deps=ignore_task_deps,
1263 ignore_ti_state=ignore_ti_state,
1264 )
1265 if not self.are_dependencies_met(dep_context=dep_context, session=session, verbose=True):
1266 self.state = State.NONE
1267 self.log.warning(hr_line_break)
1268 self.log.warning(
1269 "Rescheduling due to concurrency limits reached "
1270 "at task runtime. Attempt %s of "
1271 "%s. State set to NONE.",
1272 self.try_number,
1273 self.max_tries + 1,
1274 )
1275 self.log.warning(hr_line_break)
1276 self.queued_dttm = timezone.utcnow()
1277 session.merge(self)
1278 session.commit()
1279 return False
1281 # print status message
1282 self.log.info(hr_line_break)
1283 self.log.info("Starting attempt %s of %s", self.try_number, self.max_tries + 1)
1284 self.log.info(hr_line_break)
1285 self._try_number += 1
1287 if not test_mode:
1288 session.add(Log(State.RUNNING, self))
1289 self.state = State.RUNNING
1290 self.external_executor_id = external_executor_id
1291 self.end_date = None
1292 if not test_mode:
1293 session.merge(self).task = task
1294 session.commit()
1296 # Closing all pooled connections to prevent
1297 # "max number of connections reached"
1298 settings.engine.dispose() # type: ignore
1299 if verbose:
1300 if mark_success:
1301 self.log.info("Marking success for %s on %s", self.task, self.execution_date)
1302 else:
1303 self.log.info("Executing %s on %s", self.task, self.execution_date)
1304 return True
1306 def _date_or_empty(self, attr: str) -> str:
1307 result: datetime | None = getattr(self, attr, None)
1308 return result.strftime("%Y%m%dT%H%M%S") if result else ""
1310 def _log_state(self, lead_msg: str = "") -> None:
1311 params = [
1312 lead_msg,
1313 str(self.state).upper(),
1314 self.dag_id,
1315 self.task_id,
1316 ]
1317 message = "%sMarking task as %s. dag_id=%s, task_id=%s, "
1318 if self.map_index >= 0:
1319 params.append(self.map_index)
1320 message += "map_index=%d, "
1321 self.log.info(
1322 message + "execution_date=%s, start_date=%s, end_date=%s",
1323 *params,
1324 self._date_or_empty("execution_date"),
1325 self._date_or_empty("start_date"),
1326 self._date_or_empty("end_date"),
1327 )
1329 # Ensure we unset next_method and next_kwargs to ensure that any
1330 # retries don't re-use them.
1331 def clear_next_method_args(self) -> None:
1332 self.log.debug("Clearing next_method and next_kwargs.")
1334 self.next_method = None
1335 self.next_kwargs = None
1337 @provide_session
1338 @Sentry.enrich_errors
1339 def _run_raw_task(
1340 self,
1341 mark_success: bool = False,
1342 test_mode: bool = False,
1343 job_id: str | None = None,
1344 pool: str | None = None,
1345 session: Session = NEW_SESSION,
1346 ) -> None:
1347 """
1348 Immediately runs the task (without checking or changing db state
1349 before execution) and then sets the appropriate final state after
1350 completion and runs any post-execute callbacks. Meant to be called
1351 only after another function changes the state to running.
1353 :param mark_success: Don't run the task, mark its state as success
1354 :param test_mode: Doesn't record success or failure in the DB
1355 :param pool: specifies the pool to use to run the task instance
1356 :param session: SQLAlchemy ORM Session
1357 """
1358 self.test_mode = test_mode
1359 self.refresh_from_task(self.task, pool_override=pool)
1360 self.refresh_from_db(session=session)
1361 self.job_id = job_id
1362 self.hostname = get_hostname()
1363 self.pid = os.getpid()
1364 if not test_mode:
1365 session.merge(self)
1366 session.commit()
1367 actual_start_date = timezone.utcnow()
1368 Stats.incr(f"ti.start.{self.task.dag_id}.{self.task.task_id}")
1369 # Initialize final state counters at zero
1370 for state in State.task_states:
1371 Stats.incr(f"ti.finish.{self.task.dag_id}.{self.task.task_id}.{state}", count=0)
1373 self.task = self.task.prepare_for_execution()
1374 context = self.get_template_context(ignore_param_exceptions=False)
1375 try:
1376 if not mark_success:
1377 self._execute_task_with_callbacks(context, test_mode)
1378 if not test_mode:
1379 self.refresh_from_db(lock_for_update=True, session=session)
1380 self.state = State.SUCCESS
1381 except TaskDeferred as defer:
1382 # The task has signalled it wants to defer execution based on
1383 # a trigger.
1384 self._defer_task(defer=defer, session=session)
1385 self.log.info(
1386 "Pausing task as DEFERRED. dag_id=%s, task_id=%s, execution_date=%s, start_date=%s",
1387 self.dag_id,
1388 self.task_id,
1389 self._date_or_empty("execution_date"),
1390 self._date_or_empty("start_date"),
1391 )
1392 if not test_mode:
1393 session.add(Log(self.state, self))
1394 session.merge(self)
1395 session.commit()
1396 return
1397 except AirflowSkipException as e:
1398 # Recording SKIP
1399 # log only if exception has any arguments to prevent log flooding
1400 if e.args:
1401 self.log.info(e)
1402 if not test_mode:
1403 self.refresh_from_db(lock_for_update=True, session=session)
1404 self.state = State.SKIPPED
1405 except AirflowRescheduleException as reschedule_exception:
1406 self._handle_reschedule(actual_start_date, reschedule_exception, test_mode, session=session)
1407 session.commit()
1408 return
1409 except (AirflowFailException, AirflowSensorTimeout) as e:
1410 # If AirflowFailException is raised, task should not retry.
1411 # If a sensor in reschedule mode reaches timeout, task should not retry.
1412 self.handle_failure(e, test_mode, context, force_fail=True, session=session)
1413 session.commit()
1414 raise
1415 except AirflowException as e:
1416 if not test_mode:
1417 self.refresh_from_db(lock_for_update=True, session=session)
1418 # for case when task is marked as success/failed externally
1419 # or dagrun timed out and task is marked as skipped
1420 # current behavior doesn't hit the callbacks
1421 if self.state in State.finished:
1422 self.clear_next_method_args()
1423 session.merge(self)
1424 session.commit()
1425 return
1426 else:
1427 self.handle_failure(e, test_mode, context, session=session)
1428 session.commit()
1429 raise
1430 except (Exception, KeyboardInterrupt) as e:
1431 self.handle_failure(e, test_mode, context, session=session)
1432 session.commit()
1433 raise
1434 finally:
1435 Stats.incr(f"ti.finish.{self.dag_id}.{self.task_id}.{self.state}")
1437 # Recording SKIPPED or SUCCESS
1438 self.clear_next_method_args()
1439 self.end_date = timezone.utcnow()
1440 self._log_state()
1441 self.set_duration()
1443 # run on_success_callback before db committing
1444 # otherwise, the LocalTaskJob sees the state is changed to `success`,
1445 # but the task_runner is still running, LocalTaskJob then treats the state is set externally!
1446 self._run_finished_callback(self.task.on_success_callback, context, "on_success")
1448 if not test_mode:
1449 session.add(Log(self.state, self))
1450 session.merge(self).task = self.task
1451 if self.state == TaskInstanceState.SUCCESS:
1452 self._register_dataset_changes(session=session)
1453 session.commit()
1455 def _register_dataset_changes(self, *, session: Session) -> None:
1456 for obj in self.task.outlets or []:
1457 self.log.debug("outlet obj %s", obj)
1458 # Lineage can have other types of objects besides datasets
1459 if isinstance(obj, Dataset):
1460 dataset_manager.register_dataset_change(
1461 task_instance=self,
1462 dataset=obj,
1463 session=session,
1464 )
1466 def _execute_task_with_callbacks(self, context, test_mode=False):
1467 """Prepare Task for Execution"""
1468 from airflow.models.renderedtifields import RenderedTaskInstanceFields
1470 parent_pid = os.getpid()
1472 def signal_handler(signum, frame):
1473 pid = os.getpid()
1475 # If a task forks during execution (from DAG code) for whatever
1476 # reason, we want to make sure that we react to the signal only in
1477 # the process that we've spawned ourselves (referred to here as the
1478 # parent process).
1479 if pid != parent_pid:
1480 os._exit(1)
1481 return
1482 self.log.error("Received SIGTERM. Terminating subprocesses.")
1483 self.task.on_kill()
1484 raise AirflowException("Task received SIGTERM signal")
1486 signal.signal(signal.SIGTERM, signal_handler)
1488 # Don't clear Xcom until the task is certain to execute, and check if we are resuming from deferral.
1489 if not self.next_method:
1490 self.clear_xcom_data()
1492 with Stats.timer(f"dag.{self.task.dag_id}.{self.task.task_id}.duration"):
1493 # Set the validated/merged params on the task object.
1494 self.task.params = context["params"]
1496 task_orig = self.render_templates(context=context)
1497 if not test_mode:
1498 rtif = RenderedTaskInstanceFields(ti=self, render_templates=False)
1499 RenderedTaskInstanceFields.write(rtif)
1500 RenderedTaskInstanceFields.delete_old_records(self.task_id, self.dag_id)
1502 # Export context to make it available for operators to use.
1503 airflow_context_vars = context_to_airflow_vars(context, in_env_var_format=True)
1504 os.environ.update(airflow_context_vars)
1506 # Log context only for the default execution method, the assumption
1507 # being that otherwise we're resuming a deferred task (in which
1508 # case there's no need to log these again).
1509 if not self.next_method:
1510 self.log.info(
1511 "Exporting the following env vars:\n%s",
1512 "\n".join(f"{k}={v}" for k, v in airflow_context_vars.items()),
1513 )
1515 # Run pre_execute callback
1516 self.task.pre_execute(context=context)
1518 # Run on_execute callback
1519 self._run_execute_callback(context, self.task)
1521 # Execute the task
1522 with set_current_context(context):
1523 result = self._execute_task(context, task_orig)
1525 # Run post_execute callback
1526 self.task.post_execute(context=context, result=result)
1528 Stats.incr(f"operator_successes_{self.task.task_type}", 1, 1)
1529 Stats.incr("ti_successes")
1531 def _run_finished_callback(
1532 self,
1533 callbacks: None | TaskStateChangeCallback | list[TaskStateChangeCallback],
1534 context: Context,
1535 callback_type: str,
1536 ) -> None:
1537 """Run callback after task finishes"""
1538 if callbacks:
1539 callbacks = callbacks if isinstance(callbacks, list) else [callbacks]
1540 for callback in callbacks:
1541 try:
1542 callback(context)
1543 except Exception:
1544 callback_name = qualname(callback).split(".")[-1]
1545 self.log.exception(
1546 f"Error when executing {callback_name} callback" # type: ignore[attr-defined]
1547 )
1549 def _execute_task(self, context, task_orig):
1550 """Executes Task (optionally with a Timeout) and pushes Xcom results"""
1551 task_to_execute = self.task
1552 # If the task has been deferred and is being executed due to a trigger,
1553 # then we need to pick the right method to come back to, otherwise
1554 # we go for the default execute
1555 if self.next_method:
1556 # __fail__ is a special signal value for next_method that indicates
1557 # this task was scheduled specifically to fail.
1558 if self.next_method == "__fail__":
1559 next_kwargs = self.next_kwargs or {}
1560 traceback = self.next_kwargs.get("traceback")
1561 if traceback is not None:
1562 self.log.error("Trigger failed:\n%s", "\n".join(traceback))
1563 raise TaskDeferralError(next_kwargs.get("error", "Unknown"))
1564 # Grab the callable off the Operator/Task and add in any kwargs
1565 execute_callable = getattr(task_to_execute, self.next_method)
1566 if self.next_kwargs:
1567 execute_callable = partial(execute_callable, **self.next_kwargs)
1568 else:
1569 execute_callable = task_to_execute.execute
1570 # If a timeout is specified for the task, make it fail
1571 # if it goes beyond
1572 if task_to_execute.execution_timeout:
1573 # If we are coming in with a next_method (i.e. from a deferral),
1574 # calculate the timeout from our start_date.
1575 if self.next_method:
1576 timeout_seconds = (
1577 task_to_execute.execution_timeout - (timezone.utcnow() - self.start_date)
1578 ).total_seconds()
1579 else:
1580 timeout_seconds = task_to_execute.execution_timeout.total_seconds()
1581 try:
1582 # It's possible we're already timed out, so fast-fail if true
1583 if timeout_seconds <= 0:
1584 raise AirflowTaskTimeout()
1585 # Run task in timeout wrapper
1586 with timeout(timeout_seconds):
1587 result = execute_callable(context=context)
1588 except AirflowTaskTimeout:
1589 task_to_execute.on_kill()
1590 raise
1591 else:
1592 result = execute_callable(context=context)
1593 with create_session() as session:
1594 if task_to_execute.do_xcom_push:
1595 xcom_value = result
1596 else:
1597 xcom_value = None
1598 if xcom_value is not None: # If the task returns a result, push an XCom containing it.
1599 self.xcom_push(key=XCOM_RETURN_KEY, value=xcom_value, session=session)
1600 self._record_task_map_for_downstreams(task_orig, xcom_value, session=session)
1601 return result
1603 @provide_session
1604 def _defer_task(self, session: Session, defer: TaskDeferred) -> None:
1605 """
1606 Marks the task as deferred and sets up the trigger that is needed
1607 to resume it.
1608 """
1609 from airflow.models.trigger import Trigger
1611 # First, make the trigger entry
1612 trigger_row = Trigger.from_object(defer.trigger)
1613 session.add(trigger_row)
1614 session.flush()
1616 # Then, update ourselves so it matches the deferral request
1617 # Keep an eye on the logic in `check_and_change_state_before_execution()`
1618 # depending on self.next_method semantics
1619 self.state = State.DEFERRED
1620 self.trigger_id = trigger_row.id
1621 self.next_method = defer.method_name
1622 self.next_kwargs = defer.kwargs or {}
1624 # Decrement try number so the next one is the same try
1625 self._try_number -= 1
1627 # Calculate timeout too if it was passed
1628 if defer.timeout is not None:
1629 self.trigger_timeout = timezone.utcnow() + defer.timeout
1630 else:
1631 self.trigger_timeout = None
1633 # If an execution_timeout is set, set the timeout to the minimum of
1634 # it and the trigger timeout
1635 execution_timeout = self.task.execution_timeout
1636 if execution_timeout:
1637 if self.trigger_timeout:
1638 self.trigger_timeout = min(self.start_date + execution_timeout, self.trigger_timeout)
1639 else:
1640 self.trigger_timeout = self.start_date + execution_timeout
1642 def _run_execute_callback(self, context: Context, task: Operator) -> None:
1643 """Functions that need to be run before a Task is executed"""
1644 callbacks = task.on_execute_callback
1645 if callbacks:
1646 callbacks = callbacks if isinstance(callbacks, list) else [callbacks]
1647 for callback in callbacks:
1648 try:
1649 callback(context)
1650 except Exception:
1651 self.log.exception("Failed when executing execute callback")
1653 @provide_session
1654 def run(
1655 self,
1656 verbose: bool = True,
1657 ignore_all_deps: bool = False,
1658 ignore_depends_on_past: bool = False,
1659 ignore_task_deps: bool = False,
1660 ignore_ti_state: bool = False,
1661 mark_success: bool = False,
1662 test_mode: bool = False,
1663 job_id: str | None = None,
1664 pool: str | None = None,
1665 session: Session = NEW_SESSION,
1666 ) -> None:
1667 """Run TaskInstance"""
1668 res = self.check_and_change_state_before_execution(
1669 verbose=verbose,
1670 ignore_all_deps=ignore_all_deps,
1671 ignore_depends_on_past=ignore_depends_on_past,
1672 ignore_task_deps=ignore_task_deps,
1673 ignore_ti_state=ignore_ti_state,
1674 mark_success=mark_success,
1675 test_mode=test_mode,
1676 job_id=job_id,
1677 pool=pool,
1678 session=session,
1679 )
1680 if not res:
1681 return
1683 self._run_raw_task(
1684 mark_success=mark_success, test_mode=test_mode, job_id=job_id, pool=pool, session=session
1685 )
1687 def dry_run(self) -> None:
1688 """Only Renders Templates for the TI"""
1689 from airflow.models.baseoperator import BaseOperator
1691 self.task = self.task.prepare_for_execution()
1692 self.render_templates()
1693 if TYPE_CHECKING:
1694 assert isinstance(self.task, BaseOperator)
1695 self.task.dry_run()
1697 @provide_session
1698 def _handle_reschedule(
1699 self, actual_start_date, reschedule_exception, test_mode=False, session=NEW_SESSION
1700 ):
1701 # Don't record reschedule request in test mode
1702 if test_mode:
1703 return
1705 from airflow.models.dagrun import DagRun # Avoid circular import
1707 self.refresh_from_db(session)
1709 self.end_date = timezone.utcnow()
1710 self.set_duration()
1712 # Lock DAG run to be sure not to get into a deadlock situation when trying to insert
1713 # TaskReschedule which apparently also creates lock on corresponding DagRun entity
1714 with_row_locks(
1715 session.query(DagRun).filter_by(
1716 dag_id=self.dag_id,
1717 run_id=self.run_id,
1718 ),
1719 session=session,
1720 ).one()
1722 # Log reschedule request
1723 session.add(
1724 TaskReschedule(
1725 self.task,
1726 self.run_id,
1727 self._try_number,
1728 actual_start_date,
1729 self.end_date,
1730 reschedule_exception.reschedule_date,
1731 self.map_index,
1732 )
1733 )
1735 # set state
1736 self.state = State.UP_FOR_RESCHEDULE
1738 # Decrement try_number so subsequent runs will use the same try number and write
1739 # to same log file.
1740 self._try_number -= 1
1742 self.clear_next_method_args()
1744 session.merge(self)
1745 session.commit()
1746 self.log.info("Rescheduling task, marking task as UP_FOR_RESCHEDULE")
1748 @staticmethod
1749 def get_truncated_error_traceback(error: BaseException, truncate_to: Callable) -> TracebackType | None:
1750 """
1751 Truncates the traceback of an exception to the first frame called from within a given function
1753 :param error: exception to get traceback from
1754 :param truncate_to: Function to truncate TB to. Must have a ``__code__`` attribute
1756 :meta private:
1757 """
1758 tb = error.__traceback__
1759 code = truncate_to.__func__.__code__ # type: ignore[attr-defined]
1760 while tb is not None:
1761 if tb.tb_frame.f_code is code:
1762 return tb.tb_next
1763 tb = tb.tb_next
1764 return tb or error.__traceback__
1766 @provide_session
1767 def handle_failure(
1768 self,
1769 error: None | str | Exception | KeyboardInterrupt,
1770 test_mode: bool | None = None,
1771 context: Context | None = None,
1772 force_fail: bool = False,
1773 session: Session = NEW_SESSION,
1774 ) -> None:
1775 """Handle Failure for the TaskInstance"""
1776 if test_mode is None:
1777 test_mode = self.test_mode
1779 if error:
1780 if isinstance(error, BaseException):
1781 tb = self.get_truncated_error_traceback(error, truncate_to=self._execute_task)
1782 self.log.error("Task failed with exception", exc_info=(type(error), error, tb))
1783 else:
1784 self.log.error("%s", error)
1785 if not test_mode:
1786 self.refresh_from_db(session)
1788 self.end_date = timezone.utcnow()
1789 self.set_duration()
1790 Stats.incr(f"operator_failures_{self.operator}")
1791 Stats.incr("ti_failures")
1792 if not test_mode:
1793 session.add(Log(State.FAILED, self))
1795 # Log failure duration
1796 session.add(TaskFail(ti=self))
1798 self.clear_next_method_args()
1800 # In extreme cases (zombie in case of dag with parse error) we might _not_ have a Task.
1801 if context is None and getattr(self, "task", None):
1802 context = self.get_template_context(session)
1804 if context is not None:
1805 context["exception"] = error
1807 # Set state correctly and figure out how to log it and decide whether
1808 # to email
1810 # Note, callback invocation needs to be handled by caller of
1811 # _run_raw_task to avoid race conditions which could lead to duplicate
1812 # invocations or miss invocation.
1814 # Since this function is called only when the TaskInstance state is running,
1815 # try_number contains the current try_number (not the next). We
1816 # only mark task instance as FAILED if the next task instance
1817 # try_number exceeds the max_tries ... or if force_fail is truthy
1819 task: BaseOperator | None = None
1820 try:
1821 if getattr(self, "task", None) and context:
1822 task = self.task.unmap((context, session))
1823 except Exception:
1824 self.log.error("Unable to unmap task to determine if we need to send an alert email")
1826 if force_fail or not self.is_eligible_to_retry():
1827 self.state = State.FAILED
1828 email_for_state = operator.attrgetter("email_on_failure")
1829 callbacks = task.on_failure_callback if task else None
1830 callback_type = "on_failure"
1831 else:
1832 if self.state == State.QUEUED:
1833 # We increase the try_number so as to fail the task if it fails to start after sometime
1834 self._try_number += 1
1835 self.state = State.UP_FOR_RETRY
1836 email_for_state = operator.attrgetter("email_on_retry")
1837 callbacks = task.on_retry_callback if task else None
1838 callback_type = "on_retry"
1840 self._log_state("Immediate failure requested. " if force_fail else "")
1841 if task and email_for_state(task) and task.email:
1842 try:
1843 self.email_alert(error, task)
1844 except Exception:
1845 self.log.exception("Failed to send email to: %s", task.email)
1847 if callbacks and context:
1848 self._run_finished_callback(callbacks, context, callback_type)
1850 if not test_mode:
1851 session.merge(self)
1852 session.flush()
1854 def is_eligible_to_retry(self):
1855 """Is task instance is eligible for retry"""
1856 if self.state == State.RESTARTING:
1857 # If a task is cleared when running, it goes into RESTARTING state and is always
1858 # eligible for retry
1859 return True
1860 if not getattr(self, "task", None):
1861 # Couldn't load the task, don't know number of retries, guess:
1862 return self.try_number <= self.max_tries
1864 return self.task.retries and self.try_number <= self.max_tries
1866 def get_template_context(
1867 self,
1868 session: Session | None = None,
1869 ignore_param_exceptions: bool = True,
1870 ) -> Context:
1871 """Return TI Context"""
1872 # Do not use provide_session here -- it expunges everything on exit!
1873 if not session:
1874 session = settings.Session()
1876 from airflow import macros
1877 from airflow.models.abstractoperator import NotMapped
1879 integrate_macros_plugins()
1881 task = self.task
1882 if TYPE_CHECKING:
1883 assert task.dag
1884 dag: DAG = task.dag
1886 dag_run = self.get_dagrun(session)
1887 data_interval = dag.get_run_data_interval(dag_run)
1889 validated_params = process_params(dag, task, dag_run, suppress_exception=ignore_param_exceptions)
1891 logical_date = timezone.coerce_datetime(self.execution_date)
1892 ds = logical_date.strftime("%Y-%m-%d")
1893 ds_nodash = ds.replace("-", "")
1894 ts = logical_date.isoformat()
1895 ts_nodash = logical_date.strftime("%Y%m%dT%H%M%S")
1896 ts_nodash_with_tz = ts.replace("-", "").replace(":", "")
1898 @cache # Prevent multiple database access.
1899 def _get_previous_dagrun_success() -> DagRun | None:
1900 return self.get_previous_dagrun(state=DagRunState.SUCCESS, session=session)
1902 def _get_previous_dagrun_data_interval_success() -> DataInterval | None:
1903 dagrun = _get_previous_dagrun_success()
1904 if dagrun is None:
1905 return None
1906 return dag.get_run_data_interval(dagrun)
1908 def get_prev_data_interval_start_success() -> pendulum.DateTime | None:
1909 data_interval = _get_previous_dagrun_data_interval_success()
1910 if data_interval is None:
1911 return None
1912 return data_interval.start
1914 def get_prev_data_interval_end_success() -> pendulum.DateTime | None:
1915 data_interval = _get_previous_dagrun_data_interval_success()
1916 if data_interval is None:
1917 return None
1918 return data_interval.end
1920 def get_prev_start_date_success() -> pendulum.DateTime | None:
1921 dagrun = _get_previous_dagrun_success()
1922 if dagrun is None:
1923 return None
1924 return timezone.coerce_datetime(dagrun.start_date)
1926 @cache
1927 def get_yesterday_ds() -> str:
1928 return (logical_date - timedelta(1)).strftime("%Y-%m-%d")
1930 def get_yesterday_ds_nodash() -> str:
1931 return get_yesterday_ds().replace("-", "")
1933 @cache
1934 def get_tomorrow_ds() -> str:
1935 return (logical_date + timedelta(1)).strftime("%Y-%m-%d")
1937 def get_tomorrow_ds_nodash() -> str:
1938 return get_tomorrow_ds().replace("-", "")
1940 @cache
1941 def get_next_execution_date() -> pendulum.DateTime | None:
1942 # For manually triggered dagruns that aren't run on a schedule,
1943 # the "next" execution date doesn't make sense, and should be set
1944 # to execution date for consistency with how execution_date is set
1945 # for manually triggered tasks, i.e. triggered_date == execution_date.
1946 if dag_run.external_trigger:
1947 return logical_date
1948 if dag is None:
1949 return None
1950 next_info = dag.next_dagrun_info(data_interval, restricted=False)
1951 if next_info is None:
1952 return None
1953 return timezone.coerce_datetime(next_info.logical_date)
1955 def get_next_ds() -> str | None:
1956 execution_date = get_next_execution_date()
1957 if execution_date is None:
1958 return None
1959 return execution_date.strftime("%Y-%m-%d")
1961 def get_next_ds_nodash() -> str | None:
1962 ds = get_next_ds()
1963 if ds is None:
1964 return ds
1965 return ds.replace("-", "")
1967 @cache
1968 def get_prev_execution_date():
1969 # For manually triggered dagruns that aren't run on a schedule,
1970 # the "previous" execution date doesn't make sense, and should be set
1971 # to execution date for consistency with how execution_date is set
1972 # for manually triggered tasks, i.e. triggered_date == execution_date.
1973 if dag_run.external_trigger:
1974 return logical_date
1975 with warnings.catch_warnings():
1976 warnings.simplefilter("ignore", RemovedInAirflow3Warning)
1977 return dag.previous_schedule(logical_date)
1979 @cache
1980 def get_prev_ds() -> str | None:
1981 execution_date = get_prev_execution_date()
1982 if execution_date is None:
1983 return None
1984 return execution_date.strftime(r"%Y-%m-%d")
1986 def get_prev_ds_nodash() -> str | None:
1987 prev_ds = get_prev_ds()
1988 if prev_ds is None:
1989 return None
1990 return prev_ds.replace("-", "")
1992 def get_triggering_events() -> dict[str, list[DatasetEvent]]:
1993 if TYPE_CHECKING:
1994 assert session is not None
1996 # The dag_run may not be attached to the session anymore since the
1997 # code base is over-zealous with use of session.expunge_all().
1998 # Re-attach it if we get called.
1999 nonlocal dag_run
2000 if dag_run not in session:
2001 dag_run = session.merge(dag_run, load=False)
2003 dataset_events = dag_run.consumed_dataset_events
2004 triggering_events: dict[str, list[DatasetEvent]] = defaultdict(list)
2005 for event in dataset_events:
2006 triggering_events[event.dataset.uri].append(event)
2008 return triggering_events
2010 try:
2011 expanded_ti_count: int | None = task.get_mapped_ti_count(self.run_id, session=session)
2012 except NotMapped:
2013 expanded_ti_count = None
2015 # NOTE: If you add anything to this dict, make sure to also update the
2016 # definition in airflow/utils/context.pyi, and KNOWN_CONTEXT_KEYS in
2017 # airflow/utils/context.py!
2018 context = {
2019 "conf": conf,
2020 "dag": dag,
2021 "dag_run": dag_run,
2022 "data_interval_end": timezone.coerce_datetime(data_interval.end),
2023 "data_interval_start": timezone.coerce_datetime(data_interval.start),
2024 "ds": ds,
2025 "ds_nodash": ds_nodash,
2026 "execution_date": logical_date,
2027 "expanded_ti_count": expanded_ti_count,
2028 "inlets": task.inlets,
2029 "logical_date": logical_date,
2030 "macros": macros,
2031 "next_ds": get_next_ds(),
2032 "next_ds_nodash": get_next_ds_nodash(),
2033 "next_execution_date": get_next_execution_date(),
2034 "outlets": task.outlets,
2035 "params": validated_params,
2036 "prev_data_interval_start_success": get_prev_data_interval_start_success(),
2037 "prev_data_interval_end_success": get_prev_data_interval_end_success(),
2038 "prev_ds": get_prev_ds(),
2039 "prev_ds_nodash": get_prev_ds_nodash(),
2040 "prev_execution_date": get_prev_execution_date(),
2041 "prev_execution_date_success": self.get_previous_execution_date(
2042 state=DagRunState.SUCCESS,
2043 session=session,
2044 ),
2045 "prev_start_date_success": get_prev_start_date_success(),
2046 "run_id": self.run_id,
2047 "task": task,
2048 "task_instance": self,
2049 "task_instance_key_str": f"{task.dag_id}__{task.task_id}__{ds_nodash}",
2050 "test_mode": self.test_mode,
2051 "ti": self,
2052 "tomorrow_ds": get_tomorrow_ds(),
2053 "tomorrow_ds_nodash": get_tomorrow_ds_nodash(),
2054 "triggering_dataset_events": lazy_object_proxy.Proxy(get_triggering_events),
2055 "ts": ts,
2056 "ts_nodash": ts_nodash,
2057 "ts_nodash_with_tz": ts_nodash_with_tz,
2058 "var": {
2059 "json": VariableAccessor(deserialize_json=True),
2060 "value": VariableAccessor(deserialize_json=False),
2061 },
2062 "conn": ConnectionAccessor(),
2063 "yesterday_ds": get_yesterday_ds(),
2064 "yesterday_ds_nodash": get_yesterday_ds_nodash(),
2065 }
2066 # Mypy doesn't like turning existing dicts in to a TypeDict -- and we "lie" in the type stub to say it
2067 # is one, but in practice it isn't. See https://github.com/python/mypy/issues/8890
2068 return Context(context) # type: ignore
2070 @provide_session
2071 def get_rendered_template_fields(self, session: Session = NEW_SESSION) -> None:
2072 """
2073 Update task with rendered template fields for presentation in UI.
2074 If task has already run, will fetch from DB; otherwise will render.
2075 """
2076 from airflow.models.renderedtifields import RenderedTaskInstanceFields
2078 rendered_task_instance_fields = RenderedTaskInstanceFields.get_templated_fields(self, session=session)
2079 if rendered_task_instance_fields:
2080 self.task = self.task.unmap(None)
2081 for field_name, rendered_value in rendered_task_instance_fields.items():
2082 setattr(self.task, field_name, rendered_value)
2083 return
2085 try:
2086 # If we get here, either the task hasn't run or the RTIF record was purged.
2087 from airflow.utils.log.secrets_masker import redact
2089 self.render_templates()
2090 for field_name in self.task.template_fields:
2091 rendered_value = getattr(self.task, field_name)
2092 setattr(self.task, field_name, redact(rendered_value, field_name))
2093 except (TemplateAssertionError, UndefinedError) as e:
2094 raise AirflowException(
2095 "Webserver does not have access to User-defined Macros or Filters "
2096 "when Dag Serialization is enabled. Hence for the task that have not yet "
2097 "started running, please use 'airflow tasks render' for debugging the "
2098 "rendering of template_fields."
2099 ) from e
2101 @provide_session
2102 def get_rendered_k8s_spec(self, session: Session = NEW_SESSION):
2103 """Fetch rendered template fields from DB"""
2104 from airflow.models.renderedtifields import RenderedTaskInstanceFields
2106 rendered_k8s_spec = RenderedTaskInstanceFields.get_k8s_pod_yaml(self, session=session)
2107 if not rendered_k8s_spec:
2108 try:
2109 rendered_k8s_spec = self.render_k8s_pod_yaml()
2110 except (TemplateAssertionError, UndefinedError) as e:
2111 raise AirflowException(f"Unable to render a k8s spec for this taskinstance: {e}") from e
2112 return rendered_k8s_spec
2114 def overwrite_params_with_dag_run_conf(self, params, dag_run):
2115 """Overwrite Task Params with DagRun.conf"""
2116 if dag_run and dag_run.conf:
2117 self.log.debug("Updating task params (%s) with DagRun.conf (%s)", params, dag_run.conf)
2118 params.update(dag_run.conf)
2120 def render_templates(self, context: Context | None = None) -> Operator:
2121 """Render templates in the operator fields.
2123 If the task was originally mapped, this may replace ``self.task`` with
2124 the unmapped, fully rendered BaseOperator. The original ``self.task``
2125 before replacement is returned.
2126 """
2127 if not context:
2128 context = self.get_template_context()
2129 original_task = self.task
2131 # If self.task is mapped, this call replaces self.task to point to the
2132 # unmapped BaseOperator created by this function! This is because the
2133 # MappedOperator is useless for template rendering, and we need to be
2134 # able to access the unmapped task instead.
2135 original_task.render_template_fields(context)
2137 return original_task
2139 def render_k8s_pod_yaml(self) -> dict | None:
2140 """Render k8s pod yaml"""
2141 from kubernetes.client.api_client import ApiClient
2143 from airflow.kubernetes.kube_config import KubeConfig
2144 from airflow.kubernetes.kubernetes_helper_functions import create_pod_id # Circular import
2145 from airflow.kubernetes.pod_generator import PodGenerator
2147 kube_config = KubeConfig()
2148 pod = PodGenerator.construct_pod(
2149 dag_id=self.dag_id,
2150 run_id=self.run_id,
2151 task_id=self.task_id,
2152 map_index=self.map_index,
2153 date=None,
2154 pod_id=create_pod_id(self.dag_id, self.task_id),
2155 try_number=self.try_number,
2156 kube_image=kube_config.kube_image,
2157 args=self.command_as_list(),
2158 pod_override_object=PodGenerator.from_obj(self.executor_config),
2159 scheduler_job_id="0",
2160 namespace=kube_config.executor_namespace,
2161 base_worker_pod=PodGenerator.deserialize_model_file(kube_config.pod_template_file),
2162 with_mutation_hook=True,
2163 )
2164 sanitized_pod = ApiClient().sanitize_for_serialization(pod)
2165 return sanitized_pod
2167 def get_email_subject_content(
2168 self, exception: BaseException, task: BaseOperator | None = None
2169 ) -> tuple[str, str, str]:
2170 """Get the email subject content for exceptions."""
2171 # For a ti from DB (without ti.task), return the default value
2172 if task is None:
2173 task = getattr(self, "task")
2174 use_default = task is None
2175 exception_html = str(exception).replace("\n", "<br>")
2177 default_subject = "Airflow alert: {{ti}}"
2178 # For reporting purposes, we report based on 1-indexed,
2179 # not 0-indexed lists (i.e. Try 1 instead of
2180 # Try 0 for the first attempt).
2181 default_html_content = (
2182 "Try {{try_number}} out of {{max_tries + 1}}<br>"
2183 "Exception:<br>{{exception_html}}<br>"
2184 'Log: <a href="{{ti.log_url}}">Link</a><br>'
2185 "Host: {{ti.hostname}}<br>"
2186 'Mark success: <a href="{{ti.mark_success_url}}">Link</a><br>'
2187 )
2189 default_html_content_err = (
2190 "Try {{try_number}} out of {{max_tries + 1}}<br>"
2191 "Exception:<br>Failed attempt to attach error logs<br>"
2192 'Log: <a href="{{ti.log_url}}">Link</a><br>'
2193 "Host: {{ti.hostname}}<br>"
2194 'Mark success: <a href="{{ti.mark_success_url}}">Link</a><br>'
2195 )
2197 # This function is called after changing the state from State.RUNNING,
2198 # so we need to subtract 1 from self.try_number here.
2199 current_try_number = self.try_number - 1
2200 additional_context: dict[str, Any] = {
2201 "exception": exception,
2202 "exception_html": exception_html,
2203 "try_number": current_try_number,
2204 "max_tries": self.max_tries,
2205 }
2207 if use_default:
2208 default_context = {"ti": self, **additional_context}
2209 jinja_env = jinja2.Environment(
2210 loader=jinja2.FileSystemLoader(os.path.dirname(__file__)), autoescape=True
2211 )
2212 subject = jinja_env.from_string(default_subject).render(**default_context)
2213 html_content = jinja_env.from_string(default_html_content).render(**default_context)
2214 html_content_err = jinja_env.from_string(default_html_content_err).render(**default_context)
2216 else:
2217 # Use the DAG's get_template_env() to set force_sandboxed. Don't add
2218 # the flag to the function on task object -- that function can be
2219 # overridden, and adding a flag breaks backward compatibility.
2220 dag = self.task.get_dag()
2221 if dag:
2222 jinja_env = dag.get_template_env(force_sandboxed=True)
2223 else:
2224 jinja_env = SandboxedEnvironment(cache_size=0)
2225 jinja_context = self.get_template_context()
2226 context_merge(jinja_context, additional_context)
2228 def render(key: str, content: str) -> str:
2229 if conf.has_option("email", key):
2230 path = conf.get_mandatory_value("email", key)
2231 try:
2232 with open(path) as f:
2233 content = f.read()
2234 except FileNotFoundError:
2235 self.log.warning(f"Could not find email template file '{path!r}'. Using defaults...")
2236 except OSError:
2237 self.log.exception(f"Error while using email template '{path!r}'. Using defaults...")
2238 return render_template_to_string(jinja_env.from_string(content), jinja_context)
2240 subject = render("subject_template", default_subject)
2241 html_content = render("html_content_template", default_html_content)
2242 html_content_err = render("html_content_template", default_html_content_err)
2244 return subject, html_content, html_content_err
2246 def email_alert(self, exception, task: BaseOperator) -> None:
2247 """Send alert email with exception information."""
2248 subject, html_content, html_content_err = self.get_email_subject_content(exception, task=task)
2249 assert task.email
2250 try:
2251 send_email(task.email, subject, html_content)
2252 except Exception:
2253 send_email(task.email, subject, html_content_err)
2255 def set_duration(self) -> None:
2256 """Set TI duration"""
2257 if self.end_date and self.start_date:
2258 self.duration = (self.end_date - self.start_date).total_seconds()
2259 else:
2260 self.duration = None
2261 self.log.debug("Task Duration set to %s", self.duration)
2263 def _record_task_map_for_downstreams(self, task: Operator, value: Any, *, session: Session) -> None:
2264 if next(task.iter_mapped_dependants(), None) is None: # No mapped dependants, no need to validate.
2265 return
2266 # TODO: We don't push TaskMap for mapped task instances because it's not
2267 # currently possible for a downstream to depend on one individual mapped
2268 # task instance. This will change when we implement task mapping inside
2269 # a mapped task group, and we'll need to further analyze the case.
2270 if isinstance(task, MappedOperator):
2271 return
2272 if value is None:
2273 raise XComForMappingNotPushed()
2274 if not _is_mappable_value(value):
2275 raise UnmappableXComTypePushed(value)
2276 task_map = TaskMap.from_task_instance_xcom(self, value)
2277 max_map_length = conf.getint("core", "max_map_length", fallback=1024)
2278 if task_map.length > max_map_length:
2279 raise UnmappableXComLengthPushed(value, max_map_length)
2280 session.merge(task_map)
2282 @provide_session
2283 def xcom_push(
2284 self,
2285 key: str,
2286 value: Any,
2287 execution_date: datetime | None = None,
2288 session: Session = NEW_SESSION,
2289 ) -> None:
2290 """
2291 Make an XCom available for tasks to pull.
2293 :param key: Key to store the value under.
2294 :param value: Value to store. What types are possible depends on whether
2295 ``enable_xcom_pickling`` is true or not. If so, this can be any
2296 picklable object; only be JSON-serializable may be used otherwise.
2297 :param execution_date: Deprecated parameter that has no effect.
2298 """
2299 if execution_date is not None:
2300 self_execution_date = self.get_dagrun(session).execution_date
2301 if execution_date < self_execution_date:
2302 raise ValueError(
2303 f"execution_date can not be in the past (current execution_date is "
2304 f"{self_execution_date}; received {execution_date})"
2305 )
2306 elif execution_date is not None:
2307 message = "Passing 'execution_date' to 'TaskInstance.xcom_push()' is deprecated."
2308 warnings.warn(message, RemovedInAirflow3Warning, stacklevel=3)
2310 XCom.set(
2311 key=key,
2312 value=value,
2313 task_id=self.task_id,
2314 dag_id=self.dag_id,
2315 run_id=self.run_id,
2316 map_index=self.map_index,
2317 session=session,
2318 )
2320 @provide_session
2321 def xcom_pull(
2322 self,
2323 task_ids: str | Iterable[str] | None = None,
2324 dag_id: str | None = None,
2325 key: str = XCOM_RETURN_KEY,
2326 include_prior_dates: bool = False,
2327 session: Session = NEW_SESSION,
2328 *,
2329 map_indexes: int | Iterable[int] | None = None,
2330 default: Any = None,
2331 ) -> Any:
2332 """Pull XComs that optionally meet certain criteria.
2334 :param key: A key for the XCom. If provided, only XComs with matching
2335 keys will be returned. The default key is ``'return_value'``, also
2336 available as constant ``XCOM_RETURN_KEY``. This key is automatically
2337 given to XComs returned by tasks (as opposed to being pushed
2338 manually). To remove the filter, pass *None*.
2339 :param task_ids: Only XComs from tasks with matching ids will be
2340 pulled. Pass *None* to remove the filter.
2341 :param dag_id: If provided, only pulls XComs from this DAG. If *None*
2342 (default), the DAG of the calling task is used.
2343 :param map_indexes: If provided, only pull XComs with matching indexes.
2344 If *None* (default), this is inferred from the task(s) being pulled
2345 (see below for details).
2346 :param include_prior_dates: If False, only XComs from the current
2347 execution_date are returned. If *True*, XComs from previous dates
2348 are returned as well.
2350 When pulling one single task (``task_id`` is *None* or a str) without
2351 specifying ``map_indexes``, the return value is inferred from whether
2352 the specified task is mapped. If not, value from the one single task
2353 instance is returned. If the task to pull is mapped, an iterator (not a
2354 list) yielding XComs from mapped task instances is returned. In either
2355 case, ``default`` (*None* if not specified) is returned if no matching
2356 XComs are found.
2358 When pulling multiple tasks (i.e. either ``task_id`` or ``map_index`` is
2359 a non-str iterable), a list of matching XComs is returned. Elements in
2360 the list is ordered by item ordering in ``task_id`` and ``map_index``.
2361 """
2362 if dag_id is None:
2363 dag_id = self.dag_id
2365 query = XCom.get_many(
2366 key=key,
2367 run_id=self.run_id,
2368 dag_ids=dag_id,
2369 task_ids=task_ids,
2370 map_indexes=map_indexes,
2371 include_prior_dates=include_prior_dates,
2372 session=session,
2373 )
2375 # NOTE: Since we're only fetching the value field and not the whole
2376 # class, the @recreate annotation does not kick in. Therefore we need to
2377 # call XCom.deserialize_value() manually.
2379 # We are only pulling one single task.
2380 if (task_ids is None or isinstance(task_ids, str)) and not isinstance(map_indexes, Iterable):
2381 first = query.with_entities(
2382 XCom.run_id, XCom.task_id, XCom.dag_id, XCom.map_index, XCom.value
2383 ).first()
2384 if first is None: # No matching XCom at all.
2385 return default
2386 if map_indexes is not None or first.map_index < 0:
2387 return XCom.deserialize_value(first)
2388 query = query.order_by(None).order_by(XCom.map_index.asc())
2389 return LazyXComAccess.build_from_xcom_query(query)
2391 # At this point either task_ids or map_indexes is explicitly multi-value.
2392 # Order return values to match task_ids and map_indexes ordering.
2393 query = query.order_by(None)
2394 if task_ids is None or isinstance(task_ids, str):
2395 query = query.order_by(XCom.task_id)
2396 else:
2397 task_id_whens = {tid: i for i, tid in enumerate(task_ids)}
2398 if task_id_whens:
2399 query = query.order_by(case(task_id_whens, value=XCom.task_id))
2400 else:
2401 query = query.order_by(XCom.task_id)
2402 if map_indexes is None or isinstance(map_indexes, int):
2403 query = query.order_by(XCom.map_index)
2404 elif isinstance(map_indexes, range):
2405 order = XCom.map_index
2406 if map_indexes.step < 0:
2407 order = order.desc()
2408 query = query.order_by(order)
2409 else:
2410 map_index_whens = {map_index: i for i, map_index in enumerate(map_indexes)}
2411 if map_index_whens:
2412 query = query.order_by(case(map_index_whens, value=XCom.map_index))
2413 else:
2414 query = query.order_by(XCom.map_index)
2415 return LazyXComAccess.build_from_xcom_query(query)
2417 @provide_session
2418 def get_num_running_task_instances(self, session: Session) -> int:
2419 """Return Number of running TIs from the DB"""
2420 # .count() is inefficient
2421 return (
2422 session.query(func.count())
2423 .filter(
2424 TaskInstance.dag_id == self.dag_id,
2425 TaskInstance.task_id == self.task_id,
2426 TaskInstance.state == State.RUNNING,
2427 )
2428 .scalar()
2429 )
2431 def init_run_context(self, raw: bool = False) -> None:
2432 """Sets the log context."""
2433 self.raw = raw
2434 self._set_context(self)
2436 @staticmethod
2437 def filter_for_tis(tis: Iterable[TaskInstance | TaskInstanceKey]) -> BooleanClauseList | None:
2438 """Returns SQLAlchemy filter to query selected task instances"""
2439 # DictKeys type, (what we often pass here from the scheduler) is not directly indexable :(
2440 # Or it might be a generator, but we need to be able to iterate over it more than once
2441 tis = list(tis)
2443 if not tis:
2444 return None
2446 first = tis[0]
2448 dag_id = first.dag_id
2449 run_id = first.run_id
2450 map_index = first.map_index
2451 first_task_id = first.task_id
2453 # pre-compute the set of dag_id, run_id, map_indices and task_ids
2454 dag_ids, run_ids, map_indices, task_ids = set(), set(), set(), set()
2455 for t in tis:
2456 dag_ids.add(t.dag_id)
2457 run_ids.add(t.run_id)
2458 map_indices.add(t.map_index)
2459 task_ids.add(t.task_id)
2461 # Common path optimisations: when all TIs are for the same dag_id and run_id, or same dag_id
2462 # and task_id -- this can be over 150x faster for huge numbers of TIs (20k+)
2463 if dag_ids == {dag_id} and run_ids == {run_id} and map_indices == {map_index}:
2464 return and_(
2465 TaskInstance.dag_id == dag_id,
2466 TaskInstance.run_id == run_id,
2467 TaskInstance.map_index == map_index,
2468 TaskInstance.task_id.in_(task_ids),
2469 )
2470 if dag_ids == {dag_id} and task_ids == {first_task_id} and map_indices == {map_index}:
2471 return and_(
2472 TaskInstance.dag_id == dag_id,
2473 TaskInstance.run_id.in_(run_ids),
2474 TaskInstance.map_index == map_index,
2475 TaskInstance.task_id == first_task_id,
2476 )
2477 if dag_ids == {dag_id} and run_ids == {run_id} and task_ids == {first_task_id}:
2478 return and_(
2479 TaskInstance.dag_id == dag_id,
2480 TaskInstance.run_id == run_id,
2481 TaskInstance.map_index.in_(map_indices),
2482 TaskInstance.task_id == first_task_id,
2483 )
2485 filter_condition = []
2486 # create 2 nested groups, both primarily grouped by dag_id and run_id,
2487 # and in the nested group 1 grouped by task_id the other by map_index.
2488 task_id_groups: dict[tuple, dict[Any, list[Any]]] = defaultdict(lambda: defaultdict(list))
2489 map_index_groups: dict[tuple, dict[Any, list[Any]]] = defaultdict(lambda: defaultdict(list))
2490 for t in tis:
2491 task_id_groups[(t.dag_id, t.run_id)][t.task_id].append(t.map_index)
2492 map_index_groups[(t.dag_id, t.run_id)][t.map_index].append(t.task_id)
2494 # this assumes that most dags have dag_id as the largest grouping, followed by run_id. even
2495 # if its not, this is still a significant optimization over querying for every single tuple key
2496 for cur_dag_id in dag_ids:
2497 for cur_run_id in run_ids:
2498 # we compare the group size between task_id and map_index and use the smaller group
2499 dag_task_id_groups = task_id_groups[(cur_dag_id, cur_run_id)]
2500 dag_map_index_groups = map_index_groups[(cur_dag_id, cur_run_id)]
2502 if len(dag_task_id_groups) <= len(dag_map_index_groups):
2503 for cur_task_id, cur_map_indices in dag_task_id_groups.items():
2504 filter_condition.append(
2505 and_(
2506 TaskInstance.dag_id == cur_dag_id,
2507 TaskInstance.run_id == cur_run_id,
2508 TaskInstance.task_id == cur_task_id,
2509 TaskInstance.map_index.in_(cur_map_indices),
2510 )
2511 )
2512 else:
2513 for cur_map_index, cur_task_ids in dag_map_index_groups.items():
2514 filter_condition.append(
2515 and_(
2516 TaskInstance.dag_id == cur_dag_id,
2517 TaskInstance.run_id == cur_run_id,
2518 TaskInstance.task_id.in_(cur_task_ids),
2519 TaskInstance.map_index == cur_map_index,
2520 )
2521 )
2523 return or_(*filter_condition)
2525 @classmethod
2526 def ti_selector_condition(cls, vals: Collection[str | tuple[str, int]]) -> ColumnOperators:
2527 """
2528 Build an SQLAlchemy filter for a list where each element can contain
2529 whether a task_id, or a tuple of (task_id,map_index)
2531 :meta private:
2532 """
2533 # Compute a filter for TI.task_id and TI.map_index based on input values
2534 # For each item, it will either be a task_id, or (task_id, map_index)
2535 task_id_only = [v for v in vals if isinstance(v, str)]
2536 with_map_index = [v for v in vals if not isinstance(v, str)]
2538 filters: list[ColumnOperators] = []
2539 if task_id_only:
2540 filters.append(cls.task_id.in_(task_id_only))
2541 if with_map_index:
2542 filters.append(tuple_in_condition((cls.task_id, cls.map_index), with_map_index))
2544 if not filters:
2545 return false()
2546 if len(filters) == 1:
2547 return filters[0]
2548 return or_(*filters)
2550 @Sentry.enrich_errors
2551 @provide_session
2552 def schedule_downstream_tasks(self, session=None):
2553 """
2554 The mini-scheduler for scheduling downstream tasks of this task instance
2555 :meta: private
2556 """
2557 from sqlalchemy.exc import OperationalError
2559 from airflow.models import DagRun
2561 try:
2562 # Re-select the row with a lock
2563 dag_run = with_row_locks(
2564 session.query(DagRun).filter_by(
2565 dag_id=self.dag_id,
2566 run_id=self.run_id,
2567 ),
2568 session=session,
2569 ).one()
2571 task = self.task
2572 if TYPE_CHECKING:
2573 assert task.dag
2575 # Get a partial DAG with just the specific tasks we want to examine.
2576 # In order for dep checks to work correctly, we include ourself (so
2577 # TriggerRuleDep can check the state of the task we just executed).
2578 partial_dag = task.dag.partial_subset(
2579 task.downstream_task_ids,
2580 include_downstream=True,
2581 include_upstream=False,
2582 include_direct_upstream=True,
2583 )
2585 dag_run.dag = partial_dag
2586 info = dag_run.task_instance_scheduling_decisions(session)
2588 skippable_task_ids = {
2589 task_id for task_id in partial_dag.task_ids if task_id not in task.downstream_task_ids
2590 }
2592 schedulable_tis = [ti for ti in info.schedulable_tis if ti.task_id not in skippable_task_ids]
2593 for schedulable_ti in schedulable_tis:
2594 if not hasattr(schedulable_ti, "task"):
2595 schedulable_ti.task = task.dag.get_task(schedulable_ti.task_id)
2597 num = dag_run.schedule_tis(schedulable_tis, session=session)
2598 self.log.info("%d downstream tasks scheduled from follow-on schedule check", num)
2600 session.flush()
2602 except OperationalError as e:
2603 # Any kind of DB error here is _non fatal_ as this block is just an optimisation.
2604 self.log.info(
2605 "Skipping mini scheduling run due to exception: %s",
2606 e.statement,
2607 exc_info=True,
2608 )
2609 session.rollback()
2611 def get_relevant_upstream_map_indexes(
2612 self,
2613 upstream: Operator,
2614 ti_count: int | None,
2615 *,
2616 session: Session,
2617 ) -> int | range | None:
2618 """Infer the map indexes of an upstream "relevant" to this ti.
2620 The bulk of the logic mainly exists to solve the problem described by
2621 the following example, where 'val' must resolve to different values,
2622 depending on where the reference is being used::
2624 @task
2625 def this_task(v): # This is self.task.
2626 return v * 2
2628 @task_group
2629 def tg1(inp):
2630 val = upstream(inp) # This is the upstream task.
2631 this_task(val) # When inp is 1, val here should resolve to 2.
2632 return val
2634 # This val is the same object returned by tg1.
2635 val = tg1.expand(inp=[1, 2, 3])
2637 @task_group
2638 def tg2(inp):
2639 another_task(inp, val) # val here should resolve to [2, 4, 6].
2641 tg2.expand(inp=["a", "b"])
2643 The surrounding mapped task groups of ``upstream`` and ``self.task`` are
2644 inspected to find a common "ancestor". If such an ancestor is found,
2645 we need to return specific map indexes to pull a partial value from
2646 upstream XCom.
2648 :param upstream: The referenced upstream task.
2649 :param ti_count: The total count of task instance this task was expanded
2650 by the scheduler, i.e. ``expanded_ti_count`` in the template context.
2651 :return: Specific map index or map indexes to pull, or ``None`` if we
2652 want to "whole" return value (i.e. no mapped task groups involved).
2653 """
2654 # Find the innermost common mapped task group between the current task
2655 # If the current task and the referenced task does not have a common
2656 # mapped task group, the two are in different task mapping contexts
2657 # (like another_task above), and we should use the "whole" value.
2658 common_ancestor = _find_common_ancestor_mapped_group(self.task, upstream)
2659 if common_ancestor is None:
2660 return None
2662 # This value should never be None since we already know the current task
2663 # is in a mapped task group, and should have been expanded. The check
2664 # exists mainly to satisfy Mypy.
2665 if ti_count is None:
2666 return None
2668 # At this point we know the two tasks share a mapped task group, and we
2669 # should use a "partial" value. Let's break down the mapped ti count
2670 # between the ancestor and further expansion happened inside it.
2671 ancestor_ti_count = common_ancestor.get_mapped_ti_count(self.run_id, session=session)
2672 ancestor_map_index = self.map_index * ancestor_ti_count // ti_count
2674 # If the task is NOT further expanded inside the common ancestor, we
2675 # only want to reference one single ti. We must walk the actual DAG,
2676 # and "ti_count == ancestor_ti_count" does not work, since the further
2677 # expansion may be of length 1.
2678 if not _is_further_mapped_inside(upstream, common_ancestor):
2679 return ancestor_map_index
2681 # Otherwise we need a partial aggregation for values from selected task
2682 # instances in the ancestor's expansion context.
2683 further_count = ti_count // ancestor_ti_count
2684 map_index_start = ancestor_map_index * further_count
2685 return range(map_index_start, map_index_start + further_count)
2688def _find_common_ancestor_mapped_group(node1: Operator, node2: Operator) -> MappedTaskGroup | None:
2689 """Given two operators, find their innermost common mapped task group."""
2690 if node1.dag is None or node2.dag is None or node1.dag_id != node2.dag_id:
2691 return None
2692 parent_group_ids = {g.group_id for g in node1.iter_mapped_task_groups()}
2693 common_groups = (g for g in node2.iter_mapped_task_groups() if g.group_id in parent_group_ids)
2694 return next(common_groups, None)
2697def _is_further_mapped_inside(operator: Operator, container: TaskGroup) -> bool:
2698 """Whether given operator is *further* mapped inside a task group."""
2699 if isinstance(operator, MappedOperator):
2700 return True
2701 task_group = operator.task_group
2702 while task_group is not None and task_group.group_id != container.group_id:
2703 if isinstance(task_group, MappedTaskGroup):
2704 return True
2705 task_group = task_group.parent_group
2706 return False
2709# State of the task instance.
2710# Stores string version of the task state.
2711TaskInstanceStateType = Tuple[TaskInstanceKey, str]
2714class SimpleTaskInstance:
2715 """
2716 Simplified Task Instance.
2718 Used to send data between processes via Queues.
2719 """
2721 def __init__(
2722 self,
2723 dag_id: str,
2724 task_id: str,
2725 run_id: str,
2726 start_date: datetime | None,
2727 end_date: datetime | None,
2728 try_number: int,
2729 map_index: int,
2730 state: str,
2731 executor_config: Any,
2732 pool: str,
2733 queue: str,
2734 key: TaskInstanceKey,
2735 run_as_user: str | None = None,
2736 priority_weight: int | None = None,
2737 ):
2738 self.dag_id = dag_id
2739 self.task_id = task_id
2740 self.run_id = run_id
2741 self.map_index = map_index
2742 self.start_date = start_date
2743 self.end_date = end_date
2744 self.try_number = try_number
2745 self.state = state
2746 self.executor_config = executor_config
2747 self.run_as_user = run_as_user
2748 self.pool = pool
2749 self.priority_weight = priority_weight
2750 self.queue = queue
2751 self.key = key
2753 def __eq__(self, other):
2754 if isinstance(other, self.__class__):
2755 return self.__dict__ == other.__dict__
2756 return NotImplemented
2758 def as_dict(self):
2759 warnings.warn(
2760 "This method is deprecated. Use BaseSerialization.serialize.",
2761 RemovedInAirflow3Warning,
2762 stacklevel=2,
2763 )
2764 new_dict = dict(self.__dict__)
2765 for key in new_dict:
2766 if key in ["start_date", "end_date"]:
2767 val = new_dict[key]
2768 if not val or isinstance(val, str):
2769 continue
2770 new_dict.update({key: val.isoformat()})
2771 return new_dict
2773 @classmethod
2774 def from_ti(cls, ti: TaskInstance) -> SimpleTaskInstance:
2775 return cls(
2776 dag_id=ti.dag_id,
2777 task_id=ti.task_id,
2778 run_id=ti.run_id,
2779 map_index=ti.map_index,
2780 start_date=ti.start_date,
2781 end_date=ti.end_date,
2782 try_number=ti.try_number,
2783 state=ti.state,
2784 executor_config=ti.executor_config,
2785 pool=ti.pool,
2786 queue=ti.queue,
2787 key=ti.key,
2788 run_as_user=ti.run_as_user if hasattr(ti, "run_as_user") else None,
2789 priority_weight=ti.priority_weight if hasattr(ti, "priority_weight") else None,
2790 )
2792 @classmethod
2793 def from_dict(cls, obj_dict: dict) -> SimpleTaskInstance:
2794 warnings.warn(
2795 "This method is deprecated. Use BaseSerialization.deserialize.",
2796 RemovedInAirflow3Warning,
2797 stacklevel=2,
2798 )
2799 ti_key = TaskInstanceKey(*obj_dict.pop("key"))
2800 start_date = None
2801 end_date = None
2802 start_date_str: str | None = obj_dict.pop("start_date")
2803 end_date_str: str | None = obj_dict.pop("end_date")
2804 if start_date_str:
2805 start_date = timezone.parse(start_date_str)
2806 if end_date_str:
2807 end_date = timezone.parse(end_date_str)
2808 return cls(**obj_dict, start_date=start_date, end_date=end_date, key=ti_key)
2811class TaskInstanceNote(Base):
2812 """For storage of arbitrary notes concerning the task instance."""
2814 __tablename__ = "task_instance_note"
2816 user_id = Column(Integer, nullable=True)
2817 task_id = Column(StringID(), primary_key=True, nullable=False)
2818 dag_id = Column(StringID(), primary_key=True, nullable=False)
2819 run_id = Column(StringID(), primary_key=True, nullable=False)
2820 map_index = Column(Integer, primary_key=True, nullable=False)
2821 content = Column(String(1000).with_variant(Text(1000), "mysql"))
2822 created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False)
2823 updated_at = Column(UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow, nullable=False)
2825 task_instance = relationship("TaskInstance", back_populates="task_instance_note")
2827 __table_args__ = (
2828 PrimaryKeyConstraint(
2829 "task_id", "dag_id", "run_id", "map_index", name="task_instance_note_pkey", mssql_clustered=True
2830 ),
2831 ForeignKeyConstraint(
2832 (dag_id, task_id, run_id, map_index),
2833 [
2834 "task_instance.dag_id",
2835 "task_instance.task_id",
2836 "task_instance.run_id",
2837 "task_instance.map_index",
2838 ],
2839 name="task_instance_note_ti_fkey",
2840 ondelete="CASCADE",
2841 ),
2842 ForeignKeyConstraint(
2843 (user_id,),
2844 ["ab_user.id"],
2845 name="task_instance_note_user_fkey",
2846 ),
2847 )
2849 def __init__(self, content, user_id=None):
2850 self.content = content
2851 self.user_id = user_id
2853 def __repr__(self):
2854 prefix = f"<{self.__class__.__name__}: {self.dag_id}.{self.task_id} {self.run_id}"
2855 if self.map_index != -1:
2856 prefix += f" map_index={self.map_index}"
2857 return prefix + ">"
2860STATICA_HACK = True
2861globals()["kcah_acitats"[::-1].upper()] = False
2862if STATICA_HACK: # pragma: no cover
2863 from airflow.jobs.base_job import BaseJob
2865 TaskInstance.queued_by_job = relationship(BaseJob)