Coverage for /pythoncovmergedfiles/medio/medio/src/airflow/build/lib/airflow/models/expandinput.py: 30%
158 statements
« prev ^ index » next coverage.py v7.0.1, created at 2022-12-25 06:11 +0000
« prev ^ index » next coverage.py v7.0.1, created at 2022-12-25 06:11 +0000
1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18from __future__ import annotations
20import collections.abc
21import functools
22import operator
23from typing import TYPE_CHECKING, Any, Dict, Iterable, Mapping, NamedTuple, Sequence, Sized, Union
25import attr
27from airflow.typing_compat import TypeGuard
28from airflow.utils.context import Context
29from airflow.utils.mixins import ResolveMixin
30from airflow.utils.session import NEW_SESSION, provide_session
32if TYPE_CHECKING:
33 from sqlalchemy.orm import Session
35 from airflow.models.operator import Operator
36 from airflow.models.xcom_arg import XComArg
38ExpandInput = Union["DictOfListsExpandInput", "ListOfDictsExpandInput"]
40# Each keyword argument to expand() can be an XComArg, sequence, or dict (not
41# any mapping since we need the value to be ordered).
42OperatorExpandArgument = Union["MappedArgument", "XComArg", Sequence, Dict[str, Any]]
44# The single argument of expand_kwargs() can be an XComArg, or a list with each
45# element being either an XComArg or a dict.
46OperatorExpandKwargsArgument = Union["XComArg", Sequence[Union["XComArg", Mapping[str, Any]]]]
49@attr.define(kw_only=True)
50class MappedArgument(ResolveMixin):
51 """Stand-in stub for task-group-mapping arguments.
53 This is very similar to an XComArg, but resolved differently. Declared here
54 (instead of in the task group module) to avoid import cycles.
55 """
57 _input: ExpandInput
58 _key: str
60 def get_task_map_length(self, run_id: str, *, session: Session) -> int | None:
61 # TODO (AIP-42): Implement run-time task map length inspection. This is
62 # needed when we implement task mapping inside a mapped task group.
63 raise NotImplementedError()
65 def iter_references(self) -> Iterable[tuple[Operator, str]]:
66 yield from self._input.iter_references()
68 @provide_session
69 def resolve(self, context: Context, *, session: Session = NEW_SESSION) -> Any:
70 data, _ = self._input.resolve(context, session=session)
71 return data[self._key]
74# To replace tedious isinstance() checks.
75def is_mappable(v: Any) -> TypeGuard[OperatorExpandArgument]:
76 from airflow.models.xcom_arg import XComArg
78 return isinstance(v, (MappedArgument, XComArg, Mapping, Sequence)) and not isinstance(v, str)
81# To replace tedious isinstance() checks.
82def _is_parse_time_mappable(v: OperatorExpandArgument) -> TypeGuard[Mapping | Sequence]:
83 from airflow.models.xcom_arg import XComArg
85 return not isinstance(v, (MappedArgument, XComArg))
88# To replace tedious isinstance() checks.
89def _needs_run_time_resolution(v: OperatorExpandArgument) -> TypeGuard[MappedArgument | XComArg]:
90 from airflow.models.xcom_arg import XComArg
92 return isinstance(v, (MappedArgument, XComArg))
95class NotFullyPopulated(RuntimeError):
96 """Raise when ``get_map_lengths`` cannot populate all mapping metadata.
98 This is generally due to not all upstream tasks have finished when the
99 function is called.
100 """
102 def __init__(self, missing: set[str]) -> None:
103 self.missing = missing
105 def __str__(self) -> str:
106 keys = ", ".join(repr(k) for k in sorted(self.missing))
107 return f"Failed to populate all mapping metadata; missing: {keys}"
110class DictOfListsExpandInput(NamedTuple):
111 """Storage type of a mapped operator's mapped kwargs.
113 This is created from ``expand(**kwargs)``.
114 """
116 value: dict[str, OperatorExpandArgument]
118 def _iter_parse_time_resolved_kwargs(self) -> Iterable[tuple[str, Sized]]:
119 """Generate kwargs with values available on parse-time."""
120 return ((k, v) for k, v in self.value.items() if _is_parse_time_mappable(v))
122 def get_parse_time_mapped_ti_count(self) -> int:
123 if not self.value:
124 return 0
125 literal_values = [len(v) for _, v in self._iter_parse_time_resolved_kwargs()]
126 if len(literal_values) != len(self.value):
127 literal_keys = (k for k, _ in self._iter_parse_time_resolved_kwargs())
128 raise NotFullyPopulated(set(self.value).difference(literal_keys))
129 return functools.reduce(operator.mul, literal_values, 1)
131 def _get_map_lengths(self, run_id: str, *, session: Session) -> dict[str, int]:
132 """Return dict of argument name to map length.
134 If any arguments are not known right now (upstream task not finished),
135 they will not be present in the dict.
136 """
137 # TODO: This initiates one database call for each XComArg. Would it be
138 # more efficient to do one single db call and unpack the value here?
139 def _get_length(v: OperatorExpandArgument) -> int | None:
140 if _needs_run_time_resolution(v):
141 return v.get_task_map_length(run_id, session=session)
142 # Unfortunately a user-defined TypeGuard cannot apply negative type
143 # narrowing. https://github.com/python/typing/discussions/1013
144 if TYPE_CHECKING:
145 assert isinstance(v, Sized)
146 return len(v)
148 map_lengths_iterator = ((k, _get_length(v)) for k, v in self.value.items())
150 map_lengths = {k: v for k, v in map_lengths_iterator if v is not None}
151 if len(map_lengths) < len(self.value):
152 raise NotFullyPopulated(set(self.value).difference(map_lengths))
153 return map_lengths
155 def get_total_map_length(self, run_id: str, *, session: Session) -> int:
156 if not self.value:
157 return 0
158 lengths = self._get_map_lengths(run_id, session=session)
159 return functools.reduce(operator.mul, (lengths[name] for name in self.value), 1)
161 def _expand_mapped_field(self, key: str, value: Any, context: Context, *, session: Session) -> Any:
162 if _needs_run_time_resolution(value):
163 value = value.resolve(context, session=session)
164 map_index = context["ti"].map_index
165 if map_index < 0:
166 raise RuntimeError("can't resolve task-mapping argument without expanding")
167 all_lengths = self._get_map_lengths(context["run_id"], session=session)
169 def _find_index_for_this_field(index: int) -> int:
170 # Need to use the original user input to retain argument order.
171 for mapped_key in reversed(list(self.value)):
172 mapped_length = all_lengths[mapped_key]
173 if mapped_length < 1:
174 raise RuntimeError(f"cannot expand field mapped to length {mapped_length!r}")
175 if mapped_key == key:
176 return index % mapped_length
177 index //= mapped_length
178 return -1
180 found_index = _find_index_for_this_field(map_index)
181 if found_index < 0:
182 return value
183 if isinstance(value, collections.abc.Sequence):
184 return value[found_index]
185 if not isinstance(value, dict):
186 raise TypeError(f"can't map over value of type {type(value)}")
187 for i, (k, v) in enumerate(value.items()):
188 if i == found_index:
189 return k, v
190 raise IndexError(f"index {map_index} is over mapped length")
192 def iter_references(self) -> Iterable[tuple[Operator, str]]:
193 from airflow.models.xcom_arg import XComArg
195 for x in self.value.values():
196 if isinstance(x, XComArg):
197 yield from x.iter_references()
199 def resolve(self, context: Context, session: Session) -> tuple[Mapping[str, Any], set[int]]:
200 data = {k: self._expand_mapped_field(k, v, context, session=session) for k, v in self.value.items()}
201 literal_keys = {k for k, _ in self._iter_parse_time_resolved_kwargs()}
202 resolved_oids = {id(v) for k, v in data.items() if k not in literal_keys}
203 return data, resolved_oids
206def _describe_type(value: Any) -> str:
207 if value is None:
208 return "None"
209 return type(value).__name__
212class ListOfDictsExpandInput(NamedTuple):
213 """Storage type of a mapped operator's mapped kwargs.
215 This is created from ``expand_kwargs(xcom_arg)``.
216 """
218 value: OperatorExpandKwargsArgument
220 def get_parse_time_mapped_ti_count(self) -> int:
221 if isinstance(self.value, collections.abc.Sized):
222 return len(self.value)
223 raise NotFullyPopulated({"expand_kwargs() argument"})
225 def get_total_map_length(self, run_id: str, *, session: Session) -> int:
226 if isinstance(self.value, collections.abc.Sized):
227 return len(self.value)
228 length = self.value.get_task_map_length(run_id, session=session)
229 if length is None:
230 raise NotFullyPopulated({"expand_kwargs() argument"})
231 return length
233 def iter_references(self) -> Iterable[tuple[Operator, str]]:
234 from airflow.models.xcom_arg import XComArg
236 if isinstance(self.value, XComArg):
237 yield from self.value.iter_references()
238 else:
239 for x in self.value:
240 if isinstance(x, XComArg):
241 yield from x.iter_references()
243 def resolve(self, context: Context, session: Session) -> tuple[Mapping[str, Any], set[int]]:
244 map_index = context["ti"].map_index
245 if map_index < 0:
246 raise RuntimeError("can't resolve task-mapping argument without expanding")
248 mapping: Any
249 if isinstance(self.value, collections.abc.Sized):
250 mapping = self.value[map_index]
251 if not isinstance(mapping, collections.abc.Mapping):
252 mapping = mapping.resolve(context, session)
253 else:
254 mappings = self.value.resolve(context, session)
255 if not isinstance(mappings, collections.abc.Sequence):
256 raise ValueError(f"expand_kwargs() expects a list[dict], not {_describe_type(mappings)}")
257 mapping = mappings[map_index]
259 if not isinstance(mapping, collections.abc.Mapping):
260 raise ValueError(f"expand_kwargs() expects a list[dict], not list[{_describe_type(mapping)}]")
262 for key in mapping:
263 if not isinstance(key, str):
264 raise ValueError(
265 f"expand_kwargs() input dict keys must all be str, "
266 f"but {key!r} is of type {_describe_type(key)}"
267 )
268 return mapping, {id(v) for v in mapping.values()}
271EXPAND_INPUT_EMPTY = DictOfListsExpandInput({}) # Sentinel value.
273_EXPAND_INPUT_TYPES = {
274 "dict-of-lists": DictOfListsExpandInput,
275 "list-of-dicts": ListOfDictsExpandInput,
276}
279def get_map_type_key(expand_input: ExpandInput) -> str:
280 return next(k for k, v in _EXPAND_INPUT_TYPES.items() if v == type(expand_input))
283def create_expand_input(kind: str, value: Any) -> ExpandInput:
284 return _EXPAND_INPUT_TYPES[kind](value)