Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/airflow/sdk/execution_time/task_runner.py: 16%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

670 statements  

1# 

2# Licensed to the Apache Software Foundation (ASF) under one 

3# or more contributor license agreements. See the NOTICE file 

4# distributed with this work for additional information 

5# regarding copyright ownership. The ASF licenses this file 

6# to you under the Apache License, Version 2.0 (the 

7# "License"); you may not use this file except in compliance 

8# with the License. You may obtain a copy of the License at 

9# 

10# http://www.apache.org/licenses/LICENSE-2.0 

11# 

12# Unless required by applicable law or agreed to in writing, 

13# software distributed under the License is distributed on an 

14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

15# KIND, either express or implied. See the License for the 

16# specific language governing permissions and limitations 

17# under the License. 

18"""The entrypoint for the actual task execution process.""" 

19 

20from __future__ import annotations 

21 

22import contextlib 

23import contextvars 

24import functools 

25import os 

26import sys 

27import time 

28from collections.abc import Callable, Iterable, Iterator, Mapping 

29from contextlib import suppress 

30from datetime import datetime, timezone 

31from itertools import product 

32from pathlib import Path 

33from typing import TYPE_CHECKING, Annotated, Any, Literal 

34from urllib.parse import quote 

35 

36import attrs 

37import lazy_object_proxy 

38import structlog 

39from pydantic import AwareDatetime, ConfigDict, Field, JsonValue, TypeAdapter 

40 

41from airflow.dag_processing.bundles.base import BaseDagBundle, BundleVersionLock 

42from airflow.dag_processing.bundles.manager import DagBundlesManager 

43from airflow.listeners.listener import get_listener_manager 

44from airflow.sdk.api.client import get_hostname, getuser 

45from airflow.sdk.api.datamodels._generated import ( 

46 AssetProfile, 

47 DagRun, 

48 TaskInstance, 

49 TaskInstanceState, 

50 TIRunContext, 

51) 

52from airflow.sdk.bases.operator import BaseOperator, ExecutorSafeguard 

53from airflow.sdk.bases.xcom import BaseXCom 

54from airflow.sdk.configuration import conf 

55from airflow.sdk.definitions._internal.dag_parsing_context import _airflow_parsing_context_manager 

56from airflow.sdk.definitions._internal.types import NOTSET, ArgNotSet, is_arg_set 

57from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetNameRef, AssetUniqueKey, AssetUriRef 

58from airflow.sdk.definitions.mappedoperator import MappedOperator 

59from airflow.sdk.definitions.param import process_params 

60from airflow.sdk.exceptions import ( 

61 AirflowException, 

62 AirflowInactiveAssetInInletOrOutletException, 

63 AirflowRuntimeError, 

64 AirflowTaskTimeout, 

65 ErrorType, 

66 TaskDeferred, 

67) 

68from airflow.sdk.execution_time.callback_runner import create_executable_runner 

69from airflow.sdk.execution_time.comms import ( 

70 AssetEventDagRunReferenceResult, 

71 CommsDecoder, 

72 DagRunStateResult, 

73 DeferTask, 

74 DRCount, 

75 ErrorResponse, 

76 GetDagRunState, 

77 GetDRCount, 

78 GetPreviousDagRun, 

79 GetTaskBreadcrumbs, 

80 GetTaskRescheduleStartDate, 

81 GetTaskStates, 

82 GetTICount, 

83 InactiveAssetsResult, 

84 PreviousDagRunResult, 

85 RescheduleTask, 

86 ResendLoggingFD, 

87 RetryTask, 

88 SentFDs, 

89 SetRenderedFields, 

90 SetRenderedMapIndex, 

91 SkipDownstreamTasks, 

92 StartupDetails, 

93 SucceedTask, 

94 TaskBreadcrumbsResult, 

95 TaskRescheduleStartDate, 

96 TaskState, 

97 TaskStatesResult, 

98 TICount, 

99 ToSupervisor, 

100 ToTask, 

101 TriggerDagRun, 

102 ValidateInletsAndOutlets, 

103) 

104from airflow.sdk.execution_time.context import ( 

105 ConnectionAccessor, 

106 InletEventsAccessors, 

107 MacrosAccessor, 

108 OutletEventAccessors, 

109 TriggeringAssetEventsAccessor, 

110 VariableAccessor, 

111 context_get_outlet_events, 

112 context_to_airflow_vars, 

113 get_previous_dagrun_success, 

114 set_current_context, 

115) 

116from airflow.sdk.execution_time.sentry import Sentry 

117from airflow.sdk.execution_time.xcom import XCom 

118from airflow.sdk.observability.stats import Stats 

119from airflow.sdk.timezone import coerce_datetime 

120 

121if TYPE_CHECKING: 

122 import jinja2 

123 from pendulum.datetime import DateTime 

124 from structlog.typing import FilteringBoundLogger as Logger 

125 

126 from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator 

127 from airflow.sdk.definitions.context import Context 

128 from airflow.sdk.exceptions import DagRunTriggerException 

129 from airflow.sdk.types import OutletEventAccessorsProtocol 

130 

131 

132class TaskRunnerMarker: 

133 """Marker for listener hooks, to properly detect from which component they are called.""" 

134 

135 

136# TODO: Move this entire class into a separate file: 

137# `airflow/sdk/execution_time/task_instance.py` 

138# or `airflow/sdk/execution_time/runtime_ti.py` 

139class RuntimeTaskInstance(TaskInstance): 

140 model_config = ConfigDict(arbitrary_types_allowed=True) 

141 

142 task: BaseOperator 

143 bundle_instance: BaseDagBundle 

144 _cached_template_context: Context | None = None 

145 """The Task Instance context. This is used to cache get_template_context.""" 

146 

147 _ti_context_from_server: Annotated[TIRunContext | None, Field(repr=False)] = None 

148 """The Task Instance context from the API server, if any.""" 

149 

150 max_tries: int = 0 

151 """The maximum number of retries for the task.""" 

152 

153 start_date: AwareDatetime 

154 """Start date of the task instance.""" 

155 

156 end_date: AwareDatetime | None = None 

157 

158 state: TaskInstanceState | None = None 

159 

160 is_mapped: bool | None = None 

161 """True if the original task was mapped.""" 

162 

163 rendered_map_index: str | None = None 

164 

165 sentry_integration: str = "" 

166 

167 def __rich_repr__(self): 

168 yield "id", self.id 

169 yield "task_id", self.task_id 

170 yield "dag_id", self.dag_id 

171 yield "run_id", self.run_id 

172 yield "max_tries", self.max_tries 

173 yield "task", type(self.task) 

174 yield "start_date", self.start_date 

175 

176 __rich_repr__.angular = True # type: ignore[attr-defined] 

177 

178 def get_template_context(self) -> Context: 

179 # TODO: Move this to `airflow.sdk.execution_time.context` 

180 # once we port the entire context logic from airflow/utils/context.py ? 

181 from airflow.plugins_manager import integrate_macros_plugins 

182 

183 integrate_macros_plugins() 

184 

185 dag_run_conf: dict[str, Any] | None = None 

186 if from_server := self._ti_context_from_server: 

187 dag_run_conf = from_server.dag_run.conf or dag_run_conf 

188 

189 validated_params = process_params(self.task.dag, self.task, dag_run_conf, suppress_exception=False) 

190 

191 # Cache the context object, which ensures that all calls to get_template_context 

192 # are operating on the same context object. 

193 self._cached_template_context: Context = self._cached_template_context or { 

194 # From the Task Execution interface 

195 "dag": self.task.dag, 

196 "inlets": self.task.inlets, 

197 "map_index_template": self.task.map_index_template, 

198 "outlets": self.task.outlets, 

199 "run_id": self.run_id, 

200 "task": self.task, 

201 "task_instance": self, 

202 "ti": self, 

203 "outlet_events": OutletEventAccessors(), 

204 "inlet_events": InletEventsAccessors(self.task.inlets), 

205 "macros": MacrosAccessor(), 

206 "params": validated_params, 

207 # TODO: Make this go through Public API longer term. 

208 # "test_mode": task_instance.test_mode, 

209 "var": { 

210 "json": VariableAccessor(deserialize_json=True), 

211 "value": VariableAccessor(deserialize_json=False), 

212 }, 

213 "conn": ConnectionAccessor(), 

214 } 

215 if from_server: 

216 dag_run = from_server.dag_run 

217 context_from_server: Context = { 

218 # TODO: Assess if we need to pass these through timezone.coerce_datetime 

219 "dag_run": dag_run, # type: ignore[typeddict-item] # Removable after #46522 

220 "triggering_asset_events": TriggeringAssetEventsAccessor.build( 

221 AssetEventDagRunReferenceResult.from_asset_event_dag_run_reference(event) 

222 for event in dag_run.consumed_asset_events 

223 ), 

224 "task_instance_key_str": f"{self.task.dag_id}__{self.task.task_id}__{dag_run.run_id}", 

225 "task_reschedule_count": from_server.task_reschedule_count or 0, 

226 "prev_start_date_success": lazy_object_proxy.Proxy( 

227 lambda: coerce_datetime(get_previous_dagrun_success(self.id).start_date) 

228 ), 

229 "prev_end_date_success": lazy_object_proxy.Proxy( 

230 lambda: coerce_datetime(get_previous_dagrun_success(self.id).end_date) 

231 ), 

232 } 

233 self._cached_template_context.update(context_from_server) 

234 

235 if logical_date := coerce_datetime(dag_run.logical_date): 

236 if TYPE_CHECKING: 

237 assert isinstance(logical_date, DateTime) 

238 ds = logical_date.strftime("%Y-%m-%d") 

239 ds_nodash = ds.replace("-", "") 

240 ts = logical_date.isoformat() 

241 ts_nodash = logical_date.strftime("%Y%m%dT%H%M%S") 

242 ts_nodash_with_tz = ts.replace("-", "").replace(":", "") 

243 # logical_date and data_interval either coexist or be None together 

244 self._cached_template_context.update( 

245 { 

246 # keys that depend on logical_date 

247 "logical_date": logical_date, 

248 "ds": ds, 

249 "ds_nodash": ds_nodash, 

250 "task_instance_key_str": f"{self.task.dag_id}__{self.task.task_id}__{ds_nodash}", 

251 "ts": ts, 

252 "ts_nodash": ts_nodash, 

253 "ts_nodash_with_tz": ts_nodash_with_tz, 

254 # keys that depend on data_interval 

255 "data_interval_end": coerce_datetime(dag_run.data_interval_end), 

256 "data_interval_start": coerce_datetime(dag_run.data_interval_start), 

257 "prev_data_interval_start_success": lazy_object_proxy.Proxy( 

258 lambda: coerce_datetime(get_previous_dagrun_success(self.id).data_interval_start) 

259 ), 

260 "prev_data_interval_end_success": lazy_object_proxy.Proxy( 

261 lambda: coerce_datetime(get_previous_dagrun_success(self.id).data_interval_end) 

262 ), 

263 } 

264 ) 

265 

266 if from_server.upstream_map_indexes is not None: 

267 # We stash this in here for later use, but we purposefully don't want to document it's 

268 # existence. Should this be a private attribute on RuntimeTI instead perhaps? 

269 setattr(self, "_upstream_map_indexes", from_server.upstream_map_indexes) 

270 

271 return self._cached_template_context 

272 

273 def render_templates( 

274 self, context: Context | None = None, jinja_env: jinja2.Environment | None = None 

275 ) -> BaseOperator: 

276 """ 

277 Render templates in the operator fields. 

278 

279 If the task was originally mapped, this may replace ``self.task`` with 

280 the unmapped, fully rendered BaseOperator. The original ``self.task`` 

281 before replacement is returned. 

282 """ 

283 if not context: 

284 context = self.get_template_context() 

285 original_task = self.task 

286 

287 if TYPE_CHECKING: 

288 assert context 

289 

290 ti = context["ti"] 

291 

292 if TYPE_CHECKING: 

293 assert original_task 

294 assert self.task 

295 assert ti.task 

296 

297 # If self.task is mapped, this call replaces self.task to point to the 

298 # unmapped BaseOperator created by this function! This is because the 

299 # MappedOperator is useless for template rendering, and we need to be 

300 # able to access the unmapped task instead. 

301 self.task.render_template_fields(context, jinja_env) 

302 self.is_mapped = original_task.is_mapped 

303 return original_task 

304 

305 def xcom_pull( 

306 self, 

307 task_ids: str | Iterable[str] | None = None, 

308 dag_id: str | None = None, 

309 key: str = BaseXCom.XCOM_RETURN_KEY, 

310 include_prior_dates: bool = False, 

311 *, 

312 map_indexes: int | Iterable[int] | None | ArgNotSet = NOTSET, 

313 default: Any = None, 

314 run_id: str | None = None, 

315 ) -> Any: 

316 """ 

317 Pull XComs either from the API server (BaseXCom) or from the custom XCOM backend if configured. 

318 

319 The pull can be filtered optionally by certain criterion. 

320 

321 :param key: A key for the XCom. If provided, only XComs with matching 

322 keys will be returned. The default key is ``'return_value'``, also 

323 available as constant ``XCOM_RETURN_KEY``. This key is automatically 

324 given to XComs returned by tasks (as opposed to being pushed 

325 manually). To remove the filter, pass *None*. 

326 :param task_ids: Only XComs from tasks with matching ids will be 

327 pulled. Pass *None* to remove the filter. 

328 :param dag_id: If provided, only pulls XComs from this Dag. If *None* 

329 (default), the Dag of the calling task is used. 

330 :param map_indexes: If provided, only pull XComs with matching indexes. 

331 If *None* (default), this is inferred from the task(s) being pulled 

332 (see below for details). 

333 :param include_prior_dates: If False, only XComs from the current 

334 logical_date are returned. If *True*, XComs from previous dates 

335 are returned as well. 

336 :param run_id: If provided, only pulls XComs from a DagRun w/a matching run_id. 

337 If *None* (default), the run_id of the calling task is used. 

338 

339 When pulling one single task (``task_id`` is *None* or a str) without 

340 specifying ``map_indexes``, the return value is a single XCom entry 

341 (map_indexes is set to map_index of the calling task instance). 

342 

343 When pulling task is mapped the specified ``map_index`` is used, so by default 

344 pulling on mapped task will result in no matching XComs if the task instance 

345 of the method call is not mapped. Otherwise, the map_index of the calling task 

346 instance is used. Setting ``map_indexes`` to *None* will pull XCom as it would 

347 from a non mapped task. 

348 

349 In either case, ``default`` (*None* if not specified) is returned if no 

350 matching XComs are found. 

351 

352 When pulling multiple tasks (i.e. either ``task_id`` or ``map_index`` is 

353 a non-str iterable), a list of matching XComs is returned. Elements in 

354 the list is ordered by item ordering in ``task_id`` and ``map_index``. 

355 """ 

356 if dag_id is None: 

357 dag_id = self.dag_id 

358 if run_id is None: 

359 run_id = self.run_id 

360 

361 single_task_requested = isinstance(task_ids, (str, type(None))) 

362 single_map_index_requested = isinstance(map_indexes, (int, type(None))) 

363 

364 if task_ids is None: 

365 # default to the current task if not provided 

366 task_ids = [self.task_id] 

367 elif isinstance(task_ids, str): 

368 task_ids = [task_ids] 

369 

370 # If map_indexes is not specified, pull xcoms from all map indexes for each task 

371 if not is_arg_set(map_indexes): 

372 xcoms: list[Any] = [] 

373 for t_id in task_ids: 

374 values = XCom.get_all( 

375 run_id=run_id, 

376 key=key, 

377 task_id=t_id, 

378 dag_id=dag_id, 

379 include_prior_dates=include_prior_dates, 

380 ) 

381 

382 if values is None: 

383 xcoms.append(None) 

384 else: 

385 xcoms.extend(values) 

386 # For single task pulling from unmapped task, return single value 

387 if single_task_requested and len(xcoms) == 1: 

388 return xcoms[0] 

389 return xcoms 

390 

391 # Original logic when map_indexes is explicitly specified 

392 map_indexes_iterable: Iterable[int | None] = [] 

393 if isinstance(map_indexes, int) or map_indexes is None: 

394 map_indexes_iterable = [map_indexes] 

395 elif isinstance(map_indexes, Iterable): 

396 map_indexes_iterable = map_indexes 

397 else: 

398 raise TypeError( 

399 f"Invalid type for map_indexes: expected int, iterable of ints, or None, got {type(map_indexes)}" 

400 ) 

401 

402 xcoms = [] 

403 for t_id, m_idx in product(task_ids, map_indexes_iterable): 

404 value = XCom.get_one( 

405 run_id=run_id, 

406 key=key, 

407 task_id=t_id, 

408 dag_id=dag_id, 

409 map_index=m_idx, 

410 include_prior_dates=include_prior_dates, 

411 ) 

412 if value is None: 

413 xcoms.append(default) 

414 else: 

415 xcoms.append(value) 

416 

417 if single_task_requested and single_map_index_requested: 

418 return xcoms[0] 

419 

420 return xcoms 

421 

422 def xcom_push(self, key: str, value: Any): 

423 """ 

424 Make an XCom available for tasks to pull. 

425 

426 :param key: Key to store the value under. 

427 :param value: Value to store. Only be JSON-serializable values may be used. 

428 """ 

429 _xcom_push(self, key, value) 

430 

431 def get_relevant_upstream_map_indexes( 

432 self, upstream: BaseOperator, ti_count: int | None, session: Any 

433 ) -> int | range | None: 

434 # TODO: Implement this method 

435 return None 

436 

437 def get_first_reschedule_date(self, context: Context) -> AwareDatetime | None: 

438 """Get the first reschedule date for the task instance if found, none otherwise.""" 

439 if context.get("task_reschedule_count", 0) == 0: 

440 # If the task has not been rescheduled, there is no need to ask the supervisor 

441 return None 

442 

443 max_tries: int = self.max_tries 

444 retries: int = self.task.retries or 0 

445 first_try_number = max_tries - retries + 1 

446 

447 log = structlog.get_logger(logger_name="task") 

448 

449 log.debug("Requesting first reschedule date from supervisor") 

450 

451 response = SUPERVISOR_COMMS.send( 

452 msg=GetTaskRescheduleStartDate(ti_id=self.id, try_number=first_try_number) 

453 ) 

454 

455 if TYPE_CHECKING: 

456 assert isinstance(response, TaskRescheduleStartDate) 

457 

458 return response.start_date 

459 

460 def get_previous_dagrun(self, state: str | None = None) -> DagRun | None: 

461 """Return the previous Dag run before the given logical date, optionally filtered by state.""" 

462 context = self.get_template_context() 

463 dag_run = context.get("dag_run") 

464 

465 log = structlog.get_logger(logger_name="task") 

466 

467 log.debug("Getting previous Dag run", dag_run=dag_run) 

468 

469 if dag_run is None: 

470 return None 

471 

472 if dag_run.logical_date is None: 

473 return None 

474 

475 response = SUPERVISOR_COMMS.send( 

476 msg=GetPreviousDagRun(dag_id=self.dag_id, logical_date=dag_run.logical_date, state=state) 

477 ) 

478 

479 if TYPE_CHECKING: 

480 assert isinstance(response, PreviousDagRunResult) 

481 

482 return response.dag_run 

483 

484 @staticmethod 

485 def get_ti_count( 

486 dag_id: str, 

487 map_index: int | None = None, 

488 task_ids: list[str] | None = None, 

489 task_group_id: str | None = None, 

490 logical_dates: list[datetime] | None = None, 

491 run_ids: list[str] | None = None, 

492 states: list[str] | None = None, 

493 ) -> int: 

494 """Return the number of task instances matching the given criteria.""" 

495 response = SUPERVISOR_COMMS.send( 

496 GetTICount( 

497 dag_id=dag_id, 

498 map_index=map_index, 

499 task_ids=task_ids, 

500 task_group_id=task_group_id, 

501 logical_dates=logical_dates, 

502 run_ids=run_ids, 

503 states=states, 

504 ), 

505 ) 

506 

507 if TYPE_CHECKING: 

508 assert isinstance(response, TICount) 

509 

510 return response.count 

511 

512 @staticmethod 

513 def get_task_states( 

514 dag_id: str, 

515 map_index: int | None = None, 

516 task_ids: list[str] | None = None, 

517 task_group_id: str | None = None, 

518 logical_dates: list[datetime] | None = None, 

519 run_ids: list[str] | None = None, 

520 ) -> dict[str, Any]: 

521 """Return the task states matching the given criteria.""" 

522 response = SUPERVISOR_COMMS.send( 

523 GetTaskStates( 

524 dag_id=dag_id, 

525 map_index=map_index, 

526 task_ids=task_ids, 

527 task_group_id=task_group_id, 

528 logical_dates=logical_dates, 

529 run_ids=run_ids, 

530 ), 

531 ) 

532 

533 if TYPE_CHECKING: 

534 assert isinstance(response, TaskStatesResult) 

535 

536 return response.task_states 

537 

538 @staticmethod 

539 def get_task_breadcrumbs(dag_id: str, run_id: str) -> Iterable[dict[str, Any]]: 

540 """Return task breadcrumbs for the given dag run.""" 

541 response = SUPERVISOR_COMMS.send(GetTaskBreadcrumbs(dag_id=dag_id, run_id=run_id)) 

542 if TYPE_CHECKING: 

543 assert isinstance(response, TaskBreadcrumbsResult) 

544 return response.breadcrumbs 

545 

546 @staticmethod 

547 def get_dr_count( 

548 dag_id: str, 

549 logical_dates: list[datetime] | None = None, 

550 run_ids: list[str] | None = None, 

551 states: list[str] | None = None, 

552 ) -> int: 

553 """Return the number of Dag runs matching the given criteria.""" 

554 response = SUPERVISOR_COMMS.send( 

555 GetDRCount( 

556 dag_id=dag_id, 

557 logical_dates=logical_dates, 

558 run_ids=run_ids, 

559 states=states, 

560 ), 

561 ) 

562 

563 if TYPE_CHECKING: 

564 assert isinstance(response, DRCount) 

565 

566 return response.count 

567 

568 @staticmethod 

569 def get_dagrun_state(dag_id: str, run_id: str) -> str: 

570 """Return the state of the Dag run with the given Run ID.""" 

571 response = SUPERVISOR_COMMS.send(msg=GetDagRunState(dag_id=dag_id, run_id=run_id)) 

572 

573 if TYPE_CHECKING: 

574 assert isinstance(response, DagRunStateResult) 

575 

576 return response.state 

577 

578 @property 

579 def log_url(self) -> str: 

580 run_id = quote(self.run_id) 

581 base_url = conf.get("api", "base_url", fallback="http://localhost:8080/") 

582 map_index_value = self.map_index 

583 map_index = ( 

584 f"/mapped/{map_index_value}" if map_index_value is not None and map_index_value >= 0 else "" 

585 ) 

586 try_number_value = self.try_number 

587 try_number = ( 

588 f"?try_number={try_number_value}" if try_number_value is not None and try_number_value > 0 else "" 

589 ) 

590 _log_uri = f"{base_url.rstrip('/')}/dags/{self.dag_id}/runs/{run_id}/tasks/{self.task_id}{map_index}{try_number}" 

591 return _log_uri 

592 

593 @property 

594 def mark_success_url(self) -> str: 

595 """URL to mark TI success.""" 

596 return self.log_url 

597 

598 

599def _xcom_push(ti: RuntimeTaskInstance, key: str, value: Any, mapped_length: int | None = None) -> None: 

600 """Push a XCom through XCom.set, which pushes to XCom Backend if configured.""" 

601 # Private function, as we don't want to expose the ability to manually set `mapped_length` to SDK 

602 # consumers 

603 

604 XCom.set( 

605 key=key, 

606 value=value, 

607 dag_id=ti.dag_id, 

608 task_id=ti.task_id, 

609 run_id=ti.run_id, 

610 map_index=ti.map_index, 

611 _mapped_length=mapped_length, 

612 ) 

613 

614 

615def _xcom_push_to_db(ti: RuntimeTaskInstance, key: str, value: Any) -> None: 

616 """Push a XCom directly to metadata DB, bypassing custom xcom_backend.""" 

617 XCom._set_xcom_in_db( 

618 key=key, 

619 value=value, 

620 dag_id=ti.dag_id, 

621 task_id=ti.task_id, 

622 run_id=ti.run_id, 

623 map_index=ti.map_index, 

624 ) 

625 

626 

627def parse(what: StartupDetails, log: Logger) -> RuntimeTaskInstance: 

628 # TODO: Task-SDK: 

629 # Using DagBag here is about 98% wrong, but it'll do for now 

630 from airflow.dag_processing.dagbag import DagBag 

631 

632 bundle_info = what.bundle_info 

633 bundle_instance = DagBundlesManager().get_bundle( 

634 name=bundle_info.name, 

635 version=bundle_info.version, 

636 ) 

637 bundle_instance.initialize() 

638 

639 # Put bundle root on sys.path if needed. This allows the dag bundle to add 

640 # code in util modules to be shared between files within the same bundle. 

641 if (bundle_root := os.fspath(bundle_instance.path)) not in sys.path: 

642 sys.path.append(bundle_root) 

643 

644 dag_absolute_path = os.fspath(Path(bundle_instance.path, what.dag_rel_path)) 

645 bag = DagBag( 

646 dag_folder=dag_absolute_path, 

647 include_examples=False, 

648 safe_mode=False, 

649 load_op_links=False, 

650 bundle_name=bundle_info.name, 

651 ) 

652 if TYPE_CHECKING: 

653 assert what.ti.dag_id 

654 

655 try: 

656 dag = bag.dags[what.ti.dag_id] 

657 except KeyError: 

658 log.error( 

659 "Dag not found during start up", dag_id=what.ti.dag_id, bundle=bundle_info, path=what.dag_rel_path 

660 ) 

661 sys.exit(1) 

662 

663 # install_loader() 

664 

665 try: 

666 task = dag.task_dict[what.ti.task_id] 

667 except KeyError: 

668 log.error( 

669 "Task not found in Dag during start up", 

670 dag_id=dag.dag_id, 

671 task_id=what.ti.task_id, 

672 bundle=bundle_info, 

673 path=what.dag_rel_path, 

674 ) 

675 sys.exit(1) 

676 

677 if not isinstance(task, (BaseOperator, MappedOperator)): 

678 raise TypeError( 

679 f"task is of the wrong type, got {type(task)}, wanted {BaseOperator} or {MappedOperator}" 

680 ) 

681 

682 return RuntimeTaskInstance.model_construct( 

683 **what.ti.model_dump(exclude_unset=True), 

684 task=task, 

685 bundle_instance=bundle_instance, 

686 _ti_context_from_server=what.ti_context, 

687 max_tries=what.ti_context.max_tries, 

688 start_date=what.start_date, 

689 state=TaskInstanceState.RUNNING, 

690 sentry_integration=what.sentry_integration, 

691 ) 

692 

693 

694# This global variable will be used by Connection/Variable/XCom classes, or other parts of the task's execution, 

695# to send requests back to the supervisor process. 

696# 

697# Why it needs to be a global: 

698# - Many parts of Airflow's codebase (e.g., connections, variables, and XComs) may rely on making dynamic requests 

699# to the parent process during task execution. 

700# - These calls occur in various locations and cannot easily pass the `CommsDecoder` instance through the 

701# deeply nested execution stack. 

702# - By defining `SUPERVISOR_COMMS` as a global, it ensures that this communication mechanism is readily 

703# accessible wherever needed during task execution without modifying every layer of the call stack. 

704SUPERVISOR_COMMS: CommsDecoder[ToTask, ToSupervisor] 

705 

706 

707# State machine! 

708# 1. Start up (receive details from supervisor) 

709# 2. Execution (run task code, possibly send requests) 

710# 3. Shutdown and report status 

711 

712 

713def startup() -> tuple[RuntimeTaskInstance, Context, Logger]: 

714 # The parent sends us a StartupDetails message un-prompted. After this, every single message is only sent 

715 # in response to us sending a request. 

716 log = structlog.get_logger(logger_name="task") 

717 

718 if os.environ.get("_AIRFLOW__REEXECUTED_PROCESS") == "1" and ( 

719 msgjson := os.environ.get("_AIRFLOW__STARTUP_MSG") 

720 ): 

721 # Clear any Kerberos replace cache if there is one, so new process can't reuse it. 

722 os.environ.pop("KRB5CCNAME", None) 

723 # entrypoint of re-exec process 

724 

725 msg: StartupDetails = TypeAdapter(StartupDetails).validate_json(msgjson) 

726 reinit_supervisor_comms() 

727 

728 # We delay this message until _after_ we've got the logging re-configured, otherwise it will show up 

729 # on stdout 

730 log.debug("Using serialized startup message from environment", msg=msg) 

731 else: 

732 # normal entry point 

733 msg = SUPERVISOR_COMMS._get_response() # type: ignore[assignment] 

734 

735 if not isinstance(msg, StartupDetails): 

736 raise RuntimeError(f"Unhandled startup message {type(msg)} {msg}") 

737 

738 # setproctitle causes issue on Mac OS: https://github.com/benoitc/gunicorn/issues/3021 

739 os_type = sys.platform 

740 if os_type == "darwin": 

741 log.debug("Mac OS detected, skipping setproctitle") 

742 else: 

743 from setproctitle import setproctitle 

744 

745 setproctitle(f"airflow worker -- {msg.ti.id}") 

746 

747 try: 

748 get_listener_manager().hook.on_starting(component=TaskRunnerMarker()) 

749 except Exception: 

750 log.exception("error calling listener") 

751 

752 with _airflow_parsing_context_manager(dag_id=msg.ti.dag_id, task_id=msg.ti.task_id): 

753 ti = parse(msg, log) 

754 log.debug("Dag file parsed", file=msg.dag_rel_path) 

755 

756 run_as_user = getattr(ti.task, "run_as_user", None) or conf.get( 

757 "core", "default_impersonation", fallback=None 

758 ) 

759 

760 if os.environ.get("_AIRFLOW__REEXECUTED_PROCESS") != "1" and run_as_user and run_as_user != getuser(): 

761 # enters here for re-exec process 

762 os.environ["_AIRFLOW__REEXECUTED_PROCESS"] = "1" 

763 # store startup message in environment for re-exec process 

764 os.environ["_AIRFLOW__STARTUP_MSG"] = msg.model_dump_json() 

765 os.set_inheritable(SUPERVISOR_COMMS.socket.fileno(), True) 

766 

767 # Import main directly from the module instead of re-executing the file. 

768 # This ensures that when other parts modules import 

769 # airflow.sdk.execution_time.task_runner, they get the same module instance 

770 # with the properly initialized SUPERVISOR_COMMS global variable. 

771 # If we re-executed the module with `python -m`, it would load as __main__ and future 

772 # imports would get a fresh copy without the initialized globals. 

773 rexec_python_code = "from airflow.sdk.execution_time.task_runner import main; main()" 

774 cmd = ["sudo", "-E", "-H", "-u", run_as_user, sys.executable, "-c", rexec_python_code] 

775 log.info( 

776 "Running command", 

777 command=cmd, 

778 ) 

779 os.execvp("sudo", cmd) 

780 

781 # ideally, we should never reach here, but if we do, we should return None, None, None 

782 return None, None, None 

783 

784 return ti, ti.get_template_context(), log 

785 

786 

787def _serialize_rendered_fields(task: AbstractOperator) -> dict[str, JsonValue]: 

788 # TODO: Port one of the following to Task SDK 

789 # airflow.serialization.helpers.serialize_template_field or 

790 # airflow.models.renderedtifields.get_serialized_template_fields 

791 from airflow.sdk._shared.secrets_masker import redact 

792 from airflow.serialization.helpers import serialize_template_field 

793 

794 rendered_fields = {} 

795 for field in task.template_fields: 

796 value = getattr(task, field) 

797 serialized = serialize_template_field(value, field) 

798 # Redact secrets in the task process itself before sending to API server 

799 # This ensures that the secrets those are registered via mask_secret() on workers / dag processor are properly masked 

800 # on the UI. 

801 rendered_fields[field] = redact(serialized, field) 

802 

803 return rendered_fields # type: ignore[return-value] # Convince mypy that this is OK since we pass JsonValue to redact, so it will return the same 

804 

805 

806def _build_asset_profiles(lineage_objects: list) -> Iterator[AssetProfile]: 

807 # Lineage can have other types of objects besides assets, so we need to process them a bit. 

808 for obj in lineage_objects or (): 

809 if isinstance(obj, Asset): 

810 yield AssetProfile(name=obj.name, uri=obj.uri, type=Asset.__name__) 

811 elif isinstance(obj, AssetNameRef): 

812 yield AssetProfile(name=obj.name, type=AssetNameRef.__name__) 

813 elif isinstance(obj, AssetUriRef): 

814 yield AssetProfile(uri=obj.uri, type=AssetUriRef.__name__) 

815 elif isinstance(obj, AssetAlias): 

816 yield AssetProfile(name=obj.name, type=AssetAlias.__name__) 

817 

818 

819def _serialize_outlet_events(events: OutletEventAccessorsProtocol) -> Iterator[dict[str, JsonValue]]: 

820 if TYPE_CHECKING: 

821 assert isinstance(events, OutletEventAccessors) 

822 # We just collect everything the user recorded in the accessors. 

823 # Further filtering will be done in the API server. 

824 for key, accessor in events._dict.items(): 

825 if isinstance(key, AssetUniqueKey): 

826 yield {"dest_asset_key": attrs.asdict(key), "extra": accessor.extra} 

827 for alias_event in accessor.asset_alias_events: 

828 yield attrs.asdict(alias_event) 

829 

830 

831def _prepare(ti: RuntimeTaskInstance, log: Logger, context: Context) -> ToSupervisor | None: 

832 ti.hostname = get_hostname() 

833 ti.task = ti.task.prepare_for_execution() 

834 # Since context is now cached, and calling `ti.get_template_context` will return the same dict, we want to 

835 # update the value of the task that is sent from there 

836 context["task"] = ti.task 

837 

838 jinja_env = ti.task.dag.get_template_env() 

839 ti.render_templates(context=context, jinja_env=jinja_env) 

840 

841 if rendered_fields := _serialize_rendered_fields(ti.task): 

842 # so that we do not call the API unnecessarily 

843 SUPERVISOR_COMMS.send(msg=SetRenderedFields(rendered_fields=rendered_fields)) 

844 

845 # Try to render map_index_template early with available context (will be re-rendered after execution) 

846 # This provides a partial label during task execution for templates using pre-execution context 

847 # If rendering fails here, we suppress the error since it will be re-rendered after execution 

848 try: 

849 if rendered_map_index := _render_map_index(context, ti=ti, log=log): 

850 ti.rendered_map_index = rendered_map_index 

851 log.debug("Sending early rendered map index", length=len(rendered_map_index)) 

852 SUPERVISOR_COMMS.send(msg=SetRenderedMapIndex(rendered_map_index=rendered_map_index)) 

853 except Exception: 

854 log.debug( 

855 "Early rendering of map_index_template failed, will retry after task execution", exc_info=True 

856 ) 

857 

858 _validate_task_inlets_and_outlets(ti=ti, log=log) 

859 

860 try: 

861 # TODO: Call pre execute etc. 

862 get_listener_manager().hook.on_task_instance_running( 

863 previous_state=TaskInstanceState.QUEUED, task_instance=ti 

864 ) 

865 except Exception: 

866 log.exception("error calling listener") 

867 

868 # No error, carry on and execute the task 

869 return None 

870 

871 

872def _validate_task_inlets_and_outlets(*, ti: RuntimeTaskInstance, log: Logger) -> None: 

873 if not ti.task.inlets and not ti.task.outlets: 

874 return 

875 

876 inactive_assets_resp = SUPERVISOR_COMMS.send(msg=ValidateInletsAndOutlets(ti_id=ti.id)) 

877 if TYPE_CHECKING: 

878 assert isinstance(inactive_assets_resp, InactiveAssetsResult) 

879 if inactive_assets := inactive_assets_resp.inactive_assets: 

880 raise AirflowInactiveAssetInInletOrOutletException( 

881 inactive_asset_keys=[ 

882 AssetUniqueKey.from_profile(asset_profile) for asset_profile in inactive_assets 

883 ] 

884 ) 

885 

886 

887def _defer_task( 

888 defer: TaskDeferred, ti: RuntimeTaskInstance, log: Logger 

889) -> tuple[ToSupervisor, TaskInstanceState]: 

890 # TODO: Should we use structlog.bind_contextvars here for dag_id, task_id & run_id? 

891 

892 log.info("Pausing task as DEFERRED. ", dag_id=ti.dag_id, task_id=ti.task_id, run_id=ti.run_id) 

893 classpath, trigger_kwargs = defer.trigger.serialize() 

894 

895 msg = DeferTask( 

896 classpath=classpath, 

897 trigger_kwargs=trigger_kwargs, 

898 trigger_timeout=defer.timeout, 

899 next_method=defer.method_name, 

900 next_kwargs=defer.kwargs or {}, 

901 ) 

902 state = TaskInstanceState.DEFERRED 

903 

904 return msg, state 

905 

906 

907@Sentry.enrich_errors 

908def run( 

909 ti: RuntimeTaskInstance, 

910 context: Context, 

911 log: Logger, 

912) -> tuple[TaskInstanceState, ToSupervisor | None, BaseException | None]: 

913 """Run the task in this process.""" 

914 import signal 

915 

916 from airflow.sdk.exceptions import ( 

917 AirflowFailException, 

918 AirflowRescheduleException, 

919 AirflowSensorTimeout, 

920 AirflowSkipException, 

921 AirflowTaskTerminated, 

922 DagRunTriggerException, 

923 DownstreamTasksSkipped, 

924 TaskDeferred, 

925 ) 

926 

927 if TYPE_CHECKING: 

928 assert ti.task is not None 

929 assert isinstance(ti.task, BaseOperator) 

930 

931 parent_pid = os.getpid() 

932 

933 def _on_term(signum, frame): 

934 pid = os.getpid() 

935 if pid != parent_pid: 

936 return 

937 

938 ti.task.on_kill() 

939 

940 signal.signal(signal.SIGTERM, _on_term) 

941 

942 msg: ToSupervisor | None = None 

943 state: TaskInstanceState 

944 error: BaseException | None = None 

945 

946 try: 

947 # First, clear the xcom data sent from server 

948 if ti._ti_context_from_server and (keys_to_delete := ti._ti_context_from_server.xcom_keys_to_clear): 

949 for x in keys_to_delete: 

950 log.debug("Clearing XCom with key", key=x) 

951 XCom.delete( 

952 key=x, 

953 dag_id=ti.dag_id, 

954 task_id=ti.task_id, 

955 run_id=ti.run_id, 

956 map_index=ti.map_index, 

957 ) 

958 

959 with set_current_context(context): 

960 # This is the earliest that we can render templates -- as if it excepts for any reason we need to 

961 # catch it and handle it like a normal task failure 

962 if early_exit := _prepare(ti, log, context): 

963 msg = early_exit 

964 ti.state = state = TaskInstanceState.FAILED 

965 return state, msg, error 

966 

967 try: 

968 result = _execute_task(context=context, ti=ti, log=log) 

969 except Exception: 

970 import jinja2 

971 

972 # If the task failed, swallow rendering error so it doesn't mask the main error. 

973 with contextlib.suppress(jinja2.TemplateSyntaxError, jinja2.UndefinedError): 

974 previous_rendered_map_index = ti.rendered_map_index 

975 ti.rendered_map_index = _render_map_index(context, ti=ti, log=log) 

976 # Send update only if value changed (e.g., user set context variables during execution) 

977 if ti.rendered_map_index and ti.rendered_map_index != previous_rendered_map_index: 

978 SUPERVISOR_COMMS.send( 

979 msg=SetRenderedMapIndex(rendered_map_index=ti.rendered_map_index) 

980 ) 

981 raise 

982 else: # If the task succeeded, render normally to let rendering error bubble up. 

983 previous_rendered_map_index = ti.rendered_map_index 

984 ti.rendered_map_index = _render_map_index(context, ti=ti, log=log) 

985 # Send update only if value changed (e.g., user set context variables during execution) 

986 if ti.rendered_map_index and ti.rendered_map_index != previous_rendered_map_index: 

987 SUPERVISOR_COMMS.send(msg=SetRenderedMapIndex(rendered_map_index=ti.rendered_map_index)) 

988 

989 _push_xcom_if_needed(result, ti, log) 

990 

991 msg, state = _handle_current_task_success(context, ti) 

992 except DownstreamTasksSkipped as skip: 

993 log.info("Skipping downstream tasks.") 

994 tasks_to_skip = skip.tasks if isinstance(skip.tasks, list) else [skip.tasks] 

995 SUPERVISOR_COMMS.send(msg=SkipDownstreamTasks(tasks=tasks_to_skip)) 

996 msg, state = _handle_current_task_success(context, ti) 

997 except DagRunTriggerException as drte: 

998 msg, state = _handle_trigger_dag_run(drte, context, ti, log) 

999 except TaskDeferred as defer: 

1000 msg, state = _defer_task(defer, ti, log) 

1001 except AirflowSkipException as e: 

1002 if e.args: 

1003 log.info("Skipping task.", reason=e.args[0]) 

1004 msg = TaskState( 

1005 state=TaskInstanceState.SKIPPED, 

1006 end_date=datetime.now(tz=timezone.utc), 

1007 rendered_map_index=ti.rendered_map_index, 

1008 ) 

1009 state = TaskInstanceState.SKIPPED 

1010 except AirflowRescheduleException as reschedule: 

1011 log.info("Rescheduling task, marking task as UP_FOR_RESCHEDULE") 

1012 msg = RescheduleTask( 

1013 reschedule_date=reschedule.reschedule_date, end_date=datetime.now(tz=timezone.utc) 

1014 ) 

1015 state = TaskInstanceState.UP_FOR_RESCHEDULE 

1016 except (AirflowFailException, AirflowSensorTimeout) as e: 

1017 # If AirflowFailException is raised, task should not retry. 

1018 # If a sensor in reschedule mode reaches timeout, task should not retry. 

1019 log.exception("Task failed with exception") 

1020 ti.end_date = datetime.now(tz=timezone.utc) 

1021 msg = TaskState( 

1022 state=TaskInstanceState.FAILED, 

1023 end_date=ti.end_date, 

1024 rendered_map_index=ti.rendered_map_index, 

1025 ) 

1026 state = TaskInstanceState.FAILED 

1027 error = e 

1028 except (AirflowTaskTimeout, AirflowException, AirflowRuntimeError) as e: 

1029 # We should allow retries if the task has defined it. 

1030 log.exception("Task failed with exception") 

1031 msg, state = _handle_current_task_failed(ti) 

1032 error = e 

1033 except AirflowTaskTerminated as e: 

1034 # External state updates are already handled with `ti_heartbeat` and will be 

1035 # updated already be another UI API. So, these exceptions should ideally never be thrown. 

1036 # If these are thrown, we should mark the TI state as failed. 

1037 log.exception("Task failed with exception") 

1038 ti.end_date = datetime.now(tz=timezone.utc) 

1039 msg = TaskState( 

1040 state=TaskInstanceState.FAILED, 

1041 end_date=ti.end_date, 

1042 rendered_map_index=ti.rendered_map_index, 

1043 ) 

1044 state = TaskInstanceState.FAILED 

1045 error = e 

1046 except SystemExit as e: 

1047 # SystemExit needs to be retried if they are eligible. 

1048 log.error("Task exited", exit_code=e.code) 

1049 msg, state = _handle_current_task_failed(ti) 

1050 error = e 

1051 except BaseException as e: 

1052 log.exception("Task failed with exception") 

1053 msg, state = _handle_current_task_failed(ti) 

1054 error = e 

1055 finally: 

1056 if msg: 

1057 SUPERVISOR_COMMS.send(msg=msg) 

1058 

1059 # Return the message to make unit tests easier too 

1060 ti.state = state 

1061 return state, msg, error 

1062 

1063 

1064def _handle_current_task_success( 

1065 context: Context, 

1066 ti: RuntimeTaskInstance, 

1067) -> tuple[SucceedTask, TaskInstanceState]: 

1068 end_date = datetime.now(tz=timezone.utc) 

1069 ti.end_date = end_date 

1070 

1071 # Record operator and task instance success metrics 

1072 operator = ti.task.__class__.__name__ 

1073 stats_tags = {"dag_id": ti.dag_id, "task_id": ti.task_id} 

1074 

1075 Stats.incr(f"operator_successes_{operator}", tags=stats_tags) 

1076 # Same metric with tagging 

1077 Stats.incr("operator_successes", tags={**stats_tags, "operator": operator}) 

1078 Stats.incr("ti_successes", tags=stats_tags) 

1079 

1080 task_outlets = list(_build_asset_profiles(ti.task.outlets)) 

1081 outlet_events = list(_serialize_outlet_events(context["outlet_events"])) 

1082 msg = SucceedTask( 

1083 end_date=end_date, 

1084 task_outlets=task_outlets, 

1085 outlet_events=outlet_events, 

1086 rendered_map_index=ti.rendered_map_index, 

1087 ) 

1088 return msg, TaskInstanceState.SUCCESS 

1089 

1090 

1091def _handle_current_task_failed( 

1092 ti: RuntimeTaskInstance, 

1093) -> tuple[RetryTask, TaskInstanceState] | tuple[TaskState, TaskInstanceState]: 

1094 end_date = datetime.now(tz=timezone.utc) 

1095 ti.end_date = end_date 

1096 if ti._ti_context_from_server and ti._ti_context_from_server.should_retry: 

1097 return RetryTask(end_date=end_date), TaskInstanceState.UP_FOR_RETRY 

1098 return ( 

1099 TaskState( 

1100 state=TaskInstanceState.FAILED, end_date=end_date, rendered_map_index=ti.rendered_map_index 

1101 ), 

1102 TaskInstanceState.FAILED, 

1103 ) 

1104 

1105 

1106def _handle_trigger_dag_run( 

1107 drte: DagRunTriggerException, context: Context, ti: RuntimeTaskInstance, log: Logger 

1108) -> tuple[ToSupervisor, TaskInstanceState]: 

1109 """Handle exception from TriggerDagRunOperator.""" 

1110 log.info("Triggering Dag Run.", trigger_dag_id=drte.trigger_dag_id) 

1111 comms_msg = SUPERVISOR_COMMS.send( 

1112 TriggerDagRun( 

1113 dag_id=drte.trigger_dag_id, 

1114 run_id=drte.dag_run_id, 

1115 logical_date=drte.logical_date, 

1116 conf=drte.conf, 

1117 reset_dag_run=drte.reset_dag_run, 

1118 ), 

1119 ) 

1120 

1121 if isinstance(comms_msg, ErrorResponse) and comms_msg.error == ErrorType.DAGRUN_ALREADY_EXISTS: 

1122 if drte.skip_when_already_exists: 

1123 log.info( 

1124 "Dag Run already exists, skipping task as skip_when_already_exists is set to True.", 

1125 dag_id=drte.trigger_dag_id, 

1126 ) 

1127 msg = TaskState( 

1128 state=TaskInstanceState.SKIPPED, 

1129 end_date=datetime.now(tz=timezone.utc), 

1130 rendered_map_index=ti.rendered_map_index, 

1131 ) 

1132 state = TaskInstanceState.SKIPPED 

1133 else: 

1134 log.error("Dag Run already exists, marking task as failed.", dag_id=drte.trigger_dag_id) 

1135 msg = TaskState( 

1136 state=TaskInstanceState.FAILED, 

1137 end_date=datetime.now(tz=timezone.utc), 

1138 rendered_map_index=ti.rendered_map_index, 

1139 ) 

1140 state = TaskInstanceState.FAILED 

1141 

1142 return msg, state 

1143 

1144 log.info("Dag Run triggered successfully.", trigger_dag_id=drte.trigger_dag_id) 

1145 

1146 # Store the run id from the dag run (either created or found above) to 

1147 # be used when creating the extra link on the webserver. 

1148 ti.xcom_push(key="trigger_run_id", value=drte.dag_run_id) 

1149 

1150 if drte.deferrable: 

1151 from airflow.providers.standard.triggers.external_task import DagStateTrigger 

1152 

1153 defer = TaskDeferred( 

1154 trigger=DagStateTrigger( 

1155 dag_id=drte.trigger_dag_id, 

1156 states=drte.allowed_states + drte.failed_states, # type: ignore[arg-type] 

1157 # Don't filter by execution_dates when run_ids is provided. 

1158 # run_id uniquely identifies a DAG run, and when reset_dag_run=True, 

1159 # drte.logical_date might be a newly calculated value that doesn't match 

1160 # the persisted logical_date in the database, causing the trigger to never find the run. 

1161 execution_dates=None, 

1162 run_ids=[drte.dag_run_id], 

1163 poll_interval=drte.poke_interval, 

1164 ), 

1165 method_name="execute_complete", 

1166 ) 

1167 return _defer_task(defer, ti, log) 

1168 if drte.wait_for_completion: 

1169 while True: 

1170 log.info( 

1171 "Waiting for dag run to complete execution in allowed state.", 

1172 dag_id=drte.trigger_dag_id, 

1173 run_id=drte.dag_run_id, 

1174 allowed_state=drte.allowed_states, 

1175 ) 

1176 time.sleep(drte.poke_interval) 

1177 

1178 comms_msg = SUPERVISOR_COMMS.send( 

1179 GetDagRunState(dag_id=drte.trigger_dag_id, run_id=drte.dag_run_id) 

1180 ) 

1181 if TYPE_CHECKING: 

1182 assert isinstance(comms_msg, DagRunStateResult) 

1183 if comms_msg.state in drte.failed_states: 

1184 log.error( 

1185 "DagRun finished with failed state.", dag_id=drte.trigger_dag_id, state=comms_msg.state 

1186 ) 

1187 msg = TaskState( 

1188 state=TaskInstanceState.FAILED, 

1189 end_date=datetime.now(tz=timezone.utc), 

1190 rendered_map_index=ti.rendered_map_index, 

1191 ) 

1192 state = TaskInstanceState.FAILED 

1193 return msg, state 

1194 if comms_msg.state in drte.allowed_states: 

1195 log.info( 

1196 "DagRun finished with allowed state.", dag_id=drte.trigger_dag_id, state=comms_msg.state 

1197 ) 

1198 break 

1199 log.debug( 

1200 "DagRun not yet in allowed or failed state.", 

1201 dag_id=drte.trigger_dag_id, 

1202 state=comms_msg.state, 

1203 ) 

1204 

1205 return _handle_current_task_success(context, ti) 

1206 

1207 

1208def _run_task_state_change_callbacks( 

1209 task: BaseOperator, 

1210 kind: Literal[ 

1211 "on_execute_callback", 

1212 "on_failure_callback", 

1213 "on_success_callback", 

1214 "on_retry_callback", 

1215 "on_skipped_callback", 

1216 ], 

1217 context: Context, 

1218 log: Logger, 

1219) -> None: 

1220 callback: Callable[[Context], None] 

1221 for i, callback in enumerate(getattr(task, kind)): 

1222 try: 

1223 create_executable_runner(callback, context_get_outlet_events(context), logger=log).run(context) 

1224 except Exception: 

1225 log.exception("Failed to run task callback", kind=kind, index=i, callback=callback) 

1226 

1227 

1228def _send_error_email_notification( 

1229 task: BaseOperator | MappedOperator, 

1230 ti: RuntimeTaskInstance, 

1231 context: Context, 

1232 error: BaseException | str | None, 

1233 log: Logger, 

1234) -> None: 

1235 """Send email notification for task errors using SmtpNotifier.""" 

1236 try: 

1237 from airflow.providers.smtp.notifications.smtp import SmtpNotifier 

1238 except ImportError: 

1239 log.error( 

1240 "Failed to send task failure or retry email notification: " 

1241 "`apache-airflow-providers-smtp` is not installed. " 

1242 "Install this provider to enable email notifications." 

1243 ) 

1244 return 

1245 

1246 if not task.email: 

1247 return 

1248 

1249 subject_template_file = conf.get("email", "subject_template", fallback=None) 

1250 

1251 # Read the template file if configured 

1252 if subject_template_file and Path(subject_template_file).exists(): 

1253 subject = Path(subject_template_file).read_text() 

1254 else: 

1255 # Fallback to default 

1256 subject = "Airflow alert: {{ti}}" 

1257 

1258 html_content_template_file = conf.get("email", "html_content_template", fallback=None) 

1259 

1260 # Read the template file if configured 

1261 if html_content_template_file and Path(html_content_template_file).exists(): 

1262 html_content = Path(html_content_template_file).read_text() 

1263 else: 

1264 # Fallback to default 

1265 # For reporting purposes, we report based on 1-indexed, 

1266 # not 0-indexed lists (i.e. Try 1 instead of Try 0 for the first attempt). 

1267 html_content = ( 

1268 "Try {{try_number}} out of {{max_tries + 1}}<br>" 

1269 "Exception:<br>{{exception_html}}<br>" 

1270 'Log: <a href="{{ti.log_url}}">Link</a><br>' 

1271 "Host: {{ti.hostname}}<br>" 

1272 'Mark success: <a href="{{ti.mark_success_url}}">Link</a><br>' 

1273 ) 

1274 

1275 # Add exception_html to context for template rendering 

1276 import html 

1277 

1278 exception_html = html.escape(str(error)).replace("\n", "<br>") 

1279 additional_context = { 

1280 "exception": error, 

1281 "exception_html": exception_html, 

1282 "try_number": ti.try_number, 

1283 "max_tries": ti.max_tries, 

1284 } 

1285 email_context = {**context, **additional_context} 

1286 to_emails = task.email 

1287 if not to_emails: 

1288 return 

1289 

1290 try: 

1291 notifier = SmtpNotifier( 

1292 to=to_emails, 

1293 subject=subject, 

1294 html_content=html_content, 

1295 from_email=conf.get("email", "from_email", fallback="airflow@airflow"), 

1296 ) 

1297 notifier(email_context) 

1298 except Exception: 

1299 log.exception("Failed to send email notification") 

1300 

1301 

1302def _execute_task(context: Context, ti: RuntimeTaskInstance, log: Logger): 

1303 """Execute Task (optionally with a Timeout) and push Xcom results.""" 

1304 task = ti.task 

1305 execute = task.execute 

1306 

1307 if ti._ti_context_from_server and (next_method := ti._ti_context_from_server.next_method): 

1308 from airflow.serialization.serialized_objects import BaseSerialization 

1309 

1310 kwargs = BaseSerialization.deserialize(ti._ti_context_from_server.next_kwargs or {}) 

1311 

1312 execute = functools.partial(task.resume_execution, next_method=next_method, next_kwargs=kwargs) 

1313 

1314 ctx = contextvars.copy_context() 

1315 # Populate the context var so ExecutorSafeguard doesn't complain 

1316 ctx.run(ExecutorSafeguard.tracker.set, task) 

1317 

1318 # Export context in os.environ to make it available for operators to use. 

1319 airflow_context_vars = context_to_airflow_vars(context, in_env_var_format=True) 

1320 os.environ.update(airflow_context_vars) 

1321 

1322 outlet_events = context_get_outlet_events(context) 

1323 

1324 if (pre_execute_hook := task._pre_execute_hook) is not None: 

1325 create_executable_runner(pre_execute_hook, outlet_events, logger=log).run(context) 

1326 if getattr(pre_execute_hook := task.pre_execute, "__func__", None) is not BaseOperator.pre_execute: 

1327 create_executable_runner(pre_execute_hook, outlet_events, logger=log).run(context) 

1328 

1329 _run_task_state_change_callbacks(task, "on_execute_callback", context, log) 

1330 

1331 if task.execution_timeout: 

1332 from airflow.sdk.execution_time.timeout import timeout 

1333 

1334 # TODO: handle timeout in case of deferral 

1335 timeout_seconds = task.execution_timeout.total_seconds() 

1336 try: 

1337 # It's possible we're already timed out, so fast-fail if true 

1338 if timeout_seconds <= 0: 

1339 raise AirflowTaskTimeout() 

1340 # Run task in timeout wrapper 

1341 with timeout(timeout_seconds): 

1342 result = ctx.run(execute, context=context) 

1343 except AirflowTaskTimeout: 

1344 task.on_kill() 

1345 raise 

1346 else: 

1347 result = ctx.run(execute, context=context) 

1348 

1349 if (post_execute_hook := task._post_execute_hook) is not None: 

1350 create_executable_runner(post_execute_hook, outlet_events, logger=log).run(context, result) 

1351 if getattr(post_execute_hook := task.post_execute, "__func__", None) is not BaseOperator.post_execute: 

1352 create_executable_runner(post_execute_hook, outlet_events, logger=log).run(context) 

1353 

1354 return result 

1355 

1356 

1357def _render_map_index(context: Context, ti: RuntimeTaskInstance, log: Logger) -> str | None: 

1358 """Render named map index if the Dag author defined map_index_template at the task level.""" 

1359 if (template := context.get("map_index_template")) is None: 

1360 return None 

1361 log.debug("Rendering map_index_template", template_length=len(template)) 

1362 jinja_env = ti.task.dag.get_template_env() 

1363 rendered_map_index = jinja_env.from_string(template).render(context) 

1364 log.debug("Map index rendered", length=len(rendered_map_index)) 

1365 return rendered_map_index 

1366 

1367 

1368def _push_xcom_if_needed(result: Any, ti: RuntimeTaskInstance, log: Logger): 

1369 """Push XCom values when task has ``do_xcom_push`` set to ``True`` and the task returns a result.""" 

1370 if ti.task.do_xcom_push: 

1371 xcom_value = result 

1372 else: 

1373 xcom_value = None 

1374 

1375 has_mapped_dep = next(ti.task.iter_mapped_dependants(), None) is not None 

1376 if xcom_value is None: 

1377 if not ti.is_mapped and has_mapped_dep: 

1378 # Uhoh, a downstream mapped task depends on us to push something to map over 

1379 from airflow.sdk.exceptions import XComForMappingNotPushed 

1380 

1381 raise XComForMappingNotPushed() 

1382 return 

1383 

1384 mapped_length: int | None = None 

1385 if not ti.is_mapped and has_mapped_dep: 

1386 from airflow.sdk.definitions.mappedoperator import is_mappable_value 

1387 from airflow.sdk.exceptions import UnmappableXComTypePushed 

1388 

1389 if not is_mappable_value(xcom_value): 

1390 raise UnmappableXComTypePushed(xcom_value) 

1391 mapped_length = len(xcom_value) 

1392 

1393 log.info("Pushing xcom", ti=ti) 

1394 

1395 # If the task has multiple outputs, push each output as a separate XCom. 

1396 if ti.task.multiple_outputs: 

1397 if not isinstance(xcom_value, Mapping): 

1398 raise TypeError( 

1399 f"Returned output was type {type(xcom_value)} expected dictionary for multiple_outputs" 

1400 ) 

1401 for key in xcom_value.keys(): 

1402 if not isinstance(key, str): 

1403 raise TypeError( 

1404 "Returned dictionary keys must be strings when using " 

1405 f"multiple_outputs, found {key} ({type(key)}) instead" 

1406 ) 

1407 for k, v in result.items(): 

1408 ti.xcom_push(k, v) 

1409 

1410 _xcom_push(ti, BaseXCom.XCOM_RETURN_KEY, result, mapped_length=mapped_length) 

1411 

1412 

1413def finalize( 

1414 ti: RuntimeTaskInstance, 

1415 state: TaskInstanceState, 

1416 context: Context, 

1417 log: Logger, 

1418 error: BaseException | None = None, 

1419): 

1420 # Record task duration metrics for all terminal states 

1421 if ti.start_date and ti.end_date: 

1422 duration_ms = (ti.end_date - ti.start_date).total_seconds() * 1000 

1423 stats_tags = {"dag_id": ti.dag_id, "task_id": ti.task_id} 

1424 

1425 Stats.timing(f"dag.{ti.dag_id}.{ti.task_id}.duration", duration_ms) 

1426 Stats.timing("task.duration", duration_ms, tags=stats_tags) 

1427 

1428 task = ti.task 

1429 # Pushing xcom for each operator extra links defined on the operator only. 

1430 for oe in task.operator_extra_links: 

1431 try: 

1432 link, xcom_key = oe.get_link(operator=task, ti_key=ti), oe.xcom_key # type: ignore[arg-type] 

1433 log.debug("Setting xcom for operator extra link", link=link, xcom_key=xcom_key) 

1434 _xcom_push_to_db(ti, key=xcom_key, value=link) 

1435 except Exception: 

1436 log.exception( 

1437 "Failed to push an xcom for task operator extra link", 

1438 link_name=oe.name, 

1439 xcom_key=oe.xcom_key, 

1440 ti=ti, 

1441 ) 

1442 

1443 if getattr(ti.task, "overwrite_rtif_after_execution", False): 

1444 log.debug("Overwriting Rendered template fields.") 

1445 if ti.task.template_fields: 

1446 SUPERVISOR_COMMS.send(SetRenderedFields(rendered_fields=_serialize_rendered_fields(ti.task))) 

1447 

1448 log.debug("Running finalizers", ti=ti) 

1449 if state == TaskInstanceState.SUCCESS: 

1450 _run_task_state_change_callbacks(task, "on_success_callback", context, log) 

1451 try: 

1452 get_listener_manager().hook.on_task_instance_success( 

1453 previous_state=TaskInstanceState.RUNNING, task_instance=ti 

1454 ) 

1455 except Exception: 

1456 log.exception("error calling listener") 

1457 elif state == TaskInstanceState.SKIPPED: 

1458 _run_task_state_change_callbacks(task, "on_skipped_callback", context, log) 

1459 elif state == TaskInstanceState.UP_FOR_RETRY: 

1460 _run_task_state_change_callbacks(task, "on_retry_callback", context, log) 

1461 try: 

1462 get_listener_manager().hook.on_task_instance_failed( 

1463 previous_state=TaskInstanceState.RUNNING, task_instance=ti, error=error 

1464 ) 

1465 except Exception: 

1466 log.exception("error calling listener") 

1467 if error and task.email_on_retry and task.email: 

1468 _send_error_email_notification(task, ti, context, error, log) 

1469 elif state == TaskInstanceState.FAILED: 

1470 _run_task_state_change_callbacks(task, "on_failure_callback", context, log) 

1471 try: 

1472 get_listener_manager().hook.on_task_instance_failed( 

1473 previous_state=TaskInstanceState.RUNNING, task_instance=ti, error=error 

1474 ) 

1475 except Exception: 

1476 log.exception("error calling listener") 

1477 if error and task.email_on_failure and task.email: 

1478 _send_error_email_notification(task, ti, context, error, log) 

1479 

1480 try: 

1481 get_listener_manager().hook.before_stopping(component=TaskRunnerMarker()) 

1482 except Exception: 

1483 log.exception("error calling listener") 

1484 

1485 

1486def main(): 

1487 log = structlog.get_logger(logger_name="task") 

1488 

1489 global SUPERVISOR_COMMS 

1490 SUPERVISOR_COMMS = CommsDecoder[ToTask, ToSupervisor](log=log) 

1491 

1492 try: 

1493 ti, context, log = startup() 

1494 with BundleVersionLock( 

1495 bundle_name=ti.bundle_instance.name, 

1496 bundle_version=ti.bundle_instance.version, 

1497 ): 

1498 state, _, error = run(ti, context, log) 

1499 context["exception"] = error 

1500 finalize(ti, state, context, log, error) 

1501 except KeyboardInterrupt: 

1502 log.exception("Ctrl-c hit") 

1503 exit(2) 

1504 except Exception: 

1505 log.exception("Top level error") 

1506 exit(1) 

1507 finally: 

1508 # Ensure the request socket is closed on the child side in all circumstances 

1509 # before the process fully terminates. 

1510 if SUPERVISOR_COMMS and SUPERVISOR_COMMS.socket: 

1511 with suppress(Exception): 

1512 SUPERVISOR_COMMS.socket.close() 

1513 

1514 

1515def reinit_supervisor_comms() -> None: 

1516 """ 

1517 Re-initialize supervisor comms and logging channel in subprocess. 

1518 

1519 This is not needed for most cases, but is used when either we re-launch the process via sudo for 

1520 run_as_user, or from inside the python code in a virtualenv (et al.) operator to re-connect so those tasks 

1521 can continue to access variables etc. 

1522 """ 

1523 import socket 

1524 

1525 if "SUPERVISOR_COMMS" not in globals(): 

1526 global SUPERVISOR_COMMS 

1527 log = structlog.get_logger(logger_name="task") 

1528 

1529 fd = int(os.environ.get("__AIRFLOW_SUPERVISOR_FD", "0")) 

1530 

1531 SUPERVISOR_COMMS = CommsDecoder[ToTask, ToSupervisor](log=log, socket=socket.socket(fileno=fd)) 

1532 

1533 logs = SUPERVISOR_COMMS.send(ResendLoggingFD()) 

1534 if isinstance(logs, SentFDs): 

1535 from airflow.sdk.log import configure_logging 

1536 

1537 log_io = os.fdopen(logs.fds[0], "wb", buffering=0) 

1538 configure_logging(json_output=True, output=log_io, sending_to_supervisor=True) 

1539 else: 

1540 print("Unable to re-configure logging after sudo, we didn't get an FD", file=sys.stderr) 

1541 

1542 

1543if __name__ == "__main__": 

1544 main()