1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18"""The entrypoint for the actual task execution process."""
19
20from __future__ import annotations
21
22import contextlib
23import contextvars
24import functools
25import os
26import sys
27import time
28from collections.abc import Callable, Iterable, Iterator, Mapping
29from contextlib import suppress
30from datetime import datetime, timezone
31from itertools import product
32from pathlib import Path
33from typing import TYPE_CHECKING, Annotated, Any, Literal
34from urllib.parse import quote
35
36import attrs
37import lazy_object_proxy
38import structlog
39from pydantic import AwareDatetime, ConfigDict, Field, JsonValue, TypeAdapter
40
41from airflow.dag_processing.bundles.base import BaseDagBundle, BundleVersionLock
42from airflow.dag_processing.bundles.manager import DagBundlesManager
43from airflow.sdk.api.client import get_hostname, getuser
44from airflow.sdk.api.datamodels._generated import (
45 AssetProfile,
46 DagRun,
47 PreviousTIResponse,
48 TaskInstance,
49 TaskInstanceState,
50 TIRunContext,
51)
52from airflow.sdk.bases.operator import BaseOperator, ExecutorSafeguard
53from airflow.sdk.bases.xcom import BaseXCom
54from airflow.sdk.configuration import conf
55from airflow.sdk.definitions._internal.dag_parsing_context import _airflow_parsing_context_manager
56from airflow.sdk.definitions._internal.types import NOTSET, ArgNotSet, is_arg_set
57from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetNameRef, AssetUniqueKey, AssetUriRef
58from airflow.sdk.definitions.mappedoperator import MappedOperator
59from airflow.sdk.definitions.param import process_params
60from airflow.sdk.exceptions import (
61 AirflowException,
62 AirflowInactiveAssetInInletOrOutletException,
63 AirflowRuntimeError,
64 AirflowTaskTimeout,
65 ErrorType,
66 TaskDeferred,
67)
68from airflow.sdk.execution_time.callback_runner import create_executable_runner
69from airflow.sdk.execution_time.comms import (
70 AssetEventDagRunReferenceResult,
71 CommsDecoder,
72 DagRunStateResult,
73 DeferTask,
74 DRCount,
75 ErrorResponse,
76 GetDagRunState,
77 GetDRCount,
78 GetPreviousDagRun,
79 GetPreviousTI,
80 GetTaskBreadcrumbs,
81 GetTaskRescheduleStartDate,
82 GetTaskStates,
83 GetTICount,
84 InactiveAssetsResult,
85 PreviousDagRunResult,
86 PreviousTIResult,
87 RescheduleTask,
88 ResendLoggingFD,
89 RetryTask,
90 SentFDs,
91 SetRenderedFields,
92 SetRenderedMapIndex,
93 SkipDownstreamTasks,
94 StartupDetails,
95 SucceedTask,
96 TaskBreadcrumbsResult,
97 TaskRescheduleStartDate,
98 TaskState,
99 TaskStatesResult,
100 TICount,
101 ToSupervisor,
102 ToTask,
103 TriggerDagRun,
104 ValidateInletsAndOutlets,
105)
106from airflow.sdk.execution_time.context import (
107 ConnectionAccessor,
108 InletEventsAccessors,
109 MacrosAccessor,
110 OutletEventAccessors,
111 TriggeringAssetEventsAccessor,
112 VariableAccessor,
113 context_get_outlet_events,
114 context_to_airflow_vars,
115 get_previous_dagrun_success,
116 set_current_context,
117)
118from airflow.sdk.execution_time.sentry import Sentry
119from airflow.sdk.execution_time.xcom import XCom
120from airflow.sdk.listener import get_listener_manager
121from airflow.sdk.observability.stats import Stats
122from airflow.sdk.timezone import coerce_datetime
123from airflow.triggers.base import BaseEventTrigger
124from airflow.triggers.callback import CallbackTrigger
125
126if TYPE_CHECKING:
127 import jinja2
128 from pendulum.datetime import DateTime
129 from structlog.typing import FilteringBoundLogger as Logger
130
131 from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator
132 from airflow.sdk.definitions.context import Context
133 from airflow.sdk.exceptions import DagRunTriggerException
134 from airflow.sdk.types import OutletEventAccessorsProtocol
135
136
137class TaskRunnerMarker:
138 """Marker for listener hooks, to properly detect from which component they are called."""
139
140
141# TODO: Move this entire class into a separate file:
142# `airflow/sdk/execution_time/task_instance.py`
143# or `airflow/sdk/execution_time/runtime_ti.py`
144class RuntimeTaskInstance(TaskInstance):
145 model_config = ConfigDict(arbitrary_types_allowed=True)
146
147 task: BaseOperator
148 bundle_instance: BaseDagBundle
149 _cached_template_context: Context | None = None
150 """The Task Instance context. This is used to cache get_template_context."""
151
152 _ti_context_from_server: Annotated[TIRunContext | None, Field(repr=False)] = None
153 """The Task Instance context from the API server, if any."""
154
155 max_tries: int = 0
156 """The maximum number of retries for the task."""
157
158 start_date: AwareDatetime
159 """Start date of the task instance."""
160
161 end_date: AwareDatetime | None = None
162
163 state: TaskInstanceState | None = None
164
165 is_mapped: bool | None = None
166 """True if the original task was mapped."""
167
168 rendered_map_index: str | None = None
169
170 sentry_integration: str = ""
171
172 def __rich_repr__(self):
173 yield "id", self.id
174 yield "task_id", self.task_id
175 yield "dag_id", self.dag_id
176 yield "run_id", self.run_id
177 yield "max_tries", self.max_tries
178 yield "task", type(self.task)
179 yield "start_date", self.start_date
180
181 __rich_repr__.angular = True # type: ignore[attr-defined]
182
183 def get_template_context(self) -> Context:
184 # TODO: Move this to `airflow.sdk.execution_time.context`
185 # once we port the entire context logic from airflow/utils/context.py ?
186 from airflow.sdk.plugins_manager import integrate_macros_plugins
187
188 integrate_macros_plugins()
189
190 dag_run_conf: dict[str, Any] | None = None
191 if from_server := self._ti_context_from_server:
192 dag_run_conf = from_server.dag_run.conf or dag_run_conf
193
194 validated_params = process_params(self.task.dag, self.task, dag_run_conf, suppress_exception=False)
195
196 # Cache the context object, which ensures that all calls to get_template_context
197 # are operating on the same context object.
198 self._cached_template_context: Context = self._cached_template_context or {
199 # From the Task Execution interface
200 "dag": self.task.dag,
201 "inlets": self.task.inlets,
202 "map_index_template": self.task.map_index_template,
203 "outlets": self.task.outlets,
204 "run_id": self.run_id,
205 "task": self.task,
206 "task_instance": self,
207 "ti": self,
208 "outlet_events": OutletEventAccessors(),
209 "inlet_events": InletEventsAccessors(self.task.inlets),
210 "macros": MacrosAccessor(),
211 "params": validated_params,
212 # TODO: Make this go through Public API longer term.
213 # "test_mode": task_instance.test_mode,
214 "var": {
215 "json": VariableAccessor(deserialize_json=True),
216 "value": VariableAccessor(deserialize_json=False),
217 },
218 "conn": ConnectionAccessor(),
219 }
220 if from_server:
221 dag_run = from_server.dag_run
222 context_from_server: Context = {
223 # TODO: Assess if we need to pass these through timezone.coerce_datetime
224 "dag_run": dag_run, # type: ignore[typeddict-item] # Removable after #46522
225 "triggering_asset_events": TriggeringAssetEventsAccessor.build(
226 AssetEventDagRunReferenceResult.from_asset_event_dag_run_reference(event)
227 for event in dag_run.consumed_asset_events
228 ),
229 "task_instance_key_str": f"{self.task.dag_id}__{self.task.task_id}__{dag_run.run_id}",
230 "task_reschedule_count": from_server.task_reschedule_count or 0,
231 "prev_start_date_success": lazy_object_proxy.Proxy(
232 lambda: coerce_datetime(get_previous_dagrun_success(self.id).start_date)
233 ),
234 "prev_end_date_success": lazy_object_proxy.Proxy(
235 lambda: coerce_datetime(get_previous_dagrun_success(self.id).end_date)
236 ),
237 }
238 self._cached_template_context.update(context_from_server)
239
240 if logical_date := coerce_datetime(dag_run.logical_date):
241 if TYPE_CHECKING:
242 assert isinstance(logical_date, DateTime)
243 ds = logical_date.strftime("%Y-%m-%d")
244 ds_nodash = ds.replace("-", "")
245 ts = logical_date.isoformat()
246 ts_nodash = logical_date.strftime("%Y%m%dT%H%M%S")
247 ts_nodash_with_tz = ts.replace("-", "").replace(":", "")
248 # logical_date and data_interval either coexist or be None together
249 self._cached_template_context.update(
250 {
251 # keys that depend on logical_date
252 "logical_date": logical_date,
253 "ds": ds,
254 "ds_nodash": ds_nodash,
255 "task_instance_key_str": f"{self.task.dag_id}__{self.task.task_id}__{ds_nodash}",
256 "ts": ts,
257 "ts_nodash": ts_nodash,
258 "ts_nodash_with_tz": ts_nodash_with_tz,
259 # keys that depend on data_interval
260 "data_interval_end": coerce_datetime(dag_run.data_interval_end),
261 "data_interval_start": coerce_datetime(dag_run.data_interval_start),
262 "prev_data_interval_start_success": lazy_object_proxy.Proxy(
263 lambda: coerce_datetime(get_previous_dagrun_success(self.id).data_interval_start)
264 ),
265 "prev_data_interval_end_success": lazy_object_proxy.Proxy(
266 lambda: coerce_datetime(get_previous_dagrun_success(self.id).data_interval_end)
267 ),
268 }
269 )
270
271 if from_server.upstream_map_indexes is not None:
272 # We stash this in here for later use, but we purposefully don't want to document it's
273 # existence. Should this be a private attribute on RuntimeTI instead perhaps?
274 setattr(self, "_upstream_map_indexes", from_server.upstream_map_indexes)
275
276 return self._cached_template_context
277
278 def render_templates(
279 self, context: Context | None = None, jinja_env: jinja2.Environment | None = None
280 ) -> BaseOperator:
281 """
282 Render templates in the operator fields.
283
284 If the task was originally mapped, this may replace ``self.task`` with
285 the unmapped, fully rendered BaseOperator. The original ``self.task``
286 before replacement is returned.
287 """
288 if not context:
289 context = self.get_template_context()
290 original_task = self.task
291
292 if TYPE_CHECKING:
293 assert context
294
295 ti = context["ti"]
296
297 if TYPE_CHECKING:
298 assert original_task
299 assert self.task
300 assert ti.task
301
302 # If self.task is mapped, this call replaces self.task to point to the
303 # unmapped BaseOperator created by this function! This is because the
304 # MappedOperator is useless for template rendering, and we need to be
305 # able to access the unmapped task instead.
306 self.task.render_template_fields(context, jinja_env)
307 self.is_mapped = original_task.is_mapped
308 return original_task
309
310 def xcom_pull(
311 self,
312 task_ids: str | Iterable[str] | None = None,
313 dag_id: str | None = None,
314 key: str = BaseXCom.XCOM_RETURN_KEY,
315 include_prior_dates: bool = False,
316 *,
317 map_indexes: int | Iterable[int] | None | ArgNotSet = NOTSET,
318 default: Any = None,
319 run_id: str | None = None,
320 ) -> Any:
321 """
322 Pull XComs either from the API server (BaseXCom) or from the custom XCOM backend if configured.
323
324 The pull can be filtered optionally by certain criterion.
325
326 :param key: A key for the XCom. If provided, only XComs with matching
327 keys will be returned. The default key is ``'return_value'``, also
328 available as constant ``XCOM_RETURN_KEY``. This key is automatically
329 given to XComs returned by tasks (as opposed to being pushed
330 manually).
331 :param task_ids: Only XComs from tasks with matching ids will be
332 pulled. If *None* (default), the task_id of the calling task is used.
333 :param dag_id: If provided, only pulls XComs from this Dag. If *None*
334 (default), the Dag of the calling task is used.
335 :param map_indexes: If provided, only pull XComs with matching indexes.
336 If *None* (default), this is inferred from the task(s) being pulled
337 (see below for details).
338 :param include_prior_dates: If False, only XComs from the current
339 logical_date are returned. If *True*, XComs from previous dates
340 are returned as well.
341 :param run_id: If provided, only pulls XComs from a DagRun w/a matching run_id.
342 If *None* (default), the run_id of the calling task is used.
343
344 When pulling one single task (``task_id`` is *None* or a str) without
345 specifying ``map_indexes``, the return value is a single XCom entry
346 (map_indexes is set to map_index of the calling task instance).
347
348 When pulling task is mapped the specified ``map_index`` is used, so by default
349 pulling on mapped task will result in no matching XComs if the task instance
350 of the method call is not mapped. Otherwise, the map_index of the calling task
351 instance is used. Setting ``map_indexes`` to *None* will pull XCom as it would
352 from a non mapped task.
353
354 In either case, ``default`` (*None* if not specified) is returned if no
355 matching XComs are found.
356
357 When pulling multiple tasks (i.e. either ``task_id`` or ``map_index`` is
358 a non-str iterable), a list of matching XComs is returned. Elements in
359 the list is ordered by item ordering in ``task_id`` and ``map_index``.
360 """
361 if dag_id is None:
362 dag_id = self.dag_id
363 if run_id is None:
364 run_id = self.run_id
365
366 single_task_requested = isinstance(task_ids, (str, type(None)))
367 single_map_index_requested = isinstance(map_indexes, (int, type(None)))
368
369 if task_ids is None:
370 # default to the current task if not provided
371 task_ids = [self.task_id]
372 elif isinstance(task_ids, str):
373 task_ids = [task_ids]
374
375 # If map_indexes is not specified, pull xcoms from all map indexes for each task
376 if not is_arg_set(map_indexes):
377 xcoms: list[Any] = []
378 for t_id in task_ids:
379 values = XCom.get_all(
380 run_id=run_id,
381 key=key,
382 task_id=t_id,
383 dag_id=dag_id,
384 include_prior_dates=include_prior_dates,
385 )
386
387 if values is None:
388 xcoms.append(None)
389 else:
390 xcoms.extend(values)
391 # For single task pulling from unmapped task, return single value
392 if single_task_requested and len(xcoms) == 1:
393 return xcoms[0]
394 return xcoms
395
396 # Original logic when map_indexes is explicitly specified
397 map_indexes_iterable: Iterable[int | None] = []
398 if isinstance(map_indexes, int) or map_indexes is None:
399 map_indexes_iterable = [map_indexes]
400 elif isinstance(map_indexes, Iterable):
401 map_indexes_iterable = map_indexes
402 else:
403 raise TypeError(
404 f"Invalid type for map_indexes: expected int, iterable of ints, or None, got {type(map_indexes)}"
405 )
406
407 xcoms = []
408 for t_id, m_idx in product(task_ids, map_indexes_iterable):
409 value = XCom.get_one(
410 run_id=run_id,
411 key=key,
412 task_id=t_id,
413 dag_id=dag_id,
414 map_index=m_idx,
415 include_prior_dates=include_prior_dates,
416 )
417 if value is None:
418 xcoms.append(default)
419 else:
420 xcoms.append(value)
421
422 if single_task_requested and single_map_index_requested:
423 return xcoms[0]
424
425 return xcoms
426
427 def xcom_push(self, key: str, value: Any):
428 """
429 Make an XCom available for tasks to pull.
430
431 :param key: Key to store the value under.
432 :param value: Value to store. Only be JSON-serializable values may be used.
433 """
434 _xcom_push(self, key, value)
435
436 def get_relevant_upstream_map_indexes(
437 self, upstream: BaseOperator, ti_count: int | None, session: Any
438 ) -> int | range | None:
439 # TODO: Implement this method
440 return None
441
442 def get_first_reschedule_date(self, context: Context) -> AwareDatetime | None:
443 """Get the first reschedule date for the task instance if found, none otherwise."""
444 if context.get("task_reschedule_count", 0) == 0:
445 # If the task has not been rescheduled, there is no need to ask the supervisor
446 return None
447
448 max_tries: int = self.max_tries
449 retries: int = self.task.retries or 0
450 first_try_number = max_tries - retries + 1
451
452 log = structlog.get_logger(logger_name="task")
453
454 log.debug("Requesting first reschedule date from supervisor")
455
456 response = SUPERVISOR_COMMS.send(
457 msg=GetTaskRescheduleStartDate(ti_id=self.id, try_number=first_try_number)
458 )
459
460 if TYPE_CHECKING:
461 assert isinstance(response, TaskRescheduleStartDate)
462
463 return response.start_date
464
465 def get_previous_dagrun(self, state: str | None = None) -> DagRun | None:
466 """Return the previous Dag run before the given logical date, optionally filtered by state."""
467 context = self.get_template_context()
468 dag_run = context.get("dag_run")
469
470 log = structlog.get_logger(logger_name="task")
471
472 log.debug("Getting previous Dag run", dag_run=dag_run)
473
474 if dag_run is None:
475 return None
476
477 if dag_run.logical_date is None:
478 return None
479
480 response = SUPERVISOR_COMMS.send(
481 msg=GetPreviousDagRun(dag_id=self.dag_id, logical_date=dag_run.logical_date, state=state)
482 )
483
484 if TYPE_CHECKING:
485 assert isinstance(response, PreviousDagRunResult)
486
487 return response.dag_run
488
489 def get_previous_ti(
490 self,
491 state: TaskInstanceState | None = None,
492 logical_date: AwareDatetime | None = None,
493 map_index: int = -1,
494 ) -> PreviousTIResponse | None:
495 """
496 Return the previous task instance matching the given criteria.
497
498 :param state: Filter by TaskInstance state
499 :param logical_date: Filter by logical date (returns TI before this date)
500 :param map_index: Filter by map_index (defaults to -1 for non-mapped tasks)
501 :return: Previous task instance or None if not found
502 """
503 context = self.get_template_context()
504 dag_run = context.get("dag_run")
505
506 log = structlog.get_logger(logger_name="task")
507 log.debug("Getting previous task instance", task_id=self.task_id, state=state)
508
509 # Use current dag run's logical_date if not provided
510 effective_logical_date = logical_date
511 if effective_logical_date is None and dag_run and dag_run.logical_date:
512 effective_logical_date = dag_run.logical_date
513
514 response = SUPERVISOR_COMMS.send(
515 msg=GetPreviousTI(
516 dag_id=self.dag_id,
517 task_id=self.task_id,
518 logical_date=effective_logical_date,
519 map_index=map_index,
520 state=state,
521 )
522 )
523
524 if TYPE_CHECKING:
525 assert isinstance(response, PreviousTIResult)
526
527 return response.task_instance
528
529 @staticmethod
530 def get_ti_count(
531 dag_id: str,
532 map_index: int | None = None,
533 task_ids: list[str] | None = None,
534 task_group_id: str | None = None,
535 logical_dates: list[datetime] | None = None,
536 run_ids: list[str] | None = None,
537 states: list[str] | None = None,
538 ) -> int:
539 """Return the number of task instances matching the given criteria."""
540 response = SUPERVISOR_COMMS.send(
541 GetTICount(
542 dag_id=dag_id,
543 map_index=map_index,
544 task_ids=task_ids,
545 task_group_id=task_group_id,
546 logical_dates=logical_dates,
547 run_ids=run_ids,
548 states=states,
549 ),
550 )
551
552 if TYPE_CHECKING:
553 assert isinstance(response, TICount)
554
555 return response.count
556
557 @staticmethod
558 def get_task_states(
559 dag_id: str,
560 map_index: int | None = None,
561 task_ids: list[str] | None = None,
562 task_group_id: str | None = None,
563 logical_dates: list[datetime] | None = None,
564 run_ids: list[str] | None = None,
565 ) -> dict[str, Any]:
566 """Return the task states matching the given criteria."""
567 response = SUPERVISOR_COMMS.send(
568 GetTaskStates(
569 dag_id=dag_id,
570 map_index=map_index,
571 task_ids=task_ids,
572 task_group_id=task_group_id,
573 logical_dates=logical_dates,
574 run_ids=run_ids,
575 ),
576 )
577
578 if TYPE_CHECKING:
579 assert isinstance(response, TaskStatesResult)
580
581 return response.task_states
582
583 @staticmethod
584 def get_task_breadcrumbs(dag_id: str, run_id: str) -> Iterable[dict[str, Any]]:
585 """Return task breadcrumbs for the given dag run."""
586 response = SUPERVISOR_COMMS.send(GetTaskBreadcrumbs(dag_id=dag_id, run_id=run_id))
587 if TYPE_CHECKING:
588 assert isinstance(response, TaskBreadcrumbsResult)
589 return response.breadcrumbs
590
591 @staticmethod
592 def get_dr_count(
593 dag_id: str,
594 logical_dates: list[datetime] | None = None,
595 run_ids: list[str] | None = None,
596 states: list[str] | None = None,
597 ) -> int:
598 """Return the number of Dag runs matching the given criteria."""
599 response = SUPERVISOR_COMMS.send(
600 GetDRCount(
601 dag_id=dag_id,
602 logical_dates=logical_dates,
603 run_ids=run_ids,
604 states=states,
605 ),
606 )
607
608 if TYPE_CHECKING:
609 assert isinstance(response, DRCount)
610
611 return response.count
612
613 @staticmethod
614 def get_dagrun_state(dag_id: str, run_id: str) -> str:
615 """Return the state of the Dag run with the given Run ID."""
616 response = SUPERVISOR_COMMS.send(msg=GetDagRunState(dag_id=dag_id, run_id=run_id))
617
618 if TYPE_CHECKING:
619 assert isinstance(response, DagRunStateResult)
620
621 return response.state
622
623 @property
624 def log_url(self) -> str:
625 run_id = quote(self.run_id)
626 base_url = conf.get("api", "base_url", fallback="http://localhost:8080/")
627 map_index_value = self.map_index
628 map_index = (
629 f"/mapped/{map_index_value}" if map_index_value is not None and map_index_value >= 0 else ""
630 )
631 try_number_value = self.try_number
632 try_number = (
633 f"?try_number={try_number_value}" if try_number_value is not None and try_number_value > 0 else ""
634 )
635 _log_uri = f"{base_url.rstrip('/')}/dags/{self.dag_id}/runs/{run_id}/tasks/{self.task_id}{map_index}{try_number}"
636 return _log_uri
637
638 @property
639 def mark_success_url(self) -> str:
640 """URL to mark TI success."""
641 return self.log_url
642
643
644def _xcom_push(ti: RuntimeTaskInstance, key: str, value: Any, mapped_length: int | None = None) -> None:
645 """Push a XCom through XCom.set, which pushes to XCom Backend if configured."""
646 # Private function, as we don't want to expose the ability to manually set `mapped_length` to SDK
647 # consumers
648
649 XCom.set(
650 key=key,
651 value=value,
652 dag_id=ti.dag_id,
653 task_id=ti.task_id,
654 run_id=ti.run_id,
655 map_index=ti.map_index,
656 _mapped_length=mapped_length,
657 )
658
659
660def _xcom_push_to_db(ti: RuntimeTaskInstance, key: str, value: Any) -> None:
661 """Push a XCom directly to metadata DB, bypassing custom xcom_backend."""
662 XCom._set_xcom_in_db(
663 key=key,
664 value=value,
665 dag_id=ti.dag_id,
666 task_id=ti.task_id,
667 run_id=ti.run_id,
668 map_index=ti.map_index,
669 )
670
671
672def parse(what: StartupDetails, log: Logger) -> RuntimeTaskInstance:
673 # TODO: Task-SDK:
674 # Using BundleDagBag here is about 98% wrong, but it'll do for now
675 from airflow.dag_processing.dagbag import BundleDagBag
676
677 bundle_info = what.bundle_info
678 bundle_instance = DagBundlesManager().get_bundle(
679 name=bundle_info.name,
680 version=bundle_info.version,
681 )
682 bundle_instance.initialize()
683
684 dag_absolute_path = os.fspath(Path(bundle_instance.path, what.dag_rel_path))
685 bag = BundleDagBag(
686 dag_folder=dag_absolute_path,
687 safe_mode=False,
688 load_op_links=False,
689 bundle_path=bundle_instance.path,
690 bundle_name=bundle_info.name,
691 )
692 if TYPE_CHECKING:
693 assert what.ti.dag_id
694
695 try:
696 dag = bag.dags[what.ti.dag_id]
697 except KeyError:
698 log.error(
699 "Dag not found during start up", dag_id=what.ti.dag_id, bundle=bundle_info, path=what.dag_rel_path
700 )
701 sys.exit(1)
702
703 # install_loader()
704
705 try:
706 task = dag.task_dict[what.ti.task_id]
707 except KeyError:
708 log.error(
709 "Task not found in Dag during start up",
710 dag_id=dag.dag_id,
711 task_id=what.ti.task_id,
712 bundle=bundle_info,
713 path=what.dag_rel_path,
714 )
715 sys.exit(1)
716
717 if not isinstance(task, (BaseOperator, MappedOperator)):
718 raise TypeError(
719 f"task is of the wrong type, got {type(task)}, wanted {BaseOperator} or {MappedOperator}"
720 )
721
722 return RuntimeTaskInstance.model_construct(
723 **what.ti.model_dump(exclude_unset=True),
724 task=task,
725 bundle_instance=bundle_instance,
726 _ti_context_from_server=what.ti_context,
727 max_tries=what.ti_context.max_tries,
728 start_date=what.start_date,
729 state=TaskInstanceState.RUNNING,
730 sentry_integration=what.sentry_integration,
731 )
732
733
734# This global variable will be used by Connection/Variable/XCom classes, or other parts of the task's execution,
735# to send requests back to the supervisor process.
736#
737# Why it needs to be a global:
738# - Many parts of Airflow's codebase (e.g., connections, variables, and XComs) may rely on making dynamic requests
739# to the parent process during task execution.
740# - These calls occur in various locations and cannot easily pass the `CommsDecoder` instance through the
741# deeply nested execution stack.
742# - By defining `SUPERVISOR_COMMS` as a global, it ensures that this communication mechanism is readily
743# accessible wherever needed during task execution without modifying every layer of the call stack.
744SUPERVISOR_COMMS: CommsDecoder[ToTask, ToSupervisor]
745
746
747# State machine!
748# 1. Start up (receive details from supervisor)
749# 2. Execution (run task code, possibly send requests)
750# 3. Shutdown and report status
751
752
753def startup() -> tuple[RuntimeTaskInstance, Context, Logger]:
754 # The parent sends us a StartupDetails message un-prompted. After this, every single message is only sent
755 # in response to us sending a request.
756 log = structlog.get_logger(logger_name="task")
757
758 if os.environ.get("_AIRFLOW__REEXECUTED_PROCESS") == "1" and (
759 msgjson := os.environ.get("_AIRFLOW__STARTUP_MSG")
760 ):
761 # Clear any Kerberos replace cache if there is one, so new process can't reuse it.
762 os.environ.pop("KRB5CCNAME", None)
763 # entrypoint of re-exec process
764
765 msg: StartupDetails = TypeAdapter(StartupDetails).validate_json(msgjson)
766 reinit_supervisor_comms()
767
768 # We delay this message until _after_ we've got the logging re-configured, otherwise it will show up
769 # on stdout
770 log.debug("Using serialized startup message from environment", msg=msg)
771 else:
772 # normal entry point
773 msg = SUPERVISOR_COMMS._get_response() # type: ignore[assignment]
774
775 if not isinstance(msg, StartupDetails):
776 raise RuntimeError(f"Unhandled startup message {type(msg)} {msg}")
777
778 # setproctitle causes issue on Mac OS: https://github.com/benoitc/gunicorn/issues/3021
779 os_type = sys.platform
780 if os_type == "darwin":
781 log.debug("Mac OS detected, skipping setproctitle")
782 else:
783 from setproctitle import setproctitle
784
785 setproctitle(f"airflow worker -- {msg.ti.id}")
786
787 try:
788 get_listener_manager().hook.on_starting(component=TaskRunnerMarker())
789 except Exception:
790 log.exception("error calling listener")
791
792 with _airflow_parsing_context_manager(dag_id=msg.ti.dag_id, task_id=msg.ti.task_id):
793 ti = parse(msg, log)
794 log.debug("Dag file parsed", file=msg.dag_rel_path)
795
796 run_as_user = getattr(ti.task, "run_as_user", None) or conf.get(
797 "core", "default_impersonation", fallback=None
798 )
799
800 if os.environ.get("_AIRFLOW__REEXECUTED_PROCESS") != "1" and run_as_user and run_as_user != getuser():
801 # enters here for re-exec process
802 os.environ["_AIRFLOW__REEXECUTED_PROCESS"] = "1"
803 # store startup message in environment for re-exec process
804 os.environ["_AIRFLOW__STARTUP_MSG"] = msg.model_dump_json()
805 os.set_inheritable(SUPERVISOR_COMMS.socket.fileno(), True)
806
807 # Import main directly from the module instead of re-executing the file.
808 # This ensures that when other parts modules import
809 # airflow.sdk.execution_time.task_runner, they get the same module instance
810 # with the properly initialized SUPERVISOR_COMMS global variable.
811 # If we re-executed the module with `python -m`, it would load as __main__ and future
812 # imports would get a fresh copy without the initialized globals.
813 rexec_python_code = "from airflow.sdk.execution_time.task_runner import main; main()"
814 cmd = ["sudo", "-E", "-H", "-u", run_as_user, sys.executable, "-c", rexec_python_code]
815 log.info(
816 "Running command",
817 command=cmd,
818 )
819 os.execvp("sudo", cmd)
820
821 # ideally, we should never reach here, but if we do, we should return None, None, None
822 return None, None, None
823
824 return ti, ti.get_template_context(), log
825
826
827def _serialize_template_field(template_field: Any, name: str) -> str | dict | list | int | float:
828 """
829 Return a serializable representation of the templated field.
830
831 If ``templated_field`` contains a class or instance that requires recursive
832 templating, store them as strings. Otherwise simply return the field as-is.
833
834 Used sdk secrets masker to redact secrets in the serialized output.
835 """
836 import json
837
838 from airflow.sdk._shared.secrets_masker import redact
839
840 def is_jsonable(x):
841 try:
842 json.dumps(x)
843 except (TypeError, OverflowError):
844 return False
845 else:
846 return True
847
848 def translate_tuples_to_lists(obj: Any):
849 """Recursively convert tuples to lists."""
850 if isinstance(obj, tuple):
851 return [translate_tuples_to_lists(item) for item in obj]
852 if isinstance(obj, list):
853 return [translate_tuples_to_lists(item) for item in obj]
854 if isinstance(obj, dict):
855 return {key: translate_tuples_to_lists(value) for key, value in obj.items()}
856 return obj
857
858 def sort_dict_recursively(obj: Any) -> Any:
859 """Recursively sort dictionaries to ensure consistent ordering."""
860 if isinstance(obj, dict):
861 return {k: sort_dict_recursively(v) for k, v in sorted(obj.items())}
862 if isinstance(obj, list):
863 return [sort_dict_recursively(item) for item in obj]
864 if isinstance(obj, tuple):
865 return tuple(sort_dict_recursively(item) for item in obj)
866 return obj
867
868 max_length = conf.getint("core", "max_templated_field_length")
869
870 if not is_jsonable(template_field):
871 try:
872 serialized = template_field.serialize()
873 except AttributeError:
874 serialized = str(template_field)
875 if len(serialized) > max_length:
876 rendered = redact(serialized, name)
877 return (
878 "Truncated. You can change this behaviour in [core]max_templated_field_length. "
879 f"{rendered[: max_length - 79]!r}... "
880 )
881 return serialized
882 if not template_field and not isinstance(template_field, tuple):
883 # Avoid unnecessary serialization steps for empty fields unless they are tuples
884 # and need to be converted to lists
885 return template_field
886 template_field = translate_tuples_to_lists(template_field)
887 # Sort dictionaries recursively to ensure consistent string representation
888 # This prevents hash inconsistencies when dict ordering varies
889 if isinstance(template_field, dict):
890 template_field = sort_dict_recursively(template_field)
891 serialized = str(template_field)
892 if len(serialized) > max_length:
893 rendered = redact(serialized, name)
894 return (
895 "Truncated. You can change this behaviour in [core]max_templated_field_length. "
896 f"{rendered[: max_length - 79]!r}... "
897 )
898 return template_field
899
900
901def _serialize_rendered_fields(task: AbstractOperator) -> dict[str, JsonValue]:
902 from airflow.sdk._shared.secrets_masker import redact
903
904 rendered_fields = {}
905 for field in task.template_fields:
906 value = getattr(task, field)
907 serialized = _serialize_template_field(value, field)
908 # Redact secrets in the task process itself before sending to API server
909 # This ensures that the secrets those are registered via mask_secret() on workers / dag processor are properly masked
910 # on the UI.
911 rendered_fields[field] = redact(serialized, field)
912
913 return rendered_fields # type: ignore[return-value] # Convince mypy that this is OK since we pass JsonValue to redact, so it will return the same
914
915
916def _build_asset_profiles(lineage_objects: list) -> Iterator[AssetProfile]:
917 # Lineage can have other types of objects besides assets, so we need to process them a bit.
918 for obj in lineage_objects or ():
919 if isinstance(obj, Asset):
920 yield AssetProfile(name=obj.name, uri=obj.uri, type=Asset.__name__)
921 elif isinstance(obj, AssetNameRef):
922 yield AssetProfile(name=obj.name, type=AssetNameRef.__name__)
923 elif isinstance(obj, AssetUriRef):
924 yield AssetProfile(uri=obj.uri, type=AssetUriRef.__name__)
925 elif isinstance(obj, AssetAlias):
926 yield AssetProfile(name=obj.name, type=AssetAlias.__name__)
927
928
929def _serialize_outlet_events(events: OutletEventAccessorsProtocol) -> Iterator[dict[str, JsonValue]]:
930 if TYPE_CHECKING:
931 assert isinstance(events, OutletEventAccessors)
932 # We just collect everything the user recorded in the accessors.
933 # Further filtering will be done in the API server.
934 for key, accessor in events._dict.items():
935 if isinstance(key, AssetUniqueKey):
936 yield {"dest_asset_key": attrs.asdict(key), "extra": accessor.extra}
937 for alias_event in accessor.asset_alias_events:
938 yield attrs.asdict(alias_event)
939
940
941def _prepare(ti: RuntimeTaskInstance, log: Logger, context: Context) -> ToSupervisor | None:
942 ti.hostname = get_hostname()
943 ti.task = ti.task.prepare_for_execution()
944 # Since context is now cached, and calling `ti.get_template_context` will return the same dict, we want to
945 # update the value of the task that is sent from there
946 context["task"] = ti.task
947
948 jinja_env = ti.task.dag.get_template_env()
949 ti.render_templates(context=context, jinja_env=jinja_env)
950
951 if rendered_fields := _serialize_rendered_fields(ti.task):
952 # so that we do not call the API unnecessarily
953 SUPERVISOR_COMMS.send(msg=SetRenderedFields(rendered_fields=rendered_fields))
954
955 # Try to render map_index_template early with available context (will be re-rendered after execution)
956 # This provides a partial label during task execution for templates using pre-execution context
957 # If rendering fails here, we suppress the error since it will be re-rendered after execution
958 try:
959 if rendered_map_index := _render_map_index(context, ti=ti, log=log):
960 ti.rendered_map_index = rendered_map_index
961 log.debug("Sending early rendered map index", length=len(rendered_map_index))
962 SUPERVISOR_COMMS.send(msg=SetRenderedMapIndex(rendered_map_index=rendered_map_index))
963 except Exception:
964 log.debug(
965 "Early rendering of map_index_template failed, will retry after task execution", exc_info=True
966 )
967
968 _validate_task_inlets_and_outlets(ti=ti, log=log)
969
970 try:
971 # TODO: Call pre execute etc.
972 get_listener_manager().hook.on_task_instance_running(
973 previous_state=TaskInstanceState.QUEUED, task_instance=ti
974 )
975 except Exception:
976 log.exception("error calling listener")
977
978 # No error, carry on and execute the task
979 return None
980
981
982def _validate_task_inlets_and_outlets(*, ti: RuntimeTaskInstance, log: Logger) -> None:
983 if not ti.task.inlets and not ti.task.outlets:
984 return
985
986 inactive_assets_resp = SUPERVISOR_COMMS.send(msg=ValidateInletsAndOutlets(ti_id=ti.id))
987 if TYPE_CHECKING:
988 assert isinstance(inactive_assets_resp, InactiveAssetsResult)
989 if inactive_assets := inactive_assets_resp.inactive_assets:
990 raise AirflowInactiveAssetInInletOrOutletException(
991 inactive_asset_keys=[
992 AssetUniqueKey.from_profile(asset_profile) for asset_profile in inactive_assets
993 ]
994 )
995
996
997def _defer_task(
998 defer: TaskDeferred, ti: RuntimeTaskInstance, log: Logger
999) -> tuple[ToSupervisor, TaskInstanceState]:
1000 # TODO: Should we use structlog.bind_contextvars here for dag_id, task_id & run_id?
1001
1002 log.info("Pausing task as DEFERRED. ", dag_id=ti.dag_id, task_id=ti.task_id, run_id=ti.run_id)
1003 classpath, trigger_kwargs = defer.trigger.serialize()
1004 queue: str | None = None
1005 # Currently, only task-associated BaseTrigger instances may have a non-None queue,
1006 # and only when triggerer.queues_enabled is True.
1007 if not isinstance(defer.trigger, (BaseEventTrigger, CallbackTrigger)) and conf.getboolean(
1008 "triggerer", "queues_enabled", fallback=False
1009 ):
1010 queue = ti.task.queue
1011
1012 from airflow.sdk.serde import serialize as serde_serialize
1013
1014 trigger_kwargs = serde_serialize(trigger_kwargs)
1015 next_kwargs = serde_serialize(defer.kwargs or {})
1016
1017 if TYPE_CHECKING:
1018 assert isinstance(next_kwargs, dict)
1019 assert isinstance(trigger_kwargs, dict)
1020
1021 msg = DeferTask(
1022 classpath=classpath,
1023 trigger_kwargs=trigger_kwargs,
1024 trigger_timeout=defer.timeout,
1025 queue=queue,
1026 next_method=defer.method_name,
1027 next_kwargs=next_kwargs,
1028 )
1029 state = TaskInstanceState.DEFERRED
1030
1031 return msg, state
1032
1033
1034@Sentry.enrich_errors
1035def run(
1036 ti: RuntimeTaskInstance,
1037 context: Context,
1038 log: Logger,
1039) -> tuple[TaskInstanceState, ToSupervisor | None, BaseException | None]:
1040 """Run the task in this process."""
1041 import signal
1042
1043 from airflow.sdk.exceptions import (
1044 AirflowFailException,
1045 AirflowRescheduleException,
1046 AirflowSensorTimeout,
1047 AirflowSkipException,
1048 AirflowTaskTerminated,
1049 DagRunTriggerException,
1050 DownstreamTasksSkipped,
1051 TaskDeferred,
1052 )
1053
1054 if TYPE_CHECKING:
1055 assert ti.task is not None
1056 assert isinstance(ti.task, BaseOperator)
1057
1058 parent_pid = os.getpid()
1059
1060 def _on_term(signum, frame):
1061 pid = os.getpid()
1062 if pid != parent_pid:
1063 return
1064
1065 ti.task.on_kill()
1066
1067 signal.signal(signal.SIGTERM, _on_term)
1068
1069 msg: ToSupervisor | None = None
1070 state: TaskInstanceState
1071 error: BaseException | None = None
1072
1073 try:
1074 # First, clear the xcom data sent from server
1075 if ti._ti_context_from_server and (keys_to_delete := ti._ti_context_from_server.xcom_keys_to_clear):
1076 for x in keys_to_delete:
1077 log.debug("Clearing XCom with key", key=x)
1078 XCom.delete(
1079 key=x,
1080 dag_id=ti.dag_id,
1081 task_id=ti.task_id,
1082 run_id=ti.run_id,
1083 map_index=ti.map_index,
1084 )
1085
1086 with set_current_context(context):
1087 # This is the earliest that we can render templates -- as if it excepts for any reason we need to
1088 # catch it and handle it like a normal task failure
1089 if early_exit := _prepare(ti, log, context):
1090 msg = early_exit
1091 ti.state = state = TaskInstanceState.FAILED
1092 return state, msg, error
1093
1094 try:
1095 result = _execute_task(context=context, ti=ti, log=log)
1096 except Exception:
1097 import jinja2
1098
1099 # If the task failed, swallow rendering error so it doesn't mask the main error.
1100 with contextlib.suppress(jinja2.TemplateSyntaxError, jinja2.UndefinedError):
1101 previous_rendered_map_index = ti.rendered_map_index
1102 ti.rendered_map_index = _render_map_index(context, ti=ti, log=log)
1103 # Send update only if value changed (e.g., user set context variables during execution)
1104 if ti.rendered_map_index and ti.rendered_map_index != previous_rendered_map_index:
1105 SUPERVISOR_COMMS.send(
1106 msg=SetRenderedMapIndex(rendered_map_index=ti.rendered_map_index)
1107 )
1108 raise
1109 else: # If the task succeeded, render normally to let rendering error bubble up.
1110 previous_rendered_map_index = ti.rendered_map_index
1111 ti.rendered_map_index = _render_map_index(context, ti=ti, log=log)
1112 # Send update only if value changed (e.g., user set context variables during execution)
1113 if ti.rendered_map_index and ti.rendered_map_index != previous_rendered_map_index:
1114 SUPERVISOR_COMMS.send(msg=SetRenderedMapIndex(rendered_map_index=ti.rendered_map_index))
1115
1116 _push_xcom_if_needed(result, ti, log)
1117
1118 msg, state = _handle_current_task_success(context, ti)
1119 except DownstreamTasksSkipped as skip:
1120 log.info("Skipping downstream tasks.")
1121 tasks_to_skip = skip.tasks if isinstance(skip.tasks, list) else [skip.tasks]
1122 SUPERVISOR_COMMS.send(msg=SkipDownstreamTasks(tasks=tasks_to_skip))
1123 msg, state = _handle_current_task_success(context, ti)
1124 except DagRunTriggerException as drte:
1125 msg, state = _handle_trigger_dag_run(drte, context, ti, log)
1126 except TaskDeferred as defer:
1127 msg, state = _defer_task(defer, ti, log)
1128 except AirflowSkipException as e:
1129 if e.args:
1130 log.info("Skipping task.", reason=e.args[0])
1131 msg = TaskState(
1132 state=TaskInstanceState.SKIPPED,
1133 end_date=datetime.now(tz=timezone.utc),
1134 rendered_map_index=ti.rendered_map_index,
1135 )
1136 state = TaskInstanceState.SKIPPED
1137 except AirflowRescheduleException as reschedule:
1138 log.info("Rescheduling task, marking task as UP_FOR_RESCHEDULE")
1139 msg = RescheduleTask(
1140 reschedule_date=reschedule.reschedule_date, end_date=datetime.now(tz=timezone.utc)
1141 )
1142 state = TaskInstanceState.UP_FOR_RESCHEDULE
1143 except (AirflowFailException, AirflowSensorTimeout) as e:
1144 # If AirflowFailException is raised, task should not retry.
1145 # If a sensor in reschedule mode reaches timeout, task should not retry.
1146 log.exception("Task failed with exception")
1147 ti.end_date = datetime.now(tz=timezone.utc)
1148 msg = TaskState(
1149 state=TaskInstanceState.FAILED,
1150 end_date=ti.end_date,
1151 rendered_map_index=ti.rendered_map_index,
1152 )
1153 state = TaskInstanceState.FAILED
1154 error = e
1155 except (AirflowTaskTimeout, AirflowException, AirflowRuntimeError) as e:
1156 # We should allow retries if the task has defined it.
1157 log.exception("Task failed with exception")
1158 msg, state = _handle_current_task_failed(ti)
1159 error = e
1160 except AirflowTaskTerminated as e:
1161 # External state updates are already handled with `ti_heartbeat` and will be
1162 # updated already be another UI API. So, these exceptions should ideally never be thrown.
1163 # If these are thrown, we should mark the TI state as failed.
1164 log.exception("Task failed with exception")
1165 ti.end_date = datetime.now(tz=timezone.utc)
1166 msg = TaskState(
1167 state=TaskInstanceState.FAILED,
1168 end_date=ti.end_date,
1169 rendered_map_index=ti.rendered_map_index,
1170 )
1171 state = TaskInstanceState.FAILED
1172 error = e
1173 except SystemExit as e:
1174 # SystemExit needs to be retried if they are eligible.
1175 log.error("Task exited", exit_code=e.code)
1176 msg, state = _handle_current_task_failed(ti)
1177 error = e
1178 except BaseException as e:
1179 log.exception("Task failed with exception")
1180 msg, state = _handle_current_task_failed(ti)
1181 error = e
1182 finally:
1183 if msg:
1184 SUPERVISOR_COMMS.send(msg=msg)
1185
1186 # Return the message to make unit tests easier too
1187 ti.state = state
1188 return state, msg, error
1189
1190
1191def _handle_current_task_success(
1192 context: Context,
1193 ti: RuntimeTaskInstance,
1194) -> tuple[SucceedTask, TaskInstanceState]:
1195 end_date = datetime.now(tz=timezone.utc)
1196 ti.end_date = end_date
1197
1198 # Record operator and task instance success metrics
1199 operator = ti.task.__class__.__name__
1200 stats_tags = {"dag_id": ti.dag_id, "task_id": ti.task_id}
1201
1202 Stats.incr(f"operator_successes_{operator}", tags=stats_tags)
1203 # Same metric with tagging
1204 Stats.incr("operator_successes", tags={**stats_tags, "operator": operator})
1205 Stats.incr("ti_successes", tags=stats_tags)
1206
1207 task_outlets = list(_build_asset_profiles(ti.task.outlets))
1208 outlet_events = list(_serialize_outlet_events(context["outlet_events"]))
1209 msg = SucceedTask(
1210 end_date=end_date,
1211 task_outlets=task_outlets,
1212 outlet_events=outlet_events,
1213 rendered_map_index=ti.rendered_map_index,
1214 )
1215 return msg, TaskInstanceState.SUCCESS
1216
1217
1218def _handle_current_task_failed(
1219 ti: RuntimeTaskInstance,
1220) -> tuple[RetryTask, TaskInstanceState] | tuple[TaskState, TaskInstanceState]:
1221 end_date = datetime.now(tz=timezone.utc)
1222 ti.end_date = end_date
1223
1224 # Record operator and task instance failed metrics
1225 operator = ti.task.__class__.__name__
1226 stats_tags = {"dag_id": ti.dag_id, "task_id": ti.task_id}
1227
1228 Stats.incr(f"operator_failures_{operator}", tags=stats_tags)
1229 # Same metric with tagging
1230 Stats.incr("operator_failures", tags={**stats_tags, "operator": operator})
1231 Stats.incr("ti_failures", tags=stats_tags)
1232
1233 if ti._ti_context_from_server and ti._ti_context_from_server.should_retry:
1234 return RetryTask(end_date=end_date), TaskInstanceState.UP_FOR_RETRY
1235 return (
1236 TaskState(
1237 state=TaskInstanceState.FAILED, end_date=end_date, rendered_map_index=ti.rendered_map_index
1238 ),
1239 TaskInstanceState.FAILED,
1240 )
1241
1242
1243def _handle_trigger_dag_run(
1244 drte: DagRunTriggerException, context: Context, ti: RuntimeTaskInstance, log: Logger
1245) -> tuple[ToSupervisor, TaskInstanceState]:
1246 """Handle exception from TriggerDagRunOperator."""
1247 log.info("Triggering Dag Run.", trigger_dag_id=drte.trigger_dag_id)
1248 comms_msg = SUPERVISOR_COMMS.send(
1249 TriggerDagRun(
1250 dag_id=drte.trigger_dag_id,
1251 run_id=drte.dag_run_id,
1252 logical_date=drte.logical_date,
1253 conf=drte.conf,
1254 reset_dag_run=drte.reset_dag_run,
1255 ),
1256 )
1257
1258 if isinstance(comms_msg, ErrorResponse) and comms_msg.error == ErrorType.DAGRUN_ALREADY_EXISTS:
1259 if drte.skip_when_already_exists:
1260 log.info(
1261 "Dag Run already exists, skipping task as skip_when_already_exists is set to True.",
1262 dag_id=drte.trigger_dag_id,
1263 )
1264 msg = TaskState(
1265 state=TaskInstanceState.SKIPPED,
1266 end_date=datetime.now(tz=timezone.utc),
1267 rendered_map_index=ti.rendered_map_index,
1268 )
1269 state = TaskInstanceState.SKIPPED
1270 else:
1271 log.error("Dag Run already exists, marking task as failed.", dag_id=drte.trigger_dag_id)
1272 msg = TaskState(
1273 state=TaskInstanceState.FAILED,
1274 end_date=datetime.now(tz=timezone.utc),
1275 rendered_map_index=ti.rendered_map_index,
1276 )
1277 state = TaskInstanceState.FAILED
1278
1279 return msg, state
1280
1281 log.info("Dag Run triggered successfully.", trigger_dag_id=drte.trigger_dag_id)
1282
1283 # Store the run id from the dag run (either created or found above) to
1284 # be used when creating the extra link on the webserver.
1285 ti.xcom_push(key="trigger_run_id", value=drte.dag_run_id)
1286
1287 if drte.wait_for_completion:
1288 if drte.deferrable:
1289 from airflow.providers.standard.triggers.external_task import DagStateTrigger
1290
1291 defer = TaskDeferred(
1292 trigger=DagStateTrigger(
1293 dag_id=drte.trigger_dag_id,
1294 states=drte.allowed_states + drte.failed_states, # type: ignore[arg-type]
1295 # Don't filter by execution_dates when run_ids is provided.
1296 # run_id uniquely identifies a DAG run, and when reset_dag_run=True,
1297 # drte.logical_date might be a newly calculated value that doesn't match
1298 # the persisted logical_date in the database, causing the trigger to never find the run.
1299 execution_dates=None,
1300 run_ids=[drte.dag_run_id],
1301 poll_interval=drte.poke_interval,
1302 ),
1303 method_name="execute_complete",
1304 )
1305 return _defer_task(defer, ti, log)
1306 while True:
1307 log.info(
1308 "Waiting for dag run to complete execution in allowed state.",
1309 dag_id=drte.trigger_dag_id,
1310 run_id=drte.dag_run_id,
1311 allowed_state=drte.allowed_states,
1312 )
1313 time.sleep(drte.poke_interval)
1314
1315 comms_msg = SUPERVISOR_COMMS.send(
1316 GetDagRunState(dag_id=drte.trigger_dag_id, run_id=drte.dag_run_id)
1317 )
1318 if TYPE_CHECKING:
1319 assert isinstance(comms_msg, DagRunStateResult)
1320 if comms_msg.state in drte.failed_states:
1321 log.error(
1322 "DagRun finished with failed state.", dag_id=drte.trigger_dag_id, state=comms_msg.state
1323 )
1324 msg = TaskState(
1325 state=TaskInstanceState.FAILED,
1326 end_date=datetime.now(tz=timezone.utc),
1327 rendered_map_index=ti.rendered_map_index,
1328 )
1329 state = TaskInstanceState.FAILED
1330 return msg, state
1331 if comms_msg.state in drte.allowed_states:
1332 log.info(
1333 "DagRun finished with allowed state.", dag_id=drte.trigger_dag_id, state=comms_msg.state
1334 )
1335 break
1336 log.debug(
1337 "DagRun not yet in allowed or failed state.",
1338 dag_id=drte.trigger_dag_id,
1339 state=comms_msg.state,
1340 )
1341 else:
1342 # Fire-and-forget mode: wait_for_completion=False
1343 if drte.deferrable:
1344 log.info(
1345 "Ignoring deferrable=True because wait_for_completion=False. "
1346 "Task will complete immediately without waiting for the triggered DAG run.",
1347 trigger_dag_id=drte.trigger_dag_id,
1348 )
1349
1350 return _handle_current_task_success(context, ti)
1351
1352
1353def _run_task_state_change_callbacks(
1354 task: BaseOperator,
1355 kind: Literal[
1356 "on_execute_callback",
1357 "on_failure_callback",
1358 "on_success_callback",
1359 "on_retry_callback",
1360 "on_skipped_callback",
1361 ],
1362 context: Context,
1363 log: Logger,
1364) -> None:
1365 callback: Callable[[Context], None]
1366 for i, callback in enumerate(getattr(task, kind)):
1367 try:
1368 create_executable_runner(callback, context_get_outlet_events(context), logger=log).run(context)
1369 except Exception:
1370 log.exception("Failed to run task callback", kind=kind, index=i, callback=callback)
1371
1372
1373def _send_error_email_notification(
1374 task: BaseOperator | MappedOperator,
1375 ti: RuntimeTaskInstance,
1376 context: Context,
1377 error: BaseException | str | None,
1378 log: Logger,
1379) -> None:
1380 """Send email notification for task errors using SmtpNotifier."""
1381 try:
1382 from airflow.providers.smtp.notifications.smtp import SmtpNotifier
1383 except ImportError:
1384 log.error(
1385 "Failed to send task failure or retry email notification: "
1386 "`apache-airflow-providers-smtp` is not installed. "
1387 "Install this provider to enable email notifications."
1388 )
1389 return
1390
1391 if not task.email:
1392 return
1393
1394 subject_template_file = conf.get("email", "subject_template", fallback=None)
1395
1396 # Read the template file if configured
1397 if subject_template_file and Path(subject_template_file).exists():
1398 subject = Path(subject_template_file).read_text()
1399 else:
1400 # Fallback to default
1401 subject = "Airflow alert: {{ti}}"
1402
1403 html_content_template_file = conf.get("email", "html_content_template", fallback=None)
1404
1405 # Read the template file if configured
1406 if html_content_template_file and Path(html_content_template_file).exists():
1407 html_content = Path(html_content_template_file).read_text()
1408 else:
1409 # Fallback to default
1410 # For reporting purposes, we report based on 1-indexed,
1411 # not 0-indexed lists (i.e. Try 1 instead of Try 0 for the first attempt).
1412 html_content = (
1413 "Try {{try_number}} out of {{max_tries + 1}}<br>"
1414 "Exception:<br>{{exception_html}}<br>"
1415 'Log: <a href="{{ti.log_url}}">Link</a><br>'
1416 "Host: {{ti.hostname}}<br>"
1417 'Mark success: <a href="{{ti.mark_success_url}}">Link</a><br>'
1418 )
1419
1420 # Add exception_html to context for template rendering
1421 import html
1422
1423 exception_html = html.escape(str(error)).replace("\n", "<br>")
1424 additional_context = {
1425 "exception": error,
1426 "exception_html": exception_html,
1427 "try_number": ti.try_number,
1428 "max_tries": ti.max_tries,
1429 }
1430 email_context = {**context, **additional_context}
1431 to_emails = task.email
1432 if not to_emails:
1433 return
1434
1435 try:
1436 notifier = SmtpNotifier(
1437 to=to_emails,
1438 subject=subject,
1439 html_content=html_content,
1440 from_email=conf.get("email", "from_email", fallback="airflow@airflow"),
1441 )
1442 notifier(email_context)
1443 except Exception:
1444 log.exception("Failed to send email notification")
1445
1446
1447def _execute_task(context: Context, ti: RuntimeTaskInstance, log: Logger):
1448 """Execute Task (optionally with a Timeout) and push Xcom results."""
1449 task = ti.task
1450 execute = task.execute
1451
1452 if ti._ti_context_from_server and (next_method := ti._ti_context_from_server.next_method):
1453 from airflow.sdk.serde import deserialize
1454
1455 next_kwargs_data = ti._ti_context_from_server.next_kwargs or {}
1456 try:
1457 if TYPE_CHECKING:
1458 assert isinstance(next_kwargs_data, dict)
1459 kwargs = deserialize(next_kwargs_data)
1460 except (ImportError, KeyError, AttributeError, TypeError):
1461 from airflow.serialization.serialized_objects import BaseSerialization
1462
1463 kwargs = BaseSerialization.deserialize(next_kwargs_data)
1464
1465 if TYPE_CHECKING:
1466 assert isinstance(kwargs, dict)
1467 execute = functools.partial(task.resume_execution, next_method=next_method, next_kwargs=kwargs)
1468
1469 ctx = contextvars.copy_context()
1470 # Populate the context var so ExecutorSafeguard doesn't complain
1471 ctx.run(ExecutorSafeguard.tracker.set, task)
1472
1473 # Export context in os.environ to make it available for operators to use.
1474 airflow_context_vars = context_to_airflow_vars(context, in_env_var_format=True)
1475 os.environ.update(airflow_context_vars)
1476
1477 outlet_events = context_get_outlet_events(context)
1478
1479 if (pre_execute_hook := task._pre_execute_hook) is not None:
1480 create_executable_runner(pre_execute_hook, outlet_events, logger=log).run(context)
1481 if getattr(pre_execute_hook := task.pre_execute, "__func__", None) is not BaseOperator.pre_execute:
1482 create_executable_runner(pre_execute_hook, outlet_events, logger=log).run(context)
1483
1484 _run_task_state_change_callbacks(task, "on_execute_callback", context, log)
1485
1486 if task.execution_timeout:
1487 from airflow.sdk.execution_time.timeout import timeout
1488
1489 # TODO: handle timeout in case of deferral
1490 timeout_seconds = task.execution_timeout.total_seconds()
1491 try:
1492 # It's possible we're already timed out, so fast-fail if true
1493 if timeout_seconds <= 0:
1494 raise AirflowTaskTimeout()
1495 # Run task in timeout wrapper
1496 with timeout(timeout_seconds):
1497 result = ctx.run(execute, context=context)
1498 except AirflowTaskTimeout:
1499 task.on_kill()
1500 raise
1501 else:
1502 result = ctx.run(execute, context=context)
1503
1504 if (post_execute_hook := task._post_execute_hook) is not None:
1505 create_executable_runner(post_execute_hook, outlet_events, logger=log).run(context, result)
1506 if getattr(post_execute_hook := task.post_execute, "__func__", None) is not BaseOperator.post_execute:
1507 create_executable_runner(post_execute_hook, outlet_events, logger=log).run(context)
1508
1509 return result
1510
1511
1512def _render_map_index(context: Context, ti: RuntimeTaskInstance, log: Logger) -> str | None:
1513 """Render named map index if the Dag author defined map_index_template at the task level."""
1514 if (template := context.get("map_index_template")) is None:
1515 return None
1516 log.debug("Rendering map_index_template", template_length=len(template))
1517 jinja_env = ti.task.dag.get_template_env()
1518 rendered_map_index = jinja_env.from_string(template).render(context)
1519 log.debug("Map index rendered", length=len(rendered_map_index))
1520 return rendered_map_index
1521
1522
1523def _push_xcom_if_needed(result: Any, ti: RuntimeTaskInstance, log: Logger):
1524 """Push XCom values when task has ``do_xcom_push`` set to ``True`` and the task returns a result."""
1525 if ti.task.do_xcom_push:
1526 xcom_value = result
1527 else:
1528 xcom_value = None
1529
1530 has_mapped_dep = next(ti.task.iter_mapped_dependants(), None) is not None
1531 if xcom_value is None:
1532 if not ti.is_mapped and has_mapped_dep:
1533 # Uhoh, a downstream mapped task depends on us to push something to map over
1534 from airflow.sdk.exceptions import XComForMappingNotPushed
1535
1536 raise XComForMappingNotPushed()
1537 return
1538
1539 mapped_length: int | None = None
1540 if not ti.is_mapped and has_mapped_dep:
1541 from airflow.sdk.definitions.mappedoperator import is_mappable_value
1542 from airflow.sdk.exceptions import UnmappableXComTypePushed
1543
1544 if not is_mappable_value(xcom_value):
1545 raise UnmappableXComTypePushed(xcom_value)
1546 mapped_length = len(xcom_value)
1547
1548 log.info("Pushing xcom", ti=ti)
1549
1550 # If the task has multiple outputs, push each output as a separate XCom.
1551 if ti.task.multiple_outputs:
1552 if not isinstance(xcom_value, Mapping):
1553 raise TypeError(
1554 f"Returned output was type {type(xcom_value)} expected dictionary for multiple_outputs"
1555 )
1556 for key in xcom_value.keys():
1557 if not isinstance(key, str):
1558 raise TypeError(
1559 "Returned dictionary keys must be strings when using "
1560 f"multiple_outputs, found {key} ({type(key)}) instead"
1561 )
1562 for k, v in result.items():
1563 ti.xcom_push(k, v)
1564
1565 _xcom_push(ti, BaseXCom.XCOM_RETURN_KEY, result, mapped_length=mapped_length)
1566
1567
1568def finalize(
1569 ti: RuntimeTaskInstance,
1570 state: TaskInstanceState,
1571 context: Context,
1572 log: Logger,
1573 error: BaseException | None = None,
1574):
1575 # Record task duration metrics for all terminal states
1576 if ti.start_date and ti.end_date:
1577 duration_ms = (ti.end_date - ti.start_date).total_seconds() * 1000
1578 stats_tags = {"dag_id": ti.dag_id, "task_id": ti.task_id}
1579
1580 Stats.timing(f"dag.{ti.dag_id}.{ti.task_id}.duration", duration_ms)
1581 Stats.timing("task.duration", duration_ms, tags=stats_tags)
1582
1583 task = ti.task
1584 # Pushing xcom for each operator extra links defined on the operator only.
1585 for oe in task.operator_extra_links:
1586 try:
1587 link, xcom_key = oe.get_link(operator=task, ti_key=ti), oe.xcom_key # type: ignore[arg-type]
1588 log.debug("Setting xcom for operator extra link", link=link, xcom_key=xcom_key)
1589 _xcom_push_to_db(ti, key=xcom_key, value=link)
1590 except Exception:
1591 log.exception(
1592 "Failed to push an xcom for task operator extra link",
1593 link_name=oe.name,
1594 xcom_key=oe.xcom_key,
1595 ti=ti,
1596 )
1597
1598 if getattr(ti.task, "overwrite_rtif_after_execution", False):
1599 log.debug("Overwriting Rendered template fields.")
1600 if ti.task.template_fields:
1601 SUPERVISOR_COMMS.send(SetRenderedFields(rendered_fields=_serialize_rendered_fields(ti.task)))
1602
1603 log.debug("Running finalizers", ti=ti)
1604 if state == TaskInstanceState.SUCCESS:
1605 _run_task_state_change_callbacks(task, "on_success_callback", context, log)
1606 try:
1607 get_listener_manager().hook.on_task_instance_success(
1608 previous_state=TaskInstanceState.RUNNING, task_instance=ti
1609 )
1610 except Exception:
1611 log.exception("error calling listener")
1612 elif state == TaskInstanceState.SKIPPED:
1613 _run_task_state_change_callbacks(task, "on_skipped_callback", context, log)
1614 try:
1615 get_listener_manager().hook.on_task_instance_skipped(
1616 previous_state=TaskInstanceState.RUNNING, task_instance=ti
1617 )
1618 except Exception:
1619 log.exception("error calling listener")
1620 elif state == TaskInstanceState.UP_FOR_RETRY:
1621 _run_task_state_change_callbacks(task, "on_retry_callback", context, log)
1622 try:
1623 get_listener_manager().hook.on_task_instance_failed(
1624 previous_state=TaskInstanceState.RUNNING, task_instance=ti, error=error
1625 )
1626 except Exception:
1627 log.exception("error calling listener")
1628 if error and task.email_on_retry and task.email:
1629 _send_error_email_notification(task, ti, context, error, log)
1630 elif state == TaskInstanceState.FAILED:
1631 _run_task_state_change_callbacks(task, "on_failure_callback", context, log)
1632 try:
1633 get_listener_manager().hook.on_task_instance_failed(
1634 previous_state=TaskInstanceState.RUNNING, task_instance=ti, error=error
1635 )
1636 except Exception:
1637 log.exception("error calling listener")
1638 if error and task.email_on_failure and task.email:
1639 _send_error_email_notification(task, ti, context, error, log)
1640
1641 try:
1642 get_listener_manager().hook.before_stopping(component=TaskRunnerMarker())
1643 except Exception:
1644 log.exception("error calling listener")
1645
1646
1647def main():
1648 log = structlog.get_logger(logger_name="task")
1649
1650 global SUPERVISOR_COMMS
1651 SUPERVISOR_COMMS = CommsDecoder[ToTask, ToSupervisor](log=log)
1652
1653 try:
1654 ti, context, log = startup()
1655 with BundleVersionLock(
1656 bundle_name=ti.bundle_instance.name,
1657 bundle_version=ti.bundle_instance.version,
1658 ):
1659 state, _, error = run(ti, context, log)
1660 context["exception"] = error
1661 finalize(ti, state, context, log, error)
1662 except KeyboardInterrupt:
1663 log.exception("Ctrl-c hit")
1664 exit(2)
1665 except Exception:
1666 log.exception("Top level error")
1667 exit(1)
1668 finally:
1669 # Ensure the request socket is closed on the child side in all circumstances
1670 # before the process fully terminates.
1671 if SUPERVISOR_COMMS and SUPERVISOR_COMMS.socket:
1672 with suppress(Exception):
1673 SUPERVISOR_COMMS.socket.close()
1674
1675
1676def reinit_supervisor_comms() -> None:
1677 """
1678 Re-initialize supervisor comms and logging channel in subprocess.
1679
1680 This is not needed for most cases, but is used when either we re-launch the process via sudo for
1681 run_as_user, or from inside the python code in a virtualenv (et al.) operator to re-connect so those tasks
1682 can continue to access variables etc.
1683 """
1684 import socket
1685
1686 if "SUPERVISOR_COMMS" not in globals():
1687 global SUPERVISOR_COMMS
1688 log = structlog.get_logger(logger_name="task")
1689
1690 fd = int(os.environ.get("__AIRFLOW_SUPERVISOR_FD", "0"))
1691
1692 SUPERVISOR_COMMS = CommsDecoder[ToTask, ToSupervisor](log=log, socket=socket.socket(fileno=fd))
1693
1694 logs = SUPERVISOR_COMMS.send(ResendLoggingFD())
1695 if isinstance(logs, SentFDs):
1696 from airflow.sdk.log import configure_logging
1697
1698 log_io = os.fdopen(logs.fds[0], "wb", buffering=0)
1699 configure_logging(json_output=True, output=log_io, sending_to_supervisor=True)
1700 else:
1701 print("Unable to re-configure logging after sudo, we didn't get an FD", file=sys.stderr)
1702
1703
1704if __name__ == "__main__":
1705 main()