Coverage for /pythoncovmergedfiles/medio/medio/src/airflow/build/lib/airflow/models/xcom.py: 44%
310 statements
« prev ^ index » next coverage.py v7.0.1, created at 2022-12-25 06:11 +0000
« prev ^ index » next coverage.py v7.0.1, created at 2022-12-25 06:11 +0000
1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18from __future__ import annotations
20import collections.abc
21import contextlib
22import datetime
23import inspect
24import itertools
25import json
26import logging
27import pickle
28import warnings
29from functools import wraps
30from typing import TYPE_CHECKING, Any, Generator, Iterable, cast, overload
32import attr
33import pendulum
34from sqlalchemy import (
35 Column,
36 ForeignKeyConstraint,
37 Index,
38 Integer,
39 LargeBinary,
40 PrimaryKeyConstraint,
41 String,
42 text,
43)
44from sqlalchemy.ext.associationproxy import association_proxy
45from sqlalchemy.orm import Query, Session, reconstructor, relationship
46from sqlalchemy.orm.exc import NoResultFound
48from airflow import settings
49from airflow.compat.functools import cached_property
50from airflow.configuration import conf
51from airflow.exceptions import RemovedInAirflow3Warning
52from airflow.models.base import COLLATION_ARGS, ID_LEN, Base
53from airflow.utils import timezone
54from airflow.utils.helpers import exactly_one, is_container
55from airflow.utils.json import XComDecoder, XComEncoder
56from airflow.utils.log.logging_mixin import LoggingMixin
57from airflow.utils.session import NEW_SESSION, provide_session
58from airflow.utils.sqlalchemy import UtcDateTime
60log = logging.getLogger(__name__)
62# MAX XCOM Size is 48KB
63# https://github.com/apache/airflow/pull/1618#discussion_r68249677
64MAX_XCOM_SIZE = 49344
65XCOM_RETURN_KEY = "return_value"
67if TYPE_CHECKING:
68 from airflow.models.taskinstance import TaskInstanceKey
71class BaseXCom(Base, LoggingMixin):
72 """Base class for XCom objects."""
74 __tablename__ = "xcom"
76 dag_run_id = Column(Integer(), nullable=False, primary_key=True)
77 task_id = Column(String(ID_LEN, **COLLATION_ARGS), nullable=False, primary_key=True)
78 map_index = Column(Integer, primary_key=True, nullable=False, server_default=text("-1"))
79 key = Column(String(512, **COLLATION_ARGS), nullable=False, primary_key=True)
81 # Denormalized for easier lookup.
82 dag_id = Column(String(ID_LEN, **COLLATION_ARGS), nullable=False)
83 run_id = Column(String(ID_LEN, **COLLATION_ARGS), nullable=False)
85 value = Column(LargeBinary)
86 timestamp = Column(UtcDateTime, default=timezone.utcnow, nullable=False)
88 __table_args__ = (
89 # Ideally we should create a unique index over (key, dag_id, task_id, run_id),
90 # but it goes over MySQL's index length limit. So we instead index 'key'
91 # separately, and enforce uniqueness with DagRun.id instead.
92 Index("idx_xcom_key", key),
93 Index("idx_xcom_task_instance", dag_id, task_id, run_id, map_index),
94 PrimaryKeyConstraint(
95 "dag_run_id", "task_id", "map_index", "key", name="xcom_pkey", mssql_clustered=True
96 ),
97 ForeignKeyConstraint(
98 [dag_id, task_id, run_id, map_index],
99 [
100 "task_instance.dag_id",
101 "task_instance.task_id",
102 "task_instance.run_id",
103 "task_instance.map_index",
104 ],
105 name="xcom_task_instance_fkey",
106 ondelete="CASCADE",
107 ),
108 )
110 dag_run = relationship(
111 "DagRun",
112 primaryjoin="BaseXCom.dag_run_id == foreign(DagRun.id)",
113 uselist=False,
114 lazy="joined",
115 passive_deletes="all",
116 )
117 execution_date = association_proxy("dag_run", "execution_date")
119 @reconstructor
120 def init_on_load(self):
121 """
122 Called by the ORM after the instance has been loaded from the DB or otherwise reconstituted
123 i.e automatically deserialize Xcom value when loading from DB.
124 """
125 self.value = self.orm_deserialize_value()
127 def __repr__(self):
128 if self.map_index < 0:
129 return f'<XCom "{self.key}" ({self.task_id} @ {self.run_id})>'
130 return f'<XCom "{self.key}" ({self.task_id}[{self.map_index}] @ {self.run_id})>'
132 @overload
133 @classmethod
134 def set(
135 cls,
136 key: str,
137 value: Any,
138 *,
139 dag_id: str,
140 task_id: str,
141 run_id: str,
142 map_index: int = -1,
143 session: Session = NEW_SESSION,
144 ) -> None:
145 """Store an XCom value.
147 A deprecated form of this function accepts ``execution_date`` instead of
148 ``run_id``. The two arguments are mutually exclusive.
150 :param key: Key to store the XCom.
151 :param value: XCom value to store.
152 :param dag_id: DAG ID.
153 :param task_id: Task ID.
154 :param run_id: DAG run ID for the task.
155 :param map_index: Optional map index to assign XCom for a mapped task.
156 The default is ``-1`` (set for a non-mapped task).
157 :param session: Database session. If not given, a new session will be
158 created for this function.
159 """
161 @overload
162 @classmethod
163 def set(
164 cls,
165 key: str,
166 value: Any,
167 task_id: str,
168 dag_id: str,
169 execution_date: datetime.datetime,
170 session: Session = NEW_SESSION,
171 ) -> None:
172 """:sphinx-autoapi-skip:"""
174 @classmethod
175 @provide_session
176 def set(
177 cls,
178 key: str,
179 value: Any,
180 task_id: str,
181 dag_id: str,
182 execution_date: datetime.datetime | None = None,
183 session: Session = NEW_SESSION,
184 *,
185 run_id: str | None = None,
186 map_index: int = -1,
187 ) -> None:
188 """:sphinx-autoapi-skip:"""
189 from airflow.models.dagrun import DagRun
191 if not exactly_one(execution_date is not None, run_id is not None):
192 raise ValueError(
193 f"Exactly one of run_id or execution_date must be passed. "
194 f"Passed execution_date={execution_date}, run_id={run_id}"
195 )
197 if run_id is None:
198 message = "Passing 'execution_date' to 'XCom.set()' is deprecated. Use 'run_id' instead."
199 warnings.warn(message, RemovedInAirflow3Warning, stacklevel=3)
200 try:
201 dag_run_id, run_id = (
202 session.query(DagRun.id, DagRun.run_id)
203 .filter(DagRun.dag_id == dag_id, DagRun.execution_date == execution_date)
204 .one()
205 )
206 except NoResultFound:
207 raise ValueError(f"DAG run not found on DAG {dag_id!r} at {execution_date}") from None
208 else:
209 dag_run_id = session.query(DagRun.id).filter_by(dag_id=dag_id, run_id=run_id).scalar()
210 if dag_run_id is None:
211 raise ValueError(f"DAG run not found on DAG {dag_id!r} with ID {run_id!r}")
213 # Seamlessly resolve LazyXComAccess to a list. This is intended to work
214 # as a "lazy list" to avoid pulling a ton of XComs unnecessarily, but if
215 # it's pushed into XCom, the user should be aware of the performance
216 # implications, and this avoids leaking the implementation detail.
217 if isinstance(value, LazyXComAccess):
218 warning_message = (
219 "Coercing mapped lazy proxy %s from task %s (DAG %s, run %s) "
220 "to list, which may degrade performance. Review resource "
221 "requirements for this operation, and call list() to suppress "
222 "this message. See Dynamic Task Mapping documentation for "
223 "more information about lazy proxy objects."
224 )
225 log.warning(
226 warning_message,
227 "return value" if key == XCOM_RETURN_KEY else f"value {key}",
228 task_id,
229 dag_id,
230 run_id or execution_date,
231 )
232 value = list(value)
234 value = cls.serialize_value(
235 value=value,
236 key=key,
237 task_id=task_id,
238 dag_id=dag_id,
239 run_id=run_id,
240 map_index=map_index,
241 )
243 # Remove duplicate XComs and insert a new one.
244 session.query(cls).filter(
245 cls.key == key,
246 cls.run_id == run_id,
247 cls.task_id == task_id,
248 cls.dag_id == dag_id,
249 cls.map_index == map_index,
250 ).delete()
251 new = cast(Any, cls)( # Work around Mypy complaining model not defining '__init__'.
252 dag_run_id=dag_run_id,
253 key=key,
254 value=value,
255 run_id=run_id,
256 task_id=task_id,
257 dag_id=dag_id,
258 map_index=map_index,
259 )
260 session.add(new)
261 session.flush()
263 @classmethod
264 @provide_session
265 def get_value(
266 cls,
267 *,
268 ti_key: TaskInstanceKey,
269 key: str | None = None,
270 session: Session = NEW_SESSION,
271 ) -> Any:
272 """Retrieve an XCom value for a task instance.
274 This method returns "full" XCom values (i.e. uses ``deserialize_value``
275 from the XCom backend). Use :meth:`get_many` if you want the "shortened"
276 value via ``orm_deserialize_value``.
278 If there are no results, *None* is returned. If multiple XCom entries
279 match the criteria, an arbitrary one is returned.
281 :param ti_key: The TaskInstanceKey to look up the XCom for.
282 :param key: A key for the XCom. If provided, only XCom with matching
283 keys will be returned. Pass *None* (default) to remove the filter.
284 :param session: Database session. If not given, a new session will be
285 created for this function.
286 """
287 return cls.get_one(
288 key=key,
289 task_id=ti_key.task_id,
290 dag_id=ti_key.dag_id,
291 run_id=ti_key.run_id,
292 map_index=ti_key.map_index,
293 session=session,
294 )
296 @overload
297 @classmethod
298 def get_one(
299 cls,
300 *,
301 key: str | None = None,
302 dag_id: str | None = None,
303 task_id: str | None = None,
304 run_id: str | None = None,
305 map_index: int | None = None,
306 session: Session = NEW_SESSION,
307 ) -> Any | None:
308 """Retrieve an XCom value, optionally meeting certain criteria.
310 This method returns "full" XCom values (i.e. uses ``deserialize_value``
311 from the XCom backend). Use :meth:`get_many` if you want the "shortened"
312 value via ``orm_deserialize_value``.
314 If there are no results, *None* is returned. If multiple XCom entries
315 match the criteria, an arbitrary one is returned.
317 A deprecated form of this function accepts ``execution_date`` instead of
318 ``run_id``. The two arguments are mutually exclusive.
320 .. seealso:: ``get_value()`` is a convenience function if you already
321 have a structured TaskInstance or TaskInstanceKey object available.
323 :param run_id: DAG run ID for the task.
324 :param dag_id: Only pull XCom from this DAG. Pass *None* (default) to
325 remove the filter.
326 :param task_id: Only XCom from task with matching ID will be pulled.
327 Pass *None* (default) to remove the filter.
328 :param map_index: Only XCom from task with matching ID will be pulled.
329 Pass *None* (default) to remove the filter.
330 :param key: A key for the XCom. If provided, only XCom with matching
331 keys will be returned. Pass *None* (default) to remove the filter.
332 :param include_prior_dates: If *False* (default), only XCom from the
333 specified DAG run is returned. If *True*, the latest matching XCom is
334 returned regardless of the run it belongs to.
335 :param session: Database session. If not given, a new session will be
336 created for this function.
337 """
339 @overload
340 @classmethod
341 def get_one(
342 cls,
343 execution_date: datetime.datetime,
344 key: str | None = None,
345 task_id: str | None = None,
346 dag_id: str | None = None,
347 include_prior_dates: bool = False,
348 session: Session = NEW_SESSION,
349 ) -> Any | None:
350 """:sphinx-autoapi-skip:"""
352 @classmethod
353 @provide_session
354 def get_one(
355 cls,
356 execution_date: datetime.datetime | None = None,
357 key: str | None = None,
358 task_id: str | None = None,
359 dag_id: str | None = None,
360 include_prior_dates: bool = False,
361 session: Session = NEW_SESSION,
362 *,
363 run_id: str | None = None,
364 map_index: int | None = None,
365 ) -> Any | None:
366 """:sphinx-autoapi-skip:"""
367 if not exactly_one(execution_date is not None, run_id is not None):
368 raise ValueError("Exactly one of ti_key, run_id, or execution_date must be passed")
370 if run_id:
371 query = cls.get_many(
372 run_id=run_id,
373 key=key,
374 task_ids=task_id,
375 dag_ids=dag_id,
376 map_indexes=map_index,
377 include_prior_dates=include_prior_dates,
378 limit=1,
379 session=session,
380 )
381 elif execution_date is not None:
382 message = "Passing 'execution_date' to 'XCom.get_one()' is deprecated. Use 'run_id' instead."
383 warnings.warn(message, RemovedInAirflow3Warning, stacklevel=3)
385 with warnings.catch_warnings():
386 warnings.simplefilter("ignore", RemovedInAirflow3Warning)
387 query = cls.get_many(
388 execution_date=execution_date,
389 key=key,
390 task_ids=task_id,
391 dag_ids=dag_id,
392 map_indexes=map_index,
393 include_prior_dates=include_prior_dates,
394 limit=1,
395 session=session,
396 )
397 else:
398 raise RuntimeError("Should not happen?")
400 result = query.with_entities(cls.value).first()
401 if result:
402 return cls.deserialize_value(result)
403 return None
405 @overload
406 @classmethod
407 def get_many(
408 cls,
409 *,
410 run_id: str,
411 key: str | None = None,
412 task_ids: str | Iterable[str] | None = None,
413 dag_ids: str | Iterable[str] | None = None,
414 map_indexes: int | Iterable[int] | None = None,
415 include_prior_dates: bool = False,
416 limit: int | None = None,
417 session: Session = NEW_SESSION,
418 ) -> Query:
419 """Composes a query to get one or more XCom entries.
421 This function returns an SQLAlchemy query of full XCom objects. If you
422 just want one stored value, use :meth:`get_one` instead.
424 A deprecated form of this function accepts ``execution_date`` instead of
425 ``run_id``. The two arguments are mutually exclusive.
427 :param run_id: DAG run ID for the task.
428 :param key: A key for the XComs. If provided, only XComs with matching
429 keys will be returned. Pass *None* (default) to remove the filter.
430 :param task_ids: Only XComs from task with matching IDs will be pulled.
431 Pass *None* (default) to remove the filter.
432 :param dag_ids: Only pulls XComs from specified DAGs. Pass *None*
433 (default) to remove the filter.
434 :param map_indexes: Only XComs from matching map indexes will be pulled.
435 Pass *None* (default) to remove the filter.
436 :param include_prior_dates: If *False* (default), only XComs from the
437 specified DAG run are returned. If *True*, all matching XComs are
438 returned regardless of the run it belongs to.
439 :param session: Database session. If not given, a new session will be
440 created for this function.
441 """
443 @overload
444 @classmethod
445 def get_many(
446 cls,
447 execution_date: datetime.datetime,
448 key: str | None = None,
449 task_ids: str | Iterable[str] | None = None,
450 dag_ids: str | Iterable[str] | None = None,
451 map_indexes: int | Iterable[int] | None = None,
452 include_prior_dates: bool = False,
453 limit: int | None = None,
454 session: Session = NEW_SESSION,
455 ) -> Query:
456 """:sphinx-autoapi-skip:"""
458 @classmethod
459 @provide_session
460 def get_many(
461 cls,
462 execution_date: datetime.datetime | None = None,
463 key: str | None = None,
464 task_ids: str | Iterable[str] | None = None,
465 dag_ids: str | Iterable[str] | None = None,
466 map_indexes: int | Iterable[int] | None = None,
467 include_prior_dates: bool = False,
468 limit: int | None = None,
469 session: Session = NEW_SESSION,
470 *,
471 run_id: str | None = None,
472 ) -> Query:
473 """:sphinx-autoapi-skip:"""
474 from airflow.models.dagrun import DagRun
476 if not exactly_one(execution_date is not None, run_id is not None):
477 raise ValueError(
478 f"Exactly one of run_id or execution_date must be passed. "
479 f"Passed execution_date={execution_date}, run_id={run_id}"
480 )
481 if execution_date is not None:
482 message = "Passing 'execution_date' to 'XCom.get_many()' is deprecated. Use 'run_id' instead."
483 warnings.warn(message, RemovedInAirflow3Warning, stacklevel=3)
485 query = session.query(cls).join(cls.dag_run)
487 if key:
488 query = query.filter(cls.key == key)
490 if is_container(task_ids):
491 query = query.filter(cls.task_id.in_(task_ids))
492 elif task_ids is not None:
493 query = query.filter(cls.task_id == task_ids)
495 if is_container(dag_ids):
496 query = query.filter(cls.dag_id.in_(dag_ids))
497 elif dag_ids is not None:
498 query = query.filter(cls.dag_id == dag_ids)
500 if isinstance(map_indexes, range) and map_indexes.step == 1:
501 query = query.filter(cls.map_index >= map_indexes.start, cls.map_index < map_indexes.stop)
502 elif is_container(map_indexes):
503 query = query.filter(cls.map_index.in_(map_indexes))
504 elif map_indexes is not None:
505 query = query.filter(cls.map_index == map_indexes)
507 if include_prior_dates:
508 if execution_date is not None:
509 query = query.filter(DagRun.execution_date <= execution_date)
510 else:
511 dr = session.query(DagRun.execution_date).filter(DagRun.run_id == run_id).subquery()
512 query = query.filter(cls.execution_date <= dr.c.execution_date)
513 elif execution_date is not None:
514 query = query.filter(DagRun.execution_date == execution_date)
515 else:
516 query = query.filter(cls.run_id == run_id)
518 query = query.order_by(DagRun.execution_date.desc(), cls.timestamp.desc())
519 if limit:
520 return query.limit(limit)
521 return query
523 @classmethod
524 @provide_session
525 def delete(cls, xcoms: XCom | Iterable[XCom], session: Session) -> None:
526 """Delete one or multiple XCom entries."""
527 if isinstance(xcoms, XCom):
528 xcoms = [xcoms]
529 for xcom in xcoms:
530 if not isinstance(xcom, XCom):
531 raise TypeError(f"Expected XCom; received {xcom.__class__.__name__}")
532 session.delete(xcom)
533 session.commit()
535 @overload
536 @classmethod
537 def clear(
538 cls,
539 *,
540 dag_id: str,
541 task_id: str,
542 run_id: str,
543 map_index: int | None = None,
544 session: Session = NEW_SESSION,
545 ) -> None:
546 """Clear all XCom data from the database for the given task instance.
548 A deprecated form of this function accepts ``execution_date`` instead of
549 ``run_id``. The two arguments are mutually exclusive.
551 :param dag_id: ID of DAG to clear the XCom for.
552 :param task_id: ID of task to clear the XCom for.
553 :param run_id: ID of DAG run to clear the XCom for.
554 :param map_index: If given, only clear XCom from this particular mapped
555 task. The default ``None`` clears *all* XComs from the task.
556 :param session: Database session. If not given, a new session will be
557 created for this function.
558 """
560 @overload
561 @classmethod
562 def clear(
563 cls,
564 execution_date: pendulum.DateTime,
565 dag_id: str,
566 task_id: str,
567 session: Session = NEW_SESSION,
568 ) -> None:
569 """:sphinx-autoapi-skip:"""
571 @classmethod
572 @provide_session
573 def clear(
574 cls,
575 execution_date: pendulum.DateTime | None = None,
576 dag_id: str | None = None,
577 task_id: str | None = None,
578 session: Session = NEW_SESSION,
579 *,
580 run_id: str | None = None,
581 map_index: int | None = None,
582 ) -> None:
583 """:sphinx-autoapi-skip:"""
584 from airflow.models import DagRun
586 # Given the historic order of this function (execution_date was first argument) to add a new optional
587 # param we need to add default values for everything :(
588 if dag_id is None:
589 raise TypeError("clear() missing required argument: dag_id")
590 if task_id is None:
591 raise TypeError("clear() missing required argument: task_id")
593 if not exactly_one(execution_date is not None, run_id is not None):
594 raise ValueError(
595 f"Exactly one of run_id or execution_date must be passed. "
596 f"Passed execution_date={execution_date}, run_id={run_id}"
597 )
599 if execution_date is not None:
600 message = "Passing 'execution_date' to 'XCom.clear()' is deprecated. Use 'run_id' instead."
601 warnings.warn(message, RemovedInAirflow3Warning, stacklevel=3)
602 run_id = (
603 session.query(DagRun.run_id)
604 .filter(DagRun.dag_id == dag_id, DagRun.execution_date == execution_date)
605 .scalar()
606 )
608 query = session.query(cls).filter_by(dag_id=dag_id, task_id=task_id, run_id=run_id)
609 if map_index is not None:
610 query = query.filter_by(map_index=map_index)
611 query.delete()
613 @staticmethod
614 def serialize_value(
615 value: Any,
616 *,
617 key: str | None = None,
618 task_id: str | None = None,
619 dag_id: str | None = None,
620 run_id: str | None = None,
621 map_index: int | None = None,
622 ) -> Any:
623 """Serialize XCom value to str or pickled object."""
624 if conf.getboolean("core", "enable_xcom_pickling"):
625 return pickle.dumps(value)
626 try:
627 return json.dumps(value, cls=XComEncoder).encode("UTF-8")
628 except (ValueError, TypeError) as ex:
629 log.error(
630 "%s."
631 " If you are using pickle instead of JSON for XCom,"
632 " then you need to enable pickle support for XCom"
633 " in your airflow config or make sure to decorate your"
634 " object with attr.",
635 ex,
636 )
637 raise
639 @staticmethod
640 def _deserialize_value(result: XCom, orm: bool) -> Any:
641 object_hook = None
642 if orm:
643 object_hook = XComDecoder.orm_object_hook
645 if result.value is None:
646 return None
647 if conf.getboolean("core", "enable_xcom_pickling"):
648 try:
649 return pickle.loads(result.value)
650 except pickle.UnpicklingError:
651 return json.loads(result.value.decode("UTF-8"), cls=XComDecoder, object_hook=object_hook)
652 else:
653 try:
654 return json.loads(result.value.decode("UTF-8"), cls=XComDecoder, object_hook=object_hook)
655 except (json.JSONDecodeError, UnicodeDecodeError):
656 return pickle.loads(result.value)
658 @staticmethod
659 def deserialize_value(result: XCom) -> Any:
660 """Deserialize XCom value from str or pickle object"""
661 return BaseXCom._deserialize_value(result, False)
663 def orm_deserialize_value(self) -> Any:
664 """
665 Deserialize method which is used to reconstruct ORM XCom object.
667 This method should be overridden in custom XCom backends to avoid
668 unnecessary request or other resource consuming operations when
669 creating XCom orm model. This is used when viewing XCom listing
670 in the webserver, for example.
671 """
672 return BaseXCom._deserialize_value(self, True)
675class _LazyXComAccessIterator(collections.abc.Iterator):
676 def __init__(self, cm: contextlib.AbstractContextManager[Query]) -> None:
677 self._cm = cm
678 self._entered = False
680 def __del__(self) -> None:
681 if self._entered:
682 self._cm.__exit__(None, None, None)
684 def __iter__(self) -> collections.abc.Iterator:
685 return self
687 def __next__(self) -> Any:
688 return XCom.deserialize_value(next(self._it))
690 @cached_property
691 def _it(self) -> collections.abc.Iterator:
692 self._entered = True
693 return iter(self._cm.__enter__())
696@attr.define(slots=True)
697class LazyXComAccess(collections.abc.Sequence):
698 """Wrapper to lazily pull XCom with a sequence-like interface.
700 Note that since the session bound to the parent query may have died when we
701 actually access the sequence's content, we must create a new session
702 for every function call with ``with_session()``.
704 :meta private:
705 """
707 _query: Query
708 _len: int | None = attr.ib(init=False, default=None)
710 @classmethod
711 def build_from_xcom_query(cls, query: Query) -> LazyXComAccess:
712 return cls(query=query.with_entities(XCom.value))
714 def __repr__(self) -> str:
715 return f"LazyXComAccess([{len(self)} items])"
717 def __str__(self) -> str:
718 return str(list(self))
720 def __eq__(self, other: Any) -> bool:
721 if isinstance(other, (list, LazyXComAccess)):
722 z = itertools.zip_longest(iter(self), iter(other), fillvalue=object())
723 return all(x == y for x, y in z)
724 return NotImplemented
726 def __getstate__(self) -> Any:
727 # We don't want to go to the trouble of serializing the entire Query
728 # object, including its filters, hints, etc. (plus SQLAlchemy does not
729 # provide a public API to inspect a query's contents). Converting the
730 # query into a SQL string is the best we can get. Theoratically we can
731 # do the same for count(), but I think it should be performant enough to
732 # calculate only that eagerly.
733 with self._get_bound_query() as query:
734 statement = query.statement.compile(query.session.get_bind())
735 return (str(statement), query.count())
737 def __setstate__(self, state: Any) -> None:
738 statement, self._len = state
739 self._query = Query(XCom.value).from_statement(text(statement))
741 def __len__(self):
742 if self._len is None:
743 with self._get_bound_query() as query:
744 self._len = query.count()
745 return self._len
747 def __iter__(self):
748 return _LazyXComAccessIterator(self._get_bound_query())
750 def __getitem__(self, key):
751 if not isinstance(key, int):
752 raise ValueError("only support index access for now")
753 try:
754 with self._get_bound_query() as query:
755 r = query.offset(key).limit(1).one()
756 except NoResultFound:
757 raise IndexError(key) from None
758 return XCom.deserialize_value(r)
760 @contextlib.contextmanager
761 def _get_bound_query(self) -> Generator[Query, None, None]:
762 # Do we have a valid session already?
763 if self._query.session and self._query.session.is_active:
764 yield self._query
765 return
767 session = settings.Session()
768 try:
769 yield self._query.with_session(session)
770 finally:
771 session.close()
774def _patch_outdated_serializer(clazz: type[BaseXCom], params: Iterable[str]) -> None:
775 """Patch a custom ``serialize_value`` to accept the modern signature.
777 To give custom XCom backends more flexibility with how they store values, we
778 now forward all params passed to ``XCom.set`` to ``XCom.serialize_value``.
779 In order to maintain compatibility with custom XCom backends written with
780 the old signature, we check the signature and, if necessary, patch with a
781 method that ignores kwargs the backend does not accept.
782 """
783 old_serializer = clazz.serialize_value
785 @wraps(old_serializer)
786 def _shim(**kwargs):
787 kwargs = {k: kwargs.get(k) for k in params}
788 warnings.warn(
789 f"Method `serialize_value` in XCom backend {XCom.__name__} is using outdated signature and"
790 f"must be updated to accept all params in `BaseXCom.set` except `session`. Support will be "
791 f"removed in a future release.",
792 RemovedInAirflow3Warning,
793 )
794 return old_serializer(**kwargs)
796 clazz.serialize_value = _shim # type: ignore[assignment]
799def _get_function_params(function) -> list[str]:
800 """
801 Returns the list of variables names of a function
803 :param function: The function to inspect
804 """
805 parameters = inspect.signature(function).parameters
806 bound_arguments = [
807 name for name, p in parameters.items() if p.kind not in (p.VAR_POSITIONAL, p.VAR_KEYWORD)
808 ]
809 return bound_arguments
812def resolve_xcom_backend() -> type[BaseXCom]:
813 """Resolves custom XCom class
815 Confirms that custom XCom class extends the BaseXCom.
816 Compares the function signature of the custom XCom serialize_value to the base XCom serialize_value.
817 """
818 clazz = conf.getimport("core", "xcom_backend", fallback=f"airflow.models.xcom.{BaseXCom.__name__}")
819 if not clazz:
820 return BaseXCom
821 if not issubclass(clazz, BaseXCom):
822 raise TypeError(
823 f"Your custom XCom class `{clazz.__name__}` is not a subclass of `{BaseXCom.__name__}`."
824 )
825 base_xcom_params = _get_function_params(BaseXCom.serialize_value)
826 xcom_params = _get_function_params(clazz.serialize_value)
827 if not set(base_xcom_params) == set(xcom_params):
828 _patch_outdated_serializer(clazz=clazz, params=xcom_params)
829 return clazz
832if TYPE_CHECKING:
833 XCom = BaseXCom # Hack to avoid Mypy "Variable 'XCom' is not valid as a type".
834else:
835 XCom = resolve_xcom_backend()