Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/airflow/sdk/definitions/_internal/expandinput.py: 30%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

149 statements  

1# 

2# Licensed to the Apache Software Foundation (ASF) under one 

3# or more contributor license agreements. See the NOTICE file 

4# distributed with this work for additional information 

5# regarding copyright ownership. The ASF licenses this file 

6# to you under the Apache License, Version 2.0 (the 

7# "License"); you may not use this file except in compliance 

8# with the License. You may obtain a copy of the License at 

9# 

10# http://www.apache.org/licenses/LICENSE-2.0 

11# 

12# Unless required by applicable law or agreed to in writing, 

13# software distributed under the License is distributed on an 

14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

15# KIND, either express or implied. See the License for the 

16# specific language governing permissions and limitations 

17# under the License. 

18from __future__ import annotations 

19 

20import functools 

21import operator 

22from collections.abc import Iterable, Mapping, Sequence, Sized 

23from typing import TYPE_CHECKING, Any, ClassVar, Union 

24 

25import attrs 

26 

27from airflow.sdk.definitions._internal.mixins import ResolveMixin 

28 

29if TYPE_CHECKING: 

30 from typing import TypeGuard 

31 

32 from airflow.sdk.definitions.xcom_arg import XComArg 

33 from airflow.sdk.types import Operator 

34 

35ExpandInput = Union["DictOfListsExpandInput", "ListOfDictsExpandInput"] 

36 

37# Each keyword argument to expand() can be an XComArg, sequence, or dict (not 

38# any mapping since we need the value to be ordered). 

39OperatorExpandArgument = Union["MappedArgument", "XComArg", Sequence, dict[str, Any]] 

40 

41# The single argument of expand_kwargs() can be an XComArg, or a list with each 

42# element being either an XComArg or a dict. 

43OperatorExpandKwargsArgument = Union["XComArg", Sequence[Union["XComArg", Mapping[str, Any]]]] 

44 

45 

46class NotFullyPopulated(RuntimeError): 

47 """ 

48 Raise when ``get_map_lengths`` cannot populate all mapping metadata. 

49 

50 This is generally due to not all upstream tasks have finished when the 

51 function is called. 

52 """ 

53 

54 def __init__(self, missing: set[str]) -> None: 

55 self.missing = missing 

56 

57 def __str__(self) -> str: 

58 keys = ", ".join(repr(k) for k in sorted(self.missing)) 

59 return f"Failed to populate all mapping metadata; missing: {keys}" 

60 

61 

62# To replace tedious isinstance() checks. 

63def is_mappable(v: Any) -> TypeGuard[OperatorExpandArgument]: 

64 from airflow.sdk.definitions.xcom_arg import XComArg 

65 

66 return isinstance(v, (MappedArgument, XComArg, Mapping, Sequence)) and not isinstance(v, str) 

67 

68 

69# To replace tedious isinstance() checks. 

70def _is_parse_time_mappable(v: OperatorExpandArgument) -> TypeGuard[Mapping | Sequence]: 

71 from airflow.sdk.definitions.xcom_arg import XComArg 

72 

73 return not isinstance(v, (MappedArgument, XComArg)) 

74 

75 

76# To replace tedious isinstance() checks. 

77def _needs_run_time_resolution(v: OperatorExpandArgument) -> TypeGuard[MappedArgument | XComArg]: 

78 from airflow.sdk.definitions.xcom_arg import XComArg 

79 

80 return isinstance(v, (MappedArgument, XComArg)) 

81 

82 

83@attrs.define(kw_only=True) 

84class MappedArgument(ResolveMixin): 

85 """ 

86 Stand-in stub for task-group-mapping arguments. 

87 

88 This is very similar to an XComArg, but resolved differently. Declared here 

89 (instead of in the task group module) to avoid import cycles. 

90 """ 

91 

92 _input: ExpandInput = attrs.field() 

93 _key: str 

94 

95 @_input.validator 

96 def _validate_input(self, _, input): 

97 if isinstance(input, DictOfListsExpandInput): 

98 for value in input.value.values(): 

99 if isinstance(value, MappedArgument): 

100 raise ValueError("Nested Mapped TaskGroups are not yet supported") 

101 

102 def iter_references(self) -> Iterable[tuple[Operator, str]]: 

103 yield from self._input.iter_references() 

104 

105 def resolve(self, context: Mapping[str, Any]) -> Any: 

106 data, _ = self._input.resolve(context) 

107 return data[self._key] 

108 

109 

110@attrs.define() 

111class DictOfListsExpandInput(ResolveMixin): 

112 """ 

113 Storage type of a mapped operator's mapped kwargs. 

114 

115 This is created from ``expand(**kwargs)``. 

116 """ 

117 

118 value: dict[str, OperatorExpandArgument] 

119 

120 EXPAND_INPUT_TYPE: ClassVar[str] = "dict-of-lists" 

121 

122 def _iter_parse_time_resolved_kwargs(self) -> Iterable[tuple[str, Sized]]: 

123 """Generate kwargs with values available on parse-time.""" 

124 return ((k, v) for k, v in self.value.items() if _is_parse_time_mappable(v)) 

125 

126 def get_parse_time_mapped_ti_count(self) -> int: 

127 if not self.value: 

128 return 0 

129 literal_values = [len(v) for _, v in self._iter_parse_time_resolved_kwargs()] 

130 if len(literal_values) != len(self.value): 

131 literal_keys = (k for k, _ in self._iter_parse_time_resolved_kwargs()) 

132 raise NotFullyPopulated(set(self.value).difference(literal_keys)) 

133 return functools.reduce(operator.mul, literal_values, 1) 

134 

135 def _get_map_lengths( 

136 self, resolved_vals: dict[str, Sized], upstream_map_indexes: dict[str, int] 

137 ) -> dict[str, int]: 

138 """ 

139 Return dict of argument name to map length. 

140 

141 If any arguments are not known right now (upstream task not finished), 

142 they will not be present in the dict. 

143 """ 

144 

145 # TODO: This initiates one API call for each XComArg. Would it be 

146 # more efficient to do one single call and unpack the value here? 

147 def _get_length(k: str, v: OperatorExpandArgument) -> int | None: 

148 from airflow.sdk.definitions.xcom_arg import XComArg, get_task_map_length 

149 

150 if isinstance(v, XComArg): 

151 return get_task_map_length(v, resolved_vals[k], upstream_map_indexes) 

152 

153 # Unfortunately a user-defined TypeGuard cannot apply negative type 

154 # narrowing. https://github.com/python/typing/discussions/1013 

155 if TYPE_CHECKING: 

156 assert isinstance(v, Sized) 

157 return len(v) 

158 

159 map_lengths = { 

160 k: res for k, v in self.value.items() if v is not None if (res := _get_length(k, v)) is not None 

161 } 

162 if len(map_lengths) < len(self.value): 

163 raise NotFullyPopulated(set(self.value).difference(map_lengths)) 

164 return map_lengths 

165 

166 def _expand_mapped_field(self, key: str, value: Any, map_index: int, all_lengths: dict[str, int]) -> Any: 

167 def _find_index_for_this_field(index: int) -> int: 

168 # Need to use the original user input to retain argument order. 

169 for mapped_key in reversed(self.value): 

170 mapped_length = all_lengths[mapped_key] 

171 if mapped_length < 1: 

172 raise RuntimeError(f"cannot expand field mapped to length {mapped_length!r}") 

173 if mapped_key == key: 

174 return index % mapped_length 

175 index //= mapped_length 

176 return -1 

177 

178 found_index = _find_index_for_this_field(map_index) 

179 if found_index < 0: 

180 return value 

181 if isinstance(value, Sequence): 

182 return value[found_index] 

183 if not isinstance(value, dict): 

184 raise TypeError(f"can't map over value of type {type(value)}") 

185 for i, (k, v) in enumerate(value.items()): 

186 if i == found_index: 

187 return k, v 

188 raise IndexError(f"index {map_index} is over mapped length") 

189 

190 def iter_references(self) -> Iterable[tuple[Operator, str]]: 

191 from airflow.sdk.definitions.xcom_arg import XComArg 

192 

193 for x in self.value.values(): 

194 if isinstance(x, XComArg): 

195 yield from x.iter_references() 

196 

197 def resolve(self, context: Mapping[str, Any]) -> tuple[Mapping[str, Any], set[int]]: 

198 map_index: int | None = context["ti"].map_index 

199 if map_index is None or map_index < 0: 

200 raise RuntimeError("can't resolve task-mapping argument without expanding") 

201 

202 upstream_map_indexes = getattr(context["ti"], "_upstream_map_indexes", {}) 

203 

204 # TODO: This initiates one API call for each XComArg. Would it be 

205 # more efficient to do one single call and unpack the value here? 

206 

207 resolved = { 

208 k: v.resolve(context) if _needs_run_time_resolution(v) else v for k, v in self.value.items() 

209 } 

210 

211 sized_resolved = {k: v for k, v in resolved.items() if isinstance(v, Sized)} 

212 

213 all_lengths = self._get_map_lengths(sized_resolved, upstream_map_indexes) 

214 

215 data = {k: self._expand_mapped_field(k, v, map_index, all_lengths) for k, v in resolved.items()} 

216 literal_keys = {k for k, _ in self._iter_parse_time_resolved_kwargs()} 

217 resolved_oids = {id(v) for k, v in data.items() if k not in literal_keys} 

218 return data, resolved_oids 

219 

220 

221def _describe_type(value: Any) -> str: 

222 if value is None: 

223 return "None" 

224 return type(value).__name__ 

225 

226 

227@attrs.define() 

228class ListOfDictsExpandInput(ResolveMixin): 

229 """ 

230 Storage type of a mapped operator's mapped kwargs. 

231 

232 This is created from ``expand_kwargs(xcom_arg)``. 

233 """ 

234 

235 value: OperatorExpandKwargsArgument 

236 

237 EXPAND_INPUT_TYPE: ClassVar[str] = "list-of-dicts" 

238 

239 def get_parse_time_mapped_ti_count(self) -> int: 

240 if isinstance(self.value, Sized): 

241 return len(self.value) 

242 raise NotFullyPopulated({"expand_kwargs() argument"}) 

243 

244 def iter_references(self) -> Iterable[tuple[Operator, str]]: 

245 from airflow.sdk.definitions.xcom_arg import XComArg 

246 

247 if isinstance(self.value, XComArg): 

248 yield from self.value.iter_references() 

249 else: 

250 for x in self.value: 

251 if isinstance(x, XComArg): 

252 yield from x.iter_references() 

253 

254 def resolve(self, context: Mapping[str, Any]) -> tuple[Mapping[str, Any], set[int]]: 

255 map_index = context["ti"].map_index 

256 if map_index < 0: 

257 raise RuntimeError("can't resolve task-mapping argument without expanding") 

258 

259 mapping: Any = None 

260 if isinstance(self.value, Sized): 

261 mapping = self.value[map_index] 

262 if not isinstance(mapping, Mapping): 

263 mapping = mapping.resolve(context) 

264 else: 

265 mappings = self.value.resolve(context) 

266 if not isinstance(mappings, Sequence): 

267 raise ValueError(f"expand_kwargs() expects a list[dict], not {_describe_type(mappings)}") 

268 mapping = mappings[map_index] 

269 

270 if not isinstance(mapping, Mapping): 

271 raise ValueError(f"expand_kwargs() expects a list[dict], not list[{_describe_type(mapping)}]") 

272 

273 for key in mapping: 

274 if not isinstance(key, str): 

275 raise ValueError( 

276 f"expand_kwargs() input dict keys must all be str, " 

277 f"but {key!r} is of type {_describe_type(key)}" 

278 ) 

279 # filter out parse time resolved values from the resolved_oids 

280 resolved_oids = {id(v) for k, v in mapping.items() if not _is_parse_time_mappable(v)} 

281 

282 return mapping, resolved_oids 

283 

284 

285EXPAND_INPUT_EMPTY = DictOfListsExpandInput({}) # Sentinel value.