1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18"""The entrypoint for the actual task execution process."""
19
20from __future__ import annotations
21
22import contextlib
23import contextvars
24import functools
25import os
26import sys
27import time
28from collections.abc import Callable, Iterable, Iterator, Mapping
29from contextlib import suppress
30from datetime import datetime, timedelta, timezone
31from itertools import product
32from pathlib import Path
33from typing import TYPE_CHECKING, Annotated, Any, Literal
34from urllib.parse import quote
35
36import attrs
37import lazy_object_proxy
38import structlog
39from pydantic import AwareDatetime, ConfigDict, Field, JsonValue, TypeAdapter
40
41from airflow.dag_processing.bundles.base import BaseDagBundle, BundleVersionLock
42from airflow.dag_processing.bundles.manager import DagBundlesManager
43from airflow.sdk._shared.observability.metrics.stats import Stats
44from airflow.sdk.api.client import get_hostname, getuser
45from airflow.sdk.api.datamodels._generated import (
46 AssetProfile,
47 DagRun,
48 PreviousTIResponse,
49 TaskInstance,
50 TaskInstanceState,
51 TIRunContext,
52)
53from airflow.sdk.bases.operator import BaseOperator, ExecutorSafeguard
54from airflow.sdk.bases.xcom import BaseXCom
55from airflow.sdk.configuration import conf
56from airflow.sdk.definitions._internal.dag_parsing_context import _airflow_parsing_context_manager
57from airflow.sdk.definitions._internal.types import NOTSET, ArgNotSet, is_arg_set
58from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetNameRef, AssetUniqueKey, AssetUriRef
59from airflow.sdk.definitions.mappedoperator import MappedOperator
60from airflow.sdk.definitions.param import process_params
61from airflow.sdk.exceptions import (
62 AirflowException,
63 AirflowInactiveAssetInInletOrOutletException,
64 AirflowRescheduleException,
65 AirflowRuntimeError,
66 AirflowTaskTimeout,
67 ErrorType,
68 TaskDeferred,
69)
70from airflow.sdk.execution_time.callback_runner import create_executable_runner
71from airflow.sdk.execution_time.comms import (
72 AssetEventDagRunReferenceResult,
73 CommsDecoder,
74 DagRunStateResult,
75 DeferTask,
76 DRCount,
77 ErrorResponse,
78 GetDagRunState,
79 GetDRCount,
80 GetPreviousDagRun,
81 GetPreviousTI,
82 GetTaskBreadcrumbs,
83 GetTaskRescheduleStartDate,
84 GetTaskStates,
85 GetTICount,
86 InactiveAssetsResult,
87 PreviousDagRunResult,
88 PreviousTIResult,
89 RescheduleTask,
90 ResendLoggingFD,
91 RetryTask,
92 SentFDs,
93 SetRenderedFields,
94 SetRenderedMapIndex,
95 SkipDownstreamTasks,
96 StartupDetails,
97 SucceedTask,
98 TaskBreadcrumbsResult,
99 TaskRescheduleStartDate,
100 TaskState,
101 TaskStatesResult,
102 TICount,
103 ToSupervisor,
104 ToTask,
105 TriggerDagRun,
106 ValidateInletsAndOutlets,
107)
108from airflow.sdk.execution_time.context import (
109 ConnectionAccessor,
110 InletEventsAccessors,
111 MacrosAccessor,
112 OutletEventAccessors,
113 TriggeringAssetEventsAccessor,
114 VariableAccessor,
115 context_get_outlet_events,
116 context_to_airflow_vars,
117 get_previous_dagrun_success,
118 set_current_context,
119)
120from airflow.sdk.execution_time.sentry import Sentry
121from airflow.sdk.execution_time.xcom import XCom
122from airflow.sdk.listener import get_listener_manager
123from airflow.sdk.timezone import coerce_datetime
124
125if TYPE_CHECKING:
126 import jinja2
127 from pendulum.datetime import DateTime
128 from structlog.typing import FilteringBoundLogger as Logger
129
130 from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator
131 from airflow.sdk.definitions.context import Context
132 from airflow.sdk.exceptions import DagRunTriggerException
133 from airflow.sdk.types import OutletEventAccessorsProtocol
134
135
136class TaskRunnerMarker:
137 """Marker for listener hooks, to properly detect from which component they are called."""
138
139
140# TODO: Move this entire class into a separate file:
141# `airflow/sdk/execution_time/task_instance.py`
142# or `airflow/sdk/execution_time/runtime_ti.py`
143class RuntimeTaskInstance(TaskInstance):
144 model_config = ConfigDict(arbitrary_types_allowed=True)
145
146 task: BaseOperator
147 bundle_instance: BaseDagBundle
148 _cached_template_context: Context | None = None
149 """The Task Instance context. This is used to cache get_template_context."""
150
151 _ti_context_from_server: Annotated[TIRunContext | None, Field(repr=False)] = None
152 """The Task Instance context from the API server, if any."""
153
154 max_tries: int = 0
155 """The maximum number of retries for the task."""
156
157 start_date: AwareDatetime
158 """Start date of the task instance."""
159
160 end_date: AwareDatetime | None = None
161
162 state: TaskInstanceState | None = None
163
164 is_mapped: bool | None = None
165 """True if the original task was mapped."""
166
167 rendered_map_index: str | None = None
168
169 sentry_integration: str = ""
170
171 def __rich_repr__(self):
172 yield "id", self.id
173 yield "task_id", self.task_id
174 yield "dag_id", self.dag_id
175 yield "run_id", self.run_id
176 yield "max_tries", self.max_tries
177 yield "task", type(self.task)
178 yield "start_date", self.start_date
179
180 __rich_repr__.angular = True # type: ignore[attr-defined]
181
182 def get_template_context(self) -> Context:
183 # TODO: Move this to `airflow.sdk.execution_time.context`
184 # once we port the entire context logic from airflow/utils/context.py ?
185 from airflow.sdk.plugins_manager import integrate_macros_plugins
186
187 integrate_macros_plugins()
188
189 dag_run_conf: dict[str, Any] | None = None
190 if from_server := self._ti_context_from_server:
191 dag_run_conf = from_server.dag_run.conf or dag_run_conf
192
193 validated_params = process_params(self.task.dag, self.task, dag_run_conf, suppress_exception=False)
194
195 # Cache the context object, which ensures that all calls to get_template_context
196 # are operating on the same context object.
197 self._cached_template_context: Context = self._cached_template_context or {
198 # From the Task Execution interface
199 "dag": self.task.dag,
200 "inlets": self.task.inlets,
201 "map_index_template": self.task.map_index_template,
202 "outlets": self.task.outlets,
203 "run_id": self.run_id,
204 "task": self.task,
205 "task_instance": self,
206 "ti": self,
207 "outlet_events": OutletEventAccessors(),
208 "inlet_events": InletEventsAccessors(self.task.inlets),
209 "macros": MacrosAccessor(),
210 "params": validated_params,
211 # TODO: Make this go through Public API longer term.
212 # "test_mode": task_instance.test_mode,
213 "var": {
214 "json": VariableAccessor(deserialize_json=True),
215 "value": VariableAccessor(deserialize_json=False),
216 },
217 "conn": ConnectionAccessor(),
218 }
219 if from_server:
220 dag_run = from_server.dag_run
221 context_from_server: Context = {
222 # TODO: Assess if we need to pass these through timezone.coerce_datetime
223 "dag_run": dag_run, # type: ignore[typeddict-item] # Removable after #46522
224 "triggering_asset_events": TriggeringAssetEventsAccessor.build(
225 AssetEventDagRunReferenceResult.from_asset_event_dag_run_reference(event)
226 for event in dag_run.consumed_asset_events
227 ),
228 "task_instance_key_str": f"{self.task.dag_id}__{self.task.task_id}__{dag_run.run_id}",
229 "task_reschedule_count": from_server.task_reschedule_count or 0,
230 "prev_start_date_success": lazy_object_proxy.Proxy(
231 lambda: coerce_datetime(get_previous_dagrun_success(self.id).start_date)
232 ),
233 "prev_end_date_success": lazy_object_proxy.Proxy(
234 lambda: coerce_datetime(get_previous_dagrun_success(self.id).end_date)
235 ),
236 }
237 self._cached_template_context.update(context_from_server)
238
239 if logical_date := coerce_datetime(dag_run.logical_date):
240 if TYPE_CHECKING:
241 assert isinstance(logical_date, DateTime)
242 ds = logical_date.strftime("%Y-%m-%d")
243 ds_nodash = ds.replace("-", "")
244 ts = logical_date.isoformat()
245 ts_nodash = logical_date.strftime("%Y%m%dT%H%M%S")
246 ts_nodash_with_tz = ts.replace("-", "").replace(":", "")
247 # logical_date and data_interval either coexist or be None together
248 self._cached_template_context.update(
249 {
250 # keys that depend on logical_date
251 "logical_date": logical_date,
252 "ds": ds,
253 "ds_nodash": ds_nodash,
254 "task_instance_key_str": f"{self.task.dag_id}__{self.task.task_id}__{ds_nodash}",
255 "ts": ts,
256 "ts_nodash": ts_nodash,
257 "ts_nodash_with_tz": ts_nodash_with_tz,
258 # keys that depend on data_interval
259 "data_interval_end": coerce_datetime(dag_run.data_interval_end),
260 "data_interval_start": coerce_datetime(dag_run.data_interval_start),
261 "prev_data_interval_start_success": lazy_object_proxy.Proxy(
262 lambda: coerce_datetime(get_previous_dagrun_success(self.id).data_interval_start)
263 ),
264 "prev_data_interval_end_success": lazy_object_proxy.Proxy(
265 lambda: coerce_datetime(get_previous_dagrun_success(self.id).data_interval_end)
266 ),
267 }
268 )
269
270 # Backward compatibility: old servers may still send upstream_map_indexes
271 upstream_map_indexes = getattr(from_server, "upstream_map_indexes", None)
272 if upstream_map_indexes is not None:
273 setattr(self, "_upstream_map_indexes", upstream_map_indexes)
274
275 return self._cached_template_context
276
277 def render_templates(
278 self, context: Context | None = None, jinja_env: jinja2.Environment | None = None
279 ) -> BaseOperator:
280 """
281 Render templates in the operator fields.
282
283 If the task was originally mapped, this may replace ``self.task`` with
284 the unmapped, fully rendered BaseOperator. The original ``self.task``
285 before replacement is returned.
286 """
287 if not context:
288 context = self.get_template_context()
289 original_task = self.task
290
291 if TYPE_CHECKING:
292 assert context
293
294 ti = context["ti"]
295
296 if TYPE_CHECKING:
297 assert original_task
298 assert self.task
299 assert ti.task
300
301 # If self.task is mapped, this call replaces self.task to point to the
302 # unmapped BaseOperator created by this function! This is because the
303 # MappedOperator is useless for template rendering, and we need to be
304 # able to access the unmapped task instead.
305 self.task.render_template_fields(context, jinja_env)
306 self.is_mapped = original_task.is_mapped
307 return original_task
308
309 def xcom_pull(
310 self,
311 task_ids: str | Iterable[str] | None = None,
312 dag_id: str | None = None,
313 key: str = BaseXCom.XCOM_RETURN_KEY,
314 include_prior_dates: bool = False,
315 *,
316 map_indexes: int | Iterable[int] | None | ArgNotSet = NOTSET,
317 default: Any = None,
318 run_id: str | None = None,
319 ) -> Any:
320 """
321 Pull XComs either from the API server (BaseXCom) or from the custom XCOM backend if configured.
322
323 The pull can be filtered optionally by certain criterion.
324
325 :param key: A key for the XCom. If provided, only XComs with matching
326 keys will be returned. The default key is ``'return_value'``, also
327 available as constant ``XCOM_RETURN_KEY``. This key is automatically
328 given to XComs returned by tasks (as opposed to being pushed
329 manually).
330 :param task_ids: Only XComs from tasks with matching ids will be
331 pulled. If *None* (default), the task_id of the calling task is used.
332 :param dag_id: If provided, only pulls XComs from this Dag. If *None*
333 (default), the Dag of the calling task is used.
334 :param map_indexes: If provided, only pull XComs with matching indexes.
335 If *None* (default), this is inferred from the task(s) being pulled
336 (see below for details).
337 :param include_prior_dates: If False, only XComs from the current
338 logical_date are returned. If *True*, XComs from previous dates
339 are returned as well.
340 :param run_id: If provided, only pulls XComs from a DagRun w/a matching run_id.
341 If *None* (default), the run_id of the calling task is used.
342
343 When pulling one single task (``task_id`` is *None* or a str) without
344 specifying ``map_indexes``, the return value is a single XCom entry
345 (map_indexes is set to map_index of the calling task instance).
346
347 When pulling task is mapped the specified ``map_index`` is used, so by default
348 pulling on mapped task will result in no matching XComs if the task instance
349 of the method call is not mapped. Otherwise, the map_index of the calling task
350 instance is used. Setting ``map_indexes`` to *None* will pull XCom as it would
351 from a non mapped task.
352
353 In either case, ``default`` (*None* if not specified) is returned if no
354 matching XComs are found.
355
356 When pulling multiple tasks (i.e. either ``task_id`` or ``map_index`` is
357 a non-str iterable), a list of matching XComs is returned. Elements in
358 the list is ordered by item ordering in ``task_id`` and ``map_index``.
359 """
360 if dag_id is None:
361 dag_id = self.dag_id
362 if run_id is None:
363 run_id = self.run_id
364
365 single_task_requested = isinstance(task_ids, (str, type(None)))
366 single_map_index_requested = isinstance(map_indexes, (int, type(None)))
367
368 if task_ids is None:
369 # default to the current task if not provided
370 task_ids = [self.task_id]
371 elif isinstance(task_ids, str):
372 task_ids = [task_ids]
373
374 # If map_indexes is not specified, pull xcoms from all map indexes for each task
375 if not is_arg_set(map_indexes):
376 xcoms: list[Any] = []
377 for t_id in task_ids:
378 values = XCom.get_all(
379 run_id=run_id,
380 key=key,
381 task_id=t_id,
382 dag_id=dag_id,
383 include_prior_dates=include_prior_dates,
384 )
385
386 if values is None:
387 xcoms.append(None)
388 else:
389 xcoms.extend(values)
390 # For single task pulling from unmapped task, return single value
391 if single_task_requested and len(xcoms) == 1:
392 return xcoms[0]
393 return xcoms
394
395 # Original logic when map_indexes is explicitly specified
396 map_indexes_iterable: Iterable[int | None] = []
397 if isinstance(map_indexes, int) or map_indexes is None:
398 map_indexes_iterable = [map_indexes]
399 elif isinstance(map_indexes, Iterable):
400 map_indexes_iterable = map_indexes
401 else:
402 raise TypeError(
403 f"Invalid type for map_indexes: expected int, iterable of ints, or None, got {type(map_indexes)}"
404 )
405
406 xcoms = []
407 for t_id, m_idx in product(task_ids, map_indexes_iterable):
408 value = XCom.get_one(
409 run_id=run_id,
410 key=key,
411 task_id=t_id,
412 dag_id=dag_id,
413 map_index=m_idx,
414 include_prior_dates=include_prior_dates,
415 )
416 if value is None:
417 xcoms.append(default)
418 else:
419 xcoms.append(value)
420
421 if single_task_requested and single_map_index_requested:
422 return xcoms[0]
423
424 return xcoms
425
426 def xcom_push(self, key: str, value: Any):
427 """
428 Make an XCom available for tasks to pull.
429
430 :param key: Key to store the value under.
431 :param value: Value to store. Only be JSON-serializable values may be used.
432 """
433 _xcom_push(self, key, value)
434
435 def get_relevant_upstream_map_indexes(
436 self, upstream: BaseOperator, ti_count: int | None, session: Any
437 ) -> int | range | None:
438 """
439 Compute the relevant upstream map indexes for XCom resolution.
440
441 :param upstream: The upstream operator
442 :param ti_count: The total count of task instances for this task's expansion
443 :param session: Not used (kept for API compatibility)
444 :return: None (use entire value), int (single index), or range (subset of indexes)
445 """
446 from airflow.sdk.execution_time.task_mapping import get_relevant_map_indexes, get_ti_count_for_task
447
448 map_index = self.map_index
449 if map_index is None or map_index < 0:
450 return None
451
452 # If ti_count not provided, we need to query it
453 if ti_count is None:
454 ti_count = get_ti_count_for_task(self.task_id, self.dag_id, self.run_id)
455
456 if not ti_count:
457 return None
458
459 return get_relevant_map_indexes(
460 task=self.task,
461 run_id=self.run_id,
462 map_index=map_index,
463 ti_count=ti_count,
464 relative=upstream,
465 dag_id=self.dag_id,
466 )
467
468 def get_first_reschedule_date(self, context: Context) -> AwareDatetime | None:
469 """Get the first reschedule date for the task instance if found, none otherwise."""
470 if context.get("task_reschedule_count", 0) == 0:
471 # If the task has not been rescheduled, there is no need to ask the supervisor
472 return None
473
474 max_tries: int = self.max_tries
475 retries: int = self.task.retries or 0
476 first_try_number = max_tries - retries + 1
477
478 log = structlog.get_logger(logger_name="task")
479
480 log.debug("Requesting first reschedule date from supervisor")
481
482 response = SUPERVISOR_COMMS.send(
483 msg=GetTaskRescheduleStartDate(ti_id=self.id, try_number=first_try_number)
484 )
485
486 if TYPE_CHECKING:
487 assert isinstance(response, TaskRescheduleStartDate)
488
489 return response.start_date
490
491 def get_previous_dagrun(self, state: str | None = None) -> DagRun | None:
492 """Return the previous Dag run before the given logical date, optionally filtered by state."""
493 context = self.get_template_context()
494 dag_run = context.get("dag_run")
495
496 log = structlog.get_logger(logger_name="task")
497
498 log.debug("Getting previous Dag run", dag_run=dag_run)
499
500 if dag_run is None:
501 return None
502
503 if dag_run.logical_date is None:
504 return None
505
506 response = SUPERVISOR_COMMS.send(
507 msg=GetPreviousDagRun(dag_id=self.dag_id, logical_date=dag_run.logical_date, state=state)
508 )
509
510 if TYPE_CHECKING:
511 assert isinstance(response, PreviousDagRunResult)
512
513 return response.dag_run
514
515 def get_previous_ti(
516 self,
517 state: TaskInstanceState | None = None,
518 logical_date: AwareDatetime | None = None,
519 map_index: int = -1,
520 ) -> PreviousTIResponse | None:
521 """
522 Return the previous task instance matching the given criteria.
523
524 :param state: Filter by TaskInstance state
525 :param logical_date: Filter by logical date (returns TI before this date)
526 :param map_index: Filter by map_index (defaults to -1 for non-mapped tasks)
527 :return: Previous task instance or None if not found
528 """
529 context = self.get_template_context()
530 dag_run = context.get("dag_run")
531
532 log = structlog.get_logger(logger_name="task")
533 log.debug("Getting previous task instance", task_id=self.task_id, state=state)
534
535 # Use current dag run's logical_date if not provided
536 effective_logical_date = logical_date
537 if effective_logical_date is None and dag_run and dag_run.logical_date:
538 effective_logical_date = dag_run.logical_date
539
540 response = SUPERVISOR_COMMS.send(
541 msg=GetPreviousTI(
542 dag_id=self.dag_id,
543 task_id=self.task_id,
544 logical_date=effective_logical_date,
545 map_index=map_index,
546 state=state,
547 )
548 )
549
550 if TYPE_CHECKING:
551 assert isinstance(response, PreviousTIResult)
552
553 return response.task_instance
554
555 @staticmethod
556 def get_ti_count(
557 dag_id: str,
558 map_index: int | None = None,
559 task_ids: list[str] | None = None,
560 task_group_id: str | None = None,
561 logical_dates: list[datetime] | None = None,
562 run_ids: list[str] | None = None,
563 states: list[str] | None = None,
564 ) -> int:
565 """Return the number of task instances matching the given criteria."""
566 response = SUPERVISOR_COMMS.send(
567 GetTICount(
568 dag_id=dag_id,
569 map_index=map_index,
570 task_ids=task_ids,
571 task_group_id=task_group_id,
572 logical_dates=logical_dates,
573 run_ids=run_ids,
574 states=states,
575 ),
576 )
577
578 if TYPE_CHECKING:
579 assert isinstance(response, TICount)
580
581 return response.count
582
583 @staticmethod
584 def get_task_states(
585 dag_id: str,
586 map_index: int | None = None,
587 task_ids: list[str] | None = None,
588 task_group_id: str | None = None,
589 logical_dates: list[datetime] | None = None,
590 run_ids: list[str] | None = None,
591 ) -> dict[str, Any]:
592 """Return the task states matching the given criteria."""
593 response = SUPERVISOR_COMMS.send(
594 GetTaskStates(
595 dag_id=dag_id,
596 map_index=map_index,
597 task_ids=task_ids,
598 task_group_id=task_group_id,
599 logical_dates=logical_dates,
600 run_ids=run_ids,
601 ),
602 )
603
604 if TYPE_CHECKING:
605 assert isinstance(response, TaskStatesResult)
606
607 return response.task_states
608
609 @staticmethod
610 def get_task_breadcrumbs(dag_id: str, run_id: str) -> Iterable[dict[str, Any]]:
611 """Return task breadcrumbs for the given dag run."""
612 response = SUPERVISOR_COMMS.send(GetTaskBreadcrumbs(dag_id=dag_id, run_id=run_id))
613 if TYPE_CHECKING:
614 assert isinstance(response, TaskBreadcrumbsResult)
615 return response.breadcrumbs
616
617 @staticmethod
618 def get_dr_count(
619 dag_id: str,
620 logical_dates: list[datetime] | None = None,
621 run_ids: list[str] | None = None,
622 states: list[str] | None = None,
623 ) -> int:
624 """Return the number of Dag runs matching the given criteria."""
625 response = SUPERVISOR_COMMS.send(
626 GetDRCount(
627 dag_id=dag_id,
628 logical_dates=logical_dates,
629 run_ids=run_ids,
630 states=states,
631 ),
632 )
633
634 if TYPE_CHECKING:
635 assert isinstance(response, DRCount)
636
637 return response.count
638
639 @staticmethod
640 def get_dagrun_state(dag_id: str, run_id: str) -> str:
641 """Return the state of the Dag run with the given Run ID."""
642 response = SUPERVISOR_COMMS.send(msg=GetDagRunState(dag_id=dag_id, run_id=run_id))
643
644 if TYPE_CHECKING:
645 assert isinstance(response, DagRunStateResult)
646
647 return response.state
648
649 @property
650 def log_url(self) -> str:
651 run_id = quote(self.run_id)
652 base_url = conf.get("api", "base_url", fallback="http://localhost:8080/")
653 map_index_value = self.map_index
654 map_index = (
655 f"/mapped/{map_index_value}" if map_index_value is not None and map_index_value >= 0 else ""
656 )
657 try_number_value = self.try_number
658 try_number = (
659 f"?try_number={try_number_value}" if try_number_value is not None and try_number_value > 0 else ""
660 )
661 _log_uri = f"{base_url.rstrip('/')}/dags/{self.dag_id}/runs/{run_id}/tasks/{self.task_id}{map_index}{try_number}"
662 return _log_uri
663
664 @property
665 def mark_success_url(self) -> str:
666 """URL to mark TI success."""
667 return self.log_url
668
669
670def _xcom_push(ti: RuntimeTaskInstance, key: str, value: Any, mapped_length: int | None = None) -> None:
671 """Push a XCom through XCom.set, which pushes to XCom Backend if configured."""
672 # Private function, as we don't want to expose the ability to manually set `mapped_length` to SDK
673 # consumers
674
675 XCom.set(
676 key=key,
677 value=value,
678 dag_id=ti.dag_id,
679 task_id=ti.task_id,
680 run_id=ti.run_id,
681 map_index=ti.map_index,
682 _mapped_length=mapped_length,
683 )
684
685
686def _xcom_push_to_db(ti: RuntimeTaskInstance, key: str, value: Any) -> None:
687 """Push a XCom directly to metadata DB, bypassing custom xcom_backend."""
688 XCom._set_xcom_in_db(
689 key=key,
690 value=value,
691 dag_id=ti.dag_id,
692 task_id=ti.task_id,
693 run_id=ti.run_id,
694 map_index=ti.map_index,
695 )
696
697
698def _maybe_reschedule_startup_failure(
699 *,
700 ti_context: TIRunContext,
701 log: Logger,
702) -> None:
703 """
704 Attempt to reschedule the task when a startup failure occurs.
705
706 This does not count as a retry. If the reschedule limit is exceeded, this function
707 returns and the caller should fail the task.
708 """
709 missing_dag_retires = conf.getint("workers", "missing_dag_retires", fallback=3)
710 missing_dag_retry_delay = conf.getint("workers", "missing_dag_retry_delay", fallback=60)
711
712 reschedule_count = int(getattr(ti_context, "task_reschedule_count", 0) or 0)
713 if missing_dag_retires > 0 and reschedule_count < missing_dag_retires:
714 raise AirflowRescheduleException(
715 reschedule_date=datetime.now(tz=timezone.utc) + timedelta(seconds=missing_dag_retry_delay)
716 )
717
718 log.error(
719 "Startup reschedule limit exceeded",
720 reschedule_count=reschedule_count,
721 max_reschedules=missing_dag_retires,
722 )
723
724
725def parse(what: StartupDetails, log: Logger) -> RuntimeTaskInstance:
726 # TODO: Task-SDK:
727 # Using BundleDagBag here is about 98% wrong, but it'll do for now
728 from airflow.dag_processing.dagbag import BundleDagBag
729
730 bundle_info = what.bundle_info
731 bundle_instance = DagBundlesManager().get_bundle(
732 name=bundle_info.name,
733 version=bundle_info.version,
734 )
735 bundle_instance.initialize()
736 _verify_bundle_access(bundle_instance, log)
737
738 dag_absolute_path = os.fspath(Path(bundle_instance.path, what.dag_rel_path))
739 bag = BundleDagBag(
740 dag_folder=dag_absolute_path,
741 safe_mode=False,
742 load_op_links=False,
743 bundle_path=bundle_instance.path,
744 bundle_name=bundle_info.name,
745 )
746 if TYPE_CHECKING:
747 assert what.ti.dag_id
748
749 try:
750 dag = bag.dags[what.ti.dag_id]
751 except KeyError:
752 log.error(
753 "Dag not found during start up", dag_id=what.ti.dag_id, bundle=bundle_info, path=what.dag_rel_path
754 )
755 _maybe_reschedule_startup_failure(ti_context=what.ti_context, log=log)
756 sys.exit(1)
757
758 # install_loader()
759
760 try:
761 task = dag.task_dict[what.ti.task_id]
762 except KeyError:
763 log.error(
764 "Task not found in Dag during start up",
765 dag_id=dag.dag_id,
766 task_id=what.ti.task_id,
767 bundle=bundle_info,
768 path=what.dag_rel_path,
769 )
770 _maybe_reschedule_startup_failure(ti_context=what.ti_context, log=log)
771 sys.exit(1)
772
773 if not isinstance(task, (BaseOperator, MappedOperator)):
774 raise TypeError(
775 f"task is of the wrong type, got {type(task)}, wanted {BaseOperator} or {MappedOperator}"
776 )
777
778 return RuntimeTaskInstance.model_construct(
779 **what.ti.model_dump(exclude_unset=True),
780 task=task,
781 bundle_instance=bundle_instance,
782 _ti_context_from_server=what.ti_context,
783 max_tries=what.ti_context.max_tries,
784 start_date=what.start_date,
785 state=TaskInstanceState.RUNNING,
786 sentry_integration=what.sentry_integration,
787 )
788
789
790# This global variable will be used by Connection/Variable/XCom classes, or other parts of the task's execution,
791# to send requests back to the supervisor process.
792#
793# Why it needs to be a global:
794# - Many parts of Airflow's codebase (e.g., connections, variables, and XComs) may rely on making dynamic requests
795# to the parent process during task execution.
796# - These calls occur in various locations and cannot easily pass the `CommsDecoder` instance through the
797# deeply nested execution stack.
798# - By defining `SUPERVISOR_COMMS` as a global, it ensures that this communication mechanism is readily
799# accessible wherever needed during task execution without modifying every layer of the call stack.
800SUPERVISOR_COMMS: CommsDecoder[ToTask, ToSupervisor]
801
802
803# State machine!
804# 1. Start up (receive details from supervisor)
805# 2. Execution (run task code, possibly send requests)
806# 3. Shutdown and report status
807
808
809def _verify_bundle_access(bundle_instance: BaseDagBundle, log: Logger) -> None:
810 """
811 Verify bundle is accessible by the current user.
812
813 This is called after user impersonation (if any) to ensure the bundle
814 is actually accessible. Uses os.access() which works with any permission
815 scheme (standard Unix permissions, ACLs, SELinux, etc.).
816
817 :param bundle_instance: The bundle instance to check
818 :param log: Logger instance
819 :raises AirflowException: if bundle is not accessible
820 """
821 from getpass import getuser
822
823 from airflow.sdk.exceptions import AirflowException
824
825 bundle_path = bundle_instance.path
826
827 if not bundle_path.exists():
828 # Already handled by initialize() with a warning
829 return
830
831 # Check read permission (and execute for directories to list contents)
832 access_mode = os.R_OK
833 if bundle_path.is_dir():
834 access_mode |= os.X_OK
835
836 if not os.access(bundle_path, access_mode):
837 raise AirflowException(
838 f"Bundle '{bundle_instance.name}' path '{bundle_path}' is not accessible "
839 f"by user '{getuser()}'. When using run_as_user, ensure bundle directories "
840 f"are readable by the impersonated user. "
841 f"See: https://airflow.apache.org/docs/apache-airflow/stable/administration-and-deployment/dag-bundles.html"
842 )
843
844
845def startup() -> tuple[RuntimeTaskInstance, Context, Logger]:
846 # The parent sends us a StartupDetails message un-prompted. After this, every single message is only sent
847 # in response to us sending a request.
848 log = structlog.get_logger(logger_name="task")
849
850 if os.environ.get("_AIRFLOW__REEXECUTED_PROCESS") == "1" and (
851 msgjson := os.environ.get("_AIRFLOW__STARTUP_MSG")
852 ):
853 # Clear any Kerberos replace cache if there is one, so new process can't reuse it.
854 os.environ.pop("KRB5CCNAME", None)
855 # entrypoint of re-exec process
856
857 msg: StartupDetails = TypeAdapter(StartupDetails).validate_json(msgjson)
858 reinit_supervisor_comms()
859
860 # We delay this message until _after_ we've got the logging re-configured, otherwise it will show up
861 # on stdout
862 log.debug("Using serialized startup message from environment", msg=msg)
863 else:
864 # normal entry point
865 msg = SUPERVISOR_COMMS._get_response() # type: ignore[assignment]
866
867 if not isinstance(msg, StartupDetails):
868 raise RuntimeError(f"Unhandled startup message {type(msg)} {msg}")
869
870 # setproctitle causes issue on Mac OS: https://github.com/benoitc/gunicorn/issues/3021
871 os_type = sys.platform
872 if os_type == "darwin":
873 log.debug("Mac OS detected, skipping setproctitle")
874 else:
875 from setproctitle import setproctitle
876
877 setproctitle(f"airflow worker -- {msg.ti.id}")
878
879 try:
880 get_listener_manager().hook.on_starting(component=TaskRunnerMarker())
881 except Exception:
882 log.exception("error calling listener")
883
884 with _airflow_parsing_context_manager(dag_id=msg.ti.dag_id, task_id=msg.ti.task_id):
885 ti = parse(msg, log)
886 log.debug("Dag file parsed", file=msg.dag_rel_path)
887
888 run_as_user = getattr(ti.task, "run_as_user", None) or conf.get(
889 "core", "default_impersonation", fallback=None
890 )
891
892 if os.environ.get("_AIRFLOW__REEXECUTED_PROCESS") != "1" and run_as_user and run_as_user != getuser():
893 # enters here for re-exec process
894 os.environ["_AIRFLOW__REEXECUTED_PROCESS"] = "1"
895 # store startup message in environment for re-exec process
896 os.environ["_AIRFLOW__STARTUP_MSG"] = msg.model_dump_json()
897 os.set_inheritable(SUPERVISOR_COMMS.socket.fileno(), True)
898
899 # Import main directly from the module instead of re-executing the file.
900 # This ensures that when other parts modules import
901 # airflow.sdk.execution_time.task_runner, they get the same module instance
902 # with the properly initialized SUPERVISOR_COMMS global variable.
903 # If we re-executed the module with `python -m`, it would load as __main__ and future
904 # imports would get a fresh copy without the initialized globals.
905 rexec_python_code = "from airflow.sdk.execution_time.task_runner import main; main()"
906 cmd = ["sudo", "-E", "-H", "-u", run_as_user, sys.executable, "-c", rexec_python_code]
907 log.info(
908 "Running command",
909 command=cmd,
910 )
911 os.execvp("sudo", cmd)
912
913 # ideally, we should never reach here, but if we do, we should return None, None, None
914 return None, None, None
915
916 return ti, ti.get_template_context(), log
917
918
919def _serialize_template_field(template_field: Any, name: str) -> str | dict | list | int | float:
920 """
921 Return a serializable representation of the templated field.
922
923 If ``templated_field`` contains a class or instance that requires recursive
924 templating, store them as strings. Otherwise simply return the field as-is.
925
926 Used sdk secrets masker to redact secrets in the serialized output.
927 """
928 import json
929
930 from airflow.sdk._shared.secrets_masker import redact
931
932 def is_jsonable(x):
933 try:
934 json.dumps(x)
935 except (TypeError, OverflowError):
936 return False
937 else:
938 return True
939
940 def translate_tuples_to_lists(obj: Any):
941 """Recursively convert tuples to lists."""
942 if isinstance(obj, tuple):
943 return [translate_tuples_to_lists(item) for item in obj]
944 if isinstance(obj, list):
945 return [translate_tuples_to_lists(item) for item in obj]
946 if isinstance(obj, dict):
947 return {key: translate_tuples_to_lists(value) for key, value in obj.items()}
948 return obj
949
950 def sort_dict_recursively(obj: Any) -> Any:
951 """Recursively sort dictionaries to ensure consistent ordering."""
952 if isinstance(obj, dict):
953 return {k: sort_dict_recursively(v) for k, v in sorted(obj.items())}
954 if isinstance(obj, list):
955 return [sort_dict_recursively(item) for item in obj]
956 if isinstance(obj, tuple):
957 return tuple(sort_dict_recursively(item) for item in obj)
958 return obj
959
960 def _fallback_serialization(obj):
961 """Serialize objects with to_dict() method (eg: k8s objects) for json.dumps() default parameter."""
962 if hasattr(obj, "to_dict"):
963 return obj.to_dict()
964 raise TypeError(f"cannot serialize {obj}")
965
966 max_length = conf.getint("core", "max_templated_field_length")
967
968 if not is_jsonable(template_field):
969 try:
970 serialized = template_field.serialize()
971 except AttributeError:
972 # check if these objects can be converted to JSON serializable types
973 try:
974 serialized = json.dumps(template_field, default=_fallback_serialization)
975 except (TypeError, ValueError):
976 # fall back to string representation if not
977 serialized = str(template_field)
978 if len(serialized) > max_length:
979 rendered = redact(serialized, name)
980 return (
981 "Truncated. You can change this behaviour in [core]max_templated_field_length. "
982 f"{rendered[: max_length - 79]!r}... "
983 )
984 return serialized
985 if not template_field and not isinstance(template_field, tuple):
986 # Avoid unnecessary serialization steps for empty fields unless they are tuples
987 # and need to be converted to lists
988 return template_field
989 template_field = translate_tuples_to_lists(template_field)
990 # Sort dictionaries recursively to ensure consistent string representation
991 # This prevents hash inconsistencies when dict ordering varies
992 if isinstance(template_field, dict):
993 template_field = sort_dict_recursively(template_field)
994 serialized = str(template_field)
995 if len(serialized) > max_length:
996 rendered = redact(serialized, name)
997 return (
998 "Truncated. You can change this behaviour in [core]max_templated_field_length. "
999 f"{rendered[: max_length - 79]!r}... "
1000 )
1001 return template_field
1002
1003
1004def _serialize_rendered_fields(task: AbstractOperator) -> dict[str, JsonValue]:
1005 from airflow.sdk._shared.secrets_masker import redact
1006
1007 rendered_fields = {}
1008 for field in task.template_fields:
1009 value = getattr(task, field)
1010 serialized = _serialize_template_field(value, field)
1011 # Redact secrets in the task process itself before sending to API server
1012 # This ensures that the secrets those are registered via mask_secret() on workers / dag processor are properly masked
1013 # on the UI.
1014 rendered_fields[field] = redact(serialized, field)
1015
1016 return rendered_fields # type: ignore[return-value] # Convince mypy that this is OK since we pass JsonValue to redact, so it will return the same
1017
1018
1019def _build_asset_profiles(lineage_objects: list) -> Iterator[AssetProfile]:
1020 # Lineage can have other types of objects besides assets, so we need to process them a bit.
1021 for obj in lineage_objects or ():
1022 if isinstance(obj, Asset):
1023 yield AssetProfile(name=obj.name, uri=obj.uri, type=Asset.__name__)
1024 elif isinstance(obj, AssetNameRef):
1025 yield AssetProfile(name=obj.name, type=AssetNameRef.__name__)
1026 elif isinstance(obj, AssetUriRef):
1027 yield AssetProfile(uri=obj.uri, type=AssetUriRef.__name__)
1028 elif isinstance(obj, AssetAlias):
1029 yield AssetProfile(name=obj.name, type=AssetAlias.__name__)
1030
1031
1032def _serialize_outlet_events(events: OutletEventAccessorsProtocol) -> Iterator[dict[str, JsonValue]]:
1033 if TYPE_CHECKING:
1034 assert isinstance(events, OutletEventAccessors)
1035 # We just collect everything the user recorded in the accessors.
1036 # Further filtering will be done in the API server.
1037 for key, accessor in events._dict.items():
1038 if isinstance(key, AssetUniqueKey):
1039 yield {"dest_asset_key": attrs.asdict(key), "extra": accessor.extra}
1040 for alias_event in accessor.asset_alias_events:
1041 yield attrs.asdict(alias_event)
1042
1043
1044def _prepare(ti: RuntimeTaskInstance, log: Logger, context: Context) -> ToSupervisor | None:
1045 ti.hostname = get_hostname()
1046 ti.task = ti.task.prepare_for_execution()
1047 # Since context is now cached, and calling `ti.get_template_context` will return the same dict, we want to
1048 # update the value of the task that is sent from there
1049 context["task"] = ti.task
1050
1051 jinja_env = ti.task.dag.get_template_env()
1052 ti.render_templates(context=context, jinja_env=jinja_env)
1053
1054 if rendered_fields := _serialize_rendered_fields(ti.task):
1055 # so that we do not call the API unnecessarily
1056 SUPERVISOR_COMMS.send(msg=SetRenderedFields(rendered_fields=rendered_fields))
1057
1058 # Try to render map_index_template early with available context (will be re-rendered after execution)
1059 # This provides a partial label during task execution for templates using pre-execution context
1060 # If rendering fails here, we suppress the error since it will be re-rendered after execution
1061 try:
1062 if rendered_map_index := _render_map_index(context, ti=ti, log=log):
1063 ti.rendered_map_index = rendered_map_index
1064 log.debug("Sending early rendered map index", length=len(rendered_map_index))
1065 SUPERVISOR_COMMS.send(msg=SetRenderedMapIndex(rendered_map_index=rendered_map_index))
1066 except Exception:
1067 log.debug(
1068 "Early rendering of map_index_template failed, will retry after task execution", exc_info=True
1069 )
1070
1071 _validate_task_inlets_and_outlets(ti=ti, log=log)
1072
1073 try:
1074 # TODO: Call pre execute etc.
1075 get_listener_manager().hook.on_task_instance_running(
1076 previous_state=TaskInstanceState.QUEUED, task_instance=ti
1077 )
1078 except Exception:
1079 log.exception("error calling listener")
1080
1081 # No error, carry on and execute the task
1082 return None
1083
1084
1085def _validate_task_inlets_and_outlets(*, ti: RuntimeTaskInstance, log: Logger) -> None:
1086 if not ti.task.inlets and not ti.task.outlets:
1087 return
1088
1089 inactive_assets_resp = SUPERVISOR_COMMS.send(msg=ValidateInletsAndOutlets(ti_id=ti.id))
1090 if TYPE_CHECKING:
1091 assert isinstance(inactive_assets_resp, InactiveAssetsResult)
1092 if inactive_assets := inactive_assets_resp.inactive_assets:
1093 raise AirflowInactiveAssetInInletOrOutletException(
1094 inactive_asset_keys=[
1095 AssetUniqueKey.from_profile(asset_profile) for asset_profile in inactive_assets
1096 ]
1097 )
1098
1099
1100def _defer_task(
1101 defer: TaskDeferred, ti: RuntimeTaskInstance, log: Logger
1102) -> tuple[ToSupervisor, TaskInstanceState]:
1103 # TODO: Should we use structlog.bind_contextvars here for dag_id, task_id & run_id?
1104
1105 log.info("Pausing task as DEFERRED. ", dag_id=ti.dag_id, task_id=ti.task_id, run_id=ti.run_id)
1106 classpath, trigger_kwargs = defer.trigger.serialize()
1107 queue: str | None = None
1108 # Currently, only task-associated BaseTrigger instances may have a non-None queue,
1109 # and only when triggerer.queues_enabled conf is True.
1110 if conf.getboolean("triggerer", "queues_enabled", fallback=False) and getattr(
1111 defer.trigger, "supports_triggerer_queue", True
1112 ):
1113 queue = ti.task.queue
1114
1115 from airflow.sdk.serde import serialize as serde_serialize
1116
1117 trigger_kwargs = serde_serialize(trigger_kwargs)
1118 next_kwargs = serde_serialize(defer.kwargs or {})
1119
1120 if TYPE_CHECKING:
1121 assert isinstance(next_kwargs, dict)
1122 assert isinstance(trigger_kwargs, dict)
1123
1124 msg = DeferTask(
1125 classpath=classpath,
1126 trigger_kwargs=trigger_kwargs,
1127 trigger_timeout=defer.timeout,
1128 queue=queue,
1129 next_method=defer.method_name,
1130 next_kwargs=next_kwargs,
1131 )
1132 state = TaskInstanceState.DEFERRED
1133
1134 return msg, state
1135
1136
1137@Sentry.enrich_errors
1138def run(
1139 ti: RuntimeTaskInstance,
1140 context: Context,
1141 log: Logger,
1142) -> tuple[TaskInstanceState, ToSupervisor | None, BaseException | None]:
1143 """Run the task in this process."""
1144 import signal
1145
1146 from airflow.sdk.exceptions import (
1147 AirflowFailException,
1148 AirflowRescheduleException,
1149 AirflowSensorTimeout,
1150 AirflowSkipException,
1151 AirflowTaskTerminated,
1152 DagRunTriggerException,
1153 DownstreamTasksSkipped,
1154 TaskDeferred,
1155 )
1156
1157 if TYPE_CHECKING:
1158 assert ti.task is not None
1159 assert isinstance(ti.task, BaseOperator)
1160
1161 parent_pid = os.getpid()
1162
1163 def _on_term(signum, frame):
1164 pid = os.getpid()
1165 if pid != parent_pid:
1166 return
1167
1168 ti.task.on_kill()
1169
1170 signal.signal(signal.SIGTERM, _on_term)
1171
1172 msg: ToSupervisor | None = None
1173 state: TaskInstanceState
1174 error: BaseException | None = None
1175
1176 try:
1177 # First, clear the xcom data sent from server
1178 if ti._ti_context_from_server and (keys_to_delete := ti._ti_context_from_server.xcom_keys_to_clear):
1179 for x in keys_to_delete:
1180 log.debug("Clearing XCom with key", key=x)
1181 XCom.delete(
1182 key=x,
1183 dag_id=ti.dag_id,
1184 task_id=ti.task_id,
1185 run_id=ti.run_id,
1186 map_index=ti.map_index,
1187 )
1188
1189 with set_current_context(context):
1190 # This is the earliest that we can render templates -- as if it excepts for any reason we need to
1191 # catch it and handle it like a normal task failure
1192 if early_exit := _prepare(ti, log, context):
1193 msg = early_exit
1194 ti.state = state = TaskInstanceState.FAILED
1195 return state, msg, error
1196
1197 try:
1198 result = _execute_task(context=context, ti=ti, log=log)
1199 except Exception:
1200 import jinja2
1201
1202 # If the task failed, swallow rendering error so it doesn't mask the main error.
1203 with contextlib.suppress(jinja2.TemplateSyntaxError, jinja2.UndefinedError):
1204 previous_rendered_map_index = ti.rendered_map_index
1205 ti.rendered_map_index = _render_map_index(context, ti=ti, log=log)
1206 # Send update only if value changed (e.g., user set context variables during execution)
1207 if ti.rendered_map_index and ti.rendered_map_index != previous_rendered_map_index:
1208 SUPERVISOR_COMMS.send(
1209 msg=SetRenderedMapIndex(rendered_map_index=ti.rendered_map_index)
1210 )
1211 raise
1212 else: # If the task succeeded, render normally to let rendering error bubble up.
1213 previous_rendered_map_index = ti.rendered_map_index
1214 ti.rendered_map_index = _render_map_index(context, ti=ti, log=log)
1215 # Send update only if value changed (e.g., user set context variables during execution)
1216 if ti.rendered_map_index and ti.rendered_map_index != previous_rendered_map_index:
1217 SUPERVISOR_COMMS.send(msg=SetRenderedMapIndex(rendered_map_index=ti.rendered_map_index))
1218
1219 _push_xcom_if_needed(result, ti, log)
1220
1221 msg, state = _handle_current_task_success(context, ti)
1222 except DownstreamTasksSkipped as skip:
1223 log.info("Skipping downstream tasks.")
1224 tasks_to_skip = skip.tasks if isinstance(skip.tasks, list) else [skip.tasks]
1225 SUPERVISOR_COMMS.send(msg=SkipDownstreamTasks(tasks=tasks_to_skip))
1226 msg, state = _handle_current_task_success(context, ti)
1227 except DagRunTriggerException as drte:
1228 msg, state = _handle_trigger_dag_run(drte, context, ti, log)
1229 except TaskDeferred as defer:
1230 msg, state = _defer_task(defer, ti, log)
1231 except AirflowSkipException as e:
1232 if e.args:
1233 log.info("Skipping task.", reason=e.args[0])
1234 msg = TaskState(
1235 state=TaskInstanceState.SKIPPED,
1236 end_date=datetime.now(tz=timezone.utc),
1237 rendered_map_index=ti.rendered_map_index,
1238 )
1239 state = TaskInstanceState.SKIPPED
1240 except AirflowRescheduleException as reschedule:
1241 log.info("Rescheduling task, marking task as UP_FOR_RESCHEDULE")
1242 msg = RescheduleTask(
1243 reschedule_date=reschedule.reschedule_date, end_date=datetime.now(tz=timezone.utc)
1244 )
1245 state = TaskInstanceState.UP_FOR_RESCHEDULE
1246 except (AirflowFailException, AirflowSensorTimeout) as e:
1247 # If AirflowFailException is raised, task should not retry.
1248 # If a sensor in reschedule mode reaches timeout, task should not retry.
1249 log.exception("Task failed with exception")
1250 ti.end_date = datetime.now(tz=timezone.utc)
1251 msg = TaskState(
1252 state=TaskInstanceState.FAILED,
1253 end_date=ti.end_date,
1254 rendered_map_index=ti.rendered_map_index,
1255 )
1256 state = TaskInstanceState.FAILED
1257 error = e
1258 except (AirflowTaskTimeout, AirflowException, AirflowRuntimeError) as e:
1259 # We should allow retries if the task has defined it.
1260 log.exception("Task failed with exception")
1261 msg, state = _handle_current_task_failed(ti)
1262 error = e
1263 except AirflowTaskTerminated as e:
1264 # External state updates are already handled with `ti_heartbeat` and will be
1265 # updated already be another UI API. So, these exceptions should ideally never be thrown.
1266 # If these are thrown, we should mark the TI state as failed.
1267 log.exception("Task failed with exception")
1268 ti.end_date = datetime.now(tz=timezone.utc)
1269 msg = TaskState(
1270 state=TaskInstanceState.FAILED,
1271 end_date=ti.end_date,
1272 rendered_map_index=ti.rendered_map_index,
1273 )
1274 state = TaskInstanceState.FAILED
1275 error = e
1276 except SystemExit as e:
1277 # SystemExit needs to be retried if they are eligible.
1278 log.error("Task exited", exit_code=e.code)
1279 msg, state = _handle_current_task_failed(ti)
1280 error = e
1281 except BaseException as e:
1282 log.exception("Task failed with exception")
1283 msg, state = _handle_current_task_failed(ti)
1284 error = e
1285 finally:
1286 if msg:
1287 SUPERVISOR_COMMS.send(msg=msg)
1288
1289 # Return the message to make unit tests easier too
1290 ti.state = state
1291 return state, msg, error
1292
1293
1294def _handle_current_task_success(
1295 context: Context,
1296 ti: RuntimeTaskInstance,
1297) -> tuple[SucceedTask, TaskInstanceState]:
1298 end_date = datetime.now(tz=timezone.utc)
1299 ti.end_date = end_date
1300
1301 # Record operator and task instance success metrics
1302 operator = ti.task.__class__.__name__
1303 stats_tags = {"dag_id": ti.dag_id, "task_id": ti.task_id}
1304
1305 Stats.incr(f"operator_successes_{operator}", tags=stats_tags)
1306 # Same metric with tagging
1307 Stats.incr("operator_successes", tags={**stats_tags, "operator": operator})
1308 Stats.incr("ti_successes", tags=stats_tags)
1309
1310 task_outlets = list(_build_asset_profiles(ti.task.outlets))
1311 outlet_events = list(_serialize_outlet_events(context["outlet_events"]))
1312 msg = SucceedTask(
1313 end_date=end_date,
1314 task_outlets=task_outlets,
1315 outlet_events=outlet_events,
1316 rendered_map_index=ti.rendered_map_index,
1317 )
1318 return msg, TaskInstanceState.SUCCESS
1319
1320
1321def _handle_current_task_failed(
1322 ti: RuntimeTaskInstance,
1323) -> tuple[RetryTask, TaskInstanceState] | tuple[TaskState, TaskInstanceState]:
1324 end_date = datetime.now(tz=timezone.utc)
1325 ti.end_date = end_date
1326
1327 # Record operator and task instance failed metrics
1328 operator = ti.task.__class__.__name__
1329 stats_tags = {"dag_id": ti.dag_id, "task_id": ti.task_id}
1330
1331 Stats.incr(f"operator_failures_{operator}", tags=stats_tags)
1332 # Same metric with tagging
1333 Stats.incr("operator_failures", tags={**stats_tags, "operator": operator})
1334 Stats.incr("ti_failures", tags=stats_tags)
1335
1336 if ti._ti_context_from_server and ti._ti_context_from_server.should_retry:
1337 return RetryTask(end_date=end_date), TaskInstanceState.UP_FOR_RETRY
1338 return (
1339 TaskState(
1340 state=TaskInstanceState.FAILED, end_date=end_date, rendered_map_index=ti.rendered_map_index
1341 ),
1342 TaskInstanceState.FAILED,
1343 )
1344
1345
1346def _handle_trigger_dag_run(
1347 drte: DagRunTriggerException, context: Context, ti: RuntimeTaskInstance, log: Logger
1348) -> tuple[ToSupervisor, TaskInstanceState]:
1349 """Handle exception from TriggerDagRunOperator."""
1350 log.info("Triggering Dag Run.", trigger_dag_id=drte.trigger_dag_id)
1351 comms_msg = SUPERVISOR_COMMS.send(
1352 TriggerDagRun(
1353 dag_id=drte.trigger_dag_id,
1354 run_id=drte.dag_run_id,
1355 logical_date=drte.logical_date,
1356 conf=drte.conf,
1357 reset_dag_run=drte.reset_dag_run,
1358 ),
1359 )
1360
1361 if isinstance(comms_msg, ErrorResponse) and comms_msg.error == ErrorType.DAGRUN_ALREADY_EXISTS:
1362 if drte.skip_when_already_exists:
1363 log.info(
1364 "Dag Run already exists, skipping task as skip_when_already_exists is set to True.",
1365 dag_id=drte.trigger_dag_id,
1366 )
1367 msg = TaskState(
1368 state=TaskInstanceState.SKIPPED,
1369 end_date=datetime.now(tz=timezone.utc),
1370 rendered_map_index=ti.rendered_map_index,
1371 )
1372 state = TaskInstanceState.SKIPPED
1373 else:
1374 log.error("Dag Run already exists, marking task as failed.", dag_id=drte.trigger_dag_id)
1375 msg = TaskState(
1376 state=TaskInstanceState.FAILED,
1377 end_date=datetime.now(tz=timezone.utc),
1378 rendered_map_index=ti.rendered_map_index,
1379 )
1380 state = TaskInstanceState.FAILED
1381
1382 return msg, state
1383
1384 log.info("Dag Run triggered successfully.", trigger_dag_id=drte.trigger_dag_id)
1385
1386 # Store the run id from the dag run (either created or found above) to
1387 # be used when creating the extra link on the webserver.
1388 ti.xcom_push(key="trigger_run_id", value=drte.dag_run_id)
1389
1390 if drte.wait_for_completion:
1391 if drte.deferrable:
1392 from airflow.providers.standard.triggers.external_task import DagStateTrigger
1393
1394 defer = TaskDeferred(
1395 trigger=DagStateTrigger(
1396 dag_id=drte.trigger_dag_id,
1397 states=drte.allowed_states + drte.failed_states, # type: ignore[arg-type]
1398 # Don't filter by execution_dates when run_ids is provided.
1399 # run_id uniquely identifies a DAG run, and when reset_dag_run=True,
1400 # drte.logical_date might be a newly calculated value that doesn't match
1401 # the persisted logical_date in the database, causing the trigger to never find the run.
1402 execution_dates=None,
1403 run_ids=[drte.dag_run_id],
1404 poll_interval=drte.poke_interval,
1405 ),
1406 method_name="execute_complete",
1407 )
1408 return _defer_task(defer, ti, log)
1409 while True:
1410 log.info(
1411 "Waiting for dag run to complete execution in allowed state.",
1412 dag_id=drte.trigger_dag_id,
1413 run_id=drte.dag_run_id,
1414 allowed_state=drte.allowed_states,
1415 )
1416 time.sleep(drte.poke_interval)
1417
1418 comms_msg = SUPERVISOR_COMMS.send(
1419 GetDagRunState(dag_id=drte.trigger_dag_id, run_id=drte.dag_run_id)
1420 )
1421 if TYPE_CHECKING:
1422 assert isinstance(comms_msg, DagRunStateResult)
1423 if comms_msg.state in drte.failed_states:
1424 log.error(
1425 "DagRun finished with failed state.", dag_id=drte.trigger_dag_id, state=comms_msg.state
1426 )
1427 msg = TaskState(
1428 state=TaskInstanceState.FAILED,
1429 end_date=datetime.now(tz=timezone.utc),
1430 rendered_map_index=ti.rendered_map_index,
1431 )
1432 state = TaskInstanceState.FAILED
1433 return msg, state
1434 if comms_msg.state in drte.allowed_states:
1435 log.info(
1436 "DagRun finished with allowed state.", dag_id=drte.trigger_dag_id, state=comms_msg.state
1437 )
1438 break
1439 log.debug(
1440 "DagRun not yet in allowed or failed state.",
1441 dag_id=drte.trigger_dag_id,
1442 state=comms_msg.state,
1443 )
1444 else:
1445 # Fire-and-forget mode: wait_for_completion=False
1446 if drte.deferrable:
1447 log.info(
1448 "Ignoring deferrable=True because wait_for_completion=False. "
1449 "Task will complete immediately without waiting for the triggered DAG run.",
1450 trigger_dag_id=drte.trigger_dag_id,
1451 )
1452
1453 return _handle_current_task_success(context, ti)
1454
1455
1456def _run_task_state_change_callbacks(
1457 task: BaseOperator,
1458 kind: Literal[
1459 "on_execute_callback",
1460 "on_failure_callback",
1461 "on_success_callback",
1462 "on_retry_callback",
1463 "on_skipped_callback",
1464 ],
1465 context: Context,
1466 log: Logger,
1467) -> None:
1468 callback: Callable[[Context], None]
1469 for i, callback in enumerate(getattr(task, kind)):
1470 try:
1471 create_executable_runner(callback, context_get_outlet_events(context), logger=log).run(context)
1472 except Exception:
1473 log.exception("Failed to run task callback", kind=kind, index=i, callback=callback)
1474
1475
1476def _send_error_email_notification(
1477 task: BaseOperator | MappedOperator,
1478 ti: RuntimeTaskInstance,
1479 context: Context,
1480 error: BaseException | str | None,
1481 log: Logger,
1482) -> None:
1483 """Send email notification for task errors using SmtpNotifier."""
1484 try:
1485 from airflow.providers.smtp.notifications.smtp import SmtpNotifier
1486 except ImportError:
1487 log.error(
1488 "Failed to send task failure or retry email notification: "
1489 "`apache-airflow-providers-smtp` is not installed. "
1490 "Install this provider to enable email notifications."
1491 )
1492 return
1493
1494 if not task.email:
1495 return
1496
1497 subject_template_file = conf.get("email", "subject_template", fallback=None)
1498
1499 # Read the template file if configured
1500 if subject_template_file and Path(subject_template_file).exists():
1501 subject = Path(subject_template_file).read_text()
1502 else:
1503 # Fallback to default
1504 subject = "Airflow alert: {{ti}}"
1505
1506 html_content_template_file = conf.get("email", "html_content_template", fallback=None)
1507
1508 # Read the template file if configured
1509 if html_content_template_file and Path(html_content_template_file).exists():
1510 html_content = Path(html_content_template_file).read_text()
1511 else:
1512 # Fallback to default
1513 # For reporting purposes, we report based on 1-indexed,
1514 # not 0-indexed lists (i.e. Try 1 instead of Try 0 for the first attempt).
1515 html_content = (
1516 "Try {{try_number}} out of {{max_tries + 1}}<br>"
1517 "Exception:<br>{{exception_html}}<br>"
1518 'Log: <a href="{{ti.log_url}}">Link</a><br>'
1519 "Host: {{ti.hostname}}<br>"
1520 'Mark success: <a href="{{ti.mark_success_url}}">Link</a><br>'
1521 )
1522
1523 # Add exception_html to context for template rendering
1524 import html
1525
1526 exception_html = html.escape(str(error)).replace("\n", "<br>")
1527 additional_context = {
1528 "exception": error,
1529 "exception_html": exception_html,
1530 "try_number": ti.try_number,
1531 "max_tries": ti.max_tries,
1532 }
1533 email_context = {**context, **additional_context}
1534 to_emails = task.email
1535 if not to_emails:
1536 return
1537
1538 try:
1539 notifier = SmtpNotifier(
1540 to=to_emails,
1541 subject=subject,
1542 html_content=html_content,
1543 from_email=conf.get("email", "from_email", fallback="airflow@airflow"),
1544 )
1545 notifier(email_context)
1546 except Exception:
1547 log.exception("Failed to send email notification")
1548
1549
1550def _execute_task(context: Context, ti: RuntimeTaskInstance, log: Logger):
1551 """Execute Task (optionally with a Timeout) and push Xcom results."""
1552 task = ti.task
1553 execute = task.execute
1554
1555 if ti._ti_context_from_server and (next_method := ti._ti_context_from_server.next_method):
1556 from airflow.sdk.serde import deserialize
1557
1558 next_kwargs_data = ti._ti_context_from_server.next_kwargs or {}
1559 try:
1560 if TYPE_CHECKING:
1561 assert isinstance(next_kwargs_data, dict)
1562 kwargs = deserialize(next_kwargs_data)
1563 except (ImportError, KeyError, AttributeError, TypeError):
1564 from airflow.serialization.serialized_objects import BaseSerialization
1565
1566 kwargs = BaseSerialization.deserialize(next_kwargs_data)
1567
1568 if TYPE_CHECKING:
1569 assert isinstance(kwargs, dict)
1570 execute = functools.partial(task.resume_execution, next_method=next_method, next_kwargs=kwargs)
1571
1572 ctx = contextvars.copy_context()
1573 # Populate the context var so ExecutorSafeguard doesn't complain
1574 ctx.run(ExecutorSafeguard.tracker.set, task)
1575
1576 # Export context in os.environ to make it available for operators to use.
1577 airflow_context_vars = context_to_airflow_vars(context, in_env_var_format=True)
1578 os.environ.update(airflow_context_vars)
1579
1580 outlet_events = context_get_outlet_events(context)
1581
1582 if (pre_execute_hook := task._pre_execute_hook) is not None:
1583 create_executable_runner(pre_execute_hook, outlet_events, logger=log).run(context)
1584 if getattr(pre_execute_hook := task.pre_execute, "__func__", None) is not BaseOperator.pre_execute:
1585 create_executable_runner(pre_execute_hook, outlet_events, logger=log).run(context)
1586
1587 _run_task_state_change_callbacks(task, "on_execute_callback", context, log)
1588
1589 if task.execution_timeout:
1590 from airflow.sdk.execution_time.timeout import timeout
1591
1592 # TODO: handle timeout in case of deferral
1593 timeout_seconds = task.execution_timeout.total_seconds()
1594 try:
1595 # It's possible we're already timed out, so fast-fail if true
1596 if timeout_seconds <= 0:
1597 raise AirflowTaskTimeout()
1598 # Run task in timeout wrapper
1599 with timeout(timeout_seconds):
1600 result = ctx.run(execute, context=context)
1601 except AirflowTaskTimeout:
1602 task.on_kill()
1603 raise
1604 else:
1605 result = ctx.run(execute, context=context)
1606
1607 if (post_execute_hook := task._post_execute_hook) is not None:
1608 create_executable_runner(post_execute_hook, outlet_events, logger=log).run(context, result)
1609 if getattr(post_execute_hook := task.post_execute, "__func__", None) is not BaseOperator.post_execute:
1610 create_executable_runner(post_execute_hook, outlet_events, logger=log).run(context)
1611
1612 return result
1613
1614
1615def _render_map_index(context: Context, ti: RuntimeTaskInstance, log: Logger) -> str | None:
1616 """Render named map index if the Dag author defined map_index_template at the task level."""
1617 if (template := context.get("map_index_template")) is None:
1618 return None
1619 log.debug("Rendering map_index_template", template_length=len(template))
1620 jinja_env = ti.task.dag.get_template_env()
1621 rendered_map_index = jinja_env.from_string(template).render(context)
1622 log.debug("Map index rendered", length=len(rendered_map_index))
1623 return rendered_map_index
1624
1625
1626def _push_xcom_if_needed(result: Any, ti: RuntimeTaskInstance, log: Logger):
1627 """Push XCom values when task has ``do_xcom_push`` set to ``True`` and the task returns a result."""
1628 if ti.task.do_xcom_push:
1629 xcom_value = result
1630 else:
1631 xcom_value = None
1632
1633 has_mapped_dep = next(ti.task.iter_mapped_dependants(), None) is not None
1634 if xcom_value is None:
1635 if not ti.is_mapped and has_mapped_dep:
1636 # Uhoh, a downstream mapped task depends on us to push something to map over
1637 from airflow.sdk.exceptions import XComForMappingNotPushed
1638
1639 raise XComForMappingNotPushed()
1640 return
1641
1642 mapped_length: int | None = None
1643 if not ti.is_mapped and has_mapped_dep:
1644 from airflow.sdk.definitions.mappedoperator import is_mappable_value
1645 from airflow.sdk.exceptions import UnmappableXComTypePushed
1646
1647 if not is_mappable_value(xcom_value):
1648 raise UnmappableXComTypePushed(xcom_value)
1649 mapped_length = len(xcom_value)
1650
1651 log.info("Pushing xcom", ti=ti)
1652
1653 # If the task has multiple outputs, push each output as a separate XCom.
1654 if ti.task.multiple_outputs:
1655 if not isinstance(xcom_value, Mapping):
1656 raise TypeError(
1657 f"Returned output was type {type(xcom_value)} expected dictionary for multiple_outputs"
1658 )
1659 for key in xcom_value.keys():
1660 if not isinstance(key, str):
1661 raise TypeError(
1662 "Returned dictionary keys must be strings when using "
1663 f"multiple_outputs, found {key} ({type(key)}) instead"
1664 )
1665 for k, v in result.items():
1666 ti.xcom_push(k, v)
1667
1668 _xcom_push(ti, BaseXCom.XCOM_RETURN_KEY, result, mapped_length=mapped_length)
1669
1670
1671def finalize(
1672 ti: RuntimeTaskInstance,
1673 state: TaskInstanceState,
1674 context: Context,
1675 log: Logger,
1676 error: BaseException | None = None,
1677):
1678 # Record task duration metrics for all terminal states
1679 if ti.start_date and ti.end_date:
1680 duration_ms = (ti.end_date - ti.start_date).total_seconds() * 1000
1681 stats_tags = {"dag_id": ti.dag_id, "task_id": ti.task_id}
1682
1683 Stats.timing(f"dag.{ti.dag_id}.{ti.task_id}.duration", duration_ms)
1684 Stats.timing("task.duration", duration_ms, tags=stats_tags)
1685
1686 task = ti.task
1687 # Pushing xcom for each operator extra links defined on the operator only.
1688 for oe in task.operator_extra_links:
1689 try:
1690 link, xcom_key = oe.get_link(operator=task, ti_key=ti), oe.xcom_key # type: ignore[arg-type]
1691 log.debug("Setting xcom for operator extra link", link=link, xcom_key=xcom_key)
1692 _xcom_push_to_db(ti, key=xcom_key, value=link)
1693 except Exception:
1694 log.exception(
1695 "Failed to push an xcom for task operator extra link",
1696 link_name=oe.name,
1697 xcom_key=oe.xcom_key,
1698 ti=ti,
1699 )
1700
1701 if getattr(ti.task, "overwrite_rtif_after_execution", False):
1702 log.debug("Overwriting Rendered template fields.")
1703 if ti.task.template_fields:
1704 SUPERVISOR_COMMS.send(SetRenderedFields(rendered_fields=_serialize_rendered_fields(ti.task)))
1705
1706 log.debug("Running finalizers", ti=ti)
1707 if state == TaskInstanceState.SUCCESS:
1708 _run_task_state_change_callbacks(task, "on_success_callback", context, log)
1709 try:
1710 get_listener_manager().hook.on_task_instance_success(
1711 previous_state=TaskInstanceState.RUNNING, task_instance=ti
1712 )
1713 except Exception:
1714 log.exception("error calling listener")
1715 elif state == TaskInstanceState.SKIPPED:
1716 _run_task_state_change_callbacks(task, "on_skipped_callback", context, log)
1717 try:
1718 get_listener_manager().hook.on_task_instance_skipped(
1719 previous_state=TaskInstanceState.RUNNING, task_instance=ti
1720 )
1721 except Exception:
1722 log.exception("error calling listener")
1723 elif state == TaskInstanceState.UP_FOR_RETRY:
1724 _run_task_state_change_callbacks(task, "on_retry_callback", context, log)
1725 try:
1726 get_listener_manager().hook.on_task_instance_failed(
1727 previous_state=TaskInstanceState.RUNNING, task_instance=ti, error=error
1728 )
1729 except Exception:
1730 log.exception("error calling listener")
1731 if error and task.email_on_retry and task.email:
1732 _send_error_email_notification(task, ti, context, error, log)
1733 elif state == TaskInstanceState.FAILED:
1734 _run_task_state_change_callbacks(task, "on_failure_callback", context, log)
1735 try:
1736 get_listener_manager().hook.on_task_instance_failed(
1737 previous_state=TaskInstanceState.RUNNING, task_instance=ti, error=error
1738 )
1739 except Exception:
1740 log.exception("error calling listener")
1741 if error and task.email_on_failure and task.email:
1742 _send_error_email_notification(task, ti, context, error, log)
1743
1744 try:
1745 get_listener_manager().hook.before_stopping(component=TaskRunnerMarker())
1746 except Exception:
1747 log.exception("error calling listener")
1748
1749
1750def main():
1751 log = structlog.get_logger(logger_name="task")
1752
1753 global SUPERVISOR_COMMS
1754 SUPERVISOR_COMMS = CommsDecoder[ToTask, ToSupervisor](log=log)
1755
1756 Stats.initialize(
1757 is_statsd_datadog_enabled=conf.getboolean("metrics", "statsd_datadog_enabled"),
1758 is_statsd_on=conf.getboolean("metrics", "statsd_on"),
1759 is_otel_on=conf.getboolean("metrics", "otel_on"),
1760 )
1761
1762 try:
1763 try:
1764 ti, context, log = startup()
1765 except AirflowRescheduleException as reschedule:
1766 log.warning("Rescheduling task during startup, marking task as UP_FOR_RESCHEDULE")
1767 SUPERVISOR_COMMS.send(
1768 msg=RescheduleTask(
1769 reschedule_date=reschedule.reschedule_date,
1770 end_date=datetime.now(tz=timezone.utc),
1771 )
1772 )
1773 sys.exit(0)
1774 with BundleVersionLock(
1775 bundle_name=ti.bundle_instance.name,
1776 bundle_version=ti.bundle_instance.version,
1777 ):
1778 state, _, error = run(ti, context, log)
1779 context["exception"] = error
1780 finalize(ti, state, context, log, error)
1781 except KeyboardInterrupt:
1782 log.exception("Ctrl-c hit")
1783 sys.exit(2)
1784 except Exception:
1785 log.exception("Top level error")
1786 sys.exit(1)
1787 finally:
1788 # Ensure the request socket is closed on the child side in all circumstances
1789 # before the process fully terminates.
1790 if SUPERVISOR_COMMS and SUPERVISOR_COMMS.socket:
1791 with suppress(Exception):
1792 SUPERVISOR_COMMS.socket.close()
1793
1794
1795def reinit_supervisor_comms() -> None:
1796 """
1797 Re-initialize supervisor comms and logging channel in subprocess.
1798
1799 This is not needed for most cases, but is used when either we re-launch the process via sudo for
1800 run_as_user, or from inside the python code in a virtualenv (et al.) operator to re-connect so those tasks
1801 can continue to access variables etc.
1802 """
1803 import socket
1804
1805 if "SUPERVISOR_COMMS" not in globals():
1806 global SUPERVISOR_COMMS
1807 log = structlog.get_logger(logger_name="task")
1808
1809 fd = int(os.environ.get("__AIRFLOW_SUPERVISOR_FD", "0"))
1810
1811 SUPERVISOR_COMMS = CommsDecoder[ToTask, ToSupervisor](log=log, socket=socket.socket(fileno=fd))
1812
1813 logs = SUPERVISOR_COMMS.send(ResendLoggingFD())
1814 if isinstance(logs, SentFDs):
1815 from airflow.sdk.log import configure_logging
1816
1817 log_io = os.fdopen(logs.fds[0], "wb", buffering=0)
1818 configure_logging(json_output=True, output=log_io, sending_to_supervisor=True)
1819 else:
1820 print("Unable to re-configure logging after sudo, we didn't get an FD", file=sys.stderr)
1821
1822
1823if __name__ == "__main__":
1824 main()