1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements. See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership. The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License. You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied. See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18from __future__ import annotations
19
20import contextlib
21import inspect
22import itertools
23from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence, Sized
24from functools import singledispatch
25from typing import TYPE_CHECKING, Any, overload
26
27import attrs
28
29from airflow.sdk import TriggerRule
30from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator
31from airflow.sdk.definitions._internal.mixins import DependencyMixin, ResolveMixin
32from airflow.sdk.definitions._internal.setup_teardown import SetupTeardownContext
33from airflow.sdk.definitions._internal.types import NOTSET, ArgNotSet, is_arg_set
34from airflow.sdk.exceptions import AirflowException, XComNotFound
35from airflow.sdk.execution_time.lazy_sequence import LazyXComSequence
36from airflow.sdk.execution_time.xcom import BaseXCom
37
38if TYPE_CHECKING:
39 from airflow.sdk.bases.operator import BaseOperator
40 from airflow.sdk.definitions.edges import EdgeModifier
41 from airflow.sdk.types import Operator
42
43# Callable objects contained by MapXComArg. We only accept callables from
44# the user, but deserialize them into strings in a serialized XComArg for
45# safety (those callables are arbitrary user code).
46MapCallables = Sequence[Callable[[Any], Any]]
47
48
49class XComArg(ResolveMixin, DependencyMixin):
50 """
51 Reference to an XCom value pushed from another operator.
52
53 The implementation supports::
54
55 xcomarg >> op
56 xcomarg << op
57 op >> xcomarg # By BaseOperator code
58 op << xcomarg # By BaseOperator code
59
60 **Example**: The moment you get a result from any operator (decorated or regular) you can ::
61
62 any_op = AnyOperator()
63 xcomarg = XComArg(any_op)
64 # or equivalently
65 xcomarg = any_op.output
66 my_op = MyOperator()
67 my_op >> xcomarg
68
69 This object can be used in legacy Operators via Jinja.
70
71 **Example**: You can make this result to be part of any generated string::
72
73 any_op = AnyOperator()
74 xcomarg = any_op.output
75 op1 = MyOperator(my_text_message=f"the value is {xcomarg}")
76 op2 = MyOperator(my_text_message=f"the value is {xcomarg['topic']}")
77
78 :param operator: Operator instance to which the XComArg references.
79 :param key: Key used to pull the XCom value. Defaults to *XCOM_RETURN_KEY*,
80 i.e. the referenced operator's return value.
81 """
82
83 @overload
84 def __new__(cls: type[XComArg], operator: Operator, key: str = BaseXCom.XCOM_RETURN_KEY) -> XComArg:
85 """Execute when the user writes ``XComArg(...)`` directly."""
86
87 @overload
88 def __new__(cls: type[XComArg]) -> XComArg:
89 """Execute by Python internals from subclasses."""
90
91 def __new__(cls, *args, **kwargs) -> XComArg:
92 if cls is XComArg:
93 return PlainXComArg(*args, **kwargs)
94 return super().__new__(cls)
95
96 def iter_references(self) -> Iterator[tuple[Operator, str]]:
97 raise NotImplementedError()
98
99 @staticmethod
100 def iter_xcom_references(arg: Any) -> Iterator[tuple[Operator, str]]:
101 """
102 Return XCom references in an arbitrary value.
103
104 Recursively traverse ``arg`` and look for XComArg instances in any
105 collection objects, and instances with ``template_fields`` set.
106 """
107 if isinstance(arg, ResolveMixin):
108 yield from arg.iter_references()
109 elif isinstance(arg, (tuple, set, list)):
110 for elem in arg:
111 yield from XComArg.iter_xcom_references(elem)
112 elif isinstance(arg, dict):
113 for elem in arg.values():
114 yield from XComArg.iter_xcom_references(elem)
115 elif isinstance(arg, AbstractOperator):
116 for attr in arg.template_fields:
117 yield from XComArg.iter_xcom_references(getattr(arg, attr))
118
119 @staticmethod
120 def apply_upstream_relationship(op: DependencyMixin, arg: Any):
121 """
122 Set dependency for XComArgs.
123
124 This looks for XComArg objects in ``arg`` "deeply" (looking inside
125 collections objects and classes decorated with ``template_fields``), and
126 sets the relationship to ``op`` on any found.
127 """
128 for operator, _ in XComArg.iter_xcom_references(arg):
129 op.set_upstream(operator)
130
131 @property
132 def roots(self) -> list[Operator]:
133 """Required by DependencyMixin."""
134 return [op for op, _ in self.iter_references()]
135
136 @property
137 def leaves(self) -> list[Operator]:
138 """Required by DependencyMixin."""
139 return [op for op, _ in self.iter_references()]
140
141 def set_upstream(
142 self,
143 task_or_task_list: DependencyMixin | Sequence[DependencyMixin],
144 edge_modifier: EdgeModifier | None = None,
145 ):
146 """Proxy to underlying operator set_upstream method. Required by DependencyMixin."""
147 for operator, _ in self.iter_references():
148 operator.set_upstream(task_or_task_list, edge_modifier)
149
150 def set_downstream(
151 self,
152 task_or_task_list: DependencyMixin | Sequence[DependencyMixin],
153 edge_modifier: EdgeModifier | None = None,
154 ):
155 """Proxy to underlying operator set_downstream method. Required by DependencyMixin."""
156 for operator, _ in self.iter_references():
157 operator.set_downstream(task_or_task_list, edge_modifier)
158
159 def _serialize(self) -> dict[str, Any]:
160 """
161 Serialize an XComArg.
162
163 The implementation should be the inverse function to ``deserialize``,
164 returning a data dict converted from this XComArg derivative. DAG
165 serialization does not call this directly, but ``serialize_xcom_arg``
166 instead, which adds additional information to dispatch deserialization
167 to the correct class.
168 """
169 raise NotImplementedError()
170
171 def map(self, f: Callable[[Any], Any]) -> MapXComArg:
172 return MapXComArg(self, [f])
173
174 def zip(self, *others: XComArg, fillvalue: Any = NOTSET) -> ZipXComArg:
175 return ZipXComArg([self, *others], fillvalue=fillvalue)
176
177 def concat(self, *others: XComArg) -> ConcatXComArg:
178 return ConcatXComArg([self, *others])
179
180 def resolve(self, context: Mapping[str, Any]) -> Any:
181 raise NotImplementedError()
182
183 def __enter__(self):
184 if not self.operator.is_setup and not self.operator.is_teardown:
185 raise AirflowException("Only setup/teardown tasks can be used as context managers.")
186 SetupTeardownContext.push_setup_teardown_task(self.operator)
187 return SetupTeardownContext
188
189 def __exit__(self, exc_type, exc_val, exc_tb):
190 SetupTeardownContext.set_work_task_roots_and_leaves()
191
192
193@attrs.define
194class PlainXComArg(XComArg):
195 """
196 Reference to one single XCom without any additional semantics.
197
198 This class should not be accessed directly, but only through XComArg. The
199 class inheritance chain and ``__new__`` is implemented in this slightly
200 convoluted way because we want to
201
202 a. Allow the user to continue using XComArg directly for the simple
203 semantics (see documentation of the base class for details).
204 b. Make ``isinstance(thing, XComArg)`` be able to detect all kinds of XCom
205 references.
206 c. Not allow many properties of PlainXComArg (including ``__getitem__`` and
207 ``__str__``) to exist on other kinds of XComArg implementations since
208 they don't make sense.
209
210 :meta private:
211 """
212
213 operator: Operator
214 key: str = BaseXCom.XCOM_RETURN_KEY
215
216 def __getitem__(self, item: str) -> XComArg:
217 """Implement xcomresult['some_result_key']."""
218 if not isinstance(item, str):
219 raise ValueError(f"XComArg only supports str lookup, received {type(item).__name__}")
220 return PlainXComArg(operator=self.operator, key=item)
221
222 def __iter__(self):
223 """
224 Override iterable protocol to raise error explicitly.
225
226 The default ``__iter__`` implementation in Python calls ``__getitem__``
227 with 0, 1, 2, etc. until it hits an ``IndexError``. This does not work
228 well with our custom ``__getitem__`` implementation, and results in poor
229 Dag-writing experience since a misplaced ``*`` expansion would create an
230 infinite loop consuming the entire Dag parser.
231
232 This override catches the error eagerly, so an incorrectly implemented
233 Dag fails fast and avoids wasting resources on nonsensical iterating.
234 """
235 raise TypeError("'XComArg' object is not iterable")
236
237 def __repr__(self) -> str:
238 if self.key == BaseXCom.XCOM_RETURN_KEY:
239 return f"XComArg({self.operator!r})"
240 return f"XComArg({self.operator!r}, {self.key!r})"
241
242 def __str__(self) -> str:
243 """
244 Backward compatibility for old-style jinja used in Airflow Operators.
245
246 **Example**: to use XComArg at BashOperator::
247
248 BashOperator(cmd=f"... {xcomarg} ...")
249
250 :return:
251 """
252 xcom_pull_kwargs = [
253 f"task_ids='{self.operator.task_id}'",
254 f"dag_id='{self.operator.dag_id}'",
255 ]
256 if self.key is not None:
257 xcom_pull_kwargs.append(f"key='{self.key}'")
258
259 xcom_pull_str = ", ".join(xcom_pull_kwargs)
260 # {{{{ are required for escape {{ in f-string
261 xcom_pull = f"{{{{ task_instance.xcom_pull({xcom_pull_str}) }}}}"
262 return xcom_pull
263
264 def _serialize(self) -> dict[str, Any]:
265 return {"task_id": self.operator.task_id, "key": self.key}
266
267 @property
268 def is_setup(self) -> bool:
269 return self.operator.is_setup
270
271 @is_setup.setter
272 def is_setup(self, val: bool):
273 self.operator.is_setup = val
274
275 @property
276 def is_teardown(self) -> bool:
277 return self.operator.is_teardown
278
279 @is_teardown.setter
280 def is_teardown(self, val: bool):
281 self.operator.is_teardown = val
282
283 @property
284 def on_failure_fail_dagrun(self) -> bool:
285 return self.operator.on_failure_fail_dagrun
286
287 @on_failure_fail_dagrun.setter
288 def on_failure_fail_dagrun(self, val: bool):
289 self.operator.on_failure_fail_dagrun = val
290
291 def as_setup(self) -> DependencyMixin:
292 for operator, _ in self.iter_references():
293 operator.is_setup = True
294 return self
295
296 def as_teardown(
297 self,
298 *,
299 setups: BaseOperator | Iterable[BaseOperator] | None = None,
300 on_failure_fail_dagrun: bool | None = None,
301 ):
302 for operator, _ in self.iter_references():
303 operator.is_teardown = True
304 operator.trigger_rule = TriggerRule.ALL_DONE_SETUP_SUCCESS
305 if on_failure_fail_dagrun is not None:
306 operator.on_failure_fail_dagrun = on_failure_fail_dagrun
307 if setups is not None:
308 setups = [setups] if isinstance(setups, DependencyMixin) else setups
309 for s in setups:
310 s.is_setup = True
311 s >> operator
312 return self
313
314 def iter_references(self) -> Iterator[tuple[Operator, str]]:
315 yield self.operator, self.key
316
317 def map(self, f: Callable[[Any], Any]) -> MapXComArg:
318 if self.key != BaseXCom.XCOM_RETURN_KEY:
319 raise ValueError("cannot map against non-return XCom")
320 return super().map(f)
321
322 def zip(self, *others: XComArg, fillvalue: Any = NOTSET) -> ZipXComArg:
323 if self.key != BaseXCom.XCOM_RETURN_KEY:
324 raise ValueError("cannot map against non-return XCom")
325 return super().zip(*others, fillvalue=fillvalue)
326
327 def concat(self, *others: XComArg) -> ConcatXComArg:
328 if self.key != BaseXCom.XCOM_RETURN_KEY:
329 raise ValueError("cannot concatenate non-return XCom")
330 return super().concat(*others)
331
332 def resolve(self, context: Mapping[str, Any]) -> Any:
333 ti = context["ti"]
334 task_id = self.operator.task_id
335
336 if self.operator.is_mapped:
337 return LazyXComSequence(xcom_arg=self, ti=ti)
338 tg = self.operator.get_closest_mapped_task_group()
339 if tg is None:
340 # No mapped task group - pull from unmapped instance
341 map_indexes: int | range | None | ArgNotSet = None
342 else:
343 # Check for pre-computed value from server (backward compatibility)
344 upstream_map_indexes = getattr(ti, "_upstream_map_indexes", None)
345 if upstream_map_indexes is not None:
346 # Use None as default to match original behavior (filter for unmapped XCom)
347 map_indexes = upstream_map_indexes.get(task_id, None)
348 else:
349 # Compute lazily - ti_count will be queried if needed
350 cached_context = getattr(ti, "_cached_template_context", None)
351 ti_count = cached_context.get("expanded_ti_count") if cached_context else None
352 computed = ti.get_relevant_upstream_map_indexes(
353 upstream=self.operator,
354 ti_count=ti_count,
355 session=None, # Not used in SDK implementation
356 )
357 # None means "no filtering needed" -> use NOTSET to pull all values
358 map_indexes = NOTSET if computed is None else computed
359 result = ti.xcom_pull(
360 task_ids=task_id,
361 key=self.key,
362 default=NOTSET,
363 map_indexes=map_indexes,
364 )
365 if is_arg_set(result):
366 return result
367 if self.key == BaseXCom.XCOM_RETURN_KEY:
368 return None
369 if getattr(self.operator, "multiple_outputs", False):
370 # If the operator is set to have multiple outputs and it was not executed,
371 # we should return "None" instead of showing an error. This is because when
372 # multiple outputs XComs are created, the XCom keys associated with them will have
373 # different names than the predefined "XCOM_RETURN_KEY" and won't be found.
374 # Therefore, it's better to return "None" like we did above where self.key==XCOM_RETURN_KEY.
375 return None
376 raise XComNotFound(ti.dag_id, task_id, self.key)
377
378
379def _get_callable_name(f: Callable | str) -> str:
380 """Try to "describe" a callable by getting its name."""
381 if callable(f):
382 return f.__name__
383 # Parse the source to find whatever is behind "def". For safety, we don't
384 # want to evaluate the code in any meaningful way!
385 with contextlib.suppress(Exception):
386 kw, name, _ = f.lstrip().split(None, 2)
387 if kw == "def":
388 return name
389 return "<function>"
390
391
392@attrs.define
393class _MapResult(Sequence):
394 value: Sequence | dict
395 callables: MapCallables
396
397 def __getitem__(self, index: Any) -> Any:
398 value = self.value[index]
399
400 for f in self.callables:
401 value = f(value)
402 return value
403
404 def __len__(self) -> int:
405 return len(self.value)
406
407
408@attrs.define
409class MapXComArg(XComArg):
410 """
411 An XCom reference with ``map()`` call(s) applied.
412
413 This is based on an XComArg, but also applies a series of "transforms" that
414 convert the pulled XCom value.
415
416 :meta private:
417 """
418
419 arg: XComArg
420 callables: MapCallables
421
422 def __attrs_post_init__(self) -> None:
423 for c in self.callables:
424 if getattr(c, "_airflow_is_task_decorator", False):
425 raise ValueError("map() argument must be a plain function, not a @task operator")
426
427 def __repr__(self) -> str:
428 map_calls = "".join(f".map({_get_callable_name(f)})" for f in self.callables)
429 return f"{self.arg!r}{map_calls}"
430
431 def _serialize(self) -> dict[str, Any]:
432 return {
433 "arg": serialize_xcom_arg(self.arg),
434 "callables": [inspect.getsource(c) if callable(c) else c for c in self.callables],
435 }
436
437 def iter_references(self) -> Iterator[tuple[Operator, str]]:
438 yield from self.arg.iter_references()
439
440 def map(self, f: Callable[[Any], Any]) -> MapXComArg:
441 # Flatten arg.map(f1).map(f2) into one MapXComArg.
442 return MapXComArg(self.arg, [*self.callables, f])
443
444 def resolve(self, context: Mapping[str, Any]) -> Any:
445 value = self.arg.resolve(context)
446 if not isinstance(value, (Sequence, dict)):
447 raise ValueError(f"XCom map expects sequence or dict, not {type(value).__name__}")
448 return _MapResult(value, self.callables)
449
450
451@attrs.define
452class _ZipResult(Sequence):
453 values: Sequence[Sequence | dict]
454 fillvalue: Any = attrs.field(default=NOTSET, kw_only=True)
455
456 @staticmethod
457 def _get_or_fill(container: Sequence | dict, index: Any, fillvalue: Any) -> Any:
458 try:
459 return container[index]
460 except (IndexError, KeyError):
461 return fillvalue
462
463 def __getitem__(self, index: Any) -> Any:
464 if index >= len(self):
465 raise IndexError(index)
466 return tuple(self._get_or_fill(value, index, self.fillvalue) for value in self.values)
467
468 def __len__(self) -> int:
469 lengths = (len(v) for v in self.values)
470 if is_arg_set(self.fillvalue):
471 return max(lengths)
472 return min(lengths)
473
474
475@attrs.define
476class ZipXComArg(XComArg):
477 """
478 An XCom reference with ``zip()`` applied.
479
480 This is constructed from multiple XComArg instances, and presents an
481 iterable that "zips" them together like the built-in ``zip()`` (and
482 ``itertools.zip_longest()`` if ``fillvalue`` is provided).
483 """
484
485 args: Sequence[XComArg] = attrs.field(validator=attrs.validators.min_len(1))
486 fillvalue: Any = attrs.field(default=NOTSET, kw_only=True)
487
488 def __repr__(self) -> str:
489 args_iter = iter(self.args)
490 first = repr(next(args_iter))
491 rest = ", ".join(repr(arg) for arg in args_iter)
492 if is_arg_set(self.fillvalue):
493 return f"{first}.zip({rest}, fillvalue={self.fillvalue!r})"
494 return f"{first}.zip({rest})"
495
496 def _serialize(self) -> dict[str, Any]:
497 args = [serialize_xcom_arg(arg) for arg in self.args]
498 if is_arg_set(self.fillvalue):
499 return {"args": args, "fillvalue": self.fillvalue}
500 return {"args": args}
501
502 def iter_references(self) -> Iterator[tuple[Operator, str]]:
503 for arg in self.args:
504 yield from arg.iter_references()
505
506 def resolve(self, context: Mapping[str, Any]) -> Any:
507 values = [arg.resolve(context) for arg in self.args]
508 for value in values:
509 if not isinstance(value, (Sequence, dict)):
510 raise ValueError(f"XCom zip expects sequence or dict, not {type(value).__name__}")
511 return _ZipResult(values, fillvalue=self.fillvalue)
512
513
514@attrs.define
515class _ConcatResult(Sequence):
516 values: Sequence[Sequence | dict]
517
518 def __getitem__(self, index: Any) -> Any:
519 if index >= 0:
520 i = index
521 else:
522 i = len(self) + index
523 for value in self.values:
524 if i < 0:
525 break
526 if i >= (curlen := len(value)):
527 i -= curlen
528 elif isinstance(value, Sequence):
529 return value[i]
530 else:
531 return next(itertools.islice(iter(value), i, None))
532 raise IndexError("list index out of range")
533
534 def __len__(self) -> int:
535 return sum(len(v) for v in self.values)
536
537
538@attrs.define
539class ConcatXComArg(XComArg):
540 """
541 Concatenating multiple XCom references into one.
542
543 This is done by calling ``concat()`` on an XComArg to combine it with
544 others. The effect is similar to Python's :func:`itertools.chain`, but the
545 return value also supports index access.
546 """
547
548 args: Sequence[XComArg] = attrs.field(validator=attrs.validators.min_len(1))
549
550 def __repr__(self) -> str:
551 args_iter = iter(self.args)
552 first = repr(next(args_iter))
553 rest = ", ".join(repr(arg) for arg in args_iter)
554 return f"{first}.concat({rest})"
555
556 def _serialize(self) -> dict[str, Any]:
557 return {"args": [serialize_xcom_arg(arg) for arg in self.args]}
558
559 def iter_references(self) -> Iterator[tuple[Operator, str]]:
560 for arg in self.args:
561 yield from arg.iter_references()
562
563 def concat(self, *others: XComArg) -> ConcatXComArg:
564 # Flatten foo.concat(x).concat(y) into one call.
565 return ConcatXComArg([*self.args, *others])
566
567 def resolve(self, context: Mapping[str, Any]) -> Any:
568 values = [arg.resolve(context) for arg in self.args]
569 for value in values:
570 if not isinstance(value, (Sequence, dict)):
571 raise ValueError(f"XCom concat expects sequence or dict, not {type(value).__name__}")
572 return _ConcatResult(values)
573
574
575_XCOM_ARG_TYPES: Mapping[str, type[XComArg]] = {
576 "": PlainXComArg,
577 "concat": ConcatXComArg,
578 "map": MapXComArg,
579 "zip": ZipXComArg,
580}
581
582
583def serialize_xcom_arg(value: XComArg) -> dict[str, Any]:
584 """Dag serialization interface."""
585 key = next(k for k, v in _XCOM_ARG_TYPES.items() if isinstance(value, v))
586 if key:
587 return {"type": key, **value._serialize()}
588 return value._serialize()
589
590
591@singledispatch
592def get_task_map_length(
593 xcom_arg: XComArg, resolved_val: Sized, upstream_map_indexes: dict[str, int]
594) -> int | None:
595 # The base implementation -- specific XComArg subclasses have specialised implementations
596 raise NotImplementedError()
597
598
599@get_task_map_length.register
600def _(xcom_arg: PlainXComArg, resolved_val: Sized, upstream_map_indexes: dict[str, int]):
601 task_id = xcom_arg.operator.task_id
602
603 if xcom_arg.operator.is_mapped:
604 # TODO: How to tell if all the upstream TIs finished?
605 pass
606 return (upstream_map_indexes.get(task_id) or 1) * len(resolved_val)
607
608
609@get_task_map_length.register
610def _(xcom_arg: MapXComArg, resolved_val: Sized, upstream_map_indexes: dict[str, int]):
611 return get_task_map_length(xcom_arg.arg, resolved_val, upstream_map_indexes)
612
613
614@get_task_map_length.register
615def _(xcom_arg: ZipXComArg, resolved_val: Sized, upstream_map_indexes: dict[str, int]):
616 all_lengths = (get_task_map_length(arg, resolved_val, upstream_map_indexes) for arg in xcom_arg.args)
617 ready_lengths = [length for length in all_lengths if length is not None]
618 if len(ready_lengths) != len(xcom_arg.args):
619 return None # If any of the referenced XComs is not ready, we are not ready either.
620 if is_arg_set(xcom_arg.fillvalue):
621 return max(ready_lengths)
622 return min(ready_lengths)
623
624
625@get_task_map_length.register
626def _(xcom_arg: ConcatXComArg, resolved_val: Sized, upstream_map_indexes: dict[str, int]):
627 all_lengths = (get_task_map_length(arg, resolved_val, upstream_map_indexes) for arg in xcom_arg.args)
628 ready_lengths = [length for length in all_lengths if length is not None]
629 if len(ready_lengths) != len(xcom_arg.args):
630 return None # If any of the referenced XComs is not ready, we are not ready either.
631 return sum(ready_lengths)