Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/airflow/sdk/execution_time/task_runner.py: 15%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

757 statements  

1# 

2# Licensed to the Apache Software Foundation (ASF) under one 

3# or more contributor license agreements. See the NOTICE file 

4# distributed with this work for additional information 

5# regarding copyright ownership. The ASF licenses this file 

6# to you under the Apache License, Version 2.0 (the 

7# "License"); you may not use this file except in compliance 

8# with the License. You may obtain a copy of the License at 

9# 

10# http://www.apache.org/licenses/LICENSE-2.0 

11# 

12# Unless required by applicable law or agreed to in writing, 

13# software distributed under the License is distributed on an 

14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

15# KIND, either express or implied. See the License for the 

16# specific language governing permissions and limitations 

17# under the License. 

18"""The entrypoint for the actual task execution process.""" 

19 

20from __future__ import annotations 

21 

22import contextlib 

23import contextvars 

24import functools 

25import os 

26import sys 

27import time 

28from collections.abc import Callable, Iterable, Iterator, Mapping 

29from contextlib import suppress 

30from datetime import datetime, timezone 

31from itertools import product 

32from pathlib import Path 

33from typing import TYPE_CHECKING, Annotated, Any, Literal 

34from urllib.parse import quote 

35 

36import attrs 

37import lazy_object_proxy 

38import structlog 

39from pydantic import AwareDatetime, ConfigDict, Field, JsonValue, TypeAdapter 

40 

41from airflow.dag_processing.bundles.base import BaseDagBundle, BundleVersionLock 

42from airflow.dag_processing.bundles.manager import DagBundlesManager 

43from airflow.sdk.api.client import get_hostname, getuser 

44from airflow.sdk.api.datamodels._generated import ( 

45 AssetProfile, 

46 DagRun, 

47 PreviousTIResponse, 

48 TaskInstance, 

49 TaskInstanceState, 

50 TIRunContext, 

51) 

52from airflow.sdk.bases.operator import BaseOperator, ExecutorSafeguard 

53from airflow.sdk.bases.xcom import BaseXCom 

54from airflow.sdk.configuration import conf 

55from airflow.sdk.definitions._internal.dag_parsing_context import _airflow_parsing_context_manager 

56from airflow.sdk.definitions._internal.types import NOTSET, ArgNotSet, is_arg_set 

57from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetNameRef, AssetUniqueKey, AssetUriRef 

58from airflow.sdk.definitions.mappedoperator import MappedOperator 

59from airflow.sdk.definitions.param import process_params 

60from airflow.sdk.exceptions import ( 

61 AirflowException, 

62 AirflowInactiveAssetInInletOrOutletException, 

63 AirflowRuntimeError, 

64 AirflowTaskTimeout, 

65 ErrorType, 

66 TaskDeferred, 

67) 

68from airflow.sdk.execution_time.callback_runner import create_executable_runner 

69from airflow.sdk.execution_time.comms import ( 

70 AssetEventDagRunReferenceResult, 

71 CommsDecoder, 

72 DagRunStateResult, 

73 DeferTask, 

74 DRCount, 

75 ErrorResponse, 

76 GetDagRunState, 

77 GetDRCount, 

78 GetPreviousDagRun, 

79 GetPreviousTI, 

80 GetTaskBreadcrumbs, 

81 GetTaskRescheduleStartDate, 

82 GetTaskStates, 

83 GetTICount, 

84 InactiveAssetsResult, 

85 PreviousDagRunResult, 

86 PreviousTIResult, 

87 RescheduleTask, 

88 ResendLoggingFD, 

89 RetryTask, 

90 SentFDs, 

91 SetRenderedFields, 

92 SetRenderedMapIndex, 

93 SkipDownstreamTasks, 

94 StartupDetails, 

95 SucceedTask, 

96 TaskBreadcrumbsResult, 

97 TaskRescheduleStartDate, 

98 TaskState, 

99 TaskStatesResult, 

100 TICount, 

101 ToSupervisor, 

102 ToTask, 

103 TriggerDagRun, 

104 ValidateInletsAndOutlets, 

105) 

106from airflow.sdk.execution_time.context import ( 

107 ConnectionAccessor, 

108 InletEventsAccessors, 

109 MacrosAccessor, 

110 OutletEventAccessors, 

111 TriggeringAssetEventsAccessor, 

112 VariableAccessor, 

113 context_get_outlet_events, 

114 context_to_airflow_vars, 

115 get_previous_dagrun_success, 

116 set_current_context, 

117) 

118from airflow.sdk.execution_time.sentry import Sentry 

119from airflow.sdk.execution_time.xcom import XCom 

120from airflow.sdk.listener import get_listener_manager 

121from airflow.sdk.observability.stats import Stats 

122from airflow.sdk.timezone import coerce_datetime 

123from airflow.triggers.base import BaseEventTrigger 

124from airflow.triggers.callback import CallbackTrigger 

125 

126if TYPE_CHECKING: 

127 import jinja2 

128 from pendulum.datetime import DateTime 

129 from structlog.typing import FilteringBoundLogger as Logger 

130 

131 from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator 

132 from airflow.sdk.definitions.context import Context 

133 from airflow.sdk.exceptions import DagRunTriggerException 

134 from airflow.sdk.types import OutletEventAccessorsProtocol 

135 

136 

137class TaskRunnerMarker: 

138 """Marker for listener hooks, to properly detect from which component they are called.""" 

139 

140 

141# TODO: Move this entire class into a separate file: 

142# `airflow/sdk/execution_time/task_instance.py` 

143# or `airflow/sdk/execution_time/runtime_ti.py` 

144class RuntimeTaskInstance(TaskInstance): 

145 model_config = ConfigDict(arbitrary_types_allowed=True) 

146 

147 task: BaseOperator 

148 bundle_instance: BaseDagBundle 

149 _cached_template_context: Context | None = None 

150 """The Task Instance context. This is used to cache get_template_context.""" 

151 

152 _ti_context_from_server: Annotated[TIRunContext | None, Field(repr=False)] = None 

153 """The Task Instance context from the API server, if any.""" 

154 

155 max_tries: int = 0 

156 """The maximum number of retries for the task.""" 

157 

158 start_date: AwareDatetime 

159 """Start date of the task instance.""" 

160 

161 end_date: AwareDatetime | None = None 

162 

163 state: TaskInstanceState | None = None 

164 

165 is_mapped: bool | None = None 

166 """True if the original task was mapped.""" 

167 

168 rendered_map_index: str | None = None 

169 

170 sentry_integration: str = "" 

171 

172 def __rich_repr__(self): 

173 yield "id", self.id 

174 yield "task_id", self.task_id 

175 yield "dag_id", self.dag_id 

176 yield "run_id", self.run_id 

177 yield "max_tries", self.max_tries 

178 yield "task", type(self.task) 

179 yield "start_date", self.start_date 

180 

181 __rich_repr__.angular = True # type: ignore[attr-defined] 

182 

183 def get_template_context(self) -> Context: 

184 # TODO: Move this to `airflow.sdk.execution_time.context` 

185 # once we port the entire context logic from airflow/utils/context.py ? 

186 from airflow.sdk.plugins_manager import integrate_macros_plugins 

187 

188 integrate_macros_plugins() 

189 

190 dag_run_conf: dict[str, Any] | None = None 

191 if from_server := self._ti_context_from_server: 

192 dag_run_conf = from_server.dag_run.conf or dag_run_conf 

193 

194 validated_params = process_params(self.task.dag, self.task, dag_run_conf, suppress_exception=False) 

195 

196 # Cache the context object, which ensures that all calls to get_template_context 

197 # are operating on the same context object. 

198 self._cached_template_context: Context = self._cached_template_context or { 

199 # From the Task Execution interface 

200 "dag": self.task.dag, 

201 "inlets": self.task.inlets, 

202 "map_index_template": self.task.map_index_template, 

203 "outlets": self.task.outlets, 

204 "run_id": self.run_id, 

205 "task": self.task, 

206 "task_instance": self, 

207 "ti": self, 

208 "outlet_events": OutletEventAccessors(), 

209 "inlet_events": InletEventsAccessors(self.task.inlets), 

210 "macros": MacrosAccessor(), 

211 "params": validated_params, 

212 # TODO: Make this go through Public API longer term. 

213 # "test_mode": task_instance.test_mode, 

214 "var": { 

215 "json": VariableAccessor(deserialize_json=True), 

216 "value": VariableAccessor(deserialize_json=False), 

217 }, 

218 "conn": ConnectionAccessor(), 

219 } 

220 if from_server: 

221 dag_run = from_server.dag_run 

222 context_from_server: Context = { 

223 # TODO: Assess if we need to pass these through timezone.coerce_datetime 

224 "dag_run": dag_run, # type: ignore[typeddict-item] # Removable after #46522 

225 "triggering_asset_events": TriggeringAssetEventsAccessor.build( 

226 AssetEventDagRunReferenceResult.from_asset_event_dag_run_reference(event) 

227 for event in dag_run.consumed_asset_events 

228 ), 

229 "task_instance_key_str": f"{self.task.dag_id}__{self.task.task_id}__{dag_run.run_id}", 

230 "task_reschedule_count": from_server.task_reschedule_count or 0, 

231 "prev_start_date_success": lazy_object_proxy.Proxy( 

232 lambda: coerce_datetime(get_previous_dagrun_success(self.id).start_date) 

233 ), 

234 "prev_end_date_success": lazy_object_proxy.Proxy( 

235 lambda: coerce_datetime(get_previous_dagrun_success(self.id).end_date) 

236 ), 

237 } 

238 self._cached_template_context.update(context_from_server) 

239 

240 if logical_date := coerce_datetime(dag_run.logical_date): 

241 if TYPE_CHECKING: 

242 assert isinstance(logical_date, DateTime) 

243 ds = logical_date.strftime("%Y-%m-%d") 

244 ds_nodash = ds.replace("-", "") 

245 ts = logical_date.isoformat() 

246 ts_nodash = logical_date.strftime("%Y%m%dT%H%M%S") 

247 ts_nodash_with_tz = ts.replace("-", "").replace(":", "") 

248 # logical_date and data_interval either coexist or be None together 

249 self._cached_template_context.update( 

250 { 

251 # keys that depend on logical_date 

252 "logical_date": logical_date, 

253 "ds": ds, 

254 "ds_nodash": ds_nodash, 

255 "task_instance_key_str": f"{self.task.dag_id}__{self.task.task_id}__{ds_nodash}", 

256 "ts": ts, 

257 "ts_nodash": ts_nodash, 

258 "ts_nodash_with_tz": ts_nodash_with_tz, 

259 # keys that depend on data_interval 

260 "data_interval_end": coerce_datetime(dag_run.data_interval_end), 

261 "data_interval_start": coerce_datetime(dag_run.data_interval_start), 

262 "prev_data_interval_start_success": lazy_object_proxy.Proxy( 

263 lambda: coerce_datetime(get_previous_dagrun_success(self.id).data_interval_start) 

264 ), 

265 "prev_data_interval_end_success": lazy_object_proxy.Proxy( 

266 lambda: coerce_datetime(get_previous_dagrun_success(self.id).data_interval_end) 

267 ), 

268 } 

269 ) 

270 

271 if from_server.upstream_map_indexes is not None: 

272 # We stash this in here for later use, but we purposefully don't want to document it's 

273 # existence. Should this be a private attribute on RuntimeTI instead perhaps? 

274 setattr(self, "_upstream_map_indexes", from_server.upstream_map_indexes) 

275 

276 return self._cached_template_context 

277 

278 def render_templates( 

279 self, context: Context | None = None, jinja_env: jinja2.Environment | None = None 

280 ) -> BaseOperator: 

281 """ 

282 Render templates in the operator fields. 

283 

284 If the task was originally mapped, this may replace ``self.task`` with 

285 the unmapped, fully rendered BaseOperator. The original ``self.task`` 

286 before replacement is returned. 

287 """ 

288 if not context: 

289 context = self.get_template_context() 

290 original_task = self.task 

291 

292 if TYPE_CHECKING: 

293 assert context 

294 

295 ti = context["ti"] 

296 

297 if TYPE_CHECKING: 

298 assert original_task 

299 assert self.task 

300 assert ti.task 

301 

302 # If self.task is mapped, this call replaces self.task to point to the 

303 # unmapped BaseOperator created by this function! This is because the 

304 # MappedOperator is useless for template rendering, and we need to be 

305 # able to access the unmapped task instead. 

306 self.task.render_template_fields(context, jinja_env) 

307 self.is_mapped = original_task.is_mapped 

308 return original_task 

309 

310 def xcom_pull( 

311 self, 

312 task_ids: str | Iterable[str] | None = None, 

313 dag_id: str | None = None, 

314 key: str = BaseXCom.XCOM_RETURN_KEY, 

315 include_prior_dates: bool = False, 

316 *, 

317 map_indexes: int | Iterable[int] | None | ArgNotSet = NOTSET, 

318 default: Any = None, 

319 run_id: str | None = None, 

320 ) -> Any: 

321 """ 

322 Pull XComs either from the API server (BaseXCom) or from the custom XCOM backend if configured. 

323 

324 The pull can be filtered optionally by certain criterion. 

325 

326 :param key: A key for the XCom. If provided, only XComs with matching 

327 keys will be returned. The default key is ``'return_value'``, also 

328 available as constant ``XCOM_RETURN_KEY``. This key is automatically 

329 given to XComs returned by tasks (as opposed to being pushed 

330 manually). 

331 :param task_ids: Only XComs from tasks with matching ids will be 

332 pulled. If *None* (default), the task_id of the calling task is used. 

333 :param dag_id: If provided, only pulls XComs from this Dag. If *None* 

334 (default), the Dag of the calling task is used. 

335 :param map_indexes: If provided, only pull XComs with matching indexes. 

336 If *None* (default), this is inferred from the task(s) being pulled 

337 (see below for details). 

338 :param include_prior_dates: If False, only XComs from the current 

339 logical_date are returned. If *True*, XComs from previous dates 

340 are returned as well. 

341 :param run_id: If provided, only pulls XComs from a DagRun w/a matching run_id. 

342 If *None* (default), the run_id of the calling task is used. 

343 

344 When pulling one single task (``task_id`` is *None* or a str) without 

345 specifying ``map_indexes``, the return value is a single XCom entry 

346 (map_indexes is set to map_index of the calling task instance). 

347 

348 When pulling task is mapped the specified ``map_index`` is used, so by default 

349 pulling on mapped task will result in no matching XComs if the task instance 

350 of the method call is not mapped. Otherwise, the map_index of the calling task 

351 instance is used. Setting ``map_indexes`` to *None* will pull XCom as it would 

352 from a non mapped task. 

353 

354 In either case, ``default`` (*None* if not specified) is returned if no 

355 matching XComs are found. 

356 

357 When pulling multiple tasks (i.e. either ``task_id`` or ``map_index`` is 

358 a non-str iterable), a list of matching XComs is returned. Elements in 

359 the list is ordered by item ordering in ``task_id`` and ``map_index``. 

360 """ 

361 if dag_id is None: 

362 dag_id = self.dag_id 

363 if run_id is None: 

364 run_id = self.run_id 

365 

366 single_task_requested = isinstance(task_ids, (str, type(None))) 

367 single_map_index_requested = isinstance(map_indexes, (int, type(None))) 

368 

369 if task_ids is None: 

370 # default to the current task if not provided 

371 task_ids = [self.task_id] 

372 elif isinstance(task_ids, str): 

373 task_ids = [task_ids] 

374 

375 # If map_indexes is not specified, pull xcoms from all map indexes for each task 

376 if not is_arg_set(map_indexes): 

377 xcoms: list[Any] = [] 

378 for t_id in task_ids: 

379 values = XCom.get_all( 

380 run_id=run_id, 

381 key=key, 

382 task_id=t_id, 

383 dag_id=dag_id, 

384 include_prior_dates=include_prior_dates, 

385 ) 

386 

387 if values is None: 

388 xcoms.append(None) 

389 else: 

390 xcoms.extend(values) 

391 # For single task pulling from unmapped task, return single value 

392 if single_task_requested and len(xcoms) == 1: 

393 return xcoms[0] 

394 return xcoms 

395 

396 # Original logic when map_indexes is explicitly specified 

397 map_indexes_iterable: Iterable[int | None] = [] 

398 if isinstance(map_indexes, int) or map_indexes is None: 

399 map_indexes_iterable = [map_indexes] 

400 elif isinstance(map_indexes, Iterable): 

401 map_indexes_iterable = map_indexes 

402 else: 

403 raise TypeError( 

404 f"Invalid type for map_indexes: expected int, iterable of ints, or None, got {type(map_indexes)}" 

405 ) 

406 

407 xcoms = [] 

408 for t_id, m_idx in product(task_ids, map_indexes_iterable): 

409 value = XCom.get_one( 

410 run_id=run_id, 

411 key=key, 

412 task_id=t_id, 

413 dag_id=dag_id, 

414 map_index=m_idx, 

415 include_prior_dates=include_prior_dates, 

416 ) 

417 if value is None: 

418 xcoms.append(default) 

419 else: 

420 xcoms.append(value) 

421 

422 if single_task_requested and single_map_index_requested: 

423 return xcoms[0] 

424 

425 return xcoms 

426 

427 def xcom_push(self, key: str, value: Any): 

428 """ 

429 Make an XCom available for tasks to pull. 

430 

431 :param key: Key to store the value under. 

432 :param value: Value to store. Only be JSON-serializable values may be used. 

433 """ 

434 _xcom_push(self, key, value) 

435 

436 def get_relevant_upstream_map_indexes( 

437 self, upstream: BaseOperator, ti_count: int | None, session: Any 

438 ) -> int | range | None: 

439 # TODO: Implement this method 

440 return None 

441 

442 def get_first_reschedule_date(self, context: Context) -> AwareDatetime | None: 

443 """Get the first reschedule date for the task instance if found, none otherwise.""" 

444 if context.get("task_reschedule_count", 0) == 0: 

445 # If the task has not been rescheduled, there is no need to ask the supervisor 

446 return None 

447 

448 max_tries: int = self.max_tries 

449 retries: int = self.task.retries or 0 

450 first_try_number = max_tries - retries + 1 

451 

452 log = structlog.get_logger(logger_name="task") 

453 

454 log.debug("Requesting first reschedule date from supervisor") 

455 

456 response = SUPERVISOR_COMMS.send( 

457 msg=GetTaskRescheduleStartDate(ti_id=self.id, try_number=first_try_number) 

458 ) 

459 

460 if TYPE_CHECKING: 

461 assert isinstance(response, TaskRescheduleStartDate) 

462 

463 return response.start_date 

464 

465 def get_previous_dagrun(self, state: str | None = None) -> DagRun | None: 

466 """Return the previous Dag run before the given logical date, optionally filtered by state.""" 

467 context = self.get_template_context() 

468 dag_run = context.get("dag_run") 

469 

470 log = structlog.get_logger(logger_name="task") 

471 

472 log.debug("Getting previous Dag run", dag_run=dag_run) 

473 

474 if dag_run is None: 

475 return None 

476 

477 if dag_run.logical_date is None: 

478 return None 

479 

480 response = SUPERVISOR_COMMS.send( 

481 msg=GetPreviousDagRun(dag_id=self.dag_id, logical_date=dag_run.logical_date, state=state) 

482 ) 

483 

484 if TYPE_CHECKING: 

485 assert isinstance(response, PreviousDagRunResult) 

486 

487 return response.dag_run 

488 

489 def get_previous_ti( 

490 self, 

491 state: TaskInstanceState | None = None, 

492 logical_date: AwareDatetime | None = None, 

493 map_index: int = -1, 

494 ) -> PreviousTIResponse | None: 

495 """ 

496 Return the previous task instance matching the given criteria. 

497 

498 :param state: Filter by TaskInstance state 

499 :param logical_date: Filter by logical date (returns TI before this date) 

500 :param map_index: Filter by map_index (defaults to -1 for non-mapped tasks) 

501 :return: Previous task instance or None if not found 

502 """ 

503 context = self.get_template_context() 

504 dag_run = context.get("dag_run") 

505 

506 log = structlog.get_logger(logger_name="task") 

507 log.debug("Getting previous task instance", task_id=self.task_id, state=state) 

508 

509 # Use current dag run's logical_date if not provided 

510 effective_logical_date = logical_date 

511 if effective_logical_date is None and dag_run and dag_run.logical_date: 

512 effective_logical_date = dag_run.logical_date 

513 

514 response = SUPERVISOR_COMMS.send( 

515 msg=GetPreviousTI( 

516 dag_id=self.dag_id, 

517 task_id=self.task_id, 

518 logical_date=effective_logical_date, 

519 map_index=map_index, 

520 state=state, 

521 ) 

522 ) 

523 

524 if TYPE_CHECKING: 

525 assert isinstance(response, PreviousTIResult) 

526 

527 return response.task_instance 

528 

529 @staticmethod 

530 def get_ti_count( 

531 dag_id: str, 

532 map_index: int | None = None, 

533 task_ids: list[str] | None = None, 

534 task_group_id: str | None = None, 

535 logical_dates: list[datetime] | None = None, 

536 run_ids: list[str] | None = None, 

537 states: list[str] | None = None, 

538 ) -> int: 

539 """Return the number of task instances matching the given criteria.""" 

540 response = SUPERVISOR_COMMS.send( 

541 GetTICount( 

542 dag_id=dag_id, 

543 map_index=map_index, 

544 task_ids=task_ids, 

545 task_group_id=task_group_id, 

546 logical_dates=logical_dates, 

547 run_ids=run_ids, 

548 states=states, 

549 ), 

550 ) 

551 

552 if TYPE_CHECKING: 

553 assert isinstance(response, TICount) 

554 

555 return response.count 

556 

557 @staticmethod 

558 def get_task_states( 

559 dag_id: str, 

560 map_index: int | None = None, 

561 task_ids: list[str] | None = None, 

562 task_group_id: str | None = None, 

563 logical_dates: list[datetime] | None = None, 

564 run_ids: list[str] | None = None, 

565 ) -> dict[str, Any]: 

566 """Return the task states matching the given criteria.""" 

567 response = SUPERVISOR_COMMS.send( 

568 GetTaskStates( 

569 dag_id=dag_id, 

570 map_index=map_index, 

571 task_ids=task_ids, 

572 task_group_id=task_group_id, 

573 logical_dates=logical_dates, 

574 run_ids=run_ids, 

575 ), 

576 ) 

577 

578 if TYPE_CHECKING: 

579 assert isinstance(response, TaskStatesResult) 

580 

581 return response.task_states 

582 

583 @staticmethod 

584 def get_task_breadcrumbs(dag_id: str, run_id: str) -> Iterable[dict[str, Any]]: 

585 """Return task breadcrumbs for the given dag run.""" 

586 response = SUPERVISOR_COMMS.send(GetTaskBreadcrumbs(dag_id=dag_id, run_id=run_id)) 

587 if TYPE_CHECKING: 

588 assert isinstance(response, TaskBreadcrumbsResult) 

589 return response.breadcrumbs 

590 

591 @staticmethod 

592 def get_dr_count( 

593 dag_id: str, 

594 logical_dates: list[datetime] | None = None, 

595 run_ids: list[str] | None = None, 

596 states: list[str] | None = None, 

597 ) -> int: 

598 """Return the number of Dag runs matching the given criteria.""" 

599 response = SUPERVISOR_COMMS.send( 

600 GetDRCount( 

601 dag_id=dag_id, 

602 logical_dates=logical_dates, 

603 run_ids=run_ids, 

604 states=states, 

605 ), 

606 ) 

607 

608 if TYPE_CHECKING: 

609 assert isinstance(response, DRCount) 

610 

611 return response.count 

612 

613 @staticmethod 

614 def get_dagrun_state(dag_id: str, run_id: str) -> str: 

615 """Return the state of the Dag run with the given Run ID.""" 

616 response = SUPERVISOR_COMMS.send(msg=GetDagRunState(dag_id=dag_id, run_id=run_id)) 

617 

618 if TYPE_CHECKING: 

619 assert isinstance(response, DagRunStateResult) 

620 

621 return response.state 

622 

623 @property 

624 def log_url(self) -> str: 

625 run_id = quote(self.run_id) 

626 base_url = conf.get("api", "base_url", fallback="http://localhost:8080/") 

627 map_index_value = self.map_index 

628 map_index = ( 

629 f"/mapped/{map_index_value}" if map_index_value is not None and map_index_value >= 0 else "" 

630 ) 

631 try_number_value = self.try_number 

632 try_number = ( 

633 f"?try_number={try_number_value}" if try_number_value is not None and try_number_value > 0 else "" 

634 ) 

635 _log_uri = f"{base_url.rstrip('/')}/dags/{self.dag_id}/runs/{run_id}/tasks/{self.task_id}{map_index}{try_number}" 

636 return _log_uri 

637 

638 @property 

639 def mark_success_url(self) -> str: 

640 """URL to mark TI success.""" 

641 return self.log_url 

642 

643 

644def _xcom_push(ti: RuntimeTaskInstance, key: str, value: Any, mapped_length: int | None = None) -> None: 

645 """Push a XCom through XCom.set, which pushes to XCom Backend if configured.""" 

646 # Private function, as we don't want to expose the ability to manually set `mapped_length` to SDK 

647 # consumers 

648 

649 XCom.set( 

650 key=key, 

651 value=value, 

652 dag_id=ti.dag_id, 

653 task_id=ti.task_id, 

654 run_id=ti.run_id, 

655 map_index=ti.map_index, 

656 _mapped_length=mapped_length, 

657 ) 

658 

659 

660def _xcom_push_to_db(ti: RuntimeTaskInstance, key: str, value: Any) -> None: 

661 """Push a XCom directly to metadata DB, bypassing custom xcom_backend.""" 

662 XCom._set_xcom_in_db( 

663 key=key, 

664 value=value, 

665 dag_id=ti.dag_id, 

666 task_id=ti.task_id, 

667 run_id=ti.run_id, 

668 map_index=ti.map_index, 

669 ) 

670 

671 

672def parse(what: StartupDetails, log: Logger) -> RuntimeTaskInstance: 

673 # TODO: Task-SDK: 

674 # Using DagBag here is about 98% wrong, but it'll do for now 

675 from airflow.dag_processing.dagbag import DagBag 

676 

677 bundle_info = what.bundle_info 

678 bundle_instance = DagBundlesManager().get_bundle( 

679 name=bundle_info.name, 

680 version=bundle_info.version, 

681 ) 

682 bundle_instance.initialize() 

683 

684 # Put bundle root on sys.path if needed. This allows the dag bundle to add 

685 # code in util modules to be shared between files within the same bundle. 

686 if (bundle_root := os.fspath(bundle_instance.path)) not in sys.path: 

687 sys.path.append(bundle_root) 

688 

689 dag_absolute_path = os.fspath(Path(bundle_instance.path, what.dag_rel_path)) 

690 bag = DagBag( 

691 dag_folder=dag_absolute_path, 

692 include_examples=False, 

693 safe_mode=False, 

694 load_op_links=False, 

695 bundle_name=bundle_info.name, 

696 ) 

697 if TYPE_CHECKING: 

698 assert what.ti.dag_id 

699 

700 try: 

701 dag = bag.dags[what.ti.dag_id] 

702 except KeyError: 

703 log.error( 

704 "Dag not found during start up", dag_id=what.ti.dag_id, bundle=bundle_info, path=what.dag_rel_path 

705 ) 

706 sys.exit(1) 

707 

708 # install_loader() 

709 

710 try: 

711 task = dag.task_dict[what.ti.task_id] 

712 except KeyError: 

713 log.error( 

714 "Task not found in Dag during start up", 

715 dag_id=dag.dag_id, 

716 task_id=what.ti.task_id, 

717 bundle=bundle_info, 

718 path=what.dag_rel_path, 

719 ) 

720 sys.exit(1) 

721 

722 if not isinstance(task, (BaseOperator, MappedOperator)): 

723 raise TypeError( 

724 f"task is of the wrong type, got {type(task)}, wanted {BaseOperator} or {MappedOperator}" 

725 ) 

726 

727 return RuntimeTaskInstance.model_construct( 

728 **what.ti.model_dump(exclude_unset=True), 

729 task=task, 

730 bundle_instance=bundle_instance, 

731 _ti_context_from_server=what.ti_context, 

732 max_tries=what.ti_context.max_tries, 

733 start_date=what.start_date, 

734 state=TaskInstanceState.RUNNING, 

735 sentry_integration=what.sentry_integration, 

736 ) 

737 

738 

739# This global variable will be used by Connection/Variable/XCom classes, or other parts of the task's execution, 

740# to send requests back to the supervisor process. 

741# 

742# Why it needs to be a global: 

743# - Many parts of Airflow's codebase (e.g., connections, variables, and XComs) may rely on making dynamic requests 

744# to the parent process during task execution. 

745# - These calls occur in various locations and cannot easily pass the `CommsDecoder` instance through the 

746# deeply nested execution stack. 

747# - By defining `SUPERVISOR_COMMS` as a global, it ensures that this communication mechanism is readily 

748# accessible wherever needed during task execution without modifying every layer of the call stack. 

749SUPERVISOR_COMMS: CommsDecoder[ToTask, ToSupervisor] 

750 

751 

752# State machine! 

753# 1. Start up (receive details from supervisor) 

754# 2. Execution (run task code, possibly send requests) 

755# 3. Shutdown and report status 

756 

757 

758def startup() -> tuple[RuntimeTaskInstance, Context, Logger]: 

759 # The parent sends us a StartupDetails message un-prompted. After this, every single message is only sent 

760 # in response to us sending a request. 

761 log = structlog.get_logger(logger_name="task") 

762 

763 if os.environ.get("_AIRFLOW__REEXECUTED_PROCESS") == "1" and ( 

764 msgjson := os.environ.get("_AIRFLOW__STARTUP_MSG") 

765 ): 

766 # Clear any Kerberos replace cache if there is one, so new process can't reuse it. 

767 os.environ.pop("KRB5CCNAME", None) 

768 # entrypoint of re-exec process 

769 

770 msg: StartupDetails = TypeAdapter(StartupDetails).validate_json(msgjson) 

771 reinit_supervisor_comms() 

772 

773 # We delay this message until _after_ we've got the logging re-configured, otherwise it will show up 

774 # on stdout 

775 log.debug("Using serialized startup message from environment", msg=msg) 

776 else: 

777 # normal entry point 

778 msg = SUPERVISOR_COMMS._get_response() # type: ignore[assignment] 

779 

780 if not isinstance(msg, StartupDetails): 

781 raise RuntimeError(f"Unhandled startup message {type(msg)} {msg}") 

782 

783 # setproctitle causes issue on Mac OS: https://github.com/benoitc/gunicorn/issues/3021 

784 os_type = sys.platform 

785 if os_type == "darwin": 

786 log.debug("Mac OS detected, skipping setproctitle") 

787 else: 

788 from setproctitle import setproctitle 

789 

790 setproctitle(f"airflow worker -- {msg.ti.id}") 

791 

792 try: 

793 get_listener_manager().hook.on_starting(component=TaskRunnerMarker()) 

794 except Exception: 

795 log.exception("error calling listener") 

796 

797 with _airflow_parsing_context_manager(dag_id=msg.ti.dag_id, task_id=msg.ti.task_id): 

798 ti = parse(msg, log) 

799 log.debug("Dag file parsed", file=msg.dag_rel_path) 

800 

801 run_as_user = getattr(ti.task, "run_as_user", None) or conf.get( 

802 "core", "default_impersonation", fallback=None 

803 ) 

804 

805 if os.environ.get("_AIRFLOW__REEXECUTED_PROCESS") != "1" and run_as_user and run_as_user != getuser(): 

806 # enters here for re-exec process 

807 os.environ["_AIRFLOW__REEXECUTED_PROCESS"] = "1" 

808 # store startup message in environment for re-exec process 

809 os.environ["_AIRFLOW__STARTUP_MSG"] = msg.model_dump_json() 

810 os.set_inheritable(SUPERVISOR_COMMS.socket.fileno(), True) 

811 

812 # Import main directly from the module instead of re-executing the file. 

813 # This ensures that when other parts modules import 

814 # airflow.sdk.execution_time.task_runner, they get the same module instance 

815 # with the properly initialized SUPERVISOR_COMMS global variable. 

816 # If we re-executed the module with `python -m`, it would load as __main__ and future 

817 # imports would get a fresh copy without the initialized globals. 

818 rexec_python_code = "from airflow.sdk.execution_time.task_runner import main; main()" 

819 cmd = ["sudo", "-E", "-H", "-u", run_as_user, sys.executable, "-c", rexec_python_code] 

820 log.info( 

821 "Running command", 

822 command=cmd, 

823 ) 

824 os.execvp("sudo", cmd) 

825 

826 # ideally, we should never reach here, but if we do, we should return None, None, None 

827 return None, None, None 

828 

829 return ti, ti.get_template_context(), log 

830 

831 

832def _serialize_template_field(template_field: Any, name: str) -> str | dict | list | int | float: 

833 """ 

834 Return a serializable representation of the templated field. 

835 

836 If ``templated_field`` contains a class or instance that requires recursive 

837 templating, store them as strings. Otherwise simply return the field as-is. 

838 

839 Used sdk secrets masker to redact secrets in the serialized output. 

840 """ 

841 import json 

842 

843 from airflow.sdk._shared.secrets_masker import redact 

844 

845 def is_jsonable(x): 

846 try: 

847 json.dumps(x) 

848 except (TypeError, OverflowError): 

849 return False 

850 else: 

851 return True 

852 

853 def translate_tuples_to_lists(obj: Any): 

854 """Recursively convert tuples to lists.""" 

855 if isinstance(obj, tuple): 

856 return [translate_tuples_to_lists(item) for item in obj] 

857 if isinstance(obj, list): 

858 return [translate_tuples_to_lists(item) for item in obj] 

859 if isinstance(obj, dict): 

860 return {key: translate_tuples_to_lists(value) for key, value in obj.items()} 

861 return obj 

862 

863 def sort_dict_recursively(obj: Any) -> Any: 

864 """Recursively sort dictionaries to ensure consistent ordering.""" 

865 if isinstance(obj, dict): 

866 return {k: sort_dict_recursively(v) for k, v in sorted(obj.items())} 

867 if isinstance(obj, list): 

868 return [sort_dict_recursively(item) for item in obj] 

869 if isinstance(obj, tuple): 

870 return tuple(sort_dict_recursively(item) for item in obj) 

871 return obj 

872 

873 max_length = conf.getint("core", "max_templated_field_length") 

874 

875 if not is_jsonable(template_field): 

876 try: 

877 serialized = template_field.serialize() 

878 except AttributeError: 

879 serialized = str(template_field) 

880 if len(serialized) > max_length: 

881 rendered = redact(serialized, name) 

882 return ( 

883 "Truncated. You can change this behaviour in [core]max_templated_field_length. " 

884 f"{rendered[: max_length - 79]!r}... " 

885 ) 

886 return serialized 

887 if not template_field and not isinstance(template_field, tuple): 

888 # Avoid unnecessary serialization steps for empty fields unless they are tuples 

889 # and need to be converted to lists 

890 return template_field 

891 template_field = translate_tuples_to_lists(template_field) 

892 # Sort dictionaries recursively to ensure consistent string representation 

893 # This prevents hash inconsistencies when dict ordering varies 

894 if isinstance(template_field, dict): 

895 template_field = sort_dict_recursively(template_field) 

896 serialized = str(template_field) 

897 if len(serialized) > max_length: 

898 rendered = redact(serialized, name) 

899 return ( 

900 "Truncated. You can change this behaviour in [core]max_templated_field_length. " 

901 f"{rendered[: max_length - 79]!r}... " 

902 ) 

903 return template_field 

904 

905 

906def _serialize_rendered_fields(task: AbstractOperator) -> dict[str, JsonValue]: 

907 from airflow.sdk._shared.secrets_masker import redact 

908 

909 rendered_fields = {} 

910 for field in task.template_fields: 

911 value = getattr(task, field) 

912 serialized = _serialize_template_field(value, field) 

913 # Redact secrets in the task process itself before sending to API server 

914 # This ensures that the secrets those are registered via mask_secret() on workers / dag processor are properly masked 

915 # on the UI. 

916 rendered_fields[field] = redact(serialized, field) 

917 

918 return rendered_fields # type: ignore[return-value] # Convince mypy that this is OK since we pass JsonValue to redact, so it will return the same 

919 

920 

921def _build_asset_profiles(lineage_objects: list) -> Iterator[AssetProfile]: 

922 # Lineage can have other types of objects besides assets, so we need to process them a bit. 

923 for obj in lineage_objects or (): 

924 if isinstance(obj, Asset): 

925 yield AssetProfile(name=obj.name, uri=obj.uri, type=Asset.__name__) 

926 elif isinstance(obj, AssetNameRef): 

927 yield AssetProfile(name=obj.name, type=AssetNameRef.__name__) 

928 elif isinstance(obj, AssetUriRef): 

929 yield AssetProfile(uri=obj.uri, type=AssetUriRef.__name__) 

930 elif isinstance(obj, AssetAlias): 

931 yield AssetProfile(name=obj.name, type=AssetAlias.__name__) 

932 

933 

934def _serialize_outlet_events(events: OutletEventAccessorsProtocol) -> Iterator[dict[str, JsonValue]]: 

935 if TYPE_CHECKING: 

936 assert isinstance(events, OutletEventAccessors) 

937 # We just collect everything the user recorded in the accessors. 

938 # Further filtering will be done in the API server. 

939 for key, accessor in events._dict.items(): 

940 if isinstance(key, AssetUniqueKey): 

941 yield {"dest_asset_key": attrs.asdict(key), "extra": accessor.extra} 

942 for alias_event in accessor.asset_alias_events: 

943 yield attrs.asdict(alias_event) 

944 

945 

946def _prepare(ti: RuntimeTaskInstance, log: Logger, context: Context) -> ToSupervisor | None: 

947 ti.hostname = get_hostname() 

948 ti.task = ti.task.prepare_for_execution() 

949 # Since context is now cached, and calling `ti.get_template_context` will return the same dict, we want to 

950 # update the value of the task that is sent from there 

951 context["task"] = ti.task 

952 

953 jinja_env = ti.task.dag.get_template_env() 

954 ti.render_templates(context=context, jinja_env=jinja_env) 

955 

956 if rendered_fields := _serialize_rendered_fields(ti.task): 

957 # so that we do not call the API unnecessarily 

958 SUPERVISOR_COMMS.send(msg=SetRenderedFields(rendered_fields=rendered_fields)) 

959 

960 # Try to render map_index_template early with available context (will be re-rendered after execution) 

961 # This provides a partial label during task execution for templates using pre-execution context 

962 # If rendering fails here, we suppress the error since it will be re-rendered after execution 

963 try: 

964 if rendered_map_index := _render_map_index(context, ti=ti, log=log): 

965 ti.rendered_map_index = rendered_map_index 

966 log.debug("Sending early rendered map index", length=len(rendered_map_index)) 

967 SUPERVISOR_COMMS.send(msg=SetRenderedMapIndex(rendered_map_index=rendered_map_index)) 

968 except Exception: 

969 log.debug( 

970 "Early rendering of map_index_template failed, will retry after task execution", exc_info=True 

971 ) 

972 

973 _validate_task_inlets_and_outlets(ti=ti, log=log) 

974 

975 try: 

976 # TODO: Call pre execute etc. 

977 get_listener_manager().hook.on_task_instance_running( 

978 previous_state=TaskInstanceState.QUEUED, task_instance=ti 

979 ) 

980 except Exception: 

981 log.exception("error calling listener") 

982 

983 # No error, carry on and execute the task 

984 return None 

985 

986 

987def _validate_task_inlets_and_outlets(*, ti: RuntimeTaskInstance, log: Logger) -> None: 

988 if not ti.task.inlets and not ti.task.outlets: 

989 return 

990 

991 inactive_assets_resp = SUPERVISOR_COMMS.send(msg=ValidateInletsAndOutlets(ti_id=ti.id)) 

992 if TYPE_CHECKING: 

993 assert isinstance(inactive_assets_resp, InactiveAssetsResult) 

994 if inactive_assets := inactive_assets_resp.inactive_assets: 

995 raise AirflowInactiveAssetInInletOrOutletException( 

996 inactive_asset_keys=[ 

997 AssetUniqueKey.from_profile(asset_profile) for asset_profile in inactive_assets 

998 ] 

999 ) 

1000 

1001 

1002def _defer_task( 

1003 defer: TaskDeferred, ti: RuntimeTaskInstance, log: Logger 

1004) -> tuple[ToSupervisor, TaskInstanceState]: 

1005 # TODO: Should we use structlog.bind_contextvars here for dag_id, task_id & run_id? 

1006 

1007 log.info("Pausing task as DEFERRED. ", dag_id=ti.dag_id, task_id=ti.task_id, run_id=ti.run_id) 

1008 classpath, trigger_kwargs = defer.trigger.serialize() 

1009 queue: str | None = None 

1010 # Currently, only task-associated BaseTrigger instances may have a non-None queue, 

1011 # and only when triggerer.queues_enabled is True. 

1012 if not isinstance(defer.trigger, (BaseEventTrigger, CallbackTrigger)) and conf.getboolean( 

1013 "triggerer", "queues_enabled", fallback=False 

1014 ): 

1015 queue = ti.task.queue 

1016 

1017 from airflow.sdk.serde import serialize as serde_serialize 

1018 

1019 trigger_kwargs = serde_serialize(trigger_kwargs) 

1020 next_kwargs = serde_serialize(defer.kwargs or {}) 

1021 

1022 if TYPE_CHECKING: 

1023 assert isinstance(next_kwargs, dict) 

1024 assert isinstance(trigger_kwargs, dict) 

1025 

1026 msg = DeferTask( 

1027 classpath=classpath, 

1028 trigger_kwargs=trigger_kwargs, 

1029 trigger_timeout=defer.timeout, 

1030 queue=queue, 

1031 next_method=defer.method_name, 

1032 next_kwargs=next_kwargs, 

1033 ) 

1034 state = TaskInstanceState.DEFERRED 

1035 

1036 return msg, state 

1037 

1038 

1039@Sentry.enrich_errors 

1040def run( 

1041 ti: RuntimeTaskInstance, 

1042 context: Context, 

1043 log: Logger, 

1044) -> tuple[TaskInstanceState, ToSupervisor | None, BaseException | None]: 

1045 """Run the task in this process.""" 

1046 import signal 

1047 

1048 from airflow.sdk.exceptions import ( 

1049 AirflowFailException, 

1050 AirflowRescheduleException, 

1051 AirflowSensorTimeout, 

1052 AirflowSkipException, 

1053 AirflowTaskTerminated, 

1054 DagRunTriggerException, 

1055 DownstreamTasksSkipped, 

1056 TaskDeferred, 

1057 ) 

1058 

1059 if TYPE_CHECKING: 

1060 assert ti.task is not None 

1061 assert isinstance(ti.task, BaseOperator) 

1062 

1063 parent_pid = os.getpid() 

1064 

1065 def _on_term(signum, frame): 

1066 pid = os.getpid() 

1067 if pid != parent_pid: 

1068 return 

1069 

1070 ti.task.on_kill() 

1071 

1072 signal.signal(signal.SIGTERM, _on_term) 

1073 

1074 msg: ToSupervisor | None = None 

1075 state: TaskInstanceState 

1076 error: BaseException | None = None 

1077 

1078 try: 

1079 # First, clear the xcom data sent from server 

1080 if ti._ti_context_from_server and (keys_to_delete := ti._ti_context_from_server.xcom_keys_to_clear): 

1081 for x in keys_to_delete: 

1082 log.debug("Clearing XCom with key", key=x) 

1083 XCom.delete( 

1084 key=x, 

1085 dag_id=ti.dag_id, 

1086 task_id=ti.task_id, 

1087 run_id=ti.run_id, 

1088 map_index=ti.map_index, 

1089 ) 

1090 

1091 with set_current_context(context): 

1092 # This is the earliest that we can render templates -- as if it excepts for any reason we need to 

1093 # catch it and handle it like a normal task failure 

1094 if early_exit := _prepare(ti, log, context): 

1095 msg = early_exit 

1096 ti.state = state = TaskInstanceState.FAILED 

1097 return state, msg, error 

1098 

1099 try: 

1100 result = _execute_task(context=context, ti=ti, log=log) 

1101 except Exception: 

1102 import jinja2 

1103 

1104 # If the task failed, swallow rendering error so it doesn't mask the main error. 

1105 with contextlib.suppress(jinja2.TemplateSyntaxError, jinja2.UndefinedError): 

1106 previous_rendered_map_index = ti.rendered_map_index 

1107 ti.rendered_map_index = _render_map_index(context, ti=ti, log=log) 

1108 # Send update only if value changed (e.g., user set context variables during execution) 

1109 if ti.rendered_map_index and ti.rendered_map_index != previous_rendered_map_index: 

1110 SUPERVISOR_COMMS.send( 

1111 msg=SetRenderedMapIndex(rendered_map_index=ti.rendered_map_index) 

1112 ) 

1113 raise 

1114 else: # If the task succeeded, render normally to let rendering error bubble up. 

1115 previous_rendered_map_index = ti.rendered_map_index 

1116 ti.rendered_map_index = _render_map_index(context, ti=ti, log=log) 

1117 # Send update only if value changed (e.g., user set context variables during execution) 

1118 if ti.rendered_map_index and ti.rendered_map_index != previous_rendered_map_index: 

1119 SUPERVISOR_COMMS.send(msg=SetRenderedMapIndex(rendered_map_index=ti.rendered_map_index)) 

1120 

1121 _push_xcom_if_needed(result, ti, log) 

1122 

1123 msg, state = _handle_current_task_success(context, ti) 

1124 except DownstreamTasksSkipped as skip: 

1125 log.info("Skipping downstream tasks.") 

1126 tasks_to_skip = skip.tasks if isinstance(skip.tasks, list) else [skip.tasks] 

1127 SUPERVISOR_COMMS.send(msg=SkipDownstreamTasks(tasks=tasks_to_skip)) 

1128 msg, state = _handle_current_task_success(context, ti) 

1129 except DagRunTriggerException as drte: 

1130 msg, state = _handle_trigger_dag_run(drte, context, ti, log) 

1131 except TaskDeferred as defer: 

1132 msg, state = _defer_task(defer, ti, log) 

1133 except AirflowSkipException as e: 

1134 if e.args: 

1135 log.info("Skipping task.", reason=e.args[0]) 

1136 msg = TaskState( 

1137 state=TaskInstanceState.SKIPPED, 

1138 end_date=datetime.now(tz=timezone.utc), 

1139 rendered_map_index=ti.rendered_map_index, 

1140 ) 

1141 state = TaskInstanceState.SKIPPED 

1142 except AirflowRescheduleException as reschedule: 

1143 log.info("Rescheduling task, marking task as UP_FOR_RESCHEDULE") 

1144 msg = RescheduleTask( 

1145 reschedule_date=reschedule.reschedule_date, end_date=datetime.now(tz=timezone.utc) 

1146 ) 

1147 state = TaskInstanceState.UP_FOR_RESCHEDULE 

1148 except (AirflowFailException, AirflowSensorTimeout) as e: 

1149 # If AirflowFailException is raised, task should not retry. 

1150 # If a sensor in reschedule mode reaches timeout, task should not retry. 

1151 log.exception("Task failed with exception") 

1152 ti.end_date = datetime.now(tz=timezone.utc) 

1153 msg = TaskState( 

1154 state=TaskInstanceState.FAILED, 

1155 end_date=ti.end_date, 

1156 rendered_map_index=ti.rendered_map_index, 

1157 ) 

1158 state = TaskInstanceState.FAILED 

1159 error = e 

1160 except (AirflowTaskTimeout, AirflowException, AirflowRuntimeError) as e: 

1161 # We should allow retries if the task has defined it. 

1162 log.exception("Task failed with exception") 

1163 msg, state = _handle_current_task_failed(ti) 

1164 error = e 

1165 except AirflowTaskTerminated as e: 

1166 # External state updates are already handled with `ti_heartbeat` and will be 

1167 # updated already be another UI API. So, these exceptions should ideally never be thrown. 

1168 # If these are thrown, we should mark the TI state as failed. 

1169 log.exception("Task failed with exception") 

1170 ti.end_date = datetime.now(tz=timezone.utc) 

1171 msg = TaskState( 

1172 state=TaskInstanceState.FAILED, 

1173 end_date=ti.end_date, 

1174 rendered_map_index=ti.rendered_map_index, 

1175 ) 

1176 state = TaskInstanceState.FAILED 

1177 error = e 

1178 except SystemExit as e: 

1179 # SystemExit needs to be retried if they are eligible. 

1180 log.error("Task exited", exit_code=e.code) 

1181 msg, state = _handle_current_task_failed(ti) 

1182 error = e 

1183 except BaseException as e: 

1184 log.exception("Task failed with exception") 

1185 msg, state = _handle_current_task_failed(ti) 

1186 error = e 

1187 finally: 

1188 if msg: 

1189 SUPERVISOR_COMMS.send(msg=msg) 

1190 

1191 # Return the message to make unit tests easier too 

1192 ti.state = state 

1193 return state, msg, error 

1194 

1195 

1196def _handle_current_task_success( 

1197 context: Context, 

1198 ti: RuntimeTaskInstance, 

1199) -> tuple[SucceedTask, TaskInstanceState]: 

1200 end_date = datetime.now(tz=timezone.utc) 

1201 ti.end_date = end_date 

1202 

1203 # Record operator and task instance success metrics 

1204 operator = ti.task.__class__.__name__ 

1205 stats_tags = {"dag_id": ti.dag_id, "task_id": ti.task_id} 

1206 

1207 Stats.incr(f"operator_successes_{operator}", tags=stats_tags) 

1208 # Same metric with tagging 

1209 Stats.incr("operator_successes", tags={**stats_tags, "operator": operator}) 

1210 Stats.incr("ti_successes", tags=stats_tags) 

1211 

1212 task_outlets = list(_build_asset_profiles(ti.task.outlets)) 

1213 outlet_events = list(_serialize_outlet_events(context["outlet_events"])) 

1214 msg = SucceedTask( 

1215 end_date=end_date, 

1216 task_outlets=task_outlets, 

1217 outlet_events=outlet_events, 

1218 rendered_map_index=ti.rendered_map_index, 

1219 ) 

1220 return msg, TaskInstanceState.SUCCESS 

1221 

1222 

1223def _handle_current_task_failed( 

1224 ti: RuntimeTaskInstance, 

1225) -> tuple[RetryTask, TaskInstanceState] | tuple[TaskState, TaskInstanceState]: 

1226 end_date = datetime.now(tz=timezone.utc) 

1227 ti.end_date = end_date 

1228 

1229 # Record operator and task instance failed metrics 

1230 operator = ti.task.__class__.__name__ 

1231 stats_tags = {"dag_id": ti.dag_id, "task_id": ti.task_id} 

1232 

1233 Stats.incr(f"operator_failures_{operator}", tags=stats_tags) 

1234 # Same metric with tagging 

1235 Stats.incr("operator_failures", tags={**stats_tags, "operator": operator}) 

1236 Stats.incr("ti_failures", tags=stats_tags) 

1237 

1238 if ti._ti_context_from_server and ti._ti_context_from_server.should_retry: 

1239 return RetryTask(end_date=end_date), TaskInstanceState.UP_FOR_RETRY 

1240 return ( 

1241 TaskState( 

1242 state=TaskInstanceState.FAILED, end_date=end_date, rendered_map_index=ti.rendered_map_index 

1243 ), 

1244 TaskInstanceState.FAILED, 

1245 ) 

1246 

1247 

1248def _handle_trigger_dag_run( 

1249 drte: DagRunTriggerException, context: Context, ti: RuntimeTaskInstance, log: Logger 

1250) -> tuple[ToSupervisor, TaskInstanceState]: 

1251 """Handle exception from TriggerDagRunOperator.""" 

1252 log.info("Triggering Dag Run.", trigger_dag_id=drte.trigger_dag_id) 

1253 comms_msg = SUPERVISOR_COMMS.send( 

1254 TriggerDagRun( 

1255 dag_id=drte.trigger_dag_id, 

1256 run_id=drte.dag_run_id, 

1257 logical_date=drte.logical_date, 

1258 conf=drte.conf, 

1259 reset_dag_run=drte.reset_dag_run, 

1260 ), 

1261 ) 

1262 

1263 if isinstance(comms_msg, ErrorResponse) and comms_msg.error == ErrorType.DAGRUN_ALREADY_EXISTS: 

1264 if drte.skip_when_already_exists: 

1265 log.info( 

1266 "Dag Run already exists, skipping task as skip_when_already_exists is set to True.", 

1267 dag_id=drte.trigger_dag_id, 

1268 ) 

1269 msg = TaskState( 

1270 state=TaskInstanceState.SKIPPED, 

1271 end_date=datetime.now(tz=timezone.utc), 

1272 rendered_map_index=ti.rendered_map_index, 

1273 ) 

1274 state = TaskInstanceState.SKIPPED 

1275 else: 

1276 log.error("Dag Run already exists, marking task as failed.", dag_id=drte.trigger_dag_id) 

1277 msg = TaskState( 

1278 state=TaskInstanceState.FAILED, 

1279 end_date=datetime.now(tz=timezone.utc), 

1280 rendered_map_index=ti.rendered_map_index, 

1281 ) 

1282 state = TaskInstanceState.FAILED 

1283 

1284 return msg, state 

1285 

1286 log.info("Dag Run triggered successfully.", trigger_dag_id=drte.trigger_dag_id) 

1287 

1288 # Store the run id from the dag run (either created or found above) to 

1289 # be used when creating the extra link on the webserver. 

1290 ti.xcom_push(key="trigger_run_id", value=drte.dag_run_id) 

1291 

1292 if drte.wait_for_completion: 

1293 if drte.deferrable: 

1294 from airflow.providers.standard.triggers.external_task import DagStateTrigger 

1295 

1296 defer = TaskDeferred( 

1297 trigger=DagStateTrigger( 

1298 dag_id=drte.trigger_dag_id, 

1299 states=drte.allowed_states + drte.failed_states, # type: ignore[arg-type] 

1300 # Don't filter by execution_dates when run_ids is provided. 

1301 # run_id uniquely identifies a DAG run, and when reset_dag_run=True, 

1302 # drte.logical_date might be a newly calculated value that doesn't match 

1303 # the persisted logical_date in the database, causing the trigger to never find the run. 

1304 execution_dates=None, 

1305 run_ids=[drte.dag_run_id], 

1306 poll_interval=drte.poke_interval, 

1307 ), 

1308 method_name="execute_complete", 

1309 ) 

1310 return _defer_task(defer, ti, log) 

1311 while True: 

1312 log.info( 

1313 "Waiting for dag run to complete execution in allowed state.", 

1314 dag_id=drte.trigger_dag_id, 

1315 run_id=drte.dag_run_id, 

1316 allowed_state=drte.allowed_states, 

1317 ) 

1318 time.sleep(drte.poke_interval) 

1319 

1320 comms_msg = SUPERVISOR_COMMS.send( 

1321 GetDagRunState(dag_id=drte.trigger_dag_id, run_id=drte.dag_run_id) 

1322 ) 

1323 if TYPE_CHECKING: 

1324 assert isinstance(comms_msg, DagRunStateResult) 

1325 if comms_msg.state in drte.failed_states: 

1326 log.error( 

1327 "DagRun finished with failed state.", dag_id=drte.trigger_dag_id, state=comms_msg.state 

1328 ) 

1329 msg = TaskState( 

1330 state=TaskInstanceState.FAILED, 

1331 end_date=datetime.now(tz=timezone.utc), 

1332 rendered_map_index=ti.rendered_map_index, 

1333 ) 

1334 state = TaskInstanceState.FAILED 

1335 return msg, state 

1336 if comms_msg.state in drte.allowed_states: 

1337 log.info( 

1338 "DagRun finished with allowed state.", dag_id=drte.trigger_dag_id, state=comms_msg.state 

1339 ) 

1340 break 

1341 log.debug( 

1342 "DagRun not yet in allowed or failed state.", 

1343 dag_id=drte.trigger_dag_id, 

1344 state=comms_msg.state, 

1345 ) 

1346 else: 

1347 # Fire-and-forget mode: wait_for_completion=False 

1348 if drte.deferrable: 

1349 log.info( 

1350 "Ignoring deferrable=True because wait_for_completion=False. " 

1351 "Task will complete immediately without waiting for the triggered DAG run.", 

1352 trigger_dag_id=drte.trigger_dag_id, 

1353 ) 

1354 

1355 return _handle_current_task_success(context, ti) 

1356 

1357 

1358def _run_task_state_change_callbacks( 

1359 task: BaseOperator, 

1360 kind: Literal[ 

1361 "on_execute_callback", 

1362 "on_failure_callback", 

1363 "on_success_callback", 

1364 "on_retry_callback", 

1365 "on_skipped_callback", 

1366 ], 

1367 context: Context, 

1368 log: Logger, 

1369) -> None: 

1370 callback: Callable[[Context], None] 

1371 for i, callback in enumerate(getattr(task, kind)): 

1372 try: 

1373 create_executable_runner(callback, context_get_outlet_events(context), logger=log).run(context) 

1374 except Exception: 

1375 log.exception("Failed to run task callback", kind=kind, index=i, callback=callback) 

1376 

1377 

1378def _send_error_email_notification( 

1379 task: BaseOperator | MappedOperator, 

1380 ti: RuntimeTaskInstance, 

1381 context: Context, 

1382 error: BaseException | str | None, 

1383 log: Logger, 

1384) -> None: 

1385 """Send email notification for task errors using SmtpNotifier.""" 

1386 try: 

1387 from airflow.providers.smtp.notifications.smtp import SmtpNotifier 

1388 except ImportError: 

1389 log.error( 

1390 "Failed to send task failure or retry email notification: " 

1391 "`apache-airflow-providers-smtp` is not installed. " 

1392 "Install this provider to enable email notifications." 

1393 ) 

1394 return 

1395 

1396 if not task.email: 

1397 return 

1398 

1399 subject_template_file = conf.get("email", "subject_template", fallback=None) 

1400 

1401 # Read the template file if configured 

1402 if subject_template_file and Path(subject_template_file).exists(): 

1403 subject = Path(subject_template_file).read_text() 

1404 else: 

1405 # Fallback to default 

1406 subject = "Airflow alert: {{ti}}" 

1407 

1408 html_content_template_file = conf.get("email", "html_content_template", fallback=None) 

1409 

1410 # Read the template file if configured 

1411 if html_content_template_file and Path(html_content_template_file).exists(): 

1412 html_content = Path(html_content_template_file).read_text() 

1413 else: 

1414 # Fallback to default 

1415 # For reporting purposes, we report based on 1-indexed, 

1416 # not 0-indexed lists (i.e. Try 1 instead of Try 0 for the first attempt). 

1417 html_content = ( 

1418 "Try {{try_number}} out of {{max_tries + 1}}<br>" 

1419 "Exception:<br>{{exception_html}}<br>" 

1420 'Log: <a href="{{ti.log_url}}">Link</a><br>' 

1421 "Host: {{ti.hostname}}<br>" 

1422 'Mark success: <a href="{{ti.mark_success_url}}">Link</a><br>' 

1423 ) 

1424 

1425 # Add exception_html to context for template rendering 

1426 import html 

1427 

1428 exception_html = html.escape(str(error)).replace("\n", "<br>") 

1429 additional_context = { 

1430 "exception": error, 

1431 "exception_html": exception_html, 

1432 "try_number": ti.try_number, 

1433 "max_tries": ti.max_tries, 

1434 } 

1435 email_context = {**context, **additional_context} 

1436 to_emails = task.email 

1437 if not to_emails: 

1438 return 

1439 

1440 try: 

1441 notifier = SmtpNotifier( 

1442 to=to_emails, 

1443 subject=subject, 

1444 html_content=html_content, 

1445 from_email=conf.get("email", "from_email", fallback="airflow@airflow"), 

1446 ) 

1447 notifier(email_context) 

1448 except Exception: 

1449 log.exception("Failed to send email notification") 

1450 

1451 

1452def _execute_task(context: Context, ti: RuntimeTaskInstance, log: Logger): 

1453 """Execute Task (optionally with a Timeout) and push Xcom results.""" 

1454 task = ti.task 

1455 execute = task.execute 

1456 

1457 if ti._ti_context_from_server and (next_method := ti._ti_context_from_server.next_method): 

1458 from airflow.sdk.serde import deserialize 

1459 

1460 next_kwargs_data = ti._ti_context_from_server.next_kwargs or {} 

1461 try: 

1462 if TYPE_CHECKING: 

1463 assert isinstance(next_kwargs_data, dict) 

1464 kwargs = deserialize(next_kwargs_data) 

1465 except (ImportError, KeyError, AttributeError, TypeError): 

1466 from airflow.serialization.serialized_objects import BaseSerialization 

1467 

1468 kwargs = BaseSerialization.deserialize(next_kwargs_data) 

1469 

1470 if TYPE_CHECKING: 

1471 assert isinstance(kwargs, dict) 

1472 execute = functools.partial(task.resume_execution, next_method=next_method, next_kwargs=kwargs) 

1473 

1474 ctx = contextvars.copy_context() 

1475 # Populate the context var so ExecutorSafeguard doesn't complain 

1476 ctx.run(ExecutorSafeguard.tracker.set, task) 

1477 

1478 # Export context in os.environ to make it available for operators to use. 

1479 airflow_context_vars = context_to_airflow_vars(context, in_env_var_format=True) 

1480 os.environ.update(airflow_context_vars) 

1481 

1482 outlet_events = context_get_outlet_events(context) 

1483 

1484 if (pre_execute_hook := task._pre_execute_hook) is not None: 

1485 create_executable_runner(pre_execute_hook, outlet_events, logger=log).run(context) 

1486 if getattr(pre_execute_hook := task.pre_execute, "__func__", None) is not BaseOperator.pre_execute: 

1487 create_executable_runner(pre_execute_hook, outlet_events, logger=log).run(context) 

1488 

1489 _run_task_state_change_callbacks(task, "on_execute_callback", context, log) 

1490 

1491 if task.execution_timeout: 

1492 from airflow.sdk.execution_time.timeout import timeout 

1493 

1494 # TODO: handle timeout in case of deferral 

1495 timeout_seconds = task.execution_timeout.total_seconds() 

1496 try: 

1497 # It's possible we're already timed out, so fast-fail if true 

1498 if timeout_seconds <= 0: 

1499 raise AirflowTaskTimeout() 

1500 # Run task in timeout wrapper 

1501 with timeout(timeout_seconds): 

1502 result = ctx.run(execute, context=context) 

1503 except AirflowTaskTimeout: 

1504 task.on_kill() 

1505 raise 

1506 else: 

1507 result = ctx.run(execute, context=context) 

1508 

1509 if (post_execute_hook := task._post_execute_hook) is not None: 

1510 create_executable_runner(post_execute_hook, outlet_events, logger=log).run(context, result) 

1511 if getattr(post_execute_hook := task.post_execute, "__func__", None) is not BaseOperator.post_execute: 

1512 create_executable_runner(post_execute_hook, outlet_events, logger=log).run(context) 

1513 

1514 return result 

1515 

1516 

1517def _render_map_index(context: Context, ti: RuntimeTaskInstance, log: Logger) -> str | None: 

1518 """Render named map index if the Dag author defined map_index_template at the task level.""" 

1519 if (template := context.get("map_index_template")) is None: 

1520 return None 

1521 log.debug("Rendering map_index_template", template_length=len(template)) 

1522 jinja_env = ti.task.dag.get_template_env() 

1523 rendered_map_index = jinja_env.from_string(template).render(context) 

1524 log.debug("Map index rendered", length=len(rendered_map_index)) 

1525 return rendered_map_index 

1526 

1527 

1528def _push_xcom_if_needed(result: Any, ti: RuntimeTaskInstance, log: Logger): 

1529 """Push XCom values when task has ``do_xcom_push`` set to ``True`` and the task returns a result.""" 

1530 if ti.task.do_xcom_push: 

1531 xcom_value = result 

1532 else: 

1533 xcom_value = None 

1534 

1535 has_mapped_dep = next(ti.task.iter_mapped_dependants(), None) is not None 

1536 if xcom_value is None: 

1537 if not ti.is_mapped and has_mapped_dep: 

1538 # Uhoh, a downstream mapped task depends on us to push something to map over 

1539 from airflow.sdk.exceptions import XComForMappingNotPushed 

1540 

1541 raise XComForMappingNotPushed() 

1542 return 

1543 

1544 mapped_length: int | None = None 

1545 if not ti.is_mapped and has_mapped_dep: 

1546 from airflow.sdk.definitions.mappedoperator import is_mappable_value 

1547 from airflow.sdk.exceptions import UnmappableXComTypePushed 

1548 

1549 if not is_mappable_value(xcom_value): 

1550 raise UnmappableXComTypePushed(xcom_value) 

1551 mapped_length = len(xcom_value) 

1552 

1553 log.info("Pushing xcom", ti=ti) 

1554 

1555 # If the task has multiple outputs, push each output as a separate XCom. 

1556 if ti.task.multiple_outputs: 

1557 if not isinstance(xcom_value, Mapping): 

1558 raise TypeError( 

1559 f"Returned output was type {type(xcom_value)} expected dictionary for multiple_outputs" 

1560 ) 

1561 for key in xcom_value.keys(): 

1562 if not isinstance(key, str): 

1563 raise TypeError( 

1564 "Returned dictionary keys must be strings when using " 

1565 f"multiple_outputs, found {key} ({type(key)}) instead" 

1566 ) 

1567 for k, v in result.items(): 

1568 ti.xcom_push(k, v) 

1569 

1570 _xcom_push(ti, BaseXCom.XCOM_RETURN_KEY, result, mapped_length=mapped_length) 

1571 

1572 

1573def finalize( 

1574 ti: RuntimeTaskInstance, 

1575 state: TaskInstanceState, 

1576 context: Context, 

1577 log: Logger, 

1578 error: BaseException | None = None, 

1579): 

1580 # Record task duration metrics for all terminal states 

1581 if ti.start_date and ti.end_date: 

1582 duration_ms = (ti.end_date - ti.start_date).total_seconds() * 1000 

1583 stats_tags = {"dag_id": ti.dag_id, "task_id": ti.task_id} 

1584 

1585 Stats.timing(f"dag.{ti.dag_id}.{ti.task_id}.duration", duration_ms) 

1586 Stats.timing("task.duration", duration_ms, tags=stats_tags) 

1587 

1588 task = ti.task 

1589 # Pushing xcom for each operator extra links defined on the operator only. 

1590 for oe in task.operator_extra_links: 

1591 try: 

1592 link, xcom_key = oe.get_link(operator=task, ti_key=ti), oe.xcom_key # type: ignore[arg-type] 

1593 log.debug("Setting xcom for operator extra link", link=link, xcom_key=xcom_key) 

1594 _xcom_push_to_db(ti, key=xcom_key, value=link) 

1595 except Exception: 

1596 log.exception( 

1597 "Failed to push an xcom for task operator extra link", 

1598 link_name=oe.name, 

1599 xcom_key=oe.xcom_key, 

1600 ti=ti, 

1601 ) 

1602 

1603 if getattr(ti.task, "overwrite_rtif_after_execution", False): 

1604 log.debug("Overwriting Rendered template fields.") 

1605 if ti.task.template_fields: 

1606 SUPERVISOR_COMMS.send(SetRenderedFields(rendered_fields=_serialize_rendered_fields(ti.task))) 

1607 

1608 log.debug("Running finalizers", ti=ti) 

1609 if state == TaskInstanceState.SUCCESS: 

1610 _run_task_state_change_callbacks(task, "on_success_callback", context, log) 

1611 try: 

1612 get_listener_manager().hook.on_task_instance_success( 

1613 previous_state=TaskInstanceState.RUNNING, task_instance=ti 

1614 ) 

1615 except Exception: 

1616 log.exception("error calling listener") 

1617 elif state == TaskInstanceState.SKIPPED: 

1618 _run_task_state_change_callbacks(task, "on_skipped_callback", context, log) 

1619 try: 

1620 get_listener_manager().hook.on_task_instance_skipped( 

1621 previous_state=TaskInstanceState.RUNNING, task_instance=ti 

1622 ) 

1623 except Exception: 

1624 log.exception("error calling listener") 

1625 elif state == TaskInstanceState.UP_FOR_RETRY: 

1626 _run_task_state_change_callbacks(task, "on_retry_callback", context, log) 

1627 try: 

1628 get_listener_manager().hook.on_task_instance_failed( 

1629 previous_state=TaskInstanceState.RUNNING, task_instance=ti, error=error 

1630 ) 

1631 except Exception: 

1632 log.exception("error calling listener") 

1633 if error and task.email_on_retry and task.email: 

1634 _send_error_email_notification(task, ti, context, error, log) 

1635 elif state == TaskInstanceState.FAILED: 

1636 _run_task_state_change_callbacks(task, "on_failure_callback", context, log) 

1637 try: 

1638 get_listener_manager().hook.on_task_instance_failed( 

1639 previous_state=TaskInstanceState.RUNNING, task_instance=ti, error=error 

1640 ) 

1641 except Exception: 

1642 log.exception("error calling listener") 

1643 if error and task.email_on_failure and task.email: 

1644 _send_error_email_notification(task, ti, context, error, log) 

1645 

1646 try: 

1647 get_listener_manager().hook.before_stopping(component=TaskRunnerMarker()) 

1648 except Exception: 

1649 log.exception("error calling listener") 

1650 

1651 

1652def main(): 

1653 log = structlog.get_logger(logger_name="task") 

1654 

1655 global SUPERVISOR_COMMS 

1656 SUPERVISOR_COMMS = CommsDecoder[ToTask, ToSupervisor](log=log) 

1657 

1658 try: 

1659 ti, context, log = startup() 

1660 with BundleVersionLock( 

1661 bundle_name=ti.bundle_instance.name, 

1662 bundle_version=ti.bundle_instance.version, 

1663 ): 

1664 state, _, error = run(ti, context, log) 

1665 context["exception"] = error 

1666 finalize(ti, state, context, log, error) 

1667 except KeyboardInterrupt: 

1668 log.exception("Ctrl-c hit") 

1669 exit(2) 

1670 except Exception: 

1671 log.exception("Top level error") 

1672 exit(1) 

1673 finally: 

1674 # Ensure the request socket is closed on the child side in all circumstances 

1675 # before the process fully terminates. 

1676 if SUPERVISOR_COMMS and SUPERVISOR_COMMS.socket: 

1677 with suppress(Exception): 

1678 SUPERVISOR_COMMS.socket.close() 

1679 

1680 

1681def reinit_supervisor_comms() -> None: 

1682 """ 

1683 Re-initialize supervisor comms and logging channel in subprocess. 

1684 

1685 This is not needed for most cases, but is used when either we re-launch the process via sudo for 

1686 run_as_user, or from inside the python code in a virtualenv (et al.) operator to re-connect so those tasks 

1687 can continue to access variables etc. 

1688 """ 

1689 import socket 

1690 

1691 if "SUPERVISOR_COMMS" not in globals(): 

1692 global SUPERVISOR_COMMS 

1693 log = structlog.get_logger(logger_name="task") 

1694 

1695 fd = int(os.environ.get("__AIRFLOW_SUPERVISOR_FD", "0")) 

1696 

1697 SUPERVISOR_COMMS = CommsDecoder[ToTask, ToSupervisor](log=log, socket=socket.socket(fileno=fd)) 

1698 

1699 logs = SUPERVISOR_COMMS.send(ResendLoggingFD()) 

1700 if isinstance(logs, SentFDs): 

1701 from airflow.sdk.log import configure_logging 

1702 

1703 log_io = os.fdopen(logs.fds[0], "wb", buffering=0) 

1704 configure_logging(json_output=True, output=log_io, sending_to_supervisor=True) 

1705 else: 

1706 print("Unable to re-configure logging after sudo, we didn't get an FD", file=sys.stderr) 

1707 

1708 

1709if __name__ == "__main__": 

1710 main()