1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18from __future__ import annotations
19
20from collections.abc import Iterable, Mapping, Sequence, Sized
21from typing import TYPE_CHECKING, Any, ClassVar, Union
22
23import attrs
24
25from airflow.sdk.definitions._internal.mixins import ResolveMixin
26
27if TYPE_CHECKING:
28 from typing import TypeGuard
29
30 from airflow.sdk.definitions.xcom_arg import XComArg
31 from airflow.sdk.types import Operator
32
33ExpandInput = Union["DictOfListsExpandInput", "ListOfDictsExpandInput"]
34
35# Each keyword argument to expand() can be an XComArg, sequence, or dict (not
36# any mapping since we need the value to be ordered).
37OperatorExpandArgument = Union["MappedArgument", "XComArg", Sequence, dict[str, Any]]
38
39# The single argument of expand_kwargs() can be an XComArg, or a list with each
40# element being either an XComArg or a dict.
41OperatorExpandKwargsArgument = Union["XComArg", Sequence[Union["XComArg", Mapping[str, Any]]]]
42
43
44class _NotFullyPopulated(RuntimeError):
45 """
46 Raise when an expand input cannot be resolved due to incomplete metadata.
47
48 This generally should not happen. The scheduler should have made sure that
49 a not-yet-ready-to-expand task should not be executed. In the off chance
50 this gets raised, it will fail the task instance.
51 """
52
53 def __init__(self, missing: set[str]) -> None:
54 self.missing = missing
55
56 def __str__(self) -> str:
57 keys = ", ".join(repr(k) for k in sorted(self.missing))
58 return f"Failed to populate all mapping metadata; missing: {keys}"
59
60
61# To replace tedious isinstance() checks.
62def is_mappable(v: Any) -> TypeGuard[OperatorExpandArgument]:
63 from airflow.sdk.definitions.xcom_arg import XComArg
64
65 return isinstance(v, (MappedArgument, XComArg, Mapping, Sequence)) and not isinstance(v, str)
66
67
68# To replace tedious isinstance() checks.
69def _is_parse_time_mappable(v: OperatorExpandArgument) -> TypeGuard[Mapping | Sequence]:
70 from airflow.sdk.definitions.xcom_arg import XComArg
71
72 return not isinstance(v, (MappedArgument, XComArg))
73
74
75# To replace tedious isinstance() checks.
76def _needs_run_time_resolution(v: OperatorExpandArgument) -> TypeGuard[MappedArgument | XComArg]:
77 from airflow.sdk.definitions.xcom_arg import XComArg
78
79 return isinstance(v, (MappedArgument, XComArg))
80
81
82@attrs.define(kw_only=True)
83class MappedArgument(ResolveMixin):
84 """
85 Stand-in stub for task-group-mapping arguments.
86
87 This is very similar to an XComArg, but resolved differently. Declared here
88 (instead of in the task group module) to avoid import cycles.
89 """
90
91 _input: ExpandInput = attrs.field()
92 _key: str
93
94 @_input.validator
95 def _validate_input(self, _, input):
96 if isinstance(input, DictOfListsExpandInput):
97 for value in input.value.values():
98 if isinstance(value, MappedArgument):
99 raise ValueError("Nested Mapped TaskGroups are not yet supported")
100
101 def iter_references(self) -> Iterable[tuple[Operator, str]]:
102 yield from self._input.iter_references()
103
104 def resolve(self, context: Mapping[str, Any]) -> Any:
105 data, _ = self._input.resolve(context)
106 return data[self._key]
107
108
109@attrs.define()
110class DictOfListsExpandInput(ResolveMixin):
111 """
112 Storage type of a mapped operator's mapped kwargs.
113
114 This is created from ``expand(**kwargs)``.
115 """
116
117 value: dict[str, OperatorExpandArgument]
118
119 EXPAND_INPUT_TYPE: ClassVar[str] = "dict-of-lists"
120
121 def _iter_parse_time_resolved_kwargs(self) -> Iterable[tuple[str, Sized]]:
122 """Generate kwargs with values available on parse-time."""
123 return ((k, v) for k, v in self.value.items() if _is_parse_time_mappable(v))
124
125 def _get_map_lengths(
126 self, resolved_vals: dict[str, Sized], upstream_map_indexes: dict[str, int]
127 ) -> dict[str, int]:
128 """
129 Return dict of argument name to map length.
130
131 If any arguments are not known right now (upstream task not finished),
132 they will not be present in the dict.
133 """
134
135 # TODO: This initiates one API call for each XComArg. Would it be
136 # more efficient to do one single call and unpack the value here?
137 def _get_length(k: str, v: OperatorExpandArgument) -> int | None:
138 from airflow.sdk.definitions.xcom_arg import XComArg, get_task_map_length
139
140 if isinstance(v, XComArg):
141 return get_task_map_length(v, resolved_vals[k], upstream_map_indexes)
142
143 # Unfortunately a user-defined TypeGuard cannot apply negative type
144 # narrowing. https://github.com/python/typing/discussions/1013
145 if TYPE_CHECKING:
146 assert isinstance(v, Sized)
147 return len(v)
148
149 map_lengths = {
150 k: res for k, v in self.value.items() if v is not None if (res := _get_length(k, v)) is not None
151 }
152 if len(map_lengths) < len(self.value):
153 raise _NotFullyPopulated(set(self.value).difference(map_lengths))
154 return map_lengths
155
156 def _expand_mapped_field(self, key: str, value: Any, map_index: int, all_lengths: dict[str, int]) -> Any:
157 def _find_index_for_this_field(index: int) -> int:
158 # Need to use the original user input to retain argument order.
159 for mapped_key in reversed(self.value):
160 mapped_length = all_lengths[mapped_key]
161 if mapped_length < 1:
162 raise RuntimeError(f"cannot expand field mapped to length {mapped_length!r}")
163 if mapped_key == key:
164 return index % mapped_length
165 index //= mapped_length
166 return -1
167
168 found_index = _find_index_for_this_field(map_index)
169 if found_index < 0:
170 return value
171 if isinstance(value, Sequence):
172 return value[found_index]
173 if not isinstance(value, dict):
174 raise TypeError(f"can't map over value of type {type(value)}")
175 for i, (k, v) in enumerate(value.items()):
176 if i == found_index:
177 return k, v
178 raise IndexError(f"index {map_index} is over mapped length")
179
180 def iter_references(self) -> Iterable[tuple[Operator, str]]:
181 from airflow.sdk.definitions.xcom_arg import XComArg
182
183 for x in self.value.values():
184 if isinstance(x, XComArg):
185 yield from x.iter_references()
186
187 def resolve(self, context: Mapping[str, Any]) -> tuple[Mapping[str, Any], set[int]]:
188 map_index: int | None = context["ti"].map_index
189 if map_index is None or map_index < 0:
190 raise RuntimeError("can't resolve task-mapping argument without expanding")
191
192 upstream_map_indexes = getattr(context["ti"], "_upstream_map_indexes", {})
193
194 # TODO: This initiates one API call for each XComArg. Would it be
195 # more efficient to do one single call and unpack the value here?
196
197 resolved = {
198 k: v.resolve(context) if _needs_run_time_resolution(v) else v for k, v in self.value.items()
199 }
200
201 sized_resolved = {k: v for k, v in resolved.items() if isinstance(v, Sized)}
202
203 all_lengths = self._get_map_lengths(sized_resolved, upstream_map_indexes)
204
205 data = {k: self._expand_mapped_field(k, v, map_index, all_lengths) for k, v in resolved.items()}
206 literal_keys = {k for k, _ in self._iter_parse_time_resolved_kwargs()}
207 resolved_oids = {id(v) for k, v in data.items() if k not in literal_keys}
208 return data, resolved_oids
209
210
211def _describe_type(value: Any) -> str:
212 if value is None:
213 return "None"
214 return type(value).__name__
215
216
217@attrs.define()
218class ListOfDictsExpandInput(ResolveMixin):
219 """
220 Storage type of a mapped operator's mapped kwargs.
221
222 This is created from ``expand_kwargs(xcom_arg)``.
223 """
224
225 value: OperatorExpandKwargsArgument
226
227 EXPAND_INPUT_TYPE: ClassVar[str] = "list-of-dicts"
228
229 def iter_references(self) -> Iterable[tuple[Operator, str]]:
230 from airflow.sdk.definitions.xcom_arg import XComArg
231
232 if isinstance(self.value, XComArg):
233 yield from self.value.iter_references()
234 else:
235 for x in self.value:
236 if isinstance(x, XComArg):
237 yield from x.iter_references()
238
239 def resolve(self, context: Mapping[str, Any]) -> tuple[Mapping[str, Any], set[int]]:
240 map_index = context["ti"].map_index
241 if map_index < 0:
242 raise RuntimeError("can't resolve task-mapping argument without expanding")
243
244 mapping: Any = None
245 if isinstance(self.value, Sized):
246 mapping = self.value[map_index]
247 if not isinstance(mapping, Mapping):
248 mapping = mapping.resolve(context)
249 else:
250 mappings = self.value.resolve(context)
251 if not isinstance(mappings, Sequence):
252 raise ValueError(f"expand_kwargs() expects a list[dict], not {_describe_type(mappings)}")
253 mapping = mappings[map_index]
254
255 if not isinstance(mapping, Mapping):
256 raise ValueError(f"expand_kwargs() expects a list[dict], not list[{_describe_type(mapping)}]")
257
258 for key in mapping:
259 if not isinstance(key, str):
260 raise ValueError(
261 f"expand_kwargs() input dict keys must all be str, "
262 f"but {key!r} is of type {_describe_type(key)}"
263 )
264 # filter out parse time resolved values from the resolved_oids
265 resolved_oids = {id(v) for k, v in mapping.items() if not _is_parse_time_mappable(v)}
266
267 return mapping, resolved_oids
268
269
270EXPAND_INPUT_EMPTY = DictOfListsExpandInput({}) # Sentinel value.