1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements. See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership. The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License. You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied. See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18from __future__ import annotations
19
20import contextlib
21import inspect
22from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator, Mapping, Sequence, Union, overload
23
24from sqlalchemy import func, or_, select
25
26from airflow.exceptions import AirflowException, XComNotFound
27from airflow.models.abstractoperator import AbstractOperator
28from airflow.models.mappedoperator import MappedOperator
29from airflow.models.taskmixin import DependencyMixin
30from airflow.utils.db import exists_query
31from airflow.utils.mixins import ResolveMixin
32from airflow.utils.session import NEW_SESSION, provide_session
33from airflow.utils.setup_teardown import SetupTeardownContext
34from airflow.utils.state import State
35from airflow.utils.trigger_rule import TriggerRule
36from airflow.utils.types import NOTSET, ArgNotSet
37from airflow.utils.xcom import XCOM_RETURN_KEY
38
39if TYPE_CHECKING:
40 from sqlalchemy.orm import Session
41
42 from airflow.models.baseoperator import BaseOperator
43 from airflow.models.dag import DAG
44 from airflow.models.operator import Operator
45 from airflow.models.taskmixin import DAGNode
46 from airflow.utils.context import Context
47 from airflow.utils.edgemodifier import EdgeModifier
48
49# Callable objects contained by MapXComArg. We only accept callables from
50# the user, but deserialize them into strings in a serialized XComArg for
51# safety (those callables are arbitrary user code).
52MapCallables = Sequence[Union[Callable[[Any], Any], str]]
53
54
55class XComArg(ResolveMixin, DependencyMixin):
56 """Reference to an XCom value pushed from another operator.
57
58 The implementation supports::
59
60 xcomarg >> op
61 xcomarg << op
62 op >> xcomarg # By BaseOperator code
63 op << xcomarg # By BaseOperator code
64
65 **Example**: The moment you get a result from any operator (decorated or regular) you can ::
66
67 any_op = AnyOperator()
68 xcomarg = XComArg(any_op)
69 # or equivalently
70 xcomarg = any_op.output
71 my_op = MyOperator()
72 my_op >> xcomarg
73
74 This object can be used in legacy Operators via Jinja.
75
76 **Example**: You can make this result to be part of any generated string::
77
78 any_op = AnyOperator()
79 xcomarg = any_op.output
80 op1 = MyOperator(my_text_message=f"the value is {xcomarg}")
81 op2 = MyOperator(my_text_message=f"the value is {xcomarg['topic']}")
82
83 :param operator: Operator instance to which the XComArg references.
84 :param key: Key used to pull the XCom value. Defaults to *XCOM_RETURN_KEY*,
85 i.e. the referenced operator's return value.
86 """
87
88 @overload
89 def __new__(cls: type[XComArg], operator: Operator, key: str = XCOM_RETURN_KEY) -> XComArg:
90 """Execute when the user writes ``XComArg(...)`` directly."""
91
92 @overload
93 def __new__(cls: type[XComArg]) -> XComArg:
94 """Execute by Python internals from subclasses."""
95
96 def __new__(cls, *args, **kwargs) -> XComArg:
97 if cls is XComArg:
98 return PlainXComArg(*args, **kwargs)
99 return super().__new__(cls)
100
101 @staticmethod
102 def iter_xcom_references(arg: Any) -> Iterator[tuple[Operator, str]]:
103 """Return XCom references in an arbitrary value.
104
105 Recursively traverse ``arg`` and look for XComArg instances in any
106 collection objects, and instances with ``template_fields`` set.
107 """
108 if isinstance(arg, ResolveMixin):
109 yield from arg.iter_references()
110 elif isinstance(arg, (tuple, set, list)):
111 for elem in arg:
112 yield from XComArg.iter_xcom_references(elem)
113 elif isinstance(arg, dict):
114 for elem in arg.values():
115 yield from XComArg.iter_xcom_references(elem)
116 elif isinstance(arg, AbstractOperator):
117 for attr in arg.template_fields:
118 yield from XComArg.iter_xcom_references(getattr(arg, attr))
119
120 @staticmethod
121 def apply_upstream_relationship(op: Operator, arg: Any):
122 """Set dependency for XComArgs.
123
124 This looks for XComArg objects in ``arg`` "deeply" (looking inside
125 collections objects and classes decorated with ``template_fields``), and
126 sets the relationship to ``op`` on any found.
127 """
128 for operator, _ in XComArg.iter_xcom_references(arg):
129 op.set_upstream(operator)
130
131 @property
132 def roots(self) -> list[DAGNode]:
133 """Required by TaskMixin."""
134 return [op for op, _ in self.iter_references()]
135
136 @property
137 def leaves(self) -> list[DAGNode]:
138 """Required by TaskMixin."""
139 return [op for op, _ in self.iter_references()]
140
141 def set_upstream(
142 self,
143 task_or_task_list: DependencyMixin | Sequence[DependencyMixin],
144 edge_modifier: EdgeModifier | None = None,
145 ):
146 """Proxy to underlying operator set_upstream method. Required by TaskMixin."""
147 for operator, _ in self.iter_references():
148 operator.set_upstream(task_or_task_list, edge_modifier)
149
150 def set_downstream(
151 self,
152 task_or_task_list: DependencyMixin | Sequence[DependencyMixin],
153 edge_modifier: EdgeModifier | None = None,
154 ):
155 """Proxy to underlying operator set_downstream method. Required by TaskMixin."""
156 for operator, _ in self.iter_references():
157 operator.set_downstream(task_or_task_list, edge_modifier)
158
159 def _serialize(self) -> dict[str, Any]:
160 """
161 Serialize a DAG.
162
163 The implementation should be the inverse function to ``deserialize``,
164 returning a data dict converted from this XComArg derivative. DAG
165 serialization does not call this directly, but ``serialize_xcom_arg``
166 instead, which adds additional information to dispatch deserialization
167 to the correct class.
168 """
169 raise NotImplementedError()
170
171 @classmethod
172 def _deserialize(cls, data: dict[str, Any], dag: DAG) -> XComArg:
173 """
174 Deserialize a DAG.
175
176 The implementation should be the inverse function to ``serialize``,
177 implementing given a data dict converted from this XComArg derivative,
178 how the original XComArg should be created. DAG serialization relies on
179 additional information added in ``serialize_xcom_arg`` to dispatch data
180 dicts to the correct ``_deserialize`` information, so this function does
181 not need to validate whether the incoming data contains correct keys.
182 """
183 raise NotImplementedError()
184
185 def map(self, f: Callable[[Any], Any]) -> MapXComArg:
186 return MapXComArg(self, [f])
187
188 def zip(self, *others: XComArg, fillvalue: Any = NOTSET) -> ZipXComArg:
189 return ZipXComArg([self, *others], fillvalue=fillvalue)
190
191 def get_task_map_length(self, run_id: str, *, session: Session) -> int | None:
192 """Inspect length of pushed value for task-mapping.
193
194 This is used to determine how many task instances the scheduler should
195 create for a downstream using this XComArg for task-mapping.
196
197 *None* may be returned if the depended XCom has not been pushed.
198 """
199 raise NotImplementedError()
200
201 @provide_session
202 def resolve(self, context: Context, session: Session = NEW_SESSION) -> Any:
203 """Pull XCom value.
204
205 This should only be called during ``op.execute()`` with an appropriate
206 context (e.g. generated from ``TaskInstance.get_template_context()``).
207 Although the ``ResolveMixin`` parent mixin also has a ``resolve``
208 protocol, this adds the optional ``session`` argument that some of the
209 subclasses need.
210
211 :meta private:
212 """
213 raise NotImplementedError()
214
215 def __enter__(self):
216 if not self.operator.is_setup and not self.operator.is_teardown:
217 raise AirflowException("Only setup/teardown tasks can be used as context managers.")
218 SetupTeardownContext.push_setup_teardown_task(self.operator)
219 return SetupTeardownContext
220
221 def __exit__(self, exc_type, exc_val, exc_tb):
222 SetupTeardownContext.set_work_task_roots_and_leaves()
223
224
225class PlainXComArg(XComArg):
226 """Reference to one single XCom without any additional semantics.
227
228 This class should not be accessed directly, but only through XComArg. The
229 class inheritance chain and ``__new__`` is implemented in this slightly
230 convoluted way because we want to
231
232 a. Allow the user to continue using XComArg directly for the simple
233 semantics (see documentation of the base class for details).
234 b. Make ``isinstance(thing, XComArg)`` be able to detect all kinds of XCom
235 references.
236 c. Not allow many properties of PlainXComArg (including ``__getitem__`` and
237 ``__str__``) to exist on other kinds of XComArg implementations since
238 they don't make sense.
239
240 :meta private:
241 """
242
243 def __init__(self, operator: Operator, key: str = XCOM_RETURN_KEY):
244 self.operator = operator
245 self.key = key
246
247 def __eq__(self, other: Any) -> bool:
248 if not isinstance(other, PlainXComArg):
249 return NotImplemented
250 return self.operator == other.operator and self.key == other.key
251
252 def __getitem__(self, item: str) -> XComArg:
253 """Implement xcomresult['some_result_key']."""
254 if not isinstance(item, str):
255 raise ValueError(f"XComArg only supports str lookup, received {type(item).__name__}")
256 return PlainXComArg(operator=self.operator, key=item)
257
258 def __iter__(self):
259 """Override iterable protocol to raise error explicitly.
260
261 The default ``__iter__`` implementation in Python calls ``__getitem__``
262 with 0, 1, 2, etc. until it hits an ``IndexError``. This does not work
263 well with our custom ``__getitem__`` implementation, and results in poor
264 DAG-writing experience since a misplaced ``*`` expansion would create an
265 infinite loop consuming the entire DAG parser.
266
267 This override catches the error eagerly, so an incorrectly implemented
268 DAG fails fast and avoids wasting resources on nonsensical iterating.
269 """
270 raise TypeError("'XComArg' object is not iterable")
271
272 def __repr__(self) -> str:
273 if self.key == XCOM_RETURN_KEY:
274 return f"XComArg({self.operator!r})"
275 return f"XComArg({self.operator!r}, {self.key!r})"
276
277 def __str__(self) -> str:
278 """
279 Backward compatibility for old-style jinja used in Airflow Operators.
280
281 **Example**: to use XComArg at BashOperator::
282
283 BashOperator(cmd=f"... { xcomarg } ...")
284
285 :return:
286 """
287 xcom_pull_kwargs = [
288 f"task_ids='{self.operator.task_id}'",
289 f"dag_id='{self.operator.dag_id}'",
290 ]
291 if self.key is not None:
292 xcom_pull_kwargs.append(f"key='{self.key}'")
293
294 xcom_pull_str = ", ".join(xcom_pull_kwargs)
295 # {{{{ are required for escape {{ in f-string
296 xcom_pull = f"{{{{ task_instance.xcom_pull({xcom_pull_str}) }}}}"
297 return xcom_pull
298
299 def _serialize(self) -> dict[str, Any]:
300 return {"task_id": self.operator.task_id, "key": self.key}
301
302 @classmethod
303 def _deserialize(cls, data: dict[str, Any], dag: DAG) -> XComArg:
304 return cls(dag.get_task(data["task_id"]), data["key"])
305
306 @property
307 def is_setup(self) -> bool:
308 return self.operator.is_setup
309
310 @is_setup.setter
311 def is_setup(self, val: bool):
312 self.operator.is_setup = val
313
314 @property
315 def is_teardown(self) -> bool:
316 return self.operator.is_teardown
317
318 @is_teardown.setter
319 def is_teardown(self, val: bool):
320 self.operator.is_teardown = val
321
322 @property
323 def on_failure_fail_dagrun(self) -> bool:
324 return self.operator.on_failure_fail_dagrun
325
326 @on_failure_fail_dagrun.setter
327 def on_failure_fail_dagrun(self, val: bool):
328 self.operator.on_failure_fail_dagrun = val
329
330 def as_setup(self) -> DependencyMixin:
331 for operator, _ in self.iter_references():
332 operator.is_setup = True
333 return self
334
335 def as_teardown(
336 self,
337 *,
338 setups: BaseOperator | Iterable[BaseOperator] | ArgNotSet = NOTSET,
339 on_failure_fail_dagrun=NOTSET,
340 ):
341 for operator, _ in self.iter_references():
342 operator.is_teardown = True
343 operator.trigger_rule = TriggerRule.ALL_DONE_SETUP_SUCCESS
344 if on_failure_fail_dagrun is not NOTSET:
345 operator.on_failure_fail_dagrun = on_failure_fail_dagrun
346 if not isinstance(setups, ArgNotSet):
347 setups = [setups] if isinstance(setups, DependencyMixin) else setups
348 for s in setups:
349 s.is_setup = True
350 s >> operator
351 return self
352
353 def iter_references(self) -> Iterator[tuple[Operator, str]]:
354 yield self.operator, self.key
355
356 def map(self, f: Callable[[Any], Any]) -> MapXComArg:
357 if self.key != XCOM_RETURN_KEY:
358 raise ValueError("cannot map against non-return XCom")
359 return super().map(f)
360
361 def zip(self, *others: XComArg, fillvalue: Any = NOTSET) -> ZipXComArg:
362 if self.key != XCOM_RETURN_KEY:
363 raise ValueError("cannot map against non-return XCom")
364 return super().zip(*others, fillvalue=fillvalue)
365
366 def get_task_map_length(self, run_id: str, *, session: Session) -> int | None:
367 from airflow.models.taskinstance import TaskInstance
368 from airflow.models.taskmap import TaskMap
369 from airflow.models.xcom import XCom
370
371 task = self.operator
372 if isinstance(task, MappedOperator):
373 unfinished_ti_exists = exists_query(
374 TaskInstance.dag_id == task.dag_id,
375 TaskInstance.run_id == run_id,
376 TaskInstance.task_id == task.task_id,
377 # Special NULL treatment is needed because 'state' can be NULL.
378 # The "IN" part would produce "NULL NOT IN ..." and eventually
379 # "NULl = NULL", which is a big no-no in SQL.
380 or_(
381 TaskInstance.state.is_(None),
382 TaskInstance.state.in_(s.value for s in State.unfinished if s is not None),
383 ),
384 session=session,
385 )
386 if unfinished_ti_exists:
387 return None # Not all of the expanded tis are done yet.
388 query = select(func.count(XCom.map_index)).where(
389 XCom.dag_id == task.dag_id,
390 XCom.run_id == run_id,
391 XCom.task_id == task.task_id,
392 XCom.map_index >= 0,
393 XCom.key == XCOM_RETURN_KEY,
394 )
395 else:
396 query = select(TaskMap.length).where(
397 TaskMap.dag_id == task.dag_id,
398 TaskMap.run_id == run_id,
399 TaskMap.task_id == task.task_id,
400 TaskMap.map_index < 0,
401 )
402 return session.scalar(query)
403
404 @provide_session
405 def resolve(self, context: Context, session: Session = NEW_SESSION) -> Any:
406 from airflow.models.taskinstance import TaskInstance
407
408 ti = context["ti"]
409 if not isinstance(ti, TaskInstance):
410 raise NotImplementedError("Wait for AIP-44 implementation to complete")
411
412 task_id = self.operator.task_id
413 map_indexes = ti.get_relevant_upstream_map_indexes(
414 self.operator,
415 context["expanded_ti_count"],
416 session=session,
417 )
418 result = ti.xcom_pull(
419 task_ids=task_id,
420 map_indexes=map_indexes,
421 key=self.key,
422 default=NOTSET,
423 session=session,
424 )
425 if not isinstance(result, ArgNotSet):
426 return result
427 if self.key == XCOM_RETURN_KEY:
428 return None
429 if getattr(self.operator, "multiple_outputs", False):
430 # If the operator is set to have multiple outputs and it was not executed,
431 # we should return "None" instead of showing an error. This is because when
432 # multiple outputs XComs are created, the XCom keys associated with them will have
433 # different names than the predefined "XCOM_RETURN_KEY" and won't be found.
434 # Therefore, it's better to return "None" like we did above where self.key==XCOM_RETURN_KEY.
435 return None
436 raise XComNotFound(ti.dag_id, task_id, self.key)
437
438
439def _get_callable_name(f: Callable | str) -> str:
440 """Try to "describe" a callable by getting its name."""
441 if callable(f):
442 return f.__name__
443 # Parse the source to find whatever is behind "def". For safety, we don't
444 # want to evaluate the code in any meaningful way!
445 with contextlib.suppress(Exception):
446 kw, name, _ = f.lstrip().split(None, 2)
447 if kw == "def":
448 return name
449 return "<function>"
450
451
452class _MapResult(Sequence):
453 def __init__(self, value: Sequence | dict, callables: MapCallables) -> None:
454 self.value = value
455 self.callables = callables
456
457 def __getitem__(self, index: Any) -> Any:
458 value = self.value[index]
459
460 # In the worker, we can access all actual callables. Call them.
461 callables = [f for f in self.callables if callable(f)]
462 if len(callables) == len(self.callables):
463 for f in callables:
464 value = f(value)
465 return value
466
467 # In the scheduler, we don't have access to the actual callables, nor do
468 # we want to run it since it's arbitrary code. This builds a string to
469 # represent the call chain in the UI or logs instead.
470 for v in self.callables:
471 value = f"{_get_callable_name(v)}({value})"
472 return value
473
474 def __len__(self) -> int:
475 return len(self.value)
476
477
478class MapXComArg(XComArg):
479 """An XCom reference with ``map()`` call(s) applied.
480
481 This is based on an XComArg, but also applies a series of "transforms" that
482 convert the pulled XCom value.
483
484 :meta private:
485 """
486
487 def __init__(self, arg: XComArg, callables: MapCallables) -> None:
488 for c in callables:
489 if getattr(c, "_airflow_is_task_decorator", False):
490 raise ValueError("map() argument must be a plain function, not a @task operator")
491 self.arg = arg
492 self.callables = callables
493
494 def __repr__(self) -> str:
495 map_calls = "".join(f".map({_get_callable_name(f)})" for f in self.callables)
496 return f"{self.arg!r}{map_calls}"
497
498 def _serialize(self) -> dict[str, Any]:
499 return {
500 "arg": serialize_xcom_arg(self.arg),
501 "callables": [inspect.getsource(c) if callable(c) else c for c in self.callables],
502 }
503
504 @classmethod
505 def _deserialize(cls, data: dict[str, Any], dag: DAG) -> XComArg:
506 # We are deliberately NOT deserializing the callables. These are shown
507 # in the UI, and displaying a function object is useless.
508 return cls(deserialize_xcom_arg(data["arg"], dag), data["callables"])
509
510 def iter_references(self) -> Iterator[tuple[Operator, str]]:
511 yield from self.arg.iter_references()
512
513 def map(self, f: Callable[[Any], Any]) -> MapXComArg:
514 # Flatten arg.map(f1).map(f2) into one MapXComArg.
515 return MapXComArg(self.arg, [*self.callables, f])
516
517 def get_task_map_length(self, run_id: str, *, session: Session) -> int | None:
518 return self.arg.get_task_map_length(run_id, session=session)
519
520 @provide_session
521 def resolve(self, context: Context, session: Session = NEW_SESSION) -> Any:
522 value = self.arg.resolve(context, session=session)
523 if not isinstance(value, (Sequence, dict)):
524 raise ValueError(f"XCom map expects sequence or dict, not {type(value).__name__}")
525 return _MapResult(value, self.callables)
526
527
528class _ZipResult(Sequence):
529 def __init__(self, values: Sequence[Sequence | dict], *, fillvalue: Any = NOTSET) -> None:
530 self.values = values
531 self.fillvalue = fillvalue
532
533 @staticmethod
534 def _get_or_fill(container: Sequence | dict, index: Any, fillvalue: Any) -> Any:
535 try:
536 return container[index]
537 except (IndexError, KeyError):
538 return fillvalue
539
540 def __getitem__(self, index: Any) -> Any:
541 if index >= len(self):
542 raise IndexError(index)
543 return tuple(self._get_or_fill(value, index, self.fillvalue) for value in self.values)
544
545 def __len__(self) -> int:
546 lengths = (len(v) for v in self.values)
547 if isinstance(self.fillvalue, ArgNotSet):
548 return min(lengths)
549 return max(lengths)
550
551
552class ZipXComArg(XComArg):
553 """An XCom reference with ``zip()`` applied.
554
555 This is constructed from multiple XComArg instances, and presents an
556 iterable that "zips" them together like the built-in ``zip()`` (and
557 ``itertools.zip_longest()`` if ``fillvalue`` is provided).
558 """
559
560 def __init__(self, args: Sequence[XComArg], *, fillvalue: Any = NOTSET) -> None:
561 if not args:
562 raise ValueError("At least one input is required")
563 self.args = args
564 self.fillvalue = fillvalue
565
566 def __repr__(self) -> str:
567 args_iter = iter(self.args)
568 first = repr(next(args_iter))
569 rest = ", ".join(repr(arg) for arg in args_iter)
570 if isinstance(self.fillvalue, ArgNotSet):
571 return f"{first}.zip({rest})"
572 return f"{first}.zip({rest}, fillvalue={self.fillvalue!r})"
573
574 def _serialize(self) -> dict[str, Any]:
575 args = [serialize_xcom_arg(arg) for arg in self.args]
576 if isinstance(self.fillvalue, ArgNotSet):
577 return {"args": args}
578 return {"args": args, "fillvalue": self.fillvalue}
579
580 @classmethod
581 def _deserialize(cls, data: dict[str, Any], dag: DAG) -> XComArg:
582 return cls(
583 [deserialize_xcom_arg(arg, dag) for arg in data["args"]],
584 fillvalue=data.get("fillvalue", NOTSET),
585 )
586
587 def iter_references(self) -> Iterator[tuple[Operator, str]]:
588 for arg in self.args:
589 yield from arg.iter_references()
590
591 def get_task_map_length(self, run_id: str, *, session: Session) -> int | None:
592 all_lengths = (arg.get_task_map_length(run_id, session=session) for arg in self.args)
593 ready_lengths = [length for length in all_lengths if length is not None]
594 if len(ready_lengths) != len(self.args):
595 return None # If any of the referenced XComs is not ready, we are not ready either.
596 if isinstance(self.fillvalue, ArgNotSet):
597 return min(ready_lengths)
598 return max(ready_lengths)
599
600 @provide_session
601 def resolve(self, context: Context, session: Session = NEW_SESSION) -> Any:
602 values = [arg.resolve(context, session=session) for arg in self.args]
603 for value in values:
604 if not isinstance(value, (Sequence, dict)):
605 raise ValueError(f"XCom zip expects sequence or dict, not {type(value).__name__}")
606 return _ZipResult(values, fillvalue=self.fillvalue)
607
608
609_XCOM_ARG_TYPES: Mapping[str, type[XComArg]] = {
610 "": PlainXComArg,
611 "map": MapXComArg,
612 "zip": ZipXComArg,
613}
614
615
616def serialize_xcom_arg(value: XComArg) -> dict[str, Any]:
617 """DAG serialization interface."""
618 key = next(k for k, v in _XCOM_ARG_TYPES.items() if v == type(value))
619 if key:
620 return {"type": key, **value._serialize()}
621 return value._serialize()
622
623
624def deserialize_xcom_arg(data: dict[str, Any], dag: DAG) -> XComArg:
625 """DAG serialization interface."""
626 klass = _XCOM_ARG_TYPES[data.get("type", "")]
627 return klass._deserialize(data, dag)