Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/airflow/sdk/execution_time/task_runner.py: 15%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

755 statements  

1# 

2# Licensed to the Apache Software Foundation (ASF) under one 

3# or more contributor license agreements. See the NOTICE file 

4# distributed with this work for additional information 

5# regarding copyright ownership. The ASF licenses this file 

6# to you under the Apache License, Version 2.0 (the 

7# "License"); you may not use this file except in compliance 

8# with the License. You may obtain a copy of the License at 

9# 

10# http://www.apache.org/licenses/LICENSE-2.0 

11# 

12# Unless required by applicable law or agreed to in writing, 

13# software distributed under the License is distributed on an 

14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

15# KIND, either express or implied. See the License for the 

16# specific language governing permissions and limitations 

17# under the License. 

18"""The entrypoint for the actual task execution process.""" 

19 

20from __future__ import annotations 

21 

22import contextlib 

23import contextvars 

24import functools 

25import os 

26import sys 

27import time 

28from collections.abc import Callable, Iterable, Iterator, Mapping 

29from contextlib import suppress 

30from datetime import datetime, timezone 

31from itertools import product 

32from pathlib import Path 

33from typing import TYPE_CHECKING, Annotated, Any, Literal 

34from urllib.parse import quote 

35 

36import attrs 

37import lazy_object_proxy 

38import structlog 

39from pydantic import AwareDatetime, ConfigDict, Field, JsonValue, TypeAdapter 

40 

41from airflow.dag_processing.bundles.base import BaseDagBundle, BundleVersionLock 

42from airflow.dag_processing.bundles.manager import DagBundlesManager 

43from airflow.sdk.api.client import get_hostname, getuser 

44from airflow.sdk.api.datamodels._generated import ( 

45 AssetProfile, 

46 DagRun, 

47 PreviousTIResponse, 

48 TaskInstance, 

49 TaskInstanceState, 

50 TIRunContext, 

51) 

52from airflow.sdk.bases.operator import BaseOperator, ExecutorSafeguard 

53from airflow.sdk.bases.xcom import BaseXCom 

54from airflow.sdk.configuration import conf 

55from airflow.sdk.definitions._internal.dag_parsing_context import _airflow_parsing_context_manager 

56from airflow.sdk.definitions._internal.types import NOTSET, ArgNotSet, is_arg_set 

57from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetNameRef, AssetUniqueKey, AssetUriRef 

58from airflow.sdk.definitions.mappedoperator import MappedOperator 

59from airflow.sdk.definitions.param import process_params 

60from airflow.sdk.exceptions import ( 

61 AirflowException, 

62 AirflowInactiveAssetInInletOrOutletException, 

63 AirflowRuntimeError, 

64 AirflowTaskTimeout, 

65 ErrorType, 

66 TaskDeferred, 

67) 

68from airflow.sdk.execution_time.callback_runner import create_executable_runner 

69from airflow.sdk.execution_time.comms import ( 

70 AssetEventDagRunReferenceResult, 

71 CommsDecoder, 

72 DagRunStateResult, 

73 DeferTask, 

74 DRCount, 

75 ErrorResponse, 

76 GetDagRunState, 

77 GetDRCount, 

78 GetPreviousDagRun, 

79 GetPreviousTI, 

80 GetTaskBreadcrumbs, 

81 GetTaskRescheduleStartDate, 

82 GetTaskStates, 

83 GetTICount, 

84 InactiveAssetsResult, 

85 PreviousDagRunResult, 

86 PreviousTIResult, 

87 RescheduleTask, 

88 ResendLoggingFD, 

89 RetryTask, 

90 SentFDs, 

91 SetRenderedFields, 

92 SetRenderedMapIndex, 

93 SkipDownstreamTasks, 

94 StartupDetails, 

95 SucceedTask, 

96 TaskBreadcrumbsResult, 

97 TaskRescheduleStartDate, 

98 TaskState, 

99 TaskStatesResult, 

100 TICount, 

101 ToSupervisor, 

102 ToTask, 

103 TriggerDagRun, 

104 ValidateInletsAndOutlets, 

105) 

106from airflow.sdk.execution_time.context import ( 

107 ConnectionAccessor, 

108 InletEventsAccessors, 

109 MacrosAccessor, 

110 OutletEventAccessors, 

111 TriggeringAssetEventsAccessor, 

112 VariableAccessor, 

113 context_get_outlet_events, 

114 context_to_airflow_vars, 

115 get_previous_dagrun_success, 

116 set_current_context, 

117) 

118from airflow.sdk.execution_time.sentry import Sentry 

119from airflow.sdk.execution_time.xcom import XCom 

120from airflow.sdk.listener import get_listener_manager 

121from airflow.sdk.observability.stats import Stats 

122from airflow.sdk.timezone import coerce_datetime 

123from airflow.triggers.base import BaseEventTrigger 

124from airflow.triggers.callback import CallbackTrigger 

125 

126if TYPE_CHECKING: 

127 import jinja2 

128 from pendulum.datetime import DateTime 

129 from structlog.typing import FilteringBoundLogger as Logger 

130 

131 from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator 

132 from airflow.sdk.definitions.context import Context 

133 from airflow.sdk.exceptions import DagRunTriggerException 

134 from airflow.sdk.types import OutletEventAccessorsProtocol 

135 

136 

137class TaskRunnerMarker: 

138 """Marker for listener hooks, to properly detect from which component they are called.""" 

139 

140 

141# TODO: Move this entire class into a separate file: 

142# `airflow/sdk/execution_time/task_instance.py` 

143# or `airflow/sdk/execution_time/runtime_ti.py` 

144class RuntimeTaskInstance(TaskInstance): 

145 model_config = ConfigDict(arbitrary_types_allowed=True) 

146 

147 task: BaseOperator 

148 bundle_instance: BaseDagBundle 

149 _cached_template_context: Context | None = None 

150 """The Task Instance context. This is used to cache get_template_context.""" 

151 

152 _ti_context_from_server: Annotated[TIRunContext | None, Field(repr=False)] = None 

153 """The Task Instance context from the API server, if any.""" 

154 

155 max_tries: int = 0 

156 """The maximum number of retries for the task.""" 

157 

158 start_date: AwareDatetime 

159 """Start date of the task instance.""" 

160 

161 end_date: AwareDatetime | None = None 

162 

163 state: TaskInstanceState | None = None 

164 

165 is_mapped: bool | None = None 

166 """True if the original task was mapped.""" 

167 

168 rendered_map_index: str | None = None 

169 

170 sentry_integration: str = "" 

171 

172 def __rich_repr__(self): 

173 yield "id", self.id 

174 yield "task_id", self.task_id 

175 yield "dag_id", self.dag_id 

176 yield "run_id", self.run_id 

177 yield "max_tries", self.max_tries 

178 yield "task", type(self.task) 

179 yield "start_date", self.start_date 

180 

181 __rich_repr__.angular = True # type: ignore[attr-defined] 

182 

183 def get_template_context(self) -> Context: 

184 # TODO: Move this to `airflow.sdk.execution_time.context` 

185 # once we port the entire context logic from airflow/utils/context.py ? 

186 from airflow.sdk.plugins_manager import integrate_macros_plugins 

187 

188 integrate_macros_plugins() 

189 

190 dag_run_conf: dict[str, Any] | None = None 

191 if from_server := self._ti_context_from_server: 

192 dag_run_conf = from_server.dag_run.conf or dag_run_conf 

193 

194 validated_params = process_params(self.task.dag, self.task, dag_run_conf, suppress_exception=False) 

195 

196 # Cache the context object, which ensures that all calls to get_template_context 

197 # are operating on the same context object. 

198 self._cached_template_context: Context = self._cached_template_context or { 

199 # From the Task Execution interface 

200 "dag": self.task.dag, 

201 "inlets": self.task.inlets, 

202 "map_index_template": self.task.map_index_template, 

203 "outlets": self.task.outlets, 

204 "run_id": self.run_id, 

205 "task": self.task, 

206 "task_instance": self, 

207 "ti": self, 

208 "outlet_events": OutletEventAccessors(), 

209 "inlet_events": InletEventsAccessors(self.task.inlets), 

210 "macros": MacrosAccessor(), 

211 "params": validated_params, 

212 # TODO: Make this go through Public API longer term. 

213 # "test_mode": task_instance.test_mode, 

214 "var": { 

215 "json": VariableAccessor(deserialize_json=True), 

216 "value": VariableAccessor(deserialize_json=False), 

217 }, 

218 "conn": ConnectionAccessor(), 

219 } 

220 if from_server: 

221 dag_run = from_server.dag_run 

222 context_from_server: Context = { 

223 # TODO: Assess if we need to pass these through timezone.coerce_datetime 

224 "dag_run": dag_run, # type: ignore[typeddict-item] # Removable after #46522 

225 "triggering_asset_events": TriggeringAssetEventsAccessor.build( 

226 AssetEventDagRunReferenceResult.from_asset_event_dag_run_reference(event) 

227 for event in dag_run.consumed_asset_events 

228 ), 

229 "task_instance_key_str": f"{self.task.dag_id}__{self.task.task_id}__{dag_run.run_id}", 

230 "task_reschedule_count": from_server.task_reschedule_count or 0, 

231 "prev_start_date_success": lazy_object_proxy.Proxy( 

232 lambda: coerce_datetime(get_previous_dagrun_success(self.id).start_date) 

233 ), 

234 "prev_end_date_success": lazy_object_proxy.Proxy( 

235 lambda: coerce_datetime(get_previous_dagrun_success(self.id).end_date) 

236 ), 

237 } 

238 self._cached_template_context.update(context_from_server) 

239 

240 if logical_date := coerce_datetime(dag_run.logical_date): 

241 if TYPE_CHECKING: 

242 assert isinstance(logical_date, DateTime) 

243 ds = logical_date.strftime("%Y-%m-%d") 

244 ds_nodash = ds.replace("-", "") 

245 ts = logical_date.isoformat() 

246 ts_nodash = logical_date.strftime("%Y%m%dT%H%M%S") 

247 ts_nodash_with_tz = ts.replace("-", "").replace(":", "") 

248 # logical_date and data_interval either coexist or be None together 

249 self._cached_template_context.update( 

250 { 

251 # keys that depend on logical_date 

252 "logical_date": logical_date, 

253 "ds": ds, 

254 "ds_nodash": ds_nodash, 

255 "task_instance_key_str": f"{self.task.dag_id}__{self.task.task_id}__{ds_nodash}", 

256 "ts": ts, 

257 "ts_nodash": ts_nodash, 

258 "ts_nodash_with_tz": ts_nodash_with_tz, 

259 # keys that depend on data_interval 

260 "data_interval_end": coerce_datetime(dag_run.data_interval_end), 

261 "data_interval_start": coerce_datetime(dag_run.data_interval_start), 

262 "prev_data_interval_start_success": lazy_object_proxy.Proxy( 

263 lambda: coerce_datetime(get_previous_dagrun_success(self.id).data_interval_start) 

264 ), 

265 "prev_data_interval_end_success": lazy_object_proxy.Proxy( 

266 lambda: coerce_datetime(get_previous_dagrun_success(self.id).data_interval_end) 

267 ), 

268 } 

269 ) 

270 

271 if from_server.upstream_map_indexes is not None: 

272 # We stash this in here for later use, but we purposefully don't want to document it's 

273 # existence. Should this be a private attribute on RuntimeTI instead perhaps? 

274 setattr(self, "_upstream_map_indexes", from_server.upstream_map_indexes) 

275 

276 return self._cached_template_context 

277 

278 def render_templates( 

279 self, context: Context | None = None, jinja_env: jinja2.Environment | None = None 

280 ) -> BaseOperator: 

281 """ 

282 Render templates in the operator fields. 

283 

284 If the task was originally mapped, this may replace ``self.task`` with 

285 the unmapped, fully rendered BaseOperator. The original ``self.task`` 

286 before replacement is returned. 

287 """ 

288 if not context: 

289 context = self.get_template_context() 

290 original_task = self.task 

291 

292 if TYPE_CHECKING: 

293 assert context 

294 

295 ti = context["ti"] 

296 

297 if TYPE_CHECKING: 

298 assert original_task 

299 assert self.task 

300 assert ti.task 

301 

302 # If self.task is mapped, this call replaces self.task to point to the 

303 # unmapped BaseOperator created by this function! This is because the 

304 # MappedOperator is useless for template rendering, and we need to be 

305 # able to access the unmapped task instead. 

306 self.task.render_template_fields(context, jinja_env) 

307 self.is_mapped = original_task.is_mapped 

308 return original_task 

309 

310 def xcom_pull( 

311 self, 

312 task_ids: str | Iterable[str] | None = None, 

313 dag_id: str | None = None, 

314 key: str = BaseXCom.XCOM_RETURN_KEY, 

315 include_prior_dates: bool = False, 

316 *, 

317 map_indexes: int | Iterable[int] | None | ArgNotSet = NOTSET, 

318 default: Any = None, 

319 run_id: str | None = None, 

320 ) -> Any: 

321 """ 

322 Pull XComs either from the API server (BaseXCom) or from the custom XCOM backend if configured. 

323 

324 The pull can be filtered optionally by certain criterion. 

325 

326 :param key: A key for the XCom. If provided, only XComs with matching 

327 keys will be returned. The default key is ``'return_value'``, also 

328 available as constant ``XCOM_RETURN_KEY``. This key is automatically 

329 given to XComs returned by tasks (as opposed to being pushed 

330 manually). 

331 :param task_ids: Only XComs from tasks with matching ids will be 

332 pulled. If *None* (default), the task_id of the calling task is used. 

333 :param dag_id: If provided, only pulls XComs from this Dag. If *None* 

334 (default), the Dag of the calling task is used. 

335 :param map_indexes: If provided, only pull XComs with matching indexes. 

336 If *None* (default), this is inferred from the task(s) being pulled 

337 (see below for details). 

338 :param include_prior_dates: If False, only XComs from the current 

339 logical_date are returned. If *True*, XComs from previous dates 

340 are returned as well. 

341 :param run_id: If provided, only pulls XComs from a DagRun w/a matching run_id. 

342 If *None* (default), the run_id of the calling task is used. 

343 

344 When pulling one single task (``task_id`` is *None* or a str) without 

345 specifying ``map_indexes``, the return value is a single XCom entry 

346 (map_indexes is set to map_index of the calling task instance). 

347 

348 When pulling task is mapped the specified ``map_index`` is used, so by default 

349 pulling on mapped task will result in no matching XComs if the task instance 

350 of the method call is not mapped. Otherwise, the map_index of the calling task 

351 instance is used. Setting ``map_indexes`` to *None* will pull XCom as it would 

352 from a non mapped task. 

353 

354 In either case, ``default`` (*None* if not specified) is returned if no 

355 matching XComs are found. 

356 

357 When pulling multiple tasks (i.e. either ``task_id`` or ``map_index`` is 

358 a non-str iterable), a list of matching XComs is returned. Elements in 

359 the list is ordered by item ordering in ``task_id`` and ``map_index``. 

360 """ 

361 if dag_id is None: 

362 dag_id = self.dag_id 

363 if run_id is None: 

364 run_id = self.run_id 

365 

366 single_task_requested = isinstance(task_ids, (str, type(None))) 

367 single_map_index_requested = isinstance(map_indexes, (int, type(None))) 

368 

369 if task_ids is None: 

370 # default to the current task if not provided 

371 task_ids = [self.task_id] 

372 elif isinstance(task_ids, str): 

373 task_ids = [task_ids] 

374 

375 # If map_indexes is not specified, pull xcoms from all map indexes for each task 

376 if not is_arg_set(map_indexes): 

377 xcoms: list[Any] = [] 

378 for t_id in task_ids: 

379 values = XCom.get_all( 

380 run_id=run_id, 

381 key=key, 

382 task_id=t_id, 

383 dag_id=dag_id, 

384 include_prior_dates=include_prior_dates, 

385 ) 

386 

387 if values is None: 

388 xcoms.append(None) 

389 else: 

390 xcoms.extend(values) 

391 # For single task pulling from unmapped task, return single value 

392 if single_task_requested and len(xcoms) == 1: 

393 return xcoms[0] 

394 return xcoms 

395 

396 # Original logic when map_indexes is explicitly specified 

397 map_indexes_iterable: Iterable[int | None] = [] 

398 if isinstance(map_indexes, int) or map_indexes is None: 

399 map_indexes_iterable = [map_indexes] 

400 elif isinstance(map_indexes, Iterable): 

401 map_indexes_iterable = map_indexes 

402 else: 

403 raise TypeError( 

404 f"Invalid type for map_indexes: expected int, iterable of ints, or None, got {type(map_indexes)}" 

405 ) 

406 

407 xcoms = [] 

408 for t_id, m_idx in product(task_ids, map_indexes_iterable): 

409 value = XCom.get_one( 

410 run_id=run_id, 

411 key=key, 

412 task_id=t_id, 

413 dag_id=dag_id, 

414 map_index=m_idx, 

415 include_prior_dates=include_prior_dates, 

416 ) 

417 if value is None: 

418 xcoms.append(default) 

419 else: 

420 xcoms.append(value) 

421 

422 if single_task_requested and single_map_index_requested: 

423 return xcoms[0] 

424 

425 return xcoms 

426 

427 def xcom_push(self, key: str, value: Any): 

428 """ 

429 Make an XCom available for tasks to pull. 

430 

431 :param key: Key to store the value under. 

432 :param value: Value to store. Only be JSON-serializable values may be used. 

433 """ 

434 _xcom_push(self, key, value) 

435 

436 def get_relevant_upstream_map_indexes( 

437 self, upstream: BaseOperator, ti_count: int | None, session: Any 

438 ) -> int | range | None: 

439 # TODO: Implement this method 

440 return None 

441 

442 def get_first_reschedule_date(self, context: Context) -> AwareDatetime | None: 

443 """Get the first reschedule date for the task instance if found, none otherwise.""" 

444 if context.get("task_reschedule_count", 0) == 0: 

445 # If the task has not been rescheduled, there is no need to ask the supervisor 

446 return None 

447 

448 max_tries: int = self.max_tries 

449 retries: int = self.task.retries or 0 

450 first_try_number = max_tries - retries + 1 

451 

452 log = structlog.get_logger(logger_name="task") 

453 

454 log.debug("Requesting first reschedule date from supervisor") 

455 

456 response = SUPERVISOR_COMMS.send( 

457 msg=GetTaskRescheduleStartDate(ti_id=self.id, try_number=first_try_number) 

458 ) 

459 

460 if TYPE_CHECKING: 

461 assert isinstance(response, TaskRescheduleStartDate) 

462 

463 return response.start_date 

464 

465 def get_previous_dagrun(self, state: str | None = None) -> DagRun | None: 

466 """Return the previous Dag run before the given logical date, optionally filtered by state.""" 

467 context = self.get_template_context() 

468 dag_run = context.get("dag_run") 

469 

470 log = structlog.get_logger(logger_name="task") 

471 

472 log.debug("Getting previous Dag run", dag_run=dag_run) 

473 

474 if dag_run is None: 

475 return None 

476 

477 if dag_run.logical_date is None: 

478 return None 

479 

480 response = SUPERVISOR_COMMS.send( 

481 msg=GetPreviousDagRun(dag_id=self.dag_id, logical_date=dag_run.logical_date, state=state) 

482 ) 

483 

484 if TYPE_CHECKING: 

485 assert isinstance(response, PreviousDagRunResult) 

486 

487 return response.dag_run 

488 

489 def get_previous_ti( 

490 self, 

491 state: TaskInstanceState | None = None, 

492 logical_date: AwareDatetime | None = None, 

493 map_index: int = -1, 

494 ) -> PreviousTIResponse | None: 

495 """ 

496 Return the previous task instance matching the given criteria. 

497 

498 :param state: Filter by TaskInstance state 

499 :param logical_date: Filter by logical date (returns TI before this date) 

500 :param map_index: Filter by map_index (defaults to -1 for non-mapped tasks) 

501 :return: Previous task instance or None if not found 

502 """ 

503 context = self.get_template_context() 

504 dag_run = context.get("dag_run") 

505 

506 log = structlog.get_logger(logger_name="task") 

507 log.debug("Getting previous task instance", task_id=self.task_id, state=state) 

508 

509 # Use current dag run's logical_date if not provided 

510 effective_logical_date = logical_date 

511 if effective_logical_date is None and dag_run and dag_run.logical_date: 

512 effective_logical_date = dag_run.logical_date 

513 

514 response = SUPERVISOR_COMMS.send( 

515 msg=GetPreviousTI( 

516 dag_id=self.dag_id, 

517 task_id=self.task_id, 

518 logical_date=effective_logical_date, 

519 map_index=map_index, 

520 state=state, 

521 ) 

522 ) 

523 

524 if TYPE_CHECKING: 

525 assert isinstance(response, PreviousTIResult) 

526 

527 return response.task_instance 

528 

529 @staticmethod 

530 def get_ti_count( 

531 dag_id: str, 

532 map_index: int | None = None, 

533 task_ids: list[str] | None = None, 

534 task_group_id: str | None = None, 

535 logical_dates: list[datetime] | None = None, 

536 run_ids: list[str] | None = None, 

537 states: list[str] | None = None, 

538 ) -> int: 

539 """Return the number of task instances matching the given criteria.""" 

540 response = SUPERVISOR_COMMS.send( 

541 GetTICount( 

542 dag_id=dag_id, 

543 map_index=map_index, 

544 task_ids=task_ids, 

545 task_group_id=task_group_id, 

546 logical_dates=logical_dates, 

547 run_ids=run_ids, 

548 states=states, 

549 ), 

550 ) 

551 

552 if TYPE_CHECKING: 

553 assert isinstance(response, TICount) 

554 

555 return response.count 

556 

557 @staticmethod 

558 def get_task_states( 

559 dag_id: str, 

560 map_index: int | None = None, 

561 task_ids: list[str] | None = None, 

562 task_group_id: str | None = None, 

563 logical_dates: list[datetime] | None = None, 

564 run_ids: list[str] | None = None, 

565 ) -> dict[str, Any]: 

566 """Return the task states matching the given criteria.""" 

567 response = SUPERVISOR_COMMS.send( 

568 GetTaskStates( 

569 dag_id=dag_id, 

570 map_index=map_index, 

571 task_ids=task_ids, 

572 task_group_id=task_group_id, 

573 logical_dates=logical_dates, 

574 run_ids=run_ids, 

575 ), 

576 ) 

577 

578 if TYPE_CHECKING: 

579 assert isinstance(response, TaskStatesResult) 

580 

581 return response.task_states 

582 

583 @staticmethod 

584 def get_task_breadcrumbs(dag_id: str, run_id: str) -> Iterable[dict[str, Any]]: 

585 """Return task breadcrumbs for the given dag run.""" 

586 response = SUPERVISOR_COMMS.send(GetTaskBreadcrumbs(dag_id=dag_id, run_id=run_id)) 

587 if TYPE_CHECKING: 

588 assert isinstance(response, TaskBreadcrumbsResult) 

589 return response.breadcrumbs 

590 

591 @staticmethod 

592 def get_dr_count( 

593 dag_id: str, 

594 logical_dates: list[datetime] | None = None, 

595 run_ids: list[str] | None = None, 

596 states: list[str] | None = None, 

597 ) -> int: 

598 """Return the number of Dag runs matching the given criteria.""" 

599 response = SUPERVISOR_COMMS.send( 

600 GetDRCount( 

601 dag_id=dag_id, 

602 logical_dates=logical_dates, 

603 run_ids=run_ids, 

604 states=states, 

605 ), 

606 ) 

607 

608 if TYPE_CHECKING: 

609 assert isinstance(response, DRCount) 

610 

611 return response.count 

612 

613 @staticmethod 

614 def get_dagrun_state(dag_id: str, run_id: str) -> str: 

615 """Return the state of the Dag run with the given Run ID.""" 

616 response = SUPERVISOR_COMMS.send(msg=GetDagRunState(dag_id=dag_id, run_id=run_id)) 

617 

618 if TYPE_CHECKING: 

619 assert isinstance(response, DagRunStateResult) 

620 

621 return response.state 

622 

623 @property 

624 def log_url(self) -> str: 

625 run_id = quote(self.run_id) 

626 base_url = conf.get("api", "base_url", fallback="http://localhost:8080/") 

627 map_index_value = self.map_index 

628 map_index = ( 

629 f"/mapped/{map_index_value}" if map_index_value is not None and map_index_value >= 0 else "" 

630 ) 

631 try_number_value = self.try_number 

632 try_number = ( 

633 f"?try_number={try_number_value}" if try_number_value is not None and try_number_value > 0 else "" 

634 ) 

635 _log_uri = f"{base_url.rstrip('/')}/dags/{self.dag_id}/runs/{run_id}/tasks/{self.task_id}{map_index}{try_number}" 

636 return _log_uri 

637 

638 @property 

639 def mark_success_url(self) -> str: 

640 """URL to mark TI success.""" 

641 return self.log_url 

642 

643 

644def _xcom_push(ti: RuntimeTaskInstance, key: str, value: Any, mapped_length: int | None = None) -> None: 

645 """Push a XCom through XCom.set, which pushes to XCom Backend if configured.""" 

646 # Private function, as we don't want to expose the ability to manually set `mapped_length` to SDK 

647 # consumers 

648 

649 XCom.set( 

650 key=key, 

651 value=value, 

652 dag_id=ti.dag_id, 

653 task_id=ti.task_id, 

654 run_id=ti.run_id, 

655 map_index=ti.map_index, 

656 _mapped_length=mapped_length, 

657 ) 

658 

659 

660def _xcom_push_to_db(ti: RuntimeTaskInstance, key: str, value: Any) -> None: 

661 """Push a XCom directly to metadata DB, bypassing custom xcom_backend.""" 

662 XCom._set_xcom_in_db( 

663 key=key, 

664 value=value, 

665 dag_id=ti.dag_id, 

666 task_id=ti.task_id, 

667 run_id=ti.run_id, 

668 map_index=ti.map_index, 

669 ) 

670 

671 

672def parse(what: StartupDetails, log: Logger) -> RuntimeTaskInstance: 

673 # TODO: Task-SDK: 

674 # Using BundleDagBag here is about 98% wrong, but it'll do for now 

675 from airflow.dag_processing.dagbag import BundleDagBag 

676 

677 bundle_info = what.bundle_info 

678 bundle_instance = DagBundlesManager().get_bundle( 

679 name=bundle_info.name, 

680 version=bundle_info.version, 

681 ) 

682 bundle_instance.initialize() 

683 

684 dag_absolute_path = os.fspath(Path(bundle_instance.path, what.dag_rel_path)) 

685 bag = BundleDagBag( 

686 dag_folder=dag_absolute_path, 

687 safe_mode=False, 

688 load_op_links=False, 

689 bundle_path=bundle_instance.path, 

690 bundle_name=bundle_info.name, 

691 ) 

692 if TYPE_CHECKING: 

693 assert what.ti.dag_id 

694 

695 try: 

696 dag = bag.dags[what.ti.dag_id] 

697 except KeyError: 

698 log.error( 

699 "Dag not found during start up", dag_id=what.ti.dag_id, bundle=bundle_info, path=what.dag_rel_path 

700 ) 

701 sys.exit(1) 

702 

703 # install_loader() 

704 

705 try: 

706 task = dag.task_dict[what.ti.task_id] 

707 except KeyError: 

708 log.error( 

709 "Task not found in Dag during start up", 

710 dag_id=dag.dag_id, 

711 task_id=what.ti.task_id, 

712 bundle=bundle_info, 

713 path=what.dag_rel_path, 

714 ) 

715 sys.exit(1) 

716 

717 if not isinstance(task, (BaseOperator, MappedOperator)): 

718 raise TypeError( 

719 f"task is of the wrong type, got {type(task)}, wanted {BaseOperator} or {MappedOperator}" 

720 ) 

721 

722 return RuntimeTaskInstance.model_construct( 

723 **what.ti.model_dump(exclude_unset=True), 

724 task=task, 

725 bundle_instance=bundle_instance, 

726 _ti_context_from_server=what.ti_context, 

727 max_tries=what.ti_context.max_tries, 

728 start_date=what.start_date, 

729 state=TaskInstanceState.RUNNING, 

730 sentry_integration=what.sentry_integration, 

731 ) 

732 

733 

734# This global variable will be used by Connection/Variable/XCom classes, or other parts of the task's execution, 

735# to send requests back to the supervisor process. 

736# 

737# Why it needs to be a global: 

738# - Many parts of Airflow's codebase (e.g., connections, variables, and XComs) may rely on making dynamic requests 

739# to the parent process during task execution. 

740# - These calls occur in various locations and cannot easily pass the `CommsDecoder` instance through the 

741# deeply nested execution stack. 

742# - By defining `SUPERVISOR_COMMS` as a global, it ensures that this communication mechanism is readily 

743# accessible wherever needed during task execution without modifying every layer of the call stack. 

744SUPERVISOR_COMMS: CommsDecoder[ToTask, ToSupervisor] 

745 

746 

747# State machine! 

748# 1. Start up (receive details from supervisor) 

749# 2. Execution (run task code, possibly send requests) 

750# 3. Shutdown and report status 

751 

752 

753def startup() -> tuple[RuntimeTaskInstance, Context, Logger]: 

754 # The parent sends us a StartupDetails message un-prompted. After this, every single message is only sent 

755 # in response to us sending a request. 

756 log = structlog.get_logger(logger_name="task") 

757 

758 if os.environ.get("_AIRFLOW__REEXECUTED_PROCESS") == "1" and ( 

759 msgjson := os.environ.get("_AIRFLOW__STARTUP_MSG") 

760 ): 

761 # Clear any Kerberos replace cache if there is one, so new process can't reuse it. 

762 os.environ.pop("KRB5CCNAME", None) 

763 # entrypoint of re-exec process 

764 

765 msg: StartupDetails = TypeAdapter(StartupDetails).validate_json(msgjson) 

766 reinit_supervisor_comms() 

767 

768 # We delay this message until _after_ we've got the logging re-configured, otherwise it will show up 

769 # on stdout 

770 log.debug("Using serialized startup message from environment", msg=msg) 

771 else: 

772 # normal entry point 

773 msg = SUPERVISOR_COMMS._get_response() # type: ignore[assignment] 

774 

775 if not isinstance(msg, StartupDetails): 

776 raise RuntimeError(f"Unhandled startup message {type(msg)} {msg}") 

777 

778 # setproctitle causes issue on Mac OS: https://github.com/benoitc/gunicorn/issues/3021 

779 os_type = sys.platform 

780 if os_type == "darwin": 

781 log.debug("Mac OS detected, skipping setproctitle") 

782 else: 

783 from setproctitle import setproctitle 

784 

785 setproctitle(f"airflow worker -- {msg.ti.id}") 

786 

787 try: 

788 get_listener_manager().hook.on_starting(component=TaskRunnerMarker()) 

789 except Exception: 

790 log.exception("error calling listener") 

791 

792 with _airflow_parsing_context_manager(dag_id=msg.ti.dag_id, task_id=msg.ti.task_id): 

793 ti = parse(msg, log) 

794 log.debug("Dag file parsed", file=msg.dag_rel_path) 

795 

796 run_as_user = getattr(ti.task, "run_as_user", None) or conf.get( 

797 "core", "default_impersonation", fallback=None 

798 ) 

799 

800 if os.environ.get("_AIRFLOW__REEXECUTED_PROCESS") != "1" and run_as_user and run_as_user != getuser(): 

801 # enters here for re-exec process 

802 os.environ["_AIRFLOW__REEXECUTED_PROCESS"] = "1" 

803 # store startup message in environment for re-exec process 

804 os.environ["_AIRFLOW__STARTUP_MSG"] = msg.model_dump_json() 

805 os.set_inheritable(SUPERVISOR_COMMS.socket.fileno(), True) 

806 

807 # Import main directly from the module instead of re-executing the file. 

808 # This ensures that when other parts modules import 

809 # airflow.sdk.execution_time.task_runner, they get the same module instance 

810 # with the properly initialized SUPERVISOR_COMMS global variable. 

811 # If we re-executed the module with `python -m`, it would load as __main__ and future 

812 # imports would get a fresh copy without the initialized globals. 

813 rexec_python_code = "from airflow.sdk.execution_time.task_runner import main; main()" 

814 cmd = ["sudo", "-E", "-H", "-u", run_as_user, sys.executable, "-c", rexec_python_code] 

815 log.info( 

816 "Running command", 

817 command=cmd, 

818 ) 

819 os.execvp("sudo", cmd) 

820 

821 # ideally, we should never reach here, but if we do, we should return None, None, None 

822 return None, None, None 

823 

824 return ti, ti.get_template_context(), log 

825 

826 

827def _serialize_template_field(template_field: Any, name: str) -> str | dict | list | int | float: 

828 """ 

829 Return a serializable representation of the templated field. 

830 

831 If ``templated_field`` contains a class or instance that requires recursive 

832 templating, store them as strings. Otherwise simply return the field as-is. 

833 

834 Used sdk secrets masker to redact secrets in the serialized output. 

835 """ 

836 import json 

837 

838 from airflow.sdk._shared.secrets_masker import redact 

839 

840 def is_jsonable(x): 

841 try: 

842 json.dumps(x) 

843 except (TypeError, OverflowError): 

844 return False 

845 else: 

846 return True 

847 

848 def translate_tuples_to_lists(obj: Any): 

849 """Recursively convert tuples to lists.""" 

850 if isinstance(obj, tuple): 

851 return [translate_tuples_to_lists(item) for item in obj] 

852 if isinstance(obj, list): 

853 return [translate_tuples_to_lists(item) for item in obj] 

854 if isinstance(obj, dict): 

855 return {key: translate_tuples_to_lists(value) for key, value in obj.items()} 

856 return obj 

857 

858 def sort_dict_recursively(obj: Any) -> Any: 

859 """Recursively sort dictionaries to ensure consistent ordering.""" 

860 if isinstance(obj, dict): 

861 return {k: sort_dict_recursively(v) for k, v in sorted(obj.items())} 

862 if isinstance(obj, list): 

863 return [sort_dict_recursively(item) for item in obj] 

864 if isinstance(obj, tuple): 

865 return tuple(sort_dict_recursively(item) for item in obj) 

866 return obj 

867 

868 max_length = conf.getint("core", "max_templated_field_length") 

869 

870 if not is_jsonable(template_field): 

871 try: 

872 serialized = template_field.serialize() 

873 except AttributeError: 

874 serialized = str(template_field) 

875 if len(serialized) > max_length: 

876 rendered = redact(serialized, name) 

877 return ( 

878 "Truncated. You can change this behaviour in [core]max_templated_field_length. " 

879 f"{rendered[: max_length - 79]!r}... " 

880 ) 

881 return serialized 

882 if not template_field and not isinstance(template_field, tuple): 

883 # Avoid unnecessary serialization steps for empty fields unless they are tuples 

884 # and need to be converted to lists 

885 return template_field 

886 template_field = translate_tuples_to_lists(template_field) 

887 # Sort dictionaries recursively to ensure consistent string representation 

888 # This prevents hash inconsistencies when dict ordering varies 

889 if isinstance(template_field, dict): 

890 template_field = sort_dict_recursively(template_field) 

891 serialized = str(template_field) 

892 if len(serialized) > max_length: 

893 rendered = redact(serialized, name) 

894 return ( 

895 "Truncated. You can change this behaviour in [core]max_templated_field_length. " 

896 f"{rendered[: max_length - 79]!r}... " 

897 ) 

898 return template_field 

899 

900 

901def _serialize_rendered_fields(task: AbstractOperator) -> dict[str, JsonValue]: 

902 from airflow.sdk._shared.secrets_masker import redact 

903 

904 rendered_fields = {} 

905 for field in task.template_fields: 

906 value = getattr(task, field) 

907 serialized = _serialize_template_field(value, field) 

908 # Redact secrets in the task process itself before sending to API server 

909 # This ensures that the secrets those are registered via mask_secret() on workers / dag processor are properly masked 

910 # on the UI. 

911 rendered_fields[field] = redact(serialized, field) 

912 

913 return rendered_fields # type: ignore[return-value] # Convince mypy that this is OK since we pass JsonValue to redact, so it will return the same 

914 

915 

916def _build_asset_profiles(lineage_objects: list) -> Iterator[AssetProfile]: 

917 # Lineage can have other types of objects besides assets, so we need to process them a bit. 

918 for obj in lineage_objects or (): 

919 if isinstance(obj, Asset): 

920 yield AssetProfile(name=obj.name, uri=obj.uri, type=Asset.__name__) 

921 elif isinstance(obj, AssetNameRef): 

922 yield AssetProfile(name=obj.name, type=AssetNameRef.__name__) 

923 elif isinstance(obj, AssetUriRef): 

924 yield AssetProfile(uri=obj.uri, type=AssetUriRef.__name__) 

925 elif isinstance(obj, AssetAlias): 

926 yield AssetProfile(name=obj.name, type=AssetAlias.__name__) 

927 

928 

929def _serialize_outlet_events(events: OutletEventAccessorsProtocol) -> Iterator[dict[str, JsonValue]]: 

930 if TYPE_CHECKING: 

931 assert isinstance(events, OutletEventAccessors) 

932 # We just collect everything the user recorded in the accessors. 

933 # Further filtering will be done in the API server. 

934 for key, accessor in events._dict.items(): 

935 if isinstance(key, AssetUniqueKey): 

936 yield {"dest_asset_key": attrs.asdict(key), "extra": accessor.extra} 

937 for alias_event in accessor.asset_alias_events: 

938 yield attrs.asdict(alias_event) 

939 

940 

941def _prepare(ti: RuntimeTaskInstance, log: Logger, context: Context) -> ToSupervisor | None: 

942 ti.hostname = get_hostname() 

943 ti.task = ti.task.prepare_for_execution() 

944 # Since context is now cached, and calling `ti.get_template_context` will return the same dict, we want to 

945 # update the value of the task that is sent from there 

946 context["task"] = ti.task 

947 

948 jinja_env = ti.task.dag.get_template_env() 

949 ti.render_templates(context=context, jinja_env=jinja_env) 

950 

951 if rendered_fields := _serialize_rendered_fields(ti.task): 

952 # so that we do not call the API unnecessarily 

953 SUPERVISOR_COMMS.send(msg=SetRenderedFields(rendered_fields=rendered_fields)) 

954 

955 # Try to render map_index_template early with available context (will be re-rendered after execution) 

956 # This provides a partial label during task execution for templates using pre-execution context 

957 # If rendering fails here, we suppress the error since it will be re-rendered after execution 

958 try: 

959 if rendered_map_index := _render_map_index(context, ti=ti, log=log): 

960 ti.rendered_map_index = rendered_map_index 

961 log.debug("Sending early rendered map index", length=len(rendered_map_index)) 

962 SUPERVISOR_COMMS.send(msg=SetRenderedMapIndex(rendered_map_index=rendered_map_index)) 

963 except Exception: 

964 log.debug( 

965 "Early rendering of map_index_template failed, will retry after task execution", exc_info=True 

966 ) 

967 

968 _validate_task_inlets_and_outlets(ti=ti, log=log) 

969 

970 try: 

971 # TODO: Call pre execute etc. 

972 get_listener_manager().hook.on_task_instance_running( 

973 previous_state=TaskInstanceState.QUEUED, task_instance=ti 

974 ) 

975 except Exception: 

976 log.exception("error calling listener") 

977 

978 # No error, carry on and execute the task 

979 return None 

980 

981 

982def _validate_task_inlets_and_outlets(*, ti: RuntimeTaskInstance, log: Logger) -> None: 

983 if not ti.task.inlets and not ti.task.outlets: 

984 return 

985 

986 inactive_assets_resp = SUPERVISOR_COMMS.send(msg=ValidateInletsAndOutlets(ti_id=ti.id)) 

987 if TYPE_CHECKING: 

988 assert isinstance(inactive_assets_resp, InactiveAssetsResult) 

989 if inactive_assets := inactive_assets_resp.inactive_assets: 

990 raise AirflowInactiveAssetInInletOrOutletException( 

991 inactive_asset_keys=[ 

992 AssetUniqueKey.from_profile(asset_profile) for asset_profile in inactive_assets 

993 ] 

994 ) 

995 

996 

997def _defer_task( 

998 defer: TaskDeferred, ti: RuntimeTaskInstance, log: Logger 

999) -> tuple[ToSupervisor, TaskInstanceState]: 

1000 # TODO: Should we use structlog.bind_contextvars here for dag_id, task_id & run_id? 

1001 

1002 log.info("Pausing task as DEFERRED. ", dag_id=ti.dag_id, task_id=ti.task_id, run_id=ti.run_id) 

1003 classpath, trigger_kwargs = defer.trigger.serialize() 

1004 queue: str | None = None 

1005 # Currently, only task-associated BaseTrigger instances may have a non-None queue, 

1006 # and only when triggerer.queues_enabled is True. 

1007 if not isinstance(defer.trigger, (BaseEventTrigger, CallbackTrigger)) and conf.getboolean( 

1008 "triggerer", "queues_enabled", fallback=False 

1009 ): 

1010 queue = ti.task.queue 

1011 

1012 from airflow.sdk.serde import serialize as serde_serialize 

1013 

1014 trigger_kwargs = serde_serialize(trigger_kwargs) 

1015 next_kwargs = serde_serialize(defer.kwargs or {}) 

1016 

1017 if TYPE_CHECKING: 

1018 assert isinstance(next_kwargs, dict) 

1019 assert isinstance(trigger_kwargs, dict) 

1020 

1021 msg = DeferTask( 

1022 classpath=classpath, 

1023 trigger_kwargs=trigger_kwargs, 

1024 trigger_timeout=defer.timeout, 

1025 queue=queue, 

1026 next_method=defer.method_name, 

1027 next_kwargs=next_kwargs, 

1028 ) 

1029 state = TaskInstanceState.DEFERRED 

1030 

1031 return msg, state 

1032 

1033 

1034@Sentry.enrich_errors 

1035def run( 

1036 ti: RuntimeTaskInstance, 

1037 context: Context, 

1038 log: Logger, 

1039) -> tuple[TaskInstanceState, ToSupervisor | None, BaseException | None]: 

1040 """Run the task in this process.""" 

1041 import signal 

1042 

1043 from airflow.sdk.exceptions import ( 

1044 AirflowFailException, 

1045 AirflowRescheduleException, 

1046 AirflowSensorTimeout, 

1047 AirflowSkipException, 

1048 AirflowTaskTerminated, 

1049 DagRunTriggerException, 

1050 DownstreamTasksSkipped, 

1051 TaskDeferred, 

1052 ) 

1053 

1054 if TYPE_CHECKING: 

1055 assert ti.task is not None 

1056 assert isinstance(ti.task, BaseOperator) 

1057 

1058 parent_pid = os.getpid() 

1059 

1060 def _on_term(signum, frame): 

1061 pid = os.getpid() 

1062 if pid != parent_pid: 

1063 return 

1064 

1065 ti.task.on_kill() 

1066 

1067 signal.signal(signal.SIGTERM, _on_term) 

1068 

1069 msg: ToSupervisor | None = None 

1070 state: TaskInstanceState 

1071 error: BaseException | None = None 

1072 

1073 try: 

1074 # First, clear the xcom data sent from server 

1075 if ti._ti_context_from_server and (keys_to_delete := ti._ti_context_from_server.xcom_keys_to_clear): 

1076 for x in keys_to_delete: 

1077 log.debug("Clearing XCom with key", key=x) 

1078 XCom.delete( 

1079 key=x, 

1080 dag_id=ti.dag_id, 

1081 task_id=ti.task_id, 

1082 run_id=ti.run_id, 

1083 map_index=ti.map_index, 

1084 ) 

1085 

1086 with set_current_context(context): 

1087 # This is the earliest that we can render templates -- as if it excepts for any reason we need to 

1088 # catch it and handle it like a normal task failure 

1089 if early_exit := _prepare(ti, log, context): 

1090 msg = early_exit 

1091 ti.state = state = TaskInstanceState.FAILED 

1092 return state, msg, error 

1093 

1094 try: 

1095 result = _execute_task(context=context, ti=ti, log=log) 

1096 except Exception: 

1097 import jinja2 

1098 

1099 # If the task failed, swallow rendering error so it doesn't mask the main error. 

1100 with contextlib.suppress(jinja2.TemplateSyntaxError, jinja2.UndefinedError): 

1101 previous_rendered_map_index = ti.rendered_map_index 

1102 ti.rendered_map_index = _render_map_index(context, ti=ti, log=log) 

1103 # Send update only if value changed (e.g., user set context variables during execution) 

1104 if ti.rendered_map_index and ti.rendered_map_index != previous_rendered_map_index: 

1105 SUPERVISOR_COMMS.send( 

1106 msg=SetRenderedMapIndex(rendered_map_index=ti.rendered_map_index) 

1107 ) 

1108 raise 

1109 else: # If the task succeeded, render normally to let rendering error bubble up. 

1110 previous_rendered_map_index = ti.rendered_map_index 

1111 ti.rendered_map_index = _render_map_index(context, ti=ti, log=log) 

1112 # Send update only if value changed (e.g., user set context variables during execution) 

1113 if ti.rendered_map_index and ti.rendered_map_index != previous_rendered_map_index: 

1114 SUPERVISOR_COMMS.send(msg=SetRenderedMapIndex(rendered_map_index=ti.rendered_map_index)) 

1115 

1116 _push_xcom_if_needed(result, ti, log) 

1117 

1118 msg, state = _handle_current_task_success(context, ti) 

1119 except DownstreamTasksSkipped as skip: 

1120 log.info("Skipping downstream tasks.") 

1121 tasks_to_skip = skip.tasks if isinstance(skip.tasks, list) else [skip.tasks] 

1122 SUPERVISOR_COMMS.send(msg=SkipDownstreamTasks(tasks=tasks_to_skip)) 

1123 msg, state = _handle_current_task_success(context, ti) 

1124 except DagRunTriggerException as drte: 

1125 msg, state = _handle_trigger_dag_run(drte, context, ti, log) 

1126 except TaskDeferred as defer: 

1127 msg, state = _defer_task(defer, ti, log) 

1128 except AirflowSkipException as e: 

1129 if e.args: 

1130 log.info("Skipping task.", reason=e.args[0]) 

1131 msg = TaskState( 

1132 state=TaskInstanceState.SKIPPED, 

1133 end_date=datetime.now(tz=timezone.utc), 

1134 rendered_map_index=ti.rendered_map_index, 

1135 ) 

1136 state = TaskInstanceState.SKIPPED 

1137 except AirflowRescheduleException as reschedule: 

1138 log.info("Rescheduling task, marking task as UP_FOR_RESCHEDULE") 

1139 msg = RescheduleTask( 

1140 reschedule_date=reschedule.reschedule_date, end_date=datetime.now(tz=timezone.utc) 

1141 ) 

1142 state = TaskInstanceState.UP_FOR_RESCHEDULE 

1143 except (AirflowFailException, AirflowSensorTimeout) as e: 

1144 # If AirflowFailException is raised, task should not retry. 

1145 # If a sensor in reschedule mode reaches timeout, task should not retry. 

1146 log.exception("Task failed with exception") 

1147 ti.end_date = datetime.now(tz=timezone.utc) 

1148 msg = TaskState( 

1149 state=TaskInstanceState.FAILED, 

1150 end_date=ti.end_date, 

1151 rendered_map_index=ti.rendered_map_index, 

1152 ) 

1153 state = TaskInstanceState.FAILED 

1154 error = e 

1155 except (AirflowTaskTimeout, AirflowException, AirflowRuntimeError) as e: 

1156 # We should allow retries if the task has defined it. 

1157 log.exception("Task failed with exception") 

1158 msg, state = _handle_current_task_failed(ti) 

1159 error = e 

1160 except AirflowTaskTerminated as e: 

1161 # External state updates are already handled with `ti_heartbeat` and will be 

1162 # updated already be another UI API. So, these exceptions should ideally never be thrown. 

1163 # If these are thrown, we should mark the TI state as failed. 

1164 log.exception("Task failed with exception") 

1165 ti.end_date = datetime.now(tz=timezone.utc) 

1166 msg = TaskState( 

1167 state=TaskInstanceState.FAILED, 

1168 end_date=ti.end_date, 

1169 rendered_map_index=ti.rendered_map_index, 

1170 ) 

1171 state = TaskInstanceState.FAILED 

1172 error = e 

1173 except SystemExit as e: 

1174 # SystemExit needs to be retried if they are eligible. 

1175 log.error("Task exited", exit_code=e.code) 

1176 msg, state = _handle_current_task_failed(ti) 

1177 error = e 

1178 except BaseException as e: 

1179 log.exception("Task failed with exception") 

1180 msg, state = _handle_current_task_failed(ti) 

1181 error = e 

1182 finally: 

1183 if msg: 

1184 SUPERVISOR_COMMS.send(msg=msg) 

1185 

1186 # Return the message to make unit tests easier too 

1187 ti.state = state 

1188 return state, msg, error 

1189 

1190 

1191def _handle_current_task_success( 

1192 context: Context, 

1193 ti: RuntimeTaskInstance, 

1194) -> tuple[SucceedTask, TaskInstanceState]: 

1195 end_date = datetime.now(tz=timezone.utc) 

1196 ti.end_date = end_date 

1197 

1198 # Record operator and task instance success metrics 

1199 operator = ti.task.__class__.__name__ 

1200 stats_tags = {"dag_id": ti.dag_id, "task_id": ti.task_id} 

1201 

1202 Stats.incr(f"operator_successes_{operator}", tags=stats_tags) 

1203 # Same metric with tagging 

1204 Stats.incr("operator_successes", tags={**stats_tags, "operator": operator}) 

1205 Stats.incr("ti_successes", tags=stats_tags) 

1206 

1207 task_outlets = list(_build_asset_profiles(ti.task.outlets)) 

1208 outlet_events = list(_serialize_outlet_events(context["outlet_events"])) 

1209 msg = SucceedTask( 

1210 end_date=end_date, 

1211 task_outlets=task_outlets, 

1212 outlet_events=outlet_events, 

1213 rendered_map_index=ti.rendered_map_index, 

1214 ) 

1215 return msg, TaskInstanceState.SUCCESS 

1216 

1217 

1218def _handle_current_task_failed( 

1219 ti: RuntimeTaskInstance, 

1220) -> tuple[RetryTask, TaskInstanceState] | tuple[TaskState, TaskInstanceState]: 

1221 end_date = datetime.now(tz=timezone.utc) 

1222 ti.end_date = end_date 

1223 

1224 # Record operator and task instance failed metrics 

1225 operator = ti.task.__class__.__name__ 

1226 stats_tags = {"dag_id": ti.dag_id, "task_id": ti.task_id} 

1227 

1228 Stats.incr(f"operator_failures_{operator}", tags=stats_tags) 

1229 # Same metric with tagging 

1230 Stats.incr("operator_failures", tags={**stats_tags, "operator": operator}) 

1231 Stats.incr("ti_failures", tags=stats_tags) 

1232 

1233 if ti._ti_context_from_server and ti._ti_context_from_server.should_retry: 

1234 return RetryTask(end_date=end_date), TaskInstanceState.UP_FOR_RETRY 

1235 return ( 

1236 TaskState( 

1237 state=TaskInstanceState.FAILED, end_date=end_date, rendered_map_index=ti.rendered_map_index 

1238 ), 

1239 TaskInstanceState.FAILED, 

1240 ) 

1241 

1242 

1243def _handle_trigger_dag_run( 

1244 drte: DagRunTriggerException, context: Context, ti: RuntimeTaskInstance, log: Logger 

1245) -> tuple[ToSupervisor, TaskInstanceState]: 

1246 """Handle exception from TriggerDagRunOperator.""" 

1247 log.info("Triggering Dag Run.", trigger_dag_id=drte.trigger_dag_id) 

1248 comms_msg = SUPERVISOR_COMMS.send( 

1249 TriggerDagRun( 

1250 dag_id=drte.trigger_dag_id, 

1251 run_id=drte.dag_run_id, 

1252 logical_date=drte.logical_date, 

1253 conf=drte.conf, 

1254 reset_dag_run=drte.reset_dag_run, 

1255 ), 

1256 ) 

1257 

1258 if isinstance(comms_msg, ErrorResponse) and comms_msg.error == ErrorType.DAGRUN_ALREADY_EXISTS: 

1259 if drte.skip_when_already_exists: 

1260 log.info( 

1261 "Dag Run already exists, skipping task as skip_when_already_exists is set to True.", 

1262 dag_id=drte.trigger_dag_id, 

1263 ) 

1264 msg = TaskState( 

1265 state=TaskInstanceState.SKIPPED, 

1266 end_date=datetime.now(tz=timezone.utc), 

1267 rendered_map_index=ti.rendered_map_index, 

1268 ) 

1269 state = TaskInstanceState.SKIPPED 

1270 else: 

1271 log.error("Dag Run already exists, marking task as failed.", dag_id=drte.trigger_dag_id) 

1272 msg = TaskState( 

1273 state=TaskInstanceState.FAILED, 

1274 end_date=datetime.now(tz=timezone.utc), 

1275 rendered_map_index=ti.rendered_map_index, 

1276 ) 

1277 state = TaskInstanceState.FAILED 

1278 

1279 return msg, state 

1280 

1281 log.info("Dag Run triggered successfully.", trigger_dag_id=drte.trigger_dag_id) 

1282 

1283 # Store the run id from the dag run (either created or found above) to 

1284 # be used when creating the extra link on the webserver. 

1285 ti.xcom_push(key="trigger_run_id", value=drte.dag_run_id) 

1286 

1287 if drte.wait_for_completion: 

1288 if drte.deferrable: 

1289 from airflow.providers.standard.triggers.external_task import DagStateTrigger 

1290 

1291 defer = TaskDeferred( 

1292 trigger=DagStateTrigger( 

1293 dag_id=drte.trigger_dag_id, 

1294 states=drte.allowed_states + drte.failed_states, # type: ignore[arg-type] 

1295 # Don't filter by execution_dates when run_ids is provided. 

1296 # run_id uniquely identifies a DAG run, and when reset_dag_run=True, 

1297 # drte.logical_date might be a newly calculated value that doesn't match 

1298 # the persisted logical_date in the database, causing the trigger to never find the run. 

1299 execution_dates=None, 

1300 run_ids=[drte.dag_run_id], 

1301 poll_interval=drte.poke_interval, 

1302 ), 

1303 method_name="execute_complete", 

1304 ) 

1305 return _defer_task(defer, ti, log) 

1306 while True: 

1307 log.info( 

1308 "Waiting for dag run to complete execution in allowed state.", 

1309 dag_id=drte.trigger_dag_id, 

1310 run_id=drte.dag_run_id, 

1311 allowed_state=drte.allowed_states, 

1312 ) 

1313 time.sleep(drte.poke_interval) 

1314 

1315 comms_msg = SUPERVISOR_COMMS.send( 

1316 GetDagRunState(dag_id=drte.trigger_dag_id, run_id=drte.dag_run_id) 

1317 ) 

1318 if TYPE_CHECKING: 

1319 assert isinstance(comms_msg, DagRunStateResult) 

1320 if comms_msg.state in drte.failed_states: 

1321 log.error( 

1322 "DagRun finished with failed state.", dag_id=drte.trigger_dag_id, state=comms_msg.state 

1323 ) 

1324 msg = TaskState( 

1325 state=TaskInstanceState.FAILED, 

1326 end_date=datetime.now(tz=timezone.utc), 

1327 rendered_map_index=ti.rendered_map_index, 

1328 ) 

1329 state = TaskInstanceState.FAILED 

1330 return msg, state 

1331 if comms_msg.state in drte.allowed_states: 

1332 log.info( 

1333 "DagRun finished with allowed state.", dag_id=drte.trigger_dag_id, state=comms_msg.state 

1334 ) 

1335 break 

1336 log.debug( 

1337 "DagRun not yet in allowed or failed state.", 

1338 dag_id=drte.trigger_dag_id, 

1339 state=comms_msg.state, 

1340 ) 

1341 else: 

1342 # Fire-and-forget mode: wait_for_completion=False 

1343 if drte.deferrable: 

1344 log.info( 

1345 "Ignoring deferrable=True because wait_for_completion=False. " 

1346 "Task will complete immediately without waiting for the triggered DAG run.", 

1347 trigger_dag_id=drte.trigger_dag_id, 

1348 ) 

1349 

1350 return _handle_current_task_success(context, ti) 

1351 

1352 

1353def _run_task_state_change_callbacks( 

1354 task: BaseOperator, 

1355 kind: Literal[ 

1356 "on_execute_callback", 

1357 "on_failure_callback", 

1358 "on_success_callback", 

1359 "on_retry_callback", 

1360 "on_skipped_callback", 

1361 ], 

1362 context: Context, 

1363 log: Logger, 

1364) -> None: 

1365 callback: Callable[[Context], None] 

1366 for i, callback in enumerate(getattr(task, kind)): 

1367 try: 

1368 create_executable_runner(callback, context_get_outlet_events(context), logger=log).run(context) 

1369 except Exception: 

1370 log.exception("Failed to run task callback", kind=kind, index=i, callback=callback) 

1371 

1372 

1373def _send_error_email_notification( 

1374 task: BaseOperator | MappedOperator, 

1375 ti: RuntimeTaskInstance, 

1376 context: Context, 

1377 error: BaseException | str | None, 

1378 log: Logger, 

1379) -> None: 

1380 """Send email notification for task errors using SmtpNotifier.""" 

1381 try: 

1382 from airflow.providers.smtp.notifications.smtp import SmtpNotifier 

1383 except ImportError: 

1384 log.error( 

1385 "Failed to send task failure or retry email notification: " 

1386 "`apache-airflow-providers-smtp` is not installed. " 

1387 "Install this provider to enable email notifications." 

1388 ) 

1389 return 

1390 

1391 if not task.email: 

1392 return 

1393 

1394 subject_template_file = conf.get("email", "subject_template", fallback=None) 

1395 

1396 # Read the template file if configured 

1397 if subject_template_file and Path(subject_template_file).exists(): 

1398 subject = Path(subject_template_file).read_text() 

1399 else: 

1400 # Fallback to default 

1401 subject = "Airflow alert: {{ti}}" 

1402 

1403 html_content_template_file = conf.get("email", "html_content_template", fallback=None) 

1404 

1405 # Read the template file if configured 

1406 if html_content_template_file and Path(html_content_template_file).exists(): 

1407 html_content = Path(html_content_template_file).read_text() 

1408 else: 

1409 # Fallback to default 

1410 # For reporting purposes, we report based on 1-indexed, 

1411 # not 0-indexed lists (i.e. Try 1 instead of Try 0 for the first attempt). 

1412 html_content = ( 

1413 "Try {{try_number}} out of {{max_tries + 1}}<br>" 

1414 "Exception:<br>{{exception_html}}<br>" 

1415 'Log: <a href="{{ti.log_url}}">Link</a><br>' 

1416 "Host: {{ti.hostname}}<br>" 

1417 'Mark success: <a href="{{ti.mark_success_url}}">Link</a><br>' 

1418 ) 

1419 

1420 # Add exception_html to context for template rendering 

1421 import html 

1422 

1423 exception_html = html.escape(str(error)).replace("\n", "<br>") 

1424 additional_context = { 

1425 "exception": error, 

1426 "exception_html": exception_html, 

1427 "try_number": ti.try_number, 

1428 "max_tries": ti.max_tries, 

1429 } 

1430 email_context = {**context, **additional_context} 

1431 to_emails = task.email 

1432 if not to_emails: 

1433 return 

1434 

1435 try: 

1436 notifier = SmtpNotifier( 

1437 to=to_emails, 

1438 subject=subject, 

1439 html_content=html_content, 

1440 from_email=conf.get("email", "from_email", fallback="airflow@airflow"), 

1441 ) 

1442 notifier(email_context) 

1443 except Exception: 

1444 log.exception("Failed to send email notification") 

1445 

1446 

1447def _execute_task(context: Context, ti: RuntimeTaskInstance, log: Logger): 

1448 """Execute Task (optionally with a Timeout) and push Xcom results.""" 

1449 task = ti.task 

1450 execute = task.execute 

1451 

1452 if ti._ti_context_from_server and (next_method := ti._ti_context_from_server.next_method): 

1453 from airflow.sdk.serde import deserialize 

1454 

1455 next_kwargs_data = ti._ti_context_from_server.next_kwargs or {} 

1456 try: 

1457 if TYPE_CHECKING: 

1458 assert isinstance(next_kwargs_data, dict) 

1459 kwargs = deserialize(next_kwargs_data) 

1460 except (ImportError, KeyError, AttributeError, TypeError): 

1461 from airflow.serialization.serialized_objects import BaseSerialization 

1462 

1463 kwargs = BaseSerialization.deserialize(next_kwargs_data) 

1464 

1465 if TYPE_CHECKING: 

1466 assert isinstance(kwargs, dict) 

1467 execute = functools.partial(task.resume_execution, next_method=next_method, next_kwargs=kwargs) 

1468 

1469 ctx = contextvars.copy_context() 

1470 # Populate the context var so ExecutorSafeguard doesn't complain 

1471 ctx.run(ExecutorSafeguard.tracker.set, task) 

1472 

1473 # Export context in os.environ to make it available for operators to use. 

1474 airflow_context_vars = context_to_airflow_vars(context, in_env_var_format=True) 

1475 os.environ.update(airflow_context_vars) 

1476 

1477 outlet_events = context_get_outlet_events(context) 

1478 

1479 if (pre_execute_hook := task._pre_execute_hook) is not None: 

1480 create_executable_runner(pre_execute_hook, outlet_events, logger=log).run(context) 

1481 if getattr(pre_execute_hook := task.pre_execute, "__func__", None) is not BaseOperator.pre_execute: 

1482 create_executable_runner(pre_execute_hook, outlet_events, logger=log).run(context) 

1483 

1484 _run_task_state_change_callbacks(task, "on_execute_callback", context, log) 

1485 

1486 if task.execution_timeout: 

1487 from airflow.sdk.execution_time.timeout import timeout 

1488 

1489 # TODO: handle timeout in case of deferral 

1490 timeout_seconds = task.execution_timeout.total_seconds() 

1491 try: 

1492 # It's possible we're already timed out, so fast-fail if true 

1493 if timeout_seconds <= 0: 

1494 raise AirflowTaskTimeout() 

1495 # Run task in timeout wrapper 

1496 with timeout(timeout_seconds): 

1497 result = ctx.run(execute, context=context) 

1498 except AirflowTaskTimeout: 

1499 task.on_kill() 

1500 raise 

1501 else: 

1502 result = ctx.run(execute, context=context) 

1503 

1504 if (post_execute_hook := task._post_execute_hook) is not None: 

1505 create_executable_runner(post_execute_hook, outlet_events, logger=log).run(context, result) 

1506 if getattr(post_execute_hook := task.post_execute, "__func__", None) is not BaseOperator.post_execute: 

1507 create_executable_runner(post_execute_hook, outlet_events, logger=log).run(context) 

1508 

1509 return result 

1510 

1511 

1512def _render_map_index(context: Context, ti: RuntimeTaskInstance, log: Logger) -> str | None: 

1513 """Render named map index if the Dag author defined map_index_template at the task level.""" 

1514 if (template := context.get("map_index_template")) is None: 

1515 return None 

1516 log.debug("Rendering map_index_template", template_length=len(template)) 

1517 jinja_env = ti.task.dag.get_template_env() 

1518 rendered_map_index = jinja_env.from_string(template).render(context) 

1519 log.debug("Map index rendered", length=len(rendered_map_index)) 

1520 return rendered_map_index 

1521 

1522 

1523def _push_xcom_if_needed(result: Any, ti: RuntimeTaskInstance, log: Logger): 

1524 """Push XCom values when task has ``do_xcom_push`` set to ``True`` and the task returns a result.""" 

1525 if ti.task.do_xcom_push: 

1526 xcom_value = result 

1527 else: 

1528 xcom_value = None 

1529 

1530 has_mapped_dep = next(ti.task.iter_mapped_dependants(), None) is not None 

1531 if xcom_value is None: 

1532 if not ti.is_mapped and has_mapped_dep: 

1533 # Uhoh, a downstream mapped task depends on us to push something to map over 

1534 from airflow.sdk.exceptions import XComForMappingNotPushed 

1535 

1536 raise XComForMappingNotPushed() 

1537 return 

1538 

1539 mapped_length: int | None = None 

1540 if not ti.is_mapped and has_mapped_dep: 

1541 from airflow.sdk.definitions.mappedoperator import is_mappable_value 

1542 from airflow.sdk.exceptions import UnmappableXComTypePushed 

1543 

1544 if not is_mappable_value(xcom_value): 

1545 raise UnmappableXComTypePushed(xcom_value) 

1546 mapped_length = len(xcom_value) 

1547 

1548 log.info("Pushing xcom", ti=ti) 

1549 

1550 # If the task has multiple outputs, push each output as a separate XCom. 

1551 if ti.task.multiple_outputs: 

1552 if not isinstance(xcom_value, Mapping): 

1553 raise TypeError( 

1554 f"Returned output was type {type(xcom_value)} expected dictionary for multiple_outputs" 

1555 ) 

1556 for key in xcom_value.keys(): 

1557 if not isinstance(key, str): 

1558 raise TypeError( 

1559 "Returned dictionary keys must be strings when using " 

1560 f"multiple_outputs, found {key} ({type(key)}) instead" 

1561 ) 

1562 for k, v in result.items(): 

1563 ti.xcom_push(k, v) 

1564 

1565 _xcom_push(ti, BaseXCom.XCOM_RETURN_KEY, result, mapped_length=mapped_length) 

1566 

1567 

1568def finalize( 

1569 ti: RuntimeTaskInstance, 

1570 state: TaskInstanceState, 

1571 context: Context, 

1572 log: Logger, 

1573 error: BaseException | None = None, 

1574): 

1575 # Record task duration metrics for all terminal states 

1576 if ti.start_date and ti.end_date: 

1577 duration_ms = (ti.end_date - ti.start_date).total_seconds() * 1000 

1578 stats_tags = {"dag_id": ti.dag_id, "task_id": ti.task_id} 

1579 

1580 Stats.timing(f"dag.{ti.dag_id}.{ti.task_id}.duration", duration_ms) 

1581 Stats.timing("task.duration", duration_ms, tags=stats_tags) 

1582 

1583 task = ti.task 

1584 # Pushing xcom for each operator extra links defined on the operator only. 

1585 for oe in task.operator_extra_links: 

1586 try: 

1587 link, xcom_key = oe.get_link(operator=task, ti_key=ti), oe.xcom_key # type: ignore[arg-type] 

1588 log.debug("Setting xcom for operator extra link", link=link, xcom_key=xcom_key) 

1589 _xcom_push_to_db(ti, key=xcom_key, value=link) 

1590 except Exception: 

1591 log.exception( 

1592 "Failed to push an xcom for task operator extra link", 

1593 link_name=oe.name, 

1594 xcom_key=oe.xcom_key, 

1595 ti=ti, 

1596 ) 

1597 

1598 if getattr(ti.task, "overwrite_rtif_after_execution", False): 

1599 log.debug("Overwriting Rendered template fields.") 

1600 if ti.task.template_fields: 

1601 SUPERVISOR_COMMS.send(SetRenderedFields(rendered_fields=_serialize_rendered_fields(ti.task))) 

1602 

1603 log.debug("Running finalizers", ti=ti) 

1604 if state == TaskInstanceState.SUCCESS: 

1605 _run_task_state_change_callbacks(task, "on_success_callback", context, log) 

1606 try: 

1607 get_listener_manager().hook.on_task_instance_success( 

1608 previous_state=TaskInstanceState.RUNNING, task_instance=ti 

1609 ) 

1610 except Exception: 

1611 log.exception("error calling listener") 

1612 elif state == TaskInstanceState.SKIPPED: 

1613 _run_task_state_change_callbacks(task, "on_skipped_callback", context, log) 

1614 try: 

1615 get_listener_manager().hook.on_task_instance_skipped( 

1616 previous_state=TaskInstanceState.RUNNING, task_instance=ti 

1617 ) 

1618 except Exception: 

1619 log.exception("error calling listener") 

1620 elif state == TaskInstanceState.UP_FOR_RETRY: 

1621 _run_task_state_change_callbacks(task, "on_retry_callback", context, log) 

1622 try: 

1623 get_listener_manager().hook.on_task_instance_failed( 

1624 previous_state=TaskInstanceState.RUNNING, task_instance=ti, error=error 

1625 ) 

1626 except Exception: 

1627 log.exception("error calling listener") 

1628 if error and task.email_on_retry and task.email: 

1629 _send_error_email_notification(task, ti, context, error, log) 

1630 elif state == TaskInstanceState.FAILED: 

1631 _run_task_state_change_callbacks(task, "on_failure_callback", context, log) 

1632 try: 

1633 get_listener_manager().hook.on_task_instance_failed( 

1634 previous_state=TaskInstanceState.RUNNING, task_instance=ti, error=error 

1635 ) 

1636 except Exception: 

1637 log.exception("error calling listener") 

1638 if error and task.email_on_failure and task.email: 

1639 _send_error_email_notification(task, ti, context, error, log) 

1640 

1641 try: 

1642 get_listener_manager().hook.before_stopping(component=TaskRunnerMarker()) 

1643 except Exception: 

1644 log.exception("error calling listener") 

1645 

1646 

1647def main(): 

1648 log = structlog.get_logger(logger_name="task") 

1649 

1650 global SUPERVISOR_COMMS 

1651 SUPERVISOR_COMMS = CommsDecoder[ToTask, ToSupervisor](log=log) 

1652 

1653 try: 

1654 ti, context, log = startup() 

1655 with BundleVersionLock( 

1656 bundle_name=ti.bundle_instance.name, 

1657 bundle_version=ti.bundle_instance.version, 

1658 ): 

1659 state, _, error = run(ti, context, log) 

1660 context["exception"] = error 

1661 finalize(ti, state, context, log, error) 

1662 except KeyboardInterrupt: 

1663 log.exception("Ctrl-c hit") 

1664 exit(2) 

1665 except Exception: 

1666 log.exception("Top level error") 

1667 exit(1) 

1668 finally: 

1669 # Ensure the request socket is closed on the child side in all circumstances 

1670 # before the process fully terminates. 

1671 if SUPERVISOR_COMMS and SUPERVISOR_COMMS.socket: 

1672 with suppress(Exception): 

1673 SUPERVISOR_COMMS.socket.close() 

1674 

1675 

1676def reinit_supervisor_comms() -> None: 

1677 """ 

1678 Re-initialize supervisor comms and logging channel in subprocess. 

1679 

1680 This is not needed for most cases, but is used when either we re-launch the process via sudo for 

1681 run_as_user, or from inside the python code in a virtualenv (et al.) operator to re-connect so those tasks 

1682 can continue to access variables etc. 

1683 """ 

1684 import socket 

1685 

1686 if "SUPERVISOR_COMMS" not in globals(): 

1687 global SUPERVISOR_COMMS 

1688 log = structlog.get_logger(logger_name="task") 

1689 

1690 fd = int(os.environ.get("__AIRFLOW_SUPERVISOR_FD", "0")) 

1691 

1692 SUPERVISOR_COMMS = CommsDecoder[ToTask, ToSupervisor](log=log, socket=socket.socket(fileno=fd)) 

1693 

1694 logs = SUPERVISOR_COMMS.send(ResendLoggingFD()) 

1695 if isinstance(logs, SentFDs): 

1696 from airflow.sdk.log import configure_logging 

1697 

1698 log_io = os.fdopen(logs.fds[0], "wb", buffering=0) 

1699 configure_logging(json_output=True, output=log_io, sending_to_supervisor=True) 

1700 else: 

1701 print("Unable to re-configure logging after sudo, we didn't get an FD", file=sys.stderr) 

1702 

1703 

1704if __name__ == "__main__": 

1705 main()