Coverage for /pythoncovmergedfiles/medio/medio/src/airflow/airflow/models/xcom_arg.py: 39%

264 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-06-07 06:35 +0000

1# Licensed to the Apache Software Foundation (ASF) under one 

2# or more contributor license agreements. See the NOTICE file 

3# distributed with this work for additional information 

4# regarding copyright ownership. The ASF licenses this file 

5# to you under the Apache License, Version 2.0 (the 

6# "License"); you may not use this file except in compliance 

7# with the License. You may obtain a copy of the License at 

8# 

9# http://www.apache.org/licenses/LICENSE-2.0 

10# 

11# Unless required by applicable law or agreed to in writing, 

12# software distributed under the License is distributed on an 

13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

14# KIND, either express or implied. See the License for the 

15# specific language governing permissions and limitations 

16# under the License. 

17 

18from __future__ import annotations 

19 

20import contextlib 

21import inspect 

22from typing import TYPE_CHECKING, Any, Callable, Iterator, Mapping, Sequence, Union, overload 

23 

24from sqlalchemy import func, or_ 

25from sqlalchemy.orm import Session 

26 

27from airflow.exceptions import AirflowException, XComNotFound 

28from airflow.models.abstractoperator import AbstractOperator 

29from airflow.models.mappedoperator import MappedOperator 

30from airflow.models.taskmixin import DAGNode, DependencyMixin 

31from airflow.utils.context import Context 

32from airflow.utils.edgemodifier import EdgeModifier 

33from airflow.utils.mixins import ResolveMixin 

34from airflow.utils.session import NEW_SESSION, provide_session 

35from airflow.utils.setup_teardown import SetupTeardownContext 

36from airflow.utils.state import State 

37from airflow.utils.types import NOTSET, ArgNotSet 

38from airflow.utils.xcom import XCOM_RETURN_KEY 

39 

40if TYPE_CHECKING: 

41 from airflow.models.dag import DAG 

42 from airflow.models.operator import Operator 

43 

44# Callable objects contained by MapXComArg. We only accept callables from 

45# the user, but deserialize them into strings in a serialized XComArg for 

46# safety (those callables are arbitrary user code). 

47MapCallables = Sequence[Union[Callable[[Any], Any], str]] 

48 

49 

50class XComArg(ResolveMixin, DependencyMixin): 

51 """Reference to an XCom value pushed from another operator. 

52 

53 The implementation supports:: 

54 

55 xcomarg >> op 

56 xcomarg << op 

57 op >> xcomarg # By BaseOperator code 

58 op << xcomarg # By BaseOperator code 

59 

60 **Example**: The moment you get a result from any operator (decorated or regular) you can :: 

61 

62 any_op = AnyOperator() 

63 xcomarg = XComArg(any_op) 

64 # or equivalently 

65 xcomarg = any_op.output 

66 my_op = MyOperator() 

67 my_op >> xcomarg 

68 

69 This object can be used in legacy Operators via Jinja. 

70 

71 **Example**: You can make this result to be part of any generated string:: 

72 

73 any_op = AnyOperator() 

74 xcomarg = any_op.output 

75 op1 = MyOperator(my_text_message=f"the value is {xcomarg}") 

76 op2 = MyOperator(my_text_message=f"the value is {xcomarg['topic']}") 

77 

78 :param operator: Operator instance to which the XComArg references. 

79 :param key: Key used to pull the XCom value. Defaults to *XCOM_RETURN_KEY*, 

80 i.e. the referenced operator's return value. 

81 """ 

82 

83 @overload 

84 def __new__(cls: type[XComArg], operator: Operator, key: str = XCOM_RETURN_KEY) -> XComArg: 

85 """Called when the user writes ``XComArg(...)`` directly.""" 

86 

87 @overload 

88 def __new__(cls: type[XComArg]) -> XComArg: 

89 """Called by Python internals from subclasses.""" 

90 

91 def __new__(cls, *args, **kwargs) -> XComArg: 

92 if cls is XComArg: 

93 return PlainXComArg(*args, **kwargs) 

94 return super().__new__(cls) 

95 

96 @staticmethod 

97 def iter_xcom_references(arg: Any) -> Iterator[tuple[Operator, str]]: 

98 """Return XCom references in an arbitrary value. 

99 

100 Recursively traverse ``arg`` and look for XComArg instances in any 

101 collection objects, and instances with ``template_fields`` set. 

102 """ 

103 if isinstance(arg, ResolveMixin): 

104 yield from arg.iter_references() 

105 elif isinstance(arg, (tuple, set, list)): 

106 for elem in arg: 

107 yield from XComArg.iter_xcom_references(elem) 

108 elif isinstance(arg, dict): 

109 for elem in arg.values(): 

110 yield from XComArg.iter_xcom_references(elem) 

111 elif isinstance(arg, AbstractOperator): 

112 for attr in arg.template_fields: 

113 yield from XComArg.iter_xcom_references(getattr(arg, attr)) 

114 

115 @staticmethod 

116 def apply_upstream_relationship(op: Operator, arg: Any): 

117 """Set dependency for XComArgs. 

118 

119 This looks for XComArg objects in ``arg`` "deeply" (looking inside 

120 collections objects and classes decorated with ``template_fields``), and 

121 sets the relationship to ``op`` on any found. 

122 """ 

123 for operator, _ in XComArg.iter_xcom_references(arg): 

124 op.set_upstream(operator) 

125 

126 @property 

127 def roots(self) -> list[DAGNode]: 

128 """Required by TaskMixin.""" 

129 return [op for op, _ in self.iter_references()] 

130 

131 @property 

132 def leaves(self) -> list[DAGNode]: 

133 """Required by TaskMixin.""" 

134 return [op for op, _ in self.iter_references()] 

135 

136 def set_upstream( 

137 self, 

138 task_or_task_list: DependencyMixin | Sequence[DependencyMixin], 

139 edge_modifier: EdgeModifier | None = None, 

140 ): 

141 """Proxy to underlying operator set_upstream method. Required by TaskMixin.""" 

142 for operator, _ in self.iter_references(): 

143 operator.set_upstream(task_or_task_list, edge_modifier) 

144 

145 def set_downstream( 

146 self, 

147 task_or_task_list: DependencyMixin | Sequence[DependencyMixin], 

148 edge_modifier: EdgeModifier | None = None, 

149 ): 

150 """Proxy to underlying operator set_downstream method. Required by TaskMixin.""" 

151 for operator, _ in self.iter_references(): 

152 operator.set_downstream(task_or_task_list, edge_modifier) 

153 

154 def _serialize(self) -> dict[str, Any]: 

155 """Called by DAG serialization. 

156 

157 The implementation should be the inverse function to ``deserialize``, 

158 returning a data dict converted from this XComArg derivative. DAG 

159 serialization does not call this directly, but ``serialize_xcom_arg`` 

160 instead, which adds additional information to dispatch deserialization 

161 to the correct class. 

162 """ 

163 raise NotImplementedError() 

164 

165 @classmethod 

166 def _deserialize(cls, data: dict[str, Any], dag: DAG) -> XComArg: 

167 """Called when deserializing a DAG. 

168 

169 The implementation should be the inverse function to ``serialize``, 

170 implementing given a data dict converted from this XComArg derivative, 

171 how the original XComArg should be created. DAG serialization relies on 

172 additional information added in ``serialize_xcom_arg`` to dispatch data 

173 dicts to the correct ``_deserialize`` information, so this function does 

174 not need to validate whether the incoming data contains correct keys. 

175 """ 

176 raise NotImplementedError() 

177 

178 def map(self, f: Callable[[Any], Any]) -> MapXComArg: 

179 return MapXComArg(self, [f]) 

180 

181 def zip(self, *others: XComArg, fillvalue: Any = NOTSET) -> ZipXComArg: 

182 return ZipXComArg([self, *others], fillvalue=fillvalue) 

183 

184 def get_task_map_length(self, run_id: str, *, session: Session) -> int | None: 

185 """Inspect length of pushed value for task-mapping. 

186 

187 This is used to determine how many task instances the scheduler should 

188 create for a downstream using this XComArg for task-mapping. 

189 

190 *None* may be returned if the depended XCom has not been pushed. 

191 """ 

192 raise NotImplementedError() 

193 

194 @provide_session 

195 def resolve(self, context: Context, session: Session = NEW_SESSION) -> Any: 

196 """Pull XCom value. 

197 

198 This should only be called during ``op.execute()`` with an appropriate 

199 context (e.g. generated from ``TaskInstance.get_template_context()``). 

200 Although the ``ResolveMixin`` parent mixin also has a ``resolve`` 

201 protocol, this adds the optional ``session`` argument that some of the 

202 subclasses need. 

203 

204 :meta private: 

205 """ 

206 raise NotImplementedError() 

207 

208 def __enter__(self): 

209 if not self.operator.is_setup and not self.operator.is_teardown: 

210 raise AirflowException("Only setup/teardown tasks can be used as context managers.") 

211 SetupTeardownContext.push_setup_teardown_task(self.operator) 

212 return self 

213 

214 def __exit__(self, exc_type, exc_val, exc_tb): 

215 SetupTeardownContext.set_work_task_roots_and_leaves() 

216 

217 

218class PlainXComArg(XComArg): 

219 """Reference to one single XCom without any additional semantics. 

220 

221 This class should not be accessed directly, but only through XComArg. The 

222 class inheritance chain and ``__new__`` is implemented in this slightly 

223 convoluted way because we want to 

224 

225 a. Allow the user to continue using XComArg directly for the simple 

226 semantics (see documentation of the base class for details). 

227 b. Make ``isinstance(thing, XComArg)`` be able to detect all kinds of XCom 

228 references. 

229 c. Not allow many properties of PlainXComArg (including ``__getitem__`` and 

230 ``__str__``) to exist on other kinds of XComArg implementations since 

231 they don't make sense. 

232 

233 :meta private: 

234 """ 

235 

236 def __init__(self, operator: Operator, key: str = XCOM_RETURN_KEY): 

237 self.operator = operator 

238 self.key = key 

239 

240 def __eq__(self, other: Any) -> bool: 

241 if not isinstance(other, PlainXComArg): 

242 return NotImplemented 

243 return self.operator == other.operator and self.key == other.key 

244 

245 def __getitem__(self, item: str) -> XComArg: 

246 """Implements xcomresult['some_result_key'].""" 

247 if not isinstance(item, str): 

248 raise ValueError(f"XComArg only supports str lookup, received {type(item).__name__}") 

249 return PlainXComArg(operator=self.operator, key=item) 

250 

251 def __iter__(self): 

252 """Override iterable protocol to raise error explicitly. 

253 

254 The default ``__iter__`` implementation in Python calls ``__getitem__`` 

255 with 0, 1, 2, etc. until it hits an ``IndexError``. This does not work 

256 well with our custom ``__getitem__`` implementation, and results in poor 

257 DAG-writing experience since a misplaced ``*`` expansion would create an 

258 infinite loop consuming the entire DAG parser. 

259 

260 This override catches the error eagerly, so an incorrectly implemented 

261 DAG fails fast and avoids wasting resources on nonsensical iterating. 

262 """ 

263 raise TypeError("'XComArg' object is not iterable") 

264 

265 def __repr__(self) -> str: 

266 if self.key == XCOM_RETURN_KEY: 

267 return f"XComArg({self.operator!r})" 

268 return f"XComArg({self.operator!r}, {self.key!r})" 

269 

270 def __str__(self) -> str: 

271 """ 

272 Backward compatibility for old-style jinja used in Airflow Operators. 

273 

274 **Example**: to use XComArg at BashOperator:: 

275 

276 BashOperator(cmd=f"... { xcomarg } ...") 

277 

278 :return: 

279 """ 

280 xcom_pull_kwargs = [ 

281 f"task_ids='{self.operator.task_id}'", 

282 f"dag_id='{self.operator.dag_id}'", 

283 ] 

284 if self.key is not None: 

285 xcom_pull_kwargs.append(f"key='{self.key}'") 

286 

287 xcom_pull_str = ", ".join(xcom_pull_kwargs) 

288 # {{{{ are required for escape {{ in f-string 

289 xcom_pull = f"{{{{ task_instance.xcom_pull({xcom_pull_str}) }}}}" 

290 return xcom_pull 

291 

292 def _serialize(self) -> dict[str, Any]: 

293 return {"task_id": self.operator.task_id, "key": self.key} 

294 

295 @classmethod 

296 def _deserialize(cls, data: dict[str, Any], dag: DAG) -> XComArg: 

297 return cls(dag.get_task(data["task_id"]), data["key"]) 

298 

299 def iter_references(self) -> Iterator[tuple[Operator, str]]: 

300 yield self.operator, self.key 

301 

302 def map(self, f: Callable[[Any], Any]) -> MapXComArg: 

303 if self.key != XCOM_RETURN_KEY: 

304 raise ValueError("cannot map against non-return XCom") 

305 return super().map(f) 

306 

307 def zip(self, *others: XComArg, fillvalue: Any = NOTSET) -> ZipXComArg: 

308 if self.key != XCOM_RETURN_KEY: 

309 raise ValueError("cannot map against non-return XCom") 

310 return super().zip(*others, fillvalue=fillvalue) 

311 

312 def get_task_map_length(self, run_id: str, *, session: Session) -> int | None: 

313 from airflow.models.taskinstance import TaskInstance 

314 from airflow.models.taskmap import TaskMap 

315 from airflow.models.xcom import XCom 

316 

317 task = self.operator 

318 if isinstance(task, MappedOperator): 

319 unfinished_ti_count_query = session.query(func.count(TaskInstance.map_index)).filter( 

320 TaskInstance.dag_id == task.dag_id, 

321 TaskInstance.run_id == run_id, 

322 TaskInstance.task_id == task.task_id, 

323 # Special NULL treatment is needed because 'state' can be NULL. 

324 # The "IN" part would produce "NULL NOT IN ..." and eventually 

325 # "NULl = NULL", which is a big no-no in SQL. 

326 or_( 

327 TaskInstance.state.is_(None), 

328 TaskInstance.state.in_(s.value for s in State.unfinished if s is not None), 

329 ), 

330 ) 

331 if unfinished_ti_count_query.scalar(): 

332 return None # Not all of the expanded tis are done yet. 

333 query = session.query(func.count(XCom.map_index)).filter( 

334 XCom.dag_id == task.dag_id, 

335 XCom.run_id == run_id, 

336 XCom.task_id == task.task_id, 

337 XCom.map_index >= 0, 

338 XCom.key == XCOM_RETURN_KEY, 

339 ) 

340 else: 

341 query = session.query(TaskMap.length).filter( 

342 TaskMap.dag_id == task.dag_id, 

343 TaskMap.run_id == run_id, 

344 TaskMap.task_id == task.task_id, 

345 TaskMap.map_index < 0, 

346 ) 

347 return query.scalar() 

348 

349 @provide_session 

350 def resolve(self, context: Context, session: Session = NEW_SESSION) -> Any: 

351 from airflow.models.taskinstance import TaskInstance 

352 

353 ti = context["ti"] 

354 assert isinstance(ti, TaskInstance), "Wait for AIP-44 implementation to complete" 

355 

356 task_id = self.operator.task_id 

357 map_indexes = ti.get_relevant_upstream_map_indexes( 

358 self.operator, 

359 context["expanded_ti_count"], 

360 session=session, 

361 ) 

362 result = ti.xcom_pull( 

363 task_ids=task_id, 

364 map_indexes=map_indexes, 

365 key=self.key, 

366 default=NOTSET, 

367 session=session, 

368 ) 

369 if not isinstance(result, ArgNotSet): 

370 return result 

371 if self.key == XCOM_RETURN_KEY: 

372 return None 

373 raise XComNotFound(ti.dag_id, task_id, self.key) 

374 

375 

376def _get_callable_name(f: Callable | str) -> str: 

377 """Try to "describe" a callable by getting its name.""" 

378 if callable(f): 

379 return f.__name__ 

380 # Parse the source to find whatever is behind "def". For safety, we don't 

381 # want to evaluate the code in any meaningful way! 

382 with contextlib.suppress(Exception): 

383 kw, name, _ = f.lstrip().split(None, 2) 

384 if kw == "def": 

385 return name 

386 return "<function>" 

387 

388 

389class _MapResult(Sequence): 

390 def __init__(self, value: Sequence | dict, callables: MapCallables) -> None: 

391 self.value = value 

392 self.callables = callables 

393 

394 def __getitem__(self, index: Any) -> Any: 

395 value = self.value[index] 

396 

397 # In the worker, we can access all actual callables. Call them. 

398 callables = [f for f in self.callables if callable(f)] 

399 if len(callables) == len(self.callables): 

400 for f in callables: 

401 value = f(value) 

402 return value 

403 

404 # In the scheduler, we don't have access to the actual callables, nor do 

405 # we want to run it since it's arbitrary code. This builds a string to 

406 # represent the call chain in the UI or logs instead. 

407 for v in self.callables: 

408 value = f"{_get_callable_name(v)}({value})" 

409 return value 

410 

411 def __len__(self) -> int: 

412 return len(self.value) 

413 

414 

415class MapXComArg(XComArg): 

416 """An XCom reference with ``map()`` call(s) applied. 

417 

418 This is based on an XComArg, but also applies a series of "transforms" that 

419 convert the pulled XCom value. 

420 

421 :meta private: 

422 """ 

423 

424 def __init__(self, arg: XComArg, callables: MapCallables) -> None: 

425 for c in callables: 

426 if getattr(c, "_airflow_is_task_decorator", False): 

427 raise ValueError("map() argument must be a plain function, not a @task operator") 

428 self.arg = arg 

429 self.callables = callables 

430 

431 def __repr__(self) -> str: 

432 map_calls = "".join(f".map({_get_callable_name(f)})" for f in self.callables) 

433 return f"{self.arg!r}{map_calls}" 

434 

435 def _serialize(self) -> dict[str, Any]: 

436 return { 

437 "arg": serialize_xcom_arg(self.arg), 

438 "callables": [inspect.getsource(c) if callable(c) else c for c in self.callables], 

439 } 

440 

441 @classmethod 

442 def _deserialize(cls, data: dict[str, Any], dag: DAG) -> XComArg: 

443 # We are deliberately NOT deserializing the callables. These are shown 

444 # in the UI, and displaying a function object is useless. 

445 return cls(deserialize_xcom_arg(data["arg"], dag), data["callables"]) 

446 

447 def iter_references(self) -> Iterator[tuple[Operator, str]]: 

448 yield from self.arg.iter_references() 

449 

450 def map(self, f: Callable[[Any], Any]) -> MapXComArg: 

451 # Flatten arg.map(f1).map(f2) into one MapXComArg. 

452 return MapXComArg(self.arg, [*self.callables, f]) 

453 

454 def get_task_map_length(self, run_id: str, *, session: Session) -> int | None: 

455 return self.arg.get_task_map_length(run_id, session=session) 

456 

457 @provide_session 

458 def resolve(self, context: Context, session: Session = NEW_SESSION) -> Any: 

459 value = self.arg.resolve(context, session=session) 

460 if not isinstance(value, (Sequence, dict)): 

461 raise ValueError(f"XCom map expects sequence or dict, not {type(value).__name__}") 

462 return _MapResult(value, self.callables) 

463 

464 

465class _ZipResult(Sequence): 

466 def __init__(self, values: Sequence[Sequence | dict], *, fillvalue: Any = NOTSET) -> None: 

467 self.values = values 

468 self.fillvalue = fillvalue 

469 

470 @staticmethod 

471 def _get_or_fill(container: Sequence | dict, index: Any, fillvalue: Any) -> Any: 

472 try: 

473 return container[index] 

474 except (IndexError, KeyError): 

475 return fillvalue 

476 

477 def __getitem__(self, index: Any) -> Any: 

478 if index >= len(self): 

479 raise IndexError(index) 

480 return tuple(self._get_or_fill(value, index, self.fillvalue) for value in self.values) 

481 

482 def __len__(self) -> int: 

483 lengths = (len(v) for v in self.values) 

484 if isinstance(self.fillvalue, ArgNotSet): 

485 return min(lengths) 

486 return max(lengths) 

487 

488 

489class ZipXComArg(XComArg): 

490 """An XCom reference with ``zip()`` applied. 

491 

492 This is constructed from multiple XComArg instances, and presents an 

493 iterable that "zips" them together like the built-in ``zip()`` (and 

494 ``itertools.zip_longest()`` if ``fillvalue`` is provided). 

495 """ 

496 

497 def __init__(self, args: Sequence[XComArg], *, fillvalue: Any = NOTSET) -> None: 

498 if not args: 

499 raise ValueError("At least one input is required") 

500 self.args = args 

501 self.fillvalue = fillvalue 

502 

503 def __repr__(self) -> str: 

504 args_iter = iter(self.args) 

505 first = repr(next(args_iter)) 

506 rest = ", ".join(repr(arg) for arg in args_iter) 

507 if isinstance(self.fillvalue, ArgNotSet): 

508 return f"{first}.zip({rest})" 

509 return f"{first}.zip({rest}, fillvalue={self.fillvalue!r})" 

510 

511 def _serialize(self) -> dict[str, Any]: 

512 args = [serialize_xcom_arg(arg) for arg in self.args] 

513 if isinstance(self.fillvalue, ArgNotSet): 

514 return {"args": args} 

515 return {"args": args, "fillvalue": self.fillvalue} 

516 

517 @classmethod 

518 def _deserialize(cls, data: dict[str, Any], dag: DAG) -> XComArg: 

519 return cls( 

520 [deserialize_xcom_arg(arg, dag) for arg in data["args"]], 

521 fillvalue=data.get("fillvalue", NOTSET), 

522 ) 

523 

524 def iter_references(self) -> Iterator[tuple[Operator, str]]: 

525 for arg in self.args: 

526 yield from arg.iter_references() 

527 

528 def get_task_map_length(self, run_id: str, *, session: Session) -> int | None: 

529 all_lengths = (arg.get_task_map_length(run_id, session=session) for arg in self.args) 

530 ready_lengths = [length for length in all_lengths if length is not None] 

531 if len(ready_lengths) != len(self.args): 

532 return None # If any of the referenced XComs is not ready, we are not ready either. 

533 if isinstance(self.fillvalue, ArgNotSet): 

534 return min(ready_lengths) 

535 return max(ready_lengths) 

536 

537 @provide_session 

538 def resolve(self, context: Context, session: Session = NEW_SESSION) -> Any: 

539 values = [arg.resolve(context, session=session) for arg in self.args] 

540 for value in values: 

541 if not isinstance(value, (Sequence, dict)): 

542 raise ValueError(f"XCom zip expects sequence or dict, not {type(value).__name__}") 

543 return _ZipResult(values, fillvalue=self.fillvalue) 

544 

545 

546_XCOM_ARG_TYPES: Mapping[str, type[XComArg]] = { 

547 "": PlainXComArg, 

548 "map": MapXComArg, 

549 "zip": ZipXComArg, 

550} 

551 

552 

553def serialize_xcom_arg(value: XComArg) -> dict[str, Any]: 

554 """DAG serialization interface.""" 

555 key = next(k for k, v in _XCOM_ARG_TYPES.items() if v == type(value)) 

556 if key: 

557 return {"type": key, **value._serialize()} 

558 return value._serialize() 

559 

560 

561def deserialize_xcom_arg(data: dict[str, Any], dag: DAG) -> XComArg: 

562 """DAG serialization interface.""" 

563 klass = _XCOM_ARG_TYPES[data.get("type", "")] 

564 return klass._deserialize(data, dag)