1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18from __future__ import annotations
19
20import functools
21import operator
22from collections.abc import Iterable, Mapping, Sequence, Sized
23from typing import TYPE_CHECKING, Any, ClassVar, Union
24
25import attrs
26
27from airflow.sdk.definitions._internal.mixins import ResolveMixin
28
29if TYPE_CHECKING:
30 from typing import TypeGuard
31
32 from airflow.sdk.definitions.xcom_arg import XComArg
33 from airflow.sdk.types import Operator
34
35ExpandInput = Union["DictOfListsExpandInput", "ListOfDictsExpandInput"]
36
37# Each keyword argument to expand() can be an XComArg, sequence, or dict (not
38# any mapping since we need the value to be ordered).
39OperatorExpandArgument = Union["MappedArgument", "XComArg", Sequence, dict[str, Any]]
40
41# The single argument of expand_kwargs() can be an XComArg, or a list with each
42# element being either an XComArg or a dict.
43OperatorExpandKwargsArgument = Union["XComArg", Sequence[Union["XComArg", Mapping[str, Any]]]]
44
45
46class NotFullyPopulated(RuntimeError):
47 """
48 Raise when ``get_map_lengths`` cannot populate all mapping metadata.
49
50 This is generally due to not all upstream tasks have finished when the
51 function is called.
52 """
53
54 def __init__(self, missing: set[str]) -> None:
55 self.missing = missing
56
57 def __str__(self) -> str:
58 keys = ", ".join(repr(k) for k in sorted(self.missing))
59 return f"Failed to populate all mapping metadata; missing: {keys}"
60
61
62# To replace tedious isinstance() checks.
63def is_mappable(v: Any) -> TypeGuard[OperatorExpandArgument]:
64 from airflow.sdk.definitions.xcom_arg import XComArg
65
66 return isinstance(v, (MappedArgument, XComArg, Mapping, Sequence)) and not isinstance(v, str)
67
68
69# To replace tedious isinstance() checks.
70def _is_parse_time_mappable(v: OperatorExpandArgument) -> TypeGuard[Mapping | Sequence]:
71 from airflow.sdk.definitions.xcom_arg import XComArg
72
73 return not isinstance(v, (MappedArgument, XComArg))
74
75
76# To replace tedious isinstance() checks.
77def _needs_run_time_resolution(v: OperatorExpandArgument) -> TypeGuard[MappedArgument | XComArg]:
78 from airflow.sdk.definitions.xcom_arg import XComArg
79
80 return isinstance(v, (MappedArgument, XComArg))
81
82
83@attrs.define(kw_only=True)
84class MappedArgument(ResolveMixin):
85 """
86 Stand-in stub for task-group-mapping arguments.
87
88 This is very similar to an XComArg, but resolved differently. Declared here
89 (instead of in the task group module) to avoid import cycles.
90 """
91
92 _input: ExpandInput = attrs.field()
93 _key: str
94
95 @_input.validator
96 def _validate_input(self, _, input):
97 if isinstance(input, DictOfListsExpandInput):
98 for value in input.value.values():
99 if isinstance(value, MappedArgument):
100 raise ValueError("Nested Mapped TaskGroups are not yet supported")
101
102 def iter_references(self) -> Iterable[tuple[Operator, str]]:
103 yield from self._input.iter_references()
104
105 def resolve(self, context: Mapping[str, Any]) -> Any:
106 data, _ = self._input.resolve(context)
107 return data[self._key]
108
109
110@attrs.define()
111class DictOfListsExpandInput(ResolveMixin):
112 """
113 Storage type of a mapped operator's mapped kwargs.
114
115 This is created from ``expand(**kwargs)``.
116 """
117
118 value: dict[str, OperatorExpandArgument]
119
120 EXPAND_INPUT_TYPE: ClassVar[str] = "dict-of-lists"
121
122 def _iter_parse_time_resolved_kwargs(self) -> Iterable[tuple[str, Sized]]:
123 """Generate kwargs with values available on parse-time."""
124 return ((k, v) for k, v in self.value.items() if _is_parse_time_mappable(v))
125
126 def get_parse_time_mapped_ti_count(self) -> int:
127 if not self.value:
128 return 0
129 literal_values = [len(v) for _, v in self._iter_parse_time_resolved_kwargs()]
130 if len(literal_values) != len(self.value):
131 literal_keys = (k for k, _ in self._iter_parse_time_resolved_kwargs())
132 raise NotFullyPopulated(set(self.value).difference(literal_keys))
133 return functools.reduce(operator.mul, literal_values, 1)
134
135 def _get_map_lengths(
136 self, resolved_vals: dict[str, Sized], upstream_map_indexes: dict[str, int]
137 ) -> dict[str, int]:
138 """
139 Return dict of argument name to map length.
140
141 If any arguments are not known right now (upstream task not finished),
142 they will not be present in the dict.
143 """
144
145 # TODO: This initiates one API call for each XComArg. Would it be
146 # more efficient to do one single call and unpack the value here?
147 def _get_length(k: str, v: OperatorExpandArgument) -> int | None:
148 from airflow.sdk.definitions.xcom_arg import XComArg, get_task_map_length
149
150 if isinstance(v, XComArg):
151 return get_task_map_length(v, resolved_vals[k], upstream_map_indexes)
152
153 # Unfortunately a user-defined TypeGuard cannot apply negative type
154 # narrowing. https://github.com/python/typing/discussions/1013
155 if TYPE_CHECKING:
156 assert isinstance(v, Sized)
157 return len(v)
158
159 map_lengths = {
160 k: res for k, v in self.value.items() if v is not None if (res := _get_length(k, v)) is not None
161 }
162 if len(map_lengths) < len(self.value):
163 raise NotFullyPopulated(set(self.value).difference(map_lengths))
164 return map_lengths
165
166 def _expand_mapped_field(self, key: str, value: Any, map_index: int, all_lengths: dict[str, int]) -> Any:
167 def _find_index_for_this_field(index: int) -> int:
168 # Need to use the original user input to retain argument order.
169 for mapped_key in reversed(self.value):
170 mapped_length = all_lengths[mapped_key]
171 if mapped_length < 1:
172 raise RuntimeError(f"cannot expand field mapped to length {mapped_length!r}")
173 if mapped_key == key:
174 return index % mapped_length
175 index //= mapped_length
176 return -1
177
178 found_index = _find_index_for_this_field(map_index)
179 if found_index < 0:
180 return value
181 if isinstance(value, Sequence):
182 return value[found_index]
183 if not isinstance(value, dict):
184 raise TypeError(f"can't map over value of type {type(value)}")
185 for i, (k, v) in enumerate(value.items()):
186 if i == found_index:
187 return k, v
188 raise IndexError(f"index {map_index} is over mapped length")
189
190 def iter_references(self) -> Iterable[tuple[Operator, str]]:
191 from airflow.sdk.definitions.xcom_arg import XComArg
192
193 for x in self.value.values():
194 if isinstance(x, XComArg):
195 yield from x.iter_references()
196
197 def resolve(self, context: Mapping[str, Any]) -> tuple[Mapping[str, Any], set[int]]:
198 map_index: int | None = context["ti"].map_index
199 if map_index is None or map_index < 0:
200 raise RuntimeError("can't resolve task-mapping argument without expanding")
201
202 upstream_map_indexes = getattr(context["ti"], "_upstream_map_indexes", {})
203
204 # TODO: This initiates one API call for each XComArg. Would it be
205 # more efficient to do one single call and unpack the value here?
206
207 resolved = {
208 k: v.resolve(context) if _needs_run_time_resolution(v) else v for k, v in self.value.items()
209 }
210
211 sized_resolved = {k: v for k, v in resolved.items() if isinstance(v, Sized)}
212
213 all_lengths = self._get_map_lengths(sized_resolved, upstream_map_indexes)
214
215 data = {k: self._expand_mapped_field(k, v, map_index, all_lengths) for k, v in resolved.items()}
216 literal_keys = {k for k, _ in self._iter_parse_time_resolved_kwargs()}
217 resolved_oids = {id(v) for k, v in data.items() if k not in literal_keys}
218 return data, resolved_oids
219
220
221def _describe_type(value: Any) -> str:
222 if value is None:
223 return "None"
224 return type(value).__name__
225
226
227@attrs.define()
228class ListOfDictsExpandInput(ResolveMixin):
229 """
230 Storage type of a mapped operator's mapped kwargs.
231
232 This is created from ``expand_kwargs(xcom_arg)``.
233 """
234
235 value: OperatorExpandKwargsArgument
236
237 EXPAND_INPUT_TYPE: ClassVar[str] = "list-of-dicts"
238
239 def get_parse_time_mapped_ti_count(self) -> int:
240 if isinstance(self.value, Sized):
241 return len(self.value)
242 raise NotFullyPopulated({"expand_kwargs() argument"})
243
244 def iter_references(self) -> Iterable[tuple[Operator, str]]:
245 from airflow.sdk.definitions.xcom_arg import XComArg
246
247 if isinstance(self.value, XComArg):
248 yield from self.value.iter_references()
249 else:
250 for x in self.value:
251 if isinstance(x, XComArg):
252 yield from x.iter_references()
253
254 def resolve(self, context: Mapping[str, Any]) -> tuple[Mapping[str, Any], set[int]]:
255 map_index = context["ti"].map_index
256 if map_index < 0:
257 raise RuntimeError("can't resolve task-mapping argument without expanding")
258
259 mapping: Any = None
260 if isinstance(self.value, Sized):
261 mapping = self.value[map_index]
262 if not isinstance(mapping, Mapping):
263 mapping = mapping.resolve(context)
264 else:
265 mappings = self.value.resolve(context)
266 if not isinstance(mappings, Sequence):
267 raise ValueError(f"expand_kwargs() expects a list[dict], not {_describe_type(mappings)}")
268 mapping = mappings[map_index]
269
270 if not isinstance(mapping, Mapping):
271 raise ValueError(f"expand_kwargs() expects a list[dict], not list[{_describe_type(mapping)}]")
272
273 for key in mapping:
274 if not isinstance(key, str):
275 raise ValueError(
276 f"expand_kwargs() input dict keys must all be str, "
277 f"but {key!r} is of type {_describe_type(key)}"
278 )
279 # filter out parse time resolved values from the resolved_oids
280 resolved_oids = {id(v) for k, v in mapping.items() if not _is_parse_time_mappable(v)}
281
282 return mapping, resolved_oids
283
284
285EXPAND_INPUT_EMPTY = DictOfListsExpandInput({}) # Sentinel value.