Coverage for /pythoncovmergedfiles/medio/medio/src/airflow/airflow/models/taskinstance.py: 22%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18from __future__ import annotations
20import collections.abc
21import contextlib
22import hashlib
23import itertools
24import logging
25import math
26import operator
27import os
28import signal
29import warnings
30from collections import defaultdict
31from contextlib import nullcontext
32from datetime import timedelta
33from enum import Enum
34from typing import TYPE_CHECKING, Any, Callable, Collection, Generator, Iterable, Mapping, Tuple
35from urllib.parse import quote
37import dill
38import jinja2
39import lazy_object_proxy
40import pendulum
41from deprecated import deprecated
42from jinja2 import TemplateAssertionError, UndefinedError
43from sqlalchemy import (
44 Column,
45 DateTime,
46 Float,
47 ForeignKey,
48 ForeignKeyConstraint,
49 Index,
50 Integer,
51 PrimaryKeyConstraint,
52 String,
53 Text,
54 and_,
55 delete,
56 false,
57 func,
58 inspect,
59 or_,
60 text,
61 update,
62)
63from sqlalchemy.ext.associationproxy import association_proxy
64from sqlalchemy.ext.mutable import MutableDict
65from sqlalchemy.orm import lazyload, reconstructor, relationship
66from sqlalchemy.orm.attributes import NO_VALUE, set_committed_value
67from sqlalchemy.sql.expression import case, select
69from airflow import settings
70from airflow.api_internal.internal_api_call import InternalApiConfig, internal_api_call
71from airflow.compat.functools import cache
72from airflow.configuration import conf
73from airflow.datasets import Dataset
74from airflow.datasets.manager import dataset_manager
75from airflow.exceptions import (
76 AirflowException,
77 AirflowFailException,
78 AirflowRescheduleException,
79 AirflowSensorTimeout,
80 AirflowSkipException,
81 AirflowTaskTerminated,
82 AirflowTaskTimeout,
83 DagRunNotFound,
84 RemovedInAirflow3Warning,
85 TaskDeferred,
86 UnmappableXComLengthPushed,
87 UnmappableXComTypePushed,
88 XComForMappingNotPushed,
89)
90from airflow.listeners.listener import get_listener_manager
91from airflow.models.base import Base, StringID, TaskInstanceDependencies, _sentinel
92from airflow.models.dagbag import DagBag
93from airflow.models.log import Log
94from airflow.models.mappedoperator import MappedOperator
95from airflow.models.param import process_params
96from airflow.models.renderedtifields import get_serialized_template_fields
97from airflow.models.taskfail import TaskFail
98from airflow.models.taskinstancekey import TaskInstanceKey
99from airflow.models.taskmap import TaskMap
100from airflow.models.taskreschedule import TaskReschedule
101from airflow.models.xcom import LazyXComSelectSequence, XCom
102from airflow.plugins_manager import integrate_macros_plugins
103from airflow.sentry import Sentry
104from airflow.settings import task_instance_mutation_hook
105from airflow.stats import Stats
106from airflow.templates import SandboxedEnvironment
107from airflow.ti_deps.dep_context import DepContext
108from airflow.ti_deps.dependencies_deps import REQUEUEABLE_DEPS, RUNNING_DEPS
109from airflow.utils import timezone
110from airflow.utils.context import (
111 ConnectionAccessor,
112 Context,
113 InletEventsAccessors,
114 OutletEventAccessors,
115 VariableAccessor,
116 context_get_outlet_events,
117 context_merge,
118)
119from airflow.utils.email import send_email
120from airflow.utils.helpers import prune_dict, render_template_to_string
121from airflow.utils.log.logging_mixin import LoggingMixin
122from airflow.utils.net import get_hostname
123from airflow.utils.operator_helpers import ExecutionCallableRunner, context_to_airflow_vars
124from airflow.utils.platform import getuser
125from airflow.utils.retries import run_with_db_retries
126from airflow.utils.session import NEW_SESSION, create_session, provide_session
127from airflow.utils.sqlalchemy import (
128 ExecutorConfigType,
129 ExtendedJSON,
130 UtcDateTime,
131 tuple_in_condition,
132 with_row_locks,
133)
134from airflow.utils.state import DagRunState, JobState, State, TaskInstanceState
135from airflow.utils.task_group import MappedTaskGroup
136from airflow.utils.task_instance_session import set_current_task_instance_session
137from airflow.utils.timeout import timeout
138from airflow.utils.xcom import XCOM_RETURN_KEY
140TR = TaskReschedule
142_CURRENT_CONTEXT: list[Context] = []
143log = logging.getLogger(__name__)
146if TYPE_CHECKING:
147 from datetime import datetime
148 from pathlib import PurePath
149 from types import TracebackType
151 from sqlalchemy.orm.session import Session
152 from sqlalchemy.sql.elements import BooleanClauseList
153 from sqlalchemy.sql.expression import ColumnOperators
155 from airflow.models.abstractoperator import TaskStateChangeCallback
156 from airflow.models.baseoperator import BaseOperator
157 from airflow.models.dag import DAG, DagModel
158 from airflow.models.dagrun import DagRun
159 from airflow.models.dataset import DatasetEvent
160 from airflow.models.operator import Operator
161 from airflow.serialization.pydantic.dag import DagModelPydantic
162 from airflow.serialization.pydantic.dataset import DatasetEventPydantic
163 from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic
164 from airflow.timetables.base import DataInterval
165 from airflow.typing_compat import Literal, TypeGuard
166 from airflow.utils.task_group import TaskGroup
168 # This is a workaround because mypy doesn't work with hybrid_property
169 # TODO: remove this hack and move hybrid_property back to main import block
170 # See https://github.com/python/mypy/issues/4430
171 hybrid_property = property
172else:
173 from sqlalchemy.ext.hybrid import hybrid_property
176PAST_DEPENDS_MET = "past_depends_met"
179class TaskReturnCode(Enum):
180 """
181 Enum to signal manner of exit for task run command.
183 :meta private:
184 """
186 DEFERRED = 100
187 """When task exits with deferral to trigger."""
190@contextlib.contextmanager
191def set_current_context(context: Context) -> Generator[Context, None, None]:
192 """
193 Set the current execution context to the provided context object.
195 This method should be called once per Task execution, before calling operator.execute.
196 """
197 _CURRENT_CONTEXT.append(context)
198 try:
199 yield context
200 finally:
201 expected_state = _CURRENT_CONTEXT.pop()
202 if expected_state != context:
203 log.warning(
204 "Current context is not equal to the state at context stack. Expected=%s, got=%s",
205 context,
206 expected_state,
207 )
210def _stop_remaining_tasks(*, task_instance: TaskInstance | TaskInstancePydantic, session: Session):
211 """
212 Stop non-teardown tasks in dag.
214 :meta private:
215 """
216 if not task_instance.dag_run:
217 raise ValueError("``task_instance`` must have ``dag_run`` set")
218 tis = task_instance.dag_run.get_task_instances(session=session)
219 if TYPE_CHECKING:
220 assert task_instance.task
221 assert isinstance(task_instance.task.dag, DAG)
223 for ti in tis:
224 if ti.task_id == task_instance.task_id or ti.state in (
225 TaskInstanceState.SUCCESS,
226 TaskInstanceState.FAILED,
227 ):
228 continue
229 task = task_instance.task.dag.task_dict[ti.task_id]
230 if not task.is_teardown:
231 if ti.state == TaskInstanceState.RUNNING:
232 log.info("Forcing task %s to fail due to dag's `fail_stop` setting", ti.task_id)
233 ti.error(session)
234 else:
235 log.info("Setting task %s to SKIPPED due to dag's `fail_stop` setting.", ti.task_id)
236 ti.set_state(state=TaskInstanceState.SKIPPED, session=session)
237 else:
238 log.info("Not skipping teardown task '%s'", ti.task_id)
241def clear_task_instances(
242 tis: list[TaskInstance],
243 session: Session,
244 activate_dag_runs: None = None,
245 dag: DAG | None = None,
246 dag_run_state: DagRunState | Literal[False] = DagRunState.QUEUED,
247) -> None:
248 """
249 Clear a set of task instances, but make sure the running ones get killed.
251 Also sets Dagrun's `state` to QUEUED and `start_date` to the time of execution.
252 But only for finished DRs (SUCCESS and FAILED).
253 Doesn't clear DR's `state` and `start_date`for running
254 DRs (QUEUED and RUNNING) because clearing the state for already
255 running DR is redundant and clearing `start_date` affects DR's duration.
257 :param tis: a list of task instances
258 :param session: current session
259 :param dag_run_state: state to set finished DagRuns to.
260 If set to False, DagRuns state will not be changed.
261 :param dag: DAG object
262 :param activate_dag_runs: Deprecated parameter, do not pass
263 """
264 job_ids = []
265 # Keys: dag_id -> run_id -> map_indexes -> try_numbers -> task_id
266 task_id_by_key: dict[str, dict[str, dict[int, dict[int, set[str]]]]] = defaultdict(
267 lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(set)))
268 )
269 dag_bag = DagBag(read_dags_from_db=True)
270 for ti in tis:
271 if ti.state == TaskInstanceState.RUNNING:
272 if ti.job_id:
273 # If a task is cleared when running, set its state to RESTARTING so that
274 # the task is terminated and becomes eligible for retry.
275 ti.state = TaskInstanceState.RESTARTING
276 job_ids.append(ti.job_id)
277 else:
278 ti_dag = dag if dag and dag.dag_id == ti.dag_id else dag_bag.get_dag(ti.dag_id, session=session)
279 task_id = ti.task_id
280 if ti_dag and ti_dag.has_task(task_id):
281 task = ti_dag.get_task(task_id)
282 ti.refresh_from_task(task)
283 if TYPE_CHECKING:
284 assert ti.task
285 ti.max_tries = ti.try_number + task.retries
286 else:
287 # Ignore errors when updating max_tries if the DAG or
288 # task are not found since database records could be
289 # outdated. We make max_tries the maximum value of its
290 # original max_tries or the last attempted try number.
291 ti.max_tries = max(ti.max_tries, ti.try_number)
292 ti.state = None
293 ti.external_executor_id = None
294 ti.clear_next_method_args()
295 session.merge(ti)
297 task_id_by_key[ti.dag_id][ti.run_id][ti.map_index][ti.try_number].add(ti.task_id)
299 if task_id_by_key:
300 # Clear all reschedules related to the ti to clear
302 # This is an optimization for the common case where all tis are for a small number
303 # of dag_id, run_id, try_number, and map_index. Use a nested dict of dag_id,
304 # run_id, try_number, map_index, and task_id to construct the where clause in a
305 # hierarchical manner. This speeds up the delete statement by more than 40x for
306 # large number of tis (50k+).
307 conditions = or_(
308 and_(
309 TR.dag_id == dag_id,
310 or_(
311 and_(
312 TR.run_id == run_id,
313 or_(
314 and_(
315 TR.map_index == map_index,
316 or_(
317 and_(TR.try_number == try_number, TR.task_id.in_(task_ids))
318 for try_number, task_ids in task_tries.items()
319 ),
320 )
321 for map_index, task_tries in map_indexes.items()
322 ),
323 )
324 for run_id, map_indexes in run_ids.items()
325 ),
326 )
327 for dag_id, run_ids in task_id_by_key.items()
328 )
330 delete_qry = TR.__table__.delete().where(conditions)
331 session.execute(delete_qry)
333 if job_ids:
334 from airflow.jobs.job import Job
336 session.execute(update(Job).where(Job.id.in_(job_ids)).values(state=JobState.RESTARTING))
338 if activate_dag_runs is not None:
339 warnings.warn(
340 "`activate_dag_runs` parameter to clear_task_instances function is deprecated. "
341 "Please use `dag_run_state`",
342 RemovedInAirflow3Warning,
343 stacklevel=2,
344 )
345 if not activate_dag_runs:
346 dag_run_state = False
348 if dag_run_state is not False and tis:
349 from airflow.models.dagrun import DagRun # Avoid circular import
351 run_ids_by_dag_id = defaultdict(set)
352 for instance in tis:
353 run_ids_by_dag_id[instance.dag_id].add(instance.run_id)
355 drs = (
356 session.query(DagRun)
357 .filter(
358 or_(
359 and_(DagRun.dag_id == dag_id, DagRun.run_id.in_(run_ids))
360 for dag_id, run_ids in run_ids_by_dag_id.items()
361 )
362 )
363 .all()
364 )
365 dag_run_state = DagRunState(dag_run_state) # Validate the state value.
366 for dr in drs:
367 if dr.state in State.finished_dr_states:
368 dr.state = dag_run_state
369 dr.start_date = timezone.utcnow()
370 if dag_run_state == DagRunState.QUEUED:
371 dr.last_scheduling_decision = None
372 dr.start_date = None
373 dr.clear_number += 1
374 session.flush()
377def _is_mappable_value(value: Any) -> TypeGuard[Collection]:
378 """Whether a value can be used for task mapping.
380 We only allow collections with guaranteed ordering, but exclude character
381 sequences since that's usually not what users would expect to be mappable.
382 """
383 if not isinstance(value, (collections.abc.Sequence, dict)):
384 return False
385 if isinstance(value, (bytearray, bytes, str)):
386 return False
387 return True
390def _creator_note(val):
391 """Creator the ``note`` association proxy."""
392 if isinstance(val, str):
393 return TaskInstanceNote(content=val)
394 elif isinstance(val, dict):
395 return TaskInstanceNote(**val)
396 else:
397 return TaskInstanceNote(*val)
400def _execute_task(task_instance: TaskInstance | TaskInstancePydantic, context: Context, task_orig: Operator):
401 """
402 Execute Task (optionally with a Timeout) and push Xcom results.
404 :param task_instance: the task instance
405 :param context: Jinja2 context
406 :param task_orig: origin task
408 :meta private:
409 """
410 task_to_execute = task_instance.task
412 if TYPE_CHECKING:
413 assert task_to_execute
415 if isinstance(task_to_execute, MappedOperator):
416 raise AirflowException("MappedOperator cannot be executed.")
418 # If the task has been deferred and is being executed due to a trigger,
419 # then we need to pick the right method to come back to, otherwise
420 # we go for the default execute
421 execute_callable_kwargs: dict[str, Any] = {}
422 execute_callable: Callable
423 if task_instance.next_method:
424 if task_instance.next_method == "execute":
425 if not task_instance.next_kwargs:
426 task_instance.next_kwargs = {}
427 task_instance.next_kwargs[f"{task_to_execute.__class__.__name__}__sentinel"] = _sentinel
428 execute_callable = task_to_execute.resume_execution
429 execute_callable_kwargs["next_method"] = task_instance.next_method
430 execute_callable_kwargs["next_kwargs"] = task_instance.next_kwargs
431 else:
432 execute_callable = task_to_execute.execute
433 if execute_callable.__name__ == "execute":
434 execute_callable_kwargs[f"{task_to_execute.__class__.__name__}__sentinel"] = _sentinel
436 def _execute_callable(context: Context, **execute_callable_kwargs):
437 try:
438 # Print a marker for log grouping of details before task execution
439 log.info("::endgroup::")
441 return ExecutionCallableRunner(
442 execute_callable,
443 context_get_outlet_events(context),
444 logger=log,
445 ).run(context=context, **execute_callable_kwargs)
446 except SystemExit as e:
447 # Handle only successful cases here. Failure cases will be handled upper
448 # in the exception chain.
449 if e.code is not None and e.code != 0:
450 raise
451 return None
452 finally:
453 # Print a marker post execution for internals of post task processing
454 log.info("::group::Post task execution logs")
456 # If a timeout is specified for the task, make it fail
457 # if it goes beyond
458 if task_to_execute.execution_timeout:
459 # If we are coming in with a next_method (i.e. from a deferral),
460 # calculate the timeout from our start_date.
461 if task_instance.next_method and task_instance.start_date:
462 timeout_seconds = (
463 task_to_execute.execution_timeout - (timezone.utcnow() - task_instance.start_date)
464 ).total_seconds()
465 else:
466 timeout_seconds = task_to_execute.execution_timeout.total_seconds()
467 try:
468 # It's possible we're already timed out, so fast-fail if true
469 if timeout_seconds <= 0:
470 raise AirflowTaskTimeout()
471 # Run task in timeout wrapper
472 with timeout(timeout_seconds):
473 result = _execute_callable(context=context, **execute_callable_kwargs)
474 except AirflowTaskTimeout:
475 task_to_execute.on_kill()
476 raise
477 else:
478 result = _execute_callable(context=context, **execute_callable_kwargs)
479 cm = nullcontext() if InternalApiConfig.get_use_internal_api() else create_session()
480 with cm as session_or_null:
481 if task_to_execute.do_xcom_push:
482 xcom_value = result
483 else:
484 xcom_value = None
485 if xcom_value is not None: # If the task returns a result, push an XCom containing it.
486 if task_to_execute.multiple_outputs:
487 if not isinstance(xcom_value, Mapping):
488 raise AirflowException(
489 f"Returned output was type {type(xcom_value)} "
490 "expected dictionary for multiple_outputs"
491 )
492 for key in xcom_value.keys():
493 if not isinstance(key, str):
494 raise AirflowException(
495 "Returned dictionary keys must be strings when using "
496 f"multiple_outputs, found {key} ({type(key)}) instead"
497 )
498 for key, value in xcom_value.items():
499 task_instance.xcom_push(key=key, value=value, session=session_or_null)
500 task_instance.xcom_push(key=XCOM_RETURN_KEY, value=xcom_value, session=session_or_null)
501 _record_task_map_for_downstreams(
502 task_instance=task_instance, task=task_orig, value=xcom_value, session=session_or_null
503 )
504 return result
507def _refresh_from_db(
508 *,
509 task_instance: TaskInstance | TaskInstancePydantic,
510 session: Session | None = None,
511 lock_for_update: bool = False,
512) -> None:
513 """
514 Refresh the task instance from the database based on the primary key.
516 :param task_instance: the task instance
517 :param session: SQLAlchemy ORM Session
518 :param lock_for_update: if True, indicates that the database should
519 lock the TaskInstance (issuing a FOR UPDATE clause) until the
520 session is committed.
522 :meta private:
523 """
524 if session and task_instance in session:
525 session.refresh(task_instance, TaskInstance.__mapper__.column_attrs.keys())
527 ti = TaskInstance.get_task_instance(
528 dag_id=task_instance.dag_id,
529 task_id=task_instance.task_id,
530 run_id=task_instance.run_id,
531 map_index=task_instance.map_index,
532 lock_for_update=lock_for_update,
533 session=session,
534 )
536 if ti:
537 # Fields ordered per model definition
538 task_instance.start_date = ti.start_date
539 task_instance.end_date = ti.end_date
540 task_instance.duration = ti.duration
541 task_instance.state = ti.state
542 task_instance.try_number = ti.try_number
543 task_instance.max_tries = ti.max_tries
544 task_instance.hostname = ti.hostname
545 task_instance.unixname = ti.unixname
546 task_instance.job_id = ti.job_id
547 task_instance.pool = ti.pool
548 task_instance.pool_slots = ti.pool_slots or 1
549 task_instance.queue = ti.queue
550 task_instance.priority_weight = ti.priority_weight
551 task_instance.operator = ti.operator
552 task_instance.custom_operator_name = ti.custom_operator_name
553 task_instance.queued_dttm = ti.queued_dttm
554 task_instance.queued_by_job_id = ti.queued_by_job_id
555 task_instance.pid = ti.pid
556 task_instance.executor = ti.executor
557 task_instance.executor_config = ti.executor_config
558 task_instance.external_executor_id = ti.external_executor_id
559 task_instance.trigger_id = ti.trigger_id
560 task_instance.next_method = ti.next_method
561 task_instance.next_kwargs = ti.next_kwargs
562 else:
563 task_instance.state = None
566def _set_duration(*, task_instance: TaskInstance | TaskInstancePydantic) -> None:
567 """
568 Set task instance duration.
570 :param task_instance: the task instance
572 :meta private:
573 """
574 if task_instance.end_date and task_instance.start_date:
575 task_instance.duration = (task_instance.end_date - task_instance.start_date).total_seconds()
576 else:
577 task_instance.duration = None
578 log.debug("Task Duration set to %s", task_instance.duration)
581def _stats_tags(*, task_instance: TaskInstance | TaskInstancePydantic) -> dict[str, str]:
582 """
583 Return task instance tags.
585 :param task_instance: the task instance
587 :meta private:
588 """
589 return prune_dict({"dag_id": task_instance.dag_id, "task_id": task_instance.task_id})
592def _clear_next_method_args(*, task_instance: TaskInstance | TaskInstancePydantic) -> None:
593 """
594 Ensure we unset next_method and next_kwargs to ensure that any retries don't reuse them.
596 :param task_instance: the task instance
598 :meta private:
599 """
600 log.debug("Clearing next_method and next_kwargs.")
602 task_instance.next_method = None
603 task_instance.next_kwargs = None
606@internal_api_call
607def _get_template_context(
608 *,
609 task_instance: TaskInstance | TaskInstancePydantic,
610 session: Session | None = None,
611 ignore_param_exceptions: bool = True,
612) -> Context:
613 """
614 Return TI Context.
616 :param task_instance: the task instance
617 :param session: SQLAlchemy ORM Session
618 :param ignore_param_exceptions: flag to suppress value exceptions while initializing the ParamsDict
620 :meta private:
621 """
622 # Do not use provide_session here -- it expunges everything on exit!
623 if not session:
624 session = settings.Session()
626 from airflow import macros
627 from airflow.models.abstractoperator import NotMapped
629 integrate_macros_plugins()
631 task = task_instance.task
632 if TYPE_CHECKING:
633 assert task_instance.task
634 assert task
635 assert task.dag
636 try:
637 dag: DAG = task.dag
638 except AirflowException:
639 from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic
641 if isinstance(task_instance, TaskInstancePydantic):
642 ti = session.scalar(
643 select(TaskInstance).where(
644 TaskInstance.task_id == task_instance.task_id,
645 TaskInstance.dag_id == task_instance.dag_id,
646 TaskInstance.run_id == task_instance.run_id,
647 TaskInstance.map_index == task_instance.map_index,
648 )
649 )
650 dag = ti.dag_model.serialized_dag.dag
651 if hasattr(task_instance.task, "_dag"): # BaseOperator
652 task_instance.task._dag = dag
653 else: # MappedOperator
654 task_instance.task.dag = dag
655 else:
656 raise
657 dag_run = task_instance.get_dagrun(session)
658 data_interval = dag.get_run_data_interval(dag_run)
660 validated_params = process_params(dag, task, dag_run, suppress_exception=ignore_param_exceptions)
662 logical_date: DateTime = timezone.coerce_datetime(task_instance.execution_date)
663 ds = logical_date.strftime("%Y-%m-%d")
664 ds_nodash = ds.replace("-", "")
665 ts = logical_date.isoformat()
666 ts_nodash = logical_date.strftime("%Y%m%dT%H%M%S")
667 ts_nodash_with_tz = ts.replace("-", "").replace(":", "")
669 @cache # Prevent multiple database access.
670 def _get_previous_dagrun_success() -> DagRun | None:
671 return task_instance.get_previous_dagrun(state=DagRunState.SUCCESS, session=session)
673 def _get_previous_dagrun_data_interval_success() -> DataInterval | None:
674 dagrun = _get_previous_dagrun_success()
675 if dagrun is None:
676 return None
677 return dag.get_run_data_interval(dagrun)
679 def get_prev_data_interval_start_success() -> pendulum.DateTime | None:
680 data_interval = _get_previous_dagrun_data_interval_success()
681 if data_interval is None:
682 return None
683 return data_interval.start
685 def get_prev_data_interval_end_success() -> pendulum.DateTime | None:
686 data_interval = _get_previous_dagrun_data_interval_success()
687 if data_interval is None:
688 return None
689 return data_interval.end
691 def get_prev_start_date_success() -> pendulum.DateTime | None:
692 dagrun = _get_previous_dagrun_success()
693 if dagrun is None:
694 return None
695 return timezone.coerce_datetime(dagrun.start_date)
697 def get_prev_end_date_success() -> pendulum.DateTime | None:
698 dagrun = _get_previous_dagrun_success()
699 if dagrun is None:
700 return None
701 return timezone.coerce_datetime(dagrun.end_date)
703 @cache
704 def get_yesterday_ds() -> str:
705 return (logical_date - timedelta(1)).strftime("%Y-%m-%d")
707 def get_yesterday_ds_nodash() -> str:
708 return get_yesterday_ds().replace("-", "")
710 @cache
711 def get_tomorrow_ds() -> str:
712 return (logical_date + timedelta(1)).strftime("%Y-%m-%d")
714 def get_tomorrow_ds_nodash() -> str:
715 return get_tomorrow_ds().replace("-", "")
717 @cache
718 def get_next_execution_date() -> pendulum.DateTime | None:
719 # For manually triggered dagruns that aren't run on a schedule,
720 # the "next" execution date doesn't make sense, and should be set
721 # to execution date for consistency with how execution_date is set
722 # for manually triggered tasks, i.e. triggered_date == execution_date.
723 if dag_run.external_trigger:
724 return logical_date
725 if dag is None:
726 return None
727 next_info = dag.next_dagrun_info(data_interval, restricted=False)
728 if next_info is None:
729 return None
730 return timezone.coerce_datetime(next_info.logical_date)
732 def get_next_ds() -> str | None:
733 execution_date = get_next_execution_date()
734 if execution_date is None:
735 return None
736 return execution_date.strftime("%Y-%m-%d")
738 def get_next_ds_nodash() -> str | None:
739 ds = get_next_ds()
740 if ds is None:
741 return ds
742 return ds.replace("-", "")
744 @cache
745 def get_prev_execution_date():
746 # For manually triggered dagruns that aren't run on a schedule,
747 # the "previous" execution date doesn't make sense, and should be set
748 # to execution date for consistency with how execution_date is set
749 # for manually triggered tasks, i.e. triggered_date == execution_date.
750 if dag_run.external_trigger:
751 return logical_date
752 with warnings.catch_warnings():
753 warnings.simplefilter("ignore", RemovedInAirflow3Warning)
754 return dag.previous_schedule(logical_date)
756 @cache
757 def get_prev_ds() -> str | None:
758 execution_date = get_prev_execution_date()
759 if execution_date is None:
760 return None
761 return execution_date.strftime("%Y-%m-%d")
763 def get_prev_ds_nodash() -> str | None:
764 prev_ds = get_prev_ds()
765 if prev_ds is None:
766 return None
767 return prev_ds.replace("-", "")
769 def get_triggering_events() -> dict[str, list[DatasetEvent | DatasetEventPydantic]]:
770 if TYPE_CHECKING:
771 assert session is not None
773 # The dag_run may not be attached to the session anymore since the
774 # code base is over-zealous with use of session.expunge_all().
775 # Re-attach it if we get called.
776 nonlocal dag_run
777 if dag_run not in session:
778 dag_run = session.merge(dag_run, load=False)
779 dataset_events = dag_run.consumed_dataset_events
780 triggering_events: dict[str, list[DatasetEvent | DatasetEventPydantic]] = defaultdict(list)
781 for event in dataset_events:
782 if event.dataset:
783 triggering_events[event.dataset.uri].append(event)
785 return triggering_events
787 try:
788 expanded_ti_count: int | None = task.get_mapped_ti_count(task_instance.run_id, session=session)
789 except NotMapped:
790 expanded_ti_count = None
792 # NOTE: If you add to this dict, make sure to also update the following:
793 # * Context in airflow/utils/context.pyi
794 # * KNOWN_CONTEXT_KEYS in airflow/utils/context.py
795 # * Table in docs/apache-airflow/templates-ref.rst
796 context: dict[str, Any] = {
797 "conf": conf,
798 "dag": dag,
799 "dag_run": dag_run,
800 "data_interval_end": timezone.coerce_datetime(data_interval.end),
801 "data_interval_start": timezone.coerce_datetime(data_interval.start),
802 "outlet_events": OutletEventAccessors(),
803 "ds": ds,
804 "ds_nodash": ds_nodash,
805 "execution_date": logical_date,
806 "expanded_ti_count": expanded_ti_count,
807 "inlets": task.inlets,
808 "inlet_events": InletEventsAccessors(task.inlets, session=session),
809 "logical_date": logical_date,
810 "macros": macros,
811 "map_index_template": task.map_index_template,
812 "next_ds": get_next_ds(),
813 "next_ds_nodash": get_next_ds_nodash(),
814 "next_execution_date": get_next_execution_date(),
815 "outlets": task.outlets,
816 "params": validated_params,
817 "prev_data_interval_start_success": get_prev_data_interval_start_success(),
818 "prev_data_interval_end_success": get_prev_data_interval_end_success(),
819 "prev_ds": get_prev_ds(),
820 "prev_ds_nodash": get_prev_ds_nodash(),
821 "prev_execution_date": get_prev_execution_date(),
822 "prev_execution_date_success": task_instance.get_previous_execution_date(
823 state=DagRunState.SUCCESS,
824 session=session,
825 ),
826 "prev_start_date_success": get_prev_start_date_success(),
827 "prev_end_date_success": get_prev_end_date_success(),
828 "run_id": task_instance.run_id,
829 "task": task,
830 "task_instance": task_instance,
831 "task_instance_key_str": f"{task.dag_id}__{task.task_id}__{ds_nodash}",
832 "test_mode": task_instance.test_mode,
833 "ti": task_instance,
834 "tomorrow_ds": get_tomorrow_ds(),
835 "tomorrow_ds_nodash": get_tomorrow_ds_nodash(),
836 "triggering_dataset_events": lazy_object_proxy.Proxy(get_triggering_events),
837 "ts": ts,
838 "ts_nodash": ts_nodash,
839 "ts_nodash_with_tz": ts_nodash_with_tz,
840 "var": {
841 "json": VariableAccessor(deserialize_json=True),
842 "value": VariableAccessor(deserialize_json=False),
843 },
844 "conn": ConnectionAccessor(),
845 "yesterday_ds": get_yesterday_ds(),
846 "yesterday_ds_nodash": get_yesterday_ds_nodash(),
847 }
848 # Mypy doesn't like turning existing dicts in to a TypeDict -- and we "lie" in the type stub to say it
849 # is one, but in practice it isn't. See https://github.com/python/mypy/issues/8890
850 return Context(context) # type: ignore
853def _is_eligible_to_retry(*, task_instance: TaskInstance | TaskInstancePydantic):
854 """
855 Is task instance is eligible for retry.
857 :param task_instance: the task instance
859 :meta private:
860 """
861 if task_instance.state == TaskInstanceState.RESTARTING:
862 # If a task is cleared when running, it goes into RESTARTING state and is always
863 # eligible for retry
864 return True
865 if not getattr(task_instance, "task", None):
866 # Couldn't load the task, don't know number of retries, guess:
867 return task_instance.try_number <= task_instance.max_tries
869 if TYPE_CHECKING:
870 assert task_instance.task
872 return task_instance.task.retries and task_instance.try_number <= task_instance.max_tries
875def _handle_failure(
876 *,
877 task_instance: TaskInstance | TaskInstancePydantic,
878 error: None | str | BaseException,
879 session: Session,
880 test_mode: bool | None = None,
881 context: Context | None = None,
882 force_fail: bool = False,
883 fail_stop: bool = False,
884) -> None:
885 """
886 Handle Failure for a task instance.
888 :param task_instance: the task instance
889 :param error: if specified, log the specific exception if thrown
890 :param session: SQLAlchemy ORM Session
891 :param test_mode: doesn't record success or failure in the DB if True
892 :param context: Jinja2 context
893 :param force_fail: if True, task does not retry
895 :meta private:
896 """
897 if test_mode is None:
898 test_mode = task_instance.test_mode
900 failure_context = TaskInstance.fetch_handle_failure_context(
901 ti=task_instance,
902 error=error,
903 test_mode=test_mode,
904 context=context,
905 force_fail=force_fail,
906 session=session,
907 fail_stop=fail_stop,
908 )
910 _log_state(task_instance=task_instance, lead_msg="Immediate failure requested. " if force_fail else "")
911 if (
912 failure_context["task"]
913 and failure_context["email_for_state"](failure_context["task"])
914 and failure_context["task"].email
915 ):
916 try:
917 task_instance.email_alert(error, failure_context["task"])
918 except Exception:
919 log.exception("Failed to send email to: %s", failure_context["task"].email)
921 if failure_context["callbacks"] and failure_context["context"]:
922 _run_finished_callback(
923 callbacks=failure_context["callbacks"],
924 context=failure_context["context"],
925 )
927 if not test_mode:
928 TaskInstance.save_to_db(failure_context["ti"], session)
931def _refresh_from_task(
932 *, task_instance: TaskInstance | TaskInstancePydantic, task: Operator, pool_override: str | None = None
933) -> None:
934 """
935 Copy common attributes from the given task.
937 :param task_instance: the task instance
938 :param task: The task object to copy from
939 :param pool_override: Use the pool_override instead of task's pool
941 :meta private:
942 """
943 task_instance.task = task
944 task_instance.queue = task.queue
945 task_instance.pool = pool_override or task.pool
946 task_instance.pool_slots = task.pool_slots
947 with contextlib.suppress(Exception):
948 # This method is called from the different places, and sometimes the TI is not fully initialized
949 task_instance.priority_weight = task_instance.task.weight_rule.get_weight(
950 task_instance # type: ignore[arg-type]
951 )
952 task_instance.run_as_user = task.run_as_user
953 # Do not set max_tries to task.retries here because max_tries is a cumulative
954 # value that needs to be stored in the db.
955 task_instance.executor = task.executor
956 task_instance.executor_config = task.executor_config
957 task_instance.operator = task.task_type
958 task_instance.custom_operator_name = getattr(task, "custom_operator_name", None)
959 # Re-apply cluster policy here so that task default do not overload previous data
960 task_instance_mutation_hook(task_instance)
963def _record_task_map_for_downstreams(
964 *, task_instance: TaskInstance | TaskInstancePydantic, task: Operator, value: Any, session: Session
965) -> None:
966 """
967 Record the task map for downstream tasks.
969 :param task_instance: the task instance
970 :param task: The task object
971 :param value: The value
972 :param session: SQLAlchemy ORM Session
974 :meta private:
975 """
976 if next(task.iter_mapped_dependants(), None) is None: # No mapped dependants, no need to validate.
977 return
978 # TODO: We don't push TaskMap for mapped task instances because it's not
979 # currently possible for a downstream to depend on one individual mapped
980 # task instance. This will change when we implement task mapping inside
981 # a mapped task group, and we'll need to further analyze the case.
982 if isinstance(task, MappedOperator):
983 return
984 if value is None:
985 raise XComForMappingNotPushed()
986 if not _is_mappable_value(value):
987 raise UnmappableXComTypePushed(value)
988 task_map = TaskMap.from_task_instance_xcom(task_instance, value)
989 max_map_length = conf.getint("core", "max_map_length", fallback=1024)
990 if task_map.length > max_map_length:
991 raise UnmappableXComLengthPushed(value, max_map_length)
992 session.merge(task_map)
995def _get_previous_dagrun(
996 *,
997 task_instance: TaskInstance | TaskInstancePydantic,
998 state: DagRunState | None = None,
999 session: Session | None = None,
1000) -> DagRun | None:
1001 """
1002 Return the DagRun that ran prior to this task instance's DagRun.
1004 :param task_instance: the task instance
1005 :param state: If passed, it only take into account instances of a specific state.
1006 :param session: SQLAlchemy ORM Session.
1008 :meta private:
1009 """
1010 if TYPE_CHECKING:
1011 assert task_instance.task
1013 dag = task_instance.task.dag
1014 if dag is None:
1015 return None
1017 dr = task_instance.get_dagrun(session=session)
1018 dr.dag = dag
1020 from airflow.models.dagrun import DagRun # Avoid circular import
1022 # We always ignore schedule in dagrun lookup when `state` is given
1023 # or the DAG is never scheduled. For legacy reasons, when
1024 # `catchup=True`, we use `get_previous_scheduled_dagrun` unless
1025 # `ignore_schedule` is `True`.
1026 ignore_schedule = state is not None or not dag.timetable.can_be_scheduled
1027 if dag.catchup is True and not ignore_schedule:
1028 last_dagrun = DagRun.get_previous_scheduled_dagrun(dr.id, session=session)
1029 else:
1030 last_dagrun = DagRun.get_previous_dagrun(dag_run=dr, session=session, state=state)
1032 if last_dagrun:
1033 return last_dagrun
1035 return None
1038def _get_previous_execution_date(
1039 *,
1040 task_instance: TaskInstance | TaskInstancePydantic,
1041 state: DagRunState | None,
1042 session: Session,
1043) -> pendulum.DateTime | None:
1044 """
1045 Get execution date from property previous_ti_success.
1047 :param task_instance: the task instance
1048 :param session: SQLAlchemy ORM Session
1049 :param state: If passed, it only take into account instances of a specific state.
1051 :meta private:
1052 """
1053 log.debug("previous_execution_date was called")
1054 prev_ti = task_instance.get_previous_ti(state=state, session=session)
1055 return pendulum.instance(prev_ti.execution_date) if prev_ti and prev_ti.execution_date else None
1058def _email_alert(
1059 *, task_instance: TaskInstance | TaskInstancePydantic, exception, task: BaseOperator
1060) -> None:
1061 """
1062 Send alert email with exception information.
1064 :param task_instance: the task instance
1065 :param exception: the exception
1066 :param task: task related to the exception
1068 :meta private:
1069 """
1070 subject, html_content, html_content_err = task_instance.get_email_subject_content(exception, task=task)
1071 if TYPE_CHECKING:
1072 assert task.email
1073 try:
1074 send_email(task.email, subject, html_content)
1075 except Exception:
1076 send_email(task.email, subject, html_content_err)
1079def _get_email_subject_content(
1080 *,
1081 task_instance: TaskInstance | TaskInstancePydantic,
1082 exception: BaseException,
1083 task: BaseOperator | None = None,
1084) -> tuple[str, str, str]:
1085 """
1086 Get the email subject content for exceptions.
1088 :param task_instance: the task instance
1089 :param exception: the exception sent in the email
1090 :param task:
1092 :meta private:
1093 """
1094 # For a ti from DB (without ti.task), return the default value
1095 if task is None:
1096 task = getattr(task_instance, "task")
1097 use_default = task is None
1098 exception_html = str(exception).replace("\n", "<br>")
1100 default_subject = "Airflow alert: {{ti}}"
1101 # For reporting purposes, we report based on 1-indexed,
1102 # not 0-indexed lists (i.e. Try 1 instead of
1103 # Try 0 for the first attempt).
1104 default_html_content = (
1105 "Try {{try_number}} out of {{max_tries + 1}}<br>"
1106 "Exception:<br>{{exception_html}}<br>"
1107 'Log: <a href="{{ti.log_url}}">Link</a><br>'
1108 "Host: {{ti.hostname}}<br>"
1109 'Mark success: <a href="{{ti.mark_success_url}}">Link</a><br>'
1110 )
1112 default_html_content_err = (
1113 "Try {{try_number}} out of {{max_tries + 1}}<br>"
1114 "Exception:<br>Failed attempt to attach error logs<br>"
1115 'Log: <a href="{{ti.log_url}}">Link</a><br>'
1116 "Host: {{ti.hostname}}<br>"
1117 'Mark success: <a href="{{ti.mark_success_url}}">Link</a><br>'
1118 )
1120 additional_context: dict[str, Any] = {
1121 "exception": exception,
1122 "exception_html": exception_html,
1123 "try_number": task_instance.try_number,
1124 "max_tries": task_instance.max_tries,
1125 }
1127 if use_default:
1128 default_context = {"ti": task_instance, **additional_context}
1129 jinja_env = jinja2.Environment(
1130 loader=jinja2.FileSystemLoader(os.path.dirname(__file__)), autoescape=True
1131 )
1132 subject = jinja_env.from_string(default_subject).render(**default_context)
1133 html_content = jinja_env.from_string(default_html_content).render(**default_context)
1134 html_content_err = jinja_env.from_string(default_html_content_err).render(**default_context)
1136 else:
1137 if TYPE_CHECKING:
1138 assert task_instance.task
1140 # Use the DAG's get_template_env() to set force_sandboxed. Don't add
1141 # the flag to the function on task object -- that function can be
1142 # overridden, and adding a flag breaks backward compatibility.
1143 dag = task_instance.task.get_dag()
1144 if dag:
1145 jinja_env = dag.get_template_env(force_sandboxed=True)
1146 else:
1147 jinja_env = SandboxedEnvironment(cache_size=0)
1148 jinja_context = task_instance.get_template_context()
1149 context_merge(jinja_context, additional_context)
1151 def render(key: str, content: str) -> str:
1152 if conf.has_option("email", key):
1153 path = conf.get_mandatory_value("email", key)
1154 try:
1155 with open(path) as f:
1156 content = f.read()
1157 except FileNotFoundError:
1158 log.warning("Could not find email template file '%s'. Using defaults...", path)
1159 except OSError:
1160 log.exception("Error while using email template %s. Using defaults...", path)
1161 return render_template_to_string(jinja_env.from_string(content), jinja_context)
1163 subject = render("subject_template", default_subject)
1164 html_content = render("html_content_template", default_html_content)
1165 html_content_err = render("html_content_template", default_html_content_err)
1167 return subject, html_content, html_content_err
1170def _run_finished_callback(
1171 *,
1172 callbacks: None | TaskStateChangeCallback | list[TaskStateChangeCallback],
1173 context: Context,
1174) -> None:
1175 """
1176 Run callback after task finishes.
1178 :param callbacks: callbacks to run
1179 :param context: callbacks context
1181 :meta private:
1182 """
1183 if callbacks:
1184 callbacks = callbacks if isinstance(callbacks, list) else [callbacks]
1185 for callback in callbacks:
1186 log.info("Executing %s callback", callback.__name__)
1187 try:
1188 callback(context)
1189 except Exception:
1190 log.exception("Error when executing %s callback", callback.__name__) # type: ignore[attr-defined]
1193def _log_state(*, task_instance: TaskInstance | TaskInstancePydantic, lead_msg: str = "") -> None:
1194 """
1195 Log task state.
1197 :param task_instance: the task instance
1198 :param lead_msg: lead message
1200 :meta private:
1201 """
1202 params = [
1203 lead_msg,
1204 str(task_instance.state).upper(),
1205 task_instance.dag_id,
1206 task_instance.task_id,
1207 task_instance.run_id,
1208 ]
1209 message = "%sMarking task as %s. dag_id=%s, task_id=%s, run_id=%s, "
1210 if task_instance.map_index >= 0:
1211 params.append(task_instance.map_index)
1212 message += "map_index=%d, "
1213 message += "execution_date=%s, start_date=%s, end_date=%s"
1214 log.info(
1215 message,
1216 *params,
1217 _date_or_empty(task_instance=task_instance, attr="execution_date"),
1218 _date_or_empty(task_instance=task_instance, attr="start_date"),
1219 _date_or_empty(task_instance=task_instance, attr="end_date"),
1220 stacklevel=2,
1221 )
1224def _date_or_empty(*, task_instance: TaskInstance | TaskInstancePydantic, attr: str) -> str:
1225 """
1226 Fetch a date attribute or None of it does not exist.
1228 :param task_instance: the task instance
1229 :param attr: the attribute name
1231 :meta private:
1232 """
1233 result: datetime | None = getattr(task_instance, attr, None)
1234 return result.strftime("%Y%m%dT%H%M%S") if result else ""
1237def _get_previous_ti(
1238 *,
1239 task_instance: TaskInstance | TaskInstancePydantic,
1240 session: Session,
1241 state: DagRunState | None = None,
1242) -> TaskInstance | TaskInstancePydantic | None:
1243 """
1244 Get task instance for the task that ran before this task instance.
1246 :param task_instance: the task instance
1247 :param state: If passed, it only take into account instances of a specific state.
1248 :param session: SQLAlchemy ORM Session
1250 :meta private:
1251 """
1252 dagrun = task_instance.get_previous_dagrun(state, session=session)
1253 if dagrun is None:
1254 return None
1255 return dagrun.get_task_instance(task_instance.task_id, session=session)
1258@internal_api_call
1259@provide_session
1260def _update_rtif(ti, rendered_fields, session: Session | None = None):
1261 from airflow.models.renderedtifields import RenderedTaskInstanceFields
1263 rtif = RenderedTaskInstanceFields(ti=ti, render_templates=False, rendered_fields=rendered_fields)
1264 RenderedTaskInstanceFields.write(rtif, session=session)
1265 RenderedTaskInstanceFields.delete_old_records(ti.task_id, ti.dag_id, session=session)
1268class TaskInstance(Base, LoggingMixin):
1269 """
1270 Task instances store the state of a task instance.
1272 This table is the authority and single source of truth around what tasks
1273 have run and the state they are in.
1275 The SqlAlchemy model doesn't have a SqlAlchemy foreign key to the task or
1276 dag model deliberately to have more control over transactions.
1278 Database transactions on this table should insure double triggers and
1279 any confusion around what task instances are or aren't ready to run
1280 even while multiple schedulers may be firing task instances.
1282 A value of -1 in map_index represents any of: a TI without mapped tasks;
1283 a TI with mapped tasks that has yet to be expanded (state=pending);
1284 a TI with mapped tasks that expanded to an empty list (state=skipped).
1285 """
1287 __tablename__ = "task_instance"
1288 task_id = Column(StringID(), primary_key=True, nullable=False)
1289 dag_id = Column(StringID(), primary_key=True, nullable=False)
1290 run_id = Column(StringID(), primary_key=True, nullable=False)
1291 map_index = Column(Integer, primary_key=True, nullable=False, server_default=text("-1"))
1293 start_date = Column(UtcDateTime)
1294 end_date = Column(UtcDateTime)
1295 duration = Column(Float)
1296 state = Column(String(20))
1297 try_number = Column(Integer, default=0)
1298 max_tries = Column(Integer, server_default=text("-1"))
1299 hostname = Column(String(1000))
1300 unixname = Column(String(1000))
1301 job_id = Column(Integer)
1302 pool = Column(String(256), nullable=False)
1303 pool_slots = Column(Integer, default=1, nullable=False)
1304 queue = Column(String(256))
1305 priority_weight = Column(Integer)
1306 operator = Column(String(1000))
1307 custom_operator_name = Column(String(1000))
1308 queued_dttm = Column(UtcDateTime)
1309 queued_by_job_id = Column(Integer)
1310 pid = Column(Integer)
1311 executor = Column(String(1000))
1312 executor_config = Column(ExecutorConfigType(pickler=dill))
1313 updated_at = Column(UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow)
1314 rendered_map_index = Column(String(250))
1316 external_executor_id = Column(StringID())
1318 # The trigger to resume on if we are in state DEFERRED
1319 trigger_id = Column(Integer)
1321 # Optional timeout datetime for the trigger (past this, we'll fail)
1322 trigger_timeout = Column(DateTime)
1323 # The trigger_timeout should be TIMESTAMP(using UtcDateTime) but for ease of
1324 # migration, we are keeping it as DateTime pending a change where expensive
1325 # migration is inevitable.
1327 # The method to call next, and any extra arguments to pass to it.
1328 # Usually used when resuming from DEFERRED.
1329 next_method = Column(String(1000))
1330 next_kwargs = Column(MutableDict.as_mutable(ExtendedJSON))
1332 _task_display_property_value = Column("task_display_name", String(2000), nullable=True)
1333 # If adding new fields here then remember to add them to
1334 # refresh_from_db() or they won't display in the UI correctly
1336 __table_args__ = (
1337 Index("ti_dag_state", dag_id, state),
1338 Index("ti_dag_run", dag_id, run_id),
1339 Index("ti_state", state),
1340 Index("ti_state_lkp", dag_id, task_id, run_id, state),
1341 Index("ti_pool", pool, state, priority_weight),
1342 Index("ti_job_id", job_id),
1343 Index("ti_trigger_id", trigger_id),
1344 PrimaryKeyConstraint("dag_id", "task_id", "run_id", "map_index", name="task_instance_pkey"),
1345 ForeignKeyConstraint(
1346 [trigger_id],
1347 ["trigger.id"],
1348 name="task_instance_trigger_id_fkey",
1349 ondelete="CASCADE",
1350 ),
1351 ForeignKeyConstraint(
1352 [dag_id, run_id],
1353 ["dag_run.dag_id", "dag_run.run_id"],
1354 name="task_instance_dag_run_fkey",
1355 ondelete="CASCADE",
1356 ),
1357 )
1359 dag_model: DagModel = relationship(
1360 "DagModel",
1361 primaryjoin="TaskInstance.dag_id == DagModel.dag_id",
1362 foreign_keys=dag_id,
1363 uselist=False,
1364 innerjoin=True,
1365 viewonly=True,
1366 )
1368 trigger = relationship("Trigger", uselist=False, back_populates="task_instance")
1369 triggerer_job = association_proxy("trigger", "triggerer_job")
1370 dag_run = relationship("DagRun", back_populates="task_instances", lazy="joined", innerjoin=True)
1371 rendered_task_instance_fields = relationship("RenderedTaskInstanceFields", lazy="noload", uselist=False)
1372 execution_date = association_proxy("dag_run", "execution_date")
1373 task_instance_note = relationship(
1374 "TaskInstanceNote",
1375 back_populates="task_instance",
1376 uselist=False,
1377 cascade="all, delete, delete-orphan",
1378 )
1379 note = association_proxy("task_instance_note", "content", creator=_creator_note)
1381 task: Operator | None = None
1382 test_mode: bool = False
1383 is_trigger_log_context: bool = False
1384 run_as_user: str | None = None
1385 raw: bool | None = None
1386 """Indicate to FileTaskHandler that logging context should be set up for trigger logging.
1388 :meta private:
1389 """
1390 _logger_name = "airflow.task"
1392 def __init__(
1393 self,
1394 task: Operator,
1395 execution_date: datetime | None = None,
1396 run_id: str | None = None,
1397 state: str | None = None,
1398 map_index: int = -1,
1399 ):
1400 super().__init__()
1401 self.dag_id = task.dag_id
1402 self.task_id = task.task_id
1403 self.map_index = map_index
1404 self.refresh_from_task(task)
1405 if TYPE_CHECKING:
1406 assert self.task
1408 # init_on_load will config the log
1409 self.init_on_load()
1411 if run_id is None and execution_date is not None:
1412 from airflow.models.dagrun import DagRun # Avoid circular import
1414 warnings.warn(
1415 "Passing an execution_date to `TaskInstance()` is deprecated in favour of passing a run_id",
1416 RemovedInAirflow3Warning,
1417 # Stack level is 4 because SQLA adds some wrappers around the constructor
1418 stacklevel=4,
1419 )
1420 # make sure we have a localized execution_date stored in UTC
1421 if execution_date and not timezone.is_localized(execution_date):
1422 self.log.warning(
1423 "execution date %s has no timezone information. Using default from dag or system",
1424 execution_date,
1425 )
1426 if self.task.has_dag():
1427 if TYPE_CHECKING:
1428 assert self.task.dag
1429 execution_date = timezone.make_aware(execution_date, self.task.dag.timezone)
1430 else:
1431 execution_date = timezone.make_aware(execution_date)
1433 execution_date = timezone.convert_to_utc(execution_date)
1434 with create_session() as session:
1435 run_id = (
1436 session.query(DagRun.run_id)
1437 .filter_by(dag_id=self.dag_id, execution_date=execution_date)
1438 .scalar()
1439 )
1440 if run_id is None:
1441 raise DagRunNotFound(
1442 f"DagRun for {self.dag_id!r} with date {execution_date} not found"
1443 ) from None
1445 self.run_id = run_id
1447 self.try_number = 0
1448 self.max_tries = self.task.retries
1449 self.unixname = getuser()
1450 if state:
1451 self.state = state
1452 self.hostname = ""
1453 # Is this TaskInstance being currently running within `airflow tasks run --raw`.
1454 # Not persisted to the database so only valid for the current process
1455 self.raw = False
1456 # can be changed when calling 'run'
1457 self.test_mode = False
1459 def __hash__(self):
1460 return hash((self.task_id, self.dag_id, self.run_id, self.map_index))
1462 @property
1463 @deprecated(reason="Use try_number instead.", version="2.10.0", category=RemovedInAirflow3Warning)
1464 def _try_number(self):
1465 """
1466 Do not use. For semblance of backcompat.
1468 :meta private:
1469 """
1470 return self.try_number
1472 @_try_number.setter
1473 @deprecated(reason="Use try_number instead.", version="2.10.0", category=RemovedInAirflow3Warning)
1474 def _try_number(self, val):
1475 """
1476 Do not use. For semblance of backcompat.
1478 :meta private:
1479 """
1480 self.try_number = val
1482 @property
1483 def stats_tags(self) -> dict[str, str]:
1484 """Returns task instance tags."""
1485 return _stats_tags(task_instance=self)
1487 @staticmethod
1488 def insert_mapping(run_id: str, task: Operator, map_index: int) -> dict[str, Any]:
1489 """Insert mapping.
1491 :meta private:
1492 """
1493 priority_weight = task.weight_rule.get_weight(
1494 TaskInstance(task=task, run_id=run_id, map_index=map_index)
1495 )
1497 return {
1498 "dag_id": task.dag_id,
1499 "task_id": task.task_id,
1500 "run_id": run_id,
1501 "try_number": 0,
1502 "hostname": "",
1503 "unixname": getuser(),
1504 "queue": task.queue,
1505 "pool": task.pool,
1506 "pool_slots": task.pool_slots,
1507 "priority_weight": priority_weight,
1508 "run_as_user": task.run_as_user,
1509 "max_tries": task.retries,
1510 "executor": task.executor,
1511 "executor_config": task.executor_config,
1512 "operator": task.task_type,
1513 "custom_operator_name": getattr(task, "custom_operator_name", None),
1514 "map_index": map_index,
1515 "_task_display_property_value": task.task_display_name,
1516 }
1518 @reconstructor
1519 def init_on_load(self) -> None:
1520 """Initialize the attributes that aren't stored in the DB."""
1521 self.test_mode = False # can be changed when calling 'run'
1523 @property
1524 @deprecated(reason="Use try_number instead.", version="2.10.0", category=RemovedInAirflow3Warning)
1525 def prev_attempted_tries(self) -> int:
1526 """
1527 Calculate the total number of attempted tries, defaulting to 0.
1529 This used to be necessary because try_number did not always tell the truth.
1531 :meta private:
1532 """
1533 return self.try_number
1535 @property
1536 def next_try_number(self) -> int:
1537 # todo (dstandish): deprecate this property; we don't need a property that is just + 1
1538 return self.try_number + 1
1540 @property
1541 def operator_name(self) -> str | None:
1542 """@property: use a more friendly display name for the operator, if set."""
1543 return self.custom_operator_name or self.operator
1545 @hybrid_property
1546 def task_display_name(self) -> str:
1547 return self._task_display_property_value or self.task_id
1549 @staticmethod
1550 def _command_as_list(
1551 ti: TaskInstance | TaskInstancePydantic,
1552 mark_success: bool = False,
1553 ignore_all_deps: bool = False,
1554 ignore_task_deps: bool = False,
1555 ignore_depends_on_past: bool = False,
1556 wait_for_past_depends_before_skipping: bool = False,
1557 ignore_ti_state: bool = False,
1558 local: bool = False,
1559 pickle_id: int | None = None,
1560 raw: bool = False,
1561 job_id: str | None = None,
1562 pool: str | None = None,
1563 cfg_path: str | None = None,
1564 ) -> list[str]:
1565 dag: DAG | DagModel | DagModelPydantic | None
1566 # Use the dag if we have it, else fallback to the ORM dag_model, which might not be loaded
1567 if hasattr(ti, "task") and getattr(ti.task, "dag", None) is not None:
1568 if TYPE_CHECKING:
1569 assert ti.task
1570 dag = ti.task.dag
1571 else:
1572 dag = ti.dag_model
1574 if dag is None:
1575 raise ValueError("DagModel is empty")
1577 should_pass_filepath = not pickle_id and dag
1578 path: PurePath | None = None
1579 if should_pass_filepath:
1580 if dag.is_subdag:
1581 if TYPE_CHECKING:
1582 assert dag.parent_dag is not None
1583 path = dag.parent_dag.relative_fileloc
1584 else:
1585 path = dag.relative_fileloc
1587 if path:
1588 if not path.is_absolute():
1589 path = "DAGS_FOLDER" / path
1591 return TaskInstance.generate_command(
1592 ti.dag_id,
1593 ti.task_id,
1594 run_id=ti.run_id,
1595 mark_success=mark_success,
1596 ignore_all_deps=ignore_all_deps,
1597 ignore_task_deps=ignore_task_deps,
1598 ignore_depends_on_past=ignore_depends_on_past,
1599 wait_for_past_depends_before_skipping=wait_for_past_depends_before_skipping,
1600 ignore_ti_state=ignore_ti_state,
1601 local=local,
1602 pickle_id=pickle_id,
1603 file_path=path,
1604 raw=raw,
1605 job_id=job_id,
1606 pool=pool,
1607 cfg_path=cfg_path,
1608 map_index=ti.map_index,
1609 )
1611 def command_as_list(
1612 self,
1613 mark_success: bool = False,
1614 ignore_all_deps: bool = False,
1615 ignore_task_deps: bool = False,
1616 ignore_depends_on_past: bool = False,
1617 wait_for_past_depends_before_skipping: bool = False,
1618 ignore_ti_state: bool = False,
1619 local: bool = False,
1620 pickle_id: int | None = None,
1621 raw: bool = False,
1622 job_id: str | None = None,
1623 pool: str | None = None,
1624 cfg_path: str | None = None,
1625 ) -> list[str]:
1626 """
1627 Return a command that can be executed anywhere where airflow is installed.
1629 This command is part of the message sent to executors by the orchestrator.
1630 """
1631 return TaskInstance._command_as_list(
1632 ti=self,
1633 mark_success=mark_success,
1634 ignore_all_deps=ignore_all_deps,
1635 ignore_task_deps=ignore_task_deps,
1636 ignore_depends_on_past=ignore_depends_on_past,
1637 wait_for_past_depends_before_skipping=wait_for_past_depends_before_skipping,
1638 ignore_ti_state=ignore_ti_state,
1639 local=local,
1640 pickle_id=pickle_id,
1641 raw=raw,
1642 job_id=job_id,
1643 pool=pool,
1644 cfg_path=cfg_path,
1645 )
1647 @staticmethod
1648 def generate_command(
1649 dag_id: str,
1650 task_id: str,
1651 run_id: str,
1652 mark_success: bool = False,
1653 ignore_all_deps: bool = False,
1654 ignore_depends_on_past: bool = False,
1655 wait_for_past_depends_before_skipping: bool = False,
1656 ignore_task_deps: bool = False,
1657 ignore_ti_state: bool = False,
1658 local: bool = False,
1659 pickle_id: int | None = None,
1660 file_path: PurePath | str | None = None,
1661 raw: bool = False,
1662 job_id: str | None = None,
1663 pool: str | None = None,
1664 cfg_path: str | None = None,
1665 map_index: int = -1,
1666 ) -> list[str]:
1667 """
1668 Generate the shell command required to execute this task instance.
1670 :param dag_id: DAG ID
1671 :param task_id: Task ID
1672 :param run_id: The run_id of this task's DagRun
1673 :param mark_success: Whether to mark the task as successful
1674 :param ignore_all_deps: Ignore all ignorable dependencies.
1675 Overrides the other ignore_* parameters.
1676 :param ignore_depends_on_past: Ignore depends_on_past parameter of DAGs
1677 (e.g. for Backfills)
1678 :param wait_for_past_depends_before_skipping: Wait for past depends before marking the ti as skipped
1679 :param ignore_task_deps: Ignore task-specific dependencies such as depends_on_past
1680 and trigger rule
1681 :param ignore_ti_state: Ignore the task instance's previous failure/success
1682 :param local: Whether to run the task locally
1683 :param pickle_id: If the DAG was serialized to the DB, the ID
1684 associated with the pickled DAG
1685 :param file_path: path to the file containing the DAG definition
1686 :param raw: raw mode (needs more details)
1687 :param job_id: job ID (needs more details)
1688 :param pool: the Airflow pool that the task should run in
1689 :param cfg_path: the Path to the configuration file
1690 :return: shell command that can be used to run the task instance
1691 """
1692 cmd = ["airflow", "tasks", "run", dag_id, task_id, run_id]
1693 if mark_success:
1694 cmd.extend(["--mark-success"])
1695 if pickle_id:
1696 cmd.extend(["--pickle", str(pickle_id)])
1697 if job_id:
1698 cmd.extend(["--job-id", str(job_id)])
1699 if ignore_all_deps:
1700 cmd.extend(["--ignore-all-dependencies"])
1701 if ignore_task_deps:
1702 cmd.extend(["--ignore-dependencies"])
1703 if ignore_depends_on_past:
1704 cmd.extend(["--depends-on-past", "ignore"])
1705 elif wait_for_past_depends_before_skipping:
1706 cmd.extend(["--depends-on-past", "wait"])
1707 if ignore_ti_state:
1708 cmd.extend(["--force"])
1709 if local:
1710 cmd.extend(["--local"])
1711 if pool:
1712 cmd.extend(["--pool", pool])
1713 if raw:
1714 cmd.extend(["--raw"])
1715 if file_path:
1716 cmd.extend(["--subdir", os.fspath(file_path)])
1717 if cfg_path:
1718 cmd.extend(["--cfg-path", cfg_path])
1719 if map_index != -1:
1720 cmd.extend(["--map-index", str(map_index)])
1721 return cmd
1723 @property
1724 def log_url(self) -> str:
1725 """Log URL for TaskInstance."""
1726 run_id = quote(self.run_id)
1727 base_url = conf.get_mandatory_value("webserver", "BASE_URL")
1728 return (
1729 f"{base_url}"
1730 f"/dags"
1731 f"/{self.dag_id}"
1732 f"/grid"
1733 f"?dag_run_id={run_id}"
1734 f"&task_id={self.task_id}"
1735 f"&map_index={self.map_index}"
1736 "&tab=logs"
1737 )
1739 @property
1740 def mark_success_url(self) -> str:
1741 """URL to mark TI success."""
1742 base_url = conf.get_mandatory_value("webserver", "BASE_URL")
1743 return (
1744 f"{base_url}"
1745 "/confirm"
1746 f"?task_id={self.task_id}"
1747 f"&dag_id={self.dag_id}"
1748 f"&dag_run_id={quote(self.run_id)}"
1749 "&upstream=false"
1750 "&downstream=false"
1751 "&state=success"
1752 )
1754 @provide_session
1755 def current_state(self, session: Session = NEW_SESSION) -> str:
1756 """
1757 Get the very latest state from the database.
1759 If a session is passed, we use and looking up the state becomes part of the session,
1760 otherwise a new session is used.
1762 sqlalchemy.inspect is used here to get the primary keys ensuring that if they change
1763 it will not regress
1765 :param session: SQLAlchemy ORM Session
1766 """
1767 filters = (col == getattr(self, col.name) for col in inspect(TaskInstance).primary_key)
1768 return session.query(TaskInstance.state).filter(*filters).scalar()
1770 @provide_session
1771 def error(self, session: Session = NEW_SESSION) -> None:
1772 """
1773 Force the task instance's state to FAILED in the database.
1775 :param session: SQLAlchemy ORM Session
1776 """
1777 self.log.error("Recording the task instance as FAILED")
1778 self.state = TaskInstanceState.FAILED
1779 session.merge(self)
1780 session.commit()
1782 @classmethod
1783 @internal_api_call
1784 @provide_session
1785 def get_task_instance(
1786 cls,
1787 dag_id: str,
1788 run_id: str,
1789 task_id: str,
1790 map_index: int,
1791 lock_for_update: bool = False,
1792 session: Session = NEW_SESSION,
1793 ) -> TaskInstance | TaskInstancePydantic | None:
1794 query = (
1795 session.query(TaskInstance)
1796 .options(lazyload(TaskInstance.dag_run)) # lazy load dag run to avoid locking it
1797 .filter_by(
1798 dag_id=dag_id,
1799 run_id=run_id,
1800 task_id=task_id,
1801 map_index=map_index,
1802 )
1803 )
1805 if lock_for_update:
1806 for attempt in run_with_db_retries(logger=cls.logger()):
1807 with attempt:
1808 return query.with_for_update().one_or_none()
1809 else:
1810 return query.one_or_none()
1812 return None
1814 @provide_session
1815 def refresh_from_db(self, session: Session = NEW_SESSION, lock_for_update: bool = False) -> None:
1816 """
1817 Refresh the task instance from the database based on the primary key.
1819 :param session: SQLAlchemy ORM Session
1820 :param lock_for_update: if True, indicates that the database should
1821 lock the TaskInstance (issuing a FOR UPDATE clause) until the
1822 session is committed.
1823 """
1824 _refresh_from_db(task_instance=self, session=session, lock_for_update=lock_for_update)
1826 def refresh_from_task(self, task: Operator, pool_override: str | None = None) -> None:
1827 """
1828 Copy common attributes from the given task.
1830 :param task: The task object to copy from
1831 :param pool_override: Use the pool_override instead of task's pool
1832 """
1833 _refresh_from_task(task_instance=self, task=task, pool_override=pool_override)
1835 @staticmethod
1836 @internal_api_call
1837 @provide_session
1838 def _clear_xcom_data(ti: TaskInstance | TaskInstancePydantic, session: Session = NEW_SESSION) -> None:
1839 """Clear all XCom data from the database for the task instance.
1841 If the task is unmapped, all XComs matching this task ID in the same DAG
1842 run are removed. If the task is mapped, only the one with matching map
1843 index is removed.
1845 :param ti: The TI for which we need to clear xcoms.
1846 :param session: SQLAlchemy ORM Session
1847 """
1848 ti.log.debug("Clearing XCom data")
1849 if ti.map_index < 0:
1850 map_index: int | None = None
1851 else:
1852 map_index = ti.map_index
1853 XCom.clear(
1854 dag_id=ti.dag_id,
1855 task_id=ti.task_id,
1856 run_id=ti.run_id,
1857 map_index=map_index,
1858 session=session,
1859 )
1861 @provide_session
1862 def clear_xcom_data(self, session: Session = NEW_SESSION):
1863 self._clear_xcom_data(ti=self, session=session)
1865 @property
1866 def key(self) -> TaskInstanceKey:
1867 """Returns a tuple that identifies the task instance uniquely."""
1868 return TaskInstanceKey(self.dag_id, self.task_id, self.run_id, self.try_number, self.map_index)
1870 @staticmethod
1871 @internal_api_call
1872 def _set_state(ti: TaskInstance | TaskInstancePydantic, state, session: Session) -> bool:
1873 if not isinstance(ti, TaskInstance):
1874 ti = session.scalars(
1875 select(TaskInstance).where(
1876 TaskInstance.task_id == ti.task_id,
1877 TaskInstance.dag_id == ti.dag_id,
1878 TaskInstance.run_id == ti.run_id,
1879 TaskInstance.map_index == ti.map_index,
1880 )
1881 ).one()
1883 if ti.state == state:
1884 return False
1886 current_time = timezone.utcnow()
1887 ti.log.debug("Setting task state for %s to %s", ti, state)
1888 ti.state = state
1889 ti.start_date = ti.start_date or current_time
1890 if ti.state in State.finished or ti.state == TaskInstanceState.UP_FOR_RETRY:
1891 ti.end_date = ti.end_date or current_time
1892 ti.duration = (ti.end_date - ti.start_date).total_seconds()
1893 session.merge(ti)
1894 return True
1896 @provide_session
1897 def set_state(self, state: str | None, session: Session = NEW_SESSION) -> bool:
1898 """
1899 Set TaskInstance state.
1901 :param state: State to set for the TI
1902 :param session: SQLAlchemy ORM Session
1903 :return: Was the state changed
1904 """
1905 return self._set_state(ti=self, state=state, session=session)
1907 @property
1908 def is_premature(self) -> bool:
1909 """Returns whether a task is in UP_FOR_RETRY state and its retry interval has elapsed."""
1910 # is the task still in the retry waiting period?
1911 return self.state == TaskInstanceState.UP_FOR_RETRY and not self.ready_for_retry()
1913 @provide_session
1914 def are_dependents_done(self, session: Session = NEW_SESSION) -> bool:
1915 """
1916 Check whether the immediate dependents of this task instance have succeeded or have been skipped.
1918 This is meant to be used by wait_for_downstream.
1920 This is useful when you do not want to start processing the next
1921 schedule of a task until the dependents are done. For instance,
1922 if the task DROPs and recreates a table.
1924 :param session: SQLAlchemy ORM Session
1925 """
1926 task = self.task
1927 if TYPE_CHECKING:
1928 assert task
1930 if not task.downstream_task_ids:
1931 return True
1933 ti = session.query(func.count(TaskInstance.task_id)).filter(
1934 TaskInstance.dag_id == self.dag_id,
1935 TaskInstance.task_id.in_(task.downstream_task_ids),
1936 TaskInstance.run_id == self.run_id,
1937 TaskInstance.state.in_((TaskInstanceState.SKIPPED, TaskInstanceState.SUCCESS)),
1938 )
1939 count = ti[0][0]
1940 return count == len(task.downstream_task_ids)
1942 @provide_session
1943 def get_previous_dagrun(
1944 self,
1945 state: DagRunState | None = None,
1946 session: Session | None = None,
1947 ) -> DagRun | None:
1948 """
1949 Return the DagRun that ran before this task instance's DagRun.
1951 :param state: If passed, it only take into account instances of a specific state.
1952 :param session: SQLAlchemy ORM Session.
1953 """
1954 return _get_previous_dagrun(task_instance=self, state=state, session=session)
1956 @provide_session
1957 def get_previous_ti(
1958 self,
1959 state: DagRunState | None = None,
1960 session: Session = NEW_SESSION,
1961 ) -> TaskInstance | TaskInstancePydantic | None:
1962 """
1963 Return the task instance for the task that ran before this task instance.
1965 :param session: SQLAlchemy ORM Session
1966 :param state: If passed, it only take into account instances of a specific state.
1967 """
1968 return _get_previous_ti(task_instance=self, state=state, session=session)
1970 @property
1971 def previous_ti(self) -> TaskInstance | TaskInstancePydantic | None:
1972 """
1973 This attribute is deprecated.
1975 Please use :class:`airflow.models.taskinstance.TaskInstance.get_previous_ti`.
1976 """
1977 warnings.warn(
1978 """
1979 This attribute is deprecated.
1980 Please use `airflow.models.taskinstance.TaskInstance.get_previous_ti` method.
1981 """,
1982 RemovedInAirflow3Warning,
1983 stacklevel=2,
1984 )
1985 return self.get_previous_ti()
1987 @property
1988 def previous_ti_success(self) -> TaskInstance | TaskInstancePydantic | None:
1989 """
1990 This attribute is deprecated.
1992 Please use :class:`airflow.models.taskinstance.TaskInstance.get_previous_ti`.
1993 """
1994 warnings.warn(
1995 """
1996 This attribute is deprecated.
1997 Please use `airflow.models.taskinstance.TaskInstance.get_previous_ti` method.
1998 """,
1999 RemovedInAirflow3Warning,
2000 stacklevel=2,
2001 )
2002 return self.get_previous_ti(state=DagRunState.SUCCESS)
2004 @provide_session
2005 def get_previous_execution_date(
2006 self,
2007 state: DagRunState | None = None,
2008 session: Session = NEW_SESSION,
2009 ) -> pendulum.DateTime | None:
2010 """
2011 Return the execution date from property previous_ti_success.
2013 :param state: If passed, it only take into account instances of a specific state.
2014 :param session: SQLAlchemy ORM Session
2015 """
2016 return _get_previous_execution_date(task_instance=self, state=state, session=session)
2018 @provide_session
2019 def get_previous_start_date(
2020 self, state: DagRunState | None = None, session: Session = NEW_SESSION
2021 ) -> pendulum.DateTime | None:
2022 """
2023 Return the start date from property previous_ti_success.
2025 :param state: If passed, it only take into account instances of a specific state.
2026 :param session: SQLAlchemy ORM Session
2027 """
2028 self.log.debug("previous_start_date was called")
2029 prev_ti = self.get_previous_ti(state=state, session=session)
2030 # prev_ti may not exist and prev_ti.start_date may be None.
2031 return pendulum.instance(prev_ti.start_date) if prev_ti and prev_ti.start_date else None
2033 @property
2034 def previous_start_date_success(self) -> pendulum.DateTime | None:
2035 """
2036 This attribute is deprecated.
2038 Please use :class:`airflow.models.taskinstance.TaskInstance.get_previous_start_date`.
2039 """
2040 warnings.warn(
2041 """
2042 This attribute is deprecated.
2043 Please use `airflow.models.taskinstance.TaskInstance.get_previous_start_date` method.
2044 """,
2045 RemovedInAirflow3Warning,
2046 stacklevel=2,
2047 )
2048 return self.get_previous_start_date(state=DagRunState.SUCCESS)
2050 @provide_session
2051 def are_dependencies_met(
2052 self, dep_context: DepContext | None = None, session: Session = NEW_SESSION, verbose: bool = False
2053 ) -> bool:
2054 """
2055 Are all conditions met for this task instance to be run given the context for the dependencies.
2057 (e.g. a task instance being force run from the UI will ignore some dependencies).
2059 :param dep_context: The execution context that determines the dependencies that should be evaluated.
2060 :param session: database session
2061 :param verbose: whether log details on failed dependencies on info or debug log level
2062 """
2063 dep_context = dep_context or DepContext()
2064 failed = False
2065 verbose_aware_logger = self.log.info if verbose else self.log.debug
2066 for dep_status in self.get_failed_dep_statuses(dep_context=dep_context, session=session):
2067 failed = True
2069 verbose_aware_logger(
2070 "Dependencies not met for %s, dependency '%s' FAILED: %s",
2071 self,
2072 dep_status.dep_name,
2073 dep_status.reason,
2074 )
2076 if failed:
2077 return False
2079 verbose_aware_logger("Dependencies all met for dep_context=%s ti=%s", dep_context.description, self)
2080 return True
2082 @provide_session
2083 def get_failed_dep_statuses(self, dep_context: DepContext | None = None, session: Session = NEW_SESSION):
2084 """Get failed Dependencies."""
2085 if TYPE_CHECKING:
2086 assert self.task
2088 dep_context = dep_context or DepContext()
2089 for dep in dep_context.deps | self.task.deps:
2090 for dep_status in dep.get_dep_statuses(self, session, dep_context):
2091 self.log.debug(
2092 "%s dependency '%s' PASSED: %s, %s",
2093 self,
2094 dep_status.dep_name,
2095 dep_status.passed,
2096 dep_status.reason,
2097 )
2099 if not dep_status.passed:
2100 yield dep_status
2102 def __repr__(self) -> str:
2103 prefix = f"<TaskInstance: {self.dag_id}.{self.task_id} {self.run_id} "
2104 if self.map_index != -1:
2105 prefix += f"map_index={self.map_index} "
2106 return prefix + f"[{self.state}]>"
2108 def next_retry_datetime(self):
2109 """
2110 Get datetime of the next retry if the task instance fails.
2112 For exponential backoff, retry_delay is used as base and will be converted to seconds.
2113 """
2114 from airflow.models.abstractoperator import MAX_RETRY_DELAY
2116 delay = self.task.retry_delay
2117 if self.task.retry_exponential_backoff:
2118 # If the min_backoff calculation is below 1, it will be converted to 0 via int. Thus,
2119 # we must round up prior to converting to an int, otherwise a divide by zero error
2120 # will occur in the modded_hash calculation.
2121 # this probably gives unexpected results if a task instance has previously been cleared,
2122 # because try_number can increase without bound
2123 min_backoff = math.ceil(delay.total_seconds() * (2 ** (self.try_number - 1)))
2125 # In the case when delay.total_seconds() is 0, min_backoff will not be rounded up to 1.
2126 # To address this, we impose a lower bound of 1 on min_backoff. This effectively makes
2127 # the ceiling function unnecessary, but the ceiling function was retained to avoid
2128 # introducing a breaking change.
2129 if min_backoff < 1:
2130 min_backoff = 1
2132 # deterministic per task instance
2133 ti_hash = int(
2134 hashlib.sha1(
2135 f"{self.dag_id}#{self.task_id}#{self.execution_date}#{self.try_number}".encode()
2136 ).hexdigest(),
2137 16,
2138 )
2139 # between 1 and 1.0 * delay * (2^retry_number)
2140 modded_hash = min_backoff + ti_hash % min_backoff
2141 # timedelta has a maximum representable value. The exponentiation
2142 # here means this value can be exceeded after a certain number
2143 # of tries (around 50 if the initial delay is 1s, even fewer if
2144 # the delay is larger). Cap the value here before creating a
2145 # timedelta object so the operation doesn't fail with "OverflowError".
2146 delay_backoff_in_seconds = min(modded_hash, MAX_RETRY_DELAY)
2147 delay = timedelta(seconds=delay_backoff_in_seconds)
2148 if self.task.max_retry_delay:
2149 delay = min(self.task.max_retry_delay, delay)
2150 return self.end_date + delay
2152 def ready_for_retry(self) -> bool:
2153 """Check on whether the task instance is in the right state and timeframe to be retried."""
2154 return self.state == TaskInstanceState.UP_FOR_RETRY and self.next_retry_datetime() < timezone.utcnow()
2156 @staticmethod
2157 @internal_api_call
2158 def _get_dagrun(dag_id, run_id, session) -> DagRun:
2159 from airflow.models.dagrun import DagRun # Avoid circular import
2161 dr = session.query(DagRun).filter(DagRun.dag_id == dag_id, DagRun.run_id == run_id).one()
2162 return dr
2164 @provide_session
2165 def get_dagrun(self, session: Session = NEW_SESSION) -> DagRun:
2166 """
2167 Return the DagRun for this TaskInstance.
2169 :param session: SQLAlchemy ORM Session
2170 :return: DagRun
2171 """
2172 info = inspect(self)
2173 if info.attrs.dag_run.loaded_value is not NO_VALUE:
2174 if getattr(self, "task", None) is not None:
2175 if TYPE_CHECKING:
2176 assert self.task
2177 self.dag_run.dag = self.task.dag
2178 return self.dag_run
2180 dr = self._get_dagrun(self.dag_id, self.run_id, session)
2181 if getattr(self, "task", None) is not None:
2182 if TYPE_CHECKING:
2183 assert self.task
2184 dr.dag = self.task.dag
2185 # Record it in the instance for next time. This means that `self.execution_date` will work correctly
2186 set_committed_value(self, "dag_run", dr)
2188 return dr
2190 @classmethod
2191 @internal_api_call
2192 @provide_session
2193 def _check_and_change_state_before_execution(
2194 cls,
2195 task_instance: TaskInstance | TaskInstancePydantic,
2196 verbose: bool = True,
2197 ignore_all_deps: bool = False,
2198 ignore_depends_on_past: bool = False,
2199 wait_for_past_depends_before_skipping: bool = False,
2200 ignore_task_deps: bool = False,
2201 ignore_ti_state: bool = False,
2202 mark_success: bool = False,
2203 test_mode: bool = False,
2204 hostname: str = "",
2205 job_id: str | None = None,
2206 pool: str | None = None,
2207 external_executor_id: str | None = None,
2208 session: Session = NEW_SESSION,
2209 ) -> bool:
2210 """
2211 Check dependencies and then sets state to RUNNING if they are met.
2213 Returns True if and only if state is set to RUNNING, which implies that task should be
2214 executed, in preparation for _run_raw_task.
2216 :param verbose: whether to turn on more verbose logging
2217 :param ignore_all_deps: Ignore all of the non-critical dependencies, just runs
2218 :param ignore_depends_on_past: Ignore depends_on_past DAG attribute
2219 :param wait_for_past_depends_before_skipping: Wait for past depends before mark the ti as skipped
2220 :param ignore_task_deps: Don't check the dependencies of this TaskInstance's task
2221 :param ignore_ti_state: Disregards previous task instance state
2222 :param mark_success: Don't run the task, mark its state as success
2223 :param test_mode: Doesn't record success or failure in the DB
2224 :param hostname: The hostname of the worker running the task instance.
2225 :param job_id: Job (BackfillJob / LocalTaskJob / SchedulerJob) ID
2226 :param pool: specifies the pool to use to run the task instance
2227 :param external_executor_id: The identifier of the celery executor
2228 :param session: SQLAlchemy ORM Session
2229 :return: whether the state was changed to running or not
2230 """
2231 if TYPE_CHECKING:
2232 assert task_instance.task
2234 if isinstance(task_instance, TaskInstance):
2235 ti: TaskInstance = task_instance
2236 else: # isinstance(task_instance, TaskInstancePydantic)
2237 filters = (col == getattr(task_instance, col.name) for col in inspect(TaskInstance).primary_key)
2238 ti = session.query(TaskInstance).filter(*filters).scalar()
2239 dag = ti.dag_model.serialized_dag.dag
2240 task_instance.task = dag.task_dict[ti.task_id]
2241 ti.task = task_instance.task
2242 task = task_instance.task
2243 if TYPE_CHECKING:
2244 assert task
2245 ti.refresh_from_task(task, pool_override=pool)
2246 ti.test_mode = test_mode
2247 ti.refresh_from_db(session=session, lock_for_update=True)
2248 ti.job_id = job_id
2249 ti.hostname = hostname
2250 ti.pid = None
2252 if not ignore_all_deps and not ignore_ti_state and ti.state == TaskInstanceState.SUCCESS:
2253 Stats.incr("previously_succeeded", tags=ti.stats_tags)
2255 if not mark_success:
2256 # Firstly find non-runnable and non-requeueable tis.
2257 # Since mark_success is not set, we do nothing.
2258 non_requeueable_dep_context = DepContext(
2259 deps=RUNNING_DEPS - REQUEUEABLE_DEPS,
2260 ignore_all_deps=ignore_all_deps,
2261 ignore_ti_state=ignore_ti_state,
2262 ignore_depends_on_past=ignore_depends_on_past,
2263 wait_for_past_depends_before_skipping=wait_for_past_depends_before_skipping,
2264 ignore_task_deps=ignore_task_deps,
2265 description="non-requeueable deps",
2266 )
2267 if not ti.are_dependencies_met(
2268 dep_context=non_requeueable_dep_context, session=session, verbose=True
2269 ):
2270 session.commit()
2271 return False
2273 # For reporting purposes, we report based on 1-indexed,
2274 # not 0-indexed lists (i.e. Attempt 1 instead of
2275 # Attempt 0 for the first attempt).
2276 # Set the task start date. In case it was re-scheduled use the initial
2277 # start date that is recorded in task_reschedule table
2278 # If the task continues after being deferred (next_method is set), use the original start_date
2279 ti.start_date = ti.start_date if ti.next_method else timezone.utcnow()
2280 if ti.state == TaskInstanceState.UP_FOR_RESCHEDULE:
2281 tr_start_date = session.scalar(
2282 TR.stmt_for_task_instance(ti, descending=False).with_only_columns(TR.start_date).limit(1)
2283 )
2284 if tr_start_date:
2285 ti.start_date = tr_start_date
2287 # Secondly we find non-runnable but requeueable tis. We reset its state.
2288 # This is because we might have hit concurrency limits,
2289 # e.g. because of backfilling.
2290 dep_context = DepContext(
2291 deps=REQUEUEABLE_DEPS,
2292 ignore_all_deps=ignore_all_deps,
2293 ignore_depends_on_past=ignore_depends_on_past,
2294 wait_for_past_depends_before_skipping=wait_for_past_depends_before_skipping,
2295 ignore_task_deps=ignore_task_deps,
2296 ignore_ti_state=ignore_ti_state,
2297 description="requeueable deps",
2298 )
2299 if not ti.are_dependencies_met(dep_context=dep_context, session=session, verbose=True):
2300 ti.state = None
2301 cls.logger().warning(
2302 "Rescheduling due to concurrency limits reached "
2303 "at task runtime. Attempt %s of "
2304 "%s. State set to NONE.",
2305 ti.try_number,
2306 ti.max_tries + 1,
2307 )
2308 ti.queued_dttm = timezone.utcnow()
2309 session.merge(ti)
2310 session.commit()
2311 return False
2313 if ti.next_kwargs is not None:
2314 cls.logger().info("Resuming after deferral")
2315 else:
2316 cls.logger().info("Starting attempt %s of %s", ti.try_number, ti.max_tries + 1)
2318 if not test_mode:
2319 session.add(Log(TaskInstanceState.RUNNING.value, ti))
2321 ti.state = TaskInstanceState.RUNNING
2322 ti.emit_state_change_metric(TaskInstanceState.RUNNING)
2324 if external_executor_id:
2325 ti.external_executor_id = external_executor_id
2327 ti.end_date = None
2328 if not test_mode:
2329 session.merge(ti).task = task
2330 session.commit()
2332 # Closing all pooled connections to prevent
2333 # "max number of connections reached"
2334 settings.engine.dispose() # type: ignore
2335 if verbose:
2336 if mark_success:
2337 cls.logger().info("Marking success for %s on %s", ti.task, ti.execution_date)
2338 else:
2339 cls.logger().info("Executing %s on %s", ti.task, ti.execution_date)
2340 return True
2342 @provide_session
2343 def check_and_change_state_before_execution(
2344 self,
2345 verbose: bool = True,
2346 ignore_all_deps: bool = False,
2347 ignore_depends_on_past: bool = False,
2348 wait_for_past_depends_before_skipping: bool = False,
2349 ignore_task_deps: bool = False,
2350 ignore_ti_state: bool = False,
2351 mark_success: bool = False,
2352 test_mode: bool = False,
2353 job_id: str | None = None,
2354 pool: str | None = None,
2355 external_executor_id: str | None = None,
2356 session: Session = NEW_SESSION,
2357 ) -> bool:
2358 return TaskInstance._check_and_change_state_before_execution(
2359 task_instance=self,
2360 verbose=verbose,
2361 ignore_all_deps=ignore_all_deps,
2362 ignore_depends_on_past=ignore_depends_on_past,
2363 wait_for_past_depends_before_skipping=wait_for_past_depends_before_skipping,
2364 ignore_task_deps=ignore_task_deps,
2365 ignore_ti_state=ignore_ti_state,
2366 mark_success=mark_success,
2367 test_mode=test_mode,
2368 hostname=get_hostname(),
2369 job_id=job_id,
2370 pool=pool,
2371 external_executor_id=external_executor_id,
2372 session=session,
2373 )
2375 def emit_state_change_metric(self, new_state: TaskInstanceState) -> None:
2376 """
2377 Send a time metric representing how much time a given state transition took.
2379 The previous state and metric name is deduced from the state the task was put in.
2381 :param new_state: The state that has just been set for this task.
2382 We do not use `self.state`, because sometimes the state is updated directly in the DB and not in
2383 the local TaskInstance object.
2384 Supported states: QUEUED and RUNNING
2385 """
2386 if self.end_date:
2387 # if the task has an end date, it means that this is not its first round.
2388 # we send the state transition time metric only on the first try, otherwise it gets more complex.
2389 return
2391 # switch on state and deduce which metric to send
2392 if new_state == TaskInstanceState.RUNNING:
2393 metric_name = "queued_duration"
2394 if self.queued_dttm is None:
2395 # this should not really happen except in tests or rare cases,
2396 # but we don't want to create errors just for a metric, so we just skip it
2397 self.log.warning(
2398 "cannot record %s for task %s because previous state change time has not been saved",
2399 metric_name,
2400 self.task_id,
2401 )
2402 return
2403 timing = (timezone.utcnow() - self.queued_dttm).total_seconds()
2404 elif new_state == TaskInstanceState.QUEUED:
2405 metric_name = "scheduled_duration"
2406 if self.start_date is None:
2407 # This check does not work correctly before fields like `scheduled_dttm` are implemented.
2408 # TODO: Change the level to WARNING once it's viable.
2409 # see #30612 #34493 and #34771 for more details
2410 self.log.debug(
2411 "cannot record %s for task %s because previous state change time has not been saved",
2412 metric_name,
2413 self.task_id,
2414 )
2415 return
2416 timing = (timezone.utcnow() - self.start_date).total_seconds()
2417 else:
2418 raise NotImplementedError("no metric emission setup for state %s", new_state)
2420 # send metric twice, once (legacy) with tags in the name and once with tags as tags
2421 Stats.timing(f"dag.{self.dag_id}.{self.task_id}.{metric_name}", timing)
2422 Stats.timing(f"task.{metric_name}", timing, tags={"task_id": self.task_id, "dag_id": self.dag_id})
2424 def clear_next_method_args(self) -> None:
2425 """Ensure we unset next_method and next_kwargs to ensure that any retries don't reuse them."""
2426 _clear_next_method_args(task_instance=self)
2428 @provide_session
2429 @Sentry.enrich_errors
2430 def _run_raw_task(
2431 self,
2432 mark_success: bool = False,
2433 test_mode: bool = False,
2434 job_id: str | None = None,
2435 pool: str | None = None,
2436 raise_on_defer: bool = False,
2437 session: Session = NEW_SESSION,
2438 ) -> TaskReturnCode | None:
2439 """
2440 Run a task, update the state upon completion, and run any appropriate callbacks.
2442 Immediately runs the task (without checking or changing db state
2443 before execution) and then sets the appropriate final state after
2444 completion and runs any post-execute callbacks. Meant to be called
2445 only after another function changes the state to running.
2447 :param mark_success: Don't run the task, mark its state as success
2448 :param test_mode: Doesn't record success or failure in the DB
2449 :param pool: specifies the pool to use to run the task instance
2450 :param session: SQLAlchemy ORM Session
2451 """
2452 if TYPE_CHECKING:
2453 assert self.task
2455 self.test_mode = test_mode
2456 self.refresh_from_task(self.task, pool_override=pool)
2457 self.refresh_from_db(session=session)
2459 self.job_id = job_id
2460 self.hostname = get_hostname()
2461 self.pid = os.getpid()
2462 if not test_mode:
2463 session.merge(self)
2464 session.commit()
2465 actual_start_date = timezone.utcnow()
2466 Stats.incr(f"ti.start.{self.task.dag_id}.{self.task.task_id}", tags=self.stats_tags)
2467 # Same metric with tagging
2468 Stats.incr("ti.start", tags=self.stats_tags)
2469 # Initialize final state counters at zero
2470 for state in State.task_states:
2471 Stats.incr(
2472 f"ti.finish.{self.task.dag_id}.{self.task.task_id}.{state}",
2473 count=0,
2474 tags=self.stats_tags,
2475 )
2476 # Same metric with tagging
2477 Stats.incr(
2478 "ti.finish",
2479 count=0,
2480 tags={**self.stats_tags, "state": str(state)},
2481 )
2482 with set_current_task_instance_session(session=session):
2483 self.task = self.task.prepare_for_execution()
2484 context = self.get_template_context(ignore_param_exceptions=False)
2486 try:
2487 if not mark_success:
2488 self._execute_task_with_callbacks(context, test_mode, session=session)
2489 if not test_mode:
2490 self.refresh_from_db(lock_for_update=True, session=session)
2491 self.state = TaskInstanceState.SUCCESS
2492 except TaskDeferred as defer:
2493 # The task has signalled it wants to defer execution based on
2494 # a trigger.
2495 if raise_on_defer:
2496 raise
2497 self.defer_task(defer=defer, session=session)
2498 self.log.info(
2499 "Pausing task as DEFERRED. dag_id=%s, task_id=%s, run_id=%s, execution_date=%s, start_date=%s",
2500 self.dag_id,
2501 self.task_id,
2502 self.run_id,
2503 _date_or_empty(task_instance=self, attr="execution_date"),
2504 _date_or_empty(task_instance=self, attr="start_date"),
2505 )
2506 if not test_mode:
2507 session.add(Log(self.state, self))
2508 session.merge(self)
2509 session.commit()
2510 return TaskReturnCode.DEFERRED
2511 except AirflowSkipException as e:
2512 # Recording SKIP
2513 # log only if exception has any arguments to prevent log flooding
2514 if e.args:
2515 self.log.info(e)
2516 if not test_mode:
2517 self.refresh_from_db(lock_for_update=True, session=session)
2518 _run_finished_callback(callbacks=self.task.on_skipped_callback, context=context)
2519 session.commit()
2520 self.state = TaskInstanceState.SKIPPED
2521 except AirflowRescheduleException as reschedule_exception:
2522 self._handle_reschedule(actual_start_date, reschedule_exception, test_mode, session=session)
2523 session.commit()
2524 return None
2525 except (AirflowFailException, AirflowSensorTimeout) as e:
2526 # If AirflowFailException is raised, task should not retry.
2527 # If a sensor in reschedule mode reaches timeout, task should not retry.
2528 self.handle_failure(e, test_mode, context, force_fail=True, session=session)
2529 session.commit()
2530 raise
2531 except (AirflowTaskTimeout, AirflowException, AirflowTaskTerminated) as e:
2532 if not test_mode:
2533 self.refresh_from_db(lock_for_update=True, session=session)
2534 # for case when task is marked as success/failed externally
2535 # or dagrun timed out and task is marked as skipped
2536 # current behavior doesn't hit the callbacks
2537 if self.state in State.finished:
2538 self.clear_next_method_args()
2539 session.merge(self)
2540 session.commit()
2541 return None
2542 else:
2543 self.handle_failure(e, test_mode, context, session=session)
2544 session.commit()
2545 raise
2546 except SystemExit as e:
2547 # We have already handled SystemExit with success codes (0 and None) in the `_execute_task`.
2548 # Therefore, here we must handle only error codes.
2549 msg = f"Task failed due to SystemExit({e.code})"
2550 self.handle_failure(msg, test_mode, context, session=session)
2551 session.commit()
2552 raise AirflowException(msg)
2553 except BaseException as e:
2554 self.handle_failure(e, test_mode, context, session=session)
2555 session.commit()
2556 raise
2557 finally:
2558 Stats.incr(f"ti.finish.{self.dag_id}.{self.task_id}.{self.state}", tags=self.stats_tags)
2559 # Same metric with tagging
2560 Stats.incr("ti.finish", tags={**self.stats_tags, "state": str(self.state)})
2562 # Recording SKIPPED or SUCCESS
2563 self.clear_next_method_args()
2564 self.end_date = timezone.utcnow()
2565 _log_state(task_instance=self)
2566 self.set_duration()
2568 # run on_success_callback before db committing
2569 # otherwise, the LocalTaskJob sees the state is changed to `success`,
2570 # but the task_runner is still running, LocalTaskJob then treats the state is set externally!
2571 _run_finished_callback(callbacks=self.task.on_success_callback, context=context)
2573 if not test_mode:
2574 session.add(Log(self.state, self))
2575 session.merge(self).task = self.task
2576 if self.state == TaskInstanceState.SUCCESS:
2577 self._register_dataset_changes(events=context["outlet_events"], session=session)
2579 session.commit()
2580 if self.state == TaskInstanceState.SUCCESS:
2581 get_listener_manager().hook.on_task_instance_success(
2582 previous_state=TaskInstanceState.RUNNING, task_instance=self, session=session
2583 )
2585 return None
2587 def _register_dataset_changes(self, *, events: OutletEventAccessors, session: Session) -> None:
2588 if TYPE_CHECKING:
2589 assert self.task
2591 for obj in self.task.outlets or []:
2592 self.log.debug("outlet obj %s", obj)
2593 # Lineage can have other types of objects besides datasets
2594 if isinstance(obj, Dataset):
2595 dataset_manager.register_dataset_change(
2596 task_instance=self,
2597 dataset=obj,
2598 extra=events[obj].extra,
2599 session=session,
2600 )
2602 def _execute_task_with_callbacks(self, context: Context, test_mode: bool = False, *, session: Session):
2603 """Prepare Task for Execution."""
2604 if TYPE_CHECKING:
2605 assert self.task
2607 parent_pid = os.getpid()
2609 def signal_handler(signum, frame):
2610 pid = os.getpid()
2612 # If a task forks during execution (from DAG code) for whatever
2613 # reason, we want to make sure that we react to the signal only in
2614 # the process that we've spawned ourselves (referred to here as the
2615 # parent process).
2616 if pid != parent_pid:
2617 os._exit(1)
2618 return
2619 self.log.error("Received SIGTERM. Terminating subprocesses.")
2620 self.task.on_kill()
2621 raise AirflowTaskTerminated("Task received SIGTERM signal")
2623 signal.signal(signal.SIGTERM, signal_handler)
2625 # Don't clear Xcom until the task is certain to execute, and check if we are resuming from deferral.
2626 if not self.next_method:
2627 self.clear_xcom_data()
2629 with Stats.timer(f"dag.{self.task.dag_id}.{self.task.task_id}.duration"), Stats.timer(
2630 "task.duration", tags=self.stats_tags
2631 ):
2632 # Set the validated/merged params on the task object.
2633 self.task.params = context["params"]
2635 with set_current_context(context):
2636 dag = self.task.get_dag()
2637 if dag is not None:
2638 jinja_env = dag.get_template_env()
2639 else:
2640 jinja_env = None
2641 task_orig = self.render_templates(context=context, jinja_env=jinja_env)
2643 # The task is never MappedOperator at this point.
2644 if TYPE_CHECKING:
2645 assert isinstance(self.task, BaseOperator)
2647 if not test_mode:
2648 rendered_fields = get_serialized_template_fields(task=self.task)
2649 _update_rtif(ti=self, rendered_fields=rendered_fields)
2650 # Export context to make it available for operators to use.
2651 airflow_context_vars = context_to_airflow_vars(context, in_env_var_format=True)
2652 os.environ.update(airflow_context_vars)
2654 # Log context only for the default execution method, the assumption
2655 # being that otherwise we're resuming a deferred task (in which
2656 # case there's no need to log these again).
2657 if not self.next_method:
2658 self.log.info(
2659 "Exporting env vars: %s",
2660 " ".join(f"{k}={v!r}" for k, v in airflow_context_vars.items()),
2661 )
2663 # Run pre_execute callback
2664 self.task.pre_execute(context=context)
2666 # Run on_execute callback
2667 self._run_execute_callback(context, self.task)
2669 # Run on_task_instance_running event
2670 get_listener_manager().hook.on_task_instance_running(
2671 previous_state=TaskInstanceState.QUEUED, task_instance=self, session=session
2672 )
2674 def _render_map_index(context: Context, *, jinja_env: jinja2.Environment | None) -> str | None:
2675 """Render named map index if the DAG author defined map_index_template at the task level."""
2676 if jinja_env is None or (template := context.get("map_index_template")) is None:
2677 return None
2678 rendered_map_index = jinja_env.from_string(template).render(context)
2679 log.debug("Map index rendered as %s", rendered_map_index)
2680 return rendered_map_index
2682 # Execute the task.
2683 with set_current_context(context):
2684 try:
2685 result = self._execute_task(context, task_orig)
2686 except Exception:
2687 # If the task failed, swallow rendering error so it doesn't mask the main error.
2688 with contextlib.suppress(jinja2.TemplateSyntaxError, jinja2.UndefinedError):
2689 self.rendered_map_index = _render_map_index(context, jinja_env=jinja_env)
2690 raise
2691 else: # If the task succeeded, render normally to let rendering error bubble up.
2692 self.rendered_map_index = _render_map_index(context, jinja_env=jinja_env)
2694 # Run post_execute callback
2695 self.task.post_execute(context=context, result=result)
2697 Stats.incr(f"operator_successes_{self.task.task_type}", tags=self.stats_tags)
2698 # Same metric with tagging
2699 Stats.incr("operator_successes", tags={**self.stats_tags, "task_type": self.task.task_type})
2700 Stats.incr("ti_successes", tags=self.stats_tags)
2702 def _execute_task(self, context: Context, task_orig: Operator):
2703 """
2704 Execute Task (optionally with a Timeout) and push Xcom results.
2706 :param context: Jinja2 context
2707 :param task_orig: origin task
2708 """
2709 return _execute_task(self, context, task_orig)
2711 @provide_session
2712 def defer_task(self, session: Session, defer: TaskDeferred) -> None:
2713 """Mark the task as deferred and sets up the trigger that is needed to resume it.
2715 :meta: private
2716 """
2717 from airflow.models.trigger import Trigger
2719 if TYPE_CHECKING:
2720 assert self.task
2722 # First, make the trigger entry
2723 trigger_row = Trigger.from_object(defer.trigger)
2724 session.add(trigger_row)
2725 session.flush()
2727 # Then, update ourselves so it matches the deferral request
2728 # Keep an eye on the logic in `check_and_change_state_before_execution()`
2729 # depending on self.next_method semantics
2730 self.state = TaskInstanceState.DEFERRED
2731 self.trigger_id = trigger_row.id
2732 self.next_method = defer.method_name
2733 self.next_kwargs = defer.kwargs or {}
2735 # Calculate timeout too if it was passed
2736 if defer.timeout is not None:
2737 self.trigger_timeout = timezone.utcnow() + defer.timeout
2738 else:
2739 self.trigger_timeout = None
2741 # If an execution_timeout is set, set the timeout to the minimum of
2742 # it and the trigger timeout
2743 execution_timeout = self.task.execution_timeout
2744 if execution_timeout:
2745 if self.trigger_timeout:
2746 self.trigger_timeout = min(self.start_date + execution_timeout, self.trigger_timeout)
2747 else:
2748 self.trigger_timeout = self.start_date + execution_timeout
2750 def _run_execute_callback(self, context: Context, task: BaseOperator) -> None:
2751 """Functions that need to be run before a Task is executed."""
2752 if not (callbacks := task.on_execute_callback):
2753 return
2754 for callback in callbacks if isinstance(callbacks, list) else [callbacks]:
2755 try:
2756 callback(context)
2757 except Exception:
2758 self.log.exception("Failed when executing execute callback")
2760 @provide_session
2761 def run(
2762 self,
2763 verbose: bool = True,
2764 ignore_all_deps: bool = False,
2765 ignore_depends_on_past: bool = False,
2766 wait_for_past_depends_before_skipping: bool = False,
2767 ignore_task_deps: bool = False,
2768 ignore_ti_state: bool = False,
2769 mark_success: bool = False,
2770 test_mode: bool = False,
2771 job_id: str | None = None,
2772 pool: str | None = None,
2773 session: Session = NEW_SESSION,
2774 raise_on_defer: bool = False,
2775 ) -> None:
2776 """Run TaskInstance."""
2777 res = self.check_and_change_state_before_execution(
2778 verbose=verbose,
2779 ignore_all_deps=ignore_all_deps,
2780 ignore_depends_on_past=ignore_depends_on_past,
2781 wait_for_past_depends_before_skipping=wait_for_past_depends_before_skipping,
2782 ignore_task_deps=ignore_task_deps,
2783 ignore_ti_state=ignore_ti_state,
2784 mark_success=mark_success,
2785 test_mode=test_mode,
2786 job_id=job_id,
2787 pool=pool,
2788 session=session,
2789 )
2790 if not res:
2791 return
2793 self._run_raw_task(
2794 mark_success=mark_success,
2795 test_mode=test_mode,
2796 job_id=job_id,
2797 pool=pool,
2798 session=session,
2799 raise_on_defer=raise_on_defer,
2800 )
2802 def dry_run(self) -> None:
2803 """Only Renders Templates for the TI."""
2804 if TYPE_CHECKING:
2805 assert self.task
2807 self.task = self.task.prepare_for_execution()
2808 self.render_templates()
2809 if TYPE_CHECKING:
2810 assert isinstance(self.task, BaseOperator)
2811 self.task.dry_run()
2813 @provide_session
2814 def _handle_reschedule(
2815 self,
2816 actual_start_date: datetime,
2817 reschedule_exception: AirflowRescheduleException,
2818 test_mode: bool = False,
2819 session: Session = NEW_SESSION,
2820 ):
2821 # Don't record reschedule request in test mode
2822 if test_mode:
2823 return
2825 from airflow.models.dagrun import DagRun # Avoid circular import
2827 self.refresh_from_db(session)
2829 if TYPE_CHECKING:
2830 assert self.task
2832 self.end_date = timezone.utcnow()
2833 self.set_duration()
2835 # Lock DAG run to be sure not to get into a deadlock situation when trying to insert
2836 # TaskReschedule which apparently also creates lock on corresponding DagRun entity
2837 with_row_locks(
2838 session.query(DagRun).filter_by(
2839 dag_id=self.dag_id,
2840 run_id=self.run_id,
2841 ),
2842 session=session,
2843 ).one()
2845 # Log reschedule request
2846 session.add(
2847 TaskReschedule(
2848 self.task_id,
2849 self.dag_id,
2850 self.run_id,
2851 self.try_number,
2852 actual_start_date,
2853 self.end_date,
2854 reschedule_exception.reschedule_date,
2855 self.map_index,
2856 )
2857 )
2859 # set state
2860 self.state = TaskInstanceState.UP_FOR_RESCHEDULE
2862 self.clear_next_method_args()
2864 session.merge(self)
2865 session.commit()
2866 self.log.info("Rescheduling task, marking task as UP_FOR_RESCHEDULE")
2868 @staticmethod
2869 def get_truncated_error_traceback(error: BaseException, truncate_to: Callable) -> TracebackType | None:
2870 """
2871 Truncate the traceback of an exception to the first frame called from within a given function.
2873 :param error: exception to get traceback from
2874 :param truncate_to: Function to truncate TB to. Must have a ``__code__`` attribute
2876 :meta private:
2877 """
2878 tb = error.__traceback__
2879 code = truncate_to.__func__.__code__ # type: ignore[attr-defined]
2880 while tb is not None:
2881 if tb.tb_frame.f_code is code:
2882 return tb.tb_next
2883 tb = tb.tb_next
2884 return tb or error.__traceback__
2886 @classmethod
2887 @internal_api_call
2888 @provide_session
2889 def fetch_handle_failure_context(
2890 cls,
2891 ti: TaskInstance | TaskInstancePydantic,
2892 error: None | str | BaseException,
2893 test_mode: bool | None = None,
2894 context: Context | None = None,
2895 force_fail: bool = False,
2896 session: Session = NEW_SESSION,
2897 fail_stop: bool = False,
2898 ):
2899 """
2900 Handle Failure for the TaskInstance.
2902 :param fail_stop: if true, stop remaining tasks in dag
2903 """
2904 get_listener_manager().hook.on_task_instance_failed(
2905 previous_state=TaskInstanceState.RUNNING, task_instance=ti, error=error, session=session
2906 )
2908 if error:
2909 if isinstance(error, BaseException):
2910 tb = TaskInstance.get_truncated_error_traceback(error, truncate_to=ti._execute_task)
2911 cls.logger().error("Task failed with exception", exc_info=(type(error), error, tb))
2912 else:
2913 cls.logger().error("%s", error)
2914 if not test_mode:
2915 ti.refresh_from_db(session)
2917 ti.end_date = timezone.utcnow()
2918 ti.set_duration()
2920 Stats.incr(f"operator_failures_{ti.operator}", tags=ti.stats_tags)
2921 # Same metric with tagging
2922 Stats.incr("operator_failures", tags={**ti.stats_tags, "operator": ti.operator})
2923 Stats.incr("ti_failures", tags=ti.stats_tags)
2925 if not test_mode:
2926 session.add(Log(TaskInstanceState.FAILED.value, ti))
2928 # Log failure duration
2929 session.add(TaskFail(ti=ti))
2931 ti.clear_next_method_args()
2933 # In extreme cases (zombie in case of dag with parse error) we might _not_ have a Task.
2934 if context is None and getattr(ti, "task", None):
2935 context = ti.get_template_context(session)
2937 if context is not None:
2938 context["exception"] = error
2940 # Set state correctly and figure out how to log it and decide whether
2941 # to email
2943 # Note, callback invocation needs to be handled by caller of
2944 # _run_raw_task to avoid race conditions which could lead to duplicate
2945 # invocations or miss invocation.
2947 # Since this function is called only when the TaskInstance state is running,
2948 # try_number contains the current try_number (not the next). We
2949 # only mark task instance as FAILED if the next task instance
2950 # try_number exceeds the max_tries ... or if force_fail is truthy
2952 task: BaseOperator | None = None
2953 try:
2954 if getattr(ti, "task", None) and context:
2955 if TYPE_CHECKING:
2956 assert ti.task
2957 task = ti.task.unmap((context, session))
2958 except Exception:
2959 cls.logger().error("Unable to unmap task to determine if we need to send an alert email")
2961 if force_fail or not ti.is_eligible_to_retry():
2962 ti.state = TaskInstanceState.FAILED
2963 email_for_state = operator.attrgetter("email_on_failure")
2964 callbacks = task.on_failure_callback if task else None
2966 if task and fail_stop:
2967 _stop_remaining_tasks(task_instance=ti, session=session)
2968 else:
2969 if ti.state == TaskInstanceState.QUEUED:
2970 from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic
2972 if isinstance(ti, TaskInstancePydantic):
2973 # todo: (AIP-44) we should probably "coalesce" `ti` to TaskInstance before here
2974 # e.g. we could make refresh_from_db return a TI and replace ti with that
2975 raise RuntimeError("Expected TaskInstance here. Further AIP-44 work required.")
2976 # We increase the try_number to fail the task if it fails to start after sometime
2977 ti.state = State.UP_FOR_RETRY
2978 email_for_state = operator.attrgetter("email_on_retry")
2979 callbacks = task.on_retry_callback if task else None
2981 return {
2982 "ti": ti,
2983 "email_for_state": email_for_state,
2984 "task": task,
2985 "callbacks": callbacks,
2986 "context": context,
2987 }
2989 @staticmethod
2990 @internal_api_call
2991 @provide_session
2992 def save_to_db(ti: TaskInstance | TaskInstancePydantic, session: Session = NEW_SESSION):
2993 session.merge(ti)
2994 session.flush()
2996 @provide_session
2997 def handle_failure(
2998 self,
2999 error: None | str | BaseException,
3000 test_mode: bool | None = None,
3001 context: Context | None = None,
3002 force_fail: bool = False,
3003 session: Session = NEW_SESSION,
3004 ) -> None:
3005 """
3006 Handle Failure for a task instance.
3008 :param error: if specified, log the specific exception if thrown
3009 :param session: SQLAlchemy ORM Session
3010 :param test_mode: doesn't record success or failure in the DB if True
3011 :param context: Jinja2 context
3012 :param force_fail: if True, task does not retry
3013 """
3014 if TYPE_CHECKING:
3015 assert self.task
3016 assert self.task.dag
3017 try:
3018 fail_stop = self.task.dag.fail_stop
3019 except Exception:
3020 fail_stop = False
3021 _handle_failure(
3022 task_instance=self,
3023 error=error,
3024 session=session,
3025 test_mode=test_mode,
3026 context=context,
3027 force_fail=force_fail,
3028 fail_stop=fail_stop,
3029 )
3031 def is_eligible_to_retry(self):
3032 """Is task instance is eligible for retry."""
3033 return _is_eligible_to_retry(task_instance=self)
3035 def get_template_context(
3036 self,
3037 session: Session | None = None,
3038 ignore_param_exceptions: bool = True,
3039 ) -> Context:
3040 """
3041 Return TI Context.
3043 :param session: SQLAlchemy ORM Session
3044 :param ignore_param_exceptions: flag to suppress value exceptions while initializing the ParamsDict
3045 """
3046 return _get_template_context(
3047 task_instance=self,
3048 session=session,
3049 ignore_param_exceptions=ignore_param_exceptions,
3050 )
3052 @provide_session
3053 def get_rendered_template_fields(self, session: Session = NEW_SESSION) -> None:
3054 """
3055 Update task with rendered template fields for presentation in UI.
3057 If task has already run, will fetch from DB; otherwise will render.
3058 """
3059 from airflow.models.renderedtifields import RenderedTaskInstanceFields
3061 if TYPE_CHECKING:
3062 assert self.task
3064 rendered_task_instance_fields = RenderedTaskInstanceFields.get_templated_fields(self, session=session)
3065 if rendered_task_instance_fields:
3066 self.task = self.task.unmap(None)
3067 for field_name, rendered_value in rendered_task_instance_fields.items():
3068 setattr(self.task, field_name, rendered_value)
3069 return
3071 try:
3072 # If we get here, either the task hasn't run or the RTIF record was purged.
3073 from airflow.utils.log.secrets_masker import redact
3075 self.render_templates()
3076 for field_name in self.task.template_fields:
3077 rendered_value = getattr(self.task, field_name)
3078 setattr(self.task, field_name, redact(rendered_value, field_name))
3079 except (TemplateAssertionError, UndefinedError) as e:
3080 raise AirflowException(
3081 "Webserver does not have access to User-defined Macros or Filters "
3082 "when Dag Serialization is enabled. Hence for the task that have not yet "
3083 "started running, please use 'airflow tasks render' for debugging the "
3084 "rendering of template_fields."
3085 ) from e
3087 def overwrite_params_with_dag_run_conf(self, params: dict, dag_run: DagRun):
3088 """Overwrite Task Params with DagRun.conf."""
3089 if dag_run and dag_run.conf:
3090 self.log.debug("Updating task params (%s) with DagRun.conf (%s)", params, dag_run.conf)
3091 params.update(dag_run.conf)
3093 def render_templates(
3094 self, context: Context | None = None, jinja_env: jinja2.Environment | None = None
3095 ) -> Operator:
3096 """Render templates in the operator fields.
3098 If the task was originally mapped, this may replace ``self.task`` with
3099 the unmapped, fully rendered BaseOperator. The original ``self.task``
3100 before replacement is returned.
3101 """
3102 if not context:
3103 context = self.get_template_context()
3104 original_task = self.task
3106 if TYPE_CHECKING:
3107 assert original_task
3109 # If self.task is mapped, this call replaces self.task to point to the
3110 # unmapped BaseOperator created by this function! This is because the
3111 # MappedOperator is useless for template rendering, and we need to be
3112 # able to access the unmapped task instead.
3113 original_task.render_template_fields(context, jinja_env)
3115 return original_task
3117 def render_k8s_pod_yaml(self) -> dict | None:
3118 """Render the k8s pod yaml."""
3119 try:
3120 from airflow.providers.cncf.kubernetes.template_rendering import (
3121 render_k8s_pod_yaml as render_k8s_pod_yaml_from_provider,
3122 )
3123 except ImportError:
3124 raise RuntimeError(
3125 "You need to have the `cncf.kubernetes` provider installed to use this feature. "
3126 "Also rather than calling it directly you should import "
3127 "render_k8s_pod_yaml from airflow.providers.cncf.kubernetes.template_rendering "
3128 "and call it with TaskInstance as the first argument."
3129 )
3130 warnings.warn(
3131 "You should not call `task_instance.render_k8s_pod_yaml` directly. This method will be removed"
3132 "in Airflow 3. Rather than calling it directly you should import "
3133 "`render_k8s_pod_yaml` from `airflow.providers.cncf.kubernetes.template_rendering` "
3134 "and call it with `TaskInstance` as the first argument.",
3135 DeprecationWarning,
3136 stacklevel=2,
3137 )
3138 return render_k8s_pod_yaml_from_provider(self)
3140 @provide_session
3141 def get_rendered_k8s_spec(self, session: Session = NEW_SESSION):
3142 """Render the k8s pod yaml."""
3143 try:
3144 from airflow.providers.cncf.kubernetes.template_rendering import (
3145 get_rendered_k8s_spec as get_rendered_k8s_spec_from_provider,
3146 )
3147 except ImportError:
3148 raise RuntimeError(
3149 "You need to have the `cncf.kubernetes` provider installed to use this feature. "
3150 "Also rather than calling it directly you should import "
3151 "`get_rendered_k8s_spec` from `airflow.providers.cncf.kubernetes.template_rendering` "
3152 "and call it with `TaskInstance` as the first argument."
3153 )
3154 warnings.warn(
3155 "You should not call `task_instance.render_k8s_pod_yaml` directly. This method will be removed"
3156 "in Airflow 3. Rather than calling it directly you should import "
3157 "`get_rendered_k8s_spec` from `airflow.providers.cncf.kubernetes.template_rendering` "
3158 "and call it with `TaskInstance` as the first argument.",
3159 DeprecationWarning,
3160 stacklevel=2,
3161 )
3162 return get_rendered_k8s_spec_from_provider(self, session=session)
3164 def get_email_subject_content(
3165 self, exception: BaseException, task: BaseOperator | None = None
3166 ) -> tuple[str, str, str]:
3167 """
3168 Get the email subject content for exceptions.
3170 :param exception: the exception sent in the email
3171 :param task:
3172 """
3173 return _get_email_subject_content(task_instance=self, exception=exception, task=task)
3175 def email_alert(self, exception, task: BaseOperator) -> None:
3176 """
3177 Send alert email with exception information.
3179 :param exception: the exception
3180 :param task: task related to the exception
3181 """
3182 _email_alert(task_instance=self, exception=exception, task=task)
3184 def set_duration(self) -> None:
3185 """Set task instance duration."""
3186 _set_duration(task_instance=self)
3188 @provide_session
3189 def xcom_push(
3190 self,
3191 key: str,
3192 value: Any,
3193 execution_date: datetime | None = None,
3194 session: Session = NEW_SESSION,
3195 ) -> None:
3196 """
3197 Make an XCom available for tasks to pull.
3199 :param key: Key to store the value under.
3200 :param value: Value to store. What types are possible depends on whether
3201 ``enable_xcom_pickling`` is true or not. If so, this can be any
3202 picklable object; only be JSON-serializable may be used otherwise.
3203 :param execution_date: Deprecated parameter that has no effect.
3204 """
3205 if execution_date is not None:
3206 self_execution_date = self.get_dagrun(session).execution_date
3207 if execution_date < self_execution_date:
3208 raise ValueError(
3209 f"execution_date can not be in the past (current execution_date is "
3210 f"{self_execution_date}; received {execution_date})"
3211 )
3212 elif execution_date is not None:
3213 message = "Passing 'execution_date' to 'TaskInstance.xcom_push()' is deprecated."
3214 warnings.warn(message, RemovedInAirflow3Warning, stacklevel=3)
3216 XCom.set(
3217 key=key,
3218 value=value,
3219 task_id=self.task_id,
3220 dag_id=self.dag_id,
3221 run_id=self.run_id,
3222 map_index=self.map_index,
3223 session=session,
3224 )
3226 @provide_session
3227 def xcom_pull(
3228 self,
3229 task_ids: str | Iterable[str] | None = None,
3230 dag_id: str | None = None,
3231 key: str = XCOM_RETURN_KEY,
3232 include_prior_dates: bool = False,
3233 session: Session = NEW_SESSION,
3234 *,
3235 map_indexes: int | Iterable[int] | None = None,
3236 default: Any = None,
3237 ) -> Any:
3238 """Pull XComs that optionally meet certain criteria.
3240 :param key: A key for the XCom. If provided, only XComs with matching
3241 keys will be returned. The default key is ``'return_value'``, also
3242 available as constant ``XCOM_RETURN_KEY``. This key is automatically
3243 given to XComs returned by tasks (as opposed to being pushed
3244 manually). To remove the filter, pass *None*.
3245 :param task_ids: Only XComs from tasks with matching ids will be
3246 pulled. Pass *None* to remove the filter.
3247 :param dag_id: If provided, only pulls XComs from this DAG. If *None*
3248 (default), the DAG of the calling task is used.
3249 :param map_indexes: If provided, only pull XComs with matching indexes.
3250 If *None* (default), this is inferred from the task(s) being pulled
3251 (see below for details).
3252 :param include_prior_dates: If False, only XComs from the current
3253 execution_date are returned. If *True*, XComs from previous dates
3254 are returned as well.
3256 When pulling one single task (``task_id`` is *None* or a str) without
3257 specifying ``map_indexes``, the return value is inferred from whether
3258 the specified task is mapped. If not, value from the one single task
3259 instance is returned. If the task to pull is mapped, an iterator (not a
3260 list) yielding XComs from mapped task instances is returned. In either
3261 case, ``default`` (*None* if not specified) is returned if no matching
3262 XComs are found.
3264 When pulling multiple tasks (i.e. either ``task_id`` or ``map_index`` is
3265 a non-str iterable), a list of matching XComs is returned. Elements in
3266 the list is ordered by item ordering in ``task_id`` and ``map_index``.
3267 """
3268 if dag_id is None:
3269 dag_id = self.dag_id
3271 query = XCom.get_many(
3272 key=key,
3273 run_id=self.run_id,
3274 dag_ids=dag_id,
3275 task_ids=task_ids,
3276 map_indexes=map_indexes,
3277 include_prior_dates=include_prior_dates,
3278 session=session,
3279 )
3281 # NOTE: Since we're only fetching the value field and not the whole
3282 # class, the @recreate annotation does not kick in. Therefore we need to
3283 # call XCom.deserialize_value() manually.
3285 # We are only pulling one single task.
3286 if (task_ids is None or isinstance(task_ids, str)) and not isinstance(map_indexes, Iterable):
3287 first = query.with_entities(
3288 XCom.run_id, XCom.task_id, XCom.dag_id, XCom.map_index, XCom.value
3289 ).first()
3290 if first is None: # No matching XCom at all.
3291 return default
3292 if map_indexes is not None or first.map_index < 0:
3293 return XCom.deserialize_value(first)
3294 return LazyXComSelectSequence.from_select(
3295 query.with_entities(XCom.value).order_by(None).statement,
3296 order_by=[XCom.map_index],
3297 session=session,
3298 )
3300 # At this point either task_ids or map_indexes is explicitly multi-value.
3301 # Order return values to match task_ids and map_indexes ordering.
3302 ordering = []
3303 if task_ids is None or isinstance(task_ids, str):
3304 ordering.append(XCom.task_id)
3305 elif task_id_whens := {tid: i for i, tid in enumerate(task_ids)}:
3306 ordering.append(case(task_id_whens, value=XCom.task_id))
3307 else:
3308 ordering.append(XCom.task_id)
3309 if map_indexes is None or isinstance(map_indexes, int):
3310 ordering.append(XCom.map_index)
3311 elif isinstance(map_indexes, range):
3312 order = XCom.map_index
3313 if map_indexes.step < 0:
3314 order = order.desc()
3315 ordering.append(order)
3316 elif map_index_whens := {map_index: i for i, map_index in enumerate(map_indexes)}:
3317 ordering.append(case(map_index_whens, value=XCom.map_index))
3318 else:
3319 ordering.append(XCom.map_index)
3320 return LazyXComSelectSequence.from_select(
3321 query.with_entities(XCom.value).order_by(None).statement,
3322 order_by=ordering,
3323 session=session,
3324 )
3326 @provide_session
3327 def get_num_running_task_instances(self, session: Session, same_dagrun: bool = False) -> int:
3328 """Return Number of running TIs from the DB."""
3329 # .count() is inefficient
3330 num_running_task_instances_query = session.query(func.count()).filter(
3331 TaskInstance.dag_id == self.dag_id,
3332 TaskInstance.task_id == self.task_id,
3333 TaskInstance.state == TaskInstanceState.RUNNING,
3334 )
3335 if same_dagrun:
3336 num_running_task_instances_query = num_running_task_instances_query.filter(
3337 TaskInstance.run_id == self.run_id
3338 )
3339 return num_running_task_instances_query.scalar()
3341 def init_run_context(self, raw: bool = False) -> None:
3342 """Set the log context."""
3343 self.raw = raw
3344 self._set_context(self)
3346 @staticmethod
3347 def filter_for_tis(tis: Iterable[TaskInstance | TaskInstanceKey]) -> BooleanClauseList | None:
3348 """Return SQLAlchemy filter to query selected task instances."""
3349 # DictKeys type, (what we often pass here from the scheduler) is not directly indexable :(
3350 # Or it might be a generator, but we need to be able to iterate over it more than once
3351 tis = list(tis)
3353 if not tis:
3354 return None
3356 first = tis[0]
3358 dag_id = first.dag_id
3359 run_id = first.run_id
3360 map_index = first.map_index
3361 first_task_id = first.task_id
3363 # pre-compute the set of dag_id, run_id, map_indices and task_ids
3364 dag_ids, run_ids, map_indices, task_ids = set(), set(), set(), set()
3365 for t in tis:
3366 dag_ids.add(t.dag_id)
3367 run_ids.add(t.run_id)
3368 map_indices.add(t.map_index)
3369 task_ids.add(t.task_id)
3371 # Common path optimisations: when all TIs are for the same dag_id and run_id, or same dag_id
3372 # and task_id -- this can be over 150x faster for huge numbers of TIs (20k+)
3373 if dag_ids == {dag_id} and run_ids == {run_id} and map_indices == {map_index}:
3374 return and_(
3375 TaskInstance.dag_id == dag_id,
3376 TaskInstance.run_id == run_id,
3377 TaskInstance.map_index == map_index,
3378 TaskInstance.task_id.in_(task_ids),
3379 )
3380 if dag_ids == {dag_id} and task_ids == {first_task_id} and map_indices == {map_index}:
3381 return and_(
3382 TaskInstance.dag_id == dag_id,
3383 TaskInstance.run_id.in_(run_ids),
3384 TaskInstance.map_index == map_index,
3385 TaskInstance.task_id == first_task_id,
3386 )
3387 if dag_ids == {dag_id} and run_ids == {run_id} and task_ids == {first_task_id}:
3388 return and_(
3389 TaskInstance.dag_id == dag_id,
3390 TaskInstance.run_id == run_id,
3391 TaskInstance.map_index.in_(map_indices),
3392 TaskInstance.task_id == first_task_id,
3393 )
3395 filter_condition = []
3396 # create 2 nested groups, both primarily grouped by dag_id and run_id,
3397 # and in the nested group 1 grouped by task_id the other by map_index.
3398 task_id_groups: dict[tuple, dict[Any, list[Any]]] = defaultdict(lambda: defaultdict(list))
3399 map_index_groups: dict[tuple, dict[Any, list[Any]]] = defaultdict(lambda: defaultdict(list))
3400 for t in tis:
3401 task_id_groups[(t.dag_id, t.run_id)][t.task_id].append(t.map_index)
3402 map_index_groups[(t.dag_id, t.run_id)][t.map_index].append(t.task_id)
3404 # this assumes that most dags have dag_id as the largest grouping, followed by run_id. even
3405 # if its not, this is still a significant optimization over querying for every single tuple key
3406 for cur_dag_id, cur_run_id in itertools.product(dag_ids, run_ids):
3407 # we compare the group size between task_id and map_index and use the smaller group
3408 dag_task_id_groups = task_id_groups[(cur_dag_id, cur_run_id)]
3409 dag_map_index_groups = map_index_groups[(cur_dag_id, cur_run_id)]
3411 if len(dag_task_id_groups) <= len(dag_map_index_groups):
3412 for cur_task_id, cur_map_indices in dag_task_id_groups.items():
3413 filter_condition.append(
3414 and_(
3415 TaskInstance.dag_id == cur_dag_id,
3416 TaskInstance.run_id == cur_run_id,
3417 TaskInstance.task_id == cur_task_id,
3418 TaskInstance.map_index.in_(cur_map_indices),
3419 )
3420 )
3421 else:
3422 for cur_map_index, cur_task_ids in dag_map_index_groups.items():
3423 filter_condition.append(
3424 and_(
3425 TaskInstance.dag_id == cur_dag_id,
3426 TaskInstance.run_id == cur_run_id,
3427 TaskInstance.task_id.in_(cur_task_ids),
3428 TaskInstance.map_index == cur_map_index,
3429 )
3430 )
3432 return or_(*filter_condition)
3434 @classmethod
3435 def ti_selector_condition(cls, vals: Collection[str | tuple[str, int]]) -> ColumnOperators:
3436 """
3437 Build an SQLAlchemy filter for a list of task_ids or tuples of (task_id,map_index).
3439 :meta private:
3440 """
3441 # Compute a filter for TI.task_id and TI.map_index based on input values
3442 # For each item, it will either be a task_id, or (task_id, map_index)
3443 task_id_only = [v for v in vals if isinstance(v, str)]
3444 with_map_index = [v for v in vals if not isinstance(v, str)]
3446 filters: list[ColumnOperators] = []
3447 if task_id_only:
3448 filters.append(cls.task_id.in_(task_id_only))
3449 if with_map_index:
3450 filters.append(tuple_in_condition((cls.task_id, cls.map_index), with_map_index))
3452 if not filters:
3453 return false()
3454 if len(filters) == 1:
3455 return filters[0]
3456 return or_(*filters)
3458 @classmethod
3459 @internal_api_call
3460 @provide_session
3461 def _schedule_downstream_tasks(
3462 cls,
3463 ti: TaskInstance | TaskInstancePydantic,
3464 session: Session = NEW_SESSION,
3465 max_tis_per_query: int | None = None,
3466 ):
3467 from sqlalchemy.exc import OperationalError
3469 from airflow.models.dagrun import DagRun
3471 try:
3472 # Re-select the row with a lock
3473 dag_run = with_row_locks(
3474 session.query(DagRun).filter_by(
3475 dag_id=ti.dag_id,
3476 run_id=ti.run_id,
3477 ),
3478 session=session,
3479 nowait=True,
3480 ).one()
3482 task = ti.task
3483 if TYPE_CHECKING:
3484 assert task
3485 assert task.dag
3487 # Get a partial DAG with just the specific tasks we want to examine.
3488 # In order for dep checks to work correctly, we include ourself (so
3489 # TriggerRuleDep can check the state of the task we just executed).
3490 partial_dag = task.dag.partial_subset(
3491 task.downstream_task_ids,
3492 include_downstream=True,
3493 include_upstream=False,
3494 include_direct_upstream=True,
3495 )
3497 dag_run.dag = partial_dag
3498 info = dag_run.task_instance_scheduling_decisions(session)
3500 skippable_task_ids = {
3501 task_id for task_id in partial_dag.task_ids if task_id not in task.downstream_task_ids
3502 }
3504 schedulable_tis = [
3505 ti
3506 for ti in info.schedulable_tis
3507 if ti.task_id not in skippable_task_ids
3508 and not (
3509 ti.task.inherits_from_empty_operator
3510 and not ti.task.on_execute_callback
3511 and not ti.task.on_success_callback
3512 and not ti.task.outlets
3513 )
3514 ]
3515 for schedulable_ti in schedulable_tis:
3516 if getattr(schedulable_ti, "task", None) is None:
3517 schedulable_ti.task = task.dag.get_task(schedulable_ti.task_id)
3519 num = dag_run.schedule_tis(schedulable_tis, session=session, max_tis_per_query=max_tis_per_query)
3520 cls.logger().info("%d downstream tasks scheduled from follow-on schedule check", num)
3522 session.flush()
3524 except OperationalError as e:
3525 # Any kind of DB error here is _non fatal_ as this block is just an optimisation.
3526 cls.logger().debug(
3527 "Skipping mini scheduling run due to exception: %s",
3528 e.statement,
3529 exc_info=True,
3530 )
3531 session.rollback()
3533 @provide_session
3534 def schedule_downstream_tasks(self, session: Session = NEW_SESSION, max_tis_per_query: int | None = None):
3535 """
3536 Schedule downstream tasks of this task instance.
3538 :meta: private
3539 """
3540 return TaskInstance._schedule_downstream_tasks(
3541 ti=self, session=session, max_tis_per_query=max_tis_per_query
3542 )
3544 def get_relevant_upstream_map_indexes(
3545 self,
3546 upstream: Operator,
3547 ti_count: int | None,
3548 *,
3549 session: Session,
3550 ) -> int | range | None:
3551 """Infer the map indexes of an upstream "relevant" to this ti.
3553 The bulk of the logic mainly exists to solve the problem described by
3554 the following example, where 'val' must resolve to different values,
3555 depending on where the reference is being used::
3557 @task
3558 def this_task(v): # This is self.task.
3559 return v * 2
3562 @task_group
3563 def tg1(inp):
3564 val = upstream(inp) # This is the upstream task.
3565 this_task(val) # When inp is 1, val here should resolve to 2.
3566 return val
3569 # This val is the same object returned by tg1.
3570 val = tg1.expand(inp=[1, 2, 3])
3573 @task_group
3574 def tg2(inp):
3575 another_task(inp, val) # val here should resolve to [2, 4, 6].
3578 tg2.expand(inp=["a", "b"])
3580 The surrounding mapped task groups of ``upstream`` and ``self.task`` are
3581 inspected to find a common "ancestor". If such an ancestor is found,
3582 we need to return specific map indexes to pull a partial value from
3583 upstream XCom.
3585 :param upstream: The referenced upstream task.
3586 :param ti_count: The total count of task instance this task was expanded
3587 by the scheduler, i.e. ``expanded_ti_count`` in the template context.
3588 :return: Specific map index or map indexes to pull, or ``None`` if we
3589 want to "whole" return value (i.e. no mapped task groups involved).
3590 """
3591 if TYPE_CHECKING:
3592 assert self.task
3594 # This value should never be None since we already know the current task
3595 # is in a mapped task group, and should have been expanded, despite that,
3596 # we need to check that it is not None to satisfy Mypy.
3597 # But this value can be 0 when we expand an empty list, for that it is
3598 # necessary to check that ti_count is not 0 to avoid dividing by 0.
3599 if not ti_count:
3600 return None
3602 # Find the innermost common mapped task group between the current task
3603 # If the current task and the referenced task does not have a common
3604 # mapped task group, the two are in different task mapping contexts
3605 # (like another_task above), and we should use the "whole" value.
3606 common_ancestor = _find_common_ancestor_mapped_group(self.task, upstream)
3607 if common_ancestor is None:
3608 return None
3610 # At this point we know the two tasks share a mapped task group, and we
3611 # should use a "partial" value. Let's break down the mapped ti count
3612 # between the ancestor and further expansion happened inside it.
3613 ancestor_ti_count = common_ancestor.get_mapped_ti_count(self.run_id, session=session)
3614 ancestor_map_index = self.map_index * ancestor_ti_count // ti_count
3616 # If the task is NOT further expanded inside the common ancestor, we
3617 # only want to reference one single ti. We must walk the actual DAG,
3618 # and "ti_count == ancestor_ti_count" does not work, since the further
3619 # expansion may be of length 1.
3620 if not _is_further_mapped_inside(upstream, common_ancestor):
3621 return ancestor_map_index
3623 # Otherwise we need a partial aggregation for values from selected task
3624 # instances in the ancestor's expansion context.
3625 further_count = ti_count // ancestor_ti_count
3626 map_index_start = ancestor_map_index * further_count
3627 return range(map_index_start, map_index_start + further_count)
3629 def clear_db_references(self, session: Session):
3630 """
3631 Clear db tables that have a reference to this instance.
3633 :param session: ORM Session
3635 :meta private:
3636 """
3637 from airflow.models.renderedtifields import RenderedTaskInstanceFields
3639 tables: list[type[TaskInstanceDependencies]] = [
3640 TaskFail,
3641 TaskInstanceNote,
3642 TaskReschedule,
3643 XCom,
3644 RenderedTaskInstanceFields,
3645 TaskMap,
3646 ]
3647 for table in tables:
3648 session.execute(
3649 delete(table).where(
3650 table.dag_id == self.dag_id,
3651 table.task_id == self.task_id,
3652 table.run_id == self.run_id,
3653 table.map_index == self.map_index,
3654 )
3655 )
3658def _find_common_ancestor_mapped_group(node1: Operator, node2: Operator) -> MappedTaskGroup | None:
3659 """Given two operators, find their innermost common mapped task group."""
3660 if node1.dag is None or node2.dag is None or node1.dag_id != node2.dag_id:
3661 return None
3662 parent_group_ids = {g.group_id for g in node1.iter_mapped_task_groups()}
3663 common_groups = (g for g in node2.iter_mapped_task_groups() if g.group_id in parent_group_ids)
3664 return next(common_groups, None)
3667def _is_further_mapped_inside(operator: Operator, container: TaskGroup) -> bool:
3668 """Whether given operator is *further* mapped inside a task group."""
3669 if isinstance(operator, MappedOperator):
3670 return True
3671 task_group = operator.task_group
3672 while task_group is not None and task_group.group_id != container.group_id:
3673 if isinstance(task_group, MappedTaskGroup):
3674 return True
3675 task_group = task_group.parent_group
3676 return False
3679# State of the task instance.
3680# Stores string version of the task state.
3681TaskInstanceStateType = Tuple[TaskInstanceKey, TaskInstanceState]
3684class SimpleTaskInstance:
3685 """
3686 Simplified Task Instance.
3688 Used to send data between processes via Queues.
3689 """
3691 def __init__(
3692 self,
3693 dag_id: str,
3694 task_id: str,
3695 run_id: str,
3696 start_date: datetime | None,
3697 end_date: datetime | None,
3698 try_number: int,
3699 map_index: int,
3700 state: str,
3701 executor: str | None,
3702 executor_config: Any,
3703 pool: str,
3704 queue: str,
3705 key: TaskInstanceKey,
3706 run_as_user: str | None = None,
3707 priority_weight: int | None = None,
3708 ):
3709 self.dag_id = dag_id
3710 self.task_id = task_id
3711 self.run_id = run_id
3712 self.map_index = map_index
3713 self.start_date = start_date
3714 self.end_date = end_date
3715 self.try_number = try_number
3716 self.state = state
3717 self.executor = executor
3718 self.executor_config = executor_config
3719 self.run_as_user = run_as_user
3720 self.pool = pool
3721 self.priority_weight = priority_weight
3722 self.queue = queue
3723 self.key = key
3725 def __eq__(self, other):
3726 if isinstance(other, self.__class__):
3727 return self.__dict__ == other.__dict__
3728 return NotImplemented
3730 def as_dict(self):
3731 warnings.warn(
3732 "This method is deprecated. Use BaseSerialization.serialize.",
3733 RemovedInAirflow3Warning,
3734 stacklevel=2,
3735 )
3736 new_dict = dict(self.__dict__)
3737 for key in new_dict:
3738 if key in ["start_date", "end_date"]:
3739 val = new_dict[key]
3740 if not val or isinstance(val, str):
3741 continue
3742 new_dict.update({key: val.isoformat()})
3743 return new_dict
3745 @classmethod
3746 def from_ti(cls, ti: TaskInstance) -> SimpleTaskInstance:
3747 return cls(
3748 dag_id=ti.dag_id,
3749 task_id=ti.task_id,
3750 run_id=ti.run_id,
3751 map_index=ti.map_index,
3752 start_date=ti.start_date,
3753 end_date=ti.end_date,
3754 try_number=ti.try_number,
3755 state=ti.state,
3756 executor=ti.executor,
3757 executor_config=ti.executor_config,
3758 pool=ti.pool,
3759 queue=ti.queue,
3760 key=ti.key,
3761 run_as_user=ti.run_as_user if hasattr(ti, "run_as_user") else None,
3762 priority_weight=ti.priority_weight if hasattr(ti, "priority_weight") else None,
3763 )
3765 @classmethod
3766 def from_dict(cls, obj_dict: dict) -> SimpleTaskInstance:
3767 warnings.warn(
3768 "This method is deprecated. Use BaseSerialization.deserialize.",
3769 RemovedInAirflow3Warning,
3770 stacklevel=2,
3771 )
3772 ti_key = TaskInstanceKey(*obj_dict.pop("key"))
3773 start_date = None
3774 end_date = None
3775 start_date_str: str | None = obj_dict.pop("start_date")
3776 end_date_str: str | None = obj_dict.pop("end_date")
3777 if start_date_str:
3778 start_date = timezone.parse(start_date_str)
3779 if end_date_str:
3780 end_date = timezone.parse(end_date_str)
3781 return cls(**obj_dict, start_date=start_date, end_date=end_date, key=ti_key)
3784class TaskInstanceNote(TaskInstanceDependencies):
3785 """For storage of arbitrary notes concerning the task instance."""
3787 __tablename__ = "task_instance_note"
3789 user_id = Column(Integer, ForeignKey("ab_user.id", name="task_instance_note_user_fkey"), nullable=True)
3790 task_id = Column(StringID(), primary_key=True, nullable=False)
3791 dag_id = Column(StringID(), primary_key=True, nullable=False)
3792 run_id = Column(StringID(), primary_key=True, nullable=False)
3793 map_index = Column(Integer, primary_key=True, nullable=False)
3794 content = Column(String(1000).with_variant(Text(1000), "mysql"))
3795 created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False)
3796 updated_at = Column(UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow, nullable=False)
3798 task_instance = relationship("TaskInstance", back_populates="task_instance_note")
3800 __table_args__ = (
3801 PrimaryKeyConstraint("task_id", "dag_id", "run_id", "map_index", name="task_instance_note_pkey"),
3802 ForeignKeyConstraint(
3803 (dag_id, task_id, run_id, map_index),
3804 [
3805 "task_instance.dag_id",
3806 "task_instance.task_id",
3807 "task_instance.run_id",
3808 "task_instance.map_index",
3809 ],
3810 name="task_instance_note_ti_fkey",
3811 ondelete="CASCADE",
3812 ),
3813 )
3815 def __init__(self, content, user_id=None):
3816 self.content = content
3817 self.user_id = user_id
3819 def __repr__(self):
3820 prefix = f"<{self.__class__.__name__}: {self.dag_id}.{self.task_id} {self.run_id}"
3821 if self.map_index != -1:
3822 prefix += f" map_index={self.map_index}"
3823 return prefix + ">"
3826STATICA_HACK = True
3827globals()["kcah_acitats"[::-1].upper()] = False
3828if STATICA_HACK: # pragma: no cover
3829 from airflow.jobs.job import Job
3831 TaskInstance.queued_by_job = relationship(Job)