1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18from __future__ import annotations
19
20from collections.abc import Iterable, Mapping, Sequence, Sized
21from typing import TYPE_CHECKING, Any, ClassVar, Union
22
23import attrs
24
25from airflow.sdk.definitions._internal.mixins import ResolveMixin
26
27if TYPE_CHECKING:
28 from typing import TypeGuard
29
30 from airflow.sdk.definitions.xcom_arg import XComArg
31 from airflow.sdk.types import Operator
32
33ExpandInput = Union["DictOfListsExpandInput", "ListOfDictsExpandInput"]
34
35# Each keyword argument to expand() can be an XComArg, sequence, or dict (not
36# any mapping since we need the value to be ordered).
37OperatorExpandArgument = Union["MappedArgument", "XComArg", Sequence, dict[str, Any]]
38
39# The single argument of expand_kwargs() can be an XComArg, or a list with each
40# element being either an XComArg or a dict.
41OperatorExpandKwargsArgument = Union["XComArg", Sequence[Union["XComArg", Mapping[str, Any]]]]
42
43
44class _NotFullyPopulated(RuntimeError):
45 """
46 Raise when an expand input cannot be resolved due to incomplete metadata.
47
48 This generally should not happen. The scheduler should have made sure that
49 a not-yet-ready-to-expand task should not be executed. In the off chance
50 this gets raised, it will fail the task instance.
51 """
52
53 def __init__(self, missing: set[str]) -> None:
54 self.missing = missing
55
56 def __str__(self) -> str:
57 keys = ", ".join(repr(k) for k in sorted(self.missing))
58 return f"Failed to populate all mapping metadata; missing: {keys}"
59
60
61# To replace tedious isinstance() checks.
62def is_mappable(v: Any) -> TypeGuard[OperatorExpandArgument]:
63 from airflow.sdk.definitions.xcom_arg import XComArg
64
65 return isinstance(v, (MappedArgument, XComArg, Mapping, Sequence)) and not isinstance(v, str)
66
67
68# To replace tedious isinstance() checks.
69def _is_parse_time_mappable(v: OperatorExpandArgument) -> TypeGuard[Mapping | Sequence]:
70 from airflow.sdk.definitions.xcom_arg import XComArg
71
72 return not isinstance(v, (MappedArgument, XComArg))
73
74
75# To replace tedious isinstance() checks.
76def _needs_run_time_resolution(v: OperatorExpandArgument) -> TypeGuard[MappedArgument | XComArg]:
77 from airflow.sdk.definitions.xcom_arg import XComArg
78
79 return isinstance(v, (MappedArgument, XComArg))
80
81
82@attrs.define(kw_only=True)
83class MappedArgument(ResolveMixin):
84 """
85 Stand-in stub for task-group-mapping arguments.
86
87 This is very similar to an XComArg, but resolved differently. Declared here
88 (instead of in the task group module) to avoid import cycles.
89 """
90
91 _input: ExpandInput = attrs.field()
92 _key: str
93
94 @_input.validator
95 def _validate_input(self, _, input):
96 if isinstance(input, DictOfListsExpandInput):
97 for value in input.value.values():
98 if isinstance(value, MappedArgument):
99 raise ValueError("Nested Mapped TaskGroups are not yet supported")
100
101 def iter_references(self) -> Iterable[tuple[Operator, str]]:
102 yield from self._input.iter_references()
103
104 def resolve(self, context: Mapping[str, Any]) -> Any:
105 data, _ = self._input.resolve(context)
106 return data[self._key]
107
108
109@attrs.define()
110class DictOfListsExpandInput(ResolveMixin):
111 """
112 Storage type of a mapped operator's mapped kwargs.
113
114 This is created from ``expand(**kwargs)``.
115 """
116
117 value: dict[str, OperatorExpandArgument]
118
119 EXPAND_INPUT_TYPE: ClassVar[str] = "dict-of-lists"
120
121 def _iter_parse_time_resolved_kwargs(self) -> Iterable[tuple[str, Sized]]:
122 """Generate kwargs with values available on parse-time."""
123 return ((k, v) for k, v in self.value.items() if _is_parse_time_mappable(v))
124
125 def _get_map_lengths(
126 self, resolved_vals: dict[str, Sized], upstream_map_indexes: dict[str, int]
127 ) -> dict[str, int]:
128 """
129 Return dict of argument name to map length.
130
131 If any arguments are not known right now (upstream task not finished),
132 they will not be present in the dict.
133 """
134
135 # TODO: This initiates one API call for each XComArg. Would it be
136 # more efficient to do one single call and unpack the value here?
137 def _get_length(k: str, v: OperatorExpandArgument) -> int | None:
138 from airflow.sdk.definitions.xcom_arg import XComArg, get_task_map_length
139
140 if isinstance(v, XComArg):
141 return get_task_map_length(v, resolved_vals[k], upstream_map_indexes)
142
143 # Unfortunately a user-defined TypeGuard cannot apply negative type
144 # narrowing. https://github.com/python/typing/discussions/1013
145 if TYPE_CHECKING:
146 assert isinstance(v, Sized)
147 return len(v)
148
149 map_lengths = {
150 k: res for k, v in self.value.items() if v is not None if (res := _get_length(k, v)) is not None
151 }
152 if len(map_lengths) < len(self.value):
153 raise _NotFullyPopulated(set(self.value).difference(map_lengths))
154 return map_lengths
155
156 def _expand_mapped_field(self, key: str, value: Any, map_index: int, all_lengths: dict[str, int]) -> Any:
157 def _find_index_for_this_field(index: int) -> int:
158 # Need to use the original user input to retain argument order.
159 for mapped_key in reversed(self.value):
160 mapped_length = all_lengths[mapped_key]
161 if mapped_length < 1:
162 raise RuntimeError(f"cannot expand field mapped to length {mapped_length!r}")
163 if mapped_key == key:
164 return index % mapped_length
165 index //= mapped_length
166 return -1
167
168 found_index = _find_index_for_this_field(map_index)
169 if found_index < 0:
170 return value
171 if isinstance(value, Sequence):
172 return value[found_index]
173 if not isinstance(value, dict):
174 raise TypeError(f"can't map over value of type {type(value)}")
175 for i, (k, v) in enumerate(value.items()):
176 if i == found_index:
177 return k, v
178 raise IndexError(f"index {map_index} is over mapped length")
179
180 def iter_references(self) -> Iterable[tuple[Operator, str]]:
181 from airflow.sdk.definitions.xcom_arg import XComArg
182
183 for x in self.value.values():
184 if isinstance(x, XComArg):
185 yield from x.iter_references()
186
187 def resolve(self, context: Mapping[str, Any]) -> tuple[Mapping[str, Any], set[int]]:
188 map_index: int | None = context["ti"].map_index
189 if map_index is None or map_index < 0:
190 raise RuntimeError("can't resolve task-mapping argument without expanding")
191
192 # Get pre-computed upstream_map_indexes if available, otherwise default to empty dict.
193 # When empty, individual XComArgs will compute their map_indexes lazily in xcom_arg.py.
194 upstream_map_indexes = getattr(context["ti"], "_upstream_map_indexes", None) or {}
195
196 # TODO: This initiates one API call for each XComArg. Would it be
197 # more efficient to do one single call and unpack the value here?
198
199 resolved = {
200 k: v.resolve(context) if _needs_run_time_resolution(v) else v for k, v in self.value.items()
201 }
202
203 sized_resolved = {k: v for k, v in resolved.items() if isinstance(v, Sized)}
204
205 all_lengths = self._get_map_lengths(sized_resolved, upstream_map_indexes)
206
207 data = {k: self._expand_mapped_field(k, v, map_index, all_lengths) for k, v in resolved.items()}
208 literal_keys = {k for k, _ in self._iter_parse_time_resolved_kwargs()}
209 resolved_oids = {id(v) for k, v in data.items() if k not in literal_keys}
210 return data, resolved_oids
211
212
213def _describe_type(value: Any) -> str:
214 if value is None:
215 return "None"
216 return type(value).__name__
217
218
219@attrs.define()
220class ListOfDictsExpandInput(ResolveMixin):
221 """
222 Storage type of a mapped operator's mapped kwargs.
223
224 This is created from ``expand_kwargs(xcom_arg)``.
225 """
226
227 value: OperatorExpandKwargsArgument
228
229 EXPAND_INPUT_TYPE: ClassVar[str] = "list-of-dicts"
230
231 def iter_references(self) -> Iterable[tuple[Operator, str]]:
232 from airflow.sdk.definitions.xcom_arg import XComArg
233
234 if isinstance(self.value, XComArg):
235 yield from self.value.iter_references()
236 else:
237 for x in self.value:
238 if isinstance(x, XComArg):
239 yield from x.iter_references()
240
241 def resolve(self, context: Mapping[str, Any]) -> tuple[Mapping[str, Any], set[int]]:
242 map_index = context["ti"].map_index
243 if map_index < 0:
244 raise RuntimeError("can't resolve task-mapping argument without expanding")
245
246 mapping: Any = None
247 if isinstance(self.value, Sized):
248 mapping = self.value[map_index]
249 if not isinstance(mapping, Mapping):
250 mapping = mapping.resolve(context)
251 else:
252 mappings = self.value.resolve(context)
253 if not isinstance(mappings, Sequence):
254 raise ValueError(f"expand_kwargs() expects a list[dict], not {_describe_type(mappings)}")
255 mapping = mappings[map_index]
256
257 if not isinstance(mapping, Mapping):
258 raise ValueError(f"expand_kwargs() expects a list[dict], not list[{_describe_type(mapping)}]")
259
260 for key in mapping:
261 if not isinstance(key, str):
262 raise ValueError(
263 f"expand_kwargs() input dict keys must all be str, "
264 f"but {key!r} is of type {_describe_type(key)}"
265 )
266 # filter out parse time resolved values from the resolved_oids
267 resolved_oids = {id(v) for k, v in mapping.items() if not _is_parse_time_mappable(v)}
268
269 return mapping, resolved_oids
270
271
272EXPAND_INPUT_EMPTY = DictOfListsExpandInput({}) # Sentinel value.