1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18"""The entrypoint for the actual task execution process."""
19
20from __future__ import annotations
21
22import contextlib
23import contextvars
24import functools
25import os
26import sys
27import time
28from collections.abc import Callable, Iterable, Iterator, Mapping
29from contextlib import suppress
30from datetime import datetime, timezone
31from itertools import product
32from pathlib import Path
33from typing import TYPE_CHECKING, Annotated, Any, Literal
34from urllib.parse import quote
35
36import attrs
37import lazy_object_proxy
38import structlog
39from pydantic import AwareDatetime, ConfigDict, Field, JsonValue, TypeAdapter
40
41from airflow.dag_processing.bundles.base import BaseDagBundle, BundleVersionLock
42from airflow.dag_processing.bundles.manager import DagBundlesManager
43from airflow.listeners.listener import get_listener_manager
44from airflow.sdk.api.client import get_hostname, getuser
45from airflow.sdk.api.datamodels._generated import (
46 AssetProfile,
47 DagRun,
48 PreviousTIResponse,
49 TaskInstance,
50 TaskInstanceState,
51 TIRunContext,
52)
53from airflow.sdk.bases.operator import BaseOperator, ExecutorSafeguard
54from airflow.sdk.bases.xcom import BaseXCom
55from airflow.sdk.configuration import conf
56from airflow.sdk.definitions._internal.dag_parsing_context import _airflow_parsing_context_manager
57from airflow.sdk.definitions._internal.types import NOTSET, ArgNotSet, is_arg_set
58from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetNameRef, AssetUniqueKey, AssetUriRef
59from airflow.sdk.definitions.mappedoperator import MappedOperator
60from airflow.sdk.definitions.param import process_params
61from airflow.sdk.exceptions import (
62 AirflowException,
63 AirflowInactiveAssetInInletOrOutletException,
64 AirflowRuntimeError,
65 AirflowTaskTimeout,
66 ErrorType,
67 TaskDeferred,
68)
69from airflow.sdk.execution_time.callback_runner import create_executable_runner
70from airflow.sdk.execution_time.comms import (
71 AssetEventDagRunReferenceResult,
72 CommsDecoder,
73 DagRunStateResult,
74 DeferTask,
75 DRCount,
76 ErrorResponse,
77 GetDagRunState,
78 GetDRCount,
79 GetPreviousDagRun,
80 GetPreviousTI,
81 GetTaskBreadcrumbs,
82 GetTaskRescheduleStartDate,
83 GetTaskStates,
84 GetTICount,
85 InactiveAssetsResult,
86 PreviousDagRunResult,
87 PreviousTIResult,
88 RescheduleTask,
89 ResendLoggingFD,
90 RetryTask,
91 SentFDs,
92 SetRenderedFields,
93 SetRenderedMapIndex,
94 SkipDownstreamTasks,
95 StartupDetails,
96 SucceedTask,
97 TaskBreadcrumbsResult,
98 TaskRescheduleStartDate,
99 TaskState,
100 TaskStatesResult,
101 TICount,
102 ToSupervisor,
103 ToTask,
104 TriggerDagRun,
105 ValidateInletsAndOutlets,
106)
107from airflow.sdk.execution_time.context import (
108 ConnectionAccessor,
109 InletEventsAccessors,
110 MacrosAccessor,
111 OutletEventAccessors,
112 TriggeringAssetEventsAccessor,
113 VariableAccessor,
114 context_get_outlet_events,
115 context_to_airflow_vars,
116 get_previous_dagrun_success,
117 set_current_context,
118)
119from airflow.sdk.execution_time.sentry import Sentry
120from airflow.sdk.execution_time.xcom import XCom
121from airflow.sdk.observability.stats import Stats
122from airflow.sdk.timezone import coerce_datetime
123
124if TYPE_CHECKING:
125 import jinja2
126 from pendulum.datetime import DateTime
127 from structlog.typing import FilteringBoundLogger as Logger
128
129 from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator
130 from airflow.sdk.definitions.context import Context
131 from airflow.sdk.exceptions import DagRunTriggerException
132 from airflow.sdk.types import OutletEventAccessorsProtocol
133
134
135class TaskRunnerMarker:
136 """Marker for listener hooks, to properly detect from which component they are called."""
137
138
139# TODO: Move this entire class into a separate file:
140# `airflow/sdk/execution_time/task_instance.py`
141# or `airflow/sdk/execution_time/runtime_ti.py`
142class RuntimeTaskInstance(TaskInstance):
143 model_config = ConfigDict(arbitrary_types_allowed=True)
144
145 task: BaseOperator
146 bundle_instance: BaseDagBundle
147 _cached_template_context: Context | None = None
148 """The Task Instance context. This is used to cache get_template_context."""
149
150 _ti_context_from_server: Annotated[TIRunContext | None, Field(repr=False)] = None
151 """The Task Instance context from the API server, if any."""
152
153 max_tries: int = 0
154 """The maximum number of retries for the task."""
155
156 start_date: AwareDatetime
157 """Start date of the task instance."""
158
159 end_date: AwareDatetime | None = None
160
161 state: TaskInstanceState | None = None
162
163 is_mapped: bool | None = None
164 """True if the original task was mapped."""
165
166 rendered_map_index: str | None = None
167
168 sentry_integration: str = ""
169
170 def __rich_repr__(self):
171 yield "id", self.id
172 yield "task_id", self.task_id
173 yield "dag_id", self.dag_id
174 yield "run_id", self.run_id
175 yield "max_tries", self.max_tries
176 yield "task", type(self.task)
177 yield "start_date", self.start_date
178
179 __rich_repr__.angular = True # type: ignore[attr-defined]
180
181 def get_template_context(self) -> Context:
182 # TODO: Move this to `airflow.sdk.execution_time.context`
183 # once we port the entire context logic from airflow/utils/context.py ?
184 from airflow.plugins_manager import integrate_macros_plugins
185
186 integrate_macros_plugins()
187
188 dag_run_conf: dict[str, Any] | None = None
189 if from_server := self._ti_context_from_server:
190 dag_run_conf = from_server.dag_run.conf or dag_run_conf
191
192 validated_params = process_params(self.task.dag, self.task, dag_run_conf, suppress_exception=False)
193
194 # Cache the context object, which ensures that all calls to get_template_context
195 # are operating on the same context object.
196 self._cached_template_context: Context = self._cached_template_context or {
197 # From the Task Execution interface
198 "dag": self.task.dag,
199 "inlets": self.task.inlets,
200 "map_index_template": self.task.map_index_template,
201 "outlets": self.task.outlets,
202 "run_id": self.run_id,
203 "task": self.task,
204 "task_instance": self,
205 "ti": self,
206 "outlet_events": OutletEventAccessors(),
207 "inlet_events": InletEventsAccessors(self.task.inlets),
208 "macros": MacrosAccessor(),
209 "params": validated_params,
210 # TODO: Make this go through Public API longer term.
211 # "test_mode": task_instance.test_mode,
212 "var": {
213 "json": VariableAccessor(deserialize_json=True),
214 "value": VariableAccessor(deserialize_json=False),
215 },
216 "conn": ConnectionAccessor(),
217 }
218 if from_server:
219 dag_run = from_server.dag_run
220 context_from_server: Context = {
221 # TODO: Assess if we need to pass these through timezone.coerce_datetime
222 "dag_run": dag_run, # type: ignore[typeddict-item] # Removable after #46522
223 "triggering_asset_events": TriggeringAssetEventsAccessor.build(
224 AssetEventDagRunReferenceResult.from_asset_event_dag_run_reference(event)
225 for event in dag_run.consumed_asset_events
226 ),
227 "task_instance_key_str": f"{self.task.dag_id}__{self.task.task_id}__{dag_run.run_id}",
228 "task_reschedule_count": from_server.task_reschedule_count or 0,
229 "prev_start_date_success": lazy_object_proxy.Proxy(
230 lambda: coerce_datetime(get_previous_dagrun_success(self.id).start_date)
231 ),
232 "prev_end_date_success": lazy_object_proxy.Proxy(
233 lambda: coerce_datetime(get_previous_dagrun_success(self.id).end_date)
234 ),
235 }
236 self._cached_template_context.update(context_from_server)
237
238 if logical_date := coerce_datetime(dag_run.logical_date):
239 if TYPE_CHECKING:
240 assert isinstance(logical_date, DateTime)
241 ds = logical_date.strftime("%Y-%m-%d")
242 ds_nodash = ds.replace("-", "")
243 ts = logical_date.isoformat()
244 ts_nodash = logical_date.strftime("%Y%m%dT%H%M%S")
245 ts_nodash_with_tz = ts.replace("-", "").replace(":", "")
246 # logical_date and data_interval either coexist or be None together
247 self._cached_template_context.update(
248 {
249 # keys that depend on logical_date
250 "logical_date": logical_date,
251 "ds": ds,
252 "ds_nodash": ds_nodash,
253 "task_instance_key_str": f"{self.task.dag_id}__{self.task.task_id}__{ds_nodash}",
254 "ts": ts,
255 "ts_nodash": ts_nodash,
256 "ts_nodash_with_tz": ts_nodash_with_tz,
257 # keys that depend on data_interval
258 "data_interval_end": coerce_datetime(dag_run.data_interval_end),
259 "data_interval_start": coerce_datetime(dag_run.data_interval_start),
260 "prev_data_interval_start_success": lazy_object_proxy.Proxy(
261 lambda: coerce_datetime(get_previous_dagrun_success(self.id).data_interval_start)
262 ),
263 "prev_data_interval_end_success": lazy_object_proxy.Proxy(
264 lambda: coerce_datetime(get_previous_dagrun_success(self.id).data_interval_end)
265 ),
266 }
267 )
268
269 if from_server.upstream_map_indexes is not None:
270 # We stash this in here for later use, but we purposefully don't want to document it's
271 # existence. Should this be a private attribute on RuntimeTI instead perhaps?
272 setattr(self, "_upstream_map_indexes", from_server.upstream_map_indexes)
273
274 return self._cached_template_context
275
276 def render_templates(
277 self, context: Context | None = None, jinja_env: jinja2.Environment | None = None
278 ) -> BaseOperator:
279 """
280 Render templates in the operator fields.
281
282 If the task was originally mapped, this may replace ``self.task`` with
283 the unmapped, fully rendered BaseOperator. The original ``self.task``
284 before replacement is returned.
285 """
286 if not context:
287 context = self.get_template_context()
288 original_task = self.task
289
290 if TYPE_CHECKING:
291 assert context
292
293 ti = context["ti"]
294
295 if TYPE_CHECKING:
296 assert original_task
297 assert self.task
298 assert ti.task
299
300 # If self.task is mapped, this call replaces self.task to point to the
301 # unmapped BaseOperator created by this function! This is because the
302 # MappedOperator is useless for template rendering, and we need to be
303 # able to access the unmapped task instead.
304 self.task.render_template_fields(context, jinja_env)
305 self.is_mapped = original_task.is_mapped
306 return original_task
307
308 def xcom_pull(
309 self,
310 task_ids: str | Iterable[str] | None = None,
311 dag_id: str | None = None,
312 key: str = BaseXCom.XCOM_RETURN_KEY,
313 include_prior_dates: bool = False,
314 *,
315 map_indexes: int | Iterable[int] | None | ArgNotSet = NOTSET,
316 default: Any = None,
317 run_id: str | None = None,
318 ) -> Any:
319 """
320 Pull XComs either from the API server (BaseXCom) or from the custom XCOM backend if configured.
321
322 The pull can be filtered optionally by certain criterion.
323
324 :param key: A key for the XCom. If provided, only XComs with matching
325 keys will be returned. The default key is ``'return_value'``, also
326 available as constant ``XCOM_RETURN_KEY``. This key is automatically
327 given to XComs returned by tasks (as opposed to being pushed
328 manually). To remove the filter, pass *None*.
329 :param task_ids: Only XComs from tasks with matching ids will be
330 pulled. Pass *None* to remove the filter.
331 :param dag_id: If provided, only pulls XComs from this Dag. If *None*
332 (default), the Dag of the calling task is used.
333 :param map_indexes: If provided, only pull XComs with matching indexes.
334 If *None* (default), this is inferred from the task(s) being pulled
335 (see below for details).
336 :param include_prior_dates: If False, only XComs from the current
337 logical_date are returned. If *True*, XComs from previous dates
338 are returned as well.
339 :param run_id: If provided, only pulls XComs from a DagRun w/a matching run_id.
340 If *None* (default), the run_id of the calling task is used.
341
342 When pulling one single task (``task_id`` is *None* or a str) without
343 specifying ``map_indexes``, the return value is a single XCom entry
344 (map_indexes is set to map_index of the calling task instance).
345
346 When pulling task is mapped the specified ``map_index`` is used, so by default
347 pulling on mapped task will result in no matching XComs if the task instance
348 of the method call is not mapped. Otherwise, the map_index of the calling task
349 instance is used. Setting ``map_indexes`` to *None* will pull XCom as it would
350 from a non mapped task.
351
352 In either case, ``default`` (*None* if not specified) is returned if no
353 matching XComs are found.
354
355 When pulling multiple tasks (i.e. either ``task_id`` or ``map_index`` is
356 a non-str iterable), a list of matching XComs is returned. Elements in
357 the list is ordered by item ordering in ``task_id`` and ``map_index``.
358 """
359 if dag_id is None:
360 dag_id = self.dag_id
361 if run_id is None:
362 run_id = self.run_id
363
364 single_task_requested = isinstance(task_ids, (str, type(None)))
365 single_map_index_requested = isinstance(map_indexes, (int, type(None)))
366
367 if task_ids is None:
368 # default to the current task if not provided
369 task_ids = [self.task_id]
370 elif isinstance(task_ids, str):
371 task_ids = [task_ids]
372
373 # If map_indexes is not specified, pull xcoms from all map indexes for each task
374 if not is_arg_set(map_indexes):
375 xcoms: list[Any] = []
376 for t_id in task_ids:
377 values = XCom.get_all(
378 run_id=run_id,
379 key=key,
380 task_id=t_id,
381 dag_id=dag_id,
382 include_prior_dates=include_prior_dates,
383 )
384
385 if values is None:
386 xcoms.append(None)
387 else:
388 xcoms.extend(values)
389 # For single task pulling from unmapped task, return single value
390 if single_task_requested and len(xcoms) == 1:
391 return xcoms[0]
392 return xcoms
393
394 # Original logic when map_indexes is explicitly specified
395 map_indexes_iterable: Iterable[int | None] = []
396 if isinstance(map_indexes, int) or map_indexes is None:
397 map_indexes_iterable = [map_indexes]
398 elif isinstance(map_indexes, Iterable):
399 map_indexes_iterable = map_indexes
400 else:
401 raise TypeError(
402 f"Invalid type for map_indexes: expected int, iterable of ints, or None, got {type(map_indexes)}"
403 )
404
405 xcoms = []
406 for t_id, m_idx in product(task_ids, map_indexes_iterable):
407 value = XCom.get_one(
408 run_id=run_id,
409 key=key,
410 task_id=t_id,
411 dag_id=dag_id,
412 map_index=m_idx,
413 include_prior_dates=include_prior_dates,
414 )
415 if value is None:
416 xcoms.append(default)
417 else:
418 xcoms.append(value)
419
420 if single_task_requested and single_map_index_requested:
421 return xcoms[0]
422
423 return xcoms
424
425 def xcom_push(self, key: str, value: Any):
426 """
427 Make an XCom available for tasks to pull.
428
429 :param key: Key to store the value under.
430 :param value: Value to store. Only be JSON-serializable values may be used.
431 """
432 _xcom_push(self, key, value)
433
434 def get_relevant_upstream_map_indexes(
435 self, upstream: BaseOperator, ti_count: int | None, session: Any
436 ) -> int | range | None:
437 # TODO: Implement this method
438 return None
439
440 def get_first_reschedule_date(self, context: Context) -> AwareDatetime | None:
441 """Get the first reschedule date for the task instance if found, none otherwise."""
442 if context.get("task_reschedule_count", 0) == 0:
443 # If the task has not been rescheduled, there is no need to ask the supervisor
444 return None
445
446 max_tries: int = self.max_tries
447 retries: int = self.task.retries or 0
448 first_try_number = max_tries - retries + 1
449
450 log = structlog.get_logger(logger_name="task")
451
452 log.debug("Requesting first reschedule date from supervisor")
453
454 response = SUPERVISOR_COMMS.send(
455 msg=GetTaskRescheduleStartDate(ti_id=self.id, try_number=first_try_number)
456 )
457
458 if TYPE_CHECKING:
459 assert isinstance(response, TaskRescheduleStartDate)
460
461 return response.start_date
462
463 def get_previous_dagrun(self, state: str | None = None) -> DagRun | None:
464 """Return the previous Dag run before the given logical date, optionally filtered by state."""
465 context = self.get_template_context()
466 dag_run = context.get("dag_run")
467
468 log = structlog.get_logger(logger_name="task")
469
470 log.debug("Getting previous Dag run", dag_run=dag_run)
471
472 if dag_run is None:
473 return None
474
475 if dag_run.logical_date is None:
476 return None
477
478 response = SUPERVISOR_COMMS.send(
479 msg=GetPreviousDagRun(dag_id=self.dag_id, logical_date=dag_run.logical_date, state=state)
480 )
481
482 if TYPE_CHECKING:
483 assert isinstance(response, PreviousDagRunResult)
484
485 return response.dag_run
486
487 def get_previous_ti(
488 self,
489 state: TaskInstanceState | None = None,
490 logical_date: AwareDatetime | None = None,
491 map_index: int = -1,
492 ) -> PreviousTIResponse | None:
493 """
494 Return the previous task instance matching the given criteria.
495
496 :param state: Filter by TaskInstance state
497 :param logical_date: Filter by logical date (returns TI before this date)
498 :param map_index: Filter by map_index (defaults to -1 for non-mapped tasks)
499 :return: Previous task instance or None if not found
500 """
501 context = self.get_template_context()
502 dag_run = context.get("dag_run")
503
504 log = structlog.get_logger(logger_name="task")
505 log.debug("Getting previous task instance", task_id=self.task_id, state=state)
506
507 # Use current dag run's logical_date if not provided
508 effective_logical_date = logical_date
509 if effective_logical_date is None and dag_run and dag_run.logical_date:
510 effective_logical_date = dag_run.logical_date
511
512 response = SUPERVISOR_COMMS.send(
513 msg=GetPreviousTI(
514 dag_id=self.dag_id,
515 task_id=self.task_id,
516 logical_date=effective_logical_date,
517 map_index=map_index,
518 state=state,
519 )
520 )
521
522 if TYPE_CHECKING:
523 assert isinstance(response, PreviousTIResult)
524
525 return response.task_instance
526
527 @staticmethod
528 def get_ti_count(
529 dag_id: str,
530 map_index: int | None = None,
531 task_ids: list[str] | None = None,
532 task_group_id: str | None = None,
533 logical_dates: list[datetime] | None = None,
534 run_ids: list[str] | None = None,
535 states: list[str] | None = None,
536 ) -> int:
537 """Return the number of task instances matching the given criteria."""
538 response = SUPERVISOR_COMMS.send(
539 GetTICount(
540 dag_id=dag_id,
541 map_index=map_index,
542 task_ids=task_ids,
543 task_group_id=task_group_id,
544 logical_dates=logical_dates,
545 run_ids=run_ids,
546 states=states,
547 ),
548 )
549
550 if TYPE_CHECKING:
551 assert isinstance(response, TICount)
552
553 return response.count
554
555 @staticmethod
556 def get_task_states(
557 dag_id: str,
558 map_index: int | None = None,
559 task_ids: list[str] | None = None,
560 task_group_id: str | None = None,
561 logical_dates: list[datetime] | None = None,
562 run_ids: list[str] | None = None,
563 ) -> dict[str, Any]:
564 """Return the task states matching the given criteria."""
565 response = SUPERVISOR_COMMS.send(
566 GetTaskStates(
567 dag_id=dag_id,
568 map_index=map_index,
569 task_ids=task_ids,
570 task_group_id=task_group_id,
571 logical_dates=logical_dates,
572 run_ids=run_ids,
573 ),
574 )
575
576 if TYPE_CHECKING:
577 assert isinstance(response, TaskStatesResult)
578
579 return response.task_states
580
581 @staticmethod
582 def get_task_breadcrumbs(dag_id: str, run_id: str) -> Iterable[dict[str, Any]]:
583 """Return task breadcrumbs for the given dag run."""
584 response = SUPERVISOR_COMMS.send(GetTaskBreadcrumbs(dag_id=dag_id, run_id=run_id))
585 if TYPE_CHECKING:
586 assert isinstance(response, TaskBreadcrumbsResult)
587 return response.breadcrumbs
588
589 @staticmethod
590 def get_dr_count(
591 dag_id: str,
592 logical_dates: list[datetime] | None = None,
593 run_ids: list[str] | None = None,
594 states: list[str] | None = None,
595 ) -> int:
596 """Return the number of Dag runs matching the given criteria."""
597 response = SUPERVISOR_COMMS.send(
598 GetDRCount(
599 dag_id=dag_id,
600 logical_dates=logical_dates,
601 run_ids=run_ids,
602 states=states,
603 ),
604 )
605
606 if TYPE_CHECKING:
607 assert isinstance(response, DRCount)
608
609 return response.count
610
611 @staticmethod
612 def get_dagrun_state(dag_id: str, run_id: str) -> str:
613 """Return the state of the Dag run with the given Run ID."""
614 response = SUPERVISOR_COMMS.send(msg=GetDagRunState(dag_id=dag_id, run_id=run_id))
615
616 if TYPE_CHECKING:
617 assert isinstance(response, DagRunStateResult)
618
619 return response.state
620
621 @property
622 def log_url(self) -> str:
623 run_id = quote(self.run_id)
624 base_url = conf.get("api", "base_url", fallback="http://localhost:8080/")
625 map_index_value = self.map_index
626 map_index = (
627 f"/mapped/{map_index_value}" if map_index_value is not None and map_index_value >= 0 else ""
628 )
629 try_number_value = self.try_number
630 try_number = (
631 f"?try_number={try_number_value}" if try_number_value is not None and try_number_value > 0 else ""
632 )
633 _log_uri = f"{base_url.rstrip('/')}/dags/{self.dag_id}/runs/{run_id}/tasks/{self.task_id}{map_index}{try_number}"
634 return _log_uri
635
636 @property
637 def mark_success_url(self) -> str:
638 """URL to mark TI success."""
639 return self.log_url
640
641
642def _xcom_push(ti: RuntimeTaskInstance, key: str, value: Any, mapped_length: int | None = None) -> None:
643 """Push a XCom through XCom.set, which pushes to XCom Backend if configured."""
644 # Private function, as we don't want to expose the ability to manually set `mapped_length` to SDK
645 # consumers
646
647 XCom.set(
648 key=key,
649 value=value,
650 dag_id=ti.dag_id,
651 task_id=ti.task_id,
652 run_id=ti.run_id,
653 map_index=ti.map_index,
654 _mapped_length=mapped_length,
655 )
656
657
658def _xcom_push_to_db(ti: RuntimeTaskInstance, key: str, value: Any) -> None:
659 """Push a XCom directly to metadata DB, bypassing custom xcom_backend."""
660 XCom._set_xcom_in_db(
661 key=key,
662 value=value,
663 dag_id=ti.dag_id,
664 task_id=ti.task_id,
665 run_id=ti.run_id,
666 map_index=ti.map_index,
667 )
668
669
670def parse(what: StartupDetails, log: Logger) -> RuntimeTaskInstance:
671 # TODO: Task-SDK:
672 # Using DagBag here is about 98% wrong, but it'll do for now
673 from airflow.dag_processing.dagbag import DagBag
674
675 bundle_info = what.bundle_info
676 bundle_instance = DagBundlesManager().get_bundle(
677 name=bundle_info.name,
678 version=bundle_info.version,
679 )
680 bundle_instance.initialize()
681
682 # Put bundle root on sys.path if needed. This allows the dag bundle to add
683 # code in util modules to be shared between files within the same bundle.
684 if (bundle_root := os.fspath(bundle_instance.path)) not in sys.path:
685 sys.path.append(bundle_root)
686
687 dag_absolute_path = os.fspath(Path(bundle_instance.path, what.dag_rel_path))
688 bag = DagBag(
689 dag_folder=dag_absolute_path,
690 include_examples=False,
691 safe_mode=False,
692 load_op_links=False,
693 bundle_name=bundle_info.name,
694 )
695 if TYPE_CHECKING:
696 assert what.ti.dag_id
697
698 try:
699 dag = bag.dags[what.ti.dag_id]
700 except KeyError:
701 log.error(
702 "Dag not found during start up", dag_id=what.ti.dag_id, bundle=bundle_info, path=what.dag_rel_path
703 )
704 sys.exit(1)
705
706 # install_loader()
707
708 try:
709 task = dag.task_dict[what.ti.task_id]
710 except KeyError:
711 log.error(
712 "Task not found in Dag during start up",
713 dag_id=dag.dag_id,
714 task_id=what.ti.task_id,
715 bundle=bundle_info,
716 path=what.dag_rel_path,
717 )
718 sys.exit(1)
719
720 if not isinstance(task, (BaseOperator, MappedOperator)):
721 raise TypeError(
722 f"task is of the wrong type, got {type(task)}, wanted {BaseOperator} or {MappedOperator}"
723 )
724
725 return RuntimeTaskInstance.model_construct(
726 **what.ti.model_dump(exclude_unset=True),
727 task=task,
728 bundle_instance=bundle_instance,
729 _ti_context_from_server=what.ti_context,
730 max_tries=what.ti_context.max_tries,
731 start_date=what.start_date,
732 state=TaskInstanceState.RUNNING,
733 sentry_integration=what.sentry_integration,
734 )
735
736
737# This global variable will be used by Connection/Variable/XCom classes, or other parts of the task's execution,
738# to send requests back to the supervisor process.
739#
740# Why it needs to be a global:
741# - Many parts of Airflow's codebase (e.g., connections, variables, and XComs) may rely on making dynamic requests
742# to the parent process during task execution.
743# - These calls occur in various locations and cannot easily pass the `CommsDecoder` instance through the
744# deeply nested execution stack.
745# - By defining `SUPERVISOR_COMMS` as a global, it ensures that this communication mechanism is readily
746# accessible wherever needed during task execution without modifying every layer of the call stack.
747SUPERVISOR_COMMS: CommsDecoder[ToTask, ToSupervisor]
748
749
750# State machine!
751# 1. Start up (receive details from supervisor)
752# 2. Execution (run task code, possibly send requests)
753# 3. Shutdown and report status
754
755
756def startup() -> tuple[RuntimeTaskInstance, Context, Logger]:
757 # The parent sends us a StartupDetails message un-prompted. After this, every single message is only sent
758 # in response to us sending a request.
759 log = structlog.get_logger(logger_name="task")
760
761 if os.environ.get("_AIRFLOW__REEXECUTED_PROCESS") == "1" and (
762 msgjson := os.environ.get("_AIRFLOW__STARTUP_MSG")
763 ):
764 # Clear any Kerberos replace cache if there is one, so new process can't reuse it.
765 os.environ.pop("KRB5CCNAME", None)
766 # entrypoint of re-exec process
767
768 msg: StartupDetails = TypeAdapter(StartupDetails).validate_json(msgjson)
769 reinit_supervisor_comms()
770
771 # We delay this message until _after_ we've got the logging re-configured, otherwise it will show up
772 # on stdout
773 log.debug("Using serialized startup message from environment", msg=msg)
774 else:
775 # normal entry point
776 msg = SUPERVISOR_COMMS._get_response() # type: ignore[assignment]
777
778 if not isinstance(msg, StartupDetails):
779 raise RuntimeError(f"Unhandled startup message {type(msg)} {msg}")
780
781 # setproctitle causes issue on Mac OS: https://github.com/benoitc/gunicorn/issues/3021
782 os_type = sys.platform
783 if os_type == "darwin":
784 log.debug("Mac OS detected, skipping setproctitle")
785 else:
786 from setproctitle import setproctitle
787
788 setproctitle(f"airflow worker -- {msg.ti.id}")
789
790 try:
791 get_listener_manager().hook.on_starting(component=TaskRunnerMarker())
792 except Exception:
793 log.exception("error calling listener")
794
795 with _airflow_parsing_context_manager(dag_id=msg.ti.dag_id, task_id=msg.ti.task_id):
796 ti = parse(msg, log)
797 log.debug("Dag file parsed", file=msg.dag_rel_path)
798
799 run_as_user = getattr(ti.task, "run_as_user", None) or conf.get(
800 "core", "default_impersonation", fallback=None
801 )
802
803 if os.environ.get("_AIRFLOW__REEXECUTED_PROCESS") != "1" and run_as_user and run_as_user != getuser():
804 # enters here for re-exec process
805 os.environ["_AIRFLOW__REEXECUTED_PROCESS"] = "1"
806 # store startup message in environment for re-exec process
807 os.environ["_AIRFLOW__STARTUP_MSG"] = msg.model_dump_json()
808 os.set_inheritable(SUPERVISOR_COMMS.socket.fileno(), True)
809
810 # Import main directly from the module instead of re-executing the file.
811 # This ensures that when other parts modules import
812 # airflow.sdk.execution_time.task_runner, they get the same module instance
813 # with the properly initialized SUPERVISOR_COMMS global variable.
814 # If we re-executed the module with `python -m`, it would load as __main__ and future
815 # imports would get a fresh copy without the initialized globals.
816 rexec_python_code = "from airflow.sdk.execution_time.task_runner import main; main()"
817 cmd = ["sudo", "-E", "-H", "-u", run_as_user, sys.executable, "-c", rexec_python_code]
818 log.info(
819 "Running command",
820 command=cmd,
821 )
822 os.execvp("sudo", cmd)
823
824 # ideally, we should never reach here, but if we do, we should return None, None, None
825 return None, None, None
826
827 return ti, ti.get_template_context(), log
828
829
830def _serialize_template_field(template_field: Any, name: str) -> str | dict | list | int | float:
831 """
832 Return a serializable representation of the templated field.
833
834 If ``templated_field`` contains a class or instance that requires recursive
835 templating, store them as strings. Otherwise simply return the field as-is.
836
837 Used sdk secrets masker to redact secrets in the serialized output.
838 """
839 import json
840
841 from airflow.sdk._shared.secrets_masker import redact
842
843 def is_jsonable(x):
844 try:
845 json.dumps(x)
846 except (TypeError, OverflowError):
847 return False
848 else:
849 return True
850
851 def translate_tuples_to_lists(obj: Any):
852 """Recursively convert tuples to lists."""
853 if isinstance(obj, tuple):
854 return [translate_tuples_to_lists(item) for item in obj]
855 if isinstance(obj, list):
856 return [translate_tuples_to_lists(item) for item in obj]
857 if isinstance(obj, dict):
858 return {key: translate_tuples_to_lists(value) for key, value in obj.items()}
859 return obj
860
861 def sort_dict_recursively(obj: Any) -> Any:
862 """Recursively sort dictionaries to ensure consistent ordering."""
863 if isinstance(obj, dict):
864 return {k: sort_dict_recursively(v) for k, v in sorted(obj.items())}
865 if isinstance(obj, list):
866 return [sort_dict_recursively(item) for item in obj]
867 if isinstance(obj, tuple):
868 return tuple(sort_dict_recursively(item) for item in obj)
869 return obj
870
871 max_length = conf.getint("core", "max_templated_field_length")
872
873 if not is_jsonable(template_field):
874 try:
875 serialized = template_field.serialize()
876 except AttributeError:
877 serialized = str(template_field)
878 if len(serialized) > max_length:
879 rendered = redact(serialized, name)
880 return (
881 "Truncated. You can change this behaviour in [core]max_templated_field_length. "
882 f"{rendered[: max_length - 79]!r}... "
883 )
884 return serialized
885 if not template_field and not isinstance(template_field, tuple):
886 # Avoid unnecessary serialization steps for empty fields unless they are tuples
887 # and need to be converted to lists
888 return template_field
889 template_field = translate_tuples_to_lists(template_field)
890 # Sort dictionaries recursively to ensure consistent string representation
891 # This prevents hash inconsistencies when dict ordering varies
892 if isinstance(template_field, dict):
893 template_field = sort_dict_recursively(template_field)
894 serialized = str(template_field)
895 if len(serialized) > max_length:
896 rendered = redact(serialized, name)
897 return (
898 "Truncated. You can change this behaviour in [core]max_templated_field_length. "
899 f"{rendered[: max_length - 79]!r}... "
900 )
901 return template_field
902
903
904def _serialize_rendered_fields(task: AbstractOperator) -> dict[str, JsonValue]:
905 from airflow.sdk._shared.secrets_masker import redact
906
907 rendered_fields = {}
908 for field in task.template_fields:
909 value = getattr(task, field)
910 serialized = _serialize_template_field(value, field)
911 # Redact secrets in the task process itself before sending to API server
912 # This ensures that the secrets those are registered via mask_secret() on workers / dag processor are properly masked
913 # on the UI.
914 rendered_fields[field] = redact(serialized, field)
915
916 return rendered_fields # type: ignore[return-value] # Convince mypy that this is OK since we pass JsonValue to redact, so it will return the same
917
918
919def _build_asset_profiles(lineage_objects: list) -> Iterator[AssetProfile]:
920 # Lineage can have other types of objects besides assets, so we need to process them a bit.
921 for obj in lineage_objects or ():
922 if isinstance(obj, Asset):
923 yield AssetProfile(name=obj.name, uri=obj.uri, type=Asset.__name__)
924 elif isinstance(obj, AssetNameRef):
925 yield AssetProfile(name=obj.name, type=AssetNameRef.__name__)
926 elif isinstance(obj, AssetUriRef):
927 yield AssetProfile(uri=obj.uri, type=AssetUriRef.__name__)
928 elif isinstance(obj, AssetAlias):
929 yield AssetProfile(name=obj.name, type=AssetAlias.__name__)
930
931
932def _serialize_outlet_events(events: OutletEventAccessorsProtocol) -> Iterator[dict[str, JsonValue]]:
933 if TYPE_CHECKING:
934 assert isinstance(events, OutletEventAccessors)
935 # We just collect everything the user recorded in the accessors.
936 # Further filtering will be done in the API server.
937 for key, accessor in events._dict.items():
938 if isinstance(key, AssetUniqueKey):
939 yield {"dest_asset_key": attrs.asdict(key), "extra": accessor.extra}
940 for alias_event in accessor.asset_alias_events:
941 yield attrs.asdict(alias_event)
942
943
944def _prepare(ti: RuntimeTaskInstance, log: Logger, context: Context) -> ToSupervisor | None:
945 ti.hostname = get_hostname()
946 ti.task = ti.task.prepare_for_execution()
947 # Since context is now cached, and calling `ti.get_template_context` will return the same dict, we want to
948 # update the value of the task that is sent from there
949 context["task"] = ti.task
950
951 jinja_env = ti.task.dag.get_template_env()
952 ti.render_templates(context=context, jinja_env=jinja_env)
953
954 if rendered_fields := _serialize_rendered_fields(ti.task):
955 # so that we do not call the API unnecessarily
956 SUPERVISOR_COMMS.send(msg=SetRenderedFields(rendered_fields=rendered_fields))
957
958 # Try to render map_index_template early with available context (will be re-rendered after execution)
959 # This provides a partial label during task execution for templates using pre-execution context
960 # If rendering fails here, we suppress the error since it will be re-rendered after execution
961 try:
962 if rendered_map_index := _render_map_index(context, ti=ti, log=log):
963 ti.rendered_map_index = rendered_map_index
964 log.debug("Sending early rendered map index", length=len(rendered_map_index))
965 SUPERVISOR_COMMS.send(msg=SetRenderedMapIndex(rendered_map_index=rendered_map_index))
966 except Exception:
967 log.debug(
968 "Early rendering of map_index_template failed, will retry after task execution", exc_info=True
969 )
970
971 _validate_task_inlets_and_outlets(ti=ti, log=log)
972
973 try:
974 # TODO: Call pre execute etc.
975 get_listener_manager().hook.on_task_instance_running(
976 previous_state=TaskInstanceState.QUEUED, task_instance=ti
977 )
978 except Exception:
979 log.exception("error calling listener")
980
981 # No error, carry on and execute the task
982 return None
983
984
985def _validate_task_inlets_and_outlets(*, ti: RuntimeTaskInstance, log: Logger) -> None:
986 if not ti.task.inlets and not ti.task.outlets:
987 return
988
989 inactive_assets_resp = SUPERVISOR_COMMS.send(msg=ValidateInletsAndOutlets(ti_id=ti.id))
990 if TYPE_CHECKING:
991 assert isinstance(inactive_assets_resp, InactiveAssetsResult)
992 if inactive_assets := inactive_assets_resp.inactive_assets:
993 raise AirflowInactiveAssetInInletOrOutletException(
994 inactive_asset_keys=[
995 AssetUniqueKey.from_profile(asset_profile) for asset_profile in inactive_assets
996 ]
997 )
998
999
1000def _defer_task(
1001 defer: TaskDeferred, ti: RuntimeTaskInstance, log: Logger
1002) -> tuple[ToSupervisor, TaskInstanceState]:
1003 # TODO: Should we use structlog.bind_contextvars here for dag_id, task_id & run_id?
1004
1005 log.info("Pausing task as DEFERRED. ", dag_id=ti.dag_id, task_id=ti.task_id, run_id=ti.run_id)
1006 classpath, trigger_kwargs = defer.trigger.serialize()
1007
1008 from airflow.sdk.serde import serialize as serde_serialize
1009
1010 trigger_kwargs = serde_serialize(trigger_kwargs)
1011 next_kwargs = serde_serialize(defer.kwargs or {})
1012
1013 if TYPE_CHECKING:
1014 assert isinstance(next_kwargs, dict)
1015 assert isinstance(trigger_kwargs, dict)
1016
1017 msg = DeferTask(
1018 classpath=classpath,
1019 trigger_kwargs=trigger_kwargs,
1020 trigger_timeout=defer.timeout,
1021 next_method=defer.method_name,
1022 next_kwargs=next_kwargs,
1023 )
1024 state = TaskInstanceState.DEFERRED
1025
1026 return msg, state
1027
1028
1029@Sentry.enrich_errors
1030def run(
1031 ti: RuntimeTaskInstance,
1032 context: Context,
1033 log: Logger,
1034) -> tuple[TaskInstanceState, ToSupervisor | None, BaseException | None]:
1035 """Run the task in this process."""
1036 import signal
1037
1038 from airflow.sdk.exceptions import (
1039 AirflowFailException,
1040 AirflowRescheduleException,
1041 AirflowSensorTimeout,
1042 AirflowSkipException,
1043 AirflowTaskTerminated,
1044 DagRunTriggerException,
1045 DownstreamTasksSkipped,
1046 TaskDeferred,
1047 )
1048
1049 if TYPE_CHECKING:
1050 assert ti.task is not None
1051 assert isinstance(ti.task, BaseOperator)
1052
1053 parent_pid = os.getpid()
1054
1055 def _on_term(signum, frame):
1056 pid = os.getpid()
1057 if pid != parent_pid:
1058 return
1059
1060 ti.task.on_kill()
1061
1062 signal.signal(signal.SIGTERM, _on_term)
1063
1064 msg: ToSupervisor | None = None
1065 state: TaskInstanceState
1066 error: BaseException | None = None
1067
1068 try:
1069 # First, clear the xcom data sent from server
1070 if ti._ti_context_from_server and (keys_to_delete := ti._ti_context_from_server.xcom_keys_to_clear):
1071 for x in keys_to_delete:
1072 log.debug("Clearing XCom with key", key=x)
1073 XCom.delete(
1074 key=x,
1075 dag_id=ti.dag_id,
1076 task_id=ti.task_id,
1077 run_id=ti.run_id,
1078 map_index=ti.map_index,
1079 )
1080
1081 with set_current_context(context):
1082 # This is the earliest that we can render templates -- as if it excepts for any reason we need to
1083 # catch it and handle it like a normal task failure
1084 if early_exit := _prepare(ti, log, context):
1085 msg = early_exit
1086 ti.state = state = TaskInstanceState.FAILED
1087 return state, msg, error
1088
1089 try:
1090 result = _execute_task(context=context, ti=ti, log=log)
1091 except Exception:
1092 import jinja2
1093
1094 # If the task failed, swallow rendering error so it doesn't mask the main error.
1095 with contextlib.suppress(jinja2.TemplateSyntaxError, jinja2.UndefinedError):
1096 previous_rendered_map_index = ti.rendered_map_index
1097 ti.rendered_map_index = _render_map_index(context, ti=ti, log=log)
1098 # Send update only if value changed (e.g., user set context variables during execution)
1099 if ti.rendered_map_index and ti.rendered_map_index != previous_rendered_map_index:
1100 SUPERVISOR_COMMS.send(
1101 msg=SetRenderedMapIndex(rendered_map_index=ti.rendered_map_index)
1102 )
1103 raise
1104 else: # If the task succeeded, render normally to let rendering error bubble up.
1105 previous_rendered_map_index = ti.rendered_map_index
1106 ti.rendered_map_index = _render_map_index(context, ti=ti, log=log)
1107 # Send update only if value changed (e.g., user set context variables during execution)
1108 if ti.rendered_map_index and ti.rendered_map_index != previous_rendered_map_index:
1109 SUPERVISOR_COMMS.send(msg=SetRenderedMapIndex(rendered_map_index=ti.rendered_map_index))
1110
1111 _push_xcom_if_needed(result, ti, log)
1112
1113 msg, state = _handle_current_task_success(context, ti)
1114 except DownstreamTasksSkipped as skip:
1115 log.info("Skipping downstream tasks.")
1116 tasks_to_skip = skip.tasks if isinstance(skip.tasks, list) else [skip.tasks]
1117 SUPERVISOR_COMMS.send(msg=SkipDownstreamTasks(tasks=tasks_to_skip))
1118 msg, state = _handle_current_task_success(context, ti)
1119 except DagRunTriggerException as drte:
1120 msg, state = _handle_trigger_dag_run(drte, context, ti, log)
1121 except TaskDeferred as defer:
1122 msg, state = _defer_task(defer, ti, log)
1123 except AirflowSkipException as e:
1124 if e.args:
1125 log.info("Skipping task.", reason=e.args[0])
1126 msg = TaskState(
1127 state=TaskInstanceState.SKIPPED,
1128 end_date=datetime.now(tz=timezone.utc),
1129 rendered_map_index=ti.rendered_map_index,
1130 )
1131 state = TaskInstanceState.SKIPPED
1132 except AirflowRescheduleException as reschedule:
1133 log.info("Rescheduling task, marking task as UP_FOR_RESCHEDULE")
1134 msg = RescheduleTask(
1135 reschedule_date=reschedule.reschedule_date, end_date=datetime.now(tz=timezone.utc)
1136 )
1137 state = TaskInstanceState.UP_FOR_RESCHEDULE
1138 except (AirflowFailException, AirflowSensorTimeout) as e:
1139 # If AirflowFailException is raised, task should not retry.
1140 # If a sensor in reschedule mode reaches timeout, task should not retry.
1141 log.exception("Task failed with exception")
1142 ti.end_date = datetime.now(tz=timezone.utc)
1143 msg = TaskState(
1144 state=TaskInstanceState.FAILED,
1145 end_date=ti.end_date,
1146 rendered_map_index=ti.rendered_map_index,
1147 )
1148 state = TaskInstanceState.FAILED
1149 error = e
1150 except (AirflowTaskTimeout, AirflowException, AirflowRuntimeError) as e:
1151 # We should allow retries if the task has defined it.
1152 log.exception("Task failed with exception")
1153 msg, state = _handle_current_task_failed(ti)
1154 error = e
1155 except AirflowTaskTerminated as e:
1156 # External state updates are already handled with `ti_heartbeat` and will be
1157 # updated already be another UI API. So, these exceptions should ideally never be thrown.
1158 # If these are thrown, we should mark the TI state as failed.
1159 log.exception("Task failed with exception")
1160 ti.end_date = datetime.now(tz=timezone.utc)
1161 msg = TaskState(
1162 state=TaskInstanceState.FAILED,
1163 end_date=ti.end_date,
1164 rendered_map_index=ti.rendered_map_index,
1165 )
1166 state = TaskInstanceState.FAILED
1167 error = e
1168 except SystemExit as e:
1169 # SystemExit needs to be retried if they are eligible.
1170 log.error("Task exited", exit_code=e.code)
1171 msg, state = _handle_current_task_failed(ti)
1172 error = e
1173 except BaseException as e:
1174 log.exception("Task failed with exception")
1175 msg, state = _handle_current_task_failed(ti)
1176 error = e
1177 finally:
1178 if msg:
1179 SUPERVISOR_COMMS.send(msg=msg)
1180
1181 # Return the message to make unit tests easier too
1182 ti.state = state
1183 return state, msg, error
1184
1185
1186def _handle_current_task_success(
1187 context: Context,
1188 ti: RuntimeTaskInstance,
1189) -> tuple[SucceedTask, TaskInstanceState]:
1190 end_date = datetime.now(tz=timezone.utc)
1191 ti.end_date = end_date
1192
1193 # Record operator and task instance success metrics
1194 operator = ti.task.__class__.__name__
1195 stats_tags = {"dag_id": ti.dag_id, "task_id": ti.task_id}
1196
1197 Stats.incr(f"operator_successes_{operator}", tags=stats_tags)
1198 # Same metric with tagging
1199 Stats.incr("operator_successes", tags={**stats_tags, "operator": operator})
1200 Stats.incr("ti_successes", tags=stats_tags)
1201
1202 task_outlets = list(_build_asset_profiles(ti.task.outlets))
1203 outlet_events = list(_serialize_outlet_events(context["outlet_events"]))
1204 msg = SucceedTask(
1205 end_date=end_date,
1206 task_outlets=task_outlets,
1207 outlet_events=outlet_events,
1208 rendered_map_index=ti.rendered_map_index,
1209 )
1210 return msg, TaskInstanceState.SUCCESS
1211
1212
1213def _handle_current_task_failed(
1214 ti: RuntimeTaskInstance,
1215) -> tuple[RetryTask, TaskInstanceState] | tuple[TaskState, TaskInstanceState]:
1216 end_date = datetime.now(tz=timezone.utc)
1217 ti.end_date = end_date
1218 if ti._ti_context_from_server and ti._ti_context_from_server.should_retry:
1219 return RetryTask(end_date=end_date), TaskInstanceState.UP_FOR_RETRY
1220 return (
1221 TaskState(
1222 state=TaskInstanceState.FAILED, end_date=end_date, rendered_map_index=ti.rendered_map_index
1223 ),
1224 TaskInstanceState.FAILED,
1225 )
1226
1227
1228def _handle_trigger_dag_run(
1229 drte: DagRunTriggerException, context: Context, ti: RuntimeTaskInstance, log: Logger
1230) -> tuple[ToSupervisor, TaskInstanceState]:
1231 """Handle exception from TriggerDagRunOperator."""
1232 log.info("Triggering Dag Run.", trigger_dag_id=drte.trigger_dag_id)
1233 comms_msg = SUPERVISOR_COMMS.send(
1234 TriggerDagRun(
1235 dag_id=drte.trigger_dag_id,
1236 run_id=drte.dag_run_id,
1237 logical_date=drte.logical_date,
1238 conf=drte.conf,
1239 reset_dag_run=drte.reset_dag_run,
1240 ),
1241 )
1242
1243 if isinstance(comms_msg, ErrorResponse) and comms_msg.error == ErrorType.DAGRUN_ALREADY_EXISTS:
1244 if drte.skip_when_already_exists:
1245 log.info(
1246 "Dag Run already exists, skipping task as skip_when_already_exists is set to True.",
1247 dag_id=drte.trigger_dag_id,
1248 )
1249 msg = TaskState(
1250 state=TaskInstanceState.SKIPPED,
1251 end_date=datetime.now(tz=timezone.utc),
1252 rendered_map_index=ti.rendered_map_index,
1253 )
1254 state = TaskInstanceState.SKIPPED
1255 else:
1256 log.error("Dag Run already exists, marking task as failed.", dag_id=drte.trigger_dag_id)
1257 msg = TaskState(
1258 state=TaskInstanceState.FAILED,
1259 end_date=datetime.now(tz=timezone.utc),
1260 rendered_map_index=ti.rendered_map_index,
1261 )
1262 state = TaskInstanceState.FAILED
1263
1264 return msg, state
1265
1266 log.info("Dag Run triggered successfully.", trigger_dag_id=drte.trigger_dag_id)
1267
1268 # Store the run id from the dag run (either created or found above) to
1269 # be used when creating the extra link on the webserver.
1270 ti.xcom_push(key="trigger_run_id", value=drte.dag_run_id)
1271
1272 if drte.deferrable:
1273 from airflow.providers.standard.triggers.external_task import DagStateTrigger
1274
1275 defer = TaskDeferred(
1276 trigger=DagStateTrigger(
1277 dag_id=drte.trigger_dag_id,
1278 states=drte.allowed_states + drte.failed_states, # type: ignore[arg-type]
1279 # Don't filter by execution_dates when run_ids is provided.
1280 # run_id uniquely identifies a DAG run, and when reset_dag_run=True,
1281 # drte.logical_date might be a newly calculated value that doesn't match
1282 # the persisted logical_date in the database, causing the trigger to never find the run.
1283 execution_dates=None,
1284 run_ids=[drte.dag_run_id],
1285 poll_interval=drte.poke_interval,
1286 ),
1287 method_name="execute_complete",
1288 )
1289 return _defer_task(defer, ti, log)
1290 if drte.wait_for_completion:
1291 while True:
1292 log.info(
1293 "Waiting for dag run to complete execution in allowed state.",
1294 dag_id=drte.trigger_dag_id,
1295 run_id=drte.dag_run_id,
1296 allowed_state=drte.allowed_states,
1297 )
1298 time.sleep(drte.poke_interval)
1299
1300 comms_msg = SUPERVISOR_COMMS.send(
1301 GetDagRunState(dag_id=drte.trigger_dag_id, run_id=drte.dag_run_id)
1302 )
1303 if TYPE_CHECKING:
1304 assert isinstance(comms_msg, DagRunStateResult)
1305 if comms_msg.state in drte.failed_states:
1306 log.error(
1307 "DagRun finished with failed state.", dag_id=drte.trigger_dag_id, state=comms_msg.state
1308 )
1309 msg = TaskState(
1310 state=TaskInstanceState.FAILED,
1311 end_date=datetime.now(tz=timezone.utc),
1312 rendered_map_index=ti.rendered_map_index,
1313 )
1314 state = TaskInstanceState.FAILED
1315 return msg, state
1316 if comms_msg.state in drte.allowed_states:
1317 log.info(
1318 "DagRun finished with allowed state.", dag_id=drte.trigger_dag_id, state=comms_msg.state
1319 )
1320 break
1321 log.debug(
1322 "DagRun not yet in allowed or failed state.",
1323 dag_id=drte.trigger_dag_id,
1324 state=comms_msg.state,
1325 )
1326
1327 return _handle_current_task_success(context, ti)
1328
1329
1330def _run_task_state_change_callbacks(
1331 task: BaseOperator,
1332 kind: Literal[
1333 "on_execute_callback",
1334 "on_failure_callback",
1335 "on_success_callback",
1336 "on_retry_callback",
1337 "on_skipped_callback",
1338 ],
1339 context: Context,
1340 log: Logger,
1341) -> None:
1342 callback: Callable[[Context], None]
1343 for i, callback in enumerate(getattr(task, kind)):
1344 try:
1345 create_executable_runner(callback, context_get_outlet_events(context), logger=log).run(context)
1346 except Exception:
1347 log.exception("Failed to run task callback", kind=kind, index=i, callback=callback)
1348
1349
1350def _send_error_email_notification(
1351 task: BaseOperator | MappedOperator,
1352 ti: RuntimeTaskInstance,
1353 context: Context,
1354 error: BaseException | str | None,
1355 log: Logger,
1356) -> None:
1357 """Send email notification for task errors using SmtpNotifier."""
1358 try:
1359 from airflow.providers.smtp.notifications.smtp import SmtpNotifier
1360 except ImportError:
1361 log.error(
1362 "Failed to send task failure or retry email notification: "
1363 "`apache-airflow-providers-smtp` is not installed. "
1364 "Install this provider to enable email notifications."
1365 )
1366 return
1367
1368 if not task.email:
1369 return
1370
1371 subject_template_file = conf.get("email", "subject_template", fallback=None)
1372
1373 # Read the template file if configured
1374 if subject_template_file and Path(subject_template_file).exists():
1375 subject = Path(subject_template_file).read_text()
1376 else:
1377 # Fallback to default
1378 subject = "Airflow alert: {{ti}}"
1379
1380 html_content_template_file = conf.get("email", "html_content_template", fallback=None)
1381
1382 # Read the template file if configured
1383 if html_content_template_file and Path(html_content_template_file).exists():
1384 html_content = Path(html_content_template_file).read_text()
1385 else:
1386 # Fallback to default
1387 # For reporting purposes, we report based on 1-indexed,
1388 # not 0-indexed lists (i.e. Try 1 instead of Try 0 for the first attempt).
1389 html_content = (
1390 "Try {{try_number}} out of {{max_tries + 1}}<br>"
1391 "Exception:<br>{{exception_html}}<br>"
1392 'Log: <a href="{{ti.log_url}}">Link</a><br>'
1393 "Host: {{ti.hostname}}<br>"
1394 'Mark success: <a href="{{ti.mark_success_url}}">Link</a><br>'
1395 )
1396
1397 # Add exception_html to context for template rendering
1398 import html
1399
1400 exception_html = html.escape(str(error)).replace("\n", "<br>")
1401 additional_context = {
1402 "exception": error,
1403 "exception_html": exception_html,
1404 "try_number": ti.try_number,
1405 "max_tries": ti.max_tries,
1406 }
1407 email_context = {**context, **additional_context}
1408 to_emails = task.email
1409 if not to_emails:
1410 return
1411
1412 try:
1413 notifier = SmtpNotifier(
1414 to=to_emails,
1415 subject=subject,
1416 html_content=html_content,
1417 from_email=conf.get("email", "from_email", fallback="airflow@airflow"),
1418 )
1419 notifier(email_context)
1420 except Exception:
1421 log.exception("Failed to send email notification")
1422
1423
1424def _execute_task(context: Context, ti: RuntimeTaskInstance, log: Logger):
1425 """Execute Task (optionally with a Timeout) and push Xcom results."""
1426 task = ti.task
1427 execute = task.execute
1428
1429 if ti._ti_context_from_server and (next_method := ti._ti_context_from_server.next_method):
1430 from airflow.sdk.serde import deserialize
1431
1432 next_kwargs_data = ti._ti_context_from_server.next_kwargs or {}
1433 try:
1434 if TYPE_CHECKING:
1435 assert isinstance(next_kwargs_data, dict)
1436 kwargs = deserialize(next_kwargs_data)
1437 except (ImportError, KeyError, AttributeError, TypeError):
1438 from airflow.serialization.serialized_objects import BaseSerialization
1439
1440 kwargs = BaseSerialization.deserialize(next_kwargs_data)
1441
1442 if TYPE_CHECKING:
1443 assert isinstance(kwargs, dict)
1444 execute = functools.partial(task.resume_execution, next_method=next_method, next_kwargs=kwargs)
1445
1446 ctx = contextvars.copy_context()
1447 # Populate the context var so ExecutorSafeguard doesn't complain
1448 ctx.run(ExecutorSafeguard.tracker.set, task)
1449
1450 # Export context in os.environ to make it available for operators to use.
1451 airflow_context_vars = context_to_airflow_vars(context, in_env_var_format=True)
1452 os.environ.update(airflow_context_vars)
1453
1454 outlet_events = context_get_outlet_events(context)
1455
1456 if (pre_execute_hook := task._pre_execute_hook) is not None:
1457 create_executable_runner(pre_execute_hook, outlet_events, logger=log).run(context)
1458 if getattr(pre_execute_hook := task.pre_execute, "__func__", None) is not BaseOperator.pre_execute:
1459 create_executable_runner(pre_execute_hook, outlet_events, logger=log).run(context)
1460
1461 _run_task_state_change_callbacks(task, "on_execute_callback", context, log)
1462
1463 if task.execution_timeout:
1464 from airflow.sdk.execution_time.timeout import timeout
1465
1466 # TODO: handle timeout in case of deferral
1467 timeout_seconds = task.execution_timeout.total_seconds()
1468 try:
1469 # It's possible we're already timed out, so fast-fail if true
1470 if timeout_seconds <= 0:
1471 raise AirflowTaskTimeout()
1472 # Run task in timeout wrapper
1473 with timeout(timeout_seconds):
1474 result = ctx.run(execute, context=context)
1475 except AirflowTaskTimeout:
1476 task.on_kill()
1477 raise
1478 else:
1479 result = ctx.run(execute, context=context)
1480
1481 if (post_execute_hook := task._post_execute_hook) is not None:
1482 create_executable_runner(post_execute_hook, outlet_events, logger=log).run(context, result)
1483 if getattr(post_execute_hook := task.post_execute, "__func__", None) is not BaseOperator.post_execute:
1484 create_executable_runner(post_execute_hook, outlet_events, logger=log).run(context)
1485
1486 return result
1487
1488
1489def _render_map_index(context: Context, ti: RuntimeTaskInstance, log: Logger) -> str | None:
1490 """Render named map index if the Dag author defined map_index_template at the task level."""
1491 if (template := context.get("map_index_template")) is None:
1492 return None
1493 log.debug("Rendering map_index_template", template_length=len(template))
1494 jinja_env = ti.task.dag.get_template_env()
1495 rendered_map_index = jinja_env.from_string(template).render(context)
1496 log.debug("Map index rendered", length=len(rendered_map_index))
1497 return rendered_map_index
1498
1499
1500def _push_xcom_if_needed(result: Any, ti: RuntimeTaskInstance, log: Logger):
1501 """Push XCom values when task has ``do_xcom_push`` set to ``True`` and the task returns a result."""
1502 if ti.task.do_xcom_push:
1503 xcom_value = result
1504 else:
1505 xcom_value = None
1506
1507 has_mapped_dep = next(ti.task.iter_mapped_dependants(), None) is not None
1508 if xcom_value is None:
1509 if not ti.is_mapped and has_mapped_dep:
1510 # Uhoh, a downstream mapped task depends on us to push something to map over
1511 from airflow.sdk.exceptions import XComForMappingNotPushed
1512
1513 raise XComForMappingNotPushed()
1514 return
1515
1516 mapped_length: int | None = None
1517 if not ti.is_mapped and has_mapped_dep:
1518 from airflow.sdk.definitions.mappedoperator import is_mappable_value
1519 from airflow.sdk.exceptions import UnmappableXComTypePushed
1520
1521 if not is_mappable_value(xcom_value):
1522 raise UnmappableXComTypePushed(xcom_value)
1523 mapped_length = len(xcom_value)
1524
1525 log.info("Pushing xcom", ti=ti)
1526
1527 # If the task has multiple outputs, push each output as a separate XCom.
1528 if ti.task.multiple_outputs:
1529 if not isinstance(xcom_value, Mapping):
1530 raise TypeError(
1531 f"Returned output was type {type(xcom_value)} expected dictionary for multiple_outputs"
1532 )
1533 for key in xcom_value.keys():
1534 if not isinstance(key, str):
1535 raise TypeError(
1536 "Returned dictionary keys must be strings when using "
1537 f"multiple_outputs, found {key} ({type(key)}) instead"
1538 )
1539 for k, v in result.items():
1540 ti.xcom_push(k, v)
1541
1542 _xcom_push(ti, BaseXCom.XCOM_RETURN_KEY, result, mapped_length=mapped_length)
1543
1544
1545def finalize(
1546 ti: RuntimeTaskInstance,
1547 state: TaskInstanceState,
1548 context: Context,
1549 log: Logger,
1550 error: BaseException | None = None,
1551):
1552 # Record task duration metrics for all terminal states
1553 if ti.start_date and ti.end_date:
1554 duration_ms = (ti.end_date - ti.start_date).total_seconds() * 1000
1555 stats_tags = {"dag_id": ti.dag_id, "task_id": ti.task_id}
1556
1557 Stats.timing(f"dag.{ti.dag_id}.{ti.task_id}.duration", duration_ms)
1558 Stats.timing("task.duration", duration_ms, tags=stats_tags)
1559
1560 task = ti.task
1561 # Pushing xcom for each operator extra links defined on the operator only.
1562 for oe in task.operator_extra_links:
1563 try:
1564 link, xcom_key = oe.get_link(operator=task, ti_key=ti), oe.xcom_key # type: ignore[arg-type]
1565 log.debug("Setting xcom for operator extra link", link=link, xcom_key=xcom_key)
1566 _xcom_push_to_db(ti, key=xcom_key, value=link)
1567 except Exception:
1568 log.exception(
1569 "Failed to push an xcom for task operator extra link",
1570 link_name=oe.name,
1571 xcom_key=oe.xcom_key,
1572 ti=ti,
1573 )
1574
1575 if getattr(ti.task, "overwrite_rtif_after_execution", False):
1576 log.debug("Overwriting Rendered template fields.")
1577 if ti.task.template_fields:
1578 SUPERVISOR_COMMS.send(SetRenderedFields(rendered_fields=_serialize_rendered_fields(ti.task)))
1579
1580 log.debug("Running finalizers", ti=ti)
1581 if state == TaskInstanceState.SUCCESS:
1582 _run_task_state_change_callbacks(task, "on_success_callback", context, log)
1583 try:
1584 get_listener_manager().hook.on_task_instance_success(
1585 previous_state=TaskInstanceState.RUNNING, task_instance=ti
1586 )
1587 except Exception:
1588 log.exception("error calling listener")
1589 elif state == TaskInstanceState.SKIPPED:
1590 _run_task_state_change_callbacks(task, "on_skipped_callback", context, log)
1591 elif state == TaskInstanceState.UP_FOR_RETRY:
1592 _run_task_state_change_callbacks(task, "on_retry_callback", context, log)
1593 try:
1594 get_listener_manager().hook.on_task_instance_failed(
1595 previous_state=TaskInstanceState.RUNNING, task_instance=ti, error=error
1596 )
1597 except Exception:
1598 log.exception("error calling listener")
1599 if error and task.email_on_retry and task.email:
1600 _send_error_email_notification(task, ti, context, error, log)
1601 elif state == TaskInstanceState.FAILED:
1602 _run_task_state_change_callbacks(task, "on_failure_callback", context, log)
1603 try:
1604 get_listener_manager().hook.on_task_instance_failed(
1605 previous_state=TaskInstanceState.RUNNING, task_instance=ti, error=error
1606 )
1607 except Exception:
1608 log.exception("error calling listener")
1609 if error and task.email_on_failure and task.email:
1610 _send_error_email_notification(task, ti, context, error, log)
1611
1612 try:
1613 get_listener_manager().hook.before_stopping(component=TaskRunnerMarker())
1614 except Exception:
1615 log.exception("error calling listener")
1616
1617
1618def main():
1619 log = structlog.get_logger(logger_name="task")
1620
1621 global SUPERVISOR_COMMS
1622 SUPERVISOR_COMMS = CommsDecoder[ToTask, ToSupervisor](log=log)
1623
1624 try:
1625 ti, context, log = startup()
1626 with BundleVersionLock(
1627 bundle_name=ti.bundle_instance.name,
1628 bundle_version=ti.bundle_instance.version,
1629 ):
1630 state, _, error = run(ti, context, log)
1631 context["exception"] = error
1632 finalize(ti, state, context, log, error)
1633 except KeyboardInterrupt:
1634 log.exception("Ctrl-c hit")
1635 exit(2)
1636 except Exception:
1637 log.exception("Top level error")
1638 exit(1)
1639 finally:
1640 # Ensure the request socket is closed on the child side in all circumstances
1641 # before the process fully terminates.
1642 if SUPERVISOR_COMMS and SUPERVISOR_COMMS.socket:
1643 with suppress(Exception):
1644 SUPERVISOR_COMMS.socket.close()
1645
1646
1647def reinit_supervisor_comms() -> None:
1648 """
1649 Re-initialize supervisor comms and logging channel in subprocess.
1650
1651 This is not needed for most cases, but is used when either we re-launch the process via sudo for
1652 run_as_user, or from inside the python code in a virtualenv (et al.) operator to re-connect so those tasks
1653 can continue to access variables etc.
1654 """
1655 import socket
1656
1657 if "SUPERVISOR_COMMS" not in globals():
1658 global SUPERVISOR_COMMS
1659 log = structlog.get_logger(logger_name="task")
1660
1661 fd = int(os.environ.get("__AIRFLOW_SUPERVISOR_FD", "0"))
1662
1663 SUPERVISOR_COMMS = CommsDecoder[ToTask, ToSupervisor](log=log, socket=socket.socket(fileno=fd))
1664
1665 logs = SUPERVISOR_COMMS.send(ResendLoggingFD())
1666 if isinstance(logs, SentFDs):
1667 from airflow.sdk.log import configure_logging
1668
1669 log_io = os.fdopen(logs.fds[0], "wb", buffering=0)
1670 configure_logging(json_output=True, output=log_io, sending_to_supervisor=True)
1671 else:
1672 print("Unable to re-configure logging after sudo, we didn't get an FD", file=sys.stderr)
1673
1674
1675if __name__ == "__main__":
1676 main()