Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/airflow/sdk/execution_time/task_runner.py: 14%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

796 statements  

1# 

2# Licensed to the Apache Software Foundation (ASF) under one 

3# or more contributor license agreements. See the NOTICE file 

4# distributed with this work for additional information 

5# regarding copyright ownership. The ASF licenses this file 

6# to you under the Apache License, Version 2.0 (the 

7# "License"); you may not use this file except in compliance 

8# with the License. You may obtain a copy of the License at 

9# 

10# http://www.apache.org/licenses/LICENSE-2.0 

11# 

12# Unless required by applicable law or agreed to in writing, 

13# software distributed under the License is distributed on an 

14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

15# KIND, either express or implied. See the License for the 

16# specific language governing permissions and limitations 

17# under the License. 

18"""The entrypoint for the actual task execution process.""" 

19 

20from __future__ import annotations 

21 

22import contextlib 

23import contextvars 

24import functools 

25import os 

26import sys 

27import time 

28from collections.abc import Callable, Iterable, Iterator, Mapping 

29from contextlib import suppress 

30from datetime import datetime, timedelta, timezone 

31from itertools import product 

32from pathlib import Path 

33from typing import TYPE_CHECKING, Annotated, Any, Literal 

34from urllib.parse import quote 

35 

36import attrs 

37import lazy_object_proxy 

38import structlog 

39from pydantic import AwareDatetime, ConfigDict, Field, JsonValue, TypeAdapter 

40 

41from airflow.dag_processing.bundles.base import BaseDagBundle, BundleVersionLock 

42from airflow.dag_processing.bundles.manager import DagBundlesManager 

43from airflow.sdk._shared.observability.metrics.stats import Stats 

44from airflow.sdk.api.client import get_hostname, getuser 

45from airflow.sdk.api.datamodels._generated import ( 

46 AssetProfile, 

47 DagRun, 

48 PreviousTIResponse, 

49 TaskInstance, 

50 TaskInstanceState, 

51 TIRunContext, 

52) 

53from airflow.sdk.bases.operator import BaseOperator, ExecutorSafeguard 

54from airflow.sdk.bases.xcom import BaseXCom 

55from airflow.sdk.configuration import conf 

56from airflow.sdk.definitions._internal.dag_parsing_context import _airflow_parsing_context_manager 

57from airflow.sdk.definitions._internal.types import NOTSET, ArgNotSet, is_arg_set 

58from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetNameRef, AssetUniqueKey, AssetUriRef 

59from airflow.sdk.definitions.mappedoperator import MappedOperator 

60from airflow.sdk.definitions.param import process_params 

61from airflow.sdk.exceptions import ( 

62 AirflowException, 

63 AirflowInactiveAssetInInletOrOutletException, 

64 AirflowRescheduleException, 

65 AirflowRuntimeError, 

66 AirflowTaskTimeout, 

67 ErrorType, 

68 TaskDeferred, 

69) 

70from airflow.sdk.execution_time.callback_runner import create_executable_runner 

71from airflow.sdk.execution_time.comms import ( 

72 AssetEventDagRunReferenceResult, 

73 CommsDecoder, 

74 DagRunStateResult, 

75 DeferTask, 

76 DRCount, 

77 ErrorResponse, 

78 GetDagRunState, 

79 GetDRCount, 

80 GetPreviousDagRun, 

81 GetPreviousTI, 

82 GetTaskBreadcrumbs, 

83 GetTaskRescheduleStartDate, 

84 GetTaskStates, 

85 GetTICount, 

86 InactiveAssetsResult, 

87 PreviousDagRunResult, 

88 PreviousTIResult, 

89 RescheduleTask, 

90 ResendLoggingFD, 

91 RetryTask, 

92 SentFDs, 

93 SetRenderedFields, 

94 SetRenderedMapIndex, 

95 SkipDownstreamTasks, 

96 StartupDetails, 

97 SucceedTask, 

98 TaskBreadcrumbsResult, 

99 TaskRescheduleStartDate, 

100 TaskState, 

101 TaskStatesResult, 

102 TICount, 

103 ToSupervisor, 

104 ToTask, 

105 TriggerDagRun, 

106 ValidateInletsAndOutlets, 

107) 

108from airflow.sdk.execution_time.context import ( 

109 ConnectionAccessor, 

110 InletEventsAccessors, 

111 MacrosAccessor, 

112 OutletEventAccessors, 

113 TriggeringAssetEventsAccessor, 

114 VariableAccessor, 

115 context_get_outlet_events, 

116 context_to_airflow_vars, 

117 get_previous_dagrun_success, 

118 set_current_context, 

119) 

120from airflow.sdk.execution_time.sentry import Sentry 

121from airflow.sdk.execution_time.xcom import XCom 

122from airflow.sdk.listener import get_listener_manager 

123from airflow.sdk.timezone import coerce_datetime 

124 

125if TYPE_CHECKING: 

126 import jinja2 

127 from pendulum.datetime import DateTime 

128 from structlog.typing import FilteringBoundLogger as Logger 

129 

130 from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator 

131 from airflow.sdk.definitions.context import Context 

132 from airflow.sdk.exceptions import DagRunTriggerException 

133 from airflow.sdk.types import OutletEventAccessorsProtocol 

134 

135 

136class TaskRunnerMarker: 

137 """Marker for listener hooks, to properly detect from which component they are called.""" 

138 

139 

140# TODO: Move this entire class into a separate file: 

141# `airflow/sdk/execution_time/task_instance.py` 

142# or `airflow/sdk/execution_time/runtime_ti.py` 

143class RuntimeTaskInstance(TaskInstance): 

144 model_config = ConfigDict(arbitrary_types_allowed=True) 

145 

146 task: BaseOperator 

147 bundle_instance: BaseDagBundle 

148 _cached_template_context: Context | None = None 

149 """The Task Instance context. This is used to cache get_template_context.""" 

150 

151 _ti_context_from_server: Annotated[TIRunContext | None, Field(repr=False)] = None 

152 """The Task Instance context from the API server, if any.""" 

153 

154 max_tries: int = 0 

155 """The maximum number of retries for the task.""" 

156 

157 start_date: AwareDatetime 

158 """Start date of the task instance.""" 

159 

160 end_date: AwareDatetime | None = None 

161 

162 state: TaskInstanceState | None = None 

163 

164 is_mapped: bool | None = None 

165 """True if the original task was mapped.""" 

166 

167 rendered_map_index: str | None = None 

168 

169 sentry_integration: str = "" 

170 

171 def __rich_repr__(self): 

172 yield "id", self.id 

173 yield "task_id", self.task_id 

174 yield "dag_id", self.dag_id 

175 yield "run_id", self.run_id 

176 yield "max_tries", self.max_tries 

177 yield "task", type(self.task) 

178 yield "start_date", self.start_date 

179 

180 __rich_repr__.angular = True # type: ignore[attr-defined] 

181 

182 def get_template_context(self) -> Context: 

183 # TODO: Move this to `airflow.sdk.execution_time.context` 

184 # once we port the entire context logic from airflow/utils/context.py ? 

185 from airflow.sdk.plugins_manager import integrate_macros_plugins 

186 

187 integrate_macros_plugins() 

188 

189 dag_run_conf: dict[str, Any] | None = None 

190 if from_server := self._ti_context_from_server: 

191 dag_run_conf = from_server.dag_run.conf or dag_run_conf 

192 

193 validated_params = process_params(self.task.dag, self.task, dag_run_conf, suppress_exception=False) 

194 

195 # Cache the context object, which ensures that all calls to get_template_context 

196 # are operating on the same context object. 

197 self._cached_template_context: Context = self._cached_template_context or { 

198 # From the Task Execution interface 

199 "dag": self.task.dag, 

200 "inlets": self.task.inlets, 

201 "map_index_template": self.task.map_index_template, 

202 "outlets": self.task.outlets, 

203 "run_id": self.run_id, 

204 "task": self.task, 

205 "task_instance": self, 

206 "ti": self, 

207 "outlet_events": OutletEventAccessors(), 

208 "inlet_events": InletEventsAccessors(self.task.inlets), 

209 "macros": MacrosAccessor(), 

210 "params": validated_params, 

211 # TODO: Make this go through Public API longer term. 

212 # "test_mode": task_instance.test_mode, 

213 "var": { 

214 "json": VariableAccessor(deserialize_json=True), 

215 "value": VariableAccessor(deserialize_json=False), 

216 }, 

217 "conn": ConnectionAccessor(), 

218 } 

219 if from_server: 

220 dag_run = from_server.dag_run 

221 context_from_server: Context = { 

222 # TODO: Assess if we need to pass these through timezone.coerce_datetime 

223 "dag_run": dag_run, # type: ignore[typeddict-item] # Removable after #46522 

224 "triggering_asset_events": TriggeringAssetEventsAccessor.build( 

225 AssetEventDagRunReferenceResult.from_asset_event_dag_run_reference(event) 

226 for event in dag_run.consumed_asset_events 

227 ), 

228 "task_instance_key_str": f"{self.task.dag_id}__{self.task.task_id}__{dag_run.run_id}", 

229 "task_reschedule_count": from_server.task_reschedule_count or 0, 

230 "prev_start_date_success": lazy_object_proxy.Proxy( 

231 lambda: coerce_datetime(get_previous_dagrun_success(self.id).start_date) 

232 ), 

233 "prev_end_date_success": lazy_object_proxy.Proxy( 

234 lambda: coerce_datetime(get_previous_dagrun_success(self.id).end_date) 

235 ), 

236 } 

237 self._cached_template_context.update(context_from_server) 

238 

239 if logical_date := coerce_datetime(dag_run.logical_date): 

240 if TYPE_CHECKING: 

241 assert isinstance(logical_date, DateTime) 

242 ds = logical_date.strftime("%Y-%m-%d") 

243 ds_nodash = ds.replace("-", "") 

244 ts = logical_date.isoformat() 

245 ts_nodash = logical_date.strftime("%Y%m%dT%H%M%S") 

246 ts_nodash_with_tz = ts.replace("-", "").replace(":", "") 

247 # logical_date and data_interval either coexist or be None together 

248 self._cached_template_context.update( 

249 { 

250 # keys that depend on logical_date 

251 "logical_date": logical_date, 

252 "ds": ds, 

253 "ds_nodash": ds_nodash, 

254 "task_instance_key_str": f"{self.task.dag_id}__{self.task.task_id}__{ds_nodash}", 

255 "ts": ts, 

256 "ts_nodash": ts_nodash, 

257 "ts_nodash_with_tz": ts_nodash_with_tz, 

258 # keys that depend on data_interval 

259 "data_interval_end": coerce_datetime(dag_run.data_interval_end), 

260 "data_interval_start": coerce_datetime(dag_run.data_interval_start), 

261 "prev_data_interval_start_success": lazy_object_proxy.Proxy( 

262 lambda: coerce_datetime(get_previous_dagrun_success(self.id).data_interval_start) 

263 ), 

264 "prev_data_interval_end_success": lazy_object_proxy.Proxy( 

265 lambda: coerce_datetime(get_previous_dagrun_success(self.id).data_interval_end) 

266 ), 

267 } 

268 ) 

269 

270 # Backward compatibility: old servers may still send upstream_map_indexes 

271 upstream_map_indexes = getattr(from_server, "upstream_map_indexes", None) 

272 if upstream_map_indexes is not None: 

273 setattr(self, "_upstream_map_indexes", upstream_map_indexes) 

274 

275 return self._cached_template_context 

276 

277 def render_templates( 

278 self, context: Context | None = None, jinja_env: jinja2.Environment | None = None 

279 ) -> BaseOperator: 

280 """ 

281 Render templates in the operator fields. 

282 

283 If the task was originally mapped, this may replace ``self.task`` with 

284 the unmapped, fully rendered BaseOperator. The original ``self.task`` 

285 before replacement is returned. 

286 """ 

287 if not context: 

288 context = self.get_template_context() 

289 original_task = self.task 

290 

291 if TYPE_CHECKING: 

292 assert context 

293 

294 ti = context["ti"] 

295 

296 if TYPE_CHECKING: 

297 assert original_task 

298 assert self.task 

299 assert ti.task 

300 

301 # If self.task is mapped, this call replaces self.task to point to the 

302 # unmapped BaseOperator created by this function! This is because the 

303 # MappedOperator is useless for template rendering, and we need to be 

304 # able to access the unmapped task instead. 

305 self.task.render_template_fields(context, jinja_env) 

306 self.is_mapped = original_task.is_mapped 

307 return original_task 

308 

309 def xcom_pull( 

310 self, 

311 task_ids: str | Iterable[str] | None = None, 

312 dag_id: str | None = None, 

313 key: str = BaseXCom.XCOM_RETURN_KEY, 

314 include_prior_dates: bool = False, 

315 *, 

316 map_indexes: int | Iterable[int] | None | ArgNotSet = NOTSET, 

317 default: Any = None, 

318 run_id: str | None = None, 

319 ) -> Any: 

320 """ 

321 Pull XComs either from the API server (BaseXCom) or from the custom XCOM backend if configured. 

322 

323 The pull can be filtered optionally by certain criterion. 

324 

325 :param key: A key for the XCom. If provided, only XComs with matching 

326 keys will be returned. The default key is ``'return_value'``, also 

327 available as constant ``XCOM_RETURN_KEY``. This key is automatically 

328 given to XComs returned by tasks (as opposed to being pushed 

329 manually). 

330 :param task_ids: Only XComs from tasks with matching ids will be 

331 pulled. If *None* (default), the task_id of the calling task is used. 

332 :param dag_id: If provided, only pulls XComs from this Dag. If *None* 

333 (default), the Dag of the calling task is used. 

334 :param map_indexes: If provided, only pull XComs with matching indexes. 

335 If *None* (default), this is inferred from the task(s) being pulled 

336 (see below for details). 

337 :param include_prior_dates: If False, only XComs from the current 

338 logical_date are returned. If *True*, XComs from previous dates 

339 are returned as well. 

340 :param run_id: If provided, only pulls XComs from a DagRun w/a matching run_id. 

341 If *None* (default), the run_id of the calling task is used. 

342 

343 When pulling one single task (``task_id`` is *None* or a str) without 

344 specifying ``map_indexes``, the return value is a single XCom entry 

345 (map_indexes is set to map_index of the calling task instance). 

346 

347 When pulling task is mapped the specified ``map_index`` is used, so by default 

348 pulling on mapped task will result in no matching XComs if the task instance 

349 of the method call is not mapped. Otherwise, the map_index of the calling task 

350 instance is used. Setting ``map_indexes`` to *None* will pull XCom as it would 

351 from a non mapped task. 

352 

353 In either case, ``default`` (*None* if not specified) is returned if no 

354 matching XComs are found. 

355 

356 When pulling multiple tasks (i.e. either ``task_id`` or ``map_index`` is 

357 a non-str iterable), a list of matching XComs is returned. Elements in 

358 the list is ordered by item ordering in ``task_id`` and ``map_index``. 

359 """ 

360 if dag_id is None: 

361 dag_id = self.dag_id 

362 if run_id is None: 

363 run_id = self.run_id 

364 

365 single_task_requested = isinstance(task_ids, (str, type(None))) 

366 single_map_index_requested = isinstance(map_indexes, (int, type(None))) 

367 

368 if task_ids is None: 

369 # default to the current task if not provided 

370 task_ids = [self.task_id] 

371 elif isinstance(task_ids, str): 

372 task_ids = [task_ids] 

373 

374 # If map_indexes is not specified, pull xcoms from all map indexes for each task 

375 if not is_arg_set(map_indexes): 

376 xcoms: list[Any] = [] 

377 for t_id in task_ids: 

378 values = XCom.get_all( 

379 run_id=run_id, 

380 key=key, 

381 task_id=t_id, 

382 dag_id=dag_id, 

383 include_prior_dates=include_prior_dates, 

384 ) 

385 

386 if values is None: 

387 xcoms.append(None) 

388 else: 

389 xcoms.extend(values) 

390 # For single task pulling from unmapped task, return single value 

391 if single_task_requested and len(xcoms) == 1: 

392 return xcoms[0] 

393 return xcoms 

394 

395 # Original logic when map_indexes is explicitly specified 

396 map_indexes_iterable: Iterable[int | None] = [] 

397 if isinstance(map_indexes, int) or map_indexes is None: 

398 map_indexes_iterable = [map_indexes] 

399 elif isinstance(map_indexes, Iterable): 

400 map_indexes_iterable = map_indexes 

401 else: 

402 raise TypeError( 

403 f"Invalid type for map_indexes: expected int, iterable of ints, or None, got {type(map_indexes)}" 

404 ) 

405 

406 xcoms = [] 

407 for t_id, m_idx in product(task_ids, map_indexes_iterable): 

408 value = XCom.get_one( 

409 run_id=run_id, 

410 key=key, 

411 task_id=t_id, 

412 dag_id=dag_id, 

413 map_index=m_idx, 

414 include_prior_dates=include_prior_dates, 

415 ) 

416 if value is None: 

417 xcoms.append(default) 

418 else: 

419 xcoms.append(value) 

420 

421 if single_task_requested and single_map_index_requested: 

422 return xcoms[0] 

423 

424 return xcoms 

425 

426 def xcom_push(self, key: str, value: Any): 

427 """ 

428 Make an XCom available for tasks to pull. 

429 

430 :param key: Key to store the value under. 

431 :param value: Value to store. Only be JSON-serializable values may be used. 

432 """ 

433 _xcom_push(self, key, value) 

434 

435 def get_relevant_upstream_map_indexes( 

436 self, upstream: BaseOperator, ti_count: int | None, session: Any 

437 ) -> int | range | None: 

438 """ 

439 Compute the relevant upstream map indexes for XCom resolution. 

440 

441 :param upstream: The upstream operator 

442 :param ti_count: The total count of task instances for this task's expansion 

443 :param session: Not used (kept for API compatibility) 

444 :return: None (use entire value), int (single index), or range (subset of indexes) 

445 """ 

446 from airflow.sdk.execution_time.task_mapping import get_relevant_map_indexes, get_ti_count_for_task 

447 

448 map_index = self.map_index 

449 if map_index is None or map_index < 0: 

450 return None 

451 

452 # If ti_count not provided, we need to query it 

453 if ti_count is None: 

454 ti_count = get_ti_count_for_task(self.task_id, self.dag_id, self.run_id) 

455 

456 if not ti_count: 

457 return None 

458 

459 return get_relevant_map_indexes( 

460 task=self.task, 

461 run_id=self.run_id, 

462 map_index=map_index, 

463 ti_count=ti_count, 

464 relative=upstream, 

465 dag_id=self.dag_id, 

466 ) 

467 

468 def get_first_reschedule_date(self, context: Context) -> AwareDatetime | None: 

469 """Get the first reschedule date for the task instance if found, none otherwise.""" 

470 if context.get("task_reschedule_count", 0) == 0: 

471 # If the task has not been rescheduled, there is no need to ask the supervisor 

472 return None 

473 

474 max_tries: int = self.max_tries 

475 retries: int = self.task.retries or 0 

476 first_try_number = max_tries - retries + 1 

477 

478 log = structlog.get_logger(logger_name="task") 

479 

480 log.debug("Requesting first reschedule date from supervisor") 

481 

482 response = SUPERVISOR_COMMS.send( 

483 msg=GetTaskRescheduleStartDate(ti_id=self.id, try_number=first_try_number) 

484 ) 

485 

486 if TYPE_CHECKING: 

487 assert isinstance(response, TaskRescheduleStartDate) 

488 

489 return response.start_date 

490 

491 def get_previous_dagrun(self, state: str | None = None) -> DagRun | None: 

492 """Return the previous Dag run before the given logical date, optionally filtered by state.""" 

493 context = self.get_template_context() 

494 dag_run = context.get("dag_run") 

495 

496 log = structlog.get_logger(logger_name="task") 

497 

498 log.debug("Getting previous Dag run", dag_run=dag_run) 

499 

500 if dag_run is None: 

501 return None 

502 

503 if dag_run.logical_date is None: 

504 return None 

505 

506 response = SUPERVISOR_COMMS.send( 

507 msg=GetPreviousDagRun(dag_id=self.dag_id, logical_date=dag_run.logical_date, state=state) 

508 ) 

509 

510 if TYPE_CHECKING: 

511 assert isinstance(response, PreviousDagRunResult) 

512 

513 return response.dag_run 

514 

515 def get_previous_ti( 

516 self, 

517 state: TaskInstanceState | None = None, 

518 logical_date: AwareDatetime | None = None, 

519 map_index: int = -1, 

520 ) -> PreviousTIResponse | None: 

521 """ 

522 Return the previous task instance matching the given criteria. 

523 

524 :param state: Filter by TaskInstance state 

525 :param logical_date: Filter by logical date (returns TI before this date) 

526 :param map_index: Filter by map_index (defaults to -1 for non-mapped tasks) 

527 :return: Previous task instance or None if not found 

528 """ 

529 context = self.get_template_context() 

530 dag_run = context.get("dag_run") 

531 

532 log = structlog.get_logger(logger_name="task") 

533 log.debug("Getting previous task instance", task_id=self.task_id, state=state) 

534 

535 # Use current dag run's logical_date if not provided 

536 effective_logical_date = logical_date 

537 if effective_logical_date is None and dag_run and dag_run.logical_date: 

538 effective_logical_date = dag_run.logical_date 

539 

540 response = SUPERVISOR_COMMS.send( 

541 msg=GetPreviousTI( 

542 dag_id=self.dag_id, 

543 task_id=self.task_id, 

544 logical_date=effective_logical_date, 

545 map_index=map_index, 

546 state=state, 

547 ) 

548 ) 

549 

550 if TYPE_CHECKING: 

551 assert isinstance(response, PreviousTIResult) 

552 

553 return response.task_instance 

554 

555 @staticmethod 

556 def get_ti_count( 

557 dag_id: str, 

558 map_index: int | None = None, 

559 task_ids: list[str] | None = None, 

560 task_group_id: str | None = None, 

561 logical_dates: list[datetime] | None = None, 

562 run_ids: list[str] | None = None, 

563 states: list[str] | None = None, 

564 ) -> int: 

565 """Return the number of task instances matching the given criteria.""" 

566 response = SUPERVISOR_COMMS.send( 

567 GetTICount( 

568 dag_id=dag_id, 

569 map_index=map_index, 

570 task_ids=task_ids, 

571 task_group_id=task_group_id, 

572 logical_dates=logical_dates, 

573 run_ids=run_ids, 

574 states=states, 

575 ), 

576 ) 

577 

578 if TYPE_CHECKING: 

579 assert isinstance(response, TICount) 

580 

581 return response.count 

582 

583 @staticmethod 

584 def get_task_states( 

585 dag_id: str, 

586 map_index: int | None = None, 

587 task_ids: list[str] | None = None, 

588 task_group_id: str | None = None, 

589 logical_dates: list[datetime] | None = None, 

590 run_ids: list[str] | None = None, 

591 ) -> dict[str, Any]: 

592 """Return the task states matching the given criteria.""" 

593 response = SUPERVISOR_COMMS.send( 

594 GetTaskStates( 

595 dag_id=dag_id, 

596 map_index=map_index, 

597 task_ids=task_ids, 

598 task_group_id=task_group_id, 

599 logical_dates=logical_dates, 

600 run_ids=run_ids, 

601 ), 

602 ) 

603 

604 if TYPE_CHECKING: 

605 assert isinstance(response, TaskStatesResult) 

606 

607 return response.task_states 

608 

609 @staticmethod 

610 def get_task_breadcrumbs(dag_id: str, run_id: str) -> Iterable[dict[str, Any]]: 

611 """Return task breadcrumbs for the given dag run.""" 

612 response = SUPERVISOR_COMMS.send(GetTaskBreadcrumbs(dag_id=dag_id, run_id=run_id)) 

613 if TYPE_CHECKING: 

614 assert isinstance(response, TaskBreadcrumbsResult) 

615 return response.breadcrumbs 

616 

617 @staticmethod 

618 def get_dr_count( 

619 dag_id: str, 

620 logical_dates: list[datetime] | None = None, 

621 run_ids: list[str] | None = None, 

622 states: list[str] | None = None, 

623 ) -> int: 

624 """Return the number of Dag runs matching the given criteria.""" 

625 response = SUPERVISOR_COMMS.send( 

626 GetDRCount( 

627 dag_id=dag_id, 

628 logical_dates=logical_dates, 

629 run_ids=run_ids, 

630 states=states, 

631 ), 

632 ) 

633 

634 if TYPE_CHECKING: 

635 assert isinstance(response, DRCount) 

636 

637 return response.count 

638 

639 @staticmethod 

640 def get_dagrun_state(dag_id: str, run_id: str) -> str: 

641 """Return the state of the Dag run with the given Run ID.""" 

642 response = SUPERVISOR_COMMS.send(msg=GetDagRunState(dag_id=dag_id, run_id=run_id)) 

643 

644 if TYPE_CHECKING: 

645 assert isinstance(response, DagRunStateResult) 

646 

647 return response.state 

648 

649 @property 

650 def log_url(self) -> str: 

651 run_id = quote(self.run_id) 

652 base_url = conf.get("api", "base_url", fallback="http://localhost:8080/") 

653 map_index_value = self.map_index 

654 map_index = ( 

655 f"/mapped/{map_index_value}" if map_index_value is not None and map_index_value >= 0 else "" 

656 ) 

657 try_number_value = self.try_number 

658 try_number = ( 

659 f"?try_number={try_number_value}" if try_number_value is not None and try_number_value > 0 else "" 

660 ) 

661 _log_uri = f"{base_url.rstrip('/')}/dags/{self.dag_id}/runs/{run_id}/tasks/{self.task_id}{map_index}{try_number}" 

662 return _log_uri 

663 

664 @property 

665 def mark_success_url(self) -> str: 

666 """URL to mark TI success.""" 

667 return self.log_url 

668 

669 

670def _xcom_push(ti: RuntimeTaskInstance, key: str, value: Any, mapped_length: int | None = None) -> None: 

671 """Push a XCom through XCom.set, which pushes to XCom Backend if configured.""" 

672 # Private function, as we don't want to expose the ability to manually set `mapped_length` to SDK 

673 # consumers 

674 

675 XCom.set( 

676 key=key, 

677 value=value, 

678 dag_id=ti.dag_id, 

679 task_id=ti.task_id, 

680 run_id=ti.run_id, 

681 map_index=ti.map_index, 

682 _mapped_length=mapped_length, 

683 ) 

684 

685 

686def _xcom_push_to_db(ti: RuntimeTaskInstance, key: str, value: Any) -> None: 

687 """Push a XCom directly to metadata DB, bypassing custom xcom_backend.""" 

688 XCom._set_xcom_in_db( 

689 key=key, 

690 value=value, 

691 dag_id=ti.dag_id, 

692 task_id=ti.task_id, 

693 run_id=ti.run_id, 

694 map_index=ti.map_index, 

695 ) 

696 

697 

698def _maybe_reschedule_startup_failure( 

699 *, 

700 ti_context: TIRunContext, 

701 log: Logger, 

702) -> None: 

703 """ 

704 Attempt to reschedule the task when a startup failure occurs. 

705 

706 This does not count as a retry. If the reschedule limit is exceeded, this function 

707 returns and the caller should fail the task. 

708 """ 

709 missing_dag_retires = conf.getint("workers", "missing_dag_retires", fallback=3) 

710 missing_dag_retry_delay = conf.getint("workers", "missing_dag_retry_delay", fallback=60) 

711 

712 reschedule_count = int(getattr(ti_context, "task_reschedule_count", 0) or 0) 

713 if missing_dag_retires > 0 and reschedule_count < missing_dag_retires: 

714 raise AirflowRescheduleException( 

715 reschedule_date=datetime.now(tz=timezone.utc) + timedelta(seconds=missing_dag_retry_delay) 

716 ) 

717 

718 log.error( 

719 "Startup reschedule limit exceeded", 

720 reschedule_count=reschedule_count, 

721 max_reschedules=missing_dag_retires, 

722 ) 

723 

724 

725def parse(what: StartupDetails, log: Logger) -> RuntimeTaskInstance: 

726 # TODO: Task-SDK: 

727 # Using BundleDagBag here is about 98% wrong, but it'll do for now 

728 from airflow.dag_processing.dagbag import BundleDagBag 

729 

730 bundle_info = what.bundle_info 

731 bundle_instance = DagBundlesManager().get_bundle( 

732 name=bundle_info.name, 

733 version=bundle_info.version, 

734 ) 

735 bundle_instance.initialize() 

736 _verify_bundle_access(bundle_instance, log) 

737 

738 dag_absolute_path = os.fspath(Path(bundle_instance.path, what.dag_rel_path)) 

739 bag = BundleDagBag( 

740 dag_folder=dag_absolute_path, 

741 safe_mode=False, 

742 load_op_links=False, 

743 bundle_path=bundle_instance.path, 

744 bundle_name=bundle_info.name, 

745 ) 

746 if TYPE_CHECKING: 

747 assert what.ti.dag_id 

748 

749 try: 

750 dag = bag.dags[what.ti.dag_id] 

751 except KeyError: 

752 log.error( 

753 "Dag not found during start up", dag_id=what.ti.dag_id, bundle=bundle_info, path=what.dag_rel_path 

754 ) 

755 _maybe_reschedule_startup_failure(ti_context=what.ti_context, log=log) 

756 sys.exit(1) 

757 

758 # install_loader() 

759 

760 try: 

761 task = dag.task_dict[what.ti.task_id] 

762 except KeyError: 

763 log.error( 

764 "Task not found in Dag during start up", 

765 dag_id=dag.dag_id, 

766 task_id=what.ti.task_id, 

767 bundle=bundle_info, 

768 path=what.dag_rel_path, 

769 ) 

770 _maybe_reschedule_startup_failure(ti_context=what.ti_context, log=log) 

771 sys.exit(1) 

772 

773 if not isinstance(task, (BaseOperator, MappedOperator)): 

774 raise TypeError( 

775 f"task is of the wrong type, got {type(task)}, wanted {BaseOperator} or {MappedOperator}" 

776 ) 

777 

778 return RuntimeTaskInstance.model_construct( 

779 **what.ti.model_dump(exclude_unset=True), 

780 task=task, 

781 bundle_instance=bundle_instance, 

782 _ti_context_from_server=what.ti_context, 

783 max_tries=what.ti_context.max_tries, 

784 start_date=what.start_date, 

785 state=TaskInstanceState.RUNNING, 

786 sentry_integration=what.sentry_integration, 

787 ) 

788 

789 

790# This global variable will be used by Connection/Variable/XCom classes, or other parts of the task's execution, 

791# to send requests back to the supervisor process. 

792# 

793# Why it needs to be a global: 

794# - Many parts of Airflow's codebase (e.g., connections, variables, and XComs) may rely on making dynamic requests 

795# to the parent process during task execution. 

796# - These calls occur in various locations and cannot easily pass the `CommsDecoder` instance through the 

797# deeply nested execution stack. 

798# - By defining `SUPERVISOR_COMMS` as a global, it ensures that this communication mechanism is readily 

799# accessible wherever needed during task execution without modifying every layer of the call stack. 

800SUPERVISOR_COMMS: CommsDecoder[ToTask, ToSupervisor] 

801 

802 

803# State machine! 

804# 1. Start up (receive details from supervisor) 

805# 2. Execution (run task code, possibly send requests) 

806# 3. Shutdown and report status 

807 

808 

809def _verify_bundle_access(bundle_instance: BaseDagBundle, log: Logger) -> None: 

810 """ 

811 Verify bundle is accessible by the current user. 

812 

813 This is called after user impersonation (if any) to ensure the bundle 

814 is actually accessible. Uses os.access() which works with any permission 

815 scheme (standard Unix permissions, ACLs, SELinux, etc.). 

816 

817 :param bundle_instance: The bundle instance to check 

818 :param log: Logger instance 

819 :raises AirflowException: if bundle is not accessible 

820 """ 

821 from getpass import getuser 

822 

823 from airflow.sdk.exceptions import AirflowException 

824 

825 bundle_path = bundle_instance.path 

826 

827 if not bundle_path.exists(): 

828 # Already handled by initialize() with a warning 

829 return 

830 

831 # Check read permission (and execute for directories to list contents) 

832 access_mode = os.R_OK 

833 if bundle_path.is_dir(): 

834 access_mode |= os.X_OK 

835 

836 if not os.access(bundle_path, access_mode): 

837 raise AirflowException( 

838 f"Bundle '{bundle_instance.name}' path '{bundle_path}' is not accessible " 

839 f"by user '{getuser()}'. When using run_as_user, ensure bundle directories " 

840 f"are readable by the impersonated user. " 

841 f"See: https://airflow.apache.org/docs/apache-airflow/stable/administration-and-deployment/dag-bundles.html" 

842 ) 

843 

844 

845def startup() -> tuple[RuntimeTaskInstance, Context, Logger]: 

846 # The parent sends us a StartupDetails message un-prompted. After this, every single message is only sent 

847 # in response to us sending a request. 

848 log = structlog.get_logger(logger_name="task") 

849 

850 if os.environ.get("_AIRFLOW__REEXECUTED_PROCESS") == "1" and ( 

851 msgjson := os.environ.get("_AIRFLOW__STARTUP_MSG") 

852 ): 

853 # Clear any Kerberos replace cache if there is one, so new process can't reuse it. 

854 os.environ.pop("KRB5CCNAME", None) 

855 # entrypoint of re-exec process 

856 

857 msg: StartupDetails = TypeAdapter(StartupDetails).validate_json(msgjson) 

858 reinit_supervisor_comms() 

859 

860 # We delay this message until _after_ we've got the logging re-configured, otherwise it will show up 

861 # on stdout 

862 log.debug("Using serialized startup message from environment", msg=msg) 

863 else: 

864 # normal entry point 

865 msg = SUPERVISOR_COMMS._get_response() # type: ignore[assignment] 

866 

867 if not isinstance(msg, StartupDetails): 

868 raise RuntimeError(f"Unhandled startup message {type(msg)} {msg}") 

869 

870 # setproctitle causes issue on Mac OS: https://github.com/benoitc/gunicorn/issues/3021 

871 os_type = sys.platform 

872 if os_type == "darwin": 

873 log.debug("Mac OS detected, skipping setproctitle") 

874 else: 

875 from setproctitle import setproctitle 

876 

877 setproctitle(f"airflow worker -- {msg.ti.id}") 

878 

879 try: 

880 get_listener_manager().hook.on_starting(component=TaskRunnerMarker()) 

881 except Exception: 

882 log.exception("error calling listener") 

883 

884 with _airflow_parsing_context_manager(dag_id=msg.ti.dag_id, task_id=msg.ti.task_id): 

885 ti = parse(msg, log) 

886 log.debug("Dag file parsed", file=msg.dag_rel_path) 

887 

888 run_as_user = getattr(ti.task, "run_as_user", None) or conf.get( 

889 "core", "default_impersonation", fallback=None 

890 ) 

891 

892 if os.environ.get("_AIRFLOW__REEXECUTED_PROCESS") != "1" and run_as_user and run_as_user != getuser(): 

893 # enters here for re-exec process 

894 os.environ["_AIRFLOW__REEXECUTED_PROCESS"] = "1" 

895 # store startup message in environment for re-exec process 

896 os.environ["_AIRFLOW__STARTUP_MSG"] = msg.model_dump_json() 

897 os.set_inheritable(SUPERVISOR_COMMS.socket.fileno(), True) 

898 

899 # Import main directly from the module instead of re-executing the file. 

900 # This ensures that when other parts modules import 

901 # airflow.sdk.execution_time.task_runner, they get the same module instance 

902 # with the properly initialized SUPERVISOR_COMMS global variable. 

903 # If we re-executed the module with `python -m`, it would load as __main__ and future 

904 # imports would get a fresh copy without the initialized globals. 

905 rexec_python_code = "from airflow.sdk.execution_time.task_runner import main; main()" 

906 cmd = ["sudo", "-E", "-H", "-u", run_as_user, sys.executable, "-c", rexec_python_code] 

907 log.info( 

908 "Running command", 

909 command=cmd, 

910 ) 

911 os.execvp("sudo", cmd) 

912 

913 # ideally, we should never reach here, but if we do, we should return None, None, None 

914 return None, None, None 

915 

916 return ti, ti.get_template_context(), log 

917 

918 

919def _serialize_template_field(template_field: Any, name: str) -> str | dict | list | int | float: 

920 """ 

921 Return a serializable representation of the templated field. 

922 

923 If ``templated_field`` contains a class or instance that requires recursive 

924 templating, store them as strings. Otherwise simply return the field as-is. 

925 

926 Used sdk secrets masker to redact secrets in the serialized output. 

927 """ 

928 import json 

929 

930 from airflow.sdk._shared.secrets_masker import redact 

931 

932 def is_jsonable(x): 

933 try: 

934 json.dumps(x) 

935 except (TypeError, OverflowError): 

936 return False 

937 else: 

938 return True 

939 

940 def translate_tuples_to_lists(obj: Any): 

941 """Recursively convert tuples to lists.""" 

942 if isinstance(obj, tuple): 

943 return [translate_tuples_to_lists(item) for item in obj] 

944 if isinstance(obj, list): 

945 return [translate_tuples_to_lists(item) for item in obj] 

946 if isinstance(obj, dict): 

947 return {key: translate_tuples_to_lists(value) for key, value in obj.items()} 

948 return obj 

949 

950 def sort_dict_recursively(obj: Any) -> Any: 

951 """Recursively sort dictionaries to ensure consistent ordering.""" 

952 if isinstance(obj, dict): 

953 return {k: sort_dict_recursively(v) for k, v in sorted(obj.items())} 

954 if isinstance(obj, list): 

955 return [sort_dict_recursively(item) for item in obj] 

956 if isinstance(obj, tuple): 

957 return tuple(sort_dict_recursively(item) for item in obj) 

958 return obj 

959 

960 def _fallback_serialization(obj): 

961 """Serialize objects with to_dict() method (eg: k8s objects) for json.dumps() default parameter.""" 

962 if hasattr(obj, "to_dict"): 

963 return obj.to_dict() 

964 raise TypeError(f"cannot serialize {obj}") 

965 

966 max_length = conf.getint("core", "max_templated_field_length") 

967 

968 if not is_jsonable(template_field): 

969 try: 

970 serialized = template_field.serialize() 

971 except AttributeError: 

972 # check if these objects can be converted to JSON serializable types 

973 try: 

974 serialized = json.dumps(template_field, default=_fallback_serialization) 

975 except (TypeError, ValueError): 

976 # fall back to string representation if not 

977 serialized = str(template_field) 

978 if len(serialized) > max_length: 

979 rendered = redact(serialized, name) 

980 return ( 

981 "Truncated. You can change this behaviour in [core]max_templated_field_length. " 

982 f"{rendered[: max_length - 79]!r}... " 

983 ) 

984 return serialized 

985 if not template_field and not isinstance(template_field, tuple): 

986 # Avoid unnecessary serialization steps for empty fields unless they are tuples 

987 # and need to be converted to lists 

988 return template_field 

989 template_field = translate_tuples_to_lists(template_field) 

990 # Sort dictionaries recursively to ensure consistent string representation 

991 # This prevents hash inconsistencies when dict ordering varies 

992 if isinstance(template_field, dict): 

993 template_field = sort_dict_recursively(template_field) 

994 serialized = str(template_field) 

995 if len(serialized) > max_length: 

996 rendered = redact(serialized, name) 

997 return ( 

998 "Truncated. You can change this behaviour in [core]max_templated_field_length. " 

999 f"{rendered[: max_length - 79]!r}... " 

1000 ) 

1001 return template_field 

1002 

1003 

1004def _serialize_rendered_fields(task: AbstractOperator) -> dict[str, JsonValue]: 

1005 from airflow.sdk._shared.secrets_masker import redact 

1006 

1007 rendered_fields = {} 

1008 for field in task.template_fields: 

1009 value = getattr(task, field) 

1010 serialized = _serialize_template_field(value, field) 

1011 # Redact secrets in the task process itself before sending to API server 

1012 # This ensures that the secrets those are registered via mask_secret() on workers / dag processor are properly masked 

1013 # on the UI. 

1014 rendered_fields[field] = redact(serialized, field) 

1015 

1016 return rendered_fields # type: ignore[return-value] # Convince mypy that this is OK since we pass JsonValue to redact, so it will return the same 

1017 

1018 

1019def _build_asset_profiles(lineage_objects: list) -> Iterator[AssetProfile]: 

1020 # Lineage can have other types of objects besides assets, so we need to process them a bit. 

1021 for obj in lineage_objects or (): 

1022 if isinstance(obj, Asset): 

1023 yield AssetProfile(name=obj.name, uri=obj.uri, type=Asset.__name__) 

1024 elif isinstance(obj, AssetNameRef): 

1025 yield AssetProfile(name=obj.name, type=AssetNameRef.__name__) 

1026 elif isinstance(obj, AssetUriRef): 

1027 yield AssetProfile(uri=obj.uri, type=AssetUriRef.__name__) 

1028 elif isinstance(obj, AssetAlias): 

1029 yield AssetProfile(name=obj.name, type=AssetAlias.__name__) 

1030 

1031 

1032def _serialize_outlet_events(events: OutletEventAccessorsProtocol) -> Iterator[dict[str, JsonValue]]: 

1033 if TYPE_CHECKING: 

1034 assert isinstance(events, OutletEventAccessors) 

1035 # We just collect everything the user recorded in the accessors. 

1036 # Further filtering will be done in the API server. 

1037 for key, accessor in events._dict.items(): 

1038 if isinstance(key, AssetUniqueKey): 

1039 yield {"dest_asset_key": attrs.asdict(key), "extra": accessor.extra} 

1040 for alias_event in accessor.asset_alias_events: 

1041 yield attrs.asdict(alias_event) 

1042 

1043 

1044def _prepare(ti: RuntimeTaskInstance, log: Logger, context: Context) -> ToSupervisor | None: 

1045 ti.hostname = get_hostname() 

1046 ti.task = ti.task.prepare_for_execution() 

1047 # Since context is now cached, and calling `ti.get_template_context` will return the same dict, we want to 

1048 # update the value of the task that is sent from there 

1049 context["task"] = ti.task 

1050 

1051 jinja_env = ti.task.dag.get_template_env() 

1052 ti.render_templates(context=context, jinja_env=jinja_env) 

1053 

1054 if rendered_fields := _serialize_rendered_fields(ti.task): 

1055 # so that we do not call the API unnecessarily 

1056 SUPERVISOR_COMMS.send(msg=SetRenderedFields(rendered_fields=rendered_fields)) 

1057 

1058 # Try to render map_index_template early with available context (will be re-rendered after execution) 

1059 # This provides a partial label during task execution for templates using pre-execution context 

1060 # If rendering fails here, we suppress the error since it will be re-rendered after execution 

1061 try: 

1062 if rendered_map_index := _render_map_index(context, ti=ti, log=log): 

1063 ti.rendered_map_index = rendered_map_index 

1064 log.debug("Sending early rendered map index", length=len(rendered_map_index)) 

1065 SUPERVISOR_COMMS.send(msg=SetRenderedMapIndex(rendered_map_index=rendered_map_index)) 

1066 except Exception: 

1067 log.debug( 

1068 "Early rendering of map_index_template failed, will retry after task execution", exc_info=True 

1069 ) 

1070 

1071 _validate_task_inlets_and_outlets(ti=ti, log=log) 

1072 

1073 try: 

1074 # TODO: Call pre execute etc. 

1075 get_listener_manager().hook.on_task_instance_running( 

1076 previous_state=TaskInstanceState.QUEUED, task_instance=ti 

1077 ) 

1078 except Exception: 

1079 log.exception("error calling listener") 

1080 

1081 # No error, carry on and execute the task 

1082 return None 

1083 

1084 

1085def _validate_task_inlets_and_outlets(*, ti: RuntimeTaskInstance, log: Logger) -> None: 

1086 if not ti.task.inlets and not ti.task.outlets: 

1087 return 

1088 

1089 inactive_assets_resp = SUPERVISOR_COMMS.send(msg=ValidateInletsAndOutlets(ti_id=ti.id)) 

1090 if TYPE_CHECKING: 

1091 assert isinstance(inactive_assets_resp, InactiveAssetsResult) 

1092 if inactive_assets := inactive_assets_resp.inactive_assets: 

1093 raise AirflowInactiveAssetInInletOrOutletException( 

1094 inactive_asset_keys=[ 

1095 AssetUniqueKey.from_profile(asset_profile) for asset_profile in inactive_assets 

1096 ] 

1097 ) 

1098 

1099 

1100def _defer_task( 

1101 defer: TaskDeferred, ti: RuntimeTaskInstance, log: Logger 

1102) -> tuple[ToSupervisor, TaskInstanceState]: 

1103 # TODO: Should we use structlog.bind_contextvars here for dag_id, task_id & run_id? 

1104 

1105 log.info("Pausing task as DEFERRED. ", dag_id=ti.dag_id, task_id=ti.task_id, run_id=ti.run_id) 

1106 classpath, trigger_kwargs = defer.trigger.serialize() 

1107 queue: str | None = None 

1108 # Currently, only task-associated BaseTrigger instances may have a non-None queue, 

1109 # and only when triggerer.queues_enabled conf is True. 

1110 if conf.getboolean("triggerer", "queues_enabled", fallback=False) and getattr( 

1111 defer.trigger, "supports_triggerer_queue", True 

1112 ): 

1113 queue = ti.task.queue 

1114 

1115 from airflow.sdk.serde import serialize as serde_serialize 

1116 

1117 trigger_kwargs = serde_serialize(trigger_kwargs) 

1118 next_kwargs = serde_serialize(defer.kwargs or {}) 

1119 

1120 if TYPE_CHECKING: 

1121 assert isinstance(next_kwargs, dict) 

1122 assert isinstance(trigger_kwargs, dict) 

1123 

1124 msg = DeferTask( 

1125 classpath=classpath, 

1126 trigger_kwargs=trigger_kwargs, 

1127 trigger_timeout=defer.timeout, 

1128 queue=queue, 

1129 next_method=defer.method_name, 

1130 next_kwargs=next_kwargs, 

1131 ) 

1132 state = TaskInstanceState.DEFERRED 

1133 

1134 return msg, state 

1135 

1136 

1137@Sentry.enrich_errors 

1138def run( 

1139 ti: RuntimeTaskInstance, 

1140 context: Context, 

1141 log: Logger, 

1142) -> tuple[TaskInstanceState, ToSupervisor | None, BaseException | None]: 

1143 """Run the task in this process.""" 

1144 import signal 

1145 

1146 from airflow.sdk.exceptions import ( 

1147 AirflowFailException, 

1148 AirflowRescheduleException, 

1149 AirflowSensorTimeout, 

1150 AirflowSkipException, 

1151 AirflowTaskTerminated, 

1152 DagRunTriggerException, 

1153 DownstreamTasksSkipped, 

1154 TaskDeferred, 

1155 ) 

1156 

1157 if TYPE_CHECKING: 

1158 assert ti.task is not None 

1159 assert isinstance(ti.task, BaseOperator) 

1160 

1161 parent_pid = os.getpid() 

1162 

1163 def _on_term(signum, frame): 

1164 pid = os.getpid() 

1165 if pid != parent_pid: 

1166 return 

1167 

1168 ti.task.on_kill() 

1169 

1170 signal.signal(signal.SIGTERM, _on_term) 

1171 

1172 msg: ToSupervisor | None = None 

1173 state: TaskInstanceState 

1174 error: BaseException | None = None 

1175 

1176 try: 

1177 # First, clear the xcom data sent from server 

1178 if ti._ti_context_from_server and (keys_to_delete := ti._ti_context_from_server.xcom_keys_to_clear): 

1179 for x in keys_to_delete: 

1180 log.debug("Clearing XCom with key", key=x) 

1181 XCom.delete( 

1182 key=x, 

1183 dag_id=ti.dag_id, 

1184 task_id=ti.task_id, 

1185 run_id=ti.run_id, 

1186 map_index=ti.map_index, 

1187 ) 

1188 

1189 with set_current_context(context): 

1190 # This is the earliest that we can render templates -- as if it excepts for any reason we need to 

1191 # catch it and handle it like a normal task failure 

1192 if early_exit := _prepare(ti, log, context): 

1193 msg = early_exit 

1194 ti.state = state = TaskInstanceState.FAILED 

1195 return state, msg, error 

1196 

1197 try: 

1198 result = _execute_task(context=context, ti=ti, log=log) 

1199 except Exception: 

1200 import jinja2 

1201 

1202 # If the task failed, swallow rendering error so it doesn't mask the main error. 

1203 with contextlib.suppress(jinja2.TemplateSyntaxError, jinja2.UndefinedError): 

1204 previous_rendered_map_index = ti.rendered_map_index 

1205 ti.rendered_map_index = _render_map_index(context, ti=ti, log=log) 

1206 # Send update only if value changed (e.g., user set context variables during execution) 

1207 if ti.rendered_map_index and ti.rendered_map_index != previous_rendered_map_index: 

1208 SUPERVISOR_COMMS.send( 

1209 msg=SetRenderedMapIndex(rendered_map_index=ti.rendered_map_index) 

1210 ) 

1211 raise 

1212 else: # If the task succeeded, render normally to let rendering error bubble up. 

1213 previous_rendered_map_index = ti.rendered_map_index 

1214 ti.rendered_map_index = _render_map_index(context, ti=ti, log=log) 

1215 # Send update only if value changed (e.g., user set context variables during execution) 

1216 if ti.rendered_map_index and ti.rendered_map_index != previous_rendered_map_index: 

1217 SUPERVISOR_COMMS.send(msg=SetRenderedMapIndex(rendered_map_index=ti.rendered_map_index)) 

1218 

1219 _push_xcom_if_needed(result, ti, log) 

1220 

1221 msg, state = _handle_current_task_success(context, ti) 

1222 except DownstreamTasksSkipped as skip: 

1223 log.info("Skipping downstream tasks.") 

1224 tasks_to_skip = skip.tasks if isinstance(skip.tasks, list) else [skip.tasks] 

1225 SUPERVISOR_COMMS.send(msg=SkipDownstreamTasks(tasks=tasks_to_skip)) 

1226 msg, state = _handle_current_task_success(context, ti) 

1227 except DagRunTriggerException as drte: 

1228 msg, state = _handle_trigger_dag_run(drte, context, ti, log) 

1229 except TaskDeferred as defer: 

1230 msg, state = _defer_task(defer, ti, log) 

1231 except AirflowSkipException as e: 

1232 if e.args: 

1233 log.info("Skipping task.", reason=e.args[0]) 

1234 msg = TaskState( 

1235 state=TaskInstanceState.SKIPPED, 

1236 end_date=datetime.now(tz=timezone.utc), 

1237 rendered_map_index=ti.rendered_map_index, 

1238 ) 

1239 state = TaskInstanceState.SKIPPED 

1240 except AirflowRescheduleException as reschedule: 

1241 log.info("Rescheduling task, marking task as UP_FOR_RESCHEDULE") 

1242 msg = RescheduleTask( 

1243 reschedule_date=reschedule.reschedule_date, end_date=datetime.now(tz=timezone.utc) 

1244 ) 

1245 state = TaskInstanceState.UP_FOR_RESCHEDULE 

1246 except (AirflowFailException, AirflowSensorTimeout) as e: 

1247 # If AirflowFailException is raised, task should not retry. 

1248 # If a sensor in reschedule mode reaches timeout, task should not retry. 

1249 log.exception("Task failed with exception") 

1250 ti.end_date = datetime.now(tz=timezone.utc) 

1251 msg = TaskState( 

1252 state=TaskInstanceState.FAILED, 

1253 end_date=ti.end_date, 

1254 rendered_map_index=ti.rendered_map_index, 

1255 ) 

1256 state = TaskInstanceState.FAILED 

1257 error = e 

1258 except (AirflowTaskTimeout, AirflowException, AirflowRuntimeError) as e: 

1259 # We should allow retries if the task has defined it. 

1260 log.exception("Task failed with exception") 

1261 msg, state = _handle_current_task_failed(ti) 

1262 error = e 

1263 except AirflowTaskTerminated as e: 

1264 # External state updates are already handled with `ti_heartbeat` and will be 

1265 # updated already be another UI API. So, these exceptions should ideally never be thrown. 

1266 # If these are thrown, we should mark the TI state as failed. 

1267 log.exception("Task failed with exception") 

1268 ti.end_date = datetime.now(tz=timezone.utc) 

1269 msg = TaskState( 

1270 state=TaskInstanceState.FAILED, 

1271 end_date=ti.end_date, 

1272 rendered_map_index=ti.rendered_map_index, 

1273 ) 

1274 state = TaskInstanceState.FAILED 

1275 error = e 

1276 except SystemExit as e: 

1277 # SystemExit needs to be retried if they are eligible. 

1278 log.error("Task exited", exit_code=e.code) 

1279 msg, state = _handle_current_task_failed(ti) 

1280 error = e 

1281 except BaseException as e: 

1282 log.exception("Task failed with exception") 

1283 msg, state = _handle_current_task_failed(ti) 

1284 error = e 

1285 finally: 

1286 if msg: 

1287 SUPERVISOR_COMMS.send(msg=msg) 

1288 

1289 # Return the message to make unit tests easier too 

1290 ti.state = state 

1291 return state, msg, error 

1292 

1293 

1294def _handle_current_task_success( 

1295 context: Context, 

1296 ti: RuntimeTaskInstance, 

1297) -> tuple[SucceedTask, TaskInstanceState]: 

1298 end_date = datetime.now(tz=timezone.utc) 

1299 ti.end_date = end_date 

1300 

1301 # Record operator and task instance success metrics 

1302 operator = ti.task.__class__.__name__ 

1303 stats_tags = {"dag_id": ti.dag_id, "task_id": ti.task_id} 

1304 

1305 Stats.incr(f"operator_successes_{operator}", tags=stats_tags) 

1306 # Same metric with tagging 

1307 Stats.incr("operator_successes", tags={**stats_tags, "operator": operator}) 

1308 Stats.incr("ti_successes", tags=stats_tags) 

1309 

1310 task_outlets = list(_build_asset_profiles(ti.task.outlets)) 

1311 outlet_events = list(_serialize_outlet_events(context["outlet_events"])) 

1312 msg = SucceedTask( 

1313 end_date=end_date, 

1314 task_outlets=task_outlets, 

1315 outlet_events=outlet_events, 

1316 rendered_map_index=ti.rendered_map_index, 

1317 ) 

1318 return msg, TaskInstanceState.SUCCESS 

1319 

1320 

1321def _handle_current_task_failed( 

1322 ti: RuntimeTaskInstance, 

1323) -> tuple[RetryTask, TaskInstanceState] | tuple[TaskState, TaskInstanceState]: 

1324 end_date = datetime.now(tz=timezone.utc) 

1325 ti.end_date = end_date 

1326 

1327 # Record operator and task instance failed metrics 

1328 operator = ti.task.__class__.__name__ 

1329 stats_tags = {"dag_id": ti.dag_id, "task_id": ti.task_id} 

1330 

1331 Stats.incr(f"operator_failures_{operator}", tags=stats_tags) 

1332 # Same metric with tagging 

1333 Stats.incr("operator_failures", tags={**stats_tags, "operator": operator}) 

1334 Stats.incr("ti_failures", tags=stats_tags) 

1335 

1336 if ti._ti_context_from_server and ti._ti_context_from_server.should_retry: 

1337 return RetryTask(end_date=end_date), TaskInstanceState.UP_FOR_RETRY 

1338 return ( 

1339 TaskState( 

1340 state=TaskInstanceState.FAILED, end_date=end_date, rendered_map_index=ti.rendered_map_index 

1341 ), 

1342 TaskInstanceState.FAILED, 

1343 ) 

1344 

1345 

1346def _handle_trigger_dag_run( 

1347 drte: DagRunTriggerException, context: Context, ti: RuntimeTaskInstance, log: Logger 

1348) -> tuple[ToSupervisor, TaskInstanceState]: 

1349 """Handle exception from TriggerDagRunOperator.""" 

1350 log.info("Triggering Dag Run.", trigger_dag_id=drte.trigger_dag_id) 

1351 comms_msg = SUPERVISOR_COMMS.send( 

1352 TriggerDagRun( 

1353 dag_id=drte.trigger_dag_id, 

1354 run_id=drte.dag_run_id, 

1355 logical_date=drte.logical_date, 

1356 conf=drte.conf, 

1357 reset_dag_run=drte.reset_dag_run, 

1358 ), 

1359 ) 

1360 

1361 if isinstance(comms_msg, ErrorResponse) and comms_msg.error == ErrorType.DAGRUN_ALREADY_EXISTS: 

1362 if drte.skip_when_already_exists: 

1363 log.info( 

1364 "Dag Run already exists, skipping task as skip_when_already_exists is set to True.", 

1365 dag_id=drte.trigger_dag_id, 

1366 ) 

1367 msg = TaskState( 

1368 state=TaskInstanceState.SKIPPED, 

1369 end_date=datetime.now(tz=timezone.utc), 

1370 rendered_map_index=ti.rendered_map_index, 

1371 ) 

1372 state = TaskInstanceState.SKIPPED 

1373 else: 

1374 log.error("Dag Run already exists, marking task as failed.", dag_id=drte.trigger_dag_id) 

1375 msg = TaskState( 

1376 state=TaskInstanceState.FAILED, 

1377 end_date=datetime.now(tz=timezone.utc), 

1378 rendered_map_index=ti.rendered_map_index, 

1379 ) 

1380 state = TaskInstanceState.FAILED 

1381 

1382 return msg, state 

1383 

1384 log.info("Dag Run triggered successfully.", trigger_dag_id=drte.trigger_dag_id) 

1385 

1386 # Store the run id from the dag run (either created or found above) to 

1387 # be used when creating the extra link on the webserver. 

1388 ti.xcom_push(key="trigger_run_id", value=drte.dag_run_id) 

1389 

1390 if drte.wait_for_completion: 

1391 if drte.deferrable: 

1392 from airflow.providers.standard.triggers.external_task import DagStateTrigger 

1393 

1394 defer = TaskDeferred( 

1395 trigger=DagStateTrigger( 

1396 dag_id=drte.trigger_dag_id, 

1397 states=drte.allowed_states + drte.failed_states, # type: ignore[arg-type] 

1398 # Don't filter by execution_dates when run_ids is provided. 

1399 # run_id uniquely identifies a DAG run, and when reset_dag_run=True, 

1400 # drte.logical_date might be a newly calculated value that doesn't match 

1401 # the persisted logical_date in the database, causing the trigger to never find the run. 

1402 execution_dates=None, 

1403 run_ids=[drte.dag_run_id], 

1404 poll_interval=drte.poke_interval, 

1405 ), 

1406 method_name="execute_complete", 

1407 ) 

1408 return _defer_task(defer, ti, log) 

1409 while True: 

1410 log.info( 

1411 "Waiting for dag run to complete execution in allowed state.", 

1412 dag_id=drte.trigger_dag_id, 

1413 run_id=drte.dag_run_id, 

1414 allowed_state=drte.allowed_states, 

1415 ) 

1416 time.sleep(drte.poke_interval) 

1417 

1418 comms_msg = SUPERVISOR_COMMS.send( 

1419 GetDagRunState(dag_id=drte.trigger_dag_id, run_id=drte.dag_run_id) 

1420 ) 

1421 if TYPE_CHECKING: 

1422 assert isinstance(comms_msg, DagRunStateResult) 

1423 if comms_msg.state in drte.failed_states: 

1424 log.error( 

1425 "DagRun finished with failed state.", dag_id=drte.trigger_dag_id, state=comms_msg.state 

1426 ) 

1427 msg = TaskState( 

1428 state=TaskInstanceState.FAILED, 

1429 end_date=datetime.now(tz=timezone.utc), 

1430 rendered_map_index=ti.rendered_map_index, 

1431 ) 

1432 state = TaskInstanceState.FAILED 

1433 return msg, state 

1434 if comms_msg.state in drte.allowed_states: 

1435 log.info( 

1436 "DagRun finished with allowed state.", dag_id=drte.trigger_dag_id, state=comms_msg.state 

1437 ) 

1438 break 

1439 log.debug( 

1440 "DagRun not yet in allowed or failed state.", 

1441 dag_id=drte.trigger_dag_id, 

1442 state=comms_msg.state, 

1443 ) 

1444 else: 

1445 # Fire-and-forget mode: wait_for_completion=False 

1446 if drte.deferrable: 

1447 log.info( 

1448 "Ignoring deferrable=True because wait_for_completion=False. " 

1449 "Task will complete immediately without waiting for the triggered DAG run.", 

1450 trigger_dag_id=drte.trigger_dag_id, 

1451 ) 

1452 

1453 return _handle_current_task_success(context, ti) 

1454 

1455 

1456def _run_task_state_change_callbacks( 

1457 task: BaseOperator, 

1458 kind: Literal[ 

1459 "on_execute_callback", 

1460 "on_failure_callback", 

1461 "on_success_callback", 

1462 "on_retry_callback", 

1463 "on_skipped_callback", 

1464 ], 

1465 context: Context, 

1466 log: Logger, 

1467) -> None: 

1468 callback: Callable[[Context], None] 

1469 for i, callback in enumerate(getattr(task, kind)): 

1470 try: 

1471 create_executable_runner(callback, context_get_outlet_events(context), logger=log).run(context) 

1472 except Exception: 

1473 log.exception("Failed to run task callback", kind=kind, index=i, callback=callback) 

1474 

1475 

1476def _send_error_email_notification( 

1477 task: BaseOperator | MappedOperator, 

1478 ti: RuntimeTaskInstance, 

1479 context: Context, 

1480 error: BaseException | str | None, 

1481 log: Logger, 

1482) -> None: 

1483 """Send email notification for task errors using SmtpNotifier.""" 

1484 try: 

1485 from airflow.providers.smtp.notifications.smtp import SmtpNotifier 

1486 except ImportError: 

1487 log.error( 

1488 "Failed to send task failure or retry email notification: " 

1489 "`apache-airflow-providers-smtp` is not installed. " 

1490 "Install this provider to enable email notifications." 

1491 ) 

1492 return 

1493 

1494 if not task.email: 

1495 return 

1496 

1497 subject_template_file = conf.get("email", "subject_template", fallback=None) 

1498 

1499 # Read the template file if configured 

1500 if subject_template_file and Path(subject_template_file).exists(): 

1501 subject = Path(subject_template_file).read_text() 

1502 else: 

1503 # Fallback to default 

1504 subject = "Airflow alert: {{ti}}" 

1505 

1506 html_content_template_file = conf.get("email", "html_content_template", fallback=None) 

1507 

1508 # Read the template file if configured 

1509 if html_content_template_file and Path(html_content_template_file).exists(): 

1510 html_content = Path(html_content_template_file).read_text() 

1511 else: 

1512 # Fallback to default 

1513 # For reporting purposes, we report based on 1-indexed, 

1514 # not 0-indexed lists (i.e. Try 1 instead of Try 0 for the first attempt). 

1515 html_content = ( 

1516 "Try {{try_number}} out of {{max_tries + 1}}<br>" 

1517 "Exception:<br>{{exception_html}}<br>" 

1518 'Log: <a href="{{ti.log_url}}">Link</a><br>' 

1519 "Host: {{ti.hostname}}<br>" 

1520 'Mark success: <a href="{{ti.mark_success_url}}">Link</a><br>' 

1521 ) 

1522 

1523 # Add exception_html to context for template rendering 

1524 import html 

1525 

1526 exception_html = html.escape(str(error)).replace("\n", "<br>") 

1527 additional_context = { 

1528 "exception": error, 

1529 "exception_html": exception_html, 

1530 "try_number": ti.try_number, 

1531 "max_tries": ti.max_tries, 

1532 } 

1533 email_context = {**context, **additional_context} 

1534 to_emails = task.email 

1535 if not to_emails: 

1536 return 

1537 

1538 try: 

1539 notifier = SmtpNotifier( 

1540 to=to_emails, 

1541 subject=subject, 

1542 html_content=html_content, 

1543 from_email=conf.get("email", "from_email", fallback="airflow@airflow"), 

1544 ) 

1545 notifier(email_context) 

1546 except Exception: 

1547 log.exception("Failed to send email notification") 

1548 

1549 

1550def _execute_task(context: Context, ti: RuntimeTaskInstance, log: Logger): 

1551 """Execute Task (optionally with a Timeout) and push Xcom results.""" 

1552 task = ti.task 

1553 execute = task.execute 

1554 

1555 if ti._ti_context_from_server and (next_method := ti._ti_context_from_server.next_method): 

1556 from airflow.sdk.serde import deserialize 

1557 

1558 next_kwargs_data = ti._ti_context_from_server.next_kwargs or {} 

1559 try: 

1560 if TYPE_CHECKING: 

1561 assert isinstance(next_kwargs_data, dict) 

1562 kwargs = deserialize(next_kwargs_data) 

1563 except (ImportError, KeyError, AttributeError, TypeError): 

1564 from airflow.serialization.serialized_objects import BaseSerialization 

1565 

1566 kwargs = BaseSerialization.deserialize(next_kwargs_data) 

1567 

1568 if TYPE_CHECKING: 

1569 assert isinstance(kwargs, dict) 

1570 execute = functools.partial(task.resume_execution, next_method=next_method, next_kwargs=kwargs) 

1571 

1572 ctx = contextvars.copy_context() 

1573 # Populate the context var so ExecutorSafeguard doesn't complain 

1574 ctx.run(ExecutorSafeguard.tracker.set, task) 

1575 

1576 # Export context in os.environ to make it available for operators to use. 

1577 airflow_context_vars = context_to_airflow_vars(context, in_env_var_format=True) 

1578 os.environ.update(airflow_context_vars) 

1579 

1580 outlet_events = context_get_outlet_events(context) 

1581 

1582 if (pre_execute_hook := task._pre_execute_hook) is not None: 

1583 create_executable_runner(pre_execute_hook, outlet_events, logger=log).run(context) 

1584 if getattr(pre_execute_hook := task.pre_execute, "__func__", None) is not BaseOperator.pre_execute: 

1585 create_executable_runner(pre_execute_hook, outlet_events, logger=log).run(context) 

1586 

1587 _run_task_state_change_callbacks(task, "on_execute_callback", context, log) 

1588 

1589 if task.execution_timeout: 

1590 from airflow.sdk.execution_time.timeout import timeout 

1591 

1592 # TODO: handle timeout in case of deferral 

1593 timeout_seconds = task.execution_timeout.total_seconds() 

1594 try: 

1595 # It's possible we're already timed out, so fast-fail if true 

1596 if timeout_seconds <= 0: 

1597 raise AirflowTaskTimeout() 

1598 # Run task in timeout wrapper 

1599 with timeout(timeout_seconds): 

1600 result = ctx.run(execute, context=context) 

1601 except AirflowTaskTimeout: 

1602 task.on_kill() 

1603 raise 

1604 else: 

1605 result = ctx.run(execute, context=context) 

1606 

1607 if (post_execute_hook := task._post_execute_hook) is not None: 

1608 create_executable_runner(post_execute_hook, outlet_events, logger=log).run(context, result) 

1609 if getattr(post_execute_hook := task.post_execute, "__func__", None) is not BaseOperator.post_execute: 

1610 create_executable_runner(post_execute_hook, outlet_events, logger=log).run(context) 

1611 

1612 return result 

1613 

1614 

1615def _render_map_index(context: Context, ti: RuntimeTaskInstance, log: Logger) -> str | None: 

1616 """Render named map index if the Dag author defined map_index_template at the task level.""" 

1617 if (template := context.get("map_index_template")) is None: 

1618 return None 

1619 log.debug("Rendering map_index_template", template_length=len(template)) 

1620 jinja_env = ti.task.dag.get_template_env() 

1621 rendered_map_index = jinja_env.from_string(template).render(context) 

1622 log.debug("Map index rendered", length=len(rendered_map_index)) 

1623 return rendered_map_index 

1624 

1625 

1626def _push_xcom_if_needed(result: Any, ti: RuntimeTaskInstance, log: Logger): 

1627 """Push XCom values when task has ``do_xcom_push`` set to ``True`` and the task returns a result.""" 

1628 if ti.task.do_xcom_push: 

1629 xcom_value = result 

1630 else: 

1631 xcom_value = None 

1632 

1633 has_mapped_dep = next(ti.task.iter_mapped_dependants(), None) is not None 

1634 if xcom_value is None: 

1635 if not ti.is_mapped and has_mapped_dep: 

1636 # Uhoh, a downstream mapped task depends on us to push something to map over 

1637 from airflow.sdk.exceptions import XComForMappingNotPushed 

1638 

1639 raise XComForMappingNotPushed() 

1640 return 

1641 

1642 mapped_length: int | None = None 

1643 if not ti.is_mapped and has_mapped_dep: 

1644 from airflow.sdk.definitions.mappedoperator import is_mappable_value 

1645 from airflow.sdk.exceptions import UnmappableXComTypePushed 

1646 

1647 if not is_mappable_value(xcom_value): 

1648 raise UnmappableXComTypePushed(xcom_value) 

1649 mapped_length = len(xcom_value) 

1650 

1651 log.info("Pushing xcom", ti=ti) 

1652 

1653 # If the task has multiple outputs, push each output as a separate XCom. 

1654 if ti.task.multiple_outputs: 

1655 if not isinstance(xcom_value, Mapping): 

1656 raise TypeError( 

1657 f"Returned output was type {type(xcom_value)} expected dictionary for multiple_outputs" 

1658 ) 

1659 for key in xcom_value.keys(): 

1660 if not isinstance(key, str): 

1661 raise TypeError( 

1662 "Returned dictionary keys must be strings when using " 

1663 f"multiple_outputs, found {key} ({type(key)}) instead" 

1664 ) 

1665 for k, v in result.items(): 

1666 ti.xcom_push(k, v) 

1667 

1668 _xcom_push(ti, BaseXCom.XCOM_RETURN_KEY, result, mapped_length=mapped_length) 

1669 

1670 

1671def finalize( 

1672 ti: RuntimeTaskInstance, 

1673 state: TaskInstanceState, 

1674 context: Context, 

1675 log: Logger, 

1676 error: BaseException | None = None, 

1677): 

1678 # Record task duration metrics for all terminal states 

1679 if ti.start_date and ti.end_date: 

1680 duration_ms = (ti.end_date - ti.start_date).total_seconds() * 1000 

1681 stats_tags = {"dag_id": ti.dag_id, "task_id": ti.task_id} 

1682 

1683 Stats.timing(f"dag.{ti.dag_id}.{ti.task_id}.duration", duration_ms) 

1684 Stats.timing("task.duration", duration_ms, tags=stats_tags) 

1685 

1686 task = ti.task 

1687 # Pushing xcom for each operator extra links defined on the operator only. 

1688 for oe in task.operator_extra_links: 

1689 try: 

1690 link, xcom_key = oe.get_link(operator=task, ti_key=ti), oe.xcom_key # type: ignore[arg-type] 

1691 log.debug("Setting xcom for operator extra link", link=link, xcom_key=xcom_key) 

1692 _xcom_push_to_db(ti, key=xcom_key, value=link) 

1693 except Exception: 

1694 log.exception( 

1695 "Failed to push an xcom for task operator extra link", 

1696 link_name=oe.name, 

1697 xcom_key=oe.xcom_key, 

1698 ti=ti, 

1699 ) 

1700 

1701 if getattr(ti.task, "overwrite_rtif_after_execution", False): 

1702 log.debug("Overwriting Rendered template fields.") 

1703 if ti.task.template_fields: 

1704 SUPERVISOR_COMMS.send(SetRenderedFields(rendered_fields=_serialize_rendered_fields(ti.task))) 

1705 

1706 log.debug("Running finalizers", ti=ti) 

1707 if state == TaskInstanceState.SUCCESS: 

1708 _run_task_state_change_callbacks(task, "on_success_callback", context, log) 

1709 try: 

1710 get_listener_manager().hook.on_task_instance_success( 

1711 previous_state=TaskInstanceState.RUNNING, task_instance=ti 

1712 ) 

1713 except Exception: 

1714 log.exception("error calling listener") 

1715 elif state == TaskInstanceState.SKIPPED: 

1716 _run_task_state_change_callbacks(task, "on_skipped_callback", context, log) 

1717 try: 

1718 get_listener_manager().hook.on_task_instance_skipped( 

1719 previous_state=TaskInstanceState.RUNNING, task_instance=ti 

1720 ) 

1721 except Exception: 

1722 log.exception("error calling listener") 

1723 elif state == TaskInstanceState.UP_FOR_RETRY: 

1724 _run_task_state_change_callbacks(task, "on_retry_callback", context, log) 

1725 try: 

1726 get_listener_manager().hook.on_task_instance_failed( 

1727 previous_state=TaskInstanceState.RUNNING, task_instance=ti, error=error 

1728 ) 

1729 except Exception: 

1730 log.exception("error calling listener") 

1731 if error and task.email_on_retry and task.email: 

1732 _send_error_email_notification(task, ti, context, error, log) 

1733 elif state == TaskInstanceState.FAILED: 

1734 _run_task_state_change_callbacks(task, "on_failure_callback", context, log) 

1735 try: 

1736 get_listener_manager().hook.on_task_instance_failed( 

1737 previous_state=TaskInstanceState.RUNNING, task_instance=ti, error=error 

1738 ) 

1739 except Exception: 

1740 log.exception("error calling listener") 

1741 if error and task.email_on_failure and task.email: 

1742 _send_error_email_notification(task, ti, context, error, log) 

1743 

1744 try: 

1745 get_listener_manager().hook.before_stopping(component=TaskRunnerMarker()) 

1746 except Exception: 

1747 log.exception("error calling listener") 

1748 

1749 

1750def main(): 

1751 log = structlog.get_logger(logger_name="task") 

1752 

1753 global SUPERVISOR_COMMS 

1754 SUPERVISOR_COMMS = CommsDecoder[ToTask, ToSupervisor](log=log) 

1755 

1756 Stats.initialize( 

1757 is_statsd_datadog_enabled=conf.getboolean("metrics", "statsd_datadog_enabled"), 

1758 is_statsd_on=conf.getboolean("metrics", "statsd_on"), 

1759 is_otel_on=conf.getboolean("metrics", "otel_on"), 

1760 ) 

1761 

1762 try: 

1763 try: 

1764 ti, context, log = startup() 

1765 except AirflowRescheduleException as reschedule: 

1766 log.warning("Rescheduling task during startup, marking task as UP_FOR_RESCHEDULE") 

1767 SUPERVISOR_COMMS.send( 

1768 msg=RescheduleTask( 

1769 reschedule_date=reschedule.reschedule_date, 

1770 end_date=datetime.now(tz=timezone.utc), 

1771 ) 

1772 ) 

1773 sys.exit(0) 

1774 with BundleVersionLock( 

1775 bundle_name=ti.bundle_instance.name, 

1776 bundle_version=ti.bundle_instance.version, 

1777 ): 

1778 state, _, error = run(ti, context, log) 

1779 context["exception"] = error 

1780 finalize(ti, state, context, log, error) 

1781 except KeyboardInterrupt: 

1782 log.exception("Ctrl-c hit") 

1783 sys.exit(2) 

1784 except Exception: 

1785 log.exception("Top level error") 

1786 sys.exit(1) 

1787 finally: 

1788 # Ensure the request socket is closed on the child side in all circumstances 

1789 # before the process fully terminates. 

1790 if SUPERVISOR_COMMS and SUPERVISOR_COMMS.socket: 

1791 with suppress(Exception): 

1792 SUPERVISOR_COMMS.socket.close() 

1793 

1794 

1795def reinit_supervisor_comms() -> None: 

1796 """ 

1797 Re-initialize supervisor comms and logging channel in subprocess. 

1798 

1799 This is not needed for most cases, but is used when either we re-launch the process via sudo for 

1800 run_as_user, or from inside the python code in a virtualenv (et al.) operator to re-connect so those tasks 

1801 can continue to access variables etc. 

1802 """ 

1803 import socket 

1804 

1805 if "SUPERVISOR_COMMS" not in globals(): 

1806 global SUPERVISOR_COMMS 

1807 log = structlog.get_logger(logger_name="task") 

1808 

1809 fd = int(os.environ.get("__AIRFLOW_SUPERVISOR_FD", "0")) 

1810 

1811 SUPERVISOR_COMMS = CommsDecoder[ToTask, ToSupervisor](log=log, socket=socket.socket(fileno=fd)) 

1812 

1813 logs = SUPERVISOR_COMMS.send(ResendLoggingFD()) 

1814 if isinstance(logs, SentFDs): 

1815 from airflow.sdk.log import configure_logging 

1816 

1817 log_io = os.fdopen(logs.fds[0], "wb", buffering=0) 

1818 configure_logging(json_output=True, output=log_io, sending_to_supervisor=True) 

1819 else: 

1820 print("Unable to re-configure logging after sudo, we didn't get an FD", file=sys.stderr) 

1821 

1822 

1823if __name__ == "__main__": 

1824 main()