Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/airflow/sdk/execution_time/task_runner.py: 15%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

768 statements  

1# 

2# Licensed to the Apache Software Foundation (ASF) under one 

3# or more contributor license agreements. See the NOTICE file 

4# distributed with this work for additional information 

5# regarding copyright ownership. The ASF licenses this file 

6# to you under the Apache License, Version 2.0 (the 

7# "License"); you may not use this file except in compliance 

8# with the License. You may obtain a copy of the License at 

9# 

10# http://www.apache.org/licenses/LICENSE-2.0 

11# 

12# Unless required by applicable law or agreed to in writing, 

13# software distributed under the License is distributed on an 

14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

15# KIND, either express or implied. See the License for the 

16# specific language governing permissions and limitations 

17# under the License. 

18"""The entrypoint for the actual task execution process.""" 

19 

20from __future__ import annotations 

21 

22import contextlib 

23import contextvars 

24import functools 

25import os 

26import sys 

27import time 

28from collections.abc import Callable, Iterable, Iterator, Mapping 

29from contextlib import suppress 

30from datetime import datetime, timezone 

31from itertools import product 

32from pathlib import Path 

33from typing import TYPE_CHECKING, Annotated, Any, Literal 

34from urllib.parse import quote 

35 

36import attrs 

37import lazy_object_proxy 

38import structlog 

39from pydantic import AwareDatetime, ConfigDict, Field, JsonValue, TypeAdapter 

40 

41from airflow.dag_processing.bundles.base import BaseDagBundle, BundleVersionLock 

42from airflow.dag_processing.bundles.manager import DagBundlesManager 

43from airflow.sdk._shared.observability.metrics.stats import Stats 

44from airflow.sdk.api.client import get_hostname, getuser 

45from airflow.sdk.api.datamodels._generated import ( 

46 AssetProfile, 

47 DagRun, 

48 PreviousTIResponse, 

49 TaskInstance, 

50 TaskInstanceState, 

51 TIRunContext, 

52) 

53from airflow.sdk.bases.operator import BaseOperator, ExecutorSafeguard 

54from airflow.sdk.bases.xcom import BaseXCom 

55from airflow.sdk.configuration import conf 

56from airflow.sdk.definitions._internal.dag_parsing_context import _airflow_parsing_context_manager 

57from airflow.sdk.definitions._internal.types import NOTSET, ArgNotSet, is_arg_set 

58from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetNameRef, AssetUniqueKey, AssetUriRef 

59from airflow.sdk.definitions.mappedoperator import MappedOperator 

60from airflow.sdk.definitions.param import process_params 

61from airflow.sdk.exceptions import ( 

62 AirflowException, 

63 AirflowInactiveAssetInInletOrOutletException, 

64 AirflowRuntimeError, 

65 AirflowTaskTimeout, 

66 ErrorType, 

67 TaskDeferred, 

68) 

69from airflow.sdk.execution_time.callback_runner import create_executable_runner 

70from airflow.sdk.execution_time.comms import ( 

71 AssetEventDagRunReferenceResult, 

72 CommsDecoder, 

73 DagRunStateResult, 

74 DeferTask, 

75 DRCount, 

76 ErrorResponse, 

77 GetDagRunState, 

78 GetDRCount, 

79 GetPreviousDagRun, 

80 GetPreviousTI, 

81 GetTaskBreadcrumbs, 

82 GetTaskRescheduleStartDate, 

83 GetTaskStates, 

84 GetTICount, 

85 InactiveAssetsResult, 

86 PreviousDagRunResult, 

87 PreviousTIResult, 

88 RescheduleTask, 

89 ResendLoggingFD, 

90 RetryTask, 

91 SentFDs, 

92 SetRenderedFields, 

93 SetRenderedMapIndex, 

94 SkipDownstreamTasks, 

95 StartupDetails, 

96 SucceedTask, 

97 TaskBreadcrumbsResult, 

98 TaskRescheduleStartDate, 

99 TaskState, 

100 TaskStatesResult, 

101 TICount, 

102 ToSupervisor, 

103 ToTask, 

104 TriggerDagRun, 

105 ValidateInletsAndOutlets, 

106) 

107from airflow.sdk.execution_time.context import ( 

108 ConnectionAccessor, 

109 InletEventsAccessors, 

110 MacrosAccessor, 

111 OutletEventAccessors, 

112 TriggeringAssetEventsAccessor, 

113 VariableAccessor, 

114 context_get_outlet_events, 

115 context_to_airflow_vars, 

116 get_previous_dagrun_success, 

117 set_current_context, 

118) 

119from airflow.sdk.execution_time.sentry import Sentry 

120from airflow.sdk.execution_time.xcom import XCom 

121from airflow.sdk.listener import get_listener_manager 

122from airflow.sdk.timezone import coerce_datetime 

123from airflow.triggers.base import BaseEventTrigger 

124from airflow.triggers.callback import CallbackTrigger 

125 

126if TYPE_CHECKING: 

127 import jinja2 

128 from pendulum.datetime import DateTime 

129 from structlog.typing import FilteringBoundLogger as Logger 

130 

131 from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator 

132 from airflow.sdk.definitions.context import Context 

133 from airflow.sdk.exceptions import DagRunTriggerException 

134 from airflow.sdk.types import OutletEventAccessorsProtocol 

135 

136 

137class TaskRunnerMarker: 

138 """Marker for listener hooks, to properly detect from which component they are called.""" 

139 

140 

141# TODO: Move this entire class into a separate file: 

142# `airflow/sdk/execution_time/task_instance.py` 

143# or `airflow/sdk/execution_time/runtime_ti.py` 

144class RuntimeTaskInstance(TaskInstance): 

145 model_config = ConfigDict(arbitrary_types_allowed=True) 

146 

147 task: BaseOperator 

148 bundle_instance: BaseDagBundle 

149 _cached_template_context: Context | None = None 

150 """The Task Instance context. This is used to cache get_template_context.""" 

151 

152 _ti_context_from_server: Annotated[TIRunContext | None, Field(repr=False)] = None 

153 """The Task Instance context from the API server, if any.""" 

154 

155 max_tries: int = 0 

156 """The maximum number of retries for the task.""" 

157 

158 start_date: AwareDatetime 

159 """Start date of the task instance.""" 

160 

161 end_date: AwareDatetime | None = None 

162 

163 state: TaskInstanceState | None = None 

164 

165 is_mapped: bool | None = None 

166 """True if the original task was mapped.""" 

167 

168 rendered_map_index: str | None = None 

169 

170 sentry_integration: str = "" 

171 

172 def __rich_repr__(self): 

173 yield "id", self.id 

174 yield "task_id", self.task_id 

175 yield "dag_id", self.dag_id 

176 yield "run_id", self.run_id 

177 yield "max_tries", self.max_tries 

178 yield "task", type(self.task) 

179 yield "start_date", self.start_date 

180 

181 __rich_repr__.angular = True # type: ignore[attr-defined] 

182 

183 def get_template_context(self) -> Context: 

184 # TODO: Move this to `airflow.sdk.execution_time.context` 

185 # once we port the entire context logic from airflow/utils/context.py ? 

186 from airflow.sdk.plugins_manager import integrate_macros_plugins 

187 

188 integrate_macros_plugins() 

189 

190 dag_run_conf: dict[str, Any] | None = None 

191 if from_server := self._ti_context_from_server: 

192 dag_run_conf = from_server.dag_run.conf or dag_run_conf 

193 

194 validated_params = process_params(self.task.dag, self.task, dag_run_conf, suppress_exception=False) 

195 

196 # Cache the context object, which ensures that all calls to get_template_context 

197 # are operating on the same context object. 

198 self._cached_template_context: Context = self._cached_template_context or { 

199 # From the Task Execution interface 

200 "dag": self.task.dag, 

201 "inlets": self.task.inlets, 

202 "map_index_template": self.task.map_index_template, 

203 "outlets": self.task.outlets, 

204 "run_id": self.run_id, 

205 "task": self.task, 

206 "task_instance": self, 

207 "ti": self, 

208 "outlet_events": OutletEventAccessors(), 

209 "inlet_events": InletEventsAccessors(self.task.inlets), 

210 "macros": MacrosAccessor(), 

211 "params": validated_params, 

212 # TODO: Make this go through Public API longer term. 

213 # "test_mode": task_instance.test_mode, 

214 "var": { 

215 "json": VariableAccessor(deserialize_json=True), 

216 "value": VariableAccessor(deserialize_json=False), 

217 }, 

218 "conn": ConnectionAccessor(), 

219 } 

220 if from_server: 

221 dag_run = from_server.dag_run 

222 context_from_server: Context = { 

223 # TODO: Assess if we need to pass these through timezone.coerce_datetime 

224 "dag_run": dag_run, # type: ignore[typeddict-item] # Removable after #46522 

225 "triggering_asset_events": TriggeringAssetEventsAccessor.build( 

226 AssetEventDagRunReferenceResult.from_asset_event_dag_run_reference(event) 

227 for event in dag_run.consumed_asset_events 

228 ), 

229 "task_instance_key_str": f"{self.task.dag_id}__{self.task.task_id}__{dag_run.run_id}", 

230 "task_reschedule_count": from_server.task_reschedule_count or 0, 

231 "prev_start_date_success": lazy_object_proxy.Proxy( 

232 lambda: coerce_datetime(get_previous_dagrun_success(self.id).start_date) 

233 ), 

234 "prev_end_date_success": lazy_object_proxy.Proxy( 

235 lambda: coerce_datetime(get_previous_dagrun_success(self.id).end_date) 

236 ), 

237 } 

238 self._cached_template_context.update(context_from_server) 

239 

240 if logical_date := coerce_datetime(dag_run.logical_date): 

241 if TYPE_CHECKING: 

242 assert isinstance(logical_date, DateTime) 

243 ds = logical_date.strftime("%Y-%m-%d") 

244 ds_nodash = ds.replace("-", "") 

245 ts = logical_date.isoformat() 

246 ts_nodash = logical_date.strftime("%Y%m%dT%H%M%S") 

247 ts_nodash_with_tz = ts.replace("-", "").replace(":", "") 

248 # logical_date and data_interval either coexist or be None together 

249 self._cached_template_context.update( 

250 { 

251 # keys that depend on logical_date 

252 "logical_date": logical_date, 

253 "ds": ds, 

254 "ds_nodash": ds_nodash, 

255 "task_instance_key_str": f"{self.task.dag_id}__{self.task.task_id}__{ds_nodash}", 

256 "ts": ts, 

257 "ts_nodash": ts_nodash, 

258 "ts_nodash_with_tz": ts_nodash_with_tz, 

259 # keys that depend on data_interval 

260 "data_interval_end": coerce_datetime(dag_run.data_interval_end), 

261 "data_interval_start": coerce_datetime(dag_run.data_interval_start), 

262 "prev_data_interval_start_success": lazy_object_proxy.Proxy( 

263 lambda: coerce_datetime(get_previous_dagrun_success(self.id).data_interval_start) 

264 ), 

265 "prev_data_interval_end_success": lazy_object_proxy.Proxy( 

266 lambda: coerce_datetime(get_previous_dagrun_success(self.id).data_interval_end) 

267 ), 

268 } 

269 ) 

270 

271 if from_server.upstream_map_indexes is not None: 

272 # We stash this in here for later use, but we purposefully don't want to document it's 

273 # existence. Should this be a private attribute on RuntimeTI instead perhaps? 

274 setattr(self, "_upstream_map_indexes", from_server.upstream_map_indexes) 

275 

276 return self._cached_template_context 

277 

278 def render_templates( 

279 self, context: Context | None = None, jinja_env: jinja2.Environment | None = None 

280 ) -> BaseOperator: 

281 """ 

282 Render templates in the operator fields. 

283 

284 If the task was originally mapped, this may replace ``self.task`` with 

285 the unmapped, fully rendered BaseOperator. The original ``self.task`` 

286 before replacement is returned. 

287 """ 

288 if not context: 

289 context = self.get_template_context() 

290 original_task = self.task 

291 

292 if TYPE_CHECKING: 

293 assert context 

294 

295 ti = context["ti"] 

296 

297 if TYPE_CHECKING: 

298 assert original_task 

299 assert self.task 

300 assert ti.task 

301 

302 # If self.task is mapped, this call replaces self.task to point to the 

303 # unmapped BaseOperator created by this function! This is because the 

304 # MappedOperator is useless for template rendering, and we need to be 

305 # able to access the unmapped task instead. 

306 self.task.render_template_fields(context, jinja_env) 

307 self.is_mapped = original_task.is_mapped 

308 return original_task 

309 

310 def xcom_pull( 

311 self, 

312 task_ids: str | Iterable[str] | None = None, 

313 dag_id: str | None = None, 

314 key: str = BaseXCom.XCOM_RETURN_KEY, 

315 include_prior_dates: bool = False, 

316 *, 

317 map_indexes: int | Iterable[int] | None | ArgNotSet = NOTSET, 

318 default: Any = None, 

319 run_id: str | None = None, 

320 ) -> Any: 

321 """ 

322 Pull XComs either from the API server (BaseXCom) or from the custom XCOM backend if configured. 

323 

324 The pull can be filtered optionally by certain criterion. 

325 

326 :param key: A key for the XCom. If provided, only XComs with matching 

327 keys will be returned. The default key is ``'return_value'``, also 

328 available as constant ``XCOM_RETURN_KEY``. This key is automatically 

329 given to XComs returned by tasks (as opposed to being pushed 

330 manually). 

331 :param task_ids: Only XComs from tasks with matching ids will be 

332 pulled. If *None* (default), the task_id of the calling task is used. 

333 :param dag_id: If provided, only pulls XComs from this Dag. If *None* 

334 (default), the Dag of the calling task is used. 

335 :param map_indexes: If provided, only pull XComs with matching indexes. 

336 If *None* (default), this is inferred from the task(s) being pulled 

337 (see below for details). 

338 :param include_prior_dates: If False, only XComs from the current 

339 logical_date are returned. If *True*, XComs from previous dates 

340 are returned as well. 

341 :param run_id: If provided, only pulls XComs from a DagRun w/a matching run_id. 

342 If *None* (default), the run_id of the calling task is used. 

343 

344 When pulling one single task (``task_id`` is *None* or a str) without 

345 specifying ``map_indexes``, the return value is a single XCom entry 

346 (map_indexes is set to map_index of the calling task instance). 

347 

348 When pulling task is mapped the specified ``map_index`` is used, so by default 

349 pulling on mapped task will result in no matching XComs if the task instance 

350 of the method call is not mapped. Otherwise, the map_index of the calling task 

351 instance is used. Setting ``map_indexes`` to *None* will pull XCom as it would 

352 from a non mapped task. 

353 

354 In either case, ``default`` (*None* if not specified) is returned if no 

355 matching XComs are found. 

356 

357 When pulling multiple tasks (i.e. either ``task_id`` or ``map_index`` is 

358 a non-str iterable), a list of matching XComs is returned. Elements in 

359 the list is ordered by item ordering in ``task_id`` and ``map_index``. 

360 """ 

361 if dag_id is None: 

362 dag_id = self.dag_id 

363 if run_id is None: 

364 run_id = self.run_id 

365 

366 single_task_requested = isinstance(task_ids, (str, type(None))) 

367 single_map_index_requested = isinstance(map_indexes, (int, type(None))) 

368 

369 if task_ids is None: 

370 # default to the current task if not provided 

371 task_ids = [self.task_id] 

372 elif isinstance(task_ids, str): 

373 task_ids = [task_ids] 

374 

375 # If map_indexes is not specified, pull xcoms from all map indexes for each task 

376 if not is_arg_set(map_indexes): 

377 xcoms: list[Any] = [] 

378 for t_id in task_ids: 

379 values = XCom.get_all( 

380 run_id=run_id, 

381 key=key, 

382 task_id=t_id, 

383 dag_id=dag_id, 

384 include_prior_dates=include_prior_dates, 

385 ) 

386 

387 if values is None: 

388 xcoms.append(None) 

389 else: 

390 xcoms.extend(values) 

391 # For single task pulling from unmapped task, return single value 

392 if single_task_requested and len(xcoms) == 1: 

393 return xcoms[0] 

394 return xcoms 

395 

396 # Original logic when map_indexes is explicitly specified 

397 map_indexes_iterable: Iterable[int | None] = [] 

398 if isinstance(map_indexes, int) or map_indexes is None: 

399 map_indexes_iterable = [map_indexes] 

400 elif isinstance(map_indexes, Iterable): 

401 map_indexes_iterable = map_indexes 

402 else: 

403 raise TypeError( 

404 f"Invalid type for map_indexes: expected int, iterable of ints, or None, got {type(map_indexes)}" 

405 ) 

406 

407 xcoms = [] 

408 for t_id, m_idx in product(task_ids, map_indexes_iterable): 

409 value = XCom.get_one( 

410 run_id=run_id, 

411 key=key, 

412 task_id=t_id, 

413 dag_id=dag_id, 

414 map_index=m_idx, 

415 include_prior_dates=include_prior_dates, 

416 ) 

417 if value is None: 

418 xcoms.append(default) 

419 else: 

420 xcoms.append(value) 

421 

422 if single_task_requested and single_map_index_requested: 

423 return xcoms[0] 

424 

425 return xcoms 

426 

427 def xcom_push(self, key: str, value: Any): 

428 """ 

429 Make an XCom available for tasks to pull. 

430 

431 :param key: Key to store the value under. 

432 :param value: Value to store. Only be JSON-serializable values may be used. 

433 """ 

434 _xcom_push(self, key, value) 

435 

436 def get_relevant_upstream_map_indexes( 

437 self, upstream: BaseOperator, ti_count: int | None, session: Any 

438 ) -> int | range | None: 

439 # TODO: Implement this method 

440 return None 

441 

442 def get_first_reschedule_date(self, context: Context) -> AwareDatetime | None: 

443 """Get the first reschedule date for the task instance if found, none otherwise.""" 

444 if context.get("task_reschedule_count", 0) == 0: 

445 # If the task has not been rescheduled, there is no need to ask the supervisor 

446 return None 

447 

448 max_tries: int = self.max_tries 

449 retries: int = self.task.retries or 0 

450 first_try_number = max_tries - retries + 1 

451 

452 log = structlog.get_logger(logger_name="task") 

453 

454 log.debug("Requesting first reschedule date from supervisor") 

455 

456 response = SUPERVISOR_COMMS.send( 

457 msg=GetTaskRescheduleStartDate(ti_id=self.id, try_number=first_try_number) 

458 ) 

459 

460 if TYPE_CHECKING: 

461 assert isinstance(response, TaskRescheduleStartDate) 

462 

463 return response.start_date 

464 

465 def get_previous_dagrun(self, state: str | None = None) -> DagRun | None: 

466 """Return the previous Dag run before the given logical date, optionally filtered by state.""" 

467 context = self.get_template_context() 

468 dag_run = context.get("dag_run") 

469 

470 log = structlog.get_logger(logger_name="task") 

471 

472 log.debug("Getting previous Dag run", dag_run=dag_run) 

473 

474 if dag_run is None: 

475 return None 

476 

477 if dag_run.logical_date is None: 

478 return None 

479 

480 response = SUPERVISOR_COMMS.send( 

481 msg=GetPreviousDagRun(dag_id=self.dag_id, logical_date=dag_run.logical_date, state=state) 

482 ) 

483 

484 if TYPE_CHECKING: 

485 assert isinstance(response, PreviousDagRunResult) 

486 

487 return response.dag_run 

488 

489 def get_previous_ti( 

490 self, 

491 state: TaskInstanceState | None = None, 

492 logical_date: AwareDatetime | None = None, 

493 map_index: int = -1, 

494 ) -> PreviousTIResponse | None: 

495 """ 

496 Return the previous task instance matching the given criteria. 

497 

498 :param state: Filter by TaskInstance state 

499 :param logical_date: Filter by logical date (returns TI before this date) 

500 :param map_index: Filter by map_index (defaults to -1 for non-mapped tasks) 

501 :return: Previous task instance or None if not found 

502 """ 

503 context = self.get_template_context() 

504 dag_run = context.get("dag_run") 

505 

506 log = structlog.get_logger(logger_name="task") 

507 log.debug("Getting previous task instance", task_id=self.task_id, state=state) 

508 

509 # Use current dag run's logical_date if not provided 

510 effective_logical_date = logical_date 

511 if effective_logical_date is None and dag_run and dag_run.logical_date: 

512 effective_logical_date = dag_run.logical_date 

513 

514 response = SUPERVISOR_COMMS.send( 

515 msg=GetPreviousTI( 

516 dag_id=self.dag_id, 

517 task_id=self.task_id, 

518 logical_date=effective_logical_date, 

519 map_index=map_index, 

520 state=state, 

521 ) 

522 ) 

523 

524 if TYPE_CHECKING: 

525 assert isinstance(response, PreviousTIResult) 

526 

527 return response.task_instance 

528 

529 @staticmethod 

530 def get_ti_count( 

531 dag_id: str, 

532 map_index: int | None = None, 

533 task_ids: list[str] | None = None, 

534 task_group_id: str | None = None, 

535 logical_dates: list[datetime] | None = None, 

536 run_ids: list[str] | None = None, 

537 states: list[str] | None = None, 

538 ) -> int: 

539 """Return the number of task instances matching the given criteria.""" 

540 response = SUPERVISOR_COMMS.send( 

541 GetTICount( 

542 dag_id=dag_id, 

543 map_index=map_index, 

544 task_ids=task_ids, 

545 task_group_id=task_group_id, 

546 logical_dates=logical_dates, 

547 run_ids=run_ids, 

548 states=states, 

549 ), 

550 ) 

551 

552 if TYPE_CHECKING: 

553 assert isinstance(response, TICount) 

554 

555 return response.count 

556 

557 @staticmethod 

558 def get_task_states( 

559 dag_id: str, 

560 map_index: int | None = None, 

561 task_ids: list[str] | None = None, 

562 task_group_id: str | None = None, 

563 logical_dates: list[datetime] | None = None, 

564 run_ids: list[str] | None = None, 

565 ) -> dict[str, Any]: 

566 """Return the task states matching the given criteria.""" 

567 response = SUPERVISOR_COMMS.send( 

568 GetTaskStates( 

569 dag_id=dag_id, 

570 map_index=map_index, 

571 task_ids=task_ids, 

572 task_group_id=task_group_id, 

573 logical_dates=logical_dates, 

574 run_ids=run_ids, 

575 ), 

576 ) 

577 

578 if TYPE_CHECKING: 

579 assert isinstance(response, TaskStatesResult) 

580 

581 return response.task_states 

582 

583 @staticmethod 

584 def get_task_breadcrumbs(dag_id: str, run_id: str) -> Iterable[dict[str, Any]]: 

585 """Return task breadcrumbs for the given dag run.""" 

586 response = SUPERVISOR_COMMS.send(GetTaskBreadcrumbs(dag_id=dag_id, run_id=run_id)) 

587 if TYPE_CHECKING: 

588 assert isinstance(response, TaskBreadcrumbsResult) 

589 return response.breadcrumbs 

590 

591 @staticmethod 

592 def get_dr_count( 

593 dag_id: str, 

594 logical_dates: list[datetime] | None = None, 

595 run_ids: list[str] | None = None, 

596 states: list[str] | None = None, 

597 ) -> int: 

598 """Return the number of Dag runs matching the given criteria.""" 

599 response = SUPERVISOR_COMMS.send( 

600 GetDRCount( 

601 dag_id=dag_id, 

602 logical_dates=logical_dates, 

603 run_ids=run_ids, 

604 states=states, 

605 ), 

606 ) 

607 

608 if TYPE_CHECKING: 

609 assert isinstance(response, DRCount) 

610 

611 return response.count 

612 

613 @staticmethod 

614 def get_dagrun_state(dag_id: str, run_id: str) -> str: 

615 """Return the state of the Dag run with the given Run ID.""" 

616 response = SUPERVISOR_COMMS.send(msg=GetDagRunState(dag_id=dag_id, run_id=run_id)) 

617 

618 if TYPE_CHECKING: 

619 assert isinstance(response, DagRunStateResult) 

620 

621 return response.state 

622 

623 @property 

624 def log_url(self) -> str: 

625 run_id = quote(self.run_id) 

626 base_url = conf.get("api", "base_url", fallback="http://localhost:8080/") 

627 map_index_value = self.map_index 

628 map_index = ( 

629 f"/mapped/{map_index_value}" if map_index_value is not None and map_index_value >= 0 else "" 

630 ) 

631 try_number_value = self.try_number 

632 try_number = ( 

633 f"?try_number={try_number_value}" if try_number_value is not None and try_number_value > 0 else "" 

634 ) 

635 _log_uri = f"{base_url.rstrip('/')}/dags/{self.dag_id}/runs/{run_id}/tasks/{self.task_id}{map_index}{try_number}" 

636 return _log_uri 

637 

638 @property 

639 def mark_success_url(self) -> str: 

640 """URL to mark TI success.""" 

641 return self.log_url 

642 

643 

644def _xcom_push(ti: RuntimeTaskInstance, key: str, value: Any, mapped_length: int | None = None) -> None: 

645 """Push a XCom through XCom.set, which pushes to XCom Backend if configured.""" 

646 # Private function, as we don't want to expose the ability to manually set `mapped_length` to SDK 

647 # consumers 

648 

649 XCom.set( 

650 key=key, 

651 value=value, 

652 dag_id=ti.dag_id, 

653 task_id=ti.task_id, 

654 run_id=ti.run_id, 

655 map_index=ti.map_index, 

656 _mapped_length=mapped_length, 

657 ) 

658 

659 

660def _xcom_push_to_db(ti: RuntimeTaskInstance, key: str, value: Any) -> None: 

661 """Push a XCom directly to metadata DB, bypassing custom xcom_backend.""" 

662 XCom._set_xcom_in_db( 

663 key=key, 

664 value=value, 

665 dag_id=ti.dag_id, 

666 task_id=ti.task_id, 

667 run_id=ti.run_id, 

668 map_index=ti.map_index, 

669 ) 

670 

671 

672def parse(what: StartupDetails, log: Logger) -> RuntimeTaskInstance: 

673 # TODO: Task-SDK: 

674 # Using BundleDagBag here is about 98% wrong, but it'll do for now 

675 from airflow.dag_processing.dagbag import BundleDagBag 

676 

677 bundle_info = what.bundle_info 

678 bundle_instance = DagBundlesManager().get_bundle( 

679 name=bundle_info.name, 

680 version=bundle_info.version, 

681 ) 

682 bundle_instance.initialize() 

683 _verify_bundle_access(bundle_instance, log) 

684 

685 dag_absolute_path = os.fspath(Path(bundle_instance.path, what.dag_rel_path)) 

686 bag = BundleDagBag( 

687 dag_folder=dag_absolute_path, 

688 safe_mode=False, 

689 load_op_links=False, 

690 bundle_path=bundle_instance.path, 

691 bundle_name=bundle_info.name, 

692 ) 

693 if TYPE_CHECKING: 

694 assert what.ti.dag_id 

695 

696 try: 

697 dag = bag.dags[what.ti.dag_id] 

698 except KeyError: 

699 log.error( 

700 "Dag not found during start up", dag_id=what.ti.dag_id, bundle=bundle_info, path=what.dag_rel_path 

701 ) 

702 sys.exit(1) 

703 

704 # install_loader() 

705 

706 try: 

707 task = dag.task_dict[what.ti.task_id] 

708 except KeyError: 

709 log.error( 

710 "Task not found in Dag during start up", 

711 dag_id=dag.dag_id, 

712 task_id=what.ti.task_id, 

713 bundle=bundle_info, 

714 path=what.dag_rel_path, 

715 ) 

716 sys.exit(1) 

717 

718 if not isinstance(task, (BaseOperator, MappedOperator)): 

719 raise TypeError( 

720 f"task is of the wrong type, got {type(task)}, wanted {BaseOperator} or {MappedOperator}" 

721 ) 

722 

723 return RuntimeTaskInstance.model_construct( 

724 **what.ti.model_dump(exclude_unset=True), 

725 task=task, 

726 bundle_instance=bundle_instance, 

727 _ti_context_from_server=what.ti_context, 

728 max_tries=what.ti_context.max_tries, 

729 start_date=what.start_date, 

730 state=TaskInstanceState.RUNNING, 

731 sentry_integration=what.sentry_integration, 

732 ) 

733 

734 

735# This global variable will be used by Connection/Variable/XCom classes, or other parts of the task's execution, 

736# to send requests back to the supervisor process. 

737# 

738# Why it needs to be a global: 

739# - Many parts of Airflow's codebase (e.g., connections, variables, and XComs) may rely on making dynamic requests 

740# to the parent process during task execution. 

741# - These calls occur in various locations and cannot easily pass the `CommsDecoder` instance through the 

742# deeply nested execution stack. 

743# - By defining `SUPERVISOR_COMMS` as a global, it ensures that this communication mechanism is readily 

744# accessible wherever needed during task execution without modifying every layer of the call stack. 

745SUPERVISOR_COMMS: CommsDecoder[ToTask, ToSupervisor] 

746 

747 

748# State machine! 

749# 1. Start up (receive details from supervisor) 

750# 2. Execution (run task code, possibly send requests) 

751# 3. Shutdown and report status 

752 

753 

754def _verify_bundle_access(bundle_instance: BaseDagBundle, log: Logger) -> None: 

755 """ 

756 Verify bundle is accessible by the current user. 

757 

758 This is called after user impersonation (if any) to ensure the bundle 

759 is actually accessible. Uses os.access() which works with any permission 

760 scheme (standard Unix permissions, ACLs, SELinux, etc.). 

761 

762 :param bundle_instance: The bundle instance to check 

763 :param log: Logger instance 

764 :raises AirflowException: if bundle is not accessible 

765 """ 

766 from getpass import getuser 

767 

768 from airflow.exceptions import AirflowException 

769 

770 bundle_path = bundle_instance.path 

771 

772 if not bundle_path.exists(): 

773 # Already handled by initialize() with a warning 

774 return 

775 

776 # Check read permission (and execute for directories to list contents) 

777 access_mode = os.R_OK 

778 if bundle_path.is_dir(): 

779 access_mode |= os.X_OK 

780 

781 if not os.access(bundle_path, access_mode): 

782 raise AirflowException( 

783 f"Bundle '{bundle_instance.name}' path '{bundle_path}' is not accessible " 

784 f"by user '{getuser()}'. When using run_as_user, ensure bundle directories " 

785 f"are readable by the impersonated user. " 

786 f"See: https://airflow.apache.org/docs/apache-airflow/stable/administration-and-deployment/dag-bundles.html" 

787 ) 

788 

789 

790def startup() -> tuple[RuntimeTaskInstance, Context, Logger]: 

791 # The parent sends us a StartupDetails message un-prompted. After this, every single message is only sent 

792 # in response to us sending a request. 

793 log = structlog.get_logger(logger_name="task") 

794 

795 if os.environ.get("_AIRFLOW__REEXECUTED_PROCESS") == "1" and ( 

796 msgjson := os.environ.get("_AIRFLOW__STARTUP_MSG") 

797 ): 

798 # Clear any Kerberos replace cache if there is one, so new process can't reuse it. 

799 os.environ.pop("KRB5CCNAME", None) 

800 # entrypoint of re-exec process 

801 

802 msg: StartupDetails = TypeAdapter(StartupDetails).validate_json(msgjson) 

803 reinit_supervisor_comms() 

804 

805 # We delay this message until _after_ we've got the logging re-configured, otherwise it will show up 

806 # on stdout 

807 log.debug("Using serialized startup message from environment", msg=msg) 

808 else: 

809 # normal entry point 

810 msg = SUPERVISOR_COMMS._get_response() # type: ignore[assignment] 

811 

812 if not isinstance(msg, StartupDetails): 

813 raise RuntimeError(f"Unhandled startup message {type(msg)} {msg}") 

814 

815 # setproctitle causes issue on Mac OS: https://github.com/benoitc/gunicorn/issues/3021 

816 os_type = sys.platform 

817 if os_type == "darwin": 

818 log.debug("Mac OS detected, skipping setproctitle") 

819 else: 

820 from setproctitle import setproctitle 

821 

822 setproctitle(f"airflow worker -- {msg.ti.id}") 

823 

824 try: 

825 get_listener_manager().hook.on_starting(component=TaskRunnerMarker()) 

826 except Exception: 

827 log.exception("error calling listener") 

828 

829 with _airflow_parsing_context_manager(dag_id=msg.ti.dag_id, task_id=msg.ti.task_id): 

830 ti = parse(msg, log) 

831 log.debug("Dag file parsed", file=msg.dag_rel_path) 

832 

833 run_as_user = getattr(ti.task, "run_as_user", None) or conf.get( 

834 "core", "default_impersonation", fallback=None 

835 ) 

836 

837 if os.environ.get("_AIRFLOW__REEXECUTED_PROCESS") != "1" and run_as_user and run_as_user != getuser(): 

838 # enters here for re-exec process 

839 os.environ["_AIRFLOW__REEXECUTED_PROCESS"] = "1" 

840 # store startup message in environment for re-exec process 

841 os.environ["_AIRFLOW__STARTUP_MSG"] = msg.model_dump_json() 

842 os.set_inheritable(SUPERVISOR_COMMS.socket.fileno(), True) 

843 

844 # Import main directly from the module instead of re-executing the file. 

845 # This ensures that when other parts modules import 

846 # airflow.sdk.execution_time.task_runner, they get the same module instance 

847 # with the properly initialized SUPERVISOR_COMMS global variable. 

848 # If we re-executed the module with `python -m`, it would load as __main__ and future 

849 # imports would get a fresh copy without the initialized globals. 

850 rexec_python_code = "from airflow.sdk.execution_time.task_runner import main; main()" 

851 cmd = ["sudo", "-E", "-H", "-u", run_as_user, sys.executable, "-c", rexec_python_code] 

852 log.info( 

853 "Running command", 

854 command=cmd, 

855 ) 

856 os.execvp("sudo", cmd) 

857 

858 # ideally, we should never reach here, but if we do, we should return None, None, None 

859 return None, None, None 

860 

861 return ti, ti.get_template_context(), log 

862 

863 

864def _serialize_template_field(template_field: Any, name: str) -> str | dict | list | int | float: 

865 """ 

866 Return a serializable representation of the templated field. 

867 

868 If ``templated_field`` contains a class or instance that requires recursive 

869 templating, store them as strings. Otherwise simply return the field as-is. 

870 

871 Used sdk secrets masker to redact secrets in the serialized output. 

872 """ 

873 import json 

874 

875 from airflow.sdk._shared.secrets_masker import redact 

876 

877 def is_jsonable(x): 

878 try: 

879 json.dumps(x) 

880 except (TypeError, OverflowError): 

881 return False 

882 else: 

883 return True 

884 

885 def translate_tuples_to_lists(obj: Any): 

886 """Recursively convert tuples to lists.""" 

887 if isinstance(obj, tuple): 

888 return [translate_tuples_to_lists(item) for item in obj] 

889 if isinstance(obj, list): 

890 return [translate_tuples_to_lists(item) for item in obj] 

891 if isinstance(obj, dict): 

892 return {key: translate_tuples_to_lists(value) for key, value in obj.items()} 

893 return obj 

894 

895 def sort_dict_recursively(obj: Any) -> Any: 

896 """Recursively sort dictionaries to ensure consistent ordering.""" 

897 if isinstance(obj, dict): 

898 return {k: sort_dict_recursively(v) for k, v in sorted(obj.items())} 

899 if isinstance(obj, list): 

900 return [sort_dict_recursively(item) for item in obj] 

901 if isinstance(obj, tuple): 

902 return tuple(sort_dict_recursively(item) for item in obj) 

903 return obj 

904 

905 max_length = conf.getint("core", "max_templated_field_length") 

906 

907 if not is_jsonable(template_field): 

908 try: 

909 serialized = template_field.serialize() 

910 except AttributeError: 

911 serialized = str(template_field) 

912 if len(serialized) > max_length: 

913 rendered = redact(serialized, name) 

914 return ( 

915 "Truncated. You can change this behaviour in [core]max_templated_field_length. " 

916 f"{rendered[: max_length - 79]!r}... " 

917 ) 

918 return serialized 

919 if not template_field and not isinstance(template_field, tuple): 

920 # Avoid unnecessary serialization steps for empty fields unless they are tuples 

921 # and need to be converted to lists 

922 return template_field 

923 template_field = translate_tuples_to_lists(template_field) 

924 # Sort dictionaries recursively to ensure consistent string representation 

925 # This prevents hash inconsistencies when dict ordering varies 

926 if isinstance(template_field, dict): 

927 template_field = sort_dict_recursively(template_field) 

928 serialized = str(template_field) 

929 if len(serialized) > max_length: 

930 rendered = redact(serialized, name) 

931 return ( 

932 "Truncated. You can change this behaviour in [core]max_templated_field_length. " 

933 f"{rendered[: max_length - 79]!r}... " 

934 ) 

935 return template_field 

936 

937 

938def _serialize_rendered_fields(task: AbstractOperator) -> dict[str, JsonValue]: 

939 from airflow.sdk._shared.secrets_masker import redact 

940 

941 rendered_fields = {} 

942 for field in task.template_fields: 

943 value = getattr(task, field) 

944 serialized = _serialize_template_field(value, field) 

945 # Redact secrets in the task process itself before sending to API server 

946 # This ensures that the secrets those are registered via mask_secret() on workers / dag processor are properly masked 

947 # on the UI. 

948 rendered_fields[field] = redact(serialized, field) 

949 

950 return rendered_fields # type: ignore[return-value] # Convince mypy that this is OK since we pass JsonValue to redact, so it will return the same 

951 

952 

953def _build_asset_profiles(lineage_objects: list) -> Iterator[AssetProfile]: 

954 # Lineage can have other types of objects besides assets, so we need to process them a bit. 

955 for obj in lineage_objects or (): 

956 if isinstance(obj, Asset): 

957 yield AssetProfile(name=obj.name, uri=obj.uri, type=Asset.__name__) 

958 elif isinstance(obj, AssetNameRef): 

959 yield AssetProfile(name=obj.name, type=AssetNameRef.__name__) 

960 elif isinstance(obj, AssetUriRef): 

961 yield AssetProfile(uri=obj.uri, type=AssetUriRef.__name__) 

962 elif isinstance(obj, AssetAlias): 

963 yield AssetProfile(name=obj.name, type=AssetAlias.__name__) 

964 

965 

966def _serialize_outlet_events(events: OutletEventAccessorsProtocol) -> Iterator[dict[str, JsonValue]]: 

967 if TYPE_CHECKING: 

968 assert isinstance(events, OutletEventAccessors) 

969 # We just collect everything the user recorded in the accessors. 

970 # Further filtering will be done in the API server. 

971 for key, accessor in events._dict.items(): 

972 if isinstance(key, AssetUniqueKey): 

973 yield {"dest_asset_key": attrs.asdict(key), "extra": accessor.extra} 

974 for alias_event in accessor.asset_alias_events: 

975 yield attrs.asdict(alias_event) 

976 

977 

978def _prepare(ti: RuntimeTaskInstance, log: Logger, context: Context) -> ToSupervisor | None: 

979 ti.hostname = get_hostname() 

980 ti.task = ti.task.prepare_for_execution() 

981 # Since context is now cached, and calling `ti.get_template_context` will return the same dict, we want to 

982 # update the value of the task that is sent from there 

983 context["task"] = ti.task 

984 

985 jinja_env = ti.task.dag.get_template_env() 

986 ti.render_templates(context=context, jinja_env=jinja_env) 

987 

988 if rendered_fields := _serialize_rendered_fields(ti.task): 

989 # so that we do not call the API unnecessarily 

990 SUPERVISOR_COMMS.send(msg=SetRenderedFields(rendered_fields=rendered_fields)) 

991 

992 # Try to render map_index_template early with available context (will be re-rendered after execution) 

993 # This provides a partial label during task execution for templates using pre-execution context 

994 # If rendering fails here, we suppress the error since it will be re-rendered after execution 

995 try: 

996 if rendered_map_index := _render_map_index(context, ti=ti, log=log): 

997 ti.rendered_map_index = rendered_map_index 

998 log.debug("Sending early rendered map index", length=len(rendered_map_index)) 

999 SUPERVISOR_COMMS.send(msg=SetRenderedMapIndex(rendered_map_index=rendered_map_index)) 

1000 except Exception: 

1001 log.debug( 

1002 "Early rendering of map_index_template failed, will retry after task execution", exc_info=True 

1003 ) 

1004 

1005 _validate_task_inlets_and_outlets(ti=ti, log=log) 

1006 

1007 try: 

1008 # TODO: Call pre execute etc. 

1009 get_listener_manager().hook.on_task_instance_running( 

1010 previous_state=TaskInstanceState.QUEUED, task_instance=ti 

1011 ) 

1012 except Exception: 

1013 log.exception("error calling listener") 

1014 

1015 # No error, carry on and execute the task 

1016 return None 

1017 

1018 

1019def _validate_task_inlets_and_outlets(*, ti: RuntimeTaskInstance, log: Logger) -> None: 

1020 if not ti.task.inlets and not ti.task.outlets: 

1021 return 

1022 

1023 inactive_assets_resp = SUPERVISOR_COMMS.send(msg=ValidateInletsAndOutlets(ti_id=ti.id)) 

1024 if TYPE_CHECKING: 

1025 assert isinstance(inactive_assets_resp, InactiveAssetsResult) 

1026 if inactive_assets := inactive_assets_resp.inactive_assets: 

1027 raise AirflowInactiveAssetInInletOrOutletException( 

1028 inactive_asset_keys=[ 

1029 AssetUniqueKey.from_profile(asset_profile) for asset_profile in inactive_assets 

1030 ] 

1031 ) 

1032 

1033 

1034def _defer_task( 

1035 defer: TaskDeferred, ti: RuntimeTaskInstance, log: Logger 

1036) -> tuple[ToSupervisor, TaskInstanceState]: 

1037 # TODO: Should we use structlog.bind_contextvars here for dag_id, task_id & run_id? 

1038 

1039 log.info("Pausing task as DEFERRED. ", dag_id=ti.dag_id, task_id=ti.task_id, run_id=ti.run_id) 

1040 classpath, trigger_kwargs = defer.trigger.serialize() 

1041 queue: str | None = None 

1042 # Currently, only task-associated BaseTrigger instances may have a non-None queue, 

1043 # and only when triggerer.queues_enabled is True. 

1044 if not isinstance(defer.trigger, (BaseEventTrigger, CallbackTrigger)) and conf.getboolean( 

1045 "triggerer", "queues_enabled", fallback=False 

1046 ): 

1047 queue = ti.task.queue 

1048 

1049 from airflow.sdk.serde import serialize as serde_serialize 

1050 

1051 trigger_kwargs = serde_serialize(trigger_kwargs) 

1052 next_kwargs = serde_serialize(defer.kwargs or {}) 

1053 

1054 if TYPE_CHECKING: 

1055 assert isinstance(next_kwargs, dict) 

1056 assert isinstance(trigger_kwargs, dict) 

1057 

1058 msg = DeferTask( 

1059 classpath=classpath, 

1060 trigger_kwargs=trigger_kwargs, 

1061 trigger_timeout=defer.timeout, 

1062 queue=queue, 

1063 next_method=defer.method_name, 

1064 next_kwargs=next_kwargs, 

1065 ) 

1066 state = TaskInstanceState.DEFERRED 

1067 

1068 return msg, state 

1069 

1070 

1071@Sentry.enrich_errors 

1072def run( 

1073 ti: RuntimeTaskInstance, 

1074 context: Context, 

1075 log: Logger, 

1076) -> tuple[TaskInstanceState, ToSupervisor | None, BaseException | None]: 

1077 """Run the task in this process.""" 

1078 import signal 

1079 

1080 from airflow.sdk.exceptions import ( 

1081 AirflowFailException, 

1082 AirflowRescheduleException, 

1083 AirflowSensorTimeout, 

1084 AirflowSkipException, 

1085 AirflowTaskTerminated, 

1086 DagRunTriggerException, 

1087 DownstreamTasksSkipped, 

1088 TaskDeferred, 

1089 ) 

1090 

1091 if TYPE_CHECKING: 

1092 assert ti.task is not None 

1093 assert isinstance(ti.task, BaseOperator) 

1094 

1095 parent_pid = os.getpid() 

1096 

1097 def _on_term(signum, frame): 

1098 pid = os.getpid() 

1099 if pid != parent_pid: 

1100 return 

1101 

1102 ti.task.on_kill() 

1103 

1104 signal.signal(signal.SIGTERM, _on_term) 

1105 

1106 msg: ToSupervisor | None = None 

1107 state: TaskInstanceState 

1108 error: BaseException | None = None 

1109 

1110 try: 

1111 # First, clear the xcom data sent from server 

1112 if ti._ti_context_from_server and (keys_to_delete := ti._ti_context_from_server.xcom_keys_to_clear): 

1113 for x in keys_to_delete: 

1114 log.debug("Clearing XCom with key", key=x) 

1115 XCom.delete( 

1116 key=x, 

1117 dag_id=ti.dag_id, 

1118 task_id=ti.task_id, 

1119 run_id=ti.run_id, 

1120 map_index=ti.map_index, 

1121 ) 

1122 

1123 with set_current_context(context): 

1124 # This is the earliest that we can render templates -- as if it excepts for any reason we need to 

1125 # catch it and handle it like a normal task failure 

1126 if early_exit := _prepare(ti, log, context): 

1127 msg = early_exit 

1128 ti.state = state = TaskInstanceState.FAILED 

1129 return state, msg, error 

1130 

1131 try: 

1132 result = _execute_task(context=context, ti=ti, log=log) 

1133 except Exception: 

1134 import jinja2 

1135 

1136 # If the task failed, swallow rendering error so it doesn't mask the main error. 

1137 with contextlib.suppress(jinja2.TemplateSyntaxError, jinja2.UndefinedError): 

1138 previous_rendered_map_index = ti.rendered_map_index 

1139 ti.rendered_map_index = _render_map_index(context, ti=ti, log=log) 

1140 # Send update only if value changed (e.g., user set context variables during execution) 

1141 if ti.rendered_map_index and ti.rendered_map_index != previous_rendered_map_index: 

1142 SUPERVISOR_COMMS.send( 

1143 msg=SetRenderedMapIndex(rendered_map_index=ti.rendered_map_index) 

1144 ) 

1145 raise 

1146 else: # If the task succeeded, render normally to let rendering error bubble up. 

1147 previous_rendered_map_index = ti.rendered_map_index 

1148 ti.rendered_map_index = _render_map_index(context, ti=ti, log=log) 

1149 # Send update only if value changed (e.g., user set context variables during execution) 

1150 if ti.rendered_map_index and ti.rendered_map_index != previous_rendered_map_index: 

1151 SUPERVISOR_COMMS.send(msg=SetRenderedMapIndex(rendered_map_index=ti.rendered_map_index)) 

1152 

1153 _push_xcom_if_needed(result, ti, log) 

1154 

1155 msg, state = _handle_current_task_success(context, ti) 

1156 except DownstreamTasksSkipped as skip: 

1157 log.info("Skipping downstream tasks.") 

1158 tasks_to_skip = skip.tasks if isinstance(skip.tasks, list) else [skip.tasks] 

1159 SUPERVISOR_COMMS.send(msg=SkipDownstreamTasks(tasks=tasks_to_skip)) 

1160 msg, state = _handle_current_task_success(context, ti) 

1161 except DagRunTriggerException as drte: 

1162 msg, state = _handle_trigger_dag_run(drte, context, ti, log) 

1163 except TaskDeferred as defer: 

1164 msg, state = _defer_task(defer, ti, log) 

1165 except AirflowSkipException as e: 

1166 if e.args: 

1167 log.info("Skipping task.", reason=e.args[0]) 

1168 msg = TaskState( 

1169 state=TaskInstanceState.SKIPPED, 

1170 end_date=datetime.now(tz=timezone.utc), 

1171 rendered_map_index=ti.rendered_map_index, 

1172 ) 

1173 state = TaskInstanceState.SKIPPED 

1174 except AirflowRescheduleException as reschedule: 

1175 log.info("Rescheduling task, marking task as UP_FOR_RESCHEDULE") 

1176 msg = RescheduleTask( 

1177 reschedule_date=reschedule.reschedule_date, end_date=datetime.now(tz=timezone.utc) 

1178 ) 

1179 state = TaskInstanceState.UP_FOR_RESCHEDULE 

1180 except (AirflowFailException, AirflowSensorTimeout) as e: 

1181 # If AirflowFailException is raised, task should not retry. 

1182 # If a sensor in reschedule mode reaches timeout, task should not retry. 

1183 log.exception("Task failed with exception") 

1184 ti.end_date = datetime.now(tz=timezone.utc) 

1185 msg = TaskState( 

1186 state=TaskInstanceState.FAILED, 

1187 end_date=ti.end_date, 

1188 rendered_map_index=ti.rendered_map_index, 

1189 ) 

1190 state = TaskInstanceState.FAILED 

1191 error = e 

1192 except (AirflowTaskTimeout, AirflowException, AirflowRuntimeError) as e: 

1193 # We should allow retries if the task has defined it. 

1194 log.exception("Task failed with exception") 

1195 msg, state = _handle_current_task_failed(ti) 

1196 error = e 

1197 except AirflowTaskTerminated as e: 

1198 # External state updates are already handled with `ti_heartbeat` and will be 

1199 # updated already be another UI API. So, these exceptions should ideally never be thrown. 

1200 # If these are thrown, we should mark the TI state as failed. 

1201 log.exception("Task failed with exception") 

1202 ti.end_date = datetime.now(tz=timezone.utc) 

1203 msg = TaskState( 

1204 state=TaskInstanceState.FAILED, 

1205 end_date=ti.end_date, 

1206 rendered_map_index=ti.rendered_map_index, 

1207 ) 

1208 state = TaskInstanceState.FAILED 

1209 error = e 

1210 except SystemExit as e: 

1211 # SystemExit needs to be retried if they are eligible. 

1212 log.error("Task exited", exit_code=e.code) 

1213 msg, state = _handle_current_task_failed(ti) 

1214 error = e 

1215 except BaseException as e: 

1216 log.exception("Task failed with exception") 

1217 msg, state = _handle_current_task_failed(ti) 

1218 error = e 

1219 finally: 

1220 if msg: 

1221 SUPERVISOR_COMMS.send(msg=msg) 

1222 

1223 # Return the message to make unit tests easier too 

1224 ti.state = state 

1225 return state, msg, error 

1226 

1227 

1228def _handle_current_task_success( 

1229 context: Context, 

1230 ti: RuntimeTaskInstance, 

1231) -> tuple[SucceedTask, TaskInstanceState]: 

1232 end_date = datetime.now(tz=timezone.utc) 

1233 ti.end_date = end_date 

1234 

1235 # Record operator and task instance success metrics 

1236 operator = ti.task.__class__.__name__ 

1237 stats_tags = {"dag_id": ti.dag_id, "task_id": ti.task_id} 

1238 

1239 Stats.incr(f"operator_successes_{operator}", tags=stats_tags) 

1240 # Same metric with tagging 

1241 Stats.incr("operator_successes", tags={**stats_tags, "operator": operator}) 

1242 Stats.incr("ti_successes", tags=stats_tags) 

1243 

1244 task_outlets = list(_build_asset_profiles(ti.task.outlets)) 

1245 outlet_events = list(_serialize_outlet_events(context["outlet_events"])) 

1246 msg = SucceedTask( 

1247 end_date=end_date, 

1248 task_outlets=task_outlets, 

1249 outlet_events=outlet_events, 

1250 rendered_map_index=ti.rendered_map_index, 

1251 ) 

1252 return msg, TaskInstanceState.SUCCESS 

1253 

1254 

1255def _handle_current_task_failed( 

1256 ti: RuntimeTaskInstance, 

1257) -> tuple[RetryTask, TaskInstanceState] | tuple[TaskState, TaskInstanceState]: 

1258 end_date = datetime.now(tz=timezone.utc) 

1259 ti.end_date = end_date 

1260 

1261 # Record operator and task instance failed metrics 

1262 operator = ti.task.__class__.__name__ 

1263 stats_tags = {"dag_id": ti.dag_id, "task_id": ti.task_id} 

1264 

1265 Stats.incr(f"operator_failures_{operator}", tags=stats_tags) 

1266 # Same metric with tagging 

1267 Stats.incr("operator_failures", tags={**stats_tags, "operator": operator}) 

1268 Stats.incr("ti_failures", tags=stats_tags) 

1269 

1270 if ti._ti_context_from_server and ti._ti_context_from_server.should_retry: 

1271 return RetryTask(end_date=end_date), TaskInstanceState.UP_FOR_RETRY 

1272 return ( 

1273 TaskState( 

1274 state=TaskInstanceState.FAILED, end_date=end_date, rendered_map_index=ti.rendered_map_index 

1275 ), 

1276 TaskInstanceState.FAILED, 

1277 ) 

1278 

1279 

1280def _handle_trigger_dag_run( 

1281 drte: DagRunTriggerException, context: Context, ti: RuntimeTaskInstance, log: Logger 

1282) -> tuple[ToSupervisor, TaskInstanceState]: 

1283 """Handle exception from TriggerDagRunOperator.""" 

1284 log.info("Triggering Dag Run.", trigger_dag_id=drte.trigger_dag_id) 

1285 comms_msg = SUPERVISOR_COMMS.send( 

1286 TriggerDagRun( 

1287 dag_id=drte.trigger_dag_id, 

1288 run_id=drte.dag_run_id, 

1289 logical_date=drte.logical_date, 

1290 conf=drte.conf, 

1291 reset_dag_run=drte.reset_dag_run, 

1292 ), 

1293 ) 

1294 

1295 if isinstance(comms_msg, ErrorResponse) and comms_msg.error == ErrorType.DAGRUN_ALREADY_EXISTS: 

1296 if drte.skip_when_already_exists: 

1297 log.info( 

1298 "Dag Run already exists, skipping task as skip_when_already_exists is set to True.", 

1299 dag_id=drte.trigger_dag_id, 

1300 ) 

1301 msg = TaskState( 

1302 state=TaskInstanceState.SKIPPED, 

1303 end_date=datetime.now(tz=timezone.utc), 

1304 rendered_map_index=ti.rendered_map_index, 

1305 ) 

1306 state = TaskInstanceState.SKIPPED 

1307 else: 

1308 log.error("Dag Run already exists, marking task as failed.", dag_id=drte.trigger_dag_id) 

1309 msg = TaskState( 

1310 state=TaskInstanceState.FAILED, 

1311 end_date=datetime.now(tz=timezone.utc), 

1312 rendered_map_index=ti.rendered_map_index, 

1313 ) 

1314 state = TaskInstanceState.FAILED 

1315 

1316 return msg, state 

1317 

1318 log.info("Dag Run triggered successfully.", trigger_dag_id=drte.trigger_dag_id) 

1319 

1320 # Store the run id from the dag run (either created or found above) to 

1321 # be used when creating the extra link on the webserver. 

1322 ti.xcom_push(key="trigger_run_id", value=drte.dag_run_id) 

1323 

1324 if drte.wait_for_completion: 

1325 if drte.deferrable: 

1326 from airflow.providers.standard.triggers.external_task import DagStateTrigger 

1327 

1328 defer = TaskDeferred( 

1329 trigger=DagStateTrigger( 

1330 dag_id=drte.trigger_dag_id, 

1331 states=drte.allowed_states + drte.failed_states, # type: ignore[arg-type] 

1332 # Don't filter by execution_dates when run_ids is provided. 

1333 # run_id uniquely identifies a DAG run, and when reset_dag_run=True, 

1334 # drte.logical_date might be a newly calculated value that doesn't match 

1335 # the persisted logical_date in the database, causing the trigger to never find the run. 

1336 execution_dates=None, 

1337 run_ids=[drte.dag_run_id], 

1338 poll_interval=drte.poke_interval, 

1339 ), 

1340 method_name="execute_complete", 

1341 ) 

1342 return _defer_task(defer, ti, log) 

1343 while True: 

1344 log.info( 

1345 "Waiting for dag run to complete execution in allowed state.", 

1346 dag_id=drte.trigger_dag_id, 

1347 run_id=drte.dag_run_id, 

1348 allowed_state=drte.allowed_states, 

1349 ) 

1350 time.sleep(drte.poke_interval) 

1351 

1352 comms_msg = SUPERVISOR_COMMS.send( 

1353 GetDagRunState(dag_id=drte.trigger_dag_id, run_id=drte.dag_run_id) 

1354 ) 

1355 if TYPE_CHECKING: 

1356 assert isinstance(comms_msg, DagRunStateResult) 

1357 if comms_msg.state in drte.failed_states: 

1358 log.error( 

1359 "DagRun finished with failed state.", dag_id=drte.trigger_dag_id, state=comms_msg.state 

1360 ) 

1361 msg = TaskState( 

1362 state=TaskInstanceState.FAILED, 

1363 end_date=datetime.now(tz=timezone.utc), 

1364 rendered_map_index=ti.rendered_map_index, 

1365 ) 

1366 state = TaskInstanceState.FAILED 

1367 return msg, state 

1368 if comms_msg.state in drte.allowed_states: 

1369 log.info( 

1370 "DagRun finished with allowed state.", dag_id=drte.trigger_dag_id, state=comms_msg.state 

1371 ) 

1372 break 

1373 log.debug( 

1374 "DagRun not yet in allowed or failed state.", 

1375 dag_id=drte.trigger_dag_id, 

1376 state=comms_msg.state, 

1377 ) 

1378 else: 

1379 # Fire-and-forget mode: wait_for_completion=False 

1380 if drte.deferrable: 

1381 log.info( 

1382 "Ignoring deferrable=True because wait_for_completion=False. " 

1383 "Task will complete immediately without waiting for the triggered DAG run.", 

1384 trigger_dag_id=drte.trigger_dag_id, 

1385 ) 

1386 

1387 return _handle_current_task_success(context, ti) 

1388 

1389 

1390def _run_task_state_change_callbacks( 

1391 task: BaseOperator, 

1392 kind: Literal[ 

1393 "on_execute_callback", 

1394 "on_failure_callback", 

1395 "on_success_callback", 

1396 "on_retry_callback", 

1397 "on_skipped_callback", 

1398 ], 

1399 context: Context, 

1400 log: Logger, 

1401) -> None: 

1402 callback: Callable[[Context], None] 

1403 for i, callback in enumerate(getattr(task, kind)): 

1404 try: 

1405 create_executable_runner(callback, context_get_outlet_events(context), logger=log).run(context) 

1406 except Exception: 

1407 log.exception("Failed to run task callback", kind=kind, index=i, callback=callback) 

1408 

1409 

1410def _send_error_email_notification( 

1411 task: BaseOperator | MappedOperator, 

1412 ti: RuntimeTaskInstance, 

1413 context: Context, 

1414 error: BaseException | str | None, 

1415 log: Logger, 

1416) -> None: 

1417 """Send email notification for task errors using SmtpNotifier.""" 

1418 try: 

1419 from airflow.providers.smtp.notifications.smtp import SmtpNotifier 

1420 except ImportError: 

1421 log.error( 

1422 "Failed to send task failure or retry email notification: " 

1423 "`apache-airflow-providers-smtp` is not installed. " 

1424 "Install this provider to enable email notifications." 

1425 ) 

1426 return 

1427 

1428 if not task.email: 

1429 return 

1430 

1431 subject_template_file = conf.get("email", "subject_template", fallback=None) 

1432 

1433 # Read the template file if configured 

1434 if subject_template_file and Path(subject_template_file).exists(): 

1435 subject = Path(subject_template_file).read_text() 

1436 else: 

1437 # Fallback to default 

1438 subject = "Airflow alert: {{ti}}" 

1439 

1440 html_content_template_file = conf.get("email", "html_content_template", fallback=None) 

1441 

1442 # Read the template file if configured 

1443 if html_content_template_file and Path(html_content_template_file).exists(): 

1444 html_content = Path(html_content_template_file).read_text() 

1445 else: 

1446 # Fallback to default 

1447 # For reporting purposes, we report based on 1-indexed, 

1448 # not 0-indexed lists (i.e. Try 1 instead of Try 0 for the first attempt). 

1449 html_content = ( 

1450 "Try {{try_number}} out of {{max_tries + 1}}<br>" 

1451 "Exception:<br>{{exception_html}}<br>" 

1452 'Log: <a href="{{ti.log_url}}">Link</a><br>' 

1453 "Host: {{ti.hostname}}<br>" 

1454 'Mark success: <a href="{{ti.mark_success_url}}">Link</a><br>' 

1455 ) 

1456 

1457 # Add exception_html to context for template rendering 

1458 import html 

1459 

1460 exception_html = html.escape(str(error)).replace("\n", "<br>") 

1461 additional_context = { 

1462 "exception": error, 

1463 "exception_html": exception_html, 

1464 "try_number": ti.try_number, 

1465 "max_tries": ti.max_tries, 

1466 } 

1467 email_context = {**context, **additional_context} 

1468 to_emails = task.email 

1469 if not to_emails: 

1470 return 

1471 

1472 try: 

1473 notifier = SmtpNotifier( 

1474 to=to_emails, 

1475 subject=subject, 

1476 html_content=html_content, 

1477 from_email=conf.get("email", "from_email", fallback="airflow@airflow"), 

1478 ) 

1479 notifier(email_context) 

1480 except Exception: 

1481 log.exception("Failed to send email notification") 

1482 

1483 

1484def _execute_task(context: Context, ti: RuntimeTaskInstance, log: Logger): 

1485 """Execute Task (optionally with a Timeout) and push Xcom results.""" 

1486 task = ti.task 

1487 execute = task.execute 

1488 

1489 if ti._ti_context_from_server and (next_method := ti._ti_context_from_server.next_method): 

1490 from airflow.sdk.serde import deserialize 

1491 

1492 next_kwargs_data = ti._ti_context_from_server.next_kwargs or {} 

1493 try: 

1494 if TYPE_CHECKING: 

1495 assert isinstance(next_kwargs_data, dict) 

1496 kwargs = deserialize(next_kwargs_data) 

1497 except (ImportError, KeyError, AttributeError, TypeError): 

1498 from airflow.serialization.serialized_objects import BaseSerialization 

1499 

1500 kwargs = BaseSerialization.deserialize(next_kwargs_data) 

1501 

1502 if TYPE_CHECKING: 

1503 assert isinstance(kwargs, dict) 

1504 execute = functools.partial(task.resume_execution, next_method=next_method, next_kwargs=kwargs) 

1505 

1506 ctx = contextvars.copy_context() 

1507 # Populate the context var so ExecutorSafeguard doesn't complain 

1508 ctx.run(ExecutorSafeguard.tracker.set, task) 

1509 

1510 # Export context in os.environ to make it available for operators to use. 

1511 airflow_context_vars = context_to_airflow_vars(context, in_env_var_format=True) 

1512 os.environ.update(airflow_context_vars) 

1513 

1514 outlet_events = context_get_outlet_events(context) 

1515 

1516 if (pre_execute_hook := task._pre_execute_hook) is not None: 

1517 create_executable_runner(pre_execute_hook, outlet_events, logger=log).run(context) 

1518 if getattr(pre_execute_hook := task.pre_execute, "__func__", None) is not BaseOperator.pre_execute: 

1519 create_executable_runner(pre_execute_hook, outlet_events, logger=log).run(context) 

1520 

1521 _run_task_state_change_callbacks(task, "on_execute_callback", context, log) 

1522 

1523 if task.execution_timeout: 

1524 from airflow.sdk.execution_time.timeout import timeout 

1525 

1526 # TODO: handle timeout in case of deferral 

1527 timeout_seconds = task.execution_timeout.total_seconds() 

1528 try: 

1529 # It's possible we're already timed out, so fast-fail if true 

1530 if timeout_seconds <= 0: 

1531 raise AirflowTaskTimeout() 

1532 # Run task in timeout wrapper 

1533 with timeout(timeout_seconds): 

1534 result = ctx.run(execute, context=context) 

1535 except AirflowTaskTimeout: 

1536 task.on_kill() 

1537 raise 

1538 else: 

1539 result = ctx.run(execute, context=context) 

1540 

1541 if (post_execute_hook := task._post_execute_hook) is not None: 

1542 create_executable_runner(post_execute_hook, outlet_events, logger=log).run(context, result) 

1543 if getattr(post_execute_hook := task.post_execute, "__func__", None) is not BaseOperator.post_execute: 

1544 create_executable_runner(post_execute_hook, outlet_events, logger=log).run(context) 

1545 

1546 return result 

1547 

1548 

1549def _render_map_index(context: Context, ti: RuntimeTaskInstance, log: Logger) -> str | None: 

1550 """Render named map index if the Dag author defined map_index_template at the task level.""" 

1551 if (template := context.get("map_index_template")) is None: 

1552 return None 

1553 log.debug("Rendering map_index_template", template_length=len(template)) 

1554 jinja_env = ti.task.dag.get_template_env() 

1555 rendered_map_index = jinja_env.from_string(template).render(context) 

1556 log.debug("Map index rendered", length=len(rendered_map_index)) 

1557 return rendered_map_index 

1558 

1559 

1560def _push_xcom_if_needed(result: Any, ti: RuntimeTaskInstance, log: Logger): 

1561 """Push XCom values when task has ``do_xcom_push`` set to ``True`` and the task returns a result.""" 

1562 if ti.task.do_xcom_push: 

1563 xcom_value = result 

1564 else: 

1565 xcom_value = None 

1566 

1567 has_mapped_dep = next(ti.task.iter_mapped_dependants(), None) is not None 

1568 if xcom_value is None: 

1569 if not ti.is_mapped and has_mapped_dep: 

1570 # Uhoh, a downstream mapped task depends on us to push something to map over 

1571 from airflow.sdk.exceptions import XComForMappingNotPushed 

1572 

1573 raise XComForMappingNotPushed() 

1574 return 

1575 

1576 mapped_length: int | None = None 

1577 if not ti.is_mapped and has_mapped_dep: 

1578 from airflow.sdk.definitions.mappedoperator import is_mappable_value 

1579 from airflow.sdk.exceptions import UnmappableXComTypePushed 

1580 

1581 if not is_mappable_value(xcom_value): 

1582 raise UnmappableXComTypePushed(xcom_value) 

1583 mapped_length = len(xcom_value) 

1584 

1585 log.info("Pushing xcom", ti=ti) 

1586 

1587 # If the task has multiple outputs, push each output as a separate XCom. 

1588 if ti.task.multiple_outputs: 

1589 if not isinstance(xcom_value, Mapping): 

1590 raise TypeError( 

1591 f"Returned output was type {type(xcom_value)} expected dictionary for multiple_outputs" 

1592 ) 

1593 for key in xcom_value.keys(): 

1594 if not isinstance(key, str): 

1595 raise TypeError( 

1596 "Returned dictionary keys must be strings when using " 

1597 f"multiple_outputs, found {key} ({type(key)}) instead" 

1598 ) 

1599 for k, v in result.items(): 

1600 ti.xcom_push(k, v) 

1601 

1602 _xcom_push(ti, BaseXCom.XCOM_RETURN_KEY, result, mapped_length=mapped_length) 

1603 

1604 

1605def finalize( 

1606 ti: RuntimeTaskInstance, 

1607 state: TaskInstanceState, 

1608 context: Context, 

1609 log: Logger, 

1610 error: BaseException | None = None, 

1611): 

1612 # Record task duration metrics for all terminal states 

1613 if ti.start_date and ti.end_date: 

1614 duration_ms = (ti.end_date - ti.start_date).total_seconds() * 1000 

1615 stats_tags = {"dag_id": ti.dag_id, "task_id": ti.task_id} 

1616 

1617 Stats.timing(f"dag.{ti.dag_id}.{ti.task_id}.duration", duration_ms) 

1618 Stats.timing("task.duration", duration_ms, tags=stats_tags) 

1619 

1620 task = ti.task 

1621 # Pushing xcom for each operator extra links defined on the operator only. 

1622 for oe in task.operator_extra_links: 

1623 try: 

1624 link, xcom_key = oe.get_link(operator=task, ti_key=ti), oe.xcom_key # type: ignore[arg-type] 

1625 log.debug("Setting xcom for operator extra link", link=link, xcom_key=xcom_key) 

1626 _xcom_push_to_db(ti, key=xcom_key, value=link) 

1627 except Exception: 

1628 log.exception( 

1629 "Failed to push an xcom for task operator extra link", 

1630 link_name=oe.name, 

1631 xcom_key=oe.xcom_key, 

1632 ti=ti, 

1633 ) 

1634 

1635 if getattr(ti.task, "overwrite_rtif_after_execution", False): 

1636 log.debug("Overwriting Rendered template fields.") 

1637 if ti.task.template_fields: 

1638 SUPERVISOR_COMMS.send(SetRenderedFields(rendered_fields=_serialize_rendered_fields(ti.task))) 

1639 

1640 log.debug("Running finalizers", ti=ti) 

1641 if state == TaskInstanceState.SUCCESS: 

1642 _run_task_state_change_callbacks(task, "on_success_callback", context, log) 

1643 try: 

1644 get_listener_manager().hook.on_task_instance_success( 

1645 previous_state=TaskInstanceState.RUNNING, task_instance=ti 

1646 ) 

1647 except Exception: 

1648 log.exception("error calling listener") 

1649 elif state == TaskInstanceState.SKIPPED: 

1650 _run_task_state_change_callbacks(task, "on_skipped_callback", context, log) 

1651 try: 

1652 get_listener_manager().hook.on_task_instance_skipped( 

1653 previous_state=TaskInstanceState.RUNNING, task_instance=ti 

1654 ) 

1655 except Exception: 

1656 log.exception("error calling listener") 

1657 elif state == TaskInstanceState.UP_FOR_RETRY: 

1658 _run_task_state_change_callbacks(task, "on_retry_callback", context, log) 

1659 try: 

1660 get_listener_manager().hook.on_task_instance_failed( 

1661 previous_state=TaskInstanceState.RUNNING, task_instance=ti, error=error 

1662 ) 

1663 except Exception: 

1664 log.exception("error calling listener") 

1665 if error and task.email_on_retry and task.email: 

1666 _send_error_email_notification(task, ti, context, error, log) 

1667 elif state == TaskInstanceState.FAILED: 

1668 _run_task_state_change_callbacks(task, "on_failure_callback", context, log) 

1669 try: 

1670 get_listener_manager().hook.on_task_instance_failed( 

1671 previous_state=TaskInstanceState.RUNNING, task_instance=ti, error=error 

1672 ) 

1673 except Exception: 

1674 log.exception("error calling listener") 

1675 if error and task.email_on_failure and task.email: 

1676 _send_error_email_notification(task, ti, context, error, log) 

1677 

1678 try: 

1679 get_listener_manager().hook.before_stopping(component=TaskRunnerMarker()) 

1680 except Exception: 

1681 log.exception("error calling listener") 

1682 

1683 

1684def main(): 

1685 log = structlog.get_logger(logger_name="task") 

1686 

1687 global SUPERVISOR_COMMS 

1688 SUPERVISOR_COMMS = CommsDecoder[ToTask, ToSupervisor](log=log) 

1689 

1690 Stats.initialize( 

1691 is_statsd_datadog_enabled=conf.getboolean("metrics", "statsd_datadog_enabled"), 

1692 is_statsd_on=conf.getboolean("metrics", "statsd_on"), 

1693 is_otel_on=conf.getboolean("metrics", "otel_on"), 

1694 ) 

1695 

1696 try: 

1697 ti, context, log = startup() 

1698 with BundleVersionLock( 

1699 bundle_name=ti.bundle_instance.name, 

1700 bundle_version=ti.bundle_instance.version, 

1701 ): 

1702 state, _, error = run(ti, context, log) 

1703 context["exception"] = error 

1704 finalize(ti, state, context, log, error) 

1705 except KeyboardInterrupt: 

1706 log.exception("Ctrl-c hit") 

1707 exit(2) 

1708 except Exception: 

1709 log.exception("Top level error") 

1710 exit(1) 

1711 finally: 

1712 # Ensure the request socket is closed on the child side in all circumstances 

1713 # before the process fully terminates. 

1714 if SUPERVISOR_COMMS and SUPERVISOR_COMMS.socket: 

1715 with suppress(Exception): 

1716 SUPERVISOR_COMMS.socket.close() 

1717 

1718 

1719def reinit_supervisor_comms() -> None: 

1720 """ 

1721 Re-initialize supervisor comms and logging channel in subprocess. 

1722 

1723 This is not needed for most cases, but is used when either we re-launch the process via sudo for 

1724 run_as_user, or from inside the python code in a virtualenv (et al.) operator to re-connect so those tasks 

1725 can continue to access variables etc. 

1726 """ 

1727 import socket 

1728 

1729 if "SUPERVISOR_COMMS" not in globals(): 

1730 global SUPERVISOR_COMMS 

1731 log = structlog.get_logger(logger_name="task") 

1732 

1733 fd = int(os.environ.get("__AIRFLOW_SUPERVISOR_FD", "0")) 

1734 

1735 SUPERVISOR_COMMS = CommsDecoder[ToTask, ToSupervisor](log=log, socket=socket.socket(fileno=fd)) 

1736 

1737 logs = SUPERVISOR_COMMS.send(ResendLoggingFD()) 

1738 if isinstance(logs, SentFDs): 

1739 from airflow.sdk.log import configure_logging 

1740 

1741 log_io = os.fdopen(logs.fds[0], "wb", buffering=0) 

1742 configure_logging(json_output=True, output=log_io, sending_to_supervisor=True) 

1743 else: 

1744 print("Unable to re-configure logging after sudo, we didn't get an FD", file=sys.stderr) 

1745 

1746 

1747if __name__ == "__main__": 

1748 main()