1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18"""The entrypoint for the actual task execution process."""
19
20from __future__ import annotations
21
22import contextlib
23import contextvars
24import functools
25import os
26import sys
27import time
28from collections.abc import Callable, Iterable, Iterator, Mapping
29from contextlib import suppress
30from datetime import datetime, timezone
31from itertools import product
32from pathlib import Path
33from typing import TYPE_CHECKING, Annotated, Any, Literal
34from urllib.parse import quote
35
36import attrs
37import lazy_object_proxy
38import structlog
39from pydantic import AwareDatetime, ConfigDict, Field, JsonValue, TypeAdapter
40
41from airflow.dag_processing.bundles.base import BaseDagBundle, BundleVersionLock
42from airflow.dag_processing.bundles.manager import DagBundlesManager
43from airflow.sdk.api.client import get_hostname, getuser
44from airflow.sdk.api.datamodels._generated import (
45 AssetProfile,
46 DagRun,
47 PreviousTIResponse,
48 TaskInstance,
49 TaskInstanceState,
50 TIRunContext,
51)
52from airflow.sdk.bases.operator import BaseOperator, ExecutorSafeguard
53from airflow.sdk.bases.xcom import BaseXCom
54from airflow.sdk.configuration import conf
55from airflow.sdk.definitions._internal.dag_parsing_context import _airflow_parsing_context_manager
56from airflow.sdk.definitions._internal.types import NOTSET, ArgNotSet, is_arg_set
57from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetNameRef, AssetUniqueKey, AssetUriRef
58from airflow.sdk.definitions.mappedoperator import MappedOperator
59from airflow.sdk.definitions.param import process_params
60from airflow.sdk.exceptions import (
61 AirflowException,
62 AirflowInactiveAssetInInletOrOutletException,
63 AirflowRuntimeError,
64 AirflowTaskTimeout,
65 ErrorType,
66 TaskDeferred,
67)
68from airflow.sdk.execution_time.callback_runner import create_executable_runner
69from airflow.sdk.execution_time.comms import (
70 AssetEventDagRunReferenceResult,
71 CommsDecoder,
72 DagRunStateResult,
73 DeferTask,
74 DRCount,
75 ErrorResponse,
76 GetDagRunState,
77 GetDRCount,
78 GetPreviousDagRun,
79 GetPreviousTI,
80 GetTaskBreadcrumbs,
81 GetTaskRescheduleStartDate,
82 GetTaskStates,
83 GetTICount,
84 InactiveAssetsResult,
85 PreviousDagRunResult,
86 PreviousTIResult,
87 RescheduleTask,
88 ResendLoggingFD,
89 RetryTask,
90 SentFDs,
91 SetRenderedFields,
92 SetRenderedMapIndex,
93 SkipDownstreamTasks,
94 StartupDetails,
95 SucceedTask,
96 TaskBreadcrumbsResult,
97 TaskRescheduleStartDate,
98 TaskState,
99 TaskStatesResult,
100 TICount,
101 ToSupervisor,
102 ToTask,
103 TriggerDagRun,
104 ValidateInletsAndOutlets,
105)
106from airflow.sdk.execution_time.context import (
107 ConnectionAccessor,
108 InletEventsAccessors,
109 MacrosAccessor,
110 OutletEventAccessors,
111 TriggeringAssetEventsAccessor,
112 VariableAccessor,
113 context_get_outlet_events,
114 context_to_airflow_vars,
115 get_previous_dagrun_success,
116 set_current_context,
117)
118from airflow.sdk.execution_time.sentry import Sentry
119from airflow.sdk.execution_time.xcom import XCom
120from airflow.sdk.listener import get_listener_manager
121from airflow.sdk.observability.stats import Stats
122from airflow.sdk.timezone import coerce_datetime
123from airflow.triggers.base import BaseEventTrigger
124from airflow.triggers.callback import CallbackTrigger
125
126if TYPE_CHECKING:
127 import jinja2
128 from pendulum.datetime import DateTime
129 from structlog.typing import FilteringBoundLogger as Logger
130
131 from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator
132 from airflow.sdk.definitions.context import Context
133 from airflow.sdk.exceptions import DagRunTriggerException
134 from airflow.sdk.types import OutletEventAccessorsProtocol
135
136
137class TaskRunnerMarker:
138 """Marker for listener hooks, to properly detect from which component they are called."""
139
140
141# TODO: Move this entire class into a separate file:
142# `airflow/sdk/execution_time/task_instance.py`
143# or `airflow/sdk/execution_time/runtime_ti.py`
144class RuntimeTaskInstance(TaskInstance):
145 model_config = ConfigDict(arbitrary_types_allowed=True)
146
147 task: BaseOperator
148 bundle_instance: BaseDagBundle
149 _cached_template_context: Context | None = None
150 """The Task Instance context. This is used to cache get_template_context."""
151
152 _ti_context_from_server: Annotated[TIRunContext | None, Field(repr=False)] = None
153 """The Task Instance context from the API server, if any."""
154
155 max_tries: int = 0
156 """The maximum number of retries for the task."""
157
158 start_date: AwareDatetime
159 """Start date of the task instance."""
160
161 end_date: AwareDatetime | None = None
162
163 state: TaskInstanceState | None = None
164
165 is_mapped: bool | None = None
166 """True if the original task was mapped."""
167
168 rendered_map_index: str | None = None
169
170 sentry_integration: str = ""
171
172 def __rich_repr__(self):
173 yield "id", self.id
174 yield "task_id", self.task_id
175 yield "dag_id", self.dag_id
176 yield "run_id", self.run_id
177 yield "max_tries", self.max_tries
178 yield "task", type(self.task)
179 yield "start_date", self.start_date
180
181 __rich_repr__.angular = True # type: ignore[attr-defined]
182
183 def get_template_context(self) -> Context:
184 # TODO: Move this to `airflow.sdk.execution_time.context`
185 # once we port the entire context logic from airflow/utils/context.py ?
186 from airflow.sdk.plugins_manager import integrate_macros_plugins
187
188 integrate_macros_plugins()
189
190 dag_run_conf: dict[str, Any] | None = None
191 if from_server := self._ti_context_from_server:
192 dag_run_conf = from_server.dag_run.conf or dag_run_conf
193
194 validated_params = process_params(self.task.dag, self.task, dag_run_conf, suppress_exception=False)
195
196 # Cache the context object, which ensures that all calls to get_template_context
197 # are operating on the same context object.
198 self._cached_template_context: Context = self._cached_template_context or {
199 # From the Task Execution interface
200 "dag": self.task.dag,
201 "inlets": self.task.inlets,
202 "map_index_template": self.task.map_index_template,
203 "outlets": self.task.outlets,
204 "run_id": self.run_id,
205 "task": self.task,
206 "task_instance": self,
207 "ti": self,
208 "outlet_events": OutletEventAccessors(),
209 "inlet_events": InletEventsAccessors(self.task.inlets),
210 "macros": MacrosAccessor(),
211 "params": validated_params,
212 # TODO: Make this go through Public API longer term.
213 # "test_mode": task_instance.test_mode,
214 "var": {
215 "json": VariableAccessor(deserialize_json=True),
216 "value": VariableAccessor(deserialize_json=False),
217 },
218 "conn": ConnectionAccessor(),
219 }
220 if from_server:
221 dag_run = from_server.dag_run
222 context_from_server: Context = {
223 # TODO: Assess if we need to pass these through timezone.coerce_datetime
224 "dag_run": dag_run, # type: ignore[typeddict-item] # Removable after #46522
225 "triggering_asset_events": TriggeringAssetEventsAccessor.build(
226 AssetEventDagRunReferenceResult.from_asset_event_dag_run_reference(event)
227 for event in dag_run.consumed_asset_events
228 ),
229 "task_instance_key_str": f"{self.task.dag_id}__{self.task.task_id}__{dag_run.run_id}",
230 "task_reschedule_count": from_server.task_reschedule_count or 0,
231 "prev_start_date_success": lazy_object_proxy.Proxy(
232 lambda: coerce_datetime(get_previous_dagrun_success(self.id).start_date)
233 ),
234 "prev_end_date_success": lazy_object_proxy.Proxy(
235 lambda: coerce_datetime(get_previous_dagrun_success(self.id).end_date)
236 ),
237 }
238 self._cached_template_context.update(context_from_server)
239
240 if logical_date := coerce_datetime(dag_run.logical_date):
241 if TYPE_CHECKING:
242 assert isinstance(logical_date, DateTime)
243 ds = logical_date.strftime("%Y-%m-%d")
244 ds_nodash = ds.replace("-", "")
245 ts = logical_date.isoformat()
246 ts_nodash = logical_date.strftime("%Y%m%dT%H%M%S")
247 ts_nodash_with_tz = ts.replace("-", "").replace(":", "")
248 # logical_date and data_interval either coexist or be None together
249 self._cached_template_context.update(
250 {
251 # keys that depend on logical_date
252 "logical_date": logical_date,
253 "ds": ds,
254 "ds_nodash": ds_nodash,
255 "task_instance_key_str": f"{self.task.dag_id}__{self.task.task_id}__{ds_nodash}",
256 "ts": ts,
257 "ts_nodash": ts_nodash,
258 "ts_nodash_with_tz": ts_nodash_with_tz,
259 # keys that depend on data_interval
260 "data_interval_end": coerce_datetime(dag_run.data_interval_end),
261 "data_interval_start": coerce_datetime(dag_run.data_interval_start),
262 "prev_data_interval_start_success": lazy_object_proxy.Proxy(
263 lambda: coerce_datetime(get_previous_dagrun_success(self.id).data_interval_start)
264 ),
265 "prev_data_interval_end_success": lazy_object_proxy.Proxy(
266 lambda: coerce_datetime(get_previous_dagrun_success(self.id).data_interval_end)
267 ),
268 }
269 )
270
271 if from_server.upstream_map_indexes is not None:
272 # We stash this in here for later use, but we purposefully don't want to document it's
273 # existence. Should this be a private attribute on RuntimeTI instead perhaps?
274 setattr(self, "_upstream_map_indexes", from_server.upstream_map_indexes)
275
276 return self._cached_template_context
277
278 def render_templates(
279 self, context: Context | None = None, jinja_env: jinja2.Environment | None = None
280 ) -> BaseOperator:
281 """
282 Render templates in the operator fields.
283
284 If the task was originally mapped, this may replace ``self.task`` with
285 the unmapped, fully rendered BaseOperator. The original ``self.task``
286 before replacement is returned.
287 """
288 if not context:
289 context = self.get_template_context()
290 original_task = self.task
291
292 if TYPE_CHECKING:
293 assert context
294
295 ti = context["ti"]
296
297 if TYPE_CHECKING:
298 assert original_task
299 assert self.task
300 assert ti.task
301
302 # If self.task is mapped, this call replaces self.task to point to the
303 # unmapped BaseOperator created by this function! This is because the
304 # MappedOperator is useless for template rendering, and we need to be
305 # able to access the unmapped task instead.
306 self.task.render_template_fields(context, jinja_env)
307 self.is_mapped = original_task.is_mapped
308 return original_task
309
310 def xcom_pull(
311 self,
312 task_ids: str | Iterable[str] | None = None,
313 dag_id: str | None = None,
314 key: str = BaseXCom.XCOM_RETURN_KEY,
315 include_prior_dates: bool = False,
316 *,
317 map_indexes: int | Iterable[int] | None | ArgNotSet = NOTSET,
318 default: Any = None,
319 run_id: str | None = None,
320 ) -> Any:
321 """
322 Pull XComs either from the API server (BaseXCom) or from the custom XCOM backend if configured.
323
324 The pull can be filtered optionally by certain criterion.
325
326 :param key: A key for the XCom. If provided, only XComs with matching
327 keys will be returned. The default key is ``'return_value'``, also
328 available as constant ``XCOM_RETURN_KEY``. This key is automatically
329 given to XComs returned by tasks (as opposed to being pushed
330 manually).
331 :param task_ids: Only XComs from tasks with matching ids will be
332 pulled. If *None* (default), the task_id of the calling task is used.
333 :param dag_id: If provided, only pulls XComs from this Dag. If *None*
334 (default), the Dag of the calling task is used.
335 :param map_indexes: If provided, only pull XComs with matching indexes.
336 If *None* (default), this is inferred from the task(s) being pulled
337 (see below for details).
338 :param include_prior_dates: If False, only XComs from the current
339 logical_date are returned. If *True*, XComs from previous dates
340 are returned as well.
341 :param run_id: If provided, only pulls XComs from a DagRun w/a matching run_id.
342 If *None* (default), the run_id of the calling task is used.
343
344 When pulling one single task (``task_id`` is *None* or a str) without
345 specifying ``map_indexes``, the return value is a single XCom entry
346 (map_indexes is set to map_index of the calling task instance).
347
348 When pulling task is mapped the specified ``map_index`` is used, so by default
349 pulling on mapped task will result in no matching XComs if the task instance
350 of the method call is not mapped. Otherwise, the map_index of the calling task
351 instance is used. Setting ``map_indexes`` to *None* will pull XCom as it would
352 from a non mapped task.
353
354 In either case, ``default`` (*None* if not specified) is returned if no
355 matching XComs are found.
356
357 When pulling multiple tasks (i.e. either ``task_id`` or ``map_index`` is
358 a non-str iterable), a list of matching XComs is returned. Elements in
359 the list is ordered by item ordering in ``task_id`` and ``map_index``.
360 """
361 if dag_id is None:
362 dag_id = self.dag_id
363 if run_id is None:
364 run_id = self.run_id
365
366 single_task_requested = isinstance(task_ids, (str, type(None)))
367 single_map_index_requested = isinstance(map_indexes, (int, type(None)))
368
369 if task_ids is None:
370 # default to the current task if not provided
371 task_ids = [self.task_id]
372 elif isinstance(task_ids, str):
373 task_ids = [task_ids]
374
375 # If map_indexes is not specified, pull xcoms from all map indexes for each task
376 if not is_arg_set(map_indexes):
377 xcoms: list[Any] = []
378 for t_id in task_ids:
379 values = XCom.get_all(
380 run_id=run_id,
381 key=key,
382 task_id=t_id,
383 dag_id=dag_id,
384 include_prior_dates=include_prior_dates,
385 )
386
387 if values is None:
388 xcoms.append(None)
389 else:
390 xcoms.extend(values)
391 # For single task pulling from unmapped task, return single value
392 if single_task_requested and len(xcoms) == 1:
393 return xcoms[0]
394 return xcoms
395
396 # Original logic when map_indexes is explicitly specified
397 map_indexes_iterable: Iterable[int | None] = []
398 if isinstance(map_indexes, int) or map_indexes is None:
399 map_indexes_iterable = [map_indexes]
400 elif isinstance(map_indexes, Iterable):
401 map_indexes_iterable = map_indexes
402 else:
403 raise TypeError(
404 f"Invalid type for map_indexes: expected int, iterable of ints, or None, got {type(map_indexes)}"
405 )
406
407 xcoms = []
408 for t_id, m_idx in product(task_ids, map_indexes_iterable):
409 value = XCom.get_one(
410 run_id=run_id,
411 key=key,
412 task_id=t_id,
413 dag_id=dag_id,
414 map_index=m_idx,
415 include_prior_dates=include_prior_dates,
416 )
417 if value is None:
418 xcoms.append(default)
419 else:
420 xcoms.append(value)
421
422 if single_task_requested and single_map_index_requested:
423 return xcoms[0]
424
425 return xcoms
426
427 def xcom_push(self, key: str, value: Any):
428 """
429 Make an XCom available for tasks to pull.
430
431 :param key: Key to store the value under.
432 :param value: Value to store. Only be JSON-serializable values may be used.
433 """
434 _xcom_push(self, key, value)
435
436 def get_relevant_upstream_map_indexes(
437 self, upstream: BaseOperator, ti_count: int | None, session: Any
438 ) -> int | range | None:
439 # TODO: Implement this method
440 return None
441
442 def get_first_reschedule_date(self, context: Context) -> AwareDatetime | None:
443 """Get the first reschedule date for the task instance if found, none otherwise."""
444 if context.get("task_reschedule_count", 0) == 0:
445 # If the task has not been rescheduled, there is no need to ask the supervisor
446 return None
447
448 max_tries: int = self.max_tries
449 retries: int = self.task.retries or 0
450 first_try_number = max_tries - retries + 1
451
452 log = structlog.get_logger(logger_name="task")
453
454 log.debug("Requesting first reschedule date from supervisor")
455
456 response = SUPERVISOR_COMMS.send(
457 msg=GetTaskRescheduleStartDate(ti_id=self.id, try_number=first_try_number)
458 )
459
460 if TYPE_CHECKING:
461 assert isinstance(response, TaskRescheduleStartDate)
462
463 return response.start_date
464
465 def get_previous_dagrun(self, state: str | None = None) -> DagRun | None:
466 """Return the previous Dag run before the given logical date, optionally filtered by state."""
467 context = self.get_template_context()
468 dag_run = context.get("dag_run")
469
470 log = structlog.get_logger(logger_name="task")
471
472 log.debug("Getting previous Dag run", dag_run=dag_run)
473
474 if dag_run is None:
475 return None
476
477 if dag_run.logical_date is None:
478 return None
479
480 response = SUPERVISOR_COMMS.send(
481 msg=GetPreviousDagRun(dag_id=self.dag_id, logical_date=dag_run.logical_date, state=state)
482 )
483
484 if TYPE_CHECKING:
485 assert isinstance(response, PreviousDagRunResult)
486
487 return response.dag_run
488
489 def get_previous_ti(
490 self,
491 state: TaskInstanceState | None = None,
492 logical_date: AwareDatetime | None = None,
493 map_index: int = -1,
494 ) -> PreviousTIResponse | None:
495 """
496 Return the previous task instance matching the given criteria.
497
498 :param state: Filter by TaskInstance state
499 :param logical_date: Filter by logical date (returns TI before this date)
500 :param map_index: Filter by map_index (defaults to -1 for non-mapped tasks)
501 :return: Previous task instance or None if not found
502 """
503 context = self.get_template_context()
504 dag_run = context.get("dag_run")
505
506 log = structlog.get_logger(logger_name="task")
507 log.debug("Getting previous task instance", task_id=self.task_id, state=state)
508
509 # Use current dag run's logical_date if not provided
510 effective_logical_date = logical_date
511 if effective_logical_date is None and dag_run and dag_run.logical_date:
512 effective_logical_date = dag_run.logical_date
513
514 response = SUPERVISOR_COMMS.send(
515 msg=GetPreviousTI(
516 dag_id=self.dag_id,
517 task_id=self.task_id,
518 logical_date=effective_logical_date,
519 map_index=map_index,
520 state=state,
521 )
522 )
523
524 if TYPE_CHECKING:
525 assert isinstance(response, PreviousTIResult)
526
527 return response.task_instance
528
529 @staticmethod
530 def get_ti_count(
531 dag_id: str,
532 map_index: int | None = None,
533 task_ids: list[str] | None = None,
534 task_group_id: str | None = None,
535 logical_dates: list[datetime] | None = None,
536 run_ids: list[str] | None = None,
537 states: list[str] | None = None,
538 ) -> int:
539 """Return the number of task instances matching the given criteria."""
540 response = SUPERVISOR_COMMS.send(
541 GetTICount(
542 dag_id=dag_id,
543 map_index=map_index,
544 task_ids=task_ids,
545 task_group_id=task_group_id,
546 logical_dates=logical_dates,
547 run_ids=run_ids,
548 states=states,
549 ),
550 )
551
552 if TYPE_CHECKING:
553 assert isinstance(response, TICount)
554
555 return response.count
556
557 @staticmethod
558 def get_task_states(
559 dag_id: str,
560 map_index: int | None = None,
561 task_ids: list[str] | None = None,
562 task_group_id: str | None = None,
563 logical_dates: list[datetime] | None = None,
564 run_ids: list[str] | None = None,
565 ) -> dict[str, Any]:
566 """Return the task states matching the given criteria."""
567 response = SUPERVISOR_COMMS.send(
568 GetTaskStates(
569 dag_id=dag_id,
570 map_index=map_index,
571 task_ids=task_ids,
572 task_group_id=task_group_id,
573 logical_dates=logical_dates,
574 run_ids=run_ids,
575 ),
576 )
577
578 if TYPE_CHECKING:
579 assert isinstance(response, TaskStatesResult)
580
581 return response.task_states
582
583 @staticmethod
584 def get_task_breadcrumbs(dag_id: str, run_id: str) -> Iterable[dict[str, Any]]:
585 """Return task breadcrumbs for the given dag run."""
586 response = SUPERVISOR_COMMS.send(GetTaskBreadcrumbs(dag_id=dag_id, run_id=run_id))
587 if TYPE_CHECKING:
588 assert isinstance(response, TaskBreadcrumbsResult)
589 return response.breadcrumbs
590
591 @staticmethod
592 def get_dr_count(
593 dag_id: str,
594 logical_dates: list[datetime] | None = None,
595 run_ids: list[str] | None = None,
596 states: list[str] | None = None,
597 ) -> int:
598 """Return the number of Dag runs matching the given criteria."""
599 response = SUPERVISOR_COMMS.send(
600 GetDRCount(
601 dag_id=dag_id,
602 logical_dates=logical_dates,
603 run_ids=run_ids,
604 states=states,
605 ),
606 )
607
608 if TYPE_CHECKING:
609 assert isinstance(response, DRCount)
610
611 return response.count
612
613 @staticmethod
614 def get_dagrun_state(dag_id: str, run_id: str) -> str:
615 """Return the state of the Dag run with the given Run ID."""
616 response = SUPERVISOR_COMMS.send(msg=GetDagRunState(dag_id=dag_id, run_id=run_id))
617
618 if TYPE_CHECKING:
619 assert isinstance(response, DagRunStateResult)
620
621 return response.state
622
623 @property
624 def log_url(self) -> str:
625 run_id = quote(self.run_id)
626 base_url = conf.get("api", "base_url", fallback="http://localhost:8080/")
627 map_index_value = self.map_index
628 map_index = (
629 f"/mapped/{map_index_value}" if map_index_value is not None and map_index_value >= 0 else ""
630 )
631 try_number_value = self.try_number
632 try_number = (
633 f"?try_number={try_number_value}" if try_number_value is not None and try_number_value > 0 else ""
634 )
635 _log_uri = f"{base_url.rstrip('/')}/dags/{self.dag_id}/runs/{run_id}/tasks/{self.task_id}{map_index}{try_number}"
636 return _log_uri
637
638 @property
639 def mark_success_url(self) -> str:
640 """URL to mark TI success."""
641 return self.log_url
642
643
644def _xcom_push(ti: RuntimeTaskInstance, key: str, value: Any, mapped_length: int | None = None) -> None:
645 """Push a XCom through XCom.set, which pushes to XCom Backend if configured."""
646 # Private function, as we don't want to expose the ability to manually set `mapped_length` to SDK
647 # consumers
648
649 XCom.set(
650 key=key,
651 value=value,
652 dag_id=ti.dag_id,
653 task_id=ti.task_id,
654 run_id=ti.run_id,
655 map_index=ti.map_index,
656 _mapped_length=mapped_length,
657 )
658
659
660def _xcom_push_to_db(ti: RuntimeTaskInstance, key: str, value: Any) -> None:
661 """Push a XCom directly to metadata DB, bypassing custom xcom_backend."""
662 XCom._set_xcom_in_db(
663 key=key,
664 value=value,
665 dag_id=ti.dag_id,
666 task_id=ti.task_id,
667 run_id=ti.run_id,
668 map_index=ti.map_index,
669 )
670
671
672def parse(what: StartupDetails, log: Logger) -> RuntimeTaskInstance:
673 # TODO: Task-SDK:
674 # Using DagBag here is about 98% wrong, but it'll do for now
675 from airflow.dag_processing.dagbag import DagBag
676
677 bundle_info = what.bundle_info
678 bundle_instance = DagBundlesManager().get_bundle(
679 name=bundle_info.name,
680 version=bundle_info.version,
681 )
682 bundle_instance.initialize()
683
684 # Put bundle root on sys.path if needed. This allows the dag bundle to add
685 # code in util modules to be shared between files within the same bundle.
686 if (bundle_root := os.fspath(bundle_instance.path)) not in sys.path:
687 sys.path.append(bundle_root)
688
689 dag_absolute_path = os.fspath(Path(bundle_instance.path, what.dag_rel_path))
690 bag = DagBag(
691 dag_folder=dag_absolute_path,
692 include_examples=False,
693 safe_mode=False,
694 load_op_links=False,
695 bundle_name=bundle_info.name,
696 )
697 if TYPE_CHECKING:
698 assert what.ti.dag_id
699
700 try:
701 dag = bag.dags[what.ti.dag_id]
702 except KeyError:
703 log.error(
704 "Dag not found during start up", dag_id=what.ti.dag_id, bundle=bundle_info, path=what.dag_rel_path
705 )
706 sys.exit(1)
707
708 # install_loader()
709
710 try:
711 task = dag.task_dict[what.ti.task_id]
712 except KeyError:
713 log.error(
714 "Task not found in Dag during start up",
715 dag_id=dag.dag_id,
716 task_id=what.ti.task_id,
717 bundle=bundle_info,
718 path=what.dag_rel_path,
719 )
720 sys.exit(1)
721
722 if not isinstance(task, (BaseOperator, MappedOperator)):
723 raise TypeError(
724 f"task is of the wrong type, got {type(task)}, wanted {BaseOperator} or {MappedOperator}"
725 )
726
727 return RuntimeTaskInstance.model_construct(
728 **what.ti.model_dump(exclude_unset=True),
729 task=task,
730 bundle_instance=bundle_instance,
731 _ti_context_from_server=what.ti_context,
732 max_tries=what.ti_context.max_tries,
733 start_date=what.start_date,
734 state=TaskInstanceState.RUNNING,
735 sentry_integration=what.sentry_integration,
736 )
737
738
739# This global variable will be used by Connection/Variable/XCom classes, or other parts of the task's execution,
740# to send requests back to the supervisor process.
741#
742# Why it needs to be a global:
743# - Many parts of Airflow's codebase (e.g., connections, variables, and XComs) may rely on making dynamic requests
744# to the parent process during task execution.
745# - These calls occur in various locations and cannot easily pass the `CommsDecoder` instance through the
746# deeply nested execution stack.
747# - By defining `SUPERVISOR_COMMS` as a global, it ensures that this communication mechanism is readily
748# accessible wherever needed during task execution without modifying every layer of the call stack.
749SUPERVISOR_COMMS: CommsDecoder[ToTask, ToSupervisor]
750
751
752# State machine!
753# 1. Start up (receive details from supervisor)
754# 2. Execution (run task code, possibly send requests)
755# 3. Shutdown and report status
756
757
758def startup() -> tuple[RuntimeTaskInstance, Context, Logger]:
759 # The parent sends us a StartupDetails message un-prompted. After this, every single message is only sent
760 # in response to us sending a request.
761 log = structlog.get_logger(logger_name="task")
762
763 if os.environ.get("_AIRFLOW__REEXECUTED_PROCESS") == "1" and (
764 msgjson := os.environ.get("_AIRFLOW__STARTUP_MSG")
765 ):
766 # Clear any Kerberos replace cache if there is one, so new process can't reuse it.
767 os.environ.pop("KRB5CCNAME", None)
768 # entrypoint of re-exec process
769
770 msg: StartupDetails = TypeAdapter(StartupDetails).validate_json(msgjson)
771 reinit_supervisor_comms()
772
773 # We delay this message until _after_ we've got the logging re-configured, otherwise it will show up
774 # on stdout
775 log.debug("Using serialized startup message from environment", msg=msg)
776 else:
777 # normal entry point
778 msg = SUPERVISOR_COMMS._get_response() # type: ignore[assignment]
779
780 if not isinstance(msg, StartupDetails):
781 raise RuntimeError(f"Unhandled startup message {type(msg)} {msg}")
782
783 # setproctitle causes issue on Mac OS: https://github.com/benoitc/gunicorn/issues/3021
784 os_type = sys.platform
785 if os_type == "darwin":
786 log.debug("Mac OS detected, skipping setproctitle")
787 else:
788 from setproctitle import setproctitle
789
790 setproctitle(f"airflow worker -- {msg.ti.id}")
791
792 try:
793 get_listener_manager().hook.on_starting(component=TaskRunnerMarker())
794 except Exception:
795 log.exception("error calling listener")
796
797 with _airflow_parsing_context_manager(dag_id=msg.ti.dag_id, task_id=msg.ti.task_id):
798 ti = parse(msg, log)
799 log.debug("Dag file parsed", file=msg.dag_rel_path)
800
801 run_as_user = getattr(ti.task, "run_as_user", None) or conf.get(
802 "core", "default_impersonation", fallback=None
803 )
804
805 if os.environ.get("_AIRFLOW__REEXECUTED_PROCESS") != "1" and run_as_user and run_as_user != getuser():
806 # enters here for re-exec process
807 os.environ["_AIRFLOW__REEXECUTED_PROCESS"] = "1"
808 # store startup message in environment for re-exec process
809 os.environ["_AIRFLOW__STARTUP_MSG"] = msg.model_dump_json()
810 os.set_inheritable(SUPERVISOR_COMMS.socket.fileno(), True)
811
812 # Import main directly from the module instead of re-executing the file.
813 # This ensures that when other parts modules import
814 # airflow.sdk.execution_time.task_runner, they get the same module instance
815 # with the properly initialized SUPERVISOR_COMMS global variable.
816 # If we re-executed the module with `python -m`, it would load as __main__ and future
817 # imports would get a fresh copy without the initialized globals.
818 rexec_python_code = "from airflow.sdk.execution_time.task_runner import main; main()"
819 cmd = ["sudo", "-E", "-H", "-u", run_as_user, sys.executable, "-c", rexec_python_code]
820 log.info(
821 "Running command",
822 command=cmd,
823 )
824 os.execvp("sudo", cmd)
825
826 # ideally, we should never reach here, but if we do, we should return None, None, None
827 return None, None, None
828
829 return ti, ti.get_template_context(), log
830
831
832def _serialize_template_field(template_field: Any, name: str) -> str | dict | list | int | float:
833 """
834 Return a serializable representation of the templated field.
835
836 If ``templated_field`` contains a class or instance that requires recursive
837 templating, store them as strings. Otherwise simply return the field as-is.
838
839 Used sdk secrets masker to redact secrets in the serialized output.
840 """
841 import json
842
843 from airflow.sdk._shared.secrets_masker import redact
844
845 def is_jsonable(x):
846 try:
847 json.dumps(x)
848 except (TypeError, OverflowError):
849 return False
850 else:
851 return True
852
853 def translate_tuples_to_lists(obj: Any):
854 """Recursively convert tuples to lists."""
855 if isinstance(obj, tuple):
856 return [translate_tuples_to_lists(item) for item in obj]
857 if isinstance(obj, list):
858 return [translate_tuples_to_lists(item) for item in obj]
859 if isinstance(obj, dict):
860 return {key: translate_tuples_to_lists(value) for key, value in obj.items()}
861 return obj
862
863 def sort_dict_recursively(obj: Any) -> Any:
864 """Recursively sort dictionaries to ensure consistent ordering."""
865 if isinstance(obj, dict):
866 return {k: sort_dict_recursively(v) for k, v in sorted(obj.items())}
867 if isinstance(obj, list):
868 return [sort_dict_recursively(item) for item in obj]
869 if isinstance(obj, tuple):
870 return tuple(sort_dict_recursively(item) for item in obj)
871 return obj
872
873 max_length = conf.getint("core", "max_templated_field_length")
874
875 if not is_jsonable(template_field):
876 try:
877 serialized = template_field.serialize()
878 except AttributeError:
879 serialized = str(template_field)
880 if len(serialized) > max_length:
881 rendered = redact(serialized, name)
882 return (
883 "Truncated. You can change this behaviour in [core]max_templated_field_length. "
884 f"{rendered[: max_length - 79]!r}... "
885 )
886 return serialized
887 if not template_field and not isinstance(template_field, tuple):
888 # Avoid unnecessary serialization steps for empty fields unless they are tuples
889 # and need to be converted to lists
890 return template_field
891 template_field = translate_tuples_to_lists(template_field)
892 # Sort dictionaries recursively to ensure consistent string representation
893 # This prevents hash inconsistencies when dict ordering varies
894 if isinstance(template_field, dict):
895 template_field = sort_dict_recursively(template_field)
896 serialized = str(template_field)
897 if len(serialized) > max_length:
898 rendered = redact(serialized, name)
899 return (
900 "Truncated. You can change this behaviour in [core]max_templated_field_length. "
901 f"{rendered[: max_length - 79]!r}... "
902 )
903 return template_field
904
905
906def _serialize_rendered_fields(task: AbstractOperator) -> dict[str, JsonValue]:
907 from airflow.sdk._shared.secrets_masker import redact
908
909 rendered_fields = {}
910 for field in task.template_fields:
911 value = getattr(task, field)
912 serialized = _serialize_template_field(value, field)
913 # Redact secrets in the task process itself before sending to API server
914 # This ensures that the secrets those are registered via mask_secret() on workers / dag processor are properly masked
915 # on the UI.
916 rendered_fields[field] = redact(serialized, field)
917
918 return rendered_fields # type: ignore[return-value] # Convince mypy that this is OK since we pass JsonValue to redact, so it will return the same
919
920
921def _build_asset_profiles(lineage_objects: list) -> Iterator[AssetProfile]:
922 # Lineage can have other types of objects besides assets, so we need to process them a bit.
923 for obj in lineage_objects or ():
924 if isinstance(obj, Asset):
925 yield AssetProfile(name=obj.name, uri=obj.uri, type=Asset.__name__)
926 elif isinstance(obj, AssetNameRef):
927 yield AssetProfile(name=obj.name, type=AssetNameRef.__name__)
928 elif isinstance(obj, AssetUriRef):
929 yield AssetProfile(uri=obj.uri, type=AssetUriRef.__name__)
930 elif isinstance(obj, AssetAlias):
931 yield AssetProfile(name=obj.name, type=AssetAlias.__name__)
932
933
934def _serialize_outlet_events(events: OutletEventAccessorsProtocol) -> Iterator[dict[str, JsonValue]]:
935 if TYPE_CHECKING:
936 assert isinstance(events, OutletEventAccessors)
937 # We just collect everything the user recorded in the accessors.
938 # Further filtering will be done in the API server.
939 for key, accessor in events._dict.items():
940 if isinstance(key, AssetUniqueKey):
941 yield {"dest_asset_key": attrs.asdict(key), "extra": accessor.extra}
942 for alias_event in accessor.asset_alias_events:
943 yield attrs.asdict(alias_event)
944
945
946def _prepare(ti: RuntimeTaskInstance, log: Logger, context: Context) -> ToSupervisor | None:
947 ti.hostname = get_hostname()
948 ti.task = ti.task.prepare_for_execution()
949 # Since context is now cached, and calling `ti.get_template_context` will return the same dict, we want to
950 # update the value of the task that is sent from there
951 context["task"] = ti.task
952
953 jinja_env = ti.task.dag.get_template_env()
954 ti.render_templates(context=context, jinja_env=jinja_env)
955
956 if rendered_fields := _serialize_rendered_fields(ti.task):
957 # so that we do not call the API unnecessarily
958 SUPERVISOR_COMMS.send(msg=SetRenderedFields(rendered_fields=rendered_fields))
959
960 # Try to render map_index_template early with available context (will be re-rendered after execution)
961 # This provides a partial label during task execution for templates using pre-execution context
962 # If rendering fails here, we suppress the error since it will be re-rendered after execution
963 try:
964 if rendered_map_index := _render_map_index(context, ti=ti, log=log):
965 ti.rendered_map_index = rendered_map_index
966 log.debug("Sending early rendered map index", length=len(rendered_map_index))
967 SUPERVISOR_COMMS.send(msg=SetRenderedMapIndex(rendered_map_index=rendered_map_index))
968 except Exception:
969 log.debug(
970 "Early rendering of map_index_template failed, will retry after task execution", exc_info=True
971 )
972
973 _validate_task_inlets_and_outlets(ti=ti, log=log)
974
975 try:
976 # TODO: Call pre execute etc.
977 get_listener_manager().hook.on_task_instance_running(
978 previous_state=TaskInstanceState.QUEUED, task_instance=ti
979 )
980 except Exception:
981 log.exception("error calling listener")
982
983 # No error, carry on and execute the task
984 return None
985
986
987def _validate_task_inlets_and_outlets(*, ti: RuntimeTaskInstance, log: Logger) -> None:
988 if not ti.task.inlets and not ti.task.outlets:
989 return
990
991 inactive_assets_resp = SUPERVISOR_COMMS.send(msg=ValidateInletsAndOutlets(ti_id=ti.id))
992 if TYPE_CHECKING:
993 assert isinstance(inactive_assets_resp, InactiveAssetsResult)
994 if inactive_assets := inactive_assets_resp.inactive_assets:
995 raise AirflowInactiveAssetInInletOrOutletException(
996 inactive_asset_keys=[
997 AssetUniqueKey.from_profile(asset_profile) for asset_profile in inactive_assets
998 ]
999 )
1000
1001
1002def _defer_task(
1003 defer: TaskDeferred, ti: RuntimeTaskInstance, log: Logger
1004) -> tuple[ToSupervisor, TaskInstanceState]:
1005 # TODO: Should we use structlog.bind_contextvars here for dag_id, task_id & run_id?
1006
1007 log.info("Pausing task as DEFERRED. ", dag_id=ti.dag_id, task_id=ti.task_id, run_id=ti.run_id)
1008 classpath, trigger_kwargs = defer.trigger.serialize()
1009 queue: str | None = None
1010 # Currently, only task-associated BaseTrigger instances may have a non-None queue,
1011 # and only when triggerer.queues_enabled is True.
1012 if not isinstance(defer.trigger, (BaseEventTrigger, CallbackTrigger)) and conf.getboolean(
1013 "triggerer", "queues_enabled", fallback=False
1014 ):
1015 queue = ti.task.queue
1016
1017 from airflow.sdk.serde import serialize as serde_serialize
1018
1019 trigger_kwargs = serde_serialize(trigger_kwargs)
1020 next_kwargs = serde_serialize(defer.kwargs or {})
1021
1022 if TYPE_CHECKING:
1023 assert isinstance(next_kwargs, dict)
1024 assert isinstance(trigger_kwargs, dict)
1025
1026 msg = DeferTask(
1027 classpath=classpath,
1028 trigger_kwargs=trigger_kwargs,
1029 trigger_timeout=defer.timeout,
1030 queue=queue,
1031 next_method=defer.method_name,
1032 next_kwargs=next_kwargs,
1033 )
1034 state = TaskInstanceState.DEFERRED
1035
1036 return msg, state
1037
1038
1039@Sentry.enrich_errors
1040def run(
1041 ti: RuntimeTaskInstance,
1042 context: Context,
1043 log: Logger,
1044) -> tuple[TaskInstanceState, ToSupervisor | None, BaseException | None]:
1045 """Run the task in this process."""
1046 import signal
1047
1048 from airflow.sdk.exceptions import (
1049 AirflowFailException,
1050 AirflowRescheduleException,
1051 AirflowSensorTimeout,
1052 AirflowSkipException,
1053 AirflowTaskTerminated,
1054 DagRunTriggerException,
1055 DownstreamTasksSkipped,
1056 TaskDeferred,
1057 )
1058
1059 if TYPE_CHECKING:
1060 assert ti.task is not None
1061 assert isinstance(ti.task, BaseOperator)
1062
1063 parent_pid = os.getpid()
1064
1065 def _on_term(signum, frame):
1066 pid = os.getpid()
1067 if pid != parent_pid:
1068 return
1069
1070 ti.task.on_kill()
1071
1072 signal.signal(signal.SIGTERM, _on_term)
1073
1074 msg: ToSupervisor | None = None
1075 state: TaskInstanceState
1076 error: BaseException | None = None
1077
1078 try:
1079 # First, clear the xcom data sent from server
1080 if ti._ti_context_from_server and (keys_to_delete := ti._ti_context_from_server.xcom_keys_to_clear):
1081 for x in keys_to_delete:
1082 log.debug("Clearing XCom with key", key=x)
1083 XCom.delete(
1084 key=x,
1085 dag_id=ti.dag_id,
1086 task_id=ti.task_id,
1087 run_id=ti.run_id,
1088 map_index=ti.map_index,
1089 )
1090
1091 with set_current_context(context):
1092 # This is the earliest that we can render templates -- as if it excepts for any reason we need to
1093 # catch it and handle it like a normal task failure
1094 if early_exit := _prepare(ti, log, context):
1095 msg = early_exit
1096 ti.state = state = TaskInstanceState.FAILED
1097 return state, msg, error
1098
1099 try:
1100 result = _execute_task(context=context, ti=ti, log=log)
1101 except Exception:
1102 import jinja2
1103
1104 # If the task failed, swallow rendering error so it doesn't mask the main error.
1105 with contextlib.suppress(jinja2.TemplateSyntaxError, jinja2.UndefinedError):
1106 previous_rendered_map_index = ti.rendered_map_index
1107 ti.rendered_map_index = _render_map_index(context, ti=ti, log=log)
1108 # Send update only if value changed (e.g., user set context variables during execution)
1109 if ti.rendered_map_index and ti.rendered_map_index != previous_rendered_map_index:
1110 SUPERVISOR_COMMS.send(
1111 msg=SetRenderedMapIndex(rendered_map_index=ti.rendered_map_index)
1112 )
1113 raise
1114 else: # If the task succeeded, render normally to let rendering error bubble up.
1115 previous_rendered_map_index = ti.rendered_map_index
1116 ti.rendered_map_index = _render_map_index(context, ti=ti, log=log)
1117 # Send update only if value changed (e.g., user set context variables during execution)
1118 if ti.rendered_map_index and ti.rendered_map_index != previous_rendered_map_index:
1119 SUPERVISOR_COMMS.send(msg=SetRenderedMapIndex(rendered_map_index=ti.rendered_map_index))
1120
1121 _push_xcom_if_needed(result, ti, log)
1122
1123 msg, state = _handle_current_task_success(context, ti)
1124 except DownstreamTasksSkipped as skip:
1125 log.info("Skipping downstream tasks.")
1126 tasks_to_skip = skip.tasks if isinstance(skip.tasks, list) else [skip.tasks]
1127 SUPERVISOR_COMMS.send(msg=SkipDownstreamTasks(tasks=tasks_to_skip))
1128 msg, state = _handle_current_task_success(context, ti)
1129 except DagRunTriggerException as drte:
1130 msg, state = _handle_trigger_dag_run(drte, context, ti, log)
1131 except TaskDeferred as defer:
1132 msg, state = _defer_task(defer, ti, log)
1133 except AirflowSkipException as e:
1134 if e.args:
1135 log.info("Skipping task.", reason=e.args[0])
1136 msg = TaskState(
1137 state=TaskInstanceState.SKIPPED,
1138 end_date=datetime.now(tz=timezone.utc),
1139 rendered_map_index=ti.rendered_map_index,
1140 )
1141 state = TaskInstanceState.SKIPPED
1142 except AirflowRescheduleException as reschedule:
1143 log.info("Rescheduling task, marking task as UP_FOR_RESCHEDULE")
1144 msg = RescheduleTask(
1145 reschedule_date=reschedule.reschedule_date, end_date=datetime.now(tz=timezone.utc)
1146 )
1147 state = TaskInstanceState.UP_FOR_RESCHEDULE
1148 except (AirflowFailException, AirflowSensorTimeout) as e:
1149 # If AirflowFailException is raised, task should not retry.
1150 # If a sensor in reschedule mode reaches timeout, task should not retry.
1151 log.exception("Task failed with exception")
1152 ti.end_date = datetime.now(tz=timezone.utc)
1153 msg = TaskState(
1154 state=TaskInstanceState.FAILED,
1155 end_date=ti.end_date,
1156 rendered_map_index=ti.rendered_map_index,
1157 )
1158 state = TaskInstanceState.FAILED
1159 error = e
1160 except (AirflowTaskTimeout, AirflowException, AirflowRuntimeError) as e:
1161 # We should allow retries if the task has defined it.
1162 log.exception("Task failed with exception")
1163 msg, state = _handle_current_task_failed(ti)
1164 error = e
1165 except AirflowTaskTerminated as e:
1166 # External state updates are already handled with `ti_heartbeat` and will be
1167 # updated already be another UI API. So, these exceptions should ideally never be thrown.
1168 # If these are thrown, we should mark the TI state as failed.
1169 log.exception("Task failed with exception")
1170 ti.end_date = datetime.now(tz=timezone.utc)
1171 msg = TaskState(
1172 state=TaskInstanceState.FAILED,
1173 end_date=ti.end_date,
1174 rendered_map_index=ti.rendered_map_index,
1175 )
1176 state = TaskInstanceState.FAILED
1177 error = e
1178 except SystemExit as e:
1179 # SystemExit needs to be retried if they are eligible.
1180 log.error("Task exited", exit_code=e.code)
1181 msg, state = _handle_current_task_failed(ti)
1182 error = e
1183 except BaseException as e:
1184 log.exception("Task failed with exception")
1185 msg, state = _handle_current_task_failed(ti)
1186 error = e
1187 finally:
1188 if msg:
1189 SUPERVISOR_COMMS.send(msg=msg)
1190
1191 # Return the message to make unit tests easier too
1192 ti.state = state
1193 return state, msg, error
1194
1195
1196def _handle_current_task_success(
1197 context: Context,
1198 ti: RuntimeTaskInstance,
1199) -> tuple[SucceedTask, TaskInstanceState]:
1200 end_date = datetime.now(tz=timezone.utc)
1201 ti.end_date = end_date
1202
1203 # Record operator and task instance success metrics
1204 operator = ti.task.__class__.__name__
1205 stats_tags = {"dag_id": ti.dag_id, "task_id": ti.task_id}
1206
1207 Stats.incr(f"operator_successes_{operator}", tags=stats_tags)
1208 # Same metric with tagging
1209 Stats.incr("operator_successes", tags={**stats_tags, "operator": operator})
1210 Stats.incr("ti_successes", tags=stats_tags)
1211
1212 task_outlets = list(_build_asset_profiles(ti.task.outlets))
1213 outlet_events = list(_serialize_outlet_events(context["outlet_events"]))
1214 msg = SucceedTask(
1215 end_date=end_date,
1216 task_outlets=task_outlets,
1217 outlet_events=outlet_events,
1218 rendered_map_index=ti.rendered_map_index,
1219 )
1220 return msg, TaskInstanceState.SUCCESS
1221
1222
1223def _handle_current_task_failed(
1224 ti: RuntimeTaskInstance,
1225) -> tuple[RetryTask, TaskInstanceState] | tuple[TaskState, TaskInstanceState]:
1226 end_date = datetime.now(tz=timezone.utc)
1227 ti.end_date = end_date
1228
1229 # Record operator and task instance failed metrics
1230 operator = ti.task.__class__.__name__
1231 stats_tags = {"dag_id": ti.dag_id, "task_id": ti.task_id}
1232
1233 Stats.incr(f"operator_failures_{operator}", tags=stats_tags)
1234 # Same metric with tagging
1235 Stats.incr("operator_failures", tags={**stats_tags, "operator": operator})
1236 Stats.incr("ti_failures", tags=stats_tags)
1237
1238 if ti._ti_context_from_server and ti._ti_context_from_server.should_retry:
1239 return RetryTask(end_date=end_date), TaskInstanceState.UP_FOR_RETRY
1240 return (
1241 TaskState(
1242 state=TaskInstanceState.FAILED, end_date=end_date, rendered_map_index=ti.rendered_map_index
1243 ),
1244 TaskInstanceState.FAILED,
1245 )
1246
1247
1248def _handle_trigger_dag_run(
1249 drte: DagRunTriggerException, context: Context, ti: RuntimeTaskInstance, log: Logger
1250) -> tuple[ToSupervisor, TaskInstanceState]:
1251 """Handle exception from TriggerDagRunOperator."""
1252 log.info("Triggering Dag Run.", trigger_dag_id=drte.trigger_dag_id)
1253 comms_msg = SUPERVISOR_COMMS.send(
1254 TriggerDagRun(
1255 dag_id=drte.trigger_dag_id,
1256 run_id=drte.dag_run_id,
1257 logical_date=drte.logical_date,
1258 conf=drte.conf,
1259 reset_dag_run=drte.reset_dag_run,
1260 ),
1261 )
1262
1263 if isinstance(comms_msg, ErrorResponse) and comms_msg.error == ErrorType.DAGRUN_ALREADY_EXISTS:
1264 if drte.skip_when_already_exists:
1265 log.info(
1266 "Dag Run already exists, skipping task as skip_when_already_exists is set to True.",
1267 dag_id=drte.trigger_dag_id,
1268 )
1269 msg = TaskState(
1270 state=TaskInstanceState.SKIPPED,
1271 end_date=datetime.now(tz=timezone.utc),
1272 rendered_map_index=ti.rendered_map_index,
1273 )
1274 state = TaskInstanceState.SKIPPED
1275 else:
1276 log.error("Dag Run already exists, marking task as failed.", dag_id=drte.trigger_dag_id)
1277 msg = TaskState(
1278 state=TaskInstanceState.FAILED,
1279 end_date=datetime.now(tz=timezone.utc),
1280 rendered_map_index=ti.rendered_map_index,
1281 )
1282 state = TaskInstanceState.FAILED
1283
1284 return msg, state
1285
1286 log.info("Dag Run triggered successfully.", trigger_dag_id=drte.trigger_dag_id)
1287
1288 # Store the run id from the dag run (either created or found above) to
1289 # be used when creating the extra link on the webserver.
1290 ti.xcom_push(key="trigger_run_id", value=drte.dag_run_id)
1291
1292 if drte.wait_for_completion:
1293 if drte.deferrable:
1294 from airflow.providers.standard.triggers.external_task import DagStateTrigger
1295
1296 defer = TaskDeferred(
1297 trigger=DagStateTrigger(
1298 dag_id=drte.trigger_dag_id,
1299 states=drte.allowed_states + drte.failed_states, # type: ignore[arg-type]
1300 # Don't filter by execution_dates when run_ids is provided.
1301 # run_id uniquely identifies a DAG run, and when reset_dag_run=True,
1302 # drte.logical_date might be a newly calculated value that doesn't match
1303 # the persisted logical_date in the database, causing the trigger to never find the run.
1304 execution_dates=None,
1305 run_ids=[drte.dag_run_id],
1306 poll_interval=drte.poke_interval,
1307 ),
1308 method_name="execute_complete",
1309 )
1310 return _defer_task(defer, ti, log)
1311 while True:
1312 log.info(
1313 "Waiting for dag run to complete execution in allowed state.",
1314 dag_id=drte.trigger_dag_id,
1315 run_id=drte.dag_run_id,
1316 allowed_state=drte.allowed_states,
1317 )
1318 time.sleep(drte.poke_interval)
1319
1320 comms_msg = SUPERVISOR_COMMS.send(
1321 GetDagRunState(dag_id=drte.trigger_dag_id, run_id=drte.dag_run_id)
1322 )
1323 if TYPE_CHECKING:
1324 assert isinstance(comms_msg, DagRunStateResult)
1325 if comms_msg.state in drte.failed_states:
1326 log.error(
1327 "DagRun finished with failed state.", dag_id=drte.trigger_dag_id, state=comms_msg.state
1328 )
1329 msg = TaskState(
1330 state=TaskInstanceState.FAILED,
1331 end_date=datetime.now(tz=timezone.utc),
1332 rendered_map_index=ti.rendered_map_index,
1333 )
1334 state = TaskInstanceState.FAILED
1335 return msg, state
1336 if comms_msg.state in drte.allowed_states:
1337 log.info(
1338 "DagRun finished with allowed state.", dag_id=drte.trigger_dag_id, state=comms_msg.state
1339 )
1340 break
1341 log.debug(
1342 "DagRun not yet in allowed or failed state.",
1343 dag_id=drte.trigger_dag_id,
1344 state=comms_msg.state,
1345 )
1346 else:
1347 # Fire-and-forget mode: wait_for_completion=False
1348 if drte.deferrable:
1349 log.info(
1350 "Ignoring deferrable=True because wait_for_completion=False. "
1351 "Task will complete immediately without waiting for the triggered DAG run.",
1352 trigger_dag_id=drte.trigger_dag_id,
1353 )
1354
1355 return _handle_current_task_success(context, ti)
1356
1357
1358def _run_task_state_change_callbacks(
1359 task: BaseOperator,
1360 kind: Literal[
1361 "on_execute_callback",
1362 "on_failure_callback",
1363 "on_success_callback",
1364 "on_retry_callback",
1365 "on_skipped_callback",
1366 ],
1367 context: Context,
1368 log: Logger,
1369) -> None:
1370 callback: Callable[[Context], None]
1371 for i, callback in enumerate(getattr(task, kind)):
1372 try:
1373 create_executable_runner(callback, context_get_outlet_events(context), logger=log).run(context)
1374 except Exception:
1375 log.exception("Failed to run task callback", kind=kind, index=i, callback=callback)
1376
1377
1378def _send_error_email_notification(
1379 task: BaseOperator | MappedOperator,
1380 ti: RuntimeTaskInstance,
1381 context: Context,
1382 error: BaseException | str | None,
1383 log: Logger,
1384) -> None:
1385 """Send email notification for task errors using SmtpNotifier."""
1386 try:
1387 from airflow.providers.smtp.notifications.smtp import SmtpNotifier
1388 except ImportError:
1389 log.error(
1390 "Failed to send task failure or retry email notification: "
1391 "`apache-airflow-providers-smtp` is not installed. "
1392 "Install this provider to enable email notifications."
1393 )
1394 return
1395
1396 if not task.email:
1397 return
1398
1399 subject_template_file = conf.get("email", "subject_template", fallback=None)
1400
1401 # Read the template file if configured
1402 if subject_template_file and Path(subject_template_file).exists():
1403 subject = Path(subject_template_file).read_text()
1404 else:
1405 # Fallback to default
1406 subject = "Airflow alert: {{ti}}"
1407
1408 html_content_template_file = conf.get("email", "html_content_template", fallback=None)
1409
1410 # Read the template file if configured
1411 if html_content_template_file and Path(html_content_template_file).exists():
1412 html_content = Path(html_content_template_file).read_text()
1413 else:
1414 # Fallback to default
1415 # For reporting purposes, we report based on 1-indexed,
1416 # not 0-indexed lists (i.e. Try 1 instead of Try 0 for the first attempt).
1417 html_content = (
1418 "Try {{try_number}} out of {{max_tries + 1}}<br>"
1419 "Exception:<br>{{exception_html}}<br>"
1420 'Log: <a href="{{ti.log_url}}">Link</a><br>'
1421 "Host: {{ti.hostname}}<br>"
1422 'Mark success: <a href="{{ti.mark_success_url}}">Link</a><br>'
1423 )
1424
1425 # Add exception_html to context for template rendering
1426 import html
1427
1428 exception_html = html.escape(str(error)).replace("\n", "<br>")
1429 additional_context = {
1430 "exception": error,
1431 "exception_html": exception_html,
1432 "try_number": ti.try_number,
1433 "max_tries": ti.max_tries,
1434 }
1435 email_context = {**context, **additional_context}
1436 to_emails = task.email
1437 if not to_emails:
1438 return
1439
1440 try:
1441 notifier = SmtpNotifier(
1442 to=to_emails,
1443 subject=subject,
1444 html_content=html_content,
1445 from_email=conf.get("email", "from_email", fallback="airflow@airflow"),
1446 )
1447 notifier(email_context)
1448 except Exception:
1449 log.exception("Failed to send email notification")
1450
1451
1452def _execute_task(context: Context, ti: RuntimeTaskInstance, log: Logger):
1453 """Execute Task (optionally with a Timeout) and push Xcom results."""
1454 task = ti.task
1455 execute = task.execute
1456
1457 if ti._ti_context_from_server and (next_method := ti._ti_context_from_server.next_method):
1458 from airflow.sdk.serde import deserialize
1459
1460 next_kwargs_data = ti._ti_context_from_server.next_kwargs or {}
1461 try:
1462 if TYPE_CHECKING:
1463 assert isinstance(next_kwargs_data, dict)
1464 kwargs = deserialize(next_kwargs_data)
1465 except (ImportError, KeyError, AttributeError, TypeError):
1466 from airflow.serialization.serialized_objects import BaseSerialization
1467
1468 kwargs = BaseSerialization.deserialize(next_kwargs_data)
1469
1470 if TYPE_CHECKING:
1471 assert isinstance(kwargs, dict)
1472 execute = functools.partial(task.resume_execution, next_method=next_method, next_kwargs=kwargs)
1473
1474 ctx = contextvars.copy_context()
1475 # Populate the context var so ExecutorSafeguard doesn't complain
1476 ctx.run(ExecutorSafeguard.tracker.set, task)
1477
1478 # Export context in os.environ to make it available for operators to use.
1479 airflow_context_vars = context_to_airflow_vars(context, in_env_var_format=True)
1480 os.environ.update(airflow_context_vars)
1481
1482 outlet_events = context_get_outlet_events(context)
1483
1484 if (pre_execute_hook := task._pre_execute_hook) is not None:
1485 create_executable_runner(pre_execute_hook, outlet_events, logger=log).run(context)
1486 if getattr(pre_execute_hook := task.pre_execute, "__func__", None) is not BaseOperator.pre_execute:
1487 create_executable_runner(pre_execute_hook, outlet_events, logger=log).run(context)
1488
1489 _run_task_state_change_callbacks(task, "on_execute_callback", context, log)
1490
1491 if task.execution_timeout:
1492 from airflow.sdk.execution_time.timeout import timeout
1493
1494 # TODO: handle timeout in case of deferral
1495 timeout_seconds = task.execution_timeout.total_seconds()
1496 try:
1497 # It's possible we're already timed out, so fast-fail if true
1498 if timeout_seconds <= 0:
1499 raise AirflowTaskTimeout()
1500 # Run task in timeout wrapper
1501 with timeout(timeout_seconds):
1502 result = ctx.run(execute, context=context)
1503 except AirflowTaskTimeout:
1504 task.on_kill()
1505 raise
1506 else:
1507 result = ctx.run(execute, context=context)
1508
1509 if (post_execute_hook := task._post_execute_hook) is not None:
1510 create_executable_runner(post_execute_hook, outlet_events, logger=log).run(context, result)
1511 if getattr(post_execute_hook := task.post_execute, "__func__", None) is not BaseOperator.post_execute:
1512 create_executable_runner(post_execute_hook, outlet_events, logger=log).run(context)
1513
1514 return result
1515
1516
1517def _render_map_index(context: Context, ti: RuntimeTaskInstance, log: Logger) -> str | None:
1518 """Render named map index if the Dag author defined map_index_template at the task level."""
1519 if (template := context.get("map_index_template")) is None:
1520 return None
1521 log.debug("Rendering map_index_template", template_length=len(template))
1522 jinja_env = ti.task.dag.get_template_env()
1523 rendered_map_index = jinja_env.from_string(template).render(context)
1524 log.debug("Map index rendered", length=len(rendered_map_index))
1525 return rendered_map_index
1526
1527
1528def _push_xcom_if_needed(result: Any, ti: RuntimeTaskInstance, log: Logger):
1529 """Push XCom values when task has ``do_xcom_push`` set to ``True`` and the task returns a result."""
1530 if ti.task.do_xcom_push:
1531 xcom_value = result
1532 else:
1533 xcom_value = None
1534
1535 has_mapped_dep = next(ti.task.iter_mapped_dependants(), None) is not None
1536 if xcom_value is None:
1537 if not ti.is_mapped and has_mapped_dep:
1538 # Uhoh, a downstream mapped task depends on us to push something to map over
1539 from airflow.sdk.exceptions import XComForMappingNotPushed
1540
1541 raise XComForMappingNotPushed()
1542 return
1543
1544 mapped_length: int | None = None
1545 if not ti.is_mapped and has_mapped_dep:
1546 from airflow.sdk.definitions.mappedoperator import is_mappable_value
1547 from airflow.sdk.exceptions import UnmappableXComTypePushed
1548
1549 if not is_mappable_value(xcom_value):
1550 raise UnmappableXComTypePushed(xcom_value)
1551 mapped_length = len(xcom_value)
1552
1553 log.info("Pushing xcom", ti=ti)
1554
1555 # If the task has multiple outputs, push each output as a separate XCom.
1556 if ti.task.multiple_outputs:
1557 if not isinstance(xcom_value, Mapping):
1558 raise TypeError(
1559 f"Returned output was type {type(xcom_value)} expected dictionary for multiple_outputs"
1560 )
1561 for key in xcom_value.keys():
1562 if not isinstance(key, str):
1563 raise TypeError(
1564 "Returned dictionary keys must be strings when using "
1565 f"multiple_outputs, found {key} ({type(key)}) instead"
1566 )
1567 for k, v in result.items():
1568 ti.xcom_push(k, v)
1569
1570 _xcom_push(ti, BaseXCom.XCOM_RETURN_KEY, result, mapped_length=mapped_length)
1571
1572
1573def finalize(
1574 ti: RuntimeTaskInstance,
1575 state: TaskInstanceState,
1576 context: Context,
1577 log: Logger,
1578 error: BaseException | None = None,
1579):
1580 # Record task duration metrics for all terminal states
1581 if ti.start_date and ti.end_date:
1582 duration_ms = (ti.end_date - ti.start_date).total_seconds() * 1000
1583 stats_tags = {"dag_id": ti.dag_id, "task_id": ti.task_id}
1584
1585 Stats.timing(f"dag.{ti.dag_id}.{ti.task_id}.duration", duration_ms)
1586 Stats.timing("task.duration", duration_ms, tags=stats_tags)
1587
1588 task = ti.task
1589 # Pushing xcom for each operator extra links defined on the operator only.
1590 for oe in task.operator_extra_links:
1591 try:
1592 link, xcom_key = oe.get_link(operator=task, ti_key=ti), oe.xcom_key # type: ignore[arg-type]
1593 log.debug("Setting xcom for operator extra link", link=link, xcom_key=xcom_key)
1594 _xcom_push_to_db(ti, key=xcom_key, value=link)
1595 except Exception:
1596 log.exception(
1597 "Failed to push an xcom for task operator extra link",
1598 link_name=oe.name,
1599 xcom_key=oe.xcom_key,
1600 ti=ti,
1601 )
1602
1603 if getattr(ti.task, "overwrite_rtif_after_execution", False):
1604 log.debug("Overwriting Rendered template fields.")
1605 if ti.task.template_fields:
1606 SUPERVISOR_COMMS.send(SetRenderedFields(rendered_fields=_serialize_rendered_fields(ti.task)))
1607
1608 log.debug("Running finalizers", ti=ti)
1609 if state == TaskInstanceState.SUCCESS:
1610 _run_task_state_change_callbacks(task, "on_success_callback", context, log)
1611 try:
1612 get_listener_manager().hook.on_task_instance_success(
1613 previous_state=TaskInstanceState.RUNNING, task_instance=ti
1614 )
1615 except Exception:
1616 log.exception("error calling listener")
1617 elif state == TaskInstanceState.SKIPPED:
1618 _run_task_state_change_callbacks(task, "on_skipped_callback", context, log)
1619 try:
1620 get_listener_manager().hook.on_task_instance_skipped(
1621 previous_state=TaskInstanceState.RUNNING, task_instance=ti
1622 )
1623 except Exception:
1624 log.exception("error calling listener")
1625 elif state == TaskInstanceState.UP_FOR_RETRY:
1626 _run_task_state_change_callbacks(task, "on_retry_callback", context, log)
1627 try:
1628 get_listener_manager().hook.on_task_instance_failed(
1629 previous_state=TaskInstanceState.RUNNING, task_instance=ti, error=error
1630 )
1631 except Exception:
1632 log.exception("error calling listener")
1633 if error and task.email_on_retry and task.email:
1634 _send_error_email_notification(task, ti, context, error, log)
1635 elif state == TaskInstanceState.FAILED:
1636 _run_task_state_change_callbacks(task, "on_failure_callback", context, log)
1637 try:
1638 get_listener_manager().hook.on_task_instance_failed(
1639 previous_state=TaskInstanceState.RUNNING, task_instance=ti, error=error
1640 )
1641 except Exception:
1642 log.exception("error calling listener")
1643 if error and task.email_on_failure and task.email:
1644 _send_error_email_notification(task, ti, context, error, log)
1645
1646 try:
1647 get_listener_manager().hook.before_stopping(component=TaskRunnerMarker())
1648 except Exception:
1649 log.exception("error calling listener")
1650
1651
1652def main():
1653 log = structlog.get_logger(logger_name="task")
1654
1655 global SUPERVISOR_COMMS
1656 SUPERVISOR_COMMS = CommsDecoder[ToTask, ToSupervisor](log=log)
1657
1658 try:
1659 ti, context, log = startup()
1660 with BundleVersionLock(
1661 bundle_name=ti.bundle_instance.name,
1662 bundle_version=ti.bundle_instance.version,
1663 ):
1664 state, _, error = run(ti, context, log)
1665 context["exception"] = error
1666 finalize(ti, state, context, log, error)
1667 except KeyboardInterrupt:
1668 log.exception("Ctrl-c hit")
1669 exit(2)
1670 except Exception:
1671 log.exception("Top level error")
1672 exit(1)
1673 finally:
1674 # Ensure the request socket is closed on the child side in all circumstances
1675 # before the process fully terminates.
1676 if SUPERVISOR_COMMS and SUPERVISOR_COMMS.socket:
1677 with suppress(Exception):
1678 SUPERVISOR_COMMS.socket.close()
1679
1680
1681def reinit_supervisor_comms() -> None:
1682 """
1683 Re-initialize supervisor comms and logging channel in subprocess.
1684
1685 This is not needed for most cases, but is used when either we re-launch the process via sudo for
1686 run_as_user, or from inside the python code in a virtualenv (et al.) operator to re-connect so those tasks
1687 can continue to access variables etc.
1688 """
1689 import socket
1690
1691 if "SUPERVISOR_COMMS" not in globals():
1692 global SUPERVISOR_COMMS
1693 log = structlog.get_logger(logger_name="task")
1694
1695 fd = int(os.environ.get("__AIRFLOW_SUPERVISOR_FD", "0"))
1696
1697 SUPERVISOR_COMMS = CommsDecoder[ToTask, ToSupervisor](log=log, socket=socket.socket(fileno=fd))
1698
1699 logs = SUPERVISOR_COMMS.send(ResendLoggingFD())
1700 if isinstance(logs, SentFDs):
1701 from airflow.sdk.log import configure_logging
1702
1703 log_io = os.fdopen(logs.fds[0], "wb", buffering=0)
1704 configure_logging(json_output=True, output=log_io, sending_to_supervisor=True)
1705 else:
1706 print("Unable to re-configure logging after sudo, we didn't get an FD", file=sys.stderr)
1707
1708
1709if __name__ == "__main__":
1710 main()