Coverage for /pythoncovmergedfiles/medio/medio/src/airflow/build/lib/airflow/models/xcom_arg.py: 39%
264 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 06:35 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 06:35 +0000
1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements. See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership. The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License. You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied. See the License for the
15# specific language governing permissions and limitations
16# under the License.
18from __future__ import annotations
20import contextlib
21import inspect
22from typing import TYPE_CHECKING, Any, Callable, Iterator, Mapping, Sequence, Union, overload
24from sqlalchemy import func, or_
25from sqlalchemy.orm import Session
27from airflow.exceptions import AirflowException, XComNotFound
28from airflow.models.abstractoperator import AbstractOperator
29from airflow.models.mappedoperator import MappedOperator
30from airflow.models.taskmixin import DAGNode, DependencyMixin
31from airflow.utils.context import Context
32from airflow.utils.edgemodifier import EdgeModifier
33from airflow.utils.mixins import ResolveMixin
34from airflow.utils.session import NEW_SESSION, provide_session
35from airflow.utils.setup_teardown import SetupTeardownContext
36from airflow.utils.state import State
37from airflow.utils.types import NOTSET, ArgNotSet
38from airflow.utils.xcom import XCOM_RETURN_KEY
40if TYPE_CHECKING:
41 from airflow.models.dag import DAG
42 from airflow.models.operator import Operator
44# Callable objects contained by MapXComArg. We only accept callables from
45# the user, but deserialize them into strings in a serialized XComArg for
46# safety (those callables are arbitrary user code).
47MapCallables = Sequence[Union[Callable[[Any], Any], str]]
50class XComArg(ResolveMixin, DependencyMixin):
51 """Reference to an XCom value pushed from another operator.
53 The implementation supports::
55 xcomarg >> op
56 xcomarg << op
57 op >> xcomarg # By BaseOperator code
58 op << xcomarg # By BaseOperator code
60 **Example**: The moment you get a result from any operator (decorated or regular) you can ::
62 any_op = AnyOperator()
63 xcomarg = XComArg(any_op)
64 # or equivalently
65 xcomarg = any_op.output
66 my_op = MyOperator()
67 my_op >> xcomarg
69 This object can be used in legacy Operators via Jinja.
71 **Example**: You can make this result to be part of any generated string::
73 any_op = AnyOperator()
74 xcomarg = any_op.output
75 op1 = MyOperator(my_text_message=f"the value is {xcomarg}")
76 op2 = MyOperator(my_text_message=f"the value is {xcomarg['topic']}")
78 :param operator: Operator instance to which the XComArg references.
79 :param key: Key used to pull the XCom value. Defaults to *XCOM_RETURN_KEY*,
80 i.e. the referenced operator's return value.
81 """
83 @overload
84 def __new__(cls: type[XComArg], operator: Operator, key: str = XCOM_RETURN_KEY) -> XComArg:
85 """Called when the user writes ``XComArg(...)`` directly."""
87 @overload
88 def __new__(cls: type[XComArg]) -> XComArg:
89 """Called by Python internals from subclasses."""
91 def __new__(cls, *args, **kwargs) -> XComArg:
92 if cls is XComArg:
93 return PlainXComArg(*args, **kwargs)
94 return super().__new__(cls)
96 @staticmethod
97 def iter_xcom_references(arg: Any) -> Iterator[tuple[Operator, str]]:
98 """Return XCom references in an arbitrary value.
100 Recursively traverse ``arg`` and look for XComArg instances in any
101 collection objects, and instances with ``template_fields`` set.
102 """
103 if isinstance(arg, ResolveMixin):
104 yield from arg.iter_references()
105 elif isinstance(arg, (tuple, set, list)):
106 for elem in arg:
107 yield from XComArg.iter_xcom_references(elem)
108 elif isinstance(arg, dict):
109 for elem in arg.values():
110 yield from XComArg.iter_xcom_references(elem)
111 elif isinstance(arg, AbstractOperator):
112 for attr in arg.template_fields:
113 yield from XComArg.iter_xcom_references(getattr(arg, attr))
115 @staticmethod
116 def apply_upstream_relationship(op: Operator, arg: Any):
117 """Set dependency for XComArgs.
119 This looks for XComArg objects in ``arg`` "deeply" (looking inside
120 collections objects and classes decorated with ``template_fields``), and
121 sets the relationship to ``op`` on any found.
122 """
123 for operator, _ in XComArg.iter_xcom_references(arg):
124 op.set_upstream(operator)
126 @property
127 def roots(self) -> list[DAGNode]:
128 """Required by TaskMixin."""
129 return [op for op, _ in self.iter_references()]
131 @property
132 def leaves(self) -> list[DAGNode]:
133 """Required by TaskMixin."""
134 return [op for op, _ in self.iter_references()]
136 def set_upstream(
137 self,
138 task_or_task_list: DependencyMixin | Sequence[DependencyMixin],
139 edge_modifier: EdgeModifier | None = None,
140 ):
141 """Proxy to underlying operator set_upstream method. Required by TaskMixin."""
142 for operator, _ in self.iter_references():
143 operator.set_upstream(task_or_task_list, edge_modifier)
145 def set_downstream(
146 self,
147 task_or_task_list: DependencyMixin | Sequence[DependencyMixin],
148 edge_modifier: EdgeModifier | None = None,
149 ):
150 """Proxy to underlying operator set_downstream method. Required by TaskMixin."""
151 for operator, _ in self.iter_references():
152 operator.set_downstream(task_or_task_list, edge_modifier)
154 def _serialize(self) -> dict[str, Any]:
155 """Called by DAG serialization.
157 The implementation should be the inverse function to ``deserialize``,
158 returning a data dict converted from this XComArg derivative. DAG
159 serialization does not call this directly, but ``serialize_xcom_arg``
160 instead, which adds additional information to dispatch deserialization
161 to the correct class.
162 """
163 raise NotImplementedError()
165 @classmethod
166 def _deserialize(cls, data: dict[str, Any], dag: DAG) -> XComArg:
167 """Called when deserializing a DAG.
169 The implementation should be the inverse function to ``serialize``,
170 implementing given a data dict converted from this XComArg derivative,
171 how the original XComArg should be created. DAG serialization relies on
172 additional information added in ``serialize_xcom_arg`` to dispatch data
173 dicts to the correct ``_deserialize`` information, so this function does
174 not need to validate whether the incoming data contains correct keys.
175 """
176 raise NotImplementedError()
178 def map(self, f: Callable[[Any], Any]) -> MapXComArg:
179 return MapXComArg(self, [f])
181 def zip(self, *others: XComArg, fillvalue: Any = NOTSET) -> ZipXComArg:
182 return ZipXComArg([self, *others], fillvalue=fillvalue)
184 def get_task_map_length(self, run_id: str, *, session: Session) -> int | None:
185 """Inspect length of pushed value for task-mapping.
187 This is used to determine how many task instances the scheduler should
188 create for a downstream using this XComArg for task-mapping.
190 *None* may be returned if the depended XCom has not been pushed.
191 """
192 raise NotImplementedError()
194 @provide_session
195 def resolve(self, context: Context, session: Session = NEW_SESSION) -> Any:
196 """Pull XCom value.
198 This should only be called during ``op.execute()`` with an appropriate
199 context (e.g. generated from ``TaskInstance.get_template_context()``).
200 Although the ``ResolveMixin`` parent mixin also has a ``resolve``
201 protocol, this adds the optional ``session`` argument that some of the
202 subclasses need.
204 :meta private:
205 """
206 raise NotImplementedError()
208 def __enter__(self):
209 if not self.operator.is_setup and not self.operator.is_teardown:
210 raise AirflowException("Only setup/teardown tasks can be used as context managers.")
211 SetupTeardownContext.push_setup_teardown_task(self.operator)
212 return self
214 def __exit__(self, exc_type, exc_val, exc_tb):
215 SetupTeardownContext.set_work_task_roots_and_leaves()
218class PlainXComArg(XComArg):
219 """Reference to one single XCom without any additional semantics.
221 This class should not be accessed directly, but only through XComArg. The
222 class inheritance chain and ``__new__`` is implemented in this slightly
223 convoluted way because we want to
225 a. Allow the user to continue using XComArg directly for the simple
226 semantics (see documentation of the base class for details).
227 b. Make ``isinstance(thing, XComArg)`` be able to detect all kinds of XCom
228 references.
229 c. Not allow many properties of PlainXComArg (including ``__getitem__`` and
230 ``__str__``) to exist on other kinds of XComArg implementations since
231 they don't make sense.
233 :meta private:
234 """
236 def __init__(self, operator: Operator, key: str = XCOM_RETURN_KEY):
237 self.operator = operator
238 self.key = key
240 def __eq__(self, other: Any) -> bool:
241 if not isinstance(other, PlainXComArg):
242 return NotImplemented
243 return self.operator == other.operator and self.key == other.key
245 def __getitem__(self, item: str) -> XComArg:
246 """Implements xcomresult['some_result_key']."""
247 if not isinstance(item, str):
248 raise ValueError(f"XComArg only supports str lookup, received {type(item).__name__}")
249 return PlainXComArg(operator=self.operator, key=item)
251 def __iter__(self):
252 """Override iterable protocol to raise error explicitly.
254 The default ``__iter__`` implementation in Python calls ``__getitem__``
255 with 0, 1, 2, etc. until it hits an ``IndexError``. This does not work
256 well with our custom ``__getitem__`` implementation, and results in poor
257 DAG-writing experience since a misplaced ``*`` expansion would create an
258 infinite loop consuming the entire DAG parser.
260 This override catches the error eagerly, so an incorrectly implemented
261 DAG fails fast and avoids wasting resources on nonsensical iterating.
262 """
263 raise TypeError("'XComArg' object is not iterable")
265 def __repr__(self) -> str:
266 if self.key == XCOM_RETURN_KEY:
267 return f"XComArg({self.operator!r})"
268 return f"XComArg({self.operator!r}, {self.key!r})"
270 def __str__(self) -> str:
271 """
272 Backward compatibility for old-style jinja used in Airflow Operators.
274 **Example**: to use XComArg at BashOperator::
276 BashOperator(cmd=f"... { xcomarg } ...")
278 :return:
279 """
280 xcom_pull_kwargs = [
281 f"task_ids='{self.operator.task_id}'",
282 f"dag_id='{self.operator.dag_id}'",
283 ]
284 if self.key is not None:
285 xcom_pull_kwargs.append(f"key='{self.key}'")
287 xcom_pull_str = ", ".join(xcom_pull_kwargs)
288 # {{{{ are required for escape {{ in f-string
289 xcom_pull = f"{{{{ task_instance.xcom_pull({xcom_pull_str}) }}}}"
290 return xcom_pull
292 def _serialize(self) -> dict[str, Any]:
293 return {"task_id": self.operator.task_id, "key": self.key}
295 @classmethod
296 def _deserialize(cls, data: dict[str, Any], dag: DAG) -> XComArg:
297 return cls(dag.get_task(data["task_id"]), data["key"])
299 def iter_references(self) -> Iterator[tuple[Operator, str]]:
300 yield self.operator, self.key
302 def map(self, f: Callable[[Any], Any]) -> MapXComArg:
303 if self.key != XCOM_RETURN_KEY:
304 raise ValueError("cannot map against non-return XCom")
305 return super().map(f)
307 def zip(self, *others: XComArg, fillvalue: Any = NOTSET) -> ZipXComArg:
308 if self.key != XCOM_RETURN_KEY:
309 raise ValueError("cannot map against non-return XCom")
310 return super().zip(*others, fillvalue=fillvalue)
312 def get_task_map_length(self, run_id: str, *, session: Session) -> int | None:
313 from airflow.models.taskinstance import TaskInstance
314 from airflow.models.taskmap import TaskMap
315 from airflow.models.xcom import XCom
317 task = self.operator
318 if isinstance(task, MappedOperator):
319 unfinished_ti_count_query = session.query(func.count(TaskInstance.map_index)).filter(
320 TaskInstance.dag_id == task.dag_id,
321 TaskInstance.run_id == run_id,
322 TaskInstance.task_id == task.task_id,
323 # Special NULL treatment is needed because 'state' can be NULL.
324 # The "IN" part would produce "NULL NOT IN ..." and eventually
325 # "NULl = NULL", which is a big no-no in SQL.
326 or_(
327 TaskInstance.state.is_(None),
328 TaskInstance.state.in_(s.value for s in State.unfinished if s is not None),
329 ),
330 )
331 if unfinished_ti_count_query.scalar():
332 return None # Not all of the expanded tis are done yet.
333 query = session.query(func.count(XCom.map_index)).filter(
334 XCom.dag_id == task.dag_id,
335 XCom.run_id == run_id,
336 XCom.task_id == task.task_id,
337 XCom.map_index >= 0,
338 XCom.key == XCOM_RETURN_KEY,
339 )
340 else:
341 query = session.query(TaskMap.length).filter(
342 TaskMap.dag_id == task.dag_id,
343 TaskMap.run_id == run_id,
344 TaskMap.task_id == task.task_id,
345 TaskMap.map_index < 0,
346 )
347 return query.scalar()
349 @provide_session
350 def resolve(self, context: Context, session: Session = NEW_SESSION) -> Any:
351 from airflow.models.taskinstance import TaskInstance
353 ti = context["ti"]
354 assert isinstance(ti, TaskInstance), "Wait for AIP-44 implementation to complete"
356 task_id = self.operator.task_id
357 map_indexes = ti.get_relevant_upstream_map_indexes(
358 self.operator,
359 context["expanded_ti_count"],
360 session=session,
361 )
362 result = ti.xcom_pull(
363 task_ids=task_id,
364 map_indexes=map_indexes,
365 key=self.key,
366 default=NOTSET,
367 session=session,
368 )
369 if not isinstance(result, ArgNotSet):
370 return result
371 if self.key == XCOM_RETURN_KEY:
372 return None
373 raise XComNotFound(ti.dag_id, task_id, self.key)
376def _get_callable_name(f: Callable | str) -> str:
377 """Try to "describe" a callable by getting its name."""
378 if callable(f):
379 return f.__name__
380 # Parse the source to find whatever is behind "def". For safety, we don't
381 # want to evaluate the code in any meaningful way!
382 with contextlib.suppress(Exception):
383 kw, name, _ = f.lstrip().split(None, 2)
384 if kw == "def":
385 return name
386 return "<function>"
389class _MapResult(Sequence):
390 def __init__(self, value: Sequence | dict, callables: MapCallables) -> None:
391 self.value = value
392 self.callables = callables
394 def __getitem__(self, index: Any) -> Any:
395 value = self.value[index]
397 # In the worker, we can access all actual callables. Call them.
398 callables = [f for f in self.callables if callable(f)]
399 if len(callables) == len(self.callables):
400 for f in callables:
401 value = f(value)
402 return value
404 # In the scheduler, we don't have access to the actual callables, nor do
405 # we want to run it since it's arbitrary code. This builds a string to
406 # represent the call chain in the UI or logs instead.
407 for v in self.callables:
408 value = f"{_get_callable_name(v)}({value})"
409 return value
411 def __len__(self) -> int:
412 return len(self.value)
415class MapXComArg(XComArg):
416 """An XCom reference with ``map()`` call(s) applied.
418 This is based on an XComArg, but also applies a series of "transforms" that
419 convert the pulled XCom value.
421 :meta private:
422 """
424 def __init__(self, arg: XComArg, callables: MapCallables) -> None:
425 for c in callables:
426 if getattr(c, "_airflow_is_task_decorator", False):
427 raise ValueError("map() argument must be a plain function, not a @task operator")
428 self.arg = arg
429 self.callables = callables
431 def __repr__(self) -> str:
432 map_calls = "".join(f".map({_get_callable_name(f)})" for f in self.callables)
433 return f"{self.arg!r}{map_calls}"
435 def _serialize(self) -> dict[str, Any]:
436 return {
437 "arg": serialize_xcom_arg(self.arg),
438 "callables": [inspect.getsource(c) if callable(c) else c for c in self.callables],
439 }
441 @classmethod
442 def _deserialize(cls, data: dict[str, Any], dag: DAG) -> XComArg:
443 # We are deliberately NOT deserializing the callables. These are shown
444 # in the UI, and displaying a function object is useless.
445 return cls(deserialize_xcom_arg(data["arg"], dag), data["callables"])
447 def iter_references(self) -> Iterator[tuple[Operator, str]]:
448 yield from self.arg.iter_references()
450 def map(self, f: Callable[[Any], Any]) -> MapXComArg:
451 # Flatten arg.map(f1).map(f2) into one MapXComArg.
452 return MapXComArg(self.arg, [*self.callables, f])
454 def get_task_map_length(self, run_id: str, *, session: Session) -> int | None:
455 return self.arg.get_task_map_length(run_id, session=session)
457 @provide_session
458 def resolve(self, context: Context, session: Session = NEW_SESSION) -> Any:
459 value = self.arg.resolve(context, session=session)
460 if not isinstance(value, (Sequence, dict)):
461 raise ValueError(f"XCom map expects sequence or dict, not {type(value).__name__}")
462 return _MapResult(value, self.callables)
465class _ZipResult(Sequence):
466 def __init__(self, values: Sequence[Sequence | dict], *, fillvalue: Any = NOTSET) -> None:
467 self.values = values
468 self.fillvalue = fillvalue
470 @staticmethod
471 def _get_or_fill(container: Sequence | dict, index: Any, fillvalue: Any) -> Any:
472 try:
473 return container[index]
474 except (IndexError, KeyError):
475 return fillvalue
477 def __getitem__(self, index: Any) -> Any:
478 if index >= len(self):
479 raise IndexError(index)
480 return tuple(self._get_or_fill(value, index, self.fillvalue) for value in self.values)
482 def __len__(self) -> int:
483 lengths = (len(v) for v in self.values)
484 if isinstance(self.fillvalue, ArgNotSet):
485 return min(lengths)
486 return max(lengths)
489class ZipXComArg(XComArg):
490 """An XCom reference with ``zip()`` applied.
492 This is constructed from multiple XComArg instances, and presents an
493 iterable that "zips" them together like the built-in ``zip()`` (and
494 ``itertools.zip_longest()`` if ``fillvalue`` is provided).
495 """
497 def __init__(self, args: Sequence[XComArg], *, fillvalue: Any = NOTSET) -> None:
498 if not args:
499 raise ValueError("At least one input is required")
500 self.args = args
501 self.fillvalue = fillvalue
503 def __repr__(self) -> str:
504 args_iter = iter(self.args)
505 first = repr(next(args_iter))
506 rest = ", ".join(repr(arg) for arg in args_iter)
507 if isinstance(self.fillvalue, ArgNotSet):
508 return f"{first}.zip({rest})"
509 return f"{first}.zip({rest}, fillvalue={self.fillvalue!r})"
511 def _serialize(self) -> dict[str, Any]:
512 args = [serialize_xcom_arg(arg) for arg in self.args]
513 if isinstance(self.fillvalue, ArgNotSet):
514 return {"args": args}
515 return {"args": args, "fillvalue": self.fillvalue}
517 @classmethod
518 def _deserialize(cls, data: dict[str, Any], dag: DAG) -> XComArg:
519 return cls(
520 [deserialize_xcom_arg(arg, dag) for arg in data["args"]],
521 fillvalue=data.get("fillvalue", NOTSET),
522 )
524 def iter_references(self) -> Iterator[tuple[Operator, str]]:
525 for arg in self.args:
526 yield from arg.iter_references()
528 def get_task_map_length(self, run_id: str, *, session: Session) -> int | None:
529 all_lengths = (arg.get_task_map_length(run_id, session=session) for arg in self.args)
530 ready_lengths = [length for length in all_lengths if length is not None]
531 if len(ready_lengths) != len(self.args):
532 return None # If any of the referenced XComs is not ready, we are not ready either.
533 if isinstance(self.fillvalue, ArgNotSet):
534 return min(ready_lengths)
535 return max(ready_lengths)
537 @provide_session
538 def resolve(self, context: Context, session: Session = NEW_SESSION) -> Any:
539 values = [arg.resolve(context, session=session) for arg in self.args]
540 for value in values:
541 if not isinstance(value, (Sequence, dict)):
542 raise ValueError(f"XCom zip expects sequence or dict, not {type(value).__name__}")
543 return _ZipResult(values, fillvalue=self.fillvalue)
546_XCOM_ARG_TYPES: Mapping[str, type[XComArg]] = {
547 "": PlainXComArg,
548 "map": MapXComArg,
549 "zip": ZipXComArg,
550}
553def serialize_xcom_arg(value: XComArg) -> dict[str, Any]:
554 """DAG serialization interface."""
555 key = next(k for k, v in _XCOM_ARG_TYPES.items() if v == type(value))
556 if key:
557 return {"type": key, **value._serialize()}
558 return value._serialize()
561def deserialize_xcom_arg(data: dict[str, Any], dag: DAG) -> XComArg:
562 """DAG serialization interface."""
563 klass = _XCOM_ARG_TYPES[data.get("type", "")]
564 return klass._deserialize(data, dag)