Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/airflow/models/xcom.py: 45%

321 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-06-07 06:35 +0000

1# 

2# Licensed to the Apache Software Foundation (ASF) under one 

3# or more contributor license agreements. See the NOTICE file 

4# distributed with this work for additional information 

5# regarding copyright ownership. The ASF licenses this file 

6# to you under the Apache License, Version 2.0 (the 

7# "License"); you may not use this file except in compliance 

8# with the License. You may obtain a copy of the License at 

9# 

10# http://www.apache.org/licenses/LICENSE-2.0 

11# 

12# Unless required by applicable law or agreed to in writing, 

13# software distributed under the License is distributed on an 

14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

15# KIND, either express or implied. See the License for the 

16# specific language governing permissions and limitations 

17# under the License. 

18from __future__ import annotations 

19 

20import collections.abc 

21import contextlib 

22import datetime 

23import inspect 

24import itertools 

25import json 

26import logging 

27import pickle 

28import warnings 

29from functools import cached_property, wraps 

30from typing import TYPE_CHECKING, Any, Generator, Iterable, cast, overload 

31 

32import attr 

33import pendulum 

34from sqlalchemy import ( 

35 Column, 

36 ForeignKeyConstraint, 

37 Index, 

38 Integer, 

39 LargeBinary, 

40 PrimaryKeyConstraint, 

41 String, 

42 delete, 

43 text, 

44) 

45from sqlalchemy.ext.associationproxy import association_proxy 

46from sqlalchemy.orm import Query, Session, reconstructor, relationship 

47from sqlalchemy.orm.exc import NoResultFound 

48 

49from airflow import settings 

50from airflow.api_internal.internal_api_call import internal_api_call 

51from airflow.configuration import conf 

52from airflow.exceptions import RemovedInAirflow3Warning 

53from airflow.models.base import COLLATION_ARGS, ID_LEN, Base 

54from airflow.utils import timezone 

55from airflow.utils.helpers import exactly_one, is_container 

56from airflow.utils.json import XComDecoder, XComEncoder 

57from airflow.utils.log.logging_mixin import LoggingMixin 

58from airflow.utils.session import NEW_SESSION, provide_session 

59from airflow.utils.sqlalchemy import UtcDateTime 

60 

61# XCom constants below are needed for providers backward compatibility, 

62# which should import the constants directly after apache-airflow>=2.6.0 

63from airflow.utils.xcom import ( 

64 MAX_XCOM_SIZE, # noqa: F401 

65 XCOM_RETURN_KEY, 

66) 

67 

68log = logging.getLogger(__name__) 

69 

70if TYPE_CHECKING: 

71 from airflow.models.taskinstancekey import TaskInstanceKey 

72 

73 

74class BaseXCom(Base, LoggingMixin): 

75 """Base class for XCom objects.""" 

76 

77 __tablename__ = "xcom" 

78 

79 dag_run_id = Column(Integer(), nullable=False, primary_key=True) 

80 task_id = Column(String(ID_LEN, **COLLATION_ARGS), nullable=False, primary_key=True) 

81 map_index = Column(Integer, primary_key=True, nullable=False, server_default=text("-1")) 

82 key = Column(String(512, **COLLATION_ARGS), nullable=False, primary_key=True) 

83 

84 # Denormalized for easier lookup. 

85 dag_id = Column(String(ID_LEN, **COLLATION_ARGS), nullable=False) 

86 run_id = Column(String(ID_LEN, **COLLATION_ARGS), nullable=False) 

87 

88 value = Column(LargeBinary) 

89 timestamp = Column(UtcDateTime, default=timezone.utcnow, nullable=False) 

90 

91 __table_args__ = ( 

92 # Ideally we should create a unique index over (key, dag_id, task_id, run_id), 

93 # but it goes over MySQL's index length limit. So we instead index 'key' 

94 # separately, and enforce uniqueness with DagRun.id instead. 

95 Index("idx_xcom_key", key), 

96 Index("idx_xcom_task_instance", dag_id, task_id, run_id, map_index), 

97 PrimaryKeyConstraint( 

98 "dag_run_id", "task_id", "map_index", "key", name="xcom_pkey", mssql_clustered=True 

99 ), 

100 ForeignKeyConstraint( 

101 [dag_id, task_id, run_id, map_index], 

102 [ 

103 "task_instance.dag_id", 

104 "task_instance.task_id", 

105 "task_instance.run_id", 

106 "task_instance.map_index", 

107 ], 

108 name="xcom_task_instance_fkey", 

109 ondelete="CASCADE", 

110 ), 

111 ) 

112 

113 dag_run = relationship( 

114 "DagRun", 

115 primaryjoin="BaseXCom.dag_run_id == foreign(DagRun.id)", 

116 uselist=False, 

117 lazy="joined", 

118 passive_deletes="all", 

119 ) 

120 execution_date = association_proxy("dag_run", "execution_date") 

121 

122 @reconstructor 

123 def init_on_load(self): 

124 """ 

125 Called by the ORM after the instance has been loaded from the DB or otherwise reconstituted 

126 i.e automatically deserialize Xcom value when loading from DB. 

127 """ 

128 self.value = self.orm_deserialize_value() 

129 

130 def __repr__(self): 

131 if self.map_index < 0: 

132 return f'<XCom "{self.key}" ({self.task_id} @ {self.run_id})>' 

133 return f'<XCom "{self.key}" ({self.task_id}[{self.map_index}] @ {self.run_id})>' 

134 

135 @overload 

136 @classmethod 

137 def set( 

138 cls, 

139 key: str, 

140 value: Any, 

141 *, 

142 dag_id: str, 

143 task_id: str, 

144 run_id: str, 

145 map_index: int = -1, 

146 session: Session = NEW_SESSION, 

147 ) -> None: 

148 """Store an XCom value. 

149 

150 A deprecated form of this function accepts ``execution_date`` instead of 

151 ``run_id``. The two arguments are mutually exclusive. 

152 

153 :param key: Key to store the XCom. 

154 :param value: XCom value to store. 

155 :param dag_id: DAG ID. 

156 :param task_id: Task ID. 

157 :param run_id: DAG run ID for the task. 

158 :param map_index: Optional map index to assign XCom for a mapped task. 

159 The default is ``-1`` (set for a non-mapped task). 

160 :param session: Database session. If not given, a new session will be 

161 created for this function. 

162 """ 

163 

164 @overload 

165 @classmethod 

166 def set( 

167 cls, 

168 key: str, 

169 value: Any, 

170 task_id: str, 

171 dag_id: str, 

172 execution_date: datetime.datetime, 

173 session: Session = NEW_SESSION, 

174 ) -> None: 

175 """Store an XCom value. 

176 

177 :sphinx-autoapi-skip: 

178 """ 

179 

180 @classmethod 

181 @provide_session 

182 def set( 

183 cls, 

184 key: str, 

185 value: Any, 

186 task_id: str, 

187 dag_id: str, 

188 execution_date: datetime.datetime | None = None, 

189 session: Session = NEW_SESSION, 

190 *, 

191 run_id: str | None = None, 

192 map_index: int = -1, 

193 ) -> None: 

194 """Store an XCom value. 

195 

196 :sphinx-autoapi-skip: 

197 """ 

198 from airflow.models.dagrun import DagRun 

199 

200 if not exactly_one(execution_date is not None, run_id is not None): 

201 raise ValueError( 

202 f"Exactly one of run_id or execution_date must be passed. " 

203 f"Passed execution_date={execution_date}, run_id={run_id}" 

204 ) 

205 

206 if run_id is None: 

207 message = "Passing 'execution_date' to 'XCom.set()' is deprecated. Use 'run_id' instead." 

208 warnings.warn(message, RemovedInAirflow3Warning, stacklevel=3) 

209 try: 

210 dag_run_id, run_id = ( 

211 session.query(DagRun.id, DagRun.run_id) 

212 .filter(DagRun.dag_id == dag_id, DagRun.execution_date == execution_date) 

213 .one() 

214 ) 

215 except NoResultFound: 

216 raise ValueError(f"DAG run not found on DAG {dag_id!r} at {execution_date}") from None 

217 else: 

218 dag_run_id = session.query(DagRun.id).filter_by(dag_id=dag_id, run_id=run_id).scalar() 

219 if dag_run_id is None: 

220 raise ValueError(f"DAG run not found on DAG {dag_id!r} with ID {run_id!r}") 

221 

222 # Seamlessly resolve LazyXComAccess to a list. This is intended to work 

223 # as a "lazy list" to avoid pulling a ton of XComs unnecessarily, but if 

224 # it's pushed into XCom, the user should be aware of the performance 

225 # implications, and this avoids leaking the implementation detail. 

226 if isinstance(value, LazyXComAccess): 

227 warning_message = ( 

228 "Coercing mapped lazy proxy %s from task %s (DAG %s, run %s) " 

229 "to list, which may degrade performance. Review resource " 

230 "requirements for this operation, and call list() to suppress " 

231 "this message. See Dynamic Task Mapping documentation for " 

232 "more information about lazy proxy objects." 

233 ) 

234 log.warning( 

235 warning_message, 

236 "return value" if key == XCOM_RETURN_KEY else f"value {key}", 

237 task_id, 

238 dag_id, 

239 run_id or execution_date, 

240 ) 

241 value = list(value) 

242 

243 value = cls.serialize_value( 

244 value=value, 

245 key=key, 

246 task_id=task_id, 

247 dag_id=dag_id, 

248 run_id=run_id, 

249 map_index=map_index, 

250 ) 

251 

252 # Remove duplicate XComs and insert a new one. 

253 session.execute( 

254 delete(cls).where( 

255 cls.key == key, 

256 cls.run_id == run_id, 

257 cls.task_id == task_id, 

258 cls.dag_id == dag_id, 

259 cls.map_index == map_index, 

260 ) 

261 ) 

262 new = cast(Any, cls)( # Work around Mypy complaining model not defining '__init__'. 

263 dag_run_id=dag_run_id, 

264 key=key, 

265 value=value, 

266 run_id=run_id, 

267 task_id=task_id, 

268 dag_id=dag_id, 

269 map_index=map_index, 

270 ) 

271 session.add(new) 

272 session.flush() 

273 

274 @staticmethod 

275 @provide_session 

276 @internal_api_call 

277 def get_value( 

278 *, 

279 ti_key: TaskInstanceKey, 

280 key: str | None = None, 

281 session: Session = NEW_SESSION, 

282 ) -> Any: 

283 """Retrieve an XCom value for a task instance. 

284 

285 This method returns "full" XCom values (i.e. uses ``deserialize_value`` 

286 from the XCom backend). Use :meth:`get_many` if you want the "shortened" 

287 value via ``orm_deserialize_value``. 

288 

289 If there are no results, *None* is returned. If multiple XCom entries 

290 match the criteria, an arbitrary one is returned. 

291 

292 :param ti_key: The TaskInstanceKey to look up the XCom for. 

293 :param key: A key for the XCom. If provided, only XCom with matching 

294 keys will be returned. Pass *None* (default) to remove the filter. 

295 :param session: Database session. If not given, a new session will be 

296 created for this function. 

297 """ 

298 return BaseXCom.get_one( 

299 key=key, 

300 task_id=ti_key.task_id, 

301 dag_id=ti_key.dag_id, 

302 run_id=ti_key.run_id, 

303 map_index=ti_key.map_index, 

304 session=session, 

305 ) 

306 

307 @overload 

308 @staticmethod 

309 @internal_api_call 

310 def get_one( 

311 *, 

312 key: str | None = None, 

313 dag_id: str | None = None, 

314 task_id: str | None = None, 

315 run_id: str | None = None, 

316 map_index: int | None = None, 

317 session: Session = NEW_SESSION, 

318 ) -> Any | None: 

319 """Retrieve an XCom value, optionally meeting certain criteria. 

320 

321 This method returns "full" XCom values (i.e. uses ``deserialize_value`` 

322 from the XCom backend). Use :meth:`get_many` if you want the "shortened" 

323 value via ``orm_deserialize_value``. 

324 

325 If there are no results, *None* is returned. If multiple XCom entries 

326 match the criteria, an arbitrary one is returned. 

327 

328 A deprecated form of this function accepts ``execution_date`` instead of 

329 ``run_id``. The two arguments are mutually exclusive. 

330 

331 .. seealso:: ``get_value()`` is a convenience function if you already 

332 have a structured TaskInstance or TaskInstanceKey object available. 

333 

334 :param run_id: DAG run ID for the task. 

335 :param dag_id: Only pull XCom from this DAG. Pass *None* (default) to 

336 remove the filter. 

337 :param task_id: Only XCom from task with matching ID will be pulled. 

338 Pass *None* (default) to remove the filter. 

339 :param map_index: Only XCom from task with matching ID will be pulled. 

340 Pass *None* (default) to remove the filter. 

341 :param key: A key for the XCom. If provided, only XCom with matching 

342 keys will be returned. Pass *None* (default) to remove the filter. 

343 :param include_prior_dates: If *False* (default), only XCom from the 

344 specified DAG run is returned. If *True*, the latest matching XCom is 

345 returned regardless of the run it belongs to. 

346 :param session: Database session. If not given, a new session will be 

347 created for this function. 

348 """ 

349 

350 @overload 

351 @staticmethod 

352 @internal_api_call 

353 def get_one( 

354 execution_date: datetime.datetime, 

355 key: str | None = None, 

356 task_id: str | None = None, 

357 dag_id: str | None = None, 

358 include_prior_dates: bool = False, 

359 session: Session = NEW_SESSION, 

360 ) -> Any | None: 

361 """Retrieve an XCom value, optionally meeting certain criteria. 

362 

363 :sphinx-autoapi-skip: 

364 """ 

365 

366 @staticmethod 

367 @provide_session 

368 @internal_api_call 

369 def get_one( 

370 execution_date: datetime.datetime | None = None, 

371 key: str | None = None, 

372 task_id: str | None = None, 

373 dag_id: str | None = None, 

374 include_prior_dates: bool = False, 

375 session: Session = NEW_SESSION, 

376 *, 

377 run_id: str | None = None, 

378 map_index: int | None = None, 

379 ) -> Any | None: 

380 """Retrieve an XCom value, optionally meeting certain criteria. 

381 

382 :sphinx-autoapi-skip: 

383 """ 

384 if not exactly_one(execution_date is not None, run_id is not None): 

385 raise ValueError("Exactly one of run_id or execution_date must be passed") 

386 

387 if run_id: 

388 query = BaseXCom.get_many( 

389 run_id=run_id, 

390 key=key, 

391 task_ids=task_id, 

392 dag_ids=dag_id, 

393 map_indexes=map_index, 

394 include_prior_dates=include_prior_dates, 

395 limit=1, 

396 session=session, 

397 ) 

398 elif execution_date is not None: 

399 message = "Passing 'execution_date' to 'XCom.get_one()' is deprecated. Use 'run_id' instead." 

400 warnings.warn(message, RemovedInAirflow3Warning, stacklevel=3) 

401 

402 with warnings.catch_warnings(): 

403 warnings.simplefilter("ignore", RemovedInAirflow3Warning) 

404 query = BaseXCom.get_many( 

405 execution_date=execution_date, 

406 key=key, 

407 task_ids=task_id, 

408 dag_ids=dag_id, 

409 map_indexes=map_index, 

410 include_prior_dates=include_prior_dates, 

411 limit=1, 

412 session=session, 

413 ) 

414 else: 

415 raise RuntimeError("Should not happen?") 

416 

417 result = query.with_entities(BaseXCom.value).first() 

418 if result: 

419 return BaseXCom.deserialize_value(result) 

420 return None 

421 

422 @overload 

423 @staticmethod 

424 def get_many( 

425 *, 

426 run_id: str, 

427 key: str | None = None, 

428 task_ids: str | Iterable[str] | None = None, 

429 dag_ids: str | Iterable[str] | None = None, 

430 map_indexes: int | Iterable[int] | None = None, 

431 include_prior_dates: bool = False, 

432 limit: int | None = None, 

433 session: Session = NEW_SESSION, 

434 ) -> Query: 

435 """Composes a query to get one or more XCom entries. 

436 

437 This function returns an SQLAlchemy query of full XCom objects. If you 

438 just want one stored value, use :meth:`get_one` instead. 

439 

440 A deprecated form of this function accepts ``execution_date`` instead of 

441 ``run_id``. The two arguments are mutually exclusive. 

442 

443 :param run_id: DAG run ID for the task. 

444 :param key: A key for the XComs. If provided, only XComs with matching 

445 keys will be returned. Pass *None* (default) to remove the filter. 

446 :param task_ids: Only XComs from task with matching IDs will be pulled. 

447 Pass *None* (default) to remove the filter. 

448 :param dag_ids: Only pulls XComs from specified DAGs. Pass *None* 

449 (default) to remove the filter. 

450 :param map_indexes: Only XComs from matching map indexes will be pulled. 

451 Pass *None* (default) to remove the filter. 

452 :param include_prior_dates: If *False* (default), only XComs from the 

453 specified DAG run are returned. If *True*, all matching XComs are 

454 returned regardless of the run it belongs to. 

455 :param session: Database session. If not given, a new session will be 

456 created for this function. 

457 :param limit: Limiting returning XComs 

458 """ 

459 

460 @overload 

461 @staticmethod 

462 @internal_api_call 

463 def get_many( 

464 execution_date: datetime.datetime, 

465 key: str | None = None, 

466 task_ids: str | Iterable[str] | None = None, 

467 dag_ids: str | Iterable[str] | None = None, 

468 map_indexes: int | Iterable[int] | None = None, 

469 include_prior_dates: bool = False, 

470 limit: int | None = None, 

471 session: Session = NEW_SESSION, 

472 ) -> Query: 

473 """Composes a query to get one or more XCom entries. 

474 

475 :sphinx-autoapi-skip: 

476 """ 

477 

478 @staticmethod 

479 @provide_session 

480 @internal_api_call 

481 def get_many( 

482 execution_date: datetime.datetime | None = None, 

483 key: str | None = None, 

484 task_ids: str | Iterable[str] | None = None, 

485 dag_ids: str | Iterable[str] | None = None, 

486 map_indexes: int | Iterable[int] | None = None, 

487 include_prior_dates: bool = False, 

488 limit: int | None = None, 

489 session: Session = NEW_SESSION, 

490 *, 

491 run_id: str | None = None, 

492 ) -> Query: 

493 """Composes a query to get one or more XCom entries. 

494 

495 :sphinx-autoapi-skip: 

496 """ 

497 from airflow.models.dagrun import DagRun 

498 

499 if not exactly_one(execution_date is not None, run_id is not None): 

500 raise ValueError( 

501 f"Exactly one of run_id or execution_date must be passed. " 

502 f"Passed execution_date={execution_date}, run_id={run_id}" 

503 ) 

504 if execution_date is not None: 

505 message = "Passing 'execution_date' to 'XCom.get_many()' is deprecated. Use 'run_id' instead." 

506 warnings.warn(message, RemovedInAirflow3Warning, stacklevel=3) 

507 

508 query = session.query(BaseXCom).join(BaseXCom.dag_run) 

509 

510 if key: 

511 query = query.filter(BaseXCom.key == key) 

512 

513 if is_container(task_ids): 

514 query = query.filter(BaseXCom.task_id.in_(task_ids)) 

515 elif task_ids is not None: 

516 query = query.filter(BaseXCom.task_id == task_ids) 

517 

518 if is_container(dag_ids): 

519 query = query.filter(BaseXCom.dag_id.in_(dag_ids)) 

520 elif dag_ids is not None: 

521 query = query.filter(BaseXCom.dag_id == dag_ids) 

522 

523 if isinstance(map_indexes, range) and map_indexes.step == 1: 

524 query = query.filter( 

525 BaseXCom.map_index >= map_indexes.start, BaseXCom.map_index < map_indexes.stop 

526 ) 

527 elif is_container(map_indexes): 

528 query = query.filter(BaseXCom.map_index.in_(map_indexes)) 

529 elif map_indexes is not None: 

530 query = query.filter(BaseXCom.map_index == map_indexes) 

531 

532 if include_prior_dates: 

533 if execution_date is not None: 

534 query = query.filter(DagRun.execution_date <= execution_date) 

535 else: 

536 dr = session.query(DagRun.execution_date).filter(DagRun.run_id == run_id).subquery() 

537 query = query.filter(BaseXCom.execution_date <= dr.c.execution_date) 

538 elif execution_date is not None: 

539 query = query.filter(DagRun.execution_date == execution_date) 

540 else: 

541 query = query.filter(BaseXCom.run_id == run_id) 

542 

543 query = query.order_by(DagRun.execution_date.desc(), BaseXCom.timestamp.desc()) 

544 if limit: 

545 return query.limit(limit) 

546 return query 

547 

548 @classmethod 

549 @provide_session 

550 def delete(cls, xcoms: XCom | Iterable[XCom], session: Session) -> None: 

551 """Delete one or multiple XCom entries.""" 

552 if isinstance(xcoms, XCom): 

553 xcoms = [xcoms] 

554 for xcom in xcoms: 

555 if not isinstance(xcom, XCom): 

556 raise TypeError(f"Expected XCom; received {xcom.__class__.__name__}") 

557 session.delete(xcom) 

558 session.commit() 

559 

560 @overload 

561 @staticmethod 

562 @internal_api_call 

563 def clear( 

564 *, 

565 dag_id: str, 

566 task_id: str, 

567 run_id: str, 

568 map_index: int | None = None, 

569 session: Session = NEW_SESSION, 

570 ) -> None: 

571 """Clear all XCom data from the database for the given task instance. 

572 

573 A deprecated form of this function accepts ``execution_date`` instead of 

574 ``run_id``. The two arguments are mutually exclusive. 

575 

576 :param dag_id: ID of DAG to clear the XCom for. 

577 :param task_id: ID of task to clear the XCom for. 

578 :param run_id: ID of DAG run to clear the XCom for. 

579 :param map_index: If given, only clear XCom from this particular mapped 

580 task. The default ``None`` clears *all* XComs from the task. 

581 :param session: Database session. If not given, a new session will be 

582 created for this function. 

583 """ 

584 

585 @overload 

586 @staticmethod 

587 @internal_api_call 

588 def clear( 

589 execution_date: pendulum.DateTime, 

590 dag_id: str, 

591 task_id: str, 

592 session: Session = NEW_SESSION, 

593 ) -> None: 

594 """Clear all XCom data from the database for the given task instance. 

595 

596 :sphinx-autoapi-skip: 

597 """ 

598 

599 @staticmethod 

600 @provide_session 

601 @internal_api_call 

602 def clear( 

603 execution_date: pendulum.DateTime | None = None, 

604 dag_id: str | None = None, 

605 task_id: str | None = None, 

606 session: Session = NEW_SESSION, 

607 *, 

608 run_id: str | None = None, 

609 map_index: int | None = None, 

610 ) -> None: 

611 """Clear all XCom data from the database for the given task instance. 

612 

613 :sphinx-autoapi-skip: 

614 """ 

615 from airflow.models import DagRun 

616 

617 # Given the historic order of this function (execution_date was first argument) to add a new optional 

618 # param we need to add default values for everything :( 

619 if dag_id is None: 

620 raise TypeError("clear() missing required argument: dag_id") 

621 if task_id is None: 

622 raise TypeError("clear() missing required argument: task_id") 

623 

624 if not exactly_one(execution_date is not None, run_id is not None): 

625 raise ValueError( 

626 f"Exactly one of run_id or execution_date must be passed. " 

627 f"Passed execution_date={execution_date}, run_id={run_id}" 

628 ) 

629 

630 if execution_date is not None: 

631 message = "Passing 'execution_date' to 'XCom.clear()' is deprecated. Use 'run_id' instead." 

632 warnings.warn(message, RemovedInAirflow3Warning, stacklevel=3) 

633 run_id = ( 

634 session.query(DagRun.run_id) 

635 .filter(DagRun.dag_id == dag_id, DagRun.execution_date == execution_date) 

636 .scalar() 

637 ) 

638 

639 query = session.query(BaseXCom).filter_by(dag_id=dag_id, task_id=task_id, run_id=run_id) 

640 if map_index is not None: 

641 query = query.filter_by(map_index=map_index) 

642 query.delete() 

643 

644 @staticmethod 

645 def serialize_value( 

646 value: Any, 

647 *, 

648 key: str | None = None, 

649 task_id: str | None = None, 

650 dag_id: str | None = None, 

651 run_id: str | None = None, 

652 map_index: int | None = None, 

653 ) -> Any: 

654 """Serialize XCom value to str or pickled object.""" 

655 if conf.getboolean("core", "enable_xcom_pickling"): 

656 return pickle.dumps(value) 

657 try: 

658 return json.dumps(value, cls=XComEncoder).encode("UTF-8") 

659 except (ValueError, TypeError) as ex: 

660 log.error( 

661 "%s." 

662 " If you are using pickle instead of JSON for XCom," 

663 " then you need to enable pickle support for XCom" 

664 " in your airflow config or make sure to decorate your" 

665 " object with attr.", 

666 ex, 

667 ) 

668 raise 

669 

670 @staticmethod 

671 def _deserialize_value(result: XCom, orm: bool) -> Any: 

672 object_hook = None 

673 if orm: 

674 object_hook = XComDecoder.orm_object_hook 

675 

676 if result.value is None: 

677 return None 

678 if conf.getboolean("core", "enable_xcom_pickling"): 

679 try: 

680 return pickle.loads(result.value) 

681 except pickle.UnpicklingError: 

682 return json.loads(result.value.decode("UTF-8"), cls=XComDecoder, object_hook=object_hook) 

683 else: 

684 try: 

685 return json.loads(result.value.decode("UTF-8"), cls=XComDecoder, object_hook=object_hook) 

686 except (json.JSONDecodeError, UnicodeDecodeError): 

687 return pickle.loads(result.value) 

688 

689 @staticmethod 

690 def deserialize_value(result: XCom) -> Any: 

691 """Deserialize XCom value from str or pickle object.""" 

692 return BaseXCom._deserialize_value(result, False) 

693 

694 def orm_deserialize_value(self) -> Any: 

695 """ 

696 Deserialize method which is used to reconstruct ORM XCom object. 

697 

698 This method should be overridden in custom XCom backends to avoid 

699 unnecessary request or other resource consuming operations when 

700 creating XCom orm model. This is used when viewing XCom listing 

701 in the webserver, for example. 

702 """ 

703 return BaseXCom._deserialize_value(self, True) 

704 

705 

706class _LazyXComAccessIterator(collections.abc.Iterator): 

707 def __init__(self, cm: contextlib.AbstractContextManager[Query]) -> None: 

708 self._cm = cm 

709 self._entered = False 

710 

711 def __del__(self) -> None: 

712 if self._entered: 

713 self._cm.__exit__(None, None, None) 

714 

715 def __iter__(self) -> collections.abc.Iterator: 

716 return self 

717 

718 def __next__(self) -> Any: 

719 return XCom.deserialize_value(next(self._it)) 

720 

721 @cached_property 

722 def _it(self) -> collections.abc.Iterator: 

723 self._entered = True 

724 return iter(self._cm.__enter__()) 

725 

726 

727@attr.define(slots=True) 

728class LazyXComAccess(collections.abc.Sequence): 

729 """Wrapper to lazily pull XCom with a sequence-like interface. 

730 

731 Note that since the session bound to the parent query may have died when we 

732 actually access the sequence's content, we must create a new session 

733 for every function call with ``with_session()``. 

734 

735 :meta private: 

736 """ 

737 

738 _query: Query 

739 _len: int | None = attr.ib(init=False, default=None) 

740 

741 @classmethod 

742 def build_from_xcom_query(cls, query: Query) -> LazyXComAccess: 

743 return cls(query=query.with_entities(XCom.value)) 

744 

745 def __repr__(self) -> str: 

746 return f"LazyXComAccess([{len(self)} items])" 

747 

748 def __str__(self) -> str: 

749 return str(list(self)) 

750 

751 def __eq__(self, other: Any) -> bool: 

752 if isinstance(other, (list, LazyXComAccess)): 

753 z = itertools.zip_longest(iter(self), iter(other), fillvalue=object()) 

754 return all(x == y for x, y in z) 

755 return NotImplemented 

756 

757 def __getstate__(self) -> Any: 

758 # We don't want to go to the trouble of serializing the entire Query 

759 # object, including its filters, hints, etc. (plus SQLAlchemy does not 

760 # provide a public API to inspect a query's contents). Converting the 

761 # query into a SQL string is the best we can get. Theoratically we can 

762 # do the same for count(), but I think it should be performant enough to 

763 # calculate only that eagerly. 

764 with self._get_bound_query() as query: 

765 statement = query.statement.compile( 

766 query.session.get_bind(), 

767 # This inlines all the values into the SQL string to simplify 

768 # cross-process commuinication as much as possible. 

769 compile_kwargs={"literal_binds": True}, 

770 ) 

771 return (str(statement), query.count()) 

772 

773 def __setstate__(self, state: Any) -> None: 

774 statement, self._len = state 

775 self._query = Query(XCom.value).from_statement(text(statement)) 

776 

777 def __len__(self): 

778 if self._len is None: 

779 with self._get_bound_query() as query: 

780 self._len = query.count() 

781 return self._len 

782 

783 def __iter__(self): 

784 return _LazyXComAccessIterator(self._get_bound_query()) 

785 

786 def __getitem__(self, key): 

787 if not isinstance(key, int): 

788 raise ValueError("only support index access for now") 

789 try: 

790 with self._get_bound_query() as query: 

791 r = query.offset(key).limit(1).one() 

792 except NoResultFound: 

793 raise IndexError(key) from None 

794 return XCom.deserialize_value(r) 

795 

796 @contextlib.contextmanager 

797 def _get_bound_query(self) -> Generator[Query, None, None]: 

798 # Do we have a valid session already? 

799 if self._query.session and self._query.session.is_active: 

800 yield self._query 

801 return 

802 

803 Session = getattr(settings, "Session", None) 

804 if Session is None: 

805 raise RuntimeError("Session must be set before!") 

806 session = Session() 

807 try: 

808 yield self._query.with_session(session) 

809 finally: 

810 session.close() 

811 

812 

813def _patch_outdated_serializer(clazz: type[BaseXCom], params: Iterable[str]) -> None: 

814 """Patch a custom ``serialize_value`` to accept the modern signature. 

815 

816 To give custom XCom backends more flexibility with how they store values, we 

817 now forward all params passed to ``XCom.set`` to ``XCom.serialize_value``. 

818 In order to maintain compatibility with custom XCom backends written with 

819 the old signature, we check the signature and, if necessary, patch with a 

820 method that ignores kwargs the backend does not accept. 

821 """ 

822 old_serializer = clazz.serialize_value 

823 

824 @wraps(old_serializer) 

825 def _shim(**kwargs): 

826 kwargs = {k: kwargs.get(k) for k in params} 

827 warnings.warn( 

828 f"Method `serialize_value` in XCom backend {XCom.__name__} is using outdated signature and" 

829 f"must be updated to accept all params in `BaseXCom.set` except `session`. Support will be " 

830 f"removed in a future release.", 

831 RemovedInAirflow3Warning, 

832 ) 

833 return old_serializer(**kwargs) 

834 

835 clazz.serialize_value = _shim # type: ignore[assignment] 

836 

837 

838def _get_function_params(function) -> list[str]: 

839 """ 

840 Returns the list of variables names of a function. 

841 

842 :param function: The function to inspect 

843 """ 

844 parameters = inspect.signature(function).parameters 

845 bound_arguments = [ 

846 name for name, p in parameters.items() if p.kind not in (p.VAR_POSITIONAL, p.VAR_KEYWORD) 

847 ] 

848 return bound_arguments 

849 

850 

851def resolve_xcom_backend() -> type[BaseXCom]: 

852 """Resolves custom XCom class. 

853 

854 Confirms that custom XCom class extends the BaseXCom. 

855 Compares the function signature of the custom XCom serialize_value to the base XCom serialize_value. 

856 """ 

857 clazz = conf.getimport("core", "xcom_backend", fallback=f"airflow.models.xcom.{BaseXCom.__name__}") 

858 if not clazz: 

859 return BaseXCom 

860 if not issubclass(clazz, BaseXCom): 

861 raise TypeError( 

862 f"Your custom XCom class `{clazz.__name__}` is not a subclass of `{BaseXCom.__name__}`." 

863 ) 

864 base_xcom_params = _get_function_params(BaseXCom.serialize_value) 

865 xcom_params = _get_function_params(clazz.serialize_value) 

866 if not set(base_xcom_params) == set(xcom_params): 

867 _patch_outdated_serializer(clazz=clazz, params=xcom_params) 

868 return clazz 

869 

870 

871if TYPE_CHECKING: 

872 XCom = BaseXCom # Hack to avoid Mypy "Variable 'XCom' is not valid as a type". 

873else: 

874 XCom = resolve_xcom_backend()