1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements. See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership. The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License. You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied. See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18from __future__ import annotations
19
20import collections
21from typing import Any, Protocol
22
23import structlog
24
25from airflow.sdk.execution_time.comms import (
26 DeleteXCom,
27 GetXCom,
28 GetXComSequenceSlice,
29 SetXCom,
30 XComResult,
31 XComSequenceSliceResult,
32)
33
34# Lightweight wrapper for XCom values
35_XComValueWrapper = collections.namedtuple("_XComValueWrapper", "value")
36
37log = structlog.get_logger(logger_name="task")
38
39
40class TIKeyProtocol(Protocol):
41 dag_id: str
42 task_id: str
43 run_id: str
44 map_index: int
45
46
47class BaseXCom:
48 """BaseXcom is an interface now to interact with XCom backends."""
49
50 XCOM_RETURN_KEY = "return_value"
51
52 @classmethod
53 def set(
54 cls,
55 key: str,
56 value: Any,
57 *,
58 dag_id: str,
59 task_id: str,
60 run_id: str,
61 map_index: int = -1,
62 _mapped_length: int | None = None,
63 ) -> None:
64 """
65 Store an XCom value.
66
67 :param key: Key to store the XCom.
68 :param value: XCom value to store.
69 :param dag_id: Dag ID.
70 :param task_id: Task ID.
71 :param run_id: Dag run ID for the task.
72 :param map_index: Optional map index to assign XCom for a mapped task.
73 The default is ``-1`` (set for a non-mapped task).
74 """
75 from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
76
77 value = cls.serialize_value(
78 value=value,
79 key=key,
80 task_id=task_id,
81 dag_id=dag_id,
82 run_id=run_id,
83 map_index=map_index,
84 )
85
86 SUPERVISOR_COMMS.send(
87 SetXCom(
88 key=key,
89 value=value,
90 dag_id=dag_id,
91 task_id=task_id,
92 run_id=run_id,
93 map_index=map_index,
94 mapped_length=_mapped_length,
95 ),
96 )
97
98 @classmethod
99 def _set_xcom_in_db(
100 cls,
101 key: str,
102 value: Any,
103 *,
104 dag_id: str,
105 task_id: str,
106 run_id: str,
107 map_index: int = -1,
108 ) -> None:
109 """
110 Store an XCom value directly in the metadata database.
111
112 :param key: Key to store the XCom.
113 :param value: XCom value to store.
114 :param dag_id: Dag ID.
115 :param task_id: Task ID.
116 :param run_id: Dag run ID for the task.
117 :param map_index: Optional map index to assign XCom for a mapped task.
118 The default is ``-1`` (set for a non-mapped task).
119 """
120 from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
121
122 SUPERVISOR_COMMS.send(
123 SetXCom(
124 key=key,
125 value=value,
126 dag_id=dag_id,
127 task_id=task_id,
128 run_id=run_id,
129 map_index=map_index,
130 ),
131 )
132
133 @classmethod
134 def get_value(
135 cls,
136 *,
137 ti_key: TIKeyProtocol,
138 key: str,
139 ) -> Any:
140 """
141 Retrieve an XCom value for a task instance.
142
143 This method returns "full" XCom values (i.e. uses ``deserialize_value``
144 from the XCom backend).
145
146 If there are no results, *None* is returned. If multiple XCom entries
147 match the criteria, an arbitrary one is returned.
148
149 :param ti_key: The TaskInstanceKey to look up the XCom for.
150 :param key: A key for the XCom. If provided, only XCom with matching
151 keys will be returned. Pass *None* (default) to remove the filter.
152 """
153 return cls.get_one(
154 key=key,
155 task_id=ti_key.task_id,
156 dag_id=ti_key.dag_id,
157 run_id=ti_key.run_id,
158 map_index=ti_key.map_index,
159 )
160
161 @classmethod
162 def _get_xcom_db_ref(
163 cls,
164 *,
165 key: str,
166 dag_id: str,
167 task_id: str,
168 run_id: str,
169 map_index: int | None = None,
170 ) -> XComResult:
171 """
172 Retrieve an XCom value, optionally meeting certain criteria.
173
174 This method returns "full" XCom values (i.e. uses ``deserialize_value``
175 from the XCom backend).
176
177 If there are no results, *None* is returned. If multiple XCom entries
178 match the criteria, an arbitrary one is returned.
179
180 .. seealso:: ``get_value()`` is a convenience function if you already
181 have a structured TaskInstance or TaskInstanceKey object available.
182
183 :param run_id: Dag run ID for the task.
184 :param dag_id: Only pull XCom from this Dag. Pass *None* (default) to
185 remove the filter.
186 :param task_id: Only XCom from task with matching ID will be pulled.
187 Pass *None* (default) to remove the filter.
188 :param map_index: Only XCom from task with matching ID will be pulled.
189 Pass *None* (default) to remove the filter.
190 :param key: A key for the XCom. If provided, only XCom with matching
191 keys will be returned. Pass *None* (default) to remove the filter.
192 """
193 from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
194
195 msg = SUPERVISOR_COMMS.send(
196 GetXCom(
197 key=key,
198 dag_id=dag_id,
199 task_id=task_id,
200 run_id=run_id,
201 map_index=map_index,
202 ),
203 )
204
205 if not isinstance(msg, XComResult):
206 raise TypeError(f"Expected XComResult, received: {type(msg)} {msg}")
207
208 return msg
209
210 @classmethod
211 def get_one(
212 cls,
213 *,
214 key: str,
215 dag_id: str,
216 task_id: str,
217 run_id: str,
218 map_index: int | None = None,
219 include_prior_dates: bool = False,
220 ) -> Any | None:
221 """
222 Retrieve an XCom value, optionally meeting certain criteria.
223
224 This method returns "full" XCom values (i.e. uses ``deserialize_value``
225 from the XCom backend).
226
227 If there are no results, *None* is returned. If multiple XCom entries
228 match the criteria, an arbitrary one is returned.
229
230 .. seealso:: ``get_value()`` is a convenience function if you already
231 have a structured TaskInstance or TaskInstanceKey object available.
232
233 :param run_id: Dag run ID for the task.
234 :param dag_id: Only pull XCom from this Dag. Pass *None* (default) to
235 remove the filter.
236 :param task_id: Only XCom from task with matching ID will be pulled.
237 Pass *None* (default) to remove the filter.
238 :param map_index: Only XCom from task with matching ID will be pulled.
239 Pass *None* (default) to remove the filter.
240 :param key: A key for the XCom. If provided, only XCom with matching
241 keys will be returned. Pass *None* (default) to remove the filter.
242 :param include_prior_dates: If *False* (default), only XCom from the
243 specified Dag run is returned. If *True*, the latest matching XCom is
244 returned regardless of the run it belongs to.
245 """
246 from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
247
248 msg = SUPERVISOR_COMMS.send(
249 GetXCom(
250 key=key,
251 dag_id=dag_id,
252 task_id=task_id,
253 run_id=run_id,
254 map_index=map_index,
255 include_prior_dates=include_prior_dates,
256 ),
257 )
258
259 if not isinstance(msg, XComResult):
260 raise TypeError(f"Expected XComResult, received: {type(msg)} {msg}")
261
262 if msg.value is not None:
263 return cls.deserialize_value(msg)
264 log.warning(
265 "No XCom value found; defaulting to None.",
266 key=key,
267 dag_id=dag_id,
268 task_id=task_id,
269 run_id=run_id,
270 map_index=map_index,
271 )
272 return None
273
274 @classmethod
275 def get_all(
276 cls,
277 *,
278 key: str,
279 dag_id: str,
280 task_id: str,
281 run_id: str,
282 include_prior_dates: bool = False,
283 ) -> Any:
284 """
285 Retrieve all XCom values for a task, typically from all map indexes.
286
287 XComSequenceSliceResult can never have *None* in it, it returns an empty list
288 if no values were found.
289
290 This is particularly useful for getting all XCom values from all map
291 indexes of a mapped task at once.
292
293 :param key: A key for the XCom. Only XComs with this key will be returned.
294 :param run_id: Dag run ID for the task.
295 :param dag_id: Dag ID to pull XComs from.
296 :param task_id: Task ID to pull XComs from.
297 :param include_prior_dates: If *False* (default), only XComs from the
298 specified Dag run are returned. If *True*, the latest matching XComs are
299 returned regardless of the run they belong to.
300 :return: List of all XCom values if found.
301 """
302 from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
303
304 msg = SUPERVISOR_COMMS.send(
305 msg=GetXComSequenceSlice(
306 key=key,
307 dag_id=dag_id,
308 task_id=task_id,
309 run_id=run_id,
310 start=None,
311 stop=None,
312 step=None,
313 include_prior_dates=include_prior_dates,
314 ),
315 )
316
317 if not isinstance(msg, XComSequenceSliceResult):
318 raise TypeError(f"Expected XComSequenceSliceResult, received: {type(msg)} {msg}")
319
320 if not msg.root:
321 return None
322
323 return [cls.deserialize_value(_XComValueWrapper(value)) for value in msg.root]
324
325 @staticmethod
326 def serialize_value(
327 value: Any,
328 *,
329 key: str | None = None,
330 task_id: str | None = None,
331 dag_id: str | None = None,
332 run_id: str | None = None,
333 map_index: int | None = None,
334 ) -> str:
335 """Serialize XCom value to JSON str."""
336 from airflow.serialization.serde import serialize
337
338 # return back the value for BaseXCom, custom backends will implement this
339 return serialize(value) # type: ignore[return-value]
340
341 @staticmethod
342 def deserialize_value(result) -> Any:
343 """Deserialize XCom value from str objects."""
344 from airflow.serialization.serde import deserialize
345
346 return deserialize(result.value)
347
348 @classmethod
349 def purge(cls, xcom: XComResult, *args) -> None:
350 """Purge an XCom entry from underlying storage implementations."""
351 pass
352
353 @classmethod
354 def delete(
355 cls,
356 key: str,
357 task_id: str,
358 dag_id: str,
359 run_id: str,
360 map_index: int | None = None,
361 ) -> None:
362 """Delete an Xcom entry, for custom xcom backends, it gets the path associated with the data on the backend and purges it."""
363 from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
364
365 xcom_result = cls._get_xcom_db_ref(
366 key=key,
367 dag_id=dag_id,
368 task_id=task_id,
369 run_id=run_id,
370 map_index=map_index,
371 )
372 cls.purge(xcom_result)
373 SUPERVISOR_COMMS.send(
374 DeleteXCom(
375 key=key,
376 dag_id=dag_id,
377 task_id=task_id,
378 run_id=run_id,
379 map_index=map_index,
380 ),
381 )