Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/airflow/models/xcom.py: 45%
321 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 06:35 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 06:35 +0000
1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18from __future__ import annotations
20import collections.abc
21import contextlib
22import datetime
23import inspect
24import itertools
25import json
26import logging
27import pickle
28import warnings
29from functools import cached_property, wraps
30from typing import TYPE_CHECKING, Any, Generator, Iterable, cast, overload
32import attr
33import pendulum
34from sqlalchemy import (
35 Column,
36 ForeignKeyConstraint,
37 Index,
38 Integer,
39 LargeBinary,
40 PrimaryKeyConstraint,
41 String,
42 delete,
43 text,
44)
45from sqlalchemy.ext.associationproxy import association_proxy
46from sqlalchemy.orm import Query, Session, reconstructor, relationship
47from sqlalchemy.orm.exc import NoResultFound
49from airflow import settings
50from airflow.api_internal.internal_api_call import internal_api_call
51from airflow.configuration import conf
52from airflow.exceptions import RemovedInAirflow3Warning
53from airflow.models.base import COLLATION_ARGS, ID_LEN, Base
54from airflow.utils import timezone
55from airflow.utils.helpers import exactly_one, is_container
56from airflow.utils.json import XComDecoder, XComEncoder
57from airflow.utils.log.logging_mixin import LoggingMixin
58from airflow.utils.session import NEW_SESSION, provide_session
59from airflow.utils.sqlalchemy import UtcDateTime
61# XCom constants below are needed for providers backward compatibility,
62# which should import the constants directly after apache-airflow>=2.6.0
63from airflow.utils.xcom import (
64 MAX_XCOM_SIZE, # noqa: F401
65 XCOM_RETURN_KEY,
66)
68log = logging.getLogger(__name__)
70if TYPE_CHECKING:
71 from airflow.models.taskinstancekey import TaskInstanceKey
74class BaseXCom(Base, LoggingMixin):
75 """Base class for XCom objects."""
77 __tablename__ = "xcom"
79 dag_run_id = Column(Integer(), nullable=False, primary_key=True)
80 task_id = Column(String(ID_LEN, **COLLATION_ARGS), nullable=False, primary_key=True)
81 map_index = Column(Integer, primary_key=True, nullable=False, server_default=text("-1"))
82 key = Column(String(512, **COLLATION_ARGS), nullable=False, primary_key=True)
84 # Denormalized for easier lookup.
85 dag_id = Column(String(ID_LEN, **COLLATION_ARGS), nullable=False)
86 run_id = Column(String(ID_LEN, **COLLATION_ARGS), nullable=False)
88 value = Column(LargeBinary)
89 timestamp = Column(UtcDateTime, default=timezone.utcnow, nullable=False)
91 __table_args__ = (
92 # Ideally we should create a unique index over (key, dag_id, task_id, run_id),
93 # but it goes over MySQL's index length limit. So we instead index 'key'
94 # separately, and enforce uniqueness with DagRun.id instead.
95 Index("idx_xcom_key", key),
96 Index("idx_xcom_task_instance", dag_id, task_id, run_id, map_index),
97 PrimaryKeyConstraint(
98 "dag_run_id", "task_id", "map_index", "key", name="xcom_pkey", mssql_clustered=True
99 ),
100 ForeignKeyConstraint(
101 [dag_id, task_id, run_id, map_index],
102 [
103 "task_instance.dag_id",
104 "task_instance.task_id",
105 "task_instance.run_id",
106 "task_instance.map_index",
107 ],
108 name="xcom_task_instance_fkey",
109 ondelete="CASCADE",
110 ),
111 )
113 dag_run = relationship(
114 "DagRun",
115 primaryjoin="BaseXCom.dag_run_id == foreign(DagRun.id)",
116 uselist=False,
117 lazy="joined",
118 passive_deletes="all",
119 )
120 execution_date = association_proxy("dag_run", "execution_date")
122 @reconstructor
123 def init_on_load(self):
124 """
125 Called by the ORM after the instance has been loaded from the DB or otherwise reconstituted
126 i.e automatically deserialize Xcom value when loading from DB.
127 """
128 self.value = self.orm_deserialize_value()
130 def __repr__(self):
131 if self.map_index < 0:
132 return f'<XCom "{self.key}" ({self.task_id} @ {self.run_id})>'
133 return f'<XCom "{self.key}" ({self.task_id}[{self.map_index}] @ {self.run_id})>'
135 @overload
136 @classmethod
137 def set(
138 cls,
139 key: str,
140 value: Any,
141 *,
142 dag_id: str,
143 task_id: str,
144 run_id: str,
145 map_index: int = -1,
146 session: Session = NEW_SESSION,
147 ) -> None:
148 """Store an XCom value.
150 A deprecated form of this function accepts ``execution_date`` instead of
151 ``run_id``. The two arguments are mutually exclusive.
153 :param key: Key to store the XCom.
154 :param value: XCom value to store.
155 :param dag_id: DAG ID.
156 :param task_id: Task ID.
157 :param run_id: DAG run ID for the task.
158 :param map_index: Optional map index to assign XCom for a mapped task.
159 The default is ``-1`` (set for a non-mapped task).
160 :param session: Database session. If not given, a new session will be
161 created for this function.
162 """
164 @overload
165 @classmethod
166 def set(
167 cls,
168 key: str,
169 value: Any,
170 task_id: str,
171 dag_id: str,
172 execution_date: datetime.datetime,
173 session: Session = NEW_SESSION,
174 ) -> None:
175 """Store an XCom value.
177 :sphinx-autoapi-skip:
178 """
180 @classmethod
181 @provide_session
182 def set(
183 cls,
184 key: str,
185 value: Any,
186 task_id: str,
187 dag_id: str,
188 execution_date: datetime.datetime | None = None,
189 session: Session = NEW_SESSION,
190 *,
191 run_id: str | None = None,
192 map_index: int = -1,
193 ) -> None:
194 """Store an XCom value.
196 :sphinx-autoapi-skip:
197 """
198 from airflow.models.dagrun import DagRun
200 if not exactly_one(execution_date is not None, run_id is not None):
201 raise ValueError(
202 f"Exactly one of run_id or execution_date must be passed. "
203 f"Passed execution_date={execution_date}, run_id={run_id}"
204 )
206 if run_id is None:
207 message = "Passing 'execution_date' to 'XCom.set()' is deprecated. Use 'run_id' instead."
208 warnings.warn(message, RemovedInAirflow3Warning, stacklevel=3)
209 try:
210 dag_run_id, run_id = (
211 session.query(DagRun.id, DagRun.run_id)
212 .filter(DagRun.dag_id == dag_id, DagRun.execution_date == execution_date)
213 .one()
214 )
215 except NoResultFound:
216 raise ValueError(f"DAG run not found on DAG {dag_id!r} at {execution_date}") from None
217 else:
218 dag_run_id = session.query(DagRun.id).filter_by(dag_id=dag_id, run_id=run_id).scalar()
219 if dag_run_id is None:
220 raise ValueError(f"DAG run not found on DAG {dag_id!r} with ID {run_id!r}")
222 # Seamlessly resolve LazyXComAccess to a list. This is intended to work
223 # as a "lazy list" to avoid pulling a ton of XComs unnecessarily, but if
224 # it's pushed into XCom, the user should be aware of the performance
225 # implications, and this avoids leaking the implementation detail.
226 if isinstance(value, LazyXComAccess):
227 warning_message = (
228 "Coercing mapped lazy proxy %s from task %s (DAG %s, run %s) "
229 "to list, which may degrade performance. Review resource "
230 "requirements for this operation, and call list() to suppress "
231 "this message. See Dynamic Task Mapping documentation for "
232 "more information about lazy proxy objects."
233 )
234 log.warning(
235 warning_message,
236 "return value" if key == XCOM_RETURN_KEY else f"value {key}",
237 task_id,
238 dag_id,
239 run_id or execution_date,
240 )
241 value = list(value)
243 value = cls.serialize_value(
244 value=value,
245 key=key,
246 task_id=task_id,
247 dag_id=dag_id,
248 run_id=run_id,
249 map_index=map_index,
250 )
252 # Remove duplicate XComs and insert a new one.
253 session.execute(
254 delete(cls).where(
255 cls.key == key,
256 cls.run_id == run_id,
257 cls.task_id == task_id,
258 cls.dag_id == dag_id,
259 cls.map_index == map_index,
260 )
261 )
262 new = cast(Any, cls)( # Work around Mypy complaining model not defining '__init__'.
263 dag_run_id=dag_run_id,
264 key=key,
265 value=value,
266 run_id=run_id,
267 task_id=task_id,
268 dag_id=dag_id,
269 map_index=map_index,
270 )
271 session.add(new)
272 session.flush()
274 @staticmethod
275 @provide_session
276 @internal_api_call
277 def get_value(
278 *,
279 ti_key: TaskInstanceKey,
280 key: str | None = None,
281 session: Session = NEW_SESSION,
282 ) -> Any:
283 """Retrieve an XCom value for a task instance.
285 This method returns "full" XCom values (i.e. uses ``deserialize_value``
286 from the XCom backend). Use :meth:`get_many` if you want the "shortened"
287 value via ``orm_deserialize_value``.
289 If there are no results, *None* is returned. If multiple XCom entries
290 match the criteria, an arbitrary one is returned.
292 :param ti_key: The TaskInstanceKey to look up the XCom for.
293 :param key: A key for the XCom. If provided, only XCom with matching
294 keys will be returned. Pass *None* (default) to remove the filter.
295 :param session: Database session. If not given, a new session will be
296 created for this function.
297 """
298 return BaseXCom.get_one(
299 key=key,
300 task_id=ti_key.task_id,
301 dag_id=ti_key.dag_id,
302 run_id=ti_key.run_id,
303 map_index=ti_key.map_index,
304 session=session,
305 )
307 @overload
308 @staticmethod
309 @internal_api_call
310 def get_one(
311 *,
312 key: str | None = None,
313 dag_id: str | None = None,
314 task_id: str | None = None,
315 run_id: str | None = None,
316 map_index: int | None = None,
317 session: Session = NEW_SESSION,
318 ) -> Any | None:
319 """Retrieve an XCom value, optionally meeting certain criteria.
321 This method returns "full" XCom values (i.e. uses ``deserialize_value``
322 from the XCom backend). Use :meth:`get_many` if you want the "shortened"
323 value via ``orm_deserialize_value``.
325 If there are no results, *None* is returned. If multiple XCom entries
326 match the criteria, an arbitrary one is returned.
328 A deprecated form of this function accepts ``execution_date`` instead of
329 ``run_id``. The two arguments are mutually exclusive.
331 .. seealso:: ``get_value()`` is a convenience function if you already
332 have a structured TaskInstance or TaskInstanceKey object available.
334 :param run_id: DAG run ID for the task.
335 :param dag_id: Only pull XCom from this DAG. Pass *None* (default) to
336 remove the filter.
337 :param task_id: Only XCom from task with matching ID will be pulled.
338 Pass *None* (default) to remove the filter.
339 :param map_index: Only XCom from task with matching ID will be pulled.
340 Pass *None* (default) to remove the filter.
341 :param key: A key for the XCom. If provided, only XCom with matching
342 keys will be returned. Pass *None* (default) to remove the filter.
343 :param include_prior_dates: If *False* (default), only XCom from the
344 specified DAG run is returned. If *True*, the latest matching XCom is
345 returned regardless of the run it belongs to.
346 :param session: Database session. If not given, a new session will be
347 created for this function.
348 """
350 @overload
351 @staticmethod
352 @internal_api_call
353 def get_one(
354 execution_date: datetime.datetime,
355 key: str | None = None,
356 task_id: str | None = None,
357 dag_id: str | None = None,
358 include_prior_dates: bool = False,
359 session: Session = NEW_SESSION,
360 ) -> Any | None:
361 """Retrieve an XCom value, optionally meeting certain criteria.
363 :sphinx-autoapi-skip:
364 """
366 @staticmethod
367 @provide_session
368 @internal_api_call
369 def get_one(
370 execution_date: datetime.datetime | None = None,
371 key: str | None = None,
372 task_id: str | None = None,
373 dag_id: str | None = None,
374 include_prior_dates: bool = False,
375 session: Session = NEW_SESSION,
376 *,
377 run_id: str | None = None,
378 map_index: int | None = None,
379 ) -> Any | None:
380 """Retrieve an XCom value, optionally meeting certain criteria.
382 :sphinx-autoapi-skip:
383 """
384 if not exactly_one(execution_date is not None, run_id is not None):
385 raise ValueError("Exactly one of run_id or execution_date must be passed")
387 if run_id:
388 query = BaseXCom.get_many(
389 run_id=run_id,
390 key=key,
391 task_ids=task_id,
392 dag_ids=dag_id,
393 map_indexes=map_index,
394 include_prior_dates=include_prior_dates,
395 limit=1,
396 session=session,
397 )
398 elif execution_date is not None:
399 message = "Passing 'execution_date' to 'XCom.get_one()' is deprecated. Use 'run_id' instead."
400 warnings.warn(message, RemovedInAirflow3Warning, stacklevel=3)
402 with warnings.catch_warnings():
403 warnings.simplefilter("ignore", RemovedInAirflow3Warning)
404 query = BaseXCom.get_many(
405 execution_date=execution_date,
406 key=key,
407 task_ids=task_id,
408 dag_ids=dag_id,
409 map_indexes=map_index,
410 include_prior_dates=include_prior_dates,
411 limit=1,
412 session=session,
413 )
414 else:
415 raise RuntimeError("Should not happen?")
417 result = query.with_entities(BaseXCom.value).first()
418 if result:
419 return BaseXCom.deserialize_value(result)
420 return None
422 @overload
423 @staticmethod
424 def get_many(
425 *,
426 run_id: str,
427 key: str | None = None,
428 task_ids: str | Iterable[str] | None = None,
429 dag_ids: str | Iterable[str] | None = None,
430 map_indexes: int | Iterable[int] | None = None,
431 include_prior_dates: bool = False,
432 limit: int | None = None,
433 session: Session = NEW_SESSION,
434 ) -> Query:
435 """Composes a query to get one or more XCom entries.
437 This function returns an SQLAlchemy query of full XCom objects. If you
438 just want one stored value, use :meth:`get_one` instead.
440 A deprecated form of this function accepts ``execution_date`` instead of
441 ``run_id``. The two arguments are mutually exclusive.
443 :param run_id: DAG run ID for the task.
444 :param key: A key for the XComs. If provided, only XComs with matching
445 keys will be returned. Pass *None* (default) to remove the filter.
446 :param task_ids: Only XComs from task with matching IDs will be pulled.
447 Pass *None* (default) to remove the filter.
448 :param dag_ids: Only pulls XComs from specified DAGs. Pass *None*
449 (default) to remove the filter.
450 :param map_indexes: Only XComs from matching map indexes will be pulled.
451 Pass *None* (default) to remove the filter.
452 :param include_prior_dates: If *False* (default), only XComs from the
453 specified DAG run are returned. If *True*, all matching XComs are
454 returned regardless of the run it belongs to.
455 :param session: Database session. If not given, a new session will be
456 created for this function.
457 :param limit: Limiting returning XComs
458 """
460 @overload
461 @staticmethod
462 @internal_api_call
463 def get_many(
464 execution_date: datetime.datetime,
465 key: str | None = None,
466 task_ids: str | Iterable[str] | None = None,
467 dag_ids: str | Iterable[str] | None = None,
468 map_indexes: int | Iterable[int] | None = None,
469 include_prior_dates: bool = False,
470 limit: int | None = None,
471 session: Session = NEW_SESSION,
472 ) -> Query:
473 """Composes a query to get one or more XCom entries.
475 :sphinx-autoapi-skip:
476 """
478 @staticmethod
479 @provide_session
480 @internal_api_call
481 def get_many(
482 execution_date: datetime.datetime | None = None,
483 key: str | None = None,
484 task_ids: str | Iterable[str] | None = None,
485 dag_ids: str | Iterable[str] | None = None,
486 map_indexes: int | Iterable[int] | None = None,
487 include_prior_dates: bool = False,
488 limit: int | None = None,
489 session: Session = NEW_SESSION,
490 *,
491 run_id: str | None = None,
492 ) -> Query:
493 """Composes a query to get one or more XCom entries.
495 :sphinx-autoapi-skip:
496 """
497 from airflow.models.dagrun import DagRun
499 if not exactly_one(execution_date is not None, run_id is not None):
500 raise ValueError(
501 f"Exactly one of run_id or execution_date must be passed. "
502 f"Passed execution_date={execution_date}, run_id={run_id}"
503 )
504 if execution_date is not None:
505 message = "Passing 'execution_date' to 'XCom.get_many()' is deprecated. Use 'run_id' instead."
506 warnings.warn(message, RemovedInAirflow3Warning, stacklevel=3)
508 query = session.query(BaseXCom).join(BaseXCom.dag_run)
510 if key:
511 query = query.filter(BaseXCom.key == key)
513 if is_container(task_ids):
514 query = query.filter(BaseXCom.task_id.in_(task_ids))
515 elif task_ids is not None:
516 query = query.filter(BaseXCom.task_id == task_ids)
518 if is_container(dag_ids):
519 query = query.filter(BaseXCom.dag_id.in_(dag_ids))
520 elif dag_ids is not None:
521 query = query.filter(BaseXCom.dag_id == dag_ids)
523 if isinstance(map_indexes, range) and map_indexes.step == 1:
524 query = query.filter(
525 BaseXCom.map_index >= map_indexes.start, BaseXCom.map_index < map_indexes.stop
526 )
527 elif is_container(map_indexes):
528 query = query.filter(BaseXCom.map_index.in_(map_indexes))
529 elif map_indexes is not None:
530 query = query.filter(BaseXCom.map_index == map_indexes)
532 if include_prior_dates:
533 if execution_date is not None:
534 query = query.filter(DagRun.execution_date <= execution_date)
535 else:
536 dr = session.query(DagRun.execution_date).filter(DagRun.run_id == run_id).subquery()
537 query = query.filter(BaseXCom.execution_date <= dr.c.execution_date)
538 elif execution_date is not None:
539 query = query.filter(DagRun.execution_date == execution_date)
540 else:
541 query = query.filter(BaseXCom.run_id == run_id)
543 query = query.order_by(DagRun.execution_date.desc(), BaseXCom.timestamp.desc())
544 if limit:
545 return query.limit(limit)
546 return query
548 @classmethod
549 @provide_session
550 def delete(cls, xcoms: XCom | Iterable[XCom], session: Session) -> None:
551 """Delete one or multiple XCom entries."""
552 if isinstance(xcoms, XCom):
553 xcoms = [xcoms]
554 for xcom in xcoms:
555 if not isinstance(xcom, XCom):
556 raise TypeError(f"Expected XCom; received {xcom.__class__.__name__}")
557 session.delete(xcom)
558 session.commit()
560 @overload
561 @staticmethod
562 @internal_api_call
563 def clear(
564 *,
565 dag_id: str,
566 task_id: str,
567 run_id: str,
568 map_index: int | None = None,
569 session: Session = NEW_SESSION,
570 ) -> None:
571 """Clear all XCom data from the database for the given task instance.
573 A deprecated form of this function accepts ``execution_date`` instead of
574 ``run_id``. The two arguments are mutually exclusive.
576 :param dag_id: ID of DAG to clear the XCom for.
577 :param task_id: ID of task to clear the XCom for.
578 :param run_id: ID of DAG run to clear the XCom for.
579 :param map_index: If given, only clear XCom from this particular mapped
580 task. The default ``None`` clears *all* XComs from the task.
581 :param session: Database session. If not given, a new session will be
582 created for this function.
583 """
585 @overload
586 @staticmethod
587 @internal_api_call
588 def clear(
589 execution_date: pendulum.DateTime,
590 dag_id: str,
591 task_id: str,
592 session: Session = NEW_SESSION,
593 ) -> None:
594 """Clear all XCom data from the database for the given task instance.
596 :sphinx-autoapi-skip:
597 """
599 @staticmethod
600 @provide_session
601 @internal_api_call
602 def clear(
603 execution_date: pendulum.DateTime | None = None,
604 dag_id: str | None = None,
605 task_id: str | None = None,
606 session: Session = NEW_SESSION,
607 *,
608 run_id: str | None = None,
609 map_index: int | None = None,
610 ) -> None:
611 """Clear all XCom data from the database for the given task instance.
613 :sphinx-autoapi-skip:
614 """
615 from airflow.models import DagRun
617 # Given the historic order of this function (execution_date was first argument) to add a new optional
618 # param we need to add default values for everything :(
619 if dag_id is None:
620 raise TypeError("clear() missing required argument: dag_id")
621 if task_id is None:
622 raise TypeError("clear() missing required argument: task_id")
624 if not exactly_one(execution_date is not None, run_id is not None):
625 raise ValueError(
626 f"Exactly one of run_id or execution_date must be passed. "
627 f"Passed execution_date={execution_date}, run_id={run_id}"
628 )
630 if execution_date is not None:
631 message = "Passing 'execution_date' to 'XCom.clear()' is deprecated. Use 'run_id' instead."
632 warnings.warn(message, RemovedInAirflow3Warning, stacklevel=3)
633 run_id = (
634 session.query(DagRun.run_id)
635 .filter(DagRun.dag_id == dag_id, DagRun.execution_date == execution_date)
636 .scalar()
637 )
639 query = session.query(BaseXCom).filter_by(dag_id=dag_id, task_id=task_id, run_id=run_id)
640 if map_index is not None:
641 query = query.filter_by(map_index=map_index)
642 query.delete()
644 @staticmethod
645 def serialize_value(
646 value: Any,
647 *,
648 key: str | None = None,
649 task_id: str | None = None,
650 dag_id: str | None = None,
651 run_id: str | None = None,
652 map_index: int | None = None,
653 ) -> Any:
654 """Serialize XCom value to str or pickled object."""
655 if conf.getboolean("core", "enable_xcom_pickling"):
656 return pickle.dumps(value)
657 try:
658 return json.dumps(value, cls=XComEncoder).encode("UTF-8")
659 except (ValueError, TypeError) as ex:
660 log.error(
661 "%s."
662 " If you are using pickle instead of JSON for XCom,"
663 " then you need to enable pickle support for XCom"
664 " in your airflow config or make sure to decorate your"
665 " object with attr.",
666 ex,
667 )
668 raise
670 @staticmethod
671 def _deserialize_value(result: XCom, orm: bool) -> Any:
672 object_hook = None
673 if orm:
674 object_hook = XComDecoder.orm_object_hook
676 if result.value is None:
677 return None
678 if conf.getboolean("core", "enable_xcom_pickling"):
679 try:
680 return pickle.loads(result.value)
681 except pickle.UnpicklingError:
682 return json.loads(result.value.decode("UTF-8"), cls=XComDecoder, object_hook=object_hook)
683 else:
684 try:
685 return json.loads(result.value.decode("UTF-8"), cls=XComDecoder, object_hook=object_hook)
686 except (json.JSONDecodeError, UnicodeDecodeError):
687 return pickle.loads(result.value)
689 @staticmethod
690 def deserialize_value(result: XCom) -> Any:
691 """Deserialize XCom value from str or pickle object."""
692 return BaseXCom._deserialize_value(result, False)
694 def orm_deserialize_value(self) -> Any:
695 """
696 Deserialize method which is used to reconstruct ORM XCom object.
698 This method should be overridden in custom XCom backends to avoid
699 unnecessary request or other resource consuming operations when
700 creating XCom orm model. This is used when viewing XCom listing
701 in the webserver, for example.
702 """
703 return BaseXCom._deserialize_value(self, True)
706class _LazyXComAccessIterator(collections.abc.Iterator):
707 def __init__(self, cm: contextlib.AbstractContextManager[Query]) -> None:
708 self._cm = cm
709 self._entered = False
711 def __del__(self) -> None:
712 if self._entered:
713 self._cm.__exit__(None, None, None)
715 def __iter__(self) -> collections.abc.Iterator:
716 return self
718 def __next__(self) -> Any:
719 return XCom.deserialize_value(next(self._it))
721 @cached_property
722 def _it(self) -> collections.abc.Iterator:
723 self._entered = True
724 return iter(self._cm.__enter__())
727@attr.define(slots=True)
728class LazyXComAccess(collections.abc.Sequence):
729 """Wrapper to lazily pull XCom with a sequence-like interface.
731 Note that since the session bound to the parent query may have died when we
732 actually access the sequence's content, we must create a new session
733 for every function call with ``with_session()``.
735 :meta private:
736 """
738 _query: Query
739 _len: int | None = attr.ib(init=False, default=None)
741 @classmethod
742 def build_from_xcom_query(cls, query: Query) -> LazyXComAccess:
743 return cls(query=query.with_entities(XCom.value))
745 def __repr__(self) -> str:
746 return f"LazyXComAccess([{len(self)} items])"
748 def __str__(self) -> str:
749 return str(list(self))
751 def __eq__(self, other: Any) -> bool:
752 if isinstance(other, (list, LazyXComAccess)):
753 z = itertools.zip_longest(iter(self), iter(other), fillvalue=object())
754 return all(x == y for x, y in z)
755 return NotImplemented
757 def __getstate__(self) -> Any:
758 # We don't want to go to the trouble of serializing the entire Query
759 # object, including its filters, hints, etc. (plus SQLAlchemy does not
760 # provide a public API to inspect a query's contents). Converting the
761 # query into a SQL string is the best we can get. Theoratically we can
762 # do the same for count(), but I think it should be performant enough to
763 # calculate only that eagerly.
764 with self._get_bound_query() as query:
765 statement = query.statement.compile(
766 query.session.get_bind(),
767 # This inlines all the values into the SQL string to simplify
768 # cross-process commuinication as much as possible.
769 compile_kwargs={"literal_binds": True},
770 )
771 return (str(statement), query.count())
773 def __setstate__(self, state: Any) -> None:
774 statement, self._len = state
775 self._query = Query(XCom.value).from_statement(text(statement))
777 def __len__(self):
778 if self._len is None:
779 with self._get_bound_query() as query:
780 self._len = query.count()
781 return self._len
783 def __iter__(self):
784 return _LazyXComAccessIterator(self._get_bound_query())
786 def __getitem__(self, key):
787 if not isinstance(key, int):
788 raise ValueError("only support index access for now")
789 try:
790 with self._get_bound_query() as query:
791 r = query.offset(key).limit(1).one()
792 except NoResultFound:
793 raise IndexError(key) from None
794 return XCom.deserialize_value(r)
796 @contextlib.contextmanager
797 def _get_bound_query(self) -> Generator[Query, None, None]:
798 # Do we have a valid session already?
799 if self._query.session and self._query.session.is_active:
800 yield self._query
801 return
803 Session = getattr(settings, "Session", None)
804 if Session is None:
805 raise RuntimeError("Session must be set before!")
806 session = Session()
807 try:
808 yield self._query.with_session(session)
809 finally:
810 session.close()
813def _patch_outdated_serializer(clazz: type[BaseXCom], params: Iterable[str]) -> None:
814 """Patch a custom ``serialize_value`` to accept the modern signature.
816 To give custom XCom backends more flexibility with how they store values, we
817 now forward all params passed to ``XCom.set`` to ``XCom.serialize_value``.
818 In order to maintain compatibility with custom XCom backends written with
819 the old signature, we check the signature and, if necessary, patch with a
820 method that ignores kwargs the backend does not accept.
821 """
822 old_serializer = clazz.serialize_value
824 @wraps(old_serializer)
825 def _shim(**kwargs):
826 kwargs = {k: kwargs.get(k) for k in params}
827 warnings.warn(
828 f"Method `serialize_value` in XCom backend {XCom.__name__} is using outdated signature and"
829 f"must be updated to accept all params in `BaseXCom.set` except `session`. Support will be "
830 f"removed in a future release.",
831 RemovedInAirflow3Warning,
832 )
833 return old_serializer(**kwargs)
835 clazz.serialize_value = _shim # type: ignore[assignment]
838def _get_function_params(function) -> list[str]:
839 """
840 Returns the list of variables names of a function.
842 :param function: The function to inspect
843 """
844 parameters = inspect.signature(function).parameters
845 bound_arguments = [
846 name for name, p in parameters.items() if p.kind not in (p.VAR_POSITIONAL, p.VAR_KEYWORD)
847 ]
848 return bound_arguments
851def resolve_xcom_backend() -> type[BaseXCom]:
852 """Resolves custom XCom class.
854 Confirms that custom XCom class extends the BaseXCom.
855 Compares the function signature of the custom XCom serialize_value to the base XCom serialize_value.
856 """
857 clazz = conf.getimport("core", "xcom_backend", fallback=f"airflow.models.xcom.{BaseXCom.__name__}")
858 if not clazz:
859 return BaseXCom
860 if not issubclass(clazz, BaseXCom):
861 raise TypeError(
862 f"Your custom XCom class `{clazz.__name__}` is not a subclass of `{BaseXCom.__name__}`."
863 )
864 base_xcom_params = _get_function_params(BaseXCom.serialize_value)
865 xcom_params = _get_function_params(clazz.serialize_value)
866 if not set(base_xcom_params) == set(xcom_params):
867 _patch_outdated_serializer(clazz=clazz, params=xcom_params)
868 return clazz
871if TYPE_CHECKING:
872 XCom = BaseXCom # Hack to avoid Mypy "Variable 'XCom' is not valid as a type".
873else:
874 XCom = resolve_xcom_backend()