1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18"""The entrypoint for the actual task execution process."""
19
20from __future__ import annotations
21
22import contextlib
23import contextvars
24import functools
25import os
26import sys
27import time
28from collections.abc import Callable, Iterable, Iterator, Mapping
29from contextlib import suppress
30from datetime import datetime, timezone
31from itertools import product
32from pathlib import Path
33from typing import TYPE_CHECKING, Annotated, Any, Literal
34from urllib.parse import quote
35
36import attrs
37import lazy_object_proxy
38import structlog
39from pydantic import AwareDatetime, ConfigDict, Field, JsonValue, TypeAdapter
40
41from airflow.dag_processing.bundles.base import BaseDagBundle, BundleVersionLock
42from airflow.dag_processing.bundles.manager import DagBundlesManager
43from airflow.sdk._shared.observability.metrics.stats import Stats
44from airflow.sdk.api.client import get_hostname, getuser
45from airflow.sdk.api.datamodels._generated import (
46 AssetProfile,
47 DagRun,
48 PreviousTIResponse,
49 TaskInstance,
50 TaskInstanceState,
51 TIRunContext,
52)
53from airflow.sdk.bases.operator import BaseOperator, ExecutorSafeguard
54from airflow.sdk.bases.xcom import BaseXCom
55from airflow.sdk.configuration import conf
56from airflow.sdk.definitions._internal.dag_parsing_context import _airflow_parsing_context_manager
57from airflow.sdk.definitions._internal.types import NOTSET, ArgNotSet, is_arg_set
58from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetNameRef, AssetUniqueKey, AssetUriRef
59from airflow.sdk.definitions.mappedoperator import MappedOperator
60from airflow.sdk.definitions.param import process_params
61from airflow.sdk.exceptions import (
62 AirflowException,
63 AirflowInactiveAssetInInletOrOutletException,
64 AirflowRuntimeError,
65 AirflowTaskTimeout,
66 ErrorType,
67 TaskDeferred,
68)
69from airflow.sdk.execution_time.callback_runner import create_executable_runner
70from airflow.sdk.execution_time.comms import (
71 AssetEventDagRunReferenceResult,
72 CommsDecoder,
73 DagRunStateResult,
74 DeferTask,
75 DRCount,
76 ErrorResponse,
77 GetDagRunState,
78 GetDRCount,
79 GetPreviousDagRun,
80 GetPreviousTI,
81 GetTaskBreadcrumbs,
82 GetTaskRescheduleStartDate,
83 GetTaskStates,
84 GetTICount,
85 InactiveAssetsResult,
86 PreviousDagRunResult,
87 PreviousTIResult,
88 RescheduleTask,
89 ResendLoggingFD,
90 RetryTask,
91 SentFDs,
92 SetRenderedFields,
93 SetRenderedMapIndex,
94 SkipDownstreamTasks,
95 StartupDetails,
96 SucceedTask,
97 TaskBreadcrumbsResult,
98 TaskRescheduleStartDate,
99 TaskState,
100 TaskStatesResult,
101 TICount,
102 ToSupervisor,
103 ToTask,
104 TriggerDagRun,
105 ValidateInletsAndOutlets,
106)
107from airflow.sdk.execution_time.context import (
108 ConnectionAccessor,
109 InletEventsAccessors,
110 MacrosAccessor,
111 OutletEventAccessors,
112 TriggeringAssetEventsAccessor,
113 VariableAccessor,
114 context_get_outlet_events,
115 context_to_airflow_vars,
116 get_previous_dagrun_success,
117 set_current_context,
118)
119from airflow.sdk.execution_time.sentry import Sentry
120from airflow.sdk.execution_time.xcom import XCom
121from airflow.sdk.listener import get_listener_manager
122from airflow.sdk.timezone import coerce_datetime
123from airflow.triggers.base import BaseEventTrigger
124from airflow.triggers.callback import CallbackTrigger
125
126if TYPE_CHECKING:
127 import jinja2
128 from pendulum.datetime import DateTime
129 from structlog.typing import FilteringBoundLogger as Logger
130
131 from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator
132 from airflow.sdk.definitions.context import Context
133 from airflow.sdk.exceptions import DagRunTriggerException
134 from airflow.sdk.types import OutletEventAccessorsProtocol
135
136
137class TaskRunnerMarker:
138 """Marker for listener hooks, to properly detect from which component they are called."""
139
140
141# TODO: Move this entire class into a separate file:
142# `airflow/sdk/execution_time/task_instance.py`
143# or `airflow/sdk/execution_time/runtime_ti.py`
144class RuntimeTaskInstance(TaskInstance):
145 model_config = ConfigDict(arbitrary_types_allowed=True)
146
147 task: BaseOperator
148 bundle_instance: BaseDagBundle
149 _cached_template_context: Context | None = None
150 """The Task Instance context. This is used to cache get_template_context."""
151
152 _ti_context_from_server: Annotated[TIRunContext | None, Field(repr=False)] = None
153 """The Task Instance context from the API server, if any."""
154
155 max_tries: int = 0
156 """The maximum number of retries for the task."""
157
158 start_date: AwareDatetime
159 """Start date of the task instance."""
160
161 end_date: AwareDatetime | None = None
162
163 state: TaskInstanceState | None = None
164
165 is_mapped: bool | None = None
166 """True if the original task was mapped."""
167
168 rendered_map_index: str | None = None
169
170 sentry_integration: str = ""
171
172 def __rich_repr__(self):
173 yield "id", self.id
174 yield "task_id", self.task_id
175 yield "dag_id", self.dag_id
176 yield "run_id", self.run_id
177 yield "max_tries", self.max_tries
178 yield "task", type(self.task)
179 yield "start_date", self.start_date
180
181 __rich_repr__.angular = True # type: ignore[attr-defined]
182
183 def get_template_context(self) -> Context:
184 # TODO: Move this to `airflow.sdk.execution_time.context`
185 # once we port the entire context logic from airflow/utils/context.py ?
186 from airflow.sdk.plugins_manager import integrate_macros_plugins
187
188 integrate_macros_plugins()
189
190 dag_run_conf: dict[str, Any] | None = None
191 if from_server := self._ti_context_from_server:
192 dag_run_conf = from_server.dag_run.conf or dag_run_conf
193
194 validated_params = process_params(self.task.dag, self.task, dag_run_conf, suppress_exception=False)
195
196 # Cache the context object, which ensures that all calls to get_template_context
197 # are operating on the same context object.
198 self._cached_template_context: Context = self._cached_template_context or {
199 # From the Task Execution interface
200 "dag": self.task.dag,
201 "inlets": self.task.inlets,
202 "map_index_template": self.task.map_index_template,
203 "outlets": self.task.outlets,
204 "run_id": self.run_id,
205 "task": self.task,
206 "task_instance": self,
207 "ti": self,
208 "outlet_events": OutletEventAccessors(),
209 "inlet_events": InletEventsAccessors(self.task.inlets),
210 "macros": MacrosAccessor(),
211 "params": validated_params,
212 # TODO: Make this go through Public API longer term.
213 # "test_mode": task_instance.test_mode,
214 "var": {
215 "json": VariableAccessor(deserialize_json=True),
216 "value": VariableAccessor(deserialize_json=False),
217 },
218 "conn": ConnectionAccessor(),
219 }
220 if from_server:
221 dag_run = from_server.dag_run
222 context_from_server: Context = {
223 # TODO: Assess if we need to pass these through timezone.coerce_datetime
224 "dag_run": dag_run, # type: ignore[typeddict-item] # Removable after #46522
225 "triggering_asset_events": TriggeringAssetEventsAccessor.build(
226 AssetEventDagRunReferenceResult.from_asset_event_dag_run_reference(event)
227 for event in dag_run.consumed_asset_events
228 ),
229 "task_instance_key_str": f"{self.task.dag_id}__{self.task.task_id}__{dag_run.run_id}",
230 "task_reschedule_count": from_server.task_reschedule_count or 0,
231 "prev_start_date_success": lazy_object_proxy.Proxy(
232 lambda: coerce_datetime(get_previous_dagrun_success(self.id).start_date)
233 ),
234 "prev_end_date_success": lazy_object_proxy.Proxy(
235 lambda: coerce_datetime(get_previous_dagrun_success(self.id).end_date)
236 ),
237 }
238 self._cached_template_context.update(context_from_server)
239
240 if logical_date := coerce_datetime(dag_run.logical_date):
241 if TYPE_CHECKING:
242 assert isinstance(logical_date, DateTime)
243 ds = logical_date.strftime("%Y-%m-%d")
244 ds_nodash = ds.replace("-", "")
245 ts = logical_date.isoformat()
246 ts_nodash = logical_date.strftime("%Y%m%dT%H%M%S")
247 ts_nodash_with_tz = ts.replace("-", "").replace(":", "")
248 # logical_date and data_interval either coexist or be None together
249 self._cached_template_context.update(
250 {
251 # keys that depend on logical_date
252 "logical_date": logical_date,
253 "ds": ds,
254 "ds_nodash": ds_nodash,
255 "task_instance_key_str": f"{self.task.dag_id}__{self.task.task_id}__{ds_nodash}",
256 "ts": ts,
257 "ts_nodash": ts_nodash,
258 "ts_nodash_with_tz": ts_nodash_with_tz,
259 # keys that depend on data_interval
260 "data_interval_end": coerce_datetime(dag_run.data_interval_end),
261 "data_interval_start": coerce_datetime(dag_run.data_interval_start),
262 "prev_data_interval_start_success": lazy_object_proxy.Proxy(
263 lambda: coerce_datetime(get_previous_dagrun_success(self.id).data_interval_start)
264 ),
265 "prev_data_interval_end_success": lazy_object_proxy.Proxy(
266 lambda: coerce_datetime(get_previous_dagrun_success(self.id).data_interval_end)
267 ),
268 }
269 )
270
271 if from_server.upstream_map_indexes is not None:
272 # We stash this in here for later use, but we purposefully don't want to document it's
273 # existence. Should this be a private attribute on RuntimeTI instead perhaps?
274 setattr(self, "_upstream_map_indexes", from_server.upstream_map_indexes)
275
276 return self._cached_template_context
277
278 def render_templates(
279 self, context: Context | None = None, jinja_env: jinja2.Environment | None = None
280 ) -> BaseOperator:
281 """
282 Render templates in the operator fields.
283
284 If the task was originally mapped, this may replace ``self.task`` with
285 the unmapped, fully rendered BaseOperator. The original ``self.task``
286 before replacement is returned.
287 """
288 if not context:
289 context = self.get_template_context()
290 original_task = self.task
291
292 if TYPE_CHECKING:
293 assert context
294
295 ti = context["ti"]
296
297 if TYPE_CHECKING:
298 assert original_task
299 assert self.task
300 assert ti.task
301
302 # If self.task is mapped, this call replaces self.task to point to the
303 # unmapped BaseOperator created by this function! This is because the
304 # MappedOperator is useless for template rendering, and we need to be
305 # able to access the unmapped task instead.
306 self.task.render_template_fields(context, jinja_env)
307 self.is_mapped = original_task.is_mapped
308 return original_task
309
310 def xcom_pull(
311 self,
312 task_ids: str | Iterable[str] | None = None,
313 dag_id: str | None = None,
314 key: str = BaseXCom.XCOM_RETURN_KEY,
315 include_prior_dates: bool = False,
316 *,
317 map_indexes: int | Iterable[int] | None | ArgNotSet = NOTSET,
318 default: Any = None,
319 run_id: str | None = None,
320 ) -> Any:
321 """
322 Pull XComs either from the API server (BaseXCom) or from the custom XCOM backend if configured.
323
324 The pull can be filtered optionally by certain criterion.
325
326 :param key: A key for the XCom. If provided, only XComs with matching
327 keys will be returned. The default key is ``'return_value'``, also
328 available as constant ``XCOM_RETURN_KEY``. This key is automatically
329 given to XComs returned by tasks (as opposed to being pushed
330 manually).
331 :param task_ids: Only XComs from tasks with matching ids will be
332 pulled. If *None* (default), the task_id of the calling task is used.
333 :param dag_id: If provided, only pulls XComs from this Dag. If *None*
334 (default), the Dag of the calling task is used.
335 :param map_indexes: If provided, only pull XComs with matching indexes.
336 If *None* (default), this is inferred from the task(s) being pulled
337 (see below for details).
338 :param include_prior_dates: If False, only XComs from the current
339 logical_date are returned. If *True*, XComs from previous dates
340 are returned as well.
341 :param run_id: If provided, only pulls XComs from a DagRun w/a matching run_id.
342 If *None* (default), the run_id of the calling task is used.
343
344 When pulling one single task (``task_id`` is *None* or a str) without
345 specifying ``map_indexes``, the return value is a single XCom entry
346 (map_indexes is set to map_index of the calling task instance).
347
348 When pulling task is mapped the specified ``map_index`` is used, so by default
349 pulling on mapped task will result in no matching XComs if the task instance
350 of the method call is not mapped. Otherwise, the map_index of the calling task
351 instance is used. Setting ``map_indexes`` to *None* will pull XCom as it would
352 from a non mapped task.
353
354 In either case, ``default`` (*None* if not specified) is returned if no
355 matching XComs are found.
356
357 When pulling multiple tasks (i.e. either ``task_id`` or ``map_index`` is
358 a non-str iterable), a list of matching XComs is returned. Elements in
359 the list is ordered by item ordering in ``task_id`` and ``map_index``.
360 """
361 if dag_id is None:
362 dag_id = self.dag_id
363 if run_id is None:
364 run_id = self.run_id
365
366 single_task_requested = isinstance(task_ids, (str, type(None)))
367 single_map_index_requested = isinstance(map_indexes, (int, type(None)))
368
369 if task_ids is None:
370 # default to the current task if not provided
371 task_ids = [self.task_id]
372 elif isinstance(task_ids, str):
373 task_ids = [task_ids]
374
375 # If map_indexes is not specified, pull xcoms from all map indexes for each task
376 if not is_arg_set(map_indexes):
377 xcoms: list[Any] = []
378 for t_id in task_ids:
379 values = XCom.get_all(
380 run_id=run_id,
381 key=key,
382 task_id=t_id,
383 dag_id=dag_id,
384 include_prior_dates=include_prior_dates,
385 )
386
387 if values is None:
388 xcoms.append(None)
389 else:
390 xcoms.extend(values)
391 # For single task pulling from unmapped task, return single value
392 if single_task_requested and len(xcoms) == 1:
393 return xcoms[0]
394 return xcoms
395
396 # Original logic when map_indexes is explicitly specified
397 map_indexes_iterable: Iterable[int | None] = []
398 if isinstance(map_indexes, int) or map_indexes is None:
399 map_indexes_iterable = [map_indexes]
400 elif isinstance(map_indexes, Iterable):
401 map_indexes_iterable = map_indexes
402 else:
403 raise TypeError(
404 f"Invalid type for map_indexes: expected int, iterable of ints, or None, got {type(map_indexes)}"
405 )
406
407 xcoms = []
408 for t_id, m_idx in product(task_ids, map_indexes_iterable):
409 value = XCom.get_one(
410 run_id=run_id,
411 key=key,
412 task_id=t_id,
413 dag_id=dag_id,
414 map_index=m_idx,
415 include_prior_dates=include_prior_dates,
416 )
417 if value is None:
418 xcoms.append(default)
419 else:
420 xcoms.append(value)
421
422 if single_task_requested and single_map_index_requested:
423 return xcoms[0]
424
425 return xcoms
426
427 def xcom_push(self, key: str, value: Any):
428 """
429 Make an XCom available for tasks to pull.
430
431 :param key: Key to store the value under.
432 :param value: Value to store. Only be JSON-serializable values may be used.
433 """
434 _xcom_push(self, key, value)
435
436 def get_relevant_upstream_map_indexes(
437 self, upstream: BaseOperator, ti_count: int | None, session: Any
438 ) -> int | range | None:
439 # TODO: Implement this method
440 return None
441
442 def get_first_reschedule_date(self, context: Context) -> AwareDatetime | None:
443 """Get the first reschedule date for the task instance if found, none otherwise."""
444 if context.get("task_reschedule_count", 0) == 0:
445 # If the task has not been rescheduled, there is no need to ask the supervisor
446 return None
447
448 max_tries: int = self.max_tries
449 retries: int = self.task.retries or 0
450 first_try_number = max_tries - retries + 1
451
452 log = structlog.get_logger(logger_name="task")
453
454 log.debug("Requesting first reschedule date from supervisor")
455
456 response = SUPERVISOR_COMMS.send(
457 msg=GetTaskRescheduleStartDate(ti_id=self.id, try_number=first_try_number)
458 )
459
460 if TYPE_CHECKING:
461 assert isinstance(response, TaskRescheduleStartDate)
462
463 return response.start_date
464
465 def get_previous_dagrun(self, state: str | None = None) -> DagRun | None:
466 """Return the previous Dag run before the given logical date, optionally filtered by state."""
467 context = self.get_template_context()
468 dag_run = context.get("dag_run")
469
470 log = structlog.get_logger(logger_name="task")
471
472 log.debug("Getting previous Dag run", dag_run=dag_run)
473
474 if dag_run is None:
475 return None
476
477 if dag_run.logical_date is None:
478 return None
479
480 response = SUPERVISOR_COMMS.send(
481 msg=GetPreviousDagRun(dag_id=self.dag_id, logical_date=dag_run.logical_date, state=state)
482 )
483
484 if TYPE_CHECKING:
485 assert isinstance(response, PreviousDagRunResult)
486
487 return response.dag_run
488
489 def get_previous_ti(
490 self,
491 state: TaskInstanceState | None = None,
492 logical_date: AwareDatetime | None = None,
493 map_index: int = -1,
494 ) -> PreviousTIResponse | None:
495 """
496 Return the previous task instance matching the given criteria.
497
498 :param state: Filter by TaskInstance state
499 :param logical_date: Filter by logical date (returns TI before this date)
500 :param map_index: Filter by map_index (defaults to -1 for non-mapped tasks)
501 :return: Previous task instance or None if not found
502 """
503 context = self.get_template_context()
504 dag_run = context.get("dag_run")
505
506 log = structlog.get_logger(logger_name="task")
507 log.debug("Getting previous task instance", task_id=self.task_id, state=state)
508
509 # Use current dag run's logical_date if not provided
510 effective_logical_date = logical_date
511 if effective_logical_date is None and dag_run and dag_run.logical_date:
512 effective_logical_date = dag_run.logical_date
513
514 response = SUPERVISOR_COMMS.send(
515 msg=GetPreviousTI(
516 dag_id=self.dag_id,
517 task_id=self.task_id,
518 logical_date=effective_logical_date,
519 map_index=map_index,
520 state=state,
521 )
522 )
523
524 if TYPE_CHECKING:
525 assert isinstance(response, PreviousTIResult)
526
527 return response.task_instance
528
529 @staticmethod
530 def get_ti_count(
531 dag_id: str,
532 map_index: int | None = None,
533 task_ids: list[str] | None = None,
534 task_group_id: str | None = None,
535 logical_dates: list[datetime] | None = None,
536 run_ids: list[str] | None = None,
537 states: list[str] | None = None,
538 ) -> int:
539 """Return the number of task instances matching the given criteria."""
540 response = SUPERVISOR_COMMS.send(
541 GetTICount(
542 dag_id=dag_id,
543 map_index=map_index,
544 task_ids=task_ids,
545 task_group_id=task_group_id,
546 logical_dates=logical_dates,
547 run_ids=run_ids,
548 states=states,
549 ),
550 )
551
552 if TYPE_CHECKING:
553 assert isinstance(response, TICount)
554
555 return response.count
556
557 @staticmethod
558 def get_task_states(
559 dag_id: str,
560 map_index: int | None = None,
561 task_ids: list[str] | None = None,
562 task_group_id: str | None = None,
563 logical_dates: list[datetime] | None = None,
564 run_ids: list[str] | None = None,
565 ) -> dict[str, Any]:
566 """Return the task states matching the given criteria."""
567 response = SUPERVISOR_COMMS.send(
568 GetTaskStates(
569 dag_id=dag_id,
570 map_index=map_index,
571 task_ids=task_ids,
572 task_group_id=task_group_id,
573 logical_dates=logical_dates,
574 run_ids=run_ids,
575 ),
576 )
577
578 if TYPE_CHECKING:
579 assert isinstance(response, TaskStatesResult)
580
581 return response.task_states
582
583 @staticmethod
584 def get_task_breadcrumbs(dag_id: str, run_id: str) -> Iterable[dict[str, Any]]:
585 """Return task breadcrumbs for the given dag run."""
586 response = SUPERVISOR_COMMS.send(GetTaskBreadcrumbs(dag_id=dag_id, run_id=run_id))
587 if TYPE_CHECKING:
588 assert isinstance(response, TaskBreadcrumbsResult)
589 return response.breadcrumbs
590
591 @staticmethod
592 def get_dr_count(
593 dag_id: str,
594 logical_dates: list[datetime] | None = None,
595 run_ids: list[str] | None = None,
596 states: list[str] | None = None,
597 ) -> int:
598 """Return the number of Dag runs matching the given criteria."""
599 response = SUPERVISOR_COMMS.send(
600 GetDRCount(
601 dag_id=dag_id,
602 logical_dates=logical_dates,
603 run_ids=run_ids,
604 states=states,
605 ),
606 )
607
608 if TYPE_CHECKING:
609 assert isinstance(response, DRCount)
610
611 return response.count
612
613 @staticmethod
614 def get_dagrun_state(dag_id: str, run_id: str) -> str:
615 """Return the state of the Dag run with the given Run ID."""
616 response = SUPERVISOR_COMMS.send(msg=GetDagRunState(dag_id=dag_id, run_id=run_id))
617
618 if TYPE_CHECKING:
619 assert isinstance(response, DagRunStateResult)
620
621 return response.state
622
623 @property
624 def log_url(self) -> str:
625 run_id = quote(self.run_id)
626 base_url = conf.get("api", "base_url", fallback="http://localhost:8080/")
627 map_index_value = self.map_index
628 map_index = (
629 f"/mapped/{map_index_value}" if map_index_value is not None and map_index_value >= 0 else ""
630 )
631 try_number_value = self.try_number
632 try_number = (
633 f"?try_number={try_number_value}" if try_number_value is not None and try_number_value > 0 else ""
634 )
635 _log_uri = f"{base_url.rstrip('/')}/dags/{self.dag_id}/runs/{run_id}/tasks/{self.task_id}{map_index}{try_number}"
636 return _log_uri
637
638 @property
639 def mark_success_url(self) -> str:
640 """URL to mark TI success."""
641 return self.log_url
642
643
644def _xcom_push(ti: RuntimeTaskInstance, key: str, value: Any, mapped_length: int | None = None) -> None:
645 """Push a XCom through XCom.set, which pushes to XCom Backend if configured."""
646 # Private function, as we don't want to expose the ability to manually set `mapped_length` to SDK
647 # consumers
648
649 XCom.set(
650 key=key,
651 value=value,
652 dag_id=ti.dag_id,
653 task_id=ti.task_id,
654 run_id=ti.run_id,
655 map_index=ti.map_index,
656 _mapped_length=mapped_length,
657 )
658
659
660def _xcom_push_to_db(ti: RuntimeTaskInstance, key: str, value: Any) -> None:
661 """Push a XCom directly to metadata DB, bypassing custom xcom_backend."""
662 XCom._set_xcom_in_db(
663 key=key,
664 value=value,
665 dag_id=ti.dag_id,
666 task_id=ti.task_id,
667 run_id=ti.run_id,
668 map_index=ti.map_index,
669 )
670
671
672def parse(what: StartupDetails, log: Logger) -> RuntimeTaskInstance:
673 # TODO: Task-SDK:
674 # Using BundleDagBag here is about 98% wrong, but it'll do for now
675 from airflow.dag_processing.dagbag import BundleDagBag
676
677 bundle_info = what.bundle_info
678 bundle_instance = DagBundlesManager().get_bundle(
679 name=bundle_info.name,
680 version=bundle_info.version,
681 )
682 bundle_instance.initialize()
683 _verify_bundle_access(bundle_instance, log)
684
685 dag_absolute_path = os.fspath(Path(bundle_instance.path, what.dag_rel_path))
686 bag = BundleDagBag(
687 dag_folder=dag_absolute_path,
688 safe_mode=False,
689 load_op_links=False,
690 bundle_path=bundle_instance.path,
691 bundle_name=bundle_info.name,
692 )
693 if TYPE_CHECKING:
694 assert what.ti.dag_id
695
696 try:
697 dag = bag.dags[what.ti.dag_id]
698 except KeyError:
699 log.error(
700 "Dag not found during start up", dag_id=what.ti.dag_id, bundle=bundle_info, path=what.dag_rel_path
701 )
702 sys.exit(1)
703
704 # install_loader()
705
706 try:
707 task = dag.task_dict[what.ti.task_id]
708 except KeyError:
709 log.error(
710 "Task not found in Dag during start up",
711 dag_id=dag.dag_id,
712 task_id=what.ti.task_id,
713 bundle=bundle_info,
714 path=what.dag_rel_path,
715 )
716 sys.exit(1)
717
718 if not isinstance(task, (BaseOperator, MappedOperator)):
719 raise TypeError(
720 f"task is of the wrong type, got {type(task)}, wanted {BaseOperator} or {MappedOperator}"
721 )
722
723 return RuntimeTaskInstance.model_construct(
724 **what.ti.model_dump(exclude_unset=True),
725 task=task,
726 bundle_instance=bundle_instance,
727 _ti_context_from_server=what.ti_context,
728 max_tries=what.ti_context.max_tries,
729 start_date=what.start_date,
730 state=TaskInstanceState.RUNNING,
731 sentry_integration=what.sentry_integration,
732 )
733
734
735# This global variable will be used by Connection/Variable/XCom classes, or other parts of the task's execution,
736# to send requests back to the supervisor process.
737#
738# Why it needs to be a global:
739# - Many parts of Airflow's codebase (e.g., connections, variables, and XComs) may rely on making dynamic requests
740# to the parent process during task execution.
741# - These calls occur in various locations and cannot easily pass the `CommsDecoder` instance through the
742# deeply nested execution stack.
743# - By defining `SUPERVISOR_COMMS` as a global, it ensures that this communication mechanism is readily
744# accessible wherever needed during task execution without modifying every layer of the call stack.
745SUPERVISOR_COMMS: CommsDecoder[ToTask, ToSupervisor]
746
747
748# State machine!
749# 1. Start up (receive details from supervisor)
750# 2. Execution (run task code, possibly send requests)
751# 3. Shutdown and report status
752
753
754def _verify_bundle_access(bundle_instance: BaseDagBundle, log: Logger) -> None:
755 """
756 Verify bundle is accessible by the current user.
757
758 This is called after user impersonation (if any) to ensure the bundle
759 is actually accessible. Uses os.access() which works with any permission
760 scheme (standard Unix permissions, ACLs, SELinux, etc.).
761
762 :param bundle_instance: The bundle instance to check
763 :param log: Logger instance
764 :raises AirflowException: if bundle is not accessible
765 """
766 from getpass import getuser
767
768 from airflow.exceptions import AirflowException
769
770 bundle_path = bundle_instance.path
771
772 if not bundle_path.exists():
773 # Already handled by initialize() with a warning
774 return
775
776 # Check read permission (and execute for directories to list contents)
777 access_mode = os.R_OK
778 if bundle_path.is_dir():
779 access_mode |= os.X_OK
780
781 if not os.access(bundle_path, access_mode):
782 raise AirflowException(
783 f"Bundle '{bundle_instance.name}' path '{bundle_path}' is not accessible "
784 f"by user '{getuser()}'. When using run_as_user, ensure bundle directories "
785 f"are readable by the impersonated user. "
786 f"See: https://airflow.apache.org/docs/apache-airflow/stable/administration-and-deployment/dag-bundles.html"
787 )
788
789
790def startup() -> tuple[RuntimeTaskInstance, Context, Logger]:
791 # The parent sends us a StartupDetails message un-prompted. After this, every single message is only sent
792 # in response to us sending a request.
793 log = structlog.get_logger(logger_name="task")
794
795 if os.environ.get("_AIRFLOW__REEXECUTED_PROCESS") == "1" and (
796 msgjson := os.environ.get("_AIRFLOW__STARTUP_MSG")
797 ):
798 # Clear any Kerberos replace cache if there is one, so new process can't reuse it.
799 os.environ.pop("KRB5CCNAME", None)
800 # entrypoint of re-exec process
801
802 msg: StartupDetails = TypeAdapter(StartupDetails).validate_json(msgjson)
803 reinit_supervisor_comms()
804
805 # We delay this message until _after_ we've got the logging re-configured, otherwise it will show up
806 # on stdout
807 log.debug("Using serialized startup message from environment", msg=msg)
808 else:
809 # normal entry point
810 msg = SUPERVISOR_COMMS._get_response() # type: ignore[assignment]
811
812 if not isinstance(msg, StartupDetails):
813 raise RuntimeError(f"Unhandled startup message {type(msg)} {msg}")
814
815 # setproctitle causes issue on Mac OS: https://github.com/benoitc/gunicorn/issues/3021
816 os_type = sys.platform
817 if os_type == "darwin":
818 log.debug("Mac OS detected, skipping setproctitle")
819 else:
820 from setproctitle import setproctitle
821
822 setproctitle(f"airflow worker -- {msg.ti.id}")
823
824 try:
825 get_listener_manager().hook.on_starting(component=TaskRunnerMarker())
826 except Exception:
827 log.exception("error calling listener")
828
829 with _airflow_parsing_context_manager(dag_id=msg.ti.dag_id, task_id=msg.ti.task_id):
830 ti = parse(msg, log)
831 log.debug("Dag file parsed", file=msg.dag_rel_path)
832
833 run_as_user = getattr(ti.task, "run_as_user", None) or conf.get(
834 "core", "default_impersonation", fallback=None
835 )
836
837 if os.environ.get("_AIRFLOW__REEXECUTED_PROCESS") != "1" and run_as_user and run_as_user != getuser():
838 # enters here for re-exec process
839 os.environ["_AIRFLOW__REEXECUTED_PROCESS"] = "1"
840 # store startup message in environment for re-exec process
841 os.environ["_AIRFLOW__STARTUP_MSG"] = msg.model_dump_json()
842 os.set_inheritable(SUPERVISOR_COMMS.socket.fileno(), True)
843
844 # Import main directly from the module instead of re-executing the file.
845 # This ensures that when other parts modules import
846 # airflow.sdk.execution_time.task_runner, they get the same module instance
847 # with the properly initialized SUPERVISOR_COMMS global variable.
848 # If we re-executed the module with `python -m`, it would load as __main__ and future
849 # imports would get a fresh copy without the initialized globals.
850 rexec_python_code = "from airflow.sdk.execution_time.task_runner import main; main()"
851 cmd = ["sudo", "-E", "-H", "-u", run_as_user, sys.executable, "-c", rexec_python_code]
852 log.info(
853 "Running command",
854 command=cmd,
855 )
856 os.execvp("sudo", cmd)
857
858 # ideally, we should never reach here, but if we do, we should return None, None, None
859 return None, None, None
860
861 return ti, ti.get_template_context(), log
862
863
864def _serialize_template_field(template_field: Any, name: str) -> str | dict | list | int | float:
865 """
866 Return a serializable representation of the templated field.
867
868 If ``templated_field`` contains a class or instance that requires recursive
869 templating, store them as strings. Otherwise simply return the field as-is.
870
871 Used sdk secrets masker to redact secrets in the serialized output.
872 """
873 import json
874
875 from airflow.sdk._shared.secrets_masker import redact
876
877 def is_jsonable(x):
878 try:
879 json.dumps(x)
880 except (TypeError, OverflowError):
881 return False
882 else:
883 return True
884
885 def translate_tuples_to_lists(obj: Any):
886 """Recursively convert tuples to lists."""
887 if isinstance(obj, tuple):
888 return [translate_tuples_to_lists(item) for item in obj]
889 if isinstance(obj, list):
890 return [translate_tuples_to_lists(item) for item in obj]
891 if isinstance(obj, dict):
892 return {key: translate_tuples_to_lists(value) for key, value in obj.items()}
893 return obj
894
895 def sort_dict_recursively(obj: Any) -> Any:
896 """Recursively sort dictionaries to ensure consistent ordering."""
897 if isinstance(obj, dict):
898 return {k: sort_dict_recursively(v) for k, v in sorted(obj.items())}
899 if isinstance(obj, list):
900 return [sort_dict_recursively(item) for item in obj]
901 if isinstance(obj, tuple):
902 return tuple(sort_dict_recursively(item) for item in obj)
903 return obj
904
905 max_length = conf.getint("core", "max_templated_field_length")
906
907 if not is_jsonable(template_field):
908 try:
909 serialized = template_field.serialize()
910 except AttributeError:
911 serialized = str(template_field)
912 if len(serialized) > max_length:
913 rendered = redact(serialized, name)
914 return (
915 "Truncated. You can change this behaviour in [core]max_templated_field_length. "
916 f"{rendered[: max_length - 79]!r}... "
917 )
918 return serialized
919 if not template_field and not isinstance(template_field, tuple):
920 # Avoid unnecessary serialization steps for empty fields unless they are tuples
921 # and need to be converted to lists
922 return template_field
923 template_field = translate_tuples_to_lists(template_field)
924 # Sort dictionaries recursively to ensure consistent string representation
925 # This prevents hash inconsistencies when dict ordering varies
926 if isinstance(template_field, dict):
927 template_field = sort_dict_recursively(template_field)
928 serialized = str(template_field)
929 if len(serialized) > max_length:
930 rendered = redact(serialized, name)
931 return (
932 "Truncated. You can change this behaviour in [core]max_templated_field_length. "
933 f"{rendered[: max_length - 79]!r}... "
934 )
935 return template_field
936
937
938def _serialize_rendered_fields(task: AbstractOperator) -> dict[str, JsonValue]:
939 from airflow.sdk._shared.secrets_masker import redact
940
941 rendered_fields = {}
942 for field in task.template_fields:
943 value = getattr(task, field)
944 serialized = _serialize_template_field(value, field)
945 # Redact secrets in the task process itself before sending to API server
946 # This ensures that the secrets those are registered via mask_secret() on workers / dag processor are properly masked
947 # on the UI.
948 rendered_fields[field] = redact(serialized, field)
949
950 return rendered_fields # type: ignore[return-value] # Convince mypy that this is OK since we pass JsonValue to redact, so it will return the same
951
952
953def _build_asset_profiles(lineage_objects: list) -> Iterator[AssetProfile]:
954 # Lineage can have other types of objects besides assets, so we need to process them a bit.
955 for obj in lineage_objects or ():
956 if isinstance(obj, Asset):
957 yield AssetProfile(name=obj.name, uri=obj.uri, type=Asset.__name__)
958 elif isinstance(obj, AssetNameRef):
959 yield AssetProfile(name=obj.name, type=AssetNameRef.__name__)
960 elif isinstance(obj, AssetUriRef):
961 yield AssetProfile(uri=obj.uri, type=AssetUriRef.__name__)
962 elif isinstance(obj, AssetAlias):
963 yield AssetProfile(name=obj.name, type=AssetAlias.__name__)
964
965
966def _serialize_outlet_events(events: OutletEventAccessorsProtocol) -> Iterator[dict[str, JsonValue]]:
967 if TYPE_CHECKING:
968 assert isinstance(events, OutletEventAccessors)
969 # We just collect everything the user recorded in the accessors.
970 # Further filtering will be done in the API server.
971 for key, accessor in events._dict.items():
972 if isinstance(key, AssetUniqueKey):
973 yield {"dest_asset_key": attrs.asdict(key), "extra": accessor.extra}
974 for alias_event in accessor.asset_alias_events:
975 yield attrs.asdict(alias_event)
976
977
978def _prepare(ti: RuntimeTaskInstance, log: Logger, context: Context) -> ToSupervisor | None:
979 ti.hostname = get_hostname()
980 ti.task = ti.task.prepare_for_execution()
981 # Since context is now cached, and calling `ti.get_template_context` will return the same dict, we want to
982 # update the value of the task that is sent from there
983 context["task"] = ti.task
984
985 jinja_env = ti.task.dag.get_template_env()
986 ti.render_templates(context=context, jinja_env=jinja_env)
987
988 if rendered_fields := _serialize_rendered_fields(ti.task):
989 # so that we do not call the API unnecessarily
990 SUPERVISOR_COMMS.send(msg=SetRenderedFields(rendered_fields=rendered_fields))
991
992 # Try to render map_index_template early with available context (will be re-rendered after execution)
993 # This provides a partial label during task execution for templates using pre-execution context
994 # If rendering fails here, we suppress the error since it will be re-rendered after execution
995 try:
996 if rendered_map_index := _render_map_index(context, ti=ti, log=log):
997 ti.rendered_map_index = rendered_map_index
998 log.debug("Sending early rendered map index", length=len(rendered_map_index))
999 SUPERVISOR_COMMS.send(msg=SetRenderedMapIndex(rendered_map_index=rendered_map_index))
1000 except Exception:
1001 log.debug(
1002 "Early rendering of map_index_template failed, will retry after task execution", exc_info=True
1003 )
1004
1005 _validate_task_inlets_and_outlets(ti=ti, log=log)
1006
1007 try:
1008 # TODO: Call pre execute etc.
1009 get_listener_manager().hook.on_task_instance_running(
1010 previous_state=TaskInstanceState.QUEUED, task_instance=ti
1011 )
1012 except Exception:
1013 log.exception("error calling listener")
1014
1015 # No error, carry on and execute the task
1016 return None
1017
1018
1019def _validate_task_inlets_and_outlets(*, ti: RuntimeTaskInstance, log: Logger) -> None:
1020 if not ti.task.inlets and not ti.task.outlets:
1021 return
1022
1023 inactive_assets_resp = SUPERVISOR_COMMS.send(msg=ValidateInletsAndOutlets(ti_id=ti.id))
1024 if TYPE_CHECKING:
1025 assert isinstance(inactive_assets_resp, InactiveAssetsResult)
1026 if inactive_assets := inactive_assets_resp.inactive_assets:
1027 raise AirflowInactiveAssetInInletOrOutletException(
1028 inactive_asset_keys=[
1029 AssetUniqueKey.from_profile(asset_profile) for asset_profile in inactive_assets
1030 ]
1031 )
1032
1033
1034def _defer_task(
1035 defer: TaskDeferred, ti: RuntimeTaskInstance, log: Logger
1036) -> tuple[ToSupervisor, TaskInstanceState]:
1037 # TODO: Should we use structlog.bind_contextvars here for dag_id, task_id & run_id?
1038
1039 log.info("Pausing task as DEFERRED. ", dag_id=ti.dag_id, task_id=ti.task_id, run_id=ti.run_id)
1040 classpath, trigger_kwargs = defer.trigger.serialize()
1041 queue: str | None = None
1042 # Currently, only task-associated BaseTrigger instances may have a non-None queue,
1043 # and only when triggerer.queues_enabled is True.
1044 if not isinstance(defer.trigger, (BaseEventTrigger, CallbackTrigger)) and conf.getboolean(
1045 "triggerer", "queues_enabled", fallback=False
1046 ):
1047 queue = ti.task.queue
1048
1049 from airflow.sdk.serde import serialize as serde_serialize
1050
1051 trigger_kwargs = serde_serialize(trigger_kwargs)
1052 next_kwargs = serde_serialize(defer.kwargs or {})
1053
1054 if TYPE_CHECKING:
1055 assert isinstance(next_kwargs, dict)
1056 assert isinstance(trigger_kwargs, dict)
1057
1058 msg = DeferTask(
1059 classpath=classpath,
1060 trigger_kwargs=trigger_kwargs,
1061 trigger_timeout=defer.timeout,
1062 queue=queue,
1063 next_method=defer.method_name,
1064 next_kwargs=next_kwargs,
1065 )
1066 state = TaskInstanceState.DEFERRED
1067
1068 return msg, state
1069
1070
1071@Sentry.enrich_errors
1072def run(
1073 ti: RuntimeTaskInstance,
1074 context: Context,
1075 log: Logger,
1076) -> tuple[TaskInstanceState, ToSupervisor | None, BaseException | None]:
1077 """Run the task in this process."""
1078 import signal
1079
1080 from airflow.sdk.exceptions import (
1081 AirflowFailException,
1082 AirflowRescheduleException,
1083 AirflowSensorTimeout,
1084 AirflowSkipException,
1085 AirflowTaskTerminated,
1086 DagRunTriggerException,
1087 DownstreamTasksSkipped,
1088 TaskDeferred,
1089 )
1090
1091 if TYPE_CHECKING:
1092 assert ti.task is not None
1093 assert isinstance(ti.task, BaseOperator)
1094
1095 parent_pid = os.getpid()
1096
1097 def _on_term(signum, frame):
1098 pid = os.getpid()
1099 if pid != parent_pid:
1100 return
1101
1102 ti.task.on_kill()
1103
1104 signal.signal(signal.SIGTERM, _on_term)
1105
1106 msg: ToSupervisor | None = None
1107 state: TaskInstanceState
1108 error: BaseException | None = None
1109
1110 try:
1111 # First, clear the xcom data sent from server
1112 if ti._ti_context_from_server and (keys_to_delete := ti._ti_context_from_server.xcom_keys_to_clear):
1113 for x in keys_to_delete:
1114 log.debug("Clearing XCom with key", key=x)
1115 XCom.delete(
1116 key=x,
1117 dag_id=ti.dag_id,
1118 task_id=ti.task_id,
1119 run_id=ti.run_id,
1120 map_index=ti.map_index,
1121 )
1122
1123 with set_current_context(context):
1124 # This is the earliest that we can render templates -- as if it excepts for any reason we need to
1125 # catch it and handle it like a normal task failure
1126 if early_exit := _prepare(ti, log, context):
1127 msg = early_exit
1128 ti.state = state = TaskInstanceState.FAILED
1129 return state, msg, error
1130
1131 try:
1132 result = _execute_task(context=context, ti=ti, log=log)
1133 except Exception:
1134 import jinja2
1135
1136 # If the task failed, swallow rendering error so it doesn't mask the main error.
1137 with contextlib.suppress(jinja2.TemplateSyntaxError, jinja2.UndefinedError):
1138 previous_rendered_map_index = ti.rendered_map_index
1139 ti.rendered_map_index = _render_map_index(context, ti=ti, log=log)
1140 # Send update only if value changed (e.g., user set context variables during execution)
1141 if ti.rendered_map_index and ti.rendered_map_index != previous_rendered_map_index:
1142 SUPERVISOR_COMMS.send(
1143 msg=SetRenderedMapIndex(rendered_map_index=ti.rendered_map_index)
1144 )
1145 raise
1146 else: # If the task succeeded, render normally to let rendering error bubble up.
1147 previous_rendered_map_index = ti.rendered_map_index
1148 ti.rendered_map_index = _render_map_index(context, ti=ti, log=log)
1149 # Send update only if value changed (e.g., user set context variables during execution)
1150 if ti.rendered_map_index and ti.rendered_map_index != previous_rendered_map_index:
1151 SUPERVISOR_COMMS.send(msg=SetRenderedMapIndex(rendered_map_index=ti.rendered_map_index))
1152
1153 _push_xcom_if_needed(result, ti, log)
1154
1155 msg, state = _handle_current_task_success(context, ti)
1156 except DownstreamTasksSkipped as skip:
1157 log.info("Skipping downstream tasks.")
1158 tasks_to_skip = skip.tasks if isinstance(skip.tasks, list) else [skip.tasks]
1159 SUPERVISOR_COMMS.send(msg=SkipDownstreamTasks(tasks=tasks_to_skip))
1160 msg, state = _handle_current_task_success(context, ti)
1161 except DagRunTriggerException as drte:
1162 msg, state = _handle_trigger_dag_run(drte, context, ti, log)
1163 except TaskDeferred as defer:
1164 msg, state = _defer_task(defer, ti, log)
1165 except AirflowSkipException as e:
1166 if e.args:
1167 log.info("Skipping task.", reason=e.args[0])
1168 msg = TaskState(
1169 state=TaskInstanceState.SKIPPED,
1170 end_date=datetime.now(tz=timezone.utc),
1171 rendered_map_index=ti.rendered_map_index,
1172 )
1173 state = TaskInstanceState.SKIPPED
1174 except AirflowRescheduleException as reschedule:
1175 log.info("Rescheduling task, marking task as UP_FOR_RESCHEDULE")
1176 msg = RescheduleTask(
1177 reschedule_date=reschedule.reschedule_date, end_date=datetime.now(tz=timezone.utc)
1178 )
1179 state = TaskInstanceState.UP_FOR_RESCHEDULE
1180 except (AirflowFailException, AirflowSensorTimeout) as e:
1181 # If AirflowFailException is raised, task should not retry.
1182 # If a sensor in reschedule mode reaches timeout, task should not retry.
1183 log.exception("Task failed with exception")
1184 ti.end_date = datetime.now(tz=timezone.utc)
1185 msg = TaskState(
1186 state=TaskInstanceState.FAILED,
1187 end_date=ti.end_date,
1188 rendered_map_index=ti.rendered_map_index,
1189 )
1190 state = TaskInstanceState.FAILED
1191 error = e
1192 except (AirflowTaskTimeout, AirflowException, AirflowRuntimeError) as e:
1193 # We should allow retries if the task has defined it.
1194 log.exception("Task failed with exception")
1195 msg, state = _handle_current_task_failed(ti)
1196 error = e
1197 except AirflowTaskTerminated as e:
1198 # External state updates are already handled with `ti_heartbeat` and will be
1199 # updated already be another UI API. So, these exceptions should ideally never be thrown.
1200 # If these are thrown, we should mark the TI state as failed.
1201 log.exception("Task failed with exception")
1202 ti.end_date = datetime.now(tz=timezone.utc)
1203 msg = TaskState(
1204 state=TaskInstanceState.FAILED,
1205 end_date=ti.end_date,
1206 rendered_map_index=ti.rendered_map_index,
1207 )
1208 state = TaskInstanceState.FAILED
1209 error = e
1210 except SystemExit as e:
1211 # SystemExit needs to be retried if they are eligible.
1212 log.error("Task exited", exit_code=e.code)
1213 msg, state = _handle_current_task_failed(ti)
1214 error = e
1215 except BaseException as e:
1216 log.exception("Task failed with exception")
1217 msg, state = _handle_current_task_failed(ti)
1218 error = e
1219 finally:
1220 if msg:
1221 SUPERVISOR_COMMS.send(msg=msg)
1222
1223 # Return the message to make unit tests easier too
1224 ti.state = state
1225 return state, msg, error
1226
1227
1228def _handle_current_task_success(
1229 context: Context,
1230 ti: RuntimeTaskInstance,
1231) -> tuple[SucceedTask, TaskInstanceState]:
1232 end_date = datetime.now(tz=timezone.utc)
1233 ti.end_date = end_date
1234
1235 # Record operator and task instance success metrics
1236 operator = ti.task.__class__.__name__
1237 stats_tags = {"dag_id": ti.dag_id, "task_id": ti.task_id}
1238
1239 Stats.incr(f"operator_successes_{operator}", tags=stats_tags)
1240 # Same metric with tagging
1241 Stats.incr("operator_successes", tags={**stats_tags, "operator": operator})
1242 Stats.incr("ti_successes", tags=stats_tags)
1243
1244 task_outlets = list(_build_asset_profiles(ti.task.outlets))
1245 outlet_events = list(_serialize_outlet_events(context["outlet_events"]))
1246 msg = SucceedTask(
1247 end_date=end_date,
1248 task_outlets=task_outlets,
1249 outlet_events=outlet_events,
1250 rendered_map_index=ti.rendered_map_index,
1251 )
1252 return msg, TaskInstanceState.SUCCESS
1253
1254
1255def _handle_current_task_failed(
1256 ti: RuntimeTaskInstance,
1257) -> tuple[RetryTask, TaskInstanceState] | tuple[TaskState, TaskInstanceState]:
1258 end_date = datetime.now(tz=timezone.utc)
1259 ti.end_date = end_date
1260
1261 # Record operator and task instance failed metrics
1262 operator = ti.task.__class__.__name__
1263 stats_tags = {"dag_id": ti.dag_id, "task_id": ti.task_id}
1264
1265 Stats.incr(f"operator_failures_{operator}", tags=stats_tags)
1266 # Same metric with tagging
1267 Stats.incr("operator_failures", tags={**stats_tags, "operator": operator})
1268 Stats.incr("ti_failures", tags=stats_tags)
1269
1270 if ti._ti_context_from_server and ti._ti_context_from_server.should_retry:
1271 return RetryTask(end_date=end_date), TaskInstanceState.UP_FOR_RETRY
1272 return (
1273 TaskState(
1274 state=TaskInstanceState.FAILED, end_date=end_date, rendered_map_index=ti.rendered_map_index
1275 ),
1276 TaskInstanceState.FAILED,
1277 )
1278
1279
1280def _handle_trigger_dag_run(
1281 drte: DagRunTriggerException, context: Context, ti: RuntimeTaskInstance, log: Logger
1282) -> tuple[ToSupervisor, TaskInstanceState]:
1283 """Handle exception from TriggerDagRunOperator."""
1284 log.info("Triggering Dag Run.", trigger_dag_id=drte.trigger_dag_id)
1285 comms_msg = SUPERVISOR_COMMS.send(
1286 TriggerDagRun(
1287 dag_id=drte.trigger_dag_id,
1288 run_id=drte.dag_run_id,
1289 logical_date=drte.logical_date,
1290 conf=drte.conf,
1291 reset_dag_run=drte.reset_dag_run,
1292 ),
1293 )
1294
1295 if isinstance(comms_msg, ErrorResponse) and comms_msg.error == ErrorType.DAGRUN_ALREADY_EXISTS:
1296 if drte.skip_when_already_exists:
1297 log.info(
1298 "Dag Run already exists, skipping task as skip_when_already_exists is set to True.",
1299 dag_id=drte.trigger_dag_id,
1300 )
1301 msg = TaskState(
1302 state=TaskInstanceState.SKIPPED,
1303 end_date=datetime.now(tz=timezone.utc),
1304 rendered_map_index=ti.rendered_map_index,
1305 )
1306 state = TaskInstanceState.SKIPPED
1307 else:
1308 log.error("Dag Run already exists, marking task as failed.", dag_id=drte.trigger_dag_id)
1309 msg = TaskState(
1310 state=TaskInstanceState.FAILED,
1311 end_date=datetime.now(tz=timezone.utc),
1312 rendered_map_index=ti.rendered_map_index,
1313 )
1314 state = TaskInstanceState.FAILED
1315
1316 return msg, state
1317
1318 log.info("Dag Run triggered successfully.", trigger_dag_id=drte.trigger_dag_id)
1319
1320 # Store the run id from the dag run (either created or found above) to
1321 # be used when creating the extra link on the webserver.
1322 ti.xcom_push(key="trigger_run_id", value=drte.dag_run_id)
1323
1324 if drte.wait_for_completion:
1325 if drte.deferrable:
1326 from airflow.providers.standard.triggers.external_task import DagStateTrigger
1327
1328 defer = TaskDeferred(
1329 trigger=DagStateTrigger(
1330 dag_id=drte.trigger_dag_id,
1331 states=drte.allowed_states + drte.failed_states, # type: ignore[arg-type]
1332 # Don't filter by execution_dates when run_ids is provided.
1333 # run_id uniquely identifies a DAG run, and when reset_dag_run=True,
1334 # drte.logical_date might be a newly calculated value that doesn't match
1335 # the persisted logical_date in the database, causing the trigger to never find the run.
1336 execution_dates=None,
1337 run_ids=[drte.dag_run_id],
1338 poll_interval=drte.poke_interval,
1339 ),
1340 method_name="execute_complete",
1341 )
1342 return _defer_task(defer, ti, log)
1343 while True:
1344 log.info(
1345 "Waiting for dag run to complete execution in allowed state.",
1346 dag_id=drte.trigger_dag_id,
1347 run_id=drte.dag_run_id,
1348 allowed_state=drte.allowed_states,
1349 )
1350 time.sleep(drte.poke_interval)
1351
1352 comms_msg = SUPERVISOR_COMMS.send(
1353 GetDagRunState(dag_id=drte.trigger_dag_id, run_id=drte.dag_run_id)
1354 )
1355 if TYPE_CHECKING:
1356 assert isinstance(comms_msg, DagRunStateResult)
1357 if comms_msg.state in drte.failed_states:
1358 log.error(
1359 "DagRun finished with failed state.", dag_id=drte.trigger_dag_id, state=comms_msg.state
1360 )
1361 msg = TaskState(
1362 state=TaskInstanceState.FAILED,
1363 end_date=datetime.now(tz=timezone.utc),
1364 rendered_map_index=ti.rendered_map_index,
1365 )
1366 state = TaskInstanceState.FAILED
1367 return msg, state
1368 if comms_msg.state in drte.allowed_states:
1369 log.info(
1370 "DagRun finished with allowed state.", dag_id=drte.trigger_dag_id, state=comms_msg.state
1371 )
1372 break
1373 log.debug(
1374 "DagRun not yet in allowed or failed state.",
1375 dag_id=drte.trigger_dag_id,
1376 state=comms_msg.state,
1377 )
1378 else:
1379 # Fire-and-forget mode: wait_for_completion=False
1380 if drte.deferrable:
1381 log.info(
1382 "Ignoring deferrable=True because wait_for_completion=False. "
1383 "Task will complete immediately without waiting for the triggered DAG run.",
1384 trigger_dag_id=drte.trigger_dag_id,
1385 )
1386
1387 return _handle_current_task_success(context, ti)
1388
1389
1390def _run_task_state_change_callbacks(
1391 task: BaseOperator,
1392 kind: Literal[
1393 "on_execute_callback",
1394 "on_failure_callback",
1395 "on_success_callback",
1396 "on_retry_callback",
1397 "on_skipped_callback",
1398 ],
1399 context: Context,
1400 log: Logger,
1401) -> None:
1402 callback: Callable[[Context], None]
1403 for i, callback in enumerate(getattr(task, kind)):
1404 try:
1405 create_executable_runner(callback, context_get_outlet_events(context), logger=log).run(context)
1406 except Exception:
1407 log.exception("Failed to run task callback", kind=kind, index=i, callback=callback)
1408
1409
1410def _send_error_email_notification(
1411 task: BaseOperator | MappedOperator,
1412 ti: RuntimeTaskInstance,
1413 context: Context,
1414 error: BaseException | str | None,
1415 log: Logger,
1416) -> None:
1417 """Send email notification for task errors using SmtpNotifier."""
1418 try:
1419 from airflow.providers.smtp.notifications.smtp import SmtpNotifier
1420 except ImportError:
1421 log.error(
1422 "Failed to send task failure or retry email notification: "
1423 "`apache-airflow-providers-smtp` is not installed. "
1424 "Install this provider to enable email notifications."
1425 )
1426 return
1427
1428 if not task.email:
1429 return
1430
1431 subject_template_file = conf.get("email", "subject_template", fallback=None)
1432
1433 # Read the template file if configured
1434 if subject_template_file and Path(subject_template_file).exists():
1435 subject = Path(subject_template_file).read_text()
1436 else:
1437 # Fallback to default
1438 subject = "Airflow alert: {{ti}}"
1439
1440 html_content_template_file = conf.get("email", "html_content_template", fallback=None)
1441
1442 # Read the template file if configured
1443 if html_content_template_file and Path(html_content_template_file).exists():
1444 html_content = Path(html_content_template_file).read_text()
1445 else:
1446 # Fallback to default
1447 # For reporting purposes, we report based on 1-indexed,
1448 # not 0-indexed lists (i.e. Try 1 instead of Try 0 for the first attempt).
1449 html_content = (
1450 "Try {{try_number}} out of {{max_tries + 1}}<br>"
1451 "Exception:<br>{{exception_html}}<br>"
1452 'Log: <a href="{{ti.log_url}}">Link</a><br>'
1453 "Host: {{ti.hostname}}<br>"
1454 'Mark success: <a href="{{ti.mark_success_url}}">Link</a><br>'
1455 )
1456
1457 # Add exception_html to context for template rendering
1458 import html
1459
1460 exception_html = html.escape(str(error)).replace("\n", "<br>")
1461 additional_context = {
1462 "exception": error,
1463 "exception_html": exception_html,
1464 "try_number": ti.try_number,
1465 "max_tries": ti.max_tries,
1466 }
1467 email_context = {**context, **additional_context}
1468 to_emails = task.email
1469 if not to_emails:
1470 return
1471
1472 try:
1473 notifier = SmtpNotifier(
1474 to=to_emails,
1475 subject=subject,
1476 html_content=html_content,
1477 from_email=conf.get("email", "from_email", fallback="airflow@airflow"),
1478 )
1479 notifier(email_context)
1480 except Exception:
1481 log.exception("Failed to send email notification")
1482
1483
1484def _execute_task(context: Context, ti: RuntimeTaskInstance, log: Logger):
1485 """Execute Task (optionally with a Timeout) and push Xcom results."""
1486 task = ti.task
1487 execute = task.execute
1488
1489 if ti._ti_context_from_server and (next_method := ti._ti_context_from_server.next_method):
1490 from airflow.sdk.serde import deserialize
1491
1492 next_kwargs_data = ti._ti_context_from_server.next_kwargs or {}
1493 try:
1494 if TYPE_CHECKING:
1495 assert isinstance(next_kwargs_data, dict)
1496 kwargs = deserialize(next_kwargs_data)
1497 except (ImportError, KeyError, AttributeError, TypeError):
1498 from airflow.serialization.serialized_objects import BaseSerialization
1499
1500 kwargs = BaseSerialization.deserialize(next_kwargs_data)
1501
1502 if TYPE_CHECKING:
1503 assert isinstance(kwargs, dict)
1504 execute = functools.partial(task.resume_execution, next_method=next_method, next_kwargs=kwargs)
1505
1506 ctx = contextvars.copy_context()
1507 # Populate the context var so ExecutorSafeguard doesn't complain
1508 ctx.run(ExecutorSafeguard.tracker.set, task)
1509
1510 # Export context in os.environ to make it available for operators to use.
1511 airflow_context_vars = context_to_airflow_vars(context, in_env_var_format=True)
1512 os.environ.update(airflow_context_vars)
1513
1514 outlet_events = context_get_outlet_events(context)
1515
1516 if (pre_execute_hook := task._pre_execute_hook) is not None:
1517 create_executable_runner(pre_execute_hook, outlet_events, logger=log).run(context)
1518 if getattr(pre_execute_hook := task.pre_execute, "__func__", None) is not BaseOperator.pre_execute:
1519 create_executable_runner(pre_execute_hook, outlet_events, logger=log).run(context)
1520
1521 _run_task_state_change_callbacks(task, "on_execute_callback", context, log)
1522
1523 if task.execution_timeout:
1524 from airflow.sdk.execution_time.timeout import timeout
1525
1526 # TODO: handle timeout in case of deferral
1527 timeout_seconds = task.execution_timeout.total_seconds()
1528 try:
1529 # It's possible we're already timed out, so fast-fail if true
1530 if timeout_seconds <= 0:
1531 raise AirflowTaskTimeout()
1532 # Run task in timeout wrapper
1533 with timeout(timeout_seconds):
1534 result = ctx.run(execute, context=context)
1535 except AirflowTaskTimeout:
1536 task.on_kill()
1537 raise
1538 else:
1539 result = ctx.run(execute, context=context)
1540
1541 if (post_execute_hook := task._post_execute_hook) is not None:
1542 create_executable_runner(post_execute_hook, outlet_events, logger=log).run(context, result)
1543 if getattr(post_execute_hook := task.post_execute, "__func__", None) is not BaseOperator.post_execute:
1544 create_executable_runner(post_execute_hook, outlet_events, logger=log).run(context)
1545
1546 return result
1547
1548
1549def _render_map_index(context: Context, ti: RuntimeTaskInstance, log: Logger) -> str | None:
1550 """Render named map index if the Dag author defined map_index_template at the task level."""
1551 if (template := context.get("map_index_template")) is None:
1552 return None
1553 log.debug("Rendering map_index_template", template_length=len(template))
1554 jinja_env = ti.task.dag.get_template_env()
1555 rendered_map_index = jinja_env.from_string(template).render(context)
1556 log.debug("Map index rendered", length=len(rendered_map_index))
1557 return rendered_map_index
1558
1559
1560def _push_xcom_if_needed(result: Any, ti: RuntimeTaskInstance, log: Logger):
1561 """Push XCom values when task has ``do_xcom_push`` set to ``True`` and the task returns a result."""
1562 if ti.task.do_xcom_push:
1563 xcom_value = result
1564 else:
1565 xcom_value = None
1566
1567 has_mapped_dep = next(ti.task.iter_mapped_dependants(), None) is not None
1568 if xcom_value is None:
1569 if not ti.is_mapped and has_mapped_dep:
1570 # Uhoh, a downstream mapped task depends on us to push something to map over
1571 from airflow.sdk.exceptions import XComForMappingNotPushed
1572
1573 raise XComForMappingNotPushed()
1574 return
1575
1576 mapped_length: int | None = None
1577 if not ti.is_mapped and has_mapped_dep:
1578 from airflow.sdk.definitions.mappedoperator import is_mappable_value
1579 from airflow.sdk.exceptions import UnmappableXComTypePushed
1580
1581 if not is_mappable_value(xcom_value):
1582 raise UnmappableXComTypePushed(xcom_value)
1583 mapped_length = len(xcom_value)
1584
1585 log.info("Pushing xcom", ti=ti)
1586
1587 # If the task has multiple outputs, push each output as a separate XCom.
1588 if ti.task.multiple_outputs:
1589 if not isinstance(xcom_value, Mapping):
1590 raise TypeError(
1591 f"Returned output was type {type(xcom_value)} expected dictionary for multiple_outputs"
1592 )
1593 for key in xcom_value.keys():
1594 if not isinstance(key, str):
1595 raise TypeError(
1596 "Returned dictionary keys must be strings when using "
1597 f"multiple_outputs, found {key} ({type(key)}) instead"
1598 )
1599 for k, v in result.items():
1600 ti.xcom_push(k, v)
1601
1602 _xcom_push(ti, BaseXCom.XCOM_RETURN_KEY, result, mapped_length=mapped_length)
1603
1604
1605def finalize(
1606 ti: RuntimeTaskInstance,
1607 state: TaskInstanceState,
1608 context: Context,
1609 log: Logger,
1610 error: BaseException | None = None,
1611):
1612 # Record task duration metrics for all terminal states
1613 if ti.start_date and ti.end_date:
1614 duration_ms = (ti.end_date - ti.start_date).total_seconds() * 1000
1615 stats_tags = {"dag_id": ti.dag_id, "task_id": ti.task_id}
1616
1617 Stats.timing(f"dag.{ti.dag_id}.{ti.task_id}.duration", duration_ms)
1618 Stats.timing("task.duration", duration_ms, tags=stats_tags)
1619
1620 task = ti.task
1621 # Pushing xcom for each operator extra links defined on the operator only.
1622 for oe in task.operator_extra_links:
1623 try:
1624 link, xcom_key = oe.get_link(operator=task, ti_key=ti), oe.xcom_key # type: ignore[arg-type]
1625 log.debug("Setting xcom for operator extra link", link=link, xcom_key=xcom_key)
1626 _xcom_push_to_db(ti, key=xcom_key, value=link)
1627 except Exception:
1628 log.exception(
1629 "Failed to push an xcom for task operator extra link",
1630 link_name=oe.name,
1631 xcom_key=oe.xcom_key,
1632 ti=ti,
1633 )
1634
1635 if getattr(ti.task, "overwrite_rtif_after_execution", False):
1636 log.debug("Overwriting Rendered template fields.")
1637 if ti.task.template_fields:
1638 SUPERVISOR_COMMS.send(SetRenderedFields(rendered_fields=_serialize_rendered_fields(ti.task)))
1639
1640 log.debug("Running finalizers", ti=ti)
1641 if state == TaskInstanceState.SUCCESS:
1642 _run_task_state_change_callbacks(task, "on_success_callback", context, log)
1643 try:
1644 get_listener_manager().hook.on_task_instance_success(
1645 previous_state=TaskInstanceState.RUNNING, task_instance=ti
1646 )
1647 except Exception:
1648 log.exception("error calling listener")
1649 elif state == TaskInstanceState.SKIPPED:
1650 _run_task_state_change_callbacks(task, "on_skipped_callback", context, log)
1651 try:
1652 get_listener_manager().hook.on_task_instance_skipped(
1653 previous_state=TaskInstanceState.RUNNING, task_instance=ti
1654 )
1655 except Exception:
1656 log.exception("error calling listener")
1657 elif state == TaskInstanceState.UP_FOR_RETRY:
1658 _run_task_state_change_callbacks(task, "on_retry_callback", context, log)
1659 try:
1660 get_listener_manager().hook.on_task_instance_failed(
1661 previous_state=TaskInstanceState.RUNNING, task_instance=ti, error=error
1662 )
1663 except Exception:
1664 log.exception("error calling listener")
1665 if error and task.email_on_retry and task.email:
1666 _send_error_email_notification(task, ti, context, error, log)
1667 elif state == TaskInstanceState.FAILED:
1668 _run_task_state_change_callbacks(task, "on_failure_callback", context, log)
1669 try:
1670 get_listener_manager().hook.on_task_instance_failed(
1671 previous_state=TaskInstanceState.RUNNING, task_instance=ti, error=error
1672 )
1673 except Exception:
1674 log.exception("error calling listener")
1675 if error and task.email_on_failure and task.email:
1676 _send_error_email_notification(task, ti, context, error, log)
1677
1678 try:
1679 get_listener_manager().hook.before_stopping(component=TaskRunnerMarker())
1680 except Exception:
1681 log.exception("error calling listener")
1682
1683
1684def main():
1685 log = structlog.get_logger(logger_name="task")
1686
1687 global SUPERVISOR_COMMS
1688 SUPERVISOR_COMMS = CommsDecoder[ToTask, ToSupervisor](log=log)
1689
1690 Stats.initialize(
1691 is_statsd_datadog_enabled=conf.getboolean("metrics", "statsd_datadog_enabled"),
1692 is_statsd_on=conf.getboolean("metrics", "statsd_on"),
1693 is_otel_on=conf.getboolean("metrics", "otel_on"),
1694 )
1695
1696 try:
1697 ti, context, log = startup()
1698 with BundleVersionLock(
1699 bundle_name=ti.bundle_instance.name,
1700 bundle_version=ti.bundle_instance.version,
1701 ):
1702 state, _, error = run(ti, context, log)
1703 context["exception"] = error
1704 finalize(ti, state, context, log, error)
1705 except KeyboardInterrupt:
1706 log.exception("Ctrl-c hit")
1707 exit(2)
1708 except Exception:
1709 log.exception("Top level error")
1710 exit(1)
1711 finally:
1712 # Ensure the request socket is closed on the child side in all circumstances
1713 # before the process fully terminates.
1714 if SUPERVISOR_COMMS and SUPERVISOR_COMMS.socket:
1715 with suppress(Exception):
1716 SUPERVISOR_COMMS.socket.close()
1717
1718
1719def reinit_supervisor_comms() -> None:
1720 """
1721 Re-initialize supervisor comms and logging channel in subprocess.
1722
1723 This is not needed for most cases, but is used when either we re-launch the process via sudo for
1724 run_as_user, or from inside the python code in a virtualenv (et al.) operator to re-connect so those tasks
1725 can continue to access variables etc.
1726 """
1727 import socket
1728
1729 if "SUPERVISOR_COMMS" not in globals():
1730 global SUPERVISOR_COMMS
1731 log = structlog.get_logger(logger_name="task")
1732
1733 fd = int(os.environ.get("__AIRFLOW_SUPERVISOR_FD", "0"))
1734
1735 SUPERVISOR_COMMS = CommsDecoder[ToTask, ToSupervisor](log=log, socket=socket.socket(fileno=fd))
1736
1737 logs = SUPERVISOR_COMMS.send(ResendLoggingFD())
1738 if isinstance(logs, SentFDs):
1739 from airflow.sdk.log import configure_logging
1740
1741 log_io = os.fdopen(logs.fds[0], "wb", buffering=0)
1742 configure_logging(json_output=True, output=log_io, sending_to_supervisor=True)
1743 else:
1744 print("Unable to re-configure logging after sudo, we didn't get an FD", file=sys.stderr)
1745
1746
1747if __name__ == "__main__":
1748 main()