1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18from __future__ import annotations
19
20import collections
21import itertools
22from collections.abc import Iterator, Sequence
23from typing import TYPE_CHECKING, Any, Literal, TypeVar, overload
24
25import attrs
26import structlog
27
28if TYPE_CHECKING:
29 from airflow.sdk.definitions.xcom_arg import PlainXComArg
30 from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
31
32T = TypeVar("T")
33
34# This is used to wrap values from the API so the structure is compatible with
35# ``XCom.deserialize_value``. We don't want to wrap the API values in a nested
36# {"value": value} dict since it wastes bandwidth.
37_XComWrapper = collections.namedtuple("_XComWrapper", "value")
38
39log = structlog.get_logger(logger_name=__name__)
40
41
42@attrs.define
43class LazyXComIterator(Iterator[T]):
44 seq: LazyXComSequence[T]
45 index: int = 0
46 dir: Literal[1, -1] = 1
47
48 def __next__(self) -> T:
49 if self.index < 0:
50 # When iterating backwards, avoid extra HTTP request
51 raise StopIteration()
52 try:
53 val = self.seq[self.index]
54 except IndexError:
55 raise StopIteration from None
56 self.index += self.dir
57 return val
58
59 def __iter__(self) -> Iterator[T]:
60 return self
61
62
63@attrs.define
64class LazyXComSequence(Sequence[T]):
65 _len: int | None = attrs.field(init=False, default=None)
66 _xcom_arg: PlainXComArg = attrs.field(alias="xcom_arg")
67 _ti: RuntimeTaskInstance = attrs.field(alias="ti")
68
69 def __repr__(self) -> str:
70 if self._len is not None:
71 counter = "item" if (length := len(self)) == 1 else "items"
72 return f"LazyXComSequence([{length} {counter}])"
73 return "LazyXComSequence(<unevaluated length>)"
74
75 def __str__(self) -> str:
76 return repr(self)
77
78 def __eq__(self, other: Any) -> bool:
79 if not isinstance(other, Sequence):
80 return NotImplemented
81 z = itertools.zip_longest(iter(self), iter(other), fillvalue=object())
82 return all(x == y for x, y in z)
83
84 def __hash__(self):
85 return hash((*[item for item in iter(self)],))
86
87 def __iter__(self) -> Iterator[T]:
88 return LazyXComIterator(seq=self)
89
90 def __len__(self) -> int:
91 if self._len is None:
92 from airflow.sdk.execution_time.comms import ErrorResponse, GetXComCount, XComCountResponse
93 from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
94
95 task = self._xcom_arg.operator
96
97 msg = SUPERVISOR_COMMS.send(
98 GetXComCount(
99 key=self._xcom_arg.key,
100 dag_id=task.dag_id,
101 run_id=self._ti.run_id,
102 task_id=task.task_id,
103 ),
104 )
105 if isinstance(msg, ErrorResponse):
106 raise RuntimeError(msg)
107 if not isinstance(msg, XComCountResponse):
108 raise TypeError(f"Got unexpected response to GetXComCount: {msg!r}")
109 self._len = msg.len
110 return self._len
111
112 @overload
113 def __getitem__(self, key: int) -> T: ...
114
115 @overload
116 def __getitem__(self, key: slice) -> Sequence[T]: ...
117
118 def __getitem__(self, key: int | slice) -> T | Sequence[T]:
119 from airflow.sdk.execution_time.comms import (
120 ErrorResponse,
121 GetXComSequenceItem,
122 GetXComSequenceSlice,
123 XComSequenceIndexResult,
124 XComSequenceSliceResult,
125 )
126 from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
127 from airflow.sdk.execution_time.xcom import XCom
128
129 if isinstance(key, slice):
130 start, stop, step = _coerce_slice(key)
131 source = (xcom_arg := self._xcom_arg).operator
132 msg = SUPERVISOR_COMMS.send(
133 GetXComSequenceSlice(
134 key=xcom_arg.key,
135 dag_id=source.dag_id,
136 task_id=source.task_id,
137 run_id=self._ti.run_id,
138 start=start,
139 stop=stop,
140 step=step,
141 ),
142 )
143 if not isinstance(msg, XComSequenceSliceResult):
144 raise TypeError(f"Got unexpected response to GetXComSequenceSlice: {msg!r}")
145 return [XCom.deserialize_value(_XComWrapper(value)) for value in msg.root]
146
147 if not isinstance(key, int):
148 if (index := getattr(key, "__index__", None)) is not None:
149 key = index()
150 raise TypeError(f"Sequence indices must be integers or slices not {type(key).__name__}")
151
152 source = (xcom_arg := self._xcom_arg).operator
153 msg = SUPERVISOR_COMMS.send(
154 GetXComSequenceItem(
155 key=xcom_arg.key,
156 dag_id=source.dag_id,
157 task_id=source.task_id,
158 run_id=self._ti.run_id,
159 offset=key,
160 ),
161 )
162 if isinstance(msg, ErrorResponse):
163 raise IndexError(key)
164 if not isinstance(msg, XComSequenceIndexResult):
165 raise TypeError(f"Got unexpected response to GetXComSequenceItem: {msg!r}")
166 return XCom.deserialize_value(_XComWrapper(msg.root))
167
168
169def _coerce_slice_index(value: Any) -> int | None:
170 """
171 Check slice attribute's type and convert it to int.
172
173 See CPython documentation on this:
174 https://docs.python.org/3/reference/datamodel.html#object.__index__
175 """
176 if value is None or isinstance(value, int):
177 return value
178 if (index := getattr(value, "__index__", None)) is not None:
179 return index()
180 raise TypeError("slice indices must be integers or None or have an __index__ method")
181
182
183def _coerce_slice(key: slice) -> tuple[int | None, int | None, int | None]:
184 """
185 Check slice content and convert it for SQL.
186
187 See CPython documentation on this:
188 https://docs.python.org/3/reference/datamodel.html#slice-objects
189 """
190 if (step := _coerce_slice_index(key.step)) == 0:
191 raise ValueError("slice step cannot be zero")
192 return _coerce_slice_index(key.start), _coerce_slice_index(key.stop), step