Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/airflow/models/expandinput.py: 29%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

160 statements  

1# 

2# Licensed to the Apache Software Foundation (ASF) under one 

3# or more contributor license agreements. See the NOTICE file 

4# distributed with this work for additional information 

5# regarding copyright ownership. The ASF licenses this file 

6# to you under the Apache License, Version 2.0 (the 

7# "License"); you may not use this file except in compliance 

8# with the License. You may obtain a copy of the License at 

9# 

10# http://www.apache.org/licenses/LICENSE-2.0 

11# 

12# Unless required by applicable law or agreed to in writing, 

13# software distributed under the License is distributed on an 

14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

15# KIND, either express or implied. See the License for the 

16# specific language governing permissions and limitations 

17# under the License. 

18from __future__ import annotations 

19 

20import collections.abc 

21import functools 

22import operator 

23from collections.abc import Sized 

24from typing import TYPE_CHECKING, Any, Dict, Iterable, Mapping, NamedTuple, Sequence, Union 

25 

26import attr 

27 

28from airflow.utils.mixins import ResolveMixin 

29from airflow.utils.session import NEW_SESSION, provide_session 

30 

31if TYPE_CHECKING: 

32 from sqlalchemy.orm import Session 

33 

34 from airflow.models.operator import Operator 

35 from airflow.models.xcom_arg import XComArg 

36 from airflow.typing_compat import TypeGuard 

37 from airflow.utils.context import Context 

38 

39ExpandInput = Union["DictOfListsExpandInput", "ListOfDictsExpandInput"] 

40 

41# Each keyword argument to expand() can be an XComArg, sequence, or dict (not 

42# any mapping since we need the value to be ordered). 

43OperatorExpandArgument = Union["MappedArgument", "XComArg", Sequence, Dict[str, Any]] 

44 

45# The single argument of expand_kwargs() can be an XComArg, or a list with each 

46# element being either an XComArg or a dict. 

47OperatorExpandKwargsArgument = Union["XComArg", Sequence[Union["XComArg", Mapping[str, Any]]]] 

48 

49 

50@attr.define(kw_only=True) 

51class MappedArgument(ResolveMixin): 

52 """Stand-in stub for task-group-mapping arguments. 

53 

54 This is very similar to an XComArg, but resolved differently. Declared here 

55 (instead of in the task group module) to avoid import cycles. 

56 """ 

57 

58 _input: ExpandInput 

59 _key: str 

60 

61 def get_task_map_length(self, run_id: str, *, session: Session) -> int | None: 

62 # TODO (AIP-42): Implement run-time task map length inspection. This is 

63 # needed when we implement task mapping inside a mapped task group. 

64 raise NotImplementedError() 

65 

66 def iter_references(self) -> Iterable[tuple[Operator, str]]: 

67 yield from self._input.iter_references() 

68 

69 @provide_session 

70 def resolve(self, context: Context, *, session: Session = NEW_SESSION) -> Any: 

71 data, _ = self._input.resolve(context, session=session) 

72 return data[self._key] 

73 

74 

75# To replace tedious isinstance() checks. 

76def is_mappable(v: Any) -> TypeGuard[OperatorExpandArgument]: 

77 from airflow.models.xcom_arg import XComArg 

78 

79 return isinstance(v, (MappedArgument, XComArg, Mapping, Sequence)) and not isinstance(v, str) 

80 

81 

82# To replace tedious isinstance() checks. 

83def _is_parse_time_mappable(v: OperatorExpandArgument) -> TypeGuard[Mapping | Sequence]: 

84 from airflow.models.xcom_arg import XComArg 

85 

86 return not isinstance(v, (MappedArgument, XComArg)) 

87 

88 

89# To replace tedious isinstance() checks. 

90def _needs_run_time_resolution(v: OperatorExpandArgument) -> TypeGuard[MappedArgument | XComArg]: 

91 from airflow.models.xcom_arg import XComArg 

92 

93 return isinstance(v, (MappedArgument, XComArg)) 

94 

95 

96class NotFullyPopulated(RuntimeError): 

97 """Raise when ``get_map_lengths`` cannot populate all mapping metadata. 

98 

99 This is generally due to not all upstream tasks have finished when the 

100 function is called. 

101 """ 

102 

103 def __init__(self, missing: set[str]) -> None: 

104 self.missing = missing 

105 

106 def __str__(self) -> str: 

107 keys = ", ".join(repr(k) for k in sorted(self.missing)) 

108 return f"Failed to populate all mapping metadata; missing: {keys}" 

109 

110 

111class DictOfListsExpandInput(NamedTuple): 

112 """Storage type of a mapped operator's mapped kwargs. 

113 

114 This is created from ``expand(**kwargs)``. 

115 """ 

116 

117 value: dict[str, OperatorExpandArgument] 

118 

119 def _iter_parse_time_resolved_kwargs(self) -> Iterable[tuple[str, Sized]]: 

120 """Generate kwargs with values available on parse-time.""" 

121 return ((k, v) for k, v in self.value.items() if _is_parse_time_mappable(v)) 

122 

123 def get_parse_time_mapped_ti_count(self) -> int: 

124 if not self.value: 

125 return 0 

126 literal_values = [len(v) for _, v in self._iter_parse_time_resolved_kwargs()] 

127 if len(literal_values) != len(self.value): 

128 literal_keys = (k for k, _ in self._iter_parse_time_resolved_kwargs()) 

129 raise NotFullyPopulated(set(self.value).difference(literal_keys)) 

130 return functools.reduce(operator.mul, literal_values, 1) 

131 

132 def _get_map_lengths(self, run_id: str, *, session: Session) -> dict[str, int]: 

133 """Return dict of argument name to map length. 

134 

135 If any arguments are not known right now (upstream task not finished), 

136 they will not be present in the dict. 

137 """ 

138 

139 # TODO: This initiates one database call for each XComArg. Would it be 

140 # more efficient to do one single db call and unpack the value here? 

141 def _get_length(v: OperatorExpandArgument) -> int | None: 

142 if _needs_run_time_resolution(v): 

143 return v.get_task_map_length(run_id, session=session) 

144 # Unfortunately a user-defined TypeGuard cannot apply negative type 

145 # narrowing. https://github.com/python/typing/discussions/1013 

146 if TYPE_CHECKING: 

147 assert isinstance(v, Sized) 

148 return len(v) 

149 

150 map_lengths_iterator = ((k, _get_length(v)) for k, v in self.value.items()) 

151 

152 map_lengths = {k: v for k, v in map_lengths_iterator if v is not None} 

153 if len(map_lengths) < len(self.value): 

154 raise NotFullyPopulated(set(self.value).difference(map_lengths)) 

155 return map_lengths 

156 

157 def get_total_map_length(self, run_id: str, *, session: Session) -> int: 

158 if not self.value: 

159 return 0 

160 lengths = self._get_map_lengths(run_id, session=session) 

161 return functools.reduce(operator.mul, (lengths[name] for name in self.value), 1) 

162 

163 def _expand_mapped_field(self, key: str, value: Any, context: Context, *, session: Session) -> Any: 

164 if _needs_run_time_resolution(value): 

165 value = value.resolve(context, session=session) 

166 map_index = context["ti"].map_index 

167 if map_index < 0: 

168 raise RuntimeError("can't resolve task-mapping argument without expanding") 

169 all_lengths = self._get_map_lengths(context["run_id"], session=session) 

170 

171 def _find_index_for_this_field(index: int) -> int: 

172 # Need to use the original user input to retain argument order. 

173 for mapped_key in reversed(self.value): 

174 mapped_length = all_lengths[mapped_key] 

175 if mapped_length < 1: 

176 raise RuntimeError(f"cannot expand field mapped to length {mapped_length!r}") 

177 if mapped_key == key: 

178 return index % mapped_length 

179 index //= mapped_length 

180 return -1 

181 

182 found_index = _find_index_for_this_field(map_index) 

183 if found_index < 0: 

184 return value 

185 if isinstance(value, collections.abc.Sequence): 

186 return value[found_index] 

187 if not isinstance(value, dict): 

188 raise TypeError(f"can't map over value of type {type(value)}") 

189 for i, (k, v) in enumerate(value.items()): 

190 if i == found_index: 

191 return k, v 

192 raise IndexError(f"index {map_index} is over mapped length") 

193 

194 def iter_references(self) -> Iterable[tuple[Operator, str]]: 

195 from airflow.models.xcom_arg import XComArg 

196 

197 for x in self.value.values(): 

198 if isinstance(x, XComArg): 

199 yield from x.iter_references() 

200 

201 def resolve(self, context: Context, session: Session) -> tuple[Mapping[str, Any], set[int]]: 

202 data = {k: self._expand_mapped_field(k, v, context, session=session) for k, v in self.value.items()} 

203 literal_keys = {k for k, _ in self._iter_parse_time_resolved_kwargs()} 

204 resolved_oids = {id(v) for k, v in data.items() if k not in literal_keys} 

205 return data, resolved_oids 

206 

207 

208def _describe_type(value: Any) -> str: 

209 if value is None: 

210 return "None" 

211 return type(value).__name__ 

212 

213 

214class ListOfDictsExpandInput(NamedTuple): 

215 """Storage type of a mapped operator's mapped kwargs. 

216 

217 This is created from ``expand_kwargs(xcom_arg)``. 

218 """ 

219 

220 value: OperatorExpandKwargsArgument 

221 

222 def get_parse_time_mapped_ti_count(self) -> int: 

223 if isinstance(self.value, collections.abc.Sized): 

224 return len(self.value) 

225 raise NotFullyPopulated({"expand_kwargs() argument"}) 

226 

227 def get_total_map_length(self, run_id: str, *, session: Session) -> int: 

228 if isinstance(self.value, collections.abc.Sized): 

229 return len(self.value) 

230 length = self.value.get_task_map_length(run_id, session=session) 

231 if length is None: 

232 raise NotFullyPopulated({"expand_kwargs() argument"}) 

233 return length 

234 

235 def iter_references(self) -> Iterable[tuple[Operator, str]]: 

236 from airflow.models.xcom_arg import XComArg 

237 

238 if isinstance(self.value, XComArg): 

239 yield from self.value.iter_references() 

240 else: 

241 for x in self.value: 

242 if isinstance(x, XComArg): 

243 yield from x.iter_references() 

244 

245 def resolve(self, context: Context, session: Session) -> tuple[Mapping[str, Any], set[int]]: 

246 map_index = context["ti"].map_index 

247 if map_index < 0: 

248 raise RuntimeError("can't resolve task-mapping argument without expanding") 

249 

250 mapping: Any 

251 if isinstance(self.value, collections.abc.Sized): 

252 mapping = self.value[map_index] 

253 if not isinstance(mapping, collections.abc.Mapping): 

254 mapping = mapping.resolve(context, session) 

255 else: 

256 mappings = self.value.resolve(context, session) 

257 if not isinstance(mappings, collections.abc.Sequence): 

258 raise ValueError(f"expand_kwargs() expects a list[dict], not {_describe_type(mappings)}") 

259 mapping = mappings[map_index] 

260 

261 if not isinstance(mapping, collections.abc.Mapping): 

262 raise ValueError(f"expand_kwargs() expects a list[dict], not list[{_describe_type(mapping)}]") 

263 

264 for key in mapping: 

265 if not isinstance(key, str): 

266 raise ValueError( 

267 f"expand_kwargs() input dict keys must all be str, " 

268 f"but {key!r} is of type {_describe_type(key)}" 

269 ) 

270 # filter out parse time resolved values from the resolved_oids 

271 resolved_oids = {id(v) for k, v in mapping.items() if not _is_parse_time_mappable(v)} 

272 

273 return mapping, resolved_oids 

274 

275 

276EXPAND_INPUT_EMPTY = DictOfListsExpandInput({}) # Sentinel value. 

277 

278_EXPAND_INPUT_TYPES = { 

279 "dict-of-lists": DictOfListsExpandInput, 

280 "list-of-dicts": ListOfDictsExpandInput, 

281} 

282 

283 

284def get_map_type_key(expand_input: ExpandInput) -> str: 

285 return next(k for k, v in _EXPAND_INPUT_TYPES.items() if v == type(expand_input)) 

286 

287 

288def create_expand_input(kind: str, value: Any) -> ExpandInput: 

289 return _EXPAND_INPUT_TYPES[kind](value)