Coverage for /pythoncovmergedfiles/medio/medio/src/airflow/build/lib/airflow/models/xcom_arg.py: 40%
249 statements
« prev ^ index » next coverage.py v7.0.1, created at 2022-12-25 06:11 +0000
« prev ^ index » next coverage.py v7.0.1, created at 2022-12-25 06:11 +0000
1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements. See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership. The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License. You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied. See the License for the
15# specific language governing permissions and limitations
16# under the License.
18from __future__ import annotations
20import contextlib
21import inspect
22from typing import TYPE_CHECKING, Any, Callable, Iterator, Mapping, Sequence, Union, overload
24from sqlalchemy import func
25from sqlalchemy.orm import Session
27from airflow.exceptions import XComNotFound
28from airflow.models.abstractoperator import AbstractOperator
29from airflow.models.mappedoperator import MappedOperator
30from airflow.models.taskmixin import DAGNode, DependencyMixin
31from airflow.models.xcom import XCOM_RETURN_KEY
32from airflow.utils.context import Context
33from airflow.utils.edgemodifier import EdgeModifier
34from airflow.utils.mixins import ResolveMixin
35from airflow.utils.session import NEW_SESSION, provide_session
36from airflow.utils.types import NOTSET, ArgNotSet
38if TYPE_CHECKING:
39 from airflow.models.dag import DAG
40 from airflow.models.operator import Operator
42# Callable objects contained by MapXComArg. We only accept callables from
43# the user, but deserialize them into strings in a serialized XComArg for
44# safety (those callables are arbitrary user code).
45MapCallables = Sequence[Union[Callable[[Any], Any], str]]
48class XComArg(ResolveMixin, DependencyMixin):
49 """Reference to an XCom value pushed from another operator.
51 The implementation supports::
53 xcomarg >> op
54 xcomarg << op
55 op >> xcomarg # By BaseOperator code
56 op << xcomarg # By BaseOperator code
58 **Example**: The moment you get a result from any operator (decorated or regular) you can ::
60 any_op = AnyOperator()
61 xcomarg = XComArg(any_op)
62 # or equivalently
63 xcomarg = any_op.output
64 my_op = MyOperator()
65 my_op >> xcomarg
67 This object can be used in legacy Operators via Jinja.
69 **Example**: You can make this result to be part of any generated string::
71 any_op = AnyOperator()
72 xcomarg = any_op.output
73 op1 = MyOperator(my_text_message=f"the value is {xcomarg}")
74 op2 = MyOperator(my_text_message=f"the value is {xcomarg['topic']}")
76 :param operator: Operator instance to which the XComArg references.
77 :param key: Key used to pull the XCom value. Defaults to *XCOM_RETURN_KEY*,
78 i.e. the referenced operator's return value.
79 """
81 @overload
82 def __new__(cls: type[XComArg], operator: Operator, key: str = XCOM_RETURN_KEY) -> XComArg:
83 """Called when the user writes ``XComArg(...)`` directly."""
85 @overload
86 def __new__(cls: type[XComArg]) -> XComArg:
87 """Called by Python internals from subclasses."""
89 def __new__(cls, *args, **kwargs) -> XComArg:
90 if cls is XComArg:
91 return PlainXComArg(*args, **kwargs)
92 return super().__new__(cls)
94 @staticmethod
95 def iter_xcom_references(arg: Any) -> Iterator[tuple[Operator, str]]:
96 """Return XCom references in an arbitrary value.
98 Recursively traverse ``arg`` and look for XComArg instances in any
99 collection objects, and instances with ``template_fields`` set.
100 """
101 if isinstance(arg, ResolveMixin):
102 yield from arg.iter_references()
103 elif isinstance(arg, (tuple, set, list)):
104 for elem in arg:
105 yield from XComArg.iter_xcom_references(elem)
106 elif isinstance(arg, dict):
107 for elem in arg.values():
108 yield from XComArg.iter_xcom_references(elem)
109 elif isinstance(arg, AbstractOperator):
110 for attr in arg.template_fields:
111 yield from XComArg.iter_xcom_references(getattr(arg, attr))
113 @staticmethod
114 def apply_upstream_relationship(op: Operator, arg: Any):
115 """Set dependency for XComArgs.
117 This looks for XComArg objects in ``arg`` "deeply" (looking inside
118 collections objects and classes decorated with ``template_fields``), and
119 sets the relationship to ``op`` on any found.
120 """
121 for operator, _ in XComArg.iter_xcom_references(arg):
122 op.set_upstream(operator)
124 @property
125 def roots(self) -> list[DAGNode]:
126 """Required by TaskMixin"""
127 return [op for op, _ in self.iter_references()]
129 @property
130 def leaves(self) -> list[DAGNode]:
131 """Required by TaskMixin"""
132 return [op for op, _ in self.iter_references()]
134 def set_upstream(
135 self,
136 task_or_task_list: DependencyMixin | Sequence[DependencyMixin],
137 edge_modifier: EdgeModifier | None = None,
138 ):
139 """Proxy to underlying operator set_upstream method. Required by TaskMixin."""
140 for operator, _ in self.iter_references():
141 operator.set_upstream(task_or_task_list, edge_modifier)
143 def set_downstream(
144 self,
145 task_or_task_list: DependencyMixin | Sequence[DependencyMixin],
146 edge_modifier: EdgeModifier | None = None,
147 ):
148 """Proxy to underlying operator set_downstream method. Required by TaskMixin."""
149 for operator, _ in self.iter_references():
150 operator.set_downstream(task_or_task_list, edge_modifier)
152 def _serialize(self) -> dict[str, Any]:
153 """Called by DAG serialization.
155 The implementation should be the inverse function to ``deserialize``,
156 returning a data dict converted from this XComArg derivative. DAG
157 serialization does not call this directly, but ``serialize_xcom_arg``
158 instead, which adds additional information to dispatch deserialization
159 to the correct class.
160 """
161 raise NotImplementedError()
163 @classmethod
164 def _deserialize(cls, data: dict[str, Any], dag: DAG) -> XComArg:
165 """Called when deserializing a DAG.
167 The implementation should be the inverse function to ``serialize``,
168 implementing given a data dict converted from this XComArg derivative,
169 how the original XComArg should be created. DAG serialization relies on
170 additional information added in ``serialize_xcom_arg`` to dispatch data
171 dicts to the correct ``_deserialize`` information, so this function does
172 not need to validate whether the incoming data contains correct keys.
173 """
174 raise NotImplementedError()
176 def map(self, f: Callable[[Any], Any]) -> MapXComArg:
177 return MapXComArg(self, [f])
179 def zip(self, *others: XComArg, fillvalue: Any = NOTSET) -> ZipXComArg:
180 return ZipXComArg([self, *others], fillvalue=fillvalue)
182 def get_task_map_length(self, run_id: str, *, session: Session) -> int | None:
183 """Inspect length of pushed value for task-mapping.
185 This is used to determine how many task instances the scheduler should
186 create for a downstream using this XComArg for task-mapping.
188 *None* may be returned if the depended XCom has not been pushed.
189 """
190 raise NotImplementedError()
192 @provide_session
193 def resolve(self, context: Context, session: Session = NEW_SESSION) -> Any:
194 """Pull XCom value.
196 This should only be called during ``op.execute()`` with an appropriate
197 context (e.g. generated from ``TaskInstance.get_template_context()``).
198 Although the ``ResolveMixin`` parent mixin also has a ``resolve``
199 protocol, this adds the optional ``session`` argument that some of the
200 subclasses need.
202 :meta private:
203 """
204 raise NotImplementedError()
207class PlainXComArg(XComArg):
208 """Reference to one single XCom without any additional semantics.
210 This class should not be accessed directly, but only through XComArg. The
211 class inheritance chain and ``__new__`` is implemented in this slightly
212 convoluted way because we want to
214 a. Allow the user to continue using XComArg directly for the simple
215 semantics (see documentation of the base class for details).
216 b. Make ``isinstance(thing, XComArg)`` be able to detect all kinds of XCom
217 references.
218 c. Not allow many properties of PlainXComArg (including ``__getitem__`` and
219 ``__str__``) to exist on other kinds of XComArg implementations since
220 they don't make sense.
222 :meta private:
223 """
225 def __init__(self, operator: Operator, key: str = XCOM_RETURN_KEY):
226 self.operator = operator
227 self.key = key
229 def __eq__(self, other: Any) -> bool:
230 if not isinstance(other, PlainXComArg):
231 return NotImplemented
232 return self.operator == other.operator and self.key == other.key
234 def __getitem__(self, item: str) -> XComArg:
235 """Implements xcomresult['some_result_key']"""
236 if not isinstance(item, str):
237 raise ValueError(f"XComArg only supports str lookup, received {type(item).__name__}")
238 return PlainXComArg(operator=self.operator, key=item)
240 def __iter__(self):
241 """Override iterable protocol to raise error explicitly.
243 The default ``__iter__`` implementation in Python calls ``__getitem__``
244 with 0, 1, 2, etc. until it hits an ``IndexError``. This does not work
245 well with our custom ``__getitem__`` implementation, and results in poor
246 DAG-writing experience since a misplaced ``*`` expansion would create an
247 infinite loop consuming the entire DAG parser.
249 This override catches the error eagerly, so an incorrectly implemented
250 DAG fails fast and avoids wasting resources on nonsensical iterating.
251 """
252 raise TypeError("'XComArg' object is not iterable")
254 def __repr__(self) -> str:
255 if self.key == XCOM_RETURN_KEY:
256 return f"XComArg({self.operator!r})"
257 return f"XComArg({self.operator!r}, {self.key!r})"
259 def __str__(self) -> str:
260 """
261 Backward compatibility for old-style jinja used in Airflow Operators
263 **Example**: to use XComArg at BashOperator::
265 BashOperator(cmd=f"... { xcomarg } ...")
267 :return:
268 """
269 xcom_pull_kwargs = [
270 f"task_ids='{self.operator.task_id}'",
271 f"dag_id='{self.operator.dag_id}'",
272 ]
273 if self.key is not None:
274 xcom_pull_kwargs.append(f"key='{self.key}'")
276 xcom_pull_str = ", ".join(xcom_pull_kwargs)
277 # {{{{ are required for escape {{ in f-string
278 xcom_pull = f"{{{{ task_instance.xcom_pull({xcom_pull_str}) }}}}"
279 return xcom_pull
281 def _serialize(self) -> dict[str, Any]:
282 return {"task_id": self.operator.task_id, "key": self.key}
284 @classmethod
285 def _deserialize(cls, data: dict[str, Any], dag: DAG) -> XComArg:
286 return cls(dag.get_task(data["task_id"]), data["key"])
288 def iter_references(self) -> Iterator[tuple[Operator, str]]:
289 yield self.operator, self.key
291 def map(self, f: Callable[[Any], Any]) -> MapXComArg:
292 if self.key != XCOM_RETURN_KEY:
293 raise ValueError("cannot map against non-return XCom")
294 return super().map(f)
296 def zip(self, *others: XComArg, fillvalue: Any = NOTSET) -> ZipXComArg:
297 if self.key != XCOM_RETURN_KEY:
298 raise ValueError("cannot map against non-return XCom")
299 return super().zip(*others, fillvalue=fillvalue)
301 def get_task_map_length(self, run_id: str, *, session: Session) -> int | None:
302 from airflow.models.taskmap import TaskMap
303 from airflow.models.xcom import XCom
305 task = self.operator
306 if isinstance(task, MappedOperator):
307 query = session.query(func.count(XCom.map_index)).filter(
308 XCom.dag_id == task.dag_id,
309 XCom.run_id == run_id,
310 XCom.task_id == task.task_id,
311 XCom.map_index >= 0,
312 XCom.key == XCOM_RETURN_KEY,
313 )
314 else:
315 query = session.query(TaskMap.length).filter(
316 TaskMap.dag_id == task.dag_id,
317 TaskMap.run_id == run_id,
318 TaskMap.task_id == task.task_id,
319 TaskMap.map_index < 0,
320 )
321 return query.scalar()
323 @provide_session
324 def resolve(self, context: Context, session: Session = NEW_SESSION) -> Any:
325 ti = context["ti"]
326 task_id = self.operator.task_id
327 map_indexes = ti.get_relevant_upstream_map_indexes(
328 self.operator,
329 context["expanded_ti_count"],
330 session=session,
331 )
332 result = ti.xcom_pull(
333 task_ids=task_id,
334 map_indexes=map_indexes,
335 key=self.key,
336 default=NOTSET,
337 session=session,
338 )
339 if not isinstance(result, ArgNotSet):
340 return result
341 if self.key == XCOM_RETURN_KEY:
342 return None
343 raise XComNotFound(ti.dag_id, task_id, self.key)
346def _get_callable_name(f: Callable | str) -> str:
347 """Try to "describe" a callable by getting its name."""
348 if callable(f):
349 return f.__name__
350 # Parse the source to find whatever is behind "def". For safety, we don't
351 # want to evaluate the code in any meaningful way!
352 with contextlib.suppress(Exception):
353 kw, name, _ = f.lstrip().split(None, 2)
354 if kw == "def":
355 return name
356 return "<function>"
359class _MapResult(Sequence):
360 def __init__(self, value: Sequence | dict, callables: MapCallables) -> None:
361 self.value = value
362 self.callables = callables
364 def __getitem__(self, index: Any) -> Any:
365 value = self.value[index]
367 # In the worker, we can access all actual callables. Call them.
368 callables = [f for f in self.callables if callable(f)]
369 if len(callables) == len(self.callables):
370 for f in callables:
371 value = f(value)
372 return value
374 # In the scheduler, we don't have access to the actual callables, nor do
375 # we want to run it since it's arbitrary code. This builds a string to
376 # represent the call chain in the UI or logs instead.
377 for v in self.callables:
378 value = f"{_get_callable_name(v)}({value})"
379 return value
381 def __len__(self) -> int:
382 return len(self.value)
385class MapXComArg(XComArg):
386 """An XCom reference with ``map()`` call(s) applied.
388 This is based on an XComArg, but also applies a series of "transforms" that
389 convert the pulled XCom value.
391 :meta private:
392 """
394 def __init__(self, arg: XComArg, callables: MapCallables) -> None:
395 for c in callables:
396 if getattr(c, "_airflow_is_task_decorator", False):
397 raise ValueError("map() argument must be a plain function, not a @task operator")
398 self.arg = arg
399 self.callables = callables
401 def __repr__(self) -> str:
402 map_calls = "".join(f".map({_get_callable_name(f)})" for f in self.callables)
403 return f"{self.arg!r}{map_calls}"
405 def _serialize(self) -> dict[str, Any]:
406 return {
407 "arg": serialize_xcom_arg(self.arg),
408 "callables": [inspect.getsource(c) if callable(c) else c for c in self.callables],
409 }
411 @classmethod
412 def _deserialize(cls, data: dict[str, Any], dag: DAG) -> XComArg:
413 # We are deliberately NOT deserializing the callables. These are shown
414 # in the UI, and displaying a function object is useless.
415 return cls(deserialize_xcom_arg(data["arg"], dag), data["callables"])
417 def iter_references(self) -> Iterator[tuple[Operator, str]]:
418 yield from self.arg.iter_references()
420 def map(self, f: Callable[[Any], Any]) -> MapXComArg:
421 # Flatten arg.map(f1).map(f2) into one MapXComArg.
422 return MapXComArg(self.arg, [*self.callables, f])
424 def get_task_map_length(self, run_id: str, *, session: Session) -> int | None:
425 return self.arg.get_task_map_length(run_id, session=session)
427 @provide_session
428 def resolve(self, context: Context, session: Session = NEW_SESSION) -> Any:
429 value = self.arg.resolve(context, session=session)
430 if not isinstance(value, (Sequence, dict)):
431 raise ValueError(f"XCom map expects sequence or dict, not {type(value).__name__}")
432 return _MapResult(value, self.callables)
435class _ZipResult(Sequence):
436 def __init__(self, values: Sequence[Sequence | dict], *, fillvalue: Any = NOTSET) -> None:
437 self.values = values
438 self.fillvalue = fillvalue
440 @staticmethod
441 def _get_or_fill(container: Sequence | dict, index: Any, fillvalue: Any) -> Any:
442 try:
443 return container[index]
444 except (IndexError, KeyError):
445 return fillvalue
447 def __getitem__(self, index: Any) -> Any:
448 if index >= len(self):
449 raise IndexError(index)
450 return tuple(self._get_or_fill(value, index, self.fillvalue) for value in self.values)
452 def __len__(self) -> int:
453 lengths = (len(v) for v in self.values)
454 if isinstance(self.fillvalue, ArgNotSet):
455 return min(lengths)
456 return max(lengths)
459class ZipXComArg(XComArg):
460 """An XCom reference with ``zip()`` applied.
462 This is constructed from multiple XComArg instances, and presents an
463 iterable that "zips" them together like the built-in ``zip()`` (and
464 ``itertools.zip_longest()`` if ``fillvalue`` is provided).
465 """
467 def __init__(self, args: Sequence[XComArg], *, fillvalue: Any = NOTSET) -> None:
468 if not args:
469 raise ValueError("At least one input is required")
470 self.args = args
471 self.fillvalue = fillvalue
473 def __repr__(self) -> str:
474 args_iter = iter(self.args)
475 first = repr(next(args_iter))
476 rest = ", ".join(repr(arg) for arg in args_iter)
477 if isinstance(self.fillvalue, ArgNotSet):
478 return f"{first}.zip({rest})"
479 return f"{first}.zip({rest}, fillvalue={self.fillvalue!r})"
481 def _serialize(self) -> dict[str, Any]:
482 args = [serialize_xcom_arg(arg) for arg in self.args]
483 if isinstance(self.fillvalue, ArgNotSet):
484 return {"args": args}
485 return {"args": args, "fillvalue": self.fillvalue}
487 @classmethod
488 def _deserialize(cls, data: dict[str, Any], dag: DAG) -> XComArg:
489 return cls(
490 [deserialize_xcom_arg(arg, dag) for arg in data["args"]],
491 fillvalue=data.get("fillvalue", NOTSET),
492 )
494 def iter_references(self) -> Iterator[tuple[Operator, str]]:
495 for arg in self.args:
496 yield from arg.iter_references()
498 def get_task_map_length(self, run_id: str, *, session: Session) -> int | None:
499 all_lengths = (arg.get_task_map_length(run_id, session=session) for arg in self.args)
500 ready_lengths = [length for length in all_lengths if length is not None]
501 if len(ready_lengths) != len(self.args):
502 return None # If any of the referenced XComs is not ready, we are not ready either.
503 if isinstance(self.fillvalue, ArgNotSet):
504 return min(ready_lengths)
505 return max(ready_lengths)
507 @provide_session
508 def resolve(self, context: Context, session: Session = NEW_SESSION) -> Any:
509 values = [arg.resolve(context, session=session) for arg in self.args]
510 for value in values:
511 if not isinstance(value, (Sequence, dict)):
512 raise ValueError(f"XCom zip expects sequence or dict, not {type(value).__name__}")
513 return _ZipResult(values, fillvalue=self.fillvalue)
516_XCOM_ARG_TYPES: Mapping[str, type[XComArg]] = {
517 "": PlainXComArg,
518 "map": MapXComArg,
519 "zip": ZipXComArg,
520}
523def serialize_xcom_arg(value: XComArg) -> dict[str, Any]:
524 """DAG serialization interface."""
525 key = next(k for k, v in _XCOM_ARG_TYPES.items() if v == type(value))
526 if key:
527 return {"type": key, **value._serialize()}
528 return value._serialize()
531def deserialize_xcom_arg(data: dict[str, Any], dag: DAG) -> XComArg:
532 """DAG serialization interface."""
533 klass = _XCOM_ARG_TYPES[data.get("type", "")]
534 return klass._deserialize(data, dag)