Coverage for /pythoncovmergedfiles/medio/medio/src/airflow/build/lib/airflow/models/xcom_arg.py: 40%

249 statements  

« prev     ^ index     » next       coverage.py v7.0.1, created at 2022-12-25 06:11 +0000

1# Licensed to the Apache Software Foundation (ASF) under one 

2# or more contributor license agreements. See the NOTICE file 

3# distributed with this work for additional information 

4# regarding copyright ownership. The ASF licenses this file 

5# to you under the Apache License, Version 2.0 (the 

6# "License"); you may not use this file except in compliance 

7# with the License. You may obtain a copy of the License at 

8# 

9# http://www.apache.org/licenses/LICENSE-2.0 

10# 

11# Unless required by applicable law or agreed to in writing, 

12# software distributed under the License is distributed on an 

13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

14# KIND, either express or implied. See the License for the 

15# specific language governing permissions and limitations 

16# under the License. 

17 

18from __future__ import annotations 

19 

20import contextlib 

21import inspect 

22from typing import TYPE_CHECKING, Any, Callable, Iterator, Mapping, Sequence, Union, overload 

23 

24from sqlalchemy import func 

25from sqlalchemy.orm import Session 

26 

27from airflow.exceptions import XComNotFound 

28from airflow.models.abstractoperator import AbstractOperator 

29from airflow.models.mappedoperator import MappedOperator 

30from airflow.models.taskmixin import DAGNode, DependencyMixin 

31from airflow.models.xcom import XCOM_RETURN_KEY 

32from airflow.utils.context import Context 

33from airflow.utils.edgemodifier import EdgeModifier 

34from airflow.utils.mixins import ResolveMixin 

35from airflow.utils.session import NEW_SESSION, provide_session 

36from airflow.utils.types import NOTSET, ArgNotSet 

37 

38if TYPE_CHECKING: 

39 from airflow.models.dag import DAG 

40 from airflow.models.operator import Operator 

41 

42# Callable objects contained by MapXComArg. We only accept callables from 

43# the user, but deserialize them into strings in a serialized XComArg for 

44# safety (those callables are arbitrary user code). 

45MapCallables = Sequence[Union[Callable[[Any], Any], str]] 

46 

47 

48class XComArg(ResolveMixin, DependencyMixin): 

49 """Reference to an XCom value pushed from another operator. 

50 

51 The implementation supports:: 

52 

53 xcomarg >> op 

54 xcomarg << op 

55 op >> xcomarg # By BaseOperator code 

56 op << xcomarg # By BaseOperator code 

57 

58 **Example**: The moment you get a result from any operator (decorated or regular) you can :: 

59 

60 any_op = AnyOperator() 

61 xcomarg = XComArg(any_op) 

62 # or equivalently 

63 xcomarg = any_op.output 

64 my_op = MyOperator() 

65 my_op >> xcomarg 

66 

67 This object can be used in legacy Operators via Jinja. 

68 

69 **Example**: You can make this result to be part of any generated string:: 

70 

71 any_op = AnyOperator() 

72 xcomarg = any_op.output 

73 op1 = MyOperator(my_text_message=f"the value is {xcomarg}") 

74 op2 = MyOperator(my_text_message=f"the value is {xcomarg['topic']}") 

75 

76 :param operator: Operator instance to which the XComArg references. 

77 :param key: Key used to pull the XCom value. Defaults to *XCOM_RETURN_KEY*, 

78 i.e. the referenced operator's return value. 

79 """ 

80 

81 @overload 

82 def __new__(cls: type[XComArg], operator: Operator, key: str = XCOM_RETURN_KEY) -> XComArg: 

83 """Called when the user writes ``XComArg(...)`` directly.""" 

84 

85 @overload 

86 def __new__(cls: type[XComArg]) -> XComArg: 

87 """Called by Python internals from subclasses.""" 

88 

89 def __new__(cls, *args, **kwargs) -> XComArg: 

90 if cls is XComArg: 

91 return PlainXComArg(*args, **kwargs) 

92 return super().__new__(cls) 

93 

94 @staticmethod 

95 def iter_xcom_references(arg: Any) -> Iterator[tuple[Operator, str]]: 

96 """Return XCom references in an arbitrary value. 

97 

98 Recursively traverse ``arg`` and look for XComArg instances in any 

99 collection objects, and instances with ``template_fields`` set. 

100 """ 

101 if isinstance(arg, ResolveMixin): 

102 yield from arg.iter_references() 

103 elif isinstance(arg, (tuple, set, list)): 

104 for elem in arg: 

105 yield from XComArg.iter_xcom_references(elem) 

106 elif isinstance(arg, dict): 

107 for elem in arg.values(): 

108 yield from XComArg.iter_xcom_references(elem) 

109 elif isinstance(arg, AbstractOperator): 

110 for attr in arg.template_fields: 

111 yield from XComArg.iter_xcom_references(getattr(arg, attr)) 

112 

113 @staticmethod 

114 def apply_upstream_relationship(op: Operator, arg: Any): 

115 """Set dependency for XComArgs. 

116 

117 This looks for XComArg objects in ``arg`` "deeply" (looking inside 

118 collections objects and classes decorated with ``template_fields``), and 

119 sets the relationship to ``op`` on any found. 

120 """ 

121 for operator, _ in XComArg.iter_xcom_references(arg): 

122 op.set_upstream(operator) 

123 

124 @property 

125 def roots(self) -> list[DAGNode]: 

126 """Required by TaskMixin""" 

127 return [op for op, _ in self.iter_references()] 

128 

129 @property 

130 def leaves(self) -> list[DAGNode]: 

131 """Required by TaskMixin""" 

132 return [op for op, _ in self.iter_references()] 

133 

134 def set_upstream( 

135 self, 

136 task_or_task_list: DependencyMixin | Sequence[DependencyMixin], 

137 edge_modifier: EdgeModifier | None = None, 

138 ): 

139 """Proxy to underlying operator set_upstream method. Required by TaskMixin.""" 

140 for operator, _ in self.iter_references(): 

141 operator.set_upstream(task_or_task_list, edge_modifier) 

142 

143 def set_downstream( 

144 self, 

145 task_or_task_list: DependencyMixin | Sequence[DependencyMixin], 

146 edge_modifier: EdgeModifier | None = None, 

147 ): 

148 """Proxy to underlying operator set_downstream method. Required by TaskMixin.""" 

149 for operator, _ in self.iter_references(): 

150 operator.set_downstream(task_or_task_list, edge_modifier) 

151 

152 def _serialize(self) -> dict[str, Any]: 

153 """Called by DAG serialization. 

154 

155 The implementation should be the inverse function to ``deserialize``, 

156 returning a data dict converted from this XComArg derivative. DAG 

157 serialization does not call this directly, but ``serialize_xcom_arg`` 

158 instead, which adds additional information to dispatch deserialization 

159 to the correct class. 

160 """ 

161 raise NotImplementedError() 

162 

163 @classmethod 

164 def _deserialize(cls, data: dict[str, Any], dag: DAG) -> XComArg: 

165 """Called when deserializing a DAG. 

166 

167 The implementation should be the inverse function to ``serialize``, 

168 implementing given a data dict converted from this XComArg derivative, 

169 how the original XComArg should be created. DAG serialization relies on 

170 additional information added in ``serialize_xcom_arg`` to dispatch data 

171 dicts to the correct ``_deserialize`` information, so this function does 

172 not need to validate whether the incoming data contains correct keys. 

173 """ 

174 raise NotImplementedError() 

175 

176 def map(self, f: Callable[[Any], Any]) -> MapXComArg: 

177 return MapXComArg(self, [f]) 

178 

179 def zip(self, *others: XComArg, fillvalue: Any = NOTSET) -> ZipXComArg: 

180 return ZipXComArg([self, *others], fillvalue=fillvalue) 

181 

182 def get_task_map_length(self, run_id: str, *, session: Session) -> int | None: 

183 """Inspect length of pushed value for task-mapping. 

184 

185 This is used to determine how many task instances the scheduler should 

186 create for a downstream using this XComArg for task-mapping. 

187 

188 *None* may be returned if the depended XCom has not been pushed. 

189 """ 

190 raise NotImplementedError() 

191 

192 @provide_session 

193 def resolve(self, context: Context, session: Session = NEW_SESSION) -> Any: 

194 """Pull XCom value. 

195 

196 This should only be called during ``op.execute()`` with an appropriate 

197 context (e.g. generated from ``TaskInstance.get_template_context()``). 

198 Although the ``ResolveMixin`` parent mixin also has a ``resolve`` 

199 protocol, this adds the optional ``session`` argument that some of the 

200 subclasses need. 

201 

202 :meta private: 

203 """ 

204 raise NotImplementedError() 

205 

206 

207class PlainXComArg(XComArg): 

208 """Reference to one single XCom without any additional semantics. 

209 

210 This class should not be accessed directly, but only through XComArg. The 

211 class inheritance chain and ``__new__`` is implemented in this slightly 

212 convoluted way because we want to 

213 

214 a. Allow the user to continue using XComArg directly for the simple 

215 semantics (see documentation of the base class for details). 

216 b. Make ``isinstance(thing, XComArg)`` be able to detect all kinds of XCom 

217 references. 

218 c. Not allow many properties of PlainXComArg (including ``__getitem__`` and 

219 ``__str__``) to exist on other kinds of XComArg implementations since 

220 they don't make sense. 

221 

222 :meta private: 

223 """ 

224 

225 def __init__(self, operator: Operator, key: str = XCOM_RETURN_KEY): 

226 self.operator = operator 

227 self.key = key 

228 

229 def __eq__(self, other: Any) -> bool: 

230 if not isinstance(other, PlainXComArg): 

231 return NotImplemented 

232 return self.operator == other.operator and self.key == other.key 

233 

234 def __getitem__(self, item: str) -> XComArg: 

235 """Implements xcomresult['some_result_key']""" 

236 if not isinstance(item, str): 

237 raise ValueError(f"XComArg only supports str lookup, received {type(item).__name__}") 

238 return PlainXComArg(operator=self.operator, key=item) 

239 

240 def __iter__(self): 

241 """Override iterable protocol to raise error explicitly. 

242 

243 The default ``__iter__`` implementation in Python calls ``__getitem__`` 

244 with 0, 1, 2, etc. until it hits an ``IndexError``. This does not work 

245 well with our custom ``__getitem__`` implementation, and results in poor 

246 DAG-writing experience since a misplaced ``*`` expansion would create an 

247 infinite loop consuming the entire DAG parser. 

248 

249 This override catches the error eagerly, so an incorrectly implemented 

250 DAG fails fast and avoids wasting resources on nonsensical iterating. 

251 """ 

252 raise TypeError("'XComArg' object is not iterable") 

253 

254 def __repr__(self) -> str: 

255 if self.key == XCOM_RETURN_KEY: 

256 return f"XComArg({self.operator!r})" 

257 return f"XComArg({self.operator!r}, {self.key!r})" 

258 

259 def __str__(self) -> str: 

260 """ 

261 Backward compatibility for old-style jinja used in Airflow Operators 

262 

263 **Example**: to use XComArg at BashOperator:: 

264 

265 BashOperator(cmd=f"... { xcomarg } ...") 

266 

267 :return: 

268 """ 

269 xcom_pull_kwargs = [ 

270 f"task_ids='{self.operator.task_id}'", 

271 f"dag_id='{self.operator.dag_id}'", 

272 ] 

273 if self.key is not None: 

274 xcom_pull_kwargs.append(f"key='{self.key}'") 

275 

276 xcom_pull_str = ", ".join(xcom_pull_kwargs) 

277 # {{{{ are required for escape {{ in f-string 

278 xcom_pull = f"{{{{ task_instance.xcom_pull({xcom_pull_str}) }}}}" 

279 return xcom_pull 

280 

281 def _serialize(self) -> dict[str, Any]: 

282 return {"task_id": self.operator.task_id, "key": self.key} 

283 

284 @classmethod 

285 def _deserialize(cls, data: dict[str, Any], dag: DAG) -> XComArg: 

286 return cls(dag.get_task(data["task_id"]), data["key"]) 

287 

288 def iter_references(self) -> Iterator[tuple[Operator, str]]: 

289 yield self.operator, self.key 

290 

291 def map(self, f: Callable[[Any], Any]) -> MapXComArg: 

292 if self.key != XCOM_RETURN_KEY: 

293 raise ValueError("cannot map against non-return XCom") 

294 return super().map(f) 

295 

296 def zip(self, *others: XComArg, fillvalue: Any = NOTSET) -> ZipXComArg: 

297 if self.key != XCOM_RETURN_KEY: 

298 raise ValueError("cannot map against non-return XCom") 

299 return super().zip(*others, fillvalue=fillvalue) 

300 

301 def get_task_map_length(self, run_id: str, *, session: Session) -> int | None: 

302 from airflow.models.taskmap import TaskMap 

303 from airflow.models.xcom import XCom 

304 

305 task = self.operator 

306 if isinstance(task, MappedOperator): 

307 query = session.query(func.count(XCom.map_index)).filter( 

308 XCom.dag_id == task.dag_id, 

309 XCom.run_id == run_id, 

310 XCom.task_id == task.task_id, 

311 XCom.map_index >= 0, 

312 XCom.key == XCOM_RETURN_KEY, 

313 ) 

314 else: 

315 query = session.query(TaskMap.length).filter( 

316 TaskMap.dag_id == task.dag_id, 

317 TaskMap.run_id == run_id, 

318 TaskMap.task_id == task.task_id, 

319 TaskMap.map_index < 0, 

320 ) 

321 return query.scalar() 

322 

323 @provide_session 

324 def resolve(self, context: Context, session: Session = NEW_SESSION) -> Any: 

325 ti = context["ti"] 

326 task_id = self.operator.task_id 

327 map_indexes = ti.get_relevant_upstream_map_indexes( 

328 self.operator, 

329 context["expanded_ti_count"], 

330 session=session, 

331 ) 

332 result = ti.xcom_pull( 

333 task_ids=task_id, 

334 map_indexes=map_indexes, 

335 key=self.key, 

336 default=NOTSET, 

337 session=session, 

338 ) 

339 if not isinstance(result, ArgNotSet): 

340 return result 

341 if self.key == XCOM_RETURN_KEY: 

342 return None 

343 raise XComNotFound(ti.dag_id, task_id, self.key) 

344 

345 

346def _get_callable_name(f: Callable | str) -> str: 

347 """Try to "describe" a callable by getting its name.""" 

348 if callable(f): 

349 return f.__name__ 

350 # Parse the source to find whatever is behind "def". For safety, we don't 

351 # want to evaluate the code in any meaningful way! 

352 with contextlib.suppress(Exception): 

353 kw, name, _ = f.lstrip().split(None, 2) 

354 if kw == "def": 

355 return name 

356 return "<function>" 

357 

358 

359class _MapResult(Sequence): 

360 def __init__(self, value: Sequence | dict, callables: MapCallables) -> None: 

361 self.value = value 

362 self.callables = callables 

363 

364 def __getitem__(self, index: Any) -> Any: 

365 value = self.value[index] 

366 

367 # In the worker, we can access all actual callables. Call them. 

368 callables = [f for f in self.callables if callable(f)] 

369 if len(callables) == len(self.callables): 

370 for f in callables: 

371 value = f(value) 

372 return value 

373 

374 # In the scheduler, we don't have access to the actual callables, nor do 

375 # we want to run it since it's arbitrary code. This builds a string to 

376 # represent the call chain in the UI or logs instead. 

377 for v in self.callables: 

378 value = f"{_get_callable_name(v)}({value})" 

379 return value 

380 

381 def __len__(self) -> int: 

382 return len(self.value) 

383 

384 

385class MapXComArg(XComArg): 

386 """An XCom reference with ``map()`` call(s) applied. 

387 

388 This is based on an XComArg, but also applies a series of "transforms" that 

389 convert the pulled XCom value. 

390 

391 :meta private: 

392 """ 

393 

394 def __init__(self, arg: XComArg, callables: MapCallables) -> None: 

395 for c in callables: 

396 if getattr(c, "_airflow_is_task_decorator", False): 

397 raise ValueError("map() argument must be a plain function, not a @task operator") 

398 self.arg = arg 

399 self.callables = callables 

400 

401 def __repr__(self) -> str: 

402 map_calls = "".join(f".map({_get_callable_name(f)})" for f in self.callables) 

403 return f"{self.arg!r}{map_calls}" 

404 

405 def _serialize(self) -> dict[str, Any]: 

406 return { 

407 "arg": serialize_xcom_arg(self.arg), 

408 "callables": [inspect.getsource(c) if callable(c) else c for c in self.callables], 

409 } 

410 

411 @classmethod 

412 def _deserialize(cls, data: dict[str, Any], dag: DAG) -> XComArg: 

413 # We are deliberately NOT deserializing the callables. These are shown 

414 # in the UI, and displaying a function object is useless. 

415 return cls(deserialize_xcom_arg(data["arg"], dag), data["callables"]) 

416 

417 def iter_references(self) -> Iterator[tuple[Operator, str]]: 

418 yield from self.arg.iter_references() 

419 

420 def map(self, f: Callable[[Any], Any]) -> MapXComArg: 

421 # Flatten arg.map(f1).map(f2) into one MapXComArg. 

422 return MapXComArg(self.arg, [*self.callables, f]) 

423 

424 def get_task_map_length(self, run_id: str, *, session: Session) -> int | None: 

425 return self.arg.get_task_map_length(run_id, session=session) 

426 

427 @provide_session 

428 def resolve(self, context: Context, session: Session = NEW_SESSION) -> Any: 

429 value = self.arg.resolve(context, session=session) 

430 if not isinstance(value, (Sequence, dict)): 

431 raise ValueError(f"XCom map expects sequence or dict, not {type(value).__name__}") 

432 return _MapResult(value, self.callables) 

433 

434 

435class _ZipResult(Sequence): 

436 def __init__(self, values: Sequence[Sequence | dict], *, fillvalue: Any = NOTSET) -> None: 

437 self.values = values 

438 self.fillvalue = fillvalue 

439 

440 @staticmethod 

441 def _get_or_fill(container: Sequence | dict, index: Any, fillvalue: Any) -> Any: 

442 try: 

443 return container[index] 

444 except (IndexError, KeyError): 

445 return fillvalue 

446 

447 def __getitem__(self, index: Any) -> Any: 

448 if index >= len(self): 

449 raise IndexError(index) 

450 return tuple(self._get_or_fill(value, index, self.fillvalue) for value in self.values) 

451 

452 def __len__(self) -> int: 

453 lengths = (len(v) for v in self.values) 

454 if isinstance(self.fillvalue, ArgNotSet): 

455 return min(lengths) 

456 return max(lengths) 

457 

458 

459class ZipXComArg(XComArg): 

460 """An XCom reference with ``zip()`` applied. 

461 

462 This is constructed from multiple XComArg instances, and presents an 

463 iterable that "zips" them together like the built-in ``zip()`` (and 

464 ``itertools.zip_longest()`` if ``fillvalue`` is provided). 

465 """ 

466 

467 def __init__(self, args: Sequence[XComArg], *, fillvalue: Any = NOTSET) -> None: 

468 if not args: 

469 raise ValueError("At least one input is required") 

470 self.args = args 

471 self.fillvalue = fillvalue 

472 

473 def __repr__(self) -> str: 

474 args_iter = iter(self.args) 

475 first = repr(next(args_iter)) 

476 rest = ", ".join(repr(arg) for arg in args_iter) 

477 if isinstance(self.fillvalue, ArgNotSet): 

478 return f"{first}.zip({rest})" 

479 return f"{first}.zip({rest}, fillvalue={self.fillvalue!r})" 

480 

481 def _serialize(self) -> dict[str, Any]: 

482 args = [serialize_xcom_arg(arg) for arg in self.args] 

483 if isinstance(self.fillvalue, ArgNotSet): 

484 return {"args": args} 

485 return {"args": args, "fillvalue": self.fillvalue} 

486 

487 @classmethod 

488 def _deserialize(cls, data: dict[str, Any], dag: DAG) -> XComArg: 

489 return cls( 

490 [deserialize_xcom_arg(arg, dag) for arg in data["args"]], 

491 fillvalue=data.get("fillvalue", NOTSET), 

492 ) 

493 

494 def iter_references(self) -> Iterator[tuple[Operator, str]]: 

495 for arg in self.args: 

496 yield from arg.iter_references() 

497 

498 def get_task_map_length(self, run_id: str, *, session: Session) -> int | None: 

499 all_lengths = (arg.get_task_map_length(run_id, session=session) for arg in self.args) 

500 ready_lengths = [length for length in all_lengths if length is not None] 

501 if len(ready_lengths) != len(self.args): 

502 return None # If any of the referenced XComs is not ready, we are not ready either. 

503 if isinstance(self.fillvalue, ArgNotSet): 

504 return min(ready_lengths) 

505 return max(ready_lengths) 

506 

507 @provide_session 

508 def resolve(self, context: Context, session: Session = NEW_SESSION) -> Any: 

509 values = [arg.resolve(context, session=session) for arg in self.args] 

510 for value in values: 

511 if not isinstance(value, (Sequence, dict)): 

512 raise ValueError(f"XCom zip expects sequence or dict, not {type(value).__name__}") 

513 return _ZipResult(values, fillvalue=self.fillvalue) 

514 

515 

516_XCOM_ARG_TYPES: Mapping[str, type[XComArg]] = { 

517 "": PlainXComArg, 

518 "map": MapXComArg, 

519 "zip": ZipXComArg, 

520} 

521 

522 

523def serialize_xcom_arg(value: XComArg) -> dict[str, Any]: 

524 """DAG serialization interface.""" 

525 key = next(k for k, v in _XCOM_ARG_TYPES.items() if v == type(value)) 

526 if key: 

527 return {"type": key, **value._serialize()} 

528 return value._serialize() 

529 

530 

531def deserialize_xcom_arg(data: dict[str, Any], dag: DAG) -> XComArg: 

532 """DAG serialization interface.""" 

533 klass = _XCOM_ARG_TYPES[data.get("type", "")] 

534 return klass._deserialize(data, dag)