Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/airflow/models/xcom.py: 44%

310 statements  

« prev     ^ index     » next       coverage.py v7.0.1, created at 2022-12-25 06:11 +0000

1# 

2# Licensed to the Apache Software Foundation (ASF) under one 

3# or more contributor license agreements. See the NOTICE file 

4# distributed with this work for additional information 

5# regarding copyright ownership. The ASF licenses this file 

6# to you under the Apache License, Version 2.0 (the 

7# "License"); you may not use this file except in compliance 

8# with the License. You may obtain a copy of the License at 

9# 

10# http://www.apache.org/licenses/LICENSE-2.0 

11# 

12# Unless required by applicable law or agreed to in writing, 

13# software distributed under the License is distributed on an 

14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

15# KIND, either express or implied. See the License for the 

16# specific language governing permissions and limitations 

17# under the License. 

18from __future__ import annotations 

19 

20import collections.abc 

21import contextlib 

22import datetime 

23import inspect 

24import itertools 

25import json 

26import logging 

27import pickle 

28import warnings 

29from functools import wraps 

30from typing import TYPE_CHECKING, Any, Generator, Iterable, cast, overload 

31 

32import attr 

33import pendulum 

34from sqlalchemy import ( 

35 Column, 

36 ForeignKeyConstraint, 

37 Index, 

38 Integer, 

39 LargeBinary, 

40 PrimaryKeyConstraint, 

41 String, 

42 text, 

43) 

44from sqlalchemy.ext.associationproxy import association_proxy 

45from sqlalchemy.orm import Query, Session, reconstructor, relationship 

46from sqlalchemy.orm.exc import NoResultFound 

47 

48from airflow import settings 

49from airflow.compat.functools import cached_property 

50from airflow.configuration import conf 

51from airflow.exceptions import RemovedInAirflow3Warning 

52from airflow.models.base import COLLATION_ARGS, ID_LEN, Base 

53from airflow.utils import timezone 

54from airflow.utils.helpers import exactly_one, is_container 

55from airflow.utils.json import XComDecoder, XComEncoder 

56from airflow.utils.log.logging_mixin import LoggingMixin 

57from airflow.utils.session import NEW_SESSION, provide_session 

58from airflow.utils.sqlalchemy import UtcDateTime 

59 

60log = logging.getLogger(__name__) 

61 

62# MAX XCOM Size is 48KB 

63# https://github.com/apache/airflow/pull/1618#discussion_r68249677 

64MAX_XCOM_SIZE = 49344 

65XCOM_RETURN_KEY = "return_value" 

66 

67if TYPE_CHECKING: 

68 from airflow.models.taskinstance import TaskInstanceKey 

69 

70 

71class BaseXCom(Base, LoggingMixin): 

72 """Base class for XCom objects.""" 

73 

74 __tablename__ = "xcom" 

75 

76 dag_run_id = Column(Integer(), nullable=False, primary_key=True) 

77 task_id = Column(String(ID_LEN, **COLLATION_ARGS), nullable=False, primary_key=True) 

78 map_index = Column(Integer, primary_key=True, nullable=False, server_default=text("-1")) 

79 key = Column(String(512, **COLLATION_ARGS), nullable=False, primary_key=True) 

80 

81 # Denormalized for easier lookup. 

82 dag_id = Column(String(ID_LEN, **COLLATION_ARGS), nullable=False) 

83 run_id = Column(String(ID_LEN, **COLLATION_ARGS), nullable=False) 

84 

85 value = Column(LargeBinary) 

86 timestamp = Column(UtcDateTime, default=timezone.utcnow, nullable=False) 

87 

88 __table_args__ = ( 

89 # Ideally we should create a unique index over (key, dag_id, task_id, run_id), 

90 # but it goes over MySQL's index length limit. So we instead index 'key' 

91 # separately, and enforce uniqueness with DagRun.id instead. 

92 Index("idx_xcom_key", key), 

93 Index("idx_xcom_task_instance", dag_id, task_id, run_id, map_index), 

94 PrimaryKeyConstraint( 

95 "dag_run_id", "task_id", "map_index", "key", name="xcom_pkey", mssql_clustered=True 

96 ), 

97 ForeignKeyConstraint( 

98 [dag_id, task_id, run_id, map_index], 

99 [ 

100 "task_instance.dag_id", 

101 "task_instance.task_id", 

102 "task_instance.run_id", 

103 "task_instance.map_index", 

104 ], 

105 name="xcom_task_instance_fkey", 

106 ondelete="CASCADE", 

107 ), 

108 ) 

109 

110 dag_run = relationship( 

111 "DagRun", 

112 primaryjoin="BaseXCom.dag_run_id == foreign(DagRun.id)", 

113 uselist=False, 

114 lazy="joined", 

115 passive_deletes="all", 

116 ) 

117 execution_date = association_proxy("dag_run", "execution_date") 

118 

119 @reconstructor 

120 def init_on_load(self): 

121 """ 

122 Called by the ORM after the instance has been loaded from the DB or otherwise reconstituted 

123 i.e automatically deserialize Xcom value when loading from DB. 

124 """ 

125 self.value = self.orm_deserialize_value() 

126 

127 def __repr__(self): 

128 if self.map_index < 0: 

129 return f'<XCom "{self.key}" ({self.task_id} @ {self.run_id})>' 

130 return f'<XCom "{self.key}" ({self.task_id}[{self.map_index}] @ {self.run_id})>' 

131 

132 @overload 

133 @classmethod 

134 def set( 

135 cls, 

136 key: str, 

137 value: Any, 

138 *, 

139 dag_id: str, 

140 task_id: str, 

141 run_id: str, 

142 map_index: int = -1, 

143 session: Session = NEW_SESSION, 

144 ) -> None: 

145 """Store an XCom value. 

146 

147 A deprecated form of this function accepts ``execution_date`` instead of 

148 ``run_id``. The two arguments are mutually exclusive. 

149 

150 :param key: Key to store the XCom. 

151 :param value: XCom value to store. 

152 :param dag_id: DAG ID. 

153 :param task_id: Task ID. 

154 :param run_id: DAG run ID for the task. 

155 :param map_index: Optional map index to assign XCom for a mapped task. 

156 The default is ``-1`` (set for a non-mapped task). 

157 :param session: Database session. If not given, a new session will be 

158 created for this function. 

159 """ 

160 

161 @overload 

162 @classmethod 

163 def set( 

164 cls, 

165 key: str, 

166 value: Any, 

167 task_id: str, 

168 dag_id: str, 

169 execution_date: datetime.datetime, 

170 session: Session = NEW_SESSION, 

171 ) -> None: 

172 """:sphinx-autoapi-skip:""" 

173 

174 @classmethod 

175 @provide_session 

176 def set( 

177 cls, 

178 key: str, 

179 value: Any, 

180 task_id: str, 

181 dag_id: str, 

182 execution_date: datetime.datetime | None = None, 

183 session: Session = NEW_SESSION, 

184 *, 

185 run_id: str | None = None, 

186 map_index: int = -1, 

187 ) -> None: 

188 """:sphinx-autoapi-skip:""" 

189 from airflow.models.dagrun import DagRun 

190 

191 if not exactly_one(execution_date is not None, run_id is not None): 

192 raise ValueError( 

193 f"Exactly one of run_id or execution_date must be passed. " 

194 f"Passed execution_date={execution_date}, run_id={run_id}" 

195 ) 

196 

197 if run_id is None: 

198 message = "Passing 'execution_date' to 'XCom.set()' is deprecated. Use 'run_id' instead." 

199 warnings.warn(message, RemovedInAirflow3Warning, stacklevel=3) 

200 try: 

201 dag_run_id, run_id = ( 

202 session.query(DagRun.id, DagRun.run_id) 

203 .filter(DagRun.dag_id == dag_id, DagRun.execution_date == execution_date) 

204 .one() 

205 ) 

206 except NoResultFound: 

207 raise ValueError(f"DAG run not found on DAG {dag_id!r} at {execution_date}") from None 

208 else: 

209 dag_run_id = session.query(DagRun.id).filter_by(dag_id=dag_id, run_id=run_id).scalar() 

210 if dag_run_id is None: 

211 raise ValueError(f"DAG run not found on DAG {dag_id!r} with ID {run_id!r}") 

212 

213 # Seamlessly resolve LazyXComAccess to a list. This is intended to work 

214 # as a "lazy list" to avoid pulling a ton of XComs unnecessarily, but if 

215 # it's pushed into XCom, the user should be aware of the performance 

216 # implications, and this avoids leaking the implementation detail. 

217 if isinstance(value, LazyXComAccess): 

218 warning_message = ( 

219 "Coercing mapped lazy proxy %s from task %s (DAG %s, run %s) " 

220 "to list, which may degrade performance. Review resource " 

221 "requirements for this operation, and call list() to suppress " 

222 "this message. See Dynamic Task Mapping documentation for " 

223 "more information about lazy proxy objects." 

224 ) 

225 log.warning( 

226 warning_message, 

227 "return value" if key == XCOM_RETURN_KEY else f"value {key}", 

228 task_id, 

229 dag_id, 

230 run_id or execution_date, 

231 ) 

232 value = list(value) 

233 

234 value = cls.serialize_value( 

235 value=value, 

236 key=key, 

237 task_id=task_id, 

238 dag_id=dag_id, 

239 run_id=run_id, 

240 map_index=map_index, 

241 ) 

242 

243 # Remove duplicate XComs and insert a new one. 

244 session.query(cls).filter( 

245 cls.key == key, 

246 cls.run_id == run_id, 

247 cls.task_id == task_id, 

248 cls.dag_id == dag_id, 

249 cls.map_index == map_index, 

250 ).delete() 

251 new = cast(Any, cls)( # Work around Mypy complaining model not defining '__init__'. 

252 dag_run_id=dag_run_id, 

253 key=key, 

254 value=value, 

255 run_id=run_id, 

256 task_id=task_id, 

257 dag_id=dag_id, 

258 map_index=map_index, 

259 ) 

260 session.add(new) 

261 session.flush() 

262 

263 @classmethod 

264 @provide_session 

265 def get_value( 

266 cls, 

267 *, 

268 ti_key: TaskInstanceKey, 

269 key: str | None = None, 

270 session: Session = NEW_SESSION, 

271 ) -> Any: 

272 """Retrieve an XCom value for a task instance. 

273 

274 This method returns "full" XCom values (i.e. uses ``deserialize_value`` 

275 from the XCom backend). Use :meth:`get_many` if you want the "shortened" 

276 value via ``orm_deserialize_value``. 

277 

278 If there are no results, *None* is returned. If multiple XCom entries 

279 match the criteria, an arbitrary one is returned. 

280 

281 :param ti_key: The TaskInstanceKey to look up the XCom for. 

282 :param key: A key for the XCom. If provided, only XCom with matching 

283 keys will be returned. Pass *None* (default) to remove the filter. 

284 :param session: Database session. If not given, a new session will be 

285 created for this function. 

286 """ 

287 return cls.get_one( 

288 key=key, 

289 task_id=ti_key.task_id, 

290 dag_id=ti_key.dag_id, 

291 run_id=ti_key.run_id, 

292 map_index=ti_key.map_index, 

293 session=session, 

294 ) 

295 

296 @overload 

297 @classmethod 

298 def get_one( 

299 cls, 

300 *, 

301 key: str | None = None, 

302 dag_id: str | None = None, 

303 task_id: str | None = None, 

304 run_id: str | None = None, 

305 map_index: int | None = None, 

306 session: Session = NEW_SESSION, 

307 ) -> Any | None: 

308 """Retrieve an XCom value, optionally meeting certain criteria. 

309 

310 This method returns "full" XCom values (i.e. uses ``deserialize_value`` 

311 from the XCom backend). Use :meth:`get_many` if you want the "shortened" 

312 value via ``orm_deserialize_value``. 

313 

314 If there are no results, *None* is returned. If multiple XCom entries 

315 match the criteria, an arbitrary one is returned. 

316 

317 A deprecated form of this function accepts ``execution_date`` instead of 

318 ``run_id``. The two arguments are mutually exclusive. 

319 

320 .. seealso:: ``get_value()`` is a convenience function if you already 

321 have a structured TaskInstance or TaskInstanceKey object available. 

322 

323 :param run_id: DAG run ID for the task. 

324 :param dag_id: Only pull XCom from this DAG. Pass *None* (default) to 

325 remove the filter. 

326 :param task_id: Only XCom from task with matching ID will be pulled. 

327 Pass *None* (default) to remove the filter. 

328 :param map_index: Only XCom from task with matching ID will be pulled. 

329 Pass *None* (default) to remove the filter. 

330 :param key: A key for the XCom. If provided, only XCom with matching 

331 keys will be returned. Pass *None* (default) to remove the filter. 

332 :param include_prior_dates: If *False* (default), only XCom from the 

333 specified DAG run is returned. If *True*, the latest matching XCom is 

334 returned regardless of the run it belongs to. 

335 :param session: Database session. If not given, a new session will be 

336 created for this function. 

337 """ 

338 

339 @overload 

340 @classmethod 

341 def get_one( 

342 cls, 

343 execution_date: datetime.datetime, 

344 key: str | None = None, 

345 task_id: str | None = None, 

346 dag_id: str | None = None, 

347 include_prior_dates: bool = False, 

348 session: Session = NEW_SESSION, 

349 ) -> Any | None: 

350 """:sphinx-autoapi-skip:""" 

351 

352 @classmethod 

353 @provide_session 

354 def get_one( 

355 cls, 

356 execution_date: datetime.datetime | None = None, 

357 key: str | None = None, 

358 task_id: str | None = None, 

359 dag_id: str | None = None, 

360 include_prior_dates: bool = False, 

361 session: Session = NEW_SESSION, 

362 *, 

363 run_id: str | None = None, 

364 map_index: int | None = None, 

365 ) -> Any | None: 

366 """:sphinx-autoapi-skip:""" 

367 if not exactly_one(execution_date is not None, run_id is not None): 

368 raise ValueError("Exactly one of ti_key, run_id, or execution_date must be passed") 

369 

370 if run_id: 

371 query = cls.get_many( 

372 run_id=run_id, 

373 key=key, 

374 task_ids=task_id, 

375 dag_ids=dag_id, 

376 map_indexes=map_index, 

377 include_prior_dates=include_prior_dates, 

378 limit=1, 

379 session=session, 

380 ) 

381 elif execution_date is not None: 

382 message = "Passing 'execution_date' to 'XCom.get_one()' is deprecated. Use 'run_id' instead." 

383 warnings.warn(message, RemovedInAirflow3Warning, stacklevel=3) 

384 

385 with warnings.catch_warnings(): 

386 warnings.simplefilter("ignore", RemovedInAirflow3Warning) 

387 query = cls.get_many( 

388 execution_date=execution_date, 

389 key=key, 

390 task_ids=task_id, 

391 dag_ids=dag_id, 

392 map_indexes=map_index, 

393 include_prior_dates=include_prior_dates, 

394 limit=1, 

395 session=session, 

396 ) 

397 else: 

398 raise RuntimeError("Should not happen?") 

399 

400 result = query.with_entities(cls.value).first() 

401 if result: 

402 return cls.deserialize_value(result) 

403 return None 

404 

405 @overload 

406 @classmethod 

407 def get_many( 

408 cls, 

409 *, 

410 run_id: str, 

411 key: str | None = None, 

412 task_ids: str | Iterable[str] | None = None, 

413 dag_ids: str | Iterable[str] | None = None, 

414 map_indexes: int | Iterable[int] | None = None, 

415 include_prior_dates: bool = False, 

416 limit: int | None = None, 

417 session: Session = NEW_SESSION, 

418 ) -> Query: 

419 """Composes a query to get one or more XCom entries. 

420 

421 This function returns an SQLAlchemy query of full XCom objects. If you 

422 just want one stored value, use :meth:`get_one` instead. 

423 

424 A deprecated form of this function accepts ``execution_date`` instead of 

425 ``run_id``. The two arguments are mutually exclusive. 

426 

427 :param run_id: DAG run ID for the task. 

428 :param key: A key for the XComs. If provided, only XComs with matching 

429 keys will be returned. Pass *None* (default) to remove the filter. 

430 :param task_ids: Only XComs from task with matching IDs will be pulled. 

431 Pass *None* (default) to remove the filter. 

432 :param dag_ids: Only pulls XComs from specified DAGs. Pass *None* 

433 (default) to remove the filter. 

434 :param map_indexes: Only XComs from matching map indexes will be pulled. 

435 Pass *None* (default) to remove the filter. 

436 :param include_prior_dates: If *False* (default), only XComs from the 

437 specified DAG run are returned. If *True*, all matching XComs are 

438 returned regardless of the run it belongs to. 

439 :param session: Database session. If not given, a new session will be 

440 created for this function. 

441 """ 

442 

443 @overload 

444 @classmethod 

445 def get_many( 

446 cls, 

447 execution_date: datetime.datetime, 

448 key: str | None = None, 

449 task_ids: str | Iterable[str] | None = None, 

450 dag_ids: str | Iterable[str] | None = None, 

451 map_indexes: int | Iterable[int] | None = None, 

452 include_prior_dates: bool = False, 

453 limit: int | None = None, 

454 session: Session = NEW_SESSION, 

455 ) -> Query: 

456 """:sphinx-autoapi-skip:""" 

457 

458 @classmethod 

459 @provide_session 

460 def get_many( 

461 cls, 

462 execution_date: datetime.datetime | None = None, 

463 key: str | None = None, 

464 task_ids: str | Iterable[str] | None = None, 

465 dag_ids: str | Iterable[str] | None = None, 

466 map_indexes: int | Iterable[int] | None = None, 

467 include_prior_dates: bool = False, 

468 limit: int | None = None, 

469 session: Session = NEW_SESSION, 

470 *, 

471 run_id: str | None = None, 

472 ) -> Query: 

473 """:sphinx-autoapi-skip:""" 

474 from airflow.models.dagrun import DagRun 

475 

476 if not exactly_one(execution_date is not None, run_id is not None): 

477 raise ValueError( 

478 f"Exactly one of run_id or execution_date must be passed. " 

479 f"Passed execution_date={execution_date}, run_id={run_id}" 

480 ) 

481 if execution_date is not None: 

482 message = "Passing 'execution_date' to 'XCom.get_many()' is deprecated. Use 'run_id' instead." 

483 warnings.warn(message, RemovedInAirflow3Warning, stacklevel=3) 

484 

485 query = session.query(cls).join(cls.dag_run) 

486 

487 if key: 

488 query = query.filter(cls.key == key) 

489 

490 if is_container(task_ids): 

491 query = query.filter(cls.task_id.in_(task_ids)) 

492 elif task_ids is not None: 

493 query = query.filter(cls.task_id == task_ids) 

494 

495 if is_container(dag_ids): 

496 query = query.filter(cls.dag_id.in_(dag_ids)) 

497 elif dag_ids is not None: 

498 query = query.filter(cls.dag_id == dag_ids) 

499 

500 if isinstance(map_indexes, range) and map_indexes.step == 1: 

501 query = query.filter(cls.map_index >= map_indexes.start, cls.map_index < map_indexes.stop) 

502 elif is_container(map_indexes): 

503 query = query.filter(cls.map_index.in_(map_indexes)) 

504 elif map_indexes is not None: 

505 query = query.filter(cls.map_index == map_indexes) 

506 

507 if include_prior_dates: 

508 if execution_date is not None: 

509 query = query.filter(DagRun.execution_date <= execution_date) 

510 else: 

511 dr = session.query(DagRun.execution_date).filter(DagRun.run_id == run_id).subquery() 

512 query = query.filter(cls.execution_date <= dr.c.execution_date) 

513 elif execution_date is not None: 

514 query = query.filter(DagRun.execution_date == execution_date) 

515 else: 

516 query = query.filter(cls.run_id == run_id) 

517 

518 query = query.order_by(DagRun.execution_date.desc(), cls.timestamp.desc()) 

519 if limit: 

520 return query.limit(limit) 

521 return query 

522 

523 @classmethod 

524 @provide_session 

525 def delete(cls, xcoms: XCom | Iterable[XCom], session: Session) -> None: 

526 """Delete one or multiple XCom entries.""" 

527 if isinstance(xcoms, XCom): 

528 xcoms = [xcoms] 

529 for xcom in xcoms: 

530 if not isinstance(xcom, XCom): 

531 raise TypeError(f"Expected XCom; received {xcom.__class__.__name__}") 

532 session.delete(xcom) 

533 session.commit() 

534 

535 @overload 

536 @classmethod 

537 def clear( 

538 cls, 

539 *, 

540 dag_id: str, 

541 task_id: str, 

542 run_id: str, 

543 map_index: int | None = None, 

544 session: Session = NEW_SESSION, 

545 ) -> None: 

546 """Clear all XCom data from the database for the given task instance. 

547 

548 A deprecated form of this function accepts ``execution_date`` instead of 

549 ``run_id``. The two arguments are mutually exclusive. 

550 

551 :param dag_id: ID of DAG to clear the XCom for. 

552 :param task_id: ID of task to clear the XCom for. 

553 :param run_id: ID of DAG run to clear the XCom for. 

554 :param map_index: If given, only clear XCom from this particular mapped 

555 task. The default ``None`` clears *all* XComs from the task. 

556 :param session: Database session. If not given, a new session will be 

557 created for this function. 

558 """ 

559 

560 @overload 

561 @classmethod 

562 def clear( 

563 cls, 

564 execution_date: pendulum.DateTime, 

565 dag_id: str, 

566 task_id: str, 

567 session: Session = NEW_SESSION, 

568 ) -> None: 

569 """:sphinx-autoapi-skip:""" 

570 

571 @classmethod 

572 @provide_session 

573 def clear( 

574 cls, 

575 execution_date: pendulum.DateTime | None = None, 

576 dag_id: str | None = None, 

577 task_id: str | None = None, 

578 session: Session = NEW_SESSION, 

579 *, 

580 run_id: str | None = None, 

581 map_index: int | None = None, 

582 ) -> None: 

583 """:sphinx-autoapi-skip:""" 

584 from airflow.models import DagRun 

585 

586 # Given the historic order of this function (execution_date was first argument) to add a new optional 

587 # param we need to add default values for everything :( 

588 if dag_id is None: 

589 raise TypeError("clear() missing required argument: dag_id") 

590 if task_id is None: 

591 raise TypeError("clear() missing required argument: task_id") 

592 

593 if not exactly_one(execution_date is not None, run_id is not None): 

594 raise ValueError( 

595 f"Exactly one of run_id or execution_date must be passed. " 

596 f"Passed execution_date={execution_date}, run_id={run_id}" 

597 ) 

598 

599 if execution_date is not None: 

600 message = "Passing 'execution_date' to 'XCom.clear()' is deprecated. Use 'run_id' instead." 

601 warnings.warn(message, RemovedInAirflow3Warning, stacklevel=3) 

602 run_id = ( 

603 session.query(DagRun.run_id) 

604 .filter(DagRun.dag_id == dag_id, DagRun.execution_date == execution_date) 

605 .scalar() 

606 ) 

607 

608 query = session.query(cls).filter_by(dag_id=dag_id, task_id=task_id, run_id=run_id) 

609 if map_index is not None: 

610 query = query.filter_by(map_index=map_index) 

611 query.delete() 

612 

613 @staticmethod 

614 def serialize_value( 

615 value: Any, 

616 *, 

617 key: str | None = None, 

618 task_id: str | None = None, 

619 dag_id: str | None = None, 

620 run_id: str | None = None, 

621 map_index: int | None = None, 

622 ) -> Any: 

623 """Serialize XCom value to str or pickled object.""" 

624 if conf.getboolean("core", "enable_xcom_pickling"): 

625 return pickle.dumps(value) 

626 try: 

627 return json.dumps(value, cls=XComEncoder).encode("UTF-8") 

628 except (ValueError, TypeError) as ex: 

629 log.error( 

630 "%s." 

631 " If you are using pickle instead of JSON for XCom," 

632 " then you need to enable pickle support for XCom" 

633 " in your airflow config or make sure to decorate your" 

634 " object with attr.", 

635 ex, 

636 ) 

637 raise 

638 

639 @staticmethod 

640 def _deserialize_value(result: XCom, orm: bool) -> Any: 

641 object_hook = None 

642 if orm: 

643 object_hook = XComDecoder.orm_object_hook 

644 

645 if result.value is None: 

646 return None 

647 if conf.getboolean("core", "enable_xcom_pickling"): 

648 try: 

649 return pickle.loads(result.value) 

650 except pickle.UnpicklingError: 

651 return json.loads(result.value.decode("UTF-8"), cls=XComDecoder, object_hook=object_hook) 

652 else: 

653 try: 

654 return json.loads(result.value.decode("UTF-8"), cls=XComDecoder, object_hook=object_hook) 

655 except (json.JSONDecodeError, UnicodeDecodeError): 

656 return pickle.loads(result.value) 

657 

658 @staticmethod 

659 def deserialize_value(result: XCom) -> Any: 

660 """Deserialize XCom value from str or pickle object""" 

661 return BaseXCom._deserialize_value(result, False) 

662 

663 def orm_deserialize_value(self) -> Any: 

664 """ 

665 Deserialize method which is used to reconstruct ORM XCom object. 

666 

667 This method should be overridden in custom XCom backends to avoid 

668 unnecessary request or other resource consuming operations when 

669 creating XCom orm model. This is used when viewing XCom listing 

670 in the webserver, for example. 

671 """ 

672 return BaseXCom._deserialize_value(self, True) 

673 

674 

675class _LazyXComAccessIterator(collections.abc.Iterator): 

676 def __init__(self, cm: contextlib.AbstractContextManager[Query]) -> None: 

677 self._cm = cm 

678 self._entered = False 

679 

680 def __del__(self) -> None: 

681 if self._entered: 

682 self._cm.__exit__(None, None, None) 

683 

684 def __iter__(self) -> collections.abc.Iterator: 

685 return self 

686 

687 def __next__(self) -> Any: 

688 return XCom.deserialize_value(next(self._it)) 

689 

690 @cached_property 

691 def _it(self) -> collections.abc.Iterator: 

692 self._entered = True 

693 return iter(self._cm.__enter__()) 

694 

695 

696@attr.define(slots=True) 

697class LazyXComAccess(collections.abc.Sequence): 

698 """Wrapper to lazily pull XCom with a sequence-like interface. 

699 

700 Note that since the session bound to the parent query may have died when we 

701 actually access the sequence's content, we must create a new session 

702 for every function call with ``with_session()``. 

703 

704 :meta private: 

705 """ 

706 

707 _query: Query 

708 _len: int | None = attr.ib(init=False, default=None) 

709 

710 @classmethod 

711 def build_from_xcom_query(cls, query: Query) -> LazyXComAccess: 

712 return cls(query=query.with_entities(XCom.value)) 

713 

714 def __repr__(self) -> str: 

715 return f"LazyXComAccess([{len(self)} items])" 

716 

717 def __str__(self) -> str: 

718 return str(list(self)) 

719 

720 def __eq__(self, other: Any) -> bool: 

721 if isinstance(other, (list, LazyXComAccess)): 

722 z = itertools.zip_longest(iter(self), iter(other), fillvalue=object()) 

723 return all(x == y for x, y in z) 

724 return NotImplemented 

725 

726 def __getstate__(self) -> Any: 

727 # We don't want to go to the trouble of serializing the entire Query 

728 # object, including its filters, hints, etc. (plus SQLAlchemy does not 

729 # provide a public API to inspect a query's contents). Converting the 

730 # query into a SQL string is the best we can get. Theoratically we can 

731 # do the same for count(), but I think it should be performant enough to 

732 # calculate only that eagerly. 

733 with self._get_bound_query() as query: 

734 statement = query.statement.compile(query.session.get_bind()) 

735 return (str(statement), query.count()) 

736 

737 def __setstate__(self, state: Any) -> None: 

738 statement, self._len = state 

739 self._query = Query(XCom.value).from_statement(text(statement)) 

740 

741 def __len__(self): 

742 if self._len is None: 

743 with self._get_bound_query() as query: 

744 self._len = query.count() 

745 return self._len 

746 

747 def __iter__(self): 

748 return _LazyXComAccessIterator(self._get_bound_query()) 

749 

750 def __getitem__(self, key): 

751 if not isinstance(key, int): 

752 raise ValueError("only support index access for now") 

753 try: 

754 with self._get_bound_query() as query: 

755 r = query.offset(key).limit(1).one() 

756 except NoResultFound: 

757 raise IndexError(key) from None 

758 return XCom.deserialize_value(r) 

759 

760 @contextlib.contextmanager 

761 def _get_bound_query(self) -> Generator[Query, None, None]: 

762 # Do we have a valid session already? 

763 if self._query.session and self._query.session.is_active: 

764 yield self._query 

765 return 

766 

767 session = settings.Session() 

768 try: 

769 yield self._query.with_session(session) 

770 finally: 

771 session.close() 

772 

773 

774def _patch_outdated_serializer(clazz: type[BaseXCom], params: Iterable[str]) -> None: 

775 """Patch a custom ``serialize_value`` to accept the modern signature. 

776 

777 To give custom XCom backends more flexibility with how they store values, we 

778 now forward all params passed to ``XCom.set`` to ``XCom.serialize_value``. 

779 In order to maintain compatibility with custom XCom backends written with 

780 the old signature, we check the signature and, if necessary, patch with a 

781 method that ignores kwargs the backend does not accept. 

782 """ 

783 old_serializer = clazz.serialize_value 

784 

785 @wraps(old_serializer) 

786 def _shim(**kwargs): 

787 kwargs = {k: kwargs.get(k) for k in params} 

788 warnings.warn( 

789 f"Method `serialize_value` in XCom backend {XCom.__name__} is using outdated signature and" 

790 f"must be updated to accept all params in `BaseXCom.set` except `session`. Support will be " 

791 f"removed in a future release.", 

792 RemovedInAirflow3Warning, 

793 ) 

794 return old_serializer(**kwargs) 

795 

796 clazz.serialize_value = _shim # type: ignore[assignment] 

797 

798 

799def _get_function_params(function) -> list[str]: 

800 """ 

801 Returns the list of variables names of a function 

802 

803 :param function: The function to inspect 

804 """ 

805 parameters = inspect.signature(function).parameters 

806 bound_arguments = [ 

807 name for name, p in parameters.items() if p.kind not in (p.VAR_POSITIONAL, p.VAR_KEYWORD) 

808 ] 

809 return bound_arguments 

810 

811 

812def resolve_xcom_backend() -> type[BaseXCom]: 

813 """Resolves custom XCom class 

814 

815 Confirms that custom XCom class extends the BaseXCom. 

816 Compares the function signature of the custom XCom serialize_value to the base XCom serialize_value. 

817 """ 

818 clazz = conf.getimport("core", "xcom_backend", fallback=f"airflow.models.xcom.{BaseXCom.__name__}") 

819 if not clazz: 

820 return BaseXCom 

821 if not issubclass(clazz, BaseXCom): 

822 raise TypeError( 

823 f"Your custom XCom class `{clazz.__name__}` is not a subclass of `{BaseXCom.__name__}`." 

824 ) 

825 base_xcom_params = _get_function_params(BaseXCom.serialize_value) 

826 xcom_params = _get_function_params(clazz.serialize_value) 

827 if not set(base_xcom_params) == set(xcom_params): 

828 _patch_outdated_serializer(clazz=clazz, params=xcom_params) 

829 return clazz 

830 

831 

832if TYPE_CHECKING: 

833 XCom = BaseXCom # Hack to avoid Mypy "Variable 'XCom' is not valid as a type". 

834else: 

835 XCom = resolve_xcom_backend()