Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/airflow/sdk/execution_time/context.py: 25%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

506 statements  

1# Licensed to the Apache Software Foundation (ASF) under one 

2# or more contributor license agreements. See the NOTICE file 

3# distributed with this work for additional information 

4# regarding copyright ownership. The ASF licenses this file 

5# to you under the Apache License, Version 2.0 (the 

6# "License"); you may not use this file except in compliance 

7# with the License. You may obtain a copy of the License at 

8# 

9# http://www.apache.org/licenses/LICENSE-2.0 

10# 

11# Unless required by applicable law or agreed to in writing, 

12# software distributed under the License is distributed on an 

13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

14# KIND, either express or implied. See the License for the 

15# specific language governing permissions and limitations 

16# under the License. 

17from __future__ import annotations 

18 

19import collections 

20import contextlib 

21import functools 

22import inspect 

23from collections.abc import Generator, Iterable, Iterator, Mapping, Sequence 

24from datetime import datetime 

25from functools import cache 

26from typing import TYPE_CHECKING, Any, Generic, TypeVar, overload 

27 

28import attrs 

29import structlog 

30 

31from airflow.sdk.definitions._internal.contextmanager import _CURRENT_CONTEXT 

32from airflow.sdk.definitions._internal.types import NOTSET 

33from airflow.sdk.definitions.asset import ( 

34 Asset, 

35 AssetAlias, 

36 AssetAliasEvent, 

37 AssetAliasUniqueKey, 

38 AssetNameRef, 

39 AssetRef, 

40 AssetUniqueKey, 

41 AssetUriRef, 

42 BaseAssetUniqueKey, 

43) 

44from airflow.sdk.exceptions import AirflowNotFoundException, AirflowRuntimeError, ErrorType 

45from airflow.sdk.log import mask_secret 

46 

47if TYPE_CHECKING: 

48 from uuid import UUID 

49 

50 from pydantic.types import JsonValue 

51 from typing_extensions import Self 

52 

53 from airflow.sdk import Variable 

54 from airflow.sdk.bases.operator import BaseOperator 

55 from airflow.sdk.definitions.connection import Connection 

56 from airflow.sdk.definitions.context import Context 

57 from airflow.sdk.execution_time.comms import ( 

58 AssetEventDagRunReferenceResult, 

59 AssetEventResult, 

60 AssetEventsResult, 

61 AssetResult, 

62 ConnectionResult, 

63 OKResponse, 

64 PrevSuccessfulDagRunResponse, 

65 ReceiveMsgType, 

66 VariableResult, 

67 ) 

68 from airflow.sdk.types import OutletEventAccessorsProtocol 

69 

70 

71DEFAULT_FORMAT_PREFIX = "airflow.ctx." 

72ENV_VAR_FORMAT_PREFIX = "AIRFLOW_CTX_" 

73 

74AIRFLOW_VAR_NAME_FORMAT_MAPPING = { 

75 "AIRFLOW_CONTEXT_DAG_ID": { 

76 "default": f"{DEFAULT_FORMAT_PREFIX}dag_id", 

77 "env_var_format": f"{ENV_VAR_FORMAT_PREFIX}DAG_ID", 

78 }, 

79 "AIRFLOW_CONTEXT_TASK_ID": { 

80 "default": f"{DEFAULT_FORMAT_PREFIX}task_id", 

81 "env_var_format": f"{ENV_VAR_FORMAT_PREFIX}TASK_ID", 

82 }, 

83 "AIRFLOW_CONTEXT_LOGICAL_DATE": { 

84 "default": f"{DEFAULT_FORMAT_PREFIX}logical_date", 

85 "env_var_format": f"{ENV_VAR_FORMAT_PREFIX}LOGICAL_DATE", 

86 }, 

87 "AIRFLOW_CONTEXT_TRY_NUMBER": { 

88 "default": f"{DEFAULT_FORMAT_PREFIX}try_number", 

89 "env_var_format": f"{ENV_VAR_FORMAT_PREFIX}TRY_NUMBER", 

90 }, 

91 "AIRFLOW_CONTEXT_DAG_RUN_ID": { 

92 "default": f"{DEFAULT_FORMAT_PREFIX}dag_run_id", 

93 "env_var_format": f"{ENV_VAR_FORMAT_PREFIX}DAG_RUN_ID", 

94 }, 

95 "AIRFLOW_CONTEXT_DAG_OWNER": { 

96 "default": f"{DEFAULT_FORMAT_PREFIX}dag_owner", 

97 "env_var_format": f"{ENV_VAR_FORMAT_PREFIX}DAG_OWNER", 

98 }, 

99 "AIRFLOW_CONTEXT_DAG_EMAIL": { 

100 "default": f"{DEFAULT_FORMAT_PREFIX}dag_email", 

101 "env_var_format": f"{ENV_VAR_FORMAT_PREFIX}DAG_EMAIL", 

102 }, 

103} 

104 

105 

106log = structlog.get_logger(logger_name="task") 

107 

108T = TypeVar("T") 

109 

110 

111def _process_connection_result_conn(conn_result: ReceiveMsgType | None) -> Connection: 

112 from airflow.sdk.definitions.connection import Connection 

113 from airflow.sdk.execution_time.comms import ErrorResponse 

114 

115 if isinstance(conn_result, ErrorResponse): 

116 raise AirflowRuntimeError(conn_result) 

117 

118 if TYPE_CHECKING: 

119 assert isinstance(conn_result, ConnectionResult) 

120 

121 # `by_alias=True` is used to convert the `schema` field to `schema_` in the Connection model 

122 return Connection(**conn_result.model_dump(exclude={"type"}, by_alias=True)) 

123 

124 

125def _mask_connection_secrets(conn: Connection) -> None: 

126 """Mask sensitive connection fields from logs.""" 

127 if conn.password: 

128 mask_secret(conn.password) 

129 if conn.extra: 

130 mask_secret(conn.extra) 

131 

132 

133def _convert_variable_result_to_variable(var_result: VariableResult, deserialize_json: bool) -> Variable: 

134 from airflow.sdk.definitions.variable import Variable 

135 

136 if deserialize_json: 

137 import json 

138 

139 var_result.value = json.loads(var_result.value) # type: ignore 

140 return Variable(**var_result.model_dump(exclude={"type"})) 

141 

142 

143def _get_connection(conn_id: str) -> Connection: 

144 from airflow.sdk.execution_time.cache import SecretCache 

145 from airflow.sdk.execution_time.supervisor import ensure_secrets_backend_loaded 

146 

147 # Check cache first (optional; only on dag processor) 

148 try: 

149 uri = SecretCache.get_connection_uri(conn_id) 

150 from airflow.sdk.definitions.connection import Connection 

151 

152 conn = Connection.from_uri(uri, conn_id=conn_id) 

153 _mask_connection_secrets(conn) 

154 return conn 

155 except SecretCache.NotPresentException: 

156 pass # continue to backends 

157 

158 # Iterate over configured backends (which may include SupervisorCommsSecretsBackend 

159 # in worker contexts or MetastoreBackend in API server contexts) 

160 backends = ensure_secrets_backend_loaded() 

161 for secrets_backend in backends: 

162 try: 

163 conn = secrets_backend.get_connection(conn_id=conn_id) # type: ignore[assignment] 

164 if conn: 

165 SecretCache.save_connection_uri(conn_id, conn.get_uri()) 

166 _mask_connection_secrets(conn) 

167 return conn 

168 except Exception: 

169 log.debug( 

170 "Unable to retrieve connection from secrets backend (%s). " 

171 "Checking subsequent secrets backend.", 

172 type(secrets_backend).__name__, 

173 ) 

174 

175 # If no backend found the connection, raise an error 

176 

177 raise AirflowNotFoundException(f"The conn_id `{conn_id}` isn't defined") 

178 

179 

180async def _async_get_connection(conn_id: str) -> Connection: 

181 from asgiref.sync import sync_to_async 

182 

183 from airflow.sdk.execution_time.cache import SecretCache 

184 

185 # Check cache first 

186 try: 

187 uri = SecretCache.get_connection_uri(conn_id) 

188 from airflow.sdk.definitions.connection import Connection 

189 

190 conn = Connection.from_uri(uri, conn_id=conn_id) 

191 _mask_connection_secrets(conn) 

192 return conn 

193 except SecretCache.NotPresentException: 

194 pass # continue to backends 

195 

196 from airflow.sdk.execution_time.supervisor import ensure_secrets_backend_loaded 

197 

198 # Try secrets backends 

199 backends = ensure_secrets_backend_loaded() 

200 for secrets_backend in backends: 

201 try: 

202 # Use async method if available, otherwise wrap sync method 

203 # getattr avoids triggering AsyncMock coroutine creation under Python 3.13 

204 async_method = getattr(secrets_backend, "aget_connection", None) 

205 if async_method is not None: 

206 maybe_awaitable = async_method(conn_id) 

207 conn = await maybe_awaitable if inspect.isawaitable(maybe_awaitable) else maybe_awaitable 

208 else: 

209 conn = await sync_to_async(secrets_backend.get_connection)(conn_id) # type: ignore[assignment] 

210 

211 if conn: 

212 SecretCache.save_connection_uri(conn_id, conn.get_uri()) 

213 _mask_connection_secrets(conn) 

214 return conn 

215 except Exception: 

216 # If one backend fails, try the next one 

217 log.debug( 

218 "Unable to retrieve connection from secrets backend (%s). " 

219 "Checking subsequent secrets backend.", 

220 type(secrets_backend).__name__, 

221 ) 

222 

223 # If no backend found the connection, raise an error 

224 

225 raise AirflowNotFoundException(f"The conn_id `{conn_id}` isn't defined") 

226 

227 

228def _get_variable(key: str, deserialize_json: bool) -> Any: 

229 from airflow.sdk.execution_time.cache import SecretCache 

230 from airflow.sdk.execution_time.supervisor import ensure_secrets_backend_loaded 

231 

232 # Check cache first 

233 try: 

234 var_val = SecretCache.get_variable(key) 

235 if var_val is not None: 

236 if deserialize_json: 

237 import json 

238 

239 var_val = json.loads(var_val) 

240 if isinstance(var_val, str): 

241 mask_secret(var_val, key) 

242 return var_val 

243 except SecretCache.NotPresentException: 

244 pass # Continue to check backends 

245 

246 backends = ensure_secrets_backend_loaded() 

247 

248 # Iterate over backends if not in cache (or expired) 

249 for secrets_backend in backends: 

250 try: 

251 var_val = secrets_backend.get_variable(key=key) 

252 if var_val is not None: 

253 # Save raw value before deserialization to maintain cache consistency 

254 SecretCache.save_variable(key, var_val) 

255 if deserialize_json: 

256 import json 

257 

258 var_val = json.loads(var_val) 

259 if isinstance(var_val, str): 

260 mask_secret(var_val, key) 

261 return var_val 

262 except Exception: 

263 log.exception( 

264 "Unable to retrieve variable from secrets backend (%s). Checking subsequent secrets backend.", 

265 type(secrets_backend).__name__, 

266 ) 

267 

268 # If no backend found the variable, raise a not found error (mirrors _get_connection) 

269 from airflow.sdk.exceptions import AirflowRuntimeError, ErrorType 

270 from airflow.sdk.execution_time.comms import ErrorResponse 

271 

272 raise AirflowRuntimeError( 

273 ErrorResponse(error=ErrorType.VARIABLE_NOT_FOUND, detail={"message": f"Variable {key} not found"}) 

274 ) 

275 

276 

277def _set_variable(key: str, value: Any, description: str | None = None, serialize_json: bool = False) -> None: 

278 # TODO: This should probably be moved to a separate module like `airflow.sdk.execution_time.comms` 

279 # or `airflow.sdk.execution_time.variable` 

280 # A reason to not move it to `airflow.sdk.execution_time.comms` is that it 

281 # will make that module depend on Task SDK, which is not ideal because we intend to 

282 # keep Task SDK as a separate package than execution time mods. 

283 import json 

284 

285 from airflow.sdk.execution_time.cache import SecretCache 

286 from airflow.sdk.execution_time.comms import PutVariable 

287 from airflow.sdk.execution_time.secrets.execution_api import ExecutionAPISecretsBackend 

288 from airflow.sdk.execution_time.supervisor import ensure_secrets_backend_loaded 

289 from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS 

290 

291 # check for write conflicts on the worker 

292 for secrets_backend in ensure_secrets_backend_loaded(): 

293 if isinstance(secrets_backend, ExecutionAPISecretsBackend): 

294 continue 

295 try: 

296 var_val = secrets_backend.get_variable(key=key) 

297 if var_val is not None: 

298 _backend_name = type(secrets_backend).__name__ 

299 log.warning( 

300 "The variable %s is defined in the %s secrets backend, which takes " 

301 "precedence over reading from the API Server. The value from the API Server will be " 

302 "updated, but to read it you have to delete the conflicting variable " 

303 "from %s", 

304 key, 

305 _backend_name, 

306 _backend_name, 

307 ) 

308 except Exception: 

309 log.exception( 

310 "Unable to retrieve variable from secrets backend (%s). Checking subsequent secrets backend.", 

311 type(secrets_backend).__name__, 

312 ) 

313 

314 try: 

315 if serialize_json: 

316 value = json.dumps(value, indent=2) 

317 except Exception as e: 

318 log.exception(e) 

319 

320 SUPERVISOR_COMMS.send(PutVariable(key=key, value=value, description=description)) 

321 

322 # Invalidate cache after setting the variable 

323 SecretCache.invalidate_variable(key) 

324 

325 

326def _delete_variable(key: str) -> None: 

327 # TODO: This should probably be moved to a separate module like `airflow.sdk.execution_time.comms` 

328 # or `airflow.sdk.execution_time.variable` 

329 # A reason to not move it to `airflow.sdk.execution_time.comms` is that it 

330 # will make that module depend on Task SDK, which is not ideal because we intend to 

331 # keep Task SDK as a separate package than execution time mods. 

332 from airflow.sdk.execution_time.cache import SecretCache 

333 from airflow.sdk.execution_time.comms import DeleteVariable 

334 from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS 

335 

336 msg = SUPERVISOR_COMMS.send(DeleteVariable(key=key)) 

337 if TYPE_CHECKING: 

338 assert isinstance(msg, OKResponse) 

339 

340 # Invalidate cache after deleting the variable 

341 SecretCache.invalidate_variable(key) 

342 

343 

344class ConnectionAccessor: 

345 """Wrapper to access Connection entries in template.""" 

346 

347 def __getattr__(self, conn_id: str) -> Any: 

348 from airflow.sdk.definitions.connection import Connection 

349 

350 return Connection.get(conn_id) 

351 

352 def __repr__(self) -> str: 

353 return "<ConnectionAccessor (dynamic access)>" 

354 

355 def __eq__(self, other): 

356 if not isinstance(other, ConnectionAccessor): 

357 return False 

358 # All instances of ConnectionAccessor are equal since it is a stateless dynamic accessor 

359 return True 

360 

361 def __hash__(self): 

362 return hash(self.__class__.__name__) 

363 

364 def get(self, conn_id: str, default_conn: Any = None) -> Any: 

365 try: 

366 return _get_connection(conn_id) 

367 except AirflowRuntimeError as e: 

368 if e.error.error == ErrorType.CONNECTION_NOT_FOUND: 

369 return default_conn 

370 raise 

371 except AirflowNotFoundException: 

372 return default_conn 

373 

374 

375class VariableAccessor: 

376 """Wrapper to access Variable values in template.""" 

377 

378 def __init__(self, deserialize_json: bool) -> None: 

379 self._deserialize_json = deserialize_json 

380 

381 def __eq__(self, other): 

382 if not isinstance(other, VariableAccessor): 

383 return False 

384 # All instances of VariableAccessor are equal since it is a stateless dynamic accessor 

385 return True 

386 

387 def __hash__(self): 

388 return hash(self.__class__.__name__) 

389 

390 def __repr__(self) -> str: 

391 return "<VariableAccessor (dynamic access)>" 

392 

393 def __getattr__(self, key: str) -> Any: 

394 return _get_variable(key, self._deserialize_json) 

395 

396 def get(self, key, default: Any = NOTSET) -> Any: 

397 try: 

398 return _get_variable(key, self._deserialize_json) 

399 except AirflowRuntimeError as e: 

400 if e.error.error == ErrorType.VARIABLE_NOT_FOUND: 

401 return default 

402 raise 

403 

404 

405class MacrosAccessor: 

406 """Wrapper to access Macros module lazily.""" 

407 

408 _macros_module = None 

409 

410 def __getattr__(self, item: str) -> Any: 

411 # Lazily load Macros module 

412 if not self._macros_module: 

413 import airflow.sdk.execution_time.macros 

414 

415 self._macros_module = airflow.sdk.execution_time.macros 

416 return getattr(self._macros_module, item) 

417 

418 def __repr__(self) -> str: 

419 return "<MacrosAccessor (dynamic access to macros)>" 

420 

421 def __eq__(self, other: object) -> bool: 

422 if not isinstance(other, MacrosAccessor): 

423 return False 

424 return True 

425 

426 def __hash__(self): 

427 return hash(self.__class__.__name__) 

428 

429 

430class _AssetRefResolutionMixin: 

431 _asset_ref_cache: dict[AssetRef, tuple[AssetUniqueKey, dict[str, JsonValue]]] = {} 

432 

433 def _resolve_asset_ref(self, ref: AssetRef) -> tuple[AssetUniqueKey, dict[str, JsonValue]]: 

434 with contextlib.suppress(KeyError): 

435 return self._asset_ref_cache[ref] 

436 

437 refs_to_cache: list[AssetRef] 

438 if isinstance(ref, AssetNameRef): 

439 asset = self._get_asset_from_db(name=ref.name) 

440 refs_to_cache = [ref, AssetUriRef(asset.uri)] 

441 elif isinstance(ref, AssetUriRef): 

442 asset = self._get_asset_from_db(uri=ref.uri) 

443 refs_to_cache = [ref, AssetNameRef(asset.name)] 

444 else: 

445 raise TypeError(f"Unimplemented asset ref: {type(ref)}") 

446 unique_key = AssetUniqueKey.from_asset(asset) 

447 for ref in refs_to_cache: 

448 self._asset_ref_cache[ref] = (unique_key, asset.extra) 

449 return (unique_key, asset.extra) 

450 

451 # TODO: This is temporary to avoid code duplication between here & airflow/models/taskinstance.py 

452 @staticmethod 

453 def _get_asset_from_db(name: str | None = None, uri: str | None = None) -> Asset: 

454 from airflow.sdk.definitions.asset import Asset 

455 from airflow.sdk.execution_time.comms import ( 

456 ErrorResponse, 

457 GetAssetByName, 

458 GetAssetByUri, 

459 ToSupervisor, 

460 ) 

461 from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS 

462 

463 msg: ToSupervisor 

464 if name: 

465 msg = GetAssetByName(name=name) 

466 elif uri: 

467 msg = GetAssetByUri(uri=uri) 

468 else: 

469 raise ValueError("Either name or uri must be provided") 

470 

471 resp = SUPERVISOR_COMMS.send(msg) 

472 if isinstance(resp, ErrorResponse): 

473 raise AirflowRuntimeError(resp) 

474 

475 if TYPE_CHECKING: 

476 assert isinstance(resp, AssetResult) 

477 return Asset(**resp.model_dump(exclude={"type"})) 

478 

479 

480@attrs.define 

481class OutletEventAccessor(_AssetRefResolutionMixin): 

482 """Wrapper to access an outlet asset event in template.""" 

483 

484 key: BaseAssetUniqueKey 

485 extra: dict[str, JsonValue] = attrs.Factory(dict) 

486 asset_alias_events: list[AssetAliasEvent] = attrs.field(factory=list) 

487 

488 def add(self, asset: Asset | AssetRef, extra: dict[str, JsonValue] | None = None) -> None: 

489 """Add an AssetEvent to an existing Asset.""" 

490 if not isinstance(self.key, AssetAliasUniqueKey): 

491 return 

492 

493 if isinstance(asset, AssetRef): 

494 asset_key, asset_extra = self._resolve_asset_ref(asset) 

495 else: 

496 asset_key = AssetUniqueKey.from_asset(asset) 

497 asset_extra = asset.extra 

498 

499 asset_alias_name = self.key.name 

500 event = AssetAliasEvent( 

501 source_alias_name=asset_alias_name, 

502 dest_asset_key=asset_key, 

503 dest_asset_extra=asset_extra, 

504 extra=extra or {}, 

505 ) 

506 self.asset_alias_events.append(event) 

507 

508 

509class _AssetEventAccessorsMixin(Generic[T]): 

510 @overload 

511 def for_asset(self, *, name: str, uri: str) -> T: ... 

512 

513 @overload 

514 def for_asset(self, *, name: str) -> T: ... 

515 

516 @overload 

517 def for_asset(self, *, uri: str) -> T: ... 

518 

519 def for_asset(self, *, name: str | None = None, uri: str | None = None) -> T: 

520 if name and uri: 

521 return self[Asset(name=name, uri=uri)] 

522 if name: 

523 return self[Asset.ref(name=name)] 

524 if uri: 

525 return self[Asset.ref(uri=uri)] 

526 

527 raise ValueError("name and uri cannot both be None") 

528 

529 def for_asset_alias(self, *, name: str) -> T: 

530 return self[AssetAlias(name=name)] 

531 

532 def __getitem__(self, key: Asset | AssetAlias | AssetRef) -> T: 

533 raise NotImplementedError 

534 

535 

536class OutletEventAccessors( 

537 _AssetRefResolutionMixin, 

538 Mapping["Asset | AssetAlias", OutletEventAccessor], 

539 _AssetEventAccessorsMixin[OutletEventAccessor], 

540): 

541 """Lazy mapping of outlet asset event accessors.""" 

542 

543 def __init__(self) -> None: 

544 self._dict: dict[BaseAssetUniqueKey, OutletEventAccessor] = {} 

545 

546 def __str__(self) -> str: 

547 return f"OutletEventAccessors(_dict={self._dict})" 

548 

549 def __iter__(self) -> Iterator[Asset | AssetAlias]: 

550 return ( 

551 key.to_asset() if isinstance(key, AssetUniqueKey) else key.to_asset_alias() for key in self._dict 

552 ) 

553 

554 def __len__(self) -> int: 

555 return len(self._dict) 

556 

557 def __getitem__(self, key: Asset | AssetAlias | AssetRef) -> OutletEventAccessor: 

558 hashable_key: BaseAssetUniqueKey 

559 if isinstance(key, Asset): 

560 hashable_key = AssetUniqueKey.from_asset(key) 

561 elif isinstance(key, AssetAlias): 

562 hashable_key = AssetAliasUniqueKey.from_asset_alias(key) 

563 elif isinstance(key, AssetRef): 

564 hashable_key, _ = self._resolve_asset_ref(key) 

565 else: 

566 raise TypeError(f"Key should be either an asset or an asset alias, not {type(key)}") 

567 

568 if hashable_key not in self._dict: 

569 self._dict[hashable_key] = OutletEventAccessor(extra={}, key=hashable_key) 

570 return self._dict[hashable_key] 

571 

572 

573@attrs.define(init=False) 

574class InletEventsAccessor(Sequence["AssetEventResult"]): 

575 _after: str | datetime | None 

576 _before: str | datetime | None 

577 _ascending: bool 

578 _limit: int | None 

579 _asset_name: str | None 

580 _asset_uri: str | None 

581 _alias_name: str | None 

582 

583 def __init__( 

584 self, asset_name: str | None = None, asset_uri: str | None = None, alias_name: str | None = None 

585 ): 

586 self._asset_name = asset_name 

587 self._asset_uri = asset_uri 

588 self._alias_name = alias_name 

589 self._after = None 

590 self._before = None 

591 self._ascending = True 

592 self._limit = None 

593 

594 def after(self, after: str) -> Self: 

595 self._after = after 

596 self._reset_cache() 

597 return self 

598 

599 def before(self, before: str) -> Self: 

600 self._before = before 

601 self._reset_cache() 

602 return self 

603 

604 def ascending(self, ascending: bool = True) -> Self: 

605 self._ascending = ascending 

606 self._reset_cache() 

607 return self 

608 

609 def limit(self, limit: int) -> Self: 

610 self._limit = limit 

611 self._reset_cache() 

612 return self 

613 

614 @functools.cached_property 

615 def _asset_events(self) -> list[AssetEventResult]: 

616 from airflow.sdk.execution_time.comms import ( 

617 ErrorResponse, 

618 GetAssetEventByAsset, 

619 GetAssetEventByAssetAlias, 

620 ToSupervisor, 

621 ) 

622 from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS 

623 

624 query_dict: dict[str, Any] = { 

625 "after": self._after, 

626 "before": self._before, 

627 "ascending": self._ascending, 

628 "limit": self._limit, 

629 } 

630 

631 msg: ToSupervisor 

632 if self._alias_name is not None: 

633 msg = GetAssetEventByAssetAlias(alias_name=self._alias_name, **query_dict) 

634 else: 

635 if self._asset_name is None and self._asset_uri is None: 

636 raise ValueError("Either asset_name or asset_uri must be provided") 

637 msg = GetAssetEventByAsset(name=self._asset_name, uri=self._asset_uri, **query_dict) 

638 resp = SUPERVISOR_COMMS.send(msg) 

639 if isinstance(resp, ErrorResponse): 

640 raise AirflowRuntimeError(resp) 

641 

642 if TYPE_CHECKING: 

643 assert isinstance(resp, AssetEventsResult) 

644 

645 return list(resp.iter_asset_event_results()) 

646 

647 def _reset_cache(self) -> None: 

648 try: 

649 del self._asset_events 

650 except AttributeError: 

651 pass 

652 

653 def __iter__(self) -> Iterator[AssetEventResult]: 

654 return iter(self._asset_events) 

655 

656 def __len__(self) -> int: 

657 return len(self._asset_events) 

658 

659 @overload 

660 def __getitem__(self, key: int) -> AssetEventResult: ... 

661 

662 @overload 

663 def __getitem__(self, key: slice) -> Sequence[AssetEventResult]: ... 

664 

665 def __getitem__(self, key: int | slice) -> AssetEventResult | Sequence[AssetEventResult]: 

666 return self._asset_events[key] 

667 

668 

669@attrs.define(init=False) 

670class InletEventsAccessors( 

671 Mapping["int | Asset | AssetAlias | AssetRef", Any], 

672 _AssetEventAccessorsMixin[Sequence["AssetEventResult"]], 

673): 

674 """Lazy mapping of inlet asset event accessors.""" 

675 

676 _inlets: list[Any] 

677 _assets: dict[AssetUniqueKey, Asset] 

678 _asset_aliases: dict[AssetAliasUniqueKey, AssetAlias] 

679 

680 def __init__(self, inlets: list) -> None: 

681 self._inlets = inlets 

682 self._assets = {} 

683 self._asset_aliases = {} 

684 

685 for inlet in inlets: 

686 if isinstance(inlet, Asset): 

687 self._assets[AssetUniqueKey.from_asset(inlet)] = inlet 

688 elif isinstance(inlet, AssetAlias): 

689 self._asset_aliases[AssetAliasUniqueKey.from_asset_alias(inlet)] = inlet 

690 elif isinstance(inlet, AssetNameRef): 

691 asset = OutletEventAccessors._get_asset_from_db(name=inlet.name) 

692 self._assets[AssetUniqueKey.from_asset(asset)] = asset 

693 elif isinstance(inlet, AssetUriRef): 

694 asset = OutletEventAccessors._get_asset_from_db(uri=inlet.uri) 

695 self._assets[AssetUniqueKey.from_asset(asset)] = asset 

696 

697 def __iter__(self) -> Iterator[Asset | AssetAlias]: 

698 return iter(self._inlets) 

699 

700 def __len__(self) -> int: 

701 return len(self._inlets) 

702 

703 def __getitem__( 

704 self, 

705 key: int | Asset | AssetAlias | AssetRef, 

706 ) -> InletEventsAccessor: 

707 from airflow.sdk.definitions.asset import Asset 

708 

709 if isinstance(key, int): # Support index access; it's easier for trivial cases. 

710 obj = self._inlets[key] 

711 if not isinstance(obj, (Asset, AssetAlias, AssetRef)): 

712 raise IndexError(key) 

713 else: 

714 obj = key 

715 

716 if isinstance(obj, Asset): 

717 asset = self._assets[AssetUniqueKey.from_asset(obj)] 

718 return InletEventsAccessor(asset_name=asset.name, asset_uri=asset.uri) 

719 if isinstance(obj, AssetNameRef): 

720 try: 

721 asset = next(a for k, a in self._assets.items() if k.name == obj.name) 

722 except StopIteration: 

723 raise KeyError(obj) from None 

724 return InletEventsAccessor(asset_name=asset.name) 

725 if isinstance(obj, AssetUriRef): 

726 try: 

727 asset = next(a for k, a in self._assets.items() if k.uri == obj.uri) 

728 except StopIteration: 

729 raise KeyError(obj) from None 

730 return InletEventsAccessor(asset_uri=asset.uri) 

731 if isinstance(obj, AssetAlias): 

732 asset_alias = self._asset_aliases[AssetAliasUniqueKey.from_asset_alias(obj)] 

733 return InletEventsAccessor(alias_name=asset_alias.name) 

734 raise TypeError(f"`key` is of unknown type ({type(key).__name__})") 

735 

736 

737@attrs.define 

738class TriggeringAssetEventsAccessor( 

739 _AssetRefResolutionMixin, 

740 Mapping["Asset | AssetAlias | AssetRef", Sequence["AssetEventDagRunReferenceResult"]], 

741 _AssetEventAccessorsMixin[Sequence["AssetEventDagRunReferenceResult"]], 

742): 

743 """Lazy mapping of triggering asset events.""" 

744 

745 _events: Mapping[BaseAssetUniqueKey, Sequence[AssetEventDagRunReferenceResult]] 

746 

747 @classmethod 

748 def build(cls, events: Iterable[AssetEventDagRunReferenceResult]) -> TriggeringAssetEventsAccessor: 

749 coll: dict[BaseAssetUniqueKey, list[AssetEventDagRunReferenceResult]] = collections.defaultdict(list) 

750 for event in events: 

751 coll[AssetUniqueKey(name=event.asset.name, uri=event.asset.uri)].append(event) 

752 for alias in event.source_aliases: 

753 coll[AssetAliasUniqueKey(name=alias.name)].append(event) 

754 return cls(coll) 

755 

756 def __str__(self) -> str: 

757 return f"TriggeringAssetEventAccessor(_events={self._events})" 

758 

759 def __iter__(self) -> Iterator[Asset | AssetAlias]: 

760 return ( 

761 key.to_asset() if isinstance(key, AssetUniqueKey) else key.to_asset_alias() 

762 for key in self._events 

763 ) 

764 

765 def __len__(self) -> int: 

766 return len(self._events) 

767 

768 def __getitem__(self, key: Asset | AssetAlias | AssetRef) -> Sequence[AssetEventDagRunReferenceResult]: 

769 hashable_key: BaseAssetUniqueKey 

770 if isinstance(key, Asset): 

771 hashable_key = AssetUniqueKey.from_asset(key) 

772 elif isinstance(key, AssetRef): 

773 hashable_key, _ = self._resolve_asset_ref(key) 

774 elif isinstance(key, AssetAlias): 

775 hashable_key = AssetAliasUniqueKey.from_asset_alias(key) 

776 else: 

777 raise TypeError(f"Key should be either an asset or an asset alias, not {type(key)}") 

778 

779 return self._events[hashable_key] 

780 

781 

782@cache # Prevent multiple API access. 

783def get_previous_dagrun_success(ti_id: UUID) -> PrevSuccessfulDagRunResponse: 

784 from airflow.sdk.execution_time import task_runner 

785 from airflow.sdk.execution_time.comms import ( 

786 GetPrevSuccessfulDagRun, 

787 PrevSuccessfulDagRunResponse, 

788 PrevSuccessfulDagRunResult, 

789 ) 

790 

791 msg = task_runner.SUPERVISOR_COMMS.send(GetPrevSuccessfulDagRun(ti_id=ti_id)) 

792 

793 if TYPE_CHECKING: 

794 assert isinstance(msg, PrevSuccessfulDagRunResult) 

795 return PrevSuccessfulDagRunResponse(**msg.model_dump(exclude={"type"})) 

796 

797 

798@contextlib.contextmanager 

799def set_current_context(context: Context) -> Generator[Context, None, None]: 

800 """ 

801 Set the current execution context to the provided context object. 

802 

803 This method should be called once per Task execution, before calling operator.execute. 

804 """ 

805 _CURRENT_CONTEXT.append(context) 

806 try: 

807 yield context 

808 finally: 

809 expected_state = _CURRENT_CONTEXT.pop() 

810 if expected_state != context: 

811 log.warning( 

812 "Current context is not equal to the state at context stack.", 

813 expected=context, 

814 got=expected_state, 

815 ) 

816 

817 

818def context_update_for_unmapped(context: Context, task: BaseOperator) -> None: 

819 """ 

820 Update context after task unmapping. 

821 

822 Since ``get_template_context()`` is called before unmapping, the context 

823 contains information about the mapped task. We need to do some in-place 

824 updates to ensure the template context reflects the unmapped task instead. 

825 

826 :meta private: 

827 """ 

828 from airflow.sdk.definitions.param import process_params 

829 

830 context["task"] = context["ti"].task = task 

831 context["params"] = process_params( 

832 context["dag"], task, context["dag_run"].conf, suppress_exception=False 

833 ) 

834 

835 

836def context_to_airflow_vars(context: Mapping[str, Any], in_env_var_format: bool = False) -> dict[str, str]: 

837 """ 

838 Return values used to externally reconstruct relations between dags, dag_runs, tasks and task_instances. 

839 

840 Given a context, this function provides a dictionary of values that can be used to 

841 externally reconstruct relations between dags, dag_runs, tasks and task_instances. 

842 Default to abc.def.ghi format and can be made to ABC_DEF_GHI format if 

843 in_env_var_format is set to True. 

844 

845 :param context: The context for the task_instance of interest. 

846 :param in_env_var_format: If returned vars should be in ABC_DEF_GHI format. 

847 :return: task_instance context as dict. 

848 """ 

849 from datetime import datetime 

850 

851 from airflow import settings 

852 

853 params = {} 

854 if in_env_var_format: 

855 name_format = "env_var_format" 

856 else: 

857 name_format = "default" 

858 

859 task = context.get("task") 

860 task_instance = context.get("task_instance") 

861 dag_run = context.get("dag_run") 

862 

863 ops = [ 

864 (task, "email", "AIRFLOW_CONTEXT_DAG_EMAIL"), 

865 (task, "owner", "AIRFLOW_CONTEXT_DAG_OWNER"), 

866 (task_instance, "dag_id", "AIRFLOW_CONTEXT_DAG_ID"), 

867 (task_instance, "task_id", "AIRFLOW_CONTEXT_TASK_ID"), 

868 (dag_run, "logical_date", "AIRFLOW_CONTEXT_LOGICAL_DATE"), 

869 (task_instance, "try_number", "AIRFLOW_CONTEXT_TRY_NUMBER"), 

870 (dag_run, "run_id", "AIRFLOW_CONTEXT_DAG_RUN_ID"), 

871 ] 

872 

873 context_params = settings.get_airflow_context_vars(context) 

874 for key_raw, value in context_params.items(): 

875 if not isinstance(key_raw, str): 

876 raise TypeError(f"key <{key_raw}> must be string") 

877 if not isinstance(value, str): 

878 raise TypeError(f"value of key <{key_raw}> must be string, not {type(value)}") 

879 

880 if in_env_var_format and not key_raw.startswith(ENV_VAR_FORMAT_PREFIX): 

881 key = ENV_VAR_FORMAT_PREFIX + key_raw.upper() 

882 elif not key_raw.startswith(DEFAULT_FORMAT_PREFIX): 

883 key = DEFAULT_FORMAT_PREFIX + key_raw 

884 else: 

885 key = key_raw 

886 params[key] = value 

887 

888 for subject, attr, mapping_key in ops: 

889 _attr = getattr(subject, attr, None) 

890 if subject and _attr: 

891 mapping_value = AIRFLOW_VAR_NAME_FORMAT_MAPPING[mapping_key][name_format] 

892 if isinstance(_attr, str): 

893 params[mapping_value] = _attr 

894 elif isinstance(_attr, datetime): 

895 params[mapping_value] = _attr.isoformat() 

896 elif isinstance(_attr, list): 

897 # os env variable value needs to be string 

898 params[mapping_value] = ",".join(_attr) 

899 else: 

900 params[mapping_value] = str(_attr) 

901 

902 return params 

903 

904 

905def context_get_outlet_events(context: Context) -> OutletEventAccessorsProtocol: 

906 try: 

907 outlet_events = context["outlet_events"] 

908 except KeyError: 

909 outlet_events = context["outlet_events"] = OutletEventAccessors() 

910 return outlet_events