1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements. See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership. The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License. You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied. See the License for the
15# specific language governing permissions and limitations
16# under the License.
17from __future__ import annotations
18
19import collections
20import contextlib
21import functools
22from collections.abc import Generator, Iterable, Iterator, Mapping, Sequence
23from datetime import datetime
24from functools import cache
25from typing import TYPE_CHECKING, Any, Generic, TypeVar, overload
26
27import attrs
28import structlog
29
30from airflow.sdk.definitions._internal.contextmanager import _CURRENT_CONTEXT
31from airflow.sdk.definitions._internal.types import NOTSET
32from airflow.sdk.definitions.asset import (
33 Asset,
34 AssetAlias,
35 AssetAliasEvent,
36 AssetAliasUniqueKey,
37 AssetNameRef,
38 AssetRef,
39 AssetUniqueKey,
40 AssetUriRef,
41 BaseAssetUniqueKey,
42)
43from airflow.sdk.exceptions import AirflowNotFoundException, AirflowRuntimeError, ErrorType
44from airflow.sdk.log import mask_secret
45
46if TYPE_CHECKING:
47 from uuid import UUID
48
49 from pydantic.types import JsonValue
50 from typing_extensions import Self
51
52 from airflow.sdk import Variable
53 from airflow.sdk.bases.operator import BaseOperator
54 from airflow.sdk.definitions.connection import Connection
55 from airflow.sdk.definitions.context import Context
56 from airflow.sdk.execution_time.comms import (
57 AssetEventDagRunReferenceResult,
58 AssetEventResult,
59 AssetEventsResult,
60 AssetResult,
61 ConnectionResult,
62 OKResponse,
63 PrevSuccessfulDagRunResponse,
64 ReceiveMsgType,
65 VariableResult,
66 )
67 from airflow.sdk.types import OutletEventAccessorsProtocol
68
69
70DEFAULT_FORMAT_PREFIX = "airflow.ctx."
71ENV_VAR_FORMAT_PREFIX = "AIRFLOW_CTX_"
72
73AIRFLOW_VAR_NAME_FORMAT_MAPPING = {
74 "AIRFLOW_CONTEXT_DAG_ID": {
75 "default": f"{DEFAULT_FORMAT_PREFIX}dag_id",
76 "env_var_format": f"{ENV_VAR_FORMAT_PREFIX}DAG_ID",
77 },
78 "AIRFLOW_CONTEXT_TASK_ID": {
79 "default": f"{DEFAULT_FORMAT_PREFIX}task_id",
80 "env_var_format": f"{ENV_VAR_FORMAT_PREFIX}TASK_ID",
81 },
82 "AIRFLOW_CONTEXT_LOGICAL_DATE": {
83 "default": f"{DEFAULT_FORMAT_PREFIX}logical_date",
84 "env_var_format": f"{ENV_VAR_FORMAT_PREFIX}LOGICAL_DATE",
85 },
86 "AIRFLOW_CONTEXT_TRY_NUMBER": {
87 "default": f"{DEFAULT_FORMAT_PREFIX}try_number",
88 "env_var_format": f"{ENV_VAR_FORMAT_PREFIX}TRY_NUMBER",
89 },
90 "AIRFLOW_CONTEXT_DAG_RUN_ID": {
91 "default": f"{DEFAULT_FORMAT_PREFIX}dag_run_id",
92 "env_var_format": f"{ENV_VAR_FORMAT_PREFIX}DAG_RUN_ID",
93 },
94 "AIRFLOW_CONTEXT_DAG_OWNER": {
95 "default": f"{DEFAULT_FORMAT_PREFIX}dag_owner",
96 "env_var_format": f"{ENV_VAR_FORMAT_PREFIX}DAG_OWNER",
97 },
98 "AIRFLOW_CONTEXT_DAG_EMAIL": {
99 "default": f"{DEFAULT_FORMAT_PREFIX}dag_email",
100 "env_var_format": f"{ENV_VAR_FORMAT_PREFIX}DAG_EMAIL",
101 },
102}
103
104
105log = structlog.get_logger(logger_name="task")
106
107T = TypeVar("T")
108
109
110def _process_connection_result_conn(conn_result: ReceiveMsgType | None) -> Connection:
111 from airflow.sdk.definitions.connection import Connection
112 from airflow.sdk.execution_time.comms import ErrorResponse
113
114 if isinstance(conn_result, ErrorResponse):
115 raise AirflowRuntimeError(conn_result)
116
117 if TYPE_CHECKING:
118 assert isinstance(conn_result, ConnectionResult)
119
120 # `by_alias=True` is used to convert the `schema` field to `schema_` in the Connection model
121 return Connection(**conn_result.model_dump(exclude={"type"}, by_alias=True))
122
123
124def _mask_connection_secrets(conn: Connection) -> None:
125 """Mask sensitive connection fields from logs."""
126 if conn.password:
127 mask_secret(conn.password)
128 if conn.extra:
129 mask_secret(conn.extra)
130
131
132def _convert_variable_result_to_variable(var_result: VariableResult, deserialize_json: bool) -> Variable:
133 from airflow.sdk.definitions.variable import Variable
134
135 if deserialize_json:
136 import json
137
138 var_result.value = json.loads(var_result.value) # type: ignore
139 return Variable(**var_result.model_dump(exclude={"type"}))
140
141
142def _get_connection(conn_id: str) -> Connection:
143 from airflow.sdk.execution_time.cache import SecretCache
144 from airflow.sdk.execution_time.supervisor import ensure_secrets_backend_loaded
145
146 # Check cache first (optional; only on dag processor)
147 try:
148 uri = SecretCache.get_connection_uri(conn_id)
149 from airflow.sdk.definitions.connection import Connection
150
151 conn = Connection.from_uri(uri, conn_id=conn_id)
152 _mask_connection_secrets(conn)
153 return conn
154 except SecretCache.NotPresentException:
155 pass # continue to backends
156
157 # Iterate over configured backends (which may include SupervisorCommsSecretsBackend
158 # in worker contexts or MetastoreBackend in API server contexts)
159 backends = ensure_secrets_backend_loaded()
160 for secrets_backend in backends:
161 try:
162 conn = secrets_backend.get_connection(conn_id=conn_id) # type: ignore[assignment]
163 if conn:
164 SecretCache.save_connection_uri(conn_id, conn.get_uri())
165 _mask_connection_secrets(conn)
166 return conn
167 except Exception:
168 log.debug(
169 "Unable to retrieve connection from secrets backend (%s). "
170 "Checking subsequent secrets backend.",
171 type(secrets_backend).__name__,
172 )
173
174 # If no backend found the connection, raise an error
175
176 raise AirflowNotFoundException(f"The conn_id `{conn_id}` isn't defined")
177
178
179async def _async_get_connection(conn_id: str) -> Connection:
180 from asgiref.sync import sync_to_async
181
182 from airflow.sdk.execution_time.cache import SecretCache
183
184 # Check cache first
185 try:
186 uri = SecretCache.get_connection_uri(conn_id)
187 from airflow.sdk.definitions.connection import Connection
188
189 conn = Connection.from_uri(uri, conn_id=conn_id)
190 _mask_connection_secrets(conn)
191 return conn
192 except SecretCache.NotPresentException:
193 pass # continue to backends
194
195 from airflow.sdk.execution_time.supervisor import ensure_secrets_backend_loaded
196
197 # Try secrets backends
198 backends = ensure_secrets_backend_loaded()
199 for secrets_backend in backends:
200 try:
201 # Use async method if available, otherwise wrap sync method
202 if hasattr(secrets_backend, "aget_connection"):
203 conn = await secrets_backend.aget_connection(conn_id) # type: ignore[assignment]
204 else:
205 conn = await sync_to_async(secrets_backend.get_connection)(conn_id) # type: ignore[assignment]
206
207 if conn:
208 SecretCache.save_connection_uri(conn_id, conn.get_uri())
209 _mask_connection_secrets(conn)
210 return conn
211 except Exception:
212 # If one backend fails, try the next one
213 log.debug(
214 "Unable to retrieve connection from secrets backend (%s). "
215 "Checking subsequent secrets backend.",
216 type(secrets_backend).__name__,
217 )
218
219 # If no backend found the connection, raise an error
220
221 raise AirflowNotFoundException(f"The conn_id `{conn_id}` isn't defined")
222
223
224def _get_variable(key: str, deserialize_json: bool) -> Any:
225 from airflow.sdk.execution_time.cache import SecretCache
226 from airflow.sdk.execution_time.supervisor import ensure_secrets_backend_loaded
227
228 # Check cache first
229 try:
230 var_val = SecretCache.get_variable(key)
231 if var_val is not None:
232 if deserialize_json:
233 import json
234
235 var_val = json.loads(var_val)
236 if isinstance(var_val, str):
237 mask_secret(var_val, key)
238 return var_val
239 except SecretCache.NotPresentException:
240 pass # Continue to check backends
241
242 backends = ensure_secrets_backend_loaded()
243
244 # Iterate over backends if not in cache (or expired)
245 for secrets_backend in backends:
246 try:
247 var_val = secrets_backend.get_variable(key=key)
248 if var_val is not None:
249 # Save raw value before deserialization to maintain cache consistency
250 SecretCache.save_variable(key, var_val)
251 if deserialize_json:
252 import json
253
254 var_val = json.loads(var_val)
255 if isinstance(var_val, str):
256 mask_secret(var_val, key)
257 return var_val
258 except Exception:
259 log.exception(
260 "Unable to retrieve variable from secrets backend (%s). Checking subsequent secrets backend.",
261 type(secrets_backend).__name__,
262 )
263
264 # If no backend found the variable, raise a not found error (mirrors _get_connection)
265 from airflow.sdk.exceptions import AirflowRuntimeError, ErrorType
266 from airflow.sdk.execution_time.comms import ErrorResponse
267
268 raise AirflowRuntimeError(
269 ErrorResponse(error=ErrorType.VARIABLE_NOT_FOUND, detail={"message": f"Variable {key} not found"})
270 )
271
272
273def _set_variable(key: str, value: Any, description: str | None = None, serialize_json: bool = False) -> None:
274 # TODO: This should probably be moved to a separate module like `airflow.sdk.execution_time.comms`
275 # or `airflow.sdk.execution_time.variable`
276 # A reason to not move it to `airflow.sdk.execution_time.comms` is that it
277 # will make that module depend on Task SDK, which is not ideal because we intend to
278 # keep Task SDK as a separate package than execution time mods.
279 import json
280
281 from airflow.sdk.execution_time.cache import SecretCache
282 from airflow.sdk.execution_time.comms import PutVariable
283 from airflow.sdk.execution_time.secrets.execution_api import ExecutionAPISecretsBackend
284 from airflow.sdk.execution_time.supervisor import ensure_secrets_backend_loaded
285 from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
286
287 # check for write conflicts on the worker
288 for secrets_backend in ensure_secrets_backend_loaded():
289 if isinstance(secrets_backend, ExecutionAPISecretsBackend):
290 continue
291 try:
292 var_val = secrets_backend.get_variable(key=key)
293 if var_val is not None:
294 _backend_name = type(secrets_backend).__name__
295 log.warning(
296 "The variable %s is defined in the %s secrets backend, which takes "
297 "precedence over reading from the API Server. The value from the API Server will be "
298 "updated, but to read it you have to delete the conflicting variable "
299 "from %s",
300 key,
301 _backend_name,
302 _backend_name,
303 )
304 except Exception:
305 log.exception(
306 "Unable to retrieve variable from secrets backend (%s). Checking subsequent secrets backend.",
307 type(secrets_backend).__name__,
308 )
309
310 try:
311 if serialize_json:
312 value = json.dumps(value, indent=2)
313 except Exception as e:
314 log.exception(e)
315
316 SUPERVISOR_COMMS.send(PutVariable(key=key, value=value, description=description))
317
318 # Invalidate cache after setting the variable
319 SecretCache.invalidate_variable(key)
320
321
322def _delete_variable(key: str) -> None:
323 # TODO: This should probably be moved to a separate module like `airflow.sdk.execution_time.comms`
324 # or `airflow.sdk.execution_time.variable`
325 # A reason to not move it to `airflow.sdk.execution_time.comms` is that it
326 # will make that module depend on Task SDK, which is not ideal because we intend to
327 # keep Task SDK as a separate package than execution time mods.
328 from airflow.sdk.execution_time.cache import SecretCache
329 from airflow.sdk.execution_time.comms import DeleteVariable
330 from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
331
332 msg = SUPERVISOR_COMMS.send(DeleteVariable(key=key))
333 if TYPE_CHECKING:
334 assert isinstance(msg, OKResponse)
335
336 # Invalidate cache after deleting the variable
337 SecretCache.invalidate_variable(key)
338
339
340class ConnectionAccessor:
341 """Wrapper to access Connection entries in template."""
342
343 def __getattr__(self, conn_id: str) -> Any:
344 from airflow.sdk.definitions.connection import Connection
345
346 return Connection.get(conn_id)
347
348 def __repr__(self) -> str:
349 return "<ConnectionAccessor (dynamic access)>"
350
351 def __eq__(self, other):
352 if not isinstance(other, ConnectionAccessor):
353 return False
354 # All instances of ConnectionAccessor are equal since it is a stateless dynamic accessor
355 return True
356
357 def __hash__(self):
358 return hash(self.__class__.__name__)
359
360 def get(self, conn_id: str, default_conn: Any = None) -> Any:
361 try:
362 return _get_connection(conn_id)
363 except AirflowRuntimeError as e:
364 if e.error.error == ErrorType.CONNECTION_NOT_FOUND:
365 return default_conn
366 raise
367 except AirflowNotFoundException:
368 return default_conn
369
370
371class VariableAccessor:
372 """Wrapper to access Variable values in template."""
373
374 def __init__(self, deserialize_json: bool) -> None:
375 self._deserialize_json = deserialize_json
376
377 def __eq__(self, other):
378 if not isinstance(other, VariableAccessor):
379 return False
380 # All instances of VariableAccessor are equal since it is a stateless dynamic accessor
381 return True
382
383 def __hash__(self):
384 return hash(self.__class__.__name__)
385
386 def __repr__(self) -> str:
387 return "<VariableAccessor (dynamic access)>"
388
389 def __getattr__(self, key: str) -> Any:
390 return _get_variable(key, self._deserialize_json)
391
392 def get(self, key, default: Any = NOTSET) -> Any:
393 try:
394 return _get_variable(key, self._deserialize_json)
395 except AirflowRuntimeError as e:
396 if e.error.error == ErrorType.VARIABLE_NOT_FOUND:
397 return default
398 raise
399
400
401class MacrosAccessor:
402 """Wrapper to access Macros module lazily."""
403
404 _macros_module = None
405
406 def __getattr__(self, item: str) -> Any:
407 # Lazily load Macros module
408 if not self._macros_module:
409 import airflow.sdk.execution_time.macros
410
411 self._macros_module = airflow.sdk.execution_time.macros
412 return getattr(self._macros_module, item)
413
414 def __repr__(self) -> str:
415 return "<MacrosAccessor (dynamic access to macros)>"
416
417 def __eq__(self, other: object) -> bool:
418 if not isinstance(other, MacrosAccessor):
419 return False
420 return True
421
422 def __hash__(self):
423 return hash(self.__class__.__name__)
424
425
426class _AssetRefResolutionMixin:
427 _asset_ref_cache: dict[AssetRef, tuple[AssetUniqueKey, dict[str, JsonValue]]] = {}
428
429 def _resolve_asset_ref(self, ref: AssetRef) -> tuple[AssetUniqueKey, dict[str, JsonValue]]:
430 with contextlib.suppress(KeyError):
431 return self._asset_ref_cache[ref]
432
433 refs_to_cache: list[AssetRef]
434 if isinstance(ref, AssetNameRef):
435 asset = self._get_asset_from_db(name=ref.name)
436 refs_to_cache = [ref, AssetUriRef(asset.uri)]
437 elif isinstance(ref, AssetUriRef):
438 asset = self._get_asset_from_db(uri=ref.uri)
439 refs_to_cache = [ref, AssetNameRef(asset.name)]
440 else:
441 raise TypeError(f"Unimplemented asset ref: {type(ref)}")
442 unique_key = AssetUniqueKey.from_asset(asset)
443 for ref in refs_to_cache:
444 self._asset_ref_cache[ref] = (unique_key, asset.extra)
445 return (unique_key, asset.extra)
446
447 # TODO: This is temporary to avoid code duplication between here & airflow/models/taskinstance.py
448 @staticmethod
449 def _get_asset_from_db(name: str | None = None, uri: str | None = None) -> Asset:
450 from airflow.sdk.definitions.asset import Asset
451 from airflow.sdk.execution_time.comms import (
452 ErrorResponse,
453 GetAssetByName,
454 GetAssetByUri,
455 ToSupervisor,
456 )
457 from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
458
459 msg: ToSupervisor
460 if name:
461 msg = GetAssetByName(name=name)
462 elif uri:
463 msg = GetAssetByUri(uri=uri)
464 else:
465 raise ValueError("Either name or uri must be provided")
466
467 resp = SUPERVISOR_COMMS.send(msg)
468 if isinstance(resp, ErrorResponse):
469 raise AirflowRuntimeError(resp)
470
471 if TYPE_CHECKING:
472 assert isinstance(resp, AssetResult)
473 return Asset(**resp.model_dump(exclude={"type"}))
474
475
476@attrs.define
477class OutletEventAccessor(_AssetRefResolutionMixin):
478 """Wrapper to access an outlet asset event in template."""
479
480 key: BaseAssetUniqueKey
481 extra: dict[str, JsonValue] = attrs.Factory(dict)
482 asset_alias_events: list[AssetAliasEvent] = attrs.field(factory=list)
483
484 def add(self, asset: Asset | AssetRef, extra: dict[str, JsonValue] | None = None) -> None:
485 """Add an AssetEvent to an existing Asset."""
486 if not isinstance(self.key, AssetAliasUniqueKey):
487 return
488
489 if isinstance(asset, AssetRef):
490 asset_key, asset_extra = self._resolve_asset_ref(asset)
491 else:
492 asset_key = AssetUniqueKey.from_asset(asset)
493 asset_extra = asset.extra
494
495 asset_alias_name = self.key.name
496 event = AssetAliasEvent(
497 source_alias_name=asset_alias_name,
498 dest_asset_key=asset_key,
499 dest_asset_extra=asset_extra,
500 extra=extra or {},
501 )
502 self.asset_alias_events.append(event)
503
504
505class _AssetEventAccessorsMixin(Generic[T]):
506 @overload
507 def for_asset(self, *, name: str, uri: str) -> T: ...
508
509 @overload
510 def for_asset(self, *, name: str) -> T: ...
511
512 @overload
513 def for_asset(self, *, uri: str) -> T: ...
514
515 def for_asset(self, *, name: str | None = None, uri: str | None = None) -> T:
516 if name and uri:
517 return self[Asset(name=name, uri=uri)]
518 if name:
519 return self[Asset.ref(name=name)]
520 if uri:
521 return self[Asset.ref(uri=uri)]
522
523 raise ValueError("name and uri cannot both be None")
524
525 def for_asset_alias(self, *, name: str) -> T:
526 return self[AssetAlias(name=name)]
527
528 def __getitem__(self, key: Asset | AssetAlias | AssetRef) -> T:
529 raise NotImplementedError
530
531
532class OutletEventAccessors(
533 _AssetRefResolutionMixin,
534 Mapping["Asset | AssetAlias", OutletEventAccessor],
535 _AssetEventAccessorsMixin[OutletEventAccessor],
536):
537 """Lazy mapping of outlet asset event accessors."""
538
539 def __init__(self) -> None:
540 self._dict: dict[BaseAssetUniqueKey, OutletEventAccessor] = {}
541
542 def __str__(self) -> str:
543 return f"OutletEventAccessors(_dict={self._dict})"
544
545 def __iter__(self) -> Iterator[Asset | AssetAlias]:
546 return (
547 key.to_asset() if isinstance(key, AssetUniqueKey) else key.to_asset_alias() for key in self._dict
548 )
549
550 def __len__(self) -> int:
551 return len(self._dict)
552
553 def __getitem__(self, key: Asset | AssetAlias | AssetRef) -> OutletEventAccessor:
554 hashable_key: BaseAssetUniqueKey
555 if isinstance(key, Asset):
556 hashable_key = AssetUniqueKey.from_asset(key)
557 elif isinstance(key, AssetAlias):
558 hashable_key = AssetAliasUniqueKey.from_asset_alias(key)
559 elif isinstance(key, AssetRef):
560 hashable_key, _ = self._resolve_asset_ref(key)
561 else:
562 raise TypeError(f"Key should be either an asset or an asset alias, not {type(key)}")
563
564 if hashable_key not in self._dict:
565 self._dict[hashable_key] = OutletEventAccessor(extra={}, key=hashable_key)
566 return self._dict[hashable_key]
567
568
569@attrs.define(init=False)
570class InletEventsAccessor(Sequence["AssetEventResult"]):
571 _after: str | datetime | None
572 _before: str | datetime | None
573 _ascending: bool
574 _limit: int | None
575 _asset_name: str | None
576 _asset_uri: str | None
577 _alias_name: str | None
578
579 def __init__(
580 self, asset_name: str | None = None, asset_uri: str | None = None, alias_name: str | None = None
581 ):
582 self._asset_name = asset_name
583 self._asset_uri = asset_uri
584 self._alias_name = alias_name
585 self._after = None
586 self._before = None
587 self._ascending = True
588 self._limit = None
589
590 def after(self, after: str) -> Self:
591 self._after = after
592 self._reset_cache()
593 return self
594
595 def before(self, before: str) -> Self:
596 self._before = before
597 self._reset_cache()
598 return self
599
600 def ascending(self, ascending: bool = True) -> Self:
601 self._ascending = ascending
602 self._reset_cache()
603 return self
604
605 def limit(self, limit: int) -> Self:
606 self._limit = limit
607 self._reset_cache()
608 return self
609
610 @functools.cached_property
611 def _asset_events(self) -> list[AssetEventResult]:
612 from airflow.sdk.execution_time.comms import (
613 ErrorResponse,
614 GetAssetEventByAsset,
615 GetAssetEventByAssetAlias,
616 ToSupervisor,
617 )
618 from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
619
620 query_dict: dict[str, Any] = {
621 "after": self._after,
622 "before": self._before,
623 "ascending": self._ascending,
624 "limit": self._limit,
625 }
626
627 msg: ToSupervisor
628 if self._alias_name is not None:
629 msg = GetAssetEventByAssetAlias(alias_name=self._alias_name, **query_dict)
630 else:
631 if self._asset_name is None and self._asset_uri is None:
632 raise ValueError("Either asset_name or asset_uri must be provided")
633 msg = GetAssetEventByAsset(name=self._asset_name, uri=self._asset_uri, **query_dict)
634 resp = SUPERVISOR_COMMS.send(msg)
635 if isinstance(resp, ErrorResponse):
636 raise AirflowRuntimeError(resp)
637
638 if TYPE_CHECKING:
639 assert isinstance(resp, AssetEventsResult)
640
641 return list(resp.iter_asset_event_results())
642
643 def _reset_cache(self) -> None:
644 try:
645 del self._asset_events
646 except AttributeError:
647 pass
648
649 def __iter__(self) -> Iterator[AssetEventResult]:
650 return iter(self._asset_events)
651
652 def __len__(self) -> int:
653 return len(self._asset_events)
654
655 @overload
656 def __getitem__(self, key: int) -> AssetEventResult: ...
657
658 @overload
659 def __getitem__(self, key: slice) -> Sequence[AssetEventResult]: ...
660
661 def __getitem__(self, key: int | slice) -> AssetEventResult | Sequence[AssetEventResult]:
662 return self._asset_events[key]
663
664
665@attrs.define(init=False)
666class InletEventsAccessors(
667 Mapping["int | Asset | AssetAlias | AssetRef", Any],
668 _AssetEventAccessorsMixin[Sequence["AssetEventResult"]],
669):
670 """Lazy mapping of inlet asset event accessors."""
671
672 _inlets: list[Any]
673 _assets: dict[AssetUniqueKey, Asset]
674 _asset_aliases: dict[AssetAliasUniqueKey, AssetAlias]
675
676 def __init__(self, inlets: list) -> None:
677 self._inlets = inlets
678 self._assets = {}
679 self._asset_aliases = {}
680
681 for inlet in inlets:
682 if isinstance(inlet, Asset):
683 self._assets[AssetUniqueKey.from_asset(inlet)] = inlet
684 elif isinstance(inlet, AssetAlias):
685 self._asset_aliases[AssetAliasUniqueKey.from_asset_alias(inlet)] = inlet
686 elif isinstance(inlet, AssetNameRef):
687 asset = OutletEventAccessors._get_asset_from_db(name=inlet.name)
688 self._assets[AssetUniqueKey.from_asset(asset)] = asset
689 elif isinstance(inlet, AssetUriRef):
690 asset = OutletEventAccessors._get_asset_from_db(uri=inlet.uri)
691 self._assets[AssetUniqueKey.from_asset(asset)] = asset
692
693 def __iter__(self) -> Iterator[Asset | AssetAlias]:
694 return iter(self._inlets)
695
696 def __len__(self) -> int:
697 return len(self._inlets)
698
699 def __getitem__(
700 self,
701 key: int | Asset | AssetAlias | AssetRef,
702 ) -> InletEventsAccessor:
703 from airflow.sdk.definitions.asset import Asset
704
705 if isinstance(key, int): # Support index access; it's easier for trivial cases.
706 obj = self._inlets[key]
707 if not isinstance(obj, (Asset, AssetAlias, AssetRef)):
708 raise IndexError(key)
709 else:
710 obj = key
711
712 if isinstance(obj, Asset):
713 asset = self._assets[AssetUniqueKey.from_asset(obj)]
714 return InletEventsAccessor(asset_name=asset.name, asset_uri=asset.uri)
715 if isinstance(obj, AssetNameRef):
716 try:
717 asset = next(a for k, a in self._assets.items() if k.name == obj.name)
718 except StopIteration:
719 raise KeyError(obj) from None
720 return InletEventsAccessor(asset_name=asset.name)
721 if isinstance(obj, AssetUriRef):
722 try:
723 asset = next(a for k, a in self._assets.items() if k.uri == obj.uri)
724 except StopIteration:
725 raise KeyError(obj) from None
726 return InletEventsAccessor(asset_uri=asset.uri)
727 if isinstance(obj, AssetAlias):
728 asset_alias = self._asset_aliases[AssetAliasUniqueKey.from_asset_alias(obj)]
729 return InletEventsAccessor(alias_name=asset_alias.name)
730 raise TypeError(f"`key` is of unknown type ({type(key).__name__})")
731
732
733@attrs.define
734class TriggeringAssetEventsAccessor(
735 _AssetRefResolutionMixin,
736 Mapping["Asset | AssetAlias | AssetRef", Sequence["AssetEventDagRunReferenceResult"]],
737 _AssetEventAccessorsMixin[Sequence["AssetEventDagRunReferenceResult"]],
738):
739 """Lazy mapping of triggering asset events."""
740
741 _events: Mapping[BaseAssetUniqueKey, Sequence[AssetEventDagRunReferenceResult]]
742
743 @classmethod
744 def build(cls, events: Iterable[AssetEventDagRunReferenceResult]) -> TriggeringAssetEventsAccessor:
745 coll: dict[BaseAssetUniqueKey, list[AssetEventDagRunReferenceResult]] = collections.defaultdict(list)
746 for event in events:
747 coll[AssetUniqueKey(name=event.asset.name, uri=event.asset.uri)].append(event)
748 for alias in event.source_aliases:
749 coll[AssetAliasUniqueKey(name=alias.name)].append(event)
750 return cls(coll)
751
752 def __str__(self) -> str:
753 return f"TriggeringAssetEventAccessor(_events={self._events})"
754
755 def __iter__(self) -> Iterator[Asset | AssetAlias]:
756 return (
757 key.to_asset() if isinstance(key, AssetUniqueKey) else key.to_asset_alias()
758 for key in self._events
759 )
760
761 def __len__(self) -> int:
762 return len(self._events)
763
764 def __getitem__(self, key: Asset | AssetAlias | AssetRef) -> Sequence[AssetEventDagRunReferenceResult]:
765 hashable_key: BaseAssetUniqueKey
766 if isinstance(key, Asset):
767 hashable_key = AssetUniqueKey.from_asset(key)
768 elif isinstance(key, AssetRef):
769 hashable_key, _ = self._resolve_asset_ref(key)
770 elif isinstance(key, AssetAlias):
771 hashable_key = AssetAliasUniqueKey.from_asset_alias(key)
772 else:
773 raise TypeError(f"Key should be either an asset or an asset alias, not {type(key)}")
774
775 return self._events[hashable_key]
776
777
778@cache # Prevent multiple API access.
779def get_previous_dagrun_success(ti_id: UUID) -> PrevSuccessfulDagRunResponse:
780 from airflow.sdk.execution_time import task_runner
781 from airflow.sdk.execution_time.comms import (
782 GetPrevSuccessfulDagRun,
783 PrevSuccessfulDagRunResponse,
784 PrevSuccessfulDagRunResult,
785 )
786
787 msg = task_runner.SUPERVISOR_COMMS.send(GetPrevSuccessfulDagRun(ti_id=ti_id))
788
789 if TYPE_CHECKING:
790 assert isinstance(msg, PrevSuccessfulDagRunResult)
791 return PrevSuccessfulDagRunResponse(**msg.model_dump(exclude={"type"}))
792
793
794@contextlib.contextmanager
795def set_current_context(context: Context) -> Generator[Context, None, None]:
796 """
797 Set the current execution context to the provided context object.
798
799 This method should be called once per Task execution, before calling operator.execute.
800 """
801 _CURRENT_CONTEXT.append(context)
802 try:
803 yield context
804 finally:
805 expected_state = _CURRENT_CONTEXT.pop()
806 if expected_state != context:
807 log.warning(
808 "Current context is not equal to the state at context stack.",
809 expected=context,
810 got=expected_state,
811 )
812
813
814def context_update_for_unmapped(context: Context, task: BaseOperator) -> None:
815 """
816 Update context after task unmapping.
817
818 Since ``get_template_context()`` is called before unmapping, the context
819 contains information about the mapped task. We need to do some in-place
820 updates to ensure the template context reflects the unmapped task instead.
821
822 :meta private:
823 """
824 from airflow.sdk.definitions.param import process_params
825
826 context["task"] = context["ti"].task = task
827 context["params"] = process_params(
828 context["dag"], task, context["dag_run"].conf, suppress_exception=False
829 )
830
831
832def context_to_airflow_vars(context: Mapping[str, Any], in_env_var_format: bool = False) -> dict[str, str]:
833 """
834 Return values used to externally reconstruct relations between dags, dag_runs, tasks and task_instances.
835
836 Given a context, this function provides a dictionary of values that can be used to
837 externally reconstruct relations between dags, dag_runs, tasks and task_instances.
838 Default to abc.def.ghi format and can be made to ABC_DEF_GHI format if
839 in_env_var_format is set to True.
840
841 :param context: The context for the task_instance of interest.
842 :param in_env_var_format: If returned vars should be in ABC_DEF_GHI format.
843 :return: task_instance context as dict.
844 """
845 from datetime import datetime
846
847 from airflow import settings
848
849 params = {}
850 if in_env_var_format:
851 name_format = "env_var_format"
852 else:
853 name_format = "default"
854
855 task = context.get("task")
856 task_instance = context.get("task_instance")
857 dag_run = context.get("dag_run")
858
859 ops = [
860 (task, "email", "AIRFLOW_CONTEXT_DAG_EMAIL"),
861 (task, "owner", "AIRFLOW_CONTEXT_DAG_OWNER"),
862 (task_instance, "dag_id", "AIRFLOW_CONTEXT_DAG_ID"),
863 (task_instance, "task_id", "AIRFLOW_CONTEXT_TASK_ID"),
864 (dag_run, "logical_date", "AIRFLOW_CONTEXT_LOGICAL_DATE"),
865 (task_instance, "try_number", "AIRFLOW_CONTEXT_TRY_NUMBER"),
866 (dag_run, "run_id", "AIRFLOW_CONTEXT_DAG_RUN_ID"),
867 ]
868
869 context_params = settings.get_airflow_context_vars(context)
870 for key_raw, value in context_params.items():
871 if not isinstance(key_raw, str):
872 raise TypeError(f"key <{key_raw}> must be string")
873 if not isinstance(value, str):
874 raise TypeError(f"value of key <{key_raw}> must be string, not {type(value)}")
875
876 if in_env_var_format and not key_raw.startswith(ENV_VAR_FORMAT_PREFIX):
877 key = ENV_VAR_FORMAT_PREFIX + key_raw.upper()
878 elif not key_raw.startswith(DEFAULT_FORMAT_PREFIX):
879 key = DEFAULT_FORMAT_PREFIX + key_raw
880 else:
881 key = key_raw
882 params[key] = value
883
884 for subject, attr, mapping_key in ops:
885 _attr = getattr(subject, attr, None)
886 if subject and _attr:
887 mapping_value = AIRFLOW_VAR_NAME_FORMAT_MAPPING[mapping_key][name_format]
888 if isinstance(_attr, str):
889 params[mapping_value] = _attr
890 elif isinstance(_attr, datetime):
891 params[mapping_value] = _attr.isoformat()
892 elif isinstance(_attr, list):
893 # os env variable value needs to be string
894 params[mapping_value] = ",".join(_attr)
895 else:
896 params[mapping_value] = str(_attr)
897
898 return params
899
900
901def context_get_outlet_events(context: Context) -> OutletEventAccessorsProtocol:
902 try:
903 outlet_events = context["outlet_events"]
904 except KeyError:
905 outlet_events = context["outlet_events"] = OutletEventAccessors()
906 return outlet_events