1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements. See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership. The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License. You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied. See the License for the
15# specific language governing permissions and limitations
16# under the License.
17from __future__ import annotations
18
19import collections
20import contextlib
21import functools
22import inspect
23from collections.abc import Generator, Iterable, Iterator, Mapping, Sequence
24from datetime import datetime
25from functools import cache
26from typing import TYPE_CHECKING, Any, Generic, TypeVar, overload
27
28import attrs
29import structlog
30
31from airflow.sdk.definitions._internal.contextmanager import _CURRENT_CONTEXT
32from airflow.sdk.definitions._internal.types import NOTSET
33from airflow.sdk.definitions.asset import (
34 Asset,
35 AssetAlias,
36 AssetAliasEvent,
37 AssetAliasUniqueKey,
38 AssetNameRef,
39 AssetRef,
40 AssetUniqueKey,
41 AssetUriRef,
42 BaseAssetUniqueKey,
43)
44from airflow.sdk.exceptions import AirflowNotFoundException, AirflowRuntimeError, ErrorType
45from airflow.sdk.log import mask_secret
46
47if TYPE_CHECKING:
48 from uuid import UUID
49
50 from pydantic.types import JsonValue
51 from typing_extensions import Self
52
53 from airflow.sdk import Variable
54 from airflow.sdk.bases.operator import BaseOperator
55 from airflow.sdk.definitions.connection import Connection
56 from airflow.sdk.definitions.context import Context
57 from airflow.sdk.execution_time.comms import (
58 AssetEventDagRunReferenceResult,
59 AssetEventResult,
60 AssetEventsResult,
61 AssetResult,
62 ConnectionResult,
63 OKResponse,
64 PrevSuccessfulDagRunResponse,
65 ReceiveMsgType,
66 VariableResult,
67 )
68 from airflow.sdk.types import OutletEventAccessorsProtocol
69
70
71DEFAULT_FORMAT_PREFIX = "airflow.ctx."
72ENV_VAR_FORMAT_PREFIX = "AIRFLOW_CTX_"
73
74AIRFLOW_VAR_NAME_FORMAT_MAPPING = {
75 "AIRFLOW_CONTEXT_DAG_ID": {
76 "default": f"{DEFAULT_FORMAT_PREFIX}dag_id",
77 "env_var_format": f"{ENV_VAR_FORMAT_PREFIX}DAG_ID",
78 },
79 "AIRFLOW_CONTEXT_TASK_ID": {
80 "default": f"{DEFAULT_FORMAT_PREFIX}task_id",
81 "env_var_format": f"{ENV_VAR_FORMAT_PREFIX}TASK_ID",
82 },
83 "AIRFLOW_CONTEXT_LOGICAL_DATE": {
84 "default": f"{DEFAULT_FORMAT_PREFIX}logical_date",
85 "env_var_format": f"{ENV_VAR_FORMAT_PREFIX}LOGICAL_DATE",
86 },
87 "AIRFLOW_CONTEXT_TRY_NUMBER": {
88 "default": f"{DEFAULT_FORMAT_PREFIX}try_number",
89 "env_var_format": f"{ENV_VAR_FORMAT_PREFIX}TRY_NUMBER",
90 },
91 "AIRFLOW_CONTEXT_DAG_RUN_ID": {
92 "default": f"{DEFAULT_FORMAT_PREFIX}dag_run_id",
93 "env_var_format": f"{ENV_VAR_FORMAT_PREFIX}DAG_RUN_ID",
94 },
95 "AIRFLOW_CONTEXT_DAG_OWNER": {
96 "default": f"{DEFAULT_FORMAT_PREFIX}dag_owner",
97 "env_var_format": f"{ENV_VAR_FORMAT_PREFIX}DAG_OWNER",
98 },
99 "AIRFLOW_CONTEXT_DAG_EMAIL": {
100 "default": f"{DEFAULT_FORMAT_PREFIX}dag_email",
101 "env_var_format": f"{ENV_VAR_FORMAT_PREFIX}DAG_EMAIL",
102 },
103}
104
105
106log = structlog.get_logger(logger_name="task")
107
108T = TypeVar("T")
109
110
111def _process_connection_result_conn(conn_result: ReceiveMsgType | None) -> Connection:
112 from airflow.sdk.definitions.connection import Connection
113 from airflow.sdk.execution_time.comms import ErrorResponse
114
115 if isinstance(conn_result, ErrorResponse):
116 raise AirflowRuntimeError(conn_result)
117
118 if TYPE_CHECKING:
119 assert isinstance(conn_result, ConnectionResult)
120
121 # `by_alias=True` is used to convert the `schema` field to `schema_` in the Connection model
122 return Connection(**conn_result.model_dump(exclude={"type"}, by_alias=True))
123
124
125def _mask_connection_secrets(conn: Connection) -> None:
126 """Mask sensitive connection fields from logs."""
127 if conn.password:
128 mask_secret(conn.password)
129 if conn.extra:
130 mask_secret(conn.extra)
131
132
133def _convert_variable_result_to_variable(var_result: VariableResult, deserialize_json: bool) -> Variable:
134 from airflow.sdk.definitions.variable import Variable
135
136 if deserialize_json:
137 import json
138
139 var_result.value = json.loads(var_result.value) # type: ignore
140 return Variable(**var_result.model_dump(exclude={"type"}))
141
142
143def _get_connection(conn_id: str) -> Connection:
144 from airflow.sdk.execution_time.cache import SecretCache
145 from airflow.sdk.execution_time.supervisor import ensure_secrets_backend_loaded
146
147 # Check cache first (optional; only on dag processor)
148 try:
149 uri = SecretCache.get_connection_uri(conn_id)
150 from airflow.sdk.definitions.connection import Connection
151
152 conn = Connection.from_uri(uri, conn_id=conn_id)
153 _mask_connection_secrets(conn)
154 return conn
155 except SecretCache.NotPresentException:
156 pass # continue to backends
157
158 # Iterate over configured backends (which may include SupervisorCommsSecretsBackend
159 # in worker contexts or MetastoreBackend in API server contexts)
160 backends = ensure_secrets_backend_loaded()
161 for secrets_backend in backends:
162 try:
163 conn = secrets_backend.get_connection(conn_id=conn_id) # type: ignore[assignment]
164 if conn:
165 SecretCache.save_connection_uri(conn_id, conn.get_uri())
166 _mask_connection_secrets(conn)
167 return conn
168 except Exception:
169 log.debug(
170 "Unable to retrieve connection from secrets backend (%s). "
171 "Checking subsequent secrets backend.",
172 type(secrets_backend).__name__,
173 )
174
175 # If no backend found the connection, raise an error
176
177 raise AirflowNotFoundException(f"The conn_id `{conn_id}` isn't defined")
178
179
180async def _async_get_connection(conn_id: str) -> Connection:
181 from asgiref.sync import sync_to_async
182
183 from airflow.sdk.execution_time.cache import SecretCache
184
185 # Check cache first
186 try:
187 uri = SecretCache.get_connection_uri(conn_id)
188 from airflow.sdk.definitions.connection import Connection
189
190 conn = Connection.from_uri(uri, conn_id=conn_id)
191 _mask_connection_secrets(conn)
192 return conn
193 except SecretCache.NotPresentException:
194 pass # continue to backends
195
196 from airflow.sdk.execution_time.supervisor import ensure_secrets_backend_loaded
197
198 # Try secrets backends
199 backends = ensure_secrets_backend_loaded()
200 for secrets_backend in backends:
201 try:
202 # Use async method if available, otherwise wrap sync method
203 # getattr avoids triggering AsyncMock coroutine creation under Python 3.13
204 async_method = getattr(secrets_backend, "aget_connection", None)
205 if async_method is not None:
206 maybe_awaitable = async_method(conn_id)
207 conn = await maybe_awaitable if inspect.isawaitable(maybe_awaitable) else maybe_awaitable
208 else:
209 conn = await sync_to_async(secrets_backend.get_connection)(conn_id) # type: ignore[assignment]
210
211 if conn:
212 SecretCache.save_connection_uri(conn_id, conn.get_uri())
213 _mask_connection_secrets(conn)
214 return conn
215 except Exception:
216 # If one backend fails, try the next one
217 log.debug(
218 "Unable to retrieve connection from secrets backend (%s). "
219 "Checking subsequent secrets backend.",
220 type(secrets_backend).__name__,
221 )
222
223 # If no backend found the connection, raise an error
224
225 raise AirflowNotFoundException(f"The conn_id `{conn_id}` isn't defined")
226
227
228def _get_variable(key: str, deserialize_json: bool) -> Any:
229 from airflow.sdk.execution_time.cache import SecretCache
230 from airflow.sdk.execution_time.supervisor import ensure_secrets_backend_loaded
231
232 # Check cache first
233 try:
234 var_val = SecretCache.get_variable(key)
235 if var_val is not None:
236 if deserialize_json:
237 import json
238
239 var_val = json.loads(var_val)
240 if isinstance(var_val, str):
241 mask_secret(var_val, key)
242 return var_val
243 except SecretCache.NotPresentException:
244 pass # Continue to check backends
245
246 backends = ensure_secrets_backend_loaded()
247
248 # Iterate over backends if not in cache (or expired)
249 for secrets_backend in backends:
250 try:
251 var_val = secrets_backend.get_variable(key=key)
252 if var_val is not None:
253 # Save raw value before deserialization to maintain cache consistency
254 SecretCache.save_variable(key, var_val)
255 if deserialize_json:
256 import json
257
258 var_val = json.loads(var_val)
259 if isinstance(var_val, str):
260 mask_secret(var_val, key)
261 return var_val
262 except Exception:
263 log.exception(
264 "Unable to retrieve variable from secrets backend (%s). Checking subsequent secrets backend.",
265 type(secrets_backend).__name__,
266 )
267
268 # If no backend found the variable, raise a not found error (mirrors _get_connection)
269 from airflow.sdk.exceptions import AirflowRuntimeError, ErrorType
270 from airflow.sdk.execution_time.comms import ErrorResponse
271
272 raise AirflowRuntimeError(
273 ErrorResponse(error=ErrorType.VARIABLE_NOT_FOUND, detail={"message": f"Variable {key} not found"})
274 )
275
276
277def _set_variable(key: str, value: Any, description: str | None = None, serialize_json: bool = False) -> None:
278 # TODO: This should probably be moved to a separate module like `airflow.sdk.execution_time.comms`
279 # or `airflow.sdk.execution_time.variable`
280 # A reason to not move it to `airflow.sdk.execution_time.comms` is that it
281 # will make that module depend on Task SDK, which is not ideal because we intend to
282 # keep Task SDK as a separate package than execution time mods.
283 import json
284
285 from airflow.sdk.execution_time.cache import SecretCache
286 from airflow.sdk.execution_time.comms import PutVariable
287 from airflow.sdk.execution_time.secrets.execution_api import ExecutionAPISecretsBackend
288 from airflow.sdk.execution_time.supervisor import ensure_secrets_backend_loaded
289 from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
290
291 # check for write conflicts on the worker
292 for secrets_backend in ensure_secrets_backend_loaded():
293 if isinstance(secrets_backend, ExecutionAPISecretsBackend):
294 continue
295 try:
296 var_val = secrets_backend.get_variable(key=key)
297 if var_val is not None:
298 _backend_name = type(secrets_backend).__name__
299 log.warning(
300 "The variable %s is defined in the %s secrets backend, which takes "
301 "precedence over reading from the API Server. The value from the API Server will be "
302 "updated, but to read it you have to delete the conflicting variable "
303 "from %s",
304 key,
305 _backend_name,
306 _backend_name,
307 )
308 except Exception:
309 log.exception(
310 "Unable to retrieve variable from secrets backend (%s). Checking subsequent secrets backend.",
311 type(secrets_backend).__name__,
312 )
313
314 try:
315 if serialize_json:
316 value = json.dumps(value, indent=2)
317 except Exception as e:
318 log.exception(e)
319
320 SUPERVISOR_COMMS.send(PutVariable(key=key, value=value, description=description))
321
322 # Invalidate cache after setting the variable
323 SecretCache.invalidate_variable(key)
324
325
326def _delete_variable(key: str) -> None:
327 # TODO: This should probably be moved to a separate module like `airflow.sdk.execution_time.comms`
328 # or `airflow.sdk.execution_time.variable`
329 # A reason to not move it to `airflow.sdk.execution_time.comms` is that it
330 # will make that module depend on Task SDK, which is not ideal because we intend to
331 # keep Task SDK as a separate package than execution time mods.
332 from airflow.sdk.execution_time.cache import SecretCache
333 from airflow.sdk.execution_time.comms import DeleteVariable
334 from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
335
336 msg = SUPERVISOR_COMMS.send(DeleteVariable(key=key))
337 if TYPE_CHECKING:
338 assert isinstance(msg, OKResponse)
339
340 # Invalidate cache after deleting the variable
341 SecretCache.invalidate_variable(key)
342
343
344class ConnectionAccessor:
345 """Wrapper to access Connection entries in template."""
346
347 def __getattr__(self, conn_id: str) -> Any:
348 from airflow.sdk.definitions.connection import Connection
349
350 return Connection.get(conn_id)
351
352 def __repr__(self) -> str:
353 return "<ConnectionAccessor (dynamic access)>"
354
355 def __eq__(self, other):
356 if not isinstance(other, ConnectionAccessor):
357 return False
358 # All instances of ConnectionAccessor are equal since it is a stateless dynamic accessor
359 return True
360
361 def __hash__(self):
362 return hash(self.__class__.__name__)
363
364 def get(self, conn_id: str, default_conn: Any = None) -> Any:
365 try:
366 return _get_connection(conn_id)
367 except AirflowRuntimeError as e:
368 if e.error.error == ErrorType.CONNECTION_NOT_FOUND:
369 return default_conn
370 raise
371 except AirflowNotFoundException:
372 return default_conn
373
374
375class VariableAccessor:
376 """Wrapper to access Variable values in template."""
377
378 def __init__(self, deserialize_json: bool) -> None:
379 self._deserialize_json = deserialize_json
380
381 def __eq__(self, other):
382 if not isinstance(other, VariableAccessor):
383 return False
384 # All instances of VariableAccessor are equal since it is a stateless dynamic accessor
385 return True
386
387 def __hash__(self):
388 return hash(self.__class__.__name__)
389
390 def __repr__(self) -> str:
391 return "<VariableAccessor (dynamic access)>"
392
393 def __getattr__(self, key: str) -> Any:
394 return _get_variable(key, self._deserialize_json)
395
396 def get(self, key, default: Any = NOTSET) -> Any:
397 try:
398 return _get_variable(key, self._deserialize_json)
399 except AirflowRuntimeError as e:
400 if e.error.error == ErrorType.VARIABLE_NOT_FOUND:
401 return default
402 raise
403
404
405class MacrosAccessor:
406 """Wrapper to access Macros module lazily."""
407
408 _macros_module = None
409
410 def __getattr__(self, item: str) -> Any:
411 # Lazily load Macros module
412 if not self._macros_module:
413 import airflow.sdk.execution_time.macros
414
415 self._macros_module = airflow.sdk.execution_time.macros
416 return getattr(self._macros_module, item)
417
418 def __repr__(self) -> str:
419 return "<MacrosAccessor (dynamic access to macros)>"
420
421 def __eq__(self, other: object) -> bool:
422 if not isinstance(other, MacrosAccessor):
423 return False
424 return True
425
426 def __hash__(self):
427 return hash(self.__class__.__name__)
428
429
430class _AssetRefResolutionMixin:
431 _asset_ref_cache: dict[AssetRef, tuple[AssetUniqueKey, dict[str, JsonValue]]] = {}
432
433 def _resolve_asset_ref(self, ref: AssetRef) -> tuple[AssetUniqueKey, dict[str, JsonValue]]:
434 with contextlib.suppress(KeyError):
435 return self._asset_ref_cache[ref]
436
437 refs_to_cache: list[AssetRef]
438 if isinstance(ref, AssetNameRef):
439 asset = self._get_asset_from_db(name=ref.name)
440 refs_to_cache = [ref, AssetUriRef(asset.uri)]
441 elif isinstance(ref, AssetUriRef):
442 asset = self._get_asset_from_db(uri=ref.uri)
443 refs_to_cache = [ref, AssetNameRef(asset.name)]
444 else:
445 raise TypeError(f"Unimplemented asset ref: {type(ref)}")
446 unique_key = AssetUniqueKey.from_asset(asset)
447 for ref in refs_to_cache:
448 self._asset_ref_cache[ref] = (unique_key, asset.extra)
449 return (unique_key, asset.extra)
450
451 # TODO: This is temporary to avoid code duplication between here & airflow/models/taskinstance.py
452 @staticmethod
453 def _get_asset_from_db(name: str | None = None, uri: str | None = None) -> Asset:
454 from airflow.sdk.definitions.asset import Asset
455 from airflow.sdk.execution_time.comms import (
456 ErrorResponse,
457 GetAssetByName,
458 GetAssetByUri,
459 ToSupervisor,
460 )
461 from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
462
463 msg: ToSupervisor
464 if name:
465 msg = GetAssetByName(name=name)
466 elif uri:
467 msg = GetAssetByUri(uri=uri)
468 else:
469 raise ValueError("Either name or uri must be provided")
470
471 resp = SUPERVISOR_COMMS.send(msg)
472 if isinstance(resp, ErrorResponse):
473 raise AirflowRuntimeError(resp)
474
475 if TYPE_CHECKING:
476 assert isinstance(resp, AssetResult)
477 return Asset(**resp.model_dump(exclude={"type"}))
478
479
480@attrs.define
481class OutletEventAccessor(_AssetRefResolutionMixin):
482 """Wrapper to access an outlet asset event in template."""
483
484 key: BaseAssetUniqueKey
485 extra: dict[str, JsonValue] = attrs.Factory(dict)
486 asset_alias_events: list[AssetAliasEvent] = attrs.field(factory=list)
487
488 def add(self, asset: Asset | AssetRef, extra: dict[str, JsonValue] | None = None) -> None:
489 """Add an AssetEvent to an existing Asset."""
490 if not isinstance(self.key, AssetAliasUniqueKey):
491 return
492
493 if isinstance(asset, AssetRef):
494 asset_key, asset_extra = self._resolve_asset_ref(asset)
495 else:
496 asset_key = AssetUniqueKey.from_asset(asset)
497 asset_extra = asset.extra
498
499 asset_alias_name = self.key.name
500 event = AssetAliasEvent(
501 source_alias_name=asset_alias_name,
502 dest_asset_key=asset_key,
503 dest_asset_extra=asset_extra,
504 extra=extra or {},
505 )
506 self.asset_alias_events.append(event)
507
508
509class _AssetEventAccessorsMixin(Generic[T]):
510 @overload
511 def for_asset(self, *, name: str, uri: str) -> T: ...
512
513 @overload
514 def for_asset(self, *, name: str) -> T: ...
515
516 @overload
517 def for_asset(self, *, uri: str) -> T: ...
518
519 def for_asset(self, *, name: str | None = None, uri: str | None = None) -> T:
520 if name and uri:
521 return self[Asset(name=name, uri=uri)]
522 if name:
523 return self[Asset.ref(name=name)]
524 if uri:
525 return self[Asset.ref(uri=uri)]
526
527 raise ValueError("name and uri cannot both be None")
528
529 def for_asset_alias(self, *, name: str) -> T:
530 return self[AssetAlias(name=name)]
531
532 def __getitem__(self, key: Asset | AssetAlias | AssetRef) -> T:
533 raise NotImplementedError
534
535
536class OutletEventAccessors(
537 _AssetRefResolutionMixin,
538 Mapping["Asset | AssetAlias", OutletEventAccessor],
539 _AssetEventAccessorsMixin[OutletEventAccessor],
540):
541 """Lazy mapping of outlet asset event accessors."""
542
543 def __init__(self) -> None:
544 self._dict: dict[BaseAssetUniqueKey, OutletEventAccessor] = {}
545
546 def __str__(self) -> str:
547 return f"OutletEventAccessors(_dict={self._dict})"
548
549 def __iter__(self) -> Iterator[Asset | AssetAlias]:
550 return (
551 key.to_asset() if isinstance(key, AssetUniqueKey) else key.to_asset_alias() for key in self._dict
552 )
553
554 def __len__(self) -> int:
555 return len(self._dict)
556
557 def __getitem__(self, key: Asset | AssetAlias | AssetRef) -> OutletEventAccessor:
558 hashable_key: BaseAssetUniqueKey
559 if isinstance(key, Asset):
560 hashable_key = AssetUniqueKey.from_asset(key)
561 elif isinstance(key, AssetAlias):
562 hashable_key = AssetAliasUniqueKey.from_asset_alias(key)
563 elif isinstance(key, AssetRef):
564 hashable_key, _ = self._resolve_asset_ref(key)
565 else:
566 raise TypeError(f"Key should be either an asset or an asset alias, not {type(key)}")
567
568 if hashable_key not in self._dict:
569 self._dict[hashable_key] = OutletEventAccessor(extra={}, key=hashable_key)
570 return self._dict[hashable_key]
571
572
573@attrs.define(init=False)
574class InletEventsAccessor(Sequence["AssetEventResult"]):
575 _after: str | datetime | None
576 _before: str | datetime | None
577 _ascending: bool
578 _limit: int | None
579 _asset_name: str | None
580 _asset_uri: str | None
581 _alias_name: str | None
582
583 def __init__(
584 self, asset_name: str | None = None, asset_uri: str | None = None, alias_name: str | None = None
585 ):
586 self._asset_name = asset_name
587 self._asset_uri = asset_uri
588 self._alias_name = alias_name
589 self._after = None
590 self._before = None
591 self._ascending = True
592 self._limit = None
593
594 def after(self, after: str) -> Self:
595 self._after = after
596 self._reset_cache()
597 return self
598
599 def before(self, before: str) -> Self:
600 self._before = before
601 self._reset_cache()
602 return self
603
604 def ascending(self, ascending: bool = True) -> Self:
605 self._ascending = ascending
606 self._reset_cache()
607 return self
608
609 def limit(self, limit: int) -> Self:
610 self._limit = limit
611 self._reset_cache()
612 return self
613
614 @functools.cached_property
615 def _asset_events(self) -> list[AssetEventResult]:
616 from airflow.sdk.execution_time.comms import (
617 ErrorResponse,
618 GetAssetEventByAsset,
619 GetAssetEventByAssetAlias,
620 ToSupervisor,
621 )
622 from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
623
624 query_dict: dict[str, Any] = {
625 "after": self._after,
626 "before": self._before,
627 "ascending": self._ascending,
628 "limit": self._limit,
629 }
630
631 msg: ToSupervisor
632 if self._alias_name is not None:
633 msg = GetAssetEventByAssetAlias(alias_name=self._alias_name, **query_dict)
634 else:
635 if self._asset_name is None and self._asset_uri is None:
636 raise ValueError("Either asset_name or asset_uri must be provided")
637 msg = GetAssetEventByAsset(name=self._asset_name, uri=self._asset_uri, **query_dict)
638 resp = SUPERVISOR_COMMS.send(msg)
639 if isinstance(resp, ErrorResponse):
640 raise AirflowRuntimeError(resp)
641
642 if TYPE_CHECKING:
643 assert isinstance(resp, AssetEventsResult)
644
645 return list(resp.iter_asset_event_results())
646
647 def _reset_cache(self) -> None:
648 try:
649 del self._asset_events
650 except AttributeError:
651 pass
652
653 def __iter__(self) -> Iterator[AssetEventResult]:
654 return iter(self._asset_events)
655
656 def __len__(self) -> int:
657 return len(self._asset_events)
658
659 @overload
660 def __getitem__(self, key: int) -> AssetEventResult: ...
661
662 @overload
663 def __getitem__(self, key: slice) -> Sequence[AssetEventResult]: ...
664
665 def __getitem__(self, key: int | slice) -> AssetEventResult | Sequence[AssetEventResult]:
666 return self._asset_events[key]
667
668
669@attrs.define(init=False)
670class InletEventsAccessors(
671 Mapping["int | Asset | AssetAlias | AssetRef", Any],
672 _AssetEventAccessorsMixin[Sequence["AssetEventResult"]],
673):
674 """Lazy mapping of inlet asset event accessors."""
675
676 _inlets: list[Any]
677 _assets: dict[AssetUniqueKey, Asset]
678 _asset_aliases: dict[AssetAliasUniqueKey, AssetAlias]
679
680 def __init__(self, inlets: list) -> None:
681 self._inlets = inlets
682 self._assets = {}
683 self._asset_aliases = {}
684
685 for inlet in inlets:
686 if isinstance(inlet, Asset):
687 self._assets[AssetUniqueKey.from_asset(inlet)] = inlet
688 elif isinstance(inlet, AssetAlias):
689 self._asset_aliases[AssetAliasUniqueKey.from_asset_alias(inlet)] = inlet
690 elif isinstance(inlet, AssetNameRef):
691 asset = OutletEventAccessors._get_asset_from_db(name=inlet.name)
692 self._assets[AssetUniqueKey.from_asset(asset)] = asset
693 elif isinstance(inlet, AssetUriRef):
694 asset = OutletEventAccessors._get_asset_from_db(uri=inlet.uri)
695 self._assets[AssetUniqueKey.from_asset(asset)] = asset
696
697 def __iter__(self) -> Iterator[Asset | AssetAlias]:
698 return iter(self._inlets)
699
700 def __len__(self) -> int:
701 return len(self._inlets)
702
703 def __getitem__(
704 self,
705 key: int | Asset | AssetAlias | AssetRef,
706 ) -> InletEventsAccessor:
707 from airflow.sdk.definitions.asset import Asset
708
709 if isinstance(key, int): # Support index access; it's easier for trivial cases.
710 obj = self._inlets[key]
711 if not isinstance(obj, (Asset, AssetAlias, AssetRef)):
712 raise IndexError(key)
713 else:
714 obj = key
715
716 if isinstance(obj, Asset):
717 asset = self._assets[AssetUniqueKey.from_asset(obj)]
718 return InletEventsAccessor(asset_name=asset.name, asset_uri=asset.uri)
719 if isinstance(obj, AssetNameRef):
720 try:
721 asset = next(a for k, a in self._assets.items() if k.name == obj.name)
722 except StopIteration:
723 raise KeyError(obj) from None
724 return InletEventsAccessor(asset_name=asset.name)
725 if isinstance(obj, AssetUriRef):
726 try:
727 asset = next(a for k, a in self._assets.items() if k.uri == obj.uri)
728 except StopIteration:
729 raise KeyError(obj) from None
730 return InletEventsAccessor(asset_uri=asset.uri)
731 if isinstance(obj, AssetAlias):
732 asset_alias = self._asset_aliases[AssetAliasUniqueKey.from_asset_alias(obj)]
733 return InletEventsAccessor(alias_name=asset_alias.name)
734 raise TypeError(f"`key` is of unknown type ({type(key).__name__})")
735
736
737@attrs.define
738class TriggeringAssetEventsAccessor(
739 _AssetRefResolutionMixin,
740 Mapping["Asset | AssetAlias | AssetRef", Sequence["AssetEventDagRunReferenceResult"]],
741 _AssetEventAccessorsMixin[Sequence["AssetEventDagRunReferenceResult"]],
742):
743 """Lazy mapping of triggering asset events."""
744
745 _events: Mapping[BaseAssetUniqueKey, Sequence[AssetEventDagRunReferenceResult]]
746
747 @classmethod
748 def build(cls, events: Iterable[AssetEventDagRunReferenceResult]) -> TriggeringAssetEventsAccessor:
749 coll: dict[BaseAssetUniqueKey, list[AssetEventDagRunReferenceResult]] = collections.defaultdict(list)
750 for event in events:
751 coll[AssetUniqueKey(name=event.asset.name, uri=event.asset.uri)].append(event)
752 for alias in event.source_aliases:
753 coll[AssetAliasUniqueKey(name=alias.name)].append(event)
754 return cls(coll)
755
756 def __str__(self) -> str:
757 return f"TriggeringAssetEventAccessor(_events={self._events})"
758
759 def __iter__(self) -> Iterator[Asset | AssetAlias]:
760 return (
761 key.to_asset() if isinstance(key, AssetUniqueKey) else key.to_asset_alias()
762 for key in self._events
763 )
764
765 def __len__(self) -> int:
766 return len(self._events)
767
768 def __getitem__(self, key: Asset | AssetAlias | AssetRef) -> Sequence[AssetEventDagRunReferenceResult]:
769 hashable_key: BaseAssetUniqueKey
770 if isinstance(key, Asset):
771 hashable_key = AssetUniqueKey.from_asset(key)
772 elif isinstance(key, AssetRef):
773 hashable_key, _ = self._resolve_asset_ref(key)
774 elif isinstance(key, AssetAlias):
775 hashable_key = AssetAliasUniqueKey.from_asset_alias(key)
776 else:
777 raise TypeError(f"Key should be either an asset or an asset alias, not {type(key)}")
778
779 return self._events[hashable_key]
780
781
782@cache # Prevent multiple API access.
783def get_previous_dagrun_success(ti_id: UUID) -> PrevSuccessfulDagRunResponse:
784 from airflow.sdk.execution_time import task_runner
785 from airflow.sdk.execution_time.comms import (
786 GetPrevSuccessfulDagRun,
787 PrevSuccessfulDagRunResponse,
788 PrevSuccessfulDagRunResult,
789 )
790
791 msg = task_runner.SUPERVISOR_COMMS.send(GetPrevSuccessfulDagRun(ti_id=ti_id))
792
793 if TYPE_CHECKING:
794 assert isinstance(msg, PrevSuccessfulDagRunResult)
795 return PrevSuccessfulDagRunResponse(**msg.model_dump(exclude={"type"}))
796
797
798@contextlib.contextmanager
799def set_current_context(context: Context) -> Generator[Context, None, None]:
800 """
801 Set the current execution context to the provided context object.
802
803 This method should be called once per Task execution, before calling operator.execute.
804 """
805 _CURRENT_CONTEXT.append(context)
806 try:
807 yield context
808 finally:
809 expected_state = _CURRENT_CONTEXT.pop()
810 if expected_state != context:
811 log.warning(
812 "Current context is not equal to the state at context stack.",
813 expected=context,
814 got=expected_state,
815 )
816
817
818def context_update_for_unmapped(context: Context, task: BaseOperator) -> None:
819 """
820 Update context after task unmapping.
821
822 Since ``get_template_context()`` is called before unmapping, the context
823 contains information about the mapped task. We need to do some in-place
824 updates to ensure the template context reflects the unmapped task instead.
825
826 :meta private:
827 """
828 from airflow.sdk.definitions.param import process_params
829
830 context["task"] = context["ti"].task = task
831 context["params"] = process_params(
832 context["dag"], task, context["dag_run"].conf, suppress_exception=False
833 )
834
835
836def context_to_airflow_vars(context: Mapping[str, Any], in_env_var_format: bool = False) -> dict[str, str]:
837 """
838 Return values used to externally reconstruct relations between dags, dag_runs, tasks and task_instances.
839
840 Given a context, this function provides a dictionary of values that can be used to
841 externally reconstruct relations between dags, dag_runs, tasks and task_instances.
842 Default to abc.def.ghi format and can be made to ABC_DEF_GHI format if
843 in_env_var_format is set to True.
844
845 :param context: The context for the task_instance of interest.
846 :param in_env_var_format: If returned vars should be in ABC_DEF_GHI format.
847 :return: task_instance context as dict.
848 """
849 from datetime import datetime
850
851 from airflow import settings
852
853 params = {}
854 if in_env_var_format:
855 name_format = "env_var_format"
856 else:
857 name_format = "default"
858
859 task = context.get("task")
860 task_instance = context.get("task_instance")
861 dag_run = context.get("dag_run")
862
863 ops = [
864 (task, "email", "AIRFLOW_CONTEXT_DAG_EMAIL"),
865 (task, "owner", "AIRFLOW_CONTEXT_DAG_OWNER"),
866 (task_instance, "dag_id", "AIRFLOW_CONTEXT_DAG_ID"),
867 (task_instance, "task_id", "AIRFLOW_CONTEXT_TASK_ID"),
868 (dag_run, "logical_date", "AIRFLOW_CONTEXT_LOGICAL_DATE"),
869 (task_instance, "try_number", "AIRFLOW_CONTEXT_TRY_NUMBER"),
870 (dag_run, "run_id", "AIRFLOW_CONTEXT_DAG_RUN_ID"),
871 ]
872
873 context_params = settings.get_airflow_context_vars(context)
874 for key_raw, value in context_params.items():
875 if not isinstance(key_raw, str):
876 raise TypeError(f"key <{key_raw}> must be string")
877 if not isinstance(value, str):
878 raise TypeError(f"value of key <{key_raw}> must be string, not {type(value)}")
879
880 if in_env_var_format and not key_raw.startswith(ENV_VAR_FORMAT_PREFIX):
881 key = ENV_VAR_FORMAT_PREFIX + key_raw.upper()
882 elif not key_raw.startswith(DEFAULT_FORMAT_PREFIX):
883 key = DEFAULT_FORMAT_PREFIX + key_raw
884 else:
885 key = key_raw
886 params[key] = value
887
888 for subject, attr, mapping_key in ops:
889 _attr = getattr(subject, attr, None)
890 if subject and _attr:
891 mapping_value = AIRFLOW_VAR_NAME_FORMAT_MAPPING[mapping_key][name_format]
892 if isinstance(_attr, str):
893 params[mapping_value] = _attr
894 elif isinstance(_attr, datetime):
895 params[mapping_value] = _attr.isoformat()
896 elif isinstance(_attr, list):
897 # os env variable value needs to be string
898 params[mapping_value] = ",".join(_attr)
899 else:
900 params[mapping_value] = str(_attr)
901
902 return params
903
904
905def context_get_outlet_events(context: Context) -> OutletEventAccessorsProtocol:
906 try:
907 outlet_events = context["outlet_events"]
908 except KeyError:
909 outlet_events = context["outlet_events"] = OutletEventAccessors()
910 return outlet_events