Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/airflow/sdk/execution_time/task_runner.py: 15%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

741 statements  

1# 

2# Licensed to the Apache Software Foundation (ASF) under one 

3# or more contributor license agreements. See the NOTICE file 

4# distributed with this work for additional information 

5# regarding copyright ownership. The ASF licenses this file 

6# to you under the Apache License, Version 2.0 (the 

7# "License"); you may not use this file except in compliance 

8# with the License. You may obtain a copy of the License at 

9# 

10# http://www.apache.org/licenses/LICENSE-2.0 

11# 

12# Unless required by applicable law or agreed to in writing, 

13# software distributed under the License is distributed on an 

14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

15# KIND, either express or implied. See the License for the 

16# specific language governing permissions and limitations 

17# under the License. 

18"""The entrypoint for the actual task execution process.""" 

19 

20from __future__ import annotations 

21 

22import contextlib 

23import contextvars 

24import functools 

25import os 

26import sys 

27import time 

28from collections.abc import Callable, Iterable, Iterator, Mapping 

29from contextlib import suppress 

30from datetime import datetime, timezone 

31from itertools import product 

32from pathlib import Path 

33from typing import TYPE_CHECKING, Annotated, Any, Literal 

34from urllib.parse import quote 

35 

36import attrs 

37import lazy_object_proxy 

38import structlog 

39from pydantic import AwareDatetime, ConfigDict, Field, JsonValue, TypeAdapter 

40 

41from airflow.dag_processing.bundles.base import BaseDagBundle, BundleVersionLock 

42from airflow.dag_processing.bundles.manager import DagBundlesManager 

43from airflow.listeners.listener import get_listener_manager 

44from airflow.sdk.api.client import get_hostname, getuser 

45from airflow.sdk.api.datamodels._generated import ( 

46 AssetProfile, 

47 DagRun, 

48 PreviousTIResponse, 

49 TaskInstance, 

50 TaskInstanceState, 

51 TIRunContext, 

52) 

53from airflow.sdk.bases.operator import BaseOperator, ExecutorSafeguard 

54from airflow.sdk.bases.xcom import BaseXCom 

55from airflow.sdk.configuration import conf 

56from airflow.sdk.definitions._internal.dag_parsing_context import _airflow_parsing_context_manager 

57from airflow.sdk.definitions._internal.types import NOTSET, ArgNotSet, is_arg_set 

58from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetNameRef, AssetUniqueKey, AssetUriRef 

59from airflow.sdk.definitions.mappedoperator import MappedOperator 

60from airflow.sdk.definitions.param import process_params 

61from airflow.sdk.exceptions import ( 

62 AirflowException, 

63 AirflowInactiveAssetInInletOrOutletException, 

64 AirflowRuntimeError, 

65 AirflowTaskTimeout, 

66 ErrorType, 

67 TaskDeferred, 

68) 

69from airflow.sdk.execution_time.callback_runner import create_executable_runner 

70from airflow.sdk.execution_time.comms import ( 

71 AssetEventDagRunReferenceResult, 

72 CommsDecoder, 

73 DagRunStateResult, 

74 DeferTask, 

75 DRCount, 

76 ErrorResponse, 

77 GetDagRunState, 

78 GetDRCount, 

79 GetPreviousDagRun, 

80 GetPreviousTI, 

81 GetTaskBreadcrumbs, 

82 GetTaskRescheduleStartDate, 

83 GetTaskStates, 

84 GetTICount, 

85 InactiveAssetsResult, 

86 PreviousDagRunResult, 

87 PreviousTIResult, 

88 RescheduleTask, 

89 ResendLoggingFD, 

90 RetryTask, 

91 SentFDs, 

92 SetRenderedFields, 

93 SetRenderedMapIndex, 

94 SkipDownstreamTasks, 

95 StartupDetails, 

96 SucceedTask, 

97 TaskBreadcrumbsResult, 

98 TaskRescheduleStartDate, 

99 TaskState, 

100 TaskStatesResult, 

101 TICount, 

102 ToSupervisor, 

103 ToTask, 

104 TriggerDagRun, 

105 ValidateInletsAndOutlets, 

106) 

107from airflow.sdk.execution_time.context import ( 

108 ConnectionAccessor, 

109 InletEventsAccessors, 

110 MacrosAccessor, 

111 OutletEventAccessors, 

112 TriggeringAssetEventsAccessor, 

113 VariableAccessor, 

114 context_get_outlet_events, 

115 context_to_airflow_vars, 

116 get_previous_dagrun_success, 

117 set_current_context, 

118) 

119from airflow.sdk.execution_time.sentry import Sentry 

120from airflow.sdk.execution_time.xcom import XCom 

121from airflow.sdk.observability.stats import Stats 

122from airflow.sdk.timezone import coerce_datetime 

123 

124if TYPE_CHECKING: 

125 import jinja2 

126 from pendulum.datetime import DateTime 

127 from structlog.typing import FilteringBoundLogger as Logger 

128 

129 from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator 

130 from airflow.sdk.definitions.context import Context 

131 from airflow.sdk.exceptions import DagRunTriggerException 

132 from airflow.sdk.types import OutletEventAccessorsProtocol 

133 

134 

135class TaskRunnerMarker: 

136 """Marker for listener hooks, to properly detect from which component they are called.""" 

137 

138 

139# TODO: Move this entire class into a separate file: 

140# `airflow/sdk/execution_time/task_instance.py` 

141# or `airflow/sdk/execution_time/runtime_ti.py` 

142class RuntimeTaskInstance(TaskInstance): 

143 model_config = ConfigDict(arbitrary_types_allowed=True) 

144 

145 task: BaseOperator 

146 bundle_instance: BaseDagBundle 

147 _cached_template_context: Context | None = None 

148 """The Task Instance context. This is used to cache get_template_context.""" 

149 

150 _ti_context_from_server: Annotated[TIRunContext | None, Field(repr=False)] = None 

151 """The Task Instance context from the API server, if any.""" 

152 

153 max_tries: int = 0 

154 """The maximum number of retries for the task.""" 

155 

156 start_date: AwareDatetime 

157 """Start date of the task instance.""" 

158 

159 end_date: AwareDatetime | None = None 

160 

161 state: TaskInstanceState | None = None 

162 

163 is_mapped: bool | None = None 

164 """True if the original task was mapped.""" 

165 

166 rendered_map_index: str | None = None 

167 

168 sentry_integration: str = "" 

169 

170 def __rich_repr__(self): 

171 yield "id", self.id 

172 yield "task_id", self.task_id 

173 yield "dag_id", self.dag_id 

174 yield "run_id", self.run_id 

175 yield "max_tries", self.max_tries 

176 yield "task", type(self.task) 

177 yield "start_date", self.start_date 

178 

179 __rich_repr__.angular = True # type: ignore[attr-defined] 

180 

181 def get_template_context(self) -> Context: 

182 # TODO: Move this to `airflow.sdk.execution_time.context` 

183 # once we port the entire context logic from airflow/utils/context.py ? 

184 from airflow.plugins_manager import integrate_macros_plugins 

185 

186 integrate_macros_plugins() 

187 

188 dag_run_conf: dict[str, Any] | None = None 

189 if from_server := self._ti_context_from_server: 

190 dag_run_conf = from_server.dag_run.conf or dag_run_conf 

191 

192 validated_params = process_params(self.task.dag, self.task, dag_run_conf, suppress_exception=False) 

193 

194 # Cache the context object, which ensures that all calls to get_template_context 

195 # are operating on the same context object. 

196 self._cached_template_context: Context = self._cached_template_context or { 

197 # From the Task Execution interface 

198 "dag": self.task.dag, 

199 "inlets": self.task.inlets, 

200 "map_index_template": self.task.map_index_template, 

201 "outlets": self.task.outlets, 

202 "run_id": self.run_id, 

203 "task": self.task, 

204 "task_instance": self, 

205 "ti": self, 

206 "outlet_events": OutletEventAccessors(), 

207 "inlet_events": InletEventsAccessors(self.task.inlets), 

208 "macros": MacrosAccessor(), 

209 "params": validated_params, 

210 # TODO: Make this go through Public API longer term. 

211 # "test_mode": task_instance.test_mode, 

212 "var": { 

213 "json": VariableAccessor(deserialize_json=True), 

214 "value": VariableAccessor(deserialize_json=False), 

215 }, 

216 "conn": ConnectionAccessor(), 

217 } 

218 if from_server: 

219 dag_run = from_server.dag_run 

220 context_from_server: Context = { 

221 # TODO: Assess if we need to pass these through timezone.coerce_datetime 

222 "dag_run": dag_run, # type: ignore[typeddict-item] # Removable after #46522 

223 "triggering_asset_events": TriggeringAssetEventsAccessor.build( 

224 AssetEventDagRunReferenceResult.from_asset_event_dag_run_reference(event) 

225 for event in dag_run.consumed_asset_events 

226 ), 

227 "task_instance_key_str": f"{self.task.dag_id}__{self.task.task_id}__{dag_run.run_id}", 

228 "task_reschedule_count": from_server.task_reschedule_count or 0, 

229 "prev_start_date_success": lazy_object_proxy.Proxy( 

230 lambda: coerce_datetime(get_previous_dagrun_success(self.id).start_date) 

231 ), 

232 "prev_end_date_success": lazy_object_proxy.Proxy( 

233 lambda: coerce_datetime(get_previous_dagrun_success(self.id).end_date) 

234 ), 

235 } 

236 self._cached_template_context.update(context_from_server) 

237 

238 if logical_date := coerce_datetime(dag_run.logical_date): 

239 if TYPE_CHECKING: 

240 assert isinstance(logical_date, DateTime) 

241 ds = logical_date.strftime("%Y-%m-%d") 

242 ds_nodash = ds.replace("-", "") 

243 ts = logical_date.isoformat() 

244 ts_nodash = logical_date.strftime("%Y%m%dT%H%M%S") 

245 ts_nodash_with_tz = ts.replace("-", "").replace(":", "") 

246 # logical_date and data_interval either coexist or be None together 

247 self._cached_template_context.update( 

248 { 

249 # keys that depend on logical_date 

250 "logical_date": logical_date, 

251 "ds": ds, 

252 "ds_nodash": ds_nodash, 

253 "task_instance_key_str": f"{self.task.dag_id}__{self.task.task_id}__{ds_nodash}", 

254 "ts": ts, 

255 "ts_nodash": ts_nodash, 

256 "ts_nodash_with_tz": ts_nodash_with_tz, 

257 # keys that depend on data_interval 

258 "data_interval_end": coerce_datetime(dag_run.data_interval_end), 

259 "data_interval_start": coerce_datetime(dag_run.data_interval_start), 

260 "prev_data_interval_start_success": lazy_object_proxy.Proxy( 

261 lambda: coerce_datetime(get_previous_dagrun_success(self.id).data_interval_start) 

262 ), 

263 "prev_data_interval_end_success": lazy_object_proxy.Proxy( 

264 lambda: coerce_datetime(get_previous_dagrun_success(self.id).data_interval_end) 

265 ), 

266 } 

267 ) 

268 

269 if from_server.upstream_map_indexes is not None: 

270 # We stash this in here for later use, but we purposefully don't want to document it's 

271 # existence. Should this be a private attribute on RuntimeTI instead perhaps? 

272 setattr(self, "_upstream_map_indexes", from_server.upstream_map_indexes) 

273 

274 return self._cached_template_context 

275 

276 def render_templates( 

277 self, context: Context | None = None, jinja_env: jinja2.Environment | None = None 

278 ) -> BaseOperator: 

279 """ 

280 Render templates in the operator fields. 

281 

282 If the task was originally mapped, this may replace ``self.task`` with 

283 the unmapped, fully rendered BaseOperator. The original ``self.task`` 

284 before replacement is returned. 

285 """ 

286 if not context: 

287 context = self.get_template_context() 

288 original_task = self.task 

289 

290 if TYPE_CHECKING: 

291 assert context 

292 

293 ti = context["ti"] 

294 

295 if TYPE_CHECKING: 

296 assert original_task 

297 assert self.task 

298 assert ti.task 

299 

300 # If self.task is mapped, this call replaces self.task to point to the 

301 # unmapped BaseOperator created by this function! This is because the 

302 # MappedOperator is useless for template rendering, and we need to be 

303 # able to access the unmapped task instead. 

304 self.task.render_template_fields(context, jinja_env) 

305 self.is_mapped = original_task.is_mapped 

306 return original_task 

307 

308 def xcom_pull( 

309 self, 

310 task_ids: str | Iterable[str] | None = None, 

311 dag_id: str | None = None, 

312 key: str = BaseXCom.XCOM_RETURN_KEY, 

313 include_prior_dates: bool = False, 

314 *, 

315 map_indexes: int | Iterable[int] | None | ArgNotSet = NOTSET, 

316 default: Any = None, 

317 run_id: str | None = None, 

318 ) -> Any: 

319 """ 

320 Pull XComs either from the API server (BaseXCom) or from the custom XCOM backend if configured. 

321 

322 The pull can be filtered optionally by certain criterion. 

323 

324 :param key: A key for the XCom. If provided, only XComs with matching 

325 keys will be returned. The default key is ``'return_value'``, also 

326 available as constant ``XCOM_RETURN_KEY``. This key is automatically 

327 given to XComs returned by tasks (as opposed to being pushed 

328 manually). To remove the filter, pass *None*. 

329 :param task_ids: Only XComs from tasks with matching ids will be 

330 pulled. Pass *None* to remove the filter. 

331 :param dag_id: If provided, only pulls XComs from this Dag. If *None* 

332 (default), the Dag of the calling task is used. 

333 :param map_indexes: If provided, only pull XComs with matching indexes. 

334 If *None* (default), this is inferred from the task(s) being pulled 

335 (see below for details). 

336 :param include_prior_dates: If False, only XComs from the current 

337 logical_date are returned. If *True*, XComs from previous dates 

338 are returned as well. 

339 :param run_id: If provided, only pulls XComs from a DagRun w/a matching run_id. 

340 If *None* (default), the run_id of the calling task is used. 

341 

342 When pulling one single task (``task_id`` is *None* or a str) without 

343 specifying ``map_indexes``, the return value is a single XCom entry 

344 (map_indexes is set to map_index of the calling task instance). 

345 

346 When pulling task is mapped the specified ``map_index`` is used, so by default 

347 pulling on mapped task will result in no matching XComs if the task instance 

348 of the method call is not mapped. Otherwise, the map_index of the calling task 

349 instance is used. Setting ``map_indexes`` to *None* will pull XCom as it would 

350 from a non mapped task. 

351 

352 In either case, ``default`` (*None* if not specified) is returned if no 

353 matching XComs are found. 

354 

355 When pulling multiple tasks (i.e. either ``task_id`` or ``map_index`` is 

356 a non-str iterable), a list of matching XComs is returned. Elements in 

357 the list is ordered by item ordering in ``task_id`` and ``map_index``. 

358 """ 

359 if dag_id is None: 

360 dag_id = self.dag_id 

361 if run_id is None: 

362 run_id = self.run_id 

363 

364 single_task_requested = isinstance(task_ids, (str, type(None))) 

365 single_map_index_requested = isinstance(map_indexes, (int, type(None))) 

366 

367 if task_ids is None: 

368 # default to the current task if not provided 

369 task_ids = [self.task_id] 

370 elif isinstance(task_ids, str): 

371 task_ids = [task_ids] 

372 

373 # If map_indexes is not specified, pull xcoms from all map indexes for each task 

374 if not is_arg_set(map_indexes): 

375 xcoms: list[Any] = [] 

376 for t_id in task_ids: 

377 values = XCom.get_all( 

378 run_id=run_id, 

379 key=key, 

380 task_id=t_id, 

381 dag_id=dag_id, 

382 include_prior_dates=include_prior_dates, 

383 ) 

384 

385 if values is None: 

386 xcoms.append(None) 

387 else: 

388 xcoms.extend(values) 

389 # For single task pulling from unmapped task, return single value 

390 if single_task_requested and len(xcoms) == 1: 

391 return xcoms[0] 

392 return xcoms 

393 

394 # Original logic when map_indexes is explicitly specified 

395 map_indexes_iterable: Iterable[int | None] = [] 

396 if isinstance(map_indexes, int) or map_indexes is None: 

397 map_indexes_iterable = [map_indexes] 

398 elif isinstance(map_indexes, Iterable): 

399 map_indexes_iterable = map_indexes 

400 else: 

401 raise TypeError( 

402 f"Invalid type for map_indexes: expected int, iterable of ints, or None, got {type(map_indexes)}" 

403 ) 

404 

405 xcoms = [] 

406 for t_id, m_idx in product(task_ids, map_indexes_iterable): 

407 value = XCom.get_one( 

408 run_id=run_id, 

409 key=key, 

410 task_id=t_id, 

411 dag_id=dag_id, 

412 map_index=m_idx, 

413 include_prior_dates=include_prior_dates, 

414 ) 

415 if value is None: 

416 xcoms.append(default) 

417 else: 

418 xcoms.append(value) 

419 

420 if single_task_requested and single_map_index_requested: 

421 return xcoms[0] 

422 

423 return xcoms 

424 

425 def xcom_push(self, key: str, value: Any): 

426 """ 

427 Make an XCom available for tasks to pull. 

428 

429 :param key: Key to store the value under. 

430 :param value: Value to store. Only be JSON-serializable values may be used. 

431 """ 

432 _xcom_push(self, key, value) 

433 

434 def get_relevant_upstream_map_indexes( 

435 self, upstream: BaseOperator, ti_count: int | None, session: Any 

436 ) -> int | range | None: 

437 # TODO: Implement this method 

438 return None 

439 

440 def get_first_reschedule_date(self, context: Context) -> AwareDatetime | None: 

441 """Get the first reschedule date for the task instance if found, none otherwise.""" 

442 if context.get("task_reschedule_count", 0) == 0: 

443 # If the task has not been rescheduled, there is no need to ask the supervisor 

444 return None 

445 

446 max_tries: int = self.max_tries 

447 retries: int = self.task.retries or 0 

448 first_try_number = max_tries - retries + 1 

449 

450 log = structlog.get_logger(logger_name="task") 

451 

452 log.debug("Requesting first reschedule date from supervisor") 

453 

454 response = SUPERVISOR_COMMS.send( 

455 msg=GetTaskRescheduleStartDate(ti_id=self.id, try_number=first_try_number) 

456 ) 

457 

458 if TYPE_CHECKING: 

459 assert isinstance(response, TaskRescheduleStartDate) 

460 

461 return response.start_date 

462 

463 def get_previous_dagrun(self, state: str | None = None) -> DagRun | None: 

464 """Return the previous Dag run before the given logical date, optionally filtered by state.""" 

465 context = self.get_template_context() 

466 dag_run = context.get("dag_run") 

467 

468 log = structlog.get_logger(logger_name="task") 

469 

470 log.debug("Getting previous Dag run", dag_run=dag_run) 

471 

472 if dag_run is None: 

473 return None 

474 

475 if dag_run.logical_date is None: 

476 return None 

477 

478 response = SUPERVISOR_COMMS.send( 

479 msg=GetPreviousDagRun(dag_id=self.dag_id, logical_date=dag_run.logical_date, state=state) 

480 ) 

481 

482 if TYPE_CHECKING: 

483 assert isinstance(response, PreviousDagRunResult) 

484 

485 return response.dag_run 

486 

487 def get_previous_ti( 

488 self, 

489 state: TaskInstanceState | None = None, 

490 logical_date: AwareDatetime | None = None, 

491 map_index: int = -1, 

492 ) -> PreviousTIResponse | None: 

493 """ 

494 Return the previous task instance matching the given criteria. 

495 

496 :param state: Filter by TaskInstance state 

497 :param logical_date: Filter by logical date (returns TI before this date) 

498 :param map_index: Filter by map_index (defaults to -1 for non-mapped tasks) 

499 :return: Previous task instance or None if not found 

500 """ 

501 context = self.get_template_context() 

502 dag_run = context.get("dag_run") 

503 

504 log = structlog.get_logger(logger_name="task") 

505 log.debug("Getting previous task instance", task_id=self.task_id, state=state) 

506 

507 # Use current dag run's logical_date if not provided 

508 effective_logical_date = logical_date 

509 if effective_logical_date is None and dag_run and dag_run.logical_date: 

510 effective_logical_date = dag_run.logical_date 

511 

512 response = SUPERVISOR_COMMS.send( 

513 msg=GetPreviousTI( 

514 dag_id=self.dag_id, 

515 task_id=self.task_id, 

516 logical_date=effective_logical_date, 

517 map_index=map_index, 

518 state=state, 

519 ) 

520 ) 

521 

522 if TYPE_CHECKING: 

523 assert isinstance(response, PreviousTIResult) 

524 

525 return response.task_instance 

526 

527 @staticmethod 

528 def get_ti_count( 

529 dag_id: str, 

530 map_index: int | None = None, 

531 task_ids: list[str] | None = None, 

532 task_group_id: str | None = None, 

533 logical_dates: list[datetime] | None = None, 

534 run_ids: list[str] | None = None, 

535 states: list[str] | None = None, 

536 ) -> int: 

537 """Return the number of task instances matching the given criteria.""" 

538 response = SUPERVISOR_COMMS.send( 

539 GetTICount( 

540 dag_id=dag_id, 

541 map_index=map_index, 

542 task_ids=task_ids, 

543 task_group_id=task_group_id, 

544 logical_dates=logical_dates, 

545 run_ids=run_ids, 

546 states=states, 

547 ), 

548 ) 

549 

550 if TYPE_CHECKING: 

551 assert isinstance(response, TICount) 

552 

553 return response.count 

554 

555 @staticmethod 

556 def get_task_states( 

557 dag_id: str, 

558 map_index: int | None = None, 

559 task_ids: list[str] | None = None, 

560 task_group_id: str | None = None, 

561 logical_dates: list[datetime] | None = None, 

562 run_ids: list[str] | None = None, 

563 ) -> dict[str, Any]: 

564 """Return the task states matching the given criteria.""" 

565 response = SUPERVISOR_COMMS.send( 

566 GetTaskStates( 

567 dag_id=dag_id, 

568 map_index=map_index, 

569 task_ids=task_ids, 

570 task_group_id=task_group_id, 

571 logical_dates=logical_dates, 

572 run_ids=run_ids, 

573 ), 

574 ) 

575 

576 if TYPE_CHECKING: 

577 assert isinstance(response, TaskStatesResult) 

578 

579 return response.task_states 

580 

581 @staticmethod 

582 def get_task_breadcrumbs(dag_id: str, run_id: str) -> Iterable[dict[str, Any]]: 

583 """Return task breadcrumbs for the given dag run.""" 

584 response = SUPERVISOR_COMMS.send(GetTaskBreadcrumbs(dag_id=dag_id, run_id=run_id)) 

585 if TYPE_CHECKING: 

586 assert isinstance(response, TaskBreadcrumbsResult) 

587 return response.breadcrumbs 

588 

589 @staticmethod 

590 def get_dr_count( 

591 dag_id: str, 

592 logical_dates: list[datetime] | None = None, 

593 run_ids: list[str] | None = None, 

594 states: list[str] | None = None, 

595 ) -> int: 

596 """Return the number of Dag runs matching the given criteria.""" 

597 response = SUPERVISOR_COMMS.send( 

598 GetDRCount( 

599 dag_id=dag_id, 

600 logical_dates=logical_dates, 

601 run_ids=run_ids, 

602 states=states, 

603 ), 

604 ) 

605 

606 if TYPE_CHECKING: 

607 assert isinstance(response, DRCount) 

608 

609 return response.count 

610 

611 @staticmethod 

612 def get_dagrun_state(dag_id: str, run_id: str) -> str: 

613 """Return the state of the Dag run with the given Run ID.""" 

614 response = SUPERVISOR_COMMS.send(msg=GetDagRunState(dag_id=dag_id, run_id=run_id)) 

615 

616 if TYPE_CHECKING: 

617 assert isinstance(response, DagRunStateResult) 

618 

619 return response.state 

620 

621 @property 

622 def log_url(self) -> str: 

623 run_id = quote(self.run_id) 

624 base_url = conf.get("api", "base_url", fallback="http://localhost:8080/") 

625 map_index_value = self.map_index 

626 map_index = ( 

627 f"/mapped/{map_index_value}" if map_index_value is not None and map_index_value >= 0 else "" 

628 ) 

629 try_number_value = self.try_number 

630 try_number = ( 

631 f"?try_number={try_number_value}" if try_number_value is not None and try_number_value > 0 else "" 

632 ) 

633 _log_uri = f"{base_url.rstrip('/')}/dags/{self.dag_id}/runs/{run_id}/tasks/{self.task_id}{map_index}{try_number}" 

634 return _log_uri 

635 

636 @property 

637 def mark_success_url(self) -> str: 

638 """URL to mark TI success.""" 

639 return self.log_url 

640 

641 

642def _xcom_push(ti: RuntimeTaskInstance, key: str, value: Any, mapped_length: int | None = None) -> None: 

643 """Push a XCom through XCom.set, which pushes to XCom Backend if configured.""" 

644 # Private function, as we don't want to expose the ability to manually set `mapped_length` to SDK 

645 # consumers 

646 

647 XCom.set( 

648 key=key, 

649 value=value, 

650 dag_id=ti.dag_id, 

651 task_id=ti.task_id, 

652 run_id=ti.run_id, 

653 map_index=ti.map_index, 

654 _mapped_length=mapped_length, 

655 ) 

656 

657 

658def _xcom_push_to_db(ti: RuntimeTaskInstance, key: str, value: Any) -> None: 

659 """Push a XCom directly to metadata DB, bypassing custom xcom_backend.""" 

660 XCom._set_xcom_in_db( 

661 key=key, 

662 value=value, 

663 dag_id=ti.dag_id, 

664 task_id=ti.task_id, 

665 run_id=ti.run_id, 

666 map_index=ti.map_index, 

667 ) 

668 

669 

670def parse(what: StartupDetails, log: Logger) -> RuntimeTaskInstance: 

671 # TODO: Task-SDK: 

672 # Using DagBag here is about 98% wrong, but it'll do for now 

673 from airflow.dag_processing.dagbag import DagBag 

674 

675 bundle_info = what.bundle_info 

676 bundle_instance = DagBundlesManager().get_bundle( 

677 name=bundle_info.name, 

678 version=bundle_info.version, 

679 ) 

680 bundle_instance.initialize() 

681 

682 # Put bundle root on sys.path if needed. This allows the dag bundle to add 

683 # code in util modules to be shared between files within the same bundle. 

684 if (bundle_root := os.fspath(bundle_instance.path)) not in sys.path: 

685 sys.path.append(bundle_root) 

686 

687 dag_absolute_path = os.fspath(Path(bundle_instance.path, what.dag_rel_path)) 

688 bag = DagBag( 

689 dag_folder=dag_absolute_path, 

690 include_examples=False, 

691 safe_mode=False, 

692 load_op_links=False, 

693 bundle_name=bundle_info.name, 

694 ) 

695 if TYPE_CHECKING: 

696 assert what.ti.dag_id 

697 

698 try: 

699 dag = bag.dags[what.ti.dag_id] 

700 except KeyError: 

701 log.error( 

702 "Dag not found during start up", dag_id=what.ti.dag_id, bundle=bundle_info, path=what.dag_rel_path 

703 ) 

704 sys.exit(1) 

705 

706 # install_loader() 

707 

708 try: 

709 task = dag.task_dict[what.ti.task_id] 

710 except KeyError: 

711 log.error( 

712 "Task not found in Dag during start up", 

713 dag_id=dag.dag_id, 

714 task_id=what.ti.task_id, 

715 bundle=bundle_info, 

716 path=what.dag_rel_path, 

717 ) 

718 sys.exit(1) 

719 

720 if not isinstance(task, (BaseOperator, MappedOperator)): 

721 raise TypeError( 

722 f"task is of the wrong type, got {type(task)}, wanted {BaseOperator} or {MappedOperator}" 

723 ) 

724 

725 return RuntimeTaskInstance.model_construct( 

726 **what.ti.model_dump(exclude_unset=True), 

727 task=task, 

728 bundle_instance=bundle_instance, 

729 _ti_context_from_server=what.ti_context, 

730 max_tries=what.ti_context.max_tries, 

731 start_date=what.start_date, 

732 state=TaskInstanceState.RUNNING, 

733 sentry_integration=what.sentry_integration, 

734 ) 

735 

736 

737# This global variable will be used by Connection/Variable/XCom classes, or other parts of the task's execution, 

738# to send requests back to the supervisor process. 

739# 

740# Why it needs to be a global: 

741# - Many parts of Airflow's codebase (e.g., connections, variables, and XComs) may rely on making dynamic requests 

742# to the parent process during task execution. 

743# - These calls occur in various locations and cannot easily pass the `CommsDecoder` instance through the 

744# deeply nested execution stack. 

745# - By defining `SUPERVISOR_COMMS` as a global, it ensures that this communication mechanism is readily 

746# accessible wherever needed during task execution without modifying every layer of the call stack. 

747SUPERVISOR_COMMS: CommsDecoder[ToTask, ToSupervisor] 

748 

749 

750# State machine! 

751# 1. Start up (receive details from supervisor) 

752# 2. Execution (run task code, possibly send requests) 

753# 3. Shutdown and report status 

754 

755 

756def startup() -> tuple[RuntimeTaskInstance, Context, Logger]: 

757 # The parent sends us a StartupDetails message un-prompted. After this, every single message is only sent 

758 # in response to us sending a request. 

759 log = structlog.get_logger(logger_name="task") 

760 

761 if os.environ.get("_AIRFLOW__REEXECUTED_PROCESS") == "1" and ( 

762 msgjson := os.environ.get("_AIRFLOW__STARTUP_MSG") 

763 ): 

764 # Clear any Kerberos replace cache if there is one, so new process can't reuse it. 

765 os.environ.pop("KRB5CCNAME", None) 

766 # entrypoint of re-exec process 

767 

768 msg: StartupDetails = TypeAdapter(StartupDetails).validate_json(msgjson) 

769 reinit_supervisor_comms() 

770 

771 # We delay this message until _after_ we've got the logging re-configured, otherwise it will show up 

772 # on stdout 

773 log.debug("Using serialized startup message from environment", msg=msg) 

774 else: 

775 # normal entry point 

776 msg = SUPERVISOR_COMMS._get_response() # type: ignore[assignment] 

777 

778 if not isinstance(msg, StartupDetails): 

779 raise RuntimeError(f"Unhandled startup message {type(msg)} {msg}") 

780 

781 # setproctitle causes issue on Mac OS: https://github.com/benoitc/gunicorn/issues/3021 

782 os_type = sys.platform 

783 if os_type == "darwin": 

784 log.debug("Mac OS detected, skipping setproctitle") 

785 else: 

786 from setproctitle import setproctitle 

787 

788 setproctitle(f"airflow worker -- {msg.ti.id}") 

789 

790 try: 

791 get_listener_manager().hook.on_starting(component=TaskRunnerMarker()) 

792 except Exception: 

793 log.exception("error calling listener") 

794 

795 with _airflow_parsing_context_manager(dag_id=msg.ti.dag_id, task_id=msg.ti.task_id): 

796 ti = parse(msg, log) 

797 log.debug("Dag file parsed", file=msg.dag_rel_path) 

798 

799 run_as_user = getattr(ti.task, "run_as_user", None) or conf.get( 

800 "core", "default_impersonation", fallback=None 

801 ) 

802 

803 if os.environ.get("_AIRFLOW__REEXECUTED_PROCESS") != "1" and run_as_user and run_as_user != getuser(): 

804 # enters here for re-exec process 

805 os.environ["_AIRFLOW__REEXECUTED_PROCESS"] = "1" 

806 # store startup message in environment for re-exec process 

807 os.environ["_AIRFLOW__STARTUP_MSG"] = msg.model_dump_json() 

808 os.set_inheritable(SUPERVISOR_COMMS.socket.fileno(), True) 

809 

810 # Import main directly from the module instead of re-executing the file. 

811 # This ensures that when other parts modules import 

812 # airflow.sdk.execution_time.task_runner, they get the same module instance 

813 # with the properly initialized SUPERVISOR_COMMS global variable. 

814 # If we re-executed the module with `python -m`, it would load as __main__ and future 

815 # imports would get a fresh copy without the initialized globals. 

816 rexec_python_code = "from airflow.sdk.execution_time.task_runner import main; main()" 

817 cmd = ["sudo", "-E", "-H", "-u", run_as_user, sys.executable, "-c", rexec_python_code] 

818 log.info( 

819 "Running command", 

820 command=cmd, 

821 ) 

822 os.execvp("sudo", cmd) 

823 

824 # ideally, we should never reach here, but if we do, we should return None, None, None 

825 return None, None, None 

826 

827 return ti, ti.get_template_context(), log 

828 

829 

830def _serialize_template_field(template_field: Any, name: str) -> str | dict | list | int | float: 

831 """ 

832 Return a serializable representation of the templated field. 

833 

834 If ``templated_field`` contains a class or instance that requires recursive 

835 templating, store them as strings. Otherwise simply return the field as-is. 

836 

837 Used sdk secrets masker to redact secrets in the serialized output. 

838 """ 

839 import json 

840 

841 from airflow.sdk._shared.secrets_masker import redact 

842 

843 def is_jsonable(x): 

844 try: 

845 json.dumps(x) 

846 except (TypeError, OverflowError): 

847 return False 

848 else: 

849 return True 

850 

851 def translate_tuples_to_lists(obj: Any): 

852 """Recursively convert tuples to lists.""" 

853 if isinstance(obj, tuple): 

854 return [translate_tuples_to_lists(item) for item in obj] 

855 if isinstance(obj, list): 

856 return [translate_tuples_to_lists(item) for item in obj] 

857 if isinstance(obj, dict): 

858 return {key: translate_tuples_to_lists(value) for key, value in obj.items()} 

859 return obj 

860 

861 def sort_dict_recursively(obj: Any) -> Any: 

862 """Recursively sort dictionaries to ensure consistent ordering.""" 

863 if isinstance(obj, dict): 

864 return {k: sort_dict_recursively(v) for k, v in sorted(obj.items())} 

865 if isinstance(obj, list): 

866 return [sort_dict_recursively(item) for item in obj] 

867 if isinstance(obj, tuple): 

868 return tuple(sort_dict_recursively(item) for item in obj) 

869 return obj 

870 

871 max_length = conf.getint("core", "max_templated_field_length") 

872 

873 if not is_jsonable(template_field): 

874 try: 

875 serialized = template_field.serialize() 

876 except AttributeError: 

877 serialized = str(template_field) 

878 if len(serialized) > max_length: 

879 rendered = redact(serialized, name) 

880 return ( 

881 "Truncated. You can change this behaviour in [core]max_templated_field_length. " 

882 f"{rendered[: max_length - 79]!r}... " 

883 ) 

884 return serialized 

885 if not template_field and not isinstance(template_field, tuple): 

886 # Avoid unnecessary serialization steps for empty fields unless they are tuples 

887 # and need to be converted to lists 

888 return template_field 

889 template_field = translate_tuples_to_lists(template_field) 

890 # Sort dictionaries recursively to ensure consistent string representation 

891 # This prevents hash inconsistencies when dict ordering varies 

892 if isinstance(template_field, dict): 

893 template_field = sort_dict_recursively(template_field) 

894 serialized = str(template_field) 

895 if len(serialized) > max_length: 

896 rendered = redact(serialized, name) 

897 return ( 

898 "Truncated. You can change this behaviour in [core]max_templated_field_length. " 

899 f"{rendered[: max_length - 79]!r}... " 

900 ) 

901 return template_field 

902 

903 

904def _serialize_rendered_fields(task: AbstractOperator) -> dict[str, JsonValue]: 

905 from airflow.sdk._shared.secrets_masker import redact 

906 

907 rendered_fields = {} 

908 for field in task.template_fields: 

909 value = getattr(task, field) 

910 serialized = _serialize_template_field(value, field) 

911 # Redact secrets in the task process itself before sending to API server 

912 # This ensures that the secrets those are registered via mask_secret() on workers / dag processor are properly masked 

913 # on the UI. 

914 rendered_fields[field] = redact(serialized, field) 

915 

916 return rendered_fields # type: ignore[return-value] # Convince mypy that this is OK since we pass JsonValue to redact, so it will return the same 

917 

918 

919def _build_asset_profiles(lineage_objects: list) -> Iterator[AssetProfile]: 

920 # Lineage can have other types of objects besides assets, so we need to process them a bit. 

921 for obj in lineage_objects or (): 

922 if isinstance(obj, Asset): 

923 yield AssetProfile(name=obj.name, uri=obj.uri, type=Asset.__name__) 

924 elif isinstance(obj, AssetNameRef): 

925 yield AssetProfile(name=obj.name, type=AssetNameRef.__name__) 

926 elif isinstance(obj, AssetUriRef): 

927 yield AssetProfile(uri=obj.uri, type=AssetUriRef.__name__) 

928 elif isinstance(obj, AssetAlias): 

929 yield AssetProfile(name=obj.name, type=AssetAlias.__name__) 

930 

931 

932def _serialize_outlet_events(events: OutletEventAccessorsProtocol) -> Iterator[dict[str, JsonValue]]: 

933 if TYPE_CHECKING: 

934 assert isinstance(events, OutletEventAccessors) 

935 # We just collect everything the user recorded in the accessors. 

936 # Further filtering will be done in the API server. 

937 for key, accessor in events._dict.items(): 

938 if isinstance(key, AssetUniqueKey): 

939 yield {"dest_asset_key": attrs.asdict(key), "extra": accessor.extra} 

940 for alias_event in accessor.asset_alias_events: 

941 yield attrs.asdict(alias_event) 

942 

943 

944def _prepare(ti: RuntimeTaskInstance, log: Logger, context: Context) -> ToSupervisor | None: 

945 ti.hostname = get_hostname() 

946 ti.task = ti.task.prepare_for_execution() 

947 # Since context is now cached, and calling `ti.get_template_context` will return the same dict, we want to 

948 # update the value of the task that is sent from there 

949 context["task"] = ti.task 

950 

951 jinja_env = ti.task.dag.get_template_env() 

952 ti.render_templates(context=context, jinja_env=jinja_env) 

953 

954 if rendered_fields := _serialize_rendered_fields(ti.task): 

955 # so that we do not call the API unnecessarily 

956 SUPERVISOR_COMMS.send(msg=SetRenderedFields(rendered_fields=rendered_fields)) 

957 

958 # Try to render map_index_template early with available context (will be re-rendered after execution) 

959 # This provides a partial label during task execution for templates using pre-execution context 

960 # If rendering fails here, we suppress the error since it will be re-rendered after execution 

961 try: 

962 if rendered_map_index := _render_map_index(context, ti=ti, log=log): 

963 ti.rendered_map_index = rendered_map_index 

964 log.debug("Sending early rendered map index", length=len(rendered_map_index)) 

965 SUPERVISOR_COMMS.send(msg=SetRenderedMapIndex(rendered_map_index=rendered_map_index)) 

966 except Exception: 

967 log.debug( 

968 "Early rendering of map_index_template failed, will retry after task execution", exc_info=True 

969 ) 

970 

971 _validate_task_inlets_and_outlets(ti=ti, log=log) 

972 

973 try: 

974 # TODO: Call pre execute etc. 

975 get_listener_manager().hook.on_task_instance_running( 

976 previous_state=TaskInstanceState.QUEUED, task_instance=ti 

977 ) 

978 except Exception: 

979 log.exception("error calling listener") 

980 

981 # No error, carry on and execute the task 

982 return None 

983 

984 

985def _validate_task_inlets_and_outlets(*, ti: RuntimeTaskInstance, log: Logger) -> None: 

986 if not ti.task.inlets and not ti.task.outlets: 

987 return 

988 

989 inactive_assets_resp = SUPERVISOR_COMMS.send(msg=ValidateInletsAndOutlets(ti_id=ti.id)) 

990 if TYPE_CHECKING: 

991 assert isinstance(inactive_assets_resp, InactiveAssetsResult) 

992 if inactive_assets := inactive_assets_resp.inactive_assets: 

993 raise AirflowInactiveAssetInInletOrOutletException( 

994 inactive_asset_keys=[ 

995 AssetUniqueKey.from_profile(asset_profile) for asset_profile in inactive_assets 

996 ] 

997 ) 

998 

999 

1000def _defer_task( 

1001 defer: TaskDeferred, ti: RuntimeTaskInstance, log: Logger 

1002) -> tuple[ToSupervisor, TaskInstanceState]: 

1003 # TODO: Should we use structlog.bind_contextvars here for dag_id, task_id & run_id? 

1004 

1005 log.info("Pausing task as DEFERRED. ", dag_id=ti.dag_id, task_id=ti.task_id, run_id=ti.run_id) 

1006 classpath, trigger_kwargs = defer.trigger.serialize() 

1007 

1008 from airflow.sdk.serde import serialize as serde_serialize 

1009 

1010 trigger_kwargs = serde_serialize(trigger_kwargs) 

1011 next_kwargs = serde_serialize(defer.kwargs or {}) 

1012 

1013 if TYPE_CHECKING: 

1014 assert isinstance(next_kwargs, dict) 

1015 assert isinstance(trigger_kwargs, dict) 

1016 

1017 msg = DeferTask( 

1018 classpath=classpath, 

1019 trigger_kwargs=trigger_kwargs, 

1020 trigger_timeout=defer.timeout, 

1021 next_method=defer.method_name, 

1022 next_kwargs=next_kwargs, 

1023 ) 

1024 state = TaskInstanceState.DEFERRED 

1025 

1026 return msg, state 

1027 

1028 

1029@Sentry.enrich_errors 

1030def run( 

1031 ti: RuntimeTaskInstance, 

1032 context: Context, 

1033 log: Logger, 

1034) -> tuple[TaskInstanceState, ToSupervisor | None, BaseException | None]: 

1035 """Run the task in this process.""" 

1036 import signal 

1037 

1038 from airflow.sdk.exceptions import ( 

1039 AirflowFailException, 

1040 AirflowRescheduleException, 

1041 AirflowSensorTimeout, 

1042 AirflowSkipException, 

1043 AirflowTaskTerminated, 

1044 DagRunTriggerException, 

1045 DownstreamTasksSkipped, 

1046 TaskDeferred, 

1047 ) 

1048 

1049 if TYPE_CHECKING: 

1050 assert ti.task is not None 

1051 assert isinstance(ti.task, BaseOperator) 

1052 

1053 parent_pid = os.getpid() 

1054 

1055 def _on_term(signum, frame): 

1056 pid = os.getpid() 

1057 if pid != parent_pid: 

1058 return 

1059 

1060 ti.task.on_kill() 

1061 

1062 signal.signal(signal.SIGTERM, _on_term) 

1063 

1064 msg: ToSupervisor | None = None 

1065 state: TaskInstanceState 

1066 error: BaseException | None = None 

1067 

1068 try: 

1069 # First, clear the xcom data sent from server 

1070 if ti._ti_context_from_server and (keys_to_delete := ti._ti_context_from_server.xcom_keys_to_clear): 

1071 for x in keys_to_delete: 

1072 log.debug("Clearing XCom with key", key=x) 

1073 XCom.delete( 

1074 key=x, 

1075 dag_id=ti.dag_id, 

1076 task_id=ti.task_id, 

1077 run_id=ti.run_id, 

1078 map_index=ti.map_index, 

1079 ) 

1080 

1081 with set_current_context(context): 

1082 # This is the earliest that we can render templates -- as if it excepts for any reason we need to 

1083 # catch it and handle it like a normal task failure 

1084 if early_exit := _prepare(ti, log, context): 

1085 msg = early_exit 

1086 ti.state = state = TaskInstanceState.FAILED 

1087 return state, msg, error 

1088 

1089 try: 

1090 result = _execute_task(context=context, ti=ti, log=log) 

1091 except Exception: 

1092 import jinja2 

1093 

1094 # If the task failed, swallow rendering error so it doesn't mask the main error. 

1095 with contextlib.suppress(jinja2.TemplateSyntaxError, jinja2.UndefinedError): 

1096 previous_rendered_map_index = ti.rendered_map_index 

1097 ti.rendered_map_index = _render_map_index(context, ti=ti, log=log) 

1098 # Send update only if value changed (e.g., user set context variables during execution) 

1099 if ti.rendered_map_index and ti.rendered_map_index != previous_rendered_map_index: 

1100 SUPERVISOR_COMMS.send( 

1101 msg=SetRenderedMapIndex(rendered_map_index=ti.rendered_map_index) 

1102 ) 

1103 raise 

1104 else: # If the task succeeded, render normally to let rendering error bubble up. 

1105 previous_rendered_map_index = ti.rendered_map_index 

1106 ti.rendered_map_index = _render_map_index(context, ti=ti, log=log) 

1107 # Send update only if value changed (e.g., user set context variables during execution) 

1108 if ti.rendered_map_index and ti.rendered_map_index != previous_rendered_map_index: 

1109 SUPERVISOR_COMMS.send(msg=SetRenderedMapIndex(rendered_map_index=ti.rendered_map_index)) 

1110 

1111 _push_xcom_if_needed(result, ti, log) 

1112 

1113 msg, state = _handle_current_task_success(context, ti) 

1114 except DownstreamTasksSkipped as skip: 

1115 log.info("Skipping downstream tasks.") 

1116 tasks_to_skip = skip.tasks if isinstance(skip.tasks, list) else [skip.tasks] 

1117 SUPERVISOR_COMMS.send(msg=SkipDownstreamTasks(tasks=tasks_to_skip)) 

1118 msg, state = _handle_current_task_success(context, ti) 

1119 except DagRunTriggerException as drte: 

1120 msg, state = _handle_trigger_dag_run(drte, context, ti, log) 

1121 except TaskDeferred as defer: 

1122 msg, state = _defer_task(defer, ti, log) 

1123 except AirflowSkipException as e: 

1124 if e.args: 

1125 log.info("Skipping task.", reason=e.args[0]) 

1126 msg = TaskState( 

1127 state=TaskInstanceState.SKIPPED, 

1128 end_date=datetime.now(tz=timezone.utc), 

1129 rendered_map_index=ti.rendered_map_index, 

1130 ) 

1131 state = TaskInstanceState.SKIPPED 

1132 except AirflowRescheduleException as reschedule: 

1133 log.info("Rescheduling task, marking task as UP_FOR_RESCHEDULE") 

1134 msg = RescheduleTask( 

1135 reschedule_date=reschedule.reschedule_date, end_date=datetime.now(tz=timezone.utc) 

1136 ) 

1137 state = TaskInstanceState.UP_FOR_RESCHEDULE 

1138 except (AirflowFailException, AirflowSensorTimeout) as e: 

1139 # If AirflowFailException is raised, task should not retry. 

1140 # If a sensor in reschedule mode reaches timeout, task should not retry. 

1141 log.exception("Task failed with exception") 

1142 ti.end_date = datetime.now(tz=timezone.utc) 

1143 msg = TaskState( 

1144 state=TaskInstanceState.FAILED, 

1145 end_date=ti.end_date, 

1146 rendered_map_index=ti.rendered_map_index, 

1147 ) 

1148 state = TaskInstanceState.FAILED 

1149 error = e 

1150 except (AirflowTaskTimeout, AirflowException, AirflowRuntimeError) as e: 

1151 # We should allow retries if the task has defined it. 

1152 log.exception("Task failed with exception") 

1153 msg, state = _handle_current_task_failed(ti) 

1154 error = e 

1155 except AirflowTaskTerminated as e: 

1156 # External state updates are already handled with `ti_heartbeat` and will be 

1157 # updated already be another UI API. So, these exceptions should ideally never be thrown. 

1158 # If these are thrown, we should mark the TI state as failed. 

1159 log.exception("Task failed with exception") 

1160 ti.end_date = datetime.now(tz=timezone.utc) 

1161 msg = TaskState( 

1162 state=TaskInstanceState.FAILED, 

1163 end_date=ti.end_date, 

1164 rendered_map_index=ti.rendered_map_index, 

1165 ) 

1166 state = TaskInstanceState.FAILED 

1167 error = e 

1168 except SystemExit as e: 

1169 # SystemExit needs to be retried if they are eligible. 

1170 log.error("Task exited", exit_code=e.code) 

1171 msg, state = _handle_current_task_failed(ti) 

1172 error = e 

1173 except BaseException as e: 

1174 log.exception("Task failed with exception") 

1175 msg, state = _handle_current_task_failed(ti) 

1176 error = e 

1177 finally: 

1178 if msg: 

1179 SUPERVISOR_COMMS.send(msg=msg) 

1180 

1181 # Return the message to make unit tests easier too 

1182 ti.state = state 

1183 return state, msg, error 

1184 

1185 

1186def _handle_current_task_success( 

1187 context: Context, 

1188 ti: RuntimeTaskInstance, 

1189) -> tuple[SucceedTask, TaskInstanceState]: 

1190 end_date = datetime.now(tz=timezone.utc) 

1191 ti.end_date = end_date 

1192 

1193 # Record operator and task instance success metrics 

1194 operator = ti.task.__class__.__name__ 

1195 stats_tags = {"dag_id": ti.dag_id, "task_id": ti.task_id} 

1196 

1197 Stats.incr(f"operator_successes_{operator}", tags=stats_tags) 

1198 # Same metric with tagging 

1199 Stats.incr("operator_successes", tags={**stats_tags, "operator": operator}) 

1200 Stats.incr("ti_successes", tags=stats_tags) 

1201 

1202 task_outlets = list(_build_asset_profiles(ti.task.outlets)) 

1203 outlet_events = list(_serialize_outlet_events(context["outlet_events"])) 

1204 msg = SucceedTask( 

1205 end_date=end_date, 

1206 task_outlets=task_outlets, 

1207 outlet_events=outlet_events, 

1208 rendered_map_index=ti.rendered_map_index, 

1209 ) 

1210 return msg, TaskInstanceState.SUCCESS 

1211 

1212 

1213def _handle_current_task_failed( 

1214 ti: RuntimeTaskInstance, 

1215) -> tuple[RetryTask, TaskInstanceState] | tuple[TaskState, TaskInstanceState]: 

1216 end_date = datetime.now(tz=timezone.utc) 

1217 ti.end_date = end_date 

1218 if ti._ti_context_from_server and ti._ti_context_from_server.should_retry: 

1219 return RetryTask(end_date=end_date), TaskInstanceState.UP_FOR_RETRY 

1220 return ( 

1221 TaskState( 

1222 state=TaskInstanceState.FAILED, end_date=end_date, rendered_map_index=ti.rendered_map_index 

1223 ), 

1224 TaskInstanceState.FAILED, 

1225 ) 

1226 

1227 

1228def _handle_trigger_dag_run( 

1229 drte: DagRunTriggerException, context: Context, ti: RuntimeTaskInstance, log: Logger 

1230) -> tuple[ToSupervisor, TaskInstanceState]: 

1231 """Handle exception from TriggerDagRunOperator.""" 

1232 log.info("Triggering Dag Run.", trigger_dag_id=drte.trigger_dag_id) 

1233 comms_msg = SUPERVISOR_COMMS.send( 

1234 TriggerDagRun( 

1235 dag_id=drte.trigger_dag_id, 

1236 run_id=drte.dag_run_id, 

1237 logical_date=drte.logical_date, 

1238 conf=drte.conf, 

1239 reset_dag_run=drte.reset_dag_run, 

1240 ), 

1241 ) 

1242 

1243 if isinstance(comms_msg, ErrorResponse) and comms_msg.error == ErrorType.DAGRUN_ALREADY_EXISTS: 

1244 if drte.skip_when_already_exists: 

1245 log.info( 

1246 "Dag Run already exists, skipping task as skip_when_already_exists is set to True.", 

1247 dag_id=drte.trigger_dag_id, 

1248 ) 

1249 msg = TaskState( 

1250 state=TaskInstanceState.SKIPPED, 

1251 end_date=datetime.now(tz=timezone.utc), 

1252 rendered_map_index=ti.rendered_map_index, 

1253 ) 

1254 state = TaskInstanceState.SKIPPED 

1255 else: 

1256 log.error("Dag Run already exists, marking task as failed.", dag_id=drte.trigger_dag_id) 

1257 msg = TaskState( 

1258 state=TaskInstanceState.FAILED, 

1259 end_date=datetime.now(tz=timezone.utc), 

1260 rendered_map_index=ti.rendered_map_index, 

1261 ) 

1262 state = TaskInstanceState.FAILED 

1263 

1264 return msg, state 

1265 

1266 log.info("Dag Run triggered successfully.", trigger_dag_id=drte.trigger_dag_id) 

1267 

1268 # Store the run id from the dag run (either created or found above) to 

1269 # be used when creating the extra link on the webserver. 

1270 ti.xcom_push(key="trigger_run_id", value=drte.dag_run_id) 

1271 

1272 if drte.deferrable: 

1273 from airflow.providers.standard.triggers.external_task import DagStateTrigger 

1274 

1275 defer = TaskDeferred( 

1276 trigger=DagStateTrigger( 

1277 dag_id=drte.trigger_dag_id, 

1278 states=drte.allowed_states + drte.failed_states, # type: ignore[arg-type] 

1279 # Don't filter by execution_dates when run_ids is provided. 

1280 # run_id uniquely identifies a DAG run, and when reset_dag_run=True, 

1281 # drte.logical_date might be a newly calculated value that doesn't match 

1282 # the persisted logical_date in the database, causing the trigger to never find the run. 

1283 execution_dates=None, 

1284 run_ids=[drte.dag_run_id], 

1285 poll_interval=drte.poke_interval, 

1286 ), 

1287 method_name="execute_complete", 

1288 ) 

1289 return _defer_task(defer, ti, log) 

1290 if drte.wait_for_completion: 

1291 while True: 

1292 log.info( 

1293 "Waiting for dag run to complete execution in allowed state.", 

1294 dag_id=drte.trigger_dag_id, 

1295 run_id=drte.dag_run_id, 

1296 allowed_state=drte.allowed_states, 

1297 ) 

1298 time.sleep(drte.poke_interval) 

1299 

1300 comms_msg = SUPERVISOR_COMMS.send( 

1301 GetDagRunState(dag_id=drte.trigger_dag_id, run_id=drte.dag_run_id) 

1302 ) 

1303 if TYPE_CHECKING: 

1304 assert isinstance(comms_msg, DagRunStateResult) 

1305 if comms_msg.state in drte.failed_states: 

1306 log.error( 

1307 "DagRun finished with failed state.", dag_id=drte.trigger_dag_id, state=comms_msg.state 

1308 ) 

1309 msg = TaskState( 

1310 state=TaskInstanceState.FAILED, 

1311 end_date=datetime.now(tz=timezone.utc), 

1312 rendered_map_index=ti.rendered_map_index, 

1313 ) 

1314 state = TaskInstanceState.FAILED 

1315 return msg, state 

1316 if comms_msg.state in drte.allowed_states: 

1317 log.info( 

1318 "DagRun finished with allowed state.", dag_id=drte.trigger_dag_id, state=comms_msg.state 

1319 ) 

1320 break 

1321 log.debug( 

1322 "DagRun not yet in allowed or failed state.", 

1323 dag_id=drte.trigger_dag_id, 

1324 state=comms_msg.state, 

1325 ) 

1326 

1327 return _handle_current_task_success(context, ti) 

1328 

1329 

1330def _run_task_state_change_callbacks( 

1331 task: BaseOperator, 

1332 kind: Literal[ 

1333 "on_execute_callback", 

1334 "on_failure_callback", 

1335 "on_success_callback", 

1336 "on_retry_callback", 

1337 "on_skipped_callback", 

1338 ], 

1339 context: Context, 

1340 log: Logger, 

1341) -> None: 

1342 callback: Callable[[Context], None] 

1343 for i, callback in enumerate(getattr(task, kind)): 

1344 try: 

1345 create_executable_runner(callback, context_get_outlet_events(context), logger=log).run(context) 

1346 except Exception: 

1347 log.exception("Failed to run task callback", kind=kind, index=i, callback=callback) 

1348 

1349 

1350def _send_error_email_notification( 

1351 task: BaseOperator | MappedOperator, 

1352 ti: RuntimeTaskInstance, 

1353 context: Context, 

1354 error: BaseException | str | None, 

1355 log: Logger, 

1356) -> None: 

1357 """Send email notification for task errors using SmtpNotifier.""" 

1358 try: 

1359 from airflow.providers.smtp.notifications.smtp import SmtpNotifier 

1360 except ImportError: 

1361 log.error( 

1362 "Failed to send task failure or retry email notification: " 

1363 "`apache-airflow-providers-smtp` is not installed. " 

1364 "Install this provider to enable email notifications." 

1365 ) 

1366 return 

1367 

1368 if not task.email: 

1369 return 

1370 

1371 subject_template_file = conf.get("email", "subject_template", fallback=None) 

1372 

1373 # Read the template file if configured 

1374 if subject_template_file and Path(subject_template_file).exists(): 

1375 subject = Path(subject_template_file).read_text() 

1376 else: 

1377 # Fallback to default 

1378 subject = "Airflow alert: {{ti}}" 

1379 

1380 html_content_template_file = conf.get("email", "html_content_template", fallback=None) 

1381 

1382 # Read the template file if configured 

1383 if html_content_template_file and Path(html_content_template_file).exists(): 

1384 html_content = Path(html_content_template_file).read_text() 

1385 else: 

1386 # Fallback to default 

1387 # For reporting purposes, we report based on 1-indexed, 

1388 # not 0-indexed lists (i.e. Try 1 instead of Try 0 for the first attempt). 

1389 html_content = ( 

1390 "Try {{try_number}} out of {{max_tries + 1}}<br>" 

1391 "Exception:<br>{{exception_html}}<br>" 

1392 'Log: <a href="{{ti.log_url}}">Link</a><br>' 

1393 "Host: {{ti.hostname}}<br>" 

1394 'Mark success: <a href="{{ti.mark_success_url}}">Link</a><br>' 

1395 ) 

1396 

1397 # Add exception_html to context for template rendering 

1398 import html 

1399 

1400 exception_html = html.escape(str(error)).replace("\n", "<br>") 

1401 additional_context = { 

1402 "exception": error, 

1403 "exception_html": exception_html, 

1404 "try_number": ti.try_number, 

1405 "max_tries": ti.max_tries, 

1406 } 

1407 email_context = {**context, **additional_context} 

1408 to_emails = task.email 

1409 if not to_emails: 

1410 return 

1411 

1412 try: 

1413 notifier = SmtpNotifier( 

1414 to=to_emails, 

1415 subject=subject, 

1416 html_content=html_content, 

1417 from_email=conf.get("email", "from_email", fallback="airflow@airflow"), 

1418 ) 

1419 notifier(email_context) 

1420 except Exception: 

1421 log.exception("Failed to send email notification") 

1422 

1423 

1424def _execute_task(context: Context, ti: RuntimeTaskInstance, log: Logger): 

1425 """Execute Task (optionally with a Timeout) and push Xcom results.""" 

1426 task = ti.task 

1427 execute = task.execute 

1428 

1429 if ti._ti_context_from_server and (next_method := ti._ti_context_from_server.next_method): 

1430 from airflow.sdk.serde import deserialize 

1431 

1432 next_kwargs_data = ti._ti_context_from_server.next_kwargs or {} 

1433 try: 

1434 if TYPE_CHECKING: 

1435 assert isinstance(next_kwargs_data, dict) 

1436 kwargs = deserialize(next_kwargs_data) 

1437 except (ImportError, KeyError, AttributeError, TypeError): 

1438 from airflow.serialization.serialized_objects import BaseSerialization 

1439 

1440 kwargs = BaseSerialization.deserialize(next_kwargs_data) 

1441 

1442 if TYPE_CHECKING: 

1443 assert isinstance(kwargs, dict) 

1444 execute = functools.partial(task.resume_execution, next_method=next_method, next_kwargs=kwargs) 

1445 

1446 ctx = contextvars.copy_context() 

1447 # Populate the context var so ExecutorSafeguard doesn't complain 

1448 ctx.run(ExecutorSafeguard.tracker.set, task) 

1449 

1450 # Export context in os.environ to make it available for operators to use. 

1451 airflow_context_vars = context_to_airflow_vars(context, in_env_var_format=True) 

1452 os.environ.update(airflow_context_vars) 

1453 

1454 outlet_events = context_get_outlet_events(context) 

1455 

1456 if (pre_execute_hook := task._pre_execute_hook) is not None: 

1457 create_executable_runner(pre_execute_hook, outlet_events, logger=log).run(context) 

1458 if getattr(pre_execute_hook := task.pre_execute, "__func__", None) is not BaseOperator.pre_execute: 

1459 create_executable_runner(pre_execute_hook, outlet_events, logger=log).run(context) 

1460 

1461 _run_task_state_change_callbacks(task, "on_execute_callback", context, log) 

1462 

1463 if task.execution_timeout: 

1464 from airflow.sdk.execution_time.timeout import timeout 

1465 

1466 # TODO: handle timeout in case of deferral 

1467 timeout_seconds = task.execution_timeout.total_seconds() 

1468 try: 

1469 # It's possible we're already timed out, so fast-fail if true 

1470 if timeout_seconds <= 0: 

1471 raise AirflowTaskTimeout() 

1472 # Run task in timeout wrapper 

1473 with timeout(timeout_seconds): 

1474 result = ctx.run(execute, context=context) 

1475 except AirflowTaskTimeout: 

1476 task.on_kill() 

1477 raise 

1478 else: 

1479 result = ctx.run(execute, context=context) 

1480 

1481 if (post_execute_hook := task._post_execute_hook) is not None: 

1482 create_executable_runner(post_execute_hook, outlet_events, logger=log).run(context, result) 

1483 if getattr(post_execute_hook := task.post_execute, "__func__", None) is not BaseOperator.post_execute: 

1484 create_executable_runner(post_execute_hook, outlet_events, logger=log).run(context) 

1485 

1486 return result 

1487 

1488 

1489def _render_map_index(context: Context, ti: RuntimeTaskInstance, log: Logger) -> str | None: 

1490 """Render named map index if the Dag author defined map_index_template at the task level.""" 

1491 if (template := context.get("map_index_template")) is None: 

1492 return None 

1493 log.debug("Rendering map_index_template", template_length=len(template)) 

1494 jinja_env = ti.task.dag.get_template_env() 

1495 rendered_map_index = jinja_env.from_string(template).render(context) 

1496 log.debug("Map index rendered", length=len(rendered_map_index)) 

1497 return rendered_map_index 

1498 

1499 

1500def _push_xcom_if_needed(result: Any, ti: RuntimeTaskInstance, log: Logger): 

1501 """Push XCom values when task has ``do_xcom_push`` set to ``True`` and the task returns a result.""" 

1502 if ti.task.do_xcom_push: 

1503 xcom_value = result 

1504 else: 

1505 xcom_value = None 

1506 

1507 has_mapped_dep = next(ti.task.iter_mapped_dependants(), None) is not None 

1508 if xcom_value is None: 

1509 if not ti.is_mapped and has_mapped_dep: 

1510 # Uhoh, a downstream mapped task depends on us to push something to map over 

1511 from airflow.sdk.exceptions import XComForMappingNotPushed 

1512 

1513 raise XComForMappingNotPushed() 

1514 return 

1515 

1516 mapped_length: int | None = None 

1517 if not ti.is_mapped and has_mapped_dep: 

1518 from airflow.sdk.definitions.mappedoperator import is_mappable_value 

1519 from airflow.sdk.exceptions import UnmappableXComTypePushed 

1520 

1521 if not is_mappable_value(xcom_value): 

1522 raise UnmappableXComTypePushed(xcom_value) 

1523 mapped_length = len(xcom_value) 

1524 

1525 log.info("Pushing xcom", ti=ti) 

1526 

1527 # If the task has multiple outputs, push each output as a separate XCom. 

1528 if ti.task.multiple_outputs: 

1529 if not isinstance(xcom_value, Mapping): 

1530 raise TypeError( 

1531 f"Returned output was type {type(xcom_value)} expected dictionary for multiple_outputs" 

1532 ) 

1533 for key in xcom_value.keys(): 

1534 if not isinstance(key, str): 

1535 raise TypeError( 

1536 "Returned dictionary keys must be strings when using " 

1537 f"multiple_outputs, found {key} ({type(key)}) instead" 

1538 ) 

1539 for k, v in result.items(): 

1540 ti.xcom_push(k, v) 

1541 

1542 _xcom_push(ti, BaseXCom.XCOM_RETURN_KEY, result, mapped_length=mapped_length) 

1543 

1544 

1545def finalize( 

1546 ti: RuntimeTaskInstance, 

1547 state: TaskInstanceState, 

1548 context: Context, 

1549 log: Logger, 

1550 error: BaseException | None = None, 

1551): 

1552 # Record task duration metrics for all terminal states 

1553 if ti.start_date and ti.end_date: 

1554 duration_ms = (ti.end_date - ti.start_date).total_seconds() * 1000 

1555 stats_tags = {"dag_id": ti.dag_id, "task_id": ti.task_id} 

1556 

1557 Stats.timing(f"dag.{ti.dag_id}.{ti.task_id}.duration", duration_ms) 

1558 Stats.timing("task.duration", duration_ms, tags=stats_tags) 

1559 

1560 task = ti.task 

1561 # Pushing xcom for each operator extra links defined on the operator only. 

1562 for oe in task.operator_extra_links: 

1563 try: 

1564 link, xcom_key = oe.get_link(operator=task, ti_key=ti), oe.xcom_key # type: ignore[arg-type] 

1565 log.debug("Setting xcom for operator extra link", link=link, xcom_key=xcom_key) 

1566 _xcom_push_to_db(ti, key=xcom_key, value=link) 

1567 except Exception: 

1568 log.exception( 

1569 "Failed to push an xcom for task operator extra link", 

1570 link_name=oe.name, 

1571 xcom_key=oe.xcom_key, 

1572 ti=ti, 

1573 ) 

1574 

1575 if getattr(ti.task, "overwrite_rtif_after_execution", False): 

1576 log.debug("Overwriting Rendered template fields.") 

1577 if ti.task.template_fields: 

1578 SUPERVISOR_COMMS.send(SetRenderedFields(rendered_fields=_serialize_rendered_fields(ti.task))) 

1579 

1580 log.debug("Running finalizers", ti=ti) 

1581 if state == TaskInstanceState.SUCCESS: 

1582 _run_task_state_change_callbacks(task, "on_success_callback", context, log) 

1583 try: 

1584 get_listener_manager().hook.on_task_instance_success( 

1585 previous_state=TaskInstanceState.RUNNING, task_instance=ti 

1586 ) 

1587 except Exception: 

1588 log.exception("error calling listener") 

1589 elif state == TaskInstanceState.SKIPPED: 

1590 _run_task_state_change_callbacks(task, "on_skipped_callback", context, log) 

1591 elif state == TaskInstanceState.UP_FOR_RETRY: 

1592 _run_task_state_change_callbacks(task, "on_retry_callback", context, log) 

1593 try: 

1594 get_listener_manager().hook.on_task_instance_failed( 

1595 previous_state=TaskInstanceState.RUNNING, task_instance=ti, error=error 

1596 ) 

1597 except Exception: 

1598 log.exception("error calling listener") 

1599 if error and task.email_on_retry and task.email: 

1600 _send_error_email_notification(task, ti, context, error, log) 

1601 elif state == TaskInstanceState.FAILED: 

1602 _run_task_state_change_callbacks(task, "on_failure_callback", context, log) 

1603 try: 

1604 get_listener_manager().hook.on_task_instance_failed( 

1605 previous_state=TaskInstanceState.RUNNING, task_instance=ti, error=error 

1606 ) 

1607 except Exception: 

1608 log.exception("error calling listener") 

1609 if error and task.email_on_failure and task.email: 

1610 _send_error_email_notification(task, ti, context, error, log) 

1611 

1612 try: 

1613 get_listener_manager().hook.before_stopping(component=TaskRunnerMarker()) 

1614 except Exception: 

1615 log.exception("error calling listener") 

1616 

1617 

1618def main(): 

1619 log = structlog.get_logger(logger_name="task") 

1620 

1621 global SUPERVISOR_COMMS 

1622 SUPERVISOR_COMMS = CommsDecoder[ToTask, ToSupervisor](log=log) 

1623 

1624 try: 

1625 ti, context, log = startup() 

1626 with BundleVersionLock( 

1627 bundle_name=ti.bundle_instance.name, 

1628 bundle_version=ti.bundle_instance.version, 

1629 ): 

1630 state, _, error = run(ti, context, log) 

1631 context["exception"] = error 

1632 finalize(ti, state, context, log, error) 

1633 except KeyboardInterrupt: 

1634 log.exception("Ctrl-c hit") 

1635 exit(2) 

1636 except Exception: 

1637 log.exception("Top level error") 

1638 exit(1) 

1639 finally: 

1640 # Ensure the request socket is closed on the child side in all circumstances 

1641 # before the process fully terminates. 

1642 if SUPERVISOR_COMMS and SUPERVISOR_COMMS.socket: 

1643 with suppress(Exception): 

1644 SUPERVISOR_COMMS.socket.close() 

1645 

1646 

1647def reinit_supervisor_comms() -> None: 

1648 """ 

1649 Re-initialize supervisor comms and logging channel in subprocess. 

1650 

1651 This is not needed for most cases, but is used when either we re-launch the process via sudo for 

1652 run_as_user, or from inside the python code in a virtualenv (et al.) operator to re-connect so those tasks 

1653 can continue to access variables etc. 

1654 """ 

1655 import socket 

1656 

1657 if "SUPERVISOR_COMMS" not in globals(): 

1658 global SUPERVISOR_COMMS 

1659 log = structlog.get_logger(logger_name="task") 

1660 

1661 fd = int(os.environ.get("__AIRFLOW_SUPERVISOR_FD", "0")) 

1662 

1663 SUPERVISOR_COMMS = CommsDecoder[ToTask, ToSupervisor](log=log, socket=socket.socket(fileno=fd)) 

1664 

1665 logs = SUPERVISOR_COMMS.send(ResendLoggingFD()) 

1666 if isinstance(logs, SentFDs): 

1667 from airflow.sdk.log import configure_logging 

1668 

1669 log_io = os.fdopen(logs.fds[0], "wb", buffering=0) 

1670 configure_logging(json_output=True, output=log_io, sending_to_supervisor=True) 

1671 else: 

1672 print("Unable to re-configure logging after sudo, we didn't get an FD", file=sys.stderr) 

1673 

1674 

1675if __name__ == "__main__": 

1676 main()