Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/airflow/sdk/definitions/_internal/expandinput.py: 30%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

135 statements  

1# 

2# Licensed to the Apache Software Foundation (ASF) under one 

3# or more contributor license agreements. See the NOTICE file 

4# distributed with this work for additional information 

5# regarding copyright ownership. The ASF licenses this file 

6# to you under the Apache License, Version 2.0 (the 

7# "License"); you may not use this file except in compliance 

8# with the License. You may obtain a copy of the License at 

9# 

10# http://www.apache.org/licenses/LICENSE-2.0 

11# 

12# Unless required by applicable law or agreed to in writing, 

13# software distributed under the License is distributed on an 

14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

15# KIND, either express or implied. See the License for the 

16# specific language governing permissions and limitations 

17# under the License. 

18from __future__ import annotations 

19 

20from collections.abc import Iterable, Mapping, Sequence, Sized 

21from typing import TYPE_CHECKING, Any, ClassVar, Union 

22 

23import attrs 

24 

25from airflow.sdk.definitions._internal.mixins import ResolveMixin 

26 

27if TYPE_CHECKING: 

28 from typing import TypeGuard 

29 

30 from airflow.sdk.definitions.xcom_arg import XComArg 

31 from airflow.sdk.types import Operator 

32 

33ExpandInput = Union["DictOfListsExpandInput", "ListOfDictsExpandInput"] 

34 

35# Each keyword argument to expand() can be an XComArg, sequence, or dict (not 

36# any mapping since we need the value to be ordered). 

37OperatorExpandArgument = Union["MappedArgument", "XComArg", Sequence, dict[str, Any]] 

38 

39# The single argument of expand_kwargs() can be an XComArg, or a list with each 

40# element being either an XComArg or a dict. 

41OperatorExpandKwargsArgument = Union["XComArg", Sequence[Union["XComArg", Mapping[str, Any]]]] 

42 

43 

44class _NotFullyPopulated(RuntimeError): 

45 """ 

46 Raise when an expand input cannot be resolved due to incomplete metadata. 

47 

48 This generally should not happen. The scheduler should have made sure that 

49 a not-yet-ready-to-expand task should not be executed. In the off chance 

50 this gets raised, it will fail the task instance. 

51 """ 

52 

53 def __init__(self, missing: set[str]) -> None: 

54 self.missing = missing 

55 

56 def __str__(self) -> str: 

57 keys = ", ".join(repr(k) for k in sorted(self.missing)) 

58 return f"Failed to populate all mapping metadata; missing: {keys}" 

59 

60 

61# To replace tedious isinstance() checks. 

62def is_mappable(v: Any) -> TypeGuard[OperatorExpandArgument]: 

63 from airflow.sdk.definitions.xcom_arg import XComArg 

64 

65 return isinstance(v, (MappedArgument, XComArg, Mapping, Sequence)) and not isinstance(v, str) 

66 

67 

68# To replace tedious isinstance() checks. 

69def _is_parse_time_mappable(v: OperatorExpandArgument) -> TypeGuard[Mapping | Sequence]: 

70 from airflow.sdk.definitions.xcom_arg import XComArg 

71 

72 return not isinstance(v, (MappedArgument, XComArg)) 

73 

74 

75# To replace tedious isinstance() checks. 

76def _needs_run_time_resolution(v: OperatorExpandArgument) -> TypeGuard[MappedArgument | XComArg]: 

77 from airflow.sdk.definitions.xcom_arg import XComArg 

78 

79 return isinstance(v, (MappedArgument, XComArg)) 

80 

81 

82@attrs.define(kw_only=True) 

83class MappedArgument(ResolveMixin): 

84 """ 

85 Stand-in stub for task-group-mapping arguments. 

86 

87 This is very similar to an XComArg, but resolved differently. Declared here 

88 (instead of in the task group module) to avoid import cycles. 

89 """ 

90 

91 _input: ExpandInput = attrs.field() 

92 _key: str 

93 

94 @_input.validator 

95 def _validate_input(self, _, input): 

96 if isinstance(input, DictOfListsExpandInput): 

97 for value in input.value.values(): 

98 if isinstance(value, MappedArgument): 

99 raise ValueError("Nested Mapped TaskGroups are not yet supported") 

100 

101 def iter_references(self) -> Iterable[tuple[Operator, str]]: 

102 yield from self._input.iter_references() 

103 

104 def resolve(self, context: Mapping[str, Any]) -> Any: 

105 data, _ = self._input.resolve(context) 

106 return data[self._key] 

107 

108 

109@attrs.define() 

110class DictOfListsExpandInput(ResolveMixin): 

111 """ 

112 Storage type of a mapped operator's mapped kwargs. 

113 

114 This is created from ``expand(**kwargs)``. 

115 """ 

116 

117 value: dict[str, OperatorExpandArgument] 

118 

119 EXPAND_INPUT_TYPE: ClassVar[str] = "dict-of-lists" 

120 

121 def _iter_parse_time_resolved_kwargs(self) -> Iterable[tuple[str, Sized]]: 

122 """Generate kwargs with values available on parse-time.""" 

123 return ((k, v) for k, v in self.value.items() if _is_parse_time_mappable(v)) 

124 

125 def _get_map_lengths( 

126 self, resolved_vals: dict[str, Sized], upstream_map_indexes: dict[str, int] 

127 ) -> dict[str, int]: 

128 """ 

129 Return dict of argument name to map length. 

130 

131 If any arguments are not known right now (upstream task not finished), 

132 they will not be present in the dict. 

133 """ 

134 

135 # TODO: This initiates one API call for each XComArg. Would it be 

136 # more efficient to do one single call and unpack the value here? 

137 def _get_length(k: str, v: OperatorExpandArgument) -> int | None: 

138 from airflow.sdk.definitions.xcom_arg import XComArg, get_task_map_length 

139 

140 if isinstance(v, XComArg): 

141 return get_task_map_length(v, resolved_vals[k], upstream_map_indexes) 

142 

143 # Unfortunately a user-defined TypeGuard cannot apply negative type 

144 # narrowing. https://github.com/python/typing/discussions/1013 

145 if TYPE_CHECKING: 

146 assert isinstance(v, Sized) 

147 return len(v) 

148 

149 map_lengths = { 

150 k: res for k, v in self.value.items() if v is not None if (res := _get_length(k, v)) is not None 

151 } 

152 if len(map_lengths) < len(self.value): 

153 raise _NotFullyPopulated(set(self.value).difference(map_lengths)) 

154 return map_lengths 

155 

156 def _expand_mapped_field(self, key: str, value: Any, map_index: int, all_lengths: dict[str, int]) -> Any: 

157 def _find_index_for_this_field(index: int) -> int: 

158 # Need to use the original user input to retain argument order. 

159 for mapped_key in reversed(self.value): 

160 mapped_length = all_lengths[mapped_key] 

161 if mapped_length < 1: 

162 raise RuntimeError(f"cannot expand field mapped to length {mapped_length!r}") 

163 if mapped_key == key: 

164 return index % mapped_length 

165 index //= mapped_length 

166 return -1 

167 

168 found_index = _find_index_for_this_field(map_index) 

169 if found_index < 0: 

170 return value 

171 if isinstance(value, Sequence): 

172 return value[found_index] 

173 if not isinstance(value, dict): 

174 raise TypeError(f"can't map over value of type {type(value)}") 

175 for i, (k, v) in enumerate(value.items()): 

176 if i == found_index: 

177 return k, v 

178 raise IndexError(f"index {map_index} is over mapped length") 

179 

180 def iter_references(self) -> Iterable[tuple[Operator, str]]: 

181 from airflow.sdk.definitions.xcom_arg import XComArg 

182 

183 for x in self.value.values(): 

184 if isinstance(x, XComArg): 

185 yield from x.iter_references() 

186 

187 def resolve(self, context: Mapping[str, Any]) -> tuple[Mapping[str, Any], set[int]]: 

188 map_index: int | None = context["ti"].map_index 

189 if map_index is None or map_index < 0: 

190 raise RuntimeError("can't resolve task-mapping argument without expanding") 

191 

192 upstream_map_indexes = getattr(context["ti"], "_upstream_map_indexes", {}) 

193 

194 # TODO: This initiates one API call for each XComArg. Would it be 

195 # more efficient to do one single call and unpack the value here? 

196 

197 resolved = { 

198 k: v.resolve(context) if _needs_run_time_resolution(v) else v for k, v in self.value.items() 

199 } 

200 

201 sized_resolved = {k: v for k, v in resolved.items() if isinstance(v, Sized)} 

202 

203 all_lengths = self._get_map_lengths(sized_resolved, upstream_map_indexes) 

204 

205 data = {k: self._expand_mapped_field(k, v, map_index, all_lengths) for k, v in resolved.items()} 

206 literal_keys = {k for k, _ in self._iter_parse_time_resolved_kwargs()} 

207 resolved_oids = {id(v) for k, v in data.items() if k not in literal_keys} 

208 return data, resolved_oids 

209 

210 

211def _describe_type(value: Any) -> str: 

212 if value is None: 

213 return "None" 

214 return type(value).__name__ 

215 

216 

217@attrs.define() 

218class ListOfDictsExpandInput(ResolveMixin): 

219 """ 

220 Storage type of a mapped operator's mapped kwargs. 

221 

222 This is created from ``expand_kwargs(xcom_arg)``. 

223 """ 

224 

225 value: OperatorExpandKwargsArgument 

226 

227 EXPAND_INPUT_TYPE: ClassVar[str] = "list-of-dicts" 

228 

229 def iter_references(self) -> Iterable[tuple[Operator, str]]: 

230 from airflow.sdk.definitions.xcom_arg import XComArg 

231 

232 if isinstance(self.value, XComArg): 

233 yield from self.value.iter_references() 

234 else: 

235 for x in self.value: 

236 if isinstance(x, XComArg): 

237 yield from x.iter_references() 

238 

239 def resolve(self, context: Mapping[str, Any]) -> tuple[Mapping[str, Any], set[int]]: 

240 map_index = context["ti"].map_index 

241 if map_index < 0: 

242 raise RuntimeError("can't resolve task-mapping argument without expanding") 

243 

244 mapping: Any = None 

245 if isinstance(self.value, Sized): 

246 mapping = self.value[map_index] 

247 if not isinstance(mapping, Mapping): 

248 mapping = mapping.resolve(context) 

249 else: 

250 mappings = self.value.resolve(context) 

251 if not isinstance(mappings, Sequence): 

252 raise ValueError(f"expand_kwargs() expects a list[dict], not {_describe_type(mappings)}") 

253 mapping = mappings[map_index] 

254 

255 if not isinstance(mapping, Mapping): 

256 raise ValueError(f"expand_kwargs() expects a list[dict], not list[{_describe_type(mapping)}]") 

257 

258 for key in mapping: 

259 if not isinstance(key, str): 

260 raise ValueError( 

261 f"expand_kwargs() input dict keys must all be str, " 

262 f"but {key!r} is of type {_describe_type(key)}" 

263 ) 

264 # filter out parse time resolved values from the resolved_oids 

265 resolved_oids = {id(v) for k, v in mapping.items() if not _is_parse_time_mappable(v)} 

266 

267 return mapping, resolved_oids 

268 

269 

270EXPAND_INPUT_EMPTY = DictOfListsExpandInput({}) # Sentinel value.