1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18"""The entrypoint for the actual task execution process."""
19
20from __future__ import annotations
21
22import contextlib
23import contextvars
24import functools
25import os
26import sys
27import time
28from collections.abc import Callable, Iterable, Iterator, Mapping
29from contextlib import suppress
30from datetime import datetime, timezone
31from itertools import product
32from pathlib import Path
33from typing import TYPE_CHECKING, Annotated, Any, Literal
34from urllib.parse import quote
35
36import attrs
37import lazy_object_proxy
38import structlog
39from pydantic import AwareDatetime, ConfigDict, Field, JsonValue, TypeAdapter
40
41from airflow.dag_processing.bundles.base import BaseDagBundle, BundleVersionLock
42from airflow.dag_processing.bundles.manager import DagBundlesManager
43from airflow.listeners.listener import get_listener_manager
44from airflow.sdk.api.client import get_hostname, getuser
45from airflow.sdk.api.datamodels._generated import (
46 AssetProfile,
47 DagRun,
48 TaskInstance,
49 TaskInstanceState,
50 TIRunContext,
51)
52from airflow.sdk.bases.operator import BaseOperator, ExecutorSafeguard
53from airflow.sdk.bases.xcom import BaseXCom
54from airflow.sdk.configuration import conf
55from airflow.sdk.definitions._internal.dag_parsing_context import _airflow_parsing_context_manager
56from airflow.sdk.definitions._internal.types import NOTSET, ArgNotSet, is_arg_set
57from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetNameRef, AssetUniqueKey, AssetUriRef
58from airflow.sdk.definitions.mappedoperator import MappedOperator
59from airflow.sdk.definitions.param import process_params
60from airflow.sdk.exceptions import (
61 AirflowException,
62 AirflowInactiveAssetInInletOrOutletException,
63 AirflowRuntimeError,
64 AirflowTaskTimeout,
65 ErrorType,
66 TaskDeferred,
67)
68from airflow.sdk.execution_time.callback_runner import create_executable_runner
69from airflow.sdk.execution_time.comms import (
70 AssetEventDagRunReferenceResult,
71 CommsDecoder,
72 DagRunStateResult,
73 DeferTask,
74 DRCount,
75 ErrorResponse,
76 GetDagRunState,
77 GetDRCount,
78 GetPreviousDagRun,
79 GetTaskBreadcrumbs,
80 GetTaskRescheduleStartDate,
81 GetTaskStates,
82 GetTICount,
83 InactiveAssetsResult,
84 PreviousDagRunResult,
85 RescheduleTask,
86 ResendLoggingFD,
87 RetryTask,
88 SentFDs,
89 SetRenderedFields,
90 SetRenderedMapIndex,
91 SkipDownstreamTasks,
92 StartupDetails,
93 SucceedTask,
94 TaskBreadcrumbsResult,
95 TaskRescheduleStartDate,
96 TaskState,
97 TaskStatesResult,
98 TICount,
99 ToSupervisor,
100 ToTask,
101 TriggerDagRun,
102 ValidateInletsAndOutlets,
103)
104from airflow.sdk.execution_time.context import (
105 ConnectionAccessor,
106 InletEventsAccessors,
107 MacrosAccessor,
108 OutletEventAccessors,
109 TriggeringAssetEventsAccessor,
110 VariableAccessor,
111 context_get_outlet_events,
112 context_to_airflow_vars,
113 get_previous_dagrun_success,
114 set_current_context,
115)
116from airflow.sdk.execution_time.sentry import Sentry
117from airflow.sdk.execution_time.xcom import XCom
118from airflow.sdk.observability.stats import Stats
119from airflow.sdk.timezone import coerce_datetime
120
121if TYPE_CHECKING:
122 import jinja2
123 from pendulum.datetime import DateTime
124 from structlog.typing import FilteringBoundLogger as Logger
125
126 from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator
127 from airflow.sdk.definitions.context import Context
128 from airflow.sdk.exceptions import DagRunTriggerException
129 from airflow.sdk.types import OutletEventAccessorsProtocol
130
131
132class TaskRunnerMarker:
133 """Marker for listener hooks, to properly detect from which component they are called."""
134
135
136# TODO: Move this entire class into a separate file:
137# `airflow/sdk/execution_time/task_instance.py`
138# or `airflow/sdk/execution_time/runtime_ti.py`
139class RuntimeTaskInstance(TaskInstance):
140 model_config = ConfigDict(arbitrary_types_allowed=True)
141
142 task: BaseOperator
143 bundle_instance: BaseDagBundle
144 _cached_template_context: Context | None = None
145 """The Task Instance context. This is used to cache get_template_context."""
146
147 _ti_context_from_server: Annotated[TIRunContext | None, Field(repr=False)] = None
148 """The Task Instance context from the API server, if any."""
149
150 max_tries: int = 0
151 """The maximum number of retries for the task."""
152
153 start_date: AwareDatetime
154 """Start date of the task instance."""
155
156 end_date: AwareDatetime | None = None
157
158 state: TaskInstanceState | None = None
159
160 is_mapped: bool | None = None
161 """True if the original task was mapped."""
162
163 rendered_map_index: str | None = None
164
165 sentry_integration: str = ""
166
167 def __rich_repr__(self):
168 yield "id", self.id
169 yield "task_id", self.task_id
170 yield "dag_id", self.dag_id
171 yield "run_id", self.run_id
172 yield "max_tries", self.max_tries
173 yield "task", type(self.task)
174 yield "start_date", self.start_date
175
176 __rich_repr__.angular = True # type: ignore[attr-defined]
177
178 def get_template_context(self) -> Context:
179 # TODO: Move this to `airflow.sdk.execution_time.context`
180 # once we port the entire context logic from airflow/utils/context.py ?
181 from airflow.plugins_manager import integrate_macros_plugins
182
183 integrate_macros_plugins()
184
185 dag_run_conf: dict[str, Any] | None = None
186 if from_server := self._ti_context_from_server:
187 dag_run_conf = from_server.dag_run.conf or dag_run_conf
188
189 validated_params = process_params(self.task.dag, self.task, dag_run_conf, suppress_exception=False)
190
191 # Cache the context object, which ensures that all calls to get_template_context
192 # are operating on the same context object.
193 self._cached_template_context: Context = self._cached_template_context or {
194 # From the Task Execution interface
195 "dag": self.task.dag,
196 "inlets": self.task.inlets,
197 "map_index_template": self.task.map_index_template,
198 "outlets": self.task.outlets,
199 "run_id": self.run_id,
200 "task": self.task,
201 "task_instance": self,
202 "ti": self,
203 "outlet_events": OutletEventAccessors(),
204 "inlet_events": InletEventsAccessors(self.task.inlets),
205 "macros": MacrosAccessor(),
206 "params": validated_params,
207 # TODO: Make this go through Public API longer term.
208 # "test_mode": task_instance.test_mode,
209 "var": {
210 "json": VariableAccessor(deserialize_json=True),
211 "value": VariableAccessor(deserialize_json=False),
212 },
213 "conn": ConnectionAccessor(),
214 }
215 if from_server:
216 dag_run = from_server.dag_run
217 context_from_server: Context = {
218 # TODO: Assess if we need to pass these through timezone.coerce_datetime
219 "dag_run": dag_run, # type: ignore[typeddict-item] # Removable after #46522
220 "triggering_asset_events": TriggeringAssetEventsAccessor.build(
221 AssetEventDagRunReferenceResult.from_asset_event_dag_run_reference(event)
222 for event in dag_run.consumed_asset_events
223 ),
224 "task_instance_key_str": f"{self.task.dag_id}__{self.task.task_id}__{dag_run.run_id}",
225 "task_reschedule_count": from_server.task_reschedule_count or 0,
226 "prev_start_date_success": lazy_object_proxy.Proxy(
227 lambda: coerce_datetime(get_previous_dagrun_success(self.id).start_date)
228 ),
229 "prev_end_date_success": lazy_object_proxy.Proxy(
230 lambda: coerce_datetime(get_previous_dagrun_success(self.id).end_date)
231 ),
232 }
233 self._cached_template_context.update(context_from_server)
234
235 if logical_date := coerce_datetime(dag_run.logical_date):
236 if TYPE_CHECKING:
237 assert isinstance(logical_date, DateTime)
238 ds = logical_date.strftime("%Y-%m-%d")
239 ds_nodash = ds.replace("-", "")
240 ts = logical_date.isoformat()
241 ts_nodash = logical_date.strftime("%Y%m%dT%H%M%S")
242 ts_nodash_with_tz = ts.replace("-", "").replace(":", "")
243 # logical_date and data_interval either coexist or be None together
244 self._cached_template_context.update(
245 {
246 # keys that depend on logical_date
247 "logical_date": logical_date,
248 "ds": ds,
249 "ds_nodash": ds_nodash,
250 "task_instance_key_str": f"{self.task.dag_id}__{self.task.task_id}__{ds_nodash}",
251 "ts": ts,
252 "ts_nodash": ts_nodash,
253 "ts_nodash_with_tz": ts_nodash_with_tz,
254 # keys that depend on data_interval
255 "data_interval_end": coerce_datetime(dag_run.data_interval_end),
256 "data_interval_start": coerce_datetime(dag_run.data_interval_start),
257 "prev_data_interval_start_success": lazy_object_proxy.Proxy(
258 lambda: coerce_datetime(get_previous_dagrun_success(self.id).data_interval_start)
259 ),
260 "prev_data_interval_end_success": lazy_object_proxy.Proxy(
261 lambda: coerce_datetime(get_previous_dagrun_success(self.id).data_interval_end)
262 ),
263 }
264 )
265
266 if from_server.upstream_map_indexes is not None:
267 # We stash this in here for later use, but we purposefully don't want to document it's
268 # existence. Should this be a private attribute on RuntimeTI instead perhaps?
269 setattr(self, "_upstream_map_indexes", from_server.upstream_map_indexes)
270
271 return self._cached_template_context
272
273 def render_templates(
274 self, context: Context | None = None, jinja_env: jinja2.Environment | None = None
275 ) -> BaseOperator:
276 """
277 Render templates in the operator fields.
278
279 If the task was originally mapped, this may replace ``self.task`` with
280 the unmapped, fully rendered BaseOperator. The original ``self.task``
281 before replacement is returned.
282 """
283 if not context:
284 context = self.get_template_context()
285 original_task = self.task
286
287 if TYPE_CHECKING:
288 assert context
289
290 ti = context["ti"]
291
292 if TYPE_CHECKING:
293 assert original_task
294 assert self.task
295 assert ti.task
296
297 # If self.task is mapped, this call replaces self.task to point to the
298 # unmapped BaseOperator created by this function! This is because the
299 # MappedOperator is useless for template rendering, and we need to be
300 # able to access the unmapped task instead.
301 self.task.render_template_fields(context, jinja_env)
302 self.is_mapped = original_task.is_mapped
303 return original_task
304
305 def xcom_pull(
306 self,
307 task_ids: str | Iterable[str] | None = None,
308 dag_id: str | None = None,
309 key: str = BaseXCom.XCOM_RETURN_KEY,
310 include_prior_dates: bool = False,
311 *,
312 map_indexes: int | Iterable[int] | None | ArgNotSet = NOTSET,
313 default: Any = None,
314 run_id: str | None = None,
315 ) -> Any:
316 """
317 Pull XComs either from the API server (BaseXCom) or from the custom XCOM backend if configured.
318
319 The pull can be filtered optionally by certain criterion.
320
321 :param key: A key for the XCom. If provided, only XComs with matching
322 keys will be returned. The default key is ``'return_value'``, also
323 available as constant ``XCOM_RETURN_KEY``. This key is automatically
324 given to XComs returned by tasks (as opposed to being pushed
325 manually). To remove the filter, pass *None*.
326 :param task_ids: Only XComs from tasks with matching ids will be
327 pulled. Pass *None* to remove the filter.
328 :param dag_id: If provided, only pulls XComs from this Dag. If *None*
329 (default), the Dag of the calling task is used.
330 :param map_indexes: If provided, only pull XComs with matching indexes.
331 If *None* (default), this is inferred from the task(s) being pulled
332 (see below for details).
333 :param include_prior_dates: If False, only XComs from the current
334 logical_date are returned. If *True*, XComs from previous dates
335 are returned as well.
336 :param run_id: If provided, only pulls XComs from a DagRun w/a matching run_id.
337 If *None* (default), the run_id of the calling task is used.
338
339 When pulling one single task (``task_id`` is *None* or a str) without
340 specifying ``map_indexes``, the return value is a single XCom entry
341 (map_indexes is set to map_index of the calling task instance).
342
343 When pulling task is mapped the specified ``map_index`` is used, so by default
344 pulling on mapped task will result in no matching XComs if the task instance
345 of the method call is not mapped. Otherwise, the map_index of the calling task
346 instance is used. Setting ``map_indexes`` to *None* will pull XCom as it would
347 from a non mapped task.
348
349 In either case, ``default`` (*None* if not specified) is returned if no
350 matching XComs are found.
351
352 When pulling multiple tasks (i.e. either ``task_id`` or ``map_index`` is
353 a non-str iterable), a list of matching XComs is returned. Elements in
354 the list is ordered by item ordering in ``task_id`` and ``map_index``.
355 """
356 if dag_id is None:
357 dag_id = self.dag_id
358 if run_id is None:
359 run_id = self.run_id
360
361 single_task_requested = isinstance(task_ids, (str, type(None)))
362 single_map_index_requested = isinstance(map_indexes, (int, type(None)))
363
364 if task_ids is None:
365 # default to the current task if not provided
366 task_ids = [self.task_id]
367 elif isinstance(task_ids, str):
368 task_ids = [task_ids]
369
370 # If map_indexes is not specified, pull xcoms from all map indexes for each task
371 if not is_arg_set(map_indexes):
372 xcoms: list[Any] = []
373 for t_id in task_ids:
374 values = XCom.get_all(
375 run_id=run_id,
376 key=key,
377 task_id=t_id,
378 dag_id=dag_id,
379 include_prior_dates=include_prior_dates,
380 )
381
382 if values is None:
383 xcoms.append(None)
384 else:
385 xcoms.extend(values)
386 # For single task pulling from unmapped task, return single value
387 if single_task_requested and len(xcoms) == 1:
388 return xcoms[0]
389 return xcoms
390
391 # Original logic when map_indexes is explicitly specified
392 map_indexes_iterable: Iterable[int | None] = []
393 if isinstance(map_indexes, int) or map_indexes is None:
394 map_indexes_iterable = [map_indexes]
395 elif isinstance(map_indexes, Iterable):
396 map_indexes_iterable = map_indexes
397 else:
398 raise TypeError(
399 f"Invalid type for map_indexes: expected int, iterable of ints, or None, got {type(map_indexes)}"
400 )
401
402 xcoms = []
403 for t_id, m_idx in product(task_ids, map_indexes_iterable):
404 value = XCom.get_one(
405 run_id=run_id,
406 key=key,
407 task_id=t_id,
408 dag_id=dag_id,
409 map_index=m_idx,
410 include_prior_dates=include_prior_dates,
411 )
412 if value is None:
413 xcoms.append(default)
414 else:
415 xcoms.append(value)
416
417 if single_task_requested and single_map_index_requested:
418 return xcoms[0]
419
420 return xcoms
421
422 def xcom_push(self, key: str, value: Any):
423 """
424 Make an XCom available for tasks to pull.
425
426 :param key: Key to store the value under.
427 :param value: Value to store. Only be JSON-serializable values may be used.
428 """
429 _xcom_push(self, key, value)
430
431 def get_relevant_upstream_map_indexes(
432 self, upstream: BaseOperator, ti_count: int | None, session: Any
433 ) -> int | range | None:
434 # TODO: Implement this method
435 return None
436
437 def get_first_reschedule_date(self, context: Context) -> AwareDatetime | None:
438 """Get the first reschedule date for the task instance if found, none otherwise."""
439 if context.get("task_reschedule_count", 0) == 0:
440 # If the task has not been rescheduled, there is no need to ask the supervisor
441 return None
442
443 max_tries: int = self.max_tries
444 retries: int = self.task.retries or 0
445 first_try_number = max_tries - retries + 1
446
447 log = structlog.get_logger(logger_name="task")
448
449 log.debug("Requesting first reschedule date from supervisor")
450
451 response = SUPERVISOR_COMMS.send(
452 msg=GetTaskRescheduleStartDate(ti_id=self.id, try_number=first_try_number)
453 )
454
455 if TYPE_CHECKING:
456 assert isinstance(response, TaskRescheduleStartDate)
457
458 return response.start_date
459
460 def get_previous_dagrun(self, state: str | None = None) -> DagRun | None:
461 """Return the previous Dag run before the given logical date, optionally filtered by state."""
462 context = self.get_template_context()
463 dag_run = context.get("dag_run")
464
465 log = structlog.get_logger(logger_name="task")
466
467 log.debug("Getting previous Dag run", dag_run=dag_run)
468
469 if dag_run is None:
470 return None
471
472 if dag_run.logical_date is None:
473 return None
474
475 response = SUPERVISOR_COMMS.send(
476 msg=GetPreviousDagRun(dag_id=self.dag_id, logical_date=dag_run.logical_date, state=state)
477 )
478
479 if TYPE_CHECKING:
480 assert isinstance(response, PreviousDagRunResult)
481
482 return response.dag_run
483
484 @staticmethod
485 def get_ti_count(
486 dag_id: str,
487 map_index: int | None = None,
488 task_ids: list[str] | None = None,
489 task_group_id: str | None = None,
490 logical_dates: list[datetime] | None = None,
491 run_ids: list[str] | None = None,
492 states: list[str] | None = None,
493 ) -> int:
494 """Return the number of task instances matching the given criteria."""
495 response = SUPERVISOR_COMMS.send(
496 GetTICount(
497 dag_id=dag_id,
498 map_index=map_index,
499 task_ids=task_ids,
500 task_group_id=task_group_id,
501 logical_dates=logical_dates,
502 run_ids=run_ids,
503 states=states,
504 ),
505 )
506
507 if TYPE_CHECKING:
508 assert isinstance(response, TICount)
509
510 return response.count
511
512 @staticmethod
513 def get_task_states(
514 dag_id: str,
515 map_index: int | None = None,
516 task_ids: list[str] | None = None,
517 task_group_id: str | None = None,
518 logical_dates: list[datetime] | None = None,
519 run_ids: list[str] | None = None,
520 ) -> dict[str, Any]:
521 """Return the task states matching the given criteria."""
522 response = SUPERVISOR_COMMS.send(
523 GetTaskStates(
524 dag_id=dag_id,
525 map_index=map_index,
526 task_ids=task_ids,
527 task_group_id=task_group_id,
528 logical_dates=logical_dates,
529 run_ids=run_ids,
530 ),
531 )
532
533 if TYPE_CHECKING:
534 assert isinstance(response, TaskStatesResult)
535
536 return response.task_states
537
538 @staticmethod
539 def get_task_breadcrumbs(dag_id: str, run_id: str) -> Iterable[dict[str, Any]]:
540 """Return task breadcrumbs for the given dag run."""
541 response = SUPERVISOR_COMMS.send(GetTaskBreadcrumbs(dag_id=dag_id, run_id=run_id))
542 if TYPE_CHECKING:
543 assert isinstance(response, TaskBreadcrumbsResult)
544 return response.breadcrumbs
545
546 @staticmethod
547 def get_dr_count(
548 dag_id: str,
549 logical_dates: list[datetime] | None = None,
550 run_ids: list[str] | None = None,
551 states: list[str] | None = None,
552 ) -> int:
553 """Return the number of Dag runs matching the given criteria."""
554 response = SUPERVISOR_COMMS.send(
555 GetDRCount(
556 dag_id=dag_id,
557 logical_dates=logical_dates,
558 run_ids=run_ids,
559 states=states,
560 ),
561 )
562
563 if TYPE_CHECKING:
564 assert isinstance(response, DRCount)
565
566 return response.count
567
568 @staticmethod
569 def get_dagrun_state(dag_id: str, run_id: str) -> str:
570 """Return the state of the Dag run with the given Run ID."""
571 response = SUPERVISOR_COMMS.send(msg=GetDagRunState(dag_id=dag_id, run_id=run_id))
572
573 if TYPE_CHECKING:
574 assert isinstance(response, DagRunStateResult)
575
576 return response.state
577
578 @property
579 def log_url(self) -> str:
580 run_id = quote(self.run_id)
581 base_url = conf.get("api", "base_url", fallback="http://localhost:8080/")
582 map_index_value = self.map_index
583 map_index = (
584 f"/mapped/{map_index_value}" if map_index_value is not None and map_index_value >= 0 else ""
585 )
586 try_number_value = self.try_number
587 try_number = (
588 f"?try_number={try_number_value}" if try_number_value is not None and try_number_value > 0 else ""
589 )
590 _log_uri = f"{base_url.rstrip('/')}/dags/{self.dag_id}/runs/{run_id}/tasks/{self.task_id}{map_index}{try_number}"
591 return _log_uri
592
593 @property
594 def mark_success_url(self) -> str:
595 """URL to mark TI success."""
596 return self.log_url
597
598
599def _xcom_push(ti: RuntimeTaskInstance, key: str, value: Any, mapped_length: int | None = None) -> None:
600 """Push a XCom through XCom.set, which pushes to XCom Backend if configured."""
601 # Private function, as we don't want to expose the ability to manually set `mapped_length` to SDK
602 # consumers
603
604 XCom.set(
605 key=key,
606 value=value,
607 dag_id=ti.dag_id,
608 task_id=ti.task_id,
609 run_id=ti.run_id,
610 map_index=ti.map_index,
611 _mapped_length=mapped_length,
612 )
613
614
615def _xcom_push_to_db(ti: RuntimeTaskInstance, key: str, value: Any) -> None:
616 """Push a XCom directly to metadata DB, bypassing custom xcom_backend."""
617 XCom._set_xcom_in_db(
618 key=key,
619 value=value,
620 dag_id=ti.dag_id,
621 task_id=ti.task_id,
622 run_id=ti.run_id,
623 map_index=ti.map_index,
624 )
625
626
627def parse(what: StartupDetails, log: Logger) -> RuntimeTaskInstance:
628 # TODO: Task-SDK:
629 # Using DagBag here is about 98% wrong, but it'll do for now
630 from airflow.dag_processing.dagbag import DagBag
631
632 bundle_info = what.bundle_info
633 bundle_instance = DagBundlesManager().get_bundle(
634 name=bundle_info.name,
635 version=bundle_info.version,
636 )
637 bundle_instance.initialize()
638
639 # Put bundle root on sys.path if needed. This allows the dag bundle to add
640 # code in util modules to be shared between files within the same bundle.
641 if (bundle_root := os.fspath(bundle_instance.path)) not in sys.path:
642 sys.path.append(bundle_root)
643
644 dag_absolute_path = os.fspath(Path(bundle_instance.path, what.dag_rel_path))
645 bag = DagBag(
646 dag_folder=dag_absolute_path,
647 include_examples=False,
648 safe_mode=False,
649 load_op_links=False,
650 bundle_name=bundle_info.name,
651 )
652 if TYPE_CHECKING:
653 assert what.ti.dag_id
654
655 try:
656 dag = bag.dags[what.ti.dag_id]
657 except KeyError:
658 log.error(
659 "Dag not found during start up", dag_id=what.ti.dag_id, bundle=bundle_info, path=what.dag_rel_path
660 )
661 sys.exit(1)
662
663 # install_loader()
664
665 try:
666 task = dag.task_dict[what.ti.task_id]
667 except KeyError:
668 log.error(
669 "Task not found in Dag during start up",
670 dag_id=dag.dag_id,
671 task_id=what.ti.task_id,
672 bundle=bundle_info,
673 path=what.dag_rel_path,
674 )
675 sys.exit(1)
676
677 if not isinstance(task, (BaseOperator, MappedOperator)):
678 raise TypeError(
679 f"task is of the wrong type, got {type(task)}, wanted {BaseOperator} or {MappedOperator}"
680 )
681
682 return RuntimeTaskInstance.model_construct(
683 **what.ti.model_dump(exclude_unset=True),
684 task=task,
685 bundle_instance=bundle_instance,
686 _ti_context_from_server=what.ti_context,
687 max_tries=what.ti_context.max_tries,
688 start_date=what.start_date,
689 state=TaskInstanceState.RUNNING,
690 sentry_integration=what.sentry_integration,
691 )
692
693
694# This global variable will be used by Connection/Variable/XCom classes, or other parts of the task's execution,
695# to send requests back to the supervisor process.
696#
697# Why it needs to be a global:
698# - Many parts of Airflow's codebase (e.g., connections, variables, and XComs) may rely on making dynamic requests
699# to the parent process during task execution.
700# - These calls occur in various locations and cannot easily pass the `CommsDecoder` instance through the
701# deeply nested execution stack.
702# - By defining `SUPERVISOR_COMMS` as a global, it ensures that this communication mechanism is readily
703# accessible wherever needed during task execution without modifying every layer of the call stack.
704SUPERVISOR_COMMS: CommsDecoder[ToTask, ToSupervisor]
705
706
707# State machine!
708# 1. Start up (receive details from supervisor)
709# 2. Execution (run task code, possibly send requests)
710# 3. Shutdown and report status
711
712
713def startup() -> tuple[RuntimeTaskInstance, Context, Logger]:
714 # The parent sends us a StartupDetails message un-prompted. After this, every single message is only sent
715 # in response to us sending a request.
716 log = structlog.get_logger(logger_name="task")
717
718 if os.environ.get("_AIRFLOW__REEXECUTED_PROCESS") == "1" and (
719 msgjson := os.environ.get("_AIRFLOW__STARTUP_MSG")
720 ):
721 # Clear any Kerberos replace cache if there is one, so new process can't reuse it.
722 os.environ.pop("KRB5CCNAME", None)
723 # entrypoint of re-exec process
724
725 msg: StartupDetails = TypeAdapter(StartupDetails).validate_json(msgjson)
726 reinit_supervisor_comms()
727
728 # We delay this message until _after_ we've got the logging re-configured, otherwise it will show up
729 # on stdout
730 log.debug("Using serialized startup message from environment", msg=msg)
731 else:
732 # normal entry point
733 msg = SUPERVISOR_COMMS._get_response() # type: ignore[assignment]
734
735 if not isinstance(msg, StartupDetails):
736 raise RuntimeError(f"Unhandled startup message {type(msg)} {msg}")
737
738 # setproctitle causes issue on Mac OS: https://github.com/benoitc/gunicorn/issues/3021
739 os_type = sys.platform
740 if os_type == "darwin":
741 log.debug("Mac OS detected, skipping setproctitle")
742 else:
743 from setproctitle import setproctitle
744
745 setproctitle(f"airflow worker -- {msg.ti.id}")
746
747 try:
748 get_listener_manager().hook.on_starting(component=TaskRunnerMarker())
749 except Exception:
750 log.exception("error calling listener")
751
752 with _airflow_parsing_context_manager(dag_id=msg.ti.dag_id, task_id=msg.ti.task_id):
753 ti = parse(msg, log)
754 log.debug("Dag file parsed", file=msg.dag_rel_path)
755
756 run_as_user = getattr(ti.task, "run_as_user", None) or conf.get(
757 "core", "default_impersonation", fallback=None
758 )
759
760 if os.environ.get("_AIRFLOW__REEXECUTED_PROCESS") != "1" and run_as_user and run_as_user != getuser():
761 # enters here for re-exec process
762 os.environ["_AIRFLOW__REEXECUTED_PROCESS"] = "1"
763 # store startup message in environment for re-exec process
764 os.environ["_AIRFLOW__STARTUP_MSG"] = msg.model_dump_json()
765 os.set_inheritable(SUPERVISOR_COMMS.socket.fileno(), True)
766
767 # Import main directly from the module instead of re-executing the file.
768 # This ensures that when other parts modules import
769 # airflow.sdk.execution_time.task_runner, they get the same module instance
770 # with the properly initialized SUPERVISOR_COMMS global variable.
771 # If we re-executed the module with `python -m`, it would load as __main__ and future
772 # imports would get a fresh copy without the initialized globals.
773 rexec_python_code = "from airflow.sdk.execution_time.task_runner import main; main()"
774 cmd = ["sudo", "-E", "-H", "-u", run_as_user, sys.executable, "-c", rexec_python_code]
775 log.info(
776 "Running command",
777 command=cmd,
778 )
779 os.execvp("sudo", cmd)
780
781 # ideally, we should never reach here, but if we do, we should return None, None, None
782 return None, None, None
783
784 return ti, ti.get_template_context(), log
785
786
787def _serialize_rendered_fields(task: AbstractOperator) -> dict[str, JsonValue]:
788 # TODO: Port one of the following to Task SDK
789 # airflow.serialization.helpers.serialize_template_field or
790 # airflow.models.renderedtifields.get_serialized_template_fields
791 from airflow.sdk._shared.secrets_masker import redact
792 from airflow.serialization.helpers import serialize_template_field
793
794 rendered_fields = {}
795 for field in task.template_fields:
796 value = getattr(task, field)
797 serialized = serialize_template_field(value, field)
798 # Redact secrets in the task process itself before sending to API server
799 # This ensures that the secrets those are registered via mask_secret() on workers / dag processor are properly masked
800 # on the UI.
801 rendered_fields[field] = redact(serialized, field)
802
803 return rendered_fields # type: ignore[return-value] # Convince mypy that this is OK since we pass JsonValue to redact, so it will return the same
804
805
806def _build_asset_profiles(lineage_objects: list) -> Iterator[AssetProfile]:
807 # Lineage can have other types of objects besides assets, so we need to process them a bit.
808 for obj in lineage_objects or ():
809 if isinstance(obj, Asset):
810 yield AssetProfile(name=obj.name, uri=obj.uri, type=Asset.__name__)
811 elif isinstance(obj, AssetNameRef):
812 yield AssetProfile(name=obj.name, type=AssetNameRef.__name__)
813 elif isinstance(obj, AssetUriRef):
814 yield AssetProfile(uri=obj.uri, type=AssetUriRef.__name__)
815 elif isinstance(obj, AssetAlias):
816 yield AssetProfile(name=obj.name, type=AssetAlias.__name__)
817
818
819def _serialize_outlet_events(events: OutletEventAccessorsProtocol) -> Iterator[dict[str, JsonValue]]:
820 if TYPE_CHECKING:
821 assert isinstance(events, OutletEventAccessors)
822 # We just collect everything the user recorded in the accessors.
823 # Further filtering will be done in the API server.
824 for key, accessor in events._dict.items():
825 if isinstance(key, AssetUniqueKey):
826 yield {"dest_asset_key": attrs.asdict(key), "extra": accessor.extra}
827 for alias_event in accessor.asset_alias_events:
828 yield attrs.asdict(alias_event)
829
830
831def _prepare(ti: RuntimeTaskInstance, log: Logger, context: Context) -> ToSupervisor | None:
832 ti.hostname = get_hostname()
833 ti.task = ti.task.prepare_for_execution()
834 # Since context is now cached, and calling `ti.get_template_context` will return the same dict, we want to
835 # update the value of the task that is sent from there
836 context["task"] = ti.task
837
838 jinja_env = ti.task.dag.get_template_env()
839 ti.render_templates(context=context, jinja_env=jinja_env)
840
841 if rendered_fields := _serialize_rendered_fields(ti.task):
842 # so that we do not call the API unnecessarily
843 SUPERVISOR_COMMS.send(msg=SetRenderedFields(rendered_fields=rendered_fields))
844
845 # Try to render map_index_template early with available context (will be re-rendered after execution)
846 # This provides a partial label during task execution for templates using pre-execution context
847 # If rendering fails here, we suppress the error since it will be re-rendered after execution
848 try:
849 if rendered_map_index := _render_map_index(context, ti=ti, log=log):
850 ti.rendered_map_index = rendered_map_index
851 log.debug("Sending early rendered map index", length=len(rendered_map_index))
852 SUPERVISOR_COMMS.send(msg=SetRenderedMapIndex(rendered_map_index=rendered_map_index))
853 except Exception:
854 log.debug(
855 "Early rendering of map_index_template failed, will retry after task execution", exc_info=True
856 )
857
858 _validate_task_inlets_and_outlets(ti=ti, log=log)
859
860 try:
861 # TODO: Call pre execute etc.
862 get_listener_manager().hook.on_task_instance_running(
863 previous_state=TaskInstanceState.QUEUED, task_instance=ti
864 )
865 except Exception:
866 log.exception("error calling listener")
867
868 # No error, carry on and execute the task
869 return None
870
871
872def _validate_task_inlets_and_outlets(*, ti: RuntimeTaskInstance, log: Logger) -> None:
873 if not ti.task.inlets and not ti.task.outlets:
874 return
875
876 inactive_assets_resp = SUPERVISOR_COMMS.send(msg=ValidateInletsAndOutlets(ti_id=ti.id))
877 if TYPE_CHECKING:
878 assert isinstance(inactive_assets_resp, InactiveAssetsResult)
879 if inactive_assets := inactive_assets_resp.inactive_assets:
880 raise AirflowInactiveAssetInInletOrOutletException(
881 inactive_asset_keys=[
882 AssetUniqueKey.from_profile(asset_profile) for asset_profile in inactive_assets
883 ]
884 )
885
886
887def _defer_task(
888 defer: TaskDeferred, ti: RuntimeTaskInstance, log: Logger
889) -> tuple[ToSupervisor, TaskInstanceState]:
890 # TODO: Should we use structlog.bind_contextvars here for dag_id, task_id & run_id?
891
892 log.info("Pausing task as DEFERRED. ", dag_id=ti.dag_id, task_id=ti.task_id, run_id=ti.run_id)
893 classpath, trigger_kwargs = defer.trigger.serialize()
894
895 msg = DeferTask(
896 classpath=classpath,
897 trigger_kwargs=trigger_kwargs,
898 trigger_timeout=defer.timeout,
899 next_method=defer.method_name,
900 next_kwargs=defer.kwargs or {},
901 )
902 state = TaskInstanceState.DEFERRED
903
904 return msg, state
905
906
907@Sentry.enrich_errors
908def run(
909 ti: RuntimeTaskInstance,
910 context: Context,
911 log: Logger,
912) -> tuple[TaskInstanceState, ToSupervisor | None, BaseException | None]:
913 """Run the task in this process."""
914 import signal
915
916 from airflow.sdk.exceptions import (
917 AirflowFailException,
918 AirflowRescheduleException,
919 AirflowSensorTimeout,
920 AirflowSkipException,
921 AirflowTaskTerminated,
922 DagRunTriggerException,
923 DownstreamTasksSkipped,
924 TaskDeferred,
925 )
926
927 if TYPE_CHECKING:
928 assert ti.task is not None
929 assert isinstance(ti.task, BaseOperator)
930
931 parent_pid = os.getpid()
932
933 def _on_term(signum, frame):
934 pid = os.getpid()
935 if pid != parent_pid:
936 return
937
938 ti.task.on_kill()
939
940 signal.signal(signal.SIGTERM, _on_term)
941
942 msg: ToSupervisor | None = None
943 state: TaskInstanceState
944 error: BaseException | None = None
945
946 try:
947 # First, clear the xcom data sent from server
948 if ti._ti_context_from_server and (keys_to_delete := ti._ti_context_from_server.xcom_keys_to_clear):
949 for x in keys_to_delete:
950 log.debug("Clearing XCom with key", key=x)
951 XCom.delete(
952 key=x,
953 dag_id=ti.dag_id,
954 task_id=ti.task_id,
955 run_id=ti.run_id,
956 map_index=ti.map_index,
957 )
958
959 with set_current_context(context):
960 # This is the earliest that we can render templates -- as if it excepts for any reason we need to
961 # catch it and handle it like a normal task failure
962 if early_exit := _prepare(ti, log, context):
963 msg = early_exit
964 ti.state = state = TaskInstanceState.FAILED
965 return state, msg, error
966
967 try:
968 result = _execute_task(context=context, ti=ti, log=log)
969 except Exception:
970 import jinja2
971
972 # If the task failed, swallow rendering error so it doesn't mask the main error.
973 with contextlib.suppress(jinja2.TemplateSyntaxError, jinja2.UndefinedError):
974 previous_rendered_map_index = ti.rendered_map_index
975 ti.rendered_map_index = _render_map_index(context, ti=ti, log=log)
976 # Send update only if value changed (e.g., user set context variables during execution)
977 if ti.rendered_map_index and ti.rendered_map_index != previous_rendered_map_index:
978 SUPERVISOR_COMMS.send(
979 msg=SetRenderedMapIndex(rendered_map_index=ti.rendered_map_index)
980 )
981 raise
982 else: # If the task succeeded, render normally to let rendering error bubble up.
983 previous_rendered_map_index = ti.rendered_map_index
984 ti.rendered_map_index = _render_map_index(context, ti=ti, log=log)
985 # Send update only if value changed (e.g., user set context variables during execution)
986 if ti.rendered_map_index and ti.rendered_map_index != previous_rendered_map_index:
987 SUPERVISOR_COMMS.send(msg=SetRenderedMapIndex(rendered_map_index=ti.rendered_map_index))
988
989 _push_xcom_if_needed(result, ti, log)
990
991 msg, state = _handle_current_task_success(context, ti)
992 except DownstreamTasksSkipped as skip:
993 log.info("Skipping downstream tasks.")
994 tasks_to_skip = skip.tasks if isinstance(skip.tasks, list) else [skip.tasks]
995 SUPERVISOR_COMMS.send(msg=SkipDownstreamTasks(tasks=tasks_to_skip))
996 msg, state = _handle_current_task_success(context, ti)
997 except DagRunTriggerException as drte:
998 msg, state = _handle_trigger_dag_run(drte, context, ti, log)
999 except TaskDeferred as defer:
1000 msg, state = _defer_task(defer, ti, log)
1001 except AirflowSkipException as e:
1002 if e.args:
1003 log.info("Skipping task.", reason=e.args[0])
1004 msg = TaskState(
1005 state=TaskInstanceState.SKIPPED,
1006 end_date=datetime.now(tz=timezone.utc),
1007 rendered_map_index=ti.rendered_map_index,
1008 )
1009 state = TaskInstanceState.SKIPPED
1010 except AirflowRescheduleException as reschedule:
1011 log.info("Rescheduling task, marking task as UP_FOR_RESCHEDULE")
1012 msg = RescheduleTask(
1013 reschedule_date=reschedule.reschedule_date, end_date=datetime.now(tz=timezone.utc)
1014 )
1015 state = TaskInstanceState.UP_FOR_RESCHEDULE
1016 except (AirflowFailException, AirflowSensorTimeout) as e:
1017 # If AirflowFailException is raised, task should not retry.
1018 # If a sensor in reschedule mode reaches timeout, task should not retry.
1019 log.exception("Task failed with exception")
1020 ti.end_date = datetime.now(tz=timezone.utc)
1021 msg = TaskState(
1022 state=TaskInstanceState.FAILED,
1023 end_date=ti.end_date,
1024 rendered_map_index=ti.rendered_map_index,
1025 )
1026 state = TaskInstanceState.FAILED
1027 error = e
1028 except (AirflowTaskTimeout, AirflowException, AirflowRuntimeError) as e:
1029 # We should allow retries if the task has defined it.
1030 log.exception("Task failed with exception")
1031 msg, state = _handle_current_task_failed(ti)
1032 error = e
1033 except AirflowTaskTerminated as e:
1034 # External state updates are already handled with `ti_heartbeat` and will be
1035 # updated already be another UI API. So, these exceptions should ideally never be thrown.
1036 # If these are thrown, we should mark the TI state as failed.
1037 log.exception("Task failed with exception")
1038 ti.end_date = datetime.now(tz=timezone.utc)
1039 msg = TaskState(
1040 state=TaskInstanceState.FAILED,
1041 end_date=ti.end_date,
1042 rendered_map_index=ti.rendered_map_index,
1043 )
1044 state = TaskInstanceState.FAILED
1045 error = e
1046 except SystemExit as e:
1047 # SystemExit needs to be retried if they are eligible.
1048 log.error("Task exited", exit_code=e.code)
1049 msg, state = _handle_current_task_failed(ti)
1050 error = e
1051 except BaseException as e:
1052 log.exception("Task failed with exception")
1053 msg, state = _handle_current_task_failed(ti)
1054 error = e
1055 finally:
1056 if msg:
1057 SUPERVISOR_COMMS.send(msg=msg)
1058
1059 # Return the message to make unit tests easier too
1060 ti.state = state
1061 return state, msg, error
1062
1063
1064def _handle_current_task_success(
1065 context: Context,
1066 ti: RuntimeTaskInstance,
1067) -> tuple[SucceedTask, TaskInstanceState]:
1068 end_date = datetime.now(tz=timezone.utc)
1069 ti.end_date = end_date
1070
1071 # Record operator and task instance success metrics
1072 operator = ti.task.__class__.__name__
1073 stats_tags = {"dag_id": ti.dag_id, "task_id": ti.task_id}
1074
1075 Stats.incr(f"operator_successes_{operator}", tags=stats_tags)
1076 # Same metric with tagging
1077 Stats.incr("operator_successes", tags={**stats_tags, "operator": operator})
1078 Stats.incr("ti_successes", tags=stats_tags)
1079
1080 task_outlets = list(_build_asset_profiles(ti.task.outlets))
1081 outlet_events = list(_serialize_outlet_events(context["outlet_events"]))
1082 msg = SucceedTask(
1083 end_date=end_date,
1084 task_outlets=task_outlets,
1085 outlet_events=outlet_events,
1086 rendered_map_index=ti.rendered_map_index,
1087 )
1088 return msg, TaskInstanceState.SUCCESS
1089
1090
1091def _handle_current_task_failed(
1092 ti: RuntimeTaskInstance,
1093) -> tuple[RetryTask, TaskInstanceState] | tuple[TaskState, TaskInstanceState]:
1094 end_date = datetime.now(tz=timezone.utc)
1095 ti.end_date = end_date
1096 if ti._ti_context_from_server and ti._ti_context_from_server.should_retry:
1097 return RetryTask(end_date=end_date), TaskInstanceState.UP_FOR_RETRY
1098 return (
1099 TaskState(
1100 state=TaskInstanceState.FAILED, end_date=end_date, rendered_map_index=ti.rendered_map_index
1101 ),
1102 TaskInstanceState.FAILED,
1103 )
1104
1105
1106def _handle_trigger_dag_run(
1107 drte: DagRunTriggerException, context: Context, ti: RuntimeTaskInstance, log: Logger
1108) -> tuple[ToSupervisor, TaskInstanceState]:
1109 """Handle exception from TriggerDagRunOperator."""
1110 log.info("Triggering Dag Run.", trigger_dag_id=drte.trigger_dag_id)
1111 comms_msg = SUPERVISOR_COMMS.send(
1112 TriggerDagRun(
1113 dag_id=drte.trigger_dag_id,
1114 run_id=drte.dag_run_id,
1115 logical_date=drte.logical_date,
1116 conf=drte.conf,
1117 reset_dag_run=drte.reset_dag_run,
1118 ),
1119 )
1120
1121 if isinstance(comms_msg, ErrorResponse) and comms_msg.error == ErrorType.DAGRUN_ALREADY_EXISTS:
1122 if drte.skip_when_already_exists:
1123 log.info(
1124 "Dag Run already exists, skipping task as skip_when_already_exists is set to True.",
1125 dag_id=drte.trigger_dag_id,
1126 )
1127 msg = TaskState(
1128 state=TaskInstanceState.SKIPPED,
1129 end_date=datetime.now(tz=timezone.utc),
1130 rendered_map_index=ti.rendered_map_index,
1131 )
1132 state = TaskInstanceState.SKIPPED
1133 else:
1134 log.error("Dag Run already exists, marking task as failed.", dag_id=drte.trigger_dag_id)
1135 msg = TaskState(
1136 state=TaskInstanceState.FAILED,
1137 end_date=datetime.now(tz=timezone.utc),
1138 rendered_map_index=ti.rendered_map_index,
1139 )
1140 state = TaskInstanceState.FAILED
1141
1142 return msg, state
1143
1144 log.info("Dag Run triggered successfully.", trigger_dag_id=drte.trigger_dag_id)
1145
1146 # Store the run id from the dag run (either created or found above) to
1147 # be used when creating the extra link on the webserver.
1148 ti.xcom_push(key="trigger_run_id", value=drte.dag_run_id)
1149
1150 if drte.deferrable:
1151 from airflow.providers.standard.triggers.external_task import DagStateTrigger
1152
1153 defer = TaskDeferred(
1154 trigger=DagStateTrigger(
1155 dag_id=drte.trigger_dag_id,
1156 states=drte.allowed_states + drte.failed_states, # type: ignore[arg-type]
1157 # Don't filter by execution_dates when run_ids is provided.
1158 # run_id uniquely identifies a DAG run, and when reset_dag_run=True,
1159 # drte.logical_date might be a newly calculated value that doesn't match
1160 # the persisted logical_date in the database, causing the trigger to never find the run.
1161 execution_dates=None,
1162 run_ids=[drte.dag_run_id],
1163 poll_interval=drte.poke_interval,
1164 ),
1165 method_name="execute_complete",
1166 )
1167 return _defer_task(defer, ti, log)
1168 if drte.wait_for_completion:
1169 while True:
1170 log.info(
1171 "Waiting for dag run to complete execution in allowed state.",
1172 dag_id=drte.trigger_dag_id,
1173 run_id=drte.dag_run_id,
1174 allowed_state=drte.allowed_states,
1175 )
1176 time.sleep(drte.poke_interval)
1177
1178 comms_msg = SUPERVISOR_COMMS.send(
1179 GetDagRunState(dag_id=drte.trigger_dag_id, run_id=drte.dag_run_id)
1180 )
1181 if TYPE_CHECKING:
1182 assert isinstance(comms_msg, DagRunStateResult)
1183 if comms_msg.state in drte.failed_states:
1184 log.error(
1185 "DagRun finished with failed state.", dag_id=drte.trigger_dag_id, state=comms_msg.state
1186 )
1187 msg = TaskState(
1188 state=TaskInstanceState.FAILED,
1189 end_date=datetime.now(tz=timezone.utc),
1190 rendered_map_index=ti.rendered_map_index,
1191 )
1192 state = TaskInstanceState.FAILED
1193 return msg, state
1194 if comms_msg.state in drte.allowed_states:
1195 log.info(
1196 "DagRun finished with allowed state.", dag_id=drte.trigger_dag_id, state=comms_msg.state
1197 )
1198 break
1199 log.debug(
1200 "DagRun not yet in allowed or failed state.",
1201 dag_id=drte.trigger_dag_id,
1202 state=comms_msg.state,
1203 )
1204
1205 return _handle_current_task_success(context, ti)
1206
1207
1208def _run_task_state_change_callbacks(
1209 task: BaseOperator,
1210 kind: Literal[
1211 "on_execute_callback",
1212 "on_failure_callback",
1213 "on_success_callback",
1214 "on_retry_callback",
1215 "on_skipped_callback",
1216 ],
1217 context: Context,
1218 log: Logger,
1219) -> None:
1220 callback: Callable[[Context], None]
1221 for i, callback in enumerate(getattr(task, kind)):
1222 try:
1223 create_executable_runner(callback, context_get_outlet_events(context), logger=log).run(context)
1224 except Exception:
1225 log.exception("Failed to run task callback", kind=kind, index=i, callback=callback)
1226
1227
1228def _send_error_email_notification(
1229 task: BaseOperator | MappedOperator,
1230 ti: RuntimeTaskInstance,
1231 context: Context,
1232 error: BaseException | str | None,
1233 log: Logger,
1234) -> None:
1235 """Send email notification for task errors using SmtpNotifier."""
1236 try:
1237 from airflow.providers.smtp.notifications.smtp import SmtpNotifier
1238 except ImportError:
1239 log.error(
1240 "Failed to send task failure or retry email notification: "
1241 "`apache-airflow-providers-smtp` is not installed. "
1242 "Install this provider to enable email notifications."
1243 )
1244 return
1245
1246 if not task.email:
1247 return
1248
1249 subject_template_file = conf.get("email", "subject_template", fallback=None)
1250
1251 # Read the template file if configured
1252 if subject_template_file and Path(subject_template_file).exists():
1253 subject = Path(subject_template_file).read_text()
1254 else:
1255 # Fallback to default
1256 subject = "Airflow alert: {{ti}}"
1257
1258 html_content_template_file = conf.get("email", "html_content_template", fallback=None)
1259
1260 # Read the template file if configured
1261 if html_content_template_file and Path(html_content_template_file).exists():
1262 html_content = Path(html_content_template_file).read_text()
1263 else:
1264 # Fallback to default
1265 # For reporting purposes, we report based on 1-indexed,
1266 # not 0-indexed lists (i.e. Try 1 instead of Try 0 for the first attempt).
1267 html_content = (
1268 "Try {{try_number}} out of {{max_tries + 1}}<br>"
1269 "Exception:<br>{{exception_html}}<br>"
1270 'Log: <a href="{{ti.log_url}}">Link</a><br>'
1271 "Host: {{ti.hostname}}<br>"
1272 'Mark success: <a href="{{ti.mark_success_url}}">Link</a><br>'
1273 )
1274
1275 # Add exception_html to context for template rendering
1276 import html
1277
1278 exception_html = html.escape(str(error)).replace("\n", "<br>")
1279 additional_context = {
1280 "exception": error,
1281 "exception_html": exception_html,
1282 "try_number": ti.try_number,
1283 "max_tries": ti.max_tries,
1284 }
1285 email_context = {**context, **additional_context}
1286 to_emails = task.email
1287 if not to_emails:
1288 return
1289
1290 try:
1291 notifier = SmtpNotifier(
1292 to=to_emails,
1293 subject=subject,
1294 html_content=html_content,
1295 from_email=conf.get("email", "from_email", fallback="airflow@airflow"),
1296 )
1297 notifier(email_context)
1298 except Exception:
1299 log.exception("Failed to send email notification")
1300
1301
1302def _execute_task(context: Context, ti: RuntimeTaskInstance, log: Logger):
1303 """Execute Task (optionally with a Timeout) and push Xcom results."""
1304 task = ti.task
1305 execute = task.execute
1306
1307 if ti._ti_context_from_server and (next_method := ti._ti_context_from_server.next_method):
1308 from airflow.serialization.serialized_objects import BaseSerialization
1309
1310 kwargs = BaseSerialization.deserialize(ti._ti_context_from_server.next_kwargs or {})
1311
1312 execute = functools.partial(task.resume_execution, next_method=next_method, next_kwargs=kwargs)
1313
1314 ctx = contextvars.copy_context()
1315 # Populate the context var so ExecutorSafeguard doesn't complain
1316 ctx.run(ExecutorSafeguard.tracker.set, task)
1317
1318 # Export context in os.environ to make it available for operators to use.
1319 airflow_context_vars = context_to_airflow_vars(context, in_env_var_format=True)
1320 os.environ.update(airflow_context_vars)
1321
1322 outlet_events = context_get_outlet_events(context)
1323
1324 if (pre_execute_hook := task._pre_execute_hook) is not None:
1325 create_executable_runner(pre_execute_hook, outlet_events, logger=log).run(context)
1326 if getattr(pre_execute_hook := task.pre_execute, "__func__", None) is not BaseOperator.pre_execute:
1327 create_executable_runner(pre_execute_hook, outlet_events, logger=log).run(context)
1328
1329 _run_task_state_change_callbacks(task, "on_execute_callback", context, log)
1330
1331 if task.execution_timeout:
1332 from airflow.sdk.execution_time.timeout import timeout
1333
1334 # TODO: handle timeout in case of deferral
1335 timeout_seconds = task.execution_timeout.total_seconds()
1336 try:
1337 # It's possible we're already timed out, so fast-fail if true
1338 if timeout_seconds <= 0:
1339 raise AirflowTaskTimeout()
1340 # Run task in timeout wrapper
1341 with timeout(timeout_seconds):
1342 result = ctx.run(execute, context=context)
1343 except AirflowTaskTimeout:
1344 task.on_kill()
1345 raise
1346 else:
1347 result = ctx.run(execute, context=context)
1348
1349 if (post_execute_hook := task._post_execute_hook) is not None:
1350 create_executable_runner(post_execute_hook, outlet_events, logger=log).run(context, result)
1351 if getattr(post_execute_hook := task.post_execute, "__func__", None) is not BaseOperator.post_execute:
1352 create_executable_runner(post_execute_hook, outlet_events, logger=log).run(context)
1353
1354 return result
1355
1356
1357def _render_map_index(context: Context, ti: RuntimeTaskInstance, log: Logger) -> str | None:
1358 """Render named map index if the Dag author defined map_index_template at the task level."""
1359 if (template := context.get("map_index_template")) is None:
1360 return None
1361 log.debug("Rendering map_index_template", template_length=len(template))
1362 jinja_env = ti.task.dag.get_template_env()
1363 rendered_map_index = jinja_env.from_string(template).render(context)
1364 log.debug("Map index rendered", length=len(rendered_map_index))
1365 return rendered_map_index
1366
1367
1368def _push_xcom_if_needed(result: Any, ti: RuntimeTaskInstance, log: Logger):
1369 """Push XCom values when task has ``do_xcom_push`` set to ``True`` and the task returns a result."""
1370 if ti.task.do_xcom_push:
1371 xcom_value = result
1372 else:
1373 xcom_value = None
1374
1375 has_mapped_dep = next(ti.task.iter_mapped_dependants(), None) is not None
1376 if xcom_value is None:
1377 if not ti.is_mapped and has_mapped_dep:
1378 # Uhoh, a downstream mapped task depends on us to push something to map over
1379 from airflow.sdk.exceptions import XComForMappingNotPushed
1380
1381 raise XComForMappingNotPushed()
1382 return
1383
1384 mapped_length: int | None = None
1385 if not ti.is_mapped and has_mapped_dep:
1386 from airflow.sdk.definitions.mappedoperator import is_mappable_value
1387 from airflow.sdk.exceptions import UnmappableXComTypePushed
1388
1389 if not is_mappable_value(xcom_value):
1390 raise UnmappableXComTypePushed(xcom_value)
1391 mapped_length = len(xcom_value)
1392
1393 log.info("Pushing xcom", ti=ti)
1394
1395 # If the task has multiple outputs, push each output as a separate XCom.
1396 if ti.task.multiple_outputs:
1397 if not isinstance(xcom_value, Mapping):
1398 raise TypeError(
1399 f"Returned output was type {type(xcom_value)} expected dictionary for multiple_outputs"
1400 )
1401 for key in xcom_value.keys():
1402 if not isinstance(key, str):
1403 raise TypeError(
1404 "Returned dictionary keys must be strings when using "
1405 f"multiple_outputs, found {key} ({type(key)}) instead"
1406 )
1407 for k, v in result.items():
1408 ti.xcom_push(k, v)
1409
1410 _xcom_push(ti, BaseXCom.XCOM_RETURN_KEY, result, mapped_length=mapped_length)
1411
1412
1413def finalize(
1414 ti: RuntimeTaskInstance,
1415 state: TaskInstanceState,
1416 context: Context,
1417 log: Logger,
1418 error: BaseException | None = None,
1419):
1420 # Record task duration metrics for all terminal states
1421 if ti.start_date and ti.end_date:
1422 duration_ms = (ti.end_date - ti.start_date).total_seconds() * 1000
1423 stats_tags = {"dag_id": ti.dag_id, "task_id": ti.task_id}
1424
1425 Stats.timing(f"dag.{ti.dag_id}.{ti.task_id}.duration", duration_ms)
1426 Stats.timing("task.duration", duration_ms, tags=stats_tags)
1427
1428 task = ti.task
1429 # Pushing xcom for each operator extra links defined on the operator only.
1430 for oe in task.operator_extra_links:
1431 try:
1432 link, xcom_key = oe.get_link(operator=task, ti_key=ti), oe.xcom_key # type: ignore[arg-type]
1433 log.debug("Setting xcom for operator extra link", link=link, xcom_key=xcom_key)
1434 _xcom_push_to_db(ti, key=xcom_key, value=link)
1435 except Exception:
1436 log.exception(
1437 "Failed to push an xcom for task operator extra link",
1438 link_name=oe.name,
1439 xcom_key=oe.xcom_key,
1440 ti=ti,
1441 )
1442
1443 if getattr(ti.task, "overwrite_rtif_after_execution", False):
1444 log.debug("Overwriting Rendered template fields.")
1445 if ti.task.template_fields:
1446 SUPERVISOR_COMMS.send(SetRenderedFields(rendered_fields=_serialize_rendered_fields(ti.task)))
1447
1448 log.debug("Running finalizers", ti=ti)
1449 if state == TaskInstanceState.SUCCESS:
1450 _run_task_state_change_callbacks(task, "on_success_callback", context, log)
1451 try:
1452 get_listener_manager().hook.on_task_instance_success(
1453 previous_state=TaskInstanceState.RUNNING, task_instance=ti
1454 )
1455 except Exception:
1456 log.exception("error calling listener")
1457 elif state == TaskInstanceState.SKIPPED:
1458 _run_task_state_change_callbacks(task, "on_skipped_callback", context, log)
1459 elif state == TaskInstanceState.UP_FOR_RETRY:
1460 _run_task_state_change_callbacks(task, "on_retry_callback", context, log)
1461 try:
1462 get_listener_manager().hook.on_task_instance_failed(
1463 previous_state=TaskInstanceState.RUNNING, task_instance=ti, error=error
1464 )
1465 except Exception:
1466 log.exception("error calling listener")
1467 if error and task.email_on_retry and task.email:
1468 _send_error_email_notification(task, ti, context, error, log)
1469 elif state == TaskInstanceState.FAILED:
1470 _run_task_state_change_callbacks(task, "on_failure_callback", context, log)
1471 try:
1472 get_listener_manager().hook.on_task_instance_failed(
1473 previous_state=TaskInstanceState.RUNNING, task_instance=ti, error=error
1474 )
1475 except Exception:
1476 log.exception("error calling listener")
1477 if error and task.email_on_failure and task.email:
1478 _send_error_email_notification(task, ti, context, error, log)
1479
1480 try:
1481 get_listener_manager().hook.before_stopping(component=TaskRunnerMarker())
1482 except Exception:
1483 log.exception("error calling listener")
1484
1485
1486def main():
1487 log = structlog.get_logger(logger_name="task")
1488
1489 global SUPERVISOR_COMMS
1490 SUPERVISOR_COMMS = CommsDecoder[ToTask, ToSupervisor](log=log)
1491
1492 try:
1493 ti, context, log = startup()
1494 with BundleVersionLock(
1495 bundle_name=ti.bundle_instance.name,
1496 bundle_version=ti.bundle_instance.version,
1497 ):
1498 state, _, error = run(ti, context, log)
1499 context["exception"] = error
1500 finalize(ti, state, context, log, error)
1501 except KeyboardInterrupt:
1502 log.exception("Ctrl-c hit")
1503 exit(2)
1504 except Exception:
1505 log.exception("Top level error")
1506 exit(1)
1507 finally:
1508 # Ensure the request socket is closed on the child side in all circumstances
1509 # before the process fully terminates.
1510 if SUPERVISOR_COMMS and SUPERVISOR_COMMS.socket:
1511 with suppress(Exception):
1512 SUPERVISOR_COMMS.socket.close()
1513
1514
1515def reinit_supervisor_comms() -> None:
1516 """
1517 Re-initialize supervisor comms and logging channel in subprocess.
1518
1519 This is not needed for most cases, but is used when either we re-launch the process via sudo for
1520 run_as_user, or from inside the python code in a virtualenv (et al.) operator to re-connect so those tasks
1521 can continue to access variables etc.
1522 """
1523 import socket
1524
1525 if "SUPERVISOR_COMMS" not in globals():
1526 global SUPERVISOR_COMMS
1527 log = structlog.get_logger(logger_name="task")
1528
1529 fd = int(os.environ.get("__AIRFLOW_SUPERVISOR_FD", "0"))
1530
1531 SUPERVISOR_COMMS = CommsDecoder[ToTask, ToSupervisor](log=log, socket=socket.socket(fileno=fd))
1532
1533 logs = SUPERVISOR_COMMS.send(ResendLoggingFD())
1534 if isinstance(logs, SentFDs):
1535 from airflow.sdk.log import configure_logging
1536
1537 log_io = os.fdopen(logs.fds[0], "wb", buffering=0)
1538 configure_logging(json_output=True, output=log_io, sending_to_supervisor=True)
1539 else:
1540 print("Unable to re-configure logging after sudo, we didn't get an FD", file=sys.stderr)
1541
1542
1543if __name__ == "__main__":
1544 main()