Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/airflow/sdk/execution_time/context.py: 25%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

503 statements  

1# Licensed to the Apache Software Foundation (ASF) under one 

2# or more contributor license agreements. See the NOTICE file 

3# distributed with this work for additional information 

4# regarding copyright ownership. The ASF licenses this file 

5# to you under the Apache License, Version 2.0 (the 

6# "License"); you may not use this file except in compliance 

7# with the License. You may obtain a copy of the License at 

8# 

9# http://www.apache.org/licenses/LICENSE-2.0 

10# 

11# Unless required by applicable law or agreed to in writing, 

12# software distributed under the License is distributed on an 

13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

14# KIND, either express or implied. See the License for the 

15# specific language governing permissions and limitations 

16# under the License. 

17from __future__ import annotations 

18 

19import collections 

20import contextlib 

21import functools 

22from collections.abc import Generator, Iterable, Iterator, Mapping, Sequence 

23from datetime import datetime 

24from functools import cache 

25from typing import TYPE_CHECKING, Any, Generic, TypeVar, overload 

26 

27import attrs 

28import structlog 

29 

30from airflow.sdk.definitions._internal.contextmanager import _CURRENT_CONTEXT 

31from airflow.sdk.definitions._internal.types import NOTSET 

32from airflow.sdk.definitions.asset import ( 

33 Asset, 

34 AssetAlias, 

35 AssetAliasEvent, 

36 AssetAliasUniqueKey, 

37 AssetNameRef, 

38 AssetRef, 

39 AssetUniqueKey, 

40 AssetUriRef, 

41 BaseAssetUniqueKey, 

42) 

43from airflow.sdk.exceptions import AirflowNotFoundException, AirflowRuntimeError, ErrorType 

44from airflow.sdk.log import mask_secret 

45 

46if TYPE_CHECKING: 

47 from uuid import UUID 

48 

49 from pydantic.types import JsonValue 

50 from typing_extensions import Self 

51 

52 from airflow.sdk import Variable 

53 from airflow.sdk.bases.operator import BaseOperator 

54 from airflow.sdk.definitions.connection import Connection 

55 from airflow.sdk.definitions.context import Context 

56 from airflow.sdk.execution_time.comms import ( 

57 AssetEventDagRunReferenceResult, 

58 AssetEventResult, 

59 AssetEventsResult, 

60 AssetResult, 

61 ConnectionResult, 

62 OKResponse, 

63 PrevSuccessfulDagRunResponse, 

64 ReceiveMsgType, 

65 VariableResult, 

66 ) 

67 from airflow.sdk.types import OutletEventAccessorsProtocol 

68 

69 

70DEFAULT_FORMAT_PREFIX = "airflow.ctx." 

71ENV_VAR_FORMAT_PREFIX = "AIRFLOW_CTX_" 

72 

73AIRFLOW_VAR_NAME_FORMAT_MAPPING = { 

74 "AIRFLOW_CONTEXT_DAG_ID": { 

75 "default": f"{DEFAULT_FORMAT_PREFIX}dag_id", 

76 "env_var_format": f"{ENV_VAR_FORMAT_PREFIX}DAG_ID", 

77 }, 

78 "AIRFLOW_CONTEXT_TASK_ID": { 

79 "default": f"{DEFAULT_FORMAT_PREFIX}task_id", 

80 "env_var_format": f"{ENV_VAR_FORMAT_PREFIX}TASK_ID", 

81 }, 

82 "AIRFLOW_CONTEXT_LOGICAL_DATE": { 

83 "default": f"{DEFAULT_FORMAT_PREFIX}logical_date", 

84 "env_var_format": f"{ENV_VAR_FORMAT_PREFIX}LOGICAL_DATE", 

85 }, 

86 "AIRFLOW_CONTEXT_TRY_NUMBER": { 

87 "default": f"{DEFAULT_FORMAT_PREFIX}try_number", 

88 "env_var_format": f"{ENV_VAR_FORMAT_PREFIX}TRY_NUMBER", 

89 }, 

90 "AIRFLOW_CONTEXT_DAG_RUN_ID": { 

91 "default": f"{DEFAULT_FORMAT_PREFIX}dag_run_id", 

92 "env_var_format": f"{ENV_VAR_FORMAT_PREFIX}DAG_RUN_ID", 

93 }, 

94 "AIRFLOW_CONTEXT_DAG_OWNER": { 

95 "default": f"{DEFAULT_FORMAT_PREFIX}dag_owner", 

96 "env_var_format": f"{ENV_VAR_FORMAT_PREFIX}DAG_OWNER", 

97 }, 

98 "AIRFLOW_CONTEXT_DAG_EMAIL": { 

99 "default": f"{DEFAULT_FORMAT_PREFIX}dag_email", 

100 "env_var_format": f"{ENV_VAR_FORMAT_PREFIX}DAG_EMAIL", 

101 }, 

102} 

103 

104 

105log = structlog.get_logger(logger_name="task") 

106 

107T = TypeVar("T") 

108 

109 

110def _process_connection_result_conn(conn_result: ReceiveMsgType | None) -> Connection: 

111 from airflow.sdk.definitions.connection import Connection 

112 from airflow.sdk.execution_time.comms import ErrorResponse 

113 

114 if isinstance(conn_result, ErrorResponse): 

115 raise AirflowRuntimeError(conn_result) 

116 

117 if TYPE_CHECKING: 

118 assert isinstance(conn_result, ConnectionResult) 

119 

120 # `by_alias=True` is used to convert the `schema` field to `schema_` in the Connection model 

121 return Connection(**conn_result.model_dump(exclude={"type"}, by_alias=True)) 

122 

123 

124def _mask_connection_secrets(conn: Connection) -> None: 

125 """Mask sensitive connection fields from logs.""" 

126 if conn.password: 

127 mask_secret(conn.password) 

128 if conn.extra: 

129 mask_secret(conn.extra) 

130 

131 

132def _convert_variable_result_to_variable(var_result: VariableResult, deserialize_json: bool) -> Variable: 

133 from airflow.sdk.definitions.variable import Variable 

134 

135 if deserialize_json: 

136 import json 

137 

138 var_result.value = json.loads(var_result.value) # type: ignore 

139 return Variable(**var_result.model_dump(exclude={"type"})) 

140 

141 

142def _get_connection(conn_id: str) -> Connection: 

143 from airflow.sdk.execution_time.cache import SecretCache 

144 from airflow.sdk.execution_time.supervisor import ensure_secrets_backend_loaded 

145 

146 # Check cache first (optional; only on dag processor) 

147 try: 

148 uri = SecretCache.get_connection_uri(conn_id) 

149 from airflow.sdk.definitions.connection import Connection 

150 

151 conn = Connection.from_uri(uri, conn_id=conn_id) 

152 _mask_connection_secrets(conn) 

153 return conn 

154 except SecretCache.NotPresentException: 

155 pass # continue to backends 

156 

157 # Iterate over configured backends (which may include SupervisorCommsSecretsBackend 

158 # in worker contexts or MetastoreBackend in API server contexts) 

159 backends = ensure_secrets_backend_loaded() 

160 for secrets_backend in backends: 

161 try: 

162 conn = secrets_backend.get_connection(conn_id=conn_id) # type: ignore[assignment] 

163 if conn: 

164 SecretCache.save_connection_uri(conn_id, conn.get_uri()) 

165 _mask_connection_secrets(conn) 

166 return conn 

167 except Exception: 

168 log.debug( 

169 "Unable to retrieve connection from secrets backend (%s). " 

170 "Checking subsequent secrets backend.", 

171 type(secrets_backend).__name__, 

172 ) 

173 

174 # If no backend found the connection, raise an error 

175 

176 raise AirflowNotFoundException(f"The conn_id `{conn_id}` isn't defined") 

177 

178 

179async def _async_get_connection(conn_id: str) -> Connection: 

180 from asgiref.sync import sync_to_async 

181 

182 from airflow.sdk.execution_time.cache import SecretCache 

183 

184 # Check cache first 

185 try: 

186 uri = SecretCache.get_connection_uri(conn_id) 

187 from airflow.sdk.definitions.connection import Connection 

188 

189 conn = Connection.from_uri(uri, conn_id=conn_id) 

190 _mask_connection_secrets(conn) 

191 return conn 

192 except SecretCache.NotPresentException: 

193 pass # continue to backends 

194 

195 from airflow.sdk.execution_time.supervisor import ensure_secrets_backend_loaded 

196 

197 # Try secrets backends 

198 backends = ensure_secrets_backend_loaded() 

199 for secrets_backend in backends: 

200 try: 

201 # Use async method if available, otherwise wrap sync method 

202 if hasattr(secrets_backend, "aget_connection"): 

203 conn = await secrets_backend.aget_connection(conn_id) # type: ignore[assignment] 

204 else: 

205 conn = await sync_to_async(secrets_backend.get_connection)(conn_id) # type: ignore[assignment] 

206 

207 if conn: 

208 SecretCache.save_connection_uri(conn_id, conn.get_uri()) 

209 _mask_connection_secrets(conn) 

210 return conn 

211 except Exception: 

212 # If one backend fails, try the next one 

213 log.debug( 

214 "Unable to retrieve connection from secrets backend (%s). " 

215 "Checking subsequent secrets backend.", 

216 type(secrets_backend).__name__, 

217 ) 

218 

219 # If no backend found the connection, raise an error 

220 

221 raise AirflowNotFoundException(f"The conn_id `{conn_id}` isn't defined") 

222 

223 

224def _get_variable(key: str, deserialize_json: bool) -> Any: 

225 from airflow.sdk.execution_time.cache import SecretCache 

226 from airflow.sdk.execution_time.supervisor import ensure_secrets_backend_loaded 

227 

228 # Check cache first 

229 try: 

230 var_val = SecretCache.get_variable(key) 

231 if var_val is not None: 

232 if deserialize_json: 

233 import json 

234 

235 var_val = json.loads(var_val) 

236 if isinstance(var_val, str): 

237 mask_secret(var_val, key) 

238 return var_val 

239 except SecretCache.NotPresentException: 

240 pass # Continue to check backends 

241 

242 backends = ensure_secrets_backend_loaded() 

243 

244 # Iterate over backends if not in cache (or expired) 

245 for secrets_backend in backends: 

246 try: 

247 var_val = secrets_backend.get_variable(key=key) 

248 if var_val is not None: 

249 # Save raw value before deserialization to maintain cache consistency 

250 SecretCache.save_variable(key, var_val) 

251 if deserialize_json: 

252 import json 

253 

254 var_val = json.loads(var_val) 

255 if isinstance(var_val, str): 

256 mask_secret(var_val, key) 

257 return var_val 

258 except Exception: 

259 log.exception( 

260 "Unable to retrieve variable from secrets backend (%s). Checking subsequent secrets backend.", 

261 type(secrets_backend).__name__, 

262 ) 

263 

264 # If no backend found the variable, raise a not found error (mirrors _get_connection) 

265 from airflow.sdk.exceptions import AirflowRuntimeError, ErrorType 

266 from airflow.sdk.execution_time.comms import ErrorResponse 

267 

268 raise AirflowRuntimeError( 

269 ErrorResponse(error=ErrorType.VARIABLE_NOT_FOUND, detail={"message": f"Variable {key} not found"}) 

270 ) 

271 

272 

273def _set_variable(key: str, value: Any, description: str | None = None, serialize_json: bool = False) -> None: 

274 # TODO: This should probably be moved to a separate module like `airflow.sdk.execution_time.comms` 

275 # or `airflow.sdk.execution_time.variable` 

276 # A reason to not move it to `airflow.sdk.execution_time.comms` is that it 

277 # will make that module depend on Task SDK, which is not ideal because we intend to 

278 # keep Task SDK as a separate package than execution time mods. 

279 import json 

280 

281 from airflow.sdk.execution_time.cache import SecretCache 

282 from airflow.sdk.execution_time.comms import PutVariable 

283 from airflow.sdk.execution_time.secrets.execution_api import ExecutionAPISecretsBackend 

284 from airflow.sdk.execution_time.supervisor import ensure_secrets_backend_loaded 

285 from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS 

286 

287 # check for write conflicts on the worker 

288 for secrets_backend in ensure_secrets_backend_loaded(): 

289 if isinstance(secrets_backend, ExecutionAPISecretsBackend): 

290 continue 

291 try: 

292 var_val = secrets_backend.get_variable(key=key) 

293 if var_val is not None: 

294 _backend_name = type(secrets_backend).__name__ 

295 log.warning( 

296 "The variable %s is defined in the %s secrets backend, which takes " 

297 "precedence over reading from the API Server. The value from the API Server will be " 

298 "updated, but to read it you have to delete the conflicting variable " 

299 "from %s", 

300 key, 

301 _backend_name, 

302 _backend_name, 

303 ) 

304 except Exception: 

305 log.exception( 

306 "Unable to retrieve variable from secrets backend (%s). Checking subsequent secrets backend.", 

307 type(secrets_backend).__name__, 

308 ) 

309 

310 try: 

311 if serialize_json: 

312 value = json.dumps(value, indent=2) 

313 except Exception as e: 

314 log.exception(e) 

315 

316 SUPERVISOR_COMMS.send(PutVariable(key=key, value=value, description=description)) 

317 

318 # Invalidate cache after setting the variable 

319 SecretCache.invalidate_variable(key) 

320 

321 

322def _delete_variable(key: str) -> None: 

323 # TODO: This should probably be moved to a separate module like `airflow.sdk.execution_time.comms` 

324 # or `airflow.sdk.execution_time.variable` 

325 # A reason to not move it to `airflow.sdk.execution_time.comms` is that it 

326 # will make that module depend on Task SDK, which is not ideal because we intend to 

327 # keep Task SDK as a separate package than execution time mods. 

328 from airflow.sdk.execution_time.cache import SecretCache 

329 from airflow.sdk.execution_time.comms import DeleteVariable 

330 from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS 

331 

332 msg = SUPERVISOR_COMMS.send(DeleteVariable(key=key)) 

333 if TYPE_CHECKING: 

334 assert isinstance(msg, OKResponse) 

335 

336 # Invalidate cache after deleting the variable 

337 SecretCache.invalidate_variable(key) 

338 

339 

340class ConnectionAccessor: 

341 """Wrapper to access Connection entries in template.""" 

342 

343 def __getattr__(self, conn_id: str) -> Any: 

344 from airflow.sdk.definitions.connection import Connection 

345 

346 return Connection.get(conn_id) 

347 

348 def __repr__(self) -> str: 

349 return "<ConnectionAccessor (dynamic access)>" 

350 

351 def __eq__(self, other): 

352 if not isinstance(other, ConnectionAccessor): 

353 return False 

354 # All instances of ConnectionAccessor are equal since it is a stateless dynamic accessor 

355 return True 

356 

357 def __hash__(self): 

358 return hash(self.__class__.__name__) 

359 

360 def get(self, conn_id: str, default_conn: Any = None) -> Any: 

361 try: 

362 return _get_connection(conn_id) 

363 except AirflowRuntimeError as e: 

364 if e.error.error == ErrorType.CONNECTION_NOT_FOUND: 

365 return default_conn 

366 raise 

367 except AirflowNotFoundException: 

368 return default_conn 

369 

370 

371class VariableAccessor: 

372 """Wrapper to access Variable values in template.""" 

373 

374 def __init__(self, deserialize_json: bool) -> None: 

375 self._deserialize_json = deserialize_json 

376 

377 def __eq__(self, other): 

378 if not isinstance(other, VariableAccessor): 

379 return False 

380 # All instances of VariableAccessor are equal since it is a stateless dynamic accessor 

381 return True 

382 

383 def __hash__(self): 

384 return hash(self.__class__.__name__) 

385 

386 def __repr__(self) -> str: 

387 return "<VariableAccessor (dynamic access)>" 

388 

389 def __getattr__(self, key: str) -> Any: 

390 return _get_variable(key, self._deserialize_json) 

391 

392 def get(self, key, default: Any = NOTSET) -> Any: 

393 try: 

394 return _get_variable(key, self._deserialize_json) 

395 except AirflowRuntimeError as e: 

396 if e.error.error == ErrorType.VARIABLE_NOT_FOUND: 

397 return default 

398 raise 

399 

400 

401class MacrosAccessor: 

402 """Wrapper to access Macros module lazily.""" 

403 

404 _macros_module = None 

405 

406 def __getattr__(self, item: str) -> Any: 

407 # Lazily load Macros module 

408 if not self._macros_module: 

409 import airflow.sdk.execution_time.macros 

410 

411 self._macros_module = airflow.sdk.execution_time.macros 

412 return getattr(self._macros_module, item) 

413 

414 def __repr__(self) -> str: 

415 return "<MacrosAccessor (dynamic access to macros)>" 

416 

417 def __eq__(self, other: object) -> bool: 

418 if not isinstance(other, MacrosAccessor): 

419 return False 

420 return True 

421 

422 def __hash__(self): 

423 return hash(self.__class__.__name__) 

424 

425 

426class _AssetRefResolutionMixin: 

427 _asset_ref_cache: dict[AssetRef, tuple[AssetUniqueKey, dict[str, JsonValue]]] = {} 

428 

429 def _resolve_asset_ref(self, ref: AssetRef) -> tuple[AssetUniqueKey, dict[str, JsonValue]]: 

430 with contextlib.suppress(KeyError): 

431 return self._asset_ref_cache[ref] 

432 

433 refs_to_cache: list[AssetRef] 

434 if isinstance(ref, AssetNameRef): 

435 asset = self._get_asset_from_db(name=ref.name) 

436 refs_to_cache = [ref, AssetUriRef(asset.uri)] 

437 elif isinstance(ref, AssetUriRef): 

438 asset = self._get_asset_from_db(uri=ref.uri) 

439 refs_to_cache = [ref, AssetNameRef(asset.name)] 

440 else: 

441 raise TypeError(f"Unimplemented asset ref: {type(ref)}") 

442 unique_key = AssetUniqueKey.from_asset(asset) 

443 for ref in refs_to_cache: 

444 self._asset_ref_cache[ref] = (unique_key, asset.extra) 

445 return (unique_key, asset.extra) 

446 

447 # TODO: This is temporary to avoid code duplication between here & airflow/models/taskinstance.py 

448 @staticmethod 

449 def _get_asset_from_db(name: str | None = None, uri: str | None = None) -> Asset: 

450 from airflow.sdk.definitions.asset import Asset 

451 from airflow.sdk.execution_time.comms import ( 

452 ErrorResponse, 

453 GetAssetByName, 

454 GetAssetByUri, 

455 ToSupervisor, 

456 ) 

457 from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS 

458 

459 msg: ToSupervisor 

460 if name: 

461 msg = GetAssetByName(name=name) 

462 elif uri: 

463 msg = GetAssetByUri(uri=uri) 

464 else: 

465 raise ValueError("Either name or uri must be provided") 

466 

467 resp = SUPERVISOR_COMMS.send(msg) 

468 if isinstance(resp, ErrorResponse): 

469 raise AirflowRuntimeError(resp) 

470 

471 if TYPE_CHECKING: 

472 assert isinstance(resp, AssetResult) 

473 return Asset(**resp.model_dump(exclude={"type"})) 

474 

475 

476@attrs.define 

477class OutletEventAccessor(_AssetRefResolutionMixin): 

478 """Wrapper to access an outlet asset event in template.""" 

479 

480 key: BaseAssetUniqueKey 

481 extra: dict[str, JsonValue] = attrs.Factory(dict) 

482 asset_alias_events: list[AssetAliasEvent] = attrs.field(factory=list) 

483 

484 def add(self, asset: Asset | AssetRef, extra: dict[str, JsonValue] | None = None) -> None: 

485 """Add an AssetEvent to an existing Asset.""" 

486 if not isinstance(self.key, AssetAliasUniqueKey): 

487 return 

488 

489 if isinstance(asset, AssetRef): 

490 asset_key, asset_extra = self._resolve_asset_ref(asset) 

491 else: 

492 asset_key = AssetUniqueKey.from_asset(asset) 

493 asset_extra = asset.extra 

494 

495 asset_alias_name = self.key.name 

496 event = AssetAliasEvent( 

497 source_alias_name=asset_alias_name, 

498 dest_asset_key=asset_key, 

499 dest_asset_extra=asset_extra, 

500 extra=extra or {}, 

501 ) 

502 self.asset_alias_events.append(event) 

503 

504 

505class _AssetEventAccessorsMixin(Generic[T]): 

506 @overload 

507 def for_asset(self, *, name: str, uri: str) -> T: ... 

508 

509 @overload 

510 def for_asset(self, *, name: str) -> T: ... 

511 

512 @overload 

513 def for_asset(self, *, uri: str) -> T: ... 

514 

515 def for_asset(self, *, name: str | None = None, uri: str | None = None) -> T: 

516 if name and uri: 

517 return self[Asset(name=name, uri=uri)] 

518 if name: 

519 return self[Asset.ref(name=name)] 

520 if uri: 

521 return self[Asset.ref(uri=uri)] 

522 

523 raise ValueError("name and uri cannot both be None") 

524 

525 def for_asset_alias(self, *, name: str) -> T: 

526 return self[AssetAlias(name=name)] 

527 

528 def __getitem__(self, key: Asset | AssetAlias | AssetRef) -> T: 

529 raise NotImplementedError 

530 

531 

532class OutletEventAccessors( 

533 _AssetRefResolutionMixin, 

534 Mapping["Asset | AssetAlias", OutletEventAccessor], 

535 _AssetEventAccessorsMixin[OutletEventAccessor], 

536): 

537 """Lazy mapping of outlet asset event accessors.""" 

538 

539 def __init__(self) -> None: 

540 self._dict: dict[BaseAssetUniqueKey, OutletEventAccessor] = {} 

541 

542 def __str__(self) -> str: 

543 return f"OutletEventAccessors(_dict={self._dict})" 

544 

545 def __iter__(self) -> Iterator[Asset | AssetAlias]: 

546 return ( 

547 key.to_asset() if isinstance(key, AssetUniqueKey) else key.to_asset_alias() for key in self._dict 

548 ) 

549 

550 def __len__(self) -> int: 

551 return len(self._dict) 

552 

553 def __getitem__(self, key: Asset | AssetAlias | AssetRef) -> OutletEventAccessor: 

554 hashable_key: BaseAssetUniqueKey 

555 if isinstance(key, Asset): 

556 hashable_key = AssetUniqueKey.from_asset(key) 

557 elif isinstance(key, AssetAlias): 

558 hashable_key = AssetAliasUniqueKey.from_asset_alias(key) 

559 elif isinstance(key, AssetRef): 

560 hashable_key, _ = self._resolve_asset_ref(key) 

561 else: 

562 raise TypeError(f"Key should be either an asset or an asset alias, not {type(key)}") 

563 

564 if hashable_key not in self._dict: 

565 self._dict[hashable_key] = OutletEventAccessor(extra={}, key=hashable_key) 

566 return self._dict[hashable_key] 

567 

568 

569@attrs.define(init=False) 

570class InletEventsAccessor(Sequence["AssetEventResult"]): 

571 _after: str | datetime | None 

572 _before: str | datetime | None 

573 _ascending: bool 

574 _limit: int | None 

575 _asset_name: str | None 

576 _asset_uri: str | None 

577 _alias_name: str | None 

578 

579 def __init__( 

580 self, asset_name: str | None = None, asset_uri: str | None = None, alias_name: str | None = None 

581 ): 

582 self._asset_name = asset_name 

583 self._asset_uri = asset_uri 

584 self._alias_name = alias_name 

585 self._after = None 

586 self._before = None 

587 self._ascending = True 

588 self._limit = None 

589 

590 def after(self, after: str) -> Self: 

591 self._after = after 

592 self._reset_cache() 

593 return self 

594 

595 def before(self, before: str) -> Self: 

596 self._before = before 

597 self._reset_cache() 

598 return self 

599 

600 def ascending(self, ascending: bool = True) -> Self: 

601 self._ascending = ascending 

602 self._reset_cache() 

603 return self 

604 

605 def limit(self, limit: int) -> Self: 

606 self._limit = limit 

607 self._reset_cache() 

608 return self 

609 

610 @functools.cached_property 

611 def _asset_events(self) -> list[AssetEventResult]: 

612 from airflow.sdk.execution_time.comms import ( 

613 ErrorResponse, 

614 GetAssetEventByAsset, 

615 GetAssetEventByAssetAlias, 

616 ToSupervisor, 

617 ) 

618 from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS 

619 

620 query_dict: dict[str, Any] = { 

621 "after": self._after, 

622 "before": self._before, 

623 "ascending": self._ascending, 

624 "limit": self._limit, 

625 } 

626 

627 msg: ToSupervisor 

628 if self._alias_name is not None: 

629 msg = GetAssetEventByAssetAlias(alias_name=self._alias_name, **query_dict) 

630 else: 

631 if self._asset_name is None and self._asset_uri is None: 

632 raise ValueError("Either asset_name or asset_uri must be provided") 

633 msg = GetAssetEventByAsset(name=self._asset_name, uri=self._asset_uri, **query_dict) 

634 resp = SUPERVISOR_COMMS.send(msg) 

635 if isinstance(resp, ErrorResponse): 

636 raise AirflowRuntimeError(resp) 

637 

638 if TYPE_CHECKING: 

639 assert isinstance(resp, AssetEventsResult) 

640 

641 return list(resp.iter_asset_event_results()) 

642 

643 def _reset_cache(self) -> None: 

644 try: 

645 del self._asset_events 

646 except AttributeError: 

647 pass 

648 

649 def __iter__(self) -> Iterator[AssetEventResult]: 

650 return iter(self._asset_events) 

651 

652 def __len__(self) -> int: 

653 return len(self._asset_events) 

654 

655 @overload 

656 def __getitem__(self, key: int) -> AssetEventResult: ... 

657 

658 @overload 

659 def __getitem__(self, key: slice) -> Sequence[AssetEventResult]: ... 

660 

661 def __getitem__(self, key: int | slice) -> AssetEventResult | Sequence[AssetEventResult]: 

662 return self._asset_events[key] 

663 

664 

665@attrs.define(init=False) 

666class InletEventsAccessors( 

667 Mapping["int | Asset | AssetAlias | AssetRef", Any], 

668 _AssetEventAccessorsMixin[Sequence["AssetEventResult"]], 

669): 

670 """Lazy mapping of inlet asset event accessors.""" 

671 

672 _inlets: list[Any] 

673 _assets: dict[AssetUniqueKey, Asset] 

674 _asset_aliases: dict[AssetAliasUniqueKey, AssetAlias] 

675 

676 def __init__(self, inlets: list) -> None: 

677 self._inlets = inlets 

678 self._assets = {} 

679 self._asset_aliases = {} 

680 

681 for inlet in inlets: 

682 if isinstance(inlet, Asset): 

683 self._assets[AssetUniqueKey.from_asset(inlet)] = inlet 

684 elif isinstance(inlet, AssetAlias): 

685 self._asset_aliases[AssetAliasUniqueKey.from_asset_alias(inlet)] = inlet 

686 elif isinstance(inlet, AssetNameRef): 

687 asset = OutletEventAccessors._get_asset_from_db(name=inlet.name) 

688 self._assets[AssetUniqueKey.from_asset(asset)] = asset 

689 elif isinstance(inlet, AssetUriRef): 

690 asset = OutletEventAccessors._get_asset_from_db(uri=inlet.uri) 

691 self._assets[AssetUniqueKey.from_asset(asset)] = asset 

692 

693 def __iter__(self) -> Iterator[Asset | AssetAlias]: 

694 return iter(self._inlets) 

695 

696 def __len__(self) -> int: 

697 return len(self._inlets) 

698 

699 def __getitem__( 

700 self, 

701 key: int | Asset | AssetAlias | AssetRef, 

702 ) -> InletEventsAccessor: 

703 from airflow.sdk.definitions.asset import Asset 

704 

705 if isinstance(key, int): # Support index access; it's easier for trivial cases. 

706 obj = self._inlets[key] 

707 if not isinstance(obj, (Asset, AssetAlias, AssetRef)): 

708 raise IndexError(key) 

709 else: 

710 obj = key 

711 

712 if isinstance(obj, Asset): 

713 asset = self._assets[AssetUniqueKey.from_asset(obj)] 

714 return InletEventsAccessor(asset_name=asset.name, asset_uri=asset.uri) 

715 if isinstance(obj, AssetNameRef): 

716 try: 

717 asset = next(a for k, a in self._assets.items() if k.name == obj.name) 

718 except StopIteration: 

719 raise KeyError(obj) from None 

720 return InletEventsAccessor(asset_name=asset.name) 

721 if isinstance(obj, AssetUriRef): 

722 try: 

723 asset = next(a for k, a in self._assets.items() if k.uri == obj.uri) 

724 except StopIteration: 

725 raise KeyError(obj) from None 

726 return InletEventsAccessor(asset_uri=asset.uri) 

727 if isinstance(obj, AssetAlias): 

728 asset_alias = self._asset_aliases[AssetAliasUniqueKey.from_asset_alias(obj)] 

729 return InletEventsAccessor(alias_name=asset_alias.name) 

730 raise TypeError(f"`key` is of unknown type ({type(key).__name__})") 

731 

732 

733@attrs.define 

734class TriggeringAssetEventsAccessor( 

735 _AssetRefResolutionMixin, 

736 Mapping["Asset | AssetAlias | AssetRef", Sequence["AssetEventDagRunReferenceResult"]], 

737 _AssetEventAccessorsMixin[Sequence["AssetEventDagRunReferenceResult"]], 

738): 

739 """Lazy mapping of triggering asset events.""" 

740 

741 _events: Mapping[BaseAssetUniqueKey, Sequence[AssetEventDagRunReferenceResult]] 

742 

743 @classmethod 

744 def build(cls, events: Iterable[AssetEventDagRunReferenceResult]) -> TriggeringAssetEventsAccessor: 

745 coll: dict[BaseAssetUniqueKey, list[AssetEventDagRunReferenceResult]] = collections.defaultdict(list) 

746 for event in events: 

747 coll[AssetUniqueKey(name=event.asset.name, uri=event.asset.uri)].append(event) 

748 for alias in event.source_aliases: 

749 coll[AssetAliasUniqueKey(name=alias.name)].append(event) 

750 return cls(coll) 

751 

752 def __str__(self) -> str: 

753 return f"TriggeringAssetEventAccessor(_events={self._events})" 

754 

755 def __iter__(self) -> Iterator[Asset | AssetAlias]: 

756 return ( 

757 key.to_asset() if isinstance(key, AssetUniqueKey) else key.to_asset_alias() 

758 for key in self._events 

759 ) 

760 

761 def __len__(self) -> int: 

762 return len(self._events) 

763 

764 def __getitem__(self, key: Asset | AssetAlias | AssetRef) -> Sequence[AssetEventDagRunReferenceResult]: 

765 hashable_key: BaseAssetUniqueKey 

766 if isinstance(key, Asset): 

767 hashable_key = AssetUniqueKey.from_asset(key) 

768 elif isinstance(key, AssetRef): 

769 hashable_key, _ = self._resolve_asset_ref(key) 

770 elif isinstance(key, AssetAlias): 

771 hashable_key = AssetAliasUniqueKey.from_asset_alias(key) 

772 else: 

773 raise TypeError(f"Key should be either an asset or an asset alias, not {type(key)}") 

774 

775 return self._events[hashable_key] 

776 

777 

778@cache # Prevent multiple API access. 

779def get_previous_dagrun_success(ti_id: UUID) -> PrevSuccessfulDagRunResponse: 

780 from airflow.sdk.execution_time import task_runner 

781 from airflow.sdk.execution_time.comms import ( 

782 GetPrevSuccessfulDagRun, 

783 PrevSuccessfulDagRunResponse, 

784 PrevSuccessfulDagRunResult, 

785 ) 

786 

787 msg = task_runner.SUPERVISOR_COMMS.send(GetPrevSuccessfulDagRun(ti_id=ti_id)) 

788 

789 if TYPE_CHECKING: 

790 assert isinstance(msg, PrevSuccessfulDagRunResult) 

791 return PrevSuccessfulDagRunResponse(**msg.model_dump(exclude={"type"})) 

792 

793 

794@contextlib.contextmanager 

795def set_current_context(context: Context) -> Generator[Context, None, None]: 

796 """ 

797 Set the current execution context to the provided context object. 

798 

799 This method should be called once per Task execution, before calling operator.execute. 

800 """ 

801 _CURRENT_CONTEXT.append(context) 

802 try: 

803 yield context 

804 finally: 

805 expected_state = _CURRENT_CONTEXT.pop() 

806 if expected_state != context: 

807 log.warning( 

808 "Current context is not equal to the state at context stack.", 

809 expected=context, 

810 got=expected_state, 

811 ) 

812 

813 

814def context_update_for_unmapped(context: Context, task: BaseOperator) -> None: 

815 """ 

816 Update context after task unmapping. 

817 

818 Since ``get_template_context()`` is called before unmapping, the context 

819 contains information about the mapped task. We need to do some in-place 

820 updates to ensure the template context reflects the unmapped task instead. 

821 

822 :meta private: 

823 """ 

824 from airflow.sdk.definitions.param import process_params 

825 

826 context["task"] = context["ti"].task = task 

827 context["params"] = process_params( 

828 context["dag"], task, context["dag_run"].conf, suppress_exception=False 

829 ) 

830 

831 

832def context_to_airflow_vars(context: Mapping[str, Any], in_env_var_format: bool = False) -> dict[str, str]: 

833 """ 

834 Return values used to externally reconstruct relations between dags, dag_runs, tasks and task_instances. 

835 

836 Given a context, this function provides a dictionary of values that can be used to 

837 externally reconstruct relations between dags, dag_runs, tasks and task_instances. 

838 Default to abc.def.ghi format and can be made to ABC_DEF_GHI format if 

839 in_env_var_format is set to True. 

840 

841 :param context: The context for the task_instance of interest. 

842 :param in_env_var_format: If returned vars should be in ABC_DEF_GHI format. 

843 :return: task_instance context as dict. 

844 """ 

845 from datetime import datetime 

846 

847 from airflow import settings 

848 

849 params = {} 

850 if in_env_var_format: 

851 name_format = "env_var_format" 

852 else: 

853 name_format = "default" 

854 

855 task = context.get("task") 

856 task_instance = context.get("task_instance") 

857 dag_run = context.get("dag_run") 

858 

859 ops = [ 

860 (task, "email", "AIRFLOW_CONTEXT_DAG_EMAIL"), 

861 (task, "owner", "AIRFLOW_CONTEXT_DAG_OWNER"), 

862 (task_instance, "dag_id", "AIRFLOW_CONTEXT_DAG_ID"), 

863 (task_instance, "task_id", "AIRFLOW_CONTEXT_TASK_ID"), 

864 (dag_run, "logical_date", "AIRFLOW_CONTEXT_LOGICAL_DATE"), 

865 (task_instance, "try_number", "AIRFLOW_CONTEXT_TRY_NUMBER"), 

866 (dag_run, "run_id", "AIRFLOW_CONTEXT_DAG_RUN_ID"), 

867 ] 

868 

869 context_params = settings.get_airflow_context_vars(context) 

870 for key_raw, value in context_params.items(): 

871 if not isinstance(key_raw, str): 

872 raise TypeError(f"key <{key_raw}> must be string") 

873 if not isinstance(value, str): 

874 raise TypeError(f"value of key <{key_raw}> must be string, not {type(value)}") 

875 

876 if in_env_var_format and not key_raw.startswith(ENV_VAR_FORMAT_PREFIX): 

877 key = ENV_VAR_FORMAT_PREFIX + key_raw.upper() 

878 elif not key_raw.startswith(DEFAULT_FORMAT_PREFIX): 

879 key = DEFAULT_FORMAT_PREFIX + key_raw 

880 else: 

881 key = key_raw 

882 params[key] = value 

883 

884 for subject, attr, mapping_key in ops: 

885 _attr = getattr(subject, attr, None) 

886 if subject and _attr: 

887 mapping_value = AIRFLOW_VAR_NAME_FORMAT_MAPPING[mapping_key][name_format] 

888 if isinstance(_attr, str): 

889 params[mapping_value] = _attr 

890 elif isinstance(_attr, datetime): 

891 params[mapping_value] = _attr.isoformat() 

892 elif isinstance(_attr, list): 

893 # os env variable value needs to be string 

894 params[mapping_value] = ",".join(_attr) 

895 else: 

896 params[mapping_value] = str(_attr) 

897 

898 return params 

899 

900 

901def context_get_outlet_events(context: Context) -> OutletEventAccessorsProtocol: 

902 try: 

903 outlet_events = context["outlet_events"] 

904 except KeyError: 

905 outlet_events = context["outlet_events"] = OutletEventAccessors() 

906 return outlet_events