Coverage for /pythoncovmergedfiles/medio/medio/src/airflow/airflow/models/xcom_arg.py: 38%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements. See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership. The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License. You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied. See the License for the
15# specific language governing permissions and limitations
16# under the License.
18from __future__ import annotations
20import contextlib
21import inspect
22from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator, Mapping, Sequence, Union, overload
24from sqlalchemy import func, or_, select
26from airflow.exceptions import AirflowException, XComNotFound
27from airflow.models.abstractoperator import AbstractOperator
28from airflow.models.mappedoperator import MappedOperator
29from airflow.models.taskmixin import DependencyMixin
30from airflow.utils.db import exists_query
31from airflow.utils.mixins import ResolveMixin
32from airflow.utils.session import NEW_SESSION, provide_session
33from airflow.utils.setup_teardown import SetupTeardownContext
34from airflow.utils.state import State
35from airflow.utils.trigger_rule import TriggerRule
36from airflow.utils.types import NOTSET, ArgNotSet
37from airflow.utils.xcom import XCOM_RETURN_KEY
39if TYPE_CHECKING:
40 from sqlalchemy.orm import Session
42 from airflow.models.baseoperator import BaseOperator
43 from airflow.models.dag import DAG
44 from airflow.models.operator import Operator
45 from airflow.models.taskmixin import DAGNode
46 from airflow.utils.context import Context
47 from airflow.utils.edgemodifier import EdgeModifier
49# Callable objects contained by MapXComArg. We only accept callables from
50# the user, but deserialize them into strings in a serialized XComArg for
51# safety (those callables are arbitrary user code).
52MapCallables = Sequence[Union[Callable[[Any], Any], str]]
55class XComArg(ResolveMixin, DependencyMixin):
56 """Reference to an XCom value pushed from another operator.
58 The implementation supports::
60 xcomarg >> op
61 xcomarg << op
62 op >> xcomarg # By BaseOperator code
63 op << xcomarg # By BaseOperator code
65 **Example**: The moment you get a result from any operator (decorated or regular) you can ::
67 any_op = AnyOperator()
68 xcomarg = XComArg(any_op)
69 # or equivalently
70 xcomarg = any_op.output
71 my_op = MyOperator()
72 my_op >> xcomarg
74 This object can be used in legacy Operators via Jinja.
76 **Example**: You can make this result to be part of any generated string::
78 any_op = AnyOperator()
79 xcomarg = any_op.output
80 op1 = MyOperator(my_text_message=f"the value is {xcomarg}")
81 op2 = MyOperator(my_text_message=f"the value is {xcomarg['topic']}")
83 :param operator: Operator instance to which the XComArg references.
84 :param key: Key used to pull the XCom value. Defaults to *XCOM_RETURN_KEY*,
85 i.e. the referenced operator's return value.
86 """
88 @overload
89 def __new__(cls: type[XComArg], operator: Operator, key: str = XCOM_RETURN_KEY) -> XComArg:
90 """Execute when the user writes ``XComArg(...)`` directly."""
92 @overload
93 def __new__(cls: type[XComArg]) -> XComArg:
94 """Execute by Python internals from subclasses."""
96 def __new__(cls, *args, **kwargs) -> XComArg:
97 if cls is XComArg:
98 return PlainXComArg(*args, **kwargs)
99 return super().__new__(cls)
101 @staticmethod
102 def iter_xcom_references(arg: Any) -> Iterator[tuple[Operator, str]]:
103 """Return XCom references in an arbitrary value.
105 Recursively traverse ``arg`` and look for XComArg instances in any
106 collection objects, and instances with ``template_fields`` set.
107 """
108 if isinstance(arg, ResolveMixin):
109 yield from arg.iter_references()
110 elif isinstance(arg, (tuple, set, list)):
111 for elem in arg:
112 yield from XComArg.iter_xcom_references(elem)
113 elif isinstance(arg, dict):
114 for elem in arg.values():
115 yield from XComArg.iter_xcom_references(elem)
116 elif isinstance(arg, AbstractOperator):
117 for attr in arg.template_fields:
118 yield from XComArg.iter_xcom_references(getattr(arg, attr))
120 @staticmethod
121 def apply_upstream_relationship(op: Operator, arg: Any):
122 """Set dependency for XComArgs.
124 This looks for XComArg objects in ``arg`` "deeply" (looking inside
125 collections objects and classes decorated with ``template_fields``), and
126 sets the relationship to ``op`` on any found.
127 """
128 for operator, _ in XComArg.iter_xcom_references(arg):
129 op.set_upstream(operator)
131 @property
132 def roots(self) -> list[DAGNode]:
133 """Required by TaskMixin."""
134 return [op for op, _ in self.iter_references()]
136 @property
137 def leaves(self) -> list[DAGNode]:
138 """Required by TaskMixin."""
139 return [op for op, _ in self.iter_references()]
141 def set_upstream(
142 self,
143 task_or_task_list: DependencyMixin | Sequence[DependencyMixin],
144 edge_modifier: EdgeModifier | None = None,
145 ):
146 """Proxy to underlying operator set_upstream method. Required by TaskMixin."""
147 for operator, _ in self.iter_references():
148 operator.set_upstream(task_or_task_list, edge_modifier)
150 def set_downstream(
151 self,
152 task_or_task_list: DependencyMixin | Sequence[DependencyMixin],
153 edge_modifier: EdgeModifier | None = None,
154 ):
155 """Proxy to underlying operator set_downstream method. Required by TaskMixin."""
156 for operator, _ in self.iter_references():
157 operator.set_downstream(task_or_task_list, edge_modifier)
159 def _serialize(self) -> dict[str, Any]:
160 """
161 Serialize a DAG.
163 The implementation should be the inverse function to ``deserialize``,
164 returning a data dict converted from this XComArg derivative. DAG
165 serialization does not call this directly, but ``serialize_xcom_arg``
166 instead, which adds additional information to dispatch deserialization
167 to the correct class.
168 """
169 raise NotImplementedError()
171 @classmethod
172 def _deserialize(cls, data: dict[str, Any], dag: DAG) -> XComArg:
173 """
174 Deserialize a DAG.
176 The implementation should be the inverse function to ``serialize``,
177 implementing given a data dict converted from this XComArg derivative,
178 how the original XComArg should be created. DAG serialization relies on
179 additional information added in ``serialize_xcom_arg`` to dispatch data
180 dicts to the correct ``_deserialize`` information, so this function does
181 not need to validate whether the incoming data contains correct keys.
182 """
183 raise NotImplementedError()
185 def map(self, f: Callable[[Any], Any]) -> MapXComArg:
186 return MapXComArg(self, [f])
188 def zip(self, *others: XComArg, fillvalue: Any = NOTSET) -> ZipXComArg:
189 return ZipXComArg([self, *others], fillvalue=fillvalue)
191 def get_task_map_length(self, run_id: str, *, session: Session) -> int | None:
192 """Inspect length of pushed value for task-mapping.
194 This is used to determine how many task instances the scheduler should
195 create for a downstream using this XComArg for task-mapping.
197 *None* may be returned if the depended XCom has not been pushed.
198 """
199 raise NotImplementedError()
201 @provide_session
202 def resolve(self, context: Context, session: Session = NEW_SESSION) -> Any:
203 """Pull XCom value.
205 This should only be called during ``op.execute()`` with an appropriate
206 context (e.g. generated from ``TaskInstance.get_template_context()``).
207 Although the ``ResolveMixin`` parent mixin also has a ``resolve``
208 protocol, this adds the optional ``session`` argument that some of the
209 subclasses need.
211 :meta private:
212 """
213 raise NotImplementedError()
215 def __enter__(self):
216 if not self.operator.is_setup and not self.operator.is_teardown:
217 raise AirflowException("Only setup/teardown tasks can be used as context managers.")
218 SetupTeardownContext.push_setup_teardown_task(self.operator)
219 return SetupTeardownContext
221 def __exit__(self, exc_type, exc_val, exc_tb):
222 SetupTeardownContext.set_work_task_roots_and_leaves()
225class PlainXComArg(XComArg):
226 """Reference to one single XCom without any additional semantics.
228 This class should not be accessed directly, but only through XComArg. The
229 class inheritance chain and ``__new__`` is implemented in this slightly
230 convoluted way because we want to
232 a. Allow the user to continue using XComArg directly for the simple
233 semantics (see documentation of the base class for details).
234 b. Make ``isinstance(thing, XComArg)`` be able to detect all kinds of XCom
235 references.
236 c. Not allow many properties of PlainXComArg (including ``__getitem__`` and
237 ``__str__``) to exist on other kinds of XComArg implementations since
238 they don't make sense.
240 :meta private:
241 """
243 def __init__(self, operator: Operator, key: str = XCOM_RETURN_KEY):
244 self.operator = operator
245 self.key = key
247 def __eq__(self, other: Any) -> bool:
248 if not isinstance(other, PlainXComArg):
249 return NotImplemented
250 return self.operator == other.operator and self.key == other.key
252 def __getitem__(self, item: str) -> XComArg:
253 """Implement xcomresult['some_result_key']."""
254 if not isinstance(item, str):
255 raise ValueError(f"XComArg only supports str lookup, received {type(item).__name__}")
256 return PlainXComArg(operator=self.operator, key=item)
258 def __iter__(self):
259 """Override iterable protocol to raise error explicitly.
261 The default ``__iter__`` implementation in Python calls ``__getitem__``
262 with 0, 1, 2, etc. until it hits an ``IndexError``. This does not work
263 well with our custom ``__getitem__`` implementation, and results in poor
264 DAG-writing experience since a misplaced ``*`` expansion would create an
265 infinite loop consuming the entire DAG parser.
267 This override catches the error eagerly, so an incorrectly implemented
268 DAG fails fast and avoids wasting resources on nonsensical iterating.
269 """
270 raise TypeError("'XComArg' object is not iterable")
272 def __repr__(self) -> str:
273 if self.key == XCOM_RETURN_KEY:
274 return f"XComArg({self.operator!r})"
275 return f"XComArg({self.operator!r}, {self.key!r})"
277 def __str__(self) -> str:
278 """
279 Backward compatibility for old-style jinja used in Airflow Operators.
281 **Example**: to use XComArg at BashOperator::
283 BashOperator(cmd=f"... { xcomarg } ...")
285 :return:
286 """
287 xcom_pull_kwargs = [
288 f"task_ids='{self.operator.task_id}'",
289 f"dag_id='{self.operator.dag_id}'",
290 ]
291 if self.key is not None:
292 xcom_pull_kwargs.append(f"key='{self.key}'")
294 xcom_pull_str = ", ".join(xcom_pull_kwargs)
295 # {{{{ are required for escape {{ in f-string
296 xcom_pull = f"{{{{ task_instance.xcom_pull({xcom_pull_str}) }}}}"
297 return xcom_pull
299 def _serialize(self) -> dict[str, Any]:
300 return {"task_id": self.operator.task_id, "key": self.key}
302 @classmethod
303 def _deserialize(cls, data: dict[str, Any], dag: DAG) -> XComArg:
304 return cls(dag.get_task(data["task_id"]), data["key"])
306 @property
307 def is_setup(self) -> bool:
308 return self.operator.is_setup
310 @is_setup.setter
311 def is_setup(self, val: bool):
312 self.operator.is_setup = val
314 @property
315 def is_teardown(self) -> bool:
316 return self.operator.is_teardown
318 @is_teardown.setter
319 def is_teardown(self, val: bool):
320 self.operator.is_teardown = val
322 @property
323 def on_failure_fail_dagrun(self) -> bool:
324 return self.operator.on_failure_fail_dagrun
326 @on_failure_fail_dagrun.setter
327 def on_failure_fail_dagrun(self, val: bool):
328 self.operator.on_failure_fail_dagrun = val
330 def as_setup(self) -> DependencyMixin:
331 for operator, _ in self.iter_references():
332 operator.is_setup = True
333 return self
335 def as_teardown(
336 self,
337 *,
338 setups: BaseOperator | Iterable[BaseOperator] | ArgNotSet = NOTSET,
339 on_failure_fail_dagrun=NOTSET,
340 ):
341 for operator, _ in self.iter_references():
342 operator.is_teardown = True
343 operator.trigger_rule = TriggerRule.ALL_DONE_SETUP_SUCCESS
344 if on_failure_fail_dagrun is not NOTSET:
345 operator.on_failure_fail_dagrun = on_failure_fail_dagrun
346 if not isinstance(setups, ArgNotSet):
347 setups = [setups] if isinstance(setups, DependencyMixin) else setups
348 for s in setups:
349 s.is_setup = True
350 s >> operator
351 return self
353 def iter_references(self) -> Iterator[tuple[Operator, str]]:
354 yield self.operator, self.key
356 def map(self, f: Callable[[Any], Any]) -> MapXComArg:
357 if self.key != XCOM_RETURN_KEY:
358 raise ValueError("cannot map against non-return XCom")
359 return super().map(f)
361 def zip(self, *others: XComArg, fillvalue: Any = NOTSET) -> ZipXComArg:
362 if self.key != XCOM_RETURN_KEY:
363 raise ValueError("cannot map against non-return XCom")
364 return super().zip(*others, fillvalue=fillvalue)
366 def get_task_map_length(self, run_id: str, *, session: Session) -> int | None:
367 from airflow.models.taskinstance import TaskInstance
368 from airflow.models.taskmap import TaskMap
369 from airflow.models.xcom import XCom
371 task = self.operator
372 if isinstance(task, MappedOperator):
373 unfinished_ti_exists = exists_query(
374 TaskInstance.dag_id == task.dag_id,
375 TaskInstance.run_id == run_id,
376 TaskInstance.task_id == task.task_id,
377 # Special NULL treatment is needed because 'state' can be NULL.
378 # The "IN" part would produce "NULL NOT IN ..." and eventually
379 # "NULl = NULL", which is a big no-no in SQL.
380 or_(
381 TaskInstance.state.is_(None),
382 TaskInstance.state.in_(s.value for s in State.unfinished if s is not None),
383 ),
384 session=session,
385 )
386 if unfinished_ti_exists:
387 return None # Not all of the expanded tis are done yet.
388 query = select(func.count(XCom.map_index)).where(
389 XCom.dag_id == task.dag_id,
390 XCom.run_id == run_id,
391 XCom.task_id == task.task_id,
392 XCom.map_index >= 0,
393 XCom.key == XCOM_RETURN_KEY,
394 )
395 else:
396 query = select(TaskMap.length).where(
397 TaskMap.dag_id == task.dag_id,
398 TaskMap.run_id == run_id,
399 TaskMap.task_id == task.task_id,
400 TaskMap.map_index < 0,
401 )
402 return session.scalar(query)
404 @provide_session
405 def resolve(self, context: Context, session: Session = NEW_SESSION) -> Any:
406 from airflow.models.taskinstance import TaskInstance
408 ti = context["ti"]
409 if not isinstance(ti, TaskInstance):
410 raise NotImplementedError("Wait for AIP-44 implementation to complete")
412 task_id = self.operator.task_id
413 map_indexes = ti.get_relevant_upstream_map_indexes(
414 self.operator,
415 context["expanded_ti_count"],
416 session=session,
417 )
418 result = ti.xcom_pull(
419 task_ids=task_id,
420 map_indexes=map_indexes,
421 key=self.key,
422 default=NOTSET,
423 session=session,
424 )
425 if not isinstance(result, ArgNotSet):
426 return result
427 if self.key == XCOM_RETURN_KEY:
428 return None
429 if getattr(self.operator, "multiple_outputs", False):
430 # If the operator is set to have multiple outputs and it was not executed,
431 # we should return "None" instead of showing an error. This is because when
432 # multiple outputs XComs are created, the XCom keys associated with them will have
433 # different names than the predefined "XCOM_RETURN_KEY" and won't be found.
434 # Therefore, it's better to return "None" like we did above where self.key==XCOM_RETURN_KEY.
435 return None
436 raise XComNotFound(ti.dag_id, task_id, self.key)
439def _get_callable_name(f: Callable | str) -> str:
440 """Try to "describe" a callable by getting its name."""
441 if callable(f):
442 return f.__name__
443 # Parse the source to find whatever is behind "def". For safety, we don't
444 # want to evaluate the code in any meaningful way!
445 with contextlib.suppress(Exception):
446 kw, name, _ = f.lstrip().split(None, 2)
447 if kw == "def":
448 return name
449 return "<function>"
452class _MapResult(Sequence):
453 def __init__(self, value: Sequence | dict, callables: MapCallables) -> None:
454 self.value = value
455 self.callables = callables
457 def __getitem__(self, index: Any) -> Any:
458 value = self.value[index]
460 # In the worker, we can access all actual callables. Call them.
461 callables = [f for f in self.callables if callable(f)]
462 if len(callables) == len(self.callables):
463 for f in callables:
464 value = f(value)
465 return value
467 # In the scheduler, we don't have access to the actual callables, nor do
468 # we want to run it since it's arbitrary code. This builds a string to
469 # represent the call chain in the UI or logs instead.
470 for v in self.callables:
471 value = f"{_get_callable_name(v)}({value})"
472 return value
474 def __len__(self) -> int:
475 return len(self.value)
478class MapXComArg(XComArg):
479 """An XCom reference with ``map()`` call(s) applied.
481 This is based on an XComArg, but also applies a series of "transforms" that
482 convert the pulled XCom value.
484 :meta private:
485 """
487 def __init__(self, arg: XComArg, callables: MapCallables) -> None:
488 for c in callables:
489 if getattr(c, "_airflow_is_task_decorator", False):
490 raise ValueError("map() argument must be a plain function, not a @task operator")
491 self.arg = arg
492 self.callables = callables
494 def __repr__(self) -> str:
495 map_calls = "".join(f".map({_get_callable_name(f)})" for f in self.callables)
496 return f"{self.arg!r}{map_calls}"
498 def _serialize(self) -> dict[str, Any]:
499 return {
500 "arg": serialize_xcom_arg(self.arg),
501 "callables": [inspect.getsource(c) if callable(c) else c for c in self.callables],
502 }
504 @classmethod
505 def _deserialize(cls, data: dict[str, Any], dag: DAG) -> XComArg:
506 # We are deliberately NOT deserializing the callables. These are shown
507 # in the UI, and displaying a function object is useless.
508 return cls(deserialize_xcom_arg(data["arg"], dag), data["callables"])
510 def iter_references(self) -> Iterator[tuple[Operator, str]]:
511 yield from self.arg.iter_references()
513 def map(self, f: Callable[[Any], Any]) -> MapXComArg:
514 # Flatten arg.map(f1).map(f2) into one MapXComArg.
515 return MapXComArg(self.arg, [*self.callables, f])
517 def get_task_map_length(self, run_id: str, *, session: Session) -> int | None:
518 return self.arg.get_task_map_length(run_id, session=session)
520 @provide_session
521 def resolve(self, context: Context, session: Session = NEW_SESSION) -> Any:
522 value = self.arg.resolve(context, session=session)
523 if not isinstance(value, (Sequence, dict)):
524 raise ValueError(f"XCom map expects sequence or dict, not {type(value).__name__}")
525 return _MapResult(value, self.callables)
528class _ZipResult(Sequence):
529 def __init__(self, values: Sequence[Sequence | dict], *, fillvalue: Any = NOTSET) -> None:
530 self.values = values
531 self.fillvalue = fillvalue
533 @staticmethod
534 def _get_or_fill(container: Sequence | dict, index: Any, fillvalue: Any) -> Any:
535 try:
536 return container[index]
537 except (IndexError, KeyError):
538 return fillvalue
540 def __getitem__(self, index: Any) -> Any:
541 if index >= len(self):
542 raise IndexError(index)
543 return tuple(self._get_or_fill(value, index, self.fillvalue) for value in self.values)
545 def __len__(self) -> int:
546 lengths = (len(v) for v in self.values)
547 if isinstance(self.fillvalue, ArgNotSet):
548 return min(lengths)
549 return max(lengths)
552class ZipXComArg(XComArg):
553 """An XCom reference with ``zip()`` applied.
555 This is constructed from multiple XComArg instances, and presents an
556 iterable that "zips" them together like the built-in ``zip()`` (and
557 ``itertools.zip_longest()`` if ``fillvalue`` is provided).
558 """
560 def __init__(self, args: Sequence[XComArg], *, fillvalue: Any = NOTSET) -> None:
561 if not args:
562 raise ValueError("At least one input is required")
563 self.args = args
564 self.fillvalue = fillvalue
566 def __repr__(self) -> str:
567 args_iter = iter(self.args)
568 first = repr(next(args_iter))
569 rest = ", ".join(repr(arg) for arg in args_iter)
570 if isinstance(self.fillvalue, ArgNotSet):
571 return f"{first}.zip({rest})"
572 return f"{first}.zip({rest}, fillvalue={self.fillvalue!r})"
574 def _serialize(self) -> dict[str, Any]:
575 args = [serialize_xcom_arg(arg) for arg in self.args]
576 if isinstance(self.fillvalue, ArgNotSet):
577 return {"args": args}
578 return {"args": args, "fillvalue": self.fillvalue}
580 @classmethod
581 def _deserialize(cls, data: dict[str, Any], dag: DAG) -> XComArg:
582 return cls(
583 [deserialize_xcom_arg(arg, dag) for arg in data["args"]],
584 fillvalue=data.get("fillvalue", NOTSET),
585 )
587 def iter_references(self) -> Iterator[tuple[Operator, str]]:
588 for arg in self.args:
589 yield from arg.iter_references()
591 def get_task_map_length(self, run_id: str, *, session: Session) -> int | None:
592 all_lengths = (arg.get_task_map_length(run_id, session=session) for arg in self.args)
593 ready_lengths = [length for length in all_lengths if length is not None]
594 if len(ready_lengths) != len(self.args):
595 return None # If any of the referenced XComs is not ready, we are not ready either.
596 if isinstance(self.fillvalue, ArgNotSet):
597 return min(ready_lengths)
598 return max(ready_lengths)
600 @provide_session
601 def resolve(self, context: Context, session: Session = NEW_SESSION) -> Any:
602 values = [arg.resolve(context, session=session) for arg in self.args]
603 for value in values:
604 if not isinstance(value, (Sequence, dict)):
605 raise ValueError(f"XCom zip expects sequence or dict, not {type(value).__name__}")
606 return _ZipResult(values, fillvalue=self.fillvalue)
609_XCOM_ARG_TYPES: Mapping[str, type[XComArg]] = {
610 "": PlainXComArg,
611 "map": MapXComArg,
612 "zip": ZipXComArg,
613}
616def serialize_xcom_arg(value: XComArg) -> dict[str, Any]:
617 """DAG serialization interface."""
618 key = next(k for k, v in _XCOM_ARG_TYPES.items() if v == type(value))
619 if key:
620 return {"type": key, **value._serialize()}
621 return value._serialize()
624def deserialize_xcom_arg(data: dict[str, Any], dag: DAG) -> XComArg:
625 """DAG serialization interface."""
626 klass = _XCOM_ARG_TYPES[data.get("type", "")]
627 return klass._deserialize(data, dag)