Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/airflow/sdk/definitions/_internal/expandinput.py: 30%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

135 statements  

1# 

2# Licensed to the Apache Software Foundation (ASF) under one 

3# or more contributor license agreements. See the NOTICE file 

4# distributed with this work for additional information 

5# regarding copyright ownership. The ASF licenses this file 

6# to you under the Apache License, Version 2.0 (the 

7# "License"); you may not use this file except in compliance 

8# with the License. You may obtain a copy of the License at 

9# 

10# http://www.apache.org/licenses/LICENSE-2.0 

11# 

12# Unless required by applicable law or agreed to in writing, 

13# software distributed under the License is distributed on an 

14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

15# KIND, either express or implied. See the License for the 

16# specific language governing permissions and limitations 

17# under the License. 

18from __future__ import annotations 

19 

20from collections.abc import Iterable, Mapping, Sequence, Sized 

21from typing import TYPE_CHECKING, Any, ClassVar, Union 

22 

23import attrs 

24 

25from airflow.sdk.definitions._internal.mixins import ResolveMixin 

26 

27if TYPE_CHECKING: 

28 from typing import TypeGuard 

29 

30 from airflow.sdk.definitions.xcom_arg import XComArg 

31 from airflow.sdk.types import Operator 

32 

33ExpandInput = Union["DictOfListsExpandInput", "ListOfDictsExpandInput"] 

34 

35# Each keyword argument to expand() can be an XComArg, sequence, or dict (not 

36# any mapping since we need the value to be ordered). 

37OperatorExpandArgument = Union["MappedArgument", "XComArg", Sequence, dict[str, Any]] 

38 

39# The single argument of expand_kwargs() can be an XComArg, or a list with each 

40# element being either an XComArg or a dict. 

41OperatorExpandKwargsArgument = Union["XComArg", Sequence[Union["XComArg", Mapping[str, Any]]]] 

42 

43 

44class _NotFullyPopulated(RuntimeError): 

45 """ 

46 Raise when an expand input cannot be resolved due to incomplete metadata. 

47 

48 This generally should not happen. The scheduler should have made sure that 

49 a not-yet-ready-to-expand task should not be executed. In the off chance 

50 this gets raised, it will fail the task instance. 

51 """ 

52 

53 def __init__(self, missing: set[str]) -> None: 

54 self.missing = missing 

55 

56 def __str__(self) -> str: 

57 keys = ", ".join(repr(k) for k in sorted(self.missing)) 

58 return f"Failed to populate all mapping metadata; missing: {keys}" 

59 

60 

61# To replace tedious isinstance() checks. 

62def is_mappable(v: Any) -> TypeGuard[OperatorExpandArgument]: 

63 from airflow.sdk.definitions.xcom_arg import XComArg 

64 

65 return isinstance(v, (MappedArgument, XComArg, Mapping, Sequence)) and not isinstance(v, str) 

66 

67 

68# To replace tedious isinstance() checks. 

69def _is_parse_time_mappable(v: OperatorExpandArgument) -> TypeGuard[Mapping | Sequence]: 

70 from airflow.sdk.definitions.xcom_arg import XComArg 

71 

72 return not isinstance(v, (MappedArgument, XComArg)) 

73 

74 

75# To replace tedious isinstance() checks. 

76def _needs_run_time_resolution(v: OperatorExpandArgument) -> TypeGuard[MappedArgument | XComArg]: 

77 from airflow.sdk.definitions.xcom_arg import XComArg 

78 

79 return isinstance(v, (MappedArgument, XComArg)) 

80 

81 

82@attrs.define(kw_only=True) 

83class MappedArgument(ResolveMixin): 

84 """ 

85 Stand-in stub for task-group-mapping arguments. 

86 

87 This is very similar to an XComArg, but resolved differently. Declared here 

88 (instead of in the task group module) to avoid import cycles. 

89 """ 

90 

91 _input: ExpandInput = attrs.field() 

92 _key: str 

93 

94 @_input.validator 

95 def _validate_input(self, _, input): 

96 if isinstance(input, DictOfListsExpandInput): 

97 for value in input.value.values(): 

98 if isinstance(value, MappedArgument): 

99 raise ValueError("Nested Mapped TaskGroups are not yet supported") 

100 

101 def iter_references(self) -> Iterable[tuple[Operator, str]]: 

102 yield from self._input.iter_references() 

103 

104 def resolve(self, context: Mapping[str, Any]) -> Any: 

105 data, _ = self._input.resolve(context) 

106 return data[self._key] 

107 

108 

109@attrs.define() 

110class DictOfListsExpandInput(ResolveMixin): 

111 """ 

112 Storage type of a mapped operator's mapped kwargs. 

113 

114 This is created from ``expand(**kwargs)``. 

115 """ 

116 

117 value: dict[str, OperatorExpandArgument] 

118 

119 EXPAND_INPUT_TYPE: ClassVar[str] = "dict-of-lists" 

120 

121 def _iter_parse_time_resolved_kwargs(self) -> Iterable[tuple[str, Sized]]: 

122 """Generate kwargs with values available on parse-time.""" 

123 return ((k, v) for k, v in self.value.items() if _is_parse_time_mappable(v)) 

124 

125 def _get_map_lengths( 

126 self, resolved_vals: dict[str, Sized], upstream_map_indexes: dict[str, int] 

127 ) -> dict[str, int]: 

128 """ 

129 Return dict of argument name to map length. 

130 

131 If any arguments are not known right now (upstream task not finished), 

132 they will not be present in the dict. 

133 """ 

134 

135 # TODO: This initiates one API call for each XComArg. Would it be 

136 # more efficient to do one single call and unpack the value here? 

137 def _get_length(k: str, v: OperatorExpandArgument) -> int | None: 

138 from airflow.sdk.definitions.xcom_arg import XComArg, get_task_map_length 

139 

140 if isinstance(v, XComArg): 

141 return get_task_map_length(v, resolved_vals[k], upstream_map_indexes) 

142 

143 # Unfortunately a user-defined TypeGuard cannot apply negative type 

144 # narrowing. https://github.com/python/typing/discussions/1013 

145 if TYPE_CHECKING: 

146 assert isinstance(v, Sized) 

147 return len(v) 

148 

149 map_lengths = { 

150 k: res for k, v in self.value.items() if v is not None if (res := _get_length(k, v)) is not None 

151 } 

152 if len(map_lengths) < len(self.value): 

153 raise _NotFullyPopulated(set(self.value).difference(map_lengths)) 

154 return map_lengths 

155 

156 def _expand_mapped_field(self, key: str, value: Any, map_index: int, all_lengths: dict[str, int]) -> Any: 

157 def _find_index_for_this_field(index: int) -> int: 

158 # Need to use the original user input to retain argument order. 

159 for mapped_key in reversed(self.value): 

160 mapped_length = all_lengths[mapped_key] 

161 if mapped_length < 1: 

162 raise RuntimeError(f"cannot expand field mapped to length {mapped_length!r}") 

163 if mapped_key == key: 

164 return index % mapped_length 

165 index //= mapped_length 

166 return -1 

167 

168 found_index = _find_index_for_this_field(map_index) 

169 if found_index < 0: 

170 return value 

171 if isinstance(value, Sequence): 

172 return value[found_index] 

173 if not isinstance(value, dict): 

174 raise TypeError(f"can't map over value of type {type(value)}") 

175 for i, (k, v) in enumerate(value.items()): 

176 if i == found_index: 

177 return k, v 

178 raise IndexError(f"index {map_index} is over mapped length") 

179 

180 def iter_references(self) -> Iterable[tuple[Operator, str]]: 

181 from airflow.sdk.definitions.xcom_arg import XComArg 

182 

183 for x in self.value.values(): 

184 if isinstance(x, XComArg): 

185 yield from x.iter_references() 

186 

187 def resolve(self, context: Mapping[str, Any]) -> tuple[Mapping[str, Any], set[int]]: 

188 map_index: int | None = context["ti"].map_index 

189 if map_index is None or map_index < 0: 

190 raise RuntimeError("can't resolve task-mapping argument without expanding") 

191 

192 # Get pre-computed upstream_map_indexes if available, otherwise default to empty dict. 

193 # When empty, individual XComArgs will compute their map_indexes lazily in xcom_arg.py. 

194 upstream_map_indexes = getattr(context["ti"], "_upstream_map_indexes", None) or {} 

195 

196 # TODO: This initiates one API call for each XComArg. Would it be 

197 # more efficient to do one single call and unpack the value here? 

198 

199 resolved = { 

200 k: v.resolve(context) if _needs_run_time_resolution(v) else v for k, v in self.value.items() 

201 } 

202 

203 sized_resolved = {k: v for k, v in resolved.items() if isinstance(v, Sized)} 

204 

205 all_lengths = self._get_map_lengths(sized_resolved, upstream_map_indexes) 

206 

207 data = {k: self._expand_mapped_field(k, v, map_index, all_lengths) for k, v in resolved.items()} 

208 literal_keys = {k for k, _ in self._iter_parse_time_resolved_kwargs()} 

209 resolved_oids = {id(v) for k, v in data.items() if k not in literal_keys} 

210 return data, resolved_oids 

211 

212 

213def _describe_type(value: Any) -> str: 

214 if value is None: 

215 return "None" 

216 return type(value).__name__ 

217 

218 

219@attrs.define() 

220class ListOfDictsExpandInput(ResolveMixin): 

221 """ 

222 Storage type of a mapped operator's mapped kwargs. 

223 

224 This is created from ``expand_kwargs(xcom_arg)``. 

225 """ 

226 

227 value: OperatorExpandKwargsArgument 

228 

229 EXPAND_INPUT_TYPE: ClassVar[str] = "list-of-dicts" 

230 

231 def iter_references(self) -> Iterable[tuple[Operator, str]]: 

232 from airflow.sdk.definitions.xcom_arg import XComArg 

233 

234 if isinstance(self.value, XComArg): 

235 yield from self.value.iter_references() 

236 else: 

237 for x in self.value: 

238 if isinstance(x, XComArg): 

239 yield from x.iter_references() 

240 

241 def resolve(self, context: Mapping[str, Any]) -> tuple[Mapping[str, Any], set[int]]: 

242 map_index = context["ti"].map_index 

243 if map_index < 0: 

244 raise RuntimeError("can't resolve task-mapping argument without expanding") 

245 

246 mapping: Any = None 

247 if isinstance(self.value, Sized): 

248 mapping = self.value[map_index] 

249 if not isinstance(mapping, Mapping): 

250 mapping = mapping.resolve(context) 

251 else: 

252 mappings = self.value.resolve(context) 

253 if not isinstance(mappings, Sequence): 

254 raise ValueError(f"expand_kwargs() expects a list[dict], not {_describe_type(mappings)}") 

255 mapping = mappings[map_index] 

256 

257 if not isinstance(mapping, Mapping): 

258 raise ValueError(f"expand_kwargs() expects a list[dict], not list[{_describe_type(mapping)}]") 

259 

260 for key in mapping: 

261 if not isinstance(key, str): 

262 raise ValueError( 

263 f"expand_kwargs() input dict keys must all be str, " 

264 f"but {key!r} is of type {_describe_type(key)}" 

265 ) 

266 # filter out parse time resolved values from the resolved_oids 

267 resolved_oids = {id(v) for k, v in mapping.items() if not _is_parse_time_mappable(v)} 

268 

269 return mapping, resolved_oids 

270 

271 

272EXPAND_INPUT_EMPTY = DictOfListsExpandInput({}) # Sentinel value.